prompture 0.0.38.dev2__py3-none-any.whl → 0.0.40.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prompture/__init__.py +12 -1
- prompture/_version.py +2 -2
- prompture/async_conversation.py +9 -0
- prompture/async_core.py +16 -0
- prompture/async_driver.py +39 -0
- prompture/conversation.py +9 -0
- prompture/core.py +16 -0
- prompture/cost_mixin.py +37 -0
- prompture/discovery.py +108 -43
- prompture/driver.py +39 -0
- prompture/drivers/async_azure_driver.py +4 -4
- prompture/drivers/async_claude_driver.py +177 -8
- prompture/drivers/async_google_driver.py +10 -0
- prompture/drivers/async_grok_driver.py +4 -4
- prompture/drivers/async_groq_driver.py +4 -4
- prompture/drivers/async_openai_driver.py +155 -4
- prompture/drivers/async_openrouter_driver.py +4 -4
- prompture/drivers/azure_driver.py +3 -3
- prompture/drivers/claude_driver.py +10 -0
- prompture/drivers/google_driver.py +10 -0
- prompture/drivers/grok_driver.py +4 -4
- prompture/drivers/groq_driver.py +4 -4
- prompture/drivers/openai_driver.py +19 -10
- prompture/drivers/openrouter_driver.py +4 -4
- prompture/ledger.py +252 -0
- prompture/model_rates.py +112 -2
- {prompture-0.0.38.dev2.dist-info → prompture-0.0.40.dev1.dist-info}/METADATA +1 -1
- {prompture-0.0.38.dev2.dist-info → prompture-0.0.40.dev1.dist-info}/RECORD +32 -31
- {prompture-0.0.38.dev2.dist-info → prompture-0.0.40.dev1.dist-info}/WHEEL +0 -0
- {prompture-0.0.38.dev2.dist-info → prompture-0.0.40.dev1.dist-info}/entry_points.txt +0 -0
- {prompture-0.0.38.dev2.dist-info → prompture-0.0.40.dev1.dist-info}/licenses/LICENSE +0 -0
- {prompture-0.0.38.dev2.dist-info → prompture-0.0.40.dev1.dist-info}/top_level.txt +0 -0
|
@@ -4,6 +4,7 @@ from __future__ import annotations
|
|
|
4
4
|
|
|
5
5
|
import json
|
|
6
6
|
import os
|
|
7
|
+
from collections.abc import AsyncIterator
|
|
7
8
|
from typing import Any
|
|
8
9
|
|
|
9
10
|
try:
|
|
@@ -19,6 +20,8 @@ from .claude_driver import ClaudeDriver
|
|
|
19
20
|
class AsyncClaudeDriver(CostMixin, AsyncDriver):
|
|
20
21
|
supports_json_mode = True
|
|
21
22
|
supports_json_schema = True
|
|
23
|
+
supports_tool_use = True
|
|
24
|
+
supports_streaming = True
|
|
22
25
|
supports_vision = True
|
|
23
26
|
|
|
24
27
|
MODEL_PRICING = ClaudeDriver.MODEL_PRICING
|
|
@@ -48,16 +51,17 @@ class AsyncClaudeDriver(CostMixin, AsyncDriver):
|
|
|
48
51
|
opts = {**{"temperature": 0.0, "max_tokens": 512}, **options}
|
|
49
52
|
model = options.get("model", self.model)
|
|
50
53
|
|
|
54
|
+
# Validate capabilities against models.dev metadata
|
|
55
|
+
self._validate_model_capabilities(
|
|
56
|
+
"claude",
|
|
57
|
+
model,
|
|
58
|
+
using_json_schema=bool(options.get("json_schema")),
|
|
59
|
+
)
|
|
60
|
+
|
|
51
61
|
client = anthropic.AsyncAnthropic(api_key=self.api_key)
|
|
52
62
|
|
|
53
63
|
# Anthropic requires system messages as a top-level parameter
|
|
54
|
-
system_content =
|
|
55
|
-
api_messages = []
|
|
56
|
-
for msg in messages:
|
|
57
|
-
if msg.get("role") == "system":
|
|
58
|
-
system_content = msg.get("content", "")
|
|
59
|
-
else:
|
|
60
|
-
api_messages.append(msg)
|
|
64
|
+
system_content, api_messages = self._extract_system_and_messages(messages)
|
|
61
65
|
|
|
62
66
|
# Build common kwargs
|
|
63
67
|
common_kwargs: dict[str, Any] = {
|
|
@@ -105,9 +109,174 @@ class AsyncClaudeDriver(CostMixin, AsyncDriver):
|
|
|
105
109
|
"prompt_tokens": prompt_tokens,
|
|
106
110
|
"completion_tokens": completion_tokens,
|
|
107
111
|
"total_tokens": total_tokens,
|
|
108
|
-
"cost": total_cost,
|
|
112
|
+
"cost": round(total_cost, 6),
|
|
109
113
|
"raw_response": dict(resp),
|
|
110
114
|
"model_name": model,
|
|
111
115
|
}
|
|
112
116
|
|
|
113
117
|
return {"text": text, "meta": meta}
|
|
118
|
+
|
|
119
|
+
# ------------------------------------------------------------------
|
|
120
|
+
# Helpers
|
|
121
|
+
# ------------------------------------------------------------------
|
|
122
|
+
|
|
123
|
+
def _extract_system_and_messages(
|
|
124
|
+
self, messages: list[dict[str, Any]]
|
|
125
|
+
) -> tuple[str | None, list[dict[str, Any]]]:
|
|
126
|
+
"""Separate system message from conversation messages for Anthropic API."""
|
|
127
|
+
system_content = None
|
|
128
|
+
api_messages: list[dict[str, Any]] = []
|
|
129
|
+
for msg in messages:
|
|
130
|
+
if msg.get("role") == "system":
|
|
131
|
+
system_content = msg.get("content", "")
|
|
132
|
+
else:
|
|
133
|
+
api_messages.append(msg)
|
|
134
|
+
return system_content, api_messages
|
|
135
|
+
|
|
136
|
+
# ------------------------------------------------------------------
|
|
137
|
+
# Tool use
|
|
138
|
+
# ------------------------------------------------------------------
|
|
139
|
+
|
|
140
|
+
async def generate_messages_with_tools(
|
|
141
|
+
self,
|
|
142
|
+
messages: list[dict[str, Any]],
|
|
143
|
+
tools: list[dict[str, Any]],
|
|
144
|
+
options: dict[str, Any],
|
|
145
|
+
) -> dict[str, Any]:
|
|
146
|
+
"""Generate a response that may include tool calls (Anthropic)."""
|
|
147
|
+
if anthropic is None:
|
|
148
|
+
raise RuntimeError("anthropic package not installed")
|
|
149
|
+
|
|
150
|
+
opts = {**{"temperature": 0.0, "max_tokens": 512}, **options}
|
|
151
|
+
model = options.get("model", self.model)
|
|
152
|
+
|
|
153
|
+
self._validate_model_capabilities("claude", model, using_tool_use=True)
|
|
154
|
+
|
|
155
|
+
client = anthropic.AsyncAnthropic(api_key=self.api_key)
|
|
156
|
+
|
|
157
|
+
system_content, api_messages = self._extract_system_and_messages(messages)
|
|
158
|
+
|
|
159
|
+
# Convert tools from OpenAI format to Anthropic format if needed
|
|
160
|
+
anthropic_tools = []
|
|
161
|
+
for t in tools:
|
|
162
|
+
if "type" in t and t["type"] == "function":
|
|
163
|
+
# OpenAI format -> Anthropic format
|
|
164
|
+
fn = t["function"]
|
|
165
|
+
anthropic_tools.append({
|
|
166
|
+
"name": fn["name"],
|
|
167
|
+
"description": fn.get("description", ""),
|
|
168
|
+
"input_schema": fn.get("parameters", {"type": "object", "properties": {}}),
|
|
169
|
+
})
|
|
170
|
+
elif "input_schema" in t:
|
|
171
|
+
# Already Anthropic format
|
|
172
|
+
anthropic_tools.append(t)
|
|
173
|
+
else:
|
|
174
|
+
anthropic_tools.append(t)
|
|
175
|
+
|
|
176
|
+
kwargs: dict[str, Any] = {
|
|
177
|
+
"model": model,
|
|
178
|
+
"messages": api_messages,
|
|
179
|
+
"temperature": opts["temperature"],
|
|
180
|
+
"max_tokens": opts["max_tokens"],
|
|
181
|
+
"tools": anthropic_tools,
|
|
182
|
+
}
|
|
183
|
+
if system_content:
|
|
184
|
+
kwargs["system"] = system_content
|
|
185
|
+
|
|
186
|
+
resp = await client.messages.create(**kwargs)
|
|
187
|
+
|
|
188
|
+
prompt_tokens = resp.usage.input_tokens
|
|
189
|
+
completion_tokens = resp.usage.output_tokens
|
|
190
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
191
|
+
total_cost = self._calculate_cost("claude", model, prompt_tokens, completion_tokens)
|
|
192
|
+
|
|
193
|
+
meta = {
|
|
194
|
+
"prompt_tokens": prompt_tokens,
|
|
195
|
+
"completion_tokens": completion_tokens,
|
|
196
|
+
"total_tokens": total_tokens,
|
|
197
|
+
"cost": round(total_cost, 6),
|
|
198
|
+
"raw_response": dict(resp),
|
|
199
|
+
"model_name": model,
|
|
200
|
+
}
|
|
201
|
+
|
|
202
|
+
text = ""
|
|
203
|
+
tool_calls_out: list[dict[str, Any]] = []
|
|
204
|
+
for block in resp.content:
|
|
205
|
+
if block.type == "text":
|
|
206
|
+
text += block.text
|
|
207
|
+
elif block.type == "tool_use":
|
|
208
|
+
tool_calls_out.append({
|
|
209
|
+
"id": block.id,
|
|
210
|
+
"name": block.name,
|
|
211
|
+
"arguments": block.input,
|
|
212
|
+
})
|
|
213
|
+
|
|
214
|
+
return {
|
|
215
|
+
"text": text,
|
|
216
|
+
"meta": meta,
|
|
217
|
+
"tool_calls": tool_calls_out,
|
|
218
|
+
"stop_reason": resp.stop_reason,
|
|
219
|
+
}
|
|
220
|
+
|
|
221
|
+
# ------------------------------------------------------------------
|
|
222
|
+
# Streaming
|
|
223
|
+
# ------------------------------------------------------------------
|
|
224
|
+
|
|
225
|
+
async def generate_messages_stream(
|
|
226
|
+
self,
|
|
227
|
+
messages: list[dict[str, Any]],
|
|
228
|
+
options: dict[str, Any],
|
|
229
|
+
) -> AsyncIterator[dict[str, Any]]:
|
|
230
|
+
"""Yield response chunks via Anthropic streaming API."""
|
|
231
|
+
if anthropic is None:
|
|
232
|
+
raise RuntimeError("anthropic package not installed")
|
|
233
|
+
|
|
234
|
+
opts = {**{"temperature": 0.0, "max_tokens": 512}, **options}
|
|
235
|
+
model = options.get("model", self.model)
|
|
236
|
+
client = anthropic.AsyncAnthropic(api_key=self.api_key)
|
|
237
|
+
|
|
238
|
+
system_content, api_messages = self._extract_system_and_messages(messages)
|
|
239
|
+
|
|
240
|
+
kwargs: dict[str, Any] = {
|
|
241
|
+
"model": model,
|
|
242
|
+
"messages": api_messages,
|
|
243
|
+
"temperature": opts["temperature"],
|
|
244
|
+
"max_tokens": opts["max_tokens"],
|
|
245
|
+
}
|
|
246
|
+
if system_content:
|
|
247
|
+
kwargs["system"] = system_content
|
|
248
|
+
|
|
249
|
+
full_text = ""
|
|
250
|
+
prompt_tokens = 0
|
|
251
|
+
completion_tokens = 0
|
|
252
|
+
|
|
253
|
+
async with client.messages.stream(**kwargs) as stream:
|
|
254
|
+
async for event in stream:
|
|
255
|
+
if hasattr(event, "type"):
|
|
256
|
+
if event.type == "content_block_delta" and hasattr(event, "delta"):
|
|
257
|
+
delta_text = getattr(event.delta, "text", "")
|
|
258
|
+
if delta_text:
|
|
259
|
+
full_text += delta_text
|
|
260
|
+
yield {"type": "delta", "text": delta_text}
|
|
261
|
+
elif event.type == "message_delta" and hasattr(event, "usage"):
|
|
262
|
+
completion_tokens = getattr(event.usage, "output_tokens", 0)
|
|
263
|
+
elif event.type == "message_start" and hasattr(event, "message"):
|
|
264
|
+
usage = getattr(event.message, "usage", None)
|
|
265
|
+
if usage:
|
|
266
|
+
prompt_tokens = getattr(usage, "input_tokens", 0)
|
|
267
|
+
|
|
268
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
269
|
+
total_cost = self._calculate_cost("claude", model, prompt_tokens, completion_tokens)
|
|
270
|
+
|
|
271
|
+
yield {
|
|
272
|
+
"type": "done",
|
|
273
|
+
"text": full_text,
|
|
274
|
+
"meta": {
|
|
275
|
+
"prompt_tokens": prompt_tokens,
|
|
276
|
+
"completion_tokens": completion_tokens,
|
|
277
|
+
"total_tokens": total_tokens,
|
|
278
|
+
"cost": round(total_cost, 6),
|
|
279
|
+
"raw_response": {},
|
|
280
|
+
"model_name": model,
|
|
281
|
+
},
|
|
282
|
+
}
|
|
@@ -169,6 +169,13 @@ class AsyncGoogleDriver(CostMixin, AsyncDriver):
|
|
|
169
169
|
) -> dict[str, Any]:
|
|
170
170
|
gen_input, gen_kwargs, model_kwargs = self._build_generation_args(messages, options)
|
|
171
171
|
|
|
172
|
+
# Validate capabilities against models.dev metadata
|
|
173
|
+
self._validate_model_capabilities(
|
|
174
|
+
"google",
|
|
175
|
+
self.model,
|
|
176
|
+
using_json_schema=bool((options or {}).get("json_schema")),
|
|
177
|
+
)
|
|
178
|
+
|
|
172
179
|
try:
|
|
173
180
|
model = genai.GenerativeModel(self.model, **model_kwargs)
|
|
174
181
|
response = await model.generate_content_async(gen_input, **gen_kwargs)
|
|
@@ -201,6 +208,9 @@ class AsyncGoogleDriver(CostMixin, AsyncDriver):
|
|
|
201
208
|
options: dict[str, Any],
|
|
202
209
|
) -> dict[str, Any]:
|
|
203
210
|
"""Generate a response that may include tool/function calls (async)."""
|
|
211
|
+
model = options.get("model", self.model)
|
|
212
|
+
self._validate_model_capabilities("google", model, using_tool_use=True)
|
|
213
|
+
|
|
204
214
|
gen_input, gen_kwargs, model_kwargs = self._build_generation_args(
|
|
205
215
|
self._prepare_messages(messages), options
|
|
206
216
|
)
|
|
@@ -44,9 +44,9 @@ class AsyncGrokDriver(CostMixin, AsyncDriver):
|
|
|
44
44
|
|
|
45
45
|
model = options.get("model", self.model)
|
|
46
46
|
|
|
47
|
-
|
|
48
|
-
tokens_param =
|
|
49
|
-
supports_temperature =
|
|
47
|
+
model_config = self._get_model_config("grok", model)
|
|
48
|
+
tokens_param = model_config["tokens_param"]
|
|
49
|
+
supports_temperature = model_config["supports_temperature"]
|
|
50
50
|
|
|
51
51
|
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
52
52
|
|
|
@@ -88,7 +88,7 @@ class AsyncGrokDriver(CostMixin, AsyncDriver):
|
|
|
88
88
|
"prompt_tokens": prompt_tokens,
|
|
89
89
|
"completion_tokens": completion_tokens,
|
|
90
90
|
"total_tokens": total_tokens,
|
|
91
|
-
"cost": total_cost,
|
|
91
|
+
"cost": round(total_cost, 6),
|
|
92
92
|
"raw_response": resp,
|
|
93
93
|
"model_name": model,
|
|
94
94
|
}
|
|
@@ -49,9 +49,9 @@ class AsyncGroqDriver(CostMixin, AsyncDriver):
|
|
|
49
49
|
|
|
50
50
|
model = options.get("model", self.model)
|
|
51
51
|
|
|
52
|
-
|
|
53
|
-
tokens_param =
|
|
54
|
-
supports_temperature =
|
|
52
|
+
model_config = self._get_model_config("groq", model)
|
|
53
|
+
tokens_param = model_config["tokens_param"]
|
|
54
|
+
supports_temperature = model_config["supports_temperature"]
|
|
55
55
|
|
|
56
56
|
opts = {"temperature": 0.7, "max_tokens": 512, **options}
|
|
57
57
|
|
|
@@ -81,7 +81,7 @@ class AsyncGroqDriver(CostMixin, AsyncDriver):
|
|
|
81
81
|
"prompt_tokens": prompt_tokens,
|
|
82
82
|
"completion_tokens": completion_tokens,
|
|
83
83
|
"total_tokens": total_tokens,
|
|
84
|
-
"cost": total_cost,
|
|
84
|
+
"cost": round(total_cost, 6),
|
|
85
85
|
"raw_response": resp.model_dump(),
|
|
86
86
|
"model_name": model,
|
|
87
87
|
}
|
|
@@ -2,7 +2,9 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
+
import json
|
|
5
6
|
import os
|
|
7
|
+
from collections.abc import AsyncIterator
|
|
6
8
|
from typing import Any
|
|
7
9
|
|
|
8
10
|
try:
|
|
@@ -18,6 +20,8 @@ from .openai_driver import OpenAIDriver
|
|
|
18
20
|
class AsyncOpenAIDriver(CostMixin, AsyncDriver):
|
|
19
21
|
supports_json_mode = True
|
|
20
22
|
supports_json_schema = True
|
|
23
|
+
supports_tool_use = True
|
|
24
|
+
supports_streaming = True
|
|
21
25
|
supports_vision = True
|
|
22
26
|
|
|
23
27
|
MODEL_PRICING = OpenAIDriver.MODEL_PRICING
|
|
@@ -50,9 +54,16 @@ class AsyncOpenAIDriver(CostMixin, AsyncDriver):
|
|
|
50
54
|
|
|
51
55
|
model = options.get("model", self.model)
|
|
52
56
|
|
|
53
|
-
|
|
54
|
-
tokens_param =
|
|
55
|
-
supports_temperature =
|
|
57
|
+
model_config = self._get_model_config("openai", model)
|
|
58
|
+
tokens_param = model_config["tokens_param"]
|
|
59
|
+
supports_temperature = model_config["supports_temperature"]
|
|
60
|
+
|
|
61
|
+
# Validate capabilities against models.dev metadata
|
|
62
|
+
self._validate_model_capabilities(
|
|
63
|
+
"openai",
|
|
64
|
+
model,
|
|
65
|
+
using_json_schema=bool(options.get("json_schema")),
|
|
66
|
+
)
|
|
56
67
|
|
|
57
68
|
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
58
69
|
|
|
@@ -93,10 +104,150 @@ class AsyncOpenAIDriver(CostMixin, AsyncDriver):
|
|
|
93
104
|
"prompt_tokens": prompt_tokens,
|
|
94
105
|
"completion_tokens": completion_tokens,
|
|
95
106
|
"total_tokens": total_tokens,
|
|
96
|
-
"cost": total_cost,
|
|
107
|
+
"cost": round(total_cost, 6),
|
|
97
108
|
"raw_response": resp.model_dump(),
|
|
98
109
|
"model_name": model,
|
|
99
110
|
}
|
|
100
111
|
|
|
101
112
|
text = resp.choices[0].message.content
|
|
102
113
|
return {"text": text, "meta": meta}
|
|
114
|
+
|
|
115
|
+
# ------------------------------------------------------------------
|
|
116
|
+
# Tool use
|
|
117
|
+
# ------------------------------------------------------------------
|
|
118
|
+
|
|
119
|
+
async def generate_messages_with_tools(
|
|
120
|
+
self,
|
|
121
|
+
messages: list[dict[str, Any]],
|
|
122
|
+
tools: list[dict[str, Any]],
|
|
123
|
+
options: dict[str, Any],
|
|
124
|
+
) -> dict[str, Any]:
|
|
125
|
+
"""Generate a response that may include tool calls."""
|
|
126
|
+
if self.client is None:
|
|
127
|
+
raise RuntimeError("openai package (>=1.0.0) is not installed")
|
|
128
|
+
|
|
129
|
+
model = options.get("model", self.model)
|
|
130
|
+
model_config = self._get_model_config("openai", model)
|
|
131
|
+
tokens_param = model_config["tokens_param"]
|
|
132
|
+
supports_temperature = model_config["supports_temperature"]
|
|
133
|
+
|
|
134
|
+
self._validate_model_capabilities("openai", model, using_tool_use=True)
|
|
135
|
+
|
|
136
|
+
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
137
|
+
|
|
138
|
+
kwargs: dict[str, Any] = {
|
|
139
|
+
"model": model,
|
|
140
|
+
"messages": messages,
|
|
141
|
+
"tools": tools,
|
|
142
|
+
}
|
|
143
|
+
kwargs[tokens_param] = opts.get("max_tokens", 512)
|
|
144
|
+
|
|
145
|
+
if supports_temperature and "temperature" in opts:
|
|
146
|
+
kwargs["temperature"] = opts["temperature"]
|
|
147
|
+
|
|
148
|
+
resp = await self.client.chat.completions.create(**kwargs)
|
|
149
|
+
|
|
150
|
+
usage = getattr(resp, "usage", None)
|
|
151
|
+
prompt_tokens = getattr(usage, "prompt_tokens", 0)
|
|
152
|
+
completion_tokens = getattr(usage, "completion_tokens", 0)
|
|
153
|
+
total_tokens = getattr(usage, "total_tokens", 0)
|
|
154
|
+
total_cost = self._calculate_cost("openai", model, prompt_tokens, completion_tokens)
|
|
155
|
+
|
|
156
|
+
meta = {
|
|
157
|
+
"prompt_tokens": prompt_tokens,
|
|
158
|
+
"completion_tokens": completion_tokens,
|
|
159
|
+
"total_tokens": total_tokens,
|
|
160
|
+
"cost": round(total_cost, 6),
|
|
161
|
+
"raw_response": resp.model_dump(),
|
|
162
|
+
"model_name": model,
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
choice = resp.choices[0]
|
|
166
|
+
text = choice.message.content or ""
|
|
167
|
+
stop_reason = choice.finish_reason
|
|
168
|
+
|
|
169
|
+
tool_calls_out: list[dict[str, Any]] = []
|
|
170
|
+
if choice.message.tool_calls:
|
|
171
|
+
for tc in choice.message.tool_calls:
|
|
172
|
+
try:
|
|
173
|
+
args = json.loads(tc.function.arguments)
|
|
174
|
+
except (json.JSONDecodeError, TypeError):
|
|
175
|
+
args = {}
|
|
176
|
+
tool_calls_out.append({
|
|
177
|
+
"id": tc.id,
|
|
178
|
+
"name": tc.function.name,
|
|
179
|
+
"arguments": args,
|
|
180
|
+
})
|
|
181
|
+
|
|
182
|
+
return {
|
|
183
|
+
"text": text,
|
|
184
|
+
"meta": meta,
|
|
185
|
+
"tool_calls": tool_calls_out,
|
|
186
|
+
"stop_reason": stop_reason,
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
# ------------------------------------------------------------------
|
|
190
|
+
# Streaming
|
|
191
|
+
# ------------------------------------------------------------------
|
|
192
|
+
|
|
193
|
+
async def generate_messages_stream(
|
|
194
|
+
self,
|
|
195
|
+
messages: list[dict[str, Any]],
|
|
196
|
+
options: dict[str, Any],
|
|
197
|
+
) -> AsyncIterator[dict[str, Any]]:
|
|
198
|
+
"""Yield response chunks via OpenAI streaming API."""
|
|
199
|
+
if self.client is None:
|
|
200
|
+
raise RuntimeError("openai package (>=1.0.0) is not installed")
|
|
201
|
+
|
|
202
|
+
model = options.get("model", self.model)
|
|
203
|
+
model_config = self._get_model_config("openai", model)
|
|
204
|
+
tokens_param = model_config["tokens_param"]
|
|
205
|
+
supports_temperature = model_config["supports_temperature"]
|
|
206
|
+
|
|
207
|
+
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
208
|
+
|
|
209
|
+
kwargs: dict[str, Any] = {
|
|
210
|
+
"model": model,
|
|
211
|
+
"messages": messages,
|
|
212
|
+
"stream": True,
|
|
213
|
+
"stream_options": {"include_usage": True},
|
|
214
|
+
}
|
|
215
|
+
kwargs[tokens_param] = opts.get("max_tokens", 512)
|
|
216
|
+
|
|
217
|
+
if supports_temperature and "temperature" in opts:
|
|
218
|
+
kwargs["temperature"] = opts["temperature"]
|
|
219
|
+
|
|
220
|
+
stream = await self.client.chat.completions.create(**kwargs)
|
|
221
|
+
|
|
222
|
+
full_text = ""
|
|
223
|
+
prompt_tokens = 0
|
|
224
|
+
completion_tokens = 0
|
|
225
|
+
|
|
226
|
+
async for chunk in stream:
|
|
227
|
+
# Usage comes in the final chunk
|
|
228
|
+
if getattr(chunk, "usage", None):
|
|
229
|
+
prompt_tokens = chunk.usage.prompt_tokens or 0
|
|
230
|
+
completion_tokens = chunk.usage.completion_tokens or 0
|
|
231
|
+
|
|
232
|
+
if chunk.choices:
|
|
233
|
+
delta = chunk.choices[0].delta
|
|
234
|
+
content = getattr(delta, "content", None) or ""
|
|
235
|
+
if content:
|
|
236
|
+
full_text += content
|
|
237
|
+
yield {"type": "delta", "text": content}
|
|
238
|
+
|
|
239
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
240
|
+
total_cost = self._calculate_cost("openai", model, prompt_tokens, completion_tokens)
|
|
241
|
+
|
|
242
|
+
yield {
|
|
243
|
+
"type": "done",
|
|
244
|
+
"text": full_text,
|
|
245
|
+
"meta": {
|
|
246
|
+
"prompt_tokens": prompt_tokens,
|
|
247
|
+
"completion_tokens": completion_tokens,
|
|
248
|
+
"total_tokens": total_tokens,
|
|
249
|
+
"cost": round(total_cost, 6),
|
|
250
|
+
"raw_response": {},
|
|
251
|
+
"model_name": model,
|
|
252
|
+
},
|
|
253
|
+
}
|
|
@@ -47,9 +47,9 @@ class AsyncOpenRouterDriver(CostMixin, AsyncDriver):
|
|
|
47
47
|
async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
48
48
|
model = options.get("model", self.model)
|
|
49
49
|
|
|
50
|
-
|
|
51
|
-
tokens_param =
|
|
52
|
-
supports_temperature =
|
|
50
|
+
model_config = self._get_model_config("openrouter", model)
|
|
51
|
+
tokens_param = model_config["tokens_param"]
|
|
52
|
+
supports_temperature = model_config["supports_temperature"]
|
|
53
53
|
|
|
54
54
|
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
55
55
|
|
|
@@ -93,7 +93,7 @@ class AsyncOpenRouterDriver(CostMixin, AsyncDriver):
|
|
|
93
93
|
"prompt_tokens": prompt_tokens,
|
|
94
94
|
"completion_tokens": completion_tokens,
|
|
95
95
|
"total_tokens": total_tokens,
|
|
96
|
-
"cost": total_cost,
|
|
96
|
+
"cost": round(total_cost, 6),
|
|
97
97
|
"raw_response": resp,
|
|
98
98
|
"model_name": model,
|
|
99
99
|
}
|
|
@@ -108,9 +108,9 @@ class AzureDriver(CostMixin, Driver):
|
|
|
108
108
|
raise RuntimeError("openai package (>=1.0.0) with AzureOpenAI not installed")
|
|
109
109
|
|
|
110
110
|
model = options.get("model", self.model)
|
|
111
|
-
|
|
112
|
-
tokens_param =
|
|
113
|
-
supports_temperature =
|
|
111
|
+
model_config = self._get_model_config("azure", model)
|
|
112
|
+
tokens_param = model_config["tokens_param"]
|
|
113
|
+
supports_temperature = model_config["supports_temperature"]
|
|
114
114
|
|
|
115
115
|
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
116
116
|
|
|
@@ -77,6 +77,13 @@ class ClaudeDriver(CostMixin, Driver):
|
|
|
77
77
|
opts = {**{"temperature": 0.0, "max_tokens": 512}, **options}
|
|
78
78
|
model = options.get("model", self.model)
|
|
79
79
|
|
|
80
|
+
# Validate capabilities against models.dev metadata
|
|
81
|
+
self._validate_model_capabilities(
|
|
82
|
+
"claude",
|
|
83
|
+
model,
|
|
84
|
+
using_json_schema=bool(options.get("json_schema")),
|
|
85
|
+
)
|
|
86
|
+
|
|
80
87
|
client = anthropic.Anthropic(api_key=self.api_key)
|
|
81
88
|
|
|
82
89
|
# Anthropic requires system messages as a top-level parameter
|
|
@@ -177,6 +184,9 @@ class ClaudeDriver(CostMixin, Driver):
|
|
|
177
184
|
|
|
178
185
|
opts = {**{"temperature": 0.0, "max_tokens": 512}, **options}
|
|
179
186
|
model = options.get("model", self.model)
|
|
187
|
+
|
|
188
|
+
self._validate_model_capabilities("claude", model, using_tool_use=True)
|
|
189
|
+
|
|
180
190
|
client = anthropic.Anthropic(api_key=self.api_key)
|
|
181
191
|
|
|
182
192
|
system_content, api_messages = self._extract_system_and_messages(messages)
|
|
@@ -228,6 +228,13 @@ class GoogleDriver(CostMixin, Driver):
|
|
|
228
228
|
def _do_generate(self, messages: list[dict[str, str]], options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
|
|
229
229
|
gen_input, gen_kwargs, model_kwargs = self._build_generation_args(messages, options)
|
|
230
230
|
|
|
231
|
+
# Validate capabilities against models.dev metadata
|
|
232
|
+
self._validate_model_capabilities(
|
|
233
|
+
"google",
|
|
234
|
+
self.model,
|
|
235
|
+
using_json_schema=bool((options or {}).get("json_schema")),
|
|
236
|
+
)
|
|
237
|
+
|
|
231
238
|
try:
|
|
232
239
|
logger.debug(f"Initializing {self.model} for generation")
|
|
233
240
|
model = genai.GenerativeModel(self.model, **model_kwargs)
|
|
@@ -263,6 +270,9 @@ class GoogleDriver(CostMixin, Driver):
|
|
|
263
270
|
options: dict[str, Any],
|
|
264
271
|
) -> dict[str, Any]:
|
|
265
272
|
"""Generate a response that may include tool/function calls."""
|
|
273
|
+
model = options.get("model", self.model)
|
|
274
|
+
self._validate_model_capabilities("google", model, using_tool_use=True)
|
|
275
|
+
|
|
266
276
|
gen_input, gen_kwargs, model_kwargs = self._build_generation_args(
|
|
267
277
|
self._prepare_messages(messages), options
|
|
268
278
|
)
|
prompture/drivers/grok_driver.py
CHANGED
|
@@ -99,10 +99,10 @@ class GrokDriver(CostMixin, Driver):
|
|
|
99
99
|
|
|
100
100
|
model = options.get("model", self.model)
|
|
101
101
|
|
|
102
|
-
# Lookup model-specific config
|
|
103
|
-
|
|
104
|
-
tokens_param =
|
|
105
|
-
supports_temperature =
|
|
102
|
+
# Lookup model-specific config (live models.dev data + hardcoded fallback)
|
|
103
|
+
model_config = self._get_model_config("grok", model)
|
|
104
|
+
tokens_param = model_config["tokens_param"]
|
|
105
|
+
supports_temperature = model_config["supports_temperature"]
|
|
106
106
|
|
|
107
107
|
# Defaults
|
|
108
108
|
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
prompture/drivers/groq_driver.py
CHANGED
|
@@ -69,10 +69,10 @@ class GroqDriver(CostMixin, Driver):
|
|
|
69
69
|
|
|
70
70
|
model = options.get("model", self.model)
|
|
71
71
|
|
|
72
|
-
# Lookup model-specific config
|
|
73
|
-
|
|
74
|
-
tokens_param =
|
|
75
|
-
supports_temperature =
|
|
72
|
+
# Lookup model-specific config (live models.dev data + hardcoded fallback)
|
|
73
|
+
model_config = self._get_model_config("groq", model)
|
|
74
|
+
tokens_param = model_config["tokens_param"]
|
|
75
|
+
supports_temperature = model_config["supports_temperature"]
|
|
76
76
|
|
|
77
77
|
# Base configuration
|
|
78
78
|
opts = {"temperature": 0.7, "max_tokens": 512, **options}
|
|
@@ -93,10 +93,17 @@ class OpenAIDriver(CostMixin, Driver):
|
|
|
93
93
|
|
|
94
94
|
model = options.get("model", self.model)
|
|
95
95
|
|
|
96
|
-
# Lookup model-specific config
|
|
97
|
-
|
|
98
|
-
tokens_param =
|
|
99
|
-
supports_temperature =
|
|
96
|
+
# Lookup model-specific config (live models.dev data + hardcoded fallback)
|
|
97
|
+
model_config = self._get_model_config("openai", model)
|
|
98
|
+
tokens_param = model_config["tokens_param"]
|
|
99
|
+
supports_temperature = model_config["supports_temperature"]
|
|
100
|
+
|
|
101
|
+
# Validate capabilities against models.dev metadata
|
|
102
|
+
self._validate_model_capabilities(
|
|
103
|
+
"openai",
|
|
104
|
+
model,
|
|
105
|
+
using_json_schema=bool(options.get("json_schema")),
|
|
106
|
+
)
|
|
100
107
|
|
|
101
108
|
# Defaults
|
|
102
109
|
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
@@ -168,9 +175,11 @@ class OpenAIDriver(CostMixin, Driver):
|
|
|
168
175
|
raise RuntimeError("openai package (>=1.0.0) is not installed")
|
|
169
176
|
|
|
170
177
|
model = options.get("model", self.model)
|
|
171
|
-
|
|
172
|
-
tokens_param =
|
|
173
|
-
supports_temperature =
|
|
178
|
+
model_config = self._get_model_config("openai", model)
|
|
179
|
+
tokens_param = model_config["tokens_param"]
|
|
180
|
+
supports_temperature = model_config["supports_temperature"]
|
|
181
|
+
|
|
182
|
+
self._validate_model_capabilities("openai", model, using_tool_use=True)
|
|
174
183
|
|
|
175
184
|
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
176
185
|
|
|
@@ -239,9 +248,9 @@ class OpenAIDriver(CostMixin, Driver):
|
|
|
239
248
|
raise RuntimeError("openai package (>=1.0.0) is not installed")
|
|
240
249
|
|
|
241
250
|
model = options.get("model", self.model)
|
|
242
|
-
|
|
243
|
-
tokens_param =
|
|
244
|
-
supports_temperature =
|
|
251
|
+
model_config = self._get_model_config("openai", model)
|
|
252
|
+
tokens_param = model_config["tokens_param"]
|
|
253
|
+
supports_temperature = model_config["supports_temperature"]
|
|
245
254
|
|
|
246
255
|
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
247
256
|
|
|
@@ -85,10 +85,10 @@ class OpenRouterDriver(CostMixin, Driver):
|
|
|
85
85
|
|
|
86
86
|
model = options.get("model", self.model)
|
|
87
87
|
|
|
88
|
-
# Lookup model-specific config
|
|
89
|
-
|
|
90
|
-
tokens_param =
|
|
91
|
-
supports_temperature =
|
|
88
|
+
# Lookup model-specific config (live models.dev data + hardcoded fallback)
|
|
89
|
+
model_config = self._get_model_config("openrouter", model)
|
|
90
|
+
tokens_param = model_config["tokens_param"]
|
|
91
|
+
supports_temperature = model_config["supports_temperature"]
|
|
92
92
|
|
|
93
93
|
# Defaults
|
|
94
94
|
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|