@ -27,6 +27,7 @@
# include <thread>
# include <thread>
# include <atomic>
# include <atomic>
# include <mutex>
# include <mutex>
# include <sstream>
# define LLAMA_USE_SCRATCH
# define LLAMA_USE_SCRATCH
# define LLAMA_MAX_SCRATCH_BUFFERS 16
# define LLAMA_MAX_SCRATCH_BUFFERS 16
@ -1787,7 +1788,7 @@ struct llama_context * llama_init_from_file(
if ( params . logits_all ) {
if ( params . logits_all ) {
ctx - > logits . reserve ( hparams . n_ctx * hparams . n_vocab ) ;
ctx - > logits . reserve ( hparams . n_ctx * hparams . n_vocab ) ;
} else {
} else {
ctx - > logits . reserve ( hparams . n_ ctx ) ;
ctx - > logits . reserve ( hparams . n_ vocab ) ;
}
}
if ( params . embedding ) {
if ( params . embedding ) {
@ -2252,3 +2253,122 @@ const char * llama_print_system_info(void) {
std : : vector < std : : pair < std : : string , struct ggml_tensor * > > & llama_internal_get_tensor_map ( struct llama_context * ctx ) {
std : : vector < std : : pair < std : : string , struct ggml_tensor * > > & llama_internal_get_tensor_map ( struct llama_context * ctx ) {
return ctx - > model . tensors_by_name ;
return ctx - > model . tensors_by_name ;
}
}
// Returns the size of the state
size_t llama_get_state_size ( struct llama_context * ctx ) {
const size_t s_bool = sizeof ( int32_t ) ;
// we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state.
// for reference, std::mt19937(1337) serializes to 6701 bytes.
const size_t s_rng_size = sizeof ( size_t ) ;
const size_t s_rng = 64 * 1024 ;
const size_t s_logits_capacity = sizeof ( size_t ) ;
const size_t s_logits_size = sizeof ( size_t ) ;
const size_t s_logits = ctx - > logits . capacity ( ) * sizeof ( float ) ;
const size_t s_embedding_size = sizeof ( size_t ) ;
const size_t s_embedding = ctx - > embedding . size ( ) * sizeof ( float ) ;
const size_t s_kv_size = sizeof ( size_t ) ;
const size_t s_kv_ntok = sizeof ( int ) ;
const size_t s_kv = llama_get_kv_cache_size ( ctx ) ;
const size_t s_total = (
+ s_rng_size
+ s_rng
+ s_logits_capacity
+ s_logits_size
+ s_logits
+ s_embedding_size
+ s_embedding
+ s_kv_size
+ s_kv_ntok
+ s_kv
) ;
return s_total ;
}
// Copies the state to the specified destination address
size_t llama_copy_state_data ( struct llama_context * ctx , uint8_t * dest ) {
std : : stringstream rng_ss ;
rng_ss < < ctx - > rng ;
const size_t rng_size = rng_ss . str ( ) . size ( ) ;
char rng_buf [ 64 * 1024 ] ;
memset ( & rng_buf [ 0 ] , 0 , 64 * 1024 ) ;
memcpy ( & rng_buf [ 0 ] , rng_ss . str ( ) . data ( ) , rng_ss . str ( ) . size ( ) ) ;
const size_t logits_capacity = ctx - > logits . capacity ( ) ;
const size_t logits_size = ctx - > logits . size ( ) ;
const size_t embedding_size = ctx - > embedding . size ( ) ;
const size_t kv_size = llama_get_kv_cache_size ( ctx ) ;
const int kv_ntok = llama_get_kv_cache_token_count ( ctx ) ;
uint8_t * out = dest ;
memcpy ( out , & rng_size , sizeof ( size_t ) ) ; out + = sizeof ( size_t ) ;
memcpy ( out , & rng_buf [ 0 ] , 64 * 1024 ) ; out + = 64 * 1024 ;
memcpy ( out , & logits_capacity , sizeof ( size_t ) ) ; out + = sizeof ( size_t ) ;
memcpy ( out , & logits_size , sizeof ( size_t ) ) ; out + = sizeof ( size_t ) ;
if ( logits_size ) {
memcpy ( out , ctx - > logits . data ( ) , logits_size * sizeof ( float ) ) ;
}
out + = logits_capacity * sizeof ( float ) ;
memcpy ( out , & embedding_size , sizeof ( size_t ) ) ; out + = sizeof ( size_t ) ;
if ( embedding_size ) {
memcpy ( out , ctx - > embedding . data ( ) , embedding_size * sizeof ( float ) ) ; out + = embedding_size * sizeof ( float ) ;
}
memcpy ( out , & kv_size , sizeof ( size_t ) ) ; out + = sizeof ( size_t ) ;
memcpy ( out , & kv_ntok , sizeof ( int ) ) ; out + = sizeof ( int ) ;
if ( kv_size ) {
memcpy ( out , llama_get_kv_cache ( ctx ) , kv_size ) ; out + = kv_size ;
}
const size_t written = out - dest ;
const size_t expected = llama_get_state_size ( ctx ) ;
LLAMA_ASSERT ( written = = expected ) ;
return written ;
}
// Sets the state reading from the specified source address
size_t llama_set_state_data ( struct llama_context * ctx , const uint8_t * src ) {
size_t rng_size ;
char rng_buf [ 64 * 1024 ] ;
std : : stringstream rng_ss ;
const uint8_t * in = src ;
memcpy ( & rng_size , in , sizeof ( size_t ) ) ; in + = sizeof ( size_t ) ;
memcpy ( & rng_buf [ 0 ] , in , 64 * 1024 ) ; in + = 64 * 1024 ;
rng_ss . str ( std : : string ( & rng_buf [ 0 ] , rng_size ) ) ;
rng_ss > > ctx - > rng ;
LLAMA_ASSERT ( rng_ss . fail ( ) = = false ) ;
size_t logits_capacity ;
size_t logits_size ;
size_t embedding_size ;
size_t kv_size ;
int kv_ntok ;
memcpy ( & logits_capacity , in , sizeof ( size_t ) ) ; in + = sizeof ( size_t ) ;
memcpy ( & logits_size , in , sizeof ( size_t ) ) ; in + = sizeof ( size_t ) ;
LLAMA_ASSERT ( ctx - > logits . capacity ( ) = = logits_capacity ) ;
if ( logits_size ) {
ctx - > logits . resize ( logits_size ) ;
memcpy ( ctx - > logits . data ( ) , in , logits_size * sizeof ( float ) ) ;
}
in + = logits_capacity * sizeof ( float ) ;
memcpy ( & embedding_size , in , sizeof ( size_t ) ) ; in + = sizeof ( size_t ) ;
LLAMA_ASSERT ( ctx - > embedding . capacity ( ) = = embedding_size ) ;
if ( embedding_size ) {
memcpy ( ctx - > embedding . data ( ) , in , embedding_size * sizeof ( float ) ) ;
in + = embedding_size * sizeof ( float ) ;
}
memcpy ( & kv_size , in , sizeof ( size_t ) ) ; in + = sizeof ( size_t ) ;
memcpy ( & kv_ntok , in , sizeof ( int ) ) ; in + = sizeof ( int ) ;
if ( kv_size ) {
LLAMA_ASSERT ( ctx - > model . kv_self . buf . size = = kv_size ) ;
void * k_data = ctx - > model . kv_self . k - > data ; // remember data pointers
void * v_data = ctx - > model . kv_self . v - > data ; // because their value is stored in buf and overwritten by memcpy
memcpy ( ctx - > model . kv_self . buf . addr , in , kv_size ) ;
ctx - > model . kv_self . k - > data = k_data ; // restore correct data pointers
ctx - > model . kv_self . v - > data = v_data ;
in + = kv_size ;
}
ctx - > model . kv_self . n = kv_ntok ;
const size_t nread = in - src ;
const size_t expected = llama_get_state_size ( ctx ) ;
LLAMA_ASSERT ( nread = = expected ) ;
return nread ;
}