Skip to content

Conversation

@paul-gibbons
Copy link

An attempt at enabling E2E deterministic training runs for hybrid models.

Atomic add non-determinism

Source of non-determinism are atomic operations in mamba2 bwd triton kernels as well as causal-conv1d. Submitting MR88 in parallel to causal-conv1d repo.

This MR modifies triton kernels to have deterministic path without the use of atomics.

Autotuner / kernel-shape non-determinism

Additional source of non-determinism is usage of tl.cumsum in triton kernels. tl.cumsum will produce different outputs for BLOCK_SIZE_H=1 vs BLOCK_SIZE_H>1. See triton-lang/triton#3017.

Current implementation doesn't make use of triton.autotune cache_results, which leads to chance of introducing non-determinism via different block sizes used for tl.cumsum. Also don't believe latest NGC torch images have version of triton new enough that supports this flag anyway.

Per-kernel breakdown

for S in 2048 4096 16384; do
  python -u tests/benchmark_determinism_kernels.py \
    --preset nemotronh-56b --seqlen "$S" --warmup 50 --rep 300
done

seqlen=2048

kernel ms det_ms overhead MB det_MB overhead
chunk_cumsum_bwd 0.031 0.045 +46% 1672.0 1672.1 +0%
chunk_state_bwd_dx 0.798 0.884 +11% 1940.0 1964.0 +1%
chunk_state_bwd_db 0.812 0.847 +4% 1804.0 1860.0 +3%
chunk_state_bwd_ddAcs 0.620 0.641 +3% 1691.9 1708.0 +1%
chunk_scan_bwd_dC 0.794 0.823 +4% 1804.0 1860.0 +3%
chunk_scan_bwd_dx 0.620 0.667 +8% 1932.0 1948.0 +1%
combined_bwd_dx 0.981 1.027 +5% 1932.0 1948.0 +1%

seqlen=4096

kernel ms det_ms overhead MB det_MB overhead
chunk_cumsum_bwd 0.055 0.068 +22% 3344.0 3344.1 +0%
chunk_state_bwd_dx 1.590 1.777 +12% 3880.0 3928.0 +1%
chunk_state_bwd_db 1.636 1.768 +8% 3608.0 3720.0 +3%
chunk_state_bwd_ddAcs 1.183 1.312 +11% 3384.0 3416.0 +1%
chunk_scan_bwd_dC 1.605 1.748 +9% 3608.0 3720.0 +3%
chunk_scan_bwd_dx 1.262 1.354 +7% 3864.0 3896.0 +1%
combined_bwd_dx 1.943 2.030 +4% 3864.0 3896.0 +1%

seqlen=16384

kernel ms det_ms overhead MB det_MB overhead
chunk_cumsum_bwd 0.166 0.176 +6% 13376.0 13376.5 +0%
chunk_state_bwd_dx 6.471 7.431 +15% 15520.0 15712.0 +1%
chunk_state_bwd_db 6.665 7.056 +6% 14432.0 14880.0 +3%
chunk_state_bwd_ddAcs 4.577 5.450 +19% 13536.0 13664.0 +1%
chunk_scan_bwd_dC 6.504 6.924 +6% 14432.0 14880.0 +3%
chunk_scan_bwd_dx 5.831 6.472 +11% 15456.0 15584.0 +1%
combined_bwd_dx 10.032 10.462 +4% 15456.0 15584.0 +1%

…te correctness issue

Signed-off-by: Paul Gibbons <pgibbons@nvidia.com>
…inism tests passing after fix to _chunk_state_bwd_dx_kernel

Signed-off-by: Paul Gibbons <pgibbons@nvidia.com>
…when triton != 3.4.0

Signed-off-by: Paul Gibbons <pgibbons@nvidia.com>
…ndition, correctness tests pass when headdim/BLOCK_SIZE_N == 1, fail when > 1

Signed-off-by: Paul Gibbons <pgibbons@nvidia.com>
… when triton autotune cache is not available

Signed-off-by: Paul Gibbons <pgibbons@nvidia.com>
…im > BLOCK_SIZE_N

Signed-off-by: Paul Gibbons <pgibbons@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant