99
1010from .base import BaseAppInput , BaseAppOutput
1111from .file import File
12- from .types import ContextMessage
1312
1413class ContextMessageRole (str , Enum ):
1514 USER = "user"
@@ -535,34 +534,102 @@ def stream_generate(
535534 verbose : bool = False ,
536535) -> Generator [LLMOutput , None , None ]:
537536 """Stream generate from LLaMA.cpp model with timing and usage tracking."""
537+
538+ # Create queues for communication between threads
539+ response_queue = Queue ()
540+ error_queue = Queue ()
541+ keep_alive_queue = Queue ()
542+
543+ def _generate_worker ():
544+ """Worker thread to run the model generation."""
545+ try :
546+ # Build completion kwargs
547+ completion_kwargs = {
548+ "messages" : messages ,
549+ "stream" : True ,
550+ "temperature" : temperature ,
551+ "top_p" : top_p ,
552+ "max_tokens" : max_tokens ,
553+ "stop" : stop
554+ }
555+ if tools is not None :
556+ completion_kwargs ["tools" ] = tools
557+ if tool_choice is not None :
558+ completion_kwargs ["tool_choice" ] = tool_choice
559+
560+ # Signal that we're starting
561+ keep_alive_queue .put (("init" , time .time ()))
562+
563+ completion = model .create_chat_completion (** completion_kwargs )
564+
565+ for chunk in completion :
566+ if verbose :
567+ print (chunk )
568+ response_queue .put (("chunk" , chunk ))
569+ # Update keep-alive timestamp
570+ keep_alive_queue .put (("alive" , time .time ()))
571+
572+ # Signal completion
573+ response_queue .put (("done" , None ))
574+
575+ except Exception as e :
576+ error_queue .put (e )
577+ response_queue .put (("error" , str (e )))
578+
538579 with timing_context () as timing :
539580 transformer .timing = timing
540581
541- # Build completion kwargs
542- completion_kwargs = {
543- "messages" : messages ,
544- "stream" : True ,
545- "temperature" : temperature ,
546- "top_p" : top_p ,
547- "max_tokens" : max_tokens ,
548- "stop" : stop
549- }
550- if tools is not None :
551- completion_kwargs ["tools" ] = tools
552- if tool_choice is not None :
553- completion_kwargs ["tool_choice" ] = tool_choice
582+ # Start generation thread
583+ generation_thread = Thread (target = _generate_worker , daemon = True )
584+ generation_thread .start ()
554585
555586 # Initialize response state
556587 response = StreamResponse ()
557588 buffer = ""
558589
590+ # Keep-alive tracking
591+ last_activity = time .time ()
592+ init_timeout = 30.0 # 30 seconds for initial response
593+ chunk_timeout = 10.0 # 10 seconds between chunks
594+
559595 try :
560- completion = model .create_chat_completion (** completion_kwargs )
596+ # Wait for initial setup
597+ try :
598+ msg_type , timestamp = keep_alive_queue .get (timeout = init_timeout )
599+ if msg_type != "init" :
600+ raise RuntimeError ("Unexpected initialization message" )
601+ last_activity = timestamp
602+ except Queue .Empty :
603+ raise RuntimeError (f"Model failed to initialize within { init_timeout } seconds" )
561604
562- for chunk in completion :
563- if verbose :
564- print (chunk )
565- # Mark first token time as soon as we get any response
605+ while True :
606+ # Check for errors
607+ if not error_queue .empty ():
608+ raise error_queue .get ()
609+
610+ # Check keep-alive
611+ while not keep_alive_queue .empty ():
612+ _ , timestamp = keep_alive_queue .get_nowait ()
613+ last_activity = timestamp
614+
615+ # Check for timeout
616+ if time .time () - last_activity > chunk_timeout :
617+ raise RuntimeError (f"No response from model for { chunk_timeout } seconds" )
618+
619+ # Get next chunk
620+ try :
621+ msg_type , data = response_queue .get (timeout = 0.1 )
622+ except Queue .Empty :
623+ continue
624+
625+ if msg_type == "error" :
626+ raise RuntimeError (f"Generation error: { data } " )
627+ elif msg_type == "done" :
628+ break
629+
630+ chunk = data
631+
632+ # Mark first token time
566633 if not timing .first_token_time :
567634 timing .mark_first_token ()
568635
@@ -577,6 +644,13 @@ def stream_generate(
577644 # Break if we're done
578645 if response .finish_reason :
579646 break
647+
648+ # Wait for generation thread to finish
649+ generation_thread .join (timeout = 5.0 ) # Increased timeout to 5 seconds
650+ if generation_thread .is_alive ():
651+ # Thread didn't finish - this shouldn't happen normally
652+ # but we handle it gracefully
653+ raise RuntimeError ("Generation thread failed to finish" )
580654
581655 except Exception as e :
582656 # Ensure any error is properly propagated
0 commit comments