Skip to content

refact(dpmodel,pt_expt): fitting net#5207

Open
wanghan-iapcm wants to merge 37 commits intodeepmodeling:masterfrom
wanghan-iapcm:refact-fitting-net
Open

refact(dpmodel,pt_expt): fitting net#5207
wanghan-iapcm wants to merge 37 commits intodeepmodeling:masterfrom
wanghan-iapcm:refact-fitting-net

Conversation

@wanghan-iapcm
Copy link
Collaborator

@wanghan-iapcm wanghan-iapcm commented Feb 8, 2026

FittingNet Refactoring: Factory Function to Concrete Class

Summary

This refactoring converts FittingNet from a factory-generated dynamic class to a concrete class in the dpmodel backend, following the same pattern as the EmbeddingNet refactoring. This enables the auto-detection registry mechanism in pt_expt to work seamlessly with FittingNet.

This PR is considered after #5194 and #5204

Motivation

Before: FittingNet was created by a factory function make_fitting_network(EmbeddingNet, NativeNet, NativeLayer), producing a dynamically-typed class. This caused:

  1. Cannot be registered: Dynamic classes can't be imported or registered at module import time in the pt_expt registry
  2. Type matching fails: Each call to make_fitting_network creates a new class type, so registry lookup by type fails

After: FittingNet is now a concrete class that can be registered in the pt_expt auto-conversion registry.

Changes

1. dpmodel: Concrete FittingNet class

File: deepmd/dpmodel/utils/network.py

  • Created concrete FittingNet(EmbeddingNet) class
  • Moved constructor logic from factory into __init__
  • Fixed deserialize to use type(obj.layers[0]) instead of hardcoding T_Network.__init__(obj, layers), allowing pt_expt subclass to preserve its converted torch layers
  • Kept make_fitting_network factory for backwards compatibility (for pt/pd backends)
class FittingNet(EmbeddingNet):
    """The fitting network."""

    def __init__(self, in_dim, out_dim, neuron=[24, 48, 96],
                 activation_function="tanh", resnet_dt=False,
                 precision=DEFAULT_PRECISION, bias_out=True,
                 seed=None, trainable=True):
        # Handle trainable parameter
        if trainable is None:
            trainable = [True] * (len(neuron) + 1)
        elif isinstance(trainable, bool):
            trainable = [trainable] * (len(neuron) + 1)

        # Initialize embedding layers via parent
        super().__init__(
            in_dim, neuron=neuron,
            activation_function=activation_function,
            resnet_dt=resnet_dt, precision=precision,
            seed=seed, trainable=trainable[:-1]
        )

        # Add output layer
        i_in = neuron[-1] if len(neuron) > 0 else in_dim
        self.layers.append(
            NativeLayer(
                i_in, out_dim, bias=bias_out,
                use_timestep=False, activation_function=None,
                resnet=False, precision=precision,
                seed=child_seed(seed, len(neuron)),
                trainable=trainable[-1]
            )
        )
        self.out_dim = out_dim
        self.bias_out = bias_out

    @classmethod
    def deserialize(cls, data):
        data = data.copy()
        check_version_compatibility(data.pop("@version", 1), 1, 1)
        data.pop("@class", None)
        layers = data.pop("layers")
        obj = cls(**data)
        # Use type(obj.layers[0]) to respect subclass layer types
        layer_type = type(obj.layers[0])
        obj.layers = type(obj.layers)(
            [layer_type.deserialize(layer) for layer in layers]
        )
        return obj

2. pt_expt: Wrapper and registration

File: deepmd/pt_expt/utils/network.py

  • Added import: from deepmd.dpmodel.utils.network import FittingNet as FittingNetDP
  • Created FittingNet(FittingNetDP, torch.nn.Module) wrapper
  • Converts dpmodel layers to pt_expt NativeLayer (torch modules) in __init__
  • Registered in auto-conversion registry
from deepmd.dpmodel.utils.network import FittingNet as FittingNetDP

class FittingNet(FittingNetDP, torch.nn.Module):
    def __init__(self, *args: Any, **kwargs: Any) -> None:
        torch.nn.Module.__init__(self)
        FittingNetDP.__init__(self, *args, **kwargs)
        # Convert dpmodel layers to pt_expt NativeLayer
        self.layers = torch.nn.ModuleList(
            [NativeLayer.deserialize(layer.serialize()) for layer in self.layers]
        )

    def __call__(self, *args: Any, **kwargs: Any) -> Any:
        return torch.nn.Module.__call__(self, *args, **kwargs)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.call(x)

register_dpmodel_mapping(
    FittingNetDP,
    lambda v: FittingNet.deserialize(v.serialize()),
)

Tests

dpmodel tests

File: source/tests/common/dpmodel/test_network.py

Added to TestFittingNet class:

  1. test_fitting_net: Original roundtrip serialization test (already existed)
  2. test_is_concrete_class: Verifies FittingNet is now a concrete class, not factory output
  3. test_forward_pass: Tests dpmodel forward pass produces correct output shapes (single and batch)
  4. test_trainable_parameter_variants: Tests different trainable configurations (all trainable, all frozen, mixed)

pt_expt integration tests

File: source/tests/pt_expt/utils/test_network.py

Created TestFittingNetRefactor test suite with 4 tests:

  1. test_pt_expt_fitting_net_wraps_dpmodel: Verifies pt_expt wrapper inherits correctly and converts layers
  2. test_pt_expt_fitting_net_forward: Tests pt_expt forward pass returns torch.Tensor with correct shape
  3. test_serialization_round_trip_pt_expt: Tests pt_expt serialize/deserialize round-trip
  4. test_registry_converts_dpmodel_to_pt_expt: Tests try_convert_module auto-converts dpmodel to pt_expt

Verification

All tests pass:

# dpmodel network tests (includes new FittingNet tests)
python -m pytest source/tests/common/dpmodel/test_network.py -v
# 19 passed in 0.56s (was 16, added 3 FittingNet tests)

# dpmodel FittingNet tests specifically
python -m pytest source/tests/common/dpmodel/test_network.py::TestFittingNet -v
# 4 passed in 0.44s

# pt_expt network tests (EmbeddingNet + FittingNet)
python -m pytest source/tests/pt_expt/utils/test_network.py -v
# 14 passed in 0.45s

# Descriptor tests (verify refactoring doesn't break existing code)
python -m pytest source/tests/pt_expt/descriptor/ -v
# 8 passed in 5.43s

Benefits

  1. Type-based auto-detection: FittingNet now works with the registry mechanism
  2. Consistency: Same pattern as EmbeddingNet and other dpmodel classes
  3. Maintainability: Single source of truth for FittingNet in dpmodel
  4. Future-proof: Any dpmodel FittingNet instances can be auto-converted to pt_expt

Backward Compatibility

  • Serialization format unchanged (version 1)
  • All existing tests pass
  • make_fitting_network factory kept for pt/pd backends
  • No changes to public API

Files Changed

Modified

  • deepmd/dpmodel/utils/network.py: Concrete FittingNet class + deserialize fix
  • deepmd/pt_expt/utils/network.py: FittingNet wrapper + registration
  • source/tests/common/dpmodel/test_network.py: Added dpmodel FittingNet tests (3 new tests)
  • source/tests/pt_expt/utils/test_network.py: Added pt_expt integration tests (4 new tests)

Pattern

This refactoring follows the exact same pattern as EMBEDDING_NET_REFACTOR.md:

  1. Convert factory-generated class to concrete class in dpmodel
  2. Fix deserialize to use type(obj.layers[0])
  3. Create pt_expt wrapper with layer conversion in __init__
  4. Register with register_dpmodel_mapping
  5. Add comprehensive tests

Summary by CodeRabbit

  • New Features

    • Added PyTorch experimental descriptor implementations for multi-device support
    • Introduced unified PyTorch integration layer for automatic dpmodel-to-PyTorch conversions
  • Bug Fixes

    • Fixed implicit failures when statistics unavailable; now raises explicit errors
    • Improved device propagation for array operations across all descriptors
  • Refactor

    • Restructured network classes (EmbeddingNet, FittingNet) with explicit initialization and serialization methods
    • Enhanced array operations to support multi-device scenarios with proper device inference
  • Tests

    • Added comprehensive unit tests for PyTorch experimental descriptors
    • Extended test coverage for serialization consistency and cross-backend equivalence

@wanghan-iapcm wanghan-iapcm requested a review from njzjz February 8, 2026 14:42
@github-actions github-actions bot added the Python label Feb 8, 2026
@gemini-code-assist
Copy link

Summary of Changes

Hello @wanghan-iapcm, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly refactors the FittingNet class, transforming it from a dynamically generated entity into a concrete class. This change is crucial for enabling FittingNet to be properly registered and utilized within the pt_expt (PyTorch exportable) backend, aligning it with the architectural patterns established by previous refactorings like EmbeddingNet. The update enhances the system's ability to automatically detect and convert modules, leading to a more consistent and maintainable codebase. It also includes critical adjustments for array device handling and thread management to ensure robust operation across different PyTorch environments, all validated by a comprehensive suite of new tests.

Highlights

  • FittingNet Refactoring: Converted the FittingNet class from a factory-generated dynamic class to a concrete class within the deepmd.dpmodel backend. This change resolves issues with dynamic classes not being registrable and failing type matching in the pt_expt registry.
  • PyTorch Exportable (pt_expt) Integration: Introduced a FittingNet wrapper in deepmd/pt_expt/utils/network.py that inherits from the new concrete dpmodel.FittingNet and torch.nn.Module. This wrapper converts dpmodel layers to pt_expt NativeLayer (torch modules) and registers FittingNet for auto-conversion in the pt_expt backend.
  • Array Handling and Device Compatibility: Updated dpmodel/common.py to ensure arrays are moved to the CPU device when converting to NumPy for compatibility. Several descriptor files (dpa1.py, repflows.py, repformers.py, se_e2_a.py, se_r.py, se_t.py, se_t_tebd.py) were modified to explicitly specify the device for array creation (xp.zeros, xp.asarray) and to cast indices to int64 for PyTorch backend compatibility.
  • Thread Management Improvements: Added robust error handling to torch.set_num_interop_threads and torch.set_num_threads calls in deepmd/pt/utils/env.py and deepmd/pt_expt/utils/env.py to prevent RuntimeError if these functions are called multiple times or after threads have been created.
  • Comprehensive Testing: Added extensive new test cases for FittingNet in both dpmodel and pt_expt to verify concrete class behavior, forward pass, trainable parameter variants, serialization, cross-backend consistency, and proper registry conversion. New test files were also added for pt_expt descriptors and utilities.
Changelog
  • deepmd/backend/pt_expt.py
    • Added new file for PyTorch exportable backend definition.
  • deepmd/dpmodel/common.py
    • Modified to_numpy_array to move arrays to CPU device for numpy compatibility.
  • deepmd/dpmodel/descriptor/descriptor.py
    • Imported array_api_compat and updated extend_descrpt_stat to use xp.zeros and xp.concat with explicit device specification.
  • deepmd/dpmodel/descriptor/dpa1.py
    • Updated compute_input_stats to specify device for xp.asarray and cast idx to int64 for PyTorch compatibility.
  • deepmd/dpmodel/descriptor/repflows.py
    • Updated compute_input_stats to specify device for xp.asarray.
  • deepmd/dpmodel/descriptor/repformers.py
    • Updated compute_input_stats to specify device for xp.asarray.
  • deepmd/dpmodel/descriptor/se_e2_a.py
    • Updated compute_input_stats and call to specify device for xp.asarray and xp.zeros.
  • deepmd/dpmodel/descriptor/se_r.py
    • Updated compute_input_stats and call to specify device for xp.asarray and xp.zeros.
  • deepmd/dpmodel/descriptor/se_t.py
    • Updated compute_input_stats and call to specify device for xp.asarray and xp.zeros.
  • deepmd/dpmodel/descriptor/se_t_tebd.py
    • Updated compute_input_stats and call to specify device for xp.asarray, xp.zeros and cast idx to int64.
  • deepmd/dpmodel/utils/network.py
    • Refactored EmbeddingNet and FittingNet from factory functions to concrete classes.
    • Moved constructor logic to __init__.
    • Fixed deserialize to use type(obj.layers[0]) to preserve subclass layer types.
  • deepmd/dpmodel/utils/type_embed.py
    • Updated call and change_type_map to specify device for xp.eye, xp.zeros, xp.concat, and xp.asarray.
  • deepmd/env.py
    • Corrected environment variable lookup from TF_INTRA_OP_PARALLELISM_THREADS to TF_INTER_OP_PARALLELISM_THREADS.
  • deepmd/pt/utils/env.py
    • Added guards to torch.set_num_interop_threads and torch.set_num_threads to prevent RuntimeError if called multiple times.
  • deepmd/pt_expt/init.py
    • Added new file.
  • deepmd/pt_expt/common.py
    • Added new file defining dpmodel_setattr, register_dpmodel_mapping, and try_convert_module for automatic attribute conversion.
  • deepmd/pt_expt/descriptor/init.py
    • Added new file to import and register descriptor converters.
  • deepmd/pt_expt/descriptor/base_descriptor.py
    • Added new file defining BaseDescriptor for pt_expt.
  • deepmd/pt_expt/descriptor/se_e2_a.py
    • Added new file for DescrptSeA wrapper for pt_expt.
  • deepmd/pt_expt/descriptor/se_r.py
    • Added new file for DescrptSeR wrapper for pt_expt.
  • deepmd/pt_expt/descriptor/se_t.py
    • Added new file for DescrptSeT wrapper for pt_expt.
  • deepmd/pt_expt/descriptor/se_t_tebd.py
    • Added new file for DescrptSeTTebd wrapper for pt_expt.
  • deepmd/pt_expt/descriptor/se_t_tebd_block.py
    • Added new file for DescrptBlockSeTTebd wrapper for pt_expt.
  • deepmd/pt_expt/utils/init.py
    • Added new file to export pt_expt utilities.
  • deepmd/pt_expt/utils/env.py
    • Added new file for pt_expt environment settings, including thread guards.
  • deepmd/pt_expt/utils/exclude_mask.py
    • Added new file for AtomExcludeMask and PairExcludeMask wrappers for pt_expt.
  • deepmd/pt_expt/utils/network.py
    • Added new file for NativeLayer, NativeNet, EmbeddingNet, FittingNet, and NetworkCollection wrappers for pt_expt.
  • deepmd/pt_expt/utils/type_embed.py
    • Added new file for TypeEmbedNet wrapper for pt_expt.
  • pyproject.toml
    • Updated banned-module-level-imports and runtime-evaluated-base-classes to include deepmd.pt_expt.
  • source/tests/common/dpmodel/test_network.py
    • Added test_is_concrete_class, test_forward_pass, and test_trainable_parameter_variants for both EmbeddingNet and FittingNet.
  • source/tests/consistent/common.py
    • Added INSTALLED_PT_EXPT and pt_expt_class to CommonTest.
    • Added eval_pt_expt and related logic for pt_expt backend consistency tests.
  • source/tests/consistent/descriptor/common.py
    • Updated eval_pt_expt_descriptor to handle pt_expt descriptor evaluation.
  • source/tests/consistent/descriptor/test_se_e2_a.py
    • Added DescrptSeAPTExpt and updated skip_pt_expt and eval_pt_expt for se_e2_a descriptor.
  • source/tests/consistent/descriptor/test_se_r.py
    • Added DescrptSeRPTExpt and updated skip_pt_expt and eval_pt_expt for se_r descriptor.
  • source/tests/consistent/descriptor/test_se_t.py
    • Added DescrptSeTPTExpt and updated skip_pt_expt and eval_pt_expt for se_t descriptor.
  • source/tests/consistent/descriptor/test_se_t_tebd.py
    • Added DescrptSeTTebdPTExpt and updated skip_pt_expt and eval_pt_expt for se_t_tebd descriptor.
  • source/tests/pt/test_env_threads.py
    • Added new test file to verify thread setting guards.
  • source/tests/pt_expt/init.py
    • Added new file.
  • source/tests/pt_expt/conftest.py
    • Added new file for pt_expt test configuration.
  • source/tests/pt_expt/descriptor/init.py
    • Added new file.
  • source/tests/pt_expt/descriptor/test_se_e2_a.py
    • Added new test file for pt_expt DescrptSeA consistency and exportability.
  • source/tests/pt_expt/descriptor/test_se_r.py
    • Added new test file for pt_expt DescrptSeR consistency and exportability.
  • source/tests/pt_expt/descriptor/test_se_t.py
    • Added new test file for pt_expt DescrptSeT consistency and exportability.
  • source/tests/pt_expt/descriptor/test_se_t_tebd.py
    • Added new test file for pt_expt DescrptSeTTebd consistency and exportability.
  • source/tests/pt_expt/model/init.py
    • Added new file.
  • source/tests/pt_expt/utils/init.py
    • Added new file.
  • source/tests/pt_expt/utils/test_common.py
    • Added new test file for to_torch_array utility.
  • source/tests/pt_expt/utils/test_env.py
    • Added new test file for pt_expt thread setting guards.
  • source/tests/pt_expt/utils/test_exclusion_mask.py
    • Added new test file for pt_expt AtomExcludeMask and PairExcludeMask.
  • source/tests/pt_expt/utils/test_network.py
    • Added new test file for pt_expt EmbeddingNet and FittingNet refactoring, including serialization, cross-backend consistency, registry conversion, and trainable parameter handling.
Activity
  • The pull request initiates a significant refactoring of the FittingNet class, transitioning it from a dynamically generated class to a concrete one. This change is essential for its integration with the pt_expt backend's auto-detection registry.
  • This refactoring builds upon previous work, specifically the EmbeddingNet refactoring, ensuring a consistent architectural pattern across the codebase.
  • Extensive new test cases have been introduced to thoroughly validate the new FittingNet implementation, covering its behavior as a concrete class, its forward pass, handling of trainable parameters, serialization, and its interaction with the pt_expt registry.
  • Backward compatibility is maintained for existing pt and pd backends by retaining the original make_fitting_network factory function.
  • Several minor but important fixes related to array device handling and thread management in PyTorch environments have been incorporated to enhance robustness.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@wanghan-iapcm wanghan-iapcm changed the title refact(pt_expt): fitting net refact(dpmodel,pt_expt): fitting net Feb 8, 2026
@dosubot dosubot bot added the enhancement label Feb 8, 2026
tebd_dim,
tebd_input_mode,
resnet_dt,
excluded_types,

Check notice

Code scanning / CodeQL

Unused local variable Note test

Variable smooth is not used.
precision,
use_econf_tebd,
use_tebd_bias,
) = self.param

Check notice

Code scanning / CodeQL

Unused local variable Note test

Variable concat_output_tebd is not used.
use_econf_tebd,
use_tebd_bias,
) = self.param
return CommonTest.skip_pt_expt

Check notice

Code scanning / CodeQL

Unused local variable Note test

Variable precision is not used.
use_tebd_bias,
) = self.param
return CommonTest.skip_pt_expt

Check notice

Code scanning / CodeQL

Unused local variable Note test

Variable use_econf_tebd is not used.
) = self.param
return CommonTest.skip_pt_expt

@property

Check notice

Code scanning / CodeQL

Unused local variable Note test

Variable use_tebd_bias is not used.
Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request is a significant and well-executed refactoring. It successfully converts FittingNet to a concrete class, mirroring the previous EmbeddingNet refactoring, which is a great step towards a more consistent and maintainable codebase. The introduction of the pt_expt backend with its automatic wrapping mechanism is a clever design that will simplify future development. The changes to make the code more backend-agnostic using array_api_compat and ensuring device consistency are also excellent. The added tests are comprehensive and provide good coverage for the new features. I've found a couple of minor areas for improvement, which I've detailed in the comments.

dtype=des["dstd"].dtype,
device=array_api_compat.device(des["dstd"]),
)
xp = array_api_compat.array_namespace(des["davg"])

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This re-definition of xp is redundant if the else block on line 174 is executed, as xp is already defined on line 177. Consider defining xp once before the if-else block to improve code clarity and avoid redundancy.

Comment on lines +1054 to +1055
else:
pass

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The else: pass statement is redundant and can be removed for conciseness.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: e263270ca2

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

class PyTorchExportableBackend(Backend):
"""PyTorch exportable backend."""

name = "PyTorch Exportable"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Use a registered key as backend canonical name

Setting name to "PyTorch Exportable" breaks CLI backend selection because main_parser() canonicalizes --backend/--pt-expt to backend.name.lower(), yielding "pytorch exportable", but deepmd/main.py only accepts registered backend keys (e.g. pt-expt, pytorch-exportable). As a result, invoking dp --pt-expt ... or dp --backend pt-expt ... fails early with Unknown backend before dispatch.

Useful? React with 👍 / 👎.

type[DeepEvalBackend]
The Deep Eval backend of the backend.
"""
raise NotImplementedError

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Stop advertising unsupported deep-eval and IO hooks

This backend is marked with DEEP_EVAL/NEIGHBOR_STAT/IO and .pte suffix support, so dispatcher paths can select it (for example .pte model detection in DeepEval), but deep_eval, neighbor_stat, and serialization hooks still raise NotImplementedError. That makes backend-selected inference/inspection flows fail at runtime instead of gracefully reporting unsupported capabilities.

Useful? React with 👍 / 👎.

obj = cls(**data)
# Reinitialize layers from serialized data, using the same layer type
# that __init__ created (respects subclass overrides via MRO).
layer_type = type(obj.layers[0])

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Handle empty embedding layers in deserialize

EmbeddingNet.deserialize() now indexes obj.layers[0] to infer layer_type, but EmbeddingNet.__init__ allows neuron=[], which creates an empty layer list. Deserializing such a serialized network now raises IndexError before loading layers, whereas the previous implementation handled empty-layer round trips by reinitializing directly from serialized layers.

Useful? React with 👍 / 👎.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 8, 2026

📝 Walkthrough

Walkthrough

Introduces a PyTorch exportable backend (pt_expt) with device-aware array operations in core descriptors. Refactors network factories into explicit classes with serialization support. Adds wrappers converting dpmodel components to PyTorch modules via registry-based mapping.

Changes

Cohort / File(s) Summary
Descriptor Device-Aware Refactoring
deepmd/dpmodel/descriptor/descriptor.py, deepmd/dpmodel/descriptor/dpa1.py, deepmd/dpmodel/descriptor/se_t.py, deepmd/dpmodel/descriptor/se_t_tebd.py
Migrate NumPy operations to array_api_compat with explicit device/dtype inference. Add defensive get_stats guards raising RuntimeError when statistics uncomputed. Cast indices to int64 for backend compatibility.
Dpmodel Network Class Refactoring
deepmd/dpmodel/utils/network.py
Convert factory-based EmbeddingNet and FittingNet aliases to explicit subclasses (EmbeddingNet extends NativeNet; FittingNet extends EmbeddingNet) with concrete init, serialize/deserialize methods, and public attributes (in_dim, neuron, activation_function, resnet_dt, out_dim, bias_out).
Dpmodel Type Embedding Device-Aware Updates
deepmd/dpmodel/utils/type_embed.py
Replace NumPy operations and in-place mutations with array_api_compat-aware non-in-place expressions. Propagate device/dtype explicitly for eye matrices, padding, and weight updates using xp.eye, xp.concat, xp.asarray.
PT Expt Common Infrastructure
deepmd/pt_expt/common.py
Introduce registry-based conversion system mapping dpmodel classes to pt_expt PyTorch modules. Add utilities: register_dpmodel_mapping, try_convert_module, dpmodel_setattr for synchronized buffer/parameter management, to_torch_array for tensor normalization.
PT Expt Descriptor Wrappers
deepmd/pt_expt/descriptor/__init__.py, deepmd/pt_expt/descriptor/se_e2_a.py, deepmd/pt_expt/descriptor/se_r.py, deepmd/pt_expt/descriptor/se_t.py, deepmd/pt_expt/descriptor/se_t_tebd.py, deepmd/pt_expt/descriptor/se_t_tebd_block.py
Create PyTorch module wrappers (DescrptSeA, DescrptSeR, DescrptSeT, DescrptSeTTebd, DescrptBlockSeTTebd) extending DP model descriptors with dual inheritance, forward method routing to dpmodel.call, and dpmodel_setattr attribute handling. Register converter mappings for serialization.
PT Expt Utility Modules
deepmd/pt_expt/utils/__init__.py, deepmd/pt_expt/utils/network.py, deepmd/pt_expt/utils/type_embed.py, deepmd/pt_expt/utils/exclude_mask.py
Create PyTorch wrappers for networks (TorchArrayParam, NativeLayer, NativeNet, EmbeddingNet, FittingNet, LayerNorm), NetworkCollection with NETWORK_TYPE_MAP, and exclude masks (AtomExcludeMask, PairExcludeMask, TypeEmbedNet). Register dpmodel-to-pt_expt mappings; synchronize PyTorch state via dpmodel_setattr.
Dpmodel Network Tests
source/tests/common/dpmodel/test_network.py
Add unit tests validating EmbeddingNet and FittingNet concrete class identity, forward shapes, trainable configurations (all/none/mixed), and serialization/deserialization consistency of trainable flags.
PT Descriptor Consistency Tests
source/tests/consistent/descriptor/test_se_t.py, source/tests/consistent/descriptor/test_se_t_tebd.py
Extend test infrastructure with PT Expt support: add pt_expt_class attribute, skip_pt_expt property, eval_pt_expt method for parallel PT and PT Expt descriptor evaluation.
PT Expt Descriptor Unit Tests
source/tests/pt_expt/descriptor/test_se_t.py, source/tests/pt_expt/descriptor/test_se_t_tebd.py
Validate PT Expt DescrptSeT and DescrptSeTTebd consistency across precision variants, serialization round-trips, numerical correctness, and torch.export.export exportability.
PT Expt Network Tests
source/tests/pt_expt/utils/test_network.py
Comprehensive testing of NativeLayer, EmbeddingNet, and FittingNet wrappers: buffer/parameter management, serialization, registry conversion, cross-backend equivalence, auto-conversion via dpmodel_setattr, and trainable configuration preservation.

Sequence Diagram(s)

sequenceDiagram
    participant DP as DPModel Descriptor
    participant Reg as Registry (common.py)
    participant Conv as Converter Function
    participant PT as PT Expt Module
    participant PyTorch as PyTorch Framework

    DP->>Reg: register_dpmodel_mapping(DP_class, converter)
    Note over Reg: Store mapping: DP_class → converter
    
    DP->>Reg: try_convert_module(dp_instance)
    Reg->>Reg: Check if dp_instance type in registry
    Reg->>Conv: Found converter, call it
    Conv->>PT: Create PT_expt wrapper with dual inheritance
    PT->>PT: Initialize DP base + torch.nn.Module
    PT->>PT: Register buffers/parameters via dpmodel_setattr
    Reg-->>Conv: PT_expt module
    Conv-->>Reg: Converted module
    Reg-->>DP: PT Expt module returned
    
    DP->>PT: Assign to attribute (via dpmodel_setattr)
    PT->>PyTorch: Synchronize with PyTorch state_dict
    PyTorch-->>PT: Module registered for export/device movement
Loading
sequenceDiagram
    participant App as Application
    participant PT as PT Expt Descriptor
    participant DP as DPModel Compute
    participant XP as Array API Backend

    App->>PT: forward(extended_coord, extended_atype, nlist, ...)
    PT->>PT: __call__ routes through nn.Module.__call__
    PT->>DP: self.call(extended_coord, extended_atype, nlist, mapping)
    DP->>XP: xp.asarray(array, device=target_device, dtype=target_dtype)
    XP-->>DP: Device-aware array
    DP->>XP: xp.take_along_axis(..., indices.astype(xp.int64))
    XP-->>DP: Gathered values
    DP-->>PT: (descrpt, rot_mat, g2, h2, sw)
    PT-->>App: Return 5-tuple of torch.Tensor
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

Possibly related PRs

  • feat(jax/array-api): se_t_tebd #4288: Modifies descriptor se_t_tebd with array-API backend adoption (xp operations, device-aware handling, index casting) aligning with this PR's device-aware refactoring across descriptors.
  • feat(jax/array-api): dpa1 #4160: Updates dpa1.py, type_embed.py and related descriptor/utils modules migrating NumPy to array-api-compatible backends and device-aware operations, directly overlapping with dpmodel changes here.
  • feat: new backend pytorch exportable. #5194: Implements parallel PyTorch-exportable backend (pt_expt) with matching descriptor wrappers and dpmodel device-aware refactoring, covering the same functional scope and file overlap.

Suggested labels

new feature

Suggested reviewers

  • iProzd
  • njzjz
🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 34.24% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title "refact(dpmodel,pt_expt): fitting net" directly summarizes the main change: refactoring FittingNet across both dpmodel and pt_expt backends to convert it from a factory-generated class to a concrete class.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
deepmd/dpmodel/descriptor/dpa1.py (1)

921-927: ⚠️ Potential issue | 🟡 Minor

Initialize self.stats = None in DescrptBlockSeAtten.__init__ to prevent AttributeError.

self.stats is only assigned in compute_input_stats (line 909) and is never initialized in __init__. If get_stats is called before compute_input_stats, the check if self.stats is None: on line 923 will raise AttributeError instead of the intended RuntimeError.

Proposed fix

Add self.stats = None in DescrptBlockSeAtten.__init__ (after line 845):

         self.orig_sel = self.sel
+        self.stats = None
🤖 Fix all issues with AI agents
In `@deepmd/backend/pt_expt.py`:
- Around line 37-42: The class-level features tuple declares
Backend.Feature.DEEP_EVAL, NEIGHBOR_STAT, and IO but the corresponding
properties deep_eval, neighbor_stat, serialize_hook, and deserialize_hook raise
NotImplementedError; either remove those three flags from the features variable
so the backend is not advertised for those entrypoints, or implement the missing
methods (deep_eval(), neighbor_stat(), serialize_hook()/deserialize_hook()) to
provide the expected behavior; update the class-level features tuple (features)
and/or implement the named properties/methods (deep_eval, neighbor_stat,
serialize_hook, deserialize_hook) accordingly so get_backends_by_feature and
convert_backend no longer hit NotImplementedError.

In `@deepmd/dpmodel/utils/network.py`:
- Around line 889-896: EmbeddingNet.deserialize raises IndexError when
obj.layers is empty because it accesses obj.layers[0]; change the logic in
EmbeddingNet.deserialize to first check if obj.layers and layers are non-empty
before reading type(obj.layers[0]). If both are non-empty, keep the existing
behavior (layer_type = type(obj.layers[0]) and reconstruct obj.layers =
type(obj.layers)([layer_type.deserialize(layer) for layer in layers]));
otherwise, set obj.layers = type(obj.layers)([]) or simply use
type(obj.layers)([layer.deserialize(...) for layer in layers]) without indexing
when layers contains serialized entries, and when both are empty leave
obj.layers as an empty sequence. Ensure you update the code path in
EmbeddingNet.deserialize (the block that defines layer_type and assigns
obj.layers) to implement this guard.

In `@source/tests/consistent/descriptor/test_se_e2_a.py`:
- Around line 546-566: The test method eval_pt_expt uses env.DEVICE which is
only imported when INSTALLED_PT is true, causing a potential NameError when
INSTALLED_PT_EXPT is true but INSTALLED_PT is false; change the test to use a
guarded common symbol (e.g. PT_DEVICE) or add a local fallback: import or
reference PT_DEVICE from the shared descriptor test helper (which is available
when INSTALLED_PT or INSTALLED_PT_EXPT is true) or resolve DEVICE at the top of
this test module with a conditional that falls back to a safe default before
calling eval_pt_expt, and replace direct uses of env.DEVICE in eval_pt_expt with
PT_DEVICE (or the fallback symbol) to remove the hidden dependency on env and
INSTALLED_PT.
🧹 Nitpick comments (9)
source/tests/pt/test_env_threads.py (1)

21-28: Consider using pytest's caplog fixture instead of manually patching logging.Logger.warning.

The manual monkey-patch of logging.Logger.warning works but is fragile and non-idiomatic. pytest's built-in caplog fixture is designed for this:

def test_env_threads_guard_handles_runtimeerror(monkeypatch, caplog) -> None:
    ...
    with caplog.at_level(logging.WARNING):
        importlib.reload(env)
    assert any("Could not set torch interop threads" in r.message for r in caplog.records)
    assert any("Could not set torch intra threads" in r.message for r in caplog.records)
deepmd/dpmodel/utils/network.py (1)

816-826: Mutable default argument for neuron.

Both EmbeddingNet.__init__ (line 819) and FittingNet.__init__ (line 1042) use a mutable list as a default argument (neuron: list[int] = [24, 48, 96]). This is a well-known Python gotcha (Ruff B006). It's a pre-existing pattern inherited from the factory functions (lines 714, 930), so this may be deferred, but worth noting.

♻️ Idiomatic fix (applied to both classes)
     def __init__(
         self,
         in_dim: int,
-        neuron: list[int] = [24, 48, 96],
+        neuron: list[int] | None = None,
         ...
     ) -> None:
+        if neuron is None:
+            neuron = [24, 48, 96]

Also applies to: 1038-1048

deepmd/pt_expt/descriptor/__init__.py (1)

3-3: Unused noqa directive (Ruff RUF100).

Ruff reports that F401 is not enabled, so the # noqa: F401 directive is unnecessary. That said, it does document the intent that this is a side-effect import. You could remove it to silence the linter or suppress RUF100 instead — either way is fine.

deepmd/pt_expt/utils/type_embed.py (1)

14-15: Side-effect import is necessary for registration ordering — good.

The comment clearly explains why network must be imported before TypeEmbedNet is used. Same Ruff RUF100 noqa nit as noted in __init__.py — the directive is technically unnecessary if F401 isn't enabled, but it documents intent.

source/tests/pt_expt/descriptor/test_se_t.py (1)

65-65: Nit: prefix unused unpacked variables with _.

Static analysis flags gr1 (Line 65) and gr2 (Line 85) as unused. Since se_t returns None for these, prefixing with _ silences the warning.

-            rd1, gr1, _, _, sw1 = dd1(
+            rd1, _gr1, _, _, sw1 = dd1(
-            rd2, gr2, _, _, sw2 = dd2.call(
+            rd2, _gr2, _, _, sw2 = dd2.call(
source/tests/pt_expt/utils/test_network.py (1)

286-404: TestFittingNetRefactor has less coverage than TestEmbeddingNetRefactor.

TestEmbeddingNetRefactor includes test_cross_backend_consistency, test_deserialize_preserves_layer_type, test_auto_conversion_in_setattr, and test_trainable_parameter_handling — none of which have counterparts here. Consider adding at least test_cross_backend_consistency and test_trainable_parameter_handling for FittingNet to ensure parity.

Also, FittingNet is imported locally in each method while EmbeddingNet is imported at the top of the file — consider hoisting the FittingNet import to the top for consistency.

deepmd/pt_expt/utils/env.py (1)

19-20: Minor: import torch placed after logger setup.

The torch import on line 20 sits between the logger initialization and the code that uses it (line 22 onwards). Consider moving it up to the import block (after line 7) for standard import ordering.

deepmd/pt_expt/common.py (1)

45-82: No duplicate-registration guard.

register_dpmodel_mapping silently overwrites if called with the same dpmodel_cls twice. This is probably fine for the current use case (registrations happen at module import time), but a debug-level log or assertion on duplicate keys could help catch accidental re-registrations during development.

deepmd/pt_expt/utils/network.py (1)

27-37: Remove unused noqa directive.

Ruff reports PYI034 is not enabled, so the # noqa: PYI034 suppression on line 28 is unnecessary.

🧹 Proposed fix
-    def __new__(  # noqa: PYI034
+    def __new__(
         cls, data: Any = None, requires_grad: bool = True
     ) -> "TorchArrayParam":

Comment on lines 37 to 42
features: ClassVar[Backend.Feature] = (
Backend.Feature.ENTRY_POINT
| Backend.Feature.DEEP_EVAL
| Backend.Feature.NEIGHBOR_STAT
| Backend.Feature.IO
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

cd /tmp/repo && cat -n deepmd/backend/pt_expt.py | head -120

Repository: deepmodeling/deepmd-kit

Length of output: 126


🏁 Script executed:

# Check how Backend.Feature flags are used to gate method calls
rg -n "features\s*&" --type=py -A2 -B2 | head -100

Repository: deepmodeling/deepmd-kit

Length of output: 352


🏁 Script executed:

# Search for places where deep_eval, neighbor_stat, serialize_hook, deserialize_hook are called
rg -n "\.deep_eval\(|\.neighbor_stat\(|\.serialize_hook\(|\.deserialize_hook\(" --type=py -B3 | head -150

Repository: deepmodeling/deepmd-kit

Length of output: 350


🏁 Script executed:

# Check other backend implementations to see if similar patterns exist
fd "backend.*\.py$" --type=f | grep -E "(pt|cpu|gpu|tensorflow)" | head -5

Repository: deepmodeling/deepmd-kit

Length of output: 103


🏁 Script executed:

cat -n deepmd/backend/backend.py | sed -n '80,105p'

Repository: deepmodeling/deepmd-kit

Length of output: 946


🏁 Script executed:

cat -n deepmd/backend/pt_expt.py

Repository: deepmodeling/deepmd-kit

Length of output: 3606


🏁 Script executed:

cat -n deepmd/utils/update_sel.py | sed -n '110,130p'

Repository: deepmodeling/deepmd-kit

Length of output: 833


🏁 Script executed:

# Check if pt backend implements neighbor_stat
rg -n "def neighbor_stat" deepmd/backend/pt.py -A5

Repository: deepmodeling/deepmd-kit

Length of output: 127


🏁 Script executed:

# Check the get_backends and how backends are selected
rg -n "get_backends\|Backend.get" deepmd/backend/backend.py -A3 -B1 | head -80

Repository: deepmodeling/deepmd-kit

Length of output: 49


🏁 Script executed:

# Find all backend implementations
fd "backend.*\.py$" deepmd/backend --type f

Repository: deepmodeling/deepmd-kit

Length of output: 92


🏁 Script executed:

# Search for get_backends_by_feature usage
rg -n "get_backends_by_feature" --type=py -B2 -A5

Repository: deepmodeling/deepmd-kit

Length of output: 2044


🏁 Script executed:

# Search for calls to deep_eval, neighbor_stat that might gate on features
rg -n "backend\.deep_eval|backend\.neighbor_stat|backend\.serialize_hook|backend\.deserialize_hook" --type=py -B3 -A1

Repository: deepmodeling/deepmd-kit

Length of output: 1401


🏁 Script executed:

# Check if serialize_hook and deserialize_hook are actually used anywhere
rg -n "serialize_hook|deserialize_hook" --type=py | grep -v "def serialize_hook\|def deserialize_hook\|raise NotImplementedError" | head -20

Repository: deepmodeling/deepmd-kit

Length of output: 410


🏁 Script executed:

# Check the neighbor_stat entrypoint command
cat -n deepmd/entrypoints/neighbor_stat.py | sed -n '50,85p'

Repository: deepmodeling/deepmd-kit

Length of output: 1345


🏁 Script executed:

# Check if there's a deep_eval entrypoint
rg -n "get_backends_by_feature.*DEEP_EVAL" --type=py -B2 -A8

Repository: deepmodeling/deepmd-kit

Length of output: 49


🏁 Script executed:

# Check what happens when a .pte file is detected
cat -n deepmd/backend/backend.py | sed -n '96,120p'

Repository: deepmodeling/deepmd-kit

Length of output: 1120


🏁 Script executed:

# Verify what suffixes pt-expt has and if they're used for backend detection
rg -n "\.pte|suffixes" deepmd/backend/pt_expt.py

Repository: deepmodeling/deepmd-kit

Length of output: 156


Remove unimplemented feature declarations or implement the methods.

The pt-expt backend declares DEEP_EVAL, NEIGHBOR_STAT, and IO features, but the corresponding properties (deep_eval, neighbor_stat, serialize_hook, deserialize_hook) all raise NotImplementedError. This creates real runtime failures:

  • The neighbor_stat entrypoint (deepmd/entrypoints/neighbor_stat.py:75-80) calls get_backends_by_feature(Backend.Feature.NEIGHBOR_STAT), which includes pt-expt. If a user specifies --backend pt-expt, the call to backend_obj.neighbor_stat will crash.
  • The convert_backend entrypoint (deepmd/entrypoints/convert_backend.py:28-29) detects backends by file suffix. Since pt-expt has suffix .pte, loading or saving .pte files will crash when accessing serialize_hook or deserialize_hook.

Either remove these feature flags until the methods are implemented, or provide implementations.

🤖 Prompt for AI Agents
In `@deepmd/backend/pt_expt.py` around lines 37 - 42, The class-level features
tuple declares Backend.Feature.DEEP_EVAL, NEIGHBOR_STAT, and IO but the
corresponding properties deep_eval, neighbor_stat, serialize_hook, and
deserialize_hook raise NotImplementedError; either remove those three flags from
the features variable so the backend is not advertised for those entrypoints, or
implement the missing methods (deep_eval(), neighbor_stat(),
serialize_hook()/deserialize_hook()) to provide the expected behavior; update
the class-level features tuple (features) and/or implement the named
properties/methods (deep_eval, neighbor_stat, serialize_hook, deserialize_hook)
accordingly so get_backends_by_feature and convert_backend no longer hit
NotImplementedError.

Comment on lines +889 to +896
obj = cls(**data)
# Reinitialize layers from serialized data, using the same layer type
# that __init__ created (respects subclass overrides via MRO).
layer_type = type(obj.layers[0])
obj.layers = type(obj.layers)(
[layer_type.deserialize(layer) for layer in layers]
)
return obj
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

IndexError when neuron is empty in EmbeddingNet.deserialize.

If neuron=[], obj.layers will be an empty list after cls(**data), so type(obj.layers[0]) on line 892 raises IndexError. The same pattern in FittingNet.deserialize is safe because the output layer always exists.

🐛 Proposed fix: guard against empty layers
         obj = cls(**data)
         # Reinitialize layers from serialized data, using the same layer type
         # that __init__ created (respects subclass overrides via MRO).
-        layer_type = type(obj.layers[0])
-        obj.layers = type(obj.layers)(
-            [layer_type.deserialize(layer) for layer in layers]
-        )
+        if obj.layers:
+            layer_type = type(obj.layers[0])
+            obj.layers = type(obj.layers)(
+                [layer_type.deserialize(layer) for layer in layers]
+            )
         return obj
🤖 Prompt for AI Agents
In `@deepmd/dpmodel/utils/network.py` around lines 889 - 896,
EmbeddingNet.deserialize raises IndexError when obj.layers is empty because it
accesses obj.layers[0]; change the logic in EmbeddingNet.deserialize to first
check if obj.layers and layers are non-empty before reading type(obj.layers[0]).
If both are non-empty, keep the existing behavior (layer_type =
type(obj.layers[0]) and reconstruct obj.layers =
type(obj.layers)([layer_type.deserialize(layer) for layer in layers]));
otherwise, set obj.layers = type(obj.layers)([]) or simply use
type(obj.layers)([layer.deserialize(...) for layer in layers]) without indexing
when layers contains serialized entries, and when both are empty leave
obj.layers as an empty sequence. Ensure you update the code path in
EmbeddingNet.deserialize (the block that defines layer_type and assigns
obj.layers) to implement this guard.

Comment on lines 546 to 566
def eval_pt_expt(self, pt_expt_obj: Any) -> Any:
pt_expt_obj.compute_input_stats(
[
{
"r0": None,
"coord": torch.from_numpy(self.coords)
.reshape(-1, self.natoms[0], 3)
.to(env.DEVICE),
"atype": torch.from_numpy(self.atype.reshape(1, -1)).to(env.DEVICE),
"box": torch.from_numpy(self.box.reshape(1, 3, 3)).to(env.DEVICE),
"natoms": self.natoms[0],
}
]
)
return self.eval_pt_expt_descriptor(
pt_expt_obj,
self.natoms,
self.coords,
self.atype,
self.box,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

eval_pt_expt in TestSeAStat depends on env.DEVICE which is only imported under INSTALLED_PT.

env is imported at line 33 inside if INSTALLED_PT:, but eval_pt_expt (line 553) references env.DEVICE. If INSTALLED_PT_EXPT is True while INSTALLED_PT is False, this will raise a NameError. In practice this is unlikely since both backends require torch, but it's a latent coupling.

Consider importing PT_DEVICE from the common descriptor test helper (which guards on INSTALLED_PT or INSTALLED_PT_EXPT) or adding a local fallback:

 if INSTALLED_PT_EXPT:
     from deepmd.pt_expt.descriptor.se_e2_a import DescrptSeA as DescrptSeAPTExpt
+    if not INSTALLED_PT:
+        from deepmd.pt.utils import env
 else:
     DescrptSeAPTExpt = None
🤖 Prompt for AI Agents
In `@source/tests/consistent/descriptor/test_se_e2_a.py` around lines 546 - 566,
The test method eval_pt_expt uses env.DEVICE which is only imported when
INSTALLED_PT is true, causing a potential NameError when INSTALLED_PT_EXPT is
true but INSTALLED_PT is false; change the test to use a guarded common symbol
(e.g. PT_DEVICE) or add a local fallback: import or reference PT_DEVICE from the
shared descriptor test helper (which is available when INSTALLED_PT or
INSTALLED_PT_EXPT is true) or resolve DEVICE at the top of this test module with
a conditional that falls back to a safe default before calling eval_pt_expt, and
replace direct uses of env.DEVICE in eval_pt_expt with PT_DEVICE (or the
fallback symbol) to remove the hidden dependency on env and INSTALLED_PT.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@deepmd/dpmodel/utils/type_embed.py`:
- Around line 210-222: The random array generation uses
np.random.default_rng().random(..., dtype=PRECISION_DICT[self.precision]) which
fails for float16/half; change the generation to always request an ndarray with
a supported NumPy dtype (e.g., np.float32 or np.float64) and then cast via
xp.asarray(..., dtype=first_layer_matrix.dtype,
device=array_api_compat.device(first_layer_matrix)) before concatenation—update
the code paths involving extend_type_params, PRECISION_DICT/self.precision,
xp.asarray, first_layer_matrix, and xp.concat accordingly; optionally, if
reproducibility is required, create or accept a seeded RNG instance instead of
calling np.random.default_rng() inline.
🧹 Nitpick comments (2)
deepmd/dpmodel/utils/network.py (2)

813-850: Mutable default argument for neuron.

Ruff B006 flags neuron: list[int] = [24, 48, 96]. While the list isn't mutated here (a new list is built each time), the same pattern exists in the old factory function, so this is pre-existing. Consider switching to None with an internal default if you want to clean it up.

♻️ Suggested fix
     def __init__(
         self,
         in_dim: int,
-        neuron: list[int] = [24, 48, 96],
+        neuron: list[int] | None = None,
         activation_function: str = "tanh",
         resnet_dt: bool = False,
         precision: str = DEFAULT_PRECISION,
         seed: int | list[int] | None = None,
         bias: bool = True,
         trainable: bool | list[bool] = True,
     ) -> None:
+        if neuron is None:
+            neuron = [24, 48, 96]
         layers = []

1035-1078: Same mutable default neuron (B006) — same optional cleanup applies.

Same Ruff B006 as EmbeddingNet. The else: pass on lines 1051-1052 is a no-op and could be dropped for clarity, but this mirrors the existing factory code.

Comment on lines +210 to 222
# Create random params with same dtype and device as first_layer_matrix
extend_type_params = np.random.default_rng().random(
[len(type_map), first_layer_matrix.shape[-1]],
dtype=PRECISION_DICT[self.precision],
)
extend_type_params = xp.asarray(
extend_type_params,
dtype=first_layer_matrix.dtype,
device=array_api_compat.device(first_layer_matrix),
)
first_layer_matrix = np.concatenate(
first_layer_matrix = xp.concat(
[first_layer_matrix, extend_type_params], axis=0
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Find PRECISION_DICT definition
rg "PRECISION_DICT\s*=" --type py -A 5

Repository: deepmodeling/deepmd-kit

Length of output: 3275


🏁 Script executed:

#!/bin/bash
# Check imports in type_embed.py
head -50 deepmd/dpmodel/utils/type_embed.py | grep -E "^import|^from"

Repository: deepmodeling/deepmd-kit

Length of output: 319


🏁 Script executed:

#!/bin/bash
# Check full imports from deepmd.dpmodel.common
sed -n '1,50p' deepmd/dpmodel/utils/type_embed.py | grep -A 10 "from deepmd.dpmodel.common import"

Repository: deepmodeling/deepmd-kit

Length of output: 306


🌐 Web query:

numpy random.default_rng().random() dtype support float16

💡 Result:

np.random.default_rng().random(..., dtype=...) does not support float16. The dtype argument for Generator.random is documented as supporting only float64 and float32 (native byte order). [1], [2]

Workaround: generate float32 (or float64) and cast:

rng = np.random.default_rng()
x16 = rng.random(size=..., dtype=np.float32).astype(np.float16)

(NumPy won’t generate float16 uniforms natively via random.) [1], [5]

Sources:
[1] NumPy docs: numpy.random.Generator.random (v1.20)
[2] NumPy docs: numpy.random.Generator.random (v2.4.dev)
[5] Stack Overflow discussion noting only float32/float64 are accepted


🏁 Script executed:

#!/bin/bash
# Check the class/method containing this code and how self.precision is used
sed -n '180,230p' deepmd/dpmodel/utils/type_embed.py

Repository: deepmodeling/deepmd-kit

Length of output: 2385


🏁 Script executed:

#!/bin/bash
# Search for float16 usage in tests or with TypeEmbedding
rg "precision.*float16|float16.*precision" deepmd/dpmodel --type py -B 2 -A 2

Repository: deepmodeling/deepmd-kit

Length of output: 728


Fix dtype handling in random generation to avoid runtime errors with float16 precision.

  1. np.random.default_rng().random(..., dtype=...) only supports float32 and float64. Passing PRECISION_DICT[self.precision] will raise a TypeError if precision is "float16" or "half" since those map to np.float16, which is unsupported. Generate as float32/float64 and cast if needed, or use first_layer_matrix.dtype directly (which the code already converts to afterward).

  2. np.random.default_rng() creates an unseeded RNG on every call. If reproducibility is needed during type-map changes, consider passing a seed or using a seeded generator instance.

🤖 Prompt for AI Agents
In `@deepmd/dpmodel/utils/type_embed.py` around lines 210 - 222, The random array
generation uses np.random.default_rng().random(...,
dtype=PRECISION_DICT[self.precision]) which fails for float16/half; change the
generation to always request an ndarray with a supported NumPy dtype (e.g.,
np.float32 or np.float64) and then cast via xp.asarray(...,
dtype=first_layer_matrix.dtype,
device=array_api_compat.device(first_layer_matrix)) before concatenation—update
the code paths involving extend_type_params, PRECISION_DICT/self.precision,
xp.asarray, first_layer_matrix, and xp.concat accordingly; optionally, if
reproducibility is required, create or accept a seeded RNG instance instead of
calling np.random.default_rng() inline.

@codecov
Copy link

codecov bot commented Feb 8, 2026

Codecov Report

❌ Patch coverage is 98.19820% with 4 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.03%. Comparing base (5c2ca51) to head (ad83d98).

Files with missing lines Patch % Lines
deepmd/pt_expt/common.py 92.85% 2 Missing ⚠️
deepmd/pt_expt/utils/type_embed.py 88.88% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5207      +/-   ##
==========================================
+ Coverage   81.99%   82.03%   +0.03%     
==========================================
  Files         724      728       +4     
  Lines       73807    73943     +136     
  Branches     3616     3615       -1     
==========================================
+ Hits        60519    60659     +140     
+ Misses      12124    12121       -3     
+ Partials     1164     1163       -1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant