Skip to content

feat: JAX training #4782

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 33 commits into
base: devel
Choose a base branch
from
Draft

Conversation

njzjz
Copy link
Member

@njzjz njzjz commented Jun 5, 2025

Summary by CodeRabbit

  • New Features

    • Introduced JAX backend entry points for training and model freezing, including command-line interfaces.
    • Added support for Hessian loss computation in energy loss with configurable prefactors and RMSE reporting.
    • Implemented a new JAX-based training framework for DeePMD models with checkpointing, mixed precision, and detailed logging.
    • Added a command-line option to include Hessian output during model freezing.
    • Extended model serialization and inference to support Hessian mode and related outputs.
  • Enhancements

    • Improved RMSE calculation and display for energy and force losses.
    • Added output statistics computation for energy fitting models.
    • Generalized learning rate scheduling to support alternative numerical libraries.
    • Enhanced environment matrix statistics handling and data preparation.
  • Other Changes

    • Updated serialization to include current training step and support Hessian enabling.
    • Added license identifiers to new modules.

njzjz added 20 commits May 25, 2025 12:53
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
@njzjz njzjz added this to the v3.2.0 milestone Jun 5, 2025
@njzjz njzjz linked an issue Jun 5, 2025 that may be closed by this pull request
@github-actions github-actions bot added the Python label Jun 5, 2025
valid_data = None

# get training info
stop_batch = jdata["training"]["numb_steps"]

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable stop_batch is not used.
if (
origin_type_map is not None and not origin_type_map
): # get the type_map from data if not provided
origin_type_map = get_data(

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable origin_type_map is not used.
)
jdata_cpy = jdata.copy()
type_map = jdata["model"].get("type_map")
train_data = get_data(

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable train_data is not used.
Copy link
Contributor

coderabbitai bot commented Jun 5, 2025

📝 Walkthrough

Walkthrough

This update introduces a JAX-based training and entrypoint framework for DeePMD-Kit, including new modules for command-line handling, model freezing, and training orchestration. It adds or modifies several utilities and loss functions, enhances serialization with step tracking and optional Hessian support, and generalizes learning rate scheduling. Minor internal changes and new attributes are also introduced in descriptors and statistics utilities.

Changes

File(s) Change Summary
Backend Entry Point Update
deepmd/backend/jax.py
Updated JAXBackend.entry_point_hook to return the actual entry point (main) instead of raising NotImplementedError.
Descriptor Attribute Addition
deepmd/dpmodel/descriptor/dpa1.py
Added ndescrpt attribute to DescrptBlockSeAtten as self.nnei * 4.
Energy Fitting Enhancements
deepmd/dpmodel/fitting/ener_fitting.py
Added compute_output_stats and _compute_output_stats methods to EnergyFittingNet for output statistics computation.
Energy Loss Refinements and Hessian Support
deepmd/dpmodel/loss/ener.py
Refined RMSE calculation for energy/force loss; added EnergyHessianLoss class for Hessian loss support with dynamic prefactors and label requirements.
Environment Matrix Statistics Update
deepmd/dpmodel/utils/env_mat_stat.py
Adjusted unpacking in EnvMatStatSe.iter to ignore natoms, updated reshaping logic for coordinates, atom types, and box.
Learning Rate Generalization
deepmd/dpmodel/utils/learning_rate.py
Generalized LearningRateExp.value to accept a numerical library parameter (xp), defaulting to NumPy, replacing direct NumPy calls.
JAX Entrypoints Initialization
deepmd/jax/entrypoints/__init__.py
Added license identifier file, no functional code.
JAX Training Initialization
deepmd/jax/train/__init__.py
Added license identifier file, no functional code.
Model Freezing Entrypoint
deepmd/jax/entrypoints/freeze.py
Added freeze function to convert checkpoint data into a serialized output file, with optional Hessian inclusion.
Main JAX Entrypoint
deepmd/jax/entrypoints/main.py
Added main function as CLI entry point, dispatching to train or freeze commands, with error handling and logging setup.
Training Entrypoint Script
deepmd/jax/entrypoints/train.py
Added JAX-based training entrypoint, with SummaryPrinter class, train function for model setup and execution, and update_sel helper.
JAX Trainer Implementation
deepmd/jax/train/trainer.py
Introduced DPTrainer class for model training, checkpointing, and reporting; added prepare_input utility function for data preparation.
Serialization Enhancements
deepmd/jax/utils/serialization.py
Extended deserialize_to_file with optional hessian flag enabling Hessian mode; included current_step in serialized data under "@variables".
Inference Model Output and Hessian Support
deepmd/jax/infer/deep_eval.py
Added support for Hessian output variable category; updated output shape handling; added method to check Hessian mode from model script.
JAX-to-TF Serialization Update
deepmd/jax/jax2tf/serialization.py
Extended deserialize_to_file to support optional hessian flag enabling Hessian mode on deserialization.
HLO Model Output Definition Update
deepmd/jax/model/hlo.py
Added new output variable definition "energy_hessian"; modified output selection to use it when Hessian mode is enabled.
CLI Freeze Command Update
deepmd/main.py
Added --hessian boolean flag to the "freeze" subcommand to enable Hessian inclusion in model output.

Sequence Diagram(s)

sequenceDiagram
    participant CLI_User
    participant MainEntrypoint
    participant TrainEntrypoint
    participant FreezeEntrypoint
    participant DPTrainer
    participant SerializationUtils

    CLI_User->>MainEntrypoint: main(args)
    MainEntrypoint->>MainEntrypoint: parse_args(args)
    alt command == "train"
        MainEntrypoint->>TrainEntrypoint: train(**args)
        TrainEntrypoint->>DPTrainer: DPTrainer(jdata, ...)
        DPTrainer->>DPTrainer: train(train_data, valid_data)
        DPTrainer->>SerializationUtils: save_checkpoint(...)
    else command == "freeze"
        MainEntrypoint->>FreezeEntrypoint: freeze(checkpoint_folder, output)
        FreezeEntrypoint->>SerializationUtils: serialize_from_file(folder)
        FreezeEntrypoint->>SerializationUtils: deserialize_to_file(output, hessian)
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~40 minutes

Possibly related PRs

  • feat(jax): support neural networks #4156: Initial implementation of JAXBackend class with entry_point_hook raising NotImplementedError, related as this PR completes the implementation by returning the actual entry point.
  • feat(jax): checkpoint I/O #4236: Updates to JAXBackend class serialization and backend features; related by modifying the same backend class but different properties.

Suggested labels

Core, OP, Docs

Suggested reviewers

  • wanghan-iapcm
  • iProzd
✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Explain this complex logic.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai explain this code block.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and explain its main purpose.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai generate sequence diagram to generate a sequence diagram of the changes in this PR.
  • @coderabbitai generate unit tests to generate unit tests for this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

🧹 Nitpick comments (10)
deepmd/dpmodel/utils/learning_rate.py (1)

48-52: Update return type annotation for backend generalization.

The addition of the xp parameter to support different array backends (like JAX) is excellent for the framework's extensibility. However, the return type annotation should be updated to reflect the actual return type.

-    def value(self, step, xp=np) -> np.float64:
+    def value(self, step, xp=np):

Or if you want to be more specific:

-    def value(self, step, xp=np) -> np.float64:
+    def value(self, step, xp=np) -> Union[np.float64, Any]:
deepmd/jax/entrypoints/freeze.py (1)

12-36: LGTM! Consider adding output validation and improved error handling.

The freeze function implementation is well-structured with proper checkpoint handling logic. The keyword-only argument design and integration with serialization utilities are excellent.

Consider these minor enhancements:

def freeze(
    *,
    checkpoint_folder: str,
    output: str,
    **kwargs,
) -> None:
    """Freeze the graph in supplied folder.

    Parameters
    ----------
    checkpoint_folder : str
        location of either the folder with checkpoint or the checkpoint prefix
    output : str
-        output file name
+        output file name (supported formats: .jax, .hlo, .savedmodel)
    **kwargs
        other arguments
    """
+    # Validate output format
+    supported_formats = ['.jax', '.hlo', '.savedmodel']
+    if not any(output.endswith(fmt) for fmt in supported_formats):
+        raise ValueError(f"Unsupported output format. Supported: {supported_formats}")
+    
    if (Path(checkpoint_folder) / "checkpoint").is_file():
        checkpoint_meta = Path(checkpoint_folder) / "checkpoint"
        checkpoint_folder = checkpoint_meta.read_text().strip()
    if Path(checkpoint_folder).is_dir():
-        data = serialize_from_file(checkpoint_folder)
-        deserialize_to_file(output, data)
+        try:
+            data = serialize_from_file(checkpoint_folder)
+            deserialize_to_file(output, data)
+        except Exception as e:
+            raise RuntimeError(f"Failed to freeze checkpoint: {e}") from e
    else:
        raise FileNotFoundError(f"Checkpoint {checkpoint_folder} does not exist.")
deepmd/dpmodel/fitting/ener_fitting.py (1)

118-124: Simplify nested loops for better readability.

The triple-nested loops can be simplified using list comprehensions or numpy operations.

Consider refactoring the nested loops:

-        sys_ener = []
-        for ss in range(len(data)):
-            sys_data = []
-            for ii in range(len(data[ss])):
-                for jj in range(len(data[ss][ii])):
-                    sys_data.append(data[ss][ii][jj])
-            sys_data = np.concatenate(sys_data)
-            sys_ener.append(np.average(sys_data))
+        sys_ener = []
+        for system_data in data:
+            # Flatten all batches and frames for this system
+            sys_data = np.concatenate([frame for batch in system_data for frame in batch])
+            sys_ener.append(np.average(sys_data))

Similarly for the mixed_type branch:

-                tmp_tynatom = []
-                for ii in range(len(data[ss])):
-                    for jj in range(len(data[ss][ii])):
-                        tmp_tynatom.append(data[ss][ii][jj].astype(np.float64))
-                tmp_tynatom = np.average(np.array(tmp_tynatom), axis=0)
+                # Flatten all batches and frames, then compute average
+                tmp_tynatom = np.average(
+                    np.array([frame.astype(np.float64) 
+                             for batch in data[ss] 
+                             for frame in batch]), 
+                    axis=0
+                )

Also applies to: 130-136

deepmd/dpmodel/loss/ener.py (1)

457-457: Clarify the ndof comment for Hessian data.

The comment # 9=3*3 --> 3N*3N=ndof*natoms*natoms is confusing since ndof is set to 1, not 9.

Update the comment to be clearer:

-                    ndof=1,  # 9=3*3 --> 3N*3N=ndof*natoms*natoms
+                    ndof=1,  # Hessian has shape (natoms, 3, natoms, 3), flattened per atom
deepmd/jax/train/trainer.py (3)

134-134: Remove unused instance variables.

The following instance variables are assigned but never used in the class:

  • self.numb_fparam (line 134)
  • self.frz_model (line 142)
  • self.ckpt_meta (line 143)
  • self.model_type (line 144)

Consider removing these unused variables to improve code clarity.

Also applies to: 142-145


283-284: Simplify .get() calls by removing redundant None default.

Apply this diff:

-                fparam=jax_data.get("fparam", None),
-                aparam=jax_data.get("aparam", None),
+                fparam=jax_data.get("fparam"),
+                aparam=jax_data.get("aparam"),

And similarly for lines 333-334:

-                            fparam=jax_valid_data.get("fparam", None),
-                            aparam=jax_valid_data.get("aparam", None),
+                            fparam=jax_valid_data.get("fparam"),
+                            aparam=jax_valid_data.get("aparam"),

Also applies to: 333-334

🧰 Tools
🪛 Ruff (0.11.9)

283-283: Use jax_data.get("fparam") instead of jax_data.get("fparam", None)

Replace jax_data.get("fparam", None) with jax_data.get("fparam")

(SIM910)


284-284: Use jax_data.get("aparam") instead of jax_data.get("aparam", None)

Replace jax_data.get("aparam", None) with jax_data.get("aparam")

(SIM910)


396-396: Simplify dictionary key iteration.

Apply this diff:

-            for k in valid_results.keys():
+            for k in valid_results:

And for line 401:

-            for k in train_results.keys():
+            for k in train_results:

Also applies to: 401-401

🧰 Tools
🪛 Ruff (0.11.9)

396-396: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)

deepmd/jax/entrypoints/train.py (3)

144-147: Simplify type map assignment with ternary operator.

Apply this diff:

-    if len(type_map) == 0:
-        ipt_type_map = None
-    else:
-        ipt_type_map = type_map
+    ipt_type_map = None if len(type_map) == 0 else type_map
🧰 Tools
🪛 Ruff (0.11.9)

144-147: Use ternary operator ipt_type_map = None if len(type_map) == 0 else type_map instead of if-else-block

Replace if-else-block with ipt_type_map = None if len(type_map) == 0 else type_map

(SIM108)


172-172: Remove unused variable stop_batch.

The variable stop_batch is assigned but never used.

Apply this diff:

-    stop_batch = jdata["training"]["numb_steps"]
🧰 Tools
🪛 Ruff (0.11.9)

172-172: Local variable stop_batch is assigned to but never used

Remove assignment to unused variable stop_batch

(F841)


201-204: Address the OOM issue in neighbor statistics calculation.

The commented code indicates an out-of-memory issue that needs to be resolved. This functionality appears to be important for updating the model's selection parameters based on neighbor statistics.

Would you like me to help investigate the OOM issue and propose a memory-efficient solution for computing neighbor statistics? I could open a new issue to track this task.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9ef43fa and 15bb506.

📒 Files selected for processing (13)
  • deepmd/backend/jax.py (1 hunks)
  • deepmd/dpmodel/descriptor/dpa1.py (1 hunks)
  • deepmd/dpmodel/fitting/ener_fitting.py (3 hunks)
  • deepmd/dpmodel/loss/ener.py (3 hunks)
  • deepmd/dpmodel/utils/env_mat_stat.py (1 hunks)
  • deepmd/dpmodel/utils/learning_rate.py (1 hunks)
  • deepmd/jax/entrypoints/__init__.py (1 hunks)
  • deepmd/jax/entrypoints/freeze.py (1 hunks)
  • deepmd/jax/entrypoints/main.py (1 hunks)
  • deepmd/jax/entrypoints/train.py (1 hunks)
  • deepmd/jax/train/__init__.py (1 hunks)
  • deepmd/jax/train/trainer.py (1 hunks)
  • deepmd/jax/utils/serialization.py (2 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (3)
deepmd/backend/jax.py (1)
deepmd/jax/entrypoints/main.py (1)
  • main (32-67)
deepmd/dpmodel/fitting/ener_fitting.py (2)
deepmd/utils/out_stat.py (1)
  • compute_stats_from_redu (15-86)
deepmd/tf/fit/ener.py (2)
  • compute_output_stats (257-273)
  • _compute_output_stats (275-321)
deepmd/jax/entrypoints/main.py (3)
deepmd/backend/suffix.py (1)
  • format_model_suffix (17-75)
deepmd/jax/entrypoints/freeze.py (1)
  • freeze (12-36)
deepmd/loggers/loggers.py (1)
  • set_log_handles (146-278)
🪛 Ruff (0.11.9)
deepmd/jax/entrypoints/train.py

144-147: Use ternary operator ipt_type_map = None if len(type_map) == 0 else type_map instead of if-else-block

Replace if-else-block with ipt_type_map = None if len(type_map) == 0 else type_map

(SIM108)


172-172: Local variable stop_batch is assigned to but never used

Remove assignment to unused variable stop_batch

(F841)


195-195: Local variable train_data is assigned to but never used

Remove assignment to unused variable train_data

(F841)

deepmd/jax/train/trainer.py

269-269: Use a context manager for opening files

(SIM115)


283-283: Use jax_data.get("fparam") instead of jax_data.get("fparam", None)

Replace jax_data.get("fparam", None) with jax_data.get("fparam")

(SIM910)


284-284: Use jax_data.get("aparam") instead of jax_data.get("aparam", None)

Replace jax_data.get("aparam", None) with jax_data.get("aparam")

(SIM910)


333-333: Use jax_valid_data.get("fparam") instead of jax_valid_data.get("fparam", None)

Replace jax_valid_data.get("fparam", None) with jax_valid_data.get("fparam")

(SIM910)


334-334: Use jax_valid_data.get("aparam") instead of jax_valid_data.get("aparam", None)

Replace jax_valid_data.get("aparam", None) with jax_valid_data.get("aparam")

(SIM910)


396-396: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


401-401: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)

⏰ Context from checks skipped due to timeout of 90000ms (29)
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Build C++ (cuda, cuda)
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Test Python (6, 3.12)
  • GitHub Check: Test Python (6, 3.9)
  • GitHub Check: Test Python (5, 3.9)
  • GitHub Check: Test Python (5, 3.12)
  • GitHub Check: Test Python (4, 3.12)
  • GitHub Check: Test Python (1, 3.9)
  • GitHub Check: Test Python (3, 3.12)
  • GitHub Check: Test Python (2, 3.12)
  • GitHub Check: Test Python (2, 3.9)
  • GitHub Check: Test Python (3, 3.9)
  • GitHub Check: Test Python (4, 3.9)
  • GitHub Check: Analyze (python)
  • GitHub Check: Test Python (1, 3.12)
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
  • GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
  • GitHub Check: Test C++ (true)
  • GitHub Check: Test C++ (false)
🔇 Additional comments (11)
deepmd/jax/train/__init__.py (1)

1-1: LGTM! Clean package initialization.

The license header is properly included and this follows standard Python package conventions.

deepmd/jax/entrypoints/__init__.py (1)

1-1: LGTM! Clean package initialization.

The license header is properly included and this follows standard Python package conventions.

deepmd/dpmodel/descriptor/dpa1.py (1)

784-784: LGTM! Descriptor count calculation is logical.

The calculation self.ndescrpt = self.nnei * 4 correctly computes the total descriptor elements, where each neighbor contributes 4 values (consistent with the wanted_shape = (self.ntypes, self.nnei, 4) on line 779).

deepmd/backend/jax.py (1)

63-67: LGTM! Proper implementation of the entry point hook.

The change correctly implements the JAX backend's entry point by importing and returning the main function from the new entrypoints module, replacing the previous NotImplementedError. This properly connects the backend to the CLI infrastructure.

deepmd/jax/utils/serialization.py (1)

180-180: LGTM! Excellent checkpoint step tracking implementation.

The extraction of current_step from model_def_script and its inclusion in the @variables dictionary provides proper support for training resumption. The default value of 0 and the use of .pop() to avoid duplication are both appropriate design choices.

Also applies to: 191-193

deepmd/dpmodel/utils/env_mat_stat.py (1)

122-130:

✅ Verification successful

Verify natoms removal doesn't affect other code.

The explicit reshaping operations improve data handling clarity and the removal of natoms unpacking simplifies the code. However, ensure that natoms isn't required elsewhere in the codebase.


🏁 Script executed:

#!/bin/bash
# Description: Check if natoms is used elsewhere in the codebase that might depend on this method.
# Expected: Find any usage of natoms that could be affected by this change.

# Search for natoms usage in environment matrix statistics contexts
rg -A 5 -B 5 "natoms.*system\[" --type py
rg -A 5 -B 5 "system.*natoms" --type py

Length of output: 28436


🏁 Script executed:

#!/bin/bash
# Locate DP model env_mat_stat file and inspect natoms usage
FILE=$(find deepmd/dpmodel -type f -name env_mat_stat.py)
echo "Found file: $FILE"
if [ -z "$FILE" ]; then
  echo "env_mat_stat.py not found in deepmd/dpmodel"
  exit 0
fi

echo -e "\n---- grep natoms occurrences ----"
grep -n "natoms" "$FILE" || echo "No natoms references found"

echo -e "\n---- Context around lines 100-160 ----"
sed -n '100,160p' "$FILE"

Length of output: 2702


No missing natoms references in deepmd/dpmodel/utils/env_mat_stat.py; removal is safe.

A search in deepmd/dpmodel/utils/env_mat_stat.py shows no remaining uses of natoms around the updated lines or elsewhere in that file. The reshaping logic now only relies on coord.shape[0], and no downstream code in this module expects natoms.

– deepmd/dpmodel/utils/env_mat_stat.py: lines 122–130 no longer mention natoms, and no grep hits for that identifier in this file.

All clear to approve.

deepmd/jax/entrypoints/main.py (1)

1-31: Well-structured imports and module setup.

The imports are properly organized and the module docstring clearly describes the purpose.

deepmd/dpmodel/fitting/ener_fitting.py (3)

9-9: Appropriate imports for statistics computation.

The numpy and compute_stats_from_redu imports are necessary for the new output statistics methods.

Also applies to: 22-25


96-112: Well-documented public interface for computing output statistics.

The method provides a clean interface that delegates to the private implementation method.


144-144:

❓ Verification inconclusive

Validate array shape before slicing.

The code slices sys_tynatom[:, 2:] without validating the array has at least 3 columns. This could cause an IndexError if the data structure is unexpected.

Consider adding a validation check:

         sys_tynatom = np.array(sys_tynatom)
         sys_tynatom = np.reshape(sys_tynatom, [nsys, -1])
+        if sys_tynatom.shape[1] < 3:
+            raise ValueError(f"Expected natoms_vec to have at least 3 columns, got {sys_tynatom.shape[1]}")
         sys_tynatom = sys_tynatom[:, 2:]

Let me verify the expected structure of natoms_vec:


🏁 Script executed:

#!/bin/bash
# Search for natoms_vec usage patterns to understand expected structure
rg -A 5 -B 5 'natoms_vec.*\[:.*2' --type py

Length of output: 45


#!/bin/bash

Display context around the slicing to understand sys_tynatom’s origin and expected shape

rg -n 'sys_tynatom' -C5 deepmd/dpmodel/fitting/ener_fitting.py


🏁 Script executed:

#!/bin/bash
# Show surrounding lines where sys_tynatom is defined and sliced
rg -n 'sys_tynatom' -C5 deepmd/dpmodel/fitting/ener_fitting.py

Length of output: 1913


#!/bin/bash

Find definitions or usage of natoms_vec and real_natoms_vec across the codebase

rg -n '"natoms_vec"' -C5 --type py
rg -n '"real_natoms_vec"' -C5 --type py


🏁 Script executed:

#!/bin/bash
# Show context where all_stat is defined or populated in ener_fitting.py
rg -n 'all_stat' -C5 deepmd/dpmodel/fitting/ener_fitting.py

Length of output: 2317


🏁 Script executed:

#!/bin/bash
# Search for definitions and usage of natoms_vec, real_natoms_vec, and make_stat_input
rg -n '"natoms_vec"' -C5 --type py
rg -n '"real_natoms_vec"' -C5 --type py
rg -n 'def make_stat_input' -C5 --type py

Length of output: 72872


Validate sys_tynatom shape before slicing

Ensure sys_tynatom has at least three columns before applying sys_tynatom[:, 2:]. While NumPy slicing won’t throw an IndexError for out‐of‐bounds slices, it will silently return an empty array if shape[1] < 3, which will break downstream logic in compute_stats_from_redu. Add a guard immediately after the reshape:

--- a/deepmd/dpmodel/fitting/ener_fitting.py
+++ b/deepmd/dpmodel/fitting/ener_fitting.py
@@ -142,6 +142,10 @@
         sys_tynatom = np.reshape(sys_tynatom, [nsys, -1])
+        if sys_tynatom.shape[1] < 3:
+            raise ValueError(
+                f"compute_output_stats requires at least 3 columns in sys_tynatom, got {sys_tynatom.shape[1]}"
+            )
         sys_tynatom = sys_tynatom[:, 2:]

• File: deepmd/dpmodel/fitting/ener_fitting.py
• Context: lines 142–144, inside _compute_output_stats

Please verify that all_stat["natoms_vec"] and all_stat["real_natoms_vec"] always yield ≥3 columns so this slice is valid.

deepmd/dpmodel/loss/ener.py (1)

180-182: Correct RMSE calculation implementation.

The changes properly compute the root mean square error by applying sqrt to the mean squared loss, providing accurate error metrics.

Also applies to: 194-196

Comment on lines +64 to +67
elif args.command is None:
pass
else:
raise RuntimeError(f"unknown command {args.command}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Improve command handling for better error reporting.

The current implementation silently passes when no command is provided, which might hide configuration issues. Additionally, the error message for unknown commands could be more helpful.

Apply this diff to improve error handling:

-    elif args.command is None:
-        pass
+    elif args.command is None:
+        raise RuntimeError("No command specified. Available commands: train, freeze")
     else:
-        raise RuntimeError(f"unknown command {args.command}")
+        raise RuntimeError(
+            f"Unknown command '{args.command}'. Available commands: train, freeze"
+        )
🤖 Prompt for AI Agents
In deepmd/jax/entrypoints/main.py around lines 64 to 67, replace the silent pass
when args.command is None with a clear error message indicating that no command
was provided. Also, enhance the RuntimeError message for unknown commands to
suggest checking available commands or usage. This improves error reporting by
explicitly handling missing commands and providing more informative feedback for
unknown commands.

# In this situation, we directly use these assigned energies instead of computing stats.
# This will make the loss decrease quickly
assigned_atom_ener = np.array(
[ee if ee is not None else np.nan for ee in self.atom_ener_v]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix undefined attribute reference.

The code references self.atom_ener_v which is not defined in the class. This will cause an AttributeError at runtime.

Based on the similar implementation in deepmd/tf/fit/ener.py, this should likely be self.atom_ener:

-                [ee if ee is not None else np.nan for ee in self.atom_ener_v]
+                [ee if ee is not None else np.nan for ee in self.atom_ener]
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
[ee if ee is not None else np.nan for ee in self.atom_ener_v]
[ee if ee is not None else np.nan for ee in self.atom_ener]
🤖 Prompt for AI Agents
In deepmd/dpmodel/fitting/ener_fitting.py at line 150, the code references an
undefined attribute self.atom_ener_v, causing an AttributeError. Replace
self.atom_ener_v with self.atom_ener to match the correct attribute name used in
the class, following the pattern from deepmd/tf/fit/ener.py.

Other keyword arguments.
"""
EnergyLoss.__init__(self, **kwargs)
self.has_h = start_pref_h != 0.0 and limit_pref_h != 0.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix logical condition for enabling Hessian loss.

The condition uses and which requires both prefactors to be non-zero. This is likely too restrictive - the Hessian loss should be enabled if either prefactor is non-zero.

Apply this fix:

-        self.has_h = start_pref_h != 0.0 and limit_pref_h != 0.0
+        self.has_h = start_pref_h != 0.0 or limit_pref_h != 0.0
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
self.has_h = start_pref_h != 0.0 and limit_pref_h != 0.0
self.has_h = start_pref_h != 0.0 or limit_pref_h != 0.0
🤖 Prompt for AI Agents
In deepmd/dpmodel/loss/ener.py at line 409, the condition for enabling Hessian
loss uses 'and' to check if both start_pref_h and limit_pref_h are non-zero,
which is too restrictive. Change the logical operator from 'and' to 'or' so that
the Hessian loss is enabled if either start_pref_h or limit_pref_h is non-zero.

and "energy_derv_r_derv_r" in model_dict
and "hessian" in label_dict
):
find_hessian = label_dict.get("find_hessian", 0.0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Verify find_hessian data type consistency.

The code uses label_dict.get("find_hessian", 0.0) which returns a scalar default, but other find_* variables in the parent class appear to be arrays from the labeled data.

Let me verify the expected data type for find_* variables:


🏁 Script executed:

#!/bin/bash
# Search for find_* variable usage patterns in loss calculations
rg -B 2 -A 2 'find_[a-z_]+\s*=' deepmd/dpmodel/loss/ --type py

Length of output: 1419


🏁 Script executed:

#!/bin/bash
# Search for DataRequirementItem definitions that might indicate find_* data structure
ast-grep --pattern 'DataRequirementItem(
  "find_$_",
  $$$
)'

Length of output: 64


Ensure find_hessian is accessed like the other find_* flags

The other find_* variables are pulled directly from label_dict (no default), so using .get(..., 0.0) here is inconsistent and may silently disable Hessian loss even when "hessian" is present. Please update this so that missing flags are surfaced (or, if you really intend an optional flag, default to a matching array of ones).

• File deepmd/dpmodel/loss/ener.py, line 434

-    find_hessian = label_dict.get("find_hessian", 0.0)
+    find_hessian = label_dict["find_hessian"]

If you do need a default, consider instead:

find_hessian = label_dict.get(
    "find_hessian",
    xp.ones(label_dict["hessian"].shape[0])  # match per-frame flag shape
)
🤖 Prompt for AI Agents
In deepmd/dpmodel/loss/ener.py at line 434, the assignment of find_hessian uses
label_dict.get with a scalar default 0.0, which is inconsistent with other
find_* variables that do not use defaults and are arrays. To fix this, remove
the default value so that missing keys raise an error or, if a default is
necessary, set it to an array of ones matching the shape of
label_dict["hessian"]. This ensures data type consistency and proper handling of
the find_hessian flag.

optimizer.update(grads)

start_time = time.time()
disp_file_fp = open(self.disp_file, "w")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Use context manager for file operations.

The file disp_file_fp should be opened using a context manager to ensure proper cleanup even if an exception occurs.

Apply this diff:

-        disp_file_fp = open(self.disp_file, "w")
-        for step in range(self.start_step, self.num_steps):
+        with open(self.disp_file, "w") as disp_file_fp:
+            for step in range(self.start_step, self.num_steps):
             # ... (training loop)
-
-        disp_file_fp.close()

Note: You'll need to indent the entire training loop inside the context manager.

Also applies to: 382-382

🧰 Tools
🪛 Ruff (0.11.9)

269-269: Use a context manager for opening files

(SIM115)

🤖 Prompt for AI Agents
In deepmd/jax/train/trainer.py at lines 269 and 382, the file disp_file_fp is
opened without a context manager, risking resource leaks if exceptions occur.
Refactor the code to open disp_file_fp using a with statement as a context
manager, and indent the entire training loop or relevant code block inside this
with block to ensure the file is properly closed after use.

Comment on lines +320 to +323
valid_batch_data = valid_data.get_batch()
jax_valid_data = {
kk: jnp.asarray(vv) for kk, vv in valid_batch_data.items()
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Apply consistent data conversion for validation data.

The validation data conversion should check for keys starting with "find_" similar to the training data conversion.

Apply this diff:

                     jax_valid_data = {
-                        kk: jnp.asarray(vv) for kk, vv in valid_batch_data.items()
+                        kk: jnp.asarray(vv) if not kk.startswith("find_") else bool(vv.item())
+                        for kk, vv in valid_batch_data.items()
                     }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
valid_batch_data = valid_data.get_batch()
jax_valid_data = {
kk: jnp.asarray(vv) for kk, vv in valid_batch_data.items()
}
valid_batch_data = valid_data.get_batch()
jax_valid_data = {
kk: jnp.asarray(vv) if not kk.startswith("find_") else bool(vv.item())
for kk, vv in valid_batch_data.items()
}
🤖 Prompt for AI Agents
In deepmd/jax/train/trainer.py around lines 320 to 323, the validation data
conversion to JAX arrays should be consistent with the training data conversion
by checking if keys start with "find_". Update the dictionary comprehension to
convert values to jnp.asarray only for keys starting with "find_", leaving other
keys unchanged.

Copy link

codecov bot commented Jun 5, 2025

Codecov Report

Attention: Patch coverage is 5.76369% with 327 lines in your changes missing coverage. Please review.

Project coverage is 84.40%. Comparing base (265d094) to head (15bb506).
Report is 3 commits behind head on devel.

Files with missing lines Patch % Lines
deepmd/jax/train/trainer.py 0.00% 168 Missing ⚠️
deepmd/jax/entrypoints/train.py 0.00% 68 Missing ⚠️
deepmd/dpmodel/fitting/ener_fitting.py 10.52% 34 Missing ⚠️
deepmd/dpmodel/loss/ener.py 24.13% 22 Missing ⚠️
deepmd/jax/entrypoints/main.py 0.00% 22 Missing ⚠️
deepmd/jax/entrypoints/freeze.py 0.00% 10 Missing ⚠️
deepmd/backend/jax.py 0.00% 2 Missing ⚠️
deepmd/jax/utils/serialization.py 0.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##            devel    #4782      +/-   ##
==========================================
- Coverage   84.79%   84.40%   -0.40%     
==========================================
  Files         698      702       +4     
  Lines       67775    68126     +351     
  Branches     3544     3542       -2     
==========================================
+ Hits        57472    57499      +27     
- Misses       9169     9494     +325     
+ Partials     1134     1133       -1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

njzjz added 3 commits June 10, 2025 19:46
1. For dpmodel, pt, and pd, pass the trainable parameter to the layer (not actually used in this PR).
2. For JAX, support the `trainable` parameter in the layer.

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
njzjz and others added 10 commits June 10, 2025 20:51
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
@@ -22,7 +22,7 @@
)


def deserialize_to_file(model_file: str, data: dict) -> None:
def deserialize_to_file(model_file: str, data: dict, hessian: bool = False) -> None:

Check notice

Code scanning / CodeQL

Explicit returns mixed with implicit (fall through) returns Note

Mixing implicit and explicit returns may indicate an error as implicit returns always return None.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (4)
deepmd/jax/train/trainer.py (2)

274-274: Use context manager for file operations.

The file disp_file_fp should be opened using a context manager to ensure proper cleanup even if an exception occurs.

Apply this diff:

-        disp_file_fp = open(self.disp_file, "w")
-        for step in range(self.start_step, self.num_steps):
+        with open(self.disp_file, "w") as disp_file_fp:
+            for step in range(self.start_step, self.num_steps):
             # ... (training loop content should be indented)
-        disp_file_fp.close()

Note: You'll need to indent the entire training loop inside the context manager.


326-328: Apply consistent data conversion for validation data.

The validation data conversion should check for keys starting with "find_" similar to the training data conversion.

Apply this diff:

                     jax_valid_data = {
-                        kk: jnp.asarray(vv) for kk, vv in valid_batch_data.items()
+                        kk: jnp.asarray(vv) if not kk.startswith("find_") else bool(vv.item())
+                        for kk, vv in valid_batch_data.items()
                     }
deepmd/dpmodel/loss/ener.py (2)

409-409: Fix logical condition for enabling Hessian loss.

The condition uses and which requires both prefactors to be non-zero. This is likely too restrictive - the Hessian loss should be enabled if either prefactor is non-zero.

Apply this fix:

-        self.has_h = start_pref_h != 0.0 and limit_pref_h != 0.0
+        self.has_h = start_pref_h != 0.0 or limit_pref_h != 0.0

434-434: Ensure find_hessian is accessed like the other find_* flags.

The other find_* variables are pulled directly from label_dict (no default), so using .get(..., 0.0) here is inconsistent and may silently disable Hessian loss even when "hessian" is present.

Apply this fix:

-            find_hessian = label_dict.get("find_hessian", 0.0)
+            find_hessian = label_dict["find_hessian"]
🧹 Nitpick comments (4)
deepmd/jax/train/trainer.py (4)

288-289: Simplify .get() calls.

The explicit None default is unnecessary since it's the default value for .get().

Apply this diff:

-                fparam=jax_data.get("fparam", None),
-                aparam=jax_data.get("aparam", None),
+                fparam=jax_data.get("fparam"),
+                aparam=jax_data.get("aparam"),

338-339: Simplify .get() calls.

The explicit None default is unnecessary since it's the default value for .get().

Apply this diff:

-                            fparam=jax_valid_data.get("fparam", None),
-                            aparam=jax_valid_data.get("aparam", None),
+                            fparam=jax_valid_data.get("fparam"),
+                            aparam=jax_valid_data.get("aparam"),

401-401: Simplify dictionary iteration.

Using dict.keys() is unnecessary when iterating over dictionary keys.

Apply this diff:

-            for k in valid_results.keys():
+            for k in valid_results:

406-406: Simplify dictionary iteration.

Using dict.keys() is unnecessary when iterating over dictionary keys.

Apply this diff:

-            for k in train_results.keys():
+            for k in train_results:
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 15bb506 and 208b648.

📒 Files selected for processing (10)
  • deepmd/dpmodel/descriptor/dpa1.py (1 hunks)
  • deepmd/dpmodel/loss/ener.py (3 hunks)
  • deepmd/dpmodel/utils/env_mat_stat.py (1 hunks)
  • deepmd/jax/entrypoints/freeze.py (1 hunks)
  • deepmd/jax/infer/deep_eval.py (2 hunks)
  • deepmd/jax/jax2tf/serialization.py (2 hunks)
  • deepmd/jax/model/hlo.py (2 hunks)
  • deepmd/jax/train/trainer.py (1 hunks)
  • deepmd/jax/utils/serialization.py (5 hunks)
  • deepmd/main.py (1 hunks)
✅ Files skipped from review due to trivial changes (1)
  • deepmd/dpmodel/utils/env_mat_stat.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • deepmd/dpmodel/descriptor/dpa1.py
  • deepmd/jax/utils/serialization.py
🧰 Additional context used
🧠 Learnings (4)
📓 Common learnings
Learnt from: njzjz
PR: deepmodeling/deepmd-kit#4284
File: deepmd/jax/__init__.py:8-8
Timestamp: 2024-10-30T20:08:12.531Z
Learning: In the DeepMD project, entry points like `deepmd.jax` may be registered in external projects, so their absence in the local configuration files is acceptable.
deepmd/jax/model/hlo.py (5)

Learnt from: 1azyking
PR: #4169
File: deepmd/pt/loss/ener_hess.py:341-348
Timestamp: 2024-10-08T15:32:11.479Z
Learning: In deepmd/pt/loss/ener_hess.py, the label uses the key "atom_ener" intentionally to maintain consistency with the forked version.

Learnt from: 1azyking
PR: #4169
File: deepmd/pt/loss/ener_hess.py:341-348
Timestamp: 2024-10-05T03:11:02.922Z
Learning: In deepmd/pt/loss/ener_hess.py, the label uses the key "atom_ener" intentionally to maintain consistency with the forked version.

Learnt from: 1azyking
PR: #4169
File: deepmd/utils/argcheck.py:1982-2117
Timestamp: 2024-10-05T03:06:02.372Z
Learning: The loss_ener_hess and loss_ener functions should remain separate to avoid confusion, despite code duplication.

Learnt from: 1azyking
PR: #4169
File: deepmd/pt/model/model/ener_hess_model.py:127-144
Timestamp: 2024-10-08T15:32:11.479Z
Learning: Training with energy Hessian is not required for lower-level operations as of now.

Learnt from: 1azyking
PR: #4169
File: deepmd/pt/model/model/ener_hess_model.py:127-144
Timestamp: 2024-10-05T02:46:06.925Z
Learning: Training with energy Hessian is not required for lower-level operations as of now.

deepmd/dpmodel/loss/ener.py (5)

Learnt from: 1azyking
PR: #4169
File: deepmd/pt/loss/ener_hess.py:341-348
Timestamp: 2024-10-08T15:32:11.479Z
Learning: In deepmd/pt/loss/ener_hess.py, the label uses the key "atom_ener" intentionally to maintain consistency with the forked version.

Learnt from: 1azyking
PR: #4169
File: deepmd/pt/loss/ener_hess.py:341-348
Timestamp: 2024-10-05T03:11:02.922Z
Learning: In deepmd/pt/loss/ener_hess.py, the label uses the key "atom_ener" intentionally to maintain consistency with the forked version.

Learnt from: 1azyking
PR: #4169
File: deepmd/utils/argcheck.py:1982-2117
Timestamp: 2024-10-05T03:06:02.372Z
Learning: The loss_ener_hess and loss_ener functions should remain separate to avoid confusion, despite code duplication.

Learnt from: 1azyking
PR: #4169
File: deepmd/pt/model/model/ener_hess_model.py:127-144
Timestamp: 2024-10-08T15:32:11.479Z
Learning: Training with energy Hessian is not required for lower-level operations as of now.

Learnt from: 1azyking
PR: #4169
File: deepmd/pt/model/model/ener_hess_model.py:127-144
Timestamp: 2024-10-05T02:46:06.925Z
Learning: Training with energy Hessian is not required for lower-level operations as of now.

deepmd/jax/train/trainer.py (2)

Learnt from: 1azyking
PR: #4169
File: examples/hessian/single-task/input.json:3-59
Timestamp: 2024-10-05T02:57:02.517Z
Learning: In the deepmd-kit project, having different rcut values for repinit and repformer is acceptable when they align with existing configurations, such as in examples/water/dpa2/input_torch.json.

Learnt from: 1azyking
PR: #4169
File: examples/hessian/single-task/input.json:3-59
Timestamp: 2024-10-08T15:32:11.479Z
Learning: In the deepmd-kit project, having different rcut values for repinit and repformer is acceptable when they align with existing configurations, such as in examples/water/dpa2/input_torch.json.

🪛 Ruff (0.12.2)
deepmd/jax/train/trainer.py

274-274: Use a context manager for opening files

(SIM115)


288-288: Use jax_data.get("fparam") instead of jax_data.get("fparam", None)

Replace jax_data.get("fparam", None) with jax_data.get("fparam")

(SIM910)


289-289: Use jax_data.get("aparam") instead of jax_data.get("aparam", None)

Replace jax_data.get("aparam", None) with jax_data.get("aparam")

(SIM910)


338-338: Use jax_valid_data.get("fparam") instead of jax_valid_data.get("fparam", None)

Replace jax_valid_data.get("fparam", None) with jax_valid_data.get("fparam")

(SIM910)


339-339: Use jax_valid_data.get("aparam") instead of jax_valid_data.get("aparam", None)

Replace jax_valid_data.get("aparam", None) with jax_valid_data.get("aparam")

(SIM910)


401-401: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


406-406: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)

🔇 Additional comments (19)
deepmd/main.py (1)

338-343: LGTM! Clean CLI integration for Hessian support.

The addition of the --hessian flag to the freeze subcommand is well-implemented:

  • Follows argparse conventions with appropriate action and default
  • Clear help text explaining the functionality
  • Integrates properly with the broader Hessian support in the JAX backend
deepmd/jax/jax2tf/serialization.py (3)

24-24: LGTM! Well-designed function signature extension.

The addition of the hessian parameter with a default value of False maintains backward compatibility while enabling the new functionality.


33-34: LGTM! Clear documentation.

The parameter documentation is concise and accurately describes the functionality.


39-41: LGTM! Correct Hessian mode enablement.

The conditional logic properly:

  • Enables Hessian mode on the model instance
  • Updates the model definition script with the Hessian flag
  • Only executes when the hessian parameter is True
deepmd/jax/model/hlo.py (2)

34-41: LGTM! Properly structured Hessian output definition.

The energy_hessian output definition correctly:

  • Mirrors the standard energy definition structure
  • Adds the r_hessian=True attribute to enable Hessian computations
  • Uses consistent naming and shape specifications

178-191: LGTM! Smart conditional output selection.

The modified model_output_def method elegantly handles Hessian mode by:

  • Checking the hessian_mode flag in the model definition script
  • Only applying the transformation for energy output type
  • Using string formatting to select the appropriate output definition

The logic correctly falls back to the standard output when Hessian mode is disabled or for non-energy outputs.

deepmd/jax/entrypoints/freeze.py (2)

12-18: LGTM! Well-designed function signature.

The function design is excellent:

  • Uses keyword-only arguments (*,) which prevents positional argument misuse
  • Provides sensible default for hessian parameter
  • Accepts **kwargs for extensibility

32-39: LGTM! Robust checkpoint resolution and error handling.

The implementation correctly:

  • Handles checkpoint file indirection by reading the checkpoint metadata
  • Validates directory existence before processing
  • Provides clear error messages with the actual checkpoint path
  • Passes the hessian flag through to the serialization layer
deepmd/jax/infer/deep_eval.py (3)

281-281: LGTM! Consistent integration of Hessian output category.

The addition of OutputVariableCategory.DERV_R_DERV_R to the non-atomic output categories correctly enables Hessian matrix computation in inference.


414-416: LGTM! Correct Hessian matrix shape calculation.

The shape [nframes, 3 * natoms, 3 * natoms] is mathematically correct for the Hessian matrix, representing the second derivatives with respect to all atomic coordinates.


424-426: LGTM! Clean Hessian mode detection.

The get_has_hessian method correctly:

  • Retrieves the model definition script
  • Checks for the hessian_mode key with appropriate default
  • Returns a boolean indicating Hessian support
deepmd/jax/train/trainer.py (3)

1-63: LGTM! Well-organized imports.

The imports are comprehensive and well-structured, covering all necessary components for JAX-based training including models, loss functions, utilities, and logging.


65-145: LGTM! Well-structured initialization.

The initialization properly handles different model loading scenarios and configures training parameters. The learning rate setup is clean and extensible for future learning rate types.


430-465: Excellent utility function implementation.

The prepare_input function is well-designed with proper keyword-only arguments and handles all necessary coordinate transformations for DeepMD models, including normalization, ghost atom extension, and neighbor list building.

deepmd/dpmodel/loss/ener.py (5)

180-182: Good improvement to loss reporting.

Changing from MSE to RMSE reporting makes the loss metrics more interpretable while keeping the underlying MSE-based loss computation for optimization. The implementation is mathematically correct.


194-196: Consistent RMSE reporting for force loss.

The force loss reporting is now consistent with energy loss, using RMSE instead of MSE for better interpretability.


390-413: Well-designed Hessian loss extension.

The EnergyHessianLoss class properly extends the base EnergyLoss class with configurable Hessian prefactors and appropriate parameter handling.


414-448: Solid Hessian loss computation implementation.

The Hessian loss computation correctly handles prefactor interpolation, conditional loss addition, and RMSE reporting. The logic for checking model and label dictionary keys before computation is appropriate.


449-464: Proper label requirements extension.

The label_requirement property correctly extends the parent class requirements with Hessian data requirements when Hessian loss is enabled.

@@ -419,4 +419,8 @@ def _get_output_shape(self, odef, nframes, natoms):

def get_model_def_script(self) -> dict:
"""Get model definition script."""
return json.loads(self.dp.get_model_def_script())
return self.dp.get_model_def_script()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Verify impact of removing JSON parsing in get_model_def_script.

The change removes JSON parsing and returns the raw model definition script. This could affect downstream code that expects a parsed dictionary.


🏁 Script executed:

#!/bin/bash
# Search for usages of get_model_def_script to verify they can handle raw script
rg -A 5 "get_model_def_script\(\)" --type py

Length of output: 6125


Action Required: Restore JSON parsing in JAX infer’s get_model_def_script

The call to self.dp.get_model_def_script() now returns a raw JSON string, but downstream code in this module (e.g. get_has_hessian) still does .get(...). We need to re-introduce json.loads so callers continue to receive a dictionary.

Please update deepmd/jax/infer/deep_eval.py:

  • Add import json at the top (if not already present).
  • Change the method on line 422 from:
    return self.dp.get_model_def_script()
    to:
    return json.loads(self.dp.get_model_def_script())

This ensures that:

  • get_model_def_script() returns a dict again, matching the expectations of:
    model_def_script = self.get_model_def_script()
    return model_def_script.get("hessian_mode", False)
  • Any other callers in this file relying on dict-style access remain functional.
🤖 Prompt for AI Agents
In deepmd/jax/infer/deep_eval.py at line 422, the method get_model_def_script
currently returns a raw JSON string, but downstream code expects a dictionary.
To fix this, add "import json" at the top of the file if not already present,
then change the return statement on line 422 to return
json.loads(self.dp.get_model_def_script()) so that the method returns a parsed
dictionary as expected.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Feature Request] JAX training
1 participant