@@ -141,7 +141,9 @@ def __getitem__(self, key: Sequence[int]) -> "LlamaState":
141141 if _key is None :
142142 raise KeyError ("Key not found" )
143143 value : "LlamaState" = self .cache .pop (_key ) # type: ignore
144- self .cache .push (_key , side = "front" ) # type: ignore
144+ # NOTE: This puts an integer as key in cache, which breaks,
145+ # Llama.longest_token_prefix(k, key) above since k is not a tuple of ints/tokens
146+ # self.cache.push(_key, side="front") # type: ignore
145147 return value
146148
147149 def __contains__ (self , key : Sequence [int ]) -> bool :
@@ -168,7 +170,7 @@ def __init__(
168170 eval_logits : Deque [List [float ]],
169171 input_ids : npt .NDArray [np .intc ],
170172 scores : npt .NDArray [np .single ],
171- llama_state , # type: llama_cpp.Array[llama_cpp.c_uint8]
173+ llama_state : bytes ,
172174 llama_state_size : int ,
173175 ):
174176 self .eval_tokens = eval_tokens
@@ -1503,7 +1505,7 @@ def save_state(self) -> LlamaState:
15031505 eval_logits = self .eval_logits .copy (),
15041506 scores = self ._scores .copy (),
15051507 input_ids = self ._input_ids .copy (),
1506- llama_state = llama_state_compact ,
1508+ llama_state = bytes ( llama_state_compact ) ,
15071509 llama_state_size = n_bytes ,
15081510 )
15091511
@@ -1514,7 +1516,10 @@ def load_state(self, state: LlamaState) -> None:
15141516 self ._scores = state .scores .copy ()
15151517 self ._input_ids = state .input_ids .copy ()
15161518 state_size = state .llama_state_size
1517- if llama_cpp .llama_set_state_data (self .ctx , state .llama_state ) != state_size :
1519+ LLamaStateArrayType = (llama_cpp .c_uint8 * state_size )
1520+ llama_state = LLamaStateArrayType .from_buffer_copy (state .llama_state )
1521+
1522+ if llama_cpp .llama_set_state_data (self .ctx , llama_state ) != state_size :
15181523 raise RuntimeError ("Failed to set llama state data" )
15191524
15201525 def n_ctx (self ) -> int :
0 commit comments