Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/transformers/conversion_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,11 @@ def _build_checkpoint_conversion_mapping():
if hasattr(torch.nn.utils.parametrizations, "weight_norm"):
mapping["legacy"] += [
WeightRenaming(
source_patterns="weight_g",
source_patterns=r"weight_g$",
target_patterns="parametrizations.weight.original0",
),
WeightRenaming(
source_patterns="weight_v",
source_patterns=r"weight_v$",
target_patterns="parametrizations.weight.original1",
),
]
Expand Down
92 changes: 92 additions & 0 deletions src/transformers/integrations/fp_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
# limitations under the License.
"FP-Quant integration file"

from typing import Optional

import torch

from ..utils import (
is_fp_quant_available,
)
Expand All @@ -24,6 +28,94 @@

from transformers.utils.quantization_config import FPQuantConfig

from ..core_model_loading import ConversionOps
from ..quantizers.quantizers_utils import get_module_from_name


class FpQuantQuantize(ConversionOps):
def __init__(self, hf_quantizer):
self.hf_quantizer = hf_quantizer

def convert(
self,
input_dict: torch.Tensor,
model: Optional[torch.nn.Module] = None,
missing_keys: Optional[list[str]] = None,
**kwargs,
) -> dict[str, torch.Tensor]:
target_key, value = tuple(input_dict.items())[0]
value = value[0]
# Loading master weights or an unquantized checkpoint
weight = torch.nn.Parameter(value)
module, _ = get_module_from_name(model, target_key)
module.weight = weight

# Let pre-forward handle the quantization and set None where necessary
# This operation will quantize the weights internally
with torch.cuda.device(value.device):
module.pre_forward()

prefix_target_key = target_key.rsplit(".", 1)[0]

# keys are set inside the module.pre_forward() method, we don't need remove them from the missing keys list
missing_keys.discard(target_key)
missing_keys.discard(f"{prefix_target_key}.backward_hadamard_matrix")
missing_keys.discard(f"{prefix_target_key}.forward_hadamard_matrix")
missing_keys.discard(f"{prefix_target_key}.act_global_scale")
missing_keys.discard(f"{prefix_target_key}.weight_global_scale")
missing_keys.discard(f"{prefix_target_key}.qweight")
missing_keys.discard(f"{prefix_target_key}.scales")
missing_keys.discard(f"{prefix_target_key}.dqweight")
return {}


class FpQuantDeserialize(ConversionOps):
def __init__(self, hf_quantizer):
self.hf_quantizer = hf_quantizer

def convert(
self,
input_dict: torch.Tensor,
model: Optional[torch.nn.Module] = None,
full_layer_name: str | None = None,
missing_keys: Optional[list[str]] = None,
**kwargs,
) -> dict[str, torch.Tensor]:
target_key, value = tuple(input_dict.items())[0]
value = value[0] if isinstance(value, list) else value
module, _ = get_module_from_name(model, target_key)
# The module holds either:
# * `weight` when `store_master_weights=True`
# * `qweight` and `scales` when `store_master_weights=False` and `pseudoquantization=False`
# * `dqweight` when `store_master_weights=False` and `pseudoquantization=True`
if target_key == ".qweight":
# Loading a real quantized checkpoint without master weights
qweight = torch.nn.Parameter(
value,
requires_grad=False,
)

return {
".qweight": qweight,
# the way the FPQuantLinear module is designed, these parameters are expected in the model
# even though they are not used so we need to set them to zeros
".weight": torch.nn.Parameter(torch.zeros(0)),
".dqweight": torch.nn.Parameter(torch.zeros(0)),
}

if target_key == ".dqweight":
# Loading a pseudo-quantized checkpoint without master weights
dqweight = torch.nn.Parameter(value)

return {
".dqweight": dqweight,
# the way the FPQuantLinear module ips designed, these parameters are expected in the model
# even though they are not used so we need to set them to zeros
".weight": torch.nn.Parameter(torch.zeros(0)),
".qweight": torch.nn.Parameter(torch.zeros(0)),
".scales": torch.nn.Parameter(torch.zeros(0)),
}


def adapt_fp_quant_config(config: FPQuantConfig):
if config.forward_dtype == "mxfp4":
Expand Down
38 changes: 38 additions & 0 deletions src/transformers/quantizers/quantizer_fp_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,41 @@ def is_trainable(self, model: Optional["PreTrainedModel"] = None):

def is_serializable(self, **kwargs):
return True

def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
from fp_quant import FPQuantLinear

module, tensor_name = get_module_from_name(model, param_name)
if isinstance(module, FPQuantLinear) and tensor_name in ["weight", "qweight", "dqweight"]:
# Only quantize weights of FPQuantLinear modules that are not already quantized
return True
else:
return False

def get_quantize_ops(self):
from ..integrations.fp_quant import FpQuantQuantize

return FpQuantQuantize(self)

def get_weight_conversions(self):
from ..core_model_loading import WeightConverter
from ..integrations.fp_quant import FpQuantDeserialize

if self.pre_quantized:
if self.quantization_config.pseudoquantization:
return [
WeightConverter(
source_patterns=[".dqweight"],
target_patterns=".dqweight",
operations=[FpQuantDeserialize(self)],
),
]
else:
return [
WeightConverter(
source_patterns=[".qweight"],
target_patterns=".qweight",
operations=[FpQuantDeserialize(self)],
),
]
return []