diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py index 24eab78c14fc..710382668d17 100644 --- a/src/transformers/conversion_mapping.py +++ b/src/transformers/conversion_mapping.py @@ -142,11 +142,11 @@ def _build_checkpoint_conversion_mapping(): if hasattr(torch.nn.utils.parametrizations, "weight_norm"): mapping["legacy"] += [ WeightRenaming( - source_patterns="weight_g", + source_patterns=r"weight_g$", target_patterns="parametrizations.weight.original0", ), WeightRenaming( - source_patterns="weight_v", + source_patterns=r"weight_v$", target_patterns="parametrizations.weight.original1", ), ] diff --git a/src/transformers/integrations/fp_quant.py b/src/transformers/integrations/fp_quant.py index ccf933796165..af7821786d6c 100644 --- a/src/transformers/integrations/fp_quant.py +++ b/src/transformers/integrations/fp_quant.py @@ -13,6 +13,10 @@ # limitations under the License. "FP-Quant integration file" +from typing import Optional + +import torch + from ..utils import ( is_fp_quant_available, ) @@ -24,6 +28,94 @@ from transformers.utils.quantization_config import FPQuantConfig +from ..core_model_loading import ConversionOps +from ..quantizers.quantizers_utils import get_module_from_name + + +class FpQuantQuantize(ConversionOps): + def __init__(self, hf_quantizer): + self.hf_quantizer = hf_quantizer + + def convert( + self, + input_dict: torch.Tensor, + model: Optional[torch.nn.Module] = None, + missing_keys: Optional[list[str]] = None, + **kwargs, + ) -> dict[str, torch.Tensor]: + target_key, value = tuple(input_dict.items())[0] + value = value[0] + # Loading master weights or an unquantized checkpoint + weight = torch.nn.Parameter(value) + module, _ = get_module_from_name(model, target_key) + module.weight = weight + + # Let pre-forward handle the quantization and set None where necessary + # This operation will quantize the weights internally + with torch.cuda.device(value.device): + module.pre_forward() + + prefix_target_key = target_key.rsplit(".", 1)[0] + + # keys are set inside the module.pre_forward() method, we don't need remove them from the missing keys list + missing_keys.discard(target_key) + missing_keys.discard(f"{prefix_target_key}.backward_hadamard_matrix") + missing_keys.discard(f"{prefix_target_key}.forward_hadamard_matrix") + missing_keys.discard(f"{prefix_target_key}.act_global_scale") + missing_keys.discard(f"{prefix_target_key}.weight_global_scale") + missing_keys.discard(f"{prefix_target_key}.qweight") + missing_keys.discard(f"{prefix_target_key}.scales") + missing_keys.discard(f"{prefix_target_key}.dqweight") + return {} + + +class FpQuantDeserialize(ConversionOps): + def __init__(self, hf_quantizer): + self.hf_quantizer = hf_quantizer + + def convert( + self, + input_dict: torch.Tensor, + model: Optional[torch.nn.Module] = None, + full_layer_name: str | None = None, + missing_keys: Optional[list[str]] = None, + **kwargs, + ) -> dict[str, torch.Tensor]: + target_key, value = tuple(input_dict.items())[0] + value = value[0] if isinstance(value, list) else value + module, _ = get_module_from_name(model, target_key) + # The module holds either: + # * `weight` when `store_master_weights=True` + # * `qweight` and `scales` when `store_master_weights=False` and `pseudoquantization=False` + # * `dqweight` when `store_master_weights=False` and `pseudoquantization=True` + if target_key == ".qweight": + # Loading a real quantized checkpoint without master weights + qweight = torch.nn.Parameter( + value, + requires_grad=False, + ) + + return { + ".qweight": qweight, + # the way the FPQuantLinear module is designed, these parameters are expected in the model + # even though they are not used so we need to set them to zeros + ".weight": torch.nn.Parameter(torch.zeros(0)), + ".dqweight": torch.nn.Parameter(torch.zeros(0)), + } + + if target_key == ".dqweight": + # Loading a pseudo-quantized checkpoint without master weights + dqweight = torch.nn.Parameter(value) + + return { + ".dqweight": dqweight, + # the way the FPQuantLinear module ips designed, these parameters are expected in the model + # even though they are not used so we need to set them to zeros + ".weight": torch.nn.Parameter(torch.zeros(0)), + ".qweight": torch.nn.Parameter(torch.zeros(0)), + ".scales": torch.nn.Parameter(torch.zeros(0)), + } + def adapt_fp_quant_config(config: FPQuantConfig): if config.forward_dtype == "mxfp4": diff --git a/src/transformers/quantizers/quantizer_fp_quant.py b/src/transformers/quantizers/quantizer_fp_quant.py index f9d66986a2b4..8b21b8a16694 100644 --- a/src/transformers/quantizers/quantizer_fp_quant.py +++ b/src/transformers/quantizers/quantizer_fp_quant.py @@ -120,3 +120,31 @@ def is_trainable(self, model: Optional["PreTrainedModel"] = None): def is_serializable(self, **kwargs): return True + + def get_quantize_ops(self): + from ..integrations.fp_quant import FpQuantQuantize + + return FpQuantQuantize(self) + + def get_weight_conversions(self): + from ..core_model_loading import WeightConverter + from ..integrations.fp_quant import FpQuantDeserialize + + if self.pre_quantized: + if self.quantization_config.pseudoquantization: + return [ + WeightConverter( + source_patterns=[".dqweight"], + target_patterns=".dqweight", + operations=[FpQuantDeserialize(self)], + ), + ] + else: + return [ + WeightConverter( + source_patterns=[".qweight"], + target_patterns=".qweight", + operations=[FpQuantDeserialize(self)], + ), + ] + return []