Skip to content

Commit 3ce099a

Browse files
committed
fix stats
1 parent 8320474 commit 3ce099a

File tree

1 file changed

+32
-12
lines changed

1 file changed

+32
-12
lines changed

src/inferencesh/models/llm.py

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def timing_context():
116116
class TimingInfo:
117117
def __init__(self):
118118
self.start_time = time.time()
119-
self.first_token_time = 0
119+
self.first_token_time = None
120120
self.reasoning_start_time = None
121121
self.total_reasoning_time = 0.0
122122
self.reasoning_tokens = 0
@@ -140,12 +140,17 @@ def end_reasoning(self, token_count: int = 0):
140140

141141
@property
142142
def stats(self):
143-
end_time = time.time()
143+
current_time = time.time()
144144
if self.first_token_time is None:
145-
self.first_token_time = end_time
145+
return {
146+
"time_to_first_token": 0.0,
147+
"generation_time": 0.0,
148+
"reasoning_time": self.total_reasoning_time,
149+
"reasoning_tokens": self.reasoning_tokens
150+
}
146151

147152
time_to_first = self.first_token_time - self.start_time
148-
generation_time = end_time - self.first_token_time
153+
generation_time = current_time - self.first_token_time
149154

150155
return {
151156
"time_to_first_token": time_to_first,
@@ -216,7 +221,7 @@ def __init__(self):
216221
self.tool_calls = None # Changed from [] to None
217222
self.finish_reason = None
218223
self.timing_stats = {
219-
"time_to_first_token": 0.0,
224+
"time_to_first_token": None, # Changed from 0.0 to None
220225
"generation_time": 0.0,
221226
"reasoning_time": 0.0,
222227
"reasoning_tokens": 0,
@@ -233,7 +238,12 @@ def update_from_chunk(self, chunk: Dict[str, Any], timing: Any) -> None:
233238
"""Update response state from a chunk."""
234239
# Update usage stats if present
235240
if "usage" in chunk and chunk["usage"] is not None:
236-
self.usage_stats.update(chunk["usage"])
241+
usage = chunk["usage"]
242+
self.usage_stats.update({
243+
"prompt_tokens": usage.get("prompt_tokens", self.usage_stats["prompt_tokens"]),
244+
"completion_tokens": usage.get("completion_tokens", self.usage_stats["completion_tokens"]),
245+
"total_tokens": usage.get("total_tokens", self.usage_stats["total_tokens"])
246+
})
237247

238248
# Get the delta from the chunk
239249
delta = chunk.get("choices", [{}])[0]
@@ -245,23 +255,33 @@ def update_from_chunk(self, chunk: Dict[str, Any], timing: Any) -> None:
245255
if message.get("tool_calls"):
246256
self._update_tool_calls(message["tool_calls"])
247257
self.finish_reason = delta.get("finish_reason")
258+
if self.finish_reason:
259+
self.usage_stats["stop_reason"] = self.finish_reason
248260
elif "delta" in delta:
249261
delta_content = delta["delta"]
250262
self.content = delta_content.get("content", "")
251263
if delta_content.get("tool_calls"):
252264
self._update_tool_calls(delta_content["tool_calls"])
253265
self.finish_reason = delta.get("finish_reason")
266+
if self.finish_reason:
267+
self.usage_stats["stop_reason"] = self.finish_reason
254268

255-
# Update timing stats while preserving tokens_per_second
269+
# Update timing stats
256270
timing_stats = timing.stats
257-
generation_time = timing_stats["generation_time"]
258-
completion_tokens = self.usage_stats.get("completion_tokens", 0)
259-
tokens_per_second = (completion_tokens / generation_time) if generation_time > 0 and completion_tokens > 0 else 0.0
271+
if self.timing_stats["time_to_first_token"] is None:
272+
self.timing_stats["time_to_first_token"] = timing_stats["time_to_first_token"]
260273

261274
self.timing_stats.update({
262-
**timing_stats,
263-
"tokens_per_second": tokens_per_second
275+
"generation_time": timing_stats["generation_time"],
276+
"reasoning_time": timing_stats["reasoning_time"],
277+
"reasoning_tokens": timing_stats["reasoning_tokens"]
264278
})
279+
280+
# Calculate tokens per second only if we have valid completion tokens and generation time
281+
if self.usage_stats["completion_tokens"] > 0 and timing_stats["generation_time"] > 0:
282+
self.timing_stats["tokens_per_second"] = (
283+
self.usage_stats["completion_tokens"] / timing_stats["generation_time"]
284+
)
265285

266286
def _update_tool_calls(self, new_tool_calls: List[Dict[str, Any]]) -> None:
267287
"""Update tool calls, handling both full and partial updates."""

0 commit comments

Comments
 (0)