Skip to content

Conversation

@naoyam
Copy link
Collaborator

@naoyam naoyam commented Jan 31, 2026

No description provided.

@naoyam
Copy link
Collaborator Author

naoyam commented Jan 31, 2026

!test

@github-actions
Copy link

github-actions bot commented Jan 31, 2026

Review updated until commit adb947e

Description

  • Disable multiple Ampere-specific matmul tests in test_matmul.cpp

  • Disable LdMatrix and StMatrix tests in test_memory.cpp due to missing alternate loop domain

  • Modify tensor indexer logic in index_compute.cpp to disable LdMatrix/StMatrix optimizations

  • Change tensor indexer assertion behavior from permissive to strict

Changes walkthrough

Relevant files
Tests
test_matmul.cpp
Remove multiple Ampere matmul tests                                           

tests/cpp/test_matmul.cpp

  • Removed MatmulMatmulAmpere test (lines 667-1338)
  • Removed MatmulSoftmaxMatmulAmpere test (lines 1339-1377)
  • Removed AmpereMatmulTNCpAsync test (lines 1378-2082)
  • Removed AmpereStridedBatchedMatmulTN test (lines 2083-2568)
  • Removed AmpereViewMatmulTN test (lines 2569-2856)
  • +0/-1151
    test_memory.cpp
    Disable LdMatrix and StMatrix tests                                           

    tests/cpp/test_memory.cpp

  • Disabled LdMatrixTest.Regular by adding DISABLED_ prefix
  • Disabled StMatrixTest.Regular by adding DISABLED_ prefix
  • Disabled LdMatrixTest.Transpose by adding DISABLED_ prefix
  • Added comments explaining tests are disabled due to missing alternate
    loop domain
  • +6/-3     
    Enhancement
    index_compute.cpp
    Modify tensor indexer logic and assertions                             

    csrc/index_compute.cpp

  • Hardcoded is_producer_ldmatrix_op to false (line 1941)
  • Hardcoded is_producer_stmatrix_op_with_no_alloc_domain to false (line
    1947)
  • Changed tensor indexer assertion from false to true (line 1976)
  • +3/-1     

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Functionality Disabled

    Lines 1944 and 1951 forcibly set is_producer_ldmatrix_op and is_producer_stmatrix_op_with_no_alloc_domain to false, which appears to disable LdMatrix and StMatrix support in TensorIndexer. This needs justification and should not be merged without proper explanation of why these operations are being disabled.

    is_producer_ldmatrix_op = false;
    bool is_producer_stmatrix_op_with_no_alloc_domain =
        producer->definition() != nullptr &&
        producer->definition()->isA<LoadStoreOp>() &&
        producer->definition()->as<LoadStoreOp>()->opType() ==
            LoadStoreOpType::StMatrix &&
        !producer->hasAllocation();
    is_producer_stmatrix_op_with_no_alloc_domain = false;
    Test Cases Removed

    Multiple substantial test cases have been removed including complex matmul fusion tests, softmax-matmul tests, and various Ampere-specific tests. These tests cover important functionality and their removal should be justified or they should be preserved in some form.

    TEST_P(MatmulTestWithLayout, TuringMatmul) {
      NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0);
    
      // Keep multiples of 8 to keep vectorizable.
      int M = 504, N = 136, K = 248;
    
      Fusion fusion;
      FusionGuard fg(&fusion);
    
      auto shapes = matmulAtInputShape3DTuring(-1, -1, -1, layout);
    
      auto tv0 = makeContigConcreteTensor(shapes.first, DataType::Half);
      auto tv1 = makeContigConcreteTensor(shapes.second, DataType::Half);
    
      fusion.addInput(tv0);
      fusion.addInput(tv1);
    
      tv0 = canonicalizeInputToBMNK(tv0, layout, MmaOperand::A);
      tv1 = canonicalizeInputToBMNK(tv1, layout, MmaOperand::B);
      auto tv2 = fusedMultiplySum(tv0, tv1, {-1});
    
      fusion.addOutput(tv2);
    
      MatMulTileOptions gemm_tile;
      gemm_tile.cta_tile = GemmTile(128, 128, 32);
      gemm_tile.warp_tile = GemmTile(64, 64, 32);
    
      MatmulParams mparams;
      mparams.supported_vec_size = {8, 8, 4};
      mparams.mma_macro = MmaMacro::Turing_16_8_16;
      mparams.tile_sizes = gemm_tile;
      SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
          ->schedule(&fusion, &mparams);
    
      auto inputs = matmulAtInput3DTuring(M, N, K, layout);
    
      KernelExecutor ke;
      NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK(
          7, 5, ke.compile(&fusion, {inputs.first, inputs.second}));
      ASSERT_TRUE(getBankConflictInfo(ke.compiledKernel()->kernel()).empty());
      ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
          ke.compiledKernel()->kernel()));
      auto cg_outputs = ke.run({inputs.first, inputs.second});
      auto tref = atMatmul(
          inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout);
      NVF_CHECK(at::allclose(cg_outputs[0].as<at::Tensor>(), tref, 0.0001, 0.0001));
    }
    
    // Test an end-to-end matmul case with swizzled smem
    // data layout.
    
    // Matmul test on Ampere using ldmatrix.x4 to load operands
    TEST_P(MatmulTestWithLayout, AmpereMatmulLargeLoad) {
      NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0);
    
      REQUIRE_DEVICE_SMEM_SIZE(98384, 0);
      // Keep multiples of 8 to keep vectorizable.
      int M = 504, N = 136, K = 248;
      Fusion fusion;
      FusionGuard fg(&fusion);
    
      auto shapes = matmulAtInputShape3DTuring(-1, -1, -1, layout);
    
      auto tv0 = makeContigConcreteTensor(shapes.first, DataType::Half);
      auto tv1 = makeContigConcreteTensor(shapes.second, DataType::Half);
    
      fusion.addInput(tv0);
      fusion.addInput(tv1);
    
      tv0 = canonicalizeInputToBMNK(tv0, layout, MmaOperand::A);
      tv1 = canonicalizeInputToBMNK(tv1, layout, MmaOperand::B);
      auto tv2 = fusedMultiplySum(tv0, tv1, {-1});
    
      fusion.addOutput(tv2);
    
      MatMulTileOptions gemm_tile;
      gemm_tile.cta_tile = GemmTile(128, 128, 64);
      gemm_tile.warp_tile = GemmTile(64, 64, 64);
      MatmulParams mparams;
      mparams.supported_vec_size = {8, 8, 4};
      mparams.mma_macro = MmaMacro::Ampere_16_16_16;
      mparams.tile_sizes = gemm_tile;
      mparams.async_gmem_load_operands = true;
      mparams.circular_buffer_options.circular_buffer_smem_write = true;
      mparams.circular_buffer_options.circular_buffer_smem_read = true;
      mparams.circular_buffer_options.smem_circular_buffer_stage = 3;
      SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
          ->schedule(&fusion, &mparams);
    
      auto inputs = matmulAtInput3DTuring(M, N, K, layout);
    
      KernelExecutor ke;
      NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK(
          8,
          0,
          ke.compile(
              &fusion,
              {inputs.first, inputs.second},
              LaunchParams(),
              matmul_cparams));
      ASSERT_TRUE(getBankConflictInfo(ke.compiledKernel()->kernel()).empty());
      ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
          ke.compiledKernel()->kernel()));
      auto cg_outputs = ke.run({inputs.first, inputs.second});
      auto tref = atMatmul(
          inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout);
      NVF_CHECK(at::allclose(cg_outputs[0].as<at::Tensor>(), tref, 0.0001, 0.0001));
    }
    
    // Matmul test for Turing MMA: across supported layouts
    TEST_P(MatmulTestWithLayout, TuringMatmulLargeLoad) {
      NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0);
    
      // Keep multiples of 8 to keep vectorizable.
      int M = 504, N = 136, K = 248;
    
      Fusion fusion;
      FusionGuard fg(&fusion);
    
      auto shapes = matmulAtInputShape3DTuring(-1, -1, -1, layout);
    
      auto tv0 = makeContigConcreteTensor(shapes.first, DataType::Half);
      auto tv1 = makeContigConcreteTensor(shapes.second, DataType::Half);
    
      fusion.addInput(tv0);
      fusion.addInput(tv1);
    
      tv0 = canonicalizeInputToBMNK(tv0, layout, MmaOperand::A);
      tv1 = canonicalizeInputToBMNK(tv1, layout, MmaOperand::B);
      auto tv2 = fusedMultiplySum(tv0, tv1, {-1});
    
      fusion.addOutput(tv2);
    
      MatMulTileOptions gemm_tile;
      gemm_tile.cta_tile = GemmTile(128, 128, 32);
      gemm_tile.warp_tile = GemmTile(64, 64, 32);
    
      MatmulParams mparams;
      mparams.supported_vec_size = {8, 8, 4};
      mparams.mma_macro = MmaMacro::Turing_16_16_16;
      mparams.tile_sizes = gemm_tile;
      SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
          ->schedule(&fusion, &mparams);
    
      auto inputs = matmulAtInput3DTuring(M, N, K, layout);
    
      KernelExecutor ke;
      NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK(
          7,
          5,
          ke.compile(
              &fusion,
              {inputs.first, inputs.second},
              LaunchParams(),
              matmul_cparams));
      ASSERT_TRUE(getBankConflictInfo(ke.compiledKernel()->kernel()).empty());
      ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
          ke.compiledKernel()->kernel()));
      auto cg_outputs = ke.run({inputs.first, inputs.second});
      auto tref = atMatmul(
          inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout);
      NVF_CHECK(at::allclose(cg_outputs[0].as<at::Tensor>(), tref, 0.0001, 0.0001));
    }
    
    // Tile layout check for symmetric 4-warp recipes
    TEST_P(MatmulTestWithLayout, AmpereMatmulTileCheck4warp) {
      NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0);
    
      REQUIRE_DEVICE_SMEM_SIZE(98384, 0);
      // Keep multiples of 8 to keep vectorizable.
      int M = 504, N = 136, K = 248;
      // Symmetric tile with 16x16x16 macro,
      //  supports mn_size of multiple of 32,
      //  and k size multiple of 16.
      for (int mn_size : {32, 64, 96, 128, 160, 192}) {
        for (int k_size : {32, 48, 64}) {
          Fusion fusion;
          FusionGuard fg(&fusion);
    
          auto shapes = matmulAtInputShape3DTuring(-1, -1, -1, layout);
    
          auto tv0 = makeContigConcreteTensor(shapes.first, DataType::Half);
          auto tv1 = makeContigConcreteTensor(shapes.second, DataType::Half);
    
          fusion.addInput(tv0);
          fusion.addInput(tv1);
    
          tv0 = canonicalizeInputToBMNK(tv0, layout, MmaOperand::A);
          tv1 = canonicalizeInputToBMNK(tv1, layout, MmaOperand::B);
          auto tv2 = fusedMultiplySum(tv0, tv1, {-1});
    
          fusion.addOutput(tv2);
    
          MatMulTileOptions gemm_tile;
          gemm_tile.cta_tile = GemmTile(mn_size, mn_size, k_size);
          gemm_tile.warp_tile = GemmTile(mn_size / 2, mn_size / 2, k_size);
    
          MatmulParams mparams;
          mparams.supported_vec_size = {8, 8, 4};
          mparams.mma_macro = MmaMacro::Ampere_16_16_16;
          mparams.tile_sizes = gemm_tile;
          mparams.async_gmem_load_operands = true;
          mparams.circular_buffer_options.circular_buffer_smem_write = true;
          mma_utils::MmaDataTypes data_types = {
              DataType::Half, DataType::Half, DataType::Float};
          std::tie(mparams.use_smem_epilogue, mparams.promote_prologue_smem_reuse) =
              mma_utils::generateSharedMemoryEpilogueHeuristics(
                  gemm_tile,
                  mparams.circular_buffer_options.smem_circular_buffer_stage,
                  data_types,
                  true,
                  true);
          SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
              ->schedule(&fusion, &mparams);
    
          auto inputs = matmulAtInput3DTuring(M, N, K, layout);
    
          KernelExecutor ke;
          NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK(
              8,
              0,
              ke.compile(
                  &fusion,
                  {inputs.first, inputs.second},
                  LaunchParams(),
                  matmul_cparams));
          EXPECT_TRUE(getBankConflictInfo(ke.compiledKernel()->kernel()).empty());
          auto cg_outputs = ke.run({inputs.first, inputs.second});
          ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
              ke.compiledKernel()->kernel()));
          auto tref = atMatmul(
              inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout);
          NVF_CHECK(
              at::allclose(cg_outputs[0].as<at::Tensor>(), tref, 0.0001, 0.0001),
              "error :",
              (cg_outputs[0].as<at::Tensor>() - tref).abs().max(),
              "tile dim:",
              mn_size,
              " ",
              k_size);
        }
      }
    }
    
    TEST_P(MatmulTestWithLayout, AmpereMatmulTileCheck8warp) {
      NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0);
    
      REQUIRE_DEVICE_SMEM_SIZE(98384, 0);
      // Keep multiples of 8 to keep vectorizable.
      int M = 504, N = 136, K = 248;
      // ASymmetric tile with 16x16x16 macro,
      for (int m_size : {256}) {
        for (int n_size : {32, 64, 96, 128}) {
          for (int k_size : {32, 48, 64}) {
            Fusion fusion;
            FusionGuard fg(&fusion);
    
            auto shapes = matmulAtInputShape3DTuring(-1, -1, -1, layout);
    
            auto tv0 = makeContigConcreteTensor(shapes.first, DataType::Half);
            auto tv1 = makeContigConcreteTensor(shapes.second, DataType::Half);
    
            fusion.addInput(tv0);
            fusion.addInput(tv1);
    
            tv0 = canonicalizeInputToBMNK(tv0, layout, MmaOperand::A);
            tv1 = canonicalizeInputToBMNK(tv1, layout, MmaOperand::B);
            auto tv2 = fusedMultiplySum(tv0, tv1, {-1});
    
            fusion.addOutput(tv2);
    
            MatMulTileOptions gemm_tile;
            gemm_tile.cta_tile = GemmTile(m_size, n_size, k_size);
            gemm_tile.warp_tile = GemmTile(m_size / 4, n_size / 2, k_size);
    
            MatmulParams mparams;
            mparams.supported_vec_size = {8, 8, 4};
            mparams.mma_macro = MmaMacro::Ampere_16_16_16;
            mparams.tile_sizes = gemm_tile;
            mparams.async_gmem_load_operands = true;
            mparams.circular_buffer_options.circular_buffer_smem_write = true;
            mparams.circular_buffer_options.circular_buffer_smem_read = true;
            mparams.circular_buffer_options.smem_circular_buffer_stage = 2;
            mma_utils::MmaDataTypes data_types = {
                DataType::Half, DataType::Half, DataType::Float};
            std::tie(
                mparams.use_smem_epilogue, mparams.promote_prologue_smem_reuse) =
                mma_utils::generateSharedMemoryEpilogueHeuristics(
                    gemm_tile,
                    mparams.circular_buffer_options.smem_circular_buffer_stage,
                    data_types);
    
            SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
                ->schedule(&fusion, &mparams);
    
            auto inputs = matmulAtInput3DTuring(M, N, K, layout);
    
            KernelExecutor ke;
            NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK(
                8,
                0,
                ke.compile(
                    &fusion,
                    {inputs.first, inputs.second},
                    LaunchParams(),
                    matmul_cparams));
            ASSERT_TRUE(getBankConflictInfo(ke.compiledKernel()->kernel()).empty());
            ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
                ke.compiledKernel()->kernel()));
            auto cg_outputs = ke.run({inputs.first, inputs.second});
            auto tref = atMatmul(
                inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout);
            NVF_CHECK(
                at::allclose(cg_outputs[0].as<at::Tensor>(), tref, 0.0001, 0.0001));
          }
        }
      }
    }
    
    TEST_P(MatmulTestWithLayout, AmpereMatmulTileCheck6warp) {
      NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0);
    
      REQUIRE_DEVICE_SMEM_SIZE(98384, 0);
      // Keep multiples of 8 to keep vectorizable.
      int M = 504, N = 136, K = 248;
      for (int k_size : {32, 48, 64}) {
        Fusion fusion;
        FusionGuard fg(&fusion);
    
        auto shapes = matmulAtInputShape3DTuring(-1, -1, -1, layout);
    
        auto tv0 = makeContigConcreteTensor(shapes.first, DataType::Half);
        auto tv1 = makeContigConcreteTensor(shapes.second, DataType::Half);
    
        fusion.addInput(tv0);
        fusion.addInput(tv1);
    
        tv0 = canonicalizeInputToBMNK(tv0, layout, MmaOperand::A);
        tv1 = canonicalizeInputToBMNK(tv1, layout, MmaOperand::B);
        auto tv2 = fusedMultiplySum(tv0, tv1, {-1});
    
        fusion.addOutput(tv2);
    
        MatMulTileOptions gemm_tile;
        // 2 warp by 3 warp
        gemm_tile.cta_tile = GemmTile(192, 128, k_size);
        gemm_tile.warp_tile = GemmTile(64, 64, k_size);
    
        MatmulParams mparams;
        mparams.supported_vec_size = {8, 8, 4};
        mparams.mma_macro = MmaMacro::Ampere_16_16_16;
        mparams.tile_sizes = gemm_tile;
        mparams.async_gmem_load_operands = true;
        mparams.circular_buffer_options.circular_buffer_smem_write = true;
        mparams.circular_buffer_options.circular_buffer_smem_read = true;
        mparams.circular_buffer_options.smem_circular_buffer_stage = 2;
        mma_utils::MmaDataTypes data_types = {
            DataType::Half, DataType::Half, DataType::Float};
        std::tie(mparams.use_smem_epilogue, mparams.promote_prologue_smem_reuse) =
            mma_utils::generateSharedMemoryEpilogueHeuristics(
                gemm_tile,
                mparams.circular_buffer_options.smem_circular_buffer_stage,
                data_types);
        SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
            ->schedule(&fusion, &mparams);
    
        auto inputs = matmulAtInput3DTuring(M, N, K, layout);
    
        KernelExecutor ke;
        NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK(
            8,
            0,
            ke.compile(
                &fusion,
                {inputs.first, inputs.second},
                LaunchParams(),
                matmul_cparams));
        ASSERT_TRUE(getBankConflictInfo(ke.compiledKernel()->kernel()).empty());
        ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
            ke.compiledKernel()->kernel()));
        auto cg_outputs = ke.run({inputs.first, inputs.second});
        auto tref = atMatmul(
            inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout);
        NVF_CHECK(
            at::allclose(cg_outputs[0].as<at::Tensor>(), tref, 0.0001, 0.0001));
      }
    }
    
    // Matmul test on Ampere using ldmatrix.x4 to load operands
    TEST_P(MatmulTestWithLayout, AmpereMatmulLargeLoadLargeK) {
      NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0);
    
      // Keep multiples of 8 to keep vectorizable.
      int M = 504, N = 136, K = 2048;
      Fusion fusion;
      FusionGuard fg(&fusion);
    
      auto shapes = matmulAtInputShape3DTuring(-1, -1, -1, layout);
    
      auto tv0 = makeContigConcreteTensor(shapes.first, DataType::Half);
      auto tv1 = makeContigConcreteTensor(shapes.second, DataType::Half);
    
      fusion.addInput(tv0);
      fusion.addInput(tv1);
    
      tv0 = canonicalizeInputToBMNK(tv0, layout, MmaOperand::A);
      tv1 = canonicalizeInputToBMNK(tv1, layout, MmaOperand::B);
      auto tv2 = fusedMultiplySum(tv0, tv1, {-1});
    
      fusion.addOutput(tv2);
    
      MatMulTileOptions gemm_tile;
      gemm_tile.cta_tile = GemmTile(128, 128, 64);
      gemm_tile.warp_tile = GemmTile(64, 64, 64);
    
      MatmulParams mparams;
      mparams.supported_vec_size = {8, 8, 4};
      mparams.mma_macro = MmaMacro::Ampere_16_16_16;
      mparams.tile_sizes = gemm_tile;
      mparams.async_gmem_load_operands = true;
      mparams.circular_buffer_options.circular_buffer_smem_write = true;
      mparams.circular_buffer_options.circular_buffer_smem_read = true;
      mparams.circular_buffer_options.smem_circular_buffer_stage = 3;
      SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
          ->schedule(&fusion, &mparams);
    
      auto inputs = matmulAtInput3DTuring(M, N, K, layout);
    
      KernelExecutor ke;
      NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK(
          8,
          0,
          ke.compile(
              &fusion,
              {inputs.first, inputs.second},
              LaunchParams(),
              matmul_cparams));
      ASSERT_TRUE(getBankConflictInfo(ke.compiledKernel()->kernel()).empty());
      ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
          ke.compiledKernel()->kernel()));
      auto cg_outputs = ke.run({inputs.first, inputs.second});
      auto tref = atMatmul(
          inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout);
      NVF_CHECK(at::allclose(cg_outputs[0].as<at::Tensor>(), tref, 0.001, 0.001));
    }
    
    // Matmul test for Ampere MMA: across supported layouts
    TEST_P(MatmulTestWithLayout, AmpereSplitKLikeStridedBatchedMatmul) {
      NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0);
    
      // Keep multiples of 8 to keep vectorizable.
      int B = 2, M = 504, N = 136, K = 248;
    
      Fusion fusion;
      FusionGuard fg(&fusion);
      auto tv0 = makeContigTensor(3, DataType::Half);
      auto tv1 = makeContigTensor(3, DataType::Half);
    
      fusion.addInput(tv0);
      fusion.addInput(tv1);
    
      tv0 = canonicalizeInputToBMNK(tv0, layout, MmaOperand::A);
      tv1 = canonicalizeInputToBMNK(tv1, layout, MmaOperand::B);
      auto tv2 = fusedMultiplySum(tv0, tv1, {-1});
    
      fusion.addOutput(tv2);
    
      MatMulTileOptions gemm_tile;
      gemm_tile.cta_tile = GemmTile(128, 128, 32);
      gemm_tile.warp_tile = GemmTile(64, 64, 32);
    
      MatmulParams mparams;
      mparams.supported_vec_size = {8, 8, 4};
      mparams.mma_macro = MmaMacro::Ampere_16_8_16;
      mparams.tile_sizes = gemm_tile;
      mparams.async_gmem_load_operands = true;
      mparams.circular_buffer_options.circular_buffer_smem_write = true;
      mparams.circular_buffer_options.circular_buffer_smem_read = true;
      mparams.circular_buffer_options.smem_circular_buffer_stage = 4;
      SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
          ->schedule(&fusion, &mparams);
    
      auto t0 = matmulAtInput2D(layout, TensorMatmulPos::A, at::kHalf, M, N, K, B);
      auto t1 = matmulAtInput2D(layout, TensorMatmulPos::B, at::kHalf, M, N, K, B);
    
      KernelExecutor ke;
      NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK(
          8, 0, ke.compile(&fusion, {t0, t1}, LaunchParams(), matmul_cparams));
      ASSERT_TRUE(getBankConflictInfo(ke.compiledKernel()->kernel()).empty());
      ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
          ke.compiledKernel()->kernel()));
      auto cg_outputs = ke.run({t0, t1});
      auto tref = splitkLikeAtMatmul(t0.to(at::kFloat), t1.to(at::kFloat), layout);
      NVF_CHECK(at::allclose(cg_outputs[0].as<at::Tensor>(), tref, 0.0001, 0.0001));
    }
    
    TEST_P(MatmulTestWithLayout, AmpereMatmulSmemEpilogue) {
      NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(8, 0, 9, 0);
      constexpr bool ignore_occupancy_drop = true;
      // Keep multiples of 8 to keep vectorizable.
      int M = 4096, N = 4096, K = 4096;
      // This tests num_stages=0, which should be treated identically to
      // num_stages=1. It is put here to exercise this path to ensure we don't
      // crash in generateSharedMemoryEpilogueHeuristics.
      // See https://github.com/NVIDIA/Fuser/pull/1917 for more info
      for (int num_stages : {0, 2}) {
        Fusion fusion;
        FusionGuard fg(&fusion);
    
        auto shapes = matmulAtInputShape3DTuring(-1, -1, -1, layout);
    
        auto tv0 = makeContigConcreteTensor(shapes.first, DataType::Half);
        auto tv1 = makeContigConcreteTensor(shapes.second, DataType::Half);
    
        fusion.addInput(tv0);
        fusion.addInput(tv1);
    
        tv0 = canonicalizeInputToBMNK(tv0, layout, MmaOperand::A);
        tv1 = canonicalizeInputToBMNK(tv1, layout, MmaOperand::B);
        auto tv2 = fusedMultiplySum(tv0, tv1, {-1});
    
        fusion.addOutput(tv2);
    
        // The settings of cta_tile, warp_tile, and smem_circular_buffer_stage
        // have been purposefully selected to produce a constant occupancy of 25%.
        // This allows us to effectively evaluate the influence of the
        // use_smem_epilogue parameter on performance, since changing its value to
        // either true or false will not affect the occupancy rate.
        MatMulTileOptions gemm_tile;
        gemm_tile.cta_tile = GemmTile(64, 128, 32);
        gemm_tile.warp_tile = GemmTile(32, 32, 32);
    
        MatmulParams mparams;
        mparams.supported_vec_size = {8, 8, 4};
        mparams.mma_macro = MmaMacro::Ampere_16_8_16;
        mparams.tile_sizes = gemm_tile;
        mparams.async_gmem_load_operands = true;
        mparams.circular_buffer_options.circular_buffer_smem_write = num_stages > 1;
        mparams.circular_buffer_options.circular_buffer_smem_read = num_stages > 1;
        mparams.circular_buffer_options.smem_circular_buffer_stage = num_stages;
        mma_utils::MmaDataTypes data_types = {
            DataType::Half, DataType::Half, DataType::Float};
        std::tie(mparams.use_smem_epilogue, mparams.promote_prologue_smem_reuse) =
            mma_utils::generateSharedMemoryEpilogueHeuristics(
                gemm_tile,
                mparams.circular_buffer_options.smem_circular_buffer_stage,
                data_types,
                ignore_occupancy_drop);
        SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
            ->schedule(&fusion, &mparams);
    
        // If use_smem_epilogue is true, there should be 3 shared memory tensors 2
        // for prologue and 1 for epilogue.
        int num_shared_mem_tensors = 0;
        int expected_num_shared_mem_tensors = mparams.use_smem_epilogue ? 3 : 2;
        for (const auto& tv : fusion.allTvs()) {
          if (tv->getMemoryType() == MemoryType::Shared) {
            num_shared_mem_tensors++;
          }
        }
        NVF_CHECK(
            num_shared_mem_tensors == expected_num_shared_mem_tensors,
            "Number of shared memory tensors doesn't match!",
            "Expected: ",
            expected_num_shared_mem_tensors,
            ", Got: ",
            num_shared_mem_tensors);
    
        at::manual_seed(0);
        auto inputs = matmulAtInput3DTuring(M, N, K, layout);
    
        KernelExecutor ke;
        NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK(
            8,
            0,
            ke.compile(
                &fusion,
                {inputs.first, inputs.second},
                LaunchParams(),
                matmul_cparams));
        auto cg_outputs = ke.run({inputs.first, inputs.second});
        auto tref = atMatmul(
            inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout);
    
        // check bank conflicts
        ASSERT_TRUE(getBankConflictInfo(ke.compiledKernel()->kernel()).empty());
        ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
            ke.compiledKernel()->kernel()));
        // (0.001, 0.001) passed on local A100 but failed on CI A100
        NVF_CHECK(
            at::allclose(cg_outputs[0].as<at::Tensor>(), tref, 0.01, 0.01),
            "Result validation failed. Max diff: ",
            (cg_outputs[0].as<at::Tensor>() - tref).abs().max());
    
        if (!mparams.use_smem_epilogue) {
          GTEST_SKIP() << "Test conducted without utilizing shared memory epilogue "
                          "due to the device's constrained shared memory capacity.";
        }
    
        // Check that smem is allocated as expected.
        // There are three cases that are determined by the current device in
        // mma_utils::generateSharedMemoryEpilogueHeuristics:
        //   - !use_smem_epilogue : A + B (this test is skipped in this case)
        //   - use_smem_epilogue && !promote_prologue_smem_reuse : A + B + C
        //   - use_smem_epilogue && promote_prologue_smem_reuse : max(A + B, C)
        auto smem_allocs =
            ke.compiledKernel()->kernel()->summary().dynamic_smem_allocations;
        NVF_CHECK(smem_allocs.size() == 3);
        if (mparams.promote_prologue_smem_reuse) {
          // Check prologue shared memory re-use
          // smem_allocs = {A, B, C} where C is the epilogue buffer
          // since A and B have no further uses, we should be able to reuse both
          // of them, implying that the address of C is zero. In this case, B will
          // also be allocated at address 0 with A stacked above it at position
          // 8192.
          EXPECT_EQ(
              smem_allocs.at(0)->address()->evaluate(),
              // Assuming B numel times size(dtype) is a multiple of 16 so that
              // this address is aligned
              smem_allocs.at(1)->size()->evaluate() *
                  dataTypeSizeByte(smem_allocs.at(1)->buffer()->dtype()));
          EXPECT_EQ(smem_allocs.at(1)->address()->evaluate(), 0L);
          EXPECT_EQ(smem_allocs.at(2)->address()->evaluate(), 0L);
        } else {
          // Prologue shared memory is not re-used. In this case, memory should
          // stack in C, B, A order.
          EXPECT_EQ(
              smem_allocs.at(0)->address()->evaluate(),
              // Assuming for B and C that numel times size(dtype) is a multiple
              // of 16 so that this address is aligned
              smem_allocs.at(1)->size()->evaluate() *
                      dataTypeSizeByte(smem_allocs.at(1)->buffer()->dtype()) +
                  smem_allocs.at(2)->size()->evaluate() *
                      dataTypeSizeByte(smem_allocs.at(2)->buffer()->dtype()));
          EXPECT_EQ(
              smem_allocs.at(1)->address()->evaluate(),
              smem_allocs.at(2)->size()->evaluate() *
                  dataTypeSizeByte(smem_allocs.at(2)->buffer()->dtype()));
          EXPECT_EQ(smem_allocs.at(2)->address()->evaluate(), 0L);
        }
      }
    }
    
    // On A100, this problem is able to make use of smem epilogue but only if we
    // promote use.
    // See https://github.com/NVIDIA/Fuser/pull/1834
    TEST_F(MatmulTest, AmpereMatmulSmemEpiloguePromotionRequiredA100) {
      NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(8, 0, 9, 0);
      // Keep multiples of 8 to keep vectorizable.
      int M = 4096, N = 4096, K = 4096;
    
      Fusion fusion;
      FusionGuard fg(&fusion);
    
      auto layout = MmaLayout::TN;
    
      auto shapes = matmulAtInputShape3DTuring(-1, -1, -1, layout);
    
      auto tv0 = makeContigConcreteTensor(shapes.first, DataType::Half);
      auto tv1 = makeContigConcreteTensor(shapes.second, DataType::Half);
    
      fusion.addInput(tv0);
      fusion.addInput(tv1);
    
      tv0 = canonicalizeInputToBMNK(tv0, layout, MmaOperand::A);
      tv1 = canonicalizeInputToBMNK(tv1, layout, MmaOperand::B);
      auto tv2 = fusedMultiplySum(tv0, tv1, {-1});
    
      fusion.addOutput(tv2);
    
      // The settings of cta_tile, warp_tile, and smem_circular_buffer_stage have
      // been purposefully selected to produce a constant occupancy of 25%. This
      // allows us to effectively evaluate the influence of the use_smem_epilogue
      // parameter on performance, since changing its value to either true or
      // false will not affect the occupancy rate.
      MatMulTileOptions gemm_tile;
      gemm_tile.cta_tile = GemmTile(64, 96, 64);
      gemm_tile.warp_tile = GemmTile(16, 32, 64);
    
      MatmulParams mparams;
      mparams.supported_vec_size = {8, 8, 4};
      mparams.mma_macro = MmaMacro::Ampere_16_8_16;
      mparams.tile_sizes = gemm_tile;
      mparams.async_gmem_load_operands = true;
      mparams.circular_buffer_options.circular_buffer_smem_write = true;
      mparams.circular_buffer_options.circular_buffer_smem_read = true;
      mparams.circular_buffer_options.smem_circular_buffer_stage = 6;
      mma_utils::MmaDataTypes data_types = {
          DataType::Half, DataType::Half, DataType::Float};
      std::tie(mparams.use_smem_epilogue, mparams.promote_prologue_smem_reuse) =
          mma_utils::generateSharedMemoryEpilogueHeuristics(
              gemm_tile,
              mparams.circular_buffer_options.smem_circular_buffer_stage,
              data_types,
              /*ignore_occupancy_drop=*/false);
    
      if (deviceMajorMinorCheck(8, 0)) {
        // Test that we promote smem reuse on A100. This might differ on devices
        // with different amounts of smem.
        ASSERT_TRUE(mparams.promote_prologue_smem_reuse);
      }
    
      SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
          ->schedule(&fusion, &mparams);
    
      // KernelExecutor::compile would fail otherwise.
      SKIP_IF_INSUFFICIENT_SMEM(&mparams, data_types);
    
      at::manual_seed(0);
      auto inputs = matmulAtInput3DTuring(M, N, K, layout);
    
      KernelExecutor ke;
      NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK(
          8,
          0,
          ke.compile(
              &fusion,
              {inputs.first, inputs.second},
              LaunchParams(),
              matmul_cparams));
      auto cg_outputs = ke.run({inputs.first, inputs.second});
      auto tref = atMatmul(
          inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout);
    
      // check bank conflicts
      ASSERT_TRUE(getBankConflictInfo(ke.compiledKernel()->kernel()).empty());
      ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
          ke.compiledKernel()->kernel()));
      // (0.001, 0.001) passed on local A100 but failed on CI A100
      NVF_CHECK(
          at::allclose(cg_outputs[0].as<at::Tensor>(), tref, 0.01, 0.01),
          "Result validation failed. Max diff: ",
          (cg_outputs[0].as<at::Tensor>() - tref).abs().max());
    
      if (!mparams.use_smem_epilogue) {
        GTEST_SKIP() << "Test conducted without utilizing shared memory epilogue "
                        "due to the device's constrained shared memory capacity.";
      }
      if (!mparams.promote_prologue_smem_reuse) {
        GTEST_SKIP() << "Test conducted with shared memory epilogue but without "
                        "promoting prologue smem re-use.";
      }
    }
    
    TEST_P(MatmulTestWithLayout, AmpereMatmulSmemEpilogueCast) {
      NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(8, 0, 9, 0);
      constexpr bool ignore_occupancy_drop = true;
      // Keep multiples of 8 to keep vectorizable.
      int M = 4096, N = 4096, K = 4096;
      Fusion fusion;
      FusionGuard fg(&fusion);
    
      auto shapes = matmulAtInputShape3DTuring(-1, -1, -1, layout);
    
      auto tv0 = makeContigConcreteTensor(shapes.first, DataType::Half);
      auto tv1 = makeContigConcreteTensor(shapes.second, DataType::Half);
    
      fusion.addInput(tv0);
      fusion.addInput(tv1);
    
      tv0 = canonicalizeInputToBMNK(tv0, layout, MmaOperand::A);
      tv1 = canonicalizeInputToBMNK(tv1, layout, MmaOperand::B);
      auto tv2 = fusedMultiplySum(tv0, tv1, {-1});
      auto tv3 = castOp(DataType::Half, tv2);
    
      fusion.addOutput(tv3);
    
      MatMulTileOptions gemm_tile;
      gemm_tile.cta_tile = GemmTile(128, 128, 32);
      gemm_tile.warp_tile = GemmTile(64, 64, 32);
    
      MatmulParams mparams;
      mparams.supported_vec_size = {8, 8, 8};
      mparams.mma_macro = MmaMacro::Ampere_16_8_16;
      mparams.tile_sizes = gemm_tile;
      mparams.async_gmem_load_operands = true;
      mparams.circular_buffer_options.circular_buffer_smem_write = true;
      mparams.circular_buffer_options.circular_buffer_smem_read = true;
      mparams.circular_buffer_options.smem_circular_buffer_stage = 4;
      mma_utils::MmaDataTypes data_types = {
          DataType::Half, DataType::Half, DataType::Float};
      std::tie(mparams.use_smem_epilogue, mparams.promote_prologue_smem_reuse) =
          mma_utils::generateSharedMemoryEpilogueHeuristics(
              gemm_tile,
              mparams.circular_buffer_options.smem_circular_buffer_stage,
              data_types,
              ignore_occupancy_drop);
      SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
          ->schedule(&fusion, &mparams);
    
      // If use_smem_epilogue is true, there should be 3 shared memory tensors 2
      // for prologue and 1 for epilogue.
      int num_shared_mem_tensors = 0;
      int expected_num_shared_mem_tensors = mparams.use_smem_epilogue ? 3 : 2;
      for (const auto& tv : fusion.allTvs()) {
        if (tv->getMemoryType() == MemoryType::Shared) {
          num_shared_mem_tensors++;
        }
      }
      NVF_CHECK(
          num_shared_mem_tensors == expected_num_shared_mem_tensors,
          "Number of shared memory tensors doesn't match!",
          "Expected: ",
          expected_num_shared_mem_tensors,
          ", Got: ",
          num_shared_mem_tensors);
    
      at::manual_seed(0);
      auto inputs = matmulAtInput3DTuring(M, N, K, layout);
    
      KernelExecutor ke;
      NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK(
          8,
          0,
          ke.compile(
              &fusion,
              {inputs.first, inputs.second},
              LaunchParams(),
              matmul_cparams));
      auto cg_outputs = ke.run({inputs.first, inputs.second});
      auto tref = atMatmul(
          inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout);
      tref = tref.to(at::kHalf);
      // check bank conflicts
      ASSERT_TRUE(getBankConflictInfo(ke.compiledKernel()->kernel()).empty());
      ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
          ke.compiledKernel()->kernel()));
      // (0.001, 0.001) passed on local A100 but failed on CI A100
      NVF_CHECK(
          at::allclose(cg_outputs[0].as<at::Tensor>(), tref, 0.01, 0.01),
          "Result validation failed. Max diff: ",
          (cg_outputs[0].as<at::Tensor>() - tref).abs().max());
    
      if (!mparams.use_smem_epilogue) {
        GTEST_SKIP() << "Test conducted without utilizing shared memory epilogue "
                        "due to the device's constrained shared memory capacity.";
      }
    }
    
    TEST_P(MatmulTestWithLayout, AmpereMatmulSmemEpilogueRelu) {
      NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(8, 0, 9, 0);
      constexpr bool ignore_occupancy_drop = true;
      // Keep multiples of 8 to keep vectorizable.
      int M = 4096, N = 4096, K = 4096;
      Fusion fusion;
      FusionGuard fg(&fusion);
    
      auto shapes = matmulAtInputShape3DTuring(-1, -1, -1, layout);
    
      auto tv0 = makeContigConcreteTensor(shapes.first, DataType::Half);
      auto tv1 = makeContigConcreteTensor(shapes.second, DataType::Half);
    
      fusion.addInput(tv0);
      fusion.addInput(tv1);
    
      tv0 = canonicalizeInputToBMNK(tv0, layout, MmaOperand::A);
      tv1 = canonicalizeInputToBMNK(tv1, layout, MmaOperand::B);
      auto tv2 = fusedMultiplySum(tv0, tv1, {-1});
      auto tv3 = relu(tv2);
    
      fusion.addOutput(tv3);
    
      MatMulTileOptions gemm_tile;
      gemm_tile.cta_tile = GemmTile(128, 128, 32);
      gemm_tile.warp_tile = GemmTile(64, 64, 32);
    
      MatmulParams mparams;
      mparams.supported_vec_size = {8, 8, 4};
      mparams.mma_macro = MmaMacro::Ampere_16_8_16;
      mparams.tile_sizes = gemm_tile;
      mparams.async_gmem_load_operands = true;
      mparams.circular_buffer_options.circular_buffer_smem_write = true;
      mparams.circular_buffer_options.circular_buffer_smem_read = true;
      mparams.circular_buffer_options.smem_circular_buffer_stage = 4;
      mma_utils::MmaDataTypes data_types = {
          DataType::Half, DataType::Half, DataType::Float};
      std::tie(mparams.use_smem_epilogue, mparams.promote_prologue_smem_reuse) =
          mma_utils::generateSharedMemoryEpilogueHeuristics(
              gemm_tile,
              mparams.circular_buffer_options.smem_circular_buffer_stage,
              data_types,
              ignore_occupancy_drop);
      SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
          ->schedule(&fusion, &mparams);
    
      // If use_smem_epilogue is true, there should be 3 shared memory tensors 2
      // for prologue and 1 for epilogue.
      int num_shared_mem_tensors = 0;
      int expected_num_shared_mem_tensors = mparams.use_smem_epilogue ? 3 : 2;
      for (const auto& tv : fusion.allTvs()) {
        if (tv->getMemoryType() == MemoryType::Shared) {
          num_shared_mem_tensors++;
        }
      }
      NVF_CHECK(
          num_shared_mem_tensors == expected_num_shared_mem_tensors,
          "Number of shared memory tensors doesn't match!",
          "Expected: ",
          expected_num_shared_mem_tensors,
          ", Got: ",
          num_shared_mem_tensors);
    
      at::manual_seed(0);
      auto inputs = matmulAtInput3DTuring(M, N, K, layout);
    
      KernelExecutor ke;
      NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK(
          8,
          0,
          ke.compile(
              &fusion,
              {inputs.first, inputs.second},
              LaunchParams(),
              matmul_cparams));
      auto cg_outputs = ke.run({inputs.first, inputs.second});
      auto t2 = atMatmul(
          inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout);
      auto tref = at::relu(t2).to(at::kFloat);
    
      // check bank conflicts
      ASSERT_TRUE(getBankConflictInfo(ke.compiledKernel()->kernel()).empty());
      ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
          ke.compiledKernel()->kernel()));
      // (0.001, 0.001) passed on local A100 but failed on CI A100
      NVF_CHECK(
          at::allclose(cg_outputs[0].as<at::Tensor>(), tref, 0.01, 0.01),
          "Result validation failed. Max diff: ",
          (cg_outputs[0].as<at::Tensor>() - tref).abs().max());
    
      if (!mparams.use_smem_epilogue) {
        GTEST_SKIP() << "Test conducted without utilizing shared memory epilogue "
                        "due to the device's constrained shared memory capacity.";
      }
    }
    
    // Test the matmul scheduler's single-kernel split-K support
    TEST_P(MatmulTestWithLayout, FusionAmpereMatmulSplitK_CUDA) {
      NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(8, 0, 9, 0);
    
      // Keep multiples of 8 to keep vectorizable.
      int M = 504, N = 136, K = 8096;
    
      for (int splitk_factor : {2}) {
        for (int use_smem_epilogue : {false, true}) {
          Fusion fusion;
          FusionGuard fg(&fusion);
          auto shapes = matmulAtInputShape3DTuring(-1, -1, -1, layout);
    
          auto tv0 = makeContigConcreteTensor(shapes.first, DataType::Half);
          auto tv1 = makeContigConcreteTensor(shapes.second, DataType::Half);
    
          fusion.addInput(tv0);
          fusion.addInput(tv1);
    
          tv0 = canonicalizeInputToBMNK(tv0, layout, MmaOperand::A);
          tv1 = canonicalizeInputToBMNK(tv1, layout, MmaOperand::B);
          auto tv2 = fusedMultiplySum(tv0, tv1, {-1});
    
          fusion.addOutput(tv2);
    
          MatMulTileOptions gemm_tile;
          gemm_tile.cta_tile = GemmTile(128, 128, 32);
          gemm_tile.warp_tile = GemmTile(64, 64, 32);
    
          MatmulParams mparams;
          mparams.supported_vec_size = {8, 8, 4};
          mparams.mma_macro = MmaMacro::Ampere_16_8_16;
          mparams.tile_sizes = gemm_tile;
          mparams.splitk_factor = splitk_factor;
          if (use_smem_epilogue) {
            std::tie(
                mparams.use_smem_epilogue, mparams.promote_prologue_smem_reuse) =
                mma_utils::generateSharedMemoryEpilogueHeuristics(
                    gemm_tile,
                    1,
                    {DataType::Half, DataType::Half, DataType::Float},
                    true,
                    true,
                    true);
            if (!mparams.use_smem_epilogue) {
              std::cout << "Skipping smem epilogue due to shared memory "
                           "constraints on this device"
                        << std::endl;
              continue;
            }
            mparams.promote_prologue_smem_reuse = true;
          }
    
          SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
              ->schedule(&fusion, &mparams);
    
          auto inputs = matmulAtInput3DTuring(M, N, K, layout);
    
          KernelExecutor ke;
          NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK(
              7, 5, ke.compile(&fusion, {inputs.first, inputs.second}));
          EXPECT_TRUE(getBankConflictInfo(ke.compiledKernel()->kernel()).empty());
          auto cg_outputs = ke.run({inputs.first, inputs.second});
          ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
              ke.compiledKernel()->kernel()));
          auto tref = atMatmul(
              inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout);
    
          // Relax tolerance for larger sum due to large K
          NVF_CHECK(at::allclose(
              cg_outputs[0].as<at::Tensor>(), tref, 1e-6 * K, 1e-6 * K));
        }
      }
    }
    
    // Test splitk with bias epilogue
    TEST_P(MatmulTestWithLayout, FusionAmpereMatmulSplitKBias_CUDA) {
      NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(8, 0, 9, 0);
    
      // Keep multiples of 8 to keep vectorizable.
      int M = 504, N = 136, K = 8096;
    
      for (int splitk_factor : {2}) {
        for (int use_smem_epilogue : {false, true}) {
          Fusion fusion;
          FusionGuard fg(&fusion);
          auto shapes = matmulAtInputShape3DTuring(-1, -1, -1, layout);
    
          auto tv0 = makeContigConcreteTensor(shapes.first, DataType::Half);
          auto tv1 = makeContigConcreteTensor(shapes.second, DataType::Half);
          auto tv2 = makeContigTensor(1, DataType::Half);
    
          fusion.addInput(tv0);
          fusion.addInput(tv1);
          fusion.addInput(tv2);
    
          tv0 = canonicalizeInputToBMNK(tv0, layout, MmaOperand::A);
          tv1 = canonicalizeInputToBMNK(tv1, layout, MmaOperand::B);
          auto tv3 = fusedMultiplySum(tv0, tv1, {-1});
          auto tv4 = broadcast(tv2, {false, true});
          auto tv5 = add(tv3, tv4); // bias
    
          fusion.addOutput(tv5);
    
          MatMulTileOptions gemm_tile;
          gemm_tile.cta_tile = GemmTile(128, 128, 32);
          gemm_tile.warp_tile = GemmTile(64, 64, 32);
    
          MatmulParams mparams;
          mparams.supported_vec_size = {8, 8, 4};
          mparams.mma_macro = MmaMacro::Ampere_16_8_16;
          mparams.tile_sizes = gemm_tile;
          mparams.splitk_factor = splitk_factor;
          mparams.use_smem_epilogue = use_smem_epilogue;
          mparams.promote_prologue_smem_reuse = true;
    
          SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
              ->schedule(&fusion, &mparams);
    
          auto [t0, t1] = matmulAtInput3DTuring(M, N, K, layout);
          at::Tensor bias = at::randn({M}, t0.options());
    
          KernelExecutor ke;
          NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK(
              7, 5, ke.compile(&fusion, {t0, t1, bias}));
          EXPECT_TRUE(getBankConflictInfo(ke.compiledKernel()->kernel()).empty());
          auto cg_outputs = ke.run({t0, t1, bias});
          ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
              ke.compiledKernel()->kernel()));
          auto tref = atBiasEpilogue(
              atMatmul(t0.to(at::kFloat), t1.to(at::kFloat), layout), bias);
    
          // Relax tolerance for larger sum due to large K
          NVF_CHECK(at::allclose(
              cg_outputs[0].as<at::Tensor>(), tref, 1e-6 * K, 1e-6 * K));
        }
      }
    }
    
    // Same as above but has a batch dimension and splitk
    TEST_P(MatmulTestWithLayout, AmpereMatmulBatchSplitK) {
      NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(8, 0, 9, 0);
    
      // Keep multiples of 8 to keep vectorizable.
      int B = 2, M = 504, N = 136, K = 2048;
    
      for (int splitk_factor : {2}) {
        for (int use_smem_epilogue : {false, true}) {
          Fusion fusion;
          FusionGuard fg(&fusion);
          auto tv0 = makeContigTensor(3, DataType::Half);
          auto tv1 = makeContigTensor(3, DataType::Half);
    
          fusion.addInput(tv0);
          fusion.addInput(tv1);
    
          tv0 = canonicalizeInputToBMNK(tv0, layout, MmaOperand::A);
          tv1 = canonicalizeInputToBMNK(tv1, layout, MmaOperand::B);
          auto tv2 = fusedMultiplySum(tv0, tv1, {-1});
    
          fusion.addOutput(tv2);
    
          MatMulTileOptions gemm_tile;
          gemm_tile.cta_tile = GemmTile(128, 128, 32);
          gemm_tile.warp_tile = GemmTile(64, 64, 32);
    
          MatmulParams mparams;
          mparams.supported_vec_size = {8, 8, 4};
          mparams.mma_macro = MmaMacro::Ampere_16_8_16;
          mparams.tile_sizes = gemm_tile;
          mparams.splitk_factor = splitk_factor;
          mparams.use_smem_epilogue = use_smem_epilogue;
          mparams.promote_prologue_smem_reuse = true;
    
          SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
              ->schedule(&fusion, &mparams);
    
          auto t0 =
              matmulAtInput2D(layout, TensorMatmulPos::A, at::kHalf, M, N, K, B);
          auto t1 =
              matmulAtInput2D(layout, TensorMatmulPos::B, at::kHalf, M, N, K, B);
    
          KernelExecutor ke;
          NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK(7, 5, ke.compile(&fusion, {t0, t1}));
          ASSERT_TRUE(getBankConflictInfo(ke.compiledKernel()->kernel()).empty());
          ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
              ke.compiledKernel()->kernel()));
          auto cg_outputs = ke.run({t0, t1});
          auto tref = atMatmul(t0.to(at::kFloat), t1.to(at::kFloat), layout);
    
          // Relax tolerance for larger sum due to large K
          NVF_CHECK(at::allclose(
              cg_outputs[0].as<at::Tensor>(), tref, 1e-6 * K, 1e-6 * K));
        }
      }
    }
    
    // Test batch splitk with bias epilogue
    TEST_P(MatmulTestWithLayout, AmpereMatmulBatchSplitKBias) {
      NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(8, 0, 9, 0);
    
      // Keep multiples of 8 to keep vectorizable.
      int B = 2, M = 504, N = 136, K = 2048;
    
      for (int splitk_factor : {2}) {
        for (int use_smem_epilogue : {false, true}) {
          Fusion fusion;
          FusionGuard fg(&fusion);
          auto tv0 = makeContigTensor(3, DataType::Half);
          auto tv1 = makeContigTensor(3, DataType::Half);
          auto tv2 = makeContigTensor(1, DataType::Half);
    
          fusion.addInput(tv0);
          fusion.addInput(tv1);
          fusion.addInput(tv2);
    
          tv0 = canonicalizeInputToBMNK(tv0, layout, MmaOperand::A);
          tv1 = canonicalizeInputToBMNK(tv1, layout, MmaOperand::B);
          auto tv3 = fusedMultiplySum(tv0, tv1, {-1});
          auto tv4 = broadcast(tv2, {true, false, true});
          auto tv5 = add(tv3, tv4);
    
          fusion.addOutput(tv5);
    
          MatMulTileOptions gemm_tile;
          gemm_tile.cta_tile = GemmTile(128, 128, 32);
          gemm_tile.warp_tile = GemmTile(64, 64, 32);
    
          MatmulParams mparams;
          mparams.supported_vec_size = {8, 8, 4};
          mparams.mma_macro = MmaMacro::Ampere_16_8_16;
          mparams.tile_sizes = gemm_tile;
          mparams.splitk_factor = splitk_factor;
          mparams.use_smem_epilogue = use_smem_epilogue;
          mparams.promote_prologue_smem_reuse = true;
    
          SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
              ->schedule(&fusion, &mparams);
    
          auto t0 =
              matmulAtInput2D(layout, TensorMatmulPos::A, at::kHalf, M, N, K, B);
          auto t1 =
              matmulAtInput2D(layout, TensorMatmulPos::B, at::kHalf, M, N, K, B);
          at::Tensor bias = at::randn({M}, t0.options());
    
          KernelExecutor ke;
          NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK(
              7, 5, ke.compile(&fusion, {t0, t1, bias}));
          ASSERT_TRUE(getBankConflictInfo(ke.compiledKernel()->kernel()).empty());
          ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
              ke.compiledKernel()->kernel()));
          auto cg_outputs = ke.run({t0, t1, bias});
          auto tref = atBiasEpilogue(
              atMatmul(t0.to(at::kFloat), t1.to(at::kFloat), layout), bias);
    
          // Relax tolerance for larger sum due to large K
          NVF_CHECK(at::allclose(
              cg_outputs[0].as<at::Tensor>(), tref, 1e-6 * K, 1e-6 * K));
        }
      }
    }
    
    // Avoid lowering error https://github.com/NVIDIA/Fuser/issues/1808
    TEST_F(MatmulTest, ReproIssue1808) {
      NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(8, 0, 9, 0);
    
      // Keep multiples of 8 to keep vectorizable.
      int M = 504, N = 136, K = 248;
    
      auto layout = MmaLayout::TN;
    
      Fusion fusion;
      FusionGuard fg(&fusion);
    
      auto shapes = matmulAtInputShape3DTuring(-1, -1, -1, layout);
    
      auto tv0 = makeContigConcreteTensor(shapes.first, DataType::Half);
      auto tv1 = makeContigConcreteTensor(shapes.second, DataType::Half);
    
      fusion.addInput(tv0);
      fusion.addInput(tv1);
    
      tv0 = canonicalizeInputToBMNK(tv0, layout, MmaOperand::A);
      tv1 = canonicalizeInputToBMNK(tv1, layout, MmaOperand::B);
      auto tv2 = fusedMultiplySum(tv0, tv1, {-1});
    
      fusion.addOutput(tv2);
    
      MatMulTileOptions gemm_tile;
      gemm_tile.cta_tile = GemmTile(160, 144, 16);
      gemm_tile.warp_tile = GemmTile(80, 24, 16);
    
      MatmulParams mparams;
      mparams.supported_vec_size = {8, 8, 4};
      mparams.mma_macro = MmaMacro::Ampere_16_8_16;
      mparams.tile_sizes = gemm_tile;
      mparams.async_gmem_load_operands = true;
      mparams.circular_buffer_options.circular_buffer_smem_write = true;
      mparams.circular_buffer_options.circular_buffer_smem_read = true;
      mparams.circular_buffer_options.smem_circular_buffer_stage = 4;
      SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
          ->schedule(&fusion, &mparams);
    
      auto inputs = matmulAtInput3DTuring(M, N, K, layout);
    
      KernelExecutor ke;
      NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK(
          8,
          0,
          ke.compile(
              &fusion,
              {inputs.first, inputs.second},
              LaunchParams(),
              matmul_cparams));
      ASSERT_TRUE(getBankConflictInfo(ke.compiledKernel()->kernel()).empty());
      ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
          ke.compiledKernel()->kernel()));
      auto cg_outputs = ke.run({inputs.first, inputs.second});
      auto tref = atMatmul(
          inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout);
      NVF_CHECK(at::allclose(cg_outputs[0].as<at::Tensor>(), tref, 0.0001, 0.0001));
    }
    
    // Test matmul with sizes that are not divisible by 8 and with misaligned inputs
    TEST_P(MatmulTestWithLayout, MisalignedVectorization) {
      NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(8, 0, 9, 0);
      for (bool add_2d_bias : {false, true}) {
        for (bool downcast_output : {false, true}) {
          for (const auto& [M, N, K, alignA, alignB, alignBias] : std::vector<
                   std::
                       tuple<int64_t, int64_t, int64_t, int64_t, int64_t, int64_t>>{
                   {504, 136, 248, 8, 8, 8}, // all fully vectorizable in all
                                             // layouts
                   {504, 136, 249, 8, 8, 8}, // odd K, operands not vectorizable
                                             // in TN. output fully vectorizable
                   {504, 137, 248, 8, 8, 8}, // A fully vectorizable, B fully
                                             // vectorizable unless transposed,
                                             // output not vectorizable
                   {505, 136, 248, 8, 8, 8}, // B fully vectorizable, A
                                             // vectorizable unless transposed,
                                             // output fully vectorizable
                   {505, 137, 248, 8, 8, 8}, // none vectorizable
    
                   // Cases with vectorizable strides but misaligned base pointers
                   {504, 136, 248, 2, 8, 8}, // A not vectorizable due to offset
                   {504, 136, 248, 8, 2, 8}, // B not vectorizable due to offset
                   {504, 136, 248, 8, 8, 2}, // epilogue not vectorizable due to
                   // offset
               }) {
            const auto maybeUnalign = [](const at::Tensor& t, int64_t offset) {
              if (offset == 16 / t.element_size()) {
                // Already fully aligned
                return t;
              }
              return at::pad(t.ravel(), {{0, offset}})
                  .index({at::indexing::Slice(offset, t.numel() + offset, 1)})
                  .view({t.size(0), t.size(1)});
            };
    
            auto t0 = maybeUnalign(
                matmulAtInput2D(layout, TensorMatmulPos::A, at::kHalf, M, N, K),
                alignA);
            auto t1 = maybeUnalign(
                matmulAtInput2D(layout, TensorMatmulPos::B, at::kHalf, M, N, K),
                alignB);
    
            auto tref = atMatmul(t0.to(at::kFloat), t1.to(at::kFloat), layout);
    
            KernelArgumentHolder inputs = {t0, t1};
    
            if (add_2d_bias) {
              const auto options =
                  at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0);
              auto bias = maybeUnalign(at::randn({M, N}, options), alignBias);
              tref = tref + bias;
              inputs.push(bias);
            }
    
            if (downcast_output) {
              tref = tref.to(at::kHalf);
            }
    
            auto fusion = std::make_unique<Fusion>();
            FusionGuard fg(fusion.get());
    
            auto tv0 = makeContigTensor(2, DataType::Half);
            auto tv1 = makeContigTensor(2, DataType::Half);
            fusion->addInput(tv0);
            fusion->addInput(tv1);
    Tests Disabled

    LdMatrix and StMatrix tests are disabled with comments about "missing alternate loop domain". This suggests incomplete implementation and these tests should either be fixed or have clear plans for restoration.

    TEST_P(LdMatrixTest, DISABLED_Regular) {
      Fusion fusion;
      FusionGuard fg(&fusion);
    
      auto macro = std::get<0>(GetParam());
      auto operand = std::get<1>(GetParam());
    
      bool is_a = operand == MmaOperand::A;
    
      int size1 = (is_a ? getM(macro) : getN(macro));
    
      auto tv0 = makeConcreteTensor({size1, getK(macro)}, DataType::Half);
      fusion.addInput(tv0);
      auto tv1 = set(tv0);
      tv1->setMemoryType(MemoryType::Shared);
      auto tv2 = set(tv1);
      tv2->definition()->as<LoadStoreOp>()->setOpType(LoadStoreOpType::LdMatrix);
      auto tv3 = set(tv2);
      fusion.addOutput(tv3);
    
      tv2->applyMmaSwizzle(operand);
      tv3->applyMmaSwizzle(operand);
    
      tv3->merge(0);
      if (is_a) {
        tv3->merge(0);
      }
      tv3->axis(0)->parallelize(ParallelType::TIDx);
    
      auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0);
      auto t0 = at::randn({size1, getK(macro)}, options);
    
      KernelExecutor ke;
      ke.compile(&fusion, {t0}, LaunchParams(), matmul_cparams);
      auto cg_outputs = ke.run({t0});
    
      testValidate(&fusion, cg_outputs, {t0}, __LINE__, __FILE__);
    }
    
    // We get shapes M and N from MmaMacrao. The vector of ints are
    // the tile_m and tile_n factors (8x8, 16x8 and 16x16).
    using StMatrixTestParams = std::tuple<MmaMacro, std::vector<int>, DataType>;
    
    class StMatrixTest : public NVFuserFixtureParamTest<StMatrixTestParams> {
     protected:
      void SetUp() override {
        if (cudaArchGuardShouldSkip(9, 0)) {
          GTEST_SKIP() << "skipping tests on pre-Hopper GPUs";
        }
        NVFuserTest::SetUp();
        EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel);
      }
    };
    
    // Disabled as the alternate loop domain is missing
    TEST_P(StMatrixTest, DISABLED_Regular) {
      Fusion fusion;
      FusionGuard fg(&fusion);
    
      auto macro = std::get<0>(GetParam());
      auto tile_sizes = std::get<1>(GetParam());
      auto dtype = std::get<2>(GetParam());
      auto sizeM = getM(macro);
      auto sizeN = getN(macro);
      int64_t tile_m = tile_sizes.at(0);
      int64_t tile_n = tile_sizes.at(1);
    
      if (sizeM % tile_m || sizeN % tile_n) {
        GTEST_SKIP() << "Fractional tiling is not supported/tested";
      }
    
      fusion.manage("ldst_matrix_m_tile", tile_m);
      fusion.manage("ldst_matrix_n_tile", tile_n);
      fusion.manage("ldst_matrix_m_smem", sizeM);
      fusion.manage("ldst_matrix_n_smem", sizeN);
    
      auto tv0 = makeContigConcreteTensor({sizeM, sizeN}, dtype);
      fusion.addInput(tv0);
      // tv0 (global) -> tv1 (registers)
      auto tv1 = set(tv0);
      // tv1 (register) -> tv2 (shared)
      auto tv2 = set(tv1);
      tv2->definition()->as<LoadStoreOp>()->setOpType(LoadStoreOpType::StMatrix);
      tv2->setMemoryType(MemoryType::Shared);
      // tv2 (shared) -> tv3(global)
      auto tv3 = set(tv2);
      fusion.addOutput(tv3);
    
      tv0->merge(0);
      tv0->split(0, 32);
      tv0->axis(1)->parallelize(ParallelType::TIDx);
    
      // TODO Set alternate loop domain here once idModel support
      // MmaInputSmemSwizzle::None
    
      for (auto tv : {tv1, tv2}) {
        auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
            tv->getLoopDomain());
        tv->setLoopDomain(s.as<IterDomain*>());
      }
      tv1->setAllocationDomain(tv1->getLoopDomain(), true);
    
      mma_utils::scheduleLdStMatrixForMmaOutput(tv2, tile_m, tile_n);
    
      tv2->axis(-1)->parallelize(ParallelType::Vectorize);
    
      tv3->merge(0);
      tv3->split(0, 32);
      tv3->axis(1)->parallelize(ParallelType::TIDx);
    
      auto options =
          at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
      auto t0 = at::randn({sizeM, sizeN}, options);
    
      KernelExecutor ke;
      ke.compile(&fusion, {t0}, LaunchParams(), matmul_cparams);
      auto cg_outputs = ke.run({t0});
    
      testValidate(&fusion, cg_outputs, {t0}, __LINE__, __FILE__);
    }
    
    std::string testNameStMatrixTest(
        const testing::TestParamInfo<StMatrixTestParams>& info) {
      std::ostringstream os;
      auto macro = std::get<0>(info.param);
      auto tile_sizes = std::get<1>(info.param);
      auto dtype = std::get<2>(info.param);
      auto sizeM = getM(macro);
      auto sizeN = getN(macro);
      auto tile_m = tile_sizes.at(0);
      auto tile_n = tile_sizes.at(1);
    
      os << "m_" << sizeM << "_n_" << sizeN << "_tile_m_" << tile_m << "_tile_n_"
         << tile_n << "_" << mma_utils::dtypeToChar(dtype);
      return os.str();
    }
    
    INSTANTIATE_TEST_SUITE_P(
        ,
        StMatrixTest,
        testing::Combine(
            testing::ValuesIn(kAllHopperMacros),
            testing::Values(
                // tile_m, tile_n
                std::vector<int>{16, 8},
                std::vector<int>{16, 16}),
            testing::Values(DataType::Half, DataType::BFloat16)),
        testNameStMatrixTest);
    
    // Disabled as the alternate loop domain is missing
    TEST_P(LdMatrixTest, DISABLED_Transpose) {
      Fusion fusion;
      FusionGuard fg(&fusion);
    
      auto macro = std::get<0>(GetParam());

    Test failures

    • (Medium, 1) Thunder–eager scalar mismatch in NanoGPT autograd nvFuser CUDA test

      Test Name H100 Source
      thunder.tests.test_networks.test_nanogpt_complete_autograd_nvfuser_cuda_thunder.dtypes.float32

    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Jan 31, 2026

    !test

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    1 participant