@@ -53,12 +53,14 @@ class LlamaState:
5353 def __init__ (
5454 self ,
5555 eval_tokens : Deque [llama_cpp .llama_token ],
56- eval_logits : Deque [List [float ]],
56+ eval_logits : Deque [List [llama_cpp . c_float ]],
5757 llama_state ,
58+ llama_state_size : llama_cpp .c_size_t ,
5859 ):
5960 self .eval_tokens = eval_tokens
6061 self .eval_logits = eval_logits
6162 self .llama_state = llama_state
63+ self .llama_state_size = llama_state_size
6264
6365
6466class Llama :
@@ -394,7 +396,7 @@ def generate(
394396 and tuple (self .eval_tokens ) == tuple (tokens [: len (self .eval_tokens )])
395397 ):
396398 if self .verbose :
397- print ("generate cache hit" , file = sys .stderr )
399+ print ("Llama. generate: cache hit" , file = sys .stderr )
398400 reset = False
399401 tokens = tokens [len (self .eval_tokens ) :]
400402
@@ -516,7 +518,7 @@ def _create_completion(
516518
517519 if self .cache and prompt_tokens in self .cache :
518520 if self .verbose :
519- print ("cache hit" , file = sys .stderr )
521+ print ("Llama._create_completion: cache hit" , file = sys .stderr )
520522 self .load_state (self .cache [prompt_tokens ])
521523
522524 finish_reason = "length"
@@ -536,7 +538,7 @@ def _create_completion(
536538 if self .cache and len (completion_tokens ) == 0 :
537539 if prompt_tokens not in self .cache :
538540 if self .verbose :
539- print ("cache miss" , file = sys .stderr )
541+ print ("Llama._create_completion: cache miss" , file = sys .stderr )
540542 self .cache [prompt_tokens ] = self .save_state ()
541543
542544 completion_tokens .append (token )
@@ -950,19 +952,25 @@ def save_state(self) -> LlamaState:
950952 assert self .ctx is not None
951953 state_size = llama_cpp .llama_get_state_size (self .ctx )
952954 llama_state = (llama_cpp .c_uint8 * int (state_size ))()
953- if llama_cpp .llama_copy_state_data (self .ctx , llama_state ) != state_size :
955+ n_bytes = llama_cpp .llama_copy_state_data (self .ctx , llama_state )
956+ if int (n_bytes ) > int (state_size ):
954957 raise RuntimeError ("Failed to copy llama state data" )
958+ llama_state_compact = (llama_cpp .c_uint8 * int (n_bytes ))()
959+ llama_cpp .ctypes .memmove (llama_state_compact , llama_state , int (n_bytes ))
960+ if self .verbose :
961+ print (f"Llama.save_state: saving { n_bytes } bytes of llama state" , file = sys .stderr )
955962 return LlamaState (
956963 eval_tokens = self .eval_tokens .copy (),
957964 eval_logits = self .eval_logits .copy (),
958- llama_state = llama_state ,
965+ llama_state = llama_state_compact ,
966+ llama_state_size = n_bytes ,
959967 )
960968
961969 def load_state (self , state : LlamaState ) -> None :
962970 assert self .ctx is not None
963971 self .eval_tokens = state .eval_tokens .copy ()
964972 self .eval_logits = state .eval_logits .copy ()
965- state_size = llama_cpp . llama_get_state_size ( self . ctx )
973+ state_size = state . llama_state_size
966974 if llama_cpp .llama_set_state_data (self .ctx , state .llama_state ) != state_size :
967975 raise RuntimeError ("Failed to set llama state data" )
968976
0 commit comments