Skip to content

[Feature Request] Refactor JAX backend logic in deepmd/dpmodel into array_api.py #5201

@njzjz

Description

@njzjz

Summary

Refactor all JAX-specific backend code currently spread across files in deepmd/dpmodel/ (except for array_api.py) by consolidating these implementations into deepmd/dpmodel/array_api.py. This will reduce code duplication, improve maintainability, and centralize NUMPY, JAX, and TORCH backend logic in one place.

Detailed Description

Many files in deepmd/dpmodel/ contain implementation blocks like if is_jax_array for JAX backend support. To streamline future backend improvements and reduce maintenance overhead, all such JAX-specific logic should be migrated into array_api.py. This centralization will make backend wrappers easier to maintain and update.

Example JAX references found:

  • deepmd/dpmodel/utils/network.py, ~line 1149: JAX-specific array assignment logic
  • deepmd/dpmodel/model/transform_output.py, ~line 217: JAX scatter_sum usage

After the refactoring, other files should invoke unified backend functions from array_api.py instead of handling JAX conditionals locally.

(Search results are limited, please see GitHub search for full context.)

Further Information, Files, and Links

No response

Metadata

Metadata

Assignees

Type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions