Skip to content

feat(dpmodel): add backend-independent trainer abstraction#5603

Open
njzjz wants to merge 5 commits into
deepmodeling:masterfrom
njzjz:feat/dpmodel-abstract-trainer-5229
Open

feat(dpmodel): add backend-independent trainer abstraction#5603
njzjz wants to merge 5 commits into
deepmodeling:masterfrom
njzjz:feat/dpmodel-abstract-trainer-5229

Conversation

@njzjz

@njzjz njzjz commented Jun 28, 2026

Copy link
Copy Markdown
Member

Summary

  • Add backend-independent training abstractions under deepmd.dpmodel.train for task/rank normalization, display scheduling, learning-curve output, checkpoint cadence, lifecycle hooks, and shared train entrypoint orchestration.
  • Factor common training-data helpers so single-task training is handled as a one-task collection and multi-task data construction/summary/probability handling is shared where possible.
  • Add a backend-independent finetune rule builder in deepmd.utils.finetune, and reduce the PT, PT-exportable, Paddle, and JAX backend finetune modules to backend-specific checkpoint loading plus shared rule generation.
  • Migrate JAX train entrypoint/trainer onto the shared pipeline and add JAX finetune plus multi-task support on top of the new abstractions.
  • Migrate pt_expt train entrypoint/trainer behavior further onto the shared pipeline, including single-task-as-multi-task normalization, data summaries, checkpoint retention, stat-file parent creation, relative latest checkpoint symlinks, and checkpoint parent creation.
  • Address PR review comments around task-key validation, learning-curve metric ordering, lifecycle cleanup, print_summary fallback behavior, broken __len__ handling, JAX finetune branch/alias validation, numeric-looking JAX task keys, HDF5 stat paths, and pt_expt checkpoint symlinks.
  • Move the new dpmodel trainer/entrypoint tests from source/tests/test_dpmodel_*.py into source/tests/common/dpmodel/.

Refs #5229, #5230, #5231

Tests

  • ruff format .
  • ruff check .
  • git diff --check
  • PYTHONPATH=/home/jzzeng/codes/deepmd-kit pytest source/tests/common/dpmodel/test_train_abstract_trainer.py source/tests/common/dpmodel/test_train_entrypoint.py source/tests/common/dpmodel/test_train_data.py source/tests/common/dpmodel/test_training_utils.py source/tests/common/test_finetune_utils.py source/tests/jax/test_training.py source/tests/pt_expt/test_entrypoint.py source/tests/pt_expt/test_multitask.py::TestMultiTaskSeA::test_multitask_finetune source/tests/pt_expt/test_multitask.py::TestMultiTaskSeA::test_multitask_finetune_from_single_task source/tests/pt_expt/test_multitask.py::TestMultiTaskSeA::test_multitask_finetune_no_change_model_params -q (53 passed, 2 subtests passed)
  • PYTHONPATH=/home/jzzeng/codes/deepmd-kit timeout 180 srun --gres=gpu:1 dp --jax train input.json --skip-neighbor-stat --finetune pretrain.jax --use-pretrain-script on a temporary 1-step water finetune smoke; completed on NVIDIA GeForce RTX 5090 and saved ft-model-1.jax.
  • PYTHONPATH=/home/jzzeng/codes/deepmd-kit timeout 180 srun --gres=gpu:1 dp --pt-expt train input.json --skip-neighbor-stat on a temporary 2-step water smoke; completed on NVIDIA GeForce RTX 5090, saved ckpts/pt-model-2.pt, created stats/stat.hdf5, and verified ckpts/pt-model.pt -> pt-model-2.pt with old step checkpoint pruned by max_ckpt_keep=1.

Notes

  • Paddle-specific runtime tests were not run locally because paddle is not installed in this environment.
  • Plain PyTorch backend test collection is blocked in this environment by external deepmd_gnn/CUDA initialization, not by the shared finetune rule builder changes.

Summary by CodeRabbit

  • New Features

    • Unified training flow across backends, with improved support for single-task and multi-task training.
    • Added clearer fine-tuning support, including model-branch selection and better checkpoint reuse.
    • Introduced full-validation tracking with best-checkpoint management and validation logs.
  • Bug Fixes

    • Improved checkpoint handling, including safer save/restore behavior and cleanup of old checkpoints.
    • Fixed task-weight calculation to better handle different dataset types.
    • Enhanced training summaries and metric reporting for more consistent output.

@coderabbitai

coderabbitai Bot commented Jun 28, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds shared training abstractions for entrypoints, trainers, task normalization, and validation. Refactors JAX and pt_expt training flows to use the shared pipeline. Consolidates finetune rule handling into shared utilities. Extends JAX checkpoint serialization for multitask models.

Changes

Shared Training Abstractions and Backend Integrations

Layer / File(s) Summary
Core training contracts
deepmd/dpmodel/train/__init__.py, deepmd/dpmodel/train/data.py, deepmd/dpmodel/train/entrypoint.py, deepmd/dpmodel/train/trainer.py, deepmd/dpmodel/utils/training_utils.py, source/tests/common/dpmodel/test_train_abstract_trainer.py, source/tests/common/dpmodel/test_train_data.py, source/tests/common/dpmodel/test_train_entrypoint.py, source/tests/common/dpmodel/test_training_utils.py
Adds shared task config, entrypoint, trainer, and learning-curve abstractions, plus task-size fallback logic and coverage for the common training pipeline.
Validation and checkpoint management
deepmd/dpmodel/train/validation.py, deepmd/pt/train/validation.py, source/tests/pt/test_validation.py, source/tests/pt_expt/test_training.py
Adds shared full-validation machinery and updates PyTorch validation checkpoint naming, cleanup, and validation-data handling.
Shared finetune rule builder
deepmd/utils/finetune.py, deepmd/pt/utils/finetune.py, deepmd/pd/utils/finetune.py, deepmd/pt_expt/utils/finetune.py, deepmd/jax/utils/finetune.py, source/tests/common/test_finetune_utils.py
Adds FinetuneRuleBuilder and shared finetune rule helpers, then reduces backend finetune modules to wrappers with tests for branch and alias behavior.
JAX training refactor
deepmd/jax/entrypoints/train.py, deepmd/jax/train/trainer.py, deepmd/jax/train/validation.py, deepmd/jax/utils/serialization.py, deepmd/jax/utils/finetune.py, source/tests/jax/test_training.py
Refactors JAX entrypoint/trainer flow onto the shared abstractions, adds JAX full validation, and extends serialization for composite multitask checkpoints.
pt_expt training refactor
deepmd/pt_expt/entrypoints/main.py, deepmd/pt_expt/train/training.py, deepmd/pt_expt/utils/finetune.py, source/tests/pt_expt/test_entrypoint.py
Refactors pt_expt entrypoint and trainer onto the shared abstractions, including distributed setup/teardown, task mapping, and checkpoint handling.
Backend docs for validation
deepmd/utils/argcheck.py
Updates validating-argument docs to describe full-validation support across backends.

Estimated code review effort: 5 (Critical) | ~120 minutes

Possibly related PRs

Suggested reviewers: OutisLi, wanghan-iapcm

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 52.14% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly summarizes the main change: adding a backend-independent trainer abstraction for dpmodel.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🧹 Nitpick comments (1)
source/tests/test_dpmodel_abstract_trainer.py (1)

99-109: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick win

Avoid hard-coding the default task key in this test.

TrainingTaskCollection.single() owns the default key via DEFAULT_TASK_KEY, so indexing with "Default" makes this test fail on an internal rename without any behavior change. Pull the lone task from the collection API instead.

Proposed fix
-    task = tasks["Default"]
+    task = tasks.select()
     task.add_data_requirements()
 
     assert not tasks.is_multitask
     assert tasks.select() is task
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@source/tests/test_dpmodel_abstract_trainer.py` around lines 99 - 109, The
test is hard-coding the default task name instead of using the collection API.
In test_dpmodel_abstract_trainer.py, update the TrainingTaskCollection.single()
usage so the lone task is retrieved without indexing by "Default", and reference
the collection’s default-task behavior through its API/select() rather than the
literal key. Keep the assertions on task.add_data_requirements(),
tasks.is_multitask, and tasks.select() intact, but bind task from the collection
in a way that won’t break if DEFAULT_TASK_KEY is renamed.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@deepmd/dpmodel/train/trainer.py`:
- Around line 380-383: Keep the single-task table schema anchored to
train_results in the Trainer formatting path. The header logic and the row
emission in the same train/eval summary should use the exact same key sequence
from train_results, not valid_results, so the validation columns stay aligned
and missing or reordered validation metrics do not break train_results lookups.
Update the loop in the Trainer row-building code to derive the metric order from
train_results and only read matching validation values by key when present.
- Around line 148-161: The task normalization in the trainer constructor
silently overwrites duplicate entries when `tasks` is a sequence, so add an
explicit duplicate-key check before building `self._tasks` in
`Trainer.__init__`. Update the branch that handles non-Mapping `tasks` to
validate that each `task.key` appears only once, and raise a clear `ValueError`
if duplicates are found. Keep the existing key-matching validation and
`_normalize_probabilities` flow intact.

In `@deepmd/jax/train/trainer.py`:
- Around line 188-194: The training setup in
AbstractTrainer.on_train_begin()/train() is missing registration of loss label
requirements before creating the task collection. Update the training flow
around TrainingTaskCollection.single and self.run(tasks) so
DeepmdDataSystem.add_data_requirements() is called with
self.loss.label_requirement (or equivalent task data_requirements) before
batching starts, ensuring get_batch() includes the labels needed by the loss
path.

In `@source/tests/test_dpmodel_abstract_trainer.py`:
- Around line 53-64: The evaluate_training helper is consuming a batch for
inactive multitask entries by falling back to task.training_data.get_batch()
when step_result is None, which advances the cursor during display collection.
Update evaluate_training in test_dpmodel_abstract_trainer.py so it only reads
step_result.payload for the matching task and otherwise returns a non-consuming
placeholder/skip path for inactive tasks; keep the batch-order contract aligned
with the multitask loop in trainer.py and the TrainStepResult/task.key check.

---

Nitpick comments:
In `@source/tests/test_dpmodel_abstract_trainer.py`:
- Around line 99-109: The test is hard-coding the default task name instead of
using the collection API. In test_dpmodel_abstract_trainer.py, update the
TrainingTaskCollection.single() usage so the lone task is retrieved without
indexing by "Default", and reference the collection’s default-task behavior
through its API/select() rather than the literal key. Keep the assertions on
task.add_data_requirements(), tasks.is_multitask, and tasks.select() intact, but
bind task from the collection in a way that won’t break if DEFAULT_TASK_KEY is
renamed.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: ae088f30-bef5-4da4-ae02-40f16ee87975

📥 Commits

Reviewing files that changed from the base of the PR and between a9bcbc5 and d100673.

📒 Files selected for processing (4)
  • deepmd/dpmodel/train/__init__.py
  • deepmd/dpmodel/train/trainer.py
  • deepmd/jax/train/trainer.py
  • source/tests/test_dpmodel_abstract_trainer.py

Comment thread deepmd/dpmodel/train/trainer.py
Comment thread deepmd/dpmodel/train/trainer.py
Comment thread deepmd/jax/train/trainer.py Outdated
Comment thread source/tests/test_dpmodel_abstract_trainer.py Outdated
@njzjz njzjz force-pushed the feat/dpmodel-abstract-trainer-5229 branch from d100673 to 6e168b7 Compare June 28, 2026 18:38

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (2)
deepmd/dpmodel/train/trainer.py (2)

422-428: 🩺 Stability & Availability | 🟠 Major | ⚡ Quick win

Single-task row schema should anchor to train_results, not valid_results.

format_header iterates train_results (Line 382), but the row loop here iterates valid_results and indexes train_results[key]. If a backend returns validation metrics in a different order or omits a metric, the row desynchronizes from the header and train_results[key] can raise KeyError.

Suggested fix
             if valid_results is not None:
                 assert not self._is_multitask(valid_results)
-                for key in valid_results:
+                for key in train_results:
                     row += (
-                        f"   {float(valid_results[key]):11.2e}"
+                        f"   {float(valid_results.get(key, float('nan'))):11.2e}"
                         f" {float(train_results[key]):11.2e}"
                     )
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/dpmodel/train/trainer.py` around lines 422 - 428, The single-task row
assembly in `Trainer.format_*` is anchored to the wrong metric source: the
header follows `train_results`, but this row loop currently iterates
`valid_results` while reading `train_results[key]`, which can desynchronize
columns or fail when validation metrics differ. Update the row-building logic to
use the same key order as `train_results` (matching the existing `format_header`
contract) and only use `valid_results` for the optional validation value lookup,
keeping the schema consistent with the training metrics.

148-161: 🎯 Functional Correctness | 🟠 Major | ⚡ Quick win

Duplicate task keys are still silently dropped when tasks is a sequence.

task_dict = {task.key: task for task in tasks} overwrites earlier entries, so a misconfigured multitask run can lose a task with no error. Validate uniqueness before building _tasks.

Suggested fix
         if isinstance(tasks, Mapping):
             task_dict = dict(tasks)
         else:
-            task_dict = {task.key: task for task in tasks}
+            task_list = list(tasks)
+            task_dict = {task.key: task for task in task_list}
+            if len(task_dict) != len(task_list):
+                raise ValueError("Training task keys must be unique.")
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/dpmodel/train/trainer.py` around lines 148 - 161, The task
normalization in Trainer.__init__ still silently overwrites duplicate keys when
tasks is a sequence because task_dict is built directly from task.key; add an
explicit duplicate-key check before assigning to self._tasks. In the branch that
handles non-Mapping tasks, validate that each task.key is unique (raise a
ValueError on duplicates) before constructing the dict, while keeping the
existing key-vs-task.key consistency check and downstream
_normalize_probabilities flow unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@deepmd/pt_expt/train/training.py`:
- Line 831: Wire the parsed retention setting into checkpoint cleanup so
`max_ckpt_keep` actually limits saved checkpoints: `TrainerConfig` already
carries the value and `AbstractTrainer.run()`/`save_checkpoint()` currently only
append new `self.save_ckpt-<step>.pt` files. Update the checkpoint-saving flow
to track existing checkpoints and prune the oldest ones after each save, using
`self.max_ckpt_keep` as the cap and keeping the newest checkpoints only.

---

Duplicate comments:
In `@deepmd/dpmodel/train/trainer.py`:
- Around line 422-428: The single-task row assembly in `Trainer.format_*` is
anchored to the wrong metric source: the header follows `train_results`, but
this row loop currently iterates `valid_results` while reading
`train_results[key]`, which can desynchronize columns or fail when validation
metrics differ. Update the row-building logic to use the same key order as
`train_results` (matching the existing `format_header` contract) and only use
`valid_results` for the optional validation value lookup, keeping the schema
consistent with the training metrics.
- Around line 148-161: The task normalization in Trainer.__init__ still silently
overwrites duplicate keys when tasks is a sequence because task_dict is built
directly from task.key; add an explicit duplicate-key check before assigning to
self._tasks. In the branch that handles non-Mapping tasks, validate that each
task.key is unique (raise a ValueError on duplicates) before constructing the
dict, while keeping the existing key-vs-task.key consistency check and
downstream _normalize_probabilities flow unchanged.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 2ee2703c-74d9-48e0-8c5e-ce031e73c19a

📥 Commits

Reviewing files that changed from the base of the PR and between d100673 and 6e168b7.

📒 Files selected for processing (5)
  • deepmd/dpmodel/train/__init__.py
  • deepmd/dpmodel/train/trainer.py
  • deepmd/jax/train/trainer.py
  • deepmd/pt_expt/train/training.py
  • source/tests/test_dpmodel_abstract_trainer.py
✅ Files skipped from review due to trivial changes (1)
  • deepmd/dpmodel/train/init.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • source/tests/test_dpmodel_abstract_trainer.py
  • deepmd/jax/train/trainer.py

Comment thread deepmd/pt_expt/train/training.py Outdated
@codecov

codecov Bot commented Jun 28, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 75.58209% with 409 lines in your changes missing coverage. Please review.
✅ Project coverage is 81.85%. Comparing base (73de44b) to head (7632fef).
⚠️ Report is 7 commits behind head on master.

Files with missing lines Patch % Lines
deepmd/jax/train/trainer.py 40.67% 175 Missing ⚠️
deepmd/dpmodel/train/validation.py 78.83% 62 Missing ⚠️
deepmd/jax/train/validation.py 29.50% 43 Missing ⚠️
deepmd/jax/entrypoints/train.py 50.79% 31 Missing ⚠️
deepmd/dpmodel/train/trainer.py 91.97% 24 Missing ⚠️
deepmd/pt_expt/entrypoints/main.py 84.95% 17 Missing ⚠️
deepmd/pt_expt/train/training.py 93.06% 14 Missing ⚠️
deepmd/pt/train/validation.py 67.56% 12 Missing ⚠️
deepmd/utils/finetune.py 90.38% 10 Missing ⚠️
deepmd/dpmodel/train/entrypoint.py 89.70% 7 Missing ⚠️
... and 3 more
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5603      +/-   ##
==========================================
- Coverage   81.98%   81.85%   -0.14%     
==========================================
  Files         959      966       +7     
  Lines      105430   106745    +1315     
  Branches     4071     4102      +31     
==========================================
+ Hits        86442    87376     +934     
- Misses      17518    17875     +357     
- Partials     1470     1494      +24     

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

@njzjz njzjz force-pushed the feat/dpmodel-abstract-trainer-5229 branch from 6e168b7 to 78272a7 Compare June 29, 2026 06:35

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@deepmd/dpmodel/train/entrypoint.py`:
- Around line 79-83: The cleanup guard in the entrypoint flow is starting too
late, so `teardown_run()` will not run if `setup_run()` fails after partially
initializing backend state. Move the `try/finally` in the main training
entrypoint so it wraps `setup_run()` itself, ensuring `teardown_run()` is always
invoked even when setup raises. Use the existing `setup_run()` and
`teardown_run()` flow in this module as the anchor when making the change.
- Line 48: The AbstractTrainEntrypoint contract is too implicit: it currently
has no abstract methods and the hook methods are docstring-only, which triggers
lint issues. Update AbstractTrainEntrypoint so run_training() is marked
abstract, and give the optional hook methods explicit no-op bodies with return
None unless they must be overridden by every backend. Use the class and method
names AbstractTrainEntrypoint and run_training() to locate the contract and hook
definitions.

In `@deepmd/pt_expt/entrypoints/main.py`:
- Around line 295-297: Replace the existing assert-based validation in the
multi-task config handling with an explicit runtime check that raises ValueError
when the config is "RANDOM"; this ensures the validation in the main entrypoint
is enforced even under optimized Python runs. Update the check in the entrypoint
logic that handles the multi-task configuration so the rejection of "RANDOM" is
always consistent.
- Around line 257-259: `setup_run` and `teardown_run` need ownership tracking
for the distributed process group: make `setup_run` skip `init_process_group()`
when a default group already exists, and record whether this entrypoint created
the group. Then update `teardown_run` to call `destroy_process_group()` only
when it is cleaning up a group created by `setup_run`, so it does not interfere
with a caller-managed distributed context.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: ec38d095-ecf4-4024-ac9f-7523f8d40e60

📥 Commits

Reviewing files that changed from the base of the PR and between 6e168b7 and 78272a7.

📒 Files selected for processing (11)
  • deepmd/dpmodel/train/__init__.py
  • deepmd/dpmodel/train/entrypoint.py
  • deepmd/dpmodel/train/trainer.py
  • deepmd/jax/entrypoints/train.py
  • deepmd/jax/train/trainer.py
  • deepmd/pt_expt/entrypoints/main.py
  • deepmd/pt_expt/train/training.py
  • source/tests/jax/test_training.py
  • source/tests/pt_expt/test_entrypoint.py
  • source/tests/test_dpmodel_abstract_trainer.py
  • source/tests/test_dpmodel_train_entrypoint.py
🚧 Files skipped from review as they are similar to previous changes (4)
  • deepmd/dpmodel/train/init.py
  • source/tests/test_dpmodel_abstract_trainer.py
  • deepmd/jax/train/trainer.py
  • deepmd/pt_expt/train/training.py

Comment thread deepmd/dpmodel/train/entrypoint.py
Comment thread deepmd/dpmodel/train/entrypoint.py Outdated
Comment thread deepmd/pt_expt/entrypoints/main.py
Comment thread deepmd/pt_expt/entrypoints/main.py Outdated
@njzjz njzjz force-pushed the feat/dpmodel-abstract-trainer-5229 branch from 78272a7 to 5bebfdf Compare June 29, 2026 08:27

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
deepmd/jax/utils/serialization.py (1)

189-207: 🗄️ Data Integrity & Integration | 🟠 Major | ⚡ Quick win

Preserve string task keys when converting restored state keys.

convert_str_to_int_key(state) also rewrites digit-only task names under state["models"]. A valid multi-task key like "0" becomes 0, then state_by_model[model_key] fails because model_def_script["model_dict"] still uses "0".

Proposed fix
-        convert_str_to_int_key(state)
-
         model_def_script = data.model_def_script
         if "model_dict" in model_def_script:
             state_by_model = state.get("models", state)
+            if "models" in state:
+                for model_state in state_by_model.values():
+                    convert_str_to_int_key(model_state)
+            else:
+                convert_str_to_int_key(state_by_model)
             model_dict = {"model_dict": {}}
             for model_key, model_params in model_def_script["model_dict"].items():
                 abstract_model = get_model(model_params)
                 graphdef, abstract_state = nnx.split(abstract_model)
                 abstract_state.replace_by_pure_dict(state_by_model[model_key])
                 model = nnx.merge(graphdef, abstract_state)
                 model_dict["model_dict"][model_key] = model.serialize()
         else:
+            convert_str_to_int_key(state)
             abstract_model = get_model(model_def_script)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/jax/utils/serialization.py` around lines 189 - 207, The state
restoration logic in convert_str_to_int_key is too aggressive and converts
digit-only task names inside state["models"], which breaks later lookup in the
model merge path. Update the key conversion so it only normalizes keys that
represent numeric indices where needed, while preserving string task keys used
by model_def_script["model_dict"]; then ensure the state_by_model lookup in the
model merging section continues to use the original task-name strings.
deepmd/pt_expt/train/training.py (1)

1509-1515: 🗄️ Data Integrity & Integration | 🟠 Major | ⚡ Quick win

Make the latest symlink target relative to its own directory.

When save_ckpt includes a directory, e.g. out/model.ckpt, latest.symlink_to("out/model.ckpt-1.pt") creates out/model.ckpt.pt -> out/model.ckpt-1.pt, which resolves as out/out/model.ckpt-1.pt. Restarting from the prefix then follows a broken symlink.

Proposed fix
-        ckpt_path = f"{self.save_ckpt}-{step}.pt"
+        ckpt_path = Path(f"{self.save_ckpt}-{step}.pt")
         torch.save(state, ckpt_path)
         # symlink latest
         latest = Path(f"{self.save_ckpt}.pt")
         if latest.is_symlink() or latest.exists():
             latest.unlink()
-        latest.symlink_to(ckpt_path)
+        latest.symlink_to(ckpt_path.name)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt_expt/train/training.py` around lines 1509 - 1515, The `latest`
symlink in `training.py` is created with a target that can become incorrectly
resolved when `save_ckpt` includes a directory. Update the checkpoint-saving
logic in the block that builds `ckpt_path`, `latest =
Path(f"{self.save_ckpt}.pt")`, and calls `latest.symlink_to(...)` so the symlink
target is computed relative to `latest`’s parent directory instead of using the
full path string. Make sure the existing `torch.save` flow and cleanup of any
prior `latest` link remain unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@deepmd/dpmodel/train/data.py`:
- Around line 120-124: The fallback in _print_summary is too broad because it
catches every TypeError from data.print_summary and retries without prob, which
can hide real failures. Update _print_summary to detect the supported call
signature first (or only retry on an explicit argument-count mismatch) so only
old print_summary implementations use the no-prob path while genuine TypeError
exceptions still propagate.

In `@deepmd/dpmodel/train/trainer.py`:
- Around line 575-579: The optional ABC hook methods on Training hooks are
docstring-only and trigger Ruff B027; update the on_train_begin and on_train_end
methods in TrainingTask/Trainer-related code to use explicit no-op bodies by
adding return None so they remain optional without failing lint.
- Around line 472-479: The cleanup guard starts too late in the training setup
path, so resources created by on_train_begin() may not be released if
_open_learning_curve() fails. Move the try/finally in Trainer.train() (or the
surrounding training entry point) so it begins before on_train_begin(tasks), and
keep on_train_end(tasks) in the finally block to ensure backend state is always
cleaned up even when setup throws.

In `@deepmd/dpmodel/utils/training_utils.py`:
- Around line 124-130: The `_training_data_size` helper is swallowing all
`TypeError` from `len(training_data)` and turning broken `__len__`
implementations into a fallback size of 1. Update `_training_data_size` so it
only returns 1 for objects that truly do not support sizing, and allow real
`__len__` failures to propagate instead of masking them; keep the existing
`get_nsystems` path and the `resolve_model_prob()` callers in mind when
adjusting the behavior.

In `@deepmd/jax/utils/finetune.py`:
- Around line 50-53: The single-task pretrained checkpoint path in finetune.py
silently ignores non-empty model_branch_from values other than RANDOM, which can
hide typos and fall back to Default unexpectedly. Update the branch-selection
logic in the finetuning flow around single_config_chosen/model_branch_from to
explicitly validate the branch name when from_multitask is false, and raise an
error for any unknown non-empty value instead of proceeding. Keep RANDOM as the
only special case that enables new_fitting, and ensure the check happens before
the code falls back to Default.
- Around line 152-157: The pre-check in finetune setup is rejecting valid
finetune_head aliases because it compares against raw pretrained_keys before
alias resolution. Update the validation around pretrained_key in the finetune
flow so it uses the same alias-aware mapping as _get_finetune_rule_single() via
get_model_dict(), and only raise the ValueError after checking the resolved
model dict entries rather than the unexpanded pretrained_keys list.

In `@deepmd/pt_expt/entrypoints/main.py`:
- Around line 167-177: The _ensure_stat_file_path helper currently creates HDF5
stat files with h5py.File before ensuring the parent directory exists, so paths
like stats/model_stat.h5 can fail during trainer setup. Update
_ensure_stat_file_path to create the parent directories first for file targets,
then open the HDF5 file, and keep the existing directory creation path unchanged
for non-HDF5 stat targets. Use the stat_file_path, Path, and h5py.File logic in
this function to apply the fix.

---

Outside diff comments:
In `@deepmd/jax/utils/serialization.py`:
- Around line 189-207: The state restoration logic in convert_str_to_int_key is
too aggressive and converts digit-only task names inside state["models"], which
breaks later lookup in the model merge path. Update the key conversion so it
only normalizes keys that represent numeric indices where needed, while
preserving string task keys used by model_def_script["model_dict"]; then ensure
the state_by_model lookup in the model merging section continues to use the
original task-name strings.

In `@deepmd/pt_expt/train/training.py`:
- Around line 1509-1515: The `latest` symlink in `training.py` is created with a
target that can become incorrectly resolved when `save_ckpt` includes a
directory. Update the checkpoint-saving logic in the block that builds
`ckpt_path`, `latest = Path(f"{self.save_ckpt}.pt")`, and calls
`latest.symlink_to(...)` so the symlink target is computed relative to
`latest`’s parent directory instead of using the full path string. Make sure the
existing `torch.save` flow and cleanup of any prior `latest` link remain
unchanged.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 78180874-4d87-4c66-bd4e-fd0b4ee961bb

📥 Commits

Reviewing files that changed from the base of the PR and between 78272a7 and 5bebfdf.

📒 Files selected for processing (15)
  • deepmd/dpmodel/train/__init__.py
  • deepmd/dpmodel/train/data.py
  • deepmd/dpmodel/train/entrypoint.py
  • deepmd/dpmodel/train/trainer.py
  • deepmd/dpmodel/utils/training_utils.py
  • deepmd/jax/entrypoints/train.py
  • deepmd/jax/train/trainer.py
  • deepmd/jax/utils/finetune.py
  • deepmd/jax/utils/serialization.py
  • deepmd/pt_expt/entrypoints/main.py
  • deepmd/pt_expt/train/training.py
  • source/tests/jax/test_training.py
  • source/tests/pt_expt/test_entrypoint.py
  • source/tests/test_dpmodel_abstract_trainer.py
  • source/tests/test_dpmodel_train_entrypoint.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • deepmd/dpmodel/train/init.py

Comment thread deepmd/dpmodel/train/data.py Outdated
Comment thread deepmd/dpmodel/train/trainer.py Outdated
Comment thread deepmd/dpmodel/train/trainer.py
Comment thread deepmd/dpmodel/utils/training_utils.py
Comment thread deepmd/jax/utils/finetune.py Outdated
Comment thread deepmd/jax/utils/finetune.py Outdated
Comment thread deepmd/pt_expt/entrypoints/main.py
@njzjz njzjz force-pushed the feat/dpmodel-abstract-trainer-5229 branch from 5bebfdf to 9485193 Compare June 29, 2026 09:27
@njzjz njzjz requested review from iProzd and wanghan-iapcm June 29, 2026 09:30

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@deepmd/jax/utils/finetune.py`:
- Around line 21-25: The suffix check in _load_model_params is too restrictive
and rejects checkpoint directories or pointer paths that serialize_from_file can
already deserialize. Remove the hard-coded .jax validation in _load_model_params
and let serialize_from_file(finetune_model) handle the input format consistently
with the freeze() and init_model paths, while still extracting
"model_def_script" from the loaded checkpoint data.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: cd8a6c26-b867-4429-892e-f4aa7e5fbc5e

📥 Commits

Reviewing files that changed from the base of the PR and between 5bebfdf and 9485193.

📒 Files selected for processing (22)
  • deepmd/dpmodel/train/__init__.py
  • deepmd/dpmodel/train/data.py
  • deepmd/dpmodel/train/entrypoint.py
  • deepmd/dpmodel/train/trainer.py
  • deepmd/dpmodel/utils/training_utils.py
  • deepmd/jax/entrypoints/train.py
  • deepmd/jax/train/trainer.py
  • deepmd/jax/utils/finetune.py
  • deepmd/jax/utils/serialization.py
  • deepmd/pd/utils/finetune.py
  • deepmd/pt/utils/finetune.py
  • deepmd/pt_expt/entrypoints/main.py
  • deepmd/pt_expt/train/training.py
  • deepmd/pt_expt/utils/finetune.py
  • deepmd/utils/finetune.py
  • source/tests/common/dpmodel/test_train_abstract_trainer.py
  • source/tests/common/dpmodel/test_train_data.py
  • source/tests/common/dpmodel/test_train_entrypoint.py
  • source/tests/common/dpmodel/test_training_utils.py
  • source/tests/common/test_finetune_utils.py
  • source/tests/jax/test_training.py
  • source/tests/pt_expt/test_entrypoint.py
🚧 Files skipped from review as they are similar to previous changes (8)
  • deepmd/dpmodel/utils/training_utils.py
  • deepmd/dpmodel/train/init.py
  • deepmd/jax/utils/serialization.py
  • deepmd/dpmodel/train/data.py
  • deepmd/pt_expt/train/training.py
  • deepmd/jax/entrypoints/train.py
  • deepmd/jax/train/trainer.py
  • deepmd/pt_expt/entrypoints/main.py

Comment thread deepmd/jax/utils/finetune.py
@njzjz njzjz linked an issue Jun 29, 2026 that may be closed by this pull request
Comment on lines +388 to +391
for model_key in self.model_keys:
tx = optax.adam(
learning_rate=lambda step: self.lr.value(self.start_step + step),
)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non-blocking (JAX multi-task LR schedule): this learning_rate=lambda step: self.lr.value(self.start_step + step) was correct for the pre-refactor single-task trainer (one optimizer, optimizer.update called once per global step, so optax's internal schedule count == global step). Inside this per-task loop it no longer holds: each task gets its own nnx.Optimizer, and optax drives the schedule by that optimizer's internal update count, which only advances when the task is sampled (~p_i * num_steps times). So in multi-task each task's LR decays too slowly and never reaches the configured minimum at num_steps.

It's also internally inconsistent: the loss-prefactor LR passed into train_step (self.lr.value(step), global step) and the lcurve LR (learning_rate(step), global step) both use the global step, while the optimizer's applied LR uses the per-task count — they diverge for every multi-task run. This also differs from pt/pt_expt, which advance the scheduler once per global step for all tasks.

Single-task is unaffected (counts coincide), and JAX multi-task is new here so it's not a regression — but it ships wrong. Suggest driving the optimizer schedule by the global step explicitly (so it matches the loss-prefactor/lcurve LR and the other backends) rather than relying on optax's per-optimizer internal count.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in njzjz/deepmd-kit@7280c146. The JAX optimizer now separates Adam moment scaling from LR scaling and applies the learning rate passed from the global training step, so multi-task tasks no longer decay by per-task optimizer counts. Added a focused optax test for explicit global LR scaling.

Comment on lines +126 to +127
if hasattr(training_data, "get_nsystems"):
return int(training_data.get_nsystems())

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non-blocking (test coverage): the get_nsystems branch here is the real-data path (e.g. the wrapped training-data systems), but the new test_training_utils.py only constructs objects exercising the Sized/len() and unsized-fallback branches — none expose get_nsystems, so this first branch is never run. The project convention is to cover every reachable branch of a new helper; a small test passing an object with a get_nsystems() method would close it.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in njzjz/deepmd-kit@7280c146. Added coverage for the get_nsystems() branch and verified that it takes precedence over len for model probability defaults.

def _ordered_task_keys(self, results: Mapping[str, Any]) -> list[str]:
keys = self.task_keys or list(results)
return [key for key in keys if key in results]

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non-blocking (test coverage): is_chief is the core of the multi-rank abstraction — it gates lcurve open, display, periodic save, and final checkpoint (the is_chief uses below in this loop). But every test in test_train_abstract_trainer.py uses the default RankContext() (rank 0, chief), so the is_chief == False branch — i.e. that a non-chief rank skips all of those — is never exercised. Worth a test that constructs a RankContext(rank=1, world_size=2) and asserts the non-chief rank writes no lcurve/checkpoint.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in njzjz/deepmd-kit@7280c146. Added a non-chief RankContext(rank=1, world_size=2) test that still runs optimizer steps but writes no lcurve and saves no checkpoints.

@njzjz njzjz force-pushed the feat/dpmodel-abstract-trainer-5229 branch from 5767485 to 7280c14 Compare June 30, 2026 05:21
@njzjz njzjz requested a review from wanghan-iapcm June 30, 2026 11:27
Comment thread deepmd/jax/train/trainer.py Outdated
Comment on lines +143 to +147
self.models = self._deserialize_models(checkpoint_data)
self.model_def_script = checkpoint_data["model_def_script"]
self.model_params_by_task = self._model_params_by_task(
self.model_def_script
)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

--init-model path currently may ignore --use-pretrain-script=False.

The entrypoint only replaces config["model"] from the checkpoint when use_pretrain_script is true, but DPTrainer.__init__() later unconditionally overwrites self.model_def_script with checkpoint_data["model_def_script"] for both init_model and restart. So an init_model run without --use-pretrain-script can still save checkpoints using the init checkpoint’s model script instead of the input config.

Could we keep jdata["model"] for init_model unless use_pretrain_script was explicitly requested, and only force checkpoint metadata for restart?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in 7632feff4: JAX init_model now preserves the input jdata["model"] / model_def_script when use_pretrain_script is false. Only restart forces checkpoint metadata and current step. I also added focused tests for both the init-model and restart paths.

Comment on lines +524 to +532
last_log_time = current_time
last_log_step = display_step

if (
self.rank_context.is_chief
and self.trainer_config.save_freq > 0
and display_step % self.trainer_config.save_freq == 0
):
self.save_checkpoint(display_step)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in 7632feff4: full validation is now wired at the trainer level. AbstractTrainer has a run_full_validation() hook, shared logging/top-k/best-checkpoint bookkeeping lives in deepmd.dpmodel.train.validation, and pt_expt plus JAX provide backend-specific validators/checkpoint writers. This avoids adding a model capability API for validation.

Copilot AI review requested due to automatic review settings July 1, 2026 05:52
@njzjz

njzjz commented Jul 1, 2026

Copy link
Copy Markdown
Member Author

Pushed 7632feff4 with the requested updates.

Changes:

  • Fixed JAX init_model so it preserves the input model script when use_pretrain_script=False; only restart adopts checkpoint metadata/current step.
  • Added trainer-level full validation support for pt_expt and JAX, with shared logging/top-k/best-checkpoint bookkeeping in deepmd.dpmodel.train.validation.
  • Kept full validation out of model capability APIs.
  • Fixed pt_expt full-validation force evaluation by enabling gradients on validation coordinate inputs.

Validation:

  • ruff format .
  • ruff check .
  • git diff --check
  • source/tests/common/dpmodel/test_train_abstract_trainer.py
  • source/tests/pt/test_validation.py
  • source/tests/pt_expt/test_training.py::TestTraining::test_full_validation_loop
  • JAX full-validation tests plus TestJAXTraining::test_train_entrypoint_runs_one_step_from_scratch
  • Live GPU CLI: srun --gres=gpu:1 dp --pt-expt train input.json --skip-neighbor-stat
  • Live GPU CLI: srun --gres=gpu:1 dp --jax train input.json --skip-neighbor-stat

Both live GPU CLI runs produced val.log and best checkpoint artifacts.

@njzjz njzjz requested a review from iProzd July 1, 2026 05:54
@njzjz

njzjz commented Jul 1, 2026

Copy link
Copy Markdown
Member Author

@coderabbitai review

@coderabbitai

coderabbitai Bot commented Jul 1, 2026

Copy link
Copy Markdown
Contributor
✅ Action performed

Review finished.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@deepmd/dpmodel/train/trainer.py`:
- Around line 230-236: The mapping branch in trainer.py under the probability
handling logic currently validates missing keys but lets extra task probability
keys slip through. Update the probabilities check in the Mapping path to reject
unknown keys as well as missing ones, using self._keys as the only allowed set
before building the numpy array. Keep the fix localized to the probability
normalization/validation block so stale or mistyped entries raise a ValueError
instead of being silently ignored.

In `@deepmd/dpmodel/train/validation.py`:
- Around line 282-326: The validation flow in run() only raises errors on rank
0, so peer JAX processes can keep running after a rank-0 failure. Add a
backend-specific error propagation hook in the validator path around
_raise_if_error and/or the rank-0 try blocks, then implement the JAX version
used by the trainer to broadcast the failure or terminate all processes
consistently. Make sure the hook is invoked for failures in _evaluate,
save_checkpoint/_reconcile_best_checkpoints, and _log_result so non-zero ranks
do not continue silently.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 9857057b-3e07-4295-8a4a-6f044d0b4b1e

📥 Commits

Reviewing files that changed from the base of the PR and between 5bebfdf and 7632fef.

📒 Files selected for processing (28)
  • deepmd/dpmodel/train/__init__.py
  • deepmd/dpmodel/train/data.py
  • deepmd/dpmodel/train/entrypoint.py
  • deepmd/dpmodel/train/trainer.py
  • deepmd/dpmodel/train/validation.py
  • deepmd/dpmodel/utils/training_utils.py
  • deepmd/jax/entrypoints/train.py
  • deepmd/jax/train/trainer.py
  • deepmd/jax/train/validation.py
  • deepmd/jax/utils/finetune.py
  • deepmd/jax/utils/serialization.py
  • deepmd/pd/utils/finetune.py
  • deepmd/pt/train/validation.py
  • deepmd/pt/utils/finetune.py
  • deepmd/pt_expt/entrypoints/main.py
  • deepmd/pt_expt/train/training.py
  • deepmd/pt_expt/utils/finetune.py
  • deepmd/utils/argcheck.py
  • deepmd/utils/finetune.py
  • source/tests/common/dpmodel/test_train_abstract_trainer.py
  • source/tests/common/dpmodel/test_train_data.py
  • source/tests/common/dpmodel/test_train_entrypoint.py
  • source/tests/common/dpmodel/test_training_utils.py
  • source/tests/common/test_finetune_utils.py
  • source/tests/jax/test_training.py
  • source/tests/pt/test_validation.py
  • source/tests/pt_expt/test_entrypoint.py
  • source/tests/pt_expt/test_training.py
💤 Files with no reviewable changes (9)
  • source/tests/common/dpmodel/test_train_data.py
  • source/tests/common/dpmodel/test_training_utils.py
  • source/tests/pt/test_validation.py
  • source/tests/pt_expt/test_training.py
  • source/tests/common/dpmodel/test_train_entrypoint.py
  • source/tests/common/dpmodel/test_train_abstract_trainer.py
  • source/tests/pt_expt/test_entrypoint.py
  • source/tests/common/test_finetune_utils.py
  • source/tests/jax/test_training.py
✅ Files skipped from review due to trivial changes (1)
  • deepmd/utils/argcheck.py
🚧 Files skipped from review as they are similar to previous changes (12)
  • deepmd/dpmodel/utils/training_utils.py
  • deepmd/pt/utils/finetune.py
  • deepmd/dpmodel/train/init.py
  • deepmd/pd/utils/finetune.py
  • deepmd/pt_expt/utils/finetune.py
  • deepmd/dpmodel/train/data.py
  • deepmd/jax/utils/serialization.py
  • deepmd/jax/entrypoints/train.py
  • deepmd/utils/finetune.py
  • deepmd/jax/train/trainer.py
  • deepmd/pt_expt/train/training.py
  • deepmd/pt_expt/entrypoints/main.py

Comment on lines +230 to +236
elif isinstance(probabilities, Mapping):
missing = [key for key in self._keys if key not in probabilities]
if missing:
raise ValueError(f"Missing task probabilities for {missing}.")
prob = np.asarray(
[probabilities[key] for key in self._keys], dtype=np.float64
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎯 Functional Correctness | 🟡 Minor | ⚡ Quick win

Reject unknown probability keys.

The mapping path checks missing task probabilities but silently ignores extra keys, so stale or mistyped task weights can pass validation unnoticed.

Suggested fix
         elif isinstance(probabilities, Mapping):
             missing = [key for key in self._keys if key not in probabilities]
             if missing:
                 raise ValueError(f"Missing task probabilities for {missing}.")
+            unknown = [key for key in probabilities if key not in self._tasks]
+            if unknown:
+                raise ValueError(f"Unknown task probabilities for {unknown}.")
             prob = np.asarray(
                 [probabilities[key] for key in self._keys], dtype=np.float64
             )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
elif isinstance(probabilities, Mapping):
missing = [key for key in self._keys if key not in probabilities]
if missing:
raise ValueError(f"Missing task probabilities for {missing}.")
prob = np.asarray(
[probabilities[key] for key in self._keys], dtype=np.float64
)
elif isinstance(probabilities, Mapping):
missing = [key for key in self._keys if key not in probabilities]
if missing:
raise ValueError(f"Missing task probabilities for {missing}.")
unknown = [key for key in probabilities if key not in self._tasks]
if unknown:
raise ValueError(f"Unknown task probabilities for {unknown}.")
prob = np.asarray(
[probabilities[key] for key in self._keys], dtype=np.float64
)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/dpmodel/train/trainer.py` around lines 230 - 236, The mapping branch
in trainer.py under the probability handling logic currently validates missing
keys but lets extra task probability keys slip through. Update the probabilities
check in the Mapping path to reject unknown keys as well as missing ones, using
self._keys as the only allowed set before building the numpy array. Keep the fix
localized to the probability normalization/validation block so stale or mistyped
entries raise a ValueError instead of being silently ignored.

Comment on lines +282 to +326
if self.rank == 0:
try:
result = self._evaluate(display_step)
save_path = result.saved_best_path
except Exception as exc:
caught_exception = exc
error_message = (
"Full validation failed during evaluation:\n"
f"{traceback.format_exc()}"
)

self._raise_if_error(error_message, caught_exception)

if save_path is not None and self.rank == 0:
try:
save_checkpoint(Path(save_path), lr=lr, step=step_id)
self._reconcile_best_checkpoints()
except Exception as exc:
caught_exception = exc
error_message = (
"Full validation failed while saving the best checkpoint:\n"
f"{traceback.format_exc()}"
)
else:
error_message = None
caught_exception = None

self._raise_if_error(error_message, caught_exception)

if self.rank == 0:
try:
self._log_result(result)
except Exception as exc:
caught_exception = exc
error_message = (
"Full validation failed while writing logs:\n"
f"{traceback.format_exc()}"
)
else:
error_message = None
caught_exception = None

self._raise_if_error(error_message, caught_exception)

return result if self.rank == 0 else None

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🩺 Stability & Availability | 🟠 Major | 🏗️ Heavy lift

Propagate rank-0 validation failures to peer JAX processes.

run() catches and raises validation/save/log errors only on rank 0; other ranks call _raise_if_error(None, None) and continue. Since the JAX trainer constructs this validator with rank=int(jax.process_index()), a rank-0 failure can leave peer processes running after validation aborts. Please add an overridable backend error-propagation hook and implement it for JAX, or otherwise terminate all peers consistently.

🧰 Tools
🪛 Ruff (0.15.20)

[warning] 286-286: Do not catch blind exception: Exception

(BLE001)


[warning] 299-299: Do not catch blind exception: Exception

(BLE001)


[warning] 314-314: Do not catch blind exception: Exception

(BLE001)

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/dpmodel/train/validation.py` around lines 282 - 326, The validation
flow in run() only raises errors on rank 0, so peer JAX processes can keep
running after a rank-0 failure. Add a backend-specific error propagation hook in
the validator path around _raise_if_error and/or the rank-0 try blocks, then
implement the JAX version used by the trainer to broadcast the failure or
terminate all processes consistently. Make sure the hook is invoked for failures
in _evaluate, save_checkpoint/_reconcile_best_checkpoints, and _log_result so
non-zero ranks do not continue silently.

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces a backend-independent training abstraction layer under deepmd.dpmodel.train, and migrates JAX and pt_expt training/finetune flows to reuse shared orchestration (tasks/ranks normalization, learning-curve output, checkpoint cadence, lifecycle hooks, and full-validation best-checkpoint management). It also consolidates fine-tuning rule generation into a single backend-agnostic builder in deepmd.utils.finetune, leaving backend modules to focus on checkpoint loading and state/tensor copying.

Changes:

  • Add backend-independent trainer/entrypoint/data/full-validation primitives in deepmd.dpmodel.train, with associated test coverage moved/added under source/tests/common/dpmodel/.
  • Centralize fine-tuning rule building in deepmd.utils.finetune and refactor PT/PT-exportable/Paddle/JAX finetune helpers to delegate to the shared builder.
  • Extend full-validation best-checkpoint management to support both file checkpoints (.pt) and directory checkpoints (.jax), and migrate JAX + pt_expt training entrypoints onto the shared pipeline.

Reviewed changes

Copilot reviewed 28 out of 28 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
source/tests/pt/test_validation.py Adds coverage for best-checkpoint reconciliation when checkpoints are directories (e.g., .jax).
source/tests/pt_expt/test_training.py Adds an integration-style test asserting pt_expt full validation writes/prunes best checkpoints and logs.
source/tests/pt_expt/test_entrypoint.py New tests for pt_expt entrypoint option normalization, process-group ownership, checkpoint links/retention, and stat-file creation.
source/tests/jax/test_training.py Adds tests for JAX optimizer LR scaling, finetune loading, state/key normalization, and entrypoint gating/multitask behavior.
source/tests/common/test_finetune_utils.py Adds extensive tests for the shared finetune rule builder (single/multitask, aliases, random fitting, immutability).
source/tests/common/dpmodel/test_training_utils.py New tests for model probability resolution and size fallbacks.
source/tests/common/dpmodel/test_train_entrypoint.py New tests validating the shared entrypoint pipeline sequencing and teardown behavior on failures.
source/tests/common/dpmodel/test_train_data.py New tests for shared data-summary printing compatibility/failure propagation.
source/tests/common/dpmodel/test_train_abstract_trainer.py New tests for the backend-independent trainer loop, multitask sampling, lcurve formatting, checkpoint cadence, and full-validation ordering.
deepmd/utils/finetune.py Introduces FinetuneRuleBuilder and shared finetune-rule construction APIs.
deepmd/utils/argcheck.py Updates full-validation docs to indicate support across PT, pt_expt, and JAX.
deepmd/pt/utils/finetune.py Refactors PyTorch finetune rules to delegate to shared finetune rule builder.
deepmd/pt/train/validation.py Extends best-checkpoint handling to support configurable suffixes and directory checkpoints; improves validation-data iteration.
deepmd/pt_expt/utils/finetune.py Refactors pt_expt finetune rules to delegate to shared finetune rule builder with backend-specific errors.
deepmd/pt_expt/train/training.py Migrates pt_expt training loop to AbstractTrainer, adds task normalization, full validation hook, checkpoint retention, and relative latest-link handling.
deepmd/pt_expt/entrypoints/main.py Migrates pt_expt train entrypoint to the shared entrypoint pipeline and shared data helpers; adds stat-file creation helper.
deepmd/pd/utils/finetune.py Refactors Paddle finetune rules to delegate to shared finetune rule builder.
deepmd/jax/utils/serialization.py Updates JAX serialization to support multitask state layout and avoid mis-normalizing numeric-looking task keys.
deepmd/jax/utils/finetune.py New JAX finetune helper delegating to shared finetune rule builder.
deepmd/jax/train/validation.py New JAX full-validation implementation built on backend-independent FullValidatorBase.
deepmd/jax/train/trainer.py Migrates JAX trainer to AbstractTrainer, adds multitask support, shared finetune integration, full validation hook, and checkpoint writing changes.
deepmd/jax/entrypoints/train.py Migrates JAX train entrypoint to the shared entrypoint pipeline and shared data helpers; adds multitask-aware neighbor-stat update.
deepmd/dpmodel/utils/training_utils.py Improves model-probability resolution by handling get_nsystems() and non-sized data sources.
deepmd/dpmodel/train/validation.py New backend-independent full-validation base with best-checkpoint management and val.log formatting.
deepmd/dpmodel/train/trainer.py New backend-independent trainer abstraction handling task selection, display scheduling, lcurve writing, and checkpoint cadence.
deepmd/dpmodel/train/entrypoint.py New backend-independent entrypoint orchestration pipeline used by backend-specific entrypoints.
deepmd/dpmodel/train/data.py New shared data/task config normalization utilities and summary-printing compatibility helper.
deepmd/dpmodel/train/init.py Exposes the new backend-independent training abstractions as a package API.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 492 to 496
self._lmdb_test_data = LmdbTestData(
lmdb_dataset.lmdb_path,
type_map=list(lmdb_dataset.type_map),
lmdb_path,
type_map=list(type_map or []),
shuffle_test=False,
)
) -> tuple[Any, optax.EmptyState]:
del params
learning_rate = kwargs["learning_rate"]
updates = jax.tree.map(lambda update: -learning_rate * update, updates)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Abstract PyTorch Exportable Training Code into dpmodel

4 participants