prompture 0.0.35__py3-none-any.whl → 0.0.40.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prompture/__init__.py +132 -3
- prompture/_version.py +2 -2
- prompture/agent.py +924 -0
- prompture/agent_types.py +156 -0
- prompture/async_agent.py +880 -0
- prompture/async_conversation.py +208 -17
- prompture/async_core.py +16 -0
- prompture/async_driver.py +63 -0
- prompture/async_groups.py +551 -0
- prompture/conversation.py +222 -18
- prompture/core.py +46 -12
- prompture/cost_mixin.py +37 -0
- prompture/discovery.py +132 -44
- prompture/driver.py +77 -0
- prompture/drivers/__init__.py +5 -1
- prompture/drivers/async_azure_driver.py +11 -5
- prompture/drivers/async_claude_driver.py +184 -9
- prompture/drivers/async_google_driver.py +222 -28
- prompture/drivers/async_grok_driver.py +11 -5
- prompture/drivers/async_groq_driver.py +11 -5
- prompture/drivers/async_lmstudio_driver.py +74 -5
- prompture/drivers/async_ollama_driver.py +13 -3
- prompture/drivers/async_openai_driver.py +162 -5
- prompture/drivers/async_openrouter_driver.py +11 -5
- prompture/drivers/async_registry.py +5 -1
- prompture/drivers/azure_driver.py +10 -4
- prompture/drivers/claude_driver.py +17 -1
- prompture/drivers/google_driver.py +227 -33
- prompture/drivers/grok_driver.py +11 -5
- prompture/drivers/groq_driver.py +11 -5
- prompture/drivers/lmstudio_driver.py +73 -8
- prompture/drivers/ollama_driver.py +16 -5
- prompture/drivers/openai_driver.py +26 -11
- prompture/drivers/openrouter_driver.py +11 -5
- prompture/drivers/vision_helpers.py +153 -0
- prompture/group_types.py +147 -0
- prompture/groups.py +530 -0
- prompture/image.py +180 -0
- prompture/ledger.py +252 -0
- prompture/model_rates.py +112 -2
- prompture/persistence.py +254 -0
- prompture/persona.py +482 -0
- prompture/serialization.py +218 -0
- prompture/settings.py +1 -0
- prompture-0.0.40.dev1.dist-info/METADATA +369 -0
- prompture-0.0.40.dev1.dist-info/RECORD +78 -0
- prompture-0.0.35.dist-info/METADATA +0 -464
- prompture-0.0.35.dist-info/RECORD +0 -66
- {prompture-0.0.35.dist-info → prompture-0.0.40.dev1.dist-info}/WHEEL +0 -0
- {prompture-0.0.35.dist-info → prompture-0.0.40.dev1.dist-info}/entry_points.txt +0 -0
- {prompture-0.0.35.dist-info → prompture-0.0.40.dev1.dist-info}/licenses/LICENSE +0 -0
- {prompture-0.0.35.dist-info → prompture-0.0.40.dev1.dist-info}/top_level.txt +0 -0
prompture/async_conversation.py
CHANGED
|
@@ -4,9 +4,11 @@ from __future__ import annotations
|
|
|
4
4
|
|
|
5
5
|
import json
|
|
6
6
|
import logging
|
|
7
|
+
import uuid
|
|
7
8
|
from collections.abc import AsyncIterator
|
|
8
|
-
from datetime import date, datetime
|
|
9
|
+
from datetime import date, datetime, timezone
|
|
9
10
|
from decimal import Decimal
|
|
11
|
+
from pathlib import Path
|
|
10
12
|
from typing import Any, Callable, Literal, Union
|
|
11
13
|
|
|
12
14
|
from pydantic import BaseModel
|
|
@@ -15,6 +17,11 @@ from .async_driver import AsyncDriver
|
|
|
15
17
|
from .callbacks import DriverCallbacks
|
|
16
18
|
from .drivers.async_registry import get_async_driver_for_model
|
|
17
19
|
from .field_definitions import get_registry_snapshot
|
|
20
|
+
from .image import ImageInput, make_image
|
|
21
|
+
from .persistence import load_from_file, save_to_file
|
|
22
|
+
from .persona import Persona, get_persona
|
|
23
|
+
from .serialization import export_conversation, import_conversation
|
|
24
|
+
from .session import UsageSession
|
|
18
25
|
from .tools import (
|
|
19
26
|
clean_json_text,
|
|
20
27
|
convert_value,
|
|
@@ -43,13 +50,33 @@ class AsyncConversation:
|
|
|
43
50
|
*,
|
|
44
51
|
driver: AsyncDriver | None = None,
|
|
45
52
|
system_prompt: str | None = None,
|
|
53
|
+
persona: str | Persona | None = None,
|
|
46
54
|
options: dict[str, Any] | None = None,
|
|
47
55
|
callbacks: DriverCallbacks | None = None,
|
|
48
56
|
tools: ToolRegistry | None = None,
|
|
49
57
|
max_tool_rounds: int = 10,
|
|
58
|
+
conversation_id: str | None = None,
|
|
59
|
+
auto_save: str | Path | None = None,
|
|
60
|
+
tags: list[str] | None = None,
|
|
50
61
|
) -> None:
|
|
62
|
+
if system_prompt is not None and persona is not None:
|
|
63
|
+
raise ValueError("Cannot provide both 'system_prompt' and 'persona'. Use one or the other.")
|
|
64
|
+
|
|
65
|
+
# Resolve persona
|
|
66
|
+
resolved_persona: Persona | None = None
|
|
67
|
+
if persona is not None:
|
|
68
|
+
if isinstance(persona, str):
|
|
69
|
+
resolved_persona = get_persona(persona)
|
|
70
|
+
if resolved_persona is None:
|
|
71
|
+
raise ValueError(f"Persona '{persona}' not found in registry.")
|
|
72
|
+
else:
|
|
73
|
+
resolved_persona = persona
|
|
74
|
+
|
|
51
75
|
if model_name is None and driver is None:
|
|
52
|
-
|
|
76
|
+
if resolved_persona is not None and resolved_persona.model_hint:
|
|
77
|
+
model_name = resolved_persona.model_hint
|
|
78
|
+
else:
|
|
79
|
+
raise ValueError("Either model_name or driver must be provided")
|
|
53
80
|
|
|
54
81
|
if driver is not None:
|
|
55
82
|
self._driver = driver
|
|
@@ -60,8 +87,15 @@ class AsyncConversation:
|
|
|
60
87
|
self._driver.callbacks = callbacks
|
|
61
88
|
|
|
62
89
|
self._model_name = model_name or ""
|
|
63
|
-
|
|
64
|
-
|
|
90
|
+
|
|
91
|
+
# Apply persona: render system_prompt and merge settings
|
|
92
|
+
if resolved_persona is not None:
|
|
93
|
+
self._system_prompt = resolved_persona.render()
|
|
94
|
+
self._options = {**resolved_persona.settings, **(dict(options) if options else {})}
|
|
95
|
+
else:
|
|
96
|
+
self._system_prompt = system_prompt
|
|
97
|
+
self._options = dict(options) if options else {}
|
|
98
|
+
|
|
65
99
|
self._messages: list[dict[str, Any]] = []
|
|
66
100
|
self._usage = {
|
|
67
101
|
"prompt_tokens": 0,
|
|
@@ -73,6 +107,14 @@ class AsyncConversation:
|
|
|
73
107
|
self._tools = tools or ToolRegistry()
|
|
74
108
|
self._max_tool_rounds = max_tool_rounds
|
|
75
109
|
|
|
110
|
+
# Persistence
|
|
111
|
+
self._conversation_id = conversation_id or str(uuid.uuid4())
|
|
112
|
+
self._auto_save = Path(auto_save) if auto_save else None
|
|
113
|
+
self._metadata: dict[str, Any] = {
|
|
114
|
+
"created_at": datetime.now(timezone.utc).isoformat(),
|
|
115
|
+
"tags": list(tags) if tags else [],
|
|
116
|
+
}
|
|
117
|
+
|
|
76
118
|
# ------------------------------------------------------------------
|
|
77
119
|
# Public helpers
|
|
78
120
|
# ------------------------------------------------------------------
|
|
@@ -91,11 +133,12 @@ class AsyncConversation:
|
|
|
91
133
|
"""Reset message history (keeps system_prompt and driver)."""
|
|
92
134
|
self._messages.clear()
|
|
93
135
|
|
|
94
|
-
def add_context(self, role: str, content: str) -> None:
|
|
136
|
+
def add_context(self, role: str, content: str, images: list[ImageInput] | None = None) -> None:
|
|
95
137
|
"""Seed the history with a user or assistant message."""
|
|
96
138
|
if role not in ("user", "assistant"):
|
|
97
139
|
raise ValueError("role must be 'user' or 'assistant'")
|
|
98
|
-
self.
|
|
140
|
+
msg_content = self._build_content_with_images(content, images)
|
|
141
|
+
self._messages.append({"role": role, "content": msg_content})
|
|
99
142
|
|
|
100
143
|
def register_tool(
|
|
101
144
|
self,
|
|
@@ -112,17 +155,145 @@ class AsyncConversation:
|
|
|
112
155
|
u = self._usage
|
|
113
156
|
return f"Conversation: {u['total_tokens']:,} tokens across {u['turns']} turn(s) costing ${u['cost']:.4f}"
|
|
114
157
|
|
|
158
|
+
# ------------------------------------------------------------------
|
|
159
|
+
# Persistence properties
|
|
160
|
+
# ------------------------------------------------------------------
|
|
161
|
+
|
|
162
|
+
@property
|
|
163
|
+
def conversation_id(self) -> str:
|
|
164
|
+
"""Unique identifier for this conversation."""
|
|
165
|
+
return self._conversation_id
|
|
166
|
+
|
|
167
|
+
@property
|
|
168
|
+
def tags(self) -> list[str]:
|
|
169
|
+
"""Tags attached to this conversation."""
|
|
170
|
+
return self._metadata.get("tags", [])
|
|
171
|
+
|
|
172
|
+
@tags.setter
|
|
173
|
+
def tags(self, value: list[str]) -> None:
|
|
174
|
+
self._metadata["tags"] = list(value)
|
|
175
|
+
|
|
176
|
+
# ------------------------------------------------------------------
|
|
177
|
+
# Export / Import
|
|
178
|
+
# ------------------------------------------------------------------
|
|
179
|
+
|
|
180
|
+
def export(self, *, usage_session: UsageSession | None = None, strip_images: bool = False) -> dict[str, Any]:
|
|
181
|
+
"""Export conversation state to a JSON-serializable dict."""
|
|
182
|
+
tools_metadata = (
|
|
183
|
+
[
|
|
184
|
+
{"name": td.name, "description": td.description, "parameters": td.parameters}
|
|
185
|
+
for td in self._tools.definitions
|
|
186
|
+
]
|
|
187
|
+
if self._tools and self._tools.definitions
|
|
188
|
+
else None
|
|
189
|
+
)
|
|
190
|
+
return export_conversation(
|
|
191
|
+
model_name=self._model_name,
|
|
192
|
+
system_prompt=self._system_prompt,
|
|
193
|
+
options=self._options,
|
|
194
|
+
messages=self._messages,
|
|
195
|
+
usage=self._usage,
|
|
196
|
+
max_tool_rounds=self._max_tool_rounds,
|
|
197
|
+
tools_metadata=tools_metadata,
|
|
198
|
+
usage_session=usage_session,
|
|
199
|
+
metadata=self._metadata,
|
|
200
|
+
conversation_id=self._conversation_id,
|
|
201
|
+
strip_images=strip_images,
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
@classmethod
|
|
205
|
+
def from_export(
|
|
206
|
+
cls,
|
|
207
|
+
data: dict[str, Any],
|
|
208
|
+
*,
|
|
209
|
+
callbacks: DriverCallbacks | None = None,
|
|
210
|
+
tools: ToolRegistry | None = None,
|
|
211
|
+
) -> AsyncConversation:
|
|
212
|
+
"""Reconstruct an :class:`AsyncConversation` from an export dict.
|
|
213
|
+
|
|
214
|
+
The driver is reconstructed from the stored ``model_name`` using
|
|
215
|
+
:func:`get_async_driver_for_model`. Callbacks and tool functions
|
|
216
|
+
must be re-attached by the caller.
|
|
217
|
+
"""
|
|
218
|
+
imported = import_conversation(data)
|
|
219
|
+
|
|
220
|
+
model_name = imported.get("model_name") or ""
|
|
221
|
+
if not model_name:
|
|
222
|
+
raise ValueError("Cannot restore conversation: export has no model_name")
|
|
223
|
+
conv = cls(
|
|
224
|
+
model_name=model_name,
|
|
225
|
+
system_prompt=imported.get("system_prompt"),
|
|
226
|
+
options=imported.get("options", {}),
|
|
227
|
+
callbacks=callbacks,
|
|
228
|
+
tools=tools,
|
|
229
|
+
max_tool_rounds=imported.get("max_tool_rounds", 10),
|
|
230
|
+
conversation_id=imported.get("conversation_id"),
|
|
231
|
+
tags=imported.get("metadata", {}).get("tags", []),
|
|
232
|
+
)
|
|
233
|
+
conv._messages = imported.get("messages", [])
|
|
234
|
+
conv._usage = imported.get(
|
|
235
|
+
"usage",
|
|
236
|
+
{
|
|
237
|
+
"prompt_tokens": 0,
|
|
238
|
+
"completion_tokens": 0,
|
|
239
|
+
"total_tokens": 0,
|
|
240
|
+
"cost": 0.0,
|
|
241
|
+
"turns": 0,
|
|
242
|
+
},
|
|
243
|
+
)
|
|
244
|
+
meta = imported.get("metadata", {})
|
|
245
|
+
if "created_at" in meta:
|
|
246
|
+
conv._metadata["created_at"] = meta["created_at"]
|
|
247
|
+
return conv
|
|
248
|
+
|
|
249
|
+
def save(self, path: str | Path, **kwargs: Any) -> None:
|
|
250
|
+
"""Export and write to a JSON file."""
|
|
251
|
+
save_to_file(self.export(**kwargs), path)
|
|
252
|
+
|
|
253
|
+
@classmethod
|
|
254
|
+
def load(
|
|
255
|
+
cls,
|
|
256
|
+
path: str | Path,
|
|
257
|
+
*,
|
|
258
|
+
callbacks: DriverCallbacks | None = None,
|
|
259
|
+
tools: ToolRegistry | None = None,
|
|
260
|
+
) -> AsyncConversation:
|
|
261
|
+
"""Load a conversation from a JSON file."""
|
|
262
|
+
data = load_from_file(path)
|
|
263
|
+
return cls.from_export(data, callbacks=callbacks, tools=tools)
|
|
264
|
+
|
|
265
|
+
def _maybe_auto_save(self) -> None:
|
|
266
|
+
"""Auto-save after each turn if configured."""
|
|
267
|
+
if self._auto_save is None:
|
|
268
|
+
return
|
|
269
|
+
try:
|
|
270
|
+
self.save(self._auto_save)
|
|
271
|
+
except Exception:
|
|
272
|
+
logger.debug("Auto-save failed for conversation %s", self._conversation_id, exc_info=True)
|
|
273
|
+
|
|
115
274
|
# ------------------------------------------------------------------
|
|
116
275
|
# Core methods
|
|
117
276
|
# ------------------------------------------------------------------
|
|
118
277
|
|
|
119
|
-
|
|
278
|
+
@staticmethod
|
|
279
|
+
def _build_content_with_images(text: str, images: list[ImageInput] | None = None) -> str | list[dict[str, Any]]:
|
|
280
|
+
"""Return plain string when no images, or a list of content blocks."""
|
|
281
|
+
if not images:
|
|
282
|
+
return text
|
|
283
|
+
blocks: list[dict[str, Any]] = [{"type": "text", "text": text}]
|
|
284
|
+
for img in images:
|
|
285
|
+
ic = make_image(img)
|
|
286
|
+
blocks.append({"type": "image", "source": ic})
|
|
287
|
+
return blocks
|
|
288
|
+
|
|
289
|
+
def _build_messages(self, user_content: str, images: list[ImageInput] | None = None) -> list[dict[str, Any]]:
|
|
120
290
|
"""Build the full messages array for an API call."""
|
|
121
291
|
msgs: list[dict[str, Any]] = []
|
|
122
292
|
if self._system_prompt:
|
|
123
293
|
msgs.append({"role": "system", "content": self._system_prompt})
|
|
124
294
|
msgs.extend(self._messages)
|
|
125
|
-
|
|
295
|
+
content = self._build_content_with_images(user_content, images)
|
|
296
|
+
msgs.append({"role": "user", "content": content})
|
|
126
297
|
return msgs
|
|
127
298
|
|
|
128
299
|
def _accumulate_usage(self, meta: dict[str, Any]) -> None:
|
|
@@ -131,11 +302,22 @@ class AsyncConversation:
|
|
|
131
302
|
self._usage["total_tokens"] += meta.get("total_tokens", 0)
|
|
132
303
|
self._usage["cost"] += meta.get("cost", 0.0)
|
|
133
304
|
self._usage["turns"] += 1
|
|
305
|
+
self._maybe_auto_save()
|
|
306
|
+
|
|
307
|
+
from .ledger import _resolve_api_key_hash, record_model_usage
|
|
308
|
+
|
|
309
|
+
record_model_usage(
|
|
310
|
+
self._model_name,
|
|
311
|
+
api_key_hash=_resolve_api_key_hash(self._model_name),
|
|
312
|
+
tokens=meta.get("total_tokens", 0),
|
|
313
|
+
cost=meta.get("cost", 0.0),
|
|
314
|
+
)
|
|
134
315
|
|
|
135
316
|
async def ask(
|
|
136
317
|
self,
|
|
137
318
|
content: str,
|
|
138
319
|
options: dict[str, Any] | None = None,
|
|
320
|
+
images: list[ImageInput] | None = None,
|
|
139
321
|
) -> str:
|
|
140
322
|
"""Send a message and get a raw text response (async).
|
|
141
323
|
|
|
@@ -143,16 +325,17 @@ class AsyncConversation:
|
|
|
143
325
|
dispatches to the async tool execution loop.
|
|
144
326
|
"""
|
|
145
327
|
if self._tools and getattr(self._driver, "supports_tool_use", False):
|
|
146
|
-
return await self._ask_with_tools(content, options)
|
|
328
|
+
return await self._ask_with_tools(content, options, images=images)
|
|
147
329
|
|
|
148
330
|
merged = {**self._options, **(options or {})}
|
|
149
|
-
messages = self._build_messages(content)
|
|
331
|
+
messages = self._build_messages(content, images=images)
|
|
150
332
|
resp = await self._driver.generate_messages_with_hooks(messages, merged)
|
|
151
333
|
|
|
152
334
|
text = resp.get("text", "")
|
|
153
335
|
meta = resp.get("meta", {})
|
|
154
336
|
|
|
155
|
-
self.
|
|
337
|
+
user_content = self._build_content_with_images(content, images)
|
|
338
|
+
self._messages.append({"role": "user", "content": user_content})
|
|
156
339
|
self._messages.append({"role": "assistant", "content": text})
|
|
157
340
|
self._accumulate_usage(meta)
|
|
158
341
|
|
|
@@ -162,12 +345,14 @@ class AsyncConversation:
|
|
|
162
345
|
self,
|
|
163
346
|
content: str,
|
|
164
347
|
options: dict[str, Any] | None = None,
|
|
348
|
+
images: list[ImageInput] | None = None,
|
|
165
349
|
) -> str:
|
|
166
350
|
"""Async tool-use loop: send -> check tool_calls -> execute -> re-send."""
|
|
167
351
|
merged = {**self._options, **(options or {})}
|
|
168
352
|
tool_defs = self._tools.to_openai_format()
|
|
169
353
|
|
|
170
|
-
self.
|
|
354
|
+
user_content = self._build_content_with_images(content, images)
|
|
355
|
+
self._messages.append({"role": "user", "content": user_content})
|
|
171
356
|
msgs = self._build_messages_raw()
|
|
172
357
|
|
|
173
358
|
for _round in range(self._max_tool_rounds):
|
|
@@ -228,6 +413,7 @@ class AsyncConversation:
|
|
|
228
413
|
self,
|
|
229
414
|
content: str,
|
|
230
415
|
options: dict[str, Any] | None = None,
|
|
416
|
+
images: list[ImageInput] | None = None,
|
|
231
417
|
) -> AsyncIterator[str]:
|
|
232
418
|
"""Send a message and yield text chunks as they arrive (async).
|
|
233
419
|
|
|
@@ -235,13 +421,14 @@ class AsyncConversation:
|
|
|
235
421
|
support streaming.
|
|
236
422
|
"""
|
|
237
423
|
if not getattr(self._driver, "supports_streaming", False):
|
|
238
|
-
yield await self.ask(content, options)
|
|
424
|
+
yield await self.ask(content, options, images=images)
|
|
239
425
|
return
|
|
240
426
|
|
|
241
427
|
merged = {**self._options, **(options or {})}
|
|
242
|
-
messages = self._build_messages(content)
|
|
428
|
+
messages = self._build_messages(content, images=images)
|
|
243
429
|
|
|
244
|
-
self.
|
|
430
|
+
user_content = self._build_content_with_images(content, images)
|
|
431
|
+
self._messages.append({"role": "user", "content": user_content})
|
|
245
432
|
|
|
246
433
|
full_text = ""
|
|
247
434
|
async for chunk in self._driver.generate_messages_stream(messages, merged):
|
|
@@ -267,6 +454,7 @@ class AsyncConversation:
|
|
|
267
454
|
options: dict[str, Any] | None = None,
|
|
268
455
|
output_format: Literal["json", "toon"] = "json",
|
|
269
456
|
json_mode: Literal["auto", "on", "off"] = "auto",
|
|
457
|
+
images: list[ImageInput] | None = None,
|
|
270
458
|
) -> dict[str, Any]:
|
|
271
459
|
"""Send a message with schema enforcement and get structured JSON back (async)."""
|
|
272
460
|
merged = {**self._options, **(options or {})}
|
|
@@ -301,13 +489,14 @@ class AsyncConversation:
|
|
|
301
489
|
|
|
302
490
|
full_user_content = f"{content}\n\n{instruct}"
|
|
303
491
|
|
|
304
|
-
messages = self._build_messages(full_user_content)
|
|
492
|
+
messages = self._build_messages(full_user_content, images=images)
|
|
305
493
|
resp = await self._driver.generate_messages_with_hooks(messages, merged)
|
|
306
494
|
|
|
307
495
|
text = resp.get("text", "")
|
|
308
496
|
meta = resp.get("meta", {})
|
|
309
497
|
|
|
310
|
-
self.
|
|
498
|
+
user_content = self._build_content_with_images(content, images)
|
|
499
|
+
self._messages.append({"role": "user", "content": user_content})
|
|
311
500
|
|
|
312
501
|
cleaned = clean_json_text(text)
|
|
313
502
|
try:
|
|
@@ -361,6 +550,7 @@ class AsyncConversation:
|
|
|
361
550
|
output_format: Literal["json", "toon"] = "json",
|
|
362
551
|
options: dict[str, Any] | None = None,
|
|
363
552
|
json_mode: Literal["auto", "on", "off"] = "auto",
|
|
553
|
+
images: list[ImageInput] | None = None,
|
|
364
554
|
) -> dict[str, Any]:
|
|
365
555
|
"""Extract structured information into a Pydantic model with conversation context (async)."""
|
|
366
556
|
from .core import normalize_field_value
|
|
@@ -375,6 +565,7 @@ class AsyncConversation:
|
|
|
375
565
|
options=options,
|
|
376
566
|
output_format=output_format,
|
|
377
567
|
json_mode=json_mode,
|
|
568
|
+
images=images,
|
|
378
569
|
)
|
|
379
570
|
|
|
380
571
|
json_object = result["json_object"]
|
prompture/async_core.py
CHANGED
|
@@ -35,6 +35,18 @@ from .tools import (
|
|
|
35
35
|
logger = logging.getLogger("prompture.async_core")
|
|
36
36
|
|
|
37
37
|
|
|
38
|
+
def _record_usage_to_ledger(model_name: str, meta: dict[str, Any]) -> None:
|
|
39
|
+
"""Fire-and-forget ledger recording for standalone async core functions."""
|
|
40
|
+
from .ledger import _resolve_api_key_hash, record_model_usage
|
|
41
|
+
|
|
42
|
+
record_model_usage(
|
|
43
|
+
model_name,
|
|
44
|
+
api_key_hash=_resolve_api_key_hash(model_name),
|
|
45
|
+
tokens=meta.get("total_tokens", 0),
|
|
46
|
+
cost=meta.get("cost", 0.0),
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
|
|
38
50
|
async def clean_json_text_with_ai(
|
|
39
51
|
driver: AsyncDriver, text: str, model_name: str = "", options: dict[str, Any] | None = None
|
|
40
52
|
) -> str:
|
|
@@ -117,6 +129,8 @@ async def render_output(
|
|
|
117
129
|
"model_name": model_name or getattr(driver, "model", ""),
|
|
118
130
|
}
|
|
119
131
|
|
|
132
|
+
_record_usage_to_ledger(model_name, resp.get("meta", {}))
|
|
133
|
+
|
|
120
134
|
return {"text": raw, "usage": usage, "output_format": output_format}
|
|
121
135
|
|
|
122
136
|
|
|
@@ -211,6 +225,8 @@ async def ask_for_json(
|
|
|
211
225
|
raw = resp.get("text", "")
|
|
212
226
|
cleaned = clean_json_text(raw)
|
|
213
227
|
|
|
228
|
+
_record_usage_to_ledger(model_name, resp.get("meta", {}))
|
|
229
|
+
|
|
214
230
|
try:
|
|
215
231
|
json_obj = json.loads(cleaned)
|
|
216
232
|
json_string = cleaned
|
prompture/async_driver.py
CHANGED
|
@@ -35,6 +35,7 @@ class AsyncDriver:
|
|
|
35
35
|
supports_messages: bool = False
|
|
36
36
|
supports_tool_use: bool = False
|
|
37
37
|
supports_streaming: bool = False
|
|
38
|
+
supports_vision: bool = False
|
|
38
39
|
|
|
39
40
|
callbacks: DriverCallbacks | None = None
|
|
40
41
|
|
|
@@ -165,5 +166,67 @@ class AsyncDriver:
|
|
|
165
166
|
except Exception:
|
|
166
167
|
logger.exception("Callback %s raised an exception", event)
|
|
167
168
|
|
|
169
|
+
def _validate_model_capabilities(
|
|
170
|
+
self,
|
|
171
|
+
provider: str,
|
|
172
|
+
model: str,
|
|
173
|
+
*,
|
|
174
|
+
using_tool_use: bool = False,
|
|
175
|
+
using_json_schema: bool = False,
|
|
176
|
+
using_vision: bool = False,
|
|
177
|
+
) -> None:
|
|
178
|
+
"""Log warnings when the model may not support a requested feature.
|
|
179
|
+
|
|
180
|
+
Uses models.dev metadata as a secondary signal. Warnings only — the
|
|
181
|
+
API is the final authority and models.dev data may be stale.
|
|
182
|
+
"""
|
|
183
|
+
from .model_rates import get_model_capabilities
|
|
184
|
+
|
|
185
|
+
caps = get_model_capabilities(provider, model)
|
|
186
|
+
if caps is None:
|
|
187
|
+
return
|
|
188
|
+
|
|
189
|
+
if using_tool_use and caps.supports_tool_use is False:
|
|
190
|
+
logger.warning(
|
|
191
|
+
"Model %s/%s may not support tool use according to models.dev metadata",
|
|
192
|
+
provider,
|
|
193
|
+
model,
|
|
194
|
+
)
|
|
195
|
+
if using_json_schema and caps.supports_structured_output is False:
|
|
196
|
+
logger.warning(
|
|
197
|
+
"Model %s/%s may not support structured output / JSON schema according to models.dev metadata",
|
|
198
|
+
provider,
|
|
199
|
+
model,
|
|
200
|
+
)
|
|
201
|
+
if using_vision and caps.supports_vision is False:
|
|
202
|
+
logger.warning(
|
|
203
|
+
"Model %s/%s may not support vision/image inputs according to models.dev metadata",
|
|
204
|
+
provider,
|
|
205
|
+
model,
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
def _check_vision_support(self, messages: list[dict[str, Any]]) -> None:
|
|
209
|
+
"""Raise if messages contain image blocks and the driver lacks vision support."""
|
|
210
|
+
if self.supports_vision:
|
|
211
|
+
return
|
|
212
|
+
for msg in messages:
|
|
213
|
+
content = msg.get("content")
|
|
214
|
+
if isinstance(content, list):
|
|
215
|
+
for block in content:
|
|
216
|
+
if isinstance(block, dict) and block.get("type") == "image":
|
|
217
|
+
raise NotImplementedError(
|
|
218
|
+
f"{self.__class__.__name__} does not support vision/image inputs. "
|
|
219
|
+
"Use a vision-capable model."
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
223
|
+
"""Transform universal message format into provider-specific wire format.
|
|
224
|
+
|
|
225
|
+
Vision-capable async drivers override this to convert the universal
|
|
226
|
+
image blocks into their provider-specific format.
|
|
227
|
+
"""
|
|
228
|
+
self._check_vision_support(messages)
|
|
229
|
+
return messages
|
|
230
|
+
|
|
168
231
|
# Re-export the static helper for convenience
|
|
169
232
|
_flatten_messages = staticmethod(Driver._flatten_messages)
|