Skip to content

Conversation

@wanghan-iapcm
Copy link
Collaborator

@wanghan-iapcm wanghan-iapcm commented Feb 12, 2026

Summary by CodeRabbit

  • Bug Fixes

    • Improved tensor/device and dtype handling for consistent behavior across NumPy and PyTorch backends.
    • Fixed deserialization when layer collections are empty to avoid errors.
  • New Features

    • Added experimental PyTorch fitting wrappers: EnergyFittingNet and InvarFitting for tensor-based workflows and export/tracing.
    • Renamed descriptor registration keys to streamlined identifiers.
    • Package exports updated to surface new fitting entry points.
  • Tests

    • Added extensive tests for energy/invariant fitting, statistics computation, and PyTorch export compatibility.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @wanghan-iapcm, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the pt_expt backend by integrating energy fitting functionalities. It introduces dedicated PyTorch experimental classes for energy and invariant fitting, ensuring seamless operation with existing DeepMD models. The changes also improve backend compatibility for tensor operations and refine descriptor naming conventions, all supported by a robust suite of new unit tests to guarantee correctness and stability.

Highlights

  • PyTorch Experimental (pt_expt) Energy Fitting: Introduced new energy fitting capabilities within the pt_expt backend by implementing EnergyFittingNet and InvarFitting classes. These classes adapt the existing dpmodel fitting functionalities to the experimental PyTorch environment, enabling energy calculations and related operations.
  • Backend Compatibility and Tensor Handling: Updated general_fitting.py to use array_api_compat for handling tensor operations (fparam_avg, aparam_avg, etc.) consistently across different array backends (NumPy and PyTorch). This ensures that tensors are created and manipulated on the correct device and with the appropriate data types.
  • Descriptor Naming Refinement: Refactored the registration names for several descriptors (se_e2_a, se_r, se_t, se_t_tebd) in the pt_expt module by removing the _expt suffix, indicating a more stable or integrated status for these descriptors.
  • Comprehensive Unit Testing: Added extensive unit tests for the new pt_expt fitting modules, covering self-consistency of serialization/deserialization, correct handling of input parameters (fparam, aparam), mask functionality, exception handling, and compatibility with torch.export for model tracing and deployment. New tests also validate the compute_input_stats method for statistical calculations.
Changelog
  • deepmd/dpmodel/fitting/general_fitting.py
    • Updated tensor assignment for fparam_avg, fparam_inv_std, aparam_avg, and aparam_inv_std to use array_api_compat.asarray for improved backend compatibility.
    • Added device specification to xp.zeros call to ensure tensors are created on the correct device.
  • deepmd/dpmodel/utils/network.py
    • Added a conditional check to FittingNet.deserialize to handle cases where obj.layers might be empty during deserialization.
  • deepmd/pt_expt/descriptor/se_e2_a.py
    • Renamed descriptor registration from se_e2_a_expt to se_e2_a and se_a_expt to se_a.
  • deepmd/pt_expt/descriptor/se_r.py
    • Renamed descriptor registration from se_e2_r_expt to se_e2_r and se_r_expt to se_r.
  • deepmd/pt_expt/descriptor/se_t.py
    • Renamed descriptor registration from se_e3_expt to se_e3, se_at_expt to se_at, and se_a_3be_expt to se_a_3be.
  • deepmd/pt_expt/descriptor/se_t_tebd.py
    • Renamed descriptor registration from se_e3_tebd_expt to se_e3_tebd.
  • deepmd/pt_expt/fitting/init.py
    • Added __init__.py file to define the pt_expt.fitting module.
    • Exported BaseFitting, EnergyFittingNet, and InvarFitting.
  • deepmd/pt_expt/fitting/base_fitting.py
    • Added base_fitting.py to define BaseFitting for the pt_expt backend, leveraging torch.Tensor.
  • deepmd/pt_expt/fitting/ener_fitting.py
    • Added ener_fitting.py to implement EnergyFittingNet for pt_expt, inheriting from dpmodel and torch.nn.Module.
    • Implemented __init__, __call__, __setattr__, and forward methods for PyTorch compatibility.
    • Registered EnergyFittingNet for deserialization mapping from dpmodel.
  • deepmd/pt_expt/fitting/invar_fitting.py
    • Added invar_fitting.py to implement InvarFitting for pt_expt, inheriting from dpmodel and torch.nn.Module.
    • Implemented __init__, __call__, __setattr__, and forward methods for PyTorch compatibility.
    • Registered InvarFitting for deserialization mapping from dpmodel.
  • source/tests/consistent/fitting/test_ener.py
    • Imported INSTALLED_PT_EXPT and EnerFittingPTExpt for experimental PyTorch backend testing.
    • Added skip_pt_expt property to conditionally skip tests for the pt_expt backend.
    • Added pt_expt_class attribute to reference the experimental PyTorch fitting class.
    • Implemented eval_pt_expt method to evaluate pt_expt energy fitting models.
    • Added TestEnerStat class to test energy fitting statistics, including compute_input_stats for pt_expt.
  • source/tests/pt_expt/fitting/init.py
    • Added an empty __init__.py file to define the pt_expt.fitting test module.
  • source/tests/pt_expt/fitting/test_fitting_ener_fitting.py
    • Added new unit tests for EnergyFittingNet in pt_expt, covering self-consistency, serialization, and torch.export functionality.
  • source/tests/pt_expt/fitting/test_fitting_invar_fitting.py
    • Added new unit tests for InvarFitting in pt_expt, covering self-consistency, mask behavior, exception handling, property get/set, and torch.export functionality.
  • source/tests/pt_expt/fitting/test_fitting_stat.py
    • Added new unit tests for EnergyFittingNet's compute_input_stats method in pt_expt, verifying statistical calculations.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@dosubot dosubot bot added the new feature label Feb 12, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces energy fitting capabilities for the pt_expt backend by adding EnergyFittingNet and InvarFitting modules. The changes also include making the base GeneralFitting class more backend-agnostic using array_api_compat and adding comprehensive tests for the new functionality. The implementation is solid and aligns well with the existing design of the pt_expt backend. I have one suggestion regarding code duplication in a new test class to enhance maintainability.

Copy link
Contributor

@github-advanced-security github-advanced-security bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CodeQL found more than 20 potential problems in the proposed changes. Check the Files changed tab for more details.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 12, 2026

📝 Walkthrough

Walkthrough

Adds a PyTorch-experimental fitting backend (EnergyFittingNet, InvarFitting), fixes device/dtype handling in dpmodel fitting utilities, updates several descriptor registration keys (removing "_expt" suffixes), and adds extensive PT_EXPT-focused tests and package exports.

Changes

Cohort / File(s) Summary
DPModel fitting utilities
deepmd/dpmodel/fitting/general_fitting.py, deepmd/dpmodel/utils/network.py
Use array_api_compat (xp.asarray, device(...)) for fparam/aparam conversions and output placement; add guard for empty layers in FittingNet.deserialize. Review device/dtype conversions and serialization edge-case.
PT_EXPT descriptor registrations
deepmd/pt_expt/descriptor/se_e2_a.py, deepmd/pt_expt/descriptor/se_r.py, deepmd/pt_expt/descriptor/se_t.py, deepmd/pt_expt/descriptor/se_t_tebd.py
Remove _expt suffixes from BaseDescriptor.register keys (canonicalize registration identifiers).
PT_EXPT fitting implementation
deepmd/pt_expt/fitting/__init__.py, deepmd/pt_expt/fitting/base_fitting.py, deepmd/pt_expt/fitting/ener_fitting.py, deepmd/pt_expt/fitting/invar_fitting.py
Add PT_EXPT wrappers EnergyFittingNet and InvarFitting (subclass DP models + torch.nn.Module), convert NetworkCollection via serialize/deserialize, override __call__/__setattr__, and register dpmodel mappings and package exports. Pay attention to module initialization and attribute routing (dpmodel_setattr).
Tests — PT_EXPT and integration
source/tests/pt_expt/fitting/..., source/tests/consistent/fitting/test_ener.py, source/tests/common/dpmodel/test_fitting_invar_fitting.py
Add comprehensive PT_EXPT tests (serialization, torch.export, stats, masking, exceptions) and extend existing energy tests to exercise PT_EXPT. Review test device/precision guards and CI impact due to large test additions.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant EnergyFittingNet
    participant NetworkCollection
    participant EnergyFittingNetDP
    participant torchModule

    User->>EnergyFittingNet: __init__(args, kwargs)
    EnergyFittingNet->>EnergyFittingNetDP: super().__init__(...)
    EnergyFittingNetDP->>NetworkCollection: self.nets.serialize()
    NetworkCollection-->>EnergyFittingNet: serialized_data
    EnergyFittingNet->>NetworkCollection: NetworkCollection.deserialize(serialized_data)
    NetworkCollection-->>EnergyFittingNet: pt_expt_nets
    EnergyFittingNet->>EnergyFittingNet: self.nets = pt_expt_nets

    User->>EnergyFittingNet: __call__(descriptor, atype, ...)
    EnergyFittingNet->>torchModule: torch.nn.Module.__call__(...)
    torchModule->>EnergyFittingNet: forward(...)
    EnergyFittingNet->>EnergyFittingNetDP: call(descriptor, atype, ...)
    EnergyFittingNetDP-->>EnergyFittingNet: results (tensors)
    EnergyFittingNet-->>User: dict[str, torch.Tensor]
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

Possibly related PRs

Suggested reviewers

  • iProzd
  • njzjz
  • anyangml
🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 22.03% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately describes the main objective of the PR: adding fitting functionality for energy in the pt_expt backend, which is the central theme across all file changes.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

No actionable comments were generated in the recent review. 🎉

🧹 Recent nitpick comments
source/tests/pt_expt/fitting/test_fitting_invar_fitting.py (1)

273-316: Consider adding a torch.export test with aparam.

You have export tests for the no-param and fparam cases, but no export test with aparam. Since aparam has a different shape (per-atom vs. per-frame), it exercises a distinct code path in the forward signature and may surface export issues that fparam alone would not.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@source/tests/pt_expt/fitting/test_fitting_invar_fitting.py`:
- Around line 168-201: The assertIn checks are never executed because they are
placed inside the with self.assertRaises(ValueError) blocks after the call that
raises; move each self.assertIn(...) outside its corresponding with block and
use the captured context.exception (e.g. context.exception) to assert the error
message for the calls to ifn0 (first block using dd[0][:,:,:-2], second block
using dd[0] with ifp, third block using dd[0] with iap); specifically, keep the
call that raises inside with self.assertRaises(...) as context and then
immediately after the with block run self.assertIn("expected substring",
str(context.exception)) for each case so the message checks execute.
🧹 Nitpick comments (1)
source/tests/pt_expt/fitting/test_fitting_stat.py (1)

47-72: _brute_fparam_pt and _brute_aparam_pt are near-duplicates.

These two helpers differ only in the dictionary key ("fparam" vs "aparam"). Consider merging into a single function with a key parameter.

♻️ Suggested consolidation
-def _brute_fparam_pt(data, ndim):
-    adata = [ii["fparam"] for ii in data]
-    all_data = []
-    for ii in adata:
-        tmp = np.reshape(ii, [-1, ndim])
-        if len(all_data) == 0:
-            all_data = np.array(tmp)
-        else:
-            all_data = np.concatenate((all_data, tmp), axis=0)
-    avg = np.average(all_data, axis=0)
-    std = np.std(all_data, axis=0)
-    return avg, std
-
-
-def _brute_aparam_pt(data, ndim):
-    adata = [ii["aparam"] for ii in data]
-    all_data = []
-    for ii in adata:
-        tmp = np.reshape(ii, [-1, ndim])
-        if len(all_data) == 0:
-            all_data = np.array(tmp)
-        else:
-            all_data = np.concatenate((all_data, tmp), axis=0)
-    avg = np.average(all_data, axis=0)
-    std = np.std(all_data, axis=0)
-    return avg, std
+def _brute_param_pt(data, ndim, key):
+    adata = [ii[key] for ii in data]
+    all_data = []
+    for ii in adata:
+        tmp = np.reshape(ii, [-1, ndim])
+        if len(all_data) == 0:
+            all_data = np.array(tmp)
+        else:
+            all_data = np.concatenate((all_data, tmp), axis=0)
+    avg = np.average(all_data, axis=0)
+    std = np.std(all_data, axis=0)
+    return avg, std

@codecov
Copy link

codecov bot commented Feb 12, 2026

Codecov Report

❌ Patch coverage is 98.50746% with 1 line in your changes missing coverage. Please review.
✅ Project coverage is 82.06%. Comparing base (2901448) to head (d057ca1).

Files with missing lines Patch % Lines
deepmd/dpmodel/utils/network.py 75.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5218      +/-   ##
==========================================
+ Coverage   82.03%   82.06%   +0.03%     
==========================================
  Files         728      732       +4     
  Lines       73922    73974      +52     
  Branches     3615     3615              
==========================================
+ Hits        60640    60709      +69     
+ Misses      12119    12100      -19     
- Partials     1163     1165       +2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

@wanghan-iapcm wanghan-iapcm added this pull request to the merge queue Feb 12, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants