diff --git a/deepmd/pt_expt/common.py b/deepmd/pt_expt/common.py index c7c6cff99b..c7375119e2 100644 --- a/deepmd/pt_expt/common.py +++ b/deepmd/pt_expt/common.py @@ -17,6 +17,9 @@ from collections.abc import ( Callable, ) +from functools import ( + wraps, +) from typing import ( Any, overload, @@ -292,6 +295,46 @@ def to_torch_array(array: Any) -> torch.Tensor | None: return torch.as_tensor(array, device=env.DEVICE) +def torch_module( + module: type[NativeOP], +) -> type[torch.nn.Module]: + """Convert a NativeOP to a torch.nn.Module. + + Parameters + ---------- + module : type[NativeOP] + The NativeOP to convert. + + Returns + ------- + type[torch.nn.Module] + The torch.nn.Module. + + Examples + -------- + >>> @torch_module + ... class MyModule(NativeOP): + ... pass + """ + + @wraps(module, updated=()) + class TorchModule(module, torch.nn.Module): + def __init__(self, *args: Any, **kwargs: Any) -> None: + torch.nn.Module.__init__(self) + module.__init__(self, *args, **kwargs) + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + # Ensure torch.nn.Module.__call__ drives forward() for export/tracing. + return torch.nn.Module.__call__(self, *args, **kwargs) + + def __setattr__(self, name: str, value: Any) -> None: + handled, value = dpmodel_setattr(self, name, value) + if not handled: + super().__setattr__(name, value) + + return TorchModule + + # Import utils to trigger dpmodel→pt_expt converter registrations # This must happen after the functions above are defined to avoid circular imports def _ensure_registrations() -> None: diff --git a/deepmd/pt_expt/descriptor/se_e2_a.py b/deepmd/pt_expt/descriptor/se_e2_a.py index 1ccb4d2dda..09deceb877 100644 --- a/deepmd/pt_expt/descriptor/se_e2_a.py +++ b/deepmd/pt_expt/descriptor/se_e2_a.py @@ -1,13 +1,10 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -from typing import ( - Any, -) import torch from deepmd.dpmodel.descriptor.se_e2_a import DescrptSeAArrayAPI as DescrptSeADP from deepmd.pt_expt.common import ( - dpmodel_setattr, + torch_module, ) from deepmd.pt_expt.descriptor.base_descriptor import ( BaseDescriptor, @@ -16,20 +13,8 @@ @BaseDescriptor.register("se_e2_a_expt") @BaseDescriptor.register("se_a_expt") -class DescrptSeA(DescrptSeADP, torch.nn.Module): - def __init__(self, *args: Any, **kwargs: Any) -> None: - torch.nn.Module.__init__(self) - DescrptSeADP.__init__(self, *args, **kwargs) - - def __call__(self, *args: Any, **kwargs: Any) -> Any: - # Ensure torch.nn.Module.__call__ drives forward() for export/tracing. - return torch.nn.Module.__call__(self, *args, **kwargs) - - def __setattr__(self, name: str, value: Any) -> None: - handled, value = dpmodel_setattr(self, name, value) - if not handled: - super().__setattr__(name, value) - +@torch_module +class DescrptSeA(DescrptSeADP): def forward( self, extended_coord: torch.Tensor, diff --git a/deepmd/pt_expt/descriptor/se_r.py b/deepmd/pt_expt/descriptor/se_r.py index 7a406fb499..a6e719e2ea 100644 --- a/deepmd/pt_expt/descriptor/se_r.py +++ b/deepmd/pt_expt/descriptor/se_r.py @@ -1,13 +1,10 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -from typing import ( - Any, -) import torch from deepmd.dpmodel.descriptor.se_r import DescrptSeR as DescrptSeRDP from deepmd.pt_expt.common import ( - dpmodel_setattr, + torch_module, ) from deepmd.pt_expt.descriptor.base_descriptor import ( BaseDescriptor, @@ -16,20 +13,8 @@ @BaseDescriptor.register("se_e2_r_expt") @BaseDescriptor.register("se_r_expt") -class DescrptSeR(DescrptSeRDP, torch.nn.Module): - def __init__(self, *args: Any, **kwargs: Any) -> None: - torch.nn.Module.__init__(self) - DescrptSeRDP.__init__(self, *args, **kwargs) - - def __call__(self, *args: Any, **kwargs: Any) -> Any: - # Ensure torch.nn.Module.__call__ drives forward() for export/tracing. - return torch.nn.Module.__call__(self, *args, **kwargs) - - def __setattr__(self, name: str, value: Any) -> None: - handled, value = dpmodel_setattr(self, name, value) - if not handled: - super().__setattr__(name, value) - +@torch_module +class DescrptSeR(DescrptSeRDP): def forward( self, extended_coord: torch.Tensor, diff --git a/deepmd/pt_expt/utils/exclude_mask.py b/deepmd/pt_expt/utils/exclude_mask.py index 4060b8c446..2df6d9e9a7 100644 --- a/deepmd/pt_expt/utils/exclude_mask.py +++ b/deepmd/pt_expt/utils/exclude_mask.py @@ -1,27 +1,17 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -from typing import ( - Any, -) -import torch from deepmd.dpmodel.utils.exclude_mask import AtomExcludeMask as AtomExcludeMaskDP from deepmd.dpmodel.utils.exclude_mask import PairExcludeMask as PairExcludeMaskDP from deepmd.pt_expt.common import ( - dpmodel_setattr, register_dpmodel_mapping, + torch_module, ) -class AtomExcludeMask(AtomExcludeMaskDP, torch.nn.Module): - def __init__(self, *args: Any, **kwargs: Any) -> None: - torch.nn.Module.__init__(self) - AtomExcludeMaskDP.__init__(self, *args, **kwargs) - - def __setattr__(self, name: str, value: Any) -> None: - handled, value = dpmodel_setattr(self, name, value) - if not handled: - super().__setattr__(name, value) +@torch_module +class AtomExcludeMask(AtomExcludeMaskDP): + pass register_dpmodel_mapping( @@ -30,15 +20,9 @@ def __setattr__(self, name: str, value: Any) -> None: ) -class PairExcludeMask(PairExcludeMaskDP, torch.nn.Module): - def __init__(self, *args: Any, **kwargs: Any) -> None: - torch.nn.Module.__init__(self) - PairExcludeMaskDP.__init__(self, *args, **kwargs) - - def __setattr__(self, name: str, value: Any) -> None: - handled, value = dpmodel_setattr(self, name, value) - if not handled: - super().__setattr__(name, value) +@torch_module +class PairExcludeMask(PairExcludeMaskDP): + pass register_dpmodel_mapping( diff --git a/deepmd/pt_expt/utils/network.py b/deepmd/pt_expt/utils/network.py index 84d0024a85..9a18f607ff 100644 --- a/deepmd/pt_expt/utils/network.py +++ b/deepmd/pt_expt/utils/network.py @@ -21,6 +21,7 @@ from deepmd.pt_expt.common import ( register_dpmodel_mapping, to_torch_array, + torch_module, ) @@ -37,14 +38,8 @@ def __array__(self, dtype: Any | None = None) -> np.ndarray: return arr.astype(dtype) -class NativeLayer(NativeLayerDP, torch.nn.Module): - def __init__(self, *args: Any, **kwargs: Any) -> None: - torch.nn.Module.__init__(self) - NativeLayerDP.__init__(self, *args, **kwargs) - - def __call__(self, *args: Any, **kwargs: Any) -> Any: - return torch.nn.Module.__call__(self, *args, **kwargs) - +@torch_module +class NativeLayer(NativeLayerDP): def __setattr__(self, name: str, value: Any) -> None: if name in {"w", "b", "idt"} and "_parameters" in self.__dict__: val = to_torch_array(value) @@ -78,15 +73,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.call(x) -class NativeNet(make_multilayer_network(NativeLayer, NativeOP), torch.nn.Module): +@torch_module +class NativeNet(make_multilayer_network(NativeLayer, NativeOP)): def __init__(self, layers: list[dict] | None = None) -> None: - torch.nn.Module.__init__(self) super().__init__(layers) self.layers = torch.nn.ModuleList(self.layers) - def __call__(self, *args: Any, **kwargs: Any) -> Any: - return torch.nn.Module.__call__(self, *args, **kwargs) - def forward(self, x: torch.Tensor) -> torch.Tensor: return self.call(x) @@ -99,7 +91,8 @@ class FittingNet(make_fitting_network(EmbeddingNet, NativeNet, NativeLayer)): pass -class NetworkCollection(NetworkCollectionDP, torch.nn.Module): +@torch_module +class NetworkCollection(NetworkCollectionDP): NETWORK_TYPE_MAP: ClassVar[dict[str, type]] = { "network": NativeNet, "embedding_network": EmbeddingNet, @@ -107,7 +100,6 @@ class NetworkCollection(NetworkCollectionDP, torch.nn.Module): } def __init__(self, *args: Any, **kwargs: Any) -> None: - torch.nn.Module.__init__(self) self._module_networks = torch.nn.ModuleDict() super().__init__(*args, **kwargs)