diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index e12ed646540..af4ee21e694 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -3757,7 +3757,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { ArgumentBuilder template_args; template_args.arg(kernel_->paddedParallelDimensions().is_tidx_single_warp); - template_args.arg(isAligned()); + template_args.arg(has_warp_specialized_ ? false : isAligned()); template_args.arg(num_grouped_iterations); template_args.arg(reduction_scheduler_utils::getComputeBdimx( warp_specialized_on_, lparams_.bdimx())); diff --git a/csrc/device_lower/lower2device.h b/csrc/device_lower/lower2device.h index 8f60f69894d..1e78d2b1b01 100644 --- a/csrc/device_lower/lower2device.h +++ b/csrc/device_lower/lower2device.h @@ -318,6 +318,16 @@ class GpuLower : public NonCopyable { cluster_reduction_mbarrier_tensor_ = mbarrier; } + //! Get the uniform warp id scalar + Val* uniformWarpId() const { + return uniform_warp_id_; + } + + //! Set the uniform warp id scalar + void setUniformWarpId(Val* warp_id) { + uniform_warp_id_ = warp_id; + } + //! Define an alias for consumer as producer. //! //! If producer is already aliased, we chase the alias. If there are tensors @@ -434,6 +444,10 @@ class GpuLower : public NonCopyable { // The shared cluster reduction mbarrier tensor allocated during allocation // pass TensorView* cluster_reduction_mbarrier_tensor_ = nullptr; + + // The uniform warp id scalar allocated during allocation pass for warp + // specialized kernels + Val* uniform_warp_id_ = nullptr; }; #define NVFUSER_LOWER_VALIDATE(cond, ...) \ diff --git a/csrc/device_lower/pass/allocation.cpp b/csrc/device_lower/pass/allocation.cpp index 689dc912b16..8145e2ba9d1 100644 --- a/csrc/device_lower/pass/allocation.cpp +++ b/csrc/device_lower/pass/allocation.cpp @@ -976,7 +976,8 @@ Expr* initializeCircularBufferMbarrier( GpuLower::current() ->info() .parallelDimensionMap() - .getNumComputeThreadsEachBlock()); + .getNumComputeThreadsEachBlock( + /*only_count_same_compute_warp_groups=*/true)); } // Initialize mbarrier for each circular buffer stage. Use the thread @@ -1231,6 +1232,47 @@ class AllocationInserter : public kir::ExprMutator { return alloc_expr; } + void computeUniformWarpId(Expr* expr) { + // Compute flat thread id: tid = threadIdx.x + threadIdx.y * blockDim.x + + // threadIdx.z * blockDim.x * blockDim.y + const auto& pdim = GpuLower::current()->info().parallelDimensionMap(); + Val* tid = FusionGuard::getCurFusion()->zeroVal(); + Val* bdimx = pdim.getRaw(ParallelType::TIDx); + Val* bdimy = pdim.getRaw(ParallelType::TIDy); + Val* bdimz = pdim.getRaw(ParallelType::TIDz); + + if (bdimx != nullptr) { + tid = NamedScalar::getParallelIndex(ParallelType::TIDx); + } + if (bdimy != nullptr) { + Val* tidy = NamedScalar::getParallelIndex(ParallelType::TIDy); + if (bdimx != nullptr) { + tidy = SimplifyingIrBuilder::mulExpr(tidy, bdimx); + } + tid = SimplifyingIrBuilder::addExpr(tid, tidy); + } + if (bdimz != nullptr) { + Val* tidz = NamedScalar::getParallelIndex(ParallelType::TIDz); + if (bdimy != nullptr) { + tidz = SimplifyingIrBuilder::mulExpr(tidz, bdimy); + } + if (bdimx != nullptr) { + tidz = SimplifyingIrBuilder::mulExpr(tidz, bdimx); + } + tid = SimplifyingIrBuilder::addExpr(tid, tidz); + } + + // Compute warp_id = tid / 32 + Val* warp_size = IrBuilder::create(32L, DataType::Index); + Val* warp_id = SimplifyingIrBuilder::divExpr(tid, warp_size); + + // Cast to UInt32 for use in predicates and store in GpuLower + Val* uniform_warp_id = + SimplifyingIrBuilder::maybeCastExpr(DataType::UInt32, warp_id); + + GpuLower::current()->setUniformWarpId(uniform_warp_id); + } + // Insert cluster reduction mbarrier allocation and initialization at the // beginning of the kernel for the first top-level expression void insertClusterReductionMBarrier(Expr* expr) { @@ -1679,6 +1721,13 @@ class AllocationInserter : public kir::ExprMutator { AllocationInserter(const std::vector& exprs) : gpu_lower_(GpuLower::current()) { + // Warp-id-based predicates (e.g., warp_id >= threshold) only work when + // async/compute warps have consecutive warp IDs. + if (gpu_lower_->info() + .parallelDimensionMap() + .canUseWarpIdBasedPredicate()) { + computeUniformWarpId(exprs.at(0)); + } // insert cluster reduction mbarrier at top-level scope if (GpuLower::current()->clusterReductionCount() >= 1) { insertClusterReductionMBarrier(exprs.at(0)); diff --git a/csrc/device_lower/pass/circular_buffer.cpp b/csrc/device_lower/pass/circular_buffer.cpp index 67022363f1c..eaa579eb615 100644 --- a/csrc/device_lower/pass/circular_buffer.cpp +++ b/csrc/device_lower/pass/circular_buffer.cpp @@ -1566,17 +1566,30 @@ class WarpSpecializedCircularBufferInserter : private kir::ExprMutator { } // Create predicate for warp-specialized IfThenElse: - // kir::Predicate is thread_axis >= block_dim_axis - padded_value + // If uniform warp ID is available, use warp-ID-based predicate (warp_id >= + // num_compute_warps) kir::Predicate* getAsyncWarpPredicate(const CircularBufferOptions& options) { + const ParallelDimensionMap& pdim_map = + GpuLower::current()->info().parallelDimensionMap(); + + Val* uniform_warp_id = GpuLower::current()->uniformWarpId(); + if (uniform_warp_id != nullptr) { + // Use uniform warp ID approach: async warps have warp_id >= + // num_compute_warps + Val* num_compute_warps = pdim_map.getNumComputeWarps(); + NVF_ERROR( + num_compute_warps != nullptr, + "num_compute_warps must be initialized"); + return IrBuilder::create( + IrBuilder::geExpr(uniform_warp_id, num_compute_warps)); + } + + // Fallback: use parallel index comparison ParallelType warp_specialize_on = std::get(options.type).on; int64_t warp_specialization_pad = - GpuLower::current() - ->info() - .parallelDimensionMap() - .getWarpSpecializationPaddedVal(warp_specialize_on); - Val* raw = GpuLower::current()->info().parallelDimensionMap().get( - warp_specialize_on); + pdim_map.getWarpSpecializationPaddedVal(warp_specialize_on); + Val* raw = pdim_map.get(warp_specialize_on); Val* raw_minus_pad = SimplifyingIrBuilder::subExpr( raw, IrBuilder::create(warp_specialization_pad, DataType::Index)); return IrBuilder::create(IrBuilder::geExpr( @@ -2019,7 +2032,7 @@ kir::ForLoop* HopperPingPongMbarriers::initializePingPongMbarrier() { GpuLower::current() ->info() .parallelDimensionMap() - .getNumComputeThreadsEachBlock()); + .getNumComputeThreadsEachBlock(true)); kir::TensorIndex* ping_pong_mbarrier_index = IrBuilder::create(mbarriers_, loop->index()); kir::MBarrierInit* ping_pong_mbarrier_init = diff --git a/csrc/parallel_dimension_map.cpp b/csrc/parallel_dimension_map.cpp index 3ee6ca64b39..6194e799b89 100644 --- a/csrc/parallel_dimension_map.cpp +++ b/csrc/parallel_dimension_map.cpp @@ -152,7 +152,7 @@ void ParallelDimensionMap::adjustMappingsForWarpPadding() { exact_types_.erase(ParallelType::TIDx); } -int64_t ParallelDimensionMap::getThreadCountInDim(ParallelType pt) { +int64_t ParallelDimensionMap::getThreadCountInDim(ParallelType pt) const { if (!dim_map_.contains(pt)) { return 1; } @@ -257,12 +257,13 @@ Val* ParallelDimensionMap::getRawAsync(ParallelType pt) const { return getRaw(pt); } -Val* ParallelDimensionMap::getNumComputeThreadsEachBlock() const { +Val* ParallelDimensionMap::getNumComputeThreadsEachBlock( + bool only_count_same_compute_warp_groups) const { Val* num_threads = FusionGuard::getCurFusion()->oneVal(); for (auto pt : kParallelTypeTIDs) { // Skip warp specialized ParallelType if the are computation warp groups // are independent. - if (isWarpSpecialized(pt) && + if (only_count_same_compute_warp_groups && isWarpSpecialized(pt) && GpuLower::current() ->circularBufferInfo() .hasIndependentComputeWarpGroups()) { @@ -277,6 +278,35 @@ Val* ParallelDimensionMap::getNumComputeThreadsEachBlock() const { return num_threads; } +int64_t ParallelDimensionMap::getWarpSpecializationPaddedVal( + ParallelType pt) const { + NVF_ERROR(isWarpSpecialized(pt), "Can't find ParallelType: ", pt); + if (!warp_specialized_parallel_type_.has_value()) { + return 1; + } + NVF_ERROR( + warp_specialized_parallel_type_.value() == pt, + "Can't find padded val for: ", + pt); + return warp_specialized_padding_value_.value(); +} + +Val* ParallelDimensionMap::getNumComputeWarps() const { + NVF_ERROR( + hasWarpSpecialization(), + "getNumComputeWarps() should only be called for warp specialized " + "kernels"); + + Val* num_compute_threads = getNumComputeThreadsEachBlock( + /*only_count_same_compute_warp_groups=*/false); + + // Divide by 32 to get the number of warps + Val* num_compute_warps = SimplifyingIrBuilder::divExpr( + num_compute_threads, IrBuilder::create(32L, DataType::Index)); + + return num_compute_warps; +} + // For warp-specialization, the CTA is padded so the AsyncWarp contains 128 // threads. This function maps the AsyncWarp CTA to a linear index from // [0, 128). It is used to divide AsyncWarp into four independent warps. @@ -310,19 +340,6 @@ Val* ParallelDimensionMap::getLinearThreadIndexAsync() const { return index; } -int64_t ParallelDimensionMap::getWarpSpecializationPaddedVal( - ParallelType pt) const { - NVF_ERROR(isWarpSpecialized(pt), "Can't find ParallelType: ", pt); - if (!warp_specialized_parallel_type_.has_value()) { - return 1; - } - NVF_ERROR( - warp_specialized_parallel_type_.value() == pt, - "Can't find padded val for: ", - pt); - return warp_specialized_padding_value_.value(); -} - bool ParallelDimensionMap::canUseElectSyncInAsyncWarp() const { // short-circuit: skip if warp specialization is not enabled if (!hasWarpSpecialization()) { @@ -344,6 +361,31 @@ bool ParallelDimensionMap::canUseElectSyncInAsyncWarp() const { return false; } +bool ParallelDimensionMap::canUseWarpIdBasedPredicate() const { + if (!hasWarpSpecialization()) { + return false; + } + + // For consecutive warp IDs, all dimensions after the warp-specialized + // dimension must be 1. Otherwise outer dimensions create gaps in warp IDs. + NVF_ERROR(warp_specialized_parallel_type_.has_value()); + ParallelType ws_pt = warp_specialized_parallel_type_.value(); + + bool found_ws_pt = false; + for (ParallelType pt : kParallelTypeTIDs) { + if (pt == ws_pt) { + found_ws_pt = true; + } else if (found_ws_pt) { + int64_t thread_count = getThreadCountInDim(pt); + if (thread_count == -1 || thread_count > 1) { + return false; + } + } + } + + return true; +} + std::string ParallelDimensionMap::toString() const { std::stringstream ss; for (auto pt : kParallelTypeThreads) { diff --git a/csrc/parallel_dimension_map.h b/csrc/parallel_dimension_map.h index 00f4165e298..b6cdeb6e582 100644 --- a/csrc/parallel_dimension_map.h +++ b/csrc/parallel_dimension_map.h @@ -73,9 +73,22 @@ class ParallelDimensionMap { //! And this function will return (32 * 16) because the extra one for TIDy is //! introduced by warp specialization and only used for loading circular //! buffer tensors. - Val* getNumComputeThreadsEachBlock() const; - - //! Assign linear index to each thread of CTA. Assume (TDZ, TDY, TDX) order. + Val* getNumComputeThreadsEachBlock( + bool only_count_same_compute_warp_groups) const; + + //! Get the number of compute warps for warp specialized kernels. + //! This computes the total number of compute threads across all dimensions + //! (TIDx, TIDy, TIDz), using the compute dimension (minus padding) for the + //! warp specialized dimension, then divides by 32 to get the number of warps. + //! Examples: + //! - If warp specialized on TIDx: (bdimx - pad) * bdimy * bdimz / 32 + //! - If warp specialized on TIDy: bdimx * (bdimy - pad) * bdimz / 32 + //! - If warp specialized on TIDz: bdimx * bdimy * (bdimz - pad) / 32 + Val* getNumComputeWarps() const; + + //! For warp-specialization, the CTA is padded so the AsyncWarp contains 128 + //! threads. This function maps the AsyncWarp CTA to a linear index from + //! [0, 128). It is used to divide AsyncWarp into four independent warps. Val* getLinearThreadIndexAsync() const; //! Get if the kernel uses warp specialization @@ -96,10 +109,26 @@ class ParallelDimensionMap { // elect-sync cannot be used. bool canUseElectSyncInAsyncWarp() const; + //! Check if warp-id-based predicates can be used for warp specialization. + //! Warp-id-based predicates (e.g., warp_id >= N) only work when the + //! warp-specialized dimension produces consecutive warp IDs. This requires + //! that the warp-specialized dimension is the outermost dimension > 1, + //! meaning ALL dimensions after it must be 1. + //! + //! Example: warp specialized on TIDy with CTA (32, 6, 2): + //! TIDz=2 after TIDy causes non-consecutive warps (FAILS) + //! Example: warp specialized on TIDz with CTA (32, 4, 3): + //! No dimensions after TIDz -> consecutive warps (WORKS) + //! + //! Returns true if: + //! - No warp specialization is used, OR + //! - All dimensions after the warp-specialized dimension are 1 + bool canUseWarpIdBasedPredicate() const; + private: //! Get number of threads for ParallelType axis //! Not used: 1, Const: n, Dynamic: -1 - int64_t getThreadCountInDim(ParallelType pt); + int64_t getThreadCountInDim(ParallelType pt) const; //! TIDx may need to be marked as non-exact as it may be padded to a //! multiple of the warp size. diff --git a/csrc/predicate_compute.cpp b/csrc/predicate_compute.cpp index 532f2ff0d6e..29432cd36c7 100644 --- a/csrc/predicate_compute.cpp +++ b/csrc/predicate_compute.cpp @@ -498,6 +498,18 @@ Val* createElectSyncExpr() { // warp collective. // TODO If TIDx is known at compile-time, generate custom mask. Val* selectFirstWarpElectSyncPredicate(bool is_warp_collective) { + // If uniform warp ID is available, use warp-ID-based predicate + Val* uniform_warp_id = GpuLower::current()->uniformWarpId(); + if (uniform_warp_id != nullptr) { + Val* target_warp_index = IrBuilder::create(0u, PrimDataType::UInt32); + Val* select_warp = IrBuilder::eqExpr(uniform_warp_id, target_warp_index); + if (is_warp_collective) { + return select_warp; + } + return SimplifyingIrBuilder::logicalAndExpr( + select_warp, createElectSyncExpr()); + } + Val* warp_size = IrBuilder::create(32L, PrimDataType::UInt64); Val* select_first_warp = IrBuilder::ltExpr( NamedScalar::getParallelIndex(ParallelType::TIDx), warp_size); @@ -516,11 +528,30 @@ Val* selectFirstWarpElectSyncPredicate(bool is_warp_collective) { // ptx::elect_sync if not warp collective. // TODO If TIDx is known at compile-time, generate custom mask. Val* createElectSyncPredicateAsync() { + const ParallelDimensionMap& pdim_map = + GpuLower::current()->info().parallelDimensionMap(); + + // If uniform warp ID is available, use warp-ID-based predicate + Val* uniform_warp_id = GpuLower::current()->uniformWarpId(); + if (uniform_warp_id != nullptr) { + Val* num_compute_warps = pdim_map.getNumComputeWarps(); + NVF_ERROR( + num_compute_warps != nullptr, "NumComputeWarps must be initialized"); + NVF_ERROR( + num_compute_warps->isConstScalar(), + "NumComputeWarps must be a constant"); + uint32_t warp_index = + static_cast(num_compute_warps->evaluate().as()); + Val* target_warp_index = + IrBuilder::create(warp_index, PrimDataType::UInt32); + Val* select_warp = IrBuilder::eqExpr(uniform_warp_id, target_warp_index); + return SimplifyingIrBuilder::logicalAndExpr( + select_warp, createElectSyncExpr()); + } + Val* zero = IrBuilder::create(0L, PrimDataType::UInt64); Val* warp_size = IrBuilder::create(32L, PrimDataType::UInt64); - const ParallelDimensionMap& pdim_map = - GpuLower::current()->info().parallelDimensionMap(); Val* async_warp_thread_index = pdim_map.getLinearThreadIndexAsync(); Val* warp_id = SimplifyingIrBuilder::divExpr(async_warp_thread_index, warp_size); @@ -671,17 +702,31 @@ Val* createMultipleExpressionElectSync( const std::vector& loops) { NVF_ERROR(pred->expr() == nullptr); + auto async_warp_loop_it = + std::find_if(loops.begin(), loops.end(), [](kir::ForLoop* fl) { + return fl->circularBufferLoopStage() == + CircularBufferLoopStage::AsyncWarp; + }); + + // If uniform warp ID is available, use warp-ID-based predicate + Val* uniform_warp_id = GpuLower::current()->uniformWarpId(); + if (uniform_warp_id != nullptr) { + if (async_warp_loop_it != loops.end()) { + return createElectSyncPredicateAsync(); + } else { + Val* target_warp_index = IrBuilder::create(0u, PrimDataType::UInt32); + Val* select_warp = IrBuilder::eqExpr(uniform_warp_id, target_warp_index); + return SimplifyingIrBuilder::logicalAndExpr( + select_warp, createElectSyncExpr()); + } + } + Val* zero = IrBuilder::create(0L, PrimDataType::UInt64); const ParallelDimensionMap& pdim_map = GpuLower::current()->info().parallelDimensionMap(); // Determine if warp specialized tma load expression. ParallelType async_warp_on = ParallelType::Serial; - auto async_warp_loop_it = - std::find_if(loops.begin(), loops.end(), [](kir::ForLoop* fl) { - return fl->circularBufferLoopStage() == - CircularBufferLoopStage::AsyncWarp; - }); if (async_warp_loop_it != loops.end()) { auto circular_buffer_type = std::get( GpuLower::current()