Skip to content

Commit a118b7d

Browse files
committed
improve llm sdk
1 parent c8ac873 commit a118b7d

File tree

1 file changed

+181
-63
lines changed

1 file changed

+181
-63
lines changed

src/inferencesh/models/llm.py

Lines changed: 181 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ class LLMInput(BaseAppInput):
8888
context_size: int = Field(default=4096)
8989

9090
# Model specific flags
91-
enable_thinking: bool = Field(default=False)
91+
reasoning: bool = Field(default=False)
9292

9393
class LLMUsage(BaseAppOutput):
9494
stop_reason: str = ""
@@ -97,11 +97,13 @@ class LLMUsage(BaseAppOutput):
9797
prompt_tokens: int = 0
9898
completion_tokens: int = 0
9999
total_tokens: int = 0
100+
reasoning_tokens: int = 0
101+
reasoning_time: float = 0.0
100102

101103

102104
class LLMOutput(BaseAppOutput):
103105
response: str
104-
thinking_content: Optional[str] = None
106+
reasoning: Optional[str] = None
105107
usage: Optional[LLMUsage] = None
106108

107109

@@ -112,11 +114,27 @@ class TimingInfo:
112114
def __init__(self):
113115
self.start_time = time.time()
114116
self.first_token_time = None
117+
self.reasoning_start_time = None
118+
self.total_reasoning_time = 0.0
119+
self.reasoning_tokens = 0
120+
self.in_reasoning = False
115121

116122
def mark_first_token(self):
117123
if self.first_token_time is None:
118124
self.first_token_time = time.time()
119125

126+
def start_reasoning(self):
127+
if not self.in_reasoning:
128+
self.reasoning_start_time = time.time()
129+
self.in_reasoning = True
130+
131+
def end_reasoning(self, token_count: int = 0):
132+
if self.in_reasoning and self.reasoning_start_time:
133+
self.total_reasoning_time += time.time() - self.reasoning_start_time
134+
self.reasoning_tokens += token_count
135+
self.reasoning_start_time = None
136+
self.in_reasoning = False
137+
120138
@property
121139
def stats(self):
122140
end_time = time.time()
@@ -128,7 +146,9 @@ def stats(self):
128146

129147
return {
130148
"time_to_first_token": time_to_first,
131-
"generation_time": generation_time
149+
"generation_time": generation_time,
150+
"reasoning_time": self.total_reasoning_time,
151+
"reasoning_tokens": self.reasoning_tokens
132152
}
133153

134154
timing = TimingInfo()
@@ -186,29 +206,170 @@ def render_message(msg: ContextMessage, allow_multipart: bool) -> str | List[dic
186206
return messages
187207

188208

209+
class ResponseState:
210+
"""Holds the state of response transformation."""
211+
def __init__(self):
212+
self.buffer = ""
213+
self.response = ""
214+
self.reasoning = None
215+
self.function_calls = None # For future function calling support
216+
self.tool_calls = None # For future tool calling support
217+
self.state_changes = {
218+
"reasoning_started": False,
219+
"reasoning_ended": False,
220+
"function_call_started": False,
221+
"function_call_ended": False,
222+
"tool_call_started": False,
223+
"tool_call_ended": False
224+
}
225+
226+
class ResponseTransformer:
227+
"""Base class for transforming model responses."""
228+
def __init__(self, output_cls: type[LLMOutput] = LLMOutput):
229+
self.state = ResponseState()
230+
self.output_cls = output_cls
231+
232+
def clean_text(self, text: str) -> str:
233+
"""Clean common tokens from the text and apply model-specific cleaning.
234+
235+
Args:
236+
text: Raw text to clean
237+
238+
Returns:
239+
Cleaned text with common and model-specific tokens removed
240+
"""
241+
# Common token cleaning across most models
242+
cleaned = (text.replace("<|im_end|>", "")
243+
.replace("<|im_start|>", "")
244+
.replace("<start_of_turn>", "")
245+
.replace("<end_of_turn>", "")
246+
.replace("<eos>", ""))
247+
return self.additional_cleaning(cleaned)
248+
249+
def additional_cleaning(self, text: str) -> str:
250+
"""Apply model-specific token cleaning.
251+
252+
Args:
253+
text: Text that has had common tokens removed
254+
255+
Returns:
256+
Text with model-specific tokens removed
257+
"""
258+
return text
259+
260+
def handle_reasoning(self, text: str) -> None:
261+
"""Handle reasoning/thinking detection and extraction.
262+
263+
Args:
264+
text: Cleaned text to process for reasoning
265+
"""
266+
# Default implementation for <think> style reasoning
267+
if "<think>" in text:
268+
self.state.state_changes["reasoning_started"] = True
269+
if "</think>" in text:
270+
self.state.state_changes["reasoning_ended"] = True
271+
272+
if "<think>" in self.state.buffer:
273+
parts = self.state.buffer.split("</think>", 1)
274+
if len(parts) > 1:
275+
self.state.reasoning = parts[0].split("<think>", 1)[1].strip()
276+
self.state.response = parts[1].strip()
277+
else:
278+
self.state.reasoning = self.state.buffer.split("<think>", 1)[1].strip()
279+
self.state.response = ""
280+
else:
281+
self.state.response = self.state.buffer
282+
283+
def handle_function_calls(self, text: str) -> None:
284+
"""Handle function call detection and extraction.
285+
286+
Args:
287+
text: Cleaned text to process for function calls
288+
"""
289+
# Default no-op implementation
290+
# Models can override this to implement function call handling
291+
pass
292+
293+
def handle_tool_calls(self, text: str) -> None:
294+
"""Handle tool call detection and extraction.
295+
296+
Args:
297+
text: Cleaned text to process for tool calls
298+
"""
299+
# Default no-op implementation
300+
# Models can override this to implement tool call handling
301+
pass
302+
303+
def transform_chunk(self, chunk: str) -> None:
304+
"""Transform a single chunk of model output.
305+
306+
This method orchestrates the transformation process by:
307+
1. Cleaning the text
308+
2. Updating the buffer
309+
3. Processing various capabilities (reasoning, function calls, etc)
310+
311+
Args:
312+
chunk: Raw text chunk from the model
313+
"""
314+
cleaned = self.clean_text(chunk)
315+
self.state.buffer += cleaned
316+
317+
# Process different capabilities
318+
self.handle_reasoning(cleaned)
319+
self.handle_function_calls(cleaned)
320+
self.handle_tool_calls(cleaned)
321+
322+
def build_output(self) -> tuple[str, LLMOutput, dict]:
323+
"""Build the final output tuple.
324+
325+
Returns:
326+
Tuple of (buffer, LLMOutput, state_changes)
327+
"""
328+
return (
329+
self.state.buffer,
330+
self.output_cls(
331+
response=self.state.response.strip(),
332+
reasoning=self.state.reasoning.strip() if self.state.reasoning else None,
333+
function_calls=self.state.function_calls,
334+
tool_calls=self.state.tool_calls
335+
),
336+
self.state.state_changes
337+
)
338+
339+
def __call__(self, piece: str, buffer: str) -> tuple[str, LLMOutput, dict]:
340+
"""Transform a piece of text and return the result.
341+
342+
Args:
343+
piece: New piece of text to transform
344+
buffer: Existing buffer content
345+
346+
Returns:
347+
Tuple of (new_buffer, output, state_changes)
348+
"""
349+
self.state.buffer = buffer
350+
self.transform_chunk(piece)
351+
return self.build_output()
352+
353+
189354
def stream_generate(
190355
model: Any,
191356
messages: List[Dict[str, Any]],
192-
output_cls: type[LLMOutput],
357+
transformer: ResponseTransformer,
193358
temperature: float = 0.7,
194359
top_p: float = 0.95,
195360
max_tokens: int = 4096,
196361
stop: Optional[List[str]] = None,
197-
handle_thinking: bool = False,
198-
transform_response: Optional[Callable[[str, str], tuple[str, LLMOutput]]] = None,
199362
) -> Generator[LLMOutput, None, None]:
200363
"""Stream generate from LLaMA.cpp model with timing and usage tracking.
201364
202365
Args:
203366
model: The LLaMA.cpp model instance
204367
messages: List of messages to send to the model
205-
output_cls: Output class type to use for responses
368+
transformer: ResponseTransformer instance to use for processing output
206369
temperature: Sampling temperature
207370
top_p: Top-p sampling threshold
208371
max_tokens: Maximum tokens to generate
209372
stop: Optional list of stop sequences
210-
handle_thinking: Whether to handle thinking tags
211-
transform_response: Optional function to transform responses, takes (piece, buffer) and returns (new_buffer, output)
212373
"""
213374
response_queue: Queue[Optional[tuple[str, dict]]] = Queue()
214375
thread_exception = None
@@ -233,11 +394,9 @@ def generation_thread():
233394
)
234395

235396
for chunk in completion:
236-
# Get usage from root level if present
237397
if "usage" in chunk and chunk["usage"] is not None:
238398
usage_stats.update(chunk["usage"])
239399

240-
# Get content from choices
241400
delta = chunk.get("choices", [{}])[0]
242401
content = None
243402
finish_reason = None
@@ -265,15 +424,15 @@ def generation_thread():
265424
tokens_per_second = (usage_stats["completion_tokens"] / generation_time) if generation_time > 0 else 0
266425
response_queue.put((None, {
267426
"time_to_first_token": timing_stats["time_to_first_token"],
268-
"tokens_per_second": tokens_per_second
427+
"tokens_per_second": tokens_per_second,
428+
"reasoning_time": timing_stats["reasoning_time"],
429+
"reasoning_tokens": timing_stats["reasoning_tokens"]
269430
}))
270431

271432
thread = Thread(target=generation_thread, daemon=True)
272433
thread.start()
273434

274435
buffer = ""
275-
thinking_content = "" if handle_thinking else None
276-
in_thinking = handle_thinking
277436
try:
278437
while True:
279438
try:
@@ -290,59 +449,18 @@ def generation_thread():
290449
tokens_per_second=timing_stats["tokens_per_second"],
291450
prompt_tokens=usage_stats["prompt_tokens"],
292451
completion_tokens=usage_stats["completion_tokens"],
293-
total_tokens=usage_stats["total_tokens"]
452+
total_tokens=usage_stats["total_tokens"],
453+
reasoning_time=timing_stats["reasoning_time"],
454+
reasoning_tokens=timing_stats["reasoning_tokens"]
294455
)
295456

296-
if transform_response:
297-
buffer, output = transform_response(piece or "", buffer)
298-
output.usage = usage
299-
yield output
300-
else:
301-
# Handle thinking vs response content if enabled
302-
if handle_thinking and "</think>" in piece:
303-
parts = piece.split("</think>")
304-
if in_thinking:
305-
thinking_content += parts[0].replace("<think>", "")
306-
buffer = parts[1] if len(parts) > 1 else ""
307-
in_thinking = False
308-
else:
309-
buffer += piece
310-
else:
311-
if in_thinking:
312-
thinking_content += piece.replace("<think>", "")
313-
else:
314-
buffer += piece
315-
316-
yield output_cls(
317-
response=buffer.strip(),
318-
thinking_content=thinking_content.strip() if thinking_content else None,
319-
usage=usage
320-
)
321-
break
322-
323-
if transform_response:
324-
buffer, output = transform_response(piece, buffer)
457+
buffer, output, _ = transformer(piece or "", buffer)
458+
output.usage = usage
325459
yield output
326-
else:
327-
# Handle thinking vs response content if enabled
328-
if handle_thinking and "</think>" in piece:
329-
parts = piece.split("</think>")
330-
if in_thinking:
331-
thinking_content += parts[0].replace("<think>", "")
332-
buffer = parts[1] if len(parts) > 1 else ""
333-
in_thinking = False
334-
else:
335-
buffer += piece
336-
else:
337-
if in_thinking:
338-
thinking_content += piece.replace("<think>", "")
339-
else:
340-
buffer += piece
460+
break
341461

342-
yield output_cls(
343-
response=buffer.strip(),
344-
thinking_content=thinking_content.strip() if thinking_content else None
345-
)
462+
buffer, output, _ = transformer(piece, buffer)
463+
yield output
346464

347465
except Exception as e:
348466
if thread_exception and isinstance(e, thread_exception.__class__):

0 commit comments

Comments
 (0)