From b73d490a2ba7b38c6218cf338644d8bce492115b Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 10 Feb 2026 05:42:06 +0800 Subject: [PATCH 1/2] refact(pt_expt): add decorator to simplify the module --- deepmd/pt_expt/common.py | 43 ++++++++++++++++++++++++++++ deepmd/pt_expt/descriptor/se_e2_a.py | 19 ++---------- deepmd/pt_expt/descriptor/se_r.py | 19 ++---------- deepmd/pt_expt/utils/exclude_mask.py | 30 +++++-------------- deepmd/pt_expt/utils/network.py | 26 +++++------------ 5 files changed, 61 insertions(+), 76 deletions(-) diff --git a/deepmd/pt_expt/common.py b/deepmd/pt_expt/common.py index c7c6cff99b..be01e7771e 100644 --- a/deepmd/pt_expt/common.py +++ b/deepmd/pt_expt/common.py @@ -17,6 +17,9 @@ from collections.abc import ( Callable, ) +from functools import ( + wraps, +) from typing import ( Any, overload, @@ -292,6 +295,46 @@ def to_torch_array(array: Any) -> torch.Tensor | None: return torch.as_tensor(array, device=env.DEVICE) +def torch_module( + module: type[NativeOP], +) -> type[torch.nn.Module]: + """Convert a NativeOP to a torch.nn.Module. + + Parameters + ---------- + module : NativeOP + The NativeOP to convert. + + Returns + ------- + torch.nn.Module + The torch.nn.Module. + + Examples + -------- + >>> @torch_module + ... class MyModule(NativeOP): + ... pass + """ + + @wraps(module, updated=()) + class TorchModule(module, torch.nn.Module): + def __init__(self, *args: Any, **kwargs: Any) -> None: + torch.nn.Module.__init__(self) + module.__init__(self, *args, **kwargs) + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + # Ensure torch.nn.Module.__call__ drives forward() for export/tracing. + return torch.nn.Module.__call__(self, *args, **kwargs) + + def __setattr__(self, name: str, value: Any) -> None: + handled, value = dpmodel_setattr(self, name, value) + if not handled: + super().__setattr__(name, value) + + return TorchModule + + # Import utils to trigger dpmodel→pt_expt converter registrations # This must happen after the functions above are defined to avoid circular imports def _ensure_registrations() -> None: diff --git a/deepmd/pt_expt/descriptor/se_e2_a.py b/deepmd/pt_expt/descriptor/se_e2_a.py index 1ccb4d2dda..d545dd60e6 100644 --- a/deepmd/pt_expt/descriptor/se_e2_a.py +++ b/deepmd/pt_expt/descriptor/se_e2_a.py @@ -1,13 +1,10 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -from typing import ( - Any, -) import torch from deepmd.dpmodel.descriptor.se_e2_a import DescrptSeAArrayAPI as DescrptSeADP from deepmd.pt_expt.common import ( - dpmodel_setattr, + torch_module, ) from deepmd.pt_expt.descriptor.base_descriptor import ( BaseDescriptor, @@ -16,20 +13,8 @@ @BaseDescriptor.register("se_e2_a_expt") @BaseDescriptor.register("se_a_expt") +@torch_module class DescrptSeA(DescrptSeADP, torch.nn.Module): - def __init__(self, *args: Any, **kwargs: Any) -> None: - torch.nn.Module.__init__(self) - DescrptSeADP.__init__(self, *args, **kwargs) - - def __call__(self, *args: Any, **kwargs: Any) -> Any: - # Ensure torch.nn.Module.__call__ drives forward() for export/tracing. - return torch.nn.Module.__call__(self, *args, **kwargs) - - def __setattr__(self, name: str, value: Any) -> None: - handled, value = dpmodel_setattr(self, name, value) - if not handled: - super().__setattr__(name, value) - def forward( self, extended_coord: torch.Tensor, diff --git a/deepmd/pt_expt/descriptor/se_r.py b/deepmd/pt_expt/descriptor/se_r.py index 7a406fb499..ff355072d9 100644 --- a/deepmd/pt_expt/descriptor/se_r.py +++ b/deepmd/pt_expt/descriptor/se_r.py @@ -1,13 +1,10 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -from typing import ( - Any, -) import torch from deepmd.dpmodel.descriptor.se_r import DescrptSeR as DescrptSeRDP from deepmd.pt_expt.common import ( - dpmodel_setattr, + torch_module, ) from deepmd.pt_expt.descriptor.base_descriptor import ( BaseDescriptor, @@ -16,20 +13,8 @@ @BaseDescriptor.register("se_e2_r_expt") @BaseDescriptor.register("se_r_expt") +@torch_module class DescrptSeR(DescrptSeRDP, torch.nn.Module): - def __init__(self, *args: Any, **kwargs: Any) -> None: - torch.nn.Module.__init__(self) - DescrptSeRDP.__init__(self, *args, **kwargs) - - def __call__(self, *args: Any, **kwargs: Any) -> Any: - # Ensure torch.nn.Module.__call__ drives forward() for export/tracing. - return torch.nn.Module.__call__(self, *args, **kwargs) - - def __setattr__(self, name: str, value: Any) -> None: - handled, value = dpmodel_setattr(self, name, value) - if not handled: - super().__setattr__(name, value) - def forward( self, extended_coord: torch.Tensor, diff --git a/deepmd/pt_expt/utils/exclude_mask.py b/deepmd/pt_expt/utils/exclude_mask.py index 4060b8c446..2df6d9e9a7 100644 --- a/deepmd/pt_expt/utils/exclude_mask.py +++ b/deepmd/pt_expt/utils/exclude_mask.py @@ -1,27 +1,17 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -from typing import ( - Any, -) -import torch from deepmd.dpmodel.utils.exclude_mask import AtomExcludeMask as AtomExcludeMaskDP from deepmd.dpmodel.utils.exclude_mask import PairExcludeMask as PairExcludeMaskDP from deepmd.pt_expt.common import ( - dpmodel_setattr, register_dpmodel_mapping, + torch_module, ) -class AtomExcludeMask(AtomExcludeMaskDP, torch.nn.Module): - def __init__(self, *args: Any, **kwargs: Any) -> None: - torch.nn.Module.__init__(self) - AtomExcludeMaskDP.__init__(self, *args, **kwargs) - - def __setattr__(self, name: str, value: Any) -> None: - handled, value = dpmodel_setattr(self, name, value) - if not handled: - super().__setattr__(name, value) +@torch_module +class AtomExcludeMask(AtomExcludeMaskDP): + pass register_dpmodel_mapping( @@ -30,15 +20,9 @@ def __setattr__(self, name: str, value: Any) -> None: ) -class PairExcludeMask(PairExcludeMaskDP, torch.nn.Module): - def __init__(self, *args: Any, **kwargs: Any) -> None: - torch.nn.Module.__init__(self) - PairExcludeMaskDP.__init__(self, *args, **kwargs) - - def __setattr__(self, name: str, value: Any) -> None: - handled, value = dpmodel_setattr(self, name, value) - if not handled: - super().__setattr__(name, value) +@torch_module +class PairExcludeMask(PairExcludeMaskDP): + pass register_dpmodel_mapping( diff --git a/deepmd/pt_expt/utils/network.py b/deepmd/pt_expt/utils/network.py index 84d0024a85..44085d8a25 100644 --- a/deepmd/pt_expt/utils/network.py +++ b/deepmd/pt_expt/utils/network.py @@ -21,6 +21,7 @@ from deepmd.pt_expt.common import ( register_dpmodel_mapping, to_torch_array, + torch_module, ) @@ -37,14 +38,8 @@ def __array__(self, dtype: Any | None = None) -> np.ndarray: return arr.astype(dtype) -class NativeLayer(NativeLayerDP, torch.nn.Module): - def __init__(self, *args: Any, **kwargs: Any) -> None: - torch.nn.Module.__init__(self) - NativeLayerDP.__init__(self, *args, **kwargs) - - def __call__(self, *args: Any, **kwargs: Any) -> Any: - return torch.nn.Module.__call__(self, *args, **kwargs) - +@torch_module +class NativeLayer(NativeLayerDP): def __setattr__(self, name: str, value: Any) -> None: if name in {"w", "b", "idt"} and "_parameters" in self.__dict__: val = to_torch_array(value) @@ -78,15 +73,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.call(x) -class NativeNet(make_multilayer_network(NativeLayer, NativeOP), torch.nn.Module): - def __init__(self, layers: list[dict] | None = None) -> None: - torch.nn.Module.__init__(self) - super().__init__(layers) - self.layers = torch.nn.ModuleList(self.layers) - - def __call__(self, *args: Any, **kwargs: Any) -> Any: - return torch.nn.Module.__call__(self, *args, **kwargs) - +@torch_module +class NativeNet(make_multilayer_network(NativeLayer, NativeOP)): def forward(self, x: torch.Tensor) -> torch.Tensor: return self.call(x) @@ -99,7 +87,8 @@ class FittingNet(make_fitting_network(EmbeddingNet, NativeNet, NativeLayer)): pass -class NetworkCollection(NetworkCollectionDP, torch.nn.Module): +@torch_module +class NetworkCollection(NetworkCollectionDP): NETWORK_TYPE_MAP: ClassVar[dict[str, type]] = { "network": NativeNet, "embedding_network": EmbeddingNet, @@ -107,7 +96,6 @@ class NetworkCollection(NetworkCollectionDP, torch.nn.Module): } def __init__(self, *args: Any, **kwargs: Any) -> None: - torch.nn.Module.__init__(self) self._module_networks = torch.nn.ModuleDict() super().__init__(*args, **kwargs) From d32681a93e1b0e1ea5cb93db0b1f16cd4dc7dbe3 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 10 Feb 2026 05:53:32 +0800 Subject: [PATCH 2/2] Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: Jinzhe Zeng --- deepmd/pt_expt/common.py | 4 ++-- deepmd/pt_expt/descriptor/se_e2_a.py | 2 +- deepmd/pt_expt/descriptor/se_r.py | 2 +- deepmd/pt_expt/utils/network.py | 4 ++++ 4 files changed, 8 insertions(+), 4 deletions(-) diff --git a/deepmd/pt_expt/common.py b/deepmd/pt_expt/common.py index be01e7771e..c7375119e2 100644 --- a/deepmd/pt_expt/common.py +++ b/deepmd/pt_expt/common.py @@ -302,12 +302,12 @@ def torch_module( Parameters ---------- - module : NativeOP + module : type[NativeOP] The NativeOP to convert. Returns ------- - torch.nn.Module + type[torch.nn.Module] The torch.nn.Module. Examples diff --git a/deepmd/pt_expt/descriptor/se_e2_a.py b/deepmd/pt_expt/descriptor/se_e2_a.py index d545dd60e6..09deceb877 100644 --- a/deepmd/pt_expt/descriptor/se_e2_a.py +++ b/deepmd/pt_expt/descriptor/se_e2_a.py @@ -14,7 +14,7 @@ @BaseDescriptor.register("se_e2_a_expt") @BaseDescriptor.register("se_a_expt") @torch_module -class DescrptSeA(DescrptSeADP, torch.nn.Module): +class DescrptSeA(DescrptSeADP): def forward( self, extended_coord: torch.Tensor, diff --git a/deepmd/pt_expt/descriptor/se_r.py b/deepmd/pt_expt/descriptor/se_r.py index ff355072d9..a6e719e2ea 100644 --- a/deepmd/pt_expt/descriptor/se_r.py +++ b/deepmd/pt_expt/descriptor/se_r.py @@ -14,7 +14,7 @@ @BaseDescriptor.register("se_e2_r_expt") @BaseDescriptor.register("se_r_expt") @torch_module -class DescrptSeR(DescrptSeRDP, torch.nn.Module): +class DescrptSeR(DescrptSeRDP): def forward( self, extended_coord: torch.Tensor, diff --git a/deepmd/pt_expt/utils/network.py b/deepmd/pt_expt/utils/network.py index 44085d8a25..9a18f607ff 100644 --- a/deepmd/pt_expt/utils/network.py +++ b/deepmd/pt_expt/utils/network.py @@ -75,6 +75,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: @torch_module class NativeNet(make_multilayer_network(NativeLayer, NativeOP)): + def __init__(self, layers: list[dict] | None = None) -> None: + super().__init__(layers) + self.layers = torch.nn.ModuleList(self.layers) + def forward(self, x: torch.Tensor) -> torch.Tensor: return self.call(x)