Skip to content

Conversation

@minettekaum
Copy link
Contributor

@minettekaum minettekaum commented Dec 12, 2025

Description

Adding ring_attn algorithm

Type of Change

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

How Has This Been Tested?

I ran the tests

Checklist

  • My code follows the style guidelines of this project
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Copy link

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment @cursor review or bugbot run to trigger another review on this PR

seed_t = sync_tensor(seed_t, dim=0, group=None)
seed_t = seed_t.chunk(world_size, dim=0)[0]
seed = seed_t.item()
seed -= torch.iinfo(torch.int64).min
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Incorrect seed calculation produces excessively large values

The seed calculation subtracts torch.iinfo(torch.int64).min (which equals -2^63) from the seed, effectively adding 2^63. Since torch.randint already produces non-negative values in [0, 2^63-1), this subtraction results in seed values in [2^63, 2^64-1), which are extremely large. This appears unintentional - the seed is already suitable for manual_seed() without this transformation. The unnecessary arithmetic could cause overflow issues or unexpected behavior with the random number generator.

Fix in Cursor Fix in Web

torch.Tensor
The gradient of the output tensor.
"""
return ring_attention._scaled_dot_product_ring_flash_attention_backward(*args, **kwargs)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Incomplete backward pass missing saved tensors for gradient computation

The LocalFunc autograd function's backward method is incomplete. The forward method doesn't call ctx.save_for_backward() to save the tensors needed for gradient computation (mesh, query, key, value, output, lse). The backward method only receives gradient outputs via *args and passes them directly to _scaled_dot_product_ring_flash_attention_backward, but this function typically requires the original inputs and outputs to compute input gradients. This would cause training (backward pass) to fail with incorrect arguments or missing data.

Fix in Cursor Fix in Web

Copy link
Member

@johannaSommer johannaSommer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! There is one hook missing in smash.py that checks for this ring attention algorithm and spawns the distribtued server, otherwise no notes 🌻

@minettekaum minettekaum changed the title Feat/distributer-algorithm: adding ring_attn Feat: distributer-algorithm Dec 19, 2025
@minettekaum minettekaum changed the title Feat: distributer-algorithm Feat: add distributer algorithm Dec 19, 2025
@minettekaum minettekaum changed the title Feat: add distributer algorithm feat: add distributer algorithm Dec 19, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants