Skip to content

Conversation

@ksugama
Copy link

@ksugama ksugama commented Feb 3, 2026

This PR addresses FIXME by flattening parameter tensors on the accelerators instead of the CPU during zero stage 1 and 2 initialization. This should alleviate CPU contention, with the caveat that the optimization is only used when there is enough VRAM to allocate a full copy of the parameter buffers.


If necessary, this optimization can be extended to allowed a tiered system that trades off VRAM space with performance, which might look like the following:

if enough VRAM for 2x model_size:
    naive flatten
else if enough VRAM for model_size / N:
    distributed flatten across N devices
else:
    flatten on CPU

The distributed flatten would involve each device flattening a portion of the parameters and performing an all-gather to assemble the full flattened model. See FIXME for original discussion.

Signed-off-by: Kento Sugama <kentosugama@protonmail.ch>
Signed-off-by: Kento Sugama <kentosugama@protonmail.ch>
Signed-off-by: Kento Sugama <kentosugama@protonmail.ch>
Signed-off-by: Kento Sugama <kentosugama@protonmail.ch>
Signed-off-by: Kento Sugama <kentosugama@protonmail.ch>
Signed-off-by: Kento Sugama <kentosugama@protonmail.ch>
Signed-off-by: Kento Sugama <kentosugama@protonmail.ch>
@ksugama ksugama force-pushed the flatten-tensor-gpu branch from a07a21b to 293fbab Compare February 3, 2026 17:19
Signed-off-by: Kento Sugama <kentosugama@protonmail.ch>
@ksugama ksugama changed the title Z1/2 Flatten Parameters on device Z1/2 init: flatten params on device Feb 3, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant