Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 112 additions & 15 deletions lib/mpsgraphs/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,71 @@ else
end


@autoreleasepool function _matmul!(c::MtlArray{Tc}, a::MtlArray{Tab, Na}, b::MtlArray{Tab, Nb},
alpha::Number, beta::Number,
transpose_a, transpose_b) where {Tc, Tab, Na, Nb}
graph = MPSGraph()
#=
MPSGraph caching infrastructure.

placeA = placeholderTensor(graph, size(a), Tab)
placeB = placeholderTensor(graph, size(b), Tab)
placeC = placeholderTensor(graph, size(c), Tc)
Creating an MPSGraph takes ~2ms per call, which dominates matmul time for small-medium
matrices. By caching graphs keyed by their structural parameters (shapes, types, flags),
we achieve significant speedup for repeated operations with the same configuration.

feeds = Dict{MPSGraphTensor, MPSGraphTensorData}(
placeA => MPSGraphTensorData(a),
placeB => MPSGraphTensorData(b),
placeC => MPSGraphTensorData(c)
The cache key includes all parameters that affect graph structure:
- Input/output shapes and element types
- Transpose flags
- Alpha/beta values (baked into graph as constants)
=#

# Cache key for matmul graphs - includes all structural parameters
struct MatmulGraphKey
size_a::Tuple{Vararg{Int}}
size_b::Tuple{Vararg{Int}}
size_c::Tuple{Vararg{Int}}
eltype_ab::DataType
eltype_c::DataType
ndims_a::Int
ndims_b::Int
alpha::Float64 # Normalized to Float64 for hashing
beta::Float64
transpose_a::Bool
transpose_b::Bool
end

# Cached graph with all tensors needed for execution
struct CachedMatmulGraph
graph::MPSGraph
place_c::MPSGraphTensor
place_a::MPSGraphTensor
place_b::MPSGraphTensor
result::MPSGraphTensor
end

# Thread-safe graph cache with lock
const _matmul_graph_cache = Dict{MatmulGraphKey, CachedMatmulGraph}()
const _matmul_graph_cache_lock = ReentrantLock()

# Build graph key from matmul parameters
function _make_matmul_key(a::MtlArray{Tab, Na}, b::MtlArray{Tab, Nb}, c::MtlArray{Tc},
alpha::Number, beta::Number,
transpose_a::Bool, transpose_b::Bool) where {Tc, Tab, Na, Nb}
MatmulGraphKey(
size(a), size(b), size(c),
Tab, Tc,
Na, Nb,
Float64(alpha), Float64(beta),
transpose_a, transpose_b
)
end

# Build a new matmul graph (called only on cache miss)
function _build_matmul_graph(size_a::Tuple, size_b::Tuple, size_c::Tuple,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This only gets called by a function that feeds in each MatmulGraphKey field one-by-one, can you modify this function to just take a MatmulGraphKey?

Tab::DataType, Tc::DataType,
Na::Int, Nb::Int,
transpose_a::Bool, transpose_b::Bool,
alpha::Number, beta::Number)
graph = MPSGraph()

placeA = placeholderTensor(graph, size_a, Tab)
placeB = placeholderTensor(graph, size_b, Tab)
placeC = placeholderTensor(graph, size_c, Tc)

# cast to output eltype if input type is an integer type
castT = Tab <: Integer ? Tc : Tab
Expand All @@ -53,8 +104,8 @@ end
transA = transpose_a ? transposeTensor(graph, castA, Na-2, Na-1, "transpose_a") : castA
transB = transpose_b ? transposeTensor(graph, castB, Nb-2, Nb-1, "transpose_b") : castB

nBatchA = Na == 2 ? 1 : size(transA)[1]
nBatchB = Nb == 2 ? 1 : size(transB)[1]
nBatchA = Na == 2 ? 1 : size_a[1]
nBatchB = Nb == 2 ? 1 : size_b[1]

# for batched matmul between different sized tensors
broadcastA, broadcastB = if nBatchA == nBatchB
Expand All @@ -81,12 +132,58 @@ end

castC = castTensor(graph, afterbeta, Tc, "castC")

CachedMatmulGraph(graph, placeC, placeA, placeB, castC)
end

# Get or create cached graph
function _get_cached_graph(key::MatmulGraphKey)
# Fast path: check cache without lock (safe for reads)
cached = get(_matmul_graph_cache, key, nothing)
if cached !== nothing
return cached
end

# Slow path: acquire lock and build graph
lock(_matmul_graph_cache_lock) do
# Double-check after acquiring lock
cached = get(_matmul_graph_cache, key, nothing)
if cached !== nothing
return cached
end

# Build new graph
cached = _build_matmul_graph(
key.size_a, key.size_b, key.size_c,
key.eltype_ab, key.eltype_c,
key.ndims_a, key.ndims_b,
key.transpose_a, key.transpose_b,
key.alpha, key.beta
)
_matmul_graph_cache[key] = cached
return cached
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work?

Suggested change
return cached
@lock _matmul_graph_cache_lock get!(_matmul_graph_cache, key) do
_build_matmul_graph(
key.size_a, key.size_b, key.size_c,
key.eltype_ab, key.eltype_c,
key.ndims_a, key.ndims_b,
key.transpose_a, key.transpose_b,
key.alpha, key.beta
)
end

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, your suggested pattern works. There's a small tradeoff I see though:

  • Current (double-checked locking): Lock-free on cache hits (~0ns), takes lock only on miss
  • Your suggestion (@lock get!): Always takes lock (~40ns overhead on hits)

In practice, 40ns is negligible compared to the ~250μs total matmul time, so the cleaner pattern is probably the better choice. Happy to go either way - let me know your preference.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies, my comment didn't properly highlight the code I meant to. My intention was to keep the initial lock-free check and use the @lock <thelock> get!... pattern to replace what is at time of writing, lines 146-164

end
end

@autoreleasepool function _matmul!(c::MtlArray{Tc}, a::MtlArray{Tab, Na}, b::MtlArray{Tab, Nb},
alpha::Number, beta::Number,
transpose_a, transpose_b) where {Tc, Tab, Na, Nb}
# Get or create cached graph
key = _make_matmul_key(a, b, c, alpha, beta, transpose_a, transpose_b)
cached = _get_cached_graph(key)

# Build feed and result dictionaries with current data
feeds = Dict{MPSGraphTensor, MPSGraphTensorData}(
cached.place_a => MPSGraphTensorData(a),
cached.place_b => MPSGraphTensorData(b),
cached.place_c => MPSGraphTensorData(c)
)

resultdict = Dict{MPSGraphTensor, MPSGraphTensorData}(
castC => feeds[placeC]
cached.result => MPSGraphTensorData(c)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should work and save on constructing a MPSGraphTensorData to wrap data that’s already been wrapped.

Suggested change
cached.result => MPSGraphTensorData(c)
cached.result => feeds[cached.place_c]

)

cmdbuf = MPSCommandBuffer(Metal.global_queue(device()))
encode!(cmdbuf, graph, NSDictionary(feeds), NSDictionary(resultdict), nil, default_exec_desc())
encode!(cmdbuf, cached.graph, NSDictionary(feeds), NSDictionary(resultdict), nil, default_exec_desc())
commit!(cmdbuf)
wait_completed(cmdbuf)

Expand Down