Skip to content

Commit 11fc491

Browse files
committed
types
1 parent 23b897b commit 11fc491

File tree

1 file changed

+31
-23
lines changed

1 file changed

+31
-23
lines changed

src/inferencesh/models/llm.py

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,7 @@ class LLMUsage(BaseAppOutput):
112112

113113
class BaseLLMOutput(BaseAppOutput):
114114
"""Base class for LLM outputs with common fields."""
115-
text: str = Field(description="The generated text response")
116-
done: bool = Field(default=False, description="Whether this is the final chunk")
115+
response: str = Field(description="The generated text response")
117116

118117
class LLMUsageMixin(BaseModel):
119118
"""Mixin for models that provide token usage statistics."""
@@ -344,15 +343,10 @@ def has_updates(self) -> bool:
344343

345344
def to_output(self, buffer: str, transformer: Any) -> tuple[BaseLLMOutput, str]:
346345
"""Convert current state to LLMOutput."""
347-
buffer, output, _ = transformer(self.content, buffer)
348-
349-
# Add tool calls if present and supported
350-
if self.tool_calls and hasattr(output, 'tool_calls'):
351-
output.tool_calls = self.tool_calls
352-
353-
# Add usage stats if supported
354-
if hasattr(output, 'usage'):
355-
output.usage = LLMUsage(
346+
# Create usage object if we have stats
347+
usage = None
348+
if any(self.usage_stats.values()):
349+
usage = LLMUsage(
356350
stop_reason=self.usage_stats["stop_reason"],
357351
time_to_first_token=self.timing_stats["time_to_first_token"] or 0.0,
358352
tokens_per_second=self.timing_stats["tokens_per_second"],
@@ -362,6 +356,12 @@ def to_output(self, buffer: str, transformer: Any) -> tuple[BaseLLMOutput, str]:
362356
reasoning_time=self.timing_stats["reasoning_time"],
363357
reasoning_tokens=self.timing_stats["reasoning_tokens"]
364358
)
359+
360+
buffer, output, _ = transformer(self.content, buffer, usage)
361+
362+
# Add tool calls if present and supported
363+
if self.tool_calls and hasattr(output, 'tool_calls'):
364+
output.tool_calls = self.tool_calls
365365

366366
return output, buffer
367367

@@ -374,6 +374,7 @@ def __init__(self):
374374
self.function_calls = None # For future function calling support
375375
self.tool_calls = None # List to accumulate tool calls
376376
self.current_tool_call = None # Track current tool call being built
377+
self.usage = None # Add usage field
377378
self.state_changes = {
378379
"reasoning_started": False,
379380
"reasoning_ended": False,
@@ -496,36 +497,43 @@ def build_output(self) -> tuple[str, LLMOutput, dict]:
496497
Returns:
497498
Tuple of (buffer, LLMOutput, state_changes)
498499
"""
499-
output = self.output_cls(
500-
response=self.state.response.strip(),
501-
text=self.state.response.strip() # text is required by BaseLLMOutput
502-
)
500+
# Build base output with required fields
501+
output_data = {
502+
"response": self.state.response.strip(),
503+
}
503504

504-
# Add optional fields if supported
505-
if hasattr(output, 'reasoning') and self.state.reasoning:
506-
output.reasoning = self.state.reasoning.strip()
507-
if hasattr(output, 'function_calls') and self.state.function_calls:
508-
output.function_calls = self.state.function_calls
509-
if hasattr(output, 'tool_calls') and self.state.tool_calls:
510-
output.tool_calls = self.state.tool_calls
505+
# Add optional fields if they exist
506+
if self.state.usage is not None:
507+
output_data["usage"] = self.state.usage
508+
if self.state.reasoning:
509+
output_data["reasoning"] = self.state.reasoning.strip()
510+
if self.state.function_calls:
511+
output_data["function_calls"] = self.state.function_calls
512+
if self.state.tool_calls:
513+
output_data["tool_calls"] = self.state.tool_calls
514+
515+
output = self.output_cls(**output_data)
511516

512517
return (
513518
self.state.buffer,
514519
output,
515520
self.state.state_changes
516521
)
517522

518-
def __call__(self, piece: str, buffer: str) -> tuple[str, LLMOutput, dict]:
523+
def __call__(self, piece: str, buffer: str, usage: Optional[LLMUsage] = None) -> tuple[str, LLMOutput, dict]:
519524
"""Transform a piece of text and return the result.
520525
521526
Args:
522527
piece: New piece of text to transform
523528
buffer: Existing buffer content
529+
usage: Optional usage statistics
524530
525531
Returns:
526532
Tuple of (new_buffer, output, state_changes)
527533
"""
528534
self.state.buffer = buffer
535+
if usage is not None:
536+
self.state.usage = usage
529537
self.transform_chunk(piece)
530538
return self.build_output()
531539

0 commit comments

Comments
 (0)