refact (pt_expt): provide infrastructure for converting dpmodel classes to PyTorch modules. #5204
refact (pt_expt): provide infrastructure for converting dpmodel classes to PyTorch modules. #5204wanghan-iapcm wants to merge 34 commits intodeepmodeling:masterfrom
Conversation
Summary of ChangesHello @wanghan-iapcm, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request establishes a foundational framework for integrating Highlights
Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a significant and well-designed infrastructure for a new "PyTorch Exportable" (pt-expt) backend. The core idea of automatically wrapping dpmodel classes into torch.nn.Modules by detecting attribute types is clever and should improve maintainability. The changes are extensive, touching not only the new backend implementation but also core dpmodel components for better device consistency and the testing framework for validation.
My review has identified two important bug fixes: one related to environment variable handling for parallelism and another related to thread assignment in the PyTorch environment. I've also pointed out an opportunity to reduce code duplication between the pt and the new pt_expt environment setup modules, which would improve long-term maintainability.
Overall, this is a high-quality contribution that lays a solid foundation for making dpmodels compatible with torch.export.
📝 WalkthroughWalkthroughIntroduces a registry-driven dpmodel→PyTorch converter, device-aware array conversion and unified setattr flow, PyTorch wrapper descriptor classes (DescrptSeA, DescrptSeR), exclude-mask wrappers, and PyTorch-compatible network and NetworkCollection wrappers. Changes
Sequence DiagramsequenceDiagram
participant User
participant Registry as Registry System
participant Converter
participant DPModel as DPModel Object
participant Wrapper as PT_EXPT Wrapper
participant TorchModule as torch.nn.Module
User->>Registry: register_dpmodel_mapping(dpmodel_cls, converter)
Registry->>Registry: Store mapping
User->>Wrapper: try_convert_module(dpmodel_obj)
Wrapper->>Registry: Lookup converter for dpmodel_obj.__class__
alt Mapping found
Registry-->>Wrapper: converter callable
Wrapper->>Converter: converter(dpmodel_obj)
Converter->>DPModel: Read config/state (ntypes, arrays, etc.)
Converter->>Wrapper: Instantiate wrapper with converted args
Wrapper->>TorchModule: Initialize torch.nn.Module base
Wrapper-->>User: Return pt_expt wrapper instance
else No mapping
Registry-->>Wrapper: None
Wrapper-->>User: Return None
end
User->>Wrapper: setattr(wrapper, name, value)
Wrapper->>Wrapper: dpmodel_setattr(name, value)
alt value is numpy array
Wrapper->>Wrapper: to_torch_array(value) -> Tensor
Wrapper->>TorchModule: register_buffer / set parameter
else value is dpmodel object
Wrapper->>Registry: try_convert_module(value)
Registry-->>Wrapper: converted wrapper or None
Wrapper->>TorchModule: set attribute to converted module or original
else other
Wrapper->>TorchModule: normal setattr
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
No actionable comments were generated in the recent review. 🎉 🧹 Recent nitpick comments
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In `@deepmd/backend/pt_expt.py`:
- Around line 37-42: The declared features flag in pt_expt.py advertises
deep_eval, neighbor_stat and IO but the corresponding properties deep_eval,
neighbor_stat, serialize_hook and deserialize_hook raise NotImplementedError,
causing callers using Backend.get_backends_by_feature() (e.g.
entrypoints/neighbor_stat.py and entrypoints/main.py) to crash; to fix, change
the ClassVar features tuple on the Backend subclass to only include
Backend.Feature.ENTRY_POINT until you implement those hooks, or alternatively
implement working deep_eval, neighbor_stat, serialize_hook and deserialize_hook
methods that return the expected values; reference the symbols features,
deep_eval, neighbor_stat, serialize_hook, deserialize_hook and
Backend.get_backends_by_feature() when making the change.
In `@source/tests/consistent/descriptor/test_se_e2_a.py`:
- Around line 546-566: The eval_pt_expt method references env.DEVICE (in
eval_pt_expt) but env is only imported under the INSTALLED_PT guard; import the
correct env used by pt_expt to avoid NameError when INSTALLED_PT_EXPT is true
but INSTALLED_PT is false by adding an import from deepmd.pt_expt.utils.env (or
by importing env inside eval_pt_expt) so eval_pt_expt can always access
env.DEVICE; locate the eval_pt_expt function and update imports or add a local
import to ensure env is defined when this test runs.
🧹 Nitpick comments (6)
deepmd/pt_expt/utils/env.py (2)
1-127: Near-complete duplication ofdeepmd/pt/utils/env.py.This file is virtually identical to
deepmd/pt/utils/env.py— same constants, same device logic, same precision dicts, same thread guards. The only difference is the comment on line 97. When a constant or guard needs to change, both files must be updated in lockstep.Consider extracting the shared logic into a common helper (e.g.,
deepmd/pt_common/env.pyor a shared setup function) that bothptandpt_exptenv modules call. Each backend module can then layer on its own specifics.
19-20:import torchis separated from other top-level imports.The
torchimport sits after thedeepmd.envblock rather than grouped withnumpyat line 7. This appears to be an oversight from code movement.Suggested fix
import numpy as np +import torch from deepmd.common import ( VALID_PRECISION, @@ -17,7 +18,6 @@ ) log = logging.getLogger(__name__) -import torchdeepmd/pt_expt/utils/network.py (1)
27-37:TorchArrayParam— clean custom Parameter subclass.The
__array__protocol implementation correctly detaches, moves to CPU, and converts to numpy, which enables seamless interop with dpmodel code that expects numpy arrays.One minor note from static analysis: the
# noqa: PYI034on line 28 is flagged as unused by Ruff. It can be removed.Suggested fix
- def __new__( # noqa: PYI034 + def __new__( cls, data: Any = None, requires_grad: bool = True ) -> "TorchArrayParam":deepmd/pt_expt/common.py (2)
45-82:register_dpmodel_mappingsilently overwrites existing entries.If a dpmodel class is registered twice (e.g., due to module re-import or conflicting registrations), the second call silently replaces the first converter with no warning. This could cause hard-to-debug issues during development.
Consider adding a warning or raising on duplicate registration:
Suggested guard
""" + if dpmodel_cls in _DPMODEL_TO_PT_EXPT: + import logging + logging.getLogger(__name__).warning( + f"Overwriting existing pt_expt mapping for {dpmodel_cls.__name__}" + ) _DPMODEL_TO_PT_EXPT[dpmodel_cls] = converter
237-278:to_torch_array:torch.as_tensorshares memory with numpy arrays — document this intent or usetorch.tensorif a copy is desired.
torch.as_tensoron line 278 shares memory with the source numpy array when possible (same dtype, CPU). If the caller later mutates the original array, the tensor changes too. The dpmodel convention (replace, don't mutate) makes this safe in practice, but it's a latent footgun for users of this utility who may not share that assumption. The docstring could note this behavior.source/tests/pt/test_env_threads.py (1)
10-34: Near-identical test exists insource/tests/pt_expt/utils/test_env.py.Both tests share the same monkeypatch/reload/assert pattern, differing only in the target env module. Consider extracting a shared helper parameterized by the module path to reduce duplication, though this is minor given the test's brevity.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 3452a2a8c0
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5204 +/- ##
=======================================
Coverage 81.99% 82.00%
=======================================
Files 724 724
Lines 73807 73801 -6
Branches 3616 3615 -1
=======================================
+ Hits 60519 60520 +1
+ Misses 12124 12118 -6
+ Partials 1164 1163 -1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
consider after the merge of #5194
automatically wrapping dpmodel classes (array_api_compat-based) as PyTorch modules. The key insight is to detect attributes by their value type rather than by hard-coded names.
Summary by CodeRabbit