diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index 130dd0c25a880..8929c6b7cf6e4 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -4,6 +4,7 @@ #include "contrib_ops/webgpu/bert/attention.h" #include "contrib_ops/cpu/bert/multihead_attention_helper.h" +#include "contrib_ops/webgpu/bert/flash_attention.h" #include "contrib_ops/webgpu/bert/multihead_attention.h" #include "contrib_ops/webgpu/webgpu_contrib_kernels.h" #include "core/providers/webgpu/webgpu_supported_types.h" @@ -736,6 +737,19 @@ Status Attention::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) // Compute Q, K, V from input, weights, and bias ORT_RETURN_IF_ERROR(PrepareQKV(context, parameters, input, weights, bias, &Q, &K, &V)); + // Check if we can use flash attention + // For Attention operator, we need to create present_key and present_value tensors for flash attention + // even though they are not exposed as outputs + TensorShapeVector present_kv_shape({parameters.batch_size_, parameters.num_heads_, + parameters.total_sequence_length_, parameters.head_size_}); + Tensor present_key = context.CreateGPUTensor(input->DataType(), present_kv_shape); + Tensor present_value = context.CreateGPUTensor(input->DataType(), present_kv_shape); + + if (CanApplyFlashAttention(nullptr, &present_key, &present_value, parameters, context)) { + return ApplyFlashAttention(&Q, &K, &V, attention_bias, output, nullptr, &present_key, nullptr, &present_value, + parameters, context, nullptr); + } + // Apply the actual attention computation return ApplyAttention(&Q, &K, &V, attention_bias, nullptr, nullptr, output, /* present_key */ nullptr, /* present_value */ nullptr, /* output_qk */ nullptr, parameters, context, nullptr, nullptr, -1); diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index 2a67dfdb07912..47a223f1bed28 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -76,7 +76,12 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const { } else { shader.MainFunctionBody() << " let total_seq_length = uniforms.total_sequence_length;\n"; } - shader.MainFunctionBody() << "let past_sequence_length = total_seq_length - uniforms.kv_sequence_length;\n"; + shader.MainFunctionBody() << " let past_sequence_length = total_seq_length - uniforms.kv_sequence_length;\n"; + if (past_present_share_buffer_) { + shader.MainFunctionBody() << " let present_offset = " << present_key.IndicesToOffset("present_key_indices_t(batch, num_head_id, past_sequence_length + sequence_id, head_size_id)") << ";\n"; + } else { + shader.MainFunctionBody() << " let present_offset = " << present_key.IndicesToOffset("present_key_indices_t(batch, num_head_id, sequence_id, head_size_id)") << ";\n"; + } // Add indirect dispatch logic for thread 0 if (prepare_indirect_dispatch_) { @@ -93,8 +98,7 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const { if (has_past_) { const auto& past_key = shader.AddInput("past_key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias | ShaderUsage::UseIndicesTypeAlias); shader.AddInput("past_value", ShaderUsage::UseUniform); - shader.MainFunctionBody() << "let present_offset = global_idx;" - << "if (sequence_id < past_sequence_length) {\n" + shader.MainFunctionBody() << "if (sequence_id < past_sequence_length) {\n" << " let pastOffset = " << past_key.IndicesToOffset("past_key_indices_t(batch, num_head_id, sequence_id, head_size_id)") << ";\n" << " " << present_key.SetByOffset("present_offset", "past_key[pastOffset]") << ";\n" << " " << present_value.SetByOffset("present_offset", "past_value[pastOffset]") << ";\n" @@ -104,8 +108,7 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const { << " " << present_value.SetByOffset("present_offset", "value[offset]") << ";\n" << "}"; } else { - shader.MainFunctionBody() << " let present_offset = " << present_key.IndicesToOffset("present_key_indices_t(batch, num_head_id, past_sequence_length + sequence_id, head_size_id)") << ";\n" - << " let offset = " << key.IndicesToOffset(kv_BNSH_ ? "key_indices_t(batch, num_head_id, sequence_id, head_size_id)" : "key_indices_t(batch, sequence_id, num_head_id, head_size_id)") << ";\n" + shader.MainFunctionBody() << " let offset = " << key.IndicesToOffset(kv_BNSH_ ? "key_indices_t(batch, num_head_id, sequence_id, head_size_id)" : "key_indices_t(batch, sequence_id, num_head_id, head_size_id)") << ";\n" << " " << present_key.SetByOffset("present_offset", "key[offset]") << ";\n" << " " << present_value.SetByOffset("present_offset", "value[offset]") << ";\n"; } @@ -134,10 +137,10 @@ Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAtt // Determine if we need to prepare indirect dispatch bool prepare_indirect_dispatch = (indirect_buffer != nullptr); bool use_seqlen_k = (seqlen_k != nullptr); - - CopyKVCacheProgram program{"CopyKVCache", has_past, parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH, + bool kv_BNSH = parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH || parameters.qkv_format_ == Q_K_V_BNSH; + CopyKVCacheProgram program{"CopyKVCache", has_past, kv_BNSH, parameters.past_present_share_buffer_, prepare_indirect_dispatch, use_seqlen_k}; - if (parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH) { + if (kv_BNSH) { program.AddInputs({{K, ProgramTensorMetadataDependency::TypeAndRank, components}, {V, ProgramTensorMetadataDependency::TypeAndRank, components}}); } else { @@ -207,6 +210,7 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { WGSL_TEMPLATE_PARAMETER(is_qualcomm, is_qualcomm_), WGSL_TEMPLATE_PARAMETER(is_unidirectional, is_unidirectional_), WGSL_TEMPLATE_PARAMETER(prefer_subgroupshuffle, !is_nvidia_), + WGSL_TEMPLATE_PARAMETER(q_BNSH, q_BNSH_), WGSL_TEMPLATE_PARAMETER(qkv_head_size, qkv_head_size_), WGSL_TEMPLATE_PARAMETER(qkv_num_heads, qkv_num_heads_), WGSL_TEMPLATE_PARAMETER(use_seqlen_k, use_seqlen_k_)); @@ -256,10 +260,20 @@ Status ComputeFlashAttentionDecodeQKT(onnxruntime::webgpu::ComputeContext& conte {metadata, ProgramTensorMetadataDependency::Rank, 2}}); const uint32_t vectorized_head_size = parameters.head_size_ / components; + + // Get attention bias dimensions for broadcasting + uint32_t attn_bias_dim0 = 1; + uint32_t attn_bias_dim1 = 1; + if (has_attention_bias) { + const auto& bias_shape = attention_bias->Shape(); + attn_bias_dim0 = static_cast(bias_shape[0]); + attn_bias_dim1 = static_cast(bias_shape[1]); + } + if (use_indirect_dispatch) { program.SetIndirectDispatchTensor(indirect_buffer); } else { - program.SetDispatchGroupSize(parameters.num_heads_ * num_total_seq_length_tile); + program.SetDispatchGroupSize(parameters.batch_size_ * parameters.num_heads_ * num_total_seq_length_tile); } program.SetWorkgroupSize(64) .CacheHint(tile_size, has_attention_bias, use_indirect_dispatch) @@ -269,7 +283,10 @@ Status ComputeFlashAttentionDecodeQKT(onnxruntime::webgpu::ComputeContext& conte present_sequence_length, {static_cast(parameters.n_reps)}, {num_present_sequence_length_tile}, - {static_cast(parameters.num_heads_)}}); + {static_cast(parameters.num_heads_)}, + {static_cast(parameters.batch_size_)}, + {attn_bias_dim0}, + {attn_bias_dim1}}); return context.RunProgram(program); } @@ -313,11 +330,12 @@ Status ComputeFlashAttentionDecodeSplitVxScore(onnxruntime::webgpu::ComputeConte {qk, ProgramTensorMetadataDependency::TypeAndRank}, {present_value, ProgramTensorMetadataDependency::TypeAndRank, components}}); program.AddOutputs({{out_split_vx, ProgramTensorMetadataDependency::TypeAndRank, components}}); // [B, N, split_k, head_size] + const uint32_t batch_heads = static_cast(parameters.batch_size_ * parameters.num_heads_); if (use_indirect_dispatch) { program.AddInput({seqlen_k, ProgramTensorMetadataDependency::None}) .SetIndirectDispatchTensor(indirect_buffer); } else { - program.SetDispatchGroupSize(parameters.num_heads_ * num_total_seq_length_tile); + program.SetDispatchGroupSize(batch_heads * num_total_seq_length_tile); } program.CacheHint(tile_size, head_size_vec, use_indirect_dispatch) .SetWorkgroupSize(64) @@ -326,7 +344,7 @@ Status ComputeFlashAttentionDecodeSplitVxScore(onnxruntime::webgpu::ComputeConte present_sequence_length, {static_cast(parameters.n_reps)}, num_present_sequence_length_tile, - {static_cast(parameters.num_heads_)}}); + {batch_heads}}); return context.RunProgram(program); } @@ -363,14 +381,15 @@ Status ComputeFlashAttentionDecodeVxReduce(onnxruntime::webgpu::ComputeContext& } program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank, components}}); const uint32_t num_head_size_tile = static_cast((parameters.v_head_size_ + tile_head_size - 1) / tile_head_size); - program.SetDispatchGroupSize(parameters.num_heads_ * num_head_size_tile) + const uint32_t batch_heads = static_cast(parameters.batch_size_ * parameters.num_heads_); + program.SetDispatchGroupSize(batch_heads * num_head_size_tile) .CacheHint(tile_size, seq_tile_size, use_indirect_dispatch) .SetWorkgroupSize(tile_size * tile_size) .AddUniformVariables({{static_cast(parameters.v_head_size_ / components)}, num_total_seq_length_tile, num_present_sequence_length_tile, {num_head_size_tile}, - {static_cast(parameters.num_heads_)}}); + {batch_heads}}); return context.RunProgram(program); } @@ -429,6 +448,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co bool is_qualcomm = context.AdapterInfo().vendor == std::string_view{"qualcomm"}; bool is_nvidia = context.AdapterInfo().vendor == std::string_view{"nvidia"}; bool is_fp16 = (Q->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); + bool q_BNSH = parameters.qkv_format_ == Q_K_V_BNSH; FlashAttentionProgram program{"FlashAttention", has_attention_bias, is_qualcomm, @@ -437,6 +457,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co parameters.num_heads_, parameters.is_unidirectional_, is_nvidia, + q_BNSH, use_seqlen_k}; program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, 4}, {present_key, ProgramTensorMetadataDependency::TypeAndRank, 4}, @@ -451,15 +472,28 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co const float alpha = parameters.scale_ == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size_)) : parameters.scale_; const uint32_t num_seq_tile = (parameters.sequence_length_ + tile_size - 1) / tile_size; - program.SetDispatchGroupSize(parameters.num_heads_ * num_seq_tile) + + // Get attention bias dimensions for broadcasting + uint32_t attn_bias_dim0 = 1; + uint32_t attn_bias_dim1 = 1; + if (has_attention_bias) { + const auto& bias_shape = attention_bias->Shape(); + attn_bias_dim0 = static_cast(bias_shape[0]); + attn_bias_dim1 = static_cast(bias_shape[1]); + } + + program.SetDispatchGroupSize(parameters.batch_size_ * parameters.num_heads_ * num_seq_tile) .SetWorkgroupSize(tile_size) - .CacheHint(has_attention_bias, parameters.head_size_, parameters.num_heads_, parameters.is_unidirectional_, is_qualcomm, is_nvidia, use_seqlen_k) + .CacheHint(has_attention_bias, parameters.head_size_, parameters.num_heads_, parameters.is_unidirectional_, is_qualcomm, is_nvidia, q_BNSH, use_seqlen_k) .AddUniformVariables({{static_cast(parameters.sequence_length_)}, {static_cast(parameters.total_sequence_length_)}, {static_cast(present_sequence_length)}, + {static_cast(parameters.batch_size_)}, {static_cast(parameters.n_reps)}, {alpha}, - {num_seq_tile}}); + {num_seq_tile}, + {attn_bias_dim0}, + {attn_bias_dim1}}); return context.RunProgram(program); } @@ -500,8 +534,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co bool CanApplyFlashAttention(const Tensor* bias, const Tensor* present_key, const Tensor* present_value, const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) { - return parameters.batch_size_ == 1 && - !parameters.is_packed_qkv_ && + return !parameters.is_packed_qkv_ && parameters.head_size_ == parameters.v_head_size_ && bias == nullptr && context.HasFeature(wgpu::FeatureName::Subgroups) && diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h index a936a91695921..bb8c8de8c8ab9 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h @@ -43,9 +43,9 @@ class SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram final : public Program { public: - CopyKVCacheProgram(const std::string& kernel_name, bool has_past, bool kv_BNSH, + CopyKVCacheProgram(const std::string& kernel_name, bool has_past, bool kv_BNSH, bool past_present_share_buffer, bool prepare_indirect_dispatch = false, bool use_seqlen_k = false) - : Program{kernel_name}, has_past_(has_past), kv_BNSH_(kv_BNSH), prepare_indirect_dispatch_(prepare_indirect_dispatch), use_seqlen_k_(use_seqlen_k) { + : Program{kernel_name}, has_past_(has_past), kv_BNSH_(kv_BNSH), past_present_share_buffer_(past_present_share_buffer), prepare_indirect_dispatch_(prepare_indirect_dispatch), use_seqlen_k_(use_seqlen_k) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -59,6 +59,7 @@ class CopyKVCacheProgram final : public Program { private: bool has_past_; bool kv_BNSH_; + bool past_present_share_buffer_; bool prepare_indirect_dispatch_; bool use_seqlen_k_; }; @@ -73,6 +74,7 @@ class FlashAttentionProgram final : public Program { int qkv_num_heads, bool is_unidirectional, bool is_nvidia, + bool q_BNSH, bool use_seqlen_k = false) : Program{kernel_name}, has_attention_bias_(has_attention_bias), @@ -82,6 +84,7 @@ class FlashAttentionProgram final : public Program { qkv_num_heads_(qkv_num_heads), is_unidirectional_(is_unidirectional), is_nvidia_(is_nvidia), + q_BNSH_(q_BNSH), use_seqlen_k_(use_seqlen_k) { } @@ -90,9 +93,12 @@ class FlashAttentionProgram final : public Program { WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"new_sequence_length", ProgramUniformVariableDataType::Uint32}, {"total_sequence_length", ProgramUniformVariableDataType::Uint32}, {"present_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"batch_size", ProgramUniformVariableDataType::Uint32}, {"n_reps", ProgramUniformVariableDataType::Uint32}, {"alpha", ProgramUniformVariableDataType::Float32}, - {"num_seq_tile", ProgramUniformVariableDataType::Uint32}); + {"num_seq_tile", ProgramUniformVariableDataType::Uint32}, + {"attn_bias_dim0", ProgramUniformVariableDataType::Uint32}, + {"attn_bias_dim1", ProgramUniformVariableDataType::Uint32}); private: bool has_attention_bias_; @@ -102,6 +108,7 @@ class FlashAttentionProgram final : public Program { int qkv_num_heads_; bool is_unidirectional_; bool is_nvidia_; + bool q_BNSH_; bool use_seqlen_k_; }; @@ -120,7 +127,10 @@ class FlashAttentionDecodeQKTProgram final : public Program v_tile : array, max_k_step>; // Private memory per lane. var q_tile : array; -fn loadq(q_idx_global : u32, head_idx : u32, alpha : q_element_t) { - // Stored as float16[batch_size,sequence_length,3072] the inputs as per onnx MHA - // This is the layout if TransferBSDToBNSH has not been run. - let offset = q_idx_global * (head_size_vec)*num_heads + head_size_vec * head_idx; - // Stored as BNSH - which is what webgpu uses after TransferBSDToBNSH has been run. - // let offset = head_idx * uniforms.new_sequence_length * head_size_vec + q_idx_global * head_size_vec; +fn loadq(batch_idx : u32, q_idx_global : u32, head_idx : u32, alpha : q_element_t) { +#if q_BNSH + // Stored as BNSH - float16[batch_size, num_heads, sequence_length, head_size] + let offset = batch_idx * num_heads * uniforms.new_sequence_length * head_size_vec + + head_idx * uniforms.new_sequence_length * head_size_vec + + q_idx_global * head_size_vec; +#else + // Stored as BSNH - float16[batch_size, sequence_length, num_heads, head_size] + let offset = batch_idx * uniforms.new_sequence_length * head_size_vec * num_heads + + q_idx_global * head_size_vec * num_heads + + head_idx * head_size_vec; +#endif for (var idx : u32 = 0; idx < head_size_vec; idx++) { q_tile[idx] = q[idx + offset] * alpha; } } -fn loadk(k_start : u32, head_idx : u32, local_idx : u32, k_step : u32) { +fn loadk(k_start : u32, batch_head_idx : u32, local_idx : u32, k_step : u32) { // Stored as float16[batch_size,num_heads,present_sequence_length,96] - let offset = head_idx * uniforms.present_sequence_length * head_size_vec + k_start * head_size_vec; + let offset = batch_head_idx / uniforms.n_reps * uniforms.present_sequence_length * head_size_vec + + k_start * head_size_vec; for (var idx : u32 = local_idx; idx < head_size_vec * k_step; idx += workgroup_size_x) { let slot = u32(idx / head_size_vec); let val = select(q_value_t(0), present_key[offset + idx], k_start + slot < get_total_sequence_length()); @@ -63,9 +71,10 @@ fn loadk(k_start : u32, head_idx : u32, local_idx : u32, k_step : u32) { } } -fn loadv(v_start : u32, head_idx : u32, local_idx : u32, v_step : u32) { +fn loadv(v_start : u32, batch_head_idx : u32, local_idx : u32, v_step : u32) { // Stored as float16[batch_size,num_heads,present_sequence_length,96] - let offset = head_idx * uniforms.present_sequence_length * head_size_vec + v_start * head_size_vec; + let offset = batch_head_idx / uniforms.n_reps * uniforms.present_sequence_length * head_size_vec + + v_start * head_size_vec; for (var idx : u32 = local_idx; idx < head_size_vec * v_step; idx += workgroup_size_x) { let slot = u32(idx / head_size_vec); let val = select(q_value_t(0), present_value[offset + idx], v_start + slot < get_total_sequence_length()); @@ -83,9 +92,10 @@ var o_tile_r : array, workgroup_ // Private memory per lane. var o_tile : array; -fn writeo(o_idx_global : u32, head_idx : u32, local_idx : u32) { +fn writeo(batch_idx : u32, o_idx_global : u32, head_idx : u32, local_idx : u32) { // Stored as float16[batch_size,sequence_length,3072] - let offset = o_idx_global * num_heads * head_size_vec + head_idx * head_size_vec; + let offset = batch_idx * uniforms.new_sequence_length * num_heads * head_size_vec + + o_idx_global * num_heads * head_size_vec + head_idx * head_size_vec; for (var idx : u32 = 0; idx < half_head_size_vec; idx++) { output[offset + idx] = o_tile[idx]; output[offset + idx + half_head_size_vec] = o_tile_r[local_idx][idx]; @@ -94,9 +104,10 @@ fn writeo(o_idx_global : u32, head_idx : u32, local_idx : u32) { #else // Private memory per lane. var o_tile : array; -fn writeo(o_idx_global : u32, head_idx : u32) { +fn writeo(batch_idx : u32, o_idx_global : u32, head_idx : u32) { // Stored as float16[batch_size,sequence_length,3072] - let offset = o_idx_global * num_heads * head_size_vec + head_idx * head_size_vec; + let offset = batch_idx * uniforms.new_sequence_length * num_heads * head_size_vec + + o_idx_global * num_heads * head_size_vec + head_idx * head_size_vec; for (var idx : u32 = 0; idx < head_size_vec; idx++) { output[offset + idx] = o_tile[idx]; } @@ -104,12 +115,17 @@ fn writeo(o_idx_global : u32, head_idx : u32) { #endif #if has_attention_bias -fn loadAttentionBias(q_idx_global : u32, k_idx_global : u32, head_idx : u32) -> vec4 { +fn loadAttentionBias(batch_idx : u32, q_idx_global : u32, k_idx_global : u32, head_idx : u32) -> vec4 { // Stored as float16[batch_size,num_heads,new_seq_length,total_sequence_length] - if (q_idx_global >= uniforms.new_sequence_length || k_idx_global >= get_total_sequence_length()) { + if (k_idx_global >= get_total_sequence_length()) { return vec4(0); } - let offset_base = head_idx * uniforms.new_sequence_length * get_total_sequence_length() + q_idx_global * get_total_sequence_length(); + // Handle broadcasting: if dimension size is 1, use index 0 + let bias_batch_idx = select(batch_idx, 0u, batch_idx >= uniforms.attn_bias_dim0); + let bias_head_idx = select(head_idx, 0u, head_idx >= uniforms.attn_bias_dim1); + + let offset_base = bias_batch_idx * uniforms.attn_bias_dim1 * uniforms.new_sequence_length * get_total_sequence_length() + + bias_head_idx * uniforms.new_sequence_length * get_total_sequence_length() + q_idx_global * get_total_sequence_length(); let offset = offset_base + k_idx_global; let offset_max = offset_base + get_total_sequence_length(); let c1 = q_element_t(attention_bias[min(offset, offset_max)]); @@ -119,7 +135,7 @@ fn loadAttentionBias(q_idx_global : u32, k_idx_global : u32, head_idx : u32) -> return vec4(c1, c2, c3, c4); } #else -fn loadAttentionBias(q_idx_global : u32, k_idx_global : u32, head_idx : u32) -> vec4 { +fn loadAttentionBias(batch_idx : u32, q_idx_global : u32, k_idx_global : u32, head_idx : u32) -> vec4 { return vec4(0); } #endif @@ -141,15 +157,21 @@ fn fetchVTile(k_idx: u32, vec_idx: u32, v_val: q_value_t) -> q_value_t { } $MAIN { - let head_idx = u32(workgroup_idx / uniforms.num_seq_tile); + let batch_head_idx = u32(workgroup_idx / uniforms.num_seq_tile); + let head_idx = batch_head_idx % num_heads; + let batch_idx = batch_head_idx / num_heads; let capped_sg_id = min(sg_id, max_k_step - 1u); let capped_sg_size = min(sg_size, max_k_step); + if (batch_idx >= uniforms.batch_size) { + return; + } + // Load Q let q_idx_global = (workgroup_idx % uniforms.num_seq_tile) * workgroup_size_x + local_idx; let valid_q = q_idx_global < uniforms.new_sequence_length; if (valid_q) { - loadq(q_idx_global, head_idx, q_element_t(uniforms.alpha)); + loadq(batch_idx, q_idx_global, head_idx, q_element_t(uniforms.alpha)); } var previous_max : q_element_t = min_value; @@ -170,8 +192,8 @@ $MAIN { for (var k_start = 0u; k_start < loop_bound; k_start += capped_sg_size) { workgroupBarrier(); - loadk(k_start, head_idx / uniforms.n_reps, local_idx, capped_sg_size); - loadv(k_start, head_idx / uniforms.n_reps, local_idx, capped_sg_size); + loadk(k_start, batch_head_idx, local_idx, capped_sg_size); + loadv(k_start, batch_head_idx, local_idx, capped_sg_size); workgroupBarrier(); // Compute QKt @@ -229,11 +251,11 @@ $MAIN { qk_2[3] += dot(q_own, fetchKTile(7, i, k_local)); } } - qk_1 = qk_1 + loadAttentionBias(q_idx_global, k_start, head_idx); - qk_2 = qk_2 + loadAttentionBias(q_idx_global, k_start + 4, head_idx); + qk_1 = qk_1 + loadAttentionBias(batch_idx, q_idx_global, k_start, head_idx); + qk_2 = qk_2 + loadAttentionBias(batch_idx, q_idx_global, k_start + 4, head_idx); if (sg_size > 8) { - qk_3 = qk_3 + loadAttentionBias(q_idx_global, k_start + 8, head_idx); - qk_4 = qk_4 + loadAttentionBias(q_idx_global, k_start + 12, head_idx); + qk_3 = qk_3 + loadAttentionBias(batch_idx, q_idx_global, k_start + 8, head_idx); + qk_4 = qk_4 + loadAttentionBias(batch_idx, q_idx_global, k_start + 12, head_idx); } // Neuter qk values where K is out of bounds. @@ -360,7 +382,7 @@ $MAIN { } if (valid_q) { - writeo(q_idx_global, head_idx, local_idx); + writeo(batch_idx, q_idx_global, head_idx, local_idx); } #else if (sg_size > 8) { @@ -409,7 +431,7 @@ $MAIN { } if (valid_q) { - writeo(q_idx_global, head_idx); + writeo(batch_idx, q_idx_global, head_idx); } #endif } // MAIN diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkt.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkt.wgsl.template index ac9a157492007..e7944231f342e 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkt.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkt.wgsl.template @@ -35,12 +35,22 @@ var inner_qk_values: array, tile_ var tile_qk: array; #if has_attention_bias - fn loadAttentionBias(idx: u32) -> q_element_t + fn loadAttentionBias(batch_idx: u32, head_idx: u32, q_idx: u32, k_idx: u32, total_seq_length: u32) -> q_element_t { - return attention_bias[idx]; + // Handle broadcasting: if dimension size is 1, use index 0 + let bias_batch_idx = select(batch_idx, 0u, batch_idx >= uniforms.attn_bias_dim0); + let bias_head_idx = select(head_idx, 0u, head_idx >= uniforms.attn_bias_dim1); + + // Calculate flat offset with broadcasting applied + // attention_bias shape: [attn_bias_dim0, attn_bias_dim1, new_seq_length, total_seq_length] + // For decode, new_seq_length is 1, so we can simplify: + let offset = bias_batch_idx * uniforms.attn_bias_dim1 * total_seq_length + + bias_head_idx * total_seq_length + + k_idx; + return attention_bias[offset]; } #else - fn loadAttentionBias(idx: u32) -> q_element_t + fn loadAttentionBias(batch_idx: u32, head_idx: u32, q_idx: u32, k_idx: u32, total_seq_length: u32) -> q_element_t { return q_element_t(0); } @@ -56,9 +66,14 @@ $MAIN { #endif let num_total_seq_length_tile = (total_sequence_length + tile_size - 1) / tile_size; let total_seq_offset = (workgroup_idx % num_total_seq_length_tile) * tile_size; - let head_idx = u32(workgroup_idx / num_total_seq_length_tile); - let q_offset = head_idx * uniforms.head_size_vec; - let present_offset = u32(head_idx / uniforms.n_reps) * uniforms.present_sequence_length * uniforms.head_size_vec; + let batch_head_idx = u32(workgroup_idx / num_total_seq_length_tile); + let head_idx = batch_head_idx % uniforms.num_heads; + let batch_idx = batch_head_idx / uniforms.num_heads; + if (batch_idx >= uniforms.batch_size) { + return; + } + let q_offset = batch_idx * uniforms.num_heads * uniforms.head_size_vec + head_idx * uniforms.head_size_vec; + let present_offset = batch_head_idx / uniforms.n_reps * uniforms.present_sequence_length * uniforms.head_size_vec; for (var k: u32 = 0u; k < uniforms.head_size_vec; k += tile_size_k_vec) { if (local_idx < tile_size_k_vec && k + local_idx < uniforms.head_size_vec) { tile_q[local_idx] = q[q_offset + k + local_idx]; @@ -75,22 +90,18 @@ $MAIN { workgroupBarrier(); } - if (local_idx < tile_size && total_seq_offset + local_idx < total_sequence_length && head_idx < uniforms.num_heads) { + if (local_idx < tile_size && total_seq_offset + local_idx < total_sequence_length) { var sum = q_element_t(0); for (var i = 0u; i < tile_size_k_vec; i++) { sum += inner_qk_values[local_idx][i]; } - sum = sum + loadAttentionBias(head_idx * total_sequence_length + total_seq_offset + local_idx); + sum = sum + loadAttentionBias(batch_idx, head_idx, 0u, total_seq_offset + local_idx, total_sequence_length); tile_qk[local_idx] = sum; - output[head_idx * uniforms.present_sequence_length + total_seq_offset + local_idx] = sum; + output[batch_head_idx * uniforms.present_sequence_length + total_seq_offset + local_idx] = sum; } workgroupBarrier(); - if (head_idx >= uniforms.num_heads) { - return; - } - if (local_idx == 0u) { // Calculate the max and sum in current split. var l_max = f32(-3.4028234663852886e+38f); @@ -101,7 +112,7 @@ $MAIN { for (var i = 0u; i < tile_size && (total_seq_offset + i) < total_sequence_length; i++) { l_sum += exp(f32(tile_qk[i]) - l_max); } - let meta_offset = head_idx * uniforms.num_present_sequence_length_tile + workgroup_idx % num_total_seq_length_tile; + let meta_offset = batch_head_idx * uniforms.num_present_sequence_length_tile + workgroup_idx % num_total_seq_length_tile; metadata[meta_offset] = metadata_value_t(l_max, l_sum); } } diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template index a113e96130985..8139477172b03 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template @@ -48,30 +48,31 @@ $MAIN { #endif let num_total_seq_length_tile = (total_sequence_length + tile_size - 1) / tile_size; let total_seq_offset = (workgroup_idx % num_total_seq_length_tile) * tile_size; - let head_idx = u32(workgroup_idx / num_total_seq_length_tile); - let present_offset = u32(head_idx / uniforms.n_reps) * head_size_vec * uniforms.present_sequence_length; + let batch_head_idx = u32(workgroup_idx / num_total_seq_length_tile); + if (batch_head_idx >= uniforms.batch_heads) { + return; + } + let present_offset = u32(batch_head_idx / uniforms.n_reps) * head_size_vec * uniforms.present_sequence_length; // Calculate the global max and sum in qk. - if (head_idx < uniforms.num_heads) + var g_max = f32(-3.4028234663852886e+38f); + var g_sum = f32(0); + for (var i = 0u; i < num_total_seq_length_tile; i++) { - var g_max = f32(-3.4028234663852886e+38f); - var g_sum = f32(0); - for (var i = 0u; i < num_total_seq_length_tile; i++) - { - let meta_offset = head_idx * uniforms.num_present_sequence_length_tile + i; - g_max = max(g_max, metadata[meta_offset].x); - } - for (var i = 0u; i < num_total_seq_length_tile; i++) - { - let meta_offset = head_idx * uniforms.num_present_sequence_length_tile + i; - let m_value = metadata[meta_offset]; - g_sum += exp(m_value.x - g_max) * m_value.y; - } + let meta_offset = batch_head_idx * uniforms.num_present_sequence_length_tile + i; + g_max = max(g_max, metadata[meta_offset].x); + } + for (var i = 0u; i < num_total_seq_length_tile; i++) + { + let meta_offset = batch_head_idx * uniforms.num_present_sequence_length_tile + i; + let m_value = metadata[meta_offset]; + g_sum += exp(m_value.x - g_max) * m_value.y; + } if (total_seq_offset + local_idx < total_sequence_length) { - tile_qk[local_idx] = present_value_element_t(exp(f32(qk[head_idx * uniforms.present_sequence_length + total_seq_offset + local_idx]) - g_max) / g_sum); + tile_qk[local_idx] = present_value_element_t(exp(f32(qk[batch_head_idx * uniforms.present_sequence_length + total_seq_offset + local_idx]) - g_max) / g_sum); } - } + for (var k: u32 = 0u; k < head_size_vec; k += tile_size_k_vec) { var value = present_value_value_t(0); qkv_values[local_row][local_col] = present_value_value_t(0); @@ -96,12 +97,8 @@ $MAIN { workgroupBarrier(); } - if (head_idx >= uniforms.num_heads) { - return; - } - for (var i = local_idx; i < head_size_vec; i += workgroup_size_x) { - let out_offset = head_idx * uniforms.num_present_sequence_length_tile * head_size_vec + (workgroup_idx % num_total_seq_length_tile) * head_size_vec + i; + let out_offset = batch_head_idx * uniforms.num_present_sequence_length_tile * head_size_vec + (workgroup_idx % num_total_seq_length_tile) * head_size_vec + i; out_split_vx[out_offset] = tile_output[i]; } } diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template index 22f18655307de..f909a87724da6 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template @@ -20,8 +20,11 @@ var tile_input: array, tile_size>; $MAIN { let head_size_offset = (workgroup_idx % uniforms.num_head_size_tile) * tile_size; - let head_idx = u32(workgroup_idx / uniforms.num_head_size_tile); - let in_offset = head_idx * uniforms.num_present_sequence_length_tile * uniforms.head_size_vec; + let batch_head_idx = u32(workgroup_idx / uniforms.num_head_size_tile); + if (batch_head_idx >= uniforms.batch_heads) { + return; + } + let in_offset = batch_head_idx * uniforms.num_present_sequence_length_tile * uniforms.head_size_vec; var value = output_value_t(0); let local_row = u32(local_idx / tile_size); let local_col = local_idx % tile_size; @@ -43,16 +46,12 @@ $MAIN { tile_input[local_row][local_col] = value; workgroupBarrier(); - if (head_idx >= uniforms.num_heads) { - return; - } - if (local_idx < tile_size && head_size_offset + local_idx < uniforms.head_size_vec) { value = output_value_t(0); for (var i = 0u; i < tile_size; i++) { value += tile_input[i][local_idx]; } - let output_id = head_idx * uniforms.head_size_vec + head_size_offset + local_idx; + let output_id = batch_head_idx * uniforms.head_size_vec + head_size_offset + local_idx; output[output_id] = value; } } diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index 416a895e61745..7ca61008be83f 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -287,7 +287,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& // Create a temporary parameters copy with is_packed_qkv_ set to false to check if flash attention can be applied after unpacking WebgpuAttentionParameters temp_params = parameters; temp_params.is_packed_qkv_ = false; - will_use_flash_attention = CanApplyFlashAttention(attention_bias, present_key, present_value, temp_params, context); + will_use_flash_attention = CanApplyFlashAttention(nullptr, present_key, present_value, temp_params, context); } if (parameters.is_packed_qkv_ && do_rotary_) { diff --git a/onnxruntime/core/providers/webgpu/program_manager.cc b/onnxruntime/core/providers/webgpu/program_manager.cc index 8ef229b9ef69f..f2996da4bd29e 100644 --- a/onnxruntime/core/providers/webgpu/program_manager.cc +++ b/onnxruntime/core/providers/webgpu/program_manager.cc @@ -119,6 +119,7 @@ Status ProgramManager::Build(const ProgramBase& program, wgpu::ShaderModuleDescriptor descriptor{}; descriptor.nextInChain = &wgsl_source; + descriptor.label = program.Name().c_str(); auto shader_module = device.CreateShaderModule(&descriptor);