prompture 0.0.33.dev1__py3-none-any.whl → 0.0.34__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prompture/__init__.py +133 -49
- prompture/_version.py +34 -0
- prompture/aio/__init__.py +74 -0
- prompture/async_conversation.py +484 -0
- prompture/async_core.py +803 -0
- prompture/async_driver.py +131 -0
- prompture/cache.py +469 -0
- prompture/callbacks.py +50 -0
- prompture/cli.py +7 -3
- prompture/conversation.py +504 -0
- prompture/core.py +475 -352
- prompture/cost_mixin.py +51 -0
- prompture/discovery.py +50 -35
- prompture/driver.py +125 -5
- prompture/drivers/__init__.py +171 -73
- prompture/drivers/airllm_driver.py +13 -20
- prompture/drivers/async_airllm_driver.py +26 -0
- prompture/drivers/async_azure_driver.py +117 -0
- prompture/drivers/async_claude_driver.py +107 -0
- prompture/drivers/async_google_driver.py +132 -0
- prompture/drivers/async_grok_driver.py +91 -0
- prompture/drivers/async_groq_driver.py +84 -0
- prompture/drivers/async_hugging_driver.py +61 -0
- prompture/drivers/async_lmstudio_driver.py +79 -0
- prompture/drivers/async_local_http_driver.py +44 -0
- prompture/drivers/async_ollama_driver.py +125 -0
- prompture/drivers/async_openai_driver.py +96 -0
- prompture/drivers/async_openrouter_driver.py +96 -0
- prompture/drivers/async_registry.py +129 -0
- prompture/drivers/azure_driver.py +36 -9
- prompture/drivers/claude_driver.py +86 -34
- prompture/drivers/google_driver.py +87 -51
- prompture/drivers/grok_driver.py +29 -32
- prompture/drivers/groq_driver.py +27 -26
- prompture/drivers/hugging_driver.py +6 -6
- prompture/drivers/lmstudio_driver.py +26 -13
- prompture/drivers/local_http_driver.py +6 -6
- prompture/drivers/ollama_driver.py +90 -23
- prompture/drivers/openai_driver.py +36 -9
- prompture/drivers/openrouter_driver.py +31 -25
- prompture/drivers/registry.py +306 -0
- prompture/field_definitions.py +106 -96
- prompture/logging.py +80 -0
- prompture/model_rates.py +217 -0
- prompture/runner.py +49 -47
- prompture/session.py +117 -0
- prompture/settings.py +14 -1
- prompture/tools.py +172 -265
- prompture/validator.py +3 -3
- {prompture-0.0.33.dev1.dist-info → prompture-0.0.34.dist-info}/METADATA +18 -20
- prompture-0.0.34.dist-info/RECORD +55 -0
- prompture-0.0.33.dev1.dist-info/RECORD +0 -29
- {prompture-0.0.33.dev1.dist-info → prompture-0.0.34.dist-info}/WHEEL +0 -0
- {prompture-0.0.33.dev1.dist-info → prompture-0.0.34.dist-info}/entry_points.txt +0 -0
- {prompture-0.0.33.dev1.dist-info → prompture-0.0.34.dist-info}/licenses/LICENSE +0 -0
- {prompture-0.0.33.dev1.dist-info → prompture-0.0.34.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,484 @@
|
|
|
1
|
+
"""Async stateful multi-turn conversation support for Prompture."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import logging
|
|
7
|
+
from datetime import date, datetime
|
|
8
|
+
from decimal import Decimal
|
|
9
|
+
from typing import Any, Literal, Union
|
|
10
|
+
|
|
11
|
+
from pydantic import BaseModel
|
|
12
|
+
|
|
13
|
+
from .async_driver import AsyncDriver
|
|
14
|
+
from .callbacks import DriverCallbacks
|
|
15
|
+
from .drivers.async_registry import get_async_driver_for_model
|
|
16
|
+
from .field_definitions import get_registry_snapshot
|
|
17
|
+
from .tools import (
|
|
18
|
+
clean_json_text,
|
|
19
|
+
convert_value,
|
|
20
|
+
get_field_default,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
logger = logging.getLogger("prompture.async_conversation")
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class AsyncConversation:
|
|
27
|
+
"""Async stateful multi-turn conversation with an LLM.
|
|
28
|
+
|
|
29
|
+
Mirrors :class:`Conversation` but all methods are ``async``.
|
|
30
|
+
|
|
31
|
+
Example::
|
|
32
|
+
|
|
33
|
+
conv = AsyncConversation("openai/gpt-4", system_prompt="You are a data extractor")
|
|
34
|
+
r1 = await conv.ask_for_json("Extract names from: John, age 30", name_schema)
|
|
35
|
+
r2 = await conv.ask_for_json("Now extract ages", age_schema)
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
model_name: str | None = None,
|
|
41
|
+
*,
|
|
42
|
+
driver: AsyncDriver | None = None,
|
|
43
|
+
system_prompt: str | None = None,
|
|
44
|
+
options: dict[str, Any] | None = None,
|
|
45
|
+
callbacks: DriverCallbacks | None = None,
|
|
46
|
+
) -> None:
|
|
47
|
+
if model_name is None and driver is None:
|
|
48
|
+
raise ValueError("Either model_name or driver must be provided")
|
|
49
|
+
|
|
50
|
+
if driver is not None:
|
|
51
|
+
self._driver = driver
|
|
52
|
+
else:
|
|
53
|
+
self._driver = get_async_driver_for_model(model_name)
|
|
54
|
+
|
|
55
|
+
if callbacks is not None:
|
|
56
|
+
self._driver.callbacks = callbacks
|
|
57
|
+
|
|
58
|
+
self._model_name = model_name or ""
|
|
59
|
+
self._system_prompt = system_prompt
|
|
60
|
+
self._options = dict(options) if options else {}
|
|
61
|
+
self._messages: list[dict[str, str]] = []
|
|
62
|
+
self._usage = {
|
|
63
|
+
"prompt_tokens": 0,
|
|
64
|
+
"completion_tokens": 0,
|
|
65
|
+
"total_tokens": 0,
|
|
66
|
+
"cost": 0.0,
|
|
67
|
+
"turns": 0,
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
# ------------------------------------------------------------------
|
|
71
|
+
# Public helpers
|
|
72
|
+
# ------------------------------------------------------------------
|
|
73
|
+
|
|
74
|
+
@property
|
|
75
|
+
def messages(self) -> list[dict[str, str]]:
|
|
76
|
+
"""Read-only view of the conversation history."""
|
|
77
|
+
return list(self._messages)
|
|
78
|
+
|
|
79
|
+
@property
|
|
80
|
+
def usage(self) -> dict[str, Any]:
|
|
81
|
+
"""Accumulated token/cost totals across all turns."""
|
|
82
|
+
return dict(self._usage)
|
|
83
|
+
|
|
84
|
+
def clear(self) -> None:
|
|
85
|
+
"""Reset message history (keeps system_prompt and driver)."""
|
|
86
|
+
self._messages.clear()
|
|
87
|
+
|
|
88
|
+
def add_context(self, role: str, content: str) -> None:
|
|
89
|
+
"""Seed the history with a user or assistant message."""
|
|
90
|
+
if role not in ("user", "assistant"):
|
|
91
|
+
raise ValueError("role must be 'user' or 'assistant'")
|
|
92
|
+
self._messages.append({"role": role, "content": content})
|
|
93
|
+
|
|
94
|
+
def usage_summary(self) -> str:
|
|
95
|
+
"""Human-readable summary of accumulated usage."""
|
|
96
|
+
u = self._usage
|
|
97
|
+
return f"Conversation: {u['total_tokens']:,} tokens across {u['turns']} turn(s) costing ${u['cost']:.4f}"
|
|
98
|
+
|
|
99
|
+
# ------------------------------------------------------------------
|
|
100
|
+
# Core methods
|
|
101
|
+
# ------------------------------------------------------------------
|
|
102
|
+
|
|
103
|
+
def _build_messages(self, user_content: str) -> list[dict[str, str]]:
|
|
104
|
+
"""Build the full messages array for an API call."""
|
|
105
|
+
msgs: list[dict[str, str]] = []
|
|
106
|
+
if self._system_prompt:
|
|
107
|
+
msgs.append({"role": "system", "content": self._system_prompt})
|
|
108
|
+
msgs.extend(self._messages)
|
|
109
|
+
msgs.append({"role": "user", "content": user_content})
|
|
110
|
+
return msgs
|
|
111
|
+
|
|
112
|
+
def _accumulate_usage(self, meta: dict[str, Any]) -> None:
|
|
113
|
+
self._usage["prompt_tokens"] += meta.get("prompt_tokens", 0)
|
|
114
|
+
self._usage["completion_tokens"] += meta.get("completion_tokens", 0)
|
|
115
|
+
self._usage["total_tokens"] += meta.get("total_tokens", 0)
|
|
116
|
+
self._usage["cost"] += meta.get("cost", 0.0)
|
|
117
|
+
self._usage["turns"] += 1
|
|
118
|
+
|
|
119
|
+
async def ask(
|
|
120
|
+
self,
|
|
121
|
+
content: str,
|
|
122
|
+
options: dict[str, Any] | None = None,
|
|
123
|
+
) -> str:
|
|
124
|
+
"""Send a message and get a raw text response (async)."""
|
|
125
|
+
merged = {**self._options, **(options or {})}
|
|
126
|
+
messages = self._build_messages(content)
|
|
127
|
+
resp = await self._driver.generate_messages_with_hooks(messages, merged)
|
|
128
|
+
|
|
129
|
+
text = resp.get("text", "")
|
|
130
|
+
meta = resp.get("meta", {})
|
|
131
|
+
|
|
132
|
+
self._messages.append({"role": "user", "content": content})
|
|
133
|
+
self._messages.append({"role": "assistant", "content": text})
|
|
134
|
+
self._accumulate_usage(meta)
|
|
135
|
+
|
|
136
|
+
return text
|
|
137
|
+
|
|
138
|
+
async def ask_for_json(
|
|
139
|
+
self,
|
|
140
|
+
content: str,
|
|
141
|
+
json_schema: dict[str, Any],
|
|
142
|
+
*,
|
|
143
|
+
ai_cleanup: bool = True,
|
|
144
|
+
options: dict[str, Any] | None = None,
|
|
145
|
+
output_format: Literal["json", "toon"] = "json",
|
|
146
|
+
json_mode: Literal["auto", "on", "off"] = "auto",
|
|
147
|
+
) -> dict[str, Any]:
|
|
148
|
+
"""Send a message with schema enforcement and get structured JSON back (async)."""
|
|
149
|
+
merged = {**self._options, **(options or {})}
|
|
150
|
+
|
|
151
|
+
schema_string = json.dumps(json_schema, indent=2)
|
|
152
|
+
|
|
153
|
+
use_json_mode = False
|
|
154
|
+
if json_mode == "on":
|
|
155
|
+
use_json_mode = True
|
|
156
|
+
elif json_mode == "auto":
|
|
157
|
+
use_json_mode = getattr(self._driver, "supports_json_mode", False)
|
|
158
|
+
|
|
159
|
+
if use_json_mode:
|
|
160
|
+
merged = {**merged, "json_mode": True}
|
|
161
|
+
if getattr(self._driver, "supports_json_schema", False):
|
|
162
|
+
merged["json_schema"] = json_schema
|
|
163
|
+
|
|
164
|
+
if use_json_mode and getattr(self._driver, "supports_json_schema", False):
|
|
165
|
+
instruct = "Extract data matching the requested schema.\nIf a value is unknown use null."
|
|
166
|
+
elif use_json_mode:
|
|
167
|
+
instruct = (
|
|
168
|
+
"Return a JSON object that validates against this schema:\n"
|
|
169
|
+
f"{schema_string}\n\n"
|
|
170
|
+
"If a value is unknown use null."
|
|
171
|
+
)
|
|
172
|
+
else:
|
|
173
|
+
instruct = (
|
|
174
|
+
"Return only a single JSON object (no markdown, no extra text) that validates against this JSON schema:\n"
|
|
175
|
+
f"{schema_string}\n\n"
|
|
176
|
+
"If a value is unknown use null. Use double quotes for keys and strings."
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
full_user_content = f"{content}\n\n{instruct}"
|
|
180
|
+
|
|
181
|
+
messages = self._build_messages(full_user_content)
|
|
182
|
+
resp = await self._driver.generate_messages_with_hooks(messages, merged)
|
|
183
|
+
|
|
184
|
+
text = resp.get("text", "")
|
|
185
|
+
meta = resp.get("meta", {})
|
|
186
|
+
|
|
187
|
+
self._messages.append({"role": "user", "content": content})
|
|
188
|
+
|
|
189
|
+
cleaned = clean_json_text(text)
|
|
190
|
+
try:
|
|
191
|
+
json_obj = json.loads(cleaned)
|
|
192
|
+
except json.JSONDecodeError:
|
|
193
|
+
if ai_cleanup:
|
|
194
|
+
from .async_core import clean_json_text_with_ai
|
|
195
|
+
|
|
196
|
+
cleaned = await clean_json_text_with_ai(self._driver, cleaned, self._model_name, merged)
|
|
197
|
+
json_obj = json.loads(cleaned)
|
|
198
|
+
else:
|
|
199
|
+
raise
|
|
200
|
+
|
|
201
|
+
self._messages.append({"role": "assistant", "content": cleaned})
|
|
202
|
+
self._accumulate_usage(meta)
|
|
203
|
+
|
|
204
|
+
model_name = self._model_name
|
|
205
|
+
if "/" in model_name:
|
|
206
|
+
model_name = model_name.split("/", 1)[1]
|
|
207
|
+
|
|
208
|
+
usage = {
|
|
209
|
+
**meta,
|
|
210
|
+
"raw_response": resp,
|
|
211
|
+
"model_name": model_name or getattr(self._driver, "model", ""),
|
|
212
|
+
}
|
|
213
|
+
|
|
214
|
+
result: dict[str, Any] = {
|
|
215
|
+
"json_string": cleaned,
|
|
216
|
+
"json_object": json_obj,
|
|
217
|
+
"usage": usage,
|
|
218
|
+
"output_format": output_format,
|
|
219
|
+
}
|
|
220
|
+
|
|
221
|
+
if output_format == "toon":
|
|
222
|
+
try:
|
|
223
|
+
import toon
|
|
224
|
+
|
|
225
|
+
result["toon_string"] = toon.encode(json_obj)
|
|
226
|
+
except ImportError:
|
|
227
|
+
raise RuntimeError("TOON requested but 'python-toon' is not installed.") from None
|
|
228
|
+
|
|
229
|
+
return result
|
|
230
|
+
|
|
231
|
+
async def extract_with_model(
|
|
232
|
+
self,
|
|
233
|
+
model_cls: type[BaseModel],
|
|
234
|
+
text: str,
|
|
235
|
+
*,
|
|
236
|
+
instruction_template: str = "Extract information from the following text:",
|
|
237
|
+
ai_cleanup: bool = True,
|
|
238
|
+
output_format: Literal["json", "toon"] = "json",
|
|
239
|
+
options: dict[str, Any] | None = None,
|
|
240
|
+
json_mode: Literal["auto", "on", "off"] = "auto",
|
|
241
|
+
) -> dict[str, Any]:
|
|
242
|
+
"""Extract structured information into a Pydantic model with conversation context (async)."""
|
|
243
|
+
from .core import normalize_field_value
|
|
244
|
+
|
|
245
|
+
schema = model_cls.model_json_schema()
|
|
246
|
+
content_prompt = f"{instruction_template} {text}"
|
|
247
|
+
|
|
248
|
+
result = await self.ask_for_json(
|
|
249
|
+
content=content_prompt,
|
|
250
|
+
json_schema=schema,
|
|
251
|
+
ai_cleanup=ai_cleanup,
|
|
252
|
+
options=options,
|
|
253
|
+
output_format=output_format,
|
|
254
|
+
json_mode=json_mode,
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
json_object = result["json_object"]
|
|
258
|
+
schema_properties = schema.get("properties", {})
|
|
259
|
+
|
|
260
|
+
for field_name, field_info in model_cls.model_fields.items():
|
|
261
|
+
if field_name in json_object and field_name in schema_properties:
|
|
262
|
+
field_def = {
|
|
263
|
+
"nullable": not schema_properties[field_name].get("type")
|
|
264
|
+
or "null"
|
|
265
|
+
in (
|
|
266
|
+
schema_properties[field_name].get("anyOf", [])
|
|
267
|
+
if isinstance(schema_properties[field_name].get("anyOf"), list)
|
|
268
|
+
else []
|
|
269
|
+
),
|
|
270
|
+
"default": field_info.default
|
|
271
|
+
if hasattr(field_info, "default") and field_info.default is not ...
|
|
272
|
+
else None,
|
|
273
|
+
}
|
|
274
|
+
json_object[field_name] = normalize_field_value(
|
|
275
|
+
json_object[field_name], field_info.annotation, field_def
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
model_instance = model_cls(**json_object)
|
|
279
|
+
|
|
280
|
+
result_dict = {
|
|
281
|
+
"json_string": result["json_string"],
|
|
282
|
+
"json_object": result["json_object"],
|
|
283
|
+
"usage": result["usage"],
|
|
284
|
+
}
|
|
285
|
+
result_dict["model"] = model_instance
|
|
286
|
+
|
|
287
|
+
return type(
|
|
288
|
+
"ExtractResult",
|
|
289
|
+
(dict,),
|
|
290
|
+
{
|
|
291
|
+
"__getattr__": lambda self, key: self.get(key),
|
|
292
|
+
"__call__": lambda self: self["model"],
|
|
293
|
+
},
|
|
294
|
+
)(result_dict)
|
|
295
|
+
|
|
296
|
+
# ------------------------------------------------------------------
|
|
297
|
+
# Internal: stepwise with shared context
|
|
298
|
+
# ------------------------------------------------------------------
|
|
299
|
+
|
|
300
|
+
async def _stepwise_extract(
|
|
301
|
+
self,
|
|
302
|
+
model_cls: type[BaseModel],
|
|
303
|
+
text: str,
|
|
304
|
+
instruction_template: str,
|
|
305
|
+
ai_cleanup: bool,
|
|
306
|
+
fields: list[str] | None,
|
|
307
|
+
field_definitions: dict[str, Any] | None,
|
|
308
|
+
json_mode: Literal["auto", "on", "off"],
|
|
309
|
+
) -> dict[str, Union[str, dict[str, Any]]]:
|
|
310
|
+
"""Stepwise extraction using async conversation context between fields."""
|
|
311
|
+
if field_definitions is None:
|
|
312
|
+
field_definitions = get_registry_snapshot()
|
|
313
|
+
|
|
314
|
+
data: dict[str, Any] = {}
|
|
315
|
+
validation_errors: list[str] = []
|
|
316
|
+
field_results: dict[str, Any] = {}
|
|
317
|
+
|
|
318
|
+
accumulated_usage = {
|
|
319
|
+
"prompt_tokens": 0,
|
|
320
|
+
"completion_tokens": 0,
|
|
321
|
+
"total_tokens": 0,
|
|
322
|
+
"cost": 0.0,
|
|
323
|
+
"model_name": self._model_name,
|
|
324
|
+
"field_usages": {},
|
|
325
|
+
}
|
|
326
|
+
|
|
327
|
+
valid_fields = set(model_cls.model_fields.keys())
|
|
328
|
+
if fields is not None:
|
|
329
|
+
invalid_fields = set(fields) - valid_fields
|
|
330
|
+
if invalid_fields:
|
|
331
|
+
raise KeyError(f"Fields not found in model: {', '.join(invalid_fields)}")
|
|
332
|
+
field_items = [(name, model_cls.model_fields[name]) for name in fields]
|
|
333
|
+
else:
|
|
334
|
+
field_items = list(model_cls.model_fields.items())
|
|
335
|
+
|
|
336
|
+
# Seed conversation with the source text
|
|
337
|
+
self.add_context("user", f"I need to extract information from this text:\n\n{text}")
|
|
338
|
+
self.add_context(
|
|
339
|
+
"assistant", "I'll help you extract the information from that text. What would you like to extract?"
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
for field_name, field_info in field_items:
|
|
343
|
+
logger.debug("[stepwise-conv] Extracting field: %s", field_name)
|
|
344
|
+
|
|
345
|
+
field_schema = {
|
|
346
|
+
"value": {
|
|
347
|
+
"type": "integer" if field_info.annotation is int else "string",
|
|
348
|
+
"description": field_info.description or f"Value for {field_name}",
|
|
349
|
+
}
|
|
350
|
+
}
|
|
351
|
+
|
|
352
|
+
try:
|
|
353
|
+
prompt = instruction_template.format(field_name=field_name)
|
|
354
|
+
result = await self.ask_for_json(
|
|
355
|
+
content=f"{prompt} {text}",
|
|
356
|
+
json_schema=field_schema,
|
|
357
|
+
ai_cleanup=ai_cleanup,
|
|
358
|
+
json_mode=json_mode,
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
field_usage = result.get("usage", {})
|
|
362
|
+
accumulated_usage["prompt_tokens"] += field_usage.get("prompt_tokens", 0)
|
|
363
|
+
accumulated_usage["completion_tokens"] += field_usage.get("completion_tokens", 0)
|
|
364
|
+
accumulated_usage["total_tokens"] += field_usage.get("total_tokens", 0)
|
|
365
|
+
accumulated_usage["cost"] += field_usage.get("cost", 0.0)
|
|
366
|
+
accumulated_usage["field_usages"][field_name] = field_usage
|
|
367
|
+
|
|
368
|
+
extracted_value = result["json_object"]["value"]
|
|
369
|
+
if isinstance(extracted_value, dict) and "value" in extracted_value:
|
|
370
|
+
raw_value = extracted_value["value"]
|
|
371
|
+
else:
|
|
372
|
+
raw_value = extracted_value
|
|
373
|
+
|
|
374
|
+
from .core import normalize_field_value
|
|
375
|
+
|
|
376
|
+
field_def = {}
|
|
377
|
+
if field_definitions and field_name in field_definitions:
|
|
378
|
+
field_def = field_definitions[field_name] if isinstance(field_definitions[field_name], dict) else {}
|
|
379
|
+
|
|
380
|
+
nullable = field_def.get("nullable", True)
|
|
381
|
+
default_value = field_def.get("default")
|
|
382
|
+
if (
|
|
383
|
+
default_value is None
|
|
384
|
+
and hasattr(field_info, "default")
|
|
385
|
+
and field_info.default is not ...
|
|
386
|
+
and str(field_info.default) != "PydanticUndefined"
|
|
387
|
+
):
|
|
388
|
+
default_value = field_info.default
|
|
389
|
+
|
|
390
|
+
normalize_def = {"nullable": nullable, "default": default_value}
|
|
391
|
+
raw_value = normalize_field_value(raw_value, field_info.annotation, normalize_def)
|
|
392
|
+
|
|
393
|
+
try:
|
|
394
|
+
converted_value = convert_value(raw_value, field_info.annotation, allow_shorthand=True)
|
|
395
|
+
data[field_name] = converted_value
|
|
396
|
+
field_results[field_name] = {"status": "success", "used_default": False}
|
|
397
|
+
except ValueError as e:
|
|
398
|
+
error_msg = f"Type conversion failed for {field_name}: {e!s}"
|
|
399
|
+
has_default = _has_default(field_name, field_info, field_definitions)
|
|
400
|
+
if not has_default:
|
|
401
|
+
validation_errors.append(error_msg)
|
|
402
|
+
default_value = get_field_default(field_name, field_info, field_definitions)
|
|
403
|
+
data[field_name] = default_value
|
|
404
|
+
field_results[field_name] = {
|
|
405
|
+
"status": "conversion_failed",
|
|
406
|
+
"error": error_msg,
|
|
407
|
+
"used_default": True,
|
|
408
|
+
}
|
|
409
|
+
|
|
410
|
+
except Exception as e:
|
|
411
|
+
error_msg = f"Extraction failed for {field_name}: {e!s}"
|
|
412
|
+
has_default = _has_default(field_name, field_info, field_definitions)
|
|
413
|
+
if not has_default:
|
|
414
|
+
validation_errors.append(error_msg)
|
|
415
|
+
default_value = get_field_default(field_name, field_info, field_definitions)
|
|
416
|
+
data[field_name] = default_value
|
|
417
|
+
field_results[field_name] = {"status": "extraction_failed", "error": error_msg, "used_default": True}
|
|
418
|
+
accumulated_usage["field_usages"][field_name] = {
|
|
419
|
+
"error": str(e),
|
|
420
|
+
"status": "failed",
|
|
421
|
+
"used_default": True,
|
|
422
|
+
"default_value": default_value,
|
|
423
|
+
}
|
|
424
|
+
|
|
425
|
+
if validation_errors:
|
|
426
|
+
accumulated_usage["validation_errors"] = validation_errors
|
|
427
|
+
|
|
428
|
+
try:
|
|
429
|
+
model_instance = model_cls(**data)
|
|
430
|
+
model_dict = model_instance.model_dump()
|
|
431
|
+
|
|
432
|
+
class ExtendedJSONEncoder(json.JSONEncoder):
|
|
433
|
+
def default(self, obj):
|
|
434
|
+
if isinstance(obj, (datetime, date)):
|
|
435
|
+
return obj.isoformat()
|
|
436
|
+
if isinstance(obj, Decimal):
|
|
437
|
+
return str(obj)
|
|
438
|
+
return super().default(obj)
|
|
439
|
+
|
|
440
|
+
json_string = json.dumps(model_dict, cls=ExtendedJSONEncoder)
|
|
441
|
+
|
|
442
|
+
result = {
|
|
443
|
+
"json_string": json_string,
|
|
444
|
+
"json_object": json.loads(json_string),
|
|
445
|
+
"usage": accumulated_usage,
|
|
446
|
+
"field_results": field_results,
|
|
447
|
+
}
|
|
448
|
+
result["model"] = model_instance
|
|
449
|
+
return type(
|
|
450
|
+
"ExtractResult",
|
|
451
|
+
(dict,),
|
|
452
|
+
{"__getattr__": lambda self, key: self.get(key), "__call__": lambda self: self["model"]},
|
|
453
|
+
)(result)
|
|
454
|
+
except Exception as e:
|
|
455
|
+
error_msg = f"Model validation error: {e!s}"
|
|
456
|
+
if "validation_errors" not in accumulated_usage:
|
|
457
|
+
accumulated_usage["validation_errors"] = []
|
|
458
|
+
accumulated_usage["validation_errors"].append(error_msg)
|
|
459
|
+
|
|
460
|
+
error_result = {
|
|
461
|
+
"json_string": "{}",
|
|
462
|
+
"json_object": {},
|
|
463
|
+
"usage": accumulated_usage,
|
|
464
|
+
"field_results": field_results,
|
|
465
|
+
"error": error_msg,
|
|
466
|
+
}
|
|
467
|
+
return type(
|
|
468
|
+
"ExtractResult",
|
|
469
|
+
(dict,),
|
|
470
|
+
{"__getattr__": lambda self, key: self.get(key), "__call__": lambda self: None},
|
|
471
|
+
)(error_result)
|
|
472
|
+
|
|
473
|
+
|
|
474
|
+
def _has_default(field_name: str, field_info: Any, field_definitions: dict[str, Any] | None) -> bool:
|
|
475
|
+
"""Check whether a Pydantic field has a usable default value."""
|
|
476
|
+
if field_definitions and field_name in field_definitions:
|
|
477
|
+
fd = field_definitions[field_name]
|
|
478
|
+
if isinstance(fd, dict) and "default" in fd:
|
|
479
|
+
return True
|
|
480
|
+
if hasattr(field_info, "default"):
|
|
481
|
+
val = field_info.default
|
|
482
|
+
if val is not ... and str(val) != "PydanticUndefined":
|
|
483
|
+
return True
|
|
484
|
+
return False
|