From 6272da412c64e1708b8d62211a6eb2ae9c9c355c Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Wed, 14 Jan 2026 07:38:44 -0800 Subject: [PATCH 01/11] add getUniformWarpId --- csrc/codegen.cpp | 5 +++++ csrc/device_lower/pass/allocation.cpp | 23 +++++++++++++++++++++++ csrc/device_lower/pass/index.cpp | 7 +++++++ csrc/device_lower/pass/index.h | 1 + csrc/dispatch.h | 1 + csrc/kernel_ir.cpp | 19 +++++++++++++++++++ csrc/kernel_ir.h | 20 ++++++++++++++++++++ runtime/warp.cu | 11 +++++++++++ 8 files changed, 87 insertions(+) diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index e12ed646540..16b0c9326c3 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -4442,6 +4442,11 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { indent() << call << ";\n"; } + void handle(const kir::UniformWarpIdInit* init) final { + indent() << gen(init->out()) << " = " + << genCall("warp::getUniformWarpId", ArgumentBuilder()) << ";\n"; + } + void handle(const kir::MBarrierInvalidate* inval) final { auto call = genCall( "mbarrier::inval", ArgumentBuilder().arg(genInline(inval->mbarrier()))); diff --git a/csrc/device_lower/pass/allocation.cpp b/csrc/device_lower/pass/allocation.cpp index 689dc912b16..bc1a08af005 100644 --- a/csrc/device_lower/pass/allocation.cpp +++ b/csrc/device_lower/pass/allocation.cpp @@ -1231,6 +1231,24 @@ class AllocationInserter : public kir::ExprMutator { return alloc_expr; } + // insert a scalar register variable for uniform warp id + void insertUniformWarpId(Expr* expr) { + // allocate uniform_warp_id + Val* uniform_warp_id = IrBuilder::create(DataType::UInt32); + kir::Allocate* uniform_warp_id_alloc = IrBuilder::create( + uniform_warp_id, + MemoryType::Local, + FusionGuard::getCurFusion()->oneVal()); + registerInsertBefore(expr, uniform_warp_id_alloc, nullptr); + + // initialize uniform_warp_id + auto uniform_warp_id_init = + IrBuilder::create(uniform_warp_id); + Expr* pred_uniform_warp_id_init = uniform_warp_id_init->withPredicate( + IrBuilder::create(PredicateType::ElectSync)); + registerInsertBefore(expr, pred_uniform_warp_id_init, nullptr); + } + // Insert cluster reduction mbarrier allocation and initialization at the // beginning of the kernel for the first top-level expression void insertClusterReductionMBarrier(Expr* expr) { @@ -1679,6 +1697,11 @@ class AllocationInserter : public kir::ExprMutator { AllocationInserter(const std::vector& exprs) : gpu_lower_(GpuLower::current()) { + // For warp specialized kernel, insert uniform warp id at the top-level + // scope. + if (GpuLower::current()->circularBufferInfo().hasWarpSpecialized()) { + insertUniformWarpId(exprs.at(0)); + } // insert cluster reduction mbarrier at top-level scope if (GpuLower::current()->clusterReductionCount() >= 1) { insertClusterReductionMBarrier(exprs.at(0)); diff --git a/csrc/device_lower/pass/index.cpp b/csrc/device_lower/pass/index.cpp index 817d79dc8cc..bb74ebf6c41 100644 --- a/csrc/device_lower/pass/index.cpp +++ b/csrc/device_lower/pass/index.cpp @@ -1571,6 +1571,13 @@ void IndexLowering::handle(const kir::MBarrierInvalidate* minval) { GpuLower::current()->propagateExprInfo(minval, minval_indexed); } +void IndexLowering::handle(const kir::UniformWarpIdInit* uwid_init) { + // UniformWarpIdInit operates on scalar Vals, no indexing needed + // Just pass through as-is + pushBack(IrBuilder::create(uwid_init->out())); + GpuLower::current()->propagateExprInfo(uwid_init, back()); +} + void IndexLowering::handle(const kir::MBarrierArrive* arrive_transaction) { NVF_ERROR( arrive_transaction->mbarrier()->isA(), diff --git a/csrc/device_lower/pass/index.h b/csrc/device_lower/pass/index.h index b47a9f9b36a..d842a468978 100644 --- a/csrc/device_lower/pass/index.h +++ b/csrc/device_lower/pass/index.h @@ -88,6 +88,7 @@ class IndexLowering : private OptOutConstDispatch { void handle(const kir::Continue*) final; void handle(const kir::Return*) final; void handle(const kir::MBarrierInit*) final; + void handle(const kir::UniformWarpIdInit*) final; void handle(const kir::MBarrierInvalidate*) final; void handle(const kir::MBarrierArrive*) final; void handle(const kir::MBarrierArriveExpectTx*) final; diff --git a/csrc/dispatch.h b/csrc/dispatch.h index 822ababb149..4fc120c3acd 100644 --- a/csrc/dispatch.h +++ b/csrc/dispatch.h @@ -139,6 +139,7 @@ class Val; f(Continue); \ f(Return); \ f(MBarrierInit); \ + f(UniformWarpIdInit); \ f(MBarrierInvalidate); \ f(MBarrierArrive); \ f(MBarrierArriveExpectTx); \ diff --git a/csrc/kernel_ir.cpp b/csrc/kernel_ir.cpp index e28fc12c292..3b932326ab1 100644 --- a/csrc/kernel_ir.cpp +++ b/csrc/kernel_ir.cpp @@ -1075,6 +1075,25 @@ std::string MBarrierInit::toInlineString(int indent_size) const { NVFUSER_DEFINE_CLONE_AND_CREATE(MBarrierInit) +UniformWarpIdInit::UniformWarpIdInit(IrBuilderPasskey passkey, Val* out) + : Expr(passkey) { + NVF_ERROR(passkey.ir_container_ != nullptr); + NVF_CHECK(out->dtype() == DataType::UInt32); + addOutput(out); +} + +std::string UniformWarpIdInit::toString(int indent_size) const { + std::stringstream ss; + indent(ss, indent_size) << out()->toString() << " = UniformWarpIdInit()\n"; + return ss.str(); +} + +std::string UniformWarpIdInit::toInlineString(int indent_size) const { + NVF_CHECK(false, "UniformWarpIdInit can not be printed inline"); +} + +NVFUSER_DEFINE_CLONE_AND_CREATE(UniformWarpIdInit) + MBarrierInvalidate::MBarrierInvalidate(IrBuilderPasskey passkey, Val* mbarrier) : Expr(passkey) { NVF_ERROR(passkey.ir_container_ != nullptr); diff --git a/csrc/kernel_ir.h b/csrc/kernel_ir.h index b98aa1d0d79..42c771af4e9 100644 --- a/csrc/kernel_ir.h +++ b/csrc/kernel_ir.h @@ -58,6 +58,7 @@ class AsyncWait; class AsyncCommit; class InitMagicZero; class UpdateMagicZero; +class UniformWarpIdInit; class IfThenElse; class GridReduction; class GroupedGridReduction; @@ -847,6 +848,25 @@ class MBarrierInit final : public Expr { } }; +class UniformWarpIdInit final : public Expr { + public: + using Expr::Expr; + explicit UniformWarpIdInit(IrBuilderPasskey passkey, Val* out); + + NVFUSER_DECLARE_CLONE_AND_CREATE + + const char* getOpString() const override { + return "UniformWarpIdInit"; + } + + std::string toString(int indent_size = 0) const override; + std::string toInlineString(int indent_size = 0) const override; + + Val* out() const { + return output(0); + } +}; + class MBarrierInvalidate final : public Expr { public: using Expr::Expr; diff --git a/runtime/warp.cu b/runtime/warp.cu index 35ecbb88a41..838e0115637 100644 --- a/runtime/warp.cu +++ b/runtime/warp.cu @@ -7,6 +7,17 @@ // clang-format on namespace warp { +// Compute uniform warp id that is guaranteed to be the same for all threads in +// a warp. +// __shfl_sync helps PTXAS prove that every thread in the warp has the same +// uniform warp id. +__device__ __forceinline__ uint32_t getUniformWarpId() { + const unsigned int tid = threadIdx.x + threadIdx.y * blockDim.x + + threadIdx.z * blockDim.x * blockDim.y; + const unsigned int warp_id = tid / 32; + return __shfl_sync(0xFFFFFFFF, warp_id, 0); +} + template __device__ __forceinline__ T shfl_xor(T var, int laneMask, int width = 32) { return __shfl_xor_sync(0xffffffff, var, laneMask, width); From 3b30835e6adf0e097bd5ef64838827f2355e8590 Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Wed, 14 Jan 2026 08:42:49 -0800 Subject: [PATCH 02/11] use uniform warp id --- csrc/codegen.cpp | 2 +- csrc/device_lower/lower2device.h | 14 +++++++ csrc/device_lower/pass/allocation.cpp | 9 +++-- csrc/device_lower/pass/index.cpp | 6 +-- csrc/device_lower/pass/index.h | 2 +- csrc/dispatch.h | 2 +- csrc/ir/utils.cpp | 5 +++ csrc/kernel_ir.cpp | 14 ++++--- csrc/kernel_ir.h | 8 ++-- csrc/predicate_compute.cpp | 56 +++++++++------------------ 10 files changed, 61 insertions(+), 57 deletions(-) diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index 16b0c9326c3..abced9371f7 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -4442,7 +4442,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { indent() << call << ";\n"; } - void handle(const kir::UniformWarpIdInit* init) final { + void handle(const kir::UniformWarpId* init) final { indent() << gen(init->out()) << " = " << genCall("warp::getUniformWarpId", ArgumentBuilder()) << ";\n"; } diff --git a/csrc/device_lower/lower2device.h b/csrc/device_lower/lower2device.h index 8f60f69894d..1e78d2b1b01 100644 --- a/csrc/device_lower/lower2device.h +++ b/csrc/device_lower/lower2device.h @@ -318,6 +318,16 @@ class GpuLower : public NonCopyable { cluster_reduction_mbarrier_tensor_ = mbarrier; } + //! Get the uniform warp id scalar + Val* uniformWarpId() const { + return uniform_warp_id_; + } + + //! Set the uniform warp id scalar + void setUniformWarpId(Val* warp_id) { + uniform_warp_id_ = warp_id; + } + //! Define an alias for consumer as producer. //! //! If producer is already aliased, we chase the alias. If there are tensors @@ -434,6 +444,10 @@ class GpuLower : public NonCopyable { // The shared cluster reduction mbarrier tensor allocated during allocation // pass TensorView* cluster_reduction_mbarrier_tensor_ = nullptr; + + // The uniform warp id scalar allocated during allocation pass for warp + // specialized kernels + Val* uniform_warp_id_ = nullptr; }; #define NVFUSER_LOWER_VALIDATE(cond, ...) \ diff --git a/csrc/device_lower/pass/allocation.cpp b/csrc/device_lower/pass/allocation.cpp index bc1a08af005..1fce173bf12 100644 --- a/csrc/device_lower/pass/allocation.cpp +++ b/csrc/device_lower/pass/allocation.cpp @@ -1243,10 +1243,11 @@ class AllocationInserter : public kir::ExprMutator { // initialize uniform_warp_id auto uniform_warp_id_init = - IrBuilder::create(uniform_warp_id); - Expr* pred_uniform_warp_id_init = uniform_warp_id_init->withPredicate( - IrBuilder::create(PredicateType::ElectSync)); - registerInsertBefore(expr, pred_uniform_warp_id_init, nullptr); + IrBuilder::create(uniform_warp_id); + registerInsertBefore(expr, uniform_warp_id_init, nullptr); + + // Store the uniform_warp_id in GpuLower for later use in predicates + GpuLower::current()->setUniformWarpId(uniform_warp_id); } // Insert cluster reduction mbarrier allocation and initialization at the diff --git a/csrc/device_lower/pass/index.cpp b/csrc/device_lower/pass/index.cpp index bb74ebf6c41..36c4e792b06 100644 --- a/csrc/device_lower/pass/index.cpp +++ b/csrc/device_lower/pass/index.cpp @@ -1571,10 +1571,10 @@ void IndexLowering::handle(const kir::MBarrierInvalidate* minval) { GpuLower::current()->propagateExprInfo(minval, minval_indexed); } -void IndexLowering::handle(const kir::UniformWarpIdInit* uwid_init) { - // UniformWarpIdInit operates on scalar Vals, no indexing needed +void IndexLowering::handle(const kir::UniformWarpId* uwid_init) { + // UniformWarpId operates on scalar Vals, no indexing needed // Just pass through as-is - pushBack(IrBuilder::create(uwid_init->out())); + pushBack(IrBuilder::create(uwid_init->out())); GpuLower::current()->propagateExprInfo(uwid_init, back()); } diff --git a/csrc/device_lower/pass/index.h b/csrc/device_lower/pass/index.h index d842a468978..7643a254a1e 100644 --- a/csrc/device_lower/pass/index.h +++ b/csrc/device_lower/pass/index.h @@ -88,7 +88,7 @@ class IndexLowering : private OptOutConstDispatch { void handle(const kir::Continue*) final; void handle(const kir::Return*) final; void handle(const kir::MBarrierInit*) final; - void handle(const kir::UniformWarpIdInit*) final; + void handle(const kir::UniformWarpId*) final; void handle(const kir::MBarrierInvalidate*) final; void handle(const kir::MBarrierArrive*) final; void handle(const kir::MBarrierArriveExpectTx*) final; diff --git a/csrc/dispatch.h b/csrc/dispatch.h index 4fc120c3acd..2b0ebd5489b 100644 --- a/csrc/dispatch.h +++ b/csrc/dispatch.h @@ -139,7 +139,7 @@ class Val; f(Continue); \ f(Return); \ f(MBarrierInit); \ - f(UniformWarpIdInit); \ + f(UniformWarpId); \ f(MBarrierInvalidate); \ f(MBarrierArrive); \ f(MBarrierArriveExpectTx); \ diff --git a/csrc/ir/utils.cpp b/csrc/ir/utils.cpp index 9418b137688..b2c3b25255b 100644 --- a/csrc/ir/utils.cpp +++ b/csrc/ir/utils.cpp @@ -1417,6 +1417,11 @@ bool isFunctional(const Val* v) { if (dynamic_cast(def)) { return false; } + // UniformWarpId is a runtime operation that should not be evaluated + // at compile time or constant-folded by SimplifyingIrBuilder + if (dynamic_cast(def)) { + return false; + } return std::all_of(def->inputs().begin(), def->inputs().end(), isFunctional); } diff --git a/csrc/kernel_ir.cpp b/csrc/kernel_ir.cpp index 3b932326ab1..01f61ee1f7f 100644 --- a/csrc/kernel_ir.cpp +++ b/csrc/kernel_ir.cpp @@ -1075,24 +1075,26 @@ std::string MBarrierInit::toInlineString(int indent_size) const { NVFUSER_DEFINE_CLONE_AND_CREATE(MBarrierInit) -UniformWarpIdInit::UniformWarpIdInit(IrBuilderPasskey passkey, Val* out) +UniformWarpId::UniformWarpId(IrBuilderPasskey passkey, Val* out) : Expr(passkey) { NVF_ERROR(passkey.ir_container_ != nullptr); NVF_CHECK(out->dtype() == DataType::UInt32); addOutput(out); } -std::string UniformWarpIdInit::toString(int indent_size) const { +std::string UniformWarpId::toString(int indent_size) const { std::stringstream ss; - indent(ss, indent_size) << out()->toString() << " = UniformWarpIdInit()\n"; + indent(ss, indent_size) << out()->toString() << " = UniformWarpId()\n"; return ss.str(); } -std::string UniformWarpIdInit::toInlineString(int indent_size) const { - NVF_CHECK(false, "UniformWarpIdInit can not be printed inline"); +// uniform warp id is used in predicate, Predicate::toString() uses its +// toInlineString() +std::string UniformWarpId::toInlineString(int indent_size) const { + return std::string(getOpString()) + "()"; } -NVFUSER_DEFINE_CLONE_AND_CREATE(UniformWarpIdInit) +NVFUSER_DEFINE_CLONE_AND_CREATE(UniformWarpId) MBarrierInvalidate::MBarrierInvalidate(IrBuilderPasskey passkey, Val* mbarrier) : Expr(passkey) { diff --git a/csrc/kernel_ir.h b/csrc/kernel_ir.h index 42c771af4e9..db2d4db9291 100644 --- a/csrc/kernel_ir.h +++ b/csrc/kernel_ir.h @@ -58,7 +58,7 @@ class AsyncWait; class AsyncCommit; class InitMagicZero; class UpdateMagicZero; -class UniformWarpIdInit; +class UniformWarpId; class IfThenElse; class GridReduction; class GroupedGridReduction; @@ -848,15 +848,15 @@ class MBarrierInit final : public Expr { } }; -class UniformWarpIdInit final : public Expr { +class UniformWarpId final : public Expr { public: using Expr::Expr; - explicit UniformWarpIdInit(IrBuilderPasskey passkey, Val* out); + explicit UniformWarpId(IrBuilderPasskey passkey, Val* out); NVFUSER_DECLARE_CLONE_AND_CREATE const char* getOpString() const override { - return "UniformWarpIdInit"; + return "UniformWarpId"; } std::string toString(int indent_size = 0) const override; diff --git a/csrc/predicate_compute.cpp b/csrc/predicate_compute.cpp index 532f2ff0d6e..27fb2e01078 100644 --- a/csrc/predicate_compute.cpp +++ b/csrc/predicate_compute.cpp @@ -498,9 +498,20 @@ Val* createElectSyncExpr() { // warp collective. // TODO If TIDx is known at compile-time, generate custom mask. Val* selectFirstWarpElectSyncPredicate(bool is_warp_collective) { - Val* warp_size = IrBuilder::create(32L, PrimDataType::UInt64); - Val* select_first_warp = IrBuilder::ltExpr( - NamedScalar::getParallelIndex(ParallelType::TIDx), warp_size); + Val* select_first_warp = nullptr; + + // Try to use the uniform warp id if available (for warp specialized kernels) + Val* uniform_warp_id = GpuLower::current()->uniformWarpId(); + if (uniform_warp_id != nullptr) { + // Use uniform warp id: check if warp_id == 0 + Val* zero = IrBuilder::create(0L, PrimDataType::UInt32); + select_first_warp = IrBuilder::eqExpr(uniform_warp_id, zero); + } else { + // Fallback to original: threadIdx.x < 32 + Val* warp_size = IrBuilder::create(32L, PrimDataType::UInt64); + select_first_warp = IrBuilder::ltExpr( + NamedScalar::getParallelIndex(ParallelType::TIDx), warp_size); + } // Short-Circuit: TMA Store is a warp-collective, so ElectSync is not // necessary. @@ -508,8 +519,8 @@ Val* selectFirstWarpElectSyncPredicate(bool is_warp_collective) { return select_first_warp; } - return SimplifyingIrBuilder::logicalAndExpr( - createElectSyncExpr(), select_first_warp); + Val* elect_sync = createElectSyncExpr(); + return SimplifyingIrBuilder::logicalAndExpr(elect_sync, select_first_warp); } // Get linear index for AsyncWarp Group. Then, select first warp. Finally, use @@ -670,45 +681,16 @@ Val* createMultipleExpressionElectSync( kir::Predicate* pred, const std::vector& loops) { NVF_ERROR(pred->expr() == nullptr); - - Val* zero = IrBuilder::create(0L, PrimDataType::UInt64); - const ParallelDimensionMap& pdim_map = - GpuLower::current()->info().parallelDimensionMap(); - - // Determine if warp specialized tma load expression. - ParallelType async_warp_on = ParallelType::Serial; auto async_warp_loop_it = std::find_if(loops.begin(), loops.end(), [](kir::ForLoop* fl) { return fl->circularBufferLoopStage() == CircularBufferLoopStage::AsyncWarp; }); if (async_warp_loop_it != loops.end()) { - auto circular_buffer_type = std::get( - GpuLower::current() - ->circularBufferInfo() - .getCircularBufferOptionsFor((*async_warp_loop_it)->iter_domain()) - .type); - async_warp_on = circular_buffer_type.on; - } - - // Short-circuit: If we are in a async warp, then the warp-dispatching - // IfThenElse already selects on `async_warp_on`, so we should not - // generate predicates for it here. - if (async_warp_loop_it == loops.end()) { - Val* conditional = async_warp_on == ParallelType::TIDx - ? pred->fusion()->trueVal() - : selectFirstWarpElectSyncPredicate(/*is_warp_collective=*/false); - for (ParallelType pt : {ParallelType::TIDy, ParallelType::TIDz}) { - if (pdim_map.has(pt) && async_warp_on != pt) { - conditional = SimplifyingIrBuilder::logicalAndExpr( - conditional, - IrBuilder::eqExpr(NamedScalar::getParallelIndex(pt), zero)); - } - } - return conditional; + return createElectSyncPredicateAsync(); + } else { + return selectFirstWarpElectSyncPredicate(/*is_warp_collective=*/false); } - - return createElectSyncPredicateAsync(); } } // namespace From 35f84c855633bbdddae41827bb5375505cb2191d Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Wed, 14 Jan 2026 08:49:50 -0800 Subject: [PATCH 03/11] clean --- csrc/predicate_compute.cpp | 26 +++++++++----------------- 1 file changed, 9 insertions(+), 17 deletions(-) diff --git a/csrc/predicate_compute.cpp b/csrc/predicate_compute.cpp index 27fb2e01078..6da37022fbf 100644 --- a/csrc/predicate_compute.cpp +++ b/csrc/predicate_compute.cpp @@ -494,33 +494,25 @@ Val* createElectSyncExpr() { return elect_sync_val; } -// Select first warp of threads along TIDx axis and use ptx::elect_sync if not -// warp collective. -// TODO If TIDx is known at compile-time, generate custom mask. +// Select the first warp or one thread of the first warp in the block Val* selectFirstWarpElectSyncPredicate(bool is_warp_collective) { Val* select_first_warp = nullptr; - // Try to use the uniform warp id if available (for warp specialized kernels) Val* uniform_warp_id = GpuLower::current()->uniformWarpId(); - if (uniform_warp_id != nullptr) { - // Use uniform warp id: check if warp_id == 0 - Val* zero = IrBuilder::create(0L, PrimDataType::UInt32); - select_first_warp = IrBuilder::eqExpr(uniform_warp_id, zero); - } else { - // Fallback to original: threadIdx.x < 32 - Val* warp_size = IrBuilder::create(32L, PrimDataType::UInt64); - select_first_warp = IrBuilder::ltExpr( - NamedScalar::getParallelIndex(ParallelType::TIDx), warp_size); - } + NVF_ERROR(uniform_warp_id != nullptr); + + Val* zero = IrBuilder::create(0L, PrimDataType::UInt32); + select_first_warp = IrBuilder::eqExpr(uniform_warp_id, zero); // Short-Circuit: TMA Store is a warp-collective, so ElectSync is not - // necessary. + // necessary. Just select the first warp. if (is_warp_collective) { return select_first_warp; } - Val* elect_sync = createElectSyncExpr(); - return SimplifyingIrBuilder::logicalAndExpr(elect_sync, select_first_warp); + // Select one thread of the first warp + return SimplifyingIrBuilder::logicalAndExpr( + createElectSyncExpr(), select_first_warp); } // Get linear index for AsyncWarp Group. Then, select first warp. Finally, use From 0dda543248e1fa837af8bc373c7976a2b187be4f Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Wed, 14 Jan 2026 10:00:12 -0800 Subject: [PATCH 04/11] wip --- csrc/device_lower/lower2device.h | 14 +++ csrc/device_lower/pass/circular_buffer.cpp | 31 ++++--- csrc/parallel_dimension_map.cpp | 40 +++++++++ csrc/parallel_dimension_map.h | 10 +++ csrc/predicate_compute.cpp | 86 +++++++++++-------- .../test_combined_inner_outer_reduction.cpp | 2 +- 6 files changed, 132 insertions(+), 51 deletions(-) diff --git a/csrc/device_lower/lower2device.h b/csrc/device_lower/lower2device.h index 1e78d2b1b01..2ee40719ca2 100644 --- a/csrc/device_lower/lower2device.h +++ b/csrc/device_lower/lower2device.h @@ -328,6 +328,16 @@ class GpuLower : public NonCopyable { uniform_warp_id_ = warp_id; } + //! Get the number of compute warps + Val* numComputeWarps() const { + return num_compute_warps_; + } + + //! Set the number of compute warps + void setNumComputeWarps(Val* num_warps) { + num_compute_warps_ = num_warps; + } + //! Define an alias for consumer as producer. //! //! If producer is already aliased, we chase the alias. If there are tensors @@ -448,6 +458,10 @@ class GpuLower : public NonCopyable { // The uniform warp id scalar allocated during allocation pass for warp // specialized kernels Val* uniform_warp_id_ = nullptr; + + // The number of compute warps calculated during circular buffer pass for warp + // specialized kernels + Val* num_compute_warps_ = nullptr; }; #define NVFUSER_LOWER_VALIDATE(cond, ...) \ diff --git a/csrc/device_lower/pass/circular_buffer.cpp b/csrc/device_lower/pass/circular_buffer.cpp index 67022363f1c..bdf81a26a3c 100644 --- a/csrc/device_lower/pass/circular_buffer.cpp +++ b/csrc/device_lower/pass/circular_buffer.cpp @@ -1566,21 +1566,23 @@ class WarpSpecializedCircularBufferInserter : private kir::ExprMutator { } // Create predicate for warp-specialized IfThenElse: - // kir::Predicate is thread_axis >= block_dim_axis - padded_value + // kir::Predicate is UniformWarpId() >= num_compute_warps kir::Predicate* getAsyncWarpPredicate(const CircularBufferOptions& options) { - ParallelType warp_specialize_on = - std::get(options.type).on; - int64_t warp_specialization_pad = - GpuLower::current() - ->info() - .parallelDimensionMap() - .getWarpSpecializationPaddedVal(warp_specialize_on); - Val* raw = GpuLower::current()->info().parallelDimensionMap().get( - warp_specialize_on); - Val* raw_minus_pad = SimplifyingIrBuilder::subExpr( - raw, IrBuilder::create(warp_specialization_pad, DataType::Index)); - return IrBuilder::create(IrBuilder::geExpr( - NamedScalar::getParallelIndex(warp_specialize_on), raw_minus_pad)); + // Get the number of compute warps using ParallelDimensionMap + // This works correctly for warp specialization on TIDx, TIDy, or TIDz + const ParallelDimensionMap& pdim_map = + GpuLower::current()->info().parallelDimensionMap(); + Val* num_compute_warps = pdim_map.getNumComputeWarps(); + + ave num_compute_warps in GpuLower for reuse in cre + // teElectSyncPredicateAsync + GpuLower::current()->setNumComputeWarps(num_compute_warps); + + niformWarpId() instead of threadIdx.x Val* uniform_warp_id = + GpuLower::current()->uniformWarpId(); + NVF_ERROR(uniform_warp_id != nullptr, "UniformWarpId must be initialized"); + return IrBuilder::create( + IrBuilder::geExpr(uniform_warp_id, num_compute_warps)); } void insertTmaWarpSpecialized( @@ -2172,3 +2174,4 @@ std::vector CircularBufferPass::run(const std::vector& exprs) { } } // namespace nvfuser + diff --git a/csrc/parallel_dimension_map.cpp b/csrc/parallel_dimension_map.cpp index 3ee6ca64b39..5beaf962cd4 100644 --- a/csrc/parallel_dimension_map.cpp +++ b/csrc/parallel_dimension_map.cpp @@ -323,6 +323,46 @@ int64_t ParallelDimensionMap::getWarpSpecializationPaddedVal( return warp_specialized_padding_value_.value(); } +Val* ParallelDimensionMap::getNumComputeWarps() const { + NVF_ERROR( + hasWarpSpecialization(), + "getNumComputeWarps() should only be called for warp specialized " + "kernels"); + + // Calculate the total number of compute threads: + // If warp specialized on TIDx: (bdimx - pad) * bdimy * bdimz + // If warp specialized on TIDy: bdimx * (bdimy - pad) * bdimz + // If warp specialized on TIDz: bdimx * bdimy * (bdimz - pad) + // Then divide by 32 to get the number of warps + + Val* num_compute_threads = FusionGuard::getCurFusion()->oneVal(); + ParallelType ws_pt = warp_specialized_parallel_type_.value(); + + for (auto pt : kParallelTypeTIDs) { + Val* dim = nullptr; + if (pt == ws_pt) { + // For the warp specialized dimension, use getRawCompute which subtracts + // the pad + dim = getRawCompute(pt); + } else { + // For other dimensions, use the raw dimension + dim = getRaw(pt); + } + + if (dim == nullptr) { + continue; + } + num_compute_threads = + SimplifyingIrBuilder::mulExpr(num_compute_threads, dim); + } + + // Divide by 32 to get the number of warps + Val* num_compute_warps = SimplifyingIrBuilder::divExpr( + num_compute_threads, IrBuilder::create(32L, DataType::Index)); + + return num_compute_warps; +} + bool ParallelDimensionMap::canUseElectSyncInAsyncWarp() const { // short-circuit: skip if warp specialization is not enabled if (!hasWarpSpecialization()) { diff --git a/csrc/parallel_dimension_map.h b/csrc/parallel_dimension_map.h index 00f4165e298..77af1ed2e88 100644 --- a/csrc/parallel_dimension_map.h +++ b/csrc/parallel_dimension_map.h @@ -78,6 +78,16 @@ class ParallelDimensionMap { //! Assign linear index to each thread of CTA. Assume (TDZ, TDY, TDX) order. Val* getLinearThreadIndexAsync() const; + //! Get the number of compute warps for warp specialized kernels. + //! This computes the total number of compute threads across all dimensions + //! (TIDx, TIDy, TIDz), using the compute dimension (minus padding) for the + //! warp specialized dimension, then divides by 32 to get the number of warps. + //! Examples: + //! - If warp specialized on TIDx: (bdimx - pad) * bdimy * bdimz / 32 + //! - If warp specialized on TIDy: bdimx * (bdimy - pad) * bdimz / 32 + //! - If warp specialized on TIDz: bdimx * bdimy * (bdimz - pad) / 32 + Val* getNumComputeWarps() const; + //! Get if the kernel uses warp specialization bool hasWarpSpecialization() const { return warp_specialized_parallel_type_.has_value(); diff --git a/csrc/predicate_compute.cpp b/csrc/predicate_compute.cpp index 6da37022fbf..d6713e1a5d0 100644 --- a/csrc/predicate_compute.cpp +++ b/csrc/predicate_compute.cpp @@ -494,54 +494,68 @@ Val* createElectSyncExpr() { return elect_sync_val; } -// Select the first warp or one thread of the first warp in the block -Val* selectFirstWarpElectSyncPredicate(bool is_warp_collective) { +Val* selectWarpIndex(uint32_t warp_index) { Val* select_first_warp = nullptr; Val* uniform_warp_id = GpuLower::current()->uniformWarpId(); NVF_ERROR(uniform_warp_id != nullptr); - Val* zero = IrBuilder::create(0L, PrimDataType::UInt32); + Val* zero = IrBuilder::create(warp_index, PrimDataType::UInt32); select_first_warp = IrBuilder::eqExpr(uniform_warp_id, zero); + return select_first_warp; +} - // Short-Circuit: TMA Store is a warp-collective, so ElectSync is not - // necessary. Just select the first warp. - if (is_warp_collective) { - return select_first_warp; - } - - // Select one thread of the first warp - return SimplifyingIrBuilder::logicalAndExpr( - createElectSyncExpr(), select_first_warp); +// Select the first warp or one thread of the first warp in the block +Val* selectWarpIdxElectSyncPredicate( + uint32_t warp_index, + bool is_warp_collective) { + return is_warp_collective + ? selectWarpIndex(warp_index) + : SimplifyingIrBuilder::logicalAndExpr( + selectWarpIndex(warp_index), createElectSyncExpr()); } // Get linear index for AsyncWarp Group. Then, select first warp. Finally, use // ptx::elect_sync if not warp collective. // TODO If TIDx is known at compile-time, generate custom mask. Val* createElectSyncPredicateAsync() { - Val* zero = IrBuilder::create(0L, PrimDataType::UInt64); - Val* warp_size = IrBuilder::create(32L, PrimDataType::UInt64); - - const ParallelDimensionMap& pdim_map = - GpuLower::current()->info().parallelDimensionMap(); - Val* async_warp_thread_index = pdim_map.getLinearThreadIndexAsync(); - Val* warp_id = - SimplifyingIrBuilder::divExpr(async_warp_thread_index, warp_size); - // TODO Only select first warp now - Val* select_warp = SimplifyingIrBuilder::eqExpr(warp_id, zero); - - // Use elect-sync if available - if (pdim_map.canUseElectSyncInAsyncWarp()) { - return SimplifyingIrBuilder::logicalAndExpr( - select_warp, createElectSyncExpr()); + Val* num_compute_warps = GpuLower::current()->numComputeWarps(); + NVF_ERROR( + num_compute_warps != nullptr, "NumComputeWarps must be initialized"); + // Convert Val to uint32_t for warp index + // Since num_compute_warps represents the number of compute warps, + // the first async warp starts at index num_compute_warps + uint32_t warp_index = 0; + if (num_compute_warps->isConstScalar()) { + warp_index = + static_cast(num_compute_warps->evaluate().as()); } - - // Warp Specialized ParallelType is ThreadIdx.x and it contains less than 32 - // threads, so manually select first thread in warp. - Val* thread_id = - SimplifyingIrBuilder::modExpr(async_warp_thread_index, warp_size); - Val* select_thread = SimplifyingIrBuilder::eqExpr(thread_id, zero); - return SimplifyingIrBuilder::logicalAndExpr(select_warp, select_thread); + return selectWarpIdxElectSyncPredicate( + warp_index, /*is_warp_collective=*/false); + // Val* zero = IrBuilder::create(0L, PrimDataType::UInt64); + // Val* warp_size = IrBuilder::create(32L, PrimDataType::UInt64); + + // const ParallelDimensionMap& pdim_map = + // GpuLower::current()->info().parallelDimensionMap(); + // Val* async_warp_thread_index = pdim_map.getLinearThreadIndexAsync(); + // Val* warp_id = + // SimplifyingIrBuilder::divExpr(async_warp_thread_index, warp_size); + // // TODO Only select first warp now + // Val* select_warp = SimplifyingIrBuilder::eqExpr(warp_id, zero); + + // // Use elect-sync if available + // if (pdim_map.canUseElectSyncInAsyncWarp()) { + // return SimplifyingIrBuilder::logicalAndExpr( + // select_warp, createElectSyncExpr()); + // } + + // // Warp Specialized ParallelType is ThreadIdx.x and it contains less than + // 32 + // // threads, so manually select first thread in warp. + // Val* thread_id = + // SimplifyingIrBuilder::modExpr(async_warp_thread_index, warp_size); + // Val* select_thread = SimplifyingIrBuilder::eqExpr(thread_id, zero); + // return SimplifyingIrBuilder::logicalAndExpr(select_warp, select_thread); } Val* createElectSyncPredicate(kir::Predicate* pred, bool is_async_warp) { @@ -589,7 +603,7 @@ Val* createElectSyncPredicate(kir::Predicate* pred, bool is_async_warp) { if (is_async_warp) { return createElectSyncPredicateAsync(); } - return selectFirstWarpElectSyncPredicate(is_tma_store); + return selectWarpIdxElectSyncPredicate(0, is_tma_store); } Val* createSingleExpressionElectSync( @@ -681,7 +695,7 @@ Val* createMultipleExpressionElectSync( if (async_warp_loop_it != loops.end()) { return createElectSyncPredicateAsync(); } else { - return selectFirstWarpElectSyncPredicate(/*is_warp_collective=*/false); + return selectWarpIdxElectSyncPredicate(0, /*is_warp_collective=*/false); } } diff --git a/tests/cpp/test_combined_inner_outer_reduction.cpp b/tests/cpp/test_combined_inner_outer_reduction.cpp index e9a684b1cfe..e26ecce0b74 100644 --- a/tests/cpp/test_combined_inner_outer_reduction.cpp +++ b/tests/cpp/test_combined_inner_outer_reduction.cpp @@ -1299,7 +1299,7 @@ TEST_P(TmaWarpSpecializedTest, LayerNormBackward) { auto TmaWarpSpecializedTestParams() { std::vector values; int64_t dim0 = 2048; - for (int64_t dim1 = 1024; dim1 <= 8192; dim1 += 256) { + for (int64_t dim1 = 1024; dim1 <= 32768; dim1 += 256) { for (bool contig : {true, false}) { // to save test time if (dim1 != 1024 && !contig) { From c403dc7bbfb008fe75e32f3e846f67bd2a4b71b2 Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Wed, 14 Jan 2026 10:56:04 -0800 Subject: [PATCH 05/11] use getNumComputeWarps --- csrc/codegen.cpp | 2 +- csrc/device_lower/lower2device.h | 14 -------------- csrc/device_lower/pass/circular_buffer.cpp | 8 +------- csrc/predicate_compute.cpp | 4 +++- 4 files changed, 5 insertions(+), 23 deletions(-) diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index abced9371f7..d61bc336b00 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -3757,7 +3757,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { ArgumentBuilder template_args; template_args.arg(kernel_->paddedParallelDimensions().is_tidx_single_warp); - template_args.arg(isAligned()); + template_args.arg(has_warp_specialized_ ? false : isAligned()); template_args.arg(num_grouped_iterations); template_args.arg(reduction_scheduler_utils::getComputeBdimx( warp_specialized_on_, lparams_.bdimx())); diff --git a/csrc/device_lower/lower2device.h b/csrc/device_lower/lower2device.h index 2ee40719ca2..1e78d2b1b01 100644 --- a/csrc/device_lower/lower2device.h +++ b/csrc/device_lower/lower2device.h @@ -328,16 +328,6 @@ class GpuLower : public NonCopyable { uniform_warp_id_ = warp_id; } - //! Get the number of compute warps - Val* numComputeWarps() const { - return num_compute_warps_; - } - - //! Set the number of compute warps - void setNumComputeWarps(Val* num_warps) { - num_compute_warps_ = num_warps; - } - //! Define an alias for consumer as producer. //! //! If producer is already aliased, we chase the alias. If there are tensors @@ -458,10 +448,6 @@ class GpuLower : public NonCopyable { // The uniform warp id scalar allocated during allocation pass for warp // specialized kernels Val* uniform_warp_id_ = nullptr; - - // The number of compute warps calculated during circular buffer pass for warp - // specialized kernels - Val* num_compute_warps_ = nullptr; }; #define NVFUSER_LOWER_VALIDATE(cond, ...) \ diff --git a/csrc/device_lower/pass/circular_buffer.cpp b/csrc/device_lower/pass/circular_buffer.cpp index bdf81a26a3c..d4f1261f645 100644 --- a/csrc/device_lower/pass/circular_buffer.cpp +++ b/csrc/device_lower/pass/circular_buffer.cpp @@ -1574,12 +1574,7 @@ class WarpSpecializedCircularBufferInserter : private kir::ExprMutator { GpuLower::current()->info().parallelDimensionMap(); Val* num_compute_warps = pdim_map.getNumComputeWarps(); - ave num_compute_warps in GpuLower for reuse in cre - // teElectSyncPredicateAsync - GpuLower::current()->setNumComputeWarps(num_compute_warps); - - niformWarpId() instead of threadIdx.x Val* uniform_warp_id = - GpuLower::current()->uniformWarpId(); + Val* uniform_warp_id = GpuLower::current()->uniformWarpId(); NVF_ERROR(uniform_warp_id != nullptr, "UniformWarpId must be initialized"); return IrBuilder::create( IrBuilder::geExpr(uniform_warp_id, num_compute_warps)); @@ -2174,4 +2169,3 @@ std::vector CircularBufferPass::run(const std::vector& exprs) { } } // namespace nvfuser - diff --git a/csrc/predicate_compute.cpp b/csrc/predicate_compute.cpp index d6713e1a5d0..ba3a77de59d 100644 --- a/csrc/predicate_compute.cpp +++ b/csrc/predicate_compute.cpp @@ -519,7 +519,9 @@ Val* selectWarpIdxElectSyncPredicate( // ptx::elect_sync if not warp collective. // TODO If TIDx is known at compile-time, generate custom mask. Val* createElectSyncPredicateAsync() { - Val* num_compute_warps = GpuLower::current()->numComputeWarps(); + const ParallelDimensionMap& pdim_map = + GpuLower::current()->info().parallelDimensionMap(); + Val* num_compute_warps = pdim_map.getNumComputeWarps(); NVF_ERROR( num_compute_warps != nullptr, "NumComputeWarps must be initialized"); // Convert Val to uint32_t for warp index From 5d4d0d0281a776829a6b22465c07904221605d07 Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Wed, 14 Jan 2026 11:21:00 -0800 Subject: [PATCH 06/11] clean --- csrc/device_lower/pass/allocation.cpp | 3 +- csrc/device_lower/pass/circular_buffer.cpp | 3 +- csrc/parallel_dimension_map.cpp | 66 ++-------------------- csrc/parallel_dimension_map.h | 6 +- csrc/predicate_compute.cpp | 50 +++------------- 5 files changed, 20 insertions(+), 108 deletions(-) diff --git a/csrc/device_lower/pass/allocation.cpp b/csrc/device_lower/pass/allocation.cpp index 1fce173bf12..63d80fddb00 100644 --- a/csrc/device_lower/pass/allocation.cpp +++ b/csrc/device_lower/pass/allocation.cpp @@ -976,7 +976,8 @@ Expr* initializeCircularBufferMbarrier( GpuLower::current() ->info() .parallelDimensionMap() - .getNumComputeThreadsEachBlock()); + .getNumComputeThreadsEachBlock( + /*only_count_same_compute_warp_groups=*/true)); } // Initialize mbarrier for each circular buffer stage. Use the thread diff --git a/csrc/device_lower/pass/circular_buffer.cpp b/csrc/device_lower/pass/circular_buffer.cpp index d4f1261f645..1902128511d 100644 --- a/csrc/device_lower/pass/circular_buffer.cpp +++ b/csrc/device_lower/pass/circular_buffer.cpp @@ -2016,7 +2016,8 @@ kir::ForLoop* HopperPingPongMbarriers::initializePingPongMbarrier() { GpuLower::current() ->info() .parallelDimensionMap() - .getNumComputeThreadsEachBlock()); + .getNumComputeThreadsEachBlock( + /*only_count_same_compute_warp_groups=*/true)); kir::TensorIndex* ping_pong_mbarrier_index = IrBuilder::create(mbarriers_, loop->index()); kir::MBarrierInit* ping_pong_mbarrier_init = diff --git a/csrc/parallel_dimension_map.cpp b/csrc/parallel_dimension_map.cpp index 5beaf962cd4..80b21a24fa1 100644 --- a/csrc/parallel_dimension_map.cpp +++ b/csrc/parallel_dimension_map.cpp @@ -257,12 +257,13 @@ Val* ParallelDimensionMap::getRawAsync(ParallelType pt) const { return getRaw(pt); } -Val* ParallelDimensionMap::getNumComputeThreadsEachBlock() const { +Val* ParallelDimensionMap::getNumComputeThreadsEachBlock( + bool only_count_same_compute_warp_groups) const { Val* num_threads = FusionGuard::getCurFusion()->oneVal(); for (auto pt : kParallelTypeTIDs) { // Skip warp specialized ParallelType if the are computation warp groups // are independent. - if (isWarpSpecialized(pt) && + if (only_count_same_compute_warp_groups && isWarpSpecialized(pt) && GpuLower::current() ->circularBufferInfo() .hasIndependentComputeWarpGroups()) { @@ -277,39 +278,6 @@ Val* ParallelDimensionMap::getNumComputeThreadsEachBlock() const { return num_threads; } -// For warp-specialization, the CTA is padded so the AsyncWarp contains 128 -// threads. This function maps the AsyncWarp CTA to a linear index from -// [0, 128). It is used to divide AsyncWarp into four independent warps. -Val* ParallelDimensionMap::getLinearThreadIndexAsync() const { - Val* index = GpuLower::current()->kernel()->zeroVal(); - Val* extent = GpuLower::current()->kernel()->oneVal(); - - for (auto pt : kParallelTypeTIDs) { - // For warp-specialization, an axis is padded so the AsyncWarp contains - // 128 threads. - Val* extent_for_pdim = getRawAsync(pt); - // short-circuit: extent_for_pdim is not used in kernel. - if (extent_for_pdim == nullptr) { - continue; - } - // short-circuit: extent_for_pdim is trivial. - if (extent_for_pdim->isConstScalar() && - extent_for_pdim->evaluate().as() == 1) { - continue; - } - Val* pt_index = NamedScalar::getParallelIndex(pt); - // Map the padded parallel index to [0, padded_value] range, so the linear - // index will be in range of [0, 128). - if (isWarpSpecialized(pt)) { - pt_index = SimplifyingIrBuilder::subExpr(pt_index, getRawCompute(pt)); - } - index = SimplifyingIrBuilder::addExpr( - index, SimplifyingIrBuilder::mulExpr(pt_index, extent)); - extent = SimplifyingIrBuilder::mulExpr(extent, extent_for_pdim); - } - return index; -} - int64_t ParallelDimensionMap::getWarpSpecializationPaddedVal( ParallelType pt) const { NVF_ERROR(isWarpSpecialized(pt), "Can't find ParallelType: ", pt); @@ -329,32 +297,8 @@ Val* ParallelDimensionMap::getNumComputeWarps() const { "getNumComputeWarps() should only be called for warp specialized " "kernels"); - // Calculate the total number of compute threads: - // If warp specialized on TIDx: (bdimx - pad) * bdimy * bdimz - // If warp specialized on TIDy: bdimx * (bdimy - pad) * bdimz - // If warp specialized on TIDz: bdimx * bdimy * (bdimz - pad) - // Then divide by 32 to get the number of warps - - Val* num_compute_threads = FusionGuard::getCurFusion()->oneVal(); - ParallelType ws_pt = warp_specialized_parallel_type_.value(); - - for (auto pt : kParallelTypeTIDs) { - Val* dim = nullptr; - if (pt == ws_pt) { - // For the warp specialized dimension, use getRawCompute which subtracts - // the pad - dim = getRawCompute(pt); - } else { - // For other dimensions, use the raw dimension - dim = getRaw(pt); - } - - if (dim == nullptr) { - continue; - } - num_compute_threads = - SimplifyingIrBuilder::mulExpr(num_compute_threads, dim); - } + Val* num_compute_threads = getNumComputeThreadsEachBlock( + /*only_count_same_compute_warp_groups=*/false); // Divide by 32 to get the number of warps Val* num_compute_warps = SimplifyingIrBuilder::divExpr( diff --git a/csrc/parallel_dimension_map.h b/csrc/parallel_dimension_map.h index 77af1ed2e88..d08393d2fdc 100644 --- a/csrc/parallel_dimension_map.h +++ b/csrc/parallel_dimension_map.h @@ -73,10 +73,8 @@ class ParallelDimensionMap { //! And this function will return (32 * 16) because the extra one for TIDy is //! introduced by warp specialization and only used for loading circular //! buffer tensors. - Val* getNumComputeThreadsEachBlock() const; - - //! Assign linear index to each thread of CTA. Assume (TDZ, TDY, TDX) order. - Val* getLinearThreadIndexAsync() const; + Val* getNumComputeThreadsEachBlock( + bool only_count_same_compute_warp_groups) const; //! Get the number of compute warps for warp specialized kernels. //! This computes the total number of compute threads across all dimensions diff --git a/csrc/predicate_compute.cpp b/csrc/predicate_compute.cpp index ba3a77de59d..9957e24c076 100644 --- a/csrc/predicate_compute.cpp +++ b/csrc/predicate_compute.cpp @@ -495,14 +495,11 @@ Val* createElectSyncExpr() { } Val* selectWarpIndex(uint32_t warp_index) { - Val* select_first_warp = nullptr; - Val* uniform_warp_id = GpuLower::current()->uniformWarpId(); NVF_ERROR(uniform_warp_id != nullptr); - - Val* zero = IrBuilder::create(warp_index, PrimDataType::UInt32); - select_first_warp = IrBuilder::eqExpr(uniform_warp_id, zero); - return select_first_warp; + Val* target_warp_index = + IrBuilder::create(warp_index, PrimDataType::UInt32); + return IrBuilder::eqExpr(uniform_warp_id, target_warp_index); } // Select the first warp or one thread of the first warp in the block @@ -515,49 +512,20 @@ Val* selectWarpIdxElectSyncPredicate( selectWarpIndex(warp_index), createElectSyncExpr()); } -// Get linear index for AsyncWarp Group. Then, select first warp. Finally, use -// ptx::elect_sync if not warp collective. -// TODO If TIDx is known at compile-time, generate custom mask. +// Since num_compute_warps represents the number of compute warps, +// the first async warp starts at index num_compute_warps Val* createElectSyncPredicateAsync() { const ParallelDimensionMap& pdim_map = GpuLower::current()->info().parallelDimensionMap(); Val* num_compute_warps = pdim_map.getNumComputeWarps(); NVF_ERROR( num_compute_warps != nullptr, "NumComputeWarps must be initialized"); - // Convert Val to uint32_t for warp index - // Since num_compute_warps represents the number of compute warps, - // the first async warp starts at index num_compute_warps - uint32_t warp_index = 0; - if (num_compute_warps->isConstScalar()) { - warp_index = - static_cast(num_compute_warps->evaluate().as()); - } + NVF_ERROR( + num_compute_warps->isConstScalar(), "NumComputeWarps must be a constant"); + uint32_t warp_index = + static_cast(num_compute_warps->evaluate().as()); return selectWarpIdxElectSyncPredicate( warp_index, /*is_warp_collective=*/false); - // Val* zero = IrBuilder::create(0L, PrimDataType::UInt64); - // Val* warp_size = IrBuilder::create(32L, PrimDataType::UInt64); - - // const ParallelDimensionMap& pdim_map = - // GpuLower::current()->info().parallelDimensionMap(); - // Val* async_warp_thread_index = pdim_map.getLinearThreadIndexAsync(); - // Val* warp_id = - // SimplifyingIrBuilder::divExpr(async_warp_thread_index, warp_size); - // // TODO Only select first warp now - // Val* select_warp = SimplifyingIrBuilder::eqExpr(warp_id, zero); - - // // Use elect-sync if available - // if (pdim_map.canUseElectSyncInAsyncWarp()) { - // return SimplifyingIrBuilder::logicalAndExpr( - // select_warp, createElectSyncExpr()); - // } - - // // Warp Specialized ParallelType is ThreadIdx.x and it contains less than - // 32 - // // threads, so manually select first thread in warp. - // Val* thread_id = - // SimplifyingIrBuilder::modExpr(async_warp_thread_index, warp_size); - // Val* select_thread = SimplifyingIrBuilder::eqExpr(thread_id, zero); - // return SimplifyingIrBuilder::logicalAndExpr(select_warp, select_thread); } Val* createElectSyncPredicate(kir::Predicate* pred, bool is_async_warp) { From dcf233db0d3aafe01c27414118be56f9c13753a0 Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Wed, 14 Jan 2026 11:35:39 -0800 Subject: [PATCH 07/11] clea --- csrc/device_lower/pass/allocation.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/csrc/device_lower/pass/allocation.cpp b/csrc/device_lower/pass/allocation.cpp index 63d80fddb00..3116a02697c 100644 --- a/csrc/device_lower/pass/allocation.cpp +++ b/csrc/device_lower/pass/allocation.cpp @@ -1699,9 +1699,10 @@ class AllocationInserter : public kir::ExprMutator { AllocationInserter(const std::vector& exprs) : gpu_lower_(GpuLower::current()) { - // For warp specialized kernel, insert uniform warp id at the top-level - // scope. - if (GpuLower::current()->circularBufferInfo().hasWarpSpecialized()) { + // Insert uniform warp id at the top-level scope if warp specialization or + // cluster reduction is used. + if (GpuLower::current()->circularBufferInfo().hasWarpSpecialized() || + GpuLower::current()->clusterReductionCount() >= 1) { insertUniformWarpId(exprs.at(0)); } // insert cluster reduction mbarrier at top-level scope From ded26939e91885ad09bed914f843222ae7039528 Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Wed, 14 Jan 2026 11:52:40 -0800 Subject: [PATCH 08/11] directly compute warp id --- csrc/codegen.cpp | 5 --- csrc/device_lower/pass/allocation.cpp | 51 ++++++++++++++++------ csrc/device_lower/pass/circular_buffer.cpp | 5 ++- csrc/device_lower/pass/index.cpp | 7 --- csrc/device_lower/pass/index.h | 1 - csrc/dispatch.h | 1 - csrc/ir/utils.cpp | 5 --- csrc/kernel_ir.cpp | 21 --------- csrc/kernel_ir.h | 20 --------- runtime/warp.cu | 11 ----- 10 files changed, 40 insertions(+), 87 deletions(-) diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index d61bc336b00..af4ee21e694 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -4442,11 +4442,6 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { indent() << call << ";\n"; } - void handle(const kir::UniformWarpId* init) final { - indent() << gen(init->out()) << " = " - << genCall("warp::getUniformWarpId", ArgumentBuilder()) << ";\n"; - } - void handle(const kir::MBarrierInvalidate* inval) final { auto call = genCall( "mbarrier::inval", ArgumentBuilder().arg(genInline(inval->mbarrier()))); diff --git a/csrc/device_lower/pass/allocation.cpp b/csrc/device_lower/pass/allocation.cpp index 3116a02697c..2385142007e 100644 --- a/csrc/device_lower/pass/allocation.cpp +++ b/csrc/device_lower/pass/allocation.cpp @@ -1234,20 +1234,43 @@ class AllocationInserter : public kir::ExprMutator { // insert a scalar register variable for uniform warp id void insertUniformWarpId(Expr* expr) { - // allocate uniform_warp_id - Val* uniform_warp_id = IrBuilder::create(DataType::UInt32); - kir::Allocate* uniform_warp_id_alloc = IrBuilder::create( - uniform_warp_id, - MemoryType::Local, - FusionGuard::getCurFusion()->oneVal()); - registerInsertBefore(expr, uniform_warp_id_alloc, nullptr); - - // initialize uniform_warp_id - auto uniform_warp_id_init = - IrBuilder::create(uniform_warp_id); - registerInsertBefore(expr, uniform_warp_id_init, nullptr); - - // Store the uniform_warp_id in GpuLower for later use in predicates + // Compute flat thread id: tid = threadIdx.x + threadIdx.y * blockDim.x + + // threadIdx.z * blockDim.x * blockDim.y + const auto& pdim = GpuLower::current()->info().parallelDimensionMap(); + Val* tid = FusionGuard::getCurFusion()->zeroVal(); + Val* bdimx = pdim.getRaw(ParallelType::TIDx); + Val* bdimy = pdim.getRaw(ParallelType::TIDy); + Val* bdimz = pdim.getRaw(ParallelType::TIDz); + + if (bdimx != nullptr) { + tid = NamedScalar::getParallelIndex(ParallelType::TIDx); + } + if (bdimy != nullptr) { + Val* tidy = NamedScalar::getParallelIndex(ParallelType::TIDy); + if (bdimx != nullptr) { + tidy = SimplifyingIrBuilder::mulExpr(tidy, bdimx); + } + tid = SimplifyingIrBuilder::addExpr(tid, tidy); + } + if (bdimz != nullptr) { + Val* tidz = NamedScalar::getParallelIndex(ParallelType::TIDz); + if (bdimy != nullptr) { + tidz = SimplifyingIrBuilder::mulExpr(tidz, bdimy); + } + if (bdimx != nullptr) { + tidz = SimplifyingIrBuilder::mulExpr(tidz, bdimx); + } + tid = SimplifyingIrBuilder::addExpr(tid, tidz); + } + + // Compute warp_id = tid / 32 + Val* warp_size = IrBuilder::create(32L, DataType::Index); + Val* warp_id = SimplifyingIrBuilder::divExpr(tid, warp_size); + + // Cast to UInt32 for use in predicates and store in GpuLower + Val* uniform_warp_id = + SimplifyingIrBuilder::maybeCastExpr(DataType::UInt32, warp_id); + GpuLower::current()->setUniformWarpId(uniform_warp_id); } diff --git a/csrc/device_lower/pass/circular_buffer.cpp b/csrc/device_lower/pass/circular_buffer.cpp index 1902128511d..5e77fb6f32f 100644 --- a/csrc/device_lower/pass/circular_buffer.cpp +++ b/csrc/device_lower/pass/circular_buffer.cpp @@ -1566,7 +1566,7 @@ class WarpSpecializedCircularBufferInserter : private kir::ExprMutator { } // Create predicate for warp-specialized IfThenElse: - // kir::Predicate is UniformWarpId() >= num_compute_warps + // kir::Predicate is warp_id >= num_compute_warps kir::Predicate* getAsyncWarpPredicate(const CircularBufferOptions& options) { // Get the number of compute warps using ParallelDimensionMap // This works correctly for warp specialization on TIDx, TIDy, or TIDz @@ -1575,7 +1575,8 @@ class WarpSpecializedCircularBufferInserter : private kir::ExprMutator { Val* num_compute_warps = pdim_map.getNumComputeWarps(); Val* uniform_warp_id = GpuLower::current()->uniformWarpId(); - NVF_ERROR(uniform_warp_id != nullptr, "UniformWarpId must be initialized"); + NVF_ERROR( + uniform_warp_id != nullptr, "uniform_warp_id must be initialized"); return IrBuilder::create( IrBuilder::geExpr(uniform_warp_id, num_compute_warps)); } diff --git a/csrc/device_lower/pass/index.cpp b/csrc/device_lower/pass/index.cpp index 36c4e792b06..817d79dc8cc 100644 --- a/csrc/device_lower/pass/index.cpp +++ b/csrc/device_lower/pass/index.cpp @@ -1571,13 +1571,6 @@ void IndexLowering::handle(const kir::MBarrierInvalidate* minval) { GpuLower::current()->propagateExprInfo(minval, minval_indexed); } -void IndexLowering::handle(const kir::UniformWarpId* uwid_init) { - // UniformWarpId operates on scalar Vals, no indexing needed - // Just pass through as-is - pushBack(IrBuilder::create(uwid_init->out())); - GpuLower::current()->propagateExprInfo(uwid_init, back()); -} - void IndexLowering::handle(const kir::MBarrierArrive* arrive_transaction) { NVF_ERROR( arrive_transaction->mbarrier()->isA(), diff --git a/csrc/device_lower/pass/index.h b/csrc/device_lower/pass/index.h index 7643a254a1e..b47a9f9b36a 100644 --- a/csrc/device_lower/pass/index.h +++ b/csrc/device_lower/pass/index.h @@ -88,7 +88,6 @@ class IndexLowering : private OptOutConstDispatch { void handle(const kir::Continue*) final; void handle(const kir::Return*) final; void handle(const kir::MBarrierInit*) final; - void handle(const kir::UniformWarpId*) final; void handle(const kir::MBarrierInvalidate*) final; void handle(const kir::MBarrierArrive*) final; void handle(const kir::MBarrierArriveExpectTx*) final; diff --git a/csrc/dispatch.h b/csrc/dispatch.h index 2b0ebd5489b..822ababb149 100644 --- a/csrc/dispatch.h +++ b/csrc/dispatch.h @@ -139,7 +139,6 @@ class Val; f(Continue); \ f(Return); \ f(MBarrierInit); \ - f(UniformWarpId); \ f(MBarrierInvalidate); \ f(MBarrierArrive); \ f(MBarrierArriveExpectTx); \ diff --git a/csrc/ir/utils.cpp b/csrc/ir/utils.cpp index b2c3b25255b..9418b137688 100644 --- a/csrc/ir/utils.cpp +++ b/csrc/ir/utils.cpp @@ -1417,11 +1417,6 @@ bool isFunctional(const Val* v) { if (dynamic_cast(def)) { return false; } - // UniformWarpId is a runtime operation that should not be evaluated - // at compile time or constant-folded by SimplifyingIrBuilder - if (dynamic_cast(def)) { - return false; - } return std::all_of(def->inputs().begin(), def->inputs().end(), isFunctional); } diff --git a/csrc/kernel_ir.cpp b/csrc/kernel_ir.cpp index 01f61ee1f7f..e28fc12c292 100644 --- a/csrc/kernel_ir.cpp +++ b/csrc/kernel_ir.cpp @@ -1075,27 +1075,6 @@ std::string MBarrierInit::toInlineString(int indent_size) const { NVFUSER_DEFINE_CLONE_AND_CREATE(MBarrierInit) -UniformWarpId::UniformWarpId(IrBuilderPasskey passkey, Val* out) - : Expr(passkey) { - NVF_ERROR(passkey.ir_container_ != nullptr); - NVF_CHECK(out->dtype() == DataType::UInt32); - addOutput(out); -} - -std::string UniformWarpId::toString(int indent_size) const { - std::stringstream ss; - indent(ss, indent_size) << out()->toString() << " = UniformWarpId()\n"; - return ss.str(); -} - -// uniform warp id is used in predicate, Predicate::toString() uses its -// toInlineString() -std::string UniformWarpId::toInlineString(int indent_size) const { - return std::string(getOpString()) + "()"; -} - -NVFUSER_DEFINE_CLONE_AND_CREATE(UniformWarpId) - MBarrierInvalidate::MBarrierInvalidate(IrBuilderPasskey passkey, Val* mbarrier) : Expr(passkey) { NVF_ERROR(passkey.ir_container_ != nullptr); diff --git a/csrc/kernel_ir.h b/csrc/kernel_ir.h index db2d4db9291..b98aa1d0d79 100644 --- a/csrc/kernel_ir.h +++ b/csrc/kernel_ir.h @@ -58,7 +58,6 @@ class AsyncWait; class AsyncCommit; class InitMagicZero; class UpdateMagicZero; -class UniformWarpId; class IfThenElse; class GridReduction; class GroupedGridReduction; @@ -848,25 +847,6 @@ class MBarrierInit final : public Expr { } }; -class UniformWarpId final : public Expr { - public: - using Expr::Expr; - explicit UniformWarpId(IrBuilderPasskey passkey, Val* out); - - NVFUSER_DECLARE_CLONE_AND_CREATE - - const char* getOpString() const override { - return "UniformWarpId"; - } - - std::string toString(int indent_size = 0) const override; - std::string toInlineString(int indent_size = 0) const override; - - Val* out() const { - return output(0); - } -}; - class MBarrierInvalidate final : public Expr { public: using Expr::Expr; diff --git a/runtime/warp.cu b/runtime/warp.cu index 838e0115637..35ecbb88a41 100644 --- a/runtime/warp.cu +++ b/runtime/warp.cu @@ -7,17 +7,6 @@ // clang-format on namespace warp { -// Compute uniform warp id that is guaranteed to be the same for all threads in -// a warp. -// __shfl_sync helps PTXAS prove that every thread in the warp has the same -// uniform warp id. -__device__ __forceinline__ uint32_t getUniformWarpId() { - const unsigned int tid = threadIdx.x + threadIdx.y * blockDim.x + - threadIdx.z * blockDim.x * blockDim.y; - const unsigned int warp_id = tid / 32; - return __shfl_sync(0xFFFFFFFF, warp_id, 0); -} - template __device__ __forceinline__ T shfl_xor(T var, int laneMask, int width = 32) { return __shfl_xor_sync(0xffffffff, var, laneMask, width); From adebdd03037798062f72a194ca603b0938a12e17 Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Wed, 14 Jan 2026 19:35:22 -0800 Subject: [PATCH 09/11] only for warp specialized --- csrc/device_lower/pass/allocation.cpp | 11 +- csrc/parallel_dimension_map.cpp | 33 ++++++ csrc/parallel_dimension_map.h | 5 + csrc/predicate_compute.cpp | 150 ++++++++++++++++++++------ 4 files changed, 162 insertions(+), 37 deletions(-) diff --git a/csrc/device_lower/pass/allocation.cpp b/csrc/device_lower/pass/allocation.cpp index 2385142007e..c355f5cefd4 100644 --- a/csrc/device_lower/pass/allocation.cpp +++ b/csrc/device_lower/pass/allocation.cpp @@ -1232,8 +1232,7 @@ class AllocationInserter : public kir::ExprMutator { return alloc_expr; } - // insert a scalar register variable for uniform warp id - void insertUniformWarpId(Expr* expr) { + void computeUniformWarpId(Expr* expr) { // Compute flat thread id: tid = threadIdx.x + threadIdx.y * blockDim.x + // threadIdx.z * blockDim.x * blockDim.y const auto& pdim = GpuLower::current()->info().parallelDimensionMap(); @@ -1722,11 +1721,9 @@ class AllocationInserter : public kir::ExprMutator { AllocationInserter(const std::vector& exprs) : gpu_lower_(GpuLower::current()) { - // Insert uniform warp id at the top-level scope if warp specialization or - // cluster reduction is used. - if (GpuLower::current()->circularBufferInfo().hasWarpSpecialized() || - GpuLower::current()->clusterReductionCount() >= 1) { - insertUniformWarpId(exprs.at(0)); + // compute uniform warp id if warp specialization is enabled + if (GpuLower::current()->circularBufferInfo().hasWarpSpecialized()) { + computeUniformWarpId(exprs.at(0)); } // insert cluster reduction mbarrier at top-level scope if (GpuLower::current()->clusterReductionCount() >= 1) { diff --git a/csrc/parallel_dimension_map.cpp b/csrc/parallel_dimension_map.cpp index 80b21a24fa1..565bc580304 100644 --- a/csrc/parallel_dimension_map.cpp +++ b/csrc/parallel_dimension_map.cpp @@ -307,6 +307,39 @@ Val* ParallelDimensionMap::getNumComputeWarps() const { return num_compute_warps; } +// For warp-specialization, the CTA is padded so the AsyncWarp contains 128 +// threads. This function maps the AsyncWarp CTA to a linear index from +// [0, 128). It is used to divide AsyncWarp into four independent warps. +Val* ParallelDimensionMap::getLinearThreadIndexAsync() const { + Val* index = GpuLower::current()->kernel()->zeroVal(); + Val* extent = GpuLower::current()->kernel()->oneVal(); + + for (auto pt : kParallelTypeTIDs) { + // For warp-specialization, an axis is padded so the AsyncWarp contains + // 128 threads. + Val* extent_for_pdim = getRawAsync(pt); + // short-circuit: extent_for_pdim is not used in kernel. + if (extent_for_pdim == nullptr) { + continue; + } + // short-circuit: extent_for_pdim is trivial. + if (extent_for_pdim->isConstScalar() && + extent_for_pdim->evaluate().as() == 1) { + continue; + } + Val* pt_index = NamedScalar::getParallelIndex(pt); + // Map the padded parallel index to [0, padded_value] range, so the linear + // index will be in range of [0, 128). + if (isWarpSpecialized(pt)) { + pt_index = SimplifyingIrBuilder::subExpr(pt_index, getRawCompute(pt)); + } + index = SimplifyingIrBuilder::addExpr( + index, SimplifyingIrBuilder::mulExpr(pt_index, extent)); + extent = SimplifyingIrBuilder::mulExpr(extent, extent_for_pdim); + } + return index; +} + bool ParallelDimensionMap::canUseElectSyncInAsyncWarp() const { // short-circuit: skip if warp specialization is not enabled if (!hasWarpSpecialization()) { diff --git a/csrc/parallel_dimension_map.h b/csrc/parallel_dimension_map.h index d08393d2fdc..ba4123682ae 100644 --- a/csrc/parallel_dimension_map.h +++ b/csrc/parallel_dimension_map.h @@ -86,6 +86,11 @@ class ParallelDimensionMap { //! - If warp specialized on TIDz: bdimx * bdimy * (bdimz - pad) / 32 Val* getNumComputeWarps() const; + //! For warp-specialization, the CTA is padded so the AsyncWarp contains 128 + //! threads. This function maps the AsyncWarp CTA to a linear index from + //! [0, 128). It is used to divide AsyncWarp into four independent warps. + Val* getLinearThreadIndexAsync() const; + //! Get if the kernel uses warp specialization bool hasWarpSpecialization() const { return warp_specialized_parallel_type_.has_value(); diff --git a/csrc/predicate_compute.cpp b/csrc/predicate_compute.cpp index 9957e24c076..17abf48b68d 100644 --- a/csrc/predicate_compute.cpp +++ b/csrc/predicate_compute.cpp @@ -494,38 +494,84 @@ Val* createElectSyncExpr() { return elect_sync_val; } -Val* selectWarpIndex(uint32_t warp_index) { - Val* uniform_warp_id = GpuLower::current()->uniformWarpId(); - NVF_ERROR(uniform_warp_id != nullptr); - Val* target_warp_index = - IrBuilder::create(warp_index, PrimDataType::UInt32); - return IrBuilder::eqExpr(uniform_warp_id, target_warp_index); -} +// Select first warp of threads along TIDx axis and use ptx::elect_sync if not +// warp collective. +// TODO If TIDx is known at compile-time, generate custom mask. +Val* selectFirstWarpElectSyncPredicate(bool is_warp_collective) { + // Use new approach with uniform warp ID when warp specialization is enabled + if (GpuLower::current()->circularBufferInfo().hasWarpSpecialized()) { + Val* uniform_warp_id = GpuLower::current()->uniformWarpId(); + NVF_ERROR(uniform_warp_id != nullptr); + Val* target_warp_index = IrBuilder::create(0u, PrimDataType::UInt32); + Val* select_warp = IrBuilder::eqExpr(uniform_warp_id, target_warp_index); + if (is_warp_collective) { + return select_warp; + } + return SimplifyingIrBuilder::logicalAndExpr( + select_warp, createElectSyncExpr()); + } -// Select the first warp or one thread of the first warp in the block -Val* selectWarpIdxElectSyncPredicate( - uint32_t warp_index, - bool is_warp_collective) { - return is_warp_collective - ? selectWarpIndex(warp_index) - : SimplifyingIrBuilder::logicalAndExpr( - selectWarpIndex(warp_index), createElectSyncExpr()); + Val* warp_size = IrBuilder::create(32L, PrimDataType::UInt64); + Val* select_first_warp = IrBuilder::ltExpr( + NamedScalar::getParallelIndex(ParallelType::TIDx), warp_size); + + // Short-Circuit: TMA Store is a warp-collective, so ElectSync is not + // necessary. + if (is_warp_collective) { + return select_first_warp; + } + + return SimplifyingIrBuilder::logicalAndExpr( + createElectSyncExpr(), select_first_warp); } -// Since num_compute_warps represents the number of compute warps, -// the first async warp starts at index num_compute_warps +// Get linear index for AsyncWarp Group. Then, select first warp. Finally, use +// ptx::elect_sync if not warp collective. +// TODO If TIDx is known at compile-time, generate custom mask. Val* createElectSyncPredicateAsync() { const ParallelDimensionMap& pdim_map = GpuLower::current()->info().parallelDimensionMap(); - Val* num_compute_warps = pdim_map.getNumComputeWarps(); - NVF_ERROR( - num_compute_warps != nullptr, "NumComputeWarps must be initialized"); - NVF_ERROR( - num_compute_warps->isConstScalar(), "NumComputeWarps must be a constant"); - uint32_t warp_index = - static_cast(num_compute_warps->evaluate().as()); - return selectWarpIdxElectSyncPredicate( - warp_index, /*is_warp_collective=*/false); + + // Use new approach with uniform warp ID when warp specialization is enabled + if (GpuLower::current()->circularBufferInfo().hasWarpSpecialized()) { + Val* num_compute_warps = pdim_map.getNumComputeWarps(); + NVF_ERROR( + num_compute_warps != nullptr, "NumComputeWarps must be initialized"); + NVF_ERROR( + num_compute_warps->isConstScalar(), + "NumComputeWarps must be a constant"); + uint32_t warp_index = + static_cast(num_compute_warps->evaluate().as()); + Val* uniform_warp_id = GpuLower::current()->uniformWarpId(); + NVF_ERROR(uniform_warp_id != nullptr); + Val* target_warp_index = + IrBuilder::create(warp_index, PrimDataType::UInt32); + Val* select_warp = IrBuilder::eqExpr(uniform_warp_id, target_warp_index); + return SimplifyingIrBuilder::logicalAndExpr( + select_warp, createElectSyncExpr()); + } + + Val* zero = IrBuilder::create(0L, PrimDataType::UInt64); + Val* warp_size = IrBuilder::create(32L, PrimDataType::UInt64); + + Val* async_warp_thread_index = pdim_map.getLinearThreadIndexAsync(); + Val* warp_id = + SimplifyingIrBuilder::divExpr(async_warp_thread_index, warp_size); + // TODO Only select first warp now + Val* select_warp = SimplifyingIrBuilder::eqExpr(warp_id, zero); + + // Use elect-sync if available + if (pdim_map.canUseElectSyncInAsyncWarp()) { + return SimplifyingIrBuilder::logicalAndExpr( + select_warp, createElectSyncExpr()); + } + + // Warp Specialized ParallelType is ThreadIdx.x and it contains less than 32 + // threads, so manually select first thread in warp. + Val* thread_id = + SimplifyingIrBuilder::modExpr(async_warp_thread_index, warp_size); + Val* select_thread = SimplifyingIrBuilder::eqExpr(thread_id, zero); + return SimplifyingIrBuilder::logicalAndExpr(select_warp, select_thread); } Val* createElectSyncPredicate(kir::Predicate* pred, bool is_async_warp) { @@ -573,7 +619,7 @@ Val* createElectSyncPredicate(kir::Predicate* pred, bool is_async_warp) { if (is_async_warp) { return createElectSyncPredicateAsync(); } - return selectWarpIdxElectSyncPredicate(0, is_tma_store); + return selectFirstWarpElectSyncPredicate(is_tma_store); } Val* createSingleExpressionElectSync( @@ -657,16 +703,60 @@ Val* createMultipleExpressionElectSync( kir::Predicate* pred, const std::vector& loops) { NVF_ERROR(pred->expr() == nullptr); + auto async_warp_loop_it = std::find_if(loops.begin(), loops.end(), [](kir::ForLoop* fl) { return fl->circularBufferLoopStage() == CircularBufferLoopStage::AsyncWarp; }); + + // Use new approach with uniform warp ID when warp specialization is enabled + if (GpuLower::current()->circularBufferInfo().hasWarpSpecialized()) { + if (async_warp_loop_it != loops.end()) { + return createElectSyncPredicateAsync(); + } else { + Val* uniform_warp_id = GpuLower::current()->uniformWarpId(); + NVF_ERROR(uniform_warp_id != nullptr); + Val* target_warp_index = IrBuilder::create(0u, PrimDataType::UInt32); + Val* select_warp = IrBuilder::eqExpr(uniform_warp_id, target_warp_index); + return SimplifyingIrBuilder::logicalAndExpr( + select_warp, createElectSyncExpr()); + } + } + + Val* zero = IrBuilder::create(0L, PrimDataType::UInt64); + const ParallelDimensionMap& pdim_map = + GpuLower::current()->info().parallelDimensionMap(); + + // Determine if warp specialized tma load expression. + ParallelType async_warp_on = ParallelType::Serial; if (async_warp_loop_it != loops.end()) { - return createElectSyncPredicateAsync(); - } else { - return selectWarpIdxElectSyncPredicate(0, /*is_warp_collective=*/false); + auto circular_buffer_type = std::get( + GpuLower::current() + ->circularBufferInfo() + .getCircularBufferOptionsFor((*async_warp_loop_it)->iter_domain()) + .type); + async_warp_on = circular_buffer_type.on; + } + + // Short-circuit: If we are in a async warp, then the warp-dispatching + // IfThenElse already selects on `async_warp_on`, so we should not + // generate predicates for it here. + if (async_warp_loop_it == loops.end()) { + Val* conditional = async_warp_on == ParallelType::TIDx + ? pred->fusion()->trueVal() + : selectFirstWarpElectSyncPredicate(/*is_warp_collective=*/false); + for (ParallelType pt : {ParallelType::TIDy, ParallelType::TIDz}) { + if (pdim_map.has(pt) && async_warp_on != pt) { + conditional = SimplifyingIrBuilder::logicalAndExpr( + conditional, + IrBuilder::eqExpr(NamedScalar::getParallelIndex(pt), zero)); + } + } + return conditional; } + + return createElectSyncPredicateAsync(); } } // namespace From 459df7ec45366f4403081a8364caeb44671f0d93 Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Thu, 15 Jan 2026 06:05:01 -0800 Subject: [PATCH 10/11] clean --- tests/cpp/test_combined_inner_outer_reduction.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cpp/test_combined_inner_outer_reduction.cpp b/tests/cpp/test_combined_inner_outer_reduction.cpp index e26ecce0b74..e9a684b1cfe 100644 --- a/tests/cpp/test_combined_inner_outer_reduction.cpp +++ b/tests/cpp/test_combined_inner_outer_reduction.cpp @@ -1299,7 +1299,7 @@ TEST_P(TmaWarpSpecializedTest, LayerNormBackward) { auto TmaWarpSpecializedTestParams() { std::vector values; int64_t dim0 = 2048; - for (int64_t dim1 = 1024; dim1 <= 32768; dim1 += 256) { + for (int64_t dim1 = 1024; dim1 <= 8192; dim1 += 256) { for (bool contig : {true, false}) { // to save test time if (dim1 != 1024 && !contig) { From 122ac43355a26cdf55d9258258c2032bc63857a0 Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Thu, 15 Jan 2026 09:15:19 -0800 Subject: [PATCH 11/11] fix --- csrc/device_lower/pass/allocation.cpp | 7 +++-- csrc/device_lower/pass/circular_buffer.cpp | 34 +++++++++++++++------- csrc/parallel_dimension_map.cpp | 27 ++++++++++++++++- csrc/parallel_dimension_map.h | 18 +++++++++++- csrc/predicate_compute.cpp | 21 ++++++------- 5 files changed, 81 insertions(+), 26 deletions(-) diff --git a/csrc/device_lower/pass/allocation.cpp b/csrc/device_lower/pass/allocation.cpp index c355f5cefd4..8145e2ba9d1 100644 --- a/csrc/device_lower/pass/allocation.cpp +++ b/csrc/device_lower/pass/allocation.cpp @@ -1721,8 +1721,11 @@ class AllocationInserter : public kir::ExprMutator { AllocationInserter(const std::vector& exprs) : gpu_lower_(GpuLower::current()) { - // compute uniform warp id if warp specialization is enabled - if (GpuLower::current()->circularBufferInfo().hasWarpSpecialized()) { + // Warp-id-based predicates (e.g., warp_id >= threshold) only work when + // async/compute warps have consecutive warp IDs. + if (gpu_lower_->info() + .parallelDimensionMap() + .canUseWarpIdBasedPredicate()) { computeUniformWarpId(exprs.at(0)); } // insert cluster reduction mbarrier at top-level scope diff --git a/csrc/device_lower/pass/circular_buffer.cpp b/csrc/device_lower/pass/circular_buffer.cpp index 5e77fb6f32f..eaa579eb615 100644 --- a/csrc/device_lower/pass/circular_buffer.cpp +++ b/csrc/device_lower/pass/circular_buffer.cpp @@ -1566,19 +1566,34 @@ class WarpSpecializedCircularBufferInserter : private kir::ExprMutator { } // Create predicate for warp-specialized IfThenElse: - // kir::Predicate is warp_id >= num_compute_warps + // If uniform warp ID is available, use warp-ID-based predicate (warp_id >= + // num_compute_warps) kir::Predicate* getAsyncWarpPredicate(const CircularBufferOptions& options) { - // Get the number of compute warps using ParallelDimensionMap - // This works correctly for warp specialization on TIDx, TIDy, or TIDz const ParallelDimensionMap& pdim_map = GpuLower::current()->info().parallelDimensionMap(); - Val* num_compute_warps = pdim_map.getNumComputeWarps(); Val* uniform_warp_id = GpuLower::current()->uniformWarpId(); - NVF_ERROR( - uniform_warp_id != nullptr, "uniform_warp_id must be initialized"); - return IrBuilder::create( - IrBuilder::geExpr(uniform_warp_id, num_compute_warps)); + if (uniform_warp_id != nullptr) { + // Use uniform warp ID approach: async warps have warp_id >= + // num_compute_warps + Val* num_compute_warps = pdim_map.getNumComputeWarps(); + NVF_ERROR( + num_compute_warps != nullptr, + "num_compute_warps must be initialized"); + return IrBuilder::create( + IrBuilder::geExpr(uniform_warp_id, num_compute_warps)); + } + + // Fallback: use parallel index comparison + ParallelType warp_specialize_on = + std::get(options.type).on; + int64_t warp_specialization_pad = + pdim_map.getWarpSpecializationPaddedVal(warp_specialize_on); + Val* raw = pdim_map.get(warp_specialize_on); + Val* raw_minus_pad = SimplifyingIrBuilder::subExpr( + raw, IrBuilder::create(warp_specialization_pad, DataType::Index)); + return IrBuilder::create(IrBuilder::geExpr( + NamedScalar::getParallelIndex(warp_specialize_on), raw_minus_pad)); } void insertTmaWarpSpecialized( @@ -2017,8 +2032,7 @@ kir::ForLoop* HopperPingPongMbarriers::initializePingPongMbarrier() { GpuLower::current() ->info() .parallelDimensionMap() - .getNumComputeThreadsEachBlock( - /*only_count_same_compute_warp_groups=*/true)); + .getNumComputeThreadsEachBlock(true)); kir::TensorIndex* ping_pong_mbarrier_index = IrBuilder::create(mbarriers_, loop->index()); kir::MBarrierInit* ping_pong_mbarrier_init = diff --git a/csrc/parallel_dimension_map.cpp b/csrc/parallel_dimension_map.cpp index 565bc580304..6194e799b89 100644 --- a/csrc/parallel_dimension_map.cpp +++ b/csrc/parallel_dimension_map.cpp @@ -152,7 +152,7 @@ void ParallelDimensionMap::adjustMappingsForWarpPadding() { exact_types_.erase(ParallelType::TIDx); } -int64_t ParallelDimensionMap::getThreadCountInDim(ParallelType pt) { +int64_t ParallelDimensionMap::getThreadCountInDim(ParallelType pt) const { if (!dim_map_.contains(pt)) { return 1; } @@ -361,6 +361,31 @@ bool ParallelDimensionMap::canUseElectSyncInAsyncWarp() const { return false; } +bool ParallelDimensionMap::canUseWarpIdBasedPredicate() const { + if (!hasWarpSpecialization()) { + return false; + } + + // For consecutive warp IDs, all dimensions after the warp-specialized + // dimension must be 1. Otherwise outer dimensions create gaps in warp IDs. + NVF_ERROR(warp_specialized_parallel_type_.has_value()); + ParallelType ws_pt = warp_specialized_parallel_type_.value(); + + bool found_ws_pt = false; + for (ParallelType pt : kParallelTypeTIDs) { + if (pt == ws_pt) { + found_ws_pt = true; + } else if (found_ws_pt) { + int64_t thread_count = getThreadCountInDim(pt); + if (thread_count == -1 || thread_count > 1) { + return false; + } + } + } + + return true; +} + std::string ParallelDimensionMap::toString() const { std::stringstream ss; for (auto pt : kParallelTypeThreads) { diff --git a/csrc/parallel_dimension_map.h b/csrc/parallel_dimension_map.h index ba4123682ae..b6cdeb6e582 100644 --- a/csrc/parallel_dimension_map.h +++ b/csrc/parallel_dimension_map.h @@ -109,10 +109,26 @@ class ParallelDimensionMap { // elect-sync cannot be used. bool canUseElectSyncInAsyncWarp() const; + //! Check if warp-id-based predicates can be used for warp specialization. + //! Warp-id-based predicates (e.g., warp_id >= N) only work when the + //! warp-specialized dimension produces consecutive warp IDs. This requires + //! that the warp-specialized dimension is the outermost dimension > 1, + //! meaning ALL dimensions after it must be 1. + //! + //! Example: warp specialized on TIDy with CTA (32, 6, 2): + //! TIDz=2 after TIDy causes non-consecutive warps (FAILS) + //! Example: warp specialized on TIDz with CTA (32, 4, 3): + //! No dimensions after TIDz -> consecutive warps (WORKS) + //! + //! Returns true if: + //! - No warp specialization is used, OR + //! - All dimensions after the warp-specialized dimension are 1 + bool canUseWarpIdBasedPredicate() const; + private: //! Get number of threads for ParallelType axis //! Not used: 1, Const: n, Dynamic: -1 - int64_t getThreadCountInDim(ParallelType pt); + int64_t getThreadCountInDim(ParallelType pt) const; //! TIDx may need to be marked as non-exact as it may be padded to a //! multiple of the warp size. diff --git a/csrc/predicate_compute.cpp b/csrc/predicate_compute.cpp index 17abf48b68d..29432cd36c7 100644 --- a/csrc/predicate_compute.cpp +++ b/csrc/predicate_compute.cpp @@ -498,10 +498,9 @@ Val* createElectSyncExpr() { // warp collective. // TODO If TIDx is known at compile-time, generate custom mask. Val* selectFirstWarpElectSyncPredicate(bool is_warp_collective) { - // Use new approach with uniform warp ID when warp specialization is enabled - if (GpuLower::current()->circularBufferInfo().hasWarpSpecialized()) { - Val* uniform_warp_id = GpuLower::current()->uniformWarpId(); - NVF_ERROR(uniform_warp_id != nullptr); + // If uniform warp ID is available, use warp-ID-based predicate + Val* uniform_warp_id = GpuLower::current()->uniformWarpId(); + if (uniform_warp_id != nullptr) { Val* target_warp_index = IrBuilder::create(0u, PrimDataType::UInt32); Val* select_warp = IrBuilder::eqExpr(uniform_warp_id, target_warp_index); if (is_warp_collective) { @@ -532,8 +531,9 @@ Val* createElectSyncPredicateAsync() { const ParallelDimensionMap& pdim_map = GpuLower::current()->info().parallelDimensionMap(); - // Use new approach with uniform warp ID when warp specialization is enabled - if (GpuLower::current()->circularBufferInfo().hasWarpSpecialized()) { + // If uniform warp ID is available, use warp-ID-based predicate + Val* uniform_warp_id = GpuLower::current()->uniformWarpId(); + if (uniform_warp_id != nullptr) { Val* num_compute_warps = pdim_map.getNumComputeWarps(); NVF_ERROR( num_compute_warps != nullptr, "NumComputeWarps must be initialized"); @@ -542,8 +542,6 @@ Val* createElectSyncPredicateAsync() { "NumComputeWarps must be a constant"); uint32_t warp_index = static_cast(num_compute_warps->evaluate().as()); - Val* uniform_warp_id = GpuLower::current()->uniformWarpId(); - NVF_ERROR(uniform_warp_id != nullptr); Val* target_warp_index = IrBuilder::create(warp_index, PrimDataType::UInt32); Val* select_warp = IrBuilder::eqExpr(uniform_warp_id, target_warp_index); @@ -710,13 +708,12 @@ Val* createMultipleExpressionElectSync( CircularBufferLoopStage::AsyncWarp; }); - // Use new approach with uniform warp ID when warp specialization is enabled - if (GpuLower::current()->circularBufferInfo().hasWarpSpecialized()) { + // If uniform warp ID is available, use warp-ID-based predicate + Val* uniform_warp_id = GpuLower::current()->uniformWarpId(); + if (uniform_warp_id != nullptr) { if (async_warp_loop_it != loops.end()) { return createElectSyncPredicateAsync(); } else { - Val* uniform_warp_id = GpuLower::current()->uniformWarpId(); - NVF_ERROR(uniform_warp_id != nullptr); Val* target_warp_index = IrBuilder::create(0u, PrimDataType::UInt32); Val* select_warp = IrBuilder::eqExpr(uniform_warp_id, target_warp_index); return SimplifyingIrBuilder::logicalAndExpr(