diff --git a/.gitignore b/.gitignore index f2e689db..ac0faa38 100644 --- a/.gitignore +++ b/.gitignore @@ -95,6 +95,12 @@ __pycache__ .pytype +# CMake related stuff +CMakeFiles +CMakeCache.txt +Makefile +cmake_install.cmake + # Stuff auto-generated by MLIR->Kokkos emitter examples/lapis*/ examples/*.mlir @@ -103,6 +109,8 @@ examples/*.tns tests/lapis/ tests/*.mlir tests/*.tns +tests/Dialect/Testing +tests/Testing # Other test artifacts tests/Dialect/*/Output @@ -152,3 +160,7 @@ examples/mala_batch/forward_snap.hpp examples/mala_batch/forward_snap_lowered.mlir examples/mala_batch/build/ +# Temporary files produced by maxpool +examples/maxpool_nchw_teamlevel/maxpool.cpp +examples/maxpool_nchw_teamlevel/maxpool.hpp +examples/maxpool_nchw_teamlevel/maxpool_lowered.mlir diff --git a/examples/mlir-dumps/batched-gemv-axpy-base.mlir b/examples/mlir-dumps/batched-gemv-axpy-base.mlir new file mode 100644 index 00000000..992f52fb --- /dev/null +++ b/examples/mlir-dumps/batched-gemv-axpy-base.mlir @@ -0,0 +1,34 @@ +#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +#map3 = affine_map<(d0, d1) -> (d0, d1)> +module { + func.func @gemv(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) { + ^bb0(%in: f64, %in_0: f64, %out: f64): + %1 = arith.mulf %in, %in_0 : f64 + %2 = arith.addf %out, %1 : f64 + linalg.yield %2 : f64 + } -> tensor + return %0 : tensor + } + func.func @axpy(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + %0 = linalg.generic {indexing_maps = [#map3, #map3, #map3], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) { + ^bb0(%in: f64, %in_0: f64, %out: f64): + %1 = arith.addf %in, %in_0 : f64 + linalg.yield %1 : f64 + } -> tensor + return %0 : tensor + } + func.func @main(%batch_size : index, %n : index) -> tensor { + %0 = tensor.empty(%batch_size, %n, %n) : tensor + %1 = tensor.empty(%batch_size, %n) : tensor + %2 = tensor.empty(%batch_size, %n) : tensor + %3 = tensor.empty(%batch_size, %n) : tensor + %4 = call @gemv(%0, %1, %3) { fuse_with = "axpy" } : (tensor, tensor, tensor) -> tensor + %5 = tensor.empty(%batch_size, %n) : tensor + %6 = call @axpy(%4, %2, %5) { fuse_with = "gemv" } : (tensor, tensor, tensor) -> tensor + return %6 : tensor + } +} + diff --git a/examples/mlir-dumps/batched-gemv-axpy-non-dynamic.mlir b/examples/mlir-dumps/batched-gemv-axpy-non-dynamic.mlir new file mode 100644 index 00000000..94b660fa --- /dev/null +++ b/examples/mlir-dumps/batched-gemv-axpy-non-dynamic.mlir @@ -0,0 +1,35 @@ +#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +#map3 = affine_map<(d0, d1) -> (d0, d1)> +module { + func.func @gemv(%arg0: tensor<10x10x10xf64>, %arg1: tensor<10x10xf64>, %arg2: + tensor<10x10xf64>) -> tensor<10x10xf64> attributes { noinline } { + %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<10x10x10xf64>, tensor<10x10xf64>) outs(%arg2 : tensor<10x10xf64>) { + ^bb0(%in: f64, %in_0: f64, %out: f64): + %1 = arith.mulf %in, %in_0 : f64 + %2 = arith.addf %out, %1 : f64 + linalg.yield %2 : f64 + } -> tensor<10x10xf64> + return %0 : tensor<10x10xf64> + } + func.func @axpy(%arg0: tensor<10x10xf64>, %arg1: tensor<10x10xf64>, %arg2: + tensor<10x10xf64>) -> tensor<10x10xf64> attributes { noinline } { + %0 = linalg.generic {indexing_maps = [#map3, #map3, #map3], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<10x10xf64>, tensor<10x10xf64>) outs(%arg2 : tensor<10x10xf64>) { + ^bb0(%in: f64, %in_0: f64, %out: f64): + %1 = arith.addf %in, %in_0 : f64 + linalg.yield %1 : f64 + } -> tensor<10x10xf64> + return %0 : tensor<10x10xf64> + } + func.func @main() -> tensor<10x10xf64> { + %0 = tensor.empty() : tensor<10x10x10xf64> + %1 = tensor.empty() : tensor<10x10xf64> + %2 = tensor.empty() : tensor<10x10xf64> + %3 = tensor.empty() : tensor<10x10xf64> + %5 = tensor.empty() : tensor<10x10xf64> + %4 = call @gemv(%0, %1, %3) { fuse_with = "axpy" } : (tensor<10x10x10xf64>, tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64> + %6 = call @axpy(%4, %2, %5) { fuse_with = "gemv" } : (tensor<10x10xf64>, tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64> + return %6 : tensor<10x10xf64> + } +} diff --git a/examples/mlir-dumps/cg-iteration-unfused.mlir b/examples/mlir-dumps/cg-iteration-unfused.mlir new file mode 100644 index 00000000..e67bf96c --- /dev/null +++ b/examples/mlir-dumps/cg-iteration-unfused.mlir @@ -0,0 +1,105 @@ +module { + func.func @gemv(%arg0: memref<128x128xf64>, %arg1: memref<128xf64>, %arg2: memref<128xf64>) -> memref<128xf64> { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1 = arith.constant 1 : index + scf.parallel (%arg3) = (%c0) to (%c128) step (%c1) { + scf.for %arg4 = %c0 to %c128 step %c1 { + %0 = memref.load %arg0[%arg3, %arg4] : memref<128x128xf64> + %1 = memref.load %arg1[%arg4] : memref<128xf64> + %2 = memref.load %arg2[%arg3] : memref<128xf64> + %3 = arith.mulf %0, %1 : f64 + %4 = arith.addf %2, %3 : f64 + memref.store %4, %arg2[%arg3] : memref<128xf64> + } + scf.reduce + } + return %arg2 : memref<128xf64> + } + func.func @xpy(%arg0: memref<128xf64>, %arg1: memref<128xf64>, %arg2: memref<128xf64>) -> memref<128xf64> { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1 = arith.constant 1 : index + scf.parallel (%arg3) = (%c0) to (%c128) step (%c1) { + %0 = memref.load %arg0[%arg3] : memref<128xf64> + %1 = memref.load %arg1[%arg3] : memref<128xf64> + %2 = arith.addf %0, %1 : f64 + memref.store %2, %arg2[%arg3] : memref<128xf64> + scf.reduce + } + return %arg2 : memref<128xf64> + } + func.func @dot(%arg0: memref<128xf64>, %arg1: memref<128xf64>, %arg2: memref) -> memref { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1 = arith.constant 1 : index + scf.for %arg3 = %c0 to %c128 step %c1 { + %0 = memref.load %arg0[%arg3] : memref<128xf64> + %1 = memref.load %arg1[%arg3] : memref<128xf64> + %2 = memref.load %arg2[] : memref + %3 = arith.mulf %0, %1 : f64 + %4 = arith.addf %2, %3 : f64 + memref.store %4, %arg2[] : memref + } + return %arg2 : memref + } + func.func @dscal(%arg0: memref, %arg1: memref<128xf64>, %arg2: memref<128xf64>) -> memref<128xf64> { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1 = arith.constant 1 : index + scf.parallel (%arg3) = (%c0) to (%c128) step (%c1) { + %0 = memref.load %arg0[] : memref + %1 = memref.load %arg1[%arg3] : memref<128xf64> + %2 = arith.mulf %0, %1 : f64 + memref.store %2, %arg2[%arg3] : memref<128xf64> + scf.reduce + } + return %arg2 : memref<128xf64> + } + func.func @div(%arg0: memref, %arg1: memref, %arg2: memref) -> memref { + %0 = memref.load %arg0[] : memref + %1 = memref.load %arg1[] : memref + %2 = arith.divf %0, %1 : f64 + memref.store %2, %arg2[] : memref + return %arg2 : memref + } + func.func @neg(%arg0: memref, %arg1: memref) -> memref { + %0 = memref.load %arg0[] : memref + %1 = arith.negf %0 : f64 + memref.store %1, %arg1[] : memref + return %arg1 : memref + } + func.func @main(%arg0: memref<128x128xf64>, %arg1: memref<128xf64>, %arg2: memref<128xf64>, %arg3: memref<128xf64>) -> memref<128xf64> { + %alloc = memref.alloc() {alignment = 64 : i64} : memref + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<128xf64> + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref + %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<128xf64> + %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<128xf64> + %alloc_4 = memref.alloc() {alignment = 64 : i64} : memref + %alloc_5 = memref.alloc() {alignment = 64 : i64} : memref<128xf64> + %alloc_6 = memref.alloc() {alignment = 64 : i64} : memref<128xf64> + %alloc_7 = memref.alloc() {alignment = 64 : i64} : memref + %alloc_8 = memref.alloc() {alignment = 64 : i64} : memref<128xf64> + %alloc_9 = memref.alloc() {alignment = 64 : i64} : memref<128xf64> + %alloc_10 = memref.alloc() {alignment = 64 : i64} : memref + %alloc_11 = memref.alloc() {alignment = 64 : i64} : memref + + + %0 = call @dot(%arg2, %arg2, %alloc) {fuse_with = "gemv"} : (memref<128xf64>, memref<128xf64>, memref) -> memref + %1 = call @gemv(%arg0, %arg3, %alloc_0) {fuse_with = "dot"} : (memref<128x128xf64>, memref<128xf64>, memref<128xf64>) -> memref<128xf64> + %2 = call @dot(%1, %arg3, %alloc_1) : (memref<128xf64>, memref<128xf64>, memref) -> memref + + %3 = call @div(%0, %2, %alloc_10) {fuse_with = ""} : (memref, memref, memref) -> memref + %4 = call @dscal(%3, %arg3, %alloc_2) {fuse_with = ""} : (memref, memref<128xf64>, memref<128xf64>) -> memref<128xf64> + %5 = call @xpy(%arg1, %4, %alloc_3) {fuse_with = ""} : (memref<128xf64>, memref<128xf64>, memref<128xf64>) -> memref<128xf64> + %6 = call @neg(%3, %alloc_4) {fuse_with = ""} : (memref, memref) -> memref + %7 = call @dscal(%6, %1, %alloc_5) {fuse_with = ""} : (memref, memref<128xf64>, memref<128xf64>) -> memref<128xf64> + %8 = call @xpy(%arg2, %7, %alloc_6) {fuse_with = ""} : (memref<128xf64>, memref<128xf64>, memref<128xf64>) -> memref<128xf64> + %9 = call @dot(%8, %8, %alloc_7) {fuse_with = ""} : (memref<128xf64>, memref<128xf64>, memref) -> memref + %10 = call @div(%9, %0, %alloc_11) {fuse_with = ""} : (memref, memref, memref) -> memref + %11 = call @dscal(%10, %arg3, %alloc_8) {fuse_with = ""} : (memref, memref<128xf64>, memref<128xf64>) -> memref<128xf64> + %12 = call @xpy(%8, %11, %alloc_9) {fuse_with = ""} : (memref<128xf64>, memref<128xf64>, memref<128xf64>) -> memref<128xf64> + return %12 : memref<128xf64> + } +} + diff --git a/examples/mlir-dumps/matvec-dot.mlir b/examples/mlir-dumps/matvec-dot.mlir new file mode 100644 index 00000000..5e29ff30 --- /dev/null +++ b/examples/mlir-dumps/matvec-dot.mlir @@ -0,0 +1,28 @@ +#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 32, crdWidth = 32 }> + +module { + func.func private @spmv(%A: tensor, %x: tensor, %ydst: tensor) -> tensor { + %y = linalg.matvec ins(%A, %x: tensor, tensor) outs(%ydst : tensor) -> tensor + return %y : tensor + } + + func.func @dot(%x : tensor, %y : tensor, %res : tensor) -> + tensor attributes { noinline } { + %dot = linalg.dot ins(%x, %y: tensor, tensor) + outs(%res: tensor) -> tensor + return %dot: tensor + } + + func.func @main(%A : tensor, %x : tensor, %y : tensor) + -> f64 { + %0 = func.call @spmv(%A, %x, %y) { noinline, fuse_with = "dot" } : + (tensor, tensor, tensor) -> tensor + + %dot_res = tensor.empty() : tensor + %1 = func.call @dot(%0, %x, %dot_res) { noinline, fuse_with = "spmv" } : + (tensor, tensor, tensor) -> tensor + + %ret = tensor.extract %1[] : tensor + return %ret : f64 + } +} diff --git a/examples/mlir-dumps/mmv.mlir b/examples/mlir-dumps/mmv.mlir new file mode 100644 index 00000000..4d447d80 --- /dev/null +++ b/examples/mlir-dumps/mmv.mlir @@ -0,0 +1,59 @@ +module { + func.func private @matmul( + %a: tensor<4096x4096xf64>, + %b: tensor<4096x4096xf64>, + %c_out: tensor<4096x4096xf64> + ) -> tensor<4096x4096xf64> { + %c = linalg.matmul ins(%a, %b: tensor<4096x4096xf64>, tensor<4096x4096xf64>) + outs(%c_out: tensor<4096x4096xf64>) -> tensor<4096x4096xf64> + return %c : tensor<4096x4096xf64> + } + + func.func private @matvec( + %a: tensor<4096x4096xf64>, + %x: tensor<4096xf64>, + %y_out: tensor<4096xf64> + ) -> tensor<4096xf64> { + %y = linalg.matvec ins(%a, %x: tensor<4096x4096xf64>, tensor<4096xf64>) + outs(%y_out: tensor<4096xf64>) -> tensor<4096xf64> + + return %y : tensor<4096xf64> + } + + func.func @matmul_into_matvec( + %a: tensor<4096x4096xf64>, + %b: tensor<4096x4096xf64>, + %x: tensor<4096xf64> + ) -> tensor<4096xf64> { + + %c_init = tensor.empty() : tensor<4096x4096xf64> + %c = func.call @matmul(%a, %b, %c_init) { fuse_with = "matvec" } + : (tensor<4096x4096xf64>, tensor<4096x4096xf64>, tensor<4096x4096xf64>) + -> tensor<4096x4096xf64> + + %y_init = tensor.empty() : tensor<4096xf64> + %y_out = func.call @matvec(%c, %x, %y_init) { fuse_with = "matmul" } + : (tensor<4096x4096xf64>, tensor<4096xf64>, tensor<4096xf64>) + -> tensor<4096xf64> + + return %y_out : tensor<4096xf64> + } + + func.func @matvec_into_matvec( + %a: tensor<4096x4096xf64>, + %b: tensor<4096x4096xf64>, + %x: tensor<4096xf64> + ) -> tensor<4096xf64> { + %bx_init = tensor.empty() : tensor<4096xf64> + %bx = func.call @matvec(%b, %x, %bx_init) { fuse_with = "matvec" } + : (tensor<4096x4096xf64>, tensor<4096xf64>, tensor<4096xf64>) + -> tensor<4096xf64> + + %y_init = tensor.empty() : tensor<4096xf64> + %y_out = func.call @matvec(%a, %bx, %y_init) { fuse_with = "matvec" } + : (tensor<4096x4096xf64>, tensor<4096xf64>, tensor<4096xf64>) + -> tensor<4096xf64> + + return %y_out : tensor<4096xf64> + } +} diff --git a/examples/mlir-dumps/pcg.mlir b/examples/mlir-dumps/pcg.mlir new file mode 100644 index 00000000..7e097a2b --- /dev/null +++ b/examples/mlir-dumps/pcg.mlir @@ -0,0 +1,97 @@ +#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 32, crdWidth = 32 }> +#idmap = affine_map<(d0) -> (d0)> +module { + func.func private @spmv(%A: tensor, %x: tensor, %ydst: tensor) -> tensor { + %y = linalg.matvec ins(%A, %x: tensor, tensor) outs(%ydst : tensor) -> tensor + return %y : tensor + } + + func.func private @dot(%x: tensor, %y: tensor) -> f64 { + %0 = tensor.empty() : tensor + %dot = linalg.dot ins(%x, %y : tensor,tensor) outs(%0: tensor) -> tensor + %6 = tensor.extract %dot[] : tensor + return %6: f64 + } + + func.func private @axpby(%a: f64, %x: tensor, %b: f64, %y: tensor, %dst: tensor) -> tensor { + %1 = linalg.generic {indexing_maps = [#idmap, #idmap, #idmap], iterator_types = ["parallel"]} ins(%x, %y: tensor, tensor) outs(%dst : tensor) { + ^bb0(%inx: f64, %iny: f64, %out: f64): + %4 = arith.mulf %inx, %a: f64 + %5 = arith.mulf %iny, %b: f64 + %6 = arith.addf %4, %5: f64 + linalg.yield %6 : f64 + } -> tensor + return %1 : tensor + } + + func.func private @mult(%x: tensor, %y: tensor, %dst: tensor) -> tensor { + %1 = linalg.generic {indexing_maps = [#idmap, #idmap, #idmap], iterator_types = ["parallel"]} ins(%x, %y: tensor, tensor) outs(%dst : tensor) { + ^bb0(%inx: f64, %iny: f64, %out: f64): + %2 = arith.mulf %inx, %iny: f64 + linalg.yield %2 : f64 + } -> tensor + return %1 : tensor + } + + // CG solve with diagonal preconditioner + // Returns: x, numiter, resnorm + func.func @pcg(%A: tensor, %b: tensor, %dinv: tensor, %tol: f64, %maxiter: index) -> (tensor, index, f64) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %n = tensor.dim %b, %c0 : tensor + %f0 = arith.constant 0.0 : f64 + %f1 = arith.constant 1.0 : f64 + %fm1 = arith.constant -1.0 : f64 + + // Preallocate some intermediate tensors for dst-passing style + %buf0 = tensor.empty(%n) : tensor + %buf1 = tensor.empty(%n) : tensor + %buf2 = tensor.empty(%n) : tensor + + // Assume initial guess x0 = 0 + // Then r0 = b - A*x0 = b + %r0 = linalg.copy ins(%b : tensor) outs(%buf0 : tensor) -> tensor + %z0 = func.call @mult(%r0, %dinv, %buf1) : (tensor, tensor, tensor) -> tensor + %p0 = linalg.copy ins(%z0 : tensor) outs(%buf2 : tensor) -> tensor + %x0 = tensor.splat %f0[%n] : tensor + %Apbuf = tensor.empty(%n) : tensor + %rr0 = func.call @dot(%r0, %r0) : (tensor, tensor) -> f64 + %initres = math.sqrt %rr0 : f64 + + %x, %p, %z, %r, %final_relres, %rz, %iters = scf.while (%xiter = %x0, %piter = %p0, %ziter = %z0, %riter = %r0, %rziter = %f0, %i = %c1) : (tensor, tensor, tensor, tensor, f64, index) -> (tensor, tensor, tensor, tensor, f64, f64, index) + { + %Ap = func.call @spmv(%A, %piter, %Apbuf) { fuse_with = "dot" } : (tensor, tensor, tensor) -> tensor + %pAp = func.call @dot(%Ap, %piter) { fuse_with = "spmv" } : (tensor, tensor) -> f64 + %rz = func.call @dot(%riter, %ziter) : (tensor, tensor) -> f64 + %alpha = arith.divf %rz, %pAp : f64 + %malpha = arith.negf %alpha : f64 + + // Update x and r + %xnext = func.call @axpby(%f1, %xiter, %alpha, %piter, %xiter) { fuse_with = "axpby" } : (f64, tensor, f64, tensor, tensor) -> tensor + %rnext = func.call @axpby(%f1, %riter, %malpha, %Ap, %riter) { fuse_with = "axpby" } : (f64, tensor, f64, tensor, tensor) -> tensor + + // Test against tolerance and + %rr = func.call @dot(%rnext, %rnext) : (tensor, tensor) -> f64 + %rnorm = math.sqrt %rr : f64 + %relres = arith.divf %rnorm, %initres : f64 + %not_converged = arith.cmpf ogt, %relres, %tol : f64 + + // we have already completed an iteration, which is why i is intially 1 + %below_maxiter = arith.cmpi ne, %i, %maxiter : index + %continue = arith.andi %not_converged, %below_maxiter : i1 + + scf.condition(%continue) %xnext, %piter, %ziter, %rnext, %relres, %rz, %i: tensor, tensor, tensor, tensor, f64, f64, index + } + do { + ^bb0(%xiter : tensor, %piter : tensor, %ziter : tensor, %riter : tensor, %unused : f64, %rziter : f64, %i : index): + %znext = func.call @mult(%riter, %dinv, %ziter) : (tensor, tensor, tensor) -> tensor + %rznext = func.call @dot(%riter, %znext) : (tensor, tensor) -> f64 + %beta = arith.divf %rznext, %rziter : f64 + %pnext = func.call @axpby(%f1, %znext, %beta, %piter, %piter) : (f64, tensor, f64, tensor, tensor) -> tensor + %inext = arith.addi %i, %c1 : index + scf.yield %xiter, %pnext, %znext, %riter, %rznext, %inext : tensor, tensor, tensor, tensor, f64, index + } + return %x, %iters, %final_relres : tensor, index, f64 + } +} + diff --git a/examples/mlir-dumps/tp-dg-stiffness-dynamic-size.mlir b/examples/mlir-dumps/tp-dg-stiffness-dynamic-size.mlir new file mode 100644 index 00000000..c92b272a --- /dev/null +++ b/examples/mlir-dumps/tp-dg-stiffness-dynamic-size.mlir @@ -0,0 +1,58 @@ +#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d2, d3)> +#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d4)> +#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)> +#map3 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4, d3)> +#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d4)> +#map5 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)> +#map6 = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)> +module { + func.func @compute_reference_dx(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) { + ^bb0(%in: f64, %in_0: f64, %out: f64): + %1 = arith.mulf %in_0, %in : f64 + %2 = arith.addf %out, %1 : f64 + linalg.yield %2 : f64 + } -> tensor + return %0 : tensor + } + + func.func @compute_reference_dy(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + %0 = linalg.generic {indexing_maps = [#map3, #map4, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) { + ^bb0(%in: f64, %in_0: f64, %out: f64): + %1 = arith.mulf %in_0, %in : f64 + %2 = arith.addf %out, %1 : f64 + linalg.yield %2 : f64 + } -> tensor + return %0 : tensor + } + + func.func @compute_reference_dz(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + %0 = linalg.generic {indexing_maps = [#map5, #map6, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) { + ^bb0(%in: f64, %in_0: f64, %out: f64): + %1 = arith.mulf %in_0, %in : f64 + %2 = arith.addf %out, %1 : f64 + linalg.yield %2 : f64 + } -> tensor + return %0 : tensor + } + + func.func @main(%arg0: index, %arg1: index) -> (tensor, + tensor, tensor) attributes {llvm.emit_c_interface} { + %0 = tensor.empty(%arg0, %arg1, %arg1, %arg1) : tensor + %1 = tensor.empty(%arg1, %arg1) : tensor + %2 = tensor.empty(%arg0, %arg1, %arg1, %arg1) : tensor + %3 = tensor.empty(%arg0, %arg1, %arg1, %arg1) : tensor + %4 = tensor.empty(%arg0, %arg1, %arg1, %arg1) : tensor + + %5 = call @compute_reference_dx(%0, %1, %2) { fuse_with = + "compute_reference_dy, compute_reference_dz" } : (tensor, tensor, tensor) -> tensor + %6 = call @compute_reference_dy(%0, %1, %3) { fuse_with = + "compute_reference_dx, compute_reference_dz" } : (tensor, tensor, tensor) -> tensor + %7 = call @compute_reference_dz(%0, %1, %4) { fuse_with = + "compute_reference_dx, compute_reference_dy" }: (tensor, tensor, tensor) -> tensor + + return %5, %6, %7 : + tensor, tensor, tensor + } +} + diff --git a/examples/mlir-dumps/tp-dg-stiffness-static-size.mlir b/examples/mlir-dumps/tp-dg-stiffness-static-size.mlir new file mode 100644 index 00000000..38b7206b --- /dev/null +++ b/examples/mlir-dumps/tp-dg-stiffness-static-size.mlir @@ -0,0 +1,58 @@ +#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d2, d3)> +#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d4)> +#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)> +#map3 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4, d3)> +#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d4)> +#map5 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)> +#map6 = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)> +module { + func.func @compute_reference_dx(%arg0: tensor<100x100x100x100xf64>, %arg1: tensor<100x100xf64>, %arg2: tensor<100x100x100x100xf64>) -> tensor<100x100x100x100xf64> { + %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<100x100x100x100xf64>, tensor<100x100xf64>) outs(%arg2 : tensor<100x100x100x100xf64>) { + ^bb0(%in: f64, %in_0: f64, %out: f64): + %1 = arith.mulf %in_0, %in : f64 + %2 = arith.addf %out, %1 : f64 + linalg.yield %2 : f64 + } -> tensor<100x100x100x100xf64> + return %0 : tensor<100x100x100x100xf64> + } + + func.func @compute_reference_dy(%arg0: tensor<100x100x100x100xf64>, %arg1: tensor<100x100xf64>, %arg2: tensor<100x100x100x100xf64>) -> tensor<100x100x100x100xf64> { + %0 = linalg.generic {indexing_maps = [#map3, #map4, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<100x100x100x100xf64>, tensor<100x100xf64>) outs(%arg2 : tensor<100x100x100x100xf64>) { + ^bb0(%in: f64, %in_0: f64, %out: f64): + %1 = arith.mulf %in_0, %in : f64 + %2 = arith.addf %out, %1 : f64 + linalg.yield %2 : f64 + } -> tensor<100x100x100x100xf64> + return %0 : tensor<100x100x100x100xf64> + } + + func.func @compute_reference_dz(%arg0: tensor<100x100x100x100xf64>, %arg1: tensor<100x100xf64>, %arg2: tensor<100x100x100x100xf64>) -> tensor<100x100x100x100xf64> { + %0 = linalg.generic {indexing_maps = [#map5, #map6, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<100x100x100x100xf64>, tensor<100x100xf64>) outs(%arg2 : tensor<100x100x100x100xf64>) { + ^bb0(%in: f64, %in_0: f64, %out: f64): + %1 = arith.mulf %in_0, %in : f64 + %2 = arith.addf %out, %1 : f64 + linalg.yield %2 : f64 + } -> tensor<100x100x100x100xf64> + return %0 : tensor<100x100x100x100xf64> + } + + func.func @main(%arg0: index, %arg1: index) -> (tensor<100x100x100x100xf64>, + tensor<100x100x100x100xf64>, tensor<100x100x100x100xf64>) attributes {llvm.emit_c_interface} { + %0 = tensor.empty() : tensor<100x100x100x100xf64> + %1 = tensor.empty() : tensor<100x100xf64> + %2 = tensor.empty() : tensor<100x100x100x100xf64> + %3 = tensor.empty() : tensor<100x100x100x100xf64> + %4 = tensor.empty() : tensor<100x100x100x100xf64> + + %5 = call @compute_reference_dx(%0, %1, %2) { fuse_with = + "compute_reference_dy, compute_reference_dz" } : (tensor<100x100x100x100xf64>, tensor<100x100xf64>, tensor<100x100x100x100xf64>) -> tensor<100x100x100x100xf64> + %6 = call @compute_reference_dy(%0, %1, %3) { fuse_with = + "compute_reference_dx, compute_reference_dz" } : (tensor<100x100x100x100xf64>, tensor<100x100xf64>, tensor<100x100x100x100xf64>) -> tensor<100x100x100x100xf64> + %7 = call @compute_reference_dz(%0, %1, %4) { fuse_with = + "compute_reference_dx, compute_reference_dy" }: (tensor<100x100x100x100xf64>, tensor<100x100xf64>, tensor<100x100x100x100xf64>) -> tensor<100x100x100x100xf64> + + return %5, %6, %7 : + tensor<100x100x100x100xf64>, tensor<100x100x100x100xf64>, tensor<100x100x100x100xf64> + } +} + diff --git a/examples/pcg_solve.py b/examples/pcg_solve.py index d0bbcf4d..40189684 100644 --- a/examples/pcg_solve.py +++ b/examples/pcg_solve.py @@ -56,10 +56,12 @@ %f0 = arith.constant 0.0 : f64 %f1 = arith.constant 1.0 : f64 %fm1 = arith.constant -1.0 : f64 + // Preallocate some intermediate tensors for dst-passing style %buf0 = tensor.empty(%n) : tensor %buf1 = tensor.empty(%n) : tensor %buf2 = tensor.empty(%n) : tensor + // Assume initial guess x0 = 0 // Then r0 = b - A*x0 = b %r0 = linalg.copy ins(%b : tensor) outs(%buf0 : tensor) -> tensor @@ -69,26 +71,34 @@ %Apbuf = tensor.empty(%n) : tensor %rr0 = func.call @dot(%r0, %r0) : (tensor, tensor) -> f64 %initres = math.sqrt %rr0 : f64 + %x, %p, %z, %r, %final_relres, %rz, %iters = scf.while (%xiter = %x0, %piter = %p0, %ziter = %z0, %riter = %r0, %rziter = %f0, %i = %c1) : (tensor, tensor, tensor, tensor, f64, index) -> (tensor, tensor, tensor, tensor, f64, f64, index) { - %Ap = func.call @spmv(%A, %piter, %Apbuf) : (tensor, tensor, tensor) -> tensor - %pAp = func.call @dot(%piter, %Ap) : (tensor, tensor) -> f64 + %Ap = func.call @spmv(%A, %piter, %Apbuf) { fuse_with = "dot" } : (tensor, tensor, tensor) -> tensor + %pAp = func.call @dot(%piter, %Ap) { fuse_with = "spmv" } : (tensor, tensor) -> f64 %rz = func.call @dot(%riter, %ziter) : (tensor, tensor) -> f64 %alpha = arith.divf %rz, %pAp : f64 %malpha = arith.negf %alpha : f64 + // Update x and r - %xnext = func.call @axpby(%f1, %xiter, %alpha, %piter, %xiter) : (f64, tensor, f64, tensor, tensor) -> tensor - %rnext = func.call @axpby(%f1, %riter, %malpha, %Ap, %riter) : (f64, tensor, f64, tensor, tensor) -> tensor + %xnext = func.call @axpby(%f1, %xiter, %alpha, %piter, %xiter) { fuse_with + = "axpby" } : (f64, tensor, f64, tensor, tensor) -> tensor + %rnext = func.call @axpby(%f1, %riter, %malpha, %Ap, %riter) { fuse_with = + "axpby" } : (f64, tensor, f64, tensor, tensor) -> tensor + // Test against tolerance and %rr = func.call @dot(%rnext, %rnext) : (tensor, tensor) -> f64 %rnorm = math.sqrt %rr : f64 %relres = arith.divf %rnorm, %initres : f64 %not_converged = arith.cmpf ogt, %relres, %tol : f64 + // we have already completed an iteration, which is why i is intially 1 %below_maxiter = arith.cmpi ne, %i, %maxiter : index %continue = arith.andi %not_converged, %below_maxiter : i1 + scf.condition(%continue) %xnext, %piter, %ziter, %rnext, %relres, %rz, %i: tensor, tensor, tensor, tensor, f64, f64, index - } do { + } + do { ^bb0(%xiter : tensor, %piter : tensor, %ziter : tensor, %riter : tensor, %unused : f64, %rziter : f64, %i : index): %znext = func.call @mult(%riter, %dinv, %ziter) : (tensor, tensor, tensor) -> tensor %rznext = func.call @dot(%riter, %znext) : (tensor, tensor) -> f64 @@ -121,7 +131,7 @@ def main(): reltol = 1e-10 maxiter = 40 - backend = KokkosBackend.KokkosBackend(decompose_tensors=True) + backend = KokkosBackend.KokkosBackend(decompose_tensors=True, dump_mlir=True) module_kokkos = backend.compile(moduleText) print("x exact solution (first 10 elements):", xgold[:10]) diff --git a/mlir/include/CMakeLists.txt b/mlir/include/CMakeLists.txt index f351ec59..753e8e7a 100644 --- a/mlir/include/CMakeLists.txt +++ b/mlir/include/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(lapis) +add_subdirectory(Transform) diff --git a/mlir/include/Transform/CMakeLists.txt b/mlir/include/Transform/CMakeLists.txt new file mode 100644 index 00000000..b7ffe0dc --- /dev/null +++ b/mlir/include/Transform/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Kernel) diff --git a/mlir/include/Transform/Kernel/CMakeLists.txt b/mlir/include/Transform/Kernel/CMakeLists.txt new file mode 100644 index 00000000..08235ec0 --- /dev/null +++ b/mlir/include/Transform/Kernel/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS KernelPasses.td) +mlir_tablegen(KernelPasses.h.inc -gen-pass-decls) +add_public_tablegen_target(MLIRKernelPassesIncGen) diff --git a/mlir/include/Transform/Kernel/KernelFusionDriver.h b/mlir/include/Transform/Kernel/KernelFusionDriver.h new file mode 100644 index 00000000..3ebc8909 --- /dev/null +++ b/mlir/include/Transform/Kernel/KernelFusionDriver.h @@ -0,0 +1,19 @@ +#ifndef KERNEL_FUSION_DRIVER_H +#define KERNEL_FUSION_DRIVER_H + +#include "mlir/Pass/Pass.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" + +namespace mlir { +namespace kernel { + +#define GEN_PASS_DECL_KERNELFUSIONDRIVER +#include "Transform/Kernel/KernelPasses.h.inc" + +} +} + + +#endif diff --git a/mlir/include/Transform/Kernel/KernelFusionPass.h b/mlir/include/Transform/Kernel/KernelFusionPass.h new file mode 100644 index 00000000..fe7c8c55 --- /dev/null +++ b/mlir/include/Transform/Kernel/KernelFusionPass.h @@ -0,0 +1,15 @@ +#ifndef KERNEL_FUSION_PASS_H +#define KERNEL_FUSION_PASS_H + +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace kernel { + +#define GEN_PASS_DECL_KERNELFUSIONPASS +#include "Transform/Kernel/KernelPasses.h.inc" + +} +} + +#endif diff --git a/mlir/include/Transform/Kernel/KernelPasses.h b/mlir/include/Transform/Kernel/KernelPasses.h new file mode 100644 index 00000000..52d28689 --- /dev/null +++ b/mlir/include/Transform/Kernel/KernelPasses.h @@ -0,0 +1,29 @@ +#ifndef FUNC_CUSTOM_PASSES_H +#define FUNC_CUSTOM_PASSES_H + +#include "Transform/Kernel/KernelFusionPass.h" +#include "Transform/Kernel/KernelFusionDriver.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" + +namespace mlir { +namespace kernel { + +#define GEN_PASS_DECL_FUSEDKERNELINLININGPASS +std::unique_ptr createFusedKernelInliningPass(); + +#define GEN_PASS_DECL_KERNELDOMAINFUSIONPASS +std::unique_ptr createKernelDomainFusionPass(); + +#define GEN_PASS_DECL_LINALGGENERICREORDERINGPASS +std::unique_ptr createLinalgGenericReorderingPass(); + +#define GEN_PASS_REGISTRATION +#include "Transform/Kernel/KernelPasses.h.inc" + +} +} + +#endif diff --git a/mlir/include/Transform/Kernel/KernelPasses.td b/mlir/include/Transform/Kernel/KernelPasses.td new file mode 100644 index 00000000..ba1d0cff --- /dev/null +++ b/mlir/include/Transform/Kernel/KernelPasses.td @@ -0,0 +1,78 @@ +#ifndef KERNEL_FUSION_PASS +#define KERNEL_FUSION_PASS + +include "mlir/Pass/PassBase.td" + +def KernelFusionPass: Pass<"kernel-fusion-pass"> { + let summary = "Fuses related subkernels into a single kernel."; + let description = [{ + Fuses kernels by examining the arguments and results of kernel calls. + Multiple calls to the same kernel are treated as unique instantiations of a + kernel. + + Once related calls are identified, they are put into sets. Once these sets + are created, each subkernel call is moved into a new kernel. Afterward, a + custom inling pass is run *only* over the fused kernels and the unused + kernel definitions are removed. + + Once this pass is finished, the new fused kernels are ready for optimization + and further lowering. + }]; +} + +def KernelFusionDriver: Pass<"drive-kernel-fusion"> { + let summary = "Drive a kernel fusion pass from unfused to optimized"; + let description = [{ + Runs all relevant passes to fuse, lower, and run a default optimization pass + over a program containing computational kernels. + }]; + let dependentDialects = [ + "arith::ArithDialect", + "LLVM::LLVMDialect", + "sparse_tensor::SparseTensorDialect", + "bufferization::BufferizationDialect" + ]; +} + +def FusedKernelInliningPass: Pass<"fused-kernel-inlining-pass"> { + let summary = "Run a custom inlining pass over fused kernels"; + let constructor = "createFusedKernelInliningPass()"; + let description = [{ + Runs a restricted version of the MLIR inliner. During the KernelFusion pass, + subkernel calls are tagged with an "inline" attribute. This inlining pass + *only* inlines calls with this attribute and *always* inlines them. + + This pass is meant to run directly after KernelFusion. However, it will work + on any calls tagged with "inline", so it can be used outside of the context + of kernel fusion. + }]; +} + +def KernelDomainFusion: Pass<"fused-kernel-domain-fusion-pass"> { + let summary = "Fuse the parallel loops of subkernels in a fused kernel"; + let constructor = "createKernelDomainFusionPass()"; + let description = [{ + Fuses the domains (i.e. the parallel loops) of subkernels in a fused kernel. + Uses a custom version of the existing SCF parallel loop fusion pass. This + custom pass *does not* consider aliasing. Aliasing is explicitly prohibited + for now. + }]; +} + +def LinalgGenericReorderingPass: Pass<"reorder-linalg-generics"> { + let summary = "Use the Einstein summation convention to minimize cost"; + let constructor = "createLinalgGenericReorderingPass()"; + let description = [{ + Interprets appropriate linalg.GenericOps as einsums, determines a + minimum-cost contraction order of operands, and creates new + linalg.GenericOps based on that minimum-cost contraction order. Requires + that bodies of all linalg.GenericOps contain only a single multiplication + over parallel axes, and a single addition over reduction axes with a single + return value. + }]; + let dependentDialects = [ + "arith::ArithDialect" + ]; +} + +#endif diff --git a/mlir/include/lapis/Dialect/Kokkos/Pipelines/Passes.h b/mlir/include/lapis/Dialect/Kokkos/Pipelines/Passes.h index 23496be7..d24080d6 100644 --- a/mlir/include/lapis/Dialect/Kokkos/Pipelines/Passes.h +++ b/mlir/include/lapis/Dialect/Kokkos/Pipelines/Passes.h @@ -59,6 +59,17 @@ struct LapisCompilerOptions *this, "decompose-sparse-tensors", desc("Decompose sparse tensors into memrefs (default off)"), init(false)}; + PassOptions::Option kernel_fusion{ + *this, "kernel-fusion", + desc("Attempt to fuse kernels based on kernel call fuse_with attributes"), + init(false)}; + + PassOptions::Option reorder_linalg_generics{ + *this, "reorder-linalg-generics", + desc("Reorder tensor contraction-like generic operations to minimize " + "asymptoptic computational cost"), + init(false)}; + #ifdef LAPIS_ENABLE_PART_TENSOR PassOptions::Option partTensorBackend{ *this, "pt-backend", diff --git a/mlir/lib/CMakeLists.txt b/mlir/lib/CMakeLists.txt index 0f42e553..16d60e65 100644 --- a/mlir/lib/CMakeLists.txt +++ b/mlir/lib/CMakeLists.txt @@ -2,3 +2,4 @@ add_subdirectory(Dialect) add_subdirectory(ExecutionEngine) add_subdirectory(CAPI) add_subdirectory(Target) +add_subdirectory(Transform) diff --git a/mlir/lib/Dialect/Kokkos/Pipelines/KokkosPipelines.cpp b/mlir/lib/Dialect/Kokkos/Pipelines/KokkosPipelines.cpp index 59987751..e7068156 100644 --- a/mlir/lib/Dialect/Kokkos/Pipelines/KokkosPipelines.cpp +++ b/mlir/lib/Dialect/Kokkos/Pipelines/KokkosPipelines.cpp @@ -6,6 +6,8 @@ // //===----------------------------------------------------------------------===// +#include "Transform/Kernel/KernelFusionDriver.h" +#include "Transform/Kernel/KernelPasses.h" #include "lapis/LAPIS_config.h" #include "lapis/Dialect/Kokkos/Pipelines/Passes.h" #include "mlir/Conversion/Passes.h" @@ -55,6 +57,11 @@ void mlir::kokkos::buildSparseKokkosCompiler( // Rewrite named linalg ops into generic ops and apply fusion. pm.addNestedPass(createLinalgGeneralizeNamedOpsPass()); + if (options.kernel_fusion) + pm.addPass(kernel::createKernelFusionDriver()); + if (options.reorder_linalg_generics) + pm.addPass(kernel::createLinalgGenericReorderingPass()); + // Remove compile-time unit extent dimensions from linalg ops. // For example, a 3D loop over (N, M, 1) will be rewritten to 2D loop over (N, M). // This does not affect tensor types, at least in function parameter/return types, @@ -134,6 +141,7 @@ void mlir::kokkos::buildSparseKokkosCompiler( pm.addNestedPass(createDenseLinalgToParallelLoopsPass()); // The built-in lowering will take care of any remaining linalg ops pm.addNestedPass(createConvertLinalgToParallelLoopsPass()); + // pm.addNestedPass(kernel::createKernelDomainFusionPass()); // pm.addNestedPass(arith::createArithExpandOpsPass()); pm.addPass(memref::createExpandStridedMetadataPass()); diff --git a/mlir/lib/Transform/CMakeLists.txt b/mlir/lib/Transform/CMakeLists.txt new file mode 100644 index 00000000..b7ffe0dc --- /dev/null +++ b/mlir/lib/Transform/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Kernel) diff --git a/mlir/lib/Transform/Kernel/CMakeLists.txt b/mlir/lib/Transform/Kernel/CMakeLists.txt new file mode 100644 index 00000000..a570a0ab --- /dev/null +++ b/mlir/lib/Transform/Kernel/CMakeLists.txt @@ -0,0 +1,16 @@ +add_mlir_library(MLIRKernelPasses + KernelFusionPass.cpp + KernelFusionDriver.cpp + KernelDomainFusion.cpp + FusedKernelInliningPass.cpp + LinalgGenericReordering.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/Transform/Kernel + + DEPENDS + MLIRKernelPassesIncGen + MLIRFuncInlinerExtension + + LINK_LIBS PUBLIC +) diff --git a/mlir/lib/Transform/Kernel/FusedKernelInliningPass.cpp b/mlir/lib/Transform/Kernel/FusedKernelInliningPass.cpp new file mode 100644 index 00000000..6378cc4a --- /dev/null +++ b/mlir/lib/Transform/Kernel/FusedKernelInliningPass.cpp @@ -0,0 +1,135 @@ +/* Modified version of InlinerPass. Uses a custom profitability callback + * function so that the inliner only operates on Operations that are tagged with + * an "inline" attribute. Everything else, aside from some slight renaming, is + * entirely the same as the built-in MLIR inliner. + * + * Exists so that we can still benefit from upstream MLIR changes without + * worrying too much about conflicts. + * + * Ideally, a PR into upstream MLIR would be created that adds a stub in the + * inliner that takes a profitability callback function on pass creation. This + * way, many different custom inliners can be run over the same program. +*/ + +#include "mlir/Analysis/CallGraph.h" +#include "mlir/Transforms/Inliner.h" +#include "mlir/Transforms/Passes.h" + +#include "Transform/Kernel/KernelPasses.h" + +namespace mlir { + +static void defaultInlinerOptPipeline(OpPassManager &pm) { + pm.addPass(createCanonicalizerPass()); +} + +#define GEN_PASS_DEF_INLINER +#include "mlir/Transforms/Passes.h.inc" + +namespace kernel { + +using PipelineTy = std::function; +using OpPipelinesTy = llvm::StringMap; + +#define GEN_PASS_DEF_FUSEDKERNELINLININGPASS +#include "Transform/Kernel/KernelPasses.h.inc" + +class FusedKernelInliningPass : public mlir::impl::InlinerBase { +public: + FusedKernelInliningPass(); + FusedKernelInliningPass(const FusedKernelInliningPass &) = default; + FusedKernelInliningPass(PipelineTy defaultPipeline); + FusedKernelInliningPass(PipelineTy defaultPipeline, OpPipelinesTy opPipelines); + + void runOnOperation() override; + + static LogicalResult runPipelineHelper(Pass &pass, OpPassManager &pipeline, + Operation *op) { + return mlir::cast(pass).runPipeline(pipeline, op); + } + +private: + LogicalResult initializeOptions( + StringRef options, + function_ref errorHandler) override; + + InlinerConfig config; +}; + +// constructor definitions +FusedKernelInliningPass::FusedKernelInliningPass() + : FusedKernelInliningPass(defaultInlinerOptPipeline) {} +FusedKernelInliningPass::FusedKernelInliningPass(PipelineTy defaultPipelineArg) + : FusedKernelInliningPass(std::move(defaultPipelineArg), OpPipelinesTy{}) {} +FusedKernelInliningPass::FusedKernelInliningPass(PipelineTy defaultPipeline, + OpPipelinesTy opPipelines) + : config(std::move(defaultPipeline), maxInliningIterations) { + if (opPipelines.empty()) + return; + + for (auto &it : opPipelines) + opPipelineList.addValue(it.second); + config.setOpPipelines(std::move(opPipelines)); +} + +// adapted cost model function; only inline kernels that are tagged for inlining +static bool isProfitableToInline(const Inliner::ResolvedCall &resolvedCall) { + return resolvedCall.call->hasAttr("inline") || + !resolvedCall.call->hasAttr("noinline"); +} + +void FusedKernelInliningPass::runOnOperation() { + CallGraph &cg = getAnalysis(); + + Operation *op = getOperation(); + if (!op->hasTrait()) { + op->emitError() << " was scheduled to be run under the inliner, but does " + << "define a symbol table"; + return signalPassFailure(); + } + + auto profitabilityCb = [=](const Inliner::ResolvedCall &resolvedCall) { + return isProfitableToInline(resolvedCall); + }; + + Inliner inliner(op, cg, *this, getAnalysisManager(), runPipelineHelper, + config, profitabilityCb); + + if(failed(inliner.doInlining())) + return signalPassFailure(); +} + +LogicalResult FusedKernelInliningPass::initializeOptions( + StringRef options, + function_ref errorHandler) { + if(failed(Pass::initializeOptions(options, errorHandler))) + return failure(); + + if (!defaultPipelineStr.empty()) { + std::string defaultPipelineCopy = defaultPipelineStr; + config.setDefaultPipeline([=](OpPassManager &pm) { + (void)parsePassPipeline(defaultPipelineCopy, pm); + }); + } + else if (defaultPipelineStr.getNumOccurrences()) { + config.setDefaultPipeline(nullptr); + } + + OpPipelinesTy pipelines; + for (OpPassManager pipeline : opPipelineList) + if (!pipeline.empty()) + pipelines.try_emplace(pipeline.getOpAnchorName(), pipeline); + config.setOpPipelines(std::move(pipelines)); + + config.setMaxInliningIterations(maxInliningIterations); + + return success(); +} + +std::unique_ptr createFusedKernelInliningPass() { + return std::make_unique(); +} + +} // namespace kernel +} // namespace mlir + diff --git a/mlir/lib/Transform/Kernel/KernelDomainFusion.cpp b/mlir/lib/Transform/Kernel/KernelDomainFusion.cpp new file mode 100644 index 00000000..a2112918 --- /dev/null +++ b/mlir/lib/Transform/Kernel/KernelDomainFusion.cpp @@ -0,0 +1,223 @@ +/* An adjusted version of SCF parallel loop fusion. + * + * The largest difference is that this version *does not* check for aliasing. + * Additionally, the input expected is different from what is expected in + * vanilla SCF parallel loop fusion. + */ + +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Builders.h" + +#include "Transform/Kernel/KernelPasses.h" + +namespace mlir { + +#define GEN_PASS_DEF_SCFPARALLELLOOPFUSION +#include "mlir/Dialect/SCF/Transforms/Passes.h.inc" +namespace kernel { + +using ParallelOp = scf::ParallelOp; +using ReduceOp = scf::ReduceOp; + +std::optional> getConstBounds(OperandRange bounds) { + return getConstantIntValues(getAsOpFoldResult(SmallVector(bounds))); +} + +static bool haveNoReadsAfterWriteExceptSameIndex( + ParallelOp firstPloop, ParallelOp secondPloop, + const IRMapping &firstToSecondPloopIndices) { + DenseMap> bufferStores; + SmallVector bufferStoresVec; + firstPloop.getBody()->walk([&](memref::StoreOp store) { + bufferStores[store.getMemRef()].push_back(store.getIndices()); + bufferStoresVec.emplace_back(store.getMemRef()); + }); + auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) { + Value loadMem = load.getMemRef(); + auto *memrefDef = loadMem.getDefiningOp(); + if (memrefDef && memrefDef->getBlock() == load->getBlock()) + return WalkResult::interrupt(); + + auto write = bufferStores.find(loadMem); + if (write == bufferStores.end()) + return WalkResult::advance(); + + if (!write->second.size()) + return WalkResult::interrupt(); + + auto storeIndices = write->second.front(); + + for (const auto &othStoreIndices : write->second) { + if (othStoreIndices != storeIndices) + return WalkResult::interrupt(); + } + + auto loadIndices = load.getIndices(); + if (storeIndices.size() != loadIndices.size()) + return WalkResult::interrupt(); + for (int i = 0, e = storeIndices.size(); i < e; ++i) { + if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) != + loadIndices[i]) { + auto *storeIndexDefOp = storeIndices[i].getDefiningOp(); + auto *loadIndexDefOp = loadIndices[i].getDefiningOp(); + if (storeIndexDefOp && loadIndexDefOp) { + if (!isMemoryEffectFree(storeIndexDefOp)) + return WalkResult::interrupt(); + if (!isMemoryEffectFree(loadIndexDefOp)) + return WalkResult::interrupt(); + if (!OperationEquivalence::isEquivalentTo( + storeIndexDefOp, loadIndexDefOp, + [&](Value storeIndex, Value loadIndex) { + if (firstToSecondPloopIndices.lookupOrDefault(storeIndex) != + firstToSecondPloopIndices.lookupOrDefault(loadIndex)) + return failure(); + else + return success(); + }, + /*markEquivalent=*/nullptr, + OperationEquivalence::Flags::IgnoreLocations)) { + return WalkResult::interrupt(); + } + } else + return WalkResult::interrupt(); + } + } + return WalkResult::advance(); + }); + return !walkResult.wasInterrupted(); +} + +static LogicalResult verifyDependencies(ParallelOp firstPloop, + ParallelOp secondPloop) { + IRMapping firstToSecondPloopIndices; + firstToSecondPloopIndices.map(firstPloop.getBody()->getArguments(), + secondPloop.getBody()->getArguments()); + if (!haveNoReadsAfterWriteExceptSameIndex(firstPloop, secondPloop, + firstToSecondPloopIndices)) + return failure(); + + IRMapping secondToFirstPloopIndices; + secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(), + firstPloop.getBody()->getArguments()); + return success(haveNoReadsAfterWriteExceptSameIndex( + secondPloop, firstPloop, secondToFirstPloopIndices)); +} + +static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop) { + return succeeded(verifyDependencies(firstPloop, secondPloop)); +} + +static bool fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop, + OpBuilder builder) { + + if (!isFusionLegal(firstPloop, secondPloop)) + return false; + + DominanceInfo dom; + for (Operation *user : firstPloop->getUsers()) + if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ true)) + return false; + + ValueRange inits1 = firstPloop.getInitVals(); + ValueRange inits2 = secondPloop.getInitVals(); + + SmallVector newInitVars(inits1.begin(), inits1.end()); + newInitVars.append(inits2.begin(), inits2.end()); + + IRRewriter b(builder); + b.setInsertionPoint(secondPloop); + auto newSecondPloop = b.create( + secondPloop.getLoc(), secondPloop.getLowerBound(), + secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars); + + Block *block1 = firstPloop.getBody(); + Block *block2 = secondPloop.getBody(); + Block *newBlock = newSecondPloop.getBody(); + auto term1 = cast(block1->getTerminator()); + auto term2 = cast(block2->getTerminator()); + + b.inlineBlockBefore(block2, newBlock, newBlock->begin(), + newBlock->getArguments()); + b.inlineBlockBefore(block1, newBlock, newBlock->begin(), + newBlock->getArguments()); + + ValueRange results = newSecondPloop.getResults(); + if (!results.empty()) { + b.setInsertionPointToEnd(newBlock); + + ValueRange reduceArgs1 = term1.getOperands(); + ValueRange reduceArgs2 = term2.getOperands(); + SmallVector newReduceArgs(reduceArgs1.begin(), reduceArgs1.end()); + newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end()); + + auto newReduceOp = b.create(term2.getLoc(), newReduceArgs); + + for (auto &&[i, reg] : llvm::enumerate(llvm::concat( + term1.getReductions(), term2.getReductions()))) { + Block &oldRedBlock = reg.front(); + Block &newRedBlock = newReduceOp.getReductions()[i].front(); + b.inlineBlockBefore(&oldRedBlock, &newRedBlock, newRedBlock.begin(), + newRedBlock.getArguments()); + } + + firstPloop.replaceAllUsesWith(results.take_front(inits1.size())); + secondPloop.replaceAllUsesWith(results.take_back(inits2.size())); + } + term1->erase(); + term2->erase(); + firstPloop.erase(); + secondPloop.erase(); + secondPloop = newSecondPloop; + + return true; +} + +void naivelyFuseParallelOps(Region ®ion) { + OpBuilder b(region); + SmallVector, 1> ploopChains; + for (auto &block : region) { + ploopChains.clear(); + ploopChains.push_back({}); + + bool noSideEffects = true; + for (auto &op : block) { + if (auto ploop = dyn_cast(op)) { + if (noSideEffects) { + ploopChains.back().push_back(ploop); + } else { + ploopChains.push_back({ploop}); + noSideEffects = true; + } + continue; + } + } + for (MutableArrayRef ploops : ploopChains) { + for (int i = 0, e = ploops.size(); i + 1 < e; ++i) + fuseIfLegal(ploops[i], ploops[i + 1], b); + } + } +} + +#define GEN_PASS_DEF_KERNELDOMAINFUSION +#include "Transform/Kernel/KernelPasses.h.inc" + +struct KernelDomainFusion + : public mlir::impl::SCFParallelLoopFusionBase { + + void runOnOperation() override { + getOperation()->walk([&](Operation *child) { + for (Region ®ion : child->getRegions()) + naivelyFuseParallelOps(region); + }); + + return; + } +}; + +std::unique_ptr createKernelDomainFusionPass() { + return std::make_unique(); +} + +} +} diff --git a/mlir/lib/Transform/Kernel/KernelFusionDriver.cpp b/mlir/lib/Transform/Kernel/KernelFusionDriver.cpp new file mode 100644 index 00000000..2716de8d --- /dev/null +++ b/mlir/lib/Transform/Kernel/KernelFusionDriver.cpp @@ -0,0 +1,38 @@ +/* Top level routine for automatic kernel fusion. Creates "fusion sets", i.e. + * sets of subkernels that are to be fused together. Major steps: + */ + +#include "Transform/Kernel/KernelPasses.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" + +namespace mlir { +namespace kernel { + +#define GEN_PASS_DEF_KERNELFUSIONDRIVER +#include "Transform/Kernel/KernelPasses.h.inc" + +struct KernelFusionDriver : impl::KernelFusionDriverBase { + using KernelFusionDriverBase::KernelFusionDriverBase; + + void runOnOperation() override { + mlir::ModuleOp module = dyn_cast(getOperation()); + OpPassManager driveKernelFusionPass; + + driveKernelFusionPass.addPass(createKernelFusionPass()); + driveKernelFusionPass.addPass(createFusedKernelInliningPass()); + driveKernelFusionPass.addPass(createLinalgGeneralizeNamedOpsPass()); + + // run the pipeline + if (failed(runPipeline(driveKernelFusionPass, module))) + return signalPassFailure(); + + } +}; +} // namespace kernel +} // namespace mlir diff --git a/mlir/lib/Transform/Kernel/KernelFusionPass.cpp b/mlir/lib/Transform/Kernel/KernelFusionPass.cpp new file mode 100644 index 00000000..44ecdb8c --- /dev/null +++ b/mlir/lib/Transform/Kernel/KernelFusionPass.cpp @@ -0,0 +1,434 @@ +#include +#include +#include + +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Casting.h" + +namespace mlir { +namespace kernel { + +using AllocTensorOp = bufferization::AllocTensorOp; +using EmptyOp = tensor::EmptyOp; +using FuncOp = func::FuncOp; +using CallOp = func::CallOp; +using ReturnOp = func::ReturnOp; +using LinalgOp = linalg::LinalgOp; +using FuncVec = std::vector; +using CallMap = std::map; + +using FusionSet = std::vector; +using FusionSetVector = std::vector; +using ArgsToIdxTy = std::vector>; + +CallMap getCallMap(FuncOp func) { + SmallVector calls; + + func.walk([&calls](func::CallOp call) { + calls.push_back(call); + }); + + CallMap callMap; + for (auto keyCall : calls) { + callMap[keyCall] = FuncVec(); + for (auto valueCall : calls) { + if (keyCall == valueCall) + continue; + + bool valueCallMapped = false; + + for (auto keyOp : keyCall.getOperands()) { + if (valueCallMapped) + break; + for (auto valueOp : valueCall.getOperands()) { + if (keyOp == valueOp) { + callMap[keyCall].push_back(valueCall); + valueCallMapped = true; + break; + } + } + } + + for (auto keyResult : keyCall.getResults()) { + if (valueCallMapped) + break; + for (auto valueOp : valueCall.getOperands()) { + if (keyResult == valueOp) { + callMap[keyCall].push_back(valueCall); + valueCallMapped = true; + break; + } + } + } + } + } + + return callMap; +} + +bool parallelIterationSpacesMatch(ModuleOp module, CallOp firstCall, + CallOp secondCall) { + FuncOp firstCallee = dyn_cast(SymbolTable::lookupNearestSymbolFrom( + module, firstCall.getCallableForCallee().get())); + FuncOp secondCallee = dyn_cast(SymbolTable::lookupNearestSymbolFrom( + module, secondCall.getCallableForCallee().get())); + + IRMapping firstCalleeToCall; + IRMapping secondCalleeToCall; + Block &firstBody = firstCallee.getFunctionBody().front(); + Block &secondBody = secondCallee.getFunctionBody().front(); + firstCalleeToCall.map(firstCallee.getArguments(), firstCall.getOperands()); + secondCalleeToCall.map(secondCallee.getArguments(), secondCall.getOperands()); + + SmallVector firstLAOps(firstBody.getOps()); + SmallVector secondLAOps(secondBody.getOps()); + + for (LinalgOp firstLAOp : firstLAOps) { + SmallVector firstParDims; + firstLAOp.getParallelDims(firstParDims); + for (LinalgOp secondLAOp : secondLAOps) { + SmallVector secondParDims; + secondLAOp.getParallelDims(secondParDims); + + bool firstIncludesSecond = std::includes( + firstParDims.begin(), firstParDims.end(), + secondParDims.begin(), secondParDims.end() + ); + + if (!firstIncludesSecond) return false; + } + } + + return true; +} + +bool markedForFusion(CallOp keyKernel, CallOp valKernel) { + if (!keyKernel->getAttr("fuse_with") || !valKernel->getAttr("fuse_with")) + return false; + + SmallVector keyKernelFuseWithStrings; + StringRef keyKernelFuseWithString = + dyn_cast(keyKernel->getAttr("fuse_with")).strref(); + keyKernelFuseWithString.split(keyKernelFuseWithStrings, ","); + + SmallVector valKernelFuseWithStrings; + StringRef valKernelFuseWithString = + dyn_cast(valKernel->getAttr("fuse_with")).strref(); + valKernelFuseWithString.split(valKernelFuseWithStrings, ","); + + bool fuseWithFlag1 = false; + for (StringRef fuseWithString : keyKernelFuseWithStrings) { + if (valKernel.getCallee() == fuseWithString.trim(' ')) { + fuseWithFlag1 = true; + break; + } + } + + bool fuseWithFlag2 = false; + for (StringRef fuseWithString : valKernelFuseWithStrings) { + if (keyKernel.getCallee() == fuseWithString.trim(' ')) { + fuseWithFlag2 = true; + break; + } + } + + return fuseWithFlag1 && fuseWithFlag2; +} + +FusionSetVector createFusionSets(mlir::ModuleOp module, FuncOp func, + CallMap callMap) { + std::deque kernelsToFuse; + std::set kernelSet; + FusionSetVector fusionSets; + + func.walk([&kernelsToFuse, &kernelSet](func::CallOp call) { + if (kernelSet.find(call) == kernelSet.end()) { + kernelSet.insert(call); + kernelsToFuse.push_back(call); + } + }); + + int fusionSetIndex = 0; + while (!kernelsToFuse.empty()) { + auto kernelToFuse = kernelsToFuse.front(); + kernelsToFuse.pop_front(); + + auto kernelCheck = kernelSet.find(kernelToFuse); + if (kernelCheck == kernelSet.end()) + continue; + + kernelSet.extract(kernelToFuse); + + if (fusionSetIndex != int(fusionSets.size() - 1)) + fusionSets.push_back(std::vector()); + fusionSets[fusionSetIndex].push_back(kernelToFuse); + + for (auto val : callMap[kernelToFuse]) { + bool fusionLegal = + parallelIterationSpacesMatch(module, kernelToFuse, val); + if (fusionLegal && markedForFusion(kernelToFuse, val)) { + kernelSet.extract(val); + fusionSets[fusionSetIndex].push_back(val); + fusionLegal = false; + } + } + + if (fusionSets[fusionSetIndex].size() == 1) { + fusionSets[fusionSetIndex].pop_back(); + continue; + } + + fusionSetIndex += 1; + } + + return fusionSets; +} + +void buildArgsToIndexMap(FusionSet &fusionSet, CallOp kernel, + SmallVector &newArgs, + SmallVector &newResults, + DenseMap &argsToIndexMap, + int &fusedKernelArgIndex) { + for (auto arg : kernel.getOperands()) { + auto argCheck = std::find(newArgs.begin(), newArgs.end(), arg); + if (argCheck != newArgs.end()) + continue; + + auto producer = arg.getDefiningOp(); + if (producer) { + auto producerCheck = + std::find(fusionSet.begin(), fusionSet.end(), producer); + if (producerCheck != fusionSet.end()) + continue; + } + + newArgs.push_back(arg); + argsToIndexMap[arg] = fusedKernelArgIndex; + fusedKernelArgIndex++; + } +} + +bool checkResultUserIsInFusionSet(FusionSet &fusionSet, Value result) { + bool userInFusionSet = false; + for (auto user : result.getUsers()) { + userInFusionSet |= ( + std::find(fusionSet.begin(), fusionSet.end(), user) != fusionSet.end() + ); + } + + return userInFusionSet; +} + +bool checkResultUserIsNotInFusionSet(FusionSet &fusionSet, Value result) { + bool userNotInFusionSet = false; + for (auto user : result.getUsers()) { + userNotInFusionSet |= ( + std::find(fusionSet.begin(), fusionSet.end(), user) == fusionSet.end() + ); + } + + return userNotInFusionSet; +} + +void buildResultsToIndexMap(FusionSet &fusionSet, CallOp kernel, + SmallVector &newResults, + DenseMap &resultsToIndexMap, + int &fusedKernelResultIndex) { + for (auto res : kernel.getResults()) { + auto resCheck = std::find(newResults.begin(), newResults.end(), res); + if (resCheck != newResults.end()) + continue; + + bool userInFusionSet = checkResultUserIsInFusionSet(fusionSet, res); + bool userNotInFusionSet = checkResultUserIsNotInFusionSet(fusionSet, res); + + if (userInFusionSet && !userNotInFusionSet) + continue; + + newResults.push_back(res); + resultsToIndexMap[res] = fusedKernelResultIndex; + fusedKernelResultIndex++; + } +} + +FuncOp buildFusedKernelOp(OpBuilder &builder, ModuleOp &module, + FusionSet &fusionSet, SmallVector &newArgs, + SmallVector &newResults, + int &fusedKernelCounter) { + TypeRange newArgTypes(newArgs); + TypeRange newResultTypes(newResults); + FunctionType fusedKernelType = + builder.getFunctionType(newArgTypes, newResultTypes); + + builder.setInsertionPointToStart(module.getBody()); + + std::string fusedKernelName = ""; + for (auto kernel : fusionSet) + fusedKernelName += (kernel.getCallee() + "_").str(); + fusedKernelName += std::to_string(fusedKernelCounter); + + FuncOp fusedKernelOp = builder.create( + module.getLoc(), fusedKernelName, fusedKernelType); + + fusedKernelOp.addEntryBlock(); + + return fusedKernelOp; +} + +void insertSubkernelCallsIntoFusedKernel( + OpBuilder &builder, ModuleOp module, FusionSet &fusionSet, + FuncOp fusedKernelOp, DenseMap &argsToIndexMap, + DenseMap &newCallsToOldCallsMap) { + DenseMap callsToIntermediateValues(fusionSet.size()); + + builder.setInsertionPointToStart(&fusedKernelOp.front()); + for (CallOp kernel : fusionSet) { + FuncOp callee = dyn_cast(SymbolTable::lookupNearestSymbolFrom( + module, kernel.getCallableForCallee().get())); + + if (!callee.isPrivate()) callee.setPrivate(); + + SmallVector args; + for (auto arg : kernel.getOperands()) { + args.push_back(fusedKernelOp.getArgument(argsToIndexMap[arg])); + } + + CallOp newCallHandle = + builder.create(fusedKernelOp.getLoc(), callee, args); + newCallsToOldCallsMap[newCallHandle] = kernel; + + for (auto result : kernel.getResults()) { + for (auto &use : result.getUses()) { + if (auto consumer = dyn_cast(use.getOwner())) { + auto userCheck = + std::find(fusionSet.begin(), fusionSet.end(), consumer); + if (userCheck != fusionSet.end()) { + int argIndex = use.getOperandNumber(); + auto newArg = newCallHandle.getResult(result.getResultNumber()); + callsToIntermediateValues[consumer].push_back({newArg, argIndex}); + } + } + } + } + + if (!callsToIntermediateValues[kernel].empty()) { + for (auto argIndexPair : callsToIntermediateValues[kernel]) { + newCallHandle.setOperand(argIndexPair.second, argIndexPair.first); + } + } + + newCallHandle.getOperation()->setAttr("inline", + builder.getStringAttr("true")); + } +} + +void insertReturnOpToFusedKernelOp( + OpBuilder &builder, FusionSet &fusionSet, FuncOp fusedKernelOp, + DenseMap &resultsToIndexMap, + DenseMap &newCallsToOldCallsMap) { + SmallVector returnOperands(resultsToIndexMap.size()); + for (auto newCall : fusedKernelOp.getOps()) { + CallOp oldCall = newCallsToOldCallsMap[newCall]; + for (Value res : oldCall.getResults()) { + if (resultsToIndexMap.find(res) != resultsToIndexMap.end()) { + int resIndex = llvm::find(oldCall.getResults(), res) - + oldCall.getResults().begin(); + returnOperands[resultsToIndexMap[res]] = newCall.getResult(resIndex); + } + } + } + + builder.create(fusedKernelOp.getLoc(), ValueRange(returnOperands)); +} + +void buildFusedKernelCallAndReplaceSubkernelUses( + OpBuilder &builder, FusionSet &fusionSet, FuncOp func, + FuncOp fusedKernelOp, SmallVector newArgs, + DenseMap &resultsToIndexMap) { + + builder.setInsertionPoint(*fusionSet.rbegin()); + CallOp fusedKernelCallHandle = + builder.create(func.getLoc(), fusedKernelOp, newArgs); + fusedKernelCallHandle->setAttr("noinline", builder.getUnitAttr()); + + for (auto kernelCall : fusionSet) { + for (auto result : kernelCall.getResults()) { + auto newResult = + fusedKernelCallHandle.getResult(resultsToIndexMap[result]); + result.replaceAllUsesWith(newResult); + kernelCall.erase(); + } + } +} + +void fuseKernels(ModuleOp module, FuncOp func) { + OpBuilder builder(module.getContext()); + CallMap callMap = getCallMap(func); + FusionSetVector fusionSets = createFusionSets(module, func, callMap); + + int fusedKernelCounter = 0; + for (auto fusionSet : fusionSets) { + if (fusionSet.empty()) + continue; + + SmallVector newArgs; + SmallVector newResults; + SmallVector intermediates; + DenseMap argsToIndexMap; + DenseMap resultsToIndexMap; + DenseMap newCallsToOldCallsMap; + int fusedKernelArgIndex = 0; + int fusedKernelResultIndex = 0; + for (auto kernel : fusionSet) { + buildArgsToIndexMap(fusionSet, kernel, newArgs, newResults, + argsToIndexMap, fusedKernelArgIndex); + + buildResultsToIndexMap(fusionSet, kernel, newResults, resultsToIndexMap, + fusedKernelResultIndex); + } + + FuncOp fusedKernelOp = buildFusedKernelOp( + builder, module, fusionSet, newArgs, newResults, fusedKernelCounter); + + insertSubkernelCallsIntoFusedKernel(builder, module, fusionSet, + fusedKernelOp, argsToIndexMap, + newCallsToOldCallsMap); + + insertReturnOpToFusedKernelOp(builder, fusionSet, fusedKernelOp, + resultsToIndexMap, newCallsToOldCallsMap); + + buildFusedKernelCallAndReplaceSubkernelUses( + builder, fusionSet, func, fusedKernelOp, newArgs, resultsToIndexMap); + + } +} + +#define GEN_PASS_DEF_KERNELFUSIONPASS +#include "Transform/Kernel/KernelPasses.h.inc" + +struct KernelFusionPass : impl::KernelFusionPassBase { + using KernelFusionPassBase::KernelFusionPassBase; + + void runOnOperation() override { + ModuleOp module = dyn_cast(getOperation()); + module.walk([&](FuncOp func) { + if (!func.isPrivate()) + fuseKernels(module, func); + }); + } // end runOnOperation +}; +} // namespace kernel +} // namespace mlir diff --git a/mlir/lib/Transform/Kernel/LinalgGenericReordering.cpp b/mlir/lib/Transform/Kernel/LinalgGenericReordering.cpp new file mode 100644 index 00000000..1ddf113a --- /dev/null +++ b/mlir/lib/Transform/Kernel/LinalgGenericReordering.cpp @@ -0,0 +1,801 @@ +#include +#include +#include +#include + +#include "mlir-c/IR.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Value.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/raw_ostream.h" + +#include "Transform/Kernel/KernelPasses.h" + +namespace mlir { +namespace kernel { + +// {{{ generic -> einsum + +typedef struct EinsumArg { + std::string spec; + SmallVector shape; + int argIndex = -1; + + std::string stringify(); + + void print(llvm::raw_fd_ostream &stream) { stream << stringify(); } + void print() { print(llvm::errs()); } +} EinsumArg; + +std::string EinsumArg::stringify() { + std::string s = ""; + + s += (spec + ": "); + s += ("("); + for (auto it = shape.begin(); it != shape.end(); it++) { + s += std::to_string(*it); + if (it != (shape.end() - 1)) + s += ", "; + } + s += ")"; + + return s; +} + +typedef struct EinsumSpecification { + std::vector reductionDims; + std::vector parallelDims; + std::vector allDims; + + std::map temporaries; + + std::vector contractPath; + + linalg::LinalgOp definingOp; + + SmallVector inputs; + EinsumArg output; + + std::string stringify(); + + void print(llvm::raw_ostream &stream); + void print(); + + void setDimTypes(); + + linalg::LinalgOp getDefiningOp() { return definingOp; } + +} EinsumSpecification; + +std::string EinsumSpecification::stringify() { + std::string spec = ""; + for (auto input : inputs) { + spec += input.spec; + if (input.spec == (*(inputs.end() - 1)).spec) + spec += "->"; + else + spec += ","; + } + spec += output.spec; + + return spec; +} + +void EinsumSpecification::print(llvm::raw_ostream &stream) { + stream << stringify(); +} + +void EinsumSpecification::print() { print(llvm::errs()); } + +typedef struct FusedEinsum { + std::vector containedEinsums; + std::vector> contractPath; + + EinsumSpecification einsum; + + std::string stringify() { return einsum.stringify(); }; + + void print() { print(llvm::errs()); } + void print(llvm::raw_fd_ostream &stream) { einsum.print(stream); } + +} FusedEinsum; + +void printGenericAsEinsum(linalg::LinalgOp generic, llvm::raw_ostream &stream) { + for (auto idx_map : generic.getIndexingMapsArray()) { + stream << "("; + for (auto res : idx_map.getResults()) { + stream << res; + if (res != *(idx_map.getResults().end() - 1)) + stream << ", "; + } + stream << ")"; + + if (idx_map == *(generic.getIndexingMapsArray().end() - 2)) + stream << " -> "; + else if (idx_map != *(generic.getIndexingMapsArray().end() - 1)) + stream << ", "; + else + stream << "\n"; + } +} + +void printGenericAsEinsum(linalg::LinalgOp generic) { + printGenericAsEinsum(generic, llvm::errs()); +} + +bool isConvertibleToEinsum(linalg::GenericOp generic) { + if (llvm::range_size(generic.getResults()) > 1) return false; + + auto &block = generic.getRegion().front(); + for (auto &op : block) { + if ( + !isa(op) && + !isa(op) && + !isa(op) + ) { + op.dump(); + return false; + } + } + + auto addOps = block.getOps(); + auto mulOps = block.getOps(); + if (llvm::range_size(addOps) != 1) return false; + if (llvm::range_size(mulOps) != 1) return false; + + arith::AddFOp addition = *addOps.begin(); + auto accCheck = llvm::find( + addition.getOperands(), + block.getArguments().back() + ); + if (accCheck == addition.getOperands().end()) return false; + + arith::MulFOp mul = *mulOps.begin(); + auto mulCheck = llvm::find( + mul.getOperands(), + block.getArguments().back() + ); + if (mulCheck != mul.getOperands().end()) return false; + + return true; +} + +SmallVector generateEinsumArgsFromGeneric(linalg::LinalgOp generic) { + DenseMap affineToIndex; + std::string chars = "zyxwvutsrqponmlkjihgfedcba"; + + SmallVector args; + for (size_t i = 0; i < generic.getIndexingMapsArray().size(); ++i) { + EinsumArg einsum; + auto idxMap = generic.getIndexingMapsArray()[i]; + + std::string input = ""; + for (auto idx : idxMap.getResults()) { + if (affineToIndex.find(idx) == affineToIndex.end()) { + affineToIndex[idx] = chars.back(); + chars.pop_back(); + } + + input += affineToIndex[idx]; + } + + einsum.spec = input; + + auto genericArg = generic.getOpOperandsMatchingBBargs()[i]->get(); + if (genericArg && isa(genericArg)) { + einsum.argIndex = dyn_cast(genericArg).getArgNumber(); + } + + args.push_back(einsum); + } + + return args; +} + + +bool isProfitableToReorderGenerics(std::vector einsums) { + std::set costs; + for (EinsumSpecification einsum : einsums) { + + double cost = 1.0; + std::set seenIndices; + for (EinsumArg input : einsum.inputs) { + for (auto specIdx : llvm::enumerate(input.spec)) { + bool indexSeen = ( + std::find(seenIndices.begin(), seenIndices.end(), specIdx.value()) + != + seenIndices.end() + ); + + if (!indexSeen) { + cost *= input.shape[specIdx.index()]; + seenIndices.insert(specIdx.value()); + } + } + } + + costs.insert(cost); + } + + return (costs.size() != 1); +} + +void EinsumSpecification::setDimTypes() { + std::set inputIndices; + for (auto input : inputs) { + for (char c : input.spec) { + if (inputIndices.find(c) == inputIndices.end()) + inputIndices.insert(c); + } + } + + std::set outputIndices; + for (char c : output.spec) { + if (outputIndices.find(c) == outputIndices.end()) + outputIndices.insert(c); + } + + std::set_intersection( + inputIndices.begin(), inputIndices.end(), + outputIndices.begin(), outputIndices.end(), + std::back_inserter(parallelDims) + ); + + std::set_difference( + inputIndices.begin(), inputIndices.end(), + outputIndices.begin(), outputIndices.end(), + std::back_inserter(reductionDims) + ); + + std::set_union( + inputIndices.begin(), inputIndices.end(), + outputIndices.begin(), outputIndices.end(), + std::back_inserter(allDims) + ); +} + +EinsumSpecification genericToEinsumSpec(linalg::LinalgOp generic) { + EinsumSpecification einsum; + + SmallVector args = generateEinsumArgsFromGeneric(generic); + einsum.output = *(args.end() - 1); + einsum.inputs.insert(einsum.inputs.begin(), args.begin(), args.end() - 1); + + // FIXME: ensure linalg.generic is expressible by an einsum + unsigned int counter = 0; + for (auto arg : generic.getOpOperandsMatchingBBargs()) { + if (auto op = arg->get().getDefiningOp()) { + if (isa(op)) { + einsum.temporaries.insert( + std::pair(counter, op)); + } + } + + ++counter; + } + + for (size_t i = 0; i < (generic.getDpsInputOperands().size()); ++i) { + auto op = generic.getDpsInputOperands()[i]; + if (TensorType t = dyn_cast(op->get().getType())) { + einsum.inputs[i].shape.insert(einsum.inputs[i].shape.begin(), + t.getShape().begin(), t.getShape().end()); + } else { + llvm::errs() << "Could not determine shape of inputs arguments\n"; + abort(); + } + } + + if (!(generic.getDpsInits().size() == 1)) { + llvm::errs() << "Only a single output operand is supported\n"; + abort(); + } + + auto op = generic.getDpsInitOperand(0); + if (TensorType t = dyn_cast(op->get().getType())) { + einsum.output.shape.insert(einsum.output.shape.begin(), + t.getShape().begin(), t.getShape().end()); + } + + einsum.definingOp = generic; + einsum.setDimTypes(); + + return einsum; +} + +FusedEinsum fuseEinsums(std::vector einsums) { + FusedEinsum fusedEinsum; + + std::vector fused_einsums; + for (auto outer = einsums.rbegin(); outer != einsums.rend(); ++outer) { + std::string availChars = "zyxwvutsrqponmlkjihgfedcba"; + for (auto input : outer->inputs) { + for (auto idx : input.spec) { + auto pos = availChars.find(idx); + if (pos != std::string::npos) + availChars.erase(pos, 1); + } + } + + for (auto inner = outer + 1; inner != einsums.rend(); ++inner) { + for (auto inputToEinsum : outer->temporaries) { + if (inputToEinsum.second == inner->getDefiningOp()) { // fusable + EinsumArg outerInput = outer->inputs[inputToEinsum.first]; + + if (outerInput.shape != inner->output.shape) { + llvm::errs() << "ERROR: Shape mismatch\n"; + abort(); + } + + std::map innerOutToOuterIn; + for (size_t i = 0; i < inner->output.spec.size(); ++i) + innerOutToOuterIn.insert(std::pair( + inner->output.spec[i], outerInput.spec[i])); + + std::string inputIndices = ""; + for (auto input : inner->inputs) { + for (auto idx : input.spec) { + auto pos = inputIndices.find(idx); + if (pos == std::string::npos) + inputIndices += idx; + } + } + + for (auto p : innerOutToOuterIn) { + auto pos = inputIndices.find(p.first); + if (pos != std::string::npos) + inputIndices.erase(pos, 1); + } + + for (auto idx : inputIndices) { + auto p = innerOutToOuterIn.find(idx); + if (p == innerOutToOuterIn.end()) { + char newVal = *availChars.begin(); + availChars.erase(0, 1); + innerOutToOuterIn.insert(std::pair(idx, newVal)); + } else { + char newVal = *availChars.begin(); + availChars.erase(0, 1); + innerOutToOuterIn[p->first] = newVal; + } + } + + SmallVector newInputs; + for (auto s : inner->inputs) { + std::string innerSpec; + for (auto idx : s.spec) { + innerSpec += innerOutToOuterIn[idx]; + } + + EinsumArg newInput; + newInput.spec = innerSpec; + newInput.shape = s.shape; + newInput.argIndex = s.argIndex; + newInputs.push_back(newInput); + } + + for (auto input : outer->inputs) { + if (input.spec != outerInput.spec) { + EinsumArg newInput; + newInput.spec = input.spec; + newInput.shape = input.shape; + newInput.argIndex = input.argIndex; + newInputs.push_back(newInput); + } + } + + EinsumSpecification fusedSpec; + fusedSpec.inputs = newInputs; + fusedSpec.output = outer->output; + + fusedEinsum.einsum = fusedSpec; + } + } + } + } + + fusedEinsum.einsum.setDimTypes(); + + return fusedEinsum; +} + +typedef struct EinsumSequence { + std::vector sequence; + SmallVector> contractPath; +} EinsumSequence; + +// }}} + +// {{{ einsum -> generic + +bool buildGenericsFromEinsums(func::FuncOp func, EinsumSequence optimalOrder) { + if (llvm::range_size(func.getOps()) == 0) return false; + + OpBuilder builder(func); + builder.setInsertionPoint(func); + + mlir::MLIRContext *ctx = func.getContext(); + + FunctionType funcType = func.getFunctionType(); + + if (funcType.getNumResults() > 1) { + llvm::errs() << "Only a single return value is supported at this time. "; + llvm::errs() << "Generics will not be reordered.\n"; + return false; + } + + func::FuncOp newFunc = builder.create( + func.getLoc(), builder.getStringAttr(func.getName().str() + "_reordered"), + funcType); + newFunc.addEntryBlock(); + + RankedTensorType returnType = + dyn_cast(funcType.getResult(0)); + + auto elementType = returnType.getElementType(); + + SmallVector argList; + for (BlockArgument arg : newFunc.getBody().getArguments()) + argList.push_back(Value(arg)); + + size_t einsumCounter = 0; + for (EinsumSpecification einsum : optimalOrder.sequence) { + + // get result shape and type + SmallVector typeVector; + SmallVector outputShape; + for (int shapeComponent : einsum.output.shape) + outputShape.push_back((int64_t)shapeComponent); + typeVector.push_back(mlir::RankedTensorType::get(outputShape, elementType)); + + // get inputs and outputs + SmallVector inputs; + for (int idx : optimalOrder.contractPath[einsumCounter]) + inputs.push_back(argList[idx]); + + SmallVector outputs; + builder.setInsertionPointToEnd(&newFunc.getBody().front()); + outputs.push_back(builder.create( + newFunc.getLoc(), outputShape, elementType)); + + // get indexing maps + std::map allAffineDims; + for (auto [idx, dim] : llvm::enumerate(einsum.allDims)) { + allAffineDims.insert( + std::pair(dim, mlir::getAffineDimExpr(idx, ctx))); + } + + SmallVector indexingMaps; + for (EinsumArg input : einsum.inputs) { + std::vector indexingMapOutputDims; + for (char dim : input.spec) + indexingMapOutputDims.push_back(allAffineDims[dim]); + + indexingMaps.push_back( + AffineMap::get(einsum.allDims.size(), 0, indexingMapOutputDims, ctx)); + } + + std::vector indexingMapOutputDims; + for (char dim : einsum.output.spec) + indexingMapOutputDims.push_back(allAffineDims[dim]); + indexingMaps.push_back( + AffineMap::get(einsum.allDims.size(), 0, indexingMapOutputDims, ctx)); + + // get iterator types + SmallVector iteratorTypes; + for (auto [dim, affineDim] : allAffineDims) { + if (std::find(einsum.parallelDims.begin(), einsum.parallelDims.end(), + dim) != einsum.parallelDims.end()) + iteratorTypes.push_back(utils::IteratorType::parallel); + else if (std::find(einsum.reductionDims.begin(), + einsum.reductionDims.end(), + dim) != einsum.reductionDims.end()) + iteratorTypes.push_back(utils::IteratorType::reduction); + } + + SmallVector attributes = {}; + + linalg::GenericOp generic = builder.create( + newFunc.getLoc(), + TypeRange(typeVector), + ValueRange(inputs), + ValueRange(outputs), + indexingMaps, + ArrayRef(iteratorTypes), + [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { + Value mul = nestedBuilder.create( + nestedLoc, args[0], args[1]); + Value add = nestedBuilder.create( + nestedLoc, args[2], mul); + nestedBuilder.create(nestedLoc, add); + }, + attributes + ); + + Value result = generic.getResult(0); + int resultStoreIndex = + *std::min_element(optimalOrder.contractPath[einsumCounter].begin(), + optimalOrder.contractPath[einsumCounter].end()); + + for (int argIdx : optimalOrder.contractPath[einsumCounter]) { + argList.erase(std::find(argList.begin(), argList.end(), argList[argIdx])); + } + + argList.insert(argList.begin() + resultStoreIndex, result); + + if (einsumCounter == (optimalOrder.sequence.size() - 1)) { + builder.setInsertionPointToEnd(&newFunc.getBody().front()); + builder.create(newFunc.getLoc(), result); + } + + ++einsumCounter; + } + + // replace call to old function with call to new, optimized function + ModuleOp module = dyn_cast(func->getParentOp()); + module.walk([&](Operation *op) { + if (auto call = dyn_cast(op)) { + if (call.getCallee() == func.getName()) { + call.setCallee(newFunc.getName()); + } + } + }); + + return true; +} + +// }}} + +// {{{ optimizers + +typedef struct BruteForceOptimizer { + FusedEinsum unoptimizedEinsum; + EinsumSequence optimizedEinsumSequence; + + BruteForceOptimizer(FusedEinsum originalEinsum) + : unoptimizedEinsum(originalEinsum) {} + + void optimize(); + +} BruteForceOptimizer; + +SmallVector getResultShape(EinsumArg iarg, EinsumArg jarg, + std::string resultIndices) { + SmallVector shape; + for (char resultIdx : resultIndices) { + auto ipos = iarg.spec.find(resultIdx); + if (ipos != std::string::npos) { + shape.push_back(iarg.shape[ipos]); + continue; // avoid double-counting parallel axes + } + + auto jpos = iarg.spec.find(resultIdx); + if (jpos != std::string::npos) + shape.push_back(iarg.shape[ipos]); + } + + return shape; +} + +std::string getResultIndices(EinsumArg iarg, EinsumArg jarg, + std::vector sharedIndices, + std::vector reductionDims) { + + std::vector sharedReductionIndices; + std::set_intersection(reductionDims.begin(), reductionDims.end(), + sharedIndices.begin(), sharedIndices.end(), + std::back_inserter(sharedReductionIndices)); + + std::set allIndices; + allIndices.insert(iarg.spec.begin(), iarg.spec.end()); + allIndices.insert(jarg.spec.begin(), jarg.spec.end()); + for (char sharedReductionIndex : sharedReductionIndices) { + if (allIndices.find(sharedReductionIndex) != allIndices.end()) + allIndices.erase(sharedReductionIndex); + } + + std::string spec(allIndices.begin(), allIndices.end()); + return spec; +} + +bool checkLegalContraction(EinsumArg iarg, EinsumArg jarg, + std::vector sharedIndices, + std::vector reductionDims) { + + std::vector reducedIndices; + std::set_intersection( + sharedIndices.begin(), sharedIndices.end(), + reductionDims.begin(), reductionDims.end(), + std::back_inserter(reducedIndices) + ); + + if (reducedIndices.size() == 0) return false; + + for (char idx : reducedIndices) { + auto ipos = iarg.spec.find(idx); + auto jpos = jarg.spec.find(idx); + + if ((ipos != std::string::npos) && (jpos != std::string::npos)) { + if (iarg.shape[ipos] != jarg.shape[jpos]) return false; + } + if ((ipos != std::string::npos && jpos == std::string::npos)) return false; + if ((ipos == std::string::npos && jpos != std::string::npos)) return false; + } + + return true; + +} +std::vector getSharedIndices(EinsumArg iarg, EinsumArg jarg) { + std::set iindices(iarg.spec.begin(), iarg.spec.end()); + std::set jindices(jarg.spec.begin(), jarg.spec.end()); + + std::vector sharedIndices; + std::set_intersection( + iindices.begin(), iindices.end(), + jindices.begin(), jindices.end(), + std::back_inserter(sharedIndices) + ); + + return sharedIndices; +} + +double estimateCost(EinsumArg iarg, EinsumArg jarg, + std::vector sharedIndices) { + double cost = 1.0; + for (char sharedIdx : sharedIndices) + cost *= iarg.shape[iarg.spec.find(sharedIdx)]; + + std::set allIndices(iarg.spec.begin(), iarg.spec.end()); + allIndices.insert(jarg.spec.begin(), jarg.spec.end()); + + std::vector disjointIndices; + std::set_difference( + allIndices.begin(), allIndices.end(), + sharedIndices.begin(), sharedIndices.end(), + std::back_inserter(disjointIndices) + ); + + for (char idx : disjointIndices) { + if (iarg.spec.find(idx) != std::string::npos) + cost *= iarg.shape[iarg.spec.find(idx)]; + else if (jarg.spec.find(idx) != std::string::npos) + cost *= jarg.shape[jarg.spec.find(idx)]; + } + + return cost; +} + +void BruteForceOptimizer::optimize() { + EinsumSpecification einsum = unoptimizedEinsum.einsum; + SmallVector inputs = einsum.inputs; + EinsumArg output = einsum.output; + + double bestCost; + while (inputs.size() > 1) { + bestCost = std::numeric_limits::max(); + EinsumArg smallestTemporary; + + int imin = -1; + int jmin = -1; + for (size_t i = 0; i < inputs.size(); ++i) { + for (size_t j = i + 1; j < inputs.size(); ++j) { + EinsumArg iarg = inputs[i]; + EinsumArg jarg = inputs[j]; + + std::vector sharedIndices = getSharedIndices(iarg, jarg); + + if (!checkLegalContraction(iarg, jarg, sharedIndices, + einsum.reductionDims)) + continue; + + std::string resultIndices = + getResultIndices(iarg, jarg, sharedIndices, einsum.reductionDims); + + SmallVector resultShape = + getResultShape(iarg, jarg, resultIndices); + + double estimatedCost = estimateCost(iarg, jarg, sharedIndices); + if (estimatedCost <= bestCost) { + bestCost = estimatedCost; + imin = i; + jmin = j; + + smallestTemporary.spec = resultIndices; + smallestTemporary.shape = resultShape; + } + } + } + + EinsumSpecification einsumPart; + einsumPart.inputs.push_back(inputs[imin]); + einsumPart.inputs.push_back(inputs[jmin]); + + SmallVector path; + for (EinsumArg input : einsumPart.inputs) { + if (input.argIndex != -1) + path.push_back(input.argIndex); + } + + optimizedEinsumSequence.contractPath.push_back(path); + smallestTemporary.argIndex = *std::min_element(path.begin(), path.end()); + + if (imin > jmin) { + inputs.erase(inputs.begin() + imin); + inputs.erase(inputs.begin() + jmin); + } + else { + inputs.erase(inputs.begin() + jmin); + inputs.erase(inputs.begin() + imin); + } + inputs.push_back(smallestTemporary); + + einsumPart.output = smallestTemporary; + einsumPart.setDimTypes(); + optimizedEinsumSequence.sequence.push_back(einsumPart); + } +} + +// }}} + +// {{{ actual pass implementation + +#define GEN_PASS_DEF_LINALGGENERICREORDERINGPASS +#include "Transform/Kernel/KernelPasses.h.inc" + +struct LinalgGenericReordering + : impl::LinalgGenericReorderingPassBase { + using LinalgGenericReorderingPassBase::LinalgGenericReorderingPassBase; + + EinsumSequence + getOptimalContractionOrder(std::vector einsums) { + FusedEinsum fused = fuseEinsums(einsums); + BruteForceOptimizer optimizer(fused); + optimizer.optimize(); + + return optimizer.optimizedEinsumSequence; + } + + bool reorderGenerics(func::FuncOp func) { + for (linalg::GenericOp generic : func.getOps()) { + if (!isConvertibleToEinsum(generic)) return false; + } + + std::vector einsums; + for (linalg::LinalgOp laOp : func.getOps()) + einsums.push_back(genericToEinsumSpec(laOp)); + + if (!isProfitableToReorderGenerics(einsums)) return false; + + EinsumSequence optimalOrder = + getOptimalContractionOrder(einsums); + return buildGenericsFromEinsums(func, optimalOrder); + } + + void runOnOperation() override { + mlir::ModuleOp module = dyn_cast(getOperation()); + + // reorder linalg.generics in each fused kernel + for (func::FuncOp f : + llvm::make_early_inc_range(module.getOps())) { + if (reorderGenerics(f)) + f.erase(); + } + } + +}; + +std::unique_ptr createLinalgGenericReorderingPass() { + return std::make_unique(); +} + +// }}} + +} // namespace kernel +} // namespace mlir diff --git a/mlir/tools/lapis-opt/CMakeLists.txt b/mlir/tools/lapis-opt/CMakeLists.txt index 769d6bda..1c7c2304 100644 --- a/mlir/tools/lapis-opt/CMakeLists.txt +++ b/mlir/tools/lapis-opt/CMakeLists.txt @@ -14,6 +14,7 @@ target_link_libraries(lapis-opt MLIRCastInterfaces MLIRDialect MLIROptLib + MLIRKernelPasses MLIRParser MLIRPass MLIRTransforms diff --git a/mlir/tools/lapis-opt/lapis-opt.cpp b/mlir/tools/lapis-opt/lapis-opt.cpp index f7fb4457..1a6de921 100644 --- a/mlir/tools/lapis-opt/lapis-opt.cpp +++ b/mlir/tools/lapis-opt/lapis-opt.cpp @@ -54,6 +54,8 @@ #include "mlir/Support/LogicalResult.h" #include "mlir/Tools/mlir-opt/MlirOptMain.h" +#include "Transform/Kernel/KernelPasses.h" + using namespace mlir; int main(int argc, char **argv) { @@ -63,7 +65,7 @@ int main(int argc, char **argv) { DialectRegistry registry; registry.insert< #ifdef LAPIS_ENABLE_PART_TENSOR - mlir::part_tensor::PartTensorDialect, + mlir::part_tensor::PartTensorDialect, #endif mlir::LLVM::LLVMDialect, mlir::vector::VectorDialect, mlir::bufferization::BufferizationDialect, mlir::linalg::LinalgDialect, @@ -94,6 +96,10 @@ int main(int argc, char **argv) { tensor::registerValueBoundsOpInterfaceExternalModels(registry); vector::registerBufferizableOpInterfaceExternalModels(registry); + kernel::registerKernelFusionDriver(); + kernel::registerLinalgGenericReorderingPass(); + kernel::registerKernelDomainFusionPass(); + LLVM::registerInlinerInterface(registry); func::registerAllExtensions(registry); diff --git a/tests/Dialect/CMakeLists.txt b/tests/Dialect/CMakeLists.txt index ca41de30..81bc1059 100644 --- a/tests/Dialect/CMakeLists.txt +++ b/tests/Dialect/CMakeLists.txt @@ -3,4 +3,5 @@ enable_testing() configure_file(lit.site.cfg.py.in lit.site.cfg.py @ONLY) add_test(NAME Kokkos_Dialect COMMAND lit -v Kokkos) +add_test(NAME Transformations COMMAND lit -v Transforms) #add_test(NAME PartTensor_Dialect COMMAND lit -v PartTensor) diff --git a/tests/Dialect/Transforms/batched-ax-plus-y.mlir b/tests/Dialect/Transforms/batched-ax-plus-y.mlir new file mode 100644 index 00000000..879a2628 --- /dev/null +++ b/tests/Dialect/Transforms/batched-ax-plus-y.mlir @@ -0,0 +1,35 @@ +// RUN: %lapis-opt %s --drive-kernel-fusion | diff %s.gold - + +#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +#map3 = affine_map<(d0, d1) -> (d0, d1)> +module { + func.func @gemv(%arg0: tensor<10x10x10xf64>, %arg1: tensor<10x10xf64>, %arg2: tensor<10x10xf64>) -> tensor<10x10xf64> { + %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<10x10x10xf64>, tensor<10x10xf64>) outs(%arg2 : tensor<10x10xf64>) { + ^bb0(%in: f64, %in_0: f64, %out: f64): + %1 = arith.mulf %in, %in_0 : f64 + %2 = arith.addf %out, %1 : f64 + linalg.yield %2 : f64 + } -> tensor<10x10xf64> + return %0 : tensor<10x10xf64> + } + func.func @axpy(%arg0: tensor<10x10xf64>, %arg1: tensor<10x10xf64>, %arg2: tensor<10x10xf64>) -> tensor<10x10xf64> { + %0 = linalg.generic {indexing_maps = [#map3, #map3, #map3], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<10x10xf64>, tensor<10x10xf64>) outs(%arg2 : tensor<10x10xf64>) { + ^bb0(%in: f64, %in_0: f64, %out: f64): + %1 = arith.addf %in, %in_0 : f64 + linalg.yield %1 : f64 + } -> tensor<10x10xf64> + return %0 : tensor<10x10xf64> + } + func.func @main() -> tensor<10x10xf64> { + %0 = tensor.empty() : tensor<10x10x10xf64> + %1 = tensor.empty() : tensor<10x10xf64> + %2 = tensor.empty() : tensor<10x10xf64> + %3 = tensor.empty() : tensor<10x10xf64> + %4 = call @gemv(%0, %1, %3) { fuse_with = "axpy" } : (tensor<10x10x10xf64>, tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64> + %5 = tensor.empty() : tensor<10x10xf64> + %6 = call @axpy(%4, %2, %5) { fuse_with = "gemv" } : (tensor<10x10xf64>, tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64> + return %6 : tensor<10x10xf64> + } +} diff --git a/tests/Dialect/Transforms/batched-ax-plus-y.mlir.gold b/tests/Dialect/Transforms/batched-ax-plus-y.mlir.gold new file mode 100644 index 00000000..03eba615 --- /dev/null +++ b/tests/Dialect/Transforms/batched-ax-plus-y.mlir.gold @@ -0,0 +1,30 @@ +#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +#map3 = affine_map<(d0, d1) -> (d0, d1)> +module { + func.func @gemv_axpy_0(%arg0: tensor<10x10x10xf64>, %arg1: tensor<10x10xf64>, %arg2: tensor<10x10xf64>, %arg3: tensor<10x10xf64>, %arg4: tensor<10x10xf64>) -> tensor<10x10xf64> { + %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<10x10x10xf64>, tensor<10x10xf64>) outs(%arg2 : tensor<10x10xf64>) { + ^bb0(%in: f64, %in_0: f64, %out: f64): + %2 = arith.mulf %in, %in_0 : f64 + %3 = arith.addf %out, %2 : f64 + linalg.yield %3 : f64 + } -> tensor<10x10xf64> + %1 = linalg.generic {indexing_maps = [#map3, #map3, #map3], iterator_types = ["parallel", "parallel"]} ins(%0, %arg3 : tensor<10x10xf64>, tensor<10x10xf64>) outs(%arg4 : tensor<10x10xf64>) { + ^bb0(%in: f64, %in_0: f64, %out: f64): + %2 = arith.addf %in, %in_0 : f64 + linalg.yield %2 : f64 + } -> tensor<10x10xf64> + return %1 : tensor<10x10xf64> + } + func.func @main() -> tensor<10x10xf64> { + %0 = tensor.empty() : tensor<10x10x10xf64> + %1 = tensor.empty() : tensor<10x10xf64> + %2 = tensor.empty() : tensor<10x10xf64> + %3 = tensor.empty() : tensor<10x10xf64> + %4 = tensor.empty() : tensor<10x10xf64> + %5 = call @gemv_axpy_0(%0, %1, %3, %2, %4) {noinline} : (tensor<10x10x10xf64>, tensor<10x10xf64>, tensor<10x10xf64>, tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64> + return %5 : tensor<10x10xf64> + } +} + diff --git a/tests/Dialect/Transforms/input-dependence-test.mlir b/tests/Dialect/Transforms/input-dependence-test.mlir new file mode 100644 index 00000000..9cb41eb0 --- /dev/null +++ b/tests/Dialect/Transforms/input-dependence-test.mlir @@ -0,0 +1,60 @@ +// RUN: %lapis-opt %s --drive-kernel-fusion | diff %s.gold - + +#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d2, d3)> +#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d4)> +#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)> +#map3 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4, d3)> +#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d4)> +#map5 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)> +#map6 = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)> +module { + func.func @compute_reference_dx(%arg0: tensor<100x100x100x100xf64>, %arg1: tensor<100x100xf64>, %arg2: tensor<100x100x100x100xf64>) -> tensor<100x100x100x100xf64> { + %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<100x100x100x100xf64>, tensor<100x100xf64>) outs(%arg2 : tensor<100x100x100x100xf64>) { + ^bb0(%in: f64, %in_0: f64, %out: f64): + %1 = arith.mulf %in_0, %in : f64 + %2 = arith.addf %out, %1 : f64 + linalg.yield %2 : f64 + } -> tensor<100x100x100x100xf64> + return %0 : tensor<100x100x100x100xf64> + } + + func.func @compute_reference_dy(%arg0: tensor<100x100x100x100xf64>, %arg1: tensor<100x100xf64>, %arg2: tensor<100x100x100x100xf64>) -> tensor<100x100x100x100xf64> { + %0 = linalg.generic {indexing_maps = [#map3, #map4, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<100x100x100x100xf64>, tensor<100x100xf64>) outs(%arg2 : tensor<100x100x100x100xf64>) { + ^bb0(%in: f64, %in_0: f64, %out: f64): + %1 = arith.mulf %in_0, %in : f64 + %2 = arith.addf %out, %1 : f64 + linalg.yield %2 : f64 + } -> tensor<100x100x100x100xf64> + return %0 : tensor<100x100x100x100xf64> + } + + func.func @compute_reference_dz(%arg0: tensor<100x100x100x100xf64>, %arg1: tensor<100x100xf64>, %arg2: tensor<100x100x100x100xf64>) -> tensor<100x100x100x100xf64> { + %0 = linalg.generic {indexing_maps = [#map5, #map6, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<100x100x100x100xf64>, tensor<100x100xf64>) outs(%arg2 : tensor<100x100x100x100xf64>) { + ^bb0(%in: f64, %in_0: f64, %out: f64): + %1 = arith.mulf %in_0, %in : f64 + %2 = arith.addf %out, %1 : f64 + linalg.yield %2 : f64 + } -> tensor<100x100x100x100xf64> + return %0 : tensor<100x100x100x100xf64> + } + + func.func @main(%arg0: index, %arg1: index) -> (tensor<100x100x100x100xf64>, + tensor<100x100x100x100xf64>, tensor<100x100x100x100xf64>) attributes {llvm.emit_c_interface} { + %0 = tensor.empty() : tensor<100x100x100x100xf64> + %1 = tensor.empty() : tensor<100x100xf64> + %2 = tensor.empty() : tensor<100x100x100x100xf64> + %3 = tensor.empty() : tensor<100x100x100x100xf64> + %4 = tensor.empty() : tensor<100x100x100x100xf64> + + %5 = call @compute_reference_dx(%0, %1, %2) { fuse_with = + "compute_reference_dy, compute_reference_dz" } : (tensor<100x100x100x100xf64>, tensor<100x100xf64>, tensor<100x100x100x100xf64>) -> tensor<100x100x100x100xf64> + %6 = call @compute_reference_dy(%0, %1, %3) { fuse_with = + "compute_reference_dx, compute_reference_dz" } : (tensor<100x100x100x100xf64>, tensor<100x100xf64>, tensor<100x100x100x100xf64>) -> tensor<100x100x100x100xf64> + %7 = call @compute_reference_dz(%0, %1, %4) { fuse_with = + "compute_reference_dx, compute_reference_dy" }: (tensor<100x100x100x100xf64>, tensor<100x100xf64>, tensor<100x100x100x100xf64>) -> tensor<100x100x100x100xf64> + + return %5, %6, %7 : + tensor<100x100x100x100xf64>, tensor<100x100x100x100xf64>, tensor<100x100x100x100xf64> + } +} + diff --git a/tests/Dialect/Transforms/input-dependence-test.mlir.gold b/tests/Dialect/Transforms/input-dependence-test.mlir.gold new file mode 100644 index 00000000..e9646374 --- /dev/null +++ b/tests/Dialect/Transforms/input-dependence-test.mlir.gold @@ -0,0 +1,40 @@ +#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d2, d3)> +#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d4)> +#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)> +#map3 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4, d3)> +#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d4)> +#map5 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)> +#map6 = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)> +module { + func.func @compute_reference_dx_compute_reference_dy_compute_reference_dz_0(%arg0: tensor<100x100x100x100xf64>, %arg1: tensor<100x100xf64>, %arg2: tensor<100x100x100x100xf64>, %arg3: tensor<100x100x100x100xf64>, %arg4: tensor<100x100x100x100xf64>) -> (tensor<100x100x100x100xf64>, tensor<100x100x100x100xf64>, tensor<100x100x100x100xf64>) { + %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<100x100x100x100xf64>, tensor<100x100xf64>) outs(%arg2 : tensor<100x100x100x100xf64>) { + ^bb0(%in: f64, %in_0: f64, %out: f64): + %3 = arith.mulf %in_0, %in : f64 + %4 = arith.addf %out, %3 : f64 + linalg.yield %4 : f64 + } -> tensor<100x100x100x100xf64> + %1 = linalg.generic {indexing_maps = [#map3, #map4, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<100x100x100x100xf64>, tensor<100x100xf64>) outs(%arg3 : tensor<100x100x100x100xf64>) { + ^bb0(%in: f64, %in_0: f64, %out: f64): + %3 = arith.mulf %in_0, %in : f64 + %4 = arith.addf %out, %3 : f64 + linalg.yield %4 : f64 + } -> tensor<100x100x100x100xf64> + %2 = linalg.generic {indexing_maps = [#map5, #map6, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<100x100x100x100xf64>, tensor<100x100xf64>) outs(%arg4 : tensor<100x100x100x100xf64>) { + ^bb0(%in: f64, %in_0: f64, %out: f64): + %3 = arith.mulf %in_0, %in : f64 + %4 = arith.addf %out, %3 : f64 + linalg.yield %4 : f64 + } -> tensor<100x100x100x100xf64> + return %0, %1, %2 : tensor<100x100x100x100xf64>, tensor<100x100x100x100xf64>, tensor<100x100x100x100xf64> + } + func.func @main(%arg0: index, %arg1: index) -> (tensor<100x100x100x100xf64>, tensor<100x100x100x100xf64>, tensor<100x100x100x100xf64>) attributes {llvm.emit_c_interface} { + %0 = tensor.empty() : tensor<100x100x100x100xf64> + %1 = tensor.empty() : tensor<100x100xf64> + %2 = tensor.empty() : tensor<100x100x100x100xf64> + %3 = tensor.empty() : tensor<100x100x100x100xf64> + %4 = tensor.empty() : tensor<100x100x100x100xf64> + %5:3 = call @compute_reference_dx_compute_reference_dy_compute_reference_dz_0(%0, %1, %2, %3, %4) {noinline} : (tensor<100x100x100x100xf64>, tensor<100x100xf64>, tensor<100x100x100x100xf64>, tensor<100x100x100x100xf64>, tensor<100x100x100x100xf64>) -> (tensor<100x100x100x100xf64>, tensor<100x100x100x100xf64>, tensor<100x100x100x100xf64>) + return %5#0, %5#1, %5#2 : tensor<100x100x100x100xf64>, tensor<100x100x100x100xf64>, tensor<100x100x100x100xf64> + } +} + diff --git a/tests/Dialect/Transforms/multiple-entrypoints.mlir b/tests/Dialect/Transforms/multiple-entrypoints.mlir new file mode 100644 index 00000000..df57e628 --- /dev/null +++ b/tests/Dialect/Transforms/multiple-entrypoints.mlir @@ -0,0 +1,60 @@ +// RUN: %lapis-opt %s --drive-kernel-fusion | diff %s.gold - +module { + func.func private @matmul( + %a: tensor<4096x4096xf64>, + %b: tensor<4096x4096xf64>, + %c_out: tensor<4096x4096xf64> + ) -> tensor<4096x4096xf64> { + %c = linalg.matmul ins(%a, %b: tensor<4096x4096xf64>, tensor<4096x4096xf64>) + outs(%c_out: tensor<4096x4096xf64>) -> tensor<4096x4096xf64> + return %c : tensor<4096x4096xf64> + } + + func.func private @matvec( + %a: tensor<4096x4096xf64>, + %x: tensor<4096xf64>, + %y_out: tensor<4096xf64> + ) -> tensor<4096xf64> { + %y = linalg.matvec ins(%a, %x: tensor<4096x4096xf64>, tensor<4096xf64>) + outs(%y_out: tensor<4096xf64>) -> tensor<4096xf64> + + return %y : tensor<4096xf64> + } + + func.func @matmul_into_matvec( + %a: tensor<4096x4096xf64>, + %b: tensor<4096x4096xf64>, + %x: tensor<4096xf64> + ) -> tensor<4096xf64> { + + %c_init = tensor.empty() : tensor<4096x4096xf64> + %c = func.call @matmul(%a, %b, %c_init) { fuse_with = "matvec" } + : (tensor<4096x4096xf64>, tensor<4096x4096xf64>, tensor<4096x4096xf64>) + -> tensor<4096x4096xf64> + + %y_init = tensor.empty() : tensor<4096xf64> + %y_out = func.call @matvec(%c, %x, %y_init) { fuse_with = "matmul" } + : (tensor<4096x4096xf64>, tensor<4096xf64>, tensor<4096xf64>) + -> tensor<4096xf64> + + return %y_out : tensor<4096xf64> + } + + func.func @matvec_into_matvec( + %a: tensor<4096x4096xf64>, + %b: tensor<4096x4096xf64>, + %x: tensor<4096xf64> + ) -> tensor<4096xf64> { + %bx_init = tensor.empty() : tensor<4096xf64> + %bx = func.call @matvec(%b, %x, %bx_init) { fuse_with = "matvec" } + : (tensor<4096x4096xf64>, tensor<4096xf64>, tensor<4096xf64>) + -> tensor<4096xf64> + + %y_init = tensor.empty() : tensor<4096xf64> + %y_out = func.call @matvec(%a, %bx, %y_init) { fuse_with = "matvec" } + : (tensor<4096x4096xf64>, tensor<4096xf64>, tensor<4096xf64>) + -> tensor<4096xf64> + + return %y_out : tensor<4096xf64> + } +} diff --git a/tests/Dialect/Transforms/multiple-entrypoints.mlir.gold b/tests/Dialect/Transforms/multiple-entrypoints.mlir.gold new file mode 100644 index 00000000..9d34716e --- /dev/null +++ b/tests/Dialect/Transforms/multiple-entrypoints.mlir.gold @@ -0,0 +1,51 @@ +#map = affine_map<(d0, d1) -> (d0, d1)> +#map1 = affine_map<(d0, d1) -> (d1)> +#map2 = affine_map<(d0, d1) -> (d0)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map4 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map5 = affine_map<(d0, d1, d2) -> (d0, d1)> +module { + func.func @matvec_matvec_0(%arg0: tensor<4096x4096xf64>, %arg1: tensor<4096xf64>, %arg2: tensor<4096xf64>, %arg3: tensor<4096x4096xf64>, %arg4: tensor<4096xf64>) -> tensor<4096xf64> { + %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "reduction"]} ins(%arg0, %arg1 : tensor<4096x4096xf64>, tensor<4096xf64>) outs(%arg2 : tensor<4096xf64>) { + ^bb0(%in: f64, %in_0: f64, %out: f64): + %2 = arith.mulf %in, %in_0 : f64 + %3 = arith.addf %out, %2 : f64 + linalg.yield %3 : f64 + } -> tensor<4096xf64> + %1 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "reduction"]} ins(%arg3, %0 : tensor<4096x4096xf64>, tensor<4096xf64>) outs(%arg4 : tensor<4096xf64>) { + ^bb0(%in: f64, %in_0: f64, %out: f64): + %2 = arith.mulf %in, %in_0 : f64 + %3 = arith.addf %out, %2 : f64 + linalg.yield %3 : f64 + } -> tensor<4096xf64> + return %1 : tensor<4096xf64> + } + func.func @matmul_matvec_0(%arg0: tensor<4096x4096xf64>, %arg1: tensor<4096x4096xf64>, %arg2: tensor<4096x4096xf64>, %arg3: tensor<4096xf64>, %arg4: tensor<4096xf64>) -> tensor<4096xf64> { + %0 = linalg.generic {indexing_maps = [#map3, #map4, #map5], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<4096x4096xf64>, tensor<4096x4096xf64>) outs(%arg2 : tensor<4096x4096xf64>) { + ^bb0(%in: f64, %in_0: f64, %out: f64): + %2 = arith.mulf %in, %in_0 : f64 + %3 = arith.addf %out, %2 : f64 + linalg.yield %3 : f64 + } -> tensor<4096x4096xf64> + %1 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "reduction"]} ins(%0, %arg3 : tensor<4096x4096xf64>, tensor<4096xf64>) outs(%arg4 : tensor<4096xf64>) { + ^bb0(%in: f64, %in_0: f64, %out: f64): + %2 = arith.mulf %in, %in_0 : f64 + %3 = arith.addf %out, %2 : f64 + linalg.yield %3 : f64 + } -> tensor<4096xf64> + return %1 : tensor<4096xf64> + } + func.func @matmul_into_matvec(%arg0: tensor<4096x4096xf64>, %arg1: tensor<4096x4096xf64>, %arg2: tensor<4096xf64>) -> tensor<4096xf64> { + %0 = tensor.empty() : tensor<4096x4096xf64> + %1 = tensor.empty() : tensor<4096xf64> + %2 = call @matmul_matvec_0(%arg0, %arg1, %0, %arg2, %1) {noinline} : (tensor<4096x4096xf64>, tensor<4096x4096xf64>, tensor<4096x4096xf64>, tensor<4096xf64>, tensor<4096xf64>) -> tensor<4096xf64> + return %2 : tensor<4096xf64> + } + func.func @matvec_into_matvec(%arg0: tensor<4096x4096xf64>, %arg1: tensor<4096x4096xf64>, %arg2: tensor<4096xf64>) -> tensor<4096xf64> { + %0 = tensor.empty() : tensor<4096xf64> + %1 = tensor.empty() : tensor<4096xf64> + %2 = call @matvec_matvec_0(%arg1, %arg2, %0, %arg0, %1) {noinline} : (tensor<4096x4096xf64>, tensor<4096xf64>, tensor<4096xf64>, tensor<4096x4096xf64>, tensor<4096xf64>) -> tensor<4096xf64> + return %2 : tensor<4096xf64> + } +} + diff --git a/tests/Dialect/Transforms/pcg-fuse-spmv-dot-axpby-axpby.mlir b/tests/Dialect/Transforms/pcg-fuse-spmv-dot-axpby-axpby.mlir new file mode 100644 index 00000000..e87cdf96 --- /dev/null +++ b/tests/Dialect/Transforms/pcg-fuse-spmv-dot-axpby-axpby.mlir @@ -0,0 +1,98 @@ +// RUN: %lapis-opt %s --drive-kernel-fusion | diff %s.gold - +#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 32, crdWidth = 32 }> +#idmap = affine_map<(d0) -> (d0)> +module { + func.func private @spmv(%A: tensor, %x: tensor, %ydst: tensor) -> tensor { + %y = linalg.matvec ins(%A, %x: tensor, tensor) outs(%ydst : tensor) -> tensor + return %y : tensor + } + + func.func private @dot(%x: tensor, %y: tensor) -> f64 { + %0 = tensor.empty() : tensor + %dot = linalg.dot ins(%x, %y : tensor,tensor) outs(%0: tensor) -> tensor + %6 = tensor.extract %dot[] : tensor + return %6: f64 + } + + func.func private @axpby(%a: f64, %x: tensor, %b: f64, %y: tensor, %dst: tensor) -> tensor { + %1 = linalg.generic {indexing_maps = [#idmap, #idmap, #idmap], iterator_types = ["parallel"]} ins(%x, %y: tensor, tensor) outs(%dst : tensor) { + ^bb0(%inx: f64, %iny: f64, %out: f64): + %4 = arith.mulf %inx, %a: f64 + %5 = arith.mulf %iny, %b: f64 + %6 = arith.addf %4, %5: f64 + linalg.yield %6 : f64 + } -> tensor + return %1 : tensor + } + + func.func private @mult(%x: tensor, %y: tensor, %dst: tensor) -> tensor { + %1 = linalg.generic {indexing_maps = [#idmap, #idmap, #idmap], iterator_types = ["parallel"]} ins(%x, %y: tensor, tensor) outs(%dst : tensor) { + ^bb0(%inx: f64, %iny: f64, %out: f64): + %2 = arith.mulf %inx, %iny: f64 + linalg.yield %2 : f64 + } -> tensor + return %1 : tensor + } + + // CG solve with diagonal preconditioner + // Returns: x, numiter, resnorm + func.func @pcg(%A: tensor, %b: tensor, %dinv: tensor, %tol: f64, %maxiter: index) -> (tensor, index, f64) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %n = tensor.dim %b, %c0 : tensor + %f0 = arith.constant 0.0 : f64 + %f1 = arith.constant 1.0 : f64 + %fm1 = arith.constant -1.0 : f64 + + // Preallocate some intermediate tensors for dst-passing style + %buf0 = tensor.empty(%n) : tensor + %buf1 = tensor.empty(%n) : tensor + %buf2 = tensor.empty(%n) : tensor + + // Assume initial guess x0 = 0 + // Then r0 = b - A*x0 = b + %r0 = linalg.copy ins(%b : tensor) outs(%buf0 : tensor) -> tensor + %z0 = func.call @mult(%r0, %dinv, %buf1) : (tensor, tensor, tensor) -> tensor + %p0 = linalg.copy ins(%z0 : tensor) outs(%buf2 : tensor) -> tensor + %x0 = tensor.splat %f0[%n] : tensor + %Apbuf = tensor.empty(%n) : tensor + %rr0 = func.call @dot(%r0, %r0) : (tensor, tensor) -> f64 + %initres = math.sqrt %rr0 : f64 + + %x, %p, %z, %r, %final_relres, %rz, %iters = scf.while (%xiter = %x0, %piter = %p0, %ziter = %z0, %riter = %r0, %rziter = %f0, %i = %c1) : (tensor, tensor, tensor, tensor, f64, index) -> (tensor, tensor, tensor, tensor, f64, f64, index) + { + %Ap = func.call @spmv(%A, %piter, %Apbuf) { fuse_with = "dot" } : (tensor, tensor, tensor) -> tensor + %pAp = func.call @dot(%Ap, %piter) { fuse_with = "spmv" } : (tensor, tensor) -> f64 + %rz = func.call @dot(%riter, %ziter) : (tensor, tensor) -> f64 + %alpha = arith.divf %rz, %pAp : f64 + %malpha = arith.negf %alpha : f64 + + // Update x and r + %xnext = func.call @axpby(%f1, %xiter, %alpha, %piter, %xiter) { fuse_with = "axpby" } : (f64, tensor, f64, tensor, tensor) -> tensor + %rnext = func.call @axpby(%f1, %riter, %malpha, %Ap, %riter) { fuse_with = "axpby" } : (f64, tensor, f64, tensor, tensor) -> tensor + + // Test against tolerance and + %rr = func.call @dot(%rnext, %rnext) : (tensor, tensor) -> f64 + %rnorm = math.sqrt %rr : f64 + %relres = arith.divf %rnorm, %initres : f64 + %not_converged = arith.cmpf ogt, %relres, %tol : f64 + + // we have already completed an iteration, which is why i is intially 1 + %below_maxiter = arith.cmpi ne, %i, %maxiter : index + %continue = arith.andi %not_converged, %below_maxiter : i1 + + scf.condition(%continue) %xnext, %piter, %ziter, %rnext, %relres, %rz, %i: tensor, tensor, tensor, tensor, f64, f64, index + } + do { + ^bb0(%xiter : tensor, %piter : tensor, %ziter : tensor, %riter : tensor, %unused : f64, %rziter : f64, %i : index): + %znext = func.call @mult(%riter, %dinv, %ziter) : (tensor, tensor, tensor) -> tensor + %rznext = func.call @dot(%riter, %znext) : (tensor, tensor) -> f64 + %beta = arith.divf %rznext, %rziter : f64 + %pnext = func.call @axpby(%f1, %znext, %beta, %piter, %piter) : (f64, tensor, f64, tensor, tensor) -> tensor + %inext = arith.addi %i, %c1 : index + scf.yield %xiter, %pnext, %znext, %riter, %rznext, %inext : tensor, tensor, tensor, tensor, f64, index + } + return %x, %iters, %final_relres : tensor, index, f64 + } +} + diff --git a/tests/Dialect/Transforms/pcg-fuse-spmv-dot-axpby-axpby.mlir.gold b/tests/Dialect/Transforms/pcg-fuse-spmv-dot-axpby-axpby.mlir.gold new file mode 100644 index 00000000..cb2d082d --- /dev/null +++ b/tests/Dialect/Transforms/pcg-fuse-spmv-dot-axpby-axpby.mlir.gold @@ -0,0 +1,130 @@ +#map = affine_map<(d0) -> (d0)> +#map1 = affine_map<(d0, d1) -> (d0, d1)> +#map2 = affine_map<(d0, d1) -> (d1)> +#map3 = affine_map<(d0, d1) -> (d0)> +#map4 = affine_map<(d0) -> ()> +#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 32, crdWidth = 32 }> +module { + func.func @axpby_axpby_0(%arg0: f64, %arg1: tensor, %arg2: f64, %arg3: tensor, %arg4: tensor, %arg5: f64, %arg6: tensor) -> (tensor, tensor) { + %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg1, %arg3 : tensor, tensor) outs(%arg1 : tensor) { + ^bb0(%in: f64, %in_0: f64, %out: f64): + %2 = arith.mulf %in, %arg0 : f64 + %3 = arith.mulf %in_0, %arg2 : f64 + %4 = arith.addf %2, %3 : f64 + linalg.yield %4 : f64 + } -> tensor + %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg4, %arg6 : tensor, tensor) outs(%arg4 : tensor) { + ^bb0(%in: f64, %in_0: f64, %out: f64): + %2 = arith.mulf %in, %arg0 : f64 + %3 = arith.mulf %in_0, %arg5 : f64 + %4 = arith.addf %2, %3 : f64 + linalg.yield %4 : f64 + } -> tensor + return %0, %1 : tensor, tensor + } + func.func @spmv_dot_0(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> (tensor, f64) { + %0 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "reduction"]} ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) { + ^bb0(%in: f64, %in_0: f64, %out: f64): + %3 = arith.mulf %in, %in_0 : f64 + %4 = arith.addf %out, %3 : f64 + linalg.yield %4 : f64 + } -> tensor + %1 = tensor.empty() : tensor + %2 = linalg.generic {indexing_maps = [#map, #map, #map4], iterator_types = ["reduction"]} ins(%0, %arg1 : tensor, tensor) outs(%1 : tensor) { + ^bb0(%in: f64, %in_0: f64, %out: f64): + %3 = arith.mulf %in, %in_0 : f64 + %4 = arith.addf %out, %3 : f64 + linalg.yield %4 : f64 + } -> tensor + %extracted = tensor.extract %2[] : tensor + return %0, %extracted : tensor, f64 + } + func.func @pcg(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: f64, %arg4: index) -> (tensor, index, f64) { + %cst = arith.constant 1.000000e+00 : f64 + %cst_0 = arith.constant 0.000000e+00 : f64 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %dim = tensor.dim %arg1, %c0 : tensor + %0 = tensor.empty(%dim) : tensor + %1 = tensor.empty(%dim) : tensor + %2 = tensor.empty(%dim) : tensor + %3 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg1 : tensor) outs(%0 : tensor) { + ^bb0(%in: f64, %out: f64): + linalg.yield %in : f64 + } -> tensor + %4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%3, %arg2 : tensor, tensor) outs(%1 : tensor) { + ^bb0(%in: f64, %in_1: f64, %out: f64): + %11 = arith.mulf %in, %in_1 : f64 + linalg.yield %11 : f64 + } -> tensor + %5 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%4 : tensor) outs(%2 : tensor) { + ^bb0(%in: f64, %out: f64): + linalg.yield %in : f64 + } -> tensor + %splat = tensor.splat %cst_0[%dim] : tensor + %6 = tensor.empty(%dim) : tensor + %7 = tensor.empty() : tensor + %8 = linalg.generic {indexing_maps = [#map, #map, #map4], iterator_types = ["reduction"]} ins(%3, %3 : tensor, tensor) outs(%7 : tensor) { + ^bb0(%in: f64, %in_1: f64, %out: f64): + %11 = arith.mulf %in, %in_1 : f64 + %12 = arith.addf %out, %11 : f64 + linalg.yield %12 : f64 + } -> tensor + %extracted = tensor.extract %8[] : tensor + %9 = math.sqrt %extracted : f64 + %10:7 = scf.while (%arg5 = %splat, %arg6 = %5, %arg7 = %4, %arg8 = %3, %arg9 = %c1) : (tensor, tensor, tensor, tensor, index) -> (tensor, tensor, tensor, tensor, f64, f64, index) { + %11:2 = func.call @spmv_dot_0(%arg0, %arg6, %6) {noinline} : (tensor, tensor, tensor) -> (tensor, f64) + %12 = tensor.empty() : tensor + %13 = linalg.generic {indexing_maps = [#map, #map, #map4], iterator_types = ["reduction"]} ins(%arg8, %arg7 : tensor, tensor) outs(%12 : tensor) { + ^bb0(%in: f64, %in_3: f64, %out: f64): + %24 = arith.mulf %in, %in_3 : f64 + %25 = arith.addf %out, %24 : f64 + linalg.yield %25 : f64 + } -> tensor + %extracted_1 = tensor.extract %13[] : tensor + %14 = arith.divf %extracted_1, %11#1 : f64 + %15 = arith.negf %14 : f64 + %16:2 = func.call @axpby_axpby_0(%cst, %arg5, %14, %arg6, %arg8, %15, %11#0) {noinline} : (f64, tensor, f64, tensor, tensor, f64, tensor) -> (tensor, tensor) + %17 = tensor.empty() : tensor + %18 = linalg.generic {indexing_maps = [#map, #map, #map4], iterator_types = ["reduction"]} ins(%16#1, %16#1 : tensor, tensor) outs(%17 : tensor) { + ^bb0(%in: f64, %in_3: f64, %out: f64): + %24 = arith.mulf %in, %in_3 : f64 + %25 = arith.addf %out, %24 : f64 + linalg.yield %25 : f64 + } -> tensor + %extracted_2 = tensor.extract %18[] : tensor + %19 = math.sqrt %extracted_2 : f64 + %20 = arith.divf %19, %9 : f64 + %21 = arith.cmpf ogt, %20, %arg3 : f64 + %22 = arith.cmpi ne, %arg9, %arg4 : index + %23 = arith.andi %21, %22 : i1 + scf.condition(%23) %16#0, %arg6, %arg7, %16#1, %20, %extracted_1, %arg9 : tensor, tensor, tensor, tensor, f64, f64, index + } do { + ^bb0(%arg5: tensor, %arg6: tensor, %arg7: tensor, %arg8: tensor, %arg9: f64, %arg10: f64, %arg11: index): + %11 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg8, %arg2 : tensor, tensor) outs(%arg7 : tensor) { + ^bb0(%in: f64, %in_2: f64, %out: f64): + %17 = arith.mulf %in, %in_2 : f64 + linalg.yield %17 : f64 + } -> tensor + %12 = tensor.empty() : tensor + %13 = linalg.generic {indexing_maps = [#map, #map, #map4], iterator_types = ["reduction"]} ins(%arg8, %11 : tensor, tensor) outs(%12 : tensor) { + ^bb0(%in: f64, %in_2: f64, %out: f64): + %17 = arith.mulf %in, %in_2 : f64 + %18 = arith.addf %out, %17 : f64 + linalg.yield %18 : f64 + } -> tensor + %extracted_1 = tensor.extract %13[] : tensor + %14 = arith.divf %extracted_1, %arg10 : f64 + %15 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%11, %arg6 : tensor, tensor) outs(%arg6 : tensor) { + ^bb0(%in: f64, %in_2: f64, %out: f64): + %17 = arith.mulf %in_2, %14 : f64 + %18 = arith.addf %in, %17 : f64 + linalg.yield %18 : f64 + } -> tensor + %16 = arith.addi %arg11, %c1 : index + scf.yield %arg5, %15, %11, %arg8, %16 : tensor, tensor, tensor, tensor, index + } + return %10#0, %10#6, %10#4 : tensor, index, f64 + } +} +