-
-
Notifications
You must be signed in to change notification settings - Fork 11.8k
[ROCm][MXFP4] Enable FP4 MLA BMM support #30177
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
- Added support for FP4 BMM in ROCm aiter operations. - Introduced environment variable for FP4 BMM configuration. - Updated quantization utilities to handle dynamic MXFP4 quantization. - Enhanced MLACommonBaseImpl to utilize FP4 BMM when enabled.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
https://github.com/vllm-project/vllm/blob/c2b8f454546069ecf03afb0edcca26c70f736d8b/model_executor/layers/quantization/quark/utils.py#L7-L10
Guard aiter import for non‑ROCm environments
quark/utils.py now imports dynamic_mxfp4_quant at module load (lines 7‑10). This file is imported in CPU/NVIDIA paths such as tests/config/test_config_generation.py and quark/quark.py, where the ROCm-only aiter package is not installed. The top‑level import will raise ModuleNotFoundError before any ROCm gating, breaking non‑ROCm builds and tests. Please gate the import or move it inside the ROCm-specific code path.
https://github.com/vllm-project/vllm/blob/c2b8f454546069ecf03afb0edcca26c70f736d8b/model_executor/layers/quantization/quark/utils.py#L199-L201
Handle fp16 weights in MXFP4 post‑load quantization
quark_post_load_weights only processes torch.bfloat16 or torch.uint8 weights before returning w_kc, w_s_kc, w_vc, w_s_vc (line 201), so torch.float16 weights leave those variables undefined. When MLACommonImpl.process_weights_after_loading takes the FP4 BMM path on MI300 with fp16 model weights (the default dtype for many checkpoints), this function will raise UnboundLocalError before quantization. Add a float16 branch or a conversion before returning.
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for FP4 Batched Matrix Multiplication (BMM) in ROCm aiter operations, which is a significant enhancement for performance on compatible hardware. The changes include adding new environment variables, updating quantization utilities for MXFP4, and integrating the FP4 BMM path into the MLA attention backend. While the overall direction is good, my review has identified a few critical correctness issues that need to be addressed. Specifically, there's a bug in the e8m0_to_f32 function's NaN handling, a missing tensor transpose in the new FP4 attention path, and incorrect and duplicated padding logic within the forward method of the MLA implementation. These issues could lead to incorrect computations and need to be fixed before merging.
|
|
||
| # If the exponent value was 255 (i.e., 2^(128)), this is a special case usually used to represent NaN or Inf. | ||
| # Since this custom format has no mantissa, treat 2^128 as NaN. | ||
| x_f32[x_f32 == 128] = float("nan") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The condition to detect NaN values appears to be incorrect. The comment at line 158 indicates that an exponent value of 255 should be treated as NaN. However, the code checks if the computed float value x_f32 is equal to 128, which corresponds to an original exponent value of 134, not 255. This will lead to incorrect conversions of valid numbers to NaN and will fail to handle the special case of exponent 255 as intended. The check should be performed on the input tensor x for the value 255.
| x_f32[x_f32 == 128] = float("nan") | |
| x_f32[x == 255] = float("nan") |
| prequant=True, | ||
| y_scale=layer._q_scale if fp8_attention else None, | ||
| ) | ||
| #decode_ql_nope = decode_ql_nope.transpose(0, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The output of batched_gemm_a16wfp4, decode_ql_nope, has a shape of (num_heads, num_tokens, kv_lora_rank), which is (N, B, L). However, it is later used in a way that expects a shape of (B, N, L). The other code paths for fp8 and torch.bmm both produce tensors with the (B, N, L) shape. The new fp4 path is missing a transpose operation. The commented-out line here suggests this was intended but was missed. This should be uncommented to ensure correct tensor shapes for subsequent operations.
| #decode_ql_nope = decode_ql_nope.transpose(0, 1) | |
| decode_ql_nope = decode_ql_nope.transpose(0, 1) |
| decode_pe_padded = decode_q_pe.new_empty((B, self.q_pad_num_heads, decode_q_pe.shape[-1])) | ||
| decode_pe_padded.resize_((B, decode_q_pe.shape[1], decode_q_pe.shape[-1])) | ||
| decode_pe_padded.copy_(decode_q_pe) | ||
| decode_q_pe = decode_pe_padded | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| # Convert from (N, B, L) to (B, N, L) | ||
| decode_ql_nope = decode_ql_nope.transpose(0, 1) | ||
|
|
||
| if self.q_pad_num_heads is not None and not self.is_aiter_triton_fp4_bmm_enabled: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The condition for padding decode_q_pe incorrectly excludes the fp4 path by checking not self.is_aiter_triton_fp4_bmm_enabled. The fp4 path also uses decode_q_pe and should be padded if self.q_pad_num_heads is not None, just like the other paths. The original code before this PR applied padding to all paths. To fix this and ensure consistent behavior, the condition should be simplified to only check for self.q_pad_num_heads.
| if self.q_pad_num_heads is not None and not self.is_aiter_triton_fp4_bmm_enabled: | |
| if self.q_pad_num_heads is not None: |
|
Hi @dllehr-amd, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, |
|
Hi @dllehr-amd, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, |
Purpose
Test Plan
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.