prompture 0.0.36.dev1__py3-none-any.whl → 0.0.37.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. prompture/__init__.py +120 -2
  2. prompture/_version.py +2 -2
  3. prompture/agent.py +925 -0
  4. prompture/agent_types.py +156 -0
  5. prompture/async_agent.py +879 -0
  6. prompture/async_conversation.py +199 -17
  7. prompture/async_driver.py +24 -0
  8. prompture/async_groups.py +551 -0
  9. prompture/conversation.py +213 -18
  10. prompture/core.py +30 -12
  11. prompture/discovery.py +24 -1
  12. prompture/driver.py +38 -0
  13. prompture/drivers/__init__.py +5 -1
  14. prompture/drivers/async_azure_driver.py +7 -1
  15. prompture/drivers/async_claude_driver.py +7 -1
  16. prompture/drivers/async_google_driver.py +24 -4
  17. prompture/drivers/async_grok_driver.py +7 -1
  18. prompture/drivers/async_groq_driver.py +7 -1
  19. prompture/drivers/async_lmstudio_driver.py +59 -3
  20. prompture/drivers/async_ollama_driver.py +7 -0
  21. prompture/drivers/async_openai_driver.py +7 -1
  22. prompture/drivers/async_openrouter_driver.py +7 -1
  23. prompture/drivers/async_registry.py +5 -1
  24. prompture/drivers/azure_driver.py +7 -1
  25. prompture/drivers/claude_driver.py +7 -1
  26. prompture/drivers/google_driver.py +24 -4
  27. prompture/drivers/grok_driver.py +7 -1
  28. prompture/drivers/groq_driver.py +7 -1
  29. prompture/drivers/lmstudio_driver.py +58 -6
  30. prompture/drivers/ollama_driver.py +7 -0
  31. prompture/drivers/openai_driver.py +7 -1
  32. prompture/drivers/openrouter_driver.py +7 -1
  33. prompture/drivers/vision_helpers.py +153 -0
  34. prompture/group_types.py +147 -0
  35. prompture/groups.py +530 -0
  36. prompture/image.py +180 -0
  37. prompture/persistence.py +254 -0
  38. prompture/persona.py +482 -0
  39. prompture/serialization.py +218 -0
  40. prompture/settings.py +1 -0
  41. {prompture-0.0.36.dev1.dist-info → prompture-0.0.37.dev1.dist-info}/METADATA +1 -1
  42. prompture-0.0.37.dev1.dist-info/RECORD +77 -0
  43. prompture-0.0.36.dev1.dist-info/RECORD +0 -66
  44. {prompture-0.0.36.dev1.dist-info → prompture-0.0.37.dev1.dist-info}/WHEEL +0 -0
  45. {prompture-0.0.36.dev1.dist-info → prompture-0.0.37.dev1.dist-info}/entry_points.txt +0 -0
  46. {prompture-0.0.36.dev1.dist-info → prompture-0.0.37.dev1.dist-info}/licenses/LICENSE +0 -0
  47. {prompture-0.0.36.dev1.dist-info → prompture-0.0.37.dev1.dist-info}/top_level.txt +0 -0
@@ -20,6 +20,7 @@ class AsyncGoogleDriver(CostMixin, AsyncDriver):
20
20
 
21
21
  supports_json_mode = True
22
22
  supports_json_schema = True
23
+ supports_vision = True
23
24
 
24
25
  MODEL_PRICING = GoogleDriver.MODEL_PRICING
25
26
  _PRICING_UNIT = 1_000_000
@@ -50,12 +51,17 @@ class AsyncGoogleDriver(CostMixin, AsyncDriver):
50
51
 
51
52
  supports_messages = True
52
53
 
54
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
55
+ from .vision_helpers import _prepare_google_vision_messages
56
+
57
+ return _prepare_google_vision_messages(messages)
58
+
53
59
  async def generate(self, prompt: str, options: dict[str, Any] | None = None) -> dict[str, Any]:
54
60
  messages = [{"role": "user", "content": prompt}]
55
61
  return await self._do_generate(messages, options)
56
62
 
57
63
  async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
58
- return await self._do_generate(messages, options)
64
+ return await self._do_generate(self._prepare_messages(messages), options)
59
65
 
60
66
  async def _do_generate(
61
67
  self, messages: list[dict[str, str]], options: dict[str, Any] | None = None
@@ -90,10 +96,14 @@ class AsyncGoogleDriver(CostMixin, AsyncDriver):
90
96
  role = msg.get("role", "user")
91
97
  content = msg.get("content", "")
92
98
  if role == "system":
93
- system_instruction = content
99
+ system_instruction = content if isinstance(content, str) else str(content)
94
100
  else:
95
101
  gemini_role = "model" if role == "assistant" else "user"
96
- contents.append({"role": gemini_role, "parts": [content]})
102
+ if msg.get("_vision_parts"):
103
+ # Already converted to Gemini parts by _prepare_messages
104
+ contents.append({"role": gemini_role, "parts": content})
105
+ else:
106
+ contents.append({"role": gemini_role, "parts": [content]})
97
107
 
98
108
  try:
99
109
  model_kwargs: dict[str, Any] = {}
@@ -111,7 +121,17 @@ class AsyncGoogleDriver(CostMixin, AsyncDriver):
111
121
  if not response.text:
112
122
  raise ValueError("Empty response from model")
113
123
 
114
- total_prompt_chars = sum(len(msg.get("content", "")) for msg in messages)
124
+ total_prompt_chars = 0
125
+ for msg in messages:
126
+ c = msg.get("content", "")
127
+ if isinstance(c, str):
128
+ total_prompt_chars += len(c)
129
+ elif isinstance(c, list):
130
+ for part in c:
131
+ if isinstance(part, str):
132
+ total_prompt_chars += len(part)
133
+ elif isinstance(part, dict) and "text" in part:
134
+ total_prompt_chars += len(part["text"])
115
135
  completion_chars = len(response.text)
116
136
 
117
137
  total_cost = self._calculate_cost_chars(total_prompt_chars, completion_chars)
@@ -14,6 +14,7 @@ from .grok_driver import GrokDriver
14
14
 
15
15
  class AsyncGrokDriver(CostMixin, AsyncDriver):
16
16
  supports_json_mode = True
17
+ supports_vision = True
17
18
 
18
19
  MODEL_PRICING = GrokDriver.MODEL_PRICING
19
20
  _PRICING_UNIT = 1_000_000
@@ -25,12 +26,17 @@ class AsyncGrokDriver(CostMixin, AsyncDriver):
25
26
 
26
27
  supports_messages = True
27
28
 
29
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
30
+ from .vision_helpers import _prepare_openai_vision_messages
31
+
32
+ return _prepare_openai_vision_messages(messages)
33
+
28
34
  async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
29
35
  messages = [{"role": "user", "content": prompt}]
30
36
  return await self._do_generate(messages, options)
31
37
 
32
38
  async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
33
- return await self._do_generate(messages, options)
39
+ return await self._do_generate(self._prepare_messages(messages), options)
34
40
 
35
41
  async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
36
42
  if not self.api_key:
@@ -17,6 +17,7 @@ from .groq_driver import GroqDriver
17
17
 
18
18
  class AsyncGroqDriver(CostMixin, AsyncDriver):
19
19
  supports_json_mode = True
20
+ supports_vision = True
20
21
 
21
22
  MODEL_PRICING = GroqDriver.MODEL_PRICING
22
23
 
@@ -30,12 +31,17 @@ class AsyncGroqDriver(CostMixin, AsyncDriver):
30
31
 
31
32
  supports_messages = True
32
33
 
34
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
35
+ from .vision_helpers import _prepare_openai_vision_messages
36
+
37
+ return _prepare_openai_vision_messages(messages)
38
+
33
39
  async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
34
40
  messages = [{"role": "user", "content": prompt}]
35
41
  return await self._do_generate(messages, options)
36
42
 
37
43
  async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
38
- return await self._do_generate(messages, options)
44
+ return await self._do_generate(self._prepare_messages(messages), options)
39
45
 
40
46
  async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
41
47
  if self.client is None:
@@ -15,22 +15,47 @@ logger = logging.getLogger(__name__)
15
15
 
16
16
  class AsyncLMStudioDriver(AsyncDriver):
17
17
  supports_json_mode = True
18
+ supports_vision = True
18
19
 
19
20
  MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
20
21
 
21
- def __init__(self, endpoint: str | None = None, model: str = "deepseek/deepseek-r1-0528-qwen3-8b"):
22
+ def __init__(
23
+ self,
24
+ endpoint: str | None = None,
25
+ model: str = "deepseek/deepseek-r1-0528-qwen3-8b",
26
+ api_key: str | None = None,
27
+ ):
22
28
  self.endpoint = endpoint or os.getenv("LMSTUDIO_ENDPOINT", "http://127.0.0.1:1234/v1/chat/completions")
23
29
  self.model = model
24
30
  self.options: dict[str, Any] = {}
25
31
 
32
+ # Derive base_url once for reuse across management endpoints
33
+ self.base_url = self.endpoint.split("/v1/")[0]
34
+
35
+ # API key for LM Studio 0.4.0+ authentication
36
+ self.api_key = api_key or os.getenv("LMSTUDIO_API_KEY")
37
+ self._headers = self._build_headers()
38
+
26
39
  supports_messages = True
27
40
 
41
+ def _build_headers(self) -> dict[str, str]:
42
+ """Build request headers, including auth if an API key is configured."""
43
+ headers: dict[str, str] = {"Content-Type": "application/json"}
44
+ if self.api_key:
45
+ headers["Authorization"] = f"Bearer {self.api_key}"
46
+ return headers
47
+
48
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
49
+ from .vision_helpers import _prepare_openai_vision_messages
50
+
51
+ return _prepare_openai_vision_messages(messages)
52
+
28
53
  async def generate(self, prompt: str, options: dict[str, Any] | None = None) -> dict[str, Any]:
29
54
  messages = [{"role": "user", "content": prompt}]
30
55
  return await self._do_generate(messages, options)
31
56
 
32
57
  async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
33
- return await self._do_generate(messages, options)
58
+ return await self._do_generate(self._prepare_messages(messages), options)
34
59
 
35
60
  async def _do_generate(
36
61
  self, messages: list[dict[str, str]], options: dict[str, Any] | None = None
@@ -51,7 +76,7 @@ class AsyncLMStudioDriver(AsyncDriver):
51
76
 
52
77
  async with httpx.AsyncClient() as client:
53
78
  try:
54
- r = await client.post(self.endpoint, json=payload, timeout=120)
79
+ r = await client.post(self.endpoint, json=payload, headers=self._headers, timeout=120)
55
80
  r.raise_for_status()
56
81
  response_data = r.json()
57
82
  except Exception as e:
@@ -77,3 +102,34 @@ class AsyncLMStudioDriver(AsyncDriver):
77
102
  }
78
103
 
79
104
  return {"text": text, "meta": meta}
105
+
106
+ # -- Model management (LM Studio 0.4.0+) ----------------------------------
107
+
108
+ async def list_models(self) -> list[dict[str, Any]]:
109
+ """List currently loaded models via GET /v1/models (OpenAI-compatible)."""
110
+ url = f"{self.base_url}/v1/models"
111
+ async with httpx.AsyncClient() as client:
112
+ r = await client.get(url, headers=self._headers, timeout=10)
113
+ r.raise_for_status()
114
+ data = r.json()
115
+ return data.get("data", [])
116
+
117
+ async def load_model(self, model: str, context_length: int | None = None) -> dict[str, Any]:
118
+ """Load a model into LM Studio via POST /api/v1/models/load."""
119
+ url = f"{self.base_url}/api/v1/models/load"
120
+ payload: dict[str, Any] = {"model": model}
121
+ if context_length is not None:
122
+ payload["context_length"] = context_length
123
+ async with httpx.AsyncClient() as client:
124
+ r = await client.post(url, json=payload, headers=self._headers, timeout=120)
125
+ r.raise_for_status()
126
+ return r.json()
127
+
128
+ async def unload_model(self, model: str) -> dict[str, Any]:
129
+ """Unload a model from LM Studio via POST /api/v1/models/unload."""
130
+ url = f"{self.base_url}/api/v1/models/unload"
131
+ payload = {"model": model}
132
+ async with httpx.AsyncClient() as client:
133
+ r = await client.post(url, json=payload, headers=self._headers, timeout=30)
134
+ r.raise_for_status()
135
+ return r.json()
@@ -15,6 +15,7 @@ logger = logging.getLogger(__name__)
15
15
 
16
16
  class AsyncOllamaDriver(AsyncDriver):
17
17
  supports_json_mode = True
18
+ supports_vision = True
18
19
 
19
20
  MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
20
21
 
@@ -25,6 +26,11 @@ class AsyncOllamaDriver(AsyncDriver):
25
26
 
26
27
  supports_messages = True
27
28
 
29
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
30
+ from .vision_helpers import _prepare_ollama_vision_messages
31
+
32
+ return _prepare_ollama_vision_messages(messages)
33
+
28
34
  async def generate(self, prompt: str, options: dict[str, Any] | None = None) -> dict[str, Any]:
29
35
  merged_options = self.options.copy()
30
36
  if options:
@@ -74,6 +80,7 @@ class AsyncOllamaDriver(AsyncDriver):
74
80
 
75
81
  async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
76
82
  """Use Ollama's /api/chat endpoint for multi-turn conversations."""
83
+ messages = self._prepare_messages(messages)
77
84
  merged_options = self.options.copy()
78
85
  if options:
79
86
  merged_options.update(options)
@@ -18,6 +18,7 @@ from .openai_driver import OpenAIDriver
18
18
  class AsyncOpenAIDriver(CostMixin, AsyncDriver):
19
19
  supports_json_mode = True
20
20
  supports_json_schema = True
21
+ supports_vision = True
21
22
 
22
23
  MODEL_PRICING = OpenAIDriver.MODEL_PRICING
23
24
 
@@ -31,12 +32,17 @@ class AsyncOpenAIDriver(CostMixin, AsyncDriver):
31
32
 
32
33
  supports_messages = True
33
34
 
35
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
36
+ from .vision_helpers import _prepare_openai_vision_messages
37
+
38
+ return _prepare_openai_vision_messages(messages)
39
+
34
40
  async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
35
41
  messages = [{"role": "user", "content": prompt}]
36
42
  return await self._do_generate(messages, options)
37
43
 
38
44
  async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
39
- return await self._do_generate(messages, options)
45
+ return await self._do_generate(self._prepare_messages(messages), options)
40
46
 
41
47
  async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
42
48
  if self.client is None:
@@ -14,6 +14,7 @@ from .openrouter_driver import OpenRouterDriver
14
14
 
15
15
  class AsyncOpenRouterDriver(CostMixin, AsyncDriver):
16
16
  supports_json_mode = True
17
+ supports_vision = True
17
18
 
18
19
  MODEL_PRICING = OpenRouterDriver.MODEL_PRICING
19
20
 
@@ -31,12 +32,17 @@ class AsyncOpenRouterDriver(CostMixin, AsyncDriver):
31
32
 
32
33
  supports_messages = True
33
34
 
35
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
36
+ from .vision_helpers import _prepare_openai_vision_messages
37
+
38
+ return _prepare_openai_vision_messages(messages)
39
+
34
40
  async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
35
41
  messages = [{"role": "user", "content": prompt}]
36
42
  return await self._do_generate(messages, options)
37
43
 
38
44
  async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
39
- return await self._do_generate(messages, options)
45
+ return await self._do_generate(self._prepare_messages(messages), options)
40
46
 
41
47
  async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
42
48
  model = options.get("model", self.model)
@@ -49,7 +49,11 @@ register_async_driver(
49
49
  )
50
50
  register_async_driver(
51
51
  "lmstudio",
52
- lambda model=None: AsyncLMStudioDriver(endpoint=settings.lmstudio_endpoint, model=model or settings.lmstudio_model),
52
+ lambda model=None: AsyncLMStudioDriver(
53
+ endpoint=settings.lmstudio_endpoint,
54
+ model=model or settings.lmstudio_model,
55
+ api_key=settings.lmstudio_api_key,
56
+ ),
53
57
  overwrite=True,
54
58
  )
55
59
  register_async_driver(
@@ -17,6 +17,7 @@ from ..driver import Driver
17
17
  class AzureDriver(CostMixin, Driver):
18
18
  supports_json_mode = True
19
19
  supports_json_schema = True
20
+ supports_vision = True
20
21
 
21
22
  # Pricing per 1K tokens (adjust if your Azure pricing differs from OpenAI defaults)
22
23
  MODEL_PRICING = {
@@ -90,12 +91,17 @@ class AzureDriver(CostMixin, Driver):
90
91
 
91
92
  supports_messages = True
92
93
 
94
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
95
+ from .vision_helpers import _prepare_openai_vision_messages
96
+
97
+ return _prepare_openai_vision_messages(messages)
98
+
93
99
  def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
94
100
  messages = [{"role": "user", "content": prompt}]
95
101
  return self._do_generate(messages, options)
96
102
 
97
103
  def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
98
- return self._do_generate(messages, options)
104
+ return self._do_generate(self._prepare_messages(messages), options)
99
105
 
100
106
  def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
101
107
  if self.client is None:
@@ -21,6 +21,7 @@ class ClaudeDriver(CostMixin, Driver):
21
21
  supports_json_schema = True
22
22
  supports_tool_use = True
23
23
  supports_streaming = True
24
+ supports_vision = True
24
25
 
25
26
  # Claude pricing per 1000 tokens (prices should be kept current with Anthropic's pricing)
26
27
  MODEL_PRICING = {
@@ -57,12 +58,17 @@ class ClaudeDriver(CostMixin, Driver):
57
58
 
58
59
  supports_messages = True
59
60
 
61
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
62
+ from .vision_helpers import _prepare_claude_vision_messages
63
+
64
+ return _prepare_claude_vision_messages(messages)
65
+
60
66
  def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
61
67
  messages = [{"role": "user", "content": prompt}]
62
68
  return self._do_generate(messages, options)
63
69
 
64
70
  def generate_messages(self, messages: list[dict[str, Any]], options: dict[str, Any]) -> dict[str, Any]:
65
- return self._do_generate(messages, options)
71
+ return self._do_generate(self._prepare_messages(messages), options)
66
72
 
67
73
  def _do_generate(self, messages: list[dict[str, Any]], options: dict[str, Any]) -> dict[str, Any]:
68
74
  if anthropic is None:
@@ -15,6 +15,7 @@ class GoogleDriver(CostMixin, Driver):
15
15
 
16
16
  supports_json_mode = True
17
17
  supports_json_schema = True
18
+ supports_vision = True
18
19
 
19
20
  # Based on current Gemini pricing (as of 2025)
20
21
  # Source: https://cloud.google.com/vertex-ai/pricing#gemini_models
@@ -107,12 +108,17 @@ class GoogleDriver(CostMixin, Driver):
107
108
 
108
109
  supports_messages = True
109
110
 
111
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
112
+ from .vision_helpers import _prepare_google_vision_messages
113
+
114
+ return _prepare_google_vision_messages(messages)
115
+
110
116
  def generate(self, prompt: str, options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
111
117
  messages = [{"role": "user", "content": prompt}]
112
118
  return self._do_generate(messages, options)
113
119
 
114
120
  def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
115
- return self._do_generate(messages, options)
121
+ return self._do_generate(self._prepare_messages(messages), options)
116
122
 
117
123
  def _do_generate(self, messages: list[dict[str, str]], options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
118
124
  merged_options = self.options.copy()
@@ -147,11 +153,15 @@ class GoogleDriver(CostMixin, Driver):
147
153
  role = msg.get("role", "user")
148
154
  content = msg.get("content", "")
149
155
  if role == "system":
150
- system_instruction = content
156
+ system_instruction = content if isinstance(content, str) else str(content)
151
157
  else:
152
158
  # Gemini uses "model" for assistant role
153
159
  gemini_role = "model" if role == "assistant" else "user"
154
- contents.append({"role": gemini_role, "parts": [content]})
160
+ if msg.get("_vision_parts"):
161
+ # Already converted to Gemini parts by _prepare_messages
162
+ contents.append({"role": gemini_role, "parts": content})
163
+ else:
164
+ contents.append({"role": gemini_role, "parts": [content]})
155
165
 
156
166
  try:
157
167
  logger.debug(f"Initializing {self.model} for generation")
@@ -174,7 +184,17 @@ class GoogleDriver(CostMixin, Driver):
174
184
  raise ValueError("Empty response from model")
175
185
 
176
186
  # Calculate token usage and cost
177
- total_prompt_chars = sum(len(msg.get("content", "")) for msg in messages)
187
+ total_prompt_chars = 0
188
+ for msg in messages:
189
+ c = msg.get("content", "")
190
+ if isinstance(c, str):
191
+ total_prompt_chars += len(c)
192
+ elif isinstance(c, list):
193
+ for part in c:
194
+ if isinstance(part, str):
195
+ total_prompt_chars += len(part)
196
+ elif isinstance(part, dict) and "text" in part:
197
+ total_prompt_chars += len(part["text"])
178
198
  completion_chars = len(response.text)
179
199
 
180
200
  # Google uses character-based cost estimation
@@ -13,6 +13,7 @@ from ..driver import Driver
13
13
 
14
14
  class GrokDriver(CostMixin, Driver):
15
15
  supports_json_mode = True
16
+ supports_vision = True
16
17
 
17
18
  # Pricing per 1M tokens based on xAI's documentation
18
19
  _PRICING_UNIT = 1_000_000
@@ -80,12 +81,17 @@ class GrokDriver(CostMixin, Driver):
80
81
 
81
82
  supports_messages = True
82
83
 
84
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
85
+ from .vision_helpers import _prepare_openai_vision_messages
86
+
87
+ return _prepare_openai_vision_messages(messages)
88
+
83
89
  def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
84
90
  messages = [{"role": "user", "content": prompt}]
85
91
  return self._do_generate(messages, options)
86
92
 
87
93
  def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
88
- return self._do_generate(messages, options)
94
+ return self._do_generate(self._prepare_messages(messages), options)
89
95
 
90
96
  def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
91
97
  if not self.api_key:
@@ -16,6 +16,7 @@ from ..driver import Driver
16
16
 
17
17
  class GroqDriver(CostMixin, Driver):
18
18
  supports_json_mode = True
19
+ supports_vision = True
19
20
 
20
21
  # Approximate pricing per 1K tokens (to be updated with official pricing)
21
22
  # Each model entry defines token parameters and temperature support
@@ -50,12 +51,17 @@ class GroqDriver(CostMixin, Driver):
50
51
 
51
52
  supports_messages = True
52
53
 
54
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
55
+ from .vision_helpers import _prepare_openai_vision_messages
56
+
57
+ return _prepare_openai_vision_messages(messages)
58
+
53
59
  def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
54
60
  messages = [{"role": "user", "content": prompt}]
55
61
  return self._do_generate(messages, options)
56
62
 
57
63
  def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
58
- return self._do_generate(messages, options)
64
+ return self._do_generate(self._prepare_messages(messages), options)
59
65
 
60
66
  def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
61
67
  if self.client is None:
@@ -12,27 +12,46 @@ logger = logging.getLogger(__name__)
12
12
 
13
13
  class LMStudioDriver(Driver):
14
14
  supports_json_mode = True
15
+ supports_vision = True
15
16
 
16
17
  # LM Studio is local – costs are always zero.
17
18
  MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
18
19
 
19
- def __init__(self, endpoint: str | None = None, model: str = "deepseek/deepseek-r1-0528-qwen3-8b"):
20
+ def __init__(
21
+ self,
22
+ endpoint: str | None = None,
23
+ model: str = "deepseek/deepseek-r1-0528-qwen3-8b",
24
+ api_key: str | None = None,
25
+ ):
20
26
  # Allow override via env var
21
27
  self.endpoint = endpoint or os.getenv("LMSTUDIO_ENDPOINT", "http://127.0.0.1:1234/v1/chat/completions")
22
28
  self.model = model
23
29
  self.options: dict[str, Any] = {}
24
30
 
31
+ # Derive base_url once for reuse across management endpoints
32
+ self.base_url = self.endpoint.split("/v1/")[0]
33
+
34
+ # API key for LM Studio 0.4.0+ authentication
35
+ self.api_key = api_key or os.getenv("LMSTUDIO_API_KEY")
36
+ self._headers = self._build_headers()
37
+
25
38
  # Validate connection to LM Studio server
26
39
  self._validate_connection()
27
40
 
41
+ def _build_headers(self) -> dict[str, str]:
42
+ """Build request headers, including auth if an API key is configured."""
43
+ headers: dict[str, str] = {"Content-Type": "application/json"}
44
+ if self.api_key:
45
+ headers["Authorization"] = f"Bearer {self.api_key}"
46
+ return headers
47
+
28
48
  def _validate_connection(self):
29
49
  """Validate connection to the LM Studio server."""
30
50
  try:
31
- base_url = self.endpoint.split("/v1/")[0]
32
- health_url = f"{base_url}/v1/models"
51
+ health_url = f"{self.base_url}/v1/models"
33
52
 
34
53
  logger.debug(f"Validating connection to LM Studio server at: {health_url}")
35
- response = requests.get(health_url, timeout=5)
54
+ response = requests.get(health_url, headers=self._headers, timeout=5)
36
55
  response.raise_for_status()
37
56
  logger.debug("Connection to LM Studio server validated successfully")
38
57
  except requests.exceptions.RequestException as e:
@@ -40,12 +59,17 @@ class LMStudioDriver(Driver):
40
59
 
41
60
  supports_messages = True
42
61
 
62
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
63
+ from .vision_helpers import _prepare_openai_vision_messages
64
+
65
+ return _prepare_openai_vision_messages(messages)
66
+
43
67
  def generate(self, prompt: str, options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
44
68
  messages = [{"role": "user", "content": prompt}]
45
69
  return self._do_generate(messages, options)
46
70
 
47
71
  def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
48
- return self._do_generate(messages, options)
72
+ return self._do_generate(self._prepare_messages(messages), options)
49
73
 
50
74
  def _do_generate(self, messages: list[dict[str, str]], options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
51
75
  merged_options = self.options.copy()
@@ -66,7 +90,7 @@ class LMStudioDriver(Driver):
66
90
  logger.debug(f"Sending request to LM Studio endpoint: {self.endpoint}")
67
91
  logger.debug(f"Request payload: {payload}")
68
92
 
69
- r = requests.post(self.endpoint, json=payload, timeout=120)
93
+ r = requests.post(self.endpoint, json=payload, headers=self._headers, timeout=120)
70
94
  r.raise_for_status()
71
95
 
72
96
  response_data = r.json()
@@ -104,3 +128,31 @@ class LMStudioDriver(Driver):
104
128
  }
105
129
 
106
130
  return {"text": text, "meta": meta}
131
+
132
+ # -- Model management (LM Studio 0.4.0+) ----------------------------------
133
+
134
+ def list_models(self) -> list[dict[str, Any]]:
135
+ """List currently loaded models via GET /v1/models (OpenAI-compatible)."""
136
+ url = f"{self.base_url}/v1/models"
137
+ r = requests.get(url, headers=self._headers, timeout=10)
138
+ r.raise_for_status()
139
+ data = r.json()
140
+ return data.get("data", [])
141
+
142
+ def load_model(self, model: str, context_length: int | None = None) -> dict[str, Any]:
143
+ """Load a model into LM Studio via POST /api/v1/models/load."""
144
+ url = f"{self.base_url}/api/v1/models/load"
145
+ payload: dict[str, Any] = {"model": model}
146
+ if context_length is not None:
147
+ payload["context_length"] = context_length
148
+ r = requests.post(url, json=payload, headers=self._headers, timeout=120)
149
+ r.raise_for_status()
150
+ return r.json()
151
+
152
+ def unload_model(self, model: str) -> dict[str, Any]:
153
+ """Unload a model from LM Studio via POST /api/v1/models/unload."""
154
+ url = f"{self.base_url}/api/v1/models/unload"
155
+ payload = {"model": model}
156
+ r = requests.post(url, json=payload, headers=self._headers, timeout=30)
157
+ r.raise_for_status()
158
+ return r.json()
@@ -14,6 +14,7 @@ logger = logging.getLogger(__name__)
14
14
  class OllamaDriver(Driver):
15
15
  supports_json_mode = True
16
16
  supports_streaming = True
17
+ supports_vision = True
17
18
 
18
19
  # Ollama is free – costs are always zero.
19
20
  MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
@@ -46,6 +47,11 @@ class OllamaDriver(Driver):
46
47
 
47
48
  supports_messages = True
48
49
 
50
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
51
+ from .vision_helpers import _prepare_ollama_vision_messages
52
+
53
+ return _prepare_ollama_vision_messages(messages)
54
+
49
55
  def generate(self, prompt: str, options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
50
56
  # Merge instance options with call-specific options
51
57
  merged_options = self.options.copy()
@@ -190,6 +196,7 @@ class OllamaDriver(Driver):
190
196
 
191
197
  def generate_messages(self, messages: list[dict[str, Any]], options: dict[str, Any]) -> dict[str, Any]:
192
198
  """Use Ollama's /api/chat endpoint for multi-turn conversations."""
199
+ messages = self._prepare_messages(messages)
193
200
  merged_options = self.options.copy()
194
201
  if options:
195
202
  merged_options.update(options)
@@ -21,6 +21,7 @@ class OpenAIDriver(CostMixin, Driver):
21
21
  supports_json_schema = True
22
22
  supports_tool_use = True
23
23
  supports_streaming = True
24
+ supports_vision = True
24
25
 
25
26
  # Approximate pricing per 1K tokens (keep updated with OpenAI's official pricing)
26
27
  # Each model entry also defines which token parameter it supports and
@@ -74,12 +75,17 @@ class OpenAIDriver(CostMixin, Driver):
74
75
 
75
76
  supports_messages = True
76
77
 
78
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
79
+ from .vision_helpers import _prepare_openai_vision_messages
80
+
81
+ return _prepare_openai_vision_messages(messages)
82
+
77
83
  def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
78
84
  messages = [{"role": "user", "content": prompt}]
79
85
  return self._do_generate(messages, options)
80
86
 
81
87
  def generate_messages(self, messages: list[dict[str, Any]], options: dict[str, Any]) -> dict[str, Any]:
82
- return self._do_generate(messages, options)
88
+ return self._do_generate(self._prepare_messages(messages), options)
83
89
 
84
90
  def _do_generate(self, messages: list[dict[str, Any]], options: dict[str, Any]) -> dict[str, Any]:
85
91
  if self.client is None: