Browse Source

Fixed a bug with GEMM

pull/306/head
Matthew Johnson 11 years ago
parent
commit
bc4b926557
  1. 25
      src/NativeProviders/CUDA/blas.cpp

25
src/NativeProviders/CUDA/blas.cpp

@ -1,3 +1,4 @@
#include <stdio.h>
#include "cublas_v2.h"
#include "cuda_runtime.h"
#include "wrapper_common.h"
@ -54,19 +55,21 @@ void cuda_dot(const cublasHandle_t blasHandle, const int n, const T x[], int inc
}
template<typename T, typename GEMM>
void cuda_gemm(const cublasHandle_t handle, const cublasOperation_t transa, const cublasOperation_t transb, int m, int n, int k, const T *alpha, const T A[], int lda, const T B[], int ldb, const T *beta, T C[], int ldc, GEMM gemm)
void cuda_gemm(const cublasHandle_t handle, const cublasOperation_t transa, const cublasOperation_t transb, int m, int n, int k, const T alpha, const T A[], int lda, const T B[], int ldb, const T beta, T C[], int ldc, GEMM gemm)
{
T *d_A = NULL;
T *d_B = NULL;
T *d_C = NULL;
cudaMalloc((void**)&d_A, m*k*sizeof(T));
cudaMalloc((void**)&d_B, k*n*sizeof(T));
cudaMalloc((void**)&d_C, m*n*sizeof(T));
cublasSetMatrix(m, k, sizeof(T), A, m, d_A, m);
T *d_B = NULL;
cudaMalloc((void**)&d_B, k*n*sizeof(T));
cublasSetMatrix(k, n, sizeof(T), B, k, d_B, k);
gemm(handle, transa, transb, m, n, k, alpha, d_A, lda, d_B, ldb, beta, d_C, ldc);
T *d_C = NULL;
cudaMalloc((void**)&d_C, m*n*sizeof(T));
cublasSetMatrix(m, n, sizeof(T), C, m, d_C, m);
gemm(handle, transa, transb, m, n, k, &alpha, d_A, lda, d_B, ldb, &beta, d_C, ldc);
cublasGetMatrix(m, n, sizeof(T), d_C, m, C, m);
@ -137,28 +140,28 @@ extern "C" {
int lda = transA == CUBLAS_OP_N ? m : k;
int ldb = transB == CUBLAS_OP_N ? k : n;
cuda_gemm(blasHandle, transA, transB, m, n, k, &alpha, x, lda, y, ldb, &beta, c, m, cublasSgemm);
cuda_gemm(blasHandle, transA, transB, m, n, k, alpha, x, lda, y, ldb, beta, c, m, cublasSgemm);
}
DLLEXPORT void d_matrix_multiply(const cublasHandle_t blasHandle, cublasOperation_t transA, cublasOperation_t transB, const int m, const int n, const int k, const double alpha, const double x[], const double y[], const double beta, double c[]){
int lda = transA == CUBLAS_OP_N ? m : k;
int ldb = transB == CUBLAS_OP_N ? k : n;
cuda_gemm(blasHandle, transA, transB, m, n, k, &alpha, x, lda, y, ldb, &beta, c, m, cublasDgemm);
cuda_gemm(blasHandle, transA, transB, m, n, k, alpha, x, lda, y, ldb, beta, c, m, cublasDgemm);
}
DLLEXPORT void c_matrix_multiply(const cublasHandle_t blasHandle, cublasOperation_t transA, cublasOperation_t transB, const int m, const int n, const int k, const cuComplex alpha, const cuComplex x[], const cuComplex y[], const cuComplex beta, cuComplex c[]){
int lda = transA == CUBLAS_OP_N ? m : k;
int ldb = transB == CUBLAS_OP_N ? k : n;
cuda_gemm(blasHandle, transA, transB, m, n, k, &alpha, x, lda, y, ldb, &beta, c, m, cublasCgemm);
cuda_gemm(blasHandle, transA, transB, m, n, k, alpha, x, lda, y, ldb, beta, c, m, cublasCgemm);
}
DLLEXPORT void z_matrix_multiply(const cublasHandle_t blasHandle, cublasOperation_t transA, cublasOperation_t transB, const int m, const int n, const int k, const cuDoubleComplex alpha, const cuDoubleComplex x[], const cuDoubleComplex y[], const cuDoubleComplex beta, cuDoubleComplex c[]){
int lda = transA == CUBLAS_OP_N ? m : k;
int ldb = transB == CUBLAS_OP_N ? k : n;
cuda_gemm(blasHandle, transA, transB, m, n, k, &alpha, x, lda, y, ldb, &beta, c, m, cublasZgemm);
cuda_gemm(blasHandle, transA, transB, m, n, k, alpha, x, lda, y, ldb, beta, c, m, cublasZgemm);
}
}

Loading…
Cancel
Save