@@ -115,6 +115,7 @@ def __init__(
115115 type_k : Optional [int ] = None ,
116116 type_v : Optional [int ] = None ,
117117 # Misc
118+ spm_infill : bool = False ,
118119 verbose : bool = True ,
119120 # Extra Params
120121 ** kwargs , # type: ignore
@@ -185,6 +186,7 @@ def __init__(
185186 verbose: Print verbose output to stderr.
186187 type_k: KV cache data type for K (default: f16)
187188 type_v: KV cache data type for V (default: f16)
189+ spm_infill: Use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this.
188190
189191 Raises:
190192 ValueError: If the model path does not exist.
@@ -343,6 +345,8 @@ def __init__(
343345 self .lora_scale = lora_scale
344346 self .lora_path = lora_path
345347
348+ self .spm_infill = spm_infill
349+
346350 if not os .path .exists (model_path ):
347351 raise ValueError (f"Model path does not exist: { model_path } " )
348352
@@ -972,14 +976,33 @@ def _create_completion(
972976
973977 completion_id : str = f"cmpl-{ str (uuid .uuid4 ())} "
974978 created : int = int (time .time ())
979+ bos_token_id : int = self .token_bos ()
980+ cls_token_id : int = self ._model .token_cls ()
981+ sep_token_id : int = self ._model .token_sep ()
975982 prefix_token_id : int = self ._model .token_prefix ()
976983 middle_token_id : int = self ._model .token_middle ()
977984 suffix_token_id : int = self ._model .token_suffix ()
985+ add_space_prefix : bool = self .metadata .get ("tokenizer.ggml.add_space_prefix" , "true" ) == "true"
986+ bos_tokens : List [int ] = [cls_token_id if cls_token_id != - 1 else bos_token_id ]
987+ eos_tokens : List [int ] = [sep_token_id if sep_token_id != - 1 else self .token_eos ()]
988+
989+ if (isinstance (prompt , list ) and suffix is None ) or self ._model .add_bos_token () == 0 or bos_tokens [:1 ] == [- 1 ]:
990+ bos_tokens = []
991+
992+ if (isinstance (prompt , list ) and suffix is None ) or (self ._model .add_eos_token () != 1 and sep_token_id == - 1 ):
993+ eos_tokens = []
994+
995+ suffix_space_prefix : int = 0
996+ # Tokenizer hack to remove leading space
997+ if add_space_prefix and suffix_token_id >= 0 and suffix :
998+ suffix = "☺" + suffix
999+ suffix_space_prefix = 2
1000+
9781001 # If prompt is empty, initialize completion with BOS token to avoid
9791002 # detokenization including a space at the beginning of the completion
980- completion_tokens : List [int ] = [] if len (prompt ) > 0 else [self . token_bos () ]
1003+ completion_tokens : List [int ] = [] if len (prompt ) > 0 else [bos_token_id ]
9811004 # Add blank space to start of prompt to match OG llama tokenizer
982- prompt_tokens : List [int ] = (
1005+ prefix_tokens : List [int ] = (
9831006 (
9841007 [prefix_token_id ]
9851008 if prefix_token_id >= 0 and suffix is not None
@@ -988,38 +1011,33 @@ def _create_completion(
9881011 +
9891012 (
9901013 (
991- self .tokenize (prompt .encode ("utf-8" ), add_bos = ( prefix_token_id < 0 or suffix is None ) , special = (prefix_token_id < 0 or suffix is None ))
1014+ self .tokenize (prompt .encode ("utf-8" ), add_bos = False , special = (prefix_token_id < 0 or suffix is None ))
9921015 if prompt != ""
993- else (
994- []
995- if prefix_token_id >= 0 and suffix is not None
996- else [self .token_bos ()]
997- )
1016+ else []
9981017 )
9991018 if isinstance (prompt , str )
10001019 else prompt
10011020 )
1002- +
1021+ )
1022+ suffix_tokens : List [int ] = (
10031023 (
1024+ [suffix_token_id ]
1025+ +
10041026 (
1005- [suffix_token_id ]
1006- +
1007- (
1008- self .tokenize (suffix .encode ("utf-8" ), add_bos = False , special = False )
1009- if suffix
1010- else []
1011- )
1027+ self .tokenize (suffix .encode ("utf-8" ), add_bos = False , special = False )[suffix_space_prefix :]
1028+ if suffix
1029+ else []
10121030 )
1013- if suffix_token_id >= 0 and suffix is not None
1014- else []
1015- )
1016- +
1017- (
1018- [middle_token_id ]
1019- if middle_token_id >= 0 and suffix is not None
1020- else []
10211031 )
1032+ if suffix_token_id >= 0 and suffix is not None
1033+ else []
1034+ )
1035+ middle_tokens : List [int ] = (
1036+ [middle_token_id ]
1037+ if middle_token_id >= 0 and suffix is not None
1038+ else []
10221039 )
1040+ prompt_tokens : List [int ] = bos_tokens + ((suffix_tokens + prefix_tokens + middle_tokens ) if self .spm_infill else (prefix_tokens + suffix_tokens + middle_tokens )) + eos_tokens
10231041 text : bytes = b""
10241042 returned_tokens : int = 0
10251043 stop = (
@@ -1176,7 +1194,7 @@ def logit_bias_processor(
11761194 # not sure how to handle this branch when dealing
11771195 # with CJK output, so keep it unchanged
11781196 for token in remaining_tokens :
1179- if token == self . token_bos () :
1197+ if token == bos_token_id :
11801198 continue
11811199 token_end_position += len (self .detokenize ([token ], prev_tokens = prompt_tokens + completion_tokens [:returned_tokens ]))
11821200 # Check if stop sequence is in the token
@@ -1303,7 +1321,7 @@ def logit_bias_processor(
13031321
13041322 logprobs_or_none : Optional [CompletionLogprobs ] = None
13051323 if logprobs is not None :
1306- if token == self . token_bos () :
1324+ if token == bos_token_id :
13071325 continue
13081326 token_str = self .detokenize ([token ]).decode (
13091327 "utf-8" , errors = "ignore"
@@ -1431,7 +1449,7 @@ def logit_bias_processor(
14311449 for idx , (token , token_str , logprobs_token ) in enumerate (
14321450 zip (all_tokens , all_token_strs , all_logprobs )
14331451 ):
1434- if token == self . token_bos () :
1452+ if token == bos_token_id :
14351453 continue
14361454 text_offsets .append (
14371455 text_offset
@@ -1858,6 +1876,7 @@ def __getstate__(self):
18581876 type_k = self .context_params .type_k ,
18591877 type_v = self .context_params .type_v ,
18601878 # Misc
1879+ spm_infill = self .spm_infill ,
18611880 verbose = self .verbose ,
18621881 )
18631882
0 commit comments