-
Notifications
You must be signed in to change notification settings - Fork 590
refact(pt_expt): add decorator to simplify the module #5213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
Summary of ChangesHello @njzjz, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request refactors the Highlights
Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a @torch_module decorator to reduce boilerplate code for wrapping dpmodel classes as torch.nn.Modules. The refactoring is applied to several descriptor and utility classes, which significantly simplifies the codebase. My review focuses on a potential runtime issue in the new decorator and some inconsistencies in its application. Specifically, the decorator unconditionally adds a __call__ method which can cause errors for non-callable modules, and some refactored classes still redundantly inherit from torch.nn.Module.
| def torch_module( | ||
| module: type[NativeOP], | ||
| ) -> type[torch.nn.Module]: | ||
| """Convert a NativeOP to a torch.nn.Module. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| module : NativeOP | ||
| The NativeOP to convert. | ||
|
|
||
| Returns | ||
| ------- | ||
| torch.nn.Module | ||
| The torch.nn.Module. | ||
|
|
||
| Examples | ||
| -------- | ||
| >>> @torch_module | ||
| ... class MyModule(NativeOP): | ||
| ... pass | ||
| """ | ||
|
|
||
| @wraps(module, updated=()) | ||
| class TorchModule(module, torch.nn.Module): | ||
| def __init__(self, *args: Any, **kwargs: Any) -> None: | ||
| torch.nn.Module.__init__(self) | ||
| module.__init__(self, *args, **kwargs) | ||
|
|
||
| def __call__(self, *args: Any, **kwargs: Any) -> Any: | ||
| # Ensure torch.nn.Module.__call__ drives forward() for export/tracing. | ||
| return torch.nn.Module.__call__(self, *args, **kwargs) | ||
|
|
||
| def __setattr__(self, name: str, value: Any) -> None: | ||
| handled, value = dpmodel_setattr(self, name, value) | ||
| if not handled: | ||
| super().__setattr__(name, value) | ||
|
|
||
| return TorchModule |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The torch_module decorator is a great idea to reduce boilerplate. However, it has a couple of issues:
- The type hint
module: type[NativeOP]is too restrictive, as this decorator is also used on classes that do not inherit fromNativeOP(e.g.,AtomExcludeMask,NetworkCollection). It should bemodule: type. - The
__call__method is defined unconditionally. This will cause aNotImplementedErrorat runtime if an instance of a decorated class without aforwardmethod is called. This affects classes likeAtomExcludeMaskandNetworkCollectionwhich are not meant to be callable modules.
To fix this, __call__ should only be defined if the decorated class has a forward method.
def torch_module(
module: type,
) -> type[torch.nn.Module]:
"""Convert a NativeOP to a torch.nn.Module.
Parameters
----------
module : type
The class to convert.
Returns
-------
torch.nn.Module
The torch.nn.Module.
Examples
--------
>>> @torch_module
... class MyModule(NativeOP):
... pass
"""
@wraps(module, updated=())
class TorchModule(module, torch.nn.Module):
def __init__(self, *args: Any, **kwargs: Any) -> None:
torch.nn.Module.__init__(self)
module.__init__(self, *args, **kwargs)
if hasattr(module, "forward"):
def __call__(self, *args: Any, **kwargs: Any) -> Any:
# Ensure torch.nn.Module.__call__ drives forward() for export/tracing.
return torch.nn.Module.__call__(self, *args, **kwargs)
def __setattr__(self, name: str, value: Any) -> None:
handled, value = dpmodel_setattr(self, name, value)
if not handled:
super().__setattr__(name, value)
return TorchModuleThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR introduces a @torch_module decorator to reduce boilerplate in pt_expt wrappers that adapt dpmodel classes into torch.nn.Modules for export/tracing.
Changes:
- Added
torch_moduledecorator indeepmd/pt_expt/common.pyto centralizetorch.nn.Moduleinitialization,__call__, and__setattr__handling. - Refactored multiple
pt_exptwrappers (networks, exclude masks, descriptors) to use@torch_moduleinstead of repeating wrapper logic. - Removed now-redundant imports and wrapper boilerplate in affected modules.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| deepmd/pt_expt/utils/network.py | Applies @torch_module to network wrappers to remove duplicated torch.nn.Module boilerplate. |
| deepmd/pt_expt/utils/exclude_mask.py | Applies @torch_module to exclude mask wrappers and removes duplicated __setattr__ logic/imports. |
| deepmd/pt_expt/descriptor/se_r.py | Applies @torch_module to descriptor wrapper and removes duplicated boilerplate. |
| deepmd/pt_expt/descriptor/se_e2_a.py | Applies @torch_module to descriptor wrapper and removes duplicated boilerplate. |
| deepmd/pt_expt/common.py | Adds the torch_module decorator implementation. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Use `wraps` to keep the modules' names, so they won't be `FlaxModule`, which cannot be regonized. I realized it when implementing deepmodeling#5213.
📝 WalkthroughWalkthroughAdds a new public API Changes
Sequence Diagram(s)sequenceDiagram
participant Caller as User code / Caller
participant Decorator as torch_module (factory)
participant DPClass as NativeOP / DP base
participant Torch as torch.nn.Module
participant dpattr as dpmodel_setattr
Note over Decorator,DPClass: Decorator builds combined class (TorchModule)
Caller->>Decorator: apply `@torch_module` to DPClass
Decorator->>DPClass: create subclass combining DPClass + Torch
Decorator->>dpattr: integrate attribute hook
Caller->>DPClass: instantiate combined module
DPClass->>Torch: initialize torch.nn.Module base
DPClass->>DPClass: initialize DP base
Caller->>DPClass: call(module)(inputs)
DPClass->>Torch: route __call__ -> forward
Torch->>DPClass: forward runs (may call DPClass.call)
DPClass->>dpattr: attribute sets delegate to dpmodel_setattr (fallback to setattr)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~40 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
deepmd/pt_expt/utils/network.py (1)
41-73:⚠️ Potential issue | 🔴 CriticalThe decorator's
dpmodel_setattrintercepts numpy arrays beforeNativeLayer.__setattr__can apply trainable-aware logic, preventing trainable parameters from being created with gradients.During
NativeLayerDP.__init__, whenself.w = np.array(...)is assigned:
TorchModule.__setattr__is invoked (via decorator)dpmodel_setattrsees the numpy array and_buffersexists, so it registerswas a buffer and returnshandled=True(line 213 ofcommon.py)- Because
handled=True, the decorator's__setattr__skipssuper().__setattr__()and never reachesNativeLayer.__setattr__NativeLayer.__setattr__(lines 43–70) with its special trainable-aware logic (lines 54–63) is completely bypassedResult: When
trainable=True,w/b/idtbecome buffers without gradients instead ofTorchArrayParamparameters withrequires_grad=True, breaking gradient computation for trainable layers.
🤖 Fix all issues with AI agents
In `@deepmd/pt_expt/common.py`:
- Around line 298-335: The decorator torch_module currently forces
torch.nn.Module.__call__ and thus requires a forward() method, making decorated
classes without forward() (e.g., AtomExcludeMask, PairExcludeMask,
NetworkCollection) non-functional; modify the TorchModule produced by
torch_module to define a default forward(self, *args, **kwargs) that, if the
instance has a call method (i.e., is NativeOP-like), delegates to
self.call(*args, **kwargs), otherwise raises the existing NotImplementedError or
leaves behavior unchanged; adjust the decorator logic to detect NativeOP via
isinstance(self, NativeOP) or hasattr(self, "call") and ensure dpmodel_setattr
and other overrides remain intact, and for purely utility classes remove the
`@torch_module` usage or add explicit forward implementations.
🧹 Nitpick comments (3)
deepmd/pt_expt/common.py (1)
320-321: Redundanttorch.nn.Modulein consumer class bases when using@torch_module.Several consumer classes (e.g.,
DescrptSeA(DescrptSeADP, torch.nn.Module)) listtorch.nn.Moduleexplicitly in their bases AND apply@torch_module, which also injectstorch.nn.Moduleas a base of the generatedTorchModule. Python's C3 linearization handles the duplicate safely, but it's misleading — it suggests the class is already atorch.nn.Modulewithout the decorator.Consider documenting this expectation in the docstring (that
torch.nn.Modulein the original class bases is acceptable/expected), or note in the consumer files that the explicit base is optional with@torch_module.deepmd/pt_expt/descriptor/se_e2_a.py (1)
14-17: Redundanttorch.nn.Modulein class bases.
@torch_modulealready injectstorch.nn.Moduleinto the generated class hierarchy. The explicittorch.nn.Moduleinclass DescrptSeA(DescrptSeADP, torch.nn.Module)is harmless (C3 deduplicates it) but misleading — it suggests the class is aModuleindependent of the decorator.Optional cleanup
`@BaseDescriptor.register`("se_e2_a_expt") `@BaseDescriptor.register`("se_a_expt") `@torch_module` -class DescrptSeA(DescrptSeADP, torch.nn.Module): +class DescrptSeA(DescrptSeADP):deepmd/pt_expt/descriptor/se_r.py (1)
14-17: Same redundanttorch.nn.Modulebase asse_e2_a.py.Optional cleanup
`@BaseDescriptor.register`("se_e2_r_expt") `@BaseDescriptor.register`("se_r_expt") `@torch_module` -class DescrptSeR(DescrptSeRDP, torch.nn.Module): +class DescrptSeR(DescrptSeRDP):
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: Jinzhe Zeng <njzjz@qq.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
deepmd/pt_expt/utils/network.py (1)
41-73:⚠️ Potential issue | 🔴 CriticalCritical:
dpmodel_setattrinterceptsw/b/idtnumpy arrays during initialization, registering them as non-trainable buffers and bypassingNativeLayer.__setattr__'s trainable check.The decorator's
TorchModule.__setattr__(common.py:330-333) callsdpmodel_setattrfirst, which detects numpy arrays (common.py:207) and registers them as buffers beforeNativeLayer.__setattr__(line 43) can execute. For trainable layers:
- During
NativeLayerDP.__init__,self.w = numpy_arraytriggersTorchModule.__setattr__dpmodel_setattrseesisinstance(value, np.ndarray)and_buffersin__dict__→ callsregister_buffer→ returns(True, tensor)handled=TruecausesTorchModule.__setattr__to skipsuper().__setattr__(), soNativeLayer.__setattr__never runs- The
trainablecheck on line 54 is unreachable- Weights end up as non-trainable buffers instead of
TorchArrayParamparametersThe existing test only verifies the None-clearing case (setting w=None after init), not initial registration, so this bug is undetected. Trainable layers will not have gradient flow through their weights.
🤖 Fix all issues with AI agents
In `@deepmd/pt_expt/common.py`:
- Around line 320-324: The TorchModule wrapper currently calls
torch.nn.Module.__init__ and then module.__init__, causing
torch.nn.Module.__init__ to run twice via MRO when subclasses (e.g.,
NetworkCollectionDP / NetworkCollection.__init__) call super().__init__(), which
resets internal registries; fix by ensuring torch.nn.Module.__init__ is invoked
exactly once in the wrapper (either call torch.nn.Module.__init__(self) before
calling module.__init__(self, ...) or omit the explicit call and let
module.__init__ call it), updating the TorchModule class definition accordingly
so that module.__init__ is not followed by a second torch.nn.Module.__init__
invocation.
- Around line 298-335: The decorator torch_module currently types the input as
type[NativeOP] but is used on classes like AtomExcludeMask and PairExcludeMask
that do not inherit NativeOP and also lack a forward() implementation required
by torch.nn.Module.__call__; update torch_module to accept any class (e.g.,
type[Any]) or tighten callers to only pass NativeOP subclasses, and ensure
classes AtomExcludeMask and PairExcludeMask either inherit NativeOP or implement
a forward(self, *args, **kwargs) method that raises NotImplementedError or
provides appropriate behavior; locate torch_module, AtomExcludeMask,
PairExcludeMask, NativeOP, and dpmodel_setattr in the diff to make the change
and keep the existing initialization order and dpmodel_setattr handling.
| def torch_module( | ||
| module: type[NativeOP], | ||
| ) -> type[torch.nn.Module]: | ||
| """Convert a NativeOP to a torch.nn.Module. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| module : type[NativeOP] | ||
| The NativeOP to convert. | ||
|
|
||
| Returns | ||
| ------- | ||
| type[torch.nn.Module] | ||
| The torch.nn.Module. | ||
|
|
||
| Examples | ||
| -------- | ||
| >>> @torch_module | ||
| ... class MyModule(NativeOP): | ||
| ... pass | ||
| """ | ||
|
|
||
| @wraps(module, updated=()) | ||
| class TorchModule(module, torch.nn.Module): | ||
| def __init__(self, *args: Any, **kwargs: Any) -> None: | ||
| torch.nn.Module.__init__(self) | ||
| module.__init__(self, *args, **kwargs) | ||
|
|
||
| def __call__(self, *args: Any, **kwargs: Any) -> Any: | ||
| # Ensure torch.nn.Module.__call__ drives forward() for export/tracing. | ||
| return torch.nn.Module.__call__(self, *args, **kwargs) | ||
|
|
||
| def __setattr__(self, name: str, value: Any) -> None: | ||
| handled, value = dpmodel_setattr(self, name, value) | ||
| if not handled: | ||
| super().__setattr__(name, value) | ||
|
|
||
| return TorchModule |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Check if there are any exact-type checks (type(x) is SomeClass) for decorated classes
echo "=== Check for exact type checks on decorated classes ==="
rg -nP 'type\(.+\)\s+is\s+(DescrptSeA|DescrptSeR|NativeLayer|NativeNet|NetworkCollection|AtomExcludeMask|PairExcludeMask)\b' --type=py
echo ""
echo "=== Also check for __class__ comparisons ==="
rg -nP '__class__\s*==\s+(DescrptSeA|DescrptSeR|NativeLayer|NativeNet|NetworkCollection)' --type=pyRepository: deepmodeling/deepmd-kit
Length of output: 169
🏁 Script executed:
#!/bin/bash
# Find where torch_module decorator is used
echo "=== Uses of `@torch_module` decorator ==="
rg -n '@torch_module' --type=py -A 2Repository: deepmodeling/deepmd-kit
Length of output: 1692
🏁 Script executed:
#!/bin/bash
# Check if decorated classes define forward() method
echo "=== Checking if decorated classes define forward() ==="
rg -n 'def forward\(' deepmd/pt_expt/descriptor/ deepmd/pt_expt/fitting/ deepmd/pt_expt/network.py --type=py -B 5 | head -100Repository: deepmodeling/deepmd-kit
Length of output: 1019
🏁 Script executed:
#!/bin/bash
# Verify the import of wraps and the decorator implementation
echo "=== Verify functools import and wraps usage ==="
head -30 deepmd/pt_expt/common.py | rg -n 'wraps|functools'
echo ""
echo "=== Full torch_module decorator implementation ==="
sed -n '298,335p' deepmd/pt_expt/common.pyRepository: deepmodeling/deepmd-kit
Length of output: 1287
🏁 Script executed:
#!/bin/bash
# Check if AtomExcludeMask and PairExcludeMask define forward() method
echo "=== Check AtomExcludeMask and PairExcludeMask in exclude_mask.py ==="
cat deepmd/pt_expt/utils/exclude_mask.pyRepository: deepmodeling/deepmd-kit
Length of output: 842
🏁 Script executed:
#!/bin/bash
# Also check if the parent classes define forward() or if it's optional
echo "=== Check parent classes of exclude mask classes ==="
rg -n 'class AtomExcludeMaskDP|class PairExcludeMaskDP' --type=py -A 10Repository: deepmodeling/deepmd-kit
Length of output: 119
🏁 Script executed:
#!/bin/bash
# Check what methods the parent dpmodel classes might define
echo "=== Checking NativeLayer and NativeNet more thoroughly ==="
sed -n '41,73p' deepmd/pt_expt/utils/network.py
echo ""
sed -n '76,92p' deepmd/pt_expt/utils/network.pyRepository: deepmodeling/deepmd-kit
Length of output: 1967
🏁 Script executed:
#!/bin/bash
# Check if NativeNet and NetworkCollection have forward() method defined
echo "=== Checking if NativeNet has forward() ==="
rg -A 50 'class NativeNet' deepmd/pt_expt/utils/network.py | head -60Repository: deepmodeling/deepmd-kit
Length of output: 1531
🏁 Script executed:
#!/bin/bash
# Check if NetworkCollection has forward() method defined
echo "=== Checking if NetworkCollection has forward() ==="
rg -A 50 'class NetworkCollection' deepmd/pt_expt/utils/network.py | head -60Repository: deepmodeling/deepmd-kit
Length of output: 1051
🏁 Script executed:
#!/bin/bash
# Check if forward() is called anywhere on decorated classes
echo "=== Check for forward() calls on NetworkCollection ==="
rg -n 'NetworkCollection.*\.forward\(' --type=py
echo ""
echo "=== Check if NetworkCollection is called directly ==="
rg -n '\([^)]*NetworkCollection[^)]*\)\(' --type=py | head -20Repository: deepmodeling/deepmd-kit
Length of output: 176
🏁 Script executed:
#!/bin/bash
# Check parent classes for forward() method to see if it's inherited
echo "=== Check dpmodel parent classes for forward() ==="
python3 << 'EOF'
import sys
sys.path.insert(0, '.')
try:
from deepmd.dpmodel.utils.exclude_mask import AtomExcludeMask as AtomExcludeMaskDP
from deepmd.dpmodel.utils.exclude_mask import PairExcludeMask as PairExcludeMaskDP
print("AtomExcludeMaskDP has forward():", hasattr(AtomExcludeMaskDP, 'forward'))
print("PairExcludeMaskDP has forward():", hasattr(PairExcludeMaskDP, 'forward'))
except Exception as e:
print(f"Error: {e}")
EOFRepository: deepmodeling/deepmd-kit
Length of output: 159
🏁 Script executed:
#!/bin/bash
# Check if NetworkCollection or its parent is actually called in a way that needs forward()
echo "=== Check usage of decorated classes ==="
rg -n 'NetworkCollection\(' --type=py | head -10
echo ""
echo "=== Check if these are module outputs used in torch.nn.Module context ==="
rg -n '\.forward\(|__call__\(' deepmd/pt_expt/utils/exclude_mask.py deepmd/pt_expt/utils/network.py | head -20Repository: deepmodeling/deepmd-kit
Length of output: 1056
🏁 Script executed:
#!/bin/bash
# Check if the parent dpmodel classes are NativeOP or have a call() method
echo "=== Check NativeOP definition and call() method ==="
rg -n 'class NativeOP|def call\(' deepmd/dpmodel/common.py | head -30
echo ""
echo "=== Look for NativeOP imports ==="
rg -n 'from.*NativeOP|import.*NativeOP' deepmd/pt_expt/common.pyRepository: deepmodeling/deepmd-kit
Length of output: 238
🏁 Script executed:
#!/bin/bash
# Check the exclude_mask parent classes to see if they inherit from NativeOP
echo "=== Check parent classes of AtomExcludeMask ==="
rg -B5 'class AtomExcludeMask' deepmd/dpmodel/utils/exclude_mask.py 2>/dev/null || echo "File not found in dpmodel"
echo ""
echo "=== Check if AtomExcludeMask/PairExcludeMask are actually subclasses of NativeOP ==="
python3 << 'EOF'
import sys
sys.path.insert(0, '.')
try:
# Check what NativeOP is
from deepmd.pt_expt.common import NativeOP
print(f"NativeOP: {NativeOP}")
print(f"NativeOP is: {type(NativeOP)}")
except ImportError as e:
print(f"Cannot import: {e}")
EOFRepository: deepmodeling/deepmd-kit
Length of output: 312
🏁 Script executed:
#!/bin/bash
# Check if NetworkCollection, AtomExcludeMask, PairExcludeMask might be called in a torch context
echo "=== Check if NetworkCollection instances are used with __call__ ==="
rg -n 'networks\(|collection\(' deepmd/pt_expt/ --type=py | head -20
echo ""
echo "=== Check test files to see how decorated classes are used ==="
find deepmd/pt_expt -name '*test*.py' -type f | head -5Repository: deepmodeling/deepmd-kit
Length of output: 198
🏁 Script executed:
#!/bin/bash
# Check the dpmodel parent classes more carefully
echo "=== Check AtomExcludeMaskDP and PairExcludeMaskDP inheritance ==="
head -50 deepmd/dpmodel/utils/exclude_mask.pyRepository: deepmodeling/deepmd-kit
Length of output: 1343
🏁 Script executed:
#!/bin/bash
# Check if NetworkCollectionDP or its parents define forward() or call()
echo "=== Check NetworkCollectionDP definition ==="
rg -A 20 'class NetworkCollectionDP' deepmd/dpmodel/utils/network.py | head -30
echo ""
echo "=== Check if these parent classes have call() method ==="
rg -n 'def call\(' deepmd/dpmodel/utils/Repository: deepmodeling/deepmd-kit
Length of output: 676
🏁 Script executed:
#!/bin/bash
# Check if the parent classes of decorated classes actually inherit from NativeOP
echo "=== Check what AtomExcludeMaskDP inherits from ==="
grep -A 1 'class AtomExcludeMask' deepmd/dpmodel/utils/exclude_mask.py | head -3
echo ""
echo "=== Check NetworkCollectionDP inheritance ==="
grep 'class NetworkCollectionDP' deepmd/dpmodel/utils/network.py
echo ""
echo "=== Verify NativeOP definition to understand what classes should inherit from it ==="
sed -n '93,120p' deepmd/dpmodel/common.pyRepository: deepmodeling/deepmd-kit
Length of output: 1003
🏁 Script executed:
#!/bin/bash
# Check if classes that don't define forward() are actually expected to be called via torch
echo "=== Check imports and usage context for AtomExcludeMask ==="
rg -n 'from.*exclude_mask import|import.*exclude_mask' deepmd/pt_expt/ --type=py
echo ""
echo "=== Check how NetworkCollection is instantiated ==="
grep -r 'NetworkCollection(' deepmd/pt_expt/ --include='*.py' | head -5Repository: deepmodeling/deepmd-kit
Length of output: 565
🏁 Script executed:
#!/bin/bash
# Check if there are any tests that use the decorated exclude_mask classes
echo "=== Check exclude_mask test files ==="
find . -path './deepmd/pt_expt*' -name '*exclude_mask*' -o -path './deepmd/pt_expt*' -name '*test*.py' | grep -i excludeRepository: deepmodeling/deepmd-kit
Length of output: 143
🏁 Script executed:
#!/bin/bash
# Search for NetworkCollectionDP more broadly
echo "=== Search for NetworkCollectionDP ==="
rg -n 'NetworkCollectionDP' --type=py
echo ""
echo "=== Check the full network.py file for NetworkCollectionDP ==="
cat deepmd/dpmodel/utils/network.py | grep -n 'class.*Collection'Repository: deepmodeling/deepmd-kit
Length of output: 1331
🏁 Script executed:
#!/bin/bash
# Check if these parent classes (AtomExcludeMaskDP, PairExcludeMaskDP) inherit from NativeOP
echo "=== Full AtomExcludeMask and PairExcludeMask definitions ==="
cat deepmd/dpmodel/utils/exclude_mask.pyRepository: deepmodeling/deepmd-kit
Length of output: 4547
Type annotation allows non-NativeOP classes, and AtomExcludeMask/PairExcludeMask lack forward() method required by torch.nn.Module.
The decorator's __call__ method routes through torch.nn.Module.__call__, which invokes forward(). However:
-
Type mismatch: The parameter is annotated
module: type[NativeOP], butAtomExcludeMaskandPairExcludeMaskinherit from plain dpmodel classes that do not inherit fromNativeOP. This violates the documented contract. -
Missing forward() implementation: While
DescrptSeA,DescrptSeR,NativeLayer, andNativeNetall defineforward(), neitherAtomExcludeMasknorPairExcludeMaskdo. Calling these instances via__call__will fail withNotImplementedErrororAttributeError.
The initialization order (torch.nn.Module first, then dpmodel) is correct and ensures _parameters/_buffers/_modules dicts exist before dpmodel_setattr runs. The @wraps(module, updated=()) properly preserves metadata. No exact type checks (type(x) is ClassName) were found in the codebase, so the isinstance/inheritance trade-off is safe.
🤖 Prompt for AI Agents
In `@deepmd/pt_expt/common.py` around lines 298 - 335, The decorator torch_module
currently types the input as type[NativeOP] but is used on classes like
AtomExcludeMask and PairExcludeMask that do not inherit NativeOP and also lack a
forward() implementation required by torch.nn.Module.__call__; update
torch_module to accept any class (e.g., type[Any]) or tighten callers to only
pass NativeOP subclasses, and ensure classes AtomExcludeMask and PairExcludeMask
either inherit NativeOP or implement a forward(self, *args, **kwargs) method
that raises NotImplementedError or provides appropriate behavior; locate
torch_module, AtomExcludeMask, PairExcludeMask, NativeOP, and dpmodel_setattr in
the diff to make the change and keep the existing initialization order and
dpmodel_setattr handling.
| @wraps(module, updated=()) | ||
| class TorchModule(module, torch.nn.Module): | ||
| def __init__(self, *args: Any, **kwargs: Any) -> None: | ||
| torch.nn.Module.__init__(self) | ||
| module.__init__(self, *args, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# First, let's check the context around lines 320-324 in deepmd/pt_expt/common.py
echo "=== Context around lines 320-324 in common.py ==="
cat -n deepmd/pt_expt/common.py | sed -n '315,330p'Repository: deepmodeling/deepmd-kit
Length of output: 822
🏁 Script executed:
#!/bin/bash
# Check if NetworkCollection or decorated classes call super().__init__()
echo "=== Check NetworkCollection and related classes ==="
find deepmd/pt_expt -name "*.py" | head -20Repository: deepmodeling/deepmd-kit
Length of output: 454
🏁 Script executed:
#!/bin/bash
# Find the network.py file if it exists
echo "=== Looking for network-related files ==="
fd -t f "network" deepmd/pt_expt/Repository: deepmodeling/deepmd-kit
Length of output: 140
🏁 Script executed:
#!/bin/bash
# Check NetworkCollection class and its __init__ method
echo "=== Looking for NetworkCollection class ==="
rg -n "class NetworkCollection" deepmd/pt_expt/utils/network.py -A 30Repository: deepmodeling/deepmd-kit
Length of output: 1158
🏁 Script executed:
#!/bin/bash
# Check for any super().__init__() calls in the decorated classes
echo "=== Check for super().__init__() calls in pt_expt modules ==="
rg -n "super\(\)\.__init__\(\)" deepmd/pt_expt/ -B 5 -A 5Repository: deepmodeling/deepmd-kit
Length of output: 128
🏁 Script executed:
#!/bin/bash
# Let's inspect torch.nn.Module.__init__ to see what it does
python3 << 'EOF'
import torch
import inspect
# Get the source code of torch.nn.Module.__init__
try:
source = inspect.getsource(torch.nn.Module.__init__)
print("=== torch.nn.Module.__init__ source ===")
print(source[:1000]) # First 1000 chars
except:
print("Could not get source, checking what it sets")
# Create a test instance and see what attributes are set
print("\n=== Testing torch.nn.Module.__init__ behavior ===")
m = torch.nn.Module()
print(f"_parameters: {type(m._parameters)}, content: {m._parameters}")
print(f"_buffers: {type(m._buffers)}, content: {m._buffers}")
print(f"_modules: {type(m._modules)}, content: {m._modules}")
# Now test calling __init__ twice
print("\n=== Testing double __init__ call ===")
m2 = torch.nn.Module()
m2.custom_attr = "test"
print(f"Before second __init__: custom_attr = {m2.custom_attr}")
torch.nn.Module.__init__(m2)
print(f"After second __init__: custom_attr = {m2.custom_attr}")
print(f"_parameters after second __init__: {m2._parameters}")
EOFRepository: deepmodeling/deepmd-kit
Length of output: 184
🏁 Script executed:
#!/bin/bash
# Find NetworkCollectionDP
echo "=== Looking for NetworkCollectionDP class ==="
rg -n "class NetworkCollectionDP" deepmd/ -A 15Repository: deepmodeling/deepmd-kit
Length of output: 112
🏁 Script executed:
#!/bin/bash
# Let's also check imports to understand the base classes
echo "=== Check imports in network.py ==="
head -50 deepmd/pt_expt/utils/network.pyRepository: deepmodeling/deepmd-kit
Length of output: 1566
🌐 Web query:
torch.nn.Module.__init__ behavior what does it initialize reset state
💡 Result:
torch.nn.Module.__init__() does not initialize model weights. It initializes the bookkeeping/state containers that make PyTorch “see” your parameters, buffers, and submodules.
Concretely, Module.__init__() sets up (at least):
self.training = True(default training mode)- empty registries like
self._parameters,self._buffers,self._modules(and hook/state_dict hook dicts, plus_non_persistent_buffers_set) so that later assignments likeself.linear = nn.Linear(...)get registered properly. [1]
Implications:
- You must call
super().__init__()before assigning submodules/parameters/buffers, otherwise registration fails (PyTorch explicitly documents this requirement). [2] - Calling
Module.__init__()again on an existing module effectively resets those registries (you can “lose” previously registered submodules/parameters from the module’s tracking), but it doesn’t “reinitialize weights” in the usual layer-specific sense. [1][2] - Weight initialization happens in the specific layer constructors (e.g.,
nn.Linear.__init__), not innn.Module.__init__. [3]
Sources:
[1] https://glaringlee.github.io/_modules/torch/nn/modules/module.html
[2] https://docs.pytorch.org/docs/stable/generated/torch.nn.Module.html
[3] https://docs.pytorch.org/docs/stable/notes/modules.html
Address double torch.nn.Module.__init__ calls in the decorator pattern.
The decorator explicitly calls torch.nn.Module.__init__ on line 323, then calls module.__init__ on line 324. When NetworkCollectionDP.__init__ (and classes that inherit from it) call super().__init__() at the end, the MRO routes this back to torch.nn.Module.__init__, resulting in a second call. This resets _parameters, _buffers, and _modules dicts, causing any state registered in the decorated class's __init__ before the super() call to be lost.
For example, in NetworkCollection.__init__ (line 104), the super().__init__() call after setting _module_networks will reset those registries. While _module_networks itself is not a registered parameter/buffer/submodule, the pattern is fragile: any future logic that registers state before calling super().__init__() will have it silently discarded. Consider calling torch.nn.Module.__init__ only once, either before or after module.__init__, rather than both.
🤖 Prompt for AI Agents
In `@deepmd/pt_expt/common.py` around lines 320 - 324, The TorchModule wrapper
currently calls torch.nn.Module.__init__ and then module.__init__, causing
torch.nn.Module.__init__ to run twice via MRO when subclasses (e.g.,
NetworkCollectionDP / NetworkCollection.__init__) call super().__init__(), which
resets internal registries; fix by ensuring torch.nn.Module.__init__ is invoked
exactly once in the wrapper (either call torch.nn.Module.__init__(self) before
calling module.__init__(self, ...) or omit the explicit call and let
module.__init__ call it), updating the TorchModule class definition accordingly
so that module.__init__ is not followed by a second torch.nn.Module.__init__
invocation.
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## master #5213 +/- ##
==========================================
- Coverage 82.00% 81.97% -0.03%
==========================================
Files 724 724
Lines 73801 73779 -22
Branches 3616 3615 -1
==========================================
- Hits 60520 60484 -36
- Misses 12120 12132 +12
- Partials 1161 1163 +2 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Summary by CodeRabbit
New Features
Refactor
Breaking Changes