Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
1ec8d24
add kernel fusion passes
a-alveyblanc Nov 6, 2024
dbf9397
add mlir dumps for kernel fusion examples
a-alveyblanc Nov 6, 2024
d115b87
add CG example
a-alveyblanc Nov 6, 2024
e0f6b30
Merge branch 'main' of github.com:sandialabs/LAPIS into add-kernel-fu…
a-alveyblanc Dec 4, 2024
bf9b0b7
Merge branch 'main' of github.com:sandialabs/LAPIS into add-kernel-fu…
a-alveyblanc Dec 12, 2024
6ada7b8
merge
a-alveyblanc Jan 22, 2025
cb0cf6d
Merge branch 'main' of github.com:sandialabs/LAPIS into add-kernel-fu…
a-alveyblanc Jan 29, 2025
564e041
changes
a-alveyblanc May 7, 2025
26973e3
merge
a-alveyblanc May 7, 2025
534f466
generic -> einsum -> reordered generics transformation
a-alveyblanc Jun 12, 2025
448357f
Merge branch 'main' of github.com:sandialabs/LAPIS into add-kernel-fu…
a-alveyblanc Jun 12, 2025
5f5b72e
clean up einsum analysis utils
a-alveyblanc Jun 23, 2025
033ca38
add einsum analysis unit tests
a-alveyblanc Jun 23, 2025
fa2dfd0
improve einsum analysis compatibility checks
a-alveyblanc Jun 24, 2025
f6fb899
add more tests for transformations
a-alveyblanc Jun 24, 2025
0a7ca5f
move linalg generic reordering to its own pass
a-alveyblanc Jun 25, 2025
15322d8
move linalg generic reordering to its own pass
a-alveyblanc Jun 25, 2025
017999c
remove Utils.cpp file since its no longer needed
a-alveyblanc Jun 25, 2025
96c234f
actually make linalg generic reordering an independent pass
a-alveyblanc Jun 25, 2025
065dfb6
update fusion legality check
a-alveyblanc Jul 1, 2025
f6a05af
all tests passing
a-alveyblanc Jul 8, 2025
c3c27de
PCG mlir dump, refactor kernel fusion pass
a-alveyblanc Jul 11, 2025
95ee6d5
more cleanups + refactor fused kernel return value computation
a-alveyblanc Jul 15, 2025
74198ac
regenerate gold tests for kernel fusion
a-alveyblanc Jul 15, 2025
95a8fa4
change around what to fuse in pcg example
a-alveyblanc Jul 15, 2025
5bf2723
run fusion over non-privated functions
a-alveyblanc Jul 16, 2025
586bc4f
update producer/consumer mapping inside of fused kernels
a-alveyblanc Jul 17, 2025
7c46e0a
run a pre-check on einsum analysis to see if it would be profitable t…
a-alveyblanc Jul 22, 2025
ad21ab2
updated tests
a-alveyblanc Jul 23, 2025
a83c5bb
Merge branch 'main' of github.com:sandialabs/LAPIS into add-kernel-fu…
a-alveyblanc Jul 23, 2025
41fefe5
remove FIXME
a-alveyblanc Jul 24, 2025
5b707cc
disable linalg reordering for now
a-alveyblanc Jul 30, 2025
0d0ce43
Merge branch 'main' of github.com:sandialabs/LAPIS into add-kernel-fu…
a-alveyblanc Aug 4, 2025
bea1cde
bugfixes
a-alveyblanc Aug 14, 2025
8fbdc0f
merge conflicts
a-alveyblanc Aug 14, 2025
630e548
Add pass options
a-alveyblanc Oct 8, 2025
f837e44
resolve merge conflict
a-alveyblanc Oct 14, 2025
bf81805
add CMakeFiles, CMakeCache, Makefile to .gitignore
a-alveyblanc Oct 14, 2025
aa30e47
remove CMakeFiles, CMakeCache, Makefile from benchmarks directory
a-alveyblanc Oct 14, 2025
1cd69ed
remove CMakeCache, Makefile
a-alveyblanc Oct 14, 2025
171693d
more removals
a-alveyblanc Oct 14, 2025
de7179e
remove cmake_install.cmake
a-alveyblanc Oct 14, 2025
6446779
more cleanups
a-alveyblanc Oct 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ __pycache__

.pytype

# CMake related stuff
CMakeFiles
CMakeCache.txt
Makefile
cmake_install.cmake

# Stuff auto-generated by MLIR->Kokkos emitter
examples/lapis*/
examples/*.mlir
Expand All @@ -103,6 +109,8 @@ examples/*.tns
tests/lapis/
tests/*.mlir
tests/*.tns
tests/Dialect/Testing
tests/Testing

# Other test artifacts
tests/Dialect/*/Output
Expand Down Expand Up @@ -152,3 +160,7 @@ examples/mala_batch/forward_snap.hpp
examples/mala_batch/forward_snap_lowered.mlir
examples/mala_batch/build/

# Temporary files produced by maxpool
examples/maxpool_nchw_teamlevel/maxpool.cpp
examples/maxpool_nchw_teamlevel/maxpool.hpp
examples/maxpool_nchw_teamlevel/maxpool_lowered.mlir
34 changes: 34 additions & 0 deletions examples/mlir-dumps/batched-gemv-axpy-base.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
#map3 = affine_map<(d0, d1) -> (d0, d1)>
module {
func.func @gemv(%arg0: tensor<?x?x?xf64>, %arg1: tensor<?x?xf64>, %arg2: tensor<?x?xf64>) -> tensor<?x?xf64> {
%0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<?x?x?xf64>, tensor<?x?xf64>) outs(%arg2 : tensor<?x?xf64>) {
^bb0(%in: f64, %in_0: f64, %out: f64):
%1 = arith.mulf %in, %in_0 : f64
%2 = arith.addf %out, %1 : f64
linalg.yield %2 : f64
} -> tensor<?x?xf64>
return %0 : tensor<?x?xf64>
}
func.func @axpy(%arg0: tensor<?x?xf64>, %arg1: tensor<?x?xf64>, %arg2: tensor<?x?xf64>) -> tensor<?x?xf64> {
%0 = linalg.generic {indexing_maps = [#map3, #map3, #map3], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf64>, tensor<?x?xf64>) outs(%arg2 : tensor<?x?xf64>) {
^bb0(%in: f64, %in_0: f64, %out: f64):
%1 = arith.addf %in, %in_0 : f64
linalg.yield %1 : f64
} -> tensor<?x?xf64>
return %0 : tensor<?x?xf64>
}
func.func @main(%batch_size : index, %n : index) -> tensor<?x?xf64> {
%0 = tensor.empty(%batch_size, %n, %n) : tensor<?x?x?xf64>
%1 = tensor.empty(%batch_size, %n) : tensor<?x?xf64>
%2 = tensor.empty(%batch_size, %n) : tensor<?x?xf64>
%3 = tensor.empty(%batch_size, %n) : tensor<?x?xf64>
%4 = call @gemv(%0, %1, %3) { fuse_with = "axpy" } : (tensor<?x?x?xf64>, tensor<?x?xf64>, tensor<?x?xf64>) -> tensor<?x?xf64>
%5 = tensor.empty(%batch_size, %n) : tensor<?x?xf64>
%6 = call @axpy(%4, %2, %5) { fuse_with = "gemv" } : (tensor<?x?xf64>, tensor<?x?xf64>, tensor<?x?xf64>) -> tensor<?x?xf64>
return %6 : tensor<?x?xf64>
}
}

35 changes: 35 additions & 0 deletions examples/mlir-dumps/batched-gemv-axpy-non-dynamic.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
#map3 = affine_map<(d0, d1) -> (d0, d1)>
module {
func.func @gemv(%arg0: tensor<10x10x10xf64>, %arg1: tensor<10x10xf64>, %arg2:
tensor<10x10xf64>) -> tensor<10x10xf64> attributes { noinline } {
%0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<10x10x10xf64>, tensor<10x10xf64>) outs(%arg2 : tensor<10x10xf64>) {
^bb0(%in: f64, %in_0: f64, %out: f64):
%1 = arith.mulf %in, %in_0 : f64
%2 = arith.addf %out, %1 : f64
linalg.yield %2 : f64
} -> tensor<10x10xf64>
return %0 : tensor<10x10xf64>
}
func.func @axpy(%arg0: tensor<10x10xf64>, %arg1: tensor<10x10xf64>, %arg2:
tensor<10x10xf64>) -> tensor<10x10xf64> attributes { noinline } {
%0 = linalg.generic {indexing_maps = [#map3, #map3, #map3], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<10x10xf64>, tensor<10x10xf64>) outs(%arg2 : tensor<10x10xf64>) {
^bb0(%in: f64, %in_0: f64, %out: f64):
%1 = arith.addf %in, %in_0 : f64
linalg.yield %1 : f64
} -> tensor<10x10xf64>
return %0 : tensor<10x10xf64>
}
func.func @main() -> tensor<10x10xf64> {
%0 = tensor.empty() : tensor<10x10x10xf64>
%1 = tensor.empty() : tensor<10x10xf64>
%2 = tensor.empty() : tensor<10x10xf64>
%3 = tensor.empty() : tensor<10x10xf64>
%5 = tensor.empty() : tensor<10x10xf64>
%4 = call @gemv(%0, %1, %3) { fuse_with = "axpy" } : (tensor<10x10x10xf64>, tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64>
%6 = call @axpy(%4, %2, %5) { fuse_with = "gemv" } : (tensor<10x10xf64>, tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64>
return %6 : tensor<10x10xf64>
}
}
105 changes: 105 additions & 0 deletions examples/mlir-dumps/cg-iteration-unfused.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
module {
func.func @gemv(%arg0: memref<128x128xf64>, %arg1: memref<128xf64>, %arg2: memref<128xf64>) -> memref<128xf64> {
%c0 = arith.constant 0 : index
%c128 = arith.constant 128 : index
%c1 = arith.constant 1 : index
scf.parallel (%arg3) = (%c0) to (%c128) step (%c1) {
scf.for %arg4 = %c0 to %c128 step %c1 {
%0 = memref.load %arg0[%arg3, %arg4] : memref<128x128xf64>
%1 = memref.load %arg1[%arg4] : memref<128xf64>
%2 = memref.load %arg2[%arg3] : memref<128xf64>
%3 = arith.mulf %0, %1 : f64
%4 = arith.addf %2, %3 : f64
memref.store %4, %arg2[%arg3] : memref<128xf64>
}
scf.reduce
}
return %arg2 : memref<128xf64>
}
func.func @xpy(%arg0: memref<128xf64>, %arg1: memref<128xf64>, %arg2: memref<128xf64>) -> memref<128xf64> {
%c0 = arith.constant 0 : index
%c128 = arith.constant 128 : index
%c1 = arith.constant 1 : index
scf.parallel (%arg3) = (%c0) to (%c128) step (%c1) {
%0 = memref.load %arg0[%arg3] : memref<128xf64>
%1 = memref.load %arg1[%arg3] : memref<128xf64>
%2 = arith.addf %0, %1 : f64
memref.store %2, %arg2[%arg3] : memref<128xf64>
scf.reduce
}
return %arg2 : memref<128xf64>
}
func.func @dot(%arg0: memref<128xf64>, %arg1: memref<128xf64>, %arg2: memref<f64>) -> memref<f64> {
%c0 = arith.constant 0 : index
%c128 = arith.constant 128 : index
%c1 = arith.constant 1 : index
scf.for %arg3 = %c0 to %c128 step %c1 {
%0 = memref.load %arg0[%arg3] : memref<128xf64>
%1 = memref.load %arg1[%arg3] : memref<128xf64>
%2 = memref.load %arg2[] : memref<f64>
%3 = arith.mulf %0, %1 : f64
%4 = arith.addf %2, %3 : f64
memref.store %4, %arg2[] : memref<f64>
}
return %arg2 : memref<f64>
}
func.func @dscal(%arg0: memref<f64>, %arg1: memref<128xf64>, %arg2: memref<128xf64>) -> memref<128xf64> {
%c0 = arith.constant 0 : index
%c128 = arith.constant 128 : index
%c1 = arith.constant 1 : index
scf.parallel (%arg3) = (%c0) to (%c128) step (%c1) {
%0 = memref.load %arg0[] : memref<f64>
%1 = memref.load %arg1[%arg3] : memref<128xf64>
%2 = arith.mulf %0, %1 : f64
memref.store %2, %arg2[%arg3] : memref<128xf64>
scf.reduce
}
return %arg2 : memref<128xf64>
}
func.func @div(%arg0: memref<f64>, %arg1: memref<f64>, %arg2: memref<f64>) -> memref<f64> {
%0 = memref.load %arg0[] : memref<f64>
%1 = memref.load %arg1[] : memref<f64>
%2 = arith.divf %0, %1 : f64
memref.store %2, %arg2[] : memref<f64>
return %arg2 : memref<f64>
}
func.func @neg(%arg0: memref<f64>, %arg1: memref<f64>) -> memref<f64> {
%0 = memref.load %arg0[] : memref<f64>
%1 = arith.negf %0 : f64
memref.store %1, %arg1[] : memref<f64>
return %arg1 : memref<f64>
}
func.func @main(%arg0: memref<128x128xf64>, %arg1: memref<128xf64>, %arg2: memref<128xf64>, %arg3: memref<128xf64>) -> memref<128xf64> {
%alloc = memref.alloc() {alignment = 64 : i64} : memref<f64>
%alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<128xf64>
%alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<f64>
%alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<128xf64>
%alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<128xf64>
%alloc_4 = memref.alloc() {alignment = 64 : i64} : memref<f64>
%alloc_5 = memref.alloc() {alignment = 64 : i64} : memref<128xf64>
%alloc_6 = memref.alloc() {alignment = 64 : i64} : memref<128xf64>
%alloc_7 = memref.alloc() {alignment = 64 : i64} : memref<f64>
%alloc_8 = memref.alloc() {alignment = 64 : i64} : memref<128xf64>
%alloc_9 = memref.alloc() {alignment = 64 : i64} : memref<128xf64>
%alloc_10 = memref.alloc() {alignment = 64 : i64} : memref<f64>
%alloc_11 = memref.alloc() {alignment = 64 : i64} : memref<f64>


%0 = call @dot(%arg2, %arg2, %alloc) {fuse_with = "gemv"} : (memref<128xf64>, memref<128xf64>, memref<f64>) -> memref<f64>
%1 = call @gemv(%arg0, %arg3, %alloc_0) {fuse_with = "dot"} : (memref<128x128xf64>, memref<128xf64>, memref<128xf64>) -> memref<128xf64>
%2 = call @dot(%1, %arg3, %alloc_1) : (memref<128xf64>, memref<128xf64>, memref<f64>) -> memref<f64>

%3 = call @div(%0, %2, %alloc_10) {fuse_with = ""} : (memref<f64>, memref<f64>, memref<f64>) -> memref<f64>
%4 = call @dscal(%3, %arg3, %alloc_2) {fuse_with = ""} : (memref<f64>, memref<128xf64>, memref<128xf64>) -> memref<128xf64>
%5 = call @xpy(%arg1, %4, %alloc_3) {fuse_with = ""} : (memref<128xf64>, memref<128xf64>, memref<128xf64>) -> memref<128xf64>
%6 = call @neg(%3, %alloc_4) {fuse_with = ""} : (memref<f64>, memref<f64>) -> memref<f64>
%7 = call @dscal(%6, %1, %alloc_5) {fuse_with = ""} : (memref<f64>, memref<128xf64>, memref<128xf64>) -> memref<128xf64>
%8 = call @xpy(%arg2, %7, %alloc_6) {fuse_with = ""} : (memref<128xf64>, memref<128xf64>, memref<128xf64>) -> memref<128xf64>
%9 = call @dot(%8, %8, %alloc_7) {fuse_with = ""} : (memref<128xf64>, memref<128xf64>, memref<f64>) -> memref<f64>
%10 = call @div(%9, %0, %alloc_11) {fuse_with = ""} : (memref<f64>, memref<f64>, memref<f64>) -> memref<f64>
%11 = call @dscal(%10, %arg3, %alloc_8) {fuse_with = ""} : (memref<f64>, memref<128xf64>, memref<128xf64>) -> memref<128xf64>
%12 = call @xpy(%8, %11, %alloc_9) {fuse_with = ""} : (memref<128xf64>, memref<128xf64>, memref<128xf64>) -> memref<128xf64>
return %12 : memref<128xf64>
}
}

28 changes: 28 additions & 0 deletions examples/mlir-dumps/matvec-dot.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 32, crdWidth = 32 }>

module {
func.func private @spmv(%A: tensor<?x?xf64, #sparse>, %x: tensor<?xf64>, %ydst: tensor<?xf64>) -> tensor<?xf64> {
%y = linalg.matvec ins(%A, %x: tensor<?x?xf64, #sparse>, tensor<?xf64>) outs(%ydst : tensor<?xf64>) -> tensor<?xf64>
return %y : tensor<?xf64>
}

func.func @dot(%x : tensor<?xf64>, %y : tensor<?xf64>, %res : tensor<f64>) ->
tensor<f64> attributes { noinline } {
%dot = linalg.dot ins(%x, %y: tensor<?xf64>, tensor<?xf64>)
outs(%res: tensor<f64>) -> tensor<f64>
return %dot: tensor<f64>
}

func.func @main(%A : tensor<?x?xf64, #sparse>, %x : tensor<?xf64>, %y : tensor<?xf64>)
-> f64 {
%0 = func.call @spmv(%A, %x, %y) { noinline, fuse_with = "dot" } :
(tensor<?x?xf64, #sparse>, tensor<?xf64>, tensor<?xf64>) -> tensor<?xf64>

%dot_res = tensor.empty() : tensor<f64>
%1 = func.call @dot(%0, %x, %dot_res) { noinline, fuse_with = "spmv" } :
(tensor<?xf64>, tensor<?xf64>, tensor<f64>) -> tensor<f64>

%ret = tensor.extract %1[] : tensor<f64>
return %ret : f64
}
}
59 changes: 59 additions & 0 deletions examples/mlir-dumps/mmv.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
module {
func.func private @matmul(
%a: tensor<4096x4096xf64>,
%b: tensor<4096x4096xf64>,
%c_out: tensor<4096x4096xf64>
) -> tensor<4096x4096xf64> {
%c = linalg.matmul ins(%a, %b: tensor<4096x4096xf64>, tensor<4096x4096xf64>)
outs(%c_out: tensor<4096x4096xf64>) -> tensor<4096x4096xf64>
return %c : tensor<4096x4096xf64>
}

func.func private @matvec(
%a: tensor<4096x4096xf64>,
%x: tensor<4096xf64>,
%y_out: tensor<4096xf64>
) -> tensor<4096xf64> {
%y = linalg.matvec ins(%a, %x: tensor<4096x4096xf64>, tensor<4096xf64>)
outs(%y_out: tensor<4096xf64>) -> tensor<4096xf64>

return %y : tensor<4096xf64>
}

func.func @matmul_into_matvec(
%a: tensor<4096x4096xf64>,
%b: tensor<4096x4096xf64>,
%x: tensor<4096xf64>
) -> tensor<4096xf64> {

%c_init = tensor.empty() : tensor<4096x4096xf64>
%c = func.call @matmul(%a, %b, %c_init) { fuse_with = "matvec" }
: (tensor<4096x4096xf64>, tensor<4096x4096xf64>, tensor<4096x4096xf64>)
-> tensor<4096x4096xf64>

%y_init = tensor.empty() : tensor<4096xf64>
%y_out = func.call @matvec(%c, %x, %y_init) { fuse_with = "matmul" }
: (tensor<4096x4096xf64>, tensor<4096xf64>, tensor<4096xf64>)
-> tensor<4096xf64>

return %y_out : tensor<4096xf64>
}

func.func @matvec_into_matvec(
%a: tensor<4096x4096xf64>,
%b: tensor<4096x4096xf64>,
%x: tensor<4096xf64>
) -> tensor<4096xf64> {
%bx_init = tensor.empty() : tensor<4096xf64>
%bx = func.call @matvec(%b, %x, %bx_init) { fuse_with = "matvec" }
: (tensor<4096x4096xf64>, tensor<4096xf64>, tensor<4096xf64>)
-> tensor<4096xf64>

%y_init = tensor.empty() : tensor<4096xf64>
%y_out = func.call @matvec(%a, %bx, %y_init) { fuse_with = "matvec" }
: (tensor<4096x4096xf64>, tensor<4096xf64>, tensor<4096xf64>)
-> tensor<4096xf64>

return %y_out : tensor<4096xf64>
}
}
97 changes: 97 additions & 0 deletions examples/mlir-dumps/pcg.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 32, crdWidth = 32 }>
#idmap = affine_map<(d0) -> (d0)>
module {
func.func private @spmv(%A: tensor<?x?xf64, #sparse>, %x: tensor<?xf64>, %ydst: tensor<?xf64>) -> tensor<?xf64> {
%y = linalg.matvec ins(%A, %x: tensor<?x?xf64, #sparse>, tensor<?xf64>) outs(%ydst : tensor<?xf64>) -> tensor<?xf64>
return %y : tensor<?xf64>
}

func.func private @dot(%x: tensor<?xf64>, %y: tensor<?xf64>) -> f64 {
%0 = tensor.empty() : tensor<f64>
%dot = linalg.dot ins(%x, %y : tensor<?xf64>,tensor<?xf64>) outs(%0: tensor<f64>) -> tensor<f64>
%6 = tensor.extract %dot[] : tensor<f64>
return %6: f64
}

func.func private @axpby(%a: f64, %x: tensor<?xf64>, %b: f64, %y: tensor<?xf64>, %dst: tensor<?xf64>) -> tensor<?xf64> {
%1 = linalg.generic {indexing_maps = [#idmap, #idmap, #idmap], iterator_types = ["parallel"]} ins(%x, %y: tensor<?xf64>, tensor<?xf64>) outs(%dst : tensor<?xf64>) {
^bb0(%inx: f64, %iny: f64, %out: f64):
%4 = arith.mulf %inx, %a: f64
%5 = arith.mulf %iny, %b: f64
%6 = arith.addf %4, %5: f64
linalg.yield %6 : f64
} -> tensor<?xf64>
return %1 : tensor<?xf64>
}

func.func private @mult(%x: tensor<?xf64>, %y: tensor<?xf64>, %dst: tensor<?xf64>) -> tensor<?xf64> {
%1 = linalg.generic {indexing_maps = [#idmap, #idmap, #idmap], iterator_types = ["parallel"]} ins(%x, %y: tensor<?xf64>, tensor<?xf64>) outs(%dst : tensor<?xf64>) {
^bb0(%inx: f64, %iny: f64, %out: f64):
%2 = arith.mulf %inx, %iny: f64
linalg.yield %2 : f64
} -> tensor<?xf64>
return %1 : tensor<?xf64>
}

// CG solve with diagonal preconditioner
// Returns: x, numiter, resnorm
func.func @pcg(%A: tensor<?x?xf64, #sparse>, %b: tensor<?xf64>, %dinv: tensor<?xf64>, %tol: f64, %maxiter: index) -> (tensor<?xf64>, index, f64) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%n = tensor.dim %b, %c0 : tensor<?xf64>
%f0 = arith.constant 0.0 : f64
%f1 = arith.constant 1.0 : f64
%fm1 = arith.constant -1.0 : f64

// Preallocate some intermediate tensors for dst-passing style
%buf0 = tensor.empty(%n) : tensor<?xf64>
%buf1 = tensor.empty(%n) : tensor<?xf64>
%buf2 = tensor.empty(%n) : tensor<?xf64>

// Assume initial guess x0 = 0
// Then r0 = b - A*x0 = b
%r0 = linalg.copy ins(%b : tensor<?xf64>) outs(%buf0 : tensor<?xf64>) -> tensor<?xf64>
%z0 = func.call @mult(%r0, %dinv, %buf1) : (tensor<?xf64>, tensor<?xf64>, tensor<?xf64>) -> tensor<?xf64>
%p0 = linalg.copy ins(%z0 : tensor<?xf64>) outs(%buf2 : tensor<?xf64>) -> tensor<?xf64>
%x0 = tensor.splat %f0[%n] : tensor<?xf64>
%Apbuf = tensor.empty(%n) : tensor<?xf64>
%rr0 = func.call @dot(%r0, %r0) : (tensor<?xf64>, tensor<?xf64>) -> f64
%initres = math.sqrt %rr0 : f64

%x, %p, %z, %r, %final_relres, %rz, %iters = scf.while (%xiter = %x0, %piter = %p0, %ziter = %z0, %riter = %r0, %rziter = %f0, %i = %c1) : (tensor<?xf64>, tensor<?xf64>, tensor<?xf64>, tensor<?xf64>, f64, index) -> (tensor<?xf64>, tensor<?xf64>, tensor<?xf64>, tensor<?xf64>, f64, f64, index)
{
%Ap = func.call @spmv(%A, %piter, %Apbuf) { fuse_with = "dot" } : (tensor<?x?xf64, #sparse>, tensor<?xf64>, tensor<?xf64>) -> tensor<?xf64>
%pAp = func.call @dot(%Ap, %piter) { fuse_with = "spmv" } : (tensor<?xf64>, tensor<?xf64>) -> f64
%rz = func.call @dot(%riter, %ziter) : (tensor<?xf64>, tensor<?xf64>) -> f64
%alpha = arith.divf %rz, %pAp : f64
%malpha = arith.negf %alpha : f64

// Update x and r
%xnext = func.call @axpby(%f1, %xiter, %alpha, %piter, %xiter) { fuse_with = "axpby" } : (f64, tensor<?xf64>, f64, tensor<?xf64>, tensor<?xf64>) -> tensor<?xf64>
%rnext = func.call @axpby(%f1, %riter, %malpha, %Ap, %riter) { fuse_with = "axpby" } : (f64, tensor<?xf64>, f64, tensor<?xf64>, tensor<?xf64>) -> tensor<?xf64>

// Test against tolerance and
%rr = func.call @dot(%rnext, %rnext) : (tensor<?xf64>, tensor<?xf64>) -> f64
%rnorm = math.sqrt %rr : f64
%relres = arith.divf %rnorm, %initres : f64
%not_converged = arith.cmpf ogt, %relres, %tol : f64

// we have already completed an iteration, which is why i is intially 1
%below_maxiter = arith.cmpi ne, %i, %maxiter : index
%continue = arith.andi %not_converged, %below_maxiter : i1

scf.condition(%continue) %xnext, %piter, %ziter, %rnext, %relres, %rz, %i: tensor<?xf64>, tensor<?xf64>, tensor<?xf64>, tensor<?xf64>, f64, f64, index
}
do {
^bb0(%xiter : tensor<?xf64>, %piter : tensor<?xf64>, %ziter : tensor<?xf64>, %riter : tensor<?xf64>, %unused : f64, %rziter : f64, %i : index):
%znext = func.call @mult(%riter, %dinv, %ziter) : (tensor<?xf64>, tensor<?xf64>, tensor<?xf64>) -> tensor<?xf64>
%rznext = func.call @dot(%riter, %znext) : (tensor<?xf64>, tensor<?xf64>) -> f64
%beta = arith.divf %rznext, %rziter : f64
%pnext = func.call @axpby(%f1, %znext, %beta, %piter, %piter) : (f64, tensor<?xf64>, f64, tensor<?xf64>, tensor<?xf64>) -> tensor<?xf64>
%inext = arith.addi %i, %c1 : index
scf.yield %xiter, %pnext, %znext, %riter, %rznext, %inext : tensor<?xf64>, tensor<?xf64>, tensor<?xf64>, tensor<?xf64>, f64, index
}
return %x, %iters, %final_relres : tensor<?xf64>, index, f64
}
}

Loading