Skip to content

Commit e478cd0

Browse files
committed
llm threading
1 parent d7d6ee8 commit e478cd0

File tree

1 file changed

+93
-19
lines changed

1 file changed

+93
-19
lines changed

src/inferencesh/models/llm.py

Lines changed: 93 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
from .base import BaseAppInput, BaseAppOutput
1111
from .file import File
12-
from .types import ContextMessage
1312

1413
class ContextMessageRole(str, Enum):
1514
USER = "user"
@@ -535,34 +534,102 @@ def stream_generate(
535534
verbose: bool = False,
536535
) -> Generator[LLMOutput, None, None]:
537536
"""Stream generate from LLaMA.cpp model with timing and usage tracking."""
537+
538+
# Create queues for communication between threads
539+
response_queue = Queue()
540+
error_queue = Queue()
541+
keep_alive_queue = Queue()
542+
543+
def _generate_worker():
544+
"""Worker thread to run the model generation."""
545+
try:
546+
# Build completion kwargs
547+
completion_kwargs = {
548+
"messages": messages,
549+
"stream": True,
550+
"temperature": temperature,
551+
"top_p": top_p,
552+
"max_tokens": max_tokens,
553+
"stop": stop
554+
}
555+
if tools is not None:
556+
completion_kwargs["tools"] = tools
557+
if tool_choice is not None:
558+
completion_kwargs["tool_choice"] = tool_choice
559+
560+
# Signal that we're starting
561+
keep_alive_queue.put(("init", time.time()))
562+
563+
completion = model.create_chat_completion(**completion_kwargs)
564+
565+
for chunk in completion:
566+
if verbose:
567+
print(chunk)
568+
response_queue.put(("chunk", chunk))
569+
# Update keep-alive timestamp
570+
keep_alive_queue.put(("alive", time.time()))
571+
572+
# Signal completion
573+
response_queue.put(("done", None))
574+
575+
except Exception as e:
576+
error_queue.put(e)
577+
response_queue.put(("error", str(e)))
578+
538579
with timing_context() as timing:
539580
transformer.timing = timing
540581

541-
# Build completion kwargs
542-
completion_kwargs = {
543-
"messages": messages,
544-
"stream": True,
545-
"temperature": temperature,
546-
"top_p": top_p,
547-
"max_tokens": max_tokens,
548-
"stop": stop
549-
}
550-
if tools is not None:
551-
completion_kwargs["tools"] = tools
552-
if tool_choice is not None:
553-
completion_kwargs["tool_choice"] = tool_choice
582+
# Start generation thread
583+
generation_thread = Thread(target=_generate_worker, daemon=True)
584+
generation_thread.start()
554585

555586
# Initialize response state
556587
response = StreamResponse()
557588
buffer = ""
558589

590+
# Keep-alive tracking
591+
last_activity = time.time()
592+
init_timeout = 30.0 # 30 seconds for initial response
593+
chunk_timeout = 10.0 # 10 seconds between chunks
594+
559595
try:
560-
completion = model.create_chat_completion(**completion_kwargs)
596+
# Wait for initial setup
597+
try:
598+
msg_type, timestamp = keep_alive_queue.get(timeout=init_timeout)
599+
if msg_type != "init":
600+
raise RuntimeError("Unexpected initialization message")
601+
last_activity = timestamp
602+
except Queue.Empty:
603+
raise RuntimeError(f"Model failed to initialize within {init_timeout} seconds")
561604

562-
for chunk in completion:
563-
if verbose:
564-
print(chunk)
565-
# Mark first token time as soon as we get any response
605+
while True:
606+
# Check for errors
607+
if not error_queue.empty():
608+
raise error_queue.get()
609+
610+
# Check keep-alive
611+
while not keep_alive_queue.empty():
612+
_, timestamp = keep_alive_queue.get_nowait()
613+
last_activity = timestamp
614+
615+
# Check for timeout
616+
if time.time() - last_activity > chunk_timeout:
617+
raise RuntimeError(f"No response from model for {chunk_timeout} seconds")
618+
619+
# Get next chunk
620+
try:
621+
msg_type, data = response_queue.get(timeout=0.1)
622+
except Queue.Empty:
623+
continue
624+
625+
if msg_type == "error":
626+
raise RuntimeError(f"Generation error: {data}")
627+
elif msg_type == "done":
628+
break
629+
630+
chunk = data
631+
632+
# Mark first token time
566633
if not timing.first_token_time:
567634
timing.mark_first_token()
568635

@@ -577,6 +644,13 @@ def stream_generate(
577644
# Break if we're done
578645
if response.finish_reason:
579646
break
647+
648+
# Wait for generation thread to finish
649+
generation_thread.join(timeout=5.0) # Increased timeout to 5 seconds
650+
if generation_thread.is_alive():
651+
# Thread didn't finish - this shouldn't happen normally
652+
# but we handle it gracefully
653+
raise RuntimeError("Generation thread failed to finish")
580654

581655
except Exception as e:
582656
# Ensure any error is properly propagated

0 commit comments

Comments
 (0)