Skip to content

Commit 1abfd9e

Browse files
kylesayrsdsikka
andauthored
[model_free_ptq] NVFP4A16 (#1988)
## Purpose ## * Support NVFP4A16 for `model_free_ptq` ```bash llmcompressor.reindex_fused_weights \ unsloth/Kimi-K2-Thinking-BF16 \ Kimi-K2-Thinking-BF16-reindexed \ --num_workers=10 ``` ```python model_free_ptq( model_stub="Kimi-K2-Thinking-BF16-reindexed", save_directory="Kimi-K2-Thinking-BF16-NVFP4A16", scheme="NVFP4A16", ignore=[ "re:.*gate$", "lm_head", "re:.*kv_a_proj_with_mqa$", "re:.*q_a_proj$", "model.embed_tokens", ], max_workers=15, device="cuda:0", ) ``` ## Changes ## * Restructure files * Move `validate_scheme` to `validate.py` * Move `find_safetensors_index_path`, `find_config_path`, `find_safetensors_index_file` to `helpers.py` * Move `process_file` to `process.py` * Move `validate_scheme` to `validate.py` * Break `calibrate_weights` into `calibrate_global_scale` and `calibrate_scale_zp` * Add extra utility functions * `match_names_set_eager` * `invert_mapping` * Add microscale/fused module utility functions * `is_microscale_scheme` * `get_fused_names` * Add `process_file_microscale_scheme` to separate the fp4 lifecycle from the regular lifecycle (this script should be very trustworthy. By separating the functions, an FP8 user does not have to trust anything about FP4) * Add `llm.compressor.reindex_fused_weights` script which reindexes a model's weights so that fused modules are in the same files. * Fix [bug](https://github.com/vllm-project/llm-compressor/pull/1988/files#diff-8d11f284a49f6c4e559617aaf7750f3437a074cd526ee94dbefe86866f250a42R80-R82) where safetensors index metadata was not being saved correctly ## Testing ## * Add NVFP4A16 to `test_model_free_ptq_matches_oneshot` * Regression tested large mistral model e2e with FP8_BLOCK * Tested large mistral model e2e with NVFP4A16 ## Mistral 3 ## This branch was used to quantize Mistral 3 1. Quantize to W4A16 ```python3 from llmcompressor import model_free_ptq model_free_ptq( "mistralai/Mistral-Large-3-675B-Instruct-2512", "Mistral-Large-3-675B-Instruct-2512-FP8_BLOCK", scheme="NVFP4_A16", ignore=[ "tok_embeddings", # embeddings "re:patch_merger.*", # patch merger "re:vision_encoder.*", # vision tower "re:vision_language_adapter.*", # vision adapter "re:.*attention$", # sensitive to quantization "re:.*gate$", # sensitive to quantization "output", # lm head ], max_workers=10, # 10 = 52Gb device="cuda:0", ) ``` 2. Update ignore list to use vLLM checkpoint format ``` [ "model.embed_tokens", "re:patch_merger.*", "re:vision_encoder.*", "re:vision_language_adapter.*", "lm_head", "re:.*self_attn.*", "re:.*gate$" ] ``` 3. Add observers to vLLM model definition and run for 100 samples from ultrachat 4. Save model checkpoint, making sure to reduce values from shards For more information on how observers were added to vLLM, please reach out to @kylesayrs --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> Co-authored-by: Dipika Sikka <dipikasikka1@gmail.com>
1 parent 4982529 commit 1abfd9e

File tree

14 files changed

+626
-143
lines changed

14 files changed

+626
-143
lines changed

examples/model_free_ptq/README.md

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,53 @@
1313
In `kimi_k2_thinking_fp8_block.py`, we call `model_free_ptq` by providing a `scheme` and `ignore` list, similar to how we provide reicpes to `oneshot` calls. In the case of Kimi-K2 Thinking, we apply the `FP8_BLOCK` scheme and ignore layers that are incompatible with a block_size of 128 (specifically, `kv_a_proj_with_mqa` and `q_a_proj`).
1414

1515
In contrast to `oneshot`, we expect the model stub or pathway string to be directly passed in, as opposed to first being loaded through transformers. Once complete, the model is compressed using compressed-tensors and saved to `SAVE_DIR`.
16+
17+
To get started, simply call `model_free_ptq` with your desired model stub and save directory
18+
```python
19+
model_free_ptq(
20+
model_stub="unsloth/Kimi-K2-Thinking-BF16",
21+
save_directory="Kimi-K2-Thinking-FP8-BLOCK",
22+
scheme="FP8_BLOCK",
23+
ignore=[
24+
"re:.*gate$",
25+
"lm_head",
26+
"re:.*kv_a_proj_with_mqa$",
27+
"re:.*q_a_proj$",
28+
"model.embed_tokens",
29+
],
30+
max_workers=15,
31+
device="cuda:0",
32+
)
33+
34+
```
35+
36+
37+
# Quantizing models to NVFP4A16/ MXFP4A16
38+
39+
Using model_free_ptq to quantize models with microscale schemes (NVFP4/MXFP4) is the same as quantizing models with non-microscale schemes, except for one additional step. That extra step is that the safetensors in the model files must be reindexed to ensure that fused modules (qkv, gate_up) end up in the same safetensors files, which allows model_free_ptq to fuse global scales.
40+
41+
First, apply `llmcompressor.reindex_fused_weights` from the command line entrypoint
42+
```bash
43+
llmcompressor.reindex_fused_weights \
44+
unsloth/Kimi-K2-Thinking-BF16 \
45+
Kimi-K2-Thinking-BF16-reindexed \
46+
--num_workers=10
47+
```
48+
49+
Then, call `model_free_ptq` on the reindex files
50+
```python
51+
model_free_ptq(
52+
model_stub="Kimi-K2-Thinking-BF16-reindexed",
53+
save_directory="Kimi-K2-Thinking-BF16-NVFP4A16",
54+
scheme="NVFP4A16",
55+
ignore=[
56+
"re:.*gate$",
57+
"lm_head",
58+
"re:.*kv_a_proj_with_mqa$",
59+
"re:.*q_a_proj$",
60+
"model.embed_tokens",
61+
],
62+
max_workers=15,
63+
device="cuda:0",
64+
)
65+
```

examples/model_free_ptq/kimi_k2_thinking_fp8_block.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from llmcompressor import model_free_ptq
22

33
MODEL_ID = "unsloth/Kimi-K2-Thinking-BF16"
4-
SAVE_DIR = "Kimi-K2-Thinking-FP8-Block"
4+
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-FP8-BLOCK"
55

66
# Apply FP8-Block to the model
77
# Once quantized, the model is saved
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
"""
2+
NOTE: Please run the following script before using `model_free_ptq`
3+
4+
This script is used to reindex the safetensors files of a model such that all fused
5+
modules (gate_up, qkv) are in the same safetensors file. This is required by
6+
model_free_ptq for microscale schemes (NVFP4A16, MXFP4A16)
7+
8+
llmcompressor.reindex_fused_weights \
9+
unsloth/Kimi-K2-Thinking-BF16 \
10+
Kimi-K2-Thinking-BF16-reindexed \
11+
--num_workers=10
12+
"""
13+
14+
from llmcompressor import model_free_ptq
15+
16+
MODEL_ID = "unsloth/Kimi-K2-Thinking-BF16"
17+
REINDEX_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-reindexed"
18+
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-NVFP4A16"
19+
20+
# See above notice pertaining to safetensors reindexing
21+
# After running `llmcompressor.reindex_fused_weights`,
22+
# use `model_free_ptq` to apply NVFP4A16 quantization
23+
model_free_ptq(
24+
model_stub=REINDEX_DIR,
25+
save_directory=SAVE_DIR,
26+
scheme="NVFP4A16",
27+
ignore=[
28+
"re:.*gate$",
29+
"lm_head",
30+
"re:.*kv_a_proj_with_mqa$",
31+
"re:.*q_a_proj$",
32+
"model.embed_tokens",
33+
],
34+
max_workers=15,
35+
device="cuda:0",
36+
)

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ def localversion_func(version: ScmVersion) -> str:
188188
entry_points={
189189
"console_scripts": [
190190
"llmcompressor.trace=llmcompressor.transformers.tracing.debug:main",
191+
"llmcompressor.reindex_fused_weights=llmcompressor.entrypoints.model_free.reindex_fused_weights:main",
191192
]
192193
},
193194
python_requires=">=3.10",

src/llmcompressor/entrypoints/model_free/__init__.py

Lines changed: 20 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -7,27 +7,28 @@
77
import torch
88
import tqdm
99
from compressed_tensors.quantization import QuantizationScheme
10-
from compressed_tensors.utils.match import _match_name
1110
from loguru import logger
12-
from safetensors.torch import load_file, save_file
1311

14-
from llmcompressor.entrypoints.model_free.helpers import (
15-
gpu_if_available,
16-
validate_scheme,
17-
)
18-
from llmcompressor.entrypoints.model_free.lifecycle import (
19-
calibrate_weights,
20-
compress_module,
21-
initialize_quantized_linear,
12+
from llmcompressor.entrypoints.model_free.helpers import gpu_if_available
13+
from llmcompressor.entrypoints.model_free.microscale import (
14+
is_microscale_scheme,
2215
)
2316
from llmcompressor.entrypoints.model_free.model_utils import (
2417
get_checkpoint_files,
2518
is_weights_file,
2619
)
20+
from llmcompressor.entrypoints.model_free.process import (
21+
process_file,
22+
process_file_microscale_scheme,
23+
)
2724
from llmcompressor.entrypoints.model_free.save_utils import (
2825
update_config,
2926
update_safetensors_index,
3027
)
28+
from llmcompressor.entrypoints.model_free.validate import (
29+
validate_safetensors_index,
30+
validate_scheme,
31+
)
3132

3233
__all__ = ["model_free_ptq"]
3334

@@ -55,20 +56,24 @@ def model_free_ptq(
5556
model_files = get_checkpoint_files(model_stub)
5657
scheme_name, scheme = validate_scheme(scheme)
5758
device = gpu_if_available(device)
59+
validate_safetensors_index(model_files, scheme)
5860

5961
# 0. collect safetensors files, copy files
6062
jobs = []
61-
for file_path, resolved_path in model_files:
63+
job_fn = (
64+
process_file
65+
if not is_microscale_scheme(scheme)
66+
else process_file_microscale_scheme
67+
)
68+
for file_path, resolved_path in model_files.items():
6269
save_path = Path(save_directory) / file_path
6370

6471
if file_path.endswith("safetensors"):
65-
jobs.append(
66-
(_process_file, resolved_path, save_path, scheme, ignore, device)
67-
)
72+
jobs.append((job_fn, resolved_path, save_path, scheme, ignore, device))
6873

6974
else:
7075
if is_weights_file(file_path):
71-
logger.warning(f"Skipping weights file {file_path}")
76+
logger.warning(f"Skip processing for weights file {file_path}")
7277
save_path.parent.mkdir(parents=True, exist_ok=True)
7378
logger.info(f"Copying {file_path} {save_path}")
7479
shutil.copyfile(resolved_path, save_path)
@@ -89,50 +94,3 @@ def model_free_ptq(
8994
# 5. update config and safetensors index
9095
update_config(save_directory, scheme_name, scheme, ignore)
9196
update_safetensors_index(save_directory, total_size, weight_map)
92-
93-
94-
def _process_file(
95-
file_path: str | os.PathLike,
96-
save_path: str | os.PathLike,
97-
scheme: QuantizationScheme,
98-
ignore: str | list[str],
99-
device: str | torch.device,
100-
) -> tuple[int, dict[str, str]]:
101-
"""
102-
Quantize and compress tensors in a given safetensors file
103-
104-
:param file_path: safetensors file to process
105-
:param save_path: save path of file with quantized weights
106-
:param scheme: quantization scheme to apply to tensors
107-
:param ignore: modules to ignore. Modules ending with "norm" are automatically
108-
ignored
109-
:param device: device used to quantize and compress weights
110-
"""
111-
tensors = load_file(file_path)
112-
113-
for name in list(tensors.keys()):
114-
module_name, param_name = name.rsplit(".", 1)
115-
is_linear_weight = param_name == "weight" and not module_name.endswith("norm")
116-
is_ignored = any(_match_name(module_name, ign) for ign in ignore)
117-
if not is_linear_weight or is_ignored:
118-
continue
119-
120-
# 1. initialize module with qparams (on device)
121-
module = initialize_quantized_linear(tensors[name], scheme, device)
122-
123-
# 2. calibrate weight qparams
124-
calibrate_weights(module)
125-
126-
# 3. compress module using qparams
127-
compress_module(module)
128-
129-
# 4. save compressed data (on cpu)
130-
del tensors[name]
131-
prefix = module_name + "."
132-
for key, value in module.state_dict(prefix=prefix).items():
133-
tensors[key] = value.to("cpu")
134-
135-
save_file(tensors, save_path)
136-
total_size = sum(tensor.nbytes for tensor in tensors.values())
137-
weight_map = {key: os.path.basename(save_path) for key in tensors.keys()}
138-
return total_size, weight_map
Lines changed: 81 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,25 @@
1-
from typing import Optional
1+
import os
2+
from collections import defaultdict
3+
from typing import Mapping, TypeVar
24

35
import torch
4-
from compressed_tensors.quantization import QuantizationScheme, preset_name_to_scheme
5-
from compressed_tensors.utils import getattr_chain
66
from compressed_tensors.utils.match import _match_name
77
from loguru import logger
8+
from transformers.file_utils import CONFIG_NAME
89

9-
__all__ = ["validate_scheme", "gpu_if_available", "is_match_name"]
10+
__all__ = [
11+
"gpu_if_available",
12+
"find_safetensors_index_path",
13+
"find_config_path",
14+
"find_safetensors_index_file",
15+
"match_names_set_eager",
16+
"MatchedNamesSet",
17+
"invert_mapping",
18+
]
1019

11-
12-
def validate_scheme(scheme: QuantizationScheme) -> tuple[str, QuantizationScheme]:
13-
# treat strings as preset schemes
14-
if isinstance(scheme, str):
15-
scheme_name, scheme = scheme, preset_name_to_scheme(scheme, [])
16-
else:
17-
scheme_name = "config_group_0"
18-
19-
# weight quantization must be provided
20-
if scheme.weights is None:
21-
raise ValueError(
22-
"Must provide a weights quanitization scheme to perform weights-only PTQ"
23-
)
24-
25-
# activation quantization must be dynamic
26-
input_dynamic = getattr_chain(scheme, "input_activations.dynamic", True)
27-
output_dynamic = getattr_chain(scheme, "output_activations.dynamic", True)
28-
if input_dynamic is not True or output_dynamic is not True:
29-
raise ValueError(
30-
"Model Free PTQ cannot calibrate activations. "
31-
"Please use `oneshot` instead."
32-
)
33-
34-
# override with static observers
35-
# Remove after https://github.com/vllm-project/compressed-tensors/pull/489
36-
if scheme.weights.observer in ("minmax", "mse"):
37-
new_observer = f"static_{scheme.weights.observer}"
38-
logger.warning(
39-
f"Scheme uses {scheme.weights.observer} weight observer. "
40-
f"Using {new_observer} instead"
41-
)
42-
scheme.weights.observer = new_observer
43-
44-
# target all modules; filter by ignore list
45-
# technically this should be "re:.*", but vllm's
46-
# ct moe layer has a hard coded check for "Linear"
47-
scheme.targets = ["Linear"]
48-
return scheme_name, scheme
20+
KeyType = TypeVar("K")
21+
ValueType = TypeVar("V")
22+
MatchedNamesSet = dict[str, str | None]
4923

5024

5125
def gpu_if_available(device: torch.device | str | None) -> torch.device:
@@ -63,13 +37,70 @@ def gpu_if_available(device: torch.device | str | None) -> torch.device:
6337
return torch.device("cpu")
6438

6539

66-
def is_match_name(
67-
name: str, targets: list[str], ignore: Optional[str | list[str]] = None
68-
) -> bool:
69-
targets = targets if isinstance(targets, list) else [targets]
70-
ignore = ignore if isinstance(ignore, list) else [ignore]
40+
def find_safetensors_index_path(save_directory: str | os.PathLike) -> str | None:
41+
for file_name in os.listdir(save_directory):
42+
if file_name.endswith("safetensors.index.json"):
43+
return os.path.join(save_directory, file_name)
44+
45+
return None
46+
47+
48+
def find_config_path(save_directory: str | os.PathLike) -> str | None:
49+
for file_name in os.listdir(save_directory):
50+
if file_name in (CONFIG_NAME, "params.json"):
51+
return os.path.join(save_directory, file_name)
52+
53+
return None
54+
55+
56+
def find_safetensors_index_file(model_files: dict[str, str]) -> str | None:
57+
for file_path, resolved_path in model_files.items():
58+
if file_path.endswith("safetensors.index.json"):
59+
return resolved_path
60+
61+
return None
62+
63+
64+
def match_names_set_eager(
65+
names: set[str] | list[str],
66+
targets: set[str] | list[str],
67+
return_unmatched: bool = True,
68+
) -> list[MatchedNamesSet] | tuple[list[MatchedNamesSet], MatchedNamesSet]:
69+
matched_sets = []
70+
matches = dict.fromkeys(targets, None)
71+
72+
for name in names:
73+
# match until we get a full set
74+
for target in targets:
75+
if _match_name(name, target):
76+
if matches[target] is None:
77+
matches[target] = name
78+
else:
79+
# matched target twice without completing a set
80+
raise ValueError(
81+
f"Matched a {target} twice before "
82+
f"completing set ({matches[target]}, {name})"
83+
)
84+
85+
# once we have a full set, yield and reset
86+
if all((matches[target] is not None for target in targets)):
87+
matched_sets.append(matches)
88+
matches = dict.fromkeys(targets, None)
89+
90+
unmatched_set = matches if any((v is not None for v in matches.values())) else None
91+
92+
if return_unmatched:
93+
return matched_sets, unmatched_set
94+
else:
95+
return matched_sets
96+
97+
98+
def invert_mapping(
99+
mapping: Mapping[KeyType, ValueType],
100+
) -> dict[ValueType, list[KeyType]]:
101+
inverse = defaultdict(list)
71102

72-
matches_target = any(_match_name(name, target) for target in targets)
73-
matches_ignore = any(_match_name(name, ign) for ign in ignore)
103+
for key, value in mapping.items():
104+
inverse[value].append(key)
74105

75-
return matches_target and not matches_ignore
106+
return inverse

0 commit comments

Comments
 (0)