prompture 0.0.38.dev2__py3-none-any.whl → 0.0.40.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. prompture/__init__.py +12 -1
  2. prompture/_version.py +2 -2
  3. prompture/async_conversation.py +9 -0
  4. prompture/async_core.py +16 -0
  5. prompture/async_driver.py +39 -0
  6. prompture/conversation.py +9 -0
  7. prompture/core.py +16 -0
  8. prompture/cost_mixin.py +37 -0
  9. prompture/discovery.py +108 -43
  10. prompture/driver.py +39 -0
  11. prompture/drivers/async_azure_driver.py +4 -4
  12. prompture/drivers/async_claude_driver.py +177 -8
  13. prompture/drivers/async_google_driver.py +10 -0
  14. prompture/drivers/async_grok_driver.py +4 -4
  15. prompture/drivers/async_groq_driver.py +4 -4
  16. prompture/drivers/async_openai_driver.py +155 -4
  17. prompture/drivers/async_openrouter_driver.py +4 -4
  18. prompture/drivers/azure_driver.py +3 -3
  19. prompture/drivers/claude_driver.py +10 -0
  20. prompture/drivers/google_driver.py +10 -0
  21. prompture/drivers/grok_driver.py +4 -4
  22. prompture/drivers/groq_driver.py +4 -4
  23. prompture/drivers/openai_driver.py +19 -10
  24. prompture/drivers/openrouter_driver.py +4 -4
  25. prompture/ledger.py +252 -0
  26. prompture/model_rates.py +112 -2
  27. {prompture-0.0.38.dev2.dist-info → prompture-0.0.40.dev1.dist-info}/METADATA +1 -1
  28. {prompture-0.0.38.dev2.dist-info → prompture-0.0.40.dev1.dist-info}/RECORD +32 -31
  29. {prompture-0.0.38.dev2.dist-info → prompture-0.0.40.dev1.dist-info}/WHEEL +0 -0
  30. {prompture-0.0.38.dev2.dist-info → prompture-0.0.40.dev1.dist-info}/entry_points.txt +0 -0
  31. {prompture-0.0.38.dev2.dist-info → prompture-0.0.40.dev1.dist-info}/licenses/LICENSE +0 -0
  32. {prompture-0.0.38.dev2.dist-info → prompture-0.0.40.dev1.dist-info}/top_level.txt +0 -0
@@ -4,6 +4,7 @@ from __future__ import annotations
4
4
 
5
5
  import json
6
6
  import os
7
+ from collections.abc import AsyncIterator
7
8
  from typing import Any
8
9
 
9
10
  try:
@@ -19,6 +20,8 @@ from .claude_driver import ClaudeDriver
19
20
  class AsyncClaudeDriver(CostMixin, AsyncDriver):
20
21
  supports_json_mode = True
21
22
  supports_json_schema = True
23
+ supports_tool_use = True
24
+ supports_streaming = True
22
25
  supports_vision = True
23
26
 
24
27
  MODEL_PRICING = ClaudeDriver.MODEL_PRICING
@@ -48,16 +51,17 @@ class AsyncClaudeDriver(CostMixin, AsyncDriver):
48
51
  opts = {**{"temperature": 0.0, "max_tokens": 512}, **options}
49
52
  model = options.get("model", self.model)
50
53
 
54
+ # Validate capabilities against models.dev metadata
55
+ self._validate_model_capabilities(
56
+ "claude",
57
+ model,
58
+ using_json_schema=bool(options.get("json_schema")),
59
+ )
60
+
51
61
  client = anthropic.AsyncAnthropic(api_key=self.api_key)
52
62
 
53
63
  # Anthropic requires system messages as a top-level parameter
54
- system_content = None
55
- api_messages = []
56
- for msg in messages:
57
- if msg.get("role") == "system":
58
- system_content = msg.get("content", "")
59
- else:
60
- api_messages.append(msg)
64
+ system_content, api_messages = self._extract_system_and_messages(messages)
61
65
 
62
66
  # Build common kwargs
63
67
  common_kwargs: dict[str, Any] = {
@@ -105,9 +109,174 @@ class AsyncClaudeDriver(CostMixin, AsyncDriver):
105
109
  "prompt_tokens": prompt_tokens,
106
110
  "completion_tokens": completion_tokens,
107
111
  "total_tokens": total_tokens,
108
- "cost": total_cost,
112
+ "cost": round(total_cost, 6),
109
113
  "raw_response": dict(resp),
110
114
  "model_name": model,
111
115
  }
112
116
 
113
117
  return {"text": text, "meta": meta}
118
+
119
+ # ------------------------------------------------------------------
120
+ # Helpers
121
+ # ------------------------------------------------------------------
122
+
123
+ def _extract_system_and_messages(
124
+ self, messages: list[dict[str, Any]]
125
+ ) -> tuple[str | None, list[dict[str, Any]]]:
126
+ """Separate system message from conversation messages for Anthropic API."""
127
+ system_content = None
128
+ api_messages: list[dict[str, Any]] = []
129
+ for msg in messages:
130
+ if msg.get("role") == "system":
131
+ system_content = msg.get("content", "")
132
+ else:
133
+ api_messages.append(msg)
134
+ return system_content, api_messages
135
+
136
+ # ------------------------------------------------------------------
137
+ # Tool use
138
+ # ------------------------------------------------------------------
139
+
140
+ async def generate_messages_with_tools(
141
+ self,
142
+ messages: list[dict[str, Any]],
143
+ tools: list[dict[str, Any]],
144
+ options: dict[str, Any],
145
+ ) -> dict[str, Any]:
146
+ """Generate a response that may include tool calls (Anthropic)."""
147
+ if anthropic is None:
148
+ raise RuntimeError("anthropic package not installed")
149
+
150
+ opts = {**{"temperature": 0.0, "max_tokens": 512}, **options}
151
+ model = options.get("model", self.model)
152
+
153
+ self._validate_model_capabilities("claude", model, using_tool_use=True)
154
+
155
+ client = anthropic.AsyncAnthropic(api_key=self.api_key)
156
+
157
+ system_content, api_messages = self._extract_system_and_messages(messages)
158
+
159
+ # Convert tools from OpenAI format to Anthropic format if needed
160
+ anthropic_tools = []
161
+ for t in tools:
162
+ if "type" in t and t["type"] == "function":
163
+ # OpenAI format -> Anthropic format
164
+ fn = t["function"]
165
+ anthropic_tools.append({
166
+ "name": fn["name"],
167
+ "description": fn.get("description", ""),
168
+ "input_schema": fn.get("parameters", {"type": "object", "properties": {}}),
169
+ })
170
+ elif "input_schema" in t:
171
+ # Already Anthropic format
172
+ anthropic_tools.append(t)
173
+ else:
174
+ anthropic_tools.append(t)
175
+
176
+ kwargs: dict[str, Any] = {
177
+ "model": model,
178
+ "messages": api_messages,
179
+ "temperature": opts["temperature"],
180
+ "max_tokens": opts["max_tokens"],
181
+ "tools": anthropic_tools,
182
+ }
183
+ if system_content:
184
+ kwargs["system"] = system_content
185
+
186
+ resp = await client.messages.create(**kwargs)
187
+
188
+ prompt_tokens = resp.usage.input_tokens
189
+ completion_tokens = resp.usage.output_tokens
190
+ total_tokens = prompt_tokens + completion_tokens
191
+ total_cost = self._calculate_cost("claude", model, prompt_tokens, completion_tokens)
192
+
193
+ meta = {
194
+ "prompt_tokens": prompt_tokens,
195
+ "completion_tokens": completion_tokens,
196
+ "total_tokens": total_tokens,
197
+ "cost": round(total_cost, 6),
198
+ "raw_response": dict(resp),
199
+ "model_name": model,
200
+ }
201
+
202
+ text = ""
203
+ tool_calls_out: list[dict[str, Any]] = []
204
+ for block in resp.content:
205
+ if block.type == "text":
206
+ text += block.text
207
+ elif block.type == "tool_use":
208
+ tool_calls_out.append({
209
+ "id": block.id,
210
+ "name": block.name,
211
+ "arguments": block.input,
212
+ })
213
+
214
+ return {
215
+ "text": text,
216
+ "meta": meta,
217
+ "tool_calls": tool_calls_out,
218
+ "stop_reason": resp.stop_reason,
219
+ }
220
+
221
+ # ------------------------------------------------------------------
222
+ # Streaming
223
+ # ------------------------------------------------------------------
224
+
225
+ async def generate_messages_stream(
226
+ self,
227
+ messages: list[dict[str, Any]],
228
+ options: dict[str, Any],
229
+ ) -> AsyncIterator[dict[str, Any]]:
230
+ """Yield response chunks via Anthropic streaming API."""
231
+ if anthropic is None:
232
+ raise RuntimeError("anthropic package not installed")
233
+
234
+ opts = {**{"temperature": 0.0, "max_tokens": 512}, **options}
235
+ model = options.get("model", self.model)
236
+ client = anthropic.AsyncAnthropic(api_key=self.api_key)
237
+
238
+ system_content, api_messages = self._extract_system_and_messages(messages)
239
+
240
+ kwargs: dict[str, Any] = {
241
+ "model": model,
242
+ "messages": api_messages,
243
+ "temperature": opts["temperature"],
244
+ "max_tokens": opts["max_tokens"],
245
+ }
246
+ if system_content:
247
+ kwargs["system"] = system_content
248
+
249
+ full_text = ""
250
+ prompt_tokens = 0
251
+ completion_tokens = 0
252
+
253
+ async with client.messages.stream(**kwargs) as stream:
254
+ async for event in stream:
255
+ if hasattr(event, "type"):
256
+ if event.type == "content_block_delta" and hasattr(event, "delta"):
257
+ delta_text = getattr(event.delta, "text", "")
258
+ if delta_text:
259
+ full_text += delta_text
260
+ yield {"type": "delta", "text": delta_text}
261
+ elif event.type == "message_delta" and hasattr(event, "usage"):
262
+ completion_tokens = getattr(event.usage, "output_tokens", 0)
263
+ elif event.type == "message_start" and hasattr(event, "message"):
264
+ usage = getattr(event.message, "usage", None)
265
+ if usage:
266
+ prompt_tokens = getattr(usage, "input_tokens", 0)
267
+
268
+ total_tokens = prompt_tokens + completion_tokens
269
+ total_cost = self._calculate_cost("claude", model, prompt_tokens, completion_tokens)
270
+
271
+ yield {
272
+ "type": "done",
273
+ "text": full_text,
274
+ "meta": {
275
+ "prompt_tokens": prompt_tokens,
276
+ "completion_tokens": completion_tokens,
277
+ "total_tokens": total_tokens,
278
+ "cost": round(total_cost, 6),
279
+ "raw_response": {},
280
+ "model_name": model,
281
+ },
282
+ }
@@ -169,6 +169,13 @@ class AsyncGoogleDriver(CostMixin, AsyncDriver):
169
169
  ) -> dict[str, Any]:
170
170
  gen_input, gen_kwargs, model_kwargs = self._build_generation_args(messages, options)
171
171
 
172
+ # Validate capabilities against models.dev metadata
173
+ self._validate_model_capabilities(
174
+ "google",
175
+ self.model,
176
+ using_json_schema=bool((options or {}).get("json_schema")),
177
+ )
178
+
172
179
  try:
173
180
  model = genai.GenerativeModel(self.model, **model_kwargs)
174
181
  response = await model.generate_content_async(gen_input, **gen_kwargs)
@@ -201,6 +208,9 @@ class AsyncGoogleDriver(CostMixin, AsyncDriver):
201
208
  options: dict[str, Any],
202
209
  ) -> dict[str, Any]:
203
210
  """Generate a response that may include tool/function calls (async)."""
211
+ model = options.get("model", self.model)
212
+ self._validate_model_capabilities("google", model, using_tool_use=True)
213
+
204
214
  gen_input, gen_kwargs, model_kwargs = self._build_generation_args(
205
215
  self._prepare_messages(messages), options
206
216
  )
@@ -44,9 +44,9 @@ class AsyncGrokDriver(CostMixin, AsyncDriver):
44
44
 
45
45
  model = options.get("model", self.model)
46
46
 
47
- model_info = self.MODEL_PRICING.get(model, {})
48
- tokens_param = model_info.get("tokens_param", "max_tokens")
49
- supports_temperature = model_info.get("supports_temperature", True)
47
+ model_config = self._get_model_config("grok", model)
48
+ tokens_param = model_config["tokens_param"]
49
+ supports_temperature = model_config["supports_temperature"]
50
50
 
51
51
  opts = {"temperature": 1.0, "max_tokens": 512, **options}
52
52
 
@@ -88,7 +88,7 @@ class AsyncGrokDriver(CostMixin, AsyncDriver):
88
88
  "prompt_tokens": prompt_tokens,
89
89
  "completion_tokens": completion_tokens,
90
90
  "total_tokens": total_tokens,
91
- "cost": total_cost,
91
+ "cost": round(total_cost, 6),
92
92
  "raw_response": resp,
93
93
  "model_name": model,
94
94
  }
@@ -49,9 +49,9 @@ class AsyncGroqDriver(CostMixin, AsyncDriver):
49
49
 
50
50
  model = options.get("model", self.model)
51
51
 
52
- model_info = self.MODEL_PRICING.get(model, {})
53
- tokens_param = model_info.get("tokens_param", "max_tokens")
54
- supports_temperature = model_info.get("supports_temperature", True)
52
+ model_config = self._get_model_config("groq", model)
53
+ tokens_param = model_config["tokens_param"]
54
+ supports_temperature = model_config["supports_temperature"]
55
55
 
56
56
  opts = {"temperature": 0.7, "max_tokens": 512, **options}
57
57
 
@@ -81,7 +81,7 @@ class AsyncGroqDriver(CostMixin, AsyncDriver):
81
81
  "prompt_tokens": prompt_tokens,
82
82
  "completion_tokens": completion_tokens,
83
83
  "total_tokens": total_tokens,
84
- "cost": total_cost,
84
+ "cost": round(total_cost, 6),
85
85
  "raw_response": resp.model_dump(),
86
86
  "model_name": model,
87
87
  }
@@ -2,7 +2,9 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ import json
5
6
  import os
7
+ from collections.abc import AsyncIterator
6
8
  from typing import Any
7
9
 
8
10
  try:
@@ -18,6 +20,8 @@ from .openai_driver import OpenAIDriver
18
20
  class AsyncOpenAIDriver(CostMixin, AsyncDriver):
19
21
  supports_json_mode = True
20
22
  supports_json_schema = True
23
+ supports_tool_use = True
24
+ supports_streaming = True
21
25
  supports_vision = True
22
26
 
23
27
  MODEL_PRICING = OpenAIDriver.MODEL_PRICING
@@ -50,9 +54,16 @@ class AsyncOpenAIDriver(CostMixin, AsyncDriver):
50
54
 
51
55
  model = options.get("model", self.model)
52
56
 
53
- model_info = self.MODEL_PRICING.get(model, {})
54
- tokens_param = model_info.get("tokens_param", "max_tokens")
55
- supports_temperature = model_info.get("supports_temperature", True)
57
+ model_config = self._get_model_config("openai", model)
58
+ tokens_param = model_config["tokens_param"]
59
+ supports_temperature = model_config["supports_temperature"]
60
+
61
+ # Validate capabilities against models.dev metadata
62
+ self._validate_model_capabilities(
63
+ "openai",
64
+ model,
65
+ using_json_schema=bool(options.get("json_schema")),
66
+ )
56
67
 
57
68
  opts = {"temperature": 1.0, "max_tokens": 512, **options}
58
69
 
@@ -93,10 +104,150 @@ class AsyncOpenAIDriver(CostMixin, AsyncDriver):
93
104
  "prompt_tokens": prompt_tokens,
94
105
  "completion_tokens": completion_tokens,
95
106
  "total_tokens": total_tokens,
96
- "cost": total_cost,
107
+ "cost": round(total_cost, 6),
97
108
  "raw_response": resp.model_dump(),
98
109
  "model_name": model,
99
110
  }
100
111
 
101
112
  text = resp.choices[0].message.content
102
113
  return {"text": text, "meta": meta}
114
+
115
+ # ------------------------------------------------------------------
116
+ # Tool use
117
+ # ------------------------------------------------------------------
118
+
119
+ async def generate_messages_with_tools(
120
+ self,
121
+ messages: list[dict[str, Any]],
122
+ tools: list[dict[str, Any]],
123
+ options: dict[str, Any],
124
+ ) -> dict[str, Any]:
125
+ """Generate a response that may include tool calls."""
126
+ if self.client is None:
127
+ raise RuntimeError("openai package (>=1.0.0) is not installed")
128
+
129
+ model = options.get("model", self.model)
130
+ model_config = self._get_model_config("openai", model)
131
+ tokens_param = model_config["tokens_param"]
132
+ supports_temperature = model_config["supports_temperature"]
133
+
134
+ self._validate_model_capabilities("openai", model, using_tool_use=True)
135
+
136
+ opts = {"temperature": 1.0, "max_tokens": 512, **options}
137
+
138
+ kwargs: dict[str, Any] = {
139
+ "model": model,
140
+ "messages": messages,
141
+ "tools": tools,
142
+ }
143
+ kwargs[tokens_param] = opts.get("max_tokens", 512)
144
+
145
+ if supports_temperature and "temperature" in opts:
146
+ kwargs["temperature"] = opts["temperature"]
147
+
148
+ resp = await self.client.chat.completions.create(**kwargs)
149
+
150
+ usage = getattr(resp, "usage", None)
151
+ prompt_tokens = getattr(usage, "prompt_tokens", 0)
152
+ completion_tokens = getattr(usage, "completion_tokens", 0)
153
+ total_tokens = getattr(usage, "total_tokens", 0)
154
+ total_cost = self._calculate_cost("openai", model, prompt_tokens, completion_tokens)
155
+
156
+ meta = {
157
+ "prompt_tokens": prompt_tokens,
158
+ "completion_tokens": completion_tokens,
159
+ "total_tokens": total_tokens,
160
+ "cost": round(total_cost, 6),
161
+ "raw_response": resp.model_dump(),
162
+ "model_name": model,
163
+ }
164
+
165
+ choice = resp.choices[0]
166
+ text = choice.message.content or ""
167
+ stop_reason = choice.finish_reason
168
+
169
+ tool_calls_out: list[dict[str, Any]] = []
170
+ if choice.message.tool_calls:
171
+ for tc in choice.message.tool_calls:
172
+ try:
173
+ args = json.loads(tc.function.arguments)
174
+ except (json.JSONDecodeError, TypeError):
175
+ args = {}
176
+ tool_calls_out.append({
177
+ "id": tc.id,
178
+ "name": tc.function.name,
179
+ "arguments": args,
180
+ })
181
+
182
+ return {
183
+ "text": text,
184
+ "meta": meta,
185
+ "tool_calls": tool_calls_out,
186
+ "stop_reason": stop_reason,
187
+ }
188
+
189
+ # ------------------------------------------------------------------
190
+ # Streaming
191
+ # ------------------------------------------------------------------
192
+
193
+ async def generate_messages_stream(
194
+ self,
195
+ messages: list[dict[str, Any]],
196
+ options: dict[str, Any],
197
+ ) -> AsyncIterator[dict[str, Any]]:
198
+ """Yield response chunks via OpenAI streaming API."""
199
+ if self.client is None:
200
+ raise RuntimeError("openai package (>=1.0.0) is not installed")
201
+
202
+ model = options.get("model", self.model)
203
+ model_config = self._get_model_config("openai", model)
204
+ tokens_param = model_config["tokens_param"]
205
+ supports_temperature = model_config["supports_temperature"]
206
+
207
+ opts = {"temperature": 1.0, "max_tokens": 512, **options}
208
+
209
+ kwargs: dict[str, Any] = {
210
+ "model": model,
211
+ "messages": messages,
212
+ "stream": True,
213
+ "stream_options": {"include_usage": True},
214
+ }
215
+ kwargs[tokens_param] = opts.get("max_tokens", 512)
216
+
217
+ if supports_temperature and "temperature" in opts:
218
+ kwargs["temperature"] = opts["temperature"]
219
+
220
+ stream = await self.client.chat.completions.create(**kwargs)
221
+
222
+ full_text = ""
223
+ prompt_tokens = 0
224
+ completion_tokens = 0
225
+
226
+ async for chunk in stream:
227
+ # Usage comes in the final chunk
228
+ if getattr(chunk, "usage", None):
229
+ prompt_tokens = chunk.usage.prompt_tokens or 0
230
+ completion_tokens = chunk.usage.completion_tokens or 0
231
+
232
+ if chunk.choices:
233
+ delta = chunk.choices[0].delta
234
+ content = getattr(delta, "content", None) or ""
235
+ if content:
236
+ full_text += content
237
+ yield {"type": "delta", "text": content}
238
+
239
+ total_tokens = prompt_tokens + completion_tokens
240
+ total_cost = self._calculate_cost("openai", model, prompt_tokens, completion_tokens)
241
+
242
+ yield {
243
+ "type": "done",
244
+ "text": full_text,
245
+ "meta": {
246
+ "prompt_tokens": prompt_tokens,
247
+ "completion_tokens": completion_tokens,
248
+ "total_tokens": total_tokens,
249
+ "cost": round(total_cost, 6),
250
+ "raw_response": {},
251
+ "model_name": model,
252
+ },
253
+ }
@@ -47,9 +47,9 @@ class AsyncOpenRouterDriver(CostMixin, AsyncDriver):
47
47
  async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
48
48
  model = options.get("model", self.model)
49
49
 
50
- model_info = self.MODEL_PRICING.get(model, {})
51
- tokens_param = model_info.get("tokens_param", "max_tokens")
52
- supports_temperature = model_info.get("supports_temperature", True)
50
+ model_config = self._get_model_config("openrouter", model)
51
+ tokens_param = model_config["tokens_param"]
52
+ supports_temperature = model_config["supports_temperature"]
53
53
 
54
54
  opts = {"temperature": 1.0, "max_tokens": 512, **options}
55
55
 
@@ -93,7 +93,7 @@ class AsyncOpenRouterDriver(CostMixin, AsyncDriver):
93
93
  "prompt_tokens": prompt_tokens,
94
94
  "completion_tokens": completion_tokens,
95
95
  "total_tokens": total_tokens,
96
- "cost": total_cost,
96
+ "cost": round(total_cost, 6),
97
97
  "raw_response": resp,
98
98
  "model_name": model,
99
99
  }
@@ -108,9 +108,9 @@ class AzureDriver(CostMixin, Driver):
108
108
  raise RuntimeError("openai package (>=1.0.0) with AzureOpenAI not installed")
109
109
 
110
110
  model = options.get("model", self.model)
111
- model_info = self.MODEL_PRICING.get(model, {})
112
- tokens_param = model_info.get("tokens_param", "max_tokens")
113
- supports_temperature = model_info.get("supports_temperature", True)
111
+ model_config = self._get_model_config("azure", model)
112
+ tokens_param = model_config["tokens_param"]
113
+ supports_temperature = model_config["supports_temperature"]
114
114
 
115
115
  opts = {"temperature": 1.0, "max_tokens": 512, **options}
116
116
 
@@ -77,6 +77,13 @@ class ClaudeDriver(CostMixin, Driver):
77
77
  opts = {**{"temperature": 0.0, "max_tokens": 512}, **options}
78
78
  model = options.get("model", self.model)
79
79
 
80
+ # Validate capabilities against models.dev metadata
81
+ self._validate_model_capabilities(
82
+ "claude",
83
+ model,
84
+ using_json_schema=bool(options.get("json_schema")),
85
+ )
86
+
80
87
  client = anthropic.Anthropic(api_key=self.api_key)
81
88
 
82
89
  # Anthropic requires system messages as a top-level parameter
@@ -177,6 +184,9 @@ class ClaudeDriver(CostMixin, Driver):
177
184
 
178
185
  opts = {**{"temperature": 0.0, "max_tokens": 512}, **options}
179
186
  model = options.get("model", self.model)
187
+
188
+ self._validate_model_capabilities("claude", model, using_tool_use=True)
189
+
180
190
  client = anthropic.Anthropic(api_key=self.api_key)
181
191
 
182
192
  system_content, api_messages = self._extract_system_and_messages(messages)
@@ -228,6 +228,13 @@ class GoogleDriver(CostMixin, Driver):
228
228
  def _do_generate(self, messages: list[dict[str, str]], options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
229
229
  gen_input, gen_kwargs, model_kwargs = self._build_generation_args(messages, options)
230
230
 
231
+ # Validate capabilities against models.dev metadata
232
+ self._validate_model_capabilities(
233
+ "google",
234
+ self.model,
235
+ using_json_schema=bool((options or {}).get("json_schema")),
236
+ )
237
+
231
238
  try:
232
239
  logger.debug(f"Initializing {self.model} for generation")
233
240
  model = genai.GenerativeModel(self.model, **model_kwargs)
@@ -263,6 +270,9 @@ class GoogleDriver(CostMixin, Driver):
263
270
  options: dict[str, Any],
264
271
  ) -> dict[str, Any]:
265
272
  """Generate a response that may include tool/function calls."""
273
+ model = options.get("model", self.model)
274
+ self._validate_model_capabilities("google", model, using_tool_use=True)
275
+
266
276
  gen_input, gen_kwargs, model_kwargs = self._build_generation_args(
267
277
  self._prepare_messages(messages), options
268
278
  )
@@ -99,10 +99,10 @@ class GrokDriver(CostMixin, Driver):
99
99
 
100
100
  model = options.get("model", self.model)
101
101
 
102
- # Lookup model-specific config
103
- model_info = self.MODEL_PRICING.get(model, {})
104
- tokens_param = model_info.get("tokens_param", "max_tokens")
105
- supports_temperature = model_info.get("supports_temperature", True)
102
+ # Lookup model-specific config (live models.dev data + hardcoded fallback)
103
+ model_config = self._get_model_config("grok", model)
104
+ tokens_param = model_config["tokens_param"]
105
+ supports_temperature = model_config["supports_temperature"]
106
106
 
107
107
  # Defaults
108
108
  opts = {"temperature": 1.0, "max_tokens": 512, **options}
@@ -69,10 +69,10 @@ class GroqDriver(CostMixin, Driver):
69
69
 
70
70
  model = options.get("model", self.model)
71
71
 
72
- # Lookup model-specific config
73
- model_info = self.MODEL_PRICING.get(model, {})
74
- tokens_param = model_info.get("tokens_param", "max_tokens")
75
- supports_temperature = model_info.get("supports_temperature", True)
72
+ # Lookup model-specific config (live models.dev data + hardcoded fallback)
73
+ model_config = self._get_model_config("groq", model)
74
+ tokens_param = model_config["tokens_param"]
75
+ supports_temperature = model_config["supports_temperature"]
76
76
 
77
77
  # Base configuration
78
78
  opts = {"temperature": 0.7, "max_tokens": 512, **options}
@@ -93,10 +93,17 @@ class OpenAIDriver(CostMixin, Driver):
93
93
 
94
94
  model = options.get("model", self.model)
95
95
 
96
- # Lookup model-specific config
97
- model_info = self.MODEL_PRICING.get(model, {})
98
- tokens_param = model_info.get("tokens_param", "max_tokens")
99
- supports_temperature = model_info.get("supports_temperature", True)
96
+ # Lookup model-specific config (live models.dev data + hardcoded fallback)
97
+ model_config = self._get_model_config("openai", model)
98
+ tokens_param = model_config["tokens_param"]
99
+ supports_temperature = model_config["supports_temperature"]
100
+
101
+ # Validate capabilities against models.dev metadata
102
+ self._validate_model_capabilities(
103
+ "openai",
104
+ model,
105
+ using_json_schema=bool(options.get("json_schema")),
106
+ )
100
107
 
101
108
  # Defaults
102
109
  opts = {"temperature": 1.0, "max_tokens": 512, **options}
@@ -168,9 +175,11 @@ class OpenAIDriver(CostMixin, Driver):
168
175
  raise RuntimeError("openai package (>=1.0.0) is not installed")
169
176
 
170
177
  model = options.get("model", self.model)
171
- model_info = self.MODEL_PRICING.get(model, {})
172
- tokens_param = model_info.get("tokens_param", "max_tokens")
173
- supports_temperature = model_info.get("supports_temperature", True)
178
+ model_config = self._get_model_config("openai", model)
179
+ tokens_param = model_config["tokens_param"]
180
+ supports_temperature = model_config["supports_temperature"]
181
+
182
+ self._validate_model_capabilities("openai", model, using_tool_use=True)
174
183
 
175
184
  opts = {"temperature": 1.0, "max_tokens": 512, **options}
176
185
 
@@ -239,9 +248,9 @@ class OpenAIDriver(CostMixin, Driver):
239
248
  raise RuntimeError("openai package (>=1.0.0) is not installed")
240
249
 
241
250
  model = options.get("model", self.model)
242
- model_info = self.MODEL_PRICING.get(model, {})
243
- tokens_param = model_info.get("tokens_param", "max_tokens")
244
- supports_temperature = model_info.get("supports_temperature", True)
251
+ model_config = self._get_model_config("openai", model)
252
+ tokens_param = model_config["tokens_param"]
253
+ supports_temperature = model_config["supports_temperature"]
245
254
 
246
255
  opts = {"temperature": 1.0, "max_tokens": 512, **options}
247
256
 
@@ -85,10 +85,10 @@ class OpenRouterDriver(CostMixin, Driver):
85
85
 
86
86
  model = options.get("model", self.model)
87
87
 
88
- # Lookup model-specific config
89
- model_info = self.MODEL_PRICING.get(model, {})
90
- tokens_param = model_info.get("tokens_param", "max_tokens")
91
- supports_temperature = model_info.get("supports_temperature", True)
88
+ # Lookup model-specific config (live models.dev data + hardcoded fallback)
89
+ model_config = self._get_model_config("openrouter", model)
90
+ tokens_param = model_config["tokens_param"]
91
+ supports_temperature = model_config["supports_temperature"]
92
92
 
93
93
  # Defaults
94
94
  opts = {"temperature": 1.0, "max_tokens": 512, **options}