prompture 0.0.35__py3-none-any.whl → 0.0.40.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prompture/__init__.py +132 -3
- prompture/_version.py +2 -2
- prompture/agent.py +924 -0
- prompture/agent_types.py +156 -0
- prompture/async_agent.py +880 -0
- prompture/async_conversation.py +208 -17
- prompture/async_core.py +16 -0
- prompture/async_driver.py +63 -0
- prompture/async_groups.py +551 -0
- prompture/conversation.py +222 -18
- prompture/core.py +46 -12
- prompture/cost_mixin.py +37 -0
- prompture/discovery.py +132 -44
- prompture/driver.py +77 -0
- prompture/drivers/__init__.py +5 -1
- prompture/drivers/async_azure_driver.py +11 -5
- prompture/drivers/async_claude_driver.py +184 -9
- prompture/drivers/async_google_driver.py +222 -28
- prompture/drivers/async_grok_driver.py +11 -5
- prompture/drivers/async_groq_driver.py +11 -5
- prompture/drivers/async_lmstudio_driver.py +74 -5
- prompture/drivers/async_ollama_driver.py +13 -3
- prompture/drivers/async_openai_driver.py +162 -5
- prompture/drivers/async_openrouter_driver.py +11 -5
- prompture/drivers/async_registry.py +5 -1
- prompture/drivers/azure_driver.py +10 -4
- prompture/drivers/claude_driver.py +17 -1
- prompture/drivers/google_driver.py +227 -33
- prompture/drivers/grok_driver.py +11 -5
- prompture/drivers/groq_driver.py +11 -5
- prompture/drivers/lmstudio_driver.py +73 -8
- prompture/drivers/ollama_driver.py +16 -5
- prompture/drivers/openai_driver.py +26 -11
- prompture/drivers/openrouter_driver.py +11 -5
- prompture/drivers/vision_helpers.py +153 -0
- prompture/group_types.py +147 -0
- prompture/groups.py +530 -0
- prompture/image.py +180 -0
- prompture/ledger.py +252 -0
- prompture/model_rates.py +112 -2
- prompture/persistence.py +254 -0
- prompture/persona.py +482 -0
- prompture/serialization.py +218 -0
- prompture/settings.py +1 -0
- prompture-0.0.40.dev1.dist-info/METADATA +369 -0
- prompture-0.0.40.dev1.dist-info/RECORD +78 -0
- prompture-0.0.35.dist-info/METADATA +0 -464
- prompture-0.0.35.dist-info/RECORD +0 -66
- {prompture-0.0.35.dist-info → prompture-0.0.40.dev1.dist-info}/WHEEL +0 -0
- {prompture-0.0.35.dist-info → prompture-0.0.40.dev1.dist-info}/entry_points.txt +0 -0
- {prompture-0.0.35.dist-info → prompture-0.0.40.dev1.dist-info}/licenses/LICENSE +0 -0
- {prompture-0.0.35.dist-info → prompture-0.0.40.dev1.dist-info}/top_level.txt +0 -0
prompture/discovery.py
CHANGED
|
@@ -1,7 +1,11 @@
|
|
|
1
1
|
"""Discovery module for auto-detecting available models."""
|
|
2
2
|
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import dataclasses
|
|
3
6
|
import logging
|
|
4
7
|
import os
|
|
8
|
+
from typing import Any, overload
|
|
5
9
|
|
|
6
10
|
import requests
|
|
7
11
|
|
|
@@ -22,23 +26,40 @@ from .settings import settings
|
|
|
22
26
|
logger = logging.getLogger(__name__)
|
|
23
27
|
|
|
24
28
|
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
29
|
+
@overload
|
|
30
|
+
def get_available_models(*, include_capabilities: bool = False, verified_only: bool = False) -> list[str]: ...
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@overload
|
|
34
|
+
def get_available_models(*, include_capabilities: bool = True, verified_only: bool = False) -> list[dict[str, Any]]: ...
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def get_available_models(
|
|
38
|
+
*,
|
|
39
|
+
include_capabilities: bool = False,
|
|
40
|
+
verified_only: bool = False,
|
|
41
|
+
) -> list[str] | list[dict[str, Any]]:
|
|
42
|
+
"""Auto-detect available models based on configured drivers and environment variables.
|
|
28
43
|
|
|
29
|
-
Iterates through supported providers and checks if they are configured
|
|
30
|
-
For static drivers, returns models from their
|
|
31
|
-
For dynamic drivers (like Ollama), attempts to
|
|
44
|
+
Iterates through supported providers and checks if they are configured
|
|
45
|
+
(e.g. API key present). For static drivers, returns models from their
|
|
46
|
+
``MODEL_PRICING`` keys. For dynamic drivers (like Ollama), attempts to
|
|
47
|
+
fetch available models from the endpoint.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
include_capabilities: When ``True``, return enriched dicts with
|
|
51
|
+
``model``, ``provider``, ``model_id``, and ``capabilities``
|
|
52
|
+
fields instead of plain ``"provider/model_id"`` strings.
|
|
53
|
+
verified_only: When ``True``, only return models that have been
|
|
54
|
+
successfully used (as recorded by the usage ledger).
|
|
32
55
|
|
|
33
56
|
Returns:
|
|
34
|
-
A list of unique model strings
|
|
57
|
+
A sorted list of unique model strings (default) or enriched dicts.
|
|
35
58
|
"""
|
|
36
59
|
available_models: set[str] = set()
|
|
37
60
|
configured_providers: set[str] = set()
|
|
38
61
|
|
|
39
62
|
# Map of provider name to driver class
|
|
40
|
-
# We need to map the registry keys to the actual classes to check MODEL_PRICING
|
|
41
|
-
# and instantiate for dynamic checks if needed.
|
|
42
63
|
provider_classes = {
|
|
43
64
|
"openai": OpenAIDriver,
|
|
44
65
|
"azure": AzureDriver,
|
|
@@ -54,11 +75,6 @@ def get_available_models() -> list[str]:
|
|
|
54
75
|
|
|
55
76
|
for provider, driver_cls in provider_classes.items():
|
|
56
77
|
try:
|
|
57
|
-
# 1. Check if the provider is configured (has API key or endpoint)
|
|
58
|
-
# We can check this by looking at the settings or env vars that the driver uses.
|
|
59
|
-
# A simple way is to try to instantiate it with defaults, but that might fail if keys are missing.
|
|
60
|
-
# Instead, let's check the specific requirements for each known provider.
|
|
61
|
-
|
|
62
78
|
is_configured = False
|
|
63
79
|
|
|
64
80
|
if provider == "openai":
|
|
@@ -86,14 +102,11 @@ def get_available_models() -> list[str]:
|
|
|
86
102
|
elif provider == "grok":
|
|
87
103
|
if settings.grok_api_key or os.getenv("GROK_API_KEY"):
|
|
88
104
|
is_configured = True
|
|
89
|
-
elif
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
# LM Studio is similar to Ollama, defaults to localhost
|
|
95
|
-
is_configured = True
|
|
96
|
-
elif provider == "local_http" and (settings.local_http_endpoint or os.getenv("LOCAL_HTTP_ENDPOINT")):
|
|
105
|
+
elif (
|
|
106
|
+
provider == "ollama"
|
|
107
|
+
or provider == "lmstudio"
|
|
108
|
+
or (provider == "local_http" and os.getenv("LOCAL_HTTP_ENDPOINT"))
|
|
109
|
+
):
|
|
97
110
|
is_configured = True
|
|
98
111
|
|
|
99
112
|
if not is_configured:
|
|
@@ -101,36 +114,20 @@ def get_available_models() -> list[str]:
|
|
|
101
114
|
|
|
102
115
|
configured_providers.add(provider)
|
|
103
116
|
|
|
104
|
-
#
|
|
117
|
+
# Static Detection: Get models from MODEL_PRICING
|
|
105
118
|
if hasattr(driver_cls, "MODEL_PRICING"):
|
|
106
119
|
pricing = driver_cls.MODEL_PRICING
|
|
107
120
|
for model_id in pricing:
|
|
108
|
-
# Skip "default" or generic keys if they exist
|
|
109
121
|
if model_id == "default":
|
|
110
122
|
continue
|
|
111
|
-
|
|
112
|
-
# For Azure, the model_id in pricing is usually the base model name,
|
|
113
|
-
# but the user needs to use the deployment ID.
|
|
114
|
-
# However, our Azure driver implementation uses the deployment_id from init
|
|
115
|
-
# as the "model" for the request, but expects the user to pass a model name
|
|
116
|
-
# that maps to pricing?
|
|
117
|
-
# Looking at AzureDriver:
|
|
118
|
-
# kwargs = {"model": self.deployment_id, ...}
|
|
119
|
-
# model = options.get("model", self.model) -> used for pricing lookup
|
|
120
|
-
# So we should list the keys in MODEL_PRICING as available "models"
|
|
121
|
-
# even though for Azure specifically it's a bit weird because of deployment IDs.
|
|
122
|
-
# But for general discovery, listing supported models is correct.
|
|
123
|
-
|
|
124
123
|
available_models.add(f"{provider}/{model_id}")
|
|
125
124
|
|
|
126
|
-
#
|
|
125
|
+
# Dynamic Detection: Specific logic for Ollama
|
|
127
126
|
if provider == "ollama":
|
|
128
127
|
try:
|
|
129
128
|
endpoint = settings.ollama_endpoint or os.getenv(
|
|
130
129
|
"OLLAMA_ENDPOINT", "http://localhost:11434/api/generate"
|
|
131
130
|
)
|
|
132
|
-
# We need the base URL for tags, usually http://localhost:11434/api/tags
|
|
133
|
-
# The configured endpoint might be .../api/generate or .../api/chat
|
|
134
131
|
base_url = endpoint.split("/api/")[0]
|
|
135
132
|
tags_url = f"{base_url}/api/tags"
|
|
136
133
|
|
|
@@ -141,13 +138,34 @@ def get_available_models() -> list[str]:
|
|
|
141
138
|
for model in models:
|
|
142
139
|
name = model.get("name")
|
|
143
140
|
if name:
|
|
144
|
-
# Ollama model names often include tags like "llama3:latest"
|
|
145
|
-
# We can keep them as is.
|
|
146
141
|
available_models.add(f"ollama/{name}")
|
|
147
142
|
except Exception as e:
|
|
148
143
|
logger.debug(f"Failed to fetch Ollama models: {e}")
|
|
149
144
|
|
|
150
|
-
#
|
|
145
|
+
# Dynamic Detection: LM Studio loaded models
|
|
146
|
+
if provider == "lmstudio":
|
|
147
|
+
try:
|
|
148
|
+
endpoint = settings.lmstudio_endpoint or os.getenv(
|
|
149
|
+
"LMSTUDIO_ENDPOINT", "http://127.0.0.1:1234/v1/chat/completions"
|
|
150
|
+
)
|
|
151
|
+
base_url = endpoint.split("/v1/")[0]
|
|
152
|
+
models_url = f"{base_url}/v1/models"
|
|
153
|
+
|
|
154
|
+
headers: dict[str, str] = {}
|
|
155
|
+
api_key = settings.lmstudio_api_key or os.getenv("LMSTUDIO_API_KEY")
|
|
156
|
+
if api_key:
|
|
157
|
+
headers["Authorization"] = f"Bearer {api_key}"
|
|
158
|
+
|
|
159
|
+
resp = requests.get(models_url, headers=headers, timeout=2)
|
|
160
|
+
if resp.status_code == 200:
|
|
161
|
+
data = resp.json()
|
|
162
|
+
models = data.get("data", [])
|
|
163
|
+
for model in models:
|
|
164
|
+
model_id = model.get("id")
|
|
165
|
+
if model_id:
|
|
166
|
+
available_models.add(f"lmstudio/{model_id}")
|
|
167
|
+
except Exception as e:
|
|
168
|
+
logger.debug(f"Failed to fetch LM Studio models: {e}")
|
|
151
169
|
|
|
152
170
|
except Exception as e:
|
|
153
171
|
logger.warning(f"Error detecting models for provider {provider}: {e}")
|
|
@@ -161,4 +179,74 @@ def get_available_models() -> list[str]:
|
|
|
161
179
|
for model_id in get_all_provider_models(api_name):
|
|
162
180
|
available_models.add(f"{prompture_name}/{model_id}")
|
|
163
181
|
|
|
164
|
-
|
|
182
|
+
sorted_models = sorted(available_models)
|
|
183
|
+
|
|
184
|
+
# --- verified_only filtering ---
|
|
185
|
+
verified_set: set[str] | None = None
|
|
186
|
+
if verified_only or include_capabilities:
|
|
187
|
+
try:
|
|
188
|
+
from .ledger import _get_ledger
|
|
189
|
+
|
|
190
|
+
ledger = _get_ledger()
|
|
191
|
+
verified_set = ledger.get_verified_models()
|
|
192
|
+
except Exception:
|
|
193
|
+
logger.debug("Could not load ledger for verified models", exc_info=True)
|
|
194
|
+
verified_set = set()
|
|
195
|
+
|
|
196
|
+
if verified_only and verified_set is not None:
|
|
197
|
+
sorted_models = [m for m in sorted_models if m in verified_set]
|
|
198
|
+
|
|
199
|
+
if not include_capabilities:
|
|
200
|
+
return sorted_models
|
|
201
|
+
|
|
202
|
+
# Build enriched dicts with capabilities from models.dev
|
|
203
|
+
from .model_rates import get_model_capabilities
|
|
204
|
+
|
|
205
|
+
# Fetch all ledger stats for annotation (keyed by model_name)
|
|
206
|
+
ledger_stats: dict[str, dict[str, Any]] = {}
|
|
207
|
+
try:
|
|
208
|
+
from .ledger import _get_ledger
|
|
209
|
+
|
|
210
|
+
for row in _get_ledger().get_all_stats():
|
|
211
|
+
name = row["model_name"]
|
|
212
|
+
if name not in ledger_stats:
|
|
213
|
+
ledger_stats[name] = row
|
|
214
|
+
else:
|
|
215
|
+
# Aggregate across API key hashes
|
|
216
|
+
existing = ledger_stats[name]
|
|
217
|
+
existing["use_count"] += row["use_count"]
|
|
218
|
+
existing["total_tokens"] += row["total_tokens"]
|
|
219
|
+
existing["total_cost"] += row["total_cost"]
|
|
220
|
+
if row["last_used"] > existing["last_used"]:
|
|
221
|
+
existing["last_used"] = row["last_used"]
|
|
222
|
+
except Exception:
|
|
223
|
+
logger.debug("Could not load ledger stats for enrichment", exc_info=True)
|
|
224
|
+
|
|
225
|
+
enriched: list[dict[str, Any]] = []
|
|
226
|
+
for model_str in sorted_models:
|
|
227
|
+
parts = model_str.split("/", 1)
|
|
228
|
+
provider = parts[0]
|
|
229
|
+
model_id = parts[1] if len(parts) > 1 else parts[0]
|
|
230
|
+
|
|
231
|
+
caps = get_model_capabilities(provider, model_id)
|
|
232
|
+
caps_dict = dataclasses.asdict(caps) if caps is not None else None
|
|
233
|
+
|
|
234
|
+
entry: dict[str, Any] = {
|
|
235
|
+
"model": model_str,
|
|
236
|
+
"provider": provider,
|
|
237
|
+
"model_id": model_id,
|
|
238
|
+
"capabilities": caps_dict,
|
|
239
|
+
"verified": verified_set is not None and model_str in verified_set,
|
|
240
|
+
}
|
|
241
|
+
|
|
242
|
+
stats = ledger_stats.get(model_str)
|
|
243
|
+
if stats:
|
|
244
|
+
entry["last_used"] = stats["last_used"]
|
|
245
|
+
entry["use_count"] = stats["use_count"]
|
|
246
|
+
else:
|
|
247
|
+
entry["last_used"] = None
|
|
248
|
+
entry["use_count"] = 0
|
|
249
|
+
|
|
250
|
+
enriched.append(entry)
|
|
251
|
+
|
|
252
|
+
return enriched
|
prompture/driver.py
CHANGED
|
@@ -35,6 +35,7 @@ class Driver:
|
|
|
35
35
|
supports_messages: bool = False
|
|
36
36
|
supports_tool_use: bool = False
|
|
37
37
|
supports_streaming: bool = False
|
|
38
|
+
supports_vision: bool = False
|
|
38
39
|
|
|
39
40
|
callbacks: DriverCallbacks | None = None
|
|
40
41
|
|
|
@@ -52,6 +53,7 @@ class Driver:
|
|
|
52
53
|
support message arrays should override this method and set
|
|
53
54
|
``supports_messages = True``.
|
|
54
55
|
"""
|
|
56
|
+
self._check_vision_support(messages)
|
|
55
57
|
prompt = self._flatten_messages(messages)
|
|
56
58
|
return self.generate(prompt, options)
|
|
57
59
|
|
|
@@ -171,6 +173,69 @@ class Driver:
|
|
|
171
173
|
except Exception:
|
|
172
174
|
logger.exception("Callback %s raised an exception", event)
|
|
173
175
|
|
|
176
|
+
def _validate_model_capabilities(
|
|
177
|
+
self,
|
|
178
|
+
provider: str,
|
|
179
|
+
model: str,
|
|
180
|
+
*,
|
|
181
|
+
using_tool_use: bool = False,
|
|
182
|
+
using_json_schema: bool = False,
|
|
183
|
+
using_vision: bool = False,
|
|
184
|
+
) -> None:
|
|
185
|
+
"""Log warnings when the model may not support a requested feature.
|
|
186
|
+
|
|
187
|
+
Uses models.dev metadata as a secondary signal. Warnings only — the
|
|
188
|
+
API is the final authority and models.dev data may be stale.
|
|
189
|
+
"""
|
|
190
|
+
from .model_rates import get_model_capabilities
|
|
191
|
+
|
|
192
|
+
caps = get_model_capabilities(provider, model)
|
|
193
|
+
if caps is None:
|
|
194
|
+
return
|
|
195
|
+
|
|
196
|
+
if using_tool_use and caps.supports_tool_use is False:
|
|
197
|
+
logger.warning(
|
|
198
|
+
"Model %s/%s may not support tool use according to models.dev metadata",
|
|
199
|
+
provider,
|
|
200
|
+
model,
|
|
201
|
+
)
|
|
202
|
+
if using_json_schema and caps.supports_structured_output is False:
|
|
203
|
+
logger.warning(
|
|
204
|
+
"Model %s/%s may not support structured output / JSON schema according to models.dev metadata",
|
|
205
|
+
provider,
|
|
206
|
+
model,
|
|
207
|
+
)
|
|
208
|
+
if using_vision and caps.supports_vision is False:
|
|
209
|
+
logger.warning(
|
|
210
|
+
"Model %s/%s may not support vision/image inputs according to models.dev metadata",
|
|
211
|
+
provider,
|
|
212
|
+
model,
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
def _check_vision_support(self, messages: list[dict[str, Any]]) -> None:
|
|
216
|
+
"""Raise if messages contain image blocks and the driver lacks vision support."""
|
|
217
|
+
if self.supports_vision:
|
|
218
|
+
return
|
|
219
|
+
for msg in messages:
|
|
220
|
+
content = msg.get("content")
|
|
221
|
+
if isinstance(content, list):
|
|
222
|
+
for block in content:
|
|
223
|
+
if isinstance(block, dict) and block.get("type") == "image":
|
|
224
|
+
raise NotImplementedError(
|
|
225
|
+
f"{self.__class__.__name__} does not support vision/image inputs. "
|
|
226
|
+
"Use a vision-capable model."
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
230
|
+
"""Transform universal message format into provider-specific wire format.
|
|
231
|
+
|
|
232
|
+
Vision-capable drivers override this to convert the universal image
|
|
233
|
+
blocks into their provider-specific format. The base implementation
|
|
234
|
+
validates vision support and returns messages unchanged.
|
|
235
|
+
"""
|
|
236
|
+
self._check_vision_support(messages)
|
|
237
|
+
return messages
|
|
238
|
+
|
|
174
239
|
@staticmethod
|
|
175
240
|
def _flatten_messages(messages: list[dict[str, Any]]) -> str:
|
|
176
241
|
"""Join messages into a single prompt string with role prefixes."""
|
|
@@ -178,6 +243,18 @@ class Driver:
|
|
|
178
243
|
for msg in messages:
|
|
179
244
|
role = msg.get("role", "user")
|
|
180
245
|
content = msg.get("content", "")
|
|
246
|
+
# Handle content that is a list of blocks (vision messages)
|
|
247
|
+
if isinstance(content, list):
|
|
248
|
+
text_parts = []
|
|
249
|
+
for block in content:
|
|
250
|
+
if isinstance(block, dict):
|
|
251
|
+
if block.get("type") == "text":
|
|
252
|
+
text_parts.append(block.get("text", ""))
|
|
253
|
+
elif block.get("type") == "image":
|
|
254
|
+
text_parts.append("[image]")
|
|
255
|
+
elif isinstance(block, str):
|
|
256
|
+
text_parts.append(block)
|
|
257
|
+
content = " ".join(text_parts)
|
|
181
258
|
if role == "system":
|
|
182
259
|
parts.append(f"[System]: {content}")
|
|
183
260
|
elif role == "assistant":
|
prompture/drivers/__init__.py
CHANGED
|
@@ -84,7 +84,11 @@ register_driver(
|
|
|
84
84
|
)
|
|
85
85
|
register_driver(
|
|
86
86
|
"lmstudio",
|
|
87
|
-
lambda model=None: LMStudioDriver(
|
|
87
|
+
lambda model=None: LMStudioDriver(
|
|
88
|
+
endpoint=settings.lmstudio_endpoint,
|
|
89
|
+
model=model or settings.lmstudio_model,
|
|
90
|
+
api_key=settings.lmstudio_api_key,
|
|
91
|
+
),
|
|
88
92
|
overwrite=True,
|
|
89
93
|
)
|
|
90
94
|
register_driver(
|
|
@@ -18,6 +18,7 @@ from .azure_driver import AzureDriver
|
|
|
18
18
|
class AsyncAzureDriver(CostMixin, AsyncDriver):
|
|
19
19
|
supports_json_mode = True
|
|
20
20
|
supports_json_schema = True
|
|
21
|
+
supports_vision = True
|
|
21
22
|
|
|
22
23
|
MODEL_PRICING = AzureDriver.MODEL_PRICING
|
|
23
24
|
|
|
@@ -52,21 +53,26 @@ class AsyncAzureDriver(CostMixin, AsyncDriver):
|
|
|
52
53
|
|
|
53
54
|
supports_messages = True
|
|
54
55
|
|
|
56
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
57
|
+
from .vision_helpers import _prepare_openai_vision_messages
|
|
58
|
+
|
|
59
|
+
return _prepare_openai_vision_messages(messages)
|
|
60
|
+
|
|
55
61
|
async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
56
62
|
messages = [{"role": "user", "content": prompt}]
|
|
57
63
|
return await self._do_generate(messages, options)
|
|
58
64
|
|
|
59
65
|
async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
60
|
-
return await self._do_generate(messages, options)
|
|
66
|
+
return await self._do_generate(self._prepare_messages(messages), options)
|
|
61
67
|
|
|
62
68
|
async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
63
69
|
if self.client is None:
|
|
64
70
|
raise RuntimeError("openai package (>=1.0.0) with AsyncAzureOpenAI not installed")
|
|
65
71
|
|
|
66
72
|
model = options.get("model", self.model)
|
|
67
|
-
|
|
68
|
-
tokens_param =
|
|
69
|
-
supports_temperature =
|
|
73
|
+
model_config = self._get_model_config("azure", model)
|
|
74
|
+
tokens_param = model_config["tokens_param"]
|
|
75
|
+
supports_temperature = model_config["supports_temperature"]
|
|
70
76
|
|
|
71
77
|
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
72
78
|
|
|
@@ -107,7 +113,7 @@ class AsyncAzureDriver(CostMixin, AsyncDriver):
|
|
|
107
113
|
"prompt_tokens": prompt_tokens,
|
|
108
114
|
"completion_tokens": completion_tokens,
|
|
109
115
|
"total_tokens": total_tokens,
|
|
110
|
-
"cost": total_cost,
|
|
116
|
+
"cost": round(total_cost, 6),
|
|
111
117
|
"raw_response": resp.model_dump(),
|
|
112
118
|
"model_name": model,
|
|
113
119
|
"deployment_id": self.deployment_id,
|
|
@@ -4,6 +4,7 @@ from __future__ import annotations
|
|
|
4
4
|
|
|
5
5
|
import json
|
|
6
6
|
import os
|
|
7
|
+
from collections.abc import AsyncIterator
|
|
7
8
|
from typing import Any
|
|
8
9
|
|
|
9
10
|
try:
|
|
@@ -19,6 +20,9 @@ from .claude_driver import ClaudeDriver
|
|
|
19
20
|
class AsyncClaudeDriver(CostMixin, AsyncDriver):
|
|
20
21
|
supports_json_mode = True
|
|
21
22
|
supports_json_schema = True
|
|
23
|
+
supports_tool_use = True
|
|
24
|
+
supports_streaming = True
|
|
25
|
+
supports_vision = True
|
|
22
26
|
|
|
23
27
|
MODEL_PRICING = ClaudeDriver.MODEL_PRICING
|
|
24
28
|
|
|
@@ -28,12 +32,17 @@ class AsyncClaudeDriver(CostMixin, AsyncDriver):
|
|
|
28
32
|
|
|
29
33
|
supports_messages = True
|
|
30
34
|
|
|
35
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
36
|
+
from .vision_helpers import _prepare_claude_vision_messages
|
|
37
|
+
|
|
38
|
+
return _prepare_claude_vision_messages(messages)
|
|
39
|
+
|
|
31
40
|
async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
32
41
|
messages = [{"role": "user", "content": prompt}]
|
|
33
42
|
return await self._do_generate(messages, options)
|
|
34
43
|
|
|
35
44
|
async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
36
|
-
return await self._do_generate(messages, options)
|
|
45
|
+
return await self._do_generate(self._prepare_messages(messages), options)
|
|
37
46
|
|
|
38
47
|
async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
39
48
|
if anthropic is None:
|
|
@@ -42,16 +51,17 @@ class AsyncClaudeDriver(CostMixin, AsyncDriver):
|
|
|
42
51
|
opts = {**{"temperature": 0.0, "max_tokens": 512}, **options}
|
|
43
52
|
model = options.get("model", self.model)
|
|
44
53
|
|
|
54
|
+
# Validate capabilities against models.dev metadata
|
|
55
|
+
self._validate_model_capabilities(
|
|
56
|
+
"claude",
|
|
57
|
+
model,
|
|
58
|
+
using_json_schema=bool(options.get("json_schema")),
|
|
59
|
+
)
|
|
60
|
+
|
|
45
61
|
client = anthropic.AsyncAnthropic(api_key=self.api_key)
|
|
46
62
|
|
|
47
63
|
# Anthropic requires system messages as a top-level parameter
|
|
48
|
-
system_content =
|
|
49
|
-
api_messages = []
|
|
50
|
-
for msg in messages:
|
|
51
|
-
if msg.get("role") == "system":
|
|
52
|
-
system_content = msg.get("content", "")
|
|
53
|
-
else:
|
|
54
|
-
api_messages.append(msg)
|
|
64
|
+
system_content, api_messages = self._extract_system_and_messages(messages)
|
|
55
65
|
|
|
56
66
|
# Build common kwargs
|
|
57
67
|
common_kwargs: dict[str, Any] = {
|
|
@@ -99,9 +109,174 @@ class AsyncClaudeDriver(CostMixin, AsyncDriver):
|
|
|
99
109
|
"prompt_tokens": prompt_tokens,
|
|
100
110
|
"completion_tokens": completion_tokens,
|
|
101
111
|
"total_tokens": total_tokens,
|
|
102
|
-
"cost": total_cost,
|
|
112
|
+
"cost": round(total_cost, 6),
|
|
103
113
|
"raw_response": dict(resp),
|
|
104
114
|
"model_name": model,
|
|
105
115
|
}
|
|
106
116
|
|
|
107
117
|
return {"text": text, "meta": meta}
|
|
118
|
+
|
|
119
|
+
# ------------------------------------------------------------------
|
|
120
|
+
# Helpers
|
|
121
|
+
# ------------------------------------------------------------------
|
|
122
|
+
|
|
123
|
+
def _extract_system_and_messages(
|
|
124
|
+
self, messages: list[dict[str, Any]]
|
|
125
|
+
) -> tuple[str | None, list[dict[str, Any]]]:
|
|
126
|
+
"""Separate system message from conversation messages for Anthropic API."""
|
|
127
|
+
system_content = None
|
|
128
|
+
api_messages: list[dict[str, Any]] = []
|
|
129
|
+
for msg in messages:
|
|
130
|
+
if msg.get("role") == "system":
|
|
131
|
+
system_content = msg.get("content", "")
|
|
132
|
+
else:
|
|
133
|
+
api_messages.append(msg)
|
|
134
|
+
return system_content, api_messages
|
|
135
|
+
|
|
136
|
+
# ------------------------------------------------------------------
|
|
137
|
+
# Tool use
|
|
138
|
+
# ------------------------------------------------------------------
|
|
139
|
+
|
|
140
|
+
async def generate_messages_with_tools(
|
|
141
|
+
self,
|
|
142
|
+
messages: list[dict[str, Any]],
|
|
143
|
+
tools: list[dict[str, Any]],
|
|
144
|
+
options: dict[str, Any],
|
|
145
|
+
) -> dict[str, Any]:
|
|
146
|
+
"""Generate a response that may include tool calls (Anthropic)."""
|
|
147
|
+
if anthropic is None:
|
|
148
|
+
raise RuntimeError("anthropic package not installed")
|
|
149
|
+
|
|
150
|
+
opts = {**{"temperature": 0.0, "max_tokens": 512}, **options}
|
|
151
|
+
model = options.get("model", self.model)
|
|
152
|
+
|
|
153
|
+
self._validate_model_capabilities("claude", model, using_tool_use=True)
|
|
154
|
+
|
|
155
|
+
client = anthropic.AsyncAnthropic(api_key=self.api_key)
|
|
156
|
+
|
|
157
|
+
system_content, api_messages = self._extract_system_and_messages(messages)
|
|
158
|
+
|
|
159
|
+
# Convert tools from OpenAI format to Anthropic format if needed
|
|
160
|
+
anthropic_tools = []
|
|
161
|
+
for t in tools:
|
|
162
|
+
if "type" in t and t["type"] == "function":
|
|
163
|
+
# OpenAI format -> Anthropic format
|
|
164
|
+
fn = t["function"]
|
|
165
|
+
anthropic_tools.append({
|
|
166
|
+
"name": fn["name"],
|
|
167
|
+
"description": fn.get("description", ""),
|
|
168
|
+
"input_schema": fn.get("parameters", {"type": "object", "properties": {}}),
|
|
169
|
+
})
|
|
170
|
+
elif "input_schema" in t:
|
|
171
|
+
# Already Anthropic format
|
|
172
|
+
anthropic_tools.append(t)
|
|
173
|
+
else:
|
|
174
|
+
anthropic_tools.append(t)
|
|
175
|
+
|
|
176
|
+
kwargs: dict[str, Any] = {
|
|
177
|
+
"model": model,
|
|
178
|
+
"messages": api_messages,
|
|
179
|
+
"temperature": opts["temperature"],
|
|
180
|
+
"max_tokens": opts["max_tokens"],
|
|
181
|
+
"tools": anthropic_tools,
|
|
182
|
+
}
|
|
183
|
+
if system_content:
|
|
184
|
+
kwargs["system"] = system_content
|
|
185
|
+
|
|
186
|
+
resp = await client.messages.create(**kwargs)
|
|
187
|
+
|
|
188
|
+
prompt_tokens = resp.usage.input_tokens
|
|
189
|
+
completion_tokens = resp.usage.output_tokens
|
|
190
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
191
|
+
total_cost = self._calculate_cost("claude", model, prompt_tokens, completion_tokens)
|
|
192
|
+
|
|
193
|
+
meta = {
|
|
194
|
+
"prompt_tokens": prompt_tokens,
|
|
195
|
+
"completion_tokens": completion_tokens,
|
|
196
|
+
"total_tokens": total_tokens,
|
|
197
|
+
"cost": round(total_cost, 6),
|
|
198
|
+
"raw_response": dict(resp),
|
|
199
|
+
"model_name": model,
|
|
200
|
+
}
|
|
201
|
+
|
|
202
|
+
text = ""
|
|
203
|
+
tool_calls_out: list[dict[str, Any]] = []
|
|
204
|
+
for block in resp.content:
|
|
205
|
+
if block.type == "text":
|
|
206
|
+
text += block.text
|
|
207
|
+
elif block.type == "tool_use":
|
|
208
|
+
tool_calls_out.append({
|
|
209
|
+
"id": block.id,
|
|
210
|
+
"name": block.name,
|
|
211
|
+
"arguments": block.input,
|
|
212
|
+
})
|
|
213
|
+
|
|
214
|
+
return {
|
|
215
|
+
"text": text,
|
|
216
|
+
"meta": meta,
|
|
217
|
+
"tool_calls": tool_calls_out,
|
|
218
|
+
"stop_reason": resp.stop_reason,
|
|
219
|
+
}
|
|
220
|
+
|
|
221
|
+
# ------------------------------------------------------------------
|
|
222
|
+
# Streaming
|
|
223
|
+
# ------------------------------------------------------------------
|
|
224
|
+
|
|
225
|
+
async def generate_messages_stream(
|
|
226
|
+
self,
|
|
227
|
+
messages: list[dict[str, Any]],
|
|
228
|
+
options: dict[str, Any],
|
|
229
|
+
) -> AsyncIterator[dict[str, Any]]:
|
|
230
|
+
"""Yield response chunks via Anthropic streaming API."""
|
|
231
|
+
if anthropic is None:
|
|
232
|
+
raise RuntimeError("anthropic package not installed")
|
|
233
|
+
|
|
234
|
+
opts = {**{"temperature": 0.0, "max_tokens": 512}, **options}
|
|
235
|
+
model = options.get("model", self.model)
|
|
236
|
+
client = anthropic.AsyncAnthropic(api_key=self.api_key)
|
|
237
|
+
|
|
238
|
+
system_content, api_messages = self._extract_system_and_messages(messages)
|
|
239
|
+
|
|
240
|
+
kwargs: dict[str, Any] = {
|
|
241
|
+
"model": model,
|
|
242
|
+
"messages": api_messages,
|
|
243
|
+
"temperature": opts["temperature"],
|
|
244
|
+
"max_tokens": opts["max_tokens"],
|
|
245
|
+
}
|
|
246
|
+
if system_content:
|
|
247
|
+
kwargs["system"] = system_content
|
|
248
|
+
|
|
249
|
+
full_text = ""
|
|
250
|
+
prompt_tokens = 0
|
|
251
|
+
completion_tokens = 0
|
|
252
|
+
|
|
253
|
+
async with client.messages.stream(**kwargs) as stream:
|
|
254
|
+
async for event in stream:
|
|
255
|
+
if hasattr(event, "type"):
|
|
256
|
+
if event.type == "content_block_delta" and hasattr(event, "delta"):
|
|
257
|
+
delta_text = getattr(event.delta, "text", "")
|
|
258
|
+
if delta_text:
|
|
259
|
+
full_text += delta_text
|
|
260
|
+
yield {"type": "delta", "text": delta_text}
|
|
261
|
+
elif event.type == "message_delta" and hasattr(event, "usage"):
|
|
262
|
+
completion_tokens = getattr(event.usage, "output_tokens", 0)
|
|
263
|
+
elif event.type == "message_start" and hasattr(event, "message"):
|
|
264
|
+
usage = getattr(event.message, "usage", None)
|
|
265
|
+
if usage:
|
|
266
|
+
prompt_tokens = getattr(usage, "input_tokens", 0)
|
|
267
|
+
|
|
268
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
269
|
+
total_cost = self._calculate_cost("claude", model, prompt_tokens, completion_tokens)
|
|
270
|
+
|
|
271
|
+
yield {
|
|
272
|
+
"type": "done",
|
|
273
|
+
"text": full_text,
|
|
274
|
+
"meta": {
|
|
275
|
+
"prompt_tokens": prompt_tokens,
|
|
276
|
+
"completion_tokens": completion_tokens,
|
|
277
|
+
"total_tokens": total_tokens,
|
|
278
|
+
"cost": round(total_cost, 6),
|
|
279
|
+
"raw_response": {},
|
|
280
|
+
"model_name": model,
|
|
281
|
+
},
|
|
282
|
+
}
|