|
|
|
@ -1,3 +1,4 @@ |
|
|
|
#include <stdio.h> |
|
|
|
#include "cublas_v2.h" |
|
|
|
#include "cuda_runtime.h" |
|
|
|
#include "wrapper_common.h" |
|
|
|
@ -54,19 +55,21 @@ void cuda_dot(const cublasHandle_t blasHandle, const int n, const T x[], int inc |
|
|
|
} |
|
|
|
|
|
|
|
template<typename T, typename GEMM> |
|
|
|
void cuda_gemm(const cublasHandle_t handle, const cublasOperation_t transa, const cublasOperation_t transb, int m, int n, int k, const T *alpha, const T A[], int lda, const T B[], int ldb, const T *beta, T C[], int ldc, GEMM gemm) |
|
|
|
void cuda_gemm(const cublasHandle_t handle, const cublasOperation_t transa, const cublasOperation_t transb, int m, int n, int k, const T alpha, const T A[], int lda, const T B[], int ldb, const T beta, T C[], int ldc, GEMM gemm) |
|
|
|
{ |
|
|
|
T *d_A = NULL; |
|
|
|
T *d_B = NULL; |
|
|
|
T *d_C = NULL; |
|
|
|
cudaMalloc((void**)&d_A, m*k*sizeof(T)); |
|
|
|
cudaMalloc((void**)&d_B, k*n*sizeof(T)); |
|
|
|
cudaMalloc((void**)&d_C, m*n*sizeof(T)); |
|
|
|
|
|
|
|
cublasSetMatrix(m, k, sizeof(T), A, m, d_A, m); |
|
|
|
|
|
|
|
T *d_B = NULL; |
|
|
|
cudaMalloc((void**)&d_B, k*n*sizeof(T)); |
|
|
|
cublasSetMatrix(k, n, sizeof(T), B, k, d_B, k); |
|
|
|
|
|
|
|
gemm(handle, transa, transb, m, n, k, alpha, d_A, lda, d_B, ldb, beta, d_C, ldc); |
|
|
|
T *d_C = NULL; |
|
|
|
cudaMalloc((void**)&d_C, m*n*sizeof(T)); |
|
|
|
cublasSetMatrix(m, n, sizeof(T), C, m, d_C, m); |
|
|
|
|
|
|
|
gemm(handle, transa, transb, m, n, k, &alpha, d_A, lda, d_B, ldb, &beta, d_C, ldc); |
|
|
|
|
|
|
|
cublasGetMatrix(m, n, sizeof(T), d_C, m, C, m); |
|
|
|
|
|
|
|
@ -137,28 +140,28 @@ extern "C" { |
|
|
|
int lda = transA == CUBLAS_OP_N ? m : k; |
|
|
|
int ldb = transB == CUBLAS_OP_N ? k : n; |
|
|
|
|
|
|
|
cuda_gemm(blasHandle, transA, transB, m, n, k, &alpha, x, lda, y, ldb, &beta, c, m, cublasSgemm); |
|
|
|
cuda_gemm(blasHandle, transA, transB, m, n, k, alpha, x, lda, y, ldb, beta, c, m, cublasSgemm); |
|
|
|
} |
|
|
|
|
|
|
|
DLLEXPORT void d_matrix_multiply(const cublasHandle_t blasHandle, cublasOperation_t transA, cublasOperation_t transB, const int m, const int n, const int k, const double alpha, const double x[], const double y[], const double beta, double c[]){ |
|
|
|
int lda = transA == CUBLAS_OP_N ? m : k; |
|
|
|
int ldb = transB == CUBLAS_OP_N ? k : n; |
|
|
|
|
|
|
|
cuda_gemm(blasHandle, transA, transB, m, n, k, &alpha, x, lda, y, ldb, &beta, c, m, cublasDgemm); |
|
|
|
cuda_gemm(blasHandle, transA, transB, m, n, k, alpha, x, lda, y, ldb, beta, c, m, cublasDgemm); |
|
|
|
} |
|
|
|
|
|
|
|
DLLEXPORT void c_matrix_multiply(const cublasHandle_t blasHandle, cublasOperation_t transA, cublasOperation_t transB, const int m, const int n, const int k, const cuComplex alpha, const cuComplex x[], const cuComplex y[], const cuComplex beta, cuComplex c[]){ |
|
|
|
int lda = transA == CUBLAS_OP_N ? m : k; |
|
|
|
int ldb = transB == CUBLAS_OP_N ? k : n; |
|
|
|
|
|
|
|
cuda_gemm(blasHandle, transA, transB, m, n, k, &alpha, x, lda, y, ldb, &beta, c, m, cublasCgemm); |
|
|
|
cuda_gemm(blasHandle, transA, transB, m, n, k, alpha, x, lda, y, ldb, beta, c, m, cublasCgemm); |
|
|
|
} |
|
|
|
|
|
|
|
DLLEXPORT void z_matrix_multiply(const cublasHandle_t blasHandle, cublasOperation_t transA, cublasOperation_t transB, const int m, const int n, const int k, const cuDoubleComplex alpha, const cuDoubleComplex x[], const cuDoubleComplex y[], const cuDoubleComplex beta, cuDoubleComplex c[]){ |
|
|
|
int lda = transA == CUBLAS_OP_N ? m : k; |
|
|
|
int ldb = transB == CUBLAS_OP_N ? k : n; |
|
|
|
|
|
|
|
cuda_gemm(blasHandle, transA, transB, m, n, k, &alpha, x, lda, y, ldb, &beta, c, m, cublasZgemm); |
|
|
|
cuda_gemm(blasHandle, transA, transB, m, n, k, alpha, x, lda, y, ldb, beta, c, m, cublasZgemm); |
|
|
|
} |
|
|
|
|
|
|
|
} |
|
|
|
|