diff --git a/examples/bspmm.py b/examples/bspmm.py index 0f51b24c..e101e47f 100644 --- a/examples/bspmm.py +++ b/examples/bspmm.py @@ -100,7 +100,7 @@ def main(): backend = KokkosBackend.KokkosBackend(decompose_tensors=True, parallel_strategy=par, index_instance=instance, num_instances=len(parStrats)) instance += 1 module_kokkos = backend.compile(moduleText) - C_kokkos = module_kokkos.pte_local_bspmm(rowptrs, colinds, values, ((m, n, b), (m+1, nnz, nnz*b)), B) + C_kokkos = module_kokkos.pte_local_bspmm(rowptrs, colinds, values, ((m, n, b), (m+1, nnz, nnz*b)), B).asnumpy() # For debugging: print the CSRV formatted matrix # module_kokkos.print_csrv(rowptrs, colinds, values, ((m, n, b), (m+1, nnz, nnz*b))) if np.allclose(C_gold, C_kokkos): diff --git a/examples/csrv_softmax.py b/examples/csrv_softmax.py index 31e459f7..8647aba3 100644 --- a/examples/csrv_softmax.py +++ b/examples/csrv_softmax.py @@ -155,7 +155,7 @@ def main(): # For debugging: print the CSRV formatted matrix #module_kokkos.print_csrv(A) result = module_kokkos.pte_softmax(A) - resultDense = module_kokkos.csrv_to_dense(result) + resultDense = module_kokkos.csrv_to_dense(result).asnumpy() print("Result (converted to dense): ") print(resultDense) if checkResult is None: diff --git a/examples/gemm_no_alloc.py b/examples/gemm_no_alloc.py index e0be6150..24cbcb1a 100644 --- a/examples/gemm_no_alloc.py +++ b/examples/gemm_no_alloc.py @@ -32,9 +32,14 @@ def main(): ckokkos = torch.zeros((m, n)) backend = KokkosBackend.KokkosBackend(dump_mlir=False) - k_backend = backend.compile(module) + should_compile = True + if should_compile: + k_backend = backend.compile(module) + else: + import lapis_package.lapis_package as k_backend print("a*b from kokkos") + print(f"{type(a)=} {type(b)=} {type(ckokkos)=}") k_backend.forward(a, b, ckokkos) print(ckokkos) diff --git a/examples/hitting_times.py b/examples/hitting_times.py index d8074a2d..4c5a0712 100644 --- a/examples/hitting_times.py +++ b/examples/hitting_times.py @@ -212,6 +212,7 @@ def main(): (mom1, mom2) = module_kokkos.mht(A.indptr, A.indices, A.data, ((n, n), (n + 1, nnz, nnz)), mask, D, 0.99, 1e-10, 20) # Normalize 2nd moment mom2_norm = module_kokkos.normalize_mom2(mom1, mom2) + mom1, mom2_norm = mom1.asnumpy(), mom2_norm.asnumpy() print("1st moment:", mom1) print("2nd moment:", mom2_norm) mom1_gold = [9.9999999999999957e+01, 9.9999999999999957e+01, 1.9947875961498600e+01, 0, 1.8542105566249749e+01, 1.9909372734041774e+01, 9.9999999999999957e+01, 1.9970947033662114e+01, 2.0196661010469668e+01, 1.4624345260746775e+01, 1.8252877036565582e+01, 1.9213516388393590e+01] diff --git a/examples/issue76.py b/examples/issue76.py index e9476c48..e2841329 100644 --- a/examples/issue76.py +++ b/examples/issue76.py @@ -69,7 +69,7 @@ def main(): module_kokkos.print_dcsr(A_dcsr) [result, rank, nnz] = module_kokkos.column_sums(A_dcsr) # Convert the sparse vector result to dense to check the output - result_dense = module_kokkos.sparse_vec_to_dense(result) + result_dense = module_kokkos.sparse_vec_to_dense(result).asnumpy() print("Results: ", rank, nnz, result_dense) print("Correct result: ", 1, gold_nnz, gold) diff --git a/examples/multiple_results.py b/examples/multiple_results.py index fae3957a..bccc899a 100644 --- a/examples/multiple_results.py +++ b/examples/multiple_results.py @@ -7,10 +7,11 @@ module { func.func @plus_norm(%x: tensor, %y: tensor) -> (tensor, f64) { %c0 = arith.constant 0 : index + %f0 = arith.constant 0.0 : f64 %n = tensor.dim %y, %c0 : tensor - %alloc1 = tensor.empty (%n) : tensor + %alloc1 = tensor.splat %f0[%n] : tensor %x_plus_y = linalg.add ins(%x, %y : tensor, tensor) outs(%alloc1: tensor) -> tensor - %alloc2 = tensor.empty () : tensor + %alloc2 = tensor.splat %f0 : tensor %sumTensor = linalg.dot ins(%x_plus_y, %x_plus_y : tensor, tensor) outs(%alloc2: tensor) -> tensor %sum = tensor.extract %sumTensor[] : tensor %norm = math.sqrt %sum : f64 @@ -29,7 +30,8 @@ def main(): print("x:", x) print("y:", y) - (x_plus_y, norm) = module_kokkos.plus_norm(x, y) + x_plus_y, norm = module_kokkos.plus_norm(x, y) + x_plus_y = x_plus_y.asnumpy() print("x + y:", x_plus_y) print("norm(x+y):", norm) diff --git a/examples/pcg_solve.py b/examples/pcg_solve.py index a29697a0..139e4a23 100644 --- a/examples/pcg_solve.py +++ b/examples/pcg_solve.py @@ -53,16 +53,16 @@ %f1 = arith.constant 1.0 : f64 %fm1 = arith.constant -1.0 : f64 // Preallocate some intermediate tensors for dst-passing style - %buf0 = tensor.empty(%n) : tensor - %buf1 = tensor.empty(%n) : tensor - %buf2 = tensor.empty(%n) : tensor + %buf0 = tensor.splat %f0[%n] : tensor + %buf1 = tensor.splat %f0[%n] : tensor + %buf2 = tensor.splat %f0[%n] : tensor // Assume initial guess x0 = 0 // Then r0 = b - A*x0 = b %r0 = linalg.copy ins(%b : tensor) outs(%buf0 : tensor) -> tensor %z0 = func.call @mult(%r0, %dinv, %buf1) : (tensor, tensor, tensor) -> tensor %p0 = linalg.copy ins(%z0 : tensor) outs(%buf2 : tensor) -> tensor %x0 = tensor.splat %f0[%n] : tensor - %Apbuf = tensor.empty(%n) : tensor + %Apbuf = tensor.splat %f0[%n] : tensor %rr0 = func.call @dot(%r0, %r0) : (tensor, tensor) -> f64 %initres = math.sqrt %rr0 : f64 %x, %p, %z, %r, %final_relres, %rz, %iters = scf.while (%xiter = %x0, %piter = %p0, %ziter = %z0, %riter = %r0, %rziter = %f0, %i = %c1) : (tensor, tensor, tensor, tensor, f64, index) -> (tensor, tensor, tensor, tensor, f64, f64, index) @@ -118,11 +118,16 @@ def main(): maxiter = 40 backend = KokkosBackend.KokkosBackend(decompose_tensors=True) - module_kokkos = backend.compile(moduleText) + should_compile = True + if should_compile: + module_kokkos = backend.compile(moduleText) + else: + import lapis_package.lapis_package as module_kokkos print("x exact solution (first 10 elements):", xgold[:10]) (x, numiter, relres) = module_kokkos.pcg(A.indptr, A.indices, A.data, ((n, n), (n + 1, nnz, nnz)), b, dinv, reltol, maxiter) + x = x.asnumpy() print("Ran CG for", numiter, "iterations and achieved relative residual norm", relres) print("x vector from LAPIS (first 10 elements):", x[:10]) diff --git a/examples/scalar_args.py b/examples/scalar_args.py index db7ee0a0..48e92813 100644 --- a/examples/scalar_args.py +++ b/examples/scalar_args.py @@ -37,8 +37,12 @@ def main(): xcorrect[i] += alpha backend = KokkosBackend.KokkosBackend(decompose_tensors=True) - module_kokkos = backend.compile(moduleText) - x = module_kokkos.f(A, alpha, k, False) + should_compile = True + if should_compile: + module_kokkos = backend.compile(moduleText) + else: + import lapis_package.lapis_package as module_kokkos + x = module_kokkos.f(A, alpha, k, False).asnumpy() print("Result 1:", x) if not np.allclose(x, xcorrect): @@ -49,7 +53,7 @@ def main(): for i in range(8): xcorrect[i] -= alpha - x = module_kokkos.f(A, alpha, k, True) + x = module_kokkos.f(A, alpha, k, True).asnumpy() print("Result 2:", x) if not np.allclose(x, xcorrect): diff --git a/examples/sparse_axpy.py b/examples/sparse_axpy.py index 1e2710d5..0c8f9ac3 100644 --- a/examples/sparse_axpy.py +++ b/examples/sparse_axpy.py @@ -85,9 +85,7 @@ def check_axpy(module_kokkos, v1_pos, v1_inds, v1_vals, v2_pos, v2_inds, v2_vals print("Result crd:", result_crd) print("Result val:", result_val) print("Test case result:") - # Alternate way to print sparse vector, using sparse_tensor.print op - #module_kokkos.print_sparse_vec(result) - result = module_kokkos.sparse_to_dense(result) + result = module_kokkos.sparse_to_dense(result).asnumpy() if correct_nnz != actual_nnz: print("Failed: result nonzero count incorrect") return False diff --git a/examples/spmv_noalloc.py b/examples/spmv_noalloc.py index 20f2c33d..56b99b58 100644 --- a/examples/spmv_noalloc.py +++ b/examples/spmv_noalloc.py @@ -23,7 +23,11 @@ def main(): ykokkos = np.zeros((m), dtype=np.double) backend = KokkosBackend.KokkosBackend(decompose_tensors=True) - module_kokkos = backend.compile(moduleText) + should_compile = True + if should_compile: + module_kokkos = backend.compile(moduleText) + else: + import lapis_package.lapis_package as module_kokkos print("y = Ax from kokkos:") module_kokkos.spmv(rowptrs, colinds, values, ((m, n), (len(rowptrs), len(colinds), len(values))), x, ykokkos) diff --git a/mlir/lib/Target/KokkosCpp/LAPISSupport.hpp b/mlir/lib/Target/KokkosCpp/LAPISSupport.hpp index e69a02d1..a803bd9b 100644 --- a/mlir/lib/Target/KokkosCpp/LAPISSupport.hpp +++ b/mlir/lib/Target/KokkosCpp/LAPISSupport.hpp @@ -4,8 +4,12 @@ #include #include +struct StridedMemRefTypeBase +{ +}; + template -struct StridedMemRefType { +struct StridedMemRefType : public StridedMemRefTypeBase { T *basePtr; T *data; int64_t offset; @@ -15,24 +19,6 @@ struct StridedMemRefType { namespace LAPIS { - using TeamPolicy = Kokkos::TeamPolicy<>; - using TeamMember = typename TeamPolicy::member_type; - - template - StridedMemRefType viewToStridedMemref(const V& v) - { - StridedMemRefType smr; - smr.basePtr = v.data(); - smr.data = v.data(); - smr.offset = 0; - for(int i = 0; i < int(V::rank); i++) - { - smr.sizes[i] = v.extent(i); - smr.strides[i] = v.stride(i); - } - return smr; - } - template V stridedMemrefToView(const StridedMemRefType& smr) { @@ -90,33 +76,39 @@ namespace LAPIS return V(&smr.data[smr.offset], layout); } - // KeepAlive structure keeps a reference to Kokkos::Views which - // are returned to Python. Since it's difficult to transfer ownership of a - // Kokkos::View's memory to numpy, we just have the Kokkos::View maintain ownership - // and return an unmanaged numpy array to Python. - // - // All these views will be deallocated during lapis_finalize to avoid leaking. - // The downside is that if a function is called many times, - // all its results are kept in memory at the same time. - struct KeepAlive - { - virtual ~KeepAlive() {} - }; - - template - struct KeepAliveT : public KeepAlive + struct PythonParameterBase { - // Make a shallow-copy of val - KeepAliveT(const T& val) : p(new T(val)) {} - std::unique_ptr p; + enum WrapperType : int32_t { + EMPTY_TYPE = 0, + STRIDED_MEMREF_TYPE = 1, + DUALVIEW_TYPE = 2 + }; + + WrapperType wrapper_type; + int32_t rank; + + union { + struct StridedMemRefTypeBase* smr; + struct DualViewBase* view; + }; }; - static std::vector> alives; + using TeamPolicy = Kokkos::TeamPolicy<>; + using TeamMember = typename TeamPolicy::member_type; - template - void keepAlive(const T& val) + template + StridedMemRefType viewToStridedMemref(const V& v) { - alives.emplace_back(new KeepAliveT(val)); + StridedMemRefType smr; + smr.basePtr = v.data(); + smr.data = v.data(); + smr.offset = 0; + for(int i = 0; i < int(V::rank); i++) + { + smr.sizes[i] = v.extent(i); + smr.strides[i] = v.stride(i); + } + return smr; } // DualView design @@ -130,30 +122,37 @@ namespace LAPIS // - Assume that any DualView's parent is contiguous, and can be deep-copied between h and d // - All DualViews with the same parent share the parent's modify flags // - // DualViewBase can also "keepAliveHost" to keep its host view alive until lapis_finalize is called. - // This is used to safely return host views to python for numpy arrays to alias. - - struct DualViewBase + struct DualViewImplBase { - virtual ~DualViewBase() {} + enum AliasStatus + { + ALIAS_STATUS_UNKNOWN = 0, + HOST_IS_ALIAS = 1, + DEVICE_IS_ALIAS = 2, + NEITHER_IS_ALIAS = 3 + }; + + virtual ~DualViewImplBase() {} virtual void syncHost() = 0; virtual void syncDevice() = 0; - virtual void keepAliveHost() = 0; + virtual void toStridedMemRef(StridedMemRefTypeBase* vp_out) = 0; bool modified_host = false; bool modified_device = false; - std::shared_ptr parent; + std::shared_ptr parent; + AliasStatus alias_status; - void setParent(const std::shared_ptr& parent_) + void setParent(const std::shared_ptr& parent_) { this->parent = parent_; } }; template - struct DualViewImpl : public DualViewBase + struct DualViewImpl : public DualViewImplBase { using HostView = Kokkos::View; using DeviceView = Kokkos::View; + using HostMemRefType = StridedMemRefType; static constexpr bool deviceAccessesHost = Kokkos::SpaceAccessibility::accessible; static constexpr bool hostAccessesDevice = Kokkos::SpaceAccessibility::accessible; @@ -202,9 +201,11 @@ namespace LAPIS modified_device = true; if constexpr(deviceAccessesHost) { host_view = HostView(v.data(), v.layout()); + alias_status = AliasStatus::HOST_IS_ALIAS; } else { host_view = HostView(Kokkos::view_alloc(Kokkos::WithoutInitializing, v.label() + "_host"), v.layout()); + alias_status = AliasStatus::NEITHER_IS_ALIAS; } device_view = v; } @@ -212,9 +213,11 @@ namespace LAPIS modified_host = true; if constexpr(deviceAccessesHost) { device_view = DeviceView(v.data(), v.layout()); + alias_status = AliasStatus::DEVICE_IS_ALIAS; } else { device_view = DeviceView(Kokkos::view_alloc(Kokkos::WithoutInitializing, v.label() + "_dev"), v.layout()); + alias_status = AliasStatus::NEITHER_IS_ALIAS; } host_view = v; } @@ -281,15 +284,6 @@ namespace LAPIS } } - void keepAliveHost() override - { - // keep the parent's host view alive. - // It is assumed to be either managed, - // or unmanaged but references memory (e.g. from numpy) - // with a longer lifetime that any result from the current LAPIS function. - keepAlive(host_view); - } - void deallocate() { device_view = DeviceView(); host_view = HostView(); @@ -303,16 +297,28 @@ namespace LAPIS return device_view.stride(dim); } + void toStridedMemRef(StridedMemRefTypeBase* out) { + syncHost(); + *static_cast(out) = viewToStridedMemref(host_view); + } + DeviceView device_view; HostView host_view; }; + struct DualViewBase + { + virtual void toStridedMemRef(StridedMemRefTypeBase* out) = 0; + virtual ~DualViewBase() {} + }; + template - struct DualView + struct DualView : public DualViewBase { using ImplType = DualViewImpl; using DeviceView = typename ImplType::DeviceView; using HostView = typename ImplType::HostView; + using HostMemRefType = typename ImplType::HostMemRefType; std::shared_ptr impl; bool syncHostWhenDestroyed = false; @@ -332,6 +338,11 @@ namespace LAPIS impl->setParent(impl); } + void toStridedMemRef(StridedMemRefTypeBase* out) + { + impl->toStridedMemRef(out); + } + template DualView(const V& v) { static_assert(std::is_same_v, @@ -349,12 +360,12 @@ namespace LAPIS impl->setParent(parent.impl->parent); } - ~DualView() { + virtual ~DualView() { if(!impl) return; if(syncHostWhenDestroyed) syncHost(); - DualViewBase* parent = impl->parent.get(); + DualViewImplBase* parent = impl->parent.get(); impl.reset(); - // All DualViewBases keep a shared reference to themselves, so + // All DualViewImplBases keep a shared reference to themselves, so // parent always keeps a shared_ptr to itself. This would normally // prevent the parent destructor ever being called. // @@ -416,15 +427,63 @@ namespace LAPIS return impl->stride(dim); } - void keepAliveHost() const { - impl->parent->keepAliveHost(); - } - void syncHostOnDestroy() { syncHostWhenDestroyed = true; } }; + template + struct PythonParameter : public PythonParameterBase + { + DV toView() { + switch(wrapper_type) + { + case STRIDED_MEMREF_TYPE: + return stridedMemrefToView(*static_cast(smr)); + break; + + case DUALVIEW_TYPE: + return *dynamic_cast(view); + break; + + default: + assert(false); + + // In case asserts are turned off, initialize to nullptr to make it easier to debug + DV* ret = nullptr; + return *ret; + }; + } + + PythonParameter(const DV& dv) + { + wrapper_type = DUALVIEW_TYPE; + rank = DV::HostView::rank; + view = new DV(dv); + } + + PythonParameter(const PythonParameter& other) + { + wrapper_type = other.wrapper_type; + rank = other.rank; + if(wrapper_type == DUALVIEW_TYPE) { + view = new DV(other.view); + }else if(wrapper_type == STRIDED_MEMREF_TYPE) { + smr = new typename DV::HostMemRefType(static_cast(other.smr)); + } + } + + ~PythonParameter() + { + if(wrapper_type == DUALVIEW_TYPE) + { + delete static_cast(view); + }else if(wrapper_type == STRIDED_MEMREF_TYPE) { + delete static_cast(smr); + } + } + }; + inline int threadParallelVectorLength(int par) { if (par < 1) return 1; diff --git a/mlir/lib/Target/KokkosCpp/LAPISSupportFormatted.hpp b/mlir/lib/Target/KokkosCpp/LAPISSupportFormatted.hpp index 3a1ea8e3..a53f16e1 100644 --- a/mlir/lib/Target/KokkosCpp/LAPISSupportFormatted.hpp +++ b/mlir/lib/Target/KokkosCpp/LAPISSupportFormatted.hpp @@ -4,8 +4,12 @@ "#include \n" "#include \n" "\n" +"struct StridedMemRefTypeBase\n" +"{\n" +"};\n" +"\n" "template \n" -"struct StridedMemRefType {\n" +"struct StridedMemRefType : public StridedMemRefTypeBase {\n" " T *basePtr;\n" " T *data;\n" " int64_t offset;\n" @@ -15,24 +19,6 @@ "\n" "namespace LAPIS\n" "{\n" -" using TeamPolicy = Kokkos::TeamPolicy<>;\n" -" using TeamMember = typename TeamPolicy::member_type;\n" -"\n" -" template\n" -" StridedMemRefType viewToStridedMemref(const V& v)\n" -" {\n" -" StridedMemRefType smr;\n" -" smr.basePtr = v.data();\n" -" smr.data = v.data();\n" -" smr.offset = 0;\n" -" for(int i = 0; i < int(V::rank); i++)\n" -" {\n" -" smr.sizes[i] = v.extent(i);\n" -" smr.strides[i] = v.stride(i);\n" -" }\n" -" return smr;\n" -" }\n" -"\n" " template\n" " V stridedMemrefToView(const StridedMemRefType& smr)\n" " {\n" @@ -90,33 +76,39 @@ " return V(&smr.data[smr.offset], layout);\n" " }\n" "\n" -" // KeepAlive structure keeps a reference to Kokkos::Views which\n" -" // are returned to Python. Since it\'s difficult to transfer ownership of a\n" -" // Kokkos::View\'s memory to numpy, we just have the Kokkos::View maintain ownership\n" -" // and return an unmanaged numpy array to Python.\n" -" //\n" -" // All these views will be deallocated during lapis_finalize to avoid leaking.\n" -" // The downside is that if a function is called many times,\n" -" // all its results are kept in memory at the same time.\n" -" struct KeepAlive\n" -" {\n" -" virtual ~KeepAlive() {}\n" -" };\n" -"\n" -" template\n" -" struct KeepAliveT : public KeepAlive\n" +" struct PythonParameterBase\n" " {\n" -" // Make a shallow-copy of val\n" -" KeepAliveT(const T& val) : p(new T(val)) {}\n" -" std::unique_ptr p;\n" +" enum WrapperType : int32_t {\n" +" EMPTY_TYPE = 0,\n" +" STRIDED_MEMREF_TYPE = 1,\n" +" DUALVIEW_TYPE = 2\n" +" };\n" +"\n" +" WrapperType wrapper_type;\n" +" int32_t rank;\n" +"\n" +" union {\n" +" struct StridedMemRefTypeBase* smr;\n" +" struct DualViewBase* view;\n" +" };\n" " };\n" "\n" -" static std::vector> alives;\n" +" using TeamPolicy = Kokkos::TeamPolicy<>;\n" +" using TeamMember = typename TeamPolicy::member_type;\n" "\n" -" template\n" -" void keepAlive(const T& val)\n" +" template\n" +" StridedMemRefType viewToStridedMemref(const V& v)\n" " {\n" -" alives.emplace_back(new KeepAliveT(val));\n" +" StridedMemRefType smr;\n" +" smr.basePtr = v.data();\n" +" smr.data = v.data();\n" +" smr.offset = 0;\n" +" for(int i = 0; i < int(V::rank); i++)\n" +" {\n" +" smr.sizes[i] = v.extent(i);\n" +" smr.strides[i] = v.stride(i);\n" +" }\n" +" return smr;\n" " }\n" "\n" " // DualView design\n" @@ -130,30 +122,37 @@ " // - Assume that any DualView\'s parent is contiguous, and can be deep-copied between h and d\n" " // - All DualViews with the same parent share the parent\'s modify flags\n" " //\n" -" // DualViewBase can also \"keepAliveHost\" to keep its host view alive until lapis_finalize is called.\n" -" // This is used to safely return host views to python for numpy arrays to alias.\n" -"\n" -" struct DualViewBase\n" +" struct DualViewImplBase\n" " {\n" -" virtual ~DualViewBase() {}\n" +" enum AliasStatus\n" +" {\n" +" ALIAS_STATUS_UNKNOWN = 0,\n" +" HOST_IS_ALIAS = 1,\n" +" DEVICE_IS_ALIAS = 2,\n" +" NEITHER_IS_ALIAS = 3\n" +" };\n" +"\n" +" virtual ~DualViewImplBase() {}\n" " virtual void syncHost() = 0;\n" " virtual void syncDevice() = 0;\n" -" virtual void keepAliveHost() = 0;\n" +" virtual void toStridedMemRef(StridedMemRefTypeBase* vp_out) = 0;\n" " bool modified_host = false;\n" " bool modified_device = false;\n" -" std::shared_ptr parent;\n" +" std::shared_ptr parent;\n" +" AliasStatus alias_status;\n" "\n" -" void setParent(const std::shared_ptr& parent_)\n" +" void setParent(const std::shared_ptr& parent_)\n" " {\n" " this->parent = parent_;\n" " }\n" " };\n" "\n" " template\n" -" struct DualViewImpl : public DualViewBase\n" +" struct DualViewImpl : public DualViewImplBase\n" " {\n" " using HostView = Kokkos::View;\n" " using DeviceView = Kokkos::View;\n" +" using HostMemRefType = StridedMemRefType;\n" "\n" " static constexpr bool deviceAccessesHost = Kokkos::SpaceAccessibility::accessible;\n" " static constexpr bool hostAccessesDevice = Kokkos::SpaceAccessibility::accessible;\n" @@ -202,9 +201,11 @@ " modified_device = true;\n" " if constexpr(deviceAccessesHost) {\n" " host_view = HostView(v.data(), v.layout());\n" +" alias_status = AliasStatus::HOST_IS_ALIAS;\n" " }\n" " else {\n" " host_view = HostView(Kokkos::view_alloc(Kokkos::WithoutInitializing, v.label() + \"_host\"), v.layout());\n" +" alias_status = AliasStatus::NEITHER_IS_ALIAS;\n" " }\n" " device_view = v;\n" " }\n" @@ -212,9 +213,11 @@ " modified_host = true;\n" " if constexpr(deviceAccessesHost) {\n" " device_view = DeviceView(v.data(), v.layout());\n" +" alias_status = AliasStatus::DEVICE_IS_ALIAS;\n" " }\n" " else {\n" " device_view = DeviceView(Kokkos::view_alloc(Kokkos::WithoutInitializing, v.label() + \"_dev\"), v.layout());\n" +" alias_status = AliasStatus::NEITHER_IS_ALIAS;\n" " }\n" " host_view = v;\n" " }\n" @@ -281,15 +284,6 @@ " }\n" " }\n" "\n" -" void keepAliveHost() override\n" -" {\n" -" // keep the parent\'s host view alive.\n" -" // It is assumed to be either managed,\n" -" // or unmanaged but references memory (e.g. from numpy)\n" -" // with a longer lifetime that any result from the current LAPIS function.\n" -" keepAlive(host_view);\n" -" }\n" -"\n" " void deallocate() {\n" " device_view = DeviceView();\n" " host_view = HostView();\n" @@ -303,16 +297,28 @@ " return device_view.stride(dim);\n" " }\n" "\n" +" void toStridedMemRef(StridedMemRefTypeBase* out) {\n" +" syncHost();\n" +" *static_cast(out) = viewToStridedMemref(host_view);\n" +" }\n" +"\n" " DeviceView device_view;\n" " HostView host_view;\n" " };\n" "\n" +" struct DualViewBase\n" +" {\n" +" virtual void toStridedMemRef(StridedMemRefTypeBase* out) = 0;\n" +" virtual ~DualViewBase() {}\n" +" };\n" +"\n" " template\n" -" struct DualView\n" +" struct DualView : public DualViewBase\n" " {\n" " using ImplType = DualViewImpl;\n" " using DeviceView = typename ImplType::DeviceView;\n" " using HostView = typename ImplType::HostView;\n" +" using HostMemRefType = typename ImplType::HostMemRefType;\n" "\n" " std::shared_ptr impl;\n" " bool syncHostWhenDestroyed = false;\n" @@ -332,6 +338,11 @@ " impl->setParent(impl);\n" " }\n" "\n" +" void toStridedMemRef(StridedMemRefTypeBase* out)\n" +" {\n" +" impl->toStridedMemRef(out);\n" +" }\n" +"\n" " template\n" " DualView(const V& v) {\n" " static_assert(std::is_same_v,\n" @@ -349,12 +360,12 @@ " impl->setParent(parent.impl->parent);\n" " }\n" "\n" -" ~DualView() {\n" +" virtual ~DualView() {\n" " if(!impl) return;\n" " if(syncHostWhenDestroyed) syncHost();\n" -" DualViewBase* parent = impl->parent.get();\n" +" DualViewImplBase* parent = impl->parent.get();\n" " impl.reset();\n" -" // All DualViewBases keep a shared reference to themselves, so\n" +" // All DualViewImplBases keep a shared reference to themselves, so\n" " // parent always keeps a shared_ptr to itself. This would normally\n" " // prevent the parent destructor ever being called.\n" " //\n" @@ -416,15 +427,63 @@ " return impl->stride(dim);\n" " }\n" "\n" -" void keepAliveHost() const {\n" -" impl->parent->keepAliveHost();\n" -" }\n" -"\n" " void syncHostOnDestroy() {\n" " syncHostWhenDestroyed = true;\n" " }\n" " };\n" "\n" +" template\n" +" struct PythonParameter : public PythonParameterBase\n" +" {\n" +" DV toView() {\n" +" switch(wrapper_type)\n" +" {\n" +" case STRIDED_MEMREF_TYPE:\n" +" return stridedMemrefToView(*static_cast(smr));\n" +" break;\n" +"\n" +" case DUALVIEW_TYPE:\n" +" return *dynamic_cast(view);\n" +" break;\n" +"\n" +" default:\n" +" assert(false);\n" +"\n" +" // In case asserts are turned off, initialize to nullptr to make it easier to debug\n" +" DV* ret = nullptr;\n" +" return *ret;\n" +" };\n" +" }\n" +"\n" +" PythonParameter(const DV& dv)\n" +" {\n" +" wrapper_type = DUALVIEW_TYPE;\n" +" rank = DV::HostView::rank;\n" +" view = new DV(dv);\n" +" }\n" +"\n" +" PythonParameter(const PythonParameter& other)\n" +" {\n" +" wrapper_type = other.wrapper_type;\n" +" rank = other.rank;\n" +" if(wrapper_type == DUALVIEW_TYPE) {\n" +" view = new DV(other.view);\n" +" }else if(wrapper_type == STRIDED_MEMREF_TYPE) {\n" +" smr = new typename DV::HostMemRefType(static_cast(other.smr));\n" +" }\n" +" }\n" +"\n" +" ~PythonParameter()\n" +" {\n" +" if(wrapper_type == DUALVIEW_TYPE)\n" +" {\n" +" delete static_cast(view);\n" +" }else if(wrapper_type == STRIDED_MEMREF_TYPE) {\n" +" delete static_cast(smr);\n" +" }\n" +" }\n" +" };\n" +"\n" " inline int threadParallelVectorLength(int par) {\n" " if (par < 1)\n" " return 1;\n" diff --git a/mlir/lib/Target/KokkosCpp/TranslateToKokkosCpp.cpp b/mlir/lib/Target/KokkosCpp/TranslateToKokkosCpp.cpp index 02463a86..31927cba 100644 --- a/mlir/lib/Target/KokkosCpp/TranslateToKokkosCpp.cpp +++ b/mlir/lib/Target/KokkosCpp/TranslateToKokkosCpp.cpp @@ -2421,13 +2421,11 @@ static LogicalResult printFunctionDeviceLevel(KokkosCppEmitter &emitter, func::F auto retType = ftype.getResult(i); if(auto memrefType = dyn_cast(retType)) { - emitter << "StridedMemRefType<"; - if (failed(emitter.emitType(loc, memrefType.getElementType()))) - return func.emitError("Failed to emit result type as StridedMemRefType"); - emitter << ", " << memrefType.getShape().size() << ">** ret" << i; - } - else - { + emitter << "LAPIS::PythonParameter<"; + if (failed(emitter.emitMemrefType(loc, memrefType, kokkos::MemorySpace::DualView))) + return func.emitError("Failed to emit result type as DualView"); + emitter << ">** ret" << i; + }else{ //Assuming it is a scalar primitive if(failed(emitter.emitType(loc, retType))) return func.emitError("Failed to emit non-memref result type"); @@ -2442,13 +2440,14 @@ static LogicalResult printFunctionDeviceLevel(KokkosCppEmitter &emitter, func::F auto paramType = ftype.getInput(i); if(auto memrefType = dyn_cast(paramType)) { - emitter << "StridedMemRefType<"; - if (failed(emitter.emitType(loc, memrefType.getElementType()))) - return func.emitError("Failed to emit param type as StridedMemRefType"); - emitter << ", " << memrefType.getShape().size() << ">* param" << i; + emitter << "LAPIS::PythonParameter<"; + if (failed(emitter.emitMemrefType(loc, memrefType, kokkos::MemorySpace::DualView))) + return func.emitError("Failed to emit param type as DualView"); + emitter << ">* param" << i << "_wrapper"; } else { + //TODO: Handle structs appropriately bool isStruct = isa(paramType); // Structs are passed by const reference if(isStruct) { @@ -2467,10 +2466,14 @@ static LogicalResult printFunctionDeviceLevel(KokkosCppEmitter &emitter, func::F emitter.indent(); //FOR DEBUGGING THE EMITTED CODE: //If uncommented, the following 3 lines make the generated function pause to let user attach a debugger - //emitter << "std::cout << \"Starting MLIR function on process \" << getpid() << '\\n';\n"; - //emitter << "std::cout << \"Optionally attach debugger now, then press to continue: \";\n"; - //emitter << "std::cin.get();\n"; - //Construct an unmanaged, LayoutRight Kokkos::View for each memref input parameter. + //os << "std::cout << \"Starting MLIR function on process \" << getpid() << '\\n';\n"; + //os << "std::cout << \"Optionally attach debugger now, then press to continue: \";\n"; + //os << "std::cin.get();\n"; + //Wrap each parameter in a PythonParameter wrapper. If the parameter is a + //numpy array, the functions that use the parameters will create an unmanaged + //Kokkos::view. If the parameter was already a PythonParameter wrapper, it + //will be passed through. + // //Note: stridedMemrefToView with LayoutRight will check the strides at runtime, //and the python wrapper will use numpy.require to deep-copy the data to the correct //layout if it's not already. @@ -2480,10 +2483,7 @@ static LogicalResult printFunctionDeviceLevel(KokkosCppEmitter &emitter, func::F auto memrefType = dyn_cast(paramType); if(memrefType) { - emitter << "auto param" << i << "_smr = LAPIS::stridedMemrefToView<"; - if(failed(emitter.emitMemrefType(loc, memrefType, kokkos::MemorySpace::Host))) - return func.emitError("Failed to emit memref type as host view"); - emitter << ">(*param" << i << ");\n"; + emitter << "auto param" << i << " = param" << i << "_wrapper->toView();\n"; } } // Emit the call @@ -2498,7 +2498,7 @@ static LogicalResult printFunctionDeviceLevel(KokkosCppEmitter &emitter, func::F auto memrefType = dyn_cast(paramType); if(memrefType) { - emitter << "param" << i << "_smr"; + emitter << "param" << i; } else { @@ -2515,28 +2515,12 @@ static LogicalResult printFunctionDeviceLevel(KokkosCppEmitter &emitter, func::F auto memrefType = dyn_cast(retType); if(memrefType) { - if(numResults == size_t(1)) - emitter << "results.syncHost();\n"; - else - emitter << "std::get<" << i << ">(results).syncHost();\n"; - emitter << "**ret" << i << " = LAPIS::viewToStridedMemref("; + emitter << "new (*ret" << i << ") LAPIS::PythonParameter("; if(numResults == size_t(1)) emitter << "results"; else emitter << "std::get<" << i << ">(results)"; - emitter << ".host_view());\n"; - // Keep the host view alive until lapis_finalize() is called. - // Otherwise it would be deallocated as soon as this function returns. - std::string resultExpr; - if(numResults == size_t(1)) - resultExpr = "results"; - else - resultExpr = "std::get<" + std::to_string(i) + ">(results)"; - // If host and device memory alias each other, one of the views will be an unmanaged - // shallow copy of the other. Keep both host and device alive in this case. - emitter << "LAPIS::keepAlive(" << resultExpr << ".host_view());\n"; - emitter << "if(" << resultExpr << ".host_view().data() == " << resultExpr << ".device_view().data()) \n"; - emitter << " LAPIS::keepAlive(" << resultExpr << ".device_view());\n"; + emitter << ");\n"; } else { @@ -2625,8 +2609,7 @@ static LogicalResult printFunctionDeviceLevel(KokkosCppEmitter &emitter, func::F // NOTE: numpy.zeros(shape, dtype=...) already defaults to LayoutRight (and probably most other functions) // so in practice this shouldn't usually trigger a deep-copy. auto& py_os = emitter.py_ostream(); - //NOTE: this function is a member of the module's class, but py_os is already indented to write methods. - py_os << "def " << funcName << "(self, "; + py_os << "def " << funcName << "("; for(size_t i = 0; i < numParams; i++) { if(i != 0) @@ -2703,7 +2686,7 @@ static LogicalResult printFunctionDeviceLevel(KokkosCppEmitter &emitter, func::F std::string numpyDType = getNumpyType(memrefType.getElementType()); if(!numpyDType.size()) return func.emitError("Could not determine corresponding numpy type for memref element type"); - py_os << "param" << i << " = numpy.require(param" << i << ", dtype=" << numpyDType << ", requirements=['C'])\n"; + py_os << "param" << i << " = wrap_array_parameter(param" << i << ", dtype=" << numpyDType << ")\n"; } else if(auto structType = dyn_cast(paramType)) { // Expect this parameter to be a tuple with the correct structure. Flatten it to a numpy array. @@ -2718,7 +2701,7 @@ static LogicalResult printFunctionDeviceLevel(KokkosCppEmitter &emitter, func::F int flatIdx = 0; genStructFlatten("param" + std::to_string(i), "param_flat" + std::to_string(i), flatIdx, structType); // Replace original param with flattened version, as we don't need original anymore - py_os << "param" << i << " = param_flat" << i << "\n"; + py_os << "param" << i << " = wrap_array_parameter(param_flat" << i << ", dtype=" << numpyDType << ")\n"; } else { // Ensure scalars have the correct type. @@ -2737,7 +2720,7 @@ static LogicalResult printFunctionDeviceLevel(KokkosCppEmitter &emitter, func::F if(auto memrefType = dyn_cast(retType)) { int rank = memrefType.hasRank() ? memrefType.getShape().size() : 1; - py_os << "ret" << i << " = ctypes.pointer(ctypes.pointer(rt.make_nd_memref_descriptor(" << rank << ", " << getCtypesType(memrefType.getElementType()) << ")()))\n"; + py_os << "ret" << i << " = ParameterWrapper.empty(" << getCtypesType(memrefType.getElementType()) << ")\n"; } else if(isa(retType)) { @@ -2752,7 +2735,7 @@ static LogicalResult printFunctionDeviceLevel(KokkosCppEmitter &emitter, func::F std::string numpyDType = getNumpyType(elem); if(!numpyDType.size()) return func.emitError("Could not determine corresponding numpy type for result scalar type"); - py_os << "ret" << i << " = numpy.zeros(" << size << ", dtype=" << numpyDType << ")\n"; + py_os << "ret" << i << " = ParameterWrapper.empty(" << numpyDType << ")\n"; } else { @@ -2764,7 +2747,7 @@ static LogicalResult printFunctionDeviceLevel(KokkosCppEmitter &emitter, func::F } } // Generate the native call. It always returns void. - py_os << "self.libHandle.py_" << funcName << "("; + py_os << "libHandle.py_" << funcName << "("; // Outputs go first for(size_t i = 0; i < numResults; i++) { @@ -2778,11 +2761,15 @@ static LogicalResult printFunctionDeviceLevel(KokkosCppEmitter &emitter, func::F } else if(isa(retType)) { - py_os << "ret" << i; + py_os << "ctypes.pointer(ctypes.pointer(ret" << i << "))"; + } + else if(isa(retType)) + { + py_os << "ret" << i << ".asnumpy().ctypes.data_as(ctypes.c_void_p)"; } else { - // numpy array, flattened struct or scalar + // scalar py_os << "ret" << i << ".ctypes.data_as(ctypes.c_void_p)"; } } @@ -2800,12 +2787,12 @@ static LogicalResult printFunctionDeviceLevel(KokkosCppEmitter &emitter, func::F else if(isa(paramType)) { //Numpy array (or a scalar from a numpy array) - py_os << "ctypes.pointer(rt.get_ranked_memref_descriptor(param" << i << "))"; + py_os << "ctypes.pointer(param" << i << ")"; } else if(isa(paramType)) { //Structs are flattened to 1D Numpy arrays - py_os << "param" << i << ".ctypes.data_as(ctypes.c_void_p)"; + py_os << "param" << i << ".asnumpy().ctypes.data_as(ctypes.c_void_p)"; } else { //Scalar @@ -2814,7 +2801,7 @@ static LogicalResult printFunctionDeviceLevel(KokkosCppEmitter &emitter, func::F } py_os << ")\n"; // Finally, generate the return statement. - // Note that in Python, a 1-elem tuple is equivalent to scalar. + // Note that we return a scalar if a single result is returned. if(numResults) { py_os << "return ("; @@ -2829,7 +2816,7 @@ static LogicalResult printFunctionDeviceLevel(KokkosCppEmitter &emitter, func::F } else if(isa(retType)) { - py_os << "rt.ranked_memref_to_numpy(ret" << i << "[0])"; + py_os << "ret" << i; } else if(auto structType = dyn_cast(retType)) { int idx = 0; @@ -2841,7 +2828,7 @@ static LogicalResult printFunctionDeviceLevel(KokkosCppEmitter &emitter, func::F py_os << "ret" << i << "[0]"; } } - py_os << ")\n"; + py_os << ")\n\n"; } py_os.unindent(); return success(); @@ -4101,17 +4088,21 @@ LogicalResult KokkosCppEmitter::emitOperation(Operation &op, bool trailingSemico LogicalResult KokkosCppEmitter::emitInitAndFinalize(bool finalizeKokkos = true) { + // Declare the init/finalize in decl file selectDeclCppStream(); *this << "extern \"C\" void lapis_initialize();\n"; *this << "extern \"C\" void lapis_finalize();\n"; + if(!emittingTeamLevel()) { + *this << "extern \"C\" void getHostData(StridedMemRefTypeBase* out, LAPIS::PythonParameterBase* in);\n"; + *this << "extern \"C\" void freeDualView(LAPIS::DualViewBase* handle);\n"; + } selectMainCppStream(); - *this << "extern \"C\" void lapis_initialize() {\n"; + *this << "extern \"C\" void lapis_initialize()\n"; + *this << "{\n"; indent(); - // lapis_initialize is never responsible for initializing Kokkos if we're - // emitting team-level. if(!emittingTeamLevel()) { - *this << "if (!Kokkos::is_initialized()) Kokkos::initialize();\n"; + *this << "if (!Kokkos::is_initialized()) Kokkos::initialize();\n"; } //For each global view, that is not unused: // - allocate it @@ -4191,14 +4182,29 @@ LogicalResult KokkosCppEmitter::emitInitAndFinalize(bool finalizeKokkos = true) *this << "();\n"; } } + + // Free views returned to Python + if(finalizeKokkos) { + *this << "Kokkos::finalize();\n"; + } + this->unindent(); + *this << "}\n\n"; + if(!emittingTeamLevel()) { - // Free views returned to Python - *this << "LAPIS::alives.clear();\n"; - if(finalizeKokkos) - *this << "Kokkos::finalize();\n"; + *this << "extern \"C\" void getHostData(StridedMemRefTypeBase* out, LAPIS::PythonParameterBase* in)\n"; + *this << "{\n"; + *this << " assert(in->wrapper_type == LAPIS::PythonParameterBase::DUALVIEW_TYPE);\n"; + *this << " in->view->toStridedMemRef(out);\n"; + *this << "}\n"; + *this << "\n"; + + *this << "extern \"C\" void freeDualView(LAPIS::DualViewBase* handle)\n"; + *this << "{\n"; + *this << " delete handle;\n"; + *this << "}\n"; + *this << "\n"; } - unindent(); - *this << "}\n"; + return success(); } @@ -4217,22 +4223,126 @@ void KokkosCppEmitter::emitCppBoilerplate() void KokkosCppEmitter::emitPythonBoilerplate() { + *py_os << "import atexit\n"; *py_os << "import ctypes\n"; + *py_os << "import enum\n"; + *py_os << "import functools\n"; + *py_os << "import os.path\n"; + *py_os << "import sys\n"; + *py_os << "import types\n"; + *py_os << "import weakref\n"; + *py_os << "\n"; *py_os << "import numpy\n"; *py_os << "from mlir import runtime as rt\n"; - *py_os << "class LAPISModule:\n"; - *py_os << " def __init__(self, libPath):\n"; - //*py_os << " print('Hello from LAPISModule.__init__!')\n"; - *py_os << " self.libHandle = ctypes.CDLL(libPath)\n"; - // Do all initialization immediately - //*py_os << " print('Initializing module.')\n"; - *py_os << " self.libHandle.lapis_initialize()\n"; - //*py_os << " print('Done initializing module.')\n"; + *py_os << "import os.path\n"; + *py_os << "\n"; + *py_os << "dirpath = os.path.dirname(os.path.abspath(__file__))\n"; + *py_os << "modname = __name__.rsplit('.', 1)[-1]\n"; + *py_os << "modpath = os.path.join(dirpath, \"build\", f\"lib{modname}_module.so\")\n"; + *py_os << "if not os.path.isfile(modpath):\n"; + *py_os << " modpath = os.path.join(dirpath, \"build\", f\"lib{modname}_module.dylib\")\n"; + *py_os << "libHandle = ctypes.CDLL(modpath)\n"; + *py_os << "libHandle.lapis_initialize()\n"; + *py_os << "\n"; + *py_os << "class ParameterWrapperType(enum.Enum):\n"; + *py_os << " EMPTY_TYPE = 0\n"; + *py_os << " STRIDED_MEMREF_TYPE = 1\n"; + *py_os << " DUALVIEW_TYPE = 2\n"; + *py_os << "\n"; + *py_os << "class ParameterWrapper(ctypes.Structure):\n"; + *py_os << " _fields_ = [\n"; + *py_os << " ('wrapper_type', ctypes.c_int32),\n"; + *py_os << " ('rank', ctypes.c_int32),\n"; + *py_os << " ('ptr', ctypes.c_void_p),\n"; + *py_os << " ]\n"; + *py_os << "\n"; + *py_os << " _needs_dealloc = weakref.WeakSet()\n"; + *py_os << "\n"; + *py_os << " @classmethod\n"; + *py_os << " def build(cls, wrapper_type, ptr, dtype, rank, base=None):\n"; + *py_os << " ret = cls()\n"; + *py_os << " ret.wrapper_type = wrapper_type.value\n"; + *py_os << " ret.ptr = ctypes.cast(ptr, ctypes.c_void_p)\n"; + *py_os << " ret.rank = rank\n"; + *py_os << " cls._needs_dealloc.add(ret)\n"; + *py_os << " ret.base = base #ties lifespan of base to this object\n"; + *py_os << " ret._ctype = numpy.ctypeslib.as_ctypes_type(dtype)\n"; + *py_os << " return ret\n"; + *py_os << "\n"; + *py_os << " @classmethod\n"; + *py_os << " def empty(cls, dtype, rank=0):\n"; + *py_os << " ret = cls()\n"; + *py_os << " ret.wrapper_type = ParameterWrapperType.EMPTY_TYPE.value\n"; + *py_os << " ret.ptr = ctypes.c_void_p(0)\n"; + *py_os << " ret.rank = rank\n"; + *py_os << " cls._needs_dealloc.add(ret)\n"; + *py_os << " ret._ctype = numpy.ctypeslib.as_ctypes_type(dtype)\n"; + *py_os << " return ret\n"; + *py_os << "\n"; + *py_os << " def asmemref(self):\n"; + *py_os << " ret_type = rt.make_nd_memref_descriptor(self.rank, self._ctype)\n"; + *py_os << " if self.wrapper_type == ParameterWrapperType.STRIDED_MEMREF_TYPE.value:\n"; + *py_os << " ret = ctypes.cast(self.ptr, ctypes.POINTER(ret_type)).contents\n"; + *py_os << " elif self.wrapper_type == ParameterWrapperType.DUALVIEW_TYPE.value:\n"; + *py_os << " ret = ret_type()\n"; + *py_os << " libHandle.getHostData(ctypes.pointer(ret), ctypes.pointer(self))\n"; + *py_os << " ret.base = self # ties lifespan of this object to strided memref ret\n"; + *py_os << " return ret\n"; + *py_os << "\n"; + *py_os << " def asctypes(self):\n"; + *py_os << " smr = self.asmemref()\n"; + *py_os << " size = sum((size-1) for size in smr.shape) + smr.offset\n"; + *py_os << " buffer_type = self._ctype * size\n"; + *py_os << " ret = ctypes.cast(smr.aligned, ctypes.POINTER(buffer_type)).contents\n"; + *py_os << " ret.base = self # ties lifespan of this object to ctypes array ret\n"; + *py_os << " return ret\n"; + *py_os << "\n"; + *py_os << " def asnumpy(self):\n"; + *py_os << " smr = self.asmemref()\n"; + *py_os << " carray = self.asctypes()\n"; + *py_os << " # numpy ties lifespan of carray to numpy arrays created by frombuffer\n"; + *py_os << " obj = numpy.frombuffer(carray, dtype=self._ctype, offset=smr.offset * ctypes.sizeof(self._ctype))\n"; + *py_os << " ret = numpy.lib.stride_tricks.as_strided(\n"; + *py_os << " obj[smr.offset:],\n"; + *py_os << " shape=numpy.ctypeslib.as_array(smr.shape),\n"; + *py_os << " strides=numpy.ctypeslib.as_array(smr.strides) * obj.itemsize\n"; + *py_os << " )\n"; + *py_os << " return ret\n"; + *py_os << "\n"; + *py_os << " def __hash__(self):\n"; + *py_os << " return id(self)\n"; + *py_os << "\n"; + *py_os << " def _dealloc(self):\n"; + *py_os << " if self.wrapper_type == ParameterWrapperType.DUALVIEW_TYPE.value:\n"; + *py_os << " libHandle.freeDualView(ctypes.c_void_p(self.ptr))\n"; + *py_os << " self.wrapper_type = ParameterWrapperType.EMPTY_TYPE.value\n"; + *py_os << " self.ptr = ctypes.c_void_p()\n"; + *py_os << " self.rank = 0\n"; + *py_os << "\n"; *py_os << " def __del__(self):\n"; - *py_os << " self.libHandle.lapis_finalize()\n"; - //From here, only function wrappers are emitted. - //These are class members so indent all of them now. - py_os->indent(); + *py_os << " self._dealloc()\n"; + *py_os << "\n"; + *py_os << "def finalize():\n"; + *py_os << " for ref in ParameterWrapper._needs_dealloc:\n"; + *py_os << " ref._dealloc()\n"; + *py_os << " libHandle.lapis_finalize()\n"; + *py_os << "atexit.register(finalize)\n"; + *py_os << "\n"; + *py_os << "def wrap_array_parameter(param, dtype):\n"; + *py_os << " if isinstance(param, numpy.ndarray):\n"; + *py_os << " param = numpy.require(param, dtype=dtype, requirements=['C'])\n"; + *py_os << " ptr = ctypes.pointer(rt.get_ranked_memref_descriptor(param))\n"; + *py_os << " return ParameterWrapper.build(ParameterWrapperType.STRIDED_MEMREF_TYPE, ptr, dtype, param.ndim, base=param)\n"; + *py_os << " elif str(type(param)) == \"\": # Compare type string to avoid importing torch unnecessarily\n"; + *py_os << " if param.device.type == 'cpu':\n"; + *py_os << " return wrap_array_parameter(param.numpy(), dtype)\n"; + *py_os << " else:\n"; + *py_os << " # TODO: Do we want to allow direct references to toch managed GPU memory?\n"; + *py_os << " return wrap_array_parameter(param.cpu().numpy(), dtype)\n"; + *py_os << " else:\n"; + *py_os << " return param\n"; + *py_os << "\n"; + } LogicalResult KokkosCppEmitter::emitType(Location loc, Type type, bool forSparseRuntime) { @@ -4636,7 +4746,7 @@ LogicalResult kokkos::translateToKokkosCppTeamLevel(Operation *op, raw_ostream* if(failed(emitter.emitOperation(*op, /*trailingSemicolon=*/false))) return failure(); // Emit the init and finalize function definitions. - if (failed(emitter.emitInitAndFinalize())) + if (failed(emitter.emitInitAndFinalize(false))) return failure(); if(header_os) { *header_os << "#ifndef LAPIS_MODULE_H\n"; diff --git a/python/lapis/KokkosBackend.py b/python/lapis/KokkosBackend.py index add63511..0fbb7a52 100644 --- a/python/lapis/KokkosBackend.py +++ b/python/lapis/KokkosBackend.py @@ -61,10 +61,7 @@ def compile_kokkos_to_native(self, moduleRoot, linkSparseSupportLib): buildOut = subprocess.run(['make'], cwd=buildDir, shell=True) sys.path.insert(0, moduleRoot) lapis = __import__(self.package_name) - if os.path.isfile(buildDir + "/lib" + self.package_name + "_module.so"): - return lapis.LAPISModule(buildDir + "/lib" + self.package_name + "_module.so") - if os.path.isfile(buildDir + "/lib" + self.package_name + "_module.dylib"): - return lapis.LAPISModule(buildDir + "/lib" + self.package_name + "_module.dylib") + return lapis def run_cli(self, app, flags, stdin): appAbsolute = which(app)