diff --git a/convert-gptq-to-ggml.py b/convert-gptq-to-ggml.py new file mode 100644 index 0000000..7fccb4d --- /dev/null +++ b/convert-gptq-to-ggml.py @@ -0,0 +1,172 @@ +# Convert a GPTQ quantized LLaMA model to a ggml compatible file +# Based on: https://github.com/qwopqwop200/GPTQ-for-LLaMa +# +import os +import re +import sys +import json +import struct +import numpy as np +import torch +from sentencepiece import SentencePieceProcessor + +if len(sys.argv) != 4: + print("Usage: convert-gptq-to-ggml.py llamaXXb-4bit.pt tokenizer.model out.bin\n") + sys.exit(1) + +fname_model = sys.argv[1] +fname_tokenizer = sys.argv[2] +dir_out = sys.argv[3] + +model = torch.load(fname_model, map_location="cpu") + +n_vocab, n_embd = model['model.embed_tokens.weight'].shape +n_layer = 1 + max(int(m.group(1)) for name in model + if (m := re.match(r'model\.layers\.([0-9]+)', name))) + +# hardcoded: +n_mult = 256 +n_head = {32: 32, 40: 40, 60: 52, 80: 64}[n_layer] + +tokenizer = SentencePieceProcessor(fname_tokenizer) + +assert tokenizer.vocab_size() == n_vocab + +fname_out = sys.argv[3] + +fout = open(fname_out, "wb") + +fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex +fout.write(struct.pack("i", n_vocab)) +fout.write(struct.pack("i", n_embd)) +fout.write(struct.pack("i", n_mult)) +fout.write(struct.pack("i", n_head)) +fout.write(struct.pack("i", n_layer)) +fout.write(struct.pack("i", n_embd // n_head)) # rot (obsolete) +fout.write(struct.pack("i", 4)) + + +# This loop unchanged from convert-pth-to-ggml.py: +for i in range(tokenizer.vocab_size()): + if tokenizer.is_unknown(i): + # "" token (translated as ??) + text = " \u2047 ".encode("utf-8") + fout.write(struct.pack("i", len(text))) + fout.write(text) + elif tokenizer.is_control(i): + # ""/"" tokens + fout.write(struct.pack("i", 0)) + elif tokenizer.is_byte(i): + # "" tokens (which may be invalid UTF-8) + piece = tokenizer.id_to_piece(i) + if len(piece) != 6: + print("Invalid token: " + piece) + sys.exit(1) + byte_value = int(piece[3:-1], 16) + fout.write(struct.pack("i", 1)) + fout.write(struct.pack("B", byte_value)) + else: + # normal token. Uses U+2581 (LOWER ONE EIGHTH BLOCK) to represent spaces. + text = tokenizer.id_to_piece(i).replace("\u2581", " ").encode("utf-8") + fout.write(struct.pack("i", len(text))) + fout.write(text) + +def write_header(shape, dst_name, ftype_cur): + sname = dst_name.encode('utf-8') + fout.write(struct.pack("iii", len(shape), len(sname), ftype_cur)) + fout.write(struct.pack("i" * len(shape), *shape[::-1])) + fout.write(sname) + +def convert_non_q4(src_name, dst_name): + v = model[src_name] + shape = v.shape + print("Processing non-Q4 variable: " + src_name + " with shape: ", shape, " and type: ", v.dtype) + if len(shape) == 1: + print(" Converting to float32") + v = v.to(torch.float32) + + ftype_cur = {torch.float16: 1, torch.float32: 0}[v.dtype] + + # header + write_header(shape, dst_name, ftype_cur) + + # data + v.numpy().tofile(fout) + +def convert_q4(src_name, dst_name, permute=False): + zeros = model[f"{src_name}.zeros"].numpy() + scales = model[f"{src_name}.scales"].numpy() + bias = model[f"{src_name}.bias"].numpy() + qweight = model[f"{src_name}.qweight"].numpy().T # transpose + + # Q4_1 does not support bias; good thing the bias is always all zeros. + assert not np.any(bias) + + # Each int32 item is actually 8 int4 items packed together, and it's transposed. + shape = (qweight.shape[0], qweight.shape[1] * 8) + + print("Processing Q4 variable: " + src_name + " with shape: ", shape) + + # The output format has the int4 weights in groups of 32 rather than 8. + # It looks like this: + # For each row: + # For each group of 32 columns: + # - addend (float32, 4 bytes) + # - scale (float32, 4 bytes) + # - weights (int4 * 32, 16 bytes) + # Note that in the input, the scales and addends are shared between all + # the columns in a row, so we end up wasting quite a bit of memory with + # repeated scales and addends. + + addends = -zeros # flip sign + + # Since the output format is mixed between integers and floats, we have + # to hackily view the floats as int32s just so numpy will let us + # concatenate them. + addends_view = addends.view(dtype=np.int32) + scales_view = scales.view(dtype=np.int32) + + # Split into groups of 4 columns (i.e. 32 columns of quantized data): + grouped = qweight.reshape([qweight.shape[0], qweight.shape[1] // 4, 4]) + + # Repeat addends and scales: + addends_rep = np.atleast_3d(addends_view).repeat(grouped.shape[1], axis=1) + scales_rep = np.atleast_3d(scales_view).repeat(grouped.shape[1], axis=1) + + blob = np.concatenate([scales_rep, addends_rep, grouped], axis=2, casting='no') + + if permute: + # Permute some rows to undo the permutation done by convert_llama_weights_to_hf.py. + # This can be done after the above conversion because it doesn't affect column order/layout. + blob = (blob.reshape(n_head, 2, shape[0] // n_head // 2, *blob.shape[1:]) + .swapaxes(1, 2) + .reshape(blob.shape)) + + # header + write_header(shape, dst_name, 3) # ftype = Q4_1 + + # data + blob.tofile(fout) + +convert_non_q4("model.embed_tokens.weight", "tok_embeddings.weight") +convert_non_q4("model.norm.weight", "norm.weight") +convert_non_q4("lm_head.weight", "output.weight") + +for i in range(n_layer): + convert_q4(f"model.layers.{i}.self_attn.q_proj", f"layers.{i}.attention.wq.weight", permute=True) + convert_q4(f"model.layers.{i}.self_attn.k_proj", f"layers.{i}.attention.wk.weight", permute=True) + convert_q4(f"model.layers.{i}.self_attn.v_proj", f"layers.{i}.attention.wv.weight") + convert_q4(f"model.layers.{i}.self_attn.o_proj", f"layers.{i}.attention.wo.weight") + + convert_q4(f"model.layers.{i}.mlp.gate_proj", f"layers.{i}.feed_forward.w1.weight") + convert_q4(f"model.layers.{i}.mlp.down_proj", f"layers.{i}.feed_forward.w2.weight") + convert_q4(f"model.layers.{i}.mlp.up_proj", f"layers.{i}.feed_forward.w3.weight") + + convert_non_q4(f"model.layers.{i}.input_layernorm.weight", f"layers.{i}.attention_norm.weight") + convert_non_q4(f"model.layers.{i}.post_attention_layernorm.weight", f"layers.{i}.ffn_norm.weight") + + +fout.close() + +print("Done. Output file: " + fname_out) +print("") diff --git a/main.cpp b/main.cpp index 9f46d56..4b220c8 100644 --- a/main.cpp +++ b/main.cpp @@ -157,6 +157,12 @@ bool llama_model_load(const std::string & fname, llama_model & model, llama_voca n_parts = LLAMA_N_PARTS.at(hparams.n_embd); } + // temp warning to tell the user to use "--n_parts" + if (hparams.f16 == 4 && n_parts != 1) { + fprintf(stderr, "%s: GPTQ model detected - are you sure n_parts should be %d? we normally expect it to be 1\n", __func__, n_parts); + fprintf(stderr, "%s: use '--n_parts 1' if necessary\n", __func__); + } + fprintf(stderr, "%s: n_vocab = %d\n", __func__, hparams.n_vocab); fprintf(stderr, "%s: n_ctx = %d\n", __func__, hparams.n_ctx); fprintf(stderr, "%s: n_embd = %d\n", __func__, hparams.n_embd); @@ -198,12 +204,14 @@ bool llama_model_load(const std::string & fname, llama_model & model, llama_voca // for the big tensors, we have the option to store the data in 16-bit floats or quantized // in order to save memory and also to speed up the computation - ggml_type wtype = GGML_TYPE_COUNT; + // wtype is for per-layer weights, while vtype is for other weights + ggml_type wtype, vtype; switch (model.hparams.f16) { - case 0: wtype = GGML_TYPE_F32; break; - case 1: wtype = GGML_TYPE_F16; break; - case 2: wtype = GGML_TYPE_Q4_0; break; - case 3: wtype = GGML_TYPE_Q4_1; break; + case 0: wtype = vtype = GGML_TYPE_F32; break; + case 1: wtype = vtype = GGML_TYPE_F16; break; + case 2: wtype = vtype = GGML_TYPE_Q4_0; break; + case 3: wtype = vtype = GGML_TYPE_Q4_1; break; + case 4: wtype = GGML_TYPE_Q4_1; vtype = GGML_TYPE_F16; break; default: { fprintf(stderr, "%s: invalid model file '%s' (bad f16 value %d)\n", @@ -224,11 +232,11 @@ bool llama_model_load(const std::string & fname, llama_model & model, llama_voca const int n_ctx = hparams.n_ctx; const int n_vocab = hparams.n_vocab; - ctx_size += n_embd*n_vocab*ggml_type_sizef(wtype); // tok_embeddings + ctx_size += n_embd*n_vocab*ggml_type_sizef(vtype); // tok_embeddings ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // norm - ctx_size += n_embd*n_vocab*ggml_type_sizef(wtype); // output + ctx_size += n_embd*n_vocab*ggml_type_sizef(vtype); // output ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // attention_norm @@ -275,10 +283,10 @@ bool llama_model_load(const std::string & fname, llama_model & model, llama_voca model.layers.resize(n_layer); - model.tok_embeddings = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab); + model.tok_embeddings = ggml_new_tensor_2d(ctx, vtype, n_embd, n_vocab); model.norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); - model.output = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab); + model.output = ggml_new_tensor_2d(ctx, vtype, n_embd, n_vocab); // map by name model.tensors["tok_embeddings.weight"] = model.tok_embeddings;