Skip to content

Conversation

@Aishwarya-Tonpe
Copy link

@Aishwarya-Tonpe Aishwarya-Tonpe commented Aug 28, 2025

Support for deterministic training and reproducible logging to all PyTorch model benchmarks in SuperBench (BERT, GPT2, LLaMA, LSTM, CNN, Mixtral).

Deterministic mode: Makes sure model runs are consistent every time by fixing random seeds, turning off TF32, and using stable math operations.
Log generation: Saves key info like loss and activation stats during training.
Log comparison: Lets you compare a new run with a previous one to check if they match.
New command-line options:

--enable-determinism
--generate-log {boolean flag which when enabled, stores the metrics (loss and activation mean) to the results file}
--compare-log {log path of the json file against which you want to compare the results of the current run}
--check-frequency

Changes -

Updated pytorch_base.py to handle deterministic settings, logging, and comparisons.
Added a new example script: pytorch_deterministic_example.py
Added a test file: test_pytorch_determinism_all.py to verify everything works as expected.

Usage -

Run with --enable-determinism --generate-log to create a reference log.
Run again with --compare-log to check if the new run matches the reference.
Make sure all parameters stay the same between runs.

- Add _enable_deterministic_training() method to set all necessary seeds
- Add --deterministic and --random_seed command line arguments
- Integrate deterministic training in _create_model() and _generate_dataset()
- Add comprehensive unit tests for deterministic functionality
- Tests validate parameter parsing, functionality, and regression scenarios
- All tests pass and integrate with existing SuperBench test suite
…pass check_frequency to _is_finished in train/infer; add test capturing checksum log; stabilize fp32 loss path and small-dims determinism tests
…oss BERT/GPT2/CNN/LSTM/Mixtral; per-step fp32 loss logging; checksum logs; tests updated to strict/soft determinism pattern; add strict determinism CI guidance
…rings; fix GPT-2 params; soft vs strict checks stabilized
…sum tests with BERT pattern, improve docstrings and skip logic.
…/CNN/BERT/Mixtral with periodic fingerprints, per-step loss capture, TF32 off, SDPA math kernel; add model_log_utils; update examples and tests, add env gating for cuBLAS.
@Aishwarya-Tonpe Aishwarya-Tonpe requested a review from a team as a code owner August 28, 2025 17:41
@Aishwarya-Tonpe
Copy link
Author

@Aishwarya-Tonpe please read the following Contributor License Agreement(CLA). If you agree with the CLA, please reply with the following information.

@microsoft-github-policy-service agree [company="{your company}"]

Options:

  • (default - no company specified) I have sole ownership of intellectual property rights to my Submissions and I am not making Submissions in the course of work for my employer.
@microsoft-github-policy-service agree
  • (when company given) I am making Submissions in the course of work for my employer (or my employer has intellectual property rights in my Submissions by contract or applicable law). I have permission from my employer to make Submissions and enter into this Agreement on behalf of my employer. By signing below, the defined term “You” includes me and my employer.
@microsoft-github-policy-service agree company="Microsoft"

Contributor License Agreement

@microsoft-github-policy-service agree company="Microsoft"

root and others added 29 commits December 8, 2025 22:21
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
…es not need to be set explicitly before running the benchmarks
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Copy link
Member

@abuccts abuccts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The metadata and compare log functions still seem to be unnecessary.

  • For compare log function, it just checks whether the loss etc. in each step are equal or not, which is just a special case of the result analysis. I think you can just re-use current result analysis module to write some yaml configs to perform this comparison, rather then writing new code to do this during online benchmark run. Besides, there exist several scenarios that current compare log function cannot cover:

    1. in large scale training, the all-reduce usually produces accumulated errors due to different reduction orders among runs, so tolerating a range of differences is necessary in analysis/comparison, which can be easily configured in yaml configs of result analysis module.
    2. in validation, the results may need to compared to either baseline or results of other nodes. current compare log only performs 1 on 1 comparison of a pre-defined results, and cannot compare loss between different nodes in one run.
  • For metadata, all settings should already be included in benchmark config. When users compare loss results in two runs, they should guarantee the configs used are same, which is the same as comparing performance results. You may also write the necessary metadata into metrics so that results analysis can compare it as well.

Currently, all benchmarks in superbench only record related metrics during each run in benchmark module, then runner will collect all metrics after each run in runner module, and analysis/comparison is performed offline after all benchmarks finished in result analysis module.

Therefore, it would be better for determinism support in model benchmark follows the same process:

  1. write necessary results (e.g., loss, metadata, etc.) into metrics for each rank in pytorch benchmark during each run
  2. rely on existing results collection process in runner module to collect results from each rank, rather than ad-hoc all-reduce/all-gather in benchmark
  3. rely on existing results analysis module to compare the results offline. if there's any uncovered function for comparison, it would be better to support it generally in results analysis so that determinism in micro-benchmarks can also re-use it in the future.

Besides, please fix the unit tests accordingly.

Comment on lines +41 to +44
- `--enable-determinism`: Enables deterministic computation for reproducible results.
- `--deterministic_seed <seed>`: Sets the seed for reproducibility.
- `--generate_log` : Boolean flag that stores comparison metrics in the results file
- `--compare_log <results_file_path>`: Specifies the path to the reference file for comparison.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unify them to use either underscore or dash?

Comment on lines +229 to +230
def _save_consolidated_deterministic_results(self):
"""Gather deterministic data from all ranks and save to results-summary (rank 0 only).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all results from all ranks will be aggregated to control node by runner after benchmarks, I don't think this function is necessary

Loads the reference results.json file and compares deterministic metrics
(loss, activation mean) per-rank to verify reproducibility.
"""
import torch.distributed as dist
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not import at the beginning

# Synchronize failure status across all ranks in distributed mode
if self._args.distributed_impl == DistributedImpl.DDP:
# Convert failure status to tensor for all_reduce
import torch
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch is also imported in this file, why import again?

Comment on lines +310 to +311
failure_tensor = torch.tensor([1 if has_failure else 0], dtype=torch.int32, device='cuda')
dist.all_reduce(failure_tensor, op=dist.ReduceOp.MAX)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will this work for cpu mode?

Comment on lines -194 to +198
if self._is_finished(curr_step, end, check_frequency):
return duration
if self._is_finished(curr_step, end):
return duration, self._finalize_periodic_logging(periodic)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will change the behavior when the running will be stopped by duration rather than step number

Comment on lines -120 to +124
if self._is_finished(curr_step, end, check_frequency):
return duration
if self._is_finished(curr_step, end):
return duration, self._finalize_periodic_logging(periodic)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same, this will change the behavior when the running will be stopped by duration rather than step number


Return:
The step-time list of every training step.
A tuple of (step_times_ms, info) of every training step.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing one space in indent

Comment on lines -188 to +192
if self._is_finished(curr_step, end, check_frequency):
return duration
if self._is_finished(curr_step, end):
return duration, self._finalize_periodic_logging(periodic)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

end = self._timer()
curr_step += 1
if curr_step > self._args.num_warmup:
# Save the step time of every training/inference step, unit is millisecond.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why remove this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

benchmarks SuperBench Benchmarks model-benchmarks Model Benchmark Test for SuperBench Benchmarks

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants