-
Notifications
You must be signed in to change notification settings - Fork 590
refactor: consolidate backend logic into array_api.py with generic implementations #5202
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Consolidates backend-specific array operations into deepmd/dpmodel/array_api.py by introducing unified xp_* helpers and refactoring call sites to use them, reducing scattered JAX/PyTorch conditionals across the dpmodel code.
Changes:
- Added
xp_sigmoid,xp_setitem_at, and a genericxp_scatter_sumimplementation indeepmd/dpmodel/array_api.py(keeping PyTorch fast paths). - Refactored dpmodel logic to call the centralized helpers (network sigmoid + masked assignment; transform output scatter-sum).
- Removed the now-unneeded JAX-only
scatter_sumimplementation and expanded consistency tests (including array-api-strict coverage).
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
deepmd/dpmodel/array_api.py |
Adds centralized backend helpers (xp_sigmoid, xp_setitem_at) and replaces JAX-only scatter_sum with a generic xp_scatter_sum (PyTorch optimized path retained). |
deepmd/dpmodel/utils/network.py |
Replaces inline JAX conditionals with calls to xp_sigmoid and xp_setitem_at. |
deepmd/dpmodel/model/transform_output.py |
Switches from a JAX-only scatter_sum path to the unified xp_scatter_sum. |
deepmd/jax/common.py |
Removes the old JAX-only scatter_sum helper. |
source/tests/consistent/test_array_api.py |
Adds consistency tests for new helpers and extends array-api-strict test coverage. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## master #5202 +/- ##
=======================================
Coverage 81.96% 81.97%
=======================================
Files 714 714
Lines 73508 73516 +8
Branches 3615 3616 +1
=======================================
+ Hits 60252 60265 +13
+ Misses 12092 12090 -2
+ Partials 1164 1161 -3 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Refactor backend logic into array_api.py
Analysis Complete
Implementation Complete
xp_sigmoidfunction to array_api.py for backend-specific sigmoid handlingxp_setitem_atfunction to array_api.py for backend-specific array assignmentxp_scatter_sumusing array_api operationsxp_sigmoidfunctionxp_setitem_atfunctionxp_scatter_sumfrom array_api.pyChanges Summary
deepmd/dpmodel/array_api.py - Centralized backend implementations:
torch.sigmoid()and JAX-specificjax.nn.sigmoid()implementations with generic fallback.at[].set()syntax and PyTorch'sclone()for non-mutating behavior, while NumPy uses in-place assignmentxp_take_along_axisandxp_add_athelper functions (merged JAX implementation logic), with PyTorch optimization retaineddeepmd/dpmodel/utils/network.py - Refactored to use centralized functions:
sigmoid_tfunction with call toxp_sigmoidxp_setitem_atdeepmd/dpmodel/model/transform_output.py - Simplified scatter_sum usage:
deepmd.jax.common.scatter_sumxp_scatter_sumfrom array_api.py consistentlydeepmd/jax/common.py - Removed unused code:
scatter_sumfunction (replaced by generic implementation in array_api.py)source/tests/consistent/test_array_api.py - Complete test coverage across all backends:
TestXpSigmoidConsistentTestXpSetitemAtConsistentTestXpScatterSumConsistentTestXpBincountConsistentTestXpBincountWithWeightsConsistentTestXpBincountWithMinlengthConsistentTestXpSetitemAtConsistent(consistent with other tests)Test Results
Benefits
if is_jax_array,if is_torch_array) centralized in array_api.pyfrom deepmd.jax,from torch) outside of array_api.pytorch.sigmoid,torch.scatter_add)Original prompt
💡 You can make Copilot smarter by setting up custom instructions, customizing its development environment and configuring Model Context Protocol (MCP) servers. Learn more Copilot coding agent tips in the docs.