-
-
Notifications
You must be signed in to change notification settings - Fork 11.8k
[ROCm][MXFP4] Enable FP4 MLA BMM support #30177
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -4,8 +4,10 @@ | |||||
| from collections.abc import Iterable, Mapping | ||||||
| from types import MappingProxyType | ||||||
| from typing import Any | ||||||
|
|
||||||
| import torch | ||||||
| import regex as re | ||||||
| from torch import nn | ||||||
| from aiter.ops.triton.quant import dynamic_mxfp4_quant | ||||||
|
|
||||||
|
|
||||||
| def deep_compare(dict1: Any, dict2: Any) -> bool: | ||||||
|
|
@@ -103,3 +105,97 @@ | |||||
| elif target == value: | ||||||
| return True | ||||||
| return False | ||||||
|
|
||||||
| # utility for tensor dims > 2 cases | ||||||
| def b_dynamic_mxfp4_quant(x): | ||||||
| h, b, d = x.shape | ||||||
| x, x_scales = dynamic_mxfp4_quant(x.reshape(-1, d)) | ||||||
| return x.view(h, b, d // 2), x_scales.view(h, b, d // 32) | ||||||
| #return x.view(h, b, d // 2), x_scales.view(h, b, d // 32) | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. NITS: let's clean up these |
||||||
|
|
||||||
|
|
||||||
| def mxfp4_to_f32(x, is_threed): | ||||||
| # 2 because we pack fp4 in uint8. | ||||||
| x = x.repeat_interleave(2, dim=-1) | ||||||
| if is_threed: | ||||||
| x[..., ::2] = x[..., ::2] & 0xF | ||||||
| x[..., 1::2] = x[..., 1::2] >> 4 | ||||||
| else: | ||||||
| x[:, ::2] = x[:, ::2] & 0xF | ||||||
| x[:, 1::2] = x[:, 1::2] >> 4 | ||||||
|
|
||||||
| mxfp4_list = [ | ||||||
| 0.0, | ||||||
| 0.5, | ||||||
| 1.0, | ||||||
| 1.5, | ||||||
| 2.0, | ||||||
| 3.0, | ||||||
| 4.0, | ||||||
| 6.0, | ||||||
| -0.0, | ||||||
| -0.5, | ||||||
| -1.0, | ||||||
| -1.5, | ||||||
| -2.0, | ||||||
| -3.0, | ||||||
| -4.0, | ||||||
| -6.0, | ||||||
| ] | ||||||
| mxfp4_in_f32 = torch.tensor(mxfp4_list, dtype=torch.float32, device="cuda") | ||||||
| return mxfp4_in_f32[x.long()] | ||||||
|
|
||||||
|
|
||||||
| def e8m0_to_f32(x): | ||||||
| # Convert the input tensor `x` (assumed to be in e8m0 format) to float32. | ||||||
| # e8m0 is a custom 8-bit floating point format with 8 bits for exponent, 0 for mantissa. | ||||||
| # This means the value is essentially 2^(exponent - 127), similar to how IEEE-754 stores floats. | ||||||
|
|
||||||
| # Convert x to float32 for computation, and compute the power of 2 by subtracting the bias (127). | ||||||
| x_f32 = 2 ** ((x.to(torch.float32)) - 127) | ||||||
|
|
||||||
| # If the exponent value was 255 (i.e., 2^(128)), this is a special case usually used to represent NaN or Inf. | ||||||
| # Since this custom format has no mantissa, treat 2^128 as NaN. | ||||||
| x_f32[x_f32 == 128] = float("nan") | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The condition to detect NaN values appears to be incorrect. The comment at line 158 indicates that an exponent value of 255 should be treated as NaN. However, the code checks if the computed float value
Suggested change
|
||||||
| return x_f32 | ||||||
|
|
||||||
| def quark_post_load_weights(self_attn: nn.Module, w: torch.Tensor, quant_format: str): | ||||||
| if "mxfp4" in quant_format: | ||||||
| # when dtype is bf16, the processing flow is to dynamic quantize bf16 tensor to uint8 tensor | ||||||
| # do w_kc (bf16) first to get the w_kc(uint8) w_s_kc(uint8) | ||||||
| # and w_vc repeating the same procedure of w_kc to get w_vc(uint8) w_s_vc(uint8) | ||||||
| if w.dtype == torch.bfloat16: | ||||||
| # w_kc, w_vc = w.split( | ||||||
| # [self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) | ||||||
| w_kc, w_vc = w.unflatten( | ||||||
| 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim) | ||||||
| ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) | ||||||
| w_kc, w_s_kc = b_dynamic_mxfp4_quant(w_kc.transpose(-2, -1)) | ||||||
| w_kc = w_kc.transpose(-2, -1) | ||||||
| w_s_kc = w_s_kc.transpose(-2, -1) | ||||||
| w_vc, w_s_vc = b_dynamic_mxfp4_quant(w_vc) | ||||||
| w_s_kc = w_s_kc.transpose(1, 2).contiguous().transpose(1, 2) | ||||||
| w_s_vc = w_s_vc.contiguous().transpose(1, 2) | ||||||
| elif w.dtype == torch.uint8: # static quant for mxfp4 | ||||||
| # when dtype is uint8, it means the w has been quantized to mxfp4 format | ||||||
| # but we must separate it to w_kc and w_vc. | ||||||
| # The quantized tensor size is only half of original tensor size | ||||||
| # and the scaling factor is 1/32, the transpose behavior will be not correct | ||||||
| # need to upcast it to fp32 to separate w to w_kc and w_vc | ||||||
| # to ensure the following transpose behavior is correct | ||||||
| # and then do mxfp4 quant again | ||||||
| w = mxfp4_to_f32(w, True).to(torch.bfloat16) | ||||||
| w_scales = self_attn.kv_b_proj.weight_scale.repeat_interleave(32, dim=-1) | ||||||
| w_scales = e8m0_to_f32(w_scales).to(torch.bfloat16) | ||||||
| w = w * w_scales | ||||||
| w_kc, w_vc = w.unflatten( | ||||||
| 0, (-1, (self_attn.qk_nope_head_dim + self_attn.v_head_dim)) | ||||||
| ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) | ||||||
| w_kc, w_s_kc = b_dynamic_mxfp4_quant(w_kc.transpose(-2, -1)) | ||||||
| w_kc = w_kc.transpose(-2, -1) | ||||||
| w_s_kc = w_s_kc.transpose(-2, -1) | ||||||
| w_vc, w_s_vc = b_dynamic_mxfp4_quant(w_vc) | ||||||
| w_s_kc = w_s_kc.transpose(1, 2).contiguous().transpose(1, 2) | ||||||
| w_s_vc = w_s_vc.contiguous().transpose(1, 2) | ||||||
|
|
||||||
| return w_kc, w_s_kc, w_vc, w_s_vc | ||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1139,6 +1139,7 @@ | |
| self.indexer = indexer | ||
| self.q_pad_num_heads = q_pad_num_heads | ||
| self.is_aiter_triton_fp8_bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled() | ||
| self.is_aiter_triton_fp4_bmm_enabled = rocm_aiter_ops.is_fp4bmm_enabled() | ||
|
|
||
| def process_weights_after_loading(self, act_dtype: torch.dtype): | ||
| def get_layer_weight(layer): | ||
|
|
@@ -1237,31 +1238,54 @@ | |
|
|
||
| def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor): | ||
| # Convert from (B, N, L) to (N, B, L) | ||
| x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) | ||
|
|
||
| if self.is_aiter_triton_fp8_bmm_enabled: | ||
| if self.is_aiter_triton_fp4_bmm_enabled: | ||
| from aiter.ops.triton.batched_gemm_a16wfp4 import batched_gemm_a16wfp4 | ||
|
|
||
| #print(f'>>> x pre (up_proj) {x.shape}') | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. NITS: let's clean up these |
||
| out = out.view(-1, self.num_heads, self.v_head_dim) | ||
| # Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V) | ||
| x = rocm_aiter_ops.triton_fp8_bmm( | ||
| x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True, YQ=out | ||
| x = x.view(-1, self.num_heads, self.kv_lora_rank) | ||
| x = x.transpose(0, 1) | ||
|
|
||
| #print(f'>>> x {x.shape}, attn_bmm_output {attn_bmm_output.shape}, self.W_V {self.W_V.shape}') | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. NITS: let's clean up these |
||
| out = batched_gemm_a16wfp4( | ||
| x, | ||
| self.W_V, | ||
| self.W_V_scale, | ||
| y=out, | ||
| transpose_bm=True, | ||
| prequant=True, | ||
| y_scale=None, | ||
| ) | ||
| #print(f'>>> x before transpose {x.shape}') | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. NITS: let's clean up these |
||
| out = out.view(-1, self.num_heads * self.v_head_dim) | ||
| x = out | ||
| else: | ||
| # Convert from (B, N * V) to (N, B, V) | ||
| out = out.view(-1, self.num_heads, self.v_head_dim).transpose(0, 1) | ||
| x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) | ||
|
|
||
| if self.is_aiter_triton_fp8_bmm_enabled: | ||
| out = out.view(-1, self.num_heads, self.v_head_dim) | ||
| # Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V) | ||
| x = rocm_aiter_ops.triton_fp8_bmm( | ||
| x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True, YQ=out | ||
| ) | ||
| else: | ||
| # Convert from (B, N * V) to (N, B, V) | ||
| out = out.view(-1, self.num_heads, self.v_head_dim).transpose(0, 1) | ||
|
|
||
| # Multiply (N, B, L) x (N, L, V) -> (N, B, V) | ||
| torch.bmm(x, self.W_UV, out=out) # Reuse "out" to make it "hot" | ||
| # Multiply (N, B, L) x (N, L, V) -> (N, B, V) | ||
| torch.bmm(x, self.W_UV, out=out) # Reuse "out" to make it "hot" | ||
|
|
||
| # Convert from (N, B, V) to (B, N * V) | ||
| out_new = out.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) | ||
| # Convert from (N, B, V) to (B, N * V) | ||
| out_new = out.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) | ||
|
|
||
| # Adjust output buffer shape back to the original (B, N * V) | ||
| N, B, V = out.shape | ||
| out.resize_((B, N * V)) | ||
| out.copy_(out_new) # Copy result | ||
|
|
||
| # Adjust output buffer shape back to the original (B, N * V) | ||
| N, B, V = out.shape | ||
| out.resize_((B, N * V)) | ||
| out.copy_(out_new) # Copy result | ||
|
|
||
|
|
||
| class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): | ||
| """ | ||
| NOTE: Please read the comment at the top of the file before trying to | ||
| understand this class | ||
|
|
@@ -1574,29 +1598,42 @@ | |
| return dequant_weights.T | ||
| return layer.weight | ||
|
|
||
| # we currently do not have quantized bmm's which are needed for | ||
| # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform | ||
| # the bmm's in 16-bit, the extra memory overhead of this is fairly low | ||
| kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T | ||
| kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj) | ||
|
|
||
| # If kv_b_proj_weight is unquantized, quantize it to mxfp4 if supported | ||
| if self.is_aiter_triton_fp4_bmm_enabled: | ||
| from vllm.model_executor.layers.quantization.quark.utils import quark_post_load_weights | ||
|
|
||
| self.W_K, self.W_K_scale, W_V, self.W_V_scale = ( | ||
| quark_post_load_weights(self, kv_b_proj_weight, "mxfp4")) | ||
| self.W_V = W_V.contiguous().transpose(1, 2) | ||
|
|
||
| self.W_K = self.W_K.transpose(-2, -1).contiguous() | ||
| self.W_K_scale = self.W_K_scale.transpose(-2, -1).contiguous() | ||
| self.W_V = self.W_V.transpose(-2, -1).contiguous() | ||
| self.W_V_scale = self.W_V_scale.transpose(-2, -1).contiguous() | ||
| return | ||
|
|
||
| # If kv_b_proj_weight is not being quantized to mxfp4, take the default path | ||
| # which is to dequantize and transpoose kv_b_proj_weight | ||
| kv_b_proj_weight = kv_b_proj_weight.T | ||
| assert kv_b_proj_weight.shape == ( | ||
| self.kv_lora_rank, | ||
| self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), | ||
| ), ( | ||
| f"{kv_b_proj_weight.shape=}, " | ||
| f"{self.kv_lora_rank=}, " | ||
| f"{self.num_heads=}, " | ||
| f"{self.qk_nope_head_dim=}, " | ||
| f"{self.v_head_dim=}" | ||
| ) | ||
| self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), ( | ||
| f"{kv_b_proj_weight.shape=}, " | ||
| f"{self.kv_lora_rank=}, " | ||
| f"{self.num_heads=}, " | ||
| f"{self.qk_nope_head_dim=}, " | ||
| f"{self.v_head_dim=}") | ||
| kv_b_proj_weight = kv_b_proj_weight.view( | ||
| self.kv_lora_rank, | ||
| self.num_heads, | ||
| self.qk_nope_head_dim + self.v_head_dim, | ||
| ) | ||
|
|
||
| W_UK, W_UV = kv_b_proj_weight.split( | ||
| [self.qk_nope_head_dim, self.v_head_dim], dim=-1 | ||
| ) | ||
| [self.qk_nope_head_dim, self.v_head_dim], dim=-1 | ||
| ) | ||
|
|
||
| if self.is_aiter_triton_fp8_bmm_enabled: | ||
| W_K = W_UK.transpose(0, 1) # 16 512 128 | ||
|
|
@@ -1987,8 +2024,7 @@ | |
| decode_q_nope, decode_q_pe = decode_q.split( | ||
| [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 | ||
| ) | ||
|
|
||
| # Convert from (B, N, P) to (N, B, P) | ||
| # Convert from (B, N, P) to (N, B, P) | ||
| decode_q_nope = decode_q_nope.transpose(0, 1) | ||
|
|
||
| if self.q_pad_num_heads is not None: | ||
|
|
@@ -1998,7 +2034,26 @@ | |
| decode_pe_padded.copy_(decode_q_pe) | ||
| decode_q_pe = decode_pe_padded | ||
|
|
||
| if self.is_aiter_triton_fp8_bmm_enabled: | ||
| if self.is_aiter_triton_fp4_bmm_enabled: | ||
| from aiter.ops.triton.batched_gemm_a16wfp4 import batched_gemm_a16wfp4 | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems this kernel does not exist in the AITER commit that is currently used in the |
||
|
|
||
| #x = x.view(-1, self.num_heads, self.kv_lora_rank) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. NITS: let's clean up these |
||
| x = decode_q_nope.transpose(0, 1) | ||
| decode_ql_nope = None | ||
| #print(f'>>> x {x.shape}, q_nope_out {q_nope_out.shape}, self.W_K {self.W_K.shape}') | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. NITS: let's clean up these |
||
|
|
||
| decode_ql_nope = batched_gemm_a16wfp4( | ||
| x, | ||
| self.W_K, | ||
| self.W_K_scale, | ||
| y=decode_ql_nope, | ||
| transpose_bm=True, | ||
| prequant=True, | ||
| y_scale=layer._q_scale if fp8_attention else None, | ||
| ) | ||
| elif self.is_aiter_triton_fp8_bmm_enabled: | ||
| # Convert from (B, N, P) to (N, B, P) | ||
| decode_q_nope = decode_q_nope.transpose(0, 1) | ||
| # Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L) | ||
| decode_ql_nope = rocm_aiter_ops.triton_fp8_bmm( | ||
| decode_q_nope, | ||
|
|
@@ -2026,7 +2081,7 @@ | |
| # Convert from (N, B, L) to (B, N, L) | ||
| decode_ql_nope = decode_ql_nope.transpose(0, 1) | ||
|
|
||
| if fp8_attention: | ||
| if fp8_attention and not self.is_aiter_triton_fp4_bmm_enabled: | ||
| ql_nope_shape = decode_ql_nope.shape | ||
| decode_ql_nope, _ = ops.scaled_fp8_quant( | ||
| decode_ql_nope.reshape( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's not add new flags. Since they are all linear ops (GEMM related).
Please consolidate
VLLM_ROCM_USE_AITER_FP8BMMandVLLM_ROCM_USE_AITER_FP4BMMinto usingVLLM_ROCM_USE_AITER_LINEARto control.We can consolidate
VLLM_ROCM_USE_AITER_FP4_ASM_GEMMintoVLLM_ROCM_USE_AITER_LINEARin another PR.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's infer the conditions based on model weights, model config, and store those evaluated conditions as class properties rather than depending on the flags.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please provide the per gain of enabling FP4 MLA BMM support, and also lm_eval score of FP4 MLA BMM and FP8 MLA BMM, since we are consolidating the logic of both FP4 MLA BMM and FP8 MLA BMM in this PR.