prompture 0.0.29.dev8__py3-none-any.whl → 0.0.38.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prompture/__init__.py +264 -23
- prompture/_version.py +34 -0
- prompture/agent.py +924 -0
- prompture/agent_types.py +156 -0
- prompture/aio/__init__.py +74 -0
- prompture/async_agent.py +880 -0
- prompture/async_conversation.py +789 -0
- prompture/async_core.py +803 -0
- prompture/async_driver.py +193 -0
- prompture/async_groups.py +551 -0
- prompture/cache.py +469 -0
- prompture/callbacks.py +55 -0
- prompture/cli.py +63 -4
- prompture/conversation.py +826 -0
- prompture/core.py +894 -263
- prompture/cost_mixin.py +51 -0
- prompture/discovery.py +187 -0
- prompture/driver.py +206 -5
- prompture/drivers/__init__.py +175 -67
- prompture/drivers/airllm_driver.py +109 -0
- prompture/drivers/async_airllm_driver.py +26 -0
- prompture/drivers/async_azure_driver.py +123 -0
- prompture/drivers/async_claude_driver.py +113 -0
- prompture/drivers/async_google_driver.py +316 -0
- prompture/drivers/async_grok_driver.py +97 -0
- prompture/drivers/async_groq_driver.py +90 -0
- prompture/drivers/async_hugging_driver.py +61 -0
- prompture/drivers/async_lmstudio_driver.py +148 -0
- prompture/drivers/async_local_http_driver.py +44 -0
- prompture/drivers/async_ollama_driver.py +135 -0
- prompture/drivers/async_openai_driver.py +102 -0
- prompture/drivers/async_openrouter_driver.py +102 -0
- prompture/drivers/async_registry.py +133 -0
- prompture/drivers/azure_driver.py +42 -9
- prompture/drivers/claude_driver.py +257 -34
- prompture/drivers/google_driver.py +295 -42
- prompture/drivers/grok_driver.py +35 -32
- prompture/drivers/groq_driver.py +33 -26
- prompture/drivers/hugging_driver.py +6 -6
- prompture/drivers/lmstudio_driver.py +97 -19
- prompture/drivers/local_http_driver.py +6 -6
- prompture/drivers/ollama_driver.py +168 -23
- prompture/drivers/openai_driver.py +184 -9
- prompture/drivers/openrouter_driver.py +37 -25
- prompture/drivers/registry.py +306 -0
- prompture/drivers/vision_helpers.py +153 -0
- prompture/field_definitions.py +106 -96
- prompture/group_types.py +147 -0
- prompture/groups.py +530 -0
- prompture/image.py +180 -0
- prompture/logging.py +80 -0
- prompture/model_rates.py +217 -0
- prompture/persistence.py +254 -0
- prompture/persona.py +482 -0
- prompture/runner.py +49 -47
- prompture/scaffold/__init__.py +1 -0
- prompture/scaffold/generator.py +84 -0
- prompture/scaffold/templates/Dockerfile.j2 +12 -0
- prompture/scaffold/templates/README.md.j2 +41 -0
- prompture/scaffold/templates/config.py.j2 +21 -0
- prompture/scaffold/templates/env.example.j2 +8 -0
- prompture/scaffold/templates/main.py.j2 +86 -0
- prompture/scaffold/templates/models.py.j2 +40 -0
- prompture/scaffold/templates/requirements.txt.j2 +5 -0
- prompture/serialization.py +218 -0
- prompture/server.py +183 -0
- prompture/session.py +117 -0
- prompture/settings.py +19 -1
- prompture/tools.py +219 -267
- prompture/tools_schema.py +254 -0
- prompture/validator.py +3 -3
- prompture-0.0.38.dev2.dist-info/METADATA +369 -0
- prompture-0.0.38.dev2.dist-info/RECORD +77 -0
- {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/WHEEL +1 -1
- prompture-0.0.29.dev8.dist-info/METADATA +0 -368
- prompture-0.0.29.dev8.dist-info/RECORD +0 -27
- {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/entry_points.txt +0 -0
- {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/licenses/LICENSE +0 -0
- {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,826 @@
|
|
|
1
|
+
"""Stateful multi-turn conversation support for Prompture."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import logging
|
|
7
|
+
import uuid
|
|
8
|
+
from collections.abc import Iterator
|
|
9
|
+
from datetime import date, datetime, timezone
|
|
10
|
+
from decimal import Decimal
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import Any, Callable, Literal, Union
|
|
13
|
+
|
|
14
|
+
from pydantic import BaseModel
|
|
15
|
+
|
|
16
|
+
from .callbacks import DriverCallbacks
|
|
17
|
+
from .driver import Driver
|
|
18
|
+
from .drivers import get_driver_for_model
|
|
19
|
+
from .field_definitions import get_registry_snapshot
|
|
20
|
+
from .image import ImageInput, make_image
|
|
21
|
+
from .persistence import load_from_file, save_to_file
|
|
22
|
+
from .persona import Persona, get_persona
|
|
23
|
+
from .serialization import export_conversation, import_conversation
|
|
24
|
+
from .session import UsageSession
|
|
25
|
+
from .tools import (
|
|
26
|
+
clean_json_text,
|
|
27
|
+
convert_value,
|
|
28
|
+
get_field_default,
|
|
29
|
+
)
|
|
30
|
+
from .tools_schema import ToolRegistry
|
|
31
|
+
|
|
32
|
+
logger = logging.getLogger("prompture.conversation")
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class Conversation:
|
|
36
|
+
"""Stateful multi-turn conversation with an LLM.
|
|
37
|
+
|
|
38
|
+
Maintains a message history across calls so the model can reference
|
|
39
|
+
previous turns. Works with any Prompture driver.
|
|
40
|
+
|
|
41
|
+
Example::
|
|
42
|
+
|
|
43
|
+
conv = Conversation("openai/gpt-4", system_prompt="You are a data extractor")
|
|
44
|
+
r1 = conv.ask_for_json("Extract names from: John, age 30", name_schema)
|
|
45
|
+
r2 = conv.ask_for_json("Now extract ages", age_schema) # sees turn 1
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def __init__(
|
|
49
|
+
self,
|
|
50
|
+
model_name: str | None = None,
|
|
51
|
+
*,
|
|
52
|
+
driver: Driver | None = None,
|
|
53
|
+
system_prompt: str | None = None,
|
|
54
|
+
persona: str | Persona | None = None,
|
|
55
|
+
options: dict[str, Any] | None = None,
|
|
56
|
+
callbacks: DriverCallbacks | None = None,
|
|
57
|
+
tools: ToolRegistry | None = None,
|
|
58
|
+
max_tool_rounds: int = 10,
|
|
59
|
+
conversation_id: str | None = None,
|
|
60
|
+
auto_save: str | Path | None = None,
|
|
61
|
+
tags: list[str] | None = None,
|
|
62
|
+
) -> None:
|
|
63
|
+
if system_prompt is not None and persona is not None:
|
|
64
|
+
raise ValueError("Cannot provide both 'system_prompt' and 'persona'. Use one or the other.")
|
|
65
|
+
|
|
66
|
+
# Resolve persona
|
|
67
|
+
resolved_persona: Persona | None = None
|
|
68
|
+
if persona is not None:
|
|
69
|
+
if isinstance(persona, str):
|
|
70
|
+
resolved_persona = get_persona(persona)
|
|
71
|
+
if resolved_persona is None:
|
|
72
|
+
raise ValueError(f"Persona '{persona}' not found in registry.")
|
|
73
|
+
else:
|
|
74
|
+
resolved_persona = persona
|
|
75
|
+
|
|
76
|
+
if model_name is None and driver is None:
|
|
77
|
+
# Check persona for model_hint
|
|
78
|
+
if resolved_persona is not None and resolved_persona.model_hint:
|
|
79
|
+
model_name = resolved_persona.model_hint
|
|
80
|
+
else:
|
|
81
|
+
raise ValueError("Either model_name or driver must be provided")
|
|
82
|
+
|
|
83
|
+
if driver is not None:
|
|
84
|
+
self._driver = driver
|
|
85
|
+
else:
|
|
86
|
+
self._driver = get_driver_for_model(model_name)
|
|
87
|
+
|
|
88
|
+
if callbacks is not None:
|
|
89
|
+
self._driver.callbacks = callbacks
|
|
90
|
+
|
|
91
|
+
self._model_name = model_name or ""
|
|
92
|
+
|
|
93
|
+
# Apply persona: render system_prompt and merge settings
|
|
94
|
+
if resolved_persona is not None:
|
|
95
|
+
self._system_prompt = resolved_persona.render()
|
|
96
|
+
# Persona settings as defaults, explicit options override
|
|
97
|
+
self._options = {**resolved_persona.settings, **(dict(options) if options else {})}
|
|
98
|
+
else:
|
|
99
|
+
self._system_prompt = system_prompt
|
|
100
|
+
self._options = dict(options) if options else {}
|
|
101
|
+
|
|
102
|
+
self._messages: list[dict[str, Any]] = []
|
|
103
|
+
self._usage = {
|
|
104
|
+
"prompt_tokens": 0,
|
|
105
|
+
"completion_tokens": 0,
|
|
106
|
+
"total_tokens": 0,
|
|
107
|
+
"cost": 0.0,
|
|
108
|
+
"turns": 0,
|
|
109
|
+
}
|
|
110
|
+
self._tools = tools or ToolRegistry()
|
|
111
|
+
self._max_tool_rounds = max_tool_rounds
|
|
112
|
+
|
|
113
|
+
# Persistence
|
|
114
|
+
self._conversation_id = conversation_id or str(uuid.uuid4())
|
|
115
|
+
self._auto_save = Path(auto_save) if auto_save else None
|
|
116
|
+
self._metadata: dict[str, Any] = {
|
|
117
|
+
"created_at": datetime.now(timezone.utc).isoformat(),
|
|
118
|
+
"tags": list(tags) if tags else [],
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
# ------------------------------------------------------------------
|
|
122
|
+
# Public helpers
|
|
123
|
+
# ------------------------------------------------------------------
|
|
124
|
+
|
|
125
|
+
@property
|
|
126
|
+
def messages(self) -> list[dict[str, Any]]:
|
|
127
|
+
"""Read-only view of the conversation history."""
|
|
128
|
+
return list(self._messages)
|
|
129
|
+
|
|
130
|
+
@property
|
|
131
|
+
def usage(self) -> dict[str, Any]:
|
|
132
|
+
"""Accumulated token/cost totals across all turns."""
|
|
133
|
+
return dict(self._usage)
|
|
134
|
+
|
|
135
|
+
def clear(self) -> None:
|
|
136
|
+
"""Reset message history (keeps system_prompt and driver)."""
|
|
137
|
+
self._messages.clear()
|
|
138
|
+
|
|
139
|
+
def add_context(self, role: str, content: str, images: list[ImageInput] | None = None) -> None:
|
|
140
|
+
"""Seed the history with a user or assistant message."""
|
|
141
|
+
if role not in ("user", "assistant"):
|
|
142
|
+
raise ValueError("role must be 'user' or 'assistant'")
|
|
143
|
+
msg_content = self._build_content_with_images(content, images)
|
|
144
|
+
self._messages.append({"role": role, "content": msg_content})
|
|
145
|
+
|
|
146
|
+
def register_tool(
|
|
147
|
+
self,
|
|
148
|
+
fn: Callable[..., Any],
|
|
149
|
+
*,
|
|
150
|
+
name: str | None = None,
|
|
151
|
+
description: str | None = None,
|
|
152
|
+
) -> None:
|
|
153
|
+
"""Register a Python function as a tool the LLM can call."""
|
|
154
|
+
self._tools.register(fn, name=name, description=description)
|
|
155
|
+
|
|
156
|
+
def usage_summary(self) -> str:
|
|
157
|
+
"""Human-readable summary of accumulated usage."""
|
|
158
|
+
u = self._usage
|
|
159
|
+
return f"Conversation: {u['total_tokens']:,} tokens across {u['turns']} turn(s) costing ${u['cost']:.4f}"
|
|
160
|
+
|
|
161
|
+
# ------------------------------------------------------------------
|
|
162
|
+
# Persistence properties
|
|
163
|
+
# ------------------------------------------------------------------
|
|
164
|
+
|
|
165
|
+
@property
|
|
166
|
+
def conversation_id(self) -> str:
|
|
167
|
+
"""Unique identifier for this conversation."""
|
|
168
|
+
return self._conversation_id
|
|
169
|
+
|
|
170
|
+
@property
|
|
171
|
+
def tags(self) -> list[str]:
|
|
172
|
+
"""Tags attached to this conversation."""
|
|
173
|
+
return self._metadata.get("tags", [])
|
|
174
|
+
|
|
175
|
+
@tags.setter
|
|
176
|
+
def tags(self, value: list[str]) -> None:
|
|
177
|
+
self._metadata["tags"] = list(value)
|
|
178
|
+
|
|
179
|
+
# ------------------------------------------------------------------
|
|
180
|
+
# Export / Import
|
|
181
|
+
# ------------------------------------------------------------------
|
|
182
|
+
|
|
183
|
+
def export(self, *, usage_session: UsageSession | None = None, strip_images: bool = False) -> dict[str, Any]:
|
|
184
|
+
"""Export conversation state to a JSON-serializable dict."""
|
|
185
|
+
tools_metadata = (
|
|
186
|
+
[
|
|
187
|
+
{"name": td.name, "description": td.description, "parameters": td.parameters}
|
|
188
|
+
for td in self._tools.definitions
|
|
189
|
+
]
|
|
190
|
+
if self._tools and self._tools.definitions
|
|
191
|
+
else None
|
|
192
|
+
)
|
|
193
|
+
return export_conversation(
|
|
194
|
+
model_name=self._model_name,
|
|
195
|
+
system_prompt=self._system_prompt,
|
|
196
|
+
options=self._options,
|
|
197
|
+
messages=self._messages,
|
|
198
|
+
usage=self._usage,
|
|
199
|
+
max_tool_rounds=self._max_tool_rounds,
|
|
200
|
+
tools_metadata=tools_metadata,
|
|
201
|
+
usage_session=usage_session,
|
|
202
|
+
metadata=self._metadata,
|
|
203
|
+
conversation_id=self._conversation_id,
|
|
204
|
+
strip_images=strip_images,
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
@classmethod
|
|
208
|
+
def from_export(
|
|
209
|
+
cls,
|
|
210
|
+
data: dict[str, Any],
|
|
211
|
+
*,
|
|
212
|
+
callbacks: DriverCallbacks | None = None,
|
|
213
|
+
tools: ToolRegistry | None = None,
|
|
214
|
+
) -> Conversation:
|
|
215
|
+
"""Reconstruct a :class:`Conversation` from an export dict.
|
|
216
|
+
|
|
217
|
+
The driver is reconstructed from the stored ``model_name``.
|
|
218
|
+
Callbacks and tool *functions* must be re-attached by the caller
|
|
219
|
+
(tool metadata — name/description/parameters — is preserved in
|
|
220
|
+
the export but executable functions cannot be serialized).
|
|
221
|
+
"""
|
|
222
|
+
imported = import_conversation(data)
|
|
223
|
+
|
|
224
|
+
model_name = imported.get("model_name") or ""
|
|
225
|
+
if not model_name:
|
|
226
|
+
raise ValueError("Cannot restore conversation: export has no model_name")
|
|
227
|
+
conv = cls(
|
|
228
|
+
model_name=model_name,
|
|
229
|
+
system_prompt=imported.get("system_prompt"),
|
|
230
|
+
options=imported.get("options", {}),
|
|
231
|
+
callbacks=callbacks,
|
|
232
|
+
tools=tools,
|
|
233
|
+
max_tool_rounds=imported.get("max_tool_rounds", 10),
|
|
234
|
+
conversation_id=imported.get("conversation_id"),
|
|
235
|
+
tags=imported.get("metadata", {}).get("tags", []),
|
|
236
|
+
)
|
|
237
|
+
conv._messages = imported.get("messages", [])
|
|
238
|
+
conv._usage = imported.get(
|
|
239
|
+
"usage",
|
|
240
|
+
{
|
|
241
|
+
"prompt_tokens": 0,
|
|
242
|
+
"completion_tokens": 0,
|
|
243
|
+
"total_tokens": 0,
|
|
244
|
+
"cost": 0.0,
|
|
245
|
+
"turns": 0,
|
|
246
|
+
},
|
|
247
|
+
)
|
|
248
|
+
meta = imported.get("metadata", {})
|
|
249
|
+
if "created_at" in meta:
|
|
250
|
+
conv._metadata["created_at"] = meta["created_at"]
|
|
251
|
+
return conv
|
|
252
|
+
|
|
253
|
+
def save(self, path: str | Path, **kwargs: Any) -> None:
|
|
254
|
+
"""Export and write to a JSON file.
|
|
255
|
+
|
|
256
|
+
Keyword arguments are forwarded to :meth:`export`.
|
|
257
|
+
"""
|
|
258
|
+
save_to_file(self.export(**kwargs), path)
|
|
259
|
+
|
|
260
|
+
@classmethod
|
|
261
|
+
def load(
|
|
262
|
+
cls,
|
|
263
|
+
path: str | Path,
|
|
264
|
+
*,
|
|
265
|
+
callbacks: DriverCallbacks | None = None,
|
|
266
|
+
tools: ToolRegistry | None = None,
|
|
267
|
+
) -> Conversation:
|
|
268
|
+
"""Load a conversation from a JSON file."""
|
|
269
|
+
data = load_from_file(path)
|
|
270
|
+
return cls.from_export(data, callbacks=callbacks, tools=tools)
|
|
271
|
+
|
|
272
|
+
def _maybe_auto_save(self) -> None:
|
|
273
|
+
"""Auto-save after each turn if configured. Errors are silently logged."""
|
|
274
|
+
if self._auto_save is None:
|
|
275
|
+
return
|
|
276
|
+
try:
|
|
277
|
+
self.save(self._auto_save)
|
|
278
|
+
except Exception:
|
|
279
|
+
logger.debug("Auto-save failed for conversation %s", self._conversation_id, exc_info=True)
|
|
280
|
+
|
|
281
|
+
# ------------------------------------------------------------------
|
|
282
|
+
# Core methods
|
|
283
|
+
# ------------------------------------------------------------------
|
|
284
|
+
|
|
285
|
+
@staticmethod
|
|
286
|
+
def _build_content_with_images(text: str, images: list[ImageInput] | None = None) -> str | list[dict[str, Any]]:
|
|
287
|
+
"""Return plain string when no images, or a list of content blocks."""
|
|
288
|
+
if not images:
|
|
289
|
+
return text
|
|
290
|
+
blocks: list[dict[str, Any]] = [{"type": "text", "text": text}]
|
|
291
|
+
for img in images:
|
|
292
|
+
ic = make_image(img)
|
|
293
|
+
blocks.append({"type": "image", "source": ic})
|
|
294
|
+
return blocks
|
|
295
|
+
|
|
296
|
+
def _build_messages(self, user_content: str, images: list[ImageInput] | None = None) -> list[dict[str, Any]]:
|
|
297
|
+
"""Build the full messages array for an API call."""
|
|
298
|
+
msgs: list[dict[str, Any]] = []
|
|
299
|
+
if self._system_prompt:
|
|
300
|
+
msgs.append({"role": "system", "content": self._system_prompt})
|
|
301
|
+
msgs.extend(self._messages)
|
|
302
|
+
content = self._build_content_with_images(user_content, images)
|
|
303
|
+
msgs.append({"role": "user", "content": content})
|
|
304
|
+
return msgs
|
|
305
|
+
|
|
306
|
+
def _accumulate_usage(self, meta: dict[str, Any]) -> None:
|
|
307
|
+
self._usage["prompt_tokens"] += meta.get("prompt_tokens", 0)
|
|
308
|
+
self._usage["completion_tokens"] += meta.get("completion_tokens", 0)
|
|
309
|
+
self._usage["total_tokens"] += meta.get("total_tokens", 0)
|
|
310
|
+
self._usage["cost"] += meta.get("cost", 0.0)
|
|
311
|
+
self._usage["turns"] += 1
|
|
312
|
+
self._maybe_auto_save()
|
|
313
|
+
|
|
314
|
+
def ask(
|
|
315
|
+
self,
|
|
316
|
+
content: str,
|
|
317
|
+
options: dict[str, Any] | None = None,
|
|
318
|
+
images: list[ImageInput] | None = None,
|
|
319
|
+
) -> str:
|
|
320
|
+
"""Send a message and get a raw text response.
|
|
321
|
+
|
|
322
|
+
Appends the user message and assistant response to history.
|
|
323
|
+
If tools are registered and the driver supports tool use,
|
|
324
|
+
dispatches to the tool execution loop.
|
|
325
|
+
|
|
326
|
+
Args:
|
|
327
|
+
content: The text message to send.
|
|
328
|
+
options: Additional options for the driver.
|
|
329
|
+
images: Optional list of images to include (bytes, path, URL,
|
|
330
|
+
base64 string, or :class:`ImageContent`).
|
|
331
|
+
"""
|
|
332
|
+
if self._tools and getattr(self._driver, "supports_tool_use", False):
|
|
333
|
+
return self._ask_with_tools(content, options, images=images)
|
|
334
|
+
|
|
335
|
+
merged = {**self._options, **(options or {})}
|
|
336
|
+
messages = self._build_messages(content, images=images)
|
|
337
|
+
resp = self._driver.generate_messages_with_hooks(messages, merged)
|
|
338
|
+
|
|
339
|
+
text = resp.get("text", "")
|
|
340
|
+
meta = resp.get("meta", {})
|
|
341
|
+
|
|
342
|
+
# Record in history — store content with images for context
|
|
343
|
+
user_content = self._build_content_with_images(content, images)
|
|
344
|
+
self._messages.append({"role": "user", "content": user_content})
|
|
345
|
+
self._messages.append({"role": "assistant", "content": text})
|
|
346
|
+
self._accumulate_usage(meta)
|
|
347
|
+
|
|
348
|
+
return text
|
|
349
|
+
|
|
350
|
+
def _ask_with_tools(
|
|
351
|
+
self,
|
|
352
|
+
content: str,
|
|
353
|
+
options: dict[str, Any] | None = None,
|
|
354
|
+
images: list[ImageInput] | None = None,
|
|
355
|
+
) -> str:
|
|
356
|
+
"""Execute the tool-use loop: send -> check tool_calls -> execute -> re-send."""
|
|
357
|
+
merged = {**self._options, **(options or {})}
|
|
358
|
+
tool_defs = self._tools.to_openai_format()
|
|
359
|
+
|
|
360
|
+
# Build messages including user content
|
|
361
|
+
user_content = self._build_content_with_images(content, images)
|
|
362
|
+
self._messages.append({"role": "user", "content": user_content})
|
|
363
|
+
msgs = self._build_messages_raw()
|
|
364
|
+
|
|
365
|
+
for _round in range(self._max_tool_rounds):
|
|
366
|
+
resp = self._driver.generate_messages_with_tools(msgs, tool_defs, merged)
|
|
367
|
+
|
|
368
|
+
meta = resp.get("meta", {})
|
|
369
|
+
self._accumulate_usage(meta)
|
|
370
|
+
|
|
371
|
+
tool_calls = resp.get("tool_calls", [])
|
|
372
|
+
text = resp.get("text", "")
|
|
373
|
+
|
|
374
|
+
if not tool_calls:
|
|
375
|
+
# No tool calls -> final response
|
|
376
|
+
self._messages.append({"role": "assistant", "content": text})
|
|
377
|
+
return text
|
|
378
|
+
|
|
379
|
+
# Record assistant message with tool_calls
|
|
380
|
+
assistant_msg: dict[str, Any] = {"role": "assistant", "content": text}
|
|
381
|
+
assistant_msg["tool_calls"] = [
|
|
382
|
+
{
|
|
383
|
+
"id": tc["id"],
|
|
384
|
+
"type": "function",
|
|
385
|
+
"function": {"name": tc["name"], "arguments": json.dumps(tc["arguments"])},
|
|
386
|
+
}
|
|
387
|
+
for tc in tool_calls
|
|
388
|
+
]
|
|
389
|
+
self._messages.append(assistant_msg)
|
|
390
|
+
msgs.append(assistant_msg)
|
|
391
|
+
|
|
392
|
+
# Execute each tool call and append results
|
|
393
|
+
for tc in tool_calls:
|
|
394
|
+
try:
|
|
395
|
+
result = self._tools.execute(tc["name"], tc["arguments"])
|
|
396
|
+
result_str = json.dumps(result) if not isinstance(result, str) else result
|
|
397
|
+
except Exception as exc:
|
|
398
|
+
result_str = f"Error: {exc}"
|
|
399
|
+
|
|
400
|
+
tool_result_msg: dict[str, Any] = {
|
|
401
|
+
"role": "tool",
|
|
402
|
+
"tool_call_id": tc["id"],
|
|
403
|
+
"content": result_str,
|
|
404
|
+
}
|
|
405
|
+
self._messages.append(tool_result_msg)
|
|
406
|
+
msgs.append(tool_result_msg)
|
|
407
|
+
|
|
408
|
+
raise RuntimeError(f"Tool execution loop exceeded {self._max_tool_rounds} rounds")
|
|
409
|
+
|
|
410
|
+
def _build_messages_raw(self) -> list[dict[str, Any]]:
|
|
411
|
+
"""Build messages array from system prompt + full history (including tool messages)."""
|
|
412
|
+
msgs: list[dict[str, Any]] = []
|
|
413
|
+
if self._system_prompt:
|
|
414
|
+
msgs.append({"role": "system", "content": self._system_prompt})
|
|
415
|
+
msgs.extend(self._messages)
|
|
416
|
+
return msgs
|
|
417
|
+
|
|
418
|
+
# ------------------------------------------------------------------
|
|
419
|
+
# Streaming
|
|
420
|
+
# ------------------------------------------------------------------
|
|
421
|
+
|
|
422
|
+
def ask_stream(
|
|
423
|
+
self,
|
|
424
|
+
content: str,
|
|
425
|
+
options: dict[str, Any] | None = None,
|
|
426
|
+
images: list[ImageInput] | None = None,
|
|
427
|
+
) -> Iterator[str]:
|
|
428
|
+
"""Send a message and yield text chunks as they arrive.
|
|
429
|
+
|
|
430
|
+
Falls back to non-streaming :meth:`ask` if the driver doesn't
|
|
431
|
+
support streaming. After iteration completes, the full response
|
|
432
|
+
is recorded in history.
|
|
433
|
+
"""
|
|
434
|
+
if not getattr(self._driver, "supports_streaming", False):
|
|
435
|
+
yield self.ask(content, options, images=images)
|
|
436
|
+
return
|
|
437
|
+
|
|
438
|
+
merged = {**self._options, **(options or {})}
|
|
439
|
+
messages = self._build_messages(content, images=images)
|
|
440
|
+
|
|
441
|
+
user_content = self._build_content_with_images(content, images)
|
|
442
|
+
self._messages.append({"role": "user", "content": user_content})
|
|
443
|
+
|
|
444
|
+
full_text = ""
|
|
445
|
+
for chunk in self._driver.generate_messages_stream(messages, merged):
|
|
446
|
+
if chunk["type"] == "delta":
|
|
447
|
+
full_text += chunk["text"]
|
|
448
|
+
# Fire stream delta callback
|
|
449
|
+
self._driver._fire_callback(
|
|
450
|
+
"on_stream_delta",
|
|
451
|
+
{"text": chunk["text"], "driver": getattr(self._driver, "model", self._driver.__class__.__name__)},
|
|
452
|
+
)
|
|
453
|
+
yield chunk["text"]
|
|
454
|
+
elif chunk["type"] == "done":
|
|
455
|
+
meta = chunk.get("meta", {})
|
|
456
|
+
self._accumulate_usage(meta)
|
|
457
|
+
|
|
458
|
+
self._messages.append({"role": "assistant", "content": full_text})
|
|
459
|
+
|
|
460
|
+
def ask_for_json(
|
|
461
|
+
self,
|
|
462
|
+
content: str,
|
|
463
|
+
json_schema: dict[str, Any],
|
|
464
|
+
*,
|
|
465
|
+
ai_cleanup: bool = True,
|
|
466
|
+
options: dict[str, Any] | None = None,
|
|
467
|
+
output_format: Literal["json", "toon"] = "json",
|
|
468
|
+
json_mode: Literal["auto", "on", "off"] = "auto",
|
|
469
|
+
images: list[ImageInput] | None = None,
|
|
470
|
+
) -> dict[str, Any]:
|
|
471
|
+
"""Send a message with schema enforcement and get structured JSON back.
|
|
472
|
+
|
|
473
|
+
The schema instructions are appended to the prompt but only the
|
|
474
|
+
original *content* is stored in conversation history to keep
|
|
475
|
+
context clean for subsequent turns.
|
|
476
|
+
"""
|
|
477
|
+
|
|
478
|
+
merged = {**self._options, **(options or {})}
|
|
479
|
+
|
|
480
|
+
# Build the full prompt with schema instructions inline (handled by ask_for_json)
|
|
481
|
+
# We use a special approach: call ask_for_json with the driver but pass messages context
|
|
482
|
+
schema_string = json.dumps(json_schema, indent=2)
|
|
483
|
+
|
|
484
|
+
# Determine JSON mode
|
|
485
|
+
use_json_mode = False
|
|
486
|
+
if json_mode == "on":
|
|
487
|
+
use_json_mode = True
|
|
488
|
+
elif json_mode == "auto":
|
|
489
|
+
use_json_mode = getattr(self._driver, "supports_json_mode", False)
|
|
490
|
+
|
|
491
|
+
if use_json_mode:
|
|
492
|
+
merged = {**merged, "json_mode": True}
|
|
493
|
+
if getattr(self._driver, "supports_json_schema", False):
|
|
494
|
+
merged["json_schema"] = json_schema
|
|
495
|
+
|
|
496
|
+
# Build instruction based on JSON mode
|
|
497
|
+
if use_json_mode and getattr(self._driver, "supports_json_schema", False):
|
|
498
|
+
instruct = "Extract data matching the requested schema.\nIf a value is unknown use null."
|
|
499
|
+
elif use_json_mode:
|
|
500
|
+
instruct = (
|
|
501
|
+
"Return a JSON object that validates against this schema:\n"
|
|
502
|
+
f"{schema_string}\n\n"
|
|
503
|
+
"If a value is unknown use null."
|
|
504
|
+
)
|
|
505
|
+
else:
|
|
506
|
+
instruct = (
|
|
507
|
+
"Return only a single JSON object (no markdown, no extra text) that validates against this JSON schema:\n"
|
|
508
|
+
f"{schema_string}\n\n"
|
|
509
|
+
"If a value is unknown use null. Use double quotes for keys and strings."
|
|
510
|
+
)
|
|
511
|
+
|
|
512
|
+
full_user_content = f"{content}\n\n{instruct}"
|
|
513
|
+
|
|
514
|
+
messages = self._build_messages(full_user_content, images=images)
|
|
515
|
+
resp = self._driver.generate_messages_with_hooks(messages, merged)
|
|
516
|
+
|
|
517
|
+
text = resp.get("text", "")
|
|
518
|
+
meta = resp.get("meta", {})
|
|
519
|
+
|
|
520
|
+
# Store original content (without schema boilerplate) for cleaner context
|
|
521
|
+
# Include images in history so subsequent turns can reference them
|
|
522
|
+
user_content = self._build_content_with_images(content, images)
|
|
523
|
+
self._messages.append({"role": "user", "content": user_content})
|
|
524
|
+
|
|
525
|
+
# Parse JSON
|
|
526
|
+
cleaned = clean_json_text(text)
|
|
527
|
+
try:
|
|
528
|
+
json_obj = json.loads(cleaned)
|
|
529
|
+
except json.JSONDecodeError:
|
|
530
|
+
if ai_cleanup:
|
|
531
|
+
from .core import clean_json_text_with_ai
|
|
532
|
+
|
|
533
|
+
cleaned = clean_json_text_with_ai(self._driver, cleaned, self._model_name, merged)
|
|
534
|
+
json_obj = json.loads(cleaned)
|
|
535
|
+
else:
|
|
536
|
+
raise
|
|
537
|
+
|
|
538
|
+
# Store assistant response in history
|
|
539
|
+
self._messages.append({"role": "assistant", "content": cleaned})
|
|
540
|
+
self._accumulate_usage(meta)
|
|
541
|
+
|
|
542
|
+
model_name = self._model_name
|
|
543
|
+
if "/" in model_name:
|
|
544
|
+
model_name = model_name.split("/", 1)[1]
|
|
545
|
+
|
|
546
|
+
usage = {
|
|
547
|
+
**meta,
|
|
548
|
+
"raw_response": resp,
|
|
549
|
+
"model_name": model_name or getattr(self._driver, "model", ""),
|
|
550
|
+
}
|
|
551
|
+
|
|
552
|
+
result: dict[str, Any] = {
|
|
553
|
+
"json_string": cleaned,
|
|
554
|
+
"json_object": json_obj,
|
|
555
|
+
"usage": usage,
|
|
556
|
+
"output_format": output_format,
|
|
557
|
+
}
|
|
558
|
+
|
|
559
|
+
if output_format == "toon":
|
|
560
|
+
try:
|
|
561
|
+
import toon
|
|
562
|
+
|
|
563
|
+
result["toon_string"] = toon.encode(json_obj)
|
|
564
|
+
except ImportError:
|
|
565
|
+
raise RuntimeError("TOON requested but 'python-toon' is not installed.") from None
|
|
566
|
+
|
|
567
|
+
return result
|
|
568
|
+
|
|
569
|
+
def extract_with_model(
|
|
570
|
+
self,
|
|
571
|
+
model_cls: type[BaseModel],
|
|
572
|
+
text: str,
|
|
573
|
+
*,
|
|
574
|
+
instruction_template: str = "Extract information from the following text:",
|
|
575
|
+
ai_cleanup: bool = True,
|
|
576
|
+
output_format: Literal["json", "toon"] = "json",
|
|
577
|
+
options: dict[str, Any] | None = None,
|
|
578
|
+
json_mode: Literal["auto", "on", "off"] = "auto",
|
|
579
|
+
images: list[ImageInput] | None = None,
|
|
580
|
+
) -> dict[str, Any]:
|
|
581
|
+
"""Extract structured information into a Pydantic model with conversation context."""
|
|
582
|
+
from .core import normalize_field_value
|
|
583
|
+
|
|
584
|
+
schema = model_cls.model_json_schema()
|
|
585
|
+
content_prompt = f"{instruction_template} {text}"
|
|
586
|
+
|
|
587
|
+
result = self.ask_for_json(
|
|
588
|
+
content=content_prompt,
|
|
589
|
+
json_schema=schema,
|
|
590
|
+
ai_cleanup=ai_cleanup,
|
|
591
|
+
options=options,
|
|
592
|
+
output_format=output_format,
|
|
593
|
+
json_mode=json_mode,
|
|
594
|
+
images=images,
|
|
595
|
+
)
|
|
596
|
+
|
|
597
|
+
# Normalize field values
|
|
598
|
+
json_object = result["json_object"]
|
|
599
|
+
schema_properties = schema.get("properties", {})
|
|
600
|
+
|
|
601
|
+
for field_name, field_info in model_cls.model_fields.items():
|
|
602
|
+
if field_name in json_object and field_name in schema_properties:
|
|
603
|
+
field_def = {
|
|
604
|
+
"nullable": not schema_properties[field_name].get("type")
|
|
605
|
+
or "null"
|
|
606
|
+
in (
|
|
607
|
+
schema_properties[field_name].get("anyOf", [])
|
|
608
|
+
if isinstance(schema_properties[field_name].get("anyOf"), list)
|
|
609
|
+
else []
|
|
610
|
+
),
|
|
611
|
+
"default": field_info.default
|
|
612
|
+
if hasattr(field_info, "default") and field_info.default is not ...
|
|
613
|
+
else None,
|
|
614
|
+
}
|
|
615
|
+
json_object[field_name] = normalize_field_value(
|
|
616
|
+
json_object[field_name], field_info.annotation, field_def
|
|
617
|
+
)
|
|
618
|
+
|
|
619
|
+
model_instance = model_cls(**json_object)
|
|
620
|
+
|
|
621
|
+
result_dict = {
|
|
622
|
+
"json_string": result["json_string"],
|
|
623
|
+
"json_object": result["json_object"],
|
|
624
|
+
"usage": result["usage"],
|
|
625
|
+
}
|
|
626
|
+
result_dict["model"] = model_instance
|
|
627
|
+
|
|
628
|
+
return type(
|
|
629
|
+
"ExtractResult",
|
|
630
|
+
(dict,),
|
|
631
|
+
{
|
|
632
|
+
"__getattr__": lambda self, key: self.get(key),
|
|
633
|
+
"__call__": lambda self: self["model"],
|
|
634
|
+
},
|
|
635
|
+
)(result_dict)
|
|
636
|
+
|
|
637
|
+
# ------------------------------------------------------------------
|
|
638
|
+
# Internal: stepwise with shared context
|
|
639
|
+
# ------------------------------------------------------------------
|
|
640
|
+
|
|
641
|
+
def _stepwise_extract(
|
|
642
|
+
self,
|
|
643
|
+
model_cls: type[BaseModel],
|
|
644
|
+
text: str,
|
|
645
|
+
instruction_template: str,
|
|
646
|
+
ai_cleanup: bool,
|
|
647
|
+
fields: list[str] | None,
|
|
648
|
+
field_definitions: dict[str, Any] | None,
|
|
649
|
+
json_mode: Literal["auto", "on", "off"],
|
|
650
|
+
) -> dict[str, Union[str, dict[str, Any]]]:
|
|
651
|
+
"""Stepwise extraction using conversation context between fields."""
|
|
652
|
+
if field_definitions is None:
|
|
653
|
+
field_definitions = get_registry_snapshot()
|
|
654
|
+
|
|
655
|
+
data: dict[str, Any] = {}
|
|
656
|
+
validation_errors: list[str] = []
|
|
657
|
+
field_results: dict[str, Any] = {}
|
|
658
|
+
|
|
659
|
+
accumulated_usage = {
|
|
660
|
+
"prompt_tokens": 0,
|
|
661
|
+
"completion_tokens": 0,
|
|
662
|
+
"total_tokens": 0,
|
|
663
|
+
"cost": 0.0,
|
|
664
|
+
"model_name": self._model_name,
|
|
665
|
+
"field_usages": {},
|
|
666
|
+
}
|
|
667
|
+
|
|
668
|
+
valid_fields = set(model_cls.model_fields.keys())
|
|
669
|
+
if fields is not None:
|
|
670
|
+
invalid_fields = set(fields) - valid_fields
|
|
671
|
+
if invalid_fields:
|
|
672
|
+
raise KeyError(f"Fields not found in model: {', '.join(invalid_fields)}")
|
|
673
|
+
field_items = [(name, model_cls.model_fields[name]) for name in fields]
|
|
674
|
+
else:
|
|
675
|
+
field_items = list(model_cls.model_fields.items())
|
|
676
|
+
|
|
677
|
+
# Seed conversation with the source text
|
|
678
|
+
self.add_context("user", f"I need to extract information from this text:\n\n{text}")
|
|
679
|
+
self.add_context(
|
|
680
|
+
"assistant", "I'll help you extract the information from that text. What would you like to extract?"
|
|
681
|
+
)
|
|
682
|
+
|
|
683
|
+
for field_name, field_info in field_items:
|
|
684
|
+
logger.debug("[stepwise-conv] Extracting field: %s", field_name)
|
|
685
|
+
|
|
686
|
+
field_schema = {
|
|
687
|
+
"value": {
|
|
688
|
+
"type": "integer" if field_info.annotation is int else "string",
|
|
689
|
+
"description": field_info.description or f"Value for {field_name}",
|
|
690
|
+
}
|
|
691
|
+
}
|
|
692
|
+
|
|
693
|
+
try:
|
|
694
|
+
prompt = instruction_template.format(field_name=field_name)
|
|
695
|
+
result = self.ask_for_json(
|
|
696
|
+
content=f"{prompt} {text}",
|
|
697
|
+
json_schema=field_schema,
|
|
698
|
+
ai_cleanup=ai_cleanup,
|
|
699
|
+
json_mode=json_mode,
|
|
700
|
+
)
|
|
701
|
+
|
|
702
|
+
field_usage = result.get("usage", {})
|
|
703
|
+
accumulated_usage["prompt_tokens"] += field_usage.get("prompt_tokens", 0)
|
|
704
|
+
accumulated_usage["completion_tokens"] += field_usage.get("completion_tokens", 0)
|
|
705
|
+
accumulated_usage["total_tokens"] += field_usage.get("total_tokens", 0)
|
|
706
|
+
accumulated_usage["cost"] += field_usage.get("cost", 0.0)
|
|
707
|
+
accumulated_usage["field_usages"][field_name] = field_usage
|
|
708
|
+
|
|
709
|
+
extracted_value = result["json_object"]["value"]
|
|
710
|
+
if isinstance(extracted_value, dict) and "value" in extracted_value:
|
|
711
|
+
raw_value = extracted_value["value"]
|
|
712
|
+
else:
|
|
713
|
+
raw_value = extracted_value
|
|
714
|
+
|
|
715
|
+
# Normalize
|
|
716
|
+
from .core import normalize_field_value
|
|
717
|
+
|
|
718
|
+
field_def = {}
|
|
719
|
+
if field_definitions and field_name in field_definitions:
|
|
720
|
+
field_def = field_definitions[field_name] if isinstance(field_definitions[field_name], dict) else {}
|
|
721
|
+
|
|
722
|
+
nullable = field_def.get("nullable", True)
|
|
723
|
+
default_value = field_def.get("default")
|
|
724
|
+
if (
|
|
725
|
+
default_value is None
|
|
726
|
+
and hasattr(field_info, "default")
|
|
727
|
+
and field_info.default is not ...
|
|
728
|
+
and str(field_info.default) != "PydanticUndefined"
|
|
729
|
+
):
|
|
730
|
+
default_value = field_info.default
|
|
731
|
+
|
|
732
|
+
normalize_def = {"nullable": nullable, "default": default_value}
|
|
733
|
+
raw_value = normalize_field_value(raw_value, field_info.annotation, normalize_def)
|
|
734
|
+
|
|
735
|
+
try:
|
|
736
|
+
converted_value = convert_value(raw_value, field_info.annotation, allow_shorthand=True)
|
|
737
|
+
data[field_name] = converted_value
|
|
738
|
+
field_results[field_name] = {"status": "success", "used_default": False}
|
|
739
|
+
except ValueError as e:
|
|
740
|
+
error_msg = f"Type conversion failed for {field_name}: {e!s}"
|
|
741
|
+
has_default = _has_default(field_name, field_info, field_definitions)
|
|
742
|
+
if not has_default:
|
|
743
|
+
validation_errors.append(error_msg)
|
|
744
|
+
default_value = get_field_default(field_name, field_info, field_definitions)
|
|
745
|
+
data[field_name] = default_value
|
|
746
|
+
field_results[field_name] = {
|
|
747
|
+
"status": "conversion_failed",
|
|
748
|
+
"error": error_msg,
|
|
749
|
+
"used_default": True,
|
|
750
|
+
}
|
|
751
|
+
|
|
752
|
+
except Exception as e:
|
|
753
|
+
error_msg = f"Extraction failed for {field_name}: {e!s}"
|
|
754
|
+
has_default = _has_default(field_name, field_info, field_definitions)
|
|
755
|
+
if not has_default:
|
|
756
|
+
validation_errors.append(error_msg)
|
|
757
|
+
default_value = get_field_default(field_name, field_info, field_definitions)
|
|
758
|
+
data[field_name] = default_value
|
|
759
|
+
field_results[field_name] = {"status": "extraction_failed", "error": error_msg, "used_default": True}
|
|
760
|
+
accumulated_usage["field_usages"][field_name] = {
|
|
761
|
+
"error": str(e),
|
|
762
|
+
"status": "failed",
|
|
763
|
+
"used_default": True,
|
|
764
|
+
"default_value": default_value,
|
|
765
|
+
}
|
|
766
|
+
|
|
767
|
+
if validation_errors:
|
|
768
|
+
accumulated_usage["validation_errors"] = validation_errors
|
|
769
|
+
|
|
770
|
+
try:
|
|
771
|
+
model_instance = model_cls(**data)
|
|
772
|
+
model_dict = model_instance.model_dump()
|
|
773
|
+
|
|
774
|
+
class ExtendedJSONEncoder(json.JSONEncoder):
|
|
775
|
+
def default(self, obj):
|
|
776
|
+
if isinstance(obj, (datetime, date)):
|
|
777
|
+
return obj.isoformat()
|
|
778
|
+
if isinstance(obj, Decimal):
|
|
779
|
+
return str(obj)
|
|
780
|
+
return super().default(obj)
|
|
781
|
+
|
|
782
|
+
json_string = json.dumps(model_dict, cls=ExtendedJSONEncoder)
|
|
783
|
+
|
|
784
|
+
result = {
|
|
785
|
+
"json_string": json_string,
|
|
786
|
+
"json_object": json.loads(json_string),
|
|
787
|
+
"usage": accumulated_usage,
|
|
788
|
+
"field_results": field_results,
|
|
789
|
+
}
|
|
790
|
+
result["model"] = model_instance
|
|
791
|
+
return type(
|
|
792
|
+
"ExtractResult",
|
|
793
|
+
(dict,),
|
|
794
|
+
{"__getattr__": lambda self, key: self.get(key), "__call__": lambda self: self["model"]},
|
|
795
|
+
)(result)
|
|
796
|
+
except Exception as e:
|
|
797
|
+
error_msg = f"Model validation error: {e!s}"
|
|
798
|
+
if "validation_errors" not in accumulated_usage:
|
|
799
|
+
accumulated_usage["validation_errors"] = []
|
|
800
|
+
accumulated_usage["validation_errors"].append(error_msg)
|
|
801
|
+
|
|
802
|
+
error_result = {
|
|
803
|
+
"json_string": "{}",
|
|
804
|
+
"json_object": {},
|
|
805
|
+
"usage": accumulated_usage,
|
|
806
|
+
"field_results": field_results,
|
|
807
|
+
"error": error_msg,
|
|
808
|
+
}
|
|
809
|
+
return type(
|
|
810
|
+
"ExtractResult",
|
|
811
|
+
(dict,),
|
|
812
|
+
{"__getattr__": lambda self, key: self.get(key), "__call__": lambda self: None},
|
|
813
|
+
)(error_result)
|
|
814
|
+
|
|
815
|
+
|
|
816
|
+
def _has_default(field_name: str, field_info: Any, field_definitions: dict[str, Any] | None) -> bool:
|
|
817
|
+
"""Check whether a Pydantic field has a usable default value."""
|
|
818
|
+
if field_definitions and field_name in field_definitions:
|
|
819
|
+
fd = field_definitions[field_name]
|
|
820
|
+
if isinstance(fd, dict) and "default" in fd:
|
|
821
|
+
return True
|
|
822
|
+
if hasattr(field_info, "default"):
|
|
823
|
+
val = field_info.default
|
|
824
|
+
if val is not ... and str(val) != "PydanticUndefined":
|
|
825
|
+
return True
|
|
826
|
+
return False
|