prompture 0.0.35__py3-none-any.whl → 0.0.40.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. prompture/__init__.py +132 -3
  2. prompture/_version.py +2 -2
  3. prompture/agent.py +924 -0
  4. prompture/agent_types.py +156 -0
  5. prompture/async_agent.py +880 -0
  6. prompture/async_conversation.py +208 -17
  7. prompture/async_core.py +16 -0
  8. prompture/async_driver.py +63 -0
  9. prompture/async_groups.py +551 -0
  10. prompture/conversation.py +222 -18
  11. prompture/core.py +46 -12
  12. prompture/cost_mixin.py +37 -0
  13. prompture/discovery.py +132 -44
  14. prompture/driver.py +77 -0
  15. prompture/drivers/__init__.py +5 -1
  16. prompture/drivers/async_azure_driver.py +11 -5
  17. prompture/drivers/async_claude_driver.py +184 -9
  18. prompture/drivers/async_google_driver.py +222 -28
  19. prompture/drivers/async_grok_driver.py +11 -5
  20. prompture/drivers/async_groq_driver.py +11 -5
  21. prompture/drivers/async_lmstudio_driver.py +74 -5
  22. prompture/drivers/async_ollama_driver.py +13 -3
  23. prompture/drivers/async_openai_driver.py +162 -5
  24. prompture/drivers/async_openrouter_driver.py +11 -5
  25. prompture/drivers/async_registry.py +5 -1
  26. prompture/drivers/azure_driver.py +10 -4
  27. prompture/drivers/claude_driver.py +17 -1
  28. prompture/drivers/google_driver.py +227 -33
  29. prompture/drivers/grok_driver.py +11 -5
  30. prompture/drivers/groq_driver.py +11 -5
  31. prompture/drivers/lmstudio_driver.py +73 -8
  32. prompture/drivers/ollama_driver.py +16 -5
  33. prompture/drivers/openai_driver.py +26 -11
  34. prompture/drivers/openrouter_driver.py +11 -5
  35. prompture/drivers/vision_helpers.py +153 -0
  36. prompture/group_types.py +147 -0
  37. prompture/groups.py +530 -0
  38. prompture/image.py +180 -0
  39. prompture/ledger.py +252 -0
  40. prompture/model_rates.py +112 -2
  41. prompture/persistence.py +254 -0
  42. prompture/persona.py +482 -0
  43. prompture/serialization.py +218 -0
  44. prompture/settings.py +1 -0
  45. prompture-0.0.40.dev1.dist-info/METADATA +369 -0
  46. prompture-0.0.40.dev1.dist-info/RECORD +78 -0
  47. prompture-0.0.35.dist-info/METADATA +0 -464
  48. prompture-0.0.35.dist-info/RECORD +0 -66
  49. {prompture-0.0.35.dist-info → prompture-0.0.40.dev1.dist-info}/WHEEL +0 -0
  50. {prompture-0.0.35.dist-info → prompture-0.0.40.dev1.dist-info}/entry_points.txt +0 -0
  51. {prompture-0.0.35.dist-info → prompture-0.0.40.dev1.dist-info}/licenses/LICENSE +0 -0
  52. {prompture-0.0.35.dist-info → prompture-0.0.40.dev1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,880 @@
1
+ """Async Agent framework for Prompture.
2
+
3
+ Provides :class:`AsyncAgent`, the async counterpart of :class:`~prompture.agent.Agent`.
4
+ All methods are ``async`` and use :class:`~prompture.async_conversation.AsyncConversation`.
5
+
6
+ Example::
7
+
8
+ from prompture import AsyncAgent
9
+
10
+ agent = AsyncAgent("openai/gpt-4o", system_prompt="You are helpful.")
11
+ result = await agent.run("What is 2 + 2?")
12
+ print(result.output)
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import asyncio
18
+ import inspect
19
+ import json
20
+ import logging
21
+ import time
22
+ import typing
23
+ from collections.abc import AsyncGenerator, Callable
24
+ from typing import Any, Generic
25
+
26
+ from pydantic import BaseModel
27
+
28
+ from .agent_types import (
29
+ AgentCallbacks,
30
+ AgentResult,
31
+ AgentState,
32
+ AgentStep,
33
+ DepsType,
34
+ ModelRetry,
35
+ RunContext,
36
+ StepType,
37
+ StreamEvent,
38
+ StreamEventType,
39
+ )
40
+ from .callbacks import DriverCallbacks
41
+ from .persona import Persona
42
+ from .session import UsageSession
43
+ from .tools import clean_json_text
44
+ from .tools_schema import ToolDefinition, ToolRegistry
45
+
46
+ logger = logging.getLogger("prompture.async_agent")
47
+
48
+ _OUTPUT_PARSE_MAX_RETRIES = 3
49
+ _OUTPUT_GUARDRAIL_MAX_RETRIES = 3
50
+
51
+
52
+ # ------------------------------------------------------------------
53
+ # Helpers
54
+ # ------------------------------------------------------------------
55
+
56
+
57
+ def _is_async_callable(fn: Callable[..., Any]) -> bool:
58
+ """Check if *fn* is an async callable (coroutine function or has async ``__call__``)."""
59
+ if asyncio.iscoroutinefunction(fn):
60
+ return True
61
+ # Check if the object has an async __call__ method (callable class)
62
+ dunder_call = type(fn).__call__ if callable(fn) else None
63
+ return dunder_call is not None and asyncio.iscoroutinefunction(dunder_call)
64
+
65
+
66
+ def _tool_wants_context(fn: Callable[..., Any]) -> bool:
67
+ """Check whether *fn*'s first parameter is annotated as :class:`RunContext`."""
68
+ sig = inspect.signature(fn)
69
+ params = list(sig.parameters.keys())
70
+ if not params:
71
+ return False
72
+
73
+ first_param = params[0]
74
+ if first_param == "self":
75
+ if len(params) < 2:
76
+ return False
77
+ first_param = params[1]
78
+
79
+ annotation = None
80
+ try:
81
+ hints = typing.get_type_hints(fn, include_extras=True)
82
+ annotation = hints.get(first_param)
83
+ except Exception:
84
+ # get_type_hints can fail with local/forward references; fall back to raw annotation
85
+ pass
86
+
87
+ if annotation is None:
88
+ raw = sig.parameters[first_param].annotation
89
+ if raw is inspect.Parameter.empty:
90
+ return False
91
+ annotation = raw
92
+
93
+ if isinstance(annotation, str):
94
+ return annotation == "RunContext" or annotation.startswith("RunContext[")
95
+
96
+ if annotation is RunContext:
97
+ return True
98
+
99
+ origin = getattr(annotation, "__origin__", None)
100
+ return origin is RunContext
101
+
102
+
103
+ def _get_first_param_name(fn: Callable[..., Any]) -> str:
104
+ """Return the name of the first non-self parameter of *fn*."""
105
+ sig = inspect.signature(fn)
106
+ for name, _param in sig.parameters.items():
107
+ if name != "self":
108
+ return name
109
+ return ""
110
+
111
+
112
+ # ------------------------------------------------------------------
113
+ # AsyncAgent
114
+ # ------------------------------------------------------------------
115
+
116
+
117
+ class AsyncAgent(Generic[DepsType]):
118
+ """Async agent that executes a ReAct loop with tool support.
119
+
120
+ Mirrors :class:`~prompture.agent.Agent` but uses
121
+ :class:`~prompture.async_conversation.AsyncConversation` and
122
+ ``async`` methods throughout.
123
+
124
+ Args:
125
+ model: Model string in ``"provider/model"`` format.
126
+ driver: Pre-built async driver instance.
127
+ tools: Initial tools as a list of callables or a :class:`ToolRegistry`.
128
+ system_prompt: System prompt prepended to every run. May also be a
129
+ callable ``(RunContext) -> str`` for dynamic prompts.
130
+ output_type: Optional Pydantic model class for structured output.
131
+ max_iterations: Maximum tool-use rounds per run.
132
+ max_cost: Soft budget in USD.
133
+ options: Extra driver options forwarded to every call.
134
+ deps_type: Type hint for dependencies.
135
+ agent_callbacks: Agent-level observability callbacks.
136
+ input_guardrails: Functions called before the prompt is sent.
137
+ output_guardrails: Functions called after output is parsed.
138
+ """
139
+
140
+ def __init__(
141
+ self,
142
+ model: str = "",
143
+ *,
144
+ driver: Any | None = None,
145
+ tools: list[Callable[..., Any]] | ToolRegistry | None = None,
146
+ system_prompt: str | Persona | Callable[..., str] | None = None,
147
+ output_type: type[BaseModel] | None = None,
148
+ max_iterations: int = 10,
149
+ max_cost: float | None = None,
150
+ options: dict[str, Any] | None = None,
151
+ deps_type: type | None = None,
152
+ agent_callbacks: AgentCallbacks | None = None,
153
+ input_guardrails: list[Callable[..., Any]] | None = None,
154
+ output_guardrails: list[Callable[..., Any]] | None = None,
155
+ name: str = "",
156
+ description: str = "",
157
+ output_key: str | None = None,
158
+ ) -> None:
159
+ if not model and driver is None:
160
+ raise ValueError("Either model or driver must be provided")
161
+
162
+ self._model = model
163
+ self._driver = driver
164
+ self._system_prompt = system_prompt
165
+ self._output_type = output_type
166
+ self._max_iterations = max_iterations
167
+ self._max_cost = max_cost
168
+ self._options = dict(options) if options else {}
169
+ self._deps_type = deps_type
170
+ self._agent_callbacks = agent_callbacks or AgentCallbacks()
171
+ self._input_guardrails = list(input_guardrails) if input_guardrails else []
172
+ self._output_guardrails = list(output_guardrails) if output_guardrails else []
173
+ self.name = name
174
+ self.description = description
175
+ self.output_key = output_key
176
+
177
+ # Build internal tool registry
178
+ self._tools = ToolRegistry()
179
+ if isinstance(tools, ToolRegistry):
180
+ self._tools = tools
181
+ elif tools is not None:
182
+ for fn in tools:
183
+ self._tools.register(fn)
184
+
185
+ self._state = AgentState.idle
186
+ self._stop_requested = False
187
+
188
+ # ------------------------------------------------------------------
189
+ # Public API
190
+ # ------------------------------------------------------------------
191
+
192
+ def tool(self, fn: Callable[..., Any]) -> Callable[..., Any]:
193
+ """Decorator to register a function as a tool on this agent."""
194
+ self._tools.register(fn)
195
+ return fn
196
+
197
+ @property
198
+ def state(self) -> AgentState:
199
+ """Current lifecycle state of the agent."""
200
+ return self._state
201
+
202
+ def stop(self) -> None:
203
+ """Request graceful shutdown after the current iteration."""
204
+ self._stop_requested = True
205
+
206
+ def as_tool(
207
+ self,
208
+ name: str | None = None,
209
+ description: str | None = None,
210
+ custom_output_extractor: Callable[[AgentResult], str] | None = None,
211
+ ) -> ToolDefinition:
212
+ """Wrap this AsyncAgent as a callable tool for another Agent.
213
+
214
+ Creates a :class:`ToolDefinition` whose function accepts a ``prompt``
215
+ string, runs this agent (bridging async to sync), and returns the
216
+ output text.
217
+
218
+ Args:
219
+ name: Tool name (defaults to ``self.name`` or ``"agent_tool"``).
220
+ description: Tool description (defaults to ``self.description``).
221
+ custom_output_extractor: Optional function to extract a string
222
+ from :class:`AgentResult`.
223
+ """
224
+ tool_name = name or self.name or "agent_tool"
225
+ tool_desc = description or self.description or f"Run agent {tool_name}"
226
+ agent = self
227
+ extractor = custom_output_extractor
228
+
229
+ def _call_agent(prompt: str) -> str:
230
+ """Run the wrapped async agent with the given prompt."""
231
+ try:
232
+ loop = asyncio.get_running_loop()
233
+ except RuntimeError:
234
+ loop = None
235
+
236
+ if loop is not None and loop.is_running():
237
+ import concurrent.futures
238
+
239
+ with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
240
+ result = pool.submit(asyncio.run, agent.run(prompt)).result()
241
+ else:
242
+ result = asyncio.run(agent.run(prompt))
243
+
244
+ if extractor is not None:
245
+ return extractor(result)
246
+ return result.output_text
247
+
248
+ return ToolDefinition(
249
+ name=tool_name,
250
+ description=tool_desc,
251
+ parameters={
252
+ "type": "object",
253
+ "properties": {
254
+ "prompt": {"type": "string", "description": "The prompt to send to the agent"},
255
+ },
256
+ "required": ["prompt"],
257
+ },
258
+ function=_call_agent,
259
+ )
260
+
261
+ async def run(self, prompt: str, *, deps: Any = None) -> AgentResult:
262
+ """Execute the agent loop to completion (async).
263
+
264
+ Creates a fresh conversation, sends the prompt, handles tool calls,
265
+ and optionally parses the final response into ``output_type``.
266
+ """
267
+ self._state = AgentState.running
268
+ self._stop_requested = False
269
+ steps: list[AgentStep] = []
270
+
271
+ try:
272
+ result = await self._execute(prompt, steps, deps)
273
+ self._state = AgentState.idle
274
+ return result
275
+ except Exception:
276
+ self._state = AgentState.errored
277
+ raise
278
+
279
+ async def iter(self, prompt: str, *, deps: Any = None) -> AsyncAgentIterator:
280
+ """Execute the agent loop and iterate over steps asynchronously.
281
+
282
+ Returns an :class:`AsyncAgentIterator` yielding :class:`AgentStep` objects.
283
+ After iteration, :attr:`AsyncAgentIterator.result` holds the final result.
284
+ """
285
+ gen = self._execute_iter(prompt, deps)
286
+ return AsyncAgentIterator(gen)
287
+
288
+ async def run_stream(self, prompt: str, *, deps: Any = None) -> AsyncStreamedAgentResult:
289
+ """Execute the agent loop with streaming output (async).
290
+
291
+ Returns an :class:`AsyncStreamedAgentResult` yielding :class:`StreamEvent` objects.
292
+ """
293
+ gen = self._execute_stream(prompt, deps)
294
+ return AsyncStreamedAgentResult(gen)
295
+
296
+ # ------------------------------------------------------------------
297
+ # RunContext helpers
298
+ # ------------------------------------------------------------------
299
+
300
+ def _build_run_context(
301
+ self,
302
+ prompt: str,
303
+ deps: Any,
304
+ session: UsageSession,
305
+ messages: list[dict[str, Any]],
306
+ iteration: int,
307
+ ) -> RunContext[Any]:
308
+ return RunContext(
309
+ deps=deps,
310
+ model=self._model,
311
+ usage=session.summary(),
312
+ messages=list(messages),
313
+ iteration=iteration,
314
+ prompt=prompt,
315
+ )
316
+
317
+ # ------------------------------------------------------------------
318
+ # Tool wrapping (RunContext injection + ModelRetry + callbacks)
319
+ # ------------------------------------------------------------------
320
+
321
+ def _wrap_tools_with_context(self, ctx: RunContext[Any]) -> ToolRegistry:
322
+ """Return a new :class:`ToolRegistry` with wrapped tool functions.
323
+
324
+ All wrappers are **sync** so they work with ``ToolRegistry.execute()``.
325
+ For async tool functions, the wrapper uses
326
+ ``asyncio.get_event_loop().run_until_complete()`` as a fallback.
327
+ """
328
+ if not self._tools:
329
+ return ToolRegistry()
330
+
331
+ new_registry = ToolRegistry()
332
+ cb = self._agent_callbacks
333
+
334
+ for td in self._tools.definitions:
335
+ wants_ctx = _tool_wants_context(td.function)
336
+ original_fn = td.function
337
+ tool_name = td.name
338
+ is_async = _is_async_callable(original_fn)
339
+
340
+ def _make_wrapper(
341
+ _fn: Callable[..., Any],
342
+ _wants: bool,
343
+ _name: str,
344
+ _is_async: bool,
345
+ _cb: AgentCallbacks = cb,
346
+ ) -> Callable[..., Any]:
347
+ def wrapper(**kwargs: Any) -> Any:
348
+ if _cb.on_tool_start:
349
+ _cb.on_tool_start(_name, kwargs)
350
+ try:
351
+ if _wants:
352
+ call_args = (ctx,)
353
+ else:
354
+ call_args = ()
355
+
356
+ if _is_async:
357
+ coro = _fn(*call_args, **kwargs)
358
+ # Try to get running loop; if none, use asyncio.run()
359
+ try:
360
+ loop = asyncio.get_running_loop()
361
+ except RuntimeError:
362
+ loop = None
363
+ if loop is not None and loop.is_running():
364
+ # We're inside an async context — create a new thread
365
+ import concurrent.futures
366
+
367
+ with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
368
+ result = pool.submit(asyncio.run, coro).result()
369
+ else:
370
+ result = asyncio.run(coro)
371
+ else:
372
+ result = _fn(*call_args, **kwargs)
373
+ except ModelRetry as exc:
374
+ result = f"Error: {exc.message}"
375
+ if _cb.on_tool_end:
376
+ _cb.on_tool_end(_name, result)
377
+ return result
378
+
379
+ return wrapper
380
+
381
+ wrapped = _make_wrapper(original_fn, wants_ctx, tool_name, is_async)
382
+
383
+ # Build schema: strip RunContext param if present
384
+ params = dict(td.parameters)
385
+ if wants_ctx:
386
+ ctx_param_name = _get_first_param_name(td.function)
387
+ props = dict(params.get("properties", {}))
388
+ props.pop(ctx_param_name, None)
389
+ params = dict(params)
390
+ params["properties"] = props
391
+ req = list(params.get("required", []))
392
+ if ctx_param_name in req:
393
+ req.remove(ctx_param_name)
394
+ if req:
395
+ params["required"] = req
396
+ elif "required" in params:
397
+ del params["required"]
398
+
399
+ new_td = ToolDefinition(
400
+ name=td.name,
401
+ description=td.description,
402
+ parameters=params,
403
+ function=wrapped,
404
+ )
405
+ new_registry.add(new_td)
406
+
407
+ return new_registry
408
+
409
+ # ------------------------------------------------------------------
410
+ # Guardrails
411
+ # ------------------------------------------------------------------
412
+
413
+ def _run_input_guardrails(self, ctx: RunContext[Any], prompt: str) -> str:
414
+ for guardrail in self._input_guardrails:
415
+ result = guardrail(ctx, prompt)
416
+ if result is not None:
417
+ prompt = result
418
+ return prompt
419
+
420
+ async def _run_output_guardrails(
421
+ self,
422
+ ctx: RunContext[Any],
423
+ result: AgentResult,
424
+ conv: Any,
425
+ session: UsageSession,
426
+ steps: list[AgentStep],
427
+ all_tool_calls: list[dict[str, Any]],
428
+ ) -> AgentResult:
429
+ for guardrail in self._output_guardrails:
430
+ for attempt in range(_OUTPUT_GUARDRAIL_MAX_RETRIES):
431
+ try:
432
+ guard_result = guardrail(ctx, result)
433
+ if guard_result is not None:
434
+ result = guard_result
435
+ break
436
+ except ModelRetry as exc:
437
+ if self._is_over_budget(session):
438
+ break
439
+ if attempt >= _OUTPUT_GUARDRAIL_MAX_RETRIES - 1:
440
+ raise ValueError(
441
+ f"Output guardrail failed after {_OUTPUT_GUARDRAIL_MAX_RETRIES} retries: {exc.message}"
442
+ ) from exc
443
+ retry_text = await conv.ask(
444
+ f"Your response did not pass validation. Error: {exc.message}\n\nPlease try again."
445
+ )
446
+ self._extract_steps(conv.messages[-2:], steps, all_tool_calls)
447
+
448
+ if self._output_type is not None:
449
+ try:
450
+ cleaned = clean_json_text(retry_text)
451
+ parsed = json.loads(cleaned)
452
+ output = self._output_type.model_validate(parsed)
453
+ except Exception:
454
+ output = retry_text
455
+ else:
456
+ output = retry_text
457
+
458
+ result = AgentResult(
459
+ output=output,
460
+ output_text=retry_text,
461
+ messages=conv.messages,
462
+ usage=conv.usage,
463
+ steps=steps,
464
+ all_tool_calls=all_tool_calls,
465
+ state=AgentState.idle,
466
+ run_usage=session.summary(),
467
+ )
468
+ return result
469
+
470
+ # ------------------------------------------------------------------
471
+ # Budget check
472
+ # ------------------------------------------------------------------
473
+
474
+ def _is_over_budget(self, session: UsageSession) -> bool:
475
+ if self._max_cost is None:
476
+ return False
477
+ return session.total_cost >= self._max_cost
478
+
479
+ # ------------------------------------------------------------------
480
+ # Internals
481
+ # ------------------------------------------------------------------
482
+
483
+ def _resolve_system_prompt(self, ctx: RunContext[Any] | None = None) -> str | None:
484
+ parts: list[str] = []
485
+
486
+ if self._system_prompt is not None:
487
+ if isinstance(self._system_prompt, Persona):
488
+ render_kwargs: dict[str, Any] = {}
489
+ if ctx is not None:
490
+ render_kwargs["model"] = ctx.model
491
+ render_kwargs["iteration"] = ctx.iteration
492
+ parts.append(self._system_prompt.render(**render_kwargs))
493
+ elif callable(self._system_prompt) and not isinstance(self._system_prompt, str):
494
+ if ctx is not None:
495
+ parts.append(self._system_prompt(ctx))
496
+ else:
497
+ parts.append(self._system_prompt(None)) # type: ignore[arg-type]
498
+ else:
499
+ parts.append(str(self._system_prompt))
500
+
501
+ if self._output_type is not None:
502
+ schema = self._output_type.model_json_schema()
503
+ schema_str = json.dumps(schema, indent=2)
504
+ parts.append(
505
+ "You MUST respond with a single JSON object (no markdown, "
506
+ "no extra text) that validates against this JSON schema:\n"
507
+ f"{schema_str}\n\n"
508
+ "Use double quotes for keys and strings. "
509
+ "If a value is unknown use null."
510
+ )
511
+
512
+ return "\n\n".join(parts) if parts else None
513
+
514
+ def _build_conversation(
515
+ self,
516
+ system_prompt: str | None = None,
517
+ tools: ToolRegistry | None = None,
518
+ driver_callbacks: DriverCallbacks | None = None,
519
+ ) -> Any:
520
+ """Create a fresh AsyncConversation for a single run."""
521
+ from .async_conversation import AsyncConversation
522
+
523
+ effective_tools = tools if tools is not None else (self._tools if self._tools else None)
524
+
525
+ kwargs: dict[str, Any] = {
526
+ "system_prompt": system_prompt if system_prompt is not None else self._resolve_system_prompt(),
527
+ "tools": effective_tools,
528
+ "max_tool_rounds": self._max_iterations,
529
+ }
530
+ if self._options:
531
+ kwargs["options"] = self._options
532
+ if driver_callbacks is not None:
533
+ kwargs["callbacks"] = driver_callbacks
534
+
535
+ if self._driver is not None:
536
+ kwargs["driver"] = self._driver
537
+ else:
538
+ kwargs["model_name"] = self._model
539
+
540
+ return AsyncConversation(**kwargs)
541
+
542
+ async def _execute(self, prompt: str, steps: list[AgentStep], deps: Any) -> AgentResult:
543
+ """Core async execution: run conversation, extract steps, parse output."""
544
+ # 1. Create per-run UsageSession
545
+ session = UsageSession()
546
+ driver_callbacks = DriverCallbacks(
547
+ on_response=session.record,
548
+ on_error=session.record_error,
549
+ )
550
+
551
+ # 2. Build initial RunContext
552
+ ctx = self._build_run_context(prompt, deps, session, [], 0)
553
+
554
+ # 3. Run input guardrails
555
+ effective_prompt = self._run_input_guardrails(ctx, prompt)
556
+
557
+ # 4. Resolve system prompt
558
+ resolved_system_prompt = self._resolve_system_prompt(ctx)
559
+
560
+ # 5. Wrap tools with context
561
+ wrapped_tools = self._wrap_tools_with_context(ctx)
562
+
563
+ # 6. Build AsyncConversation
564
+ conv = self._build_conversation(
565
+ system_prompt=resolved_system_prompt,
566
+ tools=wrapped_tools if wrapped_tools else None,
567
+ driver_callbacks=driver_callbacks,
568
+ )
569
+
570
+ # 7. Fire on_iteration callback
571
+ if self._agent_callbacks.on_iteration:
572
+ self._agent_callbacks.on_iteration(0)
573
+
574
+ # 8. Ask the conversation (handles full tool loop internally)
575
+ t0 = time.perf_counter()
576
+ response_text = await conv.ask(effective_prompt)
577
+ elapsed_ms = (time.perf_counter() - t0) * 1000
578
+
579
+ # 9. Extract steps and tool calls
580
+ all_tool_calls: list[dict[str, Any]] = []
581
+ self._extract_steps(conv.messages, steps, all_tool_calls)
582
+
583
+ # Handle output_type parsing
584
+ if self._output_type is not None:
585
+ output, output_text = await self._parse_output(
586
+ conv, response_text, steps, all_tool_calls, elapsed_ms, session
587
+ )
588
+ else:
589
+ output = response_text
590
+ output_text = response_text
591
+
592
+ result = AgentResult(
593
+ output=output,
594
+ output_text=output_text,
595
+ messages=conv.messages,
596
+ usage=conv.usage,
597
+ steps=steps,
598
+ all_tool_calls=all_tool_calls,
599
+ state=AgentState.idle,
600
+ run_usage=session.summary(),
601
+ )
602
+
603
+ # 10. Run output guardrails
604
+ if self._output_guardrails:
605
+ result = await self._run_output_guardrails(ctx, result, conv, session, steps, all_tool_calls)
606
+
607
+ # 11. Fire callbacks
608
+ if self._agent_callbacks.on_step:
609
+ for step in steps:
610
+ self._agent_callbacks.on_step(step)
611
+ if self._agent_callbacks.on_output:
612
+ self._agent_callbacks.on_output(result)
613
+
614
+ return result
615
+
616
+ def _extract_steps(
617
+ self,
618
+ messages: list[dict[str, Any]],
619
+ steps: list[AgentStep],
620
+ all_tool_calls: list[dict[str, Any]],
621
+ ) -> None:
622
+ """Scan conversation messages and populate steps and tool_calls."""
623
+ now = time.time()
624
+
625
+ for msg in messages:
626
+ role = msg.get("role", "")
627
+
628
+ if role == "assistant":
629
+ tc_list = msg.get("tool_calls", [])
630
+ if tc_list:
631
+ for tc in tc_list:
632
+ fn = tc.get("function", {})
633
+ name = fn.get("name", tc.get("name", ""))
634
+ raw_args = fn.get("arguments", tc.get("arguments", "{}"))
635
+ if isinstance(raw_args, str):
636
+ try:
637
+ args = json.loads(raw_args)
638
+ except json.JSONDecodeError:
639
+ args = {}
640
+ else:
641
+ args = raw_args
642
+
643
+ steps.append(
644
+ AgentStep(
645
+ step_type=StepType.tool_call,
646
+ timestamp=now,
647
+ content=msg.get("content", ""),
648
+ tool_name=name,
649
+ tool_args=args,
650
+ )
651
+ )
652
+ all_tool_calls.append({"name": name, "arguments": args, "id": tc.get("id", "")})
653
+ else:
654
+ steps.append(
655
+ AgentStep(
656
+ step_type=StepType.output,
657
+ timestamp=now,
658
+ content=msg.get("content", ""),
659
+ )
660
+ )
661
+
662
+ elif role == "tool":
663
+ steps.append(
664
+ AgentStep(
665
+ step_type=StepType.tool_result,
666
+ timestamp=now,
667
+ content=msg.get("content", ""),
668
+ tool_name=msg.get("tool_call_id"),
669
+ )
670
+ )
671
+
672
+ async def _parse_output(
673
+ self,
674
+ conv: Any,
675
+ response_text: str,
676
+ steps: list[AgentStep],
677
+ all_tool_calls: list[dict[str, Any]],
678
+ elapsed_ms: float,
679
+ session: UsageSession | None = None,
680
+ ) -> tuple[Any, str]:
681
+ """Try to parse ``response_text`` as the output_type, with retries (async)."""
682
+ assert self._output_type is not None
683
+
684
+ last_error: Exception | None = None
685
+ text = response_text
686
+
687
+ for attempt in range(_OUTPUT_PARSE_MAX_RETRIES):
688
+ try:
689
+ cleaned = clean_json_text(text)
690
+ parsed = json.loads(cleaned)
691
+ model_instance = self._output_type.model_validate(parsed)
692
+ return model_instance, text
693
+ except Exception as exc:
694
+ last_error = exc
695
+ if attempt < _OUTPUT_PARSE_MAX_RETRIES - 1:
696
+ if session is not None and self._is_over_budget(session):
697
+ break
698
+ retry_msg = (
699
+ f"Your previous response could not be parsed as valid JSON "
700
+ f"matching the required schema. Error: {exc}\n\n"
701
+ f"Please try again and respond ONLY with valid JSON."
702
+ )
703
+ text = await conv.ask(retry_msg)
704
+ self._extract_steps(conv.messages[-2:], steps, all_tool_calls)
705
+
706
+ raise ValueError(
707
+ f"Failed to parse output as {self._output_type.__name__} "
708
+ f"after {_OUTPUT_PARSE_MAX_RETRIES} attempts: {last_error}"
709
+ )
710
+
711
+ # ------------------------------------------------------------------
712
+ # iter() — async step-by-step
713
+ # ------------------------------------------------------------------
714
+
715
+ async def _execute_iter(self, prompt: str, deps: Any) -> AsyncGenerator[AgentStep, None]:
716
+ """Async generator that executes the agent loop and yields each step."""
717
+ self._state = AgentState.running
718
+ self._stop_requested = False
719
+ steps: list[AgentStep] = []
720
+
721
+ try:
722
+ result = await self._execute(prompt, steps, deps)
723
+ for step in result.steps:
724
+ yield step
725
+ self._state = AgentState.idle
726
+ # Store result on the generator for retrieval
727
+ self._last_iter_result = result
728
+ except Exception:
729
+ self._state = AgentState.errored
730
+ raise
731
+
732
+ # ------------------------------------------------------------------
733
+ # run_stream() — async streaming
734
+ # ------------------------------------------------------------------
735
+
736
+ async def _execute_stream(self, prompt: str, deps: Any) -> AsyncGenerator[StreamEvent, None]:
737
+ """Async generator that executes the agent loop and yields stream events."""
738
+ self._state = AgentState.running
739
+ self._stop_requested = False
740
+ steps: list[AgentStep] = []
741
+
742
+ try:
743
+ # 1. Setup
744
+ session = UsageSession()
745
+ driver_callbacks = DriverCallbacks(
746
+ on_response=session.record,
747
+ on_error=session.record_error,
748
+ )
749
+ ctx = self._build_run_context(prompt, deps, session, [], 0)
750
+ effective_prompt = self._run_input_guardrails(ctx, prompt)
751
+ resolved_system_prompt = self._resolve_system_prompt(ctx)
752
+ wrapped_tools = self._wrap_tools_with_context(ctx)
753
+ has_tools = bool(wrapped_tools)
754
+
755
+ conv = self._build_conversation(
756
+ system_prompt=resolved_system_prompt,
757
+ tools=wrapped_tools if wrapped_tools else None,
758
+ driver_callbacks=driver_callbacks,
759
+ )
760
+
761
+ if self._agent_callbacks.on_iteration:
762
+ self._agent_callbacks.on_iteration(0)
763
+
764
+ if has_tools:
765
+ response_text = await conv.ask(effective_prompt)
766
+ yield StreamEvent(event_type=StreamEventType.text_delta, data=response_text)
767
+ else:
768
+ response_text = ""
769
+ async for chunk in conv.ask_stream(effective_prompt):
770
+ response_text += chunk
771
+ yield StreamEvent(event_type=StreamEventType.text_delta, data=chunk)
772
+
773
+ # Extract steps
774
+ all_tool_calls: list[dict[str, Any]] = []
775
+ self._extract_steps(conv.messages, steps, all_tool_calls)
776
+
777
+ # Parse output
778
+ if self._output_type is not None:
779
+ output, output_text = await self._parse_output(conv, response_text, steps, all_tool_calls, 0.0, session)
780
+ else:
781
+ output = response_text
782
+ output_text = response_text
783
+
784
+ result = AgentResult(
785
+ output=output,
786
+ output_text=output_text,
787
+ messages=conv.messages,
788
+ usage=conv.usage,
789
+ steps=steps,
790
+ all_tool_calls=all_tool_calls,
791
+ state=AgentState.idle,
792
+ run_usage=session.summary(),
793
+ )
794
+
795
+ if self._output_guardrails:
796
+ result = await self._run_output_guardrails(ctx, result, conv, session, steps, all_tool_calls)
797
+
798
+ if self._agent_callbacks.on_step:
799
+ for step in steps:
800
+ self._agent_callbacks.on_step(step)
801
+ if self._agent_callbacks.on_output:
802
+ self._agent_callbacks.on_output(result)
803
+
804
+ yield StreamEvent(event_type=StreamEventType.output, data=result)
805
+
806
+ self._state = AgentState.idle
807
+ self._last_stream_result = result
808
+ except Exception:
809
+ self._state = AgentState.errored
810
+ raise
811
+
812
+
813
+ # ------------------------------------------------------------------
814
+ # AsyncAgentIterator
815
+ # ------------------------------------------------------------------
816
+
817
+
818
+ class AsyncAgentIterator:
819
+ """Wraps the :meth:`AsyncAgent.iter` async generator, capturing the final result.
820
+
821
+ After async iteration completes, :attr:`result` holds the :class:`AgentResult`.
822
+ """
823
+
824
+ def __init__(self, gen: AsyncGenerator[AgentStep, None]) -> None:
825
+ self._gen = gen
826
+ self._result: AgentResult | None = None
827
+ self._agent: AsyncAgent[Any] | None = None
828
+
829
+ def __aiter__(self) -> AsyncAgentIterator:
830
+ return self
831
+
832
+ async def __anext__(self) -> AgentStep:
833
+ try:
834
+ return await self._gen.__anext__()
835
+ except StopAsyncIteration:
836
+ # Try to capture the result from the agent
837
+ agent = self._gen.ag_frame and self._gen.ag_frame.f_locals.get("self")
838
+ if agent and hasattr(agent, "_last_iter_result"):
839
+ self._result = agent._last_iter_result
840
+ raise
841
+
842
+ @property
843
+ def result(self) -> AgentResult | None:
844
+ """The final :class:`AgentResult`, available after iteration completes."""
845
+ return self._result
846
+
847
+
848
+ # ------------------------------------------------------------------
849
+ # AsyncStreamedAgentResult
850
+ # ------------------------------------------------------------------
851
+
852
+
853
+ class AsyncStreamedAgentResult:
854
+ """Wraps the :meth:`AsyncAgent.run_stream` async generator.
855
+
856
+ Yields :class:`StreamEvent` objects. After iteration completes,
857
+ :attr:`result` holds the :class:`AgentResult`.
858
+ """
859
+
860
+ def __init__(self, gen: AsyncGenerator[StreamEvent, None]) -> None:
861
+ self._gen = gen
862
+ self._result: AgentResult | None = None
863
+
864
+ def __aiter__(self) -> AsyncStreamedAgentResult:
865
+ return self
866
+
867
+ async def __anext__(self) -> StreamEvent:
868
+ try:
869
+ event = await self._gen.__anext__()
870
+ # Capture result from the output event
871
+ if event.event_type == StreamEventType.output and isinstance(event.data, AgentResult):
872
+ self._result = event.data
873
+ return event
874
+ except StopAsyncIteration:
875
+ raise
876
+
877
+ @property
878
+ def result(self) -> AgentResult | None:
879
+ """The final :class:`AgentResult`, available after iteration completes."""
880
+ return self._result