diff --git a/Grid/algorithms/blas/BatchedBlas.h b/Grid/algorithms/blas/BatchedBlas.h index bd01ab741b..580e8166c4 100644 --- a/Grid/algorithms/blas/BatchedBlas.h +++ b/Grid/algorithms/blas/BatchedBlas.h @@ -28,6 +28,7 @@ Author: Peter Boyle #pragma once #ifdef GRID_HIP +#include #include #endif #ifdef GRID_CUDA @@ -255,16 +256,29 @@ class GridBLAS { if ( OpB == GridBLAS_OP_N ) hOpB = HIPBLAS_OP_N; if ( OpB == GridBLAS_OP_T ) hOpB = HIPBLAS_OP_T; if ( OpB == GridBLAS_OP_C ) hOpB = HIPBLAS_OP_C; +#if defined(HIP_VERSION_MAJOR) && (HIP_VERSION_MAJOR >=7) auto err = hipblasZgemmBatched(gridblasHandle, hOpA, hOpB, m,n,k, - (hipblasDoubleComplex *) &alpha_p[0], - (hipblasDoubleComplex **)&Amk[0], lda, - (hipblasDoubleComplex **)&Bkn[0], ldb, - (hipblasDoubleComplex *) &beta_p[0], - (hipblasDoubleComplex **)&Cmn[0], ldc, + (hipDoubleComplex *) &alpha_p[0], + (hipDoubleComplex **)&Amk[0], lda, + (hipDoubleComplex **)&Bkn[0], ldb, + (hipDoubleComplex *) &beta_p[0], + (hipDoubleComplex **)&Cmn[0], ldc, batchCount); +#else + auto err = hipblasZgemmBatched(gridblasHandle, + hOpA, + hOpB, + m,n,k, + (hipblasDoubleComplex *) &alpha_p[0], + (hipblasDoubleComplex **)&Amk[0], lda, + (hipblasDoubleComplex **)&Bkn[0], ldb, + (hipblasDoubleComplex *) &beta_p[0], + (hipblasDoubleComplex **)&Cmn[0], ldc, + batchCount); +#endif // std::cout << " hipblas return code " <<(int)err<=7) auto err = hipblasCgemmBatched(gridblasHandle, hOpA, hOpB, m,n,k, - (hipblasComplex *) &alpha_p[0], - (hipblasComplex **)&Amk[0], lda, - (hipblasComplex **)&Bkn[0], ldb, - (hipblasComplex *) &beta_p[0], - (hipblasComplex **)&Cmn[0], ldc, + (hipComplex *) &alpha_p[0], + (hipComplex **)&Amk[0], lda, + (hipComplex **)&Bkn[0], ldb, + (hipComplex *) &beta_p[0], + (hipComplex **)&Cmn[0], ldc, batchCount); +#else + auto err = hipblasCgemmBatched(gridblasHandle, + hOpA, + hOpB, + m,n,k, + (hipblasComplex *) &alpha_p[0], + (hipblasComplex **)&Amk[0], lda, + (hipblasComplex **)&Bkn[0], ldb, + (hipblasComplex *) &beta_p[0], + (hipblasComplex **)&Cmn[0], ldc, + batchCount); +#endif GRID_ASSERT(err==HIPBLAS_STATUS_SUCCESS); #endif #ifdef GRID_CUDA @@ -1094,11 +1121,19 @@ class GridBLAS { GRID_ASSERT(info.size()==batchCount); #ifdef GRID_HIP +#if defined(HIP_VERSION_MAJOR) && (HIP_VERSION_MAJOR >=7) auto err = hipblasZgetrfBatched(gridblasHandle,(int)n, - (hipblasDoubleComplex **)&Ann[0], (int)n, + (hipDoubleComplex **)&Ann[0], (int)n, (int*) &ipiv[0], (int*) &info[0], (int)batchCount); +#else + auto err = hipblasZgetrfBatched(gridblasHandle,(int)n, + (hipblasDoubleComplex **)&Ann[0], (int)n, + (int*) &ipiv[0], + (int*) &info[0], + (int)batchCount); +#endif GRID_ASSERT(err==HIPBLAS_STATUS_SUCCESS); #endif #ifdef GRID_CUDA @@ -1124,11 +1159,20 @@ class GridBLAS { GRID_ASSERT(info.size()==batchCount); #ifdef GRID_HIP +#if defined(HIP_VERSION_MAJOR) && (HIP_VERSION_MAJOR >=7) auto err = hipblasCgetrfBatched(gridblasHandle,(int)n, - (hipblasComplex **)&Ann[0], (int)n, + (hipComplex **)&Ann[0], (int)n, (int*) &ipiv[0], (int*) &info[0], (int)batchCount); +#else + auto err = hipblasCgetrfBatched(gridblasHandle,(int)n, + (hipblasComplex **)&Ann[0], (int)n, + (int*) &ipiv[0], + (int*) &info[0], + (int)batchCount); +#endif + GRID_ASSERT(err==HIPBLAS_STATUS_SUCCESS); #endif #ifdef GRID_CUDA @@ -1201,12 +1245,22 @@ class GridBLAS { GRID_ASSERT(Cnn.size()==batchCount); #ifdef GRID_HIP +#if defined(HIP_VERSION_MAJOR) && (HIP_VERSION_MAJOR >=7) auto err = hipblasZgetriBatched(gridblasHandle,(int)n, - (hipblasDoubleComplex **)&Ann[0], (int)n, + (hipDoubleComplex **)&Ann[0], (int)n, (int*) &ipiv[0], - (hipblasDoubleComplex **)&Cnn[0], (int)n, + (hipDoubleComplex **)&Cnn[0], (int)n, (int*) &info[0], (int)batchCount); +#else + auto err = hipblasZgetriBatched(gridblasHandle,(int)n, + (hipblasDoubleComplex **)&Ann[0], (int)n, + (int*) &ipiv[0], + (hipblasDoubleComplex **)&Cnn[0], (int)n, + (int*) &info[0], + (int)batchCount); + +#endif GRID_ASSERT(err==HIPBLAS_STATUS_SUCCESS); #endif #ifdef GRID_CUDA @@ -1235,12 +1289,21 @@ class GridBLAS { GRID_ASSERT(Cnn.size()==batchCount); #ifdef GRID_HIP +#if defined(HIP_VERSION_MAJOR) && (HIP_VERSION_MAJOR >=7) auto err = hipblasCgetriBatched(gridblasHandle,(int)n, - (hipblasComplex **)&Ann[0], (int)n, + (hipComplex **)&Ann[0], (int)n, (int*) &ipiv[0], - (hipblasComplex **)&Cnn[0], (int)n, + (hipComplex **)&Cnn[0], (int)n, (int*) &info[0], (int)batchCount); +#else + auto err = hipblasCgetriBatched(gridblasHandle,(int)n, + (hipblasComplex **)&Ann[0], (int)n, + (int*) &ipiv[0], + (hipblasComplex **)&Cnn[0], (int)n, + (int*) &info[0], + (int)batchCount); +#endif GRID_ASSERT(err==HIPBLAS_STATUS_SUCCESS); #endif #ifdef GRID_CUDA