diff --git a/csrc/base.h b/csrc/base.h index fd8bbe8a678..9255dc42c0d 100644 --- a/csrc/base.h +++ b/csrc/base.h @@ -281,24 +281,26 @@ SPECIALIZE_PRINTER(VoidStar); SPECIALIZE_PRINTER(uint32_t); SPECIALIZE_PRINTER(int64_t); SPECIALIZE_PRINTER(uint64_t); -SPECIALIZE_PRINTER(DataType); -SPECIALIZE_PRINTER(MemoryType); -SPECIALIZE_PRINTER(UnaryOpType); + SPECIALIZE_PRINTER(BinaryOpType); -SPECIALIZE_PRINTER(TernaryOpType); -SPECIALIZE_PRINTER(LoadStoreOpType); SPECIALIZE_PRINTER(CircularBufferLoopStage); -SPECIALIZE_PRINTER(tma::TensorMapInterleave); -SPECIALIZE_PRINTER(tma::TensorMapL2Promotion); -SPECIALIZE_PRINTER(tma::TensorMapFloatOOBFill); +SPECIALIZE_PRINTER(DataType); +SPECIALIZE_PRINTER(LoadStoreOpType); +SPECIALIZE_PRINTER(MemoryType); SPECIALIZE_PRINTER(MmaInputSmemSwizzle); -SPECIALIZE_PRINTER(SwizzleType); +SPECIALIZE_PRINTER(ParallelType); SPECIALIZE_PRINTER(SwizzleMode); +SPECIALIZE_PRINTER(SwizzleType); +SPECIALIZE_PRINTER(TernaryOpType); +SPECIALIZE_PRINTER(UnaryOpType); +SPECIALIZE_PRINTER(std::optional); +SPECIALIZE_PRINTER(std::vector); SPECIALIZE_PRINTER(std::vector); SPECIALIZE_PRINTER(std::vector); -SPECIALIZE_PRINTER(std::vector); SPECIALIZE_PRINTER(std::vector); -SPECIALIZE_PRINTER(std::optional); +SPECIALIZE_PRINTER(tma::TensorMapFloatOOBFill); +SPECIALIZE_PRINTER(tma::TensorMapInterleave); +SPECIALIZE_PRINTER(tma::TensorMapL2Promotion); #undef SPECIALIZE_PRINTER diff --git a/csrc/multidevice/propagation.cpp b/csrc/multidevice/propagation.cpp index dcec94afcf5..2e372b63f34 100644 --- a/csrc/multidevice/propagation.cpp +++ b/csrc/multidevice/propagation.cpp @@ -12,6 +12,7 @@ #include #include +#include "base.h" #include "ir/interface_nodes.h" #include "ir/internal_base_nodes.h" #include "ir/internal_nodes.h" @@ -255,6 +256,12 @@ void shardLoopLike( TensorView* target, const std::unordered_set& selected_parallel_types, PropagateDirection direction) { + if (isDebugDumpEnabled(DebugDumpOption::TransformPropagator)) { + debug() << "Propagating shardings from " << ref->toString() << " to " + << target->toString() << " in " << direction << " for " + << toDelimitedString(selected_parallel_types) << std::endl; + } + std::unordered_set device_or_stream_ids; const std::unordered_map ref2target = getRef2TargetMap(ref, target, direction); diff --git a/csrc/preseg_passes/propagate_shardings.cpp b/csrc/preseg_passes/propagate_shardings.cpp index d0f326f992d..b6867d238c9 100644 --- a/csrc/preseg_passes/propagate_shardings.cpp +++ b/csrc/preseg_passes/propagate_shardings.cpp @@ -13,7 +13,6 @@ #include "ir/iostream.h" #include "ir/utils.h" #include "multidevice/propagation.h" -#include "multidevice/utils.h" #include "scheduler/utils.h" namespace nvfuser::preseg_passes { diff --git a/csrc/scheduler/utils.cpp b/csrc/scheduler/utils.cpp index e0bc874dcbb..9da73acb46c 100644 --- a/csrc/scheduler/utils.cpp +++ b/csrc/scheduler/utils.cpp @@ -43,6 +43,22 @@ #include #include +namespace nvfuser { + +std::ostream& operator<<(std::ostream& os, PropagateDirection direction) { + switch (direction) { + case PropagateDirection::kForward: + os << "Forward"; + break; + case PropagateDirection::kBackward: + os << "Backward"; + break; + } + return os; +} + +} // namespace nvfuser + namespace nvfuser::scheduler_utils { // Minimal PTX code for a no-op kernel, used for occupancy queries diff --git a/csrc/scheduler/utils.h b/csrc/scheduler/utils.h index 305cbfcdf67..b1eb8d50c33 100644 --- a/csrc/scheduler/utils.h +++ b/csrc/scheduler/utils.h @@ -29,6 +29,8 @@ class HeuristicDataCache; //! BoundedDirectionalTransformPropagator. enum class PropagateDirection { kBackward = 0, kForward }; +std::ostream& operator<<(std::ostream& os, PropagateDirection direction); + namespace scheduler_utils { // Assume any only half of the register file is available to spend on buffers,