-
Notifications
You must be signed in to change notification settings - Fork 75
Swizzle for ring-based overlap #5889
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Review updated until commit 4116369 Description
|
| Relevant files | |||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| Enhancement |
| ||||||||||||
| Tests |
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| ⚡ Recommended focus areas for review |
Index Calculation Logic
index = (index + team_index) % team_size implements ring-based overlap. This should be validated against the expected behavior described in the PR title and ensure it correctly handles edge cases like team_size=1 or when team_index=0. |
Greptile OverviewGreptile SummaryThis PR adds swizzle support for ring-based overlap in host IR by introducing a new Key changes:
Critical issue found:
Confidence Score: 1/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Client
participant HostIrEvaluator
participant ShardByStream
participant Swizzle
participant Tensor
Client->>HostIrEvaluator: handle(ShardByStream*)
HostIrEvaluator->>ShardByStream: get out_tv
HostIrEvaluator->>ShardByStream: find Stream axis in allocation_domain
HostIrEvaluator->>ShardByStream: evaluate stream_index
alt stream_id has Swizzle definition
HostIrEvaluator->>Swizzle: check if definition is Swizzle
HostIrEvaluator->>Swizzle: get ParallelType
HostIrEvaluator->>Tensor: get device mesh
HostIrEvaluator->>Tensor: compute team_index from device_id
Note over HostIrEvaluator: Apply cyclic shift:<br/>index = (stream_index + team_index) % team_size
end
HostIrEvaluator->>Tensor: chunk input tensor
HostIrEvaluator->>Tensor: extract chunk at computed index
HostIrEvaluator->>HostIrEvaluator: bind out_tv to result tensor
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
7 files reviewed, 1 comment
| at::Tensor out_tensor = | ||
| in_tensor | ||
| .chunk( | ||
| stream_id->extent()->evaluate().as<int64_t>(), | ||
| index, | ||
| getShardedLogicalAxis(out_tv, ParallelType::Stream)) | ||
| .at(stream_index); | ||
| .at(index); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
incorrect number of arguments passed to .chunk() - the method signature is chunk(int chunks, int dim) (2 args), but 3 arguments are being passed here
| at::Tensor out_tensor = | |
| in_tensor | |
| .chunk( | |
| stream_id->extent()->evaluate().as<int64_t>(), | |
| index, | |
| getShardedLogicalAxis(out_tv, ParallelType::Stream)) | |
| .at(stream_index); | |
| .at(index); | |
| at::Tensor out_tensor = | |
| in_tensor | |
| .chunk( | |
| stream_id->extent()->evaluate().as<int64_t>(), | |
| getShardedLogicalAxis(out_tv, ParallelType::Stream)) | |
| .at(index); |
wujingyue
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM otherwise
| } | ||
| }; | ||
|
|
||
| class Swizzle : public Expr { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| class Swizzle : public Expr { | |
| class Swizzle1D : public Expr { |
This should be moved to csrc/ir/internal_nodes.h. It's fine to not use it for kernel generation for now, but the Expr itself is agnostic and in fact will be used by preseg.
| return outputs().at(0)->as<IterDomain>(); | ||
| } | ||
|
|
||
| ParallelType pt() const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| ParallelType pt() const { | |
| ParallelType parallelType() const { |
|
|
||
| namespace nvfuser::hir { | ||
|
|
||
| TensorView* swizzle(TensorView* in, int64_t axis, ParallelType pt) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TensorView::swizzle1d
| mesh.multiDimensionalIndexOf(communicator_->deviceId()); | ||
| auto pt_axis = mesh.parallelTypeToAxis(pt); | ||
| int64_t team_index = md_index[pt_axis].item<int64_t>(); | ||
| index = (index + team_index) % team_size; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: consider putting + and % to the HostIrContainer in the future. See also
Lines 26 to 52 in 77abd29
| NVF_API std::pair<Val*, Val*> dispatchSwizzle( | |
| Swizzle2DType type, | |
| Val* x, | |
| Val* y, | |
| Val* maybe_size_x, | |
| Val* maybe_size_y); | |
| NVF_API std::pair<Val*, Val*> dispatchSwizzle( | |
| SwizzleType type, | |
| Val* x, | |
| Val* y, | |
| Val* maybe_size_x, | |
| Val* maybe_size_y); | |
| NVF_API std::pair<Val*, Val*> dispatchUnSwizzle( | |
| Swizzle2DType type, | |
| Val* x, | |
| Val* y, | |
| Val* maybe_size_x, | |
| Val* maybe_size_y); | |
| NVF_API std::pair<Val*, Val*> dispatchUnSwizzle( | |
| SwizzleType type, | |
| Val* x, | |
| Val* y, | |
| Val* maybe_size_x, | |
| Val* maybe_size_y); |
No description provided.