@@ -255,28 +255,28 @@ def __init__(
255255 for i , (k , v ) in enumerate (kv_overrides .items ()):
256256 self ._kv_overrides_array [i ].key = k .encode ("utf-8" )
257257 if isinstance (v , bool ):
258- self ._kv_overrides_array [i ]. tag = (
259- llama_cpp . LLAMA_KV_OVERRIDE_TYPE_BOOL
260- )
258+ self ._kv_overrides_array [
259+ i
260+ ]. tag = llama_cpp . LLAMA_KV_OVERRIDE_TYPE_BOOL
261261 self ._kv_overrides_array [i ].value .val_bool = v
262262 elif isinstance (v , int ):
263- self ._kv_overrides_array [i ]. tag = (
264- llama_cpp . LLAMA_KV_OVERRIDE_TYPE_INT
265- )
263+ self ._kv_overrides_array [
264+ i
265+ ]. tag = llama_cpp . LLAMA_KV_OVERRIDE_TYPE_INT
266266 self ._kv_overrides_array [i ].value .val_i64 = v
267267 elif isinstance (v , float ):
268- self ._kv_overrides_array [i ]. tag = (
269- llama_cpp . LLAMA_KV_OVERRIDE_TYPE_FLOAT
270- )
268+ self ._kv_overrides_array [
269+ i
270+ ]. tag = llama_cpp . LLAMA_KV_OVERRIDE_TYPE_FLOAT
271271 self ._kv_overrides_array [i ].value .val_f64 = v
272272 elif isinstance (v , str ): # type: ignore
273273 v_bytes = v .encode ("utf-8" )
274274 if len (v_bytes ) > 128 : # TODO: Make this a constant
275275 raise ValueError (f"Value for { k } is too long: { v } " )
276276 v_bytes = v_bytes .ljust (128 , b"\0 " )
277- self ._kv_overrides_array [i ]. tag = (
278- llama_cpp . LLAMA_KV_OVERRIDE_TYPE_STR
279- )
277+ self ._kv_overrides_array [
278+ i
279+ ]. tag = llama_cpp . LLAMA_KV_OVERRIDE_TYPE_STR
280280 # copy min(v_bytes, 128) to str_value
281281 address = typing .cast (
282282 int ,
@@ -292,9 +292,9 @@ def __init__(
292292 else :
293293 raise ValueError (f"Unknown value type for { k } : { v } " )
294294
295- self ._kv_overrides_array [- 1 ]. key = (
296- b" \0 " # ensure sentinel element is zeroed
297- )
295+ self ._kv_overrides_array [
296+ - 1
297+ ]. key = b" \0 " # ensure sentinel element is zeroed
298298 self .model_params .kv_overrides = self ._kv_overrides_array
299299
300300 self .n_batch = min (n_ctx , n_batch ) # ???
@@ -431,9 +431,9 @@ def free_lora_adapter():
431431
432432 self .chat_format = chat_format
433433 self .chat_handler = chat_handler
434- self ._chat_handlers : Dict [str , llama_chat_format . LlamaChatCompletionHandler ] = (
435- {}
436- )
434+ self ._chat_handlers : Dict [
435+ str , llama_chat_format . LlamaChatCompletionHandler
436+ ] = {}
437437
438438 self .draft_model = draft_model
439439
@@ -580,7 +580,10 @@ def tokenize(
580580 return self .tokenizer_ .tokenize (text , add_bos , special )
581581
582582 def detokenize (
583- self , tokens : List [int ], prev_tokens : Optional [List [int ]] = None , special : bool = False
583+ self ,
584+ tokens : List [int ],
585+ prev_tokens : Optional [List [int ]] = None ,
586+ special : bool = False ,
584587 ) -> bytes :
585588 """Detokenize a list of tokens.
586589
@@ -592,7 +595,9 @@ def detokenize(
592595 Returns:
593596 The detokenized string.
594597 """
595- return self .tokenizer_ .detokenize (tokens , prev_tokens = prev_tokens , special = special )
598+ return self .tokenizer_ .detokenize (
599+ tokens , prev_tokens = prev_tokens , special = special
600+ )
596601
597602 def set_cache (self , cache : Optional [BaseLlamaCache ]):
598603 """Set the cache.
@@ -681,12 +686,16 @@ def apply_func(token_data_array: llama_cpp.llama_token_data_array_p):
681686 recarray = np .recarray (
682687 shape = (size ,),
683688 dtype = np .dtype (
684- [("id" , np .intc ), ("logit" , np .single ), ("p" , np .single )], align = True
689+ [("id" , np .intc ), ("logit" , np .single ), ("p" , np .single )],
690+ align = True ,
691+ ),
692+ buf = (llama_cpp .llama_token_data * size ).from_address (
693+ data_soa_address
685694 ),
686- buf = (llama_cpp .llama_token_data * size ).from_address (data_soa_address ),
687695 )
688696 for logit_processor in logits_processor :
689697 recarray .logit [:] = logit_processor (self ._input_ids , recarray .logit )
698+
690699 sampler .add_custom (apply_func )
691700
692701 sampler .add_penalties (
@@ -698,7 +707,7 @@ def apply_func(token_data_array: llama_cpp.llama_token_data_array_p):
698707 penalty_freq = frequency_penalty ,
699708 penalty_present = presence_penalty ,
700709 penalize_nl = penalize_nl ,
701- ignore_eos = False
710+ ignore_eos = False ,
702711 )
703712
704713 if grammar is not None :
@@ -841,22 +850,22 @@ def generate(
841850 # Reset mirostat sampling
842851 self ._mirostat_mu = ctypes .c_float (2.0 * mirostat_tau )
843852 self ._sampler = self ._init_sampler (
844- top_k = top_k ,
845- top_p = top_p ,
846- min_p = min_p ,
847- typical_p = typical_p ,
848- temp = temp ,
849- repeat_penalty = repeat_penalty ,
850- frequency_penalty = frequency_penalty ,
851- presence_penalty = presence_penalty ,
852- tfs_z = tfs_z ,
853- mirostat_mode = mirostat_mode ,
854- mirostat_tau = mirostat_tau ,
855- mirostat_eta = mirostat_eta ,
856- penalize_nl = penalize_nl ,
857- logits_processor = logits_processor ,
858- grammar = grammar ,
859- seed = seed ,
853+ top_k = top_k ,
854+ top_p = top_p ,
855+ min_p = min_p ,
856+ typical_p = typical_p ,
857+ temp = temp ,
858+ repeat_penalty = repeat_penalty ,
859+ frequency_penalty = frequency_penalty ,
860+ presence_penalty = presence_penalty ,
861+ tfs_z = tfs_z ,
862+ mirostat_mode = mirostat_mode ,
863+ mirostat_tau = mirostat_tau ,
864+ mirostat_eta = mirostat_eta ,
865+ penalize_nl = penalize_nl ,
866+ logits_processor = logits_processor ,
867+ grammar = grammar ,
868+ seed = seed ,
860869 )
861870
862871 # Check for kv cache prefix match
@@ -872,8 +881,11 @@ def generate(
872881 tokens = tokens [longest_prefix :]
873882 self .n_tokens = longest_prefix
874883 if self .verbose :
875- print (f"Llama.generate: { longest_prefix } prefix-match hit, "
876- f"remaining { len (tokens )} prompt tokens to eval" , file = sys .stderr )
884+ print (
885+ f"Llama.generate: { longest_prefix } prefix-match hit, "
886+ f"remaining { len (tokens )} prompt tokens to eval" ,
887+ file = sys .stderr ,
888+ )
877889
878890 # Reset the model state
879891 if reset :
@@ -1032,7 +1044,9 @@ def decode_batch(seq_sizes: List[int]):
10321044 for j in range (size )
10331045 ]
10341046 if normalize :
1035- embedding = [internals .normalize_embedding (e ) for e in embedding ]
1047+ embedding = [
1048+ internals .normalize_embedding (e ) for e in embedding
1049+ ]
10361050 data .append (embedding )
10371051 pos += size
10381052 else :
0 commit comments