Skip to content

Conversation

@Priya2698
Copy link
Collaborator

No description provided.

@Priya2698 Priya2698 changed the title Pm/swizzle ring Swizzle for ring-based overlap Jan 28, 2026
@github-actions
Copy link

github-actions bot commented Jan 28, 2026

Review updated until commit 4116369

Description

  • Implement Swizzle IR node for ring-based stream overlap

  • Add cyclic shift computation for device-aware stream indexing

  • Create swizzle operation API for TensorView transformations

  • Add comprehensive test for swizzle with parallel types

  • Integrate swizzle handling in multi-device utilities

Changes walkthrough

Relevant files
Enhancement
evaluator.cpp
Handle swizzled stream axes in evaluator                                 

csrc/host_ir/evaluator.cpp

  • Modified ShardByStream handler to support swizzled stream axes
  • Added cyclic shift computation: in_index = (out_index + device_id) %
    num_devices
  • Extract parallel type from swizzle definition and compute device-aware
    index
  • Updated tensor chunking to use computed index instead of raw
    stream_index
  • +25/-3   
    ir.cpp
    Define Swizzle IR node implementation                                       

    csrc/host_ir/ir.cpp

  • Implement Swizzle class with input/output IterDomains and ParallelType
  • Add constructor validation for HostIrContainer registration
  • Provide string representation methods (toString, toInlineString)
  • Define clone and create methods for IR node management
  • +32/-0   
    ops.cpp
    Create swizzle operation API                                                         

    csrc/host_ir/ops.cpp

  • Add swizzle() function taking TensorView, axis, and ParallelType
  • Create Swizzle IR node connecting input/output IterDomains
  • Modify TensorView loop domain to incorporate swizzle transformation
  • Return modified TensorView for operation chaining
  • +15/-0   
    utils.cpp
    Handle swizzle in multi-device utilities                                 

    csrc/multidevice/utils.cpp

  • Add Swizzle case handling in getProducingLogicalAxis function
  • Traverse through swizzle operations to find source IterDomain
  • Include swizzle in unexpected transform error handling
  • +3/-0     
    ir.h
    Declare Swizzle IR node interface                                               

    csrc/host_ir/ir.h

  • Declare Swizzle class inheriting from Expr with full interface
  • Define constructor taking IterDomains and ParallelType
  • Provide accessors for input/output IterDomains and parallel type
  • Include standard IR node methods (clone, create, string conversion)
  • +36/-0   
    ops.h
    Declare swizzle operation function                                             

    csrc/host_ir/ops.h

  • Declare swizzle() function in hir namespace
  • Function signature: TensorView* swizzle(TensorView*, int64_t,
    ParallelType)
  • +2/-0     
    Tests
    test_multidevice_host_ir.cpp
    Add comprehensive swizzle test                                                     

    tests/cpp/test_multidevice_host_ir.cpp

  • Add SwizzleWithParallelType test case for comprehensive validation
  • Set up HostIrContainer with device mesh and parallelization
  • Apply swizzle operations on TensorViews with DIDx parallel type
  • Test ring-based overlap behavior with device-aware indexing
  • Verify output matches expected sharded and zeroed tensor
  • +78/-0   

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Index Calculation Logic

    The cyclic shift calculation index = (index + team_index) % team_size implements ring-based overlap. This should be validated against the expected behavior described in the PR title and ensure it correctly handles edge cases like team_size=1 or when team_index=0.

    index = (index + team_index) % team_size;
    Swizzle Operation Safety

    The swizzle function modifies the loop domain in-place and returns the input TensorView. This mutation pattern should be reviewed for thread-safety and consistency with other operations in the codebase.

    in->setLoopDomain(loop_domain);
    Transform Traversal Completeness

    The addition of swizzle handling in getProducingLogicalAxis should be reviewed to ensure all transform paths are covered and no edge cases are missed in the traversal logic.

    id = swizzle->in();

    @Priya2698 Priya2698 marked this pull request as ready for review January 30, 2026 21:28
    @Priya2698 Priya2698 requested a review from wujingyue January 30, 2026 21:28
    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Jan 30, 2026

    Greptile Overview

    Greptile Summary

    This PR adds swizzle support for ring-based overlap in host IR by introducing a new Swizzle class and operation. The implementation enables cyclic shift of stream indices across devices using the formula in_index = (out_index + device_id) % num_devices.

    Key changes:

    • Added new hir::Swizzle IR node with input/output IterDomains and ParallelType attribute
    • Implemented hir::swizzle() function to apply swizzle transformation to TensorView loop domains
    • Extended ShardByStream handler to compute cyclic index shifts when stream axis is defined by swizzle
    • Added swizzle case handling in getProducingLogicalAxis for multidevice utilities
    • Included comprehensive test validating swizzle with parallel type functionality

    Critical issue found:

    • The .chunk() method call in csrc/host_ir/evaluator.cpp:834 passes 3 arguments but the ATen API only accepts 2 arguments (chunks, dim), causing compilation failure

    Confidence Score: 1/5

    • This PR cannot be merged due to a critical syntax error that will cause compilation failure
    • The .chunk() method call in the evaluator has incorrect number of arguments (3 instead of 2), which will prevent the code from compiling. While the swizzle implementation itself appears sound, this syntax error is blocking.
    • csrc/host_ir/evaluator.cpp requires immediate attention to fix the chunk method call

    Important Files Changed

    Filename Overview
    csrc/host_ir/evaluator.cpp added swizzle support for ring-based overlap with cyclic shift logic, but chunk method call has wrong number of arguments
    csrc/host_ir/ir.h added new Swizzle class definition with proper accessor methods
    csrc/host_ir/ir.cpp implemented Swizzle constructor and string methods
    csrc/host_ir/ops.h added swizzle function declaration
    csrc/host_ir/ops.cpp implemented swizzle function that creates swizzle operation and updates loop domain
    csrc/multidevice/utils.cpp added handling for hir::Swizzle in getProducingLogicalAxis
    tests/cpp/test_multidevice_host_ir.cpp added comprehensive test for swizzle with parallel type functionality

    Sequence Diagram

    sequenceDiagram
        participant Client
        participant HostIrEvaluator
        participant ShardByStream
        participant Swizzle
        participant Tensor
        
        Client->>HostIrEvaluator: handle(ShardByStream*)
        HostIrEvaluator->>ShardByStream: get out_tv
        HostIrEvaluator->>ShardByStream: find Stream axis in allocation_domain
        HostIrEvaluator->>ShardByStream: evaluate stream_index
        
        alt stream_id has Swizzle definition
            HostIrEvaluator->>Swizzle: check if definition is Swizzle
            HostIrEvaluator->>Swizzle: get ParallelType
            HostIrEvaluator->>Tensor: get device mesh
            HostIrEvaluator->>Tensor: compute team_index from device_id
            Note over HostIrEvaluator: Apply cyclic shift:<br/>index = (stream_index + team_index) % team_size
        end
        
        HostIrEvaluator->>Tensor: chunk input tensor
        HostIrEvaluator->>Tensor: extract chunk at computed index
        HostIrEvaluator->>HostIrEvaluator: bind out_tv to result tensor
    
    Loading

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    7 files reviewed, 1 comment

    Edit Code Review Agent Settings | Greptile

    Comment on lines 832 to +838
    at::Tensor out_tensor =
    in_tensor
    .chunk(
    stream_id->extent()->evaluate().as<int64_t>(),
    index,
    getShardedLogicalAxis(out_tv, ParallelType::Stream))
    .at(stream_index);
    .at(index);
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    incorrect number of arguments passed to .chunk() - the method signature is chunk(int chunks, int dim) (2 args), but 3 arguments are being passed here

    Suggested change
    at::Tensor out_tensor =
    in_tensor
    .chunk(
    stream_id->extent()->evaluate().as<int64_t>(),
    index,
    getShardedLogicalAxis(out_tv, ParallelType::Stream))
    .at(stream_index);
    .at(index);
    at::Tensor out_tensor =
    in_tensor
    .chunk(
    stream_id->extent()->evaluate().as<int64_t>(),
    getShardedLogicalAxis(out_tv, ParallelType::Stream))
    .at(index);

    Copy link
    Collaborator

    @wujingyue wujingyue left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    LGTM otherwise

    }
    };

    class Swizzle : public Expr {
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Suggested change
    class Swizzle : public Expr {
    class Swizzle1D : public Expr {

    This should be moved to csrc/ir/internal_nodes.h. It's fine to not use it for kernel generation for now, but the Expr itself is agnostic and in fact will be used by preseg.

    return outputs().at(0)->as<IterDomain>();
    }

    ParallelType pt() const {
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Suggested change
    ParallelType pt() const {
    ParallelType parallelType() const {


    namespace nvfuser::hir {

    TensorView* swizzle(TensorView* in, int64_t axis, ParallelType pt) {
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    TensorView::swizzle1d

    mesh.multiDimensionalIndexOf(communicator_->deviceId());
    auto pt_axis = mesh.parallelTypeToAxis(pt);
    int64_t team_index = md_index[pt_axis].item<int64_t>();
    index = (index + team_index) % team_size;
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Note: consider putting + and % to the HostIrContainer in the future. See also

    Fuser/csrc/swizzle.h

    Lines 26 to 52 in 77abd29

    NVF_API std::pair<Val*, Val*> dispatchSwizzle(
    Swizzle2DType type,
    Val* x,
    Val* y,
    Val* maybe_size_x,
    Val* maybe_size_y);
    NVF_API std::pair<Val*, Val*> dispatchSwizzle(
    SwizzleType type,
    Val* x,
    Val* y,
    Val* maybe_size_x,
    Val* maybe_size_y);
    NVF_API std::pair<Val*, Val*> dispatchUnSwizzle(
    Swizzle2DType type,
    Val* x,
    Val* y,
    Val* maybe_size_x,
    Val* maybe_size_y);
    NVF_API std::pair<Val*, Val*> dispatchUnSwizzle(
    SwizzleType type,
    Val* x,
    Val* y,
    Val* maybe_size_x,
    Val* maybe_size_y);

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants