Skip to content

Improve sharding propagation for triangle updates outgoing #5901

@wujingyue

Description

@wujingyue

Repro: #5890

$ mpirun -np 1 -x NVFUSER_DUMP=pre_segmenter_logging pytest tests/python/multidevice/test_alphafold3.py -k outgoing --only-mpi -vs

The code of interest:

match direction:
case Direction.OUTGOING:
# z_out = einsum("bikc,bjkc->bijc", a, b)
a = fd.ops.permute(a, [0, 3, 1, 2]) # [b, c, i, k]
b = fd.ops.permute(b, [0, 3, 2, 1]) # [b, c, k, j]
case Direction.INCOMING:
# z_out = einsum("bkic,bkjc->bijc", a, b)
a = fd.ops.permute(a, [0, 3, 2, 1]) # [b, c, i, k]
b = fd.ops.permute(b, [0, 3, 1, 2]) # [b, c, k, j]
z = fd.ops.matmul(a, b) # [b, c, i, j]
z = fd.ops.permute(z, [0, 2, 3, 1]) # [b, i, j, c]

Image

The current heuristic for the forward propagation is to prefer the second input (usually the weight). Therefore, in the einsum output, j is sharded by DIDy not DIDx. This breaks the backprop from z_in (i by DIDy and j by DIDx) to the einsum output, because z_in wants j to be sharded by DIDx instead.

By the way, this is not a problem for "incoming" mode. Following the current heuristic, the einsum output does have j sharded on DIDx.

Image

cc @DejunL

Metadata

Metadata

Assignees

No one assigned

    Labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions