From 87a6f846d3e929632c45916dd08f1e2a9c72d2a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81sgeir=20Bjarni=20Ingvarsson?= Date: Wed, 26 Apr 2023 20:08:43 +0000 Subject: [PATCH] Allow setting the rng seed after initialization. (#1184) The llama_set_state_data function restores the rng state to what it was at the time llama_copy_state_data was called. But users may want to restore the state and proceed with a different seed. --- llama.cpp | 7 +++++++ llama.h | 3 +++ 2 files changed, 10 insertions(+) diff --git a/llama.cpp b/llama.cpp index 25203c9..8334553 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2082,6 +2082,13 @@ int llama_get_kv_cache_token_count(struct llama_context * ctx) { #define LLAMA_MAX_RNG_STATE 64*1024 +void llama_set_rng_seed(struct llama_context * ctx, int seed) { + if (seed <= 0) { + seed = time(NULL); + } + ctx->rng.seed(seed); +} + // Returns the size of the state size_t llama_get_state_size(struct llama_context * ctx) { // we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state. diff --git a/llama.h b/llama.h index ab41798..24c48cc 100644 --- a/llama.h +++ b/llama.h @@ -116,6 +116,9 @@ extern "C" { // Returns the number of tokens in the KV cache LLAMA_API int llama_get_kv_cache_token_count(struct llama_context * ctx); + // Sets the current rng seed. + LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, int seed); + // Returns the size in bytes of the state (rng, logits, embedding and kv_cache) LLAMA_API size_t llama_get_state_size(struct llama_context * ctx);