From d87e6d7b36784ff4a5135ea082c8105ff8eb2b8e Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 12 Dec 2025 10:30:14 -0800 Subject: [PATCH 01/47] Initial introduction of RaggedIterDomain --- CMakeLists.txt | 1 + csrc/dispatch.h | 1 + csrc/ir/graphviz.cpp | 16 +++ csrc/ir/graphviz.h | 1 + csrc/ir/internal_base_nodes.cpp | 78 ++++++++++++ csrc/ir/internal_base_nodes.h | 47 ++++++++ csrc/mutator.cpp | 25 ++++ csrc/type.cpp | 2 + csrc/type.h | 1 + tests/cpp/test_ragged_iter_domain.cpp | 166 ++++++++++++++++++++++++++ 10 files changed, 338 insertions(+) create mode 100644 tests/cpp/test_ragged_iter_domain.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 5c40ab5a615..8a5170b2e8c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -991,6 +991,7 @@ list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/tests/cpp/test_polymorphic_value.cpp ${NVFUSER_ROOT}/tests/cpp/test_predicate_elimination.cpp ${NVFUSER_ROOT}/tests/cpp/test_preseg_passes.cpp + ${NVFUSER_ROOT}/tests/cpp/test_ragged_iter_domain.cpp ${NVFUSER_ROOT}/tests/cpp/test_reduction.cpp ${NVFUSER_ROOT}/tests/cpp/test_reduction_pointwise.cpp ${NVFUSER_ROOT}/tests/cpp/test_remove_bcast_squeeze.cpp diff --git a/csrc/dispatch.h b/csrc/dispatch.h index c2f235f8aab..f5d0c6a10f9 100644 --- a/csrc/dispatch.h +++ b/csrc/dispatch.h @@ -62,6 +62,7 @@ class Val; #define DISPATCH_FOR_ALL_VALS(f) \ f(IterDomain); \ + f(RaggedIterDomain); \ f(TensorDomain); \ f(TensorView); \ f(NamedScalar); diff --git a/csrc/ir/graphviz.cpp b/csrc/ir/graphviz.cpp index 7cbd23f7dd3..306abe32784 100644 --- a/csrc/ir/graphviz.cpp +++ b/csrc/ir/graphviz.cpp @@ -68,6 +68,14 @@ class IrNodeLabel final : private OptInConstDispatch { label_ << ")"; } + void handle(const RaggedIterDomain* id) override { + label_ << "Ragged" << id->getIterType(); + label_ << id->getParallelType(); + label_ << "(extents="; + label_ << IrNodeLabel::gen(id->extents()); + label_ << ")"; + } + void handle(const Split* split) override { label_ << "Split(inner=" << (split->innerSplit() ? "true" : "false") << ", factor=" << IrNodeLabel::gen(split->factor()) << ")"; @@ -356,6 +364,14 @@ void IrGraphGenerator::handle(const IterDomain* id) { addArc(id->extent(), id, "[color=gray]"); } +void IrGraphGenerator::handle(const RaggedIterDomain* id) { + graph_def_ << " " << getid(id) << " [label=\"" << IrNodeLabel::gen(id) + << "\", shape=cds, color=orange, fontsize=10];\n"; + + // Add arc from extents tensor to the ragged dimension + addArc(id->extents(), id, "[color=orange]"); +} + void IrGraphGenerator::handle(const Val* s) { printValue(s, IrNodeLabel::gen(s, detail_level_)); } diff --git a/csrc/ir/graphviz.h b/csrc/ir/graphviz.h index 49c0991044d..788f533b608 100644 --- a/csrc/ir/graphviz.h +++ b/csrc/ir/graphviz.h @@ -80,6 +80,7 @@ class IrGraphGenerator : private OptInConstDispatch { void handle(const TensorDomain*) override; void handle(const TensorView*) override; void handle(const IterDomain*) override; + void handle(const RaggedIterDomain*) override; void handle(const Val*) override; void handle(const NamedScalar*) override; diff --git a/csrc/ir/internal_base_nodes.cpp b/csrc/ir/internal_base_nodes.cpp index bb7cdba891c..e88a6d24ded 100644 --- a/csrc/ir/internal_base_nodes.cpp +++ b/csrc/ir/internal_base_nodes.cpp @@ -787,6 +787,84 @@ void validateLoopDomain( } // namespace +RaggedIterDomain::RaggedIterDomain( + IrBuilderPasskey passkey, + TensorView* extents, + IterType iter_type, + ParallelType parallel_type) + : IterDomain( + passkey, + /*start=*/passkey.ir_container_->zeroVal(), + /*extent=*/passkey.ir_container_->oneVal(), // Placeholder + /*expanded_extent=*/nullptr, + /*stop_offset=*/nullptr, + parallel_type, + iter_type, + /*is_rfactor_domain=*/false, + /*is_padded_dimension=*/false, + /*is_clustered_blocks=*/false, + /*padded_to_size=*/std::nullopt), + extents_(extents) { + // Extents must be non-null + NVF_ERROR( + extents_ != nullptr, "RaggedIterDomain requires non-null extents tensor"); + + // Extents must have integer dtype + NVF_ERROR_EQ( + extents_->dtype(), + DataType::Index, + "RaggedIterDomain extents must have index type, got ", + extents_->dtype()); + + // Only IterType::Iteration is supported at this moment + NVF_ERROR_EQ( + iter_type, + IterType::Iteration, + "Only IterType::Iteration is supported: ", + iter_type); +} + +RaggedIterDomain::RaggedIterDomain( + const RaggedIterDomain* src, + IrCloner* ir_cloner) + : IterDomain(src, ir_cloner), extents_(ir_cloner->clone(src->extents_)) {} + +NVFUSER_DEFINE_CLONE(RaggedIterDomain) + +bool RaggedIterDomain::sameAs(const Statement* other) const { + if (this == other) { + return true; + } + + if (!other->isA()) { + return false; + } + + auto other_ragged = other->as(); + + // Compare parent IterDomain properties + if (!IterDomain::sameAs(other)) { + return false; + } + + // Compare extents tensor + return extents_->sameAs(other_ragged->extents_); +} + +std::string RaggedIterDomain::toString(int indent_size) const { + std::stringstream ss; + ss << "iRagged{"; + ss << "extents=" << extents_->toString(); + ss << ", iter_type=" << getIterType(); + ss << ", parallel_type=" << getParallelType(); + ss << "}"; + return ss.str(); +} + +std::string RaggedIterDomain::toInlineString(int indent_size) const { + return toString(indent_size); +} + TensorDomain::TensorDomain( IrBuilderPasskey passkey, std::vector logical_domain, diff --git a/csrc/ir/internal_base_nodes.h b/csrc/ir/internal_base_nodes.h index a41466a34a7..a39bac1d00c 100644 --- a/csrc/ir/internal_base_nodes.h +++ b/csrc/ir/internal_base_nodes.h @@ -25,6 +25,8 @@ namespace nvfuser { // Friends for direct access to split class TensorDomain; class IterDomain; +class RaggedIterDomain; +class TensorView; class ReplayTransformations; class IndexReferenceReplay; class ViewTransform; @@ -418,6 +420,51 @@ class NVF_API IterDomain : public Val { std::optional padded_to_size_ = std::nullopt; }; +//! RaggedIterDomain represents a dimension with variable extents +//! (ragged/jagged dimension). Used for PyTorch nested tensors. +//! Unlike IterDomain, the extent varies per component +//! and is stored as a TensorView rather than a single Val. +//! +//! Key properties: +//! - extents_: TensorView containing extent for each component (1D, 2D, or N-D) +//! - Uniform execution properties: ParallelType, IterType apply to all +//! components +class NVF_API RaggedIterDomain : public IterDomain { + public: + //! \param extents TensorView containing component extents (must be integer + //! type) + //! \param iter_type Iteration type (Iteration, Reduction, etc.) + //! Only Iteration is allowed ATM. + //! \param parallel_type Parallelization strategy (applies + //! uniformly) + RaggedIterDomain( + IrBuilderPasskey passkey, + TensorView* extents, + IterType iter_type = IterType::Iteration, + ParallelType parallel_type = ParallelType::Serial); + + //! Cloning constructor for IR cloning + RaggedIterDomain(const RaggedIterDomain* src, IrCloner* ir_cloner); + + NVFUSER_DECLARE_CLONE + + bool sameAs(const Statement* other) const override; + + std::string toString(int indent_size = 0) const override; + + std::string toInlineString(int indent_size = 0) const override; + + //! Accessor for the extents tensor + TensorView* extents() const { + return extents_; + } + + private: + //! Extent tensor containing all component extents + //! Can be 1D, 2D, or N-D depending on nesting structure + TensorView* extents_ = nullptr; +}; + //! TensorDomain holds a vector of IterDomains. It holds an IterDomain for every //! logical axis in its associated tensor. TensorDomain does not directly hold //! the Tensor it is associated with, and in theory could be associated with diff --git a/csrc/mutator.cpp b/csrc/mutator.cpp index 5d586e303ac..6106a853ba2 100644 --- a/csrc/mutator.cpp +++ b/csrc/mutator.cpp @@ -134,6 +134,31 @@ void OptOutMutator::mutate(IterDomain* id) { } } +void OptOutMutator::mutate(RaggedIterDomain* id) { + // Mutate the extents TensorView + auto mutated_extents = maybeMutated(id->extents()); + + // Check if anything changed + if (mutated_extents->sameAs(id->extents())) { + return; + } + + // Create a new RaggedIterDomain with mutated extents + auto new_id = IrBuilder::createInContainer( + id->container(), + mutated_extents->as(), + id->getIterType(), + id->getParallelType()); + + // Register the mutation + registerMutation(id, new_id); + + // Preserve definition if it exists + if (Expr* def = id->definition()) { + mutateExprOutputsOnly(def); + } +} + void OptOutMutator::mutate(TensorDomain* td) { bool mutated = false; diff --git a/csrc/type.cpp b/csrc/type.cpp index 02ea6a9cd5a..8d8eea0f62b 100644 --- a/csrc/type.cpp +++ b/csrc/type.cpp @@ -322,6 +322,8 @@ static const char* val_type2string(ValType t) { return "TensorDomain"; case ValType::IterDomain: return "IterDomain"; + case ValType::RaggedIterDomain: + return "RaggedIterDomain"; case ValType::Others: return "Scalar"; case ValType::NamedScalar: diff --git a/csrc/type.h b/csrc/type.h index b011976fe83..9e91909c09d 100644 --- a/csrc/type.h +++ b/csrc/type.h @@ -45,6 +45,7 @@ namespace nvfuser { enum class ValType { TensorDomain, IterDomain, + RaggedIterDomain, TensorView, NamedScalar, Predicate, diff --git a/tests/cpp/test_ragged_iter_domain.cpp b/tests/cpp/test_ragged_iter_domain.cpp new file mode 100644 index 00000000000..7c3596152c4 --- /dev/null +++ b/tests/cpp/test_ragged_iter_domain.cpp @@ -0,0 +1,166 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include + +#include +#include +#include +#include +#include + +namespace nvfuser { + +using RaggedIterDomainTest = NVFuserTest; + +// Basic construction of RaggedIterDomain +TEST_F(RaggedIterDomainTest, Construction) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Create a TensorView to use as extents + // This represents component sizes [3, 5, 2] + auto extents = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(extents); + + // Create RaggedIterDomain + auto ragged_id = IrBuilder::create( + extents, IterType::Iteration, ParallelType::Serial); + + // Verify properties + EXPECT_NE(ragged_id, nullptr); + EXPECT_TRUE(ragged_id->isA()); + EXPECT_EQ(ragged_id->getIterType(), IterType::Iteration); + EXPECT_EQ(ragged_id->getParallelType(), ParallelType::Serial); + EXPECT_EQ(ragged_id->extents(), extents); + EXPECT_FALSE(ragged_id->isRFactorProduct()); + + // Verify extent is not null (it's the sum of extents) + EXPECT_NE(ragged_id->extent(), nullptr); +} + +// RaggedIterDomain with parallelization +TEST_F(RaggedIterDomainTest, Parallelization) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto extents = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(extents); + + // Create with TIDx parallelization + auto ragged_parallel = IrBuilder::create( + extents, IterType::Iteration, ParallelType::TIDx); + + EXPECT_EQ(ragged_parallel->getParallelType(), ParallelType::TIDx); + EXPECT_TRUE(ragged_parallel->isThreadDim()); + + // Test that parallelize method works (inherited from IterDomain) + ragged_parallel->parallelize(ParallelType::TIDy); + EXPECT_EQ(ragged_parallel->getParallelType(), ParallelType::TIDy); +} + +// sameAs comparison +TEST_F(RaggedIterDomainTest, SameAsComparison) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto extents1 = makeSymbolicTensor(1, DataType::Index); + auto extents2 = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(extents1); + fusion.addInput(extents2); + + auto ragged1 = IrBuilder::create( + extents1, IterType::Iteration, ParallelType::Serial); + + auto ragged3 = IrBuilder::create( + extents2, // Different extents + IterType::Iteration, + ParallelType::Serial); + + // Same object + EXPECT_TRUE(ragged1->sameAs(ragged1)); + + // Different extents + EXPECT_FALSE(ragged1->sameAs(ragged3)); + + // RaggedIterDomain vs regular IterDomain + auto regular_id = + IterDomainBuilder( + fusion.zeroVal(), IrBuilder::create(10L, DataType::Index)) + .build(); + EXPECT_FALSE(ragged1->sameAs(regular_id)); +} + +// Printing/toString +TEST_F(RaggedIterDomainTest, Printing) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto extents = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(extents); + + auto ragged_id = IrBuilder::create( + extents, IterType::Iteration, ParallelType::TIDx); + + // Print it + std::string str = ragged_id->toString(); + + // Verify output contains expected elements + EXPECT_NE(str.find("iRagged"), std::string::npos); + EXPECT_NE(str.find("extents"), std::string::npos); + + // Also test toInlineString + std::string inline_str = ragged_id->toInlineString(); + EXPECT_FALSE(inline_str.empty()); +} + +// Multi-dimensional extents tensor +TEST_F(RaggedIterDomainTest, MultiDimensionalExtents) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Create 2D extents tensor for nested ragged structure + auto extents_2d = makeSymbolicTensor(2, DataType::Index); + fusion.addInput(extents_2d); + + auto ragged_nested = IrBuilder::create( + extents_2d, IterType::Iteration, ParallelType::Serial); + + EXPECT_NE(ragged_nested, nullptr); + EXPECT_EQ(ragged_nested->extents(), extents_2d); +} + +// Validation - null extents should fail +TEST_F(RaggedIterDomainTest, ValidationNullExtents) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Attempt to create with null extents should throw + EXPECT_THROW( + IrBuilder::create( + nullptr, // null extents + IterType::Iteration, + ParallelType::Serial), + nvfuser::nvfError); +} + +// Validation - non-integer extents should fail +TEST_F(RaggedIterDomainTest, ValidationNonIntegerExtents) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Create float extents (should fail) + auto float_extents = makeSymbolicTensor(1, DataType::Float); + fusion.addInput(float_extents); + + EXPECT_THROW( + IrBuilder::create( + float_extents, IterType::Iteration, ParallelType::Serial), + nvfuser::nvfError); +} + +} // namespace nvfuser From f16fc4d1f92bdcdf12dbb2fe723af12f921f29a5 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 12 Dec 2025 10:42:54 -0800 Subject: [PATCH 02/47] cleanup --- csrc/ir/internal_base_nodes.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/csrc/ir/internal_base_nodes.cpp b/csrc/ir/internal_base_nodes.cpp index e88a6d24ded..b9ff9b02681 100644 --- a/csrc/ir/internal_base_nodes.cpp +++ b/csrc/ir/internal_base_nodes.cpp @@ -853,10 +853,11 @@ bool RaggedIterDomain::sameAs(const Statement* other) const { std::string RaggedIterDomain::toString(int indent_size) const { std::stringstream ss; - ss << "iRagged{"; + ss << getIterType(); + ss << getParallelType(); + ss << name(); + ss << "Ragged{"; ss << "extents=" << extents_->toString(); - ss << ", iter_type=" << getIterType(); - ss << ", parallel_type=" << getParallelType(); ss << "}"; return ss.str(); } From 23d55f15df8b041271ab202929e213b950d2e0a3 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 12 Dec 2025 11:40:11 -0800 Subject: [PATCH 03/47] fix --- tests/cpp/test_ragged_iter_domain.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cpp/test_ragged_iter_domain.cpp b/tests/cpp/test_ragged_iter_domain.cpp index 7c3596152c4..249d31afed3 100644 --- a/tests/cpp/test_ragged_iter_domain.cpp +++ b/tests/cpp/test_ragged_iter_domain.cpp @@ -110,7 +110,7 @@ TEST_F(RaggedIterDomainTest, Printing) { std::string str = ragged_id->toString(); // Verify output contains expected elements - EXPECT_NE(str.find("iRagged"), std::string::npos); + EXPECT_NE(str.find("Ragged"), std::string::npos); EXPECT_NE(str.find("extents"), std::string::npos); // Also test toInlineString From 8392332ab5316fb58add8eae53e7700460e5a7dc Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 12 Dec 2025 11:48:56 -0800 Subject: [PATCH 04/47] fix --- csrc/ir/internal_base_nodes.cpp | 30 +++++++++++++++++++++++++++++- csrc/ir/internal_base_nodes.h | 17 +++++++++++++++++ 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/csrc/ir/internal_base_nodes.cpp b/csrc/ir/internal_base_nodes.cpp index b9ff9b02681..abb5db26d2c 100644 --- a/csrc/ir/internal_base_nodes.cpp +++ b/csrc/ir/internal_base_nodes.cpp @@ -130,7 +130,34 @@ IterDomain::IterDomain( bool is_padded_dimension, bool is_clustered_blocks, std::optional padded_to_size) - : Val(passkey, ValType::IterDomain), + : IterDomain( + passkey, + ValType::IterDomain, + start, + extent, + expanded_extent, + stop_offset, + parallel_type, + iter_type, + is_rfactor_domain, + is_padded_dimension, + is_clustered_blocks, + padded_to_size) {} + +IterDomain::IterDomain( + IrBuilderPasskey passkey, + ValType vtype, + Val* start, + Val* extent, + Val* expanded_extent, + Val* stop_offset, + ParallelType parallel_type, + IterType iter_type, + bool is_rfactor_domain, + bool is_padded_dimension, + bool is_clustered_blocks, + std::optional padded_to_size) + : Val(passkey, vtype), start_(start), extent_(extent), expanded_extent_(expanded_extent), @@ -794,6 +821,7 @@ RaggedIterDomain::RaggedIterDomain( ParallelType parallel_type) : IterDomain( passkey, + ValType::RaggedIterDomain, /*start=*/passkey.ir_container_->zeroVal(), /*extent=*/passkey.ir_container_->oneVal(), // Placeholder /*expanded_extent=*/nullptr, diff --git a/csrc/ir/internal_base_nodes.h b/csrc/ir/internal_base_nodes.h index a39bac1d00c..fb4fd8651ce 100644 --- a/csrc/ir/internal_base_nodes.h +++ b/csrc/ir/internal_base_nodes.h @@ -391,6 +391,23 @@ class NVF_API IterDomain : public Val { friend TensorDomain; friend ReplayTransformations; friend IndexReferenceReplay; + friend RaggedIterDomain; + + //! Protected constructor for derived classes (e.g., RaggedIterDomain) + //! that need to override the ValType + IterDomain( + IrBuilderPasskey passkey, + ValType vtype, + Val* start, + Val* extent, + Val* expanded_extent, + Val* stop_offset, + ParallelType parallel_type, + IterType iter_type, + bool is_rfactor_domain, + bool is_padded_dimension, + bool is_clustered_blocks, + std::optional padded_to_size); private: //! Valid range is defined as [start:-stop_offset] From 787dfecff93345a0d1959bd83aa764d4e2514f2a Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 12 Dec 2025 11:54:06 -0800 Subject: [PATCH 05/47] unit test --- tests/cpp/test_ragged_iter_domain.cpp | 31 +++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tests/cpp/test_ragged_iter_domain.cpp b/tests/cpp/test_ragged_iter_domain.cpp index 249d31afed3..856bbfea2aa 100644 --- a/tests/cpp/test_ragged_iter_domain.cpp +++ b/tests/cpp/test_ragged_iter_domain.cpp @@ -163,4 +163,35 @@ TEST_F(RaggedIterDomainTest, ValidationNonIntegerExtents) { nvfuser::nvfError); } +// ValType test - ensure RaggedIterDomain has correct ValType +TEST_F(RaggedIterDomainTest, ValType) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto extents = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(extents); + + auto ragged_id = IrBuilder::create( + extents, IterType::Iteration, ParallelType::Serial); + + // Verify ValType is RaggedIterDomain, not IterDomain + EXPECT_EQ(ragged_id->vtype(), ValType::RaggedIterDomain); + EXPECT_NE(ragged_id->vtype(), ValType::IterDomain); + + // Verify getValType also returns the correct type + EXPECT_TRUE(ragged_id->getValType().has_value()); + EXPECT_EQ(ragged_id->getValType().value(), ValType::RaggedIterDomain); + + // Compare with a regular IterDomain + auto regular_id = + IterDomainBuilder( + fusion.zeroVal(), IrBuilder::create(10L, DataType::Index)) + .build(); + EXPECT_EQ(regular_id->vtype(), ValType::IterDomain); + EXPECT_NE(regular_id->vtype(), ValType::RaggedIterDomain); + + // Verify they have different types + EXPECT_NE(ragged_id->vtype(), regular_id->vtype()); +} + } // namespace nvfuser From a0b40a39559affa87e93f461922e5a2aaecd8974 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 12 Dec 2025 11:57:44 -0800 Subject: [PATCH 06/47] cleanup --- tests/cpp/test_ragged_iter_domain.cpp | 47 +++++++++------------------ 1 file changed, 15 insertions(+), 32 deletions(-) diff --git a/tests/cpp/test_ragged_iter_domain.cpp b/tests/cpp/test_ragged_iter_domain.cpp index 856bbfea2aa..7274fcbb36b 100644 --- a/tests/cpp/test_ragged_iter_domain.cpp +++ b/tests/cpp/test_ragged_iter_domain.cpp @@ -23,7 +23,7 @@ TEST_F(RaggedIterDomainTest, Construction) { FusionGuard fg(&fusion); // Create a TensorView to use as extents - // This represents component sizes [3, 5, 2] + // This represents component sizes such as [3, 5, 2] auto extents = makeSymbolicTensor(1, DataType::Index); fusion.addInput(extents); @@ -41,6 +41,20 @@ TEST_F(RaggedIterDomainTest, Construction) { // Verify extent is not null (it's the sum of extents) EXPECT_NE(ragged_id->extent(), nullptr); + + // Verify ValType is RaggedIterDomain, not IterDomain + EXPECT_EQ(ragged_id->vtype(), ValType::RaggedIterDomain); + EXPECT_NE(ragged_id->vtype(), ValType::IterDomain); + EXPECT_TRUE(ragged_id->getValType().has_value()); + EXPECT_EQ(ragged_id->getValType().value(), ValType::RaggedIterDomain); + + // Compare with a regular IterDomain to ensure different types + auto regular_id = + IterDomainBuilder( + fusion.zeroVal(), IrBuilder::create(10L, DataType::Index)) + .build(); + EXPECT_EQ(regular_id->vtype(), ValType::IterDomain); + EXPECT_NE(ragged_id->vtype(), regular_id->vtype()); } // RaggedIterDomain with parallelization @@ -163,35 +177,4 @@ TEST_F(RaggedIterDomainTest, ValidationNonIntegerExtents) { nvfuser::nvfError); } -// ValType test - ensure RaggedIterDomain has correct ValType -TEST_F(RaggedIterDomainTest, ValType) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto extents = makeSymbolicTensor(1, DataType::Index); - fusion.addInput(extents); - - auto ragged_id = IrBuilder::create( - extents, IterType::Iteration, ParallelType::Serial); - - // Verify ValType is RaggedIterDomain, not IterDomain - EXPECT_EQ(ragged_id->vtype(), ValType::RaggedIterDomain); - EXPECT_NE(ragged_id->vtype(), ValType::IterDomain); - - // Verify getValType also returns the correct type - EXPECT_TRUE(ragged_id->getValType().has_value()); - EXPECT_EQ(ragged_id->getValType().value(), ValType::RaggedIterDomain); - - // Compare with a regular IterDomain - auto regular_id = - IterDomainBuilder( - fusion.zeroVal(), IrBuilder::create(10L, DataType::Index)) - .build(); - EXPECT_EQ(regular_id->vtype(), ValType::IterDomain); - EXPECT_NE(regular_id->vtype(), ValType::RaggedIterDomain); - - // Verify they have different types - EXPECT_NE(ragged_id->vtype(), regular_id->vtype()); -} - } // namespace nvfuser From dbdd917ee07ec787949399f694cceb8742717511 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 12 Dec 2025 13:05:35 -0800 Subject: [PATCH 07/47] Fix IterVisitor --- csrc/iter_visitor.cpp | 9 +++++++++ tests/cpp/test_ragged_iter_domain.cpp | 24 ++++++++++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/csrc/iter_visitor.cpp b/csrc/iter_visitor.cpp index 2386a4729b2..0c8c208417e 100644 --- a/csrc/iter_visitor.cpp +++ b/csrc/iter_visitor.cpp @@ -69,6 +69,15 @@ class MemberStatements : public OptOutDispatch { next_stmts_.push_back(stmt->stopOffset()); } + void handle(RaggedIterDomain* stmt) final { + // Visit the standard IterDomain fields + next_stmts_.push_back(stmt->start()); + next_stmts_.push_back(stmt->extent()); + next_stmts_.push_back(stmt->stopOffset()); + // Visit the extents TensorView (ragged-specific field) + next_stmts_.push_back(stmt->extents()); + } + void handle(TensorDomain* stmt) final { for (const std::vector* dom : stmt->allDomains()) { next_stmts_.insert(next_stmts_.end(), dom->begin(), dom->end()); diff --git a/tests/cpp/test_ragged_iter_domain.cpp b/tests/cpp/test_ragged_iter_domain.cpp index 7274fcbb36b..032a1d154c0 100644 --- a/tests/cpp/test_ragged_iter_domain.cpp +++ b/tests/cpp/test_ragged_iter_domain.cpp @@ -177,4 +177,28 @@ TEST_F(RaggedIterDomainTest, ValidationNonIntegerExtents) { nvfuser::nvfError); } +// IterVisitor test - ensure graph traversal visits extents field +TEST_F(RaggedIterDomainTest, IterVisitor) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Create extents TensorView + auto extents = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(extents); + + // Create RaggedIterDomain + auto ragged_id = IrBuilder::create( + extents, IterType::Iteration, ParallelType::Serial); + + // Collect all statements reachable from the RaggedIterDomain + // Use traverse_members=true to visit member fields + std::vector from_vals = {ragged_id}; + auto all_stmts = StmtSort::getStmtsTo(from_vals, /*traverse_members=*/true); + + // Verify the extents TensorView is visited (this is the critical check) + EXPECT_TRUE( + std::find(all_stmts.begin(), all_stmts.end(), extents) != all_stmts.end()) + << "IterVisitor should traverse the extents_ field of RaggedIterDomain"; +} + } // namespace nvfuser From cdbd81e46bb57610da0326345cf6ae09f68e90ac Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 12 Dec 2025 13:43:28 -0800 Subject: [PATCH 08/47] cleanup --- tests/cpp/test_ragged_iter_domain.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/cpp/test_ragged_iter_domain.cpp b/tests/cpp/test_ragged_iter_domain.cpp index 032a1d154c0..4002f854947 100644 --- a/tests/cpp/test_ragged_iter_domain.cpp +++ b/tests/cpp/test_ragged_iter_domain.cpp @@ -191,11 +191,10 @@ TEST_F(RaggedIterDomainTest, IterVisitor) { extents, IterType::Iteration, ParallelType::Serial); // Collect all statements reachable from the RaggedIterDomain - // Use traverse_members=true to visit member fields std::vector from_vals = {ragged_id}; auto all_stmts = StmtSort::getStmtsTo(from_vals, /*traverse_members=*/true); - // Verify the extents TensorView is visited (this is the critical check) + // Verify the extents TensorView is visited EXPECT_TRUE( std::find(all_stmts.begin(), all_stmts.end(), extents) != all_stmts.end()) << "IterVisitor should traverse the extents_ field of RaggedIterDomain"; From d4c8d7f72bb07ce9815d9c43aabbea6d8add1e92 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 12 Dec 2025 11:15:03 -0800 Subject: [PATCH 09/47] WIP: partition --- csrc/ir/internal_base_nodes.cpp | 90 +++++++++++++++++++++++++++ csrc/ir/internal_base_nodes.h | 14 +++++ tests/cpp/test_ragged_iter_domain.cpp | 88 ++++++++++++++++++++++++-- 3 files changed, 188 insertions(+), 4 deletions(-) diff --git a/csrc/ir/internal_base_nodes.cpp b/csrc/ir/internal_base_nodes.cpp index abb5db26d2c..006adad536f 100644 --- a/csrc/ir/internal_base_nodes.cpp +++ b/csrc/ir/internal_base_nodes.cpp @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -894,6 +895,95 @@ std::string RaggedIterDomain::toInlineString(int indent_size) const { return toString(indent_size); } +std::pair RaggedIterDomain::partition( + IterDomain* in, + TensorView* offsets) { + NVF_ERROR(in != nullptr, "partition: input IterDomain is null"); + + NVF_ERROR( + !in->isA(), + "partition: input is already RaggedIterDomain, cannot partition again"); + + NVF_ERROR_EQ(in->getParallelType(), ParallelType::Serial, + "Partitioning of parallelized IterDomain not supported: ", + in->toString()); + + NVF_ERROR(offsets != nullptr, "partition: offsets tensor is null"); + + NVF_ERROR( + offsets->dtype() == DataType::Index, + "partition: offsets must have Index type, got ", + offsets->dtype()); + + const auto& offsets_domain = offsets->getLogicalDomain(); + NVF_ERROR( + !offsets_domain.empty(), + "partition: offsets tensor must have at least one dimension"); + + auto container = in->container(); + + // Compute extents from offsets: extents[i] = offsets[i+1] - offsets[i] + // Slice along the last dimension of the offsets tensor + // offsets_left = offsets[..., :-1] (all but last element in last dim) + // offsets_right = offsets[..., 1:] (all but first element in last dim) + + const auto last_dim = offsets_domain.size() - 1; + auto offsets_len = offsets_domain[last_dim]->extent(); + + auto zero = container->zeroVal(DataType::Index); + auto one = container->oneVal(DataType::Index); + auto len_minus_one = sub(offsets_len, one); + + // Build slice ranges for all dimensions + // For all dimensions except the last, use full range (:) + // For the last dimension, use [:-1] for left and [1:] for right + std::vector left_ranges; + std::vector right_ranges; + + for (const auto i : arange(offsets_domain.size())) { + if (i < last_dim) { + // Full range for non-last dimensions + Slice s; + s.start = zero; + s.stop = offsets_domain[i]->extent(); + left_ranges.push_back(s); + right_ranges.push_back(s); + } else { + // Last dimension: left uses [:-1], right uses [1:] + Slice left_s; + left_s.start = zero; + left_s.stop = len_minus_one; + left_ranges.push_back(left_s); + + Slice right_s; + right_s.start = one; + right_s.stop = offsets_len; + right_ranges.push_back(right_s); + } + } + + auto offsets_left = slice(offsets, left_ranges); + auto offsets_right = slice(offsets, right_ranges); + + // Compute extents: extents = offsets_right - offsets_left + auto extents = sub(offsets_right, offsets_left); + + // Create batch IterDomain + // Batch extent = number of components = len(offsets) - 1 + auto batch_extent = len_minus_one; + auto batch_id = IterDomainBuilder(zero, batch_extent) + .parallel_type(ParallelType::Serial) + .iter_type(IterType::Iteration) + .build(); + + // Create RaggedIterDomain with computed extents + auto ragged_id = IrBuilder::create( + extents, in->getIterType()); + + // Return pair + return {batch_id, ragged_id}; +} + TensorDomain::TensorDomain( IrBuilderPasskey passkey, std::vector logical_domain, diff --git a/csrc/ir/internal_base_nodes.h b/csrc/ir/internal_base_nodes.h index fb4fd8651ce..59f099ec517 100644 --- a/csrc/ir/internal_base_nodes.h +++ b/csrc/ir/internal_base_nodes.h @@ -471,6 +471,20 @@ class NVF_API RaggedIterDomain : public IterDomain { std::string toInlineString(int indent_size = 0) const override; + //! Partition an IterDomain into batch and ragged dimensions + //! Creates a batch IterDomain and a RaggedIterDomain based on offsets + //! + //! \param in Input IterDomain to partition (must be regular IterDomain) + //! \param offsets Offset tensor defining partition boundaries + //! Shape: [num_components + 1], values: [0, off1, off2, ..., total] + //! Extents are computed as: extents[i] = offsets[i+1] - offsets[i] + //! \return Pair of (batch_id, ragged_id) + //! batch_id: IterDomain with extent = num_components + //! ragged_id: RaggedIterDomain with extents computed from offsets + static std::pair partition( + IterDomain* in, + TensorView* offsets); + //! Accessor for the extents tensor TensorView* extents() const { return extents_; diff --git a/tests/cpp/test_ragged_iter_domain.cpp b/tests/cpp/test_ragged_iter_domain.cpp index 4002f854947..b537f6c95a8 100644 --- a/tests/cpp/test_ragged_iter_domain.cpp +++ b/tests/cpp/test_ragged_iter_domain.cpp @@ -126,10 +126,6 @@ TEST_F(RaggedIterDomainTest, Printing) { // Verify output contains expected elements EXPECT_NE(str.find("Ragged"), std::string::npos); EXPECT_NE(str.find("extents"), std::string::npos); - - // Also test toInlineString - std::string inline_str = ragged_id->toInlineString(); - EXPECT_FALSE(inline_str.empty()); } // Multi-dimensional extents tensor @@ -200,4 +196,88 @@ TEST_F(RaggedIterDomainTest, IterVisitor) { << "IterVisitor should traverse the extents_ field of RaggedIterDomain"; } +// Partition operation - basic test +TEST_F(RaggedIterDomainTest, PartitionBasic) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Create input IterDomain + auto input_id = IterDomainBuilder( + fusion.zeroVal(), IrBuilder::create(10L, DataType::Index)) + .build(); + + // Create a symbolic offset tensor + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Partition the IterDomain + auto [batch_id, ragged_id] = RaggedIterDomain::partition(input_id, offsets); + + // Verify batch IterDomain + EXPECT_NE(batch_id, nullptr); + EXPECT_TRUE(batch_id->isA()); + EXPECT_FALSE(batch_id->isA()); + EXPECT_EQ(batch_id->getIterType(), IterType::Iteration); + + // Verify RaggedIterDomain + EXPECT_NE(ragged_id, nullptr); + EXPECT_TRUE(ragged_id->isA()); + EXPECT_EQ(ragged_id->getIterType(), IterType::Iteration); + EXPECT_NE(ragged_id->extents(), nullptr); +} + +// Partition operation - multi-dimensional offsets +TEST_F(RaggedIterDomainTest, PartitionMultiDimensional) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto input_id = IterDomainBuilder( + fusion.zeroVal(), IrBuilder::create(100L, DataType::Index)) + .build(); + + // Create 2D offsets tensor for nested ragged structure + auto offsets_2d = makeSymbolicTensor(2, DataType::Index); + fusion.addInput(offsets_2d); + + // Partition should work with multi-dimensional offsets + auto [batch_id, ragged_id] = RaggedIterDomain::partition(input_id, offsets_2d); + + EXPECT_NE(batch_id, nullptr); + EXPECT_NE(ragged_id, nullptr); + EXPECT_TRUE(ragged_id->isA()); + EXPECT_NE(ragged_id->extents(), nullptr); +} + +// Partition operation - validation tests +TEST_F(RaggedIterDomainTest, PartitionValidation) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto input_id = IterDomainBuilder( + fusion.zeroVal(), IrBuilder::create(10L, DataType::Index)) + .build(); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Test 1: Null input should fail + EXPECT_THROW(RaggedIterDomain::partition(nullptr, offsets), nvfuser::nvfError); + + // Test 2: Null offsets should fail + EXPECT_THROW(RaggedIterDomain::partition(input_id, nullptr), nvfuser::nvfError); + + // Test 3: Non-Index offsets should fail + auto float_offsets = makeSymbolicTensor(1, DataType::Float); + fusion.addInput(float_offsets); + EXPECT_THROW( + RaggedIterDomain::partition(input_id, float_offsets), nvfuser::nvfError); + + // Test 4: Cannot partition RaggedIterDomain + auto extents = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(extents); + auto ragged_id = IrBuilder::create( + extents, IterType::Iteration, ParallelType::Serial); + EXPECT_THROW(RaggedIterDomain::partition(ragged_id, offsets), nvfuser::nvfError); +} + } // namespace nvfuser From 9575a13b09e1b5f275588df5a956c25e987c43ab Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 12 Dec 2025 18:26:22 -0800 Subject: [PATCH 10/47] Partition expr --- csrc/dispatch.h | 1 + csrc/ir/internal_base_nodes.cpp | 31 ++++++---- csrc/ir/internal_base_nodes.h | 8 +-- csrc/ir/internal_nodes.cpp | 33 ++++++++++ csrc/ir/internal_nodes.h | 44 ++++++++++++++ tests/cpp/test_ragged_iter_domain.cpp | 87 ++++++++++++++++++--------- 6 files changed, 159 insertions(+), 45 deletions(-) diff --git a/csrc/dispatch.h b/csrc/dispatch.h index d14d5257fe0..822ababb149 100644 --- a/csrc/dispatch.h +++ b/csrc/dispatch.h @@ -115,6 +115,7 @@ class Val; f(TopKOp); \ f(ScanOp); \ f(Merge); \ + f(Partition); \ f(Swizzle); \ f(Swizzle2D); \ f(Resize); \ diff --git a/csrc/ir/internal_base_nodes.cpp b/csrc/ir/internal_base_nodes.cpp index 006adad536f..c95014bdadb 100644 --- a/csrc/ir/internal_base_nodes.cpp +++ b/csrc/ir/internal_base_nodes.cpp @@ -904,9 +904,11 @@ std::pair RaggedIterDomain::partition( !in->isA(), "partition: input is already RaggedIterDomain, cannot partition again"); - NVF_ERROR_EQ(in->getParallelType(), ParallelType::Serial, - "Partitioning of parallelized IterDomain not supported: ", - in->toString()); + NVF_ERROR_EQ( + in->getParallelType(), + ParallelType::Serial, + "Partitioning of parallelized IterDomain not supported: ", + in->toString()); NVF_ERROR(offsets != nullptr, "partition: offsets tensor is null"); @@ -968,20 +970,23 @@ std::pair RaggedIterDomain::partition( // Compute extents: extents = offsets_right - offsets_left auto extents = sub(offsets_right, offsets_left); - // Create batch IterDomain - // Batch extent = number of components = len(offsets) - 1 - auto batch_extent = len_minus_one; - auto batch_id = IterDomainBuilder(zero, batch_extent) - .parallel_type(ParallelType::Serial) - .iter_type(IterType::Iteration) - .build(); + // Create component IterDomain + // Component extent = number of components = len(offsets) - 1 + auto component_extent = len_minus_one; + auto component_id = IterDomainBuilder(zero, component_extent) + .parallel_type(ParallelType::Serial) + .iter_type(IterType::Iteration) + .build(); // Create RaggedIterDomain with computed extents - auto ragged_id = IrBuilder::create( - extents, in->getIterType()); + auto ragged_id = + IrBuilder::create(extents, in->getIterType()); + + // Create the Partition expr to represent this transformation + IrBuilder::create(component_id, ragged_id, in, extents); // Return pair - return {batch_id, ragged_id}; + return {component_id, ragged_id}; } TensorDomain::TensorDomain( diff --git a/csrc/ir/internal_base_nodes.h b/csrc/ir/internal_base_nodes.h index 59f099ec517..2a2e85d0458 100644 --- a/csrc/ir/internal_base_nodes.h +++ b/csrc/ir/internal_base_nodes.h @@ -471,15 +471,15 @@ class NVF_API RaggedIterDomain : public IterDomain { std::string toInlineString(int indent_size = 0) const override; - //! Partition an IterDomain into batch and ragged dimensions - //! Creates a batch IterDomain and a RaggedIterDomain based on offsets + //! Partition an IterDomain into component and ragged dimensions + //! Creates a component IterDomain and a RaggedIterDomain based on offsets //! //! \param in Input IterDomain to partition (must be regular IterDomain) //! \param offsets Offset tensor defining partition boundaries //! Shape: [num_components + 1], values: [0, off1, off2, ..., total] //! Extents are computed as: extents[i] = offsets[i+1] - offsets[i] - //! \return Pair of (batch_id, ragged_id) - //! batch_id: IterDomain with extent = num_components + //! \return Pair of (component_id, ragged_id) + //! component_id: IterDomain with extent = num_components //! ragged_id: RaggedIterDomain with extents computed from offsets static std::pair partition( IterDomain* in, diff --git a/csrc/ir/internal_nodes.cpp b/csrc/ir/internal_nodes.cpp index 2f0f39afaa5..bdbe706a3cf 100644 --- a/csrc/ir/internal_nodes.cpp +++ b/csrc/ir/internal_nodes.cpp @@ -2609,6 +2609,39 @@ std::string Merge::toInlineString(int indent_size) const { NVFUSER_DEFINE_CLONE_AND_CREATE(Merge) +Partition::Partition( + IrBuilderPasskey passkey, + IterDomain* component, + RaggedIterDomain* ragged, + IterDomain* in, + TensorView* extents) + : Expr(passkey) { + addOutput(component); + addOutput(ragged); + addInput(in); + // Should the extents tensor be an input rather than an attribute? + addAttribute(extents); +} + +std::string Partition::toString(int indent_size) const { + std::stringstream ss; + ss << "Partition: "; + ss << in()->toString(); + ss << " by extents " << extents()->toString(); + ss << " -> component: "; + ss << component()->toString(); + ss << ", ragged: "; + ss << ragged()->toString(); + ss << "\n"; + return ss.str(); +} + +std::string Partition::toInlineString(int indent_size) const { + NVF_CHECK(false, "Partition can not be printed inline"); +} + +NVFUSER_DEFINE_CLONE_AND_CREATE(Partition) + Swizzle::Swizzle( IrBuilderPasskey passkey, IterDomain* out_x, diff --git a/csrc/ir/internal_nodes.h b/csrc/ir/internal_nodes.h index fea0f565082..9393dc3016b 100644 --- a/csrc/ir/internal_nodes.h +++ b/csrc/ir/internal_nodes.h @@ -1801,6 +1801,50 @@ class NVF_API Merge : public Expr { } }; +//! Partition an IterDomain into component and ragged dimensions +//! Creates a component IterDomain and a RaggedIterDomain based on extents +//! tensor The extents tensor contains the extent for each component +class NVF_API Partition : public Expr { + public: + using Expr::Expr; + + Partition( + IrBuilderPasskey, + IterDomain* component, + RaggedIterDomain* ragged, + IterDomain* in, + TensorView* extents); + + NVFUSER_DECLARE_CLONE_AND_CREATE + + const char* getOpString() const override { + return "Partition"; + } + + std::string toString(int indent_size = 0) const override; + std::string toInlineString(int indent_size = 0) const override; + + //! Component dimension output (extent = num_components) + IterDomain* component() const { + return output(0)->as(); + } + + //! Ragged dimension output (variable extents per component) + RaggedIterDomain* ragged() const { + return output(1)->as(); + } + + //! Input IterDomain being partitioned + IterDomain* in() const { + return input(0)->as(); + } + + //! Extents tensor containing extent for each component + TensorView* extents() const { + return attributeVal(0)->as(); + } +}; + class Swizzle : public Expr { public: using Expr::Expr; diff --git a/tests/cpp/test_ragged_iter_domain.cpp b/tests/cpp/test_ragged_iter_domain.cpp index b537f6c95a8..a9d4c79c67e 100644 --- a/tests/cpp/test_ragged_iter_domain.cpp +++ b/tests/cpp/test_ragged_iter_domain.cpp @@ -32,7 +32,7 @@ TEST_F(RaggedIterDomainTest, Construction) { extents, IterType::Iteration, ParallelType::Serial); // Verify properties - EXPECT_NE(ragged_id, nullptr); + EXPECT_TRUE(ragged_id != nullptr); EXPECT_TRUE(ragged_id->isA()); EXPECT_EQ(ragged_id->getIterType(), IterType::Iteration); EXPECT_EQ(ragged_id->getParallelType(), ParallelType::Serial); @@ -40,7 +40,7 @@ TEST_F(RaggedIterDomainTest, Construction) { EXPECT_FALSE(ragged_id->isRFactorProduct()); // Verify extent is not null (it's the sum of extents) - EXPECT_NE(ragged_id->extent(), nullptr); + EXPECT_TRUE(ragged_id->extent() != nullptr); // Verify ValType is RaggedIterDomain, not IterDomain EXPECT_EQ(ragged_id->vtype(), ValType::RaggedIterDomain); @@ -140,7 +140,7 @@ TEST_F(RaggedIterDomainTest, MultiDimensionalExtents) { auto ragged_nested = IrBuilder::create( extents_2d, IterType::Iteration, ParallelType::Serial); - EXPECT_NE(ragged_nested, nullptr); + EXPECT_TRUE(ragged_nested != nullptr); EXPECT_EQ(ragged_nested->extents(), extents_2d); } @@ -202,28 +202,53 @@ TEST_F(RaggedIterDomainTest, PartitionBasic) { FusionGuard fg(&fusion); // Create input IterDomain - auto input_id = IterDomainBuilder( - fusion.zeroVal(), IrBuilder::create(10L, DataType::Index)) - .build(); + auto input_id = + IterDomainBuilder( + fusion.zeroVal(), IrBuilder::create(-1, DataType::Index)) + .build(); // Create a symbolic offset tensor auto offsets = makeSymbolicTensor(1, DataType::Index); fusion.addInput(offsets); // Partition the IterDomain - auto [batch_id, ragged_id] = RaggedIterDomain::partition(input_id, offsets); + auto [component_id, ragged_id] = + RaggedIterDomain::partition(input_id, offsets); - // Verify batch IterDomain - EXPECT_NE(batch_id, nullptr); - EXPECT_TRUE(batch_id->isA()); - EXPECT_FALSE(batch_id->isA()); - EXPECT_EQ(batch_id->getIterType(), IterType::Iteration); + // Verify component IterDomain + EXPECT_TRUE(component_id != nullptr); + EXPECT_TRUE(component_id->isA()); + EXPECT_FALSE(component_id->isA()); // Verify RaggedIterDomain - EXPECT_NE(ragged_id, nullptr); + EXPECT_TRUE(ragged_id != nullptr); EXPECT_TRUE(ragged_id->isA()); - EXPECT_EQ(ragged_id->getIterType(), IterType::Iteration); - EXPECT_NE(ragged_id->extents(), nullptr); + EXPECT_TRUE(ragged_id->extents() != nullptr); + + // Verify that a Partition expr was created + EXPECT_TRUE(component_id->definition() != nullptr); + EXPECT_TRUE(component_id->definition()->isA()); + + // Both outputs should have the same definition (the Partition expr) + EXPECT_EQ(component_id->definition(), ragged_id->definition()); + + // Verify the Partition expr structure + auto partition_expr = component_id->definition()->as(); + EXPECT_EQ(partition_expr->component(), component_id); + EXPECT_EQ(partition_expr->ragged(), ragged_id); + EXPECT_EQ(partition_expr->in(), input_id); + EXPECT_EQ(partition_expr->extents(), ragged_id->extents()); + + // Verify the expr has correct inputs and outputs + EXPECT_EQ(partition_expr->inputs().size(), 1); + EXPECT_EQ(partition_expr->outputs().size(), 2); + EXPECT_EQ(partition_expr->input(0), input_id); + EXPECT_EQ(partition_expr->output(0), component_id); + EXPECT_EQ(partition_expr->output(1), ragged_id); + + // Test toString + std::string str = partition_expr->toString(); + EXPECT_TRUE(str.find("Partition") != std::string::npos); } // Partition operation - multi-dimensional offsets @@ -231,21 +256,23 @@ TEST_F(RaggedIterDomainTest, PartitionMultiDimensional) { Fusion fusion; FusionGuard fg(&fusion); - auto input_id = IterDomainBuilder( - fusion.zeroVal(), IrBuilder::create(100L, DataType::Index)) - .build(); + auto input_id = + IterDomainBuilder( + fusion.zeroVal(), IrBuilder::create(100L, DataType::Index)) + .build(); // Create 2D offsets tensor for nested ragged structure auto offsets_2d = makeSymbolicTensor(2, DataType::Index); fusion.addInput(offsets_2d); // Partition should work with multi-dimensional offsets - auto [batch_id, ragged_id] = RaggedIterDomain::partition(input_id, offsets_2d); + auto [component_id, ragged_id] = + RaggedIterDomain::partition(input_id, offsets_2d); - EXPECT_NE(batch_id, nullptr); - EXPECT_NE(ragged_id, nullptr); + EXPECT_TRUE(component_id != nullptr); + EXPECT_TRUE(ragged_id != nullptr); EXPECT_TRUE(ragged_id->isA()); - EXPECT_NE(ragged_id->extents(), nullptr); + EXPECT_TRUE(ragged_id->extents() != nullptr); } // Partition operation - validation tests @@ -253,18 +280,21 @@ TEST_F(RaggedIterDomainTest, PartitionValidation) { Fusion fusion; FusionGuard fg(&fusion); - auto input_id = IterDomainBuilder( - fusion.zeroVal(), IrBuilder::create(10L, DataType::Index)) - .build(); + auto input_id = + IterDomainBuilder( + fusion.zeroVal(), IrBuilder::create(10L, DataType::Index)) + .build(); auto offsets = makeSymbolicTensor(1, DataType::Index); fusion.addInput(offsets); // Test 1: Null input should fail - EXPECT_THROW(RaggedIterDomain::partition(nullptr, offsets), nvfuser::nvfError); + EXPECT_THROW( + RaggedIterDomain::partition(nullptr, offsets), nvfuser::nvfError); // Test 2: Null offsets should fail - EXPECT_THROW(RaggedIterDomain::partition(input_id, nullptr), nvfuser::nvfError); + EXPECT_THROW( + RaggedIterDomain::partition(input_id, nullptr), nvfuser::nvfError); // Test 3: Non-Index offsets should fail auto float_offsets = makeSymbolicTensor(1, DataType::Float); @@ -277,7 +307,8 @@ TEST_F(RaggedIterDomainTest, PartitionValidation) { fusion.addInput(extents); auto ragged_id = IrBuilder::create( extents, IterType::Iteration, ParallelType::Serial); - EXPECT_THROW(RaggedIterDomain::partition(ragged_id, offsets), nvfuser::nvfError); + EXPECT_THROW( + RaggedIterDomain::partition(ragged_id, offsets), nvfuser::nvfError); } } // namespace nvfuser From a054ae0c89000f4aa9f7dc2cc0497ee3b12e71dd Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 12 Dec 2025 20:37:57 -0800 Subject: [PATCH 11/47] TensorView::partition --- csrc/ir/interface_nodes.h | 9 ++++++ csrc/ir/internal_base_nodes.cpp | 16 ++++++++++ csrc/ir/internal_base_nodes.h | 3 ++ csrc/tensor_view.cpp | 45 +++++++++++++++++++++++++++ tests/cpp/test_ragged_iter_domain.cpp | 34 ++++++++++++++++++++ 5 files changed, 107 insertions(+) diff --git a/csrc/ir/interface_nodes.h b/csrc/ir/interface_nodes.h index e9172080640..ea0236a1a7b 100644 --- a/csrc/ir/interface_nodes.h +++ b/csrc/ir/interface_nodes.h @@ -619,6 +619,15 @@ class NVF_API TensorView : public Val { return merge(axis, axis + 1); } + // Partition "axis" into component and ragged dimensions based on offsets + // The offsets tensor defines partition boundaries where: + // Shape: [num_components + 1], values: [0, off1, off2, ..., total] + // Extents are computed as: extents[i] = offsets[i+1] - offsets[i] + // Returns this TensorView with the axis replaced by component and ragged dims + // e.g. partition(0, offsets) on tv[id{N}] results in: + // tv[id{num_components}, ragged_id{extents}] + TensorView* partition(int64_t axis, TensorView* offsets); + // Flatten the axis from `from` to `to` into a single axis. // Both `from` and `to` are inclusive. TensorView* flatten(int64_t from = 0, int64_t to = -1); diff --git a/csrc/ir/internal_base_nodes.cpp b/csrc/ir/internal_base_nodes.cpp index c95014bdadb..095eb208484 100644 --- a/csrc/ir/internal_base_nodes.cpp +++ b/csrc/ir/internal_base_nodes.cpp @@ -1593,6 +1593,22 @@ void TensorDomain::merge(int64_t axis_o, int64_t axis_i) { loop_domain_.insert(loop_domain_.begin() + td_outer_pos, merged_id); } +// Partition "axis" into component and ragged dimensions. Follow the +// pattern of TensorDomain::split. +void TensorDomain::partition(int64_t axis, TensorView* offsets) { + NVF_ERROR(nDims() > 0, "Tried to do partition on a 0-dim domain"); + axis = wrapDim(axis); + + IterDomain* id = this->axis(axis); + + auto [component_id, ragged_id] = RaggedIterDomain::partition(id, offsets); + + // Remove the original axis and insert component and ragged dimensions + loop_domain_.erase(loop_domain_.begin() + axis); + loop_domain_.insert(loop_domain_.begin() + axis, ragged_id); + loop_domain_.insert(loop_domain_.begin() + axis, component_id); +} + // Reorder axes according to map[old_pos] = new_pos void TensorDomain::reorder( const std::unordered_map& old2new_) { diff --git a/csrc/ir/internal_base_nodes.h b/csrc/ir/internal_base_nodes.h index 2a2e85d0458..5888f71b989 100644 --- a/csrc/ir/internal_base_nodes.h +++ b/csrc/ir/internal_base_nodes.h @@ -790,6 +790,9 @@ class NVF_API TensorDomain : public Val { // axis is by default placed at original position axis_o void merge(int64_t axis_o, int64_t axis_i); + // Partition axis into component and ragged dimensions based on offsets + void partition(int64_t axis, TensorView* offsets); + // Reorder axes according to map[old_pos] = new_pos void reorder(const std::unordered_map& old2new); diff --git a/csrc/tensor_view.cpp b/csrc/tensor_view.cpp index 7e0d39c02bd..eabe2004f0c 100644 --- a/csrc/tensor_view.cpp +++ b/csrc/tensor_view.cpp @@ -561,6 +561,51 @@ TensorView* TensorView::merge(int64_t axis_o, int64_t axis_i) { return this; } +// Partition "axis" into component and ragged dimensions based on +// offsets. Follow the pattern of TensorView::split. +TensorView* TensorView::partition(int64_t axis, TensorView* offsets) { + NVF_ERROR( + nDims() > 0, + "Tried to do partition on a 0-dim TensorView. ", + "Tensor: ", + toString()); + + axis = wrapDim(axis); + + NVF_CHECK( + axis >= getMaxComputePosition(), + "Cannot partition axis within compute at position. Axis = ", + axis, + " computePosition = ", + getMaxComputePosition(), + ". Tensor: ", + toString()); + + NVF_CHECK( + axis >= getMaybeMaxProducerPosition(), + "Cannot partition axis within max producer position. Axis = ", + axis, + " maxProducerPosition = ", + getMaybeMaxProducerPosition(), + ". Tensor: ", + toString()); + + NVF_CHECK( + this->axis(axis)->getParallelType() == ParallelType::Serial, + "Partitioning an axis (", + this->axis(axis)->toString(), + ") of non-Serial parallel type is not supported at this time." + " Parallelization strategy must be set after calling partition: ", + toString()); + + if (offsets->dtype() != DataType::Index) { + offsets = castOp(DataType::Index, offsets); + } + + domain()->partition(axis, offsets); + return this; +} + TensorView* TensorView::resize( int64_t axis, Val* left_expansion, diff --git a/tests/cpp/test_ragged_iter_domain.cpp b/tests/cpp/test_ragged_iter_domain.cpp index a9d4c79c67e..3571d3c11a2 100644 --- a/tests/cpp/test_ragged_iter_domain.cpp +++ b/tests/cpp/test_ragged_iter_domain.cpp @@ -311,4 +311,38 @@ TEST_F(RaggedIterDomainTest, PartitionValidation) { RaggedIterDomain::partition(ragged_id, offsets), nvfuser::nvfError); } +// TensorView::partition operation +TEST_F(RaggedIterDomainTest, TensorViewPartition) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Create a 2D TensorView + auto tv0 = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(tv0); + + // Create offsets tensor + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Partition the first axis + tv0->partition(0, offsets); + + // Verify the tensor now has 3 dimensions: [component, ragged, original_dim1] + EXPECT_EQ(tv0->nDims(), 3); + + // First axis should be a regular IterDomain (component) + EXPECT_TRUE(tv0->axis(0)->isA()); + + // Second axis should be a RaggedIterDomain + EXPECT_TRUE(tv0->axis(1)->isA()); + + // Third axis should be the original second dimension + EXPECT_TRUE(tv0->axis(2)->isA()); + + // Verify both partition outputs have the same definition + EXPECT_TRUE(tv0->axis(0)->definition() != nullptr); + EXPECT_TRUE(tv0->axis(0)->definition()->isA()); + EXPECT_EQ(tv0->axis(0)->definition(), tv0->axis(1)->definition()); +} + } // namespace nvfuser From 69dbe0fd19c374a6cb3db0a0956a4c308cd1f9aa Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 12 Dec 2025 21:05:58 -0800 Subject: [PATCH 12/47] cleanup --- csrc/ir/internal_base_nodes.cpp | 63 ++++++++++----------------- tests/cpp/test_ragged_iter_domain.cpp | 32 +++----------- 2 files changed, 30 insertions(+), 65 deletions(-) diff --git a/csrc/ir/internal_base_nodes.cpp b/csrc/ir/internal_base_nodes.cpp index 095eb208484..51f855dcf80 100644 --- a/csrc/ir/internal_base_nodes.cpp +++ b/csrc/ir/internal_base_nodes.cpp @@ -912,60 +912,43 @@ std::pair RaggedIterDomain::partition( NVF_ERROR(offsets != nullptr, "partition: offsets tensor is null"); - NVF_ERROR( - offsets->dtype() == DataType::Index, + NVF_ERROR_EQ( + offsets->dtype(), + DataType::Index, "partition: offsets must have Index type, got ", offsets->dtype()); const auto& offsets_domain = offsets->getLogicalDomain(); - NVF_ERROR( - !offsets_domain.empty(), - "partition: offsets tensor must have at least one dimension"); + NVF_ERROR_EQ( + offsets_domain.size(), + 1, + "partition: offsets tensor must be 1D, got ", + offsets_domain.size(), + "D tensor. Multi-dimensional offsets not yet supported."); auto container = in->container(); // Compute extents from offsets: extents[i] = offsets[i+1] - offsets[i] - // Slice along the last dimension of the offsets tensor - // offsets_left = offsets[..., :-1] (all but last element in last dim) - // offsets_right = offsets[..., 1:] (all but first element in last dim) + // offsets_left = offsets[:-1] (all but last element) + // offsets_right = offsets[1:] (all but first element) - const auto last_dim = offsets_domain.size() - 1; - auto offsets_len = offsets_domain[last_dim]->extent(); + auto offsets_len = offsets_domain[0]->extent(); auto zero = container->zeroVal(DataType::Index); auto one = container->oneVal(DataType::Index); auto len_minus_one = sub(offsets_len, one); - // Build slice ranges for all dimensions - // For all dimensions except the last, use full range (:) - // For the last dimension, use [:-1] for left and [1:] for right - std::vector left_ranges; - std::vector right_ranges; - - for (const auto i : arange(offsets_domain.size())) { - if (i < last_dim) { - // Full range for non-last dimensions - Slice s; - s.start = zero; - s.stop = offsets_domain[i]->extent(); - left_ranges.push_back(s); - right_ranges.push_back(s); - } else { - // Last dimension: left uses [:-1], right uses [1:] - Slice left_s; - left_s.start = zero; - left_s.stop = len_minus_one; - left_ranges.push_back(left_s); - - Slice right_s; - right_s.start = one; - right_s.stop = offsets_len; - right_ranges.push_back(right_s); - } - } - - auto offsets_left = slice(offsets, left_ranges); - auto offsets_right = slice(offsets, right_ranges); + // Slice offsets[:-1] + Slice left_slice; + left_slice.start = zero; + left_slice.stop = len_minus_one; + auto offsets_left = slice(offsets, {left_slice}); + + // Slice offsets[1:] + Slice right_slice; + right_slice.start = one; + right_slice.stop = offsets_len; + auto offsets_right = slice(offsets, {right_slice}); // Compute extents: extents = offsets_right - offsets_left auto extents = sub(offsets_right, offsets_left); diff --git a/tests/cpp/test_ragged_iter_domain.cpp b/tests/cpp/test_ragged_iter_domain.cpp index 3571d3c11a2..c04e1f00e31 100644 --- a/tests/cpp/test_ragged_iter_domain.cpp +++ b/tests/cpp/test_ragged_iter_domain.cpp @@ -251,30 +251,6 @@ TEST_F(RaggedIterDomainTest, PartitionBasic) { EXPECT_TRUE(str.find("Partition") != std::string::npos); } -// Partition operation - multi-dimensional offsets -TEST_F(RaggedIterDomainTest, PartitionMultiDimensional) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto input_id = - IterDomainBuilder( - fusion.zeroVal(), IrBuilder::create(100L, DataType::Index)) - .build(); - - // Create 2D offsets tensor for nested ragged structure - auto offsets_2d = makeSymbolicTensor(2, DataType::Index); - fusion.addInput(offsets_2d); - - // Partition should work with multi-dimensional offsets - auto [component_id, ragged_id] = - RaggedIterDomain::partition(input_id, offsets_2d); - - EXPECT_TRUE(component_id != nullptr); - EXPECT_TRUE(ragged_id != nullptr); - EXPECT_TRUE(ragged_id->isA()); - EXPECT_TRUE(ragged_id->extents() != nullptr); -} - // Partition operation - validation tests TEST_F(RaggedIterDomainTest, PartitionValidation) { Fusion fusion; @@ -302,7 +278,13 @@ TEST_F(RaggedIterDomainTest, PartitionValidation) { EXPECT_THROW( RaggedIterDomain::partition(input_id, float_offsets), nvfuser::nvfError); - // Test 4: Cannot partition RaggedIterDomain + // Test 4: Multi-dimensional offsets should fail + auto offsets_2d = makeSymbolicTensor(2, DataType::Index); + fusion.addInput(offsets_2d); + EXPECT_THROW( + RaggedIterDomain::partition(input_id, offsets_2d), nvfuser::nvfError); + + // Test 5: Cannot partition RaggedIterDomain auto extents = makeSymbolicTensor(1, DataType::Index); fusion.addInput(extents); auto ragged_id = IrBuilder::create( From 2348dde73b40b6d306bf9324334176a217616e5c Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 12 Dec 2025 21:44:44 -0800 Subject: [PATCH 13/47] cleanup --- csrc/ir/internal_base_nodes.cpp | 8 ++++++++ csrc/ir/internal_base_nodes.h | 4 +++- tests/cpp/test_ragged_iter_domain.cpp | 11 ++++++++++- 3 files changed, 21 insertions(+), 2 deletions(-) diff --git a/csrc/ir/internal_base_nodes.cpp b/csrc/ir/internal_base_nodes.cpp index 4187ef96e2a..cc068fbfca2 100644 --- a/csrc/ir/internal_base_nodes.cpp +++ b/csrc/ir/internal_base_nodes.cpp @@ -910,6 +910,14 @@ std::pair RaggedIterDomain::partition( "Partitioning of parallelized IterDomain not supported: ", in->toString()); + NVF_ERROR_EQ( + in->getIterType(), + IterType::Iteration, + "partition: only IterType::Iteration is supported, got ", + in->getIterType(), + " for IterDomain: ", + in->toString()); + NVF_ERROR(offsets != nullptr, "partition: offsets tensor is null"); NVF_ERROR_EQ( diff --git a/csrc/ir/internal_base_nodes.h b/csrc/ir/internal_base_nodes.h index 421c056be82..0187c408bd7 100644 --- a/csrc/ir/internal_base_nodes.h +++ b/csrc/ir/internal_base_nodes.h @@ -480,12 +480,14 @@ class NVF_API RaggedIterDomain : public IterDomain { //! Creates a component IterDomain and a RaggedIterDomain based on offsets //! //! \param in Input IterDomain to partition (must be regular IterDomain) - //! \param offsets Offset tensor defining partition boundaries + //! \param offsets Offset tensor defining partition boundaries (must be 1D) //! Shape: [num_components + 1], values: [0, off1, off2, ..., total] //! Extents are computed as: extents[i] = offsets[i+1] - offsets[i] //! \return Pair of (component_id, ragged_id) //! component_id: IterDomain with extent = num_components //! ragged_id: RaggedIterDomain with extents computed from offsets + //! + //! TODO: Support multi-dimensional offsets for nested ragged structures static std::pair partition( IterDomain* in, TensorView* offsets); diff --git a/tests/cpp/test_ragged_iter_domain.cpp b/tests/cpp/test_ragged_iter_domain.cpp index b343636a56f..8d16615bd64 100644 --- a/tests/cpp/test_ragged_iter_domain.cpp +++ b/tests/cpp/test_ragged_iter_domain.cpp @@ -288,7 +288,16 @@ TEST_F(RaggedIterDomainTest, PartitionValidation) { EXPECT_THROW( RaggedIterDomain::partition(input_id, offsets_2d), nvfuser::nvfError); - // Test 5: Cannot partition RaggedIterDomain + // Test 5: Non-Iteration IterType should fail + auto reduction_id = + IterDomainBuilder( + fusion.zeroVal(), IrBuilder::create(10L, DataType::Index)) + .iter_type(IterType::Reduction) + .build(); + EXPECT_THROW( + RaggedIterDomain::partition(reduction_id, offsets), nvfuser::nvfError); + + // Test 6: Cannot partition RaggedIterDomain auto extents = makeSymbolicTensor(1, DataType::Index); fusion.addInput(extents); auto ragged_id = IrBuilder::create( From 7090b9c2bedacfd19f0cbf25f370829e4ae0f45b Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 12 Dec 2025 23:39:07 -0800 Subject: [PATCH 14/47] WIP: asNested --- csrc/ops/alias.cpp | 65 ++++++++++ csrc/ops/alias.h | 23 ++++ tests/cpp/test_ragged_iter_domain.cpp | 171 ++++++++++++++++++++++++++ 3 files changed, 259 insertions(+) diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index 45c0f52a603..cfe2cbc5ad1 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -1268,4 +1268,69 @@ TensorView* repeat( return out_tv; } +TensorView* asNested( + TensorView* data, + TensorView* offsets, + int64_t ragged_dim) { + // Basic null checks + NVF_ERROR(data != nullptr, "asNested: data tensor is null"); + NVF_ERROR(offsets != nullptr, "asNested: offsets tensor is null"); + + // Only 1D offset tensors are currently supported + NVF_CHECK( + offsets->nDims() == 1, + "asNested currently only supports 1D offset tensors, got ", + offsets->nDims(), + "D"); + + // Get the logical domain of the input, excluding reductions + auto inp_logical = TensorDomain::noReductions(data->getLogicalDomain()); + + // Clone the logical domain to create the root domain for output + std::vector root_domain; + root_domain.reserve(inp_logical.size()); + for (auto* id : inp_logical) { + root_domain.push_back(id->cloneWithoutRFactor()); + } + + // Partition the specified dimension in root domain + // This replaces one IterDomain with (component_id, ragged_id) + auto [component_id, ragged_id] = + RaggedIterDomain::partition(root_domain.at(ragged_dim), offsets); + + // Build the logical domain: replace ragged_dim with component and ragged + std::vector logical_domain; + logical_domain.reserve(root_domain.size() + 1); // One extra for the split + + for (const auto i : arange(root_domain.size())) { + if (static_cast(i) == ragged_dim) { + // Replace with component and ragged dimensions + logical_domain.push_back(component_id); + logical_domain.push_back(ragged_id); + } else { + logical_domain.push_back(root_domain.at(i)); + } + } + + // Create the output TensorView with the partitioned structure + auto* out = IrBuilder::create( + IrBuilder::create( + root_domain, + logical_domain, + logical_domain, + TensorDomain::getContiguityFilledWith(logical_domain, true)), + data->getDataType().value()); + + // Create a Partition expression to represent this transformation + // The Partition Expr outputs the component_id and ragged_id, and sets up + // the definitions for those IterDomains + IrBuilder::create(component_id, ragged_id, root_domain.at(ragged_dim), offsets); + + // Set the output TensorView's definition - this should be done via LoadStoreOp + // since we're creating an alias view + IrBuilder::create(LoadStoreOpType::Set, out, data); + + return out; +} + } // namespace nvfuser diff --git a/csrc/ops/alias.h b/csrc/ops/alias.h index 951b98b2a12..feea2924daa 100644 --- a/csrc/ops/alias.h +++ b/csrc/ops/alias.h @@ -156,6 +156,29 @@ NVF_API TensorView* slice( const std::vector& starts, const std::vector& stops); +//! Create a nested tensor view from a data tensor and offsets. +//! This is a convenience wrapper around TensorView::partition(). +//! +//! The function partitions the specified dimension of the data tensor into +//! a component dimension and a ragged dimension based on the provided offsets. +//! +//! \param data Input tensor to be converted to nested representation +//! \param offsets Offset tensor defining partition boundaries +//! Shape: [num_components + 1], values: [0, off1, off2, ..., total] +//! Extents are computed as: extents[i] = offsets[i+1] - offsets[i] +//! \param ragged_dim Dimension to partition into nested structure (0-indexed) +//! \return TensorView with a RaggedIterDomain at the specified dimension +//! +//! Example: +//! data shape: [10, ...] +//! offsets: [0, 3, 8, 10] +//! ragged_dim: 0 +//! Result: nested tensor with 3 components of sizes [3, 5, 2] +NVF_API TensorView* asNested( + TensorView* data, + TensorView* offsets, + int64_t ragged_dim); + // Splits `in`'s dimension `dim` into `chunks` chunks. All but the last chunk // will be of size `ceil(dim_size/chunks)`. Unlike `torch.chunk` which returns // only positive-size chunks and therefore may return fewer than `chunks` of diff --git a/tests/cpp/test_ragged_iter_domain.cpp b/tests/cpp/test_ragged_iter_domain.cpp index 8d16615bd64..8f97c099026 100644 --- a/tests/cpp/test_ragged_iter_domain.cpp +++ b/tests/cpp/test_ragged_iter_domain.cpp @@ -340,4 +340,175 @@ TEST_F(RaggedIterDomainTest, TensorViewPartition) { EXPECT_EQ(tv0->axis(0)->definition(), tv0->axis(1)->definition()); } +// asNested basic functionality +TEST_F(RaggedIterDomainTest, AsNestedBasic) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Create a 2D TensorView [10, 20] + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + // Create offsets tensor [num_components + 1] + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensor from dimension 0 + auto nested = asNested(data, offsets, 0); + + // Verify the output is a new TensorView + EXPECT_NE(nested, nullptr); + EXPECT_NE(nested, data); + EXPECT_TRUE(nested->isA()); + + // Verify nested tensor has 3 dimensions: [component, ragged, original_dim1] + EXPECT_EQ(nested->nDims(), 3); + + // First axis should be a regular IterDomain (component) + EXPECT_TRUE(nested->axis(0)->isA()); + EXPECT_FALSE(nested->axis(0)->isA()); + + // Second axis should be a RaggedIterDomain + EXPECT_TRUE(nested->axis(1)->isA()); + + // Third axis should be the original second dimension + EXPECT_TRUE(nested->axis(2)->isA()); + EXPECT_FALSE(nested->axis(2)->isA()); + + // Verify the definition exists (LoadStoreOp for aliasing) + EXPECT_TRUE(nested->definition() != nullptr); + EXPECT_TRUE(nested->definition()->isA()); + + // Verify the component and ragged IterDomains have Partition as their definition + EXPECT_TRUE(nested->axis(0)->definition() != nullptr); + EXPECT_TRUE(nested->axis(0)->definition()->isA()); + EXPECT_EQ(nested->axis(0)->definition(), nested->axis(1)->definition()); +} + +// asNested on different dimensions +TEST_F(RaggedIterDomainTest, AsNestedDifferentDimension) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Create a 3D TensorView [10, 20, 30] + auto data = makeSymbolicTensor(3, DataType::Float); + fusion.addInput(data); + + // Create offsets tensor + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Partition dimension 1 (middle dimension) + auto nested = asNested(data, offsets, 1); + + // Verify dimensions: [dim0, component, ragged, dim2] + EXPECT_EQ(nested->nDims(), 4); + + // First axis is original dim0 + EXPECT_TRUE(nested->axis(0)->isA()); + EXPECT_FALSE(nested->axis(0)->isA()); + + // Second axis is component + EXPECT_TRUE(nested->axis(1)->isA()); + EXPECT_FALSE(nested->axis(1)->isA()); + + // Third axis is ragged + EXPECT_TRUE(nested->axis(2)->isA()); + + // Fourth axis is original dim2 + EXPECT_TRUE(nested->axis(3)->isA()); + EXPECT_FALSE(nested->axis(3)->isA()); +} + +// asNested with 1D tensor +TEST_F(RaggedIterDomainTest, AsNested1DTensor) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Create a 1D TensorView [10] + auto data = makeSymbolicTensor(1, DataType::Float); + fusion.addInput(data); + + // Create offsets tensor + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensor from the only dimension + auto nested = asNested(data, offsets, 0); + + // Verify dimensions: [component, ragged] + EXPECT_EQ(nested->nDims(), 2); + + // First axis is component + EXPECT_TRUE(nested->axis(0)->isA()); + EXPECT_FALSE(nested->axis(0)->isA()); + + // Second axis is ragged + EXPECT_TRUE(nested->axis(1)->isA()); +} + +// asNested validation - null data +TEST_F(RaggedIterDomainTest, AsNestedValidationNullData) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Null data should throw + EXPECT_THROW(asNested(nullptr, offsets, 0), nvfuser::nvfError); +} + +// asNested validation - null offsets +TEST_F(RaggedIterDomainTest, AsNestedValidationNullOffsets) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + // Null offsets should throw + EXPECT_THROW(asNested(data, nullptr, 0), nvfuser::nvfError); +} + +// asNested validation - multi-dimensional offsets (not yet supported) +TEST_F(RaggedIterDomainTest, AsNestedValidationMultiDimOffsets) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + // 2D offsets should fail (only 1D supported currently) + auto offsets_2d = makeSymbolicTensor(2, DataType::Index); + fusion.addInput(offsets_2d); + + EXPECT_THROW(asNested(data, offsets_2d, 0), nvfuser::nvfError); +} + +// asNested preserves data type +TEST_F(RaggedIterDomainTest, AsNestedPreservesDataType) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Test with different data types + auto data_float = makeSymbolicTensor(2, DataType::Float); + auto data_double = makeSymbolicTensor(2, DataType::Double); + auto data_int = makeSymbolicTensor(2, DataType::Int); + fusion.addInput(data_float); + fusion.addInput(data_double); + fusion.addInput(data_int); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + auto nested_float = asNested(data_float, offsets, 0); + auto nested_double = asNested(data_double, offsets, 0); + auto nested_int = asNested(data_int, offsets, 0); + + EXPECT_EQ(nested_float->dtype(), DataType::Float); + EXPECT_EQ(nested_double->dtype(), DataType::Double); + EXPECT_EQ(nested_int->dtype(), DataType::Int); +} + } // namespace nvfuser From b07e285ab40e39da8cd4c701ca6045625aff47f4 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Sat, 13 Dec 2025 00:04:39 -0800 Subject: [PATCH 15/47] cleanup --- csrc/ops/alias.cpp | 14 +++---- tests/cpp/test_ragged_iter_domain.cpp | 53 ++++++--------------------- 2 files changed, 19 insertions(+), 48 deletions(-) diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index cfe2cbc5ad1..84dbd99b589 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -1277,11 +1277,10 @@ TensorView* asNested( NVF_ERROR(offsets != nullptr, "asNested: offsets tensor is null"); // Only 1D offset tensors are currently supported - NVF_CHECK( - offsets->nDims() == 1, - "asNested currently only supports 1D offset tensors, got ", + NVF_ERROR_EQ( offsets->nDims(), - "D"); + 1, + "asNested currently only supports 1D offset tensors"); // Get the logical domain of the input, excluding reductions auto inp_logical = TensorDomain::noReductions(data->getLogicalDomain()); @@ -1324,10 +1323,11 @@ TensorView* asNested( // Create a Partition expression to represent this transformation // The Partition Expr outputs the component_id and ragged_id, and sets up // the definitions for those IterDomains - IrBuilder::create(component_id, ragged_id, root_domain.at(ragged_dim), offsets); + IrBuilder::create( + component_id, ragged_id, root_domain.at(ragged_dim), offsets); - // Set the output TensorView's definition - this should be done via LoadStoreOp - // since we're creating an alias view + // Set the output TensorView's definition - this should be done via + // LoadStoreOp since we're creating an alias view IrBuilder::create(LoadStoreOpType::Set, out, data); return out; diff --git a/tests/cpp/test_ragged_iter_domain.cpp b/tests/cpp/test_ragged_iter_domain.cpp index 8f97c099026..f7eaac14c2e 100644 --- a/tests/cpp/test_ragged_iter_domain.cpp +++ b/tests/cpp/test_ragged_iter_domain.cpp @@ -345,19 +345,19 @@ TEST_F(RaggedIterDomainTest, AsNestedBasic) { Fusion fusion; FusionGuard fg(&fusion); - // Create a 2D TensorView [10, 20] auto data = makeSymbolicTensor(2, DataType::Float); fusion.addInput(data); - // Create offsets tensor [num_components + 1] auto offsets = makeSymbolicTensor(1, DataType::Index); fusion.addInput(offsets); // Create nested tensor from dimension 0 auto nested = asNested(data, offsets, 0); + fusion.addOutput(nested); + // Verify the output is a new TensorView - EXPECT_NE(nested, nullptr); + EXPECT_TRUE(nested != nullptr); EXPECT_NE(nested, data); EXPECT_TRUE(nested->isA()); @@ -365,21 +365,21 @@ TEST_F(RaggedIterDomainTest, AsNestedBasic) { EXPECT_EQ(nested->nDims(), 3); // First axis should be a regular IterDomain (component) - EXPECT_TRUE(nested->axis(0)->isA()); + EXPECT_TRUE(nested->axis(0)->isStrictlyA()); EXPECT_FALSE(nested->axis(0)->isA()); // Second axis should be a RaggedIterDomain EXPECT_TRUE(nested->axis(1)->isA()); // Third axis should be the original second dimension - EXPECT_TRUE(nested->axis(2)->isA()); - EXPECT_FALSE(nested->axis(2)->isA()); + EXPECT_TRUE(nested->axis(2)->isStrictlyA()); // Verify the definition exists (LoadStoreOp for aliasing) EXPECT_TRUE(nested->definition() != nullptr); EXPECT_TRUE(nested->definition()->isA()); - // Verify the component and ragged IterDomains have Partition as their definition + // Verify the component and ragged IterDomains have Partition as their + // definition EXPECT_TRUE(nested->axis(0)->definition() != nullptr); EXPECT_TRUE(nested->axis(0)->definition()->isA()); EXPECT_EQ(nested->axis(0)->definition(), nested->axis(1)->definition()); @@ -390,11 +390,9 @@ TEST_F(RaggedIterDomainTest, AsNestedDifferentDimension) { Fusion fusion; FusionGuard fg(&fusion); - // Create a 3D TensorView [10, 20, 30] auto data = makeSymbolicTensor(3, DataType::Float); fusion.addInput(data); - // Create offsets tensor auto offsets = makeSymbolicTensor(1, DataType::Index); fusion.addInput(offsets); @@ -405,19 +403,16 @@ TEST_F(RaggedIterDomainTest, AsNestedDifferentDimension) { EXPECT_EQ(nested->nDims(), 4); // First axis is original dim0 - EXPECT_TRUE(nested->axis(0)->isA()); - EXPECT_FALSE(nested->axis(0)->isA()); + EXPECT_TRUE(nested->axis(0)->isStrictlyA()); // Second axis is component - EXPECT_TRUE(nested->axis(1)->isA()); - EXPECT_FALSE(nested->axis(1)->isA()); + EXPECT_TRUE(nested->axis(1)->isStrictlyA()); // Third axis is ragged EXPECT_TRUE(nested->axis(2)->isA()); // Fourth axis is original dim2 EXPECT_TRUE(nested->axis(3)->isA()); - EXPECT_FALSE(nested->axis(3)->isA()); } // asNested with 1D tensor @@ -436,12 +431,13 @@ TEST_F(RaggedIterDomainTest, AsNested1DTensor) { // Create nested tensor from the only dimension auto nested = asNested(data, offsets, 0); + fusion.addOutput(nested); + // Verify dimensions: [component, ragged] EXPECT_EQ(nested->nDims(), 2); // First axis is component - EXPECT_TRUE(nested->axis(0)->isA()); - EXPECT_FALSE(nested->axis(0)->isA()); + EXPECT_TRUE(nested->axis(0)->isStrictlyA()); // Second axis is ragged EXPECT_TRUE(nested->axis(1)->isA()); @@ -486,29 +482,4 @@ TEST_F(RaggedIterDomainTest, AsNestedValidationMultiDimOffsets) { EXPECT_THROW(asNested(data, offsets_2d, 0), nvfuser::nvfError); } -// asNested preserves data type -TEST_F(RaggedIterDomainTest, AsNestedPreservesDataType) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Test with different data types - auto data_float = makeSymbolicTensor(2, DataType::Float); - auto data_double = makeSymbolicTensor(2, DataType::Double); - auto data_int = makeSymbolicTensor(2, DataType::Int); - fusion.addInput(data_float); - fusion.addInput(data_double); - fusion.addInput(data_int); - - auto offsets = makeSymbolicTensor(1, DataType::Index); - fusion.addInput(offsets); - - auto nested_float = asNested(data_float, offsets, 0); - auto nested_double = asNested(data_double, offsets, 0); - auto nested_int = asNested(data_int, offsets, 0); - - EXPECT_EQ(nested_float->dtype(), DataType::Float); - EXPECT_EQ(nested_double->dtype(), DataType::Double); - EXPECT_EQ(nested_int->dtype(), DataType::Int); -} - } // namespace nvfuser From a2c504baa3d0360c92d5bb50a65e3011c237f630 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 15 Dec 2025 13:58:52 -0800 Subject: [PATCH 16/47] asNested --- csrc/ops/alias.cpp | 12 +++--------- csrc/ops/alias.h | 45 ++++++++++++++++++++++----------------------- 2 files changed, 25 insertions(+), 32 deletions(-) diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index 84dbd99b589..30c3406633e 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -1272,7 +1272,6 @@ TensorView* asNested( TensorView* data, TensorView* offsets, int64_t ragged_dim) { - // Basic null checks NVF_ERROR(data != nullptr, "asNested: data tensor is null"); NVF_ERROR(offsets != nullptr, "asNested: offsets tensor is null"); @@ -1320,14 +1319,9 @@ TensorView* asNested( TensorDomain::getContiguityFilledWith(logical_domain, true)), data->getDataType().value()); - // Create a Partition expression to represent this transformation - // The Partition Expr outputs the component_id and ragged_id, and sets up - // the definitions for those IterDomains - IrBuilder::create( - component_id, ragged_id, root_domain.at(ragged_dim), offsets); - - // Set the output TensorView's definition - this should be done via - // LoadStoreOp since we're creating an alias view + // For now, just use LoadStoreOp to represent the nesting + // operation. Does it make more sense to have a specific TensorView + // op like ReshapeOp? IrBuilder::create(LoadStoreOpType::Set, out, data); return out; diff --git a/csrc/ops/alias.h b/csrc/ops/alias.h index feea2924daa..f3bf769dd71 100644 --- a/csrc/ops/alias.h +++ b/csrc/ops/alias.h @@ -156,29 +156,6 @@ NVF_API TensorView* slice( const std::vector& starts, const std::vector& stops); -//! Create a nested tensor view from a data tensor and offsets. -//! This is a convenience wrapper around TensorView::partition(). -//! -//! The function partitions the specified dimension of the data tensor into -//! a component dimension and a ragged dimension based on the provided offsets. -//! -//! \param data Input tensor to be converted to nested representation -//! \param offsets Offset tensor defining partition boundaries -//! Shape: [num_components + 1], values: [0, off1, off2, ..., total] -//! Extents are computed as: extents[i] = offsets[i+1] - offsets[i] -//! \param ragged_dim Dimension to partition into nested structure (0-indexed) -//! \return TensorView with a RaggedIterDomain at the specified dimension -//! -//! Example: -//! data shape: [10, ...] -//! offsets: [0, 3, 8, 10] -//! ragged_dim: 0 -//! Result: nested tensor with 3 components of sizes [3, 5, 2] -NVF_API TensorView* asNested( - TensorView* data, - TensorView* offsets, - int64_t ragged_dim); - // Splits `in`'s dimension `dim` into `chunks` chunks. All but the last chunk // will be of size `ceil(dim_size/chunks)`. Unlike `torch.chunk` which returns // only positive-size chunks and therefore may return fewer than `chunks` of @@ -220,4 +197,26 @@ NVF_API TensorView* repeat( TensorView* inp, const std::vector& repeat_times); +//! Create a nested tensor view from a data tensor and offsets. +//! +//! The function partitions the specified dimension of the data tensor into +//! a component dimension and a ragged dimension based on the provided offsets. +//! +//! \param data Input tensor to be converted to nested representation +//! \param offsets Offset tensor defining partition boundaries +//! Shape: [num_components + 1], values: [0, off1, off2, ..., total] +//! Extents are computed as: extents[i] = offsets[i+1] - offsets[i] +//! \param ragged_dim Dimension to partition into nested structure +//! \return TensorView with a RaggedIterDomain at the specified dimension +//! +//! Example: +//! data shape: [10, ...] +//! offsets: [0, 3, 8, 10] +//! ragged_dim: 0 +//! Result: nested tensor with 3 components. [3, [3, 5, 2], ...] +NVF_API TensorView* asNested( + TensorView* data, + TensorView* offsets, + int64_t ragged_dim); + } // namespace nvfuser From b1d8cf40a0a2fbe9725b68bf08f2e6ea55b9981e Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 15 Dec 2025 14:11:21 -0800 Subject: [PATCH 17/47] warpdim --- csrc/ops/alias.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index 30c3406633e..e32aa5e6b9c 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -1291,6 +1291,8 @@ TensorView* asNested( root_domain.push_back(id->cloneWithoutRFactor()); } + ragged_dim = wrapDim(ragged_dim, std::ssize(inp_logical)); + // Partition the specified dimension in root domain // This replaces one IterDomain with (component_id, ragged_id) auto [component_id, ragged_id] = From 201c1480ac75dec581570276937d7d1f00513e28 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 17 Dec 2025 11:31:27 -0800 Subject: [PATCH 18/47] Make sure RaggedIterDomain is propagated to output tensors --- csrc/ir/internal_base_nodes.cpp | 7 +++ csrc/ir/internal_base_nodes.h | 2 + csrc/ops/utils.cpp | 29 +++++++++++ tests/cpp/test_ragged_iter_domain.cpp | 71 ++++++++++++++++----------- 4 files changed, 80 insertions(+), 29 deletions(-) diff --git a/csrc/ir/internal_base_nodes.cpp b/csrc/ir/internal_base_nodes.cpp index cc068fbfca2..9d9984a3d11 100644 --- a/csrc/ir/internal_base_nodes.cpp +++ b/csrc/ir/internal_base_nodes.cpp @@ -1496,6 +1496,13 @@ bool TensorDomain::hasVectorize() const { }); } +bool TensorDomain::hasRaggedIterDomain() const { + return std::any_of( + logical().begin(), logical().end(), [](IterDomain* logical_id) { + return logical_id->isA(); + }); +} + std::optional TensorDomain::getReductionAxis() const { auto it = std::find_if( loop_domain_.begin(), loop_domain_.end(), [](const auto& id) { diff --git a/csrc/ir/internal_base_nodes.h b/csrc/ir/internal_base_nodes.h index 0187c408bd7..cfade4ebcba 100644 --- a/csrc/ir/internal_base_nodes.h +++ b/csrc/ir/internal_base_nodes.h @@ -642,6 +642,8 @@ class NVF_API TensorDomain : public Val { bool hasSymbolicAxis() const; + bool hasRaggedIterDomain() const; + std::optional getReductionAxis() const; // The input logical domain. The root domain of a consumer should equal the diff --git a/csrc/ops/utils.cpp b/csrc/ops/utils.cpp index 2db0b424d55..80aa95c1fe2 100644 --- a/csrc/ops/utils.cpp +++ b/csrc/ops/utils.cpp @@ -315,6 +315,28 @@ std::vector mapLinearOpIterDomains( return mapping; } +RaggedIterDomain* newOutputRaggedIterDomain( + const std::vector& input_ids, + const std::optional force_iter_type) { + NVF_ERROR( + std::ranges::all_of( + input_ids, + [](IterDomain* input_id) { + return input_id->isA(); + }), + "All input iter domains must be RaggedIterDomain"); + + NVF_ERROR(!input_ids.empty()); + RaggedIterDomain* ref_input_id = input_ids.front()->as(); + + NVF_ERROR(!force_iter_type.has_value(), "forced iter type not considered"); + + return IrBuilder::create( + ref_input_id->extents(), + ref_input_id->getIterType(), + ref_input_id->getParallelType()); +} + // Adding these pragmas since gcc-12.2.1 // incorrectly reports a warning with the use of evaluate #if defined(__GNUC__) && !defined(__clang__) @@ -324,6 +346,13 @@ std::vector mapLinearOpIterDomains( IterDomain* newOutputIterDomain( const std::vector& input_ids, const std::optional force_iter_type) { + NVF_ERROR(!input_ids.empty()); + + // If any input ID is a RaggedIterDomain, the output should also be ragged + if (input_ids.front()->isA()) { + return newOutputRaggedIterDomain(input_ids, force_iter_type); + } + // For the start and stop offsets, take the maximum of input axes. // For now, the offsets of both start and stop are always integer // constant, so we can statically compute them. It is unclear diff --git a/tests/cpp/test_ragged_iter_domain.cpp b/tests/cpp/test_ragged_iter_domain.cpp index f7eaac14c2e..3bd7127f78d 100644 --- a/tests/cpp/test_ragged_iter_domain.cpp +++ b/tests/cpp/test_ragged_iter_domain.cpp @@ -354,35 +354,48 @@ TEST_F(RaggedIterDomainTest, AsNestedBasic) { // Create nested tensor from dimension 0 auto nested = asNested(data, offsets, 0); - fusion.addOutput(nested); - - // Verify the output is a new TensorView - EXPECT_TRUE(nested != nullptr); - EXPECT_NE(nested, data); - EXPECT_TRUE(nested->isA()); - - // Verify nested tensor has 3 dimensions: [component, ragged, original_dim1] - EXPECT_EQ(nested->nDims(), 3); - - // First axis should be a regular IterDomain (component) - EXPECT_TRUE(nested->axis(0)->isStrictlyA()); - EXPECT_FALSE(nested->axis(0)->isA()); - - // Second axis should be a RaggedIterDomain - EXPECT_TRUE(nested->axis(1)->isA()); - - // Third axis should be the original second dimension - EXPECT_TRUE(nested->axis(2)->isStrictlyA()); - - // Verify the definition exists (LoadStoreOp for aliasing) - EXPECT_TRUE(nested->definition() != nullptr); - EXPECT_TRUE(nested->definition()->isA()); - - // Verify the component and ragged IterDomains have Partition as their - // definition - EXPECT_TRUE(nested->axis(0)->definition() != nullptr); - EXPECT_TRUE(nested->axis(0)->definition()->isA()); - EXPECT_EQ(nested->axis(0)->definition(), nested->axis(1)->definition()); + // This should still be a nested tensor + auto copy_of_nested = set(nested); + + fusion.addOutput(copy_of_nested); + + for (auto nested_tv : {nested, copy_of_nested}) { + // Verify the output is a new TensorView + EXPECT_TRUE(nested_tv != nullptr); + EXPECT_NE(nested_tv, data); + EXPECT_TRUE(nested_tv->isA()); + + // Verify nested_tv tensor has 3 dimensions: [component, ragged, + // original_dim1] + EXPECT_EQ(nested_tv->nDims(), 3); + + // First axis should be a regular IterDomain (component) + EXPECT_TRUE(nested_tv->axis(0)->isStrictlyA()); + EXPECT_FALSE(nested_tv->axis(0)->isA()); + + // Second axis should be a RaggedIterDomain + EXPECT_TRUE(nested_tv->axis(1)->isA()); + + // Third axis should be the original second dimension + EXPECT_TRUE(nested_tv->axis(2)->isStrictlyA()); + + if (nested_tv == nested) { + // Verify the definition exists (LoadStoreOp for aliasing) + EXPECT_TRUE(nested_tv->definition() != nullptr); + EXPECT_TRUE(nested_tv->definition()->isA()); + + // Verify the component and ragged IterDomains have Partition as their + // definition + EXPECT_TRUE(nested_tv->axis(0)->definition() != nullptr); + EXPECT_TRUE(nested_tv->axis(0)->definition()->isA()); + EXPECT_EQ( + nested_tv->axis(0)->definition(), nested_tv->axis(1)->definition()); + } else { + // The copy of the original nested tensor does not inherit the Partition + // op + EXPECT_TRUE(nested_tv->axis(0)->definition() == nullptr); + } + } } // asNested on different dimensions From 9e0b161b9adf62d62172f78a60184c9fd8ae4327 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 17 Dec 2025 13:37:32 -0800 Subject: [PATCH 19/47] Extend ops to be aware with RaggediterDomain --- csrc/ir/internal_base_nodes.cpp | 14 ++ csrc/ir/internal_base_nodes.h | 5 +- csrc/ops/alias.cpp | 3 +- csrc/ops/utils.cpp | 17 +- csrc/ops/utils.h | 6 + tests/cpp/test_ragged_iter_domain.cpp | 298 ++++++++++++++++++++++++++ 6 files changed, 334 insertions(+), 9 deletions(-) diff --git a/csrc/ir/internal_base_nodes.cpp b/csrc/ir/internal_base_nodes.cpp index 9d9984a3d11..3b7c31a89e9 100644 --- a/csrc/ir/internal_base_nodes.cpp +++ b/csrc/ir/internal_base_nodes.cpp @@ -895,6 +895,20 @@ std::string RaggedIterDomain::toString(int indent_size) const { return toInlineString(indent_size); } +IterDomain* RaggedIterDomain::cloneWithoutRFactor(bool map_with_original) { + // Create a new RaggedIterDomain with the same extents and properties + auto cloned = IrBuilder::create( + extents_, getIterType(), getParallelType()); + + // Optionally map the clone with the original in the Exact graph + if (map_with_original) { + // TODO: Implement mapping if needed + NVF_THROW("Not implemented"); + } + + return cloned; +} + std::pair RaggedIterDomain::partition( IterDomain* in, TensorView* offsets) { diff --git a/csrc/ir/internal_base_nodes.h b/csrc/ir/internal_base_nodes.h index cfade4ebcba..d56d4d21470 100644 --- a/csrc/ir/internal_base_nodes.h +++ b/csrc/ir/internal_base_nodes.h @@ -122,7 +122,7 @@ class NVF_API IterDomain : public Val { //! //! When map_with_original is true, the clone of the original is //! mapped in the Exact graph. - IterDomain* cloneWithoutRFactor(bool map_with_original = false); + virtual IterDomain* cloneWithoutRFactor(bool map_with_original = false); //! Clone a vector domains static std::vector clone( @@ -492,6 +492,9 @@ class NVF_API RaggedIterDomain : public IterDomain { IterDomain* in, TensorView* offsets); + //! Override cloneWithoutRFactor to preserve RaggedIterDomain type + IterDomain* cloneWithoutRFactor(bool map_with_original = false) override; + private: //! Extent tensor containing all component extents //! Can be 1D, 2D, or N-D depending on nesting structure diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index e32aa5e6b9c..870a1c186a3 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -1038,8 +1038,7 @@ TensorView* broadcast( .iter_type(IterType::Broadcast) .build()); } else { - out_domain.push_back( - IterDomainBuilder(inp_domain[iinp]).resetSchedulingParams().build()); + out_domain.push_back(inp_domain[iinp]->cloneWithoutRFactor()); iinp++; } ibdim++; diff --git a/csrc/ops/utils.cpp b/csrc/ops/utils.cpp index 80aa95c1fe2..2623f4c68be 100644 --- a/csrc/ops/utils.cpp +++ b/csrc/ops/utils.cpp @@ -316,8 +316,7 @@ std::vector mapLinearOpIterDomains( } RaggedIterDomain* newOutputRaggedIterDomain( - const std::vector& input_ids, - const std::optional force_iter_type) { + const std::vector& input_ids) { NVF_ERROR( std::ranges::all_of( input_ids, @@ -329,8 +328,6 @@ RaggedIterDomain* newOutputRaggedIterDomain( NVF_ERROR(!input_ids.empty()); RaggedIterDomain* ref_input_id = input_ids.front()->as(); - NVF_ERROR(!force_iter_type.has_value(), "forced iter type not considered"); - return IrBuilder::create( ref_input_id->extents(), ref_input_id->getIterType(), @@ -349,8 +346,16 @@ IterDomain* newOutputIterDomain( NVF_ERROR(!input_ids.empty()); // If any input ID is a RaggedIterDomain, the output should also be ragged - if (input_ids.front()->isA()) { - return newOutputRaggedIterDomain(input_ids, force_iter_type); + bool has_ragged = + std::any_of(input_ids.begin(), input_ids.end(), [](IterDomain* id) { + return id->isA(); + }); + + if (has_ragged) { + NVF_ERROR( + !force_iter_type.has_value(), + "force_iter_type not supported for RaggedIterDomain"); + return newOutputRaggedIterDomain(input_ids); } // For the start and stop offsets, take the maximum of input axes. diff --git a/csrc/ops/utils.h b/csrc/ops/utils.h index 44a98242a4d..3ceadc4aa6a 100644 --- a/csrc/ops/utils.h +++ b/csrc/ops/utils.h @@ -92,6 +92,12 @@ std::vector mapLinearOpIterDomains( size_t out_size, bool k_bcast); +// Creates an output RaggedIterDomain from input RaggedIterDomains at the same +// dimension position. All inputs must be RaggedIterDomain. Uses the extents, +// IterType, and ParallelType from the first input. +RaggedIterDomain* newOutputRaggedIterDomain( + const std::vector& input_ids); + // Takes a vector of aligned input iterdomains to create the output iterdomain. // This is used if the input iterdomains are not trivially mapped to the output // iterdomains. For eg: MatmulOp. If given, the forced_iter_type argument will diff --git a/tests/cpp/test_ragged_iter_domain.cpp b/tests/cpp/test_ragged_iter_domain.cpp index 3bd7127f78d..ac3c3ef1f35 100644 --- a/tests/cpp/test_ragged_iter_domain.cpp +++ b/tests/cpp/test_ragged_iter_domain.cpp @@ -495,4 +495,302 @@ TEST_F(RaggedIterDomainTest, AsNestedValidationMultiDimOffsets) { EXPECT_THROW(asNested(data, offsets_2d, 0), nvfuser::nvfError); } +// Test binary operations with nested tensors +TEST_F(RaggedIterDomainTest, BinaryOpWithNestedTensors) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Create two 2D input tensors + auto data1 = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data1); + + auto data2 = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data2); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensors from both inputs + auto nested1 = asNested(data1, offsets, 0); + auto nested2 = asNested(data2, offsets, 0); + + // Perform binary operation: add + auto result = add(nested1, nested2); + + fusion.addOutput(result); + + // Verify the result has 3 dimensions: [component, ragged, original_dim1] + EXPECT_EQ(result->nDims(), 3); + + // First axis should be a regular IterDomain (component) + EXPECT_TRUE(result->axis(0)->isStrictlyA()); + EXPECT_FALSE(result->axis(0)->isA()); + + // Second axis should be a RaggedIterDomain + EXPECT_TRUE(result->axis(1)->isA()); + + // Third axis should be the original second dimension + EXPECT_TRUE(result->axis(2)->isStrictlyA()); +} + +// Test binary operation with mixed inputs (one ragged, one not) - should error +TEST_F(RaggedIterDomainTest, BinaryOpMixedInputsError) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data1 = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data1); + + auto data2 = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data2); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensor from first input only + auto nested1 = asNested(data1, offsets, 0); + + // Try to add nested tensor with non-nested tensor + // This should fail because one is ragged and one is not + EXPECT_THROW(add(nested1, data2), nvfuser::nvfError); +} + +// Test binary operation with different offsets +TEST_F(RaggedIterDomainTest, BinaryOpDifferentRaggedStructures) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data1 = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data1); + + auto data2 = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data2); + + auto offsets1 = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets1); + + auto offsets2 = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets2); + + // Create nested tensors with different offset tensors + auto nested1 = asNested(data1, offsets1, 0); + auto nested2 = asNested(data2, offsets2, 0); + + // This would be an error if, for example, the values of the offset + // tensors are not equivalent, but, like normal tensors, we assume + // that is indeed the case. + auto result = add(nested1, nested2); + fusion.addOutput(result); + + EXPECT_TRUE(result->axis(1)->isA()); +} + +// Test unary operations with nested tensors +TEST_F(RaggedIterDomainTest, UnaryOpWithNestedTensors) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensor + auto nested = asNested(data, offsets, 0); + + // Perform unary operation: neg + auto result = neg(nested); + + fusion.addOutput(result); + + // Verify the result preserves RaggedIterDomain structure + EXPECT_EQ(result->nDims(), 3); + EXPECT_TRUE(result->axis(0)->isStrictlyA()); + EXPECT_TRUE(result->axis(1)->isA()); + EXPECT_TRUE(result->axis(2)->isStrictlyA()); +} + +// Test broadcast with nested tensors +TEST_F(RaggedIterDomainTest, BroadcastWithNestedTensors) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensor: [component, ragged, dim1] + auto nested = asNested(data, offsets, 0); + + auto result = broadcast(nested, {false, false, false, true}); + + fusion.addOutput(result); + + // Result should be: [component, ragged, dim1, broadcast_dim] + EXPECT_EQ(result->nDims(), 4); + EXPECT_TRUE(result->axis(0)->isStrictlyA()); + EXPECT_TRUE(result->axis(1)->isA()); + EXPECT_TRUE(result->axis(2)->isStrictlyA()); + EXPECT_TRUE(result->axis(3)->isStrictlyA()); + EXPECT_TRUE(result->axis(3)->isBroadcast()); +} + +// Test squeeze on non-ragged dimension +TEST_F(RaggedIterDomainTest, SqueezeNonRaggedDim) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensor: [component, ragged, dim1] + auto nested = asNested(data, offsets, 0); + + // First broadcast to add a dimension: [component, ragged, dim1, 1] + auto broadcasted = broadcast(nested, {false, false, false, true}); + + // Then squeeze the broadcast dimension (dimension index 3) + auto result = squeeze(broadcasted, {3}); + + fusion.addOutput(result); + + // Result should be: [component, ragged, dim1] + EXPECT_EQ(result->nDims(), 3); + EXPECT_TRUE(result->axis(0)->isStrictlyA()); + EXPECT_TRUE(result->axis(1)->isA()); + EXPECT_TRUE(result->axis(2)->isStrictlyA()); +} + +// Test unsqueeze with nested tensors +TEST_F(RaggedIterDomainTest, UnsqueezeWithNestedTensors) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensor: [component, ragged, dim1] + auto nested = asNested(data, offsets, 0); + + // Unsqueeze to add dimension at the end + auto result = unsqueeze(nested, -1); + + fusion.addOutput(result); + + // Result should be: [component, ragged, dim1, 1] + EXPECT_EQ(result->nDims(), 4); + EXPECT_TRUE(result->axis(0)->isStrictlyA()); + EXPECT_TRUE(result->axis(1)->isA()); + EXPECT_TRUE(result->axis(2)->isStrictlyA()); + EXPECT_TRUE(result->axis(3)->isStrictlyA()); +} + +// Test permute/transpose with nested tensors +TEST_F(RaggedIterDomainTest, PermuteWithNestedTensors) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensor: [component, ragged, dim1] + auto nested = asNested(data, offsets, 0); + + // Permute dimensions: swap ragged and dim1 + auto result = permute(nested, {0, 2, 1}); + + fusion.addOutput(result); + + // Result should be: [component, dim1, ragged] + EXPECT_EQ(result->nDims(), 3); + EXPECT_TRUE(result->axis(0)->isStrictlyA()); + EXPECT_TRUE(result->axis(1)->isStrictlyA()); + EXPECT_TRUE(result->axis(2)->isA()); +} + +// Test reduction on non-ragged dimension +TEST_F(RaggedIterDomainTest, ReductionOnNonRaggedDim) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensor: [component, ragged, dim1] + auto nested = asNested(data, offsets, 0); + + // Reduce along the last dimension (non-ragged) + auto result = sum(nested, {2}); + + fusion.addOutput(result); + + // Result should be: [component, ragged] + // Debug: print the result structure + std::cout << "ReductionOnNonRaggedDim result dimensions: " << result->nDims() + << std::endl; + for (auto i : c10::irange(result->nDims())) { + std::cout << " axis " << i << ": " + << (result->axis(i)->isA() ? "RaggedIterDomain" + : "IterDomain") + << std::endl; + } + + EXPECT_EQ(result->nDims(), 2); + EXPECT_TRUE(result->axis(0)->isStrictlyA()); + EXPECT_TRUE(result->axis(1)->isA()); +} + +// Test reduction on ragged dimension +TEST_F(RaggedIterDomainTest, ReductionOnRaggedDim) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensor: [component, ragged, dim1] + auto nested = asNested(data, offsets, 0); + + // Reduce along the ragged dimension (axis 1) + auto result = sum(nested, {1}); + + fusion.addOutput(result); + + // Result should be: [component, dim1] + // Both should be regular IterDomains (ragged dimension is reduced away) + // Debug: print the result structure + std::cout << "ReductionOnRaggedDim result dimensions: " << result->nDims() + << std::endl; + for (auto i : c10::irange(result->nDims())) { + std::cout << " axis " << i << ": " + << (result->axis(i)->isA() ? "RaggedIterDomain" + : "IterDomain") + << std::endl; + } + + EXPECT_EQ(result->nDims(), 2); + EXPECT_TRUE(result->axis(0)->isStrictlyA()); + EXPECT_FALSE(result->axis(0)->isA()); + EXPECT_TRUE(result->axis(1)->isStrictlyA()); + EXPECT_FALSE(result->axis(1)->isA()); +} + } // namespace nvfuser From 60a2dd51b3e5e5321abc4cffa1b6f58c34c12cb3 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 17 Dec 2025 14:15:04 -0800 Subject: [PATCH 20/47] RaggedIterDomain and reduction --- csrc/ops/arith.cpp | 26 +++++++--- tests/cpp/test_ragged_iter_domain.cpp | 68 ++++++++++++++------------- 2 files changed, 55 insertions(+), 39 deletions(-) diff --git a/csrc/ops/arith.cpp b/csrc/ops/arith.cpp index 10e4f4007b8..f6dceef7f1d 100644 --- a/csrc/ops/arith.cpp +++ b/csrc/ops/arith.cpp @@ -1209,6 +1209,14 @@ TensorView* newForReduction( " of tensor ", tv); } + NVF_CHECK( + !id->isA(), + "Cannot reduce a RaggedIterDomain. Reduction of ragged dimensions is " + "not supported. " + "Tried to reduce ID = ", + id, + " of tensor ", + tv); new_id = IterDomainBuilder(id) // If the domain is being reduced, but it's coming in as an // expanded extent, we need to realize the expand. @@ -1217,12 +1225,18 @@ TensorView* newForReduction( .iter_type(IterType::Reduction) .build(); } else { - new_id = IterDomainBuilder(id) - .extent(id->extent()) - .resetSchedulingParams() - .parallel_type(id->getParallelType()) - .iter_type(id->getIterType()) - .build(); + // For non-reduced dimensions, preserve RaggedIterDomain if present + if (id->isA()) { + // Cast away const since cloneWithoutRFactor is not const + new_id = const_cast(id)->cloneWithoutRFactor(); + } else { + new_id = IterDomainBuilder(id) + .extent(id->extent()) + .resetSchedulingParams() + .parallel_type(id->getParallelType()) + .iter_type(id->getIterType()) + .build(); + } } new_domain.push_back(new_id); } diff --git a/tests/cpp/test_ragged_iter_domain.cpp b/tests/cpp/test_ragged_iter_domain.cpp index ac3c3ef1f35..2fa9cf04d2d 100644 --- a/tests/cpp/test_ragged_iter_domain.cpp +++ b/tests/cpp/test_ragged_iter_domain.cpp @@ -740,23 +740,17 @@ TEST_F(RaggedIterDomainTest, ReductionOnNonRaggedDim) { fusion.addOutput(result); // Result should be: [component, ragged] - // Debug: print the result structure - std::cout << "ReductionOnNonRaggedDim result dimensions: " << result->nDims() - << std::endl; - for (auto i : c10::irange(result->nDims())) { - std::cout << " axis " << i << ": " - << (result->axis(i)->isA() ? "RaggedIterDomain" - : "IterDomain") - << std::endl; - } + // Get non-reduction dimensions + auto non_reduction_domain = + TensorDomain::noReductions(result->getLogicalDomain()); - EXPECT_EQ(result->nDims(), 2); - EXPECT_TRUE(result->axis(0)->isStrictlyA()); - EXPECT_TRUE(result->axis(1)->isA()); + EXPECT_EQ(non_reduction_domain.size(), 2); + EXPECT_TRUE(non_reduction_domain[0]->isStrictlyA()); + EXPECT_TRUE(non_reduction_domain[1]->isA()); } -// Test reduction on ragged dimension -TEST_F(RaggedIterDomainTest, ReductionOnRaggedDim) { +// Test reduction on ragged dimension - should error +TEST_F(RaggedIterDomainTest, ReductionOnRaggedDimError) { Fusion fusion; FusionGuard fg(&fusion); @@ -769,28 +763,36 @@ TEST_F(RaggedIterDomainTest, ReductionOnRaggedDim) { // Create nested tensor: [component, ragged, dim1] auto nested = asNested(data, offsets, 0); - // Reduce along the ragged dimension (axis 1) - auto result = sum(nested, {1}); + // Try to reduce along the ragged dimension (axis 1) + // This should throw an error because reducing RaggedIterDomain is not allowed + EXPECT_THROW(sum(nested, {1}), nvfuser::nvfError); +} - fusion.addOutput(result); +// Test reduction on component dimension - should error (TODO) +TEST_F(RaggedIterDomainTest, ReductionOnComponentDimError) { + GTEST_SKIP() << "TODO: Implement validation to prevent reduction of " + "component dimension. " + << "Currently there is no explicit marking of which IterDomains " + "are component dimensions, " + << "so this validation cannot be implemented yet."; - // Result should be: [component, dim1] - // Both should be regular IterDomains (ragged dimension is reduced away) - // Debug: print the result structure - std::cout << "ReductionOnRaggedDim result dimensions: " << result->nDims() - << std::endl; - for (auto i : c10::irange(result->nDims())) { - std::cout << " axis " << i << ": " - << (result->axis(i)->isA() ? "RaggedIterDomain" - : "IterDomain") - << std::endl; - } + Fusion fusion; + FusionGuard fg(&fusion); - EXPECT_EQ(result->nDims(), 2); - EXPECT_TRUE(result->axis(0)->isStrictlyA()); - EXPECT_FALSE(result->axis(0)->isA()); - EXPECT_TRUE(result->axis(1)->isStrictlyA()); - EXPECT_FALSE(result->axis(1)->isA()); + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensor: [component, ragged, dim1] + auto nested = asNested(data, offsets, 0); + + // Try to reduce along the component dimension (axis 0) + // This should throw an error because reducing component dimensions is not + // allowed The component dimension defines the batch structure of the ragged + // tensor, and reducing it would destroy the ragged structure + EXPECT_THROW(sum(nested, {0}), nvfuser::nvfError); } } // namespace nvfuser From 566d63de9021698a38032f1f47af4ac178673204 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 17 Dec 2025 19:23:55 -0800 Subject: [PATCH 21/47] WIP --- csrc/ir/internal_base_nodes.cpp | 90 ++++++++++++++++++++++++++++++++- csrc/ir/internal_base_nodes.h | 6 +++ csrc/ops/alias.cpp | 50 +++++++++++++++++- csrc/ops/indexing.cpp | 50 ++++++++++++++++++ 4 files changed, 193 insertions(+), 3 deletions(-) diff --git a/csrc/ir/internal_base_nodes.cpp b/csrc/ir/internal_base_nodes.cpp index 3b7c31a89e9..8897ac94e8e 100644 --- a/csrc/ir/internal_base_nodes.cpp +++ b/csrc/ir/internal_base_nodes.cpp @@ -48,7 +48,11 @@ IterDomainBuilder::IterDomainBuilder(const IterDomain* id) is_rfactor_domain_(id->isRFactorProduct()), is_padded_dimension_(id->hasPaddingToMultipleOfWarp()), is_clustered_dimension_(id->isClusteredBlockDim()), - padded_to_size_(id->getMaybeSizeAfterPadding()) {} + padded_to_size_(id->getMaybeSizeAfterPadding()) { + if (id->isA()) { + ragged_extents_ = id->as()->extents(); + } +} IterDomainBuilder& IterDomainBuilder::resetSchedulingParams() { parallel_type_ = ParallelType::Serial; @@ -116,7 +120,13 @@ IterDomain* IterDomainBuilder::build() const { NVF_ERROR( start_ != nullptr && extent_ != nullptr, "Start and extent are required to build an iter domain."); - return IrBuilder::createInContainer(start_->container(), *this); + + if (ragged_extents_ != nullptr) { + return IrBuilder::createInContainer( + start_->container(), *this); + } else { + return IrBuilder::createInContainer(start_->container(), *this); + } } IterDomain::IterDomain( @@ -604,6 +614,11 @@ IterDomain* IterDomain::resize( "Non-zero stop offset not considered: ", in->toString()); + NVF_CHECK( + !in->isA(), + "Resizing RaggedIterDomain is not supported: ", + in->toString()); + // The overall extent is (in_extent + left_expansion + // right_expansion). This can be simplified for a slice op as // the right expansion should look like (slice_end_offset - @@ -815,6 +830,77 @@ void validateLoopDomain( } // namespace +RaggedIterDomain::RaggedIterDomain( + IrBuilderPasskey passkey, + const IterDomainBuilder& args) + : IterDomain( + passkey, + ValType::RaggedIterDomain, + args.start_, + args.extent_, + args.expanded_extent_, + args.stop_offset_, + args.parallel_type_, + args.iter_type_, + args.is_rfactor_domain_, + args.is_padded_dimension_, + args.is_clustered_dimension_, + args.padded_to_size_), + extents_(args.ragged_extents_) { + // Extents must be non-null + NVF_ERROR( + extents_ != nullptr, "RaggedIterDomain requires non-null extents tensor"); + + // Extents must have integer dtype + NVF_ERROR_EQ( + extents_->dtype(), + DataType::Index, + "RaggedIterDomain extents must have index type, got ", + extents_->dtype()); + + // Only IterType::Iteration is supported at this moment + NVF_ERROR_EQ( + iter_type_, + IterType::Iteration, + "Only IterType::Iteration is supported: ", + iter_type_); + + // RaggedIterDomain has specific requirements on member values + NVF_ERROR( + start_->isZeroInt(), + "RaggedIterDomain start must be zero, got: ", + start_->toInlineString()); + + NVF_ERROR( + extent_->isOneInt(), + "RaggedIterDomain extent must be one (placeholder), got: ", + extent_->toInlineString()); + + NVF_ERROR( + expanded_extent_ == nullptr, + "RaggedIterDomain does not support expanded_extent"); + + NVF_ERROR( + stop_offset_ == nullptr || stop_offset_->isZeroInt(), + "RaggedIterDomain stop_offset must be nullptr or zero, got: ", + stop_offset_ ? stop_offset_->toInlineString() : "nullptr"); + + NVF_ERROR( + !is_rfactor_domain_, "RaggedIterDomain does not support rfactor domains"); + + NVF_ERROR( + !is_padded_dimension_, + "RaggedIterDomain does not support padded dimensions"); + + NVF_ERROR( + !is_clustered_dimension_, + "RaggedIterDomain does not support clustered dimensions"); + + NVF_ERROR( + !padded_to_size_.has_value(), + "RaggedIterDomain does not support padded_to_size"); +} + RaggedIterDomain::RaggedIterDomain( IrBuilderPasskey passkey, TensorView* extents, diff --git a/csrc/ir/internal_base_nodes.h b/csrc/ir/internal_base_nodes.h index d56d4d21470..1f3e6658e01 100644 --- a/csrc/ir/internal_base_nodes.h +++ b/csrc/ir/internal_base_nodes.h @@ -61,6 +61,7 @@ class IterDomainBuilder { IterDomainBuilder& is_rfactor_domain(bool _is_rfactor_domain); IterDomainBuilder& is_padded_dimension(bool _is_padded_dimension); IterDomainBuilder& padded_to_size(std::optional _padded_to_size); + IterDomainBuilder& ragged_extents(TensorView* _ragged_extents); IterDomain* build() const; @@ -79,6 +80,9 @@ class IterDomainBuilder { bool is_padded_dimension_ = false; bool is_clustered_dimension_ = false; std::optional padded_to_size_ = std::nullopt; + + // For RaggedIterDomain: stores the extents tensor + TensorView* ragged_extents_ = nullptr; }; //! Simply a representation of an annotated 1D iterable from start to extent. @@ -448,6 +452,8 @@ class NVF_API IterDomain : public Val { //! components class NVF_API RaggedIterDomain : public IterDomain { public: + RaggedIterDomain(IrBuilderPasskey passkey, const IterDomainBuilder& args); + //! \param extents TensorView containing component extents (must be integer //! type) //! \param iter_type Iteration type (Iteration, Reduction, etc.) diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index 870a1c186a3..16298a04d37 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -42,6 +42,13 @@ TensorView* segment_set(TensorView* tv) { TensorView* view(TensorView* x, DataType dtype) { NVF_ERROR(x != nullptr, "Input is invalid."); + + NVF_CHECK( + !x->domain()->hasRaggedIterDomain(), + "View operation is not supported for tensors with RaggedIterDomain. " + "Input tensor: ", + x->toString()); + if (x->getDataType() == dtype) { return x; } @@ -142,6 +149,12 @@ TensorView* reshape(TensorView* inp_tv, const std::vector& new_sizes) { "Unsupported input tensor to reshape as its axes may be partial: ", inp_tv->toString()); + NVF_CHECK( + !inp_tv->domain()->hasRaggedIterDomain(), + "Reshape operation is not supported for tensors with RaggedIterDomain. " + "Input tensor: ", + inp_tv->toString()); + auto static_reshape_output = tryStaticReshape(inp_tv, inp_dom, new_sizes); if (static_reshape_output) { return static_reshape_output; @@ -239,6 +252,12 @@ TensorView* flatten(TensorView* x, int64_t start_dim, int64_t end_dim) { end_dim); NVF_CHECK(start_dim <= end_dim, "start_dim must be <= end_dim"); + NVF_CHECK( + !x->domain()->hasRaggedIterDomain(), + "Flatten operation is not supported for tensors with RaggedIterDomain. " + "Input tensor: ", + x->toString()); + if (start_dim == end_dim) { return x; } @@ -518,6 +537,11 @@ TensorView* pad( const std::vector& pad_widths, Val* value, std::optional iter_type_opt) { + NVF_CHECK( + !inp->domain()->hasRaggedIterDomain(), + "Padding a tensor with a RaggedIterDomain not supported: ", + inp->toString()); + DataType dt = inp->getDataType().value(); if (!value) { // Create a zero of the appropriate type @@ -623,6 +647,13 @@ TensorView* cat( std::optional iter_type_opt, bool manual_padding) { NVF_CHECK(!inputs.empty(), "No input tensor given"); + NVF_CHECK( + std::ranges ::none_of( + inputs, + [](TensorView* inp_tv) { + return inp_tv->domain()->hasRaggedIterDomain(); + }), + "Concatenating a tensor with a RaggedIterDomain not supported"); const auto dtype = inputs.at(0)->getDataType().value(); @@ -783,7 +814,12 @@ TensorView* slice( NVF_CHECK_EQ( ndims, std::ssize(ranges), - "The range vector must have the same number of Slice descriptors.") + "The range vector must have the same number of Slice descriptors."); + + NVF_CHECK( + !inp->domain()->hasRaggedIterDomain(), + "Slicing a tensor with a RaggedIterDomain not supported: ", + inp->toString()); ExpressionEvaluator expr_eval; @@ -1058,6 +1094,12 @@ TensorView* broadcast( TensorView* expand(TensorView* inp, const std::vector& expanded_sizes) { auto inp_domain = TensorDomain::noReductions(inp->getLogicalDomain()); + NVF_CHECK( + !inp->domain()->hasRaggedIterDomain(), + "Expand operation is not supported for tensors with RaggedIterDomain. " + "Input tensor: ", + inp->toString()); + NVF_CHECK_GE(expanded_sizes.size(), inp_domain.size()); inp = ops::maybe_broadcast_inner_to_rank(inp, expanded_sizes.size()); @@ -1180,6 +1222,12 @@ TensorView* expand_as(TensorView* inp, TensorView* other) { TensorView* repeat( TensorView* inp_tv, const std::vector& repeat_times) { + NVF_CHECK( + !inp_tv->domain()->hasRaggedIterDomain(), + "Repeat operation is not supported for tensors with RaggedIterDomain. " + "Input tensor: ", + inp_tv->toString()); + const auto ndims = TensorDomain::noReductions(inp_tv->getLogicalDomain()).size(); diff --git a/csrc/ops/indexing.cpp b/csrc/ops/indexing.cpp index fb2f1b3feda..a28ca67f72b 100644 --- a/csrc/ops/indexing.cpp +++ b/csrc/ops/indexing.cpp @@ -24,6 +24,12 @@ TensorView* select(TensorView* tv, int64_t dim, Val* index) { auto dom = TensorDomain::noReductions(tv->getLogicalDomain()); NVF_CHECK(!dom.empty(), "select can not be applied to 0d tensor."); + NVF_CHECK( + !tv->domain()->hasRaggedIterDomain(), + "Select operation is not supported for tensors with RaggedIterDomain. " + "Input tensor: ", + tv->toString()); + std::vector new_root; new_root.reserve(dom.size() - 1); dim = wrapDim(dim, (int64_t)dom.size()); @@ -46,6 +52,18 @@ TensorView* indexSelect( TensorView* lookup_tv, int64_t dim, TensorView* index_tv) { + NVF_CHECK( + !lookup_tv->domain()->hasRaggedIterDomain(), + "IndexSelect operation is not supported for tensors with RaggedIterDomain. " + "Input tensor (lookup_tv): ", + lookup_tv->toString()); + + NVF_CHECK( + !index_tv->domain()->hasRaggedIterDomain(), + "IndexSelect operation is not supported for tensors with RaggedIterDomain. " + "Index tensor (index_tv): ", + index_tv->toString()); + DataType dtype = lookup_tv->getDataType().value(); NVF_CHECK( dtype != DataType::Null, "Invalid datatype provided for new value."); @@ -131,6 +149,18 @@ TensorView* indexPutAccumulate( // torch.gather TensorView* gather(TensorView* inp, int64_t dim, TensorView* index) { + NVF_CHECK( + !inp->domain()->hasRaggedIterDomain(), + "Gather operation is not supported for tensors with RaggedIterDomain. " + "Input tensor (inp): ", + inp->toString()); + + NVF_CHECK( + !index->domain()->hasRaggedIterDomain(), + "Gather operation is not supported for tensors with RaggedIterDomain. " + "Index tensor (index): ", + index->toString()); + auto inp_domain = TensorDomain::noReductions(inp->getLogicalDomain()); auto idx_domain = TensorDomain::noReductions(index->getLogicalDomain()); NVF_CHECK( @@ -168,6 +198,26 @@ TensorView* scatter( TensorView* index, Val* src, std::optional accumulate_op) { + NVF_CHECK( + !self->domain()->hasRaggedIterDomain(), + "Scatter operation is not supported for tensors with RaggedIterDomain. " + "Input tensor (self): ", + self->toString()); + + NVF_CHECK( + !index->domain()->hasRaggedIterDomain(), + "Scatter operation is not supported for tensors with RaggedIterDomain. " + "Index tensor (index): ", + index->toString()); + + if (src->isA()) { + NVF_CHECK( + !src->as()->domain()->hasRaggedIterDomain(), + "Scatter operation is not supported for tensors with RaggedIterDomain. " + "Source tensor (src): ", + src->toString()); + } + auto self_dom = TensorDomain::noReductions(self->getLogicalDomain()); auto idx_dom = TensorDomain::noReductions(index->getLogicalDomain()); From 144b206c988c824ab1bfbcf926d53e9bcc0c85f5 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 17 Dec 2025 21:25:37 -0800 Subject: [PATCH 22/47] WIP --- csrc/ops/arith.cpp | 75 ++++++ tests/cpp/test_ragged_iter_domain.cpp | 336 ++++++++++++++++++++++++++ 2 files changed, 411 insertions(+) diff --git a/csrc/ops/arith.cpp b/csrc/ops/arith.cpp index f6dceef7f1d..ce228a1d281 100644 --- a/csrc/ops/arith.cpp +++ b/csrc/ops/arith.cpp @@ -276,6 +276,12 @@ TensorView* randn_like(TensorView* tv, Val* philox_seed, Val* philox_offset) { isFloatingPointType(tv->dtype()), "input must have floating point type, but got ", tv->dtype()); + NVF_CHECK( + !tv->domain()->hasRaggedIterDomain(), + "randn_like operation is not supported for tensors with " + "RaggedIterDomain. " + "Input tensor: ", + tv->toString()); // Create a new output TV manually so that we carry over IterTypes, instead // of inferring them from the shape as we would if we used randn(). TensorView* out = ops::newOutputTV({tv}, tv->dtype()); @@ -303,6 +309,11 @@ TensorView* rand_like(TensorView* tv, Val* philox_seed, Val* philox_offset) { isFloatingPointType(tv->dtype()), "input must have floating point type, but got ", tv->dtype()); + NVF_CHECK( + !tv->domain()->hasRaggedIterDomain(), + "rand_like operation is not supported for tensors with RaggedIterDomain. " + "Input tensor: ", + tv->toString()); // Create a new output TV manually so that we carry over IterTypes, instead // of inferring them from the shape as we would if we used rand(). TensorView* out = ops::newOutputTV({tv}, tv->dtype()); @@ -339,6 +350,11 @@ TensorView* full( } TensorView* full_like(TensorView* tv, Val* fill_value, DataType dtype) { + NVF_CHECK( + !tv->domain()->hasRaggedIterDomain(), + "full_like operation is not supported for tensors with RaggedIterDomain. " + "Input tensor: ", + tv->toString()); fill_value = maybeCastOp(dtype, fill_value); TensorView* out = ops::newOutputTV({tv}, dtype); IrBuilder::create(out, fill_value); @@ -1575,6 +1591,31 @@ WelfordResult WelfordRaw( TensorView* init_avg, TensorView* init_var, Val* init_N) { + NVF_CHECK( + !tv->domain()->hasRaggedIterDomain(), + "WelfordRaw operation is not supported for tensors with " + "RaggedIterDomain. " + "Input tensor (tv): ", + tv->toString()); + + if (init_avg != nullptr) { + NVF_CHECK( + !init_avg->domain()->hasRaggedIterDomain(), + "WelfordRaw operation is not supported for tensors with " + "RaggedIterDomain. " + "Initial average tensor (init_avg): ", + init_avg->toString()); + } + + if (init_var != nullptr) { + NVF_CHECK( + !init_var->domain()->hasRaggedIterDomain(), + "WelfordRaw operation is not supported for tensors with " + "RaggedIterDomain. " + "Initial variance tensor (init_var): ", + init_var->toString()); + } + NVF_CHECK( TensorDomain::sameAs(tv->getLogicalDomain(), tv->getLoopDomain()), "Reducing a tensor once it's gone under transformations is not permitted " @@ -1645,6 +1686,28 @@ WelfordResult Welford( TensorView* init_avg, TensorView* init_var, Val* init_N) { + NVF_CHECK( + !tv->domain()->hasRaggedIterDomain(), + "Welford operation is not supported for tensors with RaggedIterDomain. " + "Input tensor (tv): ", + tv->toString()); + + if (init_avg != nullptr) { + NVF_CHECK( + !init_avg->domain()->hasRaggedIterDomain(), + "Welford operation is not supported for tensors with RaggedIterDomain. " + "Initial average tensor (init_avg): ", + init_avg->toString()); + } + + if (init_var != nullptr) { + NVF_CHECK( + !init_var->domain()->hasRaggedIterDomain(), + "Welford operation is not supported for tensors with RaggedIterDomain. " + "Initial variance tensor (init_var): ", + init_var->toString()); + } + NVF_CHECK( TensorDomain::sameAs(tv->getLogicalDomain(), tv->getLoopDomain()), "Reducing a tensor once it's gone under transformations is not permitted " @@ -1991,6 +2054,12 @@ TensorView* clamp(TensorView* in, Val* min_val, Val* max_val) { // sum_to operator TensorView* sum_to(TensorView* in, const std::vector& sum_to_size) { + NVF_CHECK( + !in->domain()->hasRaggedIterDomain(), + "sum_to operation is not supported for tensors with RaggedIterDomain. " + "Input tensor: ", + in->toString()); + const auto& logical = TensorDomain::noReductions(in->getLogicalDomain()); NVF_CHECK( @@ -2038,6 +2107,12 @@ TensorView* sum_to(TensorView* in, const std::vector& sum_to_size) { } TensorView* sum_to(TensorView* in, const std::vector& sum_to_size) { + NVF_CHECK( + !in->domain()->hasRaggedIterDomain(), + "sum_to operation is not supported for tensors with RaggedIterDomain. " + "Input tensor: ", + in->toString()); + const auto& logical = TensorDomain::noReductions(in->getLogicalDomain()); NVF_CHECK( diff --git a/tests/cpp/test_ragged_iter_domain.cpp b/tests/cpp/test_ragged_iter_domain.cpp index 2fa9cf04d2d..a8e7deb4be6 100644 --- a/tests/cpp/test_ragged_iter_domain.cpp +++ b/tests/cpp/test_ragged_iter_domain.cpp @@ -768,6 +768,342 @@ TEST_F(RaggedIterDomainTest, ReductionOnRaggedDimError) { EXPECT_THROW(sum(nested, {1}), nvfuser::nvfError); } +// Test reshape with nested tensors - should error +TEST_F(RaggedIterDomainTest, ReshapeWithNestedTensorsError) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensor: [component, ragged, dim1] + auto nested = asNested(data, offsets, 0); + + // Try to reshape - this should throw an error because reshape is not + // supported for tensors with RaggedIterDomain + std::vector new_shape = { + IrBuilder::create(-1L, DataType::Index), nested->axis(2)->extent()}; + EXPECT_THROW(reshape(nested, new_shape), nvfuser::nvfError); +} + +// Test flatten with nested tensors - should error +TEST_F(RaggedIterDomainTest, FlattenWithNestedTensorsError) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensor: [component, ragged, dim1] + auto nested = asNested(data, offsets, 0); + + // Try to flatten - this should throw an error because flatten is not + // supported for tensors with RaggedIterDomain + EXPECT_THROW(flatten(nested, 0, 2), nvfuser::nvfError); +} + +// Test transpose with nested tensors +TEST_F(RaggedIterDomainTest, TransposeWithNestedTensors) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensor: [component, ragged, dim1] + auto nested = asNested(data, offsets, 0); + + // Transpose ragged and dim1 dimensions + auto result = transpose(nested, 1, 2); + + fusion.addOutput(result); + + // Expected: [component, dim1, ragged] + // Should preserve RaggedIterDomain type + auto non_reduction_domain = + TensorDomain::noReductions(result->getLogicalDomain()); + + EXPECT_EQ(non_reduction_domain.size(), 3); + EXPECT_TRUE(non_reduction_domain[0]->isStrictlyA()); + EXPECT_TRUE(non_reduction_domain[1]->isStrictlyA()); + EXPECT_TRUE(non_reduction_domain[2]->isA()); +} + +// Test slice on ragged dimension - should error +TEST_F(RaggedIterDomainTest, SliceRaggedDimensionError) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensor: [component, ragged, dim1] + auto nested = asNested(data, offsets, 0); + + // Try to slice the ragged dimension (axis 1) + // This should error because resize on RaggedIterDomain is not allowed + EXPECT_THROW( + slice( + nested, + {{fusion.zeroVal(), fusion.oneVal()}, + {fusion.zeroVal(), fusion.oneVal()}, + {fusion.zeroVal(), nested->axis(2)->extent()}}), + nvfuser::nvfError); +} + +// Test cat on ragged dimension - should error +TEST_F(RaggedIterDomainTest, CatRaggedDimensionError) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data1 = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data1); + + auto data2 = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data2); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensors with same structure + auto nested1 = asNested(data1, offsets, 0); + auto nested2 = asNested(data2, offsets, 0); + + // Try to concatenate along ragged dimension (axis 1) + // This should error because cat would need to resize RaggedIterDomain + EXPECT_THROW(cat({nested1, nested2}, 1), nvfuser::nvfError); +} + +// Test cat on non-ragged dimension - currently also errors +TEST_F(RaggedIterDomainTest, CatNonRaggedDimensionError) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data1 = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data1); + + auto data2 = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data2); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensors with same structure + auto nested1 = asNested(data1, offsets, 0); + auto nested2 = asNested(data2, offsets, 0); + + // Try to concatenate along non-ragged dimension (axis 2) + // Currently cat rejects all tensors with RaggedIterDomain for safety + // In the future, this could be supported if concatenating along non-ragged + // dimensions + EXPECT_THROW(cat({nested1, nested2}, 2), nvfuser::nvfError); +} + +// Test expand with nested tensors - should error +TEST_F(RaggedIterDomainTest, ExpandWithNestedTensorsError) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensor: [component, ragged, dim1] + auto nested = asNested(data, offsets, 0); + + // Try to expand a broadcast dimension - should error + auto broadcasted = broadcast(nested, {false, false, false, true}); + EXPECT_THROW( + expand( + broadcasted, + {nested->axis(0)->extent(), + nested->axis(1)->extent(), + nested->axis(2)->extent(), + IrBuilder::create(5L, DataType::Index)}), + nvfuser::nvfError); +} + +// Test pad on ragged dimension - should error +TEST_F(RaggedIterDomainTest, PadRaggedDimensionError) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensor: [component, ragged, dim1] + auto nested = asNested(data, offsets, 0); + + // Try to pad the ragged dimension (axis 1) + // This should error because pad uses resize on RaggedIterDomain + std::vector pad_widths = { + fusion.zeroVal(), + fusion.zeroVal(), // component: no padding + fusion.oneVal(), + fusion.oneVal(), // ragged: PADDING - should error + fusion.zeroVal(), + fusion.zeroVal() // dim1: no padding + }; + + EXPECT_THROW(pad(nested, pad_widths), nvfuser::nvfError); +} + +// Test select with nested tensors - should error +TEST_F(RaggedIterDomainTest, SelectWithNestedTensorsError) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensor: [component, ragged, dim1] + auto nested = asNested(data, offsets, 0); + + // Try to select from a non-ragged dimension - should error + EXPECT_THROW(select(nested, 0, fusion.zeroVal()), nvfuser::nvfError); +} + +// Test gather with nested tensors - should error +TEST_F(RaggedIterDomainTest, GatherWithNestedTensorsError) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + auto index = makeSymbolicTensor(3, DataType::Index); + fusion.addInput(index); + + // Create nested tensor: [component, ragged, dim1] + auto nested = asNested(data, offsets, 0); + + // Try to gather from nested tensor - should error + EXPECT_THROW(gather(nested, 2, index), nvfuser::nvfError); +} + +// Test view operations with nested tensors - should error +TEST_F(RaggedIterDomainTest, ViewWithNestedTensorsError) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensor: [component, ragged, dim1] + auto nested = asNested(data, offsets, 0); + + // Try to change dtype via view - should error + EXPECT_THROW(view(nested, DataType::Half), nvfuser::nvfError); +} + +// Test select (indexing) with nested tensors - should error +TEST_F(RaggedIterDomainTest, SelectIndexingWithNestedTensorsError) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensor: [component, ragged, dim1] + auto nested = asNested(data, offsets, 0); + + // Try to select from component dimension - should error + EXPECT_THROW(select(nested, 0, fusion.zeroVal()), nvfuser::nvfError); +} + +// Test index_select with nested tensors - should error +TEST_F(RaggedIterDomainTest, IndexSelectWithNestedTensorsError) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + auto indices = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(indices); + + // Create nested tensor: [component, ragged, dim1] + auto nested = asNested(data, offsets, 0); + + // Try to index select from non-ragged dimension - should error + EXPECT_THROW(indexSelect(nested, 2, indices), nvfuser::nvfError); +} + +// Test scatter with nested tensors - should error +TEST_F(RaggedIterDomainTest, ScatterWithNestedTensorsError) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + auto src = makeSymbolicTensor(3, DataType::Float); + fusion.addInput(src); + + auto indices = makeSymbolicTensor(3, DataType::Index); + fusion.addInput(indices); + + // Create nested tensor: [component, ragged, dim1] + auto nested = asNested(data, offsets, 0); + + // Try to scatter into nested tensor - should error + EXPECT_THROW(scatter(nested, 2, indices, src), nvfuser::nvfError); +} + +// Test repeat with nested tensors - should error +TEST_F(RaggedIterDomainTest, RepeatWithNestedTensorsError) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensor: [component, ragged, dim1] + auto nested = asNested(data, offsets, 0); + + // Try to repeat along non-ragged dimension - should error + std::vector repeats = {1, 1, 2}; + EXPECT_THROW(repeat(nested, repeats), nvfuser::nvfError); +} + // Test reduction on component dimension - should error (TODO) TEST_F(RaggedIterDomainTest, ReductionOnComponentDimError) { GTEST_SKIP() << "TODO: Implement validation to prevent reduction of " From e2efe752bf4e77049509d2da1b9a9027c9798a45 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 17 Dec 2025 22:47:29 -0800 Subject: [PATCH 23/47] cleanup --- csrc/ir/internal_base_nodes.cpp | 5 - csrc/ops/alias.cpp | 23 +- csrc/ops/arith.cpp | 93 +------ csrc/ops/utils.cpp | 3 + tests/cpp/test_ragged_iter_domain.cpp | 359 +++++++------------------- 5 files changed, 117 insertions(+), 366 deletions(-) diff --git a/csrc/ir/internal_base_nodes.cpp b/csrc/ir/internal_base_nodes.cpp index 8897ac94e8e..15336e269ee 100644 --- a/csrc/ir/internal_base_nodes.cpp +++ b/csrc/ir/internal_base_nodes.cpp @@ -614,11 +614,6 @@ IterDomain* IterDomain::resize( "Non-zero stop offset not considered: ", in->toString()); - NVF_CHECK( - !in->isA(), - "Resizing RaggedIterDomain is not supported: ", - in->toString()); - // The overall extent is (in_extent + left_expansion + // right_expansion). This can be simplified for a slice op as // the right expansion should look like (slice_end_offset - diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index 16298a04d37..ea87ae1d73c 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -42,13 +42,6 @@ TensorView* segment_set(TensorView* tv) { TensorView* view(TensorView* x, DataType dtype) { NVF_ERROR(x != nullptr, "Input is invalid."); - - NVF_CHECK( - !x->domain()->hasRaggedIterDomain(), - "View operation is not supported for tensors with RaggedIterDomain. " - "Input tensor: ", - x->toString()); - if (x->getDataType() == dtype) { return x; } @@ -538,8 +531,8 @@ TensorView* pad( Val* value, std::optional iter_type_opt) { NVF_CHECK( - !inp->domain()->hasRaggedIterDomain(), - "Padding a tensor with a RaggedIterDomain not supported: ", + inp->domain()->hasRaggedIterDomain(), + "Padding a tensor with RaggedIterDomain not supported: ", inp->toString()); DataType dt = inp->getDataType().value(); @@ -647,13 +640,14 @@ TensorView* cat( std::optional iter_type_opt, bool manual_padding) { NVF_CHECK(!inputs.empty(), "No input tensor given"); + NVF_CHECK( - std::ranges ::none_of( + std::ranges::none_of( inputs, [](TensorView* inp_tv) { return inp_tv->domain()->hasRaggedIterDomain(); }), - "Concatenating a tensor with a RaggedIterDomain not supported"); + "Concat with a tensor with RaggedIterDomain not supported"); const auto dtype = inputs.at(0)->getDataType().value(); @@ -818,7 +812,7 @@ TensorView* slice( NVF_CHECK( !inp->domain()->hasRaggedIterDomain(), - "Slicing a tensor with a RaggedIterDomain not supported: ", + "Slicing a tensor with RaggedIterDomain not supported: ", inp->toString()); ExpressionEvaluator expr_eval; @@ -1328,6 +1322,11 @@ TensorView* asNested( 1, "asNested currently only supports 1D offset tensors"); + NVF_CHECK( + !data->domain()->hasRaggedIterDomain(), + "Multiple level of nesting is not supported: ", + data->toString()); + // Get the logical domain of the input, excluding reductions auto inp_logical = TensorDomain::noReductions(data->getLogicalDomain()); diff --git a/csrc/ops/arith.cpp b/csrc/ops/arith.cpp index ce228a1d281..89d6dfc3a43 100644 --- a/csrc/ops/arith.cpp +++ b/csrc/ops/arith.cpp @@ -276,12 +276,6 @@ TensorView* randn_like(TensorView* tv, Val* philox_seed, Val* philox_offset) { isFloatingPointType(tv->dtype()), "input must have floating point type, but got ", tv->dtype()); - NVF_CHECK( - !tv->domain()->hasRaggedIterDomain(), - "randn_like operation is not supported for tensors with " - "RaggedIterDomain. " - "Input tensor: ", - tv->toString()); // Create a new output TV manually so that we carry over IterTypes, instead // of inferring them from the shape as we would if we used randn(). TensorView* out = ops::newOutputTV({tv}, tv->dtype()); @@ -309,11 +303,6 @@ TensorView* rand_like(TensorView* tv, Val* philox_seed, Val* philox_offset) { isFloatingPointType(tv->dtype()), "input must have floating point type, but got ", tv->dtype()); - NVF_CHECK( - !tv->domain()->hasRaggedIterDomain(), - "rand_like operation is not supported for tensors with RaggedIterDomain. " - "Input tensor: ", - tv->toString()); // Create a new output TV manually so that we carry over IterTypes, instead // of inferring them from the shape as we would if we used rand(). TensorView* out = ops::newOutputTV({tv}, tv->dtype()); @@ -350,11 +339,6 @@ TensorView* full( } TensorView* full_like(TensorView* tv, Val* fill_value, DataType dtype) { - NVF_CHECK( - !tv->domain()->hasRaggedIterDomain(), - "full_like operation is not supported for tensors with RaggedIterDomain. " - "Input tensor: ", - tv->toString()); fill_value = maybeCastOp(dtype, fill_value); TensorView* out = ops::newOutputTV({tv}, dtype); IrBuilder::create(out, fill_value); @@ -1241,18 +1225,12 @@ TensorView* newForReduction( .iter_type(IterType::Reduction) .build(); } else { - // For non-reduced dimensions, preserve RaggedIterDomain if present - if (id->isA()) { - // Cast away const since cloneWithoutRFactor is not const - new_id = const_cast(id)->cloneWithoutRFactor(); - } else { - new_id = IterDomainBuilder(id) - .extent(id->extent()) - .resetSchedulingParams() - .parallel_type(id->getParallelType()) - .iter_type(id->getIterType()) - .build(); - } + new_id = IterDomainBuilder(id) + .extent(id->extent()) + .resetSchedulingParams() + .parallel_type(id->getParallelType()) + .iter_type(id->getIterType()) + .build(); } new_domain.push_back(new_id); } @@ -1591,31 +1569,6 @@ WelfordResult WelfordRaw( TensorView* init_avg, TensorView* init_var, Val* init_N) { - NVF_CHECK( - !tv->domain()->hasRaggedIterDomain(), - "WelfordRaw operation is not supported for tensors with " - "RaggedIterDomain. " - "Input tensor (tv): ", - tv->toString()); - - if (init_avg != nullptr) { - NVF_CHECK( - !init_avg->domain()->hasRaggedIterDomain(), - "WelfordRaw operation is not supported for tensors with " - "RaggedIterDomain. " - "Initial average tensor (init_avg): ", - init_avg->toString()); - } - - if (init_var != nullptr) { - NVF_CHECK( - !init_var->domain()->hasRaggedIterDomain(), - "WelfordRaw operation is not supported for tensors with " - "RaggedIterDomain. " - "Initial variance tensor (init_var): ", - init_var->toString()); - } - NVF_CHECK( TensorDomain::sameAs(tv->getLogicalDomain(), tv->getLoopDomain()), "Reducing a tensor once it's gone under transformations is not permitted " @@ -1686,28 +1639,6 @@ WelfordResult Welford( TensorView* init_avg, TensorView* init_var, Val* init_N) { - NVF_CHECK( - !tv->domain()->hasRaggedIterDomain(), - "Welford operation is not supported for tensors with RaggedIterDomain. " - "Input tensor (tv): ", - tv->toString()); - - if (init_avg != nullptr) { - NVF_CHECK( - !init_avg->domain()->hasRaggedIterDomain(), - "Welford operation is not supported for tensors with RaggedIterDomain. " - "Initial average tensor (init_avg): ", - init_avg->toString()); - } - - if (init_var != nullptr) { - NVF_CHECK( - !init_var->domain()->hasRaggedIterDomain(), - "Welford operation is not supported for tensors with RaggedIterDomain. " - "Initial variance tensor (init_var): ", - init_var->toString()); - } - NVF_CHECK( TensorDomain::sameAs(tv->getLogicalDomain(), tv->getLoopDomain()), "Reducing a tensor once it's gone under transformations is not permitted " @@ -2054,12 +1985,6 @@ TensorView* clamp(TensorView* in, Val* min_val, Val* max_val) { // sum_to operator TensorView* sum_to(TensorView* in, const std::vector& sum_to_size) { - NVF_CHECK( - !in->domain()->hasRaggedIterDomain(), - "sum_to operation is not supported for tensors with RaggedIterDomain. " - "Input tensor: ", - in->toString()); - const auto& logical = TensorDomain::noReductions(in->getLogicalDomain()); NVF_CHECK( @@ -2107,12 +2032,6 @@ TensorView* sum_to(TensorView* in, const std::vector& sum_to_size) { } TensorView* sum_to(TensorView* in, const std::vector& sum_to_size) { - NVF_CHECK( - !in->domain()->hasRaggedIterDomain(), - "sum_to operation is not supported for tensors with RaggedIterDomain. " - "Input tensor: ", - in->toString()); - const auto& logical = TensorDomain::noReductions(in->getLogicalDomain()); NVF_CHECK( diff --git a/csrc/ops/utils.cpp b/csrc/ops/utils.cpp index 2623f4c68be..be50385528c 100644 --- a/csrc/ops/utils.cpp +++ b/csrc/ops/utils.cpp @@ -326,6 +326,9 @@ RaggedIterDomain* newOutputRaggedIterDomain( "All input iter domains must be RaggedIterDomain"); NVF_ERROR(!input_ids.empty()); + + // Just using the first ragged ID as all input IDs are assumed to be + // equivalent RaggedIterDomain* ref_input_id = input_ids.front()->as(); return IrBuilder::create( diff --git a/tests/cpp/test_ragged_iter_domain.cpp b/tests/cpp/test_ragged_iter_domain.cpp index a8e7deb4be6..b8406303c01 100644 --- a/tests/cpp/test_ragged_iter_domain.cpp +++ b/tests/cpp/test_ragged_iter_domain.cpp @@ -354,48 +354,36 @@ TEST_F(RaggedIterDomainTest, AsNestedBasic) { // Create nested tensor from dimension 0 auto nested = asNested(data, offsets, 0); - // This should still be a nested tensor - auto copy_of_nested = set(nested); + fusion.addOutput(nested); - fusion.addOutput(copy_of_nested); + // Verify the output is a new TensorView + EXPECT_TRUE(nested != nullptr); + EXPECT_NE(nested, data); + EXPECT_TRUE(nested->isA()); + + // Verify nested tensor has 3 dimensions: [component, ragged, + // original_dim1] + EXPECT_EQ(nested->nDims(), 3); + + // First axis should be a regular IterDomain (component) + EXPECT_TRUE(nested->axis(0)->isStrictlyA()); + EXPECT_FALSE(nested->axis(0)->isA()); - for (auto nested_tv : {nested, copy_of_nested}) { - // Verify the output is a new TensorView - EXPECT_TRUE(nested_tv != nullptr); - EXPECT_NE(nested_tv, data); - EXPECT_TRUE(nested_tv->isA()); - - // Verify nested_tv tensor has 3 dimensions: [component, ragged, - // original_dim1] - EXPECT_EQ(nested_tv->nDims(), 3); - - // First axis should be a regular IterDomain (component) - EXPECT_TRUE(nested_tv->axis(0)->isStrictlyA()); - EXPECT_FALSE(nested_tv->axis(0)->isA()); - - // Second axis should be a RaggedIterDomain - EXPECT_TRUE(nested_tv->axis(1)->isA()); - - // Third axis should be the original second dimension - EXPECT_TRUE(nested_tv->axis(2)->isStrictlyA()); - - if (nested_tv == nested) { - // Verify the definition exists (LoadStoreOp for aliasing) - EXPECT_TRUE(nested_tv->definition() != nullptr); - EXPECT_TRUE(nested_tv->definition()->isA()); - - // Verify the component and ragged IterDomains have Partition as their - // definition - EXPECT_TRUE(nested_tv->axis(0)->definition() != nullptr); - EXPECT_TRUE(nested_tv->axis(0)->definition()->isA()); - EXPECT_EQ( - nested_tv->axis(0)->definition(), nested_tv->axis(1)->definition()); - } else { - // The copy of the original nested tensor does not inherit the Partition - // op - EXPECT_TRUE(nested_tv->axis(0)->definition() == nullptr); - } - } + // Second axis should be a RaggedIterDomain + EXPECT_TRUE(nested->axis(1)->isA()); + + // Third axis should be the original second dimension + EXPECT_TRUE(nested->axis(2)->isStrictlyA()); + + // Verify the definition exists (LoadStoreOp for aliasing) + EXPECT_TRUE(nested->definition() != nullptr); + EXPECT_TRUE(nested->definition()->isA()); + + // Verify the component and ragged IterDomains have Partition as their + // definition + EXPECT_TRUE(nested->axis(0)->definition() != nullptr); + EXPECT_TRUE(nested->axis(0)->definition()->isA()); + EXPECT_EQ(nested->axis(0)->definition(), nested->axis(1)->definition()); } // asNested on different dimensions @@ -495,6 +483,48 @@ TEST_F(RaggedIterDomainTest, AsNestedValidationMultiDimOffsets) { EXPECT_THROW(asNested(data, offsets_2d, 0), nvfuser::nvfError); } +TEST_F(RaggedIterDomainTest, LoadStoreWithNestedTensor) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto offsets = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(offsets); + + // Create nested tensor from dimension 0 + auto nested = asNested(data, offsets, 0); + + // This should still be a nested tensor + auto copy_of_nested = set(nested); + + fusion.addOutput(copy_of_nested); + + // Verify the output is a new TensorView + EXPECT_TRUE(copy_of_nested != nullptr); + EXPECT_NE(copy_of_nested, data); + EXPECT_TRUE(copy_of_nested->isA()); + + // Verify copy_of_nested tensor has 3 dimensions: [component, ragged, + // original_dim1] + EXPECT_EQ(copy_of_nested->nDims(), 3); + + // First axis should be a regular IterDomain (component) + EXPECT_TRUE(copy_of_nested->axis(0)->isStrictlyA()); + EXPECT_FALSE(copy_of_nested->axis(0)->isA()); + + // Second axis should be a RaggedIterDomain + EXPECT_TRUE(copy_of_nested->axis(1)->isA()); + + // Third axis should be the original second dimension + EXPECT_TRUE(copy_of_nested->axis(2)->isStrictlyA()); + + // The copy of the original copy_of_nested tensor does not inherit the + // Partition op + EXPECT_TRUE(copy_of_nested->axis(0)->definition() == nullptr); +} + // Test binary operations with nested tensors TEST_F(RaggedIterDomainTest, BinaryOpWithNestedTensors) { Fusion fusion; @@ -514,7 +544,7 @@ TEST_F(RaggedIterDomainTest, BinaryOpWithNestedTensors) { auto nested1 = asNested(data1, offsets, 0); auto nested2 = asNested(data2, offsets, 0); - // Perform binary operation: add + // Perform binary operation. The result should be a nested tensor auto result = add(nested1, nested2); fusion.addOutput(result); @@ -577,8 +607,8 @@ TEST_F(RaggedIterDomainTest, BinaryOpDifferentRaggedStructures) { auto nested2 = asNested(data2, offsets2, 0); // This would be an error if, for example, the values of the offset - // tensors are not equivalent, but, like normal tensors, we assume - // that is indeed the case. + // tensors are not equivalent, but, like binary ops with normal + // tensors, we assume their shapes are indeed compatible auto result = add(nested1, nested2); fusion.addOutput(result); @@ -768,8 +798,14 @@ TEST_F(RaggedIterDomainTest, ReductionOnRaggedDimError) { EXPECT_THROW(sum(nested, {1}), nvfuser::nvfError); } -// Test reshape with nested tensors - should error -TEST_F(RaggedIterDomainTest, ReshapeWithNestedTensorsError) { +// Test reduction on component dimension - should error (TODO) +TEST_F(RaggedIterDomainTest, ReductionOnComponentDimError) { + GTEST_SKIP() << "TODO: Implement validation to prevent reduction of " + "component dimension. " + << "Currently there is no explicit marking of which IterDomains " + "are component dimensions, " + << "so this validation cannot be implemented yet."; + Fusion fusion; FusionGuard fg(&fusion); @@ -782,15 +818,15 @@ TEST_F(RaggedIterDomainTest, ReshapeWithNestedTensorsError) { // Create nested tensor: [component, ragged, dim1] auto nested = asNested(data, offsets, 0); - // Try to reshape - this should throw an error because reshape is not - // supported for tensors with RaggedIterDomain - std::vector new_shape = { - IrBuilder::create(-1L, DataType::Index), nested->axis(2)->extent()}; - EXPECT_THROW(reshape(nested, new_shape), nvfuser::nvfError); + // Try to reduce along the component dimension (axis 0) + // This should throw an error because reducing component dimensions is not + // allowed The component dimension defines the batch structure of the ragged + // tensor, and reducing it would destroy the ragged structure + EXPECT_THROW(sum(nested, {0}), nvfuser::nvfError); } -// Test flatten with nested tensors - should error -TEST_F(RaggedIterDomainTest, FlattenWithNestedTensorsError) { +// Test reshape with nested tensors - should error +TEST_F(RaggedIterDomainTest, ReshapeWithNestedTensorsError) { Fusion fusion; FusionGuard fg(&fusion); @@ -803,13 +839,15 @@ TEST_F(RaggedIterDomainTest, FlattenWithNestedTensorsError) { // Create nested tensor: [component, ragged, dim1] auto nested = asNested(data, offsets, 0); - // Try to flatten - this should throw an error because flatten is not + // Try to reshape - this should throw an error because reshape is not // supported for tensors with RaggedIterDomain - EXPECT_THROW(flatten(nested, 0, 2), nvfuser::nvfError); + std::vector new_shape = { + IrBuilder::create(-1L, DataType::Index), nested->axis(2)->extent()}; + EXPECT_THROW(reshape(nested, new_shape), nvfuser::nvfError); } -// Test transpose with nested tensors -TEST_F(RaggedIterDomainTest, TransposeWithNestedTensors) { +// Test flatten with nested tensors - should error +TEST_F(RaggedIterDomainTest, FlattenWithNestedTensorsError) { Fusion fusion; FusionGuard fg(&fusion); @@ -822,20 +860,9 @@ TEST_F(RaggedIterDomainTest, TransposeWithNestedTensors) { // Create nested tensor: [component, ragged, dim1] auto nested = asNested(data, offsets, 0); - // Transpose ragged and dim1 dimensions - auto result = transpose(nested, 1, 2); - - fusion.addOutput(result); - - // Expected: [component, dim1, ragged] - // Should preserve RaggedIterDomain type - auto non_reduction_domain = - TensorDomain::noReductions(result->getLogicalDomain()); - - EXPECT_EQ(non_reduction_domain.size(), 3); - EXPECT_TRUE(non_reduction_domain[0]->isStrictlyA()); - EXPECT_TRUE(non_reduction_domain[1]->isStrictlyA()); - EXPECT_TRUE(non_reduction_domain[2]->isA()); + // Try to flatten - this should throw an error because flatten is not + // supported for tensors with RaggedIterDomain + EXPECT_THROW(flatten(nested, 0, 2), nvfuser::nvfError); } // Test slice on ragged dimension - should error @@ -911,32 +938,6 @@ TEST_F(RaggedIterDomainTest, CatNonRaggedDimensionError) { EXPECT_THROW(cat({nested1, nested2}, 2), nvfuser::nvfError); } -// Test expand with nested tensors - should error -TEST_F(RaggedIterDomainTest, ExpandWithNestedTensorsError) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto data = makeSymbolicTensor(2, DataType::Float); - fusion.addInput(data); - - auto offsets = makeSymbolicTensor(1, DataType::Index); - fusion.addInput(offsets); - - // Create nested tensor: [component, ragged, dim1] - auto nested = asNested(data, offsets, 0); - - // Try to expand a broadcast dimension - should error - auto broadcasted = broadcast(nested, {false, false, false, true}); - EXPECT_THROW( - expand( - broadcasted, - {nested->axis(0)->extent(), - nested->axis(1)->extent(), - nested->axis(2)->extent(), - IrBuilder::create(5L, DataType::Index)}), - nvfuser::nvfError); -} - // Test pad on ragged dimension - should error TEST_F(RaggedIterDomainTest, PadRaggedDimensionError) { Fusion fusion; @@ -965,170 +966,4 @@ TEST_F(RaggedIterDomainTest, PadRaggedDimensionError) { EXPECT_THROW(pad(nested, pad_widths), nvfuser::nvfError); } -// Test select with nested tensors - should error -TEST_F(RaggedIterDomainTest, SelectWithNestedTensorsError) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto data = makeSymbolicTensor(2, DataType::Float); - fusion.addInput(data); - - auto offsets = makeSymbolicTensor(1, DataType::Index); - fusion.addInput(offsets); - - // Create nested tensor: [component, ragged, dim1] - auto nested = asNested(data, offsets, 0); - - // Try to select from a non-ragged dimension - should error - EXPECT_THROW(select(nested, 0, fusion.zeroVal()), nvfuser::nvfError); -} - -// Test gather with nested tensors - should error -TEST_F(RaggedIterDomainTest, GatherWithNestedTensorsError) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto data = makeSymbolicTensor(2, DataType::Float); - fusion.addInput(data); - - auto offsets = makeSymbolicTensor(1, DataType::Index); - fusion.addInput(offsets); - - auto index = makeSymbolicTensor(3, DataType::Index); - fusion.addInput(index); - - // Create nested tensor: [component, ragged, dim1] - auto nested = asNested(data, offsets, 0); - - // Try to gather from nested tensor - should error - EXPECT_THROW(gather(nested, 2, index), nvfuser::nvfError); -} - -// Test view operations with nested tensors - should error -TEST_F(RaggedIterDomainTest, ViewWithNestedTensorsError) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto data = makeSymbolicTensor(2, DataType::Float); - fusion.addInput(data); - - auto offsets = makeSymbolicTensor(1, DataType::Index); - fusion.addInput(offsets); - - // Create nested tensor: [component, ragged, dim1] - auto nested = asNested(data, offsets, 0); - - // Try to change dtype via view - should error - EXPECT_THROW(view(nested, DataType::Half), nvfuser::nvfError); -} - -// Test select (indexing) with nested tensors - should error -TEST_F(RaggedIterDomainTest, SelectIndexingWithNestedTensorsError) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto data = makeSymbolicTensor(2, DataType::Float); - fusion.addInput(data); - - auto offsets = makeSymbolicTensor(1, DataType::Index); - fusion.addInput(offsets); - - // Create nested tensor: [component, ragged, dim1] - auto nested = asNested(data, offsets, 0); - - // Try to select from component dimension - should error - EXPECT_THROW(select(nested, 0, fusion.zeroVal()), nvfuser::nvfError); -} - -// Test index_select with nested tensors - should error -TEST_F(RaggedIterDomainTest, IndexSelectWithNestedTensorsError) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto data = makeSymbolicTensor(2, DataType::Float); - fusion.addInput(data); - - auto offsets = makeSymbolicTensor(1, DataType::Index); - fusion.addInput(offsets); - - auto indices = makeSymbolicTensor(1, DataType::Index); - fusion.addInput(indices); - - // Create nested tensor: [component, ragged, dim1] - auto nested = asNested(data, offsets, 0); - - // Try to index select from non-ragged dimension - should error - EXPECT_THROW(indexSelect(nested, 2, indices), nvfuser::nvfError); -} - -// Test scatter with nested tensors - should error -TEST_F(RaggedIterDomainTest, ScatterWithNestedTensorsError) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto data = makeSymbolicTensor(2, DataType::Float); - fusion.addInput(data); - - auto offsets = makeSymbolicTensor(1, DataType::Index); - fusion.addInput(offsets); - - auto src = makeSymbolicTensor(3, DataType::Float); - fusion.addInput(src); - - auto indices = makeSymbolicTensor(3, DataType::Index); - fusion.addInput(indices); - - // Create nested tensor: [component, ragged, dim1] - auto nested = asNested(data, offsets, 0); - - // Try to scatter into nested tensor - should error - EXPECT_THROW(scatter(nested, 2, indices, src), nvfuser::nvfError); -} - -// Test repeat with nested tensors - should error -TEST_F(RaggedIterDomainTest, RepeatWithNestedTensorsError) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto data = makeSymbolicTensor(2, DataType::Float); - fusion.addInput(data); - - auto offsets = makeSymbolicTensor(1, DataType::Index); - fusion.addInput(offsets); - - // Create nested tensor: [component, ragged, dim1] - auto nested = asNested(data, offsets, 0); - - // Try to repeat along non-ragged dimension - should error - std::vector repeats = {1, 1, 2}; - EXPECT_THROW(repeat(nested, repeats), nvfuser::nvfError); -} - -// Test reduction on component dimension - should error (TODO) -TEST_F(RaggedIterDomainTest, ReductionOnComponentDimError) { - GTEST_SKIP() << "TODO: Implement validation to prevent reduction of " - "component dimension. " - << "Currently there is no explicit marking of which IterDomains " - "are component dimensions, " - << "so this validation cannot be implemented yet."; - - Fusion fusion; - FusionGuard fg(&fusion); - - auto data = makeSymbolicTensor(2, DataType::Float); - fusion.addInput(data); - - auto offsets = makeSymbolicTensor(1, DataType::Index); - fusion.addInput(offsets); - - // Create nested tensor: [component, ragged, dim1] - auto nested = asNested(data, offsets, 0); - - // Try to reduce along the component dimension (axis 0) - // This should throw an error because reducing component dimensions is not - // allowed The component dimension defines the batch structure of the ragged - // tensor, and reducing it would destroy the ragged structure - EXPECT_THROW(sum(nested, {0}), nvfuser::nvfError); -} - } // namespace nvfuser From 0b68d6b4517cafb745aa00c9cd81bd8111720a22 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 17 Dec 2025 23:01:33 -0800 Subject: [PATCH 24/47] cleanup --- csrc/ir/internal_base_nodes.cpp | 46 +++++++++++++++++++-------------- csrc/ir/internal_base_nodes.h | 22 ++++++++-------- csrc/ops/alias.cpp | 2 +- csrc/ops/indexing.cpp | 6 +++-- 4 files changed, 42 insertions(+), 34 deletions(-) diff --git a/csrc/ir/internal_base_nodes.cpp b/csrc/ir/internal_base_nodes.cpp index 15336e269ee..7ffb35739a8 100644 --- a/csrc/ir/internal_base_nodes.cpp +++ b/csrc/ir/internal_base_nodes.cpp @@ -30,8 +30,8 @@ namespace nvfuser { -IterDomainBuilder::IterDomainBuilder(Val* _start, Val* _extent) - : start_(_start), extent_(_extent) { +IterDomainBuilder::IterDomainBuilder(Val* start, Val* extent) + : start_(start), extent_(extent) { NVF_ERROR( start_ != nullptr && extent_ != nullptr, "Start and extent are required to build an iter domain."); @@ -67,52 +67,58 @@ IterDomainBuilder& IterDomainBuilder::resetRfactor() { return is_rfactor_domain(false); } -IterDomainBuilder& IterDomainBuilder::start(Val* _start) { - start_ = _start; +IterDomainBuilder& IterDomainBuilder::start(Val* start) { + start_ = start; return *this; } -IterDomainBuilder& IterDomainBuilder::extent(Val* _extent) { - extent_ = _extent; +IterDomainBuilder& IterDomainBuilder::extent(Val* extent) { + extent_ = extent; return *this; } -IterDomainBuilder& IterDomainBuilder::expanded_extent(Val* _expanded_extent) { - expanded_extent_ = _expanded_extent; +IterDomainBuilder& IterDomainBuilder::expanded_extent(Val* expanded_extent) { + expanded_extent_ = expanded_extent; return *this; } -IterDomainBuilder& IterDomainBuilder::stop_offset(Val* _stop_offset) { - stop_offset_ = _stop_offset; +IterDomainBuilder& IterDomainBuilder::stop_offset(Val* stop_offset) { + stop_offset_ = stop_offset; return *this; } IterDomainBuilder& IterDomainBuilder::parallel_type( - ParallelType _parallel_type) { - parallel_type_ = _parallel_type; + ParallelType parallel_type) { + parallel_type_ = parallel_type; return *this; } -IterDomainBuilder& IterDomainBuilder::iter_type(IterType _iter_type) { - iter_type_ = _iter_type; +IterDomainBuilder& IterDomainBuilder::iter_type(IterType iter_type) { + iter_type_ = iter_type; return *this; } IterDomainBuilder& IterDomainBuilder::is_rfactor_domain( - bool _is_rfactor_domain) { - is_rfactor_domain_ = _is_rfactor_domain; + bool is_rfactor_domain) { + is_rfactor_domain_ = is_rfactor_domain; return *this; } IterDomainBuilder& IterDomainBuilder::is_padded_dimension( - bool _is_padded_dimension) { - is_padded_dimension_ = _is_padded_dimension; + bool is_padded_dimension) { + is_padded_dimension_ = is_padded_dimension; return *this; } IterDomainBuilder& IterDomainBuilder::padded_to_size( - std::optional _padded_to_size) { - padded_to_size_ = _padded_to_size; + std::optional padded_to_size) { + padded_to_size_ = padded_to_size; + return *this; +} + +IterDomainBuilder& IterDomainBuilder::ragged_extents( + TensorView* ragged_extents) { + ragged_extents_ = ragged_extents; return *this; } diff --git a/csrc/ir/internal_base_nodes.h b/csrc/ir/internal_base_nodes.h index 1f3e6658e01..84a2e7686be 100644 --- a/csrc/ir/internal_base_nodes.h +++ b/csrc/ir/internal_base_nodes.h @@ -40,7 +40,7 @@ struct AnalyzeViewResult; class IterDomainBuilder { public: // Match legacy constructor - IterDomainBuilder(Val* _start, Val* _extent); + IterDomainBuilder(Val* start, Val* extent); // Grab all the parameters from id to set the IterDomainBuilder IterDomainBuilder(const IterDomain* id); @@ -52,16 +52,16 @@ class IterDomainBuilder { // Resets is_rfactor_domain IterDomainBuilder& resetRfactor(); - IterDomainBuilder& start(Val* _start); - IterDomainBuilder& extent(Val* _extent); - IterDomainBuilder& expanded_extent(Val* _expanded_extent); - IterDomainBuilder& stop_offset(Val* _stop_offset); - IterDomainBuilder& parallel_type(ParallelType _parallel_type); - IterDomainBuilder& iter_type(IterType _iter_type); - IterDomainBuilder& is_rfactor_domain(bool _is_rfactor_domain); - IterDomainBuilder& is_padded_dimension(bool _is_padded_dimension); - IterDomainBuilder& padded_to_size(std::optional _padded_to_size); - IterDomainBuilder& ragged_extents(TensorView* _ragged_extents); + IterDomainBuilder& start(Val* start); + IterDomainBuilder& extent(Val* extent); + IterDomainBuilder& expanded_extent(Val* expanded_extent); + IterDomainBuilder& stop_offset(Val* stop_offset); + IterDomainBuilder& parallel_type(ParallelType parallel_type); + IterDomainBuilder& iter_type(IterType iter_type); + IterDomainBuilder& is_rfactor_domain(bool is_rfactor_domain); + IterDomainBuilder& is_padded_dimension(bool is_padded_dimension); + IterDomainBuilder& padded_to_size(std::optional padded_to_size); + IterDomainBuilder& ragged_extents(TensorView* ragged_extents); IterDomain* build() const; diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index ea87ae1d73c..4a3609f8b28 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -531,7 +531,7 @@ TensorView* pad( Val* value, std::optional iter_type_opt) { NVF_CHECK( - inp->domain()->hasRaggedIterDomain(), + !inp->domain()->hasRaggedIterDomain(), "Padding a tensor with RaggedIterDomain not supported: ", inp->toString()); diff --git a/csrc/ops/indexing.cpp b/csrc/ops/indexing.cpp index a28ca67f72b..b8975c9ac66 100644 --- a/csrc/ops/indexing.cpp +++ b/csrc/ops/indexing.cpp @@ -54,13 +54,15 @@ TensorView* indexSelect( TensorView* index_tv) { NVF_CHECK( !lookup_tv->domain()->hasRaggedIterDomain(), - "IndexSelect operation is not supported for tensors with RaggedIterDomain. " + "IndexSelect operation is not supported for tensors with " + "RaggedIterDomain. " "Input tensor (lookup_tv): ", lookup_tv->toString()); NVF_CHECK( !index_tv->domain()->hasRaggedIterDomain(), - "IndexSelect operation is not supported for tensors with RaggedIterDomain. " + "IndexSelect operation is not supported for tensors with " + "RaggedIterDomain. " "Index tensor (index_tv): ", index_tv->toString()); From 8a73bb2f76729c9b5973365a20a8c7829f4bb1bd Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 18 Dec 2025 10:00:40 -0800 Subject: [PATCH 25/47] cleanup --- csrc/ir/interface_nodes.h | 11 +++--- csrc/ir/internal_base_nodes.cpp | 54 ++++++++------------------- csrc/ir/internal_base_nodes.h | 19 +++++----- csrc/ir/internal_nodes.cpp | 10 ++++- csrc/tensor_view.cpp | 10 ++--- tests/cpp/test_ragged_iter_domain.cpp | 50 ++++++++++++------------- 6 files changed, 69 insertions(+), 85 deletions(-) diff --git a/csrc/ir/interface_nodes.h b/csrc/ir/interface_nodes.h index ea0236a1a7b..0128c937ed4 100644 --- a/csrc/ir/interface_nodes.h +++ b/csrc/ir/interface_nodes.h @@ -619,14 +619,13 @@ class NVF_API TensorView : public Val { return merge(axis, axis + 1); } - // Partition "axis" into component and ragged dimensions based on offsets - // The offsets tensor defines partition boundaries where: - // Shape: [num_components + 1], values: [0, off1, off2, ..., total] - // Extents are computed as: extents[i] = offsets[i+1] - offsets[i] + // Partition "axis" into component and ragged dimensions based on extents + // The extents tensor directly specifies the size of each component: + // Shape: [num_components], values: [extent0, extent1, ..., extent(n-1)] // Returns this TensorView with the axis replaced by component and ragged dims - // e.g. partition(0, offsets) on tv[id{N}] results in: + // e.g. partition(0, extents) on tv[id{N}] results in: // tv[id{num_components}, ragged_id{extents}] - TensorView* partition(int64_t axis, TensorView* offsets); + TensorView* partition(int64_t axis, TensorView* extents); // Flatten the axis from `from` to `to` into a single axis. // Both `from` and `to` are inclusive. diff --git a/csrc/ir/internal_base_nodes.cpp b/csrc/ir/internal_base_nodes.cpp index cc068fbfca2..292c00a8740 100644 --- a/csrc/ir/internal_base_nodes.cpp +++ b/csrc/ir/internal_base_nodes.cpp @@ -897,7 +897,7 @@ std::string RaggedIterDomain::toString(int indent_size) const { std::pair RaggedIterDomain::partition( IterDomain* in, - TensorView* offsets) { + TensorView* extents) { NVF_ERROR(in != nullptr, "partition: input IterDomain is null"); NVF_ERROR( @@ -918,52 +918,28 @@ std::pair RaggedIterDomain::partition( " for IterDomain: ", in->toString()); - NVF_ERROR(offsets != nullptr, "partition: offsets tensor is null"); + NVF_ERROR(extents != nullptr, "partition: extents tensor is null"); NVF_ERROR_EQ( - offsets->dtype(), + extents->dtype(), DataType::Index, - "partition: offsets must have Index type, got ", - offsets->dtype()); + "partition: extents must have Index type, got ", + extents->dtype()); - const auto& offsets_domain = offsets->getLogicalDomain(); + const auto& extents_domain = extents->getLogicalDomain(); NVF_ERROR_EQ( - offsets_domain.size(), + extents_domain.size(), 1, - "partition: offsets tensor must be 1D, got ", - offsets_domain.size(), - "D tensor. Multi-dimensional offsets not yet supported."); + "partition: extents tensor must be 1D, got ", + extents_domain.size(), + "D tensor. Multi-dimensional extents not yet supported."); auto container = in->container(); - // Compute extents from offsets: extents[i] = offsets[i+1] - offsets[i] - // offsets_left = offsets[:-1] (all but last element) - // offsets_right = offsets[1:] (all but first element) - - auto offsets_len = offsets_domain[0]->extent(); - - auto zero = container->zeroVal(DataType::Index); - auto one = container->oneVal(DataType::Index); - auto len_minus_one = sub(offsets_len, one); - - // Slice offsets[:-1] - Slice left_slice; - left_slice.start = zero; - left_slice.stop = len_minus_one; - auto offsets_left = slice(offsets, {left_slice}); - - // Slice offsets[1:] - Slice right_slice; - right_slice.start = one; - right_slice.stop = offsets_len; - auto offsets_right = slice(offsets, {right_slice}); - - // Compute extents: extents = offsets_right - offsets_left - auto extents = sub(offsets_right, offsets_left); - // Create component IterDomain - // Component extent = number of components = len(offsets) - 1 - auto component_extent = len_minus_one; + // Component extent = number of components = length of extents tensor + auto zero = container->zeroVal(DataType::Index); + auto component_extent = extents_domain.at(0)->extent(); auto component_id = IterDomainBuilder(zero, component_extent) .parallel_type(ParallelType::Serial) .iter_type(IterType::Iteration) @@ -1583,13 +1559,13 @@ void TensorDomain::merge(int64_t axis_o, int64_t axis_i) { // Partition "axis" into component and ragged dimensions. Follow the // pattern of TensorDomain::split. -void TensorDomain::partition(int64_t axis, TensorView* offsets) { +void TensorDomain::partition(int64_t axis, TensorView* extents) { NVF_ERROR(nDims() > 0, "Tried to do partition on a 0-dim domain"); axis = wrapDim(axis); IterDomain* id = this->axis(axis); - auto [component_id, ragged_id] = RaggedIterDomain::partition(id, offsets); + auto [component_id, ragged_id] = RaggedIterDomain::partition(id, extents); // Remove the original axis and insert component and ragged dimensions loop_domain_.erase(loop_domain_.begin() + axis); diff --git a/csrc/ir/internal_base_nodes.h b/csrc/ir/internal_base_nodes.h index 0187c408bd7..c5fad115ba3 100644 --- a/csrc/ir/internal_base_nodes.h +++ b/csrc/ir/internal_base_nodes.h @@ -477,20 +477,21 @@ class NVF_API RaggedIterDomain : public IterDomain { } //! Partition an IterDomain into component and ragged dimensions - //! Creates a component IterDomain and a RaggedIterDomain based on offsets + //! Creates a component IterDomain and a RaggedIterDomain based on extents //! //! \param in Input IterDomain to partition (must be regular IterDomain) - //! \param offsets Offset tensor defining partition boundaries (must be 1D) - //! Shape: [num_components + 1], values: [0, off1, off2, ..., total] - //! Extents are computed as: extents[i] = offsets[i+1] - offsets[i] + //! \param extents Extents tensor defining the size of each component (must be + //! 1D) + //! Shape: [num_components], values: [extent0, extent1, ..., + //! extent(n-1)] //! \return Pair of (component_id, ragged_id) //! component_id: IterDomain with extent = num_components - //! ragged_id: RaggedIterDomain with extents computed from offsets + //! ragged_id: RaggedIterDomain with the provided extents //! - //! TODO: Support multi-dimensional offsets for nested ragged structures + //! TODO: Support multi-dimensional extents for nested ragged structures static std::pair partition( IterDomain* in, - TensorView* offsets); + TensorView* extents); private: //! Extent tensor containing all component extents @@ -792,8 +793,8 @@ class NVF_API TensorDomain : public Val { // axis is by default placed at original position axis_o void merge(int64_t axis_o, int64_t axis_i); - // Partition axis into component and ragged dimensions based on offsets - void partition(int64_t axis, TensorView* offsets); + // Partition axis into component and ragged dimensions based on extents + void partition(int64_t axis, TensorView* extents); // Reorder axes according to map[old_pos] = new_pos void reorder(const std::unordered_map& old2new); diff --git a/csrc/ir/internal_nodes.cpp b/csrc/ir/internal_nodes.cpp index 6ebc9271b02..219d1669b48 100644 --- a/csrc/ir/internal_nodes.cpp +++ b/csrc/ir/internal_nodes.cpp @@ -2622,7 +2622,15 @@ Partition::Partition( addOutput(component); addOutput(ragged); addInput(in); - // Should the extents tensor be an input rather than an attribute? + // Note: extents is held as an attribute rather than an input, + // despite it's a TensorView. Inputs and outputs in the existing + // IterDomain exprs are always IterDomains. Intuitively, they + // transform input iteration spaces into output iteration spaces in + // some way. Since the extents tensor itself is not transformed in the + // Partition expr, it doesn't seem to be considered as an input. Note that in + // Split, the split factor is an attribute. However, that said, none + // of the existing exprs has tensors as attributes, which makes this + // choice less certain with possible implications. addAttribute(extents); } diff --git a/csrc/tensor_view.cpp b/csrc/tensor_view.cpp index eabe2004f0c..fc10d3db3a9 100644 --- a/csrc/tensor_view.cpp +++ b/csrc/tensor_view.cpp @@ -562,8 +562,8 @@ TensorView* TensorView::merge(int64_t axis_o, int64_t axis_i) { } // Partition "axis" into component and ragged dimensions based on -// offsets. Follow the pattern of TensorView::split. -TensorView* TensorView::partition(int64_t axis, TensorView* offsets) { +// extents. Follow the pattern of TensorView::split. +TensorView* TensorView::partition(int64_t axis, TensorView* extents) { NVF_ERROR( nDims() > 0, "Tried to do partition on a 0-dim TensorView. ", @@ -598,11 +598,11 @@ TensorView* TensorView::partition(int64_t axis, TensorView* offsets) { " Parallelization strategy must be set after calling partition: ", toString()); - if (offsets->dtype() != DataType::Index) { - offsets = castOp(DataType::Index, offsets); + if (extents->dtype() != DataType::Index) { + extents = castOp(DataType::Index, extents); } - domain()->partition(axis, offsets); + domain()->partition(axis, extents); return this; } diff --git a/tests/cpp/test_ragged_iter_domain.cpp b/tests/cpp/test_ragged_iter_domain.cpp index 8d16615bd64..ecbbf10d2f2 100644 --- a/tests/cpp/test_ragged_iter_domain.cpp +++ b/tests/cpp/test_ragged_iter_domain.cpp @@ -211,13 +211,13 @@ TEST_F(RaggedIterDomainTest, PartitionBasic) { fusion.zeroVal(), IrBuilder::create(-1, DataType::Index)) .build(); - // Create a symbolic offset tensor - auto offsets = makeSymbolicTensor(1, DataType::Index); - fusion.addInput(offsets); + // Create a symbolic extents tensor + auto extents = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(extents); // Partition the IterDomain auto [component_id, ragged_id] = - RaggedIterDomain::partition(input_id, offsets); + RaggedIterDomain::partition(input_id, extents); // Verify component IterDomain EXPECT_TRUE(component_id != nullptr); @@ -265,28 +265,28 @@ TEST_F(RaggedIterDomainTest, PartitionValidation) { fusion.zeroVal(), IrBuilder::create(10L, DataType::Index)) .build(); - auto offsets = makeSymbolicTensor(1, DataType::Index); - fusion.addInput(offsets); + auto extents = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(extents); // Test 1: Null input should fail EXPECT_THROW( - RaggedIterDomain::partition(nullptr, offsets), nvfuser::nvfError); + RaggedIterDomain::partition(nullptr, extents), nvfuser::nvfError); - // Test 2: Null offsets should fail + // Test 2: Null extents should fail EXPECT_THROW( RaggedIterDomain::partition(input_id, nullptr), nvfuser::nvfError); - // Test 3: Non-Index offsets should fail - auto float_offsets = makeSymbolicTensor(1, DataType::Float); - fusion.addInput(float_offsets); + // Test 3: Non-Index extents should fail + auto float_extents = makeSymbolicTensor(1, DataType::Float); + fusion.addInput(float_extents); EXPECT_THROW( - RaggedIterDomain::partition(input_id, float_offsets), nvfuser::nvfError); + RaggedIterDomain::partition(input_id, float_extents), nvfuser::nvfError); - // Test 4: Multi-dimensional offsets should fail - auto offsets_2d = makeSymbolicTensor(2, DataType::Index); - fusion.addInput(offsets_2d); + // Test 4: Multi-dimensional extents should fail + auto extents_2d = makeSymbolicTensor(2, DataType::Index); + fusion.addInput(extents_2d); EXPECT_THROW( - RaggedIterDomain::partition(input_id, offsets_2d), nvfuser::nvfError); + RaggedIterDomain::partition(input_id, extents_2d), nvfuser::nvfError); // Test 5: Non-Iteration IterType should fail auto reduction_id = @@ -295,15 +295,15 @@ TEST_F(RaggedIterDomainTest, PartitionValidation) { .iter_type(IterType::Reduction) .build(); EXPECT_THROW( - RaggedIterDomain::partition(reduction_id, offsets), nvfuser::nvfError); + RaggedIterDomain::partition(reduction_id, extents), nvfuser::nvfError); // Test 6: Cannot partition RaggedIterDomain - auto extents = makeSymbolicTensor(1, DataType::Index); - fusion.addInput(extents); + auto extents2 = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(extents2); auto ragged_id = IrBuilder::create( - extents, IterType::Iteration, ParallelType::Serial); + extents2, IterType::Iteration, ParallelType::Serial); EXPECT_THROW( - RaggedIterDomain::partition(ragged_id, offsets), nvfuser::nvfError); + RaggedIterDomain::partition(ragged_id, extents), nvfuser::nvfError); } // TensorView::partition operation @@ -315,12 +315,12 @@ TEST_F(RaggedIterDomainTest, TensorViewPartition) { auto tv0 = makeSymbolicTensor(2, DataType::Float); fusion.addInput(tv0); - // Create offsets tensor - auto offsets = makeSymbolicTensor(1, DataType::Index); - fusion.addInput(offsets); + // Create extents tensor + auto extents = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(extents); // Partition the first axis - tv0->partition(0, offsets); + tv0->partition(0, extents); // Verify the tensor now has 3 dimensions: [component, ragged, original_dim1] EXPECT_EQ(tv0->nDims(), 3); From f215f079f1fa8ccbcf425ddf8470f8cf42cf3566 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 18 Dec 2025 14:01:20 -0800 Subject: [PATCH 26/47] Use extents as a parameter --- csrc/ops/alias.cpp | 12 ++++---- csrc/ops/alias.h | 13 ++++---- tests/cpp/test_ragged_iter_domain.cpp | 44 +++++++++++++-------------- 3 files changed, 34 insertions(+), 35 deletions(-) diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index e32aa5e6b9c..16e71e7ecdd 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -1270,16 +1270,16 @@ TensorView* repeat( TensorView* asNested( TensorView* data, - TensorView* offsets, + TensorView* extents, int64_t ragged_dim) { NVF_ERROR(data != nullptr, "asNested: data tensor is null"); - NVF_ERROR(offsets != nullptr, "asNested: offsets tensor is null"); + NVF_ERROR(extents != nullptr, "asNested: extents tensor is null"); - // Only 1D offset tensors are currently supported + // Only 1D extents tensors are currently supported NVF_ERROR_EQ( - offsets->nDims(), + extents->nDims(), 1, - "asNested currently only supports 1D offset tensors"); + "asNested currently only supports 1D extents tensors"); // Get the logical domain of the input, excluding reductions auto inp_logical = TensorDomain::noReductions(data->getLogicalDomain()); @@ -1296,7 +1296,7 @@ TensorView* asNested( // Partition the specified dimension in root domain // This replaces one IterDomain with (component_id, ragged_id) auto [component_id, ragged_id] = - RaggedIterDomain::partition(root_domain.at(ragged_dim), offsets); + RaggedIterDomain::partition(root_domain.at(ragged_dim), extents); // Build the logical domain: replace ragged_dim with component and ragged std::vector logical_domain; diff --git a/csrc/ops/alias.h b/csrc/ops/alias.h index f3bf769dd71..5963e99df66 100644 --- a/csrc/ops/alias.h +++ b/csrc/ops/alias.h @@ -197,26 +197,25 @@ NVF_API TensorView* repeat( TensorView* inp, const std::vector& repeat_times); -//! Create a nested tensor view from a data tensor and offsets. +//! Create a nested tensor view from a data tensor and extents. //! //! The function partitions the specified dimension of the data tensor into -//! a component dimension and a ragged dimension based on the provided offsets. +//! a component dimension and a ragged dimension based on the provided extents. //! //! \param data Input tensor to be converted to nested representation -//! \param offsets Offset tensor defining partition boundaries -//! Shape: [num_components + 1], values: [0, off1, off2, ..., total] -//! Extents are computed as: extents[i] = offsets[i+1] - offsets[i] +//! \param extents Extents tensor defining the size of each component +//! Shape: [num_components], values: [extent0, extent1, ..., extent(n-1)] //! \param ragged_dim Dimension to partition into nested structure //! \return TensorView with a RaggedIterDomain at the specified dimension //! //! Example: //! data shape: [10, ...] -//! offsets: [0, 3, 8, 10] +//! extents: [3, 5, 2] //! ragged_dim: 0 //! Result: nested tensor with 3 components. [3, [3, 5, 2], ...] NVF_API TensorView* asNested( TensorView* data, - TensorView* offsets, + TensorView* extents, int64_t ragged_dim); } // namespace nvfuser diff --git a/tests/cpp/test_ragged_iter_domain.cpp b/tests/cpp/test_ragged_iter_domain.cpp index 536a723d0cf..cc6b65078c1 100644 --- a/tests/cpp/test_ragged_iter_domain.cpp +++ b/tests/cpp/test_ragged_iter_domain.cpp @@ -348,11 +348,11 @@ TEST_F(RaggedIterDomainTest, AsNestedBasic) { auto data = makeSymbolicTensor(2, DataType::Float); fusion.addInput(data); - auto offsets = makeSymbolicTensor(1, DataType::Index); - fusion.addInput(offsets); + auto extents = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(extents); // Create nested tensor from dimension 0 - auto nested = asNested(data, offsets, 0); + auto nested = asNested(data, extents, 0); fusion.addOutput(nested); @@ -393,11 +393,11 @@ TEST_F(RaggedIterDomainTest, AsNestedDifferentDimension) { auto data = makeSymbolicTensor(3, DataType::Float); fusion.addInput(data); - auto offsets = makeSymbolicTensor(1, DataType::Index); - fusion.addInput(offsets); + auto extents = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(extents); // Partition dimension 1 (middle dimension) - auto nested = asNested(data, offsets, 1); + auto nested = asNested(data, extents, 1); // Verify dimensions: [dim0, component, ragged, dim2] EXPECT_EQ(nested->nDims(), 4); @@ -424,12 +424,12 @@ TEST_F(RaggedIterDomainTest, AsNested1DTensor) { auto data = makeSymbolicTensor(1, DataType::Float); fusion.addInput(data); - // Create offsets tensor - auto offsets = makeSymbolicTensor(1, DataType::Index); - fusion.addInput(offsets); + // Create extents tensor + auto extents = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(extents); // Create nested tensor from the only dimension - auto nested = asNested(data, offsets, 0); + auto nested = asNested(data, extents, 0); fusion.addOutput(nested); @@ -448,38 +448,38 @@ TEST_F(RaggedIterDomainTest, AsNestedValidationNullData) { Fusion fusion; FusionGuard fg(&fusion); - auto offsets = makeSymbolicTensor(1, DataType::Index); - fusion.addInput(offsets); + auto extents = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(extents); // Null data should throw - EXPECT_THROW(asNested(nullptr, offsets, 0), nvfuser::nvfError); + EXPECT_THROW(asNested(nullptr, extents, 0), nvfuser::nvfError); } -// asNested validation - null offsets -TEST_F(RaggedIterDomainTest, AsNestedValidationNullOffsets) { +// asNested validation - null extents +TEST_F(RaggedIterDomainTest, AsNestedValidationNullExtents) { Fusion fusion; FusionGuard fg(&fusion); auto data = makeSymbolicTensor(2, DataType::Float); fusion.addInput(data); - // Null offsets should throw + // Null extents should throw EXPECT_THROW(asNested(data, nullptr, 0), nvfuser::nvfError); } -// asNested validation - multi-dimensional offsets (not yet supported) -TEST_F(RaggedIterDomainTest, AsNestedValidationMultiDimOffsets) { +// asNested validation - multi-dimensional extents (not yet supported) +TEST_F(RaggedIterDomainTest, AsNestedValidationMultiDimExtents) { Fusion fusion; FusionGuard fg(&fusion); auto data = makeSymbolicTensor(2, DataType::Float); fusion.addInput(data); - // 2D offsets should fail (only 1D supported currently) - auto offsets_2d = makeSymbolicTensor(2, DataType::Index); - fusion.addInput(offsets_2d); + // 2D extents should fail (only 1D supported currently) + auto extents_2d = makeSymbolicTensor(2, DataType::Index); + fusion.addInput(extents_2d); - EXPECT_THROW(asNested(data, offsets_2d, 0), nvfuser::nvfError); + EXPECT_THROW(asNested(data, extents_2d, 0), nvfuser::nvfError); } } // namespace nvfuser From c3aebec4231fd4a52d3208cbe9c5b0e22e4fad79 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 18 Dec 2025 16:33:05 -0800 Subject: [PATCH 27/47] combine --- csrc/ir/internal_base_nodes.cpp | 76 +++++++++++++++++++++ csrc/ir/internal_base_nodes.h | 16 +++++ csrc/ir/internal_nodes.cpp | 27 ++++++++ csrc/ir/internal_nodes.h | 38 +++++++++++ tests/cpp/test_ragged_iter_domain.cpp | 95 +++++++++++++++++++++++++++ 5 files changed, 252 insertions(+) diff --git a/csrc/ir/internal_base_nodes.cpp b/csrc/ir/internal_base_nodes.cpp index 698a047db6f..d29122f4898 100644 --- a/csrc/ir/internal_base_nodes.cpp +++ b/csrc/ir/internal_base_nodes.cpp @@ -1054,6 +1054,82 @@ std::pair RaggedIterDomain::partition( return {component_id, ragged_id}; } +IterDomain* RaggedIterDomain::combine( + IterDomain* component, + RaggedIterDomain* ragged) { + NVF_ERROR(component != nullptr, "combine: component IterDomain is null"); + NVF_ERROR(ragged != nullptr, "combine: ragged IterDomain is null"); + + NVF_ERROR( + !component->isA(), + "combine: component must be a regular IterDomain, got RaggedIterDomain: ", + component->toString()); + + // Validate that component and ragged have compatible properties + NVF_ERROR_EQ( + component->getParallelType(), + ParallelType::Serial, + "Combining parallelized IterDomain not supported: ", + component->toString()); + + NVF_ERROR_EQ( + ragged->getParallelType(), + ParallelType::Serial, + "Combining parallelized RaggedIterDomain not supported: ", + ragged->toString()); + + NVF_ERROR_EQ( + component->getIterType(), + IterType::Iteration, + "combine: only IterType::Iteration is supported for component, got ", + component->getIterType(), + " for IterDomain: ", + component->toString()); + + NVF_ERROR_EQ( + ragged->getIterType(), + IterType::Iteration, + "combine: only IterType::Iteration is supported for ragged, got ", + ragged->getIterType(), + " for RaggedIterDomain: ", + ragged->toString()); + + // The combined extent is the sum of all extents in the ragged dimension + // For a 1D extents tensor [e0, e1, ..., en-1], the total is sum(extents) + TensorView* extents_tv = ragged->extents(); + NVF_ERROR(extents_tv != nullptr, "combine: ragged extents tensor is null"); + + // It is still assumed the extents tensor is just 1D + NVF_ERROR_EQ( + std::ssize(extents_tv->getLogicalDomain()), + 1, + "Unexpected rank of extent tensor: ", + extents_tv->toString()); + + auto container = component->container(); + auto zero = container->zeroVal(DataType::Index); + + // Create a symbolic extent for the combined IterDomain + // This represents the sum of all ragged extents, i.e., + // sum(extents_tv, {0}). We could use the sum output as the extent + // but we would need to extract the scalar value out of the 0-dim + // tensor. For now, we leave it as a symbolic Val. + Val* combined_extent = + IrBuilder::createInContainer(container, DataType::Index); + + // Create the combined IterDomain with the symbolic extent + IterDomain* combined_id = IterDomainBuilder(zero, combined_extent) + .parallel_type(ParallelType::Serial) + .iter_type(IterType::Iteration) + .build(); + + // Create the Combine expression linking component + ragged -> combined + IrBuilder::createInContainer( + container, combined_id, component, ragged); + + return combined_id; +} + TensorDomain::TensorDomain( IrBuilderPasskey passkey, std::vector logical_domain, diff --git a/csrc/ir/internal_base_nodes.h b/csrc/ir/internal_base_nodes.h index 9d40fcfbf1e..2d1872563e5 100644 --- a/csrc/ir/internal_base_nodes.h +++ b/csrc/ir/internal_base_nodes.h @@ -499,6 +499,22 @@ class NVF_API RaggedIterDomain : public IterDomain { IterDomain* in, TensorView* extents); + //! Combine a component IterDomain with a RaggedIterDomain to flatten + //! This is the inverse of partition, creating a regular IterDomain + //! + //! \param component Component IterDomain (extent = num_components) + //! \param ragged RaggedIterDomain with variable extents per component + //! \return Regular IterDomain with extent = sum of all component extents + //! + //! This operation flattens the ragged structure back into a single dimension. + //! Example: component extent=3, ragged extents=[127, 0, 198] + //! -> output extent = 325 (= 127 + 0 + 198) + //! + //! Note: We use "combine" instead of "merge" to differentiate from the + //! regular IterDomain::merge operation which only works with regular + //! IterDomains. + static IterDomain* combine(IterDomain* component, RaggedIterDomain* ragged); + //! Override cloneWithoutRFactor to preserve RaggedIterDomain type IterDomain* cloneWithoutRFactor(bool map_with_original = false) override; diff --git a/csrc/ir/internal_nodes.cpp b/csrc/ir/internal_nodes.cpp index 219d1669b48..116eadad676 100644 --- a/csrc/ir/internal_nodes.cpp +++ b/csrc/ir/internal_nodes.cpp @@ -2653,6 +2653,33 @@ std::string Partition::toInlineString(int indent_size) const { NVFUSER_DEFINE_CLONE_AND_CREATE(Partition) +Combine::Combine( + IrBuilderPasskey passkey, + IterDomain* out, + IterDomain* component, + RaggedIterDomain* ragged) + : Expr(passkey) { + addOutput(out); + addInput(component); + addInput(ragged); +} + +std::string Combine::toString(int indent_size) const { + std::stringstream ss; + ss << "Combine: "; + ss << "component: " << component()->toString(); + ss << " + ragged: " << ragged()->toString(); + ss << " -> " << out()->toString(); + ss << "\n"; + return ss.str(); +} + +std::string Combine::toInlineString(int indent_size) const { + NVF_CHECK(false, "Combine can not be printed inline"); +} + +NVFUSER_DEFINE_CLONE_AND_CREATE(Combine) + Swizzle::Swizzle( IrBuilderPasskey passkey, IterDomain* out_x, diff --git a/csrc/ir/internal_nodes.h b/csrc/ir/internal_nodes.h index 9393dc3016b..863304ce1be 100644 --- a/csrc/ir/internal_nodes.h +++ b/csrc/ir/internal_nodes.h @@ -1845,6 +1845,44 @@ class NVF_API Partition : public Expr { } }; +//! Combine a component IterDomain with a RaggedIterDomain to flatten +//! This is the inverse of Partition, merging component and ragged dimensions +//! into a single regular IterDomain +class NVF_API Combine : public Expr { + public: + using Expr::Expr; + + Combine( + IrBuilderPasskey, + IterDomain* out, + IterDomain* component, + RaggedIterDomain* ragged); + + NVFUSER_DECLARE_CLONE_AND_CREATE + + const char* getOpString() const override { + return "Combine"; + } + + std::string toString(int indent_size = 0) const override; + std::string toInlineString(int indent_size = 0) const override; + + //! Output IterDomain (combined/flattened dimension) + IterDomain* out() const { + return output(0)->as(); + } + + //! Component dimension input (extent = num_components) + IterDomain* component() const { + return input(0)->as(); + } + + //! Ragged dimension input (variable extents per component) + RaggedIterDomain* ragged() const { + return input(1)->as(); + } +}; + class Swizzle : public Expr { public: using Expr::Expr; diff --git a/tests/cpp/test_ragged_iter_domain.cpp b/tests/cpp/test_ragged_iter_domain.cpp index bd3c88fe748..bdcf0bf834e 100644 --- a/tests/cpp/test_ragged_iter_domain.cpp +++ b/tests/cpp/test_ragged_iter_domain.cpp @@ -340,6 +340,101 @@ TEST_F(RaggedIterDomainTest, TensorViewPartition) { EXPECT_EQ(tv0->axis(0)->definition(), tv0->axis(1)->definition()); } +// Test combining component and ragged IterDomains (inverse of partition) +TEST_F(RaggedIterDomainTest, CombineBasic) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Create extents tensor + auto extents = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(extents); + + // Create a regular IterDomain to partition + auto orig_id = IterDomainBuilder( + fusion.zeroVal(DataType::Index), IrBuilder::create(325L, DataType::Index)) + .build(); + + // Partition into component and ragged + auto [component_id, ragged_id] = RaggedIterDomain::partition(orig_id, extents); + + // Verify partition worked + EXPECT_NE(component_id, nullptr); + EXPECT_NE(ragged_id, nullptr); + EXPECT_TRUE(component_id->isA()); + EXPECT_TRUE(ragged_id->isA()); + + // Now combine them back + auto combined_id = RaggedIterDomain::combine(component_id, ragged_id); + + // Verify combine worked + EXPECT_NE(combined_id, nullptr); + EXPECT_TRUE(combined_id->isA()); + EXPECT_FALSE(combined_id->isA()); + + // Verify the combine has a definition (Combine expr) + EXPECT_NE(combined_id->definition(), nullptr); + EXPECT_TRUE(combined_id->definition()->isA()); + + // Verify the Combine expression has correct inputs + auto combine_expr = combined_id->definition()->as(); + EXPECT_EQ(combine_expr->component(), component_id); + EXPECT_EQ(combine_expr->ragged(), ragged_id); + EXPECT_EQ(combine_expr->out(), combined_id); +} + +// Test combine validation: null component +TEST_F(RaggedIterDomainTest, CombineValidationNullComponent) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto extents = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(extents); + + auto ragged_id = + IrBuilder::create(extents, IterType::Iteration); + + // Should fail with null component + EXPECT_THROW( + RaggedIterDomain::combine(nullptr, ragged_id), + nvfuser::nvfError); +} + +// Test combine validation: null ragged +TEST_F(RaggedIterDomainTest, CombineValidationNullRagged) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto component_id = IterDomainBuilder( + fusion.zeroVal(DataType::Index), IrBuilder::create(3L, DataType::Index)) + .build(); + + // Should fail with null ragged + EXPECT_THROW( + RaggedIterDomain::combine(component_id, nullptr), + nvfuser::nvfError); +} + +// Test combine validation: component is RaggedIterDomain +TEST_F(RaggedIterDomainTest, CombineValidationComponentIsRagged) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto extents1 = makeSymbolicTensor(1, DataType::Index); + auto extents2 = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(extents1); + fusion.addInput(extents2); + + auto ragged_id1 = + IrBuilder::create(extents1, IterType::Iteration); + auto ragged_id2 = + IrBuilder::create(extents2, IterType::Iteration); + + // Should fail when component is also RaggedIterDomain + EXPECT_THROW( + RaggedIterDomain::combine(ragged_id1, ragged_id2), + nvfuser::nvfError); +} + // asNested basic functionality TEST_F(RaggedIterDomainTest, AsNestedBasic) { Fusion fusion; From a22bb1fa005c37f737c287a1ea6aa855f5190c5d Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 18 Dec 2025 20:57:32 -0800 Subject: [PATCH 28/47] Add tests --- tests/cpp/test_ragged_iter_domain.cpp | 157 ++++++++++++++++++++++++-- 1 file changed, 148 insertions(+), 9 deletions(-) diff --git a/tests/cpp/test_ragged_iter_domain.cpp b/tests/cpp/test_ragged_iter_domain.cpp index bdcf0bf834e..b34de886d65 100644 --- a/tests/cpp/test_ragged_iter_domain.cpp +++ b/tests/cpp/test_ragged_iter_domain.cpp @@ -351,11 +351,13 @@ TEST_F(RaggedIterDomainTest, CombineBasic) { // Create a regular IterDomain to partition auto orig_id = IterDomainBuilder( - fusion.zeroVal(DataType::Index), IrBuilder::create(325L, DataType::Index)) + fusion.zeroVal(DataType::Index), + IrBuilder::create(325L, DataType::Index)) .build(); // Partition into component and ragged - auto [component_id, ragged_id] = RaggedIterDomain::partition(orig_id, extents); + auto [component_id, ragged_id] = + RaggedIterDomain::partition(orig_id, extents); // Verify partition worked EXPECT_NE(component_id, nullptr); @@ -395,8 +397,7 @@ TEST_F(RaggedIterDomainTest, CombineValidationNullComponent) { // Should fail with null component EXPECT_THROW( - RaggedIterDomain::combine(nullptr, ragged_id), - nvfuser::nvfError); + RaggedIterDomain::combine(nullptr, ragged_id), nvfuser::nvfError); } // Test combine validation: null ragged @@ -405,13 +406,13 @@ TEST_F(RaggedIterDomainTest, CombineValidationNullRagged) { FusionGuard fg(&fusion); auto component_id = IterDomainBuilder( - fusion.zeroVal(DataType::Index), IrBuilder::create(3L, DataType::Index)) + fusion.zeroVal(DataType::Index), + IrBuilder::create(3L, DataType::Index)) .build(); // Should fail with null ragged EXPECT_THROW( - RaggedIterDomain::combine(component_id, nullptr), - nvfuser::nvfError); + RaggedIterDomain::combine(component_id, nullptr), nvfuser::nvfError); } // Test combine validation: component is RaggedIterDomain @@ -431,8 +432,7 @@ TEST_F(RaggedIterDomainTest, CombineValidationComponentIsRagged) { // Should fail when component is also RaggedIterDomain EXPECT_THROW( - RaggedIterDomain::combine(ragged_id1, ragged_id2), - nvfuser::nvfError); + RaggedIterDomain::combine(ragged_id1, ragged_id2), nvfuser::nvfError); } // asNested basic functionality @@ -481,6 +481,145 @@ TEST_F(RaggedIterDomainTest, AsNestedBasic) { EXPECT_EQ(nested->axis(0)->definition(), nested->axis(1)->definition()); } +// Test combining nested tensor back to normal tensor +TEST_F(RaggedIterDomainTest, AsNestedThenCombine) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto extents = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(extents); + + // Create nested tensor from dimension 0 + auto nested = asNested(data, extents, 0); + + // Verify nested tensor has 3 dimensions: [component, ragged, original_dim1] + EXPECT_EQ(nested->nDims(), 3); + EXPECT_TRUE(nested->axis(0)->isStrictlyA()); + EXPECT_TRUE(nested->axis(1)->isA()); + + // Get the component and ragged IterDomains + auto component_id = nested->axis(0); + auto ragged_id = nested->axis(1)->as(); + + // Combine them back into a normal IterDomain + auto combined_id = RaggedIterDomain::combine(component_id, ragged_id); + + // Verify the combined IterDomain is a regular IterDomain, not ragged + EXPECT_NE(combined_id, nullptr); + EXPECT_TRUE(combined_id->isStrictlyA()); + EXPECT_FALSE(combined_id->isA()); + + // Verify the combined IterDomain has a Combine definition + EXPECT_NE(combined_id->definition(), nullptr); + EXPECT_TRUE(combined_id->definition()->isA()); + + // Verify the Combine expression has correct inputs + auto combine_expr = combined_id->definition()->as(); + EXPECT_EQ(combine_expr->component(), component_id); + EXPECT_EQ(combine_expr->ragged(), ragged_id); + EXPECT_EQ(combine_expr->out(), combined_id); + + // Verify that the component came from the same Partition as the ragged + EXPECT_NE(component_id->definition(), nullptr); + EXPECT_TRUE(component_id->definition()->isA()); + EXPECT_EQ(component_id->definition(), ragged_id->definition()); + + auto partition_expr = component_id->definition()->as(); + EXPECT_EQ(partition_expr->component(), component_id); + EXPECT_EQ(partition_expr->ragged(), ragged_id); +} + +// Test combining nested tensor back to normal tensor after set operation +TEST_F(RaggedIterDomainTest, AsNestedThenSetThenCombine) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto extents = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(extents); + + // Create nested tensor from dimension 0 + auto nested = asNested(data, extents, 0); + + // Insert a set operation after asNested + auto nested_copy = set(nested); + + // Verify nested_copy tensor has 3 dimensions: [component, ragged, + // original_dim1] + EXPECT_EQ(nested_copy->nDims(), 3); + EXPECT_TRUE(nested_copy->axis(0)->isStrictlyA()); + EXPECT_TRUE(nested_copy->axis(1)->isA()); + + // Get the component and ragged IterDomains from the copy + auto component_id = nested_copy->axis(0); + auto ragged_id = nested_copy->axis(1)->as(); + + // Combine them back into a normal IterDomain. Even though + // component_id and ragged_id are not directly produced by a + // partition, this should succeed. See the next test + // (AsNestedThenSetThenCombineInvalidComponent) for a failing example. + auto combined_id = RaggedIterDomain::combine(component_id, ragged_id); + + // Verify the combined IterDomain is a regular IterDomain, not ragged + EXPECT_NE(combined_id, nullptr); + EXPECT_TRUE(combined_id->isStrictlyA()); + EXPECT_FALSE(combined_id->isA()); + + // Verify the combined IterDomain has a Combine definition + EXPECT_NE(combined_id->definition(), nullptr); + EXPECT_TRUE(combined_id->definition()->isA()); + + // Verify the Combine expression has correct inputs + auto combine_expr = combined_id->definition()->as(); + EXPECT_EQ(combine_expr->component(), component_id); + EXPECT_EQ(combine_expr->ragged(), ragged_id); + EXPECT_EQ(combine_expr->out(), combined_id); +} + +// Test combining with invalid component (not from same partition) - should +// error +TEST_F(RaggedIterDomainTest, AsNestedThenSetThenCombineInvalidComponent) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto data = makeSymbolicTensor(2, DataType::Float); + fusion.addInput(data); + + auto extents = makeSymbolicTensor(1, DataType::Index); + fusion.addInput(extents); + + // Create nested tensor from dimension 0 + auto nested = asNested(data, extents, 0); + + // Insert a set operation after asNested + auto nested_copy = set(nested); + + // Verify nested_copy tensor has 3 dimensions: [component, ragged, + // original_dim1] + EXPECT_EQ(nested_copy->nDims(), 3); + EXPECT_TRUE(nested_copy->axis(0)->isStrictlyA()); + EXPECT_TRUE(nested_copy->axis(1)->isA()); + + // Get the ragged IterDomain from the copy + auto ragged_id = nested_copy->axis(1)->as(); + + // Use an INVALID component: the third axis instead of the first + // This is NOT the component from the partition, it's the original second + // dimension + auto invalid_component_id = nested_copy->axis(2); + + // Try to combine with the wrong component - this should fail + // The component must be from the same Partition as the ragged IterDomain + EXPECT_THROW( + RaggedIterDomain::combine(invalid_component_id, ragged_id), + nvfuser::nvfError); +} + // asNested on different dimensions TEST_F(RaggedIterDomainTest, AsNestedDifferentDimension) { Fusion fusion; From f521c3856136ab864e823781cd70d1e537354e9c Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 19 Dec 2025 10:21:55 -0800 Subject: [PATCH 29/47] WIP --- csrc/ir/internal_base_nodes.cpp | 44 ++++++++++++++++++--- csrc/ir/internal_base_nodes.h | 12 +++++- csrc/ops/utils.cpp | 70 ++++++++++++++++++++++++--------- csrc/ops/utils.h | 3 +- 4 files changed, 103 insertions(+), 26 deletions(-) diff --git a/csrc/ir/internal_base_nodes.cpp b/csrc/ir/internal_base_nodes.cpp index d29122f4898..0729536cb37 100644 --- a/csrc/ir/internal_base_nodes.cpp +++ b/csrc/ir/internal_base_nodes.cpp @@ -906,7 +906,8 @@ RaggedIterDomain::RaggedIterDomain( IrBuilderPasskey passkey, TensorView* extents, IterType iter_type, - ParallelType parallel_type) + ParallelType parallel_type, + IterDomain* component) : IterDomain( passkey, ValType::RaggedIterDomain, @@ -920,7 +921,8 @@ RaggedIterDomain::RaggedIterDomain( /*is_padded_dimension=*/false, /*is_clustered_blocks=*/false, /*padded_to_size=*/std::nullopt), - extents_(extents) { + extents_(extents), + component_(component) { // Extents must be non-null NVF_ERROR( extents_ != nullptr, "RaggedIterDomain requires non-null extents tensor"); @@ -943,7 +945,9 @@ RaggedIterDomain::RaggedIterDomain( RaggedIterDomain::RaggedIterDomain( const RaggedIterDomain* src, IrCloner* ir_cloner) - : IterDomain(src, ir_cloner), extents_(ir_cloner->clone(src->extents_)) {} + : IterDomain(src, ir_cloner), + extents_(ir_cloner->clone(src->extents_)), + component_(ir_cloner->clone(src->component_)) {} NVFUSER_DEFINE_CLONE(RaggedIterDomain) @@ -964,7 +968,18 @@ bool RaggedIterDomain::sameAs(const Statement* other) const { } // Compare extents tensor - return extents_->sameAs(other_ragged->extents_); + if (!extents_->sameAs(other_ragged->extents_)) { + return false; + } + + // Compare component pointers + if (component_ == other_ragged->component_) { + return true; // Same pointer (including both null) + } + if (component_ && other_ragged->component_) { + return component_->sameAs(other_ragged->component_); + } + return false; // One is null, other is not } std::string RaggedIterDomain::toInlineString(int indent_size) const { @@ -1046,8 +1061,8 @@ std::pair RaggedIterDomain::partition( .iter_type(IterType::Iteration) .build(); - auto ragged_id = - IrBuilder::create(extents, in->getIterType()); + auto ragged_id = IrBuilder::create( + extents, in->getIterType(), ParallelType::Serial, component_id); IrBuilder::create(component_id, ragged_id, in, extents); @@ -1094,6 +1109,23 @@ IterDomain* RaggedIterDomain::combine( " for RaggedIterDomain: ", ragged->toString()); + // Validate that component matches the ragged's stored component + IterDomain* expected_component = ragged->component(); + + NVF_ERROR( + expected_component != nullptr, + "combine: ragged IterDomain does not have an associated component. ", + "RaggedIterDomain: ", + ragged->toString()); + + NVF_ERROR( + component == expected_component, + "combine: component does not match the ragged's paired component. ", + "Provided component: ", + component->toString(), + ", Expected component: ", + expected_component->toString()); + // The combined extent is the sum of all extents in the ragged dimension // For a 1D extents tensor [e0, e1, ..., en-1], the total is sum(extents) TensorView* extents_tv = ragged->extents(); diff --git a/csrc/ir/internal_base_nodes.h b/csrc/ir/internal_base_nodes.h index 2d1872563e5..ed41ee32654 100644 --- a/csrc/ir/internal_base_nodes.h +++ b/csrc/ir/internal_base_nodes.h @@ -460,11 +460,13 @@ class NVF_API RaggedIterDomain : public IterDomain { //! Only Iteration is allowed ATM. //! \param parallel_type Parallelization strategy (applies //! uniformly) + //! \param component Optional paired component IterDomain from Partition RaggedIterDomain( IrBuilderPasskey passkey, TensorView* extents, IterType iter_type = IterType::Iteration, - ParallelType parallel_type = ParallelType::Serial); + ParallelType parallel_type = ParallelType::Serial, + IterDomain* component = nullptr); //! Cloning constructor for IR cloning RaggedIterDomain(const RaggedIterDomain* src, IrCloner* ir_cloner); @@ -482,6 +484,11 @@ class NVF_API RaggedIterDomain : public IterDomain { return extents_; } + //! Accessor for the paired component IterDomain + IterDomain* component() const { + return component_; + } + //! Partition an IterDomain into component and ragged dimensions //! Creates a component IterDomain and a RaggedIterDomain based on extents //! @@ -522,6 +529,9 @@ class NVF_API RaggedIterDomain : public IterDomain { //! Extent tensor containing all component extents //! Can be 1D, 2D, or N-D depending on nesting structure TensorView* extents_ = nullptr; + + //! Paired component IterDomain from Partition + IterDomain* component_ = nullptr; }; //! TensorDomain holds a vector of IterDomains. It holds an IterDomain for every diff --git a/csrc/ops/utils.cpp b/csrc/ops/utils.cpp index be50385528c..a3dfbed5a46 100644 --- a/csrc/ops/utils.cpp +++ b/csrc/ops/utils.cpp @@ -316,7 +316,8 @@ std::vector mapLinearOpIterDomains( } RaggedIterDomain* newOutputRaggedIterDomain( - const std::vector& input_ids) { + const std::vector& input_ids, + const std::unordered_map& p2c_map) { NVF_ERROR( std::ranges::all_of( input_ids, @@ -329,12 +330,19 @@ RaggedIterDomain* newOutputRaggedIterDomain( // Just using the first ragged ID as all input IDs are assumed to be // equivalent - RaggedIterDomain* ref_input_id = input_ids.front()->as(); + auto ref_input_id = input_ids.front()->as(); + + auto component_id_it = p2c_map.find(ref_input_id->component()); + NVF_ERROR( + component_id_it != p2c_map.end(), + "No p2c mapping found for component ID of ", + ref_input_id->component()->toString()); return IrBuilder::create( ref_input_id->extents(), ref_input_id->getIterType(), - ref_input_id->getParallelType()); + ref_input_id->getParallelType(), + component_id_it->second); } // Adding these pragmas since gcc-12.2.1 @@ -348,18 +356,12 @@ IterDomain* newOutputIterDomain( const std::optional force_iter_type) { NVF_ERROR(!input_ids.empty()); - // If any input ID is a RaggedIterDomain, the output should also be ragged - bool has_ragged = - std::any_of(input_ids.begin(), input_ids.end(), [](IterDomain* id) { - return id->isA(); - }); - - if (has_ragged) { - NVF_ERROR( - !force_iter_type.has_value(), - "force_iter_type not supported for RaggedIterDomain"); - return newOutputRaggedIterDomain(input_ids); - } + NVF_ERROR( + std::none_of( + input_ids.begin(), + input_ids.end(), + [](IterDomain* id) { return id->isA(); }), + "RaggedIterDomain should use newOutputRaggedIterDomain"); // For the start and stop offsets, take the maximum of input axes. // For now, the offsets of both start and stop are always integer @@ -471,15 +473,47 @@ std::vector newOutputDomain(const std::vector& vals) { std::vector out_domain( TensorDomain::noReductions(tvs[0]->getLogicalDomain()).size(), nullptr); - for (const auto dim_i : arange(out_domain.size())) { + const auto is_nested = tvs.front()->domain()->hasRaggedIterDomain(); + std::unordered_map p2c_map; + + auto get_ith_id = [](const std::vector& domain, + int64_t i) -> IterDomain* { + auto no_reduction_domain = domain | TensorDomain::kNoReductions; + auto id_it = std::ranges::next(no_reduction_domain.begin(), i); + NVF_ERROR(id_it != no_reduction_domain.end()); + return *id_it; + }; + + std::vector ragged_id_offsets; + + for (const auto dim_i : arange(std::ssize(out_domain))) { + if (get_ith_id(tvs.front()->getLogicalDomain(), dim_i) + ->isA()) { + ragged_id_offsets.push_back(dim_i); + continue; + } std::vector input_ids; input_ids.reserve(tvs.size()); for (auto* tv : tvs) { - auto dom = TensorDomain::noReductions(tv->getLogicalDomain()); - input_ids.emplace_back(dom[dim_i]); + input_ids.emplace_back(get_ith_id(tv->getLogicalDomain(), dim_i)); } out_domain[dim_i] = newOutputIterDomain(input_ids); + if (is_nested) { + for (auto input_id : input_ids) { + p2c_map.emplace(input_id, out_domain[dim_i]); + } + } } + + for (const auto dim_i : ragged_id_offsets) { + std::vector input_ids; + input_ids.reserve(tvs.size()); + for (auto* tv : tvs) { + input_ids.emplace_back(get_ith_id(tv->getLogicalDomain(), dim_i)); + } + out_domain[dim_i] = newOutputRaggedIterDomain(input_ids, p2c_map); + } + return out_domain; } diff --git a/csrc/ops/utils.h b/csrc/ops/utils.h index 3ceadc4aa6a..ccb7ff2351b 100644 --- a/csrc/ops/utils.h +++ b/csrc/ops/utils.h @@ -96,7 +96,8 @@ std::vector mapLinearOpIterDomains( // dimension position. All inputs must be RaggedIterDomain. Uses the extents, // IterType, and ParallelType from the first input. RaggedIterDomain* newOutputRaggedIterDomain( - const std::vector& input_ids); + const std::vector& input_ids, + const std::unordered_map& p2c_map); // Takes a vector of aligned input iterdomains to create the output iterdomain. // This is used if the input iterdomains are not trivially mapped to the output From 8d0d9cbf48b64f18a341acd4ab6d84e38d9db9ad Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 19 Dec 2025 14:46:15 -0800 Subject: [PATCH 30/47] don't hold component ID in RaggedIterDomain --- csrc/dispatch.h | 1 + csrc/ir/internal_base_nodes.cpp | 64 +++++++++++------------- csrc/ir/internal_base_nodes.h | 12 +---- csrc/ops/utils.cpp | 70 +++++++-------------------- csrc/ops/utils.h | 3 +- tests/cpp/test_ragged_iter_domain.cpp | 14 +++--- 6 files changed, 57 insertions(+), 107 deletions(-) diff --git a/csrc/dispatch.h b/csrc/dispatch.h index 822ababb149..cee1fa911e2 100644 --- a/csrc/dispatch.h +++ b/csrc/dispatch.h @@ -116,6 +116,7 @@ class Val; f(ScanOp); \ f(Merge); \ f(Partition); \ + f(Combine); \ f(Swizzle); \ f(Swizzle2D); \ f(Resize); \ diff --git a/csrc/ir/internal_base_nodes.cpp b/csrc/ir/internal_base_nodes.cpp index 0729536cb37..2da29879773 100644 --- a/csrc/ir/internal_base_nodes.cpp +++ b/csrc/ir/internal_base_nodes.cpp @@ -906,8 +906,7 @@ RaggedIterDomain::RaggedIterDomain( IrBuilderPasskey passkey, TensorView* extents, IterType iter_type, - ParallelType parallel_type, - IterDomain* component) + ParallelType parallel_type) : IterDomain( passkey, ValType::RaggedIterDomain, @@ -921,8 +920,7 @@ RaggedIterDomain::RaggedIterDomain( /*is_padded_dimension=*/false, /*is_clustered_blocks=*/false, /*padded_to_size=*/std::nullopt), - extents_(extents), - component_(component) { + extents_(extents) { // Extents must be non-null NVF_ERROR( extents_ != nullptr, "RaggedIterDomain requires non-null extents tensor"); @@ -945,9 +943,7 @@ RaggedIterDomain::RaggedIterDomain( RaggedIterDomain::RaggedIterDomain( const RaggedIterDomain* src, IrCloner* ir_cloner) - : IterDomain(src, ir_cloner), - extents_(ir_cloner->clone(src->extents_)), - component_(ir_cloner->clone(src->component_)) {} + : IterDomain(src, ir_cloner), extents_(ir_cloner->clone(src->extents_)) {} NVFUSER_DEFINE_CLONE(RaggedIterDomain) @@ -968,18 +964,7 @@ bool RaggedIterDomain::sameAs(const Statement* other) const { } // Compare extents tensor - if (!extents_->sameAs(other_ragged->extents_)) { - return false; - } - - // Compare component pointers - if (component_ == other_ragged->component_) { - return true; // Same pointer (including both null) - } - if (component_ && other_ragged->component_) { - return component_->sameAs(other_ragged->component_); - } - return false; // One is null, other is not + return extents_->sameAs(other_ragged->extents_); } std::string RaggedIterDomain::toInlineString(int indent_size) const { @@ -1061,8 +1046,8 @@ std::pair RaggedIterDomain::partition( .iter_type(IterType::Iteration) .build(); - auto ragged_id = IrBuilder::create( - extents, in->getIterType(), ParallelType::Serial, component_id); + auto ragged_id = + IrBuilder::create(extents, in->getIterType()); IrBuilder::create(component_id, ragged_id, in, extents); @@ -1109,22 +1094,29 @@ IterDomain* RaggedIterDomain::combine( " for RaggedIterDomain: ", ragged->toString()); - // Validate that component matches the ragged's stored component - IterDomain* expected_component = ragged->component(); + // Validate component-ragged pairing when Partition definition is available + // (Option 3: Best-effort validation) + // Only validate when the RaggedIterDomain has a direct Partition definition. + // After propagation (e.g., set() operations), the definition may be nullptr, + // in which case we trust the user to provide the correct component. + if (ragged->definition() != nullptr && + ragged->definition()->isA()) { + auto* partition = ragged->definition()->as(); + IterDomain* expected_component = partition->component(); - NVF_ERROR( - expected_component != nullptr, - "combine: ragged IterDomain does not have an associated component. ", - "RaggedIterDomain: ", - ragged->toString()); - - NVF_ERROR( - component == expected_component, - "combine: component does not match the ragged's paired component. ", - "Provided component: ", - component->toString(), - ", Expected component: ", - expected_component->toString()); + NVF_ERROR( + component == expected_component, + "combine: component mismatch. The provided component does not match ", + "the component from the Partition that created this " + "RaggedIterDomain.\n", + " Provided component: ", + component->toString(), + "\n", + " Expected component: ", + expected_component->toString()); + } + // If no Partition definition (after set, in segmented fusion, or external + // input), trust the user and proceed without validation // The combined extent is the sum of all extents in the ragged dimension // For a 1D extents tensor [e0, e1, ..., en-1], the total is sum(extents) diff --git a/csrc/ir/internal_base_nodes.h b/csrc/ir/internal_base_nodes.h index ed41ee32654..2d1872563e5 100644 --- a/csrc/ir/internal_base_nodes.h +++ b/csrc/ir/internal_base_nodes.h @@ -460,13 +460,11 @@ class NVF_API RaggedIterDomain : public IterDomain { //! Only Iteration is allowed ATM. //! \param parallel_type Parallelization strategy (applies //! uniformly) - //! \param component Optional paired component IterDomain from Partition RaggedIterDomain( IrBuilderPasskey passkey, TensorView* extents, IterType iter_type = IterType::Iteration, - ParallelType parallel_type = ParallelType::Serial, - IterDomain* component = nullptr); + ParallelType parallel_type = ParallelType::Serial); //! Cloning constructor for IR cloning RaggedIterDomain(const RaggedIterDomain* src, IrCloner* ir_cloner); @@ -484,11 +482,6 @@ class NVF_API RaggedIterDomain : public IterDomain { return extents_; } - //! Accessor for the paired component IterDomain - IterDomain* component() const { - return component_; - } - //! Partition an IterDomain into component and ragged dimensions //! Creates a component IterDomain and a RaggedIterDomain based on extents //! @@ -529,9 +522,6 @@ class NVF_API RaggedIterDomain : public IterDomain { //! Extent tensor containing all component extents //! Can be 1D, 2D, or N-D depending on nesting structure TensorView* extents_ = nullptr; - - //! Paired component IterDomain from Partition - IterDomain* component_ = nullptr; }; //! TensorDomain holds a vector of IterDomains. It holds an IterDomain for every diff --git a/csrc/ops/utils.cpp b/csrc/ops/utils.cpp index a3dfbed5a46..be50385528c 100644 --- a/csrc/ops/utils.cpp +++ b/csrc/ops/utils.cpp @@ -316,8 +316,7 @@ std::vector mapLinearOpIterDomains( } RaggedIterDomain* newOutputRaggedIterDomain( - const std::vector& input_ids, - const std::unordered_map& p2c_map) { + const std::vector& input_ids) { NVF_ERROR( std::ranges::all_of( input_ids, @@ -330,19 +329,12 @@ RaggedIterDomain* newOutputRaggedIterDomain( // Just using the first ragged ID as all input IDs are assumed to be // equivalent - auto ref_input_id = input_ids.front()->as(); - - auto component_id_it = p2c_map.find(ref_input_id->component()); - NVF_ERROR( - component_id_it != p2c_map.end(), - "No p2c mapping found for component ID of ", - ref_input_id->component()->toString()); + RaggedIterDomain* ref_input_id = input_ids.front()->as(); return IrBuilder::create( ref_input_id->extents(), ref_input_id->getIterType(), - ref_input_id->getParallelType(), - component_id_it->second); + ref_input_id->getParallelType()); } // Adding these pragmas since gcc-12.2.1 @@ -356,12 +348,18 @@ IterDomain* newOutputIterDomain( const std::optional force_iter_type) { NVF_ERROR(!input_ids.empty()); - NVF_ERROR( - std::none_of( - input_ids.begin(), - input_ids.end(), - [](IterDomain* id) { return id->isA(); }), - "RaggedIterDomain should use newOutputRaggedIterDomain"); + // If any input ID is a RaggedIterDomain, the output should also be ragged + bool has_ragged = + std::any_of(input_ids.begin(), input_ids.end(), [](IterDomain* id) { + return id->isA(); + }); + + if (has_ragged) { + NVF_ERROR( + !force_iter_type.has_value(), + "force_iter_type not supported for RaggedIterDomain"); + return newOutputRaggedIterDomain(input_ids); + } // For the start and stop offsets, take the maximum of input axes. // For now, the offsets of both start and stop are always integer @@ -473,47 +471,15 @@ std::vector newOutputDomain(const std::vector& vals) { std::vector out_domain( TensorDomain::noReductions(tvs[0]->getLogicalDomain()).size(), nullptr); - const auto is_nested = tvs.front()->domain()->hasRaggedIterDomain(); - std::unordered_map p2c_map; - - auto get_ith_id = [](const std::vector& domain, - int64_t i) -> IterDomain* { - auto no_reduction_domain = domain | TensorDomain::kNoReductions; - auto id_it = std::ranges::next(no_reduction_domain.begin(), i); - NVF_ERROR(id_it != no_reduction_domain.end()); - return *id_it; - }; - - std::vector ragged_id_offsets; - - for (const auto dim_i : arange(std::ssize(out_domain))) { - if (get_ith_id(tvs.front()->getLogicalDomain(), dim_i) - ->isA()) { - ragged_id_offsets.push_back(dim_i); - continue; - } + for (const auto dim_i : arange(out_domain.size())) { std::vector input_ids; input_ids.reserve(tvs.size()); for (auto* tv : tvs) { - input_ids.emplace_back(get_ith_id(tv->getLogicalDomain(), dim_i)); + auto dom = TensorDomain::noReductions(tv->getLogicalDomain()); + input_ids.emplace_back(dom[dim_i]); } out_domain[dim_i] = newOutputIterDomain(input_ids); - if (is_nested) { - for (auto input_id : input_ids) { - p2c_map.emplace(input_id, out_domain[dim_i]); - } - } } - - for (const auto dim_i : ragged_id_offsets) { - std::vector input_ids; - input_ids.reserve(tvs.size()); - for (auto* tv : tvs) { - input_ids.emplace_back(get_ith_id(tv->getLogicalDomain(), dim_i)); - } - out_domain[dim_i] = newOutputRaggedIterDomain(input_ids, p2c_map); - } - return out_domain; } diff --git a/csrc/ops/utils.h b/csrc/ops/utils.h index ccb7ff2351b..3ceadc4aa6a 100644 --- a/csrc/ops/utils.h +++ b/csrc/ops/utils.h @@ -96,8 +96,7 @@ std::vector mapLinearOpIterDomains( // dimension position. All inputs must be RaggedIterDomain. Uses the extents, // IterType, and ParallelType from the first input. RaggedIterDomain* newOutputRaggedIterDomain( - const std::vector& input_ids, - const std::unordered_map& p2c_map); + const std::vector& input_ids); // Takes a vector of aligned input iterdomains to create the output iterdomain. // This is used if the input iterdomains are not trivially mapped to the output diff --git a/tests/cpp/test_ragged_iter_domain.cpp b/tests/cpp/test_ragged_iter_domain.cpp index b34de886d65..f8142bf4c9c 100644 --- a/tests/cpp/test_ragged_iter_domain.cpp +++ b/tests/cpp/test_ragged_iter_domain.cpp @@ -582,7 +582,10 @@ TEST_F(RaggedIterDomainTest, AsNestedThenSetThenCombine) { } // Test combining with invalid component (not from same partition) - should -// error +// Test combining after set operation with invalid component +// With Option 3 validation strategy, this does NOT throw an error +// because after set(), the RaggedIterDomain loses its Partition definition +// and validation is skipped (trusts the user) TEST_F(RaggedIterDomainTest, AsNestedThenSetThenCombineInvalidComponent) { Fusion fusion; FusionGuard fg(&fusion); @@ -613,11 +616,10 @@ TEST_F(RaggedIterDomainTest, AsNestedThenSetThenCombineInvalidComponent) { // dimension auto invalid_component_id = nested_copy->axis(2); - // Try to combine with the wrong component - this should fail - // The component must be from the same Partition as the ragged IterDomain - EXPECT_THROW( - RaggedIterDomain::combine(invalid_component_id, ragged_id), - nvfuser::nvfError); + // With Option 3: After set(), the RaggedIterDomain no longer has a + // Partition definition, so validation is skipped and the operation succeeds. + // The user is responsible for providing the correct component. + EXPECT_NO_THROW(RaggedIterDomain::combine(invalid_component_id, ragged_id)); } // asNested on different dimensions From 67aac1b3622198b42adaddf9329fc002efffe63c Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 19 Dec 2025 15:04:06 -0800 Subject: [PATCH 31/47] Add design doc --- csrc/ir/internal_base_nodes.cpp | 2 +- .../ragged_iter_domain_combine_design_doc.md | 348 ++++++++++++++++++ 2 files changed, 349 insertions(+), 1 deletion(-) create mode 100644 doc/dev/ragged_iter_domain_combine_design_doc.md diff --git a/csrc/ir/internal_base_nodes.cpp b/csrc/ir/internal_base_nodes.cpp index 2da29879773..db995ec6a93 100644 --- a/csrc/ir/internal_base_nodes.cpp +++ b/csrc/ir/internal_base_nodes.cpp @@ -1095,7 +1095,7 @@ IterDomain* RaggedIterDomain::combine( ragged->toString()); // Validate component-ragged pairing when Partition definition is available - // (Option 3: Best-effort validation) + // (Option 3 of doc/dev/ragged_iter_domain_combine_design_doc.md). // Only validate when the RaggedIterDomain has a direct Partition definition. // After propagation (e.g., set() operations), the definition may be nullptr, // in which case we trust the user to provide the correct component. diff --git a/doc/dev/ragged_iter_domain_combine_design_doc.md b/doc/dev/ragged_iter_domain_combine_design_doc.md new file mode 100644 index 00000000000..f3671c19377 --- /dev/null +++ b/doc/dev/ragged_iter_domain_combine_design_doc.md @@ -0,0 +1,348 @@ +# Design Document: Component IterDomain Tracking for RaggedIterDomain + +## Problem Statement + +When calling `RaggedIterDomain::combine(component, ragged)`, we need to validate that the `component` IterDomain is the correct one that was originally paired with `ragged` during the Partition operation that created it. + +### The Challenge + +The naive approach of checking `ragged->definition()` for a Partition expression fails because: + +1. **Tensor-level operations break the definition chain**: Operations like `set()` create new TensorViews with new IterDomains +2. **IterDomains are propagated without definitions**: The new IterDomains are clones/descendants but don't have the original Partition as their definition +3. **The pairing information is lost**: After propagation, there's no explicit link between the ragged IterDomain and its paired component + +### Concrete Example + +```cpp +// tv0: [i0] - regular tensor +auto result = asNested(tv0, 0, extents); // Creates Partition expression +// result.ragged: RaggedIterDomain with Partition definition +// result.component: Component IterDomain + +auto tv1 = set(tv0); // Propagates IterDomains +// tv1 has a RaggedIterDomain, but it's a clone without Partition definition + +combine(result.component, tv1->getRaggedDomain()); // How do we validate? +``` + +## Design Alternatives + +### Option 1: Store Component Pointer in RaggedIterDomain + +**Approach**: Add a `component_` member variable to `RaggedIterDomain` that points to the paired component IterDomain. + +**Implementation**: +```cpp +class RaggedIterDomain : public IterDomain { + private: + TensorView* extents_ = nullptr; + IterDomain* component_ = nullptr; // NEW: paired component + + public: + IterDomain* component() const { return component_; } + void setComponent(IterDomain* component) { component_ = component; } +}; +``` + +**How It Works**: +1. When `Partition` creates a RaggedIterDomain, it sets the component pointer +2. When IterDomains are cloned (e.g., during `set()`), the component pointer is cloned/mapped too +3. In `combine()`, validate that the provided component matches `ragged->component()` + +**Pros**: +- ✅ Simple and direct solution +- ✅ Component pointer is automatically preserved during cloning +- ✅ Fast O(1) lookup - no graph traversal needed +- ✅ Follows existing pattern (similar to how `extents_` is stored) +- ✅ Self-documenting - makes the pairing explicit in the data structure + +**Cons**: +- ❌ **CRITICAL: Dependency ordering not guaranteed** - Since `component_` is not an input to RaggedIterDomain's definition, IR graph traversal (during lowering, cloning, replay) has no guarantee that the component IterDomain will be visited/cloned before the ragged IterDomain. This can lead to: + - Dangling pointers during cloning (trying to remap component before it's cloned) + - Incorrect mappings in IrCloner when component hasn't been processed yet + - Failures in topological traversal algorithms that expect dependencies to be explicit +- ❌ **Fragile during replacements** - When IterDomains are replaced (e.g., via `replaceAllUsesWith()`), the component pointer in ragged doesn't get updated automatically. Would require special-case handling throughout the codebase to maintain this hidden dependency. +- ❌ **Strong implicit coupling** - Creates a dependency that's not reflected in the IR graph structure, making the IR harder to reason about and maintain. Optimization passes and transformations that don't know about this hidden link could break the invariant. +- ❌ Component pointer could become stale if component IterDomain is replaced/transformed + +**Why This Is Problematic**: + +The fundamental issue is that this approach tries to store a relationship that *should* be part of the IR graph structure as an *out-of-band* pointer. nvFuser's IR infrastructure is designed around explicit dependency edges (inputs/outputs of expressions). Adding a pointer that doesn't follow these edges creates a parallel tracking mechanism that must be manually maintained across all IR operations: + +1. **IrCloner** would need to special-case the component pointer remapping, but it can't guarantee ordering +2. **replaceAllUsesWith()** and similar operations wouldn't know to update the component pointer +3. **Replay** operations that transform IterDomains wouldn't propagate the component link correctly +4. **Serialization/deserialization** would need special handling for this out-of-band pointer + +--- + +### Option 2: Traverse IR Graph to Find Original Partition + +**Approach**: Walk backward through the IterDomain definition chain to find the original Partition expression. + +**Implementation**: +```cpp +IterDomain* findOriginalComponent(RaggedIterDomain* ragged) { + // Traverse backward through set operations, clones, etc. + auto* current = ragged; + while (current != nullptr) { + if (current->definition() && current->definition()->isA()) { + return current->definition()->as()->component(); + } + // Follow the chain backward (e.g., through set operations) + current = getSourceIterDomain(current); + } + return nullptr; // No Partition found +} +``` + +**How It Works**: +1. Start from the given RaggedIterDomain +2. Traverse backward through the IR graph following definition chains +3. Find the original Partition expression +4. Extract the component from that Partition + +**Pros**: +- ✅ No additional memory overhead +- ✅ No new state to maintain +- ✅ Always finds the "true" original component by traversing the IR +- ✅ Component pointer can't become stale (computed on demand) + +**Cons**: +- ❌ **CRITICAL: Fusion segmentation breaks traversal** - When a fusion is segmented (split into multiple kernels), each segment contains only a subset of the full IR graph. A segment may contain a RaggedIterDomain that needs to be combined, but the original Partition expression that created it may be in a different segment. Traversal cannot cross segment boundaries, making it impossible to find the original component. +- ❌ **CRITICAL: External ragged tensors have no Partition** - When RaggedIterDomain support is extended in the future to accept ragged tensors from PyTorch as fusion inputs, these would arrive as RaggedIterDomains without any Partition expression in the nvFuser IR. There would be nothing to traverse back to, yet we still need to know the component for validation. +- ❌ **Unreliable chain traversal** - Even when a Partition exists in the same segment, the definition chain can be broken or complex: + - Operations like `set()` intentionally break the definition chain + - Multiple paths back through different transformations + - Split/merge operations on the path complicate tracking +- ❌ Requires IR graph traversal - O(n) where n is chain depth +- ❌ Complex implementation - need to handle all propagation patterns (set, replay, clone, etc.) +- ❌ Performance cost on every combine() call + +**Why This Is Problematic**: + +This approach assumes the Partition expression is always reachable, but there are fundamental scenarios where it isn't: + +1. **Segmented Fusions**: nvFuser segments complex fusions into multiple kernels. Each segment is scheduled and lowered independently. A RaggedIterDomain in segment N may have been created by a Partition in segment M, but segment boundaries are opaque - you can't traverse across them. + +2. **Future External Inputs**: When RaggedIterDomain support is extended to accept ragged tensors from PyTorch as fusion inputs, these RaggedIterDomains will have no corresponding nvFuser Partition expression. They represent already-partitioned data from outside nvFuser. + +3. **Definition Chain Breaks**: Even within a segment, operations like `set()` intentionally create new IterDomains without definitions, breaking the chain. + +The fundamental flaw is assuming component information can be recovered from the IR graph structure, when in reality the information may not exist in the graph at all. + +--- + +### Option 3: Track Component in Partition Expression Only + +**Approach**: Only validate when a direct Partition definition exists, otherwise trust the user. + +**Implementation**: +```cpp +void combine(IterDomain* component, RaggedIterDomain* ragged) { + // Only validate if we can find a Partition + if (ragged->definition() && ragged->definition()->isA()) { + auto* partition = ragged->definition()->as(); + NVF_ERROR(component == partition->component(), + "Component doesn't match partition"); + } + // Otherwise, no validation - trust the user + + // Proceed with combine... +} +``` + +**How It Works**: +1. Check if ragged has a Partition definition +2. If yes, validate the component +3. If no, skip validation and trust the user provided the correct component + +**Pros**: +- ✅ Minimal implementation - no new infrastructure +- ✅ No memory overhead +- ✅ Simple to understand +- ✅ Validation when possible, permissive when not + +**Cons**: +- ⚠️ Validation is incomplete - only validates when Partition definition is directly available +- ⚠️ After propagation operations (set, segmentation), relies on user correctness + +**Why This Is Actually Reasonable**: + +This approach aligns with how nvFuser handles other operations: +- **Arithmetic operations** (add, mul, etc.) assume inputs have matching shapes - they don't validate +- **User responsibility**: If users call `combine(component, ragged)`, we trust they're providing the correct component +- **Validation where possible**: When we CAN validate (Partition definition exists), we do +- **Fail-fast when detectable**: Catches obvious errors early in the fusion definition +- **Pragmatic**: Acknowledges that complete validation isn't feasible given segmentation and external inputs + +The key insight is that `combine()` is a user-facing API. Users are expected to know which component pairs with which ragged domain, just as they're expected to know when tensor shapes are compatible for arithmetic operations. + +--- + +### Option 4: Store Component Pairing in TensorDomain + +**Approach**: Store component-ragged pairings in TensorDomain rather than in RaggedIterDomain itself. + +**Implementation**: +```cpp +// In TensorDomain +class TensorDomain { + private: + std::vector logical_domain_; + // Other domain vectors... + + // NEW: Track ragged-component pairings for IterDomains in this TensorDomain + struct RaggedComponentPair { + RaggedIterDomain* ragged; + IterDomain* component; + }; + std::vector ragged_component_pairs_; + + public: + // Get the component for a ragged IterDomain in this TensorDomain + IterDomain* getComponentFor(RaggedIterDomain* ragged) const; + + // Register a ragged-component pairing (called when creating from Partition) + void registerRaggedComponentPair(RaggedIterDomain* ragged, IterDomain* component); +}; +``` + +**How It Works**: +1. When Partition creates a TensorView with ragged and component IterDomains, register the pairing in the TensorDomain +2. The pairing is stored alongside the IterDomains themselves, ensuring both ragged and component are in `allIds()` +3. When tensor operations (like `set()`) propagate TensorDomains, they also propagate the pairing information +4. In `combine()`, look up the component from the TensorView's TensorDomain + +**Pros**: +- ✅ **Looser coupling**: The relationship is stored in TensorDomain, not in RaggedIterDomain itself +- ✅ **Follows containment**: TensorDomain already owns and manages its IterDomains, so it's natural to manage their relationships +- ✅ **Explicit in domain operations**: Operations that propagate TensorDomain can explicitly propagate pairings +- ✅ **Validates across propagation**: Works even after `set()` if the pairing is propagated correctly +- ✅ **Both IDs guaranteed present**: Since both must be in `allIds()`, dependency ordering is less problematic + +**Cons**: +- ❌ **Propagation must be explicit**: Every operation that creates/clones TensorDomain must handle pairing propagation +- ❌ **More complex than Option 3**: Requires changes to TensorDomain and all operations that manipulate it +- ❌ **Still has propagation challenges**: Operations like replay, resize, or transformations need to update pairings +- ❌ **Segmentation issues remain**: After fusion segmentation, TensorDomain in one segment may not have the original pairing information + +**Key Challenge**: + +The main implementation challenge is ensuring pairing propagation through all tensor operations: +- `set()`: Must copy pairings from input TensorDomain to output +- `view/reshape`: Must map pairings through transformations +- Replay operations: Must track how ragged and component are transformed +- Cloning: Must clone pairings along with IterDomains + +**Why This Is Better Than Option 1**: + +Unlike storing the pointer in RaggedIterDomain: +- TensorDomain already manages relationships between IterDomains (root→logical→allocation mappings) +- Both ragged and component are explicitly part of the domain, reducing implicit dependencies +- The coupling is at the TensorDomain level, not at the individual IterDomain level + +**Why This May Not Be Worth It**: + +While architecturally cleaner than Option 1, it's still significantly more complex than Option 3: +- Requires modifying TensorDomain and many tensor operations +- Still doesn't solve segmentation (segments may not preserve original TensorDomain) +- Adds complexity for validation that may not be critical (users can track pairings) + +If Option 3's "trust the user" approach is sufficient, Option 4's additional complexity may not be justified. + +--- + +## Analysis Summary + +### Why Options 1 & 2 Are Not Viable + +**Option 1 (Stored Pointer)**: Fundamentally flawed due to dependency ordering. The component pointer would be an out-of-band dependency not reflected in the IR graph. IR traversal algorithms follow explicit input/output edges, with no guarantee that component will be processed before ragged during cloning/lowering/replay. Violates nvFuser's design principle of explicit dependency edges. + +**Option 2 (IR Traversal)**: Fails in two critical scenarios: +1. **Fusion Segmentation**: Partition expression may be in a different segment, unreachable via traversal +2. **Future External Inputs**: When RaggedIterDomain support is extended to accept ragged tensors from PyTorch as fusion inputs, these will have no nvFuser Partition expression to traverse to + +These aren't edge cases - they're fundamental use cases that must be supported. + +### Why Option 3 Is The Pragmatic Choice + +**Option 3** aligns with nvFuser's design philosophy: like arithmetic operations that assume shape compatibility, `combine()` trusts users to provide correct inputs. It validates when Partition definition exists but otherwise relies on user correctness. Simple to implement, handles all use cases (propagation, segmentation, external inputs), and acknowledges that complete validation is architecturally infeasible. + +**Option 4 (TensorDomain Pairing)** is architecturally cleaner than Option 1 (looser coupling) but requires extensive changes to TensorDomain operations and still has segmentation issues. Could be a future enhancement if user errors become problematic, but Option 3's simplicity is preferred for now. + +## Recommendation + +### Proposed Solution: **Option 3 - Validate When Partition Definition Exists** + +**This is the current design choice.** We will reconsider Option 4 (TensorDomain Pairing) if it proves more appropriate based on practical experience or future requirements. + +**Rationale**: + +Option 3 is the most reasonable approach because it: + +1. **Aligns with nvFuser's design philosophy**: Like arithmetic operations that assume shape compatibility, `combine()` trusts users to provide correct inputs +2. **Provides validation where feasible**: When a Partition definition is directly accessible, we validate the component +3. **Simple and maintainable**: No complex infrastructure, no global state, no dependency ordering issues +4. **Handles all use cases**: Works for direct Partition usage, propagated domains, segmented fusions, and future external inputs +5. **Pragmatic**: Acknowledges that complete validation is architecturally infeasible + +**Implementation**: + +```cpp +void combine(IterDomain* component, RaggedIterDomain* ragged) { + // Basic validation (null checks, type checks, etc.) + NVF_ERROR(component != nullptr && ragged != nullptr, "Null inputs"); + NVF_ERROR(!component->isRaggedDomain(), "Component must be regular IterDomain"); + + // Validate against Partition definition if available + if (ragged->definition() && ragged->definition()->isA()) { + auto* partition = ragged->definition()->as(); + NVF_ERROR( + component == partition->component(), + "Component mismatch: provided ", component->toString(), + " but Partition expects ", partition->component()->toString()); + } + + // If no Partition definition (after set, in segmented fusion, or external input), + // trust the user and proceed + + // Create combined IterDomain... +} +``` + +**What This Means**: + +- ✅ Early error detection when Partition definition is available +- ✅ No architectural violations or fragile infrastructure +- ✅ Users are responsible for correct usage (like other operations) +- ✅ Works across all scenarios (propagation, segmentation, external inputs) +- ⚠️ After propagation/segmentation, incorrect usage won't be caught by validation +- ⚠️ Users must track component-ragged pairings themselves + +**Comparison to Other Operations**: + +This is consistent with how nvFuser handles other operations: +- `add(tv1, tv2)` doesn't validate that shapes match - user responsibility +- `set(tv)` doesn't validate all properties - user responsibility +- `combine(component, ragged)` doesn't always validate pairing - user responsibility + +## Implementation Notes + +1. **Testing Strategy**: + - Test validation when Partition definition exists (should catch errors) + - Test that validation is skipped after `set()` operations (should succeed with correct usage) + - Document user responsibility in API documentation + +2. **Future Considerations**: + - Option 4 (TensorDomain Pairing) remains a viable alternative if the current approach proves insufficient + - We will reconsider Option 4 based on practical experience, user feedback, or new requirements + - If incorrect `combine()` usage becomes a common source of bugs, we can implement Option 4's more comprehensive validation + - For now, follow the principle of trusting user-facing APIs + - The `extents_` pointer handling may also need similar considerations in the future + +3. **Documentation**: + - Clearly document that users must provide the correct component that was paired with the ragged domain + - Note that validation is best-effort and may not catch all errors + - Provide examples of correct usage patterns From 3a80926ac7f1a56ca91167f4b4eed74e7a3e968b Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 19 Dec 2025 15:22:34 -0800 Subject: [PATCH 32/47] license --- doc/dev/ragged_iter_domain_combine_design_doc.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/doc/dev/ragged_iter_domain_combine_design_doc.md b/doc/dev/ragged_iter_domain_combine_design_doc.md index f3671c19377..fc05758e672 100644 --- a/doc/dev/ragged_iter_domain_combine_design_doc.md +++ b/doc/dev/ragged_iter_domain_combine_design_doc.md @@ -1,3 +1,9 @@ + + # Design Document: Component IterDomain Tracking for RaggedIterDomain ## Problem Statement From 8aa854e17bc7dfc6a08c12a7fd63eb5d3ae43070 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 7 Jan 2026 09:58:52 -0800 Subject: [PATCH 33/47] feedback --- csrc/ops/alias.cpp | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index ffc26e308cb..1c412a61f8d 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -1274,23 +1274,24 @@ TensorView* asNested( NVF_ERROR(data != nullptr, "asNested: data tensor is null"); NVF_ERROR(extents != nullptr, "asNested: extents tensor is null"); + // Get the logical domain of the input, excluding reductions + auto inp_logical = data->getLogicalDomain() | TensorDomain::kNoReductions; + auto inp_logical_size = std::ranges::distance(inp_logical); + // Only 1D extents tensors are currently supported NVF_ERROR_EQ( - extents->nDims(), + inp_logical_size, 1, "asNested currently only supports 1D extents tensors"); - // Get the logical domain of the input, excluding reductions - auto inp_logical = TensorDomain::noReductions(data->getLogicalDomain()); - // Clone the logical domain to create the root domain for output std::vector root_domain; - root_domain.reserve(inp_logical.size()); + root_domain.reserve(inp_logical_size); for (auto* id : inp_logical) { root_domain.push_back(id->cloneWithoutRFactor()); } - ragged_dim = wrapDim(ragged_dim, std::ssize(inp_logical)); + ragged_dim = wrapDim(ragged_dim, inp_logical_size); // Partition the specified dimension in root domain // This replaces one IterDomain with (component_id, ragged_id) From 72ae14f1ebac33c7144572558e5ac8c55925023b Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 7 Jan 2026 10:17:33 -0800 Subject: [PATCH 34/47] fix --- csrc/ops/alias.cpp | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index 1c412a61f8d..0fdb03274c0 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -1274,16 +1274,17 @@ TensorView* asNested( NVF_ERROR(data != nullptr, "asNested: data tensor is null"); NVF_ERROR(extents != nullptr, "asNested: extents tensor is null"); - // Get the logical domain of the input, excluding reductions - auto inp_logical = data->getLogicalDomain() | TensorDomain::kNoReductions; - auto inp_logical_size = std::ranges::distance(inp_logical); - // Only 1D extents tensors are currently supported NVF_ERROR_EQ( - inp_logical_size, + std::ranges::distance( + extents->getLogicalDomain() | TensorDomain::kNoReductions), 1, "asNested currently only supports 1D extents tensors"); + // Get the logical domain of the input, excluding reductions + auto inp_logical = data->getLogicalDomain() | TensorDomain::kNoReductions; + auto inp_logical_size = std::ranges::distance(inp_logical); + // Clone the logical domain to create the root domain for output std::vector root_domain; root_domain.reserve(inp_logical_size); From 4d8acaba484d63b0f2653964ecd8eaa919081f78 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 8 Jan 2026 17:07:37 -0800 Subject: [PATCH 35/47] cleanup --- csrc/ops/utils.cpp | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/csrc/ops/utils.cpp b/csrc/ops/utils.cpp index be50385528c..ccb15fd7f45 100644 --- a/csrc/ops/utils.cpp +++ b/csrc/ops/utils.cpp @@ -331,10 +331,7 @@ RaggedIterDomain* newOutputRaggedIterDomain( // equivalent RaggedIterDomain* ref_input_id = input_ids.front()->as(); - return IrBuilder::create( - ref_input_id->extents(), - ref_input_id->getIterType(), - ref_input_id->getParallelType()); + return IterDomainBuilder(ref_input_id).build()->as(); } // Adding these pragmas since gcc-12.2.1 From 3b082ba6fcd636e517610b44eaec8c07436db02a Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 8 Jan 2026 22:59:50 -0800 Subject: [PATCH 36/47] cleanup --- csrc/ir/internal_base_nodes.cpp | 4 +--- csrc/ops/utils.cpp | 9 ++++++++- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/csrc/ir/internal_base_nodes.cpp b/csrc/ir/internal_base_nodes.cpp index 698a047db6f..e9a0da1fa2f 100644 --- a/csrc/ir/internal_base_nodes.cpp +++ b/csrc/ir/internal_base_nodes.cpp @@ -983,9 +983,7 @@ std::string RaggedIterDomain::toString(int indent_size) const { } IterDomain* RaggedIterDomain::cloneWithoutRFactor(bool map_with_original) { - // Create a new RaggedIterDomain with the same extents and properties - auto cloned = IrBuilder::create( - extents_, getIterType(), getParallelType()); + auto cloned = IterDomainBuilder(this).resetRfactor().build(); // Optionally map the clone with the original in the Exact graph if (map_with_original) { diff --git a/csrc/ops/utils.cpp b/csrc/ops/utils.cpp index ccb15fd7f45..36e46ef97e2 100644 --- a/csrc/ops/utils.cpp +++ b/csrc/ops/utils.cpp @@ -345,13 +345,20 @@ IterDomain* newOutputIterDomain( const std::optional force_iter_type) { NVF_ERROR(!input_ids.empty()); - // If any input ID is a RaggedIterDomain, the output should also be ragged + // If an input ID is a RaggedIterDomain, the output as well as all + // other inputs must be ragged bool has_ragged = std::any_of(input_ids.begin(), input_ids.end(), [](IterDomain* id) { return id->isA(); }); if (has_ragged) { + NVF_ERROR( + std::all_of( + input_ids.begin(), + input_ids.end(), + [](IterDomain* id) { return id->isA(); }), + "All of none input IDs must be ragged"); NVF_ERROR( !force_iter_type.has_value(), "force_iter_type not supported for RaggedIterDomain"); From 5002407a9e9380cd93f50ebf53ee9dbf8de6fa33 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Tue, 13 Jan 2026 10:14:53 -0800 Subject: [PATCH 37/47] expand doc --- doc/dev/ragged_iter_domain_combine_design_doc.md | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/doc/dev/ragged_iter_domain_combine_design_doc.md b/doc/dev/ragged_iter_domain_combine_design_doc.md index fc05758e672..8816a3ecc9d 100644 --- a/doc/dev/ragged_iter_domain_combine_design_doc.md +++ b/doc/dev/ragged_iter_domain_combine_design_doc.md @@ -23,17 +23,18 @@ The naive approach of checking `ragged->definition()` for a Partition expression ```cpp // tv0: [i0] - regular tensor auto result = asNested(tv0, 0, extents); // Creates Partition expression -// result.ragged: RaggedIterDomain with Partition definition -// result.component: Component IterDomain +// result->getLogicalDomain()[0], result->getLogicalDomain()[1] = partition(result->getRootDomain()[0], extents); auto tv1 = set(tv0); // Propagates IterDomains // tv1 has a RaggedIterDomain, but it's a clone without Partition definition -combine(result.component, tv1->getRaggedDomain()); // How do we validate? +combine(tv1->getLogicalDomain()[0], tv1->getLogicalDomain()[1]); // How do we validate? ``` ## Design Alternatives +Several design alternatives are considered, among which Option 3 is the current choice as it is the simplest approach, although it is not ideal in terms of completeness of validation. + ### Option 1: Store Component Pointer in RaggedIterDomain **Approach**: Add a `component_` member variable to `RaggedIterDomain` that points to the paired component IterDomain. From be0e2ea424fcffd8698ab0733299780a811c56e6 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Tue, 13 Jan 2026 10:38:54 -0800 Subject: [PATCH 38/47] cleanup --- csrc/ir/internal_base_nodes.cpp | 2 +- tests/cpp/test_ragged_iter_domain.cpp | 41 --------------------------- 2 files changed, 1 insertion(+), 42 deletions(-) diff --git a/csrc/ir/internal_base_nodes.cpp b/csrc/ir/internal_base_nodes.cpp index e526a67f677..7ca95e03803 100644 --- a/csrc/ir/internal_base_nodes.cpp +++ b/csrc/ir/internal_base_nodes.cpp @@ -1123,7 +1123,7 @@ IterDomain* RaggedIterDomain::combine( // It is still assumed the extents tensor is just 1D NVF_ERROR_EQ( - std::ssize(extents_tv->getLogicalDomain()), + std::ranges::distance(extents_tv->getLogicalDomain() | TensorDomain::kNoReductions), 1, "Unexpected rank of extent tensor: ", extents_tv->toString()); diff --git a/tests/cpp/test_ragged_iter_domain.cpp b/tests/cpp/test_ragged_iter_domain.cpp index 0e0edbe6893..58bb0b18481 100644 --- a/tests/cpp/test_ragged_iter_domain.cpp +++ b/tests/cpp/test_ragged_iter_domain.cpp @@ -580,47 +580,6 @@ TEST_F(RaggedIterDomainTest, AsNestedThenSetThenCombine) { EXPECT_EQ(combine_expr->out(), combined_id); } -// Test combining with invalid component (not from same partition) - should -// Test combining after set operation with invalid component -// With Option 3 validation strategy, this does NOT throw an error -// because after set(), the RaggedIterDomain loses its Partition definition -// and validation is skipped (trusts the user) -TEST_F(RaggedIterDomainTest, AsNestedThenSetThenCombineInvalidComponent) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto data = makeSymbolicTensor(2, DataType::Float); - fusion.addInput(data); - - auto extents = makeSymbolicTensor(1, DataType::Index); - fusion.addInput(extents); - - // Create nested tensor from dimension 0 - auto nested = asNested(data, extents, 0); - - // Insert a set operation after asNested - auto nested_copy = set(nested); - - // Verify nested_copy tensor has 3 dimensions: [component, ragged, - // original_dim1] - EXPECT_EQ(nested_copy->nDims(), 3); - EXPECT_TRUE(nested_copy->axis(0)->isStrictlyA()); - EXPECT_TRUE(nested_copy->axis(1)->isA()); - - // Get the ragged IterDomain from the copy - auto ragged_id = nested_copy->axis(1)->as(); - - // Use an INVALID component: the third axis instead of the first - // This is NOT the component from the partition, it's the original second - // dimension - auto invalid_component_id = nested_copy->axis(2); - - // With Option 3: After set(), the RaggedIterDomain no longer has a - // Partition definition, so validation is skipped and the operation succeeds. - // The user is responsible for providing the correct component. - EXPECT_NO_THROW(RaggedIterDomain::combine(invalid_component_id, ragged_id)); -} - // asNested on different dimensions TEST_F(RaggedIterDomainTest, AsNestedDifferentDimension) { Fusion fusion; From cf346adca542e07b6a1bf99a72d6c8db4e3b207b Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Tue, 13 Jan 2026 10:48:06 -0800 Subject: [PATCH 39/47] WIP --- csrc/ir/internal_base_nodes.cpp | 55 ++++++++----- csrc/ir/internal_base_nodes.h | 15 ++-- csrc/ops/alias.cpp | 10 ++- tests/cpp/test_ragged_iter_domain.cpp | 110 +++++++++++++++++++++++--- 4 files changed, 149 insertions(+), 41 deletions(-) diff --git a/csrc/ir/internal_base_nodes.cpp b/csrc/ir/internal_base_nodes.cpp index cc068fbfca2..623858ade3d 100644 --- a/csrc/ir/internal_base_nodes.cpp +++ b/csrc/ir/internal_base_nodes.cpp @@ -927,36 +927,49 @@ std::pair RaggedIterDomain::partition( offsets->dtype()); const auto& offsets_domain = offsets->getLogicalDomain(); - NVF_ERROR_EQ( - offsets_domain.size(), - 1, - "partition: offsets tensor must be 1D, got ", - offsets_domain.size(), - "D tensor. Multi-dimensional offsets not yet supported."); + NVF_ERROR( + !offsets_domain.empty(), + "partition: offsets tensor must have at least one dimension"); auto container = in->container(); - // Compute extents from offsets: extents[i] = offsets[i+1] - offsets[i] - // offsets_left = offsets[:-1] (all but last element) - // offsets_right = offsets[1:] (all but first element) + // Compute extents from offsets: extents[..., i] = offsets[..., i+1] - offsets[..., i] + // For N-D offsets, we slice along the last dimension: + // offsets_left = offsets[..., :-1] (all but last element along last dim) + // offsets_right = offsets[..., 1:] (all but first element along last dim) - auto offsets_len = offsets_domain[0]->extent(); + // Get the extent of the last dimension (which defines the number of components) + auto last_dim_idx = std::ssize(offsets_domain) - 1; + auto offsets_len = offsets_domain[last_dim_idx]->extent(); auto zero = container->zeroVal(DataType::Index); auto one = container->oneVal(DataType::Index); auto len_minus_one = sub(offsets_len, one); - // Slice offsets[:-1] - Slice left_slice; - left_slice.start = zero; - left_slice.stop = len_minus_one; - auto offsets_left = slice(offsets, {left_slice}); - - // Slice offsets[1:] - Slice right_slice; - right_slice.start = one; - right_slice.stop = offsets_len; - auto offsets_right = slice(offsets, {right_slice}); + // Build slice specifications for all dimensions + // All dimensions except the last use full range (:), last dimension uses [:-1] and [1:] + std::vector left_slices(offsets_domain.size()); + std::vector right_slices(offsets_domain.size()); + + // For all dimensions except the last, use full range + for (size_t i = 0; i < offsets_domain.size() - 1; i++) { + left_slices[i].start = zero; + left_slices[i].stop = offsets_domain[i]->extent(); + right_slices[i].start = zero; + right_slices[i].stop = offsets_domain[i]->extent(); + } + + // For the last dimension, slice [:-1] and [1:] + left_slices[last_dim_idx].start = zero; + left_slices[last_dim_idx].stop = len_minus_one; + right_slices[last_dim_idx].start = one; + right_slices[last_dim_idx].stop = offsets_len; + + // Slice offsets[..., :-1] + auto offsets_left = slice(offsets, left_slices); + + // Slice offsets[..., 1:] + auto offsets_right = slice(offsets, right_slices); // Compute extents: extents = offsets_right - offsets_left auto extents = sub(offsets_right, offsets_left); diff --git a/csrc/ir/internal_base_nodes.h b/csrc/ir/internal_base_nodes.h index 0187c408bd7..702b966886e 100644 --- a/csrc/ir/internal_base_nodes.h +++ b/csrc/ir/internal_base_nodes.h @@ -480,14 +480,15 @@ class NVF_API RaggedIterDomain : public IterDomain { //! Creates a component IterDomain and a RaggedIterDomain based on offsets //! //! \param in Input IterDomain to partition (must be regular IterDomain) - //! \param offsets Offset tensor defining partition boundaries (must be 1D) - //! Shape: [num_components + 1], values: [0, off1, off2, ..., total] - //! Extents are computed as: extents[i] = offsets[i+1] - offsets[i] - //! \return Pair of (component_id, ragged_id) - //! component_id: IterDomain with extent = num_components - //! ragged_id: RaggedIterDomain with extents computed from offsets + //! \param offsets Offset tensor defining partition boundaries + //! Shape [..., num_components + 1], offsets along last dimension + //! Extents computed as: extents[..., i] = offsets[..., i+1] - offsets[..., i] + //! 1D example: Shape [num_components + 1], values [0, off1, off2, ..., total] + //! 2D example: Shape [outer_dim, num_components + 1], e.g., [num_gpus, num_experts + 1] //! - //! TODO: Support multi-dimensional offsets for nested ragged structures + //! \return Pair of (component_id, ragged_id) + //! component_id: IterDomain with extent = num_components (from last dim of offsets) + //! ragged_id: RaggedIterDomain with N-D extents tensor (same shape as offsets minus 1 in last dim) static std::pair partition( IterDomain* in, TensorView* offsets); diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index e32aa5e6b9c..231535c5955 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -1275,11 +1275,13 @@ TensorView* asNested( NVF_ERROR(data != nullptr, "asNested: data tensor is null"); NVF_ERROR(offsets != nullptr, "asNested: offsets tensor is null"); - // Only 1D offset tensors are currently supported - NVF_ERROR_EQ( + // Offsets can be N-D tensors for nested ragged structures + // The partition operation will handle multi-dimensional offsets correctly + NVF_ERROR( + offsets->nDims() >= 1, + "asNested requires offsets to be at least 1D, got ", offsets->nDims(), - 1, - "asNested currently only supports 1D offset tensors"); + "D"); // Get the logical domain of the input, excluding reductions auto inp_logical = TensorDomain::noReductions(data->getLogicalDomain()); diff --git a/tests/cpp/test_ragged_iter_domain.cpp b/tests/cpp/test_ragged_iter_domain.cpp index f7eaac14c2e..e40ffdda2bc 100644 --- a/tests/cpp/test_ragged_iter_domain.cpp +++ b/tests/cpp/test_ragged_iter_domain.cpp @@ -282,11 +282,8 @@ TEST_F(RaggedIterDomainTest, PartitionValidation) { EXPECT_THROW( RaggedIterDomain::partition(input_id, float_offsets), nvfuser::nvfError); - // Test 4: Multi-dimensional offsets should fail - auto offsets_2d = makeSymbolicTensor(2, DataType::Index); - fusion.addInput(offsets_2d); - EXPECT_THROW( - RaggedIterDomain::partition(input_id, offsets_2d), nvfuser::nvfError); + // Test 4: 0-dimensional offsets should fail + // (We can't easily create a 0D tensor in the test, so skip this test) // Test 5: Non-Iteration IterType should fail auto reduction_id = @@ -467,19 +464,114 @@ TEST_F(RaggedIterDomainTest, AsNestedValidationNullOffsets) { EXPECT_THROW(asNested(data, nullptr, 0), nvfuser::nvfError); } -// asNested validation - multi-dimensional offsets (not yet supported) -TEST_F(RaggedIterDomainTest, AsNestedValidationMultiDimOffsets) { +// Multi-dimensional offsets partition test +TEST_F(RaggedIterDomainTest, Partition2DOffsets) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Create input IterDomain + auto input_id = + IterDomainBuilder( + fusion.zeroVal(), IrBuilder::create(-1, DataType::Index)) + .build(); + + // Create 2D offset tensor (e.g., [num_gpus, num_experts + 1]) + auto offsets_2d = makeSymbolicTensor(2, DataType::Index); + fusion.addInput(offsets_2d); + + // Partition the IterDomain with 2D offsets + auto [component_id, ragged_id] = + RaggedIterDomain::partition(input_id, offsets_2d); + + // Verify component IterDomain + EXPECT_TRUE(component_id != nullptr); + EXPECT_TRUE(component_id->isA()); + EXPECT_FALSE(component_id->isA()); + + // Verify RaggedIterDomain + EXPECT_TRUE(ragged_id != nullptr); + EXPECT_TRUE(ragged_id->isA()); + EXPECT_TRUE(ragged_id->extents() != nullptr); + + // Verify the extents tensor is 2D (one dimension less in last dim than offsets) + auto extents = ragged_id->extents(); + EXPECT_EQ(extents->nDims(), 2); + + // Verify that a Partition expr was created + EXPECT_TRUE(component_id->definition() != nullptr); + EXPECT_TRUE(component_id->definition()->isA()); + EXPECT_EQ(component_id->definition(), ragged_id->definition()); +} + +// asNested with 2D offsets (expert parallelism use case) +TEST_F(RaggedIterDomainTest, AsNested2DOffsets) { Fusion fusion; FusionGuard fg(&fusion); + // Create a 2D TensorView representing tokens [total_tokens, hidden] auto data = makeSymbolicTensor(2, DataType::Float); fusion.addInput(data); - // 2D offsets should fail (only 1D supported currently) + // Create 2D offsets [num_gpus, num_experts + 1] + // This represents per-GPU offsets for routing tokens to experts auto offsets_2d = makeSymbolicTensor(2, DataType::Index); fusion.addInput(offsets_2d); - EXPECT_THROW(asNested(data, offsets_2d, 0), nvfuser::nvfError); + // Create nested tensor partitioning the first dimension + auto nested = asNested(data, offsets_2d, 0); + + fusion.addOutput(nested); + + // Verify the output dimensions + // Should be: [component (num_experts), ragged (tokens_per_expert with 2D extents), hidden] + EXPECT_EQ(nested->nDims(), 3); + + // First axis is component (num_experts from last dim of offsets) + EXPECT_TRUE(nested->axis(0)->isStrictlyA()); + EXPECT_FALSE(nested->axis(0)->isA()); + + // Second axis is ragged with 2D extents + EXPECT_TRUE(nested->axis(1)->isA()); + auto ragged_id = nested->axis(1)->as(); + EXPECT_EQ(ragged_id->extents()->nDims(), 2); + + // Third axis is the original hidden dimension + EXPECT_TRUE(nested->axis(2)->isStrictlyA()); +} + +// Partition with 3D offsets +TEST_F(RaggedIterDomainTest, Partition3DOffsets) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Create input IterDomain + auto input_id = + IterDomainBuilder( + fusion.zeroVal(), IrBuilder::create(-1, DataType::Index)) + .build(); + + // Create 3D offset tensor (e.g., [batch, num_gpus, num_experts + 1]) + auto offsets_3d = makeSymbolicTensor(3, DataType::Index); + fusion.addInput(offsets_3d); + + // Partition the IterDomain with 3D offsets + auto [component_id, ragged_id] = + RaggedIterDomain::partition(input_id, offsets_3d); + + // Verify component IterDomain extent comes from last dimension + EXPECT_TRUE(component_id != nullptr); + EXPECT_TRUE(component_id->isA()); + + // Verify RaggedIterDomain has 3D extents + EXPECT_TRUE(ragged_id != nullptr); + EXPECT_TRUE(ragged_id->isA()); + auto extents = ragged_id->extents(); + EXPECT_TRUE(extents != nullptr); + EXPECT_EQ(extents->nDims(), 3); + + // Verify Partition expr + EXPECT_TRUE(component_id->definition()->isA()); + EXPECT_EQ(component_id->definition(), ragged_id->definition()); } } // namespace nvfuser From 19737a86e65a2e4785e7c82d6a2bcf49e1c8d37c Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 14 Jan 2026 11:56:18 -0800 Subject: [PATCH 40/47] Partition with multi-dim extents --- csrc/ir/internal_base_nodes.cpp | 19 +++--- csrc/ir/internal_base_nodes.h | 9 +-- csrc/ops/alias.cpp | 23 +++++-- tests/cpp/test_ragged_iter_domain.cpp | 96 ++++++++++++++++++++++----- 4 files changed, 110 insertions(+), 37 deletions(-) diff --git a/csrc/ir/internal_base_nodes.cpp b/csrc/ir/internal_base_nodes.cpp index 3dbae3ba5d0..f198fe509c4 100644 --- a/csrc/ir/internal_base_nodes.cpp +++ b/csrc/ir/internal_base_nodes.cpp @@ -1026,20 +1026,21 @@ std::pair RaggedIterDomain::partition( extents->dtype()); const auto& extents_domain = extents->getLogicalDomain(); - NVF_ERROR_EQ( - extents_domain.size(), - 1, - "partition: extents tensor must be 1D, got ", - extents_domain.size(), - "D tensor. Multi-dimensional extents not yet supported."); + NVF_ERROR( + !extents_domain.empty(), + "partition: extents tensor must have at least one dimension, got 0D tensor"); auto container = in->container(); // Create component IterDomain - // Component extent = number of components = length of extents tensor + // Component extent = number of components = size of last dimension of extents + // For 1D extents [K]: component_extent = K + // For 2D extents [D, K]: component_extent = K (last dim) + // For N-D extents [..., K]: component_extent = K (last dim) + // The outer dimensions of extents correspond to outer dimensions of the + // tensor being partitioned, allowing non-uniform partitions across instances. auto zero = container->zeroVal(DataType::Index); - // TODO: This is likely wrong - auto component_extent = extents_domain.at(0)->extent(); + auto component_extent = extents_domain.back()->extent(); auto component_id = IterDomainBuilder(zero, component_extent) .parallel_type(ParallelType::Serial) .iter_type(IterType::Iteration) diff --git a/csrc/ir/internal_base_nodes.h b/csrc/ir/internal_base_nodes.h index 72ce7efcdd1..9300b6d72e2 100644 --- a/csrc/ir/internal_base_nodes.h +++ b/csrc/ir/internal_base_nodes.h @@ -487,13 +487,14 @@ class NVF_API RaggedIterDomain : public IterDomain { //! //! \param in Input IterDomain to partition (must be regular IterDomain) //! \param extents Extents tensor defining the size of each component - //! 1D example: Shape [num_components], values [extent0, extent1, ..., - //! extent(n-1)] + //! 1D example: Shape [num_components], values [extent0, extent1, ...] //! 2D example: Shape [outer_dim, num_components], e.g., [num_gpus, num_experts] + //! For N-D extents, the last dimension defines the number of components, + //! and outer dimensions correspond to outer dimensions of the tensor. //! //! \return Pair of (component_id, ragged_id) - //! component_id: IterDomain with extent = num_components (from last dim of offsets) - //! ragged_id: RaggedIterDomain with N-D extents tensor (same shape as extents in last dim) + //! component_id: IterDomain with extent = num_components (from last dim of extents) + //! ragged_id: RaggedIterDomain with N-D extents tensor static std::pair partition( IterDomain* in, TensorView* extents); diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index 3f531896a61..220cc269be3 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -1316,12 +1316,6 @@ TensorView* asNested( NVF_ERROR(data != nullptr, "asNested: data tensor is null"); NVF_ERROR(extents != nullptr, "asNested: extents tensor is null"); - NVF_ERROR_GE( - std::ranges::distance( - extents->getLogicalDomain() | TensorDomain::kNoReductions), - 1, - "asNested currently only supports 1D extents tensors"); - NVF_CHECK( !data->domain()->hasRaggedIterDomain(), "Multiple level of nesting is not supported: ", @@ -1340,6 +1334,23 @@ TensorView* asNested( ragged_dim = wrapDim(ragged_dim, inp_logical_size); + // Validate shape correspondence for multi-dimensional extents + // For N-D extents, outer dimensions must match outer dimensions of input tensor + // Rule: extents.ndim - 1 == ragged_dim (except 1D extents which are always valid) + const auto& extents_domain = extents->getLogicalDomain(); + if (extents_domain.size() > 1) { + NVF_ERROR_EQ( + extents_domain.size() - 1, + ragged_dim, + "asNested: Multi-dimensional extents require shape ", + "[d0, d1, ..., d(axis-1), num_components]. ", + "Got ", + extents_domain.size(), + "D extents for partitioning axis ", + ragged_dim); + } + // Note: 1D extents are always valid for any axis (uniform partition) + // Partition the specified dimension in root domain // This replaces one IterDomain with (component_id, ragged_id) auto [component_id, ragged_id] = diff --git a/tests/cpp/test_ragged_iter_domain.cpp b/tests/cpp/test_ragged_iter_domain.cpp index c1273b57258..5a354fde8f6 100644 --- a/tests/cpp/test_ragged_iter_domain.cpp +++ b/tests/cpp/test_ragged_iter_domain.cpp @@ -695,40 +695,44 @@ TEST_F(RaggedIterDomainTest, Partition2DOffsets) { EXPECT_EQ(component_id->definition(), ragged_id->definition()); } -// asNested with 2D offsets (expert parallelism use case) +// asNested with 2D extents (should partition axis 1, not axis 0) TEST_F(RaggedIterDomainTest, AsNested2DOffsets) { Fusion fusion; FusionGuard fg(&fusion); - // Create a 2D TensorView representing tokens [total_tokens, hidden] - auto data = makeSymbolicTensor(2, DataType::Float); + // Create a 3D TensorView: [D=2, tokens=100, hidden=512] + // This represents 2 GPUs, each with tokens, and hidden dimension + auto data = makeSymbolicTensor(3, DataType::Float); fusion.addInput(data); - // Create 2D offsets [num_gpus, num_experts + 1] - // This represents per-GPU offsets for routing tokens to experts - auto offsets_2d = makeSymbolicTensor(2, DataType::Index); - fusion.addInput(offsets_2d); + // Create 2D extents [D=2, num_experts=4] + // This represents per-GPU token counts for experts + auto extents_2d = makeSymbolicTensor(2, DataType::Index); + fusion.addInput(extents_2d); - // Create nested tensor partitioning the first dimension - auto nested = asNested(data, offsets_2d, 0); + // Create nested tensor partitioning dimension 1 (tokens) + auto nested = asNested(data, extents_2d, 1); fusion.addOutput(nested); // Verify the output dimensions - // Should be: [component (num_experts), ragged (tokens_per_expert with 2D extents), hidden] - EXPECT_EQ(nested->nDims(), 3); + // Should be: [D=2, component=4, ragged, hidden=512] + EXPECT_EQ(nested->nDims(), 4); - // First axis is component (num_experts from last dim of offsets) + // First axis is unchanged (D=2) EXPECT_TRUE(nested->axis(0)->isStrictlyA()); - EXPECT_FALSE(nested->axis(0)->isA()); - // Second axis is ragged with 2D extents - EXPECT_TRUE(nested->axis(1)->isA()); - auto ragged_id = nested->axis(1)->as(); + // Second axis is component (num_experts from last dim) + EXPECT_TRUE(nested->axis(1)->isStrictlyA()); + EXPECT_FALSE(nested->axis(1)->isA()); + + // Third axis is ragged with 2D extents + EXPECT_TRUE(nested->axis(2)->isA()); + auto ragged_id = nested->axis(2)->as(); EXPECT_EQ(ragged_id->extents()->nDims(), 2); - // Third axis is the original hidden dimension - EXPECT_TRUE(nested->axis(2)->isStrictlyA()); + // Fourth axis is the original hidden dimension + EXPECT_TRUE(nested->axis(3)->isStrictlyA()); } // Partition with 3D offsets @@ -1249,4 +1253,60 @@ TEST_F(RaggedIterDomainTest, PadRaggedDimensionError) { EXPECT_THROW(pad(nested, pad_widths), nvfuser::nvfError); } +// Test asNested with 2D extents on correct axis (matching expert parallelism) +TEST_F(RaggedIterDomainTest, AsNested2DExtentsAxis1) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Create a 3D TensorView: [D=2, S=100, hidden=512] + // This represents 2 GPUs, each with 100 tokens, and hidden dimension + auto tokens = makeSymbolicTensor(3, DataType::Float); + fusion.addInput(tokens); + + // Create 2D extents: [D=2, E=4] + // Represents per-GPU token counts for 4 experts + auto extents = makeSymbolicTensor(2, DataType::Index); + fusion.addInput(extents); + + // Partition dimension 1 (tokens) with 2D extents + auto nested = asNested(tokens, extents, 1); + + fusion.addOutput(nested); + + // Should be: [D=2, component=4, ragged, hidden=512] + EXPECT_EQ(nested->nDims(), 4); + + // First axis unchanged (D=2) + EXPECT_TRUE(nested->axis(0)->isStrictlyA()); + + // Second axis is component (num_experts=4) + EXPECT_TRUE(nested->axis(1)->isStrictlyA()); + EXPECT_FALSE(nested->axis(1)->isA()); + + // Third axis is ragged with 2D extents + EXPECT_TRUE(nested->axis(2)->isA()); + auto ragged_id = nested->axis(2)->as(); + EXPECT_EQ(ragged_id->extents()->nDims(), 2); + + // Fourth axis unchanged (hidden=512) + EXPECT_TRUE(nested->axis(3)->isStrictlyA()); +} + +// Test asNested invalid shape +TEST_F(RaggedIterDomainTest, AsNestedInvalidShape) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Create a 3D TensorView: [D=2, S=100, hidden=512] + auto tokens = makeSymbolicTensor(3, DataType::Float); + fusion.addInput(tokens); + + // Create 3D extents (wrong dimensionality for axis 1) + auto extents_3d = makeSymbolicTensor(3, DataType::Index); + fusion.addInput(extents_3d); + + // This should throw: 3D extents for axis 1 requires extents.ndim - 1 == 1 + EXPECT_THROW(asNested(tokens, extents_3d, 1), nvfuser::nvfError); +} + } // namespace nvfuser From 5a1e3252007057b3b538bc4648126aac732c68fe Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 14 Jan 2026 12:59:16 -0800 Subject: [PATCH 41/47] WIP --- csrc/ir/internal_base_nodes.cpp | 15 ++++++++++----- csrc/ops/alias.cpp | 15 +++++++++++---- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/csrc/ir/internal_base_nodes.cpp b/csrc/ir/internal_base_nodes.cpp index f198fe509c4..42ef9252a95 100644 --- a/csrc/ir/internal_base_nodes.cpp +++ b/csrc/ir/internal_base_nodes.cpp @@ -1025,10 +1025,14 @@ std::pair RaggedIterDomain::partition( "partition: extents must have Index type, got ", extents->dtype()); - const auto& extents_domain = extents->getLogicalDomain(); - NVF_ERROR( - !extents_domain.empty(), - "partition: extents tensor must have at least one dimension, got 0D tensor"); + // Filter out reduction dimensions from extents tensor + auto extents_no_reduction = + extents->getLogicalDomain() | TensorDomain::kNoReductions; + auto extents_ndim = std::ranges::distance(extents_no_reduction); + NVF_ERROR_GT( + extents_ndim, + 0, + "partition: extents tensor must have at least one non-reduction dimension"); auto container = in->container(); @@ -1040,7 +1044,8 @@ std::pair RaggedIterDomain::partition( // The outer dimensions of extents correspond to outer dimensions of the // tensor being partitioned, allowing non-uniform partitions across instances. auto zero = container->zeroVal(DataType::Index); - auto component_extent = extents_domain.back()->extent(); + auto component_extent = + (*std::ranges::prev(extents_no_reduction.end()))->extent(); auto component_id = IterDomainBuilder(zero, component_extent) .parallel_type(ParallelType::Serial) .iter_type(IterType::Iteration) diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index 220cc269be3..c9fb2efd397 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -1337,15 +1337,22 @@ TensorView* asNested( // Validate shape correspondence for multi-dimensional extents // For N-D extents, outer dimensions must match outer dimensions of input tensor // Rule: extents.ndim - 1 == ragged_dim (except 1D extents which are always valid) - const auto& extents_domain = extents->getLogicalDomain(); - if (extents_domain.size() > 1) { + // Filter out reduction dimensions from extents tensor + auto extents_no_reduction = + extents->getLogicalDomain() | TensorDomain::kNoReductions; + auto extents_ndim = std::ranges::distance(extents_no_reduction); + NVF_ERROR_GT( + extents_ndim, + 0, + "asNested: extents tensor must have at least one non-reduction dimension"); + if (extents_ndim > 1) { NVF_ERROR_EQ( - extents_domain.size() - 1, + extents_ndim - 1, ragged_dim, "asNested: Multi-dimensional extents require shape ", "[d0, d1, ..., d(axis-1), num_components]. ", "Got ", - extents_domain.size(), + extents_ndim, "D extents for partitioning axis ", ragged_dim); } From 1400bafcf7982e2c94bc52adb8e3512a91446803 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 14 Jan 2026 13:07:32 -0800 Subject: [PATCH 42/47] cleanup --- tests/cpp/test_ragged_iter_domain.cpp | 74 --------------------------- 1 file changed, 74 deletions(-) diff --git a/tests/cpp/test_ragged_iter_domain.cpp b/tests/cpp/test_ragged_iter_domain.cpp index 5a354fde8f6..5e2d4affd48 100644 --- a/tests/cpp/test_ragged_iter_domain.cpp +++ b/tests/cpp/test_ragged_iter_domain.cpp @@ -656,45 +656,6 @@ TEST_F(RaggedIterDomainTest, AsNestedValidationNullExtents) { EXPECT_THROW(asNested(data, nullptr, 0), nvfuser::nvfError); } -// Multi-dimensional offsets partition test -TEST_F(RaggedIterDomainTest, Partition2DOffsets) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Create input IterDomain - auto input_id = - IterDomainBuilder( - fusion.zeroVal(), IrBuilder::create(-1, DataType::Index)) - .build(); - - // Create 2D offset tensor (e.g., [num_gpus, num_experts + 1]) - auto offsets_2d = makeSymbolicTensor(2, DataType::Index); - fusion.addInput(offsets_2d); - - // Partition the IterDomain with 2D offsets - auto [component_id, ragged_id] = - RaggedIterDomain::partition(input_id, offsets_2d); - - // Verify component IterDomain - EXPECT_TRUE(component_id != nullptr); - EXPECT_TRUE(component_id->isA()); - EXPECT_FALSE(component_id->isA()); - - // Verify RaggedIterDomain - EXPECT_TRUE(ragged_id != nullptr); - EXPECT_TRUE(ragged_id->isA()); - EXPECT_TRUE(ragged_id->extents() != nullptr); - - // Verify the extents tensor is 2D (one dimension less in last dim than offsets) - auto extents = ragged_id->extents(); - EXPECT_EQ(extents->nDims(), 2); - - // Verify that a Partition expr was created - EXPECT_TRUE(component_id->definition() != nullptr); - EXPECT_TRUE(component_id->definition()->isA()); - EXPECT_EQ(component_id->definition(), ragged_id->definition()); -} - // asNested with 2D extents (should partition axis 1, not axis 0) TEST_F(RaggedIterDomainTest, AsNested2DOffsets) { Fusion fusion; @@ -735,41 +696,6 @@ TEST_F(RaggedIterDomainTest, AsNested2DOffsets) { EXPECT_TRUE(nested->axis(3)->isStrictlyA()); } -// Partition with 3D offsets -TEST_F(RaggedIterDomainTest, Partition3DOffsets) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Create input IterDomain - auto input_id = - IterDomainBuilder( - fusion.zeroVal(), IrBuilder::create(-1, DataType::Index)) - .build(); - - // Create 3D offset tensor (e.g., [batch, num_gpus, num_experts + 1]) - auto offsets_3d = makeSymbolicTensor(3, DataType::Index); - fusion.addInput(offsets_3d); - - // Partition the IterDomain with 3D offsets - auto [component_id, ragged_id] = - RaggedIterDomain::partition(input_id, offsets_3d); - - // Verify component IterDomain extent comes from last dimension - EXPECT_TRUE(component_id != nullptr); - EXPECT_TRUE(component_id->isA()); - - // Verify RaggedIterDomain has 3D extents - EXPECT_TRUE(ragged_id != nullptr); - EXPECT_TRUE(ragged_id->isA()); - auto extents = ragged_id->extents(); - EXPECT_TRUE(extents != nullptr); - EXPECT_EQ(extents->nDims(), 3); - - // Verify Partition expr - EXPECT_TRUE(component_id->definition()->isA()); - EXPECT_EQ(component_id->definition(), ragged_id->definition()); -} - TEST_F(RaggedIterDomainTest, LoadStoreWithNestedTensor) { Fusion fusion; FusionGuard fg(&fusion); From d0a359b2558bfb92142110f0941555b9d11b4a07 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 14 Jan 2026 13:11:40 -0800 Subject: [PATCH 43/47] format --- csrc/ir/internal_base_nodes.h | 12 +++++++----- csrc/ops/alias.cpp | 9 +++++---- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/csrc/ir/internal_base_nodes.h b/csrc/ir/internal_base_nodes.h index 9300b6d72e2..67f66b68782 100644 --- a/csrc/ir/internal_base_nodes.h +++ b/csrc/ir/internal_base_nodes.h @@ -488,13 +488,15 @@ class NVF_API RaggedIterDomain : public IterDomain { //! \param in Input IterDomain to partition (must be regular IterDomain) //! \param extents Extents tensor defining the size of each component //! 1D example: Shape [num_components], values [extent0, extent1, ...] - //! 2D example: Shape [outer_dim, num_components], e.g., [num_gpus, num_experts] - //! For N-D extents, the last dimension defines the number of components, - //! and outer dimensions correspond to outer dimensions of the tensor. + //! 2D example: Shape [outer_dim, num_components], e.g., [num_gpus, + //! num_experts] For N-D extents, the last dimension defines the number + //! of components, and outer dimensions correspond to outer dimensions + //! of the tensor. //! //! \return Pair of (component_id, ragged_id) - //! component_id: IterDomain with extent = num_components (from last dim of extents) - //! ragged_id: RaggedIterDomain with N-D extents tensor + //! component_id: IterDomain with extent = num_components (from last + //! dim of extents) ragged_id: RaggedIterDomain with N-D extents + //! tensor static std::pair partition( IterDomain* in, TensorView* extents); diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index c9fb2efd397..4da87c4cb60 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -1335,16 +1335,17 @@ TensorView* asNested( ragged_dim = wrapDim(ragged_dim, inp_logical_size); // Validate shape correspondence for multi-dimensional extents - // For N-D extents, outer dimensions must match outer dimensions of input tensor - // Rule: extents.ndim - 1 == ragged_dim (except 1D extents which are always valid) - // Filter out reduction dimensions from extents tensor + // For N-D extents, outer dimensions must match outer dimensions of input + // tensor Rule: extents.ndim - 1 == ragged_dim (except 1D extents which are + // always valid) Filter out reduction dimensions from extents tensor auto extents_no_reduction = extents->getLogicalDomain() | TensorDomain::kNoReductions; auto extents_ndim = std::ranges::distance(extents_no_reduction); NVF_ERROR_GT( extents_ndim, 0, - "asNested: extents tensor must have at least one non-reduction dimension"); + "asNested: extents tensor must have at least one non-reduction " + "dimension"); if (extents_ndim > 1) { NVF_ERROR_EQ( extents_ndim - 1, From cbe185081467a6d52480f842ebb71ad71125f390 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 14 Jan 2026 13:19:42 -0800 Subject: [PATCH 44/47] Error check --- csrc/ir/internal_base_nodes.cpp | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/csrc/ir/internal_base_nodes.cpp b/csrc/ir/internal_base_nodes.cpp index 42ef9252a95..b514dd7d52d 100644 --- a/csrc/ir/internal_base_nodes.cpp +++ b/csrc/ir/internal_base_nodes.cpp @@ -1032,7 +1032,8 @@ std::pair RaggedIterDomain::partition( NVF_ERROR_GT( extents_ndim, 0, - "partition: extents tensor must have at least one non-reduction dimension"); + "partition: extents tensor must have at least one non-reduction " + "dimension"); auto container = in->container(); @@ -1128,11 +1129,18 @@ IterDomain* RaggedIterDomain::combine( TensorView* extents_tv = ragged->extents(); NVF_ERROR(extents_tv != nullptr, "combine: ragged extents tensor is null"); - // It is still assumed the extents tensor is just 1D + // Multi-dimensional extents are not yet supported in combine + // Filter out reduction dimensions before checking + auto extents_no_reduction = + extents_tv->getLogicalDomain() | TensorDomain::kNoReductions; + auto extents_ndim = std::ranges::distance(extents_no_reduction); NVF_ERROR_EQ( - std::ranges::distance(extents_tv->getLogicalDomain() | TensorDomain::kNoReductions), + extents_ndim, 1, - "Unexpected rank of extent tensor: ", + "combine: Multi-dimensional extents are not yet supported. ", + "Expected 1D extents tensor, got ", + extents_ndim, + "D extents: ", extents_tv->toString()); auto container = component->container(); From 0abf6d29c67059539101c5f5c1bbff62d96dcb64 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 14 Jan 2026 16:43:14 -0800 Subject: [PATCH 45/47] Empty commit From aec93856ffaf0f679e9d05c8990e2ebdcb4aff69 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 16 Jan 2026 12:37:57 -0800 Subject: [PATCH 46/47] cleanup --- csrc/ops/alias.cpp | 5 ----- 1 file changed, 5 deletions(-) diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index 0f96952afa2..4da87c4cb60 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -1316,11 +1316,6 @@ TensorView* asNested( NVF_ERROR(data != nullptr, "asNested: data tensor is null"); NVF_ERROR(extents != nullptr, "asNested: extents tensor is null"); - NVF_CHECK( - !data->domain()->hasRaggedIterDomain(), - "Multiple level of nesting is not supported: ", - data->toString()); - NVF_CHECK( !data->domain()->hasRaggedIterDomain(), "Multiple level of nesting is not supported: ", From 1fde60b99fc1048a03d441fcdcb7dd3c0d0240fb Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 16 Jan 2026 14:06:20 -0800 Subject: [PATCH 47/47] cleanup --- csrc/ir/internal_base_nodes.cpp | 2 +- csrc/ops/alias.cpp | 10 ++++--- tests/cpp/test_ragged_iter_domain.cpp | 39 --------------------------- 3 files changed, 7 insertions(+), 44 deletions(-) diff --git a/csrc/ir/internal_base_nodes.cpp b/csrc/ir/internal_base_nodes.cpp index b514dd7d52d..479b53caa84 100644 --- a/csrc/ir/internal_base_nodes.cpp +++ b/csrc/ir/internal_base_nodes.cpp @@ -1129,10 +1129,10 @@ IterDomain* RaggedIterDomain::combine( TensorView* extents_tv = ragged->extents(); NVF_ERROR(extents_tv != nullptr, "combine: ragged extents tensor is null"); - // Multi-dimensional extents are not yet supported in combine // Filter out reduction dimensions before checking auto extents_no_reduction = extents_tv->getLogicalDomain() | TensorDomain::kNoReductions; + // Multi-dimensional extents are not yet supported in combine auto extents_ndim = std::ranges::distance(extents_no_reduction); NVF_ERROR_EQ( extents_ndim, diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index 4da87c4cb60..24336211b2e 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -1334,10 +1334,7 @@ TensorView* asNested( ragged_dim = wrapDim(ragged_dim, inp_logical_size); - // Validate shape correspondence for multi-dimensional extents - // For N-D extents, outer dimensions must match outer dimensions of input - // tensor Rule: extents.ndim - 1 == ragged_dim (except 1D extents which are - // always valid) Filter out reduction dimensions from extents tensor + // Filter out reduction dimensions from extents tensor auto extents_no_reduction = extents->getLogicalDomain() | TensorDomain::kNoReductions; auto extents_ndim = std::ranges::distance(extents_no_reduction); @@ -1346,6 +1343,11 @@ TensorView* asNested( 0, "asNested: extents tensor must have at least one non-reduction " "dimension"); + + // Validate shape correspondence for multi-dimensional extents + // For N-D extents, outer dimensions must match outer dimensions of input + // tensor Rule: extents.ndim - 1 == ragged_dim (except 1D extents which are + // always valid). if (extents_ndim > 1) { NVF_ERROR_EQ( extents_ndim - 1, diff --git a/tests/cpp/test_ragged_iter_domain.cpp b/tests/cpp/test_ragged_iter_domain.cpp index 5d794a5cb56..4f974a3f86c 100644 --- a/tests/cpp/test_ragged_iter_domain.cpp +++ b/tests/cpp/test_ragged_iter_domain.cpp @@ -1156,45 +1156,6 @@ TEST_F(RaggedIterDomainTest, PadRaggedDimensionError) { EXPECT_THROW(pad(nested, pad_widths), nvfError); } -// Test asNested with 2D extents on correct axis (matching expert parallelism) -TEST_F(RaggedIterDomainTest, AsNested2DExtentsAxis1) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Create a 3D TensorView: [D=2, S=100, hidden=512] - // This represents 2 GPUs, each with 100 tokens, and hidden dimension - auto tokens = makeSymbolicTensor(3, DataType::Float); - fusion.addInput(tokens); - - // Create 2D extents: [D=2, E=4] - // Represents per-GPU token counts for 4 experts - auto extents = makeSymbolicTensor(2, DataType::Index); - fusion.addInput(extents); - - // Partition dimension 1 (tokens) with 2D extents - auto nested = asNested(tokens, extents, 1); - - fusion.addOutput(nested); - - // Should be: [D=2, component=4, ragged, hidden=512] - EXPECT_EQ(nested->nDims(), 4); - - // First axis unchanged (D=2) - EXPECT_TRUE(nested->axis(0)->isStrictlyA()); - - // Second axis is component (num_experts=4) - EXPECT_TRUE(nested->axis(1)->isStrictlyA()); - EXPECT_FALSE(nested->axis(1)->isA()); - - // Third axis is ragged with 2D extents - EXPECT_TRUE(nested->axis(2)->isA()); - auto ragged_id = nested->axis(2)->as(); - EXPECT_EQ(ragged_id->extents()->nDims(), 2); - - // Fourth axis unchanged (hidden=512) - EXPECT_TRUE(nested->axis(3)->isStrictlyA()); -} - // Test asNested invalid shape TEST_F(RaggedIterDomainTest, AsNestedInvalidShape) { Fusion fusion;