diff --git a/docs/deep_tile_matmul_design.md b/docs/deep_tile_matmul_design.md new file mode 100644 index 000000000..4e48d82de --- /dev/null +++ b/docs/deep_tile_matmul_design.md @@ -0,0 +1,594 @@ +# DOC: deep-Tiled Matmul + +## Introduction + +Tiling and parallelization are important for the performance of a computation intensitive workload (matmul, convolution, and e.t.c). Modern hardware is often equipped with multiple cores and multiple levels of cache, each with different characteristics in terms of size, latency, and bandwidth. To achieve good performance, it is important to utilize the parallelism of the underlying hardware and minimize the number of cache misses to improve the performance of the generated code. The goal of this document is to provide a design overview of the deep-tiled matmul in the graph compiler and its current situation in the community. + +## Current Situation in the MLIR Community + +According to the last section, tiling and parallelization are two important optimization techniques used in compilers to improve the performance of the generated code(matmul, convolution, and e.t.c). The code template could allow some complex optimization(some nontrivial memory copy/reuse to maximize the hardware efficiency), which is hard to write a unified pass in the compiler. + +In the upstream MLIR, there is already some support for tiling and parallelization optimization. The `Linalg` dialect provides a tiling interface to support tiling optimization. Besides, for better representing the concept of schedule, it also introduces the `Transform` dialect to declare the `schedule` in an IR form(vertical to the `payload`). + +This section will introduce the current situation in the MLIR community about the tiling interface, `Transform` dialect, hardware abstration layer and what is missing in the current upstream MLIR. + +### Tiling Interface And the Related Pass + +The MLIR provides the tiling interface as follows to support some simple tiling optimization. + +The tiling interface is a set of methods that an operation can implement to provide information about its iteration space and how it can be tiled. The tiling interface is used by the tiling pass to generate a tiled implementation of the operation. It could easily transform the operation like: + +```MLIR +%0 = linalg.generic ins(%in) outs(%out) {indexing_maps = [affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"]} + : tensor -> tensor +``` + +into: + +```MLIR +%1 = scf.for %iv = %c0 to %dim_0 step %c4 iter_args(%arg3 = %out) -> (tensor) { + %2 = tensor.extract_slice %in[%iv] [%c4] [1] : tensor to tensor + %3 = tensor.extract_slice %out[%iv] [%c4] [1] : tensor to tensor + %4 = linalg.generic ins(%2) outs(%3) ["parallel"] : tensor -> tensor + %5 = tensor.insert_slice %4, %arg3 : tensor + scf.yield %5 +} +``` + +The tiling interface further provides several functions like `tileUsingSCF(RewriterBase &rewriter, TilingInterface op, const SCFTilingOptions &options)` to support tile an op inherited the tiling interface, where the SCFTilingOption contains the loop type(scf::For or scf::Forall), interchange vector, mapping vector, and tile size. Through this function, the user could easily generate a tiled implementation of the operation on the parallel axis. + +```c++ +class TilingInterface : public ::mlir::OpInterface { +public: + using ::mlir::OpInterface::OpInterface; + template + struct Trait : public detail::TilingInterfaceTrait {}; + /// Returns a list of iterator types that describe the number of loops. + SmallVector getLoopIteratorTypes(); + /// Returns a list of ranges that describe the loop bounds and + /// step for the loops of the operation. + SmallVector getIterationDomain(OpBuilder & b); + /// Method to generate the tiled implementation of an operation. + /// + /// The iteration space of the operation is returned by + /// `getIterationDomain`. The caller provides the information of the + /// tile within this iteration space whose implementation the + /// caller needs. + /// - `offsets` provides the offset of the tile in the coordinate system + /// of the original iteration space, i.e., if an iteration space + /// dimension had non-zero offset, it must be included in the offset + /// provided here (as opposed to zero-based offset "relative" to the + /// iteration space). + /// - `sizes` provides the size of the tile. + /// + /// The method returns the operation that is the tiled + /// implementation. + FailureOr getTiledImplementation(OpBuilder & b, ArrayRef offsets, ArrayRef sizes); + /// Method to return the position of the result tile computed by the tiled operation. + /// + /// Specifies what tile of the result of the original tensor is computed + /// by the tiled implementation. Expects the same `offsets` and `sizes` as + /// used to obtain the tiled implementation of the operation. + LogicalResult getResultTilePosition(OpBuilder & b, unsigned resultNumber, ArrayRef offsets, ArrayRef sizes, SmallVector & resultOffsets, SmallVector & resultSizes); + /// Method to generate the code that produces a tile of the result. + /// + /// Generates the IR that computes the tile of a result of the + /// operation. The `offsets` and `sizes` describe the tile of + /// the output required. This is different from + /// `getTiledImplementation` which generates the tiled + /// implementation of the operation given a tile of the + /// iteration space. This method generates a tiled + /// implementation of the operation based on the tile of the + /// result required. This method enables fusion by using tile + /// and fuse. The method returns failure if the operation can't be + /// tiled to generate the result tile. In practical terms this + /// implies it cannot be tiled and fused with its consumers. + /// + /// - `offsets` provides the offset of the tile in the coordinate system + /// of the original iteration space, i.e., if an iteration space + /// dimension had non-zero offset, it must be included in the offset + /// provided here (as opposed to zero-based offset "relative" to the + /// iteration space). + /// - `sizes` provides the size of the tile. + FailureOr generateResultTileValue(OpBuilder & b, unsigned resultNumber, ArrayRef offsets, ArrayRef sizes); + /// Generates the scalar implementation of the operation. + /// + /// Given the list `ivs` that represent points in the iteration space + /// (as specified by `getIterationDomain()`) returns the scalar operations + /// that represent the computation at that point in the iteration space. + /// This method is typically used as the "exit path", i.e. once all + /// transformations are done, this method can be used to lower to scalar + /// code that can then be lowered to LLVM or SPIR-V dialects. + LogicalResult generateScalarImplementation(OpBuilder & b, Location loc, ValueRange ivs); +}; + +struct SCFTilingOptions { + /// Computation function that returns the tile sizes for each operation. + /// Delayed construction of constant tile sizes should occur to interoperate + /// with folding. + SCFTileSizeComputationFunction tileSizeComputationFunction = nullptr; + + /// The interchange vector to reorder the tiled loops. + SmallVector interchangeVector = {}; + + /// Specify which loop construct to use for tile and fuse. + enum class LoopType { ForOp, ForallOp }; + LoopType loopType = LoopType::ForOp; + + /// Specify mapping of loops to devices. This is only respected when the loop + /// constructs support such a mapping (like `scf.forall`). Will be ignored + /// when using loop constructs that dont support such a mapping (like + /// `scf.for`) + SmallVector mappingVector = {}; +}; +FailureOr tileUsingSCF(RewriterBase &rewriter, + TilingInterface op, + const SCFTilingOptions &options); + +/// Rewrite a TilingInterface `op` to a tiled `scf.forall`, applying +/// tiling by `numThreads`. +/// If non-empty, the `mapping` is added as an attribute to the +/// resulting `scf.forall`. +/// Zero tile sizes indicate that the dimension is not tiled, and can be +/// thought of as tiling by the full size of data. It is the user's +/// responsibility to ensure that `numThreads` is a valid tiling specification +/// (i.e. that only tiles parallel dimensions, e.g. in the Linalg case). +struct ForallTilingResult { + Operation *tileOp; + Operation *tiledOp; +}; +FailureOr tileToForallOp(RewriterBase &builder, + TilingInterface op, + ArrayRef numThreads, + std::optional mapping); + +/// Same as `tileToForallOp`, but calculate the number of threads +/// required using the given tileSizes. +FailureOr +tileToForallOpUsingTileSizes(RewriterBase &builder, TilingInterface op, + ArrayRef tileSizes, + std::optional mapping); +``` + +The above tiling interface only supports the tiling on the parallel axis. But in a workload like matmul, it is often required to do a tiling on the reduction axis for better performance considering the size of available memory/cache, computation intensity, cache communication, etc. + +```MLIR +%red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%arg0 : tensor) + outs(%out : tensor) { + ^bb0(%arg7: f32, %arg9: f32): + %1 = arith.addf %arg7, %arg9 : f32 + linalg.yield %1 : f32 + } -> tensor +``` + +into: + +```MLIR +%0 = tensor.empty(%dim_1) : tensor +%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor) -> tensor +%2 = scf.for %arg2 = %c0 to %dim_0 step %c5 iter_args(%arg3 = %1) -> (tensor) { + %extracted_slice = tensor.extract_slice %1[0, 0] [%dim, 5] [1, 1] : tensor to tensor + %extracted_slice_2 = tensor.extract_slice %arg0[0, %arg2] [%dim, 5] [1, 1] : tensor to tensor + %4 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%extracted_slice_2 : tensor) + outs(%extracted_slice : tensor) { + ^bb0(%in: f32, %out: f32): + %5 = arith.addf %in, %out : f32 + linalg.yield %5 : f32 + } -> tensor + %dim_3 = tensor.dim %1, %c0 : tensor + %inserted_slice = tensor.insert_slice %4 into %arg3[0, 0] [%dim_3, 5] [1, 1] : tensor into tensor + scf.yield %inserted_slice : tensor +} +%3 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%2 : tensor) + outs(%arg1 : tensor) { + ^bb0(%in: f32, %out: f32): + %4 = arith.addf %in, %out : f32 + linalg.yield %4 : f32 + } -> tensor +``` + +To support this kind of tiling, the MLIR also provide a `PartialReductionOpInterface` based on TilingInterface. The `PartialReductionOpInterface` is an interface with a set of methods that provide information about its partial reduction and how it can be tiled. Based on the `PartialReductionOpInterface`, it further provides a function `tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, ArrayRef tileSize)` to support tile an op inherited the `PartialReductionOpInterface`, where the `tileSize` is the tile size for the reduction axis. + +```c++ +class PartialReductionOpInterface : public ::mlir::OpInterface { +public: + using ::mlir::OpInterface::OpInterface; + template + struct Trait : public detail::PartialReductionOpInterfaceTrait {}; + /// Method to generate a tensor initalized with the identity value of the + /// operation reduction. The tensor shape is equal to operation result + /// shape with new dimension for each non zero tile size. + FailureOr generateInitialTensorForPartialReduction(OpBuilder & b, Location loc, ArrayRef sizes, ArrayRef reductionDim); + /// Method to generate a tiled version of the operation where the tiled + /// reduction dimension are converted to parallel dimensions with a size + /// less or equal to the tile size. This is meant to be used with + /// `mergeReductions` method which will combine the partial reductions. + Operation*tileToPartialReduction(OpBuilder & b, Location loc, ValueRange init, ArrayRef offsets, ArrayRef sizes, ArrayRef reductionDims); + /// Method to merge partial reductions for an operation that has been + /// tiled along the reduction dimensions. This will only apply the + /// reduction the operation. + Operation*mergeReductions(OpBuilder & b, Location loc, ValueRange partialReduce, ArrayRef reductionDim); + +/// Method to tile a reduction and generate a parallel op within a serial loop. +/// Each of the partial reductions are calculated in parallel. Then after the +/// loop all the partial reduction are merged into a final reduction. +/// For example for the following sequence +/// +/// ```mlir +/// %0 = linalg.generic %in ["parallel", "reduction"] +/// : tensor<7x9xf32> -> tensor<7xf32> +/// ``` +/// +/// into: +/// +/// ```mlir +/// %0 = linalg.fill ... : tensor<7x4xf32> +/// %1 = scf.for ... iter_args(%arg0 = %0) +/// %2 = tensor.extract_slice %arg0 : tensor<7x4xf32> -> tensor<7x?xf32> +/// %3 = tensor.extract_slice %in : tensor<7x9xf32> -> tensor<7x?xf32> +/// %4 = linalg.generic %2, %3 ["parallel", "parallel"] +/// : tensor<7x?xf32> -> tensor<7x?xf32> +/// %5 = tensor.insert_slice %3, %0[0, 0] : tensor<7x4xf32> +/// } +/// %6 = linalg.generic %1 ["parallel", "reduction"] +/// : tensor<7x4xf32> -> tensor<7xf32> +/// ``` +FailureOr +tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, + ArrayRef tileSize); + +/// Method to tile a reduction to parallel iterations computing partial +/// reductions. After the loop all the partial reduction are merged into a final +/// reduction. For example for the following sequence +/// +/// ```mlir +/// %0 = linalg.generic %in ["parallel", "reduction"] +/// : tensor<7x9xf32> -> tensor<7xf32> +/// ``` +/// +/// into: +/// +/// ```mlir +/// %0 = linalg.fill ... : tensor<7x4xf32> +/// %1 = scf.forall (%iv) in (%c4) shared_outs(%arg0 = %0) +/// -> (tensor<7x4xf32>) { +/// %2 = tensor.extract_slice %arg3 : tensor<7x4xf32> to tensor<7xf32> +/// %3 = tensor.extract_slice %in : tensor<7x9xf32> -> tensor<7x?xf32> +/// %4 = linalg.generic %2, %3 ["parallel", "reduction"] +/// : tensor<7x?xf32> -> tensor<7xf32> +/// %5 = tensor.insert_slice %3, %arg0[0, %iv] : tensor<7x4xf32> +/// } +/// %6 = linalg.generic %1 ["parallel", "reduction"] +/// : tensor<7x +FailureOr +tileReductionUsingForall(RewriterBase &b, PartialReductionOpInterface op, + ArrayRef numThreads, + ArrayRef tileSizes = {}, + std::optional mapping = std::nullopt); +}; +``` + +### Hardware Abstraction Layer(HAL) + +To achieve the best performance, a good schedule requires the hardware information as a reference. Hardware information like cache size, thread number, etc. is often needed to generate the best schedule. Hardware Abstraction Layer(HAL) is a layer of software that provides a hardware-independent interface to the underlying hardware. The mainstream dl compiler or performance library has a way to get the hardware information to guide the schedule like [IREE](https://iree.dev/developers/design-docs/design-roadmap/#hal-hardware-abstraction-layer-and-multi-architecture-executables), [TVM](https://tvm.apache.org/docs/arch/device_target_interactions.html#tvm-target-specific-overview), [onednn](https://github.com/oneapi-src/oneDNN), etc. However, the MLIR doesn't have such a hardware abstraction layer(HAL) to provide the hardware information. + +## Deep-Tiled Matmul Introduction + +This section will introduce the concept of the deep-tiled matmul optimization(nested matmul/managed_matmul in graph compiler v1) and how it could improve the performance. + +Deep-tiled matmul originally is a [matmul code template](https://github.com/oneapi-src/oneDNN/blob/main/src/graph/backend/graph_compiler/core/src/ops/templates/managed_matmul_core.cpp) in the [onednn graph compiler v1](https://arxiv.org/ftp/arxiv/papers/2301/2301.01333.pdf) with well-tuned default parameters to deliver good performance in the e2e model. The basic idea of the deep-tiled matmul is to partition the iteration space of the matmul into 9 loops as the pseudocode shown below. The outermost 3 loops(`Mthreads, NThreads, KThreads`) are used to partition the iteration space of the matmul according to the number of threads, which is used to balance the workload distribution among the threads and minimize the cache synchronization/communication overhead. The middle 3 loops(`MBlock, NBlock, KBlock`) are used to partition the iteration space of the matmul and control the loop order according to the L2 cache size in the CPU, which is used to improve the data locality of the generated code. The innermost 3 loops(`innermostMBlock, innermostNBlock, innermostKBlock`) are used to partition the iteration space of the matmul and control the loop order according to the L1 cache size in CPU, which could further improve the data locality of the generated code. At this level, the matmul will be converted to the micro-kernel call [*brgemm*](https://arxiv.org/pdf/2104.05755.pdf) which is a highly optimized vectorized kernel(appling the optimiztion like unroll, operation interleave, prefetch, nt load/store, particularly tuned memory accessing pattern, carefully handcrafted register allocation). Though the tiling strategy above is based on the CPU model, it could be easily extended to the concept of the other hardware like GPU, FPGA, etc.(`global/shared memory`, `L1/2 cache size`, `execution model(threads, warp, block, grid, etc)`, etc) + +```c++ +parameter M, N, K, MBlock, NBlock, KBlock, MThreads, NThreads, KThreads, innermostMBlock, innermostNBlock, innermostKBlock +tensor A, B, C +tempC = create_tensor for C -> tensor([KThreads, M, N]) +parallel_for([PM, PN, PK]: [MThreads, NThreads, KThreads]) { + ASlice = extract_slice from A -> tensor([MOuterBlock, KOuterBlock]) + BSlice = extract_slice from B -> tensor([KOuterBlock, NOuterBlock]) + CSlice = extract_slice from C -> tensor([MOuterBlock, NOuterBlock]) + MNumBlock = MOuterBlock / MBlock + NNumBlock = NOuterBlock / NBlock + KNumBlock = KOuterBlock / KBlovk + for([om, on, ok]: [MNumBlock, NNumBlock, KNumBlock]) { + ASlice2 = extract_slice from ASlice -> tensor([MBlock, KBlock]) + BSlice2 = extract_slice from BSlice -> tensor([KBlock, NBlock]) + CSlice2 = extract_slice from CSlice -> tensor([1, MBlock, NBlock]) + MNumInnerBlock = MBlock / innermostMBlock + NNumInnerBlock = NBlock / innermostNBlock + KNumInnerBlock = KBlock / innermostKBlock + for([im, in]: [MNumInnerBlock, NNumInnerBlock]) { + ASlice3 = extract_slice from ASlice2 -> tensor([innermostMBlock, KBlock]) + BSlice3 = extract_slice from BSlice2 -> tensor([KBlock, innermostNBlock]) + CSlice3 = extract_slice from CSlice2 -> tensor([innermostMBlock, innermostNBlock]) + if(ok == 0) { + init CSlice3 with 0 (could use init_brgemm when it is avaliable) + } + brgemm(bs=KNumInnerBlock, M=innermostMBlock, N=innermostNBlock, K=innermostKBlock, +A=ASlice3, B=BSlice3, C=CSlice4, onlyUpdate=(ok!=0)); + } + } +} +C = final_reduce(tempC) -> [M, N] +``` + +## Proposal + +This section will present a proposal based on the [Option 4](#option-4---outer-loop-based-on-tiling-interface--inner-loop-through-a-predefined-template-with-ir-builder) above to implement the deep-tiled matmul in the graph compiler v2. According to the discussion above, option 4 could deliver high performance and maximally reuse the current existing work MLIR, which minimizes the difficulty of acceptance by the community. In the meantime, future optimizations like `loop reorder`, and `axis split` could be easily added by changing the parameter. So this is the recommended way in this document and the detail will be introduced in the following. + +### Position + +> The transformation control infrastructure provided by this dialect is positioned roughly between rewrite patterns and passes. A transformation that is executed by a transform operation is likely to be sufficiently complex to require at least a set of patterns to be implemented. It is also expected to be more focused than a pass: a pass typically applies identical transformations everywhere in the IR, a transform dialect-controlled transformation would apply to a small subset of operations selected, e.g., by a pattern-matching operation or generated by a previous transformation. It is discouraged, although technically possible, to run a pass pipeline as part of the transform op implementation. *From [MLIR documentation](https://mlir.llvm.org/docs/Dialects/Transform/)* + +As MLIR mentions in the documentation, the scope order from largest to smallest is `pass > Transform dialect > rewrite patterns`. The deep-tiled matmul only applies to the operation `matmul` and `batch_matmul`. So it is better to implement it as a rewrite pattern. To better meet the upstream's need, it could be warped into an operation of the `Transform` dialect so that it could become a part of the `Transform` schedule. + +In the graph compiler v2, this could be further warped in a pass `deepTilingRewriteForContractionOperation`, which could also contain other deep-tiling rewrite patterns in the future(`paddedConvolution`, `reduceLoweringConvolution`, `depthwiseConvolution`, etc). This pass is expected to be executed after the `padding/layout propagation`-related pass and before the `fusion` '-related pass. `Layout` related pass could convert the input/output tensor to the required blocked layout to achieve better performance. And fusion-related pass may depend on the tiled matmul's `insert_slice/extract_slice` as the anchor to do fusion. + +```MLIR +... +layout propogation related pass(pack/unpack, pad, propogation, etc) + +deepTilingRewriteForContractionOperation(deep-tiled matmul, deep-tiled padded conv, conv1x1, depthwise conv, etc) + +fusion related pass +... +``` + +Besides, this rewrite pattern is expected to be a part of the linalg dialect. This is similar to the existing rewrite `ConvertConv2DToImg2Col` in MLIR. In graph compiler v2, it could be a part of `linalgX` before upstream. + +### Outer Loop Generation + +For outer loop generation, we will generate the loop step by step according to the parameters/config(`outermost loop for multicore -> loop for L2 cache -> loop for L1 cache`). This part would be implemented based on the tiling interface and its related utility function, which could maximally reuse the existing work in the MLIR and decrease the difficulty of the maintenance. Besides, function like `tileToForallOp` provides an `interchange` parameter which makes it easy to change the loop order according to the workload characteristics. This way could be also easily reused by other operations like `convolution`, `depthwise convolution`, etc because they have a similar structure in this part. + +The expected implementation in pseudocode code is as follows + +```c++ +// generate outer loop with MThreads, NThreads +linalg::tileToForallOp(rewriter, cast(matmul), {MThreads, NThreads}); +// generate outer reduction loop with KThreads +linalg::tileReductionUsingForall(rewriter, cast(matmul), KThreads, tileSizes); +// generate the middle three loops(MBlock, NBlock, KBlock) +scf::tileUsingSCF(rewriter, cast(matmul),tileOption); +// generate the inner loops(innerMostMBlock, innerMostNBlock, innerMostKBlock) +scf::tileUsingSCF(rewriter, cast(matmul),tileOption); +``` + +As mentioned in the [Current Situation in the MLIR Community](#current-situation-in-the-mlir-community), there are still some missing things in the current MLIR like the lack of balance211 for not perfectly divisible cases, inefficient partial K threads position for cpu, etc. These should be further enhanced in future work. + +### Inner Loop Body Generation + +Compared to outer loop generation, the inner loop body generation is sometimes op-specific. For example, the `squeeze stride` optimization for convolution doesn't make any sense for `matmul`. Besides, this part is possibly more complex than the outer-loop(may have tail processing, non-trivial memory copy/init) and hard to unify a pass to do it. So it is better to implement it as a predefined template through IR builder which could make the code more flexible. We could also add easy builder/util support to make it more readable. + +Below is the expected pseudocode of the inner loop body for the deep-tiled matmul in the graph compiler v2. + +```c++ +A = tensor.extract_slice +B = tensor.extract_slice +C = tensor.extract_slice +D3 = scf.if(ok == 0) { + D1 = init_brgemm(A,B,C) tensor<...>, tensor<...>, tensor<...> -> tensor<...> +} else { + D2 = brgemm(A,B,C) tensor<...>, tensor<...>, tensor<...> -> tensor<...> +} -> tensor<...> +tensor.insert_slice D3 +``` + +The inner loop body will convert the `matmul` to the `batch_reduce_gemm`, which will be finally converted to the microkernel [`brgemm`](https://github.com/oneapi-src/oneDNN/pull/1852) call. + +### Config/Schedule + +```c++ +struct MatmulConfig { + int MThreads, NThreads, KThreads; + int MBlock, NBlock, KBlock; + int innerMostMBlock, innerMostNBlock, innerMostKBlock; + int loopOrder; +}; +``` + +The above is the expected config for the deep-tiled matmul. The `MThreads, NThreads, KThreads` is used to partition the iteration space of the matmul according to the number of threads. The `MBlock, NBlock, KBlock` is used to partition the iteration space of the matmul and control the loop order according to the L2 cache size in the CPU. The `innerMostMBlock, innerMostNBlock, innerMostKBlock` is used to partition the iteration space of the matmul and control the loop order according to the L1 cache size in the CPU. The `loopOrder` is used to control the loop order/iterate order according to the workload characteristics. + +A default heuristic config corresponding to these items will be tuned for the performance. + +1. For `MThreads, NThreads, KThreads`, we should rely on the available threads, required memory for the input/output/temp buffer, and the L2/L3 cache size to build a cost model that maximizes the workload balance, threads utilization and minimize the cache synchronization. But the threads on the K axis should be set carefully as it may hurt performance in most cases (performance gain on large K but small M, N). +2. For `MBlock, NBlock, KBlock`, the L2 cache size and the required memory for every core are needed to build a cost model so that the L2 cache misses would be minimized. +3. For `innerMostMBlock, innerMostNBlock, innerMostKBlock`, we need to know the L1 cache size, the size of available registers and vector/matrix-vector (amx-like) length to decide the innermost block size so that the hardware efficiency can be maximized. Besides, if we convert the brgemm to an external library function call, the cost of the function call is also needed to be considered. In the case that M/N/K is not divisible by vector length, we usually will choose a factor of the M/N/K as the innermost block size or do the packing/unpacking to make it divisible in advance(a tradeoff between reducing memory copy and maximize hardware efficiency). +4. The `loopOrder` is mainly related to the workload characteristics(data, weight, output size), the cache size and where the actual data/weight is located at L1/L2/L3/memory. This will have an impact on the visit order of the memory and finally impact the cache data locality. + +The description above shows what should be considered from the horizontal view(`[M/N/K]threads`, `[M/N/K]block`, `innermost[M/N/K]Block`, `loop order`) in the config. However, in the vertical view(`MThreads, MBlock, innermostMBlock`, `N...`, `K...`), they will have some interdependence that will also impact the performance, and the order to decide them will matter. The breakdown of how to decide is as follows. + +1. Firstly, we need to decide the `innerMostBlock[M/N/K]` which will impact the maximum hardware efficiency we can achieve, especially for the machine with a specialized matrix computation unit(amx-like). For example, if the physical matrix vector size is 16x64 and we choose the innermost block size as 8x32, then the theoretical efficiency will be a quarter of the maximum. Even for the vector instruction set like `avx512, avx2, etc`, the `innermostBlock` still matters because they still require the `innermostBlock` to align with the vector length(64/32/...). So the priority of the `innermostBlock` is the highest. +2. After the `innermostBlock` is decided, the input and output matrix will be divided into `[M/N/K]NumBlock` blocks with block size `[M/N/K]innermostBlock`. Then we will decide what `[M/N/K]Threads` should use to distribute these blocks so that the best workload balance, compute intensity and cache utilization can be achieved. +3. After step 2, the number of innermost blocks for every thread has been decided. Then we will decide the `[M/N/K]Block` to further partition the iteration space of the matmul so that the L2 cache misses in a single core would be minimized. This should be the multiples of the `innermost[M/N/K]Block`. +4. After above steps, all tile size is decided and we have enough infomation about where the data is located(L1/L2/L3 and their size). The `loopOrder` could be decided to maximize the data locality/data reuse. What it decides is the order of these loops(`pmpnpkomonokiminik`, `pnpmpkokonominimik`, etc where `p` is the outermost parallel loop, `o` is the middle outer loop, `i` is the innermost loop). + +**Note**: In the graph compiler v1, we also consider the impact of the previous matmul as this will decide where the output of the previous matmul is located (3rd core's l2 cache or 4th core's). This could be also further enhanced in the future. + +The heuristic default config will be implemented as an [analysis pass](https://mlir.llvm.org/docs/PassManagement/#analysis-management). In this way, the heuristic is maximally isolated from the real IR transformation and easier to be accepted by the upstream community(who want to separate the heuristics from passes as much as possible). By the way, other passes like layout/padding propagation could also know which tile size is preferable by the matmul and will not have a dependence cycle among these passes. + +All choices above need to be under the guidance of HAL. But the HAL support(multi-level cache size, machine kind, available threads, register vector length) is not fully ready in the MLIR now. So there is a risk here to tune a good performance for general. + +### Expected IR Change + +Below is a matmul example(`M=256, K=128, N=512`) of the expected IR change after applying the deep-tiled matmul rewrite pattern(with config `MThreads=2, NThreads=2, KThreads=1, MBlock=128, NBlock=256, KBlock=128, innerMostMBlock=32, innerMostKBlock=32, loopOrder=0`). + +```MLIR +%0 = linalg.matmul ins(%cst_0, %cst_1 : tensor<256x128xf32>, tensor<128x512xf32>) outs(%cst_2 : tensor<256x512xf32>) -> tensor<256x512xf32> +``` + +into: + +```MLIR +%0 = scf.forall (%arg0, %arg1) in (2, 2) shared_outs(%arg2 = %cst_3) -> (tensor<256x512xf32>) { + %1 = affine.apply affine_map<(d0) -> (d0 * 128)>(%arg0) + %2 = affine.apply affine_map<(d0) -> (d0 * 256)>(%arg1) + %3 = affine.apply affine_map<(d0) -> (d0 * 128)>(%arg0) + %4 = affine.apply affine_map<(d0) -> (d0 * 256)>(%arg1) + %5 = affine.apply affine_map<(d0) -> (d0 * 128)>(%arg0) + %6 = affine.apply affine_map<(d0) -> (d0 * 256)>(%arg1) + %extracted_slice = tensor.extract_slice %cst_1[%3, 0] [128, 128] [1, 1] : tensor<256x128xf32> to tensor<128x128xf32> + %extracted_slice_5 = tensor.extract_slice %cst_2[0, %4] [128, 256] [1, 1] : tensor<128x512xf32> to tensor<128x256xf32> + %extracted_slice_6 = tensor.extract_slice %arg2[%5, %6] [128, 256] [1, 1] : tensor<256x512xf32> to tensor<128x256xf32> + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c128_7 = arith.constant 128 : index + %7 = scf.for %arg3 = %c0 to %c128 step %c128_7 iter_args(%arg4 = %extracted_slice_6) -> (tensor<128x256xf32>) { + %c0_8 = arith.constant 0 : index + %c256 = arith.constant 256 : index + %c256_9 = arith.constant 256 : index + %10 = scf.for %arg5 = %c0_8 to %c256 step %c256_9 iter_args(%arg6 = %arg4) -> (tensor<128x256xf32>) { + %c0_10 = arith.constant 0 : index + %c128_11 = arith.constant 128 : index + %c128_12 = arith.constant 128 : index + %11 = scf.for %arg7 = %c0_10 to %c128_11 step %c128_12 iter_args(%arg8 = %arg6) -> (tensor<128x256xf32>) { + %extracted_slice_13 = tensor.extract_slice %extracted_slice[%arg3, %arg7] [128, 128] [1, 1] : tensor<128x128xf32> to tensor<128x128xf32> + %extracted_slice_14 = tensor.extract_slice %extracted_slice_5[%arg7, %arg5] [128, 256] [1, 1] : tensor<128x256xf32> to tensor<128x256xf32> + %extracted_slice_15 = tensor.extract_slice %arg8[%arg3, %arg5] [128, 256] [1, 1] : tensor<128x256xf32> to tensor<128x256xf32> + %c0_16 = arith.constant 0 : index + %c128_17 = arith.constant 128 : index + %c32 = arith.constant 32 : index + %12 = scf.for %arg9 = %c0_16 to %c128_17 step %c32 iter_args(%arg10 = %extracted_slice_15) -> (tensor<128x256xf32>) { + %c0_18 = arith.constant 0 : index + %c256_19 = arith.constant 256 : index + %c32_20 = arith.constant 32 : index + %13 = scf.for %arg11 = %c0_18 to %c256_19 step %c32_20 iter_args(%arg12 = %arg10) -> (tensor<128x256xf32>) { + %c0_21 = arith.constant 0 : index + %c128_22 = arith.constant 128 : index + %c128_23 = arith.constant 128 : index + %14 = scf.for %arg13 = %c0_21 to %c128_22 step %c128_23 iter_args(%arg14 = %arg12) -> (tensor<128x256xf32>) { + %extracted_slice_24 = tensor.extract_slice %extracted_slice_13[%arg9, %arg13] [32, 128] [1, 1] : tensor<128x128xf32> to tensor<32x128xf32> + %extracted_slice_25 = tensor.extract_slice %extracted_slice_14[%arg13, %arg11] [128, 32] [1, 1] : tensor<128x256xf32> to tensor<128x32xf32> + %extracted_slice_26 = tensor.extract_slice %arg14[%arg9, %arg11] [32, 32] [1, 1] : tensor<128x256xf32> to tensor<32x32xf32> + %expanded = tensor.expand_shape %extracted_slice_24 [[0, 1], [2]] : tensor<32x128xf32> into tensor<1x32x128xf32> + %expanded_27 = tensor.expand_shape %extracted_slice_25 [[0, 1], [2]] : tensor<128x32xf32> into tensor<1x128x32xf32> + %15 = linalg.batch_reduce_matmul ins(%expanded, %expanded_27 : tensor<1x32x128xf32>, tensor<1x128x32xf32>) outs(%extracted_slice_26 : tensor<32x32xf32>) -> tensor<32x32xf32> + %inserted_slice_28 = tensor.insert_slice %15 into %arg14[%arg9, %arg11] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<128x256xf32> + scf.yield %inserted_slice_28 : tensor<128x256xf32> + } + scf.yield %14 : tensor<128x256xf32> + } + scf.yield %13 : tensor<128x256xf32> + } + %inserted_slice = tensor.insert_slice %12 into %arg8[%arg3, %arg5] [128, 256] [1, 1] : tensor<128x256xf32> into tensor<128x256xf32> + scf.yield %inserted_slice : tensor<128x256xf32> + } + scf.yield %11 : tensor<128x256xf32> + } + scf.yield %10 : tensor<128x256xf32> + } + %8 = affine.apply affine_map<(d0) -> (d0 * 128)>(%arg0) + %9 = affine.apply affine_map<(d0) -> (d0 * 256)>(%arg1) + scf.forall.in_parallel { + tensor.parallel_insert_slice %7 into %arg2[%8, %9] [128, 256] [1, 1] : tensor<128x256xf32> into tensor<256x512xf32> + } +} +``` + +When the `KThreads=2`, there will be partial reduction in the loop + +```MLIR +%0 = scf.forall (%arg0, %arg1) in (2, 2) shared_outs(%arg2 = %cst_3) -> (tensor<256x512xf32>) { + %1 = affine.apply affine_map<(d0) -> (d0 * 128)>(%arg0) + %2 = affine.apply affine_map<(d0) -> (d0 * 256)>(%arg1) + %3 = affine.apply affine_map<(d0) -> (d0 * 128)>(%arg0) + %4 = affine.apply affine_map<(d0) -> (d0 * 256)>(%arg1) + %5 = affine.apply affine_map<(d0) -> (d0 * 128)>(%arg0) + %6 = affine.apply affine_map<(d0) -> (d0 * 256)>(%arg1) + %extracted_slice = tensor.extract_slice %cst_1[%3, 0] [128, 128] [1, 1] : tensor<256x128xf32> to tensor<128x128xf32> + %extracted_slice_5 = tensor.extract_slice %cst_2[0, %4] [128, 256] [1, 1] : tensor<128x512xf32> to tensor<128x256xf32> + %extracted_slice_6 = tensor.extract_slice %arg2[%5, %6] [128, 256] [1, 1] : tensor<256x512xf32> to tensor<128x256xf32> + %c0 = arith.constant 0 : index + %c0_7 = arith.constant 0 : index + %c2_8 = arith.constant 2 : index + %7 = tensor.empty() : tensor<128x256x2xf32> + %cst_9 = arith.constant 0.000000e+00 : f32 + %8 = linalg.fill ins(%cst_9 : f32) outs(%7 : tensor<128x256x2xf32>) -> tensor<128x256x2xf32> + %c2_10 = arith.constant 2 : index + %9 = scf.forall (%arg3) in (2) shared_outs(%arg4 = %8) -> (tensor<128x256x2xf32>) { + %13 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg3) + %extracted_slice_11 = tensor.extract_slice %arg4[0, 0, %arg3] [128, 256, 1] [1, 1, 1] : tensor<128x256x2xf32> to tensor<128x256xf32> + %c0_12 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c128_13 = arith.constant 128 : index + %14 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%arg3, %c128_13] + %15 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%14, %c0_12] + %16 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%c2_10, %c128_13] + %17 = scf.for %arg5 = %15 to %c128 step %16 iter_args(%arg6 = %extracted_slice_11) -> (tensor<128x256xf32>) { + %extracted_slice_14 = tensor.extract_slice %extracted_slice[0, %arg5] [128, 128] [1, 1] : tensor<128x128xf32> to tensor<128x128xf32> + %extracted_slice_15 = tensor.extract_slice %extracted_slice_5[%arg5, 0] [128, 256] [1, 1] : tensor<128x256xf32> to tensor<128x256xf32> + %extracted_slice_16 = tensor.extract_slice %arg6[0, 0] [128, 256] [1, 1] : tensor<128x256xf32> to tensor<128x256xf32> + %c0_17 = arith.constant 0 : index + %c128_18 = arith.constant 128 : index + %c128_19 = arith.constant 128 : index + %18 = scf.for %arg7 = %c0_17 to %c128_18 step %c128_19 iter_args(%arg8 = %extracted_slice_16) -> (tensor<128x256xf32>) { + %c0_20 = arith.constant 0 : index + %c256 = arith.constant 256 : index + %c256_21 = arith.constant 256 : index + %19 = scf.for %arg9 = %c0_20 to %c256 step %c256_21 iter_args(%arg10 = %arg8) -> (tensor<128x256xf32>) { + %c0_22 = arith.constant 0 : index + %c128_23 = arith.constant 128 : index + %c64 = arith.constant 64 : index + %20 = scf.for %arg11 = %c0_22 to %c128_23 step %c64 iter_args(%arg12 = %arg10) -> (tensor<128x256xf32>) { + %extracted_slice_24 = tensor.extract_slice %extracted_slice_14[%arg7, %arg11] [128, 64] [1, 1] : tensor<128x128xf32> to tensor<128x64xf32> + %extracted_slice_25 = tensor.extract_slice %extracted_slice_15[%arg11, %arg9] [64, 256] [1, 1] : tensor<128x256xf32> to tensor<64x256xf32> + %extracted_slice_26 = tensor.extract_slice %arg12[%arg7, %arg9] [128, 256] [1, 1] : tensor<128x256xf32> to tensor<128x256xf32> + %c0_27 = arith.constant 0 : index + %c128_28 = arith.constant 128 : index + %c32 = arith.constant 32 : index + %21 = scf.for %arg13 = %c0_27 to %c128_28 step %c32 iter_args(%arg14 = %extracted_slice_26) -> (tensor<128x256xf32>) { + %c0_30 = arith.constant 0 : index + %c256_31 = arith.constant 256 : index + %c32_32 = arith.constant 32 : index + %22 = scf.for %arg15 = %c0_30 to %c256_31 step %c32_32 iter_args(%arg16 = %arg14) -> (tensor<128x256xf32>) { + %c0_33 = arith.constant 0 : index + %c64_34 = arith.constant 64 : index + %c64_35 = arith.constant 64 : index + %23 = scf.for %arg17 = %c0_33 to %c64_34 step %c64_35 iter_args(%arg18 = %arg16) -> (tensor<128x256xf32>) { + %extracted_slice_36 = tensor.extract_slice %extracted_slice_24[%arg13, %arg17] [32, 64] [1, 1] : tensor<128x64xf32> to tensor<32x64xf32> + %extracted_slice_37 = tensor.extract_slice %extracted_slice_25[%arg17, %arg15] [64, 32] [1, 1] : tensor<64x256xf32> to tensor<64x32xf32> + %extracted_slice_38 = tensor.extract_slice %arg18[%arg13, %arg15] [32, 32] [1, 1] : tensor<128x256xf32> to tensor<32x32xf32> + %expanded = tensor.expand_shape %extracted_slice_36 [[0, 1], [2]] : tensor<32x64xf32> into tensor<1x32x64xf32> + %expanded_39 = tensor.expand_shape %extracted_slice_37 [[0, 1], [2]] : tensor<64x32xf32> into tensor<1x64x32xf32> + %24 = linalg.batch_reduce_matmul ins(%expanded, %expanded_39 : tensor<1x32x64xf32>, tensor<1x64x32xf32>) outs(%extracted_slice_38 : tensor<32x32xf32>) -> tensor<32x32xf32> + %inserted_slice_40 = tensor.insert_slice %24 into %arg18[%arg13, %arg15] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<128x256xf32> + scf.yield %inserted_slice_40 : tensor<128x256xf32> + } + scf.yield %23 : tensor<128x256xf32> + } + scf.yield %22 : tensor<128x256xf32> + } + %inserted_slice_29 = tensor.insert_slice %21 into %arg12[%arg7, %arg9] [128, 256] [1, 1] : tensor<128x256xf32> into tensor<128x256xf32> + scf.yield %inserted_slice_29 : tensor<128x256xf32> + } + scf.yield %20 : tensor<128x256xf32> + } + scf.yield %19 : tensor<128x256xf32> + } + %inserted_slice = tensor.insert_slice %18 into %arg6[0, 0] [128, 256] [1, 1] : tensor<128x256xf32> into tensor<128x256xf32> + scf.yield %inserted_slice : tensor<128x256xf32> + } + scf.forall.in_parallel { + tensor.parallel_insert_slice %17 into %arg4[0, 0, %arg3] [128, 256, 1] [1, 1, 1] : tensor<128x256xf32> into tensor<128x256x2xf32> + } + } + %10 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%9 : tensor<128x256x2xf32>) outs(%extracted_slice_6 : tensor<128x256xf32>) { + ^bb0(%in: f32, %out: f32): + %13 = arith.addf %in, %out : f32 + linalg.yield %13 : f32 + } -> tensor<128x256xf32> + %11 = affine.apply affine_map<(d0) -> (d0 * 128)>(%arg0) + %12 = affine.apply affine_map<(d0) -> (d0 * 256)>(%arg1) + scf.forall.in_parallel { + tensor.parallel_insert_slice %10 into %arg2[%11, %12] [128, 256] [1, 1] : tensor<128x256xf32> into tensor<256x512xf32> + } +} +``` \ No newline at end of file diff --git a/include/gc/Analysis/MatmulConfigAnalysis.h b/include/gc/Analysis/MatmulConfigAnalysis.h new file mode 100644 index 000000000..7bd1bb4f0 --- /dev/null +++ b/include/gc/Analysis/MatmulConfigAnalysis.h @@ -0,0 +1,112 @@ +//===-- MatmulConfigAnalysis.h - the analysis for matmul config -*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_MATMULCONFIGANALYSIS_H +#define MLIR_ANALYSIS_MATMULCONFIGANALYSIS_H + +#include "gc/Dialect/Linalgx/LinalgxOps.h" +#include "mlir/Dialect/DLTI/DLTI.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Interfaces/DataLayoutInterfaces.h" + +namespace mlir { +namespace gc { + +using namespace mlir; + +// The configuration for matmul tiling +// TODO: support batch matmul +struct MatmulConfig { + // The number of threads distributed to M, N, K + uint32_t MThreads, NThreads, KThreads; + // The outer block size for M, N, K which will be used to decide the loop tile + // size in single thread + uint32_t MBlock, NBlock, KBlock; + // The innermost block size for M, N, K which will be directly converted to + // brgemm. + uint32_t innerMostMBlock, innerMostNBlock, innerMostKBlock; +}; + +enum DimType { Batch, M, N, K }; + +// Extract the index of the given DimType in the DimType list +inline SmallVector extractDimTypeIdx(ArrayRef tyList, + DimType ty) { + SmallVector idxList; + for (auto [idx, type] : llvm::enumerate(tyList)) { + if (type == ty) { + idxList.push_back(idx); + } + } + return idxList; +} + +// Get the operand dim type for every operand for the given linalg op +inline FailureOr>> +getOprandDimType(linalg::LinalgOp &linalgOp) { + // TODO: replace the linalgx op with generic op + if (llvm::isa(linalgOp)) { + return SmallVector>{ + SmallVector{DimType::M, DimType::K}, + SmallVector{DimType::K, DimType::N}, + SmallVector{DimType::M, DimType::N}}; + } else if (llvm::isa(linalgOp)) { + return SmallVector>{ + SmallVector{DimType::M, DimType::K}, + SmallVector{DimType::N, DimType::K, DimType::K, DimType::N, + DimType::K}, + SmallVector{DimType::M, DimType::N, DimType::M, DimType::N}}; + } else if (llvm::isa(linalgOp)) { + return SmallVector>{ + SmallVector{DimType::M, DimType::K, DimType::M, DimType::K}, + SmallVector{DimType::N, DimType::K, DimType::K, DimType::N, + DimType::K}, + SmallVector{DimType::M, DimType::N, DimType::M, DimType::N}}; + } else if (llvm::isa(linalgOp)) { + return SmallVector>{ + SmallVector{DimType::Batch, DimType::M, DimType::K}, + SmallVector{DimType::Batch, DimType::K, DimType::N}, + SmallVector{DimType::Batch, DimType::M, DimType::N}}; + } else if (llvm::isa(linalgOp)) { + return SmallVector>{ + SmallVector{DimType::K, DimType::M}, + SmallVector{DimType::K, DimType::N}, + SmallVector{DimType::M, DimType::N}}; + } else if (llvm::isa(linalgOp)) { + return SmallVector>{ + SmallVector{DimType::M, DimType::K}, + SmallVector{DimType::N, DimType::K}, + SmallVector{DimType::M, DimType::N}}; + } else if (llvm::isa(linalgOp)) { + return SmallVector>{ + SmallVector{DimType::Batch, DimType::K, DimType::M}, + SmallVector{DimType::Batch, DimType::K, DimType::N}, + SmallVector{DimType::Batch, DimType::M, DimType::N}}; + } else if (llvm::isa(linalgOp)) { + return SmallVector>{ + SmallVector{DimType::Batch, DimType::M, DimType::K}, + SmallVector{DimType::Batch, DimType::N, DimType::K}, + SmallVector{DimType::Batch, DimType::M, DimType::N}}; + } + return failure(); +} + +// The analysis to extract the matmul configuration from the given linalg op +struct MatmulConfigAnalysis { +public: + explicit MatmulConfigAnalysis(Operation *root); + MatmulConfig getConfig() { return config; } + +private: + MatmulConfig config; +}; + +} // namespace gc +} // namespace mlir + +#endif \ No newline at end of file diff --git a/include/gc/Transforms/Passes.td b/include/gc/Transforms/Passes.td index 9d75ac2e9..d5330851b 100644 --- a/include/gc/Transforms/Passes.td +++ b/include/gc/Transforms/Passes.td @@ -11,12 +11,6 @@ include "mlir/Pass/PassBase.td" -def TileLinalgNamed : Pass<"tile-named-linalg", "func::FuncOp"> { - let summary = "Tile linalg named operations."; - let dependentDialects = - ["linalg::LinalgDialect", "scf::SCFDialect", "tensor::TensorDialect"]; -} - #ifdef GC_HAS_ONEDNN_DIALECT def ConvertOneDNNGraphToLinalg : Pass<"convert-onednn-graph-to-linalg"> { let summary = @@ -71,6 +65,18 @@ def IterativeTilingAndFusion : Pass<"iterative-tiling-and-fusion", "Decide if enable cost model to control iterative fusion.">, ListOption<"defaultTileSize", "default-tile-size", "std::string", "Set default TileSize for the certain type of op, saying `matmul:{32,32}`">, + ]; +} +def DeepTileContractionNamedOp + : Pass<"deep-tile-contraction-named-op", "func::FuncOp"> { + let summary = "Tile linalg contraction named operation deeply"; + let description = + [{The pass tries to tile the linalg contraction named op deeply.}]; + let dependentDialects = [ + "func::FuncDialect", + "arith::ArithDialect", + "tensor::TensorDialect", + "linalg::LinalgDialect", ]; } @@ -87,4 +93,17 @@ def VerifyTargetDescription : Pass<"verify-target-description", "ModuleOp"> { ]; } +def SinkOpIntoInnerLoop : Pass<"sink-op-into-inner-loop"> { + let summary = "Sink operations into inner loops"; + let description = [{The pass tries to sink operations into inner loops as deep as possible to maximize the chance for outer loop optimization. + }]; + let dependentDialects = []; +} + +def MergeNestedForall : Pass<"merge-nested-forall"> { + let summary = "Merge nested scf.forall operations"; + let description = [{The pass tries to merge nested forall operations.}]; + let dependentDialects = ["scf::SCFDialect"]; +} + #endif // GC_DIALECT_GC_PASSES diff --git a/lib/gc/Analysis/CMakeLists.txt b/lib/gc/Analysis/CMakeLists.txt index c1c34ea50..d7160f350 100644 --- a/lib/gc/Analysis/CMakeLists.txt +++ b/lib/gc/Analysis/CMakeLists.txt @@ -4,6 +4,7 @@ gc_set_mlir_link_components(MLIR_LINK_COMPONENTS gc_add_mlir_library(GcAnalysis TargetDescriptionAnalysis.cpp + MatmulConfigAnalysis.cpp DEPENDS GraphCompilerPassIncGen @@ -12,4 +13,4 @@ gc_add_mlir_library(GcAnalysis ${mlir_dialect_libs} ${MLIR_LINK_COMPONENTS} GcInterface - ) \ No newline at end of file +) diff --git a/lib/gc/Analysis/MatmulConfigAnalysis.cpp b/lib/gc/Analysis/MatmulConfigAnalysis.cpp new file mode 100644 index 000000000..b31e0933e --- /dev/null +++ b/lib/gc/Analysis/MatmulConfigAnalysis.cpp @@ -0,0 +1,450 @@ +//===-- MatmulConfigAnalysis.cpp - Analysis for matmul config ---*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "gc/Analysis/MatmulConfigAnalysis.h" +#include "gc/Analysis/TargetDescriptionAnalysis.h" +#include +#include + +namespace mlir { +namespace gc { + +#define DEBUG_TYPE "matmul-config-analysis" + +llvm::raw_ostream &operator<<(llvm::raw_ostream &ss, + const MatmulConfig &config) { + + ss << "MThreads: " << config.MThreads << ", NThreads: " << config.NThreads + << ", KThreads: " << config.KThreads << ", MBlock: " << config.MBlock + << ", NBlock: " << config.NBlock << ", KBlock: " << config.KBlock + << ", innerMostMBlock: " << config.innerMostMBlock + << ", innerMostNBlock: " << config.innerMostNBlock + << ", innerMostKBlock: " << config.innerMostKBlock; + return ss; +} + +template +static llvm::raw_ostream &operator<<(llvm::raw_ostream &ss, + std::vector array) { + ss << "["; + llvm::interleaveComma(array, ss); + ss << "]"; + return ss; +} + +// generate the candidate for the block size(factor of `num`, pow of 2 which is +// less than `num`) +std::vector +getCandidate(uint32_t num, uint32_t floor, + uint32_t ceil = std::numeric_limits::max()) { + // factor + std::vector candidates; + uint32_t upperbound = std::min(num, ceil); + for (uint32_t i = floor; i <= upperbound; i++) + if (num % i == 0) + candidates.push_back(i); + + // the pow of 2 + uint32_t candidate = 1U; + while (candidate < floor) + candidate *= 2; + while (candidate <= upperbound) { + candidates.push_back(candidate); + candidate *= 2; + } + // remove duplicate candidates + std::sort(candidates.begin(), candidates.end()); + candidates.erase(std::unique(candidates.begin(), candidates.end()), + candidates.end()); + return candidates; +} + +// check if the threads are valid +bool validateThreads(ArrayRef threads, + CPUTargetDescriptionAnalysis &sysDesc) { + uint32_t numThreads = sysDesc.getNumThreads(); + uint32_t actualThreads = 1U; + for (uint32_t t : threads) + actualThreads *= t; + return actualThreads == numThreads; +} + +// calculate the cost of the hardware efficiency(whether the vector register is +// fully utilized) +double vectorRegEfficiencyCost(linalg::LinalgOp &linalgOp, + ArrayRef shape, + const MatmulConfig &config, + CPUTargetDescriptionAnalysis &sysDesc) { + size_t dtypeSize = DataLayout().getTypeSizeInBits( + ShapeAdaptor(linalgOp.getDpsInputs()[1].getType()).getElementType()); + size_t maxVectorWidth = sysDesc.getMaxVectorWidth() / dtypeSize; + // TODO: take matrix register like amx into account + double cost = (maxVectorWidth - config.innerMostMBlock % maxVectorWidth) % + maxVectorWidth * 1.0 / config.innerMostMBlock + + (maxVectorWidth - config.innerMostKBlock % maxVectorWidth) % + maxVectorWidth * 1.0 / config.innerMostKBlock + + (maxVectorWidth - config.innerMostNBlock % maxVectorWidth) % + maxVectorWidth * 1.0 / config.innerMostNBlock; + return cost; +} + +// calculate the cost of the workload balance +double workloadBalancedCost(linalg::LinalgOp &linalgOp, + ArrayRef shape, + const MatmulConfig &config, + CPUTargetDescriptionAnalysis &sysDesc) { + if (shape.size() < 3) { + // Has an invalid shape + return 0; + } + uint32_t M = shape[0], N = shape[1], K = shape[2]; + uint32_t MTaskNum = llvm::divideCeil(M, config.MBlock); + uint32_t NTaskNum = llvm::divideCeil(N, config.NBlock); + uint32_t KTaskNum = llvm::divideCeil(K, config.KBlock); + double cost = (MTaskNum % config.MThreads) * 1.0 / MTaskNum + + (NTaskNum % config.NThreads) * 1.0 / NTaskNum + + (KTaskNum % config.KThreads) * 1.0 / KTaskNum; + if (MTaskNum < config.MThreads || NTaskNum < config.NThreads || + KTaskNum < config.KThreads) { + double threadNotFulllyUtilizedPenalty = 10.0; + cost *= threadNotFulllyUtilizedPenalty; + } + return cost; +} + +// calculate the cost of the memory consumption on the thread +double memoryConsumptionOnThreadCost(linalg::LinalgOp &linalgOp, + ArrayRef shape, + const MatmulConfig &config, + CPUTargetDescriptionAnalysis &sysDesc) { + if (shape.size() < 3) { + // Has an invalid shape + return 0; + } + uint32_t M = shape[0], N = shape[1], K = shape[2]; + size_t dtypeSize = DataLayout().getTypeSize( + ShapeAdaptor(linalgOp.getDpsInputs()[1].getType()).getElementType()); + // if use K split, there will be one more final reduce and break the post + // fusion + double KSplitPenalty = 8.0 * dtypeSize; + double memoryConsumptionPerThread = + M * K * 1.0 / config.MThreads / config.KThreads + + K * N * 1.0 / config.KThreads / config.NThreads + + M * N * ((config.KThreads - 1) * KSplitPenalty + 1.0) / config.MThreads / + config.NThreads; + return memoryConsumptionPerThread; +} + +// calculate the cost of the computation intensity on the L2 cache +double computationIntensityOnL2Cache(linalg::LinalgOp &linalgOp, + ArrayRef shape, + const MatmulConfig &config, + CPUTargetDescriptionAnalysis &sysDesc) { + double fullLoadRatio = 0.7; + uint32_t L2Cache = sysDesc.getCacheSize(2); + size_t dtypeSize = DataLayout().getTypeSize( + ShapeAdaptor(linalgOp.getDpsInputs()[1].getType()).getElementType()); + uint32_t outOfCachePenalty = 1024; + double FLOPS = 2.0 * config.MBlock * config.NBlock * config.KBlock; + double memoryConsumption = config.MBlock * config.NBlock + + config.NBlock * config.KBlock + + config.MBlock * config.KBlock; + double computationIntensity = FLOPS / memoryConsumption; + if (memoryConsumption * dtypeSize > L2Cache * fullLoadRatio) + computationIntensity /= outOfCachePenalty; + return 1 / computationIntensity; +} + +using CostModelFn = std::function shape, MatmulConfig cfg, + CPUTargetDescriptionAnalysis &sysDesc)>; + +// filter the config by the cost model +std::vector +filterConfigByCostModel(ArrayRef configs, + linalg::LinalgOp &linalgOp, ArrayRef shape, + CPUTargetDescriptionAnalysis &sysDesc, + const CostModelFn &costModel, float preserveRatio = 0.5, + float threshold = -1) { + std::vector result; + std::vector costs; + std::vector idx; + for (auto &&[i, config] : llvm::enumerate(configs)) { + costs.push_back(costModel(linalgOp, shape, config, sysDesc)); + idx.push_back(i); + } + std::stable_sort(idx.begin(), idx.end(), [&costs](size_t i1, size_t i2) { + return costs[i1] < costs[i2]; + }); + double thresholdCost = costs[idx[(size_t)(preserveRatio * configs.size())]]; + thresholdCost = + threshold < thresholdCost && threshold > 0 ? threshold : thresholdCost; + for (const auto &i : idx) + if (costs[i] <= thresholdCost) + result.push_back(configs[i]); + + LLVM_DEBUG(llvm::dbgs() << "thresholdCost is: " << thresholdCost + << "\nbest with cost: " << costs[idx[0]] << "\n" + << configs[idx[0]] << "\n worst with cost: " + << costs[idx[configs.size() - 1]] << "\n" + << configs[idx[configs.size() - 1]] << "\n"); + if (result.empty()) + result = configs; + return result; +} + +// prepare the config candidates +std::vector +prepareConfigCandidates(Operation *root, CPUTargetDescriptionAnalysis &sysDesc, + ArrayRef shape, + ArrayRef givenInnermostBlock) { + if (shape.size() < 3) { + LLVM_DEBUG(llvm::dbgs() + << "The shape is invalid, no candidate is generated\n"); + return {}; + } + std::vector configs; + uint32_t threads = sysDesc.getNumThreads(); + std::vector MThreadsCandidates = + getCandidate((uint32_t)threads, 1U); + std::vector NThreadsCandidates = + getCandidate((uint32_t)threads, 1U); + std::vector KThreadsCandidates = + getCandidate((uint32_t)threads, 1U); + uint32_t noSmallBlockNeedThreshold = 8 * 8U; + std::vector MBlockCandidates = getCandidate( + (uint32_t)shape[0], shape[0] >= noSmallBlockNeedThreshold ? 8U : 1U, + (uint32_t)shape[0]); + std::vector NBlockCandidates = + getCandidate((uint32_t)shape[1], + shape[1] >= noSmallBlockNeedThreshold ? 8U : 1U, shape[1]); + std::vector KBlockCandidates = + getCandidate((uint32_t)shape[2], + shape[2] >= noSmallBlockNeedThreshold ? 8U : 1U, shape[2]); + std::vector innerMostMBlockCandidates = + givenInnermostBlock[0] != 0 && givenInnermostBlock.size() == 3 + ? std::vector{givenInnermostBlock[0]} + : getCandidate((uint32_t)shape[0], + shape[0] >= noSmallBlockNeedThreshold ? 8U : 1U, 256U); + std::vector innerMostNBlockCandidates = + givenInnermostBlock[1] != 0 && givenInnermostBlock.size() == 3 + ? std::vector{givenInnermostBlock[1]} + : getCandidate((uint32_t)shape[1], + shape[1] >= noSmallBlockNeedThreshold ? 8U : 1U, 256U); + std::vector innerMostKBlockCandidates = + givenInnermostBlock[2] != 0 && givenInnermostBlock.size() == 3 + ? std::vector{givenInnermostBlock[2]} + : getCandidate((uint32_t)shape[2], + shape[2] >= noSmallBlockNeedThreshold ? 8U : 1U, 256U); + + // TODO: improve via multi threading or add more constraints to restrict the + // candidate size + for (uint32_t MThreads : MThreadsCandidates) { + for (uint32_t NThreads : NThreadsCandidates) { + for (uint32_t KThreads : KThreadsCandidates) { + if (!validateThreads({MThreads, NThreads, KThreads}, sysDesc)) + continue; + for (uint32_t MBlock : MBlockCandidates) { + for (uint32_t innerMostMBlock : innerMostMBlockCandidates) { + if (MBlock % innerMostMBlock != 0 || + shape[0] % innerMostMBlock != 0) + continue; + for (uint32_t NBlock : NBlockCandidates) { + for (uint32_t innerMostNBlock : innerMostNBlockCandidates) { + if (NBlock % innerMostNBlock != 0 || + shape[1] % innerMostNBlock != 0) + continue; + for (uint32_t KBlock : KBlockCandidates) { + for (uint32_t innerMostKBlock : innerMostKBlockCandidates) { + if (KBlock % innerMostKBlock != 0 || + shape[2] % innerMostKBlock != 0) + continue; + MatmulConfig config{ + MThreads, NThreads, KThreads, + MBlock, NBlock, KBlock, + innerMostMBlock, innerMostNBlock, innerMostKBlock}; + configs.push_back(config); + } + } + } + } + } + } + } + } + } + LLVM_DEBUG( + llvm::dbgs() << "Finish generating candidates. ConfigCandidates size: " + << configs.size() << "\n"); + return configs; +} + +bool validateConfig(const MatmulConfig &cfg) { + if (cfg.MThreads <= 0 || cfg.NThreads <= 0 || cfg.KThreads <= 0 || + cfg.MBlock <= 0 || cfg.NBlock <= 0 || cfg.KBlock <= 0 || + cfg.innerMostMBlock <= 0 || cfg.innerMostNBlock <= 0 || + cfg.innerMostKBlock <= 0) + return false; + if (cfg.MBlock % cfg.innerMostMBlock != 0 || + cfg.NBlock % cfg.innerMostNBlock != 0 || + cfg.KBlock % cfg.innerMostKBlock != 0) + return false; + return true; +} + +// read the config from the attributes for tuning +bool readConfigFromAttrs(MatmulConfig &config, ArrayRef attrs) { + size_t cfgItemCnt = 0; + for (const auto &attr : attrs) { + if (attr.getName() == "KBlock") { + config.KBlock = cast(attr.getValue()).getInt(); + cfgItemCnt++; + } else if (attr.getName() == "KThreads") { + config.KThreads = cast(attr.getValue()).getInt(); + cfgItemCnt++; + } else if (attr.getName() == "NBlock") { + config.NBlock = cast(attr.getValue()).getInt(); + cfgItemCnt++; + } else if (attr.getName() == "NThreads") { + config.NThreads = cast(attr.getValue()).getInt(); + cfgItemCnt++; + } else if (attr.getName() == "MBlock") { + config.MBlock = cast(attr.getValue()).getInt(); + cfgItemCnt++; + } else if (attr.getName() == "MThreads") { + config.MThreads = cast(attr.getValue()).getInt(); + cfgItemCnt++; + } else if (attr.getName() == "innermostMBlock") { + config.innerMostMBlock = cast(attr.getValue()).getInt(); + cfgItemCnt++; + } else if (attr.getName() == "innermostNBlock") { + config.innerMostNBlock = cast(attr.getValue()).getInt(); + cfgItemCnt++; + } else if (attr.getName() == "innermostKBlock") { + config.innerMostKBlock = cast(attr.getValue()).getInt(); + cfgItemCnt++; + } + } + if (validateConfig(config)) { + return cfgItemCnt == 9; + } else { + LLVM_DEBUG(llvm::dbgs() << "The predefined config is invalid\n"); + return false; + } +} + +// Analyze the workload and system description to generate the default config +// Factor to consider: +// thread utilization +// computation intensity +// cache locality +// memory requirements +// computation unit efficiency +// padding/pack cost +// workload balance +// communication +// previous matmul +MatmulConfigAnalysis::MatmulConfigAnalysis(Operation *root) { + if (auto linalgOp = dyn_cast(root)) { + CPUTargetDescriptionAnalysis sysDesc(root); + SmallVector> oprandDimType = + *getOprandDimType(linalgOp); + // get the origin M,N,K size + SmallVector MDimTypeIdx = + extractDimTypeIdx(oprandDimType[0], DimType::M); + SmallVector KDimTypeIdx = + extractDimTypeIdx(oprandDimType[1], DimType::K); + SmallVector NDimTypeIdx = + extractDimTypeIdx(oprandDimType[1], DimType::N); + uint32_t M = 1U, N = 1U, K = 1U; + for (auto &&[s, dimType] : + llvm::zip(linalgOp.getShape(linalgOp.getDpsInputOperand(0)), + oprandDimType[0])) + if (dimType == DimType::M) + M *= s; + for (auto &&[s, dimType] : + llvm::zip(linalgOp.getShape(linalgOp.getDpsInputOperand(1)), + oprandDimType[1])) { + if (dimType == DimType::N) + N *= s; + else if (dimType == DimType::K) + K *= s; + } + + // innermost Block, if the layout is blockied layout, the innermost block + // will derived from the layout directly + uint32_t defaultBlock = 32; + config.innerMostMBlock = M % defaultBlock == 0 ? defaultBlock : M; + config.innerMostNBlock = N % defaultBlock == 0 ? defaultBlock : N; + config.innerMostKBlock = K % defaultBlock == 0 ? defaultBlock : K; + SmallVector givenInnermostBlock; + if (MDimTypeIdx.size() > 1) { + config.innerMostMBlock = 1; + for (auto &&[i, d] : llvm::enumerate(MDimTypeIdx)) + if (i != 0) + config.innerMostMBlock *= + linalgOp.getShape(linalgOp.getDpsInputOperand(0))[d]; + givenInnermostBlock.push_back(config.innerMostMBlock); + } else { + givenInnermostBlock.push_back(0); + } + if (NDimTypeIdx.size() > 1) { + config.innerMostNBlock = 1; + for (auto &&[i, d] : llvm::enumerate(NDimTypeIdx)) + if (i != 0) + config.innerMostNBlock *= + linalgOp.getShape(linalgOp.getDpsInputOperand(1))[d]; + givenInnermostBlock.push_back(config.innerMostNBlock); + } else { + givenInnermostBlock.push_back(0); + } + if (KDimTypeIdx.size() > 1) { + config.innerMostKBlock = 1; + for (auto &&[i, d] : llvm::enumerate(KDimTypeIdx)) + if (i != 0) + config.innerMostKBlock *= + linalgOp.getShape(linalgOp.getDpsInputOperand(1))[d]; + givenInnermostBlock.push_back(config.innerMostKBlock); + } else { + givenInnermostBlock.push_back(0); + } + + LLVM_DEBUG(llvm::dbgs() + << "M: " << M << ", N: " << N << ", K: " << K << "\n"); + + // try to read the config from the attributes + SmallVector attrs(linalgOp->getAttrs()); + bool hasPredefinedConfig = readConfigFromAttrs(config, attrs); + + // if there is a given config, skip the cost model + if (!hasPredefinedConfig) { + LLVM_DEBUG(llvm::dbgs() << "No predefined config\n"); + // TODO: Could add a weight or priority for cost model + SmallVector> costModelList = + {{workloadBalancedCost, "workloadBalancedCost", 1}, + {vectorRegEfficiencyCost, "vectorRegEfficiencyCost ", -1}, + {computationIntensityOnL2Cache, "computationIntensityOnL2Cache", -1}, + {memoryConsumptionOnThreadCost, "memoryConsumptionOnThreadCost", + -1}}; + SmallVector shape = {M, N, K}; + std::vector configCandidates = + prepareConfigCandidates(root, sysDesc, shape, givenInnermostBlock); + for (auto &&[fn, name, threshold] : costModelList) + configCandidates = filterConfigByCostModel( + configCandidates, linalgOp, shape, sysDesc, fn, 0.5, threshold); + if (!configCandidates.empty()) + config = configCandidates[0]; + } + + LLVM_DEBUG(llvm::dbgs() + << "Final config\nNumThreads: " << sysDesc.getNumThreads() + << ", MatmulConfig: " << config << "\n"); + } +} +} // namespace gc +} // namespace mlir \ No newline at end of file diff --git a/lib/gc/Transforms/CMakeLists.txt b/lib/gc/Transforms/CMakeLists.txt index d240f28c1..705e257d7 100644 --- a/lib/gc/Transforms/CMakeLists.txt +++ b/lib/gc/Transforms/CMakeLists.txt @@ -12,10 +12,13 @@ get_property(mlir_conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) gc_add_mlir_library(GcPasses OneDNNGraphToLinalg.cpp Pipeline.cpp - TileNamed.cpp IterativeTilingAndFusion.cpp TilingUsingInterfaceX.cpp VerifyTargetDescription.cpp + DeepTileContractionNamedOp.cpp + TilingUtil.cpp + SinkOpIntoInnerLoop.cpp + MergeNestedForall.cpp DEPENDS GraphCompilerPassIncGen diff --git a/lib/gc/Transforms/DeepTileContractionNamedOp.cpp b/lib/gc/Transforms/DeepTileContractionNamedOp.cpp new file mode 100644 index 000000000..30d0e022f --- /dev/null +++ b/lib/gc/Transforms/DeepTileContractionNamedOp.cpp @@ -0,0 +1,1014 @@ +//===-- DeepTileContractionNamedOp.cpp - tile named op deeply ---*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "./TilingUtil.hpp" +#include "gc/Analysis/MatmulConfigAnalysis.h" +#include "gc/Dialect/Linalgx/LinalgxOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" +#include "mlir/Interfaces/TilingInterface.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "gc/Transforms/Passes.h" + +#include + +#include + +namespace mlir { +namespace gc { +#define GEN_PASS_DEF_DEEPTILECONTRACTIONNAMEDOP +#include "gc/Transforms/Passes.h.inc" + +namespace { + +// Util function to tensor view a ranked tensor to another ranked tensor without +// change the data layout +static Value +tensorViewRankedTensor(RewriterBase &rewriter, RankedTensorType outTensorType, + Value value, + ArrayRef permutation = SmallVector{}) { + Value result, currentValue = value; + Location loc = currentValue.getLoc(); + RankedTensorType inTensorType = + cast(currentValue.getType()); + ArrayRef inShape = inTensorType.getShape(); + ArrayRef outShape = outTensorType.getShape(); + mlir::Type tensorElementType = inTensorType.getElementType(); + + // Check if the input and output tensor have the same shape + if (inShape == outShape) + return currentValue; + + if (outShape.size() < inShape.size()) { + SmallVector reassocIndices; + uint64_t outIdx = 0UL, inIdx = 0UL; + while (inIdx < inShape.size() && outIdx < outShape.size()) { + ReassociationIndices firstEntry; + int64_t remaining = outShape[outIdx++]; + if (remaining == 1) { + firstEntry.push_back(inIdx++); + reassocIndices.push_back(firstEntry); + continue; + } + while (remaining > 1) { + remaining /= inShape[inIdx]; + firstEntry.push_back(inIdx++); + } + reassocIndices.push_back(firstEntry); + } + result = rewriter.create( + loc, outTensorType, currentValue, reassocIndices); + } else if (outShape.size() > inShape.size()) { + SmallVector reassocIndices; + uint64_t outIdx = 0UL, inIdx = 0UL; + while (outIdx < outShape.size() && inIdx < inShape.size()) { + ReassociationIndices firstEntry; + int64_t remaining = inShape[inIdx++]; + if (remaining == 1) { + firstEntry.push_back(outIdx++); + reassocIndices.push_back(firstEntry); + continue; + } + while (remaining > 1) { + remaining /= outShape[outIdx]; + firstEntry.push_back(outIdx++); + } + reassocIndices.push_back(firstEntry); + } + result = rewriter.create( + loc, outTensorType, currentValue, reassocIndices); + } else { + result = rewriter.create(loc, outTensorType, currentValue); + } + + // Transpose the tensor if permutation is not empty + if (!permutation.empty()) { + SmallVector transposeShape; + for (int64_t idx : permutation) + transposeShape.push_back(outShape[idx]); + Operation *initOp = rewriter.create(loc, transposeShape, + tensorElementType); + Operation *transposeOp = rewriter.create( + loc, result, initOp->getResult(0), permutation); + result = transposeOp->getResult(0); + } + return result; +} + +// Check if the loop is dummy loop(has only one iteration) +bool isDummyLoop(LoopLikeOpInterface loop) { + std::optional tripCount = mlir::constantTripCount( + *loop.getSingleLowerBound(), *loop.getSingleUpperBound(), + *loop.getSingleStep()); + if (tripCount) + return *tripCount == 1; + return false; +} + +// Build the linalg region for a linalg op +static void buildLinalgRegion(Operation *op, bool createTemporaryOp = false) { + SmallVector argTypes; + SmallVector argLocs; + for (const Value &opOperand : op->getOperands()) { + argTypes.push_back(getElementTypeOrSelf(opOperand.getType())); + argLocs.push_back(opOperand.getLoc()); + } + size_t initSize = op->getResults().size(); + ImplicitLocOpBuilder b(op->getLoc(), op->getContext()); + Region ®ion = op->getRegion(0); + Block *body = b.createBlock(®ion, /*insertPt=*/{}, argTypes, argLocs); + b.setInsertionPointToStart(body); + if (createTemporaryOp) { + unsigned argNum = body->getNumArguments(); + SmallVector vals; + for (size_t i = initSize; i > 0; i--) + vals.push_back(body->getArgument(argNum - i)); + OpBuilder::InsertionGuard g(b); + b.setInsertionPointToEnd(body); + Location loc = b.getUnknownLoc(); + b.create(loc, ValueRange(vals)); + } else { + linalg::LinalgDialect *dialect = + static_cast(op->getDialect()); + linalg::LinalgDialect::RegionBuilderFunType fun = + dialect->getRegionBuilder("linalg.matmul"); + fun(b, *body, op->getAttrs()); + } +} + +// Check if the linalgOp need to be legalized to f32 accumulation type +static bool needToLegalizeDtype(linalg::LinalgOp linalgOp) { + mlir::Type dataType = + dyn_cast(linalgOp.getDpsInputs()[0].getType()) + .getElementType(); + mlir::Type resultType = + dyn_cast(linalgOp.getDpsInits()[0].getType()) + .getElementType(); + return (dataType.isBF16() || dataType.isF16()) && dataType == resultType; +} + +struct DtypeLegalizeResult { + Operation *linalgOp = nullptr; + Operation *castOp = nullptr; +}; + +// Split a low precision matmul(bf16xbf16->bf16) to a combination +// matmul(bf16xbf16->f32) + cast(f32->bf16) +// if needFurtherFuse=true, a middle temporary linalgOp(bf16xbf16->(f32,bf16)) +// will be created +static FailureOr +matmulDtypeLegalize(RewriterBase &rewriter, Operation *op, + bool needCopyInit = true, bool needFurtherFuse = false) { + linalg::LinalgOp linalgOp = dyn_cast(op); + if (!linalgOp) + return failure(); + + Location loc = linalgOp->getLoc(); + DtypeLegalizeResult result; + + if (needToLegalizeDtype(linalgOp)) { + rewriter.setInsertionPoint(linalgOp); + IRMapping mapping; + Operation *initOp = linalgOp.getDpsInits()[0].getDefiningOp(); + Value initValue = initOp->getResult(0); + ShapedType initType = cast(initValue.getType()); + ArrayRef tensorShape = initType.getShape(); + SmallVector mixedShape; + for (auto &&[i, t] : llvm::enumerate(tensorShape)) { + if (initType.isDynamicDim(i)) { + Value val = rewriter.create(loc, initValue, i); + mixedShape.push_back(val); + } else { + mixedShape.push_back(getAsIndexOpFoldResult(rewriter.getContext(), t)); + } + } + Operation *currentOp; + + currentOp = rewriter.create( + loc, mixedShape, Float32Type::get(op->getContext())); + if (needCopyInit) + currentOp = rewriter.create(loc, initOp->getResult(0), + currentOp->getResult(0)); + SmallVector newOperands = linalgOp->getOperands(); + Value oldInit = newOperands.back(); + newOperands.back() = currentOp->getResult(0); + + SmallVector indexingMaps = linalgOp.getIndexingMapsArray(); + indexingMaps.push_back(indexingMaps.back()); + SmallVector attrs(linalgOp->getAttrs()); + SmallVector types = {currentOp->getResult(0).getType()}; + if (needFurtherFuse) { + NamedAttribute segmentSize = rewriter.getNamedAttr( + "operandSegmentSizes", rewriter.getDenseI32ArrayAttr({2, 2})); + for (auto &attr : attrs) { + if (attr.getName() == "indexing_maps") + attr.setValue(rewriter.getAffineMapArrayAttr(indexingMaps)); + if (attr.getName() == "operandSegmentSizes") + attr.setValue(segmentSize.getValue()); + } + types.push_back(oldInit.getType()); + newOperands.push_back(oldInit); + } + OperationState state(loc, linalgOp->getName(), newOperands, types, attrs); + state.addRegion(); + currentOp = rewriter.create(state); + buildLinalgRegion(currentOp, needFurtherFuse); + linalg::CopyOp castOp = rewriter.create( + loc, currentOp->getResult(0), initOp->getResult(0)); + result.linalgOp = currentOp; + result.castOp = castOp; + } + + return result; +} + +// Find the parent fill op of a value and will penetrate pack/pad ops +static Operation *findParentFillOp(Value val) { + SmallVector skipOpList = {"tensor.pack", "tensor.pad"}; + Operation *currentOp = val.getDefiningOp(); + while (currentOp && + llvm::find(skipOpList, currentOp->getName().getStringRef()) != + skipOpList.end() && + !isa(currentOp)) { + currentOp = currentOp->getOperand(0).getDefiningOp(); + } + if (currentOp && isa(currentOp)) + return currentOp; + return nullptr; +} + +// Get the parallel dims of a matmul op +static void getMatmulParallelDims(linalg::LinalgOp linalgOp, + unsigned operandIdx, + SmallVectorImpl &dims) { + AffineMap map = + linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(operandIdx)); + SmallVector iteratorTypes = + linalgOp.getIteratorTypesArray(); + + ArrayRef results = map.getResults(); + for (const AffineExpr &dim : results) { + AffineDimExpr dimExpr = dyn_cast(dim); + if (dimExpr && iteratorTypes[dimExpr.getPosition()] == + mlir::utils::IteratorType::parallel) + dims.push_back(dimExpr.getPosition()); + } +} + +// set the dynamic size to static size for ExtractSliceOp according to the tile +// config +static void setStaticSizeForExtractSliceOp(RewriterBase &rewriter, + Operation *op, bool isExtract, + SmallVector size, + int shrinDimNum = 0) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(op); + if (auto extractSlice = dyn_cast(op)) { + SmallVector mixedOffsets = extractSlice.getMixedOffsets(); + SmallVector mixedSizes = extractSlice.getMixedSizes(); + SmallVector mixedStrides = extractSlice.getMixedStrides(); + auto targetTensor = mlir::RankedTensorType::get( + SmallVector(size.begin() + shrinDimNum, size.end()), + extractSlice.getResult().getType().getElementType()); + for (auto &&[i, s] : llvm::enumerate(size)) + mixedSizes[i] = getAsIndexOpFoldResult(rewriter.getContext(), s); + Operation *newExtractSliceOp = rewriter.create( + extractSlice->getLoc(), extractSlice.getSource(), mixedOffsets, + mixedSizes, mixedStrides); + if (shrinDimNum > 0) { + rewriter.setInsertionPointAfter(newExtractSliceOp); + Value viewResult = tensorViewRankedTensor( + rewriter, targetTensor, newExtractSliceOp->getResult(0)); + rewriter.replaceOp(extractSlice, viewResult); + } else { + rewriter.replaceOp(extractSlice, newExtractSliceOp); + } + } +} + +// set the dynamic size to static size for InsertSliceOp according to the tile +// config +static void setStaticSizeForInsertSliceOp(RewriterBase &rewriter, Operation *op, + Value source, + SmallVector size) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(op); + if (auto insertSlice = dyn_cast(op)) { + SmallVector mixedOffsets = insertSlice.getMixedOffsets(); + SmallVector mixedSizes = insertSlice.getMixedSizes(); + SmallVector mixedStrides = insertSlice.getMixedStrides(); + for (auto &&[i, s] : llvm::enumerate(size)) + mixedSizes[i] = getAsIndexOpFoldResult(rewriter.getContext(), s); + auto targetTensor = mlir::RankedTensorType::get( + size, insertSlice.getDest().getType().getElementType()); + Value viewResult = tensorViewRankedTensor(rewriter, targetTensor, source); + rewriter.replaceOpWithNewOp( + insertSlice, viewResult, insertSlice.getDest(), mixedOffsets, + mixedSizes, mixedStrides); + } +} + +using InnermostFullResultCallBackFn = std::function( + RewriterBase &rewriter, Location loc, linalg::LinalgOp linalgop)>; + +using FinalReduceCallBackFn = std::function( + RewriterBase &rewriter, Location loc, + linalg::ForallReductionTilingResult result)>; + +struct OuterLoopGenerationOption { + enum LoopType { ForOp, ForallOp }; + SmallVector> nestedTileSizes; + SmallVector loopType; + SmallVector> loopDim; + SmallVector innermostFullResultCallBacks; + SmallVector finalReduceCallBacks; + bool isPartialResult = false; +}; + +struct OuterLoopGenerationResult { + /// Tiled operations that are generated during tiling. The order does not + /// matter except the last op. The replacements are expected to be the results + /// of the last op. + SmallVector tiledOps; + /// The `scf.for` operations that iterate over the tiles. + SmallVector loops; + SmallVector reductionLoops; +}; + +// Generate outer loop for a linalg op +static FailureOr +generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp, + const OuterLoopGenerationOption &option) { + OuterLoopGenerationResult result; + SmallVector> nestedTileSizes = option.nestedTileSizes; + SmallVector loopType = option.loopType; + SmallVector> loopDim = option.loopDim; + SmallVector iteratorTypes = + linalgOp.getIteratorTypesArray(); + + if (loopType.size() != loopDim.size() || + loopDim.size() != nestedTileSizes.size()) + return b.notifyMatchFailure( + linalgOp, + "loopType, loopDim and nestedTileSizes should have the same size"); + + if (linalgOp.hasPureBufferSemantics()) + return b.notifyMatchFailure( + linalgOp, "currentOp should not has pure buffer semantics"); + linalg::LinalgOp currentOp = linalgOp; + + bool hasFullResult = !option.isPartialResult; + for (auto &&[i, loopType] : llvm::enumerate(loopType)) { + ArrayRef currentDim = loopDim[i]; + ArrayRef currentTileSize = nestedTileSizes[i]; + if (loopType == OuterLoopGenerationOption::LoopType::ForOp) { + for (auto &&[d, tile] : llvm::zip(currentDim, currentTileSize)) { + scf::SCFTilingOptions tileOption; + SmallVector TileSizes( + currentOp.getNumLoops(), getAsIndexOpFoldResult(b.getContext(), 0)); + TileSizes[d] = getAsIndexOpFoldResult(b.getContext(), tile); + tileOption.setTileSizes(TileSizes); + tileOption.setLoopType(scf::SCFTilingOptions::LoopType::ForOp); + OpBuilder::InsertionGuard guard(b); + b.setInsertionPoint(currentOp); + if (iteratorTypes[d] == mlir::utils::IteratorType::reduction && + tile != 0 && hasFullResult) { + for (const auto &fn : option.innermostFullResultCallBacks) { + FailureOr result = + fn(b, currentOp->getLoc(), currentOp); + if (succeeded(result)) + currentOp = *result; + } + hasFullResult = false; + } + FailureOr tilingResult = scf::tileUsingSCF( + b, cast(currentOp.getOperation()), tileOption); + if (failed(tilingResult)) + return failure(); + + if (!isDummyLoop(tilingResult->loops.back())) { + b.replaceOp(currentOp, tilingResult->replacements); + currentOp = dyn_cast(tilingResult->tiledOps.back()); + if (iteratorTypes[d] == mlir::utils::IteratorType::reduction) + result.reductionLoops.push_back(tilingResult->loops.back()); + result.loops.push_back(tilingResult->loops.back()); + } + } + } else if (loopType == OuterLoopGenerationOption::LoopType::ForallOp) { + SmallVector tileSizes( + currentOp.getNumLoops(), getAsIndexOpFoldResult(b.getContext(), 0)); + SmallVector threads( + currentOp.getNumLoops(), getAsIndexOpFoldResult(b.getContext(), 0)); + SmallVector reductionDims; + SmallVector loopRanges = + cast(currentOp.getOperation()).getIterationDomain(b); + currentOp.getReductionDims(reductionDims); + bool tileOnReduction = false; + for (auto &&[d, tile] : llvm::zip(currentDim, currentTileSize)) { + if (llvm::find(reductionDims, d) != reductionDims.end() && tile != 0 && + (!getConstantIntValue(loopRanges[d].size) || + tile != + static_cast(*getConstantIntValue(loopRanges[d].size)))) + tileOnReduction = true; + if (llvm::find(reductionDims, d) != reductionDims.end() && + !dyn_cast(currentOp.getOperation())) { + tileSizes[d] = getAsIndexOpFoldResult(b.getContext(), 0); + tileOnReduction = false; + } else { + tileSizes[d] = getAsIndexOpFoldResult(b.getContext(), tile); + } + } + + OpBuilder::InsertionGuard guard(b); + b.setInsertionPoint(currentOp); + if (tileOnReduction) { + for (auto &&[idx, tile] : llvm::enumerate(tileSizes)) + if (isConstantIntValue(tile, 0) && + llvm::find(reductionDims, idx) != reductionDims.end()) + tileSizes[idx] = loopRanges[idx].size; + SmallVector newParallelDims; + for (auto iter : llvm::enumerate(reductionDims)) + newParallelDims.push_back( + getAsIndexOpFoldResult(b.getContext(), iter.index())); + FailureOr tilingResult = + linalgX::tileReductionUsingForall( + b, cast(currentOp.getOperation()), + {}, tileSizes, newParallelDims, std::nullopt); + if (failed(tilingResult) && + llvm::hasSingleElement(tilingResult->parallelTiledOps)) + return failure(); + currentOp = + dyn_cast(tilingResult->parallelTiledOps.back()); + if (!tilingResult->mergeOps.empty()) { + for (const auto &fn : option.finalReduceCallBacks) { + FailureOr result = + fn(b, currentOp->getLoc(), *tilingResult); + if (succeeded(result)) + currentOp = *result; + } + } + } else { + scf::SCFTilingOptions tileOption; + tileOption.setTileSizes(tileSizes); + tileOption.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp); + FailureOr tilingResult = scf::tileUsingSCF( + b, cast(currentOp.getOperation()), tileOption); + if (failed(tilingResult)) + return failure(); + b.replaceOp(currentOp, tilingResult->replacements); + currentOp = dyn_cast(tilingResult->tiledOps.back()); + } + } + } + result.tiledOps.emplace_back(currentOp); + return result; +} + +// Turn a OpFoldResult into a Value +static Value turnOpFoldResultIntoValue(RewriterBase &rewriter, Location loc, + OpFoldResult result) { + if (auto value = dyn_cast(result)) + return value; + if (auto attr = dyn_cast(result)) { + if (auto val = dyn_cast(attr)) { + if (val.getType().isIndex()) + return rewriter.create(loc, val.getInt()); + else + return rewriter.create(loc, val.getInt(), + val.getType()); + } + } + return Value(); +} + +/* +matmul(A, B) -> C +----------------> +forall([PM, PN, PK]: [MThreads, NThreads, KThreads]) { + CSlice = [KThreads, PM * MOuterBlock: (PM + 1) * MOuterBlock, + PN * NOuterBlock: (PN + 1) * NOuterBlock] + ASlice = A[PM * MOuterBlock: (PM + 1) * MOuterBlock, PK * KOuterBlock * (PK ++ 1) * KOuterBlock] + BSlice = B[PK * KOuterBlock * (PK + 1) * KOuterBlock, PN * +NOuterBlock: (PN + 1) * NOuterBlock] CSlice2 = CSlice[PK, PM * MOuterBlock: (PM ++ 1) * MOuterBlock, PN * NOuterBlock: (PN + 1) * NOuterBlock] + MNumBlock = MOuterBlock / MBlock + NNumBlock = NOuterBlock / NBlock + KNumBlock = KOuterBlock / KBlovk + for([om, on, ok]: [MNumBlock, NNumBlock, KNumBlock]) { + ASlice2 = ASlice[om * MBlock: (om + 1) * MBlock, ok * KBlock: (ok + 1) * +KBlock] + BSlice2 = BSlice[0, ok * KBlock: (ok + 1) * KBlock, on * NBlock: (on + +1) * NBlock] + CSlice3 = CSlice2[0, om * MBlock: (om + 1) * MBlock, on * NBlock: +(on + 1) * NBlock] (init with 0 when ok == 0) + MNumInnerBlock = MBlock / iim_block_ + ... + for([im, in]: [MNumInnerBlock, NNumInnerBlock]) { + ASlice3 = ASlice2[im * iim_block_: (im + 1) * iim_block_, :] + BSlice3 = BSlice2[0, im * iim_block_: (im + 1) * iim_block_, :] + CSlice4 = CSlice3[0, im * iim_block_: (im + 1) * iim_block_, in * +iin_block_: (in + 1) * iin_block_] (init with 0 when ok == 0) + brgemm(bs=KNumInnerBlock, M=iim_block_, N=iin_block_, K=iik_block, +A=ASlice3, B=BSlice3, C=CSlice4, onlyUpdate=(ok!=0)); + } + } +} +C = final_reduce(CSlice) +*/ +struct DeepTileMatmul : public OpInterfaceRewritePattern { + using OpInterfaceRewritePattern::OpInterfaceRewritePattern; + + static FailureOr + outerLoopGeneration(RewriterBase &rewriter, linalg::LinalgOp linalgOp, + gc::MatmulConfig cfg, bool hasFillOp) { + SmallVector KDimPos, MDimPos, NDimPos; + linalgOp.getReductionDims(KDimPos); + getMatmulParallelDims(linalgOp, 0, MDimPos); + getMatmulParallelDims(linalgOp, 1, NDimPos); + OuterLoopGenerationOption option; + + SmallVector iteratorTypes = + linalgOp.getIteratorTypesArray(); + SmallVector loopRange = + cast(linalgOp.getOperation()) + .getIterationDomain(rewriter); + size_t KFirstDim = *getConstantIntValue(loopRange[KDimPos[0]].size); + size_t MFirstDim = *getConstantIntValue(loopRange[MDimPos[0]].size); + size_t NFirstDim = *getConstantIntValue(loopRange[NDimPos[0]].size); + + size_t KParallelBlockSize = + cfg.KThreads == 1 + ? 0 + : (KDimPos.size() > 1 + ? llvm::divideCeil(KFirstDim, cfg.KThreads) + : llvm::divideCeil(llvm::divideCeil(KFirstDim, cfg.KBlock), + cfg.KThreads) * + cfg.KBlock); + size_t MParallelBlockSize = + MDimPos.size() > 1 + ? llvm::divideCeil(MFirstDim, cfg.MThreads) + : llvm::divideCeil(llvm::divideCeil(MFirstDim, cfg.MBlock), + cfg.MThreads) * + cfg.MBlock; + size_t NParallelBlockSize = + NDimPos.size() > 1 + ? llvm::divideCeil(NFirstDim, cfg.NThreads) + : llvm::divideCeil(llvm::divideCeil(NFirstDim, cfg.NBlock), + cfg.NThreads) * + cfg.NBlock; + size_t KOuterBlockSize = KDimPos.size() > 1 + ? (cfg.KBlock - 1) / cfg.innerMostKBlock + 1 + : cfg.KBlock; + size_t MOuterBlockSize = MDimPos.size() > 1 + ? (cfg.MBlock - 1) / cfg.innerMostMBlock + 1 + : cfg.MBlock; + size_t NOuterBlockSize = NDimPos.size() > 1 + ? (cfg.NBlock - 1) / cfg.innerMostNBlock + 1 + : cfg.NBlock; + + // Outer loop tile size + for (auto &&[tile, dim] : + llvm::zip(SmallVector{KParallelBlockSize, MParallelBlockSize, + NParallelBlockSize}, + SmallVector{KDimPos[0], MDimPos[0], NDimPos[0]})) { + option.nestedTileSizes.emplace_back(SmallVector{tile}); + option.loopType.emplace_back( + OuterLoopGenerationOption::LoopType::ForallOp); + option.loopDim.emplace_back(SmallVector{dim}); + } + + // Middle loop tile size + for (auto &&[tile, dim] : + llvm::zip(SmallVector{MOuterBlockSize, NOuterBlockSize, + KOuterBlockSize}, + SmallVector{MDimPos[0], NDimPos[0], KDimPos[0]})) { + option.nestedTileSizes.emplace_back(SmallVector{tile}); + option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForOp); + option.loopDim.emplace_back(SmallVector{dim}); + } + if (llvm::hasSingleElement(KDimPos)) { + option.nestedTileSizes.emplace_back(SmallVector{cfg.KBlock}); + option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForOp); + option.loopDim.emplace_back(SmallVector{KDimPos.back()}); + } + // Inner loop tile size + if (llvm::hasSingleElement(MDimPos)) { + option.nestedTileSizes.emplace_back( + SmallVector{cfg.innerMostMBlock}); + option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForOp); + option.loopDim.emplace_back(SmallVector{MDimPos.back()}); + } + if (llvm::hasSingleElement(NDimPos)) { + option.nestedTileSizes.emplace_back( + SmallVector{cfg.innerMostNBlock}); + option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForOp); + option.loopDim.emplace_back(SmallVector{NDimPos.back()}); + } + for (size_t dim = 0UL; dim < linalgOp.getNumLoops(); ++dim) { + if (dim != MDimPos.back() && dim != NDimPos.back() && + iteratorTypes[dim] != mlir::utils::IteratorType::reduction) { + option.nestedTileSizes.emplace_back(SmallVector{1}); + option.loopType.emplace_back( + OuterLoopGenerationOption::LoopType::ForOp); + option.loopDim.emplace_back(SmallVector{dim}); + } + } + + // cast the low precision matmul to f32 when partial accumulation(result not + // full) is needed + auto lowPrecisionCast = + [&](RewriterBase &rewriter, Location loc, + linalg::LinalgOp linalgop) -> FailureOr { + FailureOr legalizedResult = matmulDtypeLegalize( + rewriter, linalgop.getOperation(), !hasFillOp, true); + if (succeeded(legalizedResult) && legalizedResult->castOp && + legalizedResult->linalgOp) { + Operation *linalgOp = legalizedResult->linalgOp; + rewriter.replaceOp(linalgop, + linalgOp->getResult(linalgOp->getNumResults() - 1)); + return dyn_cast(linalgOp); + } + return failure(); + }; + option.innermostFullResultCallBacks.push_back(lowPrecisionCast); + + if (hasFillOp) { + auto removeReduncantFill = + [&](RewriterBase &rewriter, Location loc, + const linalg::ForallReductionTilingResult &result) + -> FailureOr { + ArrayRef initValue = result.initialValues; + if (llvm::hasSingleElement(initValue) && + isa(initValue[0].getDefiningOp())) + rewriter.replaceOp(initValue[0].getDefiningOp(), + dyn_cast( + initValue[0].getDefiningOp()) + .getDpsInits()[0]); + return dyn_cast(result.parallelTiledOps.back()); + }; + option.finalReduceCallBacks.push_back(removeReduncantFill); + } + + return generateOuterLoop(rewriter, linalgOp, option); + } + + struct innerBodyGenerationOption { + Operation *fillOp; + bool needLowPrecisionCast; + SmallVector KLoopHandles; + }; + + LogicalResult innerBodyGeneration(RewriterBase &rewriter, + linalg::LinalgOp originOp, + linalg::LinalgOp currentOp, + innerBodyGenerationOption &option) const { + Location loc = currentOp->getLoc(); + FailureOr>> operandDimTypes = + getOprandDimType(originOp); + MatmulConfig cfg = + MatmulConfigAnalysis(originOp.getOperation()).getConfig(); + ArrayRef AShape = + originOp.getShape(originOp.getDpsInputOperand(0)); + ArrayRef BShape = + originOp.getShape(originOp.getDpsInputOperand(1)); + ArrayRef CShape = originOp.getShape(originOp.getDpsInitOperand(0)); + + if (failed(operandDimTypes)) + return failure(); + + size_t MDimNum = std::count_if((*operandDimTypes)[0].begin(), + (*operandDimTypes)[0].end(), + [](DimType d) { return d == DimType::M; }); + size_t NDimNum = std::count_if((*operandDimTypes)[1].begin(), + (*operandDimTypes)[1].end(), + [](DimType d) { return d == DimType::N; }); + // TODO: support plain in/block out format + // Calculate the innermost block size according to the config + SmallVector AInnermostDims, BInnermostDims, CInnermostDims; + bool firstM = true, firstK = true, firstN = true; + if (MDimNum > 1) { + for (auto &&[idx, iter] : llvm::enumerate((*operandDimTypes)[0])) { + if (iter == DimType::M && firstM) { + AInnermostDims.push_back(1); + firstM = false; + } else if (iter == DimType::Batch) { + AInnermostDims.push_back(1); + } else if (iter == DimType::K && firstK) { + AInnermostDims.push_back(cfg.KBlock / cfg.innerMostKBlock); + firstK = false; + } else { + AInnermostDims.push_back(AShape[idx]); + } + } + firstM = true; + firstN = true; + for (auto &&[idx, iter] : llvm::enumerate((*operandDimTypes)[2])) { + if (iter == DimType::M && firstM) { + CInnermostDims.push_back(1); + firstM = false; + } else if (iter == DimType::Batch) { + CInnermostDims.push_back(1); + } else if (iter == DimType::N && firstN) { + CInnermostDims.push_back(1); + firstN = false; + } else { + CInnermostDims.push_back(CShape[idx]); + } + } + } else { + AInnermostDims = SmallVector{cfg.innerMostMBlock, + cfg.KBlock / cfg.innerMostKBlock * + cfg.innerMostKBlock}; + CInnermostDims = + SmallVector{cfg.innerMostMBlock, cfg.innerMostNBlock}; + } + + if (NDimNum > 1) { + firstN = true; + firstK = true; + for (auto &&[idx, iter] : llvm::enumerate((*operandDimTypes)[1])) { + if (iter == DimType::N && firstN) { + BInnermostDims.push_back(1); + firstN = false; + } else if (iter == DimType::Batch) { + BInnermostDims.push_back(1); + } else if (iter == DimType::K && firstK) { + BInnermostDims.push_back(cfg.KBlock / cfg.innerMostKBlock); + firstK = false; + } else { + BInnermostDims.push_back(BShape[idx]); + } + } + } else { + BInnermostDims = SmallVector{cfg.KBlock / cfg.innerMostKBlock * + cfg.innerMostKBlock, + cfg.innerMostNBlock}; + } + + // Get the data/wei/dst data type + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(currentOp); + mlir::Type dataType = + dyn_cast(currentOp.getDpsInputs()[0].getType()) + .getElementType(); + mlir::Type weightType = + dyn_cast(currentOp.getDpsInputs()[1].getType()) + .getElementType(); + mlir::Type resultType = + dyn_cast(currentOp.getDpsInits()[0].getType()) + .getElementType(); + + // update the extractSlice to static size, replace it with + // useBlockedLayout when + setStaticSizeForExtractSliceOp(rewriter, + currentOp.getDpsInputs()[1].getDefiningOp(), + true, BInnermostDims, NDimNum > 1); + setStaticSizeForExtractSliceOp(rewriter, + currentOp.getDpsInputs()[0].getDefiningOp(), + true, AInnermostDims, MDimNum > 1); + for (const Value &init : currentOp.getDpsInits()) { + setStaticSizeForExtractSliceOp(rewriter, init.getDefiningOp(), true, + CInnermostDims, MDimNum > 1 ? 2 : 0); + } + + // View the tensor to brgemm required format + Value dataOprand = tensorViewRankedTensor( + rewriter, + mlir::RankedTensorType::get( + MDimNum > 1 ? SmallVector(AInnermostDims.begin() + 1, + AInnermostDims.end()) + : SmallVector{cfg.innerMostMBlock, + cfg.KBlock / cfg.innerMostKBlock, + cfg.innerMostKBlock}, + dataType), + currentOp.getDpsInputs()[0], + MDimNum == 1 ? SmallVector{1, 0, 2} : SmallVector{}); + Value weightOprand = tensorViewRankedTensor( + rewriter, + mlir::RankedTensorType::get( + NDimNum > 1 ? SmallVector(BInnermostDims.begin() + 1, + BInnermostDims.end()) + : SmallVector{cfg.KBlock / cfg.innerMostKBlock, + cfg.innerMostKBlock, + cfg.innerMostNBlock}, + weightType), + currentOp.getDpsInputs()[1]); + Value resultOprand = tensorViewRankedTensor( + rewriter, + mlir::RankedTensorType::get( + SmallVector(CInnermostDims.begin() + (MDimNum > 1 ? 2 : 0), + CInnermostDims.end()), + resultType), + currentOp.getDpsInits()[0]); + // Create the brgemm op and replace the origin linalg op + linalg::LinalgOp matmul; + if (dyn_cast(weightOprand.getType()).getShape().size() == + 3) + matmul = rewriter.create( + loc, resultOprand.getType(), ValueRange{dataOprand, weightOprand}, + resultOprand); + else + matmul = rewriter.create( + loc, resultOprand.getType(), ValueRange{dataOprand, weightOprand}, + resultOprand); + Value result = matmul.getOperation()->getResult(0); + + // Insert the result back to the original tensor + for (Operation *user : currentOp->getResult(0).getUsers()) + setStaticSizeForInsertSliceOp(rewriter, user, result, CInnermostDims); + + if (option.needLowPrecisionCast) { + // fuse the low precision cast to the innermost body + rewriter.setInsertionPointAfter(currentOp); + Value cond; + for (LoopLikeOpInterface &loop : option.KLoopHandles) { + Value induceVar = turnOpFoldResultIntoValue( + rewriter, loc, *loop.getSingleInductionVar()); + Value upBound = turnOpFoldResultIntoValue(rewriter, loc, + *loop.getSingleUpperBound()); + Value step = + turnOpFoldResultIntoValue(rewriter, loc, *loop.getSingleStep()); + Value currentCond = + rewriter.create(loc, induceVar, step); + currentCond = rewriter.create( + loc, arith::CmpIPredicate::sge, currentCond, upBound); + cond = cond ? rewriter.create(loc, cond, currentCond) + : currentCond; + } + scf::IfOp ifOp = rewriter.create( + loc, TypeRange{currentOp.getDpsInits().back().getType()}, + cond ? cond : rewriter.create(loc, true, 1), + true); + { + OpBuilder::InsertionGuard guard(rewriter); + Region ®ion = ifOp.getThenRegion(); + rewriter.setInsertionPointToStart(®ion.back()); + linalg::CopyOp castOp = rewriter.create( + loc, matmul->getResult(0), currentOp.getDpsInits().back()); + rewriter.create(loc, castOp->getResult(0)); + } + { + OpBuilder::InsertionGuard guard(rewriter); + Region ®ion = ifOp.getElseRegion(); + rewriter.setInsertionPointToStart(®ion.back()); + rewriter.create(loc, currentOp.getDpsInits().back()); + } + // set static size for the insertSliceOp of copyOp + for (Operation *user : currentOp->getResult(1).getUsers()) + setStaticSizeForInsertSliceOp(rewriter, user, ifOp->getResult(0), + CInnermostDims); + rewriter.replaceOp(currentOp, {matmul->getResult(0), ifOp->getResult(0)}); + } else { + rewriter.replaceOp(currentOp, matmul->getResult(0)); + } + currentOp = matmul; + + // Fuse the fill op to the innermost body + if (auto fillOp = llvm::dyn_cast_or_null(option.fillOp)) { + Value fillValue = fillOp.getDpsInputs()[0]; + if (cfg.KThreads <= 1) + // if use k slicing, the fill op is still need to be kept for the reduce + // init + rewriter.replaceUsesWithIf(fillOp.getResult(0), fillOp.getDpsInits()[0], + [&](OpOperand &operand) { + return isa( + operand.getOwner()); + }); + + rewriter.setInsertionPointAfter(currentOp); + Value cond; + arith::ConstantIndexOp zeroConst = + rewriter.create(loc, 0); + for (LoopLikeOpInterface &loop : option.KLoopHandles) { + Value induceVar = loop.getLoopRegions().front()->front().getArgument(0); + Value currentCond = rewriter.create( + loc, arith::CmpIPredicate::eq, induceVar, zeroConst); + cond = cond ? rewriter.create(loc, cond, currentCond) + : currentCond; + } + scf::IfOp ifOp = rewriter.create( + loc, TypeRange{currentOp.getDpsInits()[0].getType()}, + cond ? cond : rewriter.create(loc, true, 1), + true); + { + OpBuilder::InsertionGuard guard(rewriter); + Region ®ion = ifOp.getThenRegion(); + rewriter.setInsertionPointToStart(®ion.back()); + linalg::FillOp fillOp = rewriter.create( + loc, fillValue, currentOp.getDpsInits()[0]); + IRMapping mapping; + mapping.map(currentOp.getDpsInits()[0], fillOp.getResult(0)); + Operation *res = rewriter.clone(*(currentOp.getOperation()), mapping); + rewriter.create(loc, res->getResult(0)); + } + { + OpBuilder::InsertionGuard guard(rewriter); + Region ®ion = ifOp.getElseRegion(); + rewriter.setInsertionPointToStart(®ion.back()); + Operation *res = rewriter.clone(*(currentOp.getOperation())); + rewriter.create(loc, res->getResult(0)); + } + rewriter.replaceOp(currentOp, ifOp); + } + return success(); + } + + bool checkLinalgMatmulType(linalg::LinalgOp linalgOp) const { + return llvm::isa(linalgOp) || + llvm::isa(linalgOp) || + llvm::isa(linalgOp) || + llvm::isa(linalgOp) || + llvm::isa(linalgOp); + } + + LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp, + PatternRewriter &rewriter) const override { + if (!checkLinalgMatmulType(linalgOp)) + return failure(); + if (linalgOp.hasPureBufferSemantics()) + return failure(); + + if (linalgOp.getOperation()->getParentOfType() || + !linalgOp || linalgOp.getNumDpsInputs() != 2) + return failure(); + + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(linalgOp); + linalg::LinalgOp originOp = + dyn_cast(*rewriter.clone(*(linalgOp.getOperation()))); + Operation *fillOp = findParentFillOp(linalgOp.getDpsInits()[0]); + + // Step 1. Split matmul(bf16xbf16->bf16) to matmul(bf16xbf16->f32) + + // cast(f32->bf16) if K slicing is needed + MatmulConfig cfg = + MatmulConfigAnalysis(originOp.getOperation()).getConfig(); + linalgOp = *linalg::generalizeNamedOp(rewriter, linalgOp); + bool needLowPrecisionCast = needToLegalizeDtype(linalgOp); + if (cfg.KThreads > 1) { + FailureOr result = + matmulDtypeLegalize(rewriter, linalgOp.getOperation()); + if (succeeded(result) && result->castOp && result->linalgOp) { + rewriter.replaceOp(linalgOp, result->castOp); + linalgOp = dyn_cast(result->linalgOp); + } else { + return failure(); + } + needLowPrecisionCast = false; + } + + // Step 2. Outer loop generation + FailureOr outerLoopResult = outerLoopGeneration( + rewriter, linalgOp, cfg, fillOp && isa(fillOp)); + if (failed(outerLoopResult)) + return failure(); + linalgOp = dyn_cast(outerLoopResult->tiledOps.back()); + + // Step 3 generate inner loop body, convert the linalg.generic to brgemm + innerBodyGenerationOption option = innerBodyGenerationOption{ + fillOp, needLowPrecisionCast, outerLoopResult->reductionLoops}; + + if (failed(innerBodyGeneration(rewriter, originOp, linalgOp, option))) + return failure(); + rewriter.eraseOp(originOp); + return success(); + } +}; + +struct DeepTileContractionNamedOp + : public impl::DeepTileContractionNamedOpBase { +public: + void runOnOperation() final { + MLIRContext &ctx = getContext(); + IRRewriter rewriter(&ctx); + RewritePatternSet patterns(&ctx); + + patterns.add(patterns.getContext()); + linalg::populateLinalgTilingCanonicalizationPatterns(patterns); + linalg::ControlDropUnitDims options; + options.rankReductionStrategy = + linalg::ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice; + linalg::populateFoldUnitExtentDimsPatterns(patterns, options); + tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns); + + for (Dialect *dialect : ctx.getLoadedDialects()) + dialect->getCanonicalizationPatterns(patterns); + for (RegisteredOperationName op : ctx.getRegisteredOperations()) + op.getCanonicalizationPatterns(patterns, &ctx); + if (failed( + applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace +} // namespace gc +} // namespace mlir \ No newline at end of file diff --git a/lib/gc/Transforms/MergeNestedForall.cpp b/lib/gc/Transforms/MergeNestedForall.cpp new file mode 100644 index 000000000..07eb5ffbf --- /dev/null +++ b/lib/gc/Transforms/MergeNestedForall.cpp @@ -0,0 +1,93 @@ +//===-- MergeNestedForall.cpp - Merge nested scf.forall op ------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Transforms/Passes.h" + +namespace mlir { +namespace gc { +#define GEN_PASS_DEF_MERGENESTEDFORALL +#include "gc/Transforms/Passes.h.inc" + +namespace { + +struct MergeNestedForallLoops : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(scf::ForallOp op, + PatternRewriter &rewriter) const override { + Block &outerBody = *op.getBody(); + if (!llvm::hasSingleElement(outerBody.without_terminator())) + return failure(); + + scf::ForallOp innerOp = dyn_cast(outerBody.front()); + if (!innerOp) + return failure(); + + for (auto val : outerBody.getArguments()) + if (llvm::is_contained(innerOp.getDynamicLowerBound(), val) || + llvm::is_contained(innerOp.getDynamicUpperBound(), val) || + llvm::is_contained(innerOp.getDynamicStep(), val)) + return failure(); + + // Reductions are not supported yet. + if (!op.getInits().empty() || !innerOp.getInits().empty()) + return failure(); + + auto bodyBuilder = [&](OpBuilder &builder, Location /*loc*/, + ValueRange iterVals) { + Block &innerBody = *innerOp.getBody(); + assert(iterVals.size() == + (outerBody.getNumArguments() + innerBody.getNumArguments())); + IRMapping mapping; + mapping.map(outerBody.getArguments(), + iterVals.take_front(outerBody.getNumArguments())); + mapping.map(innerBody.getArguments(), + iterVals.take_back(innerBody.getNumArguments())); + for (Operation &op : innerBody) + builder.clone(op, mapping); + }; + + auto concatValues = [](const auto &first, const auto &second) { + SmallVector ret; + ret.reserve(first.size() + second.size()); + ret.assign(first.begin(), first.end()); + ret.append(second.begin(), second.end()); + return ret; + }; + + auto newLowerBounds = + concatValues(op.getMixedLowerBound(), innerOp.getMixedLowerBound()); + auto newUpperBounds = + concatValues(op.getMixedUpperBound(), innerOp.getMixedUpperBound()); + auto newSteps = concatValues(op.getMixedStep(), innerOp.getMixedStep()); + rewriter.replaceOpWithNewOp( + op, newLowerBounds, newUpperBounds, newSteps, ValueRange{}, + std::nullopt, bodyBuilder); + return success(); + } +}; + +struct MergeNestedForall + : public impl::MergeNestedForallBase { +public: + void runOnOperation() final { + auto &ctx = getContext(); + RewritePatternSet patterns(&ctx); + + patterns.add(patterns.getContext()); + + if (failed( + applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace +} // namespace gc +} // namespace mlir diff --git a/lib/gc/Transforms/Pipeline.cpp b/lib/gc/Transforms/Pipeline.cpp index 459e77fa8..f198c6c75 100644 --- a/lib/gc/Transforms/Pipeline.cpp +++ b/lib/gc/Transforms/Pipeline.cpp @@ -34,6 +34,11 @@ namespace mlir::gc { +void populateCleanUpPasses(mlir::OpPassManager &pm) { + pm.addPass(createCanonicalizerPass()); + pm.addPass(createCSEPass()); +} + // linalg + linalgX + tensor void populateFrontendPasses(mlir::OpPassManager &pm) { #ifdef GC_HAS_ONEDNN_DIALECT @@ -46,14 +51,20 @@ void populateTensorPasses(mlir::OpPassManager &pm) { // todo: padding propagation pass // todo: layout propagation pass // todo: tensor constant propagation pass - // todo: linalg.matmul lowering to (scf.loop + linalg.brgemm) pass + // linalg.matmul lowering to (scf.loop + linalg.brgemm) pass + pm.addNestedPass(createDeepTileContractionNamedOp()); + // Fine-grain fusion pass pm.addNestedPass(createIterativeTilingAndFusion()); + // todo: fine-grain fusion pass // todo: lower linalg to arith/math on virtual vector pass // REMOVE this pass after the above passes are added. Currently we add this // pass to make the pipeline work properly pm.addNestedPass(createLinalgGeneralizeNamedOpsPass()); + pm.addPass(createLoopInvariantCodeMotionPass()); + pm.addPass(createControlFlowSinkPass()); + populateCleanUpPasses(pm); } // scf + arith + math + vector + tensor + linalg.brgemm @@ -72,6 +83,7 @@ void populateVectorPasses(mlir::OpPassManager &pm) { // oneDNN graph spec pm.addNestedPass(arith::createArithExpandOpsPass()); // todo: lower to physical vector pass, device dependent pass + populateCleanUpPasses(pm); } // scf + arith + math + vector + memref + linalg.brgemm @@ -91,6 +103,7 @@ void populateBufferizationPasses(mlir::OpPassManager &pm) { pm.addNestedPass(bufferization::createBufferLoopHoistingPass()); pm.addNestedPass(bufferization::createBufferDeallocationPass()); pm.addPass(createBufferizationToMemRefPass()); + populateCleanUpPasses(pm); } // scf + arith + math + vector + memref + func/microkernel @@ -107,7 +120,15 @@ void populateMicroKernelPasses(mlir::OpPassManager &pm) { void populateCPURuntimePasses(mlir::OpPassManager &pm) { // todo: flatten nested parallel pass to support coarse-grain usion // remove this pass after we add FlattenNestedParallel + pm.addPass(createSinkOpIntoInnerLoop()); + pm.addPass(createMergeNestedForall()); + pm.addPass(createLoopInvariantCodeMotionPass()); + pm.addPass(createControlFlowSinkPass()); + pm.addPass(createForallToParallelLoopPass()); + pm.addPass(createParallelLoopFusionPass()); + pm.addPass(createLoopInvariantCodeMotionPass()); pm.addPass(createConvertSCFToOpenMPPass()); + populateCleanUpPasses(pm); } void populateLoweringToLLVMPasses(mlir::OpPassManager &pm) { @@ -149,7 +170,7 @@ void populateCPUPipeline(mlir::OpPassManager &pm) { pm.addNestedPass(createConvertLinalgToParallelLoopsPass()); populateMicroKernelPasses(pm); populateCPURuntimePasses(pm); - // // back-end, llvm dialect + // back-end, llvm dialect populateLLVMPasses(pm); } diff --git a/lib/gc/Transforms/SinkOpIntoInnerLoop.cpp b/lib/gc/Transforms/SinkOpIntoInnerLoop.cpp new file mode 100644 index 000000000..965a26392 --- /dev/null +++ b/lib/gc/Transforms/SinkOpIntoInnerLoop.cpp @@ -0,0 +1,50 @@ +//===-- SinkOpIntoInnerLoop.cpp - sink op to inner if possible --*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/Passes.h" + +#include "mlir/IR/Dominance.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/LoopLikeInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Transforms/ControlFlowSinkUtils.h" + +namespace mlir { +namespace gc { +#define GEN_PASS_DEF_SINKOPINTOINNERLOOP +#include "gc/Transforms/Passes.h.inc" + +namespace { + +struct SinkOpIntoInnerLoop + : public impl::SinkOpIntoInnerLoopBase { +public: + void runOnOperation() final { + auto &domInfo = getAnalysis(); + getOperation()->walk([&](LoopLikeOpInterface loop) { + SmallVector regionsToSink; + // Get the regions are that known to be executed at most once. + for (auto &it : loop->getRegions()) + regionsToSink.push_back(&it); + // Sink side-effect free operations. + controlFlowSink( + regionsToSink, domInfo, + [](Operation *op, Region *) { return isMemoryEffectFree(op); }, + [](Operation *op, Region *region) { + // Move the operation to the beginning of the region's entry block. + // This guarantees the preservation of SSA dominance of all of the + // operation's uses are in the region. + op->moveBefore(®ion->front(), region->front().begin()); + }); + }); + } +}; + +} // namespace +} // namespace gc +} // namespace mlir \ No newline at end of file diff --git a/lib/gc/Transforms/TileNamed.cpp b/lib/gc/Transforms/TileNamed.cpp deleted file mode 100644 index 43348685d..000000000 --- a/lib/gc/Transforms/TileNamed.cpp +++ /dev/null @@ -1,49 +0,0 @@ -//===-- TileNamed.cpp - Tile Named Linalg Ops -------------------*- C++ -*-===// -// -// This file is licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "gc/Transforms/Passes.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -using namespace mlir; - -namespace mlir { -namespace gc { -#define GEN_PASS_DEF_TILELINALGNAMED -#include "gc/Transforms/Passes.h.inc" -} // namespace gc -} // namespace mlir - -namespace { -class TileLinalg : public mlir::gc::impl::TileLinalgNamedBase { - - void runOnOperation() override { - auto *ctx = &getContext(); - IRRewriter rewriter(ctx); - - llvm::SmallVector to_tile; - for (Operation &o : getOperation()->getRegion(0).front().getOperations()) { - if (isa(o)) { - to_tile.push_back(&o); - } - } - - for (Operation *o : to_tile) { - llvm::errs() << "func op body to tile: " << *o << "\n"; - } - } -}; - -} // namespace diff --git a/lib/gc/Transforms/TilingUtil.cpp b/lib/gc/Transforms/TilingUtil.cpp new file mode 100644 index 000000000..25d94938c --- /dev/null +++ b/lib/gc/Transforms/TilingUtil.cpp @@ -0,0 +1,748 @@ +//===-- TilingUtil.cpp - Implementation of linalg Tiling --------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/LoopUtils.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Interfaces/TilingInterface.h" +#include +#include + +namespace mlir { +#define GEN_PASS_DEF_LINALGTILINGPASS +#include "mlir/Dialect/Linalg/Passes.h.inc" +} // namespace mlir + +using namespace mlir; +using namespace mlir::affine; +using namespace mlir::linalg; +using namespace mlir::scf; + +#define DEBUG_TYPE "linalg-tiling" + +namespace mlir { +namespace linalgX { + +struct LinalgOpPartialReductionInterface { + static FailureOr> generateInitialTensorForPartialReduction( + Operation *op, OpBuilder &b, Location loc, ArrayRef sizes, + ArrayRef reductionDims, ArrayRef newParallelDims) { + auto linalgOp = cast(op); + OpBuilder::InsertionGuard guard(b); + + if (newParallelDims.empty()) + newParallelDims = reductionDims; + + if (linalgOp.hasPureBufferSemantics()) + return op->emitOpError("expected operation to have tensor semantics"); + // Insert the new parallel dimension based on the index of the reduction + // loops. This could be controlled by user for more flexibility. + SmallVector inits; + for (int initIdx = 0, e = linalgOp.getNumDpsInits(); initIdx < e; + ++initIdx) { + SmallVector combinerOps; + if (!matchReduction(linalgOp.getRegionOutputArgs(), 0, combinerOps) || + combinerOps.size() != 1) + return op->emitOpError("Failed to anaysis the reduction operation."); + + Operation *reductionOp = combinerOps[0]; + std::optional identity = arith::getNeutralElement(reductionOp); + if (!identity.has_value()) + return op->emitOpError( + "Failed to get an identity value for the reduction operation."); + + ArrayRef oldShape = + linalgOp.getShape(linalgOp.getDpsInitOperand(0)); + + // Extend tile size vector to the rank of the output tensor. + SmallVector tileSizeVector = + getValueOrCreateConstantIndexOp(b, loc, sizes); + if (tileSizeVector.size() < oldShape.size()) { + auto zero = b.create(loc, 0); + tileSizeVector.append(oldShape.size() - tileSizeVector.size(), zero); + } + + // Calculate the new shape, we insert the new dimensions based on the + // index of the reduction dimensions. + SmallVector newOutputShape; + SmallVector dynamicDims; + int64_t currReductionDims = 0; + DenseSet newParallelDimsSet(newParallelDims.begin(), + newParallelDims.end()); + for (int64_t idx : + llvm::seq(0, oldShape.size() + newParallelDims.size())) { + if (newParallelDimsSet.contains(idx)) { + dispatchIndexOpFoldResults(sizes[reductionDims[currReductionDims]], + dynamicDims, newOutputShape); + currReductionDims++; + continue; + } + int64_t oldIdx = idx - currReductionDims; + int64_t dim = oldShape[oldIdx]; + newOutputShape.push_back(dim); + if (ShapedType::isDynamic(dim)) + dynamicDims.push_back(b.create( + loc, linalgOp.getDpsInitOperand(0)->get(), oldIdx)); + } + Value emptyTensor = b.create( + loc, newOutputShape, linalgOp.getRegionOutputArgs()[0].getType(), + dynamicDims); + Value constantOp = b.create(loc, *identity); + auto identityTensor = + b.create(loc, constantOp, emptyTensor); + inits.push_back(identityTensor.getResult(0)); + } + return inits; + } + + static Operation *tileToPartialReduction(Operation *op, OpBuilder &b, + Location loc, ValueRange init, + ArrayRef offsets, + ArrayRef sizes, + ArrayRef reductionDims) { + OpBuilder::InsertionGuard guard(b); + auto linalgOp = cast(op); + + AffineMap oldOutputMap = + linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(0)); + SmallVector outputExpr(oldOutputMap.getNumResults() + + reductionDims.size()); + + for (int idx : reductionDims) + outputExpr[idx] = b.getAffineDimExpr(idx); + int currExpr = 0; + for (int idx : llvm::seq(0, outputExpr.size())) { + if (outputExpr[idx]) + continue; + outputExpr[idx] = oldOutputMap.getResult(currExpr++); + } + + // Step 1: Extract a slice of the input operands. + SmallVector valuesToTile = linalgOp.getDpsInputs(); + SmallVector tiledOperands = makeTiledShapes( + b, loc, linalgOp, valuesToTile, offsets, sizes, {}, true); + + // Step 2: Extract the accumulator operands + SmallVector strides(offsets.size(), b.getIndexAttr(1)); + SmallVector outOffsets(offsets.size(), b.getIndexAttr(0)); + // TODO: use SubsetExtractOpInterface once it is available. + Value out = b.create(loc, init[0], outOffsets, + sizes, strides); + + // Step3. Create a generic op where the reduction dimensions are replaced + // by a parallel dimension of the size of reduction. + SmallVector newIteratorTypes = + linalgOp.getIteratorTypesArray(); + for (int dim : reductionDims) + newIteratorTypes[dim] = utils::IteratorType::parallel; + SmallVector newMaps = linalgOp.getIndexingMapsArray(); + newMaps.back() = AffineMap::get(newMaps.back().getNumDims(), 0, outputExpr, + linalgOp.getContext()); + auto genericOp = + b.create(loc, TypeRange({out.getType()}), tiledOperands, + ValueRange({out}), newMaps, newIteratorTypes); + IRMapping mapping; + op->getRegion(0).cloneInto(&genericOp.getRegion(), + genericOp.getRegion().begin(), mapping); + return genericOp.getOperation(); + } + + static Operation *mergeReductions(Operation *op, OpBuilder &b, Location loc, + ValueRange partialReduce, + ArrayRef reductionDims) { + auto linalgOp = cast(op); + SmallVector reductionDimsInt64(reductionDims.begin(), + reductionDims.end()); + SmallVector combinerOps; + matchReduction(linalgOp.getRegionOutputArgs(), 0, combinerOps); + Operation *reductionOp = combinerOps[0]; + + auto reduction = b.create( + loc, ValueRange({partialReduce[0]}), + ValueRange({linalgOp.getDpsInits()[0]}), reductionDimsInt64, + [reductionOp](OpBuilder &b, Location loc, ValueRange inputs) { + Operation *clonedReductionOp = b.clone(*reductionOp); + clonedReductionOp->setOperand(0, inputs[0]); + clonedReductionOp->setOperand(1, inputs[1]); + b.create(loc, clonedReductionOp->getResult(0)); + }); + return reduction.getOperation(); + } +}; + +std::tuple, LoopIndexToRangeIndexMap> +makeTiledLoopRanges(RewriterBase &b, Location loc, AffineMap map, + ArrayRef allShapeSizes, + ArrayRef allTileSizes) { + assert(allTileSizes.size() == map.getNumResults()); + // Apply `map` to get shape sizes in loop order. + SmallVector shapeSizes = + makeComposedFoldedMultiResultAffineApply(b, loc, map, allShapeSizes); + SmallVector tileSizes(allTileSizes.begin(), allTileSizes.end()); + + // Traverse the tile sizes, which are in loop order, erase zeros everywhere. + LoopIndexToRangeIndexMap loopIndexToRangeIndex; + for (int idx = 0, e = tileSizes.size(), zerosCount = 0; idx < e; ++idx) { + if (getConstantIntValue(tileSizes[idx - zerosCount]) == + static_cast(0)) { + shapeSizes.erase(shapeSizes.begin() + idx - zerosCount); + tileSizes.erase(tileSizes.begin() + idx - zerosCount); + ++zerosCount; + continue; + } + loopIndexToRangeIndex[idx] = idx - zerosCount; + } + + // Create a new range with the applied tile sizes. + SmallVector res; + for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) + res.push_back(Range{b.getIndexAttr(0), shapeSizes[idx], tileSizes[idx]}); + return std::make_tuple(res, loopIndexToRangeIndex); +} + +void transformIndexOps(RewriterBase &b, LinalgOp op, + SmallVectorImpl &ivs, + const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) { + SmallVector allIvs(op.getNumLoops(), nullptr); + for (auto en : enumerate(allIvs)) { + auto rangeIndex = loopIndexToRangeIndex.find(en.index()); + if (rangeIndex == loopIndexToRangeIndex.end()) + continue; + en.value() = ivs[rangeIndex->second]; + } + offsetIndices(b, op, getAsOpFoldResult(allIvs)); +} + +/// Returns true if the maximum tile offset `tileSize * numThreads-1` is less +/// than `iterationSize`. +static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize, + OpFoldResult numThreads, + OpFoldResult iterationSize) { + std::optional tileSizeConst = getConstantIntValue(tileSize); + std::optional numThreadsConst = getConstantIntValue(numThreads); + std::optional iterSizeConst = getConstantIntValue(iterationSize); + if (!tileSizeConst || !numThreadsConst || !iterSizeConst) + return false; + return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst; +} + +/// Build an `affine_max` of all the `vals`. +static OpFoldResult buildMax(OpBuilder &b, Location loc, + ArrayRef vals) { + return affine::makeComposedFoldedAffineMax( + b, loc, AffineMap::getMultiDimIdentityMap(vals.size(), loc.getContext()), + vals); +} + +/// Build an `affine_min` of all the `vals`. +static OpFoldResult buildMin(OpBuilder &b, Location loc, + ArrayRef vals) { + return affine::makeComposedFoldedAffineMin( + b, loc, AffineMap::getMultiDimIdentityMap(vals.size(), loc.getContext()), + vals); +} + +/// Fill out the `tiledOffsets` and `tiledSizes` to be used to tile to a given +/// number of threads. +static void calculateTileOffsetsAndSizes( + RewriterBase &b, Location loc, scf::ForallOp forallOp, + ArrayRef numThreads, SmallVector loopRanges, + bool omitTileOffsetBoundsCheck, + std::optional> nominalTileSizes, + SmallVector &tiledOffsets, + SmallVector &tiledSizes) { + OpBuilder::InsertionGuard g(b); + b.setInsertionPointToStart(forallOp.getBody(0)); + + SmallVector threadIds = forallOp.getInductionVars(); + SmallVector nonZeroNumThreads = + llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) { + return !isConstantIntValue(ofr, 0); + })); + int64_t nLoops = loopRanges.size(); + tiledOffsets.reserve(nLoops); + tiledSizes.reserve(nLoops); + for (unsigned loopIdx = 0, threadIdIdx = 0; loopIdx < nLoops; ++loopIdx) { + bool overflow = loopIdx >= numThreads.size(); + bool isZero = !overflow && isConstantIntValue(numThreads[loopIdx], 0); + // Degenerate case: take the whole domain. + if (overflow || isZero) { + tiledOffsets.push_back(loopRanges[loopIdx].offset); + tiledSizes.push_back(loopRanges[loopIdx].size); + continue; + } + + // Tiled case: compute the offset and size. + AffineExpr i, j, m, n, o; + bindDims(b.getContext(), i, j); + bindSymbols(b.getContext(), m, n, o); + OpFoldResult size = loopRanges[loopIdx].size; + OpFoldResult offset = loopRanges[loopIdx].offset; + OpFoldResult threadId = threadIds[threadIdIdx]; + // Symbolic fixed max size per thread. + // TODO: floor + 0/1 depending on case for better load-balancing. + OpFoldResult tileSizePerThread = + nominalTileSizes.has_value() + ? (*nominalTileSizes)[loopIdx] + : makeComposedFoldedAffineApply( + b, loc, m.ceilDiv(n), + ArrayRef{size, nonZeroNumThreads[threadIdIdx]}); + // Dynamic offset shifted by threadId * maxSizePerThread. + OpFoldResult offsetPerThread = makeComposedFoldedAffineApply( + b, loc, i + j * m, {offset, threadId, tileSizePerThread}); + // Dynamic upper-bound depending on the threadId. + OpFoldResult residualTileSize = makeComposedFoldedAffineApply( + b, loc, i + j * m - n, + {offset, nonZeroNumThreads[threadIdIdx], tileSizePerThread, size}); + if (!isConstantIntValue(residualTileSize, 0)) { + OpFoldResult sizeMinusOffsetPerThread = makeComposedFoldedAffineApply( + b, loc, -i + m, {offsetPerThread, size}); + tileSizePerThread = + buildMin(b, loc, {sizeMinusOffsetPerThread, tileSizePerThread}); + } + + tiledOffsets.push_back(offsetPerThread); + // TODO: if tileSizePerThread <= 0 early exit. + if (!omitTileOffsetBoundsCheck && + !canOmitTileOffsetInBoundsCheck(tileSizePerThread, + nonZeroNumThreads[threadIdIdx], size)) + tileSizePerThread = + buildMax(b, loc, {b.getIndexAttr(0), tileSizePerThread}); + + tiledSizes.push_back(tileSizePerThread); + ++threadIdIdx; + } +} + +template +static FailureOr +tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ArrayRef tileSizes, + const LinalgTilingOptions &options) { + OpBuilder::InsertionGuard g(b); + + auto nLoops = op.getNumLoops(); + // Initial tile sizes may be too big, only take the first nLoops. + tileSizes = tileSizes.take_front(nLoops); + + if (llvm::all_of(tileSizes, [](OpFoldResult ofr) { + return getConstantIntValue(ofr) == static_cast(0); + })) { + TiledLinalgOp tiledOp; + tiledOp.op = cast(b.clone(*op.getOperation())); + tiledOp.tensorResults.assign(tiledOp.op->result_begin(), + tiledOp.op->result_end()); + return tiledOp; + } + + // 1. Build the tiled loop ranges. + SmallVector allShapeSizes = + op.createFlatListOfOperandDims(b, op.getLoc()); + AffineMap shapeSizesToLoopsMap = op.getShapesToLoopsMap(); + if (!shapeSizesToLoopsMap) + return failure(); + + auto [loopRanges, loopIndexToRangeIndex] = makeTiledLoopRanges( + b, op.getLoc(), shapeSizesToLoopsMap, allShapeSizes, tileSizes); + + SmallVector iteratorTypes; + for (const auto &attr : enumerate(op.getIteratorTypesArray())) { + if (loopIndexToRangeIndex.count(attr.index())) + iteratorTypes.push_back(attr.value()); + } + // If interchangeVector is empty, use the identity. Build the permutation map + // otherwise. + auto invPermutationMap = + AffineMap::getMultiDimIdentityMap(tileSizes.size(), b.getContext()); + if (!options.interchangeVector.empty()) { + // Based on the pruned iterations (due to zero tile size), recompute the + // interchange vector. + SmallVector interchangeVector; + interchangeVector.reserve(options.interchangeVector.size()); + for (auto pos : options.interchangeVector) { + auto it = loopIndexToRangeIndex.find(pos); + if (it == loopIndexToRangeIndex.end()) + continue; + interchangeVector.push_back(it->second); + } + // Interchange vector is guaranteed to be a permutation, + // `inversePermutation` must succeed. + invPermutationMap = inversePermutation( + AffineMap::getPermutationMap(interchangeVector, b.getContext())); + assert(invPermutationMap); + SmallVector permutation(interchangeVector.begin(), + interchangeVector.end()); + applyPermutationToVector(loopRanges, permutation); + applyPermutationToVector(iteratorTypes, permutation); + } + + // Handle distribution. Create a vector of the same size of loops that are to + // be tiled. + SmallVector procInfo; + if (options.distribution) { + procInfo.resize( + iteratorTypes.size(), + linalg::ProcInfo{nullptr, nullptr, linalg::DistributionMethod::None}); + // Collect loop ranges of tiled loops, loops that are parallel. + SmallVector parallelLoopRanges; + for (const auto &iteratorType : llvm::enumerate(iteratorTypes)) { + if (!isParallelIterator(iteratorType.value())) + break; + parallelLoopRanges.push_back(loopRanges[iteratorType.index()]); + } + auto returnedProcInfo = + options.distribution->procInfo(b, op.getLoc(), parallelLoopRanges); + unsigned procIdIdx = 0; + // Update the distribution information for the loops. + for (const auto &iteratorType : llvm::enumerate(iteratorTypes)) { + if (!isParallelIterator(iteratorType.value())) + break; + procInfo[iteratorType.index()] = returnedProcInfo[procIdIdx++]; + } + } + + // 2. Create the tiled loops. + LinalgOp res = op; + SmallVector ivs, tensorResults; + auto tiledLoopBodyBuilder = + [&](OpBuilder &builder, Location loc, ValueRange localIvs, + ValueRange operandValuesToUse) -> scf::ValueVector { + ivs.assign(localIvs.begin(), localIvs.end()); + + // When an `interchangeVector` is present, it has been applied to the + // loop ranges and the iterator types. Apply its inverse to the + // resulting loop `ivs` to match the op definition. + SmallVector interchangedIvs; + if (!options.interchangeVector.empty()) { + for (AffineExpr result : invPermutationMap.getResults()) + interchangedIvs.push_back( + ivs[cast(result).getPosition()]); + } else { + interchangedIvs.assign(ivs.begin(), ivs.end()); + } + + // Tile the `operandValuesToUse` that either match the `op` operands + // themselves or the tile loop arguments forwarding them. + assert(operandValuesToUse.size() == + static_cast(op->getNumOperands()) && + "expect the number of operands and inputs and outputs to match"); + SmallVector valuesToTile = operandValuesToUse; + SmallVector sizeBounds = + makeComposedFoldedMultiResultAffineApply(b, loc, shapeSizesToLoopsMap, + allShapeSizes); + SmallVector tiledOperands = makeTiledShapes( + b, loc, op, valuesToTile, getAsOpFoldResult(interchangedIvs), tileSizes, + sizeBounds, + /*omitPartialTileCheck=*/false); + + SmallVector resultTensorTypes = + getTensorOutputTypes(op, tiledOperands); + res = clone(b, op, resultTensorTypes, tiledOperands); + tensorResults = + insertSlicesBack(builder, loc, op, tiledOperands, res->getResults()); + return scf::ValueVector(tensorResults.begin(), tensorResults.end()); + }; + GenerateLoopNest::doit(b, op.getLoc(), loopRanges, op, iteratorTypes, + tiledLoopBodyBuilder, procInfo); + + // 3. Transform IndexOp results w.r.t. the tiling. + linalg::transformIndexOps(b, res, ivs, loopIndexToRangeIndex); + + // 4. Gather the newly created loops and return them with the new op. + SmallVector loops; + loops.reserve(ivs.size()); + for (auto iv : ivs) { + if (isa(iv)) { + loops.push_back(cast(iv).getOwner()->getParentOp()); + assert(loops.back() && "no owner found for induction variable!"); + } else { + // TODO: Instead of doing this, try to recover the ops used instead of the + // loop. + loops.push_back(nullptr); + } + } + + // 5. Get the tensor results from the outermost loop if available. Otherwise + // use the previously captured `tensorResults`. + Operation *outermostLoop = nullptr; + for (Operation *loop : loops) + if ((outermostLoop = loop)) + break; + + return TiledLinalgOp{ + res, loops, outermostLoop ? outermostLoop->getResults() : tensorResults}; +} + +FailureOr tileReductionUsingForall( + RewriterBase &b, PartialReductionOpInterface op, + ArrayRef threadNums, ArrayRef tileSizes, + ArrayRef newParallelDims, std::optional mapping) { + Location loc = op.getLoc(); + OpBuilder::InsertionGuard g(b); + + // Ops implementing PartialReductionOpInterface are expected to implement + // TilingInterface. + // TODO: proper core mechanism to tie interfaces together. + auto tilingInterfaceOp = cast(op.getOperation()); + + // Ops implementing PartialReductionOpInterface are not necessarily expected + // to implement TilingInterface.. This cast is unsafe atm. + // TODO: proper core mechanism to tie interfaces together. + // TODO: this function requires a pair of interfaces .. + auto destinationStyleOp = + dyn_cast(op.getOperation()); + if (!destinationStyleOp) + return b.notifyMatchFailure(op, "not a destination style op"); + + // Actually this only work for Linalg ops atm. + auto linalgOp = dyn_cast(op.getOperation()); + if (!linalgOp) + return b.notifyMatchFailure(op, "not a linalg op"); + + SmallVector iterationDomain = tilingInterfaceOp.getIterationDomain(b); + if (op->getNumResults() != 1) + return b.notifyMatchFailure( + op, "don't support ops with multiple results for now"); + + SmallVector iterators = + tilingInterfaceOp.getLoopIteratorTypes(); + SmallVector redDims; + for (auto [idx, iteratorType] : + llvm::enumerate(tilingInterfaceOp.getLoopIteratorTypes())) { + if (iteratorType == utils::IteratorType::reduction) + redDims.push_back(idx); + } + + SmallVector numThreads(threadNums.begin(), threadNums.end()); + if (numThreads.empty()) { + SmallVector loopRanges = tilingInterfaceOp.getIterationDomain(b); + unsigned nLoops = loopRanges.size(); + numThreads.reserve(nLoops); + AffineExpr s0, s1; + bindSymbols(b.getContext(), s0, s1); + AffineExpr divExpr = s0.ceilDiv(s1); + for (const auto &it : llvm::zip(tileSizes, loopRanges)) { + OpFoldResult numTiles = std::get<0>(it); + if (!isConstantIntValue(numTiles, 0)) + numTiles = makeComposedFoldedAffineApply( + b, op.getLoc(), divExpr, {std::get<1>(it).size, std::get<0>(it)}); + numThreads.push_back(numTiles); + } + } + + if (!tileSizes.empty() && tileSizes.size() != numThreads.size()) + return b.notifyMatchFailure(op, "if tile sizes are present it must have as " + "many elements as number of threads"); + + if ((unsigned)redDims.front() >= numThreads.size()) + return b.notifyMatchFailure( + op, "reduction dimension must be mapped to threads"); + SmallVector constantNewParallelDims; + for (auto dim : newParallelDims) { + if (getConstantIntValue(dim) == std::nullopt) + return b.notifyMatchFailure( + op, "Expected new parallel dims to be constant integers."); + constantNewParallelDims.push_back(*getConstantIntValue(dim)); + } + if (newParallelDims.empty()) + constantNewParallelDims = redDims; + if (constantNewParallelDims.size() != redDims.size()) + return b.notifyMatchFailure( + op, "reduction dimension must be mapped to new parallel dims"); + // 1. Create the inital tensor value. + FailureOr> maybeInitTensors = + LinalgOpPartialReductionInterface:: + generateInitialTensorForPartialReduction( + op, b, loc, numThreads, redDims, constantNewParallelDims); + if (failed(maybeInitTensors)) + return b.notifyMatchFailure( + op, "Failed to create inital tensors for partial reduction"); + SmallVector &initTensors = maybeInitTensors.value(); + + // Gather destination tensors. + SmallVector dest; + if (failed(tensor::getOrCreateDestinations(b, loc, op, dest))) + return b.notifyMatchFailure(op, "failed to get destination tensors"); + + Operation *tiledOp = nullptr; + SmallVector nonZeroNumThreads = + llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) { + return !isConstantIntValue(ofr, 0); + })); + SmallVector materializedNonZeroNumThreads = + getValueOrCreateConstantIndexOp(b, loc, nonZeroNumThreads); + // 2. Create the ForallOp with an empty region. + scf::ForallOp forallOp = b.create( + loc, getAsOpFoldResult(materializedNonZeroNumThreads), initTensors, + mapping); + // 3. Calculate the tile offsets and sizes for the subsequent loop that will + // be nested under `forallOp`. + SmallVector tiledOffsets, tiledSizes; + std::optional> nominalTileSizes = std::nullopt; + if (!tileSizes.empty() && threadNums.empty()) { + nominalTileSizes = tileSizes; + } + calculateTileOffsetsAndSizes(b, loc, forallOp, numThreads, iterationDomain, + /*omitTileOffsetBoundsCheck =*/false, + /*nominalTileSizes=*/nominalTileSizes, + tiledOffsets, tiledSizes); + // 4. Clone the tileable op and update its destination operands to use the + // output bbArgs of the ForallOp. + SmallVector tilingResults; + ArrayRef destBbArgs = forallOp.getRegionIterArgs(); + { + // 4.a. RAII guard, inserting within forallOp, before terminator. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(forallOp.getTerminator()); + + SmallVector tiledDpsInitOperands; + for (Value initOperand : destinationStyleOp.getDpsInits()) { + auto *it = llvm::find(dest, initOperand); + assert(it != dest.end() && "dest operand not found in dest"); + unsigned destNum = std::distance(dest.begin(), it); + auto dest = destBbArgs[destNum]; + auto destShape = cast(dest.getType()).getShape(); + SmallVector strides(destShape.size(), b.getIndexAttr(1)); + SmallVector outOffsets(destShape.size(), b.getIndexAttr(0)); + SmallVector sizes(destShape.size(), b.getIndexAttr(0)); + for (const auto &iteratorType : + llvm::enumerate(cast(dest.getType()).getShape())) { + sizes[iteratorType.index()] = + getAsIndexOpFoldResult(b.getContext(), iteratorType.value()); + if (llvm::find(constantNewParallelDims, iteratorType.index()) != + constantNewParallelDims.end()) { + sizes[iteratorType.index()] = b.getIndexAttr(1); + } + } + + auto nonZeroDimIdx = 0; + auto currentReductionIdx = 0; + for (const auto &iteratorType : llvm::enumerate(numThreads)) { + if (!isConstantIntValue(iteratorType.value(), 0)) { + if (llvm::find(redDims, iteratorType.index()) != redDims.end()) { + outOffsets[constantNewParallelDims[currentReductionIdx++]] = + forallOp.getInductionVars()[nonZeroDimIdx]; + } + nonZeroDimIdx++; + } + } + // TODO: use SubsetExtractOpInterface once it is available. + tiledDpsInitOperands.push_back(b.create( + loc, cast(initOperand.getType()), dest, outOffsets, + sizes, strides)); + } + + // 4.b. Clone the op and update init operands. + // We cannot use a IRMapping here because it can replace + // different OpOperands with the same value. + Operation *clonedOp = b.clone(*op.getOperation()); + b.modifyOpInPlace(clonedOp, [&]() { + for (auto [initOperandPtr, tiledInitValue] : llvm::zip_equal( + cast(clonedOp).getDpsInitsMutable(), + tiledDpsInitOperands)) { + initOperandPtr.set(tiledInitValue); + } + }); + // 5. Tile the cloned op and delete the clone. + if (tileSizes.empty() || threadNums.empty()) { + FailureOr tilingResult = + cast(clonedOp).getTiledImplementation( + b, tiledOffsets, tiledSizes); + if (failed(tilingResult)) + return clonedOp->emitError("Failed to tile op: "); + if (tilingResult->tiledOps.size() != 1) { + return clonedOp->emitError("expected a single produced tiled op, got ") + << tilingResult->tiledOps.size(); + } + tiledOp = tilingResult->tiledOps.front(); + tilingResults = tilingResult->tiledValues; + } else { + LinalgTilingOptions options; + FailureOr maybeTiled = tileLinalgOpImpl( + b, cast(clonedOp), tileSizes, options); + if (failed(maybeTiled)) + return b.notifyMatchFailure(op, "failed tileLinalgOpImpl"); + + SmallVector ids = forallOp.getInductionVars(); + mapLoopToProcessorIds(cast(maybeTiled->loops.back()), ids, + materializedNonZeroNumThreads); + if (maybeTiled->loops.size() != 1) { + return clonedOp->emitError("expected a single produced loop"); + } + tiledOp = maybeTiled->op; + tilingResults = maybeTiled->loops.front()->getResults(); + } + + b.eraseOp(clonedOp); + } + + // 6. Insert the partial reductions back into a new tensor. + for (auto [index, result, bbArg] : llvm::zip( + llvm::seq(0, dest.size()), tilingResults, destBbArgs)) { + // 6.a. Partial subset information is inserted just before the terminator. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(forallOp.getTerminator()); + + SmallVector resultOffsets, resultSizes; + if (failed(tilingInterfaceOp.getResultTilePosition( + b, index, tiledOffsets, tiledSizes, resultOffsets, resultSizes))) + return op->emitOpError("output offsets couldn't be calculated"); + SmallVector resultOffsetsRank, resultSizesRank; + uint64_t offIdx = 0; + int64_t nonZeroDimIdx = 0; + SmallVector reductionInductionVars; + for (auto i = 0UL; i < numThreads.size(); ++i) { + if (llvm::find(constantNewParallelDims, i) != + constantNewParallelDims.end()) { + resultOffsetsRank.push_back(b.getIndexAttr(1)); + resultSizesRank.push_back(b.getIndexAttr(1)); + } else if (offIdx < resultOffsets.size()) { + resultOffsetsRank.push_back(resultOffsets[offIdx]); + resultSizesRank.push_back(resultSizes[offIdx++]); + } + if (llvm::find(redDims, i) != redDims.end()) { + reductionInductionVars.push_back( + forallOp.getInductionVars()[nonZeroDimIdx]); + } + if (!isConstantIntValue(numThreads[i], 0)) { + nonZeroDimIdx++; + } + } + for (auto [parallelDims, redVar] : + llvm::zip(constantNewParallelDims, reductionInductionVars)) { + resultOffsetsRank[parallelDims] = redVar; + resultSizesRank[parallelDims] = b.getIndexAttr(1); + } + SmallVector strides(resultSizesRank.size(), + b.getIndexAttr(1)); + + // 6.b. Parallel insertions are inserted at the end of the combining + // terminator. + b.setInsertionPointToEnd(forallOp.getTerminator().getBody()); + b.create( + loc, result, bbArg, resultOffsetsRank, resultSizesRank, strides); + } + // 7. Merge the partial reductions. + Operation *mergeOp = nullptr; + b.setInsertionPointAfter(forallOp); + mergeOp = linalgX::LinalgOpPartialReductionInterface::mergeReductions( + op, b, loc, forallOp->getResults(), constantNewParallelDims); + b.replaceOp(op, mergeOp->getResults()); + // 8. Return. + ForallReductionTilingResult results; + results.initialValues = initTensors; + results.loops = forallOp; + results.parallelTiledOps = SmallVector{tiledOp}; + results.mergeOps = SmallVector{mergeOp}; + return results; +} + +} // namespace linalgX +} // namespace mlir \ No newline at end of file diff --git a/lib/gc/Transforms/TilingUtil.hpp b/lib/gc/Transforms/TilingUtil.hpp new file mode 100644 index 000000000..42460d374 --- /dev/null +++ b/lib/gc/Transforms/TilingUtil.hpp @@ -0,0 +1,28 @@ +//===-- TilingUtil.hpp - Tile op using tiling interface ---------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef TEMPORARY_TILEUSINGINTERFACE_X_H +#define TEMPORARY_TILEUSINGINTERFACE_X_H + +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Interfaces/TilingInterface.h" +#include +namespace mlir { +namespace linalgX { + +// An enahncement for the upstream pass to support tiling reduction for MKmk +// like cases(with multiple reduction iterators). +FailureOr tileReductionUsingForall( + RewriterBase &b, PartialReductionOpInterface op, + ArrayRef threadNums, ArrayRef tileSizes, + ArrayRef newParallelDims, std::optional mapping); +} // namespace linalgX +} // namespace mlir + +#endif \ No newline at end of file diff --git a/test/mlir/test/gc/Transforms/deepTileContractionNamedOp.mlir b/test/mlir/test/gc/Transforms/deepTileContractionNamedOp.mlir new file mode 100644 index 000000000..7cb39e2b3 --- /dev/null +++ b/test/mlir/test/gc/Transforms/deepTileContractionNamedOp.mlir @@ -0,0 +1,155 @@ +// RUN: gc-opt --split-input-file --deep-tile-contraction-named-op %s | FileCheck %s + +// ----- + +/// CHECK-LABEL: @matmul_2Dx2D_f32 +func.func @matmul_2Dx2D_f32(%arg0: tensor<4096x4096xf32>, %arg1: tensor<4096x4096xf32>) -> tensor<4096x4096xf32> { + %cst_0 = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<4096x4096xf32> + %1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<4096x4096xf32>) -> tensor<4096x4096xf32> + // CHECK: scf.forall {{.*}} (0) to (4096) step (1024) {{.*}} (tensor<4096x4096xf32>) { + // CHECK: tensor.extract_slice {{.*}} [1024, 4096] [1, 1] + // CHECK: scf.forall {{.*}} (0) to (4096) step (2048) {{.*}} (tensor<1024x4096xf32>) + // CHECK: tensor.extract_slice {{.*}} [1024, 2048] [1, 1] + // CHECK: scf.for + // CHECK: tensor.extract_slice {{.*}} [256, 2048] [1, 1] + // CHECK: scf.for + // CHECK: tensor.extract_slice {{.*}} [256, 256] [1, 1] + // CHECK: scf.for + // CHECK: scf.for + // CHECK: tensor.extract_slice {{.*}} [32, 256] [1, 1] + // CHECK: tensor.extract_slice {{.*}} [32, 256] [1, 1] + // CHECK: scf.for + // CHECK: tensor.extract_slice {{.*}} [256, 32] [1, 1] + // CHECK: tensor.extract_slice {{.*}} [32, 32] [1, 1] + // CHECK: linalg.transpose {{.*}} permutation = [1, 0, 2] + // CHECK: tensor.expand_shape {{.*}} output_shape [8, 32, 32] : tensor<256x32xf32> into tensor<8x32x32xf32> + // CHECK: scf.if + // CHECK: linalg.fill + // CHECK: linalg.batch_reduce_matmul + // CHECK: else + // CHECK: linalg.batch_reduce_matmul + // CHECK: tensor.insert_slice {{.*}} [32, 256] [1, 1] + %2 = linalg.matmul {MThreads = 4 : i32, NThreads = 2 : i32, KThreads = 1 : i32, MBlock = 256 : i32, NBlock = 256 : i32, KBlock = 256 : i32,innermostMBlock = 32 : i32, innermostNBlock = 32 : i32, innermostKBlock = 32 : i32 } ins(%arg0, %arg1 : tensor<4096x4096xf32>, tensor<4096x4096xf32>) outs(%1 : tensor<4096x4096xf32>) -> tensor<4096x4096xf32> + return %2 : tensor<4096x4096xf32> +} + +// ----- + +/// CHECK-LABEL: @matmul_4Dx4D_bf16 +func.func @matmul_4Dx4D_bf16(%arg0: tensor<128x128x32x32xbf16>, %arg1: tensor<128x128x16x32x2xbf16>) -> tensor<128x128x32x32xbf16> { + %cst_0 = arith.constant 0.000000e+00 : bf16 + // CHECK: tensor.empty() : tensor<128x128x32x32xbf16> + %0 = tensor.empty() : tensor<128x128x32x32xbf16> + // CHECK-NOT: linalg.fill + %1 = linalg.fill ins(%cst_0 : bf16) outs(%0 : tensor<128x128x32x32xbf16>) -> tensor<128x128x32x32xbf16> + // CHECK: scf.forall {{.*}} (0) to (128) step (8) {{.*}} (tensor<128x128x32x32xbf16>) + // CHECK: tensor.extract_slice {{.*}} [8, 128, 32, 32] [1, 1, 1, 1] + // CHECK: scf.forall {{.*}} (0) to (128) step (64) {{.*}} (tensor<8x128x32x32xbf16>) + // CHECK: tensor.extract_slice {{.*}} [8, 64, 32, 32] [1, 1, 1, 1] + // CHECK: scf.for + // CHECK: tensor.extract_slice {{.*}} [8, 8, 32, 32] [1, 1, 1, 1] + // CHECK: tensor.empty() : tensor<8x8x32x32xf32> + // CHECK: scf.for + // CHECK: scf.for + // CHECK: tensor.extract_slice {{.*}} [1, 8, 32, 32] [1, 1, 1, 1] + // CHECK: tensor.extract_slice {{.*}} [1, 8, 32, 32] [1, 1, 1, 1] + // CHECK: scf.for + // CHECK: tensor.collapse_shape {{.*}} tensor<1x8x32x32xbf16> into tensor<8x32x32xbf16> + // CHECK: tensor.extract_slice {{.*}} [1, 8, 16, 32, 2] [1, 1, 1, 1, 1] + // CHECK: tensor.collapse_shape {{.*}} tensor<1x8x16x32x2xbf16> into tensor<8x16x32x2xbf16> + // CHECK: tensor.extract_slice {{.*}} [1, 1, 32, 32] [1, 1, 1, 1] + // CHECK: tensor.collapse_shape {{.*}} tensor<1x1x32x32xf32> into tensor<32x32xf32> + // CHECK: tensor.extract_slice {{.*}} [1, 1, 32, 32] [1, 1, 1, 1] + // CHECK: tensor.collapse_shape {{.*}} tensor<1x1x32x32xbf16> into tensor<32x32xbf16> + // CHECK: scf.if + // CHECK: linalg.fill + // CHECK: linalgx.batch_reduce_matmul_vnni + // CHECK: else + // CHECK: linalgx.batch_reduce_matmul_vnni + // CHECK: scf.if + // CHECK: linalg.copy + // CHECK: else + %2 = linalgx.mm4d_vnni {MThreads = 16 : i32, NThreads = 2 : i32, KThreads = 1 : i32, MBlock = 256 : i32, NBlock = 256 : i32, KBlock = 256 : i32,innermostMBlock = 32 : i32, innermostNBlock = 32 : i32, innermostKBlock = 32 : i32 } ins(%arg0, %arg1 : tensor<128x128x32x32xbf16>, tensor<128x128x16x32x2xbf16>) outs(%1 : tensor<128x128x32x32xbf16>) -> tensor<128x128x32x32xbf16> + return %2 : tensor<128x128x32x32xbf16> +} + +// ----- + +/// CHECK-LABEL: @matmul_2Dx4D_bf16 +func.func @matmul_2Dx4D_bf16(%arg0: tensor<4096x4096xbf16>, %arg1: tensor<128x128x16x32x2xbf16>) -> tensor<4096x4096xbf16> { + %cst_0 = arith.constant 0.000000e+00 : bf16 + %0 = tensor.empty() : tensor<4096x4096xbf16> + %1 = linalg.fill ins(%cst_0 : bf16) outs(%0 : tensor<4096x4096xbf16>) -> tensor<4096x4096xbf16> + // CHECK: scf.forall {{.*}} (2) {{.*}} (tensor<2x1x1x4096x4096xf32>) + // CHECK: tensor.extract_slice {{.*}} [1, 1, 1, 4096, 4096] [1, 1, 1, 1, 1] + // CHECK: scf.forall {{.*}} (0) to (4096) step (256) {{.*}} (tensor<4096x4096xf32>) + // CHECK: tensor.extract_slice {{.*}} [256, 4096] [1, 1] + // CHECK: scf.forall {{.*}} (0) to (128) step (64) {{.*}} (tensor<256x4096xf32>) + // CHECK: tensor.extract_slice {{.*}} [256, 2048] [1, 1] + // CHECK: scf.for + // CHECK: tensor.extract_slice {{.*}} [256, 256] [1, 1] + // CHECK: scf.for + // CHECK: scf.for + // CHECK: tensor.extract_slice {{.*}} [32, 256] [1, 1] + // CHECK: tensor.extract_slice {{.*}} [32, 256] [1, 1] + // CHECK: scf.for + // CHECK: tensor.extract_slice {{.*}} [1, 8, 16, 32, 2] [1, 1, 1, 1, 1] + // CHECK: tensor.collapse_shape {{.*}} tensor<1x8x16x32x2xbf16> into tensor<8x16x32x2xbf16> + // CHECK: tensor.extract_slice {{.*}} [32, 32] [1, 1] + // CHECK: linalg.transpose {{.*}} permutation = [1, 0, 2] + // CHECK: scf.if + // CHECK: linalg.fill + // CHECK: linalgx.batch_reduce_matmul_vnni + // CHECK: else + // CHECK: linalgx.batch_reduce_matmul_vnni + // CHECK: scf.forall.in_parallel + // CHECK: scf.forall.in_parallel + // CHECK: scf.forall.in_parallel + // CHECK: linalg.reduce {{.*}} dimensions = [0, 1, 2] + // CHECK: linalg.copy + %2 = linalgx.mm2d_vnni {MThreads = 32 : i32, NThreads = 2 : i32, KThreads = 2 : i32, MBlock = 256 : i32, NBlock = 256 : i32, KBlock = 256 : i32,innermostMBlock = 32 : i32, innermostNBlock = 32 : i32, innermostKBlock = 32 : i32 } ins(%arg0, %arg1 : tensor<4096x4096xbf16>, tensor<128x128x16x32x2xbf16>) outs(%1 : tensor<4096x4096xbf16>) -> tensor<4096x4096xbf16> + return %2 : tensor<4096x4096xbf16> +} + +// ----- + +module attributes { + dlti.target_system_spec = #dlti.target_system_spec< + "CPU": #dlti.target_device_spec< + #dlti.dl_entry<"L1_cache_size_in_bytes", 49152 : i32>, + #dlti.dl_entry<"L2_cache_size_in_bytes", 2097152 : i32>, + #dlti.dl_entry<"L3_cache_size_in_bytes", 110100480 : i32>, + #dlti.dl_entry<"num_threads", 56 : i32>, + #dlti.dl_entry<"max_vector_width", 512 : i32>> + >} { + /// CHECK-LABEL: @matmul_2Dx4D_bf16_with_dlti +func.func @matmul_2Dx4D_bf16_with_dlti(%arg0: tensor<4096x4096xbf16>, %arg1: tensor<128x128x16x32x2xbf16>) -> tensor<4096x4096xbf16> { + %cst_0 = arith.constant 0.000000e+00 : bf16 + %0 = tensor.empty() : tensor<4096x4096xbf16> + %1 = linalg.fill ins(%cst_0 : bf16) outs(%0 : tensor<4096x4096xbf16>) -> tensor<4096x4096xbf16> + // CHECK: scf.forall + // CHECK: tensor.extract_slice + // CHECK: scf.forall + // CHECK: tensor.extract_slice + // CHECK: scf.for + // CHECK: tensor.extract_slice + // CHECK: scf.for + // CHECK: tensor.extract_slice + // CHECK: tensor.extract_slice + // CHECK: scf.for + // CHECK: tensor.extract_slice + // CHECK: tensor.extract_slice + // CHECK: linalg.transpose + // CHECK: scf.if + // CHECK: linalg.fill + // CHECK: linalgx.batch_reduce_matmul_vnni + // CHECK: else + // CHECK: linalgx.batch_reduce_matmul_vnni + // CHECK: scf.forall.in_parallel + // CHECK: scf.forall.in_parallel + %2 = linalgx.mm2d_vnni ins(%arg0, %arg1 : tensor<4096x4096xbf16>, tensor<128x128x16x32x2xbf16>) outs(%1 : tensor<4096x4096xbf16>) -> tensor<4096x4096xbf16> + return %2 : tensor<4096x4096xbf16> +} + +} diff --git a/test/mlir/test/gc/Transforms/mergeNestedForall.mlir b/test/mlir/test/gc/Transforms/mergeNestedForall.mlir new file mode 100644 index 000000000..d878739c8 --- /dev/null +++ b/test/mlir/test/gc/Transforms/mergeNestedForall.mlir @@ -0,0 +1,93 @@ +// RUN: gc-opt --split-input-file --merge-nested-forall %s | FileCheck %s + +// ----- + +#map = affine_map<(d0) -> (d0 * 1024)> +#map1 = affine_map<(d0) -> (d0 * 2048)> +#map2 = affine_map<(d0)[s0, s1] -> (d0 * 2048 + s0 + s1)> +#map3 = affine_map<(d0)[s0, s1] -> (d0 * 1024 + s0 + s1)> +module { + func.func @matmul_2Dx2D_f32(%arg0: memref<4096x4096xf32>, %arg1: memref<4096x4096xf32>, %arg2: memref<4096x4096xf32>) { + // CHECK: scf.forall {{.*}} (4, 2) + scf.forall (%arg3) in (4) { + scf.forall (%arg4) in (2) { + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + %c0 = arith.constant 0 : index + %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x32x32xf32> + scf.for %arg5 = %c0 to %c1024 step %c256 { + %c2048 = arith.constant 2048 : index + scf.for %arg6 = %c0 to %c2048 step %c256 { + %c4096 = arith.constant 4096 : index + scf.for %arg7 = %c0 to %c4096 step %c256 { + %c32 = arith.constant 32 : index + scf.for %arg8 = %c0 to %c256 step %c32 { + scf.for %arg9 = %c0 to %c256 step %c32 { + %0 = affine.apply #map(%arg3) + %1 = affine.apply #map1(%arg4) + %subview = memref.subview %arg2[%0, 0] [1024, 4096] [1, 1] : memref<4096x4096xf32> to memref<1024x4096xf32, strided<[4096, 1], offset: ?>> + %subview_0 = memref.subview %subview[0, %1] [1024, 2048] [1, 1] : memref<1024x4096xf32, strided<[4096, 1], offset: ?>> to memref<1024x2048xf32, strided<[4096, 1], offset: ?>> + %subview_1 = memref.subview %subview_0[%arg5, 0] [256, 2048] [1, 1] : memref<1024x2048xf32, strided<[4096, 1], offset: ?>> to memref<256x2048xf32, strided<[4096, 1], offset: ?>> + %subview_2 = memref.subview %subview_1[0, %arg6] [256, 256] [1, 1] : memref<256x2048xf32, strided<[4096, 1], offset: ?>> to memref<256x256xf32, strided<[4096, 1], offset: ?>> + %subview_3 = memref.subview %subview_2[%arg8, 0] [32, 256] [1, 1] : memref<256x256xf32, strided<[4096, 1], offset: ?>> to memref<32x256xf32, strided<[4096, 1], offset: ?>> + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %2 = arith.cmpi eq, %arg7, %c0 : index + %3 = affine.apply #map2(%arg4)[%arg9, %arg6] + %subview_4 = memref.subview %arg1[%arg7, %3] [256, 32] [1, 1] : memref<4096x4096xf32> to memref<256x32xf32, strided<[4096, 1], offset: ?>> + %subview_5 = memref.subview %subview_3[0, %arg9] [32, 32] [1, 1] : memref<32x256xf32, strided<[4096, 1], offset: ?>> to memref<32x32xf32, strided<[4096, 1], offset: ?>> + scf.parallel (%arg10, %arg11, %arg12) = (%c0, %c0, %c0) to (%c8, %c32, %c32) step (%c1, %c1, %c1) { + %4 = affine.apply #map3(%arg3)[%arg8, %arg5] + %subview_6 = memref.subview %arg0[%4, %arg7] [32, 256] [1, 1] : memref<4096x4096xf32> to memref<32x256xf32, strided<[4096, 1], offset: ?>> + %expand_shape_7 = memref.expand_shape %subview_6 [[0], [1, 2]] output_shape [32, 8, 32] : memref<32x256xf32, strided<[4096, 1], offset: ?>> into memref<32x8x32xf32, strided<[4096, 32, 1], offset: ?>> + %5 = memref.load %expand_shape_7[%arg11, %arg10, %arg12] : memref<32x8x32xf32, strided<[4096, 32, 1], offset: ?>> + memref.store %5, %alloc[%arg10, %arg11, %arg12] : memref<8x32x32xf32> + scf.reduce + } + %expand_shape = memref.expand_shape %subview_4 [[0, 1], [2]] output_shape [8, 32, 32] : memref<256x32xf32, strided<[4096, 1], offset: ?>> into memref<8x32x32xf32, strided<[131072, 4096, 1], offset: ?>> + scf.if %2 { + scf.parallel (%arg10, %arg11) = (%c0, %c0) to (%c32, %c32) step (%c1, %c1) { + %cst = arith.constant 0.000000e+00 : f32 + memref.store %cst, %subview_5[%arg10, %arg11] : memref<32x32xf32, strided<[4096, 1], offset: ?>> + scf.reduce + } + scf.for %arg10 = %c0 to %c8 step %c1 { + scf.parallel (%arg11, %arg12) = (%c0, %c0) to (%c32, %c32) step (%c1, %c1) { + scf.for %arg13 = %c0 to %c32 step %c1 { + %4 = memref.load %alloc[%arg10, %arg11, %arg13] : memref<8x32x32xf32> + %5 = memref.load %expand_shape[%arg10, %arg13, %arg12] : memref<8x32x32xf32, strided<[131072, 4096, 1], offset: ?>> + %6 = memref.load %subview_5[%arg11, %arg12] : memref<32x32xf32, strided<[4096, 1], offset: ?>> + %7 = arith.mulf %4, %5 : f32 + %8 = arith.addf %6, %7 : f32 + memref.store %8, %subview_5[%arg11, %arg12] : memref<32x32xf32, strided<[4096, 1], offset: ?>> + } + scf.reduce + } + } + } else { + scf.for %arg10 = %c0 to %c8 step %c1 { + scf.parallel (%arg11, %arg12) = (%c0, %c0) to (%c32, %c32) step (%c1, %c1) { + scf.for %arg13 = %c0 to %c32 step %c1 { + %4 = memref.load %alloc[%arg10, %arg11, %arg13] : memref<8x32x32xf32> + %5 = memref.load %expand_shape[%arg10, %arg13, %arg12] : memref<8x32x32xf32, strided<[131072, 4096, 1], offset: ?>> + %6 = memref.load %subview_5[%arg11, %arg12] : memref<32x32xf32, strided<[4096, 1], offset: ?>> + %7 = arith.mulf %4, %5 : f32 + %8 = arith.addf %6, %7 : f32 + memref.store %8, %subview_5[%arg11, %arg12] : memref<32x32xf32, strided<[4096, 1], offset: ?>> + } + scf.reduce + } + } + } + } + } + } + } + } + memref.dealloc %alloc : memref<8x32x32xf32> + } + } + return + } +} + diff --git a/test/mlir/test/gc/Transforms/sinkOpIntoInnerLoop.mlir b/test/mlir/test/gc/Transforms/sinkOpIntoInnerLoop.mlir new file mode 100644 index 000000000..908d08883 --- /dev/null +++ b/test/mlir/test/gc/Transforms/sinkOpIntoInnerLoop.mlir @@ -0,0 +1,46 @@ +// RUN: gc-opt --split-input-file --sink-op-into-inner-loop %s | FileCheck %s + +func.func @matmul_2Dx2D_f32(%arg0: memref<4096x4096xf32>, %arg1: memref<4096x4096xf32>, %arg2: memref<4096x4096xf32>) { + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %cst = arith.constant 0.000000e+00 : f32 + %c0 = arith.constant 0 : index + %c1024 = arith.constant 1024 : index + %c256 = arith.constant 256 : index + %c2048 = arith.constant 2048 : index + %c4096 = arith.constant 4096 : index + %c32 = arith.constant 32 : index + // CHECK: scf.forall + // CHECK-NOT: affine.apply + // CHECK-NOT: memref.subview + // CHECK-NEXT: scf.forall + scf.forall (%arg3) in (4) { + %0 = affine.apply affine_map<(d0) -> (d0 * 1024)>(%arg3) + %subview = memref.subview %arg2[%0, 0] [1024, 4096] [1, 1] : memref<4096x4096xf32> to memref<1024x4096xf32, strided<[4096, 1], offset: ?>> + scf.forall (%arg4) in (2) { + %1 = affine.apply affine_map<(d0) -> (d0 * 2048)>(%arg4) + %subview_0 = memref.subview %subview[0, %1] [1024, 2048] [1, 1] : memref<1024x4096xf32, strided<[4096, 1], offset: ?>> to memref<1024x2048xf32, strided<[4096, 1], offset: ?>> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x32x32xf32> + scf.for %arg5 = %c0 to %c1024 step %c256 { + %subview_1 = memref.subview %subview_0[%arg5, 0] [256, 2048] [1, 1] : memref<1024x2048xf32, strided<[4096, 1], offset: ?>> to memref<256x2048xf32, strided<[4096, 1], offset: ?>> + scf.for %arg6 = %c0 to %c2048 step %c256 { + %subview_2 = memref.subview %subview_1[0, %arg6] [256, 256] [1, 1] : memref<256x2048xf32, strided<[4096, 1], offset: ?>> to memref<256x256xf32, strided<[4096, 1], offset: ?>> + scf.for %arg7 = %c0 to %c4096 step %c256 { + %2 = arith.cmpi eq, %arg7, %c0 : index + scf.for %arg8 = %c0 to %c256 step %c32 { + %3 = affine.apply affine_map<(d0)[s0, s1] -> (d0 * 1024 + s0 + s1)>(%arg3)[%arg8, %arg5] + %subview_3 = memref.subview %arg0[%3, %arg7] [32, 256] [1, 1] : memref<4096x4096xf32> to memref<32x256xf32, strided<[4096, 1], offset: ?>> + %subview_4 = memref.subview %subview_2[%arg8, 0] [32, 256] [1, 1] : memref<256x256xf32, strided<[4096, 1], offset: ?>> to memref<32x256xf32, strided<[4096, 1], offset: ?>> + %expand_shape = memref.expand_shape %subview_3 [[0], [1, 2]] output_shape [32, 8, 32] : memref<32x256xf32, strided<[4096, 1], offset: ?>> into memref<32x8x32xf32, strided<[4096, 32, 1], offset: ?>> + scf.for %arg9 = %c0 to %c256 step %c32 { + + } + } + } + } + } + memref.dealloc %alloc : memref<8x32x32xf32> + } + } + return +} \ No newline at end of file diff --git a/test/mlir/unittests/Analysis/CMakeLists.txt b/test/mlir/unittests/Analysis/CMakeLists.txt index d78877afe..ed253bfdf 100644 --- a/test/mlir/unittests/Analysis/CMakeLists.txt +++ b/test/mlir/unittests/Analysis/CMakeLists.txt @@ -3,5 +3,6 @@ add_mlir_unittest(GCAnalysisTests ) target_link_libraries(GCAnalysisTests PRIVATE + GcPasses GcAnalysis GcJitWrapper)