Skip to content

Commit 3a67087

Browse files
committed
llm function calling v0
1 parent d35e69f commit 3a67087

File tree

1 file changed

+40
-20
lines changed

1 file changed

+40
-20
lines changed

src/inferencesh/models/llm.py

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,8 @@ def __init__(self):
216216
self.response = ""
217217
self.reasoning = None
218218
self.function_calls = None # For future function calling support
219-
self.tool_calls = None # For future tool calling support
219+
self.tool_calls = [] # List to accumulate tool calls
220+
self.current_tool_call = None # Track current tool call being built
220221
self.state_changes = {
221222
"reasoning_started": False,
222223
"reasoning_ended": False,
@@ -373,17 +374,7 @@ def stream_generate(
373374
max_tokens: int = 4096,
374375
stop: Optional[List[str]] = None,
375376
) -> Generator[LLMOutput, None, None]:
376-
"""Stream generate from LLaMA.cpp model with timing and usage tracking.
377-
378-
Args:
379-
model: The LLaMA.cpp model instance
380-
messages: List of messages to send to the model
381-
transformer: ResponseTransformer instance to use for processing output
382-
temperature: Sampling temperature
383-
top_p: Top-p sampling threshold
384-
max_tokens: Maximum tokens to generate
385-
stop: Optional list of stop sequences
386-
"""
377+
"""Stream generate from LLaMA.cpp model with timing and usage tracking."""
387378
response_queue: Queue[Optional[tuple[str, dict, Optional[List[Dict[str, Any]]]]]] = Queue()
388379
thread_exception = None
389380
usage_stats = {
@@ -394,7 +385,6 @@ def stream_generate(
394385
}
395386

396387
with timing_context() as timing:
397-
# Set timing context in transformer
398388
transformer.timing = timing
399389

400390
def generation_thread():
@@ -411,30 +401,60 @@ def generation_thread():
411401
stop=stop
412402
)
413403

404+
tool_calls = []
405+
current_tool = None
406+
414407
for chunk in completion:
415408
if "usage" in chunk and chunk["usage"] is not None:
416409
usage_stats.update(chunk["usage"])
417410

418411
delta = chunk.get("choices", [{}])[0]
419-
content = None
412+
content = ""
420413
finish_reason = None
421-
tool_calls = None
422414

415+
# Extract delta content from either message or delta
423416
if "message" in delta:
424417
message = delta["message"]
425418
content = message.get("content", "")
426-
tool_calls = message.get("tool_calls")
419+
if "tool_calls" in message:
420+
for tool in message["tool_calls"]:
421+
if tool.get("id") not in {t.get("id") for t in tool_calls}:
422+
tool_calls.append(tool)
427423
finish_reason = delta.get("finish_reason")
428424
elif "delta" in delta:
429425
delta_content = delta["delta"]
430426
content = delta_content.get("content", "")
431-
tool_calls = delta_content.get("tool_calls")
427+
428+
# Handle streaming tool calls
429+
if "tool_calls" in delta_content:
430+
for tool_delta in delta_content["tool_calls"]:
431+
tool_id = tool_delta.get("id")
432+
433+
# Find or create tool call
434+
if tool_id:
435+
current_tool = next((t for t in tool_calls if t["id"] == tool_id), None)
436+
if not current_tool:
437+
current_tool = {
438+
"id": tool_id,
439+
"type": tool_delta.get("type", "function"),
440+
"function": {"name": "", "arguments": ""}
441+
}
442+
tool_calls.append(current_tool)
443+
444+
# Update tool call
445+
if current_tool and "function" in tool_delta:
446+
func_delta = tool_delta["function"]
447+
if "name" in func_delta:
448+
current_tool["function"]["name"] = func_delta["name"]
449+
if "arguments" in func_delta:
450+
current_tool["function"]["arguments"] += func_delta["arguments"]
451+
432452
finish_reason = delta.get("finish_reason")
433453

434-
if content or tool_calls:
454+
if content or "tool_calls" in (delta.get("message", {}) or delta.get("delta", {})):
435455
if not timing.first_token_time:
436456
timing.mark_first_token()
437-
response_queue.put((content or "", {}, tool_calls))
457+
response_queue.put((content, {}, tool_calls[:] if tool_calls else None))
438458

439459
if finish_reason:
440460
usage_stats["stop_reason"] = finish_reason
@@ -450,7 +470,7 @@ def generation_thread():
450470
"tokens_per_second": tokens_per_second,
451471
"reasoning_time": timing_stats["reasoning_time"],
452472
"reasoning_tokens": timing_stats["reasoning_tokens"]
453-
}, None))
473+
}, tool_calls if tool_calls else None))
454474

455475
thread = Thread(target=generation_thread, daemon=True)
456476
thread.start()

0 commit comments

Comments
 (0)