inferencesh 0.2.23__py3-none-any.whl → 0.4.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
inferencesh/models/llm.py CHANGED
@@ -1,96 +1,124 @@
1
1
  from typing import Optional, List, Any, Callable, Dict, Generator
2
2
  from enum import Enum
3
- from pydantic import Field
4
- from queue import Queue
3
+ from pydantic import Field, BaseModel
4
+ from queue import Queue, Empty
5
5
  from threading import Thread
6
6
  import time
7
7
  from contextlib import contextmanager
8
8
  import base64
9
+ import json
9
10
 
10
11
  from .base import BaseAppInput, BaseAppOutput
11
12
  from .file import File
12
13
 
13
-
14
14
  class ContextMessageRole(str, Enum):
15
15
  USER = "user"
16
16
  ASSISTANT = "assistant"
17
17
  SYSTEM = "system"
18
+ TOOL = "tool"
18
19
 
19
20
 
20
21
  class Message(BaseAppInput):
21
22
  role: ContextMessageRole
22
23
  content: str
23
24
 
24
-
25
25
  class ContextMessage(BaseAppInput):
26
26
  role: ContextMessageRole = Field(
27
- description="The role of the message",
27
+ description="the role of the message. user, assistant, or system",
28
28
  )
29
29
  text: str = Field(
30
- description="The text content of the message"
30
+ description="the text content of the message"
31
31
  )
32
32
  image: Optional[File] = Field(
33
- description="The image url of the message",
33
+ description="the image file of the message",
34
+ default=None
35
+ )
36
+ images: Optional[List[File]] = Field(
37
+ description="the images of the message",
38
+ default=None
39
+ )
40
+ tool_calls: Optional[List[Dict[str, Any]]] = Field(
41
+ description="the tool calls of the message",
42
+ default=None
43
+ )
44
+ tool_call_id: Optional[str] = Field(
45
+ description="the tool call id for tool role messages",
34
46
  default=None
35
47
  )
36
48
 
37
- class LLMInput(BaseAppInput):
49
+ class BaseLLMInput(BaseAppInput):
50
+ """Base class with common LLM fields."""
38
51
  system_prompt: str = Field(
39
- description="The system prompt to use for the model",
40
- default="You are a helpful assistant that can answer questions and help with tasks.",
52
+ description="the system prompt to use for the model",
53
+ default="you are a helpful assistant that can answer questions and help with tasks.",
41
54
  examples=[
42
- "You are a helpful assistant that can answer questions and help with tasks.",
43
- "You are a certified medical professional who can provide accurate health information.",
44
- "You are a certified financial advisor who can give sound investment guidance.",
45
- "You are a certified cybersecurity expert who can explain security best practices.",
46
- "You are a certified environmental scientist who can discuss climate and sustainability.",
55
+ "you are a helpful assistant that can answer questions and help with tasks.",
47
56
  ]
48
57
  )
49
58
  context: List[ContextMessage] = Field(
50
- description="The context to use for the model",
59
+ description="the context to use for the model",
60
+ default=[],
51
61
  examples=[
52
62
  [
53
- {"role": "user", "content": [{"type": "text", "text": "What is the capital of France?"}]},
63
+ {"role": "user", "content": [{"type": "text", "text": "What is the capital of France?"}]},
54
64
  {"role": "assistant", "content": [{"type": "text", "text": "The capital of France is Paris."}]}
55
- ],
56
- [
57
- {"role": "user", "content": [{"type": "text", "text": "What is the weather like today?"}]},
58
- {"role": "assistant", "content": [{"type": "text", "text": "I apologize, but I don't have access to real-time weather information. You would need to check a weather service or app to get current weather conditions for your location."}]}
59
- ],
60
- [
61
- {"role": "user", "content": [{"type": "text", "text": "Can you help me write a poem about spring?"}]},
62
- {"role": "assistant", "content": [{"type": "text", "text": "Here's a short poem about spring:\n\nGreen buds awakening,\nSoft rain gently falling down,\nNew life springs anew.\n\nWarm sun breaks through clouds,\nBirds return with joyful song,\nNature's sweet rebirth."}]}
63
- ],
64
- [
65
- {"role": "user", "content": [{"type": "text", "text": "Explain quantum computing in simple terms"}]},
66
- {"role": "assistant", "content": [{"type": "text", "text": "Quantum computing is like having a super-powerful calculator that can solve many problems at once instead of one at a time. While regular computers use bits (0s and 1s), quantum computers use quantum bits or \"qubits\" that can be both 0 and 1 at the same time - kind of like being in two places at once! This allows them to process huge amounts of information much faster than regular computers for certain types of problems."}]}
67
65
  ]
68
- ],
69
- default=[]
66
+ ]
67
+ )
68
+ role: ContextMessageRole = Field(
69
+ description="the role of the input text",
70
+ default=ContextMessageRole.USER
70
71
  )
71
72
  text: str = Field(
72
- description="The user prompt to use for the model",
73
+ description="the input text to use for the model",
73
74
  examples=[
74
- "What is the capital of France?",
75
- "What is the weather like today?",
76
- "Can you help me write a poem about spring?",
77
- "Explain quantum computing in simple terms"
78
- ],
75
+ "write a haiku about artificial general intelligence"
76
+ ]
79
77
  )
78
+ temperature: float = Field(default=0.7, ge=0.0, le=1.0)
79
+ top_p: float = Field(default=0.95, ge=0.0, le=1.0)
80
+ context_size: int = Field(default=4096)
81
+
82
+ class ImageCapabilityMixin(BaseModel):
83
+ """Mixin for models that support image inputs."""
80
84
  image: Optional[File] = Field(
81
- description="The image to use for the model",
82
- default=None
85
+ description="the image to use for the model",
86
+ default=None,
87
+ contentMediaType="image/*",
83
88
  )
84
- # Optional parameters
85
- temperature: float = Field(default=0.7)
86
- top_p: float = Field(default=0.95)
87
- max_tokens: int = Field(default=4096)
88
- context_size: int = Field(default=4096)
89
89
 
90
- # Model specific flags
91
- reasoning: bool = Field(default=False)
92
-
93
- tools: List[Dict[str, Any]] = Field(default=[])
90
+ class MultipleImageCapabilityMixin(BaseModel):
91
+ """Mixin for models that support image inputs."""
92
+ images: Optional[List[File]] = Field(
93
+ description="the images to use for the model",
94
+ default=None,
95
+ )
96
+
97
+ class ReasoningCapabilityMixin(BaseModel):
98
+ """Mixin for models that support reasoning."""
99
+ reasoning: bool = Field(
100
+ description="enable step-by-step reasoning",
101
+ default=False
102
+ )
103
+
104
+ class ToolsCapabilityMixin(BaseModel):
105
+ """Mixin for models that support tool/function calling."""
106
+ tools: Optional[List[Dict[str, Any]]] = Field(
107
+ description="tool definitions for function calling",
108
+ default=None
109
+ )
110
+ tool_call_id: Optional[str] = Field(
111
+ description="the tool call id for tool role messages",
112
+ default=None
113
+ )
114
+
115
+ # Example of how to use:
116
+ class LLMInput(BaseLLMInput):
117
+ """Default LLM input model with no special capabilities."""
118
+ pass
119
+
120
+ # For backward compatibility
121
+ LLMInput.model_config["title"] = "LLMInput"
94
122
 
95
123
  class LLMUsage(BaseAppOutput):
96
124
  stop_reason: str = ""
@@ -103,12 +131,38 @@ class LLMUsage(BaseAppOutput):
103
131
  reasoning_time: float = 0.0
104
132
 
105
133
 
106
- class LLMOutput(BaseAppOutput):
107
- response: str
108
- reasoning: Optional[str] = None
109
- tool_calls: Optional[List[Dict[str, Any]]] = None
110
- usage: Optional[LLMUsage] = None
134
+ class BaseLLMOutput(BaseAppOutput):
135
+ """Base class for LLM outputs with common fields."""
136
+ response: str = Field(description="the generated text response")
111
137
 
138
+ class LLMUsageMixin(BaseModel):
139
+ """Mixin for models that provide token usage statistics."""
140
+ usage: Optional[LLMUsage] = Field(
141
+ description="token usage statistics",
142
+ default=None
143
+ )
144
+
145
+ class ReasoningMixin(BaseModel):
146
+ """Mixin for models that support reasoning."""
147
+ reasoning: Optional[str] = Field(
148
+ description="the reasoning output of the model",
149
+ default=None
150
+ )
151
+
152
+ class ToolCallsMixin(BaseModel):
153
+ """Mixin for models that support tool calls."""
154
+ tool_calls: Optional[List[Dict[str, Any]]] = Field(
155
+ description="tool calls for function calling",
156
+ default=None
157
+ )
158
+
159
+ # Example of how to use:
160
+ class LLMOutput(LLMUsageMixin, BaseLLMOutput):
161
+ """Default LLM output model with token usage tracking."""
162
+ pass
163
+
164
+ # For backward compatibility
165
+ LLMOutput.model_config["title"] = "LLMOutput"
112
166
 
113
167
  @contextmanager
114
168
  def timing_context():
@@ -116,7 +170,7 @@ def timing_context():
116
170
  class TimingInfo:
117
171
  def __init__(self):
118
172
  self.start_time = time.time()
119
- self.first_token_time = 0
173
+ self.first_token_time = None
120
174
  self.reasoning_start_time = None
121
175
  self.total_reasoning_time = 0.0
122
176
  self.reasoning_tokens = 0
@@ -140,12 +194,17 @@ def timing_context():
140
194
 
141
195
  @property
142
196
  def stats(self):
143
- end_time = time.time()
197
+ current_time = time.time()
144
198
  if self.first_token_time is None:
145
- self.first_token_time = end_time
199
+ return {
200
+ "time_to_first_token": 0.0,
201
+ "generation_time": 0.0,
202
+ "reasoning_time": self.total_reasoning_time,
203
+ "reasoning_tokens": self.reasoning_tokens
204
+ }
146
205
 
147
206
  time_to_first = self.first_token_time - self.start_time
148
- generation_time = end_time - self.first_token_time
207
+ generation_time = current_time - self.first_token_time
149
208
 
150
209
  return {
151
210
  "time_to_first_token": time_to_first,
@@ -179,36 +238,184 @@ def build_messages(
179
238
  text = transform_user_message(msg.text) if transform_user_message and msg.role == ContextMessageRole.USER else msg.text
180
239
  if text:
181
240
  parts.append({"type": "text", "text": text})
241
+ else:
242
+ parts.append({"type": "text", "text": ""})
182
243
  if msg.image:
183
244
  if msg.image.path:
184
245
  image_data_uri = image_to_base64_data_uri(msg.image.path)
185
246
  parts.append({"type": "image_url", "image_url": {"url": image_data_uri}})
186
247
  elif msg.image.uri:
187
248
  parts.append({"type": "image_url", "image_url": {"url": msg.image.uri}})
249
+ if msg.images:
250
+ for image in msg.images:
251
+ if image.path:
252
+ image_data_uri = image_to_base64_data_uri(image.path)
253
+ parts.append({"type": "image_url", "image_url": {"url": image_data_uri}})
254
+ elif image.uri:
255
+ parts.append({"type": "image_url", "image_url": {"url": image.uri}})
188
256
  if allow_multipart:
189
257
  return parts
190
258
  if len(parts) == 1 and parts[0]["type"] == "text":
191
259
  return parts[0]["text"]
192
- raise ValueError("Image content requires multipart support")
260
+ if len(parts) > 1:
261
+ if parts.any(lambda x: x["type"] == "image_url"):
262
+ raise ValueError("Image content requires multipart support")
263
+ return parts
264
+ raise ValueError("Invalid message content")
193
265
 
194
- multipart = any(m.image for m in input_data.context) or input_data.image is not None
195
266
  messages = [{"role": "system", "content": input_data.system_prompt}] if input_data.system_prompt is not None and input_data.system_prompt != "" else []
196
267
 
197
- for msg in input_data.context:
198
- messages.append({
199
- "role": msg.role,
200
- "content": render_message(msg, allow_multipart=multipart)
201
- })
268
+ def merge_messages(messages: List[ContextMessage]) -> ContextMessage:
269
+ text = "\n\n".join(msg.text for msg in messages if msg.text)
270
+ images = []
271
+ # Collect single images
272
+ for msg in messages:
273
+ if msg.image:
274
+ images.append(msg.image)
275
+ # Collect multiple images (flatten the list)
276
+ for msg in messages:
277
+ if msg.images:
278
+ images.extend(msg.images)
279
+ # Set image to single File if there's exactly one, otherwise None
280
+ image = images[0] if len(images) == 1 else None
281
+ # Set images to the list if there are multiple, otherwise None
282
+ images_list = images if len(images) > 1 else None
283
+ return ContextMessage(role=messages[0].role, text=text, image=image, images=images_list)
284
+
285
+ def merge_tool_calls(messages: List[ContextMessage]) -> List[Dict[str, Any]]:
286
+ tool_calls = []
287
+ for msg in messages:
288
+ if msg.tool_calls:
289
+ tool_calls.extend(msg.tool_calls)
290
+ return tool_calls
291
+
292
+ user_input_text = ""
293
+ if hasattr(input_data, "text"):
294
+ user_input_text = transform_user_message(input_data.text) if transform_user_message else input_data.text
295
+
296
+ user_input_image = None
297
+ multipart = any(m.image for m in input_data.context)
298
+ if hasattr(input_data, "image"):
299
+ user_input_image = input_data.image
300
+ multipart = multipart or input_data.image is not None
301
+
302
+ user_input_images = None
303
+ if hasattr(input_data, "images"):
304
+ user_input_images = input_data.images
305
+ multipart = multipart or input_data.images is not None
202
306
 
203
- user_msg = ContextMessage(role=ContextMessageRole.USER, text=input_data.text, image=input_data.image)
204
- messages.append({
205
- "role": "user",
206
- "content": render_message(user_msg, allow_multipart=multipart)
207
- })
307
+ input_role = input_data.role if hasattr(input_data, "role") else ContextMessageRole.USER
308
+ input_tool_call_id = input_data.tool_call_id if hasattr(input_data, "tool_call_id") else None
309
+ user_msg = ContextMessage(role=input_role, text=user_input_text, image=user_input_image, images=user_input_images, tool_call_id=input_tool_call_id)
310
+
311
+ input_data.context.append(user_msg)
312
+
313
+ current_role = None
314
+ current_messages = []
315
+
316
+ for msg in input_data.context:
317
+ if msg.role == current_role or current_role is None:
318
+ current_messages.append(msg)
319
+ current_role = msg.role
320
+ else:
321
+ # Convert role enum to string for OpenAI API compatibility
322
+ role_str = current_role.value if hasattr(current_role, "value") else current_role
323
+ msg_dict = {
324
+ "role": role_str,
325
+ "content": render_message(merge_messages(current_messages), allow_multipart=multipart),
326
+ }
327
+
328
+ # Only add tool_calls if not empty
329
+ tool_calls = merge_tool_calls(current_messages)
330
+ if tool_calls:
331
+ # Ensure arguments are JSON strings (OpenAI API requirement)
332
+ for tc in tool_calls:
333
+ if "function" in tc and "arguments" in tc["function"]:
334
+ if isinstance(tc["function"]["arguments"], dict):
335
+ tc["function"]["arguments"] = json.dumps(tc["function"]["arguments"])
336
+ msg_dict["tool_calls"] = tool_calls
337
+
338
+ # Add tool_call_id for tool role messages (required by OpenAI API)
339
+ if role_str == "tool":
340
+ if current_messages and current_messages[0].tool_call_id:
341
+ msg_dict["tool_call_id"] = current_messages[0].tool_call_id
342
+ else:
343
+ # If not provided, use empty string to satisfy schema
344
+ msg_dict["tool_call_id"] = ""
345
+
346
+ messages.append(msg_dict)
347
+ current_messages = [msg]
348
+ current_role = msg.role
349
+
350
+ if len(current_messages) > 0:
351
+ # Convert role enum to string for OpenAI API compatibility
352
+ role_str = current_role.value if hasattr(current_role, "value") else current_role
353
+ msg_dict = {
354
+ "role": role_str,
355
+ "content": render_message(merge_messages(current_messages), allow_multipart=multipart),
356
+ }
357
+
358
+ # Only add tool_calls if not empty
359
+ tool_calls = merge_tool_calls(current_messages)
360
+ if tool_calls:
361
+ # Ensure arguments are JSON strings (OpenAI API requirement)
362
+ for tc in tool_calls:
363
+ if "function" in tc and "arguments" in tc["function"]:
364
+ if isinstance(tc["function"]["arguments"], dict):
365
+ tc["function"]["arguments"] = json.dumps(tc["function"]["arguments"])
366
+ msg_dict["tool_calls"] = tool_calls
367
+
368
+ # Add tool_call_id for tool role messages (required by OpenAI API)
369
+ if role_str == "tool":
370
+ if current_messages and current_messages[0].tool_call_id:
371
+ msg_dict["tool_call_id"] = current_messages[0].tool_call_id
372
+ else:
373
+ # If not provided, use empty string to satisfy schema
374
+ msg_dict["tool_call_id"] = ""
375
+
376
+ messages.append(msg_dict)
208
377
 
209
378
  return messages
210
379
 
211
380
 
381
+ def build_tools(tools: Optional[List[Dict[str, Any]]]) -> Optional[List[Dict[str, Any]]]:
382
+ """Build tools in OpenAI API format.
383
+
384
+ Ensures tools are properly formatted:
385
+ - Wrapped in {"type": "function", "function": {...}}
386
+ - Parameters is never None (OpenAI API requirement)
387
+ """
388
+ if not tools:
389
+ return None
390
+
391
+ result = []
392
+ for tool in tools:
393
+ # Extract function definition
394
+ if "type" in tool and "function" in tool:
395
+ func_def = tool["function"].copy()
396
+ else:
397
+ func_def = tool.copy()
398
+
399
+ # Ensure parameters is not None (OpenAI API requirement)
400
+ if func_def.get("parameters") is None:
401
+ func_def["parameters"] = {"type": "object", "properties": {}}
402
+ # Also ensure properties within parameters is not None
403
+ elif func_def["parameters"].get("properties") is None:
404
+ func_def["parameters"]["properties"] = {}
405
+ else:
406
+ # Remove properties with null values (OpenAI API doesn't accept them)
407
+ properties = func_def["parameters"].get("properties", {})
408
+ if properties:
409
+ func_def["parameters"]["properties"] = {
410
+ k: v for k, v in properties.items() if v is not None
411
+ }
412
+
413
+ # Wrap in OpenAI format
414
+ result.append({"type": "function", "function": func_def})
415
+
416
+ return result
417
+
418
+
212
419
  class StreamResponse:
213
420
  """Holds a single chunk of streamed response."""
214
421
  def __init__(self):
@@ -216,7 +423,7 @@ class StreamResponse:
216
423
  self.tool_calls = None # Changed from [] to None
217
424
  self.finish_reason = None
218
425
  self.timing_stats = {
219
- "time_to_first_token": 0.0,
426
+ "time_to_first_token": None, # Changed from 0.0 to None
220
427
  "generation_time": 0.0,
221
428
  "reasoning_time": 0.0,
222
429
  "reasoning_tokens": 0,
@@ -232,8 +439,15 @@ class StreamResponse:
232
439
  def update_from_chunk(self, chunk: Dict[str, Any], timing: Any) -> None:
233
440
  """Update response state from a chunk."""
234
441
  # Update usage stats if present
235
- if "usage" in chunk and chunk["usage"] is not None:
236
- self.usage_stats.update(chunk["usage"])
442
+ if "usage" in chunk:
443
+ usage = chunk["usage"]
444
+ if usage is not None:
445
+ # Update usage stats preserving existing values if not provided
446
+ self.usage_stats.update({
447
+ "prompt_tokens": usage.get("prompt_tokens", self.usage_stats["prompt_tokens"]),
448
+ "completion_tokens": usage.get("completion_tokens", self.usage_stats["completion_tokens"]),
449
+ "total_tokens": usage.get("total_tokens", self.usage_stats["total_tokens"])
450
+ })
237
451
 
238
452
  # Get the delta from the chunk
239
453
  delta = chunk.get("choices", [{}])[0]
@@ -245,23 +459,34 @@ class StreamResponse:
245
459
  if message.get("tool_calls"):
246
460
  self._update_tool_calls(message["tool_calls"])
247
461
  self.finish_reason = delta.get("finish_reason")
462
+ if self.finish_reason:
463
+ self.usage_stats["stop_reason"] = self.finish_reason
248
464
  elif "delta" in delta:
249
465
  delta_content = delta["delta"]
250
466
  self.content = delta_content.get("content", "")
251
467
  if delta_content.get("tool_calls"):
252
468
  self._update_tool_calls(delta_content["tool_calls"])
253
469
  self.finish_reason = delta.get("finish_reason")
470
+ if self.finish_reason:
471
+ self.usage_stats["stop_reason"] = self.finish_reason
254
472
 
255
- # Update timing stats while preserving tokens_per_second
473
+ # Update timing stats
256
474
  timing_stats = timing.stats
257
- generation_time = timing_stats["generation_time"]
258
- completion_tokens = self.usage_stats.get("completion_tokens", 0)
259
- tokens_per_second = (completion_tokens / generation_time) if generation_time > 0 and completion_tokens > 0 else 0.0
475
+ if self.timing_stats["time_to_first_token"] is None:
476
+ self.timing_stats["time_to_first_token"] = timing_stats["time_to_first_token"]
260
477
 
261
478
  self.timing_stats.update({
262
- **timing_stats,
263
- "tokens_per_second": tokens_per_second
479
+ "generation_time": timing_stats["generation_time"],
480
+ "reasoning_time": timing_stats["reasoning_time"],
481
+ "reasoning_tokens": timing_stats["reasoning_tokens"]
264
482
  })
483
+
484
+ # Calculate tokens per second only if we have valid completion tokens and generation time
485
+ if self.usage_stats["completion_tokens"] > 0 and timing_stats["generation_time"] > 0:
486
+ self.timing_stats["tokens_per_second"] = (
487
+ self.usage_stats["completion_tokens"] / timing_stats["generation_time"]
488
+ )
489
+
265
490
 
266
491
  def _update_tool_calls(self, new_tool_calls: List[Dict[str, Any]]) -> None:
267
492
  """Update tool calls, handling both full and partial updates."""
@@ -292,22 +517,22 @@ class StreamResponse:
292
517
  current_tool["function"]["arguments"] += func_delta["arguments"]
293
518
 
294
519
  def has_updates(self) -> bool:
295
- """Check if this response has any content or tool call updates."""
296
- return bool(self.content) or bool(self.tool_calls)
297
-
298
- def to_output(self, buffer: str, transformer: Any) -> LLMOutput:
299
- """Convert current state to LLMOutput."""
300
- buffer, output, _ = transformer(self.content, buffer)
520
+ """Check if this response has any content, tool call, or usage updates."""
521
+ has_content = bool(self.content)
522
+ has_tool_calls = bool(self.tool_calls)
523
+ has_usage = self.usage_stats["prompt_tokens"] > 0 or self.usage_stats["completion_tokens"] > 0
524
+ has_finish = bool(self.finish_reason)
301
525
 
302
- # Add tool calls if present
303
- if self.tool_calls:
304
- output.tool_calls = self.tool_calls
305
-
306
- # Add usage stats if this is final
307
- if self.finish_reason:
308
- output.usage = LLMUsage(
526
+ return has_content or has_tool_calls or has_usage or has_finish
527
+
528
+ def to_output(self, buffer: str, transformer: Any) -> tuple[BaseLLMOutput, str]:
529
+ """Convert current state to LLMOutput."""
530
+ # Create usage object if we have stats
531
+ usage = None
532
+ if any(self.usage_stats.values()):
533
+ usage = LLMUsage(
309
534
  stop_reason=self.usage_stats["stop_reason"],
310
- time_to_first_token=self.timing_stats["time_to_first_token"],
535
+ time_to_first_token=self.timing_stats["time_to_first_token"] or 0.0,
311
536
  tokens_per_second=self.timing_stats["tokens_per_second"],
312
537
  prompt_tokens=self.usage_stats["prompt_tokens"],
313
538
  completion_tokens=self.usage_stats["completion_tokens"],
@@ -315,6 +540,12 @@ class StreamResponse:
315
540
  reasoning_time=self.timing_stats["reasoning_time"],
316
541
  reasoning_tokens=self.timing_stats["reasoning_tokens"]
317
542
  )
543
+
544
+ buffer, output, _ = transformer(self.content, buffer, usage)
545
+
546
+ # Add tool calls if present and supported
547
+ if self.tool_calls and hasattr(output, 'tool_calls'):
548
+ output.tool_calls = self.tool_calls
318
549
 
319
550
  return output, buffer
320
551
 
@@ -327,6 +558,7 @@ class ResponseState:
327
558
  self.function_calls = None # For future function calling support
328
559
  self.tool_calls = None # List to accumulate tool calls
329
560
  self.current_tool_call = None # Track current tool call being built
561
+ self.usage = None # Add usage field
330
562
  self.state_changes = {
331
563
  "reasoning_started": False,
332
564
  "reasoning_ended": False,
@@ -338,7 +570,7 @@ class ResponseState:
338
570
 
339
571
  class ResponseTransformer:
340
572
  """Base class for transforming model responses."""
341
- def __init__(self, output_cls: type[LLMOutput] = LLMOutput):
573
+ def __init__(self, output_cls: type[BaseLLMOutput] = LLMOutput):
342
574
  self.state = ResponseState()
343
575
  self.output_cls = output_cls
344
576
  self.timing = None # Will be set by stream_generate
@@ -381,26 +613,27 @@ class ResponseTransformer:
381
613
  text: Cleaned text to process for reasoning
382
614
  """
383
615
  # Default implementation for <think> style reasoning
384
- if "<think>" in text and not self.state.state_changes["reasoning_started"]:
616
+ # Check for tags in the complete buffer
617
+ if "<think>" in self.state.buffer and not self.state.state_changes["reasoning_started"]:
385
618
  self.state.state_changes["reasoning_started"] = True
386
619
  if self.timing:
387
620
  self.timing.start_reasoning()
388
621
 
389
- if "</think>" in text and not self.state.state_changes["reasoning_ended"]:
390
- self.state.state_changes["reasoning_ended"] = True
391
- if self.timing:
392
- # Estimate token count from character count (rough approximation)
393
- token_count = len(self.state.buffer.split("<think>")[1].split("</think>")[0]) // 4
394
- self.timing.end_reasoning(token_count)
395
-
396
- if "<think>" in self.state.buffer:
397
- parts = self.state.buffer.split("</think>", 1)
398
- if len(parts) > 1:
399
- self.state.reasoning = parts[0].split("<think>", 1)[1].strip()
400
- self.state.response = parts[1].strip()
401
- else:
402
- self.state.reasoning = self.state.buffer.split("<think>", 1)[1].strip()
403
- self.state.response = ""
622
+ # Extract content and handle end of reasoning
623
+ parts = self.state.buffer.split("<think>", 1)
624
+ if len(parts) > 1:
625
+ reasoning_text = parts[1]
626
+ end_parts = reasoning_text.split("</think>", 1)
627
+ self.state.reasoning = end_parts[0].strip()
628
+ self.state.response = end_parts[1].strip() if len(end_parts) > 1 else ""
629
+
630
+ # Check for end tag in complete buffer
631
+ if "</think>" in self.state.buffer and not self.state.state_changes["reasoning_ended"]:
632
+ self.state.state_changes["reasoning_ended"] = True
633
+ if self.timing:
634
+ # Estimate token count from character count (rough approximation)
635
+ token_count = len(self.state.reasoning) // 4
636
+ self.timing.end_reasoning(token_count)
404
637
  else:
405
638
  self.state.response = self.state.buffer
406
639
 
@@ -449,28 +682,43 @@ class ResponseTransformer:
449
682
  Returns:
450
683
  Tuple of (buffer, LLMOutput, state_changes)
451
684
  """
685
+ # Build base output with required fields
686
+ output_data = {
687
+ "response": self.state.response.strip(),
688
+ }
689
+
690
+ # Add optional fields if they exist
691
+ if self.state.usage is not None:
692
+ output_data["usage"] = self.state.usage
693
+ if self.state.reasoning:
694
+ output_data["reasoning"] = self.state.reasoning.strip()
695
+ if self.state.function_calls:
696
+ output_data["function_calls"] = self.state.function_calls
697
+ if self.state.tool_calls:
698
+ output_data["tool_calls"] = self.state.tool_calls
699
+
700
+ output = self.output_cls(**output_data)
701
+
452
702
  return (
453
703
  self.state.buffer,
454
- self.output_cls(
455
- response=self.state.response.strip(),
456
- reasoning=self.state.reasoning.strip() if self.state.reasoning else None,
457
- function_calls=self.state.function_calls,
458
- tool_calls=self.state.tool_calls
459
- ),
704
+ output,
460
705
  self.state.state_changes
461
706
  )
462
707
 
463
- def __call__(self, piece: str, buffer: str) -> tuple[str, LLMOutput, dict]:
708
+ def __call__(self, piece: str, buffer: str, usage: Optional[LLMUsage] = None) -> tuple[str, LLMOutput, dict]:
464
709
  """Transform a piece of text and return the result.
465
710
 
466
711
  Args:
467
712
  piece: New piece of text to transform
468
713
  buffer: Existing buffer content
714
+ usage: Optional usage statistics
469
715
 
470
716
  Returns:
471
717
  Tuple of (new_buffer, output, state_changes)
472
718
  """
473
719
  self.state.buffer = buffer
720
+ if usage is not None:
721
+ self.state.usage = usage
474
722
  self.transform_chunk(piece)
475
723
  return self.build_output()
476
724
 
@@ -483,42 +731,131 @@ def stream_generate(
483
731
  tool_choice: Optional[Dict[str, Any]] = None,
484
732
  temperature: float = 0.7,
485
733
  top_p: float = 0.95,
486
- max_tokens: int = 4096,
487
734
  stop: Optional[List[str]] = None,
488
735
  verbose: bool = False,
489
- ) -> Generator[LLMOutput, None, None]:
736
+ output_cls: type[BaseLLMOutput] = LLMOutput,
737
+ kwargs: Optional[Dict[str, Any]] = None,
738
+ ) -> Generator[BaseLLMOutput, None, None]:
490
739
  """Stream generate from LLaMA.cpp model with timing and usage tracking."""
740
+
741
+ # Create queues for communication between threads
742
+ response_queue = Queue()
743
+ error_queue = Queue()
744
+ keep_alive_queue = Queue()
745
+
746
+ # Set the output class for the transformer
747
+ transformer.output_cls = output_cls
748
+
749
+ def _generate_worker():
750
+ """Worker thread to run the model generation."""
751
+ try:
752
+ # Build completion kwargs
753
+ completion_kwargs = {
754
+ "messages": messages,
755
+ "stream": True,
756
+ "temperature": temperature,
757
+ "top_p": top_p,
758
+ "stop": stop,
759
+ }
760
+ if kwargs:
761
+ completion_kwargs.update(kwargs)
762
+ if tools is not None:
763
+ completion_kwargs["tools"] = tools
764
+ if tool_choice is not None:
765
+ completion_kwargs["tool_choice"] = tool_choice
766
+
767
+ # Signal that we're starting
768
+ keep_alive_queue.put(("init", time.time()))
769
+
770
+ completion = model.create_chat_completion(**completion_kwargs)
771
+
772
+ for chunk in completion:
773
+ response_queue.put(("chunk", chunk))
774
+ # Update keep-alive timestamp
775
+ keep_alive_queue.put(("alive", time.time()))
776
+
777
+ # Signal completion
778
+ response_queue.put(("done", None))
779
+
780
+ except Exception as e:
781
+ # Preserve the full exception with traceback
782
+ import sys
783
+ error_queue.put((e, sys.exc_info()[2]))
784
+ response_queue.put(("error", str(e)))
785
+
491
786
  with timing_context() as timing:
492
787
  transformer.timing = timing
493
788
 
494
- # Build completion kwargs
495
- completion_kwargs = {
496
- "messages": messages,
497
- "stream": True,
498
- "temperature": temperature,
499
- "top_p": top_p,
500
- "max_tokens": max_tokens,
501
- "stop": stop
502
- }
503
- if tools is not None:
504
- completion_kwargs["tools"] = tools
505
- if tool_choice is not None:
506
- completion_kwargs["tool_choice"] = tool_choice
789
+ # Start generation thread
790
+ generation_thread = Thread(target=_generate_worker, daemon=True)
791
+ generation_thread.start()
507
792
 
508
793
  # Initialize response state
509
794
  response = StreamResponse()
510
795
  buffer = ""
511
796
 
797
+ # Keep-alive tracking
798
+ last_activity = time.time()
799
+ init_timeout = 30.0 # 30 seconds for initial response
800
+ chunk_timeout = 10.0 # 10 seconds between chunks
801
+ chunks_begun = False
802
+
512
803
  try:
513
- completion = model.create_chat_completion(**completion_kwargs)
804
+ # Wait for initial setup
805
+ try:
806
+ msg_type, timestamp = keep_alive_queue.get(timeout=init_timeout)
807
+ if msg_type != "init":
808
+ raise RuntimeError("Unexpected initialization message")
809
+ last_activity = timestamp
810
+ except Empty:
811
+ raise RuntimeError(f"Model failed to initialize within {init_timeout} seconds")
514
812
 
515
- for chunk in completion:
813
+ while True:
814
+ # Check for errors - now with proper exception chaining
815
+ if not error_queue.empty():
816
+ exc, tb = error_queue.get()
817
+ if isinstance(exc, Exception):
818
+ raise exc.with_traceback(tb)
819
+ else:
820
+ raise RuntimeError(f"Unknown error in worker thread: {exc}")
821
+
822
+ # Check keep-alive
823
+ try:
824
+ while not keep_alive_queue.empty():
825
+ _, timestamp = keep_alive_queue.get_nowait()
826
+ last_activity = timestamp
827
+ except Empty:
828
+ # Ignore empty queue - this is expected
829
+ pass
830
+
831
+ # Check for timeout
832
+ if chunks_begun and time.time() - last_activity > chunk_timeout:
833
+ raise RuntimeError(f"No response from model for {chunk_timeout} seconds")
834
+
835
+ # Get next chunk
836
+ try:
837
+ msg_type, data = response_queue.get(timeout=0.1)
838
+ except Empty:
839
+ continue
840
+
841
+ if msg_type == "error":
842
+ # If we get an error message but no exception in error_queue,
843
+ # create a new error
844
+ raise RuntimeError(f"Generation error: {data}")
845
+ elif msg_type == "done":
846
+ break
847
+
848
+ chunk = data
849
+
516
850
  if verbose:
517
851
  print(chunk)
518
- # Mark first token time as soon as we get any response
852
+
853
+ # Mark first token time
519
854
  if not timing.first_token_time:
520
855
  timing.mark_first_token()
521
856
 
857
+ chunks_begun = True
858
+
522
859
  # Update response state from chunk
523
860
  response.update_from_chunk(chunk, timing)
524
861
 
@@ -530,7 +867,19 @@ def stream_generate(
530
867
  # Break if we're done
531
868
  if response.finish_reason:
532
869
  break
870
+
871
+ # Wait for generation thread to finish
872
+ if generation_thread.is_alive():
873
+ generation_thread.join(timeout=5.0) # Increased timeout to 5 seconds
874
+ if generation_thread.is_alive():
875
+ # Thread didn't finish - this shouldn't happen normally
876
+ raise RuntimeError("Generation thread failed to finish")
533
877
 
534
878
  except Exception as e:
535
- # Ensure any error is properly propagated
536
- raise e
879
+ # Check if there's a thread error we should chain with
880
+ if not error_queue.empty():
881
+ thread_exc, thread_tb = error_queue.get()
882
+ if isinstance(thread_exc, Exception):
883
+ raise e from thread_exc
884
+ # If no thread error, raise the original exception
885
+ raise