-
-
Notifications
You must be signed in to change notification settings - Fork 11.8k
Add llmcompressor fp8 kv-cache quant (per-tensor and per-attn_head) #30141
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add llmcompressor fp8 kv-cache quant (per-tensor and per-attn_head) #30141
Conversation
…attn-head) Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com>
|
Documentation preview: https://vllm--30141.org.readthedocs.build/en/30141/ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
|
Hi @eldarkurtic, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for per-attention-head FP8 KV cache quantization, primarily for use with Flash Attention and llm-compressor. The changes involve modifying CUDA kernels to handle both per-tensor and per-attention-head scaling factors, adding a new kernel for per-channel static FP8 quantization, and updating the reshape_and_cache_flash_kernel to use a kv_scale_stride for flexible scale access. The Python _custom_ops are updated to support per-channel scales, and the attention layer initialization logic is refactored to correctly set KV cache quantization attributes and query quantization group shapes based on the loaded llm-compressor configuration. Model weight loading utilities are extended to remap q_scale and zero_point parameters, and a new CompressedTensorsConfig method is added to process llm-compressor loaded scales, including reducing q_scale for Flash Attention and repeating it for QuantFP8 operations. Comprehensive unit tests were added to cover per-attention-head scaling in reshape_and_cache_flash and per-channel static FP8 quantization. The documentation for FP8 KV Cache has been significantly updated to detail per-attention-head quantization, various scale calibration approaches, and provides an example using llm-compressor. A review comment highlighted a brittle logic in CompressedTensorsConfig.from_config for filtering attention quantization config groups, suggesting a more robust check for empty target lists and iterating through all targets.
| config["config_groups"] = { | ||
| k: v | ||
| for k, v in config["config_groups"].items() | ||
| if "Attention" not in v["targets"][0] | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic to filter out attention quantization config groups is brittle. It assumes that v["targets"] is a non-empty list and only checks the first element (v["targets"][0]). If v["targets"] is an empty list, this will raise an IndexError, causing a crash when loading a model with such a configuration. A more robust implementation should handle empty target lists and check all targets for the 'Attention' keyword.
| config["config_groups"] = { | |
| k: v | |
| for k, v in config["config_groups"].items() | |
| if "Attention" not in v["targets"][0] | |
| } | |
| config["config_groups"] = { | |
| k: v | |
| for k, v in config["config_groups"].items() | |
| if not any("Attention" in t for t in v.get("targets", [])) | |
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a temporary and hacky way to drop config-groups where targets are LlamaAttention/Qwen3Attention/etc in llm-compressor checkpoints because this config doesn't have any meaning for vllm. My assumption here is that targets will never be an empty list.
Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com>
|
Hi @eldarkurtic, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, |
Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com>
|
Hi @eldarkurtic, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, |
Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com>
|
This pull request has merge conflicts that must be resolved before it can be |
| assert not static or group_shape == GroupShape.PER_TENSOR, ( | ||
| "Only per-tensor scales supported for static quantization." | ||
| ) | ||
| assert group_shape in { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dynamic quant doesn't support per-channel
| else: # strategy == "attn_head" | ||
| from vllm.v1.attention.backends.flash_attn import FlashAttentionImpl | ||
|
|
||
| assert ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a property that attention backend supports per-channel scales (default false), see CG support or query quant support
TLDR: this PR adds support to load and run
llm-compressormodels with FP8 KV-cache and attention quantization. In addition to the standard "per-tensor" quantization, it adds support for "per-attention-head" quantization.Summary
llm-compressor2.1 for query quantization:
QuantFP8to support per-channel static quantization (queries are of the shapenum_tokens x hidden_sizeso we expand the per-attention-head scales to accommodate per-channel scalingstatic_scaled_fp8_quantkernel to work with an array of scales2.2 for kv-quantization:
reshape_and_cache_flashkernel by adding support for an array ofk/v_scalevllm/attention/layer.pyas so far it has been hardcoded forcalculate_kv_scalespathway onlyTests
To confirm that the existing pathway with
calculate_kv_scales=True/Falseisn't affected, I ran GSM8k evals on the following models:Llama-2-7b-chat-hf(MHA),Llama-3.1-8B-Instruct(GQA),Qwen/Qwen3-8B(different model family), and got the same results before and after the PR. They are as follows:meta-llama/Llama-2-7b-chat-hfmeta-llama/Llama-3.1-8B-InstructQwen/Qwen3-8BAnd to confirm that the new support for
llm-compressormodels with both,per-tensorandper-attention-headscales is working correctly, I ran the same models from above with both configurations and observed the expected results:meta-llama/Llama-2-7b-chat-hfmeta-llama/Llama-3.1-8B-InstructQwen/Qwen3-8BThe
llm-compressorcode to produce these models is available in diffs ofdocs/features/quantization/quantized_kvcache.md. I haven't done any tuning of the calibration parameters for the testing purposes, just ran the defaults so better results are expected with better tuning.I've also verified that changes are working with both, LLM class and vllm serve.
Note: some model implementations have remapping of scales guarded by something like this:
if "scale" in nameso I had to expand it toif "scale" in name or "zero_point" in name:to support loading of zero_points which are present inllm-compressorcheckpoints.