|
8 | 8 | namespace pmpp::ops::cuda |
9 | 9 | { |
10 | 10 |
|
11 | | -template <typename ScalarT> |
| 11 | +constexpr int32_t MAX_CONV2D_KERNEL_SIZE = 9; |
| 12 | +__constant__ fp32_t |
| 13 | + CONV2D_KERNEL[MAX_CONV2D_KERNEL_SIZE * MAX_CONV2D_KERNEL_SIZE]; |
| 14 | + |
| 15 | +template <typename ScalarT, uint32_t IN_TILE_SIZE = 32> |
12 | 16 | __global__ void conv2DKernel(const ScalarT* input, const ScalarT* kernel, |
13 | | - ScalarT* output, int32_t input_height, |
14 | | - int32_t input_width, int32_t kernel_size) |
| 17 | + ScalarT* output, int32_t inHeight, |
| 18 | + int32_t inWidth, int32_t kernelSize) |
15 | 19 | { |
16 | | - int32_t outRow = blockIdx.x * blockDim.x + threadIdx.x; |
17 | | - int32_t outCol = blockIdx.y * blockDim.y + threadIdx.y; |
18 | | - |
19 | | - ScalarT tmp = 0; |
20 | | - for (int32_t kRow = 0; kRow < kernel_size; ++kRow) { |
21 | | - for (int32_t kCol = 0; kCol < kernel_size; ++kCol) { |
22 | | - int32_t inRow = outRow + kRow - kernel_size / 2; |
23 | | - int32_t inCol = outCol + kCol - kernel_size / 2; |
24 | | - if (inRow >= 0 && inRow < input_height && inCol >= 0 && |
25 | | - inCol < input_width) { |
26 | | - tmp += input[computeOffset<int32_t>(inRow, inCol, input_width, |
27 | | - input_width)] * |
28 | | - kernel[computeOffset<int32_t>(kRow, kCol, kernel_size, |
29 | | - kernel_size)]; |
| 20 | + uint32_t OUT_TILE_SIZE = IN_TILE_SIZE - kernelSize / 2 * 2; |
| 21 | + |
| 22 | + int32_t outRow = blockIdx.x * OUT_TILE_SIZE + threadIdx.x - kernelSize / 2; |
| 23 | + int32_t outCol = blockIdx.y * OUT_TILE_SIZE + threadIdx.y - kernelSize / 2; |
| 24 | + |
| 25 | + // [NOTE] IN_TILE_SIZE must equal to blockDim.x and blockDim.y |
| 26 | + __shared__ ScalarT inTile[IN_TILE_SIZE][IN_TILE_SIZE]; |
| 27 | + |
| 28 | + if (outRow >= 0 && outRow < inHeight && outCol >= 0 && outCol < inWidth) { |
| 29 | + inTile[threadIdx.x][threadIdx.y] = |
| 30 | + input[computeOffset<uint32_t>(outRow, outCol, inWidth, inHeight)]; |
| 31 | + } else { |
| 32 | + inTile[threadIdx.x][threadIdx.y] = 0.0; |
| 33 | + } |
| 34 | + __syncthreads(); |
| 35 | + |
| 36 | + int32_t outTileRow = threadIdx.x - kernelSize / 2; |
| 37 | + int32_t outTileCol = threadIdx.y - kernelSize / 2; |
| 38 | + |
| 39 | + if (outRow >= 0 && outRow < inHeight && outCol >= 0 && outCol < inWidth) { |
| 40 | + if (outTileRow >= 0 && outTileRow < OUT_TILE_SIZE && outTileCol >= 0 && |
| 41 | + outTileCol < OUT_TILE_SIZE) { |
| 42 | + ScalarT tmp = 0; |
| 43 | + for (int32_t kRow = 0; kRow < kernelSize; ++kRow) { |
| 44 | + for (int32_t kCol = 0; kCol < kernelSize; ++kCol) { |
| 45 | + tmp += CONV2D_KERNEL[computeOffset<uint32_t>( |
| 46 | + kRow, kCol, kernelSize, kernelSize)] * |
| 47 | + inTile[kRow + outTileRow][kCol + outTileCol]; |
| 48 | + } |
30 | 49 | } |
| 50 | + output[computeOffset<uint32_t>(outRow, outCol, inWidth, inWidth)] = |
| 51 | + tmp; |
31 | 52 | } |
32 | 53 | } |
33 | | - output[computeOffset<int32_t>(outRow, outCol, input_width, input_width)] = |
34 | | - tmp; |
35 | 54 | } |
36 | 55 |
|
37 | 56 | template <> |
38 | 57 | void launchConv2D<fp32_t>(const fp32_t* d_input, const fp32_t* d_kernel, |
39 | 58 | fp32_t* d_output, int32_t inputHeight, |
40 | 59 | int32_t inputWidth, int32_t kernelSize) |
41 | 60 | { |
42 | | - dim3 blockSize = {32, 32, 1}; |
43 | | - dim3 gridSize = {uint32_t(ceilDiv(inputWidth, blockSize.x)), |
44 | | - uint32_t(ceilDiv(inputHeight, blockSize.y))}; |
45 | | - conv2DKernel<<<gridSize, blockSize>>>(d_input, d_kernel, d_output, |
46 | | - inputHeight, inputWidth, kernelSize); |
| 61 | + if (kernelSize > MAX_CONV2D_KERNEL_SIZE) { |
| 62 | + throw std::runtime_error("Kernel size is too large"); |
| 63 | + } |
| 64 | + |
| 65 | + cudaMemcpyToSymbol(CONV2D_KERNEL, d_kernel, |
| 66 | + kernelSize * kernelSize * sizeof(fp32_t)); |
| 67 | + |
| 68 | + dim3 blockDim = {32, 32, 1}; |
| 69 | + dim3 gridDim = {uint32_t(ceilDiv(inputWidth, blockDim.x)), |
| 70 | + uint32_t(ceilDiv(inputHeight, blockDim.y))}; |
| 71 | + conv2DKernel<fp32_t, 32><<<gridDim, blockDim>>>( |
| 72 | + d_input, d_kernel, d_output, inputHeight, inputWidth, kernelSize); |
47 | 73 |
|
48 | 74 | PMPP_DEBUG_CUDA_ERR_CHECK(cudaGetLastError()); |
49 | 75 | } |
|
0 commit comments