Skip to content

Commit dd221ce

Browse files
committed
llm function calling v0
1 parent a8f82d3 commit dd221ce

File tree

1 file changed

+23
-7
lines changed

1 file changed

+23
-7
lines changed

src/inferencesh/models/llm.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ class LLMInput(BaseAppInput):
8989

9090
# Model specific flags
9191
reasoning: bool = Field(default=False)
92+
93+
tools: List[Dict[str, Any]] = Field(default=[])
9294

9395
class LLMUsage(BaseAppOutput):
9496
stop_reason: str = ""
@@ -104,6 +106,7 @@ class LLMUsage(BaseAppOutput):
104106
class LLMOutput(BaseAppOutput):
105107
response: str
106108
reasoning: Optional[str] = None
109+
tool_calls: Optional[List[Dict[str, Any]]] = None
107110
usage: Optional[LLMUsage] = None
108111

109112

@@ -362,6 +365,8 @@ def __call__(self, piece: str, buffer: str) -> tuple[str, LLMOutput, dict]:
362365
def stream_generate(
363366
model: Any,
364367
messages: List[Dict[str, Any]],
368+
tools: List[Dict[str, Any]],
369+
tool_choice: Dict[str, Any],
365370
transformer: ResponseTransformer,
366371
temperature: float = 0.7,
367372
top_p: float = 0.95,
@@ -379,7 +384,7 @@ def stream_generate(
379384
max_tokens: Maximum tokens to generate
380385
stop: Optional list of stop sequences
381386
"""
382-
response_queue: Queue[Optional[tuple[str, dict]]] = Queue()
387+
response_queue: Queue[Optional[tuple[str, dict, Optional[List[Dict[str, Any]]]]]] = Queue()
383388
thread_exception = None
384389
usage_stats = {
385390
"prompt_tokens": 0,
@@ -397,6 +402,8 @@ def generation_thread():
397402
try:
398403
completion = model.create_chat_completion(
399404
messages=messages,
405+
tools=tools,
406+
tool_choice=tool_choice,
400407
stream=True,
401408
temperature=temperature,
402409
top_p=top_p,
@@ -411,18 +418,23 @@ def generation_thread():
411418
delta = chunk.get("choices", [{}])[0]
412419
content = None
413420
finish_reason = None
421+
tool_calls = None
414422

415423
if "message" in delta:
416-
content = delta["message"].get("content", "")
424+
message = delta["message"]
425+
content = message.get("content", "")
426+
tool_calls = message.get("tool_calls")
417427
finish_reason = delta.get("finish_reason")
418428
elif "delta" in delta:
419-
content = delta["delta"].get("content", "")
429+
delta_content = delta["delta"]
430+
content = delta_content.get("content", "")
431+
tool_calls = delta_content.get("tool_calls")
420432
finish_reason = delta.get("finish_reason")
421433

422-
if content:
434+
if content or tool_calls:
423435
if not timing.first_token_time:
424436
timing.mark_first_token()
425-
response_queue.put((content, {}))
437+
response_queue.put((content or "", {}, tool_calls))
426438

427439
if finish_reason:
428440
usage_stats["stop_reason"] = finish_reason
@@ -438,7 +450,7 @@ def generation_thread():
438450
"tokens_per_second": tokens_per_second,
439451
"reasoning_time": timing_stats["reasoning_time"],
440452
"reasoning_tokens": timing_stats["reasoning_tokens"]
441-
}))
453+
}, None))
442454

443455
thread = Thread(target=generation_thread, daemon=True)
444456
thread.start()
@@ -451,7 +463,7 @@ def generation_thread():
451463
if thread_exception:
452464
raise thread_exception
453465

454-
piece, timing_stats = result
466+
piece, timing_stats, tool_calls = result
455467
if piece is None:
456468
# Final yield with complete usage stats
457469
usage = LLMUsage(
@@ -467,10 +479,14 @@ def generation_thread():
467479

468480
buffer, output, _ = transformer(piece or "", buffer)
469481
output.usage = usage
482+
if tool_calls:
483+
output.tool_calls = tool_calls
470484
yield output
471485
break
472486

473487
buffer, output, _ = transformer(piece, buffer)
488+
if tool_calls:
489+
output.tool_calls = tool_calls
474490
yield output
475491

476492
except Exception as e:

0 commit comments

Comments
 (0)