prompture 0.0.35__py3-none-any.whl → 0.0.40.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. prompture/__init__.py +132 -3
  2. prompture/_version.py +2 -2
  3. prompture/agent.py +924 -0
  4. prompture/agent_types.py +156 -0
  5. prompture/async_agent.py +880 -0
  6. prompture/async_conversation.py +208 -17
  7. prompture/async_core.py +16 -0
  8. prompture/async_driver.py +63 -0
  9. prompture/async_groups.py +551 -0
  10. prompture/conversation.py +222 -18
  11. prompture/core.py +46 -12
  12. prompture/cost_mixin.py +37 -0
  13. prompture/discovery.py +132 -44
  14. prompture/driver.py +77 -0
  15. prompture/drivers/__init__.py +5 -1
  16. prompture/drivers/async_azure_driver.py +11 -5
  17. prompture/drivers/async_claude_driver.py +184 -9
  18. prompture/drivers/async_google_driver.py +222 -28
  19. prompture/drivers/async_grok_driver.py +11 -5
  20. prompture/drivers/async_groq_driver.py +11 -5
  21. prompture/drivers/async_lmstudio_driver.py +74 -5
  22. prompture/drivers/async_ollama_driver.py +13 -3
  23. prompture/drivers/async_openai_driver.py +162 -5
  24. prompture/drivers/async_openrouter_driver.py +11 -5
  25. prompture/drivers/async_registry.py +5 -1
  26. prompture/drivers/azure_driver.py +10 -4
  27. prompture/drivers/claude_driver.py +17 -1
  28. prompture/drivers/google_driver.py +227 -33
  29. prompture/drivers/grok_driver.py +11 -5
  30. prompture/drivers/groq_driver.py +11 -5
  31. prompture/drivers/lmstudio_driver.py +73 -8
  32. prompture/drivers/ollama_driver.py +16 -5
  33. prompture/drivers/openai_driver.py +26 -11
  34. prompture/drivers/openrouter_driver.py +11 -5
  35. prompture/drivers/vision_helpers.py +153 -0
  36. prompture/group_types.py +147 -0
  37. prompture/groups.py +530 -0
  38. prompture/image.py +180 -0
  39. prompture/ledger.py +252 -0
  40. prompture/model_rates.py +112 -2
  41. prompture/persistence.py +254 -0
  42. prompture/persona.py +482 -0
  43. prompture/serialization.py +218 -0
  44. prompture/settings.py +1 -0
  45. prompture-0.0.40.dev1.dist-info/METADATA +369 -0
  46. prompture-0.0.40.dev1.dist-info/RECORD +78 -0
  47. prompture-0.0.35.dist-info/METADATA +0 -464
  48. prompture-0.0.35.dist-info/RECORD +0 -66
  49. {prompture-0.0.35.dist-info → prompture-0.0.40.dev1.dist-info}/WHEEL +0 -0
  50. {prompture-0.0.35.dist-info → prompture-0.0.40.dev1.dist-info}/entry_points.txt +0 -0
  51. {prompture-0.0.35.dist-info → prompture-0.0.40.dev1.dist-info}/licenses/LICENSE +0 -0
  52. {prompture-0.0.35.dist-info → prompture-0.0.40.dev1.dist-info}/top_level.txt +0 -0
@@ -4,6 +4,8 @@ from __future__ import annotations
4
4
 
5
5
  import logging
6
6
  import os
7
+ import uuid
8
+ from collections.abc import AsyncIterator
7
9
  from typing import Any
8
10
 
9
11
  import google.generativeai as genai
@@ -20,6 +22,9 @@ class AsyncGoogleDriver(CostMixin, AsyncDriver):
20
22
 
21
23
  supports_json_mode = True
22
24
  supports_json_schema = True
25
+ supports_vision = True
26
+ supports_tool_use = True
27
+ supports_streaming = True
23
28
 
24
29
  MODEL_PRICING = GoogleDriver.MODEL_PRICING
25
30
  _PRICING_UNIT = 1_000_000
@@ -48,18 +53,51 @@ class AsyncGoogleDriver(CostMixin, AsyncDriver):
48
53
  completion_cost = (completion_chars / 1_000_000) * model_pricing["completion"]
49
54
  return round(prompt_cost + completion_cost, 6)
50
55
 
56
+ def _extract_usage_metadata(self, response: Any, messages: list[dict[str, Any]]) -> dict[str, Any]:
57
+ """Extract token counts from response, falling back to character estimation."""
58
+ usage = getattr(response, "usage_metadata", None)
59
+ if usage:
60
+ prompt_tokens = getattr(usage, "prompt_token_count", 0) or 0
61
+ completion_tokens = getattr(usage, "candidates_token_count", 0) or 0
62
+ total_tokens = getattr(usage, "total_token_count", 0) or (prompt_tokens + completion_tokens)
63
+ cost = self._calculate_cost("google", self.model, prompt_tokens, completion_tokens)
64
+ else:
65
+ # Fallback: estimate from character counts
66
+ total_prompt_chars = 0
67
+ for msg in messages:
68
+ c = msg.get("content", "")
69
+ if isinstance(c, str):
70
+ total_prompt_chars += len(c)
71
+ elif isinstance(c, list):
72
+ for part in c:
73
+ if isinstance(part, str):
74
+ total_prompt_chars += len(part)
75
+ elif isinstance(part, dict) and "text" in part:
76
+ total_prompt_chars += len(part["text"])
77
+ completion_chars = len(response.text) if response.text else 0
78
+ prompt_tokens = total_prompt_chars // 4
79
+ completion_tokens = completion_chars // 4
80
+ total_tokens = prompt_tokens + completion_tokens
81
+ cost = self._calculate_cost_chars(total_prompt_chars, completion_chars)
82
+
83
+ return {
84
+ "prompt_tokens": prompt_tokens,
85
+ "completion_tokens": completion_tokens,
86
+ "total_tokens": total_tokens,
87
+ "cost": round(cost, 6),
88
+ }
89
+
51
90
  supports_messages = True
52
91
 
53
- async def generate(self, prompt: str, options: dict[str, Any] | None = None) -> dict[str, Any]:
54
- messages = [{"role": "user", "content": prompt}]
55
- return await self._do_generate(messages, options)
92
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
93
+ from .vision_helpers import _prepare_google_vision_messages
56
94
 
57
- async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
58
- return await self._do_generate(messages, options)
95
+ return _prepare_google_vision_messages(messages)
59
96
 
60
- async def _do_generate(
61
- self, messages: list[dict[str, str]], options: dict[str, Any] | None = None
62
- ) -> dict[str, Any]:
97
+ def _build_generation_args(
98
+ self, messages: list[dict[str, Any]], options: dict[str, Any] | None = None
99
+ ) -> tuple[Any, dict[str, Any], dict[str, Any]]:
100
+ """Parse messages and options into (gen_input, gen_kwargs, model_kwargs)."""
63
101
  merged_options = self.options.copy()
64
102
  if options:
65
103
  merged_options.update(options)
@@ -90,37 +128,65 @@ class AsyncGoogleDriver(CostMixin, AsyncDriver):
90
128
  role = msg.get("role", "user")
91
129
  content = msg.get("content", "")
92
130
  if role == "system":
93
- system_instruction = content
131
+ system_instruction = content if isinstance(content, str) else str(content)
94
132
  else:
95
133
  gemini_role = "model" if role == "assistant" else "user"
96
- contents.append({"role": gemini_role, "parts": [content]})
134
+ if msg.get("_vision_parts"):
135
+ contents.append({"role": gemini_role, "parts": content})
136
+ else:
137
+ contents.append({"role": gemini_role, "parts": [content]})
138
+
139
+ # For a single message, unwrap only if it has exactly one string part
140
+ if len(contents) == 1:
141
+ parts = contents[0]["parts"]
142
+ if len(parts) == 1 and isinstance(parts[0], str):
143
+ gen_input = parts[0]
144
+ else:
145
+ gen_input = contents
146
+ else:
147
+ gen_input = contents
148
+
149
+ model_kwargs: dict[str, Any] = {}
150
+ if system_instruction:
151
+ model_kwargs["system_instruction"] = system_instruction
152
+
153
+ gen_kwargs: dict[str, Any] = {
154
+ "generation_config": generation_config if generation_config else None,
155
+ "safety_settings": safety_settings if safety_settings else None,
156
+ }
157
+
158
+ return gen_input, gen_kwargs, model_kwargs
159
+
160
+ async def generate(self, prompt: str, options: dict[str, Any] | None = None) -> dict[str, Any]:
161
+ messages = [{"role": "user", "content": prompt}]
162
+ return await self._do_generate(messages, options)
163
+
164
+ async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
165
+ return await self._do_generate(self._prepare_messages(messages), options)
166
+
167
+ async def _do_generate(
168
+ self, messages: list[dict[str, str]], options: dict[str, Any] | None = None
169
+ ) -> dict[str, Any]:
170
+ gen_input, gen_kwargs, model_kwargs = self._build_generation_args(messages, options)
171
+
172
+ # Validate capabilities against models.dev metadata
173
+ self._validate_model_capabilities(
174
+ "google",
175
+ self.model,
176
+ using_json_schema=bool((options or {}).get("json_schema")),
177
+ )
97
178
 
98
179
  try:
99
- model_kwargs: dict[str, Any] = {}
100
- if system_instruction:
101
- model_kwargs["system_instruction"] = system_instruction
102
180
  model = genai.GenerativeModel(self.model, **model_kwargs)
103
-
104
- gen_input: Any = contents if len(contents) != 1 else contents[0]["parts"][0]
105
- response = await model.generate_content_async(
106
- gen_input,
107
- generation_config=generation_config if generation_config else None,
108
- safety_settings=safety_settings if safety_settings else None,
109
- )
181
+ response = await model.generate_content_async(gen_input, **gen_kwargs)
110
182
 
111
183
  if not response.text:
112
184
  raise ValueError("Empty response from model")
113
185
 
114
- total_prompt_chars = sum(len(msg.get("content", "")) for msg in messages)
115
- completion_chars = len(response.text)
116
-
117
- total_cost = self._calculate_cost_chars(total_prompt_chars, completion_chars)
186
+ usage_meta = self._extract_usage_metadata(response, messages)
118
187
 
119
188
  meta = {
120
- "prompt_chars": total_prompt_chars,
121
- "completion_chars": completion_chars,
122
- "total_chars": total_prompt_chars + completion_chars,
123
- "cost": total_cost,
189
+ **usage_meta,
124
190
  "raw_response": response.prompt_feedback if hasattr(response, "prompt_feedback") else None,
125
191
  "model_name": self.model,
126
192
  }
@@ -130,3 +196,131 @@ class AsyncGoogleDriver(CostMixin, AsyncDriver):
130
196
  except Exception as e:
131
197
  logger.error(f"Google API request failed: {e}")
132
198
  raise RuntimeError(f"Google API request failed: {e}") from e
199
+
200
+ # ------------------------------------------------------------------
201
+ # Tool use
202
+ # ------------------------------------------------------------------
203
+
204
+ async def generate_messages_with_tools(
205
+ self,
206
+ messages: list[dict[str, Any]],
207
+ tools: list[dict[str, Any]],
208
+ options: dict[str, Any],
209
+ ) -> dict[str, Any]:
210
+ """Generate a response that may include tool/function calls (async)."""
211
+ model = options.get("model", self.model)
212
+ self._validate_model_capabilities("google", model, using_tool_use=True)
213
+
214
+ gen_input, gen_kwargs, model_kwargs = self._build_generation_args(
215
+ self._prepare_messages(messages), options
216
+ )
217
+
218
+ # Convert tools from OpenAI format to Gemini function declarations
219
+ function_declarations = []
220
+ for t in tools:
221
+ if "type" in t and t["type"] == "function":
222
+ fn = t["function"]
223
+ decl = {
224
+ "name": fn["name"],
225
+ "description": fn.get("description", ""),
226
+ }
227
+ params = fn.get("parameters")
228
+ if params:
229
+ decl["parameters"] = params
230
+ function_declarations.append(decl)
231
+ elif "name" in t:
232
+ decl = {"name": t["name"], "description": t.get("description", "")}
233
+ params = t.get("parameters") or t.get("input_schema")
234
+ if params:
235
+ decl["parameters"] = params
236
+ function_declarations.append(decl)
237
+
238
+ try:
239
+ model = genai.GenerativeModel(self.model, **model_kwargs)
240
+
241
+ gemini_tools = [genai.types.Tool(function_declarations=function_declarations)]
242
+ response = await model.generate_content_async(gen_input, tools=gemini_tools, **gen_kwargs)
243
+
244
+ usage_meta = self._extract_usage_metadata(response, messages)
245
+ meta = {
246
+ **usage_meta,
247
+ "raw_response": response.prompt_feedback if hasattr(response, "prompt_feedback") else None,
248
+ "model_name": self.model,
249
+ }
250
+
251
+ text = ""
252
+ tool_calls_out: list[dict[str, Any]] = []
253
+ stop_reason = "stop"
254
+
255
+ for candidate in response.candidates:
256
+ for part in candidate.content.parts:
257
+ if hasattr(part, "text") and part.text:
258
+ text += part.text
259
+ if hasattr(part, "function_call") and part.function_call.name:
260
+ fc = part.function_call
261
+ tool_calls_out.append({
262
+ "id": str(uuid.uuid4()),
263
+ "name": fc.name,
264
+ "arguments": dict(fc.args) if fc.args else {},
265
+ })
266
+
267
+ finish_reason = getattr(candidate, "finish_reason", None)
268
+ if finish_reason is not None:
269
+ reason_map = {1: "stop", 2: "max_tokens", 3: "safety", 4: "recitation", 5: "other"}
270
+ stop_reason = reason_map.get(finish_reason, "stop")
271
+
272
+ if tool_calls_out:
273
+ stop_reason = "tool_use"
274
+
275
+ return {
276
+ "text": text,
277
+ "meta": meta,
278
+ "tool_calls": tool_calls_out,
279
+ "stop_reason": stop_reason,
280
+ }
281
+
282
+ except Exception as e:
283
+ logger.error(f"Google API tool call request failed: {e}")
284
+ raise RuntimeError(f"Google API tool call request failed: {e}") from e
285
+
286
+ # ------------------------------------------------------------------
287
+ # Streaming
288
+ # ------------------------------------------------------------------
289
+
290
+ async def generate_messages_stream(
291
+ self,
292
+ messages: list[dict[str, Any]],
293
+ options: dict[str, Any],
294
+ ) -> AsyncIterator[dict[str, Any]]:
295
+ """Yield response chunks via Gemini async streaming API."""
296
+ gen_input, gen_kwargs, model_kwargs = self._build_generation_args(
297
+ self._prepare_messages(messages), options
298
+ )
299
+
300
+ try:
301
+ model = genai.GenerativeModel(self.model, **model_kwargs)
302
+ response = await model.generate_content_async(gen_input, stream=True, **gen_kwargs)
303
+
304
+ full_text = ""
305
+ async for chunk in response:
306
+ chunk_text = getattr(chunk, "text", None) or ""
307
+ if chunk_text:
308
+ full_text += chunk_text
309
+ yield {"type": "delta", "text": chunk_text}
310
+
311
+ # After iteration completes, usage_metadata should be available
312
+ usage_meta = self._extract_usage_metadata(response, messages)
313
+
314
+ yield {
315
+ "type": "done",
316
+ "text": full_text,
317
+ "meta": {
318
+ **usage_meta,
319
+ "raw_response": {},
320
+ "model_name": self.model,
321
+ },
322
+ }
323
+
324
+ except Exception as e:
325
+ logger.error(f"Google API streaming request failed: {e}")
326
+ raise RuntimeError(f"Google API streaming request failed: {e}") from e
@@ -14,6 +14,7 @@ from .grok_driver import GrokDriver
14
14
 
15
15
  class AsyncGrokDriver(CostMixin, AsyncDriver):
16
16
  supports_json_mode = True
17
+ supports_vision = True
17
18
 
18
19
  MODEL_PRICING = GrokDriver.MODEL_PRICING
19
20
  _PRICING_UNIT = 1_000_000
@@ -25,12 +26,17 @@ class AsyncGrokDriver(CostMixin, AsyncDriver):
25
26
 
26
27
  supports_messages = True
27
28
 
29
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
30
+ from .vision_helpers import _prepare_openai_vision_messages
31
+
32
+ return _prepare_openai_vision_messages(messages)
33
+
28
34
  async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
29
35
  messages = [{"role": "user", "content": prompt}]
30
36
  return await self._do_generate(messages, options)
31
37
 
32
38
  async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
33
- return await self._do_generate(messages, options)
39
+ return await self._do_generate(self._prepare_messages(messages), options)
34
40
 
35
41
  async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
36
42
  if not self.api_key:
@@ -38,9 +44,9 @@ class AsyncGrokDriver(CostMixin, AsyncDriver):
38
44
 
39
45
  model = options.get("model", self.model)
40
46
 
41
- model_info = self.MODEL_PRICING.get(model, {})
42
- tokens_param = model_info.get("tokens_param", "max_tokens")
43
- supports_temperature = model_info.get("supports_temperature", True)
47
+ model_config = self._get_model_config("grok", model)
48
+ tokens_param = model_config["tokens_param"]
49
+ supports_temperature = model_config["supports_temperature"]
44
50
 
45
51
  opts = {"temperature": 1.0, "max_tokens": 512, **options}
46
52
 
@@ -82,7 +88,7 @@ class AsyncGrokDriver(CostMixin, AsyncDriver):
82
88
  "prompt_tokens": prompt_tokens,
83
89
  "completion_tokens": completion_tokens,
84
90
  "total_tokens": total_tokens,
85
- "cost": total_cost,
91
+ "cost": round(total_cost, 6),
86
92
  "raw_response": resp,
87
93
  "model_name": model,
88
94
  }
@@ -17,6 +17,7 @@ from .groq_driver import GroqDriver
17
17
 
18
18
  class AsyncGroqDriver(CostMixin, AsyncDriver):
19
19
  supports_json_mode = True
20
+ supports_vision = True
20
21
 
21
22
  MODEL_PRICING = GroqDriver.MODEL_PRICING
22
23
 
@@ -30,12 +31,17 @@ class AsyncGroqDriver(CostMixin, AsyncDriver):
30
31
 
31
32
  supports_messages = True
32
33
 
34
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
35
+ from .vision_helpers import _prepare_openai_vision_messages
36
+
37
+ return _prepare_openai_vision_messages(messages)
38
+
33
39
  async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
34
40
  messages = [{"role": "user", "content": prompt}]
35
41
  return await self._do_generate(messages, options)
36
42
 
37
43
  async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
38
- return await self._do_generate(messages, options)
44
+ return await self._do_generate(self._prepare_messages(messages), options)
39
45
 
40
46
  async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
41
47
  if self.client is None:
@@ -43,9 +49,9 @@ class AsyncGroqDriver(CostMixin, AsyncDriver):
43
49
 
44
50
  model = options.get("model", self.model)
45
51
 
46
- model_info = self.MODEL_PRICING.get(model, {})
47
- tokens_param = model_info.get("tokens_param", "max_tokens")
48
- supports_temperature = model_info.get("supports_temperature", True)
52
+ model_config = self._get_model_config("groq", model)
53
+ tokens_param = model_config["tokens_param"]
54
+ supports_temperature = model_config["supports_temperature"]
49
55
 
50
56
  opts = {"temperature": 0.7, "max_tokens": 512, **options}
51
57
 
@@ -75,7 +81,7 @@ class AsyncGroqDriver(CostMixin, AsyncDriver):
75
81
  "prompt_tokens": prompt_tokens,
76
82
  "completion_tokens": completion_tokens,
77
83
  "total_tokens": total_tokens,
78
- "cost": total_cost,
84
+ "cost": round(total_cost, 6),
79
85
  "raw_response": resp.model_dump(),
80
86
  "model_name": model,
81
87
  }
@@ -15,22 +15,48 @@ logger = logging.getLogger(__name__)
15
15
 
16
16
  class AsyncLMStudioDriver(AsyncDriver):
17
17
  supports_json_mode = True
18
+ supports_json_schema = True
19
+ supports_vision = True
18
20
 
19
21
  MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
20
22
 
21
- def __init__(self, endpoint: str | None = None, model: str = "deepseek/deepseek-r1-0528-qwen3-8b"):
23
+ def __init__(
24
+ self,
25
+ endpoint: str | None = None,
26
+ model: str = "deepseek/deepseek-r1-0528-qwen3-8b",
27
+ api_key: str | None = None,
28
+ ):
22
29
  self.endpoint = endpoint or os.getenv("LMSTUDIO_ENDPOINT", "http://127.0.0.1:1234/v1/chat/completions")
23
30
  self.model = model
24
31
  self.options: dict[str, Any] = {}
25
32
 
33
+ # Derive base_url once for reuse across management endpoints
34
+ self.base_url = self.endpoint.split("/v1/")[0]
35
+
36
+ # API key for LM Studio 0.4.0+ authentication
37
+ self.api_key = api_key or os.getenv("LMSTUDIO_API_KEY")
38
+ self._headers = self._build_headers()
39
+
26
40
  supports_messages = True
27
41
 
42
+ def _build_headers(self) -> dict[str, str]:
43
+ """Build request headers, including auth if an API key is configured."""
44
+ headers: dict[str, str] = {"Content-Type": "application/json"}
45
+ if self.api_key:
46
+ headers["Authorization"] = f"Bearer {self.api_key}"
47
+ return headers
48
+
49
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
50
+ from .vision_helpers import _prepare_openai_vision_messages
51
+
52
+ return _prepare_openai_vision_messages(messages)
53
+
28
54
  async def generate(self, prompt: str, options: dict[str, Any] | None = None) -> dict[str, Any]:
29
55
  messages = [{"role": "user", "content": prompt}]
30
56
  return await self._do_generate(messages, options)
31
57
 
32
58
  async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
33
- return await self._do_generate(messages, options)
59
+ return await self._do_generate(self._prepare_messages(messages), options)
34
60
 
35
61
  async def _do_generate(
36
62
  self, messages: list[dict[str, str]], options: dict[str, Any] | None = None
@@ -45,13 +71,25 @@ class AsyncLMStudioDriver(AsyncDriver):
45
71
  "temperature": merged_options.get("temperature", 0.7),
46
72
  }
47
73
 
48
- # Native JSON mode support
74
+ # Native JSON mode support (LM Studio requires json_schema, not json_object)
49
75
  if merged_options.get("json_mode"):
50
- payload["response_format"] = {"type": "json_object"}
76
+ json_schema = merged_options.get("json_schema")
77
+ if json_schema:
78
+ payload["response_format"] = {
79
+ "type": "json_schema",
80
+ "json_schema": {
81
+ "name": "extraction",
82
+ "schema": json_schema,
83
+ },
84
+ }
85
+ else:
86
+ # No schema provided — omit response_format entirely;
87
+ # LM Studio rejects "json_object" type.
88
+ pass
51
89
 
52
90
  async with httpx.AsyncClient() as client:
53
91
  try:
54
- r = await client.post(self.endpoint, json=payload, timeout=120)
92
+ r = await client.post(self.endpoint, json=payload, headers=self._headers, timeout=120)
55
93
  r.raise_for_status()
56
94
  response_data = r.json()
57
95
  except Exception as e:
@@ -77,3 +115,34 @@ class AsyncLMStudioDriver(AsyncDriver):
77
115
  }
78
116
 
79
117
  return {"text": text, "meta": meta}
118
+
119
+ # -- Model management (LM Studio 0.4.0+) ----------------------------------
120
+
121
+ async def list_models(self) -> list[dict[str, Any]]:
122
+ """List currently loaded models via GET /v1/models (OpenAI-compatible)."""
123
+ url = f"{self.base_url}/v1/models"
124
+ async with httpx.AsyncClient() as client:
125
+ r = await client.get(url, headers=self._headers, timeout=10)
126
+ r.raise_for_status()
127
+ data = r.json()
128
+ return data.get("data", [])
129
+
130
+ async def load_model(self, model: str, context_length: int | None = None) -> dict[str, Any]:
131
+ """Load a model into LM Studio via POST /api/v1/models/load."""
132
+ url = f"{self.base_url}/api/v1/models/load"
133
+ payload: dict[str, Any] = {"model": model}
134
+ if context_length is not None:
135
+ payload["context_length"] = context_length
136
+ async with httpx.AsyncClient() as client:
137
+ r = await client.post(url, json=payload, headers=self._headers, timeout=120)
138
+ r.raise_for_status()
139
+ return r.json()
140
+
141
+ async def unload_model(self, model: str) -> dict[str, Any]:
142
+ """Unload a model from LM Studio via POST /api/v1/models/unload."""
143
+ url = f"{self.base_url}/api/v1/models/unload"
144
+ payload = {"instance_id": model}
145
+ async with httpx.AsyncClient() as client:
146
+ r = await client.post(url, json=payload, headers=self._headers, timeout=30)
147
+ r.raise_for_status()
148
+ return r.json()
@@ -15,6 +15,8 @@ logger = logging.getLogger(__name__)
15
15
 
16
16
  class AsyncOllamaDriver(AsyncDriver):
17
17
  supports_json_mode = True
18
+ supports_json_schema = True
19
+ supports_vision = True
18
20
 
19
21
  MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
20
22
 
@@ -25,6 +27,11 @@ class AsyncOllamaDriver(AsyncDriver):
25
27
 
26
28
  supports_messages = True
27
29
 
30
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
31
+ from .vision_helpers import _prepare_ollama_vision_messages
32
+
33
+ return _prepare_ollama_vision_messages(messages)
34
+
28
35
  async def generate(self, prompt: str, options: dict[str, Any] | None = None) -> dict[str, Any]:
29
36
  merged_options = self.options.copy()
30
37
  if options:
@@ -36,9 +43,10 @@ class AsyncOllamaDriver(AsyncDriver):
36
43
  "stream": False,
37
44
  }
38
45
 
39
- # Native JSON mode support
46
+ # Native JSON mode / structured output support
40
47
  if merged_options.get("json_mode"):
41
- payload["format"] = "json"
48
+ json_schema = merged_options.get("json_schema")
49
+ payload["format"] = json_schema if json_schema else "json"
42
50
 
43
51
  if "temperature" in merged_options:
44
52
  payload["temperature"] = merged_options["temperature"]
@@ -74,6 +82,7 @@ class AsyncOllamaDriver(AsyncDriver):
74
82
 
75
83
  async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
76
84
  """Use Ollama's /api/chat endpoint for multi-turn conversations."""
85
+ messages = self._prepare_messages(messages)
77
86
  merged_options = self.options.copy()
78
87
  if options:
79
88
  merged_options.update(options)
@@ -88,7 +97,8 @@ class AsyncOllamaDriver(AsyncDriver):
88
97
  }
89
98
 
90
99
  if merged_options.get("json_mode"):
91
- payload["format"] = "json"
100
+ json_schema = merged_options.get("json_schema")
101
+ payload["format"] = json_schema if json_schema else "json"
92
102
 
93
103
  if "temperature" in merged_options:
94
104
  payload["temperature"] = merged_options["temperature"]