-
Notifications
You must be signed in to change notification settings - Fork 590
feat(pt_expt): add fitting for energy #5218
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary of ChangesHello @wanghan-iapcm, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the Highlights
Changelog
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces energy fitting capabilities for the pt_expt backend by adding EnergyFittingNet and InvarFitting modules. The changes also include making the base GeneralFitting class more backend-agnostic using array_api_compat and adding comprehensive tests for the new functionality. The implementation is solid and aligns well with the existing design of the pt_expt backend. I have one suggestion regarding code duplication in a new test class to enhance maintainability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CodeQL found more than 20 potential problems in the proposed changes. Check the Files changed tab for more details.
📝 WalkthroughWalkthroughAdds a PyTorch-experimental fitting backend (EnergyFittingNet, InvarFitting), fixes device/dtype handling in dpmodel fitting utilities, updates several descriptor registration keys (removing "_expt" suffixes), and adds extensive PT_EXPT-focused tests and package exports. Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant EnergyFittingNet
participant NetworkCollection
participant EnergyFittingNetDP
participant torchModule
User->>EnergyFittingNet: __init__(args, kwargs)
EnergyFittingNet->>EnergyFittingNetDP: super().__init__(...)
EnergyFittingNetDP->>NetworkCollection: self.nets.serialize()
NetworkCollection-->>EnergyFittingNet: serialized_data
EnergyFittingNet->>NetworkCollection: NetworkCollection.deserialize(serialized_data)
NetworkCollection-->>EnergyFittingNet: pt_expt_nets
EnergyFittingNet->>EnergyFittingNet: self.nets = pt_expt_nets
User->>EnergyFittingNet: __call__(descriptor, atype, ...)
EnergyFittingNet->>torchModule: torch.nn.Module.__call__(...)
torchModule->>EnergyFittingNet: forward(...)
EnergyFittingNet->>EnergyFittingNetDP: call(descriptor, atype, ...)
EnergyFittingNetDP-->>EnergyFittingNet: results (tensors)
EnergyFittingNet-->>User: dict[str, torch.Tensor]
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
No actionable comments were generated in the recent review. 🎉 🧹 Recent nitpick comments
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@source/tests/pt_expt/fitting/test_fitting_invar_fitting.py`:
- Around line 168-201: The assertIn checks are never executed because they are
placed inside the with self.assertRaises(ValueError) blocks after the call that
raises; move each self.assertIn(...) outside its corresponding with block and
use the captured context.exception (e.g. context.exception) to assert the error
message for the calls to ifn0 (first block using dd[0][:,:,:-2], second block
using dd[0] with ifp, third block using dd[0] with iap); specifically, keep the
call that raises inside with self.assertRaises(...) as context and then
immediately after the with block run self.assertIn("expected substring",
str(context.exception)) for each case so the message checks execute.
🧹 Nitpick comments (1)
source/tests/pt_expt/fitting/test_fitting_stat.py (1)
47-72:_brute_fparam_ptand_brute_aparam_ptare near-duplicates.These two helpers differ only in the dictionary key (
"fparam"vs"aparam"). Consider merging into a single function with akeyparameter.♻️ Suggested consolidation
-def _brute_fparam_pt(data, ndim): - adata = [ii["fparam"] for ii in data] - all_data = [] - for ii in adata: - tmp = np.reshape(ii, [-1, ndim]) - if len(all_data) == 0: - all_data = np.array(tmp) - else: - all_data = np.concatenate((all_data, tmp), axis=0) - avg = np.average(all_data, axis=0) - std = np.std(all_data, axis=0) - return avg, std - - -def _brute_aparam_pt(data, ndim): - adata = [ii["aparam"] for ii in data] - all_data = [] - for ii in adata: - tmp = np.reshape(ii, [-1, ndim]) - if len(all_data) == 0: - all_data = np.array(tmp) - else: - all_data = np.concatenate((all_data, tmp), axis=0) - avg = np.average(all_data, axis=0) - std = np.std(all_data, axis=0) - return avg, std +def _brute_param_pt(data, ndim, key): + adata = [ii[key] for ii in data] + all_data = [] + for ii in adata: + tmp = np.reshape(ii, [-1, ndim]) + if len(all_data) == 0: + all_data = np.array(tmp) + else: + all_data = np.concatenate((all_data, tmp), axis=0) + avg = np.average(all_data, axis=0) + std = np.std(all_data, axis=0) + return avg, std
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5218 +/- ##
==========================================
+ Coverage 82.03% 82.06% +0.03%
==========================================
Files 728 732 +4
Lines 73922 73974 +52
Branches 3615 3615
==========================================
+ Hits 60640 60709 +69
+ Misses 12119 12100 -19
- Partials 1163 1165 +2 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Summary by CodeRabbit
Bug Fixes
New Features
Tests