From bb1d32b555d49df6059e5ec10a2111be35a78454 Mon Sep 17 00:00:00 2001 From: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Date: Tue, 19 Aug 2025 12:16:49 +0300 Subject: [PATCH 01/11] Add failing mamba2 prefill chunking unittest Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> --- tests/kernels/mamba/test_mamba_ssm_ssd.py | 216 +++++++++++++++++++++- 1 file changed, 209 insertions(+), 7 deletions(-) diff --git a/tests/kernels/mamba/test_mamba_ssm_ssd.py b/tests/kernels/mamba/test_mamba_ssm_ssd.py index 2c554baaff76..e2e87b5fd85e 100644 --- a/tests/kernels/mamba/test_mamba_ssm_ssd.py +++ b/tests/kernels/mamba/test_mamba_ssm_ssd.py @@ -115,7 +115,8 @@ def generate_continuous_batched_examples(example_lens_by_batch, n_heads, d_head, itype, - device='cuda'): + device='cuda', + return_ref=True): # this function generates a random examples of certain length # and then cut according to "example_lens_by_batch" and feed @@ -125,11 +126,13 @@ def generate_continuous_batched_examples(example_lens_by_batch, A, dt, X, B, C = generate_random_inputs(num_examples, full_length, n_heads, d_head, itype) - Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1), - A * dt, - B, - C, - block_len=full_length // 4) + if return_ref: + Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1), + A * dt, + B, + C, + block_len=full_length // + 4) # internal function that outputs a cont batch of examples # given a tuple of lengths for each example in the batch @@ -179,7 +182,8 @@ def end_boundary(n: int): IND_S = [x % full_length for x in IND_E] IND_E = [end_boundary(x + y) for x, y in zip(IND_S, spec)] - yield ([Y_min[s, IND_S[s]:IND_E[s]] for s in range(num_examples)], + yield ([Y_min[s, IND_S[s]:IND_E[s]] + for s in range(num_examples)] if return_ref else None, cu_seqlens, seq_idx.unsqueeze(0), (A, dt2, X2, B2, C2)) @@ -324,3 +328,201 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, if clear: states[i].fill_(0.) exhausted[i] = False + + +@pytest.mark.parametrize("chunk_size", [8, 128]) +@pytest.mark.parametrize("max_seqlen", [16, 270]) +def test_mamba_chunk_scan_cont_batch_prefill_chunking(max_seqlen, chunk_size): + # This test can have larger error for longer sequences + if max_seqlen > 256: + atol, rtol = 1e-2, 5e-3 + else: + atol, rtol = 5e-3, 5e-3 + + batch_size = 4 + n_heads = 16 + d_head = 64 + itype = torch.float32 + + last_taken = {} + exhausted = {} + device = "cuda" + current_platform.seed_everything(0) + per_example_seqlens = torch.randint(1, + max_seqlen + 1, (batch_size, ), + dtype=torch.int32, + device=device) + per_example_seqlens[0] = max_seqlen + _, cu_seqlens, seq_idx, (A, dt, X, B, C) = next( + generate_continuous_batched_examples( + [tuple(per_example_seqlens.tolist())], + batch_size, + max_seqlen, + last_taken, + exhausted, + n_heads, + d_head, + itype, + device, + return_ref=False)) + + ## full seqlen computation + chunk_indices, chunk_offsets = \ + _query_start_loc_to_chunk_indices_offsets( + cu_seqlens, chunk_size, cu_seqlens[-1]) + Y_ref = torch.empty_like(X) + state_ref = mamba_chunk_scan_combined( + X, + dt, + A, + B, + C, + chunk_size, + D=None, + cu_seqlens=cu_seqlens, + seq_idx=seq_idx, + chunk_indices=chunk_indices, + chunk_offsets=chunk_offsets, + return_varlen_states=True, + initial_states=None, + out=Y_ref, + ) + + ## chunked seqlen computation + # first chunk + chunked_seqlens = per_example_seqlens // 2 + chunked_cu_seqlens = torch.cat([ + torch.tensor([0], device=device), + torch.cumsum(chunked_seqlens, dim=0) + ], + dim=0) + chunked_seq_idx = torch.repeat_interleave( + torch.arange(len(chunked_seqlens), device=device), + chunked_seqlens, + output_size=chunked_cu_seqlens[-1]).unsqueeze(0).to(torch.int32) + chunked_input_seq_len = chunked_cu_seqlens[-1] + X_chunked = torch.zeros_like(X)[:, :chunked_input_seq_len, ...] + dt_chunked = torch.zeros_like(dt)[:, :chunked_input_seq_len, ...] + B_chunked = torch.zeros_like(B)[:, :chunked_input_seq_len, ...] + C_chunked = torch.zeros_like(C)[:, :chunked_input_seq_len, ...] + for i in range(batch_size): + # fmt: off + chunk_f = lambda x, i: x[:, cu_seqlens[i]:cu_seqlens[i] + chunked_seqlens[i], ...] # noqa: E501 + + X_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(X, i) # noqa: E501 + dt_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(dt, i) # noqa: E501 + B_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(B, i) # noqa: E501 + C_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(C, i) # noqa: E501 + # fmt: on + + chunk_indices, chunk_offsets = \ + _query_start_loc_to_chunk_indices_offsets( + chunked_cu_seqlens, chunk_size, chunked_cu_seqlens[-1]) + Y_partial = torch.empty_like(X_chunked) + partial_state = mamba_chunk_scan_combined( + X_chunked, + dt_chunked, + A, + B_chunked, + C_chunked, + chunk_size, + D=None, + cu_seqlens=chunked_cu_seqlens, + seq_idx=chunked_seq_idx, + chunk_indices=chunk_indices, + chunk_offsets=chunk_offsets, + return_varlen_states=True, + initial_states=None, + out=Y_partial, + ) + + # remaining chunk + remaining_chunked_seqlens = per_example_seqlens - chunked_seqlens + remaining_chunked_cu_seqlens = torch.cat([ + torch.tensor([0], device=device), + torch.cumsum(remaining_chunked_seqlens, dim=0) + ], + dim=0) + remaining_chunked_seq_idx = torch.repeat_interleave( + torch.arange(len(remaining_chunked_seqlens), device=device), + remaining_chunked_seqlens, + output_size=remaining_chunked_cu_seqlens[-1]).unsqueeze(0).to( + torch.int32) + remaining_chunked_input_seq_len = remaining_chunked_cu_seqlens[-1] + # fmt: off + remaining_X_chunked = torch.zeros_like(X)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501 + remaining_dt_chunked = torch.zeros_like(dt)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501 + remaining_B_chunked = torch.zeros_like(B)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501 + remaining_C_chunked = torch.zeros_like(C)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501 + for i in range(batch_size): + remaining_chunk_f = lambda x, i: x[:, cu_seqlens[i] + chunked_seqlens[i]:cu_seqlens[i+1], ...] # noqa: E501 + + remaining_X_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(X, i) # noqa: E501 + remaining_dt_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(dt, i) # noqa: E501 + remaining_B_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(B, i) # noqa: E501 + remaining_C_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(C, i) # noqa: E501 + + # assert input chunking is correct + concat_chunk_f = lambda pt1, pt2, i: torch.cat([ + pt1[:,chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1],...], + pt2[:,remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1],...], + ], + dim=1) + concat_batch_f = lambda pt1, pt2: torch.cat([concat_chunk_f(pt1, pt2, i) for i in range(batch_size)], dim=1) # noqa: E501 + # fmt: on + + assert concat_batch_f(X_chunked, remaining_X_chunked).equal(X) + assert concat_batch_f(dt_chunked, remaining_dt_chunked).equal(dt) + assert concat_batch_f(B_chunked, remaining_B_chunked).equal(B) + assert concat_batch_f(C_chunked, remaining_C_chunked).equal(C) + + chunk_indices, chunk_offsets = \ + _query_start_loc_to_chunk_indices_offsets( + remaining_chunked_cu_seqlens, + chunk_size, + remaining_chunked_cu_seqlens[-1]) + + Y_chunked = torch.empty_like(remaining_X_chunked) + state_chunked = mamba_chunk_scan_combined( + remaining_X_chunked, + remaining_dt_chunked, + A, + remaining_B_chunked, + remaining_C_chunked, + chunk_size, + D=None, + cu_seqlens=remaining_chunked_cu_seqlens, + seq_idx=remaining_chunked_seq_idx, + chunk_indices=chunk_indices, + chunk_offsets=chunk_offsets, + return_varlen_states=True, + initial_states=partial_state, + out=Y_chunked, + ) + Y = concat_batch_f(Y_partial, Y_chunked) + + # kernel chunked is same as kernel overall + for i in range(batch_size): + Y_seq = Y[:, cu_seqlens[i]:cu_seqlens[i + 1], ...] + Y_ref_seq = Y_ref[:, cu_seqlens[i]:cu_seqlens[i + 1], ...] + torch.testing.assert_close( + Y_seq[:, :chunked_seqlens[i], ...], + Y_ref_seq[:, :chunked_seqlens[i], ...], + atol=atol, + rtol=rtol, + msg=lambda x: f"seq{i} output part1 " + x) # noqa: B023 + torch.testing.assert_close( + Y_seq[:, chunked_seqlens[i]:, ...], + Y_ref_seq[:, chunked_seqlens[i]:, ...], + atol=atol, + rtol=rtol, + msg=lambda x: f"seq{i} output part2 " + x) # noqa: B023 + + state_seq = state_chunked[i] + state_seq_ref = state_ref[i] + torch.testing.assert_close( + state_seq, + state_seq_ref, + atol=atol, + rtol=rtol, + msg=lambda x: f"seq{i} state " + x) # noqa: B023 From 5ce3ce7a13cd55d9adca309fa47aec547863593a Mon Sep 17 00:00:00 2001 From: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Date: Tue, 19 Aug 2025 14:56:59 +0300 Subject: [PATCH 02/11] Fix chunked prefill + valren batching bugs in mamba2 triton kernels Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> --- .../layers/mamba/ops/ssd_chunk_scan.py | 3 + .../layers/mamba/ops/ssd_combined.py | 9 ++- .../layers/mamba/ops/ssd_state_passing.py | 71 +++++++++++++++---- 3 files changed, 66 insertions(+), 17 deletions(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py index 365139e237c6..fb8350e191c9 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -289,6 +289,9 @@ def _chunk_scan_fwd_kernel( # get the cs at the offset boundary # - c_off == 0 is a passthrough + # - We need dA_cs at the boundary, defined by c_off - no need + # to increase pointer by pid_m (it is a constant offset, + # i.e. the same for all blocks) dA_cs_m_boundary = tl.load( dA_cumsum_ptr + (c_off - 1) * stride_dA_cs_csize, mask=(((c_off - 1) > -1) and ((c_off) < chunk_size)), diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index d0b3e9e5235b..fcc5c905bf77 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -106,21 +106,24 @@ def _mamba_chunk_scan_combined_fwd(x, # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries # (middle term of factorization of off-diag blocks; A terms) # - for handling chunked prefill, this requires i) initial_states - # ii) seq_idx and iii) is_cont_batched to be all specified. + # ii) seq_idx iii) is_cont_batched and (iv) chunk_offsets to be all specified. # - When a new seq_idx is detected, we will stop passing the prev_state # and switch accordingly to the init_state corresponding to the new seq_idx. + # - We will also make sure that the dA_cumsum is taken only from the start of the + # sequence (hence we need the full dA_cumsum tensor and not just the values at chunk boundaries) # - this will ensure that states will be updated with the rightmost flushed seq_idx # of the previous chunk. This implies that the first chunk of states is either 0 # or equal to init_states of the first example. states, final_states = _state_passing_fwd( rearrange(states, "... p n -> ... (p n)"), - dA_cumsum[:, :, :, -1], + dA_cumsum, initial_states=rearrange(initial_states, "... p n -> ... (p n)") if initial_states is not None else None, seq_idx=seq_idx, chunk_size=chunk_size, out_dtype=state_dtype if state_dtype is not None else C.dtype, - is_cont_batched=cu_seqlens is not None) + is_cont_batched=cu_seqlens is not None, + chunk_offsets=chunk_offsets) states, final_states = (rearrange(t, "... (p n) -> ... p n", n=dstate) for t in [states, final_states]) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py index a28fc9ffad71..5a238f97d568 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py @@ -31,6 +31,8 @@ def _state_passing_fwd_kernel( dA_cs_ptr, initstates_ptr, seq_idx_ptr, + chunk_offsets_ptr, + chunk_meta_num, # Matrix dimensions dim, nchunks, @@ -51,6 +53,7 @@ def _state_passing_fwd_kernel( stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, + stride_dA_cs_csize, stride_initstates_batch, stride_initstates_head, stride_initstates_dim, @@ -66,7 +69,8 @@ def _state_passing_fwd_kernel( pid_h = tl.program_id(axis=2) pid_m = tl.program_id(axis=0) states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head - dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head + dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head + ( + chunk_size - 1) * stride_dA_cs_csize out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head final_states_ptr += pid_b * stride_final_states_batch + pid_h * stride_final_states_head if HAS_INITSTATES: @@ -96,15 +100,20 @@ def _state_passing_fwd_kernel( tl.store(out_ptrs, states, mask=offs_m < dim) out_ptrs += stride_out_chunk seq_idx = 0 + logical_chunk_idx = 0 for c in range(nchunks): new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) dA_cs = tl.load(dA_cs_ptr).to(tl.float32) - scale = tl.exp(dA_cs) + scale_mask = True if HAS_SEQ_IDX: + # sequence index at the start of the current chunk + seq_idx = tl.load(seq_idx_ptr + min(c * chunk_size, seqlen) * + stride_seq_idx_seqlen) + # - the seq to pass forward is the one that is flushed to the right # boundary. - # - that is given by seq_idx_new below. + # - that is given by seq_idx_new below: the sequence index at the end of the chunk. seq_idx_new = tl.load(seq_idx_ptr + (min((c + 1) * chunk_size, seqlen) - 1) * stride_seq_idx_seqlen) @@ -120,10 +129,33 @@ def _state_passing_fwd_kernel( states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + + # - we need to consider the cumsum only of the last sequence in the chunk + # - find its starting position (given by c_off of the logical chunk index) + # - and subtract the cumsum just before that position from the total cumsum + # - first, update the logical chunk index (add the number of sequences in the current physical chunk): + logical_chunk_idx += (seq_idx_new - seq_idx) + # - load the chunk offset: + c_off = tl.load(chunk_offsets_ptr + logical_chunk_idx, + mask=logical_chunk_idx > -1 and + (logical_chunk_idx + 1) < chunk_meta_num, + other=0) + # - if offset is 0, then the sequence starts at the beginning of the chunk, and we don't need to subtract anything + if c_off > 0: + # - dA_cs_ptr currently points to the cumsum at the end of the chunk - subtract the chunk size and add the offset + dA_cs_boundary = tl.load( + dA_cs_ptr - (chunk_size - 1) * stride_dA_cs_csize + + (c_off - 1) * stride_dA_cs_csize, + mask=(c_off - 1) > -1 and c_off < chunk_size, + other=0.0) + dA_cs -= dA_cs_boundary + + # - increment logical chunk index for every physical chunk + logical_chunk_idx += 1 else: - scale = tl.where(seq_idx_new == seq_idx, scale, 0.0) + scale_mask = seq_idx_new == seq_idx - seq_idx = seq_idx_new + scale = tl.where(scale_mask, tl.exp(dA_cs), 0.0) states = scale * states + new_states if c < nchunks - 1: tl.store(out_ptrs, states, mask=offs_m < dim) @@ -136,28 +168,36 @@ def _state_passing_fwd_kernel( def _state_passing_fwd( states, - dA_chunk_cumsum, + dA_cumsum, initial_states=None, seq_idx=None, chunk_size=None, out_dtype=None, is_cont_batched=False, + chunk_offsets=None, ): batch, nchunks, nheads, dim = states.shape - assert dA_chunk_cumsum.shape == (batch, nheads, nchunks) + if chunk_size is None: + chunk_size = dA_cumsum.shape[-1] + else: + assert chunk_size == dA_cumsum.shape[-1] + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) if initial_states is not None: if is_cont_batched: # - if cu_seqlens is provided, then the initial states # are used for continuous batching. In which case we # require seq_idx to be provided - assert seq_idx is not None, "" + assert seq_idx is not None, "seq_idx must be provided for continuous batching" + # - we also need chunk_offsets to be provided, to account + # for computation of dA_cumsum from the start of the + # sequence + assert chunk_offsets is not None, "chunk_offsets must be provided for continuous batching" else: # - this is the regular batching case, where initial # states are used are for each example of the batch. assert initial_states.shape == (batch, nheads, dim) if seq_idx is not None: - assert chunk_size is not None seqlen = seq_idx.shape[-1] assert seq_idx.shape == (batch, seqlen) out_dtype = states.dtype if out_dtype is None else out_dtype @@ -173,13 +213,15 @@ def _state_passing_fwd( states, out, final_states, - dA_chunk_cumsum, + dA_cumsum, initial_states, seq_idx, + chunk_offsets, + len(chunk_offsets) if chunk_offsets is not None else 0, dim, nchunks, seqlen if seq_idx is not None else 0, - chunk_size if seq_idx is not None else 0, + chunk_size, states.stride(0), states.stride(1), states.stride(2), @@ -191,9 +233,10 @@ def _state_passing_fwd( final_states.stride(0), final_states.stride(1), final_states.stride(2), - dA_chunk_cumsum.stride(0), - dA_chunk_cumsum.stride(2), - dA_chunk_cumsum.stride(1), + dA_cumsum.stride(0), + dA_cumsum.stride(2), + dA_cumsum.stride(1), + dA_cumsum.stride(3), *((initial_states.stride(0), initial_states.stride(1), initial_states.stride(2)) if initial_states is not None else (0, 0, 0)), From 7cee118967c41c01bd752acb7eb5c4d6ee7090bf Mon Sep 17 00:00:00 2001 From: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Date: Wed, 20 Aug 2025 20:48:52 +0300 Subject: [PATCH 03/11] refactor test for readability Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> --- tests/kernels/mamba/test_mamba_ssm_ssd.py | 53 +++++++++++------------ 1 file changed, 25 insertions(+), 28 deletions(-) diff --git a/tests/kernels/mamba/test_mamba_ssm_ssd.py b/tests/kernels/mamba/test_mamba_ssm_ssd.py index e2e87b5fd85e..642418146c8e 100644 --- a/tests/kernels/mamba/test_mamba_ssm_ssd.py +++ b/tests/kernels/mamba/test_mamba_ssm_ssd.py @@ -330,41 +330,38 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, exhausted[i] = False -@pytest.mark.parametrize("chunk_size", [8, 128]) -@pytest.mark.parametrize("max_seqlen", [16, 270]) -def test_mamba_chunk_scan_cont_batch_prefill_chunking(max_seqlen, chunk_size): +@pytest.mark.parametrize("chunk_size", [8, 256]) +@pytest.mark.parametrize("seqlens", [ + (16, 2, 8, 13), + (270, 88, 212, 203), +]) +def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens): + max_seqlen = max(seqlens) # This test can have larger error for longer sequences if max_seqlen > 256: atol, rtol = 1e-2, 5e-3 else: atol, rtol = 5e-3, 5e-3 - batch_size = 4 + num_sequences = len(seqlens) n_heads = 16 d_head = 64 itype = torch.float32 last_taken = {} exhausted = {} - device = "cuda" - current_platform.seed_everything(0) - per_example_seqlens = torch.randint(1, - max_seqlen + 1, (batch_size, ), - dtype=torch.int32, - device=device) - per_example_seqlens[0] = max_seqlen _, cu_seqlens, seq_idx, (A, dt, X, B, C) = next( - generate_continuous_batched_examples( - [tuple(per_example_seqlens.tolist())], - batch_size, - max_seqlen, - last_taken, - exhausted, - n_heads, - d_head, - itype, - device, - return_ref=False)) + generate_continuous_batched_examples([seqlens], + num_sequences, + max_seqlen, + last_taken, + exhausted, + n_heads, + d_head, + itype, + return_ref=False)) + seqlens = torch.tensor(seqlens, dtype=torch.int32, device=X.device) + device = X.device ## full seqlen computation chunk_indices, chunk_offsets = \ @@ -390,7 +387,7 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(max_seqlen, chunk_size): ## chunked seqlen computation # first chunk - chunked_seqlens = per_example_seqlens // 2 + chunked_seqlens = seqlens // 2 chunked_cu_seqlens = torch.cat([ torch.tensor([0], device=device), torch.cumsum(chunked_seqlens, dim=0) @@ -405,7 +402,7 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(max_seqlen, chunk_size): dt_chunked = torch.zeros_like(dt)[:, :chunked_input_seq_len, ...] B_chunked = torch.zeros_like(B)[:, :chunked_input_seq_len, ...] C_chunked = torch.zeros_like(C)[:, :chunked_input_seq_len, ...] - for i in range(batch_size): + for i in range(num_sequences): # fmt: off chunk_f = lambda x, i: x[:, cu_seqlens[i]:cu_seqlens[i] + chunked_seqlens[i], ...] # noqa: E501 @@ -437,7 +434,7 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(max_seqlen, chunk_size): ) # remaining chunk - remaining_chunked_seqlens = per_example_seqlens - chunked_seqlens + remaining_chunked_seqlens = seqlens - chunked_seqlens remaining_chunked_cu_seqlens = torch.cat([ torch.tensor([0], device=device), torch.cumsum(remaining_chunked_seqlens, dim=0) @@ -454,7 +451,7 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(max_seqlen, chunk_size): remaining_dt_chunked = torch.zeros_like(dt)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501 remaining_B_chunked = torch.zeros_like(B)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501 remaining_C_chunked = torch.zeros_like(C)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501 - for i in range(batch_size): + for i in range(num_sequences): remaining_chunk_f = lambda x, i: x[:, cu_seqlens[i] + chunked_seqlens[i]:cu_seqlens[i+1], ...] # noqa: E501 remaining_X_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(X, i) # noqa: E501 @@ -468,7 +465,7 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(max_seqlen, chunk_size): pt2[:,remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1],...], ], dim=1) - concat_batch_f = lambda pt1, pt2: torch.cat([concat_chunk_f(pt1, pt2, i) for i in range(batch_size)], dim=1) # noqa: E501 + concat_batch_f = lambda pt1, pt2: torch.cat([concat_chunk_f(pt1, pt2, i) for i in range(num_sequences)], dim=1) # noqa: E501 # fmt: on assert concat_batch_f(X_chunked, remaining_X_chunked).equal(X) @@ -502,7 +499,7 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(max_seqlen, chunk_size): Y = concat_batch_f(Y_partial, Y_chunked) # kernel chunked is same as kernel overall - for i in range(batch_size): + for i in range(num_sequences): Y_seq = Y[:, cu_seqlens[i]:cu_seqlens[i + 1], ...] Y_ref_seq = Y_ref[:, cu_seqlens[i]:cu_seqlens[i + 1], ...] torch.testing.assert_close( From 4b18938f5effb699ac85eb7e993b2a97da17422a Mon Sep 17 00:00:00 2001 From: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Date: Wed, 20 Aug 2025 20:50:12 +0300 Subject: [PATCH 04/11] Add another failing test case Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> --- tests/kernels/mamba/test_mamba_ssm_ssd.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/kernels/mamba/test_mamba_ssm_ssd.py b/tests/kernels/mamba/test_mamba_ssm_ssd.py index 642418146c8e..833f00fbc0b3 100644 --- a/tests/kernels/mamba/test_mamba_ssm_ssd.py +++ b/tests/kernels/mamba/test_mamba_ssm_ssd.py @@ -334,6 +334,7 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, @pytest.mark.parametrize("seqlens", [ (16, 2, 8, 13), (270, 88, 212, 203), + (16, 20), ]) def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens): max_seqlen = max(seqlens) From 1fff3d7466b28619fec5973b9446d8e9b4e67a90 Mon Sep 17 00:00:00 2001 From: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Date: Wed, 20 Aug 2025 20:52:52 +0300 Subject: [PATCH 05/11] fix the failing test case: more careful sequence index handling (+refactor to change names for better readability) Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> --- .../layers/mamba/ops/ssd_state_passing.py | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py index 5a238f97d568..945bcabb7daf 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py @@ -99,7 +99,7 @@ def _state_passing_fwd_kernel( tl.store(out_ptrs, states, mask=offs_m < dim) out_ptrs += stride_out_chunk - seq_idx = 0 + prev_seq_idx_chunk_end = 0 logical_chunk_idx = 0 for c in range(nchunks): new_states = tl.load(states_ptrs, mask=offs_m < dim, @@ -107,23 +107,18 @@ def _state_passing_fwd_kernel( dA_cs = tl.load(dA_cs_ptr).to(tl.float32) scale_mask = True if HAS_SEQ_IDX: - # sequence index at the start of the current chunk - seq_idx = tl.load(seq_idx_ptr + min(c * chunk_size, seqlen) * - stride_seq_idx_seqlen) - # - the seq to pass forward is the one that is flushed to the right # boundary. - # - that is given by seq_idx_new below: the sequence index at the end of the chunk. - seq_idx_new = tl.load(seq_idx_ptr + - (min((c + 1) * chunk_size, seqlen) - 1) * - stride_seq_idx_seqlen) + # - that is given by seq_idx_chunk_end below: the sequence index at the end of the chunk. + seq_idx_chunk_end = tl.load(seq_idx_ptr + (min( + (c + 1) * chunk_size, seqlen) - 1) * stride_seq_idx_seqlen) if HAS_INITSTATES: - if IS_CONT_BATCHED and seq_idx != seq_idx_new: + if IS_CONT_BATCHED and prev_seq_idx_chunk_end != seq_idx_chunk_end: # this means in the current chunk the rightmost flushed seq # has changed. # - so we do not propagate the state from previous chunk # - but rather we load that sequence's init state - initstates_ptrs = initstates_ptr + seq_idx_new * stride_initstates_batch + initstates_ptrs = initstates_ptr + seq_idx_chunk_end * stride_initstates_batch # - update state with seq_idx_new's init state states = tl.load(initstates_ptrs, @@ -134,7 +129,11 @@ def _state_passing_fwd_kernel( # - find its starting position (given by c_off of the logical chunk index) # - and subtract the cumsum just before that position from the total cumsum # - first, update the logical chunk index (add the number of sequences in the current physical chunk): - logical_chunk_idx += (seq_idx_new - seq_idx) + # sequence index at the start of the current chunk + seq_idx_chunk_start = tl.load(seq_idx_ptr + + min(c * chunk_size, seqlen) * + stride_seq_idx_seqlen) + logical_chunk_idx += seq_idx_chunk_end - seq_idx_chunk_start # - load the chunk offset: c_off = tl.load(chunk_offsets_ptr + logical_chunk_idx, mask=logical_chunk_idx > -1 and @@ -153,7 +152,8 @@ def _state_passing_fwd_kernel( # - increment logical chunk index for every physical chunk logical_chunk_idx += 1 else: - scale_mask = seq_idx_new == seq_idx + scale_mask = seq_idx_chunk_end == prev_seq_idx_chunk_end + prev_seq_idx_chunk_end = seq_idx_chunk_end scale = tl.where(scale_mask, tl.exp(dA_cs), 0.0) states = scale * states + new_states From a2101a7ac565d6434d40d5f4348582fb21077b59 Mon Sep 17 00:00:00 2001 From: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Date: Wed, 20 Aug 2025 21:29:02 +0300 Subject: [PATCH 06/11] Add docstring to somewhat cryptic function Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> --- vllm/v1/attention/backends/mamba2_attn.py | 55 +++++++++++++++++++++-- 1 file changed, 52 insertions(+), 3 deletions(-) diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index ace078e2b27c..817ef17cf87d 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -16,9 +16,58 @@ from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec -def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor, - chunk_size: int, - total_seqlens: int): +def _query_start_loc_to_chunk_indices_offsets( + query_start_loc: torch.Tensor, chunk_size: int, + total_seqlens: int) -> tuple[torch.Tensor, torch.Tensor]: + """ + Args: + query_start_loc (torch.Tensor): 1D tensor of cumulative sequence + lengths, shape (num_seqs + 1,). + The first element should be 0. Each entry represents the starting + index of a sequence in the flattened token array. + chunk_size (int): The size of each physical mamba chunk + (number of tokens per chunk). + total_seqlens (int): The total number of tokens in the batch. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - chunk_indices (torch.Tensor): 1D tensor of indices + indicating the physical chunk for each logical chunk. + - chunk_offsets (torch.Tensor): 1D tensor of offsets + indicating the starting index of each logical chunk within + its physical chunk. + + This function computes the chunk indices and offsets for the given + query_start_loc and chunk_size. Both are tensors of integers with length N, + where N is the number of logical (pseudo) chunks. + A logical chunk is a sequence of tokens that are all part of the same + sequence and are all in the same physical mamba chunk. + In other words, a logical chunk changes every time we cross a sequence + boundary or a physical mamba chunk boundary. + Logical chunks are needed to handle batched requests with initial states + (see _state_passing_fwd and _chunk_scan_fwd). + The chunk_indices tensor contains the index of the physical chunk for each + logical chunk. + The chunk_offsets tensor contains the offset (AKA starting index) of the + logical chunk in the physical chunk. + + Example: + query_start_loc = [0, 5, 10] + chunk_size = 8 + total_seqlens = 10 + -> chunk_indices = [0, 1, 0] + -> chunk_offsets = [0, 5, 0] + + In this example, we have 2 sequences, each with 5 tokens. The physical + chunk size is 8 tokens. + We have three logical chunks: + - the first logical chunk starts at token 0 in the first physical chunk + and contains all 5 tokens from the first sequence + - the second logical chunk starts at token 5 in the first physical chunk + and contains first 3 tokens from the second sequence + - the third logical chunk starts at token 0 in the second physical chunk + and contains the remaining 2 tokens from the second sequence + """ cu_seqlens = query_start_loc[1:] # remove prepended 0 From 7d7bf565c43eafa28fb341a236c307ac839b8e16 Mon Sep 17 00:00:00 2001 From: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Date: Wed, 20 Aug 2025 22:37:01 +0300 Subject: [PATCH 07/11] mypy typehint Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> --- tests/kernels/mamba/test_mamba_ssm_ssd.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/kernels/mamba/test_mamba_ssm_ssd.py b/tests/kernels/mamba/test_mamba_ssm_ssd.py index 833f00fbc0b3..1bccc9d08db5 100644 --- a/tests/kernels/mamba/test_mamba_ssm_ssd.py +++ b/tests/kernels/mamba/test_mamba_ssm_ssd.py @@ -349,8 +349,10 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens): d_head = 64 itype = torch.float32 - last_taken = {} - exhausted = {} + # hold state during the cutting process so we know if an + # example has been exhausted and needs to cycle + last_taken: dict = {} # map: eg -> pointer to last taken sample + exhausted: dict = {} # map: eg -> boolean indicating example is exhausted _, cu_seqlens, seq_idx, (A, dt, X, B, C) = next( generate_continuous_batched_examples([seqlens], num_sequences, From 6ff01bbd8e90c59b83f8613279baa41fc84873ed Mon Sep 17 00:00:00 2001 From: tomeras91 <57313761+tomeras91@users.noreply.github.com> Date: Wed, 20 Aug 2025 22:38:31 +0300 Subject: [PATCH 08/11] fix masking when loading chunk offset Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: tomeras91 <57313761+tomeras91@users.noreply.github.com> --- vllm/model_executor/layers/mamba/ops/ssd_state_passing.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py index 945bcabb7daf..d61c3a8cdbe9 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py @@ -136,8 +136,7 @@ def _state_passing_fwd_kernel( logical_chunk_idx += seq_idx_chunk_end - seq_idx_chunk_start # - load the chunk offset: c_off = tl.load(chunk_offsets_ptr + logical_chunk_idx, - mask=logical_chunk_idx > -1 and - (logical_chunk_idx + 1) < chunk_meta_num, + mask=logical_chunk_idx < chunk_meta_num, other=0) # - if offset is 0, then the sequence starts at the beginning of the chunk, and we don't need to subtract anything if c_off > 0: From 2bfe36b66d8c7c2c54491cbdcf59e09cbedb77f9 Mon Sep 17 00:00:00 2001 From: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Date: Mon, 1 Sep 2025 02:15:48 +0300 Subject: [PATCH 09/11] fix example in docstring Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> --- vllm/v1/attention/backends/mamba2_attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index 817ef17cf87d..2960ca7e4402 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -55,7 +55,7 @@ def _query_start_loc_to_chunk_indices_offsets( query_start_loc = [0, 5, 10] chunk_size = 8 total_seqlens = 10 - -> chunk_indices = [0, 1, 0] + -> chunk_indices = [0, 0, 1] -> chunk_offsets = [0, 5, 0] In this example, we have 2 sequences, each with 5 tokens. The physical From 5ad41fcb8ba2dcf7b5f10866fabf70f69939b2a8 Mon Sep 17 00:00:00 2001 From: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Date: Mon, 1 Sep 2025 09:51:14 +0300 Subject: [PATCH 10/11] rename parameter and add documentation Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> --- tests/kernels/mamba/test_mamba_ssm_ssd.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tests/kernels/mamba/test_mamba_ssm_ssd.py b/tests/kernels/mamba/test_mamba_ssm_ssd.py index 1bccc9d08db5..f1afcfe95467 100644 --- a/tests/kernels/mamba/test_mamba_ssm_ssd.py +++ b/tests/kernels/mamba/test_mamba_ssm_ssd.py @@ -116,17 +116,20 @@ def generate_continuous_batched_examples(example_lens_by_batch, d_head, itype, device='cuda', - return_ref=True): + return_naive_ref=True): # this function generates a random examples of certain length # and then cut according to "example_lens_by_batch" and feed - # them in continuous batches to the kernels + # them in continuous batches to the kernels. + # If if return_naive_ref=True, the naive torch implementation + # ssd_minimal_discrete will be used to compute and return + # reference output. # generate the full-length example A, dt, X, B, C = generate_random_inputs(num_examples, full_length, n_heads, d_head, itype) - if return_ref: + if return_naive_ref: Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1), A * dt, B, @@ -183,7 +186,7 @@ def end_boundary(n: int): IND_E = [end_boundary(x + y) for x, y in zip(IND_S, spec)] yield ([Y_min[s, IND_S[s]:IND_E[s]] - for s in range(num_examples)] if return_ref else None, + for s in range(num_examples)] if return_naive_ref else None, cu_seqlens, seq_idx.unsqueeze(0), (A, dt2, X2, B2, C2)) @@ -362,7 +365,7 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens): n_heads, d_head, itype, - return_ref=False)) + return_naive_ref=False)) seqlens = torch.tensor(seqlens, dtype=torch.int32, device=X.device) device = X.device From 39bca2d7035272ff64c84fd18a6f00b094b6fb6e Mon Sep 17 00:00:00 2001 From: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Date: Mon, 1 Sep 2025 10:03:34 +0300 Subject: [PATCH 11/11] Add docstring to test Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> --- tests/kernels/mamba/test_mamba_ssm_ssd.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/kernels/mamba/test_mamba_ssm_ssd.py b/tests/kernels/mamba/test_mamba_ssm_ssd.py index f1afcfe95467..1ce7f9d85e87 100644 --- a/tests/kernels/mamba/test_mamba_ssm_ssd.py +++ b/tests/kernels/mamba/test_mamba_ssm_ssd.py @@ -340,6 +340,18 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, (16, 20), ]) def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens): + + # This test verifies the correctness of the chunked prefill implementation + # in the mamba2 ssd kernels, by comparing concatenation (in the sequence + # dimension) of chunked results with the full sequence result. + # It is different from test_mamba_chunk_scan_cont_batch by: + # 1. Not using the naive torch implementaion (ssd_minimal_discrete) to get + # reference outputs. Instead, it compares chunked kernel outputs to full + # sequence kernel outputs. This is the most straightforward way to + # assert chunked prefill correctness. + # 2. It focuses on cases where sequences change in the middle of mamba + # chunks, and not necessarily on chunk boundaries. + max_seqlen = max(seqlens) # This test can have larger error for longer sequences if max_seqlen > 256: