From 23a6c82bed9ea4821d68bf6dec800d1ae4a5186c Mon Sep 17 00:00:00 2001 From: Michel Schanen Date: Tue, 9 Sep 2025 15:50:31 +0000 Subject: [PATCH 1/2] oneAPI-aware MPI --- Project.toml | 4 ++++ ext/OneAPIExt.jl | 27 +++++++++++++++++++++++++++ 2 files changed, 31 insertions(+) create mode 100644 ext/OneAPIExt.jl diff --git a/Project.toml b/Project.toml index 7e3d66252..ac4123773 100644 --- a/Project.toml +++ b/Project.toml @@ -34,15 +34,19 @@ Requires = "~0.5, 1.0" Serialization = "1" Sockets = "1" julia = "1.6" +oneAPI = "2.1" [extensions] AMDGPUExt = "AMDGPU" CUDAExt = "CUDA" +OneAPIExt = "oneAPI" [extras] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b" [weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b" diff --git a/ext/OneAPIExt.jl b/ext/OneAPIExt.jl new file mode 100644 index 000000000..6891aaf43 --- /dev/null +++ b/ext/OneAPIExt.jl @@ -0,0 +1,27 @@ +module OneAPIExt + +import MPI +isdefined(Base, :get_extension) ? (import oneAPI) : (import ..oneAPI) +import MPI: MPIPtr, Buffer, Datatype + +function Base.cconvert(::Type{MPIPtr}, A::oneAPI.oneArray{T}) where T + A +end + +function Base.unsafe_convert(::Type{MPIPtr}, X::oneAPI.oneArray{T}) where T + reinterpret(MPIPtr, Base.unsafe_convert(oneAPI.ZePtr{T}, X)) +end + +# only need to define this for strided arrays: all others can be handled by generic machinery +function Base.unsafe_convert(::Type{MPIPtr}, V::SubArray{T,N,P,I,true}) where {T,N,P<:oneAPI.oneArray,I} + X = parent(V) + pX = Base.unsafe_convert(oneAPI.ZePtr{T}, X) + pV = pX + ((V.offset1 + V.stride1) - first(LinearIndices(X)))*sizeof(T) + return reinterpret(MPIPtr, pV) +end + +function Buffer(arr::oneAPI.oneArray) + Buffer(arr, Cint(length(arr)), Datatype(eltype(arr))) +end + +end # OneAPIExt From 477e30d05e53cf65df5802529b6d8ecf53679e1c Mon Sep 17 00:00:00 2001 From: Michel Schanen Date: Thu, 4 Dec 2025 11:01:20 -0600 Subject: [PATCH 2/2] Add has_oneapi() --- src/environment.jl | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/environment.jl b/src/environment.jl index 70502173b..ea3be9005 100644 --- a/src/environment.jl +++ b/src/environment.jl @@ -374,6 +374,25 @@ function has_rocm() end end +""" + MPI.has_oneapi() + +Check if the MPI implementation is known to have oneAPI support. + +This can be overridden by setting the `JULIA_MPI_HAS_ONEAPI` environment variable to `true` +or `false`. + +See also [`MPI.has_cuda`](@ref) and [`MPI.has_rocm`](@ref) for CUDA and ROCm support. +""" +function has_oneapi() + flag = get(ENV, "JULIA_MPI_HAS_ONEAPI", nothing) + if flag === nothing + return false + else + return parse(Bool, flag) + end +end + """ MPI.has_gpu()