Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions vllm/_aiter_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def _rocm_aiter_fused_moe_impl(
activation = ActivationType(activation_method)
quant_type = QuantType(quant_method)


return fused_moe(
hidden_states,
w1,
Expand Down Expand Up @@ -481,6 +482,7 @@ class rocm_aiter_ops:
_MHA_ENABLED = envs.VLLM_ROCM_USE_AITER_MHA
_TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
_FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM
_FP4BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP4BMM
_FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
_TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE
_MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS
Expand Down Expand Up @@ -550,6 +552,12 @@ def is_triton_unified_attn_enabled(cls) -> bool:
def is_fp8bmm_enabled(cls) -> bool:
return cls._AITER_ENABLED and cls._FP8BMM_ENABLED

@classmethod
@if_aiter_supported
def is_fp4bmm_enabled(cls) -> bool:
""" "Verifies device specs and availability of env variable."""
return cls._AITER_ENABLED and cls._FP4BMM_ENABLED and current_platform.supports_mx()

@classmethod
@if_aiter_supported
def is_asm_fp4_gemm_dynamic_quant_enabled(cls) -> bool:
Expand Down
6 changes: 6 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@
VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: bool = False
VLLM_ROCM_USE_AITER_TRITON_ROPE: bool = False
VLLM_ROCM_USE_AITER_FP8BMM: bool = True
VLLM_ROCM_USE_AITER_FP4BMM: bool = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not add new flags. Since they are all linear ops (GEMM related).
Please consolidate VLLM_ROCM_USE_AITER_FP8BMM and VLLM_ROCM_USE_AITER_FP4BMM into using VLLM_ROCM_USE_AITER_LINEAR to control.

We can consolidate VLLM_ROCM_USE_AITER_FP4_ASM_GEMM into VLLM_ROCM_USE_AITER_LINEAR in another PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's infer the conditions based on model weights, model config, and store those evaluated conditions as class properties rather than depending on the flags.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please provide the per gain of enabling FP4 MLA BMM support, and also lm_eval score of FP4 MLA BMM and FP8 MLA BMM, since we are consolidating the logic of both FP4 MLA BMM and FP8 MLA BMM in this PR.

VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: bool = False
VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: bool = False
VLLM_ROCM_USE_AITER_TRITON_GEMM: bool = True
Expand Down Expand Up @@ -981,6 +982,11 @@ def get_vllm_port() -> int | None:
"VLLM_ROCM_USE_AITER_FP8BMM": lambda: (
os.getenv("VLLM_ROCM_USE_AITER_FP8BMM", "True").lower() in ("true", "1")
),
# Whether to use aiter triton fp4 bmm kernel
# By default is enabled, only used if layers are unquantized
"VLLM_ROCM_USE_AITER_FP4BMM": lambda: (
os.getenv("VLLM_ROCM_USE_AITER_FP4BMM", "True").lower() in ("true", "1")
),
# Use AITER triton unified attention for V1 attention
"VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION": lambda: (
os.getenv("VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION", "False").lower()
Expand Down
98 changes: 97 additions & 1 deletion vllm/model_executor/layers/quantization/quark/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
from collections.abc import Iterable, Mapping
from types import MappingProxyType
from typing import Any

import torch
import regex as re
from torch import nn
from aiter.ops.triton.quant import dynamic_mxfp4_quant


def deep_compare(dict1: Any, dict2: Any) -> bool:
Expand Down Expand Up @@ -103,3 +105,97 @@
elif target == value:
return True
return False

# utility for tensor dims > 2 cases
def b_dynamic_mxfp4_quant(x):
h, b, d = x.shape
x, x_scales = dynamic_mxfp4_quant(x.reshape(-1, d))
return x.view(h, b, d // 2), x_scales.view(h, b, d // 32)
#return x.view(h, b, d // 2), x_scales.view(h, b, d // 32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NITS: let's clean up these



def mxfp4_to_f32(x, is_threed):
# 2 because we pack fp4 in uint8.
x = x.repeat_interleave(2, dim=-1)
if is_threed:
x[..., ::2] = x[..., ::2] & 0xF
x[..., 1::2] = x[..., 1::2] >> 4
else:
x[:, ::2] = x[:, ::2] & 0xF
x[:, 1::2] = x[:, 1::2] >> 4

mxfp4_list = [
0.0,
0.5,
1.0,
1.5,
2.0,
3.0,
4.0,
6.0,
-0.0,
-0.5,
-1.0,
-1.5,
-2.0,
-3.0,
-4.0,
-6.0,
]
mxfp4_in_f32 = torch.tensor(mxfp4_list, dtype=torch.float32, device="cuda")
return mxfp4_in_f32[x.long()]


def e8m0_to_f32(x):
# Convert the input tensor `x` (assumed to be in e8m0 format) to float32.
# e8m0 is a custom 8-bit floating point format with 8 bits for exponent, 0 for mantissa.
# This means the value is essentially 2^(exponent - 127), similar to how IEEE-754 stores floats.

Check failure on line 152 in vllm/model_executor/layers/quantization/quark/utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/layers/quantization/quark/utils.py:152:89: E501 Line too long (92 > 88)

Check failure on line 153 in vllm/model_executor/layers/quantization/quark/utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/layers/quantization/quark/utils.py:153:89: E501 Line too long (100 > 88)
# Convert x to float32 for computation, and compute the power of 2 by subtracting the bias (127).
x_f32 = 2 ** ((x.to(torch.float32)) - 127)

Check failure on line 155 in vllm/model_executor/layers/quantization/quark/utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/layers/quantization/quark/utils.py:155:89: E501 Line too long (101 > 88)

# If the exponent value was 255 (i.e., 2^(128)), this is a special case usually used to represent NaN or Inf.
# Since this custom format has no mantissa, treat 2^128 as NaN.

Check failure on line 158 in vllm/model_executor/layers/quantization/quark/utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/layers/quantization/quark/utils.py:158:89: E501 Line too long (113 > 88)
x_f32[x_f32 == 128] = float("nan")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The condition to detect NaN values appears to be incorrect. The comment at line 158 indicates that an exponent value of 255 should be treated as NaN. However, the code checks if the computed float value x_f32 is equal to 128, which corresponds to an original exponent value of 134, not 255. This will lead to incorrect conversions of valid numbers to NaN and will fail to handle the special case of exponent 255 as intended. The check should be performed on the input tensor x for the value 255.

Suggested change
x_f32[x_f32 == 128] = float("nan")
x_f32[x == 255] = float("nan")

return x_f32

def quark_post_load_weights(self_attn: nn.Module, w: torch.Tensor, quant_format: str):
if "mxfp4" in quant_format:
# when dtype is bf16, the processing flow is to dynamic quantize bf16 tensor to uint8 tensor
# do w_kc (bf16) first to get the w_kc(uint8) w_s_kc(uint8)

Check failure on line 165 in vllm/model_executor/layers/quantization/quark/utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/layers/quantization/quark/utils.py:165:89: E501 Line too long (100 > 88)
# and w_vc repeating the same procedure of w_kc to get w_vc(uint8) w_s_vc(uint8)
if w.dtype == torch.bfloat16:

Check failure on line 167 in vllm/model_executor/layers/quantization/quark/utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/layers/quantization/quark/utils.py:167:89: E501 Line too long (89 > 88)
# w_kc, w_vc = w.split(
# [self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
w_kc, w_vc = w.unflatten(
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
w_kc, w_s_kc = b_dynamic_mxfp4_quant(w_kc.transpose(-2, -1))
w_kc = w_kc.transpose(-2, -1)
w_s_kc = w_s_kc.transpose(-2, -1)
w_vc, w_s_vc = b_dynamic_mxfp4_quant(w_vc)
w_s_kc = w_s_kc.transpose(1, 2).contiguous().transpose(1, 2)
w_s_vc = w_s_vc.contiguous().transpose(1, 2)
elif w.dtype == torch.uint8: # static quant for mxfp4
# when dtype is uint8, it means the w has been quantized to mxfp4 format
# but we must separate it to w_kc and w_vc.
# The quantized tensor size is only half of original tensor size
# and the scaling factor is 1/32, the transpose behavior will be not correct
# need to upcast it to fp32 to separate w to w_kc and w_vc
# to ensure the following transpose behavior is correct
# and then do mxfp4 quant again
w = mxfp4_to_f32(w, True).to(torch.bfloat16)
w_scales = self_attn.kv_b_proj.weight_scale.repeat_interleave(32, dim=-1)
w_scales = e8m0_to_f32(w_scales).to(torch.bfloat16)
w = w * w_scales
w_kc, w_vc = w.unflatten(
0, (-1, (self_attn.qk_nope_head_dim + self_attn.v_head_dim))
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
w_kc, w_s_kc = b_dynamic_mxfp4_quant(w_kc.transpose(-2, -1))
w_kc = w_kc.transpose(-2, -1)
w_s_kc = w_s_kc.transpose(-2, -1)
w_vc, w_s_vc = b_dynamic_mxfp4_quant(w_vc)
w_s_kc = w_s_kc.transpose(1, 2).contiguous().transpose(1, 2)
w_s_vc = w_s_vc.contiguous().transpose(1, 2)

return w_kc, w_s_kc, w_vc, w_s_vc
123 changes: 89 additions & 34 deletions vllm/v1/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1139,6 +1139,7 @@
self.indexer = indexer
self.q_pad_num_heads = q_pad_num_heads
self.is_aiter_triton_fp8_bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled()
self.is_aiter_triton_fp4_bmm_enabled = rocm_aiter_ops.is_fp4bmm_enabled()

def process_weights_after_loading(self, act_dtype: torch.dtype):
def get_layer_weight(layer):
Expand Down Expand Up @@ -1237,31 +1238,54 @@

def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor):
# Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)

if self.is_aiter_triton_fp8_bmm_enabled:
if self.is_aiter_triton_fp4_bmm_enabled:
from aiter.ops.triton.batched_gemm_a16wfp4 import batched_gemm_a16wfp4

#print(f'>>> x pre (up_proj) {x.shape}')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NITS: let's clean up these

out = out.view(-1, self.num_heads, self.v_head_dim)
# Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V)
x = rocm_aiter_ops.triton_fp8_bmm(
x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True, YQ=out
x = x.view(-1, self.num_heads, self.kv_lora_rank)
x = x.transpose(0, 1)

#print(f'>>> x {x.shape}, attn_bmm_output {attn_bmm_output.shape}, self.W_V {self.W_V.shape}')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NITS: let's clean up these

out = batched_gemm_a16wfp4(
x,
self.W_V,
self.W_V_scale,
y=out,
transpose_bm=True,
prequant=True,
y_scale=None,
)

Check failure on line 1258 in vllm/v1/attention/backends/mla/common.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/v1/attention/backends/mla/common.py:1258:89: E501 Line too long (106 > 88)
#print(f'>>> x before transpose {x.shape}')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NITS: let's clean up these

out = out.view(-1, self.num_heads * self.v_head_dim)
x = out
else:
# Convert from (B, N * V) to (N, B, V)
out = out.view(-1, self.num_heads, self.v_head_dim).transpose(0, 1)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)

if self.is_aiter_triton_fp8_bmm_enabled:
out = out.view(-1, self.num_heads, self.v_head_dim)
# Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V)
x = rocm_aiter_ops.triton_fp8_bmm(
x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True, YQ=out
)
else:
# Convert from (B, N * V) to (N, B, V)
out = out.view(-1, self.num_heads, self.v_head_dim).transpose(0, 1)

# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
torch.bmm(x, self.W_UV, out=out) # Reuse "out" to make it "hot"
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
torch.bmm(x, self.W_UV, out=out) # Reuse "out" to make it "hot"

# Convert from (N, B, V) to (B, N * V)
out_new = out.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
# Convert from (N, B, V) to (B, N * V)

Check failure on line 1278 in vllm/v1/attention/backends/mla/common.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/v1/attention/backends/mla/common.py:1278:89: E501 Line too long (90 > 88)
out_new = out.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)

# Adjust output buffer shape back to the original (B, N * V)
N, B, V = out.shape
out.resize_((B, N * V))
out.copy_(out_new) # Copy result

# Adjust output buffer shape back to the original (B, N * V)
N, B, V = out.shape
out.resize_((B, N * V))
out.copy_(out_new) # Copy result


class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):

Check failure on line 1288 in vllm/v1/attention/backends/mla/common.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/v1/attention/backends/mla/common.py:1288:89: E501 Line too long (91 > 88)
"""
NOTE: Please read the comment at the top of the file before trying to
understand this class
Expand Down Expand Up @@ -1574,29 +1598,42 @@
return dequant_weights.T
return layer.weight

# we currently do not have quantized bmm's which are needed for
# `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj)

# If kv_b_proj_weight is unquantized, quantize it to mxfp4 if supported
if self.is_aiter_triton_fp4_bmm_enabled:
from vllm.model_executor.layers.quantization.quark.utils import quark_post_load_weights

self.W_K, self.W_K_scale, W_V, self.W_V_scale = (
quark_post_load_weights(self, kv_b_proj_weight, "mxfp4"))
self.W_V = W_V.contiguous().transpose(1, 2)

self.W_K = self.W_K.transpose(-2, -1).contiguous()
self.W_K_scale = self.W_K_scale.transpose(-2, -1).contiguous()
self.W_V = self.W_V.transpose(-2, -1).contiguous()
self.W_V_scale = self.W_V_scale.transpose(-2, -1).contiguous()
return

# If kv_b_proj_weight is not being quantized to mxfp4, take the default path
# which is to dequantize and transpoose kv_b_proj_weight
kv_b_proj_weight = kv_b_proj_weight.T
assert kv_b_proj_weight.shape == (
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
), (
f"{kv_b_proj_weight.shape=}, "
f"{self.kv_lora_rank=}, "
f"{self.num_heads=}, "
f"{self.qk_nope_head_dim=}, "
f"{self.v_head_dim=}"
)
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
f"{kv_b_proj_weight.shape=}, "
f"{self.kv_lora_rank=}, "
f"{self.num_heads=}, "
f"{self.qk_nope_head_dim=}, "
f"{self.v_head_dim=}")
kv_b_proj_weight = kv_b_proj_weight.view(
self.kv_lora_rank,
self.num_heads,
self.qk_nope_head_dim + self.v_head_dim,
)

W_UK, W_UV = kv_b_proj_weight.split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1
)
[self.qk_nope_head_dim, self.v_head_dim], dim=-1
)

if self.is_aiter_triton_fp8_bmm_enabled:
W_K = W_UK.transpose(0, 1) # 16 512 128
Expand Down Expand Up @@ -1987,8 +2024,7 @@
decode_q_nope, decode_q_pe = decode_q.split(
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)

# Convert from (B, N, P) to (N, B, P)
# Convert from (B, N, P) to (N, B, P)
decode_q_nope = decode_q_nope.transpose(0, 1)

if self.q_pad_num_heads is not None:
Expand All @@ -1998,7 +2034,26 @@
decode_pe_padded.copy_(decode_q_pe)
decode_q_pe = decode_pe_padded

if self.is_aiter_triton_fp8_bmm_enabled:
if self.is_aiter_triton_fp4_bmm_enabled:
from aiter.ops.triton.batched_gemm_a16wfp4 import batched_gemm_a16wfp4
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems this kernel does not exist in the AITER commit that is currently used in the docker/Dockerfile.rocm_base (https://github.com/ROCm/aiter/tree/59bd8ff2c8c3dc1c6caa990a68055528657a1506/aiter/ops/triton). So, we will only be able to merge after we upgrade the AITER version.


#x = x.view(-1, self.num_heads, self.kv_lora_rank)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NITS: let's clean up these

x = decode_q_nope.transpose(0, 1)
decode_ql_nope = None
#print(f'>>> x {x.shape}, q_nope_out {q_nope_out.shape}, self.W_K {self.W_K.shape}')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NITS: let's clean up these


decode_ql_nope = batched_gemm_a16wfp4(
x,
self.W_K,
self.W_K_scale,
y=decode_ql_nope,
transpose_bm=True,
prequant=True,
y_scale=layer._q_scale if fp8_attention else None,
)
elif self.is_aiter_triton_fp8_bmm_enabled:

Check failure on line 2054 in vllm/v1/attention/backends/mla/common.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/v1/attention/backends/mla/common.py:2054:89: E501 Line too long (100 > 88)
# Convert from (B, N, P) to (N, B, P)
decode_q_nope = decode_q_nope.transpose(0, 1)
# Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L)
decode_ql_nope = rocm_aiter_ops.triton_fp8_bmm(
decode_q_nope,
Expand Down Expand Up @@ -2026,7 +2081,7 @@
# Convert from (N, B, L) to (B, N, L)
decode_ql_nope = decode_ql_nope.transpose(0, 1)

if fp8_attention:
if fp8_attention and not self.is_aiter_triton_fp4_bmm_enabled:
ql_nope_shape = decode_ql_nope.shape
decode_ql_nope, _ = ops.scaled_fp8_quant(
decode_ql_nope.reshape(
Expand Down