Skip to content

Conversation

@KaanKesginLW
Copy link
Contributor

@KaanKesginLW KaanKesginLW commented Dec 3, 2025

Summary

Add FFT support for MtlArray via the AbstractFFTs.jl interface. This closes #270.

Finally, Metal.jl users can run FFTs on the GPU with the same familiar API they use with FFTW:

using Metal, AbstractFFTs

x = MtlArray(randn(ComplexF32, 2048, 2048))
y = fft(x)  # Just works!

Performance

Benchmarked on Apple M2 Max against FFTW.jl on CPU:

Size CPU (FFTW) GPU (Metal) Speedup
512×512 4.3ms 4.2ms 1.0×
1024×1024 20ms 5ms
2048×2048 102ms 6.7ms 15×
4096×4096 455ms 11ms 42×

For large arrays, Metal FFT is up to 42× faster than CPU.

Features

Complex FFT:

  • fft, ifft, bfft — forward, inverse, and unnormalized inverse
  • fft!, ifft!, bfft! — in-place variants
  • Works on 1D, 2D, 3D, and N-D arrays

Real FFT:

  • rfft, irfft, brfft — real-to-complex and complex-to-real transforms
  • Proper handling of odd output sizes

Full AbstractFFTs.jl compatibility:

  • plan_fft, plan_ifft, plan_rfft, etc.
  • mul! interface
  • Transform along any dimension or subset of dimensions

Supported types:

  • ComplexF32, ComplexF16
  • Float32, Float16

Implementation

Built on MPSGraph's fastFourierTransformWithTensor, realToHermiteanFFTWithTensor, and HermiteanToRealFFTWithTensor.

Code structure follows CUDA.jl's cufft pattern for maintainability:

  • Low-level @objc wrappers in operations.jl
  • High-level AbstractFFTs interface in fft.jl
  • Test suite adapted from CUDA.jl (117 tests)

Limitations

ComplexF64/Float64 not supported — this is a Metal hardware limitation, not a software one. Use FFTW.jl on CPU for double precision.

Example Usage

using Metal
using AbstractFFTs: fft, ifft, rfft, irfft, plan_fft

# Complex FFT
x = MtlArray(randn(ComplexF32, 1024, 1024))
y = fft(x)
z = ifft(y)  # z ≈ x

# Real FFT  
r = MtlArray(randn(Float32, 1024, 1024))
c = rfft(r)           # Real → Complex
r2 = irfft(c, 1024)   # Complex → Real, r2 ≈ r

# FFT along specific dimensions
y = fft(x, 1)         # First dimension only
y = fft(x, (1, 2))    # Batched transform

# Plan reuse
p = plan_fft(x)
y1 = p * x
y2 = p * another_x    # Same plan, different data

@KaanKesginLW KaanKesginLW mentioned this pull request Dec 3, 2025
@codecov
Copy link

codecov bot commented Dec 3, 2025

Codecov Report

❌ Patch coverage is 94.04255% with 14 lines in your changes missing coverage. Please review.
✅ Project coverage is 81.96%. Comparing base (239fa4d) to head (b88d77f).

Files with missing lines Patch % Lines
lib/mpsgraphs/fft.jl 96.36% 8 Missing ⚠️
lib/mpsgraphs/operations.jl 60.00% 6 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #713      +/-   ##
==========================================
+ Coverage   80.96%   81.96%   +1.00%     
==========================================
  Files          62       63       +1     
  Lines        2837     3072     +235     
==========================================
+ Hits         2297     2518     +221     
- Misses        540      554      +14     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Metal Benchmarks

Details
Benchmark suite Current: b88d77f Previous: 239fa4d Ratio
latency/precompile 24243436166 ns 24383716541 ns 0.99
latency/ttfp 2293510291 ns 2324081375 ns 0.99
latency/import 1407606104 ns 1427504083 ns 0.99
integration/metaldevrt 835916 ns 837292 ns 1.00
integration/byval/slices=1 1533708 ns 1598354 ns 0.96
integration/byval/slices=3 8393604.5 ns 19021791.5 ns 0.44
integration/byval/reference 1533750 ns 1590708.5 ns 0.96
integration/byval/slices=2 2594562 ns 2727250 ns 0.95
kernel/indexing 584959 ns 459062.5 ns 1.27
kernel/indexing_checked 589916 ns 463104.5 ns 1.27
kernel/launch 11542 ns 11625 ns 0.99
kernel/rand 563000 ns 526667 ns 1.07
array/construct 6000 ns 5958 ns 1.01
array/broadcast 592667 ns 545375 ns 1.09
array/random/randn/Float32 746750 ns 886167 ns 0.84
array/random/randn!/Float32 622625 ns 578875 ns 1.08
array/random/rand!/Int64 556354.5 ns 539083 ns 1.03
array/random/rand!/Float32 586312.5 ns 533229.5 ns 1.10
array/random/rand/Int64 753687.5 ns 887000 ns 0.85
array/random/rand/Float32 652083 ns 840959 ns 0.78
array/accumulate/Int64/1d 1270208 ns 1292146 ns 0.98
array/accumulate/Int64/dims=1 1807458 ns 1865375 ns 0.97
array/accumulate/Int64/dims=2 2136250.5 ns 2215437 ns 0.96
array/accumulate/Int64/dims=1L 11722000 ns 12096125 ns 0.97
array/accumulate/Int64/dims=2L 9775667 ns 10003417 ns 0.98
array/accumulate/Float32/1d 1100916.5 ns 1086042 ns 1.01
array/accumulate/Float32/dims=1 1536875 ns 1581542 ns 0.97
array/accumulate/Float32/dims=2 1841833 ns 1998167 ns 0.92
array/accumulate/Float32/dims=1L 9771875 ns 10248396 ns 0.95
array/accumulate/Float32/dims=2L 7215521 ns 7422792 ns 0.97
array/reductions/reduce/Int64/1d 1352583 ns 1312917 ns 1.03
array/reductions/reduce/Int64/dims=1 1073459 ns 1120125 ns 0.96
array/reductions/reduce/Int64/dims=2 1137625 ns 1153917 ns 0.99
array/reductions/reduce/Int64/dims=1L 2007917 ns 2041417 ns 0.98
array/reductions/reduce/Int64/dims=2L 4211687.5 ns 3778125 ns 1.11
array/reductions/reduce/Float32/1d 1016625 ns 796167 ns 1.28
array/reductions/reduce/Float32/dims=1 811875 ns 794000 ns 1.02
array/reductions/reduce/Float32/dims=2 832791 ns 818562.5 ns 1.02
array/reductions/reduce/Float32/dims=1L 1312916 ns 1329000 ns 0.99
array/reductions/reduce/Float32/dims=2L 1796875 ns 1796708.5 ns 1.00
array/reductions/mapreduce/Int64/1d 1546750 ns 1298666 ns 1.19
array/reductions/mapreduce/Int64/dims=1 1082646 ns 1086313 ns 1.00
array/reductions/mapreduce/Int64/dims=2 1238521 ns 1122666 ns 1.10
array/reductions/mapreduce/Int64/dims=1L 2000354.5 ns 2025395.5 ns 0.99
array/reductions/mapreduce/Int64/dims=2L 3611813 ns 3647583 ns 0.99
array/reductions/mapreduce/Float32/1d 1024583.5 ns 774083.5 ns 1.32
array/reductions/mapreduce/Float32/dims=1 817750 ns 791417 ns 1.03
array/reductions/mapreduce/Float32/dims=2 847791.5 ns 826542 ns 1.03
array/reductions/mapreduce/Float32/dims=1L 1306833 ns 1322667 ns 0.99
array/reductions/mapreduce/Float32/dims=2L 1799916 ns 1817916.5 ns 0.99
array/private/copyto!/gpu_to_gpu 644750 ns 533917 ns 1.21
array/private/copyto!/cpu_to_gpu 791291.5 ns 690271 ns 1.15
array/private/copyto!/gpu_to_cpu 785896 ns 668542 ns 1.18
array/private/iteration/findall/int 1614645.5 ns 1565687.5 ns 1.03
array/private/iteration/findall/bool 1423792 ns 1465333.5 ns 0.97
array/private/iteration/findfirst/int 2059208 ns 2079042 ns 0.99
array/private/iteration/findfirst/bool 2027895.5 ns 2020083 ns 1.00
array/private/iteration/scalar 5481104 ns 2787125 ns 1.97
array/private/iteration/logical 2536375 ns 2599208 ns 0.98
array/private/iteration/findmin/1d 2227167 ns 2265458 ns 0.98
array/private/iteration/findmin/2d 1503687 ns 1528791 ns 0.98
array/private/copy 589709 ns 847041.5 ns 0.70
array/shared/copyto!/gpu_to_gpu 83500 ns 84333 ns 0.99
array/shared/copyto!/cpu_to_gpu 81229.5 ns 83042 ns 0.98
array/shared/copyto!/gpu_to_cpu 83459 ns 83479.5 ns 1.00
array/shared/iteration/findall/int 1592166.5 ns 1558208 ns 1.02
array/shared/iteration/findall/bool 1446167 ns 1470708 ns 0.98
array/shared/iteration/findfirst/int 1650583 ns 1682792 ns 0.98
array/shared/iteration/findfirst/bool 1617771 ns 1644334 ns 0.98
array/shared/iteration/scalar 209708 ns 202000 ns 1.04
array/shared/iteration/logical 2463417 ns 2368458 ns 1.04
array/shared/iteration/findmin/1d 1813792 ns 1845542 ns 0.98
array/shared/iteration/findmin/2d 1505541 ns 1521583 ns 0.99
array/shared/copy 260542 ns 210959 ns 1.24
array/permutedims/4d 2364541 ns 2473375 ns 0.96
array/permutedims/2d 1142917 ns 1178666.5 ns 0.97
array/permutedims/3d 1658292 ns 1780750 ns 0.93
metal/synchronization/stream 19416 ns 19334 ns 1.00
metal/synchronization/context 20209 ns 20000 ns 1.01

This comment was automatically generated by workflow using github-action-benchmark.

@christiangnrd
Copy link
Member

Thank you for your PR, I know many people have been waiting for someone to get around to implementing this for a while!

This is a great start! Before I fully dive in to review, here are a few changes I'd like you to make to simplify review (and future maintainance) of the code:

  • All @objc calls should be wrapped so that everything else can be done with Julia code. Graph operations can go in lib/mpsgraph/operations.jl, and constructors can go at the top of the fft.jl file.
  • If you could structure the implementation more closely with CUDA's implementation, it would make it quicker for me to revie
  • Similarly to the last point, for tests, please adapt the CUDA fft tests to here, so that we can take advantage of their regressions tests. Feel free to add tests from the current suite that the CUDA tests don't cover.

Feel free to ask if you have any questions.

@KaanKesginLW
Copy link
Contributor Author

Thanks for the feedback! I've addressed all the requested changes:

  1. Wrapped @objc calls — Graph operations (fastFourierTransformWithTensor, realToHermiteanFFTWithTensor, HermiteanToRealFFTWithTensor, sliceTensor, concatTensors) moved to operations.jl. Constructor kept at top of fft.jl.

  2. Aligned with CUDA.jl structure — Removed the fused fftshift feature to match CUDA.jl's API. Simplified the implementation.

  3. Adapted CUDA.jl tests — Added Float16 shims, tolerance functions, reusable test functions, @inferred checks, and batched transform tests. Test count: 117.

Ready for review!

@github-actions
Copy link
Contributor

github-actions bot commented Dec 4, 2025

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic main) to apply these changes.

Click here to view the suggested changes.
diff --git a/lib/mpsgraphs/operations.jl b/lib/mpsgraphs/operations.jl
index b6028253..4abf8a09 100644
--- a/lib/mpsgraphs/operations.jl
+++ b/lib/mpsgraphs/operations.jl
@@ -66,43 +66,53 @@ function identityWithTensor(graph::MPSGraph, tensor::MPSGraphTensor, name = "ide
 end
 
 function sliceTensor(graph::MPSGraph, tensor::MPSGraphTensor, dimension::Int, start::Int, length::Int, name = "slice")
-    obj = @objc [graph::id{MPSGraph} sliceTensor:tensor::id{MPSGraphTensor}
-                                dimension:dimension::NSInteger
-                                start:start::NSInteger
-                                length:length::NSInteger
-                                name:name::id{NSString}]::id{MPSGraphTensor}
-    MPSGraphTensor(obj)
+    obj = @objc [
+        graph::id{MPSGraph} sliceTensor:tensor::id{MPSGraphTensor}
+        dimension:dimension::NSInteger
+        start:start::NSInteger
+        length:length::NSInteger
+        name:name::id{NSString}
+    ]::id{MPSGraphTensor}
+    return MPSGraphTensor(obj)
 end
 
 function concatTensors(graph::MPSGraph, tensors::NSArray, dimension::Int, name = "concat")
-    obj = @objc [graph::id{MPSGraph} concatTensors:tensors::id{NSArray}
-                                dimension:dimension::NSInteger
-                                name:name::id{NSString}]::id{MPSGraphTensor}
-    MPSGraphTensor(obj)
+    obj = @objc [
+        graph::id{MPSGraph} concatTensors:tensors::id{NSArray}
+        dimension:dimension::NSInteger
+        name:name::id{NSString}
+    ]::id{MPSGraphTensor}
+    return MPSGraphTensor(obj)
 end
 
 function fastFourierTransformWithTensor(graph::MPSGraph, tensor::MPSGraphTensor, axes::NSArray, descriptor::MPSGraphFFTDescriptor, name = "fft")
-    obj = @objc [graph::id{MPSGraph} fastFourierTransformWithTensor:tensor::id{MPSGraphTensor}
-                                axes:axes::id{NSArray}
-                                descriptor:descriptor::id{MPSGraphFFTDescriptor}
-                                name:name::id{NSString}]::id{MPSGraphTensor}
-    MPSGraphTensor(obj)
+    obj = @objc [
+        graph::id{MPSGraph} fastFourierTransformWithTensor:tensor::id{MPSGraphTensor}
+        axes:axes::id{NSArray}
+        descriptor:descriptor::id{MPSGraphFFTDescriptor}
+        name:name::id{NSString}
+    ]::id{MPSGraphTensor}
+    return MPSGraphTensor(obj)
 end
 
 function realToHermiteanFFTWithTensor(graph::MPSGraph, tensor::MPSGraphTensor, axes::NSArray, descriptor::MPSGraphFFTDescriptor, name = "rfft")
-    obj = @objc [graph::id{MPSGraph} realToHermiteanFFTWithTensor:tensor::id{MPSGraphTensor}
-                                axes:axes::id{NSArray}
-                                descriptor:descriptor::id{MPSGraphFFTDescriptor}
-                                name:name::id{NSString}]::id{MPSGraphTensor}
-    MPSGraphTensor(obj)
+    obj = @objc [
+        graph::id{MPSGraph} realToHermiteanFFTWithTensor:tensor::id{MPSGraphTensor}
+        axes:axes::id{NSArray}
+        descriptor:descriptor::id{MPSGraphFFTDescriptor}
+        name:name::id{NSString}
+    ]::id{MPSGraphTensor}
+    return MPSGraphTensor(obj)
 end
 
 function HermiteanToRealFFTWithTensor(graph::MPSGraph, tensor::MPSGraphTensor, axes::NSArray, descriptor::MPSGraphFFTDescriptor, name = "irfft")
-    obj = @objc [graph::id{MPSGraph} HermiteanToRealFFTWithTensor:tensor::id{MPSGraphTensor}
-                                axes:axes::id{NSArray}
-                                descriptor:descriptor::id{MPSGraphFFTDescriptor}
-                                name:name::id{NSString}]::id{MPSGraphTensor}
-    MPSGraphTensor(obj)
+    obj = @objc [
+        graph::id{MPSGraph} HermiteanToRealFFTWithTensor:tensor::id{MPSGraphTensor}
+        axes:axes::id{NSArray}
+        descriptor:descriptor::id{MPSGraphFFTDescriptor}
+        name:name::id{NSString}
+    ]::id{MPSGraphTensor}
+    return MPSGraphTensor(obj)
 end
 
 """

.gitignore Outdated
wip.*
dev
.vscode
FFT_REFACTORING_ROADMAP.md
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you undo the unrelated formatting changes that were made to this file?

Implements GPU FFT operations for MtlArray using MPSGraph's
fastFourierTransformWithTensor. This addresses issue JuliaGPU#270.

Features:
- plan_fft, plan_ifft, plan_bfft for ComplexF32 arrays
- Multi-dimensional FFT support (single axis, multiple axes, all axes)
- FFTW.jl-compatible API via AbstractFFTs.jl interface
- Plan execution via * operator and mul!

Implementation notes:
- Uses MPSGraphFFTDescriptor for forward/inverse control
- Scaling handled manually for multi-axis ifft (not via MPSGraph's
  scalingMode) to ensure correct normalization across all FFT dimensions
- Axis mapping accounts for Julia's column-major vs Metal's row-major
  ordering via shape reversal in placeholderTensor

Tested against FFTW.jl with <1e-4 relative tolerance for all operations.
Implements real-to-complex and complex-to-real FFT operations using
MPSGraph's realToHermiteanFFTWithTensor and HermiteanToRealFFTWithTensor.

Features:
- plan_rfft for Float32 arrays (output size n÷2+1 in first FFT dimension)
- plan_irfft for ComplexF32 arrays (normalized inverse)
- plan_brfft for ComplexF32 arrays (unnormalized inverse)
- Proper handling of odd output sizes via roundToOddHermitean
- FFTW-compatible dimension conventions

The output size reduction follows FFTW convention: the first transformed
dimension is reduced to n÷2+1 for rfft, and irfft requires the original
size to be specified.

Tested against FFTW.jl with <1e-4 relative tolerance.
Phase 4: Type generalization for FFT operations.

Changes:
- Add ComplexF16 support for fft/ifft/bfft
- Add Float16 support for rfft/irfft/brfft
- Improve error messages for unsupported types (ComplexF64/Float64)
- Add documentation about supported and unsupported types
- Add FFTComplexTypes and FFTRealTypes type unions

Note: ComplexF64/Float64 are NOT supported by Metal's MPSGraph FFT.
Users requiring double precision should use FFTW.jl on CPU.

ComplexF16 results have ~3 decimal digits precision (expected for Float16).
Phase 5: Verify and test 1D FFT support.

The existing implementation already handles 1D arrays correctly since
the multi-dimensional FFT code works for arbitrary dimensions.

Added tests:
- 1D fft correctness vs FFTW
- 1D ifft roundtrip
- 1D rfft correctness vs FFTW
- 1D rfft -> irfft roundtrip
Phase 6: Implement in-place FFT operations.

New types:
- MtlFFTInplacePlan{T,K,N} - plan that modifies input directly

New functions:
- plan_fft!(x, [region]) - create in-place forward FFT plan
- plan_ifft!(x, [region]) - create in-place inverse FFT plan
- plan_bfft!(x, [region]) - create in-place backward FFT plan

The in-place variants are useful for avoiding memory allocation when
the input data is no longer needed after the transform.

Usage:
    x = MtlArray(randn(ComplexF32, 64, 64))
    plan = plan_fft!(x)
    plan * x  # x is modified in-place, returns x
Add shift keyword argument to FFT plans and convenience functions
that fuses fftshift/ifftshift into the GPU graph execution.

- plan_fft(x; shift=true) fuses fftshift after forward FFT
- plan_ifft(x; shift=true) fuses ifftshift before inverse FFT
- Convenience functions in Metal.MPSGraphs namespace support shift kwarg
- Implemented via slice+concat in MPSGraph (no intermediate memory ops)
- Handles both even and odd array sizes correctly

Example usage:
  # Via plans
  p = plan_fft(x_gpu; shift=true)
  result = p * x_gpu

  # Via convenience functions
  using Metal.MPSGraphs: fft, ifft
  result = fft(x_gpu; shift=true)
Allow fft(x; shift=true) syntax by extending AbstractFFTs.fft with
a shift keyword argument for MtlArray inputs. This provides a cleaner
API without requiring separate imports.
- Add sliceTensor, concatTensors to operations.jl
- Add fastFourierTransformWithTensor to operations.jl
- Add realToHermiteanFFTWithTensor, HermiteanToRealFFTWithTensor to operations.jl
- Remove duplicate wrapper functions from fft.jl
- Update _apply_fftshift_to_tensor to use operations.jl functions
- Remove shift parameter from all plan types and functions
- Remove _apply_fftshift_to_tensor helper
- Simplify _execute_single_axis_fft!
- Remove fftshift tests (14 tests removed)
- Aligns API with CUDA.jl which has no built-in fftshift
- Add Float16 CPU shims for FFTW reference comparisons
- Add tolerance functions based on type precision
- Restructure tests with reusable test functions
- Add @inferred checks for plan creation
- Add comprehensive batched transform tests (1D, 2D, 3D)
- Test both ComplexF16/ComplexF32 and Float16/Float32
- Tests increased from 55 to 117
@KaanKesginLW
Copy link
Contributor Author

Thanks for the review! I've addressed both points:

  1. Removed the .gitignore change
  2. Restored original formatting in operations.jl - now only the new FFT functions are added without any formatting changes to existing code

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

FFT support

2 participants