inferencesh 0.3.0__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of inferencesh might be problematic. Click here for more details.
- inferencesh/__init__.py +37 -1
- inferencesh/client.py +830 -0
- inferencesh/models/__init__.py +29 -0
- inferencesh/models/base.py +94 -0
- inferencesh/models/file.py +206 -0
- inferencesh/models/llm.py +729 -0
- inferencesh/utils/__init__.py +6 -0
- inferencesh/utils/download.py +51 -0
- inferencesh/utils/storage.py +16 -0
- {inferencesh-0.3.0.dist-info → inferencesh-0.3.1.dist-info}/METADATA +6 -1
- inferencesh-0.3.1.dist-info/RECORD +15 -0
- {inferencesh-0.3.0.dist-info → inferencesh-0.3.1.dist-info}/WHEEL +1 -1
- inferencesh/sdk.py +0 -363
- inferencesh-0.3.0.dist-info/RECORD +0 -8
- {inferencesh-0.3.0.dist-info → inferencesh-0.3.1.dist-info}/entry_points.txt +0 -0
- {inferencesh-0.3.0.dist-info → inferencesh-0.3.1.dist-info}/licenses/LICENSE +0 -0
- {inferencesh-0.3.0.dist-info → inferencesh-0.3.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,729 @@
|
|
|
1
|
+
from typing import Optional, List, Any, Callable, Dict, Generator
|
|
2
|
+
from enum import Enum
|
|
3
|
+
from pydantic import Field, BaseModel
|
|
4
|
+
from queue import Queue, Empty
|
|
5
|
+
from threading import Thread
|
|
6
|
+
import time
|
|
7
|
+
from contextlib import contextmanager
|
|
8
|
+
import base64
|
|
9
|
+
|
|
10
|
+
from .base import BaseAppInput, BaseAppOutput
|
|
11
|
+
from .file import File
|
|
12
|
+
|
|
13
|
+
class ContextMessageRole(str, Enum):
|
|
14
|
+
USER = "user"
|
|
15
|
+
ASSISTANT = "assistant"
|
|
16
|
+
SYSTEM = "system"
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class Message(BaseAppInput):
|
|
20
|
+
role: ContextMessageRole
|
|
21
|
+
content: str
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class ContextMessage(BaseAppInput):
|
|
25
|
+
role: ContextMessageRole = Field(
|
|
26
|
+
description="the role of the message. user, assistant, or system",
|
|
27
|
+
)
|
|
28
|
+
text: str = Field(
|
|
29
|
+
description="the text content of the message"
|
|
30
|
+
)
|
|
31
|
+
image: Optional[File] = Field(
|
|
32
|
+
description="the image file of the message",
|
|
33
|
+
default=None
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
class BaseLLMInput(BaseAppInput):
|
|
37
|
+
"""Base class with common LLM fields."""
|
|
38
|
+
system_prompt: str = Field(
|
|
39
|
+
description="the system prompt to use for the model",
|
|
40
|
+
default="you are a helpful assistant that can answer questions and help with tasks.",
|
|
41
|
+
examples=[
|
|
42
|
+
"you are a helpful assistant that can answer questions and help with tasks.",
|
|
43
|
+
]
|
|
44
|
+
)
|
|
45
|
+
context: List[ContextMessage] = Field(
|
|
46
|
+
description="the context to use for the model",
|
|
47
|
+
default=[],
|
|
48
|
+
examples=[
|
|
49
|
+
[
|
|
50
|
+
{"role": "user", "content": [{"type": "text", "text": "What is the capital of France?"}]},
|
|
51
|
+
{"role": "assistant", "content": [{"type": "text", "text": "The capital of France is Paris."}]}
|
|
52
|
+
]
|
|
53
|
+
]
|
|
54
|
+
)
|
|
55
|
+
text: str = Field(
|
|
56
|
+
description="the user prompt to use for the model",
|
|
57
|
+
examples=[
|
|
58
|
+
"write a haiku about artificial general intelligence"
|
|
59
|
+
]
|
|
60
|
+
)
|
|
61
|
+
temperature: float = Field(default=0.7, ge=0.0, le=1.0)
|
|
62
|
+
top_p: float = Field(default=0.95, ge=0.0, le=1.0)
|
|
63
|
+
context_size: int = Field(default=4096)
|
|
64
|
+
|
|
65
|
+
class ImageCapabilityMixin(BaseModel):
|
|
66
|
+
"""Mixin for models that support image inputs."""
|
|
67
|
+
image: Optional[File] = Field(
|
|
68
|
+
description="the image to use for the model",
|
|
69
|
+
default=None,
|
|
70
|
+
contentMediaType="image/*",
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
class ReasoningCapabilityMixin(BaseModel):
|
|
74
|
+
"""Mixin for models that support reasoning."""
|
|
75
|
+
reasoning: bool = Field(
|
|
76
|
+
description="enable step-by-step reasoning",
|
|
77
|
+
default=False
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
class ToolsCapabilityMixin(BaseModel):
|
|
81
|
+
"""Mixin for models that support tool/function calling."""
|
|
82
|
+
tools: Optional[List[Dict[str, Any]]] = Field(
|
|
83
|
+
description="tool definitions for function calling",
|
|
84
|
+
default=None
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
# Example of how to use:
|
|
88
|
+
class LLMInput(BaseLLMInput):
|
|
89
|
+
"""Default LLM input model with no special capabilities."""
|
|
90
|
+
pass
|
|
91
|
+
|
|
92
|
+
# For backward compatibility
|
|
93
|
+
LLMInput.model_config["title"] = "LLMInput"
|
|
94
|
+
|
|
95
|
+
class LLMUsage(BaseAppOutput):
|
|
96
|
+
stop_reason: str = ""
|
|
97
|
+
time_to_first_token: float = 0.0
|
|
98
|
+
tokens_per_second: float = 0.0
|
|
99
|
+
prompt_tokens: int = 0
|
|
100
|
+
completion_tokens: int = 0
|
|
101
|
+
total_tokens: int = 0
|
|
102
|
+
reasoning_tokens: int = 0
|
|
103
|
+
reasoning_time: float = 0.0
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class BaseLLMOutput(BaseAppOutput):
|
|
107
|
+
"""Base class for LLM outputs with common fields."""
|
|
108
|
+
response: str = Field(description="the generated text response")
|
|
109
|
+
|
|
110
|
+
class LLMUsageMixin(BaseModel):
|
|
111
|
+
"""Mixin for models that provide token usage statistics."""
|
|
112
|
+
usage: Optional[LLMUsage] = Field(
|
|
113
|
+
description="token usage statistics",
|
|
114
|
+
default=None
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
class ReasoningMixin(BaseModel):
|
|
118
|
+
"""Mixin for models that support reasoning."""
|
|
119
|
+
reasoning: Optional[str] = Field(
|
|
120
|
+
description="the reasoning output of the model",
|
|
121
|
+
default=None
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
class ToolCallsMixin(BaseModel):
|
|
125
|
+
"""Mixin for models that support tool calls."""
|
|
126
|
+
tool_calls: Optional[List[Dict[str, Any]]] = Field(
|
|
127
|
+
description="tool calls for function calling",
|
|
128
|
+
default=None
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
# Example of how to use:
|
|
132
|
+
class LLMOutput(LLMUsageMixin, BaseLLMOutput):
|
|
133
|
+
"""Default LLM output model with token usage tracking."""
|
|
134
|
+
pass
|
|
135
|
+
|
|
136
|
+
# For backward compatibility
|
|
137
|
+
LLMOutput.model_config["title"] = "LLMOutput"
|
|
138
|
+
|
|
139
|
+
@contextmanager
|
|
140
|
+
def timing_context():
|
|
141
|
+
"""Context manager to track timing information for LLM generation."""
|
|
142
|
+
class TimingInfo:
|
|
143
|
+
def __init__(self):
|
|
144
|
+
self.start_time = time.time()
|
|
145
|
+
self.first_token_time = None
|
|
146
|
+
self.reasoning_start_time = None
|
|
147
|
+
self.total_reasoning_time = 0.0
|
|
148
|
+
self.reasoning_tokens = 0
|
|
149
|
+
self.in_reasoning = False
|
|
150
|
+
|
|
151
|
+
def mark_first_token(self):
|
|
152
|
+
if self.first_token_time is None:
|
|
153
|
+
self.first_token_time = time.time()
|
|
154
|
+
|
|
155
|
+
def start_reasoning(self):
|
|
156
|
+
if not self.in_reasoning:
|
|
157
|
+
self.reasoning_start_time = time.time()
|
|
158
|
+
self.in_reasoning = True
|
|
159
|
+
|
|
160
|
+
def end_reasoning(self, token_count: int = 0):
|
|
161
|
+
if self.in_reasoning and self.reasoning_start_time:
|
|
162
|
+
self.total_reasoning_time += time.time() - self.reasoning_start_time
|
|
163
|
+
self.reasoning_tokens += token_count
|
|
164
|
+
self.reasoning_start_time = None
|
|
165
|
+
self.in_reasoning = False
|
|
166
|
+
|
|
167
|
+
@property
|
|
168
|
+
def stats(self):
|
|
169
|
+
current_time = time.time()
|
|
170
|
+
if self.first_token_time is None:
|
|
171
|
+
return {
|
|
172
|
+
"time_to_first_token": 0.0,
|
|
173
|
+
"generation_time": 0.0,
|
|
174
|
+
"reasoning_time": self.total_reasoning_time,
|
|
175
|
+
"reasoning_tokens": self.reasoning_tokens
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
time_to_first = self.first_token_time - self.start_time
|
|
179
|
+
generation_time = current_time - self.first_token_time
|
|
180
|
+
|
|
181
|
+
return {
|
|
182
|
+
"time_to_first_token": time_to_first,
|
|
183
|
+
"generation_time": generation_time,
|
|
184
|
+
"reasoning_time": self.total_reasoning_time,
|
|
185
|
+
"reasoning_tokens": self.reasoning_tokens
|
|
186
|
+
}
|
|
187
|
+
|
|
188
|
+
timing = TimingInfo()
|
|
189
|
+
try:
|
|
190
|
+
yield timing
|
|
191
|
+
finally:
|
|
192
|
+
pass
|
|
193
|
+
|
|
194
|
+
def image_to_base64_data_uri(file_path):
|
|
195
|
+
with open(file_path, "rb") as img_file:
|
|
196
|
+
base64_data = base64.b64encode(img_file.read()).decode('utf-8')
|
|
197
|
+
return f"data:image/png;base64,{base64_data}"
|
|
198
|
+
|
|
199
|
+
def build_messages(
|
|
200
|
+
input_data: LLMInput,
|
|
201
|
+
transform_user_message: Optional[Callable[[str], str]] = None
|
|
202
|
+
) -> List[Dict[str, Any]]:
|
|
203
|
+
"""Build messages for LLaMA.cpp chat completion.
|
|
204
|
+
|
|
205
|
+
If any message includes image content, builds OpenAI-style multipart format.
|
|
206
|
+
Otherwise, uses plain string-only format.
|
|
207
|
+
"""
|
|
208
|
+
def render_message(msg: ContextMessage, allow_multipart: bool) -> str | List[dict]:
|
|
209
|
+
parts = []
|
|
210
|
+
text = transform_user_message(msg.text) if transform_user_message and msg.role == ContextMessageRole.USER else msg.text
|
|
211
|
+
if text:
|
|
212
|
+
parts.append({"type": "text", "text": text})
|
|
213
|
+
if msg.image:
|
|
214
|
+
if msg.image.path:
|
|
215
|
+
image_data_uri = image_to_base64_data_uri(msg.image.path)
|
|
216
|
+
parts.append({"type": "image_url", "image_url": {"url": image_data_uri}})
|
|
217
|
+
elif msg.image.uri:
|
|
218
|
+
parts.append({"type": "image_url", "image_url": {"url": msg.image.uri}})
|
|
219
|
+
if allow_multipart:
|
|
220
|
+
return parts
|
|
221
|
+
if len(parts) == 1 and parts[0]["type"] == "text":
|
|
222
|
+
return parts[0]["text"]
|
|
223
|
+
raise ValueError("Image content requires multipart support")
|
|
224
|
+
|
|
225
|
+
messages = [{"role": "system", "content": input_data.system_prompt}] if input_data.system_prompt is not None and input_data.system_prompt != "" else []
|
|
226
|
+
|
|
227
|
+
def merge_messages(messages: List[ContextMessage]) -> ContextMessage:
|
|
228
|
+
text = "\n\n".join(msg.text for msg in messages if msg.text)
|
|
229
|
+
images = [msg.image for msg in messages if msg.image]
|
|
230
|
+
image = images[0] if images else None # TODO: handle multiple images
|
|
231
|
+
return ContextMessage(role=messages[0].role, text=text, image=image)
|
|
232
|
+
|
|
233
|
+
user_input_text = ""
|
|
234
|
+
if hasattr(input_data, "text"):
|
|
235
|
+
user_input_text = transform_user_message(input_data.text) if transform_user_message else input_data.text
|
|
236
|
+
|
|
237
|
+
user_input_image = None
|
|
238
|
+
multipart = any(m.image for m in input_data.context)
|
|
239
|
+
if hasattr(input_data, "image"):
|
|
240
|
+
user_input_image = input_data.image
|
|
241
|
+
multipart = multipart or input_data.image is not None
|
|
242
|
+
|
|
243
|
+
user_msg = ContextMessage(role=ContextMessageRole.USER, text=user_input_text, image=user_input_image)
|
|
244
|
+
|
|
245
|
+
input_data.context.append(user_msg)
|
|
246
|
+
|
|
247
|
+
current_role = None
|
|
248
|
+
current_messages = []
|
|
249
|
+
|
|
250
|
+
for msg in input_data.context:
|
|
251
|
+
if msg.role == current_role or current_role is None:
|
|
252
|
+
current_messages.append(msg)
|
|
253
|
+
current_role = msg.role
|
|
254
|
+
else:
|
|
255
|
+
messages.append({
|
|
256
|
+
"role": current_role,
|
|
257
|
+
"content": render_message(merge_messages(current_messages), allow_multipart=multipart)
|
|
258
|
+
})
|
|
259
|
+
current_messages = [msg]
|
|
260
|
+
current_role = msg.role
|
|
261
|
+
if len(current_messages) > 0:
|
|
262
|
+
messages.append({
|
|
263
|
+
"role": current_role,
|
|
264
|
+
"content": render_message(merge_messages(current_messages), allow_multipart=multipart)
|
|
265
|
+
})
|
|
266
|
+
|
|
267
|
+
return messages
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
class StreamResponse:
|
|
271
|
+
"""Holds a single chunk of streamed response."""
|
|
272
|
+
def __init__(self):
|
|
273
|
+
self.content = ""
|
|
274
|
+
self.tool_calls = None # Changed from [] to None
|
|
275
|
+
self.finish_reason = None
|
|
276
|
+
self.timing_stats = {
|
|
277
|
+
"time_to_first_token": None, # Changed from 0.0 to None
|
|
278
|
+
"generation_time": 0.0,
|
|
279
|
+
"reasoning_time": 0.0,
|
|
280
|
+
"reasoning_tokens": 0,
|
|
281
|
+
"tokens_per_second": 0.0
|
|
282
|
+
}
|
|
283
|
+
self.usage_stats = {
|
|
284
|
+
"prompt_tokens": 0,
|
|
285
|
+
"completion_tokens": 0,
|
|
286
|
+
"total_tokens": 0,
|
|
287
|
+
"stop_reason": ""
|
|
288
|
+
}
|
|
289
|
+
|
|
290
|
+
def update_from_chunk(self, chunk: Dict[str, Any], timing: Any) -> None:
|
|
291
|
+
"""Update response state from a chunk."""
|
|
292
|
+
# Update usage stats if present
|
|
293
|
+
if "usage" in chunk:
|
|
294
|
+
usage = chunk["usage"]
|
|
295
|
+
if usage is not None:
|
|
296
|
+
# Update usage stats preserving existing values if not provided
|
|
297
|
+
self.usage_stats.update({
|
|
298
|
+
"prompt_tokens": usage.get("prompt_tokens", self.usage_stats["prompt_tokens"]),
|
|
299
|
+
"completion_tokens": usage.get("completion_tokens", self.usage_stats["completion_tokens"]),
|
|
300
|
+
"total_tokens": usage.get("total_tokens", self.usage_stats["total_tokens"])
|
|
301
|
+
})
|
|
302
|
+
|
|
303
|
+
# Get the delta from the chunk
|
|
304
|
+
delta = chunk.get("choices", [{}])[0]
|
|
305
|
+
|
|
306
|
+
# Extract content and tool calls from either message or delta
|
|
307
|
+
if "message" in delta:
|
|
308
|
+
message = delta["message"]
|
|
309
|
+
self.content = message.get("content", "")
|
|
310
|
+
if message.get("tool_calls"):
|
|
311
|
+
self._update_tool_calls(message["tool_calls"])
|
|
312
|
+
self.finish_reason = delta.get("finish_reason")
|
|
313
|
+
if self.finish_reason:
|
|
314
|
+
self.usage_stats["stop_reason"] = self.finish_reason
|
|
315
|
+
elif "delta" in delta:
|
|
316
|
+
delta_content = delta["delta"]
|
|
317
|
+
self.content = delta_content.get("content", "")
|
|
318
|
+
if delta_content.get("tool_calls"):
|
|
319
|
+
self._update_tool_calls(delta_content["tool_calls"])
|
|
320
|
+
self.finish_reason = delta.get("finish_reason")
|
|
321
|
+
if self.finish_reason:
|
|
322
|
+
self.usage_stats["stop_reason"] = self.finish_reason
|
|
323
|
+
|
|
324
|
+
# Update timing stats
|
|
325
|
+
timing_stats = timing.stats
|
|
326
|
+
if self.timing_stats["time_to_first_token"] is None:
|
|
327
|
+
self.timing_stats["time_to_first_token"] = timing_stats["time_to_first_token"]
|
|
328
|
+
|
|
329
|
+
self.timing_stats.update({
|
|
330
|
+
"generation_time": timing_stats["generation_time"],
|
|
331
|
+
"reasoning_time": timing_stats["reasoning_time"],
|
|
332
|
+
"reasoning_tokens": timing_stats["reasoning_tokens"]
|
|
333
|
+
})
|
|
334
|
+
|
|
335
|
+
# Calculate tokens per second only if we have valid completion tokens and generation time
|
|
336
|
+
if self.usage_stats["completion_tokens"] > 0 and timing_stats["generation_time"] > 0:
|
|
337
|
+
self.timing_stats["tokens_per_second"] = (
|
|
338
|
+
self.usage_stats["completion_tokens"] / timing_stats["generation_time"]
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
def _update_tool_calls(self, new_tool_calls: List[Dict[str, Any]]) -> None:
|
|
343
|
+
"""Update tool calls, handling both full and partial updates."""
|
|
344
|
+
if self.tool_calls is None:
|
|
345
|
+
self.tool_calls = []
|
|
346
|
+
|
|
347
|
+
for tool_delta in new_tool_calls:
|
|
348
|
+
tool_id = tool_delta.get("id")
|
|
349
|
+
if not tool_id:
|
|
350
|
+
continue
|
|
351
|
+
|
|
352
|
+
# Find or create tool call
|
|
353
|
+
current_tool = next((t for t in self.tool_calls if t["id"] == tool_id), None)
|
|
354
|
+
if not current_tool:
|
|
355
|
+
current_tool = {
|
|
356
|
+
"id": tool_id,
|
|
357
|
+
"type": tool_delta.get("type", "function"),
|
|
358
|
+
"function": {"name": "", "arguments": ""}
|
|
359
|
+
}
|
|
360
|
+
self.tool_calls.append(current_tool)
|
|
361
|
+
|
|
362
|
+
# Update tool call
|
|
363
|
+
if "function" in tool_delta:
|
|
364
|
+
func_delta = tool_delta["function"]
|
|
365
|
+
if "name" in func_delta:
|
|
366
|
+
current_tool["function"]["name"] = func_delta["name"]
|
|
367
|
+
if "arguments" in func_delta:
|
|
368
|
+
current_tool["function"]["arguments"] += func_delta["arguments"]
|
|
369
|
+
|
|
370
|
+
def has_updates(self) -> bool:
|
|
371
|
+
"""Check if this response has any content, tool call, or usage updates."""
|
|
372
|
+
has_content = bool(self.content)
|
|
373
|
+
has_tool_calls = bool(self.tool_calls)
|
|
374
|
+
has_usage = self.usage_stats["prompt_tokens"] > 0 or self.usage_stats["completion_tokens"] > 0
|
|
375
|
+
has_finish = bool(self.finish_reason)
|
|
376
|
+
|
|
377
|
+
return has_content or has_tool_calls or has_usage or has_finish
|
|
378
|
+
|
|
379
|
+
def to_output(self, buffer: str, transformer: Any) -> tuple[BaseLLMOutput, str]:
|
|
380
|
+
"""Convert current state to LLMOutput."""
|
|
381
|
+
# Create usage object if we have stats
|
|
382
|
+
usage = None
|
|
383
|
+
if any(self.usage_stats.values()):
|
|
384
|
+
usage = LLMUsage(
|
|
385
|
+
stop_reason=self.usage_stats["stop_reason"],
|
|
386
|
+
time_to_first_token=self.timing_stats["time_to_first_token"] or 0.0,
|
|
387
|
+
tokens_per_second=self.timing_stats["tokens_per_second"],
|
|
388
|
+
prompt_tokens=self.usage_stats["prompt_tokens"],
|
|
389
|
+
completion_tokens=self.usage_stats["completion_tokens"],
|
|
390
|
+
total_tokens=self.usage_stats["total_tokens"],
|
|
391
|
+
reasoning_time=self.timing_stats["reasoning_time"],
|
|
392
|
+
reasoning_tokens=self.timing_stats["reasoning_tokens"]
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
buffer, output, _ = transformer(self.content, buffer, usage)
|
|
396
|
+
|
|
397
|
+
# Add tool calls if present and supported
|
|
398
|
+
if self.tool_calls and hasattr(output, 'tool_calls'):
|
|
399
|
+
output.tool_calls = self.tool_calls
|
|
400
|
+
|
|
401
|
+
return output, buffer
|
|
402
|
+
|
|
403
|
+
class ResponseState:
|
|
404
|
+
"""Holds the state of response transformation."""
|
|
405
|
+
def __init__(self):
|
|
406
|
+
self.buffer = ""
|
|
407
|
+
self.response = ""
|
|
408
|
+
self.reasoning = None
|
|
409
|
+
self.function_calls = None # For future function calling support
|
|
410
|
+
self.tool_calls = None # List to accumulate tool calls
|
|
411
|
+
self.current_tool_call = None # Track current tool call being built
|
|
412
|
+
self.usage = None # Add usage field
|
|
413
|
+
self.state_changes = {
|
|
414
|
+
"reasoning_started": False,
|
|
415
|
+
"reasoning_ended": False,
|
|
416
|
+
"function_call_started": False,
|
|
417
|
+
"function_call_ended": False,
|
|
418
|
+
"tool_call_started": False,
|
|
419
|
+
"tool_call_ended": False
|
|
420
|
+
}
|
|
421
|
+
|
|
422
|
+
class ResponseTransformer:
|
|
423
|
+
"""Base class for transforming model responses."""
|
|
424
|
+
def __init__(self, output_cls: type[BaseLLMOutput] = LLMOutput):
|
|
425
|
+
self.state = ResponseState()
|
|
426
|
+
self.output_cls = output_cls
|
|
427
|
+
self.timing = None # Will be set by stream_generate
|
|
428
|
+
|
|
429
|
+
def clean_text(self, text: str) -> str:
|
|
430
|
+
"""Clean common tokens from the text and apply model-specific cleaning.
|
|
431
|
+
|
|
432
|
+
Args:
|
|
433
|
+
text: Raw text to clean
|
|
434
|
+
|
|
435
|
+
Returns:
|
|
436
|
+
Cleaned text with common and model-specific tokens removed
|
|
437
|
+
"""
|
|
438
|
+
if text is None:
|
|
439
|
+
return ""
|
|
440
|
+
|
|
441
|
+
# Common token cleaning across most models
|
|
442
|
+
cleaned = (text.replace("<|im_end|>", "")
|
|
443
|
+
.replace("<|im_start|>", "")
|
|
444
|
+
.replace("<start_of_turn>", "")
|
|
445
|
+
.replace("<end_of_turn>", "")
|
|
446
|
+
.replace("<eos>", ""))
|
|
447
|
+
return self.additional_cleaning(cleaned)
|
|
448
|
+
|
|
449
|
+
def additional_cleaning(self, text: str) -> str:
|
|
450
|
+
"""Apply model-specific token cleaning.
|
|
451
|
+
|
|
452
|
+
Args:
|
|
453
|
+
text: Text that has had common tokens removed
|
|
454
|
+
|
|
455
|
+
Returns:
|
|
456
|
+
Text with model-specific tokens removed
|
|
457
|
+
"""
|
|
458
|
+
return text
|
|
459
|
+
|
|
460
|
+
def handle_reasoning(self, text: str) -> None:
|
|
461
|
+
"""Handle reasoning/thinking detection and extraction.
|
|
462
|
+
|
|
463
|
+
Args:
|
|
464
|
+
text: Cleaned text to process for reasoning
|
|
465
|
+
"""
|
|
466
|
+
# Default implementation for <think> style reasoning
|
|
467
|
+
if "<think>" in text and not self.state.state_changes["reasoning_started"]:
|
|
468
|
+
self.state.state_changes["reasoning_started"] = True
|
|
469
|
+
if self.timing:
|
|
470
|
+
self.timing.start_reasoning()
|
|
471
|
+
|
|
472
|
+
if "</think>" in text and not self.state.state_changes["reasoning_ended"]:
|
|
473
|
+
self.state.state_changes["reasoning_ended"] = True
|
|
474
|
+
if self.timing:
|
|
475
|
+
# Estimate token count from character count (rough approximation)
|
|
476
|
+
token_count = len(self.state.buffer.split("<think>")[1].split("</think>")[0]) // 4
|
|
477
|
+
self.timing.end_reasoning(token_count)
|
|
478
|
+
|
|
479
|
+
if "<think>" in self.state.buffer:
|
|
480
|
+
parts = self.state.buffer.split("</think>", 1)
|
|
481
|
+
if len(parts) > 1:
|
|
482
|
+
self.state.reasoning = parts[0].split("<think>", 1)[1].strip()
|
|
483
|
+
self.state.response = parts[1].strip()
|
|
484
|
+
else:
|
|
485
|
+
self.state.reasoning = self.state.buffer.split("<think>", 1)[1].strip()
|
|
486
|
+
self.state.response = ""
|
|
487
|
+
else:
|
|
488
|
+
self.state.response = self.state.buffer
|
|
489
|
+
|
|
490
|
+
def handle_function_calls(self, text: str) -> None:
|
|
491
|
+
"""Handle function call detection and extraction.
|
|
492
|
+
|
|
493
|
+
Args:
|
|
494
|
+
text: Cleaned text to process for function calls
|
|
495
|
+
"""
|
|
496
|
+
# Default no-op implementation
|
|
497
|
+
# Models can override this to implement function call handling
|
|
498
|
+
pass
|
|
499
|
+
|
|
500
|
+
def handle_tool_calls(self, text: str) -> None:
|
|
501
|
+
"""Handle tool call detection and extraction.
|
|
502
|
+
|
|
503
|
+
Args:
|
|
504
|
+
text: Cleaned text to process for tool calls
|
|
505
|
+
"""
|
|
506
|
+
# Default no-op implementation
|
|
507
|
+
# Models can override this to implement tool call handling
|
|
508
|
+
pass
|
|
509
|
+
|
|
510
|
+
def transform_chunk(self, chunk: str) -> None:
|
|
511
|
+
"""Transform a single chunk of model output.
|
|
512
|
+
|
|
513
|
+
This method orchestrates the transformation process by:
|
|
514
|
+
1. Cleaning the text
|
|
515
|
+
2. Updating the buffer
|
|
516
|
+
3. Processing various capabilities (reasoning, function calls, etc)
|
|
517
|
+
|
|
518
|
+
Args:
|
|
519
|
+
chunk: Raw text chunk from the model
|
|
520
|
+
"""
|
|
521
|
+
cleaned = self.clean_text(chunk)
|
|
522
|
+
self.state.buffer += cleaned
|
|
523
|
+
|
|
524
|
+
# Process different capabilities
|
|
525
|
+
self.handle_reasoning(cleaned)
|
|
526
|
+
self.handle_function_calls(cleaned)
|
|
527
|
+
self.handle_tool_calls(cleaned)
|
|
528
|
+
|
|
529
|
+
def build_output(self) -> tuple[str, LLMOutput, dict]:
|
|
530
|
+
"""Build the final output tuple.
|
|
531
|
+
|
|
532
|
+
Returns:
|
|
533
|
+
Tuple of (buffer, LLMOutput, state_changes)
|
|
534
|
+
"""
|
|
535
|
+
# Build base output with required fields
|
|
536
|
+
output_data = {
|
|
537
|
+
"response": self.state.response.strip(),
|
|
538
|
+
}
|
|
539
|
+
|
|
540
|
+
# Add optional fields if they exist
|
|
541
|
+
if self.state.usage is not None:
|
|
542
|
+
output_data["usage"] = self.state.usage
|
|
543
|
+
if self.state.reasoning:
|
|
544
|
+
output_data["reasoning"] = self.state.reasoning.strip()
|
|
545
|
+
if self.state.function_calls:
|
|
546
|
+
output_data["function_calls"] = self.state.function_calls
|
|
547
|
+
if self.state.tool_calls:
|
|
548
|
+
output_data["tool_calls"] = self.state.tool_calls
|
|
549
|
+
|
|
550
|
+
output = self.output_cls(**output_data)
|
|
551
|
+
|
|
552
|
+
return (
|
|
553
|
+
self.state.buffer,
|
|
554
|
+
output,
|
|
555
|
+
self.state.state_changes
|
|
556
|
+
)
|
|
557
|
+
|
|
558
|
+
def __call__(self, piece: str, buffer: str, usage: Optional[LLMUsage] = None) -> tuple[str, LLMOutput, dict]:
|
|
559
|
+
"""Transform a piece of text and return the result.
|
|
560
|
+
|
|
561
|
+
Args:
|
|
562
|
+
piece: New piece of text to transform
|
|
563
|
+
buffer: Existing buffer content
|
|
564
|
+
usage: Optional usage statistics
|
|
565
|
+
|
|
566
|
+
Returns:
|
|
567
|
+
Tuple of (new_buffer, output, state_changes)
|
|
568
|
+
"""
|
|
569
|
+
self.state.buffer = buffer
|
|
570
|
+
if usage is not None:
|
|
571
|
+
self.state.usage = usage
|
|
572
|
+
self.transform_chunk(piece)
|
|
573
|
+
return self.build_output()
|
|
574
|
+
|
|
575
|
+
|
|
576
|
+
def stream_generate(
|
|
577
|
+
model: Any,
|
|
578
|
+
messages: List[Dict[str, Any]],
|
|
579
|
+
transformer: ResponseTransformer = ResponseTransformer(),
|
|
580
|
+
tools: Optional[List[Dict[str, Any]]] = None,
|
|
581
|
+
tool_choice: Optional[Dict[str, Any]] = None,
|
|
582
|
+
temperature: float = 0.7,
|
|
583
|
+
top_p: float = 0.95,
|
|
584
|
+
stop: Optional[List[str]] = None,
|
|
585
|
+
verbose: bool = False,
|
|
586
|
+
output_cls: type[BaseLLMOutput] = LLMOutput,
|
|
587
|
+
) -> Generator[BaseLLMOutput, None, None]:
|
|
588
|
+
"""Stream generate from LLaMA.cpp model with timing and usage tracking."""
|
|
589
|
+
|
|
590
|
+
# Create queues for communication between threads
|
|
591
|
+
response_queue = Queue()
|
|
592
|
+
error_queue = Queue()
|
|
593
|
+
keep_alive_queue = Queue()
|
|
594
|
+
|
|
595
|
+
# Set the output class for the transformer
|
|
596
|
+
transformer.output_cls = output_cls
|
|
597
|
+
|
|
598
|
+
def _generate_worker():
|
|
599
|
+
"""Worker thread to run the model generation."""
|
|
600
|
+
try:
|
|
601
|
+
# Build completion kwargs
|
|
602
|
+
completion_kwargs = {
|
|
603
|
+
"messages": messages,
|
|
604
|
+
"stream": True,
|
|
605
|
+
"temperature": temperature,
|
|
606
|
+
"top_p": top_p,
|
|
607
|
+
"stop": stop
|
|
608
|
+
}
|
|
609
|
+
if tools is not None:
|
|
610
|
+
completion_kwargs["tools"] = tools
|
|
611
|
+
if tool_choice is not None:
|
|
612
|
+
completion_kwargs["tool_choice"] = tool_choice
|
|
613
|
+
|
|
614
|
+
# Signal that we're starting
|
|
615
|
+
keep_alive_queue.put(("init", time.time()))
|
|
616
|
+
|
|
617
|
+
completion = model.create_chat_completion(**completion_kwargs)
|
|
618
|
+
|
|
619
|
+
for chunk in completion:
|
|
620
|
+
response_queue.put(("chunk", chunk))
|
|
621
|
+
# Update keep-alive timestamp
|
|
622
|
+
keep_alive_queue.put(("alive", time.time()))
|
|
623
|
+
|
|
624
|
+
# Signal completion
|
|
625
|
+
response_queue.put(("done", None))
|
|
626
|
+
|
|
627
|
+
except Exception as e:
|
|
628
|
+
# Preserve the full exception with traceback
|
|
629
|
+
import sys
|
|
630
|
+
error_queue.put((e, sys.exc_info()[2]))
|
|
631
|
+
response_queue.put(("error", str(e)))
|
|
632
|
+
|
|
633
|
+
with timing_context() as timing:
|
|
634
|
+
transformer.timing = timing
|
|
635
|
+
|
|
636
|
+
# Start generation thread
|
|
637
|
+
generation_thread = Thread(target=_generate_worker, daemon=True)
|
|
638
|
+
generation_thread.start()
|
|
639
|
+
|
|
640
|
+
# Initialize response state
|
|
641
|
+
response = StreamResponse()
|
|
642
|
+
buffer = ""
|
|
643
|
+
|
|
644
|
+
# Keep-alive tracking
|
|
645
|
+
last_activity = time.time()
|
|
646
|
+
init_timeout = 30.0 # 30 seconds for initial response
|
|
647
|
+
chunk_timeout = 10.0 # 10 seconds between chunks
|
|
648
|
+
|
|
649
|
+
try:
|
|
650
|
+
# Wait for initial setup
|
|
651
|
+
try:
|
|
652
|
+
msg_type, timestamp = keep_alive_queue.get(timeout=init_timeout)
|
|
653
|
+
if msg_type != "init":
|
|
654
|
+
raise RuntimeError("Unexpected initialization message")
|
|
655
|
+
last_activity = timestamp
|
|
656
|
+
except Empty:
|
|
657
|
+
raise RuntimeError(f"Model failed to initialize within {init_timeout} seconds")
|
|
658
|
+
|
|
659
|
+
while True:
|
|
660
|
+
# Check for errors - now with proper exception chaining
|
|
661
|
+
if not error_queue.empty():
|
|
662
|
+
exc, tb = error_queue.get()
|
|
663
|
+
if isinstance(exc, Exception):
|
|
664
|
+
raise exc.with_traceback(tb)
|
|
665
|
+
else:
|
|
666
|
+
raise RuntimeError(f"Unknown error in worker thread: {exc}")
|
|
667
|
+
|
|
668
|
+
# Check keep-alive
|
|
669
|
+
try:
|
|
670
|
+
while not keep_alive_queue.empty():
|
|
671
|
+
_, timestamp = keep_alive_queue.get_nowait()
|
|
672
|
+
last_activity = timestamp
|
|
673
|
+
except Empty:
|
|
674
|
+
# Ignore empty queue - this is expected
|
|
675
|
+
pass
|
|
676
|
+
|
|
677
|
+
# Check for timeout
|
|
678
|
+
if time.time() - last_activity > chunk_timeout:
|
|
679
|
+
raise RuntimeError(f"No response from model for {chunk_timeout} seconds")
|
|
680
|
+
|
|
681
|
+
# Get next chunk
|
|
682
|
+
try:
|
|
683
|
+
msg_type, data = response_queue.get(timeout=0.1)
|
|
684
|
+
except Empty:
|
|
685
|
+
continue
|
|
686
|
+
|
|
687
|
+
if msg_type == "error":
|
|
688
|
+
# If we get an error message but no exception in error_queue,
|
|
689
|
+
# create a new error
|
|
690
|
+
raise RuntimeError(f"Generation error: {data}")
|
|
691
|
+
elif msg_type == "done":
|
|
692
|
+
break
|
|
693
|
+
|
|
694
|
+
chunk = data
|
|
695
|
+
|
|
696
|
+
if verbose:
|
|
697
|
+
print(chunk)
|
|
698
|
+
|
|
699
|
+
# Mark first token time
|
|
700
|
+
if not timing.first_token_time:
|
|
701
|
+
timing.mark_first_token()
|
|
702
|
+
|
|
703
|
+
# Update response state from chunk
|
|
704
|
+
response.update_from_chunk(chunk, timing)
|
|
705
|
+
|
|
706
|
+
# Yield output if we have updates
|
|
707
|
+
if response.has_updates():
|
|
708
|
+
output, buffer = response.to_output(buffer, transformer)
|
|
709
|
+
yield output
|
|
710
|
+
|
|
711
|
+
# Break if we're done
|
|
712
|
+
if response.finish_reason:
|
|
713
|
+
break
|
|
714
|
+
|
|
715
|
+
# Wait for generation thread to finish
|
|
716
|
+
if generation_thread.is_alive():
|
|
717
|
+
generation_thread.join(timeout=5.0) # Increased timeout to 5 seconds
|
|
718
|
+
if generation_thread.is_alive():
|
|
719
|
+
# Thread didn't finish - this shouldn't happen normally
|
|
720
|
+
raise RuntimeError("Generation thread failed to finish")
|
|
721
|
+
|
|
722
|
+
except Exception as e:
|
|
723
|
+
# Check if there's a thread error we should chain with
|
|
724
|
+
if not error_queue.empty():
|
|
725
|
+
thread_exc, thread_tb = error_queue.get()
|
|
726
|
+
if isinstance(thread_exc, Exception):
|
|
727
|
+
raise e from thread_exc
|
|
728
|
+
# If no thread error, raise the original exception
|
|
729
|
+
raise
|