prompture 0.0.35__py3-none-any.whl → 0.0.40.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prompture/__init__.py +132 -3
- prompture/_version.py +2 -2
- prompture/agent.py +924 -0
- prompture/agent_types.py +156 -0
- prompture/async_agent.py +880 -0
- prompture/async_conversation.py +208 -17
- prompture/async_core.py +16 -0
- prompture/async_driver.py +63 -0
- prompture/async_groups.py +551 -0
- prompture/conversation.py +222 -18
- prompture/core.py +46 -12
- prompture/cost_mixin.py +37 -0
- prompture/discovery.py +132 -44
- prompture/driver.py +77 -0
- prompture/drivers/__init__.py +5 -1
- prompture/drivers/async_azure_driver.py +11 -5
- prompture/drivers/async_claude_driver.py +184 -9
- prompture/drivers/async_google_driver.py +222 -28
- prompture/drivers/async_grok_driver.py +11 -5
- prompture/drivers/async_groq_driver.py +11 -5
- prompture/drivers/async_lmstudio_driver.py +74 -5
- prompture/drivers/async_ollama_driver.py +13 -3
- prompture/drivers/async_openai_driver.py +162 -5
- prompture/drivers/async_openrouter_driver.py +11 -5
- prompture/drivers/async_registry.py +5 -1
- prompture/drivers/azure_driver.py +10 -4
- prompture/drivers/claude_driver.py +17 -1
- prompture/drivers/google_driver.py +227 -33
- prompture/drivers/grok_driver.py +11 -5
- prompture/drivers/groq_driver.py +11 -5
- prompture/drivers/lmstudio_driver.py +73 -8
- prompture/drivers/ollama_driver.py +16 -5
- prompture/drivers/openai_driver.py +26 -11
- prompture/drivers/openrouter_driver.py +11 -5
- prompture/drivers/vision_helpers.py +153 -0
- prompture/group_types.py +147 -0
- prompture/groups.py +530 -0
- prompture/image.py +180 -0
- prompture/ledger.py +252 -0
- prompture/model_rates.py +112 -2
- prompture/persistence.py +254 -0
- prompture/persona.py +482 -0
- prompture/serialization.py +218 -0
- prompture/settings.py +1 -0
- prompture-0.0.40.dev1.dist-info/METADATA +369 -0
- prompture-0.0.40.dev1.dist-info/RECORD +78 -0
- prompture-0.0.35.dist-info/METADATA +0 -464
- prompture-0.0.35.dist-info/RECORD +0 -66
- {prompture-0.0.35.dist-info → prompture-0.0.40.dev1.dist-info}/WHEEL +0 -0
- {prompture-0.0.35.dist-info → prompture-0.0.40.dev1.dist-info}/entry_points.txt +0 -0
- {prompture-0.0.35.dist-info → prompture-0.0.40.dev1.dist-info}/licenses/LICENSE +0 -0
- {prompture-0.0.35.dist-info → prompture-0.0.40.dev1.dist-info}/top_level.txt +0 -0
prompture/drivers/grok_driver.py
CHANGED
|
@@ -13,6 +13,7 @@ from ..driver import Driver
|
|
|
13
13
|
|
|
14
14
|
class GrokDriver(CostMixin, Driver):
|
|
15
15
|
supports_json_mode = True
|
|
16
|
+
supports_vision = True
|
|
16
17
|
|
|
17
18
|
# Pricing per 1M tokens based on xAI's documentation
|
|
18
19
|
_PRICING_UNIT = 1_000_000
|
|
@@ -80,12 +81,17 @@ class GrokDriver(CostMixin, Driver):
|
|
|
80
81
|
|
|
81
82
|
supports_messages = True
|
|
82
83
|
|
|
84
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
85
|
+
from .vision_helpers import _prepare_openai_vision_messages
|
|
86
|
+
|
|
87
|
+
return _prepare_openai_vision_messages(messages)
|
|
88
|
+
|
|
83
89
|
def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
84
90
|
messages = [{"role": "user", "content": prompt}]
|
|
85
91
|
return self._do_generate(messages, options)
|
|
86
92
|
|
|
87
93
|
def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
88
|
-
return self._do_generate(messages, options)
|
|
94
|
+
return self._do_generate(self._prepare_messages(messages), options)
|
|
89
95
|
|
|
90
96
|
def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
91
97
|
if not self.api_key:
|
|
@@ -93,10 +99,10 @@ class GrokDriver(CostMixin, Driver):
|
|
|
93
99
|
|
|
94
100
|
model = options.get("model", self.model)
|
|
95
101
|
|
|
96
|
-
# Lookup model-specific config
|
|
97
|
-
|
|
98
|
-
tokens_param =
|
|
99
|
-
supports_temperature =
|
|
102
|
+
# Lookup model-specific config (live models.dev data + hardcoded fallback)
|
|
103
|
+
model_config = self._get_model_config("grok", model)
|
|
104
|
+
tokens_param = model_config["tokens_param"]
|
|
105
|
+
supports_temperature = model_config["supports_temperature"]
|
|
100
106
|
|
|
101
107
|
# Defaults
|
|
102
108
|
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
prompture/drivers/groq_driver.py
CHANGED
|
@@ -16,6 +16,7 @@ from ..driver import Driver
|
|
|
16
16
|
|
|
17
17
|
class GroqDriver(CostMixin, Driver):
|
|
18
18
|
supports_json_mode = True
|
|
19
|
+
supports_vision = True
|
|
19
20
|
|
|
20
21
|
# Approximate pricing per 1K tokens (to be updated with official pricing)
|
|
21
22
|
# Each model entry defines token parameters and temperature support
|
|
@@ -50,12 +51,17 @@ class GroqDriver(CostMixin, Driver):
|
|
|
50
51
|
|
|
51
52
|
supports_messages = True
|
|
52
53
|
|
|
54
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
55
|
+
from .vision_helpers import _prepare_openai_vision_messages
|
|
56
|
+
|
|
57
|
+
return _prepare_openai_vision_messages(messages)
|
|
58
|
+
|
|
53
59
|
def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
54
60
|
messages = [{"role": "user", "content": prompt}]
|
|
55
61
|
return self._do_generate(messages, options)
|
|
56
62
|
|
|
57
63
|
def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
58
|
-
return self._do_generate(messages, options)
|
|
64
|
+
return self._do_generate(self._prepare_messages(messages), options)
|
|
59
65
|
|
|
60
66
|
def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
61
67
|
if self.client is None:
|
|
@@ -63,10 +69,10 @@ class GroqDriver(CostMixin, Driver):
|
|
|
63
69
|
|
|
64
70
|
model = options.get("model", self.model)
|
|
65
71
|
|
|
66
|
-
# Lookup model-specific config
|
|
67
|
-
|
|
68
|
-
tokens_param =
|
|
69
|
-
supports_temperature =
|
|
72
|
+
# Lookup model-specific config (live models.dev data + hardcoded fallback)
|
|
73
|
+
model_config = self._get_model_config("groq", model)
|
|
74
|
+
tokens_param = model_config["tokens_param"]
|
|
75
|
+
supports_temperature = model_config["supports_temperature"]
|
|
70
76
|
|
|
71
77
|
# Base configuration
|
|
72
78
|
opts = {"temperature": 0.7, "max_tokens": 512, **options}
|
|
@@ -12,27 +12,47 @@ logger = logging.getLogger(__name__)
|
|
|
12
12
|
|
|
13
13
|
class LMStudioDriver(Driver):
|
|
14
14
|
supports_json_mode = True
|
|
15
|
+
supports_json_schema = True
|
|
16
|
+
supports_vision = True
|
|
15
17
|
|
|
16
18
|
# LM Studio is local – costs are always zero.
|
|
17
19
|
MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
|
|
18
20
|
|
|
19
|
-
def __init__(
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
endpoint: str | None = None,
|
|
24
|
+
model: str = "deepseek/deepseek-r1-0528-qwen3-8b",
|
|
25
|
+
api_key: str | None = None,
|
|
26
|
+
):
|
|
20
27
|
# Allow override via env var
|
|
21
28
|
self.endpoint = endpoint or os.getenv("LMSTUDIO_ENDPOINT", "http://127.0.0.1:1234/v1/chat/completions")
|
|
22
29
|
self.model = model
|
|
23
30
|
self.options: dict[str, Any] = {}
|
|
24
31
|
|
|
32
|
+
# Derive base_url once for reuse across management endpoints
|
|
33
|
+
self.base_url = self.endpoint.split("/v1/")[0]
|
|
34
|
+
|
|
35
|
+
# API key for LM Studio 0.4.0+ authentication
|
|
36
|
+
self.api_key = api_key or os.getenv("LMSTUDIO_API_KEY")
|
|
37
|
+
self._headers = self._build_headers()
|
|
38
|
+
|
|
25
39
|
# Validate connection to LM Studio server
|
|
26
40
|
self._validate_connection()
|
|
27
41
|
|
|
42
|
+
def _build_headers(self) -> dict[str, str]:
|
|
43
|
+
"""Build request headers, including auth if an API key is configured."""
|
|
44
|
+
headers: dict[str, str] = {"Content-Type": "application/json"}
|
|
45
|
+
if self.api_key:
|
|
46
|
+
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
47
|
+
return headers
|
|
48
|
+
|
|
28
49
|
def _validate_connection(self):
|
|
29
50
|
"""Validate connection to the LM Studio server."""
|
|
30
51
|
try:
|
|
31
|
-
|
|
32
|
-
health_url = f"{base_url}/v1/models"
|
|
52
|
+
health_url = f"{self.base_url}/v1/models"
|
|
33
53
|
|
|
34
54
|
logger.debug(f"Validating connection to LM Studio server at: {health_url}")
|
|
35
|
-
response = requests.get(health_url, timeout=5)
|
|
55
|
+
response = requests.get(health_url, headers=self._headers, timeout=5)
|
|
36
56
|
response.raise_for_status()
|
|
37
57
|
logger.debug("Connection to LM Studio server validated successfully")
|
|
38
58
|
except requests.exceptions.RequestException as e:
|
|
@@ -40,12 +60,17 @@ class LMStudioDriver(Driver):
|
|
|
40
60
|
|
|
41
61
|
supports_messages = True
|
|
42
62
|
|
|
63
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
64
|
+
from .vision_helpers import _prepare_openai_vision_messages
|
|
65
|
+
|
|
66
|
+
return _prepare_openai_vision_messages(messages)
|
|
67
|
+
|
|
43
68
|
def generate(self, prompt: str, options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
|
|
44
69
|
messages = [{"role": "user", "content": prompt}]
|
|
45
70
|
return self._do_generate(messages, options)
|
|
46
71
|
|
|
47
72
|
def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
48
|
-
return self._do_generate(messages, options)
|
|
73
|
+
return self._do_generate(self._prepare_messages(messages), options)
|
|
49
74
|
|
|
50
75
|
def _do_generate(self, messages: list[dict[str, str]], options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
|
|
51
76
|
merged_options = self.options.copy()
|
|
@@ -58,15 +83,27 @@ class LMStudioDriver(Driver):
|
|
|
58
83
|
"temperature": merged_options.get("temperature", 0.7),
|
|
59
84
|
}
|
|
60
85
|
|
|
61
|
-
# Native JSON mode support
|
|
86
|
+
# Native JSON mode support (LM Studio requires json_schema, not json_object)
|
|
62
87
|
if merged_options.get("json_mode"):
|
|
63
|
-
|
|
88
|
+
json_schema = merged_options.get("json_schema")
|
|
89
|
+
if json_schema:
|
|
90
|
+
payload["response_format"] = {
|
|
91
|
+
"type": "json_schema",
|
|
92
|
+
"json_schema": {
|
|
93
|
+
"name": "extraction",
|
|
94
|
+
"schema": json_schema,
|
|
95
|
+
},
|
|
96
|
+
}
|
|
97
|
+
else:
|
|
98
|
+
# No schema provided — omit response_format entirely;
|
|
99
|
+
# LM Studio rejects "json_object" type.
|
|
100
|
+
pass
|
|
64
101
|
|
|
65
102
|
try:
|
|
66
103
|
logger.debug(f"Sending request to LM Studio endpoint: {self.endpoint}")
|
|
67
104
|
logger.debug(f"Request payload: {payload}")
|
|
68
105
|
|
|
69
|
-
r = requests.post(self.endpoint, json=payload, timeout=120)
|
|
106
|
+
r = requests.post(self.endpoint, json=payload, headers=self._headers, timeout=120)
|
|
70
107
|
r.raise_for_status()
|
|
71
108
|
|
|
72
109
|
response_data = r.json()
|
|
@@ -104,3 +141,31 @@ class LMStudioDriver(Driver):
|
|
|
104
141
|
}
|
|
105
142
|
|
|
106
143
|
return {"text": text, "meta": meta}
|
|
144
|
+
|
|
145
|
+
# -- Model management (LM Studio 0.4.0+) ----------------------------------
|
|
146
|
+
|
|
147
|
+
def list_models(self) -> list[dict[str, Any]]:
|
|
148
|
+
"""List currently loaded models via GET /v1/models (OpenAI-compatible)."""
|
|
149
|
+
url = f"{self.base_url}/v1/models"
|
|
150
|
+
r = requests.get(url, headers=self._headers, timeout=10)
|
|
151
|
+
r.raise_for_status()
|
|
152
|
+
data = r.json()
|
|
153
|
+
return data.get("data", [])
|
|
154
|
+
|
|
155
|
+
def load_model(self, model: str, context_length: int | None = None) -> dict[str, Any]:
|
|
156
|
+
"""Load a model into LM Studio via POST /api/v1/models/load."""
|
|
157
|
+
url = f"{self.base_url}/api/v1/models/load"
|
|
158
|
+
payload: dict[str, Any] = {"model": model}
|
|
159
|
+
if context_length is not None:
|
|
160
|
+
payload["context_length"] = context_length
|
|
161
|
+
r = requests.post(url, json=payload, headers=self._headers, timeout=120)
|
|
162
|
+
r.raise_for_status()
|
|
163
|
+
return r.json()
|
|
164
|
+
|
|
165
|
+
def unload_model(self, model: str) -> dict[str, Any]:
|
|
166
|
+
"""Unload a model from LM Studio via POST /api/v1/models/unload."""
|
|
167
|
+
url = f"{self.base_url}/api/v1/models/unload"
|
|
168
|
+
payload = {"instance_id": model}
|
|
169
|
+
r = requests.post(url, json=payload, headers=self._headers, timeout=30)
|
|
170
|
+
r.raise_for_status()
|
|
171
|
+
return r.json()
|
|
@@ -13,7 +13,9 @@ logger = logging.getLogger(__name__)
|
|
|
13
13
|
|
|
14
14
|
class OllamaDriver(Driver):
|
|
15
15
|
supports_json_mode = True
|
|
16
|
+
supports_json_schema = True
|
|
16
17
|
supports_streaming = True
|
|
18
|
+
supports_vision = True
|
|
17
19
|
|
|
18
20
|
# Ollama is free – costs are always zero.
|
|
19
21
|
MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
|
|
@@ -46,6 +48,11 @@ class OllamaDriver(Driver):
|
|
|
46
48
|
|
|
47
49
|
supports_messages = True
|
|
48
50
|
|
|
51
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
52
|
+
from .vision_helpers import _prepare_ollama_vision_messages
|
|
53
|
+
|
|
54
|
+
return _prepare_ollama_vision_messages(messages)
|
|
55
|
+
|
|
49
56
|
def generate(self, prompt: str, options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
|
|
50
57
|
# Merge instance options with call-specific options
|
|
51
58
|
merged_options = self.options.copy()
|
|
@@ -58,9 +65,10 @@ class OllamaDriver(Driver):
|
|
|
58
65
|
"stream": False,
|
|
59
66
|
}
|
|
60
67
|
|
|
61
|
-
# Native JSON mode support
|
|
68
|
+
# Native JSON mode / structured output support
|
|
62
69
|
if merged_options.get("json_mode"):
|
|
63
|
-
|
|
70
|
+
json_schema = merged_options.get("json_schema")
|
|
71
|
+
payload["format"] = json_schema if json_schema else "json"
|
|
64
72
|
|
|
65
73
|
# Add any Ollama-specific options from merged_options
|
|
66
74
|
if "temperature" in merged_options:
|
|
@@ -146,7 +154,8 @@ class OllamaDriver(Driver):
|
|
|
146
154
|
}
|
|
147
155
|
|
|
148
156
|
if merged_options.get("json_mode"):
|
|
149
|
-
|
|
157
|
+
json_schema = merged_options.get("json_schema")
|
|
158
|
+
payload["format"] = json_schema if json_schema else "json"
|
|
150
159
|
if "temperature" in merged_options:
|
|
151
160
|
payload["temperature"] = merged_options["temperature"]
|
|
152
161
|
if "top_p" in merged_options:
|
|
@@ -190,6 +199,7 @@ class OllamaDriver(Driver):
|
|
|
190
199
|
|
|
191
200
|
def generate_messages(self, messages: list[dict[str, Any]], options: dict[str, Any]) -> dict[str, Any]:
|
|
192
201
|
"""Use Ollama's /api/chat endpoint for multi-turn conversations."""
|
|
202
|
+
messages = self._prepare_messages(messages)
|
|
193
203
|
merged_options = self.options.copy()
|
|
194
204
|
if options:
|
|
195
205
|
merged_options.update(options)
|
|
@@ -203,9 +213,10 @@ class OllamaDriver(Driver):
|
|
|
203
213
|
"stream": False,
|
|
204
214
|
}
|
|
205
215
|
|
|
206
|
-
# Native JSON mode support
|
|
216
|
+
# Native JSON mode / structured output support
|
|
207
217
|
if merged_options.get("json_mode"):
|
|
208
|
-
|
|
218
|
+
json_schema = merged_options.get("json_schema")
|
|
219
|
+
payload["format"] = json_schema if json_schema else "json"
|
|
209
220
|
|
|
210
221
|
if "temperature" in merged_options:
|
|
211
222
|
payload["temperature"] = merged_options["temperature"]
|
|
@@ -21,6 +21,7 @@ class OpenAIDriver(CostMixin, Driver):
|
|
|
21
21
|
supports_json_schema = True
|
|
22
22
|
supports_tool_use = True
|
|
23
23
|
supports_streaming = True
|
|
24
|
+
supports_vision = True
|
|
24
25
|
|
|
25
26
|
# Approximate pricing per 1K tokens (keep updated with OpenAI's official pricing)
|
|
26
27
|
# Each model entry also defines which token parameter it supports and
|
|
@@ -74,12 +75,17 @@ class OpenAIDriver(CostMixin, Driver):
|
|
|
74
75
|
|
|
75
76
|
supports_messages = True
|
|
76
77
|
|
|
78
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
79
|
+
from .vision_helpers import _prepare_openai_vision_messages
|
|
80
|
+
|
|
81
|
+
return _prepare_openai_vision_messages(messages)
|
|
82
|
+
|
|
77
83
|
def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
78
84
|
messages = [{"role": "user", "content": prompt}]
|
|
79
85
|
return self._do_generate(messages, options)
|
|
80
86
|
|
|
81
87
|
def generate_messages(self, messages: list[dict[str, Any]], options: dict[str, Any]) -> dict[str, Any]:
|
|
82
|
-
return self._do_generate(messages, options)
|
|
88
|
+
return self._do_generate(self._prepare_messages(messages), options)
|
|
83
89
|
|
|
84
90
|
def _do_generate(self, messages: list[dict[str, Any]], options: dict[str, Any]) -> dict[str, Any]:
|
|
85
91
|
if self.client is None:
|
|
@@ -87,10 +93,17 @@ class OpenAIDriver(CostMixin, Driver):
|
|
|
87
93
|
|
|
88
94
|
model = options.get("model", self.model)
|
|
89
95
|
|
|
90
|
-
# Lookup model-specific config
|
|
91
|
-
|
|
92
|
-
tokens_param =
|
|
93
|
-
supports_temperature =
|
|
96
|
+
# Lookup model-specific config (live models.dev data + hardcoded fallback)
|
|
97
|
+
model_config = self._get_model_config("openai", model)
|
|
98
|
+
tokens_param = model_config["tokens_param"]
|
|
99
|
+
supports_temperature = model_config["supports_temperature"]
|
|
100
|
+
|
|
101
|
+
# Validate capabilities against models.dev metadata
|
|
102
|
+
self._validate_model_capabilities(
|
|
103
|
+
"openai",
|
|
104
|
+
model,
|
|
105
|
+
using_json_schema=bool(options.get("json_schema")),
|
|
106
|
+
)
|
|
94
107
|
|
|
95
108
|
# Defaults
|
|
96
109
|
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
@@ -162,9 +175,11 @@ class OpenAIDriver(CostMixin, Driver):
|
|
|
162
175
|
raise RuntimeError("openai package (>=1.0.0) is not installed")
|
|
163
176
|
|
|
164
177
|
model = options.get("model", self.model)
|
|
165
|
-
|
|
166
|
-
tokens_param =
|
|
167
|
-
supports_temperature =
|
|
178
|
+
model_config = self._get_model_config("openai", model)
|
|
179
|
+
tokens_param = model_config["tokens_param"]
|
|
180
|
+
supports_temperature = model_config["supports_temperature"]
|
|
181
|
+
|
|
182
|
+
self._validate_model_capabilities("openai", model, using_tool_use=True)
|
|
168
183
|
|
|
169
184
|
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
170
185
|
|
|
@@ -233,9 +248,9 @@ class OpenAIDriver(CostMixin, Driver):
|
|
|
233
248
|
raise RuntimeError("openai package (>=1.0.0) is not installed")
|
|
234
249
|
|
|
235
250
|
model = options.get("model", self.model)
|
|
236
|
-
|
|
237
|
-
tokens_param =
|
|
238
|
-
supports_temperature =
|
|
251
|
+
model_config = self._get_model_config("openai", model)
|
|
252
|
+
tokens_param = model_config["tokens_param"]
|
|
253
|
+
supports_temperature = model_config["supports_temperature"]
|
|
239
254
|
|
|
240
255
|
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
241
256
|
|
|
@@ -13,6 +13,7 @@ from ..driver import Driver
|
|
|
13
13
|
|
|
14
14
|
class OpenRouterDriver(CostMixin, Driver):
|
|
15
15
|
supports_json_mode = True
|
|
16
|
+
supports_vision = True
|
|
16
17
|
|
|
17
18
|
# Approximate pricing per 1K tokens based on OpenRouter's pricing
|
|
18
19
|
# https://openrouter.ai/docs#pricing
|
|
@@ -66,12 +67,17 @@ class OpenRouterDriver(CostMixin, Driver):
|
|
|
66
67
|
|
|
67
68
|
supports_messages = True
|
|
68
69
|
|
|
70
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
71
|
+
from .vision_helpers import _prepare_openai_vision_messages
|
|
72
|
+
|
|
73
|
+
return _prepare_openai_vision_messages(messages)
|
|
74
|
+
|
|
69
75
|
def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
70
76
|
messages = [{"role": "user", "content": prompt}]
|
|
71
77
|
return self._do_generate(messages, options)
|
|
72
78
|
|
|
73
79
|
def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
74
|
-
return self._do_generate(messages, options)
|
|
80
|
+
return self._do_generate(self._prepare_messages(messages), options)
|
|
75
81
|
|
|
76
82
|
def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
77
83
|
if not self.api_key:
|
|
@@ -79,10 +85,10 @@ class OpenRouterDriver(CostMixin, Driver):
|
|
|
79
85
|
|
|
80
86
|
model = options.get("model", self.model)
|
|
81
87
|
|
|
82
|
-
# Lookup model-specific config
|
|
83
|
-
|
|
84
|
-
tokens_param =
|
|
85
|
-
supports_temperature =
|
|
88
|
+
# Lookup model-specific config (live models.dev data + hardcoded fallback)
|
|
89
|
+
model_config = self._get_model_config("openrouter", model)
|
|
90
|
+
tokens_param = model_config["tokens_param"]
|
|
91
|
+
supports_temperature = model_config["supports_temperature"]
|
|
86
92
|
|
|
87
93
|
# Defaults
|
|
88
94
|
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
"""Shared helpers for converting universal vision message blocks to provider-specific formats."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def _prepare_openai_vision_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
9
|
+
"""Convert universal image blocks to OpenAI-compatible vision format.
|
|
10
|
+
|
|
11
|
+
Works for OpenAI, Azure, Groq, Grok, LM Studio, and OpenRouter.
|
|
12
|
+
|
|
13
|
+
Universal format::
|
|
14
|
+
|
|
15
|
+
{"type": "image", "source": ImageContent(...)}
|
|
16
|
+
|
|
17
|
+
OpenAI format::
|
|
18
|
+
|
|
19
|
+
{"type": "image_url", "image_url": {"url": "data:mime;base64,..."}}
|
|
20
|
+
"""
|
|
21
|
+
out: list[dict[str, Any]] = []
|
|
22
|
+
for msg in messages:
|
|
23
|
+
content = msg.get("content")
|
|
24
|
+
if not isinstance(content, list):
|
|
25
|
+
out.append(msg)
|
|
26
|
+
continue
|
|
27
|
+
new_blocks: list[dict[str, Any]] = []
|
|
28
|
+
for block in content:
|
|
29
|
+
if isinstance(block, dict) and block.get("type") == "image":
|
|
30
|
+
source = block["source"]
|
|
31
|
+
if source.source_type == "url" and source.url:
|
|
32
|
+
url = source.url
|
|
33
|
+
else:
|
|
34
|
+
url = f"data:{source.media_type};base64,{source.data}"
|
|
35
|
+
new_blocks.append(
|
|
36
|
+
{
|
|
37
|
+
"type": "image_url",
|
|
38
|
+
"image_url": {"url": url},
|
|
39
|
+
}
|
|
40
|
+
)
|
|
41
|
+
else:
|
|
42
|
+
new_blocks.append(block)
|
|
43
|
+
out.append({**msg, "content": new_blocks})
|
|
44
|
+
return out
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _prepare_claude_vision_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
48
|
+
"""Convert universal image blocks to Anthropic Claude format.
|
|
49
|
+
|
|
50
|
+
Claude format::
|
|
51
|
+
|
|
52
|
+
{"type": "image", "source": {"type": "base64", "media_type": "...", "data": "..."}}
|
|
53
|
+
"""
|
|
54
|
+
out: list[dict[str, Any]] = []
|
|
55
|
+
for msg in messages:
|
|
56
|
+
content = msg.get("content")
|
|
57
|
+
if not isinstance(content, list):
|
|
58
|
+
out.append(msg)
|
|
59
|
+
continue
|
|
60
|
+
new_blocks: list[dict[str, Any]] = []
|
|
61
|
+
for block in content:
|
|
62
|
+
if isinstance(block, dict) and block.get("type") == "image":
|
|
63
|
+
source = block["source"]
|
|
64
|
+
if source.source_type == "url" and source.url:
|
|
65
|
+
new_blocks.append(
|
|
66
|
+
{
|
|
67
|
+
"type": "image",
|
|
68
|
+
"source": {
|
|
69
|
+
"type": "url",
|
|
70
|
+
"url": source.url,
|
|
71
|
+
},
|
|
72
|
+
}
|
|
73
|
+
)
|
|
74
|
+
else:
|
|
75
|
+
new_blocks.append(
|
|
76
|
+
{
|
|
77
|
+
"type": "image",
|
|
78
|
+
"source": {
|
|
79
|
+
"type": "base64",
|
|
80
|
+
"media_type": source.media_type,
|
|
81
|
+
"data": source.data,
|
|
82
|
+
},
|
|
83
|
+
}
|
|
84
|
+
)
|
|
85
|
+
else:
|
|
86
|
+
new_blocks.append(block)
|
|
87
|
+
out.append({**msg, "content": new_blocks})
|
|
88
|
+
return out
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def _prepare_google_vision_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
92
|
+
"""Convert universal image blocks to Google Gemini format.
|
|
93
|
+
|
|
94
|
+
Gemini expects ``parts`` arrays containing text and inline_data dicts::
|
|
95
|
+
|
|
96
|
+
{"role": "user", "parts": [
|
|
97
|
+
"text prompt",
|
|
98
|
+
{"inline_data": {"mime_type": "image/png", "data": "base64..."}},
|
|
99
|
+
]}
|
|
100
|
+
"""
|
|
101
|
+
out: list[dict[str, Any]] = []
|
|
102
|
+
for msg in messages:
|
|
103
|
+
content = msg.get("content")
|
|
104
|
+
if not isinstance(content, list):
|
|
105
|
+
out.append(msg)
|
|
106
|
+
continue
|
|
107
|
+
# Convert content blocks to Gemini parts
|
|
108
|
+
parts: list[Any] = []
|
|
109
|
+
for block in content:
|
|
110
|
+
if isinstance(block, dict) and block.get("type") == "text":
|
|
111
|
+
parts.append(block["text"])
|
|
112
|
+
elif isinstance(block, dict) and block.get("type") == "image":
|
|
113
|
+
source = block["source"]
|
|
114
|
+
parts.append(
|
|
115
|
+
{
|
|
116
|
+
"inline_data": {
|
|
117
|
+
"mime_type": source.media_type,
|
|
118
|
+
"data": source.data,
|
|
119
|
+
}
|
|
120
|
+
}
|
|
121
|
+
)
|
|
122
|
+
else:
|
|
123
|
+
parts.append(block)
|
|
124
|
+
out.append({**msg, "content": parts, "_vision_parts": True})
|
|
125
|
+
return out
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def _prepare_ollama_vision_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
129
|
+
"""Convert universal image blocks to Ollama format.
|
|
130
|
+
|
|
131
|
+
Ollama expects images as a separate field::
|
|
132
|
+
|
|
133
|
+
{"role": "user", "content": "text", "images": ["base64..."]}
|
|
134
|
+
"""
|
|
135
|
+
out: list[dict[str, Any]] = []
|
|
136
|
+
for msg in messages:
|
|
137
|
+
content = msg.get("content")
|
|
138
|
+
if not isinstance(content, list):
|
|
139
|
+
out.append(msg)
|
|
140
|
+
continue
|
|
141
|
+
text_parts: list[str] = []
|
|
142
|
+
images: list[str] = []
|
|
143
|
+
for block in content:
|
|
144
|
+
if isinstance(block, dict) and block.get("type") == "text":
|
|
145
|
+
text_parts.append(block["text"])
|
|
146
|
+
elif isinstance(block, dict) and block.get("type") == "image":
|
|
147
|
+
source = block["source"]
|
|
148
|
+
images.append(source.data)
|
|
149
|
+
new_msg = {**msg, "content": " ".join(text_parts)}
|
|
150
|
+
if images:
|
|
151
|
+
new_msg["images"] = images
|
|
152
|
+
out.append(new_msg)
|
|
153
|
+
return out
|