prompture 0.0.35__py3-none-any.whl → 0.0.38.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. prompture/__init__.py +120 -2
  2. prompture/_version.py +2 -2
  3. prompture/agent.py +924 -0
  4. prompture/agent_types.py +156 -0
  5. prompture/async_agent.py +880 -0
  6. prompture/async_conversation.py +199 -17
  7. prompture/async_driver.py +24 -0
  8. prompture/async_groups.py +551 -0
  9. prompture/conversation.py +213 -18
  10. prompture/core.py +30 -12
  11. prompture/discovery.py +24 -1
  12. prompture/driver.py +38 -0
  13. prompture/drivers/__init__.py +5 -1
  14. prompture/drivers/async_azure_driver.py +7 -1
  15. prompture/drivers/async_claude_driver.py +7 -1
  16. prompture/drivers/async_google_driver.py +212 -28
  17. prompture/drivers/async_grok_driver.py +7 -1
  18. prompture/drivers/async_groq_driver.py +7 -1
  19. prompture/drivers/async_lmstudio_driver.py +74 -5
  20. prompture/drivers/async_ollama_driver.py +13 -3
  21. prompture/drivers/async_openai_driver.py +7 -1
  22. prompture/drivers/async_openrouter_driver.py +7 -1
  23. prompture/drivers/async_registry.py +5 -1
  24. prompture/drivers/azure_driver.py +7 -1
  25. prompture/drivers/claude_driver.py +7 -1
  26. prompture/drivers/google_driver.py +217 -33
  27. prompture/drivers/grok_driver.py +7 -1
  28. prompture/drivers/groq_driver.py +7 -1
  29. prompture/drivers/lmstudio_driver.py +73 -8
  30. prompture/drivers/ollama_driver.py +16 -5
  31. prompture/drivers/openai_driver.py +7 -1
  32. prompture/drivers/openrouter_driver.py +7 -1
  33. prompture/drivers/vision_helpers.py +153 -0
  34. prompture/group_types.py +147 -0
  35. prompture/groups.py +530 -0
  36. prompture/image.py +180 -0
  37. prompture/persistence.py +254 -0
  38. prompture/persona.py +482 -0
  39. prompture/serialization.py +218 -0
  40. prompture/settings.py +1 -0
  41. prompture-0.0.38.dev2.dist-info/METADATA +369 -0
  42. prompture-0.0.38.dev2.dist-info/RECORD +77 -0
  43. prompture-0.0.35.dist-info/METADATA +0 -464
  44. prompture-0.0.35.dist-info/RECORD +0 -66
  45. {prompture-0.0.35.dist-info → prompture-0.0.38.dev2.dist-info}/WHEEL +0 -0
  46. {prompture-0.0.35.dist-info → prompture-0.0.38.dev2.dist-info}/entry_points.txt +0 -0
  47. {prompture-0.0.35.dist-info → prompture-0.0.38.dev2.dist-info}/licenses/LICENSE +0 -0
  48. {prompture-0.0.35.dist-info → prompture-0.0.38.dev2.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,7 @@
1
1
  import logging
2
2
  import os
3
+ import uuid
4
+ from collections.abc import Iterator
3
5
  from typing import Any, Optional
4
6
 
5
7
  import google.generativeai as genai
@@ -15,6 +17,9 @@ class GoogleDriver(CostMixin, Driver):
15
17
 
16
18
  supports_json_mode = True
17
19
  supports_json_schema = True
20
+ supports_vision = True
21
+ supports_tool_use = True
22
+ supports_streaming = True
18
23
 
19
24
  # Based on current Gemini pricing (as of 2025)
20
25
  # Source: https://cloud.google.com/vertex-ai/pricing#gemini_models
@@ -105,25 +110,62 @@ class GoogleDriver(CostMixin, Driver):
105
110
  completion_cost = (completion_chars / 1_000_000) * model_pricing["completion"]
106
111
  return round(prompt_cost + completion_cost, 6)
107
112
 
113
+ def _extract_usage_metadata(self, response: Any, messages: list[dict[str, Any]]) -> dict[str, Any]:
114
+ """Extract token counts from response, falling back to character estimation."""
115
+ usage = getattr(response, "usage_metadata", None)
116
+ if usage:
117
+ prompt_tokens = getattr(usage, "prompt_token_count", 0) or 0
118
+ completion_tokens = getattr(usage, "candidates_token_count", 0) or 0
119
+ total_tokens = getattr(usage, "total_token_count", 0) or (prompt_tokens + completion_tokens)
120
+ cost = self._calculate_cost("google", self.model, prompt_tokens, completion_tokens)
121
+ else:
122
+ # Fallback: estimate from character counts
123
+ total_prompt_chars = 0
124
+ for msg in messages:
125
+ c = msg.get("content", "")
126
+ if isinstance(c, str):
127
+ total_prompt_chars += len(c)
128
+ elif isinstance(c, list):
129
+ for part in c:
130
+ if isinstance(part, str):
131
+ total_prompt_chars += len(part)
132
+ elif isinstance(part, dict) and "text" in part:
133
+ total_prompt_chars += len(part["text"])
134
+ completion_chars = len(response.text) if response.text else 0
135
+ prompt_tokens = total_prompt_chars // 4
136
+ completion_tokens = completion_chars // 4
137
+ total_tokens = prompt_tokens + completion_tokens
138
+ cost = self._calculate_cost_chars(total_prompt_chars, completion_chars)
139
+
140
+ return {
141
+ "prompt_tokens": prompt_tokens,
142
+ "completion_tokens": completion_tokens,
143
+ "total_tokens": total_tokens,
144
+ "cost": round(cost, 6),
145
+ }
146
+
108
147
  supports_messages = True
109
148
 
110
- def generate(self, prompt: str, options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
111
- messages = [{"role": "user", "content": prompt}]
112
- return self._do_generate(messages, options)
149
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
150
+ from .vision_helpers import _prepare_google_vision_messages
113
151
 
114
- def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
115
- return self._do_generate(messages, options)
152
+ return _prepare_google_vision_messages(messages)
116
153
 
117
- def _do_generate(self, messages: list[dict[str, str]], options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
154
+ def _build_generation_args(
155
+ self, messages: list[dict[str, Any]], options: Optional[dict[str, Any]] = None
156
+ ) -> tuple[Any, dict[str, Any]]:
157
+ """Parse messages and options into (gen_input, kwargs) for generate_content.
158
+
159
+ Returns the content input and a dict of keyword arguments
160
+ (generation_config, safety_settings, model kwargs including system_instruction).
161
+ """
118
162
  merged_options = self.options.copy()
119
163
  if options:
120
164
  merged_options.update(options)
121
165
 
122
- # Extract specific options for Google's API
123
166
  generation_config = merged_options.get("generation_config", {})
124
167
  safety_settings = merged_options.get("safety_settings", {})
125
168
 
126
- # Map common options to generation_config if not present
127
169
  if "temperature" in merged_options and "temperature" not in generation_config:
128
170
  generation_config["temperature"] = merged_options["temperature"]
129
171
  if "max_tokens" in merged_options and "max_output_tokens" not in generation_config:
@@ -147,44 +189,59 @@ class GoogleDriver(CostMixin, Driver):
147
189
  role = msg.get("role", "user")
148
190
  content = msg.get("content", "")
149
191
  if role == "system":
150
- system_instruction = content
192
+ system_instruction = content if isinstance(content, str) else str(content)
151
193
  else:
152
- # Gemini uses "model" for assistant role
153
194
  gemini_role = "model" if role == "assistant" else "user"
154
- contents.append({"role": gemini_role, "parts": [content]})
195
+ if msg.get("_vision_parts"):
196
+ contents.append({"role": gemini_role, "parts": content})
197
+ else:
198
+ contents.append({"role": gemini_role, "parts": [content]})
199
+
200
+ # For a single message, unwrap only if it has exactly one string part
201
+ if len(contents) == 1:
202
+ parts = contents[0]["parts"]
203
+ if len(parts) == 1 and isinstance(parts[0], str):
204
+ gen_input = parts[0]
205
+ else:
206
+ gen_input = contents
207
+ else:
208
+ gen_input = contents
209
+
210
+ model_kwargs: dict[str, Any] = {}
211
+ if system_instruction:
212
+ model_kwargs["system_instruction"] = system_instruction
213
+
214
+ gen_kwargs: dict[str, Any] = {
215
+ "generation_config": generation_config if generation_config else None,
216
+ "safety_settings": safety_settings if safety_settings else None,
217
+ }
218
+
219
+ return gen_input, gen_kwargs, model_kwargs
220
+
221
+ def generate(self, prompt: str, options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
222
+ messages = [{"role": "user", "content": prompt}]
223
+ return self._do_generate(messages, options)
224
+
225
+ def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
226
+ return self._do_generate(self._prepare_messages(messages), options)
227
+
228
+ def _do_generate(self, messages: list[dict[str, str]], options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
229
+ gen_input, gen_kwargs, model_kwargs = self._build_generation_args(messages, options)
155
230
 
156
231
  try:
157
232
  logger.debug(f"Initializing {self.model} for generation")
158
- model_kwargs: dict[str, Any] = {}
159
- if system_instruction:
160
- model_kwargs["system_instruction"] = system_instruction
161
233
  model = genai.GenerativeModel(self.model, **model_kwargs)
162
234
 
163
- # Generate response
164
- logger.debug(f"Generating with {len(contents)} content parts")
165
- # If single user message, pass content directly for backward compatibility
166
- gen_input: Any = contents if len(contents) != 1 else contents[0]["parts"][0]
167
- response = model.generate_content(
168
- gen_input,
169
- generation_config=generation_config if generation_config else None,
170
- safety_settings=safety_settings if safety_settings else None,
171
- )
235
+ logger.debug(f"Generating with model {self.model}")
236
+ response = model.generate_content(gen_input, **gen_kwargs)
172
237
 
173
238
  if not response.text:
174
239
  raise ValueError("Empty response from model")
175
240
 
176
- # Calculate token usage and cost
177
- total_prompt_chars = sum(len(msg.get("content", "")) for msg in messages)
178
- completion_chars = len(response.text)
179
-
180
- # Google uses character-based cost estimation
181
- total_cost = self._calculate_cost_chars(total_prompt_chars, completion_chars)
241
+ usage_meta = self._extract_usage_metadata(response, messages)
182
242
 
183
243
  meta = {
184
- "prompt_chars": total_prompt_chars,
185
- "completion_chars": completion_chars,
186
- "total_chars": total_prompt_chars + completion_chars,
187
- "cost": total_cost,
244
+ **usage_meta,
188
245
  "raw_response": response.prompt_feedback if hasattr(response, "prompt_feedback") else None,
189
246
  "model_name": self.model,
190
247
  }
@@ -194,3 +251,130 @@ class GoogleDriver(CostMixin, Driver):
194
251
  except Exception as e:
195
252
  logger.error(f"Google API request failed: {e}")
196
253
  raise RuntimeError(f"Google API request failed: {e}") from e
254
+
255
+ # ------------------------------------------------------------------
256
+ # Tool use
257
+ # ------------------------------------------------------------------
258
+
259
+ def generate_messages_with_tools(
260
+ self,
261
+ messages: list[dict[str, Any]],
262
+ tools: list[dict[str, Any]],
263
+ options: dict[str, Any],
264
+ ) -> dict[str, Any]:
265
+ """Generate a response that may include tool/function calls."""
266
+ gen_input, gen_kwargs, model_kwargs = self._build_generation_args(
267
+ self._prepare_messages(messages), options
268
+ )
269
+
270
+ # Convert tools from OpenAI format to Gemini function declarations
271
+ function_declarations = []
272
+ for t in tools:
273
+ if "type" in t and t["type"] == "function":
274
+ fn = t["function"]
275
+ decl = {
276
+ "name": fn["name"],
277
+ "description": fn.get("description", ""),
278
+ }
279
+ params = fn.get("parameters")
280
+ if params:
281
+ decl["parameters"] = params
282
+ function_declarations.append(decl)
283
+ elif "name" in t:
284
+ # Already in a generic format
285
+ decl = {"name": t["name"], "description": t.get("description", "")}
286
+ params = t.get("parameters") or t.get("input_schema")
287
+ if params:
288
+ decl["parameters"] = params
289
+ function_declarations.append(decl)
290
+
291
+ try:
292
+ model = genai.GenerativeModel(self.model, **model_kwargs)
293
+
294
+ gemini_tools = [genai.types.Tool(function_declarations=function_declarations)]
295
+ response = model.generate_content(gen_input, tools=gemini_tools, **gen_kwargs)
296
+
297
+ usage_meta = self._extract_usage_metadata(response, messages)
298
+ meta = {
299
+ **usage_meta,
300
+ "raw_response": response.prompt_feedback if hasattr(response, "prompt_feedback") else None,
301
+ "model_name": self.model,
302
+ }
303
+
304
+ text = ""
305
+ tool_calls_out: list[dict[str, Any]] = []
306
+ stop_reason = "stop"
307
+
308
+ for candidate in response.candidates:
309
+ for part in candidate.content.parts:
310
+ if hasattr(part, "text") and part.text:
311
+ text += part.text
312
+ if hasattr(part, "function_call") and part.function_call.name:
313
+ fc = part.function_call
314
+ tool_calls_out.append({
315
+ "id": str(uuid.uuid4()),
316
+ "name": fc.name,
317
+ "arguments": dict(fc.args) if fc.args else {},
318
+ })
319
+
320
+ finish_reason = getattr(candidate, "finish_reason", None)
321
+ if finish_reason is not None:
322
+ # Map Gemini finish reasons to standard stop reasons
323
+ reason_map = {1: "stop", 2: "max_tokens", 3: "safety", 4: "recitation", 5: "other"}
324
+ stop_reason = reason_map.get(finish_reason, "stop")
325
+
326
+ if tool_calls_out:
327
+ stop_reason = "tool_use"
328
+
329
+ return {
330
+ "text": text,
331
+ "meta": meta,
332
+ "tool_calls": tool_calls_out,
333
+ "stop_reason": stop_reason,
334
+ }
335
+
336
+ except Exception as e:
337
+ logger.error(f"Google API tool call request failed: {e}")
338
+ raise RuntimeError(f"Google API tool call request failed: {e}") from e
339
+
340
+ # ------------------------------------------------------------------
341
+ # Streaming
342
+ # ------------------------------------------------------------------
343
+
344
+ def generate_messages_stream(
345
+ self,
346
+ messages: list[dict[str, Any]],
347
+ options: dict[str, Any],
348
+ ) -> Iterator[dict[str, Any]]:
349
+ """Yield response chunks via Gemini streaming API."""
350
+ gen_input, gen_kwargs, model_kwargs = self._build_generation_args(
351
+ self._prepare_messages(messages), options
352
+ )
353
+
354
+ try:
355
+ model = genai.GenerativeModel(self.model, **model_kwargs)
356
+ response = model.generate_content(gen_input, stream=True, **gen_kwargs)
357
+
358
+ full_text = ""
359
+ for chunk in response:
360
+ chunk_text = getattr(chunk, "text", None) or ""
361
+ if chunk_text:
362
+ full_text += chunk_text
363
+ yield {"type": "delta", "text": chunk_text}
364
+
365
+ # After iteration completes, resolve() has been called on the response
366
+ usage_meta = self._extract_usage_metadata(response, messages)
367
+
368
+ yield {
369
+ "type": "done",
370
+ "text": full_text,
371
+ "meta": {
372
+ **usage_meta,
373
+ "raw_response": {},
374
+ "model_name": self.model,
375
+ },
376
+ }
377
+
378
+ except Exception as e:
379
+ logger.error(f"Google API streaming request failed: {e}")
380
+ raise RuntimeError(f"Google API streaming request failed: {e}") from e
@@ -13,6 +13,7 @@ from ..driver import Driver
13
13
 
14
14
  class GrokDriver(CostMixin, Driver):
15
15
  supports_json_mode = True
16
+ supports_vision = True
16
17
 
17
18
  # Pricing per 1M tokens based on xAI's documentation
18
19
  _PRICING_UNIT = 1_000_000
@@ -80,12 +81,17 @@ class GrokDriver(CostMixin, Driver):
80
81
 
81
82
  supports_messages = True
82
83
 
84
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
85
+ from .vision_helpers import _prepare_openai_vision_messages
86
+
87
+ return _prepare_openai_vision_messages(messages)
88
+
83
89
  def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
84
90
  messages = [{"role": "user", "content": prompt}]
85
91
  return self._do_generate(messages, options)
86
92
 
87
93
  def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
88
- return self._do_generate(messages, options)
94
+ return self._do_generate(self._prepare_messages(messages), options)
89
95
 
90
96
  def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
91
97
  if not self.api_key:
@@ -16,6 +16,7 @@ from ..driver import Driver
16
16
 
17
17
  class GroqDriver(CostMixin, Driver):
18
18
  supports_json_mode = True
19
+ supports_vision = True
19
20
 
20
21
  # Approximate pricing per 1K tokens (to be updated with official pricing)
21
22
  # Each model entry defines token parameters and temperature support
@@ -50,12 +51,17 @@ class GroqDriver(CostMixin, Driver):
50
51
 
51
52
  supports_messages = True
52
53
 
54
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
55
+ from .vision_helpers import _prepare_openai_vision_messages
56
+
57
+ return _prepare_openai_vision_messages(messages)
58
+
53
59
  def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
54
60
  messages = [{"role": "user", "content": prompt}]
55
61
  return self._do_generate(messages, options)
56
62
 
57
63
  def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
58
- return self._do_generate(messages, options)
64
+ return self._do_generate(self._prepare_messages(messages), options)
59
65
 
60
66
  def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
61
67
  if self.client is None:
@@ -12,27 +12,47 @@ logger = logging.getLogger(__name__)
12
12
 
13
13
  class LMStudioDriver(Driver):
14
14
  supports_json_mode = True
15
+ supports_json_schema = True
16
+ supports_vision = True
15
17
 
16
18
  # LM Studio is local – costs are always zero.
17
19
  MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
18
20
 
19
- def __init__(self, endpoint: str | None = None, model: str = "deepseek/deepseek-r1-0528-qwen3-8b"):
21
+ def __init__(
22
+ self,
23
+ endpoint: str | None = None,
24
+ model: str = "deepseek/deepseek-r1-0528-qwen3-8b",
25
+ api_key: str | None = None,
26
+ ):
20
27
  # Allow override via env var
21
28
  self.endpoint = endpoint or os.getenv("LMSTUDIO_ENDPOINT", "http://127.0.0.1:1234/v1/chat/completions")
22
29
  self.model = model
23
30
  self.options: dict[str, Any] = {}
24
31
 
32
+ # Derive base_url once for reuse across management endpoints
33
+ self.base_url = self.endpoint.split("/v1/")[0]
34
+
35
+ # API key for LM Studio 0.4.0+ authentication
36
+ self.api_key = api_key or os.getenv("LMSTUDIO_API_KEY")
37
+ self._headers = self._build_headers()
38
+
25
39
  # Validate connection to LM Studio server
26
40
  self._validate_connection()
27
41
 
42
+ def _build_headers(self) -> dict[str, str]:
43
+ """Build request headers, including auth if an API key is configured."""
44
+ headers: dict[str, str] = {"Content-Type": "application/json"}
45
+ if self.api_key:
46
+ headers["Authorization"] = f"Bearer {self.api_key}"
47
+ return headers
48
+
28
49
  def _validate_connection(self):
29
50
  """Validate connection to the LM Studio server."""
30
51
  try:
31
- base_url = self.endpoint.split("/v1/")[0]
32
- health_url = f"{base_url}/v1/models"
52
+ health_url = f"{self.base_url}/v1/models"
33
53
 
34
54
  logger.debug(f"Validating connection to LM Studio server at: {health_url}")
35
- response = requests.get(health_url, timeout=5)
55
+ response = requests.get(health_url, headers=self._headers, timeout=5)
36
56
  response.raise_for_status()
37
57
  logger.debug("Connection to LM Studio server validated successfully")
38
58
  except requests.exceptions.RequestException as e:
@@ -40,12 +60,17 @@ class LMStudioDriver(Driver):
40
60
 
41
61
  supports_messages = True
42
62
 
63
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
64
+ from .vision_helpers import _prepare_openai_vision_messages
65
+
66
+ return _prepare_openai_vision_messages(messages)
67
+
43
68
  def generate(self, prompt: str, options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
44
69
  messages = [{"role": "user", "content": prompt}]
45
70
  return self._do_generate(messages, options)
46
71
 
47
72
  def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
48
- return self._do_generate(messages, options)
73
+ return self._do_generate(self._prepare_messages(messages), options)
49
74
 
50
75
  def _do_generate(self, messages: list[dict[str, str]], options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
51
76
  merged_options = self.options.copy()
@@ -58,15 +83,27 @@ class LMStudioDriver(Driver):
58
83
  "temperature": merged_options.get("temperature", 0.7),
59
84
  }
60
85
 
61
- # Native JSON mode support
86
+ # Native JSON mode support (LM Studio requires json_schema, not json_object)
62
87
  if merged_options.get("json_mode"):
63
- payload["response_format"] = {"type": "json_object"}
88
+ json_schema = merged_options.get("json_schema")
89
+ if json_schema:
90
+ payload["response_format"] = {
91
+ "type": "json_schema",
92
+ "json_schema": {
93
+ "name": "extraction",
94
+ "schema": json_schema,
95
+ },
96
+ }
97
+ else:
98
+ # No schema provided — omit response_format entirely;
99
+ # LM Studio rejects "json_object" type.
100
+ pass
64
101
 
65
102
  try:
66
103
  logger.debug(f"Sending request to LM Studio endpoint: {self.endpoint}")
67
104
  logger.debug(f"Request payload: {payload}")
68
105
 
69
- r = requests.post(self.endpoint, json=payload, timeout=120)
106
+ r = requests.post(self.endpoint, json=payload, headers=self._headers, timeout=120)
70
107
  r.raise_for_status()
71
108
 
72
109
  response_data = r.json()
@@ -104,3 +141,31 @@ class LMStudioDriver(Driver):
104
141
  }
105
142
 
106
143
  return {"text": text, "meta": meta}
144
+
145
+ # -- Model management (LM Studio 0.4.0+) ----------------------------------
146
+
147
+ def list_models(self) -> list[dict[str, Any]]:
148
+ """List currently loaded models via GET /v1/models (OpenAI-compatible)."""
149
+ url = f"{self.base_url}/v1/models"
150
+ r = requests.get(url, headers=self._headers, timeout=10)
151
+ r.raise_for_status()
152
+ data = r.json()
153
+ return data.get("data", [])
154
+
155
+ def load_model(self, model: str, context_length: int | None = None) -> dict[str, Any]:
156
+ """Load a model into LM Studio via POST /api/v1/models/load."""
157
+ url = f"{self.base_url}/api/v1/models/load"
158
+ payload: dict[str, Any] = {"model": model}
159
+ if context_length is not None:
160
+ payload["context_length"] = context_length
161
+ r = requests.post(url, json=payload, headers=self._headers, timeout=120)
162
+ r.raise_for_status()
163
+ return r.json()
164
+
165
+ def unload_model(self, model: str) -> dict[str, Any]:
166
+ """Unload a model from LM Studio via POST /api/v1/models/unload."""
167
+ url = f"{self.base_url}/api/v1/models/unload"
168
+ payload = {"instance_id": model}
169
+ r = requests.post(url, json=payload, headers=self._headers, timeout=30)
170
+ r.raise_for_status()
171
+ return r.json()
@@ -13,7 +13,9 @@ logger = logging.getLogger(__name__)
13
13
 
14
14
  class OllamaDriver(Driver):
15
15
  supports_json_mode = True
16
+ supports_json_schema = True
16
17
  supports_streaming = True
18
+ supports_vision = True
17
19
 
18
20
  # Ollama is free – costs are always zero.
19
21
  MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
@@ -46,6 +48,11 @@ class OllamaDriver(Driver):
46
48
 
47
49
  supports_messages = True
48
50
 
51
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
52
+ from .vision_helpers import _prepare_ollama_vision_messages
53
+
54
+ return _prepare_ollama_vision_messages(messages)
55
+
49
56
  def generate(self, prompt: str, options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
50
57
  # Merge instance options with call-specific options
51
58
  merged_options = self.options.copy()
@@ -58,9 +65,10 @@ class OllamaDriver(Driver):
58
65
  "stream": False,
59
66
  }
60
67
 
61
- # Native JSON mode support
68
+ # Native JSON mode / structured output support
62
69
  if merged_options.get("json_mode"):
63
- payload["format"] = "json"
70
+ json_schema = merged_options.get("json_schema")
71
+ payload["format"] = json_schema if json_schema else "json"
64
72
 
65
73
  # Add any Ollama-specific options from merged_options
66
74
  if "temperature" in merged_options:
@@ -146,7 +154,8 @@ class OllamaDriver(Driver):
146
154
  }
147
155
 
148
156
  if merged_options.get("json_mode"):
149
- payload["format"] = "json"
157
+ json_schema = merged_options.get("json_schema")
158
+ payload["format"] = json_schema if json_schema else "json"
150
159
  if "temperature" in merged_options:
151
160
  payload["temperature"] = merged_options["temperature"]
152
161
  if "top_p" in merged_options:
@@ -190,6 +199,7 @@ class OllamaDriver(Driver):
190
199
 
191
200
  def generate_messages(self, messages: list[dict[str, Any]], options: dict[str, Any]) -> dict[str, Any]:
192
201
  """Use Ollama's /api/chat endpoint for multi-turn conversations."""
202
+ messages = self._prepare_messages(messages)
193
203
  merged_options = self.options.copy()
194
204
  if options:
195
205
  merged_options.update(options)
@@ -203,9 +213,10 @@ class OllamaDriver(Driver):
203
213
  "stream": False,
204
214
  }
205
215
 
206
- # Native JSON mode support
216
+ # Native JSON mode / structured output support
207
217
  if merged_options.get("json_mode"):
208
- payload["format"] = "json"
218
+ json_schema = merged_options.get("json_schema")
219
+ payload["format"] = json_schema if json_schema else "json"
209
220
 
210
221
  if "temperature" in merged_options:
211
222
  payload["temperature"] = merged_options["temperature"]
@@ -21,6 +21,7 @@ class OpenAIDriver(CostMixin, Driver):
21
21
  supports_json_schema = True
22
22
  supports_tool_use = True
23
23
  supports_streaming = True
24
+ supports_vision = True
24
25
 
25
26
  # Approximate pricing per 1K tokens (keep updated with OpenAI's official pricing)
26
27
  # Each model entry also defines which token parameter it supports and
@@ -74,12 +75,17 @@ class OpenAIDriver(CostMixin, Driver):
74
75
 
75
76
  supports_messages = True
76
77
 
78
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
79
+ from .vision_helpers import _prepare_openai_vision_messages
80
+
81
+ return _prepare_openai_vision_messages(messages)
82
+
77
83
  def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
78
84
  messages = [{"role": "user", "content": prompt}]
79
85
  return self._do_generate(messages, options)
80
86
 
81
87
  def generate_messages(self, messages: list[dict[str, Any]], options: dict[str, Any]) -> dict[str, Any]:
82
- return self._do_generate(messages, options)
88
+ return self._do_generate(self._prepare_messages(messages), options)
83
89
 
84
90
  def _do_generate(self, messages: list[dict[str, Any]], options: dict[str, Any]) -> dict[str, Any]:
85
91
  if self.client is None:
@@ -13,6 +13,7 @@ from ..driver import Driver
13
13
 
14
14
  class OpenRouterDriver(CostMixin, Driver):
15
15
  supports_json_mode = True
16
+ supports_vision = True
16
17
 
17
18
  # Approximate pricing per 1K tokens based on OpenRouter's pricing
18
19
  # https://openrouter.ai/docs#pricing
@@ -66,12 +67,17 @@ class OpenRouterDriver(CostMixin, Driver):
66
67
 
67
68
  supports_messages = True
68
69
 
70
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
71
+ from .vision_helpers import _prepare_openai_vision_messages
72
+
73
+ return _prepare_openai_vision_messages(messages)
74
+
69
75
  def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
70
76
  messages = [{"role": "user", "content": prompt}]
71
77
  return self._do_generate(messages, options)
72
78
 
73
79
  def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
74
- return self._do_generate(messages, options)
80
+ return self._do_generate(self._prepare_messages(messages), options)
75
81
 
76
82
  def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
77
83
  if not self.api_key: