-
Notifications
You must be signed in to change notification settings - Fork 48
Cache MPSGraph instances for matmul to reduce overhead #722
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -30,20 +30,71 @@ else | |||||||||||||||||||||
| end | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| @autoreleasepool function _matmul!(c::MtlArray{Tc}, a::MtlArray{Tab, Na}, b::MtlArray{Tab, Nb}, | ||||||||||||||||||||||
| alpha::Number, beta::Number, | ||||||||||||||||||||||
| transpose_a, transpose_b) where {Tc, Tab, Na, Nb} | ||||||||||||||||||||||
| graph = MPSGraph() | ||||||||||||||||||||||
| #= | ||||||||||||||||||||||
| MPSGraph caching infrastructure. | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| placeA = placeholderTensor(graph, size(a), Tab) | ||||||||||||||||||||||
| placeB = placeholderTensor(graph, size(b), Tab) | ||||||||||||||||||||||
| placeC = placeholderTensor(graph, size(c), Tc) | ||||||||||||||||||||||
| Creating an MPSGraph takes ~2ms per call, which dominates matmul time for small-medium | ||||||||||||||||||||||
| matrices. By caching graphs keyed by their structural parameters (shapes, types, flags), | ||||||||||||||||||||||
| we achieve significant speedup for repeated operations with the same configuration. | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| feeds = Dict{MPSGraphTensor, MPSGraphTensorData}( | ||||||||||||||||||||||
| placeA => MPSGraphTensorData(a), | ||||||||||||||||||||||
| placeB => MPSGraphTensorData(b), | ||||||||||||||||||||||
| placeC => MPSGraphTensorData(c) | ||||||||||||||||||||||
| The cache key includes all parameters that affect graph structure: | ||||||||||||||||||||||
| - Input/output shapes and element types | ||||||||||||||||||||||
| - Transpose flags | ||||||||||||||||||||||
| - Alpha/beta values (baked into graph as constants) | ||||||||||||||||||||||
| =# | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # Cache key for matmul graphs - includes all structural parameters | ||||||||||||||||||||||
| struct MatmulGraphKey | ||||||||||||||||||||||
| size_a::Tuple{Vararg{Int}} | ||||||||||||||||||||||
| size_b::Tuple{Vararg{Int}} | ||||||||||||||||||||||
| size_c::Tuple{Vararg{Int}} | ||||||||||||||||||||||
| eltype_ab::DataType | ||||||||||||||||||||||
| eltype_c::DataType | ||||||||||||||||||||||
| ndims_a::Int | ||||||||||||||||||||||
| ndims_b::Int | ||||||||||||||||||||||
| alpha::Float64 # Normalized to Float64 for hashing | ||||||||||||||||||||||
| beta::Float64 | ||||||||||||||||||||||
| transpose_a::Bool | ||||||||||||||||||||||
| transpose_b::Bool | ||||||||||||||||||||||
| end | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # Cached graph with all tensors needed for execution | ||||||||||||||||||||||
| struct CachedMatmulGraph | ||||||||||||||||||||||
| graph::MPSGraph | ||||||||||||||||||||||
| place_c::MPSGraphTensor | ||||||||||||||||||||||
| place_a::MPSGraphTensor | ||||||||||||||||||||||
| place_b::MPSGraphTensor | ||||||||||||||||||||||
| result::MPSGraphTensor | ||||||||||||||||||||||
| end | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # Thread-safe graph cache with lock | ||||||||||||||||||||||
| const _matmul_graph_cache = Dict{MatmulGraphKey, CachedMatmulGraph}() | ||||||||||||||||||||||
| const _matmul_graph_cache_lock = ReentrantLock() | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # Build graph key from matmul parameters | ||||||||||||||||||||||
| function _make_matmul_key(a::MtlArray{Tab, Na}, b::MtlArray{Tab, Nb}, c::MtlArray{Tc}, | ||||||||||||||||||||||
| alpha::Number, beta::Number, | ||||||||||||||||||||||
| transpose_a::Bool, transpose_b::Bool) where {Tc, Tab, Na, Nb} | ||||||||||||||||||||||
| MatmulGraphKey( | ||||||||||||||||||||||
| size(a), size(b), size(c), | ||||||||||||||||||||||
| Tab, Tc, | ||||||||||||||||||||||
| Na, Nb, | ||||||||||||||||||||||
| Float64(alpha), Float64(beta), | ||||||||||||||||||||||
| transpose_a, transpose_b | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| end | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # Build a new matmul graph (called only on cache miss) | ||||||||||||||||||||||
| function _build_matmul_graph(size_a::Tuple, size_b::Tuple, size_c::Tuple, | ||||||||||||||||||||||
| Tab::DataType, Tc::DataType, | ||||||||||||||||||||||
| Na::Int, Nb::Int, | ||||||||||||||||||||||
| transpose_a::Bool, transpose_b::Bool, | ||||||||||||||||||||||
| alpha::Number, beta::Number) | ||||||||||||||||||||||
| graph = MPSGraph() | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| placeA = placeholderTensor(graph, size_a, Tab) | ||||||||||||||||||||||
| placeB = placeholderTensor(graph, size_b, Tab) | ||||||||||||||||||||||
| placeC = placeholderTensor(graph, size_c, Tc) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # cast to output eltype if input type is an integer type | ||||||||||||||||||||||
| castT = Tab <: Integer ? Tc : Tab | ||||||||||||||||||||||
|
|
@@ -53,8 +104,8 @@ end | |||||||||||||||||||||
| transA = transpose_a ? transposeTensor(graph, castA, Na-2, Na-1, "transpose_a") : castA | ||||||||||||||||||||||
| transB = transpose_b ? transposeTensor(graph, castB, Nb-2, Nb-1, "transpose_b") : castB | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| nBatchA = Na == 2 ? 1 : size(transA)[1] | ||||||||||||||||||||||
| nBatchB = Nb == 2 ? 1 : size(transB)[1] | ||||||||||||||||||||||
| nBatchA = Na == 2 ? 1 : size_a[1] | ||||||||||||||||||||||
| nBatchB = Nb == 2 ? 1 : size_b[1] | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # for batched matmul between different sized tensors | ||||||||||||||||||||||
| broadcastA, broadcastB = if nBatchA == nBatchB | ||||||||||||||||||||||
|
|
@@ -81,12 +132,58 @@ end | |||||||||||||||||||||
|
|
||||||||||||||||||||||
| castC = castTensor(graph, afterbeta, Tc, "castC") | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| CachedMatmulGraph(graph, placeC, placeA, placeB, castC) | ||||||||||||||||||||||
| end | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # Get or create cached graph | ||||||||||||||||||||||
| function _get_cached_graph(key::MatmulGraphKey) | ||||||||||||||||||||||
| # Fast path: check cache without lock (safe for reads) | ||||||||||||||||||||||
| cached = get(_matmul_graph_cache, key, nothing) | ||||||||||||||||||||||
| if cached !== nothing | ||||||||||||||||||||||
| return cached | ||||||||||||||||||||||
| end | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # Slow path: acquire lock and build graph | ||||||||||||||||||||||
| lock(_matmul_graph_cache_lock) do | ||||||||||||||||||||||
| # Double-check after acquiring lock | ||||||||||||||||||||||
| cached = get(_matmul_graph_cache, key, nothing) | ||||||||||||||||||||||
| if cached !== nothing | ||||||||||||||||||||||
| return cached | ||||||||||||||||||||||
| end | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # Build new graph | ||||||||||||||||||||||
| cached = _build_matmul_graph( | ||||||||||||||||||||||
| key.size_a, key.size_b, key.size_c, | ||||||||||||||||||||||
| key.eltype_ab, key.eltype_c, | ||||||||||||||||||||||
| key.ndims_a, key.ndims_b, | ||||||||||||||||||||||
| key.transpose_a, key.transpose_b, | ||||||||||||||||||||||
| key.alpha, key.beta | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| _matmul_graph_cache[key] = cached | ||||||||||||||||||||||
| return cached | ||||||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this work?
Suggested change
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, your suggested pattern works. There's a small tradeoff I see though:
In practice, 40ns is negligible compared to the ~250μs total matmul time, so the cleaner pattern is probably the better choice. Happy to go either way - let me know your preference.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Apologies, my comment didn't properly highlight the code I meant to. My intention was to keep the initial lock-free check and use the |
||||||||||||||||||||||
| end | ||||||||||||||||||||||
| end | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| @autoreleasepool function _matmul!(c::MtlArray{Tc}, a::MtlArray{Tab, Na}, b::MtlArray{Tab, Nb}, | ||||||||||||||||||||||
| alpha::Number, beta::Number, | ||||||||||||||||||||||
| transpose_a, transpose_b) where {Tc, Tab, Na, Nb} | ||||||||||||||||||||||
| # Get or create cached graph | ||||||||||||||||||||||
| key = _make_matmul_key(a, b, c, alpha, beta, transpose_a, transpose_b) | ||||||||||||||||||||||
| cached = _get_cached_graph(key) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # Build feed and result dictionaries with current data | ||||||||||||||||||||||
| feeds = Dict{MPSGraphTensor, MPSGraphTensorData}( | ||||||||||||||||||||||
| cached.place_a => MPSGraphTensorData(a), | ||||||||||||||||||||||
| cached.place_b => MPSGraphTensorData(b), | ||||||||||||||||||||||
| cached.place_c => MPSGraphTensorData(c) | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| resultdict = Dict{MPSGraphTensor, MPSGraphTensorData}( | ||||||||||||||||||||||
| castC => feeds[placeC] | ||||||||||||||||||||||
| cached.result => MPSGraphTensorData(c) | ||||||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should work and save on constructing a
Suggested change
|
||||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| cmdbuf = MPSCommandBuffer(Metal.global_queue(device())) | ||||||||||||||||||||||
| encode!(cmdbuf, graph, NSDictionary(feeds), NSDictionary(resultdict), nil, default_exec_desc()) | ||||||||||||||||||||||
| encode!(cmdbuf, cached.graph, NSDictionary(feeds), NSDictionary(resultdict), nil, default_exec_desc()) | ||||||||||||||||||||||
| commit!(cmdbuf) | ||||||||||||||||||||||
| wait_completed(cmdbuf) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This only gets called by a function that feeds in each
MatmulGraphKeyfield one-by-one, can you modify this function to just take aMatmulGraphKey?