Skip to content

Conversation

@jpr-snl
Copy link

@jpr-snl jpr-snl commented Jul 24, 2025

Implements DualView class in Python to tie Kokkos View lifespans to python objects. This should eliminate what is effectively a large memory leak, where every LAPIS return value was kept alive until the Python interpreter exits. This change should also allow a significant reduction in Host <-> Device communication by allowing LAPIS results to stay resident within their Kokkos memory space via DualView.

It also fixes some more uninitialized value problems stemming from use of LLVM tensor.empty() in the examples.

jpr-snl and others added 6 commits May 16, 2025 14:35
@jpr-snl jpr-snl changed the title WIP: Dual view in python Dual view in python Jul 25, 2025
@jpr-snl jpr-snl changed the title Dual view in python Dual Views managed by Python Jul 25, 2025
@jpr-snl
Copy link
Author

jpr-snl commented Jul 25, 2025

I was unable to test these torch examples with the new setup because I don't have torch_mlir installed in my environment:

resnet18_static
twomodules
gemm
batched_gemm
spmv
spmv_sparsetensor
resnet18_dynamic
sparse_matadd
matadd

@jpr-snl
Copy link
Author

jpr-snl commented Jul 25, 2025

It's also worth noting that this will be a breaking change for any existing code. We should do a minor version bump. To get access to LAPIS results you now have to call .asnumpy() or .asmemref() on the result object (unless it's a scalar).

jpr-snl and others added 8 commits July 31, 2025 13:01
Merge branch 'main' into DualViewMerge
Signed-off-by: Brian Kelley <bmkelle@sandia.gov>
This change still applies since DualView::deallocate
still resets impl.

Signed-off-by: Brian Kelley <bmkelle@sandia.gov>
Merge in recent main changes into DualViewInPython
@brian-kelley
Copy link
Collaborator

@jpr-snl I added .asnumpy() to dense results and everything looks good on the torch/mpact based tests (patch for these below). Now the only failures are in tests which write their results to pre-existing numpy arrays - they pass on OpenMP but fail on Cuda.

spmm_dcsr_opdsl
spmv_noalloc
gemm_no_alloc
batched_gemm_no_alloc

Patch to update remaining examples:

diff --git a/examples/batched_gemm.py b/examples/batched_gemm.py
index 5477ab9..ab82bca 100644
--- a/examples/batched_gemm.py
+++ b/examples/batched_gemm.py
@@ -35,7 +35,7 @@ def main():
     k_backend = backend.compile(mlir_module)
 
     print("a*b from kokkos (showing slice [0,:,:] only)")
-    ckokkos = k_backend.forward(a, b)
+    ckokkos = k_backend.forward(a, b).asnumpy()
     print(ckokkos[0, :, :])
 
     print("a*b from pytorch (showing slice [0,:,:] only)")
diff --git a/examples/gemm.py b/examples/gemm.py
index 6865519..362026f 100644
--- a/examples/gemm.py
+++ b/examples/gemm.py
@@ -29,7 +29,7 @@ def main():
     k_backend = backend.compile(mlir_module)
 
     print("a*b from kokkos")
-    print(k_backend.forward(a, b))
+    print(k_backend.forward(a, b).asnumpy())
 
     print("a*b from pytorch")
     print(m.forward(a, b).numpy())
diff --git a/examples/matadd.py b/examples/matadd.py
index f3bb248..1fdb2d7 100644
--- a/examples/matadd.py
+++ b/examples/matadd.py
@@ -34,7 +34,7 @@ def main():
     print(sumTorch)
 
     print("a+b from kokkos")
-    sumKokkos = k_backend.forward(a, b)
+    sumKokkos = k_backend.forward(a, b).asnumpy()
     print(sumKokkos)
 
     sys.exit(0 if allclose(sumTorch, sumKokkos) else 1)
diff --git a/examples/maxpool_nchw.py b/examples/maxpool_nchw.py
index 8ea2786..07ab3ab 100644
--- a/examples/maxpool_nchw.py
+++ b/examples/maxpool_nchw.py
@@ -32,7 +32,7 @@ def main():
     k_backend = backend.compile(mlir_module)
 
     mpTorch = m(T).numpy()
-    mpKokkos = k_backend.forward(T)
+    mpKokkos = k_backend.forward(T).asnumpy()
 
     if allclose(mpTorch, mpKokkos):
         print("Success, results match with torch")
diff --git a/examples/resnet18_dynamic.py b/examples/resnet18_dynamic.py
index 2ec491b..0df9a93 100644
--- a/examples/resnet18_dynamic.py
+++ b/examples/resnet18_dynamic.py
@@ -41,7 +41,7 @@ def load_multiple_images(paths):
 
 def predictions(torch_func, kokkos_func, images, labels):
     torch_pred = torch_func(images)
-    kokkos_pred = torch.from_numpy(kokkos_func(images.numpy()))
+    kokkos_pred = torch.from_numpy(kokkos_func(images.numpy()).asnumpy())
     success = True
     for i in range(3):
         print("Image", i, "top 3 predictions:")
diff --git a/examples/resnet18_static.py b/examples/resnet18_static.py
index 9d5d794..2e4300b 100644
--- a/examples/resnet18_static.py
+++ b/examples/resnet18_static.py
@@ -39,7 +39,7 @@ def predictions(torch_func, kokkos_func, img, labels):
     pred1 = top3_possibilities(torch_func(img), labels)
     print("PyTorch prediction")
     print(pred1)
-    pred2 = top3_possibilities(torch.from_numpy(kokkos_func(img.numpy())), labels)
+    pred2 = top3_possibilities(torch.from_numpy(kokkos_func(img.numpy()).asnumpy()), labels)
     print("LAPIS prediction")
     print(pred2)
     # Return success if top class is correct, and its probability is close to torch's
diff --git a/examples/spmv.py b/examples/spmv.py
index dc3fca9..05473e0 100644
--- a/examples/spmv.py
+++ b/examples/spmv.py
@@ -44,7 +44,7 @@ def main():
     ytorch = module_torch.forward(A, x).numpy()
     print(ytorch)
     print("y = Ax from kokkos:")
-    ykokkos = module_kokkos.lapis_main(rowptrs, colinds, values, ((m, n), (len(rowptrs), len(colinds), len(values))), x.numpy())
+    ykokkos = module_kokkos.lapis_main(rowptrs, colinds, values, ((m, n), (len(rowptrs), len(colinds), len(values))), x.numpy()).asnumpy()
     print(ykokkos)
     sys.exit(0 if np.allclose(ytorch, ykokkos) else 1)
 
diff --git a/examples/spmv_sparsetensor.py b/examples/spmv_sparsetensor.py
index 646b550..df8082b 100644
--- a/examples/spmv_sparsetensor.py
+++ b/examples/spmv_sparsetensor.py
@@ -48,7 +48,7 @@ def main():
     ytorch = module_torch.forward(A, x).numpy()
     print(ytorch)
     print("y = Ax from kokkos:")
-    ykokkos = module_kokkos.lapis_main(Asp, x.numpy())
+    ykokkos = module_kokkos.lapis_main(Asp, x.numpy()).asnumpy()
     print(ykokkos)
     sys.exit(0 if np.allclose(ytorch, ykokkos) else 1)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants