prompture 0.0.35__py3-none-any.whl → 0.0.40.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. prompture/__init__.py +132 -3
  2. prompture/_version.py +2 -2
  3. prompture/agent.py +924 -0
  4. prompture/agent_types.py +156 -0
  5. prompture/async_agent.py +880 -0
  6. prompture/async_conversation.py +208 -17
  7. prompture/async_core.py +16 -0
  8. prompture/async_driver.py +63 -0
  9. prompture/async_groups.py +551 -0
  10. prompture/conversation.py +222 -18
  11. prompture/core.py +46 -12
  12. prompture/cost_mixin.py +37 -0
  13. prompture/discovery.py +132 -44
  14. prompture/driver.py +77 -0
  15. prompture/drivers/__init__.py +5 -1
  16. prompture/drivers/async_azure_driver.py +11 -5
  17. prompture/drivers/async_claude_driver.py +184 -9
  18. prompture/drivers/async_google_driver.py +222 -28
  19. prompture/drivers/async_grok_driver.py +11 -5
  20. prompture/drivers/async_groq_driver.py +11 -5
  21. prompture/drivers/async_lmstudio_driver.py +74 -5
  22. prompture/drivers/async_ollama_driver.py +13 -3
  23. prompture/drivers/async_openai_driver.py +162 -5
  24. prompture/drivers/async_openrouter_driver.py +11 -5
  25. prompture/drivers/async_registry.py +5 -1
  26. prompture/drivers/azure_driver.py +10 -4
  27. prompture/drivers/claude_driver.py +17 -1
  28. prompture/drivers/google_driver.py +227 -33
  29. prompture/drivers/grok_driver.py +11 -5
  30. prompture/drivers/groq_driver.py +11 -5
  31. prompture/drivers/lmstudio_driver.py +73 -8
  32. prompture/drivers/ollama_driver.py +16 -5
  33. prompture/drivers/openai_driver.py +26 -11
  34. prompture/drivers/openrouter_driver.py +11 -5
  35. prompture/drivers/vision_helpers.py +153 -0
  36. prompture/group_types.py +147 -0
  37. prompture/groups.py +530 -0
  38. prompture/image.py +180 -0
  39. prompture/ledger.py +252 -0
  40. prompture/model_rates.py +112 -2
  41. prompture/persistence.py +254 -0
  42. prompture/persona.py +482 -0
  43. prompture/serialization.py +218 -0
  44. prompture/settings.py +1 -0
  45. prompture-0.0.40.dev1.dist-info/METADATA +369 -0
  46. prompture-0.0.40.dev1.dist-info/RECORD +78 -0
  47. prompture-0.0.35.dist-info/METADATA +0 -464
  48. prompture-0.0.35.dist-info/RECORD +0 -66
  49. {prompture-0.0.35.dist-info → prompture-0.0.40.dev1.dist-info}/WHEEL +0 -0
  50. {prompture-0.0.35.dist-info → prompture-0.0.40.dev1.dist-info}/entry_points.txt +0 -0
  51. {prompture-0.0.35.dist-info → prompture-0.0.40.dev1.dist-info}/licenses/LICENSE +0 -0
  52. {prompture-0.0.35.dist-info → prompture-0.0.40.dev1.dist-info}/top_level.txt +0 -0
prompture/agent.py ADDED
@@ -0,0 +1,924 @@
1
+ """Agent framework for Prompture.
2
+
3
+ Provides a reusable :class:`Agent` that wraps a ReAct-style loop around
4
+ :class:`~prompture.conversation.Conversation`, with optional structured
5
+ output via Pydantic models and tool support via :class:`ToolRegistry`.
6
+
7
+ Example::
8
+
9
+ from prompture import Agent
10
+
11
+ agent = Agent("openai/gpt-4o", system_prompt="You are a helpful assistant.")
12
+ result = agent.run("What is the capital of France?")
13
+ print(result.output)
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import inspect
19
+ import json
20
+ import logging
21
+ import time
22
+ import typing
23
+ from collections.abc import Callable, Generator, Iterator
24
+ from typing import Any, Generic
25
+
26
+ from pydantic import BaseModel
27
+
28
+ from .agent_types import (
29
+ AgentCallbacks,
30
+ AgentResult,
31
+ AgentState,
32
+ AgentStep,
33
+ DepsType,
34
+ ModelRetry,
35
+ RunContext,
36
+ StepType,
37
+ StreamEvent,
38
+ StreamEventType,
39
+ )
40
+ from .callbacks import DriverCallbacks
41
+ from .conversation import Conversation
42
+ from .driver import Driver
43
+ from .persona import Persona
44
+ from .session import UsageSession
45
+ from .tools import clean_json_text
46
+ from .tools_schema import ToolDefinition, ToolRegistry
47
+
48
+ logger = logging.getLogger("prompture.agent")
49
+
50
+ _OUTPUT_PARSE_MAX_RETRIES = 3
51
+ _OUTPUT_GUARDRAIL_MAX_RETRIES = 3
52
+
53
+
54
+ # ------------------------------------------------------------------
55
+ # Module-level helpers for RunContext injection
56
+ # ------------------------------------------------------------------
57
+
58
+
59
+ def _tool_wants_context(fn: Callable[..., Any]) -> bool:
60
+ """Check whether *fn*'s first parameter is annotated as :class:`RunContext`.
61
+
62
+ Uses :func:`typing.get_type_hints` to resolve string annotations
63
+ (from ``from __future__ import annotations``). Falls back to raw
64
+ ``__annotations__`` when ``get_type_hints`` cannot resolve local types.
65
+ """
66
+ sig = inspect.signature(fn)
67
+ params = list(sig.parameters.keys())
68
+ if not params:
69
+ return False
70
+
71
+ first_param = params[0]
72
+ if first_param == "self":
73
+ if len(params) < 2:
74
+ return False
75
+ first_param = params[1]
76
+
77
+ # Try get_type_hints first (resolves string annotations)
78
+ annotation = None
79
+ try:
80
+ hints = typing.get_type_hints(fn, include_extras=True)
81
+ annotation = hints.get(first_param)
82
+ except Exception:
83
+ # get_type_hints can fail with local/forward references; fall back to raw annotation
84
+ pass
85
+
86
+ # Fallback: inspect raw annotation (may be a string)
87
+ if annotation is None:
88
+ raw = sig.parameters[first_param].annotation
89
+ if raw is inspect.Parameter.empty:
90
+ return False
91
+ annotation = raw
92
+
93
+ # String annotation: check if it starts with "RunContext"
94
+ if isinstance(annotation, str):
95
+ return annotation == "RunContext" or annotation.startswith("RunContext[")
96
+
97
+ # Direct match
98
+ if annotation is RunContext:
99
+ return True
100
+
101
+ # Generic alias: RunContext[X]
102
+ origin = getattr(annotation, "__origin__", None)
103
+ return origin is RunContext
104
+
105
+
106
+ def _get_first_param_name(fn: Callable[..., Any]) -> str:
107
+ """Return the name of the first non-self parameter of *fn*."""
108
+ sig = inspect.signature(fn)
109
+ for name, _param in sig.parameters.items():
110
+ if name != "self":
111
+ return name
112
+ return ""
113
+
114
+
115
+ # ------------------------------------------------------------------
116
+ # Agent
117
+ # ------------------------------------------------------------------
118
+
119
+
120
+ class Agent(Generic[DepsType]):
121
+ """A reusable agent that executes a ReAct loop with tool support.
122
+
123
+ Each call to :meth:`run` creates a fresh :class:`Conversation`,
124
+ preventing state leakage between runs. The Agent itself is a
125
+ template holding model config, tools, and system prompt.
126
+
127
+ Args:
128
+ model: Model string in ``"provider/model"`` format.
129
+ driver: Pre-built driver instance (useful for testing).
130
+ tools: Initial tools as a list of callables or a
131
+ :class:`ToolRegistry`.
132
+ system_prompt: System prompt prepended to every run. May also
133
+ be a callable ``(RunContext) -> str`` for dynamic prompts.
134
+ output_type: Optional Pydantic model class. When set, the
135
+ final LLM response is parsed and validated against this type.
136
+ max_iterations: Maximum tool-use rounds per run.
137
+ max_cost: Soft budget in USD. When exceeded, output parse and
138
+ guardrail retries are skipped.
139
+ options: Extra driver options forwarded to every call.
140
+ deps_type: Type hint for dependencies (for docs/IDE only).
141
+ agent_callbacks: Agent-level observability callbacks.
142
+ input_guardrails: Functions called before the prompt is sent.
143
+ output_guardrails: Functions called after output is parsed.
144
+ """
145
+
146
+ def __init__(
147
+ self,
148
+ model: str = "",
149
+ *,
150
+ driver: Driver | None = None,
151
+ tools: list[Callable[..., Any]] | ToolRegistry | None = None,
152
+ system_prompt: str | Persona | Callable[..., str] | None = None,
153
+ output_type: type[BaseModel] | None = None,
154
+ max_iterations: int = 10,
155
+ max_cost: float | None = None,
156
+ options: dict[str, Any] | None = None,
157
+ deps_type: type | None = None,
158
+ agent_callbacks: AgentCallbacks | None = None,
159
+ input_guardrails: list[Callable[..., Any]] | None = None,
160
+ output_guardrails: list[Callable[..., Any]] | None = None,
161
+ name: str = "",
162
+ description: str = "",
163
+ output_key: str | None = None,
164
+ ) -> None:
165
+ if not model and driver is None:
166
+ raise ValueError("Either model or driver must be provided")
167
+
168
+ self._model = model
169
+ self._driver = driver
170
+ self._system_prompt = system_prompt
171
+ self._output_type = output_type
172
+ self._max_iterations = max_iterations
173
+ self._max_cost = max_cost
174
+ self._options = dict(options) if options else {}
175
+ self._deps_type = deps_type
176
+ self._agent_callbacks = agent_callbacks or AgentCallbacks()
177
+ self._input_guardrails = list(input_guardrails) if input_guardrails else []
178
+ self._output_guardrails = list(output_guardrails) if output_guardrails else []
179
+ self.name = name
180
+ self.description = description
181
+ self.output_key = output_key
182
+
183
+ # Build internal tool registry
184
+ self._tools = ToolRegistry()
185
+ if isinstance(tools, ToolRegistry):
186
+ self._tools = tools
187
+ elif tools is not None:
188
+ for fn in tools:
189
+ self._tools.register(fn)
190
+
191
+ self._state = AgentState.idle
192
+ self._stop_requested = False
193
+
194
+ # ------------------------------------------------------------------
195
+ # Public API
196
+ # ------------------------------------------------------------------
197
+
198
+ def tool(self, fn: Callable[..., Any]) -> Callable[..., Any]:
199
+ """Decorator to register a function as a tool on this agent.
200
+
201
+ Returns the original function unchanged.
202
+ """
203
+ self._tools.register(fn)
204
+ return fn
205
+
206
+ @property
207
+ def state(self) -> AgentState:
208
+ """Current lifecycle state of the agent."""
209
+ return self._state
210
+
211
+ def stop(self) -> None:
212
+ """Request graceful shutdown after the current iteration."""
213
+ self._stop_requested = True
214
+
215
+ def as_tool(
216
+ self,
217
+ name: str | None = None,
218
+ description: str | None = None,
219
+ custom_output_extractor: Callable[[AgentResult], str] | None = None,
220
+ ) -> ToolDefinition:
221
+ """Wrap this Agent as a callable tool for another Agent.
222
+
223
+ Creates a :class:`ToolDefinition` whose function accepts a ``prompt``
224
+ string, runs this agent, and returns the output text.
225
+
226
+ Args:
227
+ name: Tool name (defaults to ``self.name`` or ``"agent_tool"``).
228
+ description: Tool description (defaults to ``self.description``).
229
+ custom_output_extractor: Optional function to extract a string
230
+ from :class:`AgentResult`. Defaults to ``result.output_text``.
231
+ """
232
+ tool_name = name or self.name or "agent_tool"
233
+ tool_desc = description or self.description or f"Run agent {tool_name}"
234
+ agent = self
235
+ extractor = custom_output_extractor
236
+
237
+ def _call_agent(prompt: str) -> str:
238
+ """Run the wrapped agent with the given prompt."""
239
+ result = agent.run(prompt)
240
+ if extractor is not None:
241
+ return extractor(result)
242
+ return result.output_text
243
+
244
+ return ToolDefinition(
245
+ name=tool_name,
246
+ description=tool_desc,
247
+ parameters={
248
+ "type": "object",
249
+ "properties": {
250
+ "prompt": {"type": "string", "description": "The prompt to send to the agent"},
251
+ },
252
+ "required": ["prompt"],
253
+ },
254
+ function=_call_agent,
255
+ )
256
+
257
+ def run(self, prompt: str, *, deps: Any = None) -> AgentResult:
258
+ """Execute the agent loop to completion.
259
+
260
+ Creates a fresh :class:`Conversation`, sends the prompt,
261
+ handles any tool calls, and optionally parses the final
262
+ response into an ``output_type`` Pydantic model.
263
+
264
+ Args:
265
+ prompt: The user prompt to send.
266
+ deps: Optional dependencies injected into :class:`RunContext`.
267
+ """
268
+ self._state = AgentState.running
269
+ self._stop_requested = False
270
+ steps: list[AgentStep] = []
271
+
272
+ try:
273
+ result = self._execute(prompt, steps, deps)
274
+ self._state = AgentState.idle
275
+ return result
276
+ except Exception:
277
+ self._state = AgentState.errored
278
+ raise
279
+
280
+ # ------------------------------------------------------------------
281
+ # RunContext helpers
282
+ # ------------------------------------------------------------------
283
+
284
+ def _build_run_context(
285
+ self,
286
+ prompt: str,
287
+ deps: Any,
288
+ session: UsageSession,
289
+ messages: list[dict[str, Any]],
290
+ iteration: int,
291
+ ) -> RunContext[Any]:
292
+ """Create a :class:`RunContext` snapshot for the current run."""
293
+ return RunContext(
294
+ deps=deps,
295
+ model=self._model,
296
+ usage=session.summary(),
297
+ messages=list(messages),
298
+ iteration=iteration,
299
+ prompt=prompt,
300
+ )
301
+
302
+ # ------------------------------------------------------------------
303
+ # Tool wrapping (RunContext injection + ModelRetry + callbacks)
304
+ # ------------------------------------------------------------------
305
+
306
+ def _wrap_tools_with_context(self, ctx: RunContext[Any]) -> ToolRegistry:
307
+ """Return a new :class:`ToolRegistry` with wrapped tool functions.
308
+
309
+ For each registered tool:
310
+ - If the tool's first param is ``RunContext``, inject *ctx* automatically.
311
+ - Catch :class:`ModelRetry` and convert to an error string.
312
+ - Fire ``agent_callbacks.on_tool_start`` / ``on_tool_end``.
313
+ - Strip the ``RunContext`` parameter from the JSON schema sent to the LLM.
314
+ """
315
+ if not self._tools:
316
+ return ToolRegistry()
317
+
318
+ new_registry = ToolRegistry()
319
+
320
+ cb = self._agent_callbacks
321
+
322
+ for td in self._tools.definitions:
323
+ wants_ctx = _tool_wants_context(td.function)
324
+ original_fn = td.function
325
+ tool_name = td.name
326
+
327
+ def _make_wrapper(
328
+ _fn: Callable[..., Any],
329
+ _wants: bool,
330
+ _name: str,
331
+ _cb: AgentCallbacks = cb,
332
+ ) -> Callable[..., Any]:
333
+ def wrapper(**kwargs: Any) -> Any:
334
+ if _cb.on_tool_start:
335
+ _cb.on_tool_start(_name, kwargs)
336
+ try:
337
+ if _wants:
338
+ result = _fn(ctx, **kwargs)
339
+ else:
340
+ result = _fn(**kwargs)
341
+ except ModelRetry as exc:
342
+ result = f"Error: {exc.message}"
343
+ if _cb.on_tool_end:
344
+ _cb.on_tool_end(_name, result)
345
+ return result
346
+
347
+ return wrapper
348
+
349
+ wrapped = _make_wrapper(original_fn, wants_ctx, tool_name)
350
+
351
+ # Build schema: strip RunContext param if present
352
+ params = dict(td.parameters)
353
+ if wants_ctx:
354
+ ctx_param_name = _get_first_param_name(td.function)
355
+ props = dict(params.get("properties", {}))
356
+ props.pop(ctx_param_name, None)
357
+ params = dict(params)
358
+ params["properties"] = props
359
+ req = list(params.get("required", []))
360
+ if ctx_param_name in req:
361
+ req.remove(ctx_param_name)
362
+ if req:
363
+ params["required"] = req
364
+ elif "required" in params:
365
+ del params["required"]
366
+
367
+ new_td = ToolDefinition(
368
+ name=td.name,
369
+ description=td.description,
370
+ parameters=params,
371
+ function=wrapped,
372
+ )
373
+ new_registry.add(new_td)
374
+
375
+ return new_registry
376
+
377
+ # ------------------------------------------------------------------
378
+ # Guardrails
379
+ # ------------------------------------------------------------------
380
+
381
+ def _run_input_guardrails(self, ctx: RunContext[Any], prompt: str) -> str:
382
+ """Execute input guardrails in order. Returns the (possibly transformed) prompt.
383
+
384
+ Each guardrail receives ``(ctx, prompt)`` and may:
385
+ - Return a ``str`` to transform the prompt.
386
+ - Return ``None`` to leave it unchanged.
387
+ - Raise :class:`GuardrailError` to reject entirely.
388
+ """
389
+ for guardrail in self._input_guardrails:
390
+ result = guardrail(ctx, prompt)
391
+ if result is not None:
392
+ prompt = result
393
+ return prompt
394
+
395
+ def _run_output_guardrails(
396
+ self,
397
+ ctx: RunContext[Any],
398
+ result: AgentResult,
399
+ conv: Conversation,
400
+ session: UsageSession,
401
+ steps: list[AgentStep],
402
+ all_tool_calls: list[dict[str, Any]],
403
+ ) -> AgentResult:
404
+ """Execute output guardrails. Returns the (possibly modified) result.
405
+
406
+ Each guardrail receives ``(ctx, result)`` and may:
407
+ - Return ``None`` to pass (no change).
408
+ - Return an :class:`AgentResult` to modify the result.
409
+ - Raise :class:`ModelRetry` to re-prompt the LLM (up to 3 retries).
410
+ """
411
+ for guardrail in self._output_guardrails:
412
+ for attempt in range(_OUTPUT_GUARDRAIL_MAX_RETRIES):
413
+ try:
414
+ guard_result = guardrail(ctx, result)
415
+ if guard_result is not None:
416
+ result = guard_result
417
+ break # guardrail passed
418
+ except ModelRetry as exc:
419
+ if self._is_over_budget(session):
420
+ logger.debug("Over budget, skipping output guardrail retry")
421
+ break
422
+ if attempt >= _OUTPUT_GUARDRAIL_MAX_RETRIES - 1:
423
+ raise ValueError(
424
+ f"Output guardrail failed after {_OUTPUT_GUARDRAIL_MAX_RETRIES} retries: {exc.message}"
425
+ ) from exc
426
+ # Re-prompt the LLM
427
+ retry_text = conv.ask(
428
+ f"Your response did not pass validation. Error: {exc.message}\n\nPlease try again."
429
+ )
430
+ self._extract_steps(conv.messages[-2:], steps, all_tool_calls)
431
+
432
+ # Re-parse if output_type is set
433
+ if self._output_type is not None:
434
+ try:
435
+ cleaned = clean_json_text(retry_text)
436
+ parsed = json.loads(cleaned)
437
+ output = self._output_type.model_validate(parsed)
438
+ except Exception:
439
+ output = retry_text
440
+ else:
441
+ output = retry_text
442
+
443
+ result = AgentResult(
444
+ output=output,
445
+ output_text=retry_text,
446
+ messages=conv.messages,
447
+ usage=conv.usage,
448
+ steps=steps,
449
+ all_tool_calls=all_tool_calls,
450
+ state=AgentState.idle,
451
+ run_usage=session.summary(),
452
+ )
453
+ return result
454
+
455
+ # ------------------------------------------------------------------
456
+ # Budget check
457
+ # ------------------------------------------------------------------
458
+
459
+ def _is_over_budget(self, session: UsageSession) -> bool:
460
+ """Return True if max_cost is set and the session has exceeded it."""
461
+ if self._max_cost is None:
462
+ return False
463
+ return session.total_cost >= self._max_cost
464
+
465
+ # ------------------------------------------------------------------
466
+ # Internals
467
+ # ------------------------------------------------------------------
468
+
469
+ def _resolve_system_prompt(self, ctx: RunContext[Any] | None = None) -> str | None:
470
+ """Build the system prompt, appending output schema if needed."""
471
+ parts: list[str] = []
472
+
473
+ if self._system_prompt is not None:
474
+ if isinstance(self._system_prompt, Persona):
475
+ # Render Persona with RunContext variables if available
476
+ render_kwargs: dict[str, Any] = {}
477
+ if ctx is not None:
478
+ render_kwargs["model"] = ctx.model
479
+ render_kwargs["iteration"] = ctx.iteration
480
+ parts.append(self._system_prompt.render(**render_kwargs))
481
+ elif callable(self._system_prompt) and not isinstance(self._system_prompt, str):
482
+ if ctx is not None:
483
+ parts.append(self._system_prompt(ctx))
484
+ else:
485
+ # Fallback: call without context (shouldn't happen in normal flow)
486
+ parts.append(self._system_prompt(None)) # type: ignore[arg-type]
487
+ else:
488
+ parts.append(str(self._system_prompt))
489
+
490
+ if self._output_type is not None:
491
+ schema = self._output_type.model_json_schema()
492
+ schema_str = json.dumps(schema, indent=2)
493
+ parts.append(
494
+ "You MUST respond with a single JSON object (no markdown, "
495
+ "no extra text) that validates against this JSON schema:\n"
496
+ f"{schema_str}\n\n"
497
+ "Use double quotes for keys and strings. "
498
+ "If a value is unknown use null."
499
+ )
500
+
501
+ return "\n\n".join(parts) if parts else None
502
+
503
+ def _build_conversation(
504
+ self,
505
+ system_prompt: str | None = None,
506
+ tools: ToolRegistry | None = None,
507
+ driver_callbacks: DriverCallbacks | None = None,
508
+ ) -> Conversation:
509
+ """Create a fresh Conversation for a single run."""
510
+ effective_tools = tools if tools is not None else (self._tools if self._tools else None)
511
+
512
+ kwargs: dict[str, Any] = {
513
+ "system_prompt": system_prompt if system_prompt is not None else self._resolve_system_prompt(),
514
+ "tools": effective_tools,
515
+ "max_tool_rounds": self._max_iterations,
516
+ }
517
+ if self._options:
518
+ kwargs["options"] = self._options
519
+ if driver_callbacks is not None:
520
+ kwargs["callbacks"] = driver_callbacks
521
+
522
+ if self._driver is not None:
523
+ kwargs["driver"] = self._driver
524
+ else:
525
+ kwargs["model_name"] = self._model
526
+
527
+ return Conversation(**kwargs)
528
+
529
+ def _execute(self, prompt: str, steps: list[AgentStep], deps: Any) -> AgentResult:
530
+ """Core execution: run conversation, extract steps, parse output."""
531
+ # 1. Create per-run UsageSession and wire into DriverCallbacks
532
+ session = UsageSession()
533
+ driver_callbacks = DriverCallbacks(
534
+ on_response=session.record,
535
+ on_error=session.record_error,
536
+ )
537
+
538
+ # 2. Build initial RunContext
539
+ ctx = self._build_run_context(prompt, deps, session, [], 0)
540
+
541
+ # 3. Run input guardrails
542
+ effective_prompt = self._run_input_guardrails(ctx, prompt)
543
+
544
+ # 4. Resolve system prompt (call it if callable, passing ctx)
545
+ resolved_system_prompt = self._resolve_system_prompt(ctx)
546
+
547
+ # 5. Wrap tools with context
548
+ wrapped_tools = self._wrap_tools_with_context(ctx)
549
+
550
+ # 6. Build Conversation
551
+ conv = self._build_conversation(
552
+ system_prompt=resolved_system_prompt,
553
+ tools=wrapped_tools if wrapped_tools else None,
554
+ driver_callbacks=driver_callbacks,
555
+ )
556
+
557
+ # 7. Fire on_iteration callback
558
+ if self._agent_callbacks.on_iteration:
559
+ self._agent_callbacks.on_iteration(0)
560
+
561
+ # 8. Ask the conversation (handles full tool loop internally)
562
+ t0 = time.perf_counter()
563
+ response_text = conv.ask(effective_prompt)
564
+ elapsed_ms = (time.perf_counter() - t0) * 1000
565
+
566
+ # 9. Extract steps and tool calls from conversation messages
567
+ all_tool_calls: list[dict[str, Any]] = []
568
+ self._extract_steps(conv.messages, steps, all_tool_calls)
569
+
570
+ # Handle output_type parsing
571
+ if self._output_type is not None:
572
+ output, output_text = self._parse_output(conv, response_text, steps, all_tool_calls, elapsed_ms, session)
573
+ else:
574
+ output = response_text
575
+ output_text = response_text
576
+
577
+ # Build result with run_usage
578
+ result = AgentResult(
579
+ output=output,
580
+ output_text=output_text,
581
+ messages=conv.messages,
582
+ usage=conv.usage,
583
+ steps=steps,
584
+ all_tool_calls=all_tool_calls,
585
+ state=AgentState.idle,
586
+ run_usage=session.summary(),
587
+ )
588
+
589
+ # 10. Run output guardrails
590
+ if self._output_guardrails:
591
+ result = self._run_output_guardrails(ctx, result, conv, session, steps, all_tool_calls)
592
+
593
+ # 11. Fire callbacks
594
+ if self._agent_callbacks.on_step:
595
+ for step in steps:
596
+ self._agent_callbacks.on_step(step)
597
+
598
+ if self._agent_callbacks.on_output:
599
+ self._agent_callbacks.on_output(result)
600
+
601
+ return result
602
+
603
+ def _extract_steps(
604
+ self,
605
+ messages: list[dict[str, Any]],
606
+ steps: list[AgentStep],
607
+ all_tool_calls: list[dict[str, Any]],
608
+ ) -> None:
609
+ """Scan conversation messages and populate steps and tool_calls."""
610
+ now = time.time()
611
+
612
+ for msg in messages:
613
+ role = msg.get("role", "")
614
+
615
+ if role == "assistant":
616
+ tc_list = msg.get("tool_calls", [])
617
+ if tc_list:
618
+ # Assistant message with tool calls
619
+ for tc in tc_list:
620
+ fn = tc.get("function", {})
621
+ name = fn.get("name", tc.get("name", ""))
622
+ raw_args = fn.get("arguments", tc.get("arguments", "{}"))
623
+ if isinstance(raw_args, str):
624
+ try:
625
+ args = json.loads(raw_args)
626
+ except json.JSONDecodeError:
627
+ args = {}
628
+ else:
629
+ args = raw_args
630
+
631
+ steps.append(
632
+ AgentStep(
633
+ step_type=StepType.tool_call,
634
+ timestamp=now,
635
+ content=msg.get("content", ""),
636
+ tool_name=name,
637
+ tool_args=args,
638
+ )
639
+ )
640
+ all_tool_calls.append({"name": name, "arguments": args, "id": tc.get("id", "")})
641
+ else:
642
+ # Final assistant message (no tool calls)
643
+ steps.append(
644
+ AgentStep(
645
+ step_type=StepType.output,
646
+ timestamp=now,
647
+ content=msg.get("content", ""),
648
+ )
649
+ )
650
+
651
+ elif role == "tool":
652
+ steps.append(
653
+ AgentStep(
654
+ step_type=StepType.tool_result,
655
+ timestamp=now,
656
+ content=msg.get("content", ""),
657
+ tool_name=msg.get("tool_call_id"),
658
+ )
659
+ )
660
+
661
+ def _parse_output(
662
+ self,
663
+ conv: Conversation,
664
+ response_text: str,
665
+ steps: list[AgentStep],
666
+ all_tool_calls: list[dict[str, Any]],
667
+ elapsed_ms: float,
668
+ session: UsageSession | None = None,
669
+ ) -> tuple[Any, str]:
670
+ """Try to parse ``response_text`` as the output_type, with retries."""
671
+ assert self._output_type is not None
672
+
673
+ last_error: Exception | None = None
674
+ text = response_text
675
+
676
+ for attempt in range(_OUTPUT_PARSE_MAX_RETRIES):
677
+ try:
678
+ cleaned = clean_json_text(text)
679
+ parsed = json.loads(cleaned)
680
+ model_instance = self._output_type.model_validate(parsed)
681
+ return model_instance, text
682
+ except Exception as exc:
683
+ last_error = exc
684
+ if attempt < _OUTPUT_PARSE_MAX_RETRIES - 1:
685
+ # Check budget before retrying
686
+ if session is not None and self._is_over_budget(session):
687
+ logger.debug("Over budget, skipping output parse retry")
688
+ break
689
+ logger.debug("Output parse attempt %d failed: %s", attempt + 1, exc)
690
+ retry_msg = (
691
+ f"Your previous response could not be parsed as valid JSON "
692
+ f"matching the required schema. Error: {exc}\n\n"
693
+ f"Please try again and respond ONLY with valid JSON."
694
+ )
695
+ text = conv.ask(retry_msg)
696
+
697
+ # Record the retry step
698
+ self._extract_steps(conv.messages[-2:], steps, all_tool_calls)
699
+
700
+ raise ValueError(
701
+ f"Failed to parse output as {self._output_type.__name__} "
702
+ f"after {_OUTPUT_PARSE_MAX_RETRIES} attempts: {last_error}"
703
+ )
704
+
705
+ # ------------------------------------------------------------------
706
+ # iter() — step-by-step inspection
707
+ # ------------------------------------------------------------------
708
+
709
+ def iter(self, prompt: str, *, deps: Any = None) -> AgentIterator:
710
+ """Execute the agent loop and iterate over steps.
711
+
712
+ Returns an :class:`AgentIterator` that yields :class:`AgentStep`
713
+ objects. After iteration completes, the final :class:`AgentResult`
714
+ is available via :attr:`AgentIterator.result`.
715
+
716
+ Note:
717
+ In Phase 3c the conversation's tool loop runs to completion
718
+ before steps are yielded. True mid-loop yielding is deferred.
719
+ """
720
+ gen = self._execute_iter(prompt, deps)
721
+ return AgentIterator(gen)
722
+
723
+ def _execute_iter(self, prompt: str, deps: Any) -> Generator[AgentStep, None, AgentResult]:
724
+ """Generator that executes the agent loop and yields each step."""
725
+ self._state = AgentState.running
726
+ self._stop_requested = False
727
+ steps: list[AgentStep] = []
728
+
729
+ try:
730
+ result = self._execute(prompt, steps, deps)
731
+ # Yield each step one at a time
732
+ yield from result.steps
733
+ self._state = AgentState.idle
734
+ return result
735
+ except Exception:
736
+ self._state = AgentState.errored
737
+ raise
738
+
739
+ # ------------------------------------------------------------------
740
+ # run_stream() — streaming output
741
+ # ------------------------------------------------------------------
742
+
743
+ def run_stream(self, prompt: str, *, deps: Any = None) -> StreamedAgentResult:
744
+ """Execute the agent loop with streaming output.
745
+
746
+ Returns a :class:`StreamedAgentResult` that yields
747
+ :class:`StreamEvent` objects. After iteration completes, the
748
+ final :class:`AgentResult` is available via
749
+ :attr:`StreamedAgentResult.result`.
750
+
751
+ When tools are registered, streaming falls back to non-streaming
752
+ ``conv.ask()`` and yields the full response as a single
753
+ ``text_delta`` event.
754
+ """
755
+ gen = self._execute_stream(prompt, deps)
756
+ return StreamedAgentResult(gen)
757
+
758
+ def _execute_stream(self, prompt: str, deps: Any) -> Generator[StreamEvent, None, AgentResult]:
759
+ """Generator that executes the agent loop and yields stream events."""
760
+ self._state = AgentState.running
761
+ self._stop_requested = False
762
+ steps: list[AgentStep] = []
763
+
764
+ try:
765
+ # 1. Create per-run UsageSession and wire into DriverCallbacks
766
+ session = UsageSession()
767
+ driver_callbacks = DriverCallbacks(
768
+ on_response=session.record,
769
+ on_error=session.record_error,
770
+ )
771
+
772
+ # 2. Build initial RunContext
773
+ ctx = self._build_run_context(prompt, deps, session, [], 0)
774
+
775
+ # 3. Run input guardrails
776
+ effective_prompt = self._run_input_guardrails(ctx, prompt)
777
+
778
+ # 4. Resolve system prompt
779
+ resolved_system_prompt = self._resolve_system_prompt(ctx)
780
+
781
+ # 5. Wrap tools with context
782
+ wrapped_tools = self._wrap_tools_with_context(ctx)
783
+ has_tools = bool(wrapped_tools)
784
+
785
+ # 6. Build Conversation
786
+ conv = self._build_conversation(
787
+ system_prompt=resolved_system_prompt,
788
+ tools=wrapped_tools if wrapped_tools else None,
789
+ driver_callbacks=driver_callbacks,
790
+ )
791
+
792
+ # 7. Fire on_iteration callback
793
+ if self._agent_callbacks.on_iteration:
794
+ self._agent_callbacks.on_iteration(0)
795
+
796
+ if has_tools:
797
+ # Tools registered: fall back to non-streaming conv.ask()
798
+ response_text = conv.ask(effective_prompt)
799
+
800
+ # Yield the full text as a single delta
801
+ yield StreamEvent(
802
+ event_type=StreamEventType.text_delta,
803
+ data=response_text,
804
+ )
805
+ else:
806
+ # No tools: use streaming if available
807
+ response_text = ""
808
+ stream_iter: Iterator[str] = conv.ask_stream(effective_prompt)
809
+ for chunk in stream_iter:
810
+ response_text += chunk
811
+ yield StreamEvent(
812
+ event_type=StreamEventType.text_delta,
813
+ data=chunk,
814
+ )
815
+
816
+ # 8. Extract steps
817
+ all_tool_calls: list[dict[str, Any]] = []
818
+ self._extract_steps(conv.messages, steps, all_tool_calls)
819
+
820
+ # 9. Parse output
821
+ if self._output_type is not None:
822
+ output, output_text = self._parse_output(conv, response_text, steps, all_tool_calls, 0.0, session)
823
+ else:
824
+ output = response_text
825
+ output_text = response_text
826
+
827
+ # 10. Build result
828
+ result = AgentResult(
829
+ output=output,
830
+ output_text=output_text,
831
+ messages=conv.messages,
832
+ usage=conv.usage,
833
+ steps=steps,
834
+ all_tool_calls=all_tool_calls,
835
+ state=AgentState.idle,
836
+ run_usage=session.summary(),
837
+ )
838
+
839
+ # 11. Run output guardrails
840
+ if self._output_guardrails:
841
+ result = self._run_output_guardrails(ctx, result, conv, session, steps, all_tool_calls)
842
+
843
+ # 12. Fire callbacks
844
+ if self._agent_callbacks.on_step:
845
+ for step in steps:
846
+ self._agent_callbacks.on_step(step)
847
+ if self._agent_callbacks.on_output:
848
+ self._agent_callbacks.on_output(result)
849
+
850
+ # 13. Yield final output event
851
+ yield StreamEvent(
852
+ event_type=StreamEventType.output,
853
+ data=result,
854
+ )
855
+
856
+ self._state = AgentState.idle
857
+ return result
858
+ except Exception:
859
+ self._state = AgentState.errored
860
+ raise
861
+
862
+
863
+ # ------------------------------------------------------------------
864
+ # AgentIterator
865
+ # ------------------------------------------------------------------
866
+
867
+
868
+ class AgentIterator:
869
+ """Wraps the :meth:`Agent.iter` generator, capturing the final result.
870
+
871
+ After iteration completes (the ``for`` loop ends), the
872
+ :attr:`result` property holds the :class:`AgentResult`.
873
+ """
874
+
875
+ def __init__(self, gen: Generator[AgentStep, None, AgentResult]) -> None:
876
+ self._gen = gen
877
+ self._result: AgentResult | None = None
878
+
879
+ def __iter__(self) -> AgentIterator:
880
+ return self
881
+
882
+ def __next__(self) -> AgentStep:
883
+ try:
884
+ return next(self._gen)
885
+ except StopIteration as e:
886
+ self._result = e.value
887
+ raise
888
+
889
+ @property
890
+ def result(self) -> AgentResult | None:
891
+ """The final :class:`AgentResult`, available after iteration completes."""
892
+ return self._result
893
+
894
+
895
+ # ------------------------------------------------------------------
896
+ # StreamedAgentResult
897
+ # ------------------------------------------------------------------
898
+
899
+
900
+ class StreamedAgentResult:
901
+ """Wraps the :meth:`Agent.run_stream` generator, capturing the final result.
902
+
903
+ Yields :class:`StreamEvent` objects during iteration. After iteration
904
+ completes, the :attr:`result` property holds the :class:`AgentResult`.
905
+ """
906
+
907
+ def __init__(self, gen: Generator[StreamEvent, None, AgentResult]) -> None:
908
+ self._gen = gen
909
+ self._result: AgentResult | None = None
910
+
911
+ def __iter__(self) -> StreamedAgentResult:
912
+ return self
913
+
914
+ def __next__(self) -> StreamEvent:
915
+ try:
916
+ return next(self._gen)
917
+ except StopIteration as e:
918
+ self._result = e.value
919
+ raise
920
+
921
+ @property
922
+ def result(self) -> AgentResult | None:
923
+ """The final :class:`AgentResult`, available after iteration completes."""
924
+ return self._result