-
Notifications
You must be signed in to change notification settings - Fork 46
feat: Add RoPE Embedding challenge (Medium) #136
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@23silicon Can you make the spec the correct format, add more functional tests, and add JAX starter code |
| import jax | ||
| import jax.numpy as jnp | ||
|
|
||
| #Q, cos, sin are tensors on the GPU |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need space between # and Q
|
|
||
| #Q, cos, sin are tensors on the GPU | ||
| @jax.jit | ||
| def solve(Q: array, cos: array, sin: array, output: array, M: int, D: int) -> jax.Array: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
jax.Array, not array
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no output. It's returned
Description
This PR adds a new Medium difficulty challenge: Rotary Positional Embedding (RoPE).
RoPE is a critical component in modern LLMs (Llama 2/3, Mistral, etc.). I noticed that Unsloth recently added optimized RoPE kernels to their main repo, so I thought it would be a great addition to LeetGPU for users to practice implementing this fundamental operation.
Challenge Details
challenges/medium/61_rope_embeddingchallenge.py.Verification
challenge.pyreference implementation verified against standard formula.