diff --git a/mlir/test/e2e/AttentionNonPowerOfTwoTileSize.toml b/mlir/test/e2e/AttentionNonPowerOfTwoTileSize.toml index 527666fff756..b855f26dcd38 100644 --- a/mlir/test/e2e/AttentionNonPowerOfTwoTileSize.toml +++ b/mlir/test/e2e/AttentionNonPowerOfTwoTileSize.toml @@ -81,3 +81,7 @@ config = "-rand 1 -return_lse -split_kv 8 -current_seq_len=17,1,32 -g 3 -num_hea # GQA + prefix causal + KV Cache batch=3 + return LSE + split-kv (padding) [[suite.test]] config = "-rand 1 -return_lse -split_kv 8 -prefix_offset=18,5,16 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias" + +# GQA + prefix causal + KV Cache batch=3 + return LSE + split-kv (padding) + paged attention +[[suite.test]] +config = "-rand 1 -return_lse -split_kv 8 -prefix_offset=18,5,16 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias --paged-attention --num-pages 6 --page-size 8192" diff --git a/mlir/test/e2e/AttentionSchedule.toml b/mlir/test/e2e/AttentionSchedule.toml index 5b685218e659..169c34a50aad 100644 --- a/mlir/test/e2e/AttentionSchedule.toml +++ b/mlir/test/e2e/AttentionSchedule.toml @@ -116,3 +116,7 @@ config = "-rand 1 -return_lse -split_kv 8 -current_seq_len=17,1,32 -g 3 -num_hea # GQA + prefix causal + KV Cache batch=3 + return LSE + split-kv (padding) [[suite.test]] config = "-rand 1 -return_lse -split_kv 8 -prefix_offset=18,5,16 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias -schedule_version 2" + +# GQA + prefix causal + KV Cache batch=3 + return LSE + split-kv (padding) + paged attention +[[suite.test]] +config = "-rand 1 -return_lse -split_kv 8 -prefix_offset=18,5,16 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias -schedule_version 2 --paged-attention --num-pages 6 --page-size 8192" diff --git a/mlir/test/e2e/PrAttentionBF16.toml b/mlir/test/e2e/PrAttentionBF16.toml index 6de1b8f50afc..f92f7b71acb4 100644 --- a/mlir/test/e2e/PrAttentionBF16.toml +++ b/mlir/test/e2e/PrAttentionBF16.toml @@ -119,3 +119,7 @@ config = "-rand 1 -return_lse -split_kv 4 -current_seq_len=17,1,32 -g 3 -num_hea # GQA + causal + KV Cache batch=3 + return LSE + split-kv (padding) [[suite.test]] config = "-rand 1 -return_lse -split_kv 8 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias" + +# GQA + causal + KV Cache batch=3 + return LSE + split-kv (padding) + paged attention +[[suite.test]] +config = "-rand 1 -return_lse -split_kv 8 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias --paged-attention --num-pages 6 --page-size 8192" diff --git a/mlir/test/e2e/PrAttentionDirectToLDS.toml b/mlir/test/e2e/PrAttentionDirectToLDS.toml index 58b8861f4c76..97191acf6219 100644 --- a/mlir/test/e2e/PrAttentionDirectToLDS.toml +++ b/mlir/test/e2e/PrAttentionDirectToLDS.toml @@ -26,3 +26,7 @@ config = "-rand 1 -return_lse -split_kv 8 -current_seq_len=17,1,32 -g 3 -num_hea # GQA + prefix causal + KV Cache batch=3 + return LSE + split-kv (padding) [[suite.test]] config = "-rand 1 -return_lse -split_kv 8 -current_seq_len=17,1,32 --prefix_offset=16,14,12 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias" + +# GQA + prefix causal + KV Cache batch=3 + return LSE + split-kv (padding) + paged attention +[[suite.test]] +config = "-rand 1 -return_lse -split_kv 8 -current_seq_len=17,1,32 --prefix_offset=16,14,12 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias --paged-attention --num-pages 6 --page-size 8192" diff --git a/mlir/test/e2e/PrAttentionF16.toml b/mlir/test/e2e/PrAttentionF16.toml index 1ceae277283a..1935cc629859 100644 --- a/mlir/test/e2e/PrAttentionF16.toml +++ b/mlir/test/e2e/PrAttentionF16.toml @@ -119,3 +119,7 @@ config = "-rand 1 -return_lse -split_kv 4 -current_seq_len=17,1,32 -g 3 -num_hea # GQA + causal + KV Cache batch=3 + return LSE + split-kv (padding) [[suite.test]] config = "-rand 1 -return_lse -split_kv 8 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias" + +# GQA + causal + KV Cache batch=3 + return LSE + split-kv (padding) + paged attention +[[suite.test]] +config = "-rand 1 -return_lse -split_kv 8 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias --paged-attention --num-pages 6 --page-size 8192" diff --git a/mlir/test/e2e/PrAttentionF32.toml b/mlir/test/e2e/PrAttentionF32.toml index b9e051119733..6a71cf19bab7 100644 --- a/mlir/test/e2e/PrAttentionF32.toml +++ b/mlir/test/e2e/PrAttentionF32.toml @@ -91,3 +91,7 @@ config = "-rand 1 -return_lse -split_kv 8 -current_seq_len=17,1,32 -g 3 -num_hea # GQA + prefix causal + KV Cache batch=3 + return LSE + split-kv (padding) [[suite.test]] config = "-rand 1 -return_lse -split_kv 8 -prefix_offset=18,5,16 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias" + +# GQA + prefix causal + KV Cache batch=3 + return LSE + split-kv (padding) + paged attention +[[suite.test]] +config = "-rand 1 -return_lse -split_kv 8 -prefix_offset=18,5,16 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias --paged-attention --num-pages 6 --page-size 8192" diff --git a/mlir/test/e2e/PrAttentionI8.toml b/mlir/test/e2e/PrAttentionI8.toml index eb1b7c251544..5844422c1197 100644 --- a/mlir/test/e2e/PrAttentionI8.toml +++ b/mlir/test/e2e/PrAttentionI8.toml @@ -96,3 +96,7 @@ config = "-rand 1 -return_lse -split_kv 4 -current_seq_len=17,1,32 -g 3 -num_hea # GQA + causal + KV Cache batch=3 + return LSE + split-kv (padding) [[suite.test]] config = "-rand 1 -return_lse -split_kv 8 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias" + +# GQA + causal + KV Cache batch=3 + return LSE + split-kv (padding) + paged attention +[[suite.test]] +config = "-rand 1 -return_lse -split_kv 8 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias --paged-attention --num-pages 6 --page-size 8192" diff --git a/mlir/test/e2e/PrAttentionSchedule.toml b/mlir/test/e2e/PrAttentionSchedule.toml index 2465abecaa4e..f98ff70caf75 100644 --- a/mlir/test/e2e/PrAttentionSchedule.toml +++ b/mlir/test/e2e/PrAttentionSchedule.toml @@ -28,3 +28,7 @@ config = "-rand 1 -return_lse -split_kv 8 -current_seq_len=17,1,32 -g 3 -num_hea # GQA + prefix causal + KV Cache batch=3 + return LSE + split-kv (padding) [[suite.test]] config = "-rand 1 -return_lse -split_kv 8 -prefix_offset=18,5,16 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias --schedule_version 2" + +# GQA + prefix causal + KV Cache batch=3 + return LSE + split-kv (padding) + paged attention +[[suite.test]] +config = "-rand 1 -return_lse -split_kv 8 -prefix_offset=18,5,16 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias --schedule_version 2 --paged-attention --num-pages 6 --page-size 8192" diff --git a/mlir/test/rocmlir-gen/paged-attention-kernel.mlir b/mlir/test/rocmlir-gen/paged-attention-kernel.mlir new file mode 100644 index 000000000000..f556693c0e85 --- /dev/null +++ b/mlir/test/rocmlir-gen/paged-attention-kernel.mlir @@ -0,0 +1,147 @@ +// RUN: rocmlir-gen --arch gfx90a:sramecc+:xnack- --operation attention -seq_len_q 1024 -seq_len_k 1024 -head_dim_qk 32 -head_dim_v 32 -t f16 -pv --apply-bufferization-pipeline=false --paged-attention --num-pages 32 --page-size 1024 | rocmlir-opt | FileCheck %s --enable-var-scope --check-prefixes=CHECK + +// CHECK: module attributes {mhal.arch = "[[$ARCH:.*]]"} + +// Verify paged attention kernel signature: +// - Q input: memref<32768xf16> (flattened [1, 1024, 32]) +// - K page table: memref<32xi64> (page pointers, numPages = 32) +// - V page table: memref<32xi64> (page pointers, numPages = 32) +// - Output: memref<32768xf16> (flattened [1, 1024, 32]) +// CHECK-LABEL: func.func @rock_attention +// CHECK-SAME: (%[[queriesRaw:.*0]]: memref<32768xf16>, +// CHECK-SAME: %[[keysPageTable:.*1]]: memref<32xi64>, +// CHECK-SAME: %[[valuesPageTable:.*2]]: memref<32xi64>, +// CHECK-SAME: %[[outputRaw:.*3]]: memref<32768xf16>) +// CHECK-SAME: attributes {kernel, mhal.arch = "[[$ARCH]]"} + +// Transform Q to [G, seq_q, head_qk] +// CHECK-NEXT: %[[queries:.*]] = rock.transform %[[queriesRaw]] {{.*}} : memref<32768xf16> to memref<1x1024x32xf16> + +// Transform K page table to [batch, numPages, 1] +// CHECK-NEXT: %[[keysPageTableTransformed:.*]] = rock.transform %[[keysPageTable]] {{.*}} : memref<32xi64> to memref<1x32x1xi64> + +// Transform V page table to [batch, numPages, 1] +// CHECK-NEXT: %[[valuesPageTableTransformed:.*]] = rock.transform %[[valuesPageTable]] {{.*}} : memref<32xi64> to memref<1x32x1xi64> + +// Transform output to [G, seq_q, head_v] +// CHECK-NEXT: %[[output:.*]] = rock.transform %[[outputRaw]] {{.*}} : memref<32768xf16> to memref<1x1024x32xf16> + +// rock.deref: dereference K page table to get actual K data +// CHECK-NEXT: %[[keyDeref:.*]] = rock.deref %[[keysPageTableTransformed]] : memref<1x32x1xi64> -> memref<1x32x1024xf16> + +// rock.deref: dereference V page table to get actual V data +// CHECK-NEXT: %[[valueDeref:.*]] = rock.deref %[[valuesPageTableTransformed]] : memref<1x32x1xi64> -> memref<1x32x1024xf16> + +// Transform deref'd K to intermediate shapes for attention GEMM +// CHECK-NEXT: %[[keyTransform1:.*]] = rock.transform %[[keyDeref]] {{.*}} : memref<1x32x1024xf16> to memref<1x32768xf16> +// CHECK-NEXT: %[[keyTransform2:.*]] = rock.transform %[[keyTransform1]] {{.*}} : memref<1x32768xf16> to memref<1x1x1024x32xf16> +// CHECK-NEXT: %[[keys:.*]] = rock.transform %[[keyTransform2]] {{.*}} : memref<1x1x1024x32xf16> to memref<1x32x1024xf16> + +// Transform deref'd V to intermediate shapes for attention GEMM +// CHECK-NEXT: %[[valueTransform1:.*]] = rock.transform %[[valueDeref]] {{.*}} : memref<1x32x1024xf16> to memref<1x32768xf16> +// CHECK-NEXT: %[[valueTransform2:.*]] = rock.transform %[[valueTransform1]] {{.*}} : memref<1x32768xf16> to memref<1x1x1024x32xf16> +// CHECK-NEXT: %[[values:.*]] = rock.transform %[[valueTransform2]] {{.*}} : memref<1x1x1024x32xf16> to memref<1x1024x32xf16> + +// Verify rock.attention op with keyAddresses and valueAddresses attributes +// CHECK-NEXT: rock.attention +// CHECK-NEXT: qk = %[[queries]] * %[[keys]] +// CHECK-NEXT: keyAddresses = (%[[keyDeref]] : memref<1x32x1024xf16>) +// CHECK-NEXT: valueAddresses = (%[[valueDeref]] : memref<1x32x1024xf16>) +// CHECK: %[[output]] = softmax(qk) * %[[values]] +// CHECK: return + +// ============================================================================= +// CPU host function validation for paged attention +// ============================================================================= + +// CHECK-LABEL: func.func @host_naive_attention +// CHECK-SAME: (%[[hostQ:.*0]]: memref<32768xf16>, +// CHECK-SAME: %[[hostK:.*1]]: memref<32768xf16>, +// CHECK-SAME: %[[hostV:.*2]]: memref<32768xf16>, +// CHECK-SAME: %[[hostOut:.*3]]: memref<32768xf16>) + +// Convert Q memref to tensor and reshape to [1, 1024, 32] +// CHECK: bufferization.to_tensor %[[hostQ]] +// CHECK: tosa.reshape {{.*}} : (tensor<32768xf16>, !tosa.shape<3>) -> tensor<1x1024x32xf16> + +// Reshape K to [1, 32, 1024] for Q*K^T matmul +// CHECK: bufferization.to_tensor %[[hostK]] +// CHECK: tosa.reshape {{.*}} : (tensor<32768xf16>, !tosa.shape<3>) -> tensor<1x32x1024xf16> + +// Reshape V to [1, 1024, 32] +// CHECK: bufferization.to_tensor %[[hostV]] +// CHECK: tosa.reshape {{.*}} : (tensor<32768xf16>, !tosa.shape<3>) -> tensor<1x1024x32xf16> + +// First matmul: Q * K^T -> [1, 1024, 1024] +// CHECK: tosa.matmul {{.*}} {acc_type = f32} : (tensor<1x1024x32xf16>, tensor<1x32x1024xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x1024x1024xf16> + +// Softmax: cast to f32, reduce_max, sub, exp, reduce_sum, reciprocal, mul, cast back +// CHECK: tosa.cast {{.*}} : (tensor<1x1024x1024xf16>) -> tensor<1x1024x1024xf32> +// CHECK: tosa.reduce_max {{.*}} : (tensor<1x1024x1024xf32>) -> tensor<1x1024x1xf32> +// CHECK: tosa.sub {{.*}} : (tensor<1x1024x1024xf32>, tensor<1x1024x1xf32>) -> tensor<1x1024x1024xf32> +// CHECK: tosa.exp {{.*}} : (tensor<1x1024x1024xf32>) -> tensor<1x1024x1024xf32> +// CHECK: tosa.reduce_sum {{.*}} : (tensor<1x1024x1024xf32>) -> tensor<1x1024x1xf32> +// CHECK: tosa.reciprocal {{.*}} : (tensor<1x1024x1xf32>) -> tensor<1x1024x1xf32> +// CHECK: tosa.mul {{.*}} : (tensor<1x1024x1024xf32>, tensor<1x1024x1xf32>, tensor<1xi8>) -> tensor<1x1024x1024xf32> +// CHECK: tosa.cast {{.*}} : (tensor<1x1024x1024xf32>) -> tensor<1x1024x1024xf16> + +// Second matmul: softmax(Q*K^T) * V -> [1, 1024, 32] +// CHECK: tosa.matmul {{.*}} {acc_type = f32} : (tensor<1x1024x1024xf16>, tensor<1x1024x32xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x1024x32xf16> + +// Reshape output and copy to result +// CHECK: tosa.reshape {{.*}} : (tensor<1x1024x32xf16>, !tosa.shape<1>) -> tensor<32768xf16> +// CHECK: bufferization.to_buffer +// CHECK: memref.copy +// CHECK: return + +// ---- + +// Test paged attention with GQA (grouped query attention) +// RUN: rocmlir-gen --arch gfx90a:sramecc+:xnack- --operation attention -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1024 -seq_len_k 1024 -head_dim_qk 32 -head_dim_v 32 -t f16 -pv --apply-bufferization-pipeline=false --paged-attention --num-pages 64 --page-size 1024 | rocmlir-opt | FileCheck %s --enable-var-scope --check-prefixes=CHECK_GQA + +// CHECK_GQA: module attributes {mhal.arch = "[[$ARCH:.*]]"} + +// Verify GQA paged attention kernel signature: +// - Q input: memref<131072xf16> (flattened [4, 1024, 32]) +// - K page table: memref<64xi64> (page pointers for 2 heads) +// - V page table: memref<64xi64> (page pointers for 2 heads) +// - Output: memref<131072xf16> (flattened [4, 1024, 32]) +// CHECK_GQA-LABEL: func.func @rock_attention +// CHECK_GQA-SAME: (%[[queriesRaw:.*0]]: memref<131072xf16>, +// CHECK_GQA-SAME: %[[keysPageTable:.*1]]: memref<64xi64>, +// CHECK_GQA-SAME: %[[valuesPageTable:.*2]]: memref<64xi64>, +// CHECK_GQA-SAME: %[[outputRaw:.*3]]: memref<131072xf16>) +// CHECK_GQA-SAME: attributes {kernel, mhal.arch = "[[$ARCH]]"} + +// Transform Q to [G, seq_q, head_qk] with G = num_heads_q = 4 +// CHECK_GQA-NEXT: %[[queries:.*]] = rock.transform %[[queriesRaw]] {{.*}} : memref<131072xf16> to memref<4x1024x32xf16> + +// Transform K page table +// CHECK_GQA-NEXT: %[[keysPageTableTransformed:.*]] = rock.transform %[[keysPageTable]] {{.*}} : memref<64xi64> to memref<1x64x1xi64> + +// Transform V page table +// CHECK_GQA-NEXT: %[[valuesPageTableTransformed:.*]] = rock.transform %[[valuesPageTable]] {{.*}} : memref<64xi64> to memref<1x64x1xi64> + +// Transform output +// CHECK_GQA-NEXT: %[[output:.*]] = rock.transform %[[outputRaw]] {{.*}} : memref<131072xf16> to memref<4x1024x32xf16> + +// rock.deref K +// CHECK_GQA-NEXT: %[[keyDeref:.*]] = rock.deref %[[keysPageTableTransformed]] : memref<1x64x1xi64> -> memref<1x64x1024xf16> + +// rock.deref V +// CHECK_GQA-NEXT: %[[valueDeref:.*]] = rock.deref %[[valuesPageTableTransformed]] : memref<1x64x1xi64> -> memref<1x64x1024xf16> + +// K transforms to [G, head_dim_qk, seq_k] with G = num_heads_kv = 2 +// CHECK_GQA: %[[keys:.*]] = rock.transform %{{.*}} {{.*}} to memref<2x32x1024xf16> + +// V transforms to [G, seq_k, head_dim_v] with G = num_heads_kv = 2 +// CHECK_GQA: %[[values:.*]] = rock.transform %{{.*}} {{.*}} to memref<2x1024x32xf16> + +// Verify rock.attention op with GQA and paged attention +// CHECK_GQA: rock.attention +// CHECK_GQA-NEXT: qk = %[[queries]] * %[[keys]] +// CHECK_GQA-NEXT: keyAddresses = (%[[keyDeref]] : memref<1x64x1024xf16>) +// CHECK_GQA-NEXT: valueAddresses = (%[[valueDeref]] : memref<1x64x1024xf16>) +// CHECK_GQA: %[[output]] = softmax(qk) * %[[values]] +// CHECK_GQA-NEXT: numHeadsKV = 2 : i32, numHeadsQ = 4 : i32 +// CHECK_GQA: return diff --git a/mlir/tools/rocmlir-gen/rocmlir-gen.cpp b/mlir/tools/rocmlir-gen/rocmlir-gen.cpp index 0dc3423813fd..8c95da3e2ded 100644 --- a/mlir/tools/rocmlir-gen/rocmlir-gen.cpp +++ b/mlir/tools/rocmlir-gen/rocmlir-gen.cpp @@ -701,6 +701,23 @@ static llvm::cl::opt llvm::cl::desc("Data type for softmax (attention)"), llvm::cl::init("f32")); +// Paged attention options +static llvm::cl::opt pagedAttention( + "paged-attention", + llvm::cl::desc("Enable paged attention mode with page table inputs"), + llvm::cl::init(false)); + +static llvm::cl::opt + pageSize("page-size", + llvm::cl::desc("Number of elements per page for paged attention"), + llvm::cl::value_desc("positive integer")); + +static llvm::cl::opt numPages( + "num-pages", + llvm::cl::desc("Number of pages for paged attention (required when " + "--paged-attention is set)"), + llvm::cl::value_desc("positive integer"), llvm::cl::init(-1)); + ////////////////////////////////////////////////////////////////////////// //// Host Generator options ////////////////////////////////////////////////////////////////////////// @@ -1318,6 +1335,17 @@ static LogicalResult detectMissingArguments() { << "If split-kv > 1 (flash decoding), we need to return LSE\n"; return failure(); } + + if (pagedAttention) { + if (numPages <= 0) { + llvm::errs() << "Paged attention requires --num-pages to be set\n"; + return failure(); + } + if (pageSize <= 0) { + llvm::errs() << "Paged attention requires --page-size > 0\n"; + return failure(); + } + } } if (operation == rock::KernelType::Attention || @@ -2668,7 +2696,8 @@ static func::FuncOp createGpuGemmKernel(ModuleOp module, } static void getAttentionTypes(SmallVectorImpl &result, - ArrayRef elemTypes) { + ArrayRef elemTypes, + bool forValidation = false) { SmallVector qDims{groupSize * numHeadsQ, sequenceLengthQ, headDimQK}; SmallVector transposedQDims{groupSize * numHeadsQ, headDimQK, sequenceLengthQ}; @@ -2684,6 +2713,9 @@ static void getAttentionTypes(SmallVectorImpl &result, SmallVector transposedODims{groupSize * numHeadsQ * splitKV, headDimV, sequenceLengthQ}; + // Page table dimensions for paged attention + SmallVector pageTableDims{groupSize, numPages, 1}; + bool isQuantized = elemTypes[0] == IntegerType::get(elemTypes[0].getContext(), 8); @@ -2706,12 +2738,23 @@ static void getAttentionTypes(SmallVectorImpl &result, // output type = bias type const size_t outputIndex = biasIndex; - MemRefType qType = MemRefType::get(transposeQ ? transposedQDims : qDims, - elemTypes[qIndex]), - kType = MemRefType::get(transposeK ? kDims : transposedKDims, - elemTypes[kIndex]), - vType = MemRefType::get(transposeV ? transposedVDims : vDims, - elemTypes[vIndex]); + MLIRContext *ctx = elemTypes[0].getContext(); + MemRefType qType = + MemRefType::get(transposeQ ? transposedQDims : qDims, elemTypes[qIndex]); + MemRefType kType, vType; + if (pagedAttention && !forValidation) { + // For paged attention GPU kernel, K and V are page tables with i64 + // addresses + kType = MemRefType::get(pageTableDims, IntegerType::get(ctx, 64)); + vType = MemRefType::get(pageTableDims, IntegerType::get(ctx, 64)); + } else { + // For regular attention OR paged attention validation (which uses regular + // K/V) + kType = MemRefType::get(transposeK ? kDims : transposedKDims, + elemTypes[kIndex]); + vType = MemRefType::get(transposeV ? transposedVDims : vDims, + elemTypes[vIndex]); + } result.push_back(qType); result.push_back(kType); @@ -2766,22 +2809,34 @@ static void getAttentionTypes(SmallVectorImpl &result, static void getAttentionDimNames(SmallVectorImpl> &result, - ArrayRef elementTypes) { + ArrayRef elementTypes, bool forValidation = false) { result.reserve(elementTypes.size()); constexpr StringLiteral gName = "g", seqQName = "seq_q", seqKName = "seq_k", - headQKName = "head_qk", headVName = "head_v"; + headQKName = "head_qk", headVName = "head_v", + batchName = "batch", numPagesName = "num_pages", + oneName = "one"; if (transposeQ) result.emplace_back(SmallVector{gName, headQKName, seqQName}); else result.emplace_back(SmallVector{gName, seqQName, headQKName}); - if (transposeK) - result.emplace_back(SmallVector{gName, seqKName, headQKName}); - else - result.emplace_back(SmallVector{gName, headQKName, seqKName}); - if (transposeV) - result.emplace_back(SmallVector{gName, headVName, seqKName}); - else - result.emplace_back(SmallVector{gName, seqKName, headVName}); + + // K and V dimension names differ for paged attention (but not for validation) + if (pagedAttention && !forValidation) { + // Page table shape: [batch, numPages, 1] + result.emplace_back( + SmallVector{batchName, numPagesName, oneName}); + result.emplace_back( + SmallVector{batchName, numPagesName, oneName}); + } else { + if (transposeK) + result.emplace_back(SmallVector{gName, seqKName, headQKName}); + else + result.emplace_back(SmallVector{gName, headQKName, seqKName}); + if (transposeV) + result.emplace_back(SmallVector{gName, headVName, seqKName}); + else + result.emplace_back(SmallVector{gName, seqKName, headVName}); + } bool isQuantized = elementTypes[0].isInteger(8); if (isQuantized) { result.emplace_back(SmallVector{gName, seqQName, seqKName}); @@ -3271,6 +3326,90 @@ static Value broadcastBatchTensorRock(OpBuilder builder, Location loc, return rock::TransformOp::create(builder, loc, tensorBroadcast, mergerAttr); } +// Transform paged attention deref output to K matrix shape. +// Input: [batch, numPages, pageSize] +// Output (transposeK=true): [G_kv, seqK, headDimQK] where G_kv = batch * +// numHeadsKV Output (transposeK=false): [G_kv, headDimQK, seqK] +static Value createPagedDerefToKTransforms( + OpBuilder &builder, Location loc, Value derefOutput, int64_t numHeadsKVVal, + int64_t seqLenKVal, int64_t headDimQKVal, bool transposeKVal) { + SmallVector startNames = {"batch", "numPages", "pageSize"}; + ArrayRef inpShape = + cast(derefOutput.getType()).getShape(); + + // Step 1: Merge [batch, numPages, pageSize] -> [batch, total] + rock::BottomUpTMBuilder mergeB(builder, startNames, inpShape); + mergeB.passThrough({"batch"}, {0}, {"batch"}); + mergeB.merge("total", 1, {"numPages", "pageSize"}); + auto mergeAttr = mergeB.get(); + Value merged = + rock::TransformOp::create(builder, loc, derefOutput, mergeAttr); + + // Step 2: Unmerge [batch, total] -> [batch, numHeadsKV, seqK, headDimQK] + auto unmergeB = rock::BottomUpTMBuilder::above(mergeB, mergeAttr); + unmergeB.passThrough({"batch"}, {0}, {"batch"}); + unmergeB.unmerge({"numHeadsKV", "seqK", "headDimQK"}, {1, 2, 3}, "total", + {numHeadsKVVal, seqLenKVal, headDimQKVal}); + auto unmergeAttr = unmergeB.get(); + Value unmerged = rock::TransformOp::create(builder, loc, merged, unmergeAttr); + + // Step 3: Merge [batch, numHeadsKV] -> [G_kv] and handle transpose + auto finalB = rock::BottomUpTMBuilder::above(unmergeB, unmergeAttr); + finalB.merge("G", 0, {"batch", "numHeadsKV"}); + + if (transposeKVal) { + // transposeK=true means NOT transposed layout: [G, seqK, headDimQK] + finalB.passThrough({"seqK", "headDimQK"}, {1, 2}, {"seqK", "headDimQK"}); + } else { + // transposeK=false means transposed layout: [G, headDimQK, seqK] + finalB.passThrough({"headDimQK", "seqK"}, {1, 2}, {"headDimQK", "seqK"}); + } + auto finalAttr = finalB.get(); + return rock::TransformOp::create(builder, loc, unmerged, finalAttr); +} + +// Transform paged attention deref output to V matrix shape. +// Input: [batch, numPages, pageSize] +// Output (transposeV=true): [G_kv, headDimV, seqK] +// Output (transposeV=false): [G_kv, seqK, headDimV] +static Value createPagedDerefToVTransforms( + OpBuilder &builder, Location loc, Value derefOutput, int64_t numHeadsKVVal, + int64_t seqLenKVal, int64_t headDimVVal, bool transposeVVal) { + SmallVector startNames = {"batch", "numPages", "pageSize"}; + ArrayRef inpShape = + cast(derefOutput.getType()).getShape(); + + // Step 1: Merge [batch, numPages, pageSize] -> [batch, total] + rock::BottomUpTMBuilder mergeB(builder, startNames, inpShape); + mergeB.passThrough({"batch"}, {0}, {"batch"}); + mergeB.merge("total", 1, {"numPages", "pageSize"}); + auto mergeAttr = mergeB.get(); + Value merged = + rock::TransformOp::create(builder, loc, derefOutput, mergeAttr); + + // Step 2: Unmerge [batch, total] -> [batch, numHeadsKV, seqK, headDimV] + auto unmergeB = rock::BottomUpTMBuilder::above(mergeB, mergeAttr); + unmergeB.passThrough({"batch"}, {0}, {"batch"}); + unmergeB.unmerge({"numHeadsKV", "seqK", "headDimV"}, {1, 2, 3}, "total", + {numHeadsKVVal, seqLenKVal, headDimVVal}); + auto unmergeAttr = unmergeB.get(); + Value unmerged = rock::TransformOp::create(builder, loc, merged, unmergeAttr); + + // Step 3: Merge [batch, numHeadsKV] -> [G_kv] and handle transpose + auto finalB = rock::BottomUpTMBuilder::above(unmergeB, unmergeAttr); + finalB.merge("G", 0, {"batch", "numHeadsKV"}); + + if (transposeVVal) { + // transposeV=true means transposed layout: [G, headDimV, seqK] + finalB.passThrough({"headDimV", "seqK"}, {1, 2}, {"headDimV", "seqK"}); + } else { + // transposeV=false means NOT transposed layout: [G, seqK, headDimV] + finalB.passThrough({"seqK", "headDimV"}, {1, 2}, {"seqK", "headDimV"}); + } + auto finalAttr = finalB.get(); + return rock::TransformOp::create(builder, loc, unmerged, finalAttr); +} + static void setScheduleVersion(MLIRContext *ctx, func::FuncOp func) { if (gemmScheduleVersion.getValue() != GEMMScheduleVersion::V1) func->setAttr(rock::ScheduleVersionAttr::getMnemonic(), @@ -3331,6 +3470,44 @@ static func::FuncOp createGpuAttentionKernel(ModuleOp module, Value keys = unflattenedArgs[1]; Value values = unflattenedArgs[2]; + // For paged attention, keys and values are page tables. + // We need to create rock.deref ops and apply transforms. + Value keyAddresses = nullptr; + Value valueAddresses = nullptr; + if (pagedAttention) { + // Keys and values are page tables [batch, numPages, 1] with i64 addresses. + // Create rock.deref ops to get [batch, numPages, pageSize] memrefs. + Type elemTypeK = params.types[1]; + Type elemTypeV = params.types[2]; + MemRefType keyDerefOutputType = + MemRefType::get({groupSize, numPages, pageSize}, elemTypeK); + MemRefType valueDerefOutputType = + MemRefType::get({groupSize, numPages, pageSize}, elemTypeV); + + // rock.deref: input page table -> output data memref + Value keyDeref = + rock::DerefOp::create(builder, loc, keyDerefOutputType, keys); + Value valueDeref = + rock::DerefOp::create(builder, loc, valueDerefOutputType, values); + + // keyAddresses/valueAddresses are the raw deref outputs [batch, numPages, + // pageSize] These are used by attention op for block-level loading + keyAddresses = keyDeref; + valueAddresses = valueDeref; + + // Compute seqLenK from paged cache dimensions: + // totalElements = numPages * pageSize = numHeadsKV * seqLenK * headDimQK + int64_t totalElements = numPages * pageSize; + int64_t denominator = numHeadsKV * headDimQK; + int64_t seqLenKVal = totalElements / denominator; + + // Transform deref outputs to attention's expected K/V shapes + keys = createPagedDerefToKTransforms(builder, loc, keyDeref, numHeadsKV, + seqLenKVal, headDimQK, transposeK); + values = createPagedDerefToVTransforms(builder, loc, valueDeref, numHeadsKV, + seqLenKVal, headDimV, transposeV); + } + Value quantBias; Value quantScale; Value scale; @@ -3376,11 +3553,9 @@ static func::FuncOp createGpuAttentionKernel(ModuleOp module, TypeAttr::get(typeFromString(softmaxDataType.getValue(), ctx)); auto attention = rock::AttentionOp::create( builder, loc, TypeRange{}, queries, keys, values, elemwiseInputs, - currentSeqLenTensor, prefixOffsetTensor, - /*keyAddresses=*/nullptr, /*valueAddresses=*/nullptr, - output, lse, numHeadsQ, - numHeadsKV, transposeQ, transposeK, transposeV, transposeO, actualCausal, - splitKV, + currentSeqLenTensor, prefixOffsetTensor, keyAddresses, valueAddresses, + output, lse, numHeadsQ, numHeadsKV, transposeQ, transposeK, transposeV, + transposeO, actualCausal, splitKV, rock::GemmFeaturesAttr::get(builder.getContext(), params.features), storeMethod, softmaxType, /*params0=*/nullptr, /*params1=*/nullptr, @@ -4133,7 +4308,8 @@ static func::FuncOp createCpuAttentionKernelWithMlir(ModuleOp module, bool isQuantized = params.types[0] == IntegerType::get(ctx, 8); SmallVector argTypes; - getAttentionTypes(argTypes, params.types); + // For validation, always use regular K/V types (not page tables) + getAttentionTypes(argTypes, params.types, /*forValidation=*/true); SmallVector flatArgTypes = llvm::map_to_vector(argTypes, rock::getFlattenedType); @@ -4928,6 +5104,103 @@ static void insertValidationCalls(const GenParams &genParams, OpBuilder &b, } } +// Generates a deterministic shuffled permutation for page indices. +static SmallVector generatePageShuffle(int64_t totalPages, + int64_t seed) { + SmallVector perm(totalPages); + // Initialize with identity permutation + for (int64_t i = 0; i < totalPages; ++i) + perm[i] = i; + + // Fisher-Yates shuffle with deterministic LCG-based random + // LCG parameters (same as glibc): a=1103515245, c=12345, m=2^31 + uint64_t state = static_cast(seed) ^ 0xDEADBEEF; + auto nextRand = [&state]() { + state = (state * 1103515245ULL + 12345ULL) & 0x7FFFFFFF; + return state; + }; + + for (int64_t i = totalPages - 1; i > 0; --i) { + int64_t j = static_cast(nextRand() % static_cast(i + 1)); + std::swap(perm[i], perm[j]); + } + + return perm; +} + +// Allocates GPU cache buffer and populates page table with GPU addresses. +// This generates code that: +// 1. Allocates GPU memory for the cache +// 2. Copies CPU cache data to GPU (data is pre-shuffled on CPU side) +// 3. Extracts the GPU base pointer +// 4. Fills each entry in the page table with shuffled addresses +static Value populatePagedAttentionPageTableWithGpuCache( + OpBuilder &b, Location loc, ModuleOp module, Value cpuCache, + Value pageTable, int64_t batchSize, int64_t numPagesVal, + int64_t pageSizeVal, Type elemType, const SmallVector &shuffle) { + MLIRContext *ctx = b.getContext(); + + // Get element size in bytes + int64_t elemSizeBytes = elemType.getIntOrFloatBitWidth() / 8; + + // Allocate GPU memory for cache + MemRefType cacheType = cast(cpuCache.getType()); + auto tokenType = gpu::AsyncTokenType::get(ctx); + + // gpu.wait to get initial token + auto waitOp = gpu::WaitOp::create(b, loc, tokenType, ValueRange{}); + Value initToken = waitOp.getAsyncToken(); + + // gpu.alloc for cache buffer + auto gpuAllocOp = + gpu::AllocOp::create(b, loc, cacheType, tokenType, ValueRange{initToken}, + ValueRange{}, ValueRange{}); + Value gpuCache = gpuAllocOp.getMemref(); + Value allocToken = gpuAllocOp.getAsyncToken(); + + // gpu.memcpy from CPU to GPU cache + auto memcpyOp = gpu::MemcpyOp::create( + b, loc, tokenType, ValueRange{allocToken}, gpuCache, cpuCache); + Value copyToken = memcpyOp.getAsyncToken(); + + // gpu.wait to ensure copy completes before extracting pointer + gpu::WaitOp::create(b, loc, TypeRange{}, ValueRange{copyToken}); + + // Extract base pointer from GPU cache as index + Value baseAddr = + memref::ExtractAlignedPointerAsIndexOp::create(b, loc, gpuCache); + + // Convert to i64 + Value baseAddrI64 = + arith::IndexCastOp::create(b, loc, b.getI64Type(), baseAddr); + + // Constants + Value pageSizeBytes = arith::ConstantOp::create( + b, loc, b.getI64IntegerAttr(pageSizeVal * elemSizeBytes)); + + // Page table is flattened to 1D: memref + // Logical page L maps to physical slot shuffle[L] + for (int64_t batch = 0; batch < batchSize; ++batch) { + for (int64_t page = 0; page < numPagesVal; ++page) { + int64_t logicalIdx = batch * numPagesVal + page; + int64_t physicalSlot = shuffle[logicalIdx]; + + // Compute GPU memory offset for the physical slot + Value physicalOffset = + arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(physicalSlot)); + Value offset = + arith::MulIOp::create(b, loc, physicalOffset, pageSizeBytes); + Value addr = arith::AddIOp::create(b, loc, baseAddrI64, offset); + + // Store in page table at logical index + Value logicalIdxVal = arith::ConstantIndexOp::create(b, loc, logicalIdx); + memref::StoreOp::create(b, loc, addr, pageTable, logicalIdxVal); + } + } + + return gpuCache; +} + static LogicalResult populateHostHarnessLogic( ModuleOp module, const SmallVector &kernels, const SmallVector &roots, const GenParams &genParams) { @@ -5030,6 +5303,134 @@ static LogicalResult populateHostHarnessLogic( SmallVector localVars; SmallVector valVars; + + // For paged attention: track cache buffers separately (not passed to kernel) + Value keyCacheBuffer = nullptr; + Value valueCacheBuffer = nullptr; + Value keyCacheCPU = nullptr; + Value valueCacheCPU = nullptr; + Value keyCacheShuffled = nullptr; + Value valueCacheShuffled = nullptr; + SmallVector keyShuffle; + SmallVector valueShuffle; + const int64_t kParamIdx = 1; + const int64_t vParamIdx = 2; + + // If paged attention, pre-allocate CPU cache buffers and fill with data + if (isAttention && pagedAttention) { + // Cache shape: [groupSize * numPages * pageSize] flattened + Type keyCacheElemType = genParams.types[1]; + Type valueCacheElemType = genParams.types[2]; + int64_t cacheSize = groupSize * numPages * pageSize; + int64_t totalPages = groupSize * numPages; + MemRefType keyCacheType = MemRefType::get({cacheSize}, keyCacheElemType); + MemRefType valueCacheType = + MemRefType::get({cacheSize}, valueCacheElemType); + + // Allocate CPU cache buffers in logical order (for validation) + keyCacheCPU = memref::AllocOp::create(b, loc, keyCacheType); + valueCacheCPU = memref::AllocOp::create(b, loc, valueCacheType); + + // Fill CPU cache buffers with random/pattern data (logical order) + if (!isRandom) { + SmallVector kPattern = getTensorInitPattern(keyCacheElemType, 1); + SmallVector vPattern = getTensorInitPattern(valueCacheElemType, 2); + if (failed(populateTensorFillLogic(b, loc, kPattern, keyCacheElemType, + keyCacheCPU))) + return failure(); + if (failed(populateTensorFillLogic(b, loc, vPattern, valueCacheElemType, + valueCacheCPU))) + return failure(); + } else { + if (failed(populateRandomTensorFillLogic(b, loc, module, keyCacheElemType, + keyCacheCPU, 1, false))) + return failure(); + if (failed(populateRandomTensorFillLogic( + b, loc, module, valueCacheElemType, valueCacheCPU, 2, false))) + return failure(); + } + + // Generate shuffled page mappings (different seeds for K and V) + int64_t seed = getRandomSeed(); + keyShuffle = generatePageShuffle(totalPages, seed); + valueShuffle = generatePageShuffle(totalPages, seed + 12345); + + // Allocate shuffled cache buffers + keyCacheShuffled = memref::AllocOp::create(b, loc, keyCacheType); + valueCacheShuffled = memref::AllocOp::create(b, loc, valueCacheType); + + // Create a lookup table in memory for the shuffle mapping + // This allows us to use runtime loops instead of unrolling at compile time + auto createShuffleLUT = [&](const SmallVector &shuffle) -> Value { + MemRefType lutType = MemRefType::get({totalPages}, b.getIndexType()); + Value lut = memref::AllocOp::create(b, loc, lutType); + for (int64_t i = 0; i < totalPages; ++i) { + Value idx = arith::ConstantIndexOp::create(b, loc, i); + Value physSlot = arith::ConstantIndexOp::create(b, loc, shuffle[i]); + memref::StoreOp::create(b, loc, physSlot, lut, idx); + } + return lut; + }; + + Value keyLUT = createShuffleLUT(keyShuffle); + Value valueLUT = createShuffleLUT(valueShuffle); + + // Copy data from logical order to shuffled order using runtime loops + // For each logical page, copy pageSize elements to the shuffled physical + // slot + auto emitShuffledCopy = [&](Value logicalCache, Value shuffledCache, + Value shuffleLUT) { + Value zero = arith::ConstantIndexOp::create(b, loc, 0); + Value one = arith::ConstantIndexOp::create(b, loc, 1); + Value totalPagesVal = arith::ConstantIndexOp::create(b, loc, totalPages); + Value pageSizeVal = arith::ConstantIndexOp::create(b, loc, pageSize); + + // Outer loop over pages + scf::ForOp::create( + b, loc, zero, totalPagesVal, one, ValueRange{}, + [&](OpBuilder &pageBuilder, Location pageLoc, Value logicalPage, + ValueRange) { + // Look up the physical slot for this logical page + Value physicalSlot = memref::LoadOp::create( + pageBuilder, pageLoc, shuffleLUT, logicalPage); + + // Compute base indices + Value srcBase = arith::MulIOp::create(pageBuilder, pageLoc, + logicalPage, pageSizeVal); + Value dstBase = arith::MulIOp::create(pageBuilder, pageLoc, + physicalSlot, pageSizeVal); + + // Inner loop over elements within the page + scf::ForOp::create( + pageBuilder, pageLoc, zero, pageSizeVal, one, ValueRange{}, + [&](OpBuilder &elemBuilder, Location elemLoc, Value elemIdx, + ValueRange) { + Value srcIdx = arith::AddIOp::create(elemBuilder, elemLoc, + srcBase, elemIdx); + Value dstIdx = arith::AddIOp::create(elemBuilder, elemLoc, + dstBase, elemIdx); + Value val = memref::LoadOp::create(elemBuilder, elemLoc, + logicalCache, srcIdx); + memref::StoreOp::create(elemBuilder, elemLoc, val, + shuffledCache, dstIdx); + scf::YieldOp::create(elemBuilder, elemLoc); + }); + scf::YieldOp::create(pageBuilder, pageLoc); + }); + }; + + emitShuffledCopy(keyCacheCPU, keyCacheShuffled, keyLUT); + emitShuffledCopy(valueCacheCPU, valueCacheShuffled, valueLUT); + + // Deallocate LUTs after use + memref::DeallocOp::create(b, loc, keyLUT); + memref::DeallocOp::create(b, loc, valueLUT); + + // The shuffled buffers will be used for GPU, logical buffers for validation + keyCacheBuffer = keyCacheCPU; + valueCacheBuffer = valueCacheCPU; + } + // Calculate expected indices for currentSeqLen and prefixOffset tensors. // The layout is: ..., currentSeqLen?, prefixOffset?, LSE?, Output // We need to count backwards from the end. @@ -5076,6 +5477,24 @@ static LogicalResult populateHostHarnessLogic( } else if (!prefixOffset.empty() && isAttention && static_cast(idx) == expectedPrefixOffsetIdx) { fillWithI32Values(prefixOffset); + } else if (isAttention && pagedAttention && + (static_cast(idx) == kParamIdx || + static_cast(idx) == vParamIdx)) { + // For paged attention K/V: these are page tables [batch, numPages, 1] + // Allocate GPU cache and fill page table with shuffled GPU addresses + bool isK = (static_cast(idx) == kParamIdx); + Value shuffledCache = isK ? keyCacheShuffled : valueCacheShuffled; + const SmallVector &shuffle = isK ? keyShuffle : valueShuffle; + Type cacheElemType = isK ? genParams.types[1] : genParams.types[2]; + Value gpuCache = populatePagedAttentionPageTableWithGpuCache( + b, loc, module, shuffledCache, lvar, groupSize, numPages, pageSize, + cacheElemType, shuffle); + // Store GPU cache to keep it alive during kernel execution + if (isK) { + keyCacheBuffer = gpuCache; + } else { + valueCacheBuffer = gpuCache; + } } else if (!isRandom) { bool zeroInit = llvm::is_contained(outIndices, idx) && isSplitK; SmallVector zeroPattern = {0.0f}; @@ -5108,11 +5527,101 @@ static LogicalResult populateHostHarnessLogic( valElemType = elemType; } - auto valType = MemRefType::get(paramMRType.getShape(), valElemType); - auto vvar = memref::AllocOp::create(b, loc, valType); - valVars.push_back(vvar); + // For paged attention K/V (indices 1 and 2), create regular K/V tensors + // for validation instead of copying page tables + if (isAttention && pagedAttention && + (static_cast(idx) == kParamIdx || + static_cast(idx) == vParamIdx)) { + // The CPU attention function expects FLATTENED (1D) arguments + // Calculate K/V tensor total size + bool isK = (static_cast(idx) == kParamIdx); + Type cacheElemType = isK ? genParams.types[1] : genParams.types[2]; + int64_t G_kv = groupSize * numHeadsKV; + int64_t seqK, headDim; + if (isK) { + headDim = headDimQK; + } else { + headDim = headDimV; + } + seqK = (numPages * pageSize) / (numHeadsKV * headDim); + int64_t totalElements = G_kv * seqK * headDim; + + // Create flattened K/V tensor (1D) to match function signature + MemRefType flatKvType = MemRefType::get({totalElements}, cacheElemType); + auto vvar = memref::AllocOp::create(b, loc, flatKvType); + valVars.push_back(vvar); + + Value cacheBuffer = isK ? keyCacheCPU : valueCacheCPU; + + // Cache layout (from deref unmerge): storage[g*seqK*headDim + s*headDim + // + h] Validation layout depends on transpose flags: + // transposeK=false: [G, headDim, seqK] -> storage[g*headDim*seqK + + // h*seqK + s] transposeK=true: [G, seqK, headDim] -> + // storage[g*seqK*headDim + s*headDim + h] + // Need transpose when layouts differ: K with !transposeK, V with + // transposeV + bool needsTranspose = (isK && !transposeK) || (!isK && transposeV); + + if (needsTranspose) { + // Emit transpose copy using a helper that iterates over validation + // indices and computes the corresponding cache index. + // vvar[g*headDim*seqK + h*seqK + s] = cache[g*seqK*headDim + + // s*headDim + h] + Value zero = arith::ConstantIndexOp::create(b, loc, 0); + Value one = arith::ConstantIndexOp::create(b, loc, 1); + Value totalSize = + arith::ConstantIndexOp::create(b, loc, totalElements); + Value headDimVal = arith::ConstantIndexOp::create(b, loc, headDim); + Value seqKVal = arith::ConstantIndexOp::create(b, loc, seqK); + Value headDimxSeqK = + arith::ConstantIndexOp::create(b, loc, headDim * seqK); + Value seqKxHeadDim = + arith::ConstantIndexOp::create(b, loc, seqK * headDim); + + // Single loop over all validation indices + scf::ForOp::create( + b, loc, zero, totalSize, one, ValueRange{}, + [&](OpBuilder &loopBuilder, Location loopLoc, Value valIdx, + ValueRange) { + // Decompose valIdx into (g, h, s) based on validation layout + // valIdx = g * (headDim * seqK) + h * seqK + s + Value g = arith::DivUIOp::create(loopBuilder, loopLoc, valIdx, + headDimxSeqK); + Value remainder = arith::RemUIOp::create(loopBuilder, loopLoc, + valIdx, headDimxSeqK); + Value h = arith::DivUIOp::create(loopBuilder, loopLoc, + remainder, seqKVal); + Value s = arith::RemUIOp::create(loopBuilder, loopLoc, + remainder, seqKVal); + + // Compute cache index: g * (seqK * headDim) + s * headDim + h + Value cacheIdx = arith::MulIOp::create(loopBuilder, loopLoc, g, + seqKxHeadDim); + Value tmp = + arith::MulIOp::create(loopBuilder, loopLoc, s, headDimVal); + cacheIdx = + arith::AddIOp::create(loopBuilder, loopLoc, cacheIdx, tmp); + cacheIdx = + arith::AddIOp::create(loopBuilder, loopLoc, cacheIdx, h); + + // Copy element + Value elem = memref::LoadOp::create(loopBuilder, loopLoc, + cacheBuffer, cacheIdx); + memref::StoreOp::create(loopBuilder, loopLoc, elem, vvar, + valIdx); + scf::YieldOp::create(loopBuilder, loopLoc); + }); + } else { + // Layouts match, direct copy + memref::CopyOp::create(b, loc, cacheBuffer, vvar); + } + } else { + auto valType = MemRefType::get(paramMRType.getShape(), valElemType); + auto vvar = memref::AllocOp::create(b, loc, valType); + valVars.push_back(vvar); - emitMemcpy(b, lvar, vvar); + emitMemcpy(b, lvar, vvar); + } } } @@ -5183,6 +5692,43 @@ static LogicalResult populateHostHarnessLogic( memref::DeallocOp::create(b, loc, lvar); } + // Deallocate paged attention GPU and CPU cache buffers + if (isAttention && pagedAttention) { + // Deallocate GPU cache buffers + if (keyCacheBuffer) { + auto tokenType = gpu::AsyncTokenType::get(context); + auto waitOp = gpu::WaitOp::create(b, loc, tokenType, ValueRange{}); + Value token = waitOp.getAsyncToken(); + auto deallocOp = gpu::DeallocOp::create( + b, loc, tokenType, ValueRange{token}, keyCacheBuffer); + gpu::WaitOp::create(b, loc, TypeRange{}, + ValueRange{deallocOp.getAsyncToken()}); + } + if (valueCacheBuffer) { + auto tokenType = gpu::AsyncTokenType::get(context); + auto waitOp = gpu::WaitOp::create(b, loc, tokenType, ValueRange{}); + Value token = waitOp.getAsyncToken(); + auto deallocOp = gpu::DeallocOp::create( + b, loc, tokenType, ValueRange{token}, valueCacheBuffer); + gpu::WaitOp::create(b, loc, TypeRange{}, + ValueRange{deallocOp.getAsyncToken()}); + } + // Deallocate CPU cache buffers (logical order - for validation) + if (keyCacheCPU) { + memref::DeallocOp::create(b, loc, keyCacheCPU); + } + if (valueCacheCPU) { + memref::DeallocOp::create(b, loc, valueCacheCPU); + } + // Deallocate shuffled CPU cache buffers (used for GPU copy) + if (keyCacheShuffled) { + memref::DeallocOp::create(b, loc, keyCacheShuffled); + } + if (valueCacheShuffled) { + memref::DeallocOp::create(b, loc, valueCacheShuffled); + } + } + func::ReturnOp::create(b, loc, ValueRange{}); // Set of kernels