From 8944a1329648c57bb7d66851170938230587a52c Mon Sep 17 00:00:00 2001 From: slaren <2141330+slaren@users.noreply.github.com> Date: Wed, 19 Apr 2023 11:22:45 +0200 Subject: [PATCH] Add NVIDIA cuBLAS support (#1044) --- CMakeLists.txt | 21 +++++ Makefile | 4 + ggml.c | 206 ++++++++++++++++++++++++++++++++++++++++++++++--- ggml.h | 1 + llama.cpp | 2 +- 5 files changed, 221 insertions(+), 13 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index ed9a3aa..8eadea4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -66,6 +66,7 @@ endif() # 3rd party libs option(LLAMA_ACCELERATE "llama: enable Accelerate framework" ON) option(LLAMA_OPENBLAS "llama: use OpenBLAS" OFF) +option(LLAMA_CUBLAS "llama: use cuBLAS" OFF) option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE}) option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE}) @@ -142,6 +143,26 @@ if (LLAMA_OPENBLAS) endif() endif() +if (LLAMA_CUBLAS) + cmake_minimum_required(VERSION 3.17) + + find_package(CUDAToolkit) + if (CUDAToolkit_FOUND) + message(STATUS "cuBLAS found") + + add_compile_definitions(GGML_USE_CUBLAS) + + if (LLAMA_STATIC) + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static) + else() + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt) + endif() + + else() + message(WARNING "cuBLAS not found") + endif() +endif() + if (LLAMA_ALL_WARNINGS) if (NOT MSVC) set(c_flags diff --git a/Makefile b/Makefile index 071d956..deb0d00 100644 --- a/Makefile +++ b/Makefile @@ -97,6 +97,10 @@ ifdef LLAMA_OPENBLAS CFLAGS += -DGGML_USE_OPENBLAS -I/usr/local/include/openblas LDFLAGS += -lopenblas endif +ifdef LLAMA_CUBLAS + CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include + LDFLAGS += -lcublas_static -lculibos -lcudart_static -lcublasLt_static -lpthread -ldl -L/usr/local/cuda/lib64 +endif ifdef LLAMA_GPROF CFLAGS += -pg CXXFLAGS += -pg diff --git a/ggml.c b/ggml.c index f4b8fc2..13c1548 100644 --- a/ggml.c +++ b/ggml.c @@ -142,10 +142,46 @@ inline static void* ggml_aligned_malloc(size_t size) { } \ } while (0) -#ifdef GGML_USE_ACCELERATE +#if defined(GGML_USE_ACCELERATE) #include -#elif GGML_USE_OPENBLAS +#elif defined(GGML_USE_OPENBLAS) #include +#elif defined(GGML_USE_CUBLAS) +#include +#include +#define CUDA_CHECK(err) \ + do { \ + cudaError_t err_ = (err); \ + if (err_ != cudaSuccess) { \ + printf("CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \ + cudaGetErrorString(err_)); \ + exit(1); \ + } \ + } while (0) + +#define CUBLAS_CHECK(err) \ + do { \ + cublasStatus_t err_ = (err); \ + if (err_ != CUBLAS_STATUS_SUCCESS) { \ + printf("cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \ + exit(1); \ + } \ + } while (0) + +static cublasHandle_t cublasH = NULL; +static cudaStream_t cudaStream = NULL; +static void init_cublas(void) { + if (cublasH == NULL) { + // create cublas handle, bind a stream + CUBLAS_CHECK(cublasCreate(&cublasH)); + + CUDA_CHECK(cudaStreamCreateWithFlags(&cudaStream, cudaStreamNonBlocking)); + CUBLAS_CHECK(cublasSetStream(cublasH, cudaStream)); + + // configure logging to stdout + // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL)); + } +} #endif #undef MIN @@ -3836,6 +3872,11 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f); } + // initialize cuBLAS + #if defined(GGML_USE_CUBLAS) + init_cublas(); + #endif + is_first_call = false; } @@ -7567,7 +7608,7 @@ static void ggml_compute_forward_rms_norm( // ggml_compute_forward_mul_mat -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) // helper function to determine if it is better to use BLAS or not // for large matrices, BLAS is faster static bool ggml_compute_forward_mul_mat_use_blas( @@ -7607,7 +7648,7 @@ static void ggml_compute_forward_mul_mat_f32( const int64_t ne02 = src0->ne[2]; const int64_t ne03 = src0->ne[3]; -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) const int64_t ne10 = src1->ne[0]; #endif const int64_t ne11 = src1->ne[1]; @@ -7664,7 +7705,7 @@ static void ggml_compute_forward_mul_mat_f32( // nb01 >= nb00 - src0 is not transposed // compute by src0 rows -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { if (params->ith != 0) { return; @@ -7678,6 +7719,21 @@ static void ggml_compute_forward_mul_mat_f32( return; } +#if defined(GGML_USE_CUBLAS) + float *d_X = NULL; + float *d_Y = NULL; + float *d_D = NULL; + const float alpha = 1.0f; + const float beta = 0.0f; + const int x_ne = ne01 * ne10; + const int y_ne = ne11 * ne10; + const int d_ne = ne11 * ne01; + + CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne)); + CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne)); + CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne)); +#endif + for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03); @@ -7685,15 +7741,37 @@ static void ggml_compute_forward_mul_mat_f32( float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); +#if defined(GGML_USE_CUBLAS) + // copy data to device + CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, cudaStream)); + CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream)); + + // compute + CUBLAS_CHECK( + cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N, + ne01, ne11, ne10, + &alpha, d_X, ne00, + d_Y, ne10, + &beta, d_D, ne01)); + + // copy data to host + CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream)); + CUDA_CHECK(cudaStreamSynchronize(cudaStream)); +#else // zT = y * xT cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, ne11, ne01, ne10, 1.0f, y, ne10, x, ne00, 0.0f, d, ne01); +#endif } } - +#if defined(GGML_USE_CUBLAS) + CUDA_CHECK(cudaFree(d_X)); + CUDA_CHECK(cudaFree(d_Y)); + CUDA_CHECK(cudaFree(d_D)); +#endif //printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3); return; @@ -7823,7 +7901,7 @@ static void ggml_compute_forward_mul_mat_f16_f32( // nb01 >= nb00 - src0 is not transposed // compute by src0 rows -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { GGML_ASSERT(nb10 == sizeof(float)); @@ -7839,10 +7917,37 @@ static void ggml_compute_forward_mul_mat_f16_f32( return; } - float * const wdata = params->wdata; +#if defined(GGML_USE_CUBLAS) + ggml_fp16_t * const wdata = params->wdata; + float *d_X = NULL; + float *d_Y = NULL; + float *d_D = NULL; + const float alpha = 1.0f; + const float beta = 0.0f; + const int x_ne = ne01 * ne10; + const int y_ne = ne11 * ne10; + const int d_ne = ne11 * ne01; + + CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(ggml_fp16_t) * x_ne)); + CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne)); + CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne)); +#else + float * const wdata = params->wdata; +#endif for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { +#if defined(GGML_USE_CUBLAS) + // with cuBlAS, instead of converting src0 to fp32, we convert src1 to fp16 + { + size_t id = 0; + for (int64_t i01 = 0; i01 < ne11; ++i01) { + for (int64_t i00 = 0; i00 < ne10; ++i00) { + wdata[id++] = GGML_FP32_TO_FP16(*(float *) ((char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11 + i00*nb10)); + } + } + } +#else { size_t id = 0; for (int64_t i01 = 0; i01 < ne01; ++i01) { @@ -7851,7 +7956,32 @@ static void ggml_compute_forward_mul_mat_f16_f32( } } } +#endif +#if defined(GGML_USE_CUBLAS) + const ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + i02*nb02 + i03*nb03); + const ggml_fp16_t * y = (ggml_fp16_t *) wdata; + + float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); + + // copy data to device + CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(ggml_fp16_t) * x_ne, cudaMemcpyHostToDevice, cudaStream)); + CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(ggml_fp16_t) * y_ne, cudaMemcpyHostToDevice, cudaStream)); + + // compute + CUBLAS_CHECK( + cublasGemmEx(cublasH, CUBLAS_OP_T, CUBLAS_OP_N, + ne01, ne11, ne10, + &alpha, d_X, CUDA_R_16F, ne00, + d_Y, CUDA_R_16F, ne10, + &beta, d_D, CUDA_R_32F, ne01, + CUBLAS_COMPUTE_32F, + CUBLAS_GEMM_DEFAULT)); + + // copy data to host + CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream)); + CUDA_CHECK(cudaStreamSynchronize(cudaStream)); +#else const float * x = wdata; const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13); @@ -7863,9 +7993,15 @@ static void ggml_compute_forward_mul_mat_f16_f32( 1.0f, y, ne10, x, ne00, 0.0f, d, ne01); +#endif } } +#if defined(GGML_USE_CUBLAS) + CUDA_CHECK(cudaFree(d_X)); + CUDA_CHECK(cudaFree(d_Y)); + CUDA_CHECK(cudaFree(d_D)); +#endif /*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/ return; @@ -8017,7 +8153,7 @@ static void ggml_compute_forward_mul_mat_q_f32( // nb01 >= nb00 - src0 is not transposed // compute by src0 rows -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { if (params->ith != 0) { return; @@ -8034,6 +8170,21 @@ static void ggml_compute_forward_mul_mat_q_f32( float * const wdata = params->wdata; dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q; +#if defined(GGML_USE_CUBLAS) + float *d_X = NULL; + float *d_Y = NULL; + float *d_D = NULL; + const float alpha = 1.0f; + const float beta = 0.0f; + const int x_ne = ne01 * ne10; + const int y_ne = ne11 * ne10; + const int d_ne = ne11 * ne01; + + CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne)); + CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne)); + CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne)); +#endif + for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { { @@ -8049,15 +8200,38 @@ static void ggml_compute_forward_mul_mat_q_f32( float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); +#if defined(GGML_USE_CUBLAS) + // copy data to device + CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, cudaStream)); + CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream)); + + // compute + CUBLAS_CHECK( + cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N, + ne01, ne11, ne10, + &alpha, d_X, ne00, + d_Y, ne10, + &beta, d_D, ne01)); + + // copy data to host + CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream)); + CUDA_CHECK(cudaStreamSynchronize(cudaStream)); +#else // zT = y * xT cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, ne11, ne01, ne10, 1.0f, y, ne10, x, ne00, 0.0f, d, ne01); +#endif } } +#if defined(GGML_USE_CUBLAS) + CUDA_CHECK(cudaFree(d_X)); + CUDA_CHECK(cudaFree(d_Y)); + CUDA_CHECK(cudaFree(d_D)); +#endif //printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3); return; @@ -10874,7 +11048,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) size_t cur = 0; if (node->src0->type == GGML_TYPE_F16 && node->src1->type == GGML_TYPE_F32) { -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { node->n_tasks = 1; // TODO: this actually is doing nothing // the threads are still spinning @@ -10891,7 +11065,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) } else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) { cur = 0; } else if (quantize_fns[node->src0->type].vec_dot_q && node->src1->type == GGML_TYPE_F32) { -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { node->n_tasks = 1; cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]); @@ -12231,7 +12405,15 @@ int ggml_cpu_has_wasm_simd(void) { } int ggml_cpu_has_blas(void) { -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_cublas(void) { +#if defined(GGML_USE_CUBLAS) return 1; #else return 0; diff --git a/ggml.h b/ggml.h index 603be84..570147f 100644 --- a/ggml.h +++ b/ggml.h @@ -825,6 +825,7 @@ int ggml_cpu_has_f16c(void); int ggml_cpu_has_fp16_va(void); int ggml_cpu_has_wasm_simd(void); int ggml_cpu_has_blas(void); +int ggml_cpu_has_cublas(void); int ggml_cpu_has_sse3(void); int ggml_cpu_has_vsx(void); diff --git a/llama.cpp b/llama.cpp index f14324f..3ff5dc1 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1069,7 +1069,7 @@ static bool llama_eval_internal( // for big prompts, if BLAS is enabled, it is better to use only one thread // otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance ggml_cgraph gf = {}; - gf.n_threads = N >= 32 && ggml_cpu_has_blas() ? 1 : n_threads; + gf.n_threads = N >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_cublas() ? 1 : n_threads; struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); memcpy(embd->data, tokens, N*ggml_element_size(embd));