Skip to content
10 changes: 6 additions & 4 deletions mamba_ssm/ops/triton/k_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,18 @@
import triton
import triton.language as tl

from mamba_ssm.utils.determinism import autotune_configs


@triton.autotune(
configs=[
configs=autotune_configs([
triton.Config({'BLOCK_N': 32}),
triton.Config({'BLOCK_N': 64}),
triton.Config({'BLOCK_N': 128}),
triton.Config({'BLOCK_N': 256}),
triton.Config({'BLOCK_N': 512}),
triton.Config({'BLOCK_N': 1024}),
],
]),
key=['ncols'],
)
@triton.jit
Expand Down Expand Up @@ -61,14 +63,14 @@ def _swiglu_fwd(xy, out=None):


@triton.autotune(
configs=[
configs=autotune_configs([
triton.Config({'BLOCK_N': 32}),
triton.Config({'BLOCK_N': 64}),
triton.Config({'BLOCK_N': 128}),
triton.Config({'BLOCK_N': 256}),
triton.Config({'BLOCK_N': 512}),
triton.Config({'BLOCK_N': 1024}),
],
]),
key=['ncols'],
)
@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["OUT"] is not None})
Expand Down
6 changes: 4 additions & 2 deletions mamba_ssm/ops/triton/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import triton
import triton.language as tl

from mamba_ssm.utils.determinism import autotune_configs


def layer_norm_ref(
x,
Expand Down Expand Up @@ -167,7 +169,7 @@ def config_prune(configs):
pruned_configs_autotune = config_prune(configs_autotune)

@triton.autotune(
configs = pruned_configs_autotune,
configs=autotune_configs(pruned_configs_autotune),
key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
)
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
Expand Down Expand Up @@ -419,7 +421,7 @@ def _layer_norm_fwd(


@triton.autotune(
configs=pruned_configs_autotune,
configs=autotune_configs(pruned_configs_autotune),
key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS", "HAS_DROPOUT"],
)
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
Expand Down
10 changes: 6 additions & 4 deletions mamba_ssm/ops/triton/ssd_bmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@

from einops import rearrange, repeat

from mamba_ssm.utils.determinism import autotune_configs


def init_to_zero(names):
return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None]


@triton.autotune(
configs=[
configs=autotune_configs([
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
Expand All @@ -28,7 +30,7 @@ def init_to_zero(names):
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),
],
]),
key=['chunk_size', 'K', 'IS_CAUSAL'],
)
@triton.jit
Expand Down Expand Up @@ -92,7 +94,7 @@ def _bmm_chunk_fwd_kernel(


@triton.autotune(
configs=[
configs=autotune_configs([
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_CS': 64}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
Expand All @@ -102,7 +104,7 @@ def _bmm_chunk_fwd_kernel(
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_CS': 32}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=2),
],
]),
key=['chunk_size', 'K'],
)
@triton.jit
Expand Down
Loading