cuBLAS: refactor and optimize f16 mat mul performance (#1259)

* cuBLAS: refactor, convert fp16 to fp32 on device

* cuBLAS: use multiple streams, choose smartly between mul_mat_q and mul_mat_f16

* fix build

* cuBLAS: update block_q5_1
pull/1237/head master-58b367c
slaren 1 year ago committed by GitHub
parent ea3a0ad6b6
commit 58b367c2d7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,11 +1,38 @@
#include <cstddef>
#include <cstdint>
#include <stdint.h>
#include <stdio.h>
#include <cuda_fp16.h>
#include <atomic>
#include "ggml-cuda.h"
typedef uint16_t ggml_fp16_t;
static_assert(sizeof(__half) == sizeof(ggml_fp16_t), "wrong fp16 size");
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <cuda_fp16.h>
#include "ggml-cuda.h"
#include "ggml.h"
static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
#define CUDA_CHECK(err) \
do { \
cudaError_t err_ = (err); \
if (err_ != cudaSuccess) { \
fprintf(stderr, "CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
cudaGetErrorString(err_)); \
exit(1); \
} \
} while (0)
#define CUBLAS_CHECK(err) \
do { \
cublasStatus_t err_ = (err); \
if (err_ != CUBLAS_STATUS_SUCCESS) { \
fprintf(stderr, "cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \
exit(1); \
} \
} while (0)
typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
#define QK4_0 32
typedef struct {
@ -24,14 +51,14 @@ static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 b
#define QK4_2 16
typedef struct {
__half d; // delta
half d; // delta
uint8_t qs[QK4_2 / 2]; // nibbles / quants
} block_q4_2;
static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2 block size/padding");
#define QK5_0 32
typedef struct {
__half d; // delta
half d; // delta
uint8_t qh[4]; // 5-th bit of quants
uint8_t qs[QK5_0 / 2]; // nibbles / quants
} block_q5_0;
@ -39,9 +66,9 @@ static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5
#define QK5_1 32
typedef struct {
__half d; // delta
__half m; // min
uint32_t qh; // 5-th bit of quants
half d; // delta
half m; // min
uint8_t qh[4]; // 5-th bit of quants
uint8_t qs[QK5_1 / 2]; // nibbles / quants
} block_q5_1;
static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
@ -162,7 +189,8 @@ static __global__ void dequantize_block_q5_1(const void * vx, float * y) {
const uint8_t * pp = x[i].qs;
const uint32_t qh = x[i].qh;
uint32_t qh;
memcpy(&qh, x[i].qh, sizeof(qh));
for (int l = 0; l < QK5_1; l += 2) {
const uint8_t vi = pp[l/2];
@ -197,37 +225,50 @@ static __global__ void dequantize_block_q8_0(const void * vx, float * y) {
}
}
void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
static void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
const int nb = k / QK4_0;
dequantize_block_q4_0<<<nb, 1, 0, stream>>>(vx, y);
}
void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
static void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
const int nb = k / QK4_1;
dequantize_block_q4_1<<<nb, 1, 0, stream>>>(vx, y);
}
void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
static void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
const int nb = k / QK4_2;
dequantize_block_q4_2<<<nb, 1, 0, stream>>>(vx, y);
}
void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
static void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
const int nb = k / QK5_0;
dequantize_block_q5_0<<<nb, 1, 0, stream>>>(vx, y);
}
void dequantize_row_q5_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
static void dequantize_row_q5_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
const int nb = k / QK5_1;
dequantize_block_q5_1<<<nb, 1, 0, stream>>>(vx, y);
}
void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
static void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
const int nb = k / QK8_0;
dequantize_block_q8_0<<<nb, 1, 0, stream>>>(vx, y);
}
dequantize_row_q_cuda_t ggml_get_dequantize_row_q_cuda(ggml_type type) {
// TODO: optimize
static __global__ void convert_fp16_to_fp32(const void * vx, float * y) {
const half * x = (const half *) vx;
const int i = blockIdx.x;
y[i] = __half2float(x[i]);
}
static void convert_fp16_to_fp32_cuda(const void * x, float * y, int k, cudaStream_t stream) {
convert_fp16_to_fp32<<<k, 1, 0, stream>>>(x, y);
}
static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
switch (type) {
case GGML_TYPE_Q4_0:
return dequantize_row_q4_0_cuda;
@ -241,6 +282,8 @@ dequantize_row_q_cuda_t ggml_get_dequantize_row_q_cuda(ggml_type type) {
return dequantize_row_q5_1_cuda;
case GGML_TYPE_Q8_0:
return dequantize_row_q8_0_cuda;
case GGML_TYPE_F16:
return convert_fp16_to_fp32_cuda;
default:
return nullptr;
}
@ -271,7 +314,7 @@ struct cuda_buffer {
static cuda_buffer g_cuda_buffer_pool[MAX_CUDA_BUFFERS];
static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT;
void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
scoped_spin_lock lock(g_cuda_pool_lock);
for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
@ -290,7 +333,7 @@ void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
return ptr;
}
void ggml_cuda_pool_free(void * ptr, size_t size) {
static void ggml_cuda_pool_free(void * ptr, size_t size) {
scoped_spin_lock lock(g_cuda_pool_lock);
for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
@ -305,28 +348,55 @@ void ggml_cuda_pool_free(void * ptr, size_t size) {
CUDA_CHECK(cudaFree(ptr));
}
cublasHandle_t g_cublasH = nullptr;
cudaStream_t g_cudaStream = nullptr;
cudaStream_t g_cudaStream2 = nullptr;
cudaEvent_t g_cudaEvent = nullptr;
#define GGML_CUDA_MAX_STREAMS 8
#define GGML_CUDA_MAX_EVENTS 64
static cublasHandle_t g_cublasH = nullptr;
static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_STREAMS] = { nullptr };
static cudaStream_t g_cudaStreams2[GGML_CUDA_MAX_STREAMS] = { nullptr };
static cudaEvent_t g_cudaEvents[GGML_CUDA_MAX_EVENTS] = { nullptr };
void ggml_init_cublas() {
if (g_cublasH == nullptr) {
// create cublas handle, bind a stream
CUBLAS_CHECK(cublasCreate(&g_cublasH));
CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStream, cudaStreamNonBlocking));
CUBLAS_CHECK(cublasSetStream(g_cublasH, g_cudaStream));
// create streams
for (int i = 0; i < GGML_CUDA_MAX_STREAMS; ++i) {
CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams[i], cudaStreamNonBlocking));
CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams2[i], cudaStreamNonBlocking));
}
// create events
for (int i = 0; i < GGML_CUDA_MAX_EVENTS; ++i) {
CUDA_CHECK(cudaEventCreateWithFlags(&g_cudaEvents[i], cudaEventDisableTiming));
}
// create additional stream and event for synchronization
CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStream2, cudaStreamNonBlocking));
CUDA_CHECK(cudaEventCreateWithFlags(&g_cudaEvent, cudaEventDisableTiming));
// create cublas handle
CUBLAS_CHECK(cublasCreate(&g_cublasH));
CUBLAS_CHECK(cublasSetMathMode(g_cublasH, CUBLAS_TF32_TENSOR_OP_MATH));
// configure logging to stdout
// CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL));
// CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr));
}
}
void * ggml_cuda_host_malloc(size_t size) {
if (getenv("GGML_CUDA_NO_PINNED") != nullptr) {
return nullptr;
}
void * ptr = nullptr;
cudaError_t err = cudaMallocHost((void **) &ptr, size);
if (err != cudaSuccess) {
fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory: %s\n",
size/1024.0/1024.0, cudaGetErrorString(err));
return nullptr;
}
return ptr;
}
void ggml_cuda_host_free(void * ptr) {
CUDA_CHECK(cudaFreeHost(ptr));
}
cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, cudaStream_t stream) {
static cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, cudaStream_t stream) {
const uint64_t ne0 = src->ne[0];
const uint64_t ne1 = src->ne[1];
const uint64_t nb0 = src->nb[0];
@ -354,22 +424,293 @@ cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor * src,
}
}
void * ggml_cuda_host_malloc(size_t size) {
if (getenv("GGML_CUDA_NO_PINNED") != nullptr) {
return nullptr;
static void ggml_cuda_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2];
const int64_t ne03 = src0->ne[3];
const int64_t ne10 = src1->ne[0];
const int64_t ne11 = src1->ne[1];
const int nb2 = dst->nb[2];
const int nb3 = dst->nb[3];
const float alpha = 1.0f;
const float beta = 0.0f;
const int x_ne = ne01 * ne00;
const int y_ne = ne11 * ne10;
const int d_ne = ne11 * ne01;
const int n_mm = ne03 * ne02;
size_t x_size, y_size, d_size;
float * d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size);
float * d_Y = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * y_ne, &y_size);
float * d_D = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size);
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
int i = i03*ne02 + i02;
cudaStream_t cudaStream = g_cudaStreams[i % GGML_CUDA_MAX_STREAMS];
float * c_X = d_X + i * x_ne;
float * c_Y = d_Y + i * y_ne;
float * c_D = d_D + i * d_ne;
// copy data to device
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_X, src0, i03, i02, cudaStream));
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream));
// compute
CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream));
CUBLAS_CHECK(
cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
&alpha, c_X, ne00,
c_Y, ne10,
&beta, c_D, ne01));
// copy dst to host
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
CUDA_CHECK(cudaMemcpyAsync(d, c_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
}
}
void * ptr = nullptr;
cudaError_t err = cudaMallocHost((void **) &ptr, size);
if (err != cudaSuccess) {
fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory: %s\n",
size/1024.0/1024.0, cudaGetErrorString(err));
return nullptr;
CUDA_CHECK(cudaDeviceSynchronize());
ggml_cuda_pool_free(d_X, x_size);
ggml_cuda_pool_free(d_Y, y_size);
ggml_cuda_pool_free(d_D, d_size);
}
static void ggml_cuda_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t /* wsize */) {
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2];
const int64_t ne03 = src0->ne[3];
const int64_t ne10 = src1->ne[0];
const int64_t ne11 = src1->ne[1];
const int nb10 = src1->nb[0];
const int nb11 = src1->nb[1];
const int nb12 = src1->nb[2];
const int nb13 = src1->nb[3];
const int nb2 = dst->nb[2];
const int nb3 = dst->nb[3];
const float alpha = 1.0f;
const float beta = 0.0f;
const int x_ne = ne01 * ne00;
const int y_ne = ne11 * ne10;
const int d_ne = ne11 * ne01;
const int n_mm = ne03 * ne02;
size_t x_size, y_size, d_size;
half * d_X = (half *) ggml_cuda_pool_malloc(n_mm * sizeof(half) * x_ne, &x_size);
half * d_Y = (half *) ggml_cuda_pool_malloc(n_mm * sizeof(half) * y_ne, &y_size);
float * d_D = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size);
bool src1_cont_rows = nb10 == sizeof(float);
bool src1_cont_cols = (size_t)nb11 == ne11*sizeof(float);
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
int i = i03*ne02 + i02;
cudaStream_t cudaStream = g_cudaStreams[i % GGML_CUDA_MAX_STREAMS];
half * c_X = d_X + i * x_ne;
half * c_Y = d_Y + i * y_ne;
float * c_D = d_D + i * d_ne;
// copy src0 to device
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_X, src0, i03, i02, cudaStream));
// convert src1 to fp16
// TODO: use multiple threads
ggml_fp16_t * const tmp = (ggml_fp16_t *) wdata + (ne11 * ne10) * (i03 * ne02 + i02);
char * src1i = (char *) src1->data + i03*nb13 + i02*nb12;
if (src1_cont_rows) {
if (src1_cont_cols) {
ggml_fp32_to_fp16_row((float *) src1i, tmp, ne10*ne11);
}
else {
for (int64_t i01 = 0; i01 < ne11; i01++) {
ggml_fp32_to_fp16_row((float *) (src1i + i01*nb11), tmp + i01*ne10, ne10);
}
}
}
else {
for (int64_t i01 = 0; i01 < ne11; i01++) {
for (int64_t i00 = 0; i00 < ne10; i00++) {
// very slow due to no inlining
tmp[i01*ne10 + i00] = ggml_fp32_to_fp16(*(float *) (src1i + i01*nb11 + i00*nb10));
}
}
}
// copy src1 to device
CUDA_CHECK(cudaMemcpyAsync(c_Y, tmp, sizeof(half) * y_ne, cudaMemcpyHostToDevice, cudaStream));
// compute
CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream));
CUBLAS_CHECK(
cublasGemmEx(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
&alpha, c_X, CUDA_R_16F, ne00,
c_Y, CUDA_R_16F, ne10,
&beta, c_D, CUDA_R_32F, ne01,
CUBLAS_COMPUTE_32F_FAST_16F,
CUBLAS_GEMM_DEFAULT));
// copy dst to host
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
CUDA_CHECK(cudaMemcpyAsync(d, c_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
}
}
return ptr;
CUDA_CHECK(cudaDeviceSynchronize());
ggml_cuda_pool_free(d_X, x_size);
ggml_cuda_pool_free(d_Y, y_size);
ggml_cuda_pool_free(d_D, d_size);
}
void ggml_cuda_host_free(void * ptr) {
CUDA_CHECK(cudaFreeHost(ptr));
static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2];
const int64_t ne03 = src0->ne[3];
const int64_t ne10 = src1->ne[0];
const int64_t ne11 = src1->ne[1];
const int nb2 = dst->nb[2];
const int nb3 = dst->nb[3];
const ggml_type type = src0->type;
const float alpha = 1.0f;
const float beta = 0.0f;
const int x_ne = ne01 * ne00;
const int y_ne = ne11 * ne10;
const int d_ne = ne11 * ne01;
const int n_mm = ne03 * ne02;
const size_t q_sz = ggml_type_size(type) * x_ne / ggml_blck_size(type);
size_t x_size, y_size, d_size, q_size;
float * d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size);
float * d_Y = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * y_ne, &y_size);
float * d_D = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size);
char * d_Q = (char *) ggml_cuda_pool_malloc(n_mm * q_sz, &q_size);
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(type);
GGML_ASSERT(to_fp32_cuda != nullptr);
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
int i = i03*ne02 + i02;
cudaStream_t cudaStream = g_cudaStreams[i % GGML_CUDA_MAX_STREAMS];
cudaStream_t cudaStream2 = g_cudaStreams2[i % GGML_CUDA_MAX_STREAMS];
cudaEvent_t cudaEvent = g_cudaEvents[i % GGML_CUDA_MAX_EVENTS];
float * c_X = d_X + i * x_ne;
float * c_Y = d_Y + i * y_ne;
float * c_D = d_D + i * d_ne;
char * c_Q = d_Q + i * q_sz;
// copy src0 and convert to fp32 on device
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Q, src0, i03, i02, cudaStream2));
to_fp32_cuda(c_Q, c_X, x_ne, cudaStream2);
CUDA_CHECK(cudaGetLastError());
CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
// copy src1 to device
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream));
// wait for conversion
CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));
// compute
CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream));
CUBLAS_CHECK(
cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
&alpha, c_X, ne00,
c_Y, ne10,
&beta, c_D, ne01));
// copy dst to host
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
CUDA_CHECK(cudaMemcpyAsync(d, c_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
}
}
CUDA_CHECK(cudaDeviceSynchronize());
ggml_cuda_pool_free(d_X, x_size);
ggml_cuda_pool_free(d_Y, y_size);
ggml_cuda_pool_free(d_D, d_size);
ggml_cuda_pool_free(d_Q, q_size);
}
bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
const int64_t ne10 = src1->ne[0];
const int64_t ne0 = dst->ne[0];
const int64_t ne1 = dst->ne[1];
// TODO: find the optimal values for these
if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
src1->type == GGML_TYPE_F32 &&
dst->type == GGML_TYPE_F32 &&
(ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) {
return true;
}
return false;
}
bool ggml_cuda_mul_mat_use_f16(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * /* dst */) {
size_t src0_sz = ggml_nbytes(src0);
size_t src1_sz = ggml_nbytes(src1);
// mul_mat_q: src0 is converted to fp32 on device
size_t mul_mat_q_transfer = src0_sz + src1_sz;
// mul_mat_f16: src1 is converted to fp16 on cpu
size_t mul_mat_f16_transfer = src0_sz + sizeof(half) * ggml_nelements(src1);
// choose the smaller one to transfer to the device
// TODO: this is not always the best choice due to the overhead of converting to fp16
return mul_mat_f16_transfer < mul_mat_q_transfer;
}
void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t wsize) {
GGML_ASSERT(ggml_cuda_can_mul_mat(src0, src1, dst));
if (src0->type == GGML_TYPE_F32) {
ggml_cuda_mul_mat_f32(src0, src1, dst);
}
else if (src0->type == GGML_TYPE_F16) {
if (ggml_cuda_mul_mat_use_f16(src0, src1, dst)) {
ggml_cuda_mul_mat_f16(src0, src1, dst, wdata, wsize);
}
else {
ggml_cuda_mul_mat_q_f32(src0, src1, dst);
}
}
else if (ggml_is_quantized(src0->type)) {
ggml_cuda_mul_mat_q_f32(src0, src1, dst);
}
else {
GGML_ASSERT(false);
}
}
size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
if (ggml_cuda_mul_mat_use_f16(src0, src1, dst)) {
return ggml_nelements(src1) * sizeof(ggml_fp16_t);
}
else {
return 0;
}
}

@ -1,54 +1,19 @@
#include <cublas_v2.h>
#include <cuda_runtime.h>
#include "ggml.h"
#ifdef __cplusplus
extern "C" {
#endif
#define CUDA_CHECK(err) \
do { \
cudaError_t err_ = (err); \
if (err_ != cudaSuccess) { \
fprintf(stderr, "CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
cudaGetErrorString(err_)); \
exit(1); \
} \
} while (0)
#define CUBLAS_CHECK(err) \
do { \
cublasStatus_t err_ = (err); \
if (err_ != CUBLAS_STATUS_SUCCESS) { \
fprintf(stderr, "cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \
exit(1); \
} \
} while (0)
void ggml_init_cublas(void);
extern cublasHandle_t g_cublasH;
extern cudaStream_t g_cudaStream;
extern cudaStream_t g_cudaStream2;
extern cudaEvent_t g_cudaEvent;
bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
void ggml_cuda_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst, void * wdata, size_t wsize);
void ggml_init_cublas(void);
// TODO: export these with GGML_API
void * ggml_cuda_host_malloc(size_t size);
void ggml_cuda_host_free(void * ptr);
void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size);
void ggml_cuda_pool_free(void * ptr, size_t size);
void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream);
void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream);
void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream);
void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t stream);
void dequantize_row_q5_1_cuda(const void * vx, float * y, int k, cudaStream_t stream);
void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t stream);
cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, cudaStream_t stream);
typedef void (*dequantize_row_q_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
dequantize_row_q_cuda_t ggml_get_dequantize_row_q_cuda(enum ggml_type type);
#ifdef __cplusplus
}
#endif

252
ggml.c

@ -135,14 +135,6 @@ inline static void* ggml_aligned_malloc(size_t size) {
#define UNUSED(x) (void)(x)
#define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0)
#define GGML_ASSERT(x) \
do { \
if (!(x)) { \
fprintf(stderr, "GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
abort(); \
} \
} while (0)
#if defined(GGML_USE_ACCELERATE)
#include <Accelerate/Accelerate.h>
#elif defined(GGML_USE_OPENBLAS)
@ -370,6 +362,32 @@ ggml_fp16_t ggml_fp32_to_fp16(float x) {
return GGML_FP32_TO_FP16(x);
}
void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, size_t n) {
for (size_t i = 0; i < n; i++) {
y[i] = GGML_FP16_TO_FP32(x[i]);
}
}
void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, size_t n) {
size_t i = 0;
#if defined(__F16C__)
for (; i + 7 < n; i += 8) {
__m256 x_vec = _mm256_loadu_ps(x + i);
__m128i y_vec = _mm256_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);
_mm_storeu_si128((__m128i *)(y + i), y_vec);
}
for(; i + 3 < n; i += 4) {
__m128 x_vec = _mm_loadu_ps(x + i);
__m128i y_vec = _mm_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);
_mm_storel_epi64((__m128i *)(y + i), y_vec);
}
#endif
for (; i < n; i++) {
y[i] = GGML_FP32_TO_FP16(x[i]);
}
}
//
// timing
//
@ -4325,12 +4343,11 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
}
// initialize cuBLAS
#if defined(GGML_USE_CUBLAS)
#if defined(GGML_USE_CUBLAS)
ggml_init_cublas();
#elif defined(GGML_USE_CLBLAST)
#elif defined(GGML_USE_CLBLAST)
ggml_cl_init();
#endif
#endif
is_first_call = false;
}
@ -8101,7 +8118,7 @@ static void ggml_compute_forward_rms_norm(
// ggml_compute_forward_mul_mat
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
// helper function to determine if it is better to use BLAS or not
// for large matrices, BLAS is faster
static bool ggml_compute_forward_mul_mat_use_blas(
@ -8117,12 +8134,9 @@ static bool ggml_compute_forward_mul_mat_use_blas(
const int64_t ne1 = dst->ne[1];
// TODO: find the optimal values for these
if (
#if !defined(GGML_USE_CUBLAS)
ggml_is_contiguous(src0) &&
if (ggml_is_contiguous(src0) &&
ggml_is_contiguous(src1) &&
#endif
((ne0 >= 32 && ne1 >= 32 && ne10 >= 32))) {
(ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) {
/*printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);*/
return true;
@ -8130,7 +8144,6 @@ static bool ggml_compute_forward_mul_mat_use_blas(
return false;
}
#endif
static void ggml_compute_forward_mul_mat_f32(
@ -8146,7 +8159,7 @@ static void ggml_compute_forward_mul_mat_f32(
const int64_t ne02 = src0->ne[2];
const int64_t ne03 = src0->ne[3];
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
const int64_t ne10 = src1->ne[0];
#endif
const int64_t ne11 = src1->ne[1];
@ -8203,7 +8216,16 @@ static void ggml_compute_forward_mul_mat_f32(
// nb01 >= nb00 - src0 is not transposed
// compute by src0 rows
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
#if defined(GGML_USE_CUBLAS)
if (ggml_cuda_can_mul_mat(src0, src1, dst)) {
if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize);
}
return;
}
#endif
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
if (params->ith != 0) {
return;
@ -8217,43 +8239,13 @@ static void ggml_compute_forward_mul_mat_f32(
return;
}
#if defined(GGML_USE_CUBLAS)
const float alpha = 1.0f;
const float beta = 0.0f;
const int x_ne = ne01 * ne00;
const int y_ne = ne11 * ne10;
const int d_ne = ne11 * ne01;
size_t x_size, y_size, d_size;
float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
#endif
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
#if !defined(GGML_USE_CUBLAS)
const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
#endif
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
#if defined(GGML_USE_CUBLAS)
// copy data to device
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_X, src0, i03, i02, g_cudaStream));
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Y, src1, i03, i02, g_cudaStream));
// compute
CUBLAS_CHECK(
cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
&alpha, d_X, ne00,
d_Y, ne10,
&beta, d_D, ne01));
// copy data to host
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
#elif defined(GGML_USE_CLBLAST)
#if defined(GGML_USE_CLBLAST)
// zT = y * xT
ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T,
ne11, ne01, ne10,
@ -8270,12 +8262,6 @@ static void ggml_compute_forward_mul_mat_f32(
#endif
}
}
#if defined(GGML_USE_CUBLAS)
CUDA_CHECK(cudaStreamSynchronize(g_cudaStream));
ggml_cuda_pool_free(d_X, x_size);
ggml_cuda_pool_free(d_Y, y_size);
ggml_cuda_pool_free(d_D, d_size);
#endif
//printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
return;
@ -8405,7 +8391,16 @@ static void ggml_compute_forward_mul_mat_f16_f32(
// nb01 >= nb00 - src0 is not transposed
// compute by src0 rows
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
#if defined(GGML_USE_CUBLAS)
if (ggml_cuda_can_mul_mat(src0, src1, dst)) {
if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize);
}
return;
}
#endif
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
GGML_ASSERT(nb10 == sizeof(float));
@ -8421,37 +8416,8 @@ static void ggml_compute_forward_mul_mat_f16_f32(
return;
}
#if defined(GGML_USE_CUBLAS)
const float alpha = 1.0f;
const float beta = 0.0f;
const int x_ne = ne01 * ne00;
const int y_ne = ne11 * ne10;
const int d_ne = ne11 * ne01;
size_t x_size, y_size, d_size;
ggml_fp16_t * d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
ggml_fp16_t * d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
float * d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
#endif
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
#if defined(GGML_USE_CUBLAS)
// copy src0 while converting src1
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_X, src0, i03, i02, g_cudaStream));
// with cuBlAS, instead of converting src0 to fp32, we convert src1 to fp16
ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + (ne11 * ne10) * (i03 * ne02 + i02);
{
size_t id = 0;
for (int64_t i01 = 0; i01 < ne11; ++i01) {
for (int64_t i00 = 0; i00 < ne10; ++i00) {
wdata[id++] = GGML_FP32_TO_FP16(*(float *) ((char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11 + i00*nb10));
}
}
assert(id*sizeof(ggml_fp16_t) <= params->wsize);
}
#else
float * const wdata = params->wdata;
{
size_t id = 0;
@ -8463,28 +8429,8 @@ static void ggml_compute_forward_mul_mat_f16_f32(
assert(id*sizeof(float) <= params->wsize);
}
#endif
#if defined(GGML_USE_CUBLAS)
const ggml_fp16_t * y = (ggml_fp16_t *) wdata;
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
// copy data to device
CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(ggml_fp16_t) * y_ne, cudaMemcpyHostToDevice, g_cudaStream));
// compute
CUBLAS_CHECK(
cublasGemmEx(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
&alpha, d_X, CUDA_R_16F, ne00,
d_Y, CUDA_R_16F, ne10,
&beta, d_D, CUDA_R_32F, ne01,
CUBLAS_COMPUTE_32F,
CUBLAS_GEMM_DEFAULT));
// copy data to host
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
#elif defined(GGML_USE_CLBLAST)
#if defined(GGML_USE_CLBLAST)
const float * x = wdata;
const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
@ -8513,12 +8459,6 @@ static void ggml_compute_forward_mul_mat_f16_f32(
}
}
#if defined(GGML_USE_CUBLAS)
CUDA_CHECK(cudaStreamSynchronize(g_cudaStream));
ggml_cuda_pool_free(d_X, x_size);
ggml_cuda_pool_free(d_Y, y_size);
ggml_cuda_pool_free(d_D, d_size);
#endif
/*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/
return;
@ -8671,7 +8611,16 @@ static void ggml_compute_forward_mul_mat_q_f32(
// nb01 >= nb00 - src0 is not transposed
// compute by src0 rows
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
#if defined(GGML_USE_CUBLAS)
if (ggml_cuda_can_mul_mat(src0, src1, dst)) {
if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize);
}
return;
}
#endif
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
if (params->ith != 0) {
return;
@ -8685,25 +8634,8 @@ static void ggml_compute_forward_mul_mat_q_f32(
return;
}
#if defined(GGML_USE_CUBLAS)
const float alpha = 1.0f;
const float beta = 0.0f;
const int x_ne = ne01 * ne00;
const int y_ne = ne11 * ne10;
const int d_ne = ne11 * ne01;
size_t x_size, y_size, d_size, q_size;
float * d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
float * d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
float * d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
void * d_Q = ggml_cuda_pool_malloc(GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], &q_size);
const dequantize_row_q_cuda_t dequantize_row_q_cuda = ggml_get_dequantize_row_q_cuda(type);
GGML_ASSERT(dequantize_row_q_cuda != NULL);
#else
float * const wdata = params->wdata;
dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
#endif
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
@ -8711,14 +8643,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
#if defined(GGML_USE_CUBLAS)
// copy and dequantize on device
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Q, src0, i03, i02, g_cudaStream2));
dequantize_row_q_cuda(d_Q, d_X, x_ne, g_cudaStream2);
CUDA_CHECK(cudaGetLastError());
CUDA_CHECK(cudaEventRecord(g_cudaEvent, g_cudaStream2));
#elif defined(GGML_USE_CLBLAST)
#if defined(GGML_USE_CLBLAST)
const void* x = (char *) src0->data + i03*nb03 + i02*nb02;
#else
{
@ -8734,24 +8659,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
const float * x = wdata;
#endif
#if defined(GGML_USE_CUBLAS)
// copy data to device
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Y, src1, i03, i02, g_cudaStream));
// wait for dequantization
CUDA_CHECK(cudaStreamWaitEvent(g_cudaStream, g_cudaEvent, 0));
// compute
CUBLAS_CHECK(
cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
&alpha, d_X, ne00,
d_Y, ne10,
&beta, d_D, ne01));
// copy data to host
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
#elif defined(GGML_USE_CLBLAST)
#if defined(GGML_USE_CLBLAST)
// zT = y * xT
ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T,
ne11, ne01, ne10,
@ -8769,13 +8677,6 @@ static void ggml_compute_forward_mul_mat_q_f32(
}
}
#if defined(GGML_USE_CUBLAS)
CUDA_CHECK(cudaStreamSynchronize(g_cudaStream));
ggml_cuda_pool_free(d_X, x_size);
ggml_cuda_pool_free(d_Y, y_size);
ggml_cuda_pool_free(d_D, d_size);
ggml_cuda_pool_free(d_Q, q_size);
#endif
//printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
return;
@ -11759,18 +11660,21 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
size_t cur = 0;
#if defined(GGML_USE_CUBLAS)
if (ggml_cuda_can_mul_mat(node->src0, node->src1, node)) {
node->n_tasks = 1; // TODO: this actually is doing nothing
// the threads are still spinning
cur = ggml_cuda_mul_mat_get_wsize(node->src0, node->src1, node);
}
else
#endif
if (node->src0->type == GGML_TYPE_F16 && node->src1->type == GGML_TYPE_F32) {
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
node->n_tasks = 1; // TODO: this actually is doing nothing
// the threads are still spinning
#if defined(GGML_USE_CUBLAS)
// with cuBLAS, we need memory for the full 3D / 4D data of src1
cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1);
#else
// here we need memory just for single 2D matrix from src0
cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
#endif
} else {
cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1);
}
@ -11779,13 +11683,13 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
#endif
} else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) {
cur = 0;
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
node->n_tasks = 1;
}
#endif
} else if (ggml_is_quantized(node->src0->type) && node->src1->type == GGML_TYPE_F32) {
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
node->n_tasks = 1;
cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);

@ -197,6 +197,14 @@
#define GGML_MAX_OPT 4
#define GGML_DEFAULT_N_THREADS 4
#define GGML_ASSERT(x) \
do { \
if (!(x)) { \
fprintf(stderr, "GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
abort(); \
} \
} while (0)
#ifdef __cplusplus
extern "C" {
#endif
@ -212,6 +220,9 @@ extern "C" {
GGML_API float ggml_fp16_to_fp32(ggml_fp16_t x);
GGML_API ggml_fp16_t ggml_fp32_to_fp16(float x);
GGML_API void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, size_t n);
GGML_API void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, size_t n);
struct ggml_object;
struct ggml_context;

Loading…
Cancel
Save