prompture 0.0.35__py3-none-any.whl → 0.0.38.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. prompture/__init__.py +120 -2
  2. prompture/_version.py +2 -2
  3. prompture/agent.py +924 -0
  4. prompture/agent_types.py +156 -0
  5. prompture/async_agent.py +880 -0
  6. prompture/async_conversation.py +199 -17
  7. prompture/async_driver.py +24 -0
  8. prompture/async_groups.py +551 -0
  9. prompture/conversation.py +213 -18
  10. prompture/core.py +30 -12
  11. prompture/discovery.py +24 -1
  12. prompture/driver.py +38 -0
  13. prompture/drivers/__init__.py +5 -1
  14. prompture/drivers/async_azure_driver.py +7 -1
  15. prompture/drivers/async_claude_driver.py +7 -1
  16. prompture/drivers/async_google_driver.py +212 -28
  17. prompture/drivers/async_grok_driver.py +7 -1
  18. prompture/drivers/async_groq_driver.py +7 -1
  19. prompture/drivers/async_lmstudio_driver.py +74 -5
  20. prompture/drivers/async_ollama_driver.py +13 -3
  21. prompture/drivers/async_openai_driver.py +7 -1
  22. prompture/drivers/async_openrouter_driver.py +7 -1
  23. prompture/drivers/async_registry.py +5 -1
  24. prompture/drivers/azure_driver.py +7 -1
  25. prompture/drivers/claude_driver.py +7 -1
  26. prompture/drivers/google_driver.py +217 -33
  27. prompture/drivers/grok_driver.py +7 -1
  28. prompture/drivers/groq_driver.py +7 -1
  29. prompture/drivers/lmstudio_driver.py +73 -8
  30. prompture/drivers/ollama_driver.py +16 -5
  31. prompture/drivers/openai_driver.py +7 -1
  32. prompture/drivers/openrouter_driver.py +7 -1
  33. prompture/drivers/vision_helpers.py +153 -0
  34. prompture/group_types.py +147 -0
  35. prompture/groups.py +530 -0
  36. prompture/image.py +180 -0
  37. prompture/persistence.py +254 -0
  38. prompture/persona.py +482 -0
  39. prompture/serialization.py +218 -0
  40. prompture/settings.py +1 -0
  41. prompture-0.0.38.dev2.dist-info/METADATA +369 -0
  42. prompture-0.0.38.dev2.dist-info/RECORD +77 -0
  43. prompture-0.0.35.dist-info/METADATA +0 -464
  44. prompture-0.0.35.dist-info/RECORD +0 -66
  45. {prompture-0.0.35.dist-info → prompture-0.0.38.dev2.dist-info}/WHEEL +0 -0
  46. {prompture-0.0.35.dist-info → prompture-0.0.38.dev2.dist-info}/entry_points.txt +0 -0
  47. {prompture-0.0.35.dist-info → prompture-0.0.38.dev2.dist-info}/licenses/LICENSE +0 -0
  48. {prompture-0.0.35.dist-info → prompture-0.0.38.dev2.dist-info}/top_level.txt +0 -0
@@ -4,6 +4,8 @@ from __future__ import annotations
4
4
 
5
5
  import logging
6
6
  import os
7
+ import uuid
8
+ from collections.abc import AsyncIterator
7
9
  from typing import Any
8
10
 
9
11
  import google.generativeai as genai
@@ -20,6 +22,9 @@ class AsyncGoogleDriver(CostMixin, AsyncDriver):
20
22
 
21
23
  supports_json_mode = True
22
24
  supports_json_schema = True
25
+ supports_vision = True
26
+ supports_tool_use = True
27
+ supports_streaming = True
23
28
 
24
29
  MODEL_PRICING = GoogleDriver.MODEL_PRICING
25
30
  _PRICING_UNIT = 1_000_000
@@ -48,18 +53,51 @@ class AsyncGoogleDriver(CostMixin, AsyncDriver):
48
53
  completion_cost = (completion_chars / 1_000_000) * model_pricing["completion"]
49
54
  return round(prompt_cost + completion_cost, 6)
50
55
 
56
+ def _extract_usage_metadata(self, response: Any, messages: list[dict[str, Any]]) -> dict[str, Any]:
57
+ """Extract token counts from response, falling back to character estimation."""
58
+ usage = getattr(response, "usage_metadata", None)
59
+ if usage:
60
+ prompt_tokens = getattr(usage, "prompt_token_count", 0) or 0
61
+ completion_tokens = getattr(usage, "candidates_token_count", 0) or 0
62
+ total_tokens = getattr(usage, "total_token_count", 0) or (prompt_tokens + completion_tokens)
63
+ cost = self._calculate_cost("google", self.model, prompt_tokens, completion_tokens)
64
+ else:
65
+ # Fallback: estimate from character counts
66
+ total_prompt_chars = 0
67
+ for msg in messages:
68
+ c = msg.get("content", "")
69
+ if isinstance(c, str):
70
+ total_prompt_chars += len(c)
71
+ elif isinstance(c, list):
72
+ for part in c:
73
+ if isinstance(part, str):
74
+ total_prompt_chars += len(part)
75
+ elif isinstance(part, dict) and "text" in part:
76
+ total_prompt_chars += len(part["text"])
77
+ completion_chars = len(response.text) if response.text else 0
78
+ prompt_tokens = total_prompt_chars // 4
79
+ completion_tokens = completion_chars // 4
80
+ total_tokens = prompt_tokens + completion_tokens
81
+ cost = self._calculate_cost_chars(total_prompt_chars, completion_chars)
82
+
83
+ return {
84
+ "prompt_tokens": prompt_tokens,
85
+ "completion_tokens": completion_tokens,
86
+ "total_tokens": total_tokens,
87
+ "cost": round(cost, 6),
88
+ }
89
+
51
90
  supports_messages = True
52
91
 
53
- async def generate(self, prompt: str, options: dict[str, Any] | None = None) -> dict[str, Any]:
54
- messages = [{"role": "user", "content": prompt}]
55
- return await self._do_generate(messages, options)
92
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
93
+ from .vision_helpers import _prepare_google_vision_messages
56
94
 
57
- async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
58
- return await self._do_generate(messages, options)
95
+ return _prepare_google_vision_messages(messages)
59
96
 
60
- async def _do_generate(
61
- self, messages: list[dict[str, str]], options: dict[str, Any] | None = None
62
- ) -> dict[str, Any]:
97
+ def _build_generation_args(
98
+ self, messages: list[dict[str, Any]], options: dict[str, Any] | None = None
99
+ ) -> tuple[Any, dict[str, Any], dict[str, Any]]:
100
+ """Parse messages and options into (gen_input, gen_kwargs, model_kwargs)."""
63
101
  merged_options = self.options.copy()
64
102
  if options:
65
103
  merged_options.update(options)
@@ -90,37 +128,58 @@ class AsyncGoogleDriver(CostMixin, AsyncDriver):
90
128
  role = msg.get("role", "user")
91
129
  content = msg.get("content", "")
92
130
  if role == "system":
93
- system_instruction = content
131
+ system_instruction = content if isinstance(content, str) else str(content)
94
132
  else:
95
133
  gemini_role = "model" if role == "assistant" else "user"
96
- contents.append({"role": gemini_role, "parts": [content]})
134
+ if msg.get("_vision_parts"):
135
+ contents.append({"role": gemini_role, "parts": content})
136
+ else:
137
+ contents.append({"role": gemini_role, "parts": [content]})
138
+
139
+ # For a single message, unwrap only if it has exactly one string part
140
+ if len(contents) == 1:
141
+ parts = contents[0]["parts"]
142
+ if len(parts) == 1 and isinstance(parts[0], str):
143
+ gen_input = parts[0]
144
+ else:
145
+ gen_input = contents
146
+ else:
147
+ gen_input = contents
148
+
149
+ model_kwargs: dict[str, Any] = {}
150
+ if system_instruction:
151
+ model_kwargs["system_instruction"] = system_instruction
152
+
153
+ gen_kwargs: dict[str, Any] = {
154
+ "generation_config": generation_config if generation_config else None,
155
+ "safety_settings": safety_settings if safety_settings else None,
156
+ }
157
+
158
+ return gen_input, gen_kwargs, model_kwargs
159
+
160
+ async def generate(self, prompt: str, options: dict[str, Any] | None = None) -> dict[str, Any]:
161
+ messages = [{"role": "user", "content": prompt}]
162
+ return await self._do_generate(messages, options)
163
+
164
+ async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
165
+ return await self._do_generate(self._prepare_messages(messages), options)
166
+
167
+ async def _do_generate(
168
+ self, messages: list[dict[str, str]], options: dict[str, Any] | None = None
169
+ ) -> dict[str, Any]:
170
+ gen_input, gen_kwargs, model_kwargs = self._build_generation_args(messages, options)
97
171
 
98
172
  try:
99
- model_kwargs: dict[str, Any] = {}
100
- if system_instruction:
101
- model_kwargs["system_instruction"] = system_instruction
102
173
  model = genai.GenerativeModel(self.model, **model_kwargs)
103
-
104
- gen_input: Any = contents if len(contents) != 1 else contents[0]["parts"][0]
105
- response = await model.generate_content_async(
106
- gen_input,
107
- generation_config=generation_config if generation_config else None,
108
- safety_settings=safety_settings if safety_settings else None,
109
- )
174
+ response = await model.generate_content_async(gen_input, **gen_kwargs)
110
175
 
111
176
  if not response.text:
112
177
  raise ValueError("Empty response from model")
113
178
 
114
- total_prompt_chars = sum(len(msg.get("content", "")) for msg in messages)
115
- completion_chars = len(response.text)
116
-
117
- total_cost = self._calculate_cost_chars(total_prompt_chars, completion_chars)
179
+ usage_meta = self._extract_usage_metadata(response, messages)
118
180
 
119
181
  meta = {
120
- "prompt_chars": total_prompt_chars,
121
- "completion_chars": completion_chars,
122
- "total_chars": total_prompt_chars + completion_chars,
123
- "cost": total_cost,
182
+ **usage_meta,
124
183
  "raw_response": response.prompt_feedback if hasattr(response, "prompt_feedback") else None,
125
184
  "model_name": self.model,
126
185
  }
@@ -130,3 +189,128 @@ class AsyncGoogleDriver(CostMixin, AsyncDriver):
130
189
  except Exception as e:
131
190
  logger.error(f"Google API request failed: {e}")
132
191
  raise RuntimeError(f"Google API request failed: {e}") from e
192
+
193
+ # ------------------------------------------------------------------
194
+ # Tool use
195
+ # ------------------------------------------------------------------
196
+
197
+ async def generate_messages_with_tools(
198
+ self,
199
+ messages: list[dict[str, Any]],
200
+ tools: list[dict[str, Any]],
201
+ options: dict[str, Any],
202
+ ) -> dict[str, Any]:
203
+ """Generate a response that may include tool/function calls (async)."""
204
+ gen_input, gen_kwargs, model_kwargs = self._build_generation_args(
205
+ self._prepare_messages(messages), options
206
+ )
207
+
208
+ # Convert tools from OpenAI format to Gemini function declarations
209
+ function_declarations = []
210
+ for t in tools:
211
+ if "type" in t and t["type"] == "function":
212
+ fn = t["function"]
213
+ decl = {
214
+ "name": fn["name"],
215
+ "description": fn.get("description", ""),
216
+ }
217
+ params = fn.get("parameters")
218
+ if params:
219
+ decl["parameters"] = params
220
+ function_declarations.append(decl)
221
+ elif "name" in t:
222
+ decl = {"name": t["name"], "description": t.get("description", "")}
223
+ params = t.get("parameters") or t.get("input_schema")
224
+ if params:
225
+ decl["parameters"] = params
226
+ function_declarations.append(decl)
227
+
228
+ try:
229
+ model = genai.GenerativeModel(self.model, **model_kwargs)
230
+
231
+ gemini_tools = [genai.types.Tool(function_declarations=function_declarations)]
232
+ response = await model.generate_content_async(gen_input, tools=gemini_tools, **gen_kwargs)
233
+
234
+ usage_meta = self._extract_usage_metadata(response, messages)
235
+ meta = {
236
+ **usage_meta,
237
+ "raw_response": response.prompt_feedback if hasattr(response, "prompt_feedback") else None,
238
+ "model_name": self.model,
239
+ }
240
+
241
+ text = ""
242
+ tool_calls_out: list[dict[str, Any]] = []
243
+ stop_reason = "stop"
244
+
245
+ for candidate in response.candidates:
246
+ for part in candidate.content.parts:
247
+ if hasattr(part, "text") and part.text:
248
+ text += part.text
249
+ if hasattr(part, "function_call") and part.function_call.name:
250
+ fc = part.function_call
251
+ tool_calls_out.append({
252
+ "id": str(uuid.uuid4()),
253
+ "name": fc.name,
254
+ "arguments": dict(fc.args) if fc.args else {},
255
+ })
256
+
257
+ finish_reason = getattr(candidate, "finish_reason", None)
258
+ if finish_reason is not None:
259
+ reason_map = {1: "stop", 2: "max_tokens", 3: "safety", 4: "recitation", 5: "other"}
260
+ stop_reason = reason_map.get(finish_reason, "stop")
261
+
262
+ if tool_calls_out:
263
+ stop_reason = "tool_use"
264
+
265
+ return {
266
+ "text": text,
267
+ "meta": meta,
268
+ "tool_calls": tool_calls_out,
269
+ "stop_reason": stop_reason,
270
+ }
271
+
272
+ except Exception as e:
273
+ logger.error(f"Google API tool call request failed: {e}")
274
+ raise RuntimeError(f"Google API tool call request failed: {e}") from e
275
+
276
+ # ------------------------------------------------------------------
277
+ # Streaming
278
+ # ------------------------------------------------------------------
279
+
280
+ async def generate_messages_stream(
281
+ self,
282
+ messages: list[dict[str, Any]],
283
+ options: dict[str, Any],
284
+ ) -> AsyncIterator[dict[str, Any]]:
285
+ """Yield response chunks via Gemini async streaming API."""
286
+ gen_input, gen_kwargs, model_kwargs = self._build_generation_args(
287
+ self._prepare_messages(messages), options
288
+ )
289
+
290
+ try:
291
+ model = genai.GenerativeModel(self.model, **model_kwargs)
292
+ response = await model.generate_content_async(gen_input, stream=True, **gen_kwargs)
293
+
294
+ full_text = ""
295
+ async for chunk in response:
296
+ chunk_text = getattr(chunk, "text", None) or ""
297
+ if chunk_text:
298
+ full_text += chunk_text
299
+ yield {"type": "delta", "text": chunk_text}
300
+
301
+ # After iteration completes, usage_metadata should be available
302
+ usage_meta = self._extract_usage_metadata(response, messages)
303
+
304
+ yield {
305
+ "type": "done",
306
+ "text": full_text,
307
+ "meta": {
308
+ **usage_meta,
309
+ "raw_response": {},
310
+ "model_name": self.model,
311
+ },
312
+ }
313
+
314
+ except Exception as e:
315
+ logger.error(f"Google API streaming request failed: {e}")
316
+ raise RuntimeError(f"Google API streaming request failed: {e}") from e
@@ -14,6 +14,7 @@ from .grok_driver import GrokDriver
14
14
 
15
15
  class AsyncGrokDriver(CostMixin, AsyncDriver):
16
16
  supports_json_mode = True
17
+ supports_vision = True
17
18
 
18
19
  MODEL_PRICING = GrokDriver.MODEL_PRICING
19
20
  _PRICING_UNIT = 1_000_000
@@ -25,12 +26,17 @@ class AsyncGrokDriver(CostMixin, AsyncDriver):
25
26
 
26
27
  supports_messages = True
27
28
 
29
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
30
+ from .vision_helpers import _prepare_openai_vision_messages
31
+
32
+ return _prepare_openai_vision_messages(messages)
33
+
28
34
  async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
29
35
  messages = [{"role": "user", "content": prompt}]
30
36
  return await self._do_generate(messages, options)
31
37
 
32
38
  async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
33
- return await self._do_generate(messages, options)
39
+ return await self._do_generate(self._prepare_messages(messages), options)
34
40
 
35
41
  async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
36
42
  if not self.api_key:
@@ -17,6 +17,7 @@ from .groq_driver import GroqDriver
17
17
 
18
18
  class AsyncGroqDriver(CostMixin, AsyncDriver):
19
19
  supports_json_mode = True
20
+ supports_vision = True
20
21
 
21
22
  MODEL_PRICING = GroqDriver.MODEL_PRICING
22
23
 
@@ -30,12 +31,17 @@ class AsyncGroqDriver(CostMixin, AsyncDriver):
30
31
 
31
32
  supports_messages = True
32
33
 
34
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
35
+ from .vision_helpers import _prepare_openai_vision_messages
36
+
37
+ return _prepare_openai_vision_messages(messages)
38
+
33
39
  async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
34
40
  messages = [{"role": "user", "content": prompt}]
35
41
  return await self._do_generate(messages, options)
36
42
 
37
43
  async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
38
- return await self._do_generate(messages, options)
44
+ return await self._do_generate(self._prepare_messages(messages), options)
39
45
 
40
46
  async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
41
47
  if self.client is None:
@@ -15,22 +15,48 @@ logger = logging.getLogger(__name__)
15
15
 
16
16
  class AsyncLMStudioDriver(AsyncDriver):
17
17
  supports_json_mode = True
18
+ supports_json_schema = True
19
+ supports_vision = True
18
20
 
19
21
  MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
20
22
 
21
- def __init__(self, endpoint: str | None = None, model: str = "deepseek/deepseek-r1-0528-qwen3-8b"):
23
+ def __init__(
24
+ self,
25
+ endpoint: str | None = None,
26
+ model: str = "deepseek/deepseek-r1-0528-qwen3-8b",
27
+ api_key: str | None = None,
28
+ ):
22
29
  self.endpoint = endpoint or os.getenv("LMSTUDIO_ENDPOINT", "http://127.0.0.1:1234/v1/chat/completions")
23
30
  self.model = model
24
31
  self.options: dict[str, Any] = {}
25
32
 
33
+ # Derive base_url once for reuse across management endpoints
34
+ self.base_url = self.endpoint.split("/v1/")[0]
35
+
36
+ # API key for LM Studio 0.4.0+ authentication
37
+ self.api_key = api_key or os.getenv("LMSTUDIO_API_KEY")
38
+ self._headers = self._build_headers()
39
+
26
40
  supports_messages = True
27
41
 
42
+ def _build_headers(self) -> dict[str, str]:
43
+ """Build request headers, including auth if an API key is configured."""
44
+ headers: dict[str, str] = {"Content-Type": "application/json"}
45
+ if self.api_key:
46
+ headers["Authorization"] = f"Bearer {self.api_key}"
47
+ return headers
48
+
49
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
50
+ from .vision_helpers import _prepare_openai_vision_messages
51
+
52
+ return _prepare_openai_vision_messages(messages)
53
+
28
54
  async def generate(self, prompt: str, options: dict[str, Any] | None = None) -> dict[str, Any]:
29
55
  messages = [{"role": "user", "content": prompt}]
30
56
  return await self._do_generate(messages, options)
31
57
 
32
58
  async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
33
- return await self._do_generate(messages, options)
59
+ return await self._do_generate(self._prepare_messages(messages), options)
34
60
 
35
61
  async def _do_generate(
36
62
  self, messages: list[dict[str, str]], options: dict[str, Any] | None = None
@@ -45,13 +71,25 @@ class AsyncLMStudioDriver(AsyncDriver):
45
71
  "temperature": merged_options.get("temperature", 0.7),
46
72
  }
47
73
 
48
- # Native JSON mode support
74
+ # Native JSON mode support (LM Studio requires json_schema, not json_object)
49
75
  if merged_options.get("json_mode"):
50
- payload["response_format"] = {"type": "json_object"}
76
+ json_schema = merged_options.get("json_schema")
77
+ if json_schema:
78
+ payload["response_format"] = {
79
+ "type": "json_schema",
80
+ "json_schema": {
81
+ "name": "extraction",
82
+ "schema": json_schema,
83
+ },
84
+ }
85
+ else:
86
+ # No schema provided — omit response_format entirely;
87
+ # LM Studio rejects "json_object" type.
88
+ pass
51
89
 
52
90
  async with httpx.AsyncClient() as client:
53
91
  try:
54
- r = await client.post(self.endpoint, json=payload, timeout=120)
92
+ r = await client.post(self.endpoint, json=payload, headers=self._headers, timeout=120)
55
93
  r.raise_for_status()
56
94
  response_data = r.json()
57
95
  except Exception as e:
@@ -77,3 +115,34 @@ class AsyncLMStudioDriver(AsyncDriver):
77
115
  }
78
116
 
79
117
  return {"text": text, "meta": meta}
118
+
119
+ # -- Model management (LM Studio 0.4.0+) ----------------------------------
120
+
121
+ async def list_models(self) -> list[dict[str, Any]]:
122
+ """List currently loaded models via GET /v1/models (OpenAI-compatible)."""
123
+ url = f"{self.base_url}/v1/models"
124
+ async with httpx.AsyncClient() as client:
125
+ r = await client.get(url, headers=self._headers, timeout=10)
126
+ r.raise_for_status()
127
+ data = r.json()
128
+ return data.get("data", [])
129
+
130
+ async def load_model(self, model: str, context_length: int | None = None) -> dict[str, Any]:
131
+ """Load a model into LM Studio via POST /api/v1/models/load."""
132
+ url = f"{self.base_url}/api/v1/models/load"
133
+ payload: dict[str, Any] = {"model": model}
134
+ if context_length is not None:
135
+ payload["context_length"] = context_length
136
+ async with httpx.AsyncClient() as client:
137
+ r = await client.post(url, json=payload, headers=self._headers, timeout=120)
138
+ r.raise_for_status()
139
+ return r.json()
140
+
141
+ async def unload_model(self, model: str) -> dict[str, Any]:
142
+ """Unload a model from LM Studio via POST /api/v1/models/unload."""
143
+ url = f"{self.base_url}/api/v1/models/unload"
144
+ payload = {"instance_id": model}
145
+ async with httpx.AsyncClient() as client:
146
+ r = await client.post(url, json=payload, headers=self._headers, timeout=30)
147
+ r.raise_for_status()
148
+ return r.json()
@@ -15,6 +15,8 @@ logger = logging.getLogger(__name__)
15
15
 
16
16
  class AsyncOllamaDriver(AsyncDriver):
17
17
  supports_json_mode = True
18
+ supports_json_schema = True
19
+ supports_vision = True
18
20
 
19
21
  MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
20
22
 
@@ -25,6 +27,11 @@ class AsyncOllamaDriver(AsyncDriver):
25
27
 
26
28
  supports_messages = True
27
29
 
30
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
31
+ from .vision_helpers import _prepare_ollama_vision_messages
32
+
33
+ return _prepare_ollama_vision_messages(messages)
34
+
28
35
  async def generate(self, prompt: str, options: dict[str, Any] | None = None) -> dict[str, Any]:
29
36
  merged_options = self.options.copy()
30
37
  if options:
@@ -36,9 +43,10 @@ class AsyncOllamaDriver(AsyncDriver):
36
43
  "stream": False,
37
44
  }
38
45
 
39
- # Native JSON mode support
46
+ # Native JSON mode / structured output support
40
47
  if merged_options.get("json_mode"):
41
- payload["format"] = "json"
48
+ json_schema = merged_options.get("json_schema")
49
+ payload["format"] = json_schema if json_schema else "json"
42
50
 
43
51
  if "temperature" in merged_options:
44
52
  payload["temperature"] = merged_options["temperature"]
@@ -74,6 +82,7 @@ class AsyncOllamaDriver(AsyncDriver):
74
82
 
75
83
  async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
76
84
  """Use Ollama's /api/chat endpoint for multi-turn conversations."""
85
+ messages = self._prepare_messages(messages)
77
86
  merged_options = self.options.copy()
78
87
  if options:
79
88
  merged_options.update(options)
@@ -88,7 +97,8 @@ class AsyncOllamaDriver(AsyncDriver):
88
97
  }
89
98
 
90
99
  if merged_options.get("json_mode"):
91
- payload["format"] = "json"
100
+ json_schema = merged_options.get("json_schema")
101
+ payload["format"] = json_schema if json_schema else "json"
92
102
 
93
103
  if "temperature" in merged_options:
94
104
  payload["temperature"] = merged_options["temperature"]
@@ -18,6 +18,7 @@ from .openai_driver import OpenAIDriver
18
18
  class AsyncOpenAIDriver(CostMixin, AsyncDriver):
19
19
  supports_json_mode = True
20
20
  supports_json_schema = True
21
+ supports_vision = True
21
22
 
22
23
  MODEL_PRICING = OpenAIDriver.MODEL_PRICING
23
24
 
@@ -31,12 +32,17 @@ class AsyncOpenAIDriver(CostMixin, AsyncDriver):
31
32
 
32
33
  supports_messages = True
33
34
 
35
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
36
+ from .vision_helpers import _prepare_openai_vision_messages
37
+
38
+ return _prepare_openai_vision_messages(messages)
39
+
34
40
  async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
35
41
  messages = [{"role": "user", "content": prompt}]
36
42
  return await self._do_generate(messages, options)
37
43
 
38
44
  async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
39
- return await self._do_generate(messages, options)
45
+ return await self._do_generate(self._prepare_messages(messages), options)
40
46
 
41
47
  async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
42
48
  if self.client is None:
@@ -14,6 +14,7 @@ from .openrouter_driver import OpenRouterDriver
14
14
 
15
15
  class AsyncOpenRouterDriver(CostMixin, AsyncDriver):
16
16
  supports_json_mode = True
17
+ supports_vision = True
17
18
 
18
19
  MODEL_PRICING = OpenRouterDriver.MODEL_PRICING
19
20
 
@@ -31,12 +32,17 @@ class AsyncOpenRouterDriver(CostMixin, AsyncDriver):
31
32
 
32
33
  supports_messages = True
33
34
 
35
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
36
+ from .vision_helpers import _prepare_openai_vision_messages
37
+
38
+ return _prepare_openai_vision_messages(messages)
39
+
34
40
  async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
35
41
  messages = [{"role": "user", "content": prompt}]
36
42
  return await self._do_generate(messages, options)
37
43
 
38
44
  async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
39
- return await self._do_generate(messages, options)
45
+ return await self._do_generate(self._prepare_messages(messages), options)
40
46
 
41
47
  async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
42
48
  model = options.get("model", self.model)
@@ -49,7 +49,11 @@ register_async_driver(
49
49
  )
50
50
  register_async_driver(
51
51
  "lmstudio",
52
- lambda model=None: AsyncLMStudioDriver(endpoint=settings.lmstudio_endpoint, model=model or settings.lmstudio_model),
52
+ lambda model=None: AsyncLMStudioDriver(
53
+ endpoint=settings.lmstudio_endpoint,
54
+ model=model or settings.lmstudio_model,
55
+ api_key=settings.lmstudio_api_key,
56
+ ),
53
57
  overwrite=True,
54
58
  )
55
59
  register_async_driver(
@@ -17,6 +17,7 @@ from ..driver import Driver
17
17
  class AzureDriver(CostMixin, Driver):
18
18
  supports_json_mode = True
19
19
  supports_json_schema = True
20
+ supports_vision = True
20
21
 
21
22
  # Pricing per 1K tokens (adjust if your Azure pricing differs from OpenAI defaults)
22
23
  MODEL_PRICING = {
@@ -90,12 +91,17 @@ class AzureDriver(CostMixin, Driver):
90
91
 
91
92
  supports_messages = True
92
93
 
94
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
95
+ from .vision_helpers import _prepare_openai_vision_messages
96
+
97
+ return _prepare_openai_vision_messages(messages)
98
+
93
99
  def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
94
100
  messages = [{"role": "user", "content": prompt}]
95
101
  return self._do_generate(messages, options)
96
102
 
97
103
  def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
98
- return self._do_generate(messages, options)
104
+ return self._do_generate(self._prepare_messages(messages), options)
99
105
 
100
106
  def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
101
107
  if self.client is None:
@@ -21,6 +21,7 @@ class ClaudeDriver(CostMixin, Driver):
21
21
  supports_json_schema = True
22
22
  supports_tool_use = True
23
23
  supports_streaming = True
24
+ supports_vision = True
24
25
 
25
26
  # Claude pricing per 1000 tokens (prices should be kept current with Anthropic's pricing)
26
27
  MODEL_PRICING = {
@@ -57,12 +58,17 @@ class ClaudeDriver(CostMixin, Driver):
57
58
 
58
59
  supports_messages = True
59
60
 
61
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
62
+ from .vision_helpers import _prepare_claude_vision_messages
63
+
64
+ return _prepare_claude_vision_messages(messages)
65
+
60
66
  def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
61
67
  messages = [{"role": "user", "content": prompt}]
62
68
  return self._do_generate(messages, options)
63
69
 
64
70
  def generate_messages(self, messages: list[dict[str, Any]], options: dict[str, Any]) -> dict[str, Any]:
65
- return self._do_generate(messages, options)
71
+ return self._do_generate(self._prepare_messages(messages), options)
66
72
 
67
73
  def _do_generate(self, messages: list[dict[str, Any]], options: dict[str, Any]) -> dict[str, Any]:
68
74
  if anthropic is None: