inferencesh 0.3.0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of inferencesh might be problematic. Click here for more details.

@@ -0,0 +1,729 @@
1
+ from typing import Optional, List, Any, Callable, Dict, Generator
2
+ from enum import Enum
3
+ from pydantic import Field, BaseModel
4
+ from queue import Queue, Empty
5
+ from threading import Thread
6
+ import time
7
+ from contextlib import contextmanager
8
+ import base64
9
+
10
+ from .base import BaseAppInput, BaseAppOutput
11
+ from .file import File
12
+
13
+ class ContextMessageRole(str, Enum):
14
+ USER = "user"
15
+ ASSISTANT = "assistant"
16
+ SYSTEM = "system"
17
+
18
+
19
+ class Message(BaseAppInput):
20
+ role: ContextMessageRole
21
+ content: str
22
+
23
+
24
+ class ContextMessage(BaseAppInput):
25
+ role: ContextMessageRole = Field(
26
+ description="the role of the message. user, assistant, or system",
27
+ )
28
+ text: str = Field(
29
+ description="the text content of the message"
30
+ )
31
+ image: Optional[File] = Field(
32
+ description="the image file of the message",
33
+ default=None
34
+ )
35
+
36
+ class BaseLLMInput(BaseAppInput):
37
+ """Base class with common LLM fields."""
38
+ system_prompt: str = Field(
39
+ description="the system prompt to use for the model",
40
+ default="you are a helpful assistant that can answer questions and help with tasks.",
41
+ examples=[
42
+ "you are a helpful assistant that can answer questions and help with tasks.",
43
+ ]
44
+ )
45
+ context: List[ContextMessage] = Field(
46
+ description="the context to use for the model",
47
+ default=[],
48
+ examples=[
49
+ [
50
+ {"role": "user", "content": [{"type": "text", "text": "What is the capital of France?"}]},
51
+ {"role": "assistant", "content": [{"type": "text", "text": "The capital of France is Paris."}]}
52
+ ]
53
+ ]
54
+ )
55
+ text: str = Field(
56
+ description="the user prompt to use for the model",
57
+ examples=[
58
+ "write a haiku about artificial general intelligence"
59
+ ]
60
+ )
61
+ temperature: float = Field(default=0.7, ge=0.0, le=1.0)
62
+ top_p: float = Field(default=0.95, ge=0.0, le=1.0)
63
+ context_size: int = Field(default=4096)
64
+
65
+ class ImageCapabilityMixin(BaseModel):
66
+ """Mixin for models that support image inputs."""
67
+ image: Optional[File] = Field(
68
+ description="the image to use for the model",
69
+ default=None,
70
+ contentMediaType="image/*",
71
+ )
72
+
73
+ class ReasoningCapabilityMixin(BaseModel):
74
+ """Mixin for models that support reasoning."""
75
+ reasoning: bool = Field(
76
+ description="enable step-by-step reasoning",
77
+ default=False
78
+ )
79
+
80
+ class ToolsCapabilityMixin(BaseModel):
81
+ """Mixin for models that support tool/function calling."""
82
+ tools: Optional[List[Dict[str, Any]]] = Field(
83
+ description="tool definitions for function calling",
84
+ default=None
85
+ )
86
+
87
+ # Example of how to use:
88
+ class LLMInput(BaseLLMInput):
89
+ """Default LLM input model with no special capabilities."""
90
+ pass
91
+
92
+ # For backward compatibility
93
+ LLMInput.model_config["title"] = "LLMInput"
94
+
95
+ class LLMUsage(BaseAppOutput):
96
+ stop_reason: str = ""
97
+ time_to_first_token: float = 0.0
98
+ tokens_per_second: float = 0.0
99
+ prompt_tokens: int = 0
100
+ completion_tokens: int = 0
101
+ total_tokens: int = 0
102
+ reasoning_tokens: int = 0
103
+ reasoning_time: float = 0.0
104
+
105
+
106
+ class BaseLLMOutput(BaseAppOutput):
107
+ """Base class for LLM outputs with common fields."""
108
+ response: str = Field(description="the generated text response")
109
+
110
+ class LLMUsageMixin(BaseModel):
111
+ """Mixin for models that provide token usage statistics."""
112
+ usage: Optional[LLMUsage] = Field(
113
+ description="token usage statistics",
114
+ default=None
115
+ )
116
+
117
+ class ReasoningMixin(BaseModel):
118
+ """Mixin for models that support reasoning."""
119
+ reasoning: Optional[str] = Field(
120
+ description="the reasoning output of the model",
121
+ default=None
122
+ )
123
+
124
+ class ToolCallsMixin(BaseModel):
125
+ """Mixin for models that support tool calls."""
126
+ tool_calls: Optional[List[Dict[str, Any]]] = Field(
127
+ description="tool calls for function calling",
128
+ default=None
129
+ )
130
+
131
+ # Example of how to use:
132
+ class LLMOutput(LLMUsageMixin, BaseLLMOutput):
133
+ """Default LLM output model with token usage tracking."""
134
+ pass
135
+
136
+ # For backward compatibility
137
+ LLMOutput.model_config["title"] = "LLMOutput"
138
+
139
+ @contextmanager
140
+ def timing_context():
141
+ """Context manager to track timing information for LLM generation."""
142
+ class TimingInfo:
143
+ def __init__(self):
144
+ self.start_time = time.time()
145
+ self.first_token_time = None
146
+ self.reasoning_start_time = None
147
+ self.total_reasoning_time = 0.0
148
+ self.reasoning_tokens = 0
149
+ self.in_reasoning = False
150
+
151
+ def mark_first_token(self):
152
+ if self.first_token_time is None:
153
+ self.first_token_time = time.time()
154
+
155
+ def start_reasoning(self):
156
+ if not self.in_reasoning:
157
+ self.reasoning_start_time = time.time()
158
+ self.in_reasoning = True
159
+
160
+ def end_reasoning(self, token_count: int = 0):
161
+ if self.in_reasoning and self.reasoning_start_time:
162
+ self.total_reasoning_time += time.time() - self.reasoning_start_time
163
+ self.reasoning_tokens += token_count
164
+ self.reasoning_start_time = None
165
+ self.in_reasoning = False
166
+
167
+ @property
168
+ def stats(self):
169
+ current_time = time.time()
170
+ if self.first_token_time is None:
171
+ return {
172
+ "time_to_first_token": 0.0,
173
+ "generation_time": 0.0,
174
+ "reasoning_time": self.total_reasoning_time,
175
+ "reasoning_tokens": self.reasoning_tokens
176
+ }
177
+
178
+ time_to_first = self.first_token_time - self.start_time
179
+ generation_time = current_time - self.first_token_time
180
+
181
+ return {
182
+ "time_to_first_token": time_to_first,
183
+ "generation_time": generation_time,
184
+ "reasoning_time": self.total_reasoning_time,
185
+ "reasoning_tokens": self.reasoning_tokens
186
+ }
187
+
188
+ timing = TimingInfo()
189
+ try:
190
+ yield timing
191
+ finally:
192
+ pass
193
+
194
+ def image_to_base64_data_uri(file_path):
195
+ with open(file_path, "rb") as img_file:
196
+ base64_data = base64.b64encode(img_file.read()).decode('utf-8')
197
+ return f"data:image/png;base64,{base64_data}"
198
+
199
+ def build_messages(
200
+ input_data: LLMInput,
201
+ transform_user_message: Optional[Callable[[str], str]] = None
202
+ ) -> List[Dict[str, Any]]:
203
+ """Build messages for LLaMA.cpp chat completion.
204
+
205
+ If any message includes image content, builds OpenAI-style multipart format.
206
+ Otherwise, uses plain string-only format.
207
+ """
208
+ def render_message(msg: ContextMessage, allow_multipart: bool) -> str | List[dict]:
209
+ parts = []
210
+ text = transform_user_message(msg.text) if transform_user_message and msg.role == ContextMessageRole.USER else msg.text
211
+ if text:
212
+ parts.append({"type": "text", "text": text})
213
+ if msg.image:
214
+ if msg.image.path:
215
+ image_data_uri = image_to_base64_data_uri(msg.image.path)
216
+ parts.append({"type": "image_url", "image_url": {"url": image_data_uri}})
217
+ elif msg.image.uri:
218
+ parts.append({"type": "image_url", "image_url": {"url": msg.image.uri}})
219
+ if allow_multipart:
220
+ return parts
221
+ if len(parts) == 1 and parts[0]["type"] == "text":
222
+ return parts[0]["text"]
223
+ raise ValueError("Image content requires multipart support")
224
+
225
+ messages = [{"role": "system", "content": input_data.system_prompt}] if input_data.system_prompt is not None and input_data.system_prompt != "" else []
226
+
227
+ def merge_messages(messages: List[ContextMessage]) -> ContextMessage:
228
+ text = "\n\n".join(msg.text for msg in messages if msg.text)
229
+ images = [msg.image for msg in messages if msg.image]
230
+ image = images[0] if images else None # TODO: handle multiple images
231
+ return ContextMessage(role=messages[0].role, text=text, image=image)
232
+
233
+ user_input_text = ""
234
+ if hasattr(input_data, "text"):
235
+ user_input_text = transform_user_message(input_data.text) if transform_user_message else input_data.text
236
+
237
+ user_input_image = None
238
+ multipart = any(m.image for m in input_data.context)
239
+ if hasattr(input_data, "image"):
240
+ user_input_image = input_data.image
241
+ multipart = multipart or input_data.image is not None
242
+
243
+ user_msg = ContextMessage(role=ContextMessageRole.USER, text=user_input_text, image=user_input_image)
244
+
245
+ input_data.context.append(user_msg)
246
+
247
+ current_role = None
248
+ current_messages = []
249
+
250
+ for msg in input_data.context:
251
+ if msg.role == current_role or current_role is None:
252
+ current_messages.append(msg)
253
+ current_role = msg.role
254
+ else:
255
+ messages.append({
256
+ "role": current_role,
257
+ "content": render_message(merge_messages(current_messages), allow_multipart=multipart)
258
+ })
259
+ current_messages = [msg]
260
+ current_role = msg.role
261
+ if len(current_messages) > 0:
262
+ messages.append({
263
+ "role": current_role,
264
+ "content": render_message(merge_messages(current_messages), allow_multipart=multipart)
265
+ })
266
+
267
+ return messages
268
+
269
+
270
+ class StreamResponse:
271
+ """Holds a single chunk of streamed response."""
272
+ def __init__(self):
273
+ self.content = ""
274
+ self.tool_calls = None # Changed from [] to None
275
+ self.finish_reason = None
276
+ self.timing_stats = {
277
+ "time_to_first_token": None, # Changed from 0.0 to None
278
+ "generation_time": 0.0,
279
+ "reasoning_time": 0.0,
280
+ "reasoning_tokens": 0,
281
+ "tokens_per_second": 0.0
282
+ }
283
+ self.usage_stats = {
284
+ "prompt_tokens": 0,
285
+ "completion_tokens": 0,
286
+ "total_tokens": 0,
287
+ "stop_reason": ""
288
+ }
289
+
290
+ def update_from_chunk(self, chunk: Dict[str, Any], timing: Any) -> None:
291
+ """Update response state from a chunk."""
292
+ # Update usage stats if present
293
+ if "usage" in chunk:
294
+ usage = chunk["usage"]
295
+ if usage is not None:
296
+ # Update usage stats preserving existing values if not provided
297
+ self.usage_stats.update({
298
+ "prompt_tokens": usage.get("prompt_tokens", self.usage_stats["prompt_tokens"]),
299
+ "completion_tokens": usage.get("completion_tokens", self.usage_stats["completion_tokens"]),
300
+ "total_tokens": usage.get("total_tokens", self.usage_stats["total_tokens"])
301
+ })
302
+
303
+ # Get the delta from the chunk
304
+ delta = chunk.get("choices", [{}])[0]
305
+
306
+ # Extract content and tool calls from either message or delta
307
+ if "message" in delta:
308
+ message = delta["message"]
309
+ self.content = message.get("content", "")
310
+ if message.get("tool_calls"):
311
+ self._update_tool_calls(message["tool_calls"])
312
+ self.finish_reason = delta.get("finish_reason")
313
+ if self.finish_reason:
314
+ self.usage_stats["stop_reason"] = self.finish_reason
315
+ elif "delta" in delta:
316
+ delta_content = delta["delta"]
317
+ self.content = delta_content.get("content", "")
318
+ if delta_content.get("tool_calls"):
319
+ self._update_tool_calls(delta_content["tool_calls"])
320
+ self.finish_reason = delta.get("finish_reason")
321
+ if self.finish_reason:
322
+ self.usage_stats["stop_reason"] = self.finish_reason
323
+
324
+ # Update timing stats
325
+ timing_stats = timing.stats
326
+ if self.timing_stats["time_to_first_token"] is None:
327
+ self.timing_stats["time_to_first_token"] = timing_stats["time_to_first_token"]
328
+
329
+ self.timing_stats.update({
330
+ "generation_time": timing_stats["generation_time"],
331
+ "reasoning_time": timing_stats["reasoning_time"],
332
+ "reasoning_tokens": timing_stats["reasoning_tokens"]
333
+ })
334
+
335
+ # Calculate tokens per second only if we have valid completion tokens and generation time
336
+ if self.usage_stats["completion_tokens"] > 0 and timing_stats["generation_time"] > 0:
337
+ self.timing_stats["tokens_per_second"] = (
338
+ self.usage_stats["completion_tokens"] / timing_stats["generation_time"]
339
+ )
340
+
341
+
342
+ def _update_tool_calls(self, new_tool_calls: List[Dict[str, Any]]) -> None:
343
+ """Update tool calls, handling both full and partial updates."""
344
+ if self.tool_calls is None:
345
+ self.tool_calls = []
346
+
347
+ for tool_delta in new_tool_calls:
348
+ tool_id = tool_delta.get("id")
349
+ if not tool_id:
350
+ continue
351
+
352
+ # Find or create tool call
353
+ current_tool = next((t for t in self.tool_calls if t["id"] == tool_id), None)
354
+ if not current_tool:
355
+ current_tool = {
356
+ "id": tool_id,
357
+ "type": tool_delta.get("type", "function"),
358
+ "function": {"name": "", "arguments": ""}
359
+ }
360
+ self.tool_calls.append(current_tool)
361
+
362
+ # Update tool call
363
+ if "function" in tool_delta:
364
+ func_delta = tool_delta["function"]
365
+ if "name" in func_delta:
366
+ current_tool["function"]["name"] = func_delta["name"]
367
+ if "arguments" in func_delta:
368
+ current_tool["function"]["arguments"] += func_delta["arguments"]
369
+
370
+ def has_updates(self) -> bool:
371
+ """Check if this response has any content, tool call, or usage updates."""
372
+ has_content = bool(self.content)
373
+ has_tool_calls = bool(self.tool_calls)
374
+ has_usage = self.usage_stats["prompt_tokens"] > 0 or self.usage_stats["completion_tokens"] > 0
375
+ has_finish = bool(self.finish_reason)
376
+
377
+ return has_content or has_tool_calls or has_usage or has_finish
378
+
379
+ def to_output(self, buffer: str, transformer: Any) -> tuple[BaseLLMOutput, str]:
380
+ """Convert current state to LLMOutput."""
381
+ # Create usage object if we have stats
382
+ usage = None
383
+ if any(self.usage_stats.values()):
384
+ usage = LLMUsage(
385
+ stop_reason=self.usage_stats["stop_reason"],
386
+ time_to_first_token=self.timing_stats["time_to_first_token"] or 0.0,
387
+ tokens_per_second=self.timing_stats["tokens_per_second"],
388
+ prompt_tokens=self.usage_stats["prompt_tokens"],
389
+ completion_tokens=self.usage_stats["completion_tokens"],
390
+ total_tokens=self.usage_stats["total_tokens"],
391
+ reasoning_time=self.timing_stats["reasoning_time"],
392
+ reasoning_tokens=self.timing_stats["reasoning_tokens"]
393
+ )
394
+
395
+ buffer, output, _ = transformer(self.content, buffer, usage)
396
+
397
+ # Add tool calls if present and supported
398
+ if self.tool_calls and hasattr(output, 'tool_calls'):
399
+ output.tool_calls = self.tool_calls
400
+
401
+ return output, buffer
402
+
403
+ class ResponseState:
404
+ """Holds the state of response transformation."""
405
+ def __init__(self):
406
+ self.buffer = ""
407
+ self.response = ""
408
+ self.reasoning = None
409
+ self.function_calls = None # For future function calling support
410
+ self.tool_calls = None # List to accumulate tool calls
411
+ self.current_tool_call = None # Track current tool call being built
412
+ self.usage = None # Add usage field
413
+ self.state_changes = {
414
+ "reasoning_started": False,
415
+ "reasoning_ended": False,
416
+ "function_call_started": False,
417
+ "function_call_ended": False,
418
+ "tool_call_started": False,
419
+ "tool_call_ended": False
420
+ }
421
+
422
+ class ResponseTransformer:
423
+ """Base class for transforming model responses."""
424
+ def __init__(self, output_cls: type[BaseLLMOutput] = LLMOutput):
425
+ self.state = ResponseState()
426
+ self.output_cls = output_cls
427
+ self.timing = None # Will be set by stream_generate
428
+
429
+ def clean_text(self, text: str) -> str:
430
+ """Clean common tokens from the text and apply model-specific cleaning.
431
+
432
+ Args:
433
+ text: Raw text to clean
434
+
435
+ Returns:
436
+ Cleaned text with common and model-specific tokens removed
437
+ """
438
+ if text is None:
439
+ return ""
440
+
441
+ # Common token cleaning across most models
442
+ cleaned = (text.replace("<|im_end|>", "")
443
+ .replace("<|im_start|>", "")
444
+ .replace("<start_of_turn>", "")
445
+ .replace("<end_of_turn>", "")
446
+ .replace("<eos>", ""))
447
+ return self.additional_cleaning(cleaned)
448
+
449
+ def additional_cleaning(self, text: str) -> str:
450
+ """Apply model-specific token cleaning.
451
+
452
+ Args:
453
+ text: Text that has had common tokens removed
454
+
455
+ Returns:
456
+ Text with model-specific tokens removed
457
+ """
458
+ return text
459
+
460
+ def handle_reasoning(self, text: str) -> None:
461
+ """Handle reasoning/thinking detection and extraction.
462
+
463
+ Args:
464
+ text: Cleaned text to process for reasoning
465
+ """
466
+ # Default implementation for <think> style reasoning
467
+ if "<think>" in text and not self.state.state_changes["reasoning_started"]:
468
+ self.state.state_changes["reasoning_started"] = True
469
+ if self.timing:
470
+ self.timing.start_reasoning()
471
+
472
+ if "</think>" in text and not self.state.state_changes["reasoning_ended"]:
473
+ self.state.state_changes["reasoning_ended"] = True
474
+ if self.timing:
475
+ # Estimate token count from character count (rough approximation)
476
+ token_count = len(self.state.buffer.split("<think>")[1].split("</think>")[0]) // 4
477
+ self.timing.end_reasoning(token_count)
478
+
479
+ if "<think>" in self.state.buffer:
480
+ parts = self.state.buffer.split("</think>", 1)
481
+ if len(parts) > 1:
482
+ self.state.reasoning = parts[0].split("<think>", 1)[1].strip()
483
+ self.state.response = parts[1].strip()
484
+ else:
485
+ self.state.reasoning = self.state.buffer.split("<think>", 1)[1].strip()
486
+ self.state.response = ""
487
+ else:
488
+ self.state.response = self.state.buffer
489
+
490
+ def handle_function_calls(self, text: str) -> None:
491
+ """Handle function call detection and extraction.
492
+
493
+ Args:
494
+ text: Cleaned text to process for function calls
495
+ """
496
+ # Default no-op implementation
497
+ # Models can override this to implement function call handling
498
+ pass
499
+
500
+ def handle_tool_calls(self, text: str) -> None:
501
+ """Handle tool call detection and extraction.
502
+
503
+ Args:
504
+ text: Cleaned text to process for tool calls
505
+ """
506
+ # Default no-op implementation
507
+ # Models can override this to implement tool call handling
508
+ pass
509
+
510
+ def transform_chunk(self, chunk: str) -> None:
511
+ """Transform a single chunk of model output.
512
+
513
+ This method orchestrates the transformation process by:
514
+ 1. Cleaning the text
515
+ 2. Updating the buffer
516
+ 3. Processing various capabilities (reasoning, function calls, etc)
517
+
518
+ Args:
519
+ chunk: Raw text chunk from the model
520
+ """
521
+ cleaned = self.clean_text(chunk)
522
+ self.state.buffer += cleaned
523
+
524
+ # Process different capabilities
525
+ self.handle_reasoning(cleaned)
526
+ self.handle_function_calls(cleaned)
527
+ self.handle_tool_calls(cleaned)
528
+
529
+ def build_output(self) -> tuple[str, LLMOutput, dict]:
530
+ """Build the final output tuple.
531
+
532
+ Returns:
533
+ Tuple of (buffer, LLMOutput, state_changes)
534
+ """
535
+ # Build base output with required fields
536
+ output_data = {
537
+ "response": self.state.response.strip(),
538
+ }
539
+
540
+ # Add optional fields if they exist
541
+ if self.state.usage is not None:
542
+ output_data["usage"] = self.state.usage
543
+ if self.state.reasoning:
544
+ output_data["reasoning"] = self.state.reasoning.strip()
545
+ if self.state.function_calls:
546
+ output_data["function_calls"] = self.state.function_calls
547
+ if self.state.tool_calls:
548
+ output_data["tool_calls"] = self.state.tool_calls
549
+
550
+ output = self.output_cls(**output_data)
551
+
552
+ return (
553
+ self.state.buffer,
554
+ output,
555
+ self.state.state_changes
556
+ )
557
+
558
+ def __call__(self, piece: str, buffer: str, usage: Optional[LLMUsage] = None) -> tuple[str, LLMOutput, dict]:
559
+ """Transform a piece of text and return the result.
560
+
561
+ Args:
562
+ piece: New piece of text to transform
563
+ buffer: Existing buffer content
564
+ usage: Optional usage statistics
565
+
566
+ Returns:
567
+ Tuple of (new_buffer, output, state_changes)
568
+ """
569
+ self.state.buffer = buffer
570
+ if usage is not None:
571
+ self.state.usage = usage
572
+ self.transform_chunk(piece)
573
+ return self.build_output()
574
+
575
+
576
+ def stream_generate(
577
+ model: Any,
578
+ messages: List[Dict[str, Any]],
579
+ transformer: ResponseTransformer = ResponseTransformer(),
580
+ tools: Optional[List[Dict[str, Any]]] = None,
581
+ tool_choice: Optional[Dict[str, Any]] = None,
582
+ temperature: float = 0.7,
583
+ top_p: float = 0.95,
584
+ stop: Optional[List[str]] = None,
585
+ verbose: bool = False,
586
+ output_cls: type[BaseLLMOutput] = LLMOutput,
587
+ ) -> Generator[BaseLLMOutput, None, None]:
588
+ """Stream generate from LLaMA.cpp model with timing and usage tracking."""
589
+
590
+ # Create queues for communication between threads
591
+ response_queue = Queue()
592
+ error_queue = Queue()
593
+ keep_alive_queue = Queue()
594
+
595
+ # Set the output class for the transformer
596
+ transformer.output_cls = output_cls
597
+
598
+ def _generate_worker():
599
+ """Worker thread to run the model generation."""
600
+ try:
601
+ # Build completion kwargs
602
+ completion_kwargs = {
603
+ "messages": messages,
604
+ "stream": True,
605
+ "temperature": temperature,
606
+ "top_p": top_p,
607
+ "stop": stop
608
+ }
609
+ if tools is not None:
610
+ completion_kwargs["tools"] = tools
611
+ if tool_choice is not None:
612
+ completion_kwargs["tool_choice"] = tool_choice
613
+
614
+ # Signal that we're starting
615
+ keep_alive_queue.put(("init", time.time()))
616
+
617
+ completion = model.create_chat_completion(**completion_kwargs)
618
+
619
+ for chunk in completion:
620
+ response_queue.put(("chunk", chunk))
621
+ # Update keep-alive timestamp
622
+ keep_alive_queue.put(("alive", time.time()))
623
+
624
+ # Signal completion
625
+ response_queue.put(("done", None))
626
+
627
+ except Exception as e:
628
+ # Preserve the full exception with traceback
629
+ import sys
630
+ error_queue.put((e, sys.exc_info()[2]))
631
+ response_queue.put(("error", str(e)))
632
+
633
+ with timing_context() as timing:
634
+ transformer.timing = timing
635
+
636
+ # Start generation thread
637
+ generation_thread = Thread(target=_generate_worker, daemon=True)
638
+ generation_thread.start()
639
+
640
+ # Initialize response state
641
+ response = StreamResponse()
642
+ buffer = ""
643
+
644
+ # Keep-alive tracking
645
+ last_activity = time.time()
646
+ init_timeout = 30.0 # 30 seconds for initial response
647
+ chunk_timeout = 10.0 # 10 seconds between chunks
648
+
649
+ try:
650
+ # Wait for initial setup
651
+ try:
652
+ msg_type, timestamp = keep_alive_queue.get(timeout=init_timeout)
653
+ if msg_type != "init":
654
+ raise RuntimeError("Unexpected initialization message")
655
+ last_activity = timestamp
656
+ except Empty:
657
+ raise RuntimeError(f"Model failed to initialize within {init_timeout} seconds")
658
+
659
+ while True:
660
+ # Check for errors - now with proper exception chaining
661
+ if not error_queue.empty():
662
+ exc, tb = error_queue.get()
663
+ if isinstance(exc, Exception):
664
+ raise exc.with_traceback(tb)
665
+ else:
666
+ raise RuntimeError(f"Unknown error in worker thread: {exc}")
667
+
668
+ # Check keep-alive
669
+ try:
670
+ while not keep_alive_queue.empty():
671
+ _, timestamp = keep_alive_queue.get_nowait()
672
+ last_activity = timestamp
673
+ except Empty:
674
+ # Ignore empty queue - this is expected
675
+ pass
676
+
677
+ # Check for timeout
678
+ if time.time() - last_activity > chunk_timeout:
679
+ raise RuntimeError(f"No response from model for {chunk_timeout} seconds")
680
+
681
+ # Get next chunk
682
+ try:
683
+ msg_type, data = response_queue.get(timeout=0.1)
684
+ except Empty:
685
+ continue
686
+
687
+ if msg_type == "error":
688
+ # If we get an error message but no exception in error_queue,
689
+ # create a new error
690
+ raise RuntimeError(f"Generation error: {data}")
691
+ elif msg_type == "done":
692
+ break
693
+
694
+ chunk = data
695
+
696
+ if verbose:
697
+ print(chunk)
698
+
699
+ # Mark first token time
700
+ if not timing.first_token_time:
701
+ timing.mark_first_token()
702
+
703
+ # Update response state from chunk
704
+ response.update_from_chunk(chunk, timing)
705
+
706
+ # Yield output if we have updates
707
+ if response.has_updates():
708
+ output, buffer = response.to_output(buffer, transformer)
709
+ yield output
710
+
711
+ # Break if we're done
712
+ if response.finish_reason:
713
+ break
714
+
715
+ # Wait for generation thread to finish
716
+ if generation_thread.is_alive():
717
+ generation_thread.join(timeout=5.0) # Increased timeout to 5 seconds
718
+ if generation_thread.is_alive():
719
+ # Thread didn't finish - this shouldn't happen normally
720
+ raise RuntimeError("Generation thread failed to finish")
721
+
722
+ except Exception as e:
723
+ # Check if there's a thread error we should chain with
724
+ if not error_queue.empty():
725
+ thread_exc, thread_tb = error_queue.get()
726
+ if isinstance(thread_exc, Exception):
727
+ raise e from thread_exc
728
+ # If no thread error, raise the original exception
729
+ raise