@@ -623,7 +623,7 @@ def _create_completion(
623623 b" " + prompt .encode ("utf-8" )
624624 )
625625 text : bytes = b""
626- returned_characters : int = 0
626+ returned_tokens : int = 0
627627 stop = stop if stop is not None else []
628628 model_name : str = model if model is not None else self .model_path
629629
@@ -707,33 +707,42 @@ def _create_completion(
707707 break
708708
709709 if stream :
710- start = returned_characters
711- longest = 0
712710 # We want to avoid yielding any characters from
713711 # the generated text if they are part of a stop
714712 # sequence.
713+ longest = 0
715714 for s in stop_sequences :
716715 for i in range (len (s ), 0 , - 1 ):
717716 if all_text .endswith (s [:i ]):
718717 if i > longest :
719718 longest = i
720719 break
721- text = all_text [: len (all_text ) - longest ]
722- returned_characters += len (text [start :])
723- yield {
724- "id" : completion_id ,
725- "object" : "text_completion" ,
726- "created" : created ,
727- "model" : model_name ,
728- "choices" : [
729- {
730- "text" : text [start :].decode ("utf-8" , errors = "ignore" ),
731- "index" : 0 ,
732- "logprobs" : None ,
733- "finish_reason" : None ,
734- }
735- ],
736- }
720+
721+ offset = 0
722+ remaining_tokens = completion_tokens [returned_tokens :]
723+ remaining_length = len (self .detokenize (remaining_tokens ))
724+ for token in remaining_tokens :
725+ offset += len (self .detokenize ([token ]))
726+ # Check if stop sequence is not in the token
727+ if offset >= (remaining_length - longest - 1 ):
728+ break
729+ returned_tokens += 1
730+ yield {
731+ "id" : completion_id ,
732+ "object" : "text_completion" ,
733+ "created" : created ,
734+ "model" : model_name ,
735+ "choices" : [
736+ {
737+ "text" : self .detokenize ([token ]).decode (
738+ "utf-8" , errors = "ignore"
739+ ),
740+ "index" : 0 ,
741+ "logprobs" : None ,
742+ "finish_reason" : None ,
743+ }
744+ ],
745+ }
737746
738747 if len (completion_tokens ) >= max_tokens :
739748 text = self .detokenize (completion_tokens )
@@ -749,22 +758,57 @@ def _create_completion(
749758 llama_cpp .llama_print_timings (self .ctx )
750759
751760 if stream :
752- yield {
753- "id" : completion_id ,
754- "object" : "text_completion" ,
755- "created" : created ,
756- "model" : model_name ,
757- "choices" : [
758- {
759- "text" : text [returned_characters :].decode (
760- "utf-8" , errors = "ignore"
761- ),
762- "index" : 0 ,
763- "logprobs" : None ,
764- "finish_reason" : finish_reason ,
761+ remaining_tokens = completion_tokens [returned_tokens :]
762+ all_text = self .detokenize (remaining_tokens )
763+ any_stop = [s for s in stop_sequences if s in all_text ]
764+ if len (any_stop ) > 0 :
765+ end = min (all_text .index (stop ) for stop in any_stop )
766+ else :
767+ end = len (all_text )
768+
769+ offset = 0
770+ for token in remaining_tokens :
771+ offset += len (self .detokenize ([token ]))
772+ if offset >= end :
773+ last_text = self .detokenize ([token ])
774+ if offset == end - 1 :
775+ break
776+ yield {
777+ "id" : completion_id ,
778+ "object" : "text_completion" ,
779+ "created" : created ,
780+ "model" : model_name ,
781+ "choices" : [
782+ {
783+ "text" : last_text [
784+ : len (last_text ) - (offset - end )
785+ ].decode ("utf-8" , errors = "ignore" ),
786+ "index" : 0 ,
787+ "logprobs" : None ,
788+ "finish_reason" : finish_reason ,
789+ }
790+ ],
765791 }
766- ],
767- }
792+ break
793+ returned_tokens += 1
794+ yield {
795+ "id" : completion_id ,
796+ "object" : "text_completion" ,
797+ "created" : created ,
798+ "model" : model_name ,
799+ "choices" : [
800+ {
801+ "text" : self .detokenize ([token ]).decode (
802+ "utf-8" , errors = "ignore"
803+ ),
804+ "index" : 0 ,
805+ "logprobs" : None ,
806+ "finish_reason" : finish_reason
807+ if returned_tokens == len (completion_tokens )
808+ else None ,
809+ }
810+ ],
811+ }
768812 return
769813
770814 text_str = text .decode ("utf-8" , errors = "ignore" )
0 commit comments