Skip to content

Conversation

@rogeryoungh
Copy link

What does this PR do?

This PR adds MiniMax-M2 model to Hugging Face Transformers from MiniMaxAI.

Relevant Links:

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@ArthurZucker @Cyrilvallez

xuebi added 7 commits October 31, 2025 14:17
Signed-off-by: xuebi <xuebi@minimaxi.com>
Signed-off-by: xuebi <xuebi@minimaxi.com>
Signed-off-by: xuebi <xuebi@minimaxi.com>
Signed-off-by: xuebi <xuebi@minimaxi.com>
Signed-off-by: xuebi <xuebi@minimaxi.com>
Signed-off-by: xuebi <xuebi@minimaxi.com>
Signed-off-by: xuebi <xuebi@minimaxi.com>
Copy link
Contributor

@molbap molbap left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very clean integration, no particular comments. Thank you! cc @Cyrilvallez for core review

rogeryoungh and others added 2 commits November 6, 2025 12:42
Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com>
Signed-off-by: xuebi <xuebi@minimaxi.com>
@molbap
Copy link
Contributor

molbap commented Nov 13, 2025

Failing tests on the hub are unrelated timeouts. Thanks for the update!

@molbap molbap requested a review from ArthurZucker November 13, 2025 14:37
@Qubitium
Copy link
Contributor

Qubitium commented Nov 19, 2025

@ArthurZucker lets gets this quickly merged. 2 weeks is infinity in the current model wars cycle.

Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added some comments but looks already very nice, good job! Some things have changed on main so you will likely need to add our new weights converter to this model as well.

Imo, the integration tests seem suspicious to me and might not have been checked?

Comment on lines 232 to 251
class MiniMaxM2Experts(MixtralExperts):
pass


class MiniMaxM2SparseMoeBlock(MixtralSparseMoeBlock):
def __init__(self, config):
nn.Module.__init__(self)
self.top_k = config.num_experts_per_tok
self.jitter_noise = config.router_jitter_noise
self.gate = nn.Linear(config.hidden_size, config.num_local_experts, bias=False)
self.experts = MiniMaxM2Experts(config)
self.register_buffer("e_score_correction_bias", torch.zeros(config.num_local_experts))

def route_tokens_to_experts(self, router_logits):
routing_weights = torch.nn.functional.sigmoid(router_logits.float())
scores_for_choice = routing_weights + self.e_score_correction_bias
_, top_k_index = torch.topk(scores_for_choice, self.top_k, dim=-1, sorted=False)
top_k_weights = routing_weights.gather(1, top_k_index)
top_k_weights /= top_k_weights.sum(dim=-1, keepdim=True)
return top_k_index, top_k_weights.to(router_logits.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We recently changed how we build MoE, can you add this to our weight converter and adjust respectively; see

mapping["phimoe"] = mapping["mixtral"].copy()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This possibly also affects the tp plan

Copy link
Contributor

@Qubitium Qubitium Nov 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vasqu @rogeryoungh I am not minimax staff but still have concerns myself.

Since HF Transformer changes broke this PR in terms of compat, what is exaclty is

def _build_checkpoint_conversion_mapping(): doing? What is WeightConverter?

I see no method level documentation and only internal code comments which looks like list generation to assist in concat/fusing of MoE modules? But how everything is glued together is still unclear.

Minimax internally does not use HF transformers stack and I think they would have the same question as I do. What exactly is the latest v5 changes regarding MoE? Is there an overview doc to make this clear?

@ArthurZucker Can we get a global (even in alpha form) doc/readme on the MoE arch changes in v5 so model makers has a clear eagle eye view of the v5 target when it comes to MoE?

Copy link
Contributor

@vasqu vasqu Nov 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tried to answer here a bit #42028 (review)

  • _build_checkpoint_conversion_mapping is providing a mapping for model_type -> weight renaming+conversion
  • These operations can be split into 2 types
    • Renaming, should be self-explanatory --> rename weights
    • WeightConverter --> change the original weights in some way (spread out in core_model_loading
      • Chunk --> split the weights along a dimension into x chunks, e.g. useful for qkv splitting into q, k, and v
      • Concatenate --> essentially the reverse of chunk, e.g. fusing q, k, v into qkv
      • MergeModulelist --> merge a module list of tensors into one big tensor along a new dimension, e.g. when a module list should be one big parameter instead of a list (useful for us to merge MoE parameters and add hf kernels on top for more efficient inference etc)
      • PermuteForRope --> more complicated but essentially related to how weights are rotated, e.g. original llama needs this for hf RoPE conversion

You can stack these operations as well but we continue to change/add these.

The standard for MoE currently is to have fused experts into one big parameter instead of a modulelist. We use these to enable kernels, e.g.

@use_kernel_forward_from_hub("Llama4TextMoe")

This is for efficiency reasons in essence.

@slow
@require_torch_accelerator
def test_small_model_logits(self):
model_id = "hf-internal-testing/MiniMaxM2-tiny"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@molbap have you uploaded a model there? That's definitely the way to go but we should double-check and probably adjust numbers

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

haven't! good catch

@vasqu
Copy link
Contributor

vasqu commented Nov 19, 2025

@Qubitium sorry for the delays, will be taking over for the most part. Crunching some important features for v5 so some PRs got lost under it, apologies

@rogeryoungh
Copy link
Author

I've encountered a strange problem:

>>> model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="auto")
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "../transformers/src/transformers/models/auto/auto_factory.py", line 373, in from_pretrained
    return model_class.from_pretrained(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "../transformers/src/transformers/modeling_utils.py", line 278, in _wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "../transformers/src/transformers/modeling_utils.py", line 3967, in from_pretrained
    device_map = _get_device_map(model, device_map, max_memory, hf_quantizer)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "../transformers/src/transformers/integrations/accelerate.py", line 407, in _get_device_map
    device_map = infer_auto_device_map(
                 ^^^^^^^^^^^^^^^^^^^^^^
  File "../transformers/src/transformers/integrations/accelerate.py", line 736, in infer_auto_device_map
    current_buffer_size = compute_module_total_buffer_size(module, hf_quantizer)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "../transformers/src/transformers/integrations/accelerate.py", line 260, in compute_module_total_buffer_size
    module_sizes, _ = compute_module_sizes(model, hf_quantizer, buffers_only=True)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "../transformers/src/transformers/integrations/accelerate.py", line 240, in compute_module_sizes
    dtype_size = hf_quantizer.param_element_size(model, name)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "../transformers/src/transformers/quantizers/base.py", line 182, in param_element_size
    return model.get_parameter_or_buffer(param_name).element_size()
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "../transformers/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1964, in __getattr__
    raise AttributeError(
AttributeError: 'MiniMaxM2DecoderLayer' object has no attribute 'get_parameter_or_buffer'

I found a solution #42342, which involves adding a get_parameter_or_buffer function to the MiniMaxM2RotaryEmbedding. However, for MiniMax-M2, e_score_correction_bias also needs to be changed to nn.Parameter.

@@ -106,7 +106,10 @@ class MiniMaxM2SparseMoeBlock(nn.Module):
         self.jitter_noise = config.router_jitter_noise
         self.gate = nn.Linear(config.hidden_size, config.num_local_experts, bias=False)
         self.experts = MiniMaxM2Experts(config)
-        self.register_buffer("e_score_correction_bias", torch.zeros(config.num_local_experts))
+        self.e_score_correction_bias = nn.Parameter(
+            torch.zeros(config.num_local_experts),
+            requires_grad=False,
+        )
 
     def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
         batch_size, sequence_length, hidden_dim = hidden_states.shape
@@ -168,6 +171,20 @@ class MiniMaxM2RotaryEmbedding(nn.Module):
         self.register_buffer("inv_freq", inv_freq, persistent=False)
         self.original_inv_freq = inv_freq
 
+    # Add a compatibility method so callers expecting PreTrainedModel-like API don't crash.
+    def get_parameter_or_buffer(self, name: str):
+        # Prefer direct attribute access (parameters and buffers are attributes)
+        if hasattr(self, name):
+            return getattr(self, name)
+        # Fallback: search named parameters and buffers (non-recursive to keep semantics)
+        for n, p in self.named_parameters(recurse=False):
+            if n == name:
+                return p
+        for n, b in self.named_buffers(recurse=False):
+            if n == name:
+                return b
+        raise AttributeError(f"{self.__class__.__name__} has no parameter or buffer named '{name}'")
+
     @staticmethod
     def compute_default_rope_parameters(
         config: Optional[MiniMaxM2Config] = None,

@MekkCyber
Copy link
Contributor

MekkCyber commented Nov 24, 2025

Hi @rogeryoungh, we are taking care of this in : https://github.com/huggingface/transformers/pull/42289/files#diff-7f40070336f6d7b1ffe08e654cdf930080e3cbd4dbbcbee2996fabe4ffc1c2b3, this should fix the issue without adding a get_parameter_or_buffer

Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added some comments which are mostly revolving around the weights name/convention. See #41580 for the respective PR that introduced this

The essence is that we now support on the fly renaming and weight conversion (similar to vllm). You only have to specify your operations in

def _build_checkpoint_conversion_mapping():

In your case, this will look very similar to Mixtral I assume, i.e. block_sparse->mlp + fusing the MoE modulelists to parameters. Maybe additional mlp renaming for the base MLP to inherit from LlamaMLP. Please let us know if it doesn't work as expected, it's still WIP in a way and we work on providing better docs for this (cc @ArthurZucker).

I know this is a lot at once so apologies.



@require_torch
class MiniMaxM2IntegrationTest(unittest.TestCase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you provide us a slimmed down model version, i.e. ideally running on A10 GPUs (24GB)? So that we can have some integration tests? I can move this model/copy it to our internal repos

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

^

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wrote a script to generate small models, but I encountered a problem: AutoModelForCausalLM.from_config does not follow the quantization_config settings, so the generated models do not have FP8Linear layers.
https://github.com/rogeryoungh/MiniMaxM2TinyModelGenerator/blob/main/generator.py

Copy link
Contributor

@vasqu vasqu Dec 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would initialize random weights, no? I think we can just cut the original safetensors and load only a partial subset of the original instead (e.g. 8 layers), wdyt? This should keep the fp8 weights

Otherwise, even if it is saved in bf16 for example, we should be able to quantize to fp8 on the fly. Not ideal but the goal here is to have a working version.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Native fp8 weights for models initialized from config cc @MekkCyber @SunMarc

@rogeryoungh
Copy link
Author

I verified that PR #42289 resolved the previous issue.

Now I've encountered a new problem: when loading weights, I get the following error. Can this be resolved using WeightRenaming?

Key                                                       | Status     | Details
----------------------------------------------------------+------------+--------
model.layers.{0...61}.mlp.experts.down_proj_scale_inv     | UNEXPECTED |        
model.layers.{0...61}.mlp.experts.gate_up_proj_scale_inv  | UNEXPECTED |        
model.layers.{0...61}.mlp.experts.gate_up_proj_scales_inv | MISSING    |        
model.layers.{0...61}.mlp.experts.down_proj_scales_inv    | MISSING    |        

Notes:
- UNEXPECTED    :can be ignored when loading from different task/architecture; not ok if you expect identical arch.
- MISSING       :those params were newly initialized because missing form the checkpoint. Consider training on your downstream task.

@vasqu
Copy link
Contributor

vasqu commented Nov 26, 2025

I'm a bit surprised that you get

model.layers.{0...61}.mlp.experts.gate_up_proj_scales_inv | MISSING    |        
model.layers.{0...61}.mlp.experts.down_proj_scales_inv    | MISSING    |     

Don't you have

self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))

This seems weird to me, how do you load the model? But yes, that can work as well but I'm surprised about the naming because it isn't in the modeling file at all - can you double check?

@rogeryoungh
Copy link
Author

rogeryoungh commented Nov 26, 2025

image

This was a part of FP8Linear.

if self.weight.element_size() == 1:
scale_out_features = (out_features + block_size[0] - 1) // block_size[0]
scale_in_features = (in_features + block_size[1] - 1) // block_size[1]
self.weight_scale_inv = nn.Parameter(
torch.empty(scale_out_features, scale_in_features, dtype=torch.float32, device=device)
)
else:
self.register_parameter("weight_scale_inv", None)

@molbap
Copy link
Contributor

molbap commented Nov 26, 2025

Could it be linked with what we discussed yesterday @MekkCyber , just in case?

@vasqu
Copy link
Contributor

vasqu commented Nov 26, 2025

#42434 should fix your issue @rogeryoungh

Sorry didn't notice we have native fp8 in this model 🤦

@MekkCyber
Copy link
Contributor

Hi @rogeryoungh ! Can you try with the pr @vasqu shared please 🤗 ! Thank you

Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My last few important comments, looks very good already! Sorry about all the issues during the PR with fp8 and the buffer 😓



@require_torch
class MiniMaxM2IntegrationTest(unittest.TestCase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

^

@rogeryoungh
Copy link
Author

When syncing branches, I encountered some issues related to FP8Linear:

Key                                                      | Status     |  | 
---------------------------------------------------------+------------+--+-
model.layers.{0...61}.mlp.experts.gate_up_proj_scale_inv | UNEXPECTED |  | 
model.layers.{0...61}.mlp.experts.down_proj_scale_inv    | UNEXPECTED |  | 

Notes:
- UNEXPECTED    :can be ignored when loading from different task/architecture; not ok if you expect identical arch.
Some parameters are on the meta device because they were offloaded to the cpu.

I also encountered a strange error, it seems that FP8Linear is not working correctly here?

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "../transformers/.venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "../transformers/src/transformers/generation/utils.py", line 2684, in generate
    result = decoding_method(
             ^^^^^^^^^^^^^^^^
  File "../transformers/src/transformers/generation/utils.py", line 2877, in _sample
    outputs = self._prefill(input_ids, generation_config, model_kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "../transformers/src/transformers/generation/utils.py", line 3853, in _prefill
    return self(**model_inputs, return_dict=True)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "../transformers/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "../transformers/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "../transformers/.venv/lib/python3.12/site-packages/accelerate/hooks.py", line 175, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "../transformers/src/transformers/utils/generic.py", line 764, in wrapper
    output = func(self, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "../transformers/src/transformers/models/minimax_m2/modeling_minimax_m2.py", line 671, in forward
    outputs: MoeModelOutputWithPast = self.model(
                                      ^^^^^^^^^^^
  File "../transformers/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "../transformers/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "../transformers/src/transformers/utils/generic.py", line 919, in wrapper
    outputs = func(self, *args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "../transformers/src/transformers/models/minimax_m2/modeling_minimax_m2.py", line 508, in forward
    hidden_states = decoder_layer(
                    ^^^^^^^^^^^^^^
  File "../transformers/src/transformers/modeling_layers.py", line 94, in __call__
    return super().__call__(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "../transformers/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "../transformers/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "../transformers/.venv/lib/python3.12/site-packages/accelerate/hooks.py", line 175, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "../transformers/src/transformers/models/minimax_m2/modeling_minimax_m2.py", line 412, in forward
    hidden_states = self.mlp(hidden_states)
                    ^^^^^^^^^^^^^^^^^^^^^^^
  File "../transformers/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "../transformers/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "../transformers/.venv/lib/python3.12/site-packages/accelerate/hooks.py", line 175, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "../transformers/src/transformers/models/minimax_m2/modeling_minimax_m2.py", line 124, in forward
    hidden_states = self.experts(hidden_states, top_k_index, top_k_weights)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "../transformers/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "../transformers/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "../transformers/.venv/lib/python3.12/site-packages/accelerate/hooks.py", line 175, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "../transformers/src/transformers/models/minimax_m2/modeling_minimax_m2.py", line 80, in forward
    gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::Float8_e4m3fn

@MekkCyber
Copy link
Contributor

Sorry about that @rogeryoungh ! can you check with this pr: #42654, it's a small fix

Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks left a few comments, assuming that the other PR fixed the fp8 issue

It looks very ready, let's just focus on adding integration tests then we're good to go imo



@require_torch
class MiniMaxM2IntegrationTest(unittest.TestCase):
Copy link
Contributor

@vasqu vasqu Dec 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would initialize random weights, no? I think we can just cut the original safetensors and load only a partial subset of the original instead (e.g. 8 layers), wdyt? This should keep the fp8 weights

Otherwise, even if it is saved in bf16 for example, we should be able to quantize to fp8 on the fly. Not ideal but the goal here is to have a working version.



@require_torch
class MiniMaxM2IntegrationTest(unittest.TestCase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Native fp8 weights for models initialized from config cc @MekkCyber @SunMarc

@rogeryoungh
Copy link
Author

rogeryoungh commented Dec 8, 2025

I used the following code to load the model from the configuration file:

# ... modify default config and save
config = AutoConfig.from_pretrained(
    save_folder,
    trust_remote_code=True,
)
model = AutoModelForCausalLM.from_config(config)

# full code: https://github.com/rogeryoungh/MiniMaxM2TinyModelGenerator/blob/main/generator.py

But I received the following error:

Traceback (most recent call last):
  File "../MiniMaxM2TinyModelGenerator/generator.py", line 49, in <module>
    model = AutoModelForCausalLM.from_config(config)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "../transformers/src/transformers/models/auto/auto_factory.py", line 237, in from_config
    return model_class._from_config(config, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "../transformers/src/transformers/modeling_utils.py", line 250, in _wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "../transformers/src/transformers/modeling_utils.py", line 1438, in _from_config
    model = cls(config, **kwargs)
            ^^^^^^^^^^^^^^^^^^^^^
  File "../transformers/src/transformers/models/minimax_m2/modeling_minimax_m2.py", line 617, in __init__
    self.model = MiniMaxM2Model(config)
                 ^^^^^^^^^^^^^^^^^^^^^^
  File "../transformers/src/transformers/models/minimax_m2/modeling_minimax_m2.py", line 458, in __init__
    self.rotary_emb = MiniMaxM2RotaryEmbedding(config=config)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "../transformers/src/transformers/models/minimax_m2/modeling_minimax_m2.py", line 160, in __init__
    self.rope_type = self.config.rope_parameters["rope_type"]
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "../transformers/src/transformers/configuration_utils.py", line 198, in __getattribute__
    return super().__getattribute__(key)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'MiniMaxM2Config' object has no attribute 'rope_parameters'

@vasqu
Copy link
Contributor

vasqu commented Dec 8, 2025

@rogeryoungh I think the auto mapping is getting confused here, you should rather directly use MiniMaxM2Config from your current PR. At least that seems to work somewhat. Note that this was used for BC loading but not saving so you might need to check if some values are unwanted, e.g. rope_scaling will also live beside rope_parameters, layernorm_full_attention_beta, etc. --> We should have a config that is only having the relevant fields (if that's ok) for us.

Also the automap should not consist in the config unless you want to use remote code and need to point the respective files (which shouldn't be the case anymore after merging this).

Last side note: I think the way you load the generation config it will load all our values (including those that don't divert from defaults). Imo, you can keep the generation config purely as is and just move the raw json.

Edit: If it helps you can also do something along

from transformers import MiniMaxM2Config, MiniMaxM2Config

config = MiniMaxM2Config.from_pretrained("MiniMaxAI/MiniMax-M2") # ideally just MiniMaxM2Config() as those should have the respective values already 
config.num_hidden_layers = 7  # check the size, something up to 7-8B should be ok
model = MiniMaxM2Config.from_pretrained(
    "MiniMaxAI/MiniMax-M2",
    config=config,  # we cut the model into x layers, e.g. here 7
    device_map="auto",
    dtype="auto",
    # ... (quants if you need etc)
)
model.save_pretrained("Small-MiniMax-M2")

@github-actions
Copy link
Contributor

github-actions bot commented Dec 9, 2025

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, minimax_m2

@rogeryoungh
Copy link
Author

According to my tests, from_pretrained does not reduce the number of layers in the model, indicating that the configuration is not being used.

I found a solution: generate the model twice.

  1. Load the model using from_config, which saves the model with the correct structure but without the correct FP8Linear layers.
  2. Use from_pretrained again, which allows me to save the correct quantized results.

My complete code is here: https://github.com/rogeryoungh/MiniMaxM2TinyModelGenerator/blob/main/generator.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants