Skip to content

feat(dpmodel): graph-native se_atten attention (NeighborGraph PR-D)#5715

Open
wanghan-iapcm wants to merge 10 commits into
deepmodeling:masterfrom
wanghan-iapcm:feat-graph-attn-prD
Open

feat(dpmodel): graph-native se_atten attention (NeighborGraph PR-D)#5715
wanghan-iapcm wants to merge 10 commits into
deepmodeling:masterfrom
wanghan-iapcm:feat-graph-attn-prD

Conversation

@wanghan-iapcm

@wanghan-iapcm wanghan-iapcm commented Jul 2, 2026

Copy link
Copy Markdown
Collaborator

Implements NeighborGraph PR-D: the graph path now supports attn_layer > 0 for dpa1/se_atten, removing the attn_layer=0-only restriction shipped in #5583.

What

  • Segment toolkit: segment_max + numerically-stable, mask-aware segment_softmax (deepmd/dpmodel/utils/neighbor_graph/segment.py), built on the existing xp_maximum_at.
  • center_edge_pairs (neighbor_graph/pairs.py): pairs of edges sharing a center — the edge-pair axis shared with the upcoming angle machinery (PR-E). Segment-based enumeration (a global (E,E) boolean is deliberately avoided: O(N²·nnei²) memory). Two forms: compact eager (dynamic P, carry-all graphs) and shape-static (P = n_center·nnei², pure arange/reshape arithmetic, no nonzero) for the center-major static layout — this keeps the traced/compiled/export path traceable.
  • DescrptBlockSeAtten._graph_attention: op-for-op ragged mirror of GatedAttentionLayer/NeighborGatedAttention — per-center q@kᵀ becomes per-pair q_m·k_n, softmax over keys becomes segment_softmax grouped by the query edge; head_dim QKV slicing, q/k/v normalize, temperature/scaling, smooth shift trick, post-softmax sw and dotr weighting, residual + LayerNorm per layer.
  • edge_env_mat(return_sw=True) exposes the per-edge switch (zeroed on padding) for the smooth branch.
  • uses_graph_lower widened: attention configs (concat tebd, no exclude_types) are now graph-eligible — pt_expt eager/compiled/exported paths route them through the graph lower by default.

Numerical semantics (reviewed decision)

  • Shape-static adapter path (the dense call adapter, from_dense_quartet(compact=False) + static_nnei): bit-exact vs the dense body, rtol 1e-12, full flag matrix (attn_layer 1/2 × dotr × smooth × normalize × temperature, binding AND non-binding sel).
  • Carry-all graphs: exact for non-smooth attention. For smooth_type_embedding=True, the dense branch keeps sel-padding slots in the attention softmax denominator (weight exp(-attnw_shift)), which makes the dense output depend on sel itself (measured up to ~1e-4 with an identical physical neighbor set). The carry-all form drops those phantom terms by design — the sel-independent math. Pinned by a clean-divergence test; route-equivalence fixtures pin smooth_type_embedding=False.
  • se_atten_v2 (tebd_input_mode="strip") remains graph-ineligible (strip mode is a later PR) — pinned by test.

Testing

  • 38 new dpmodel tests (segment toolkit, pairs incl. random-vs-oracle + static-vs-compact equality, attention parity matrix, binding-sel divergence sanity).
  • pt_expt: test_make_fx_graph_attn (graph forward + autograd at attn_layer=2 traces under make_fx, both smooth branches — required since compiled training uses the graph lower); model-level graph-vs-legacy force/virial/atom-virial parity parametrized over attn_layer {0,2}.
  • Local CPU: common/dpmodel 583, consistent dpa1+se_atten_v2 209, pt_expt descriptor/model/utils 701 (2 failures: dpa4 export inductor error pre-existing on upstream/master, and a route-parity fixture fixed in-branch).
  • GPU-validated (Tesla T4, cuda:0): dpmodel suites 38, pt_expt graph-lower/make_fx/consistency 44 (CUDA 1e-10), route-parity 6, attention AOTI export pipeline + dpa1 cross-backend consistency 105 — all passed.

Known limitations

  • Strip-mode (se_atten_v2) attention stays on the dense path.
  • Carry-all smooth attention diverges from dense by design (see above); old behavior reachable via neighbor_graph_method="legacy" / explicit World-1 builders.
  • num_heads == 1 assumed (dpa1 never exposes num_heads); fail-fast otherwise.
  • Compact center_edge_pairs is eager-only (nonzero); traced paths use the shape-static form.
  • 3-body angles (PR-E), jax graph force (PR-F), dpa2/3 MP (PR-G) unchanged.

Summary by CodeRabbit

  • New Features

    • Expanded graph-native attention support for more descriptor attention configurations, including transformer-style graph execution during tracing/export.
    • Added new neighbor-graph toolkit primitives for center-based edge pairing and segment reduction (max/softmax), plus enhanced environment-matrix smoothing switch handling.
    • Added dynamic shape hinting to improve tracing/export stability.
  • Bug Fixes

    • Improved shape-static graph routing and attention behavior for padding/empty-edge cases, improving reliability and parity with dense reference paths.
  • Tests

    • Added coverage for graph attention parity, eligibility, FX traceability, and single-atom/no-edge export scenarios.

Han Wang added 6 commits July 3, 2026 00:10
…ftmax

Built on the existing xp_maximum_at (no new array_api helper needed).
Part of NeighborGraph PR-D (graph-native attention).
Segment-based (global (E,E) boolean deliberately avoided): compact eager
form for carry-all graphs + shape-static nonzero-free form for the
center-major static layout (jit/export/make_fx traceable).
Part of NeighborGraph PR-D; PR-E angles reuse (unordered, no-self).
…r > 0)

DescrptBlockSeAtten.call_graph grows _graph_attention: the dense per-center
(nnei, nnei) attention square becomes the edge-pair axis (center_edge_pairs,
ordered + self-included), softmax over keys becomes segment_softmax grouped
by the query edge. Op-for-op mirror of GatedAttentionLayer.call (head_dim
QKV slicing, normalize q/k/v, temperature/scaling, smooth shift trick,
post-softmax sw and dotr weighting, residual + LayerNorm per layer).

- shape-static adapter path (static_nnei threaded from the dense call
  adapter): bit-exact vs the dense body, rtol 1e-12, full flag matrix
  (attn_layer 1/2 x dotr x smooth x normalize x temperature, binding and
  non-binding sel).
- carry-all (compact) graphs: exact for non-smooth; for smooth the dense
  branch keeps sel-padding slots in the softmax denominator (dense output is
  sel-DEPENDENT, up to ~1e-4) — the carry-all form drops those phantom terms
  by design (user decision 2026-07-03), pinned by a clean-divergence test.
- edge_env_mat(return_sw=True) exposes the per-edge switch (zeroed on
  padding) for the smooth branch.
- uses_graph_lower: attention configs are now graph-eligible (concat tebd,
  no exclude_types still required).
…ial parity

- test_make_fx_graph_attn: graph forward + autograd.grad at attn_layer=2
  traces under make_fx for BOTH smooth branches (the shape-static
  center_edge_pairs form is nonzero-free) — required since pt_expt compiled
  training routes eligible models through the graph lower.
- model-level graph-vs-legacy lower parity now parametrized over
  attn_layer {0, 2} (energy/force/virial/atom_virial, 1e-12 CPU).
- eligibility pins: attention+concat is graph-eligible; se_atten_v2
  (tebd_input_mode='strip') correctly stays dense (strip = later PR;
  the plan's 'se_atten_v2 inherits for free' did not hold).
- linear-model weight tests: pin smooth_type_embedding=False — the standard
  (graph-routed, carry-all) and linear (graph-ineligible, dense) submodels
  otherwise differ by the accepted smooth-attention denominator divergence
  (~1e-6), which is a route artifact, not a weight-combination bug.
- new binding-sel sanity: carry-all graph attention diverges from the
  sel-truncated dense path when sel binds (spec decision deepmodeling#17).
…rity)

neighbor_list=None now takes the carry-all graph default for eligible
attention models; explicit World-1 builders take the legacy dense route.
With smooth attention the two routes differ by design (PR-D), so the
route-equivalence tests pin smooth_type_embedding=False.
@dosubot dosubot Bot added the new feature label Jul 2, 2026
@github-actions github-actions Bot added the Python label Jul 2, 2026
@coderabbitai

coderabbitai Bot commented Jul 2, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 8a8a06d4-3efe-4b32-82c8-399f039ceb31

📥 Commits

Reviewing files that changed from the base of the PR and between edc2fca and 25d0161.

📒 Files selected for processing (2)
  • deepmd/dpmodel/descriptor/dpa1.py
  • source/tests/pt_expt/model/test_dpa1_graph_lower.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • source/tests/pt_expt/model/test_dpa1_graph_lower.py
  • deepmd/dpmodel/descriptor/dpa1.py

📝 Walkthrough

Walkthrough

DPA1 graph-native lowering now supports attention configurations with static neighbor-pair enumeration, new segment reduction helpers, and expanded graph/export test coverage.

Changes

Graph-native attention support

Layer / File(s) Summary
Segment max/softmax reduction primitives
deepmd/dpmodel/utils/neighbor_graph/segment.py, deepmd/dpmodel/utils/neighbor_graph/__init__.py, source/tests/common/dpmodel/test_segment_softmax.py
Adds segment_max and segment_softmax, exports them, and covers reduction behavior, masking, stability, and NumPy/Torch parity.
Center edge-pair enumeration and edge switches
deepmd/dpmodel/utils/neighbor_graph/pairs.py, deepmd/dpmodel/utils/neighbor_graph/env.py, deepmd/dpmodel/utils/neighbor_graph/__init__.py, source/tests/common/dpmodel/test_center_edge_pairs.py, deepmd/dpmodel/array_api.py
Adds center_edge_pairs with compact and shape-static paths, extends edge_env_mat to optionally return smooth switches, and adds a dynamic-size trace hint used by the graph path.
DPA1 graph-native attention forward
deepmd/dpmodel/descriptor/dpa1.py
Expands graph eligibility, threads static_nnei through call_graph, requests smooth-switch values, and computes graph-native attention with center-edge pairs and segment softmax.
Parity, FX, and export coverage
source/tests/common/dpmodel/test_dpa1_graph_attention_parity.py, source/tests/common/dpmodel/test_dpa1_call_graph_block.py, source/tests/pt_expt/descriptor/test_dpa1.py, source/tests/pt_expt/model/test_dpa1_graph_lower.py, source/tests/pt_expt/model/test_linear_model.py, source/tests/pt_expt/utils/test_neighbor_list.py, source/tests/pt_expt/infer/test_graph_deepeval.py, deepmd/pt_expt/entrypoints/main.py
Adds dense-vs-graph parity, FX-trace, export, and PT2 inference coverage, removes the old fail-fast assumption, and updates supporting graph-export text and smooth-type settings.

Estimated code review effort: 4 (Complex) | ~60 minutes

Possibly related PRs

  • deepmodeling/deepmd-kit#5583: Both PRs extend the same graph-native DPA1 lowering path with attention-related behavior and new neighbor-graph utilities.
  • deepmodeling/deepmd-kit#5604: Both PRs modify deepmd/dpmodel/descriptor/dpa1.py’s graph call path and related graph-export behavior.

Suggested labels: enhancement

Suggested reviewers: OutisLi, iProzd

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly summarizes the main change: adding graph-native attention support for DPA1 se_atten in dpmodel.
Docstring Coverage ✅ Passed No functions found in the changed files to evaluate docstring coverage. Skipping docstring coverage check.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (4)
source/tests/common/dpmodel/test_segment_softmax.py (1)

55-65: 🎯 Functional Correctness | 🔵 Trivial | ⚡ Quick win

Add a regression test for masked-entry-larger-than-max.

None of the mask tests here cover a masked entry whose value exceeds the unmasked max in the same segment — the scenario that triggers the NaN-propagation issue flagged in segment.py. Once that's fixed, a test like the one below would guard the regression:

def test_masked_entry_extreme_value_no_nan(self) -> None:
    logits = np.array([1.0, 1e30, 2.0])  # masked entry (idx 1) dwarfs the max
    ids = np.array([0, 0, 0], dtype=np.int64)
    mask = np.array([True, False, True])
    w = segment_softmax(logits, ids, 1, mask=mask)
    assert not np.any(np.isnan(w))
    assert w[1] == 0.0
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@source/tests/common/dpmodel/test_segment_softmax.py` around lines 55 - 65,
Add a regression test in test_segment_softmax for the
masked-entry-larger-than-max case that currently leads to NaN propagation in
segment_softmax. Extend the existing mask coverage by creating a segment where
the masked element has an extreme value above the unmasked max, then assert the
result contains no NaNs, the masked position is exactly zero, and the unmasked
weights still normalize correctly. Use the existing segment_softmax test pattern
in test_masked_entries_zero to keep the new case consistent.
deepmd/dpmodel/utils/neighbor_graph/pairs.py (1)

92-117: 📐 Maintainability & Code Quality | 🔵 Trivial | 💤 Low value

dst values are unused in the shape-static path (only its shape matters).

_pairs_shape_static derives query/key edges purely from index arithmetic assuming the center-major layout documented in the module docstring; the actual dst values are never consulted to validate that assumption. This matches the documented contract, but if a caller ever passes a dst/static_nnei combination that doesn't match the assumed layout, this silently produces wrong pairs with no diagnostic. Consider a lightweight assertion (e.g., e_tot % nn == 0) or a debug-mode check that dst is actually constant within each block, to fail fast on a layout mismatch instead of silently mis-pairing.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/dpmodel/utils/neighbor_graph/pairs.py` around lines 92 - 117, The
shape-static path in `_pairs_shape_static` relies on center-major block layout
but never validates that `dst` actually matches that assumption, so a mismatched
`static_nnei`/layout can silently produce wrong pairs. Add a lightweight guard
in `_pairs_shape_static` to fail fast on layout mismatches, such as verifying
`e_tot % nn == 0` and/or checking that `dst` is constant within each `nn` block
in a debug-friendly way. Keep the existing index-arithmetic logic for
`query_edge`, `key_edge`, and `pair_mask`, but ensure the contract is enforced
before returning.
deepmd/dpmodel/descriptor/dpa1.py (2)

1671-1684: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick win

"Bit-exact" claim needs a caveat for the default (smooth + compact) configuration.

The docstring states this is a "Bit-exact analogue of call" and the "Known limitations" section only lists tebd_input_mode and exclude_types. But per test_block_compact_graph_smooth_clean_divergence in test_dpa1_graph_attention_parity.py, when static_nnei is None (the default, compact/carry-all form) and smooth=True (also the class default), the output deliberately diverges from dense (up to ~1e-4) by design — the carry-all graph drops phantom sel-padding softmax terms that dense keeps. A reader of this docstring/API surface would not learn about this without digging into the test suite. Since smooth_type_embedding defaults to True and static_nnei defaults to None, the "bit-exact" claim is misleading for the descriptor's own default configuration.

Suggest adding a short caveat to the "Known limitations" (or a new "Notes") section referencing this divergence, mirroring what's already documented in the test docstring.

📝 Suggested docstring addition
         Notes
         -----
         Known limitations:
         - ``tebd_input_mode == "concat"`` only (strip mode lands later);
         - ``exclude_types`` is not yet supported and raises (lands in a later PR).
+        - When ``attn_layer > 0``, ``smooth_type_embedding=True`` (the class
+          default) combined with the compact/carry-all form (``static_nnei=None``,
+          also the default) intentionally diverges from the dense reference
+          (up to ~1e-4): the carry-all graph has no sel-padding slots, so it
+          drops the phantom denominator terms the dense smooth branch keeps.
+          Bit-exact parity (1e-12) only holds on the shape-static form
+          (``static_nnei`` set, as used by the dense ``call`` adapter) or when
+          ``smooth_type_embedding=False``.
         """

Also applies to: 1712-1717

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/dpmodel/descriptor/dpa1.py` around lines 1671 - 1684, Update the
call_graph docstring in dpa1.py to add a caveat that the “bit-exact” claim does
not hold for the default smooth + compact/carry-all configuration: when
static_nnei is None and smooth=True, the graph path can intentionally diverge
slightly from dense because it omits phantom sel-padding softmax terms. Add this
to the existing “Known limitations” or a new “Notes” section, and keep the
wording consistent with the behavior exercised by
test_block_compact_graph_smooth_clean_divergence and the related call_graph
documentation block.

1856-1932: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick win

Extract the shared attnw_shift default. GatedAttentionLayer.call also uses 20.0, so pulling this into a shared constant would keep the dense and graph paths aligned if that default ever changes.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/dpmodel/descriptor/dpa1.py` around lines 1856 - 1932, The hardcoded
attention shift value is duplicated in _graph_attention_one_layer and
GatedAttentionLayer.call, so pull the 20.0 default into a shared constant or
class attribute used by both paths. Update the graph attention logic to
reference that shared symbol so the dense and graph implementations stay aligned
if the default changes.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@deepmd/dpmodel/utils/neighbor_graph/segment.py`:
- Around line 59-89: The masked path in segment_softmax is using raw data for
the exponent shift, which can turn masked large values into inf and then nan
after multiplying by the mask. Update segment_softmax to compute shifted from
data_for_max (the same masked-safe values used for seg_max), and keep the
existing empty/fully-masked guards so exp and denom stay finite. Check the
segment_max/segment_sum flow and the _graph_attention_one_layer caller to ensure
masked attention logits cannot leak into the denominator.

---

Nitpick comments:
In `@deepmd/dpmodel/descriptor/dpa1.py`:
- Around line 1671-1684: Update the call_graph docstring in dpa1.py to add a
caveat that the “bit-exact” claim does not hold for the default smooth +
compact/carry-all configuration: when static_nnei is None and smooth=True, the
graph path can intentionally diverge slightly from dense because it omits
phantom sel-padding softmax terms. Add this to the existing “Known limitations”
or a new “Notes” section, and keep the wording consistent with the behavior
exercised by test_block_compact_graph_smooth_clean_divergence and the related
call_graph documentation block.
- Around line 1856-1932: The hardcoded attention shift value is duplicated in
_graph_attention_one_layer and GatedAttentionLayer.call, so pull the 20.0
default into a shared constant or class attribute used by both paths. Update the
graph attention logic to reference that shared symbol so the dense and graph
implementations stay aligned if the default changes.

In `@deepmd/dpmodel/utils/neighbor_graph/pairs.py`:
- Around line 92-117: The shape-static path in `_pairs_shape_static` relies on
center-major block layout but never validates that `dst` actually matches that
assumption, so a mismatched `static_nnei`/layout can silently produce wrong
pairs. Add a lightweight guard in `_pairs_shape_static` to fail fast on layout
mismatches, such as verifying `e_tot % nn == 0` and/or checking that `dst` is
constant within each `nn` block in a debug-friendly way. Keep the existing
index-arithmetic logic for `query_edge`, `key_edge`, and `pair_mask`, but ensure
the contract is enforced before returning.

In `@source/tests/common/dpmodel/test_segment_softmax.py`:
- Around line 55-65: Add a regression test in test_segment_softmax for the
masked-entry-larger-than-max case that currently leads to NaN propagation in
segment_softmax. Extend the existing mask coverage by creating a segment where
the masked element has an extreme value above the unmasked max, then assert the
result contains no NaNs, the masked position is exactly zero, and the unmasked
weights still normalize correctly. Use the existing segment_softmax test pattern
in test_masked_entries_zero to keep the new case consistent.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 86eede99-fa33-4044-b859-5fe1eb620896

📥 Commits

Reviewing files that changed from the base of the PR and between 55d7e79 and 91784df.

📒 Files selected for processing (13)
  • deepmd/dpmodel/descriptor/dpa1.py
  • deepmd/dpmodel/utils/neighbor_graph/__init__.py
  • deepmd/dpmodel/utils/neighbor_graph/env.py
  • deepmd/dpmodel/utils/neighbor_graph/pairs.py
  • deepmd/dpmodel/utils/neighbor_graph/segment.py
  • source/tests/common/dpmodel/test_center_edge_pairs.py
  • source/tests/common/dpmodel/test_dpa1_call_graph_block.py
  • source/tests/common/dpmodel/test_dpa1_graph_attention_parity.py
  • source/tests/common/dpmodel/test_segment_softmax.py
  • source/tests/pt_expt/descriptor/test_dpa1.py
  • source/tests/pt_expt/model/test_dpa1_graph_lower.py
  • source/tests/pt_expt/model/test_linear_model.py
  • source/tests/pt_expt/utils/test_neighbor_list.py

Comment on lines +59 to +89
def segment_softmax(
data: Array,
segment_ids: Array,
num_segments: int,
mask: Array | None = None,
) -> Array:
"""Softmax over entries sharing a segment id, numerically stable.

Mirrors the dense ``np_softmax`` max-subtraction trick with a PER-SEGMENT
max. ``mask`` (bool, per entry) removes masked entries from the softmax
entirely (zero weight AND excluded from the denominator). Empty or
fully-masked segments produce all-zero weights (no NaN).
"""
xp = array_api_compat.array_namespace(data)
if mask is not None:
# keep masked entries out of the per-segment max: send them to -inf
neg = xp.full_like(data, -xp.inf)
data_for_max = xp.where(mask, data, neg)
else:
data_for_max = data
seg_max = segment_max(data_for_max, segment_ids, num_segments)
# guard -inf (empty / fully-masked segments) so gather doesn't yield inf-inf
seg_max = xp.where(xp.isinf(seg_max), xp.zeros_like(seg_max), seg_max)
shifted = data - xp.take(seg_max, segment_ids, axis=0)
ex = xp.exp(shifted)
if mask is not None:
ex = ex * xp.astype(mask, ex.dtype)
denom = segment_sum(ex, segment_ids, num_segments)
denom_e = xp.take(denom, segment_ids, axis=0)
safe = xp.where(denom_e > 0, denom_e, xp.ones_like(denom_e))
return ex / safe

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🩺 Stability & Availability | 🟠 Major | ⚡ Quick win

Masked-softmax can silently produce NaN for an entire segment.

shifted (line 82) is computed from the raw data, not data_for_max. If a masked entry's value is larger than the per-segment max of the unmasked entries, exp(shifted) overflows to +inf for that entry; the later ex * mask (line 85) then evaluates inf * 0.0 = nan. That NaN is summed into denom (line 86) and gathered back onto every entry sharing the segment id (line 87), so the NaN contaminates the whole segment's softmax output — not just the masked entry. This is exactly the scenario the "numerically stable" masking is supposed to guard against, and it's untested (test_masked_entries_zero / test_all_masked_segment_is_zero_no_nan only use masked values smaller than the unmasked max).

Downstream, dpa1.py's _graph_attention_one_layer calls this with mask=pair_mask on raw attention logits for padding pairs, which are not bounded a priori.

🛡️ Proposed fix
     seg_max = segment_max(data_for_max, segment_ids, num_segments)
     # guard -inf (empty / fully-masked segments) so gather doesn't yield inf-inf
     seg_max = xp.where(xp.isinf(seg_max), xp.zeros_like(seg_max), seg_max)
-    shifted = data - xp.take(seg_max, segment_ids, axis=0)
+    # use data_for_max (already -inf on masked entries) so masked entries
+    # exp() to exactly 0 instead of relying on a post-hoc inf*0 multiply
+    shifted = data_for_max - xp.take(seg_max, segment_ids, axis=0)
     ex = xp.exp(shifted)
     if mask is not None:
         ex = ex * xp.astype(mask, ex.dtype)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def segment_softmax(
data: Array,
segment_ids: Array,
num_segments: int,
mask: Array | None = None,
) -> Array:
"""Softmax over entries sharing a segment id, numerically stable.
Mirrors the dense ``np_softmax`` max-subtraction trick with a PER-SEGMENT
max. ``mask`` (bool, per entry) removes masked entries from the softmax
entirely (zero weight AND excluded from the denominator). Empty or
fully-masked segments produce all-zero weights (no NaN).
"""
xp = array_api_compat.array_namespace(data)
if mask is not None:
# keep masked entries out of the per-segment max: send them to -inf
neg = xp.full_like(data, -xp.inf)
data_for_max = xp.where(mask, data, neg)
else:
data_for_max = data
seg_max = segment_max(data_for_max, segment_ids, num_segments)
# guard -inf (empty / fully-masked segments) so gather doesn't yield inf-inf
seg_max = xp.where(xp.isinf(seg_max), xp.zeros_like(seg_max), seg_max)
shifted = data - xp.take(seg_max, segment_ids, axis=0)
ex = xp.exp(shifted)
if mask is not None:
ex = ex * xp.astype(mask, ex.dtype)
denom = segment_sum(ex, segment_ids, num_segments)
denom_e = xp.take(denom, segment_ids, axis=0)
safe = xp.where(denom_e > 0, denom_e, xp.ones_like(denom_e))
return ex / safe
def segment_softmax(
data: Array,
segment_ids: Array,
num_segments: int,
mask: Array | None = None,
) -> Array:
"""Softmax over entries sharing a segment id, numerically stable.
Mirrors the dense ``np_softmax`` max-subtraction trick with a PER-SEGMENT
max. ``mask`` (bool, per entry) removes masked entries from the softmax
entirely (zero weight AND excluded from the denominator). Empty or
fully-masked segments produce all-zero weights (no NaN).
"""
xp = array_api_compat.array_namespace(data)
if mask is not None:
# keep masked entries out of the per-segment max: send them to -inf
neg = xp.full_like(data, -xp.inf)
data_for_max = xp.where(mask, data, neg)
else:
data_for_max = data
seg_max = segment_max(data_for_max, segment_ids, num_segments)
# guard -inf (empty / fully-masked segments) so gather doesn't yield inf-inf
seg_max = xp.where(xp.isinf(seg_max), xp.zeros_like(seg_max), seg_max)
# use data_for_max (already -inf on masked entries) so masked entries
# exp() to exactly 0 instead of relying on a post-hoc inf*0 multiply
shifted = data_for_max - xp.take(seg_max, segment_ids, axis=0)
ex = xp.exp(shifted)
if mask is not None:
ex = ex * xp.astype(mask, ex.dtype)
denom = segment_sum(ex, segment_ids, num_segments)
denom_e = xp.take(denom, segment_ids, axis=0)
safe = xp.where(denom_e > 0, denom_e, xp.ones_like(denom_e))
return ex / safe
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/dpmodel/utils/neighbor_graph/segment.py` around lines 59 - 89, The
masked path in segment_softmax is using raw data for the exponent shift, which
can turn masked large values into inf and then nan after multiplying by the
mask. Update segment_softmax to compute shifted from data_for_max (the same
masked-safe values used for seg_max), and keep the existing empty/fully-masked
guards so exp and denom stay finite. Check the segment_max/segment_sum flow and
the _graph_attention_one_layer caller to ensure masked attention logits cannot
leak into the denominator.

@codecov

codecov Bot commented Jul 2, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 97.76119% with 3 lines in your changes missing coverage. Please review.
✅ Project coverage is 81.17%. Comparing base (55d7e79) to head (91784df).

Files with missing lines Patch % Lines
deepmd/dpmodel/descriptor/dpa1.py 97.95% 1 Missing ⚠️
deepmd/dpmodel/utils/neighbor_graph/env.py 80.00% 1 Missing ⚠️
deepmd/dpmodel/utils/neighbor_graph/pairs.py 98.30% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5715      +/-   ##
==========================================
- Coverage   81.26%   81.17%   -0.10%     
==========================================
  Files         988      989       +1     
  Lines      110876   111007     +131     
  Branches     4234     4232       -2     
==========================================
+ Hits        90103    90106       +3     
- Misses      19247    19378     +131     
+ Partials     1526     1523       -3     

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

Han Wang added 2 commits July 3, 2026 15:50
…SymInts

The compact (carry-all) pair enumeration used nonzero + tensor-repeat with
Python control flow on their data-dependent sizes, so the attention graph
lower failed torch.export with GuardOnDataDependentSymNode. Register those
sizes as unbacked SymInt sizes (new torch-free xp_hint_dynamic_size shim,
no-op for numpy/jax), take the empty-input fast paths only on concrete int
shapes, build iotas via cumsum(ones)-1 (the array_api_compat arange wrapper
branches on the length in Python), and skip the policy-compression nonzero
when no filter applies (include_self and ordered - the attention default).
Eager numpy/torch results are unchanged.
…> 0)

With the compact pair enumeration unbacked-SymInt-traceable, the carry-all
attention graph lower now exports to a graph-form .pt2 unchanged in ABI
(same 5-tensor NeighborGraph schema, dynamic edge axis) and with carry-all
semantics preserved (no sel truncation, unlike the dense-adapter nlist-form
export). Update the stale freeze-gate message (attention is eligible), add a
symbolic-trace merge gate at attn_layer in {0,2}, parametrize the DeepEval
graph .pt2 fixture over attn_layer (both artifacts: dynamic sizes, PBC and
non-PBC, 1e-10 vs the sel-capped dense reference at non-binding sel), and
add a single-atom zero-real-edge runtime test (the R==0 extreme of the
unbacked sizes).
@wanghan-iapcm

Copy link
Copy Markdown
Collaborator Author

Pushed two additional commits that remove the "attention graph-form export deferred" limitation:

  • feat(dpmodel): make compact center_edge_pairs traceable via unbacked SymInts — the carry-all pair enumeration (nonzero + tensor-repeat) now registers its data-dependent sizes via a new torch-free xp_hint_dynamic_size shim (no-op for numpy/jax), takes empty-input fast paths only on concrete int shapes, builds iotas as cumsum(ones)-1 (the array_api_compat arange wrapper branches on the length in Python), and skips the policy-compression nonzero when no filter applies (the attention default). Eager numpy/torch results are unchanged (bit-identical; full eager suites re-run green).
  • feat(pt_expt): graph-form .pt2 export for dpa1 attention (attn_layer > 0) — with the above, lower_kind="graph" now exports attention models unchanged in ABI (same 5-tensor NeighborGraph schema, dynamic edge axis) and with carry-all semantics preserved (no sel truncation, unlike the dense-adapter nlist-form export). Adds a symbolic-trace merge gate (attn_layer 0/2), parametrizes the graph-.pt2 DeepEval fixture over attn_layer (dynamic sizes, PBC/non-PBC, 1e-10 vs the sel-capped dense reference at non-binding sel), a single-atom zero-edge runtime test, and fixes the stale freeze-gate message.

AOTI parity vs eager carry-all measured at ≤5e-18 across system sizes. Benchmark on a Tesla T4 (fp64, diamond C at experimental density, rcut 6 / sel 180, eager): the graph path is flat ~100 µs/atom (O(N)) and still runs at 4096 atoms where the dense path OOMs; with attention the graph is consistently faster at every size that fits.

Known limitations: relies on torch unbacked-SymInt maturity (validated on 2.10; CPU AOTI); jax.jit of the compact path still needs a static realization (PR-F); C++ gtest of an attention graph .pt2 not added (ABI unchanged from the attn=0 artifact).

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
source/tests/pt_expt/model/test_dpa1_graph_lower.py (1)

240-292: 🎯 Functional Correctness | 🔵 Trivial | ⚡ Quick win

Also assert atom_virial parity.

do_atomic_virial=True is passed on both sides but the resulting atom_virial tensor is never compared, leaving a gap in parity coverage specifically for the new attn_layer=2 graph-attention path this test targets.

✅ Proposed addition
         torch.testing.assert_close(
             out["virial"], ref["energy_derv_c_redu"].reshape(out["virial"].shape), **tol
         )
+        torch.testing.assert_close(
+            out["atom_virial"],
+            ref["energy_derv_c"].reshape(out["atom_virial"].shape),
+            **tol,
+        )
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@source/tests/pt_expt/model/test_dpa1_graph_lower.py` around lines 240 - 292,
The symbolic-trace parity test in test_graph_lower_symbolic_trace already
compares energy, force, and virial, but it omits the atom-level virial output
even though do_atomic_virial=True is used. Update the assertions in
test_graph_lower_symbolic_trace to also compare traced versus reference
atom_virial from forward_lower_graph_exportable and forward_common_lower_graph,
using the same tolerance and reshaping pattern as the other tensor checks if
needed.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@source/tests/pt_expt/model/test_dpa1_graph_lower.py`:
- Around line 240-292: The symbolic-trace parity test in
test_graph_lower_symbolic_trace already compares energy, force, and virial, but
it omits the atom-level virial output even though do_atomic_virial=True is used.
Update the assertions in test_graph_lower_symbolic_trace to also compare traced
versus reference atom_virial from forward_lower_graph_exportable and
forward_common_lower_graph, using the same tolerance and reshaping pattern as
the other tensor checks if needed.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 7850ae92-4cd9-427e-b8b6-a37842539f6d

📥 Commits

Reviewing files that changed from the base of the PR and between 91784df and edc2fca.

📒 Files selected for processing (5)
  • deepmd/dpmodel/array_api.py
  • deepmd/dpmodel/utils/neighbor_graph/pairs.py
  • deepmd/pt_expt/entrypoints/main.py
  • source/tests/pt_expt/infer/test_graph_deepeval.py
  • source/tests/pt_expt/model/test_dpa1_graph_lower.py
✅ Files skipped from review due to trivial changes (1)
  • deepmd/pt_expt/entrypoints/main.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • deepmd/dpmodel/utils/neighbor_graph/pairs.py

The dense call is wrapped in @cast_precision, but the graph route's only
float input (edge_vec) lives inside the NeighborGraph dataclass where the
decorator cannot see it, so non-global-precision models (e.g. float32)
crashed with a double-vs-float matmul on the graph route while the dense
route worked. Cast edge_vec down to the descriptor precision on entry and
the outputs back to the caller's dtype on exit (differentiable, so the
model-level force autograd is unaffected). Add an fp32 graph-vs-dense route
parity test at attn_layer 0 and 2.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant