1+ #include < cuda_runtime.h>
2+
3+ #include " ../ops.hpp"
4+ #include " pmpp/utils/address.hpp"
5+ #include " pmpp/utils/common.cuh"
6+ #include " pmpp/utils/math.hpp"
7+
8+ namespace pmpp ::ops::cuda
9+ {
10+
11+ template <typename ScalarT, dim3 TILE_DIM>
12+ __global__ void stencilKernel (const ScalarT* input, ScalarT* output,
13+ dim3 shape, const std::array<ScalarT, 7 >& coeffs)
14+ {
15+ int32_t iStart = blockIdx .z * TILE_DIM.z ;
16+ int32_t j = blockIdx .y * TILE_DIM.y + threadIdx .y - 1 ;
17+ int32_t k = blockIdx .x * TILE_DIM.x + threadIdx .x - 1 ;
18+
19+ __shared__ ScalarT inPrev_s[TILE_DIM.x ][TILE_DIM.y ];
20+ __shared__ ScalarT inCurr_s[TILE_DIM.x ][TILE_DIM.y ];
21+ __shared__ ScalarT inNext_s[TILE_DIM.x ][TILE_DIM.y ];
22+
23+ if (iStart - 1 >= 0 && iStart - 1 < shape.z && j >= 0 && j < shape.y &&
24+ k >= 0 && k < shape.x ) {
25+ inPrev_s[threadIdx .y ][threadIdx .x ] = input[offset<uint32_t >(
26+ iStart - 1 , j, k, shape.z , shape.y , shape.x )];
27+ }
28+
29+ if (iStart >= 0 && iStart < shape.z && j >= 0 && j < shape.y && k >= 0 &&
30+ k < shape.x ) {
31+ inCurr_s[threadIdx .y ][threadIdx .x ] =
32+ input[offset<uint32_t >(iStart, j, k, shape.z , shape.y , shape.x )];
33+ }
34+
35+ for (int32_t i = iStart; i < iStart + TILE_DIM.z ; ++i) {
36+ if (i + 1 >= 0 && i + 1 < shape.z && j >= 0 && j < shape.y && k >= 0 &&
37+ k < shape.x ) {
38+ inNext_s[threadIdx .y ][threadIdx .x ] = input[offset<uint32_t >(
39+ i + 1 , j, k, shape.z , shape.y , shape.x )];
40+ }
41+ __syncthreads ();
42+ if (i >= 1 && i < shape.z - 1 && j >= 1 && j < shape.y - 1 && k >= 1 &&
43+ k < shape.x - 1 ) {
44+ if (threadIdx .y >= 1 && threadIdx .y < TILE_DIM.y - 1 &&
45+ threadIdx .x >= 1 && threadIdx .x < TILE_DIM.x - 1 ) {
46+ output[offset<uint32_t >(i, j, k, shape.z , shape.y , shape.x )] =
47+ coeffs[0 ] * inCurr_s[threadIdx .y ][threadIdx .x ] +
48+ coeffs[1 ] * inCurr_s[threadIdx .y ][threadIdx .x - 1 ] +
49+ coeffs[2 ] * inCurr_s[threadIdx .y ][threadIdx .x + 1 ] +
50+ coeffs[3 ] * inCurr_s[threadIdx .y - 1 ][threadIdx .x ] +
51+ coeffs[4 ] * inCurr_s[threadIdx .y + 1 ][threadIdx .x ] +
52+ coeffs[5 ] * inPrev_s[threadIdx .y ][threadIdx .x ] +
53+ coeffs[6 ] * inNext_s[threadIdx .y ][threadIdx .x ];
54+ }
55+ }
56+ __syncthreads ();
57+ inPrev_s[threadIdx .y ][threadIdx .x ] =
58+ inCurr_s[threadIdx .y ][threadIdx .x ];
59+ inCurr_s[threadIdx .y ][threadIdx .x ] =
60+ inNext_s[threadIdx .y ][threadIdx .x ];
61+ }
62+ }
63+
64+ template <>
65+ void launchStencil3d (const fp32_t * input, fp32_t * output, dim3 shape,
66+ const std::array<fp32_t , 7 >& coeffs)
67+ {
68+ constexpr dim3 BLOCK_DIM = {8 , 8 , 8 };
69+ dim3 gridDim = {ceilDiv (shape.x , BLOCK_DIM.x ),
70+ ceilDiv (shape.y , BLOCK_DIM.y ),
71+ ceilDiv (shape.z , BLOCK_DIM.z )};
72+
73+
74+
75+ stencilKernel<fp32_t , BLOCK_DIM>
76+ <<<gridDim , BLOCK_DIM>>> (input, output, shape, coeffs);
77+
78+ PMPP_DEBUG_CUDA_ERR_CHECK (cudaGetLastError ());
79+ }
80+
81+ } // namespace pmpp::ops::cuda
0 commit comments