diff --git a/.github/workflows/build-llvm.yml b/.github/workflows/build-llvm.yml index dbfe8443c..730fce043 100644 --- a/.github/workflows/build-llvm.yml +++ b/.github/workflows/build-llvm.yml @@ -32,7 +32,7 @@ jobs: - uses: actions/checkout@v4 with: - repository: Menooker/mlir-extensions + repository: intel/mlir-extensions ref: ${{ env.IMEX_HASH }} path: mlir-extensions if: ${{ matrix.build-type == 'IMEX' }} diff --git a/cmake/imex-version.txt b/cmake/imex-version.txt index 678e84fdf..d4bd81ca0 100644 --- a/cmake/imex-version.txt +++ b/cmake/imex-version.txt @@ -1 +1 @@ -ee459724294e165e360e1de72ad3b217eb9b6206 \ No newline at end of file +6c2e414a953b9a118bce6adac21cf9d42630e674 \ No newline at end of file diff --git a/cmake/imex.cmake b/cmake/imex.cmake index 13cd496b1..c698945c4 100644 --- a/cmake/imex.cmake +++ b/cmake/imex.cmake @@ -14,7 +14,7 @@ if (NOT DEFINED IMEX_INCLUDES) # TODO: Change to main https://github.com/intel/mlir-extensions when all the # required functionality is merged. - gc_fetch_content(imex "${IMEX_HASH}" https://github.com/Menooker/mlir-extensions + gc_fetch_content(imex "${IMEX_HASH}" https://github.com/intel/mlir-extensions SET IMEX_CHECK_LLVM_VERSION=ON IMEX_ENABLE_L0_RUNTIME=0 ) diff --git a/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp b/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp index bbc0dc56c..8309e106c 100644 --- a/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp +++ b/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp @@ -1394,6 +1394,92 @@ struct ConvertNamedEltwiseToXeGPU : public OpRewritePattern { LinalgToXeGPUOptions options; }; +// Create XeGPU kernel out of memory fill operation. +LogicalResult createMemoryFillKernel(linalg::LinalgOp linalgOp, + PatternRewriter &rewriter) { + Location loc = linalgOp.getLoc(); + auto ctx = linalgOp.getContext(); + + auto scalar = linalgOp.getDpsInputs()[0]; + auto output = linalgOp.getDpsInits()[0]; + auto outputType = cast(output.getType()); + auto outputShape = outputType.getShape(); + + // Extract SIMD sized sub-tiles + int maxSizeSIMD = 256; + int64_t subTileCols = outputShape[1]; + int64_t subTileRows = std::min(outputShape[0], maxSizeSIMD / subTileCols); + + // Output descriptors for later stores. + SmallVector outputTiles = createDescriptorTiles( + rewriter, loc, output, outputShape, {0, 0}, {subTileRows, subTileCols}); + + SmallVector results; + for (size_t i = 0; i < outputTiles.size(); i++) { + // Operands are sub-tiles at the same location. + auto flatType = VectorType::get({subTileRows * subTileCols}, + outputType.getElementType()); + auto tileType = VectorType::get({subTileRows, subTileCols}, + outputType.getElementType()); + Value vec = rewriter.create(loc, flatType, scalar); + Value res = rewriter.create(loc, tileType, vec); + + if (!res) + return failure(); + + results.push_back(res); + } + + // Store results. + auto writeCacheHint = + xegpu::CachePolicyAttr::get(ctx, xegpu::CachePolicy::WRITE_BACK); + for (size_t i = 0; i < outputTiles.size(); i++) { + rewriter.create(loc, results[i], outputTiles[i], + /*l1_hint=*/writeCacheHint, + /*l2_hint=*/writeCacheHint, + /*l3_hint=*/writeCacheHint); + } + + rewriter.eraseOp(linalgOp); + + return success(); +} + +// Convert a named fill operation to an XeGPU kernel. +template +struct ConvertMemoryFillToXeGPU : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + ConvertMemoryFillToXeGPU(MLIRContext *ctx, LinalgToXeGPUOptions options) + : OpRewritePattern(ctx), options(options) {} + + LogicalResult matchAndRewrite(LinalgOpTy linalgOp, + PatternRewriter &rewriter) const override { + if (!linalgOp.hasPureBufferSemantics()) { + return rewriter.notifyMatchFailure( + linalgOp, "Linalg eltwise to GPU expects memref type"); + } + if (linalgOp.hasDynamicShape()) { + return rewriter.notifyMatchFailure( + linalgOp, "Expect static shape when mapping to GPU"); + } + auto isInputValid = + success(linalgOp.isScalar(linalgOp.getDpsInputOperand(0))); + if (failed(isInputValid)) + return isInputValid; + + auto isOutputValid = + isValidMemrefOperand(linalgOp, linalgOp.getDpsInits()[0], rewriter); + if (failed(isOutputValid)) + return isOutputValid; + + return createMemoryFillKernel(linalgOp, rewriter); + } + +private: + LinalgToXeGPUOptions options; +}; + // TODO: Finalize BRGEMM support and register the pattern. void populateLinalgGemmToXeGPUPatterns(RewritePatternSet &patterns, LinalgToXeGPUOptions options) { @@ -1418,6 +1504,12 @@ void populateLinalgEltwiseToXeGPUPatterns(RewritePatternSet &patterns, options); } +void populateLinalgMemoryFillToXeGPUPatterns(RewritePatternSet &patterns, + LinalgToXeGPUOptions options) { + patterns.add>(patterns.getContext(), + options); +} + struct LinalgToXeGPU : public gc::impl::LinalgToXeGPUBase { using LinalgToXeGPUBase::LinalgToXeGPUBase; @@ -1429,6 +1521,11 @@ struct LinalgToXeGPU : public gc::impl::LinalgToXeGPUBase { populateLinalgGemmToXeGPUPatterns(gemmPatterns, options); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(gemmPatterns)); + // Convert memory fill ops. + RewritePatternSet fillPatterns(&getContext()); + populateLinalgMemoryFillToXeGPUPatterns(fillPatterns, options); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(fillPatterns)); + // Convert other remaining ops. RewritePatternSet patterns(&getContext()); populateLinalgEltwiseToXeGPUPatterns(patterns, options); diff --git a/scripts/compile.sh b/scripts/compile.sh index 08ad73b20..67d688f09 100755 --- a/scripts/compile.sh +++ b/scripts/compile.sh @@ -120,7 +120,7 @@ build_llvm() { local mlir_ext_dir="$EXTERNALS_DIR/mlir-extensions" if ! [ -d "$mlir_ext_dir" ]; then cd "$EXTERNALS_DIR" - git clone https://github.com/Menooker/mlir-extensions.git + git clone https://github.com/intel/mlir-extensions.git cd "$mlir_ext_dir" else cd "$mlir_ext_dir" diff --git a/test/mlir/test/gc/gpu-runner/XeGPU/f16_mlp_32x4096x4096x4096.mlir b/test/mlir/test/gc/gpu-runner/XeGPU/f16_mlp_32x4096x4096x4096.mlir new file mode 100644 index 000000000..b6d31092a --- /dev/null +++ b/test/mlir/test/gc/gpu-runner/XeGPU/f16_mlp_32x4096x4096x4096.mlir @@ -0,0 +1,57 @@ +// RUN: gc-opt %s --pass-pipeline='builtin.module(func.func(iterative-tiling-and-fusion{use-cost-model=0 default-tile-size=matmul:{16,16}}),eliminate-empty-tensors,empty-tensor-to-alloc-tensor,one-shot-bufferize{bufferize-function-boundaries=1 function-boundary-type-conversion=identity-layout-map},drop-equivalent-buffer-results,func.func(finalizing-bufferize),canonicalize,cse,drop-equivalent-buffer-results,expand-realloc,canonicalize,ownership-based-buffer-deallocation,canonicalize,buffer-deallocation-simplification,bufferization-lower-deallocations,cse,canonicalize,convert-bufferization-to-memref,func.func(scf-forall-to-parallel),func.func(linalg-to-xegpu{stages=1 dpas-tile=8,16,16 k-tile=16}),xegpu-fold-alias-ops,func.func(convert-linalg-to-parallel-loops),func.func(gpu-map-parallel-loops),func.func(convert-parallel-loops-to-gpu),func.func(insert-gpu-allocs),gpu-kernel-outlining,canonicalize,set-spirv-capabilities{client-api=opencl},gpu.module(set-spirv-abi-attrs{client-api=opencl}),lower-affine,imex-vector-linearize,gpu.module(convert-xegpu-to-vc),reconcile-unrealized-casts,bf16-to-gpu,gpu.module(convert-func-to-spirv),gpu.module(convert-vector-to-spirv),imex-convert-gpu-to-spirv,spirv.module(spirv-lower-abi-attrs,spirv-update-vce),func.func(llvm-request-c-wrappers),serialize-spirv,convert-vector-to-scf,convert-gpu-to-gpux,convert-scf-to-cf,convert-cf-to-llvm,convert-vector-to-llvm,convert-index-to-llvm,convert-arith-to-llvm,convert-func-to-llvm,convert-math-to-llvm,convert-gpux-to-llvm,convert-index-to-llvm,expand-strided-metadata,lower-affine,finalize-memref-to-llvm,reconcile-unrealized-casts)' \ +// RUN: | gc-cpu-runner -e main --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime | FileCheck %s + +module { + func.func @linalg_mlp(%arg0: tensor<32x4096xf16>, %arg1: tensor<4096x4096xf16>, %arg2 : tensor<32x4096xf16>, + %arg3: tensor<4096x4096xf16>, %arg4 : tensor<32x4096xf16>) { + %cst = arith.constant 0.000000e+00 : f16 + %0 = tensor.empty() : tensor<32x4096xf16> + %1 = linalg.fill ins(%cst : f16) outs(%0 : tensor<32x4096xf16>) -> tensor<32x4096xf16> + %2 = linalg.matmul ins(%arg0, %arg1 : tensor<32x4096xf16>, tensor<4096x4096xf16>) + outs(%1 : tensor<32x4096xf16>) -> (tensor<32x4096xf16>) + %3 = tensor.empty() : tensor<32x4096xf16> + %4 = linalg.add ins(%arg2, %2 : tensor<32x4096xf16>, tensor<32x4096xf16>) + outs(%3 : tensor<32x4096xf16>) -> tensor<32x4096xf16> + %5 = arith.constant dense<0.000000e+00> : tensor<32x4096xf16> + %6 = tensor.empty() : tensor<32x4096xf16> + %7 = linalg.max ins(%5, %4 : tensor<32x4096xf16>, tensor<32x4096xf16>) + outs(%6 : tensor<32x4096xf16>) -> tensor<32x4096xf16> + + %8 = tensor.empty() : tensor<32x4096xf16> + %9 = linalg.fill ins(%cst : f16) outs(%8 : tensor<32x4096xf16>) -> tensor<32x4096xf16> + %10 = linalg.matmul ins(%7, %arg3 : tensor<32x4096xf16>, tensor<4096x4096xf16>) + outs(%9 : tensor<32x4096xf16>) -> (tensor<32x4096xf16>) + %11 = tensor.empty() : tensor<32x4096xf16> + %12 = linalg.add ins(%arg4, %10 : tensor<32x4096xf16>, tensor<32x4096xf16>) + outs(%11 : tensor<32x4096xf16>) -> tensor<32x4096xf16> + %13 = arith.constant dense<0.000000e+00> : tensor<32x4096xf16> + %14 = tensor.empty() : tensor<32x4096xf16> + %15 = linalg.max ins(%13, %12 : tensor<32x4096xf16>, tensor<32x4096xf16>) + outs(%14 : tensor<32x4096xf16>) -> tensor<32x4096xf16> + + %slice = tensor.extract_slice %15[0, 0][32, 1][1, 1] : tensor<32x4096xf16> to tensor<32xf16> + %cast = tensor.cast %slice : tensor<32xf16> to tensor<*xf16> + call @printMemrefF16(%cast) : (tensor<*xf16>) -> () + + return + } + + func.func @main() { + %0 = arith.constant dense<0.01> : tensor<32x4096xf16> + %1 = arith.constant dense<0.01> : tensor<4096x4096xf16> + %2 = arith.constant dense<0.02> : tensor<32x4096xf16> + %3 = arith.constant dense<0.01> : tensor<4096x4096xf16> + %4 = arith.constant dense<0.02> : tensor<32x4096xf16> + + func.call @linalg_mlp(%0, %1, %2, %3, %4) : (tensor<32x4096xf16>, tensor<4096x4096xf16>, tensor<32x4096xf16>, + tensor<4096x4096xf16>, tensor<32x4096xf16>) -> () + return + } + + func.func private @printMemrefF16(%ptr : tensor<*xf16>) attributes { llvm.emit_c_interface } +} + +// CHECK: Unranked Memref base@{{(0x)?[-0-9a-fA-F]*}} +// CHECK-SAME: rank = 1 offset = 0 sizes = [32] strides = [4096] data = +// CHECK-NEXT: [17.625, 17.625, 17.625, 17.625, 17.625, 17.625, 17.625, 17.625, 17.625, 17.625, 17.625, 17.625, 17.625, 17.625, 17.625, 17.625, 17.625, 17.625, 17.625, 17.625, 17.625, 17.625, 17.625, 17.625, 17.625, 17.625, 17.625, 17.625, 17.625, 17.625, 17.625, 17.625]