Skip to content

Conversation

@eldarkurtic
Copy link
Contributor

@eldarkurtic eldarkurtic commented Dec 5, 2025

TLDR: this PR adds support to load and run llm-compressor models with FP8 KV-cache and attention quantization. In addition to the standard "per-tensor" quantization, it adds support for "per-attention-head" quantization.

Summary

  1. enable using the existing pathway of "per-tensor" KV-cache (and query) FP8 quantization with scales calibrated through llm-compressor
  2. Flash Attention v3 backend supports "finer-grained" scales, i.e. one scale per attention head. This is currently not supported, and this PR enables it in the following way:

2.1 for query quantization:

  • expand QuantFP8 to support per-channel static quantization (queries are of the shape num_tokens x hidden_size so we expand the per-attention-head scales to accommodate per-channel scaling
  • this is enabled further by expanding the static_scaled_fp8_quant kernel to work with an array of scales

2.2 for kv-quantization:

  • expand the reshape_and_cache_flash kernel by adding support for an array of k/v_scale
  • it also covers both cache layouts (NHD and HND)
  1. reorganized some things around KV-cache in vllm/attention/layer.py as so far it has been hardcoded for calculate_kv_scales pathway only
  2. expands tests to cover all kernel-related updates
  3. update documentation on the newly supported KV-cache and attention quantization technique in vLLM

Tests

To confirm that the existing pathway with calculate_kv_scales=True/False isn't affected, I ran GSM8k evals on the following models: Llama-2-7b-chat-hf (MHA), Llama-3.1-8B-Instruct (GQA), Qwen/Qwen3-8B (different model family), and got the same results before and after the PR. They are as follows:

  • Model = meta-llama/Llama-2-7b-chat-hf
unquantized baseline
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.2509|±  |0.0119|
|     |       |strict-match    |     5|exact_match|↑  |0.1933|±  |0.0109|

calculate_kv_scales=False
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.2563|±  |0.0120|
|     |       |strict-match    |     5|exact_match|↑  |0.1911|±  |0.0108|

calculate_kv_scales=True
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.0053|±  | 0.002|
|     |       |strict-match    |     5|exact_match|↑  |0.0000|±  | 0.000|
  • Model = meta-llama/Llama-3.1-8B-Instruct
unquantized baseline
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8522|±  |0.0098|
|     |       |strict-match    |     5|exact_match|↑  |0.8249|±  |0.0105|


calculate_kv_scales=False
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8302|±  |0.0103|
|     |       |strict-match    |     5|exact_match|↑  |0.8036|±  |0.0109|


calculate_kv_scales=True
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8324|±  |0.0103|
|     |       |strict-match    |     5|exact_match|↑  |0.8127|±  |0.0107|
  • Model = Qwen/Qwen3-8B
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8810|±  |0.0089|
|     |       |strict-match    |     5|exact_match|↑  |0.8764|±  |0.0091|

calculate_kv_scales=True
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8787|±  |0.0090|
|     |       |strict-match    |     5|exact_match|↑  |0.8734|±  |0.0092|


calculate_kv_scales=False
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8741|±  |0.0091|
|     |       |strict-match    |     5|exact_match|↑  |0.8688|±  |0.0093|

And to confirm that the new support for llm-compressor models with both, per-tensor and per-attention-head scales is working correctly, I ran the same models from above with both configurations and observed the expected results:

  • Model = meta-llama/Llama-2-7b-chat-hf
unquantized baseline
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.2509|±  |0.0119|
|     |       |strict-match    |     5|exact_match|↑  |0.1933|±  |0.0109|

per-tensor
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.2555|±  |0.0120|
|     |       |strict-match    |     5|exact_match|↑  |0.1835|±  |0.0107|

per-attn-head
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.2441|±  |0.0118|
|     |       |strict-match    |     5|exact_match|↑  |0.1736|±  |0.0104|
  • Model = meta-llama/Llama-3.1-8B-Instruct
unquantized baseline
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8522|±  |0.0098|
|     |       |strict-match    |     5|exact_match|↑  |0.8249|±  |0.0105|


per-tensor
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8362|±  |0.0102|
|     |       |strict-match    |     5|exact_match|↑  |0.8021|±  |0.0110|

per-attn-head
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8393|±  |0.0101|
|     |       |strict-match    |     5|exact_match|↑  |0.8112|±  |0.0108|
  • Model = Qwen/Qwen3-8B
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8810|±  |0.0089|
|     |       |strict-match    |     5|exact_match|↑  |0.8764|±  |0.0091|

per-tensor
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8810|±  |0.0089|
|     |       |strict-match    |     5|exact_match|↑  |0.8757|±  |0.0091|

per-attn-head
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8787|±  |0.0090|
|     |       |strict-match    |     5|exact_match|↑  |0.8757|±  |0.0091

The llm-compressor code to produce these models is available in diffs of docs/features/quantization/quantized_kvcache.md. I haven't done any tuning of the calibration parameters for the testing purposes, just ran the defaults so better results are expected with better tuning.
I've also verified that changes are working with both, LLM class and vllm serve.

Note: some model implementations have remapping of scales guarded by something like this: if "scale" in name so I had to expand it to if "scale" in name or "zero_point" in name: to support loading of zero_points which are present in llm-compressor checkpoints.

…attn-head)

Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com>
@mergify
Copy link

mergify bot commented Dec 5, 2025

Documentation preview: https://vllm--30141.org.readthedocs.build/en/30141/

@mergify mergify bot added documentation Improvements or additions to documentation llama Related to Llama models speculative-decoding v1 labels Dec 5, 2025
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

@mergify
Copy link

mergify bot commented Dec 5, 2025

Hi @eldarkurtic, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for per-attention-head FP8 KV cache quantization, primarily for use with Flash Attention and llm-compressor. The changes involve modifying CUDA kernels to handle both per-tensor and per-attention-head scaling factors, adding a new kernel for per-channel static FP8 quantization, and updating the reshape_and_cache_flash_kernel to use a kv_scale_stride for flexible scale access. The Python _custom_ops are updated to support per-channel scales, and the attention layer initialization logic is refactored to correctly set KV cache quantization attributes and query quantization group shapes based on the loaded llm-compressor configuration. Model weight loading utilities are extended to remap q_scale and zero_point parameters, and a new CompressedTensorsConfig method is added to process llm-compressor loaded scales, including reducing q_scale for Flash Attention and repeating it for QuantFP8 operations. Comprehensive unit tests were added to cover per-attention-head scaling in reshape_and_cache_flash and per-channel static FP8 quantization. The documentation for FP8 KV Cache has been significantly updated to detail per-attention-head quantization, various scale calibration approaches, and provides an example using llm-compressor. A review comment highlighted a brittle logic in CompressedTensorsConfig.from_config for filtering attention quantization config groups, suggesting a more robust check for empty target lists and iterating through all targets.

Comment on lines +185 to +189
config["config_groups"] = {
k: v
for k, v in config["config_groups"].items()
if "Attention" not in v["targets"][0]
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The logic to filter out attention quantization config groups is brittle. It assumes that v["targets"] is a non-empty list and only checks the first element (v["targets"][0]). If v["targets"] is an empty list, this will raise an IndexError, causing a crash when loading a model with such a configuration. A more robust implementation should handle empty target lists and check all targets for the 'Attention' keyword.

Suggested change
config["config_groups"] = {
k: v
for k, v in config["config_groups"].items()
if "Attention" not in v["targets"][0]
}
config["config_groups"] = {
k: v
for k, v in config["config_groups"].items()
if not any("Attention" in t for t in v.get("targets", []))
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a temporary and hacky way to drop config-groups where targets are LlamaAttention/Qwen3Attention/etc in llm-compressor checkpoints because this config doesn't have any meaning for vllm. My assumption here is that targets will never be an empty list.

Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com>
@mergify
Copy link

mergify bot commented Dec 5, 2025

Hi @eldarkurtic, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com>
@mergify
Copy link

mergify bot commented Dec 5, 2025

Hi @eldarkurtic, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com>
@mergify
Copy link

mergify bot commented Dec 5, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @eldarkurtic.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Dec 5, 2025
assert not static or group_shape == GroupShape.PER_TENSOR, (
"Only per-tensor scales supported for static quantization."
)
assert group_shape in {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dynamic quant doesn't support per-channel

else: # strategy == "attn_head"
from vllm.v1.attention.backends.flash_attn import FlashAttentionImpl

assert (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a property that attention backend supports per-channel scales (default false), see CG support or query quant support

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation llama Related to Llama models needs-rebase speculative-decoding v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants