prompture 0.0.35__py3-none-any.whl → 0.0.38.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prompture/__init__.py +120 -2
- prompture/_version.py +2 -2
- prompture/agent.py +924 -0
- prompture/agent_types.py +156 -0
- prompture/async_agent.py +880 -0
- prompture/async_conversation.py +199 -17
- prompture/async_driver.py +24 -0
- prompture/async_groups.py +551 -0
- prompture/conversation.py +213 -18
- prompture/core.py +30 -12
- prompture/discovery.py +24 -1
- prompture/driver.py +38 -0
- prompture/drivers/__init__.py +5 -1
- prompture/drivers/async_azure_driver.py +7 -1
- prompture/drivers/async_claude_driver.py +7 -1
- prompture/drivers/async_google_driver.py +212 -28
- prompture/drivers/async_grok_driver.py +7 -1
- prompture/drivers/async_groq_driver.py +7 -1
- prompture/drivers/async_lmstudio_driver.py +74 -5
- prompture/drivers/async_ollama_driver.py +13 -3
- prompture/drivers/async_openai_driver.py +7 -1
- prompture/drivers/async_openrouter_driver.py +7 -1
- prompture/drivers/async_registry.py +5 -1
- prompture/drivers/azure_driver.py +7 -1
- prompture/drivers/claude_driver.py +7 -1
- prompture/drivers/google_driver.py +217 -33
- prompture/drivers/grok_driver.py +7 -1
- prompture/drivers/groq_driver.py +7 -1
- prompture/drivers/lmstudio_driver.py +73 -8
- prompture/drivers/ollama_driver.py +16 -5
- prompture/drivers/openai_driver.py +7 -1
- prompture/drivers/openrouter_driver.py +7 -1
- prompture/drivers/vision_helpers.py +153 -0
- prompture/group_types.py +147 -0
- prompture/groups.py +530 -0
- prompture/image.py +180 -0
- prompture/persistence.py +254 -0
- prompture/persona.py +482 -0
- prompture/serialization.py +218 -0
- prompture/settings.py +1 -0
- prompture-0.0.38.dev2.dist-info/METADATA +369 -0
- prompture-0.0.38.dev2.dist-info/RECORD +77 -0
- prompture-0.0.35.dist-info/METADATA +0 -464
- prompture-0.0.35.dist-info/RECORD +0 -66
- {prompture-0.0.35.dist-info → prompture-0.0.38.dev2.dist-info}/WHEEL +0 -0
- {prompture-0.0.35.dist-info → prompture-0.0.38.dev2.dist-info}/entry_points.txt +0 -0
- {prompture-0.0.35.dist-info → prompture-0.0.38.dev2.dist-info}/licenses/LICENSE +0 -0
- {prompture-0.0.35.dist-info → prompture-0.0.38.dev2.dist-info}/top_level.txt +0 -0
|
@@ -4,6 +4,8 @@ from __future__ import annotations
|
|
|
4
4
|
|
|
5
5
|
import logging
|
|
6
6
|
import os
|
|
7
|
+
import uuid
|
|
8
|
+
from collections.abc import AsyncIterator
|
|
7
9
|
from typing import Any
|
|
8
10
|
|
|
9
11
|
import google.generativeai as genai
|
|
@@ -20,6 +22,9 @@ class AsyncGoogleDriver(CostMixin, AsyncDriver):
|
|
|
20
22
|
|
|
21
23
|
supports_json_mode = True
|
|
22
24
|
supports_json_schema = True
|
|
25
|
+
supports_vision = True
|
|
26
|
+
supports_tool_use = True
|
|
27
|
+
supports_streaming = True
|
|
23
28
|
|
|
24
29
|
MODEL_PRICING = GoogleDriver.MODEL_PRICING
|
|
25
30
|
_PRICING_UNIT = 1_000_000
|
|
@@ -48,18 +53,51 @@ class AsyncGoogleDriver(CostMixin, AsyncDriver):
|
|
|
48
53
|
completion_cost = (completion_chars / 1_000_000) * model_pricing["completion"]
|
|
49
54
|
return round(prompt_cost + completion_cost, 6)
|
|
50
55
|
|
|
56
|
+
def _extract_usage_metadata(self, response: Any, messages: list[dict[str, Any]]) -> dict[str, Any]:
|
|
57
|
+
"""Extract token counts from response, falling back to character estimation."""
|
|
58
|
+
usage = getattr(response, "usage_metadata", None)
|
|
59
|
+
if usage:
|
|
60
|
+
prompt_tokens = getattr(usage, "prompt_token_count", 0) or 0
|
|
61
|
+
completion_tokens = getattr(usage, "candidates_token_count", 0) or 0
|
|
62
|
+
total_tokens = getattr(usage, "total_token_count", 0) or (prompt_tokens + completion_tokens)
|
|
63
|
+
cost = self._calculate_cost("google", self.model, prompt_tokens, completion_tokens)
|
|
64
|
+
else:
|
|
65
|
+
# Fallback: estimate from character counts
|
|
66
|
+
total_prompt_chars = 0
|
|
67
|
+
for msg in messages:
|
|
68
|
+
c = msg.get("content", "")
|
|
69
|
+
if isinstance(c, str):
|
|
70
|
+
total_prompt_chars += len(c)
|
|
71
|
+
elif isinstance(c, list):
|
|
72
|
+
for part in c:
|
|
73
|
+
if isinstance(part, str):
|
|
74
|
+
total_prompt_chars += len(part)
|
|
75
|
+
elif isinstance(part, dict) and "text" in part:
|
|
76
|
+
total_prompt_chars += len(part["text"])
|
|
77
|
+
completion_chars = len(response.text) if response.text else 0
|
|
78
|
+
prompt_tokens = total_prompt_chars // 4
|
|
79
|
+
completion_tokens = completion_chars // 4
|
|
80
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
81
|
+
cost = self._calculate_cost_chars(total_prompt_chars, completion_chars)
|
|
82
|
+
|
|
83
|
+
return {
|
|
84
|
+
"prompt_tokens": prompt_tokens,
|
|
85
|
+
"completion_tokens": completion_tokens,
|
|
86
|
+
"total_tokens": total_tokens,
|
|
87
|
+
"cost": round(cost, 6),
|
|
88
|
+
}
|
|
89
|
+
|
|
51
90
|
supports_messages = True
|
|
52
91
|
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
return await self._do_generate(messages, options)
|
|
92
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
93
|
+
from .vision_helpers import _prepare_google_vision_messages
|
|
56
94
|
|
|
57
|
-
|
|
58
|
-
return await self._do_generate(messages, options)
|
|
95
|
+
return _prepare_google_vision_messages(messages)
|
|
59
96
|
|
|
60
|
-
|
|
61
|
-
self, messages: list[dict[str,
|
|
62
|
-
) -> dict[str, Any]:
|
|
97
|
+
def _build_generation_args(
|
|
98
|
+
self, messages: list[dict[str, Any]], options: dict[str, Any] | None = None
|
|
99
|
+
) -> tuple[Any, dict[str, Any], dict[str, Any]]:
|
|
100
|
+
"""Parse messages and options into (gen_input, gen_kwargs, model_kwargs)."""
|
|
63
101
|
merged_options = self.options.copy()
|
|
64
102
|
if options:
|
|
65
103
|
merged_options.update(options)
|
|
@@ -90,37 +128,58 @@ class AsyncGoogleDriver(CostMixin, AsyncDriver):
|
|
|
90
128
|
role = msg.get("role", "user")
|
|
91
129
|
content = msg.get("content", "")
|
|
92
130
|
if role == "system":
|
|
93
|
-
system_instruction = content
|
|
131
|
+
system_instruction = content if isinstance(content, str) else str(content)
|
|
94
132
|
else:
|
|
95
133
|
gemini_role = "model" if role == "assistant" else "user"
|
|
96
|
-
|
|
134
|
+
if msg.get("_vision_parts"):
|
|
135
|
+
contents.append({"role": gemini_role, "parts": content})
|
|
136
|
+
else:
|
|
137
|
+
contents.append({"role": gemini_role, "parts": [content]})
|
|
138
|
+
|
|
139
|
+
# For a single message, unwrap only if it has exactly one string part
|
|
140
|
+
if len(contents) == 1:
|
|
141
|
+
parts = contents[0]["parts"]
|
|
142
|
+
if len(parts) == 1 and isinstance(parts[0], str):
|
|
143
|
+
gen_input = parts[0]
|
|
144
|
+
else:
|
|
145
|
+
gen_input = contents
|
|
146
|
+
else:
|
|
147
|
+
gen_input = contents
|
|
148
|
+
|
|
149
|
+
model_kwargs: dict[str, Any] = {}
|
|
150
|
+
if system_instruction:
|
|
151
|
+
model_kwargs["system_instruction"] = system_instruction
|
|
152
|
+
|
|
153
|
+
gen_kwargs: dict[str, Any] = {
|
|
154
|
+
"generation_config": generation_config if generation_config else None,
|
|
155
|
+
"safety_settings": safety_settings if safety_settings else None,
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
return gen_input, gen_kwargs, model_kwargs
|
|
159
|
+
|
|
160
|
+
async def generate(self, prompt: str, options: dict[str, Any] | None = None) -> dict[str, Any]:
|
|
161
|
+
messages = [{"role": "user", "content": prompt}]
|
|
162
|
+
return await self._do_generate(messages, options)
|
|
163
|
+
|
|
164
|
+
async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
165
|
+
return await self._do_generate(self._prepare_messages(messages), options)
|
|
166
|
+
|
|
167
|
+
async def _do_generate(
|
|
168
|
+
self, messages: list[dict[str, str]], options: dict[str, Any] | None = None
|
|
169
|
+
) -> dict[str, Any]:
|
|
170
|
+
gen_input, gen_kwargs, model_kwargs = self._build_generation_args(messages, options)
|
|
97
171
|
|
|
98
172
|
try:
|
|
99
|
-
model_kwargs: dict[str, Any] = {}
|
|
100
|
-
if system_instruction:
|
|
101
|
-
model_kwargs["system_instruction"] = system_instruction
|
|
102
173
|
model = genai.GenerativeModel(self.model, **model_kwargs)
|
|
103
|
-
|
|
104
|
-
gen_input: Any = contents if len(contents) != 1 else contents[0]["parts"][0]
|
|
105
|
-
response = await model.generate_content_async(
|
|
106
|
-
gen_input,
|
|
107
|
-
generation_config=generation_config if generation_config else None,
|
|
108
|
-
safety_settings=safety_settings if safety_settings else None,
|
|
109
|
-
)
|
|
174
|
+
response = await model.generate_content_async(gen_input, **gen_kwargs)
|
|
110
175
|
|
|
111
176
|
if not response.text:
|
|
112
177
|
raise ValueError("Empty response from model")
|
|
113
178
|
|
|
114
|
-
|
|
115
|
-
completion_chars = len(response.text)
|
|
116
|
-
|
|
117
|
-
total_cost = self._calculate_cost_chars(total_prompt_chars, completion_chars)
|
|
179
|
+
usage_meta = self._extract_usage_metadata(response, messages)
|
|
118
180
|
|
|
119
181
|
meta = {
|
|
120
|
-
|
|
121
|
-
"completion_chars": completion_chars,
|
|
122
|
-
"total_chars": total_prompt_chars + completion_chars,
|
|
123
|
-
"cost": total_cost,
|
|
182
|
+
**usage_meta,
|
|
124
183
|
"raw_response": response.prompt_feedback if hasattr(response, "prompt_feedback") else None,
|
|
125
184
|
"model_name": self.model,
|
|
126
185
|
}
|
|
@@ -130,3 +189,128 @@ class AsyncGoogleDriver(CostMixin, AsyncDriver):
|
|
|
130
189
|
except Exception as e:
|
|
131
190
|
logger.error(f"Google API request failed: {e}")
|
|
132
191
|
raise RuntimeError(f"Google API request failed: {e}") from e
|
|
192
|
+
|
|
193
|
+
# ------------------------------------------------------------------
|
|
194
|
+
# Tool use
|
|
195
|
+
# ------------------------------------------------------------------
|
|
196
|
+
|
|
197
|
+
async def generate_messages_with_tools(
|
|
198
|
+
self,
|
|
199
|
+
messages: list[dict[str, Any]],
|
|
200
|
+
tools: list[dict[str, Any]],
|
|
201
|
+
options: dict[str, Any],
|
|
202
|
+
) -> dict[str, Any]:
|
|
203
|
+
"""Generate a response that may include tool/function calls (async)."""
|
|
204
|
+
gen_input, gen_kwargs, model_kwargs = self._build_generation_args(
|
|
205
|
+
self._prepare_messages(messages), options
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
# Convert tools from OpenAI format to Gemini function declarations
|
|
209
|
+
function_declarations = []
|
|
210
|
+
for t in tools:
|
|
211
|
+
if "type" in t and t["type"] == "function":
|
|
212
|
+
fn = t["function"]
|
|
213
|
+
decl = {
|
|
214
|
+
"name": fn["name"],
|
|
215
|
+
"description": fn.get("description", ""),
|
|
216
|
+
}
|
|
217
|
+
params = fn.get("parameters")
|
|
218
|
+
if params:
|
|
219
|
+
decl["parameters"] = params
|
|
220
|
+
function_declarations.append(decl)
|
|
221
|
+
elif "name" in t:
|
|
222
|
+
decl = {"name": t["name"], "description": t.get("description", "")}
|
|
223
|
+
params = t.get("parameters") or t.get("input_schema")
|
|
224
|
+
if params:
|
|
225
|
+
decl["parameters"] = params
|
|
226
|
+
function_declarations.append(decl)
|
|
227
|
+
|
|
228
|
+
try:
|
|
229
|
+
model = genai.GenerativeModel(self.model, **model_kwargs)
|
|
230
|
+
|
|
231
|
+
gemini_tools = [genai.types.Tool(function_declarations=function_declarations)]
|
|
232
|
+
response = await model.generate_content_async(gen_input, tools=gemini_tools, **gen_kwargs)
|
|
233
|
+
|
|
234
|
+
usage_meta = self._extract_usage_metadata(response, messages)
|
|
235
|
+
meta = {
|
|
236
|
+
**usage_meta,
|
|
237
|
+
"raw_response": response.prompt_feedback if hasattr(response, "prompt_feedback") else None,
|
|
238
|
+
"model_name": self.model,
|
|
239
|
+
}
|
|
240
|
+
|
|
241
|
+
text = ""
|
|
242
|
+
tool_calls_out: list[dict[str, Any]] = []
|
|
243
|
+
stop_reason = "stop"
|
|
244
|
+
|
|
245
|
+
for candidate in response.candidates:
|
|
246
|
+
for part in candidate.content.parts:
|
|
247
|
+
if hasattr(part, "text") and part.text:
|
|
248
|
+
text += part.text
|
|
249
|
+
if hasattr(part, "function_call") and part.function_call.name:
|
|
250
|
+
fc = part.function_call
|
|
251
|
+
tool_calls_out.append({
|
|
252
|
+
"id": str(uuid.uuid4()),
|
|
253
|
+
"name": fc.name,
|
|
254
|
+
"arguments": dict(fc.args) if fc.args else {},
|
|
255
|
+
})
|
|
256
|
+
|
|
257
|
+
finish_reason = getattr(candidate, "finish_reason", None)
|
|
258
|
+
if finish_reason is not None:
|
|
259
|
+
reason_map = {1: "stop", 2: "max_tokens", 3: "safety", 4: "recitation", 5: "other"}
|
|
260
|
+
stop_reason = reason_map.get(finish_reason, "stop")
|
|
261
|
+
|
|
262
|
+
if tool_calls_out:
|
|
263
|
+
stop_reason = "tool_use"
|
|
264
|
+
|
|
265
|
+
return {
|
|
266
|
+
"text": text,
|
|
267
|
+
"meta": meta,
|
|
268
|
+
"tool_calls": tool_calls_out,
|
|
269
|
+
"stop_reason": stop_reason,
|
|
270
|
+
}
|
|
271
|
+
|
|
272
|
+
except Exception as e:
|
|
273
|
+
logger.error(f"Google API tool call request failed: {e}")
|
|
274
|
+
raise RuntimeError(f"Google API tool call request failed: {e}") from e
|
|
275
|
+
|
|
276
|
+
# ------------------------------------------------------------------
|
|
277
|
+
# Streaming
|
|
278
|
+
# ------------------------------------------------------------------
|
|
279
|
+
|
|
280
|
+
async def generate_messages_stream(
|
|
281
|
+
self,
|
|
282
|
+
messages: list[dict[str, Any]],
|
|
283
|
+
options: dict[str, Any],
|
|
284
|
+
) -> AsyncIterator[dict[str, Any]]:
|
|
285
|
+
"""Yield response chunks via Gemini async streaming API."""
|
|
286
|
+
gen_input, gen_kwargs, model_kwargs = self._build_generation_args(
|
|
287
|
+
self._prepare_messages(messages), options
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
try:
|
|
291
|
+
model = genai.GenerativeModel(self.model, **model_kwargs)
|
|
292
|
+
response = await model.generate_content_async(gen_input, stream=True, **gen_kwargs)
|
|
293
|
+
|
|
294
|
+
full_text = ""
|
|
295
|
+
async for chunk in response:
|
|
296
|
+
chunk_text = getattr(chunk, "text", None) or ""
|
|
297
|
+
if chunk_text:
|
|
298
|
+
full_text += chunk_text
|
|
299
|
+
yield {"type": "delta", "text": chunk_text}
|
|
300
|
+
|
|
301
|
+
# After iteration completes, usage_metadata should be available
|
|
302
|
+
usage_meta = self._extract_usage_metadata(response, messages)
|
|
303
|
+
|
|
304
|
+
yield {
|
|
305
|
+
"type": "done",
|
|
306
|
+
"text": full_text,
|
|
307
|
+
"meta": {
|
|
308
|
+
**usage_meta,
|
|
309
|
+
"raw_response": {},
|
|
310
|
+
"model_name": self.model,
|
|
311
|
+
},
|
|
312
|
+
}
|
|
313
|
+
|
|
314
|
+
except Exception as e:
|
|
315
|
+
logger.error(f"Google API streaming request failed: {e}")
|
|
316
|
+
raise RuntimeError(f"Google API streaming request failed: {e}") from e
|
|
@@ -14,6 +14,7 @@ from .grok_driver import GrokDriver
|
|
|
14
14
|
|
|
15
15
|
class AsyncGrokDriver(CostMixin, AsyncDriver):
|
|
16
16
|
supports_json_mode = True
|
|
17
|
+
supports_vision = True
|
|
17
18
|
|
|
18
19
|
MODEL_PRICING = GrokDriver.MODEL_PRICING
|
|
19
20
|
_PRICING_UNIT = 1_000_000
|
|
@@ -25,12 +26,17 @@ class AsyncGrokDriver(CostMixin, AsyncDriver):
|
|
|
25
26
|
|
|
26
27
|
supports_messages = True
|
|
27
28
|
|
|
29
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
30
|
+
from .vision_helpers import _prepare_openai_vision_messages
|
|
31
|
+
|
|
32
|
+
return _prepare_openai_vision_messages(messages)
|
|
33
|
+
|
|
28
34
|
async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
29
35
|
messages = [{"role": "user", "content": prompt}]
|
|
30
36
|
return await self._do_generate(messages, options)
|
|
31
37
|
|
|
32
38
|
async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
33
|
-
return await self._do_generate(messages, options)
|
|
39
|
+
return await self._do_generate(self._prepare_messages(messages), options)
|
|
34
40
|
|
|
35
41
|
async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
36
42
|
if not self.api_key:
|
|
@@ -17,6 +17,7 @@ from .groq_driver import GroqDriver
|
|
|
17
17
|
|
|
18
18
|
class AsyncGroqDriver(CostMixin, AsyncDriver):
|
|
19
19
|
supports_json_mode = True
|
|
20
|
+
supports_vision = True
|
|
20
21
|
|
|
21
22
|
MODEL_PRICING = GroqDriver.MODEL_PRICING
|
|
22
23
|
|
|
@@ -30,12 +31,17 @@ class AsyncGroqDriver(CostMixin, AsyncDriver):
|
|
|
30
31
|
|
|
31
32
|
supports_messages = True
|
|
32
33
|
|
|
34
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
35
|
+
from .vision_helpers import _prepare_openai_vision_messages
|
|
36
|
+
|
|
37
|
+
return _prepare_openai_vision_messages(messages)
|
|
38
|
+
|
|
33
39
|
async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
34
40
|
messages = [{"role": "user", "content": prompt}]
|
|
35
41
|
return await self._do_generate(messages, options)
|
|
36
42
|
|
|
37
43
|
async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
38
|
-
return await self._do_generate(messages, options)
|
|
44
|
+
return await self._do_generate(self._prepare_messages(messages), options)
|
|
39
45
|
|
|
40
46
|
async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
41
47
|
if self.client is None:
|
|
@@ -15,22 +15,48 @@ logger = logging.getLogger(__name__)
|
|
|
15
15
|
|
|
16
16
|
class AsyncLMStudioDriver(AsyncDriver):
|
|
17
17
|
supports_json_mode = True
|
|
18
|
+
supports_json_schema = True
|
|
19
|
+
supports_vision = True
|
|
18
20
|
|
|
19
21
|
MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
|
|
20
22
|
|
|
21
|
-
def __init__(
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
endpoint: str | None = None,
|
|
26
|
+
model: str = "deepseek/deepseek-r1-0528-qwen3-8b",
|
|
27
|
+
api_key: str | None = None,
|
|
28
|
+
):
|
|
22
29
|
self.endpoint = endpoint or os.getenv("LMSTUDIO_ENDPOINT", "http://127.0.0.1:1234/v1/chat/completions")
|
|
23
30
|
self.model = model
|
|
24
31
|
self.options: dict[str, Any] = {}
|
|
25
32
|
|
|
33
|
+
# Derive base_url once for reuse across management endpoints
|
|
34
|
+
self.base_url = self.endpoint.split("/v1/")[0]
|
|
35
|
+
|
|
36
|
+
# API key for LM Studio 0.4.0+ authentication
|
|
37
|
+
self.api_key = api_key or os.getenv("LMSTUDIO_API_KEY")
|
|
38
|
+
self._headers = self._build_headers()
|
|
39
|
+
|
|
26
40
|
supports_messages = True
|
|
27
41
|
|
|
42
|
+
def _build_headers(self) -> dict[str, str]:
|
|
43
|
+
"""Build request headers, including auth if an API key is configured."""
|
|
44
|
+
headers: dict[str, str] = {"Content-Type": "application/json"}
|
|
45
|
+
if self.api_key:
|
|
46
|
+
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
47
|
+
return headers
|
|
48
|
+
|
|
49
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
50
|
+
from .vision_helpers import _prepare_openai_vision_messages
|
|
51
|
+
|
|
52
|
+
return _prepare_openai_vision_messages(messages)
|
|
53
|
+
|
|
28
54
|
async def generate(self, prompt: str, options: dict[str, Any] | None = None) -> dict[str, Any]:
|
|
29
55
|
messages = [{"role": "user", "content": prompt}]
|
|
30
56
|
return await self._do_generate(messages, options)
|
|
31
57
|
|
|
32
58
|
async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
33
|
-
return await self._do_generate(messages, options)
|
|
59
|
+
return await self._do_generate(self._prepare_messages(messages), options)
|
|
34
60
|
|
|
35
61
|
async def _do_generate(
|
|
36
62
|
self, messages: list[dict[str, str]], options: dict[str, Any] | None = None
|
|
@@ -45,13 +71,25 @@ class AsyncLMStudioDriver(AsyncDriver):
|
|
|
45
71
|
"temperature": merged_options.get("temperature", 0.7),
|
|
46
72
|
}
|
|
47
73
|
|
|
48
|
-
# Native JSON mode support
|
|
74
|
+
# Native JSON mode support (LM Studio requires json_schema, not json_object)
|
|
49
75
|
if merged_options.get("json_mode"):
|
|
50
|
-
|
|
76
|
+
json_schema = merged_options.get("json_schema")
|
|
77
|
+
if json_schema:
|
|
78
|
+
payload["response_format"] = {
|
|
79
|
+
"type": "json_schema",
|
|
80
|
+
"json_schema": {
|
|
81
|
+
"name": "extraction",
|
|
82
|
+
"schema": json_schema,
|
|
83
|
+
},
|
|
84
|
+
}
|
|
85
|
+
else:
|
|
86
|
+
# No schema provided — omit response_format entirely;
|
|
87
|
+
# LM Studio rejects "json_object" type.
|
|
88
|
+
pass
|
|
51
89
|
|
|
52
90
|
async with httpx.AsyncClient() as client:
|
|
53
91
|
try:
|
|
54
|
-
r = await client.post(self.endpoint, json=payload, timeout=120)
|
|
92
|
+
r = await client.post(self.endpoint, json=payload, headers=self._headers, timeout=120)
|
|
55
93
|
r.raise_for_status()
|
|
56
94
|
response_data = r.json()
|
|
57
95
|
except Exception as e:
|
|
@@ -77,3 +115,34 @@ class AsyncLMStudioDriver(AsyncDriver):
|
|
|
77
115
|
}
|
|
78
116
|
|
|
79
117
|
return {"text": text, "meta": meta}
|
|
118
|
+
|
|
119
|
+
# -- Model management (LM Studio 0.4.0+) ----------------------------------
|
|
120
|
+
|
|
121
|
+
async def list_models(self) -> list[dict[str, Any]]:
|
|
122
|
+
"""List currently loaded models via GET /v1/models (OpenAI-compatible)."""
|
|
123
|
+
url = f"{self.base_url}/v1/models"
|
|
124
|
+
async with httpx.AsyncClient() as client:
|
|
125
|
+
r = await client.get(url, headers=self._headers, timeout=10)
|
|
126
|
+
r.raise_for_status()
|
|
127
|
+
data = r.json()
|
|
128
|
+
return data.get("data", [])
|
|
129
|
+
|
|
130
|
+
async def load_model(self, model: str, context_length: int | None = None) -> dict[str, Any]:
|
|
131
|
+
"""Load a model into LM Studio via POST /api/v1/models/load."""
|
|
132
|
+
url = f"{self.base_url}/api/v1/models/load"
|
|
133
|
+
payload: dict[str, Any] = {"model": model}
|
|
134
|
+
if context_length is not None:
|
|
135
|
+
payload["context_length"] = context_length
|
|
136
|
+
async with httpx.AsyncClient() as client:
|
|
137
|
+
r = await client.post(url, json=payload, headers=self._headers, timeout=120)
|
|
138
|
+
r.raise_for_status()
|
|
139
|
+
return r.json()
|
|
140
|
+
|
|
141
|
+
async def unload_model(self, model: str) -> dict[str, Any]:
|
|
142
|
+
"""Unload a model from LM Studio via POST /api/v1/models/unload."""
|
|
143
|
+
url = f"{self.base_url}/api/v1/models/unload"
|
|
144
|
+
payload = {"instance_id": model}
|
|
145
|
+
async with httpx.AsyncClient() as client:
|
|
146
|
+
r = await client.post(url, json=payload, headers=self._headers, timeout=30)
|
|
147
|
+
r.raise_for_status()
|
|
148
|
+
return r.json()
|
|
@@ -15,6 +15,8 @@ logger = logging.getLogger(__name__)
|
|
|
15
15
|
|
|
16
16
|
class AsyncOllamaDriver(AsyncDriver):
|
|
17
17
|
supports_json_mode = True
|
|
18
|
+
supports_json_schema = True
|
|
19
|
+
supports_vision = True
|
|
18
20
|
|
|
19
21
|
MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
|
|
20
22
|
|
|
@@ -25,6 +27,11 @@ class AsyncOllamaDriver(AsyncDriver):
|
|
|
25
27
|
|
|
26
28
|
supports_messages = True
|
|
27
29
|
|
|
30
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
31
|
+
from .vision_helpers import _prepare_ollama_vision_messages
|
|
32
|
+
|
|
33
|
+
return _prepare_ollama_vision_messages(messages)
|
|
34
|
+
|
|
28
35
|
async def generate(self, prompt: str, options: dict[str, Any] | None = None) -> dict[str, Any]:
|
|
29
36
|
merged_options = self.options.copy()
|
|
30
37
|
if options:
|
|
@@ -36,9 +43,10 @@ class AsyncOllamaDriver(AsyncDriver):
|
|
|
36
43
|
"stream": False,
|
|
37
44
|
}
|
|
38
45
|
|
|
39
|
-
# Native JSON mode support
|
|
46
|
+
# Native JSON mode / structured output support
|
|
40
47
|
if merged_options.get("json_mode"):
|
|
41
|
-
|
|
48
|
+
json_schema = merged_options.get("json_schema")
|
|
49
|
+
payload["format"] = json_schema if json_schema else "json"
|
|
42
50
|
|
|
43
51
|
if "temperature" in merged_options:
|
|
44
52
|
payload["temperature"] = merged_options["temperature"]
|
|
@@ -74,6 +82,7 @@ class AsyncOllamaDriver(AsyncDriver):
|
|
|
74
82
|
|
|
75
83
|
async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
76
84
|
"""Use Ollama's /api/chat endpoint for multi-turn conversations."""
|
|
85
|
+
messages = self._prepare_messages(messages)
|
|
77
86
|
merged_options = self.options.copy()
|
|
78
87
|
if options:
|
|
79
88
|
merged_options.update(options)
|
|
@@ -88,7 +97,8 @@ class AsyncOllamaDriver(AsyncDriver):
|
|
|
88
97
|
}
|
|
89
98
|
|
|
90
99
|
if merged_options.get("json_mode"):
|
|
91
|
-
|
|
100
|
+
json_schema = merged_options.get("json_schema")
|
|
101
|
+
payload["format"] = json_schema if json_schema else "json"
|
|
92
102
|
|
|
93
103
|
if "temperature" in merged_options:
|
|
94
104
|
payload["temperature"] = merged_options["temperature"]
|
|
@@ -18,6 +18,7 @@ from .openai_driver import OpenAIDriver
|
|
|
18
18
|
class AsyncOpenAIDriver(CostMixin, AsyncDriver):
|
|
19
19
|
supports_json_mode = True
|
|
20
20
|
supports_json_schema = True
|
|
21
|
+
supports_vision = True
|
|
21
22
|
|
|
22
23
|
MODEL_PRICING = OpenAIDriver.MODEL_PRICING
|
|
23
24
|
|
|
@@ -31,12 +32,17 @@ class AsyncOpenAIDriver(CostMixin, AsyncDriver):
|
|
|
31
32
|
|
|
32
33
|
supports_messages = True
|
|
33
34
|
|
|
35
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
36
|
+
from .vision_helpers import _prepare_openai_vision_messages
|
|
37
|
+
|
|
38
|
+
return _prepare_openai_vision_messages(messages)
|
|
39
|
+
|
|
34
40
|
async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
35
41
|
messages = [{"role": "user", "content": prompt}]
|
|
36
42
|
return await self._do_generate(messages, options)
|
|
37
43
|
|
|
38
44
|
async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
39
|
-
return await self._do_generate(messages, options)
|
|
45
|
+
return await self._do_generate(self._prepare_messages(messages), options)
|
|
40
46
|
|
|
41
47
|
async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
42
48
|
if self.client is None:
|
|
@@ -14,6 +14,7 @@ from .openrouter_driver import OpenRouterDriver
|
|
|
14
14
|
|
|
15
15
|
class AsyncOpenRouterDriver(CostMixin, AsyncDriver):
|
|
16
16
|
supports_json_mode = True
|
|
17
|
+
supports_vision = True
|
|
17
18
|
|
|
18
19
|
MODEL_PRICING = OpenRouterDriver.MODEL_PRICING
|
|
19
20
|
|
|
@@ -31,12 +32,17 @@ class AsyncOpenRouterDriver(CostMixin, AsyncDriver):
|
|
|
31
32
|
|
|
32
33
|
supports_messages = True
|
|
33
34
|
|
|
35
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
36
|
+
from .vision_helpers import _prepare_openai_vision_messages
|
|
37
|
+
|
|
38
|
+
return _prepare_openai_vision_messages(messages)
|
|
39
|
+
|
|
34
40
|
async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
35
41
|
messages = [{"role": "user", "content": prompt}]
|
|
36
42
|
return await self._do_generate(messages, options)
|
|
37
43
|
|
|
38
44
|
async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
39
|
-
return await self._do_generate(messages, options)
|
|
45
|
+
return await self._do_generate(self._prepare_messages(messages), options)
|
|
40
46
|
|
|
41
47
|
async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
42
48
|
model = options.get("model", self.model)
|
|
@@ -49,7 +49,11 @@ register_async_driver(
|
|
|
49
49
|
)
|
|
50
50
|
register_async_driver(
|
|
51
51
|
"lmstudio",
|
|
52
|
-
lambda model=None: AsyncLMStudioDriver(
|
|
52
|
+
lambda model=None: AsyncLMStudioDriver(
|
|
53
|
+
endpoint=settings.lmstudio_endpoint,
|
|
54
|
+
model=model or settings.lmstudio_model,
|
|
55
|
+
api_key=settings.lmstudio_api_key,
|
|
56
|
+
),
|
|
53
57
|
overwrite=True,
|
|
54
58
|
)
|
|
55
59
|
register_async_driver(
|
|
@@ -17,6 +17,7 @@ from ..driver import Driver
|
|
|
17
17
|
class AzureDriver(CostMixin, Driver):
|
|
18
18
|
supports_json_mode = True
|
|
19
19
|
supports_json_schema = True
|
|
20
|
+
supports_vision = True
|
|
20
21
|
|
|
21
22
|
# Pricing per 1K tokens (adjust if your Azure pricing differs from OpenAI defaults)
|
|
22
23
|
MODEL_PRICING = {
|
|
@@ -90,12 +91,17 @@ class AzureDriver(CostMixin, Driver):
|
|
|
90
91
|
|
|
91
92
|
supports_messages = True
|
|
92
93
|
|
|
94
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
95
|
+
from .vision_helpers import _prepare_openai_vision_messages
|
|
96
|
+
|
|
97
|
+
return _prepare_openai_vision_messages(messages)
|
|
98
|
+
|
|
93
99
|
def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
94
100
|
messages = [{"role": "user", "content": prompt}]
|
|
95
101
|
return self._do_generate(messages, options)
|
|
96
102
|
|
|
97
103
|
def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
98
|
-
return self._do_generate(messages, options)
|
|
104
|
+
return self._do_generate(self._prepare_messages(messages), options)
|
|
99
105
|
|
|
100
106
|
def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
101
107
|
if self.client is None:
|
|
@@ -21,6 +21,7 @@ class ClaudeDriver(CostMixin, Driver):
|
|
|
21
21
|
supports_json_schema = True
|
|
22
22
|
supports_tool_use = True
|
|
23
23
|
supports_streaming = True
|
|
24
|
+
supports_vision = True
|
|
24
25
|
|
|
25
26
|
# Claude pricing per 1000 tokens (prices should be kept current with Anthropic's pricing)
|
|
26
27
|
MODEL_PRICING = {
|
|
@@ -57,12 +58,17 @@ class ClaudeDriver(CostMixin, Driver):
|
|
|
57
58
|
|
|
58
59
|
supports_messages = True
|
|
59
60
|
|
|
61
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
62
|
+
from .vision_helpers import _prepare_claude_vision_messages
|
|
63
|
+
|
|
64
|
+
return _prepare_claude_vision_messages(messages)
|
|
65
|
+
|
|
60
66
|
def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
61
67
|
messages = [{"role": "user", "content": prompt}]
|
|
62
68
|
return self._do_generate(messages, options)
|
|
63
69
|
|
|
64
70
|
def generate_messages(self, messages: list[dict[str, Any]], options: dict[str, Any]) -> dict[str, Any]:
|
|
65
|
-
return self._do_generate(messages, options)
|
|
71
|
+
return self._do_generate(self._prepare_messages(messages), options)
|
|
66
72
|
|
|
67
73
|
def _do_generate(self, messages: list[dict[str, Any]], options: dict[str, Any]) -> dict[str, Any]:
|
|
68
74
|
if anthropic is None:
|