@@ -89,6 +89,8 @@ class LLMInput(BaseAppInput):
8989
9090 # Model specific flags
9191 reasoning : bool = Field (default = False )
92+
93+ tools : List [Dict [str , Any ]] = Field (default = [])
9294
9395class LLMUsage (BaseAppOutput ):
9496 stop_reason : str = ""
@@ -104,6 +106,7 @@ class LLMUsage(BaseAppOutput):
104106class LLMOutput (BaseAppOutput ):
105107 response : str
106108 reasoning : Optional [str ] = None
109+ tool_calls : Optional [List [Dict [str , Any ]]] = None
107110 usage : Optional [LLMUsage ] = None
108111
109112
@@ -362,6 +365,8 @@ def __call__(self, piece: str, buffer: str) -> tuple[str, LLMOutput, dict]:
362365def stream_generate (
363366 model : Any ,
364367 messages : List [Dict [str , Any ]],
368+ tools : List [Dict [str , Any ]],
369+ tool_choice : Dict [str , Any ],
365370 transformer : ResponseTransformer ,
366371 temperature : float = 0.7 ,
367372 top_p : float = 0.95 ,
@@ -379,7 +384,7 @@ def stream_generate(
379384 max_tokens: Maximum tokens to generate
380385 stop: Optional list of stop sequences
381386 """
382- response_queue : Queue [Optional [tuple [str , dict ]]] = Queue ()
387+ response_queue : Queue [Optional [tuple [str , dict , Optional [ List [ Dict [ str , Any ]]] ]]] = Queue ()
383388 thread_exception = None
384389 usage_stats = {
385390 "prompt_tokens" : 0 ,
@@ -397,6 +402,8 @@ def generation_thread():
397402 try :
398403 completion = model .create_chat_completion (
399404 messages = messages ,
405+ tools = tools ,
406+ tool_choice = tool_choice ,
400407 stream = True ,
401408 temperature = temperature ,
402409 top_p = top_p ,
@@ -411,18 +418,23 @@ def generation_thread():
411418 delta = chunk .get ("choices" , [{}])[0 ]
412419 content = None
413420 finish_reason = None
421+ tool_calls = None
414422
415423 if "message" in delta :
416- content = delta ["message" ].get ("content" , "" )
424+ message = delta ["message" ]
425+ content = message .get ("content" , "" )
426+ tool_calls = message .get ("tool_calls" )
417427 finish_reason = delta .get ("finish_reason" )
418428 elif "delta" in delta :
419- content = delta ["delta" ].get ("content" , "" )
429+ delta_content = delta ["delta" ]
430+ content = delta_content .get ("content" , "" )
431+ tool_calls = delta_content .get ("tool_calls" )
420432 finish_reason = delta .get ("finish_reason" )
421433
422- if content :
434+ if content or tool_calls :
423435 if not timing .first_token_time :
424436 timing .mark_first_token ()
425- response_queue .put ((content , {}))
437+ response_queue .put ((content or "" , {}, tool_calls ))
426438
427439 if finish_reason :
428440 usage_stats ["stop_reason" ] = finish_reason
@@ -438,7 +450,7 @@ def generation_thread():
438450 "tokens_per_second" : tokens_per_second ,
439451 "reasoning_time" : timing_stats ["reasoning_time" ],
440452 "reasoning_tokens" : timing_stats ["reasoning_tokens" ]
441- }))
453+ }, None ))
442454
443455 thread = Thread (target = generation_thread , daemon = True )
444456 thread .start ()
@@ -451,7 +463,7 @@ def generation_thread():
451463 if thread_exception :
452464 raise thread_exception
453465
454- piece , timing_stats = result
466+ piece , timing_stats , tool_calls = result
455467 if piece is None :
456468 # Final yield with complete usage stats
457469 usage = LLMUsage (
@@ -467,10 +479,14 @@ def generation_thread():
467479
468480 buffer , output , _ = transformer (piece or "" , buffer )
469481 output .usage = usage
482+ if tool_calls :
483+ output .tool_calls = tool_calls
470484 yield output
471485 break
472486
473487 buffer , output , _ = transformer (piece , buffer )
488+ if tool_calls :
489+ output .tool_calls = tool_calls
474490 yield output
475491
476492 except Exception as e :
0 commit comments