Skip to content

Conversation

@iProzd
Copy link
Collaborator

@iProzd iProzd commented Feb 11, 2026

Summary by CodeRabbit

  • New Features
    • Introduces a modular PyTorch training system with pluggable loops, extensible hooks (TensorBoard, timing, early stopping), and unified logging.
    • Adds robust checkpointing with resume and fine-tuning workflows.
    • Enables multi-task training and configurable data loaders/adapters.
    • Provides an optimizer factory (e.g., Adam/AdamW/LKF/AdaMuon/HybridMuon) with LR scheduling.
    • Allows switching between legacy and new trainer via a flag.
  • Documentation
    • Adds package-level documentation and consolidated public API.
  • Tests
    • Adds comprehensive unit and integration tests across configs, hooks, optimizers, logging, and end-to-end training.

Copilot AI review requested due to automatic review settings February 11, 2026 15:03
@iProzd iProzd marked this pull request as draft February 11, 2026 15:03
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @iProzd, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly refactors the PyTorch training code into a highly modular and extensible system. The primary goal is to improve the organization and maintainability of the training pipeline by breaking down monolithic logic into specialized components for data management, checkpointing, logging, and optimizer handling. The new design also introduces a flexible hook system for custom callbacks and ensures backward compatibility with the previous trainer implementation, facilitating a smooth transition to the new architecture.

Highlights

  • Modular Training System: Introduced a new, modular PyTorch training system that decouples core functionalities into specialized components.
  • Component-Based Design: Refactored training logic into distinct components such as CheckpointManager, DataManager, HookManager, OptimizerFactory, and TrainingLoop for improved organization and maintainability.
  • Extensibility via Hooks: Implemented a flexible hook system, allowing users to inject custom logic at various points during the training process without modifying core code.
  • Centralized Configuration Management: Centralized training configuration using Python dataclasses (TrainingConfig, OptimizerConfig, LearningRateConfig, etc.) for clear and validated settings.
  • Backward Compatibility: Ensured backward compatibility with the legacy trainer by introducing a use_legacy flag and aliasing the old Trainer as LegacyTrainer during the transition.
Changelog
  • deepmd/pt/entrypoints/main.py
    • Imported NewTrainer from the refactored training module.
    • Added a use_legacy boolean parameter to the get_trainer function.
    • Modified get_trainer to conditionally instantiate either the new Trainer or the legacy training.Trainer based on the use_legacy flag.
  • deepmd/pt/train/init.py
    • Expanded module imports to expose new modular training components, including CheckpointManager, TrainingConfig, DataManager, HookManager, TrainingLogger, OptimizerFactory, Trainer, and TrainingLoopFactory.
    • Aliased the old Trainer as LegacyTrainer to maintain backward compatibility during the transition.
  • deepmd/pt/train/checkpoint_manager.py
    • Added a new module that manages model checkpoints, including saving, loading for resume/finetune, automatic cleanup of old checkpoints, and symlink management.
  • deepmd/pt/train/config.py
    • Added a new module defining dataclasses for structured training configuration, such as OptimizerConfig, LearningRateConfig, DisplayConfig, CheckpointConfig, and the overarching TrainingConfig.
  • deepmd/pt/train/data_loader.py
    • Added a new module introducing an abstract DataLoader interface and a DpLoaderSetAdapter to make existing DpLoaderSet compatible with the new system, enabling future high-performance data loading implementations.
  • deepmd/pt/train/data_manager.py
    • Added a new module that manages training and validation data, handling data loading and batch iteration for both single-task and multi-task scenarios using the new AbstractDataLoader interface.
  • deepmd/pt/train/hooks.py
    • Added a new module implementing a flexible hook system for extensible training callbacks, including TrainingHook base class, HookManager, TensorBoardHook, and TimingHook.
  • deepmd/pt/train/logger.py
    • Added a new module providing utilities for formatted training log output to console and file, and a LossAccumulator for averaging metrics over multiple steps.
  • deepmd/pt/train/optimizer_factory.py
    • Added a new module that acts as a factory for creating optimizers and learning rate schedulers using a strategy pattern, supporting various optimizer types like Adam, AdamW, LKF, AdaMuon, and HybridMuon.
  • deepmd/pt/train/trainer.py
    • Added the new Trainer class, which orchestrates the entire training process by composing various modular components like data management, checkpointing, hooks, and training loops.
  • deepmd/pt/train/training_loop.py
    • Added a new module defining abstract and concrete training loops for different optimizer types, such as AdamTrainingLoop and LKFEnergyTrainingLoop, handling the core training step logic.
  • source/tests/pt/test_new_training.py
    • Added unit tests to verify the functionality and integration of the new modular training system components, including configuration parsing, optimizer factory, hooks, and loss accumulation.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

latest = Path(checkpoint_file.read_text().strip())
if latest.exists():
return latest
except Exception:

Check notice

Code scanning / CodeQL

Empty except Note

'except' clause does nothing but pass and there is no explanatory comment.

def __next__(self) -> dict[str, Any]:
"""Get next batch."""
...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.

def add_data_requirement(self, requirement: Any) -> None:
"""Add data requirements for labels."""
...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.

def preload_and_modify_all_data_torch(self) -> None:
"""Preload and apply modifiers to data."""
...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.

def print_summary(self, name: str, weights: Any = None) -> None:
"""Print dataset summary."""
...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.
@property
def systems(self) -> list[Any]:
"""Get list of systems/datasets."""
...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.
)

# Compute statistics
finetune_has_new_type_key = (

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable finetune_has_new_type_key is not used.
@dosubot dosubot bot added the enhancement label Feb 11, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant and well-executed refactoring of the PyTorch training code. The new modular architecture, which separates concerns into components like DataManager, CheckpointManager, HookManager, and TrainingLoopFactory, is a major improvement for maintainability, extensibility, and clarity. The use of modern design patterns is commendable. My review has identified a few issues, including a potential bug in the LKF training loop, an incomplete implementation for multi-task parameter sharing, and some opportunities for code cleanup and performance optimization. Overall, this is an excellent contribution to the codebase.

Comment on lines +271 to +277
_, loss, more_loss = module.loss[task_key](
{},
fake_model,
label_dict,
natoms,
learning_rate=pref_lr,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The loss function module.loss[task_key] is called with an empty dictionary {} as the first argument for input_dict. The loss function's signature expects input_dict, and it might require information from it (e.g., flags indicating which properties are present). Passing an empty dictionary is likely a bug and could lead to incorrect loss calculation or runtime errors.

Suggested change
_, loss, more_loss = module.loss[task_key](
{},
fake_model,
label_dict,
natoms,
learning_rate=pref_lr,
)
_, loss, more_loss = module.loss[task_key](
input_dict,
fake_model,
label_dict,
natoms,
learning_rate=pref_lr,
)

Comment on lines +205 to +232
if use_legacy:
trainer = training.Trainer(
config,
train_data,
stat_file_path=stat_file_path,
validation_data=validation_data,
init_model=init_model,
restart_model=restart_model,
finetune_model=finetune_model,
force_load=force_load,
shared_links=shared_links,
finetune_links=finetune_links,
init_frz_model=init_frz_model,
)
else:
trainer = NewTrainer(
config,
train_data,
stat_file_path=stat_file_path,
validation_data=validation_data,
init_model=init_model,
restart_model=restart_model,
finetune_model=finetune_model,
force_load=force_load,
shared_links=shared_links,
finetune_links=finetune_links,
init_frz_model=init_frz_model,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block of code for creating the trainer instance is duplicated for the use_legacy and else branches. This can be refactored to reduce code duplication and improve maintainability by selecting the class first and then instantiating it once.

    trainer_class = training.Trainer if use_legacy else NewTrainer
    trainer = trainer_class(
        config,
        train_data,
        stat_file_path=stat_file_path,
        validation_data=validation_data,
        init_model=init_model,
        restart_model=restart_model,
        finetune_model=finetune_model,
        force_load=force_load,
        shared_links=shared_links,
        finetune_links=finetune_links,
        init_frz_model=init_frz_model,
    )

else 0,
drop_last=False,
collate_fn=lambda batch: batch, # prevent extra conversion
pin_memory=False, # Batch processor handles device transfer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For optimal performance when training on a GPU, pin_memory should be set to True. This allows the DataLoader to place tensors in pinned (page-locked) memory, which enables faster, asynchronous data transfer to the GPU when using non_blocking=True in the .to() call, as is done in BatchProcessor.

Suggested change
pin_memory=False, # Batch processor handles device transfer
pin_memory=True, # Enable asynchronous data transfer

Comment on lines +664 to +671
model_prob = np.zeros(len(self.model_keys), dtype=np.float32)
for ii, model_key in enumerate(self.model_keys):
# Get training data size for this model
if hasattr(self, "data_manager") and self.data_manager:
# Try to get from data_manager
pass
# Use uniform probability for now
model_prob[ii] = 1.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for computing model_prob for multi-task parameter sharing appears to be incomplete, as it currently falls back to a uniform probability. For correct statistical merging of shared parameters, these probabilities should ideally be based on the size of the respective datasets for each task. The pass statement and the comment "Use uniform probability for now" suggest this is a placeholder.

Suggested change
model_prob = np.zeros(len(self.model_keys), dtype=np.float32)
for ii, model_key in enumerate(self.model_keys):
# Get training data size for this model
if hasattr(self, "data_manager") and self.data_manager:
# Try to get from data_manager
pass
# Use uniform probability for now
model_prob[ii] = 1.0
model_prob = np.zeros(len(self.model_keys), dtype=np.float32)
for ii, model_key in enumerate(self.model_keys):
# Get training data size for this model
if hasattr(self, "data_manager") and self.data_manager:
# Base probability on the number of batches in each training loader
model_prob[ii] = float(len(self.data_manager.training_loaders[model_key]))
else:
# Fallback to uniform probability if data_manager is not yet available
model_prob[ii] = 1.0

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces a new modular PyTorch training system (Trainer + component managers/factories) and wires it into the dp --pt train entrypoint, with accompanying unit and CLI-level tests.

Changes:

  • Added a new modular deepmd.pt.train.trainer.Trainer and supporting components (data manager/loader abstraction, optimizer factory, training loop strategies, checkpoint manager, hooks, logger, typed config).
  • Updated the PT entrypoint to optionally use the new trainer (use_legacy=False by default).
  • Added a new PT test suite validating config parsing, factories/hooks/logger utilities, and basic end-to-end trainer construction/data access.

Reviewed changes

Copilot reviewed 12 out of 12 changed files in this pull request and generated 15 comments.

Show a summary per file
File Description
source/tests/pt/test_new_training.py Adds unit + end-to-end CLI tests for the new modular trainer components.
deepmd/pt/train/training_loop.py Introduces optimizer-specific training loop implementations and a factory.
deepmd/pt/train/trainer.py New modular Trainer orchestrating training, validation, hooks, logging, and checkpointing.
deepmd/pt/train/optimizer_factory.py Strategy-based optimizer/scheduler factory.
deepmd/pt/train/logger.py New training logger + loss accumulator implementation.
deepmd/pt/train/hooks.py Hook system (priority-based) + built-in hooks (TensorBoard, timing, early stopping).
deepmd/pt/train/data_manager.py Unifies single/multi-task batch iteration and abstracts data loader backends.
deepmd/pt/train/data_loader.py Defines abstract loader interface and DpLoaderSet adapter + batch processing.
deepmd/pt/train/config.py Adds typed config dataclasses and parsing/validation logic.
deepmd/pt/train/checkpoint_manager.py Centralizes checkpoint save/load and retention management.
deepmd/pt/train/init.py Exposes new modular training API while keeping legacy trainer available.
deepmd/pt/entrypoints/main.py Adds new trainer selection switch (use_legacy) and instantiates the new trainer.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

module.train_infos["step"] = step

# Prepare checkpoint path
save_path = Path(self.config.save_ckpt + f"-{step + 1}.pt")
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checkpoint filenames are off-by-one: save() receives step (already 1-based in the new Trainer), but the path suffix uses step + 1. This will create checkpoints like model.ckpt-2.pt for the first saved step and makes the saved step metadata inconsistent with the filename. Consider making save() treat step as the display step and remove the + 1, or pass a 0-based step consistently from the caller.

Suggested change
save_path = Path(self.config.save_ckpt + f"-{step + 1}.pt")
save_path = Path(self.config.save_ckpt + f"-{step}.pt")

Copilot uses AI. Check for mistakes.
Comment on lines +272 to +283
line += f" {valid_res.get(k, 0.0):11.2e} {train_res.get(k, 0.0):11.2e}"
else:
for k in sorted(train_res.keys()):
line += f" {train_res.get(k, 0.0):11.2e}"
else:
train_keys = sorted(train_results.keys())
if valid_results and isinstance(valid_results, dict):
for k in train_keys:
line += f" {valid_results.get(k, 0.0):11.2e} {train_results.get(k, 0.0):11.2e}"
else:
for k in train_keys:
line += f" {train_results.get(k, 0.0):11.2e}"
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_print_to_file() formats metric values with float format specs (:11.2e) but pulls them via .get(k, 0.0). In practice, loss modules return scalar torch.Tensors for metrics (e.g., rmse_*), which will raise TypeError: unsupported format string passed to Tensor.__format__. Convert scalar tensors to Python floats before formatting (e.g., float(v) / v.item()), and prefer nan rather than 0.0 for missing validation values to match the header note and legacy behavior.

Suggested change
line += f" {valid_res.get(k, 0.0):11.2e} {train_res.get(k, 0.0):11.2e}"
else:
for k in sorted(train_res.keys()):
line += f" {train_res.get(k, 0.0):11.2e}"
else:
train_keys = sorted(train_results.keys())
if valid_results and isinstance(valid_results, dict):
for k in train_keys:
line += f" {valid_results.get(k, 0.0):11.2e} {train_results.get(k, 0.0):11.2e}"
else:
for k in train_keys:
line += f" {train_results.get(k, 0.0):11.2e}"
valid_val = valid_res.get(k, float("nan"))
try:
valid_val = float(valid_val)
except (TypeError, ValueError):
valid_val = float("nan")
train_val = train_res.get(k, 0.0)
try:
train_val = float(train_val)
except (TypeError, ValueError):
train_val = float("nan")
line += f" {valid_val:11.2e} {train_val:11.2e}"
else:
for k in sorted(train_res.keys()):
train_val = train_res.get(k, 0.0)
try:
train_val = float(train_val)
except (TypeError, ValueError):
train_val = float("nan")
line += f" {train_val:11.2e}"
else:
train_keys = sorted(train_results.keys())
if valid_results and isinstance(valid_results, dict):
for k in train_keys:
valid_val = valid_results.get(k, float("nan"))
try:
valid_val = float(valid_val)
except (TypeError, ValueError):
valid_val = float("nan")
train_val = train_results.get(k, 0.0)
try:
train_val = float(train_val)
except (TypeError, ValueError):
train_val = float("nan")
line += f" {valid_val:11.2e} {train_val:11.2e}"
else:
for k in train_keys:
train_val = train_results.get(k, 0.0)
try:
train_val = float(train_val)
except (TypeError, ValueError):
train_val = float("nan")
line += f" {train_val:11.2e}"

Copilot uses AI. Check for mistakes.
Comment on lines +378 to +382
for key, value in more_loss.items():
if "l2_" in key:
continue
if not isinstance(value, (int, float)):
continue
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LossAccumulator.update() ignores non-(int, float) values, but training losses/metrics are typically scalar torch.Tensors (see TaskLoss.display_if_exist() and loss implementations). This will drop all metrics when disp_avg is enabled, resulting in empty logs. Accept scalar tensors here and convert them to floats for accumulation.

Copilot uses AI. Check for mistakes.
Comment on lines +1011 to +1017
train_results[task_key] = {
k: v for k, v in step_result.more_loss.items() if "l2_" not in k
}
else:
train_results = {
k: v for k, v in step_result.more_loss.items() if "l2_" not in k
}
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When disp_avg is disabled, train_results is built directly from step_result.more_loss values without converting scalar torch.Tensors to floats. TrainingLogger._print_to_file() uses float formatting and will fail on tensors. Convert scalar tensors to Python floats when constructing train_results (and similarly for validation results) to keep logging working and consistent with the legacy trainer.

Copilot uses AI. Check for mistakes.
Comment on lines +1053 to +1055
"""Run validation on all tasks."""
self.hook_manager.on_validation_begin(0, {})

Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_run_validation() always calls on_validation_begin(0, ...) with step=0, regardless of the current training step. Hooks that rely on step indexing (e.g., TensorBoard/early stopping) will record incorrect x-axes. Pass the current step through from _log_and_validate() and use it here.

Copilot uses AI. Check for mistakes.
Comment on lines +1097 to +1099
if "l2_" not in k and isinstance(v, (int, float)):
results[k] = results.get(k, 0.0) + v * natoms

Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Validation aggregation drops tensor metrics: _validate_task() only accumulates values that are (int, float), but loss modules typically return scalar torch.Tensors in more_loss. This will make valid_results empty even when validation is available. Include scalar tensors (convert via float(v) / v.item()) in the aggregation.

Suggested change
if "l2_" not in k and isinstance(v, (int, float)):
results[k] = results.get(k, 0.0) + v * natoms
if "l2_" in k:
continue
# Accept numeric scalars and scalar tensors
if isinstance(v, (int, float)):
val = float(v)
elif isinstance(v, torch.Tensor) and v.numel() == 1:
# Convert scalar tensor to Python float
val = float(v)
else:
continue
results[k] = results.get(k, 0.0) + val * natoms

Copilot uses AI. Check for mistakes.
shared_links: dict[str, Any] | None = None,
finetune_links: dict[str, Any] | None = None,
use_legacy: bool = False,
) -> training.Trainer:
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return type annotation of get_trainer() is training.Trainer, but the function can now return the new modular deepmd.pt.train.trainer.Trainer when use_legacy=False. Update the annotation to a union/protocol (or a common base type) to keep typing accurate for downstream users.

Suggested change
) -> training.Trainer:
) -> training.Trainer | NewTrainer:

Copilot uses AI. Check for mistakes.
Comment on lines +413 to +419
# Compute statistics
finetune_has_new_type_key = (
self.finetune_links[key].get_has_new_type()
if self.is_finetune and self.finetune_links
else False
)

Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable finetune_has_new_type_key is not used.

Suggested change
# Compute statistics
finetune_has_new_type_key = (
self.finetune_links[key].get_has_new_type()
if self.is_finetune and self.finetune_links
else False
)

Copilot uses AI. Check for mistakes.
This allows gradual migration from DpLoaderSet to new implementations.
"""

def __iter__(self) -> Iterator[dict[str, Any]]:
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Iter method of iterator DataLoaderInterface does not return self.

Suggested change
def __iter__(self) -> Iterator[dict[str, Any]]:
def __iter__(self) -> DataLoaderInterface:

Copilot uses AI. Check for mistakes.
Comment on lines +344 to +345
except Exception:
pass
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'except' clause does nothing but pass and there is no explanatory comment.

Suggested change
except Exception:
pass
except Exception as e:
log.warning(
"Failed to read or parse 'checkpoint' file; "
"falling back to scanning for latest checkpoint: %s",
e,
)

Copilot uses AI. Check for mistakes.
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 11, 2026

📝 Walkthrough

Walkthrough

This PR introduces a comprehensive modular refactoring of DeepMD's PyTorch training system, adding dedicated components for configuration validation, data management, checkpointing, optimization strategies, training hooks, and logging. The main entry point now supports switching between legacy and new trainer implementations via a use_legacy parameter.

Changes

Cohort / File(s) Summary
Entry Point
deepmd/pt/entrypoints/main.py
Added use_legacy parameter to get_trainer() to conditionally instantiate either NewTrainer or legacy training.Trainer; includes new import alias for Trainer from deepmd.pt.train.trainer.
Training Package Initialization
deepmd/pt/train/__init__.py
Added comprehensive module docstring and re-exports modular training components including CheckpointManager, Config classes, DataManager, Hooks, Logger, OptimizerFactory, Trainer variants, and TrainingLoop factories via new __all__ list; includes LegacyTrainer alias for backward compatibility.
Configuration Management
deepmd/pt/train/config.py
Introduced dataclasses for OptimizerConfig, LearningRateConfig, DisplayConfig, CheckpointConfig, and TrainingConfig with factory methods (from_dict), validation logic for warmup steps and num_steps, and multi-task support with per-task optimizer/learning-rate configurations.
Data Loading Abstraction
deepmd/pt/train/data_loader.py
Added DataLoaderInterface protocol, BatchProcessor for unified batch handling, AbstractDataLoader base class, DpLoaderSetAdapter wrapper for legacy DpLoaderSet, and DataLoaderFactory with registry pattern; supports device transfer, input/label splitting, and backend-agnostic iteration.
Data Management
deepmd/pt/train/data_manager.py
Introduced DataManager to centralize data loading for single-task and multi-task training; handles DpLoaderSet and AbstractDataLoader inputs, provides get_train_batch/get_valid_batch with optional task_key, and includes factory method create_from_dploader_set for legacy compatibility.
Checkpointing System
deepmd/pt/train/checkpoint_manager.py
New CheckpointManager handles model/optimizer state saving with DDP unwrapping, checkpoint rotation (max_ckpt_keep), loading with strict/non-strict modes, fine-tuning support (load_for_finetune), and utility methods for checkpoint resumption; includes symlink management and comprehensive logging.
Hook System
deepmd/pt/train/hooks.py
Introduced extensible hook framework with HookPriority enum, TrainingHook abstract base, HookManager for registration/dispatch with error resilience, and concrete hooks: TensorBoardHook (metrics logging), TimingHook (step timing/ETA), and EarlyStoppingHook (monitoring with patience).
Training Logging & Loss Tracking
deepmd/pt/train/logger.py
Added TrainingLogger for file and console logging with multi-task support, header formatting, and context manager interface; includes LossAccumulator for tracking and averaging step losses across single-task and multi-task scenarios with automatic l2\_ key filtering.
Optimizer Factory
deepmd/pt/train/optimizer_factory.py
Introduced OptimizerFactory with strategy pattern; defines OptimizerStrategy base class and concrete strategies for Adam, AdamW, LKF, AdaMuon, and HybridMuon; factory supports scheduler creation, optimizer registration, and convenience module-level functions.
Training Loops
deepmd/pt/train/training_loop.py
Added BaseTrainingLoop abstract class with zero_grad and step methods; concrete implementations include AdamTrainingLoop (standard backprop), LKFEnergyTrainingLoop (Kalman Filter with prefactor scheduling), LKFDenoiseTrainingLoop (specialized denoise path); TrainingLoopFactory selects loop by optimizer and loss type.
Main Trainer
deepmd/pt/train/trainer.py
Introduced comprehensive Trainer class orchestrating end-to-end training with composition of TrainingConfig, DataManager, OptimizerFactory, CheckpointManager, TrainingLoop, HookManager, and TrainingLogger; supports single/multi-task, fine-tuning, DDP distribution, model bias adjustment, and dynamic task selection in training loop.
Test Suite
source/tests/pt/test_new_training.py
Comprehensive test module covering OptimizerConfig, LearningRateConfig, TrainingConfig (single/multi-task, warmup logic), OptimizerFactory, hooks (priority, execution, error handling), LossAccumulator, DisplayConfig, CheckpointConfig, and end-to-end CLI integration tests for trainer creation and component initialization.

Sequence Diagram

sequenceDiagram
    participant Main as Main Trainer
    participant DM as DataManager
    participant OF as OptimizerFactory
    participant CM as CheckpointManager
    participant TL as TrainingLoop
    participant HM as HookManager
    participant Logger as TrainingLogger

    Main->>DM: get_train_batch(task_key)
    DM-->>Main: (inputs, labels, log_dict)
    
    Main->>OF: create_optimizer(params, config, lr_config)
    OF-->>Main: optimizer
    
    Main->>OF: create_scheduler(optimizer, warmup_steps, ...)
    OF-->>Main: lr_scheduler
    
    Main->>CM: load(checkpoint_path)
    CM-->>Main: start_step, metadata
    
    Main->>HM: on_train_begin()
    HM-->>Main: executed all hooks
    
    loop Training Steps
        Main->>TL: step(inputs, labels, cur_lr, task_key)
        TL-->>Main: TrainingStepResult(loss, predictions, lr)
        
        Main->>Logger: log_step(step, losses, lr)
        Logger-->>Main: logged to file/console
        
        Main->>HM: on_step_end(step, logs)
        HM-->>Main: executed all hooks
        
        alt Validation Step
            Main->>DM: get_valid_batch(task_key)
            DM-->>Main: (valid_inputs, valid_labels, log_dict)
            
            Main->>HM: on_validation_begin(step)
            HM-->>Main: executed hooks
            
            Main->>HM: on_validation_end(step, val_logs)
            HM-->>Main: executed hooks
        end
        
        alt Checkpoint Step
            Main->>CM: save(step, wrapper, optimizer, lr)
            CM-->>Main: checkpoint_path
            
            Main->>HM: on_save_checkpoint(step, path)
            HM-->>Main: executed hooks
        end
    end
    
    Main->>Logger: log_summary()
    Logger-->>Main: logged summary
    
    Main->>HM: on_train_end()
    HM-->>Main: executed all hooks
    
    Main->>CM: save_final(step, wrapper)
    CM-->>Main: final_checkpoint_path
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

Possibly related PRs

  • #5130: Adds AdaMuon optimizer implementation that is integrated via OptimizerFactory strategy pattern in this PR.
  • #5149: Adds HybridMuon optimizer implementation that is integrated via OptimizerFactory strategy pattern in this PR.
  • #4849: Modifies legacy Trainer finetuning logic in deepmd.pt.train.training, which is now aliased as LegacyTrainer and selectable via use_legacy flag.

Suggested labels

Python, new feature

Suggested reviewers

  • wanghan-iapcm
  • njzjz
🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 inconclusive)
Check name Status Explanation Resolution
Title check ❓ Inconclusive The title 'refactor(pt): refactor training code' is vague and uses non-descriptive terms that don't convey meaningful information about the specific changes made. Replace with a more specific title that highlights the main architectural change, e.g., 'refactor(pt): modularize training system with config, data manager, optimizer factory, and hooks' or 'refactor(pt): introduce modular training components for single and multi-task support'.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Docstring Coverage ✅ Passed Docstring coverage is 93.67% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

Issue Planner is now in beta. Read the docs and try it out! Share your feedback on Discord.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 19

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
deepmd/pt/entrypoints/main.py (1)

101-111: ⚠️ Potential issue | 🟡 Minor

Return type annotation -> training.Trainer is inaccurate when use_legacy=False.

When use_legacy=False (the default), the function returns a NewTrainer (from deepmd.pt.train.trainer), not a training.Trainer (from deepmd.pt.train.training). These are distinct classes. Consider using a Union type or a common protocol/base class.

Proposed fix
-) -> training.Trainer:
+) -> training.Trainer | NewTrainer:

Or better, define a common Protocol that both trainers implement (e.g., requiring run() and model attributes).

🤖 Fix all issues with AI agents
In `@deepmd/pt/train/checkpoint_manager.py`:
- Around line 151-171: The early return in _cleanup_old_checkpoints uses
self._saved_checkpoints but the cleanup logic re-discovers files with glob,
causing stale or missing cleanup; remove the early-return guard and instead
compute checkpoint_files first (using
Path(".").glob(f"{self.config.save_ckpt}*.pt") with the existing is_symlink and
name.startswith filters), sort by st_mtime, then while len(checkpoint_files) >
self.config.max_ckpt_keep unlink oldest files and log; this makes cleanup
consistent whether the manager was just created or persistent files exist and
keeps the existing filtering based on self.config.save_ckpt and max_ckpt_keep.
- Around line 138-139: The call to symlink_prefix_files uses save_path.stem
which strips directory components, causing glob(old_prefix + ".*") to miss files
when self.config.save_ckpt includes a directory; replace the stem-based prefix
with full-path prefixes (without extensions) so both prefixes include the same
directory. Concretely, pass str(save_path.with_suffix("")) (or equivalent
full-path prefix derived from save_path) and
str(Path(self.config.save_ckpt).with_suffix("")) to symlink_prefix_files instead
of save_path.stem and self.config.save_ckpt alone, and make the same change
where symlink_prefix_files is called in save_final to ensure glob operates
against the correct directory.
- Around line 216-218: The state_dict extraction in checkpoint_manager (both in
the block using state_dict = checkpoint.get("model", checkpoint) followed by an
immediate if "model" in checkpoint override, and the identical pattern in
load_for_finetune) is redundant; replace the two-step logic with a single
canonical extraction (e.g., use state_dict = checkpoint.get("model", checkpoint)
or use an if/else that sets state_dict = checkpoint["model"] else checkpoint) so
the fallback is actually used and the value is not overwritten immediately —
update the occurrences referencing state_dict and checkpoint in the methods load
(the block around state_dict = ...) and load_for_finetune accordingly.
- Around line 338-345: The code currently uses a bare except in the
checkpoint_file read block which silently swallows errors; change the except to
"except Exception as e" and log the exception (e.g., logger.exception or
logger.error with the exception info) when reading checkpoint_file.read_text()
or parsing latest, referencing the checkpoint_file and latest
variables/Path(...) usage in this block; ensure a module logger
(logging.getLogger(__name__)) is available or imported so operators can see
corruption/permission/invalid-path errors instead of them being ignored.

In `@deepmd/pt/train/config.py`:
- Around line 206-214: The code currently raises a ValueError when some
model_keys are missing from optim_dict (in the optim_dict block using
OptimizerConfig.from_dict) but silently falls back to lr_params when keys are
missing from learning_rate_dict; make the behavior consistent by validating
learning_rate_dict the same way: after building learning_rate_dict for keys in
model_keys, compute missing_keys = [k for k in model_keys if k not in
learning_rate_dict] and if missing_keys raise a ValueError (or optionally emit a
warning per project policy) instead of silently using lr_params; update the code
near the learning_rate_dict handling (the block that references lr_params and
learning_rate_dict) to perform this check.
- Around line 200-202: Replace the assert-based validation with an explicit
exception raise so the check always runs: in the code that currently uses
`assert num_steps - computed_warmup_steps > 0 or computed_warmup_steps == 0,
("Warm up steps must be less than total training steps!")` (the variables
`num_steps` and `computed_warmup_steps`), change it to an if statement that
raises `ValueError("Warm up steps must be less than total training steps!")`
when the condition fails; keep the same message and logic but use `ValueError`
for proper production-time validation.

In `@deepmd/pt/train/data_manager.py`:
- Around line 193-216: get_valid_numb_batch currently returns 1 when no
validation loader exists which causes an unnecessary validation iteration;
update get_valid_numb_batch (and its use of _get_valid_loader) to return 0 when
loader is None and also return 0 when len(loader) cannot be determined (replace
the fallback 1 with 0) so the trainer will skip validation loops for missing
loaders; ensure callers like get_valid_batch still handle zero batches
correctly.
- Around line 144-162: The code calls next(...) directly on AbstractDataLoader
instances (in get_train_batch), but AbstractDataLoader only defines __iter__, so
change DataManager to store iterators (e.g., create a training_iterators
structure by calling iter() on each loader when loaders are initialized) and
then call next(...) on those iterators (use
next(self.training_iterators[task_key]) or next(self.training_iterators) inside
get_train_batch); this keeps DpLoaderSetAdapter working and avoids requiring
__next__ on the ABC.

In `@deepmd/pt/train/optimizer_factory.py`:
- Around line 155-162: The post-warmup branch divides by lr_schedule.start_lr
which can be zero; add a guard to check if lr_schedule.start_lr == 0 and raise a
clear ValueError (or return a documented fallback) instead of performing the
division. Locate the division using lr_schedule.value(...)/lr_schedule.start_lr
in the function (the post-warmup branch shown) and replace it with a check that
either raises ValueError("lr_schedule.start_lr is zero; please set a non-zero
start_lr") or handles a defined fallback, and apply the same fix to the other
occurrence around the 200-208 region so both divisions are protected. Ensure the
error message references lr_schedule.start_lr so users know which config to fix.
- Around line 151-164: Extract the duplicated warmup_linear closure into a
single shared factory function named _make_warmup_lr_lambda(warmup_steps,
warmup_start_factor, lr_schedule, start_step=0) that returns the warmup_linear
callable (preserving the same logic and docstring), then replace the inline
warmup_linear closures in AdamStrategy.create_scheduler,
AdamWStrategy.create_scheduler, AdaMuonStrategy.create_scheduler and
HybridMuonStrategy.create_scheduler with calls to
torch.optim.lr_scheduler.LambdaLR(..., lr_lambda=_make_warmup_lr_lambda(...)) so
each scheduler construction becomes a one-liner using the new factory.
- Around line 226-231: The LKFOptimizer instantiation in optimizer_factory.py
currently hardcodes kalman_lambda=0.98 and kalman_nu=0.99870; update this by
adding configurable fields (e.g., kalman_lambda and kalman_nu) to
OptimizerConfig and use those fields when constructing LKFOptimizer (or, if you
prefer not to change the public config, extract the magic numbers into named
constants with docstrings and reference those constants in the LKFOptimizer(...)
call). Modify OptimizerConfig (config.py) to include the new fields with
sensible defaults and update any config parsing/validation, then replace the
literal values in the LKFOptimizer(...) call to reference
OptimizerConfig.kalman_lambda and OptimizerConfig.kalman_nu (or the new named
constants) so the parameters are documented and configurable.

In `@deepmd/pt/train/trainer.py`:
- Around line 237-240: The multitask detection is inconsistent: update the
assignment to self.is_multitask in trainer.py to match the rest of the codebase
by setting it solely based on the presence of "model_dict" in model_params
(i.e., self.is_multitask = "model_dict" in model_params) rather than checking
len(self.model_keys) > 1; adjust any dependent logic if necessary to rely on
self.model_keys for key counts but keep self.is_multitask as the simple presence
check.
- Around line 685-694: The _init_distributed method currently calls
torch.cuda.set_device(LOCAL_RANK) unconditionally which fails on CPU-only
backends; modify _init_distributed (and the DDP wrap) to first check
torch.cuda.is_available() before calling torch.cuda.set_device or passing
CUDA-specific arguments to DDP—i.e., only call torch.cuda.set_device(LOCAL_RANK)
and use device_ids=[LOCAL_RANK], output_device=LOCAL_RANK when
torch.cuda.is_available() is True; otherwise skip set_device and instantiate
DDP(self.wrapper, find_unused_parameters=True) (no CUDA-specific
device_ids/output_device) so it works with gloo/CPU-only setups.
- Around line 327-373: The _create_loss function mutates the input loss_params
dict (e.g., adding "starter_learning_rate", "ntypes", "tensor_size", etc.), so
make a local deep copy at the top of _create_loss (e.g., copy =
deepcopy(loss_params)) and use that copy for all modifications and to construct
the loss instances (EnergyStdLoss, EnergyHessianStdLoss, EnergySpinLoss,
DenoiseLoss, DOSLoss, TensorLoss, PropertyLoss, and
TaskLoss.get_class_by_type(...).get_loss). This prevents changing the original
config passed by the caller while preserving all current key assignments and
behavior.
- Around line 413-418: The local variable finetune_has_new_type_key is computed
but never used; remove the dead assignment or if you intend to keep it for
future use, rename it to _finetune_has_new_type_key to signal intentionally
unused. Update the block around the expression that references
self.finetune_links.get_has_new_type(), which is currently guarded by
self.is_finetune and self.finetune_links, by either deleting that whole
assignment line or renaming the variable as described; ensure no other code
relies on finetune_has_new_type_key before removing or renaming.
- Around line 662-683: The loop that builds model_prob for self.model_keys is a
no-op and always assigns uniform probability; replace the placeholder with logic
that queries self.data_manager for each model_key's training dataset size (e.g.,
via a method/property analogous to the legacy trainer) to compute per-model
weights proportional to data sizes, falling back to uniform weights if
data_manager is missing or a model's size is zero; normalize model_prob to sum
to 1 and then build model_key_prob_map as before, ensuring compatibility with
self.wrapper.share_params and matching the legacy trainer's probability
computation semantics.
- Around line 1118-1121: The final checkpoint currently calls
self._save_checkpoint(self.config.num_steps - 1, 0.0) which hardcodes lr=0.0 and
may double-save; change _finalize_training to use the actual last learning rate
tracked by the training loop (e.g., maintain self._last_lr updated each step in
the training loop) and call self._save_checkpoint(self.config.num_steps - 1,
self._last_lr) instead, and add a guard in _finalize_training (or check in
_save_checkpoint) to skip saving if a checkpoint for step self.config.num_steps
- 1 was already written (or if the stored last checkpoint step equals
config.num_steps - 1) to avoid redundant saves.

In `@deepmd/pt/train/training_loop.py`:
- Around line 226-238: The _compute_prefactors function can divide by zero when
start_pref_e or start_pref_f is 0; update it to guard against zeros by computing
ratio = step / max(1, self.num_steps) and for each prefactor use geometric
interpolation as now when start_pref_* != 0, but fall back to linear
interpolation pref = start + (limit - start) * ratio when start_pref_* == 0
(apply this logic for start_pref_e/limit_pref_e and start_pref_f/limit_pref_f in
_compute_prefactors).
- Around line 248-252: The code incorrectly reads step from
self.optimizer.state.get("step", 0) (optimizer.state is per-parameter, so this
always yields 0) causing _compute_prefactors to never advance; fix by passing
the current training step into the training step method instead: update the
signature of BaseTrainingLoop.step (and any overrides like
LKFEnergyTrainingLoop.step) to accept a step: int parameter, propagate the
caller's step (Trainer) into that call, and replace uses of
self.optimizer.state.get("step", 0) with the passed-in step when calling
_compute_prefactors.
🧹 Nitpick comments (19)
deepmd/pt/train/training_loop.py (1)

217-224: Magic numbers 24 (kp) and 6 (kq) are duplicated across two classes.

These KF parameters appear in both LKFEnergyTrainingLoop (lines 221-222) and LKFDenoiseTrainingLoop (lines 303-304) with no explanation or configurability. Extract them as class-level constants or accept them from configuration.

♻️ Proposed refactor
+# Default Kalman Filter parameters
+_DEFAULT_KF_KP = 24
+_DEFAULT_KF_KQ = 6
+
 class LKFEnergyTrainingLoop(BaseTrainingLoop):
     ...
         self.kf_wrapper = KFOptimizerWrapper(
             wrapper,
             optimizer,
-            24,  # kp
-            6,  # kq
+            _DEFAULT_KF_KP,
+            _DEFAULT_KF_KQ,
             dist.is_available() and dist.is_initialized(),
         )

Apply the same change to LKFDenoiseTrainingLoop.

Also applies to: 300-306

deepmd/pt/train/__init__.py (1)

67-72: LKFDenoiseTrainingLoop is not re-exported, unlike its sibling LKFEnergyTrainingLoop.

LKFEnergyTrainingLoop is imported and listed in __all__, but LKFDenoiseTrainingLoop (also a public class in training_loop.py) is neither imported nor listed. If both are part of the public API, add the missing export for consistency.

♻️ Proposed fix
 from deepmd.pt.train.training_loop import (
     AdamTrainingLoop,
     BaseTrainingLoop,
+    LKFDenoiseTrainingLoop,
     LKFEnergyTrainingLoop,
     TrainingLoopFactory,
 )

And add "LKFDenoiseTrainingLoop" to __all__.

deepmd/pt/train/logger.py (1)

47-82: File handle opened in __init__ may leak if close() is never called.

The TrainingLogger opens a file on construction (line 79) but cleanup depends on explicit close() or context-manager usage. If Trainer.__init__ fails after creating the logger (but before run() is invoked), the handle leaks. Consider adding a __del__ safety net.

♻️ Proposed safety net
+    def __del__(self) -> None:
+        """Ensure file handle is closed on garbage collection."""
+        self.close()
deepmd/pt/train/data_loader.py (2)

295-308: DpLoaderSetAdapter is both iterable and iterator (__iter__ returns self).

This means only one iteration can be active at a time. If __iter__ is called again, it resets the internal _iterator, potentially losing the position of any in-progress iteration. This is fine for the current single-consumer training loop, but worth noting as a limitation.


409-426: DataLoaderFactory._implementations is a mutable ClassVarregister() mutates global state.

register() modifies a class-level dict shared by all instances and imports. In multi-process or testing scenarios, registered implementations persist across test cases. This is acceptable for a plugin registry pattern but consider documenting the global-state semantics.

deepmd/pt/train/data_manager.py (1)

86-101: training_loaders and validation_loaders have different types depending on is_multitask.

In multitask mode, these are dict[str, AbstractDataLoader]; in single-task mode, they're AbstractDataLoader / AbstractDataLoader | None. This makes downstream code require isinstance checks or is_multitask guards everywhere. Consider always using a dict (e.g., {"Default": loader}) for a uniform interface.

source/tests/pt/test_new_training.py (1)

308-444: End-to-end tests verify creation and initialization but not training execution.

The TestEndToEndCLI tests confirm the trainer initializes correctly and can fetch data, but there's no test that calls trainer.run() even for a single step. Consider adding a minimal test that runs 1-2 training steps to verify the full loop (forward, backward, optimizer step, logging, checkpointing).

💡 Example test sketch
def test_trainer_runs_one_step(self):
    """Test that trainer can execute at least one training step."""
    config = copy.deepcopy(self.config)
    config["training"]["numb_steps"] = 1
    config["training"]["save_freq"] = 1
    trainer = get_trainer(config)
    # Should not raise
    trainer.run()
deepmd/pt/entrypoints/main.py (1)

385-394: Add CLI flag to fall back to legacy trainer during transition period.

The new (refactored) trainer is now the default with no option to use the legacy trainer via the CLI. Users cannot fall back to the legacy trainer during the transition period. Add a --use-legacy-trainer flag to parser_train to allow users to opt in to the legacy trainer behavior.

deepmd/pt/train/checkpoint_manager.py (3)

106-109: DDP unwrapping should not depend on dist.is_initialized().

A module can be wrapped in DistributedDataParallel (and thus have a .module attribute) even when checked outside a distributed context, e.g., in certain test setups or after dist has been destroyed. The current guard means you'd serialize the DDP wrapper's state dict (with module. prefixed keys) in those edge cases. The same pattern is repeated in save_final (Lines 380–383), load (Lines 211–214), and load_for_finetune (Lines 264–267).

A simpler, more robust pattern:

Suggested change
-        module = wrapper
-        if dist.is_available() and dist.is_initialized():
-            if hasattr(wrapper, "module"):
-                module = wrapper.module
+        module = getattr(wrapper, "module", wrapper)

354-406: save_final duplicates most of save — consider consolidation.

save_final repeats the DDP unwrapping, train_infos update, checkpoint creation, torch.save, symlink, and checkpoint-file write logic from save, differing only in the absence of the optimizer state and the naming scheme. This violates DRY and increases the risk of divergence (the off-by-one naming issue is one such case already).

Consider having save_final delegate to save with an optional include_optimizer=False flag, or extracting the shared logic into a private helper.


120-124: deepcopy is unnecessary and memory-inefficient for optimizer state saving

optimizer.state_dict() returns a new dictionary, but its values (including state tensors) are references to the optimizer's internal state. Since the code only replaces scalar lr values (line 124) and doesn't mutate the state tensors themselves, the deepcopy unnecessarily copies all optimizer state tensors (momentum buffers, Adam's second moments, etc.). For large models, this doubles peak memory usage during checkpoint saves. Save directly with the state dict returned by state_dict().

Suggested change
-        optim_state = deepcopy(optimizer.state_dict())
+        optim_state = optimizer.state_dict()
deepmd/pt/train/config.py (1)

34-71: Default values are duplicated between field declarations and from_dict.

Every dataclass in this file (e.g., OptimizerConfig, LearningRateConfig, DisplayConfig, CheckpointConfig) specifies default values twice — once in the field declarations and again in the .get() calls within from_dict. If one changes without the other, they'll silently diverge.

Consider using the field defaults directly:

Example for OptimizerConfig
     `@classmethod`
     def from_dict(cls, params: dict[str, Any]) -> OptimizerConfig:
         """Create OptimizerConfig from dictionary."""
+        fields = {f.name: f.default for f in cls.__dataclass_fields__.values()}
         return cls(
-            opt_type=params.get("opt_type", "Adam"),
-            weight_decay=params.get("weight_decay", 0.001),
-            ...
+            **{k: params.get(k, v) for k, v in fields.items()}
         )

This applies to all four smaller config classes. TrainingConfig.from_dict is different since it performs transformations.

deepmd/pt/train/hooks.py (5)

54-67: TrainingHook inherits ABC but has no abstract methods — consider dropping ABC.

Since all hook methods are optional overrides (none are @abstractmethod), inheriting from ABC is misleading — it suggests subclasses must implement something. A plain base class communicates the intent (all methods optional) more clearly. Ruff also flags this (B024).


411-421: TimingHook.on_step_end measures wall time of the entire step plus all higher-priority hooks.

Since hooks execute in priority order and TimingHook has LOW priority, by the time on_step_end fires, all other hooks' on_step_end calls have already completed. The last_step_time is set at the end of this method, so the measured interval includes other hooks' on_step_end overhead plus the gap between on_step_begin and on_step_end of the next step. This is probably fine for a rough ETA, but worth documenting that it measures wall time inclusive of hook overhead.

Also, step_times.pop(0) on a list is O(n). For a fixed-size rolling window, collections.deque(maxlen=100) would be more efficient and idiomatic.

Suggested change for deque
+from collections import deque
+
 class TimingHook(TrainingHook):
     ...
     def __init__(self) -> None:
-        self.step_times: list[float] = []
+        self.step_times: deque[float] = deque(maxlen=100)
         self.start_time: float | None = None
         self.last_step_time: float | None = None

     def on_step_end(self, step, logs=None):
         import time
         if self.last_step_time is not None:
             step_time = time.time() - self.last_step_time
             self.step_times.append(step_time)
-            if len(self.step_times) > 100:
-                self.step_times.pop(0)
         self.last_step_time = time.time()

490-504: EarlyStoppingHook stores is_better as a lambda — not picklable.

If any part of the system needs to serialize hooks (e.g., for distributed training, checkpointing hook state), the lambda stored in self.is_better will fail to pickle. This can be avoided by using operator.lt / operator.gt or a simple method dispatch on self.mode.

Suggested change
-        if mode == "min":
-            self.is_better = lambda current, best: current < best
-            self.best_value = float("inf")
-        elif mode == "max":
-            self.is_better = lambda current, best: current > best
-            self.best_value = float("-inf")
+        if mode == "min":
+            self.best_value = float("inf")
+        elif mode == "max":
+            self.best_value = float("-inf")
         else:
             raise ValueError(f"mode must be 'min' or 'max', got {mode}")
+
+    def is_better(self, current: float, best: float) -> bool:
+        if self.mode == "min":
+            return current < best
+        return current > best

347-368: TensorBoardHook.on_step_end iterates all log keys after already checking loss and lr.

The loop on Lines 366–368 re-checks and skips "loss" and "lr", but if logs contains many keys, this does redundant string comparisons. More importantly, non-numeric values (e.g., tensors, strings) are silently dropped by the isinstance guard, which is fine, but torch Tensor scalars (0-dim) won't match (int, float) and will be silently ignored.

Consider adding torch.Tensor to the isinstance check or calling .item() for scalar tensors:

Suggested addition
         for key, value in logs.items():
             if key not in ["loss", "lr"] and isinstance(value, (int, float)):
                 self.writer.add_scalar(f"train/{key}", value, display_step)
+            elif key not in ["loss", "lr"] and isinstance(value, torch.Tensor) and value.ndim == 0:
+                self.writer.add_scalar(f"train/{key}", value.item(), display_step)

382-386: TensorBoardHook.on_train_end does not flush before closing.

SummaryWriter.close() should flush pending events. While the default implementation of close() calls flush() internally in most TensorBoard versions, explicitly flushing before close is a common best practice to avoid data loss if the writer implementation changes.

deepmd/pt/train/optimizer_factory.py (2)

135-139: Simplify the fused ternary.

False if DEVICE.type == "cpu" else True is equivalent to DEVICE.type != "cpu".

Suggested fix
-            fused=False if DEVICE.type == "cpu" else True,
+            fused=(DEVICE.type != "cpu"),

Also applies to: 180-184


496-497: Global mutable singleton _default_factory is a hidden shared-state risk.

The module-level _default_factory = OptimizerFactory() is created at import time. If any caller uses _default_factory.register(...) to add a custom strategy, it affects all subsequent users of the module-level create_optimizer / create_scheduler functions — including across tests. This could lead to flaky tests or unexpected behavior in long-running processes.

Consider documenting this explicitly, or providing a get_default_factory() function that makes the shared state visible and resettable.

Comment on lines +138 to +139
# Update symlinks
symlink_prefix_files(save_path.stem, self.config.save_ckpt)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

symlink_prefix_files expects string prefixes — save_path.stem drops the directory.

save_path.stem returns just the filename without extension (e.g., "model.ckpt-1001"), while self.config.save_ckpt is "model.ckpt". The symlink_prefix_files function calls glob(old_prefix + ".*"), which glob-matches from the current working directory. This works only if the checkpoint is saved in the CWD. If save_ckpt ever contains a directory component (e.g., "checkpoints/model.ckpt"), the stem-based prefix and the config prefix will diverge and the symlink won't be created. The same concern applies to Line 401 in save_final.

🤖 Prompt for AI Agents
In `@deepmd/pt/train/checkpoint_manager.py` around lines 138 - 139, The call to
symlink_prefix_files uses save_path.stem which strips directory components,
causing glob(old_prefix + ".*") to miss files when self.config.save_ckpt
includes a directory; replace the stem-based prefix with full-path prefixes
(without extensions) so both prefixes include the same directory. Concretely,
pass str(save_path.with_suffix("")) (or equivalent full-path prefix derived from
save_path) and str(Path(self.config.save_ckpt).with_suffix("")) to
symlink_prefix_files instead of save_path.stem and self.config.save_ckpt alone,
and make the same change where symlink_prefix_files is called in save_final to
ensure glob operates against the correct directory.

Comment on lines +151 to +171
def _cleanup_old_checkpoints(self) -> None:
"""Remove old checkpoints keeping only max_ckpt_keep most recent."""
if len(self._saved_checkpoints) <= self.config.max_ckpt_keep:
return

# Sort by modification time
checkpoint_files = [
f
for f in Path(".").glob(f"{self.config.save_ckpt}*.pt")
if not f.is_symlink() and f.name.startswith(self.config.save_ckpt)
]
checkpoint_files.sort(key=lambda x: x.stat().st_mtime)

# Remove oldest
while len(checkpoint_files) > self.config.max_ckpt_keep:
old_file = checkpoint_files.pop(0)
try:
old_file.unlink()
log.debug(f"Removed old checkpoint: {old_file}")
except OSError as e:
log.warning(f"Failed to remove old checkpoint {old_file}: {e}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Cleanup logic has a disconnect with _saved_checkpoints tracking.

The early return on Line 153 checks self._saved_checkpoints, but the actual cleanup on Lines 157–171 re-discovers checkpoints via glob, ignoring _saved_checkpoints entirely. This means:

  1. If the process restarts (fresh CheckpointManager), _saved_checkpoints is empty and cleanup never triggers even if many old files exist on disk.
  2. The glob f"{self.config.save_ckpt}*.pt" may match files not managed by this run.

Consider either consistently using _saved_checkpoints for cleanup or always using the glob-based discovery (without the early-return guard).

🤖 Prompt for AI Agents
In `@deepmd/pt/train/checkpoint_manager.py` around lines 151 - 171, The early
return in _cleanup_old_checkpoints uses self._saved_checkpoints but the cleanup
logic re-discovers files with glob, causing stale or missing cleanup; remove the
early-return guard and instead compute checkpoint_files first (using
Path(".").glob(f"{self.config.save_ckpt}*.pt") with the existing is_symlink and
name.startswith filters), sort by st_mtime, then while len(checkpoint_files) >
self.config.max_ckpt_keep unlink oldest files and log; this makes cleanup
consistent whether the manager was just created or persistent files exist and
keeps the existing filtering based on self.config.save_ckpt and max_ckpt_keep.

Comment on lines +216 to +218
state_dict = checkpoint.get("model", checkpoint)
if "model" in checkpoint:
state_dict = checkpoint["model"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Redundant state_dict extraction.

Lines 216–218 extract state_dict twice — Line 216 uses .get("model", checkpoint) as a fallback and Line 217-218 immediately overwrite it if "model" key exists. The fallback on Line 216 is never reachable when Line 217 takes effect. The same pattern appears in load_for_finetune (Lines 269–271).

Simplify to
-            state_dict = checkpoint.get("model", checkpoint)
-            if "model" in checkpoint:
-                state_dict = checkpoint["model"]
+            state_dict = checkpoint.get("model", checkpoint)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
state_dict = checkpoint.get("model", checkpoint)
if "model" in checkpoint:
state_dict = checkpoint["model"]
state_dict = checkpoint.get("model", checkpoint)
🤖 Prompt for AI Agents
In `@deepmd/pt/train/checkpoint_manager.py` around lines 216 - 218, The state_dict
extraction in checkpoint_manager (both in the block using state_dict =
checkpoint.get("model", checkpoint) followed by an immediate if "model" in
checkpoint override, and the identical pattern in load_for_finetune) is
redundant; replace the two-step logic with a single canonical extraction (e.g.,
use state_dict = checkpoint.get("model", checkpoint) or use an if/else that sets
state_dict = checkpoint["model"] else checkpoint) so the fallback is actually
used and the value is not overwritten immediately — update the occurrences
referencing state_dict and checkpoint in the methods load (the block around
state_dict = ...) and load_for_finetune accordingly.

Comment on lines +338 to +345
checkpoint_file = Path("checkpoint")
if checkpoint_file.exists():
try:
latest = Path(checkpoint_file.read_text().strip())
if latest.exists():
return latest
except Exception:
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Silent except: pass swallows all errors when reading the checkpoint file.

If the checkpoint file exists but is corrupted, has permission issues, or contains an invalid path, this is silently ignored. At minimum, log the exception so operators can diagnose recovery failures. This aligns with the Ruff S110 hint.

Suggested fix
-            except Exception:
-                pass
+            except Exception:
+                log.warning("Failed to read checkpoint file", exc_info=True)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
checkpoint_file = Path("checkpoint")
if checkpoint_file.exists():
try:
latest = Path(checkpoint_file.read_text().strip())
if latest.exists():
return latest
except Exception:
pass
checkpoint_file = Path("checkpoint")
if checkpoint_file.exists():
try:
latest = Path(checkpoint_file.read_text().strip())
if latest.exists():
return latest
except Exception:
log.warning("Failed to read checkpoint file", exc_info=True)
🧰 Tools
🪛 Ruff (0.15.0)

[error] 344-345: try-except-pass detected, consider logging the exception

(S110)


[warning] 344-344: Do not catch blind exception: Exception

(BLE001)

🤖 Prompt for AI Agents
In `@deepmd/pt/train/checkpoint_manager.py` around lines 338 - 345, The code
currently uses a bare except in the checkpoint_file read block which silently
swallows errors; change the except to "except Exception as e" and log the
exception (e.g., logger.exception or logger.error with the exception info) when
reading checkpoint_file.read_text() or parsing latest, referencing the
checkpoint_file and latest variables/Path(...) usage in this block; ensure a
module logger (logging.getLogger(__name__)) is available or imported so
operators can see corruption/permission/invalid-path errors instead of them
being ignored.

Comment on lines +200 to +202
assert num_steps - computed_warmup_steps > 0 or computed_warmup_steps == 0, (
"Warm up steps must be less than total training steps!"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Use ValueError instead of assert for input validation.

assert statements are stripped when Python runs with -O (optimized mode), so this check would silently disappear in production. Since this validates user-provided configuration, it should use ValueError to ensure it's always enforced.

Suggested fix
-        assert num_steps - computed_warmup_steps > 0 or computed_warmup_steps == 0, (
-            "Warm up steps must be less than total training steps!"
-        )
+        if computed_warmup_steps != 0 and num_steps - computed_warmup_steps <= 0:
+            raise ValueError(
+                "Warm up steps must be less than total training steps!"
+            )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
assert num_steps - computed_warmup_steps > 0 or computed_warmup_steps == 0, (
"Warm up steps must be less than total training steps!"
)
if computed_warmup_steps != 0 and num_steps - computed_warmup_steps <= 0:
raise ValueError(
"Warm up steps must be less than total training steps!"
)
🤖 Prompt for AI Agents
In `@deepmd/pt/train/config.py` around lines 200 - 202, Replace the assert-based
validation with an explicit exception raise so the check always runs: in the
code that currently uses `assert num_steps - computed_warmup_steps > 0 or
computed_warmup_steps == 0, ("Warm up steps must be less than total training
steps!")` (the variables `num_steps` and `computed_warmup_steps`), change it to
an if statement that raises `ValueError("Warm up steps must be less than total
training steps!")` when the condition fails; keep the same message and logic but
use `ValueError` for proper production-time validation.

Comment on lines +662 to +683

# Compute model probabilities
model_prob = np.zeros(len(self.model_keys), dtype=np.float32)
for ii, model_key in enumerate(self.model_keys):
# Get training data size for this model
if hasattr(self, "data_manager") and self.data_manager:
# Try to get from data_manager
pass
# Use uniform probability for now
model_prob[ii] = 1.0

model_prob = model_prob / np.sum(model_prob)
model_key_prob_map = dict(zip(self.model_keys, model_prob))

# Call share_params
self.wrapper.share_params(
shared_links,
resume=(self.is_restart and not self.finetune_update_stat)
or self.rank != 0,
model_key_prob_map=model_key_prob_map,
data_stat_protect=data_stat_protect_values[0],
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Multi-task model probability computation is a no-op — always produces uniform distribution.

Lines 665-671: The loop body unconditionally sets model_prob[ii] = 1.0 with a pass inside the if hasattr(...) branch, meaning the data-size-based probability logic was never implemented. The legacy trainer computes probabilities from actual data sizes. This will affect multi-task training balance.

🐛 Proposed fix sketch
         model_prob = np.zeros(len(self.model_keys), dtype=np.float32)
         for ii, model_key in enumerate(self.model_keys):
-            # Get training data size for this model
-            if hasattr(self, "data_manager") and self.data_manager:
-                # Try to get from data_manager
-                pass
-            # Use uniform probability for now
-            model_prob[ii] = 1.0
+            if hasattr(self, "data_manager") and self.data_manager:
+                loader = self.data_manager.training_loaders.get(model_key)
+                if loader is not None:
+                    model_prob[ii] = float(len(loader))
+                else:
+                    model_prob[ii] = 1.0
+            else:
+                model_prob[ii] = 1.0

Verify against the legacy trainer's probability computation to ensure parity.

🧰 Tools
🪛 Ruff (0.15.0)

[warning] 665-665: Loop control variable model_key not used within loop body

Rename unused model_key to _model_key

(B007)


[warning] 674-674: zip() without an explicit strict= parameter

Add explicit value for parameter strict=

(B905)

🤖 Prompt for AI Agents
In `@deepmd/pt/train/trainer.py` around lines 662 - 683, The loop that builds
model_prob for self.model_keys is a no-op and always assigns uniform
probability; replace the placeholder with logic that queries self.data_manager
for each model_key's training dataset size (e.g., via a method/property
analogous to the legacy trainer) to compute per-model weights proportional to
data sizes, falling back to uniform weights if data_manager is missing or a
model's size is zero; normalize model_prob to sum to 1 and then build
model_key_prob_map as before, ensuring compatibility with
self.wrapper.share_params and matching the legacy trainer's probability
computation semantics.

Comment on lines +685 to +694
def _init_distributed(self) -> None:
"""Initialize distributed training."""
if dist.is_available() and dist.is_initialized():
torch.cuda.set_device(LOCAL_RANK)
self.wrapper = DDP(
self.wrapper,
device_ids=[LOCAL_RANK],
find_unused_parameters=True,
output_device=LOCAL_RANK,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

torch.cuda.set_device called unconditionally — fails on CPU-only systems.

Line 688: When dist.is_initialized() is True but the backend is gloo (CPU-only), torch.cuda.set_device(LOCAL_RANK) will raise an error if no CUDA devices are available. Guard with a CUDA availability check.

🐛 Proposed fix
     def _init_distributed(self) -> None:
         """Initialize distributed training."""
         if dist.is_available() and dist.is_initialized():
-            torch.cuda.set_device(LOCAL_RANK)
+            if torch.cuda.is_available():
+                torch.cuda.set_device(LOCAL_RANK)
             self.wrapper = DDP(
                 self.wrapper,
                 device_ids=[LOCAL_RANK],
🤖 Prompt for AI Agents
In `@deepmd/pt/train/trainer.py` around lines 685 - 694, The _init_distributed
method currently calls torch.cuda.set_device(LOCAL_RANK) unconditionally which
fails on CPU-only backends; modify _init_distributed (and the DDP wrap) to first
check torch.cuda.is_available() before calling torch.cuda.set_device or passing
CUDA-specific arguments to DDP—i.e., only call torch.cuda.set_device(LOCAL_RANK)
and use device_ids=[LOCAL_RANK], output_device=LOCAL_RANK when
torch.cuda.is_available() is True; otherwise skip set_device and instantiate
DDP(self.wrapper, find_unused_parameters=True) (no CUDA-specific
device_ids/output_device) so it works with gloo/CPU-only setups.

Comment on lines +1118 to +1121
def _finalize_training(self, total_time: float, timed_steps: int) -> None:
"""Finalize training and cleanup."""
# Save final checkpoint
self._save_checkpoint(self.config.num_steps - 1, 0.0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Final checkpoint saved with lr=0.0 instead of actual learning rate.

Line 1121: self._save_checkpoint(self.config.num_steps - 1, 0.0) hardcodes lr=0.0. This misrepresents the actual learning rate at the end of training and could affect resume behavior. Additionally, if the last step already triggered a checkpoint save (line 915-916), this creates a redundant save.

🐛 Proposed fix

Track the last learning rate from the training loop and use it here, and guard against double-saving:

     def _finalize_training(self, total_time: float, timed_steps: int) -> None:
         """Finalize training and cleanup."""
-        # Save final checkpoint
-        self._save_checkpoint(self.config.num_steps - 1, 0.0)
+        # Save final checkpoint if not already saved
+        last_step = self.config.num_steps
+        if last_step % self.config.checkpoint.save_freq != 0:
+            self._save_checkpoint(self.config.num_steps - 1, self.lr_schedule.value(self.config.num_steps - 1))
🤖 Prompt for AI Agents
In `@deepmd/pt/train/trainer.py` around lines 1118 - 1121, The final checkpoint
currently calls self._save_checkpoint(self.config.num_steps - 1, 0.0) which
hardcodes lr=0.0 and may double-save; change _finalize_training to use the
actual last learning rate tracked by the training loop (e.g., maintain
self._last_lr updated each step in the training loop) and call
self._save_checkpoint(self.config.num_steps - 1, self._last_lr) instead, and add
a guard in _finalize_training (or check in _save_checkpoint) to skip saving if a
checkpoint for step self.config.num_steps - 1 was already written (or if the
stored last checkpoint step equals config.num_steps - 1) to avoid redundant
saves.

Comment on lines +226 to +238
def _compute_prefactors(self, step: int) -> tuple[float, float]:
"""Compute energy and force prefactors for current step."""
start_pref_e = self.opt_param["kf_start_pref_e"]
limit_pref_e = self.opt_param["kf_limit_pref_e"]
start_pref_f = self.opt_param["kf_start_pref_f"]
limit_pref_f = self.opt_param["kf_limit_pref_f"]

ratio = step / self.num_steps

pref_e = start_pref_e * (limit_pref_e / start_pref_e) ** ratio
pref_f = start_pref_f * (limit_pref_f / start_pref_f) ** ratio

return pref_e, pref_f
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Division by zero when start_pref_e or start_pref_f is 0.

Line 235: (limit_pref_e / start_pref_e) ** ratio will raise ZeroDivisionError if start_pref_e is 0, and similarly for start_pref_f on line 236. The legacy trainer guards against zero prefactors.

🐛 Proposed fix
     def _compute_prefactors(self, step: int) -> tuple[float, float]:
         """Compute energy and force prefactors for current step."""
         start_pref_e = self.opt_param["kf_start_pref_e"]
         limit_pref_e = self.opt_param["kf_limit_pref_e"]
         start_pref_f = self.opt_param["kf_start_pref_f"]
         limit_pref_f = self.opt_param["kf_limit_pref_f"]

         ratio = step / self.num_steps

-        pref_e = start_pref_e * (limit_pref_e / start_pref_e) ** ratio
-        pref_f = start_pref_f * (limit_pref_f / start_pref_f) ** ratio
+        if start_pref_e == 0.0:
+            pref_e = limit_pref_e
+        else:
+            pref_e = start_pref_e * (limit_pref_e / start_pref_e) ** ratio
+        if start_pref_f == 0.0:
+            pref_f = limit_pref_f
+        else:
+            pref_f = start_pref_f * (limit_pref_f / start_pref_f) ** ratio

         return pref_e, pref_f
🤖 Prompt for AI Agents
In `@deepmd/pt/train/training_loop.py` around lines 226 - 238, The
_compute_prefactors function can divide by zero when start_pref_e or
start_pref_f is 0; update it to guard against zeros by computing ratio = step /
max(1, self.num_steps) and for each prefactor use geometric interpolation as now
when start_pref_* != 0, but fall back to linear interpolation pref = start +
(limit - start) * ratio when start_pref_* == 0 (apply this logic for
start_pref_e/limit_pref_e and start_pref_f/limit_pref_f in _compute_prefactors).

Comment on lines +248 to +252
"""Execute LKF training step."""
# Compute prefactors
step = self.optimizer.state.get("step", 0)
pref_e, pref_f = self._compute_prefactors(step)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Step retrieval from self.optimizer.state may not work as expected.

Line 250: self.optimizer.state.get("step", 0)optimizer.state is a defaultdict keyed by parameter tensors, not a flat dict with a "step" key. This will always return 0 for standard PyTorch optimizers and likely also for the LKF optimizer, causing the prefactor schedule to never advance.

The step should be passed into the step() method or tracked externally. Consider accepting step as a parameter, since the caller (Trainer) already knows the current training step.

🐛 Proposed approach: pass step as argument

Update BaseTrainingLoop.step to accept step: int or have LKFEnergyTrainingLoop track the step count internally:

 class BaseTrainingLoop(ABC):
     `@abstractmethod`
     def step(
         self,
         input_dict: dict[str, torch.Tensor],
         label_dict: dict[str, torch.Tensor],
         cur_lr: float,
         pref_lr: float,
         task_key: str = "Default",
+        global_step: int = 0,
     ) -> TrainingStepResult:
🤖 Prompt for AI Agents
In `@deepmd/pt/train/training_loop.py` around lines 248 - 252, The code
incorrectly reads step from self.optimizer.state.get("step", 0) (optimizer.state
is per-parameter, so this always yields 0) causing _compute_prefactors to never
advance; fix by passing the current training step into the training step method
instead: update the signature of BaseTrainingLoop.step (and any overrides like
LKFEnergyTrainingLoop.step) to accept a step: int parameter, propagate the
caller's step (Trainer) into that call, and replace uses of
self.optimizer.state.get("step", 0) with the passed-in step when calling
_compute_prefactors.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant