Skip to content

Conversation

@wujingyue
Copy link
Collaborator

No description provided.

@github-actions
Copy link

github-actions bot commented Jan 29, 2026

Review updated until commit 7928dec

Description

  • Add 3D sharding support for AlphaFold3 triangle updates with multi-device testing

  • Relax broadcast constraints to support non-1D device meshes

  • Enhance debugging with propagation logging and transform visualization

  • Reorganize code includes and add missing printer specializations

Changes walkthrough

Relevant files
Tests
test_alphafold3.py
AlphaFold3 triangle updates test with 3D sharding               

tests/python/multidevice/test_alphafold3.py

  • Add comprehensive test for AlphaFold3 triangle updates
  • Test both incoming and outgoing direction variants
  • Implement 3D device mesh sharding (dp_size × cp_size × cp_size)
  • Include layer normalization and gating components
  • +222/-0 
    Bug fix
    lower_to_communication.cpp
    Relax broadcast mesh constraints                                                 

    csrc/host_ir/lower_to_communication.cpp

  • Remove 1D mesh rank constraints for broadcast operations
  • Enhance error messages with tensor string representations
  • Allow non-1D sender and receiver meshes
  • +2/-10   
    exact_mapped_extent_substitution.cpp
    Fix exact mapping pass                                                                     

    csrc/preseg_passes/exact_mapped_extent_substitution.cpp

  • Remove logical_domain_map.h include
  • Enable self-mapping in IdModel construction
  • Add debug logging for exact graph state
  • +2/-2     
    Enhancement
    propagation.cpp
    Add propagation debugging                                                               

    csrc/multidevice/propagation.cpp

  • Add debug logging for sharding propagation operations
  • Include tensor references and propagation direction
  • Add missing base.h include
  • +7/-0     
    propagate_shardings.cpp
    Add transform debugging output                                                     

    csrc/preseg_passes/propagate_shardings.cpp

  • Add debug logging for fusion transforms after propagation
  • Remove unnecessary multidevice/utils.h include
  • Print complete transform state for debugging
  • +7/-1     
    utils.cpp
    Add PropagateDirection printer                                                     

    csrc/scheduler/utils.cpp

  • Implement operator<< for PropagateDirection enum
  • Enable printing of propagation direction for debugging
  • +16/-0   
    utils.h
    Fix PropagateDirection declarations                                           

    csrc/scheduler/utils.h

  • Add forward declaration of PropagateDirection operator<<
  • Fix include ordering for base.h
  • +3/-1     
    Formatting
    base.h
    Reorganize printer specializations                                             

    csrc/base.h

  • Reorganize SPECIALIZE_PRINTER macro definitions
  • Remove unused include
  • Add missing printer specializations
  • +13/-12 

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Error Message Quality

    The error messages for mesh rank validation were significantly shortened from detailed messages to concise ones. Verify that the new error messages still provide sufficient context for debugging mesh rank mismatches in production.

    NVF_ERROR_EQ(sender_mesh.rank(), 1, "sender: ", input_tv->toString());
    NVF_ERROR_EQ(receiver_mesh.rank(), 1, "receiver: ", output_tv->toString());
    Test Coverage Completeness

    The test only validates output shape but doesn't verify numerical correctness of the triangle update computation. Consider adding validation that the sharded computation produces equivalent results to a non-sharded reference implementation.

    assert z_out.shape == (batch_per_rank, n_tokens_per_rank, n_tokens_per_rank, c_z)
    ID Model Configuration Change

    The IdModel constructor now uses allow_self_mapping=true which changes the behavior of exact graph building. This could affect the precision of extent substitution - ensure this change is intentional and doesn't introduce incorrect optimizations.

    IdModel id_model(
        fusion, /*build_graphs=*/false, /*allow_self_mapping=*/true);

    Otherwise I couldn't even locate the Expr in a printed fusion.
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    1 participant