-
Notifications
You must be signed in to change notification settings - Fork 75
Open
Labels
Description
Repro: #5890
$ mpirun -np 1 -x NVFUSER_DUMP=pre_segmenter_logging pytest tests/python/multidevice/test_alphafold3.py -k outgoing --only-mpi -vs
The code of interest:
Fuser/tests/python/direct/test_alphafold3.py
Lines 123 to 133 in 77abd29
| match direction: | |
| case Direction.OUTGOING: | |
| # z_out = einsum("bikc,bjkc->bijc", a, b) | |
| a = fd.ops.permute(a, [0, 3, 1, 2]) # [b, c, i, k] | |
| b = fd.ops.permute(b, [0, 3, 2, 1]) # [b, c, k, j] | |
| case Direction.INCOMING: | |
| # z_out = einsum("bkic,bkjc->bijc", a, b) | |
| a = fd.ops.permute(a, [0, 3, 2, 1]) # [b, c, i, k] | |
| b = fd.ops.permute(b, [0, 3, 1, 2]) # [b, c, k, j] | |
| z = fd.ops.matmul(a, b) # [b, c, i, j] | |
| z = fd.ops.permute(z, [0, 2, 3, 1]) # [b, i, j, c] |
The current heuristic for the forward propagation is to prefer the second input (usually the weight). Therefore, in the einsum output, j is sharded by DIDy not DIDx. This breaks the backprop from z_in (i by DIDy and j by DIDx) to the einsum output, because z_in wants j to be sharded by DIDx instead.
By the way, this is not a problem for "incoming" mode. Following the current heuristic, the einsum output does have j sharded on DIDx.
cc @DejunL