prompture 0.0.35__py3-none-any.whl → 0.0.40.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prompture/__init__.py +132 -3
- prompture/_version.py +2 -2
- prompture/agent.py +924 -0
- prompture/agent_types.py +156 -0
- prompture/async_agent.py +880 -0
- prompture/async_conversation.py +208 -17
- prompture/async_core.py +16 -0
- prompture/async_driver.py +63 -0
- prompture/async_groups.py +551 -0
- prompture/conversation.py +222 -18
- prompture/core.py +46 -12
- prompture/cost_mixin.py +37 -0
- prompture/discovery.py +132 -44
- prompture/driver.py +77 -0
- prompture/drivers/__init__.py +5 -1
- prompture/drivers/async_azure_driver.py +11 -5
- prompture/drivers/async_claude_driver.py +184 -9
- prompture/drivers/async_google_driver.py +222 -28
- prompture/drivers/async_grok_driver.py +11 -5
- prompture/drivers/async_groq_driver.py +11 -5
- prompture/drivers/async_lmstudio_driver.py +74 -5
- prompture/drivers/async_ollama_driver.py +13 -3
- prompture/drivers/async_openai_driver.py +162 -5
- prompture/drivers/async_openrouter_driver.py +11 -5
- prompture/drivers/async_registry.py +5 -1
- prompture/drivers/azure_driver.py +10 -4
- prompture/drivers/claude_driver.py +17 -1
- prompture/drivers/google_driver.py +227 -33
- prompture/drivers/grok_driver.py +11 -5
- prompture/drivers/groq_driver.py +11 -5
- prompture/drivers/lmstudio_driver.py +73 -8
- prompture/drivers/ollama_driver.py +16 -5
- prompture/drivers/openai_driver.py +26 -11
- prompture/drivers/openrouter_driver.py +11 -5
- prompture/drivers/vision_helpers.py +153 -0
- prompture/group_types.py +147 -0
- prompture/groups.py +530 -0
- prompture/image.py +180 -0
- prompture/ledger.py +252 -0
- prompture/model_rates.py +112 -2
- prompture/persistence.py +254 -0
- prompture/persona.py +482 -0
- prompture/serialization.py +218 -0
- prompture/settings.py +1 -0
- prompture-0.0.40.dev1.dist-info/METADATA +369 -0
- prompture-0.0.40.dev1.dist-info/RECORD +78 -0
- prompture-0.0.35.dist-info/METADATA +0 -464
- prompture-0.0.35.dist-info/RECORD +0 -66
- {prompture-0.0.35.dist-info → prompture-0.0.40.dev1.dist-info}/WHEEL +0 -0
- {prompture-0.0.35.dist-info → prompture-0.0.40.dev1.dist-info}/entry_points.txt +0 -0
- {prompture-0.0.35.dist-info → prompture-0.0.40.dev1.dist-info}/licenses/LICENSE +0 -0
- {prompture-0.0.35.dist-info → prompture-0.0.40.dev1.dist-info}/top_level.txt +0 -0
|
@@ -4,6 +4,8 @@ from __future__ import annotations
|
|
|
4
4
|
|
|
5
5
|
import logging
|
|
6
6
|
import os
|
|
7
|
+
import uuid
|
|
8
|
+
from collections.abc import AsyncIterator
|
|
7
9
|
from typing import Any
|
|
8
10
|
|
|
9
11
|
import google.generativeai as genai
|
|
@@ -20,6 +22,9 @@ class AsyncGoogleDriver(CostMixin, AsyncDriver):
|
|
|
20
22
|
|
|
21
23
|
supports_json_mode = True
|
|
22
24
|
supports_json_schema = True
|
|
25
|
+
supports_vision = True
|
|
26
|
+
supports_tool_use = True
|
|
27
|
+
supports_streaming = True
|
|
23
28
|
|
|
24
29
|
MODEL_PRICING = GoogleDriver.MODEL_PRICING
|
|
25
30
|
_PRICING_UNIT = 1_000_000
|
|
@@ -48,18 +53,51 @@ class AsyncGoogleDriver(CostMixin, AsyncDriver):
|
|
|
48
53
|
completion_cost = (completion_chars / 1_000_000) * model_pricing["completion"]
|
|
49
54
|
return round(prompt_cost + completion_cost, 6)
|
|
50
55
|
|
|
56
|
+
def _extract_usage_metadata(self, response: Any, messages: list[dict[str, Any]]) -> dict[str, Any]:
|
|
57
|
+
"""Extract token counts from response, falling back to character estimation."""
|
|
58
|
+
usage = getattr(response, "usage_metadata", None)
|
|
59
|
+
if usage:
|
|
60
|
+
prompt_tokens = getattr(usage, "prompt_token_count", 0) or 0
|
|
61
|
+
completion_tokens = getattr(usage, "candidates_token_count", 0) or 0
|
|
62
|
+
total_tokens = getattr(usage, "total_token_count", 0) or (prompt_tokens + completion_tokens)
|
|
63
|
+
cost = self._calculate_cost("google", self.model, prompt_tokens, completion_tokens)
|
|
64
|
+
else:
|
|
65
|
+
# Fallback: estimate from character counts
|
|
66
|
+
total_prompt_chars = 0
|
|
67
|
+
for msg in messages:
|
|
68
|
+
c = msg.get("content", "")
|
|
69
|
+
if isinstance(c, str):
|
|
70
|
+
total_prompt_chars += len(c)
|
|
71
|
+
elif isinstance(c, list):
|
|
72
|
+
for part in c:
|
|
73
|
+
if isinstance(part, str):
|
|
74
|
+
total_prompt_chars += len(part)
|
|
75
|
+
elif isinstance(part, dict) and "text" in part:
|
|
76
|
+
total_prompt_chars += len(part["text"])
|
|
77
|
+
completion_chars = len(response.text) if response.text else 0
|
|
78
|
+
prompt_tokens = total_prompt_chars // 4
|
|
79
|
+
completion_tokens = completion_chars // 4
|
|
80
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
81
|
+
cost = self._calculate_cost_chars(total_prompt_chars, completion_chars)
|
|
82
|
+
|
|
83
|
+
return {
|
|
84
|
+
"prompt_tokens": prompt_tokens,
|
|
85
|
+
"completion_tokens": completion_tokens,
|
|
86
|
+
"total_tokens": total_tokens,
|
|
87
|
+
"cost": round(cost, 6),
|
|
88
|
+
}
|
|
89
|
+
|
|
51
90
|
supports_messages = True
|
|
52
91
|
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
return await self._do_generate(messages, options)
|
|
92
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
93
|
+
from .vision_helpers import _prepare_google_vision_messages
|
|
56
94
|
|
|
57
|
-
|
|
58
|
-
return await self._do_generate(messages, options)
|
|
95
|
+
return _prepare_google_vision_messages(messages)
|
|
59
96
|
|
|
60
|
-
|
|
61
|
-
self, messages: list[dict[str,
|
|
62
|
-
) -> dict[str, Any]:
|
|
97
|
+
def _build_generation_args(
|
|
98
|
+
self, messages: list[dict[str, Any]], options: dict[str, Any] | None = None
|
|
99
|
+
) -> tuple[Any, dict[str, Any], dict[str, Any]]:
|
|
100
|
+
"""Parse messages and options into (gen_input, gen_kwargs, model_kwargs)."""
|
|
63
101
|
merged_options = self.options.copy()
|
|
64
102
|
if options:
|
|
65
103
|
merged_options.update(options)
|
|
@@ -90,37 +128,65 @@ class AsyncGoogleDriver(CostMixin, AsyncDriver):
|
|
|
90
128
|
role = msg.get("role", "user")
|
|
91
129
|
content = msg.get("content", "")
|
|
92
130
|
if role == "system":
|
|
93
|
-
system_instruction = content
|
|
131
|
+
system_instruction = content if isinstance(content, str) else str(content)
|
|
94
132
|
else:
|
|
95
133
|
gemini_role = "model" if role == "assistant" else "user"
|
|
96
|
-
|
|
134
|
+
if msg.get("_vision_parts"):
|
|
135
|
+
contents.append({"role": gemini_role, "parts": content})
|
|
136
|
+
else:
|
|
137
|
+
contents.append({"role": gemini_role, "parts": [content]})
|
|
138
|
+
|
|
139
|
+
# For a single message, unwrap only if it has exactly one string part
|
|
140
|
+
if len(contents) == 1:
|
|
141
|
+
parts = contents[0]["parts"]
|
|
142
|
+
if len(parts) == 1 and isinstance(parts[0], str):
|
|
143
|
+
gen_input = parts[0]
|
|
144
|
+
else:
|
|
145
|
+
gen_input = contents
|
|
146
|
+
else:
|
|
147
|
+
gen_input = contents
|
|
148
|
+
|
|
149
|
+
model_kwargs: dict[str, Any] = {}
|
|
150
|
+
if system_instruction:
|
|
151
|
+
model_kwargs["system_instruction"] = system_instruction
|
|
152
|
+
|
|
153
|
+
gen_kwargs: dict[str, Any] = {
|
|
154
|
+
"generation_config": generation_config if generation_config else None,
|
|
155
|
+
"safety_settings": safety_settings if safety_settings else None,
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
return gen_input, gen_kwargs, model_kwargs
|
|
159
|
+
|
|
160
|
+
async def generate(self, prompt: str, options: dict[str, Any] | None = None) -> dict[str, Any]:
|
|
161
|
+
messages = [{"role": "user", "content": prompt}]
|
|
162
|
+
return await self._do_generate(messages, options)
|
|
163
|
+
|
|
164
|
+
async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
165
|
+
return await self._do_generate(self._prepare_messages(messages), options)
|
|
166
|
+
|
|
167
|
+
async def _do_generate(
|
|
168
|
+
self, messages: list[dict[str, str]], options: dict[str, Any] | None = None
|
|
169
|
+
) -> dict[str, Any]:
|
|
170
|
+
gen_input, gen_kwargs, model_kwargs = self._build_generation_args(messages, options)
|
|
171
|
+
|
|
172
|
+
# Validate capabilities against models.dev metadata
|
|
173
|
+
self._validate_model_capabilities(
|
|
174
|
+
"google",
|
|
175
|
+
self.model,
|
|
176
|
+
using_json_schema=bool((options or {}).get("json_schema")),
|
|
177
|
+
)
|
|
97
178
|
|
|
98
179
|
try:
|
|
99
|
-
model_kwargs: dict[str, Any] = {}
|
|
100
|
-
if system_instruction:
|
|
101
|
-
model_kwargs["system_instruction"] = system_instruction
|
|
102
180
|
model = genai.GenerativeModel(self.model, **model_kwargs)
|
|
103
|
-
|
|
104
|
-
gen_input: Any = contents if len(contents) != 1 else contents[0]["parts"][0]
|
|
105
|
-
response = await model.generate_content_async(
|
|
106
|
-
gen_input,
|
|
107
|
-
generation_config=generation_config if generation_config else None,
|
|
108
|
-
safety_settings=safety_settings if safety_settings else None,
|
|
109
|
-
)
|
|
181
|
+
response = await model.generate_content_async(gen_input, **gen_kwargs)
|
|
110
182
|
|
|
111
183
|
if not response.text:
|
|
112
184
|
raise ValueError("Empty response from model")
|
|
113
185
|
|
|
114
|
-
|
|
115
|
-
completion_chars = len(response.text)
|
|
116
|
-
|
|
117
|
-
total_cost = self._calculate_cost_chars(total_prompt_chars, completion_chars)
|
|
186
|
+
usage_meta = self._extract_usage_metadata(response, messages)
|
|
118
187
|
|
|
119
188
|
meta = {
|
|
120
|
-
|
|
121
|
-
"completion_chars": completion_chars,
|
|
122
|
-
"total_chars": total_prompt_chars + completion_chars,
|
|
123
|
-
"cost": total_cost,
|
|
189
|
+
**usage_meta,
|
|
124
190
|
"raw_response": response.prompt_feedback if hasattr(response, "prompt_feedback") else None,
|
|
125
191
|
"model_name": self.model,
|
|
126
192
|
}
|
|
@@ -130,3 +196,131 @@ class AsyncGoogleDriver(CostMixin, AsyncDriver):
|
|
|
130
196
|
except Exception as e:
|
|
131
197
|
logger.error(f"Google API request failed: {e}")
|
|
132
198
|
raise RuntimeError(f"Google API request failed: {e}") from e
|
|
199
|
+
|
|
200
|
+
# ------------------------------------------------------------------
|
|
201
|
+
# Tool use
|
|
202
|
+
# ------------------------------------------------------------------
|
|
203
|
+
|
|
204
|
+
async def generate_messages_with_tools(
|
|
205
|
+
self,
|
|
206
|
+
messages: list[dict[str, Any]],
|
|
207
|
+
tools: list[dict[str, Any]],
|
|
208
|
+
options: dict[str, Any],
|
|
209
|
+
) -> dict[str, Any]:
|
|
210
|
+
"""Generate a response that may include tool/function calls (async)."""
|
|
211
|
+
model = options.get("model", self.model)
|
|
212
|
+
self._validate_model_capabilities("google", model, using_tool_use=True)
|
|
213
|
+
|
|
214
|
+
gen_input, gen_kwargs, model_kwargs = self._build_generation_args(
|
|
215
|
+
self._prepare_messages(messages), options
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
# Convert tools from OpenAI format to Gemini function declarations
|
|
219
|
+
function_declarations = []
|
|
220
|
+
for t in tools:
|
|
221
|
+
if "type" in t and t["type"] == "function":
|
|
222
|
+
fn = t["function"]
|
|
223
|
+
decl = {
|
|
224
|
+
"name": fn["name"],
|
|
225
|
+
"description": fn.get("description", ""),
|
|
226
|
+
}
|
|
227
|
+
params = fn.get("parameters")
|
|
228
|
+
if params:
|
|
229
|
+
decl["parameters"] = params
|
|
230
|
+
function_declarations.append(decl)
|
|
231
|
+
elif "name" in t:
|
|
232
|
+
decl = {"name": t["name"], "description": t.get("description", "")}
|
|
233
|
+
params = t.get("parameters") or t.get("input_schema")
|
|
234
|
+
if params:
|
|
235
|
+
decl["parameters"] = params
|
|
236
|
+
function_declarations.append(decl)
|
|
237
|
+
|
|
238
|
+
try:
|
|
239
|
+
model = genai.GenerativeModel(self.model, **model_kwargs)
|
|
240
|
+
|
|
241
|
+
gemini_tools = [genai.types.Tool(function_declarations=function_declarations)]
|
|
242
|
+
response = await model.generate_content_async(gen_input, tools=gemini_tools, **gen_kwargs)
|
|
243
|
+
|
|
244
|
+
usage_meta = self._extract_usage_metadata(response, messages)
|
|
245
|
+
meta = {
|
|
246
|
+
**usage_meta,
|
|
247
|
+
"raw_response": response.prompt_feedback if hasattr(response, "prompt_feedback") else None,
|
|
248
|
+
"model_name": self.model,
|
|
249
|
+
}
|
|
250
|
+
|
|
251
|
+
text = ""
|
|
252
|
+
tool_calls_out: list[dict[str, Any]] = []
|
|
253
|
+
stop_reason = "stop"
|
|
254
|
+
|
|
255
|
+
for candidate in response.candidates:
|
|
256
|
+
for part in candidate.content.parts:
|
|
257
|
+
if hasattr(part, "text") and part.text:
|
|
258
|
+
text += part.text
|
|
259
|
+
if hasattr(part, "function_call") and part.function_call.name:
|
|
260
|
+
fc = part.function_call
|
|
261
|
+
tool_calls_out.append({
|
|
262
|
+
"id": str(uuid.uuid4()),
|
|
263
|
+
"name": fc.name,
|
|
264
|
+
"arguments": dict(fc.args) if fc.args else {},
|
|
265
|
+
})
|
|
266
|
+
|
|
267
|
+
finish_reason = getattr(candidate, "finish_reason", None)
|
|
268
|
+
if finish_reason is not None:
|
|
269
|
+
reason_map = {1: "stop", 2: "max_tokens", 3: "safety", 4: "recitation", 5: "other"}
|
|
270
|
+
stop_reason = reason_map.get(finish_reason, "stop")
|
|
271
|
+
|
|
272
|
+
if tool_calls_out:
|
|
273
|
+
stop_reason = "tool_use"
|
|
274
|
+
|
|
275
|
+
return {
|
|
276
|
+
"text": text,
|
|
277
|
+
"meta": meta,
|
|
278
|
+
"tool_calls": tool_calls_out,
|
|
279
|
+
"stop_reason": stop_reason,
|
|
280
|
+
}
|
|
281
|
+
|
|
282
|
+
except Exception as e:
|
|
283
|
+
logger.error(f"Google API tool call request failed: {e}")
|
|
284
|
+
raise RuntimeError(f"Google API tool call request failed: {e}") from e
|
|
285
|
+
|
|
286
|
+
# ------------------------------------------------------------------
|
|
287
|
+
# Streaming
|
|
288
|
+
# ------------------------------------------------------------------
|
|
289
|
+
|
|
290
|
+
async def generate_messages_stream(
|
|
291
|
+
self,
|
|
292
|
+
messages: list[dict[str, Any]],
|
|
293
|
+
options: dict[str, Any],
|
|
294
|
+
) -> AsyncIterator[dict[str, Any]]:
|
|
295
|
+
"""Yield response chunks via Gemini async streaming API."""
|
|
296
|
+
gen_input, gen_kwargs, model_kwargs = self._build_generation_args(
|
|
297
|
+
self._prepare_messages(messages), options
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
try:
|
|
301
|
+
model = genai.GenerativeModel(self.model, **model_kwargs)
|
|
302
|
+
response = await model.generate_content_async(gen_input, stream=True, **gen_kwargs)
|
|
303
|
+
|
|
304
|
+
full_text = ""
|
|
305
|
+
async for chunk in response:
|
|
306
|
+
chunk_text = getattr(chunk, "text", None) or ""
|
|
307
|
+
if chunk_text:
|
|
308
|
+
full_text += chunk_text
|
|
309
|
+
yield {"type": "delta", "text": chunk_text}
|
|
310
|
+
|
|
311
|
+
# After iteration completes, usage_metadata should be available
|
|
312
|
+
usage_meta = self._extract_usage_metadata(response, messages)
|
|
313
|
+
|
|
314
|
+
yield {
|
|
315
|
+
"type": "done",
|
|
316
|
+
"text": full_text,
|
|
317
|
+
"meta": {
|
|
318
|
+
**usage_meta,
|
|
319
|
+
"raw_response": {},
|
|
320
|
+
"model_name": self.model,
|
|
321
|
+
},
|
|
322
|
+
}
|
|
323
|
+
|
|
324
|
+
except Exception as e:
|
|
325
|
+
logger.error(f"Google API streaming request failed: {e}")
|
|
326
|
+
raise RuntimeError(f"Google API streaming request failed: {e}") from e
|
|
@@ -14,6 +14,7 @@ from .grok_driver import GrokDriver
|
|
|
14
14
|
|
|
15
15
|
class AsyncGrokDriver(CostMixin, AsyncDriver):
|
|
16
16
|
supports_json_mode = True
|
|
17
|
+
supports_vision = True
|
|
17
18
|
|
|
18
19
|
MODEL_PRICING = GrokDriver.MODEL_PRICING
|
|
19
20
|
_PRICING_UNIT = 1_000_000
|
|
@@ -25,12 +26,17 @@ class AsyncGrokDriver(CostMixin, AsyncDriver):
|
|
|
25
26
|
|
|
26
27
|
supports_messages = True
|
|
27
28
|
|
|
29
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
30
|
+
from .vision_helpers import _prepare_openai_vision_messages
|
|
31
|
+
|
|
32
|
+
return _prepare_openai_vision_messages(messages)
|
|
33
|
+
|
|
28
34
|
async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
29
35
|
messages = [{"role": "user", "content": prompt}]
|
|
30
36
|
return await self._do_generate(messages, options)
|
|
31
37
|
|
|
32
38
|
async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
33
|
-
return await self._do_generate(messages, options)
|
|
39
|
+
return await self._do_generate(self._prepare_messages(messages), options)
|
|
34
40
|
|
|
35
41
|
async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
36
42
|
if not self.api_key:
|
|
@@ -38,9 +44,9 @@ class AsyncGrokDriver(CostMixin, AsyncDriver):
|
|
|
38
44
|
|
|
39
45
|
model = options.get("model", self.model)
|
|
40
46
|
|
|
41
|
-
|
|
42
|
-
tokens_param =
|
|
43
|
-
supports_temperature =
|
|
47
|
+
model_config = self._get_model_config("grok", model)
|
|
48
|
+
tokens_param = model_config["tokens_param"]
|
|
49
|
+
supports_temperature = model_config["supports_temperature"]
|
|
44
50
|
|
|
45
51
|
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
46
52
|
|
|
@@ -82,7 +88,7 @@ class AsyncGrokDriver(CostMixin, AsyncDriver):
|
|
|
82
88
|
"prompt_tokens": prompt_tokens,
|
|
83
89
|
"completion_tokens": completion_tokens,
|
|
84
90
|
"total_tokens": total_tokens,
|
|
85
|
-
"cost": total_cost,
|
|
91
|
+
"cost": round(total_cost, 6),
|
|
86
92
|
"raw_response": resp,
|
|
87
93
|
"model_name": model,
|
|
88
94
|
}
|
|
@@ -17,6 +17,7 @@ from .groq_driver import GroqDriver
|
|
|
17
17
|
|
|
18
18
|
class AsyncGroqDriver(CostMixin, AsyncDriver):
|
|
19
19
|
supports_json_mode = True
|
|
20
|
+
supports_vision = True
|
|
20
21
|
|
|
21
22
|
MODEL_PRICING = GroqDriver.MODEL_PRICING
|
|
22
23
|
|
|
@@ -30,12 +31,17 @@ class AsyncGroqDriver(CostMixin, AsyncDriver):
|
|
|
30
31
|
|
|
31
32
|
supports_messages = True
|
|
32
33
|
|
|
34
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
35
|
+
from .vision_helpers import _prepare_openai_vision_messages
|
|
36
|
+
|
|
37
|
+
return _prepare_openai_vision_messages(messages)
|
|
38
|
+
|
|
33
39
|
async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
34
40
|
messages = [{"role": "user", "content": prompt}]
|
|
35
41
|
return await self._do_generate(messages, options)
|
|
36
42
|
|
|
37
43
|
async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
38
|
-
return await self._do_generate(messages, options)
|
|
44
|
+
return await self._do_generate(self._prepare_messages(messages), options)
|
|
39
45
|
|
|
40
46
|
async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
41
47
|
if self.client is None:
|
|
@@ -43,9 +49,9 @@ class AsyncGroqDriver(CostMixin, AsyncDriver):
|
|
|
43
49
|
|
|
44
50
|
model = options.get("model", self.model)
|
|
45
51
|
|
|
46
|
-
|
|
47
|
-
tokens_param =
|
|
48
|
-
supports_temperature =
|
|
52
|
+
model_config = self._get_model_config("groq", model)
|
|
53
|
+
tokens_param = model_config["tokens_param"]
|
|
54
|
+
supports_temperature = model_config["supports_temperature"]
|
|
49
55
|
|
|
50
56
|
opts = {"temperature": 0.7, "max_tokens": 512, **options}
|
|
51
57
|
|
|
@@ -75,7 +81,7 @@ class AsyncGroqDriver(CostMixin, AsyncDriver):
|
|
|
75
81
|
"prompt_tokens": prompt_tokens,
|
|
76
82
|
"completion_tokens": completion_tokens,
|
|
77
83
|
"total_tokens": total_tokens,
|
|
78
|
-
"cost": total_cost,
|
|
84
|
+
"cost": round(total_cost, 6),
|
|
79
85
|
"raw_response": resp.model_dump(),
|
|
80
86
|
"model_name": model,
|
|
81
87
|
}
|
|
@@ -15,22 +15,48 @@ logger = logging.getLogger(__name__)
|
|
|
15
15
|
|
|
16
16
|
class AsyncLMStudioDriver(AsyncDriver):
|
|
17
17
|
supports_json_mode = True
|
|
18
|
+
supports_json_schema = True
|
|
19
|
+
supports_vision = True
|
|
18
20
|
|
|
19
21
|
MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
|
|
20
22
|
|
|
21
|
-
def __init__(
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
endpoint: str | None = None,
|
|
26
|
+
model: str = "deepseek/deepseek-r1-0528-qwen3-8b",
|
|
27
|
+
api_key: str | None = None,
|
|
28
|
+
):
|
|
22
29
|
self.endpoint = endpoint or os.getenv("LMSTUDIO_ENDPOINT", "http://127.0.0.1:1234/v1/chat/completions")
|
|
23
30
|
self.model = model
|
|
24
31
|
self.options: dict[str, Any] = {}
|
|
25
32
|
|
|
33
|
+
# Derive base_url once for reuse across management endpoints
|
|
34
|
+
self.base_url = self.endpoint.split("/v1/")[0]
|
|
35
|
+
|
|
36
|
+
# API key for LM Studio 0.4.0+ authentication
|
|
37
|
+
self.api_key = api_key or os.getenv("LMSTUDIO_API_KEY")
|
|
38
|
+
self._headers = self._build_headers()
|
|
39
|
+
|
|
26
40
|
supports_messages = True
|
|
27
41
|
|
|
42
|
+
def _build_headers(self) -> dict[str, str]:
|
|
43
|
+
"""Build request headers, including auth if an API key is configured."""
|
|
44
|
+
headers: dict[str, str] = {"Content-Type": "application/json"}
|
|
45
|
+
if self.api_key:
|
|
46
|
+
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
47
|
+
return headers
|
|
48
|
+
|
|
49
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
50
|
+
from .vision_helpers import _prepare_openai_vision_messages
|
|
51
|
+
|
|
52
|
+
return _prepare_openai_vision_messages(messages)
|
|
53
|
+
|
|
28
54
|
async def generate(self, prompt: str, options: dict[str, Any] | None = None) -> dict[str, Any]:
|
|
29
55
|
messages = [{"role": "user", "content": prompt}]
|
|
30
56
|
return await self._do_generate(messages, options)
|
|
31
57
|
|
|
32
58
|
async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
33
|
-
return await self._do_generate(messages, options)
|
|
59
|
+
return await self._do_generate(self._prepare_messages(messages), options)
|
|
34
60
|
|
|
35
61
|
async def _do_generate(
|
|
36
62
|
self, messages: list[dict[str, str]], options: dict[str, Any] | None = None
|
|
@@ -45,13 +71,25 @@ class AsyncLMStudioDriver(AsyncDriver):
|
|
|
45
71
|
"temperature": merged_options.get("temperature", 0.7),
|
|
46
72
|
}
|
|
47
73
|
|
|
48
|
-
# Native JSON mode support
|
|
74
|
+
# Native JSON mode support (LM Studio requires json_schema, not json_object)
|
|
49
75
|
if merged_options.get("json_mode"):
|
|
50
|
-
|
|
76
|
+
json_schema = merged_options.get("json_schema")
|
|
77
|
+
if json_schema:
|
|
78
|
+
payload["response_format"] = {
|
|
79
|
+
"type": "json_schema",
|
|
80
|
+
"json_schema": {
|
|
81
|
+
"name": "extraction",
|
|
82
|
+
"schema": json_schema,
|
|
83
|
+
},
|
|
84
|
+
}
|
|
85
|
+
else:
|
|
86
|
+
# No schema provided — omit response_format entirely;
|
|
87
|
+
# LM Studio rejects "json_object" type.
|
|
88
|
+
pass
|
|
51
89
|
|
|
52
90
|
async with httpx.AsyncClient() as client:
|
|
53
91
|
try:
|
|
54
|
-
r = await client.post(self.endpoint, json=payload, timeout=120)
|
|
92
|
+
r = await client.post(self.endpoint, json=payload, headers=self._headers, timeout=120)
|
|
55
93
|
r.raise_for_status()
|
|
56
94
|
response_data = r.json()
|
|
57
95
|
except Exception as e:
|
|
@@ -77,3 +115,34 @@ class AsyncLMStudioDriver(AsyncDriver):
|
|
|
77
115
|
}
|
|
78
116
|
|
|
79
117
|
return {"text": text, "meta": meta}
|
|
118
|
+
|
|
119
|
+
# -- Model management (LM Studio 0.4.0+) ----------------------------------
|
|
120
|
+
|
|
121
|
+
async def list_models(self) -> list[dict[str, Any]]:
|
|
122
|
+
"""List currently loaded models via GET /v1/models (OpenAI-compatible)."""
|
|
123
|
+
url = f"{self.base_url}/v1/models"
|
|
124
|
+
async with httpx.AsyncClient() as client:
|
|
125
|
+
r = await client.get(url, headers=self._headers, timeout=10)
|
|
126
|
+
r.raise_for_status()
|
|
127
|
+
data = r.json()
|
|
128
|
+
return data.get("data", [])
|
|
129
|
+
|
|
130
|
+
async def load_model(self, model: str, context_length: int | None = None) -> dict[str, Any]:
|
|
131
|
+
"""Load a model into LM Studio via POST /api/v1/models/load."""
|
|
132
|
+
url = f"{self.base_url}/api/v1/models/load"
|
|
133
|
+
payload: dict[str, Any] = {"model": model}
|
|
134
|
+
if context_length is not None:
|
|
135
|
+
payload["context_length"] = context_length
|
|
136
|
+
async with httpx.AsyncClient() as client:
|
|
137
|
+
r = await client.post(url, json=payload, headers=self._headers, timeout=120)
|
|
138
|
+
r.raise_for_status()
|
|
139
|
+
return r.json()
|
|
140
|
+
|
|
141
|
+
async def unload_model(self, model: str) -> dict[str, Any]:
|
|
142
|
+
"""Unload a model from LM Studio via POST /api/v1/models/unload."""
|
|
143
|
+
url = f"{self.base_url}/api/v1/models/unload"
|
|
144
|
+
payload = {"instance_id": model}
|
|
145
|
+
async with httpx.AsyncClient() as client:
|
|
146
|
+
r = await client.post(url, json=payload, headers=self._headers, timeout=30)
|
|
147
|
+
r.raise_for_status()
|
|
148
|
+
return r.json()
|
|
@@ -15,6 +15,8 @@ logger = logging.getLogger(__name__)
|
|
|
15
15
|
|
|
16
16
|
class AsyncOllamaDriver(AsyncDriver):
|
|
17
17
|
supports_json_mode = True
|
|
18
|
+
supports_json_schema = True
|
|
19
|
+
supports_vision = True
|
|
18
20
|
|
|
19
21
|
MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
|
|
20
22
|
|
|
@@ -25,6 +27,11 @@ class AsyncOllamaDriver(AsyncDriver):
|
|
|
25
27
|
|
|
26
28
|
supports_messages = True
|
|
27
29
|
|
|
30
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
31
|
+
from .vision_helpers import _prepare_ollama_vision_messages
|
|
32
|
+
|
|
33
|
+
return _prepare_ollama_vision_messages(messages)
|
|
34
|
+
|
|
28
35
|
async def generate(self, prompt: str, options: dict[str, Any] | None = None) -> dict[str, Any]:
|
|
29
36
|
merged_options = self.options.copy()
|
|
30
37
|
if options:
|
|
@@ -36,9 +43,10 @@ class AsyncOllamaDriver(AsyncDriver):
|
|
|
36
43
|
"stream": False,
|
|
37
44
|
}
|
|
38
45
|
|
|
39
|
-
# Native JSON mode support
|
|
46
|
+
# Native JSON mode / structured output support
|
|
40
47
|
if merged_options.get("json_mode"):
|
|
41
|
-
|
|
48
|
+
json_schema = merged_options.get("json_schema")
|
|
49
|
+
payload["format"] = json_schema if json_schema else "json"
|
|
42
50
|
|
|
43
51
|
if "temperature" in merged_options:
|
|
44
52
|
payload["temperature"] = merged_options["temperature"]
|
|
@@ -74,6 +82,7 @@ class AsyncOllamaDriver(AsyncDriver):
|
|
|
74
82
|
|
|
75
83
|
async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
76
84
|
"""Use Ollama's /api/chat endpoint for multi-turn conversations."""
|
|
85
|
+
messages = self._prepare_messages(messages)
|
|
77
86
|
merged_options = self.options.copy()
|
|
78
87
|
if options:
|
|
79
88
|
merged_options.update(options)
|
|
@@ -88,7 +97,8 @@ class AsyncOllamaDriver(AsyncDriver):
|
|
|
88
97
|
}
|
|
89
98
|
|
|
90
99
|
if merged_options.get("json_mode"):
|
|
91
|
-
|
|
100
|
+
json_schema = merged_options.get("json_schema")
|
|
101
|
+
payload["format"] = json_schema if json_schema else "json"
|
|
92
102
|
|
|
93
103
|
if "temperature" in merged_options:
|
|
94
104
|
payload["temperature"] = merged_options["temperature"]
|