prompture 0.0.35__py3-none-any.whl → 0.0.38.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prompture/__init__.py +120 -2
- prompture/_version.py +2 -2
- prompture/agent.py +924 -0
- prompture/agent_types.py +156 -0
- prompture/async_agent.py +880 -0
- prompture/async_conversation.py +199 -17
- prompture/async_driver.py +24 -0
- prompture/async_groups.py +551 -0
- prompture/conversation.py +213 -18
- prompture/core.py +30 -12
- prompture/discovery.py +24 -1
- prompture/driver.py +38 -0
- prompture/drivers/__init__.py +5 -1
- prompture/drivers/async_azure_driver.py +7 -1
- prompture/drivers/async_claude_driver.py +7 -1
- prompture/drivers/async_google_driver.py +212 -28
- prompture/drivers/async_grok_driver.py +7 -1
- prompture/drivers/async_groq_driver.py +7 -1
- prompture/drivers/async_lmstudio_driver.py +74 -5
- prompture/drivers/async_ollama_driver.py +13 -3
- prompture/drivers/async_openai_driver.py +7 -1
- prompture/drivers/async_openrouter_driver.py +7 -1
- prompture/drivers/async_registry.py +5 -1
- prompture/drivers/azure_driver.py +7 -1
- prompture/drivers/claude_driver.py +7 -1
- prompture/drivers/google_driver.py +217 -33
- prompture/drivers/grok_driver.py +7 -1
- prompture/drivers/groq_driver.py +7 -1
- prompture/drivers/lmstudio_driver.py +73 -8
- prompture/drivers/ollama_driver.py +16 -5
- prompture/drivers/openai_driver.py +7 -1
- prompture/drivers/openrouter_driver.py +7 -1
- prompture/drivers/vision_helpers.py +153 -0
- prompture/group_types.py +147 -0
- prompture/groups.py +530 -0
- prompture/image.py +180 -0
- prompture/persistence.py +254 -0
- prompture/persona.py +482 -0
- prompture/serialization.py +218 -0
- prompture/settings.py +1 -0
- prompture-0.0.38.dev2.dist-info/METADATA +369 -0
- prompture-0.0.38.dev2.dist-info/RECORD +77 -0
- prompture-0.0.35.dist-info/METADATA +0 -464
- prompture-0.0.35.dist-info/RECORD +0 -66
- {prompture-0.0.35.dist-info → prompture-0.0.38.dev2.dist-info}/WHEEL +0 -0
- {prompture-0.0.35.dist-info → prompture-0.0.38.dev2.dist-info}/entry_points.txt +0 -0
- {prompture-0.0.35.dist-info → prompture-0.0.38.dev2.dist-info}/licenses/LICENSE +0 -0
- {prompture-0.0.35.dist-info → prompture-0.0.38.dev2.dist-info}/top_level.txt +0 -0
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import os
|
|
3
|
+
import uuid
|
|
4
|
+
from collections.abc import Iterator
|
|
3
5
|
from typing import Any, Optional
|
|
4
6
|
|
|
5
7
|
import google.generativeai as genai
|
|
@@ -15,6 +17,9 @@ class GoogleDriver(CostMixin, Driver):
|
|
|
15
17
|
|
|
16
18
|
supports_json_mode = True
|
|
17
19
|
supports_json_schema = True
|
|
20
|
+
supports_vision = True
|
|
21
|
+
supports_tool_use = True
|
|
22
|
+
supports_streaming = True
|
|
18
23
|
|
|
19
24
|
# Based on current Gemini pricing (as of 2025)
|
|
20
25
|
# Source: https://cloud.google.com/vertex-ai/pricing#gemini_models
|
|
@@ -105,25 +110,62 @@ class GoogleDriver(CostMixin, Driver):
|
|
|
105
110
|
completion_cost = (completion_chars / 1_000_000) * model_pricing["completion"]
|
|
106
111
|
return round(prompt_cost + completion_cost, 6)
|
|
107
112
|
|
|
113
|
+
def _extract_usage_metadata(self, response: Any, messages: list[dict[str, Any]]) -> dict[str, Any]:
|
|
114
|
+
"""Extract token counts from response, falling back to character estimation."""
|
|
115
|
+
usage = getattr(response, "usage_metadata", None)
|
|
116
|
+
if usage:
|
|
117
|
+
prompt_tokens = getattr(usage, "prompt_token_count", 0) or 0
|
|
118
|
+
completion_tokens = getattr(usage, "candidates_token_count", 0) or 0
|
|
119
|
+
total_tokens = getattr(usage, "total_token_count", 0) or (prompt_tokens + completion_tokens)
|
|
120
|
+
cost = self._calculate_cost("google", self.model, prompt_tokens, completion_tokens)
|
|
121
|
+
else:
|
|
122
|
+
# Fallback: estimate from character counts
|
|
123
|
+
total_prompt_chars = 0
|
|
124
|
+
for msg in messages:
|
|
125
|
+
c = msg.get("content", "")
|
|
126
|
+
if isinstance(c, str):
|
|
127
|
+
total_prompt_chars += len(c)
|
|
128
|
+
elif isinstance(c, list):
|
|
129
|
+
for part in c:
|
|
130
|
+
if isinstance(part, str):
|
|
131
|
+
total_prompt_chars += len(part)
|
|
132
|
+
elif isinstance(part, dict) and "text" in part:
|
|
133
|
+
total_prompt_chars += len(part["text"])
|
|
134
|
+
completion_chars = len(response.text) if response.text else 0
|
|
135
|
+
prompt_tokens = total_prompt_chars // 4
|
|
136
|
+
completion_tokens = completion_chars // 4
|
|
137
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
138
|
+
cost = self._calculate_cost_chars(total_prompt_chars, completion_chars)
|
|
139
|
+
|
|
140
|
+
return {
|
|
141
|
+
"prompt_tokens": prompt_tokens,
|
|
142
|
+
"completion_tokens": completion_tokens,
|
|
143
|
+
"total_tokens": total_tokens,
|
|
144
|
+
"cost": round(cost, 6),
|
|
145
|
+
}
|
|
146
|
+
|
|
108
147
|
supports_messages = True
|
|
109
148
|
|
|
110
|
-
def
|
|
111
|
-
|
|
112
|
-
return self._do_generate(messages, options)
|
|
149
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
150
|
+
from .vision_helpers import _prepare_google_vision_messages
|
|
113
151
|
|
|
114
|
-
|
|
115
|
-
return self._do_generate(messages, options)
|
|
152
|
+
return _prepare_google_vision_messages(messages)
|
|
116
153
|
|
|
117
|
-
def
|
|
154
|
+
def _build_generation_args(
|
|
155
|
+
self, messages: list[dict[str, Any]], options: Optional[dict[str, Any]] = None
|
|
156
|
+
) -> tuple[Any, dict[str, Any]]:
|
|
157
|
+
"""Parse messages and options into (gen_input, kwargs) for generate_content.
|
|
158
|
+
|
|
159
|
+
Returns the content input and a dict of keyword arguments
|
|
160
|
+
(generation_config, safety_settings, model kwargs including system_instruction).
|
|
161
|
+
"""
|
|
118
162
|
merged_options = self.options.copy()
|
|
119
163
|
if options:
|
|
120
164
|
merged_options.update(options)
|
|
121
165
|
|
|
122
|
-
# Extract specific options for Google's API
|
|
123
166
|
generation_config = merged_options.get("generation_config", {})
|
|
124
167
|
safety_settings = merged_options.get("safety_settings", {})
|
|
125
168
|
|
|
126
|
-
# Map common options to generation_config if not present
|
|
127
169
|
if "temperature" in merged_options and "temperature" not in generation_config:
|
|
128
170
|
generation_config["temperature"] = merged_options["temperature"]
|
|
129
171
|
if "max_tokens" in merged_options and "max_output_tokens" not in generation_config:
|
|
@@ -147,44 +189,59 @@ class GoogleDriver(CostMixin, Driver):
|
|
|
147
189
|
role = msg.get("role", "user")
|
|
148
190
|
content = msg.get("content", "")
|
|
149
191
|
if role == "system":
|
|
150
|
-
system_instruction = content
|
|
192
|
+
system_instruction = content if isinstance(content, str) else str(content)
|
|
151
193
|
else:
|
|
152
|
-
# Gemini uses "model" for assistant role
|
|
153
194
|
gemini_role = "model" if role == "assistant" else "user"
|
|
154
|
-
|
|
195
|
+
if msg.get("_vision_parts"):
|
|
196
|
+
contents.append({"role": gemini_role, "parts": content})
|
|
197
|
+
else:
|
|
198
|
+
contents.append({"role": gemini_role, "parts": [content]})
|
|
199
|
+
|
|
200
|
+
# For a single message, unwrap only if it has exactly one string part
|
|
201
|
+
if len(contents) == 1:
|
|
202
|
+
parts = contents[0]["parts"]
|
|
203
|
+
if len(parts) == 1 and isinstance(parts[0], str):
|
|
204
|
+
gen_input = parts[0]
|
|
205
|
+
else:
|
|
206
|
+
gen_input = contents
|
|
207
|
+
else:
|
|
208
|
+
gen_input = contents
|
|
209
|
+
|
|
210
|
+
model_kwargs: dict[str, Any] = {}
|
|
211
|
+
if system_instruction:
|
|
212
|
+
model_kwargs["system_instruction"] = system_instruction
|
|
213
|
+
|
|
214
|
+
gen_kwargs: dict[str, Any] = {
|
|
215
|
+
"generation_config": generation_config if generation_config else None,
|
|
216
|
+
"safety_settings": safety_settings if safety_settings else None,
|
|
217
|
+
}
|
|
218
|
+
|
|
219
|
+
return gen_input, gen_kwargs, model_kwargs
|
|
220
|
+
|
|
221
|
+
def generate(self, prompt: str, options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
|
|
222
|
+
messages = [{"role": "user", "content": prompt}]
|
|
223
|
+
return self._do_generate(messages, options)
|
|
224
|
+
|
|
225
|
+
def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
226
|
+
return self._do_generate(self._prepare_messages(messages), options)
|
|
227
|
+
|
|
228
|
+
def _do_generate(self, messages: list[dict[str, str]], options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
|
|
229
|
+
gen_input, gen_kwargs, model_kwargs = self._build_generation_args(messages, options)
|
|
155
230
|
|
|
156
231
|
try:
|
|
157
232
|
logger.debug(f"Initializing {self.model} for generation")
|
|
158
|
-
model_kwargs: dict[str, Any] = {}
|
|
159
|
-
if system_instruction:
|
|
160
|
-
model_kwargs["system_instruction"] = system_instruction
|
|
161
233
|
model = genai.GenerativeModel(self.model, **model_kwargs)
|
|
162
234
|
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
# If single user message, pass content directly for backward compatibility
|
|
166
|
-
gen_input: Any = contents if len(contents) != 1 else contents[0]["parts"][0]
|
|
167
|
-
response = model.generate_content(
|
|
168
|
-
gen_input,
|
|
169
|
-
generation_config=generation_config if generation_config else None,
|
|
170
|
-
safety_settings=safety_settings if safety_settings else None,
|
|
171
|
-
)
|
|
235
|
+
logger.debug(f"Generating with model {self.model}")
|
|
236
|
+
response = model.generate_content(gen_input, **gen_kwargs)
|
|
172
237
|
|
|
173
238
|
if not response.text:
|
|
174
239
|
raise ValueError("Empty response from model")
|
|
175
240
|
|
|
176
|
-
|
|
177
|
-
total_prompt_chars = sum(len(msg.get("content", "")) for msg in messages)
|
|
178
|
-
completion_chars = len(response.text)
|
|
179
|
-
|
|
180
|
-
# Google uses character-based cost estimation
|
|
181
|
-
total_cost = self._calculate_cost_chars(total_prompt_chars, completion_chars)
|
|
241
|
+
usage_meta = self._extract_usage_metadata(response, messages)
|
|
182
242
|
|
|
183
243
|
meta = {
|
|
184
|
-
|
|
185
|
-
"completion_chars": completion_chars,
|
|
186
|
-
"total_chars": total_prompt_chars + completion_chars,
|
|
187
|
-
"cost": total_cost,
|
|
244
|
+
**usage_meta,
|
|
188
245
|
"raw_response": response.prompt_feedback if hasattr(response, "prompt_feedback") else None,
|
|
189
246
|
"model_name": self.model,
|
|
190
247
|
}
|
|
@@ -194,3 +251,130 @@ class GoogleDriver(CostMixin, Driver):
|
|
|
194
251
|
except Exception as e:
|
|
195
252
|
logger.error(f"Google API request failed: {e}")
|
|
196
253
|
raise RuntimeError(f"Google API request failed: {e}") from e
|
|
254
|
+
|
|
255
|
+
# ------------------------------------------------------------------
|
|
256
|
+
# Tool use
|
|
257
|
+
# ------------------------------------------------------------------
|
|
258
|
+
|
|
259
|
+
def generate_messages_with_tools(
|
|
260
|
+
self,
|
|
261
|
+
messages: list[dict[str, Any]],
|
|
262
|
+
tools: list[dict[str, Any]],
|
|
263
|
+
options: dict[str, Any],
|
|
264
|
+
) -> dict[str, Any]:
|
|
265
|
+
"""Generate a response that may include tool/function calls."""
|
|
266
|
+
gen_input, gen_kwargs, model_kwargs = self._build_generation_args(
|
|
267
|
+
self._prepare_messages(messages), options
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
# Convert tools from OpenAI format to Gemini function declarations
|
|
271
|
+
function_declarations = []
|
|
272
|
+
for t in tools:
|
|
273
|
+
if "type" in t and t["type"] == "function":
|
|
274
|
+
fn = t["function"]
|
|
275
|
+
decl = {
|
|
276
|
+
"name": fn["name"],
|
|
277
|
+
"description": fn.get("description", ""),
|
|
278
|
+
}
|
|
279
|
+
params = fn.get("parameters")
|
|
280
|
+
if params:
|
|
281
|
+
decl["parameters"] = params
|
|
282
|
+
function_declarations.append(decl)
|
|
283
|
+
elif "name" in t:
|
|
284
|
+
# Already in a generic format
|
|
285
|
+
decl = {"name": t["name"], "description": t.get("description", "")}
|
|
286
|
+
params = t.get("parameters") or t.get("input_schema")
|
|
287
|
+
if params:
|
|
288
|
+
decl["parameters"] = params
|
|
289
|
+
function_declarations.append(decl)
|
|
290
|
+
|
|
291
|
+
try:
|
|
292
|
+
model = genai.GenerativeModel(self.model, **model_kwargs)
|
|
293
|
+
|
|
294
|
+
gemini_tools = [genai.types.Tool(function_declarations=function_declarations)]
|
|
295
|
+
response = model.generate_content(gen_input, tools=gemini_tools, **gen_kwargs)
|
|
296
|
+
|
|
297
|
+
usage_meta = self._extract_usage_metadata(response, messages)
|
|
298
|
+
meta = {
|
|
299
|
+
**usage_meta,
|
|
300
|
+
"raw_response": response.prompt_feedback if hasattr(response, "prompt_feedback") else None,
|
|
301
|
+
"model_name": self.model,
|
|
302
|
+
}
|
|
303
|
+
|
|
304
|
+
text = ""
|
|
305
|
+
tool_calls_out: list[dict[str, Any]] = []
|
|
306
|
+
stop_reason = "stop"
|
|
307
|
+
|
|
308
|
+
for candidate in response.candidates:
|
|
309
|
+
for part in candidate.content.parts:
|
|
310
|
+
if hasattr(part, "text") and part.text:
|
|
311
|
+
text += part.text
|
|
312
|
+
if hasattr(part, "function_call") and part.function_call.name:
|
|
313
|
+
fc = part.function_call
|
|
314
|
+
tool_calls_out.append({
|
|
315
|
+
"id": str(uuid.uuid4()),
|
|
316
|
+
"name": fc.name,
|
|
317
|
+
"arguments": dict(fc.args) if fc.args else {},
|
|
318
|
+
})
|
|
319
|
+
|
|
320
|
+
finish_reason = getattr(candidate, "finish_reason", None)
|
|
321
|
+
if finish_reason is not None:
|
|
322
|
+
# Map Gemini finish reasons to standard stop reasons
|
|
323
|
+
reason_map = {1: "stop", 2: "max_tokens", 3: "safety", 4: "recitation", 5: "other"}
|
|
324
|
+
stop_reason = reason_map.get(finish_reason, "stop")
|
|
325
|
+
|
|
326
|
+
if tool_calls_out:
|
|
327
|
+
stop_reason = "tool_use"
|
|
328
|
+
|
|
329
|
+
return {
|
|
330
|
+
"text": text,
|
|
331
|
+
"meta": meta,
|
|
332
|
+
"tool_calls": tool_calls_out,
|
|
333
|
+
"stop_reason": stop_reason,
|
|
334
|
+
}
|
|
335
|
+
|
|
336
|
+
except Exception as e:
|
|
337
|
+
logger.error(f"Google API tool call request failed: {e}")
|
|
338
|
+
raise RuntimeError(f"Google API tool call request failed: {e}") from e
|
|
339
|
+
|
|
340
|
+
# ------------------------------------------------------------------
|
|
341
|
+
# Streaming
|
|
342
|
+
# ------------------------------------------------------------------
|
|
343
|
+
|
|
344
|
+
def generate_messages_stream(
|
|
345
|
+
self,
|
|
346
|
+
messages: list[dict[str, Any]],
|
|
347
|
+
options: dict[str, Any],
|
|
348
|
+
) -> Iterator[dict[str, Any]]:
|
|
349
|
+
"""Yield response chunks via Gemini streaming API."""
|
|
350
|
+
gen_input, gen_kwargs, model_kwargs = self._build_generation_args(
|
|
351
|
+
self._prepare_messages(messages), options
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
try:
|
|
355
|
+
model = genai.GenerativeModel(self.model, **model_kwargs)
|
|
356
|
+
response = model.generate_content(gen_input, stream=True, **gen_kwargs)
|
|
357
|
+
|
|
358
|
+
full_text = ""
|
|
359
|
+
for chunk in response:
|
|
360
|
+
chunk_text = getattr(chunk, "text", None) or ""
|
|
361
|
+
if chunk_text:
|
|
362
|
+
full_text += chunk_text
|
|
363
|
+
yield {"type": "delta", "text": chunk_text}
|
|
364
|
+
|
|
365
|
+
# After iteration completes, resolve() has been called on the response
|
|
366
|
+
usage_meta = self._extract_usage_metadata(response, messages)
|
|
367
|
+
|
|
368
|
+
yield {
|
|
369
|
+
"type": "done",
|
|
370
|
+
"text": full_text,
|
|
371
|
+
"meta": {
|
|
372
|
+
**usage_meta,
|
|
373
|
+
"raw_response": {},
|
|
374
|
+
"model_name": self.model,
|
|
375
|
+
},
|
|
376
|
+
}
|
|
377
|
+
|
|
378
|
+
except Exception as e:
|
|
379
|
+
logger.error(f"Google API streaming request failed: {e}")
|
|
380
|
+
raise RuntimeError(f"Google API streaming request failed: {e}") from e
|
prompture/drivers/grok_driver.py
CHANGED
|
@@ -13,6 +13,7 @@ from ..driver import Driver
|
|
|
13
13
|
|
|
14
14
|
class GrokDriver(CostMixin, Driver):
|
|
15
15
|
supports_json_mode = True
|
|
16
|
+
supports_vision = True
|
|
16
17
|
|
|
17
18
|
# Pricing per 1M tokens based on xAI's documentation
|
|
18
19
|
_PRICING_UNIT = 1_000_000
|
|
@@ -80,12 +81,17 @@ class GrokDriver(CostMixin, Driver):
|
|
|
80
81
|
|
|
81
82
|
supports_messages = True
|
|
82
83
|
|
|
84
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
85
|
+
from .vision_helpers import _prepare_openai_vision_messages
|
|
86
|
+
|
|
87
|
+
return _prepare_openai_vision_messages(messages)
|
|
88
|
+
|
|
83
89
|
def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
84
90
|
messages = [{"role": "user", "content": prompt}]
|
|
85
91
|
return self._do_generate(messages, options)
|
|
86
92
|
|
|
87
93
|
def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
88
|
-
return self._do_generate(messages, options)
|
|
94
|
+
return self._do_generate(self._prepare_messages(messages), options)
|
|
89
95
|
|
|
90
96
|
def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
91
97
|
if not self.api_key:
|
prompture/drivers/groq_driver.py
CHANGED
|
@@ -16,6 +16,7 @@ from ..driver import Driver
|
|
|
16
16
|
|
|
17
17
|
class GroqDriver(CostMixin, Driver):
|
|
18
18
|
supports_json_mode = True
|
|
19
|
+
supports_vision = True
|
|
19
20
|
|
|
20
21
|
# Approximate pricing per 1K tokens (to be updated with official pricing)
|
|
21
22
|
# Each model entry defines token parameters and temperature support
|
|
@@ -50,12 +51,17 @@ class GroqDriver(CostMixin, Driver):
|
|
|
50
51
|
|
|
51
52
|
supports_messages = True
|
|
52
53
|
|
|
54
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
55
|
+
from .vision_helpers import _prepare_openai_vision_messages
|
|
56
|
+
|
|
57
|
+
return _prepare_openai_vision_messages(messages)
|
|
58
|
+
|
|
53
59
|
def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
54
60
|
messages = [{"role": "user", "content": prompt}]
|
|
55
61
|
return self._do_generate(messages, options)
|
|
56
62
|
|
|
57
63
|
def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
58
|
-
return self._do_generate(messages, options)
|
|
64
|
+
return self._do_generate(self._prepare_messages(messages), options)
|
|
59
65
|
|
|
60
66
|
def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
61
67
|
if self.client is None:
|
|
@@ -12,27 +12,47 @@ logger = logging.getLogger(__name__)
|
|
|
12
12
|
|
|
13
13
|
class LMStudioDriver(Driver):
|
|
14
14
|
supports_json_mode = True
|
|
15
|
+
supports_json_schema = True
|
|
16
|
+
supports_vision = True
|
|
15
17
|
|
|
16
18
|
# LM Studio is local – costs are always zero.
|
|
17
19
|
MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
|
|
18
20
|
|
|
19
|
-
def __init__(
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
endpoint: str | None = None,
|
|
24
|
+
model: str = "deepseek/deepseek-r1-0528-qwen3-8b",
|
|
25
|
+
api_key: str | None = None,
|
|
26
|
+
):
|
|
20
27
|
# Allow override via env var
|
|
21
28
|
self.endpoint = endpoint or os.getenv("LMSTUDIO_ENDPOINT", "http://127.0.0.1:1234/v1/chat/completions")
|
|
22
29
|
self.model = model
|
|
23
30
|
self.options: dict[str, Any] = {}
|
|
24
31
|
|
|
32
|
+
# Derive base_url once for reuse across management endpoints
|
|
33
|
+
self.base_url = self.endpoint.split("/v1/")[0]
|
|
34
|
+
|
|
35
|
+
# API key for LM Studio 0.4.0+ authentication
|
|
36
|
+
self.api_key = api_key or os.getenv("LMSTUDIO_API_KEY")
|
|
37
|
+
self._headers = self._build_headers()
|
|
38
|
+
|
|
25
39
|
# Validate connection to LM Studio server
|
|
26
40
|
self._validate_connection()
|
|
27
41
|
|
|
42
|
+
def _build_headers(self) -> dict[str, str]:
|
|
43
|
+
"""Build request headers, including auth if an API key is configured."""
|
|
44
|
+
headers: dict[str, str] = {"Content-Type": "application/json"}
|
|
45
|
+
if self.api_key:
|
|
46
|
+
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
47
|
+
return headers
|
|
48
|
+
|
|
28
49
|
def _validate_connection(self):
|
|
29
50
|
"""Validate connection to the LM Studio server."""
|
|
30
51
|
try:
|
|
31
|
-
|
|
32
|
-
health_url = f"{base_url}/v1/models"
|
|
52
|
+
health_url = f"{self.base_url}/v1/models"
|
|
33
53
|
|
|
34
54
|
logger.debug(f"Validating connection to LM Studio server at: {health_url}")
|
|
35
|
-
response = requests.get(health_url, timeout=5)
|
|
55
|
+
response = requests.get(health_url, headers=self._headers, timeout=5)
|
|
36
56
|
response.raise_for_status()
|
|
37
57
|
logger.debug("Connection to LM Studio server validated successfully")
|
|
38
58
|
except requests.exceptions.RequestException as e:
|
|
@@ -40,12 +60,17 @@ class LMStudioDriver(Driver):
|
|
|
40
60
|
|
|
41
61
|
supports_messages = True
|
|
42
62
|
|
|
63
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
64
|
+
from .vision_helpers import _prepare_openai_vision_messages
|
|
65
|
+
|
|
66
|
+
return _prepare_openai_vision_messages(messages)
|
|
67
|
+
|
|
43
68
|
def generate(self, prompt: str, options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
|
|
44
69
|
messages = [{"role": "user", "content": prompt}]
|
|
45
70
|
return self._do_generate(messages, options)
|
|
46
71
|
|
|
47
72
|
def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
48
|
-
return self._do_generate(messages, options)
|
|
73
|
+
return self._do_generate(self._prepare_messages(messages), options)
|
|
49
74
|
|
|
50
75
|
def _do_generate(self, messages: list[dict[str, str]], options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
|
|
51
76
|
merged_options = self.options.copy()
|
|
@@ -58,15 +83,27 @@ class LMStudioDriver(Driver):
|
|
|
58
83
|
"temperature": merged_options.get("temperature", 0.7),
|
|
59
84
|
}
|
|
60
85
|
|
|
61
|
-
# Native JSON mode support
|
|
86
|
+
# Native JSON mode support (LM Studio requires json_schema, not json_object)
|
|
62
87
|
if merged_options.get("json_mode"):
|
|
63
|
-
|
|
88
|
+
json_schema = merged_options.get("json_schema")
|
|
89
|
+
if json_schema:
|
|
90
|
+
payload["response_format"] = {
|
|
91
|
+
"type": "json_schema",
|
|
92
|
+
"json_schema": {
|
|
93
|
+
"name": "extraction",
|
|
94
|
+
"schema": json_schema,
|
|
95
|
+
},
|
|
96
|
+
}
|
|
97
|
+
else:
|
|
98
|
+
# No schema provided — omit response_format entirely;
|
|
99
|
+
# LM Studio rejects "json_object" type.
|
|
100
|
+
pass
|
|
64
101
|
|
|
65
102
|
try:
|
|
66
103
|
logger.debug(f"Sending request to LM Studio endpoint: {self.endpoint}")
|
|
67
104
|
logger.debug(f"Request payload: {payload}")
|
|
68
105
|
|
|
69
|
-
r = requests.post(self.endpoint, json=payload, timeout=120)
|
|
106
|
+
r = requests.post(self.endpoint, json=payload, headers=self._headers, timeout=120)
|
|
70
107
|
r.raise_for_status()
|
|
71
108
|
|
|
72
109
|
response_data = r.json()
|
|
@@ -104,3 +141,31 @@ class LMStudioDriver(Driver):
|
|
|
104
141
|
}
|
|
105
142
|
|
|
106
143
|
return {"text": text, "meta": meta}
|
|
144
|
+
|
|
145
|
+
# -- Model management (LM Studio 0.4.0+) ----------------------------------
|
|
146
|
+
|
|
147
|
+
def list_models(self) -> list[dict[str, Any]]:
|
|
148
|
+
"""List currently loaded models via GET /v1/models (OpenAI-compatible)."""
|
|
149
|
+
url = f"{self.base_url}/v1/models"
|
|
150
|
+
r = requests.get(url, headers=self._headers, timeout=10)
|
|
151
|
+
r.raise_for_status()
|
|
152
|
+
data = r.json()
|
|
153
|
+
return data.get("data", [])
|
|
154
|
+
|
|
155
|
+
def load_model(self, model: str, context_length: int | None = None) -> dict[str, Any]:
|
|
156
|
+
"""Load a model into LM Studio via POST /api/v1/models/load."""
|
|
157
|
+
url = f"{self.base_url}/api/v1/models/load"
|
|
158
|
+
payload: dict[str, Any] = {"model": model}
|
|
159
|
+
if context_length is not None:
|
|
160
|
+
payload["context_length"] = context_length
|
|
161
|
+
r = requests.post(url, json=payload, headers=self._headers, timeout=120)
|
|
162
|
+
r.raise_for_status()
|
|
163
|
+
return r.json()
|
|
164
|
+
|
|
165
|
+
def unload_model(self, model: str) -> dict[str, Any]:
|
|
166
|
+
"""Unload a model from LM Studio via POST /api/v1/models/unload."""
|
|
167
|
+
url = f"{self.base_url}/api/v1/models/unload"
|
|
168
|
+
payload = {"instance_id": model}
|
|
169
|
+
r = requests.post(url, json=payload, headers=self._headers, timeout=30)
|
|
170
|
+
r.raise_for_status()
|
|
171
|
+
return r.json()
|
|
@@ -13,7 +13,9 @@ logger = logging.getLogger(__name__)
|
|
|
13
13
|
|
|
14
14
|
class OllamaDriver(Driver):
|
|
15
15
|
supports_json_mode = True
|
|
16
|
+
supports_json_schema = True
|
|
16
17
|
supports_streaming = True
|
|
18
|
+
supports_vision = True
|
|
17
19
|
|
|
18
20
|
# Ollama is free – costs are always zero.
|
|
19
21
|
MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
|
|
@@ -46,6 +48,11 @@ class OllamaDriver(Driver):
|
|
|
46
48
|
|
|
47
49
|
supports_messages = True
|
|
48
50
|
|
|
51
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
52
|
+
from .vision_helpers import _prepare_ollama_vision_messages
|
|
53
|
+
|
|
54
|
+
return _prepare_ollama_vision_messages(messages)
|
|
55
|
+
|
|
49
56
|
def generate(self, prompt: str, options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
|
|
50
57
|
# Merge instance options with call-specific options
|
|
51
58
|
merged_options = self.options.copy()
|
|
@@ -58,9 +65,10 @@ class OllamaDriver(Driver):
|
|
|
58
65
|
"stream": False,
|
|
59
66
|
}
|
|
60
67
|
|
|
61
|
-
# Native JSON mode support
|
|
68
|
+
# Native JSON mode / structured output support
|
|
62
69
|
if merged_options.get("json_mode"):
|
|
63
|
-
|
|
70
|
+
json_schema = merged_options.get("json_schema")
|
|
71
|
+
payload["format"] = json_schema if json_schema else "json"
|
|
64
72
|
|
|
65
73
|
# Add any Ollama-specific options from merged_options
|
|
66
74
|
if "temperature" in merged_options:
|
|
@@ -146,7 +154,8 @@ class OllamaDriver(Driver):
|
|
|
146
154
|
}
|
|
147
155
|
|
|
148
156
|
if merged_options.get("json_mode"):
|
|
149
|
-
|
|
157
|
+
json_schema = merged_options.get("json_schema")
|
|
158
|
+
payload["format"] = json_schema if json_schema else "json"
|
|
150
159
|
if "temperature" in merged_options:
|
|
151
160
|
payload["temperature"] = merged_options["temperature"]
|
|
152
161
|
if "top_p" in merged_options:
|
|
@@ -190,6 +199,7 @@ class OllamaDriver(Driver):
|
|
|
190
199
|
|
|
191
200
|
def generate_messages(self, messages: list[dict[str, Any]], options: dict[str, Any]) -> dict[str, Any]:
|
|
192
201
|
"""Use Ollama's /api/chat endpoint for multi-turn conversations."""
|
|
202
|
+
messages = self._prepare_messages(messages)
|
|
193
203
|
merged_options = self.options.copy()
|
|
194
204
|
if options:
|
|
195
205
|
merged_options.update(options)
|
|
@@ -203,9 +213,10 @@ class OllamaDriver(Driver):
|
|
|
203
213
|
"stream": False,
|
|
204
214
|
}
|
|
205
215
|
|
|
206
|
-
# Native JSON mode support
|
|
216
|
+
# Native JSON mode / structured output support
|
|
207
217
|
if merged_options.get("json_mode"):
|
|
208
|
-
|
|
218
|
+
json_schema = merged_options.get("json_schema")
|
|
219
|
+
payload["format"] = json_schema if json_schema else "json"
|
|
209
220
|
|
|
210
221
|
if "temperature" in merged_options:
|
|
211
222
|
payload["temperature"] = merged_options["temperature"]
|
|
@@ -21,6 +21,7 @@ class OpenAIDriver(CostMixin, Driver):
|
|
|
21
21
|
supports_json_schema = True
|
|
22
22
|
supports_tool_use = True
|
|
23
23
|
supports_streaming = True
|
|
24
|
+
supports_vision = True
|
|
24
25
|
|
|
25
26
|
# Approximate pricing per 1K tokens (keep updated with OpenAI's official pricing)
|
|
26
27
|
# Each model entry also defines which token parameter it supports and
|
|
@@ -74,12 +75,17 @@ class OpenAIDriver(CostMixin, Driver):
|
|
|
74
75
|
|
|
75
76
|
supports_messages = True
|
|
76
77
|
|
|
78
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
79
|
+
from .vision_helpers import _prepare_openai_vision_messages
|
|
80
|
+
|
|
81
|
+
return _prepare_openai_vision_messages(messages)
|
|
82
|
+
|
|
77
83
|
def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
78
84
|
messages = [{"role": "user", "content": prompt}]
|
|
79
85
|
return self._do_generate(messages, options)
|
|
80
86
|
|
|
81
87
|
def generate_messages(self, messages: list[dict[str, Any]], options: dict[str, Any]) -> dict[str, Any]:
|
|
82
|
-
return self._do_generate(messages, options)
|
|
88
|
+
return self._do_generate(self._prepare_messages(messages), options)
|
|
83
89
|
|
|
84
90
|
def _do_generate(self, messages: list[dict[str, Any]], options: dict[str, Any]) -> dict[str, Any]:
|
|
85
91
|
if self.client is None:
|
|
@@ -13,6 +13,7 @@ from ..driver import Driver
|
|
|
13
13
|
|
|
14
14
|
class OpenRouterDriver(CostMixin, Driver):
|
|
15
15
|
supports_json_mode = True
|
|
16
|
+
supports_vision = True
|
|
16
17
|
|
|
17
18
|
# Approximate pricing per 1K tokens based on OpenRouter's pricing
|
|
18
19
|
# https://openrouter.ai/docs#pricing
|
|
@@ -66,12 +67,17 @@ class OpenRouterDriver(CostMixin, Driver):
|
|
|
66
67
|
|
|
67
68
|
supports_messages = True
|
|
68
69
|
|
|
70
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
71
|
+
from .vision_helpers import _prepare_openai_vision_messages
|
|
72
|
+
|
|
73
|
+
return _prepare_openai_vision_messages(messages)
|
|
74
|
+
|
|
69
75
|
def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
70
76
|
messages = [{"role": "user", "content": prompt}]
|
|
71
77
|
return self._do_generate(messages, options)
|
|
72
78
|
|
|
73
79
|
def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
74
|
-
return self._do_generate(messages, options)
|
|
80
|
+
return self._do_generate(self._prepare_messages(messages), options)
|
|
75
81
|
|
|
76
82
|
def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
77
83
|
if not self.api_key:
|