-
Notifications
You must be signed in to change notification settings - Fork 562
feat: JAX training #4782
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: devel
Are you sure you want to change the base?
feat: JAX training #4782
Conversation
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
…to jax_training
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
valid_data = None | ||
|
||
# get training info | ||
stop_batch = jdata["training"]["numb_steps"] |
Check notice
Code scanning / CodeQL
Unused local variable Note
if ( | ||
origin_type_map is not None and not origin_type_map | ||
): # get the type_map from data if not provided | ||
origin_type_map = get_data( |
Check notice
Code scanning / CodeQL
Unused local variable Note
) | ||
jdata_cpy = jdata.copy() | ||
type_map = jdata["model"].get("type_map") | ||
train_data = get_data( |
Check notice
Code scanning / CodeQL
Unused local variable Note
📝 WalkthroughWalkthroughThis update introduces a JAX-based training and entrypoint framework for DeePMD-Kit, including new modules for command-line handling, model freezing, and training orchestration. It adds or modifies several utilities and loss functions, enhances serialization with step tracking and optional Hessian support, and generalizes learning rate scheduling. Minor internal changes and new attributes are also introduced in descriptors and statistics utilities. Changes
Sequence Diagram(s)sequenceDiagram
participant CLI_User
participant MainEntrypoint
participant TrainEntrypoint
participant FreezeEntrypoint
participant DPTrainer
participant SerializationUtils
CLI_User->>MainEntrypoint: main(args)
MainEntrypoint->>MainEntrypoint: parse_args(args)
alt command == "train"
MainEntrypoint->>TrainEntrypoint: train(**args)
TrainEntrypoint->>DPTrainer: DPTrainer(jdata, ...)
DPTrainer->>DPTrainer: train(train_data, valid_data)
DPTrainer->>SerializationUtils: save_checkpoint(...)
else command == "freeze"
MainEntrypoint->>FreezeEntrypoint: freeze(checkpoint_folder, output)
FreezeEntrypoint->>SerializationUtils: serialize_from_file(folder)
FreezeEntrypoint->>SerializationUtils: deserialize_to_file(output, hessian)
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~40 minutes Possibly related PRs
Suggested labels
Suggested reviewers
✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 6
🧹 Nitpick comments (10)
deepmd/dpmodel/utils/learning_rate.py (1)
48-52
: Update return type annotation for backend generalization.The addition of the
xp
parameter to support different array backends (like JAX) is excellent for the framework's extensibility. However, the return type annotation should be updated to reflect the actual return type.- def value(self, step, xp=np) -> np.float64: + def value(self, step, xp=np):Or if you want to be more specific:
- def value(self, step, xp=np) -> np.float64: + def value(self, step, xp=np) -> Union[np.float64, Any]:deepmd/jax/entrypoints/freeze.py (1)
12-36
: LGTM! Consider adding output validation and improved error handling.The freeze function implementation is well-structured with proper checkpoint handling logic. The keyword-only argument design and integration with serialization utilities are excellent.
Consider these minor enhancements:
def freeze( *, checkpoint_folder: str, output: str, **kwargs, ) -> None: """Freeze the graph in supplied folder. Parameters ---------- checkpoint_folder : str location of either the folder with checkpoint or the checkpoint prefix output : str - output file name + output file name (supported formats: .jax, .hlo, .savedmodel) **kwargs other arguments """ + # Validate output format + supported_formats = ['.jax', '.hlo', '.savedmodel'] + if not any(output.endswith(fmt) for fmt in supported_formats): + raise ValueError(f"Unsupported output format. Supported: {supported_formats}") + if (Path(checkpoint_folder) / "checkpoint").is_file(): checkpoint_meta = Path(checkpoint_folder) / "checkpoint" checkpoint_folder = checkpoint_meta.read_text().strip() if Path(checkpoint_folder).is_dir(): - data = serialize_from_file(checkpoint_folder) - deserialize_to_file(output, data) + try: + data = serialize_from_file(checkpoint_folder) + deserialize_to_file(output, data) + except Exception as e: + raise RuntimeError(f"Failed to freeze checkpoint: {e}") from e else: raise FileNotFoundError(f"Checkpoint {checkpoint_folder} does not exist.")deepmd/dpmodel/fitting/ener_fitting.py (1)
118-124
: Simplify nested loops for better readability.The triple-nested loops can be simplified using list comprehensions or numpy operations.
Consider refactoring the nested loops:
- sys_ener = [] - for ss in range(len(data)): - sys_data = [] - for ii in range(len(data[ss])): - for jj in range(len(data[ss][ii])): - sys_data.append(data[ss][ii][jj]) - sys_data = np.concatenate(sys_data) - sys_ener.append(np.average(sys_data)) + sys_ener = [] + for system_data in data: + # Flatten all batches and frames for this system + sys_data = np.concatenate([frame for batch in system_data for frame in batch]) + sys_ener.append(np.average(sys_data))Similarly for the mixed_type branch:
- tmp_tynatom = [] - for ii in range(len(data[ss])): - for jj in range(len(data[ss][ii])): - tmp_tynatom.append(data[ss][ii][jj].astype(np.float64)) - tmp_tynatom = np.average(np.array(tmp_tynatom), axis=0) + # Flatten all batches and frames, then compute average + tmp_tynatom = np.average( + np.array([frame.astype(np.float64) + for batch in data[ss] + for frame in batch]), + axis=0 + )Also applies to: 130-136
deepmd/dpmodel/loss/ener.py (1)
457-457
: Clarify the ndof comment for Hessian data.The comment
# 9=3*3 --> 3N*3N=ndof*natoms*natoms
is confusing since ndof is set to 1, not 9.Update the comment to be clearer:
- ndof=1, # 9=3*3 --> 3N*3N=ndof*natoms*natoms + ndof=1, # Hessian has shape (natoms, 3, natoms, 3), flattened per atomdeepmd/jax/train/trainer.py (3)
134-134
: Remove unused instance variables.The following instance variables are assigned but never used in the class:
self.numb_fparam
(line 134)self.frz_model
(line 142)self.ckpt_meta
(line 143)self.model_type
(line 144)Consider removing these unused variables to improve code clarity.
Also applies to: 142-145
283-284
: Simplify .get() calls by removing redundant None default.Apply this diff:
- fparam=jax_data.get("fparam", None), - aparam=jax_data.get("aparam", None), + fparam=jax_data.get("fparam"), + aparam=jax_data.get("aparam"),And similarly for lines 333-334:
- fparam=jax_valid_data.get("fparam", None), - aparam=jax_valid_data.get("aparam", None), + fparam=jax_valid_data.get("fparam"), + aparam=jax_valid_data.get("aparam"),Also applies to: 333-334
🧰 Tools
🪛 Ruff (0.11.9)
283-283: Use
jax_data.get("fparam")
instead ofjax_data.get("fparam", None)
Replace
jax_data.get("fparam", None)
withjax_data.get("fparam")
(SIM910)
284-284: Use
jax_data.get("aparam")
instead ofjax_data.get("aparam", None)
Replace
jax_data.get("aparam", None)
withjax_data.get("aparam")
(SIM910)
396-396
: Simplify dictionary key iteration.Apply this diff:
- for k in valid_results.keys(): + for k in valid_results:And for line 401:
- for k in train_results.keys(): + for k in train_results:Also applies to: 401-401
🧰 Tools
🪛 Ruff (0.11.9)
396-396: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
deepmd/jax/entrypoints/train.py (3)
144-147
: Simplify type map assignment with ternary operator.Apply this diff:
- if len(type_map) == 0: - ipt_type_map = None - else: - ipt_type_map = type_map + ipt_type_map = None if len(type_map) == 0 else type_map🧰 Tools
🪛 Ruff (0.11.9)
144-147: Use ternary operator
ipt_type_map = None if len(type_map) == 0 else type_map
instead ofif
-else
-blockReplace
if
-else
-block withipt_type_map = None if len(type_map) == 0 else type_map
(SIM108)
172-172
: Remove unused variablestop_batch
.The variable
stop_batch
is assigned but never used.Apply this diff:
- stop_batch = jdata["training"]["numb_steps"]
🧰 Tools
🪛 Ruff (0.11.9)
172-172: Local variable
stop_batch
is assigned to but never usedRemove assignment to unused variable
stop_batch
(F841)
201-204
: Address the OOM issue in neighbor statistics calculation.The commented code indicates an out-of-memory issue that needs to be resolved. This functionality appears to be important for updating the model's selection parameters based on neighbor statistics.
Would you like me to help investigate the OOM issue and propose a memory-efficient solution for computing neighbor statistics? I could open a new issue to track this task.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (13)
deepmd/backend/jax.py
(1 hunks)deepmd/dpmodel/descriptor/dpa1.py
(1 hunks)deepmd/dpmodel/fitting/ener_fitting.py
(3 hunks)deepmd/dpmodel/loss/ener.py
(3 hunks)deepmd/dpmodel/utils/env_mat_stat.py
(1 hunks)deepmd/dpmodel/utils/learning_rate.py
(1 hunks)deepmd/jax/entrypoints/__init__.py
(1 hunks)deepmd/jax/entrypoints/freeze.py
(1 hunks)deepmd/jax/entrypoints/main.py
(1 hunks)deepmd/jax/entrypoints/train.py
(1 hunks)deepmd/jax/train/__init__.py
(1 hunks)deepmd/jax/train/trainer.py
(1 hunks)deepmd/jax/utils/serialization.py
(2 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (3)
deepmd/backend/jax.py (1)
deepmd/jax/entrypoints/main.py (1)
main
(32-67)
deepmd/dpmodel/fitting/ener_fitting.py (2)
deepmd/utils/out_stat.py (1)
compute_stats_from_redu
(15-86)deepmd/tf/fit/ener.py (2)
compute_output_stats
(257-273)_compute_output_stats
(275-321)
deepmd/jax/entrypoints/main.py (3)
deepmd/backend/suffix.py (1)
format_model_suffix
(17-75)deepmd/jax/entrypoints/freeze.py (1)
freeze
(12-36)deepmd/loggers/loggers.py (1)
set_log_handles
(146-278)
🪛 Ruff (0.11.9)
deepmd/jax/entrypoints/train.py
144-147: Use ternary operator ipt_type_map = None if len(type_map) == 0 else type_map
instead of if
-else
-block
Replace if
-else
-block with ipt_type_map = None if len(type_map) == 0 else type_map
(SIM108)
172-172: Local variable stop_batch
is assigned to but never used
Remove assignment to unused variable stop_batch
(F841)
195-195: Local variable train_data
is assigned to but never used
Remove assignment to unused variable train_data
(F841)
deepmd/jax/train/trainer.py
269-269: Use a context manager for opening files
(SIM115)
283-283: Use jax_data.get("fparam")
instead of jax_data.get("fparam", None)
Replace jax_data.get("fparam", None)
with jax_data.get("fparam")
(SIM910)
284-284: Use jax_data.get("aparam")
instead of jax_data.get("aparam", None)
Replace jax_data.get("aparam", None)
with jax_data.get("aparam")
(SIM910)
333-333: Use jax_valid_data.get("fparam")
instead of jax_valid_data.get("fparam", None)
Replace jax_valid_data.get("fparam", None)
with jax_valid_data.get("fparam")
(SIM910)
334-334: Use jax_valid_data.get("aparam")
instead of jax_valid_data.get("aparam", None)
Replace jax_valid_data.get("aparam", None)
with jax_valid_data.get("aparam")
(SIM910)
396-396: Use key in dict
instead of key in dict.keys()
Remove .keys()
(SIM118)
401-401: Use key in dict
instead of key in dict.keys()
Remove .keys()
(SIM118)
⏰ Context from checks skipped due to timeout of 90000ms (29)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Test Python (6, 3.12)
- GitHub Check: Test Python (6, 3.9)
- GitHub Check: Test Python (5, 3.9)
- GitHub Check: Test Python (5, 3.12)
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Analyze (python)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Test C++ (true)
- GitHub Check: Test C++ (false)
🔇 Additional comments (11)
deepmd/jax/train/__init__.py (1)
1-1
: LGTM! Clean package initialization.The license header is properly included and this follows standard Python package conventions.
deepmd/jax/entrypoints/__init__.py (1)
1-1
: LGTM! Clean package initialization.The license header is properly included and this follows standard Python package conventions.
deepmd/dpmodel/descriptor/dpa1.py (1)
784-784
: LGTM! Descriptor count calculation is logical.The calculation
self.ndescrpt = self.nnei * 4
correctly computes the total descriptor elements, where each neighbor contributes 4 values (consistent with thewanted_shape = (self.ntypes, self.nnei, 4)
on line 779).deepmd/backend/jax.py (1)
63-67
: LGTM! Proper implementation of the entry point hook.The change correctly implements the JAX backend's entry point by importing and returning the
main
function from the new entrypoints module, replacing the previousNotImplementedError
. This properly connects the backend to the CLI infrastructure.deepmd/jax/utils/serialization.py (1)
180-180
: LGTM! Excellent checkpoint step tracking implementation.The extraction of
current_step
frommodel_def_script
and its inclusion in the@variables
dictionary provides proper support for training resumption. The default value of 0 and the use of.pop()
to avoid duplication are both appropriate design choices.Also applies to: 191-193
deepmd/dpmodel/utils/env_mat_stat.py (1)
122-130
:✅ Verification successful
Verify natoms removal doesn't affect other code.
The explicit reshaping operations improve data handling clarity and the removal of
natoms
unpacking simplifies the code. However, ensure thatnatoms
isn't required elsewhere in the codebase.
🏁 Script executed:
#!/bin/bash # Description: Check if natoms is used elsewhere in the codebase that might depend on this method. # Expected: Find any usage of natoms that could be affected by this change. # Search for natoms usage in environment matrix statistics contexts rg -A 5 -B 5 "natoms.*system\[" --type py rg -A 5 -B 5 "system.*natoms" --type pyLength of output: 28436
🏁 Script executed:
#!/bin/bash # Locate DP model env_mat_stat file and inspect natoms usage FILE=$(find deepmd/dpmodel -type f -name env_mat_stat.py) echo "Found file: $FILE" if [ -z "$FILE" ]; then echo "env_mat_stat.py not found in deepmd/dpmodel" exit 0 fi echo -e "\n---- grep natoms occurrences ----" grep -n "natoms" "$FILE" || echo "No natoms references found" echo -e "\n---- Context around lines 100-160 ----" sed -n '100,160p' "$FILE"Length of output: 2702
No missing
natoms
references in deepmd/dpmodel/utils/env_mat_stat.py; removal is safe.A search in
deepmd/dpmodel/utils/env_mat_stat.py
shows no remaining uses ofnatoms
around the updated lines or elsewhere in that file. The reshaping logic now only relies oncoord.shape[0]
, and no downstream code in this module expectsnatoms
.– deepmd/dpmodel/utils/env_mat_stat.py: lines 122–130 no longer mention
natoms
, and no grep hits for that identifier in this file.All clear to approve.
deepmd/jax/entrypoints/main.py (1)
1-31
: Well-structured imports and module setup.The imports are properly organized and the module docstring clearly describes the purpose.
deepmd/dpmodel/fitting/ener_fitting.py (3)
9-9
: Appropriate imports for statistics computation.The numpy and compute_stats_from_redu imports are necessary for the new output statistics methods.
Also applies to: 22-25
96-112
: Well-documented public interface for computing output statistics.The method provides a clean interface that delegates to the private implementation method.
144-144
:❓ Verification inconclusive
Validate array shape before slicing.
The code slices
sys_tynatom[:, 2:]
without validating the array has at least 3 columns. This could cause an IndexError if the data structure is unexpected.Consider adding a validation check:
sys_tynatom = np.array(sys_tynatom) sys_tynatom = np.reshape(sys_tynatom, [nsys, -1]) + if sys_tynatom.shape[1] < 3: + raise ValueError(f"Expected natoms_vec to have at least 3 columns, got {sys_tynatom.shape[1]}") sys_tynatom = sys_tynatom[:, 2:]Let me verify the expected structure of natoms_vec:
🏁 Script executed:
#!/bin/bash # Search for natoms_vec usage patterns to understand expected structure rg -A 5 -B 5 'natoms_vec.*\[:.*2' --type pyLength of output: 45
#!/bin/bash
Display context around the slicing to understand sys_tynatom’s origin and expected shape
rg -n 'sys_tynatom' -C5 deepmd/dpmodel/fitting/ener_fitting.py
🏁 Script executed:
#!/bin/bash # Show surrounding lines where sys_tynatom is defined and sliced rg -n 'sys_tynatom' -C5 deepmd/dpmodel/fitting/ener_fitting.pyLength of output: 1913
#!/bin/bash
Find definitions or usage of natoms_vec and real_natoms_vec across the codebase
rg -n '"natoms_vec"' -C5 --type py
rg -n '"real_natoms_vec"' -C5 --type py
🏁 Script executed:
#!/bin/bash # Show context where all_stat is defined or populated in ener_fitting.py rg -n 'all_stat' -C5 deepmd/dpmodel/fitting/ener_fitting.pyLength of output: 2317
🏁 Script executed:
#!/bin/bash # Search for definitions and usage of natoms_vec, real_natoms_vec, and make_stat_input rg -n '"natoms_vec"' -C5 --type py rg -n '"real_natoms_vec"' -C5 --type py rg -n 'def make_stat_input' -C5 --type pyLength of output: 72872
Validate
sys_tynatom
shape before slicingEnsure
sys_tynatom
has at least three columns before applyingsys_tynatom[:, 2:]
. While NumPy slicing won’t throw anIndexError
for out‐of‐bounds slices, it will silently return an empty array ifshape[1] < 3
, which will break downstream logic incompute_stats_from_redu
. Add a guard immediately after the reshape:--- a/deepmd/dpmodel/fitting/ener_fitting.py +++ b/deepmd/dpmodel/fitting/ener_fitting.py @@ -142,6 +142,10 @@ sys_tynatom = np.reshape(sys_tynatom, [nsys, -1]) + if sys_tynatom.shape[1] < 3: + raise ValueError( + f"compute_output_stats requires at least 3 columns in sys_tynatom, got {sys_tynatom.shape[1]}" + ) sys_tynatom = sys_tynatom[:, 2:]• File:
deepmd/dpmodel/fitting/ener_fitting.py
• Context: lines 142–144, inside_compute_output_stats
Please verify that
all_stat["natoms_vec"]
andall_stat["real_natoms_vec"]
always yield ≥3 columns so this slice is valid.deepmd/dpmodel/loss/ener.py (1)
180-182
: Correct RMSE calculation implementation.The changes properly compute the root mean square error by applying
sqrt
to the mean squared loss, providing accurate error metrics.Also applies to: 194-196
elif args.command is None: | ||
pass | ||
else: | ||
raise RuntimeError(f"unknown command {args.command}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Improve command handling for better error reporting.
The current implementation silently passes when no command is provided, which might hide configuration issues. Additionally, the error message for unknown commands could be more helpful.
Apply this diff to improve error handling:
- elif args.command is None:
- pass
+ elif args.command is None:
+ raise RuntimeError("No command specified. Available commands: train, freeze")
else:
- raise RuntimeError(f"unknown command {args.command}")
+ raise RuntimeError(
+ f"Unknown command '{args.command}'. Available commands: train, freeze"
+ )
🤖 Prompt for AI Agents
In deepmd/jax/entrypoints/main.py around lines 64 to 67, replace the silent pass
when args.command is None with a clear error message indicating that no command
was provided. Also, enhance the RuntimeError message for unknown commands to
suggest checking available commands or usage. This improves error reporting by
explicitly handling missing commands and providing more informative feedback for
unknown commands.
# In this situation, we directly use these assigned energies instead of computing stats. | ||
# This will make the loss decrease quickly | ||
assigned_atom_ener = np.array( | ||
[ee if ee is not None else np.nan for ee in self.atom_ener_v] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix undefined attribute reference.
The code references self.atom_ener_v
which is not defined in the class. This will cause an AttributeError
at runtime.
Based on the similar implementation in deepmd/tf/fit/ener.py
, this should likely be self.atom_ener
:
- [ee if ee is not None else np.nan for ee in self.atom_ener_v]
+ [ee if ee is not None else np.nan for ee in self.atom_ener]
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
[ee if ee is not None else np.nan for ee in self.atom_ener_v] | |
[ee if ee is not None else np.nan for ee in self.atom_ener] |
🤖 Prompt for AI Agents
In deepmd/dpmodel/fitting/ener_fitting.py at line 150, the code references an
undefined attribute self.atom_ener_v, causing an AttributeError. Replace
self.atom_ener_v with self.atom_ener to match the correct attribute name used in
the class, following the pattern from deepmd/tf/fit/ener.py.
Other keyword arguments. | ||
""" | ||
EnergyLoss.__init__(self, **kwargs) | ||
self.has_h = start_pref_h != 0.0 and limit_pref_h != 0.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix logical condition for enabling Hessian loss.
The condition uses and
which requires both prefactors to be non-zero. This is likely too restrictive - the Hessian loss should be enabled if either prefactor is non-zero.
Apply this fix:
- self.has_h = start_pref_h != 0.0 and limit_pref_h != 0.0
+ self.has_h = start_pref_h != 0.0 or limit_pref_h != 0.0
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
self.has_h = start_pref_h != 0.0 and limit_pref_h != 0.0 | |
self.has_h = start_pref_h != 0.0 or limit_pref_h != 0.0 |
🤖 Prompt for AI Agents
In deepmd/dpmodel/loss/ener.py at line 409, the condition for enabling Hessian
loss uses 'and' to check if both start_pref_h and limit_pref_h are non-zero,
which is too restrictive. Change the logical operator from 'and' to 'or' so that
the Hessian loss is enabled if either start_pref_h or limit_pref_h is non-zero.
and "energy_derv_r_derv_r" in model_dict | ||
and "hessian" in label_dict | ||
): | ||
find_hessian = label_dict.get("find_hessian", 0.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Verify find_hessian data type consistency.
The code uses label_dict.get("find_hessian", 0.0)
which returns a scalar default, but other find_*
variables in the parent class appear to be arrays from the labeled data.
Let me verify the expected data type for find_* variables:
🏁 Script executed:
#!/bin/bash
# Search for find_* variable usage patterns in loss calculations
rg -B 2 -A 2 'find_[a-z_]+\s*=' deepmd/dpmodel/loss/ --type py
Length of output: 1419
🏁 Script executed:
#!/bin/bash
# Search for DataRequirementItem definitions that might indicate find_* data structure
ast-grep --pattern 'DataRequirementItem(
"find_$_",
$$$
)'
Length of output: 64
Ensure find_hessian
is accessed like the other find_*
flags
The other find_*
variables are pulled directly from label_dict
(no default), so using .get(..., 0.0)
here is inconsistent and may silently disable Hessian loss even when "hessian"
is present. Please update this so that missing flags are surfaced (or, if you really intend an optional flag, default to a matching array of ones).
• File deepmd/dpmodel/loss/ener.py, line 434
- find_hessian = label_dict.get("find_hessian", 0.0)
+ find_hessian = label_dict["find_hessian"]
If you do need a default, consider instead:
find_hessian = label_dict.get(
"find_hessian",
xp.ones(label_dict["hessian"].shape[0]) # match per-frame flag shape
)
🤖 Prompt for AI Agents
In deepmd/dpmodel/loss/ener.py at line 434, the assignment of find_hessian uses
label_dict.get with a scalar default 0.0, which is inconsistent with other
find_* variables that do not use defaults and are arrays. To fix this, remove
the default value so that missing keys raise an error or, if a default is
necessary, set it to an array of ones matching the shape of
label_dict["hessian"]. This ensures data type consistency and proper handling of
the find_hessian flag.
optimizer.update(grads) | ||
|
||
start_time = time.time() | ||
disp_file_fp = open(self.disp_file, "w") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use context manager for file operations.
The file disp_file_fp
should be opened using a context manager to ensure proper cleanup even if an exception occurs.
Apply this diff:
- disp_file_fp = open(self.disp_file, "w")
- for step in range(self.start_step, self.num_steps):
+ with open(self.disp_file, "w") as disp_file_fp:
+ for step in range(self.start_step, self.num_steps):
# ... (training loop)
-
- disp_file_fp.close()
Note: You'll need to indent the entire training loop inside the context manager.
Also applies to: 382-382
🧰 Tools
🪛 Ruff (0.11.9)
269-269: Use a context manager for opening files
(SIM115)
🤖 Prompt for AI Agents
In deepmd/jax/train/trainer.py at lines 269 and 382, the file disp_file_fp is
opened without a context manager, risking resource leaks if exceptions occur.
Refactor the code to open disp_file_fp using a with statement as a context
manager, and indent the entire training loop or relevant code block inside this
with block to ensure the file is properly closed after use.
valid_batch_data = valid_data.get_batch() | ||
jax_valid_data = { | ||
kk: jnp.asarray(vv) for kk, vv in valid_batch_data.items() | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apply consistent data conversion for validation data.
The validation data conversion should check for keys starting with "find_" similar to the training data conversion.
Apply this diff:
jax_valid_data = {
- kk: jnp.asarray(vv) for kk, vv in valid_batch_data.items()
+ kk: jnp.asarray(vv) if not kk.startswith("find_") else bool(vv.item())
+ for kk, vv in valid_batch_data.items()
}
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
valid_batch_data = valid_data.get_batch() | |
jax_valid_data = { | |
kk: jnp.asarray(vv) for kk, vv in valid_batch_data.items() | |
} | |
valid_batch_data = valid_data.get_batch() | |
jax_valid_data = { | |
kk: jnp.asarray(vv) if not kk.startswith("find_") else bool(vv.item()) | |
for kk, vv in valid_batch_data.items() | |
} |
🤖 Prompt for AI Agents
In deepmd/jax/train/trainer.py around lines 320 to 323, the validation data
conversion to JAX arrays should be consistent with the training data conversion
by checking if keys start with "find_". Update the dictionary comprehension to
convert values to jnp.asarray only for keys starting with "find_", leaving other
keys unchanged.
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## devel #4782 +/- ##
==========================================
- Coverage 84.79% 84.40% -0.40%
==========================================
Files 698 702 +4
Lines 67775 68126 +351
Branches 3544 3542 -2
==========================================
+ Hits 57472 57499 +27
- Misses 9169 9494 +325
+ Partials 1134 1133 -1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
1. For dpmodel, pt, and pd, pass the trainable parameter to the layer (not actually used in this PR). 2. For JAX, support the `trainable` parameter in the layer. Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
for more information, see https://pre-commit.ci
@@ -22,7 +22,7 @@ | |||
) | |||
|
|||
|
|||
def deserialize_to_file(model_file: str, data: dict) -> None: | |||
def deserialize_to_file(model_file: str, data: dict, hessian: bool = False) -> None: |
Check notice
Code scanning / CodeQL
Explicit returns mixed with implicit (fall through) returns Note
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (4)
deepmd/jax/train/trainer.py (2)
274-274
: Use context manager for file operations.The file
disp_file_fp
should be opened using a context manager to ensure proper cleanup even if an exception occurs.Apply this diff:
- disp_file_fp = open(self.disp_file, "w") - for step in range(self.start_step, self.num_steps): + with open(self.disp_file, "w") as disp_file_fp: + for step in range(self.start_step, self.num_steps): # ... (training loop content should be indented) - disp_file_fp.close()Note: You'll need to indent the entire training loop inside the context manager.
326-328
: Apply consistent data conversion for validation data.The validation data conversion should check for keys starting with "find_" similar to the training data conversion.
Apply this diff:
jax_valid_data = { - kk: jnp.asarray(vv) for kk, vv in valid_batch_data.items() + kk: jnp.asarray(vv) if not kk.startswith("find_") else bool(vv.item()) + for kk, vv in valid_batch_data.items() }deepmd/dpmodel/loss/ener.py (2)
409-409
: Fix logical condition for enabling Hessian loss.The condition uses
and
which requires both prefactors to be non-zero. This is likely too restrictive - the Hessian loss should be enabled if either prefactor is non-zero.Apply this fix:
- self.has_h = start_pref_h != 0.0 and limit_pref_h != 0.0 + self.has_h = start_pref_h != 0.0 or limit_pref_h != 0.0
434-434
: Ensurefind_hessian
is accessed like the otherfind_*
flags.The other
find_*
variables are pulled directly fromlabel_dict
(no default), so using.get(..., 0.0)
here is inconsistent and may silently disable Hessian loss even when"hessian"
is present.Apply this fix:
- find_hessian = label_dict.get("find_hessian", 0.0) + find_hessian = label_dict["find_hessian"]
🧹 Nitpick comments (4)
deepmd/jax/train/trainer.py (4)
288-289
: Simplify .get() calls.The explicit
None
default is unnecessary since it's the default value for.get()
.Apply this diff:
- fparam=jax_data.get("fparam", None), - aparam=jax_data.get("aparam", None), + fparam=jax_data.get("fparam"), + aparam=jax_data.get("aparam"),
338-339
: Simplify .get() calls.The explicit
None
default is unnecessary since it's the default value for.get()
.Apply this diff:
- fparam=jax_valid_data.get("fparam", None), - aparam=jax_valid_data.get("aparam", None), + fparam=jax_valid_data.get("fparam"), + aparam=jax_valid_data.get("aparam"),
401-401
: Simplify dictionary iteration.Using
dict.keys()
is unnecessary when iterating over dictionary keys.Apply this diff:
- for k in valid_results.keys(): + for k in valid_results:
406-406
: Simplify dictionary iteration.Using
dict.keys()
is unnecessary when iterating over dictionary keys.Apply this diff:
- for k in train_results.keys(): + for k in train_results:
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (10)
deepmd/dpmodel/descriptor/dpa1.py
(1 hunks)deepmd/dpmodel/loss/ener.py
(3 hunks)deepmd/dpmodel/utils/env_mat_stat.py
(1 hunks)deepmd/jax/entrypoints/freeze.py
(1 hunks)deepmd/jax/infer/deep_eval.py
(2 hunks)deepmd/jax/jax2tf/serialization.py
(2 hunks)deepmd/jax/model/hlo.py
(2 hunks)deepmd/jax/train/trainer.py
(1 hunks)deepmd/jax/utils/serialization.py
(5 hunks)deepmd/main.py
(1 hunks)
✅ Files skipped from review due to trivial changes (1)
- deepmd/dpmodel/utils/env_mat_stat.py
🚧 Files skipped from review as they are similar to previous changes (2)
- deepmd/dpmodel/descriptor/dpa1.py
- deepmd/jax/utils/serialization.py
🧰 Additional context used
🧠 Learnings (4)
📓 Common learnings
Learnt from: njzjz
PR: deepmodeling/deepmd-kit#4284
File: deepmd/jax/__init__.py:8-8
Timestamp: 2024-10-30T20:08:12.531Z
Learning: In the DeepMD project, entry points like `deepmd.jax` may be registered in external projects, so their absence in the local configuration files is acceptable.
deepmd/jax/model/hlo.py (5)
Learnt from: 1azyking
PR: #4169
File: deepmd/pt/loss/ener_hess.py:341-348
Timestamp: 2024-10-08T15:32:11.479Z
Learning: In deepmd/pt/loss/ener_hess.py
, the label
uses the key "atom_ener"
intentionally to maintain consistency with the forked version.
Learnt from: 1azyking
PR: #4169
File: deepmd/pt/loss/ener_hess.py:341-348
Timestamp: 2024-10-05T03:11:02.922Z
Learning: In deepmd/pt/loss/ener_hess.py
, the label
uses the key "atom_ener"
intentionally to maintain consistency with the forked version.
Learnt from: 1azyking
PR: #4169
File: deepmd/utils/argcheck.py:1982-2117
Timestamp: 2024-10-05T03:06:02.372Z
Learning: The loss_ener_hess
and loss_ener
functions should remain separate to avoid confusion, despite code duplication.
Learnt from: 1azyking
PR: #4169
File: deepmd/pt/model/model/ener_hess_model.py:127-144
Timestamp: 2024-10-08T15:32:11.479Z
Learning: Training with energy Hessian is not required for lower-level operations as of now.
Learnt from: 1azyking
PR: #4169
File: deepmd/pt/model/model/ener_hess_model.py:127-144
Timestamp: 2024-10-05T02:46:06.925Z
Learning: Training with energy Hessian is not required for lower-level operations as of now.
deepmd/dpmodel/loss/ener.py (5)
Learnt from: 1azyking
PR: #4169
File: deepmd/pt/loss/ener_hess.py:341-348
Timestamp: 2024-10-08T15:32:11.479Z
Learning: In deepmd/pt/loss/ener_hess.py
, the label
uses the key "atom_ener"
intentionally to maintain consistency with the forked version.
Learnt from: 1azyking
PR: #4169
File: deepmd/pt/loss/ener_hess.py:341-348
Timestamp: 2024-10-05T03:11:02.922Z
Learning: In deepmd/pt/loss/ener_hess.py
, the label
uses the key "atom_ener"
intentionally to maintain consistency with the forked version.
Learnt from: 1azyking
PR: #4169
File: deepmd/utils/argcheck.py:1982-2117
Timestamp: 2024-10-05T03:06:02.372Z
Learning: The loss_ener_hess
and loss_ener
functions should remain separate to avoid confusion, despite code duplication.
Learnt from: 1azyking
PR: #4169
File: deepmd/pt/model/model/ener_hess_model.py:127-144
Timestamp: 2024-10-08T15:32:11.479Z
Learning: Training with energy Hessian is not required for lower-level operations as of now.
Learnt from: 1azyking
PR: #4169
File: deepmd/pt/model/model/ener_hess_model.py:127-144
Timestamp: 2024-10-05T02:46:06.925Z
Learning: Training with energy Hessian is not required for lower-level operations as of now.
deepmd/jax/train/trainer.py (2)
Learnt from: 1azyking
PR: #4169
File: examples/hessian/single-task/input.json:3-59
Timestamp: 2024-10-05T02:57:02.517Z
Learning: In the deepmd-kit project, having different rcut
values for repinit
and repformer
is acceptable when they align with existing configurations, such as in examples/water/dpa2/input_torch.json
.
Learnt from: 1azyking
PR: #4169
File: examples/hessian/single-task/input.json:3-59
Timestamp: 2024-10-08T15:32:11.479Z
Learning: In the deepmd-kit project, having different rcut
values for repinit
and repformer
is acceptable when they align with existing configurations, such as in examples/water/dpa2/input_torch.json
.
🪛 Ruff (0.12.2)
deepmd/jax/train/trainer.py
274-274: Use a context manager for opening files
(SIM115)
288-288: Use jax_data.get("fparam")
instead of jax_data.get("fparam", None)
Replace jax_data.get("fparam", None)
with jax_data.get("fparam")
(SIM910)
289-289: Use jax_data.get("aparam")
instead of jax_data.get("aparam", None)
Replace jax_data.get("aparam", None)
with jax_data.get("aparam")
(SIM910)
338-338: Use jax_valid_data.get("fparam")
instead of jax_valid_data.get("fparam", None)
Replace jax_valid_data.get("fparam", None)
with jax_valid_data.get("fparam")
(SIM910)
339-339: Use jax_valid_data.get("aparam")
instead of jax_valid_data.get("aparam", None)
Replace jax_valid_data.get("aparam", None)
with jax_valid_data.get("aparam")
(SIM910)
401-401: Use key in dict
instead of key in dict.keys()
Remove .keys()
(SIM118)
406-406: Use key in dict
instead of key in dict.keys()
Remove .keys()
(SIM118)
🔇 Additional comments (19)
deepmd/main.py (1)
338-343
: LGTM! Clean CLI integration for Hessian support.The addition of the
--hessian
flag to the freeze subcommand is well-implemented:
- Follows argparse conventions with appropriate action and default
- Clear help text explaining the functionality
- Integrates properly with the broader Hessian support in the JAX backend
deepmd/jax/jax2tf/serialization.py (3)
24-24
: LGTM! Well-designed function signature extension.The addition of the
hessian
parameter with a default value ofFalse
maintains backward compatibility while enabling the new functionality.
33-34
: LGTM! Clear documentation.The parameter documentation is concise and accurately describes the functionality.
39-41
: LGTM! Correct Hessian mode enablement.The conditional logic properly:
- Enables Hessian mode on the model instance
- Updates the model definition script with the Hessian flag
- Only executes when the hessian parameter is True
deepmd/jax/model/hlo.py (2)
34-41
: LGTM! Properly structured Hessian output definition.The
energy_hessian
output definition correctly:
- Mirrors the standard
energy
definition structure- Adds the
r_hessian=True
attribute to enable Hessian computations- Uses consistent naming and shape specifications
178-191
: LGTM! Smart conditional output selection.The modified
model_output_def
method elegantly handles Hessian mode by:
- Checking the
hessian_mode
flag in the model definition script- Only applying the transformation for
energy
output type- Using string formatting to select the appropriate output definition
The logic correctly falls back to the standard output when Hessian mode is disabled or for non-energy outputs.
deepmd/jax/entrypoints/freeze.py (2)
12-18
: LGTM! Well-designed function signature.The function design is excellent:
- Uses keyword-only arguments (
*,
) which prevents positional argument misuse- Provides sensible default for
hessian
parameter- Accepts
**kwargs
for extensibility
32-39
: LGTM! Robust checkpoint resolution and error handling.The implementation correctly:
- Handles checkpoint file indirection by reading the checkpoint metadata
- Validates directory existence before processing
- Provides clear error messages with the actual checkpoint path
- Passes the hessian flag through to the serialization layer
deepmd/jax/infer/deep_eval.py (3)
281-281
: LGTM! Consistent integration of Hessian output category.The addition of
OutputVariableCategory.DERV_R_DERV_R
to the non-atomic output categories correctly enables Hessian matrix computation in inference.
414-416
: LGTM! Correct Hessian matrix shape calculation.The shape
[nframes, 3 * natoms, 3 * natoms]
is mathematically correct for the Hessian matrix, representing the second derivatives with respect to all atomic coordinates.
424-426
: LGTM! Clean Hessian mode detection.The
get_has_hessian
method correctly:
- Retrieves the model definition script
- Checks for the
hessian_mode
key with appropriate default- Returns a boolean indicating Hessian support
deepmd/jax/train/trainer.py (3)
1-63
: LGTM! Well-organized imports.The imports are comprehensive and well-structured, covering all necessary components for JAX-based training including models, loss functions, utilities, and logging.
65-145
: LGTM! Well-structured initialization.The initialization properly handles different model loading scenarios and configures training parameters. The learning rate setup is clean and extensible for future learning rate types.
430-465
: Excellent utility function implementation.The
prepare_input
function is well-designed with proper keyword-only arguments and handles all necessary coordinate transformations for DeepMD models, including normalization, ghost atom extension, and neighbor list building.deepmd/dpmodel/loss/ener.py (5)
180-182
: Good improvement to loss reporting.Changing from MSE to RMSE reporting makes the loss metrics more interpretable while keeping the underlying MSE-based loss computation for optimization. The implementation is mathematically correct.
194-196
: Consistent RMSE reporting for force loss.The force loss reporting is now consistent with energy loss, using RMSE instead of MSE for better interpretability.
390-413
: Well-designed Hessian loss extension.The
EnergyHessianLoss
class properly extends the baseEnergyLoss
class with configurable Hessian prefactors and appropriate parameter handling.
414-448
: Solid Hessian loss computation implementation.The Hessian loss computation correctly handles prefactor interpolation, conditional loss addition, and RMSE reporting. The logic for checking model and label dictionary keys before computation is appropriate.
449-464
: Proper label requirements extension.The
label_requirement
property correctly extends the parent class requirements with Hessian data requirements when Hessian loss is enabled.
@@ -419,4 +419,8 @@ def _get_output_shape(self, odef, nframes, natoms): | |||
|
|||
def get_model_def_script(self) -> dict: | |||
"""Get model definition script.""" | |||
return json.loads(self.dp.get_model_def_script()) | |||
return self.dp.get_model_def_script() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Verify impact of removing JSON parsing in get_model_def_script.
The change removes JSON parsing and returns the raw model definition script. This could affect downstream code that expects a parsed dictionary.
🏁 Script executed:
#!/bin/bash
# Search for usages of get_model_def_script to verify they can handle raw script
rg -A 5 "get_model_def_script\(\)" --type py
Length of output: 6125
Action Required: Restore JSON parsing in JAX infer’s get_model_def_script
The call to self.dp.get_model_def_script()
now returns a raw JSON string, but downstream code in this module (e.g. get_has_hessian
) still does .get(...)
. We need to re-introduce json.loads
so callers continue to receive a dictionary.
Please update deepmd/jax/infer/deep_eval.py:
- Add
import json
at the top (if not already present). - Change the method on line 422 from:
to:
return self.dp.get_model_def_script()
return json.loads(self.dp.get_model_def_script())
This ensures that:
get_model_def_script()
returns adict
again, matching the expectations of:model_def_script = self.get_model_def_script() return model_def_script.get("hessian_mode", False)
- Any other callers in this file relying on dict-style access remain functional.
🤖 Prompt for AI Agents
In deepmd/jax/infer/deep_eval.py at line 422, the method get_model_def_script
currently returns a raw JSON string, but downstream code expects a dictionary.
To fix this, add "import json" at the top of the file if not already present,
then change the return statement on line 422 to return
json.loads(self.dp.get_model_def_script()) so that the method returns a parsed
dictionary as expected.
Summary by CodeRabbit
New Features
Enhancements
Other Changes