inferencesh 0.2.23__py3-none-any.whl → 0.4.29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inferencesh/__init__.py +5 -0
- inferencesh/client.py +1081 -0
- inferencesh/models/base.py +81 -3
- inferencesh/models/file.py +120 -21
- inferencesh/models/llm.py +485 -136
- inferencesh/utils/download.py +15 -7
- inferencesh-0.4.29.dist-info/METADATA +196 -0
- inferencesh-0.4.29.dist-info/RECORD +15 -0
- inferencesh-0.2.23.dist-info/METADATA +0 -105
- inferencesh-0.2.23.dist-info/RECORD +0 -14
- {inferencesh-0.2.23.dist-info → inferencesh-0.4.29.dist-info}/WHEEL +0 -0
- {inferencesh-0.2.23.dist-info → inferencesh-0.4.29.dist-info}/entry_points.txt +0 -0
- {inferencesh-0.2.23.dist-info → inferencesh-0.4.29.dist-info}/licenses/LICENSE +0 -0
- {inferencesh-0.2.23.dist-info → inferencesh-0.4.29.dist-info}/top_level.txt +0 -0
inferencesh/models/llm.py
CHANGED
|
@@ -1,96 +1,124 @@
|
|
|
1
1
|
from typing import Optional, List, Any, Callable, Dict, Generator
|
|
2
2
|
from enum import Enum
|
|
3
|
-
from pydantic import Field
|
|
4
|
-
from queue import Queue
|
|
3
|
+
from pydantic import Field, BaseModel
|
|
4
|
+
from queue import Queue, Empty
|
|
5
5
|
from threading import Thread
|
|
6
6
|
import time
|
|
7
7
|
from contextlib import contextmanager
|
|
8
8
|
import base64
|
|
9
|
+
import json
|
|
9
10
|
|
|
10
11
|
from .base import BaseAppInput, BaseAppOutput
|
|
11
12
|
from .file import File
|
|
12
13
|
|
|
13
|
-
|
|
14
14
|
class ContextMessageRole(str, Enum):
|
|
15
15
|
USER = "user"
|
|
16
16
|
ASSISTANT = "assistant"
|
|
17
17
|
SYSTEM = "system"
|
|
18
|
+
TOOL = "tool"
|
|
18
19
|
|
|
19
20
|
|
|
20
21
|
class Message(BaseAppInput):
|
|
21
22
|
role: ContextMessageRole
|
|
22
23
|
content: str
|
|
23
24
|
|
|
24
|
-
|
|
25
25
|
class ContextMessage(BaseAppInput):
|
|
26
26
|
role: ContextMessageRole = Field(
|
|
27
|
-
description="
|
|
27
|
+
description="the role of the message. user, assistant, or system",
|
|
28
28
|
)
|
|
29
29
|
text: str = Field(
|
|
30
|
-
description="
|
|
30
|
+
description="the text content of the message"
|
|
31
31
|
)
|
|
32
32
|
image: Optional[File] = Field(
|
|
33
|
-
description="
|
|
33
|
+
description="the image file of the message",
|
|
34
|
+
default=None
|
|
35
|
+
)
|
|
36
|
+
images: Optional[List[File]] = Field(
|
|
37
|
+
description="the images of the message",
|
|
38
|
+
default=None
|
|
39
|
+
)
|
|
40
|
+
tool_calls: Optional[List[Dict[str, Any]]] = Field(
|
|
41
|
+
description="the tool calls of the message",
|
|
42
|
+
default=None
|
|
43
|
+
)
|
|
44
|
+
tool_call_id: Optional[str] = Field(
|
|
45
|
+
description="the tool call id for tool role messages",
|
|
34
46
|
default=None
|
|
35
47
|
)
|
|
36
48
|
|
|
37
|
-
class
|
|
49
|
+
class BaseLLMInput(BaseAppInput):
|
|
50
|
+
"""Base class with common LLM fields."""
|
|
38
51
|
system_prompt: str = Field(
|
|
39
|
-
description="
|
|
40
|
-
default="
|
|
52
|
+
description="the system prompt to use for the model",
|
|
53
|
+
default="you are a helpful assistant that can answer questions and help with tasks.",
|
|
41
54
|
examples=[
|
|
42
|
-
"
|
|
43
|
-
"You are a certified medical professional who can provide accurate health information.",
|
|
44
|
-
"You are a certified financial advisor who can give sound investment guidance.",
|
|
45
|
-
"You are a certified cybersecurity expert who can explain security best practices.",
|
|
46
|
-
"You are a certified environmental scientist who can discuss climate and sustainability.",
|
|
55
|
+
"you are a helpful assistant that can answer questions and help with tasks.",
|
|
47
56
|
]
|
|
48
57
|
)
|
|
49
58
|
context: List[ContextMessage] = Field(
|
|
50
|
-
description="
|
|
59
|
+
description="the context to use for the model",
|
|
60
|
+
default=[],
|
|
51
61
|
examples=[
|
|
52
62
|
[
|
|
53
|
-
{"role": "user", "content": [{"type": "text", "text": "What is the capital of France?"}]},
|
|
63
|
+
{"role": "user", "content": [{"type": "text", "text": "What is the capital of France?"}]},
|
|
54
64
|
{"role": "assistant", "content": [{"type": "text", "text": "The capital of France is Paris."}]}
|
|
55
|
-
],
|
|
56
|
-
[
|
|
57
|
-
{"role": "user", "content": [{"type": "text", "text": "What is the weather like today?"}]},
|
|
58
|
-
{"role": "assistant", "content": [{"type": "text", "text": "I apologize, but I don't have access to real-time weather information. You would need to check a weather service or app to get current weather conditions for your location."}]}
|
|
59
|
-
],
|
|
60
|
-
[
|
|
61
|
-
{"role": "user", "content": [{"type": "text", "text": "Can you help me write a poem about spring?"}]},
|
|
62
|
-
{"role": "assistant", "content": [{"type": "text", "text": "Here's a short poem about spring:\n\nGreen buds awakening,\nSoft rain gently falling down,\nNew life springs anew.\n\nWarm sun breaks through clouds,\nBirds return with joyful song,\nNature's sweet rebirth."}]}
|
|
63
|
-
],
|
|
64
|
-
[
|
|
65
|
-
{"role": "user", "content": [{"type": "text", "text": "Explain quantum computing in simple terms"}]},
|
|
66
|
-
{"role": "assistant", "content": [{"type": "text", "text": "Quantum computing is like having a super-powerful calculator that can solve many problems at once instead of one at a time. While regular computers use bits (0s and 1s), quantum computers use quantum bits or \"qubits\" that can be both 0 and 1 at the same time - kind of like being in two places at once! This allows them to process huge amounts of information much faster than regular computers for certain types of problems."}]}
|
|
67
65
|
]
|
|
68
|
-
]
|
|
69
|
-
|
|
66
|
+
]
|
|
67
|
+
)
|
|
68
|
+
role: ContextMessageRole = Field(
|
|
69
|
+
description="the role of the input text",
|
|
70
|
+
default=ContextMessageRole.USER
|
|
70
71
|
)
|
|
71
72
|
text: str = Field(
|
|
72
|
-
description="
|
|
73
|
+
description="the input text to use for the model",
|
|
73
74
|
examples=[
|
|
74
|
-
"
|
|
75
|
-
|
|
76
|
-
"Can you help me write a poem about spring?",
|
|
77
|
-
"Explain quantum computing in simple terms"
|
|
78
|
-
],
|
|
75
|
+
"write a haiku about artificial general intelligence"
|
|
76
|
+
]
|
|
79
77
|
)
|
|
78
|
+
temperature: float = Field(default=0.7, ge=0.0, le=1.0)
|
|
79
|
+
top_p: float = Field(default=0.95, ge=0.0, le=1.0)
|
|
80
|
+
context_size: int = Field(default=4096)
|
|
81
|
+
|
|
82
|
+
class ImageCapabilityMixin(BaseModel):
|
|
83
|
+
"""Mixin for models that support image inputs."""
|
|
80
84
|
image: Optional[File] = Field(
|
|
81
|
-
description="
|
|
82
|
-
default=None
|
|
85
|
+
description="the image to use for the model",
|
|
86
|
+
default=None,
|
|
87
|
+
contentMediaType="image/*",
|
|
83
88
|
)
|
|
84
|
-
# Optional parameters
|
|
85
|
-
temperature: float = Field(default=0.7)
|
|
86
|
-
top_p: float = Field(default=0.95)
|
|
87
|
-
max_tokens: int = Field(default=4096)
|
|
88
|
-
context_size: int = Field(default=4096)
|
|
89
89
|
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
90
|
+
class MultipleImageCapabilityMixin(BaseModel):
|
|
91
|
+
"""Mixin for models that support image inputs."""
|
|
92
|
+
images: Optional[List[File]] = Field(
|
|
93
|
+
description="the images to use for the model",
|
|
94
|
+
default=None,
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
class ReasoningCapabilityMixin(BaseModel):
|
|
98
|
+
"""Mixin for models that support reasoning."""
|
|
99
|
+
reasoning: bool = Field(
|
|
100
|
+
description="enable step-by-step reasoning",
|
|
101
|
+
default=False
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
class ToolsCapabilityMixin(BaseModel):
|
|
105
|
+
"""Mixin for models that support tool/function calling."""
|
|
106
|
+
tools: Optional[List[Dict[str, Any]]] = Field(
|
|
107
|
+
description="tool definitions for function calling",
|
|
108
|
+
default=None
|
|
109
|
+
)
|
|
110
|
+
tool_call_id: Optional[str] = Field(
|
|
111
|
+
description="the tool call id for tool role messages",
|
|
112
|
+
default=None
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
# Example of how to use:
|
|
116
|
+
class LLMInput(BaseLLMInput):
|
|
117
|
+
"""Default LLM input model with no special capabilities."""
|
|
118
|
+
pass
|
|
119
|
+
|
|
120
|
+
# For backward compatibility
|
|
121
|
+
LLMInput.model_config["title"] = "LLMInput"
|
|
94
122
|
|
|
95
123
|
class LLMUsage(BaseAppOutput):
|
|
96
124
|
stop_reason: str = ""
|
|
@@ -103,12 +131,38 @@ class LLMUsage(BaseAppOutput):
|
|
|
103
131
|
reasoning_time: float = 0.0
|
|
104
132
|
|
|
105
133
|
|
|
106
|
-
class
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
tool_calls: Optional[List[Dict[str, Any]]] = None
|
|
110
|
-
usage: Optional[LLMUsage] = None
|
|
134
|
+
class BaseLLMOutput(BaseAppOutput):
|
|
135
|
+
"""Base class for LLM outputs with common fields."""
|
|
136
|
+
response: str = Field(description="the generated text response")
|
|
111
137
|
|
|
138
|
+
class LLMUsageMixin(BaseModel):
|
|
139
|
+
"""Mixin for models that provide token usage statistics."""
|
|
140
|
+
usage: Optional[LLMUsage] = Field(
|
|
141
|
+
description="token usage statistics",
|
|
142
|
+
default=None
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
class ReasoningMixin(BaseModel):
|
|
146
|
+
"""Mixin for models that support reasoning."""
|
|
147
|
+
reasoning: Optional[str] = Field(
|
|
148
|
+
description="the reasoning output of the model",
|
|
149
|
+
default=None
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
class ToolCallsMixin(BaseModel):
|
|
153
|
+
"""Mixin for models that support tool calls."""
|
|
154
|
+
tool_calls: Optional[List[Dict[str, Any]]] = Field(
|
|
155
|
+
description="tool calls for function calling",
|
|
156
|
+
default=None
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
# Example of how to use:
|
|
160
|
+
class LLMOutput(LLMUsageMixin, BaseLLMOutput):
|
|
161
|
+
"""Default LLM output model with token usage tracking."""
|
|
162
|
+
pass
|
|
163
|
+
|
|
164
|
+
# For backward compatibility
|
|
165
|
+
LLMOutput.model_config["title"] = "LLMOutput"
|
|
112
166
|
|
|
113
167
|
@contextmanager
|
|
114
168
|
def timing_context():
|
|
@@ -116,7 +170,7 @@ def timing_context():
|
|
|
116
170
|
class TimingInfo:
|
|
117
171
|
def __init__(self):
|
|
118
172
|
self.start_time = time.time()
|
|
119
|
-
self.first_token_time =
|
|
173
|
+
self.first_token_time = None
|
|
120
174
|
self.reasoning_start_time = None
|
|
121
175
|
self.total_reasoning_time = 0.0
|
|
122
176
|
self.reasoning_tokens = 0
|
|
@@ -140,12 +194,17 @@ def timing_context():
|
|
|
140
194
|
|
|
141
195
|
@property
|
|
142
196
|
def stats(self):
|
|
143
|
-
|
|
197
|
+
current_time = time.time()
|
|
144
198
|
if self.first_token_time is None:
|
|
145
|
-
|
|
199
|
+
return {
|
|
200
|
+
"time_to_first_token": 0.0,
|
|
201
|
+
"generation_time": 0.0,
|
|
202
|
+
"reasoning_time": self.total_reasoning_time,
|
|
203
|
+
"reasoning_tokens": self.reasoning_tokens
|
|
204
|
+
}
|
|
146
205
|
|
|
147
206
|
time_to_first = self.first_token_time - self.start_time
|
|
148
|
-
generation_time =
|
|
207
|
+
generation_time = current_time - self.first_token_time
|
|
149
208
|
|
|
150
209
|
return {
|
|
151
210
|
"time_to_first_token": time_to_first,
|
|
@@ -179,36 +238,184 @@ def build_messages(
|
|
|
179
238
|
text = transform_user_message(msg.text) if transform_user_message and msg.role == ContextMessageRole.USER else msg.text
|
|
180
239
|
if text:
|
|
181
240
|
parts.append({"type": "text", "text": text})
|
|
241
|
+
else:
|
|
242
|
+
parts.append({"type": "text", "text": ""})
|
|
182
243
|
if msg.image:
|
|
183
244
|
if msg.image.path:
|
|
184
245
|
image_data_uri = image_to_base64_data_uri(msg.image.path)
|
|
185
246
|
parts.append({"type": "image_url", "image_url": {"url": image_data_uri}})
|
|
186
247
|
elif msg.image.uri:
|
|
187
248
|
parts.append({"type": "image_url", "image_url": {"url": msg.image.uri}})
|
|
249
|
+
if msg.images:
|
|
250
|
+
for image in msg.images:
|
|
251
|
+
if image.path:
|
|
252
|
+
image_data_uri = image_to_base64_data_uri(image.path)
|
|
253
|
+
parts.append({"type": "image_url", "image_url": {"url": image_data_uri}})
|
|
254
|
+
elif image.uri:
|
|
255
|
+
parts.append({"type": "image_url", "image_url": {"url": image.uri}})
|
|
188
256
|
if allow_multipart:
|
|
189
257
|
return parts
|
|
190
258
|
if len(parts) == 1 and parts[0]["type"] == "text":
|
|
191
259
|
return parts[0]["text"]
|
|
192
|
-
|
|
260
|
+
if len(parts) > 1:
|
|
261
|
+
if parts.any(lambda x: x["type"] == "image_url"):
|
|
262
|
+
raise ValueError("Image content requires multipart support")
|
|
263
|
+
return parts
|
|
264
|
+
raise ValueError("Invalid message content")
|
|
193
265
|
|
|
194
|
-
multipart = any(m.image for m in input_data.context) or input_data.image is not None
|
|
195
266
|
messages = [{"role": "system", "content": input_data.system_prompt}] if input_data.system_prompt is not None and input_data.system_prompt != "" else []
|
|
196
267
|
|
|
197
|
-
|
|
198
|
-
messages.
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
268
|
+
def merge_messages(messages: List[ContextMessage]) -> ContextMessage:
|
|
269
|
+
text = "\n\n".join(msg.text for msg in messages if msg.text)
|
|
270
|
+
images = []
|
|
271
|
+
# Collect single images
|
|
272
|
+
for msg in messages:
|
|
273
|
+
if msg.image:
|
|
274
|
+
images.append(msg.image)
|
|
275
|
+
# Collect multiple images (flatten the list)
|
|
276
|
+
for msg in messages:
|
|
277
|
+
if msg.images:
|
|
278
|
+
images.extend(msg.images)
|
|
279
|
+
# Set image to single File if there's exactly one, otherwise None
|
|
280
|
+
image = images[0] if len(images) == 1 else None
|
|
281
|
+
# Set images to the list if there are multiple, otherwise None
|
|
282
|
+
images_list = images if len(images) > 1 else None
|
|
283
|
+
return ContextMessage(role=messages[0].role, text=text, image=image, images=images_list)
|
|
284
|
+
|
|
285
|
+
def merge_tool_calls(messages: List[ContextMessage]) -> List[Dict[str, Any]]:
|
|
286
|
+
tool_calls = []
|
|
287
|
+
for msg in messages:
|
|
288
|
+
if msg.tool_calls:
|
|
289
|
+
tool_calls.extend(msg.tool_calls)
|
|
290
|
+
return tool_calls
|
|
291
|
+
|
|
292
|
+
user_input_text = ""
|
|
293
|
+
if hasattr(input_data, "text"):
|
|
294
|
+
user_input_text = transform_user_message(input_data.text) if transform_user_message else input_data.text
|
|
295
|
+
|
|
296
|
+
user_input_image = None
|
|
297
|
+
multipart = any(m.image for m in input_data.context)
|
|
298
|
+
if hasattr(input_data, "image"):
|
|
299
|
+
user_input_image = input_data.image
|
|
300
|
+
multipart = multipart or input_data.image is not None
|
|
301
|
+
|
|
302
|
+
user_input_images = None
|
|
303
|
+
if hasattr(input_data, "images"):
|
|
304
|
+
user_input_images = input_data.images
|
|
305
|
+
multipart = multipart or input_data.images is not None
|
|
202
306
|
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
307
|
+
input_role = input_data.role if hasattr(input_data, "role") else ContextMessageRole.USER
|
|
308
|
+
input_tool_call_id = input_data.tool_call_id if hasattr(input_data, "tool_call_id") else None
|
|
309
|
+
user_msg = ContextMessage(role=input_role, text=user_input_text, image=user_input_image, images=user_input_images, tool_call_id=input_tool_call_id)
|
|
310
|
+
|
|
311
|
+
input_data.context.append(user_msg)
|
|
312
|
+
|
|
313
|
+
current_role = None
|
|
314
|
+
current_messages = []
|
|
315
|
+
|
|
316
|
+
for msg in input_data.context:
|
|
317
|
+
if msg.role == current_role or current_role is None:
|
|
318
|
+
current_messages.append(msg)
|
|
319
|
+
current_role = msg.role
|
|
320
|
+
else:
|
|
321
|
+
# Convert role enum to string for OpenAI API compatibility
|
|
322
|
+
role_str = current_role.value if hasattr(current_role, "value") else current_role
|
|
323
|
+
msg_dict = {
|
|
324
|
+
"role": role_str,
|
|
325
|
+
"content": render_message(merge_messages(current_messages), allow_multipart=multipart),
|
|
326
|
+
}
|
|
327
|
+
|
|
328
|
+
# Only add tool_calls if not empty
|
|
329
|
+
tool_calls = merge_tool_calls(current_messages)
|
|
330
|
+
if tool_calls:
|
|
331
|
+
# Ensure arguments are JSON strings (OpenAI API requirement)
|
|
332
|
+
for tc in tool_calls:
|
|
333
|
+
if "function" in tc and "arguments" in tc["function"]:
|
|
334
|
+
if isinstance(tc["function"]["arguments"], dict):
|
|
335
|
+
tc["function"]["arguments"] = json.dumps(tc["function"]["arguments"])
|
|
336
|
+
msg_dict["tool_calls"] = tool_calls
|
|
337
|
+
|
|
338
|
+
# Add tool_call_id for tool role messages (required by OpenAI API)
|
|
339
|
+
if role_str == "tool":
|
|
340
|
+
if current_messages and current_messages[0].tool_call_id:
|
|
341
|
+
msg_dict["tool_call_id"] = current_messages[0].tool_call_id
|
|
342
|
+
else:
|
|
343
|
+
# If not provided, use empty string to satisfy schema
|
|
344
|
+
msg_dict["tool_call_id"] = ""
|
|
345
|
+
|
|
346
|
+
messages.append(msg_dict)
|
|
347
|
+
current_messages = [msg]
|
|
348
|
+
current_role = msg.role
|
|
349
|
+
|
|
350
|
+
if len(current_messages) > 0:
|
|
351
|
+
# Convert role enum to string for OpenAI API compatibility
|
|
352
|
+
role_str = current_role.value if hasattr(current_role, "value") else current_role
|
|
353
|
+
msg_dict = {
|
|
354
|
+
"role": role_str,
|
|
355
|
+
"content": render_message(merge_messages(current_messages), allow_multipart=multipart),
|
|
356
|
+
}
|
|
357
|
+
|
|
358
|
+
# Only add tool_calls if not empty
|
|
359
|
+
tool_calls = merge_tool_calls(current_messages)
|
|
360
|
+
if tool_calls:
|
|
361
|
+
# Ensure arguments are JSON strings (OpenAI API requirement)
|
|
362
|
+
for tc in tool_calls:
|
|
363
|
+
if "function" in tc and "arguments" in tc["function"]:
|
|
364
|
+
if isinstance(tc["function"]["arguments"], dict):
|
|
365
|
+
tc["function"]["arguments"] = json.dumps(tc["function"]["arguments"])
|
|
366
|
+
msg_dict["tool_calls"] = tool_calls
|
|
367
|
+
|
|
368
|
+
# Add tool_call_id for tool role messages (required by OpenAI API)
|
|
369
|
+
if role_str == "tool":
|
|
370
|
+
if current_messages and current_messages[0].tool_call_id:
|
|
371
|
+
msg_dict["tool_call_id"] = current_messages[0].tool_call_id
|
|
372
|
+
else:
|
|
373
|
+
# If not provided, use empty string to satisfy schema
|
|
374
|
+
msg_dict["tool_call_id"] = ""
|
|
375
|
+
|
|
376
|
+
messages.append(msg_dict)
|
|
208
377
|
|
|
209
378
|
return messages
|
|
210
379
|
|
|
211
380
|
|
|
381
|
+
def build_tools(tools: Optional[List[Dict[str, Any]]]) -> Optional[List[Dict[str, Any]]]:
|
|
382
|
+
"""Build tools in OpenAI API format.
|
|
383
|
+
|
|
384
|
+
Ensures tools are properly formatted:
|
|
385
|
+
- Wrapped in {"type": "function", "function": {...}}
|
|
386
|
+
- Parameters is never None (OpenAI API requirement)
|
|
387
|
+
"""
|
|
388
|
+
if not tools:
|
|
389
|
+
return None
|
|
390
|
+
|
|
391
|
+
result = []
|
|
392
|
+
for tool in tools:
|
|
393
|
+
# Extract function definition
|
|
394
|
+
if "type" in tool and "function" in tool:
|
|
395
|
+
func_def = tool["function"].copy()
|
|
396
|
+
else:
|
|
397
|
+
func_def = tool.copy()
|
|
398
|
+
|
|
399
|
+
# Ensure parameters is not None (OpenAI API requirement)
|
|
400
|
+
if func_def.get("parameters") is None:
|
|
401
|
+
func_def["parameters"] = {"type": "object", "properties": {}}
|
|
402
|
+
# Also ensure properties within parameters is not None
|
|
403
|
+
elif func_def["parameters"].get("properties") is None:
|
|
404
|
+
func_def["parameters"]["properties"] = {}
|
|
405
|
+
else:
|
|
406
|
+
# Remove properties with null values (OpenAI API doesn't accept them)
|
|
407
|
+
properties = func_def["parameters"].get("properties", {})
|
|
408
|
+
if properties:
|
|
409
|
+
func_def["parameters"]["properties"] = {
|
|
410
|
+
k: v for k, v in properties.items() if v is not None
|
|
411
|
+
}
|
|
412
|
+
|
|
413
|
+
# Wrap in OpenAI format
|
|
414
|
+
result.append({"type": "function", "function": func_def})
|
|
415
|
+
|
|
416
|
+
return result
|
|
417
|
+
|
|
418
|
+
|
|
212
419
|
class StreamResponse:
|
|
213
420
|
"""Holds a single chunk of streamed response."""
|
|
214
421
|
def __init__(self):
|
|
@@ -216,7 +423,7 @@ class StreamResponse:
|
|
|
216
423
|
self.tool_calls = None # Changed from [] to None
|
|
217
424
|
self.finish_reason = None
|
|
218
425
|
self.timing_stats = {
|
|
219
|
-
"time_to_first_token": 0.0
|
|
426
|
+
"time_to_first_token": None, # Changed from 0.0 to None
|
|
220
427
|
"generation_time": 0.0,
|
|
221
428
|
"reasoning_time": 0.0,
|
|
222
429
|
"reasoning_tokens": 0,
|
|
@@ -232,8 +439,15 @@ class StreamResponse:
|
|
|
232
439
|
def update_from_chunk(self, chunk: Dict[str, Any], timing: Any) -> None:
|
|
233
440
|
"""Update response state from a chunk."""
|
|
234
441
|
# Update usage stats if present
|
|
235
|
-
if "usage" in chunk
|
|
236
|
-
|
|
442
|
+
if "usage" in chunk:
|
|
443
|
+
usage = chunk["usage"]
|
|
444
|
+
if usage is not None:
|
|
445
|
+
# Update usage stats preserving existing values if not provided
|
|
446
|
+
self.usage_stats.update({
|
|
447
|
+
"prompt_tokens": usage.get("prompt_tokens", self.usage_stats["prompt_tokens"]),
|
|
448
|
+
"completion_tokens": usage.get("completion_tokens", self.usage_stats["completion_tokens"]),
|
|
449
|
+
"total_tokens": usage.get("total_tokens", self.usage_stats["total_tokens"])
|
|
450
|
+
})
|
|
237
451
|
|
|
238
452
|
# Get the delta from the chunk
|
|
239
453
|
delta = chunk.get("choices", [{}])[0]
|
|
@@ -245,23 +459,34 @@ class StreamResponse:
|
|
|
245
459
|
if message.get("tool_calls"):
|
|
246
460
|
self._update_tool_calls(message["tool_calls"])
|
|
247
461
|
self.finish_reason = delta.get("finish_reason")
|
|
462
|
+
if self.finish_reason:
|
|
463
|
+
self.usage_stats["stop_reason"] = self.finish_reason
|
|
248
464
|
elif "delta" in delta:
|
|
249
465
|
delta_content = delta["delta"]
|
|
250
466
|
self.content = delta_content.get("content", "")
|
|
251
467
|
if delta_content.get("tool_calls"):
|
|
252
468
|
self._update_tool_calls(delta_content["tool_calls"])
|
|
253
469
|
self.finish_reason = delta.get("finish_reason")
|
|
470
|
+
if self.finish_reason:
|
|
471
|
+
self.usage_stats["stop_reason"] = self.finish_reason
|
|
254
472
|
|
|
255
|
-
# Update timing stats
|
|
473
|
+
# Update timing stats
|
|
256
474
|
timing_stats = timing.stats
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
tokens_per_second = (completion_tokens / generation_time) if generation_time > 0 and completion_tokens > 0 else 0.0
|
|
475
|
+
if self.timing_stats["time_to_first_token"] is None:
|
|
476
|
+
self.timing_stats["time_to_first_token"] = timing_stats["time_to_first_token"]
|
|
260
477
|
|
|
261
478
|
self.timing_stats.update({
|
|
262
|
-
|
|
263
|
-
"
|
|
479
|
+
"generation_time": timing_stats["generation_time"],
|
|
480
|
+
"reasoning_time": timing_stats["reasoning_time"],
|
|
481
|
+
"reasoning_tokens": timing_stats["reasoning_tokens"]
|
|
264
482
|
})
|
|
483
|
+
|
|
484
|
+
# Calculate tokens per second only if we have valid completion tokens and generation time
|
|
485
|
+
if self.usage_stats["completion_tokens"] > 0 and timing_stats["generation_time"] > 0:
|
|
486
|
+
self.timing_stats["tokens_per_second"] = (
|
|
487
|
+
self.usage_stats["completion_tokens"] / timing_stats["generation_time"]
|
|
488
|
+
)
|
|
489
|
+
|
|
265
490
|
|
|
266
491
|
def _update_tool_calls(self, new_tool_calls: List[Dict[str, Any]]) -> None:
|
|
267
492
|
"""Update tool calls, handling both full and partial updates."""
|
|
@@ -292,22 +517,22 @@ class StreamResponse:
|
|
|
292
517
|
current_tool["function"]["arguments"] += func_delta["arguments"]
|
|
293
518
|
|
|
294
519
|
def has_updates(self) -> bool:
|
|
295
|
-
"""Check if this response has any content
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
buffer, output, _ = transformer(self.content, buffer)
|
|
520
|
+
"""Check if this response has any content, tool call, or usage updates."""
|
|
521
|
+
has_content = bool(self.content)
|
|
522
|
+
has_tool_calls = bool(self.tool_calls)
|
|
523
|
+
has_usage = self.usage_stats["prompt_tokens"] > 0 or self.usage_stats["completion_tokens"] > 0
|
|
524
|
+
has_finish = bool(self.finish_reason)
|
|
301
525
|
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
#
|
|
307
|
-
|
|
308
|
-
|
|
526
|
+
return has_content or has_tool_calls or has_usage or has_finish
|
|
527
|
+
|
|
528
|
+
def to_output(self, buffer: str, transformer: Any) -> tuple[BaseLLMOutput, str]:
|
|
529
|
+
"""Convert current state to LLMOutput."""
|
|
530
|
+
# Create usage object if we have stats
|
|
531
|
+
usage = None
|
|
532
|
+
if any(self.usage_stats.values()):
|
|
533
|
+
usage = LLMUsage(
|
|
309
534
|
stop_reason=self.usage_stats["stop_reason"],
|
|
310
|
-
time_to_first_token=self.timing_stats["time_to_first_token"],
|
|
535
|
+
time_to_first_token=self.timing_stats["time_to_first_token"] or 0.0,
|
|
311
536
|
tokens_per_second=self.timing_stats["tokens_per_second"],
|
|
312
537
|
prompt_tokens=self.usage_stats["prompt_tokens"],
|
|
313
538
|
completion_tokens=self.usage_stats["completion_tokens"],
|
|
@@ -315,6 +540,12 @@ class StreamResponse:
|
|
|
315
540
|
reasoning_time=self.timing_stats["reasoning_time"],
|
|
316
541
|
reasoning_tokens=self.timing_stats["reasoning_tokens"]
|
|
317
542
|
)
|
|
543
|
+
|
|
544
|
+
buffer, output, _ = transformer(self.content, buffer, usage)
|
|
545
|
+
|
|
546
|
+
# Add tool calls if present and supported
|
|
547
|
+
if self.tool_calls and hasattr(output, 'tool_calls'):
|
|
548
|
+
output.tool_calls = self.tool_calls
|
|
318
549
|
|
|
319
550
|
return output, buffer
|
|
320
551
|
|
|
@@ -327,6 +558,7 @@ class ResponseState:
|
|
|
327
558
|
self.function_calls = None # For future function calling support
|
|
328
559
|
self.tool_calls = None # List to accumulate tool calls
|
|
329
560
|
self.current_tool_call = None # Track current tool call being built
|
|
561
|
+
self.usage = None # Add usage field
|
|
330
562
|
self.state_changes = {
|
|
331
563
|
"reasoning_started": False,
|
|
332
564
|
"reasoning_ended": False,
|
|
@@ -338,7 +570,7 @@ class ResponseState:
|
|
|
338
570
|
|
|
339
571
|
class ResponseTransformer:
|
|
340
572
|
"""Base class for transforming model responses."""
|
|
341
|
-
def __init__(self, output_cls: type[
|
|
573
|
+
def __init__(self, output_cls: type[BaseLLMOutput] = LLMOutput):
|
|
342
574
|
self.state = ResponseState()
|
|
343
575
|
self.output_cls = output_cls
|
|
344
576
|
self.timing = None # Will be set by stream_generate
|
|
@@ -381,26 +613,27 @@ class ResponseTransformer:
|
|
|
381
613
|
text: Cleaned text to process for reasoning
|
|
382
614
|
"""
|
|
383
615
|
# Default implementation for <think> style reasoning
|
|
384
|
-
|
|
616
|
+
# Check for tags in the complete buffer
|
|
617
|
+
if "<think>" in self.state.buffer and not self.state.state_changes["reasoning_started"]:
|
|
385
618
|
self.state.state_changes["reasoning_started"] = True
|
|
386
619
|
if self.timing:
|
|
387
620
|
self.timing.start_reasoning()
|
|
388
621
|
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
if
|
|
399
|
-
self.state.
|
|
400
|
-
self.
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
622
|
+
# Extract content and handle end of reasoning
|
|
623
|
+
parts = self.state.buffer.split("<think>", 1)
|
|
624
|
+
if len(parts) > 1:
|
|
625
|
+
reasoning_text = parts[1]
|
|
626
|
+
end_parts = reasoning_text.split("</think>", 1)
|
|
627
|
+
self.state.reasoning = end_parts[0].strip()
|
|
628
|
+
self.state.response = end_parts[1].strip() if len(end_parts) > 1 else ""
|
|
629
|
+
|
|
630
|
+
# Check for end tag in complete buffer
|
|
631
|
+
if "</think>" in self.state.buffer and not self.state.state_changes["reasoning_ended"]:
|
|
632
|
+
self.state.state_changes["reasoning_ended"] = True
|
|
633
|
+
if self.timing:
|
|
634
|
+
# Estimate token count from character count (rough approximation)
|
|
635
|
+
token_count = len(self.state.reasoning) // 4
|
|
636
|
+
self.timing.end_reasoning(token_count)
|
|
404
637
|
else:
|
|
405
638
|
self.state.response = self.state.buffer
|
|
406
639
|
|
|
@@ -449,28 +682,43 @@ class ResponseTransformer:
|
|
|
449
682
|
Returns:
|
|
450
683
|
Tuple of (buffer, LLMOutput, state_changes)
|
|
451
684
|
"""
|
|
685
|
+
# Build base output with required fields
|
|
686
|
+
output_data = {
|
|
687
|
+
"response": self.state.response.strip(),
|
|
688
|
+
}
|
|
689
|
+
|
|
690
|
+
# Add optional fields if they exist
|
|
691
|
+
if self.state.usage is not None:
|
|
692
|
+
output_data["usage"] = self.state.usage
|
|
693
|
+
if self.state.reasoning:
|
|
694
|
+
output_data["reasoning"] = self.state.reasoning.strip()
|
|
695
|
+
if self.state.function_calls:
|
|
696
|
+
output_data["function_calls"] = self.state.function_calls
|
|
697
|
+
if self.state.tool_calls:
|
|
698
|
+
output_data["tool_calls"] = self.state.tool_calls
|
|
699
|
+
|
|
700
|
+
output = self.output_cls(**output_data)
|
|
701
|
+
|
|
452
702
|
return (
|
|
453
703
|
self.state.buffer,
|
|
454
|
-
|
|
455
|
-
response=self.state.response.strip(),
|
|
456
|
-
reasoning=self.state.reasoning.strip() if self.state.reasoning else None,
|
|
457
|
-
function_calls=self.state.function_calls,
|
|
458
|
-
tool_calls=self.state.tool_calls
|
|
459
|
-
),
|
|
704
|
+
output,
|
|
460
705
|
self.state.state_changes
|
|
461
706
|
)
|
|
462
707
|
|
|
463
|
-
def __call__(self, piece: str, buffer: str) -> tuple[str, LLMOutput, dict]:
|
|
708
|
+
def __call__(self, piece: str, buffer: str, usage: Optional[LLMUsage] = None) -> tuple[str, LLMOutput, dict]:
|
|
464
709
|
"""Transform a piece of text and return the result.
|
|
465
710
|
|
|
466
711
|
Args:
|
|
467
712
|
piece: New piece of text to transform
|
|
468
713
|
buffer: Existing buffer content
|
|
714
|
+
usage: Optional usage statistics
|
|
469
715
|
|
|
470
716
|
Returns:
|
|
471
717
|
Tuple of (new_buffer, output, state_changes)
|
|
472
718
|
"""
|
|
473
719
|
self.state.buffer = buffer
|
|
720
|
+
if usage is not None:
|
|
721
|
+
self.state.usage = usage
|
|
474
722
|
self.transform_chunk(piece)
|
|
475
723
|
return self.build_output()
|
|
476
724
|
|
|
@@ -483,42 +731,131 @@ def stream_generate(
|
|
|
483
731
|
tool_choice: Optional[Dict[str, Any]] = None,
|
|
484
732
|
temperature: float = 0.7,
|
|
485
733
|
top_p: float = 0.95,
|
|
486
|
-
max_tokens: int = 4096,
|
|
487
734
|
stop: Optional[List[str]] = None,
|
|
488
735
|
verbose: bool = False,
|
|
489
|
-
|
|
736
|
+
output_cls: type[BaseLLMOutput] = LLMOutput,
|
|
737
|
+
kwargs: Optional[Dict[str, Any]] = None,
|
|
738
|
+
) -> Generator[BaseLLMOutput, None, None]:
|
|
490
739
|
"""Stream generate from LLaMA.cpp model with timing and usage tracking."""
|
|
740
|
+
|
|
741
|
+
# Create queues for communication between threads
|
|
742
|
+
response_queue = Queue()
|
|
743
|
+
error_queue = Queue()
|
|
744
|
+
keep_alive_queue = Queue()
|
|
745
|
+
|
|
746
|
+
# Set the output class for the transformer
|
|
747
|
+
transformer.output_cls = output_cls
|
|
748
|
+
|
|
749
|
+
def _generate_worker():
|
|
750
|
+
"""Worker thread to run the model generation."""
|
|
751
|
+
try:
|
|
752
|
+
# Build completion kwargs
|
|
753
|
+
completion_kwargs = {
|
|
754
|
+
"messages": messages,
|
|
755
|
+
"stream": True,
|
|
756
|
+
"temperature": temperature,
|
|
757
|
+
"top_p": top_p,
|
|
758
|
+
"stop": stop,
|
|
759
|
+
}
|
|
760
|
+
if kwargs:
|
|
761
|
+
completion_kwargs.update(kwargs)
|
|
762
|
+
if tools is not None:
|
|
763
|
+
completion_kwargs["tools"] = tools
|
|
764
|
+
if tool_choice is not None:
|
|
765
|
+
completion_kwargs["tool_choice"] = tool_choice
|
|
766
|
+
|
|
767
|
+
# Signal that we're starting
|
|
768
|
+
keep_alive_queue.put(("init", time.time()))
|
|
769
|
+
|
|
770
|
+
completion = model.create_chat_completion(**completion_kwargs)
|
|
771
|
+
|
|
772
|
+
for chunk in completion:
|
|
773
|
+
response_queue.put(("chunk", chunk))
|
|
774
|
+
# Update keep-alive timestamp
|
|
775
|
+
keep_alive_queue.put(("alive", time.time()))
|
|
776
|
+
|
|
777
|
+
# Signal completion
|
|
778
|
+
response_queue.put(("done", None))
|
|
779
|
+
|
|
780
|
+
except Exception as e:
|
|
781
|
+
# Preserve the full exception with traceback
|
|
782
|
+
import sys
|
|
783
|
+
error_queue.put((e, sys.exc_info()[2]))
|
|
784
|
+
response_queue.put(("error", str(e)))
|
|
785
|
+
|
|
491
786
|
with timing_context() as timing:
|
|
492
787
|
transformer.timing = timing
|
|
493
788
|
|
|
494
|
-
#
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
"stream": True,
|
|
498
|
-
"temperature": temperature,
|
|
499
|
-
"top_p": top_p,
|
|
500
|
-
"max_tokens": max_tokens,
|
|
501
|
-
"stop": stop
|
|
502
|
-
}
|
|
503
|
-
if tools is not None:
|
|
504
|
-
completion_kwargs["tools"] = tools
|
|
505
|
-
if tool_choice is not None:
|
|
506
|
-
completion_kwargs["tool_choice"] = tool_choice
|
|
789
|
+
# Start generation thread
|
|
790
|
+
generation_thread = Thread(target=_generate_worker, daemon=True)
|
|
791
|
+
generation_thread.start()
|
|
507
792
|
|
|
508
793
|
# Initialize response state
|
|
509
794
|
response = StreamResponse()
|
|
510
795
|
buffer = ""
|
|
511
796
|
|
|
797
|
+
# Keep-alive tracking
|
|
798
|
+
last_activity = time.time()
|
|
799
|
+
init_timeout = 30.0 # 30 seconds for initial response
|
|
800
|
+
chunk_timeout = 10.0 # 10 seconds between chunks
|
|
801
|
+
chunks_begun = False
|
|
802
|
+
|
|
512
803
|
try:
|
|
513
|
-
|
|
804
|
+
# Wait for initial setup
|
|
805
|
+
try:
|
|
806
|
+
msg_type, timestamp = keep_alive_queue.get(timeout=init_timeout)
|
|
807
|
+
if msg_type != "init":
|
|
808
|
+
raise RuntimeError("Unexpected initialization message")
|
|
809
|
+
last_activity = timestamp
|
|
810
|
+
except Empty:
|
|
811
|
+
raise RuntimeError(f"Model failed to initialize within {init_timeout} seconds")
|
|
514
812
|
|
|
515
|
-
|
|
813
|
+
while True:
|
|
814
|
+
# Check for errors - now with proper exception chaining
|
|
815
|
+
if not error_queue.empty():
|
|
816
|
+
exc, tb = error_queue.get()
|
|
817
|
+
if isinstance(exc, Exception):
|
|
818
|
+
raise exc.with_traceback(tb)
|
|
819
|
+
else:
|
|
820
|
+
raise RuntimeError(f"Unknown error in worker thread: {exc}")
|
|
821
|
+
|
|
822
|
+
# Check keep-alive
|
|
823
|
+
try:
|
|
824
|
+
while not keep_alive_queue.empty():
|
|
825
|
+
_, timestamp = keep_alive_queue.get_nowait()
|
|
826
|
+
last_activity = timestamp
|
|
827
|
+
except Empty:
|
|
828
|
+
# Ignore empty queue - this is expected
|
|
829
|
+
pass
|
|
830
|
+
|
|
831
|
+
# Check for timeout
|
|
832
|
+
if chunks_begun and time.time() - last_activity > chunk_timeout:
|
|
833
|
+
raise RuntimeError(f"No response from model for {chunk_timeout} seconds")
|
|
834
|
+
|
|
835
|
+
# Get next chunk
|
|
836
|
+
try:
|
|
837
|
+
msg_type, data = response_queue.get(timeout=0.1)
|
|
838
|
+
except Empty:
|
|
839
|
+
continue
|
|
840
|
+
|
|
841
|
+
if msg_type == "error":
|
|
842
|
+
# If we get an error message but no exception in error_queue,
|
|
843
|
+
# create a new error
|
|
844
|
+
raise RuntimeError(f"Generation error: {data}")
|
|
845
|
+
elif msg_type == "done":
|
|
846
|
+
break
|
|
847
|
+
|
|
848
|
+
chunk = data
|
|
849
|
+
|
|
516
850
|
if verbose:
|
|
517
851
|
print(chunk)
|
|
518
|
-
|
|
852
|
+
|
|
853
|
+
# Mark first token time
|
|
519
854
|
if not timing.first_token_time:
|
|
520
855
|
timing.mark_first_token()
|
|
521
856
|
|
|
857
|
+
chunks_begun = True
|
|
858
|
+
|
|
522
859
|
# Update response state from chunk
|
|
523
860
|
response.update_from_chunk(chunk, timing)
|
|
524
861
|
|
|
@@ -530,7 +867,19 @@ def stream_generate(
|
|
|
530
867
|
# Break if we're done
|
|
531
868
|
if response.finish_reason:
|
|
532
869
|
break
|
|
870
|
+
|
|
871
|
+
# Wait for generation thread to finish
|
|
872
|
+
if generation_thread.is_alive():
|
|
873
|
+
generation_thread.join(timeout=5.0) # Increased timeout to 5 seconds
|
|
874
|
+
if generation_thread.is_alive():
|
|
875
|
+
# Thread didn't finish - this shouldn't happen normally
|
|
876
|
+
raise RuntimeError("Generation thread failed to finish")
|
|
533
877
|
|
|
534
878
|
except Exception as e:
|
|
535
|
-
#
|
|
536
|
-
|
|
879
|
+
# Check if there's a thread error we should chain with
|
|
880
|
+
if not error_queue.empty():
|
|
881
|
+
thread_exc, thread_tb = error_queue.get()
|
|
882
|
+
if isinstance(thread_exc, Exception):
|
|
883
|
+
raise e from thread_exc
|
|
884
|
+
# If no thread error, raise the original exception
|
|
885
|
+
raise
|