-
Notifications
You must be signed in to change notification settings - Fork 589
Description
Summary
Refactor all JAX-specific backend code currently spread across files in deepmd/dpmodel/ (except for array_api.py) by consolidating these implementations into deepmd/dpmodel/array_api.py. This will reduce code duplication, improve maintainability, and centralize NUMPY, JAX, and TORCH backend logic in one place.
Detailed Description
Many files in deepmd/dpmodel/ contain implementation blocks like if is_jax_array for JAX backend support. To streamline future backend improvements and reduce maintenance overhead, all such JAX-specific logic should be migrated into array_api.py. This centralization will make backend wrappers easier to maintain and update.
Example JAX references found:
- deepmd/dpmodel/utils/network.py, ~line 1149: JAX-specific array assignment logic
- deepmd/dpmodel/model/transform_output.py, ~line 217: JAX scatter_sum usage
After the refactoring, other files should invoke unified backend functions from array_api.py instead of handling JAX conditionals locally.
(Search results are limited, please see GitHub search for full context.)
Further Information, Files, and Links
No response