Skip to content

Commit 8626a8d

Browse files
author
root
committed
Update Tiled Conv
1 parent 344fac9 commit 8626a8d

File tree

6 files changed

+64
-33
lines changed

6 files changed

+64
-33
lines changed

.github/workflows/ci-auto-format-and-commit.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ jobs:
2929
- name: Install formatter
3030
shell: bash
3131
run: |
32-
wget https://apt.llvm.org/llvm.sh && sudo bash ./llvm.sh 20 && rm ./llvm.sh
32+
wget https://apt.llvm.org/llvm.sh && chmod +x ./llvm.sh && ./llvm.sh 20
33+
sudo apt-get update
3334
sudo apt-get install clang-format-20
3435
sudo ln -sf $(which clang-format-20) /usr/bin/clang-format
3536
python -m pip install black

configs/ctests.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
OpTest:
22
Conv2D:
3-
- inputHeight: 32
4-
inputWidth: 32
3+
- inputHeight: 30
4+
inputWidth: 30
55
kernelSize: 3
6-
- inputHeight: 320
7-
inputWidth: 320
6+
- inputHeight: 300
7+
inputWidth: 300
88
kernelSize: 3

csrc/lib/ops/conv2d/op.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,14 @@ void launchConv2D<fp32_t>(const fp32_t* input, const fp32_t* kernel,
1717
? inputHeight - 1
1818
: i + kernelSize / 2;
1919
int32_t endCol = j + kernelSize / 2 >= inputWidth
20-
? inputWidth - 1
21-
: j + kernelSize / 2;
22-
20+
? inputWidth - 1
21+
: j + kernelSize / 2;
22+
2323
for (int32_t k = startRow; k <= endRow; ++k) {
2424
for (int32_t l = startCol; l <= endCol; ++l) {
25-
tmp += input[k * inputWidth + l] * kernel[(k - i + kernelSize / 2) * kernelSize + (l - j + kernelSize / 2)];
25+
tmp += input[k * inputWidth + l] *
26+
kernel[(k - i + kernelSize / 2) * kernelSize +
27+
(l - j + kernelSize / 2)];
2628
}
2729
}
2830
output[i * inputWidth + j] = tmp;

csrc/lib/ops/conv2d/op.cu

Lines changed: 50 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,42 +8,68 @@
88
namespace pmpp::ops::cuda
99
{
1010

11-
template <typename ScalarT>
11+
constexpr int32_t MAX_CONV2D_KERNEL_SIZE = 9;
12+
__constant__ fp32_t
13+
CONV2D_KERNEL[MAX_CONV2D_KERNEL_SIZE * MAX_CONV2D_KERNEL_SIZE];
14+
15+
template <typename ScalarT, uint32_t IN_TILE_SIZE = 32>
1216
__global__ void conv2DKernel(const ScalarT* input, const ScalarT* kernel,
13-
ScalarT* output, int32_t input_height,
14-
int32_t input_width, int32_t kernel_size)
17+
ScalarT* output, int32_t inHeight,
18+
int32_t inWidth, int32_t kernelSize)
1519
{
16-
int32_t outRow = blockIdx.x * blockDim.x + threadIdx.x;
17-
int32_t outCol = blockIdx.y * blockDim.y + threadIdx.y;
18-
19-
ScalarT tmp = 0;
20-
for (int32_t kRow = 0; kRow < kernel_size; ++kRow) {
21-
for (int32_t kCol = 0; kCol < kernel_size; ++kCol) {
22-
int32_t inRow = outRow + kRow - kernel_size / 2;
23-
int32_t inCol = outCol + kCol - kernel_size / 2;
24-
if (inRow >= 0 && inRow < input_height && inCol >= 0 &&
25-
inCol < input_width) {
26-
tmp += input[computeOffset<int32_t>(inRow, inCol, input_width,
27-
input_width)] *
28-
kernel[computeOffset<int32_t>(kRow, kCol, kernel_size,
29-
kernel_size)];
20+
uint32_t OUT_TILE_SIZE = IN_TILE_SIZE - kernelSize / 2 * 2;
21+
22+
int32_t outRow = blockIdx.x * OUT_TILE_SIZE + threadIdx.x - kernelSize / 2;
23+
int32_t outCol = blockIdx.y * OUT_TILE_SIZE + threadIdx.y - kernelSize / 2;
24+
25+
// [NOTE] IN_TILE_SIZE must equal to blockDim.x and blockDim.y
26+
__shared__ ScalarT inTile[IN_TILE_SIZE][IN_TILE_SIZE];
27+
28+
if (outRow >= 0 && outRow < inHeight && outCol >= 0 && outCol < inWidth) {
29+
inTile[threadIdx.x][threadIdx.y] =
30+
input[computeOffset<uint32_t>(outRow, outCol, inWidth, inHeight)];
31+
} else {
32+
inTile[threadIdx.x][threadIdx.y] = 0.0;
33+
}
34+
__syncthreads();
35+
36+
int32_t outTileRow = threadIdx.x - kernelSize / 2;
37+
int32_t outTileCol = threadIdx.y - kernelSize / 2;
38+
39+
if (outRow >= 0 && outRow < inHeight && outCol >= 0 && outCol < inWidth) {
40+
if (outTileRow >= 0 && outTileRow < OUT_TILE_SIZE && outTileCol >= 0 &&
41+
outTileCol < OUT_TILE_SIZE) {
42+
ScalarT tmp = 0;
43+
for (int32_t kRow = 0; kRow < kernelSize; ++kRow) {
44+
for (int32_t kCol = 0; kCol < kernelSize; ++kCol) {
45+
tmp += CONV2D_KERNEL[computeOffset<uint32_t>(
46+
kRow, kCol, kernelSize, kernelSize)] *
47+
inTile[kRow + outTileRow][kCol + outTileCol];
48+
}
3049
}
50+
output[computeOffset<uint32_t>(outRow, outCol, inWidth, inWidth)] =
51+
tmp;
3152
}
3253
}
33-
output[computeOffset<int32_t>(outRow, outCol, input_width, input_width)] =
34-
tmp;
3554
}
3655

3756
template <>
3857
void launchConv2D<fp32_t>(const fp32_t* d_input, const fp32_t* d_kernel,
3958
fp32_t* d_output, int32_t inputHeight,
4059
int32_t inputWidth, int32_t kernelSize)
4160
{
42-
dim3 blockSize = {32, 32, 1};
43-
dim3 gridSize = {uint32_t(ceilDiv(inputWidth, blockSize.x)),
44-
uint32_t(ceilDiv(inputHeight, blockSize.y))};
45-
conv2DKernel<<<gridSize, blockSize>>>(d_input, d_kernel, d_output,
46-
inputHeight, inputWidth, kernelSize);
61+
if (kernelSize > MAX_CONV2D_KERNEL_SIZE) {
62+
throw std::runtime_error("Kernel size is too large");
63+
}
64+
65+
cudaMemcpyToSymbol(CONV2D_KERNEL, d_kernel,
66+
kernelSize * kernelSize * sizeof(fp32_t));
67+
68+
dim3 blockDim = {32, 32, 1};
69+
dim3 gridDim = {uint32_t(ceilDiv(inputWidth, blockDim.x)),
70+
uint32_t(ceilDiv(inputHeight, blockDim.y))};
71+
conv2DKernel<fp32_t, 32><<<gridDim, blockDim>>>(
72+
d_input, d_kernel, d_output, inputHeight, inputWidth, kernelSize);
4773

4874
PMPP_DEBUG_CUDA_ERR_CHECK(cudaGetLastError());
4975
}

scripts/build.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,4 +47,5 @@ cmake -S $SOURCE_DIR -B $BUILD_DIR -G Ninja \
4747
-DVCPKG_TARGET_TRIPLET="x64-linux" \
4848
-DVCPKG_OVERLAY_TRIPLETS="csrc/cmake/vcpkg-triplets"
4949

50+
GTEST_COLOR=yes \
5051
cmake --build $BUILD_DIR -j $(nproc) --target all check

test/test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@
2020
print(pic_out_cuda.cpu())
2121

2222
print(torch.ops.pmpp.matmul(torch.ones((32, 32)).cuda(), torch.ones((32, 32)).cuda()))
23+

0 commit comments

Comments
 (0)