11#include < cuda_runtime.h>
22
3- #include " pmpp/types/cxx_types.hpp"
3+ #include " ../ops.hpp"
4+ #include " pmpp/utils/math.hpp"
45
56namespace pmpp ::ops::cuda
67{
78/* *
8- * Assumes:
9- * 1. M, N, P are square matrices of size width x width;
10- * 2. Each thread computes one element;
9+ * @brief Matrix multiplication kernel
10+ *
11+ * @note 1. A, B, C are square matrices of size (m, m);
12+ * 2. Each thread computes 1 element of C and each block computes
13+ * (TILE_SIZE, TILE_SIZE) elements of C, which means block size should
14+ * be (TILE_SIZE, TILE_SIZE);
15+ * @todo Add boundary checks.
1116 */
12- template <int32_t TILE_SIZE = 16 , typename ScalarT = fp32_t >
13- __global__ void matMulKernel (ScalarT* M, ScalarT* N, ScalarT* P, int32_t Width)
17+ template <int32_t TILE_SIZE = 32 , typename ScalarT = fp32_t >
18+ __global__ void matmulKernel (const ScalarT* A, const ScalarT* B, ScalarT* C,
19+ int32_t m)
1420{
1521 __shared__ ScalarT Mds[TILE_SIZE][TILE_SIZE];
1622 __shared__ ScalarT Nds[TILE_SIZE][TILE_SIZE];
1723
18- int32_t Row = blockIdx .y * TILE_SIZE + threadIdx .y ;
19- int32_t Col = blockIdx .x * TILE_SIZE + threadIdx .x ;
24+ int32_t row = blockIdx .x * TILE_SIZE + threadIdx .x ;
25+ int32_t col = blockIdx .y * TILE_SIZE + threadIdx .y ;
2026
21- fp32_t Pvalue = 0 .0F ;
22- for (int32_t ph = 0 ; ph < Width / TILE_SIZE; ++ph) {
23- Mds[threadIdx .y ][threadIdx .x ] =
24- M[Row * Width + (ph * TILE_SIZE + threadIdx .x )];
25- Nds[threadIdx .y ][threadIdx .x ] =
26- N [(ph * TILE_SIZE + threadIdx .y ) * Width + Col ];
27+ ScalarT tmp = 0 .0F ;
28+ for (int32_t ph = 0 ; ph < m / TILE_SIZE; ++ph) {
29+ Mds[threadIdx .x ][threadIdx .y ] =
30+ A[row * m + (ph * TILE_SIZE + threadIdx .y )];
31+ Nds[threadIdx .x ][threadIdx .y ] =
32+ B [(ph * TILE_SIZE + threadIdx .x ) * m + col ];
2733 __syncthreads ();
2834
2935 for (int32_t k = 0 ; k < TILE_SIZE; ++k) {
30- Pvalue += Mds[threadIdx .y ][k] * Nds[k][threadIdx .x ];
36+ tmp += Mds[threadIdx .x ][k] * Nds[k][threadIdx .y ];
3137 }
3238 __syncthreads ();
3339 }
3440
35- P[Row * Width + Col] = Pvalue;
41+ C[row * m + col] = tmp;
42+ }
43+
44+ void launchMatmul (const fp32_t * dA, const fp32_t * dB, fp32_t * dC, size_t m)
45+ {
46+ constexpr uint32_t tileSize = 32 ;
47+
48+ dim3 blockSize = {tileSize, tileSize};
49+ dim3 gridSize = {uint32_t (ceilDiv (m, tileSize)),
50+ uint32_t (ceilDiv (m, tileSize))};
51+
52+ matmulKernel<tileSize, fp32_t >
53+ <<<gridSize, blockSize>>> (dA, dB, dC, int32_t (m));
3654}
3755} // namespace pmpp::ops::cuda
0 commit comments