@@ -357,7 +357,7 @@ def generate(self):
357357
358358 # Apply params.logit_bias map
359359 for key , value in self .params .logit_bias .items ():
360- logits [key ] += value
360+ logits [key ] += llama_cpp . c_float ( value )
361361
362362 _arr = (llama_cpp .llama_token_data * n_vocab )(* [
363363 llama_cpp .llama_token_data (token_id , logits [token_id ], 0.0 )
@@ -372,34 +372,34 @@ def generate(self):
372372 _arr = (llama_cpp .llama_token * last_n_repeat )(* self .last_n_tokens [len (self .last_n_tokens ) - last_n_repeat :])
373373 llama_cpp .llama_sample_repetition_penalty (self .ctx , candidates_p ,
374374 _arr ,
375- last_n_repeat , self .params .repeat_penalty )
375+ last_n_repeat , llama_cpp . c_float ( self .params .repeat_penalty ) )
376376 llama_cpp .llama_sample_frequency_and_presence_penalties (self .ctx , candidates_p ,
377377 _arr ,
378- last_n_repeat , self .params .frequency_penalty , self .params .presence_penalty )
378+ last_n_repeat , llama_cpp . c_float ( self .params .frequency_penalty ), llama_cpp . c_float ( self .params .presence_penalty ) )
379379
380380 if not self .params .penalize_nl :
381381 logits [llama_cpp .llama_token_nl ()] = nl_logit
382-
382+
383383 if self .params .temp <= 0 :
384384 # Greedy sampling
385385 id = llama_cpp .llama_sample_token_greedy (self .ctx , candidates_p )
386386 else :
387387 if self .params .mirostat == 1 :
388388 mirostat_mu = 2.0 * self .params .mirostat_tau
389389 mirostat_m = 100
390- llama_cpp .llama_sample_temperature (self .ctx , candidates_p , self .params .temp )
391- id = llama_cpp .llama_sample_token_mirostat (self .ctx , candidates_p , self .params .mirostat_tau , self .params .mirostat_eta , mirostat_m , mirostat_mu )
390+ llama_cpp .llama_sample_temperature (self .ctx , candidates_p , llama_cpp . c_float ( self .params .temp ) )
391+ id = llama_cpp .llama_sample_token_mirostat (self .ctx , candidates_p , llama_cpp . c_float ( self .params .mirostat_tau ), llama_cpp . c_float ( self .params .mirostat_eta ), llama_cpp . c_int ( mirostat_m ), llama_cpp . c_float ( mirostat_mu ) )
392392 elif self .params .mirostat == 2 :
393393 mirostat_mu = 2.0 * self .params .mirostat_tau
394- llama_cpp .llama_sample_temperature (self .ctx , candidates_p , self .params .temp )
395- id = llama_cpp .llama_sample_token_mirostat_v2 (self .ctx , candidates_p , self .params .mirostat_tau , self .params .mirostat_eta , mirostat_mu )
394+ llama_cpp .llama_sample_temperature (self .ctx , candidates_p , llama_cpp . c_float ( self .params .temp ) )
395+ id = llama_cpp .llama_sample_token_mirostat_v2 (self .ctx , candidates_p , llama_cpp . c_float ( self .params .mirostat_tau ), llama_cpp . c_float ( self .params .mirostat_eta ), llama_cpp . c_float ( mirostat_mu ) )
396396 else :
397397 # Temperature sampling
398398 llama_cpp .llama_sample_top_k (self .ctx , candidates_p , top_k )
399- llama_cpp .llama_sample_tail_free (self .ctx , candidates_p , self .params .tfs_z )
400- llama_cpp .llama_sample_typical (self .ctx , candidates_p , self .params .typical_p )
401- llama_cpp .llama_sample_top_p (self .ctx , candidates_p , self .params .top_p )
402- llama_cpp .llama_sample_temperature (self .ctx , candidates_p , self .params .temp )
399+ llama_cpp .llama_sample_tail_free (self .ctx , candidates_p , llama_cpp . c_float ( self .params .tfs_z ) )
400+ llama_cpp .llama_sample_typical (self .ctx , candidates_p , llama_cpp . c_float ( self .params .typical_p ) )
401+ llama_cpp .llama_sample_top_p (self .ctx , candidates_p , llama_cpp . c_float ( self .params .top_p ) )
402+ llama_cpp .llama_sample_temperature (self .ctx , candidates_p , llama_cpp . c_float ( self .params .temp ) )
403403 id = llama_cpp .llama_sample_token (self .ctx , candidates_p )
404404 # print("`{}`".format(candidates_p.size))
405405
0 commit comments