-
Notifications
You must be signed in to change notification settings - Fork 75
feat: add distributer algorithm #459
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment @cursor review or bugbot run to trigger another review on this PR
| seed_t = sync_tensor(seed_t, dim=0, group=None) | ||
| seed_t = seed_t.chunk(world_size, dim=0)[0] | ||
| seed = seed_t.item() | ||
| seed -= torch.iinfo(torch.int64).min |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: Incorrect seed calculation produces excessively large values
The seed calculation subtracts torch.iinfo(torch.int64).min (which equals -2^63) from the seed, effectively adding 2^63. Since torch.randint already produces non-negative values in [0, 2^63-1), this subtraction results in seed values in [2^63, 2^64-1), which are extremely large. This appears unintentional - the seed is already suitable for manual_seed() without this transformation. The unnecessary arithmetic could cause overflow issues or unexpected behavior with the random number generator.
| torch.Tensor | ||
| The gradient of the output tensor. | ||
| """ | ||
| return ring_attention._scaled_dot_product_ring_flash_attention_backward(*args, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: Incomplete backward pass missing saved tensors for gradient computation
The LocalFunc autograd function's backward method is incomplete. The forward method doesn't call ctx.save_for_backward() to save the tensors needed for gradient computation (mesh, query, key, value, output, lse). The backward method only receives gradient outputs via *args and passes them directly to _scaled_dot_product_ring_flash_attention_backward, but this function typically requires the original inputs and outputs to compute input gradients. This would cause training (backward pass) to fail with incorrect arguments or missing data.
johannaSommer
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! There is one hook missing in smash.py that checks for this ring attention algorithm and spawns the distribtued server, otherwise no notes 🌻
Description
Adding ring_attn algorithm
Type of Change
How Has This Been Tested?
I ran the tests
Checklist