Skip to content

Commit 7d31268

Browse files
committed
Update matmul
1 parent d8cf784 commit 7d31268

File tree

4 files changed

+110
-41
lines changed

4 files changed

+110
-41
lines changed

.clangd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ CompileFlags:
3838

3939
Diagnostics:
4040
UnusedIncludes: Strict
41-
# MissingIncludes: Strict
4241

4342
ClangTidy:
4443
Add: [
@@ -53,6 +52,7 @@ Diagnostics:
5352
readability-identifier-length,
5453
readability-magic-numbers,
5554
readability-function-cognitive-complexity,
55+
modernize-avoid-c-arrays
5656
]
5757

5858
CheckOptions:

csrc/cmake/compilers/cxx-compiler-configs.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111

1212
include(${PROJECT_SOURCE_DIR}/cmake/utils/common.cmake)
1313

14-
set_default_values(CMAKE_CXX_SCAN_FOR_MODULES OFF)
1514
enable_language(CXX)
15+
set_default_values(CMAKE_CXX_SCAN_FOR_MODULES OFF)
1616

1717
# Generate compile_commands.json in build directory
1818
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)

csrc/lib/ops/matmul/op.cu

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#include <cuda_runtime.h>
2+
3+
#include "pmpp/types/cxx_types.hpp"
4+
5+
namespace pmpp::ops::cuda
6+
{
7+
/**
8+
* Assumes:
9+
* 1. M, N, P are square matrices of size width x width;
10+
* 2. Each thread computes one element;
11+
*/
12+
template <int32_t TILE_SIZE = 16, typename ScalarT = fp32_t>
13+
__global__ void matMulKernel(ScalarT* M, ScalarT* N, ScalarT* P, int32_t Width)
14+
{
15+
__shared__ ScalarT Mds[TILE_SIZE][TILE_SIZE];
16+
__shared__ ScalarT Nds[TILE_SIZE][TILE_SIZE];
17+
18+
int32_t Row = blockIdx.y * TILE_SIZE + threadIdx.y;
19+
int32_t Col = blockIdx.x * TILE_SIZE + threadIdx.x;
20+
21+
fp32_t Pvalue = 0.0F;
22+
for (int32_t ph = 0; ph < Width / TILE_SIZE; ++ph) {
23+
Mds[threadIdx.y][threadIdx.x] = M[Row * Width + (ph * TILE_SIZE + threadIdx.x)];
24+
Nds[threadIdx.y][threadIdx.x] = N[(ph * TILE_SIZE + threadIdx.y) * Width + Col];
25+
__syncthreads();
26+
27+
for (int32_t k = 0; k < TILE_SIZE; ++k) {
28+
Pvalue += Mds[threadIdx.y][k] * Nds[k][threadIdx.x];
29+
}
30+
__syncthreads();
31+
}
32+
33+
P[Row * Width + Col] = Pvalue;
34+
}
35+
} // namespace pmpp::ops::cuda

src/pmpp/models/attention.py

Lines changed: 73 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
import math
2+
from typing import Optional
13
import torch
24
from torch import nn
5+
from torch.nn import functional as F
36
import numpy as np
47

58

@@ -24,7 +27,13 @@ def __init__(self, embed_dim: int, num_heads: int):
2427
self.num_heads: int = num_heads
2528
self.head_size: int = embed_dim // num_heads
2629

27-
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
30+
def forward(
31+
self,
32+
q: torch.Tensor,
33+
k: torch.Tensor,
34+
v: torch.Tensor,
35+
mask: Optional[torch.Tensor] = None,
36+
):
2837
"""
2938
Calculates softmax(Q @ KT / sqrt(dk)) @ V .
3039
@@ -46,30 +55,34 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
4655
"""
4756

4857
q_len, kv_len = q.size(0), k.size(0)
49-
# q: (num_heads, q_len, head_size)
58+
# q -> (num_heads, q_len, head_size)
5059
q = q.view(q_len, self.num_heads, self.head_size).transpose(0, 1)
51-
# k: (num_heads, kv_len, head_size)
60+
# k -> (num_heads, kv_len, head_size)
5261
k = k.view(kv_len, self.num_heads, self.head_size).transpose(0, 1)
53-
# v: (num_heads, kv_len, head_size)
62+
# v -> (num_heads, kv_len, head_size)
5463
v = v.view(kv_len, self.num_heads, self.head_size).transpose(0, 1)
55-
56-
attn_weights = torch.matmul(q, k.transpose(-1, -2)) / torch.sqrt(
57-
torch.tensor(self.head_size, dtype=torch.float32)
64+
# scores -> (num_heads, q_len, kv_len)
65+
scores = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(
66+
self.head_size
5867
)
59-
60-
# logits: (num_heads, q_len, kv_len)
61-
logits = torch.softmax(attn_weights, dim=-1)
62-
63-
# out: (num_head, q_len, head_size)
64-
out = torch.matmul(logits, v)
65-
# out: (q_len, embed_dim)
68+
# scores -> (num_heads, q_len, kv_len)
69+
scores = scores + mask if mask is not None else scores
70+
# scores -> (num_heads, q_len, kv_len)
71+
scores = F.softmax(scores, dim=-1)
72+
# out -> (num_heads, q_len, head_size)
73+
out = torch.matmul(scores, v)
74+
# out -> (q_len, num_heads, head_size)
6675
out = out.transpose(0, 1).reshape(q_len, self.embed_dim)
6776

6877
return out
6978

7079

7180
class MultiHeadAttention(nn.Module):
72-
def __init__(self, embed_dim: int, num_heads: int):
81+
def __init__(
82+
self,
83+
embed_dim: int,
84+
num_heads: int,
85+
):
7386
super().__init__()
7487

7588
self.embed_dim: int = embed_dim
@@ -82,7 +95,11 @@ def __init__(self, embed_dim: int, num_heads: int):
8295

8396
self.attn_kernel = MultiHeadAttentionKernel(embed_dim, num_heads)
8497

85-
def forward(self, seq: torch.Tensor):
98+
def forward(
99+
self,
100+
seq: torch.Tensor,
101+
mask: Optional[torch.Tensor] = None,
102+
):
86103
"""
87104
Parameters
88105
----------
@@ -96,15 +113,15 @@ def forward(self, seq: torch.Tensor):
96113
Attention output, cached K and cached V.
97114
"""
98115

99-
# q: (seq_len, embed_dim)
116+
# q -> (seq_len, embed_dim)
100117
q = self.Wq(seq)
101-
# k: (seq_len, embed_dim)
118+
# k -> (seq_len, embed_dim)
102119
k = self.Wk(seq)
103-
# v: (seq_len, embed_dim)
120+
# v -> (seq_len, embed_dim)
104121
v = self.Wv(seq)
105122

106-
# out: (seq_len, embed_dim)
107-
out = self.Wo(self.attn_kernel(q, k, v))
123+
# out -> (seq_len, embed_dim)
124+
out = self.Wo(self.attn_kernel(q, k, v, mask))
108125

109126
return out, k, v
110127

@@ -125,7 +142,10 @@ def __init__(self, embed_dim: int, num_heads: int):
125142
self.attn_kernel = MultiHeadAttentionKernel(embed_dim, num_heads)
126143

127144
def forward(
128-
self, seq: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor
145+
self,
146+
seq: torch.Tensor,
147+
k_cache: torch.Tensor,
148+
v_cache: torch.Tensor,
129149
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
130150
"""
131151
Parameters
@@ -147,19 +167,19 @@ def forward(
147167
When decoing, the input seq only has ONE newly generated token.
148168
"""
149169

150-
# q: (1, embed_dim)
170+
# q -> (1, embed_dim)
151171
q = self.Wq(seq)
152-
# k: (1, embed_dim)
172+
# k -> (1, embed_dim)
153173
k = self.Wk(seq)
154-
# v: (1, embed_dim)
174+
# v -> (1, embed_dim)
155175
v = self.Wv(seq)
156176

157-
# k_cache: (kv_len + 1, embed_dim)
177+
# k_cache -> (kv_len + 1, embed_dim)
158178
k_cache = torch.cat([k_cache, k.detach()], dim=0)
159-
# v_cache: (kv_len + 1, embed_dim)
179+
# v_cache -> (kv_len + 1, embed_dim)
160180
v_cache = torch.cat([v_cache, v.detach()], dim=0)
161181

162-
# out: (seq_len, embed_dim)
182+
# out -> (seq_len, embed_dim)
163183
out = self.Wo(self.attn_kernel(q, k_cache, v_cache))
164184

165185
return out, k_cache, v_cache
@@ -189,7 +209,11 @@ def __init__(
189209
self.k_cache = nn.Buffer(torch.zeros(size=(0, embed_dim)))
190210
self.v_cache = nn.Buffer(torch.zeros(size=(0, embed_dim)))
191211

192-
def forward(self, prompt: torch.Tensor, is_prefilling: bool = True):
212+
def forward(
213+
self,
214+
prompt: torch.Tensor,
215+
is_prefilling: bool = True,
216+
):
193217
"""
194218
Parameters
195219
----------
@@ -203,19 +227,29 @@ def forward(self, prompt: torch.Tensor, is_prefilling: bool = True):
203227
step, which means `seq_len` should equal to `1`.
204228
"""
205229

206-
# embedded_prompt: (seq_len, embed_dim)
230+
# embedded_prompt -> (seq_len, embed_dim)
207231
embedded_prompt = self.embed(prompt)
208232

209233
if is_prefilling:
210-
# out: (seq_len, embed_dim)
211-
# k: (seq_len, embed_dim)
212-
# v: (seq_len, embed_dim)
213-
out, k, v = self.mha(embedded_prompt)
234+
seq_len = prompt.size(0)
235+
mask = None
236+
if seq_len > 1:
237+
mask = torch.full(
238+
(seq_len, seq_len), -float("Inf"), device=prompt.device
239+
)
240+
mask = torch.triu(mask, diagonal=1)
241+
# out -> (seq_len, embed_dim)
242+
# k -> (seq_len, embed_dim)
243+
# v -> (seq_len, embed_dim)
244+
out, k, v = self.mha(embedded_prompt, mask)
214245
else:
215-
# out: (seq_len, embed_dim)
216-
# k: (kv_len, embed_dim)
217-
# v: (kv_len, embed_dim)
218-
out, k, v = self.cached_mha(embedded_prompt, self.k_cache, self.v_cache)
246+
assert prompt.size(0) == 1
247+
# out -> (seq_len, embed_dim)
248+
# k -> (kv_len, embed_dim)
249+
# v -> (kv_len, embed_dim)
250+
out, k, v = self.cached_mha(
251+
embedded_prompt, self.k_cache, self.v_cache
252+
)
219253

220254
# Update k cache and v cache
221255
# [NOTE]
@@ -226,7 +260,7 @@ def forward(self, prompt: torch.Tensor, is_prefilling: bool = True):
226260

227261
# Use the last token to calculate the probability of each word in the
228262
# vocabulary bank:
229-
# probs: (vocab_size,)
263+
# probs -> (vocab_size,)
230264
probs = torch.softmax(self.proj_to_vocab(out[-1]), dim=-1)
231265

232266
return probs

0 commit comments

Comments
 (0)