diff --git a/llama.cpp b/llama.cpp index 447fa91..14de611 100644 --- a/llama.cpp +++ b/llama.cpp @@ -267,14 +267,16 @@ static void kv_cache_free(struct llama_kv_cache & cache) { struct llama_context_params llama_context_default_params() { struct llama_context_params result = { - /*.n_ctx =*/ 512, - /*.n_parts =*/ -1, - /*.seed =*/ 0, - /*.f16_kv =*/ false, - /*.logits_all =*/ false, - /*.vocab_only =*/ false, - /*.use_mlock =*/ false, - /*.embedding =*/ false, + /*.n_ctx =*/ 512, + /*.n_parts =*/ -1, + /*.seed =*/ 0, + /*.f16_kv =*/ false, + /*.logits_all =*/ false, + /*.vocab_only =*/ false, + /*.use_mlock =*/ false, + /*.embedding =*/ false, + /*.progress_callback =*/ nullptr, + /*.progress_callback_user_data =*/ nullptr, }; return result; @@ -290,7 +292,9 @@ static bool llama_model_load( int n_ctx, int n_parts, ggml_type memory_type, - bool vocab_only) { + bool vocab_only, + llama_progress_callback progress_callback, + void *progress_callback_user_data) { fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str()); const int64_t t_start_us = ggml_time_us(); @@ -576,6 +580,10 @@ static bool llama_model_load( std::vector tmp; + if (progress_callback) { + progress_callback(0.0, progress_callback_user_data); + } + for (int i = 0; i < n_parts; ++i) { const int part_id = i; //const int part_id = n_parts - i - 1; @@ -589,6 +597,10 @@ static bool llama_model_load( fin = std::ifstream(fname_part, std::ios::binary); fin.rdbuf()->pubsetbuf(f_buf.data(), f_buf.size()); + + fin.seekg(0, fin.end); + const size_t file_size = fin.tellg(); + fin.seekg(file_offset); // load weights @@ -764,6 +776,11 @@ static bool llama_model_load( model.n_loaded++; // progress + if (progress_callback) { + double current_file_progress = double(size_t(fin.tellg()) - file_offset) / double(file_size - file_offset); + double current_progress = (double(i) + current_file_progress) / double(n_parts); + progress_callback(current_progress, progress_callback_user_data); + } if (model.n_loaded % 8 == 0) { fprintf(stderr, "."); fflush(stderr); @@ -786,6 +803,10 @@ static bool llama_model_load( lctx.t_load_us = ggml_time_us() - t_start_us; + if (progress_callback) { + progress_callback(1.0, progress_callback_user_data); + } + return true; } @@ -1617,7 +1638,8 @@ struct llama_context * llama_init_from_file( ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_parts, memory_type, - params.vocab_only)) { + params.vocab_only, params.progress_callback, + params.progress_callback_user_data)) { fprintf(stderr, "%s: failed to load model\n", __func__); llama_free(ctx); return nullptr; diff --git a/llama.h b/llama.h index 57123db..827abc1 100644 --- a/llama.h +++ b/llama.h @@ -45,6 +45,8 @@ extern "C" { } llama_token_data; + typedef void (*llama_progress_callback)(double progress, void *ctx); + struct llama_context_params { int n_ctx; // text context int n_parts; // -1 for default @@ -55,6 +57,11 @@ extern "C" { bool vocab_only; // only load the vocabulary, no weights bool use_mlock; // force system to keep model in RAM bool embedding; // embedding mode only + + // called with a progress value between 0 and 1, pass NULL to disable + llama_progress_callback progress_callback; + // context pointer passed to the progress callback + void * progress_callback_user_data; }; LLAMA_API struct llama_context_params llama_context_default_params();