Skip to content

Conversation

@njzjz
Copy link
Member

@njzjz njzjz commented Feb 9, 2026

Use wraps to keep the modules' names, so they won't be FlaxModule, which cannot be regonized. I realized it when implementing #5213.

Summary by CodeRabbit

  • Refactor
    • Improved type safety and metadata handling for JAX module processing.

Use `wraps` to keep the modules' names, so they won't be `FlaxModule`, which cannot be regonized. I realized it when implementing deepmodeling#5213.
Copilot AI review requested due to automatic review settings February 9, 2026 21:48
@gemini-code-assist
Copy link

Summary of Changes

Hello @njzjz, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses an issue where dynamically generated JAX/Flax modules were losing their original class names, making them difficult to identify and debug. By leveraging the functools.wraps decorator, the change ensures that these modules correctly inherit and display the names of their underlying native operations. This improvement enhances the clarity and traceability of module instances, which is particularly beneficial for integration with other features like #5213.

Highlights

  • Import wraps: The wraps decorator from the functools module has been imported to facilitate preserving metadata of wrapped functions/classes.
  • Type Hint Refinement: The type hints for the flax_module function's module parameter and return type have been updated from NativeOP to type[NativeOP] and nnx.Module to type[nnx.Module] respectively, indicating that the function now operates on and returns module types.
  • Preserving Module Names: The @wraps(module, updated=()) decorator has been applied to the dynamically created FlaxModule class within the flax_module function. This ensures that the generated Flax modules retain the original names of the NativeOP classes they wrap, preventing them from being generically named FlaxModule.
Changelog
  • deepmd/jax/common.py
    • Added import for wraps from functools.
    • Modified type annotations for the flax_module function to correctly reflect that it handles and returns module types.
    • Applied the @wraps decorator to the FlaxModule class definition to preserve the original module's name.
Activity
  • No human activity (comments, reviews, etc.) has been recorded on this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@github-actions github-actions bot added the Python label Feb 9, 2026
Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively addresses the issue of JAX modules not retaining their original names by correctly utilizing the functools.wraps decorator. The changes to type hints for the flax_module function are also a significant improvement, enhancing type correctness and clarity. The use of updated=() in the wraps decorator is a good practice when wrapping classes, preventing unintended overwrites of the wrapper's __dict__ while still copying essential metadata like __name__ and __module__ from the original class.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 9, 2026

📝 Walkthrough

Walkthrough

Type annotations for the flax_module function are updated to accept and return type objects rather than instances. A wraps decorator is added to preserve module metadata during transformation.

Changes

Cohort / File(s) Summary
Type Hints & Metadata Preservation
deepmd/jax/common.py
Updated flax_module function signature to clarify parameter and return types (from instance types to type objects). Added functools.wraps import and applied @wraps decorator to preserve original module metadata.

Estimated code review effort

🎯 1 (Trivial) | ⏱️ ~5 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 33.33% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately summarizes the main change: using functools.wraps to preserve JAX/Flax module names instead of replacing them with the generic 'FlaxModule' name.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

No actionable comments were generated in the recent review. 🎉


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Updates the JAX backend’s flax_module decorator so wrapped nnx.Module classes retain the original module/class metadata, avoiding generic FlaxModule names that can interfere with recognition/introspection.

Changes:

  • Add functools.wraps(updated=()) to preserve the wrapped class’s __name__/__qualname__/__module__/__doc__.
  • Refine flax_module type annotations and docstring types to reflect class-in/class-out behavior.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 50 to 60
"""Convert a NativeOP to a Flax module.

Parameters
----------
module : NativeOP
module : type[NativeOP]
The NativeOP to convert.

Returns
-------
flax.nnx.Module
type[flax.nnx.Module]
The Flax module.
Copy link

Copilot AI Feb 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring still reads like module is an instance ("The NativeOP to convert"), but the function is used as a class decorator and now expects a class object. Updating the parameter/return descriptions to explicitly say "NativeOP subclass" / "Flax nnx.Module subclass" would avoid confusion for readers.

Copilot uses AI. Check for mistakes.
Comment on lines +78 to 80
@wraps(module, updated=())
class FlaxModule(module, nnx.Module, metaclass=MixedMetaClass):
def __init_subclass__(cls, **kwargs: Any) -> None:
Copy link

Copilot AI Feb 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change fixes a subtle introspection/serialization issue (module names no longer being FlaxModule). There doesn't appear to be a regression test that asserts @flax_module preserves __name__/__qualname__ (and avoids FlaxModule). Adding a small JAX unit test would help prevent future breakage.

Copilot uses AI. Check for mistakes.
Comment on lines 47 to +49
def flax_module(
module: NativeOP,
) -> nnx.Module:
module: type[NativeOP],
) -> type[nnx.Module]:
Copy link

Copilot AI Feb 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return type annotation type[nnx.Module] loses the original class type (the returned class is also a subclass of the input module). This makes downstream typing inconsistent (e.g., decorated DPModel classes no longer type as their original DPModel base). Consider using a TypeVar so flax_module is typed as def flax_module(module: type[T]) -> type[T]: ... (and cast the generated class), optionally adding an nnx.Module Protocol/mixin if you want to preserve both facets.

Copilot uses AI. Check for mistakes.
@codecov
Copy link

codecov bot commented Feb 9, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 82.00%. Comparing base (97d8ded) to head (51985af).

Additional details and impacted files
@@           Coverage Diff           @@
##           master    #5214   +/-   ##
=======================================
  Coverage   82.00%   82.00%           
=======================================
  Files         724      724           
  Lines       73801    73803    +2     
  Branches     3616     3615    -1     
=======================================
+ Hits        60520    60522    +2     
+ Misses      12120    12118    -2     
- Partials     1161     1163    +2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

@njzjz njzjz requested a review from wanghan-iapcm February 10, 2026 02:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant