@@ -53,12 +53,14 @@ class LlamaState:
5353 def __init__ (
5454 self ,
5555 eval_tokens : Deque [llama_cpp .llama_token ],
56- eval_logits : Deque [List [float ]],
56+ eval_logits : Deque [List [llama_cpp . c_float ]],
5757 llama_state ,
58+ llama_state_size : llama_cpp .c_size_t ,
5859 ):
5960 self .eval_tokens = eval_tokens
6061 self .eval_logits = eval_logits
6162 self .llama_state = llama_state
63+ self .llama_state_size = llama_state_size
6264
6365
6466class Llama :
@@ -950,19 +952,23 @@ def save_state(self) -> LlamaState:
950952 assert self .ctx is not None
951953 state_size = llama_cpp .llama_get_state_size (self .ctx )
952954 llama_state = (llama_cpp .c_uint8 * int (state_size ))()
953- if llama_cpp .llama_copy_state_data (self .ctx , llama_state ) != state_size :
955+ n_bytes = llama_cpp .llama_copy_state_data (self .ctx , llama_state )
956+ if int (n_bytes ) > int (state_size ):
954957 raise RuntimeError ("Failed to copy llama state data" )
958+ llama_state_compact = (llama_cpp .c_uint8 * int (n_bytes ))()
959+ llama_cpp .ctypes .memmove (llama_state_compact , llama_state , int (n_bytes ))
955960 return LlamaState (
956961 eval_tokens = self .eval_tokens .copy (),
957962 eval_logits = self .eval_logits .copy (),
958- llama_state = llama_state ,
963+ llama_state = llama_state_compact ,
964+ llama_state_size = n_bytes ,
959965 )
960966
961967 def load_state (self , state : LlamaState ) -> None :
962968 assert self .ctx is not None
963969 self .eval_tokens = state .eval_tokens .copy ()
964970 self .eval_logits = state .eval_logits .copy ()
965- state_size = llama_cpp . llama_get_state_size ( self . ctx )
971+ state_size = state . llama_state_size
966972 if llama_cpp .llama_set_state_data (self .ctx , state .llama_state ) != state_size :
967973 raise RuntimeError ("Failed to set llama state data" )
968974
0 commit comments