@@ -532,6 +532,7 @@ def format_phind(
532532 _prompt = _format_add_colon_single (_system_message , _messages , _sep )
533533 return ChatFormatterResponse (prompt = _prompt )
534534
535+
535536@register_chat_format ("intel" )
536537def format_intel (
537538 messages : List [llama_types .ChatCompletionRequestMessage ],
@@ -588,6 +589,7 @@ def format_mistrallite(
588589 _prompt = _format_no_colon_single (system_message , _messages , _sep )
589590 return ChatFormatterResponse (prompt = _prompt )
590591
592+
591593@register_chat_format ("chatml" )
592594def format_chatml (
593595 messages : List [llama_types .ChatCompletionRequestMessage ],
@@ -604,6 +606,7 @@ def format_chatml(
604606 _prompt = _format_chatml (system_message , _messages , _sep )
605607 return ChatFormatterResponse (prompt = _prompt , stop = _sep )
606608
609+
607610@register_chat_format ("openchat" )
608611def format_openchat (
609612 messages : List [llama_types .ChatCompletionRequestMessage ],
@@ -612,7 +615,9 @@ def format_openchat(
612615 system_template = "{system_message}<|end_of_turn|>"
613616 system_message = _get_system_message (messages )
614617 system_message = system_template .format (system_message = system_message )
615- _roles = dict (user = "GPT4 Correct User: " , assistant = "<|end_of_turn|>GPT4 Correct Assistant: " )
618+ _roles = dict (
619+ user = "GPT4 Correct User: " , assistant = "<|end_of_turn|>GPT4 Correct Assistant: "
620+ )
616621 _sep = "<|end_of_turn|>"
617622 _messages = _map_roles (messages , _roles )
618623 _messages .append ((_roles ["assistant" ], None ))
@@ -651,46 +656,60 @@ def functionary_chat_handler(
651656) -> Union [llama_types .ChatCompletion , Iterator [llama_types .ChatCompletionChunk ]]:
652657 SYSTEM_MESSAGE = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary"""
653658
654- def generate_type_definition (param : Dict [str , llama_types .JsonType ], indent_level : int , shared_defs ) -> str :
655- indent = ' ' * indent_level
656- if '$ref' in param :
659+ def generate_type_definition (
660+ param : Dict [str , llama_types .JsonType ], indent_level : int , shared_defs
661+ ) -> str :
662+ indent = " " * indent_level
663+ if "$ref" in param :
657664 # Reference to a shared definition
658- ref_name = param ['$ref' ].split ('/' )[- 1 ] # Extract the type name from the reference
665+ ref_name = param ["$ref" ].split ("/" )[
666+ - 1
667+ ] # Extract the type name from the reference
659668 return ref_name
660- elif param .get (' type' ) == ' array' :
661- items = param .get (' items' , {})
669+ elif param .get (" type" ) == " array" :
670+ items = param .get (" items" , {})
662671 item_type = generate_type_definition (items , indent_level + 1 , shared_defs )
663672 return f"Array<{ item_type } >"
664- elif param .get (' type' ) == ' object' :
665- properties = param .get (' properties' , {})
673+ elif param .get (" type" ) == " object" :
674+ properties = param .get (" properties" , {})
666675 nested_schema = "{\n "
667676 for nested_param_name , nested_param in properties .items ():
668- nested_param_type = generate_type_definition (nested_param , indent_level + 1 , shared_defs )
669- nested_schema += f"{ indent } { nested_param_name } : { nested_param_type } ,\n "
677+ nested_param_type = generate_type_definition (
678+ nested_param , indent_level + 1 , shared_defs
679+ )
680+ nested_schema += (
681+ f"{ indent } { nested_param_name } : { nested_param_type } ,\n "
682+ )
670683 nested_schema += indent + "}"
671684 return nested_schema
672- elif ' enum' in param :
685+ elif " enum" in param :
673686 # Enum type
674- return " | " .join ([f'"{ enum_value } "' for enum_value in param [' enum' ]])
687+ return " | " .join ([f'"{ enum_value } "' for enum_value in param [" enum" ]])
675688 else :
676689 # Simple type
677- return param .get (' type' , ' any' )
690+ return param .get (" type" , " any" )
678691
679692 def generate_shared_definitions (shared_defs , indent_level : int ) -> str :
680- indent = ' ' * indent_level
693+ indent = " " * indent_level
681694 shared_definitions = ""
682695 for def_name , def_properties in shared_defs .items ():
683696 shared_definitions += f"{ indent } type { def_name } = "
684- if def_properties .get ('type' ) == 'object' :
685- shared_definitions += generate_type_definition (def_properties , indent_level , shared_defs )
686- elif 'enum' in def_properties :
697+ if def_properties .get ("type" ) == "object" :
698+ shared_definitions += generate_type_definition (
699+ def_properties , indent_level , shared_defs
700+ )
701+ elif "enum" in def_properties :
687702 # Enum type
688- shared_definitions += " | " .join ([f'"{ enum_value } "' for enum_value in def_properties ['enum' ]])
703+ shared_definitions += " | " .join (
704+ [f'"{ enum_value } "' for enum_value in def_properties ["enum" ]]
705+ )
689706 shared_definitions += ";\n "
690707 return shared_definitions
691708
692709 def generate_schema_from_functions (functions , namespace = "functions" ) -> str :
693- schema = "// Supported function definitions that should be called when necessary.\n "
710+ schema = (
711+ "// Supported function definitions that should be called when necessary.\n "
712+ )
694713 schema += f"namespace { namespace } {{\n \n "
695714
696715 # Generate shared definitions
@@ -706,10 +725,10 @@ def generate_schema_from_functions(functions, namespace="functions") -> str:
706725 description = function .get ("description" , "" )
707726 parameters = function .get ("parameters" , {})
708727 required_params = parameters .get ("required" , [])
709-
728+
710729 schema += f" // { description } \n "
711730 schema += f" type { function_name } = (_: {{\n "
712-
731+
713732 for param_name , param in parameters .get ("properties" , {}).items ():
714733 param_description = param .get ("description" , "" )
715734 param_type = generate_type_definition (param , 2 , shared_definitions )
@@ -733,13 +752,18 @@ def prepare_messages_for_inference(
733752 role = "system" , content = generate_schema_from_functions (functions )
734753 )
735754 )
736-
755+
737756 if tools is not None :
738757 all_messages .append (
739758 llama_types .ChatCompletionRequestSystemMessage (
740- role = "system" , content = generate_schema_from_functions (
741- [tool ["function" ] for tool in tools if tool ["type" ] == "function" ]
742- )
759+ role = "system" ,
760+ content = generate_schema_from_functions (
761+ [
762+ tool ["function" ]
763+ for tool in tools
764+ if tool ["type" ] == "function"
765+ ]
766+ ),
743767 )
744768 )
745769
@@ -790,7 +814,9 @@ def message_to_str(msg: llama_types.ChatCompletionRequestMessage):
790814 elif "function_call" in msg :
791815 return f"assistant to={ msg ['function_call' ]['name' ]} :\n { msg ['function_call' ]['arguments' ]} </s>\n "
792816 elif "tool_calls" in msg and len (msg ["tool_calls" ]) > 0 :
793- for tool_call in msg ["tool_calls" ]: # NOTE: probably doesn't work with the functionary model
817+ for tool_call in msg [
818+ "tool_calls"
819+ ]: # NOTE: probably doesn't work with the functionary model
794820 return f"assistant to={ tool_call ['id' ]} :\n { tool_call ['function' ]['arguments' ]} </s>\n "
795821 elif msg ["content" ] is None :
796822 return "assistant"
@@ -800,12 +826,14 @@ def message_to_str(msg: llama_types.ChatCompletionRequestMessage):
800826 raise ValueError (f"Unsupported role: { msg ['role' ]} " )
801827
802828 return "" .join ([message_to_str (msg ) for msg in all_messages ])
803-
829+
804830 if tools is not None :
805831 functions = [tool ["function" ] for tool in tools if tool ["type" ] == "function" ]
806-
832+
807833 if tool_choice is not None :
808- function_call = tool_choice if isinstance (tool_choice , str ) else tool_choice ["function" ]
834+ function_call = (
835+ tool_choice if isinstance (tool_choice , str ) else tool_choice ["function" ]
836+ )
809837
810838 prompt = prepare_messages_for_inference (messages , functions , tools )
811839
@@ -861,19 +889,27 @@ def message_to_str(msg: llama_types.ChatCompletionRequestMessage):
861889 if tool ["type" ] == "function" and tool ["function" ]["name" ] == function_call :
862890 function_body = tool ["function" ]["parameters" ]
863891 break
864-
892+
865893 if function_body is not None :
866894 try :
867895 with suppress_stdout_stderr (disable = llama .verbose ):
868- grammar_text = llama_grammar .json_schema_to_gbnf (json .dumps (function_body ))
869- grammar = llama_grammar .LlamaGrammar .from_string (llama_grammar .json_schema_to_gbnf (json .dumps (function_body )))
896+ grammar_text = llama_grammar .json_schema_to_gbnf (
897+ json .dumps (function_body )
898+ )
899+ grammar = llama_grammar .LlamaGrammar .from_string (
900+ llama_grammar .json_schema_to_gbnf (json .dumps (function_body ))
901+ )
870902 print (grammar_text )
871903 except Exception as e :
872904 if llama .verbose :
873- print ("Failed to parse function body as JSON schema, falling back to default grammar" )
905+ print (
906+ "Failed to parse function body as JSON schema, falling back to default grammar"
907+ )
874908 print (e )
875909 with suppress_stdout_stderr (disable = llama .verbose ):
876- grammar = llama_grammar .LlamaGrammar .from_string (llama_grammar .JSON_GBNF )
910+ grammar = llama_grammar .LlamaGrammar .from_string (
911+ llama_grammar .JSON_GBNF
912+ )
877913 else :
878914 with suppress_stdout_stderr (disable = llama .verbose ):
879915 grammar = llama_grammar .LlamaGrammar .from_string (llama_grammar .JSON_GBNF )
@@ -929,9 +965,9 @@ def message_to_str(msg: llama_types.ChatCompletionRequestMessage):
929965 "function" : {
930966 "name" : function_call ,
931967 "arguments" : completion ["choices" ][0 ]["text" ],
932- }
968+ },
933969 }
934- ]
970+ ],
935971 },
936972 "finish_reason" : "tool_calls" ,
937973 }
0 commit comments