From f5141c4b1b94c260d886ce3957e8985b9f69222f Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Wed, 29 Jan 2025 00:28:48 -0400 Subject: [PATCH 1/8] Add `nextafter` intrinsic --- src/device/intrinsics/math.jl | 5 +++++ test/device/intrinsics.jl | 20 +++++++++++++++++++- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/src/device/intrinsics/math.jl b/src/device/intrinsics/math.jl index b33942390..e7ccd20b8 100644 --- a/src/device/intrinsics/math.jl +++ b/src/device/intrinsics/math.jl @@ -294,6 +294,11 @@ end @device_override Base.trunc(x::Float32) = ccall("extern air.trunc.f32", llvmcall, Cfloat, (Cfloat,), x) @device_override Base.trunc(x::Float16) = ccall("extern air.trunc.f16", llvmcall, Float16, (Float16,), x) +@static if Metal.is_macos(v"14") + @device_function nextafter(x::Float32, y::Float32) = ccall("extern air.nextafter.f32", llvmcall, Cfloat, (Cfloat, Cfloat), x, y) + @device_function nextafter(x::Float16, y::Float16) = ccall("extern air.nextafter.f16", llvmcall, Float16, (Float16, Float16), x, y) +end + # hypot without use of double # # taken from Cosmopolitan Libc diff --git a/test/device/intrinsics.jl b/test/device/intrinsics.jl index b71f62b14..d1bf45288 100644 --- a/test/device/intrinsics.jl +++ b/test/device/intrinsics.jl @@ -159,7 +159,6 @@ MATH_INTR_FUNCS_2_ARG = [ # frexp, # T frexp(T x, Ti &exponent) # ldexp, # T ldexp(T x, Ti k) # modf, # T modf(T x, T &intval) - # nextafter, # T nextafter(T x, T y) # Metal 3.1+ hypot, # NOT MSL but tested the same ] @@ -353,6 +352,25 @@ end vec = Array(expm1.(buffer)) @test vec ≈ expm1.(arr) end + + + let # nextafter + if Metal.is_macos(v"14") + N = 4 + function nextafter_test(X, y) + idx = thread_position_in_grid_1d() + X[idx] = Metal.nextafter(X[idx], y) + return nothing + end + arr = rand(T, N) + buffer = MtlArray(arr) + Metal.@sync @metal threads = N nextafter_test(buffer, typemax(T)) + @test Array(buffer) == nextfloat.(arr) + + Metal.@sync @metal threads = N nextafter_test(buffer, typemin(T)) + @test Array(buffer) == arr + end + end end end From ef4370710325474706ff6f622f91ecd8b5d217a8 Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Thu, 6 Feb 2025 21:43:05 -0400 Subject: [PATCH 2/8] Better version? --- src/Metal.jl | 2 +- src/device/intrinsics/math.jl | 16 +++++++++++++--- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/src/Metal.jl b/src/Metal.jl index 7eb057e4f..28b38c559 100644 --- a/src/Metal.jl +++ b/src/Metal.jl @@ -27,12 +27,12 @@ include("device/utils.jl") include("device/pointer.jl") include("device/array.jl") include("device/runtime.jl") +include("device/intrinsics/version.jl") include("device/intrinsics/arguments.jl") include("device/intrinsics/math.jl") include("device/intrinsics/synchronization.jl") include("device/intrinsics/memory.jl") include("device/intrinsics/simd.jl") -include("device/intrinsics/version.jl") include("device/intrinsics/atomics.jl") include("device/quirks.jl") diff --git a/src/device/intrinsics/math.jl b/src/device/intrinsics/math.jl index e7ccd20b8..35009c382 100644 --- a/src/device/intrinsics/math.jl +++ b/src/device/intrinsics/math.jl @@ -294,9 +294,19 @@ end @device_override Base.trunc(x::Float32) = ccall("extern air.trunc.f32", llvmcall, Cfloat, (Cfloat,), x) @device_override Base.trunc(x::Float16) = ccall("extern air.trunc.f16", llvmcall, Float16, (Float16,), x) -@static if Metal.is_macos(v"14") - @device_function nextafter(x::Float32, y::Float32) = ccall("extern air.nextafter.f32", llvmcall, Cfloat, (Cfloat, Cfloat), x, y) - @device_function nextafter(x::Float16, y::Float16) = ccall("extern air.nextafter.f16", llvmcall, Float16, (Float16, Float16), x, y) +@device_function function nextafter(x::Float32, y::Float32) + if metal_version() >= sv"3.1" # macOS 14+ + ccall("extern air.nextafter.f32", llvmcall, Cfloat, (Cfloat, Cfloat), x, y) + else + error() + end +end +@device_function function nextafter(x::Float16, y::Float16) + if metal_version() >= sv"3.1" # macOS 14+ + ccall("extern air.nextafter.f16", llvmcall, Float16, (Float16, Float16), x, y) + else + error() + end end # hypot without use of double From d1a8729700152def12cd41c2668d86ef44b8bfea Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Thu, 6 Feb 2025 22:32:13 -0400 Subject: [PATCH 3/8] Better version?? --- src/device/intrinsics/math.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/device/intrinsics/math.jl b/src/device/intrinsics/math.jl index 35009c382..5079d7eb0 100644 --- a/src/device/intrinsics/math.jl +++ b/src/device/intrinsics/math.jl @@ -298,14 +298,14 @@ end if metal_version() >= sv"3.1" # macOS 14+ ccall("extern air.nextafter.f32", llvmcall, Cfloat, (Cfloat, Cfloat), x, y) else - error() + reinterpret(Float32, reinterpret(UInt32, x) + sign(y-x)) end end @device_function function nextafter(x::Float16, y::Float16) if metal_version() >= sv"3.1" # macOS 14+ ccall("extern air.nextafter.f16", llvmcall, Float16, (Float16, Float16), x, y) else - error() + reinterpret(Float16, reinterpret(UInt16, x) + sign(y-x)) end end From 78947d9877621988639d378392f0d67653fbbb91 Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Fri, 7 Feb 2025 10:19:20 -0400 Subject: [PATCH 4/8] Working code --- src/device/intrinsics/math.jl | 4 ++-- test/device/intrinsics.jl | 26 ++++++++++++-------------- 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/src/device/intrinsics/math.jl b/src/device/intrinsics/math.jl index 5079d7eb0..19a49dfc7 100644 --- a/src/device/intrinsics/math.jl +++ b/src/device/intrinsics/math.jl @@ -298,14 +298,14 @@ end if metal_version() >= sv"3.1" # macOS 14+ ccall("extern air.nextafter.f32", llvmcall, Cfloat, (Cfloat, Cfloat), x, y) else - reinterpret(Float32, reinterpret(UInt32, x) + sign(y-x)) + nextfloat(x, unsafe_trunc(Int32, sign(y - x))) end end @device_function function nextafter(x::Float16, y::Float16) if metal_version() >= sv"3.1" # macOS 14+ ccall("extern air.nextafter.f16", llvmcall, Float16, (Float16, Float16), x, y) else - reinterpret(Float16, reinterpret(UInt16, x) + sign(y-x)) + nextfloat(x, unsafe_trunc(Int16, sign(y - x))) end end diff --git a/test/device/intrinsics.jl b/test/device/intrinsics.jl index d1bf45288..7a1af6497 100644 --- a/test/device/intrinsics.jl +++ b/test/device/intrinsics.jl @@ -355,21 +355,19 @@ end let # nextafter - if Metal.is_macos(v"14") - N = 4 - function nextafter_test(X, y) - idx = thread_position_in_grid_1d() - X[idx] = Metal.nextafter(X[idx], y) - return nothing - end - arr = rand(T, N) - buffer = MtlArray(arr) - Metal.@sync @metal threads = N nextafter_test(buffer, typemax(T)) - @test Array(buffer) == nextfloat.(arr) - - Metal.@sync @metal threads = N nextafter_test(buffer, typemin(T)) - @test Array(buffer) == arr + N = 4 + function nextafter_test(X, y) + idx = thread_position_in_grid_1d() + X[idx] = Metal.nextafter(X[idx], y) + return nothing end + arr = rand(T, N) + buffer = MtlArray(arr) + Metal.@sync @metal threads = N nextafter_test(buffer, typemax(T)) + @test Array(buffer) == nextfloat.(arr) + + Metal.@sync @metal threads = N nextafter_test(buffer, typemin(T)) + @test Array(buffer) == arr end end end From f4fb3763c69a7e2a1c6193ed7364753b1617105a Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Sat, 8 Feb 2025 12:28:48 -0400 Subject: [PATCH 5/8] Sneak in some typo fixes --- src/compiler/execution.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/compiler/execution.jl b/src/compiler/execution.jl index 58b2c7f29..54a089a3a 100644 --- a/src/compiler/execution.jl +++ b/src/compiler/execution.jl @@ -18,9 +18,9 @@ certain extent arguments will be converted and managed automatically using `mtlc Finally, a call to `mtlcall` is performed, creating a command buffer in the current global command queue then committing it. -There is one supported keyword argument that influences the behavior of `@metal`: +There are a few keyword arguments that influence the behavior of `@metal`: -- `launch`: whether to launch this kernel, defaults to `true`. If `false` the returned +- `launch`: whether to launch this kernel, defaults to `true`. If `false`, the returned kernel object should be launched by calling it and passing arguments again. - `name`: the name of the kernel in the generated code. Defaults to an automatically- generated name. From 535400331716906f79849a51eb8fe255e700ee6e Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Sat, 8 Feb 2025 12:29:54 -0400 Subject: [PATCH 6/8] Test both versions of intrinsic --- test/device/intrinsics.jl | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/test/device/intrinsics.jl b/test/device/intrinsics.jl index 7a1af6497..843dc826f 100644 --- a/test/device/intrinsics.jl +++ b/test/device/intrinsics.jl @@ -362,12 +362,20 @@ end return nothing end arr = rand(T, N) - buffer = MtlArray(arr) - Metal.@sync @metal threads = N nextafter_test(buffer, typemax(T)) - @test Array(buffer) == nextfloat.(arr) - Metal.@sync @metal threads = N nextafter_test(buffer, typemin(T)) - @test Array(buffer) == arr + # test the intrinsic + buffer1 = MtlArray(arr) + Metal.@sync @metal threads = N nextafter_test(buffer1, typemax(T)) + @test Array(buffer1) == nextfloat.(arr) + Metal.@sync @metal threads = N nextafter_test(buffer1, typemin(T)) + @test Array(buffer1) == arr + + # test for metal < 3.1 + buffer2 = MtlArray(arr) + Metal.@sync @metal threads = N metal = v"3.0" nextafter_test(buffer2, typemax(T)) + @test Array(buffer2) == nextfloat.(arr) + Metal.@sync @metal threads = N metal = v"3.0" nextafter_test(buffer2, typemin(T)) + @test Array(buffer2) == arr end end end From f05282b64b8bdf84be441fdf36a80f4117797c62 Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Mon, 10 Feb 2025 16:08:58 -0400 Subject: [PATCH 7/8] Add tests --- test/device/intrinsics.jl | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/test/device/intrinsics.jl b/test/device/intrinsics.jl index 843dc826f..98b22149f 100644 --- a/test/device/intrinsics.jl +++ b/test/device/intrinsics.jl @@ -376,6 +376,17 @@ end @test Array(buffer2) == nextfloat.(arr) Metal.@sync @metal threads = N metal = v"3.0" nextafter_test(buffer2, typemin(T)) @test Array(buffer2) == arr + + # Check the code is generated as expected + outval = T(0) + function nextafter_out_test() + Metal.nextafter(outval, outval) + return + end + ir = sprint(io->(@device_code_llvm io=io dump_module=true @metal nextafter_out_test())) + @test occursin(Regex("@air\\.nextafter\\.f$(8*sizeof(T))"), ir) + ir = sprint(io->(@device_code_llvm io=io dump_module=true @metal metal = v"3.0" nextafter_out_test())) + @test occursin(Regex("@air\\.sign\\.f$(8*sizeof(T))"), ir) end end end From 78ba579c71d059ba6e4e0fb8f9c7734a8b56c93d Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Mon, 10 Feb 2025 21:27:12 -0400 Subject: [PATCH 8/8] Only test nextafter intrinsic when available. Tests should still pass when run on macOS 13 --- test/device/intrinsics.jl | 35 ++++++++++++++++++++--------------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/test/device/intrinsics.jl b/test/device/intrinsics.jl index 98b22149f..ecd5e44d5 100644 --- a/test/device/intrinsics.jl +++ b/test/device/intrinsics.jl @@ -355,20 +355,33 @@ end let # nextafter - N = 4 function nextafter_test(X, y) idx = thread_position_in_grid_1d() X[idx] = Metal.nextafter(X[idx], y) return nothing end + + # Check the code is generated as expected + outval = T(0) + function nextafter_out_test() + Metal.nextafter(outval, outval) + return + end + + N = 4 arr = rand(T, N) - # test the intrinsic - buffer1 = MtlArray(arr) - Metal.@sync @metal threads = N nextafter_test(buffer1, typemax(T)) - @test Array(buffer1) == nextfloat.(arr) - Metal.@sync @metal threads = N nextafter_test(buffer1, typemin(T)) - @test Array(buffer1) == arr + # test the intrinsic (macOS >= v14) + if metal_support() >= v"3.1" + buffer1 = MtlArray(arr) + Metal.@sync @metal threads = N nextafter_test(buffer1, typemax(T)) + @test Array(buffer1) == nextfloat.(arr) + Metal.@sync @metal threads = N nextafter_test(buffer1, typemin(T)) + @test Array(buffer1) == arr + + ir = sprint(io->(@device_code_llvm io=io dump_module=true @metal nextafter_out_test())) + @test occursin(Regex("@air\\.nextafter\\.f$(8*sizeof(T))"), ir) + end # test for metal < 3.1 buffer2 = MtlArray(arr) @@ -377,14 +390,6 @@ end Metal.@sync @metal threads = N metal = v"3.0" nextafter_test(buffer2, typemin(T)) @test Array(buffer2) == arr - # Check the code is generated as expected - outval = T(0) - function nextafter_out_test() - Metal.nextafter(outval, outval) - return - end - ir = sprint(io->(@device_code_llvm io=io dump_module=true @metal nextafter_out_test())) - @test occursin(Regex("@air\\.nextafter\\.f$(8*sizeof(T))"), ir) ir = sprint(io->(@device_code_llvm io=io dump_module=true @metal metal = v"3.0" nextafter_out_test())) @test occursin(Regex("@air\\.sign\\.f$(8*sizeof(T))"), ir) end