Browse Source

Fix for the matrix inverse bug.

pull/306/head
Matthew Johnson 11 years ago
parent
commit
547745dd22
  1. 79
      src/NativeProviders/CUDA/lapack.cpp

79
src/NativeProviders/CUDA/lapack.cpp

@ -43,8 +43,8 @@ inline int lu_factor(cusolverDnHandle_t solverHandle, int m, T a[], int ipiv[],
return info;
};
template<typename T, typename GETRF, typename GETRI, typename GETRFBSIZE>
inline int lu_inverse(cusolverDnHandle_t solverHandle, cublasHandle_t blasHandle, int n, T a[], GETRF getrf, GETRI getri, GETRFBSIZE getrfbsize)
template<typename T, typename GETRF, typename GETRIBATCHED, typename GETRFBSIZE>
inline int lu_inverse(cusolverDnHandle_t solverHandle, cublasHandle_t blasHandle, int n, T a[], GETRF getrf, GETRIBATCHED getribatched, GETRFBSIZE getrfbsize)
{
int info = 0;
@ -63,13 +63,8 @@ inline int lu_inverse(cusolverDnHandle_t solverHandle, cublasHandle_t blasHandle
int* d_info = NULL;
cudaMalloc((void**)&d_info, sizeof(int));
printf("initial %f %f %f %f %f %f %f %f %f\r\n", a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8]);
getrf(solverHandle, n, n, d_A, n, work, d_I, d_info);
cudaMemcpy(&info, d_info, 1, cudaMemcpyDeviceToHost);
cublasGetMatrix(n, n, sizeof(T), d_A, n, a, n);
printf("after factor %f %f %f %f %f %f %f %f %f\r\n", a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8]);
cudaFree(work);
@ -83,20 +78,26 @@ inline int lu_inverse(cusolverDnHandle_t solverHandle, cublasHandle_t blasHandle
T* d_C = NULL;
cudaMalloc((void**)&d_C, n*n*sizeof(T));
getri(blasHandle, n, d_A, n, d_I, d_C, n, d_info);
cudaMemcpy(&info, d_info, 1, cudaMemcpyDeviceToHost);
cublasGetMatrix(n, n, sizeof(T), d_A, n, a, n);
printf("a inverse %f %f %f %f %f %f %f %f %f\r\n", a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8]);
const T **d_Aarray = NULL;
cudaMalloc((void**)&d_Aarray, sizeof(T*));
cudaMemcpy(d_Aarray, &d_A, sizeof(T*), cudaMemcpyHostToDevice);
T **d_Carray = NULL;
cudaMalloc((void**)&d_Carray, sizeof(T*));
cudaMemcpy(d_Carray, &d_C, sizeof(T*), cudaMemcpyHostToDevice);
getribatched(blasHandle, n, d_Aarray, n, d_I, d_Carray, n, d_info, 1);
cudaMemcpy(&info, d_info, 1, cudaMemcpyDeviceToHost);
cublasGetMatrix(n, n, sizeof(T), d_C, n, a, n);
printf("c inverse %f %f %f %f %f %f %f %f %f\r\n", a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8]);
cudaFree(d_A);
cudaFree(d_I);
cudaFree(d_C);
cudaFree(d_info);
cudaFree(d_Aarray);
cudaFree(d_Carray);
return info;
};
@ -122,7 +123,15 @@ inline int lu_inverse_factored(cublasHandle_t blasHandle, int n, T a[], int ipiv
int* d_info = NULL;
cudaMalloc((void**)&d_info, sizeof(int));
getri(blasHandle, n, d_A, n, d_I, d_C, n, d_info);
const T **d_Aarray = NULL;
cudaMalloc((void**)&d_Aarray, sizeof(T*));
cudaMemcpy(d_Aarray, &d_A, sizeof(T*), cudaMemcpyHostToDevice);
T **d_Carray = NULL;
cudaMalloc((void**)&d_Carray, sizeof(T*));
cudaMemcpy(d_Carray, &d_C, sizeof(T*), cudaMemcpyHostToDevice);
getri(blasHandle, n, d_Aarray, n, d_I, d_Carray, n, d_info, 1);
cudaMemcpy(&info, d_info, 1, cudaMemcpyDeviceToHost);
cublasGetMatrix(n, n, sizeof(T), d_C, n, a, n);
@ -134,6 +143,8 @@ inline int lu_inverse_factored(cublasHandle_t blasHandle, int n, T a[], int ipiv
cudaFree(d_I);
cudaFree(d_C);
cudaFree(d_info);
cudaFree(d_Aarray);
cudaFree(d_Carray);
return info;
}
@ -711,26 +722,10 @@ inline int complex_svd_factor(cusolverDnHandle_t solverHandle, bool compute_vect
#define cgesvdbsize cusolverDnCgesvd_bufferSize
#define zgesvdbsize cusolverDnZgesvd_bufferSize
inline int sgetri(cublasHandle_t handle, int n, const float a[], int lda, const int ipiv[], float c[], int ldc, int *info)
{
return cublasSgetriBatched(handle, n, &a, lda, ipiv, &c, ldc, info, 1);
}
inline int dgetri(cublasHandle_t handle, int n, const double a[], int lda, const int ipiv[], double c[], int ldc, int *info)
{
return cublasDgetriBatched(handle, n, &a, lda, ipiv, &c, ldc, info, 1);
}
inline int cgetri(cublasHandle_t handle, int n, const cuComplex a[], int lda, const int ipiv[], cuComplex c[], int ldc, int *info)
{
return cublasCgetriBatched(handle, n, &a, lda, ipiv, &c, ldc, info, 1);
}
inline int zgetri(cublasHandle_t handle, int n, const cuDoubleComplex a[], int lda, const int ipiv[], cuDoubleComplex c[], int ldc, int *info)
{
return cublasZgetriBatched(handle, n, &a, lda, ipiv, &c, ldc, info, 1);
}
#define sgetribatched cublasSgetriBatched
#define dgetribatched cublasDgetriBatched
#define cgetribatched cublasCgetriBatched
#define zgetribatched cublasZgetriBatched
extern "C" {
@ -756,42 +751,42 @@ extern "C" {
DLLEXPORT int s_lu_inverse(cusolverDnHandle_t solverHandle, cublasHandle_t blasHandle, int n, float a[])
{
return lu_inverse(solverHandle, blasHandle, n, a, sgetrf, sgetri, sgetrfbsize);
return lu_inverse(solverHandle, blasHandle, n, a, sgetrf, sgetribatched, sgetrfbsize);
}
DLLEXPORT int d_lu_inverse(cusolverDnHandle_t solverHandle, cublasHandle_t blasHandle, int n, double a[])
{
return lu_inverse(solverHandle, blasHandle, n, a, dgetrf, dgetri, dgetrfbsize);
return lu_inverse(solverHandle, blasHandle, n, a, dgetrf, dgetribatched, dgetrfbsize);
}
DLLEXPORT int c_lu_inverse(cusolverDnHandle_t solverHandle, cublasHandle_t blasHandle, int n, cuComplex a[])
{
return lu_inverse(solverHandle, blasHandle, n, a, cgetrf, cgetri, cgetrfbsize);
return lu_inverse(solverHandle, blasHandle, n, a, cgetrf, cgetribatched, cgetrfbsize);
}
DLLEXPORT int z_lu_inverse(cusolverDnHandle_t solverHandle, cublasHandle_t blasHandle, int n, cuDoubleComplex a[])
{
return lu_inverse(solverHandle, blasHandle, n, a, zgetrf, zgetri, zgetrfbsize);
return lu_inverse(solverHandle, blasHandle, n, a, zgetrf, zgetribatched, zgetrfbsize);
}
DLLEXPORT int s_lu_inverse_factored(cublasHandle_t blasHandle, int n, float a[], int ipiv[])
{
return lu_inverse_factored(blasHandle, n, a, ipiv, sgetri);
return lu_inverse_factored(blasHandle, n, a, ipiv, sgetribatched);
}
DLLEXPORT int d_lu_inverse_factored(cublasHandle_t blasHandle, int n, double a[], int ipiv[])
{
return lu_inverse_factored(blasHandle, n, a, ipiv, dgetri);
return lu_inverse_factored(blasHandle, n, a, ipiv, dgetribatched);
}
DLLEXPORT int c_lu_inverse_factored(cublasHandle_t blasHandle, int n, cuComplex a[], int ipiv[])
{
return lu_inverse_factored(blasHandle, n, a, ipiv, cgetri);
return lu_inverse_factored(blasHandle, n, a, ipiv, cgetribatched);
}
DLLEXPORT int z_lu_inverse_factored(cublasHandle_t blasHandle, int n, cuDoubleComplex a[], int ipiv[])
{
return lu_inverse_factored(blasHandle, n, a, ipiv, zgetri);
return lu_inverse_factored(blasHandle, n, a, ipiv, zgetribatched);
}
DLLEXPORT int s_lu_solve_factored(cusolverDnHandle_t solverHandle, int n, int nrhs, float a[], int ipiv[], float b[])

Loading…
Cancel
Save