@@ -1596,13 +1596,15 @@ def prepare_messages_for_inference(
15961596 function_call = (
15971597 tool_choice if isinstance (tool_choice , str ) else tool_choice ["function" ]
15981598 )
1599+ else :
1600+ function_call = "auto"
15991601
16001602 prompt = prepare_messages_for_inference (
16011603 messages , tokenizer , version , functions , tools
16021604 )
16031605
16041606 # If no tools/functions are provided
1605- if function_call is None and ( functions is None or len (functions ) == 0 ) :
1607+ if function_call == "none" or functions is None or len (functions ) == 0 :
16061608 if version == "v1" :
16071609 stop = END_ASSISTANT_TOKEN
16081610 else :
@@ -1630,6 +1632,7 @@ def prepare_messages_for_inference(
16301632 logits_processor = logits_processor ,
16311633 grammar = grammar ,
16321634 )
1635+ completion_or_completion_chunks ["choices" ][0 ]["text" ] = completion_or_completion_chunks ["choices" ][0 ]["text" ].lstrip ()
16331636 return _convert_completion_to_chat (completion_or_completion_chunks , stream = stream ) # type: ignore
16341637
16351638 assert stream is False # TODO: support stream mode
@@ -1692,13 +1695,12 @@ def create_completion(stop):
16921695
16931696 return completion
16941697
1698+ content = ""
16951699 function_calls , function_bodies = [], []
16961700
16971701 if version == "v1" :
16981702 # If no or "auto" tool_choice/function_call
1699- if function_call is None or (
1700- isinstance (function_call , str ) and function_call == "auto"
1701- ):
1703+ if isinstance (function_call , str ) and function_call == "auto" :
17021704 stops = ["\n " , END_ASSISTANT_TOKEN ]
17031705 # If tool_choice/function_call is "none"
17041706 elif isinstance (function_call , str ) and function_call == "none" :
@@ -1747,70 +1749,67 @@ def create_completion(stop):
17471749 else :
17481750 function_bodies .append (completion_text .strip ())
17491751 else :
1750- # Loop until all parallel function calls are generated
1751- while True :
1752- # If no or "auto" tool_choice/function_call
1753- if function_call is None or (
1754- isinstance (function_call , str ) and function_call == "auto"
1755- ):
1756- grammar = None
1757- stops = CONTENT_TOKEN
1758- # If tool_choice/function_call is "none"
1759- elif isinstance (function_call , str ) and function_call == "none" :
1760- prompt = (
1761- prepare_messages_for_inference (messages , tokenizer , version , [], [])
1762- + "all\n <|content|>"
1763- )
1764- stops = STOP_TOKEN
1765- # If tool_choice/function_call is provided
1766- elif isinstance (function_call , dict ):
1767- prompt += f"{ function_call ['name' ]} \n { CONTENT_TOKEN } "
1768- stops = STOP_TOKEN
1769- function_call = function_call ["name" ]
1770- function_calls .append (function_call )
1771- grammar = get_grammar (function_call )
1772- else :
1773- prompt = prompt
1774- stops = STOP_TOKEN
1775-
1752+ # If tool_choice/function_call is "none"
1753+ if isinstance (function_call , str ) and function_call == "none" :
1754+ prompt = (
1755+ prepare_messages_for_inference (messages , tokenizer , version , [], [])
1756+ + "all\n <|content|>"
1757+ )
1758+ stops = [STOP_TOKEN , FROM_TOKEN ]
1759+ completion = create_completion (stop = stops )
1760+ completion ["choices" ][0 ]["text" ] = completion ["choices" ][0 ]["text" ].strip ()
1761+ return _convert_completion_to_chat (completion , stream = stream ) # type: ignore
1762+ # If tool_choice/function_call is provided
1763+ elif isinstance (function_call , dict ):
1764+ prompt += f"{ function_call ['name' ]} \n { CONTENT_TOKEN } "
1765+ function_call = function_call ["name" ]
1766+ function_calls .append (function_call )
1767+ grammar = get_grammar (function_call )
1768+ stops = [STOP_TOKEN , FROM_TOKEN ]
17761769 completion = create_completion (stop = stops )
17771770 completion_text = completion ["choices" ][0 ]["text" ]
1778-
1779- # If the generation does not involve a function call
1780- if prompt .endswith ("all\n <|content|>" ) and not completion_text .startswith (
1781- "all"
1782- ):
1783- return _convert_completion_to_chat (completion , stream = stream ) # type: ignore
1784- # Generate model response if the model decides not to call any function
1785- elif prompt .endswith (RECIPIENT_TOKEN ) and completion_text .startswith ("all" ):
1786- prompt += completion_text + CONTENT_TOKEN
1787- completion = create_completion (stop = STOP_TOKEN )
1788- return _convert_completion_to_chat (completion , stream = stream ) # type: ignore
1789- # Generate parameters if model decides to call a function
1790- elif prompt .endswith (RECIPIENT_TOKEN ):
1791- function_calls .append (completion_text [:- 1 ])
1792- grammar = get_grammar (function_calls [- 1 ])
1793- completion = create_completion (stop = [STOP_TOKEN , "\n " ])
1794- function_bodies .append (completion ["choices" ][0 ]["text" ].strip ())
1795- prompt += f"{ function_calls [- 1 ]} \n { CONTENT_TOKEN } { function_bodies [- 1 ]} "
1771+ function_bodies .append (completion_text .strip ())
1772+ # If "auto" or no tool_choice/function_call
1773+ elif isinstance (function_call , str ) and function_call == "auto" :
1774+ while True :
1775+ # Generate function name first
17961776 grammar = None
1797-
1798- # Try to generate the beginning of next turn
1799- # If empty completion, break from loop
1800- next_turn_completion_text = create_completion (
1801- stop = [STOP_TOKEN , RECIPIENT_TOKEN ]
1802- )["choices" ][0 ]["text" ]
1803- if len (next_turn_completion_text ) > 0 :
1804- prompt += f"\n { FROM_TOKEN } assistant\n { RECIPIENT_TOKEN } "
1777+ stops = CONTENT_TOKEN
1778+ completion = create_completion (stop = stops )
1779+ completion_text = completion ["choices" ][0 ]["text" ]
1780+ function_name = completion_text .strip ()
1781+ if function_name == "all" :
1782+ prompt += "all\n <|content|>"
18051783 else :
1806- break
1807- # Break from loop if tool_choice/function_call is provided as a dict
1808- else :
1809- function_bodies .append (completion_text .strip ())
1810- break
1784+ function_call = completion_text .strip ()
1785+ prompt += f"{ function_call } \n <|content|>"
1786+ function_calls .append (function_call )
1787+ grammar = get_grammar (function_call )
1788+ # Generate content
1789+ stops = [RECIPIENT_TOKEN , STOP_TOKEN ]
1790+ completion = create_completion (stop = stops )
1791+ completion_text = completion ["choices" ][0 ]["text" ]
1792+ if function_name == "all" :
1793+ content += completion_text .removesuffix ("\n <|from|>assistant\n " ).removesuffix ("\n <|from|> assistant\n " )
1794+ content = content .lstrip ()
1795+ # Check whether the model wants to generate another turn
1796+ if "<|from|> assistant" in completion_text or "<|from|>assistant" in completion_text :
1797+ cleaned_completion_text = completion_text .removesuffix ("\n <|from|>assistant\n " ).removesuffix ("\n <|from|> assistant\n " ).strip ()
1798+ prompt += f"{ cleaned_completion_text } \n <|from|>assistant\n <|recipient|>"
1799+ else :
1800+ break
1801+ else :
1802+ function_bodies .append (completion_text .strip ())
1803+ # Check whether the model wants to generate another turn
1804+ prompt += completion_text .strip ()
1805+ grammar = None
1806+ completion = create_completion (stop = stops )
1807+ if "<|from|> assistant" in completion ["choices" ][0 ]["text" ] or "<|from|>assistant" in completion ["choices" ][0 ]["text" ]:
1808+ prompt += "\n <|from|>assistant\n <|recipient|>"
1809+ else :
1810+ break
18111811
18121812 assert "usage" in completion
1813- assert len (function_calls ) > 0
18141813 assert len (function_calls ) == len (function_bodies )
18151814
18161815 tool_calls = []
@@ -1843,14 +1842,14 @@ def create_completion(stop):
18431842 "index" : 0 ,
18441843 "message" : {
18451844 "role" : "assistant" ,
1846- "content" : None ,
1845+ "content" : None if content == "" else content ,
18471846 "function_call" : {
18481847 "name" : tool_calls [0 ]["function" ]["name" ],
18491848 "arguments" : tool_calls [0 ]["function" ]["arguments" ],
1850- },
1851- "tool_calls" : tool_calls ,
1849+ } if len ( tool_calls ) > 0 else None ,
1850+ "tool_calls" : tool_calls if len ( tool_calls ) > 0 else None ,
18521851 },
1853- "finish_reason" : "tool_calls" ,
1852+ "finish_reason" : "tool_calls" if len ( tool_calls ) > 0 else "stop" ,
18541853 }
18551854 ],
18561855 usage = completion ["usage" ],
0 commit comments