Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions csrc/id_model/indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
namespace nvfuser {

TensorIndexer::TensorIndexer(IdModel& id_model) : id_model_(id_model) {
NVF_ERROR(isSupported(id_model.fusion()));

buildLoopIndexMap();

if (isDebugDumpEnabled(DebugDumpOption::IndexingVerbose)) {
Expand Down
1 change: 1 addition & 0 deletions csrc/scheduler/expr_eval_sched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ bool ExprEvalScheduler::canScheduleCompileTime(Fusion* fusion) {
// TODO: remove IndexPutAccumulateOp
if (exprs.front()
->isOneOf<
GatherOp,
ScatterOp,
SdpaFwdOp,
SdpaBwdOp,
Expand Down
10 changes: 10 additions & 0 deletions csrc/scheduler/registry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,16 @@ bool checkCanSchedule(Fusion* fusion, SchedulerType scheduler_type) {
return false;
}

// Support of non-exact gather was dropped when the legacy indexer was
// deprecated
if (std::ranges::any_of(
ir_utils::getOpsOfType<GatherOp>(fusion),
[](GatherOp* gather) { return !gather->exactSizes(); })) {
scheduler_debug_utils::canScheduleRejectReason(
scheduler_type, "Non-exact gather ops");
return false;
}

// Fusions with `MatmulOp, LinearOp, MmaOp` can only be accepted by Matmul
// scheduler.
if (scheduler_type != SchedulerType::Matmul &&
Expand Down
8 changes: 5 additions & 3 deletions tests/cpp/test_gather.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ TEST_F(GatherTest, TakeAlongAxisIntermediateTensorReduction1) {

validateSegmentation(
executor_cache.getMostRecentKernelRuntime(),
{SchedulerType::Reduction, SchedulerType::PointWise});
{SchedulerType::Reduction, SchedulerType::ExprEval});

testValidate(&fusion, outputs, {t0, t1}, __LINE__, __FILE__);
}
Expand Down Expand Up @@ -1128,7 +1128,8 @@ TEST_F(GatherTest, TakeAlongAxisCrossEntropyLoss) {
}

// Test grouped reduction on IterType::GatherScatter
TEST_F(GatherTest, GatherIterGoupedReduction) {
// Codegen support of non-exact gather dropped
TEST_F(GatherTest, DISABLED_GatherIterGoupedReduction) {
const int max_dim_size = 128;
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
auto options_i = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0);
Expand Down Expand Up @@ -1212,7 +1213,8 @@ TEST_F(GatherTest, GatherIterGoupedReduction) {
lparams);
}

TEST_F(GatherTest, SameTvUsedAsLookupAndIndex) {
// Codegen support of non-exact gather dropped
TEST_F(GatherTest, DISABLED_SameTvUsedAsLookupAndIndex) {
auto fusion_ptr = std::make_unique<Fusion>();
Fusion& fusion = *fusion_ptr.get();
FusionGuard fg(&fusion);
Expand Down