Add NVIDIA cuBLAS support (#1044)

pull/1061/head master-8944a13
slaren 1 year ago committed by GitHub
parent 6667401238
commit 8944a13296
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -66,6 +66,7 @@ endif()
# 3rd party libs # 3rd party libs
option(LLAMA_ACCELERATE "llama: enable Accelerate framework" ON) option(LLAMA_ACCELERATE "llama: enable Accelerate framework" ON)
option(LLAMA_OPENBLAS "llama: use OpenBLAS" OFF) option(LLAMA_OPENBLAS "llama: use OpenBLAS" OFF)
option(LLAMA_CUBLAS "llama: use cuBLAS" OFF)
option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE}) option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE})
option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE}) option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE})
@ -142,6 +143,26 @@ if (LLAMA_OPENBLAS)
endif() endif()
endif() endif()
if (LLAMA_CUBLAS)
cmake_minimum_required(VERSION 3.17)
find_package(CUDAToolkit)
if (CUDAToolkit_FOUND)
message(STATUS "cuBLAS found")
add_compile_definitions(GGML_USE_CUBLAS)
if (LLAMA_STATIC)
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
else()
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
endif()
else()
message(WARNING "cuBLAS not found")
endif()
endif()
if (LLAMA_ALL_WARNINGS) if (LLAMA_ALL_WARNINGS)
if (NOT MSVC) if (NOT MSVC)
set(c_flags set(c_flags

@ -97,6 +97,10 @@ ifdef LLAMA_OPENBLAS
CFLAGS += -DGGML_USE_OPENBLAS -I/usr/local/include/openblas CFLAGS += -DGGML_USE_OPENBLAS -I/usr/local/include/openblas
LDFLAGS += -lopenblas LDFLAGS += -lopenblas
endif endif
ifdef LLAMA_CUBLAS
CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include
LDFLAGS += -lcublas_static -lculibos -lcudart_static -lcublasLt_static -lpthread -ldl -L/usr/local/cuda/lib64
endif
ifdef LLAMA_GPROF ifdef LLAMA_GPROF
CFLAGS += -pg CFLAGS += -pg
CXXFLAGS += -pg CXXFLAGS += -pg

206
ggml.c

@ -142,10 +142,46 @@ inline static void* ggml_aligned_malloc(size_t size) {
} \ } \
} while (0) } while (0)
#ifdef GGML_USE_ACCELERATE #if defined(GGML_USE_ACCELERATE)
#include <Accelerate/Accelerate.h> #include <Accelerate/Accelerate.h>
#elif GGML_USE_OPENBLAS #elif defined(GGML_USE_OPENBLAS)
#include <cblas.h> #include <cblas.h>
#elif defined(GGML_USE_CUBLAS)
#include <cublas_v2.h>
#include <cuda_runtime.h>
#define CUDA_CHECK(err) \
do { \
cudaError_t err_ = (err); \
if (err_ != cudaSuccess) { \
printf("CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
cudaGetErrorString(err_)); \
exit(1); \
} \
} while (0)
#define CUBLAS_CHECK(err) \
do { \
cublasStatus_t err_ = (err); \
if (err_ != CUBLAS_STATUS_SUCCESS) { \
printf("cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \
exit(1); \
} \
} while (0)
static cublasHandle_t cublasH = NULL;
static cudaStream_t cudaStream = NULL;
static void init_cublas(void) {
if (cublasH == NULL) {
// create cublas handle, bind a stream
CUBLAS_CHECK(cublasCreate(&cublasH));
CUDA_CHECK(cudaStreamCreateWithFlags(&cudaStream, cudaStreamNonBlocking));
CUBLAS_CHECK(cublasSetStream(cublasH, cudaStream));
// configure logging to stdout
// CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL));
}
}
#endif #endif
#undef MIN #undef MIN
@ -3836,6 +3872,11 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f); GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
} }
// initialize cuBLAS
#if defined(GGML_USE_CUBLAS)
init_cublas();
#endif
is_first_call = false; is_first_call = false;
} }
@ -7567,7 +7608,7 @@ static void ggml_compute_forward_rms_norm(
// ggml_compute_forward_mul_mat // ggml_compute_forward_mul_mat
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
// helper function to determine if it is better to use BLAS or not // helper function to determine if it is better to use BLAS or not
// for large matrices, BLAS is faster // for large matrices, BLAS is faster
static bool ggml_compute_forward_mul_mat_use_blas( static bool ggml_compute_forward_mul_mat_use_blas(
@ -7607,7 +7648,7 @@ static void ggml_compute_forward_mul_mat_f32(
const int64_t ne02 = src0->ne[2]; const int64_t ne02 = src0->ne[2];
const int64_t ne03 = src0->ne[3]; const int64_t ne03 = src0->ne[3];
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
const int64_t ne10 = src1->ne[0]; const int64_t ne10 = src1->ne[0];
#endif #endif
const int64_t ne11 = src1->ne[1]; const int64_t ne11 = src1->ne[1];
@ -7664,7 +7705,7 @@ static void ggml_compute_forward_mul_mat_f32(
// nb01 >= nb00 - src0 is not transposed // nb01 >= nb00 - src0 is not transposed
// compute by src0 rows // compute by src0 rows
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
if (params->ith != 0) { if (params->ith != 0) {
return; return;
@ -7678,6 +7719,21 @@ static void ggml_compute_forward_mul_mat_f32(
return; return;
} }
#if defined(GGML_USE_CUBLAS)
float *d_X = NULL;
float *d_Y = NULL;
float *d_D = NULL;
const float alpha = 1.0f;
const float beta = 0.0f;
const int x_ne = ne01 * ne10;
const int y_ne = ne11 * ne10;
const int d_ne = ne11 * ne01;
CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne));
CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
#endif
for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) { for (int64_t i02 = 0; i02 < ne02; i02++) {
const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03); const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
@ -7685,15 +7741,37 @@ static void ggml_compute_forward_mul_mat_f32(
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
#if defined(GGML_USE_CUBLAS)
// copy data to device
CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, cudaStream));
CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream));
// compute
CUBLAS_CHECK(
cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
&alpha, d_X, ne00,
d_Y, ne10,
&beta, d_D, ne01));
// copy data to host
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
CUDA_CHECK(cudaStreamSynchronize(cudaStream));
#else
// zT = y * xT // zT = y * xT
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
ne11, ne01, ne10, ne11, ne01, ne10,
1.0f, y, ne10, 1.0f, y, ne10,
x, ne00, x, ne00,
0.0f, d, ne01); 0.0f, d, ne01);
#endif
} }
} }
#if defined(GGML_USE_CUBLAS)
CUDA_CHECK(cudaFree(d_X));
CUDA_CHECK(cudaFree(d_Y));
CUDA_CHECK(cudaFree(d_D));
#endif
//printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3); //printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
return; return;
@ -7823,7 +7901,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
// nb01 >= nb00 - src0 is not transposed // nb01 >= nb00 - src0 is not transposed
// compute by src0 rows // compute by src0 rows
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
GGML_ASSERT(nb10 == sizeof(float)); GGML_ASSERT(nb10 == sizeof(float));
@ -7839,10 +7917,37 @@ static void ggml_compute_forward_mul_mat_f16_f32(
return; return;
} }
float * const wdata = params->wdata; #if defined(GGML_USE_CUBLAS)
ggml_fp16_t * const wdata = params->wdata;
float *d_X = NULL;
float *d_Y = NULL;
float *d_D = NULL;
const float alpha = 1.0f;
const float beta = 0.0f;
const int x_ne = ne01 * ne10;
const int y_ne = ne11 * ne10;
const int d_ne = ne11 * ne01;
CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(ggml_fp16_t) * x_ne));
CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
#else
float * const wdata = params->wdata;
#endif
for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) { for (int64_t i02 = 0; i02 < ne02; i02++) {
#if defined(GGML_USE_CUBLAS)
// with cuBlAS, instead of converting src0 to fp32, we convert src1 to fp16
{
size_t id = 0;
for (int64_t i01 = 0; i01 < ne11; ++i01) {
for (int64_t i00 = 0; i00 < ne10; ++i00) {
wdata[id++] = GGML_FP32_TO_FP16(*(float *) ((char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11 + i00*nb10));
}
}
}
#else
{ {
size_t id = 0; size_t id = 0;
for (int64_t i01 = 0; i01 < ne01; ++i01) { for (int64_t i01 = 0; i01 < ne01; ++i01) {
@ -7851,7 +7956,32 @@ static void ggml_compute_forward_mul_mat_f16_f32(
} }
} }
} }
#endif
#if defined(GGML_USE_CUBLAS)
const ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + i02*nb02 + i03*nb03);
const ggml_fp16_t * y = (ggml_fp16_t *) wdata;
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
// copy data to device
CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(ggml_fp16_t) * x_ne, cudaMemcpyHostToDevice, cudaStream));
CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(ggml_fp16_t) * y_ne, cudaMemcpyHostToDevice, cudaStream));
// compute
CUBLAS_CHECK(
cublasGemmEx(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
&alpha, d_X, CUDA_R_16F, ne00,
d_Y, CUDA_R_16F, ne10,
&beta, d_D, CUDA_R_32F, ne01,
CUBLAS_COMPUTE_32F,
CUBLAS_GEMM_DEFAULT));
// copy data to host
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
CUDA_CHECK(cudaStreamSynchronize(cudaStream));
#else
const float * x = wdata; const float * x = wdata;
const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13); const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
@ -7863,9 +7993,15 @@ static void ggml_compute_forward_mul_mat_f16_f32(
1.0f, y, ne10, 1.0f, y, ne10,
x, ne00, x, ne00,
0.0f, d, ne01); 0.0f, d, ne01);
#endif
} }
} }
#if defined(GGML_USE_CUBLAS)
CUDA_CHECK(cudaFree(d_X));
CUDA_CHECK(cudaFree(d_Y));
CUDA_CHECK(cudaFree(d_D));
#endif
/*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/ /*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/
return; return;
@ -8017,7 +8153,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
// nb01 >= nb00 - src0 is not transposed // nb01 >= nb00 - src0 is not transposed
// compute by src0 rows // compute by src0 rows
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
if (params->ith != 0) { if (params->ith != 0) {
return; return;
@ -8034,6 +8170,21 @@ static void ggml_compute_forward_mul_mat_q_f32(
float * const wdata = params->wdata; float * const wdata = params->wdata;
dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q; dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
#if defined(GGML_USE_CUBLAS)
float *d_X = NULL;
float *d_Y = NULL;
float *d_D = NULL;
const float alpha = 1.0f;
const float beta = 0.0f;
const int x_ne = ne01 * ne10;
const int y_ne = ne11 * ne10;
const int d_ne = ne11 * ne01;
CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne));
CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
#endif
for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) { for (int64_t i02 = 0; i02 < ne02; i02++) {
{ {
@ -8049,15 +8200,38 @@ static void ggml_compute_forward_mul_mat_q_f32(
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
#if defined(GGML_USE_CUBLAS)
// copy data to device
CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, cudaStream));
CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream));
// compute
CUBLAS_CHECK(
cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
&alpha, d_X, ne00,
d_Y, ne10,
&beta, d_D, ne01));
// copy data to host
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
CUDA_CHECK(cudaStreamSynchronize(cudaStream));
#else
// zT = y * xT // zT = y * xT
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
ne11, ne01, ne10, ne11, ne01, ne10,
1.0f, y, ne10, 1.0f, y, ne10,
x, ne00, x, ne00,
0.0f, d, ne01); 0.0f, d, ne01);
#endif
} }
} }
#if defined(GGML_USE_CUBLAS)
CUDA_CHECK(cudaFree(d_X));
CUDA_CHECK(cudaFree(d_Y));
CUDA_CHECK(cudaFree(d_D));
#endif
//printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3); //printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
return; return;
@ -10874,7 +11048,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
size_t cur = 0; size_t cur = 0;
if (node->src0->type == GGML_TYPE_F16 && node->src1->type == GGML_TYPE_F32) { if (node->src0->type == GGML_TYPE_F16 && node->src1->type == GGML_TYPE_F32) {
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
node->n_tasks = 1; // TODO: this actually is doing nothing node->n_tasks = 1; // TODO: this actually is doing nothing
// the threads are still spinning // the threads are still spinning
@ -10891,7 +11065,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
} else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) { } else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) {
cur = 0; cur = 0;
} else if (quantize_fns[node->src0->type].vec_dot_q && node->src1->type == GGML_TYPE_F32) { } else if (quantize_fns[node->src0->type].vec_dot_q && node->src1->type == GGML_TYPE_F32) {
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
node->n_tasks = 1; node->n_tasks = 1;
cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]); cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
@ -12231,7 +12405,15 @@ int ggml_cpu_has_wasm_simd(void) {
} }
int ggml_cpu_has_blas(void) { int ggml_cpu_has_blas(void) {
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
return 1;
#else
return 0;
#endif
}
int ggml_cpu_has_cublas(void) {
#if defined(GGML_USE_CUBLAS)
return 1; return 1;
#else #else
return 0; return 0;

@ -825,6 +825,7 @@ int ggml_cpu_has_f16c(void);
int ggml_cpu_has_fp16_va(void); int ggml_cpu_has_fp16_va(void);
int ggml_cpu_has_wasm_simd(void); int ggml_cpu_has_wasm_simd(void);
int ggml_cpu_has_blas(void); int ggml_cpu_has_blas(void);
int ggml_cpu_has_cublas(void);
int ggml_cpu_has_sse3(void); int ggml_cpu_has_sse3(void);
int ggml_cpu_has_vsx(void); int ggml_cpu_has_vsx(void);

@ -1069,7 +1069,7 @@ static bool llama_eval_internal(
// for big prompts, if BLAS is enabled, it is better to use only one thread // for big prompts, if BLAS is enabled, it is better to use only one thread
// otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance // otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance
ggml_cgraph gf = {}; ggml_cgraph gf = {};
gf.n_threads = N >= 32 && ggml_cpu_has_blas() ? 1 : n_threads; gf.n_threads = N >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_cublas() ? 1 : n_threads;
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
memcpy(embd->data, tokens, N*ggml_element_size(embd)); memcpy(embd->data, tokens, N*ggml_element_size(embd));

Loading…
Cancel
Save