prompture 0.0.29.dev8__py3-none-any.whl → 0.0.38.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. prompture/__init__.py +264 -23
  2. prompture/_version.py +34 -0
  3. prompture/agent.py +924 -0
  4. prompture/agent_types.py +156 -0
  5. prompture/aio/__init__.py +74 -0
  6. prompture/async_agent.py +880 -0
  7. prompture/async_conversation.py +789 -0
  8. prompture/async_core.py +803 -0
  9. prompture/async_driver.py +193 -0
  10. prompture/async_groups.py +551 -0
  11. prompture/cache.py +469 -0
  12. prompture/callbacks.py +55 -0
  13. prompture/cli.py +63 -4
  14. prompture/conversation.py +826 -0
  15. prompture/core.py +894 -263
  16. prompture/cost_mixin.py +51 -0
  17. prompture/discovery.py +187 -0
  18. prompture/driver.py +206 -5
  19. prompture/drivers/__init__.py +175 -67
  20. prompture/drivers/airllm_driver.py +109 -0
  21. prompture/drivers/async_airllm_driver.py +26 -0
  22. prompture/drivers/async_azure_driver.py +123 -0
  23. prompture/drivers/async_claude_driver.py +113 -0
  24. prompture/drivers/async_google_driver.py +316 -0
  25. prompture/drivers/async_grok_driver.py +97 -0
  26. prompture/drivers/async_groq_driver.py +90 -0
  27. prompture/drivers/async_hugging_driver.py +61 -0
  28. prompture/drivers/async_lmstudio_driver.py +148 -0
  29. prompture/drivers/async_local_http_driver.py +44 -0
  30. prompture/drivers/async_ollama_driver.py +135 -0
  31. prompture/drivers/async_openai_driver.py +102 -0
  32. prompture/drivers/async_openrouter_driver.py +102 -0
  33. prompture/drivers/async_registry.py +133 -0
  34. prompture/drivers/azure_driver.py +42 -9
  35. prompture/drivers/claude_driver.py +257 -34
  36. prompture/drivers/google_driver.py +295 -42
  37. prompture/drivers/grok_driver.py +35 -32
  38. prompture/drivers/groq_driver.py +33 -26
  39. prompture/drivers/hugging_driver.py +6 -6
  40. prompture/drivers/lmstudio_driver.py +97 -19
  41. prompture/drivers/local_http_driver.py +6 -6
  42. prompture/drivers/ollama_driver.py +168 -23
  43. prompture/drivers/openai_driver.py +184 -9
  44. prompture/drivers/openrouter_driver.py +37 -25
  45. prompture/drivers/registry.py +306 -0
  46. prompture/drivers/vision_helpers.py +153 -0
  47. prompture/field_definitions.py +106 -96
  48. prompture/group_types.py +147 -0
  49. prompture/groups.py +530 -0
  50. prompture/image.py +180 -0
  51. prompture/logging.py +80 -0
  52. prompture/model_rates.py +217 -0
  53. prompture/persistence.py +254 -0
  54. prompture/persona.py +482 -0
  55. prompture/runner.py +49 -47
  56. prompture/scaffold/__init__.py +1 -0
  57. prompture/scaffold/generator.py +84 -0
  58. prompture/scaffold/templates/Dockerfile.j2 +12 -0
  59. prompture/scaffold/templates/README.md.j2 +41 -0
  60. prompture/scaffold/templates/config.py.j2 +21 -0
  61. prompture/scaffold/templates/env.example.j2 +8 -0
  62. prompture/scaffold/templates/main.py.j2 +86 -0
  63. prompture/scaffold/templates/models.py.j2 +40 -0
  64. prompture/scaffold/templates/requirements.txt.j2 +5 -0
  65. prompture/serialization.py +218 -0
  66. prompture/server.py +183 -0
  67. prompture/session.py +117 -0
  68. prompture/settings.py +19 -1
  69. prompture/tools.py +219 -267
  70. prompture/tools_schema.py +254 -0
  71. prompture/validator.py +3 -3
  72. prompture-0.0.38.dev2.dist-info/METADATA +369 -0
  73. prompture-0.0.38.dev2.dist-info/RECORD +77 -0
  74. {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/WHEEL +1 -1
  75. prompture-0.0.29.dev8.dist-info/METADATA +0 -368
  76. prompture-0.0.29.dev8.dist-info/RECORD +0 -27
  77. {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/entry_points.txt +0 -0
  78. {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/licenses/LICENSE +0 -0
  79. {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,789 @@
1
+ """Async stateful multi-turn conversation support for Prompture."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import logging
7
+ import uuid
8
+ from collections.abc import AsyncIterator
9
+ from datetime import date, datetime, timezone
10
+ from decimal import Decimal
11
+ from pathlib import Path
12
+ from typing import Any, Callable, Literal, Union
13
+
14
+ from pydantic import BaseModel
15
+
16
+ from .async_driver import AsyncDriver
17
+ from .callbacks import DriverCallbacks
18
+ from .drivers.async_registry import get_async_driver_for_model
19
+ from .field_definitions import get_registry_snapshot
20
+ from .image import ImageInput, make_image
21
+ from .persistence import load_from_file, save_to_file
22
+ from .persona import Persona, get_persona
23
+ from .serialization import export_conversation, import_conversation
24
+ from .session import UsageSession
25
+ from .tools import (
26
+ clean_json_text,
27
+ convert_value,
28
+ get_field_default,
29
+ )
30
+ from .tools_schema import ToolRegistry
31
+
32
+ logger = logging.getLogger("prompture.async_conversation")
33
+
34
+
35
+ class AsyncConversation:
36
+ """Async stateful multi-turn conversation with an LLM.
37
+
38
+ Mirrors :class:`Conversation` but all methods are ``async``.
39
+
40
+ Example::
41
+
42
+ conv = AsyncConversation("openai/gpt-4", system_prompt="You are a data extractor")
43
+ r1 = await conv.ask_for_json("Extract names from: John, age 30", name_schema)
44
+ r2 = await conv.ask_for_json("Now extract ages", age_schema)
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ model_name: str | None = None,
50
+ *,
51
+ driver: AsyncDriver | None = None,
52
+ system_prompt: str | None = None,
53
+ persona: str | Persona | None = None,
54
+ options: dict[str, Any] | None = None,
55
+ callbacks: DriverCallbacks | None = None,
56
+ tools: ToolRegistry | None = None,
57
+ max_tool_rounds: int = 10,
58
+ conversation_id: str | None = None,
59
+ auto_save: str | Path | None = None,
60
+ tags: list[str] | None = None,
61
+ ) -> None:
62
+ if system_prompt is not None and persona is not None:
63
+ raise ValueError("Cannot provide both 'system_prompt' and 'persona'. Use one or the other.")
64
+
65
+ # Resolve persona
66
+ resolved_persona: Persona | None = None
67
+ if persona is not None:
68
+ if isinstance(persona, str):
69
+ resolved_persona = get_persona(persona)
70
+ if resolved_persona is None:
71
+ raise ValueError(f"Persona '{persona}' not found in registry.")
72
+ else:
73
+ resolved_persona = persona
74
+
75
+ if model_name is None and driver is None:
76
+ if resolved_persona is not None and resolved_persona.model_hint:
77
+ model_name = resolved_persona.model_hint
78
+ else:
79
+ raise ValueError("Either model_name or driver must be provided")
80
+
81
+ if driver is not None:
82
+ self._driver = driver
83
+ else:
84
+ self._driver = get_async_driver_for_model(model_name)
85
+
86
+ if callbacks is not None:
87
+ self._driver.callbacks = callbacks
88
+
89
+ self._model_name = model_name or ""
90
+
91
+ # Apply persona: render system_prompt and merge settings
92
+ if resolved_persona is not None:
93
+ self._system_prompt = resolved_persona.render()
94
+ self._options = {**resolved_persona.settings, **(dict(options) if options else {})}
95
+ else:
96
+ self._system_prompt = system_prompt
97
+ self._options = dict(options) if options else {}
98
+
99
+ self._messages: list[dict[str, Any]] = []
100
+ self._usage = {
101
+ "prompt_tokens": 0,
102
+ "completion_tokens": 0,
103
+ "total_tokens": 0,
104
+ "cost": 0.0,
105
+ "turns": 0,
106
+ }
107
+ self._tools = tools or ToolRegistry()
108
+ self._max_tool_rounds = max_tool_rounds
109
+
110
+ # Persistence
111
+ self._conversation_id = conversation_id or str(uuid.uuid4())
112
+ self._auto_save = Path(auto_save) if auto_save else None
113
+ self._metadata: dict[str, Any] = {
114
+ "created_at": datetime.now(timezone.utc).isoformat(),
115
+ "tags": list(tags) if tags else [],
116
+ }
117
+
118
+ # ------------------------------------------------------------------
119
+ # Public helpers
120
+ # ------------------------------------------------------------------
121
+
122
+ @property
123
+ def messages(self) -> list[dict[str, Any]]:
124
+ """Read-only view of the conversation history."""
125
+ return list(self._messages)
126
+
127
+ @property
128
+ def usage(self) -> dict[str, Any]:
129
+ """Accumulated token/cost totals across all turns."""
130
+ return dict(self._usage)
131
+
132
+ def clear(self) -> None:
133
+ """Reset message history (keeps system_prompt and driver)."""
134
+ self._messages.clear()
135
+
136
+ def add_context(self, role: str, content: str, images: list[ImageInput] | None = None) -> None:
137
+ """Seed the history with a user or assistant message."""
138
+ if role not in ("user", "assistant"):
139
+ raise ValueError("role must be 'user' or 'assistant'")
140
+ msg_content = self._build_content_with_images(content, images)
141
+ self._messages.append({"role": role, "content": msg_content})
142
+
143
+ def register_tool(
144
+ self,
145
+ fn: Callable[..., Any],
146
+ *,
147
+ name: str | None = None,
148
+ description: str | None = None,
149
+ ) -> None:
150
+ """Register a Python function as a tool the LLM can call."""
151
+ self._tools.register(fn, name=name, description=description)
152
+
153
+ def usage_summary(self) -> str:
154
+ """Human-readable summary of accumulated usage."""
155
+ u = self._usage
156
+ return f"Conversation: {u['total_tokens']:,} tokens across {u['turns']} turn(s) costing ${u['cost']:.4f}"
157
+
158
+ # ------------------------------------------------------------------
159
+ # Persistence properties
160
+ # ------------------------------------------------------------------
161
+
162
+ @property
163
+ def conversation_id(self) -> str:
164
+ """Unique identifier for this conversation."""
165
+ return self._conversation_id
166
+
167
+ @property
168
+ def tags(self) -> list[str]:
169
+ """Tags attached to this conversation."""
170
+ return self._metadata.get("tags", [])
171
+
172
+ @tags.setter
173
+ def tags(self, value: list[str]) -> None:
174
+ self._metadata["tags"] = list(value)
175
+
176
+ # ------------------------------------------------------------------
177
+ # Export / Import
178
+ # ------------------------------------------------------------------
179
+
180
+ def export(self, *, usage_session: UsageSession | None = None, strip_images: bool = False) -> dict[str, Any]:
181
+ """Export conversation state to a JSON-serializable dict."""
182
+ tools_metadata = (
183
+ [
184
+ {"name": td.name, "description": td.description, "parameters": td.parameters}
185
+ for td in self._tools.definitions
186
+ ]
187
+ if self._tools and self._tools.definitions
188
+ else None
189
+ )
190
+ return export_conversation(
191
+ model_name=self._model_name,
192
+ system_prompt=self._system_prompt,
193
+ options=self._options,
194
+ messages=self._messages,
195
+ usage=self._usage,
196
+ max_tool_rounds=self._max_tool_rounds,
197
+ tools_metadata=tools_metadata,
198
+ usage_session=usage_session,
199
+ metadata=self._metadata,
200
+ conversation_id=self._conversation_id,
201
+ strip_images=strip_images,
202
+ )
203
+
204
+ @classmethod
205
+ def from_export(
206
+ cls,
207
+ data: dict[str, Any],
208
+ *,
209
+ callbacks: DriverCallbacks | None = None,
210
+ tools: ToolRegistry | None = None,
211
+ ) -> AsyncConversation:
212
+ """Reconstruct an :class:`AsyncConversation` from an export dict.
213
+
214
+ The driver is reconstructed from the stored ``model_name`` using
215
+ :func:`get_async_driver_for_model`. Callbacks and tool functions
216
+ must be re-attached by the caller.
217
+ """
218
+ imported = import_conversation(data)
219
+
220
+ model_name = imported.get("model_name") or ""
221
+ if not model_name:
222
+ raise ValueError("Cannot restore conversation: export has no model_name")
223
+ conv = cls(
224
+ model_name=model_name,
225
+ system_prompt=imported.get("system_prompt"),
226
+ options=imported.get("options", {}),
227
+ callbacks=callbacks,
228
+ tools=tools,
229
+ max_tool_rounds=imported.get("max_tool_rounds", 10),
230
+ conversation_id=imported.get("conversation_id"),
231
+ tags=imported.get("metadata", {}).get("tags", []),
232
+ )
233
+ conv._messages = imported.get("messages", [])
234
+ conv._usage = imported.get(
235
+ "usage",
236
+ {
237
+ "prompt_tokens": 0,
238
+ "completion_tokens": 0,
239
+ "total_tokens": 0,
240
+ "cost": 0.0,
241
+ "turns": 0,
242
+ },
243
+ )
244
+ meta = imported.get("metadata", {})
245
+ if "created_at" in meta:
246
+ conv._metadata["created_at"] = meta["created_at"]
247
+ return conv
248
+
249
+ def save(self, path: str | Path, **kwargs: Any) -> None:
250
+ """Export and write to a JSON file."""
251
+ save_to_file(self.export(**kwargs), path)
252
+
253
+ @classmethod
254
+ def load(
255
+ cls,
256
+ path: str | Path,
257
+ *,
258
+ callbacks: DriverCallbacks | None = None,
259
+ tools: ToolRegistry | None = None,
260
+ ) -> AsyncConversation:
261
+ """Load a conversation from a JSON file."""
262
+ data = load_from_file(path)
263
+ return cls.from_export(data, callbacks=callbacks, tools=tools)
264
+
265
+ def _maybe_auto_save(self) -> None:
266
+ """Auto-save after each turn if configured."""
267
+ if self._auto_save is None:
268
+ return
269
+ try:
270
+ self.save(self._auto_save)
271
+ except Exception:
272
+ logger.debug("Auto-save failed for conversation %s", self._conversation_id, exc_info=True)
273
+
274
+ # ------------------------------------------------------------------
275
+ # Core methods
276
+ # ------------------------------------------------------------------
277
+
278
+ @staticmethod
279
+ def _build_content_with_images(text: str, images: list[ImageInput] | None = None) -> str | list[dict[str, Any]]:
280
+ """Return plain string when no images, or a list of content blocks."""
281
+ if not images:
282
+ return text
283
+ blocks: list[dict[str, Any]] = [{"type": "text", "text": text}]
284
+ for img in images:
285
+ ic = make_image(img)
286
+ blocks.append({"type": "image", "source": ic})
287
+ return blocks
288
+
289
+ def _build_messages(self, user_content: str, images: list[ImageInput] | None = None) -> list[dict[str, Any]]:
290
+ """Build the full messages array for an API call."""
291
+ msgs: list[dict[str, Any]] = []
292
+ if self._system_prompt:
293
+ msgs.append({"role": "system", "content": self._system_prompt})
294
+ msgs.extend(self._messages)
295
+ content = self._build_content_with_images(user_content, images)
296
+ msgs.append({"role": "user", "content": content})
297
+ return msgs
298
+
299
+ def _accumulate_usage(self, meta: dict[str, Any]) -> None:
300
+ self._usage["prompt_tokens"] += meta.get("prompt_tokens", 0)
301
+ self._usage["completion_tokens"] += meta.get("completion_tokens", 0)
302
+ self._usage["total_tokens"] += meta.get("total_tokens", 0)
303
+ self._usage["cost"] += meta.get("cost", 0.0)
304
+ self._usage["turns"] += 1
305
+ self._maybe_auto_save()
306
+
307
+ async def ask(
308
+ self,
309
+ content: str,
310
+ options: dict[str, Any] | None = None,
311
+ images: list[ImageInput] | None = None,
312
+ ) -> str:
313
+ """Send a message and get a raw text response (async).
314
+
315
+ If tools are registered and the driver supports tool use,
316
+ dispatches to the async tool execution loop.
317
+ """
318
+ if self._tools and getattr(self._driver, "supports_tool_use", False):
319
+ return await self._ask_with_tools(content, options, images=images)
320
+
321
+ merged = {**self._options, **(options or {})}
322
+ messages = self._build_messages(content, images=images)
323
+ resp = await self._driver.generate_messages_with_hooks(messages, merged)
324
+
325
+ text = resp.get("text", "")
326
+ meta = resp.get("meta", {})
327
+
328
+ user_content = self._build_content_with_images(content, images)
329
+ self._messages.append({"role": "user", "content": user_content})
330
+ self._messages.append({"role": "assistant", "content": text})
331
+ self._accumulate_usage(meta)
332
+
333
+ return text
334
+
335
+ async def _ask_with_tools(
336
+ self,
337
+ content: str,
338
+ options: dict[str, Any] | None = None,
339
+ images: list[ImageInput] | None = None,
340
+ ) -> str:
341
+ """Async tool-use loop: send -> check tool_calls -> execute -> re-send."""
342
+ merged = {**self._options, **(options or {})}
343
+ tool_defs = self._tools.to_openai_format()
344
+
345
+ user_content = self._build_content_with_images(content, images)
346
+ self._messages.append({"role": "user", "content": user_content})
347
+ msgs = self._build_messages_raw()
348
+
349
+ for _round in range(self._max_tool_rounds):
350
+ resp = await self._driver.generate_messages_with_tools(msgs, tool_defs, merged)
351
+
352
+ meta = resp.get("meta", {})
353
+ self._accumulate_usage(meta)
354
+
355
+ tool_calls = resp.get("tool_calls", [])
356
+ text = resp.get("text", "")
357
+
358
+ if not tool_calls:
359
+ self._messages.append({"role": "assistant", "content": text})
360
+ return text
361
+
362
+ assistant_msg: dict[str, Any] = {"role": "assistant", "content": text}
363
+ assistant_msg["tool_calls"] = [
364
+ {
365
+ "id": tc["id"],
366
+ "type": "function",
367
+ "function": {"name": tc["name"], "arguments": json.dumps(tc["arguments"])},
368
+ }
369
+ for tc in tool_calls
370
+ ]
371
+ self._messages.append(assistant_msg)
372
+ msgs.append(assistant_msg)
373
+
374
+ for tc in tool_calls:
375
+ try:
376
+ result = self._tools.execute(tc["name"], tc["arguments"])
377
+ result_str = json.dumps(result) if not isinstance(result, str) else result
378
+ except Exception as exc:
379
+ result_str = f"Error: {exc}"
380
+
381
+ tool_result_msg: dict[str, Any] = {
382
+ "role": "tool",
383
+ "tool_call_id": tc["id"],
384
+ "content": result_str,
385
+ }
386
+ self._messages.append(tool_result_msg)
387
+ msgs.append(tool_result_msg)
388
+
389
+ raise RuntimeError(f"Tool execution loop exceeded {self._max_tool_rounds} rounds")
390
+
391
+ def _build_messages_raw(self) -> list[dict[str, Any]]:
392
+ """Build messages array from system prompt + full history (including tool messages)."""
393
+ msgs: list[dict[str, Any]] = []
394
+ if self._system_prompt:
395
+ msgs.append({"role": "system", "content": self._system_prompt})
396
+ msgs.extend(self._messages)
397
+ return msgs
398
+
399
+ # ------------------------------------------------------------------
400
+ # Streaming
401
+ # ------------------------------------------------------------------
402
+
403
+ async def ask_stream(
404
+ self,
405
+ content: str,
406
+ options: dict[str, Any] | None = None,
407
+ images: list[ImageInput] | None = None,
408
+ ) -> AsyncIterator[str]:
409
+ """Send a message and yield text chunks as they arrive (async).
410
+
411
+ Falls back to non-streaming :meth:`ask` if the driver doesn't
412
+ support streaming.
413
+ """
414
+ if not getattr(self._driver, "supports_streaming", False):
415
+ yield await self.ask(content, options, images=images)
416
+ return
417
+
418
+ merged = {**self._options, **(options or {})}
419
+ messages = self._build_messages(content, images=images)
420
+
421
+ user_content = self._build_content_with_images(content, images)
422
+ self._messages.append({"role": "user", "content": user_content})
423
+
424
+ full_text = ""
425
+ async for chunk in self._driver.generate_messages_stream(messages, merged):
426
+ if chunk["type"] == "delta":
427
+ full_text += chunk["text"]
428
+ self._driver._fire_callback(
429
+ "on_stream_delta",
430
+ {"text": chunk["text"], "driver": getattr(self._driver, "model", self._driver.__class__.__name__)},
431
+ )
432
+ yield chunk["text"]
433
+ elif chunk["type"] == "done":
434
+ meta = chunk.get("meta", {})
435
+ self._accumulate_usage(meta)
436
+
437
+ self._messages.append({"role": "assistant", "content": full_text})
438
+
439
+ async def ask_for_json(
440
+ self,
441
+ content: str,
442
+ json_schema: dict[str, Any],
443
+ *,
444
+ ai_cleanup: bool = True,
445
+ options: dict[str, Any] | None = None,
446
+ output_format: Literal["json", "toon"] = "json",
447
+ json_mode: Literal["auto", "on", "off"] = "auto",
448
+ images: list[ImageInput] | None = None,
449
+ ) -> dict[str, Any]:
450
+ """Send a message with schema enforcement and get structured JSON back (async)."""
451
+ merged = {**self._options, **(options or {})}
452
+
453
+ schema_string = json.dumps(json_schema, indent=2)
454
+
455
+ use_json_mode = False
456
+ if json_mode == "on":
457
+ use_json_mode = True
458
+ elif json_mode == "auto":
459
+ use_json_mode = getattr(self._driver, "supports_json_mode", False)
460
+
461
+ if use_json_mode:
462
+ merged = {**merged, "json_mode": True}
463
+ if getattr(self._driver, "supports_json_schema", False):
464
+ merged["json_schema"] = json_schema
465
+
466
+ if use_json_mode and getattr(self._driver, "supports_json_schema", False):
467
+ instruct = "Extract data matching the requested schema.\nIf a value is unknown use null."
468
+ elif use_json_mode:
469
+ instruct = (
470
+ "Return a JSON object that validates against this schema:\n"
471
+ f"{schema_string}\n\n"
472
+ "If a value is unknown use null."
473
+ )
474
+ else:
475
+ instruct = (
476
+ "Return only a single JSON object (no markdown, no extra text) that validates against this JSON schema:\n"
477
+ f"{schema_string}\n\n"
478
+ "If a value is unknown use null. Use double quotes for keys and strings."
479
+ )
480
+
481
+ full_user_content = f"{content}\n\n{instruct}"
482
+
483
+ messages = self._build_messages(full_user_content, images=images)
484
+ resp = await self._driver.generate_messages_with_hooks(messages, merged)
485
+
486
+ text = resp.get("text", "")
487
+ meta = resp.get("meta", {})
488
+
489
+ user_content = self._build_content_with_images(content, images)
490
+ self._messages.append({"role": "user", "content": user_content})
491
+
492
+ cleaned = clean_json_text(text)
493
+ try:
494
+ json_obj = json.loads(cleaned)
495
+ except json.JSONDecodeError:
496
+ if ai_cleanup:
497
+ from .async_core import clean_json_text_with_ai
498
+
499
+ cleaned = await clean_json_text_with_ai(self._driver, cleaned, self._model_name, merged)
500
+ json_obj = json.loads(cleaned)
501
+ else:
502
+ raise
503
+
504
+ self._messages.append({"role": "assistant", "content": cleaned})
505
+ self._accumulate_usage(meta)
506
+
507
+ model_name = self._model_name
508
+ if "/" in model_name:
509
+ model_name = model_name.split("/", 1)[1]
510
+
511
+ usage = {
512
+ **meta,
513
+ "raw_response": resp,
514
+ "model_name": model_name or getattr(self._driver, "model", ""),
515
+ }
516
+
517
+ result: dict[str, Any] = {
518
+ "json_string": cleaned,
519
+ "json_object": json_obj,
520
+ "usage": usage,
521
+ "output_format": output_format,
522
+ }
523
+
524
+ if output_format == "toon":
525
+ try:
526
+ import toon
527
+
528
+ result["toon_string"] = toon.encode(json_obj)
529
+ except ImportError:
530
+ raise RuntimeError("TOON requested but 'python-toon' is not installed.") from None
531
+
532
+ return result
533
+
534
+ async def extract_with_model(
535
+ self,
536
+ model_cls: type[BaseModel],
537
+ text: str,
538
+ *,
539
+ instruction_template: str = "Extract information from the following text:",
540
+ ai_cleanup: bool = True,
541
+ output_format: Literal["json", "toon"] = "json",
542
+ options: dict[str, Any] | None = None,
543
+ json_mode: Literal["auto", "on", "off"] = "auto",
544
+ images: list[ImageInput] | None = None,
545
+ ) -> dict[str, Any]:
546
+ """Extract structured information into a Pydantic model with conversation context (async)."""
547
+ from .core import normalize_field_value
548
+
549
+ schema = model_cls.model_json_schema()
550
+ content_prompt = f"{instruction_template} {text}"
551
+
552
+ result = await self.ask_for_json(
553
+ content=content_prompt,
554
+ json_schema=schema,
555
+ ai_cleanup=ai_cleanup,
556
+ options=options,
557
+ output_format=output_format,
558
+ json_mode=json_mode,
559
+ images=images,
560
+ )
561
+
562
+ json_object = result["json_object"]
563
+ schema_properties = schema.get("properties", {})
564
+
565
+ for field_name, field_info in model_cls.model_fields.items():
566
+ if field_name in json_object and field_name in schema_properties:
567
+ field_def = {
568
+ "nullable": not schema_properties[field_name].get("type")
569
+ or "null"
570
+ in (
571
+ schema_properties[field_name].get("anyOf", [])
572
+ if isinstance(schema_properties[field_name].get("anyOf"), list)
573
+ else []
574
+ ),
575
+ "default": field_info.default
576
+ if hasattr(field_info, "default") and field_info.default is not ...
577
+ else None,
578
+ }
579
+ json_object[field_name] = normalize_field_value(
580
+ json_object[field_name], field_info.annotation, field_def
581
+ )
582
+
583
+ model_instance = model_cls(**json_object)
584
+
585
+ result_dict = {
586
+ "json_string": result["json_string"],
587
+ "json_object": result["json_object"],
588
+ "usage": result["usage"],
589
+ }
590
+ result_dict["model"] = model_instance
591
+
592
+ return type(
593
+ "ExtractResult",
594
+ (dict,),
595
+ {
596
+ "__getattr__": lambda self, key: self.get(key),
597
+ "__call__": lambda self: self["model"],
598
+ },
599
+ )(result_dict)
600
+
601
+ # ------------------------------------------------------------------
602
+ # Internal: stepwise with shared context
603
+ # ------------------------------------------------------------------
604
+
605
+ async def _stepwise_extract(
606
+ self,
607
+ model_cls: type[BaseModel],
608
+ text: str,
609
+ instruction_template: str,
610
+ ai_cleanup: bool,
611
+ fields: list[str] | None,
612
+ field_definitions: dict[str, Any] | None,
613
+ json_mode: Literal["auto", "on", "off"],
614
+ ) -> dict[str, Union[str, dict[str, Any]]]:
615
+ """Stepwise extraction using async conversation context between fields."""
616
+ if field_definitions is None:
617
+ field_definitions = get_registry_snapshot()
618
+
619
+ data: dict[str, Any] = {}
620
+ validation_errors: list[str] = []
621
+ field_results: dict[str, Any] = {}
622
+
623
+ accumulated_usage = {
624
+ "prompt_tokens": 0,
625
+ "completion_tokens": 0,
626
+ "total_tokens": 0,
627
+ "cost": 0.0,
628
+ "model_name": self._model_name,
629
+ "field_usages": {},
630
+ }
631
+
632
+ valid_fields = set(model_cls.model_fields.keys())
633
+ if fields is not None:
634
+ invalid_fields = set(fields) - valid_fields
635
+ if invalid_fields:
636
+ raise KeyError(f"Fields not found in model: {', '.join(invalid_fields)}")
637
+ field_items = [(name, model_cls.model_fields[name]) for name in fields]
638
+ else:
639
+ field_items = list(model_cls.model_fields.items())
640
+
641
+ # Seed conversation with the source text
642
+ self.add_context("user", f"I need to extract information from this text:\n\n{text}")
643
+ self.add_context(
644
+ "assistant", "I'll help you extract the information from that text. What would you like to extract?"
645
+ )
646
+
647
+ for field_name, field_info in field_items:
648
+ logger.debug("[stepwise-conv] Extracting field: %s", field_name)
649
+
650
+ field_schema = {
651
+ "value": {
652
+ "type": "integer" if field_info.annotation is int else "string",
653
+ "description": field_info.description or f"Value for {field_name}",
654
+ }
655
+ }
656
+
657
+ try:
658
+ prompt = instruction_template.format(field_name=field_name)
659
+ result = await self.ask_for_json(
660
+ content=f"{prompt} {text}",
661
+ json_schema=field_schema,
662
+ ai_cleanup=ai_cleanup,
663
+ json_mode=json_mode,
664
+ )
665
+
666
+ field_usage = result.get("usage", {})
667
+ accumulated_usage["prompt_tokens"] += field_usage.get("prompt_tokens", 0)
668
+ accumulated_usage["completion_tokens"] += field_usage.get("completion_tokens", 0)
669
+ accumulated_usage["total_tokens"] += field_usage.get("total_tokens", 0)
670
+ accumulated_usage["cost"] += field_usage.get("cost", 0.0)
671
+ accumulated_usage["field_usages"][field_name] = field_usage
672
+
673
+ extracted_value = result["json_object"]["value"]
674
+ if isinstance(extracted_value, dict) and "value" in extracted_value:
675
+ raw_value = extracted_value["value"]
676
+ else:
677
+ raw_value = extracted_value
678
+
679
+ from .core import normalize_field_value
680
+
681
+ field_def = {}
682
+ if field_definitions and field_name in field_definitions:
683
+ field_def = field_definitions[field_name] if isinstance(field_definitions[field_name], dict) else {}
684
+
685
+ nullable = field_def.get("nullable", True)
686
+ default_value = field_def.get("default")
687
+ if (
688
+ default_value is None
689
+ and hasattr(field_info, "default")
690
+ and field_info.default is not ...
691
+ and str(field_info.default) != "PydanticUndefined"
692
+ ):
693
+ default_value = field_info.default
694
+
695
+ normalize_def = {"nullable": nullable, "default": default_value}
696
+ raw_value = normalize_field_value(raw_value, field_info.annotation, normalize_def)
697
+
698
+ try:
699
+ converted_value = convert_value(raw_value, field_info.annotation, allow_shorthand=True)
700
+ data[field_name] = converted_value
701
+ field_results[field_name] = {"status": "success", "used_default": False}
702
+ except ValueError as e:
703
+ error_msg = f"Type conversion failed for {field_name}: {e!s}"
704
+ has_default = _has_default(field_name, field_info, field_definitions)
705
+ if not has_default:
706
+ validation_errors.append(error_msg)
707
+ default_value = get_field_default(field_name, field_info, field_definitions)
708
+ data[field_name] = default_value
709
+ field_results[field_name] = {
710
+ "status": "conversion_failed",
711
+ "error": error_msg,
712
+ "used_default": True,
713
+ }
714
+
715
+ except Exception as e:
716
+ error_msg = f"Extraction failed for {field_name}: {e!s}"
717
+ has_default = _has_default(field_name, field_info, field_definitions)
718
+ if not has_default:
719
+ validation_errors.append(error_msg)
720
+ default_value = get_field_default(field_name, field_info, field_definitions)
721
+ data[field_name] = default_value
722
+ field_results[field_name] = {"status": "extraction_failed", "error": error_msg, "used_default": True}
723
+ accumulated_usage["field_usages"][field_name] = {
724
+ "error": str(e),
725
+ "status": "failed",
726
+ "used_default": True,
727
+ "default_value": default_value,
728
+ }
729
+
730
+ if validation_errors:
731
+ accumulated_usage["validation_errors"] = validation_errors
732
+
733
+ try:
734
+ model_instance = model_cls(**data)
735
+ model_dict = model_instance.model_dump()
736
+
737
+ class ExtendedJSONEncoder(json.JSONEncoder):
738
+ def default(self, obj):
739
+ if isinstance(obj, (datetime, date)):
740
+ return obj.isoformat()
741
+ if isinstance(obj, Decimal):
742
+ return str(obj)
743
+ return super().default(obj)
744
+
745
+ json_string = json.dumps(model_dict, cls=ExtendedJSONEncoder)
746
+
747
+ result = {
748
+ "json_string": json_string,
749
+ "json_object": json.loads(json_string),
750
+ "usage": accumulated_usage,
751
+ "field_results": field_results,
752
+ }
753
+ result["model"] = model_instance
754
+ return type(
755
+ "ExtractResult",
756
+ (dict,),
757
+ {"__getattr__": lambda self, key: self.get(key), "__call__": lambda self: self["model"]},
758
+ )(result)
759
+ except Exception as e:
760
+ error_msg = f"Model validation error: {e!s}"
761
+ if "validation_errors" not in accumulated_usage:
762
+ accumulated_usage["validation_errors"] = []
763
+ accumulated_usage["validation_errors"].append(error_msg)
764
+
765
+ error_result = {
766
+ "json_string": "{}",
767
+ "json_object": {},
768
+ "usage": accumulated_usage,
769
+ "field_results": field_results,
770
+ "error": error_msg,
771
+ }
772
+ return type(
773
+ "ExtractResult",
774
+ (dict,),
775
+ {"__getattr__": lambda self, key: self.get(key), "__call__": lambda self: None},
776
+ )(error_result)
777
+
778
+
779
+ def _has_default(field_name: str, field_info: Any, field_definitions: dict[str, Any] | None) -> bool:
780
+ """Check whether a Pydantic field has a usable default value."""
781
+ if field_definitions and field_name in field_definitions:
782
+ fd = field_definitions[field_name]
783
+ if isinstance(fd, dict) and "default" in fd:
784
+ return True
785
+ if hasattr(field_info, "default"):
786
+ val = field_info.default
787
+ if val is not ... and str(val) != "PydanticUndefined":
788
+ return True
789
+ return False