Skip to content

Conversation

@njzjz
Copy link
Member

@njzjz njzjz commented Feb 9, 2026

Summary by CodeRabbit

  • New Features

    • Added a public adapter to expose DP classes as PyTorch modules.
  • Refactor

    • Switched multiple descriptor, network and utility classes to a decorator-based PyTorch integration for consistent parameter/buffer handling.
  • Breaking Changes

    • Several descriptor forward signatures expanded to accept extended topology and embedding inputs (adjust call sites accordingly).

Copilot AI review requested due to automatic review settings February 9, 2026 21:42
@gemini-code-assist
Copy link

Summary of Changes

Hello @njzjz, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request refactors the pt_expt module by introducing a torch_module decorator. This decorator centralizes the boilerplate code required to integrate NativeOP classes with PyTorch's nn.Module functionality. By applying this decorator to various classes, the change significantly streamlines their definitions, making the codebase more maintainable and easier to understand without altering their core behavior.

Highlights

  • New torch_module Decorator: A new decorator, torch_module, has been introduced in deepmd/pt_expt/common.py. This decorator is designed to simplify the process of converting NativeOP classes into torch.nn.Module compatible classes by handling common initialization and method overriding logic.
  • Code Simplification and Boilerplate Reduction: The torch_module decorator abstracts away the repetitive __init__, __call__, and __setattr__ methods that were previously required for torch.nn.Module integration. This significantly reduces code duplication and improves readability.
  • Refactoring of Existing Classes: Several classes across deepmd/pt_expt/descriptor and deepmd/pt_expt/utils (including DescrptSeA, DescrptSeR, AtomExcludeMask, PairExcludeMask, NativeLayer, NativeNet, and NetworkCollection) have been refactored to utilize the new @torch_module decorator, resulting in cleaner and more concise class definitions.
Changelog
  • deepmd/pt_expt/common.py
    • Added wraps import from functools.
    • Introduced the torch_module decorator, which converts a NativeOP class into a torch.nn.Module by handling __init__, __call__, and __setattr__.
  • deepmd/pt_expt/descriptor/se_e2_a.py
    • Removed Any import from typing.
    • Replaced dpmodel_setattr import with torch_module.
    • Applied the @torch_module decorator to DescrptSeA, removing its explicit __init__, __call__, and __setattr__ methods.
  • deepmd/pt_expt/descriptor/se_r.py
    • Removed Any import from typing.
    • Replaced dpmodel_setattr import with torch_module.
    • Applied the @torch_module decorator to DescrptSeR, removing its explicit __init__, __call__, and __setattr__ methods.
  • deepmd/pt_expt/utils/exclude_mask.py
    • Removed Any import from typing and torch import.
    • Replaced dpmodel_setattr import with torch_module.
    • Applied the @torch_module decorator to AtomExcludeMask and PairExcludeMask, removing their explicit __init__ and __setattr__ methods.
  • deepmd/pt_expt/utils/network.py
    • Added torch_module import from deepmd/pt_expt/common.
    • Applied the @torch_module decorator to NativeLayer, NativeNet, and NetworkCollection, removing their explicit __init__ and __call__ methods where applicable.
Activity
  • No human activity has been recorded on this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a @torch_module decorator to reduce boilerplate code for wrapping dpmodel classes as torch.nn.Modules. The refactoring is applied to several descriptor and utility classes, which significantly simplifies the codebase. My review focuses on a potential runtime issue in the new decorator and some inconsistencies in its application. Specifically, the decorator unconditionally adds a __call__ method which can cause errors for non-callable modules, and some refactored classes still redundantly inherit from torch.nn.Module.

Comment on lines 298 to 335
def torch_module(
module: type[NativeOP],
) -> type[torch.nn.Module]:
"""Convert a NativeOP to a torch.nn.Module.

Parameters
----------
module : NativeOP
The NativeOP to convert.

Returns
-------
torch.nn.Module
The torch.nn.Module.

Examples
--------
>>> @torch_module
... class MyModule(NativeOP):
... pass
"""

@wraps(module, updated=())
class TorchModule(module, torch.nn.Module):
def __init__(self, *args: Any, **kwargs: Any) -> None:
torch.nn.Module.__init__(self)
module.__init__(self, *args, **kwargs)

def __call__(self, *args: Any, **kwargs: Any) -> Any:
# Ensure torch.nn.Module.__call__ drives forward() for export/tracing.
return torch.nn.Module.__call__(self, *args, **kwargs)

def __setattr__(self, name: str, value: Any) -> None:
handled, value = dpmodel_setattr(self, name, value)
if not handled:
super().__setattr__(name, value)

return TorchModule

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The torch_module decorator is a great idea to reduce boilerplate. However, it has a couple of issues:

  1. The type hint module: type[NativeOP] is too restrictive, as this decorator is also used on classes that do not inherit from NativeOP (e.g., AtomExcludeMask, NetworkCollection). It should be module: type.
  2. The __call__ method is defined unconditionally. This will cause a NotImplementedError at runtime if an instance of a decorated class without a forward method is called. This affects classes like AtomExcludeMask and NetworkCollection which are not meant to be callable modules.

To fix this, __call__ should only be defined if the decorated class has a forward method.

def torch_module(
    module: type,
) -> type[torch.nn.Module]:
    """Convert a NativeOP to a torch.nn.Module.

    Parameters
    ----------
    module : type
        The class to convert.

    Returns
    -------
    torch.nn.Module
        The torch.nn.Module.

    Examples
    --------
    >>> @torch_module
    ... class MyModule(NativeOP):
    ...     pass
    """

    @wraps(module, updated=())
    class TorchModule(module, torch.nn.Module):
        def __init__(self, *args: Any, **kwargs: Any) -> None:
            torch.nn.Module.__init__(self)
            module.__init__(self, *args, **kwargs)

        if hasattr(module, "forward"):

            def __call__(self, *args: Any, **kwargs: Any) -> Any:
                # Ensure torch.nn.Module.__call__ drives forward() for export/tracing.
                return torch.nn.Module.__call__(self, *args, **kwargs)

        def __setattr__(self, name: str, value: Any) -> None:
            handled, value = dpmodel_setattr(self, name, value)
            if not handled:
                super().__setattr__(name, value)

    return TorchModule

@dosubot dosubot bot added the enhancement label Feb 9, 2026
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces a @torch_module decorator to reduce boilerplate in pt_expt wrappers that adapt dpmodel classes into torch.nn.Modules for export/tracing.

Changes:

  • Added torch_module decorator in deepmd/pt_expt/common.py to centralize torch.nn.Module initialization, __call__, and __setattr__ handling.
  • Refactored multiple pt_expt wrappers (networks, exclude masks, descriptors) to use @torch_module instead of repeating wrapper logic.
  • Removed now-redundant imports and wrapper boilerplate in affected modules.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
deepmd/pt_expt/utils/network.py Applies @torch_module to network wrappers to remove duplicated torch.nn.Module boilerplate.
deepmd/pt_expt/utils/exclude_mask.py Applies @torch_module to exclude mask wrappers and removes duplicated __setattr__ logic/imports.
deepmd/pt_expt/descriptor/se_r.py Applies @torch_module to descriptor wrapper and removes duplicated boilerplate.
deepmd/pt_expt/descriptor/se_e2_a.py Applies @torch_module to descriptor wrapper and removes duplicated boilerplate.
deepmd/pt_expt/common.py Adds the torch_module decorator implementation.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

njzjz added a commit to njzjz/deepmd-kit that referenced this pull request Feb 9, 2026
Use `wraps` to keep the modules' names, so they won't be `FlaxModule`, which cannot be regonized. I realized it when implementing deepmodeling#5213.
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 9, 2026

📝 Walkthrough

Walkthrough

Adds a new public API torch_module in deepmd/pt_expt/common.py that converts NativeOP-based classes into torch.nn.Module subclasses; refactors multiple classes to use @torch_module instead of manual torch.nn.Module multiple inheritance, explicit init/call/setattr boilerplate, and dpmodel_setattr wiring.

Changes

Cohort / File(s) Summary
Core Decorator Utility
deepmd/pt_expt/common.py
Add torch_module(module: type[NativeOP]) -> type[torch.nn.Module] and internal TorchModule implementation that composes NativeOP with torch.nn.Module, routes __call__ to ensure forward runs for export/tracing, and delegates __setattr__ to dpmodel_setattr with fallback.
Descriptor Classes
deepmd/pt_expt/descriptor/se_e2_a.py, deepmd/pt_expt/descriptor/se_r.py
Replace dpmodel_setattr usage with @torch_module decorator; remove explicit __init__, __call__, __setattr__; expand forward signatures to accept extended_coord, extended_atype, nlist and optional extended_atype_embd, mapping, type_embedding, and return a 5-tuple including sw.
Exclude Mask Utilities
deepmd/pt_expt/utils/exclude_mask.py
Apply @torch_module to AtomExcludeMask and PairExcludeMask; drop explicit torch.nn.Module base, custom __init__ and __setattr__, simplify classes to minimal definitions delegating behavior to decorator.
Network Classes
deepmd/pt_expt/utils/network.py
Apply @torch_module to NativeLayer, NativeNet, and NetworkCollection; remove explicit torch.nn.Module initialization and __call__ wrappers; rely on decorator to register parameters/buffers and manage module lifecycle while preserving forward-call delegation.

Sequence Diagram(s)

sequenceDiagram
    participant Caller as User code / Caller
    participant Decorator as torch_module (factory)
    participant DPClass as NativeOP / DP base
    participant Torch as torch.nn.Module
    participant dpattr as dpmodel_setattr

    Note over Decorator,DPClass: Decorator builds combined class (TorchModule)
    Caller->>Decorator: apply `@torch_module` to DPClass
    Decorator->>DPClass: create subclass combining DPClass + Torch
    Decorator->>dpattr: integrate attribute hook
    Caller->>DPClass: instantiate combined module
    DPClass->>Torch: initialize torch.nn.Module base
    DPClass->>DPClass: initialize DP base
    Caller->>DPClass: call(module)(inputs)
    DPClass->>Torch: route __call__ -> forward
    Torch->>DPClass: forward runs (may call DPClass.call)
    DPClass->>dpattr: attribute sets delegate to dpmodel_setattr (fallback to setattr)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~40 minutes

Possibly related PRs

Suggested reviewers

  • wanghan-iapcm
  • iProzd
🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 21.43% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately describes the main change: introducing a new @torch_module decorator to simplify module implementation across multiple files by replacing boilerplate code.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
deepmd/pt_expt/utils/network.py (1)

41-73: ⚠️ Potential issue | 🔴 Critical

The decorator's dpmodel_setattr intercepts numpy arrays before NativeLayer.__setattr__ can apply trainable-aware logic, preventing trainable parameters from being created with gradients.

During NativeLayerDP.__init__, when self.w = np.array(...) is assigned:

  1. TorchModule.__setattr__ is invoked (via decorator)
  2. dpmodel_setattr sees the numpy array and _buffers exists, so it registers w as a buffer and returns handled=True (line 213 of common.py)
  3. Because handled=True, the decorator's __setattr__ skips super().__setattr__() and never reaches NativeLayer.__setattr__
  4. NativeLayer.__setattr__ (lines 43–70) with its special trainable-aware logic (lines 54–63) is completely bypassed

Result: When trainable=True, w/b/idt become buffers without gradients instead of TorchArrayParam parameters with requires_grad=True, breaking gradient computation for trainable layers.

🤖 Fix all issues with AI agents
In `@deepmd/pt_expt/common.py`:
- Around line 298-335: The decorator torch_module currently forces
torch.nn.Module.__call__ and thus requires a forward() method, making decorated
classes without forward() (e.g., AtomExcludeMask, PairExcludeMask,
NetworkCollection) non-functional; modify the TorchModule produced by
torch_module to define a default forward(self, *args, **kwargs) that, if the
instance has a call method (i.e., is NativeOP-like), delegates to
self.call(*args, **kwargs), otherwise raises the existing NotImplementedError or
leaves behavior unchanged; adjust the decorator logic to detect NativeOP via
isinstance(self, NativeOP) or hasattr(self, "call") and ensure dpmodel_setattr
and other overrides remain intact, and for purely utility classes remove the
`@torch_module` usage or add explicit forward implementations.
🧹 Nitpick comments (3)
deepmd/pt_expt/common.py (1)

320-321: Redundant torch.nn.Module in consumer class bases when using @torch_module.

Several consumer classes (e.g., DescrptSeA(DescrptSeADP, torch.nn.Module)) list torch.nn.Module explicitly in their bases AND apply @torch_module, which also injects torch.nn.Module as a base of the generated TorchModule. Python's C3 linearization handles the duplicate safely, but it's misleading — it suggests the class is already a torch.nn.Module without the decorator.

Consider documenting this expectation in the docstring (that torch.nn.Module in the original class bases is acceptable/expected), or note in the consumer files that the explicit base is optional with @torch_module.

deepmd/pt_expt/descriptor/se_e2_a.py (1)

14-17: Redundant torch.nn.Module in class bases.

@torch_module already injects torch.nn.Module into the generated class hierarchy. The explicit torch.nn.Module in class DescrptSeA(DescrptSeADP, torch.nn.Module) is harmless (C3 deduplicates it) but misleading — it suggests the class is a Module independent of the decorator.

Optional cleanup
 `@BaseDescriptor.register`("se_e2_a_expt")
 `@BaseDescriptor.register`("se_a_expt")
 `@torch_module`
-class DescrptSeA(DescrptSeADP, torch.nn.Module):
+class DescrptSeA(DescrptSeADP):
deepmd/pt_expt/descriptor/se_r.py (1)

14-17: Same redundant torch.nn.Module base as se_e2_a.py.

Optional cleanup
 `@BaseDescriptor.register`("se_e2_r_expt")
 `@BaseDescriptor.register`("se_r_expt")
 `@torch_module`
-class DescrptSeR(DescrptSeRDP, torch.nn.Module):
+class DescrptSeR(DescrptSeRDP):

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: Jinzhe Zeng <njzjz@qq.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
deepmd/pt_expt/utils/network.py (1)

41-73: ⚠️ Potential issue | 🔴 Critical

Critical: dpmodel_setattr intercepts w/b/idt numpy arrays during initialization, registering them as non-trainable buffers and bypassing NativeLayer.__setattr__'s trainable check.

The decorator's TorchModule.__setattr__ (common.py:330-333) calls dpmodel_setattr first, which detects numpy arrays (common.py:207) and registers them as buffers before NativeLayer.__setattr__ (line 43) can execute. For trainable layers:

  1. During NativeLayerDP.__init__, self.w = numpy_array triggers TorchModule.__setattr__
  2. dpmodel_setattr sees isinstance(value, np.ndarray) and _buffers in __dict__ → calls register_buffer → returns (True, tensor)
  3. handled=True causes TorchModule.__setattr__ to skip super().__setattr__(), so NativeLayer.__setattr__ never runs
  4. The trainable check on line 54 is unreachable
  5. Weights end up as non-trainable buffers instead of TorchArrayParam parameters

The existing test only verifies the None-clearing case (setting w=None after init), not initial registration, so this bug is undetected. Trainable layers will not have gradient flow through their weights.

🤖 Fix all issues with AI agents
In `@deepmd/pt_expt/common.py`:
- Around line 320-324: The TorchModule wrapper currently calls
torch.nn.Module.__init__ and then module.__init__, causing
torch.nn.Module.__init__ to run twice via MRO when subclasses (e.g.,
NetworkCollectionDP / NetworkCollection.__init__) call super().__init__(), which
resets internal registries; fix by ensuring torch.nn.Module.__init__ is invoked
exactly once in the wrapper (either call torch.nn.Module.__init__(self) before
calling module.__init__(self, ...) or omit the explicit call and let
module.__init__ call it), updating the TorchModule class definition accordingly
so that module.__init__ is not followed by a second torch.nn.Module.__init__
invocation.
- Around line 298-335: The decorator torch_module currently types the input as
type[NativeOP] but is used on classes like AtomExcludeMask and PairExcludeMask
that do not inherit NativeOP and also lack a forward() implementation required
by torch.nn.Module.__call__; update torch_module to accept any class (e.g.,
type[Any]) or tighten callers to only pass NativeOP subclasses, and ensure
classes AtomExcludeMask and PairExcludeMask either inherit NativeOP or implement
a forward(self, *args, **kwargs) method that raises NotImplementedError or
provides appropriate behavior; locate torch_module, AtomExcludeMask,
PairExcludeMask, NativeOP, and dpmodel_setattr in the diff to make the change
and keep the existing initialization order and dpmodel_setattr handling.

Comment on lines +298 to +335
def torch_module(
module: type[NativeOP],
) -> type[torch.nn.Module]:
"""Convert a NativeOP to a torch.nn.Module.

Parameters
----------
module : type[NativeOP]
The NativeOP to convert.

Returns
-------
type[torch.nn.Module]
The torch.nn.Module.

Examples
--------
>>> @torch_module
... class MyModule(NativeOP):
... pass
"""

@wraps(module, updated=())
class TorchModule(module, torch.nn.Module):
def __init__(self, *args: Any, **kwargs: Any) -> None:
torch.nn.Module.__init__(self)
module.__init__(self, *args, **kwargs)

def __call__(self, *args: Any, **kwargs: Any) -> Any:
# Ensure torch.nn.Module.__call__ drives forward() for export/tracing.
return torch.nn.Module.__call__(self, *args, **kwargs)

def __setattr__(self, name: str, value: Any) -> None:
handled, value = dpmodel_setattr(self, name, value)
if not handled:
super().__setattr__(name, value)

return TorchModule
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Check if there are any exact-type checks (type(x) is SomeClass) for decorated classes
echo "=== Check for exact type checks on decorated classes ==="
rg -nP 'type\(.+\)\s+is\s+(DescrptSeA|DescrptSeR|NativeLayer|NativeNet|NetworkCollection|AtomExcludeMask|PairExcludeMask)\b' --type=py
echo ""
echo "=== Also check for __class__ comparisons ==="
rg -nP '__class__\s*==\s+(DescrptSeA|DescrptSeR|NativeLayer|NativeNet|NetworkCollection)' --type=py

Repository: deepmodeling/deepmd-kit

Length of output: 169


🏁 Script executed:

#!/bin/bash
# Find where torch_module decorator is used
echo "=== Uses of `@torch_module` decorator ==="
rg -n '@torch_module' --type=py -A 2

Repository: deepmodeling/deepmd-kit

Length of output: 1692


🏁 Script executed:

#!/bin/bash
# Check if decorated classes define forward() method
echo "=== Checking if decorated classes define forward() ==="
rg -n 'def forward\(' deepmd/pt_expt/descriptor/ deepmd/pt_expt/fitting/ deepmd/pt_expt/network.py --type=py -B 5 | head -100

Repository: deepmodeling/deepmd-kit

Length of output: 1019


🏁 Script executed:

#!/bin/bash
# Verify the import of wraps and the decorator implementation
echo "=== Verify functools import and wraps usage ==="
head -30 deepmd/pt_expt/common.py | rg -n 'wraps|functools'
echo ""
echo "=== Full torch_module decorator implementation ==="
sed -n '298,335p' deepmd/pt_expt/common.py

Repository: deepmodeling/deepmd-kit

Length of output: 1287


🏁 Script executed:

#!/bin/bash
# Check if AtomExcludeMask and PairExcludeMask define forward() method
echo "=== Check AtomExcludeMask and PairExcludeMask in exclude_mask.py ==="
cat deepmd/pt_expt/utils/exclude_mask.py

Repository: deepmodeling/deepmd-kit

Length of output: 842


🏁 Script executed:

#!/bin/bash
# Also check if the parent classes define forward() or if it's optional
echo "=== Check parent classes of exclude mask classes ==="
rg -n 'class AtomExcludeMaskDP|class PairExcludeMaskDP' --type=py -A 10

Repository: deepmodeling/deepmd-kit

Length of output: 119


🏁 Script executed:

#!/bin/bash
# Check what methods the parent dpmodel classes might define
echo "=== Checking NativeLayer and NativeNet more thoroughly ==="
sed -n '41,73p' deepmd/pt_expt/utils/network.py
echo ""
sed -n '76,92p' deepmd/pt_expt/utils/network.py

Repository: deepmodeling/deepmd-kit

Length of output: 1967


🏁 Script executed:

#!/bin/bash
# Check if NativeNet and NetworkCollection have forward() method defined
echo "=== Checking if NativeNet has forward() ==="
rg -A 50 'class NativeNet' deepmd/pt_expt/utils/network.py | head -60

Repository: deepmodeling/deepmd-kit

Length of output: 1531


🏁 Script executed:

#!/bin/bash
# Check if NetworkCollection has forward() method defined
echo "=== Checking if NetworkCollection has forward() ==="
rg -A 50 'class NetworkCollection' deepmd/pt_expt/utils/network.py | head -60

Repository: deepmodeling/deepmd-kit

Length of output: 1051


🏁 Script executed:

#!/bin/bash
# Check if forward() is called anywhere on decorated classes
echo "=== Check for forward() calls on NetworkCollection ==="
rg -n 'NetworkCollection.*\.forward\(' --type=py
echo ""
echo "=== Check if NetworkCollection is called directly ==="
rg -n '\([^)]*NetworkCollection[^)]*\)\(' --type=py | head -20

Repository: deepmodeling/deepmd-kit

Length of output: 176


🏁 Script executed:

#!/bin/bash
# Check parent classes for forward() method to see if it's inherited
echo "=== Check dpmodel parent classes for forward() ==="
python3 << 'EOF'
import sys
sys.path.insert(0, '.')
try:
    from deepmd.dpmodel.utils.exclude_mask import AtomExcludeMask as AtomExcludeMaskDP
    from deepmd.dpmodel.utils.exclude_mask import PairExcludeMask as PairExcludeMaskDP
    print("AtomExcludeMaskDP has forward():", hasattr(AtomExcludeMaskDP, 'forward'))
    print("PairExcludeMaskDP has forward():", hasattr(PairExcludeMaskDP, 'forward'))
except Exception as e:
    print(f"Error: {e}")
EOF

Repository: deepmodeling/deepmd-kit

Length of output: 159


🏁 Script executed:

#!/bin/bash
# Check if NetworkCollection or its parent is actually called in a way that needs forward()
echo "=== Check usage of decorated classes ==="
rg -n 'NetworkCollection\(' --type=py | head -10
echo ""
echo "=== Check if these are module outputs used in torch.nn.Module context ==="
rg -n '\.forward\(|__call__\(' deepmd/pt_expt/utils/exclude_mask.py deepmd/pt_expt/utils/network.py | head -20

Repository: deepmodeling/deepmd-kit

Length of output: 1056


🏁 Script executed:

#!/bin/bash
# Check if the parent dpmodel classes are NativeOP or have a call() method
echo "=== Check NativeOP definition and call() method ==="
rg -n 'class NativeOP|def call\(' deepmd/dpmodel/common.py | head -30
echo ""
echo "=== Look for NativeOP imports ==="
rg -n 'from.*NativeOP|import.*NativeOP' deepmd/pt_expt/common.py

Repository: deepmodeling/deepmd-kit

Length of output: 238


🏁 Script executed:

#!/bin/bash
# Check the exclude_mask parent classes to see if they inherit from NativeOP
echo "=== Check parent classes of AtomExcludeMask ==="
rg -B5 'class AtomExcludeMask' deepmd/dpmodel/utils/exclude_mask.py 2>/dev/null || echo "File not found in dpmodel"
echo ""
echo "=== Check if AtomExcludeMask/PairExcludeMask are actually subclasses of NativeOP ==="
python3 << 'EOF'
import sys
sys.path.insert(0, '.')
try:
    # Check what NativeOP is
    from deepmd.pt_expt.common import NativeOP
    print(f"NativeOP: {NativeOP}")
    print(f"NativeOP is: {type(NativeOP)}")
except ImportError as e:
    print(f"Cannot import: {e}")
EOF

Repository: deepmodeling/deepmd-kit

Length of output: 312


🏁 Script executed:

#!/bin/bash
# Check if NetworkCollection, AtomExcludeMask, PairExcludeMask might be called in a torch context
echo "=== Check if NetworkCollection instances are used with __call__ ==="
rg -n 'networks\(|collection\(' deepmd/pt_expt/ --type=py | head -20
echo ""
echo "=== Check test files to see how decorated classes are used ==="
find deepmd/pt_expt -name '*test*.py' -type f | head -5

Repository: deepmodeling/deepmd-kit

Length of output: 198


🏁 Script executed:

#!/bin/bash
# Check the dpmodel parent classes more carefully
echo "=== Check AtomExcludeMaskDP and PairExcludeMaskDP inheritance ==="
head -50 deepmd/dpmodel/utils/exclude_mask.py

Repository: deepmodeling/deepmd-kit

Length of output: 1343


🏁 Script executed:

#!/bin/bash
# Check if NetworkCollectionDP or its parents define forward() or call()
echo "=== Check NetworkCollectionDP definition ==="
rg -A 20 'class NetworkCollectionDP' deepmd/dpmodel/utils/network.py | head -30
echo ""
echo "=== Check if these parent classes have call() method ==="
rg -n 'def call\(' deepmd/dpmodel/utils/

Repository: deepmodeling/deepmd-kit

Length of output: 676


🏁 Script executed:

#!/bin/bash
# Check if the parent classes of decorated classes actually inherit from NativeOP
echo "=== Check what AtomExcludeMaskDP inherits from ==="
grep -A 1 'class AtomExcludeMask' deepmd/dpmodel/utils/exclude_mask.py | head -3
echo ""
echo "=== Check NetworkCollectionDP inheritance ==="
grep 'class NetworkCollectionDP' deepmd/dpmodel/utils/network.py
echo ""
echo "=== Verify NativeOP definition to understand what classes should inherit from it ==="
sed -n '93,120p' deepmd/dpmodel/common.py

Repository: deepmodeling/deepmd-kit

Length of output: 1003


🏁 Script executed:

#!/bin/bash
# Check if classes that don't define forward() are actually expected to be called via torch
echo "=== Check imports and usage context for AtomExcludeMask ==="
rg -n 'from.*exclude_mask import|import.*exclude_mask' deepmd/pt_expt/ --type=py
echo ""
echo "=== Check how NetworkCollection is instantiated ==="
grep -r 'NetworkCollection(' deepmd/pt_expt/ --include='*.py' | head -5

Repository: deepmodeling/deepmd-kit

Length of output: 565


🏁 Script executed:

#!/bin/bash
# Check if there are any tests that use the decorated exclude_mask classes
echo "=== Check exclude_mask test files ==="
find . -path './deepmd/pt_expt*' -name '*exclude_mask*' -o -path './deepmd/pt_expt*' -name '*test*.py' | grep -i exclude

Repository: deepmodeling/deepmd-kit

Length of output: 143


🏁 Script executed:

#!/bin/bash
# Search for NetworkCollectionDP more broadly
echo "=== Search for NetworkCollectionDP ==="
rg -n 'NetworkCollectionDP' --type=py
echo ""
echo "=== Check the full network.py file for NetworkCollectionDP ==="
cat deepmd/dpmodel/utils/network.py | grep -n 'class.*Collection'

Repository: deepmodeling/deepmd-kit

Length of output: 1331


🏁 Script executed:

#!/bin/bash
# Check if these parent classes (AtomExcludeMaskDP, PairExcludeMaskDP) inherit from NativeOP
echo "=== Full AtomExcludeMask and PairExcludeMask definitions ==="
cat deepmd/dpmodel/utils/exclude_mask.py

Repository: deepmodeling/deepmd-kit

Length of output: 4547


Type annotation allows non-NativeOP classes, and AtomExcludeMask/PairExcludeMask lack forward() method required by torch.nn.Module.

The decorator's __call__ method routes through torch.nn.Module.__call__, which invokes forward(). However:

  1. Type mismatch: The parameter is annotated module: type[NativeOP], but AtomExcludeMask and PairExcludeMask inherit from plain dpmodel classes that do not inherit from NativeOP. This violates the documented contract.

  2. Missing forward() implementation: While DescrptSeA, DescrptSeR, NativeLayer, and NativeNet all define forward(), neither AtomExcludeMask nor PairExcludeMask do. Calling these instances via __call__ will fail with NotImplementedError or AttributeError.

The initialization order (torch.nn.Module first, then dpmodel) is correct and ensures _parameters/_buffers/_modules dicts exist before dpmodel_setattr runs. The @wraps(module, updated=()) properly preserves metadata. No exact type checks (type(x) is ClassName) were found in the codebase, so the isinstance/inheritance trade-off is safe.

🤖 Prompt for AI Agents
In `@deepmd/pt_expt/common.py` around lines 298 - 335, The decorator torch_module
currently types the input as type[NativeOP] but is used on classes like
AtomExcludeMask and PairExcludeMask that do not inherit NativeOP and also lack a
forward() implementation required by torch.nn.Module.__call__; update
torch_module to accept any class (e.g., type[Any]) or tighten callers to only
pass NativeOP subclasses, and ensure classes AtomExcludeMask and PairExcludeMask
either inherit NativeOP or implement a forward(self, *args, **kwargs) method
that raises NotImplementedError or provides appropriate behavior; locate
torch_module, AtomExcludeMask, PairExcludeMask, NativeOP, and dpmodel_setattr in
the diff to make the change and keep the existing initialization order and
dpmodel_setattr handling.

Comment on lines +320 to +324
@wraps(module, updated=())
class TorchModule(module, torch.nn.Module):
def __init__(self, *args: Any, **kwargs: Any) -> None:
torch.nn.Module.__init__(self)
module.__init__(self, *args, **kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# First, let's check the context around lines 320-324 in deepmd/pt_expt/common.py
echo "=== Context around lines 320-324 in common.py ==="
cat -n deepmd/pt_expt/common.py | sed -n '315,330p'

Repository: deepmodeling/deepmd-kit

Length of output: 822


🏁 Script executed:

#!/bin/bash
# Check if NetworkCollection or decorated classes call super().__init__()
echo "=== Check NetworkCollection and related classes ==="
find deepmd/pt_expt -name "*.py" | head -20

Repository: deepmodeling/deepmd-kit

Length of output: 454


🏁 Script executed:

#!/bin/bash
# Find the network.py file if it exists
echo "=== Looking for network-related files ==="
fd -t f "network" deepmd/pt_expt/

Repository: deepmodeling/deepmd-kit

Length of output: 140


🏁 Script executed:

#!/bin/bash
# Check NetworkCollection class and its __init__ method
echo "=== Looking for NetworkCollection class ==="
rg -n "class NetworkCollection" deepmd/pt_expt/utils/network.py -A 30

Repository: deepmodeling/deepmd-kit

Length of output: 1158


🏁 Script executed:

#!/bin/bash
# Check for any super().__init__() calls in the decorated classes
echo "=== Check for super().__init__() calls in pt_expt modules ==="
rg -n "super\(\)\.__init__\(\)" deepmd/pt_expt/ -B 5 -A 5

Repository: deepmodeling/deepmd-kit

Length of output: 128


🏁 Script executed:

#!/bin/bash
# Let's inspect torch.nn.Module.__init__ to see what it does
python3 << 'EOF'
import torch
import inspect

# Get the source code of torch.nn.Module.__init__
try:
    source = inspect.getsource(torch.nn.Module.__init__)
    print("=== torch.nn.Module.__init__ source ===")
    print(source[:1000])  # First 1000 chars
except:
    print("Could not get source, checking what it sets")

# Create a test instance and see what attributes are set
print("\n=== Testing torch.nn.Module.__init__ behavior ===")
m = torch.nn.Module()
print(f"_parameters: {type(m._parameters)}, content: {m._parameters}")
print(f"_buffers: {type(m._buffers)}, content: {m._buffers}")
print(f"_modules: {type(m._modules)}, content: {m._modules}")

# Now test calling __init__ twice
print("\n=== Testing double __init__ call ===")
m2 = torch.nn.Module()
m2.custom_attr = "test"
print(f"Before second __init__: custom_attr = {m2.custom_attr}")
torch.nn.Module.__init__(m2)
print(f"After second __init__: custom_attr = {m2.custom_attr}")
print(f"_parameters after second __init__: {m2._parameters}")
EOF

Repository: deepmodeling/deepmd-kit

Length of output: 184


🏁 Script executed:

#!/bin/bash
# Find NetworkCollectionDP
echo "=== Looking for NetworkCollectionDP class ==="
rg -n "class NetworkCollectionDP" deepmd/ -A 15

Repository: deepmodeling/deepmd-kit

Length of output: 112


🏁 Script executed:

#!/bin/bash
# Let's also check imports to understand the base classes
echo "=== Check imports in network.py ==="
head -50 deepmd/pt_expt/utils/network.py

Repository: deepmodeling/deepmd-kit

Length of output: 1566


🌐 Web query:

torch.nn.Module.__init__ behavior what does it initialize reset state

💡 Result:

torch.nn.Module.__init__() does not initialize model weights. It initializes the bookkeeping/state containers that make PyTorch “see” your parameters, buffers, and submodules.

Concretely, Module.__init__() sets up (at least):

  • self.training = True (default training mode)
  • empty registries like self._parameters, self._buffers, self._modules (and hook/state_dict hook dicts, plus _non_persistent_buffers_set) so that later assignments like self.linear = nn.Linear(...) get registered properly. [1]

Implications:

  • You must call super().__init__() before assigning submodules/parameters/buffers, otherwise registration fails (PyTorch explicitly documents this requirement). [2]
  • Calling Module.__init__() again on an existing module effectively resets those registries (you can “lose” previously registered submodules/parameters from the module’s tracking), but it doesn’t “reinitialize weights” in the usual layer-specific sense. [1][2]
  • Weight initialization happens in the specific layer constructors (e.g., nn.Linear.__init__), not in nn.Module.__init__. [3]

Sources:
[1] https://glaringlee.github.io/_modules/torch/nn/modules/module.html
[2] https://docs.pytorch.org/docs/stable/generated/torch.nn.Module.html
[3] https://docs.pytorch.org/docs/stable/notes/modules.html


Address double torch.nn.Module.__init__ calls in the decorator pattern.

The decorator explicitly calls torch.nn.Module.__init__ on line 323, then calls module.__init__ on line 324. When NetworkCollectionDP.__init__ (and classes that inherit from it) call super().__init__() at the end, the MRO routes this back to torch.nn.Module.__init__, resulting in a second call. This resets _parameters, _buffers, and _modules dicts, causing any state registered in the decorated class's __init__ before the super() call to be lost.

For example, in NetworkCollection.__init__ (line 104), the super().__init__() call after setting _module_networks will reset those registries. While _module_networks itself is not a registered parameter/buffer/submodule, the pattern is fragile: any future logic that registers state before calling super().__init__() will have it silently discarded. Consider calling torch.nn.Module.__init__ only once, either before or after module.__init__, rather than both.

🤖 Prompt for AI Agents
In `@deepmd/pt_expt/common.py` around lines 320 - 324, The TorchModule wrapper
currently calls torch.nn.Module.__init__ and then module.__init__, causing
torch.nn.Module.__init__ to run twice via MRO when subclasses (e.g.,
NetworkCollectionDP / NetworkCollection.__init__) call super().__init__(), which
resets internal registries; fix by ensuring torch.nn.Module.__init__ is invoked
exactly once in the wrapper (either call torch.nn.Module.__init__(self) before
calling module.__init__(self, ...) or omit the explicit call and let
module.__init__ call it), updating the TorchModule class definition accordingly
so that module.__init__ is not followed by a second torch.nn.Module.__init__
invocation.

@codecov
Copy link

codecov bot commented Feb 9, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 81.97%. Comparing base (97d8ded) to head (d32681a).

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5213      +/-   ##
==========================================
- Coverage   82.00%   81.97%   -0.03%     
==========================================
  Files         724      724              
  Lines       73801    73779      -22     
  Branches     3616     3615       -1     
==========================================
- Hits        60520    60484      -36     
- Misses      12120    12132      +12     
- Partials     1161     1163       +2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

@njzjz njzjz requested a review from wanghan-iapcm February 10, 2026 02:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant