Skip to content

Commit 2f40b5a

Browse files
committed
Update Reduction
1 parent 7a99eea commit 2f40b5a

File tree

4 files changed

+32
-24
lines changed

4 files changed

+32
-24
lines changed

.vscode/settings.json

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,8 @@
2626
"--completion-style=bundled",
2727
"--cross-file-rename",
2828
"--header-insertion=never",
29-
"--header-insertion-decorators",
3029
"--background-index",
31-
"-j=16",
30+
"-j=12",
3231
"--pch-storage=memory",
3332
"--function-arg-placeholders=false",
3433
],

csrc/lib/ops/reduction/op.cuh

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,34 +8,46 @@ namespace pmpp::ops::cuda
88
{
99

1010
template <typename ScalarT, typename PredT>
11-
__global__ void reductionKernel(ScalarT* in, ScalarT* out, const PredT& pred)
11+
__global__ void reductionKernel(const ScalarT* in, uint32_t n, ScalarT* out,
12+
const PredT& pred)
1213
{
1314
// Thread index in the block
14-
int32_t bTid = threadIdx.x;
15-
int32_t i = bTid * 2;
16-
for (uint32_t stride = 1; stride < blockDim.x; stride *= 2) {
17-
if (bTid % stride == 0) {
18-
in[i] = pred(in[i], in[i + stride]);
19-
}
15+
uint32_t bTid = threadIdx.x;
16+
extern __shared__ ScalarT shmem[];
17+
18+
uint32_t stride = blockDim.x;
19+
shmem[bTid] = pred(in[bTid], in[bTid + stride]);
20+
stride /= 2;
21+
22+
for (; stride >= 1; stride /= 2) {
2023
__syncthreads();
24+
if (bTid < stride) {
25+
shmem[bTid] = pred(shmem[bTid], shmem[bTid + stride]);
26+
}
2127
}
2228
if (bTid == 0) {
23-
out[blockIdx.x] = in[0];
29+
out[0] = shmem[0];
2430
}
2531
}
2632

2733
template <typename ScalarT, typename PredT>
28-
[[nodiscard]] auto launchReduction(ScalarT* in, int32_t n, const PredT& pred)
29-
-> ScalarT
34+
[[nodiscard]] auto launchReduction(const ScalarT* in, uint32_t n,
35+
const PredT& pred) -> ScalarT
3036
{
31-
dim3 blockDim = {uint32_t(n), 1, 1};
32-
dim3 gridDim = {uint32_t(ceilDiv(n, blockDim.x)), 1, 1};
3337
ScalarT* d_out;
34-
cudaMalloc(&d_out, gridDim.x * sizeof(ScalarT));
35-
reductionKernel<<<gridDim, blockDim>>>(in, d_out, pred);
38+
cudaMalloc(&d_out, 1 * sizeof(ScalarT));
39+
40+
uint32_t nTreads = n / 2;
41+
dim3 blockDim = {nTreads, 1, 1};
42+
dim3 gridDim = {1, 1, 1};
43+
uint32_t shmemSize = blockDim.x * sizeof(ScalarT);
44+
45+
reductionKernel<<<gridDim, blockDim, shmemSize>>>(in, n, d_out, pred);
46+
3647
ScalarT out;
3748
cudaMemcpy(&out, d_out, sizeof(ScalarT), cudaMemcpyDeviceToHost);
3849
cudaFree(d_out);
50+
3951
PMPP_DEBUG_CUDA_ERR_CHECK(cudaGetLastError());
4052

4153
return out;
@@ -46,13 +58,12 @@ namespace torch_impl
4658
[[nodiscard]] inline auto mulReduction(const torch::Tensor& in)
4759
-> torch::Tensor
4860
{
49-
torch::Tensor mutableIn = in.contiguous();
5061
torch::Tensor result = {};
5162

5263
switch (in.scalar_type()) {
5364
case torch::kFloat32: {
5465
result =
55-
torch::tensor(launchReduction(mutableIn.mutable_data_ptr<fp32_t>(),
66+
torch::tensor(launchReduction(in.const_data_ptr<fp32_t>(),
5667
in.numel(), std::multiplies<>()),
5768
in.options());
5869
break;

csrc/test/OpTest/MulReduction.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,13 @@ TEST_F(OpTest, MulRedection)
1616

1717
for (auto cfg : configs) {
1818
auto nInputs = cfg["nInputs"].as<pmpp::int64_t>();
19-
Tensor input = torch::randint(1, 10, {nInputs}).to(torch::kFloat32);
19+
Tensor input = torch::rand({nInputs}).to(torch::kFloat32) * 1.5 + 0.5;
2020

2121
Tensor resultCPU = custom_op.call(input);
2222
Tensor resultCUDA = custom_op.call(input.cuda());
2323

24-
std::cout << resultCPU << std::endl;
25-
std::cout << resultCUDA << std::endl;
26-
27-
EXPECT_TRUE(resultCPU.equal(resultCUDA.cpu()));
24+
Tensor diff = resultCPU - resultCUDA.cpu();
25+
EXPECT_LE(diff.abs().max().item<fp32_t>(), 1e-3);
2826
}
2927
}
3028
} // namespace pmpp::test::ops

scripts/build.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,4 +48,4 @@ cmake -S $SOURCE_DIR -B $BUILD_DIR -G Ninja \
4848
-DVCPKG_OVERLAY_TRIPLETS="csrc/cmake/vcpkg-triplets"
4949

5050
GTEST_COLOR=yes \
51-
cmake --build $BUILD_DIR --parallel 16 --target all check
51+
cmake --build $BUILD_DIR --parallel 12 --target all check

0 commit comments

Comments
 (0)