From bc3f09928e99ccb43007eeb5508d264746b0932f Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 29 Nov 2025 02:50:55 +0000 Subject: [PATCH 1/3] Initial plan From 80348052bfa50549826c3236c1f31a84293ab6a0 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 29 Nov 2025 02:54:22 +0000 Subject: [PATCH 2/3] Fix E731: Convert lambda assignments to def functions Co-authored-by: crcrpar <16191443+crcrpar@users.noreply.github.com> --- apex/contrib/openfold_triton/layer_norm.py | 19 +++++++++++++------ apex/contrib/openfold_triton/mha.py | 9 +++++++-- examples/imagenet/main_amp.py | 3 ++- pyproject.toml | 1 - 4 files changed, 22 insertions(+), 10 deletions(-) diff --git a/apex/contrib/openfold_triton/layer_norm.py b/apex/contrib/openfold_triton/layer_norm.py index 7d9c6242a..881137eca 100644 --- a/apex/contrib/openfold_triton/layer_norm.py +++ b/apex/contrib/openfold_triton/layer_norm.py @@ -46,7 +46,9 @@ def forward(ctx, inputs, normalized_shape, weight, bias, eps=1e-05): x_mean = torch.empty(M, dtype=torch.float32, device=inputs.device) y = torch.empty(inputs.shape, dtype=inputs.dtype, device=inputs.device) - grid = lambda kwargs: (triton.cdiv(kwargs["M"], kwargs["M_BLOCK"]),) + def grid(kwargs): + return (triton.cdiv(kwargs["M"], kwargs["M_BLOCK"]),) + if inputs.is_contiguous(): _layer_norm_forward[grid]( x_ptr=inputs, @@ -96,7 +98,9 @@ def backward(ctx, d_y): # %% Separated kernels, similar to Inductor. # 1. dX. - grid = lambda kwargs: (triton.cdiv(kwargs["M"], kwargs["M_BLOCK"]),) + def grid(kwargs): + return (triton.cdiv(kwargs["M"], kwargs["M_BLOCK"]),) + if inputs.is_contiguous(): _layer_norm_backward_dx[grid]( dy_ptr=d_y, @@ -134,10 +138,13 @@ def backward(ctx, d_y): M_BUFSIZE = _M_BUFSIZE_CACHE.get(key, triton.cdiv(M, PARTIAL_REDUCE_MIN)) dw_partial_buf = torch.empty([N, M_BUFSIZE], dtype=torch.float32, device=d_y.device) db_partial_buf = torch.empty([N, M_BUFSIZE], dtype=torch.float32, device=d_y.device) - grid = lambda kwargs: ( - triton.cdiv(M, kwargs["M_PARTIAL_REDUCE"]), - triton.cdiv(N, kwargs["N_BLOCK"]), - ) + + def grid(kwargs): + return ( + triton.cdiv(M, kwargs["M_PARTIAL_REDUCE"]), + triton.cdiv(N, kwargs["N_BLOCK"]), + ) + if inputs.is_contiguous(): _layer_norm_backward_dw_db_partial[grid]( dy_ptr=d_y, diff --git a/apex/contrib/openfold_triton/mha.py b/apex/contrib/openfold_triton/mha.py index 9065b6ca8..7f7fa0f77 100644 --- a/apex/contrib/openfold_triton/mha.py +++ b/apex/contrib/openfold_triton/mha.py @@ -158,7 +158,10 @@ def forward(ctx, q, k, v, mask=None, bias=None, inf=1000000000.0, is_training=Tr o = torch.empty_like(q) Z, H, N_CTX, H_DIM = q.shape - grid = lambda META: (triton.cdiv(N_CTX, META["BLOCK_M"]), Z * H) + + def grid(META): + return (triton.cdiv(N_CTX, META["BLOCK_M"]), Z * H) + l = torch.empty( (q.shape[-4], q.shape[-3], q.shape[-2]), device=q.device, @@ -309,7 +312,9 @@ def backward(ctx, do): # grid = lambda META: (Z * H, triton.cdiv(N_CTX, META["BLOCK_N"])) # grid = lambda META: (triton.cdiv(N_CTX, META["BLOCK_N"]) if META["SEQUENCE_PARALLEL"] else 1, # Z * H) - grid = lambda META: (Z * H,) + def grid(META): + return (Z * H,) + _bwd_kernel[grid]( q, k, diff --git a/examples/imagenet/main_amp.py b/examples/imagenet/main_amp.py index 384e5bda9..c12591281 100644 --- a/examples/imagenet/main_amp.py +++ b/examples/imagenet/main_amp.py @@ -205,7 +205,8 @@ def resume(): train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset) - collate_fn = lambda b: fast_collate(b, memory_format) + def collate_fn(b): + return fast_collate(b, memory_format) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), diff --git a/pyproject.toml b/pyproject.toml index 952f4df66..c041d414b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,6 @@ build-backend = "setuptools.build_meta" line-length = 100 ignore = [ # Sorted by occurrence count (ascending) - easier to fix first - "E731", # lambda assignment (6 occurrences) "E721", # type comparison should use isinstance (8 occurrences) "E741", # ambiguous variable name (8 occurrences) "E712", # comparison to True/False (9 occurrences) From de8bfd4dff958cb2639c4d77004abe32ac9163b9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 29 Nov 2025 02:57:29 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- apex/contrib/openfold_triton/mha.py | 1 + 1 file changed, 1 insertion(+) diff --git a/apex/contrib/openfold_triton/mha.py b/apex/contrib/openfold_triton/mha.py index 7f7fa0f77..e19df31e9 100644 --- a/apex/contrib/openfold_triton/mha.py +++ b/apex/contrib/openfold_triton/mha.py @@ -308,6 +308,7 @@ def backward(ctx, do): # BLOCK_M, BLOCK_N = 128, 64 BLOCK_M, BLOCK_N, num_warps, num_stages = schedule_triton_mha(list(q.shape), fwd=False) + # grid = lambda META: (triton.cdiv(N_CTX, META["BLOCK_N"]), Z * H) # grid = lambda META: (Z * H, triton.cdiv(N_CTX, META["BLOCK_N"])) # grid = lambda META: (triton.cdiv(N_CTX, META["BLOCK_N"]) if META["SEQUENCE_PARALLEL"] else 1,