-
Notifications
You must be signed in to change notification settings - Fork 48
Add FFT support via AbstractFFTs interface #713
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #713 +/- ##
==========================================
+ Coverage 80.96% 81.96% +1.00%
==========================================
Files 62 63 +1
Lines 2837 3072 +235
==========================================
+ Hits 2297 2518 +221
- Misses 540 554 +14 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Metal Benchmarks
Details
| Benchmark suite | Current: b88d77f | Previous: 239fa4d | Ratio |
|---|---|---|---|
latency/precompile |
24243436166 ns |
24383716541 ns |
0.99 |
latency/ttfp |
2293510291 ns |
2324081375 ns |
0.99 |
latency/import |
1407606104 ns |
1427504083 ns |
0.99 |
integration/metaldevrt |
835916 ns |
837292 ns |
1.00 |
integration/byval/slices=1 |
1533708 ns |
1598354 ns |
0.96 |
integration/byval/slices=3 |
8393604.5 ns |
19021791.5 ns |
0.44 |
integration/byval/reference |
1533750 ns |
1590708.5 ns |
0.96 |
integration/byval/slices=2 |
2594562 ns |
2727250 ns |
0.95 |
kernel/indexing |
584959 ns |
459062.5 ns |
1.27 |
kernel/indexing_checked |
589916 ns |
463104.5 ns |
1.27 |
kernel/launch |
11542 ns |
11625 ns |
0.99 |
kernel/rand |
563000 ns |
526667 ns |
1.07 |
array/construct |
6000 ns |
5958 ns |
1.01 |
array/broadcast |
592667 ns |
545375 ns |
1.09 |
array/random/randn/Float32 |
746750 ns |
886167 ns |
0.84 |
array/random/randn!/Float32 |
622625 ns |
578875 ns |
1.08 |
array/random/rand!/Int64 |
556354.5 ns |
539083 ns |
1.03 |
array/random/rand!/Float32 |
586312.5 ns |
533229.5 ns |
1.10 |
array/random/rand/Int64 |
753687.5 ns |
887000 ns |
0.85 |
array/random/rand/Float32 |
652083 ns |
840959 ns |
0.78 |
array/accumulate/Int64/1d |
1270208 ns |
1292146 ns |
0.98 |
array/accumulate/Int64/dims=1 |
1807458 ns |
1865375 ns |
0.97 |
array/accumulate/Int64/dims=2 |
2136250.5 ns |
2215437 ns |
0.96 |
array/accumulate/Int64/dims=1L |
11722000 ns |
12096125 ns |
0.97 |
array/accumulate/Int64/dims=2L |
9775667 ns |
10003417 ns |
0.98 |
array/accumulate/Float32/1d |
1100916.5 ns |
1086042 ns |
1.01 |
array/accumulate/Float32/dims=1 |
1536875 ns |
1581542 ns |
0.97 |
array/accumulate/Float32/dims=2 |
1841833 ns |
1998167 ns |
0.92 |
array/accumulate/Float32/dims=1L |
9771875 ns |
10248396 ns |
0.95 |
array/accumulate/Float32/dims=2L |
7215521 ns |
7422792 ns |
0.97 |
array/reductions/reduce/Int64/1d |
1352583 ns |
1312917 ns |
1.03 |
array/reductions/reduce/Int64/dims=1 |
1073459 ns |
1120125 ns |
0.96 |
array/reductions/reduce/Int64/dims=2 |
1137625 ns |
1153917 ns |
0.99 |
array/reductions/reduce/Int64/dims=1L |
2007917 ns |
2041417 ns |
0.98 |
array/reductions/reduce/Int64/dims=2L |
4211687.5 ns |
3778125 ns |
1.11 |
array/reductions/reduce/Float32/1d |
1016625 ns |
796167 ns |
1.28 |
array/reductions/reduce/Float32/dims=1 |
811875 ns |
794000 ns |
1.02 |
array/reductions/reduce/Float32/dims=2 |
832791 ns |
818562.5 ns |
1.02 |
array/reductions/reduce/Float32/dims=1L |
1312916 ns |
1329000 ns |
0.99 |
array/reductions/reduce/Float32/dims=2L |
1796875 ns |
1796708.5 ns |
1.00 |
array/reductions/mapreduce/Int64/1d |
1546750 ns |
1298666 ns |
1.19 |
array/reductions/mapreduce/Int64/dims=1 |
1082646 ns |
1086313 ns |
1.00 |
array/reductions/mapreduce/Int64/dims=2 |
1238521 ns |
1122666 ns |
1.10 |
array/reductions/mapreduce/Int64/dims=1L |
2000354.5 ns |
2025395.5 ns |
0.99 |
array/reductions/mapreduce/Int64/dims=2L |
3611813 ns |
3647583 ns |
0.99 |
array/reductions/mapreduce/Float32/1d |
1024583.5 ns |
774083.5 ns |
1.32 |
array/reductions/mapreduce/Float32/dims=1 |
817750 ns |
791417 ns |
1.03 |
array/reductions/mapreduce/Float32/dims=2 |
847791.5 ns |
826542 ns |
1.03 |
array/reductions/mapreduce/Float32/dims=1L |
1306833 ns |
1322667 ns |
0.99 |
array/reductions/mapreduce/Float32/dims=2L |
1799916 ns |
1817916.5 ns |
0.99 |
array/private/copyto!/gpu_to_gpu |
644750 ns |
533917 ns |
1.21 |
array/private/copyto!/cpu_to_gpu |
791291.5 ns |
690271 ns |
1.15 |
array/private/copyto!/gpu_to_cpu |
785896 ns |
668542 ns |
1.18 |
array/private/iteration/findall/int |
1614645.5 ns |
1565687.5 ns |
1.03 |
array/private/iteration/findall/bool |
1423792 ns |
1465333.5 ns |
0.97 |
array/private/iteration/findfirst/int |
2059208 ns |
2079042 ns |
0.99 |
array/private/iteration/findfirst/bool |
2027895.5 ns |
2020083 ns |
1.00 |
array/private/iteration/scalar |
5481104 ns |
2787125 ns |
1.97 |
array/private/iteration/logical |
2536375 ns |
2599208 ns |
0.98 |
array/private/iteration/findmin/1d |
2227167 ns |
2265458 ns |
0.98 |
array/private/iteration/findmin/2d |
1503687 ns |
1528791 ns |
0.98 |
array/private/copy |
589709 ns |
847041.5 ns |
0.70 |
array/shared/copyto!/gpu_to_gpu |
83500 ns |
84333 ns |
0.99 |
array/shared/copyto!/cpu_to_gpu |
81229.5 ns |
83042 ns |
0.98 |
array/shared/copyto!/gpu_to_cpu |
83459 ns |
83479.5 ns |
1.00 |
array/shared/iteration/findall/int |
1592166.5 ns |
1558208 ns |
1.02 |
array/shared/iteration/findall/bool |
1446167 ns |
1470708 ns |
0.98 |
array/shared/iteration/findfirst/int |
1650583 ns |
1682792 ns |
0.98 |
array/shared/iteration/findfirst/bool |
1617771 ns |
1644334 ns |
0.98 |
array/shared/iteration/scalar |
209708 ns |
202000 ns |
1.04 |
array/shared/iteration/logical |
2463417 ns |
2368458 ns |
1.04 |
array/shared/iteration/findmin/1d |
1813792 ns |
1845542 ns |
0.98 |
array/shared/iteration/findmin/2d |
1505541 ns |
1521583 ns |
0.99 |
array/shared/copy |
260542 ns |
210959 ns |
1.24 |
array/permutedims/4d |
2364541 ns |
2473375 ns |
0.96 |
array/permutedims/2d |
1142917 ns |
1178666.5 ns |
0.97 |
array/permutedims/3d |
1658292 ns |
1780750 ns |
0.93 |
metal/synchronization/stream |
19416 ns |
19334 ns |
1.00 |
metal/synchronization/context |
20209 ns |
20000 ns |
1.01 |
This comment was automatically generated by workflow using github-action-benchmark.
|
Thank you for your PR, I know many people have been waiting for someone to get around to implementing this for a while! This is a great start! Before I fully dive in to review, here are a few changes I'd like you to make to simplify review (and future maintainance) of the code:
Feel free to ask if you have any questions. |
|
Thanks for the feedback! I've addressed all the requested changes:
Ready for review! |
|
Your PR requires formatting changes to meet the project's style guidelines. Click here to view the suggested changes.diff --git a/lib/mpsgraphs/operations.jl b/lib/mpsgraphs/operations.jl
index b6028253..4abf8a09 100644
--- a/lib/mpsgraphs/operations.jl
+++ b/lib/mpsgraphs/operations.jl
@@ -66,43 +66,53 @@ function identityWithTensor(graph::MPSGraph, tensor::MPSGraphTensor, name = "ide
end
function sliceTensor(graph::MPSGraph, tensor::MPSGraphTensor, dimension::Int, start::Int, length::Int, name = "slice")
- obj = @objc [graph::id{MPSGraph} sliceTensor:tensor::id{MPSGraphTensor}
- dimension:dimension::NSInteger
- start:start::NSInteger
- length:length::NSInteger
- name:name::id{NSString}]::id{MPSGraphTensor}
- MPSGraphTensor(obj)
+ obj = @objc [
+ graph::id{MPSGraph} sliceTensor:tensor::id{MPSGraphTensor}
+ dimension:dimension::NSInteger
+ start:start::NSInteger
+ length:length::NSInteger
+ name:name::id{NSString}
+ ]::id{MPSGraphTensor}
+ return MPSGraphTensor(obj)
end
function concatTensors(graph::MPSGraph, tensors::NSArray, dimension::Int, name = "concat")
- obj = @objc [graph::id{MPSGraph} concatTensors:tensors::id{NSArray}
- dimension:dimension::NSInteger
- name:name::id{NSString}]::id{MPSGraphTensor}
- MPSGraphTensor(obj)
+ obj = @objc [
+ graph::id{MPSGraph} concatTensors:tensors::id{NSArray}
+ dimension:dimension::NSInteger
+ name:name::id{NSString}
+ ]::id{MPSGraphTensor}
+ return MPSGraphTensor(obj)
end
function fastFourierTransformWithTensor(graph::MPSGraph, tensor::MPSGraphTensor, axes::NSArray, descriptor::MPSGraphFFTDescriptor, name = "fft")
- obj = @objc [graph::id{MPSGraph} fastFourierTransformWithTensor:tensor::id{MPSGraphTensor}
- axes:axes::id{NSArray}
- descriptor:descriptor::id{MPSGraphFFTDescriptor}
- name:name::id{NSString}]::id{MPSGraphTensor}
- MPSGraphTensor(obj)
+ obj = @objc [
+ graph::id{MPSGraph} fastFourierTransformWithTensor:tensor::id{MPSGraphTensor}
+ axes:axes::id{NSArray}
+ descriptor:descriptor::id{MPSGraphFFTDescriptor}
+ name:name::id{NSString}
+ ]::id{MPSGraphTensor}
+ return MPSGraphTensor(obj)
end
function realToHermiteanFFTWithTensor(graph::MPSGraph, tensor::MPSGraphTensor, axes::NSArray, descriptor::MPSGraphFFTDescriptor, name = "rfft")
- obj = @objc [graph::id{MPSGraph} realToHermiteanFFTWithTensor:tensor::id{MPSGraphTensor}
- axes:axes::id{NSArray}
- descriptor:descriptor::id{MPSGraphFFTDescriptor}
- name:name::id{NSString}]::id{MPSGraphTensor}
- MPSGraphTensor(obj)
+ obj = @objc [
+ graph::id{MPSGraph} realToHermiteanFFTWithTensor:tensor::id{MPSGraphTensor}
+ axes:axes::id{NSArray}
+ descriptor:descriptor::id{MPSGraphFFTDescriptor}
+ name:name::id{NSString}
+ ]::id{MPSGraphTensor}
+ return MPSGraphTensor(obj)
end
function HermiteanToRealFFTWithTensor(graph::MPSGraph, tensor::MPSGraphTensor, axes::NSArray, descriptor::MPSGraphFFTDescriptor, name = "irfft")
- obj = @objc [graph::id{MPSGraph} HermiteanToRealFFTWithTensor:tensor::id{MPSGraphTensor}
- axes:axes::id{NSArray}
- descriptor:descriptor::id{MPSGraphFFTDescriptor}
- name:name::id{NSString}]::id{MPSGraphTensor}
- MPSGraphTensor(obj)
+ obj = @objc [
+ graph::id{MPSGraph} HermiteanToRealFFTWithTensor:tensor::id{MPSGraphTensor}
+ axes:axes::id{NSArray}
+ descriptor:descriptor::id{MPSGraphFFTDescriptor}
+ name:name::id{NSString}
+ ]::id{MPSGraphTensor}
+ return MPSGraphTensor(obj)
end
""" |
.gitignore
Outdated
| wip.* | ||
| dev | ||
| .vscode | ||
| FFT_REFACTORING_ROADMAP.md |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you undo the unrelated formatting changes that were made to this file?
Implements GPU FFT operations for MtlArray using MPSGraph's fastFourierTransformWithTensor. This addresses issue JuliaGPU#270. Features: - plan_fft, plan_ifft, plan_bfft for ComplexF32 arrays - Multi-dimensional FFT support (single axis, multiple axes, all axes) - FFTW.jl-compatible API via AbstractFFTs.jl interface - Plan execution via * operator and mul! Implementation notes: - Uses MPSGraphFFTDescriptor for forward/inverse control - Scaling handled manually for multi-axis ifft (not via MPSGraph's scalingMode) to ensure correct normalization across all FFT dimensions - Axis mapping accounts for Julia's column-major vs Metal's row-major ordering via shape reversal in placeholderTensor Tested against FFTW.jl with <1e-4 relative tolerance for all operations.
Implements real-to-complex and complex-to-real FFT operations using MPSGraph's realToHermiteanFFTWithTensor and HermiteanToRealFFTWithTensor. Features: - plan_rfft for Float32 arrays (output size n÷2+1 in first FFT dimension) - plan_irfft for ComplexF32 arrays (normalized inverse) - plan_brfft for ComplexF32 arrays (unnormalized inverse) - Proper handling of odd output sizes via roundToOddHermitean - FFTW-compatible dimension conventions The output size reduction follows FFTW convention: the first transformed dimension is reduced to n÷2+1 for rfft, and irfft requires the original size to be specified. Tested against FFTW.jl with <1e-4 relative tolerance.
Phase 4: Type generalization for FFT operations. Changes: - Add ComplexF16 support for fft/ifft/bfft - Add Float16 support for rfft/irfft/brfft - Improve error messages for unsupported types (ComplexF64/Float64) - Add documentation about supported and unsupported types - Add FFTComplexTypes and FFTRealTypes type unions Note: ComplexF64/Float64 are NOT supported by Metal's MPSGraph FFT. Users requiring double precision should use FFTW.jl on CPU. ComplexF16 results have ~3 decimal digits precision (expected for Float16).
Phase 5: Verify and test 1D FFT support. The existing implementation already handles 1D arrays correctly since the multi-dimensional FFT code works for arbitrary dimensions. Added tests: - 1D fft correctness vs FFTW - 1D ifft roundtrip - 1D rfft correctness vs FFTW - 1D rfft -> irfft roundtrip
Phase 6: Implement in-place FFT operations.
New types:
- MtlFFTInplacePlan{T,K,N} - plan that modifies input directly
New functions:
- plan_fft!(x, [region]) - create in-place forward FFT plan
- plan_ifft!(x, [region]) - create in-place inverse FFT plan
- plan_bfft!(x, [region]) - create in-place backward FFT plan
The in-place variants are useful for avoiding memory allocation when
the input data is no longer needed after the transform.
Usage:
x = MtlArray(randn(ComplexF32, 64, 64))
plan = plan_fft!(x)
plan * x # x is modified in-place, returns x
Add shift keyword argument to FFT plans and convenience functions that fuses fftshift/ifftshift into the GPU graph execution. - plan_fft(x; shift=true) fuses fftshift after forward FFT - plan_ifft(x; shift=true) fuses ifftshift before inverse FFT - Convenience functions in Metal.MPSGraphs namespace support shift kwarg - Implemented via slice+concat in MPSGraph (no intermediate memory ops) - Handles both even and odd array sizes correctly Example usage: # Via plans p = plan_fft(x_gpu; shift=true) result = p * x_gpu # Via convenience functions using Metal.MPSGraphs: fft, ifft result = fft(x_gpu; shift=true)
Allow fft(x; shift=true) syntax by extending AbstractFFTs.fft with a shift keyword argument for MtlArray inputs. This provides a cleaner API without requiring separate imports.
- Add sliceTensor, concatTensors to operations.jl - Add fastFourierTransformWithTensor to operations.jl - Add realToHermiteanFFTWithTensor, HermiteanToRealFFTWithTensor to operations.jl - Remove duplicate wrapper functions from fft.jl - Update _apply_fftshift_to_tensor to use operations.jl functions
- Remove shift parameter from all plan types and functions - Remove _apply_fftshift_to_tensor helper - Simplify _execute_single_axis_fft! - Remove fftshift tests (14 tests removed) - Aligns API with CUDA.jl which has no built-in fftshift
- Add Float16 CPU shims for FFTW reference comparisons - Add tolerance functions based on type precision - Restructure tests with reusable test functions - Add @inferred checks for plan creation - Add comprehensive batched transform tests (1D, 2D, 3D) - Test both ComplexF16/ComplexF32 and Float16/Float32 - Tests increased from 55 to 117
…ormatting in operations.jl
40871e3 to
b88d77f
Compare
|
Thanks for the review! I've addressed both points:
|
Summary
Add FFT support for
MtlArrayvia the AbstractFFTs.jl interface. This closes #270.Finally, Metal.jl users can run FFTs on the GPU with the same familiar API they use with FFTW:
Performance
Benchmarked on Apple M2 Max against FFTW.jl on CPU:
For large arrays, Metal FFT is up to 42× faster than CPU.
Features
Complex FFT:
fft,ifft,bfft— forward, inverse, and unnormalized inversefft!,ifft!,bfft!— in-place variantsReal FFT:
rfft,irfft,brfft— real-to-complex and complex-to-real transformsFull AbstractFFTs.jl compatibility:
plan_fft,plan_ifft,plan_rfft, etc.mul!interfaceSupported types:
ComplexF32,ComplexF16Float32,Float16Implementation
Built on MPSGraph's
fastFourierTransformWithTensor,realToHermiteanFFTWithTensor, andHermiteanToRealFFTWithTensor.Code structure follows CUDA.jl's cufft pattern for maintainability:
@objcwrappers inoperations.jlfft.jlLimitations
ComplexF64/Float64not supported — this is a Metal hardware limitation, not a software one. Use FFTW.jl on CPU for double precision.Example Usage