Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions deepmd/pt_expt/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from collections.abc import (
Callable,
)
from functools import (
wraps,
)
from typing import (
Any,
overload,
Expand Down Expand Up @@ -292,6 +295,46 @@ def to_torch_array(array: Any) -> torch.Tensor | None:
return torch.as_tensor(array, device=env.DEVICE)


def torch_module(
module: type[NativeOP],
) -> type[torch.nn.Module]:
"""Convert a NativeOP to a torch.nn.Module.

Parameters
----------
module : type[NativeOP]
The NativeOP to convert.

Returns
-------
type[torch.nn.Module]
The torch.nn.Module.

Examples
--------
>>> @torch_module
... class MyModule(NativeOP):
... pass
"""

@wraps(module, updated=())
class TorchModule(module, torch.nn.Module):
def __init__(self, *args: Any, **kwargs: Any) -> None:
torch.nn.Module.__init__(self)
module.__init__(self, *args, **kwargs)
Comment on lines +320 to +324
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# First, let's check the context around lines 320-324 in deepmd/pt_expt/common.py
echo "=== Context around lines 320-324 in common.py ==="
cat -n deepmd/pt_expt/common.py | sed -n '315,330p'

Repository: deepmodeling/deepmd-kit

Length of output: 822


🏁 Script executed:

#!/bin/bash
# Check if NetworkCollection or decorated classes call super().__init__()
echo "=== Check NetworkCollection and related classes ==="
find deepmd/pt_expt -name "*.py" | head -20

Repository: deepmodeling/deepmd-kit

Length of output: 454


🏁 Script executed:

#!/bin/bash
# Find the network.py file if it exists
echo "=== Looking for network-related files ==="
fd -t f "network" deepmd/pt_expt/

Repository: deepmodeling/deepmd-kit

Length of output: 140


🏁 Script executed:

#!/bin/bash
# Check NetworkCollection class and its __init__ method
echo "=== Looking for NetworkCollection class ==="
rg -n "class NetworkCollection" deepmd/pt_expt/utils/network.py -A 30

Repository: deepmodeling/deepmd-kit

Length of output: 1158


🏁 Script executed:

#!/bin/bash
# Check for any super().__init__() calls in the decorated classes
echo "=== Check for super().__init__() calls in pt_expt modules ==="
rg -n "super\(\)\.__init__\(\)" deepmd/pt_expt/ -B 5 -A 5

Repository: deepmodeling/deepmd-kit

Length of output: 128


🏁 Script executed:

#!/bin/bash
# Let's inspect torch.nn.Module.__init__ to see what it does
python3 << 'EOF'
import torch
import inspect

# Get the source code of torch.nn.Module.__init__
try:
    source = inspect.getsource(torch.nn.Module.__init__)
    print("=== torch.nn.Module.__init__ source ===")
    print(source[:1000])  # First 1000 chars
except:
    print("Could not get source, checking what it sets")

# Create a test instance and see what attributes are set
print("\n=== Testing torch.nn.Module.__init__ behavior ===")
m = torch.nn.Module()
print(f"_parameters: {type(m._parameters)}, content: {m._parameters}")
print(f"_buffers: {type(m._buffers)}, content: {m._buffers}")
print(f"_modules: {type(m._modules)}, content: {m._modules}")

# Now test calling __init__ twice
print("\n=== Testing double __init__ call ===")
m2 = torch.nn.Module()
m2.custom_attr = "test"
print(f"Before second __init__: custom_attr = {m2.custom_attr}")
torch.nn.Module.__init__(m2)
print(f"After second __init__: custom_attr = {m2.custom_attr}")
print(f"_parameters after second __init__: {m2._parameters}")
EOF

Repository: deepmodeling/deepmd-kit

Length of output: 184


🏁 Script executed:

#!/bin/bash
# Find NetworkCollectionDP
echo "=== Looking for NetworkCollectionDP class ==="
rg -n "class NetworkCollectionDP" deepmd/ -A 15

Repository: deepmodeling/deepmd-kit

Length of output: 112


🏁 Script executed:

#!/bin/bash
# Let's also check imports to understand the base classes
echo "=== Check imports in network.py ==="
head -50 deepmd/pt_expt/utils/network.py

Repository: deepmodeling/deepmd-kit

Length of output: 1566


🌐 Web query:

torch.nn.Module.__init__ behavior what does it initialize reset state

💡 Result:

torch.nn.Module.__init__() does not initialize model weights. It initializes the bookkeeping/state containers that make PyTorch “see” your parameters, buffers, and submodules.

Concretely, Module.__init__() sets up (at least):

  • self.training = True (default training mode)
  • empty registries like self._parameters, self._buffers, self._modules (and hook/state_dict hook dicts, plus _non_persistent_buffers_set) so that later assignments like self.linear = nn.Linear(...) get registered properly. [1]

Implications:

  • You must call super().__init__() before assigning submodules/parameters/buffers, otherwise registration fails (PyTorch explicitly documents this requirement). [2]
  • Calling Module.__init__() again on an existing module effectively resets those registries (you can “lose” previously registered submodules/parameters from the module’s tracking), but it doesn’t “reinitialize weights” in the usual layer-specific sense. [1][2]
  • Weight initialization happens in the specific layer constructors (e.g., nn.Linear.__init__), not in nn.Module.__init__. [3]

Sources:
[1] https://glaringlee.github.io/_modules/torch/nn/modules/module.html
[2] https://docs.pytorch.org/docs/stable/generated/torch.nn.Module.html
[3] https://docs.pytorch.org/docs/stable/notes/modules.html


Address double torch.nn.Module.__init__ calls in the decorator pattern.

The decorator explicitly calls torch.nn.Module.__init__ on line 323, then calls module.__init__ on line 324. When NetworkCollectionDP.__init__ (and classes that inherit from it) call super().__init__() at the end, the MRO routes this back to torch.nn.Module.__init__, resulting in a second call. This resets _parameters, _buffers, and _modules dicts, causing any state registered in the decorated class's __init__ before the super() call to be lost.

For example, in NetworkCollection.__init__ (line 104), the super().__init__() call after setting _module_networks will reset those registries. While _module_networks itself is not a registered parameter/buffer/submodule, the pattern is fragile: any future logic that registers state before calling super().__init__() will have it silently discarded. Consider calling torch.nn.Module.__init__ only once, either before or after module.__init__, rather than both.

🤖 Prompt for AI Agents
In `@deepmd/pt_expt/common.py` around lines 320 - 324, The TorchModule wrapper
currently calls torch.nn.Module.__init__ and then module.__init__, causing
torch.nn.Module.__init__ to run twice via MRO when subclasses (e.g.,
NetworkCollectionDP / NetworkCollection.__init__) call super().__init__(), which
resets internal registries; fix by ensuring torch.nn.Module.__init__ is invoked
exactly once in the wrapper (either call torch.nn.Module.__init__(self) before
calling module.__init__(self, ...) or omit the explicit call and let
module.__init__ call it), updating the TorchModule class definition accordingly
so that module.__init__ is not followed by a second torch.nn.Module.__init__
invocation.


def __call__(self, *args: Any, **kwargs: Any) -> Any:
# Ensure torch.nn.Module.__call__ drives forward() for export/tracing.
return torch.nn.Module.__call__(self, *args, **kwargs)

def __setattr__(self, name: str, value: Any) -> None:
handled, value = dpmodel_setattr(self, name, value)
if not handled:
super().__setattr__(name, value)

return TorchModule
Comment on lines 298 to 335

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The torch_module decorator is a great idea to reduce boilerplate. However, it has a couple of issues:

  1. The type hint module: type[NativeOP] is too restrictive, as this decorator is also used on classes that do not inherit from NativeOP (e.g., AtomExcludeMask, NetworkCollection). It should be module: type.
  2. The __call__ method is defined unconditionally. This will cause a NotImplementedError at runtime if an instance of a decorated class without a forward method is called. This affects classes like AtomExcludeMask and NetworkCollection which are not meant to be callable modules.

To fix this, __call__ should only be defined if the decorated class has a forward method.

def torch_module(
    module: type,
) -> type[torch.nn.Module]:
    """Convert a NativeOP to a torch.nn.Module.

    Parameters
    ----------
    module : type
        The class to convert.

    Returns
    -------
    torch.nn.Module
        The torch.nn.Module.

    Examples
    --------
    >>> @torch_module
    ... class MyModule(NativeOP):
    ...     pass
    """

    @wraps(module, updated=())
    class TorchModule(module, torch.nn.Module):
        def __init__(self, *args: Any, **kwargs: Any) -> None:
            torch.nn.Module.__init__(self)
            module.__init__(self, *args, **kwargs)

        if hasattr(module, "forward"):

            def __call__(self, *args: Any, **kwargs: Any) -> Any:
                # Ensure torch.nn.Module.__call__ drives forward() for export/tracing.
                return torch.nn.Module.__call__(self, *args, **kwargs)

        def __setattr__(self, name: str, value: Any) -> None:
            handled, value = dpmodel_setattr(self, name, value)
            if not handled:
                super().__setattr__(name, value)

    return TorchModule

Comment on lines +298 to +335
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Check if there are any exact-type checks (type(x) is SomeClass) for decorated classes
echo "=== Check for exact type checks on decorated classes ==="
rg -nP 'type\(.+\)\s+is\s+(DescrptSeA|DescrptSeR|NativeLayer|NativeNet|NetworkCollection|AtomExcludeMask|PairExcludeMask)\b' --type=py
echo ""
echo "=== Also check for __class__ comparisons ==="
rg -nP '__class__\s*==\s+(DescrptSeA|DescrptSeR|NativeLayer|NativeNet|NetworkCollection)' --type=py

Repository: deepmodeling/deepmd-kit

Length of output: 169


🏁 Script executed:

#!/bin/bash
# Find where torch_module decorator is used
echo "=== Uses of `@torch_module` decorator ==="
rg -n '@torch_module' --type=py -A 2

Repository: deepmodeling/deepmd-kit

Length of output: 1692


🏁 Script executed:

#!/bin/bash
# Check if decorated classes define forward() method
echo "=== Checking if decorated classes define forward() ==="
rg -n 'def forward\(' deepmd/pt_expt/descriptor/ deepmd/pt_expt/fitting/ deepmd/pt_expt/network.py --type=py -B 5 | head -100

Repository: deepmodeling/deepmd-kit

Length of output: 1019


🏁 Script executed:

#!/bin/bash
# Verify the import of wraps and the decorator implementation
echo "=== Verify functools import and wraps usage ==="
head -30 deepmd/pt_expt/common.py | rg -n 'wraps|functools'
echo ""
echo "=== Full torch_module decorator implementation ==="
sed -n '298,335p' deepmd/pt_expt/common.py

Repository: deepmodeling/deepmd-kit

Length of output: 1287


🏁 Script executed:

#!/bin/bash
# Check if AtomExcludeMask and PairExcludeMask define forward() method
echo "=== Check AtomExcludeMask and PairExcludeMask in exclude_mask.py ==="
cat deepmd/pt_expt/utils/exclude_mask.py

Repository: deepmodeling/deepmd-kit

Length of output: 842


🏁 Script executed:

#!/bin/bash
# Also check if the parent classes define forward() or if it's optional
echo "=== Check parent classes of exclude mask classes ==="
rg -n 'class AtomExcludeMaskDP|class PairExcludeMaskDP' --type=py -A 10

Repository: deepmodeling/deepmd-kit

Length of output: 119


🏁 Script executed:

#!/bin/bash
# Check what methods the parent dpmodel classes might define
echo "=== Checking NativeLayer and NativeNet more thoroughly ==="
sed -n '41,73p' deepmd/pt_expt/utils/network.py
echo ""
sed -n '76,92p' deepmd/pt_expt/utils/network.py

Repository: deepmodeling/deepmd-kit

Length of output: 1967


🏁 Script executed:

#!/bin/bash
# Check if NativeNet and NetworkCollection have forward() method defined
echo "=== Checking if NativeNet has forward() ==="
rg -A 50 'class NativeNet' deepmd/pt_expt/utils/network.py | head -60

Repository: deepmodeling/deepmd-kit

Length of output: 1531


🏁 Script executed:

#!/bin/bash
# Check if NetworkCollection has forward() method defined
echo "=== Checking if NetworkCollection has forward() ==="
rg -A 50 'class NetworkCollection' deepmd/pt_expt/utils/network.py | head -60

Repository: deepmodeling/deepmd-kit

Length of output: 1051


🏁 Script executed:

#!/bin/bash
# Check if forward() is called anywhere on decorated classes
echo "=== Check for forward() calls on NetworkCollection ==="
rg -n 'NetworkCollection.*\.forward\(' --type=py
echo ""
echo "=== Check if NetworkCollection is called directly ==="
rg -n '\([^)]*NetworkCollection[^)]*\)\(' --type=py | head -20

Repository: deepmodeling/deepmd-kit

Length of output: 176


🏁 Script executed:

#!/bin/bash
# Check parent classes for forward() method to see if it's inherited
echo "=== Check dpmodel parent classes for forward() ==="
python3 << 'EOF'
import sys
sys.path.insert(0, '.')
try:
    from deepmd.dpmodel.utils.exclude_mask import AtomExcludeMask as AtomExcludeMaskDP
    from deepmd.dpmodel.utils.exclude_mask import PairExcludeMask as PairExcludeMaskDP
    print("AtomExcludeMaskDP has forward():", hasattr(AtomExcludeMaskDP, 'forward'))
    print("PairExcludeMaskDP has forward():", hasattr(PairExcludeMaskDP, 'forward'))
except Exception as e:
    print(f"Error: {e}")
EOF

Repository: deepmodeling/deepmd-kit

Length of output: 159


🏁 Script executed:

#!/bin/bash
# Check if NetworkCollection or its parent is actually called in a way that needs forward()
echo "=== Check usage of decorated classes ==="
rg -n 'NetworkCollection\(' --type=py | head -10
echo ""
echo "=== Check if these are module outputs used in torch.nn.Module context ==="
rg -n '\.forward\(|__call__\(' deepmd/pt_expt/utils/exclude_mask.py deepmd/pt_expt/utils/network.py | head -20

Repository: deepmodeling/deepmd-kit

Length of output: 1056


🏁 Script executed:

#!/bin/bash
# Check if the parent dpmodel classes are NativeOP or have a call() method
echo "=== Check NativeOP definition and call() method ==="
rg -n 'class NativeOP|def call\(' deepmd/dpmodel/common.py | head -30
echo ""
echo "=== Look for NativeOP imports ==="
rg -n 'from.*NativeOP|import.*NativeOP' deepmd/pt_expt/common.py

Repository: deepmodeling/deepmd-kit

Length of output: 238


🏁 Script executed:

#!/bin/bash
# Check the exclude_mask parent classes to see if they inherit from NativeOP
echo "=== Check parent classes of AtomExcludeMask ==="
rg -B5 'class AtomExcludeMask' deepmd/dpmodel/utils/exclude_mask.py 2>/dev/null || echo "File not found in dpmodel"
echo ""
echo "=== Check if AtomExcludeMask/PairExcludeMask are actually subclasses of NativeOP ==="
python3 << 'EOF'
import sys
sys.path.insert(0, '.')
try:
    # Check what NativeOP is
    from deepmd.pt_expt.common import NativeOP
    print(f"NativeOP: {NativeOP}")
    print(f"NativeOP is: {type(NativeOP)}")
except ImportError as e:
    print(f"Cannot import: {e}")
EOF

Repository: deepmodeling/deepmd-kit

Length of output: 312


🏁 Script executed:

#!/bin/bash
# Check if NetworkCollection, AtomExcludeMask, PairExcludeMask might be called in a torch context
echo "=== Check if NetworkCollection instances are used with __call__ ==="
rg -n 'networks\(|collection\(' deepmd/pt_expt/ --type=py | head -20
echo ""
echo "=== Check test files to see how decorated classes are used ==="
find deepmd/pt_expt -name '*test*.py' -type f | head -5

Repository: deepmodeling/deepmd-kit

Length of output: 198


🏁 Script executed:

#!/bin/bash
# Check the dpmodel parent classes more carefully
echo "=== Check AtomExcludeMaskDP and PairExcludeMaskDP inheritance ==="
head -50 deepmd/dpmodel/utils/exclude_mask.py

Repository: deepmodeling/deepmd-kit

Length of output: 1343


🏁 Script executed:

#!/bin/bash
# Check if NetworkCollectionDP or its parents define forward() or call()
echo "=== Check NetworkCollectionDP definition ==="
rg -A 20 'class NetworkCollectionDP' deepmd/dpmodel/utils/network.py | head -30
echo ""
echo "=== Check if these parent classes have call() method ==="
rg -n 'def call\(' deepmd/dpmodel/utils/

Repository: deepmodeling/deepmd-kit

Length of output: 676


🏁 Script executed:

#!/bin/bash
# Check if the parent classes of decorated classes actually inherit from NativeOP
echo "=== Check what AtomExcludeMaskDP inherits from ==="
grep -A 1 'class AtomExcludeMask' deepmd/dpmodel/utils/exclude_mask.py | head -3
echo ""
echo "=== Check NetworkCollectionDP inheritance ==="
grep 'class NetworkCollectionDP' deepmd/dpmodel/utils/network.py
echo ""
echo "=== Verify NativeOP definition to understand what classes should inherit from it ==="
sed -n '93,120p' deepmd/dpmodel/common.py

Repository: deepmodeling/deepmd-kit

Length of output: 1003


🏁 Script executed:

#!/bin/bash
# Check if classes that don't define forward() are actually expected to be called via torch
echo "=== Check imports and usage context for AtomExcludeMask ==="
rg -n 'from.*exclude_mask import|import.*exclude_mask' deepmd/pt_expt/ --type=py
echo ""
echo "=== Check how NetworkCollection is instantiated ==="
grep -r 'NetworkCollection(' deepmd/pt_expt/ --include='*.py' | head -5

Repository: deepmodeling/deepmd-kit

Length of output: 565


🏁 Script executed:

#!/bin/bash
# Check if there are any tests that use the decorated exclude_mask classes
echo "=== Check exclude_mask test files ==="
find . -path './deepmd/pt_expt*' -name '*exclude_mask*' -o -path './deepmd/pt_expt*' -name '*test*.py' | grep -i exclude

Repository: deepmodeling/deepmd-kit

Length of output: 143


🏁 Script executed:

#!/bin/bash
# Search for NetworkCollectionDP more broadly
echo "=== Search for NetworkCollectionDP ==="
rg -n 'NetworkCollectionDP' --type=py
echo ""
echo "=== Check the full network.py file for NetworkCollectionDP ==="
cat deepmd/dpmodel/utils/network.py | grep -n 'class.*Collection'

Repository: deepmodeling/deepmd-kit

Length of output: 1331


🏁 Script executed:

#!/bin/bash
# Check if these parent classes (AtomExcludeMaskDP, PairExcludeMaskDP) inherit from NativeOP
echo "=== Full AtomExcludeMask and PairExcludeMask definitions ==="
cat deepmd/dpmodel/utils/exclude_mask.py

Repository: deepmodeling/deepmd-kit

Length of output: 4547


Type annotation allows non-NativeOP classes, and AtomExcludeMask/PairExcludeMask lack forward() method required by torch.nn.Module.

The decorator's __call__ method routes through torch.nn.Module.__call__, which invokes forward(). However:

  1. Type mismatch: The parameter is annotated module: type[NativeOP], but AtomExcludeMask and PairExcludeMask inherit from plain dpmodel classes that do not inherit from NativeOP. This violates the documented contract.

  2. Missing forward() implementation: While DescrptSeA, DescrptSeR, NativeLayer, and NativeNet all define forward(), neither AtomExcludeMask nor PairExcludeMask do. Calling these instances via __call__ will fail with NotImplementedError or AttributeError.

The initialization order (torch.nn.Module first, then dpmodel) is correct and ensures _parameters/_buffers/_modules dicts exist before dpmodel_setattr runs. The @wraps(module, updated=()) properly preserves metadata. No exact type checks (type(x) is ClassName) were found in the codebase, so the isinstance/inheritance trade-off is safe.

🤖 Prompt for AI Agents
In `@deepmd/pt_expt/common.py` around lines 298 - 335, The decorator torch_module
currently types the input as type[NativeOP] but is used on classes like
AtomExcludeMask and PairExcludeMask that do not inherit NativeOP and also lack a
forward() implementation required by torch.nn.Module.__call__; update
torch_module to accept any class (e.g., type[Any]) or tighten callers to only
pass NativeOP subclasses, and ensure classes AtomExcludeMask and PairExcludeMask
either inherit NativeOP or implement a forward(self, *args, **kwargs) method
that raises NotImplementedError or provides appropriate behavior; locate
torch_module, AtomExcludeMask, PairExcludeMask, NativeOP, and dpmodel_setattr in
the diff to make the change and keep the existing initialization order and
dpmodel_setattr handling.



# Import utils to trigger dpmodel→pt_expt converter registrations
# This must happen after the functions above are defined to avoid circular imports
def _ensure_registrations() -> None:
Expand Down
21 changes: 3 additions & 18 deletions deepmd/pt_expt/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
)

import torch

from deepmd.dpmodel.descriptor.se_e2_a import DescrptSeAArrayAPI as DescrptSeADP
from deepmd.pt_expt.common import (
dpmodel_setattr,
torch_module,
)
from deepmd.pt_expt.descriptor.base_descriptor import (
BaseDescriptor,
Expand All @@ -16,20 +13,8 @@

@BaseDescriptor.register("se_e2_a_expt")
@BaseDescriptor.register("se_a_expt")
class DescrptSeA(DescrptSeADP, torch.nn.Module):
def __init__(self, *args: Any, **kwargs: Any) -> None:
torch.nn.Module.__init__(self)
DescrptSeADP.__init__(self, *args, **kwargs)

def __call__(self, *args: Any, **kwargs: Any) -> Any:
# Ensure torch.nn.Module.__call__ drives forward() for export/tracing.
return torch.nn.Module.__call__(self, *args, **kwargs)

def __setattr__(self, name: str, value: Any) -> None:
handled, value = dpmodel_setattr(self, name, value)
if not handled:
super().__setattr__(name, value)

@torch_module
class DescrptSeA(DescrptSeADP):
def forward(
self,
extended_coord: torch.Tensor,
Expand Down
21 changes: 3 additions & 18 deletions deepmd/pt_expt/descriptor/se_r.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
)

import torch

from deepmd.dpmodel.descriptor.se_r import DescrptSeR as DescrptSeRDP
from deepmd.pt_expt.common import (
dpmodel_setattr,
torch_module,
)
from deepmd.pt_expt.descriptor.base_descriptor import (
BaseDescriptor,
Expand All @@ -16,20 +13,8 @@

@BaseDescriptor.register("se_e2_r_expt")
@BaseDescriptor.register("se_r_expt")
class DescrptSeR(DescrptSeRDP, torch.nn.Module):
def __init__(self, *args: Any, **kwargs: Any) -> None:
torch.nn.Module.__init__(self)
DescrptSeRDP.__init__(self, *args, **kwargs)

def __call__(self, *args: Any, **kwargs: Any) -> Any:
# Ensure torch.nn.Module.__call__ drives forward() for export/tracing.
return torch.nn.Module.__call__(self, *args, **kwargs)

def __setattr__(self, name: str, value: Any) -> None:
handled, value = dpmodel_setattr(self, name, value)
if not handled:
super().__setattr__(name, value)

@torch_module
class DescrptSeR(DescrptSeRDP):
def forward(
self,
extended_coord: torch.Tensor,
Expand Down
30 changes: 7 additions & 23 deletions deepmd/pt_expt/utils/exclude_mask.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,17 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
)

import torch

from deepmd.dpmodel.utils.exclude_mask import AtomExcludeMask as AtomExcludeMaskDP
from deepmd.dpmodel.utils.exclude_mask import PairExcludeMask as PairExcludeMaskDP
from deepmd.pt_expt.common import (
dpmodel_setattr,
register_dpmodel_mapping,
torch_module,
)


class AtomExcludeMask(AtomExcludeMaskDP, torch.nn.Module):
def __init__(self, *args: Any, **kwargs: Any) -> None:
torch.nn.Module.__init__(self)
AtomExcludeMaskDP.__init__(self, *args, **kwargs)

def __setattr__(self, name: str, value: Any) -> None:
handled, value = dpmodel_setattr(self, name, value)
if not handled:
super().__setattr__(name, value)
@torch_module
class AtomExcludeMask(AtomExcludeMaskDP):
pass


register_dpmodel_mapping(
Expand All @@ -30,15 +20,9 @@ def __setattr__(self, name: str, value: Any) -> None:
)


class PairExcludeMask(PairExcludeMaskDP, torch.nn.Module):
def __init__(self, *args: Any, **kwargs: Any) -> None:
torch.nn.Module.__init__(self)
PairExcludeMaskDP.__init__(self, *args, **kwargs)

def __setattr__(self, name: str, value: Any) -> None:
handled, value = dpmodel_setattr(self, name, value)
if not handled:
super().__setattr__(name, value)
@torch_module
class PairExcludeMask(PairExcludeMaskDP):
pass


register_dpmodel_mapping(
Expand Down
22 changes: 7 additions & 15 deletions deepmd/pt_expt/utils/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from deepmd.pt_expt.common import (
register_dpmodel_mapping,
to_torch_array,
torch_module,
)


Expand All @@ -37,14 +38,8 @@ def __array__(self, dtype: Any | None = None) -> np.ndarray:
return arr.astype(dtype)


class NativeLayer(NativeLayerDP, torch.nn.Module):
def __init__(self, *args: Any, **kwargs: Any) -> None:
torch.nn.Module.__init__(self)
NativeLayerDP.__init__(self, *args, **kwargs)

def __call__(self, *args: Any, **kwargs: Any) -> Any:
return torch.nn.Module.__call__(self, *args, **kwargs)

@torch_module
class NativeLayer(NativeLayerDP):
def __setattr__(self, name: str, value: Any) -> None:
if name in {"w", "b", "idt"} and "_parameters" in self.__dict__:
val = to_torch_array(value)
Expand Down Expand Up @@ -78,15 +73,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.call(x)


class NativeNet(make_multilayer_network(NativeLayer, NativeOP), torch.nn.Module):
@torch_module
class NativeNet(make_multilayer_network(NativeLayer, NativeOP)):
def __init__(self, layers: list[dict] | None = None) -> None:
torch.nn.Module.__init__(self)
super().__init__(layers)
self.layers = torch.nn.ModuleList(self.layers)

def __call__(self, *args: Any, **kwargs: Any) -> Any:
return torch.nn.Module.__call__(self, *args, **kwargs)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.call(x)

Expand All @@ -99,15 +91,15 @@ class FittingNet(make_fitting_network(EmbeddingNet, NativeNet, NativeLayer)):
pass


class NetworkCollection(NetworkCollectionDP, torch.nn.Module):
@torch_module
class NetworkCollection(NetworkCollectionDP):
NETWORK_TYPE_MAP: ClassVar[dict[str, type]] = {
"network": NativeNet,
"embedding_network": EmbeddingNet,
"fitting_network": FittingNet,
}

def __init__(self, *args: Any, **kwargs: Any) -> None:
torch.nn.Module.__init__(self)
self._module_networks = torch.nn.ModuleDict()
super().__init__(*args, **kwargs)

Expand Down