diff --git a/src/NativeProviders/CUDA/lapack.cpp b/src/NativeProviders/CUDA/lapack.cpp index 718c4909..d0a7c96d 100644 --- a/src/NativeProviders/CUDA/lapack.cpp +++ b/src/NativeProviders/CUDA/lapack.cpp @@ -43,8 +43,8 @@ inline int lu_factor(cusolverDnHandle_t solverHandle, int m, T a[], int ipiv[], return info; }; -template -inline int lu_inverse(cusolverDnHandle_t solverHandle, cublasHandle_t blasHandle, int n, T a[], GETRF getrf, GETRI getri, GETRFBSIZE getrfbsize) +template +inline int lu_inverse(cusolverDnHandle_t solverHandle, cublasHandle_t blasHandle, int n, T a[], GETRF getrf, GETRIBATCHED getribatched, GETRFBSIZE getrfbsize) { int info = 0; @@ -63,13 +63,8 @@ inline int lu_inverse(cusolverDnHandle_t solverHandle, cublasHandle_t blasHandle int* d_info = NULL; cudaMalloc((void**)&d_info, sizeof(int)); - printf("initial %f %f %f %f %f %f %f %f %f\r\n", a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8]); - getrf(solverHandle, n, n, d_A, n, work, d_I, d_info); cudaMemcpy(&info, d_info, 1, cudaMemcpyDeviceToHost); - - cublasGetMatrix(n, n, sizeof(T), d_A, n, a, n); - printf("after factor %f %f %f %f %f %f %f %f %f\r\n", a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8]); cudaFree(work); @@ -83,20 +78,26 @@ inline int lu_inverse(cusolverDnHandle_t solverHandle, cublasHandle_t blasHandle T* d_C = NULL; cudaMalloc((void**)&d_C, n*n*sizeof(T)); - - getri(blasHandle, n, d_A, n, d_I, d_C, n, d_info); - cudaMemcpy(&info, d_info, 1, cudaMemcpyDeviceToHost); - cublasGetMatrix(n, n, sizeof(T), d_A, n, a, n); - printf("a inverse %f %f %f %f %f %f %f %f %f\r\n", a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8]); + const T **d_Aarray = NULL; + cudaMalloc((void**)&d_Aarray, sizeof(T*)); + cudaMemcpy(d_Aarray, &d_A, sizeof(T*), cudaMemcpyHostToDevice); + + T **d_Carray = NULL; + cudaMalloc((void**)&d_Carray, sizeof(T*)); + cudaMemcpy(d_Carray, &d_C, sizeof(T*), cudaMemcpyHostToDevice); + + getribatched(blasHandle, n, d_Aarray, n, d_I, d_Carray, n, d_info, 1); + cudaMemcpy(&info, d_info, 1, cudaMemcpyDeviceToHost); cublasGetMatrix(n, n, sizeof(T), d_C, n, a, n); - printf("c inverse %f %f %f %f %f %f %f %f %f\r\n", a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8]); cudaFree(d_A); cudaFree(d_I); cudaFree(d_C); cudaFree(d_info); + cudaFree(d_Aarray); + cudaFree(d_Carray); return info; }; @@ -122,7 +123,15 @@ inline int lu_inverse_factored(cublasHandle_t blasHandle, int n, T a[], int ipiv int* d_info = NULL; cudaMalloc((void**)&d_info, sizeof(int)); - getri(blasHandle, n, d_A, n, d_I, d_C, n, d_info); + const T **d_Aarray = NULL; + cudaMalloc((void**)&d_Aarray, sizeof(T*)); + cudaMemcpy(d_Aarray, &d_A, sizeof(T*), cudaMemcpyHostToDevice); + + T **d_Carray = NULL; + cudaMalloc((void**)&d_Carray, sizeof(T*)); + cudaMemcpy(d_Carray, &d_C, sizeof(T*), cudaMemcpyHostToDevice); + + getri(blasHandle, n, d_Aarray, n, d_I, d_Carray, n, d_info, 1); cudaMemcpy(&info, d_info, 1, cudaMemcpyDeviceToHost); cublasGetMatrix(n, n, sizeof(T), d_C, n, a, n); @@ -134,6 +143,8 @@ inline int lu_inverse_factored(cublasHandle_t blasHandle, int n, T a[], int ipiv cudaFree(d_I); cudaFree(d_C); cudaFree(d_info); + cudaFree(d_Aarray); + cudaFree(d_Carray); return info; } @@ -711,26 +722,10 @@ inline int complex_svd_factor(cusolverDnHandle_t solverHandle, bool compute_vect #define cgesvdbsize cusolverDnCgesvd_bufferSize #define zgesvdbsize cusolverDnZgesvd_bufferSize - -inline int sgetri(cublasHandle_t handle, int n, const float a[], int lda, const int ipiv[], float c[], int ldc, int *info) -{ - return cublasSgetriBatched(handle, n, &a, lda, ipiv, &c, ldc, info, 1); -} - -inline int dgetri(cublasHandle_t handle, int n, const double a[], int lda, const int ipiv[], double c[], int ldc, int *info) -{ - return cublasDgetriBatched(handle, n, &a, lda, ipiv, &c, ldc, info, 1); -} - -inline int cgetri(cublasHandle_t handle, int n, const cuComplex a[], int lda, const int ipiv[], cuComplex c[], int ldc, int *info) -{ - return cublasCgetriBatched(handle, n, &a, lda, ipiv, &c, ldc, info, 1); -} - -inline int zgetri(cublasHandle_t handle, int n, const cuDoubleComplex a[], int lda, const int ipiv[], cuDoubleComplex c[], int ldc, int *info) -{ - return cublasZgetriBatched(handle, n, &a, lda, ipiv, &c, ldc, info, 1); -} +#define sgetribatched cublasSgetriBatched +#define dgetribatched cublasDgetriBatched +#define cgetribatched cublasCgetriBatched +#define zgetribatched cublasZgetriBatched extern "C" { @@ -756,42 +751,42 @@ extern "C" { DLLEXPORT int s_lu_inverse(cusolverDnHandle_t solverHandle, cublasHandle_t blasHandle, int n, float a[]) { - return lu_inverse(solverHandle, blasHandle, n, a, sgetrf, sgetri, sgetrfbsize); + return lu_inverse(solverHandle, blasHandle, n, a, sgetrf, sgetribatched, sgetrfbsize); } DLLEXPORT int d_lu_inverse(cusolverDnHandle_t solverHandle, cublasHandle_t blasHandle, int n, double a[]) { - return lu_inverse(solverHandle, blasHandle, n, a, dgetrf, dgetri, dgetrfbsize); + return lu_inverse(solverHandle, blasHandle, n, a, dgetrf, dgetribatched, dgetrfbsize); } DLLEXPORT int c_lu_inverse(cusolverDnHandle_t solverHandle, cublasHandle_t blasHandle, int n, cuComplex a[]) { - return lu_inverse(solverHandle, blasHandle, n, a, cgetrf, cgetri, cgetrfbsize); + return lu_inverse(solverHandle, blasHandle, n, a, cgetrf, cgetribatched, cgetrfbsize); } DLLEXPORT int z_lu_inverse(cusolverDnHandle_t solverHandle, cublasHandle_t blasHandle, int n, cuDoubleComplex a[]) { - return lu_inverse(solverHandle, blasHandle, n, a, zgetrf, zgetri, zgetrfbsize); + return lu_inverse(solverHandle, blasHandle, n, a, zgetrf, zgetribatched, zgetrfbsize); } DLLEXPORT int s_lu_inverse_factored(cublasHandle_t blasHandle, int n, float a[], int ipiv[]) { - return lu_inverse_factored(blasHandle, n, a, ipiv, sgetri); + return lu_inverse_factored(blasHandle, n, a, ipiv, sgetribatched); } DLLEXPORT int d_lu_inverse_factored(cublasHandle_t blasHandle, int n, double a[], int ipiv[]) { - return lu_inverse_factored(blasHandle, n, a, ipiv, dgetri); + return lu_inverse_factored(blasHandle, n, a, ipiv, dgetribatched); } DLLEXPORT int c_lu_inverse_factored(cublasHandle_t blasHandle, int n, cuComplex a[], int ipiv[]) { - return lu_inverse_factored(blasHandle, n, a, ipiv, cgetri); + return lu_inverse_factored(blasHandle, n, a, ipiv, cgetribatched); } DLLEXPORT int z_lu_inverse_factored(cublasHandle_t blasHandle, int n, cuDoubleComplex a[], int ipiv[]) { - return lu_inverse_factored(blasHandle, n, a, ipiv, zgetri); + return lu_inverse_factored(blasHandle, n, a, ipiv, zgetribatched); } DLLEXPORT int s_lu_solve_factored(cusolverDnHandle_t solverHandle, int n, int nrhs, float a[], int ipiv[], float b[])