From 20ddb822efe7436957f09712d0f7260b898d8c57 Mon Sep 17 00:00:00 2001 From: Paul Gibbons Date: Thu, 11 Dec 2025 16:06:56 -0800 Subject: [PATCH 1/7] move gitlab -> github branch --- mamba_ssm/modules/mamba2.py | 5 + mamba_ssm/modules/mamba2_simple.py | 5 + mamba_ssm/ops/triton/ssd_chunk_scan.py | 76 ++++++++-- mamba_ssm/ops/triton/ssd_chunk_state.py | 186 ++++++++++++++++++++---- mamba_ssm/ops/triton/ssd_combined.py | 39 ++++- mamba_ssm/utils/determinism.py | 41 ++++++ tests/benchmark_determinism_kernels.py | 146 +++++++++++++++++++ tests/test_determinism.py | 120 +++++++++++++++ 8 files changed, 567 insertions(+), 51 deletions(-) create mode 100644 mamba_ssm/utils/determinism.py create mode 100644 tests/benchmark_determinism_kernels.py create mode 100644 tests/test_determinism.py diff --git a/mamba_ssm/modules/mamba2.py b/mamba_ssm/modules/mamba2.py index 36b16d471..a3d567c6a 100644 --- a/mamba_ssm/modules/mamba2.py +++ b/mamba_ssm/modules/mamba2.py @@ -30,6 +30,7 @@ from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined +from mamba_ssm.utils.determinism import set_deterministic_mode from huggingface_hub import PyTorchModelHubMixin @@ -58,6 +59,7 @@ def __init__( # Fused kernel and sharding options chunk_size=256, use_mem_eff_path=True, + deterministic=None, layer_idx=None, # Absorb kwarg for general module process_group=None, sequence_parallel=True, @@ -90,6 +92,7 @@ def __init__( self.activation = "silu" self.chunk_size = chunk_size self.use_mem_eff_path = use_mem_eff_path + self.deterministic = deterministic self.layer_idx = layer_idx # Order: [z, x, B, C, dt] @@ -159,6 +162,8 @@ def forward(self, u, seqlen=None, seq_idx=None, cu_seqlens=None, inference_param (in case batch is small). Returns: same shape as u """ + if self.deterministic is not None: + set_deterministic_mode(self.deterministic) seqlen_og = seqlen if seqlen is None: batch, seqlen, dim = u.shape diff --git a/mamba_ssm/modules/mamba2_simple.py b/mamba_ssm/modules/mamba2_simple.py index 77a6af28e..290d58a63 100644 --- a/mamba_ssm/modules/mamba2_simple.py +++ b/mamba_ssm/modules/mamba2_simple.py @@ -19,6 +19,7 @@ from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined +from mamba_ssm.utils.determinism import set_deterministic_mode class Mamba2Simple(nn.Module): @@ -43,6 +44,7 @@ def __init__( # Fused kernel and sharding options chunk_size=256, use_mem_eff_path=True, + deterministic=None, layer_idx=None, # Absorb kwarg for general module device=None, dtype=None, @@ -64,6 +66,7 @@ def __init__( self.activation = activation self.chunk_size = chunk_size self.use_mem_eff_path = use_mem_eff_path + self.deterministic = deterministic self.layer_idx = layer_idx # Order: [z, x, B, C, dt] @@ -126,6 +129,8 @@ def forward(self, u, seq_idx=None): u: (B, L, D) Returns: same shape as u """ + if self.deterministic is not None: + set_deterministic_mode(self.deterministic) batch, seqlen, dim = u.shape zxbcdt = self.in_proj(u) # (B, L, d_in_proj) diff --git a/mamba_ssm/ops/triton/ssd_chunk_scan.py b/mamba_ssm/ops/triton/ssd_chunk_scan.py index 959078061..20f76eed3 100644 --- a/mamba_ssm/ops/triton/ssd_chunk_scan.py +++ b/mamba_ssm/ops/triton/ssd_chunk_scan.py @@ -15,6 +15,11 @@ from einops import rearrange, repeat from mamba_ssm.ops.triton.ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd +from mamba_ssm.utils.determinism import ( + alloc_tile_workspace, + finalize_tile_workspace, + use_deterministic_mode, +) TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0') @@ -535,10 +540,11 @@ def _chunk_scan_bwd_dc_kernel( stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, stride_seq_idx_batch, stride_seq_idx_seqlen, stride_dc_batch, stride_dc_seqlen, stride_dc_split, stride_dc_group, stride_dc_dstate, - stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, + stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, stride_ddA_tile, # Meta-parameters HAS_DDA_CS: tl.constexpr, HAS_SEQ_IDX: tl.constexpr, + DETERMINISTIC_REDUCTION: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, ): pid_bc = tl.program_id(axis=1) @@ -556,7 +562,7 @@ def _chunk_scan_bwd_dc_kernel( dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dA_cs_head if HAS_DDA_CS: C_ptr += pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + pid_g * stride_C_head - ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_ddA_cs_head + ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_ddA_cs_head + pid_n * stride_ddA_tile if HAS_SEQ_IDX: seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen @@ -591,7 +597,10 @@ def _chunk_scan_bwd_dc_kernel( dc *= scale[:, None] if HAS_DDA_CS: ddA_cs = tl.sum(dc * c, axis=1) - tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size) + if DETERMINISTIC_REDUCTION: + tl.store(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size) + else: + tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size) acc += dc dout_ptrs += stride_dout_head prev_states_ptrs += stride_prev_states_head @@ -608,6 +617,11 @@ def _chunk_scan_bwd_dc_kernel( tl.store(dc_ptrs, acc, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate)) +_CHUNK_SCAN_BWD_DC_MIN_BLOCK_N = min( + cfg.kwargs['BLOCK_SIZE_N'] for cfg in _chunk_scan_bwd_dc_kernel.configs +) + + @triton.autotune( configs=[ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddt_ptr"])), @@ -638,11 +652,12 @@ def _chunk_scan_bwd_dx_kernel( stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, stride_D_head, stride_dx_batch, stride_dx_seqlen, stride_dx_head, stride_dx_hdim, - stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize, + stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize, stride_ddt_tile, # stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_hdim, stride_dD_csize, # Meta-parameters HAS_D: tl.constexpr, D_HAS_HDIM: tl.constexpr, + DETERMINISTIC_REDUCTION: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, ): pid_bc = tl.program_id(axis=1) @@ -656,7 +671,7 @@ def _chunk_scan_bwd_dx_kernel( cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head - ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head + ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head + pid_n * stride_ddt_tile dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head # if HAS_D: # dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize @@ -715,7 +730,10 @@ def _chunk_scan_bwd_dx_kernel( x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) ddt = tl.sum(acc * x, axis=1) ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize - tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size) + if DETERMINISTIC_REDUCTION: + tl.store(ddt_ptrs, ddt, mask=offs_m < chunk_size) + else: + tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size) # if HAS_D: # dout_new_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_csize + offs_n[None, :] * stride_dout_hdim) @@ -724,6 +742,11 @@ def _chunk_scan_bwd_dx_kernel( # tl.store(dD_ptr + offs_n * stride_dD_hdim, dD, mask=offs_n < N) +_CHUNK_SCAN_BWD_DX_MIN_BLOCK_N = min( + cfg.kwargs['BLOCK_SIZE_N'] for cfg in _chunk_scan_bwd_dx_kernel.configs +) + + # Disabling HAS_DDA_CS for now since it's much slower @triton.autotune( configs=[ @@ -1433,15 +1456,30 @@ def _chunk_scan_bwd_dC(prev_states, dA_cumsum, dout, seq_idx=None, C=None, ngrou assert dout.shape == (batch, seqlen, nheads, headdim) if seq_idx is not None: assert seq_idx.shape == (batch, seqlen) + deterministic = use_deterministic_mode() if C is not None: assert C.shape == (batch, seqlen, ngroups, dstate) C_strides = (C.stride(0), C.stride(1), C.stride(2), C.stride(3)) - ddA_cumsum_prev = torch.empty(batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32) - ddA_cumsum_prev_strides = (ddA_cumsum_prev.stride(0), ddA_cumsum_prev.stride(2), ddA_cumsum_prev.stride(1), ddA_cumsum_prev.stride(3)) + tile_count = math.ceil(dstate / _CHUNK_SCAN_BWD_DC_MIN_BLOCK_N) + ddA_cumsum_prev, stride_ddA_tile = alloc_tile_workspace( + (batch, nheads, nchunks, chunk_size), + tile_count, + torch.float32, + dout.device, + deterministic, + zero_init=True, + ) + ddA_cumsum_prev_strides = ( + ddA_cumsum_prev.stride(0), + ddA_cumsum_prev.stride(2), + ddA_cumsum_prev.stride(1), + ddA_cumsum_prev.stride(3), + ) else: C_strides = (0, 0, 0, 0) ddA_cumsum_prev = None ddA_cumsum_prev_strides = (0, 0, 0, 0) + stride_ddA_tile = 0 nheads_ngroups_ratio = nheads // ngroups sm_count = torch.cuda.get_device_properties(dout.device).multi_processor_count nheads_per_program = max(min(math.ceil(batch * nchunks * nheads / sm_count), nheads_ngroups_ratio), 1) @@ -1460,12 +1498,15 @@ def _chunk_scan_bwd_dC(prev_states, dA_cumsum, dout, seq_idx=None, C=None, ngrou dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), dC.stride(0), dC.stride(1), dC.stride(2), dC.stride(3), dC.stride(4), - *ddA_cumsum_prev_strides, + *ddA_cumsum_prev_strides, stride_ddA_tile, HAS_DDA_CS=ddA_cumsum_prev is not None, HAS_SEQ_IDX=seq_idx is not None, + DETERMINISTIC_REDUCTION=deterministic, BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16), ) dC = dC.sum(2) + if ddA_cumsum_prev is not None: + ddA_cumsum_prev = finalize_tile_workspace(ddA_cumsum_prev, deterministic) return dC if C is None else (dC, ddA_cumsum_prev) @@ -1535,7 +1576,16 @@ def _chunk_scan_bwd_dx(cb, x, dt, dA_cumsum, dout, D=None): # else: # dD = None dx = torch.empty_like(x) - ddt = torch.empty(batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32) + deterministic = use_deterministic_mode() + tile_count = math.ceil(headdim / _CHUNK_SCAN_BWD_DX_MIN_BLOCK_N) + ddt, stride_ddt_tile = alloc_tile_workspace( + (batch, nheads, nchunks, chunk_size), + tile_count, + torch.float32, + dout.device, + deterministic, + zero_init=True, + ) grid_dx = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']), batch * nchunks, nheads) with torch.cuda.device(x.device.index): @@ -1550,16 +1600,18 @@ def _chunk_scan_bwd_dx(cb, x, dt, dA_cumsum, dout, D=None): dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), D.stride(0) if D is not None else 0, dx.stride(0), dx.stride(1), dx.stride(2), dx.stride(3), - ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3), + ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3), stride_ddt_tile, # dD.stride(1) if dD is not None else 0, dD.stride(2) if dD is not None else 0, dD.stride(3) if dD is not None else 0, dD.stride(4) if dD is not None else 0, dD.stride(0) if dD is not None else 0, D is not None, D.dim() == 2 if D is not None else True, + DETERMINISTIC_REDUCTION=deterministic, ) # if D is not None: # BLOCK_SIZE_actual = _chunk_scan_bwd_dx_kernel.best_config.kwargs["BLOCK_SIZE_M"] # n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual # dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype) - return dx, ddt.to(dtype=dt.dtype) + ddt = finalize_tile_workspace(ddt, deterministic, target_dtype=dt.dtype) + return dx, ddt def _chunk_scan_bwd_ddAcs_unstable(x, dt, out, dout, ddt, D=None, subtract_ddtdt=True): diff --git a/mamba_ssm/ops/triton/ssd_chunk_state.py b/mamba_ssm/ops/triton/ssd_chunk_state.py index 633c66e82..9ff8b949a 100644 --- a/mamba_ssm/ops/triton/ssd_chunk_state.py +++ b/mamba_ssm/ops/triton/ssd_chunk_state.py @@ -13,6 +13,11 @@ from einops import rearrange, repeat from mamba_ssm.ops.triton.softplus import softplus +from mamba_ssm.utils.determinism import ( + alloc_tile_workspace, + finalize_tile_workspace, + use_deterministic_mode, +) def init_to_zero(names): @@ -107,12 +112,13 @@ def _chunk_cumsum_bwd_kernel( stride_A_head, stride_dt_bias_head, stride_ddt_batch, stride_ddt_seqlen, stride_ddt_head, - stride_dA_head, - stride_ddt_bias_head, + stride_dA_batch, stride_dA_chunk, stride_dA_head, + stride_ddt_bias_batch, stride_ddt_bias_chunk, stride_ddt_bias_head, # Meta-parameters DT_SOFTPLUS: tl.constexpr, HAS_DT_BIAS: tl.constexpr, BLOCK_SIZE_H: tl.constexpr, BLOCK_SIZE_CHUNK: tl.constexpr, + DETERMINISTIC_REDUCTION: tl.constexpr, ): pid_b = tl.program_id(axis=0) pid_c = tl.program_id(axis=1) @@ -153,10 +159,18 @@ def _chunk_cumsum_bwd_kernel( ddt = tl.where(dt_presoftplus <= 20.0, ddt * tl.sigmoid(dt_presoftplus), ddt) tl.store(ddt_ptrs, ddt, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit)) dA = tl.sum(ddA * dt, axis=1) - tl.atomic_add(dA_ptr + offs_h * stride_dA_head, dA, mask=offs_h < nheads) + dA_ptr += pid_b * stride_dA_batch + pid_c * stride_dA_chunk + if DETERMINISTIC_REDUCTION: + tl.store(dA_ptr + offs_h * stride_dA_head, dA, mask=offs_h < nheads) + else: + tl.atomic_add(dA_ptr + offs_h * stride_dA_head, dA, mask=offs_h < nheads) if HAS_DT_BIAS: ddt_bias = tl.sum(ddt, axis=1) - tl.atomic_add(ddt_bias_ptr + offs_h * stride_ddt_bias_head, ddt_bias, mask=offs_h < nheads) + ddt_bias_ptr += pid_b * stride_ddt_bias_batch + pid_c * stride_ddt_bias_chunk + if DETERMINISTIC_REDUCTION: + tl.store(ddt_bias_ptr + offs_h * stride_ddt_bias_head, ddt_bias, mask=offs_h < nheads) + else: + tl.atomic_add(ddt_bias_ptr + offs_h * stride_ddt_bias_head, ddt_bias, mask=offs_h < nheads) @triton.autotune( @@ -282,9 +296,10 @@ def _chunk_state_bwd_dx_kernel( stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, stride_dx_batch, stride_dx_seqlen, stride_dx_head, stride_dx_hdim, - stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize, - stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, + stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize, stride_ddt_tile, + stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, stride_ddA_tile, # Meta-parameters + DETERMINISTIC_REDUCTION: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_DSTATE: tl.constexpr, ): @@ -299,8 +314,8 @@ def _chunk_state_bwd_dx_kernel( b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + pid_h * stride_states_head dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head - ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head - ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head + ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head + pid_n * stride_ddt_tile + ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head + pid_n * stride_ddA_tile dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) @@ -341,12 +356,19 @@ def _chunk_state_bwd_dx_kernel( x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) ddt = tl.sum(acc * x, axis=1) ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize - tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size) + if DETERMINISTIC_REDUCTION: + tl.store(ddt_ptrs, ddt, mask=offs_m < chunk_size) + else: + tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size) ddA_cs = -(ddt * dt_m) ddA_cs_last = -tl.sum(ddA_cs) ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize - tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size) - tl.atomic_add(ddA_cumsum_ptr + (chunk_size - 1) * stride_ddA_cs_csize, ddA_cs_last) + if DETERMINISTIC_REDUCTION: + tl.store(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size) + tl.store(ddA_cumsum_ptr + (chunk_size - 1) * stride_ddA_cs_csize, ddA_cs_last) + else: + tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size) + tl.atomic_add(ddA_cumsum_ptr + (chunk_size - 1) * stride_ddA_cs_csize, ddA_cs_last) dx = (acc * dt_m[:, None]).to(dx_ptr.dtype.element_ty) dx_ptr += pid_b * stride_dx_batch + pid_c * chunk_size * stride_dx_seqlen + pid_h * stride_dx_head @@ -354,6 +376,11 @@ def _chunk_state_bwd_dx_kernel( tl.store(dx_ptrs, dx, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim)) +_CHUNK_STATE_BWD_DX_MIN_BLOCK_N = min( + cfg.kwargs['BLOCK_SIZE_N'] for cfg in _chunk_state_bwd_dx_kernel.configs +) + + @triton.autotune( configs=[ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), @@ -383,10 +410,11 @@ def _chunk_state_bwd_db_kernel( stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, stride_seq_idx_batch, stride_seq_idx_seqlen, stride_db_batch, stride_db_seqlen, stride_db_split, stride_db_group, stride_db_dstate, - stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, + stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, stride_ddA_tile, # Meta-parameters HAS_DDA_CS: tl.constexpr, HAS_SEQ_IDX: tl.constexpr, + DETERMINISTIC_REDUCTION: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, ): pid_bc = tl.program_id(axis=1) @@ -405,7 +433,7 @@ def _chunk_state_bwd_db_kernel( dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dA_cs_head if HAS_DDA_CS: b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + pid_g * stride_b_head - ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_ddA_cs_head + ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_ddA_cs_head + pid_n * stride_ddA_tile if HAS_SEQ_IDX: seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen @@ -446,7 +474,10 @@ def _chunk_state_bwd_db_kernel( if HAS_DDA_CS: # This is the gradient wrt (dA_cs_last - dA_cs_m), i.e. the exclusive reverse cumsum ddA_cs = tl.sum(db * b, axis=1) - tl.atomic_add(ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1) + if DETERMINISTIC_REDUCTION: + tl.store(ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1) + else: + tl.atomic_add(ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1) acc += db x_ptrs += stride_x_head dstates_ptrs += stride_states_head @@ -466,6 +497,11 @@ def _chunk_state_bwd_db_kernel( tl.store(db_ptrs, acc, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate)) +_CHUNK_STATE_BWD_DB_MIN_BLOCK_N = min( + cfg.kwargs['BLOCK_SIZE_N'] for cfg in _chunk_state_bwd_db_kernel.configs +) + + @triton.autotune( configs=[ # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), @@ -503,9 +539,10 @@ def _chunk_state_bwd_ddAcs_stable_kernel( stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, stride_seq_idx_batch, stride_seq_idx_seqlen, - stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, + stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, stride_ddA_tile, # Meta-parameters HAS_SEQ_IDX: tl.constexpr, + DETERMINISTIC_REDUCTION: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_DSTATE: tl.constexpr, ): @@ -520,7 +557,7 @@ def _chunk_state_bwd_ddAcs_stable_kernel( b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + pid_h * stride_states_head dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head - ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head + ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head + pid_n * stride_ddA_tile dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head if HAS_SEQ_IDX: seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen @@ -574,8 +611,15 @@ def _chunk_state_bwd_ddAcs_stable_kernel( # ddA_cs = tl.cumsum(ddt * dt_m) ddA_cs = ddt * dt_m ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize - # tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size) - tl.atomic_add(ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1) + if DETERMINISTIC_REDUCTION: + tl.store(ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1) + else: + tl.atomic_add(ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1) + + +_CHUNK_STATE_BWD_DDACS_MIN_BLOCK_N = min( + cfg.kwargs['BLOCK_SIZE_N'] for cfg in _chunk_state_bwd_ddAcs_stable_kernel.configs +) @triton.autotune( @@ -703,20 +747,46 @@ def _chunk_cumsum_bwd(ddA, ddt_out, dt, A, dt_bias=None, dt_softplus=False, dt_l assert ddA.shape == (batch, nheads, nchunks, chunk_size) assert ddt_out.shape == (batch, nheads, nchunks, chunk_size) assert A.shape == (nheads,) + deterministic = use_deterministic_mode() if dt_bias is not None: assert dt_bias.shape == (nheads,) - ddt_bias = torch.empty_like(dt_bias, dtype=torch.float32) + if deterministic: + ddt_bias_workspace = torch.zeros( + batch, nchunks, nheads, device=dt.device, dtype=torch.float32 + ) + ddt_bias = torch.empty_like(dt_bias, dtype=torch.float32) + stride_ddt_bias_batch = ddt_bias_workspace.stride(0) + stride_ddt_bias_chunk = ddt_bias_workspace.stride(1) + else: + ddt_bias_workspace = ddt_bias = torch.empty_like( + dt_bias, dtype=torch.float32 + ) + stride_ddt_bias_batch = 0 + stride_ddt_bias_chunk = 0 else: ddt_bias = None + ddt_bias_workspace = None + stride_ddt_bias_batch = 0 + stride_ddt_bias_chunk = 0 if ddt is not None: assert ddt.shape == dt.shape else: ddt = torch.empty_like(dt) dA = torch.empty_like(A, dtype=torch.float32) + if deterministic: + dA_workspace = torch.zeros( + batch, nchunks, nheads, device=dt.device, dtype=torch.float32 + ) + stride_dA_batch = dA_workspace.stride(0) + stride_dA_chunk = dA_workspace.stride(1) + else: + dA_workspace = dA + stride_dA_batch = 0 + stride_dA_chunk = 0 grid_chunk_cs = lambda META: (batch, nchunks, triton.cdiv(nheads, META['BLOCK_SIZE_H'])) with torch.cuda.device(dt.device.index): _chunk_cumsum_bwd_kernel[grid_chunk_cs]( - ddA, ddt_out, dt, A, dt_bias, ddt, dA, ddt_bias, + ddA, ddt_out, dt, A, dt_bias, ddt, dA_workspace, ddt_bias_workspace if ddt_bias is not None else None, batch, seqlen, nheads, chunk_size, dt_limit[0], dt_limit[1], ddA.stride(0), ddA.stride(2), ddA.stride(1), ddA.stride(3), @@ -725,12 +795,17 @@ def _chunk_cumsum_bwd(ddA, ddt_out, dt, A, dt_bias=None, dt_softplus=False, dt_l A.stride(0), dt_bias.stride(0) if dt_bias is not None else 0, ddt.stride(0), ddt.stride(1), ddt.stride(2), - dA.stride(0), - ddt_bias.stride(0) if ddt_bias is not None else 0, + stride_dA_batch, stride_dA_chunk, dA_workspace.stride(-1), + stride_ddt_bias_batch, stride_ddt_bias_chunk, (ddt_bias_workspace.stride(-1) if ddt_bias is not None else 0), dt_softplus, HAS_DT_BIAS=dt_bias is not None, BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size), + DETERMINISTIC_REDUCTION=deterministic, ) + if deterministic: + dA.copy_(dA_workspace.sum(dim=(0, 1))) + if ddt_bias is not None: + ddt_bias.copy_(ddt_bias_workspace.sum(dim=(0, 1))) return ddt, dA, ddt_bias @@ -780,8 +855,24 @@ def _chunk_state_bwd_dx(B, x, dt, dA_cumsum, dstates, dx=None): assert dx.shape == x.shape else: dx = torch.empty_like(x) - ddt = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32) - ddA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=dA_cumsum.device, dtype=torch.float32) + deterministic = use_deterministic_mode() + tile_count = math.ceil(headdim / _CHUNK_STATE_BWD_DX_MIN_BLOCK_N) + ddt, stride_ddt_tile = alloc_tile_workspace( + (batch, nheads, nchunks, chunk_size), + tile_count, + torch.float32, + dt.device, + deterministic, + zero_init=True, + ) + ddA_cumsum, stride_ddA_tile = alloc_tile_workspace( + (batch, nheads, nchunks, chunk_size), + tile_count, + torch.float32, + dA_cumsum.device, + deterministic, + zero_init=True, + ) grid_dx = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']), batch * nchunks, nheads) with torch.cuda.device(x.device.index): @@ -795,11 +886,14 @@ def _chunk_state_bwd_dx(B, x, dt, dA_cumsum, dstates, dx=None): dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), dx.stride(0), dx.stride(1), dx.stride(2), dx.stride(3), - ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3), - ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), + ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3), stride_ddt_tile, + ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), stride_ddA_tile, + DETERMINISTIC_REDUCTION=deterministic, BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), ) - return dx, ddt.to(dt.dtype), ddA_cumsum.to(dA_cumsum.dtype) + ddt = finalize_tile_workspace(ddt, deterministic, target_dtype=dt.dtype) + ddA_cumsum = finalize_tile_workspace(ddA_cumsum, deterministic, target_dtype=dA_cumsum.dtype) + return dx, ddt, ddA_cumsum def _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=None, B=None, ngroups=1): @@ -811,16 +905,31 @@ def _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=None, B=None, ngroups assert dstates.shape == (batch, nchunks, nheads, headdim, dstate) if seq_idx is not None: assert seq_idx.shape == (batch, seqlen) + deterministic = use_deterministic_mode() if B is not None: assert B.shape == (batch, seqlen, ngroups, dstate) B_strides = (B.stride(0), B.stride(1), B.stride(2), B.stride(3)) # Use torch.empty since the Triton kernel will call init_to_zero - ddA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=x.device, dtype=torch.float32) - ddA_cumsum_strides = (ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3)) + tile_count = math.ceil(dstate / _CHUNK_STATE_BWD_DB_MIN_BLOCK_N) + ddA_cumsum, stride_ddA_tile = alloc_tile_workspace( + (batch, nheads, nchunks, chunk_size), + tile_count, + torch.float32, + x.device, + deterministic, + zero_init=True, + ) + ddA_cumsum_strides = ( + ddA_cumsum.stride(0), + ddA_cumsum.stride(2), + ddA_cumsum.stride(1), + ddA_cumsum.stride(3), + ) else: B_strides = (0, 0, 0, 0) ddA_cumsum = None ddA_cumsum_strides = (0, 0, 0, 0) + stride_ddA_tile = 0 nheads_ngroups_ratio = nheads // ngroups sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count nheads_per_program = max(min(math.ceil(batch * nchunks * nheads / sm_count), nheads_ngroups_ratio), 1) @@ -840,13 +949,15 @@ def _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=None, B=None, ngroups dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), dB.stride(0), dB.stride(1), dB.stride(2), dB.stride(3), dB.stride(4), - *ddA_cumsum_strides, + *ddA_cumsum_strides, stride_ddA_tile, HAS_DDA_CS=ddA_cumsum is not None, HAS_SEQ_IDX=seq_idx is not None, + DETERMINISTIC_REDUCTION=deterministic, BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16), ) dB = dB.sum(2) if ddA_cumsum is not None: + ddA_cumsum = finalize_tile_workspace(ddA_cumsum, deterministic) # The first element of ddA_cumsum is always zero, since that dA_cumsum does not contribute # to the state of the chunk. # torch.cumsum(ddA_cumsum[..., 1:], dim=-1, out=ddA_cumsum[..., 1:]) @@ -867,7 +978,16 @@ def _chunk_state_bwd_ddAcs_stable(B, x, dt, dA_cumsum, dstates, seq_idx=None): if seq_idx is not None: assert seq_idx.shape == (batch, seqlen) # Use torch.empty since the Triton kernel will call init_to_zero - ddA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=x.device, dtype=torch.float32) + deterministic = use_deterministic_mode() + tile_count = math.ceil(headdim / _CHUNK_STATE_BWD_DDACS_MIN_BLOCK_N) + ddA_cumsum, stride_ddA_tile = alloc_tile_workspace( + (batch, nheads, nchunks, chunk_size), + tile_count, + torch.float32, + x.device, + deterministic, + zero_init=True, + ) grid_ddtcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']), batch * nchunks, nheads) with torch.cuda.device(x.device.index): @@ -881,11 +1001,13 @@ def _chunk_state_bwd_ddAcs_stable(B, x, dt, dA_cumsum, dstates, seq_idx=None): dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), - ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), + ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), stride_ddA_tile, HAS_SEQ_IDX=seq_idx is not None, + DETERMINISTIC_REDUCTION=deterministic, BLOCK_SIZE_M=max(triton.next_power_of_2(chunk_size), 16), BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), ) + ddA_cumsum = finalize_tile_workspace(ddA_cumsum, deterministic, target_dtype=ddA_cumsum.dtype) torch.cumsum(ddA_cumsum[..., 1:], dim=-1, out=ddA_cumsum[..., 1:]) return ddA_cumsum diff --git a/mamba_ssm/ops/triton/ssd_combined.py b/mamba_ssm/ops/triton/ssd_combined.py index bbf4ecf84..86db29296 100644 --- a/mamba_ssm/ops/triton/ssd_combined.py +++ b/mamba_ssm/ops/triton/ssd_combined.py @@ -42,6 +42,11 @@ from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_ddAcs_prev from mamba_ssm.ops.triton.layernorm_gated import rmsnorm_fn, _layer_norm_fwd, _layer_norm_bwd from mamba_ssm.ops.triton.k_activations import _swiglu_fwd, _swiglu_bwd +from mamba_ssm.utils.determinism import ( + alloc_tile_workspace, + finalize_tile_workspace, + use_deterministic_mode, +) TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0') @@ -91,7 +96,7 @@ def _chunk_scan_chunk_state_bwd_dx_kernel( stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate, stride_dstates_batch, stride_dstates_chunk, stride_dstates_head, stride_dstates_hdim, stride_dstates_dstate, stride_dx_batch, stride_dx_seqlen, stride_dx_head, stride_dx_hdim, - stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize, + stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize, stride_ddt_tile, stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_csize, stride_dD_hdim, # Meta-parameters HAS_D: tl.constexpr, @@ -100,6 +105,7 @@ def _chunk_scan_chunk_state_bwd_dx_kernel( BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_DSTATE: tl.constexpr, IS_TRITON_22: tl.constexpr, + DETERMINISTIC_REDUCTION: tl.constexpr, ): pid_bc = tl.program_id(axis=1) pid_c = pid_bc // batch @@ -112,7 +118,7 @@ def _chunk_scan_chunk_state_bwd_dx_kernel( cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head - ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head + ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head + pid_n * stride_ddt_tile dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + pid_h * stride_dstates_head @@ -226,7 +232,15 @@ def _chunk_scan_chunk_state_bwd_dx_kernel( tl.store(dD_ptr, dD) ddt = tl.sum(acc * x, axis=1) ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize - tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size) + if DETERMINISTIC_REDUCTION: + tl.store(ddt_ptrs, ddt, mask=offs_m < chunk_size) + else: + tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size) + + +_CHUNK_SCAN_CHUNK_STATE_BWD_DX_MIN_BLOCK_N = min( + cfg.kwargs['BLOCK_SIZE_N'] for cfg in _chunk_scan_chunk_state_bwd_dx_kernel.configs +) def _chunk_scan_chunk_state_bwd_dx(x, dt, dA_cumsum, B, CB, dout, dstates, D=None, seq_idx=None, dx=None): @@ -256,7 +270,16 @@ def _chunk_scan_chunk_state_bwd_dx(x, dt, dA_cumsum, B, CB, dout, dstates, D=Non dx = torch.empty_like(x) else: assert dx.shape == x.shape - ddt = torch.empty(batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32) + deterministic = use_deterministic_mode() + tile_count = math.ceil(headdim / _CHUNK_SCAN_CHUNK_STATE_BWD_DX_MIN_BLOCK_N) + ddt, stride_ddt_tile = alloc_tile_workspace( + (batch, nheads, nchunks, chunk_size), + tile_count, + torch.float32, + dout.device, + deterministic, + zero_init=True, + ) grid_dx = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']), batch * nchunks, nheads) with torch.cuda.device(x.device.index): @@ -274,13 +297,14 @@ def _chunk_scan_chunk_state_bwd_dx(x, dt, dA_cumsum, B, CB, dout, dstates, D=Non B.stride(0), B.stride(1), B.stride(2), B.stride(3), dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4), dx.stride(0), dx.stride(1), dx.stride(2), dx.stride(3), - ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3), + ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3), stride_ddt_tile, dD_strides[1], dD_strides[2], dD_strides[3], dD_strides[0], dD_strides[4], D is not None, D.dim() == 2 if D is not None else True, HAS_SEQ_IDX=seq_idx is not None, BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), - IS_TRITON_22=TRITON_22 + IS_TRITON_22=TRITON_22, + DETERMINISTIC_REDUCTION=deterministic, ) if D is not None: BLOCK_SIZE_actual = _chunk_scan_chunk_state_bwd_dx_kernel.best_config.kwargs["BLOCK_SIZE_M"] @@ -288,7 +312,8 @@ def _chunk_scan_chunk_state_bwd_dx(x, dt, dA_cumsum, B, CB, dout, dstates, D=Non dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype) if D.dim() == 1: dD = rearrange(dD, "h 1 -> h") - return dx, ddt.to(dtype=dt.dtype), dD + ddt = finalize_tile_workspace(ddt, deterministic, target_dtype=dt.dtype) + return dx, ddt, dD def _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, cu_seqlens=None, dt_softplus=False, dt_limit=(0.0, float("inf"))): diff --git a/mamba_ssm/utils/determinism.py b/mamba_ssm/utils/determinism.py new file mode 100644 index 000000000..874b7a466 --- /dev/null +++ b/mamba_ssm/utils/determinism.py @@ -0,0 +1,41 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. + +import os +import torch + +_deterministic_override = None + + +def use_deterministic_mode(): + if _deterministic_override is not None: + return _deterministic_override + val = os.environ.get('MAMBA_DETERMINISTIC', '').lower() + return val in ('1', 'true', 'yes') + + +def set_deterministic_mode(value): + global _deterministic_override + _deterministic_override = value + + +def alloc_tile_workspace(base_shape, tile_dim, dtype, device, deterministic, *, zero_init=True): + """Allocate buffer for deterministic per-program reductions.""" + if base_shape is None: + return None, 0 + if deterministic: + factory = torch.zeros if zero_init else torch.empty + tensor = factory(*base_shape, tile_dim, device=device, dtype=dtype) + return tensor, tensor.stride(-1) + tensor = torch.empty(*base_shape, device=device, dtype=dtype) + return tensor, 0 + + +def finalize_tile_workspace(tensor, deterministic, *, target_dtype=torch.float32): + """Collapse extra tile dimension (if needed) and optionally cast.""" + if tensor is None: + return None + if deterministic: + tensor = tensor.sum(dim=-1) + if target_dtype is not None and tensor.dtype != target_dtype: + tensor = tensor.to(target_dtype) + return tensor diff --git a/tests/benchmark_determinism_kernels.py b/tests/benchmark_determinism_kernels.py new file mode 100644 index 000000000..3897018db --- /dev/null +++ b/tests/benchmark_determinism_kernels.py @@ -0,0 +1,146 @@ +#!/usr/bin/env python +# Copyright (c) 2024, Tri Dao, Albert Gu. + +import gc +import math + +import torch +from triton.testing import do_bench + +from mamba_ssm.utils.determinism import set_deterministic_mode + +MODEL_PRESETS = { + "small": {"nheads": 32, "headdim": 64, "dstate": 64, "ngroups": 1}, + "nemotronh-56b": {"nheads": 256, "headdim": 64, "dstate": 256, "ngroups": 8}, +} + + +def _reset_peak_memory() -> None: + gc.collect() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + torch.cuda.synchronize() + + +def _peak_memory_mb(fn, *, warmup: int = 3) -> float: + for _ in range(warmup): + fn() + torch.cuda.synchronize() + _reset_peak_memory() + fn() + torch.cuda.synchronize() + return torch.cuda.max_memory_allocated() / (1024 * 1024) + + +def make_tensors(*, batch: int, seqlen: int, nheads: int, headdim: int, dstate: int, ngroups: int, chunk_size: int, + dtype: torch.dtype = torch.bfloat16) -> dict[str, torch.Tensor]: + device = "cuda" + nchunks = math.ceil(seqlen / chunk_size) + return { + "x": torch.randn(batch, seqlen, nheads, headdim, device=device, dtype=dtype), + "B": torch.randn(batch, seqlen, ngroups, dstate, device=device, dtype=dtype), + "C": torch.randn(batch, seqlen, ngroups, dstate, device=device, dtype=dtype), + "dt": torch.randn(batch, nheads, nchunks, chunk_size, device=device, dtype=torch.float32), + "dA_cumsum": torch.randn(batch, nheads, nchunks, chunk_size, device=device, dtype=torch.float32), + "dstates": torch.randn(batch, nchunks, nheads, headdim, dstate, device=device, dtype=torch.float32), + "dout": torch.randn(batch, seqlen, nheads, headdim, device=device, dtype=dtype), + "ddA": torch.randn(batch, nheads, nchunks, chunk_size, device=device, dtype=torch.float32), + "ddt_out": torch.randn(batch, nheads, nchunks, chunk_size, device=device, dtype=torch.float32), + "dt_raw": torch.randn(batch, seqlen, nheads, device=device, dtype=dtype), + "A": torch.randn(nheads, device=device, dtype=torch.float32) * -1, + "dt_bias": torch.randn(nheads, device=device, dtype=torch.float32), + "prev_states": torch.randn(batch, nchunks, nheads, headdim, dstate, device=device, dtype=torch.float32), + "cb": torch.randn(batch, nchunks, ngroups, chunk_size, chunk_size, device=device, dtype=dtype), + } + + +def get_benchmarks(t: dict[str, torch.Tensor], *, ngroups: int): + from mamba_ssm.ops.triton.ssd_chunk_state import ( + _chunk_cumsum_bwd, + _chunk_state_bwd_db, + _chunk_state_bwd_ddAcs_stable, + _chunk_state_bwd_dx, + ) + from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_dC, _chunk_scan_bwd_dx + from mamba_ssm.ops.triton.ssd_combined import _chunk_scan_chunk_state_bwd_dx + + x = t["x"].contiguous() + B = t["B"].contiguous() + C = t["C"].contiguous() + dout = t["dout"].contiguous() + dstates = t["dstates"].contiguous() + + return [ + ("chunk_cumsum_bwd", lambda: _chunk_cumsum_bwd(t["ddA"], t["ddt_out"], t["dt_raw"], t["A"], dt_bias=t["dt_bias"], dt_softplus=True)), + ("chunk_state_bwd_dx", lambda: _chunk_state_bwd_dx(B, x, t["dt"], t["dA_cumsum"], dstates)), + ("chunk_state_bwd_db", lambda: _chunk_state_bwd_db(x, t["dt"], t["dA_cumsum"], dstates, B=B, ngroups=ngroups)), + ("chunk_state_bwd_ddAcs", lambda: _chunk_state_bwd_ddAcs_stable(B, x, t["dt"], t["dA_cumsum"], dstates)), + ("chunk_scan_bwd_dC", lambda: _chunk_scan_bwd_dC(t["prev_states"], t["dA_cumsum"], dout, C=C, ngroups=ngroups)), + ("chunk_scan_bwd_dx", lambda: _chunk_scan_bwd_dx(t["cb"], x, t["dt"], t["dA_cumsum"], dout)), + ("combined_bwd_dx", lambda: _chunk_scan_chunk_state_bwd_dx(x, t["dt"], t["dA_cumsum"], B, t["cb"], dout, dstates)), + ] + + +def _run_one(fn, *, deterministic: bool, warmup: int, rep: int): + set_deterministic_mode(deterministic) + ms = do_bench(fn, warmup=warmup, rep=rep, return_mode="median") + peak_mb = _peak_memory_mb(fn, warmup=1) + return ms, peak_mb + + +def main() -> None: + import argparse + + parser = argparse.ArgumentParser(description="Benchmark determinism overhead for key Triton backward kernels") + parser.add_argument("--preset", choices=sorted(MODEL_PRESETS.keys()), default="small") + parser.add_argument("--warmup", type=int, default=25) + parser.add_argument("--rep", type=int, default=100) + parser.add_argument("--batch", type=int, default=4) + parser.add_argument("--seqlen", type=int, default=2048) + parser.add_argument("--chunk-size", type=int, default=256) + args = parser.parse_args() + + if not torch.cuda.is_available(): + raise SystemExit("CUDA not available") + + p = MODEL_PRESETS[args.preset] + tensors = make_tensors( + batch=args.batch, + seqlen=args.seqlen, + nheads=p["nheads"], + headdim=p["headdim"], + dstate=p["dstate"], + ngroups=p["ngroups"], + chunk_size=args.chunk_size, + ) + benches = get_benchmarks(tensors, ngroups=p["ngroups"]) + + print(f"GPU: {torch.cuda.get_device_name(0)}") + print(f"preset={args.preset} batch={args.batch} seqlen={args.seqlen} chunk_size={args.chunk_size}") + print(f"{'kernel':<20} {'ms':>9} {'det_ms':>9} {'ms_%':>6} {'MB':>9} {'det_MB':>9} {'MB_%':>6}") + + rows = [] + try: + for name, fn in benches: + ms, mb = _run_one(fn, deterministic=False, warmup=args.warmup, rep=args.rep) + det_ms, det_mb = _run_one(fn, deterministic=True, warmup=args.warmup, rep=args.rep) + ms_pct = (det_ms / ms - 1.0) * 100.0 + mb_pct = (det_mb / mb - 1.0) * 100.0 if mb else 0.0 + rows.append((name, ms, det_ms, ms_pct, mb, det_mb, mb_pct)) + print(f"{name:<20} {ms:>9.3f} {det_ms:>9.3f} {ms_pct:>+6.0f}% {mb:>9.1f} {det_mb:>9.1f} {mb_pct:>+6.0f}%") + finally: + set_deterministic_mode(None) + + total_ms = sum(r[1] for r in rows) + total_det_ms = sum(r[2] for r in rows) + max_mb = max(r[4] for r in rows) if rows else 0.0 + max_det_mb = max(r[5] for r in rows) if rows else 0.0 + total_pct = (total_det_ms / total_ms - 1.0) * 100.0 if total_ms else 0.0 + max_mb_pct = (max_det_mb / max_mb - 1.0) * 100.0 if max_mb else 0.0 + print(f"{'TOTAL/MAX':<20} {total_ms:>9.3f} {total_det_ms:>9.3f} {total_pct:>+6.0f}% {max_mb:>9.1f} {max_det_mb:>9.1f} {max_mb_pct:>+6.0f}%") + + +if __name__ == "__main__": + main() + + diff --git a/tests/test_determinism.py b/tests/test_determinism.py new file mode 100644 index 000000000..3975657f9 --- /dev/null +++ b/tests/test_determinism.py @@ -0,0 +1,120 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. + +import os + +import pytest +import torch + +from mamba_ssm.utils.determinism import set_deterministic_mode + +MODEL_PRESETS = { + "small": {"d_model": 256, "headdim": 64, "d_state": 64, "ngroups": 1}, + "nemotronh-56b": {"d_model": 8192, "headdim": 64, "d_state": 256, "ngroups": 8}, +} + + +def _configure(deterministic: bool) -> None: + os.environ["MAMBA_DETERMINISTIC"] = "1" if deterministic else "0" + os.environ["CAUSAL_CONV1D_DETERMINISTIC"] = "1" if deterministic else "0" + if deterministic: + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + set_deterministic_mode(deterministic) + + +def _set_seeds(seed: int) -> None: + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def _max_abs_diff(a: torch.Tensor, b: torch.Tensor) -> float: + return (a.float() - b.float()).abs().max().item() + + +def _run_mamba2_backward(*, cfg: dict, seed: int) -> dict[str, torch.Tensor]: + from mamba_ssm.modules.mamba2 import Mamba2 + + _set_seeds(seed) + model = Mamba2( + d_model=cfg["d_model"], + d_state=cfg["d_state"], + d_conv=4, + expand=2, + headdim=cfg["headdim"], + ngroups=cfg["ngroups"], + use_mem_eff_path=True, + device="cuda", + dtype=torch.bfloat16, + ) + x = torch.randn(cfg["batch"], cfg["seqlen"], cfg["d_model"], device="cuda", dtype=torch.bfloat16) + out = model(x) + out.sum().backward() + return {name: p.grad.detach().clone() for name, p in model.named_parameters() if p.grad is not None} + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +def test_deterministic_backend_reproducible_small(): + cfg = {**MODEL_PRESETS["small"], "batch": 2, "seqlen": 2048} + _configure(True) + grads0 = _run_mamba2_backward(cfg=cfg, seed=123) + grads1 = _run_mamba2_backward(cfg=cfg, seed=123) + for k in grads0: + assert _max_abs_diff(grads0[k], grads1[k]) == 0.0 + + +def main() -> int: + import argparse + import subprocess + import sys + + parser = argparse.ArgumentParser(description="Mamba2 determinism check (manual)") + parser.add_argument("--preset", choices=sorted(MODEL_PRESETS.keys()), default="small") + parser.add_argument("--batch", type=int, default=2) + parser.add_argument("--seqlen", type=int, default=2048) + parser.add_argument("--runs", type=int, default=5) + parser.add_argument("--mode", choices=["det", "default", "both"], default="det") + args = parser.parse_args() + + if not torch.cuda.is_available(): + raise SystemExit("CUDA not available") + + if args.mode == "both": + # Run each mode in a fresh process to avoid environment / library init leakage. + base = [ + sys.executable, __file__, + "--preset", args.preset, + "--batch", str(args.batch), + "--seqlen", str(args.seqlen), + "--runs", str(args.runs), + ] + subprocess.check_call(base + ["--mode", "default"]) + subprocess.check_call(base + ["--mode", "det"]) + return 0 + + deterministic = args.mode == "det" + _configure(deterministic) + + cfg = {**MODEL_PRESETS[args.preset], "batch": args.batch, "seqlen": args.seqlen} + grads = [_run_mamba2_backward(cfg=cfg, seed=123) for _ in range(args.runs)] + + max_diff = 0.0 + max_name = None + for name in grads[0]: + for i in range(1, args.runs): + diff = _max_abs_diff(grads[0][name], grads[i][name]) + if diff > max_diff: + max_diff = diff + max_name = name + + print(f"mode={args.mode} max_grad_diff={max_diff:.6e} max_param={max_name}") + if args.mode == "default" and max_diff == 0.0: + print( + "note: default path can appear deterministic for some small configs; " + "try --preset nemotronh-56b --seqlen 16384 (or larger batch/seqlen) to reproduce nondeterminism." + ) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) + + From ea7dfe72fc32ca526161e3cd99497db48ec02f96 Mon Sep 17 00:00:00 2001 From: Paul Gibbons Date: Thu, 11 Dec 2025 20:53:03 -0800 Subject: [PATCH 2/7] improve determinism tests + add correctness tests + fix ssd_chunk_state correctness issue Signed-off-by: Paul Gibbons --- mamba_ssm/modules/mamba2.py | 5 - mamba_ssm/modules/mamba2_simple.py | 5 - mamba_ssm/ops/triton/ssd_chunk_state.py | 3 +- tests/test_determinism.py | 216 +++++++++++++++--------- 4 files changed, 134 insertions(+), 95 deletions(-) diff --git a/mamba_ssm/modules/mamba2.py b/mamba_ssm/modules/mamba2.py index a3d567c6a..36b16d471 100644 --- a/mamba_ssm/modules/mamba2.py +++ b/mamba_ssm/modules/mamba2.py @@ -30,7 +30,6 @@ from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined -from mamba_ssm.utils.determinism import set_deterministic_mode from huggingface_hub import PyTorchModelHubMixin @@ -59,7 +58,6 @@ def __init__( # Fused kernel and sharding options chunk_size=256, use_mem_eff_path=True, - deterministic=None, layer_idx=None, # Absorb kwarg for general module process_group=None, sequence_parallel=True, @@ -92,7 +90,6 @@ def __init__( self.activation = "silu" self.chunk_size = chunk_size self.use_mem_eff_path = use_mem_eff_path - self.deterministic = deterministic self.layer_idx = layer_idx # Order: [z, x, B, C, dt] @@ -162,8 +159,6 @@ def forward(self, u, seqlen=None, seq_idx=None, cu_seqlens=None, inference_param (in case batch is small). Returns: same shape as u """ - if self.deterministic is not None: - set_deterministic_mode(self.deterministic) seqlen_og = seqlen if seqlen is None: batch, seqlen, dim = u.shape diff --git a/mamba_ssm/modules/mamba2_simple.py b/mamba_ssm/modules/mamba2_simple.py index 290d58a63..77a6af28e 100644 --- a/mamba_ssm/modules/mamba2_simple.py +++ b/mamba_ssm/modules/mamba2_simple.py @@ -19,7 +19,6 @@ from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined -from mamba_ssm.utils.determinism import set_deterministic_mode class Mamba2Simple(nn.Module): @@ -44,7 +43,6 @@ def __init__( # Fused kernel and sharding options chunk_size=256, use_mem_eff_path=True, - deterministic=None, layer_idx=None, # Absorb kwarg for general module device=None, dtype=None, @@ -66,7 +64,6 @@ def __init__( self.activation = activation self.chunk_size = chunk_size self.use_mem_eff_path = use_mem_eff_path - self.deterministic = deterministic self.layer_idx = layer_idx # Order: [z, x, B, C, dt] @@ -129,8 +126,6 @@ def forward(self, u, seq_idx=None): u: (B, L, D) Returns: same shape as u """ - if self.deterministic is not None: - set_deterministic_mode(self.deterministic) batch, seqlen, dim = u.shape zxbcdt = self.in_proj(u) # (B, L, d_in_proj) diff --git a/mamba_ssm/ops/triton/ssd_chunk_state.py b/mamba_ssm/ops/triton/ssd_chunk_state.py index 9ff8b949a..0f2b78c0c 100644 --- a/mamba_ssm/ops/triton/ssd_chunk_state.py +++ b/mamba_ssm/ops/triton/ssd_chunk_state.py @@ -364,8 +364,9 @@ def _chunk_state_bwd_dx_kernel( ddA_cs_last = -tl.sum(ddA_cs) ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize if DETERMINISTIC_REDUCTION: + # Preserve atomic semantics by adding ddA_cs_last into the last element + ddA_cs = tl.where(offs_m == (chunk_size - 1), ddA_cs + ddA_cs_last, ddA_cs) tl.store(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size) - tl.store(ddA_cumsum_ptr + (chunk_size - 1) * stride_ddA_cs_csize, ddA_cs_last) else: tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size) tl.atomic_add(ddA_cumsum_ptr + (chunk_size - 1) * stride_ddA_cs_csize, ddA_cs_last) diff --git a/tests/test_determinism.py b/tests/test_determinism.py index 3975657f9..885bc1112 100644 --- a/tests/test_determinism.py +++ b/tests/test_determinism.py @@ -7,15 +7,9 @@ from mamba_ssm.utils.determinism import set_deterministic_mode -MODEL_PRESETS = { - "small": {"d_model": 256, "headdim": 64, "d_state": 64, "ngroups": 1}, - "nemotronh-56b": {"d_model": 8192, "headdim": 64, "d_state": 256, "ngroups": 8}, -} - def _configure(deterministic: bool) -> None: os.environ["MAMBA_DETERMINISTIC"] = "1" if deterministic else "0" - os.environ["CAUSAL_CONV1D_DETERMINISTIC"] = "1" if deterministic else "0" if deterministic: os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" set_deterministic_mode(deterministic) @@ -30,91 +24,145 @@ def _max_abs_diff(a: torch.Tensor, b: torch.Tensor) -> float: return (a.float() - b.float()).abs().max().item() -def _run_mamba2_backward(*, cfg: dict, seed: int) -> dict[str, torch.Tensor]: - from mamba_ssm.modules.mamba2 import Mamba2 +def _make_inputs(*, seed: int, headdim: int, dstate: int, scale: float = 1.0) -> dict[str, torch.Tensor]: + """Inputs for determinism-enabled backward kernels.""" + import math _set_seeds(seed) - model = Mamba2( - d_model=cfg["d_model"], - d_state=cfg["d_state"], - d_conv=4, - expand=2, - headdim=cfg["headdim"], - ngroups=cfg["ngroups"], - use_mem_eff_path=True, - device="cuda", - dtype=torch.bfloat16, - ) - x = torch.randn(cfg["batch"], cfg["seqlen"], cfg["d_model"], device="cuda", dtype=torch.bfloat16) - out = model(x) - out.sum().backward() - return {name: p.grad.detach().clone() for name, p in model.named_parameters() if p.grad is not None} - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") -def test_deterministic_backend_reproducible_small(): - cfg = {**MODEL_PRESETS["small"], "batch": 2, "seqlen": 2048} - _configure(True) - grads0 = _run_mamba2_backward(cfg=cfg, seed=123) - grads1 = _run_mamba2_backward(cfg=cfg, seed=123) - for k in grads0: - assert _max_abs_diff(grads0[k], grads1[k]) == 0.0 - - -def main() -> int: - import argparse - import subprocess - import sys - - parser = argparse.ArgumentParser(description="Mamba2 determinism check (manual)") - parser.add_argument("--preset", choices=sorted(MODEL_PRESETS.keys()), default="small") - parser.add_argument("--batch", type=int, default=2) - parser.add_argument("--seqlen", type=int, default=2048) - parser.add_argument("--runs", type=int, default=5) - parser.add_argument("--mode", choices=["det", "default", "both"], default="det") - args = parser.parse_args() - - if not torch.cuda.is_available(): - raise SystemExit("CUDA not available") - - if args.mode == "both": - # Run each mode in a fresh process to avoid environment / library init leakage. - base = [ - sys.executable, __file__, - "--preset", args.preset, - "--batch", str(args.batch), - "--seqlen", str(args.seqlen), - "--runs", str(args.runs), - ] - subprocess.check_call(base + ["--mode", "default"]) - subprocess.check_call(base + ["--mode", "det"]) - return 0 - - deterministic = args.mode == "det" + device = "cuda" + + batch = 2 + seqlen = 2048 + nheads = 8 + ngroups = 1 + chunk_size = 256 + nchunks = math.ceil(seqlen / chunk_size) + + x = torch.randn(batch, seqlen, nheads, headdim, device=device, dtype=torch.bfloat16) * scale + dout = torch.randn_like(x) * scale + dt = torch.randn(batch, nheads, nchunks, chunk_size, device=device, dtype=torch.float32) * scale + dA_cumsum = torch.randn_like(dt) * scale + cb = torch.randn(batch, nchunks, ngroups, chunk_size, chunk_size, device=device, dtype=torch.bfloat16) * scale + + B = (torch.randn(batch, seqlen, ngroups, dstate, device=device, dtype=torch.bfloat16) * scale).contiguous() + C = (torch.randn(batch, seqlen, ngroups, dstate, device=device, dtype=torch.bfloat16) * scale).contiguous() + dstates = torch.randn(batch, nchunks, nheads, headdim, dstate, device=device, dtype=torch.float32) * scale + prev_states = torch.randn_like(dstates) * scale + + ddA = torch.randn(batch, nheads, nchunks, chunk_size, device=device, dtype=torch.float32) * scale + ddt_out = torch.randn_like(ddA) * scale + dt_raw = torch.randn(batch, seqlen, nheads, device=device, dtype=torch.bfloat16) * scale + A = (torch.randn(nheads, device=device, dtype=torch.float32) * -1.0).contiguous() + dt_bias = (torch.randn(nheads, device=device, dtype=torch.float32) * scale).contiguous() + + return { + "x": x, + "dout": dout, + "dt": dt, + "dA_cumsum": dA_cumsum, + "cb": cb, + "B": B, + "C": C, + "dstates": dstates, + "prev_states": prev_states, + "ddA": ddA, + "ddt_out": ddt_out, + "dt_raw": dt_raw, + "A": A, + "dt_bias": dt_bias, + } + + +def _run_case_outputs(*, case: str, deterministic: bool, seed: int, scale: float = 1.0) -> dict[str, torch.Tensor]: + """Run one kernel wrapper and return named outputs (as fp32).""" _configure(deterministic) + if case in ("chunk_scan_bwd_dC", "chunk_state_bwd_db"): + headdim = 256 + else: + headdim = 384 + dstate = 384 + t = _make_inputs(seed=seed, headdim=headdim, dstate=dstate, scale=scale) + + if case == "chunk_scan_bwd_dx": + from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_dx + dx, ddt = _chunk_scan_bwd_dx(t["cb"], t["x"], t["dt"], t["dA_cumsum"], t["dout"]) + out = {"dx": dx, "ddt": ddt} + elif case == "chunk_scan_bwd_dC": + from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_dC + dC, ddA_prev = _chunk_scan_bwd_dC(t["prev_states"], t["dA_cumsum"], t["dout"], C=t["C"], ngroups=1) + out = {"dC": dC, "ddA_cumsum_prev": ddA_prev} + elif case == "chunk_state_bwd_dx": + from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_state_bwd_dx + dx, ddt, ddA = _chunk_state_bwd_dx(t["B"], t["x"], t["dt"], t["dA_cumsum"], t["dstates"]) + out = {"dx": dx, "ddt": ddt, "ddA_cumsum": ddA} + elif case == "chunk_state_bwd_db": + from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_state_bwd_db + dB, ddA = _chunk_state_bwd_db(t["x"], t["dt"], t["dA_cumsum"], t["dstates"], B=t["B"], ngroups=1) + out = {"dB": dB, "ddA_cumsum": ddA} + elif case == "chunk_state_bwd_ddAcs_stable": + from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_state_bwd_ddAcs_stable + ddA = _chunk_state_bwd_ddAcs_stable(t["B"], t["x"], t["dt"], t["dA_cumsum"], t["dstates"]) + out = {"ddA_cumsum": ddA} + elif case == "chunk_cumsum_bwd": + from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_cumsum_bwd + ddt, dA, ddt_bias = _chunk_cumsum_bwd(t["ddA"], t["ddt_out"], t["dt_raw"], t["A"], dt_bias=t["dt_bias"], dt_softplus=True) + out = {"ddt": ddt, "dA": dA, "ddt_bias": ddt_bias} + elif case == "combined_bwd_dx": + from mamba_ssm.ops.triton.ssd_combined import _chunk_scan_chunk_state_bwd_dx + dx, ddt, _ = _chunk_scan_chunk_state_bwd_dx(t["x"], t["dt"], t["dA_cumsum"], t["B"], t["cb"], t["dout"], t["dstates"]) + out = {"dx": dx, "ddt": ddt} + else: + raise AssertionError(f"Unknown case: {case}") + + torch.cuda.synchronize() + return {k: v.detach().clone().float() for k, v in out.items() if v is not None} + + +_CASES = [ + "chunk_scan_bwd_dx", + "chunk_scan_bwd_dC", + "chunk_state_bwd_dx", + "chunk_state_bwd_db", + "chunk_state_bwd_ddAcs_stable", + "chunk_cumsum_bwd", + "combined_bwd_dx", +] - cfg = {**MODEL_PRESETS[args.preset], "batch": args.batch, "seqlen": args.seqlen} - grads = [_run_mamba2_backward(cfg=cfg, seed=123) for _ in range(args.runs)] - max_diff = 0.0 - max_name = None - for name in grads[0]: - for i in range(1, args.runs): - diff = _max_abs_diff(grads[0][name], grads[i][name]) - if diff > max_diff: - max_diff = diff - max_name = name +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +@pytest.mark.parametrize("case", _CASES) +def test_all_determinism_enabled_kernels_reproducible(case: str): + _run_case_outputs(case=case, deterministic=True, seed=123) - print(f"mode={args.mode} max_grad_diff={max_diff:.6e} max_param={max_name}") - if args.mode == "default" and max_diff == 0.0: - print( - "note: default path can appear deterministic for some small configs; " - "try --preset nemotronh-56b --seqlen 16384 (or larger batch/seqlen) to reproduce nondeterminism." - ) - return 0 + runs = 5 + outs = [_run_case_outputs(case=case, deterministic=True, seed=123) for _ in range(runs)] + ref = outs[0] + for i in range(1, runs): + for k in ref: + assert _max_abs_diff(ref[k], outs[i][k]) == 0.0, f"{case} output {k} differs" -if __name__ == "__main__": - raise SystemExit(main()) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +def test_default_mode_is_not_reproducible_for_atomics_path(): + runs = 20 + outs = [_run_case_outputs(case="chunk_scan_bwd_dx", deterministic=False, seed=123) for _ in range(runs)] + ref = outs[0]["ddt"] + observed = any(_max_abs_diff(ref, outs[i]["ddt"]) != 0.0 for i in range(1, runs)) + if not observed: + pytest.skip("Did not observe nondeterminism in default mode (may be GPU/Triton dependent).") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +@pytest.mark.parametrize("case", _CASES) +def test_all_determinism_enabled_kernels_close_to_default(case: str): + scale = 1e-2 if case == "chunk_state_bwd_dx" else 1.0 + atol = 1e-3 if scale != 1.0 else 1e-2 + rtol = 1e-3 if scale != 1.0 else 1e-2 + _run_case_outputs(case=case, deterministic=True, seed=123, scale=scale) + _run_case_outputs(case=case, deterministic=False, seed=123, scale=scale) + + det = _run_case_outputs(case=case, deterministic=True, seed=123, scale=scale) + for _ in range(3): + default = _run_case_outputs(case=case, deterministic=False, seed=123, scale=scale) + for k in det: + assert torch.allclose(default[k], det[k], atol=atol, rtol=rtol), f"{case} output {k} not close" From 04ebb650f9b7dc0ec21965f3f4977ab381817b7b Mon Sep 17 00:00:00 2001 From: Paul Gibbons Date: Mon, 15 Dec 2025 16:48:23 -0800 Subject: [PATCH 3/7] use torch determinism algo api, refresh determinism tests, all determinism tests passing after fix to _chunk_state_bwd_dx_kernel Signed-off-by: Paul Gibbons --- mamba_ssm/ops/triton/ssd_chunk_state.py | 5 +- mamba_ssm/utils/determinism.py | 6 +- tests/test_determinism.py | 80 +++++++++++++++++++------ 3 files changed, 70 insertions(+), 21 deletions(-) diff --git a/mamba_ssm/ops/triton/ssd_chunk_state.py b/mamba_ssm/ops/triton/ssd_chunk_state.py index 0f2b78c0c..59696d07f 100644 --- a/mamba_ssm/ops/triton/ssd_chunk_state.py +++ b/mamba_ssm/ops/triton/ssd_chunk_state.py @@ -364,8 +364,6 @@ def _chunk_state_bwd_dx_kernel( ddA_cs_last = -tl.sum(ddA_cs) ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize if DETERMINISTIC_REDUCTION: - # Preserve atomic semantics by adding ddA_cs_last into the last element - ddA_cs = tl.where(offs_m == (chunk_size - 1), ddA_cs + ddA_cs_last, ddA_cs) tl.store(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size) else: tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size) @@ -894,6 +892,9 @@ def _chunk_state_bwd_dx(B, x, dt, dA_cumsum, dstates, dx=None): ) ddt = finalize_tile_workspace(ddt, deterministic, target_dtype=dt.dtype) ddA_cumsum = finalize_tile_workspace(ddA_cumsum, deterministic, target_dtype=dA_cumsum.dtype) + if deterministic: + # Match `_chunk_state_bwd_dx_kernel` atomic path (`tl.atomic_add(..., ddA_cs_last)` into last element). + ddA_cumsum[..., -1] -= ddA_cumsum.sum(dim=-1) return dx, ddt, ddA_cumsum diff --git a/mamba_ssm/utils/determinism.py b/mamba_ssm/utils/determinism.py index 874b7a466..cffc90de8 100644 --- a/mamba_ssm/utils/determinism.py +++ b/mamba_ssm/utils/determinism.py @@ -9,8 +9,10 @@ def use_deterministic_mode(): if _deterministic_override is not None: return _deterministic_override - val = os.environ.get('MAMBA_DETERMINISTIC', '').lower() - return val in ('1', 'true', 'yes') + env = os.environ.get('MAMBA_DETERMINISTIC') + if env: + return env[0] == '1' + return torch.are_deterministic_algorithms_enabled() def set_deterministic_mode(value): diff --git a/tests/test_determinism.py b/tests/test_determinism.py index 885bc1112..7348dbe19 100644 --- a/tests/test_determinism.py +++ b/tests/test_determinism.py @@ -5,14 +5,11 @@ import pytest import torch -from mamba_ssm.utils.determinism import set_deterministic_mode - -def _configure(deterministic: bool) -> None: - os.environ["MAMBA_DETERMINISTIC"] = "1" if deterministic else "0" - if deterministic: +def _set_deterministic(enabled: bool) -> None: + torch.use_deterministic_algorithms(enabled) + if enabled: os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" - set_deterministic_mode(deterministic) def _set_seeds(seed: int) -> None: @@ -75,7 +72,7 @@ def _make_inputs(*, seed: int, headdim: int, dstate: int, scale: float = 1.0) -> def _run_case_outputs(*, case: str, deterministic: bool, seed: int, scale: float = 1.0) -> dict[str, torch.Tensor]: """Run one kernel wrapper and return named outputs (as fp32).""" - _configure(deterministic) + _set_deterministic(deterministic) if case in ("chunk_scan_bwd_dC", "chunk_state_bwd_db"): headdim = 256 else: @@ -132,8 +129,6 @@ def _run_case_outputs(*, case: str, deterministic: bool, seed: int, scale: float @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") @pytest.mark.parametrize("case", _CASES) def test_all_determinism_enabled_kernels_reproducible(case: str): - _run_case_outputs(case=case, deterministic=True, seed=123) - runs = 5 outs = [_run_case_outputs(case=case, deterministic=True, seed=123) for _ in range(runs)] ref = outs[0] @@ -144,25 +139,76 @@ def test_all_determinism_enabled_kernels_reproducible(case: str): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") def test_default_mode_is_not_reproducible_for_atomics_path(): - runs = 20 + runs = 50 outs = [_run_case_outputs(case="chunk_scan_bwd_dx", deterministic=False, seed=123) for _ in range(runs)] ref = outs[0]["ddt"] observed = any(_max_abs_diff(ref, outs[i]["ddt"]) != 0.0 for i in range(1, runs)) if not observed: - pytest.skip("Did not observe nondeterminism in default mode (may be GPU/Triton dependent).") + pytest.xfail( + "Did not observe nondeterminism in default mode after " + f"{runs} runs. If you expect nondeterminism on this GPU, increase " + "the run count and/or adjust shapes to increase atomic contention." + ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") @pytest.mark.parametrize("case", _CASES) def test_all_determinism_enabled_kernels_close_to_default(case: str): - scale = 1e-2 if case == "chunk_state_bwd_dx" else 1.0 - atol = 1e-3 if scale != 1.0 else 1e-2 - rtol = 1e-3 if scale != 1.0 else 1e-2 - _run_case_outputs(case=case, deterministic=True, seed=123, scale=scale) - _run_case_outputs(case=case, deterministic=False, seed=123, scale=scale) - + scale = 1.0 + atol = 1e-2 + rtol = atol det = _run_case_outputs(case=case, deterministic=True, seed=123, scale=scale) for _ in range(3): default = _run_case_outputs(case=case, deterministic=False, seed=123, scale=scale) for k in det: assert torch.allclose(default[k], det[k], atol=atol, rtol=rtol), f"{case} output {k} not close" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +def test_mamba2_fwd_bwd_deterministic_mode_is_reproducible(): + from mamba_ssm.modules.mamba2 import Mamba2 + + device = "cuda" + dtype = torch.bfloat16 + seed = 123 + runs = 5 + scale = 1.0 + batch = 2 + seqlen = 2048 + + _set_seeds(seed) + _set_deterministic(True) + + model = Mamba2( + d_model=256, + d_state=384, + headdim=128, + expand=2, + d_conv=4, + chunk_size=256, + use_mem_eff_path=True, + device=device, + dtype=dtype, + ).train() + x_data = torch.randn(batch, seqlen, model.d_model, device=device, dtype=dtype) * scale + + def _run() -> tuple[torch.Tensor, dict[str, torch.Tensor]]: + model.zero_grad(set_to_none=True) + x = x_data.clone().requires_grad_(True) + y = model(x) + (y.float().square().mean()).backward() + torch.cuda.synchronize() + grads: dict[str, torch.Tensor] = {"input": x.grad.detach().float().clone()} + for name, p in model.named_parameters(): + if p.grad is not None: + grads[name] = p.grad.detach().float().clone() + return y.detach().float().clone(), grads + + _run() # warmup + y0, g0 = _run() + for _ in range(runs - 1): + y, g = _run() + assert _max_abs_diff(y0, y) == 0.0 + assert g.keys() == g0.keys() + for k in g0: + assert _max_abs_diff(g0[k], g[k]) == 0.0, f"Mamba2 grad {k} differs" From 9457dda0a2f204541f1745bbb5c3dd34c50201eb Mon Sep 17 00:00:00 2001 From: Paul Gibbons Date: Tue, 16 Dec 2025 00:31:21 -0800 Subject: [PATCH 4/7] add triton cache_results check, fallback to fixed autoconfig wrapper when triton != 3.4.0 Signed-off-by: Paul Gibbons --- mamba_ssm/ops/triton/k_activations.py | 10 +- mamba_ssm/ops/triton/layer_norm.py | 6 +- mamba_ssm/ops/triton/ssd_bmm.py | 10 +- mamba_ssm/ops/triton/ssd_chunk_scan.py | 45 +++--- mamba_ssm/ops/triton/ssd_chunk_state.py | 29 ++-- mamba_ssm/ops/triton/ssd_combined.py | 5 +- mamba_ssm/ops/triton/ssd_state_passing.py | 10 +- mamba_ssm/utils/determinism.py | 37 +++++ tests/test_determinism.py | 186 ++++++++++++++++------ 9 files changed, 237 insertions(+), 101 deletions(-) diff --git a/mamba_ssm/ops/triton/k_activations.py b/mamba_ssm/ops/triton/k_activations.py index 79fa2cc67..6ac41179b 100644 --- a/mamba_ssm/ops/triton/k_activations.py +++ b/mamba_ssm/ops/triton/k_activations.py @@ -5,16 +5,18 @@ import triton import triton.language as tl +from mamba_ssm.utils.determinism import autotune_configs + @triton.autotune( - configs=[ + configs=autotune_configs([ triton.Config({'BLOCK_N': 32}), triton.Config({'BLOCK_N': 64}), triton.Config({'BLOCK_N': 128}), triton.Config({'BLOCK_N': 256}), triton.Config({'BLOCK_N': 512}), triton.Config({'BLOCK_N': 1024}), - ], + ]), key=['ncols'], ) @triton.jit @@ -61,14 +63,14 @@ def _swiglu_fwd(xy, out=None): @triton.autotune( - configs=[ + configs=autotune_configs([ triton.Config({'BLOCK_N': 32}), triton.Config({'BLOCK_N': 64}), triton.Config({'BLOCK_N': 128}), triton.Config({'BLOCK_N': 256}), triton.Config({'BLOCK_N': 512}), triton.Config({'BLOCK_N': 1024}), - ], + ]), key=['ncols'], ) @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["OUT"] is not None}) diff --git a/mamba_ssm/ops/triton/layer_norm.py b/mamba_ssm/ops/triton/layer_norm.py index 200b415a2..3e61f298d 100755 --- a/mamba_ssm/ops/triton/layer_norm.py +++ b/mamba_ssm/ops/triton/layer_norm.py @@ -16,6 +16,8 @@ import triton import triton.language as tl +from mamba_ssm.utils.determinism import autotune_configs + def layer_norm_ref( x, @@ -167,7 +169,7 @@ def config_prune(configs): pruned_configs_autotune = config_prune(configs_autotune) @triton.autotune( - configs = pruned_configs_autotune, + configs=autotune_configs(pruned_configs_autotune), key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"], ) # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) @@ -419,7 +421,7 @@ def _layer_norm_fwd( @triton.autotune( - configs=pruned_configs_autotune, + configs=autotune_configs(pruned_configs_autotune), key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS", "HAS_DROPOUT"], ) # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) diff --git a/mamba_ssm/ops/triton/ssd_bmm.py b/mamba_ssm/ops/triton/ssd_bmm.py index 48fd4f063..20f619f23 100644 --- a/mamba_ssm/ops/triton/ssd_bmm.py +++ b/mamba_ssm/ops/triton/ssd_bmm.py @@ -12,13 +12,15 @@ from einops import rearrange, repeat +from mamba_ssm.utils.determinism import autotune_configs + def init_to_zero(names): return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None] @triton.autotune( - configs=[ + configs=autotune_configs([ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), @@ -28,7 +30,7 @@ def init_to_zero(names): triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2), - ], + ]), key=['chunk_size', 'K', 'IS_CAUSAL'], ) @triton.jit @@ -92,7 +94,7 @@ def _bmm_chunk_fwd_kernel( @triton.autotune( - configs=[ + configs=autotune_configs([ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_CS': 64}, num_stages=3, num_warps=8), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4), @@ -102,7 +104,7 @@ def _bmm_chunk_fwd_kernel( triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_CS': 32}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=2), - ], + ]), key=['chunk_size', 'K'], ) @triton.jit diff --git a/mamba_ssm/ops/triton/ssd_chunk_scan.py b/mamba_ssm/ops/triton/ssd_chunk_scan.py index 20f76eed3..96f9d8907 100644 --- a/mamba_ssm/ops/triton/ssd_chunk_scan.py +++ b/mamba_ssm/ops/triton/ssd_chunk_scan.py @@ -19,6 +19,7 @@ alloc_tile_workspace, finalize_tile_workspace, use_deterministic_mode, + autotune_configs, ) TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0') @@ -29,7 +30,7 @@ def init_to_zero(names): @triton.autotune( - configs=[ + configs=autotune_configs([ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), @@ -41,7 +42,7 @@ def init_to_zero(names): triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2), - ], + ]), key=['chunk_size', 'hdim', 'dstate', 'IS_CAUSAL'], ) @triton.jit @@ -180,14 +181,14 @@ def _chunk_scan_fwd_kernel( @triton.autotune( - configs=[ + configs=autotune_configs([ # triton.Config({'BLOCK_SIZE_N': 256}, num_stages=4, num_warps=4), # triton.Config({'BLOCK_SIZE_N': 128}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_N': 64}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_N': 32}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_N': 64}, num_stages=4, num_warps=8), triton.Config({'BLOCK_SIZE_N': 32}, num_stages=4, num_warps=8), - ], + ]), key=['chunk_size', 'hdim', 'dstate'], ) @triton.jit @@ -336,12 +337,12 @@ def _chunk_scan_fwd_kernel_wip( @triton.autotune( - configs=[ + configs=autotune_configs([ triton.Config({'BLOCK_SIZE_M': 32}), triton.Config({'BLOCK_SIZE_M': 64}), triton.Config({'BLOCK_SIZE_M': 128}), triton.Config({'BLOCK_SIZE_M': 256}), - ], + ]), key=["chunk_size", "hdim"], ) @triton.jit @@ -431,7 +432,7 @@ def _chunk_scan_bwd_dz_kernel( @triton.autotune( - configs=[ + configs=autotune_configs([ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), @@ -441,7 +442,7 @@ def _chunk_scan_bwd_dz_kernel( triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2), - ], + ]), key=['hdim', 'dstate', 'chunk_size'], ) @triton.jit @@ -513,7 +514,7 @@ def _chunk_scan_bwd_dstates_kernel( @triton.autotune( - configs=[ + configs=autotune_configs([ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), @@ -522,7 +523,7 @@ def _chunk_scan_bwd_dstates_kernel( triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), - ], + ]), key=['chunk_size', 'dstate', 'hdim'], ) @triton.jit @@ -623,7 +624,7 @@ def _chunk_scan_bwd_dc_kernel( @triton.autotune( - configs=[ + configs=autotune_configs([ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddt_ptr"])), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), @@ -633,7 +634,7 @@ def _chunk_scan_bwd_dc_kernel( triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), - ], + ]), key=['chunk_size', 'hdim'], ) @triton.jit @@ -749,7 +750,7 @@ def _chunk_scan_bwd_dx_kernel( # Disabling HAS_DDA_CS for now since it's much slower @triton.autotune( - configs=[ + configs=autotune_configs([ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4), @@ -764,7 +765,7 @@ def _chunk_scan_bwd_dx_kernel( # triton.Config({'BLOCK_SIZE_M': 32}, num_stages=4, num_warps=8), # triton.Config({'BLOCK_SIZE_M': 64}, num_stages=4, num_warps=8), # triton.Config({'BLOCK_SIZE_M': 128}, num_stages=4, num_warps=8), - ], + ]), key=['chunk_size', 'hdim'], ) # @triton.heuristics({"BLOCK_SIZE_N": lambda args: max(triton.next_power_of_2(args["chunk_size"]), 16)}) @@ -876,12 +877,12 @@ def _chunk_scan_bwd_dcb_kernel( # Not numerically stable and should not be used. Leaving here for reference. @triton.autotune( - configs=[ + configs=autotune_configs([ triton.Config({'BLOCK_SIZE_M': 32}), triton.Config({'BLOCK_SIZE_M': 64}), triton.Config({'BLOCK_SIZE_M': 128}), triton.Config({'BLOCK_SIZE_M': 256}), - ], + ]), key=["chunk_size", "hdim"], ) @triton.jit @@ -954,7 +955,7 @@ def _chunk_scan_bwd_ddAcs_unstable_kernel( @triton.autotune( - configs=[ + configs=autotune_configs([ # triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4), # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4), # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4), @@ -971,7 +972,7 @@ def _chunk_scan_bwd_ddAcs_unstable_kernel( triton.Config({'BLOCK_SIZE_M': 32}, num_stages=4, num_warps=8), triton.Config({'BLOCK_SIZE_M': 64}, num_stages=4, num_warps=8), triton.Config({'BLOCK_SIZE_M': 128}, num_stages=4, num_warps=8), - ], + ]), key=['chunk_size', 'hdim'], ) @triton.jit @@ -1080,7 +1081,7 @@ def _chunk_scan_bwd_ddAcs_stable_kernel_old( @triton.autotune( - configs=[ + configs=autotune_configs([ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4), # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4), # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4), @@ -1090,7 +1091,7 @@ def _chunk_scan_bwd_ddAcs_stable_kernel_old( triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4), - ], + ]), key=['chunk_size', 'hdim'], ) @triton.jit @@ -1183,14 +1184,14 @@ def _chunk_scan_bwd_ddAcs_stable_kernel( @triton.autotune( - configs=[ + configs=autotune_configs([ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), - ], + ]), key=['chunk_size', 'dstate', 'hdim'], ) @triton.jit diff --git a/mamba_ssm/ops/triton/ssd_chunk_state.py b/mamba_ssm/ops/triton/ssd_chunk_state.py index 59696d07f..88ababdd4 100644 --- a/mamba_ssm/ops/triton/ssd_chunk_state.py +++ b/mamba_ssm/ops/triton/ssd_chunk_state.py @@ -17,6 +17,7 @@ alloc_tile_workspace, finalize_tile_workspace, use_deterministic_mode, + autotune_configs, ) @@ -24,7 +25,7 @@ def init_to_zero(names): return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None] @triton.autotune( - configs=[ + configs=autotune_configs([ triton.Config({'BLOCK_SIZE_H': 1}), triton.Config({'BLOCK_SIZE_H': 2}), triton.Config({'BLOCK_SIZE_H': 4}), @@ -32,7 +33,7 @@ def init_to_zero(names): triton.Config({'BLOCK_SIZE_H': 16}), triton.Config({'BLOCK_SIZE_H': 32}), triton.Config({'BLOCK_SIZE_H': 64}), - ], + ]), key=['chunk_size', 'nheads'], ) @triton.jit @@ -86,7 +87,7 @@ def _chunk_cumsum_fwd_kernel( @triton.autotune( - configs=[ + configs=autotune_configs([ triton.Config({'BLOCK_SIZE_H': 1}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])), triton.Config({'BLOCK_SIZE_H': 2}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])), triton.Config({'BLOCK_SIZE_H': 4}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])), @@ -94,7 +95,7 @@ def _chunk_cumsum_fwd_kernel( triton.Config({'BLOCK_SIZE_H': 16}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])), triton.Config({'BLOCK_SIZE_H': 32}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])), triton.Config({'BLOCK_SIZE_H': 64}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])), - ], + ]), key=['chunk_size', 'nheads'], ) @triton.jit @@ -174,7 +175,7 @@ def _chunk_cumsum_bwd_kernel( @triton.autotune( - configs=[ + configs=autotune_configs([ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), @@ -184,7 +185,7 @@ def _chunk_cumsum_bwd_kernel( triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2), - ], + ]), key=['hdim', 'dstate', 'chunk_size'], ) @triton.jit @@ -268,7 +269,7 @@ def _chunk_state_fwd_kernel( @triton.autotune( - configs=[ + configs=autotune_configs([ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])), @@ -278,7 +279,7 @@ def _chunk_state_fwd_kernel( triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])), - ], + ]), key=['chunk_size', 'hdim', 'dstate'], ) @triton.jit @@ -381,7 +382,7 @@ def _chunk_state_bwd_dx_kernel( @triton.autotune( - configs=[ + configs=autotune_configs([ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), @@ -390,7 +391,7 @@ def _chunk_state_bwd_dx_kernel( triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), - ], + ]), key=['chunk_size', 'dstate', 'hdim'], ) @triton.jit @@ -502,7 +503,7 @@ def _chunk_state_bwd_db_kernel( @triton.autotune( - configs=[ + configs=autotune_configs([ # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), @@ -520,7 +521,7 @@ def _chunk_state_bwd_db_kernel( triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), - ], + ]), key=['chunk_size', 'hdim', 'dstate'], ) @triton.jit @@ -622,7 +623,7 @@ def _chunk_state_bwd_ddAcs_stable_kernel( @triton.autotune( - configs=[ + configs=autotune_configs([ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), @@ -632,7 +633,7 @@ def _chunk_state_bwd_ddAcs_stable_kernel( triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2), - ], + ]), key=['hdim', 'dstate', 'chunk_size'], ) @triton.jit diff --git a/mamba_ssm/ops/triton/ssd_combined.py b/mamba_ssm/ops/triton/ssd_combined.py index 86db29296..8e7157286 100644 --- a/mamba_ssm/ops/triton/ssd_combined.py +++ b/mamba_ssm/ops/triton/ssd_combined.py @@ -44,6 +44,7 @@ from mamba_ssm.ops.triton.k_activations import _swiglu_fwd, _swiglu_bwd from mamba_ssm.utils.determinism import ( alloc_tile_workspace, + autotune_configs, finalize_tile_workspace, use_deterministic_mode, ) @@ -63,7 +64,7 @@ def rearrange_and_update_stride(tensor, pattern=None, dim=2): @triton.autotune( - configs=[ + configs=autotune_configs([ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddt_ptr"])), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), @@ -73,7 +74,7 @@ def rearrange_and_update_stride(tensor, pattern=None, dim=2): triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), - ], + ]), key=['chunk_size', 'hdim', 'dstate'], ) @triton.jit diff --git a/mamba_ssm/ops/triton/ssd_state_passing.py b/mamba_ssm/ops/triton/ssd_state_passing.py index 63863b823..d6aa53c96 100644 --- a/mamba_ssm/ops/triton/ssd_state_passing.py +++ b/mamba_ssm/ops/triton/ssd_state_passing.py @@ -12,16 +12,18 @@ from einops import rearrange, repeat +from mamba_ssm.utils.determinism import autotune_configs + @triton.autotune( - configs=[ + configs=autotune_configs([ triton.Config({'BLOCK_SIZE': 64}), triton.Config({'BLOCK_SIZE': 128}), triton.Config({'BLOCK_SIZE': 256}), triton.Config({'BLOCK_SIZE': 512}), triton.Config({'BLOCK_SIZE': 1024}), triton.Config({'BLOCK_SIZE': 2048}), - ], + ]), key=['dim'], ) @triton.jit @@ -86,14 +88,14 @@ def _state_passing_fwd_kernel( @triton.autotune( - configs=[ + configs=autotune_configs([ triton.Config({'BLOCK_SIZE': 64}), triton.Config({'BLOCK_SIZE': 128}), triton.Config({'BLOCK_SIZE': 256}), triton.Config({'BLOCK_SIZE': 512}), triton.Config({'BLOCK_SIZE': 1024}), triton.Config({'BLOCK_SIZE': 2048}), - ], + ]), key=['dim'], ) @triton.jit diff --git a/mamba_ssm/utils/determinism.py b/mamba_ssm/utils/determinism.py index cffc90de8..b9dcb52f9 100644 --- a/mamba_ssm/utils/determinism.py +++ b/mamba_ssm/utils/determinism.py @@ -1,8 +1,20 @@ # Copyright (c) 2024, Tri Dao, Albert Gu. import os +import warnings +from packaging import version + import torch +try: + import triton + TRITON_VERSION = version.parse(triton.__version__) +except ImportError: + TRITON_VERSION = version.parse("0.0.0") + +TRITON_HAS_CACHE_RESULTS = TRITON_VERSION >= version.parse("3.4.0") +_autotune_warning_issued = False + _deterministic_override = None @@ -20,6 +32,31 @@ def set_deterministic_mode(value): _deterministic_override = value +def autotune_configs(configs): + """Wrap autotune configs for determinism. Uses cached autotuning if available, + otherwise selects single config via TRITON_AUTOTUNE_CONFIG_INDEX (default: last).""" + if not configs or not use_deterministic_mode(): + return configs + + if TRITON_HAS_CACHE_RESULTS and os.environ.get("TRITON_CACHE_AUTOTUNING") == "1": + return configs + + global _autotune_warning_issued + if not _autotune_warning_issued: + _autotune_warning_issued = True + if TRITON_HAS_CACHE_RESULTS: + msg = "Deterministic mode: set TRITON_CACHE_AUTOTUNING=1 for cached autotuning." + else: + msg = "Deterministic mode: upgrade to Triton >= 3.4.0 for cached autotuning." + warnings.warn(msg) + + idx = int(os.environ.get("TRITON_AUTOTUNE_CONFIG_INDEX", "-1")) + if idx < 0: + idx += len(configs) + idx = max(0, min(idx, len(configs) - 1)) + return configs[idx:idx + 1] + + def alloc_tile_workspace(base_shape, tile_dim, dtype, device, deterministic, *, zero_init=True): """Allocate buffer for deterministic per-program reductions.""" if base_shape is None: diff --git a/tests/test_determinism.py b/tests/test_determinism.py index 7348dbe19..39913f87c 100644 --- a/tests/test_determinism.py +++ b/tests/test_determinism.py @@ -21,7 +21,7 @@ def _max_abs_diff(a: torch.Tensor, b: torch.Tensor) -> float: return (a.float() - b.float()).abs().max().item() -def _make_inputs(*, seed: int, headdim: int, dstate: int, scale: float = 1.0) -> dict[str, torch.Tensor]: +def _make_inputs(*, seed: int, headdim: int, dstate: int) -> dict[str, torch.Tensor]: """Inputs for determinism-enabled backward kernels.""" import math @@ -35,22 +35,23 @@ def _make_inputs(*, seed: int, headdim: int, dstate: int, scale: float = 1.0) -> chunk_size = 256 nchunks = math.ceil(seqlen / chunk_size) - x = torch.randn(batch, seqlen, nheads, headdim, device=device, dtype=torch.bfloat16) * scale - dout = torch.randn_like(x) * scale - dt = torch.randn(batch, nheads, nchunks, chunk_size, device=device, dtype=torch.float32) * scale - dA_cumsum = torch.randn_like(dt) * scale - cb = torch.randn(batch, nchunks, ngroups, chunk_size, chunk_size, device=device, dtype=torch.bfloat16) * scale + x = torch.randn(batch, seqlen, nheads, headdim, device=device, dtype=torch.bfloat16) + dout = torch.randn_like(x) + dt = torch.randn(batch, nheads, nchunks, chunk_size, device=device, dtype=torch.float32) + dA_cumsum = torch.randn_like(dt) + cb = torch.randn(batch, nchunks, ngroups, chunk_size, chunk_size, device=device, dtype=torch.bfloat16) - B = (torch.randn(batch, seqlen, ngroups, dstate, device=device, dtype=torch.bfloat16) * scale).contiguous() - C = (torch.randn(batch, seqlen, ngroups, dstate, device=device, dtype=torch.bfloat16) * scale).contiguous() - dstates = torch.randn(batch, nchunks, nheads, headdim, dstate, device=device, dtype=torch.float32) * scale - prev_states = torch.randn_like(dstates) * scale + B = torch.randn(batch, seqlen, ngroups, dstate, device=device, dtype=torch.bfloat16).contiguous() + C = torch.randn(batch, seqlen, ngroups, dstate, device=device, dtype=torch.bfloat16).contiguous() + dstates = torch.randn(batch, nchunks, nheads, headdim, dstate, device=device, dtype=torch.float32) + prev_states = torch.randn_like(dstates) - ddA = torch.randn(batch, nheads, nchunks, chunk_size, device=device, dtype=torch.float32) * scale - ddt_out = torch.randn_like(ddA) * scale - dt_raw = torch.randn(batch, seqlen, nheads, device=device, dtype=torch.bfloat16) * scale + ddA = torch.randn(batch, nheads, nchunks, chunk_size, device=device, dtype=torch.float32) + ddt_out = torch.randn_like(ddA) + dt_raw = torch.randn(batch, seqlen, nheads, device=device, dtype=torch.bfloat16) A = (torch.randn(nheads, device=device, dtype=torch.float32) * -1.0).contiguous() - dt_bias = (torch.randn(nheads, device=device, dtype=torch.float32) * scale).contiguous() + dt_bias = torch.randn(nheads, device=device, dtype=torch.float32).contiguous() + D = torch.randn(nheads, device=device, dtype=torch.float32) return { "x": x, @@ -67,18 +68,16 @@ def _make_inputs(*, seed: int, headdim: int, dstate: int, scale: float = 1.0) -> "dt_raw": dt_raw, "A": A, "dt_bias": dt_bias, + "D": D, } -def _run_case_outputs(*, case: str, deterministic: bool, seed: int, scale: float = 1.0) -> dict[str, torch.Tensor]: +def _run_case_outputs( + *, case: str, deterministic: bool, seed: int, headdim: int = 64, +) -> dict[str, torch.Tensor]: """Run one kernel wrapper and return named outputs (as fp32).""" _set_deterministic(deterministic) - if case in ("chunk_scan_bwd_dC", "chunk_state_bwd_db"): - headdim = 256 - else: - headdim = 384 - dstate = 384 - t = _make_inputs(seed=seed, headdim=headdim, dstate=dstate, scale=scale) + t = _make_inputs(seed=seed, headdim=headdim, dstate=64) if case == "chunk_scan_bwd_dx": from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_dx @@ -106,8 +105,8 @@ def _run_case_outputs(*, case: str, deterministic: bool, seed: int, scale: float out = {"ddt": ddt, "dA": dA, "ddt_bias": ddt_bias} elif case == "combined_bwd_dx": from mamba_ssm.ops.triton.ssd_combined import _chunk_scan_chunk_state_bwd_dx - dx, ddt, _ = _chunk_scan_chunk_state_bwd_dx(t["x"], t["dt"], t["dA_cumsum"], t["B"], t["cb"], t["dout"], t["dstates"]) - out = {"dx": dx, "ddt": ddt} + dx, ddt, dD = _chunk_scan_chunk_state_bwd_dx(t["x"], t["dt"], t["dA_cumsum"], t["B"], t["cb"], t["dout"], t["dstates"], D=t["D"]) + out = {"dx": dx, "ddt": ddt, "dD": dD} else: raise AssertionError(f"Unknown case: {case}") @@ -125,54 +124,97 @@ def _run_case_outputs(*, case: str, deterministic: bool, seed: int, scale: float "combined_bwd_dx", ] - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +@pytest.mark.parametrize("headdim", [32, 64]) @pytest.mark.parametrize("case", _CASES) -def test_all_determinism_enabled_kernels_reproducible(case: str): +def test_all_determinism_enabled_kernels_reproducible(case: str, headdim: int): runs = 5 - outs = [_run_case_outputs(case=case, deterministic=True, seed=123) for _ in range(runs)] + outs = [_run_case_outputs(case=case, deterministic=True, seed=123, headdim=headdim) for _ in range(runs)] ref = outs[0] for i in range(1, runs): for k in ref: - assert _max_abs_diff(ref[k], outs[i][k]) == 0.0, f"{case} output {k} differs" + assert _max_abs_diff(ref[k], outs[i][k]) == 0.0, f"{case} output {k} differs (headdim={headdim})" @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") -def test_default_mode_is_not_reproducible_for_atomics_path(): - runs = 50 - outs = [_run_case_outputs(case="chunk_scan_bwd_dx", deterministic=False, seed=123) for _ in range(runs)] - ref = outs[0]["ddt"] - observed = any(_max_abs_diff(ref, outs[i]["ddt"]) != 0.0 for i in range(1, runs)) - if not observed: +def test_default_mode_is_not_reproducible(): + from mamba_ssm.modules.mamba2 import Mamba2 + + device = "cuda" + dtype = torch.bfloat16 + seed = 123 + runs = 20 + batch = 4 + seqlen = 4096 + + _set_seeds(seed) + model = Mamba2( + d_model=256, + d_state=64, + headdim=64, + expand=2, + d_conv=4, + chunk_size=256, + use_mem_eff_path=True, + device=device, + dtype=dtype, + ).train() + x_data = torch.randn(batch, seqlen, model.d_model, device=device, dtype=dtype) + + def _run() -> dict[str, torch.Tensor]: + _set_deterministic(False) + model.zero_grad(set_to_none=True) + x = x_data.clone().requires_grad_(True) + y = model(x) + (y.float().square().mean()).backward() + torch.cuda.synchronize() + grads = {"input": x.grad.detach().float().clone()} + for name, p in model.named_parameters(): + if p.grad is not None: + grads[name] = p.grad.detach().float().clone() + return grads + + _run() # warmup + ref = _run() + observed_diff = False + for _ in range(runs - 1): + g = _run() + for k in ref: + if _max_abs_diff(ref[k], g[k]) != 0.0: + observed_diff = True + break + if observed_diff: + break + + if not observed_diff: pytest.xfail( - "Did not observe nondeterminism in default mode after " - f"{runs} runs. If you expect nondeterminism on this GPU, increase " - "the run count and/or adjust shapes to increase atomic contention." + f"Did not observe nondeterminism in default mode after {runs} runs. " + "This GPU may have deterministic atomic behavior at these shapes." ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +@pytest.mark.parametrize("headdim", [32, 64]) @pytest.mark.parametrize("case", _CASES) -def test_all_determinism_enabled_kernels_close_to_default(case: str): - scale = 1.0 +def test_all_determinism_enabled_kernels_close_to_default(case: str, headdim: int): atol = 1e-2 rtol = atol - det = _run_case_outputs(case=case, deterministic=True, seed=123, scale=scale) + det = _run_case_outputs(case=case, deterministic=True, seed=123, headdim=headdim) for _ in range(3): - default = _run_case_outputs(case=case, deterministic=False, seed=123, scale=scale) + default = _run_case_outputs(case=case, deterministic=False, seed=123, headdim=headdim) for k in det: - assert torch.allclose(default[k], det[k], atol=atol, rtol=rtol), f"{case} output {k} not close" + assert torch.allclose(default[k], det[k], atol=atol, rtol=rtol), f"{case} output {k} not close (headdim={headdim})" @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") -def test_mamba2_fwd_bwd_deterministic_mode_is_reproducible(): +@pytest.mark.parametrize("headdim", [32, 128]) +def test_mamba2_fwd_bwd_deterministic_mode_is_reproducible(headdim: int): from mamba_ssm.modules.mamba2 import Mamba2 device = "cuda" dtype = torch.bfloat16 seed = 123 runs = 5 - scale = 1.0 batch = 2 seqlen = 2048 @@ -180,17 +222,17 @@ def test_mamba2_fwd_bwd_deterministic_mode_is_reproducible(): _set_deterministic(True) model = Mamba2( - d_model=256, - d_state=384, - headdim=128, + d_model=headdim, + d_state=16, + headdim=headdim, expand=2, d_conv=4, - chunk_size=256, + chunk_size=16, use_mem_eff_path=True, device=device, dtype=dtype, ).train() - x_data = torch.randn(batch, seqlen, model.d_model, device=device, dtype=dtype) * scale + x_data = torch.randn(batch, seqlen, model.d_model, device=device, dtype=dtype) def _run() -> tuple[torch.Tensor, dict[str, torch.Tensor]]: model.zero_grad(set_to_none=True) @@ -211,4 +253,50 @@ def _run() -> tuple[torch.Tensor, dict[str, torch.Tensor]]: assert _max_abs_diff(y0, y) == 0.0 assert g.keys() == g0.keys() for k in g0: - assert _max_abs_diff(g0[k], g[k]) == 0.0, f"Mamba2 grad {k} differs" + assert _max_abs_diff(g0[k], g[k]) == 0.0, f"Mamba2 grad {k} differs (headdim={headdim})" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +@pytest.mark.parametrize("headdim", [32, 64]) +def test_mamba2_fwd_bwd_deterministic_close_to_default(headdim: int): + from mamba_ssm.modules.mamba2 import Mamba2 + + device = "cuda" + dtype = torch.bfloat16 + seed = 123 + batch = 2 + seqlen = 2048 + atol = 1e-2 + rtol = 1e-2 + + def _run(deterministic: bool) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: + torch.use_deterministic_algorithms(deterministic, warn_only=True) + _set_seeds(seed) + model = Mamba2( + d_model=headdim * 4, + d_state=32, + headdim=headdim, + expand=2, + d_conv=4, + chunk_size=64, + use_mem_eff_path=True, + device=device, + dtype=dtype, + ).train() + x = torch.randn(batch, seqlen, model.d_model, device=device, dtype=dtype).requires_grad_(True) + y = model(x) + (y.float().square().mean()).backward() + torch.cuda.synchronize() + grads: dict[str, torch.Tensor] = {"input": x.grad.detach().float().clone()} + for name, p in model.named_parameters(): + if p.grad is not None: + grads[name] = p.grad.detach().float().clone() + return y.detach().float().clone(), grads + + _run(False) # warmup + y_default, g_default = _run(False) + y_det, g_det = _run(True) + + assert torch.allclose(y_default, y_det, atol=atol, rtol=rtol), "Mamba2 output differs" + for k in g_default: + assert torch.allclose(g_default[k], g_det[k], atol=atol, rtol=rtol), f"Mamba2 grad {k} not close (headdim={headdim})" From 4ffdde96e2b7ebda749980d6203d8e3cc3c53cc9 Mon Sep 17 00:00:00 2001 From: Paul Gibbons Date: Tue, 16 Dec 2025 13:42:01 -0800 Subject: [PATCH 5/7] debug: add TRITON_AUTOTUNE_BLOCK_SIZE_N env var to isolate dD race condition, correctness tests pass when headdim/BLOCK_SIZE_N == 1, fail when > 1 Signed-off-by: Paul Gibbons --- mamba_ssm/utils/determinism.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/mamba_ssm/utils/determinism.py b/mamba_ssm/utils/determinism.py index b9dcb52f9..3ce9323a4 100644 --- a/mamba_ssm/utils/determinism.py +++ b/mamba_ssm/utils/determinism.py @@ -34,7 +34,7 @@ def set_deterministic_mode(value): def autotune_configs(configs): """Wrap autotune configs for determinism. Uses cached autotuning if available, - otherwise selects single config via TRITON_AUTOTUNE_CONFIG_INDEX (default: last).""" + otherwise selects single config via TRITON_AUTOTUNE_BLOCK_SIZE_N or TRITON_AUTOTUNE_CONFIG_INDEX.""" if not configs or not use_deterministic_mode(): return configs @@ -50,6 +50,13 @@ def autotune_configs(configs): msg = "Deterministic mode: upgrade to Triton >= 3.4.0 for cached autotuning." warnings.warn(msg) + block_size_n = os.environ.get("TRITON_AUTOTUNE_BLOCK_SIZE_N") + if block_size_n is not None: + target_n = int(block_size_n) + matching = [c for c in configs if c.kwargs.get('BLOCK_SIZE_N') == target_n] + if matching: + return matching[:1] + idx = int(os.environ.get("TRITON_AUTOTUNE_CONFIG_INDEX", "-1")) if idx < 0: idx += len(configs) From 5b2ef9c82ea8c146143cbb15e85dbdf4eaae7bf5 Mon Sep 17 00:00:00 2001 From: Paul Gibbons Date: Wed, 17 Dec 2025 13:21:20 -0800 Subject: [PATCH 6/7] refactor determinism test + select autotune config based on heuristic when triton autotune cache is not available Signed-off-by: Paul Gibbons --- mamba_ssm/utils/determinism.py | 63 +++++++----- tests/test_determinism.py | 180 ++++++++++++++++++++------------- 2 files changed, 146 insertions(+), 97 deletions(-) diff --git a/mamba_ssm/utils/determinism.py b/mamba_ssm/utils/determinism.py index 3ce9323a4..c6066f80d 100644 --- a/mamba_ssm/utils/determinism.py +++ b/mamba_ssm/utils/determinism.py @@ -32,36 +32,49 @@ def set_deterministic_mode(value): _deterministic_override = value +def _estimate_config_cost(cfg): + """Estimate shared memory cost of a config. Lower is cheaper.""" + block_product = 1 + for key, val in cfg.kwargs.items(): + if key.startswith('BLOCK_SIZE_'): + block_product *= val + return block_product * (getattr(cfg, 'num_stages', 1) or 1) + + +def _filter_configs_by_block_sizes(configs): + """Filter configs by TRITON_AUTOTUNE_BLOCK_SIZE_* env vars.""" + env_filters = {} + for suffix in ('M', 'N', 'K', 'DSTATE'): + env_val = os.environ.get(f"TRITON_AUTOTUNE_BLOCK_SIZE_{suffix}") + if env_val is not None: + env_filters[f'BLOCK_SIZE_{suffix}'] = int(env_val) + if not env_filters: + return None + matching = configs + for key, target in env_filters.items(): + matching = [c for c in matching if c.kwargs.get(key) == target] + return matching[:1] if matching else None + + def autotune_configs(configs): - """Wrap autotune configs for determinism. Uses cached autotuning if available, - otherwise selects single config via TRITON_AUTOTUNE_BLOCK_SIZE_N or TRITON_AUTOTUNE_CONFIG_INDEX.""" + """Select autotune configs for deterministic mode. + + Uses cached autotuning (TRITON_CACHE_AUTOTUNING=1) if Triton >= 3.4.0, + otherwise auto-selects the cheapest config by block size * stages. + """ if not configs or not use_deterministic_mode(): return configs - if TRITON_HAS_CACHE_RESULTS and os.environ.get("TRITON_CACHE_AUTOTUNING") == "1": return configs - global _autotune_warning_issued if not _autotune_warning_issued: _autotune_warning_issued = True - if TRITON_HAS_CACHE_RESULTS: - msg = "Deterministic mode: set TRITON_CACHE_AUTOTUNING=1 for cached autotuning." - else: - msg = "Deterministic mode: upgrade to Triton >= 3.4.0 for cached autotuning." + msg = "Deterministic mode: set TRITON_CACHE_AUTOTUNING=1 for cached autotuning." if TRITON_HAS_CACHE_RESULTS else "Deterministic mode: upgrade to Triton >= 3.4.0 for cached autotuning." warnings.warn(msg) - - block_size_n = os.environ.get("TRITON_AUTOTUNE_BLOCK_SIZE_N") - if block_size_n is not None: - target_n = int(block_size_n) - matching = [c for c in configs if c.kwargs.get('BLOCK_SIZE_N') == target_n] - if matching: - return matching[:1] - - idx = int(os.environ.get("TRITON_AUTOTUNE_CONFIG_INDEX", "-1")) - if idx < 0: - idx += len(configs) - idx = max(0, min(idx, len(configs) - 1)) - return configs[idx:idx + 1] + filtered = _filter_configs_by_block_sizes(configs) + if filtered: + return filtered + return [min(configs, key=_estimate_config_cost)] def alloc_tile_workspace(base_shape, tile_dim, dtype, device, deterministic, *, zero_init=True): @@ -72,16 +85,12 @@ def alloc_tile_workspace(base_shape, tile_dim, dtype, device, deterministic, *, factory = torch.zeros if zero_init else torch.empty tensor = factory(*base_shape, tile_dim, device=device, dtype=dtype) return tensor, tensor.stride(-1) - tensor = torch.empty(*base_shape, device=device, dtype=dtype) - return tensor, 0 + return torch.empty(*base_shape, device=device, dtype=dtype), 0 -def finalize_tile_workspace(tensor, deterministic, *, target_dtype=torch.float32): - """Collapse extra tile dimension (if needed) and optionally cast.""" +def finalize_tile_workspace(tensor, deterministic): if tensor is None: return None if deterministic: tensor = tensor.sum(dim=-1) - if target_dtype is not None and tensor.dtype != target_dtype: - tensor = tensor.to(target_dtype) return tensor diff --git a/tests/test_determinism.py b/tests/test_determinism.py index 39913f87c..b516fb8ac 100644 --- a/tests/test_determinism.py +++ b/tests/test_determinism.py @@ -21,8 +21,16 @@ def _max_abs_diff(a: torch.Tensor, b: torch.Tensor) -> float: return (a.float() - b.float()).abs().max().item() -def _make_inputs(*, seed: int, headdim: int, dstate: int) -> dict[str, torch.Tensor]: - """Inputs for determinism-enabled backward kernels.""" +def _make_inputs( + *, + seed: int, + headdim: int, + dstate: int, + chunk_size: int = 256, + ngroups: int = 1, + dtype: torch.dtype = torch.bfloat16, + d_has_hdim: bool = False, +) -> dict[str, torch.Tensor]: import math _set_seeds(seed) @@ -31,27 +39,29 @@ def _make_inputs(*, seed: int, headdim: int, dstate: int) -> dict[str, torch.Ten batch = 2 seqlen = 2048 nheads = 8 - ngroups = 1 - chunk_size = 256 nchunks = math.ceil(seqlen / chunk_size) - x = torch.randn(batch, seqlen, nheads, headdim, device=device, dtype=torch.bfloat16) + x = torch.randn(batch, seqlen, nheads, headdim, device=device, dtype=dtype) dout = torch.randn_like(x) dt = torch.randn(batch, nheads, nchunks, chunk_size, device=device, dtype=torch.float32) dA_cumsum = torch.randn_like(dt) - cb = torch.randn(batch, nchunks, ngroups, chunk_size, chunk_size, device=device, dtype=torch.bfloat16) + cb = torch.randn(batch, nchunks, ngroups, chunk_size, chunk_size, device=device, dtype=dtype) - B = torch.randn(batch, seqlen, ngroups, dstate, device=device, dtype=torch.bfloat16).contiguous() - C = torch.randn(batch, seqlen, ngroups, dstate, device=device, dtype=torch.bfloat16).contiguous() + B = torch.randn(batch, seqlen, ngroups, dstate, device=device, dtype=dtype).contiguous() + C = torch.randn(batch, seqlen, ngroups, dstate, device=device, dtype=dtype).contiguous() dstates = torch.randn(batch, nchunks, nheads, headdim, dstate, device=device, dtype=torch.float32) prev_states = torch.randn_like(dstates) ddA = torch.randn(batch, nheads, nchunks, chunk_size, device=device, dtype=torch.float32) ddt_out = torch.randn_like(ddA) - dt_raw = torch.randn(batch, seqlen, nheads, device=device, dtype=torch.bfloat16) + dt_raw = torch.randn(batch, seqlen, nheads, device=device, dtype=dtype) A = (torch.randn(nheads, device=device, dtype=torch.float32) * -1.0).contiguous() dt_bias = torch.randn(nheads, device=device, dtype=torch.float32).contiguous() - D = torch.randn(nheads, device=device, dtype=torch.float32) + # D shape: (nheads, headdim) when d_has_hdim=True, else (nheads,) + if d_has_hdim: + D = torch.randn(nheads, headdim, device=device, dtype=torch.float32) + else: + D = torch.randn(nheads, device=device, dtype=torch.float32) return { "x": x, @@ -73,11 +83,27 @@ def _make_inputs(*, seed: int, headdim: int, dstate: int) -> dict[str, torch.Ten def _run_case_outputs( - *, case: str, deterministic: bool, seed: int, headdim: int = 64, + *, + case: str, + deterministic: bool, + seed: int, + headdim: int = 64, + dstate: int = 64, + chunk_size: int = 256, + ngroups: int = 1, + dtype: torch.dtype = torch.bfloat16, + d_has_hdim: bool = False, ) -> dict[str, torch.Tensor]: - """Run one kernel wrapper and return named outputs (as fp32).""" _set_deterministic(deterministic) - t = _make_inputs(seed=seed, headdim=headdim, dstate=64) + t = _make_inputs( + seed=seed, + headdim=headdim, + dstate=dstate, + chunk_size=chunk_size, + ngroups=ngroups, + dtype=dtype, + d_has_hdim=d_has_hdim, + ) if case == "chunk_scan_bwd_dx": from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_dx @@ -103,7 +129,7 @@ def _run_case_outputs( from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_cumsum_bwd ddt, dA, ddt_bias = _chunk_cumsum_bwd(t["ddA"], t["ddt_out"], t["dt_raw"], t["A"], dt_bias=t["dt_bias"], dt_softplus=True) out = {"ddt": ddt, "dA": dA, "ddt_bias": ddt_bias} - elif case == "combined_bwd_dx": + elif case.startswith("combined_bwd_dx"): from mamba_ssm.ops.triton.ssd_combined import _chunk_scan_chunk_state_bwd_dx dx, ddt, dD = _chunk_scan_chunk_state_bwd_dx(t["x"], t["dt"], t["dA_cumsum"], t["B"], t["cb"], t["dout"], t["dstates"], D=t["D"]) out = {"dx": dx, "ddt": ddt, "dD": dD} @@ -114,26 +140,75 @@ def _run_case_outputs( return {k: v.detach().clone().float() for k, v in out.items() if v is not None} -_CASES = [ +_KERNEL_CASES = [ "chunk_scan_bwd_dx", "chunk_scan_bwd_dC", "chunk_state_bwd_dx", "chunk_state_bwd_db", "chunk_state_bwd_ddAcs_stable", "chunk_cumsum_bwd", - "combined_bwd_dx", ] -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") -@pytest.mark.parametrize("headdim", [32, 64]) -@pytest.mark.parametrize("case", _CASES) -def test_all_determinism_enabled_kernels_reproducible(case: str, headdim: int): +_COMBINED_CASES = [ + ("combined_bwd_dx", False), + ("combined_bwd_dx_d_hdim", True), +] + +_HEADDIMS = [64, 128] +_DSTATES = [64] + + +def _kernel_is_reproducible(case: str, headdim: int, dstate: int, d_has_hdim: bool = False): runs = 5 - outs = [_run_case_outputs(case=case, deterministic=True, seed=123, headdim=headdim) for _ in range(runs)] + outs = [ + _run_case_outputs(case=case, deterministic=True, seed=123, headdim=headdim, dstate=dstate, d_has_hdim=d_has_hdim) + for _ in range(runs) + ] ref = outs[0] for i in range(1, runs): for k in ref: - assert _max_abs_diff(ref[k], outs[i][k]) == 0.0, f"{case} output {k} differs (headdim={headdim})" + assert _max_abs_diff(ref[k], outs[i][k]) == 0.0, f"{case} output {k} differs (headdim={headdim}, dstate={dstate})" + + +def _kernel_close_to_default(case: str, headdim: int, dstate: int, d_has_hdim: bool = False): + atol = rtol = 1e-2 + det = _run_case_outputs(case=case, deterministic=True, seed=123, headdim=headdim, dstate=dstate, d_has_hdim=d_has_hdim) + for _ in range(3): + default = _run_case_outputs(case=case, deterministic=False, seed=123, headdim=headdim, dstate=dstate, d_has_hdim=d_has_hdim) + for k in det: + assert torch.allclose(default[k], det[k], atol=atol, rtol=rtol), f"{case} output {k} not close (headdim={headdim}, dstate={dstate})" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +@pytest.mark.parametrize("dstate", _DSTATES) +@pytest.mark.parametrize("headdim", _HEADDIMS) +@pytest.mark.parametrize("case", _KERNEL_CASES) +def test_kernel_reproducible(case: str, headdim: int, dstate: int): + _kernel_is_reproducible(case, headdim, dstate) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +@pytest.mark.parametrize("dstate", _DSTATES) +@pytest.mark.parametrize("headdim", _HEADDIMS) +@pytest.mark.parametrize("case,d_has_hdim", _COMBINED_CASES) +def test_combined_kernel_reproducible(case: str, d_has_hdim: bool, headdim: int, dstate: int): + _kernel_is_reproducible(case, headdim, dstate, d_has_hdim) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +@pytest.mark.parametrize("dstate", _DSTATES) +@pytest.mark.parametrize("headdim", _HEADDIMS) +@pytest.mark.parametrize("case", _KERNEL_CASES) +def test_kernel_close_to_default(case: str, headdim: int, dstate: int): + _kernel_close_to_default(case, headdim, dstate) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +@pytest.mark.parametrize("dstate", _DSTATES) +@pytest.mark.parametrize("headdim", _HEADDIMS) +@pytest.mark.parametrize("case,d_has_hdim", _COMBINED_CASES) +def test_combined_kernel_close_to_default(case: str, d_has_hdim: bool, headdim: int, dstate: int): + _kernel_close_to_default(case, headdim, dstate, d_has_hdim) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") @@ -149,15 +224,8 @@ def test_default_mode_is_not_reproducible(): _set_seeds(seed) model = Mamba2( - d_model=256, - d_state=64, - headdim=64, - expand=2, - d_conv=4, - chunk_size=256, - use_mem_eff_path=True, - device=device, - dtype=dtype, + d_model=256, d_state=64, headdim=64, expand=2, d_conv=4, chunk_size=256, + use_mem_eff_path=True, device=device, dtype=dtype, ).train() x_data = torch.randn(batch, seqlen, model.d_model, device=device, dtype=dtype) @@ -194,21 +262,7 @@ def _run() -> dict[str, torch.Tensor]: @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") -@pytest.mark.parametrize("headdim", [32, 64]) -@pytest.mark.parametrize("case", _CASES) -def test_all_determinism_enabled_kernels_close_to_default(case: str, headdim: int): - atol = 1e-2 - rtol = atol - det = _run_case_outputs(case=case, deterministic=True, seed=123, headdim=headdim) - for _ in range(3): - default = _run_case_outputs(case=case, deterministic=False, seed=123, headdim=headdim) - for k in det: - assert torch.allclose(default[k], det[k], atol=atol, rtol=rtol), f"{case} output {k} not close (headdim={headdim})" - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") -@pytest.mark.parametrize("headdim", [32, 128]) -def test_mamba2_fwd_bwd_deterministic_mode_is_reproducible(headdim: int): +def test_mamba2_fwd_bwd_deterministic_reproducible(): from mamba_ssm.modules.mamba2 import Mamba2 device = "cuda" @@ -217,20 +271,14 @@ def test_mamba2_fwd_bwd_deterministic_mode_is_reproducible(headdim: int): runs = 5 batch = 2 seqlen = 2048 + headdim = 64 _set_seeds(seed) _set_deterministic(True) model = Mamba2( - d_model=headdim, - d_state=16, - headdim=headdim, - expand=2, - d_conv=4, - chunk_size=16, - use_mem_eff_path=True, - device=device, - dtype=dtype, + d_model=headdim, d_state=16, headdim=headdim, expand=2, d_conv=4, chunk_size=16, + use_mem_eff_path=True, device=device, dtype=dtype, ).train() x_data = torch.randn(batch, seqlen, model.d_model, device=device, dtype=dtype) @@ -253,12 +301,11 @@ def _run() -> tuple[torch.Tensor, dict[str, torch.Tensor]]: assert _max_abs_diff(y0, y) == 0.0 assert g.keys() == g0.keys() for k in g0: - assert _max_abs_diff(g0[k], g[k]) == 0.0, f"Mamba2 grad {k} differs (headdim={headdim})" + assert _max_abs_diff(g0[k], g[k]) == 0.0, f"Mamba2 grad {k} differs" @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") -@pytest.mark.parametrize("headdim", [32, 64]) -def test_mamba2_fwd_bwd_deterministic_close_to_default(headdim: int): +def test_mamba2_fwd_bwd_deterministic_close_to_default(): from mamba_ssm.modules.mamba2 import Mamba2 device = "cuda" @@ -266,22 +313,15 @@ def test_mamba2_fwd_bwd_deterministic_close_to_default(headdim: int): seed = 123 batch = 2 seqlen = 2048 - atol = 1e-2 - rtol = 1e-2 + headdim = 64 + atol = rtol = 1e-2 def _run(deterministic: bool) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: torch.use_deterministic_algorithms(deterministic, warn_only=True) _set_seeds(seed) model = Mamba2( - d_model=headdim * 4, - d_state=32, - headdim=headdim, - expand=2, - d_conv=4, - chunk_size=64, - use_mem_eff_path=True, - device=device, - dtype=dtype, + d_model=headdim * 4, d_state=32, headdim=headdim, expand=2, d_conv=4, chunk_size=64, + use_mem_eff_path=True, device=device, dtype=dtype, ).train() x = torch.randn(batch, seqlen, model.d_model, device=device, dtype=dtype).requires_grad_(True) y = model(x) @@ -299,4 +339,4 @@ def _run(deterministic: bool) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: assert torch.allclose(y_default, y_det, atol=atol, rtol=rtol), "Mamba2 output differs" for k in g_default: - assert torch.allclose(g_default[k], g_det[k], atol=atol, rtol=rtol), f"Mamba2 grad {k} not close (headdim={headdim})" + assert torch.allclose(g_default[k], g_det[k], atol=atol, rtol=rtol), f"Mamba2 grad {k} not close" From 30b5d1ccf2f284e2cdcf3bb3ecae9aa23c2e27a2 Mon Sep 17 00:00:00 2001 From: Paul Gibbons Date: Wed, 17 Dec 2025 13:24:45 -0800 Subject: [PATCH 7/7] address dD gradient race in _chunk_scan_chunk_state_bwd_dx when headdim > BLOCK_SIZE_N Signed-off-by: Paul Gibbons --- mamba_ssm/ops/triton/ssd_chunk_scan.py | 2 +- mamba_ssm/ops/triton/ssd_chunk_state.py | 6 ++-- mamba_ssm/ops/triton/ssd_combined.py | 47 +++++++++++++++---------- 3 files changed, 33 insertions(+), 22 deletions(-) diff --git a/mamba_ssm/ops/triton/ssd_chunk_scan.py b/mamba_ssm/ops/triton/ssd_chunk_scan.py index 96f9d8907..df946fa19 100644 --- a/mamba_ssm/ops/triton/ssd_chunk_scan.py +++ b/mamba_ssm/ops/triton/ssd_chunk_scan.py @@ -1611,7 +1611,7 @@ def _chunk_scan_bwd_dx(cb, x, dt, dA_cumsum, dout, D=None): # BLOCK_SIZE_actual = _chunk_scan_bwd_dx_kernel.best_config.kwargs["BLOCK_SIZE_M"] # n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual # dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype) - ddt = finalize_tile_workspace(ddt, deterministic, target_dtype=dt.dtype) + ddt = finalize_tile_workspace(ddt, deterministic) return dx, ddt diff --git a/mamba_ssm/ops/triton/ssd_chunk_state.py b/mamba_ssm/ops/triton/ssd_chunk_state.py index 88ababdd4..c85f5616a 100644 --- a/mamba_ssm/ops/triton/ssd_chunk_state.py +++ b/mamba_ssm/ops/triton/ssd_chunk_state.py @@ -891,8 +891,8 @@ def _chunk_state_bwd_dx(B, x, dt, dA_cumsum, dstates, dx=None): DETERMINISTIC_REDUCTION=deterministic, BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), ) - ddt = finalize_tile_workspace(ddt, deterministic, target_dtype=dt.dtype) - ddA_cumsum = finalize_tile_workspace(ddA_cumsum, deterministic, target_dtype=dA_cumsum.dtype) + ddt = finalize_tile_workspace(ddt, deterministic) + ddA_cumsum = finalize_tile_workspace(ddA_cumsum, deterministic) if deterministic: # Match `_chunk_state_bwd_dx_kernel` atomic path (`tl.atomic_add(..., ddA_cs_last)` into last element). ddA_cumsum[..., -1] -= ddA_cumsum.sum(dim=-1) @@ -1010,7 +1010,7 @@ def _chunk_state_bwd_ddAcs_stable(B, x, dt, dA_cumsum, dstates, seq_idx=None): BLOCK_SIZE_M=max(triton.next_power_of_2(chunk_size), 16), BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), ) - ddA_cumsum = finalize_tile_workspace(ddA_cumsum, deterministic, target_dtype=ddA_cumsum.dtype) + ddA_cumsum = finalize_tile_workspace(ddA_cumsum, deterministic) torch.cumsum(ddA_cumsum[..., 1:], dim=-1, out=ddA_cumsum[..., 1:]) return ddA_cumsum diff --git a/mamba_ssm/ops/triton/ssd_combined.py b/mamba_ssm/ops/triton/ssd_combined.py index 8e7157286..ce9fbc713 100644 --- a/mamba_ssm/ops/triton/ssd_combined.py +++ b/mamba_ssm/ops/triton/ssd_combined.py @@ -65,15 +65,15 @@ def rearrange_and_update_stride(tensor, pattern=None, dim=2): @triton.autotune( configs=autotune_configs([ - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddt_ptr"])), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddt_ptr", "dD_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "dD_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "dD_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "dD_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "dD_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "dD_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "dD_ptr"])), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "dD_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "dD_ptr"])), ]), key=['chunk_size', 'hdim', 'dstate'], ) @@ -230,7 +230,10 @@ def _chunk_scan_chunk_state_bwd_dx_kernel( tl.store(dD_ptrs, dD, mask=offs_n < hdim) else: dD = tl.sum(dout_res * x) - tl.store(dD_ptr, dD) + if DETERMINISTIC_REDUCTION: + tl.store(dD_ptr + pid_n * stride_dD_hdim, dD) + else: + tl.atomic_add(dD_ptr, dD) ddt = tl.sum(acc * x, axis=1) ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize if DETERMINISTIC_REDUCTION: @@ -257,21 +260,28 @@ def _chunk_scan_chunk_state_bwd_dx(x, dt, dA_cumsum, B, CB, dout, dstates, D=Non assert dstates.shape == (batch, nchunks, nheads, headdim, dstate) if seq_idx is not None: assert seq_idx.shape == (batch, seqlen) + deterministic = use_deterministic_mode() if D is not None: assert D.shape == (nheads, headdim) or D.shape == (nheads,) assert D.stride(-1) == 1 BLOCK_SIZE_min = 32 - dD = torch.empty(triton.cdiv(chunk_size, BLOCK_SIZE_min), batch, nchunks, nheads, - headdim if D.dim() == 2 else 1, device=D.device, dtype=torch.float32) + pid_m_tiles = triton.cdiv(chunk_size, BLOCK_SIZE_min) + pid_n_tiles = math.ceil(headdim / _CHUNK_SCAN_CHUNK_STATE_BWD_DX_MIN_BLOCK_N) + if D.dim() == 2: + dD_hdim = headdim + elif deterministic: + dD_hdim = pid_n_tiles + else: + dD_hdim = 1 + dD = torch.zeros(pid_m_tiles, batch, nchunks, nheads, dD_hdim, device=D.device, dtype=torch.float32) + dD_strides = (dD.stride(0), dD.stride(1), dD.stride(2), dD.stride(3), dD.stride(4)) else: dD = None - dD_strides = ((dD.stride(0), dD.stride(1), dD.stride(2), dD.stride(3), dD.stride(4)) - if D is not None else (0, 0, 0, 0, 0)) + dD_strides = (0, 0, 0, 0, 0) if dx is None: dx = torch.empty_like(x) else: assert dx.shape == x.shape - deterministic = use_deterministic_mode() tile_count = math.ceil(headdim / _CHUNK_SCAN_CHUNK_STATE_BWD_DX_MIN_BLOCK_N) ddt, stride_ddt_tile = alloc_tile_workspace( (batch, nheads, nchunks, chunk_size), @@ -310,10 +320,11 @@ def _chunk_scan_chunk_state_bwd_dx(x, dt, dA_cumsum, B, CB, dout, dstates, D=Non if D is not None: BLOCK_SIZE_actual = _chunk_scan_chunk_state_bwd_dx_kernel.best_config.kwargs["BLOCK_SIZE_M"] n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual - dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype) + dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)) if D.dim() == 1: - dD = rearrange(dD, "h 1 -> h") - ddt = finalize_tile_workspace(ddt, deterministic, target_dtype=dt.dtype) + dD = dD.sum(dim=-1) + dD = dD.to(dtype=D.dtype) + ddt = finalize_tile_workspace(ddt, deterministic) return dx, ddt, dD