@@ -88,7 +88,7 @@ class LLMInput(BaseAppInput):
8888 context_size : int = Field (default = 4096 )
8989
9090 # Model specific flags
91- enable_thinking : bool = Field (default = False )
91+ reasoning : bool = Field (default = False )
9292
9393class LLMUsage (BaseAppOutput ):
9494 stop_reason : str = ""
@@ -97,11 +97,13 @@ class LLMUsage(BaseAppOutput):
9797 prompt_tokens : int = 0
9898 completion_tokens : int = 0
9999 total_tokens : int = 0
100+ reasoning_tokens : int = 0
101+ reasoning_time : float = 0.0
100102
101103
102104class LLMOutput (BaseAppOutput ):
103105 response : str
104- thinking_content : Optional [str ] = None
106+ reasoning : Optional [str ] = None
105107 usage : Optional [LLMUsage ] = None
106108
107109
@@ -112,11 +114,27 @@ class TimingInfo:
112114 def __init__ (self ):
113115 self .start_time = time .time ()
114116 self .first_token_time = None
117+ self .reasoning_start_time = None
118+ self .total_reasoning_time = 0.0
119+ self .reasoning_tokens = 0
120+ self .in_reasoning = False
115121
116122 def mark_first_token (self ):
117123 if self .first_token_time is None :
118124 self .first_token_time = time .time ()
119125
126+ def start_reasoning (self ):
127+ if not self .in_reasoning :
128+ self .reasoning_start_time = time .time ()
129+ self .in_reasoning = True
130+
131+ def end_reasoning (self , token_count : int = 0 ):
132+ if self .in_reasoning and self .reasoning_start_time :
133+ self .total_reasoning_time += time .time () - self .reasoning_start_time
134+ self .reasoning_tokens += token_count
135+ self .reasoning_start_time = None
136+ self .in_reasoning = False
137+
120138 @property
121139 def stats (self ):
122140 end_time = time .time ()
@@ -128,7 +146,9 @@ def stats(self):
128146
129147 return {
130148 "time_to_first_token" : time_to_first ,
131- "generation_time" : generation_time
149+ "generation_time" : generation_time ,
150+ "reasoning_time" : self .total_reasoning_time ,
151+ "reasoning_tokens" : self .reasoning_tokens
132152 }
133153
134154 timing = TimingInfo ()
@@ -186,29 +206,170 @@ def render_message(msg: ContextMessage, allow_multipart: bool) -> str | List[dic
186206 return messages
187207
188208
209+ class ResponseState :
210+ """Holds the state of response transformation."""
211+ def __init__ (self ):
212+ self .buffer = ""
213+ self .response = ""
214+ self .reasoning = None
215+ self .function_calls = None # For future function calling support
216+ self .tool_calls = None # For future tool calling support
217+ self .state_changes = {
218+ "reasoning_started" : False ,
219+ "reasoning_ended" : False ,
220+ "function_call_started" : False ,
221+ "function_call_ended" : False ,
222+ "tool_call_started" : False ,
223+ "tool_call_ended" : False
224+ }
225+
226+ class ResponseTransformer :
227+ """Base class for transforming model responses."""
228+ def __init__ (self , output_cls : type [LLMOutput ] = LLMOutput ):
229+ self .state = ResponseState ()
230+ self .output_cls = output_cls
231+
232+ def clean_text (self , text : str ) -> str :
233+ """Clean common tokens from the text and apply model-specific cleaning.
234+
235+ Args:
236+ text: Raw text to clean
237+
238+ Returns:
239+ Cleaned text with common and model-specific tokens removed
240+ """
241+ # Common token cleaning across most models
242+ cleaned = (text .replace ("<|im_end|>" , "" )
243+ .replace ("<|im_start|>" , "" )
244+ .replace ("<start_of_turn>" , "" )
245+ .replace ("<end_of_turn>" , "" )
246+ .replace ("<eos>" , "" ))
247+ return self .additional_cleaning (cleaned )
248+
249+ def additional_cleaning (self , text : str ) -> str :
250+ """Apply model-specific token cleaning.
251+
252+ Args:
253+ text: Text that has had common tokens removed
254+
255+ Returns:
256+ Text with model-specific tokens removed
257+ """
258+ return text
259+
260+ def handle_reasoning (self , text : str ) -> None :
261+ """Handle reasoning/thinking detection and extraction.
262+
263+ Args:
264+ text: Cleaned text to process for reasoning
265+ """
266+ # Default implementation for <think> style reasoning
267+ if "<think>" in text :
268+ self .state .state_changes ["reasoning_started" ] = True
269+ if "</think>" in text :
270+ self .state .state_changes ["reasoning_ended" ] = True
271+
272+ if "<think>" in self .state .buffer :
273+ parts = self .state .buffer .split ("</think>" , 1 )
274+ if len (parts ) > 1 :
275+ self .state .reasoning = parts [0 ].split ("<think>" , 1 )[1 ].strip ()
276+ self .state .response = parts [1 ].strip ()
277+ else :
278+ self .state .reasoning = self .state .buffer .split ("<think>" , 1 )[1 ].strip ()
279+ self .state .response = ""
280+ else :
281+ self .state .response = self .state .buffer
282+
283+ def handle_function_calls (self , text : str ) -> None :
284+ """Handle function call detection and extraction.
285+
286+ Args:
287+ text: Cleaned text to process for function calls
288+ """
289+ # Default no-op implementation
290+ # Models can override this to implement function call handling
291+ pass
292+
293+ def handle_tool_calls (self , text : str ) -> None :
294+ """Handle tool call detection and extraction.
295+
296+ Args:
297+ text: Cleaned text to process for tool calls
298+ """
299+ # Default no-op implementation
300+ # Models can override this to implement tool call handling
301+ pass
302+
303+ def transform_chunk (self , chunk : str ) -> None :
304+ """Transform a single chunk of model output.
305+
306+ This method orchestrates the transformation process by:
307+ 1. Cleaning the text
308+ 2. Updating the buffer
309+ 3. Processing various capabilities (reasoning, function calls, etc)
310+
311+ Args:
312+ chunk: Raw text chunk from the model
313+ """
314+ cleaned = self .clean_text (chunk )
315+ self .state .buffer += cleaned
316+
317+ # Process different capabilities
318+ self .handle_reasoning (cleaned )
319+ self .handle_function_calls (cleaned )
320+ self .handle_tool_calls (cleaned )
321+
322+ def build_output (self ) -> tuple [str , LLMOutput , dict ]:
323+ """Build the final output tuple.
324+
325+ Returns:
326+ Tuple of (buffer, LLMOutput, state_changes)
327+ """
328+ return (
329+ self .state .buffer ,
330+ self .output_cls (
331+ response = self .state .response .strip (),
332+ reasoning = self .state .reasoning .strip () if self .state .reasoning else None ,
333+ function_calls = self .state .function_calls ,
334+ tool_calls = self .state .tool_calls
335+ ),
336+ self .state .state_changes
337+ )
338+
339+ def __call__ (self , piece : str , buffer : str ) -> tuple [str , LLMOutput , dict ]:
340+ """Transform a piece of text and return the result.
341+
342+ Args:
343+ piece: New piece of text to transform
344+ buffer: Existing buffer content
345+
346+ Returns:
347+ Tuple of (new_buffer, output, state_changes)
348+ """
349+ self .state .buffer = buffer
350+ self .transform_chunk (piece )
351+ return self .build_output ()
352+
353+
189354def stream_generate (
190355 model : Any ,
191356 messages : List [Dict [str , Any ]],
192- output_cls : type [ LLMOutput ] ,
357+ transformer : ResponseTransformer ,
193358 temperature : float = 0.7 ,
194359 top_p : float = 0.95 ,
195360 max_tokens : int = 4096 ,
196361 stop : Optional [List [str ]] = None ,
197- handle_thinking : bool = False ,
198- transform_response : Optional [Callable [[str , str ], tuple [str , LLMOutput ]]] = None ,
199362) -> Generator [LLMOutput , None , None ]:
200363 """Stream generate from LLaMA.cpp model with timing and usage tracking.
201364
202365 Args:
203366 model: The LLaMA.cpp model instance
204367 messages: List of messages to send to the model
205- output_cls: Output class type to use for responses
368+ transformer: ResponseTransformer instance to use for processing output
206369 temperature: Sampling temperature
207370 top_p: Top-p sampling threshold
208371 max_tokens: Maximum tokens to generate
209372 stop: Optional list of stop sequences
210- handle_thinking: Whether to handle thinking tags
211- transform_response: Optional function to transform responses, takes (piece, buffer) and returns (new_buffer, output)
212373 """
213374 response_queue : Queue [Optional [tuple [str , dict ]]] = Queue ()
214375 thread_exception = None
@@ -233,11 +394,9 @@ def generation_thread():
233394 )
234395
235396 for chunk in completion :
236- # Get usage from root level if present
237397 if "usage" in chunk and chunk ["usage" ] is not None :
238398 usage_stats .update (chunk ["usage" ])
239399
240- # Get content from choices
241400 delta = chunk .get ("choices" , [{}])[0 ]
242401 content = None
243402 finish_reason = None
@@ -265,15 +424,15 @@ def generation_thread():
265424 tokens_per_second = (usage_stats ["completion_tokens" ] / generation_time ) if generation_time > 0 else 0
266425 response_queue .put ((None , {
267426 "time_to_first_token" : timing_stats ["time_to_first_token" ],
268- "tokens_per_second" : tokens_per_second
427+ "tokens_per_second" : tokens_per_second ,
428+ "reasoning_time" : timing_stats ["reasoning_time" ],
429+ "reasoning_tokens" : timing_stats ["reasoning_tokens" ]
269430 }))
270431
271432 thread = Thread (target = generation_thread , daemon = True )
272433 thread .start ()
273434
274435 buffer = ""
275- thinking_content = "" if handle_thinking else None
276- in_thinking = handle_thinking
277436 try :
278437 while True :
279438 try :
@@ -290,59 +449,18 @@ def generation_thread():
290449 tokens_per_second = timing_stats ["tokens_per_second" ],
291450 prompt_tokens = usage_stats ["prompt_tokens" ],
292451 completion_tokens = usage_stats ["completion_tokens" ],
293- total_tokens = usage_stats ["total_tokens" ]
452+ total_tokens = usage_stats ["total_tokens" ],
453+ reasoning_time = timing_stats ["reasoning_time" ],
454+ reasoning_tokens = timing_stats ["reasoning_tokens" ]
294455 )
295456
296- if transform_response :
297- buffer , output = transform_response (piece or "" , buffer )
298- output .usage = usage
299- yield output
300- else :
301- # Handle thinking vs response content if enabled
302- if handle_thinking and "</think>" in piece :
303- parts = piece .split ("</think>" )
304- if in_thinking :
305- thinking_content += parts [0 ].replace ("<think>" , "" )
306- buffer = parts [1 ] if len (parts ) > 1 else ""
307- in_thinking = False
308- else :
309- buffer += piece
310- else :
311- if in_thinking :
312- thinking_content += piece .replace ("<think>" , "" )
313- else :
314- buffer += piece
315-
316- yield output_cls (
317- response = buffer .strip (),
318- thinking_content = thinking_content .strip () if thinking_content else None ,
319- usage = usage
320- )
321- break
322-
323- if transform_response :
324- buffer , output = transform_response (piece , buffer )
457+ buffer , output , _ = transformer (piece or "" , buffer )
458+ output .usage = usage
325459 yield output
326- else :
327- # Handle thinking vs response content if enabled
328- if handle_thinking and "</think>" in piece :
329- parts = piece .split ("</think>" )
330- if in_thinking :
331- thinking_content += parts [0 ].replace ("<think>" , "" )
332- buffer = parts [1 ] if len (parts ) > 1 else ""
333- in_thinking = False
334- else :
335- buffer += piece
336- else :
337- if in_thinking :
338- thinking_content += piece .replace ("<think>" , "" )
339- else :
340- buffer += piece
460+ break
341461
342- yield output_cls (
343- response = buffer .strip (),
344- thinking_content = thinking_content .strip () if thinking_content else None
345- )
462+ buffer , output , _ = transformer (piece , buffer )
463+ yield output
346464
347465 except Exception as e :
348466 if thread_exception and isinstance (e , thread_exception .__class__ ):
0 commit comments