-
Notifications
You must be signed in to change notification settings - Fork 590
refact(pt_expt): add decorator to simplify the module #5213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,9 @@ | |
| from collections.abc import ( | ||
| Callable, | ||
| ) | ||
| from functools import ( | ||
| wraps, | ||
| ) | ||
| from typing import ( | ||
| Any, | ||
| overload, | ||
|
|
@@ -292,6 +295,46 @@ def to_torch_array(array: Any) -> torch.Tensor | None: | |
| return torch.as_tensor(array, device=env.DEVICE) | ||
|
|
||
|
|
||
| def torch_module( | ||
| module: type[NativeOP], | ||
| ) -> type[torch.nn.Module]: | ||
| """Convert a NativeOP to a torch.nn.Module. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| module : type[NativeOP] | ||
| The NativeOP to convert. | ||
|
|
||
| Returns | ||
| ------- | ||
| type[torch.nn.Module] | ||
| The torch.nn.Module. | ||
|
|
||
| Examples | ||
| -------- | ||
| >>> @torch_module | ||
| ... class MyModule(NativeOP): | ||
| ... pass | ||
| """ | ||
|
|
||
| @wraps(module, updated=()) | ||
| class TorchModule(module, torch.nn.Module): | ||
| def __init__(self, *args: Any, **kwargs: Any) -> None: | ||
| torch.nn.Module.__init__(self) | ||
| module.__init__(self, *args, **kwargs) | ||
|
Comment on lines
+320
to
+324
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: #!/bin/bash
# First, let's check the context around lines 320-324 in deepmd/pt_expt/common.py
echo "=== Context around lines 320-324 in common.py ==="
cat -n deepmd/pt_expt/common.py | sed -n '315,330p'Repository: deepmodeling/deepmd-kit Length of output: 822 🏁 Script executed: #!/bin/bash
# Check if NetworkCollection or decorated classes call super().__init__()
echo "=== Check NetworkCollection and related classes ==="
find deepmd/pt_expt -name "*.py" | head -20Repository: deepmodeling/deepmd-kit Length of output: 454 🏁 Script executed: #!/bin/bash
# Find the network.py file if it exists
echo "=== Looking for network-related files ==="
fd -t f "network" deepmd/pt_expt/Repository: deepmodeling/deepmd-kit Length of output: 140 🏁 Script executed: #!/bin/bash
# Check NetworkCollection class and its __init__ method
echo "=== Looking for NetworkCollection class ==="
rg -n "class NetworkCollection" deepmd/pt_expt/utils/network.py -A 30Repository: deepmodeling/deepmd-kit Length of output: 1158 🏁 Script executed: #!/bin/bash
# Check for any super().__init__() calls in the decorated classes
echo "=== Check for super().__init__() calls in pt_expt modules ==="
rg -n "super\(\)\.__init__\(\)" deepmd/pt_expt/ -B 5 -A 5Repository: deepmodeling/deepmd-kit Length of output: 128 🏁 Script executed: #!/bin/bash
# Let's inspect torch.nn.Module.__init__ to see what it does
python3 << 'EOF'
import torch
import inspect
# Get the source code of torch.nn.Module.__init__
try:
source = inspect.getsource(torch.nn.Module.__init__)
print("=== torch.nn.Module.__init__ source ===")
print(source[:1000]) # First 1000 chars
except:
print("Could not get source, checking what it sets")
# Create a test instance and see what attributes are set
print("\n=== Testing torch.nn.Module.__init__ behavior ===")
m = torch.nn.Module()
print(f"_parameters: {type(m._parameters)}, content: {m._parameters}")
print(f"_buffers: {type(m._buffers)}, content: {m._buffers}")
print(f"_modules: {type(m._modules)}, content: {m._modules}")
# Now test calling __init__ twice
print("\n=== Testing double __init__ call ===")
m2 = torch.nn.Module()
m2.custom_attr = "test"
print(f"Before second __init__: custom_attr = {m2.custom_attr}")
torch.nn.Module.__init__(m2)
print(f"After second __init__: custom_attr = {m2.custom_attr}")
print(f"_parameters after second __init__: {m2._parameters}")
EOFRepository: deepmodeling/deepmd-kit Length of output: 184 🏁 Script executed: #!/bin/bash
# Find NetworkCollectionDP
echo "=== Looking for NetworkCollectionDP class ==="
rg -n "class NetworkCollectionDP" deepmd/ -A 15Repository: deepmodeling/deepmd-kit Length of output: 112 🏁 Script executed: #!/bin/bash
# Let's also check imports to understand the base classes
echo "=== Check imports in network.py ==="
head -50 deepmd/pt_expt/utils/network.pyRepository: deepmodeling/deepmd-kit Length of output: 1566 🌐 Web query:
💡 Result:
Concretely,
Implications:
Sources: Address double The decorator explicitly calls For example, in 🤖 Prompt for AI Agents |
||
|
|
||
| def __call__(self, *args: Any, **kwargs: Any) -> Any: | ||
| # Ensure torch.nn.Module.__call__ drives forward() for export/tracing. | ||
| return torch.nn.Module.__call__(self, *args, **kwargs) | ||
|
|
||
| def __setattr__(self, name: str, value: Any) -> None: | ||
| handled, value = dpmodel_setattr(self, name, value) | ||
| if not handled: | ||
| super().__setattr__(name, value) | ||
|
|
||
| return TorchModule | ||
|
Comment on lines
298
to
335
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
To fix this, def torch_module(
module: type,
) -> type[torch.nn.Module]:
"""Convert a NativeOP to a torch.nn.Module.
Parameters
----------
module : type
The class to convert.
Returns
-------
torch.nn.Module
The torch.nn.Module.
Examples
--------
>>> @torch_module
... class MyModule(NativeOP):
... pass
"""
@wraps(module, updated=())
class TorchModule(module, torch.nn.Module):
def __init__(self, *args: Any, **kwargs: Any) -> None:
torch.nn.Module.__init__(self)
module.__init__(self, *args, **kwargs)
if hasattr(module, "forward"):
def __call__(self, *args: Any, **kwargs: Any) -> Any:
# Ensure torch.nn.Module.__call__ drives forward() for export/tracing.
return torch.nn.Module.__call__(self, *args, **kwargs)
def __setattr__(self, name: str, value: Any) -> None:
handled, value = dpmodel_setattr(self, name, value)
if not handled:
super().__setattr__(name, value)
return TorchModule
Comment on lines
+298
to
+335
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: #!/bin/bash
# Check if there are any exact-type checks (type(x) is SomeClass) for decorated classes
echo "=== Check for exact type checks on decorated classes ==="
rg -nP 'type\(.+\)\s+is\s+(DescrptSeA|DescrptSeR|NativeLayer|NativeNet|NetworkCollection|AtomExcludeMask|PairExcludeMask)\b' --type=py
echo ""
echo "=== Also check for __class__ comparisons ==="
rg -nP '__class__\s*==\s+(DescrptSeA|DescrptSeR|NativeLayer|NativeNet|NetworkCollection)' --type=pyRepository: deepmodeling/deepmd-kit Length of output: 169 🏁 Script executed: #!/bin/bash
# Find where torch_module decorator is used
echo "=== Uses of `@torch_module` decorator ==="
rg -n '@torch_module' --type=py -A 2Repository: deepmodeling/deepmd-kit Length of output: 1692 🏁 Script executed: #!/bin/bash
# Check if decorated classes define forward() method
echo "=== Checking if decorated classes define forward() ==="
rg -n 'def forward\(' deepmd/pt_expt/descriptor/ deepmd/pt_expt/fitting/ deepmd/pt_expt/network.py --type=py -B 5 | head -100Repository: deepmodeling/deepmd-kit Length of output: 1019 🏁 Script executed: #!/bin/bash
# Verify the import of wraps and the decorator implementation
echo "=== Verify functools import and wraps usage ==="
head -30 deepmd/pt_expt/common.py | rg -n 'wraps|functools'
echo ""
echo "=== Full torch_module decorator implementation ==="
sed -n '298,335p' deepmd/pt_expt/common.pyRepository: deepmodeling/deepmd-kit Length of output: 1287 🏁 Script executed: #!/bin/bash
# Check if AtomExcludeMask and PairExcludeMask define forward() method
echo "=== Check AtomExcludeMask and PairExcludeMask in exclude_mask.py ==="
cat deepmd/pt_expt/utils/exclude_mask.pyRepository: deepmodeling/deepmd-kit Length of output: 842 🏁 Script executed: #!/bin/bash
# Also check if the parent classes define forward() or if it's optional
echo "=== Check parent classes of exclude mask classes ==="
rg -n 'class AtomExcludeMaskDP|class PairExcludeMaskDP' --type=py -A 10Repository: deepmodeling/deepmd-kit Length of output: 119 🏁 Script executed: #!/bin/bash
# Check what methods the parent dpmodel classes might define
echo "=== Checking NativeLayer and NativeNet more thoroughly ==="
sed -n '41,73p' deepmd/pt_expt/utils/network.py
echo ""
sed -n '76,92p' deepmd/pt_expt/utils/network.pyRepository: deepmodeling/deepmd-kit Length of output: 1967 🏁 Script executed: #!/bin/bash
# Check if NativeNet and NetworkCollection have forward() method defined
echo "=== Checking if NativeNet has forward() ==="
rg -A 50 'class NativeNet' deepmd/pt_expt/utils/network.py | head -60Repository: deepmodeling/deepmd-kit Length of output: 1531 🏁 Script executed: #!/bin/bash
# Check if NetworkCollection has forward() method defined
echo "=== Checking if NetworkCollection has forward() ==="
rg -A 50 'class NetworkCollection' deepmd/pt_expt/utils/network.py | head -60Repository: deepmodeling/deepmd-kit Length of output: 1051 🏁 Script executed: #!/bin/bash
# Check if forward() is called anywhere on decorated classes
echo "=== Check for forward() calls on NetworkCollection ==="
rg -n 'NetworkCollection.*\.forward\(' --type=py
echo ""
echo "=== Check if NetworkCollection is called directly ==="
rg -n '\([^)]*NetworkCollection[^)]*\)\(' --type=py | head -20Repository: deepmodeling/deepmd-kit Length of output: 176 🏁 Script executed: #!/bin/bash
# Check parent classes for forward() method to see if it's inherited
echo "=== Check dpmodel parent classes for forward() ==="
python3 << 'EOF'
import sys
sys.path.insert(0, '.')
try:
from deepmd.dpmodel.utils.exclude_mask import AtomExcludeMask as AtomExcludeMaskDP
from deepmd.dpmodel.utils.exclude_mask import PairExcludeMask as PairExcludeMaskDP
print("AtomExcludeMaskDP has forward():", hasattr(AtomExcludeMaskDP, 'forward'))
print("PairExcludeMaskDP has forward():", hasattr(PairExcludeMaskDP, 'forward'))
except Exception as e:
print(f"Error: {e}")
EOFRepository: deepmodeling/deepmd-kit Length of output: 159 🏁 Script executed: #!/bin/bash
# Check if NetworkCollection or its parent is actually called in a way that needs forward()
echo "=== Check usage of decorated classes ==="
rg -n 'NetworkCollection\(' --type=py | head -10
echo ""
echo "=== Check if these are module outputs used in torch.nn.Module context ==="
rg -n '\.forward\(|__call__\(' deepmd/pt_expt/utils/exclude_mask.py deepmd/pt_expt/utils/network.py | head -20Repository: deepmodeling/deepmd-kit Length of output: 1056 🏁 Script executed: #!/bin/bash
# Check if the parent dpmodel classes are NativeOP or have a call() method
echo "=== Check NativeOP definition and call() method ==="
rg -n 'class NativeOP|def call\(' deepmd/dpmodel/common.py | head -30
echo ""
echo "=== Look for NativeOP imports ==="
rg -n 'from.*NativeOP|import.*NativeOP' deepmd/pt_expt/common.pyRepository: deepmodeling/deepmd-kit Length of output: 238 🏁 Script executed: #!/bin/bash
# Check the exclude_mask parent classes to see if they inherit from NativeOP
echo "=== Check parent classes of AtomExcludeMask ==="
rg -B5 'class AtomExcludeMask' deepmd/dpmodel/utils/exclude_mask.py 2>/dev/null || echo "File not found in dpmodel"
echo ""
echo "=== Check if AtomExcludeMask/PairExcludeMask are actually subclasses of NativeOP ==="
python3 << 'EOF'
import sys
sys.path.insert(0, '.')
try:
# Check what NativeOP is
from deepmd.pt_expt.common import NativeOP
print(f"NativeOP: {NativeOP}")
print(f"NativeOP is: {type(NativeOP)}")
except ImportError as e:
print(f"Cannot import: {e}")
EOFRepository: deepmodeling/deepmd-kit Length of output: 312 🏁 Script executed: #!/bin/bash
# Check if NetworkCollection, AtomExcludeMask, PairExcludeMask might be called in a torch context
echo "=== Check if NetworkCollection instances are used with __call__ ==="
rg -n 'networks\(|collection\(' deepmd/pt_expt/ --type=py | head -20
echo ""
echo "=== Check test files to see how decorated classes are used ==="
find deepmd/pt_expt -name '*test*.py' -type f | head -5Repository: deepmodeling/deepmd-kit Length of output: 198 🏁 Script executed: #!/bin/bash
# Check the dpmodel parent classes more carefully
echo "=== Check AtomExcludeMaskDP and PairExcludeMaskDP inheritance ==="
head -50 deepmd/dpmodel/utils/exclude_mask.pyRepository: deepmodeling/deepmd-kit Length of output: 1343 🏁 Script executed: #!/bin/bash
# Check if NetworkCollectionDP or its parents define forward() or call()
echo "=== Check NetworkCollectionDP definition ==="
rg -A 20 'class NetworkCollectionDP' deepmd/dpmodel/utils/network.py | head -30
echo ""
echo "=== Check if these parent classes have call() method ==="
rg -n 'def call\(' deepmd/dpmodel/utils/Repository: deepmodeling/deepmd-kit Length of output: 676 🏁 Script executed: #!/bin/bash
# Check if the parent classes of decorated classes actually inherit from NativeOP
echo "=== Check what AtomExcludeMaskDP inherits from ==="
grep -A 1 'class AtomExcludeMask' deepmd/dpmodel/utils/exclude_mask.py | head -3
echo ""
echo "=== Check NetworkCollectionDP inheritance ==="
grep 'class NetworkCollectionDP' deepmd/dpmodel/utils/network.py
echo ""
echo "=== Verify NativeOP definition to understand what classes should inherit from it ==="
sed -n '93,120p' deepmd/dpmodel/common.pyRepository: deepmodeling/deepmd-kit Length of output: 1003 🏁 Script executed: #!/bin/bash
# Check if classes that don't define forward() are actually expected to be called via torch
echo "=== Check imports and usage context for AtomExcludeMask ==="
rg -n 'from.*exclude_mask import|import.*exclude_mask' deepmd/pt_expt/ --type=py
echo ""
echo "=== Check how NetworkCollection is instantiated ==="
grep -r 'NetworkCollection(' deepmd/pt_expt/ --include='*.py' | head -5Repository: deepmodeling/deepmd-kit Length of output: 565 🏁 Script executed: #!/bin/bash
# Check if there are any tests that use the decorated exclude_mask classes
echo "=== Check exclude_mask test files ==="
find . -path './deepmd/pt_expt*' -name '*exclude_mask*' -o -path './deepmd/pt_expt*' -name '*test*.py' | grep -i excludeRepository: deepmodeling/deepmd-kit Length of output: 143 🏁 Script executed: #!/bin/bash
# Search for NetworkCollectionDP more broadly
echo "=== Search for NetworkCollectionDP ==="
rg -n 'NetworkCollectionDP' --type=py
echo ""
echo "=== Check the full network.py file for NetworkCollectionDP ==="
cat deepmd/dpmodel/utils/network.py | grep -n 'class.*Collection'Repository: deepmodeling/deepmd-kit Length of output: 1331 🏁 Script executed: #!/bin/bash
# Check if these parent classes (AtomExcludeMaskDP, PairExcludeMaskDP) inherit from NativeOP
echo "=== Full AtomExcludeMask and PairExcludeMask definitions ==="
cat deepmd/dpmodel/utils/exclude_mask.pyRepository: deepmodeling/deepmd-kit Length of output: 4547 Type annotation allows non-NativeOP classes, and AtomExcludeMask/PairExcludeMask lack forward() method required by torch.nn.Module. The decorator's
The initialization order (torch.nn.Module first, then dpmodel) is correct and ensures 🤖 Prompt for AI Agents |
||
|
|
||
|
|
||
| # Import utils to trigger dpmodel→pt_expt converter registrations | ||
| # This must happen after the functions above are defined to avoid circular imports | ||
| def _ensure_registrations() -> None: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.