prompture 0.0.36.dev1__py3-none-any.whl → 0.0.37.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. prompture/__init__.py +120 -2
  2. prompture/_version.py +2 -2
  3. prompture/agent.py +925 -0
  4. prompture/agent_types.py +156 -0
  5. prompture/async_agent.py +879 -0
  6. prompture/async_conversation.py +199 -17
  7. prompture/async_driver.py +24 -0
  8. prompture/async_groups.py +551 -0
  9. prompture/conversation.py +213 -18
  10. prompture/core.py +30 -12
  11. prompture/discovery.py +24 -1
  12. prompture/driver.py +38 -0
  13. prompture/drivers/__init__.py +5 -1
  14. prompture/drivers/async_azure_driver.py +7 -1
  15. prompture/drivers/async_claude_driver.py +7 -1
  16. prompture/drivers/async_google_driver.py +24 -4
  17. prompture/drivers/async_grok_driver.py +7 -1
  18. prompture/drivers/async_groq_driver.py +7 -1
  19. prompture/drivers/async_lmstudio_driver.py +59 -3
  20. prompture/drivers/async_ollama_driver.py +7 -0
  21. prompture/drivers/async_openai_driver.py +7 -1
  22. prompture/drivers/async_openrouter_driver.py +7 -1
  23. prompture/drivers/async_registry.py +5 -1
  24. prompture/drivers/azure_driver.py +7 -1
  25. prompture/drivers/claude_driver.py +7 -1
  26. prompture/drivers/google_driver.py +24 -4
  27. prompture/drivers/grok_driver.py +7 -1
  28. prompture/drivers/groq_driver.py +7 -1
  29. prompture/drivers/lmstudio_driver.py +58 -6
  30. prompture/drivers/ollama_driver.py +7 -0
  31. prompture/drivers/openai_driver.py +7 -1
  32. prompture/drivers/openrouter_driver.py +7 -1
  33. prompture/drivers/vision_helpers.py +153 -0
  34. prompture/group_types.py +147 -0
  35. prompture/groups.py +530 -0
  36. prompture/image.py +180 -0
  37. prompture/persistence.py +254 -0
  38. prompture/persona.py +482 -0
  39. prompture/serialization.py +218 -0
  40. prompture/settings.py +1 -0
  41. {prompture-0.0.36.dev1.dist-info → prompture-0.0.37.dev1.dist-info}/METADATA +1 -1
  42. prompture-0.0.37.dev1.dist-info/RECORD +77 -0
  43. prompture-0.0.36.dev1.dist-info/RECORD +0 -66
  44. {prompture-0.0.36.dev1.dist-info → prompture-0.0.37.dev1.dist-info}/WHEEL +0 -0
  45. {prompture-0.0.36.dev1.dist-info → prompture-0.0.37.dev1.dist-info}/entry_points.txt +0 -0
  46. {prompture-0.0.36.dev1.dist-info → prompture-0.0.37.dev1.dist-info}/licenses/LICENSE +0 -0
  47. {prompture-0.0.36.dev1.dist-info → prompture-0.0.37.dev1.dist-info}/top_level.txt +0 -0
prompture/agent.py ADDED
@@ -0,0 +1,925 @@
1
+ """Agent framework for Prompture.
2
+
3
+ Provides a reusable :class:`Agent` that wraps a ReAct-style loop around
4
+ :class:`~prompture.conversation.Conversation`, with optional structured
5
+ output via Pydantic models and tool support via :class:`ToolRegistry`.
6
+
7
+ Example::
8
+
9
+ from prompture import Agent
10
+
11
+ agent = Agent("openai/gpt-4o", system_prompt="You are a helpful assistant.")
12
+ result = agent.run("What is the capital of France?")
13
+ print(result.output)
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import inspect
19
+ import json
20
+ import logging
21
+ import time
22
+ import typing
23
+ from collections.abc import Callable, Generator, Iterator
24
+ from typing import Any, Generic
25
+
26
+ from pydantic import BaseModel
27
+
28
+ from .agent_types import (
29
+ AgentCallbacks,
30
+ AgentResult,
31
+ AgentState,
32
+ AgentStep,
33
+ DepsType,
34
+ ModelRetry,
35
+ RunContext,
36
+ StepType,
37
+ StreamEvent,
38
+ StreamEventType,
39
+ )
40
+ from .callbacks import DriverCallbacks
41
+ from .conversation import Conversation
42
+ from .driver import Driver
43
+ from .persona import Persona
44
+ from .session import UsageSession
45
+ from .tools import clean_json_text
46
+ from .tools_schema import ToolDefinition, ToolRegistry
47
+
48
+ logger = logging.getLogger("prompture.agent")
49
+
50
+ _OUTPUT_PARSE_MAX_RETRIES = 3
51
+ _OUTPUT_GUARDRAIL_MAX_RETRIES = 3
52
+
53
+
54
+ # ------------------------------------------------------------------
55
+ # Module-level helpers for RunContext injection
56
+ # ------------------------------------------------------------------
57
+
58
+
59
+ def _tool_wants_context(fn: Callable[..., Any]) -> bool:
60
+ """Check whether *fn*'s first parameter is annotated as :class:`RunContext`.
61
+
62
+ Uses :func:`typing.get_type_hints` to resolve string annotations
63
+ (from ``from __future__ import annotations``). Falls back to raw
64
+ ``__annotations__`` when ``get_type_hints`` cannot resolve local types.
65
+ """
66
+ sig = inspect.signature(fn)
67
+ params = list(sig.parameters.keys())
68
+ if not params:
69
+ return False
70
+
71
+ first_param = params[0]
72
+ if first_param == "self":
73
+ if len(params) < 2:
74
+ return False
75
+ first_param = params[1]
76
+
77
+ # Try get_type_hints first (resolves string annotations)
78
+ annotation = None
79
+ try:
80
+ hints = typing.get_type_hints(fn, include_extras=True)
81
+ annotation = hints.get(first_param)
82
+ except Exception:
83
+ pass
84
+
85
+ # Fallback: inspect raw annotation (may be a string)
86
+ if annotation is None:
87
+ raw = sig.parameters[first_param].annotation
88
+ if raw is inspect.Parameter.empty:
89
+ return False
90
+ annotation = raw
91
+
92
+ # String annotation: check if it starts with "RunContext"
93
+ if isinstance(annotation, str):
94
+ return annotation == "RunContext" or annotation.startswith("RunContext[")
95
+
96
+ # Direct match
97
+ if annotation is RunContext:
98
+ return True
99
+
100
+ # Generic alias: RunContext[X]
101
+ origin = getattr(annotation, "__origin__", None)
102
+ return origin is RunContext
103
+
104
+
105
+ def _get_first_param_name(fn: Callable[..., Any]) -> str:
106
+ """Return the name of the first non-self parameter of *fn*."""
107
+ sig = inspect.signature(fn)
108
+ for name, _param in sig.parameters.items():
109
+ if name != "self":
110
+ return name
111
+ return ""
112
+
113
+
114
+ # ------------------------------------------------------------------
115
+ # Agent
116
+ # ------------------------------------------------------------------
117
+
118
+
119
+ class Agent(Generic[DepsType]):
120
+ """A reusable agent that executes a ReAct loop with tool support.
121
+
122
+ Each call to :meth:`run` creates a fresh :class:`Conversation`,
123
+ preventing state leakage between runs. The Agent itself is a
124
+ template holding model config, tools, and system prompt.
125
+
126
+ Args:
127
+ model: Model string in ``"provider/model"`` format.
128
+ driver: Pre-built driver instance (useful for testing).
129
+ tools: Initial tools as a list of callables or a
130
+ :class:`ToolRegistry`.
131
+ system_prompt: System prompt prepended to every run. May also
132
+ be a callable ``(RunContext) -> str`` for dynamic prompts.
133
+ output_type: Optional Pydantic model class. When set, the
134
+ final LLM response is parsed and validated against this type.
135
+ max_iterations: Maximum tool-use rounds per run.
136
+ max_cost: Soft budget in USD. When exceeded, output parse and
137
+ guardrail retries are skipped.
138
+ options: Extra driver options forwarded to every call.
139
+ deps_type: Type hint for dependencies (for docs/IDE only).
140
+ agent_callbacks: Agent-level observability callbacks.
141
+ input_guardrails: Functions called before the prompt is sent.
142
+ output_guardrails: Functions called after output is parsed.
143
+ """
144
+
145
+ def __init__(
146
+ self,
147
+ model: str = "",
148
+ *,
149
+ driver: Driver | None = None,
150
+ tools: list[Callable[..., Any]] | ToolRegistry | None = None,
151
+ system_prompt: str | Persona | Callable[..., str] | None = None,
152
+ output_type: type[BaseModel] | None = None,
153
+ max_iterations: int = 10,
154
+ max_cost: float | None = None,
155
+ options: dict[str, Any] | None = None,
156
+ deps_type: type | None = None,
157
+ agent_callbacks: AgentCallbacks | None = None,
158
+ input_guardrails: list[Callable[..., Any]] | None = None,
159
+ output_guardrails: list[Callable[..., Any]] | None = None,
160
+ name: str = "",
161
+ description: str = "",
162
+ output_key: str | None = None,
163
+ ) -> None:
164
+ if not model and driver is None:
165
+ raise ValueError("Either model or driver must be provided")
166
+
167
+ self._model = model
168
+ self._driver = driver
169
+ self._system_prompt = system_prompt
170
+ self._output_type = output_type
171
+ self._max_iterations = max_iterations
172
+ self._max_cost = max_cost
173
+ self._options = dict(options) if options else {}
174
+ self._deps_type = deps_type
175
+ self._agent_callbacks = agent_callbacks or AgentCallbacks()
176
+ self._input_guardrails = list(input_guardrails) if input_guardrails else []
177
+ self._output_guardrails = list(output_guardrails) if output_guardrails else []
178
+ self.name = name
179
+ self.description = description
180
+ self.output_key = output_key
181
+
182
+ # Build internal tool registry
183
+ self._tools = ToolRegistry()
184
+ if isinstance(tools, ToolRegistry):
185
+ self._tools = tools
186
+ elif tools is not None:
187
+ for fn in tools:
188
+ self._tools.register(fn)
189
+
190
+ self._state = AgentState.idle
191
+ self._stop_requested = False
192
+
193
+ # ------------------------------------------------------------------
194
+ # Public API
195
+ # ------------------------------------------------------------------
196
+
197
+ def tool(self, fn: Callable[..., Any]) -> Callable[..., Any]:
198
+ """Decorator to register a function as a tool on this agent.
199
+
200
+ Returns the original function unchanged.
201
+ """
202
+ self._tools.register(fn)
203
+ return fn
204
+
205
+ @property
206
+ def state(self) -> AgentState:
207
+ """Current lifecycle state of the agent."""
208
+ return self._state
209
+
210
+ def stop(self) -> None:
211
+ """Request graceful shutdown after the current iteration."""
212
+ self._stop_requested = True
213
+
214
+ def as_tool(
215
+ self,
216
+ name: str | None = None,
217
+ description: str | None = None,
218
+ custom_output_extractor: Callable[[AgentResult], str] | None = None,
219
+ ) -> ToolDefinition:
220
+ """Wrap this Agent as a callable tool for another Agent.
221
+
222
+ Creates a :class:`ToolDefinition` whose function accepts a ``prompt``
223
+ string, runs this agent, and returns the output text.
224
+
225
+ Args:
226
+ name: Tool name (defaults to ``self.name`` or ``"agent_tool"``).
227
+ description: Tool description (defaults to ``self.description``).
228
+ custom_output_extractor: Optional function to extract a string
229
+ from :class:`AgentResult`. Defaults to ``result.output_text``.
230
+ """
231
+ tool_name = name or self.name or "agent_tool"
232
+ tool_desc = description or self.description or f"Run agent {tool_name}"
233
+ agent = self
234
+ extractor = custom_output_extractor
235
+
236
+ def _call_agent(prompt: str) -> str:
237
+ """Run the wrapped agent with the given prompt."""
238
+ result = agent.run(prompt)
239
+ if extractor is not None:
240
+ return extractor(result)
241
+ return result.output_text
242
+
243
+ return ToolDefinition(
244
+ name=tool_name,
245
+ description=tool_desc,
246
+ parameters={
247
+ "type": "object",
248
+ "properties": {
249
+ "prompt": {"type": "string", "description": "The prompt to send to the agent"},
250
+ },
251
+ "required": ["prompt"],
252
+ },
253
+ function=_call_agent,
254
+ )
255
+
256
+ def run(self, prompt: str, *, deps: Any = None) -> AgentResult:
257
+ """Execute the agent loop to completion.
258
+
259
+ Creates a fresh :class:`Conversation`, sends the prompt,
260
+ handles any tool calls, and optionally parses the final
261
+ response into an ``output_type`` Pydantic model.
262
+
263
+ Args:
264
+ prompt: The user prompt to send.
265
+ deps: Optional dependencies injected into :class:`RunContext`.
266
+ """
267
+ self._state = AgentState.running
268
+ self._stop_requested = False
269
+ steps: list[AgentStep] = []
270
+
271
+ try:
272
+ result = self._execute(prompt, steps, deps)
273
+ self._state = AgentState.idle
274
+ return result
275
+ except Exception:
276
+ self._state = AgentState.errored
277
+ raise
278
+
279
+ # ------------------------------------------------------------------
280
+ # RunContext helpers
281
+ # ------------------------------------------------------------------
282
+
283
+ def _build_run_context(
284
+ self,
285
+ prompt: str,
286
+ deps: Any,
287
+ session: UsageSession,
288
+ messages: list[dict[str, Any]],
289
+ iteration: int,
290
+ ) -> RunContext[Any]:
291
+ """Create a :class:`RunContext` snapshot for the current run."""
292
+ return RunContext(
293
+ deps=deps,
294
+ model=self._model,
295
+ usage=session.summary(),
296
+ messages=list(messages),
297
+ iteration=iteration,
298
+ prompt=prompt,
299
+ )
300
+
301
+ # ------------------------------------------------------------------
302
+ # Tool wrapping (RunContext injection + ModelRetry + callbacks)
303
+ # ------------------------------------------------------------------
304
+
305
+ def _wrap_tools_with_context(self, ctx: RunContext[Any]) -> ToolRegistry:
306
+ """Return a new :class:`ToolRegistry` with wrapped tool functions.
307
+
308
+ For each registered tool:
309
+ - If the tool's first param is ``RunContext``, inject *ctx* automatically.
310
+ - Catch :class:`ModelRetry` and convert to an error string.
311
+ - Fire ``agent_callbacks.on_tool_start`` / ``on_tool_end``.
312
+ - Strip the ``RunContext`` parameter from the JSON schema sent to the LLM.
313
+ """
314
+ if not self._tools:
315
+ return ToolRegistry()
316
+
317
+ new_registry = ToolRegistry()
318
+
319
+ cb = self._agent_callbacks
320
+
321
+ for td in self._tools.definitions:
322
+ wants_ctx = _tool_wants_context(td.function)
323
+ original_fn = td.function
324
+ tool_name = td.name
325
+
326
+ def _make_wrapper(
327
+ _fn: Callable[..., Any],
328
+ _wants: bool,
329
+ _name: str,
330
+ _cb: AgentCallbacks = cb,
331
+ ) -> Callable[..., Any]:
332
+ def wrapper(**kwargs: Any) -> Any:
333
+ if _cb.on_tool_start:
334
+ _cb.on_tool_start(_name, kwargs)
335
+ try:
336
+ if _wants:
337
+ result = _fn(ctx, **kwargs)
338
+ else:
339
+ result = _fn(**kwargs)
340
+ except ModelRetry as exc:
341
+ result = f"Error: {exc.message}"
342
+ if _cb.on_tool_end:
343
+ _cb.on_tool_end(_name, result)
344
+ return result
345
+
346
+ return wrapper
347
+
348
+ wrapped = _make_wrapper(original_fn, wants_ctx, tool_name)
349
+
350
+ # Build schema: strip RunContext param if present
351
+ params = dict(td.parameters)
352
+ if wants_ctx:
353
+ ctx_param_name = _get_first_param_name(td.function)
354
+ props = dict(params.get("properties", {}))
355
+ props.pop(ctx_param_name, None)
356
+ params = dict(params)
357
+ params["properties"] = props
358
+ req = list(params.get("required", []))
359
+ if ctx_param_name in req:
360
+ req.remove(ctx_param_name)
361
+ if req:
362
+ params["required"] = req
363
+ elif "required" in params:
364
+ del params["required"]
365
+
366
+ new_td = ToolDefinition(
367
+ name=td.name,
368
+ description=td.description,
369
+ parameters=params,
370
+ function=wrapped,
371
+ )
372
+ new_registry.add(new_td)
373
+
374
+ return new_registry
375
+
376
+ # ------------------------------------------------------------------
377
+ # Guardrails
378
+ # ------------------------------------------------------------------
379
+
380
+ def _run_input_guardrails(self, ctx: RunContext[Any], prompt: str) -> str:
381
+ """Execute input guardrails in order. Returns the (possibly transformed) prompt.
382
+
383
+ Each guardrail receives ``(ctx, prompt)`` and may:
384
+ - Return a ``str`` to transform the prompt.
385
+ - Return ``None`` to leave it unchanged.
386
+ - Raise :class:`GuardrailError` to reject entirely.
387
+ """
388
+ for guardrail in self._input_guardrails:
389
+ result = guardrail(ctx, prompt)
390
+ if result is not None:
391
+ prompt = result
392
+ return prompt
393
+
394
+ def _run_output_guardrails(
395
+ self,
396
+ ctx: RunContext[Any],
397
+ result: AgentResult,
398
+ conv: Conversation,
399
+ session: UsageSession,
400
+ steps: list[AgentStep],
401
+ all_tool_calls: list[dict[str, Any]],
402
+ ) -> AgentResult:
403
+ """Execute output guardrails. Returns the (possibly modified) result.
404
+
405
+ Each guardrail receives ``(ctx, result)`` and may:
406
+ - Return ``None`` to pass (no change).
407
+ - Return an :class:`AgentResult` to modify the result.
408
+ - Raise :class:`ModelRetry` to re-prompt the LLM (up to 3 retries).
409
+ """
410
+ for guardrail in self._output_guardrails:
411
+ for attempt in range(_OUTPUT_GUARDRAIL_MAX_RETRIES):
412
+ try:
413
+ guard_result = guardrail(ctx, result)
414
+ if guard_result is not None:
415
+ result = guard_result
416
+ break # guardrail passed
417
+ except ModelRetry as exc:
418
+ if self._is_over_budget(session):
419
+ logger.debug("Over budget, skipping output guardrail retry")
420
+ break
421
+ if attempt >= _OUTPUT_GUARDRAIL_MAX_RETRIES - 1:
422
+ raise ValueError(
423
+ f"Output guardrail failed after {_OUTPUT_GUARDRAIL_MAX_RETRIES} retries: {exc.message}"
424
+ ) from exc
425
+ # Re-prompt the LLM
426
+ retry_text = conv.ask(
427
+ f"Your response did not pass validation. Error: {exc.message}\n\nPlease try again."
428
+ )
429
+ self._extract_steps(conv.messages[-2:], steps, all_tool_calls)
430
+
431
+ # Re-parse if output_type is set
432
+ if self._output_type is not None:
433
+ try:
434
+ cleaned = clean_json_text(retry_text)
435
+ parsed = json.loads(cleaned)
436
+ output = self._output_type.model_validate(parsed)
437
+ except Exception:
438
+ output = retry_text
439
+ else:
440
+ output = retry_text
441
+
442
+ result = AgentResult(
443
+ output=output,
444
+ output_text=retry_text,
445
+ messages=conv.messages,
446
+ usage=conv.usage,
447
+ steps=steps,
448
+ all_tool_calls=all_tool_calls,
449
+ state=AgentState.idle,
450
+ run_usage=session.summary(),
451
+ )
452
+ return result
453
+
454
+ # ------------------------------------------------------------------
455
+ # Budget check
456
+ # ------------------------------------------------------------------
457
+
458
+ def _is_over_budget(self, session: UsageSession) -> bool:
459
+ """Return True if max_cost is set and the session has exceeded it."""
460
+ if self._max_cost is None:
461
+ return False
462
+ return session.total_cost >= self._max_cost
463
+
464
+ # ------------------------------------------------------------------
465
+ # Internals
466
+ # ------------------------------------------------------------------
467
+
468
+ def _resolve_system_prompt(self, ctx: RunContext[Any] | None = None) -> str | None:
469
+ """Build the system prompt, appending output schema if needed."""
470
+ parts: list[str] = []
471
+
472
+ if self._system_prompt is not None:
473
+ if isinstance(self._system_prompt, Persona):
474
+ # Render Persona with RunContext variables if available
475
+ render_kwargs: dict[str, Any] = {}
476
+ if ctx is not None:
477
+ render_kwargs["model"] = ctx.model
478
+ render_kwargs["iteration"] = ctx.iteration
479
+ parts.append(self._system_prompt.render(**render_kwargs))
480
+ elif callable(self._system_prompt) and not isinstance(self._system_prompt, str):
481
+ if ctx is not None:
482
+ parts.append(self._system_prompt(ctx))
483
+ else:
484
+ # Fallback: call without context (shouldn't happen in normal flow)
485
+ parts.append(self._system_prompt(None)) # type: ignore[arg-type]
486
+ else:
487
+ parts.append(str(self._system_prompt))
488
+
489
+ if self._output_type is not None:
490
+ schema = self._output_type.model_json_schema()
491
+ schema_str = json.dumps(schema, indent=2)
492
+ parts.append(
493
+ "You MUST respond with a single JSON object (no markdown, "
494
+ "no extra text) that validates against this JSON schema:\n"
495
+ f"{schema_str}\n\n"
496
+ "Use double quotes for keys and strings. "
497
+ "If a value is unknown use null."
498
+ )
499
+
500
+ return "\n\n".join(parts) if parts else None
501
+
502
+ def _build_conversation(
503
+ self,
504
+ system_prompt: str | None = None,
505
+ tools: ToolRegistry | None = None,
506
+ driver_callbacks: DriverCallbacks | None = None,
507
+ ) -> Conversation:
508
+ """Create a fresh Conversation for a single run."""
509
+ effective_tools = tools if tools is not None else (self._tools if self._tools else None)
510
+
511
+ kwargs: dict[str, Any] = {
512
+ "system_prompt": system_prompt if system_prompt is not None else self._resolve_system_prompt(),
513
+ "tools": effective_tools,
514
+ "max_tool_rounds": self._max_iterations,
515
+ }
516
+ if self._options:
517
+ kwargs["options"] = self._options
518
+ if driver_callbacks is not None:
519
+ kwargs["callbacks"] = driver_callbacks
520
+
521
+ if self._driver is not None:
522
+ kwargs["driver"] = self._driver
523
+ else:
524
+ kwargs["model_name"] = self._model
525
+
526
+ return Conversation(**kwargs)
527
+
528
+ def _execute(self, prompt: str, steps: list[AgentStep], deps: Any) -> AgentResult:
529
+ """Core execution: run conversation, extract steps, parse output."""
530
+ # 1. Create per-run UsageSession and wire into DriverCallbacks
531
+ session = UsageSession()
532
+ driver_callbacks = DriverCallbacks(
533
+ on_response=session.record,
534
+ on_error=session.record_error,
535
+ )
536
+
537
+ # 2. Build initial RunContext
538
+ ctx = self._build_run_context(prompt, deps, session, [], 0)
539
+
540
+ # 3. Run input guardrails
541
+ effective_prompt = self._run_input_guardrails(ctx, prompt)
542
+
543
+ # 4. Resolve system prompt (call it if callable, passing ctx)
544
+ resolved_system_prompt = self._resolve_system_prompt(ctx)
545
+
546
+ # 5. Wrap tools with context
547
+ wrapped_tools = self._wrap_tools_with_context(ctx)
548
+
549
+ # 6. Build Conversation
550
+ conv = self._build_conversation(
551
+ system_prompt=resolved_system_prompt,
552
+ tools=wrapped_tools if wrapped_tools else None,
553
+ driver_callbacks=driver_callbacks,
554
+ )
555
+
556
+ # 7. Fire on_iteration callback
557
+ if self._agent_callbacks.on_iteration:
558
+ self._agent_callbacks.on_iteration(0)
559
+
560
+ # 8. Ask the conversation (handles full tool loop internally)
561
+ t0 = time.perf_counter()
562
+ response_text = conv.ask(effective_prompt)
563
+ elapsed_ms = (time.perf_counter() - t0) * 1000
564
+
565
+ # 9. Extract steps and tool calls from conversation messages
566
+ all_tool_calls: list[dict[str, Any]] = []
567
+ self._extract_steps(conv.messages, steps, all_tool_calls)
568
+
569
+ # Handle output_type parsing
570
+ if self._output_type is not None:
571
+ output, output_text = self._parse_output(conv, response_text, steps, all_tool_calls, elapsed_ms, session)
572
+ else:
573
+ output = response_text
574
+ output_text = response_text
575
+
576
+ # Build result with run_usage
577
+ result = AgentResult(
578
+ output=output,
579
+ output_text=output_text,
580
+ messages=conv.messages,
581
+ usage=conv.usage,
582
+ steps=steps,
583
+ all_tool_calls=all_tool_calls,
584
+ state=AgentState.idle,
585
+ run_usage=session.summary(),
586
+ )
587
+
588
+ # 10. Run output guardrails
589
+ if self._output_guardrails:
590
+ result = self._run_output_guardrails(ctx, result, conv, session, steps, all_tool_calls)
591
+
592
+ # 11. Fire callbacks
593
+ if self._agent_callbacks.on_step:
594
+ for step in steps:
595
+ self._agent_callbacks.on_step(step)
596
+
597
+ if self._agent_callbacks.on_output:
598
+ self._agent_callbacks.on_output(result)
599
+
600
+ return result
601
+
602
+ def _extract_steps(
603
+ self,
604
+ messages: list[dict[str, Any]],
605
+ steps: list[AgentStep],
606
+ all_tool_calls: list[dict[str, Any]],
607
+ ) -> None:
608
+ """Scan conversation messages and populate steps and tool_calls."""
609
+ now = time.time()
610
+
611
+ for msg in messages:
612
+ role = msg.get("role", "")
613
+
614
+ if role == "assistant":
615
+ tc_list = msg.get("tool_calls", [])
616
+ if tc_list:
617
+ # Assistant message with tool calls
618
+ for tc in tc_list:
619
+ fn = tc.get("function", {})
620
+ name = fn.get("name", tc.get("name", ""))
621
+ raw_args = fn.get("arguments", tc.get("arguments", "{}"))
622
+ if isinstance(raw_args, str):
623
+ try:
624
+ args = json.loads(raw_args)
625
+ except json.JSONDecodeError:
626
+ args = {}
627
+ else:
628
+ args = raw_args
629
+
630
+ steps.append(
631
+ AgentStep(
632
+ step_type=StepType.tool_call,
633
+ timestamp=now,
634
+ content=msg.get("content", ""),
635
+ tool_name=name,
636
+ tool_args=args,
637
+ )
638
+ )
639
+ all_tool_calls.append({"name": name, "arguments": args, "id": tc.get("id", "")})
640
+ else:
641
+ # Final assistant message (no tool calls)
642
+ steps.append(
643
+ AgentStep(
644
+ step_type=StepType.output,
645
+ timestamp=now,
646
+ content=msg.get("content", ""),
647
+ )
648
+ )
649
+
650
+ elif role == "tool":
651
+ steps.append(
652
+ AgentStep(
653
+ step_type=StepType.tool_result,
654
+ timestamp=now,
655
+ content=msg.get("content", ""),
656
+ tool_name=msg.get("tool_call_id"),
657
+ )
658
+ )
659
+
660
+ def _parse_output(
661
+ self,
662
+ conv: Conversation,
663
+ response_text: str,
664
+ steps: list[AgentStep],
665
+ all_tool_calls: list[dict[str, Any]],
666
+ elapsed_ms: float,
667
+ session: UsageSession | None = None,
668
+ ) -> tuple[Any, str]:
669
+ """Try to parse ``response_text`` as the output_type, with retries."""
670
+ assert self._output_type is not None
671
+
672
+ last_error: Exception | None = None
673
+ text = response_text
674
+
675
+ for attempt in range(_OUTPUT_PARSE_MAX_RETRIES):
676
+ try:
677
+ cleaned = clean_json_text(text)
678
+ parsed = json.loads(cleaned)
679
+ model_instance = self._output_type.model_validate(parsed)
680
+ return model_instance, text
681
+ except Exception as exc:
682
+ last_error = exc
683
+ if attempt < _OUTPUT_PARSE_MAX_RETRIES - 1:
684
+ # Check budget before retrying
685
+ if session is not None and self._is_over_budget(session):
686
+ logger.debug("Over budget, skipping output parse retry")
687
+ break
688
+ logger.debug("Output parse attempt %d failed: %s", attempt + 1, exc)
689
+ retry_msg = (
690
+ f"Your previous response could not be parsed as valid JSON "
691
+ f"matching the required schema. Error: {exc}\n\n"
692
+ f"Please try again and respond ONLY with valid JSON."
693
+ )
694
+ text = conv.ask(retry_msg)
695
+
696
+ # Record the retry step
697
+ self._extract_steps(conv.messages[-2:], steps, all_tool_calls)
698
+
699
+ raise ValueError(
700
+ f"Failed to parse output as {self._output_type.__name__} "
701
+ f"after {_OUTPUT_PARSE_MAX_RETRIES} attempts: {last_error}"
702
+ )
703
+
704
+ # ------------------------------------------------------------------
705
+ # iter() — step-by-step inspection
706
+ # ------------------------------------------------------------------
707
+
708
+ def iter(self, prompt: str, *, deps: Any = None) -> AgentIterator:
709
+ """Execute the agent loop and iterate over steps.
710
+
711
+ Returns an :class:`AgentIterator` that yields :class:`AgentStep`
712
+ objects. After iteration completes, the final :class:`AgentResult`
713
+ is available via :attr:`AgentIterator.result`.
714
+
715
+ Note:
716
+ In Phase 3c the conversation's tool loop runs to completion
717
+ before steps are yielded. True mid-loop yielding is deferred.
718
+ """
719
+ gen = self._execute_iter(prompt, deps)
720
+ return AgentIterator(gen)
721
+
722
+ def _execute_iter(self, prompt: str, deps: Any) -> Generator[AgentStep, None, AgentResult]:
723
+ """Generator that executes the agent loop and yields each step."""
724
+ self._state = AgentState.running
725
+ self._stop_requested = False
726
+ steps: list[AgentStep] = []
727
+
728
+ try:
729
+ result = self._execute(prompt, steps, deps)
730
+ # Yield each step one at a time
731
+ yield from result.steps
732
+ self._state = AgentState.idle
733
+ return result
734
+ except Exception:
735
+ self._state = AgentState.errored
736
+ raise
737
+
738
+ # ------------------------------------------------------------------
739
+ # run_stream() — streaming output
740
+ # ------------------------------------------------------------------
741
+
742
+ def run_stream(self, prompt: str, *, deps: Any = None) -> StreamedAgentResult:
743
+ """Execute the agent loop with streaming output.
744
+
745
+ Returns a :class:`StreamedAgentResult` that yields
746
+ :class:`StreamEvent` objects. After iteration completes, the
747
+ final :class:`AgentResult` is available via
748
+ :attr:`StreamedAgentResult.result`.
749
+
750
+ When tools are registered, streaming falls back to non-streaming
751
+ ``conv.ask()`` and yields the full response as a single
752
+ ``text_delta`` event.
753
+ """
754
+ gen = self._execute_stream(prompt, deps)
755
+ return StreamedAgentResult(gen)
756
+
757
+ def _execute_stream(self, prompt: str, deps: Any) -> Generator[StreamEvent, None, AgentResult]:
758
+ """Generator that executes the agent loop and yields stream events."""
759
+ self._state = AgentState.running
760
+ self._stop_requested = False
761
+ steps: list[AgentStep] = []
762
+
763
+ try:
764
+ # 1. Create per-run UsageSession and wire into DriverCallbacks
765
+ session = UsageSession()
766
+ driver_callbacks = DriverCallbacks(
767
+ on_response=session.record,
768
+ on_error=session.record_error,
769
+ )
770
+
771
+ # 2. Build initial RunContext
772
+ ctx = self._build_run_context(prompt, deps, session, [], 0)
773
+
774
+ # 3. Run input guardrails
775
+ effective_prompt = self._run_input_guardrails(ctx, prompt)
776
+
777
+ # 4. Resolve system prompt
778
+ resolved_system_prompt = self._resolve_system_prompt(ctx)
779
+
780
+ # 5. Wrap tools with context
781
+ wrapped_tools = self._wrap_tools_with_context(ctx)
782
+ has_tools = bool(wrapped_tools)
783
+
784
+ # 6. Build Conversation
785
+ conv = self._build_conversation(
786
+ system_prompt=resolved_system_prompt,
787
+ tools=wrapped_tools if wrapped_tools else None,
788
+ driver_callbacks=driver_callbacks,
789
+ )
790
+
791
+ # 7. Fire on_iteration callback
792
+ if self._agent_callbacks.on_iteration:
793
+ self._agent_callbacks.on_iteration(0)
794
+
795
+ if has_tools:
796
+ # Tools registered: fall back to non-streaming conv.ask()
797
+ t0 = time.perf_counter()
798
+ response_text = conv.ask(effective_prompt)
799
+ _elapsed_ms = (time.perf_counter() - t0) * 1000
800
+
801
+ # Yield the full text as a single delta
802
+ yield StreamEvent(
803
+ event_type=StreamEventType.text_delta,
804
+ data=response_text,
805
+ )
806
+ else:
807
+ # No tools: use streaming if available
808
+ response_text = ""
809
+ stream_iter: Iterator[str] = conv.ask_stream(effective_prompt)
810
+ for chunk in stream_iter:
811
+ response_text += chunk
812
+ yield StreamEvent(
813
+ event_type=StreamEventType.text_delta,
814
+ data=chunk,
815
+ )
816
+
817
+ # 8. Extract steps
818
+ all_tool_calls: list[dict[str, Any]] = []
819
+ self._extract_steps(conv.messages, steps, all_tool_calls)
820
+
821
+ # 9. Parse output
822
+ if self._output_type is not None:
823
+ output, output_text = self._parse_output(conv, response_text, steps, all_tool_calls, 0.0, session)
824
+ else:
825
+ output = response_text
826
+ output_text = response_text
827
+
828
+ # 10. Build result
829
+ result = AgentResult(
830
+ output=output,
831
+ output_text=output_text,
832
+ messages=conv.messages,
833
+ usage=conv.usage,
834
+ steps=steps,
835
+ all_tool_calls=all_tool_calls,
836
+ state=AgentState.idle,
837
+ run_usage=session.summary(),
838
+ )
839
+
840
+ # 11. Run output guardrails
841
+ if self._output_guardrails:
842
+ result = self._run_output_guardrails(ctx, result, conv, session, steps, all_tool_calls)
843
+
844
+ # 12. Fire callbacks
845
+ if self._agent_callbacks.on_step:
846
+ for step in steps:
847
+ self._agent_callbacks.on_step(step)
848
+ if self._agent_callbacks.on_output:
849
+ self._agent_callbacks.on_output(result)
850
+
851
+ # 13. Yield final output event
852
+ yield StreamEvent(
853
+ event_type=StreamEventType.output,
854
+ data=result,
855
+ )
856
+
857
+ self._state = AgentState.idle
858
+ return result
859
+ except Exception:
860
+ self._state = AgentState.errored
861
+ raise
862
+
863
+
864
+ # ------------------------------------------------------------------
865
+ # AgentIterator
866
+ # ------------------------------------------------------------------
867
+
868
+
869
+ class AgentIterator:
870
+ """Wraps the :meth:`Agent.iter` generator, capturing the final result.
871
+
872
+ After iteration completes (the ``for`` loop ends), the
873
+ :attr:`result` property holds the :class:`AgentResult`.
874
+ """
875
+
876
+ def __init__(self, gen: Generator[AgentStep, None, AgentResult]) -> None:
877
+ self._gen = gen
878
+ self._result: AgentResult | None = None
879
+
880
+ def __iter__(self) -> AgentIterator:
881
+ return self
882
+
883
+ def __next__(self) -> AgentStep:
884
+ try:
885
+ return next(self._gen)
886
+ except StopIteration as e:
887
+ self._result = e.value
888
+ raise
889
+
890
+ @property
891
+ def result(self) -> AgentResult | None:
892
+ """The final :class:`AgentResult`, available after iteration completes."""
893
+ return self._result
894
+
895
+
896
+ # ------------------------------------------------------------------
897
+ # StreamedAgentResult
898
+ # ------------------------------------------------------------------
899
+
900
+
901
+ class StreamedAgentResult:
902
+ """Wraps the :meth:`Agent.run_stream` generator, capturing the final result.
903
+
904
+ Yields :class:`StreamEvent` objects during iteration. After iteration
905
+ completes, the :attr:`result` property holds the :class:`AgentResult`.
906
+ """
907
+
908
+ def __init__(self, gen: Generator[StreamEvent, None, AgentResult]) -> None:
909
+ self._gen = gen
910
+ self._result: AgentResult | None = None
911
+
912
+ def __iter__(self) -> StreamedAgentResult:
913
+ return self
914
+
915
+ def __next__(self) -> StreamEvent:
916
+ try:
917
+ return next(self._gen)
918
+ except StopIteration as e:
919
+ self._result = e.value
920
+ raise
921
+
922
+ @property
923
+ def result(self) -> AgentResult | None:
924
+ """The final :class:`AgentResult`, available after iteration completes."""
925
+ return self._result