feat(dpmodel): add backend-independent trainer abstraction#5603
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds shared training abstractions for entrypoints, trainers, task normalization, and validation. Refactors JAX and pt_expt training flows to use the shared pipeline. Consolidates finetune rule handling into shared utilities. Extends JAX checkpoint serialization for multitask models. ChangesShared Training Abstractions and Backend Integrations
Estimated code review effort: 5 (Critical) | ~120 minutes Possibly related PRs
Suggested reviewers: 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 4
🧹 Nitpick comments (1)
source/tests/test_dpmodel_abstract_trainer.py (1)
99-109: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick winAvoid hard-coding the default task key in this test.
TrainingTaskCollection.single()owns the default key viaDEFAULT_TASK_KEY, so indexing with"Default"makes this test fail on an internal rename without any behavior change. Pull the lone task from the collection API instead.Proposed fix
- task = tasks["Default"] + task = tasks.select() task.add_data_requirements() assert not tasks.is_multitask assert tasks.select() is task🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@source/tests/test_dpmodel_abstract_trainer.py` around lines 99 - 109, The test is hard-coding the default task name instead of using the collection API. In test_dpmodel_abstract_trainer.py, update the TrainingTaskCollection.single() usage so the lone task is retrieved without indexing by "Default", and reference the collection’s default-task behavior through its API/select() rather than the literal key. Keep the assertions on task.add_data_requirements(), tasks.is_multitask, and tasks.select() intact, but bind task from the collection in a way that won’t break if DEFAULT_TASK_KEY is renamed.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@deepmd/dpmodel/train/trainer.py`:
- Around line 380-383: Keep the single-task table schema anchored to
train_results in the Trainer formatting path. The header logic and the row
emission in the same train/eval summary should use the exact same key sequence
from train_results, not valid_results, so the validation columns stay aligned
and missing or reordered validation metrics do not break train_results lookups.
Update the loop in the Trainer row-building code to derive the metric order from
train_results and only read matching validation values by key when present.
- Around line 148-161: The task normalization in the trainer constructor
silently overwrites duplicate entries when `tasks` is a sequence, so add an
explicit duplicate-key check before building `self._tasks` in
`Trainer.__init__`. Update the branch that handles non-Mapping `tasks` to
validate that each `task.key` appears only once, and raise a clear `ValueError`
if duplicates are found. Keep the existing key-matching validation and
`_normalize_probabilities` flow intact.
In `@deepmd/jax/train/trainer.py`:
- Around line 188-194: The training setup in
AbstractTrainer.on_train_begin()/train() is missing registration of loss label
requirements before creating the task collection. Update the training flow
around TrainingTaskCollection.single and self.run(tasks) so
DeepmdDataSystem.add_data_requirements() is called with
self.loss.label_requirement (or equivalent task data_requirements) before
batching starts, ensuring get_batch() includes the labels needed by the loss
path.
In `@source/tests/test_dpmodel_abstract_trainer.py`:
- Around line 53-64: The evaluate_training helper is consuming a batch for
inactive multitask entries by falling back to task.training_data.get_batch()
when step_result is None, which advances the cursor during display collection.
Update evaluate_training in test_dpmodel_abstract_trainer.py so it only reads
step_result.payload for the matching task and otherwise returns a non-consuming
placeholder/skip path for inactive tasks; keep the batch-order contract aligned
with the multitask loop in trainer.py and the TrainStepResult/task.key check.
---
Nitpick comments:
In `@source/tests/test_dpmodel_abstract_trainer.py`:
- Around line 99-109: The test is hard-coding the default task name instead of
using the collection API. In test_dpmodel_abstract_trainer.py, update the
TrainingTaskCollection.single() usage so the lone task is retrieved without
indexing by "Default", and reference the collection’s default-task behavior
through its API/select() rather than the literal key. Keep the assertions on
task.add_data_requirements(), tasks.is_multitask, and tasks.select() intact, but
bind task from the collection in a way that won’t break if DEFAULT_TASK_KEY is
renamed.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: ae088f30-bef5-4da4-ae02-40f16ee87975
📒 Files selected for processing (4)
deepmd/dpmodel/train/__init__.pydeepmd/dpmodel/train/trainer.pydeepmd/jax/train/trainer.pysource/tests/test_dpmodel_abstract_trainer.py
d100673 to
6e168b7
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (2)
deepmd/dpmodel/train/trainer.py (2)
422-428: 🩺 Stability & Availability | 🟠 Major | ⚡ Quick winSingle-task row schema should anchor to
train_results, notvalid_results.
format_headeriteratestrain_results(Line 382), but the row loop here iteratesvalid_resultsand indexestrain_results[key]. If a backend returns validation metrics in a different order or omits a metric, the row desynchronizes from the header andtrain_results[key]can raiseKeyError.Suggested fix
if valid_results is not None: assert not self._is_multitask(valid_results) - for key in valid_results: + for key in train_results: row += ( - f" {float(valid_results[key]):11.2e}" + f" {float(valid_results.get(key, float('nan'))):11.2e}" f" {float(train_results[key]):11.2e}" )🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/dpmodel/train/trainer.py` around lines 422 - 428, The single-task row assembly in `Trainer.format_*` is anchored to the wrong metric source: the header follows `train_results`, but this row loop currently iterates `valid_results` while reading `train_results[key]`, which can desynchronize columns or fail when validation metrics differ. Update the row-building logic to use the same key order as `train_results` (matching the existing `format_header` contract) and only use `valid_results` for the optional validation value lookup, keeping the schema consistent with the training metrics.
148-161: 🎯 Functional Correctness | 🟠 Major | ⚡ Quick winDuplicate task keys are still silently dropped when
tasksis a sequence.
task_dict = {task.key: task for task in tasks}overwrites earlier entries, so a misconfigured multitask run can lose a task with no error. Validate uniqueness before building_tasks.Suggested fix
if isinstance(tasks, Mapping): task_dict = dict(tasks) else: - task_dict = {task.key: task for task in tasks} + task_list = list(tasks) + task_dict = {task.key: task for task in task_list} + if len(task_dict) != len(task_list): + raise ValueError("Training task keys must be unique.")🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/dpmodel/train/trainer.py` around lines 148 - 161, The task normalization in Trainer.__init__ still silently overwrites duplicate keys when tasks is a sequence because task_dict is built directly from task.key; add an explicit duplicate-key check before assigning to self._tasks. In the branch that handles non-Mapping tasks, validate that each task.key is unique (raise a ValueError on duplicates) before constructing the dict, while keeping the existing key-vs-task.key consistency check and downstream _normalize_probabilities flow unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@deepmd/pt_expt/train/training.py`:
- Line 831: Wire the parsed retention setting into checkpoint cleanup so
`max_ckpt_keep` actually limits saved checkpoints: `TrainerConfig` already
carries the value and `AbstractTrainer.run()`/`save_checkpoint()` currently only
append new `self.save_ckpt-<step>.pt` files. Update the checkpoint-saving flow
to track existing checkpoints and prune the oldest ones after each save, using
`self.max_ckpt_keep` as the cap and keeping the newest checkpoints only.
---
Duplicate comments:
In `@deepmd/dpmodel/train/trainer.py`:
- Around line 422-428: The single-task row assembly in `Trainer.format_*` is
anchored to the wrong metric source: the header follows `train_results`, but
this row loop currently iterates `valid_results` while reading
`train_results[key]`, which can desynchronize columns or fail when validation
metrics differ. Update the row-building logic to use the same key order as
`train_results` (matching the existing `format_header` contract) and only use
`valid_results` for the optional validation value lookup, keeping the schema
consistent with the training metrics.
- Around line 148-161: The task normalization in Trainer.__init__ still silently
overwrites duplicate keys when tasks is a sequence because task_dict is built
directly from task.key; add an explicit duplicate-key check before assigning to
self._tasks. In the branch that handles non-Mapping tasks, validate that each
task.key is unique (raise a ValueError on duplicates) before constructing the
dict, while keeping the existing key-vs-task.key consistency check and
downstream _normalize_probabilities flow unchanged.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 2ee2703c-74d9-48e0-8c5e-ce031e73c19a
📒 Files selected for processing (5)
deepmd/dpmodel/train/__init__.pydeepmd/dpmodel/train/trainer.pydeepmd/jax/train/trainer.pydeepmd/pt_expt/train/training.pysource/tests/test_dpmodel_abstract_trainer.py
✅ Files skipped from review due to trivial changes (1)
- deepmd/dpmodel/train/init.py
🚧 Files skipped from review as they are similar to previous changes (2)
- source/tests/test_dpmodel_abstract_trainer.py
- deepmd/jax/train/trainer.py
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## master #5603 +/- ##
==========================================
- Coverage 81.98% 81.85% -0.14%
==========================================
Files 959 966 +7
Lines 105430 106745 +1315
Branches 4071 4102 +31
==========================================
+ Hits 86442 87376 +934
- Misses 17518 17875 +357
- Partials 1470 1494 +24 ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
6e168b7 to
78272a7
Compare
There was a problem hiding this comment.
Actionable comments posted: 4
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@deepmd/dpmodel/train/entrypoint.py`:
- Around line 79-83: The cleanup guard in the entrypoint flow is starting too
late, so `teardown_run()` will not run if `setup_run()` fails after partially
initializing backend state. Move the `try/finally` in the main training
entrypoint so it wraps `setup_run()` itself, ensuring `teardown_run()` is always
invoked even when setup raises. Use the existing `setup_run()` and
`teardown_run()` flow in this module as the anchor when making the change.
- Line 48: The AbstractTrainEntrypoint contract is too implicit: it currently
has no abstract methods and the hook methods are docstring-only, which triggers
lint issues. Update AbstractTrainEntrypoint so run_training() is marked
abstract, and give the optional hook methods explicit no-op bodies with return
None unless they must be overridden by every backend. Use the class and method
names AbstractTrainEntrypoint and run_training() to locate the contract and hook
definitions.
In `@deepmd/pt_expt/entrypoints/main.py`:
- Around line 295-297: Replace the existing assert-based validation in the
multi-task config handling with an explicit runtime check that raises ValueError
when the config is "RANDOM"; this ensures the validation in the main entrypoint
is enforced even under optimized Python runs. Update the check in the entrypoint
logic that handles the multi-task configuration so the rejection of "RANDOM" is
always consistent.
- Around line 257-259: `setup_run` and `teardown_run` need ownership tracking
for the distributed process group: make `setup_run` skip `init_process_group()`
when a default group already exists, and record whether this entrypoint created
the group. Then update `teardown_run` to call `destroy_process_group()` only
when it is cleaning up a group created by `setup_run`, so it does not interfere
with a caller-managed distributed context.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: ec38d095-ecf4-4024-ac9f-7523f8d40e60
📒 Files selected for processing (11)
deepmd/dpmodel/train/__init__.pydeepmd/dpmodel/train/entrypoint.pydeepmd/dpmodel/train/trainer.pydeepmd/jax/entrypoints/train.pydeepmd/jax/train/trainer.pydeepmd/pt_expt/entrypoints/main.pydeepmd/pt_expt/train/training.pysource/tests/jax/test_training.pysource/tests/pt_expt/test_entrypoint.pysource/tests/test_dpmodel_abstract_trainer.pysource/tests/test_dpmodel_train_entrypoint.py
🚧 Files skipped from review as they are similar to previous changes (4)
- deepmd/dpmodel/train/init.py
- source/tests/test_dpmodel_abstract_trainer.py
- deepmd/jax/train/trainer.py
- deepmd/pt_expt/train/training.py
78272a7 to
5bebfdf
Compare
There was a problem hiding this comment.
Actionable comments posted: 7
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
deepmd/jax/utils/serialization.py (1)
189-207: 🗄️ Data Integrity & Integration | 🟠 Major | ⚡ Quick winPreserve string task keys when converting restored state keys.
convert_str_to_int_key(state)also rewrites digit-only task names understate["models"]. A valid multi-task key like"0"becomes0, thenstate_by_model[model_key]fails becausemodel_def_script["model_dict"]still uses"0".Proposed fix
- convert_str_to_int_key(state) - model_def_script = data.model_def_script if "model_dict" in model_def_script: state_by_model = state.get("models", state) + if "models" in state: + for model_state in state_by_model.values(): + convert_str_to_int_key(model_state) + else: + convert_str_to_int_key(state_by_model) model_dict = {"model_dict": {}} for model_key, model_params in model_def_script["model_dict"].items(): abstract_model = get_model(model_params) graphdef, abstract_state = nnx.split(abstract_model) abstract_state.replace_by_pure_dict(state_by_model[model_key]) model = nnx.merge(graphdef, abstract_state) model_dict["model_dict"][model_key] = model.serialize() else: + convert_str_to_int_key(state) abstract_model = get_model(model_def_script)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/jax/utils/serialization.py` around lines 189 - 207, The state restoration logic in convert_str_to_int_key is too aggressive and converts digit-only task names inside state["models"], which breaks later lookup in the model merge path. Update the key conversion so it only normalizes keys that represent numeric indices where needed, while preserving string task keys used by model_def_script["model_dict"]; then ensure the state_by_model lookup in the model merging section continues to use the original task-name strings.deepmd/pt_expt/train/training.py (1)
1509-1515: 🗄️ Data Integrity & Integration | 🟠 Major | ⚡ Quick winMake the
latestsymlink target relative to its own directory.When
save_ckptincludes a directory, e.g.out/model.ckpt,latest.symlink_to("out/model.ckpt-1.pt")createsout/model.ckpt.pt -> out/model.ckpt-1.pt, which resolves asout/out/model.ckpt-1.pt. Restarting from the prefix then follows a broken symlink.Proposed fix
- ckpt_path = f"{self.save_ckpt}-{step}.pt" + ckpt_path = Path(f"{self.save_ckpt}-{step}.pt") torch.save(state, ckpt_path) # symlink latest latest = Path(f"{self.save_ckpt}.pt") if latest.is_symlink() or latest.exists(): latest.unlink() - latest.symlink_to(ckpt_path) + latest.symlink_to(ckpt_path.name)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/pt_expt/train/training.py` around lines 1509 - 1515, The `latest` symlink in `training.py` is created with a target that can become incorrectly resolved when `save_ckpt` includes a directory. Update the checkpoint-saving logic in the block that builds `ckpt_path`, `latest = Path(f"{self.save_ckpt}.pt")`, and calls `latest.symlink_to(...)` so the symlink target is computed relative to `latest`’s parent directory instead of using the full path string. Make sure the existing `torch.save` flow and cleanup of any prior `latest` link remain unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@deepmd/dpmodel/train/data.py`:
- Around line 120-124: The fallback in _print_summary is too broad because it
catches every TypeError from data.print_summary and retries without prob, which
can hide real failures. Update _print_summary to detect the supported call
signature first (or only retry on an explicit argument-count mismatch) so only
old print_summary implementations use the no-prob path while genuine TypeError
exceptions still propagate.
In `@deepmd/dpmodel/train/trainer.py`:
- Around line 575-579: The optional ABC hook methods on Training hooks are
docstring-only and trigger Ruff B027; update the on_train_begin and on_train_end
methods in TrainingTask/Trainer-related code to use explicit no-op bodies by
adding return None so they remain optional without failing lint.
- Around line 472-479: The cleanup guard starts too late in the training setup
path, so resources created by on_train_begin() may not be released if
_open_learning_curve() fails. Move the try/finally in Trainer.train() (or the
surrounding training entry point) so it begins before on_train_begin(tasks), and
keep on_train_end(tasks) in the finally block to ensure backend state is always
cleaned up even when setup throws.
In `@deepmd/dpmodel/utils/training_utils.py`:
- Around line 124-130: The `_training_data_size` helper is swallowing all
`TypeError` from `len(training_data)` and turning broken `__len__`
implementations into a fallback size of 1. Update `_training_data_size` so it
only returns 1 for objects that truly do not support sizing, and allow real
`__len__` failures to propagate instead of masking them; keep the existing
`get_nsystems` path and the `resolve_model_prob()` callers in mind when
adjusting the behavior.
In `@deepmd/jax/utils/finetune.py`:
- Around line 50-53: The single-task pretrained checkpoint path in finetune.py
silently ignores non-empty model_branch_from values other than RANDOM, which can
hide typos and fall back to Default unexpectedly. Update the branch-selection
logic in the finetuning flow around single_config_chosen/model_branch_from to
explicitly validate the branch name when from_multitask is false, and raise an
error for any unknown non-empty value instead of proceeding. Keep RANDOM as the
only special case that enables new_fitting, and ensure the check happens before
the code falls back to Default.
- Around line 152-157: The pre-check in finetune setup is rejecting valid
finetune_head aliases because it compares against raw pretrained_keys before
alias resolution. Update the validation around pretrained_key in the finetune
flow so it uses the same alias-aware mapping as _get_finetune_rule_single() via
get_model_dict(), and only raise the ValueError after checking the resolved
model dict entries rather than the unexpanded pretrained_keys list.
In `@deepmd/pt_expt/entrypoints/main.py`:
- Around line 167-177: The _ensure_stat_file_path helper currently creates HDF5
stat files with h5py.File before ensuring the parent directory exists, so paths
like stats/model_stat.h5 can fail during trainer setup. Update
_ensure_stat_file_path to create the parent directories first for file targets,
then open the HDF5 file, and keep the existing directory creation path unchanged
for non-HDF5 stat targets. Use the stat_file_path, Path, and h5py.File logic in
this function to apply the fix.
---
Outside diff comments:
In `@deepmd/jax/utils/serialization.py`:
- Around line 189-207: The state restoration logic in convert_str_to_int_key is
too aggressive and converts digit-only task names inside state["models"], which
breaks later lookup in the model merge path. Update the key conversion so it
only normalizes keys that represent numeric indices where needed, while
preserving string task keys used by model_def_script["model_dict"]; then ensure
the state_by_model lookup in the model merging section continues to use the
original task-name strings.
In `@deepmd/pt_expt/train/training.py`:
- Around line 1509-1515: The `latest` symlink in `training.py` is created with a
target that can become incorrectly resolved when `save_ckpt` includes a
directory. Update the checkpoint-saving logic in the block that builds
`ckpt_path`, `latest = Path(f"{self.save_ckpt}.pt")`, and calls
`latest.symlink_to(...)` so the symlink target is computed relative to
`latest`’s parent directory instead of using the full path string. Make sure the
existing `torch.save` flow and cleanup of any prior `latest` link remain
unchanged.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 78180874-4d87-4c66-bd4e-fd0b4ee961bb
📒 Files selected for processing (15)
deepmd/dpmodel/train/__init__.pydeepmd/dpmodel/train/data.pydeepmd/dpmodel/train/entrypoint.pydeepmd/dpmodel/train/trainer.pydeepmd/dpmodel/utils/training_utils.pydeepmd/jax/entrypoints/train.pydeepmd/jax/train/trainer.pydeepmd/jax/utils/finetune.pydeepmd/jax/utils/serialization.pydeepmd/pt_expt/entrypoints/main.pydeepmd/pt_expt/train/training.pysource/tests/jax/test_training.pysource/tests/pt_expt/test_entrypoint.pysource/tests/test_dpmodel_abstract_trainer.pysource/tests/test_dpmodel_train_entrypoint.py
🚧 Files skipped from review as they are similar to previous changes (1)
- deepmd/dpmodel/train/init.py
5bebfdf to
9485193
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@deepmd/jax/utils/finetune.py`:
- Around line 21-25: The suffix check in _load_model_params is too restrictive
and rejects checkpoint directories or pointer paths that serialize_from_file can
already deserialize. Remove the hard-coded .jax validation in _load_model_params
and let serialize_from_file(finetune_model) handle the input format consistently
with the freeze() and init_model paths, while still extracting
"model_def_script" from the loaded checkpoint data.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: cd8a6c26-b867-4429-892e-f4aa7e5fbc5e
📒 Files selected for processing (22)
deepmd/dpmodel/train/__init__.pydeepmd/dpmodel/train/data.pydeepmd/dpmodel/train/entrypoint.pydeepmd/dpmodel/train/trainer.pydeepmd/dpmodel/utils/training_utils.pydeepmd/jax/entrypoints/train.pydeepmd/jax/train/trainer.pydeepmd/jax/utils/finetune.pydeepmd/jax/utils/serialization.pydeepmd/pd/utils/finetune.pydeepmd/pt/utils/finetune.pydeepmd/pt_expt/entrypoints/main.pydeepmd/pt_expt/train/training.pydeepmd/pt_expt/utils/finetune.pydeepmd/utils/finetune.pysource/tests/common/dpmodel/test_train_abstract_trainer.pysource/tests/common/dpmodel/test_train_data.pysource/tests/common/dpmodel/test_train_entrypoint.pysource/tests/common/dpmodel/test_training_utils.pysource/tests/common/test_finetune_utils.pysource/tests/jax/test_training.pysource/tests/pt_expt/test_entrypoint.py
🚧 Files skipped from review as they are similar to previous changes (8)
- deepmd/dpmodel/utils/training_utils.py
- deepmd/dpmodel/train/init.py
- deepmd/jax/utils/serialization.py
- deepmd/dpmodel/train/data.py
- deepmd/pt_expt/train/training.py
- deepmd/jax/entrypoints/train.py
- deepmd/jax/train/trainer.py
- deepmd/pt_expt/entrypoints/main.py
| for model_key in self.model_keys: | ||
| tx = optax.adam( | ||
| learning_rate=lambda step: self.lr.value(self.start_step + step), | ||
| ) |
There was a problem hiding this comment.
Non-blocking (JAX multi-task LR schedule): this learning_rate=lambda step: self.lr.value(self.start_step + step) was correct for the pre-refactor single-task trainer (one optimizer, optimizer.update called once per global step, so optax's internal schedule count == global step). Inside this per-task loop it no longer holds: each task gets its own nnx.Optimizer, and optax drives the schedule by that optimizer's internal update count, which only advances when the task is sampled (~p_i * num_steps times). So in multi-task each task's LR decays too slowly and never reaches the configured minimum at num_steps.
It's also internally inconsistent: the loss-prefactor LR passed into train_step (self.lr.value(step), global step) and the lcurve LR (learning_rate(step), global step) both use the global step, while the optimizer's applied LR uses the per-task count — they diverge for every multi-task run. This also differs from pt/pt_expt, which advance the scheduler once per global step for all tasks.
Single-task is unaffected (counts coincide), and JAX multi-task is new here so it's not a regression — but it ships wrong. Suggest driving the optimizer schedule by the global step explicitly (so it matches the loss-prefactor/lcurve LR and the other backends) rather than relying on optax's per-optimizer internal count.
There was a problem hiding this comment.
Fixed in njzjz/deepmd-kit@7280c146. The JAX optimizer now separates Adam moment scaling from LR scaling and applies the learning rate passed from the global training step, so multi-task tasks no longer decay by per-task optimizer counts. Added a focused optax test for explicit global LR scaling.
| if hasattr(training_data, "get_nsystems"): | ||
| return int(training_data.get_nsystems()) |
There was a problem hiding this comment.
Non-blocking (test coverage): the get_nsystems branch here is the real-data path (e.g. the wrapped training-data systems), but the new test_training_utils.py only constructs objects exercising the Sized/len() and unsized-fallback branches — none expose get_nsystems, so this first branch is never run. The project convention is to cover every reachable branch of a new helper; a small test passing an object with a get_nsystems() method would close it.
There was a problem hiding this comment.
Fixed in njzjz/deepmd-kit@7280c146. Added coverage for the get_nsystems() branch and verified that it takes precedence over len for model probability defaults.
| def _ordered_task_keys(self, results: Mapping[str, Any]) -> list[str]: | ||
| keys = self.task_keys or list(results) | ||
| return [key for key in keys if key in results] | ||
|
|
There was a problem hiding this comment.
Non-blocking (test coverage): is_chief is the core of the multi-rank abstraction — it gates lcurve open, display, periodic save, and final checkpoint (the is_chief uses below in this loop). But every test in test_train_abstract_trainer.py uses the default RankContext() (rank 0, chief), so the is_chief == False branch — i.e. that a non-chief rank skips all of those — is never exercised. Worth a test that constructs a RankContext(rank=1, world_size=2) and asserts the non-chief rank writes no lcurve/checkpoint.
There was a problem hiding this comment.
Fixed in njzjz/deepmd-kit@7280c146. Added a non-chief RankContext(rank=1, world_size=2) test that still runs optimizer steps but writes no lcurve and saves no checkpoints.
5767485 to
7280c14
Compare
| self.models = self._deserialize_models(checkpoint_data) | ||
| self.model_def_script = checkpoint_data["model_def_script"] | ||
| self.model_params_by_task = self._model_params_by_task( | ||
| self.model_def_script | ||
| ) |
There was a problem hiding this comment.
--init-model path currently may ignore --use-pretrain-script=False.
The entrypoint only replaces config["model"] from the checkpoint when use_pretrain_script is true, but DPTrainer.__init__() later unconditionally overwrites self.model_def_script with checkpoint_data["model_def_script"] for both init_model and restart. So an init_model run without --use-pretrain-script can still save checkpoints using the init checkpoint’s model script instead of the input config.
Could we keep jdata["model"] for init_model unless use_pretrain_script was explicitly requested, and only force checkpoint metadata for restart?
There was a problem hiding this comment.
Addressed in 7632feff4: JAX init_model now preserves the input jdata["model"] / model_def_script when use_pretrain_script is false. Only restart forces checkpoint metadata and current step. I also added focused tests for both the init-model and restart paths.
| last_log_time = current_time | ||
| last_log_step = display_step | ||
|
|
||
| if ( | ||
| self.rank_context.is_chief | ||
| and self.trainer_config.save_freq > 0 | ||
| and display_step % self.trainer_config.save_freq == 0 | ||
| ): | ||
| self.save_checkpoint(display_step) |
There was a problem hiding this comment.
Is it possible to support full validation here? See https://github.com/deepmodeling/deepmd-kit/blob/master/deepmd/pt/train/training.py#L1815
There was a problem hiding this comment.
Addressed in 7632feff4: full validation is now wired at the trainer level. AbstractTrainer has a run_full_validation() hook, shared logging/top-k/best-checkpoint bookkeeping lives in deepmd.dpmodel.train.validation, and pt_expt plus JAX provide backend-specific validators/checkpoint writers. This avoids adding a model capability API for validation.
|
Pushed Changes:
Validation:
Both live GPU CLI runs produced |
|
@coderabbitai review |
✅ Action performedReview finished.
|
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@deepmd/dpmodel/train/trainer.py`:
- Around line 230-236: The mapping branch in trainer.py under the probability
handling logic currently validates missing keys but lets extra task probability
keys slip through. Update the probabilities check in the Mapping path to reject
unknown keys as well as missing ones, using self._keys as the only allowed set
before building the numpy array. Keep the fix localized to the probability
normalization/validation block so stale or mistyped entries raise a ValueError
instead of being silently ignored.
In `@deepmd/dpmodel/train/validation.py`:
- Around line 282-326: The validation flow in run() only raises errors on rank
0, so peer JAX processes can keep running after a rank-0 failure. Add a
backend-specific error propagation hook in the validator path around
_raise_if_error and/or the rank-0 try blocks, then implement the JAX version
used by the trainer to broadcast the failure or terminate all processes
consistently. Make sure the hook is invoked for failures in _evaluate,
save_checkpoint/_reconcile_best_checkpoints, and _log_result so non-zero ranks
do not continue silently.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 9857057b-3e07-4295-8a4a-6f044d0b4b1e
📒 Files selected for processing (28)
deepmd/dpmodel/train/__init__.pydeepmd/dpmodel/train/data.pydeepmd/dpmodel/train/entrypoint.pydeepmd/dpmodel/train/trainer.pydeepmd/dpmodel/train/validation.pydeepmd/dpmodel/utils/training_utils.pydeepmd/jax/entrypoints/train.pydeepmd/jax/train/trainer.pydeepmd/jax/train/validation.pydeepmd/jax/utils/finetune.pydeepmd/jax/utils/serialization.pydeepmd/pd/utils/finetune.pydeepmd/pt/train/validation.pydeepmd/pt/utils/finetune.pydeepmd/pt_expt/entrypoints/main.pydeepmd/pt_expt/train/training.pydeepmd/pt_expt/utils/finetune.pydeepmd/utils/argcheck.pydeepmd/utils/finetune.pysource/tests/common/dpmodel/test_train_abstract_trainer.pysource/tests/common/dpmodel/test_train_data.pysource/tests/common/dpmodel/test_train_entrypoint.pysource/tests/common/dpmodel/test_training_utils.pysource/tests/common/test_finetune_utils.pysource/tests/jax/test_training.pysource/tests/pt/test_validation.pysource/tests/pt_expt/test_entrypoint.pysource/tests/pt_expt/test_training.py
💤 Files with no reviewable changes (9)
- source/tests/common/dpmodel/test_train_data.py
- source/tests/common/dpmodel/test_training_utils.py
- source/tests/pt/test_validation.py
- source/tests/pt_expt/test_training.py
- source/tests/common/dpmodel/test_train_entrypoint.py
- source/tests/common/dpmodel/test_train_abstract_trainer.py
- source/tests/pt_expt/test_entrypoint.py
- source/tests/common/test_finetune_utils.py
- source/tests/jax/test_training.py
✅ Files skipped from review due to trivial changes (1)
- deepmd/utils/argcheck.py
🚧 Files skipped from review as they are similar to previous changes (12)
- deepmd/dpmodel/utils/training_utils.py
- deepmd/pt/utils/finetune.py
- deepmd/dpmodel/train/init.py
- deepmd/pd/utils/finetune.py
- deepmd/pt_expt/utils/finetune.py
- deepmd/dpmodel/train/data.py
- deepmd/jax/utils/serialization.py
- deepmd/jax/entrypoints/train.py
- deepmd/utils/finetune.py
- deepmd/jax/train/trainer.py
- deepmd/pt_expt/train/training.py
- deepmd/pt_expt/entrypoints/main.py
| elif isinstance(probabilities, Mapping): | ||
| missing = [key for key in self._keys if key not in probabilities] | ||
| if missing: | ||
| raise ValueError(f"Missing task probabilities for {missing}.") | ||
| prob = np.asarray( | ||
| [probabilities[key] for key in self._keys], dtype=np.float64 | ||
| ) |
There was a problem hiding this comment.
🎯 Functional Correctness | 🟡 Minor | ⚡ Quick win
Reject unknown probability keys.
The mapping path checks missing task probabilities but silently ignores extra keys, so stale or mistyped task weights can pass validation unnoticed.
Suggested fix
elif isinstance(probabilities, Mapping):
missing = [key for key in self._keys if key not in probabilities]
if missing:
raise ValueError(f"Missing task probabilities for {missing}.")
+ unknown = [key for key in probabilities if key not in self._tasks]
+ if unknown:
+ raise ValueError(f"Unknown task probabilities for {unknown}.")
prob = np.asarray(
[probabilities[key] for key in self._keys], dtype=np.float64
)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| elif isinstance(probabilities, Mapping): | |
| missing = [key for key in self._keys if key not in probabilities] | |
| if missing: | |
| raise ValueError(f"Missing task probabilities for {missing}.") | |
| prob = np.asarray( | |
| [probabilities[key] for key in self._keys], dtype=np.float64 | |
| ) | |
| elif isinstance(probabilities, Mapping): | |
| missing = [key for key in self._keys if key not in probabilities] | |
| if missing: | |
| raise ValueError(f"Missing task probabilities for {missing}.") | |
| unknown = [key for key in probabilities if key not in self._tasks] | |
| if unknown: | |
| raise ValueError(f"Unknown task probabilities for {unknown}.") | |
| prob = np.asarray( | |
| [probabilities[key] for key in self._keys], dtype=np.float64 | |
| ) |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@deepmd/dpmodel/train/trainer.py` around lines 230 - 236, The mapping branch
in trainer.py under the probability handling logic currently validates missing
keys but lets extra task probability keys slip through. Update the probabilities
check in the Mapping path to reject unknown keys as well as missing ones, using
self._keys as the only allowed set before building the numpy array. Keep the fix
localized to the probability normalization/validation block so stale or mistyped
entries raise a ValueError instead of being silently ignored.
| if self.rank == 0: | ||
| try: | ||
| result = self._evaluate(display_step) | ||
| save_path = result.saved_best_path | ||
| except Exception as exc: | ||
| caught_exception = exc | ||
| error_message = ( | ||
| "Full validation failed during evaluation:\n" | ||
| f"{traceback.format_exc()}" | ||
| ) | ||
|
|
||
| self._raise_if_error(error_message, caught_exception) | ||
|
|
||
| if save_path is not None and self.rank == 0: | ||
| try: | ||
| save_checkpoint(Path(save_path), lr=lr, step=step_id) | ||
| self._reconcile_best_checkpoints() | ||
| except Exception as exc: | ||
| caught_exception = exc | ||
| error_message = ( | ||
| "Full validation failed while saving the best checkpoint:\n" | ||
| f"{traceback.format_exc()}" | ||
| ) | ||
| else: | ||
| error_message = None | ||
| caught_exception = None | ||
|
|
||
| self._raise_if_error(error_message, caught_exception) | ||
|
|
||
| if self.rank == 0: | ||
| try: | ||
| self._log_result(result) | ||
| except Exception as exc: | ||
| caught_exception = exc | ||
| error_message = ( | ||
| "Full validation failed while writing logs:\n" | ||
| f"{traceback.format_exc()}" | ||
| ) | ||
| else: | ||
| error_message = None | ||
| caught_exception = None | ||
|
|
||
| self._raise_if_error(error_message, caught_exception) | ||
|
|
||
| return result if self.rank == 0 else None |
There was a problem hiding this comment.
🩺 Stability & Availability | 🟠 Major | 🏗️ Heavy lift
Propagate rank-0 validation failures to peer JAX processes.
run() catches and raises validation/save/log errors only on rank 0; other ranks call _raise_if_error(None, None) and continue. Since the JAX trainer constructs this validator with rank=int(jax.process_index()), a rank-0 failure can leave peer processes running after validation aborts. Please add an overridable backend error-propagation hook and implement it for JAX, or otherwise terminate all peers consistently.
🧰 Tools
🪛 Ruff (0.15.20)
[warning] 286-286: Do not catch blind exception: Exception
(BLE001)
[warning] 299-299: Do not catch blind exception: Exception
(BLE001)
[warning] 314-314: Do not catch blind exception: Exception
(BLE001)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@deepmd/dpmodel/train/validation.py` around lines 282 - 326, The validation
flow in run() only raises errors on rank 0, so peer JAX processes can keep
running after a rank-0 failure. Add a backend-specific error propagation hook in
the validator path around _raise_if_error and/or the rank-0 try blocks, then
implement the JAX version used by the trainer to broadcast the failure or
terminate all processes consistently. Make sure the hook is invoked for failures
in _evaluate, save_checkpoint/_reconcile_best_checkpoints, and _log_result so
non-zero ranks do not continue silently.
There was a problem hiding this comment.
Pull request overview
This PR introduces a backend-independent training abstraction layer under deepmd.dpmodel.train, and migrates JAX and pt_expt training/finetune flows to reuse shared orchestration (tasks/ranks normalization, learning-curve output, checkpoint cadence, lifecycle hooks, and full-validation best-checkpoint management). It also consolidates fine-tuning rule generation into a single backend-agnostic builder in deepmd.utils.finetune, leaving backend modules to focus on checkpoint loading and state/tensor copying.
Changes:
- Add backend-independent trainer/entrypoint/data/full-validation primitives in
deepmd.dpmodel.train, with associated test coverage moved/added undersource/tests/common/dpmodel/. - Centralize fine-tuning rule building in
deepmd.utils.finetuneand refactor PT/PT-exportable/Paddle/JAX finetune helpers to delegate to the shared builder. - Extend full-validation best-checkpoint management to support both file checkpoints (
.pt) and directory checkpoints (.jax), and migrate JAX +pt_expttraining entrypoints onto the shared pipeline.
Reviewed changes
Copilot reviewed 28 out of 28 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| source/tests/pt/test_validation.py | Adds coverage for best-checkpoint reconciliation when checkpoints are directories (e.g., .jax). |
| source/tests/pt_expt/test_training.py | Adds an integration-style test asserting pt_expt full validation writes/prunes best checkpoints and logs. |
| source/tests/pt_expt/test_entrypoint.py | New tests for pt_expt entrypoint option normalization, process-group ownership, checkpoint links/retention, and stat-file creation. |
| source/tests/jax/test_training.py | Adds tests for JAX optimizer LR scaling, finetune loading, state/key normalization, and entrypoint gating/multitask behavior. |
| source/tests/common/test_finetune_utils.py | Adds extensive tests for the shared finetune rule builder (single/multitask, aliases, random fitting, immutability). |
| source/tests/common/dpmodel/test_training_utils.py | New tests for model probability resolution and size fallbacks. |
| source/tests/common/dpmodel/test_train_entrypoint.py | New tests validating the shared entrypoint pipeline sequencing and teardown behavior on failures. |
| source/tests/common/dpmodel/test_train_data.py | New tests for shared data-summary printing compatibility/failure propagation. |
| source/tests/common/dpmodel/test_train_abstract_trainer.py | New tests for the backend-independent trainer loop, multitask sampling, lcurve formatting, checkpoint cadence, and full-validation ordering. |
| deepmd/utils/finetune.py | Introduces FinetuneRuleBuilder and shared finetune-rule construction APIs. |
| deepmd/utils/argcheck.py | Updates full-validation docs to indicate support across PT, pt_expt, and JAX. |
| deepmd/pt/utils/finetune.py | Refactors PyTorch finetune rules to delegate to shared finetune rule builder. |
| deepmd/pt/train/validation.py | Extends best-checkpoint handling to support configurable suffixes and directory checkpoints; improves validation-data iteration. |
| deepmd/pt_expt/utils/finetune.py | Refactors pt_expt finetune rules to delegate to shared finetune rule builder with backend-specific errors. |
| deepmd/pt_expt/train/training.py | Migrates pt_expt training loop to AbstractTrainer, adds task normalization, full validation hook, checkpoint retention, and relative latest-link handling. |
| deepmd/pt_expt/entrypoints/main.py | Migrates pt_expt train entrypoint to the shared entrypoint pipeline and shared data helpers; adds stat-file creation helper. |
| deepmd/pd/utils/finetune.py | Refactors Paddle finetune rules to delegate to shared finetune rule builder. |
| deepmd/jax/utils/serialization.py | Updates JAX serialization to support multitask state layout and avoid mis-normalizing numeric-looking task keys. |
| deepmd/jax/utils/finetune.py | New JAX finetune helper delegating to shared finetune rule builder. |
| deepmd/jax/train/validation.py | New JAX full-validation implementation built on backend-independent FullValidatorBase. |
| deepmd/jax/train/trainer.py | Migrates JAX trainer to AbstractTrainer, adds multitask support, shared finetune integration, full validation hook, and checkpoint writing changes. |
| deepmd/jax/entrypoints/train.py | Migrates JAX train entrypoint to the shared entrypoint pipeline and shared data helpers; adds multitask-aware neighbor-stat update. |
| deepmd/dpmodel/utils/training_utils.py | Improves model-probability resolution by handling get_nsystems() and non-sized data sources. |
| deepmd/dpmodel/train/validation.py | New backend-independent full-validation base with best-checkpoint management and val.log formatting. |
| deepmd/dpmodel/train/trainer.py | New backend-independent trainer abstraction handling task selection, display scheduling, lcurve writing, and checkpoint cadence. |
| deepmd/dpmodel/train/entrypoint.py | New backend-independent entrypoint orchestration pipeline used by backend-specific entrypoints. |
| deepmd/dpmodel/train/data.py | New shared data/task config normalization utilities and summary-printing compatibility helper. |
| deepmd/dpmodel/train/init.py | Exposes the new backend-independent training abstractions as a package API. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| self._lmdb_test_data = LmdbTestData( | ||
| lmdb_dataset.lmdb_path, | ||
| type_map=list(lmdb_dataset.type_map), | ||
| lmdb_path, | ||
| type_map=list(type_map or []), | ||
| shuffle_test=False, | ||
| ) |
| ) -> tuple[Any, optax.EmptyState]: | ||
| del params | ||
| learning_rate = kwargs["learning_rate"] | ||
| updates = jax.tree.map(lambda update: -learning_rate * update, updates) |
Summary
deepmd.dpmodel.trainfor task/rank normalization, display scheduling, learning-curve output, checkpoint cadence, lifecycle hooks, and shared train entrypoint orchestration.deepmd.utils.finetune, and reduce the PT, PT-exportable, Paddle, and JAX backend finetune modules to backend-specific checkpoint loading plus shared rule generation.pt_expttrain entrypoint/trainer behavior further onto the shared pipeline, including single-task-as-multi-task normalization, data summaries, checkpoint retention, stat-file parent creation, relative latest checkpoint symlinks, and checkpoint parent creation.print_summaryfallback behavior, broken__len__handling, JAX finetune branch/alias validation, numeric-looking JAX task keys, HDF5 stat paths, andpt_exptcheckpoint symlinks.source/tests/test_dpmodel_*.pyintosource/tests/common/dpmodel/.Refs #5229, #5230, #5231
Tests
ruff format .ruff check .git diff --checkPYTHONPATH=/home/jzzeng/codes/deepmd-kit pytest source/tests/common/dpmodel/test_train_abstract_trainer.py source/tests/common/dpmodel/test_train_entrypoint.py source/tests/common/dpmodel/test_train_data.py source/tests/common/dpmodel/test_training_utils.py source/tests/common/test_finetune_utils.py source/tests/jax/test_training.py source/tests/pt_expt/test_entrypoint.py source/tests/pt_expt/test_multitask.py::TestMultiTaskSeA::test_multitask_finetune source/tests/pt_expt/test_multitask.py::TestMultiTaskSeA::test_multitask_finetune_from_single_task source/tests/pt_expt/test_multitask.py::TestMultiTaskSeA::test_multitask_finetune_no_change_model_params -q(53 passed, 2 subtests passed)PYTHONPATH=/home/jzzeng/codes/deepmd-kit timeout 180 srun --gres=gpu:1 dp --jax train input.json --skip-neighbor-stat --finetune pretrain.jax --use-pretrain-scripton a temporary 1-step water finetune smoke; completed on NVIDIA GeForce RTX 5090 and savedft-model-1.jax.PYTHONPATH=/home/jzzeng/codes/deepmd-kit timeout 180 srun --gres=gpu:1 dp --pt-expt train input.json --skip-neighbor-staton a temporary 2-step water smoke; completed on NVIDIA GeForce RTX 5090, savedckpts/pt-model-2.pt, createdstats/stat.hdf5, and verifiedckpts/pt-model.pt -> pt-model-2.ptwith old step checkpoint pruned bymax_ckpt_keep=1.Notes
paddleis not installed in this environment.deepmd_gnn/CUDA initialization, not by the shared finetune rule builder changes.Summary by CodeRabbit
New Features
Bug Fixes