@@ -8,34 +8,46 @@ namespace pmpp::ops::cuda
88{
99
1010template <typename ScalarT, typename PredT>
11- __global__ void reductionKernel (ScalarT* in, ScalarT* out, const PredT& pred)
11+ __global__ void reductionKernel (const ScalarT* in, uint32_t n, ScalarT* out,
12+ const PredT& pred)
1213{
1314 // Thread index in the block
14- int32_t bTid = threadIdx .x ;
15- int32_t i = bTid * 2 ;
16- for (uint32_t stride = 1 ; stride < blockDim .x ; stride *= 2 ) {
17- if (bTid % stride == 0 ) {
18- in[i] = pred (in[i], in[i + stride]);
19- }
15+ uint32_t bTid = threadIdx .x ;
16+ extern __shared__ ScalarT shmem[];
17+
18+ uint32_t stride = blockDim .x ;
19+ shmem[bTid] = pred (in[bTid], in[bTid + stride]);
20+ stride /= 2 ;
21+
22+ for (; stride >= 1 ; stride /= 2 ) {
2023 __syncthreads ();
24+ if (bTid < stride) {
25+ shmem[bTid] = pred (shmem[bTid], shmem[bTid + stride]);
26+ }
2127 }
2228 if (bTid == 0 ) {
23- out[blockIdx . x ] = in [0 ];
29+ out[0 ] = shmem [0 ];
2430 }
2531}
2632
2733template <typename ScalarT, typename PredT>
28- [[nodiscard]] auto launchReduction (ScalarT* in, int32_t n, const PredT& pred)
29- -> ScalarT
34+ [[nodiscard]] auto launchReduction (const ScalarT* in, uint32_t n,
35+ const PredT& pred) -> ScalarT
3036{
31- dim3 blockDim = {uint32_t (n), 1 , 1 };
32- dim3 gridDim = {uint32_t (ceilDiv (n, blockDim .x )), 1 , 1 };
3337 ScalarT* d_out;
34- cudaMalloc (&d_out, gridDim .x * sizeof (ScalarT));
35- reductionKernel<<<gridDim , blockDim >>> (in, d_out, pred);
38+ cudaMalloc (&d_out, 1 * sizeof (ScalarT));
39+
40+ uint32_t nTreads = n / 2 ;
41+ dim3 blockDim = {nTreads, 1 , 1 };
42+ dim3 gridDim = {1 , 1 , 1 };
43+ uint32_t shmemSize = blockDim .x * sizeof (ScalarT);
44+
45+ reductionKernel<<<gridDim , blockDim , shmemSize>>> (in, n, d_out, pred);
46+
3647 ScalarT out;
3748 cudaMemcpy (&out, d_out, sizeof (ScalarT), cudaMemcpyDeviceToHost);
3849 cudaFree (d_out);
50+
3951 PMPP_DEBUG_CUDA_ERR_CHECK (cudaGetLastError ());
4052
4153 return out;
@@ -46,13 +58,12 @@ namespace torch_impl
4658[[nodiscard]] inline auto mulReduction (const torch::Tensor& in)
4759 -> torch::Tensor
4860{
49- torch::Tensor mutableIn = in.contiguous ();
5061 torch::Tensor result = {};
5162
5263 switch (in.scalar_type ()) {
5364 case torch::kFloat32 : {
5465 result =
55- torch::tensor (launchReduction (mutableIn. mutable_data_ptr <fp32_t >(),
66+ torch::tensor (launchReduction (in. const_data_ptr <fp32_t >(),
5667 in.numel (), std::multiplies<>()),
5768 in.options ());
5869 break ;
0 commit comments