@@ -9,47 +9,60 @@ namespace pmpp::ops::cuda
99{
1010
1111constexpr int32_t MAX_CONV2D_KERNEL_SIZE = 9 ;
12- __constant__ fp32_t
13- CONV2D_KERNEL[MAX_CONV2D_KERNEL_SIZE * MAX_CONV2D_KERNEL_SIZE];
12+ __constant__ fp32_t KERNEL[MAX_CONV2D_KERNEL_SIZE * MAX_CONV2D_KERNEL_SIZE];
1413
15- template <typename ScalarT, uint32_t IN_TILE_SIZE = 32 >
14+ template <typename ScalarT, uint32_t TILE_SIZE = 32 >
1615__global__ void conv2DKernel (const ScalarT* input, const ScalarT* kernel,
17- ScalarT* output, int32_t inHeight ,
18- int32_t inWidth, int32_t kernelSize)
16+ ScalarT* output, int32_t nRows, int32_t nCols ,
17+ int32_t kernelSize)
1918{
20- uint32_t OUT_TILE_SIZE = IN_TILE_SIZE - kernelSize / 2 * 2 ;
19+ // Each block computes (TILE_SIZE, TILE_SIZE) output elements
20+ // Each block contains (TILE_SIZE, TILE_SIZE) threads
21+ // TILE_SIZE must equal to blockDim.x and blockDim.y
2122
22- int32_t outRow = blockIdx .x * OUT_TILE_SIZE + threadIdx .x - kernelSize / 2 ;
23- int32_t outCol = blockIdx .y * OUT_TILE_SIZE + threadIdx .y - kernelSize / 2 ;
23+ // Current thread computes element at output[outRow, outCol]
24+ int32_t outRow = blockIdx .x * TILE_SIZE + threadIdx .x ;
25+ int32_t outCol = blockIdx .y * TILE_SIZE + threadIdx .y ;
2426
25- // [NOTE] IN_TILE_SIZE must equal to blockDim.x and blockDim.y
26- __shared__ ScalarT inTile[IN_TILE_SIZE][IN_TILE_SIZE];
27-
28- if (outRow >= 0 && outRow < inHeight && outCol >= 0 && outCol < inWidth) {
27+ __shared__ ScalarT inTile[TILE_SIZE][TILE_SIZE];
28+ // Load input tile into shared memory
29+ if (outRow < nRows && outCol < nCols) {
2930 inTile[threadIdx .x ][threadIdx .y ] =
30- input[computeOffset <uint32_t >(outRow, outCol, inWidth, inHeight )];
31+ input[offset <uint32_t >(outRow, outCol, nRows, nCols )];
3132 } else {
3233 inTile[threadIdx .x ][threadIdx .y ] = 0.0 ;
3334 }
3435 __syncthreads ();
3536
36- int32_t outTileRow = threadIdx .x - kernelSize / 2 ;
37- int32_t outTileCol = threadIdx .y - kernelSize / 2 ;
38-
39- if (outRow >= 0 && outRow < inHeight && outCol >= 0 && outCol < inWidth) {
40- if (outTileRow >= 0 && outTileRow < OUT_TILE_SIZE && outTileCol >= 0 &&
41- outTileCol < OUT_TILE_SIZE) {
42- ScalarT tmp = 0 ;
43- for (int32_t kRow = 0 ; kRow < kernelSize; ++kRow ) {
44- for (int32_t kCol = 0 ; kCol < kernelSize; ++kCol ) {
45- tmp += CONV2D_KERNEL[computeOffset<uint32_t >(
46- kRow , kCol , kernelSize, kernelSize)] *
47- inTile[kRow + outTileRow][kCol + outTileCol];
37+ if (outRow < nRows && outCol < nCols) {
38+ ScalarT tmp = 0 ;
39+ // To compute one output element, each thread needs to loop over the
40+ // kernel:
41+ for (int32_t kRow = 0 ; kRow < kernelSize; ++kRow ) {
42+ for (int32_t kCol = 0 ; kCol < kernelSize; ++kCol ) {
43+ // Realative kernel index in the input tile
44+ int32_t rkInRow = threadIdx .x - kernelSize / 2 + kRow ;
45+ int32_t rkInCol = threadIdx .y - kernelSize / 2 + kCol ;
46+ if (rkInRow >= 0 && rkInRow < TILE_SIZE && rkInCol >= 0 &&
47+ rkInCol < TILE_SIZE) {
48+ tmp += inTile[rkInRow][rkInCol] *
49+ KERNEL[offset<uint32_t >(kRow , kCol , kernelSize,
50+ kernelSize)];
51+ } else {
52+ // Boundary
53+ int32_t inRow = outRow - kernelSize / 2 + kRow ;
54+ int32_t inCol = outCol - kernelSize / 2 + kCol ;
55+ if (inRow >= 0 && inRow < nRows && inCol >= 0 &&
56+ inCol < nCols) {
57+ tmp += input[offset<uint32_t >(inRow, inCol, nRows,
58+ nCols)] *
59+ KERNEL[offset<uint32_t >(kRow , kCol , kernelSize,
60+ kernelSize)];
61+ }
4862 }
4963 }
50- output[computeOffset<uint32_t >(outRow, outCol, inWidth, inWidth)] =
51- tmp;
5264 }
65+ output[offset<uint32_t >(outRow, outCol, nRows, nCols)] = tmp;
5366 }
5467}
5568
@@ -62,7 +75,7 @@ void launchConv2d<fp32_t>(const fp32_t* d_input, const fp32_t* d_kernel,
6275 throw std::runtime_error (" Kernel size is too large" );
6376 }
6477
65- cudaMemcpyToSymbol (CONV2D_KERNEL , d_kernel,
78+ cudaMemcpyToSymbol (KERNEL , d_kernel,
6679 kernelSize * kernelSize * sizeof (fp32_t ));
6780
6881 dim3 blockDim = {32 , 32 , 1 };
0 commit comments