diff --git a/src/NativeProviders/CUDA/blas.cpp b/src/NativeProviders/CUDA/blas.cpp index c8348a32..fbad097e 100644 --- a/src/NativeProviders/CUDA/blas.cpp +++ b/src/NativeProviders/CUDA/blas.cpp @@ -1,3 +1,4 @@ +#include #include "cublas_v2.h" #include "cuda_runtime.h" #include "wrapper_common.h" @@ -54,19 +55,21 @@ void cuda_dot(const cublasHandle_t blasHandle, const int n, const T x[], int inc } template -void cuda_gemm(const cublasHandle_t handle, const cublasOperation_t transa, const cublasOperation_t transb, int m, int n, int k, const T *alpha, const T A[], int lda, const T B[], int ldb, const T *beta, T C[], int ldc, GEMM gemm) +void cuda_gemm(const cublasHandle_t handle, const cublasOperation_t transa, const cublasOperation_t transb, int m, int n, int k, const T alpha, const T A[], int lda, const T B[], int ldb, const T beta, T C[], int ldc, GEMM gemm) { T *d_A = NULL; - T *d_B = NULL; - T *d_C = NULL; cudaMalloc((void**)&d_A, m*k*sizeof(T)); - cudaMalloc((void**)&d_B, k*n*sizeof(T)); - cudaMalloc((void**)&d_C, m*n*sizeof(T)); - cublasSetMatrix(m, k, sizeof(T), A, m, d_A, m); + + T *d_B = NULL; + cudaMalloc((void**)&d_B, k*n*sizeof(T)); cublasSetMatrix(k, n, sizeof(T), B, k, d_B, k); - gemm(handle, transa, transb, m, n, k, alpha, d_A, lda, d_B, ldb, beta, d_C, ldc); + T *d_C = NULL; + cudaMalloc((void**)&d_C, m*n*sizeof(T)); + cublasSetMatrix(m, n, sizeof(T), C, m, d_C, m); + + gemm(handle, transa, transb, m, n, k, &alpha, d_A, lda, d_B, ldb, &beta, d_C, ldc); cublasGetMatrix(m, n, sizeof(T), d_C, m, C, m); @@ -137,28 +140,28 @@ extern "C" { int lda = transA == CUBLAS_OP_N ? m : k; int ldb = transB == CUBLAS_OP_N ? k : n; - cuda_gemm(blasHandle, transA, transB, m, n, k, &alpha, x, lda, y, ldb, &beta, c, m, cublasSgemm); + cuda_gemm(blasHandle, transA, transB, m, n, k, alpha, x, lda, y, ldb, beta, c, m, cublasSgemm); } DLLEXPORT void d_matrix_multiply(const cublasHandle_t blasHandle, cublasOperation_t transA, cublasOperation_t transB, const int m, const int n, const int k, const double alpha, const double x[], const double y[], const double beta, double c[]){ int lda = transA == CUBLAS_OP_N ? m : k; int ldb = transB == CUBLAS_OP_N ? k : n; - cuda_gemm(blasHandle, transA, transB, m, n, k, &alpha, x, lda, y, ldb, &beta, c, m, cublasDgemm); + cuda_gemm(blasHandle, transA, transB, m, n, k, alpha, x, lda, y, ldb, beta, c, m, cublasDgemm); } DLLEXPORT void c_matrix_multiply(const cublasHandle_t blasHandle, cublasOperation_t transA, cublasOperation_t transB, const int m, const int n, const int k, const cuComplex alpha, const cuComplex x[], const cuComplex y[], const cuComplex beta, cuComplex c[]){ int lda = transA == CUBLAS_OP_N ? m : k; int ldb = transB == CUBLAS_OP_N ? k : n; - cuda_gemm(blasHandle, transA, transB, m, n, k, &alpha, x, lda, y, ldb, &beta, c, m, cublasCgemm); + cuda_gemm(blasHandle, transA, transB, m, n, k, alpha, x, lda, y, ldb, beta, c, m, cublasCgemm); } DLLEXPORT void z_matrix_multiply(const cublasHandle_t blasHandle, cublasOperation_t transA, cublasOperation_t transB, const int m, const int n, const int k, const cuDoubleComplex alpha, const cuDoubleComplex x[], const cuDoubleComplex y[], const cuDoubleComplex beta, cuDoubleComplex c[]){ int lda = transA == CUBLAS_OP_N ? m : k; int ldb = transB == CUBLAS_OP_N ? k : n; - cuda_gemm(blasHandle, transA, transB, m, n, k, &alpha, x, lda, y, ldb, &beta, c, m, cublasZgemm); + cuda_gemm(blasHandle, transA, transB, m, n, k, alpha, x, lda, y, ldb, beta, c, m, cublasZgemm); } }