From 3be653f1aae3c29166721e0742fabebf766b3ff2 Mon Sep 17 00:00:00 2001 From: jinmanx Date: Wed, 3 Dec 2025 02:09:23 -0800 Subject: [PATCH 1/2] fix attention bug --- gpt_oss/triton/attention.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/gpt_oss/triton/attention.py b/gpt_oss/triton/attention.py index bf689055..07d179c7 100644 --- a/gpt_oss/triton/attention.py +++ b/gpt_oss/triton/attention.py @@ -62,6 +62,7 @@ def _attn_fwd( lo, hi = tl.maximum(start_q, start_q + start_m * BLOCK_M - BANDWIDTH), start_q + (start_m + 1) * BLOCK_M else: lo, hi = start_q, start_q + (start_m + 1) * BLOCK_M + hi = tl.minimum(hi, N_KV_CTX) for start_n in range(lo, hi, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) @@ -181,6 +182,7 @@ def attention_ref( pos_keys = torch.arange(num_keys, device=query.device) pos_queries = torch.arange(num_queries, device=query.device) + start_q mask = pos_keys[None, :] > pos_queries[:, None] + mask = mask | (pos_keys[None, :] < start_q) mask = mask.float().masked_fill(mask, float("-inf")) if sliding_window: @@ -211,7 +213,7 @@ def attention_ref( @pytest.mark.parametrize("head_dim", [64]) @pytest.mark.parametrize("sm_scale", [0.125]) @pytest.mark.parametrize("sliding_window", [None, 128]) -@pytest.mark.parametrize("start_q", [0, 5]) +@pytest.mark.parametrize("start_q", [0, 64]) def test_eq(batch_size, num_queries, num_keys, num_key_value_heads, num_key_value_groups, head_dim, sm_scale, sliding_window, start_q): if num_queries > num_keys: pytest.skip("too many queries") @@ -226,4 +228,4 @@ def test_eq(batch_size, num_queries, num_keys, num_key_value_heads, num_key_valu o1 = attention(q, k, v, sinks, sm_scale, sliding_window, start_q) o2 = attention_ref(q, k, v, sinks, sm_scale, sliding_window, start_q) - torch.testing.assert_close(o1, o2) + torch.testing.assert_close(o1, o2) \ No newline at end of file From d6c5ccb7f37577522d687ce282aea56f7310ea8d Mon Sep 17 00:00:00 2001 From: jinmanx Date: Wed, 3 Dec 2025 02:51:54 -0800 Subject: [PATCH 2/2] fix bug --- gpt_oss/triton/attention.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/gpt_oss/triton/attention.py b/gpt_oss/triton/attention.py index 07d179c7..018b59d0 100644 --- a/gpt_oss/triton/attention.py +++ b/gpt_oss/triton/attention.py @@ -59,9 +59,9 @@ def _attn_fwd( q = Q.load([off_z, off_h, start_m * BLOCK_M, 0]).reshape([BLOCK_M, HEAD_DIM]) if BANDWIDTH: - lo, hi = tl.maximum(start_q, start_q + start_m * BLOCK_M - BANDWIDTH), start_q + (start_m + 1) * BLOCK_M + lo, hi = tl.maximum(0, start_q + start_m * BLOCK_M - BANDWIDTH), start_q + (start_m + 1) * BLOCK_M else: - lo, hi = start_q, start_q + (start_m + 1) * BLOCK_M + lo, hi = 0, start_q + (start_m + 1) * BLOCK_M hi = tl.minimum(hi, N_KV_CTX) for start_n in range(lo, hi, BLOCK_N): @@ -182,7 +182,6 @@ def attention_ref( pos_keys = torch.arange(num_keys, device=query.device) pos_queries = torch.arange(num_queries, device=query.device) + start_q mask = pos_keys[None, :] > pos_queries[:, None] - mask = mask | (pos_keys[None, :] < start_q) mask = mask.float().masked_fill(mask, float("-inf")) if sliding_window: @@ -213,7 +212,7 @@ def attention_ref( @pytest.mark.parametrize("head_dim", [64]) @pytest.mark.parametrize("sm_scale", [0.125]) @pytest.mark.parametrize("sliding_window", [None, 128]) -@pytest.mark.parametrize("start_q", [0, 64]) +@pytest.mark.parametrize("start_q", [0, 5]) def test_eq(batch_size, num_queries, num_keys, num_key_value_heads, num_key_value_groups, head_dim, sm_scale, sliding_window, start_q): if num_queries > num_keys: pytest.skip("too many queries")