Skip to content

Conversation

Copy link
Contributor

Copilot AI commented Feb 7, 2026

Refactor backend logic into array_api.py

Analysis Complete

  • Identify all JAX-specific code in deepmd/dpmodel/
  • Review existing array_api.py structure
  • Identify test infrastructure

Implementation Complete

  • Add xp_sigmoid function to array_api.py for backend-specific sigmoid handling
  • Add xp_setitem_at function to array_api.py for backend-specific array assignment
  • Implement generic xp_scatter_sum using array_api operations
  • Update network.py to use new xp_sigmoid function
  • Update network.py to use new xp_setitem_at function
  • Update transform_output.py to use xp_scatter_sum from array_api.py
  • Add comprehensive tests for all functions
  • Add array-api-strict tests for complete backend coverage
  • Remove unused JAX-specific scatter_sum implementation
  • Make xp_setitem_at non-mutating for PyTorch (consistent with other xp_* helpers)
  • Run tests to validate changes (20 passed, 7 skipped)
  • Run linters to ensure code quality (all checks passed)

Changes Summary

  1. deepmd/dpmodel/array_api.py - Centralized backend implementations:

    • xp_sigmoid: Added PyTorch-specific torch.sigmoid() and JAX-specific jax.nn.sigmoid() implementations with generic fallback
    • xp_setitem_at: Handles JAX's functional .at[].set() syntax and PyTorch's clone() for non-mutating behavior, while NumPy uses in-place assignment
    • xp_scatter_sum: Implemented generic array_api version using xp_take_along_axis and xp_add_at helper functions (merged JAX implementation logic), with PyTorch optimization retained
    • Now supports NumPy, JAX, PyTorch, and array-api-strict backends
    • All xp_ helpers are now consistently non-mutating for PyTorch and JAX*
  2. deepmd/dpmodel/utils/network.py - Refactored to use centralized functions:

    • Replaced direct JAX conditional in sigmoid_t function with call to xp_sigmoid
    • Replaced JAX conditional for array assignment with call to xp_setitem_at
  3. deepmd/dpmodel/model/transform_output.py - Simplified scatter_sum usage:

    • Removed JAX conditional and direct import of deepmd.jax.common.scatter_sum
    • Now uses xp_scatter_sum from array_api.py consistently
  4. deepmd/jax/common.py - Removed unused code:

    • Removed scatter_sum function (replaced by generic implementation in array_api.py)
    • Function was not imported or used anywhere in the codebase
  5. source/tests/consistent/test_array_api.py - Complete test coverage across all backends:

    • Added array-api-strict tests for TestXpSigmoidConsistent
    • Added array-api-strict tests for TestXpSetitemAtConsistent
    • Added array-api-strict tests for TestXpScatterSumConsistent
    • Added array-api-strict tests for TestXpBincountConsistent
    • Added array-api-strict tests for TestXpBincountWithWeightsConsistent
    • Added array-api-strict tests for TestXpBincountWithMinlengthConsistent
    • Added non-mutating verification for PyTorch in TestXpSetitemAtConsistent (consistent with other tests)
    • All test classes now have consistent coverage across NumPy, PyTorch, JAX, and array-api-strict

Test Results

  • All 27 array_api tests pass (20 passed, 7 skipped for JAX not installed)
  • Complete array-api-strict test coverage for all xp_* functions
  • All PyTorch tests verify non-mutating behavior
  • All network tests pass
  • Ruff linting: all checks passed

Benefits

  • ✅ All backend-specific conditionals (if is_jax_array, if is_torch_array) centralized in array_api.py
  • ✅ No direct backend imports (from deepmd.jax, from torch) outside of array_api.py
  • ✅ Generic array_api implementations support all backends (NumPy, JAX, PyTorch, array-api-strict)
  • ✅ Backend-specific optimizations retained where beneficial (e.g., torch.sigmoid, torch.scatter_add)
  • All xp_ helpers are consistently non-mutating for PyTorch and JAX*, preventing autograd issues
  • ✅ Removed duplicate/unused code from JAX backend
  • ✅ Improved maintainability - future backend changes only need to update array_api.py
  • ✅ Consistent API for backend-specific operations across the codebase
Original prompt

This section details on the original issue you should resolve

<issue_title>[Feature Request] Refactor JAX backend logic in deepmd/dpmodel into array_api.py</issue_title>
<issue_description>### Summary

Refactor all JAX-specific backend code currently spread across files in deepmd/dpmodel/ (except for array_api.py) by consolidating these implementations into deepmd/dpmodel/array_api.py. This will reduce code duplication, improve maintainability, and centralize NUMPY, JAX, and TORCH backend logic in one place.

Detailed Description

Many files in deepmd/dpmodel/ contain implementation blocks like if is_jax_array for JAX backend support. To streamline future backend improvements and reduce maintenance overhead, all such JAX-specific logic should be migrated into array_api.py. This centralization will make backend wrappers easier to maintain and update.

Example JAX references found:

  • deepmd/dpmodel/utils/network.py, ~line 1149: JAX-specific array assignment logic
  • deepmd/dpmodel/model/transform_output.py, ~line 217: JAX scatter_sum usage

After the refactoring, other files should invoke unified backend functions from array_api.py instead of handling JAX conditionals locally.

(Search results are limited, please see GitHub search for full context.)

Further Information, Files, and Links

No response</issue_description>

Comments on the Issue (you are @copilot in this section)


💡 You can make Copilot smarter by setting up custom instructions, customizing its development environment and configuring Model Context Protocol (MCP) servers. Learn more Copilot coding agent tips in the docs.

Copilot AI and others added 3 commits February 7, 2026 21:45
Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
Copilot AI changed the title [WIP] Refactor JAX backend logic in array_api.py refactor: consolidate JAX backend logic into array_api.py Feb 7, 2026
Copilot AI requested a review from njzjz February 7, 2026 21:52
Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
Copilot AI changed the title refactor: consolidate JAX backend logic into array_api.py refactor: consolidate backend logic into array_api.py with generic implementations Feb 8, 2026
Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
Copilot AI requested a review from njzjz February 8, 2026 03:23
@njzjz njzjz marked this pull request as ready for review February 8, 2026 03:24
Copilot AI review requested due to automatic review settings February 8, 2026 03:24
@dosubot dosubot bot added the enhancement label Feb 8, 2026
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Consolidates backend-specific array operations into deepmd/dpmodel/array_api.py by introducing unified xp_* helpers and refactoring call sites to use them, reducing scattered JAX/PyTorch conditionals across the dpmodel code.

Changes:

  • Added xp_sigmoid, xp_setitem_at, and a generic xp_scatter_sum implementation in deepmd/dpmodel/array_api.py (keeping PyTorch fast paths).
  • Refactored dpmodel logic to call the centralized helpers (network sigmoid + masked assignment; transform output scatter-sum).
  • Removed the now-unneeded JAX-only scatter_sum implementation and expanded consistency tests (including array-api-strict coverage).

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
deepmd/dpmodel/array_api.py Adds centralized backend helpers (xp_sigmoid, xp_setitem_at) and replaces JAX-only scatter_sum with a generic xp_scatter_sum (PyTorch optimized path retained).
deepmd/dpmodel/utils/network.py Replaces inline JAX conditionals with calls to xp_sigmoid and xp_setitem_at.
deepmd/dpmodel/model/transform_output.py Switches from a JAX-only scatter_sum path to the unified xp_scatter_sum.
deepmd/jax/common.py Removes the old JAX-only scatter_sum helper.
source/tests/consistent/test_array_api.py Adds consistency tests for new helpers and extends array-api-strict test coverage.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
@codecov
Copy link

codecov bot commented Feb 8, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 81.97%. Comparing base (ef0be63) to head (e7bb076).

Additional details and impacted files
@@           Coverage Diff           @@
##           master    #5202   +/-   ##
=======================================
  Coverage   81.96%   81.97%           
=======================================
  Files         714      714           
  Lines       73508    73516    +8     
  Branches     3615     3616    +1     
=======================================
+ Hits        60252    60265   +13     
+ Misses      12092    12090    -2     
+ Partials     1164     1161    -3     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

@njzjz njzjz requested review from iProzd and wanghan-iapcm February 8, 2026 09:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature Request] Refactor JAX backend logic in deepmd/dpmodel into array_api.py

2 participants