Skip to content

Commit e9eb453

Browse files
[skip ci] Auto-format code with clang-format and black
1 parent 7d31268 commit e9eb453

File tree

2 files changed

+6
-8
lines changed

2 files changed

+6
-8
lines changed

csrc/lib/ops/matmul/op.cu

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@ __global__ void matMulKernel(ScalarT* M, ScalarT* N, ScalarT* P, int32_t Width)
2020

2121
fp32_t Pvalue = 0.0F;
2222
for (int32_t ph = 0; ph < Width / TILE_SIZE; ++ph) {
23-
Mds[threadIdx.y][threadIdx.x] = M[Row * Width + (ph * TILE_SIZE + threadIdx.x)];
24-
Nds[threadIdx.y][threadIdx.x] = N[(ph * TILE_SIZE + threadIdx.y) * Width + Col];
23+
Mds[threadIdx.y][threadIdx.x] =
24+
M[Row * Width + (ph * TILE_SIZE + threadIdx.x)];
25+
Nds[threadIdx.y][threadIdx.x] =
26+
N[(ph * TILE_SIZE + threadIdx.y) * Width + Col];
2527
__syncthreads();
2628

2729
for (int32_t k = 0; k < TILE_SIZE; ++k) {

src/pmpp/models/attention.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,7 @@ def forward(
6262
# v -> (num_heads, kv_len, head_size)
6363
v = v.view(kv_len, self.num_heads, self.head_size).transpose(0, 1)
6464
# scores -> (num_heads, q_len, kv_len)
65-
scores = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(
66-
self.head_size
67-
)
65+
scores = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.head_size)
6866
# scores -> (num_heads, q_len, kv_len)
6967
scores = scores + mask if mask is not None else scores
7068
# scores -> (num_heads, q_len, kv_len)
@@ -247,9 +245,7 @@ def forward(
247245
# out -> (seq_len, embed_dim)
248246
# k -> (kv_len, embed_dim)
249247
# v -> (kv_len, embed_dim)
250-
out, k, v = self.cached_mha(
251-
embedded_prompt, self.k_cache, self.v_cache
252-
)
248+
out, k, v = self.cached_mha(embedded_prompt, self.k_cache, self.v_cache)
253249

254250
# Update k cache and v cache
255251
# [NOTE]

0 commit comments

Comments
 (0)