-
Notifications
You must be signed in to change notification settings - Fork 590
refactor(pt): refactor training code #5216
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
Summary of ChangesHello @iProzd, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly refactors the PyTorch training code into a highly modular and extensible system. The primary goal is to improve the organization and maintainability of the training pipeline by breaking down monolithic logic into specialized components for data management, checkpointing, logging, and optimizer handling. The new design also introduces a flexible hook system for custom callbacks and ensures backward compatibility with the previous trainer implementation, facilitating a smooth transition to the new architecture. Highlights
Changelog
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
| latest = Path(checkpoint_file.read_text().strip()) | ||
| if latest.exists(): | ||
| return latest | ||
| except Exception: |
Check notice
Code scanning / CodeQL
Empty except Note
|
|
||
| def __next__(self) -> dict[str, Any]: | ||
| """Get next batch.""" | ||
| ... |
Check notice
Code scanning / CodeQL
Statement has no effect Note
|
|
||
| def add_data_requirement(self, requirement: Any) -> None: | ||
| """Add data requirements for labels.""" | ||
| ... |
Check notice
Code scanning / CodeQL
Statement has no effect Note
|
|
||
| def preload_and_modify_all_data_torch(self) -> None: | ||
| """Preload and apply modifiers to data.""" | ||
| ... |
Check notice
Code scanning / CodeQL
Statement has no effect Note
|
|
||
| def print_summary(self, name: str, weights: Any = None) -> None: | ||
| """Print dataset summary.""" | ||
| ... |
Check notice
Code scanning / CodeQL
Statement has no effect Note
| @property | ||
| def systems(self) -> list[Any]: | ||
| """Get list of systems/datasets.""" | ||
| ... |
Check notice
Code scanning / CodeQL
Statement has no effect Note
| ) | ||
|
|
||
| # Compute statistics | ||
| finetune_has_new_type_key = ( |
Check notice
Code scanning / CodeQL
Unused local variable Note
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a significant and well-executed refactoring of the PyTorch training code. The new modular architecture, which separates concerns into components like DataManager, CheckpointManager, HookManager, and TrainingLoopFactory, is a major improvement for maintainability, extensibility, and clarity. The use of modern design patterns is commendable. My review has identified a few issues, including a potential bug in the LKF training loop, an incomplete implementation for multi-task parameter sharing, and some opportunities for code cleanup and performance optimization. Overall, this is an excellent contribution to the codebase.
| _, loss, more_loss = module.loss[task_key]( | ||
| {}, | ||
| fake_model, | ||
| label_dict, | ||
| natoms, | ||
| learning_rate=pref_lr, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The loss function module.loss[task_key] is called with an empty dictionary {} as the first argument for input_dict. The loss function's signature expects input_dict, and it might require information from it (e.g., flags indicating which properties are present). Passing an empty dictionary is likely a bug and could lead to incorrect loss calculation or runtime errors.
| _, loss, more_loss = module.loss[task_key]( | |
| {}, | |
| fake_model, | |
| label_dict, | |
| natoms, | |
| learning_rate=pref_lr, | |
| ) | |
| _, loss, more_loss = module.loss[task_key]( | |
| input_dict, | |
| fake_model, | |
| label_dict, | |
| natoms, | |
| learning_rate=pref_lr, | |
| ) |
| if use_legacy: | ||
| trainer = training.Trainer( | ||
| config, | ||
| train_data, | ||
| stat_file_path=stat_file_path, | ||
| validation_data=validation_data, | ||
| init_model=init_model, | ||
| restart_model=restart_model, | ||
| finetune_model=finetune_model, | ||
| force_load=force_load, | ||
| shared_links=shared_links, | ||
| finetune_links=finetune_links, | ||
| init_frz_model=init_frz_model, | ||
| ) | ||
| else: | ||
| trainer = NewTrainer( | ||
| config, | ||
| train_data, | ||
| stat_file_path=stat_file_path, | ||
| validation_data=validation_data, | ||
| init_model=init_model, | ||
| restart_model=restart_model, | ||
| finetune_model=finetune_model, | ||
| force_load=force_load, | ||
| shared_links=shared_links, | ||
| finetune_links=finetune_links, | ||
| init_frz_model=init_frz_model, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This block of code for creating the trainer instance is duplicated for the use_legacy and else branches. This can be refactored to reduce code duplication and improve maintainability by selecting the class first and then instantiating it once.
trainer_class = training.Trainer if use_legacy else NewTrainer
trainer = trainer_class(
config,
train_data,
stat_file_path=stat_file_path,
validation_data=validation_data,
init_model=init_model,
restart_model=restart_model,
finetune_model=finetune_model,
force_load=force_load,
shared_links=shared_links,
finetune_links=finetune_links,
init_frz_model=init_frz_model,
)| else 0, | ||
| drop_last=False, | ||
| collate_fn=lambda batch: batch, # prevent extra conversion | ||
| pin_memory=False, # Batch processor handles device transfer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For optimal performance when training on a GPU, pin_memory should be set to True. This allows the DataLoader to place tensors in pinned (page-locked) memory, which enables faster, asynchronous data transfer to the GPU when using non_blocking=True in the .to() call, as is done in BatchProcessor.
| pin_memory=False, # Batch processor handles device transfer | |
| pin_memory=True, # Enable asynchronous data transfer |
| model_prob = np.zeros(len(self.model_keys), dtype=np.float32) | ||
| for ii, model_key in enumerate(self.model_keys): | ||
| # Get training data size for this model | ||
| if hasattr(self, "data_manager") and self.data_manager: | ||
| # Try to get from data_manager | ||
| pass | ||
| # Use uniform probability for now | ||
| model_prob[ii] = 1.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic for computing model_prob for multi-task parameter sharing appears to be incomplete, as it currently falls back to a uniform probability. For correct statistical merging of shared parameters, these probabilities should ideally be based on the size of the respective datasets for each task. The pass statement and the comment "Use uniform probability for now" suggest this is a placeholder.
| model_prob = np.zeros(len(self.model_keys), dtype=np.float32) | |
| for ii, model_key in enumerate(self.model_keys): | |
| # Get training data size for this model | |
| if hasattr(self, "data_manager") and self.data_manager: | |
| # Try to get from data_manager | |
| pass | |
| # Use uniform probability for now | |
| model_prob[ii] = 1.0 | |
| model_prob = np.zeros(len(self.model_keys), dtype=np.float32) | |
| for ii, model_key in enumerate(self.model_keys): | |
| # Get training data size for this model | |
| if hasattr(self, "data_manager") and self.data_manager: | |
| # Base probability on the number of batches in each training loader | |
| model_prob[ii] = float(len(self.data_manager.training_loaders[model_key])) | |
| else: | |
| # Fallback to uniform probability if data_manager is not yet available | |
| model_prob[ii] = 1.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR introduces a new modular PyTorch training system (Trainer + component managers/factories) and wires it into the dp --pt train entrypoint, with accompanying unit and CLI-level tests.
Changes:
- Added a new modular
deepmd.pt.train.trainer.Trainerand supporting components (data manager/loader abstraction, optimizer factory, training loop strategies, checkpoint manager, hooks, logger, typed config). - Updated the PT entrypoint to optionally use the new trainer (
use_legacy=Falseby default). - Added a new PT test suite validating config parsing, factories/hooks/logger utilities, and basic end-to-end trainer construction/data access.
Reviewed changes
Copilot reviewed 12 out of 12 changed files in this pull request and generated 15 comments.
Show a summary per file
| File | Description |
|---|---|
| source/tests/pt/test_new_training.py | Adds unit + end-to-end CLI tests for the new modular trainer components. |
| deepmd/pt/train/training_loop.py | Introduces optimizer-specific training loop implementations and a factory. |
| deepmd/pt/train/trainer.py | New modular Trainer orchestrating training, validation, hooks, logging, and checkpointing. |
| deepmd/pt/train/optimizer_factory.py | Strategy-based optimizer/scheduler factory. |
| deepmd/pt/train/logger.py | New training logger + loss accumulator implementation. |
| deepmd/pt/train/hooks.py | Hook system (priority-based) + built-in hooks (TensorBoard, timing, early stopping). |
| deepmd/pt/train/data_manager.py | Unifies single/multi-task batch iteration and abstracts data loader backends. |
| deepmd/pt/train/data_loader.py | Defines abstract loader interface and DpLoaderSet adapter + batch processing. |
| deepmd/pt/train/config.py | Adds typed config dataclasses and parsing/validation logic. |
| deepmd/pt/train/checkpoint_manager.py | Centralizes checkpoint save/load and retention management. |
| deepmd/pt/train/init.py | Exposes new modular training API while keeping legacy trainer available. |
| deepmd/pt/entrypoints/main.py | Adds new trainer selection switch (use_legacy) and instantiates the new trainer. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| module.train_infos["step"] = step | ||
|
|
||
| # Prepare checkpoint path | ||
| save_path = Path(self.config.save_ckpt + f"-{step + 1}.pt") |
Copilot
AI
Feb 11, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Checkpoint filenames are off-by-one: save() receives step (already 1-based in the new Trainer), but the path suffix uses step + 1. This will create checkpoints like model.ckpt-2.pt for the first saved step and makes the saved step metadata inconsistent with the filename. Consider making save() treat step as the display step and remove the + 1, or pass a 0-based step consistently from the caller.
| save_path = Path(self.config.save_ckpt + f"-{step + 1}.pt") | |
| save_path = Path(self.config.save_ckpt + f"-{step}.pt") |
| line += f" {valid_res.get(k, 0.0):11.2e} {train_res.get(k, 0.0):11.2e}" | ||
| else: | ||
| for k in sorted(train_res.keys()): | ||
| line += f" {train_res.get(k, 0.0):11.2e}" | ||
| else: | ||
| train_keys = sorted(train_results.keys()) | ||
| if valid_results and isinstance(valid_results, dict): | ||
| for k in train_keys: | ||
| line += f" {valid_results.get(k, 0.0):11.2e} {train_results.get(k, 0.0):11.2e}" | ||
| else: | ||
| for k in train_keys: | ||
| line += f" {train_results.get(k, 0.0):11.2e}" |
Copilot
AI
Feb 11, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_print_to_file() formats metric values with float format specs (:11.2e) but pulls them via .get(k, 0.0). In practice, loss modules return scalar torch.Tensors for metrics (e.g., rmse_*), which will raise TypeError: unsupported format string passed to Tensor.__format__. Convert scalar tensors to Python floats before formatting (e.g., float(v) / v.item()), and prefer nan rather than 0.0 for missing validation values to match the header note and legacy behavior.
| line += f" {valid_res.get(k, 0.0):11.2e} {train_res.get(k, 0.0):11.2e}" | |
| else: | |
| for k in sorted(train_res.keys()): | |
| line += f" {train_res.get(k, 0.0):11.2e}" | |
| else: | |
| train_keys = sorted(train_results.keys()) | |
| if valid_results and isinstance(valid_results, dict): | |
| for k in train_keys: | |
| line += f" {valid_results.get(k, 0.0):11.2e} {train_results.get(k, 0.0):11.2e}" | |
| else: | |
| for k in train_keys: | |
| line += f" {train_results.get(k, 0.0):11.2e}" | |
| valid_val = valid_res.get(k, float("nan")) | |
| try: | |
| valid_val = float(valid_val) | |
| except (TypeError, ValueError): | |
| valid_val = float("nan") | |
| train_val = train_res.get(k, 0.0) | |
| try: | |
| train_val = float(train_val) | |
| except (TypeError, ValueError): | |
| train_val = float("nan") | |
| line += f" {valid_val:11.2e} {train_val:11.2e}" | |
| else: | |
| for k in sorted(train_res.keys()): | |
| train_val = train_res.get(k, 0.0) | |
| try: | |
| train_val = float(train_val) | |
| except (TypeError, ValueError): | |
| train_val = float("nan") | |
| line += f" {train_val:11.2e}" | |
| else: | |
| train_keys = sorted(train_results.keys()) | |
| if valid_results and isinstance(valid_results, dict): | |
| for k in train_keys: | |
| valid_val = valid_results.get(k, float("nan")) | |
| try: | |
| valid_val = float(valid_val) | |
| except (TypeError, ValueError): | |
| valid_val = float("nan") | |
| train_val = train_results.get(k, 0.0) | |
| try: | |
| train_val = float(train_val) | |
| except (TypeError, ValueError): | |
| train_val = float("nan") | |
| line += f" {valid_val:11.2e} {train_val:11.2e}" | |
| else: | |
| for k in train_keys: | |
| train_val = train_results.get(k, 0.0) | |
| try: | |
| train_val = float(train_val) | |
| except (TypeError, ValueError): | |
| train_val = float("nan") | |
| line += f" {train_val:11.2e}" |
| for key, value in more_loss.items(): | ||
| if "l2_" in key: | ||
| continue | ||
| if not isinstance(value, (int, float)): | ||
| continue |
Copilot
AI
Feb 11, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LossAccumulator.update() ignores non-(int, float) values, but training losses/metrics are typically scalar torch.Tensors (see TaskLoss.display_if_exist() and loss implementations). This will drop all metrics when disp_avg is enabled, resulting in empty logs. Accept scalar tensors here and convert them to floats for accumulation.
| train_results[task_key] = { | ||
| k: v for k, v in step_result.more_loss.items() if "l2_" not in k | ||
| } | ||
| else: | ||
| train_results = { | ||
| k: v for k, v in step_result.more_loss.items() if "l2_" not in k | ||
| } |
Copilot
AI
Feb 11, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When disp_avg is disabled, train_results is built directly from step_result.more_loss values without converting scalar torch.Tensors to floats. TrainingLogger._print_to_file() uses float formatting and will fail on tensors. Convert scalar tensors to Python floats when constructing train_results (and similarly for validation results) to keep logging working and consistent with the legacy trainer.
| """Run validation on all tasks.""" | ||
| self.hook_manager.on_validation_begin(0, {}) | ||
|
|
Copilot
AI
Feb 11, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_run_validation() always calls on_validation_begin(0, ...) with step=0, regardless of the current training step. Hooks that rely on step indexing (e.g., TensorBoard/early stopping) will record incorrect x-axes. Pass the current step through from _log_and_validate() and use it here.
| if "l2_" not in k and isinstance(v, (int, float)): | ||
| results[k] = results.get(k, 0.0) + v * natoms | ||
|
|
Copilot
AI
Feb 11, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Validation aggregation drops tensor metrics: _validate_task() only accumulates values that are (int, float), but loss modules typically return scalar torch.Tensors in more_loss. This will make valid_results empty even when validation is available. Include scalar tensors (convert via float(v) / v.item()) in the aggregation.
| if "l2_" not in k and isinstance(v, (int, float)): | |
| results[k] = results.get(k, 0.0) + v * natoms | |
| if "l2_" in k: | |
| continue | |
| # Accept numeric scalars and scalar tensors | |
| if isinstance(v, (int, float)): | |
| val = float(v) | |
| elif isinstance(v, torch.Tensor) and v.numel() == 1: | |
| # Convert scalar tensor to Python float | |
| val = float(v) | |
| else: | |
| continue | |
| results[k] = results.get(k, 0.0) + val * natoms |
| shared_links: dict[str, Any] | None = None, | ||
| finetune_links: dict[str, Any] | None = None, | ||
| use_legacy: bool = False, | ||
| ) -> training.Trainer: |
Copilot
AI
Feb 11, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The return type annotation of get_trainer() is training.Trainer, but the function can now return the new modular deepmd.pt.train.trainer.Trainer when use_legacy=False. Update the annotation to a union/protocol (or a common base type) to keep typing accurate for downstream users.
| ) -> training.Trainer: | |
| ) -> training.Trainer | NewTrainer: |
| # Compute statistics | ||
| finetune_has_new_type_key = ( | ||
| self.finetune_links[key].get_has_new_type() | ||
| if self.is_finetune and self.finetune_links | ||
| else False | ||
| ) | ||
|
|
Copilot
AI
Feb 11, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable finetune_has_new_type_key is not used.
| # Compute statistics | |
| finetune_has_new_type_key = ( | |
| self.finetune_links[key].get_has_new_type() | |
| if self.is_finetune and self.finetune_links | |
| else False | |
| ) |
| This allows gradual migration from DpLoaderSet to new implementations. | ||
| """ | ||
|
|
||
| def __iter__(self) -> Iterator[dict[str, Any]]: |
Copilot
AI
Feb 11, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Iter method of iterator DataLoaderInterface does not return self.
| def __iter__(self) -> Iterator[dict[str, Any]]: | |
| def __iter__(self) -> DataLoaderInterface: |
| except Exception: | ||
| pass |
Copilot
AI
Feb 11, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
'except' clause does nothing but pass and there is no explanatory comment.
| except Exception: | |
| pass | |
| except Exception as e: | |
| log.warning( | |
| "Failed to read or parse 'checkpoint' file; " | |
| "falling back to scanning for latest checkpoint: %s", | |
| e, | |
| ) |
📝 WalkthroughWalkthroughThis PR introduces a comprehensive modular refactoring of DeepMD's PyTorch training system, adding dedicated components for configuration validation, data management, checkpointing, optimization strategies, training hooks, and logging. The main entry point now supports switching between legacy and new trainer implementations via a Changes
Sequence DiagramsequenceDiagram
participant Main as Main Trainer
participant DM as DataManager
participant OF as OptimizerFactory
participant CM as CheckpointManager
participant TL as TrainingLoop
participant HM as HookManager
participant Logger as TrainingLogger
Main->>DM: get_train_batch(task_key)
DM-->>Main: (inputs, labels, log_dict)
Main->>OF: create_optimizer(params, config, lr_config)
OF-->>Main: optimizer
Main->>OF: create_scheduler(optimizer, warmup_steps, ...)
OF-->>Main: lr_scheduler
Main->>CM: load(checkpoint_path)
CM-->>Main: start_step, metadata
Main->>HM: on_train_begin()
HM-->>Main: executed all hooks
loop Training Steps
Main->>TL: step(inputs, labels, cur_lr, task_key)
TL-->>Main: TrainingStepResult(loss, predictions, lr)
Main->>Logger: log_step(step, losses, lr)
Logger-->>Main: logged to file/console
Main->>HM: on_step_end(step, logs)
HM-->>Main: executed all hooks
alt Validation Step
Main->>DM: get_valid_batch(task_key)
DM-->>Main: (valid_inputs, valid_labels, log_dict)
Main->>HM: on_validation_begin(step)
HM-->>Main: executed hooks
Main->>HM: on_validation_end(step, val_logs)
HM-->>Main: executed hooks
end
alt Checkpoint Step
Main->>CM: save(step, wrapper, optimizer, lr)
CM-->>Main: checkpoint_path
Main->>HM: on_save_checkpoint(step, path)
HM-->>Main: executed hooks
end
end
Main->>Logger: log_summary()
Logger-->>Main: logged summary
Main->>HM: on_train_end()
HM-->>Main: executed all hooks
Main->>CM: save_final(step, wrapper)
CM-->>Main: final_checkpoint_path
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 inconclusive)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Tip Issue Planner is now in beta. Read the docs and try it out! Share your feedback on Discord. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 19
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
deepmd/pt/entrypoints/main.py (1)
101-111:⚠️ Potential issue | 🟡 MinorReturn type annotation
-> training.Traineris inaccurate whenuse_legacy=False.When
use_legacy=False(the default), the function returns aNewTrainer(fromdeepmd.pt.train.trainer), not atraining.Trainer(fromdeepmd.pt.train.training). These are distinct classes. Consider using aUniontype or a common protocol/base class.Proposed fix
-) -> training.Trainer: +) -> training.Trainer | NewTrainer:Or better, define a common
Protocolthat both trainers implement (e.g., requiringrun()andmodelattributes).
🤖 Fix all issues with AI agents
In `@deepmd/pt/train/checkpoint_manager.py`:
- Around line 151-171: The early return in _cleanup_old_checkpoints uses
self._saved_checkpoints but the cleanup logic re-discovers files with glob,
causing stale or missing cleanup; remove the early-return guard and instead
compute checkpoint_files first (using
Path(".").glob(f"{self.config.save_ckpt}*.pt") with the existing is_symlink and
name.startswith filters), sort by st_mtime, then while len(checkpoint_files) >
self.config.max_ckpt_keep unlink oldest files and log; this makes cleanup
consistent whether the manager was just created or persistent files exist and
keeps the existing filtering based on self.config.save_ckpt and max_ckpt_keep.
- Around line 138-139: The call to symlink_prefix_files uses save_path.stem
which strips directory components, causing glob(old_prefix + ".*") to miss files
when self.config.save_ckpt includes a directory; replace the stem-based prefix
with full-path prefixes (without extensions) so both prefixes include the same
directory. Concretely, pass str(save_path.with_suffix("")) (or equivalent
full-path prefix derived from save_path) and
str(Path(self.config.save_ckpt).with_suffix("")) to symlink_prefix_files instead
of save_path.stem and self.config.save_ckpt alone, and make the same change
where symlink_prefix_files is called in save_final to ensure glob operates
against the correct directory.
- Around line 216-218: The state_dict extraction in checkpoint_manager (both in
the block using state_dict = checkpoint.get("model", checkpoint) followed by an
immediate if "model" in checkpoint override, and the identical pattern in
load_for_finetune) is redundant; replace the two-step logic with a single
canonical extraction (e.g., use state_dict = checkpoint.get("model", checkpoint)
or use an if/else that sets state_dict = checkpoint["model"] else checkpoint) so
the fallback is actually used and the value is not overwritten immediately —
update the occurrences referencing state_dict and checkpoint in the methods load
(the block around state_dict = ...) and load_for_finetune accordingly.
- Around line 338-345: The code currently uses a bare except in the
checkpoint_file read block which silently swallows errors; change the except to
"except Exception as e" and log the exception (e.g., logger.exception or
logger.error with the exception info) when reading checkpoint_file.read_text()
or parsing latest, referencing the checkpoint_file and latest
variables/Path(...) usage in this block; ensure a module logger
(logging.getLogger(__name__)) is available or imported so operators can see
corruption/permission/invalid-path errors instead of them being ignored.
In `@deepmd/pt/train/config.py`:
- Around line 206-214: The code currently raises a ValueError when some
model_keys are missing from optim_dict (in the optim_dict block using
OptimizerConfig.from_dict) but silently falls back to lr_params when keys are
missing from learning_rate_dict; make the behavior consistent by validating
learning_rate_dict the same way: after building learning_rate_dict for keys in
model_keys, compute missing_keys = [k for k in model_keys if k not in
learning_rate_dict] and if missing_keys raise a ValueError (or optionally emit a
warning per project policy) instead of silently using lr_params; update the code
near the learning_rate_dict handling (the block that references lr_params and
learning_rate_dict) to perform this check.
- Around line 200-202: Replace the assert-based validation with an explicit
exception raise so the check always runs: in the code that currently uses
`assert num_steps - computed_warmup_steps > 0 or computed_warmup_steps == 0,
("Warm up steps must be less than total training steps!")` (the variables
`num_steps` and `computed_warmup_steps`), change it to an if statement that
raises `ValueError("Warm up steps must be less than total training steps!")`
when the condition fails; keep the same message and logic but use `ValueError`
for proper production-time validation.
In `@deepmd/pt/train/data_manager.py`:
- Around line 193-216: get_valid_numb_batch currently returns 1 when no
validation loader exists which causes an unnecessary validation iteration;
update get_valid_numb_batch (and its use of _get_valid_loader) to return 0 when
loader is None and also return 0 when len(loader) cannot be determined (replace
the fallback 1 with 0) so the trainer will skip validation loops for missing
loaders; ensure callers like get_valid_batch still handle zero batches
correctly.
- Around line 144-162: The code calls next(...) directly on AbstractDataLoader
instances (in get_train_batch), but AbstractDataLoader only defines __iter__, so
change DataManager to store iterators (e.g., create a training_iterators
structure by calling iter() on each loader when loaders are initialized) and
then call next(...) on those iterators (use
next(self.training_iterators[task_key]) or next(self.training_iterators) inside
get_train_batch); this keeps DpLoaderSetAdapter working and avoids requiring
__next__ on the ABC.
In `@deepmd/pt/train/optimizer_factory.py`:
- Around line 155-162: The post-warmup branch divides by lr_schedule.start_lr
which can be zero; add a guard to check if lr_schedule.start_lr == 0 and raise a
clear ValueError (or return a documented fallback) instead of performing the
division. Locate the division using lr_schedule.value(...)/lr_schedule.start_lr
in the function (the post-warmup branch shown) and replace it with a check that
either raises ValueError("lr_schedule.start_lr is zero; please set a non-zero
start_lr") or handles a defined fallback, and apply the same fix to the other
occurrence around the 200-208 region so both divisions are protected. Ensure the
error message references lr_schedule.start_lr so users know which config to fix.
- Around line 151-164: Extract the duplicated warmup_linear closure into a
single shared factory function named _make_warmup_lr_lambda(warmup_steps,
warmup_start_factor, lr_schedule, start_step=0) that returns the warmup_linear
callable (preserving the same logic and docstring), then replace the inline
warmup_linear closures in AdamStrategy.create_scheduler,
AdamWStrategy.create_scheduler, AdaMuonStrategy.create_scheduler and
HybridMuonStrategy.create_scheduler with calls to
torch.optim.lr_scheduler.LambdaLR(..., lr_lambda=_make_warmup_lr_lambda(...)) so
each scheduler construction becomes a one-liner using the new factory.
- Around line 226-231: The LKFOptimizer instantiation in optimizer_factory.py
currently hardcodes kalman_lambda=0.98 and kalman_nu=0.99870; update this by
adding configurable fields (e.g., kalman_lambda and kalman_nu) to
OptimizerConfig and use those fields when constructing LKFOptimizer (or, if you
prefer not to change the public config, extract the magic numbers into named
constants with docstrings and reference those constants in the LKFOptimizer(...)
call). Modify OptimizerConfig (config.py) to include the new fields with
sensible defaults and update any config parsing/validation, then replace the
literal values in the LKFOptimizer(...) call to reference
OptimizerConfig.kalman_lambda and OptimizerConfig.kalman_nu (or the new named
constants) so the parameters are documented and configurable.
In `@deepmd/pt/train/trainer.py`:
- Around line 237-240: The multitask detection is inconsistent: update the
assignment to self.is_multitask in trainer.py to match the rest of the codebase
by setting it solely based on the presence of "model_dict" in model_params
(i.e., self.is_multitask = "model_dict" in model_params) rather than checking
len(self.model_keys) > 1; adjust any dependent logic if necessary to rely on
self.model_keys for key counts but keep self.is_multitask as the simple presence
check.
- Around line 685-694: The _init_distributed method currently calls
torch.cuda.set_device(LOCAL_RANK) unconditionally which fails on CPU-only
backends; modify _init_distributed (and the DDP wrap) to first check
torch.cuda.is_available() before calling torch.cuda.set_device or passing
CUDA-specific arguments to DDP—i.e., only call torch.cuda.set_device(LOCAL_RANK)
and use device_ids=[LOCAL_RANK], output_device=LOCAL_RANK when
torch.cuda.is_available() is True; otherwise skip set_device and instantiate
DDP(self.wrapper, find_unused_parameters=True) (no CUDA-specific
device_ids/output_device) so it works with gloo/CPU-only setups.
- Around line 327-373: The _create_loss function mutates the input loss_params
dict (e.g., adding "starter_learning_rate", "ntypes", "tensor_size", etc.), so
make a local deep copy at the top of _create_loss (e.g., copy =
deepcopy(loss_params)) and use that copy for all modifications and to construct
the loss instances (EnergyStdLoss, EnergyHessianStdLoss, EnergySpinLoss,
DenoiseLoss, DOSLoss, TensorLoss, PropertyLoss, and
TaskLoss.get_class_by_type(...).get_loss). This prevents changing the original
config passed by the caller while preserving all current key assignments and
behavior.
- Around line 413-418: The local variable finetune_has_new_type_key is computed
but never used; remove the dead assignment or if you intend to keep it for
future use, rename it to _finetune_has_new_type_key to signal intentionally
unused. Update the block around the expression that references
self.finetune_links.get_has_new_type(), which is currently guarded by
self.is_finetune and self.finetune_links, by either deleting that whole
assignment line or renaming the variable as described; ensure no other code
relies on finetune_has_new_type_key before removing or renaming.
- Around line 662-683: The loop that builds model_prob for self.model_keys is a
no-op and always assigns uniform probability; replace the placeholder with logic
that queries self.data_manager for each model_key's training dataset size (e.g.,
via a method/property analogous to the legacy trainer) to compute per-model
weights proportional to data sizes, falling back to uniform weights if
data_manager is missing or a model's size is zero; normalize model_prob to sum
to 1 and then build model_key_prob_map as before, ensuring compatibility with
self.wrapper.share_params and matching the legacy trainer's probability
computation semantics.
- Around line 1118-1121: The final checkpoint currently calls
self._save_checkpoint(self.config.num_steps - 1, 0.0) which hardcodes lr=0.0 and
may double-save; change _finalize_training to use the actual last learning rate
tracked by the training loop (e.g., maintain self._last_lr updated each step in
the training loop) and call self._save_checkpoint(self.config.num_steps - 1,
self._last_lr) instead, and add a guard in _finalize_training (or check in
_save_checkpoint) to skip saving if a checkpoint for step self.config.num_steps
- 1 was already written (or if the stored last checkpoint step equals
config.num_steps - 1) to avoid redundant saves.
In `@deepmd/pt/train/training_loop.py`:
- Around line 226-238: The _compute_prefactors function can divide by zero when
start_pref_e or start_pref_f is 0; update it to guard against zeros by computing
ratio = step / max(1, self.num_steps) and for each prefactor use geometric
interpolation as now when start_pref_* != 0, but fall back to linear
interpolation pref = start + (limit - start) * ratio when start_pref_* == 0
(apply this logic for start_pref_e/limit_pref_e and start_pref_f/limit_pref_f in
_compute_prefactors).
- Around line 248-252: The code incorrectly reads step from
self.optimizer.state.get("step", 0) (optimizer.state is per-parameter, so this
always yields 0) causing _compute_prefactors to never advance; fix by passing
the current training step into the training step method instead: update the
signature of BaseTrainingLoop.step (and any overrides like
LKFEnergyTrainingLoop.step) to accept a step: int parameter, propagate the
caller's step (Trainer) into that call, and replace uses of
self.optimizer.state.get("step", 0) with the passed-in step when calling
_compute_prefactors.
🧹 Nitpick comments (19)
deepmd/pt/train/training_loop.py (1)
217-224: Magic numbers24(kp) and6(kq) are duplicated across two classes.These KF parameters appear in both
LKFEnergyTrainingLoop(lines 221-222) andLKFDenoiseTrainingLoop(lines 303-304) with no explanation or configurability. Extract them as class-level constants or accept them from configuration.♻️ Proposed refactor
+# Default Kalman Filter parameters +_DEFAULT_KF_KP = 24 +_DEFAULT_KF_KQ = 6 + class LKFEnergyTrainingLoop(BaseTrainingLoop): ... self.kf_wrapper = KFOptimizerWrapper( wrapper, optimizer, - 24, # kp - 6, # kq + _DEFAULT_KF_KP, + _DEFAULT_KF_KQ, dist.is_available() and dist.is_initialized(), )Apply the same change to
LKFDenoiseTrainingLoop.Also applies to: 300-306
deepmd/pt/train/__init__.py (1)
67-72:LKFDenoiseTrainingLoopis not re-exported, unlike its siblingLKFEnergyTrainingLoop.
LKFEnergyTrainingLoopis imported and listed in__all__, butLKFDenoiseTrainingLoop(also a public class intraining_loop.py) is neither imported nor listed. If both are part of the public API, add the missing export for consistency.♻️ Proposed fix
from deepmd.pt.train.training_loop import ( AdamTrainingLoop, BaseTrainingLoop, + LKFDenoiseTrainingLoop, LKFEnergyTrainingLoop, TrainingLoopFactory, )And add
"LKFDenoiseTrainingLoop"to__all__.deepmd/pt/train/logger.py (1)
47-82: File handle opened in__init__may leak ifclose()is never called.The
TrainingLoggeropens a file on construction (line 79) but cleanup depends on explicitclose()or context-manager usage. IfTrainer.__init__fails after creating the logger (but beforerun()is invoked), the handle leaks. Consider adding a__del__safety net.♻️ Proposed safety net
+ def __del__(self) -> None: + """Ensure file handle is closed on garbage collection.""" + self.close()deepmd/pt/train/data_loader.py (2)
295-308:DpLoaderSetAdapteris both iterable and iterator (__iter__returnsself).This means only one iteration can be active at a time. If
__iter__is called again, it resets the internal_iterator, potentially losing the position of any in-progress iteration. This is fine for the current single-consumer training loop, but worth noting as a limitation.
409-426:DataLoaderFactory._implementationsis a mutableClassVar—register()mutates global state.
register()modifies a class-level dict shared by all instances and imports. In multi-process or testing scenarios, registered implementations persist across test cases. This is acceptable for a plugin registry pattern but consider documenting the global-state semantics.deepmd/pt/train/data_manager.py (1)
86-101:training_loadersandvalidation_loadershave different types depending onis_multitask.In multitask mode, these are
dict[str, AbstractDataLoader]; in single-task mode, they'reAbstractDataLoader/AbstractDataLoader | None. This makes downstream code requireisinstancechecks oris_multitaskguards everywhere. Consider always using a dict (e.g.,{"Default": loader}) for a uniform interface.source/tests/pt/test_new_training.py (1)
308-444: End-to-end tests verify creation and initialization but not training execution.The
TestEndToEndCLItests confirm the trainer initializes correctly and can fetch data, but there's no test that callstrainer.run()even for a single step. Consider adding a minimal test that runs 1-2 training steps to verify the full loop (forward, backward, optimizer step, logging, checkpointing).💡 Example test sketch
def test_trainer_runs_one_step(self): """Test that trainer can execute at least one training step.""" config = copy.deepcopy(self.config) config["training"]["numb_steps"] = 1 config["training"]["save_freq"] = 1 trainer = get_trainer(config) # Should not raise trainer.run()deepmd/pt/entrypoints/main.py (1)
385-394: Add CLI flag to fall back to legacy trainer during transition period.The new (refactored) trainer is now the default with no option to use the legacy trainer via the CLI. Users cannot fall back to the legacy trainer during the transition period. Add a
--use-legacy-trainerflag toparser_trainto allow users to opt in to the legacy trainer behavior.deepmd/pt/train/checkpoint_manager.py (3)
106-109: DDP unwrapping should not depend ondist.is_initialized().A module can be wrapped in
DistributedDataParallel(and thus have a.moduleattribute) even when checked outside a distributed context, e.g., in certain test setups or afterdisthas been destroyed. The current guard means you'd serialize the DDP wrapper's state dict (withmodule.prefixed keys) in those edge cases. The same pattern is repeated insave_final(Lines 380–383),load(Lines 211–214), andload_for_finetune(Lines 264–267).A simpler, more robust pattern:
Suggested change
- module = wrapper - if dist.is_available() and dist.is_initialized(): - if hasattr(wrapper, "module"): - module = wrapper.module + module = getattr(wrapper, "module", wrapper)
354-406:save_finalduplicates most ofsave— consider consolidation.
save_finalrepeats the DDP unwrapping, train_infos update, checkpoint creation, torch.save, symlink, and checkpoint-file write logic fromsave, differing only in the absence of the optimizer state and the naming scheme. This violates DRY and increases the risk of divergence (the off-by-one naming issue is one such case already).Consider having
save_finaldelegate tosavewith an optionalinclude_optimizer=Falseflag, or extracting the shared logic into a private helper.
120-124:deepcopyis unnecessary and memory-inefficient for optimizer state saving
optimizer.state_dict()returns a new dictionary, but its values (including state tensors) are references to the optimizer's internal state. Since the code only replaces scalarlrvalues (line 124) and doesn't mutate the state tensors themselves, thedeepcopyunnecessarily copies all optimizer state tensors (momentum buffers, Adam's second moments, etc.). For large models, this doubles peak memory usage during checkpoint saves. Save directly with the state dict returned bystate_dict().Suggested change
- optim_state = deepcopy(optimizer.state_dict()) + optim_state = optimizer.state_dict()deepmd/pt/train/config.py (1)
34-71: Default values are duplicated between field declarations andfrom_dict.Every dataclass in this file (e.g.,
OptimizerConfig,LearningRateConfig,DisplayConfig,CheckpointConfig) specifies default values twice — once in the field declarations and again in the.get()calls withinfrom_dict. If one changes without the other, they'll silently diverge.Consider using the field defaults directly:
Example for OptimizerConfig
`@classmethod` def from_dict(cls, params: dict[str, Any]) -> OptimizerConfig: """Create OptimizerConfig from dictionary.""" + fields = {f.name: f.default for f in cls.__dataclass_fields__.values()} return cls( - opt_type=params.get("opt_type", "Adam"), - weight_decay=params.get("weight_decay", 0.001), - ... + **{k: params.get(k, v) for k, v in fields.items()} )This applies to all four smaller config classes.
TrainingConfig.from_dictis different since it performs transformations.deepmd/pt/train/hooks.py (5)
54-67:TrainingHookinheritsABCbut has no abstract methods — consider droppingABC.Since all hook methods are optional overrides (none are
@abstractmethod), inheriting fromABCis misleading — it suggests subclasses must implement something. A plain base class communicates the intent (all methods optional) more clearly. Ruff also flags this (B024).
411-421:TimingHook.on_step_endmeasures wall time of the entire step plus all higher-priority hooks.Since hooks execute in priority order and
TimingHookhasLOWpriority, by the timeon_step_endfires, all other hooks'on_step_endcalls have already completed. Thelast_step_timeis set at the end of this method, so the measured interval includes other hooks'on_step_endoverhead plus the gap betweenon_step_beginandon_step_endof the next step. This is probably fine for a rough ETA, but worth documenting that it measures wall time inclusive of hook overhead.Also,
step_times.pop(0)on a list is O(n). For a fixed-size rolling window,collections.deque(maxlen=100)would be more efficient and idiomatic.Suggested change for deque
+from collections import deque + class TimingHook(TrainingHook): ... def __init__(self) -> None: - self.step_times: list[float] = [] + self.step_times: deque[float] = deque(maxlen=100) self.start_time: float | None = None self.last_step_time: float | None = None def on_step_end(self, step, logs=None): import time if self.last_step_time is not None: step_time = time.time() - self.last_step_time self.step_times.append(step_time) - if len(self.step_times) > 100: - self.step_times.pop(0) self.last_step_time = time.time()
490-504:EarlyStoppingHookstoresis_betteras a lambda — not picklable.If any part of the system needs to serialize hooks (e.g., for distributed training, checkpointing hook state), the lambda stored in
self.is_betterwill fail to pickle. This can be avoided by usingoperator.lt/operator.gtor a simple method dispatch onself.mode.Suggested change
- if mode == "min": - self.is_better = lambda current, best: current < best - self.best_value = float("inf") - elif mode == "max": - self.is_better = lambda current, best: current > best - self.best_value = float("-inf") + if mode == "min": + self.best_value = float("inf") + elif mode == "max": + self.best_value = float("-inf") else: raise ValueError(f"mode must be 'min' or 'max', got {mode}") + + def is_better(self, current: float, best: float) -> bool: + if self.mode == "min": + return current < best + return current > best
347-368:TensorBoardHook.on_step_enditerates all log keys after already checkinglossandlr.The loop on Lines 366–368 re-checks and skips
"loss"and"lr", but iflogscontains many keys, this does redundant string comparisons. More importantly, non-numeric values (e.g., tensors, strings) are silently dropped by theisinstanceguard, which is fine, but torchTensorscalars (0-dim) won't match(int, float)and will be silently ignored.Consider adding
torch.Tensorto the isinstance check or calling.item()for scalar tensors:Suggested addition
for key, value in logs.items(): if key not in ["loss", "lr"] and isinstance(value, (int, float)): self.writer.add_scalar(f"train/{key}", value, display_step) + elif key not in ["loss", "lr"] and isinstance(value, torch.Tensor) and value.ndim == 0: + self.writer.add_scalar(f"train/{key}", value.item(), display_step)
382-386:TensorBoardHook.on_train_enddoes not flush before closing.
SummaryWriter.close()should flush pending events. While the default implementation ofclose()callsflush()internally in most TensorBoard versions, explicitly flushing before close is a common best practice to avoid data loss if the writer implementation changes.deepmd/pt/train/optimizer_factory.py (2)
135-139: Simplify thefusedternary.
False if DEVICE.type == "cpu" else Trueis equivalent toDEVICE.type != "cpu".Suggested fix
- fused=False if DEVICE.type == "cpu" else True, + fused=(DEVICE.type != "cpu"),Also applies to: 180-184
496-497: Global mutable singleton_default_factoryis a hidden shared-state risk.The module-level
_default_factory = OptimizerFactory()is created at import time. If any caller uses_default_factory.register(...)to add a custom strategy, it affects all subsequent users of the module-levelcreate_optimizer/create_schedulerfunctions — including across tests. This could lead to flaky tests or unexpected behavior in long-running processes.Consider documenting this explicitly, or providing a
get_default_factory()function that makes the shared state visible and resettable.
| # Update symlinks | ||
| symlink_prefix_files(save_path.stem, self.config.save_ckpt) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
symlink_prefix_files expects string prefixes — save_path.stem drops the directory.
save_path.stem returns just the filename without extension (e.g., "model.ckpt-1001"), while self.config.save_ckpt is "model.ckpt". The symlink_prefix_files function calls glob(old_prefix + ".*"), which glob-matches from the current working directory. This works only if the checkpoint is saved in the CWD. If save_ckpt ever contains a directory component (e.g., "checkpoints/model.ckpt"), the stem-based prefix and the config prefix will diverge and the symlink won't be created. The same concern applies to Line 401 in save_final.
🤖 Prompt for AI Agents
In `@deepmd/pt/train/checkpoint_manager.py` around lines 138 - 139, The call to
symlink_prefix_files uses save_path.stem which strips directory components,
causing glob(old_prefix + ".*") to miss files when self.config.save_ckpt
includes a directory; replace the stem-based prefix with full-path prefixes
(without extensions) so both prefixes include the same directory. Concretely,
pass str(save_path.with_suffix("")) (or equivalent full-path prefix derived from
save_path) and str(Path(self.config.save_ckpt).with_suffix("")) to
symlink_prefix_files instead of save_path.stem and self.config.save_ckpt alone,
and make the same change where symlink_prefix_files is called in save_final to
ensure glob operates against the correct directory.
| def _cleanup_old_checkpoints(self) -> None: | ||
| """Remove old checkpoints keeping only max_ckpt_keep most recent.""" | ||
| if len(self._saved_checkpoints) <= self.config.max_ckpt_keep: | ||
| return | ||
|
|
||
| # Sort by modification time | ||
| checkpoint_files = [ | ||
| f | ||
| for f in Path(".").glob(f"{self.config.save_ckpt}*.pt") | ||
| if not f.is_symlink() and f.name.startswith(self.config.save_ckpt) | ||
| ] | ||
| checkpoint_files.sort(key=lambda x: x.stat().st_mtime) | ||
|
|
||
| # Remove oldest | ||
| while len(checkpoint_files) > self.config.max_ckpt_keep: | ||
| old_file = checkpoint_files.pop(0) | ||
| try: | ||
| old_file.unlink() | ||
| log.debug(f"Removed old checkpoint: {old_file}") | ||
| except OSError as e: | ||
| log.warning(f"Failed to remove old checkpoint {old_file}: {e}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cleanup logic has a disconnect with _saved_checkpoints tracking.
The early return on Line 153 checks self._saved_checkpoints, but the actual cleanup on Lines 157–171 re-discovers checkpoints via glob, ignoring _saved_checkpoints entirely. This means:
- If the process restarts (fresh
CheckpointManager),_saved_checkpointsis empty and cleanup never triggers even if many old files exist on disk. - The glob
f"{self.config.save_ckpt}*.pt"may match files not managed by this run.
Consider either consistently using _saved_checkpoints for cleanup or always using the glob-based discovery (without the early-return guard).
🤖 Prompt for AI Agents
In `@deepmd/pt/train/checkpoint_manager.py` around lines 151 - 171, The early
return in _cleanup_old_checkpoints uses self._saved_checkpoints but the cleanup
logic re-discovers files with glob, causing stale or missing cleanup; remove the
early-return guard and instead compute checkpoint_files first (using
Path(".").glob(f"{self.config.save_ckpt}*.pt") with the existing is_symlink and
name.startswith filters), sort by st_mtime, then while len(checkpoint_files) >
self.config.max_ckpt_keep unlink oldest files and log; this makes cleanup
consistent whether the manager was just created or persistent files exist and
keeps the existing filtering based on self.config.save_ckpt and max_ckpt_keep.
| state_dict = checkpoint.get("model", checkpoint) | ||
| if "model" in checkpoint: | ||
| state_dict = checkpoint["model"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Redundant state_dict extraction.
Lines 216–218 extract state_dict twice — Line 216 uses .get("model", checkpoint) as a fallback and Line 217-218 immediately overwrite it if "model" key exists. The fallback on Line 216 is never reachable when Line 217 takes effect. The same pattern appears in load_for_finetune (Lines 269–271).
Simplify to
- state_dict = checkpoint.get("model", checkpoint)
- if "model" in checkpoint:
- state_dict = checkpoint["model"]
+ state_dict = checkpoint.get("model", checkpoint)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| state_dict = checkpoint.get("model", checkpoint) | |
| if "model" in checkpoint: | |
| state_dict = checkpoint["model"] | |
| state_dict = checkpoint.get("model", checkpoint) |
🤖 Prompt for AI Agents
In `@deepmd/pt/train/checkpoint_manager.py` around lines 216 - 218, The state_dict
extraction in checkpoint_manager (both in the block using state_dict =
checkpoint.get("model", checkpoint) followed by an immediate if "model" in
checkpoint override, and the identical pattern in load_for_finetune) is
redundant; replace the two-step logic with a single canonical extraction (e.g.,
use state_dict = checkpoint.get("model", checkpoint) or use an if/else that sets
state_dict = checkpoint["model"] else checkpoint) so the fallback is actually
used and the value is not overwritten immediately — update the occurrences
referencing state_dict and checkpoint in the methods load (the block around
state_dict = ...) and load_for_finetune accordingly.
| checkpoint_file = Path("checkpoint") | ||
| if checkpoint_file.exists(): | ||
| try: | ||
| latest = Path(checkpoint_file.read_text().strip()) | ||
| if latest.exists(): | ||
| return latest | ||
| except Exception: | ||
| pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Silent except: pass swallows all errors when reading the checkpoint file.
If the checkpoint file exists but is corrupted, has permission issues, or contains an invalid path, this is silently ignored. At minimum, log the exception so operators can diagnose recovery failures. This aligns with the Ruff S110 hint.
Suggested fix
- except Exception:
- pass
+ except Exception:
+ log.warning("Failed to read checkpoint file", exc_info=True)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| checkpoint_file = Path("checkpoint") | |
| if checkpoint_file.exists(): | |
| try: | |
| latest = Path(checkpoint_file.read_text().strip()) | |
| if latest.exists(): | |
| return latest | |
| except Exception: | |
| pass | |
| checkpoint_file = Path("checkpoint") | |
| if checkpoint_file.exists(): | |
| try: | |
| latest = Path(checkpoint_file.read_text().strip()) | |
| if latest.exists(): | |
| return latest | |
| except Exception: | |
| log.warning("Failed to read checkpoint file", exc_info=True) |
🧰 Tools
🪛 Ruff (0.15.0)
[error] 344-345: try-except-pass detected, consider logging the exception
(S110)
[warning] 344-344: Do not catch blind exception: Exception
(BLE001)
🤖 Prompt for AI Agents
In `@deepmd/pt/train/checkpoint_manager.py` around lines 338 - 345, The code
currently uses a bare except in the checkpoint_file read block which silently
swallows errors; change the except to "except Exception as e" and log the
exception (e.g., logger.exception or logger.error with the exception info) when
reading checkpoint_file.read_text() or parsing latest, referencing the
checkpoint_file and latest variables/Path(...) usage in this block; ensure a
module logger (logging.getLogger(__name__)) is available or imported so
operators can see corruption/permission/invalid-path errors instead of them
being ignored.
| assert num_steps - computed_warmup_steps > 0 or computed_warmup_steps == 0, ( | ||
| "Warm up steps must be less than total training steps!" | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use ValueError instead of assert for input validation.
assert statements are stripped when Python runs with -O (optimized mode), so this check would silently disappear in production. Since this validates user-provided configuration, it should use ValueError to ensure it's always enforced.
Suggested fix
- assert num_steps - computed_warmup_steps > 0 or computed_warmup_steps == 0, (
- "Warm up steps must be less than total training steps!"
- )
+ if computed_warmup_steps != 0 and num_steps - computed_warmup_steps <= 0:
+ raise ValueError(
+ "Warm up steps must be less than total training steps!"
+ )📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| assert num_steps - computed_warmup_steps > 0 or computed_warmup_steps == 0, ( | |
| "Warm up steps must be less than total training steps!" | |
| ) | |
| if computed_warmup_steps != 0 and num_steps - computed_warmup_steps <= 0: | |
| raise ValueError( | |
| "Warm up steps must be less than total training steps!" | |
| ) |
🤖 Prompt for AI Agents
In `@deepmd/pt/train/config.py` around lines 200 - 202, Replace the assert-based
validation with an explicit exception raise so the check always runs: in the
code that currently uses `assert num_steps - computed_warmup_steps > 0 or
computed_warmup_steps == 0, ("Warm up steps must be less than total training
steps!")` (the variables `num_steps` and `computed_warmup_steps`), change it to
an if statement that raises `ValueError("Warm up steps must be less than total
training steps!")` when the condition fails; keep the same message and logic but
use `ValueError` for proper production-time validation.
|
|
||
| # Compute model probabilities | ||
| model_prob = np.zeros(len(self.model_keys), dtype=np.float32) | ||
| for ii, model_key in enumerate(self.model_keys): | ||
| # Get training data size for this model | ||
| if hasattr(self, "data_manager") and self.data_manager: | ||
| # Try to get from data_manager | ||
| pass | ||
| # Use uniform probability for now | ||
| model_prob[ii] = 1.0 | ||
|
|
||
| model_prob = model_prob / np.sum(model_prob) | ||
| model_key_prob_map = dict(zip(self.model_keys, model_prob)) | ||
|
|
||
| # Call share_params | ||
| self.wrapper.share_params( | ||
| shared_links, | ||
| resume=(self.is_restart and not self.finetune_update_stat) | ||
| or self.rank != 0, | ||
| model_key_prob_map=model_key_prob_map, | ||
| data_stat_protect=data_stat_protect_values[0], | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Multi-task model probability computation is a no-op — always produces uniform distribution.
Lines 665-671: The loop body unconditionally sets model_prob[ii] = 1.0 with a pass inside the if hasattr(...) branch, meaning the data-size-based probability logic was never implemented. The legacy trainer computes probabilities from actual data sizes. This will affect multi-task training balance.
🐛 Proposed fix sketch
model_prob = np.zeros(len(self.model_keys), dtype=np.float32)
for ii, model_key in enumerate(self.model_keys):
- # Get training data size for this model
- if hasattr(self, "data_manager") and self.data_manager:
- # Try to get from data_manager
- pass
- # Use uniform probability for now
- model_prob[ii] = 1.0
+ if hasattr(self, "data_manager") and self.data_manager:
+ loader = self.data_manager.training_loaders.get(model_key)
+ if loader is not None:
+ model_prob[ii] = float(len(loader))
+ else:
+ model_prob[ii] = 1.0
+ else:
+ model_prob[ii] = 1.0Verify against the legacy trainer's probability computation to ensure parity.
🧰 Tools
🪛 Ruff (0.15.0)
[warning] 665-665: Loop control variable model_key not used within loop body
Rename unused model_key to _model_key
(B007)
[warning] 674-674: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
🤖 Prompt for AI Agents
In `@deepmd/pt/train/trainer.py` around lines 662 - 683, The loop that builds
model_prob for self.model_keys is a no-op and always assigns uniform
probability; replace the placeholder with logic that queries self.data_manager
for each model_key's training dataset size (e.g., via a method/property
analogous to the legacy trainer) to compute per-model weights proportional to
data sizes, falling back to uniform weights if data_manager is missing or a
model's size is zero; normalize model_prob to sum to 1 and then build
model_key_prob_map as before, ensuring compatibility with
self.wrapper.share_params and matching the legacy trainer's probability
computation semantics.
| def _init_distributed(self) -> None: | ||
| """Initialize distributed training.""" | ||
| if dist.is_available() and dist.is_initialized(): | ||
| torch.cuda.set_device(LOCAL_RANK) | ||
| self.wrapper = DDP( | ||
| self.wrapper, | ||
| device_ids=[LOCAL_RANK], | ||
| find_unused_parameters=True, | ||
| output_device=LOCAL_RANK, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.cuda.set_device called unconditionally — fails on CPU-only systems.
Line 688: When dist.is_initialized() is True but the backend is gloo (CPU-only), torch.cuda.set_device(LOCAL_RANK) will raise an error if no CUDA devices are available. Guard with a CUDA availability check.
🐛 Proposed fix
def _init_distributed(self) -> None:
"""Initialize distributed training."""
if dist.is_available() and dist.is_initialized():
- torch.cuda.set_device(LOCAL_RANK)
+ if torch.cuda.is_available():
+ torch.cuda.set_device(LOCAL_RANK)
self.wrapper = DDP(
self.wrapper,
device_ids=[LOCAL_RANK],🤖 Prompt for AI Agents
In `@deepmd/pt/train/trainer.py` around lines 685 - 694, The _init_distributed
method currently calls torch.cuda.set_device(LOCAL_RANK) unconditionally which
fails on CPU-only backends; modify _init_distributed (and the DDP wrap) to first
check torch.cuda.is_available() before calling torch.cuda.set_device or passing
CUDA-specific arguments to DDP—i.e., only call torch.cuda.set_device(LOCAL_RANK)
and use device_ids=[LOCAL_RANK], output_device=LOCAL_RANK when
torch.cuda.is_available() is True; otherwise skip set_device and instantiate
DDP(self.wrapper, find_unused_parameters=True) (no CUDA-specific
device_ids/output_device) so it works with gloo/CPU-only setups.
| def _finalize_training(self, total_time: float, timed_steps: int) -> None: | ||
| """Finalize training and cleanup.""" | ||
| # Save final checkpoint | ||
| self._save_checkpoint(self.config.num_steps - 1, 0.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Final checkpoint saved with lr=0.0 instead of actual learning rate.
Line 1121: self._save_checkpoint(self.config.num_steps - 1, 0.0) hardcodes lr=0.0. This misrepresents the actual learning rate at the end of training and could affect resume behavior. Additionally, if the last step already triggered a checkpoint save (line 915-916), this creates a redundant save.
🐛 Proposed fix
Track the last learning rate from the training loop and use it here, and guard against double-saving:
def _finalize_training(self, total_time: float, timed_steps: int) -> None:
"""Finalize training and cleanup."""
- # Save final checkpoint
- self._save_checkpoint(self.config.num_steps - 1, 0.0)
+ # Save final checkpoint if not already saved
+ last_step = self.config.num_steps
+ if last_step % self.config.checkpoint.save_freq != 0:
+ self._save_checkpoint(self.config.num_steps - 1, self.lr_schedule.value(self.config.num_steps - 1))🤖 Prompt for AI Agents
In `@deepmd/pt/train/trainer.py` around lines 1118 - 1121, The final checkpoint
currently calls self._save_checkpoint(self.config.num_steps - 1, 0.0) which
hardcodes lr=0.0 and may double-save; change _finalize_training to use the
actual last learning rate tracked by the training loop (e.g., maintain
self._last_lr updated each step in the training loop) and call
self._save_checkpoint(self.config.num_steps - 1, self._last_lr) instead, and add
a guard in _finalize_training (or check in _save_checkpoint) to skip saving if a
checkpoint for step self.config.num_steps - 1 was already written (or if the
stored last checkpoint step equals config.num_steps - 1) to avoid redundant
saves.
| def _compute_prefactors(self, step: int) -> tuple[float, float]: | ||
| """Compute energy and force prefactors for current step.""" | ||
| start_pref_e = self.opt_param["kf_start_pref_e"] | ||
| limit_pref_e = self.opt_param["kf_limit_pref_e"] | ||
| start_pref_f = self.opt_param["kf_start_pref_f"] | ||
| limit_pref_f = self.opt_param["kf_limit_pref_f"] | ||
|
|
||
| ratio = step / self.num_steps | ||
|
|
||
| pref_e = start_pref_e * (limit_pref_e / start_pref_e) ** ratio | ||
| pref_f = start_pref_f * (limit_pref_f / start_pref_f) ** ratio | ||
|
|
||
| return pref_e, pref_f |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Division by zero when start_pref_e or start_pref_f is 0.
Line 235: (limit_pref_e / start_pref_e) ** ratio will raise ZeroDivisionError if start_pref_e is 0, and similarly for start_pref_f on line 236. The legacy trainer guards against zero prefactors.
🐛 Proposed fix
def _compute_prefactors(self, step: int) -> tuple[float, float]:
"""Compute energy and force prefactors for current step."""
start_pref_e = self.opt_param["kf_start_pref_e"]
limit_pref_e = self.opt_param["kf_limit_pref_e"]
start_pref_f = self.opt_param["kf_start_pref_f"]
limit_pref_f = self.opt_param["kf_limit_pref_f"]
ratio = step / self.num_steps
- pref_e = start_pref_e * (limit_pref_e / start_pref_e) ** ratio
- pref_f = start_pref_f * (limit_pref_f / start_pref_f) ** ratio
+ if start_pref_e == 0.0:
+ pref_e = limit_pref_e
+ else:
+ pref_e = start_pref_e * (limit_pref_e / start_pref_e) ** ratio
+ if start_pref_f == 0.0:
+ pref_f = limit_pref_f
+ else:
+ pref_f = start_pref_f * (limit_pref_f / start_pref_f) ** ratio
return pref_e, pref_f🤖 Prompt for AI Agents
In `@deepmd/pt/train/training_loop.py` around lines 226 - 238, The
_compute_prefactors function can divide by zero when start_pref_e or
start_pref_f is 0; update it to guard against zeros by computing ratio = step /
max(1, self.num_steps) and for each prefactor use geometric interpolation as now
when start_pref_* != 0, but fall back to linear interpolation pref = start +
(limit - start) * ratio when start_pref_* == 0 (apply this logic for
start_pref_e/limit_pref_e and start_pref_f/limit_pref_f in _compute_prefactors).
| """Execute LKF training step.""" | ||
| # Compute prefactors | ||
| step = self.optimizer.state.get("step", 0) | ||
| pref_e, pref_f = self._compute_prefactors(step) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Step retrieval from self.optimizer.state may not work as expected.
Line 250: self.optimizer.state.get("step", 0) — optimizer.state is a defaultdict keyed by parameter tensors, not a flat dict with a "step" key. This will always return 0 for standard PyTorch optimizers and likely also for the LKF optimizer, causing the prefactor schedule to never advance.
The step should be passed into the step() method or tracked externally. Consider accepting step as a parameter, since the caller (Trainer) already knows the current training step.
🐛 Proposed approach: pass step as argument
Update BaseTrainingLoop.step to accept step: int or have LKFEnergyTrainingLoop track the step count internally:
class BaseTrainingLoop(ABC):
`@abstractmethod`
def step(
self,
input_dict: dict[str, torch.Tensor],
label_dict: dict[str, torch.Tensor],
cur_lr: float,
pref_lr: float,
task_key: str = "Default",
+ global_step: int = 0,
) -> TrainingStepResult:🤖 Prompt for AI Agents
In `@deepmd/pt/train/training_loop.py` around lines 248 - 252, The code
incorrectly reads step from self.optimizer.state.get("step", 0) (optimizer.state
is per-parameter, so this always yields 0) causing _compute_prefactors to never
advance; fix by passing the current training step into the training step method
instead: update the signature of BaseTrainingLoop.step (and any overrides like
LKFEnergyTrainingLoop.step) to accept a step: int parameter, propagate the
caller's step (Trainer) into that call, and replace uses of
self.optimizer.state.get("step", 0) with the passed-in step when calling
_compute_prefactors.
Summary by CodeRabbit