prompture 0.0.36.dev1__py3-none-any.whl → 0.0.37.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. prompture/__init__.py +120 -2
  2. prompture/_version.py +2 -2
  3. prompture/agent.py +925 -0
  4. prompture/agent_types.py +156 -0
  5. prompture/async_agent.py +879 -0
  6. prompture/async_conversation.py +199 -17
  7. prompture/async_driver.py +24 -0
  8. prompture/async_groups.py +551 -0
  9. prompture/conversation.py +213 -18
  10. prompture/core.py +30 -12
  11. prompture/discovery.py +24 -1
  12. prompture/driver.py +38 -0
  13. prompture/drivers/__init__.py +5 -1
  14. prompture/drivers/async_azure_driver.py +7 -1
  15. prompture/drivers/async_claude_driver.py +7 -1
  16. prompture/drivers/async_google_driver.py +24 -4
  17. prompture/drivers/async_grok_driver.py +7 -1
  18. prompture/drivers/async_groq_driver.py +7 -1
  19. prompture/drivers/async_lmstudio_driver.py +59 -3
  20. prompture/drivers/async_ollama_driver.py +7 -0
  21. prompture/drivers/async_openai_driver.py +7 -1
  22. prompture/drivers/async_openrouter_driver.py +7 -1
  23. prompture/drivers/async_registry.py +5 -1
  24. prompture/drivers/azure_driver.py +7 -1
  25. prompture/drivers/claude_driver.py +7 -1
  26. prompture/drivers/google_driver.py +24 -4
  27. prompture/drivers/grok_driver.py +7 -1
  28. prompture/drivers/groq_driver.py +7 -1
  29. prompture/drivers/lmstudio_driver.py +58 -6
  30. prompture/drivers/ollama_driver.py +7 -0
  31. prompture/drivers/openai_driver.py +7 -1
  32. prompture/drivers/openrouter_driver.py +7 -1
  33. prompture/drivers/vision_helpers.py +153 -0
  34. prompture/group_types.py +147 -0
  35. prompture/groups.py +530 -0
  36. prompture/image.py +180 -0
  37. prompture/persistence.py +254 -0
  38. prompture/persona.py +482 -0
  39. prompture/serialization.py +218 -0
  40. prompture/settings.py +1 -0
  41. {prompture-0.0.36.dev1.dist-info → prompture-0.0.37.dev1.dist-info}/METADATA +1 -1
  42. prompture-0.0.37.dev1.dist-info/RECORD +77 -0
  43. prompture-0.0.36.dev1.dist-info/RECORD +0 -66
  44. {prompture-0.0.36.dev1.dist-info → prompture-0.0.37.dev1.dist-info}/WHEEL +0 -0
  45. {prompture-0.0.36.dev1.dist-info → prompture-0.0.37.dev1.dist-info}/entry_points.txt +0 -0
  46. {prompture-0.0.36.dev1.dist-info → prompture-0.0.37.dev1.dist-info}/licenses/LICENSE +0 -0
  47. {prompture-0.0.36.dev1.dist-info → prompture-0.0.37.dev1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,879 @@
1
+ """Async Agent framework for Prompture.
2
+
3
+ Provides :class:`AsyncAgent`, the async counterpart of :class:`~prompture.agent.Agent`.
4
+ All methods are ``async`` and use :class:`~prompture.async_conversation.AsyncConversation`.
5
+
6
+ Example::
7
+
8
+ from prompture import AsyncAgent
9
+
10
+ agent = AsyncAgent("openai/gpt-4o", system_prompt="You are helpful.")
11
+ result = await agent.run("What is 2 + 2?")
12
+ print(result.output)
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import asyncio
18
+ import inspect
19
+ import json
20
+ import logging
21
+ import time
22
+ import typing
23
+ from collections.abc import AsyncGenerator, Callable
24
+ from typing import Any, Generic
25
+
26
+ from pydantic import BaseModel
27
+
28
+ from .agent_types import (
29
+ AgentCallbacks,
30
+ AgentResult,
31
+ AgentState,
32
+ AgentStep,
33
+ DepsType,
34
+ ModelRetry,
35
+ RunContext,
36
+ StepType,
37
+ StreamEvent,
38
+ StreamEventType,
39
+ )
40
+ from .callbacks import DriverCallbacks
41
+ from .persona import Persona
42
+ from .session import UsageSession
43
+ from .tools import clean_json_text
44
+ from .tools_schema import ToolDefinition, ToolRegistry
45
+
46
+ logger = logging.getLogger("prompture.async_agent")
47
+
48
+ _OUTPUT_PARSE_MAX_RETRIES = 3
49
+ _OUTPUT_GUARDRAIL_MAX_RETRIES = 3
50
+
51
+
52
+ # ------------------------------------------------------------------
53
+ # Helpers
54
+ # ------------------------------------------------------------------
55
+
56
+
57
+ def _is_async_callable(fn: Callable[..., Any]) -> bool:
58
+ """Check if *fn* is an async callable (coroutine function or has async ``__call__``)."""
59
+ if asyncio.iscoroutinefunction(fn):
60
+ return True
61
+ # Check if the object has an async __call__ method (callable class)
62
+ dunder_call = type(fn).__call__ if callable(fn) else None
63
+ return dunder_call is not None and asyncio.iscoroutinefunction(dunder_call)
64
+
65
+
66
+ def _tool_wants_context(fn: Callable[..., Any]) -> bool:
67
+ """Check whether *fn*'s first parameter is annotated as :class:`RunContext`."""
68
+ sig = inspect.signature(fn)
69
+ params = list(sig.parameters.keys())
70
+ if not params:
71
+ return False
72
+
73
+ first_param = params[0]
74
+ if first_param == "self":
75
+ if len(params) < 2:
76
+ return False
77
+ first_param = params[1]
78
+
79
+ annotation = None
80
+ try:
81
+ hints = typing.get_type_hints(fn, include_extras=True)
82
+ annotation = hints.get(first_param)
83
+ except Exception:
84
+ pass
85
+
86
+ if annotation is None:
87
+ raw = sig.parameters[first_param].annotation
88
+ if raw is inspect.Parameter.empty:
89
+ return False
90
+ annotation = raw
91
+
92
+ if isinstance(annotation, str):
93
+ return annotation == "RunContext" or annotation.startswith("RunContext[")
94
+
95
+ if annotation is RunContext:
96
+ return True
97
+
98
+ origin = getattr(annotation, "__origin__", None)
99
+ return origin is RunContext
100
+
101
+
102
+ def _get_first_param_name(fn: Callable[..., Any]) -> str:
103
+ """Return the name of the first non-self parameter of *fn*."""
104
+ sig = inspect.signature(fn)
105
+ for name, _param in sig.parameters.items():
106
+ if name != "self":
107
+ return name
108
+ return ""
109
+
110
+
111
+ # ------------------------------------------------------------------
112
+ # AsyncAgent
113
+ # ------------------------------------------------------------------
114
+
115
+
116
+ class AsyncAgent(Generic[DepsType]):
117
+ """Async agent that executes a ReAct loop with tool support.
118
+
119
+ Mirrors :class:`~prompture.agent.Agent` but uses
120
+ :class:`~prompture.async_conversation.AsyncConversation` and
121
+ ``async`` methods throughout.
122
+
123
+ Args:
124
+ model: Model string in ``"provider/model"`` format.
125
+ driver: Pre-built async driver instance.
126
+ tools: Initial tools as a list of callables or a :class:`ToolRegistry`.
127
+ system_prompt: System prompt prepended to every run. May also be a
128
+ callable ``(RunContext) -> str`` for dynamic prompts.
129
+ output_type: Optional Pydantic model class for structured output.
130
+ max_iterations: Maximum tool-use rounds per run.
131
+ max_cost: Soft budget in USD.
132
+ options: Extra driver options forwarded to every call.
133
+ deps_type: Type hint for dependencies.
134
+ agent_callbacks: Agent-level observability callbacks.
135
+ input_guardrails: Functions called before the prompt is sent.
136
+ output_guardrails: Functions called after output is parsed.
137
+ """
138
+
139
+ def __init__(
140
+ self,
141
+ model: str = "",
142
+ *,
143
+ driver: Any | None = None,
144
+ tools: list[Callable[..., Any]] | ToolRegistry | None = None,
145
+ system_prompt: str | Persona | Callable[..., str] | None = None,
146
+ output_type: type[BaseModel] | None = None,
147
+ max_iterations: int = 10,
148
+ max_cost: float | None = None,
149
+ options: dict[str, Any] | None = None,
150
+ deps_type: type | None = None,
151
+ agent_callbacks: AgentCallbacks | None = None,
152
+ input_guardrails: list[Callable[..., Any]] | None = None,
153
+ output_guardrails: list[Callable[..., Any]] | None = None,
154
+ name: str = "",
155
+ description: str = "",
156
+ output_key: str | None = None,
157
+ ) -> None:
158
+ if not model and driver is None:
159
+ raise ValueError("Either model or driver must be provided")
160
+
161
+ self._model = model
162
+ self._driver = driver
163
+ self._system_prompt = system_prompt
164
+ self._output_type = output_type
165
+ self._max_iterations = max_iterations
166
+ self._max_cost = max_cost
167
+ self._options = dict(options) if options else {}
168
+ self._deps_type = deps_type
169
+ self._agent_callbacks = agent_callbacks or AgentCallbacks()
170
+ self._input_guardrails = list(input_guardrails) if input_guardrails else []
171
+ self._output_guardrails = list(output_guardrails) if output_guardrails else []
172
+ self.name = name
173
+ self.description = description
174
+ self.output_key = output_key
175
+
176
+ # Build internal tool registry
177
+ self._tools = ToolRegistry()
178
+ if isinstance(tools, ToolRegistry):
179
+ self._tools = tools
180
+ elif tools is not None:
181
+ for fn in tools:
182
+ self._tools.register(fn)
183
+
184
+ self._state = AgentState.idle
185
+ self._stop_requested = False
186
+
187
+ # ------------------------------------------------------------------
188
+ # Public API
189
+ # ------------------------------------------------------------------
190
+
191
+ def tool(self, fn: Callable[..., Any]) -> Callable[..., Any]:
192
+ """Decorator to register a function as a tool on this agent."""
193
+ self._tools.register(fn)
194
+ return fn
195
+
196
+ @property
197
+ def state(self) -> AgentState:
198
+ """Current lifecycle state of the agent."""
199
+ return self._state
200
+
201
+ def stop(self) -> None:
202
+ """Request graceful shutdown after the current iteration."""
203
+ self._stop_requested = True
204
+
205
+ def as_tool(
206
+ self,
207
+ name: str | None = None,
208
+ description: str | None = None,
209
+ custom_output_extractor: Callable[[AgentResult], str] | None = None,
210
+ ) -> ToolDefinition:
211
+ """Wrap this AsyncAgent as a callable tool for another Agent.
212
+
213
+ Creates a :class:`ToolDefinition` whose function accepts a ``prompt``
214
+ string, runs this agent (bridging async to sync), and returns the
215
+ output text.
216
+
217
+ Args:
218
+ name: Tool name (defaults to ``self.name`` or ``"agent_tool"``).
219
+ description: Tool description (defaults to ``self.description``).
220
+ custom_output_extractor: Optional function to extract a string
221
+ from :class:`AgentResult`.
222
+ """
223
+ tool_name = name or self.name or "agent_tool"
224
+ tool_desc = description or self.description or f"Run agent {tool_name}"
225
+ agent = self
226
+ extractor = custom_output_extractor
227
+
228
+ def _call_agent(prompt: str) -> str:
229
+ """Run the wrapped async agent with the given prompt."""
230
+ try:
231
+ loop = asyncio.get_running_loop()
232
+ except RuntimeError:
233
+ loop = None
234
+
235
+ if loop is not None and loop.is_running():
236
+ import concurrent.futures
237
+
238
+ with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
239
+ result = pool.submit(asyncio.run, agent.run(prompt)).result()
240
+ else:
241
+ result = asyncio.run(agent.run(prompt))
242
+
243
+ if extractor is not None:
244
+ return extractor(result)
245
+ return result.output_text
246
+
247
+ return ToolDefinition(
248
+ name=tool_name,
249
+ description=tool_desc,
250
+ parameters={
251
+ "type": "object",
252
+ "properties": {
253
+ "prompt": {"type": "string", "description": "The prompt to send to the agent"},
254
+ },
255
+ "required": ["prompt"],
256
+ },
257
+ function=_call_agent,
258
+ )
259
+
260
+ async def run(self, prompt: str, *, deps: Any = None) -> AgentResult:
261
+ """Execute the agent loop to completion (async).
262
+
263
+ Creates a fresh conversation, sends the prompt, handles tool calls,
264
+ and optionally parses the final response into ``output_type``.
265
+ """
266
+ self._state = AgentState.running
267
+ self._stop_requested = False
268
+ steps: list[AgentStep] = []
269
+
270
+ try:
271
+ result = await self._execute(prompt, steps, deps)
272
+ self._state = AgentState.idle
273
+ return result
274
+ except Exception:
275
+ self._state = AgentState.errored
276
+ raise
277
+
278
+ async def iter(self, prompt: str, *, deps: Any = None) -> AsyncAgentIterator:
279
+ """Execute the agent loop and iterate over steps asynchronously.
280
+
281
+ Returns an :class:`AsyncAgentIterator` yielding :class:`AgentStep` objects.
282
+ After iteration, :attr:`AsyncAgentIterator.result` holds the final result.
283
+ """
284
+ gen = self._execute_iter(prompt, deps)
285
+ return AsyncAgentIterator(gen)
286
+
287
+ async def run_stream(self, prompt: str, *, deps: Any = None) -> AsyncStreamedAgentResult:
288
+ """Execute the agent loop with streaming output (async).
289
+
290
+ Returns an :class:`AsyncStreamedAgentResult` yielding :class:`StreamEvent` objects.
291
+ """
292
+ gen = self._execute_stream(prompt, deps)
293
+ return AsyncStreamedAgentResult(gen)
294
+
295
+ # ------------------------------------------------------------------
296
+ # RunContext helpers
297
+ # ------------------------------------------------------------------
298
+
299
+ def _build_run_context(
300
+ self,
301
+ prompt: str,
302
+ deps: Any,
303
+ session: UsageSession,
304
+ messages: list[dict[str, Any]],
305
+ iteration: int,
306
+ ) -> RunContext[Any]:
307
+ return RunContext(
308
+ deps=deps,
309
+ model=self._model,
310
+ usage=session.summary(),
311
+ messages=list(messages),
312
+ iteration=iteration,
313
+ prompt=prompt,
314
+ )
315
+
316
+ # ------------------------------------------------------------------
317
+ # Tool wrapping (RunContext injection + ModelRetry + callbacks)
318
+ # ------------------------------------------------------------------
319
+
320
+ def _wrap_tools_with_context(self, ctx: RunContext[Any]) -> ToolRegistry:
321
+ """Return a new :class:`ToolRegistry` with wrapped tool functions.
322
+
323
+ All wrappers are **sync** so they work with ``ToolRegistry.execute()``.
324
+ For async tool functions, the wrapper uses
325
+ ``asyncio.get_event_loop().run_until_complete()`` as a fallback.
326
+ """
327
+ if not self._tools:
328
+ return ToolRegistry()
329
+
330
+ new_registry = ToolRegistry()
331
+ cb = self._agent_callbacks
332
+
333
+ for td in self._tools.definitions:
334
+ wants_ctx = _tool_wants_context(td.function)
335
+ original_fn = td.function
336
+ tool_name = td.name
337
+ is_async = _is_async_callable(original_fn)
338
+
339
+ def _make_wrapper(
340
+ _fn: Callable[..., Any],
341
+ _wants: bool,
342
+ _name: str,
343
+ _is_async: bool,
344
+ _cb: AgentCallbacks = cb,
345
+ ) -> Callable[..., Any]:
346
+ def wrapper(**kwargs: Any) -> Any:
347
+ if _cb.on_tool_start:
348
+ _cb.on_tool_start(_name, kwargs)
349
+ try:
350
+ if _wants:
351
+ call_args = (ctx,)
352
+ else:
353
+ call_args = ()
354
+
355
+ if _is_async:
356
+ coro = _fn(*call_args, **kwargs)
357
+ # Try to get running loop; if none, use asyncio.run()
358
+ try:
359
+ loop = asyncio.get_running_loop()
360
+ except RuntimeError:
361
+ loop = None
362
+ if loop is not None and loop.is_running():
363
+ # We're inside an async context — create a new thread
364
+ import concurrent.futures
365
+
366
+ with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
367
+ result = pool.submit(asyncio.run, coro).result()
368
+ else:
369
+ result = asyncio.run(coro)
370
+ else:
371
+ result = _fn(*call_args, **kwargs)
372
+ except ModelRetry as exc:
373
+ result = f"Error: {exc.message}"
374
+ if _cb.on_tool_end:
375
+ _cb.on_tool_end(_name, result)
376
+ return result
377
+
378
+ return wrapper
379
+
380
+ wrapped = _make_wrapper(original_fn, wants_ctx, tool_name, is_async)
381
+
382
+ # Build schema: strip RunContext param if present
383
+ params = dict(td.parameters)
384
+ if wants_ctx:
385
+ ctx_param_name = _get_first_param_name(td.function)
386
+ props = dict(params.get("properties", {}))
387
+ props.pop(ctx_param_name, None)
388
+ params = dict(params)
389
+ params["properties"] = props
390
+ req = list(params.get("required", []))
391
+ if ctx_param_name in req:
392
+ req.remove(ctx_param_name)
393
+ if req:
394
+ params["required"] = req
395
+ elif "required" in params:
396
+ del params["required"]
397
+
398
+ new_td = ToolDefinition(
399
+ name=td.name,
400
+ description=td.description,
401
+ parameters=params,
402
+ function=wrapped,
403
+ )
404
+ new_registry.add(new_td)
405
+
406
+ return new_registry
407
+
408
+ # ------------------------------------------------------------------
409
+ # Guardrails
410
+ # ------------------------------------------------------------------
411
+
412
+ def _run_input_guardrails(self, ctx: RunContext[Any], prompt: str) -> str:
413
+ for guardrail in self._input_guardrails:
414
+ result = guardrail(ctx, prompt)
415
+ if result is not None:
416
+ prompt = result
417
+ return prompt
418
+
419
+ async def _run_output_guardrails(
420
+ self,
421
+ ctx: RunContext[Any],
422
+ result: AgentResult,
423
+ conv: Any,
424
+ session: UsageSession,
425
+ steps: list[AgentStep],
426
+ all_tool_calls: list[dict[str, Any]],
427
+ ) -> AgentResult:
428
+ for guardrail in self._output_guardrails:
429
+ for attempt in range(_OUTPUT_GUARDRAIL_MAX_RETRIES):
430
+ try:
431
+ guard_result = guardrail(ctx, result)
432
+ if guard_result is not None:
433
+ result = guard_result
434
+ break
435
+ except ModelRetry as exc:
436
+ if self._is_over_budget(session):
437
+ break
438
+ if attempt >= _OUTPUT_GUARDRAIL_MAX_RETRIES - 1:
439
+ raise ValueError(
440
+ f"Output guardrail failed after {_OUTPUT_GUARDRAIL_MAX_RETRIES} retries: {exc.message}"
441
+ ) from exc
442
+ retry_text = await conv.ask(
443
+ f"Your response did not pass validation. Error: {exc.message}\n\nPlease try again."
444
+ )
445
+ self._extract_steps(conv.messages[-2:], steps, all_tool_calls)
446
+
447
+ if self._output_type is not None:
448
+ try:
449
+ cleaned = clean_json_text(retry_text)
450
+ parsed = json.loads(cleaned)
451
+ output = self._output_type.model_validate(parsed)
452
+ except Exception:
453
+ output = retry_text
454
+ else:
455
+ output = retry_text
456
+
457
+ result = AgentResult(
458
+ output=output,
459
+ output_text=retry_text,
460
+ messages=conv.messages,
461
+ usage=conv.usage,
462
+ steps=steps,
463
+ all_tool_calls=all_tool_calls,
464
+ state=AgentState.idle,
465
+ run_usage=session.summary(),
466
+ )
467
+ return result
468
+
469
+ # ------------------------------------------------------------------
470
+ # Budget check
471
+ # ------------------------------------------------------------------
472
+
473
+ def _is_over_budget(self, session: UsageSession) -> bool:
474
+ if self._max_cost is None:
475
+ return False
476
+ return session.total_cost >= self._max_cost
477
+
478
+ # ------------------------------------------------------------------
479
+ # Internals
480
+ # ------------------------------------------------------------------
481
+
482
+ def _resolve_system_prompt(self, ctx: RunContext[Any] | None = None) -> str | None:
483
+ parts: list[str] = []
484
+
485
+ if self._system_prompt is not None:
486
+ if isinstance(self._system_prompt, Persona):
487
+ render_kwargs: dict[str, Any] = {}
488
+ if ctx is not None:
489
+ render_kwargs["model"] = ctx.model
490
+ render_kwargs["iteration"] = ctx.iteration
491
+ parts.append(self._system_prompt.render(**render_kwargs))
492
+ elif callable(self._system_prompt) and not isinstance(self._system_prompt, str):
493
+ if ctx is not None:
494
+ parts.append(self._system_prompt(ctx))
495
+ else:
496
+ parts.append(self._system_prompt(None)) # type: ignore[arg-type]
497
+ else:
498
+ parts.append(str(self._system_prompt))
499
+
500
+ if self._output_type is not None:
501
+ schema = self._output_type.model_json_schema()
502
+ schema_str = json.dumps(schema, indent=2)
503
+ parts.append(
504
+ "You MUST respond with a single JSON object (no markdown, "
505
+ "no extra text) that validates against this JSON schema:\n"
506
+ f"{schema_str}\n\n"
507
+ "Use double quotes for keys and strings. "
508
+ "If a value is unknown use null."
509
+ )
510
+
511
+ return "\n\n".join(parts) if parts else None
512
+
513
+ def _build_conversation(
514
+ self,
515
+ system_prompt: str | None = None,
516
+ tools: ToolRegistry | None = None,
517
+ driver_callbacks: DriverCallbacks | None = None,
518
+ ) -> Any:
519
+ """Create a fresh AsyncConversation for a single run."""
520
+ from .async_conversation import AsyncConversation
521
+
522
+ effective_tools = tools if tools is not None else (self._tools if self._tools else None)
523
+
524
+ kwargs: dict[str, Any] = {
525
+ "system_prompt": system_prompt if system_prompt is not None else self._resolve_system_prompt(),
526
+ "tools": effective_tools,
527
+ "max_tool_rounds": self._max_iterations,
528
+ }
529
+ if self._options:
530
+ kwargs["options"] = self._options
531
+ if driver_callbacks is not None:
532
+ kwargs["callbacks"] = driver_callbacks
533
+
534
+ if self._driver is not None:
535
+ kwargs["driver"] = self._driver
536
+ else:
537
+ kwargs["model_name"] = self._model
538
+
539
+ return AsyncConversation(**kwargs)
540
+
541
+ async def _execute(self, prompt: str, steps: list[AgentStep], deps: Any) -> AgentResult:
542
+ """Core async execution: run conversation, extract steps, parse output."""
543
+ # 1. Create per-run UsageSession
544
+ session = UsageSession()
545
+ driver_callbacks = DriverCallbacks(
546
+ on_response=session.record,
547
+ on_error=session.record_error,
548
+ )
549
+
550
+ # 2. Build initial RunContext
551
+ ctx = self._build_run_context(prompt, deps, session, [], 0)
552
+
553
+ # 3. Run input guardrails
554
+ effective_prompt = self._run_input_guardrails(ctx, prompt)
555
+
556
+ # 4. Resolve system prompt
557
+ resolved_system_prompt = self._resolve_system_prompt(ctx)
558
+
559
+ # 5. Wrap tools with context
560
+ wrapped_tools = self._wrap_tools_with_context(ctx)
561
+
562
+ # 6. Build AsyncConversation
563
+ conv = self._build_conversation(
564
+ system_prompt=resolved_system_prompt,
565
+ tools=wrapped_tools if wrapped_tools else None,
566
+ driver_callbacks=driver_callbacks,
567
+ )
568
+
569
+ # 7. Fire on_iteration callback
570
+ if self._agent_callbacks.on_iteration:
571
+ self._agent_callbacks.on_iteration(0)
572
+
573
+ # 8. Ask the conversation (handles full tool loop internally)
574
+ t0 = time.perf_counter()
575
+ response_text = await conv.ask(effective_prompt)
576
+ elapsed_ms = (time.perf_counter() - t0) * 1000
577
+
578
+ # 9. Extract steps and tool calls
579
+ all_tool_calls: list[dict[str, Any]] = []
580
+ self._extract_steps(conv.messages, steps, all_tool_calls)
581
+
582
+ # Handle output_type parsing
583
+ if self._output_type is not None:
584
+ output, output_text = await self._parse_output(
585
+ conv, response_text, steps, all_tool_calls, elapsed_ms, session
586
+ )
587
+ else:
588
+ output = response_text
589
+ output_text = response_text
590
+
591
+ result = AgentResult(
592
+ output=output,
593
+ output_text=output_text,
594
+ messages=conv.messages,
595
+ usage=conv.usage,
596
+ steps=steps,
597
+ all_tool_calls=all_tool_calls,
598
+ state=AgentState.idle,
599
+ run_usage=session.summary(),
600
+ )
601
+
602
+ # 10. Run output guardrails
603
+ if self._output_guardrails:
604
+ result = await self._run_output_guardrails(ctx, result, conv, session, steps, all_tool_calls)
605
+
606
+ # 11. Fire callbacks
607
+ if self._agent_callbacks.on_step:
608
+ for step in steps:
609
+ self._agent_callbacks.on_step(step)
610
+ if self._agent_callbacks.on_output:
611
+ self._agent_callbacks.on_output(result)
612
+
613
+ return result
614
+
615
+ def _extract_steps(
616
+ self,
617
+ messages: list[dict[str, Any]],
618
+ steps: list[AgentStep],
619
+ all_tool_calls: list[dict[str, Any]],
620
+ ) -> None:
621
+ """Scan conversation messages and populate steps and tool_calls."""
622
+ now = time.time()
623
+
624
+ for msg in messages:
625
+ role = msg.get("role", "")
626
+
627
+ if role == "assistant":
628
+ tc_list = msg.get("tool_calls", [])
629
+ if tc_list:
630
+ for tc in tc_list:
631
+ fn = tc.get("function", {})
632
+ name = fn.get("name", tc.get("name", ""))
633
+ raw_args = fn.get("arguments", tc.get("arguments", "{}"))
634
+ if isinstance(raw_args, str):
635
+ try:
636
+ args = json.loads(raw_args)
637
+ except json.JSONDecodeError:
638
+ args = {}
639
+ else:
640
+ args = raw_args
641
+
642
+ steps.append(
643
+ AgentStep(
644
+ step_type=StepType.tool_call,
645
+ timestamp=now,
646
+ content=msg.get("content", ""),
647
+ tool_name=name,
648
+ tool_args=args,
649
+ )
650
+ )
651
+ all_tool_calls.append({"name": name, "arguments": args, "id": tc.get("id", "")})
652
+ else:
653
+ steps.append(
654
+ AgentStep(
655
+ step_type=StepType.output,
656
+ timestamp=now,
657
+ content=msg.get("content", ""),
658
+ )
659
+ )
660
+
661
+ elif role == "tool":
662
+ steps.append(
663
+ AgentStep(
664
+ step_type=StepType.tool_result,
665
+ timestamp=now,
666
+ content=msg.get("content", ""),
667
+ tool_name=msg.get("tool_call_id"),
668
+ )
669
+ )
670
+
671
+ async def _parse_output(
672
+ self,
673
+ conv: Any,
674
+ response_text: str,
675
+ steps: list[AgentStep],
676
+ all_tool_calls: list[dict[str, Any]],
677
+ elapsed_ms: float,
678
+ session: UsageSession | None = None,
679
+ ) -> tuple[Any, str]:
680
+ """Try to parse ``response_text`` as the output_type, with retries (async)."""
681
+ assert self._output_type is not None
682
+
683
+ last_error: Exception | None = None
684
+ text = response_text
685
+
686
+ for attempt in range(_OUTPUT_PARSE_MAX_RETRIES):
687
+ try:
688
+ cleaned = clean_json_text(text)
689
+ parsed = json.loads(cleaned)
690
+ model_instance = self._output_type.model_validate(parsed)
691
+ return model_instance, text
692
+ except Exception as exc:
693
+ last_error = exc
694
+ if attempt < _OUTPUT_PARSE_MAX_RETRIES - 1:
695
+ if session is not None and self._is_over_budget(session):
696
+ break
697
+ retry_msg = (
698
+ f"Your previous response could not be parsed as valid JSON "
699
+ f"matching the required schema. Error: {exc}\n\n"
700
+ f"Please try again and respond ONLY with valid JSON."
701
+ )
702
+ text = await conv.ask(retry_msg)
703
+ self._extract_steps(conv.messages[-2:], steps, all_tool_calls)
704
+
705
+ raise ValueError(
706
+ f"Failed to parse output as {self._output_type.__name__} "
707
+ f"after {_OUTPUT_PARSE_MAX_RETRIES} attempts: {last_error}"
708
+ )
709
+
710
+ # ------------------------------------------------------------------
711
+ # iter() — async step-by-step
712
+ # ------------------------------------------------------------------
713
+
714
+ async def _execute_iter(self, prompt: str, deps: Any) -> AsyncGenerator[AgentStep, None]:
715
+ """Async generator that executes the agent loop and yields each step."""
716
+ self._state = AgentState.running
717
+ self._stop_requested = False
718
+ steps: list[AgentStep] = []
719
+
720
+ try:
721
+ result = await self._execute(prompt, steps, deps)
722
+ for step in result.steps:
723
+ yield step
724
+ self._state = AgentState.idle
725
+ # Store result on the generator for retrieval
726
+ self._last_iter_result = result
727
+ except Exception:
728
+ self._state = AgentState.errored
729
+ raise
730
+
731
+ # ------------------------------------------------------------------
732
+ # run_stream() — async streaming
733
+ # ------------------------------------------------------------------
734
+
735
+ async def _execute_stream(self, prompt: str, deps: Any) -> AsyncGenerator[StreamEvent, None]:
736
+ """Async generator that executes the agent loop and yields stream events."""
737
+ self._state = AgentState.running
738
+ self._stop_requested = False
739
+ steps: list[AgentStep] = []
740
+
741
+ try:
742
+ # 1. Setup
743
+ session = UsageSession()
744
+ driver_callbacks = DriverCallbacks(
745
+ on_response=session.record,
746
+ on_error=session.record_error,
747
+ )
748
+ ctx = self._build_run_context(prompt, deps, session, [], 0)
749
+ effective_prompt = self._run_input_guardrails(ctx, prompt)
750
+ resolved_system_prompt = self._resolve_system_prompt(ctx)
751
+ wrapped_tools = self._wrap_tools_with_context(ctx)
752
+ has_tools = bool(wrapped_tools)
753
+
754
+ conv = self._build_conversation(
755
+ system_prompt=resolved_system_prompt,
756
+ tools=wrapped_tools if wrapped_tools else None,
757
+ driver_callbacks=driver_callbacks,
758
+ )
759
+
760
+ if self._agent_callbacks.on_iteration:
761
+ self._agent_callbacks.on_iteration(0)
762
+
763
+ if has_tools:
764
+ response_text = await conv.ask(effective_prompt)
765
+ yield StreamEvent(event_type=StreamEventType.text_delta, data=response_text)
766
+ else:
767
+ response_text = ""
768
+ async for chunk in conv.ask_stream(effective_prompt):
769
+ response_text += chunk
770
+ yield StreamEvent(event_type=StreamEventType.text_delta, data=chunk)
771
+
772
+ # Extract steps
773
+ all_tool_calls: list[dict[str, Any]] = []
774
+ self._extract_steps(conv.messages, steps, all_tool_calls)
775
+
776
+ # Parse output
777
+ if self._output_type is not None:
778
+ output, output_text = await self._parse_output(conv, response_text, steps, all_tool_calls, 0.0, session)
779
+ else:
780
+ output = response_text
781
+ output_text = response_text
782
+
783
+ result = AgentResult(
784
+ output=output,
785
+ output_text=output_text,
786
+ messages=conv.messages,
787
+ usage=conv.usage,
788
+ steps=steps,
789
+ all_tool_calls=all_tool_calls,
790
+ state=AgentState.idle,
791
+ run_usage=session.summary(),
792
+ )
793
+
794
+ if self._output_guardrails:
795
+ result = await self._run_output_guardrails(ctx, result, conv, session, steps, all_tool_calls)
796
+
797
+ if self._agent_callbacks.on_step:
798
+ for step in steps:
799
+ self._agent_callbacks.on_step(step)
800
+ if self._agent_callbacks.on_output:
801
+ self._agent_callbacks.on_output(result)
802
+
803
+ yield StreamEvent(event_type=StreamEventType.output, data=result)
804
+
805
+ self._state = AgentState.idle
806
+ self._last_stream_result = result
807
+ except Exception:
808
+ self._state = AgentState.errored
809
+ raise
810
+
811
+
812
+ # ------------------------------------------------------------------
813
+ # AsyncAgentIterator
814
+ # ------------------------------------------------------------------
815
+
816
+
817
+ class AsyncAgentIterator:
818
+ """Wraps the :meth:`AsyncAgent.iter` async generator, capturing the final result.
819
+
820
+ After async iteration completes, :attr:`result` holds the :class:`AgentResult`.
821
+ """
822
+
823
+ def __init__(self, gen: AsyncGenerator[AgentStep, None]) -> None:
824
+ self._gen = gen
825
+ self._result: AgentResult | None = None
826
+ self._agent: AsyncAgent[Any] | None = None
827
+
828
+ def __aiter__(self) -> AsyncAgentIterator:
829
+ return self
830
+
831
+ async def __anext__(self) -> AgentStep:
832
+ try:
833
+ return await self._gen.__anext__()
834
+ except StopAsyncIteration:
835
+ # Try to capture the result from the agent
836
+ agent = self._gen.ag_frame and self._gen.ag_frame.f_locals.get("self")
837
+ if agent and hasattr(agent, "_last_iter_result"):
838
+ self._result = agent._last_iter_result
839
+ raise
840
+
841
+ @property
842
+ def result(self) -> AgentResult | None:
843
+ """The final :class:`AgentResult`, available after iteration completes."""
844
+ return self._result
845
+
846
+
847
+ # ------------------------------------------------------------------
848
+ # AsyncStreamedAgentResult
849
+ # ------------------------------------------------------------------
850
+
851
+
852
+ class AsyncStreamedAgentResult:
853
+ """Wraps the :meth:`AsyncAgent.run_stream` async generator.
854
+
855
+ Yields :class:`StreamEvent` objects. After iteration completes,
856
+ :attr:`result` holds the :class:`AgentResult`.
857
+ """
858
+
859
+ def __init__(self, gen: AsyncGenerator[StreamEvent, None]) -> None:
860
+ self._gen = gen
861
+ self._result: AgentResult | None = None
862
+
863
+ def __aiter__(self) -> AsyncStreamedAgentResult:
864
+ return self
865
+
866
+ async def __anext__(self) -> StreamEvent:
867
+ try:
868
+ event = await self._gen.__anext__()
869
+ # Capture result from the output event
870
+ if event.event_type == StreamEventType.output and isinstance(event.data, AgentResult):
871
+ self._result = event.data
872
+ return event
873
+ except StopAsyncIteration:
874
+ raise
875
+
876
+ @property
877
+ def result(self) -> AgentResult | None:
878
+ """The final :class:`AgentResult`, available after iteration completes."""
879
+ return self._result