Skip to content

Commit 23b897b

Browse files
committed
types
1 parent b208b4a commit 23b897b

File tree

1 file changed

+34
-21
lines changed

1 file changed

+34
-21
lines changed

src/inferencesh/models/llm.py

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -342,25 +342,26 @@ def has_updates(self) -> bool:
342342

343343
return has_content or has_tool_calls or has_usage or has_finish
344344

345-
def to_output(self, buffer: str, transformer: Any) -> LLMOutput:
345+
def to_output(self, buffer: str, transformer: Any) -> tuple[BaseLLMOutput, str]:
346346
"""Convert current state to LLMOutput."""
347347
buffer, output, _ = transformer(self.content, buffer)
348348

349-
# Add tool calls if present
350-
if self.tool_calls:
349+
# Add tool calls if present and supported
350+
if self.tool_calls and hasattr(output, 'tool_calls'):
351351
output.tool_calls = self.tool_calls
352352

353-
# Add usage stats
354-
output.usage = LLMUsage(
355-
stop_reason=self.usage_stats["stop_reason"],
356-
time_to_first_token=self.timing_stats["time_to_first_token"] or 0.0,
357-
tokens_per_second=self.timing_stats["tokens_per_second"],
358-
prompt_tokens=self.usage_stats["prompt_tokens"],
359-
completion_tokens=self.usage_stats["completion_tokens"],
360-
total_tokens=self.usage_stats["total_tokens"],
361-
reasoning_time=self.timing_stats["reasoning_time"],
362-
reasoning_tokens=self.timing_stats["reasoning_tokens"]
363-
)
353+
# Add usage stats if supported
354+
if hasattr(output, 'usage'):
355+
output.usage = LLMUsage(
356+
stop_reason=self.usage_stats["stop_reason"],
357+
time_to_first_token=self.timing_stats["time_to_first_token"] or 0.0,
358+
tokens_per_second=self.timing_stats["tokens_per_second"],
359+
prompt_tokens=self.usage_stats["prompt_tokens"],
360+
completion_tokens=self.usage_stats["completion_tokens"],
361+
total_tokens=self.usage_stats["total_tokens"],
362+
reasoning_time=self.timing_stats["reasoning_time"],
363+
reasoning_tokens=self.timing_stats["reasoning_tokens"]
364+
)
364365

365366
return output, buffer
366367

@@ -495,14 +496,22 @@ def build_output(self) -> tuple[str, LLMOutput, dict]:
495496
Returns:
496497
Tuple of (buffer, LLMOutput, state_changes)
497498
"""
499+
output = self.output_cls(
500+
response=self.state.response.strip(),
501+
text=self.state.response.strip() # text is required by BaseLLMOutput
502+
)
503+
504+
# Add optional fields if supported
505+
if hasattr(output, 'reasoning') and self.state.reasoning:
506+
output.reasoning = self.state.reasoning.strip()
507+
if hasattr(output, 'function_calls') and self.state.function_calls:
508+
output.function_calls = self.state.function_calls
509+
if hasattr(output, 'tool_calls') and self.state.tool_calls:
510+
output.tool_calls = self.state.tool_calls
511+
498512
return (
499513
self.state.buffer,
500-
self.output_cls(
501-
response=self.state.response.strip(),
502-
reasoning=self.state.reasoning.strip() if self.state.reasoning else None,
503-
function_calls=self.state.function_calls,
504-
tool_calls=self.state.tool_calls
505-
),
514+
output,
506515
self.state.state_changes
507516
)
508517

@@ -532,14 +541,18 @@ def stream_generate(
532541
max_tokens: int = 4096,
533542
stop: Optional[List[str]] = None,
534543
verbose: bool = False,
535-
) -> Generator[LLMOutput, None, None]:
544+
output_cls: type[BaseLLMOutput] = LLMOutput,
545+
) -> Generator[BaseLLMOutput, None, None]:
536546
"""Stream generate from LLaMA.cpp model with timing and usage tracking."""
537547

538548
# Create queues for communication between threads
539549
response_queue = Queue()
540550
error_queue = Queue()
541551
keep_alive_queue = Queue()
542552

553+
# Set the output class for the transformer
554+
transformer.output_cls = output_cls
555+
543556
def _generate_worker():
544557
"""Worker thread to run the model generation."""
545558
try:

0 commit comments

Comments
 (0)