diff --git a/mamba_ssm/ops/triton/k_activations.py b/mamba_ssm/ops/triton/k_activations.py index 79fa2cc6..6ac41179 100644 --- a/mamba_ssm/ops/triton/k_activations.py +++ b/mamba_ssm/ops/triton/k_activations.py @@ -5,16 +5,18 @@ import triton import triton.language as tl +from mamba_ssm.utils.determinism import autotune_configs + @triton.autotune( - configs=[ + configs=autotune_configs([ triton.Config({'BLOCK_N': 32}), triton.Config({'BLOCK_N': 64}), triton.Config({'BLOCK_N': 128}), triton.Config({'BLOCK_N': 256}), triton.Config({'BLOCK_N': 512}), triton.Config({'BLOCK_N': 1024}), - ], + ]), key=['ncols'], ) @triton.jit @@ -61,14 +63,14 @@ def _swiglu_fwd(xy, out=None): @triton.autotune( - configs=[ + configs=autotune_configs([ triton.Config({'BLOCK_N': 32}), triton.Config({'BLOCK_N': 64}), triton.Config({'BLOCK_N': 128}), triton.Config({'BLOCK_N': 256}), triton.Config({'BLOCK_N': 512}), triton.Config({'BLOCK_N': 1024}), - ], + ]), key=['ncols'], ) @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["OUT"] is not None}) diff --git a/mamba_ssm/ops/triton/layer_norm.py b/mamba_ssm/ops/triton/layer_norm.py index 200b415a..3e61f298 100755 --- a/mamba_ssm/ops/triton/layer_norm.py +++ b/mamba_ssm/ops/triton/layer_norm.py @@ -16,6 +16,8 @@ import triton import triton.language as tl +from mamba_ssm.utils.determinism import autotune_configs + def layer_norm_ref( x, @@ -167,7 +169,7 @@ def config_prune(configs): pruned_configs_autotune = config_prune(configs_autotune) @triton.autotune( - configs = pruned_configs_autotune, + configs=autotune_configs(pruned_configs_autotune), key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"], ) # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) @@ -419,7 +421,7 @@ def _layer_norm_fwd( @triton.autotune( - configs=pruned_configs_autotune, + configs=autotune_configs(pruned_configs_autotune), key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS", "HAS_DROPOUT"], ) # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) diff --git a/mamba_ssm/ops/triton/ssd_bmm.py b/mamba_ssm/ops/triton/ssd_bmm.py index 48fd4f06..20f619f2 100644 --- a/mamba_ssm/ops/triton/ssd_bmm.py +++ b/mamba_ssm/ops/triton/ssd_bmm.py @@ -12,13 +12,15 @@ from einops import rearrange, repeat +from mamba_ssm.utils.determinism import autotune_configs + def init_to_zero(names): return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None] @triton.autotune( - configs=[ + configs=autotune_configs([ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), @@ -28,7 +30,7 @@ def init_to_zero(names): triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2), - ], + ]), key=['chunk_size', 'K', 'IS_CAUSAL'], ) @triton.jit @@ -92,7 +94,7 @@ def _bmm_chunk_fwd_kernel( @triton.autotune( - configs=[ + configs=autotune_configs([ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_CS': 64}, num_stages=3, num_warps=8), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4), @@ -102,7 +104,7 @@ def _bmm_chunk_fwd_kernel( triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_CS': 32}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=2), - ], + ]), key=['chunk_size', 'K'], ) @triton.jit diff --git a/mamba_ssm/ops/triton/ssd_chunk_scan.py b/mamba_ssm/ops/triton/ssd_chunk_scan.py index 95907806..df946fa1 100644 --- a/mamba_ssm/ops/triton/ssd_chunk_scan.py +++ b/mamba_ssm/ops/triton/ssd_chunk_scan.py @@ -15,6 +15,12 @@ from einops import rearrange, repeat from mamba_ssm.ops.triton.ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd +from mamba_ssm.utils.determinism import ( + alloc_tile_workspace, + finalize_tile_workspace, + use_deterministic_mode, + autotune_configs, +) TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0') @@ -24,7 +30,7 @@ def init_to_zero(names): @triton.autotune( - configs=[ + configs=autotune_configs([ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), @@ -36,7 +42,7 @@ def init_to_zero(names): triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2), - ], + ]), key=['chunk_size', 'hdim', 'dstate', 'IS_CAUSAL'], ) @triton.jit @@ -175,14 +181,14 @@ def _chunk_scan_fwd_kernel( @triton.autotune( - configs=[ + configs=autotune_configs([ # triton.Config({'BLOCK_SIZE_N': 256}, num_stages=4, num_warps=4), # triton.Config({'BLOCK_SIZE_N': 128}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_N': 64}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_N': 32}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_N': 64}, num_stages=4, num_warps=8), triton.Config({'BLOCK_SIZE_N': 32}, num_stages=4, num_warps=8), - ], + ]), key=['chunk_size', 'hdim', 'dstate'], ) @triton.jit @@ -331,12 +337,12 @@ def _chunk_scan_fwd_kernel_wip( @triton.autotune( - configs=[ + configs=autotune_configs([ triton.Config({'BLOCK_SIZE_M': 32}), triton.Config({'BLOCK_SIZE_M': 64}), triton.Config({'BLOCK_SIZE_M': 128}), triton.Config({'BLOCK_SIZE_M': 256}), - ], + ]), key=["chunk_size", "hdim"], ) @triton.jit @@ -426,7 +432,7 @@ def _chunk_scan_bwd_dz_kernel( @triton.autotune( - configs=[ + configs=autotune_configs([ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), @@ -436,7 +442,7 @@ def _chunk_scan_bwd_dz_kernel( triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2), - ], + ]), key=['hdim', 'dstate', 'chunk_size'], ) @triton.jit @@ -508,7 +514,7 @@ def _chunk_scan_bwd_dstates_kernel( @triton.autotune( - configs=[ + configs=autotune_configs([ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), @@ -517,7 +523,7 @@ def _chunk_scan_bwd_dstates_kernel( triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), - ], + ]), key=['chunk_size', 'dstate', 'hdim'], ) @triton.jit @@ -535,10 +541,11 @@ def _chunk_scan_bwd_dc_kernel( stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, stride_seq_idx_batch, stride_seq_idx_seqlen, stride_dc_batch, stride_dc_seqlen, stride_dc_split, stride_dc_group, stride_dc_dstate, - stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, + stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, stride_ddA_tile, # Meta-parameters HAS_DDA_CS: tl.constexpr, HAS_SEQ_IDX: tl.constexpr, + DETERMINISTIC_REDUCTION: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, ): pid_bc = tl.program_id(axis=1) @@ -556,7 +563,7 @@ def _chunk_scan_bwd_dc_kernel( dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dA_cs_head if HAS_DDA_CS: C_ptr += pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + pid_g * stride_C_head - ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_ddA_cs_head + ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_ddA_cs_head + pid_n * stride_ddA_tile if HAS_SEQ_IDX: seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen @@ -591,7 +598,10 @@ def _chunk_scan_bwd_dc_kernel( dc *= scale[:, None] if HAS_DDA_CS: ddA_cs = tl.sum(dc * c, axis=1) - tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size) + if DETERMINISTIC_REDUCTION: + tl.store(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size) + else: + tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size) acc += dc dout_ptrs += stride_dout_head prev_states_ptrs += stride_prev_states_head @@ -608,8 +618,13 @@ def _chunk_scan_bwd_dc_kernel( tl.store(dc_ptrs, acc, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate)) +_CHUNK_SCAN_BWD_DC_MIN_BLOCK_N = min( + cfg.kwargs['BLOCK_SIZE_N'] for cfg in _chunk_scan_bwd_dc_kernel.configs +) + + @triton.autotune( - configs=[ + configs=autotune_configs([ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddt_ptr"])), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), @@ -619,7 +634,7 @@ def _chunk_scan_bwd_dc_kernel( triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), - ], + ]), key=['chunk_size', 'hdim'], ) @triton.jit @@ -638,11 +653,12 @@ def _chunk_scan_bwd_dx_kernel( stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, stride_D_head, stride_dx_batch, stride_dx_seqlen, stride_dx_head, stride_dx_hdim, - stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize, + stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize, stride_ddt_tile, # stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_hdim, stride_dD_csize, # Meta-parameters HAS_D: tl.constexpr, D_HAS_HDIM: tl.constexpr, + DETERMINISTIC_REDUCTION: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, ): pid_bc = tl.program_id(axis=1) @@ -656,7 +672,7 @@ def _chunk_scan_bwd_dx_kernel( cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head - ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head + ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head + pid_n * stride_ddt_tile dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head # if HAS_D: # dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize @@ -715,7 +731,10 @@ def _chunk_scan_bwd_dx_kernel( x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) ddt = tl.sum(acc * x, axis=1) ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize - tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size) + if DETERMINISTIC_REDUCTION: + tl.store(ddt_ptrs, ddt, mask=offs_m < chunk_size) + else: + tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size) # if HAS_D: # dout_new_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_csize + offs_n[None, :] * stride_dout_hdim) @@ -724,9 +743,14 @@ def _chunk_scan_bwd_dx_kernel( # tl.store(dD_ptr + offs_n * stride_dD_hdim, dD, mask=offs_n < N) +_CHUNK_SCAN_BWD_DX_MIN_BLOCK_N = min( + cfg.kwargs['BLOCK_SIZE_N'] for cfg in _chunk_scan_bwd_dx_kernel.configs +) + + # Disabling HAS_DDA_CS for now since it's much slower @triton.autotune( - configs=[ + configs=autotune_configs([ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4), @@ -741,7 +765,7 @@ def _chunk_scan_bwd_dx_kernel( # triton.Config({'BLOCK_SIZE_M': 32}, num_stages=4, num_warps=8), # triton.Config({'BLOCK_SIZE_M': 64}, num_stages=4, num_warps=8), # triton.Config({'BLOCK_SIZE_M': 128}, num_stages=4, num_warps=8), - ], + ]), key=['chunk_size', 'hdim'], ) # @triton.heuristics({"BLOCK_SIZE_N": lambda args: max(triton.next_power_of_2(args["chunk_size"]), 16)}) @@ -853,12 +877,12 @@ def _chunk_scan_bwd_dcb_kernel( # Not numerically stable and should not be used. Leaving here for reference. @triton.autotune( - configs=[ + configs=autotune_configs([ triton.Config({'BLOCK_SIZE_M': 32}), triton.Config({'BLOCK_SIZE_M': 64}), triton.Config({'BLOCK_SIZE_M': 128}), triton.Config({'BLOCK_SIZE_M': 256}), - ], + ]), key=["chunk_size", "hdim"], ) @triton.jit @@ -931,7 +955,7 @@ def _chunk_scan_bwd_ddAcs_unstable_kernel( @triton.autotune( - configs=[ + configs=autotune_configs([ # triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4), # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4), # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4), @@ -948,7 +972,7 @@ def _chunk_scan_bwd_ddAcs_unstable_kernel( triton.Config({'BLOCK_SIZE_M': 32}, num_stages=4, num_warps=8), triton.Config({'BLOCK_SIZE_M': 64}, num_stages=4, num_warps=8), triton.Config({'BLOCK_SIZE_M': 128}, num_stages=4, num_warps=8), - ], + ]), key=['chunk_size', 'hdim'], ) @triton.jit @@ -1057,7 +1081,7 @@ def _chunk_scan_bwd_ddAcs_stable_kernel_old( @triton.autotune( - configs=[ + configs=autotune_configs([ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4), # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4), # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4), @@ -1067,7 +1091,7 @@ def _chunk_scan_bwd_ddAcs_stable_kernel_old( triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4), - ], + ]), key=['chunk_size', 'hdim'], ) @triton.jit @@ -1160,14 +1184,14 @@ def _chunk_scan_bwd_ddAcs_stable_kernel( @triton.autotune( - configs=[ + configs=autotune_configs([ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), - ], + ]), key=['chunk_size', 'dstate', 'hdim'], ) @triton.jit @@ -1433,15 +1457,30 @@ def _chunk_scan_bwd_dC(prev_states, dA_cumsum, dout, seq_idx=None, C=None, ngrou assert dout.shape == (batch, seqlen, nheads, headdim) if seq_idx is not None: assert seq_idx.shape == (batch, seqlen) + deterministic = use_deterministic_mode() if C is not None: assert C.shape == (batch, seqlen, ngroups, dstate) C_strides = (C.stride(0), C.stride(1), C.stride(2), C.stride(3)) - ddA_cumsum_prev = torch.empty(batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32) - ddA_cumsum_prev_strides = (ddA_cumsum_prev.stride(0), ddA_cumsum_prev.stride(2), ddA_cumsum_prev.stride(1), ddA_cumsum_prev.stride(3)) + tile_count = math.ceil(dstate / _CHUNK_SCAN_BWD_DC_MIN_BLOCK_N) + ddA_cumsum_prev, stride_ddA_tile = alloc_tile_workspace( + (batch, nheads, nchunks, chunk_size), + tile_count, + torch.float32, + dout.device, + deterministic, + zero_init=True, + ) + ddA_cumsum_prev_strides = ( + ddA_cumsum_prev.stride(0), + ddA_cumsum_prev.stride(2), + ddA_cumsum_prev.stride(1), + ddA_cumsum_prev.stride(3), + ) else: C_strides = (0, 0, 0, 0) ddA_cumsum_prev = None ddA_cumsum_prev_strides = (0, 0, 0, 0) + stride_ddA_tile = 0 nheads_ngroups_ratio = nheads // ngroups sm_count = torch.cuda.get_device_properties(dout.device).multi_processor_count nheads_per_program = max(min(math.ceil(batch * nchunks * nheads / sm_count), nheads_ngroups_ratio), 1) @@ -1460,12 +1499,15 @@ def _chunk_scan_bwd_dC(prev_states, dA_cumsum, dout, seq_idx=None, C=None, ngrou dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), dC.stride(0), dC.stride(1), dC.stride(2), dC.stride(3), dC.stride(4), - *ddA_cumsum_prev_strides, + *ddA_cumsum_prev_strides, stride_ddA_tile, HAS_DDA_CS=ddA_cumsum_prev is not None, HAS_SEQ_IDX=seq_idx is not None, + DETERMINISTIC_REDUCTION=deterministic, BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16), ) dC = dC.sum(2) + if ddA_cumsum_prev is not None: + ddA_cumsum_prev = finalize_tile_workspace(ddA_cumsum_prev, deterministic) return dC if C is None else (dC, ddA_cumsum_prev) @@ -1535,7 +1577,16 @@ def _chunk_scan_bwd_dx(cb, x, dt, dA_cumsum, dout, D=None): # else: # dD = None dx = torch.empty_like(x) - ddt = torch.empty(batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32) + deterministic = use_deterministic_mode() + tile_count = math.ceil(headdim / _CHUNK_SCAN_BWD_DX_MIN_BLOCK_N) + ddt, stride_ddt_tile = alloc_tile_workspace( + (batch, nheads, nchunks, chunk_size), + tile_count, + torch.float32, + dout.device, + deterministic, + zero_init=True, + ) grid_dx = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']), batch * nchunks, nheads) with torch.cuda.device(x.device.index): @@ -1550,16 +1601,18 @@ def _chunk_scan_bwd_dx(cb, x, dt, dA_cumsum, dout, D=None): dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), D.stride(0) if D is not None else 0, dx.stride(0), dx.stride(1), dx.stride(2), dx.stride(3), - ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3), + ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3), stride_ddt_tile, # dD.stride(1) if dD is not None else 0, dD.stride(2) if dD is not None else 0, dD.stride(3) if dD is not None else 0, dD.stride(4) if dD is not None else 0, dD.stride(0) if dD is not None else 0, D is not None, D.dim() == 2 if D is not None else True, + DETERMINISTIC_REDUCTION=deterministic, ) # if D is not None: # BLOCK_SIZE_actual = _chunk_scan_bwd_dx_kernel.best_config.kwargs["BLOCK_SIZE_M"] # n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual # dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype) - return dx, ddt.to(dtype=dt.dtype) + ddt = finalize_tile_workspace(ddt, deterministic) + return dx, ddt def _chunk_scan_bwd_ddAcs_unstable(x, dt, out, dout, ddt, D=None, subtract_ddtdt=True): diff --git a/mamba_ssm/ops/triton/ssd_chunk_state.py b/mamba_ssm/ops/triton/ssd_chunk_state.py index 633c66e8..c85f5616 100644 --- a/mamba_ssm/ops/triton/ssd_chunk_state.py +++ b/mamba_ssm/ops/triton/ssd_chunk_state.py @@ -13,13 +13,19 @@ from einops import rearrange, repeat from mamba_ssm.ops.triton.softplus import softplus +from mamba_ssm.utils.determinism import ( + alloc_tile_workspace, + finalize_tile_workspace, + use_deterministic_mode, + autotune_configs, +) def init_to_zero(names): return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None] @triton.autotune( - configs=[ + configs=autotune_configs([ triton.Config({'BLOCK_SIZE_H': 1}), triton.Config({'BLOCK_SIZE_H': 2}), triton.Config({'BLOCK_SIZE_H': 4}), @@ -27,7 +33,7 @@ def init_to_zero(names): triton.Config({'BLOCK_SIZE_H': 16}), triton.Config({'BLOCK_SIZE_H': 32}), triton.Config({'BLOCK_SIZE_H': 64}), - ], + ]), key=['chunk_size', 'nheads'], ) @triton.jit @@ -81,7 +87,7 @@ def _chunk_cumsum_fwd_kernel( @triton.autotune( - configs=[ + configs=autotune_configs([ triton.Config({'BLOCK_SIZE_H': 1}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])), triton.Config({'BLOCK_SIZE_H': 2}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])), triton.Config({'BLOCK_SIZE_H': 4}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])), @@ -89,7 +95,7 @@ def _chunk_cumsum_fwd_kernel( triton.Config({'BLOCK_SIZE_H': 16}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])), triton.Config({'BLOCK_SIZE_H': 32}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])), triton.Config({'BLOCK_SIZE_H': 64}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])), - ], + ]), key=['chunk_size', 'nheads'], ) @triton.jit @@ -107,12 +113,13 @@ def _chunk_cumsum_bwd_kernel( stride_A_head, stride_dt_bias_head, stride_ddt_batch, stride_ddt_seqlen, stride_ddt_head, - stride_dA_head, - stride_ddt_bias_head, + stride_dA_batch, stride_dA_chunk, stride_dA_head, + stride_ddt_bias_batch, stride_ddt_bias_chunk, stride_ddt_bias_head, # Meta-parameters DT_SOFTPLUS: tl.constexpr, HAS_DT_BIAS: tl.constexpr, BLOCK_SIZE_H: tl.constexpr, BLOCK_SIZE_CHUNK: tl.constexpr, + DETERMINISTIC_REDUCTION: tl.constexpr, ): pid_b = tl.program_id(axis=0) pid_c = tl.program_id(axis=1) @@ -153,14 +160,22 @@ def _chunk_cumsum_bwd_kernel( ddt = tl.where(dt_presoftplus <= 20.0, ddt * tl.sigmoid(dt_presoftplus), ddt) tl.store(ddt_ptrs, ddt, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit)) dA = tl.sum(ddA * dt, axis=1) - tl.atomic_add(dA_ptr + offs_h * stride_dA_head, dA, mask=offs_h < nheads) + dA_ptr += pid_b * stride_dA_batch + pid_c * stride_dA_chunk + if DETERMINISTIC_REDUCTION: + tl.store(dA_ptr + offs_h * stride_dA_head, dA, mask=offs_h < nheads) + else: + tl.atomic_add(dA_ptr + offs_h * stride_dA_head, dA, mask=offs_h < nheads) if HAS_DT_BIAS: ddt_bias = tl.sum(ddt, axis=1) - tl.atomic_add(ddt_bias_ptr + offs_h * stride_ddt_bias_head, ddt_bias, mask=offs_h < nheads) + ddt_bias_ptr += pid_b * stride_ddt_bias_batch + pid_c * stride_ddt_bias_chunk + if DETERMINISTIC_REDUCTION: + tl.store(ddt_bias_ptr + offs_h * stride_ddt_bias_head, ddt_bias, mask=offs_h < nheads) + else: + tl.atomic_add(ddt_bias_ptr + offs_h * stride_ddt_bias_head, ddt_bias, mask=offs_h < nheads) @triton.autotune( - configs=[ + configs=autotune_configs([ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), @@ -170,7 +185,7 @@ def _chunk_cumsum_bwd_kernel( triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2), - ], + ]), key=['hdim', 'dstate', 'chunk_size'], ) @triton.jit @@ -254,7 +269,7 @@ def _chunk_state_fwd_kernel( @triton.autotune( - configs=[ + configs=autotune_configs([ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])), @@ -264,7 +279,7 @@ def _chunk_state_fwd_kernel( triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])), - ], + ]), key=['chunk_size', 'hdim', 'dstate'], ) @triton.jit @@ -282,9 +297,10 @@ def _chunk_state_bwd_dx_kernel( stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, stride_dx_batch, stride_dx_seqlen, stride_dx_head, stride_dx_hdim, - stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize, - stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, + stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize, stride_ddt_tile, + stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, stride_ddA_tile, # Meta-parameters + DETERMINISTIC_REDUCTION: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_DSTATE: tl.constexpr, ): @@ -299,8 +315,8 @@ def _chunk_state_bwd_dx_kernel( b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + pid_h * stride_states_head dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head - ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head - ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head + ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head + pid_n * stride_ddt_tile + ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head + pid_n * stride_ddA_tile dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) @@ -341,12 +357,18 @@ def _chunk_state_bwd_dx_kernel( x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) ddt = tl.sum(acc * x, axis=1) ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize - tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size) + if DETERMINISTIC_REDUCTION: + tl.store(ddt_ptrs, ddt, mask=offs_m < chunk_size) + else: + tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size) ddA_cs = -(ddt * dt_m) ddA_cs_last = -tl.sum(ddA_cs) ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize - tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size) - tl.atomic_add(ddA_cumsum_ptr + (chunk_size - 1) * stride_ddA_cs_csize, ddA_cs_last) + if DETERMINISTIC_REDUCTION: + tl.store(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size) + else: + tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size) + tl.atomic_add(ddA_cumsum_ptr + (chunk_size - 1) * stride_ddA_cs_csize, ddA_cs_last) dx = (acc * dt_m[:, None]).to(dx_ptr.dtype.element_ty) dx_ptr += pid_b * stride_dx_batch + pid_c * chunk_size * stride_dx_seqlen + pid_h * stride_dx_head @@ -354,8 +376,13 @@ def _chunk_state_bwd_dx_kernel( tl.store(dx_ptrs, dx, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim)) +_CHUNK_STATE_BWD_DX_MIN_BLOCK_N = min( + cfg.kwargs['BLOCK_SIZE_N'] for cfg in _chunk_state_bwd_dx_kernel.configs +) + + @triton.autotune( - configs=[ + configs=autotune_configs([ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), @@ -364,7 +391,7 @@ def _chunk_state_bwd_dx_kernel( triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), - ], + ]), key=['chunk_size', 'dstate', 'hdim'], ) @triton.jit @@ -383,10 +410,11 @@ def _chunk_state_bwd_db_kernel( stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, stride_seq_idx_batch, stride_seq_idx_seqlen, stride_db_batch, stride_db_seqlen, stride_db_split, stride_db_group, stride_db_dstate, - stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, + stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, stride_ddA_tile, # Meta-parameters HAS_DDA_CS: tl.constexpr, HAS_SEQ_IDX: tl.constexpr, + DETERMINISTIC_REDUCTION: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, ): pid_bc = tl.program_id(axis=1) @@ -405,7 +433,7 @@ def _chunk_state_bwd_db_kernel( dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dA_cs_head if HAS_DDA_CS: b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + pid_g * stride_b_head - ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_ddA_cs_head + ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_ddA_cs_head + pid_n * stride_ddA_tile if HAS_SEQ_IDX: seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen @@ -446,7 +474,10 @@ def _chunk_state_bwd_db_kernel( if HAS_DDA_CS: # This is the gradient wrt (dA_cs_last - dA_cs_m), i.e. the exclusive reverse cumsum ddA_cs = tl.sum(db * b, axis=1) - tl.atomic_add(ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1) + if DETERMINISTIC_REDUCTION: + tl.store(ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1) + else: + tl.atomic_add(ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1) acc += db x_ptrs += stride_x_head dstates_ptrs += stride_states_head @@ -466,8 +497,13 @@ def _chunk_state_bwd_db_kernel( tl.store(db_ptrs, acc, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate)) +_CHUNK_STATE_BWD_DB_MIN_BLOCK_N = min( + cfg.kwargs['BLOCK_SIZE_N'] for cfg in _chunk_state_bwd_db_kernel.configs +) + + @triton.autotune( - configs=[ + configs=autotune_configs([ # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), @@ -485,7 +521,7 @@ def _chunk_state_bwd_db_kernel( triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), - ], + ]), key=['chunk_size', 'hdim', 'dstate'], ) @triton.jit @@ -503,9 +539,10 @@ def _chunk_state_bwd_ddAcs_stable_kernel( stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, stride_seq_idx_batch, stride_seq_idx_seqlen, - stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, + stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, stride_ddA_tile, # Meta-parameters HAS_SEQ_IDX: tl.constexpr, + DETERMINISTIC_REDUCTION: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_DSTATE: tl.constexpr, ): @@ -520,7 +557,7 @@ def _chunk_state_bwd_ddAcs_stable_kernel( b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + pid_h * stride_states_head dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head - ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head + ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head + pid_n * stride_ddA_tile dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head if HAS_SEQ_IDX: seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen @@ -574,12 +611,19 @@ def _chunk_state_bwd_ddAcs_stable_kernel( # ddA_cs = tl.cumsum(ddt * dt_m) ddA_cs = ddt * dt_m ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize - # tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size) - tl.atomic_add(ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1) + if DETERMINISTIC_REDUCTION: + tl.store(ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1) + else: + tl.atomic_add(ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1) + + +_CHUNK_STATE_BWD_DDACS_MIN_BLOCK_N = min( + cfg.kwargs['BLOCK_SIZE_N'] for cfg in _chunk_state_bwd_ddAcs_stable_kernel.configs +) @triton.autotune( - configs=[ + configs=autotune_configs([ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), @@ -589,7 +633,7 @@ def _chunk_state_bwd_ddAcs_stable_kernel( triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2), - ], + ]), key=['hdim', 'dstate', 'chunk_size'], ) @triton.jit @@ -703,20 +747,46 @@ def _chunk_cumsum_bwd(ddA, ddt_out, dt, A, dt_bias=None, dt_softplus=False, dt_l assert ddA.shape == (batch, nheads, nchunks, chunk_size) assert ddt_out.shape == (batch, nheads, nchunks, chunk_size) assert A.shape == (nheads,) + deterministic = use_deterministic_mode() if dt_bias is not None: assert dt_bias.shape == (nheads,) - ddt_bias = torch.empty_like(dt_bias, dtype=torch.float32) + if deterministic: + ddt_bias_workspace = torch.zeros( + batch, nchunks, nheads, device=dt.device, dtype=torch.float32 + ) + ddt_bias = torch.empty_like(dt_bias, dtype=torch.float32) + stride_ddt_bias_batch = ddt_bias_workspace.stride(0) + stride_ddt_bias_chunk = ddt_bias_workspace.stride(1) + else: + ddt_bias_workspace = ddt_bias = torch.empty_like( + dt_bias, dtype=torch.float32 + ) + stride_ddt_bias_batch = 0 + stride_ddt_bias_chunk = 0 else: ddt_bias = None + ddt_bias_workspace = None + stride_ddt_bias_batch = 0 + stride_ddt_bias_chunk = 0 if ddt is not None: assert ddt.shape == dt.shape else: ddt = torch.empty_like(dt) dA = torch.empty_like(A, dtype=torch.float32) + if deterministic: + dA_workspace = torch.zeros( + batch, nchunks, nheads, device=dt.device, dtype=torch.float32 + ) + stride_dA_batch = dA_workspace.stride(0) + stride_dA_chunk = dA_workspace.stride(1) + else: + dA_workspace = dA + stride_dA_batch = 0 + stride_dA_chunk = 0 grid_chunk_cs = lambda META: (batch, nchunks, triton.cdiv(nheads, META['BLOCK_SIZE_H'])) with torch.cuda.device(dt.device.index): _chunk_cumsum_bwd_kernel[grid_chunk_cs]( - ddA, ddt_out, dt, A, dt_bias, ddt, dA, ddt_bias, + ddA, ddt_out, dt, A, dt_bias, ddt, dA_workspace, ddt_bias_workspace if ddt_bias is not None else None, batch, seqlen, nheads, chunk_size, dt_limit[0], dt_limit[1], ddA.stride(0), ddA.stride(2), ddA.stride(1), ddA.stride(3), @@ -725,12 +795,17 @@ def _chunk_cumsum_bwd(ddA, ddt_out, dt, A, dt_bias=None, dt_softplus=False, dt_l A.stride(0), dt_bias.stride(0) if dt_bias is not None else 0, ddt.stride(0), ddt.stride(1), ddt.stride(2), - dA.stride(0), - ddt_bias.stride(0) if ddt_bias is not None else 0, + stride_dA_batch, stride_dA_chunk, dA_workspace.stride(-1), + stride_ddt_bias_batch, stride_ddt_bias_chunk, (ddt_bias_workspace.stride(-1) if ddt_bias is not None else 0), dt_softplus, HAS_DT_BIAS=dt_bias is not None, BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size), + DETERMINISTIC_REDUCTION=deterministic, ) + if deterministic: + dA.copy_(dA_workspace.sum(dim=(0, 1))) + if ddt_bias is not None: + ddt_bias.copy_(ddt_bias_workspace.sum(dim=(0, 1))) return ddt, dA, ddt_bias @@ -780,8 +855,24 @@ def _chunk_state_bwd_dx(B, x, dt, dA_cumsum, dstates, dx=None): assert dx.shape == x.shape else: dx = torch.empty_like(x) - ddt = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32) - ddA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=dA_cumsum.device, dtype=torch.float32) + deterministic = use_deterministic_mode() + tile_count = math.ceil(headdim / _CHUNK_STATE_BWD_DX_MIN_BLOCK_N) + ddt, stride_ddt_tile = alloc_tile_workspace( + (batch, nheads, nchunks, chunk_size), + tile_count, + torch.float32, + dt.device, + deterministic, + zero_init=True, + ) + ddA_cumsum, stride_ddA_tile = alloc_tile_workspace( + (batch, nheads, nchunks, chunk_size), + tile_count, + torch.float32, + dA_cumsum.device, + deterministic, + zero_init=True, + ) grid_dx = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']), batch * nchunks, nheads) with torch.cuda.device(x.device.index): @@ -795,11 +886,17 @@ def _chunk_state_bwd_dx(B, x, dt, dA_cumsum, dstates, dx=None): dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), dx.stride(0), dx.stride(1), dx.stride(2), dx.stride(3), - ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3), - ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), + ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3), stride_ddt_tile, + ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), stride_ddA_tile, + DETERMINISTIC_REDUCTION=deterministic, BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), ) - return dx, ddt.to(dt.dtype), ddA_cumsum.to(dA_cumsum.dtype) + ddt = finalize_tile_workspace(ddt, deterministic) + ddA_cumsum = finalize_tile_workspace(ddA_cumsum, deterministic) + if deterministic: + # Match `_chunk_state_bwd_dx_kernel` atomic path (`tl.atomic_add(..., ddA_cs_last)` into last element). + ddA_cumsum[..., -1] -= ddA_cumsum.sum(dim=-1) + return dx, ddt, ddA_cumsum def _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=None, B=None, ngroups=1): @@ -811,16 +908,31 @@ def _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=None, B=None, ngroups assert dstates.shape == (batch, nchunks, nheads, headdim, dstate) if seq_idx is not None: assert seq_idx.shape == (batch, seqlen) + deterministic = use_deterministic_mode() if B is not None: assert B.shape == (batch, seqlen, ngroups, dstate) B_strides = (B.stride(0), B.stride(1), B.stride(2), B.stride(3)) # Use torch.empty since the Triton kernel will call init_to_zero - ddA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=x.device, dtype=torch.float32) - ddA_cumsum_strides = (ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3)) + tile_count = math.ceil(dstate / _CHUNK_STATE_BWD_DB_MIN_BLOCK_N) + ddA_cumsum, stride_ddA_tile = alloc_tile_workspace( + (batch, nheads, nchunks, chunk_size), + tile_count, + torch.float32, + x.device, + deterministic, + zero_init=True, + ) + ddA_cumsum_strides = ( + ddA_cumsum.stride(0), + ddA_cumsum.stride(2), + ddA_cumsum.stride(1), + ddA_cumsum.stride(3), + ) else: B_strides = (0, 0, 0, 0) ddA_cumsum = None ddA_cumsum_strides = (0, 0, 0, 0) + stride_ddA_tile = 0 nheads_ngroups_ratio = nheads // ngroups sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count nheads_per_program = max(min(math.ceil(batch * nchunks * nheads / sm_count), nheads_ngroups_ratio), 1) @@ -840,13 +952,15 @@ def _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=None, B=None, ngroups dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), dB.stride(0), dB.stride(1), dB.stride(2), dB.stride(3), dB.stride(4), - *ddA_cumsum_strides, + *ddA_cumsum_strides, stride_ddA_tile, HAS_DDA_CS=ddA_cumsum is not None, HAS_SEQ_IDX=seq_idx is not None, + DETERMINISTIC_REDUCTION=deterministic, BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16), ) dB = dB.sum(2) if ddA_cumsum is not None: + ddA_cumsum = finalize_tile_workspace(ddA_cumsum, deterministic) # The first element of ddA_cumsum is always zero, since that dA_cumsum does not contribute # to the state of the chunk. # torch.cumsum(ddA_cumsum[..., 1:], dim=-1, out=ddA_cumsum[..., 1:]) @@ -867,7 +981,16 @@ def _chunk_state_bwd_ddAcs_stable(B, x, dt, dA_cumsum, dstates, seq_idx=None): if seq_idx is not None: assert seq_idx.shape == (batch, seqlen) # Use torch.empty since the Triton kernel will call init_to_zero - ddA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=x.device, dtype=torch.float32) + deterministic = use_deterministic_mode() + tile_count = math.ceil(headdim / _CHUNK_STATE_BWD_DDACS_MIN_BLOCK_N) + ddA_cumsum, stride_ddA_tile = alloc_tile_workspace( + (batch, nheads, nchunks, chunk_size), + tile_count, + torch.float32, + x.device, + deterministic, + zero_init=True, + ) grid_ddtcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']), batch * nchunks, nheads) with torch.cuda.device(x.device.index): @@ -881,11 +1004,13 @@ def _chunk_state_bwd_ddAcs_stable(B, x, dt, dA_cumsum, dstates, seq_idx=None): dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), - ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), + ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), stride_ddA_tile, HAS_SEQ_IDX=seq_idx is not None, + DETERMINISTIC_REDUCTION=deterministic, BLOCK_SIZE_M=max(triton.next_power_of_2(chunk_size), 16), BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), ) + ddA_cumsum = finalize_tile_workspace(ddA_cumsum, deterministic) torch.cumsum(ddA_cumsum[..., 1:], dim=-1, out=ddA_cumsum[..., 1:]) return ddA_cumsum diff --git a/mamba_ssm/ops/triton/ssd_combined.py b/mamba_ssm/ops/triton/ssd_combined.py index bbf4ecf8..ce9fbc71 100644 --- a/mamba_ssm/ops/triton/ssd_combined.py +++ b/mamba_ssm/ops/triton/ssd_combined.py @@ -42,6 +42,12 @@ from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_ddAcs_prev from mamba_ssm.ops.triton.layernorm_gated import rmsnorm_fn, _layer_norm_fwd, _layer_norm_bwd from mamba_ssm.ops.triton.k_activations import _swiglu_fwd, _swiglu_bwd +from mamba_ssm.utils.determinism import ( + alloc_tile_workspace, + autotune_configs, + finalize_tile_workspace, + use_deterministic_mode, +) TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0') @@ -58,17 +64,17 @@ def rearrange_and_update_stride(tensor, pattern=None, dim=2): @triton.autotune( - configs=[ - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddt_ptr"])), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), - ], + configs=autotune_configs([ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddt_ptr", "dD_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "dD_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "dD_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "dD_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "dD_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "dD_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "dD_ptr"])), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "dD_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "dD_ptr"])), + ]), key=['chunk_size', 'hdim', 'dstate'], ) @triton.jit @@ -91,7 +97,7 @@ def _chunk_scan_chunk_state_bwd_dx_kernel( stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate, stride_dstates_batch, stride_dstates_chunk, stride_dstates_head, stride_dstates_hdim, stride_dstates_dstate, stride_dx_batch, stride_dx_seqlen, stride_dx_head, stride_dx_hdim, - stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize, + stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize, stride_ddt_tile, stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_csize, stride_dD_hdim, # Meta-parameters HAS_D: tl.constexpr, @@ -100,6 +106,7 @@ def _chunk_scan_chunk_state_bwd_dx_kernel( BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_DSTATE: tl.constexpr, IS_TRITON_22: tl.constexpr, + DETERMINISTIC_REDUCTION: tl.constexpr, ): pid_bc = tl.program_id(axis=1) pid_c = pid_bc // batch @@ -112,7 +119,7 @@ def _chunk_scan_chunk_state_bwd_dx_kernel( cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head - ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head + ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head + pid_n * stride_ddt_tile dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + pid_h * stride_dstates_head @@ -223,10 +230,21 @@ def _chunk_scan_chunk_state_bwd_dx_kernel( tl.store(dD_ptrs, dD, mask=offs_n < hdim) else: dD = tl.sum(dout_res * x) - tl.store(dD_ptr, dD) + if DETERMINISTIC_REDUCTION: + tl.store(dD_ptr + pid_n * stride_dD_hdim, dD) + else: + tl.atomic_add(dD_ptr, dD) ddt = tl.sum(acc * x, axis=1) ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize - tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size) + if DETERMINISTIC_REDUCTION: + tl.store(ddt_ptrs, ddt, mask=offs_m < chunk_size) + else: + tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size) + + +_CHUNK_SCAN_CHUNK_STATE_BWD_DX_MIN_BLOCK_N = min( + cfg.kwargs['BLOCK_SIZE_N'] for cfg in _chunk_scan_chunk_state_bwd_dx_kernel.configs +) def _chunk_scan_chunk_state_bwd_dx(x, dt, dA_cumsum, B, CB, dout, dstates, D=None, seq_idx=None, dx=None): @@ -242,21 +260,37 @@ def _chunk_scan_chunk_state_bwd_dx(x, dt, dA_cumsum, B, CB, dout, dstates, D=Non assert dstates.shape == (batch, nchunks, nheads, headdim, dstate) if seq_idx is not None: assert seq_idx.shape == (batch, seqlen) + deterministic = use_deterministic_mode() if D is not None: assert D.shape == (nheads, headdim) or D.shape == (nheads,) assert D.stride(-1) == 1 BLOCK_SIZE_min = 32 - dD = torch.empty(triton.cdiv(chunk_size, BLOCK_SIZE_min), batch, nchunks, nheads, - headdim if D.dim() == 2 else 1, device=D.device, dtype=torch.float32) + pid_m_tiles = triton.cdiv(chunk_size, BLOCK_SIZE_min) + pid_n_tiles = math.ceil(headdim / _CHUNK_SCAN_CHUNK_STATE_BWD_DX_MIN_BLOCK_N) + if D.dim() == 2: + dD_hdim = headdim + elif deterministic: + dD_hdim = pid_n_tiles + else: + dD_hdim = 1 + dD = torch.zeros(pid_m_tiles, batch, nchunks, nheads, dD_hdim, device=D.device, dtype=torch.float32) + dD_strides = (dD.stride(0), dD.stride(1), dD.stride(2), dD.stride(3), dD.stride(4)) else: dD = None - dD_strides = ((dD.stride(0), dD.stride(1), dD.stride(2), dD.stride(3), dD.stride(4)) - if D is not None else (0, 0, 0, 0, 0)) + dD_strides = (0, 0, 0, 0, 0) if dx is None: dx = torch.empty_like(x) else: assert dx.shape == x.shape - ddt = torch.empty(batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32) + tile_count = math.ceil(headdim / _CHUNK_SCAN_CHUNK_STATE_BWD_DX_MIN_BLOCK_N) + ddt, stride_ddt_tile = alloc_tile_workspace( + (batch, nheads, nchunks, chunk_size), + tile_count, + torch.float32, + dout.device, + deterministic, + zero_init=True, + ) grid_dx = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']), batch * nchunks, nheads) with torch.cuda.device(x.device.index): @@ -274,21 +308,24 @@ def _chunk_scan_chunk_state_bwd_dx(x, dt, dA_cumsum, B, CB, dout, dstates, D=Non B.stride(0), B.stride(1), B.stride(2), B.stride(3), dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4), dx.stride(0), dx.stride(1), dx.stride(2), dx.stride(3), - ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3), + ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3), stride_ddt_tile, dD_strides[1], dD_strides[2], dD_strides[3], dD_strides[0], dD_strides[4], D is not None, D.dim() == 2 if D is not None else True, HAS_SEQ_IDX=seq_idx is not None, BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), - IS_TRITON_22=TRITON_22 + IS_TRITON_22=TRITON_22, + DETERMINISTIC_REDUCTION=deterministic, ) if D is not None: BLOCK_SIZE_actual = _chunk_scan_chunk_state_bwd_dx_kernel.best_config.kwargs["BLOCK_SIZE_M"] n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual - dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype) + dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)) if D.dim() == 1: - dD = rearrange(dD, "h 1 -> h") - return dx, ddt.to(dtype=dt.dtype), dD + dD = dD.sum(dim=-1) + dD = dD.to(dtype=D.dtype) + ddt = finalize_tile_workspace(ddt, deterministic) + return dx, ddt, dD def _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, cu_seqlens=None, dt_softplus=False, dt_limit=(0.0, float("inf"))): diff --git a/mamba_ssm/ops/triton/ssd_state_passing.py b/mamba_ssm/ops/triton/ssd_state_passing.py index 63863b82..d6aa53c9 100644 --- a/mamba_ssm/ops/triton/ssd_state_passing.py +++ b/mamba_ssm/ops/triton/ssd_state_passing.py @@ -12,16 +12,18 @@ from einops import rearrange, repeat +from mamba_ssm.utils.determinism import autotune_configs + @triton.autotune( - configs=[ + configs=autotune_configs([ triton.Config({'BLOCK_SIZE': 64}), triton.Config({'BLOCK_SIZE': 128}), triton.Config({'BLOCK_SIZE': 256}), triton.Config({'BLOCK_SIZE': 512}), triton.Config({'BLOCK_SIZE': 1024}), triton.Config({'BLOCK_SIZE': 2048}), - ], + ]), key=['dim'], ) @triton.jit @@ -86,14 +88,14 @@ def _state_passing_fwd_kernel( @triton.autotune( - configs=[ + configs=autotune_configs([ triton.Config({'BLOCK_SIZE': 64}), triton.Config({'BLOCK_SIZE': 128}), triton.Config({'BLOCK_SIZE': 256}), triton.Config({'BLOCK_SIZE': 512}), triton.Config({'BLOCK_SIZE': 1024}), triton.Config({'BLOCK_SIZE': 2048}), - ], + ]), key=['dim'], ) @triton.jit diff --git a/mamba_ssm/utils/determinism.py b/mamba_ssm/utils/determinism.py new file mode 100644 index 00000000..c6066f80 --- /dev/null +++ b/mamba_ssm/utils/determinism.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. + +import os +import warnings +from packaging import version + +import torch + +try: + import triton + TRITON_VERSION = version.parse(triton.__version__) +except ImportError: + TRITON_VERSION = version.parse("0.0.0") + +TRITON_HAS_CACHE_RESULTS = TRITON_VERSION >= version.parse("3.4.0") +_autotune_warning_issued = False + +_deterministic_override = None + + +def use_deterministic_mode(): + if _deterministic_override is not None: + return _deterministic_override + env = os.environ.get('MAMBA_DETERMINISTIC') + if env: + return env[0] == '1' + return torch.are_deterministic_algorithms_enabled() + + +def set_deterministic_mode(value): + global _deterministic_override + _deterministic_override = value + + +def _estimate_config_cost(cfg): + """Estimate shared memory cost of a config. Lower is cheaper.""" + block_product = 1 + for key, val in cfg.kwargs.items(): + if key.startswith('BLOCK_SIZE_'): + block_product *= val + return block_product * (getattr(cfg, 'num_stages', 1) or 1) + + +def _filter_configs_by_block_sizes(configs): + """Filter configs by TRITON_AUTOTUNE_BLOCK_SIZE_* env vars.""" + env_filters = {} + for suffix in ('M', 'N', 'K', 'DSTATE'): + env_val = os.environ.get(f"TRITON_AUTOTUNE_BLOCK_SIZE_{suffix}") + if env_val is not None: + env_filters[f'BLOCK_SIZE_{suffix}'] = int(env_val) + if not env_filters: + return None + matching = configs + for key, target in env_filters.items(): + matching = [c for c in matching if c.kwargs.get(key) == target] + return matching[:1] if matching else None + + +def autotune_configs(configs): + """Select autotune configs for deterministic mode. + + Uses cached autotuning (TRITON_CACHE_AUTOTUNING=1) if Triton >= 3.4.0, + otherwise auto-selects the cheapest config by block size * stages. + """ + if not configs or not use_deterministic_mode(): + return configs + if TRITON_HAS_CACHE_RESULTS and os.environ.get("TRITON_CACHE_AUTOTUNING") == "1": + return configs + global _autotune_warning_issued + if not _autotune_warning_issued: + _autotune_warning_issued = True + msg = "Deterministic mode: set TRITON_CACHE_AUTOTUNING=1 for cached autotuning." if TRITON_HAS_CACHE_RESULTS else "Deterministic mode: upgrade to Triton >= 3.4.0 for cached autotuning." + warnings.warn(msg) + filtered = _filter_configs_by_block_sizes(configs) + if filtered: + return filtered + return [min(configs, key=_estimate_config_cost)] + + +def alloc_tile_workspace(base_shape, tile_dim, dtype, device, deterministic, *, zero_init=True): + """Allocate buffer for deterministic per-program reductions.""" + if base_shape is None: + return None, 0 + if deterministic: + factory = torch.zeros if zero_init else torch.empty + tensor = factory(*base_shape, tile_dim, device=device, dtype=dtype) + return tensor, tensor.stride(-1) + return torch.empty(*base_shape, device=device, dtype=dtype), 0 + + +def finalize_tile_workspace(tensor, deterministic): + if tensor is None: + return None + if deterministic: + tensor = tensor.sum(dim=-1) + return tensor diff --git a/tests/benchmark_determinism_kernels.py b/tests/benchmark_determinism_kernels.py new file mode 100644 index 00000000..3897018d --- /dev/null +++ b/tests/benchmark_determinism_kernels.py @@ -0,0 +1,146 @@ +#!/usr/bin/env python +# Copyright (c) 2024, Tri Dao, Albert Gu. + +import gc +import math + +import torch +from triton.testing import do_bench + +from mamba_ssm.utils.determinism import set_deterministic_mode + +MODEL_PRESETS = { + "small": {"nheads": 32, "headdim": 64, "dstate": 64, "ngroups": 1}, + "nemotronh-56b": {"nheads": 256, "headdim": 64, "dstate": 256, "ngroups": 8}, +} + + +def _reset_peak_memory() -> None: + gc.collect() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + torch.cuda.synchronize() + + +def _peak_memory_mb(fn, *, warmup: int = 3) -> float: + for _ in range(warmup): + fn() + torch.cuda.synchronize() + _reset_peak_memory() + fn() + torch.cuda.synchronize() + return torch.cuda.max_memory_allocated() / (1024 * 1024) + + +def make_tensors(*, batch: int, seqlen: int, nheads: int, headdim: int, dstate: int, ngroups: int, chunk_size: int, + dtype: torch.dtype = torch.bfloat16) -> dict[str, torch.Tensor]: + device = "cuda" + nchunks = math.ceil(seqlen / chunk_size) + return { + "x": torch.randn(batch, seqlen, nheads, headdim, device=device, dtype=dtype), + "B": torch.randn(batch, seqlen, ngroups, dstate, device=device, dtype=dtype), + "C": torch.randn(batch, seqlen, ngroups, dstate, device=device, dtype=dtype), + "dt": torch.randn(batch, nheads, nchunks, chunk_size, device=device, dtype=torch.float32), + "dA_cumsum": torch.randn(batch, nheads, nchunks, chunk_size, device=device, dtype=torch.float32), + "dstates": torch.randn(batch, nchunks, nheads, headdim, dstate, device=device, dtype=torch.float32), + "dout": torch.randn(batch, seqlen, nheads, headdim, device=device, dtype=dtype), + "ddA": torch.randn(batch, nheads, nchunks, chunk_size, device=device, dtype=torch.float32), + "ddt_out": torch.randn(batch, nheads, nchunks, chunk_size, device=device, dtype=torch.float32), + "dt_raw": torch.randn(batch, seqlen, nheads, device=device, dtype=dtype), + "A": torch.randn(nheads, device=device, dtype=torch.float32) * -1, + "dt_bias": torch.randn(nheads, device=device, dtype=torch.float32), + "prev_states": torch.randn(batch, nchunks, nheads, headdim, dstate, device=device, dtype=torch.float32), + "cb": torch.randn(batch, nchunks, ngroups, chunk_size, chunk_size, device=device, dtype=dtype), + } + + +def get_benchmarks(t: dict[str, torch.Tensor], *, ngroups: int): + from mamba_ssm.ops.triton.ssd_chunk_state import ( + _chunk_cumsum_bwd, + _chunk_state_bwd_db, + _chunk_state_bwd_ddAcs_stable, + _chunk_state_bwd_dx, + ) + from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_dC, _chunk_scan_bwd_dx + from mamba_ssm.ops.triton.ssd_combined import _chunk_scan_chunk_state_bwd_dx + + x = t["x"].contiguous() + B = t["B"].contiguous() + C = t["C"].contiguous() + dout = t["dout"].contiguous() + dstates = t["dstates"].contiguous() + + return [ + ("chunk_cumsum_bwd", lambda: _chunk_cumsum_bwd(t["ddA"], t["ddt_out"], t["dt_raw"], t["A"], dt_bias=t["dt_bias"], dt_softplus=True)), + ("chunk_state_bwd_dx", lambda: _chunk_state_bwd_dx(B, x, t["dt"], t["dA_cumsum"], dstates)), + ("chunk_state_bwd_db", lambda: _chunk_state_bwd_db(x, t["dt"], t["dA_cumsum"], dstates, B=B, ngroups=ngroups)), + ("chunk_state_bwd_ddAcs", lambda: _chunk_state_bwd_ddAcs_stable(B, x, t["dt"], t["dA_cumsum"], dstates)), + ("chunk_scan_bwd_dC", lambda: _chunk_scan_bwd_dC(t["prev_states"], t["dA_cumsum"], dout, C=C, ngroups=ngroups)), + ("chunk_scan_bwd_dx", lambda: _chunk_scan_bwd_dx(t["cb"], x, t["dt"], t["dA_cumsum"], dout)), + ("combined_bwd_dx", lambda: _chunk_scan_chunk_state_bwd_dx(x, t["dt"], t["dA_cumsum"], B, t["cb"], dout, dstates)), + ] + + +def _run_one(fn, *, deterministic: bool, warmup: int, rep: int): + set_deterministic_mode(deterministic) + ms = do_bench(fn, warmup=warmup, rep=rep, return_mode="median") + peak_mb = _peak_memory_mb(fn, warmup=1) + return ms, peak_mb + + +def main() -> None: + import argparse + + parser = argparse.ArgumentParser(description="Benchmark determinism overhead for key Triton backward kernels") + parser.add_argument("--preset", choices=sorted(MODEL_PRESETS.keys()), default="small") + parser.add_argument("--warmup", type=int, default=25) + parser.add_argument("--rep", type=int, default=100) + parser.add_argument("--batch", type=int, default=4) + parser.add_argument("--seqlen", type=int, default=2048) + parser.add_argument("--chunk-size", type=int, default=256) + args = parser.parse_args() + + if not torch.cuda.is_available(): + raise SystemExit("CUDA not available") + + p = MODEL_PRESETS[args.preset] + tensors = make_tensors( + batch=args.batch, + seqlen=args.seqlen, + nheads=p["nheads"], + headdim=p["headdim"], + dstate=p["dstate"], + ngroups=p["ngroups"], + chunk_size=args.chunk_size, + ) + benches = get_benchmarks(tensors, ngroups=p["ngroups"]) + + print(f"GPU: {torch.cuda.get_device_name(0)}") + print(f"preset={args.preset} batch={args.batch} seqlen={args.seqlen} chunk_size={args.chunk_size}") + print(f"{'kernel':<20} {'ms':>9} {'det_ms':>9} {'ms_%':>6} {'MB':>9} {'det_MB':>9} {'MB_%':>6}") + + rows = [] + try: + for name, fn in benches: + ms, mb = _run_one(fn, deterministic=False, warmup=args.warmup, rep=args.rep) + det_ms, det_mb = _run_one(fn, deterministic=True, warmup=args.warmup, rep=args.rep) + ms_pct = (det_ms / ms - 1.0) * 100.0 + mb_pct = (det_mb / mb - 1.0) * 100.0 if mb else 0.0 + rows.append((name, ms, det_ms, ms_pct, mb, det_mb, mb_pct)) + print(f"{name:<20} {ms:>9.3f} {det_ms:>9.3f} {ms_pct:>+6.0f}% {mb:>9.1f} {det_mb:>9.1f} {mb_pct:>+6.0f}%") + finally: + set_deterministic_mode(None) + + total_ms = sum(r[1] for r in rows) + total_det_ms = sum(r[2] for r in rows) + max_mb = max(r[4] for r in rows) if rows else 0.0 + max_det_mb = max(r[5] for r in rows) if rows else 0.0 + total_pct = (total_det_ms / total_ms - 1.0) * 100.0 if total_ms else 0.0 + max_mb_pct = (max_det_mb / max_mb - 1.0) * 100.0 if max_mb else 0.0 + print(f"{'TOTAL/MAX':<20} {total_ms:>9.3f} {total_det_ms:>9.3f} {total_pct:>+6.0f}% {max_mb:>9.1f} {max_det_mb:>9.1f} {max_mb_pct:>+6.0f}%") + + +if __name__ == "__main__": + main() + + diff --git a/tests/test_determinism.py b/tests/test_determinism.py new file mode 100644 index 00000000..b516fb8a --- /dev/null +++ b/tests/test_determinism.py @@ -0,0 +1,342 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. + +import os + +import pytest +import torch + + +def _set_deterministic(enabled: bool) -> None: + torch.use_deterministic_algorithms(enabled) + if enabled: + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + + +def _set_seeds(seed: int) -> None: + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def _max_abs_diff(a: torch.Tensor, b: torch.Tensor) -> float: + return (a.float() - b.float()).abs().max().item() + + +def _make_inputs( + *, + seed: int, + headdim: int, + dstate: int, + chunk_size: int = 256, + ngroups: int = 1, + dtype: torch.dtype = torch.bfloat16, + d_has_hdim: bool = False, +) -> dict[str, torch.Tensor]: + import math + + _set_seeds(seed) + device = "cuda" + + batch = 2 + seqlen = 2048 + nheads = 8 + nchunks = math.ceil(seqlen / chunk_size) + + x = torch.randn(batch, seqlen, nheads, headdim, device=device, dtype=dtype) + dout = torch.randn_like(x) + dt = torch.randn(batch, nheads, nchunks, chunk_size, device=device, dtype=torch.float32) + dA_cumsum = torch.randn_like(dt) + cb = torch.randn(batch, nchunks, ngroups, chunk_size, chunk_size, device=device, dtype=dtype) + + B = torch.randn(batch, seqlen, ngroups, dstate, device=device, dtype=dtype).contiguous() + C = torch.randn(batch, seqlen, ngroups, dstate, device=device, dtype=dtype).contiguous() + dstates = torch.randn(batch, nchunks, nheads, headdim, dstate, device=device, dtype=torch.float32) + prev_states = torch.randn_like(dstates) + + ddA = torch.randn(batch, nheads, nchunks, chunk_size, device=device, dtype=torch.float32) + ddt_out = torch.randn_like(ddA) + dt_raw = torch.randn(batch, seqlen, nheads, device=device, dtype=dtype) + A = (torch.randn(nheads, device=device, dtype=torch.float32) * -1.0).contiguous() + dt_bias = torch.randn(nheads, device=device, dtype=torch.float32).contiguous() + # D shape: (nheads, headdim) when d_has_hdim=True, else (nheads,) + if d_has_hdim: + D = torch.randn(nheads, headdim, device=device, dtype=torch.float32) + else: + D = torch.randn(nheads, device=device, dtype=torch.float32) + + return { + "x": x, + "dout": dout, + "dt": dt, + "dA_cumsum": dA_cumsum, + "cb": cb, + "B": B, + "C": C, + "dstates": dstates, + "prev_states": prev_states, + "ddA": ddA, + "ddt_out": ddt_out, + "dt_raw": dt_raw, + "A": A, + "dt_bias": dt_bias, + "D": D, + } + + +def _run_case_outputs( + *, + case: str, + deterministic: bool, + seed: int, + headdim: int = 64, + dstate: int = 64, + chunk_size: int = 256, + ngroups: int = 1, + dtype: torch.dtype = torch.bfloat16, + d_has_hdim: bool = False, +) -> dict[str, torch.Tensor]: + _set_deterministic(deterministic) + t = _make_inputs( + seed=seed, + headdim=headdim, + dstate=dstate, + chunk_size=chunk_size, + ngroups=ngroups, + dtype=dtype, + d_has_hdim=d_has_hdim, + ) + + if case == "chunk_scan_bwd_dx": + from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_dx + dx, ddt = _chunk_scan_bwd_dx(t["cb"], t["x"], t["dt"], t["dA_cumsum"], t["dout"]) + out = {"dx": dx, "ddt": ddt} + elif case == "chunk_scan_bwd_dC": + from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_dC + dC, ddA_prev = _chunk_scan_bwd_dC(t["prev_states"], t["dA_cumsum"], t["dout"], C=t["C"], ngroups=1) + out = {"dC": dC, "ddA_cumsum_prev": ddA_prev} + elif case == "chunk_state_bwd_dx": + from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_state_bwd_dx + dx, ddt, ddA = _chunk_state_bwd_dx(t["B"], t["x"], t["dt"], t["dA_cumsum"], t["dstates"]) + out = {"dx": dx, "ddt": ddt, "ddA_cumsum": ddA} + elif case == "chunk_state_bwd_db": + from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_state_bwd_db + dB, ddA = _chunk_state_bwd_db(t["x"], t["dt"], t["dA_cumsum"], t["dstates"], B=t["B"], ngroups=1) + out = {"dB": dB, "ddA_cumsum": ddA} + elif case == "chunk_state_bwd_ddAcs_stable": + from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_state_bwd_ddAcs_stable + ddA = _chunk_state_bwd_ddAcs_stable(t["B"], t["x"], t["dt"], t["dA_cumsum"], t["dstates"]) + out = {"ddA_cumsum": ddA} + elif case == "chunk_cumsum_bwd": + from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_cumsum_bwd + ddt, dA, ddt_bias = _chunk_cumsum_bwd(t["ddA"], t["ddt_out"], t["dt_raw"], t["A"], dt_bias=t["dt_bias"], dt_softplus=True) + out = {"ddt": ddt, "dA": dA, "ddt_bias": ddt_bias} + elif case.startswith("combined_bwd_dx"): + from mamba_ssm.ops.triton.ssd_combined import _chunk_scan_chunk_state_bwd_dx + dx, ddt, dD = _chunk_scan_chunk_state_bwd_dx(t["x"], t["dt"], t["dA_cumsum"], t["B"], t["cb"], t["dout"], t["dstates"], D=t["D"]) + out = {"dx": dx, "ddt": ddt, "dD": dD} + else: + raise AssertionError(f"Unknown case: {case}") + + torch.cuda.synchronize() + return {k: v.detach().clone().float() for k, v in out.items() if v is not None} + + +_KERNEL_CASES = [ + "chunk_scan_bwd_dx", + "chunk_scan_bwd_dC", + "chunk_state_bwd_dx", + "chunk_state_bwd_db", + "chunk_state_bwd_ddAcs_stable", + "chunk_cumsum_bwd", +] + +_COMBINED_CASES = [ + ("combined_bwd_dx", False), + ("combined_bwd_dx_d_hdim", True), +] + +_HEADDIMS = [64, 128] +_DSTATES = [64] + + +def _kernel_is_reproducible(case: str, headdim: int, dstate: int, d_has_hdim: bool = False): + runs = 5 + outs = [ + _run_case_outputs(case=case, deterministic=True, seed=123, headdim=headdim, dstate=dstate, d_has_hdim=d_has_hdim) + for _ in range(runs) + ] + ref = outs[0] + for i in range(1, runs): + for k in ref: + assert _max_abs_diff(ref[k], outs[i][k]) == 0.0, f"{case} output {k} differs (headdim={headdim}, dstate={dstate})" + + +def _kernel_close_to_default(case: str, headdim: int, dstate: int, d_has_hdim: bool = False): + atol = rtol = 1e-2 + det = _run_case_outputs(case=case, deterministic=True, seed=123, headdim=headdim, dstate=dstate, d_has_hdim=d_has_hdim) + for _ in range(3): + default = _run_case_outputs(case=case, deterministic=False, seed=123, headdim=headdim, dstate=dstate, d_has_hdim=d_has_hdim) + for k in det: + assert torch.allclose(default[k], det[k], atol=atol, rtol=rtol), f"{case} output {k} not close (headdim={headdim}, dstate={dstate})" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +@pytest.mark.parametrize("dstate", _DSTATES) +@pytest.mark.parametrize("headdim", _HEADDIMS) +@pytest.mark.parametrize("case", _KERNEL_CASES) +def test_kernel_reproducible(case: str, headdim: int, dstate: int): + _kernel_is_reproducible(case, headdim, dstate) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +@pytest.mark.parametrize("dstate", _DSTATES) +@pytest.mark.parametrize("headdim", _HEADDIMS) +@pytest.mark.parametrize("case,d_has_hdim", _COMBINED_CASES) +def test_combined_kernel_reproducible(case: str, d_has_hdim: bool, headdim: int, dstate: int): + _kernel_is_reproducible(case, headdim, dstate, d_has_hdim) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +@pytest.mark.parametrize("dstate", _DSTATES) +@pytest.mark.parametrize("headdim", _HEADDIMS) +@pytest.mark.parametrize("case", _KERNEL_CASES) +def test_kernel_close_to_default(case: str, headdim: int, dstate: int): + _kernel_close_to_default(case, headdim, dstate) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +@pytest.mark.parametrize("dstate", _DSTATES) +@pytest.mark.parametrize("headdim", _HEADDIMS) +@pytest.mark.parametrize("case,d_has_hdim", _COMBINED_CASES) +def test_combined_kernel_close_to_default(case: str, d_has_hdim: bool, headdim: int, dstate: int): + _kernel_close_to_default(case, headdim, dstate, d_has_hdim) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +def test_default_mode_is_not_reproducible(): + from mamba_ssm.modules.mamba2 import Mamba2 + + device = "cuda" + dtype = torch.bfloat16 + seed = 123 + runs = 20 + batch = 4 + seqlen = 4096 + + _set_seeds(seed) + model = Mamba2( + d_model=256, d_state=64, headdim=64, expand=2, d_conv=4, chunk_size=256, + use_mem_eff_path=True, device=device, dtype=dtype, + ).train() + x_data = torch.randn(batch, seqlen, model.d_model, device=device, dtype=dtype) + + def _run() -> dict[str, torch.Tensor]: + _set_deterministic(False) + model.zero_grad(set_to_none=True) + x = x_data.clone().requires_grad_(True) + y = model(x) + (y.float().square().mean()).backward() + torch.cuda.synchronize() + grads = {"input": x.grad.detach().float().clone()} + for name, p in model.named_parameters(): + if p.grad is not None: + grads[name] = p.grad.detach().float().clone() + return grads + + _run() # warmup + ref = _run() + observed_diff = False + for _ in range(runs - 1): + g = _run() + for k in ref: + if _max_abs_diff(ref[k], g[k]) != 0.0: + observed_diff = True + break + if observed_diff: + break + + if not observed_diff: + pytest.xfail( + f"Did not observe nondeterminism in default mode after {runs} runs. " + "This GPU may have deterministic atomic behavior at these shapes." + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +def test_mamba2_fwd_bwd_deterministic_reproducible(): + from mamba_ssm.modules.mamba2 import Mamba2 + + device = "cuda" + dtype = torch.bfloat16 + seed = 123 + runs = 5 + batch = 2 + seqlen = 2048 + headdim = 64 + + _set_seeds(seed) + _set_deterministic(True) + + model = Mamba2( + d_model=headdim, d_state=16, headdim=headdim, expand=2, d_conv=4, chunk_size=16, + use_mem_eff_path=True, device=device, dtype=dtype, + ).train() + x_data = torch.randn(batch, seqlen, model.d_model, device=device, dtype=dtype) + + def _run() -> tuple[torch.Tensor, dict[str, torch.Tensor]]: + model.zero_grad(set_to_none=True) + x = x_data.clone().requires_grad_(True) + y = model(x) + (y.float().square().mean()).backward() + torch.cuda.synchronize() + grads: dict[str, torch.Tensor] = {"input": x.grad.detach().float().clone()} + for name, p in model.named_parameters(): + if p.grad is not None: + grads[name] = p.grad.detach().float().clone() + return y.detach().float().clone(), grads + + _run() # warmup + y0, g0 = _run() + for _ in range(runs - 1): + y, g = _run() + assert _max_abs_diff(y0, y) == 0.0 + assert g.keys() == g0.keys() + for k in g0: + assert _max_abs_diff(g0[k], g[k]) == 0.0, f"Mamba2 grad {k} differs" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +def test_mamba2_fwd_bwd_deterministic_close_to_default(): + from mamba_ssm.modules.mamba2 import Mamba2 + + device = "cuda" + dtype = torch.bfloat16 + seed = 123 + batch = 2 + seqlen = 2048 + headdim = 64 + atol = rtol = 1e-2 + + def _run(deterministic: bool) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: + torch.use_deterministic_algorithms(deterministic, warn_only=True) + _set_seeds(seed) + model = Mamba2( + d_model=headdim * 4, d_state=32, headdim=headdim, expand=2, d_conv=4, chunk_size=64, + use_mem_eff_path=True, device=device, dtype=dtype, + ).train() + x = torch.randn(batch, seqlen, model.d_model, device=device, dtype=dtype).requires_grad_(True) + y = model(x) + (y.float().square().mean()).backward() + torch.cuda.synchronize() + grads: dict[str, torch.Tensor] = {"input": x.grad.detach().float().clone()} + for name, p in model.named_parameters(): + if p.grad is not None: + grads[name] = p.grad.detach().float().clone() + return y.detach().float().clone(), grads + + _run(False) # warmup + y_default, g_default = _run(False) + y_det, g_det = _run(True) + + assert torch.allclose(y_default, y_det, atol=atol, rtol=rtol), "Mamba2 output differs" + for k in g_default: + assert torch.allclose(g_default[k], g_det[k], atol=atol, rtol=rtol), f"Mamba2 grad {k} not close"