diff --git a/convert-lora-to-ggml.py b/convert-lora-to-ggml.py new file mode 100644 index 0000000..8a2085c --- /dev/null +++ b/convert-lora-to-ggml.py @@ -0,0 +1,124 @@ +import json +import os +import re +import struct +import sys +from typing import Any, Dict, Sequence, TextIO + +import torch + +from convert import DATA_TYPE_TO_FTYPE, NUMPY_TYPE_TO_DATA_TYPE, DataType + +HF_SUBLAYER_TO_GGML = { + "self_attn.q_proj": "attention.wq", + "self_attn.k_proj": "attention.wk", + "self_attn.v_proj": "attention.wv", + "self_attn.o_proj": "attention.wo", + "mlp.gate_proj": "feed_forward.w1", + "mlp.down_proj": "feed_forward.w2", + "mlp.up_proj": "feed_forward.w3", + "input_layernorm": "attention_norm", + "post_attention_layernorm": "ffn_norm", + # "norm": "norm", + # "embed_tokens": "tok_embeddings", + # "lm_head": "output", +} + + +def translate_tensor_name(t: str) -> str: + match = re.match(r".*layers\.(\d+)\.(\w+\.\w+)\.lora_(A|B)\.weight", t) + if match: + nn = match.group(1) + sub_layer = match.group(2) + lora_type = match.group(3) + + sub_layer_renamed = HF_SUBLAYER_TO_GGML.get(sub_layer) + if sub_layer_renamed is None: + print(f"Error: unrecognized sub-layer {sub_layer} in tensor {t}") + sys.exit(1) + + output_string = ( + f"layers.{nn}.{HF_SUBLAYER_TO_GGML[sub_layer]}.weight.lora{lora_type}" + ) + return output_string + else: + print(f"Error: unrecognized tensor {t}") + sys.exit(1) + + +def write_file_header(fout: TextIO, params: Dict[str, Any]) -> None: + fout.write(b"ggla"[::-1]) # magic (ggml lora) + fout.write(struct.pack("i", 1)) # file version + fout.write(struct.pack("ii", params["r"], params["lora_alpha"])) + + +def write_tensor_header( + self, name: str, shape: Sequence[int], data_type: DataType +) -> None: + sname = name.encode("utf-8") + fout.write( + struct.pack( + "iii", + len(shape), + len(sname), + DATA_TYPE_TO_FTYPE[NUMPY_TYPE_TO_DATA_TYPE[data_type]], + ) + ) + fout.write(struct.pack("i" * len(shape), *shape[::-1])) + fout.write(sname) + fout.seek((fout.tell() + 31) & -32) + + +if len(sys.argv) != 2: + print(f"Usage: python {sys.argv[0]} ") + print( + "Path must contain HuggingFace PEFT LoRA files 'adapter_config.json' and 'adapter_model.bin'" + ) + sys.exit(1) + +input_json = os.path.join(sys.argv[1], "adapter_config.json") +input_model = os.path.join(sys.argv[1], "adapter_model.bin") +output_path = os.path.join(sys.argv[1], "ggml-adapter-model.bin") + +model = torch.load(input_model, map_location="cpu") + +with open(input_json, "r") as f: + params = json.load(f) + +if params["peft_type"] != "LORA": + print(f"Error: unsupported adapter type {params['peft_type']}, expected LORA") + sys.exit(1) + +if params["fan_in_fan_out"] == True: + print("Error: param fan_in_fan_out is not supported") + sys.exit(1) + +if params["bias"] is not None and params["bias"] != "none": + print("Error: param bias is not supported") + sys.exit(1) + +# TODO: these seem to be layers that have been trained but without lora. +# doesn't seem widely used but eventually should be supported +if params["modules_to_save"] is not None and len(params["modules_to_save"]) > 0: + print("Error: param modules_to_save is not supported") + sys.exit(1) + +with open(output_path, "wb") as fout: + fout.truncate() + + write_file_header(fout, params) + for k, v in model.items(): + if k.endswith("lora_A.weight"): + if v.dtype != torch.float16 and v.dtype != torch.float32: + v = v.float() + v = v.T + else: + v = v.float() + + t = v.numpy() + tname = translate_tensor_name(k) + print(f"{k} => {tname} {t.shape} {t.dtype} {t.nbytes/1024/1024:.2f}MB") + write_tensor_header(fout, tname, t.shape, t.dtype) + t.tofile(fout) + +print(f"Converted {input_json} and {input_model} to {output_path}") diff --git a/examples/common.cpp b/examples/common.cpp index 0772dbf..a0b6f10 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -139,6 +139,19 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.model = argv[i]; + } else if (arg == "--lora") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.lora_adapter = argv[i]; + params.use_mmap = false; + } else if (arg == "--lora-base") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.lora_base = argv[i]; } else if (arg == "-i" || arg == "--interactive") { params.interactive = true; } else if (arg == "--embedding") { @@ -242,6 +255,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { } fprintf(stderr, " --mtest compute maximum memory usage\n"); fprintf(stderr, " --verbose-prompt print prompt before generation\n"); + fprintf(stderr, " --lora FNAME apply LoRA adapter (implies --no-mmap)\n"); + fprintf(stderr, " --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n"); fprintf(stderr, " -m FNAME, --model FNAME\n"); fprintf(stderr, " model path (default: %s)\n", params.model.c_str()); fprintf(stderr, "\n"); diff --git a/examples/common.h b/examples/common.h index 1ea6f74..cbbc2df 100644 --- a/examples/common.h +++ b/examples/common.h @@ -31,11 +31,12 @@ struct gpt_params { std::string model = "models/lamma-7B/ggml-model.bin"; // model path std::string prompt = ""; - std::string input_prefix = ""; // string to prefix user inputs with - - + std::string input_prefix = ""; // string to prefix user inputs with std::vector antiprompt; // string upon seeing which more user input is prompted + std::string lora_adapter = ""; // lora adapter path + std::string lora_base = ""; // base model path for the lora adapter + bool memory_f16 = true; // use f16 instead of f32 for memory kv bool random_prompt = false; // do not randomize prompt if none provided bool use_color = false; // use color to distinguish generations and inputs diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 3e4b003..b7b3c41 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -114,6 +114,17 @@ int main(int argc, char ** argv) { } } + if (!params.lora_adapter.empty()) { + int err = llama_apply_lora_from_file(ctx, + params.lora_adapter.c_str(), + params.lora_base.empty() ? NULL : params.lora_base.c_str(), + params.n_threads); + if (err != 0) { + fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__); + return 1; + } + } + // print system information { fprintf(stderr, "\n"); diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 19449e1..80792ea 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -134,6 +134,17 @@ int main(int argc, char ** argv) { } } + if (!params.lora_adapter.empty()) { + int err = llama_apply_lora_from_file(ctx, + params.lora_adapter.c_str(), + params.lora_base.empty() ? NULL : params.lora_base.c_str(), + params.n_threads); + if (err != 0) { + fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__); + return 1; + } + } + // print system information { fprintf(stderr, "\n"); diff --git a/ggml.c b/ggml.c index 995a2fa..acdba03 100644 --- a/ggml.c +++ b/ggml.c @@ -1420,6 +1420,34 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in #endif } +static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); +static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); + +static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = { + [GGML_TYPE_Q4_0] = { + .dequantize_row_q = dequantize_row_q4_0, + .quantize_row_q = quantize_row_q4_0, + .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference, + .quantize_row_q_dot = quantize_row_q8_0, + .vec_dot_q = ggml_vec_dot_q4_0_q8_0, + }, + [GGML_TYPE_Q4_1] = { + .dequantize_row_q = dequantize_row_q4_1, + .quantize_row_q = quantize_row_q4_1, + .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference, + .quantize_row_q_dot = quantize_row_q4_1, + .vec_dot_q = ggml_vec_dot_q4_1, + }, + // TODO: GGML_TYPE_Q8_0 +}; + +// For internal test use +quantize_fns_t ggml_internal_get_quantize_fn(size_t i) { + GGML_ASSERT(i < GGML_TYPE_COUNT); + return quantize_fns[i]; +} + + // // simd mappings // @@ -5588,6 +5616,26 @@ static void ggml_compute_forward_dup_f16( } } } + } else if (dst->type == GGML_TYPE_Q4_0 || dst->type == GGML_TYPE_Q4_1) { + quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q; + size_t id = 0; + uint8_t * dst_ptr = (uint8_t *) dst->data; + size_t dst_row_size = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]); + float * src0_f32 = (float *) params->wdata; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + // convert to f32 and quantize + for (int i00 = 0; i00 < ne00; i00++) { + src0_f32[i00] = GGML_FP16_TO_FP32(src0_ptr[i00]); + } + quantize_row_q(src0_f32, dst_ptr + id, ne00); + id += dst_row_size; + } + } + } } else { GGML_ASSERT(false); // TODO: implement } @@ -5780,6 +5828,21 @@ static void ggml_compute_forward_dup_f32( } } } + } else if (dst->type == GGML_TYPE_Q4_0 || dst->type == GGML_TYPE_Q4_1) { + quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q; + size_t id = 0; + uint8_t * dst_ptr = (uint8_t *) dst->data; + size_t dst_row_size = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]); + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + quantize_row_q(src0_ptr, dst_ptr + id, ne00); + id += dst_row_size; + } + } + } } else { GGML_ASSERT(false); // TODO: implement } @@ -5968,6 +6031,212 @@ static void ggml_compute_forward_add_f32( } } +static void ggml_compute_forward_add_f16_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int ith = params->ith; + const int nth = params->nth; + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + const size_t nb00 = src0->nb[0]; + const size_t nb01 = src0->nb[1]; + + const size_t nb10 = src1->nb[0]; + const size_t nb11 = src1->nb[1]; + + const size_t nb0 = dst->nb[0]; + const size_t nb1 = dst->nb[1]; + + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F16); + + GGML_ASSERT( nb0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + + if (nb10 == sizeof(float)) { + for (int j = ith; j < n; j += nth) { + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1); + ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01); + for (int i = 0; i < nc; i++) { + float * src1_ptr = (float *) ((char *) src1->data + j*nb11 + i*nb10); + dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + *src1_ptr); + } + } + } + else { + // src1 is not contiguous + GGML_ASSERT(false); + } +} + +static void ggml_compute_forward_add_f16_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int ith = params->ith; + const int nth = params->nth; + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + const size_t nb00 = src0->nb[0]; + const size_t nb01 = src0->nb[1]; + + const size_t nb10 = src1->nb[0]; + const size_t nb11 = src1->nb[1]; + + const size_t nb0 = dst->nb[0]; + const size_t nb1 = dst->nb[1]; + + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F16); + + GGML_ASSERT( nb0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + + if (nb10 == sizeof(ggml_fp16_t)) { + for (int j = ith; j < n; j += nth) { + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1); + ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01); + for (int i = 0; i < nc; i++) { + ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + j*nb11 + i*nb10); + dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + GGML_FP16_TO_FP32(*src1_ptr)); + } + } + } + else { + // src1 is not contiguous + GGML_ASSERT(false); + } +} + +static void ggml_compute_forward_add_q_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[3]; + + //const int64_t ne10 = src1->ne[0]; + //const int64_t ne11 = src1->ne[1]; + const int64_t ne12 = src1->ne[2]; + const int64_t ne13 = src1->ne[3]; + + //const int64_t ne0 = dst->ne[0]; + //const int64_t ne1 = dst->ne[1]; + const int64_t ne2 = dst->ne[2]; + const int64_t ne3 = dst->ne[3]; + + const int nb00 = src0->nb[0]; + const int nb01 = src0->nb[1]; + const int nb02 = src0->nb[2]; + const int nb03 = src0->nb[3]; + + const int nb10 = src1->nb[0]; + const int nb11 = src1->nb[1]; + const int nb12 = src1->nb[2]; + const int nb13 = src1->nb[3]; + + const int nb0 = dst->nb[0]; + const int nb1 = dst->nb[1]; + const int nb2 = dst->nb[2]; + const int nb3 = dst->nb[3]; + + const int ith = params->ith; + const int nth = params->nth; + + GGML_ASSERT(ne02 == ne12); + GGML_ASSERT(ne03 == ne13); + GGML_ASSERT(ne2 == ne12); + GGML_ASSERT(ne3 == ne13); + + const enum ggml_type type = src0->type; + dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q; + quantize_row_q_t const quantize_row_q = quantize_fns[type].quantize_row_q; + + // we don't support permuted src0 or src1 + GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]); + GGML_ASSERT(nb10 == sizeof(float)); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + GGML_ASSERT(src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1); + GGML_ASSERT(dst->type == src0->type); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + // total rows in src0 + const int nr = ne01*ne02*ne03; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + float * wdata = (float*) params->wdata + ne00 * ith; + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 indices + const int i03 = ir/(ne02*ne01); + const int i02 = (ir - i03*ne02*ne01)/ne01; + const int i01 = (ir - i03*ne02*ne01 - i02*ne01); + + // src1 and dst are same shape as src0 => same indices + const int i13 = i03; + const int i12 = i02; + const int i11 = i01; + + const int i3 = i03; + const int i2 = i02; + const int i1 = i01; + + void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)); + float * src1_row = (float *)((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13)); + void * dst_row = (void *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb0)); + + assert(ne00 % 32 == 0); + + // unquantize row from src0 to temp buffer + dequantize_row_q(src0_row, wdata, ne00); + // add src1 + ggml_vec_acc_f32(ne00, wdata, src1_row); + // quantize row to dst + quantize_row_q(wdata, dst_row, ne00); + } +} + static void ggml_compute_forward_add( const struct ggml_compute_params * params, const struct ggml_tensor * src0, @@ -5978,6 +6247,23 @@ static void ggml_compute_forward_add( { ggml_compute_forward_add_f32(params, src0, src1, dst); } break; + case GGML_TYPE_F16: + { + if (src1->type == GGML_TYPE_F16) { + ggml_compute_forward_add_f16_f16(params, src0, src1, dst); + } + else if (src1->type == GGML_TYPE_F32) { + ggml_compute_forward_add_f16_f32(params, src0, src1, dst); + } + else { + GGML_ASSERT(false); + } + } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + { + ggml_compute_forward_add_q_f32(params, src0, src1, dst); + } break; default: { GGML_ASSERT(false); @@ -7257,30 +7543,6 @@ static void ggml_compute_forward_mul_mat_f16_f32( //} } -static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = { - [GGML_TYPE_Q4_0] = { - .dequantize_row_q = dequantize_row_q4_0, - .quantize_row_q = quantize_row_q4_0, - .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference, - .quantize_row_q_dot = quantize_row_q8_0, - .vec_dot_q = ggml_vec_dot_q4_0_q8_0, - }, - [GGML_TYPE_Q4_1] = { - .dequantize_row_q = dequantize_row_q4_1, - .quantize_row_q = quantize_row_q4_1, - .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference, - .quantize_row_q_dot = quantize_row_q4_1, - .vec_dot_q = ggml_vec_dot_q4_1, - }, - // TODO: GGML_TYPE_Q8_0 -}; - -// For internal test use -quantize_fns_t ggml_internal_get_quantize_fn(size_t i) { - GGML_ASSERT(i < GGML_TYPE_COUNT); - return quantize_fns[i]; -} - static void ggml_compute_forward_mul_mat_q_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, @@ -10137,13 +10399,29 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) struct ggml_tensor * node = cgraph->nodes[i]; switch (node->op) { + case GGML_OP_CPY: case GGML_OP_DUP: { node->n_tasks = 1; + + size_t cur = 0; + if (node->type == GGML_TYPE_Q4_0 || node->type == GGML_TYPE_Q4_1) { + cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0]; + } + + work_size = MAX(work_size, cur); } break; case GGML_OP_ADD: { node->n_tasks = n_threads; + + size_t cur = 0; + + if (node->src0->type == GGML_TYPE_Q4_0 || node->src0->type == GGML_TYPE_Q4_1) { + cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src0->ne[0] * n_threads; + } + + work_size = MAX(work_size, cur); } break; case GGML_OP_SUB: case GGML_OP_MUL: @@ -10224,7 +10502,6 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) { node->n_tasks = n_threads; } break; - case GGML_OP_CPY: case GGML_OP_CONT: case GGML_OP_RESHAPE: case GGML_OP_VIEW: diff --git a/ggml.h b/ggml.h index e693754..59de0cb 100644 --- a/ggml.h +++ b/ggml.h @@ -430,6 +430,12 @@ struct ggml_tensor * ggml_add( struct ggml_tensor * a, struct ggml_tensor * b); + +struct ggml_tensor * ggml_add_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + struct ggml_tensor * ggml_sub( struct ggml_context * ctx, struct ggml_tensor * a, diff --git a/llama.cpp b/llama.cpp index cdd8bd1..db71c03 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1,6 +1,8 @@ // Defines fileno on msys: #ifndef _GNU_SOURCE #define _GNU_SOURCE +#include +#include #endif #include "llama_util.h" @@ -633,6 +635,7 @@ struct llama_model_loader { throw format("llama.cpp: tensor '%s' has wrong shape; expected %s, got %s", name.c_str(), llama_format_tensor_shape(ne).c_str(), llama_format_tensor_shape(lt.ne).c_str()); } + return get_tensor_for(lt); } @@ -1774,6 +1777,254 @@ int llama_model_quantize( } } +int llama_apply_lora_from_file_internal(struct llama_context * ctx, const char * path_lora, const char * path_base_model, int n_threads) { + fprintf(stderr, "%s: applying lora adapter from '%s' - please wait ...\n", __func__, path_lora); + + auto & model = ctx->model; + + const int64_t t_start_lora_us = ggml_time_us(); + + auto fin = std::ifstream(path_lora, std::ios::binary); + if (!fin) { + fprintf(stderr, "%s: failed to open '%s'\n", __func__, path_lora); + return 1; + } + + // verify magic and version + { + uint32_t magic; + fin.read((char *) &magic, sizeof(magic)); + if (magic != 'ggla') { + fprintf(stderr, "%s: bad file magic\n", __func__); + return 1; + } + uint32_t format_version; + fin.read((char *) &format_version, sizeof(format_version)); + + if (format_version != 1) { + fprintf(stderr, "%s: unsupported file version\n", __func__ ); + return 1; + } + } + + int32_t lora_r; + int32_t lora_alpha; + fin.read((char *) &lora_r, sizeof(lora_r)); + fin.read((char *) &lora_alpha, sizeof(lora_alpha)); + float scaling = (float)lora_alpha / (float)lora_r; + + fprintf(stderr, "%s: r = %d, alpha = %d, scaling = %.2f\n", __func__, lora_r, lora_alpha, scaling); + + + // create a temporary ggml context to store the lora tensors + // todo: calculate size from biggest possible tensor + std::vector lora_buf(1024ull * 1024ull * 1024ull); + struct ggml_init_params params; + params.mem_size = lora_buf.size(); + params.mem_buffer = lora_buf.data(); + params.no_alloc = false; + + ggml_context * lora_ctx = ggml_init(params); + std::unordered_map lora_tensors; + + // create a name -> tensor map of the model to accelerate lookups + std::unordered_map model_tensors; + for (auto & kv: model.tensors_by_name) { + model_tensors.insert(kv); + } + + + // load base model + std::unique_ptr model_loader; + ggml_context * base_ctx = NULL; + llama_buffer base_buf; + if (path_base_model) { + fprintf(stderr, "%s: loading base model from '%s'\n", __func__, path_base_model); + model_loader.reset(new llama_model_loader(path_base_model, /*use_mmap*/ true, /*vocab_only*/ false)); + + size_t ctx_size, mmapped_size; + model_loader->calc_sizes(&ctx_size, &mmapped_size); + base_buf.resize(ctx_size); + + ggml_init_params base_params; + base_params.mem_size = base_buf.size; + base_params.mem_buffer = base_buf.addr; + base_params.no_alloc = model_loader->use_mmap; + + base_ctx = ggml_init(base_params); + + model_loader->ggml_ctx = base_ctx; + + // maybe this should in llama_model_loader + if (model_loader->use_mmap) { + model_loader->mapping.reset(new llama_mmap(&model_loader->file_loaders.at(0)->file, /* prefetch */ false)); + } + } + + // read tensors and apply + bool warned = false; + int n_tensors = 0; + while (true) { + int32_t n_dims; + int32_t length; + int32_t ftype; + + fin.read(reinterpret_cast(&n_dims), sizeof(n_dims)); + fin.read(reinterpret_cast(&length), sizeof(length)); + fin.read(reinterpret_cast(&ftype), sizeof(ftype)); + if (fin.eof()) { + break; + } + + int32_t ne[2] = { 1, 1 }; + for (int i = 0; i < n_dims; ++i) { + fin.read(reinterpret_cast(&ne[i]), sizeof(ne[i])); + } + + std::string name(length, 0); + fin.read(&name[0], length); + + // check for lora suffix and get the type of tensor + const std::string lora_suffix = ".lora"; + size_t pos = name.rfind(lora_suffix); + if (pos == std::string::npos) { + fprintf(stderr, "%s: error: '%s' is not a lora tensor\n", __func__, name.c_str()); + return 1; + } + + std::string lora_type = name.substr(pos + lora_suffix.length()); + std::string base_name = name; + base_name.erase(pos); + // fprintf(stderr, "%s: %s => %s (lora type %s) ", __func__, name.c_str(),base_name.c_str(), lora_type.c_str()); + + if (model_tensors.find(base_name.data()) == model_tensors.end()) { + fprintf(stderr, "%s: unknown tensor '%s' in lora adapter\n", __func__, name.data()); + return 1; + } + + // create ggml tensor + ggml_type wtype; + switch (ftype) { + case 0: wtype = GGML_TYPE_F32; break; + case 1: wtype = GGML_TYPE_F16; break; + default: + { + fprintf(stderr, "%s: invalid tensor data type '%d'\n", + __func__, ftype); + return false; + } + } + ggml_tensor* lora_tensor; + if (n_dims == 2) { + lora_tensor = ggml_new_tensor_2d(lora_ctx, wtype, ne[0], ne[1]); + } + else { + fprintf(stderr, "%s: unsupported tensor dimension %d\n", __func__, n_dims); + return 1; + } + + // load tensor data + size_t offset = fin.tellg(); + size_t tensor_data_size = ggml_nbytes(lora_tensor); + offset = (offset + 31) & -32; + fin.seekg(offset); + fin.read((char*)lora_tensor->data, tensor_data_size); + + lora_tensors[name] = lora_tensor; + + // check if we have both A and B tensors and apply + if (lora_tensors.find(base_name + ".loraA") != lora_tensors.end() && + lora_tensors.find(base_name + ".loraB") != lora_tensors.end()) { + + ggml_tensor * dest_t = model_tensors[base_name]; + ggml_tensor * base_t; + if (model_loader) { + // load from base model + if (model_loader->tensors_map.name_to_idx.find(base_name) == model_loader->tensors_map.name_to_idx.end()) { + fprintf(stderr, "%s: error: tensor '%s' not found in base model\n", __func__, base_name.c_str()); + return 1; + } + size_t idx = model_loader->tensors_map.name_to_idx[base_name]; + llama_load_tensor & lt = model_loader->tensors_map.tensors[idx]; + base_t = model_loader->get_tensor(base_name, { (uint32_t)dest_t->ne[0], (uint32_t)dest_t->ne[1] }); + lt.data = (uint8_t *) lt.ggml_tensor->data; + model_loader->load_data_for(lt); + lt.ggml_tensor->data = lt.data; + } + else { + base_t = dest_t; + } + + if (base_t->type == GGML_TYPE_Q4_0 || base_t->type == GGML_TYPE_Q4_1) { + if (!warned) { + fprintf(stderr, "%s: warning: using a lora adapter with a quantized model may result in poor quality, " + "use a f16 or f32 base model with --lora-base\n", __func__); + warned = true; + } + } + + ggml_tensor * loraA = lora_tensors[base_name + ".loraA"]; + ggml_tensor * loraB = lora_tensors[base_name + ".loraB"]; + + if (base_t->ne[0] != loraA->ne[1] || base_t->ne[1] != loraB->ne[1]) { + fprintf(stderr, "%s: incompatible tensor dimensions (%" PRId64 " and %" PRId64 ");" + " are you sure that this adapter is for this model?\n", __func__, base_t->ne[0], loraA->ne[1]); + return 1; + } + + // w = w + BA*s + ggml_tensor * BA = ggml_mul_mat(lora_ctx, loraA, loraB); + + if (scaling != 1.0f) { + ggml_tensor * scale_tensor = ggml_new_f32(lora_ctx, scaling); + BA = ggml_scale(lora_ctx, BA, scale_tensor); + } + + ggml_tensor * r; + if (base_t == dest_t) { + r = ggml_add_inplace(lora_ctx, dest_t, BA); + } + else { + r = ggml_add(lora_ctx, base_t, BA); + r = ggml_cpy(lora_ctx, r, dest_t); + } + + struct ggml_cgraph gf = ggml_build_forward(r); + gf.n_threads = n_threads; + ggml_graph_compute(lora_ctx, &gf); + + // we won't need these tensors again, reset the context to save memory + ggml_free(lora_ctx); + lora_ctx = ggml_init(params); + lora_tensors.clear(); + + n_tensors++; + if (n_tensors % 4 == 0) + fprintf(stderr, "."); + } + } + + // TODO: this should be in a destructor, it will leak on failure + ggml_free(lora_ctx); + if (base_ctx) { + ggml_free(base_ctx); + } + + const int64_t t_lora_us = ggml_time_us() - t_start_lora_us; + fprintf(stderr, " done (%.2f ms)\n", t_lora_us / 1000.0); + + return 0; +} + +int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lora, const char * path_base_model, int n_threads) { + try { + return llama_apply_lora_from_file_internal(ctx, path_lora, path_base_model, n_threads); + } catch (const std::string & err) { + fprintf(stderr, "%s: failed to apply lora adapter: %s\n", __func__, err.c_str()); + return 1; + } +} + // Returns the KV cache that will contain the context for the // ongoing prediction with the model. const uint8_t * llama_get_kv_cache(struct llama_context * ctx) { diff --git a/llama.h b/llama.h index 1922175..c35193a 100644 --- a/llama.h +++ b/llama.h @@ -96,6 +96,18 @@ extern "C" { const char * fname_out, enum llama_ftype ftype); + // Apply a LoRA adapter to a loaded model + // path_base_model is the path to a higher quality model to use as a base for + // the layers modified by the adapter. Can be NULL to use the current loaded model. + // The model needs to be reloaded before applying a new adapter, otherwise the adapter + // will be applied on top of the previous one + // Returns 0 on success + LLAMA_API int llama_apply_lora_from_file( + struct llama_context * ctx, + const char * path_lora, + const char * path_base_model, + int n_threads); + // Returns the KV cache that will contain the context for the // ongoing prediction with the model. LLAMA_API const uint8_t * llama_get_kv_cache(struct llama_context * ctx); diff --git a/llama_util.h b/llama_util.h index c92c0cc..ee9b2a6 100755 --- a/llama_util.h +++ b/llama_util.h @@ -168,7 +168,7 @@ struct llama_mmap { #ifdef _POSIX_MAPPED_FILES static constexpr bool SUPPORTED = true; - llama_mmap(struct llama_file * file) { + llama_mmap(struct llama_file * file, bool prefetch = true) { size = file->size; int fd = fileno(file->fp); int flags = MAP_SHARED; @@ -180,10 +180,12 @@ struct llama_mmap { throw format("mmap failed: %s", strerror(errno)); } - // Advise the kernel to preload the mapped memory - if (madvise(addr, file->size, MADV_WILLNEED)) { - fprintf(stderr, "warning: madvise(.., MADV_WILLNEED) failed: %s\n", - strerror(errno)); + if (prefetch) { + // Advise the kernel to preload the mapped memory + if (madvise(addr, file->size, MADV_WILLNEED)) { + fprintf(stderr, "warning: madvise(.., MADV_WILLNEED) failed: %s\n", + strerror(errno)); + } } } @@ -193,7 +195,7 @@ struct llama_mmap { #elif defined(_WIN32) static constexpr bool SUPPORTED = true; - llama_mmap(struct llama_file * file) { + llama_mmap(struct llama_file * file, bool prefetch = true) { size = file->size; HANDLE hFile = (HANDLE) _get_osfhandle(_fileno(file->fp)); @@ -215,13 +217,15 @@ struct llama_mmap { } #if _WIN32_WINNT >= _WIN32_WINNT_WIN8 - // Advise the kernel to preload the mapped memory - WIN32_MEMORY_RANGE_ENTRY range; - range.VirtualAddress = addr; - range.NumberOfBytes = (SIZE_T)size; - if (!PrefetchVirtualMemory(GetCurrentProcess(), 1, &range, 0)) { - fprintf(stderr, "warning: PrefetchVirtualMemory failed: %s\n", - llama_format_win_err(GetLastError()).c_str()); + if (prefetch) { + // Advise the kernel to preload the mapped memory + WIN32_MEMORY_RANGE_ENTRY range; + range.VirtualAddress = addr; + range.NumberOfBytes = (SIZE_T)size; + if (!PrefetchVirtualMemory(GetCurrentProcess(), 1, &range, 0)) { + fprintf(stderr, "warning: PrefetchVirtualMemory failed: %s\n", + llama_format_win_err(GetLastError()).c_str()); + } } #else #pragma message("warning: You are building for pre-Windows 8; prefetch not supported")