prompture 0.0.35__py3-none-any.whl → 0.0.38.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prompture/__init__.py +120 -2
- prompture/_version.py +2 -2
- prompture/agent.py +924 -0
- prompture/agent_types.py +156 -0
- prompture/async_agent.py +880 -0
- prompture/async_conversation.py +199 -17
- prompture/async_driver.py +24 -0
- prompture/async_groups.py +551 -0
- prompture/conversation.py +213 -18
- prompture/core.py +30 -12
- prompture/discovery.py +24 -1
- prompture/driver.py +38 -0
- prompture/drivers/__init__.py +5 -1
- prompture/drivers/async_azure_driver.py +7 -1
- prompture/drivers/async_claude_driver.py +7 -1
- prompture/drivers/async_google_driver.py +212 -28
- prompture/drivers/async_grok_driver.py +7 -1
- prompture/drivers/async_groq_driver.py +7 -1
- prompture/drivers/async_lmstudio_driver.py +74 -5
- prompture/drivers/async_ollama_driver.py +13 -3
- prompture/drivers/async_openai_driver.py +7 -1
- prompture/drivers/async_openrouter_driver.py +7 -1
- prompture/drivers/async_registry.py +5 -1
- prompture/drivers/azure_driver.py +7 -1
- prompture/drivers/claude_driver.py +7 -1
- prompture/drivers/google_driver.py +217 -33
- prompture/drivers/grok_driver.py +7 -1
- prompture/drivers/groq_driver.py +7 -1
- prompture/drivers/lmstudio_driver.py +73 -8
- prompture/drivers/ollama_driver.py +16 -5
- prompture/drivers/openai_driver.py +7 -1
- prompture/drivers/openrouter_driver.py +7 -1
- prompture/drivers/vision_helpers.py +153 -0
- prompture/group_types.py +147 -0
- prompture/groups.py +530 -0
- prompture/image.py +180 -0
- prompture/persistence.py +254 -0
- prompture/persona.py +482 -0
- prompture/serialization.py +218 -0
- prompture/settings.py +1 -0
- prompture-0.0.38.dev2.dist-info/METADATA +369 -0
- prompture-0.0.38.dev2.dist-info/RECORD +77 -0
- prompture-0.0.35.dist-info/METADATA +0 -464
- prompture-0.0.35.dist-info/RECORD +0 -66
- {prompture-0.0.35.dist-info → prompture-0.0.38.dev2.dist-info}/WHEEL +0 -0
- {prompture-0.0.35.dist-info → prompture-0.0.38.dev2.dist-info}/entry_points.txt +0 -0
- {prompture-0.0.35.dist-info → prompture-0.0.38.dev2.dist-info}/licenses/LICENSE +0 -0
- {prompture-0.0.35.dist-info → prompture-0.0.38.dev2.dist-info}/top_level.txt +0 -0
prompture/agent.py
ADDED
|
@@ -0,0 +1,924 @@
|
|
|
1
|
+
"""Agent framework for Prompture.
|
|
2
|
+
|
|
3
|
+
Provides a reusable :class:`Agent` that wraps a ReAct-style loop around
|
|
4
|
+
:class:`~prompture.conversation.Conversation`, with optional structured
|
|
5
|
+
output via Pydantic models and tool support via :class:`ToolRegistry`.
|
|
6
|
+
|
|
7
|
+
Example::
|
|
8
|
+
|
|
9
|
+
from prompture import Agent
|
|
10
|
+
|
|
11
|
+
agent = Agent("openai/gpt-4o", system_prompt="You are a helpful assistant.")
|
|
12
|
+
result = agent.run("What is the capital of France?")
|
|
13
|
+
print(result.output)
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
import inspect
|
|
19
|
+
import json
|
|
20
|
+
import logging
|
|
21
|
+
import time
|
|
22
|
+
import typing
|
|
23
|
+
from collections.abc import Callable, Generator, Iterator
|
|
24
|
+
from typing import Any, Generic
|
|
25
|
+
|
|
26
|
+
from pydantic import BaseModel
|
|
27
|
+
|
|
28
|
+
from .agent_types import (
|
|
29
|
+
AgentCallbacks,
|
|
30
|
+
AgentResult,
|
|
31
|
+
AgentState,
|
|
32
|
+
AgentStep,
|
|
33
|
+
DepsType,
|
|
34
|
+
ModelRetry,
|
|
35
|
+
RunContext,
|
|
36
|
+
StepType,
|
|
37
|
+
StreamEvent,
|
|
38
|
+
StreamEventType,
|
|
39
|
+
)
|
|
40
|
+
from .callbacks import DriverCallbacks
|
|
41
|
+
from .conversation import Conversation
|
|
42
|
+
from .driver import Driver
|
|
43
|
+
from .persona import Persona
|
|
44
|
+
from .session import UsageSession
|
|
45
|
+
from .tools import clean_json_text
|
|
46
|
+
from .tools_schema import ToolDefinition, ToolRegistry
|
|
47
|
+
|
|
48
|
+
logger = logging.getLogger("prompture.agent")
|
|
49
|
+
|
|
50
|
+
_OUTPUT_PARSE_MAX_RETRIES = 3
|
|
51
|
+
_OUTPUT_GUARDRAIL_MAX_RETRIES = 3
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
# ------------------------------------------------------------------
|
|
55
|
+
# Module-level helpers for RunContext injection
|
|
56
|
+
# ------------------------------------------------------------------
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _tool_wants_context(fn: Callable[..., Any]) -> bool:
|
|
60
|
+
"""Check whether *fn*'s first parameter is annotated as :class:`RunContext`.
|
|
61
|
+
|
|
62
|
+
Uses :func:`typing.get_type_hints` to resolve string annotations
|
|
63
|
+
(from ``from __future__ import annotations``). Falls back to raw
|
|
64
|
+
``__annotations__`` when ``get_type_hints`` cannot resolve local types.
|
|
65
|
+
"""
|
|
66
|
+
sig = inspect.signature(fn)
|
|
67
|
+
params = list(sig.parameters.keys())
|
|
68
|
+
if not params:
|
|
69
|
+
return False
|
|
70
|
+
|
|
71
|
+
first_param = params[0]
|
|
72
|
+
if first_param == "self":
|
|
73
|
+
if len(params) < 2:
|
|
74
|
+
return False
|
|
75
|
+
first_param = params[1]
|
|
76
|
+
|
|
77
|
+
# Try get_type_hints first (resolves string annotations)
|
|
78
|
+
annotation = None
|
|
79
|
+
try:
|
|
80
|
+
hints = typing.get_type_hints(fn, include_extras=True)
|
|
81
|
+
annotation = hints.get(first_param)
|
|
82
|
+
except Exception:
|
|
83
|
+
# get_type_hints can fail with local/forward references; fall back to raw annotation
|
|
84
|
+
pass
|
|
85
|
+
|
|
86
|
+
# Fallback: inspect raw annotation (may be a string)
|
|
87
|
+
if annotation is None:
|
|
88
|
+
raw = sig.parameters[first_param].annotation
|
|
89
|
+
if raw is inspect.Parameter.empty:
|
|
90
|
+
return False
|
|
91
|
+
annotation = raw
|
|
92
|
+
|
|
93
|
+
# String annotation: check if it starts with "RunContext"
|
|
94
|
+
if isinstance(annotation, str):
|
|
95
|
+
return annotation == "RunContext" or annotation.startswith("RunContext[")
|
|
96
|
+
|
|
97
|
+
# Direct match
|
|
98
|
+
if annotation is RunContext:
|
|
99
|
+
return True
|
|
100
|
+
|
|
101
|
+
# Generic alias: RunContext[X]
|
|
102
|
+
origin = getattr(annotation, "__origin__", None)
|
|
103
|
+
return origin is RunContext
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def _get_first_param_name(fn: Callable[..., Any]) -> str:
|
|
107
|
+
"""Return the name of the first non-self parameter of *fn*."""
|
|
108
|
+
sig = inspect.signature(fn)
|
|
109
|
+
for name, _param in sig.parameters.items():
|
|
110
|
+
if name != "self":
|
|
111
|
+
return name
|
|
112
|
+
return ""
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
# ------------------------------------------------------------------
|
|
116
|
+
# Agent
|
|
117
|
+
# ------------------------------------------------------------------
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class Agent(Generic[DepsType]):
|
|
121
|
+
"""A reusable agent that executes a ReAct loop with tool support.
|
|
122
|
+
|
|
123
|
+
Each call to :meth:`run` creates a fresh :class:`Conversation`,
|
|
124
|
+
preventing state leakage between runs. The Agent itself is a
|
|
125
|
+
template holding model config, tools, and system prompt.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
model: Model string in ``"provider/model"`` format.
|
|
129
|
+
driver: Pre-built driver instance (useful for testing).
|
|
130
|
+
tools: Initial tools as a list of callables or a
|
|
131
|
+
:class:`ToolRegistry`.
|
|
132
|
+
system_prompt: System prompt prepended to every run. May also
|
|
133
|
+
be a callable ``(RunContext) -> str`` for dynamic prompts.
|
|
134
|
+
output_type: Optional Pydantic model class. When set, the
|
|
135
|
+
final LLM response is parsed and validated against this type.
|
|
136
|
+
max_iterations: Maximum tool-use rounds per run.
|
|
137
|
+
max_cost: Soft budget in USD. When exceeded, output parse and
|
|
138
|
+
guardrail retries are skipped.
|
|
139
|
+
options: Extra driver options forwarded to every call.
|
|
140
|
+
deps_type: Type hint for dependencies (for docs/IDE only).
|
|
141
|
+
agent_callbacks: Agent-level observability callbacks.
|
|
142
|
+
input_guardrails: Functions called before the prompt is sent.
|
|
143
|
+
output_guardrails: Functions called after output is parsed.
|
|
144
|
+
"""
|
|
145
|
+
|
|
146
|
+
def __init__(
|
|
147
|
+
self,
|
|
148
|
+
model: str = "",
|
|
149
|
+
*,
|
|
150
|
+
driver: Driver | None = None,
|
|
151
|
+
tools: list[Callable[..., Any]] | ToolRegistry | None = None,
|
|
152
|
+
system_prompt: str | Persona | Callable[..., str] | None = None,
|
|
153
|
+
output_type: type[BaseModel] | None = None,
|
|
154
|
+
max_iterations: int = 10,
|
|
155
|
+
max_cost: float | None = None,
|
|
156
|
+
options: dict[str, Any] | None = None,
|
|
157
|
+
deps_type: type | None = None,
|
|
158
|
+
agent_callbacks: AgentCallbacks | None = None,
|
|
159
|
+
input_guardrails: list[Callable[..., Any]] | None = None,
|
|
160
|
+
output_guardrails: list[Callable[..., Any]] | None = None,
|
|
161
|
+
name: str = "",
|
|
162
|
+
description: str = "",
|
|
163
|
+
output_key: str | None = None,
|
|
164
|
+
) -> None:
|
|
165
|
+
if not model and driver is None:
|
|
166
|
+
raise ValueError("Either model or driver must be provided")
|
|
167
|
+
|
|
168
|
+
self._model = model
|
|
169
|
+
self._driver = driver
|
|
170
|
+
self._system_prompt = system_prompt
|
|
171
|
+
self._output_type = output_type
|
|
172
|
+
self._max_iterations = max_iterations
|
|
173
|
+
self._max_cost = max_cost
|
|
174
|
+
self._options = dict(options) if options else {}
|
|
175
|
+
self._deps_type = deps_type
|
|
176
|
+
self._agent_callbacks = agent_callbacks or AgentCallbacks()
|
|
177
|
+
self._input_guardrails = list(input_guardrails) if input_guardrails else []
|
|
178
|
+
self._output_guardrails = list(output_guardrails) if output_guardrails else []
|
|
179
|
+
self.name = name
|
|
180
|
+
self.description = description
|
|
181
|
+
self.output_key = output_key
|
|
182
|
+
|
|
183
|
+
# Build internal tool registry
|
|
184
|
+
self._tools = ToolRegistry()
|
|
185
|
+
if isinstance(tools, ToolRegistry):
|
|
186
|
+
self._tools = tools
|
|
187
|
+
elif tools is not None:
|
|
188
|
+
for fn in tools:
|
|
189
|
+
self._tools.register(fn)
|
|
190
|
+
|
|
191
|
+
self._state = AgentState.idle
|
|
192
|
+
self._stop_requested = False
|
|
193
|
+
|
|
194
|
+
# ------------------------------------------------------------------
|
|
195
|
+
# Public API
|
|
196
|
+
# ------------------------------------------------------------------
|
|
197
|
+
|
|
198
|
+
def tool(self, fn: Callable[..., Any]) -> Callable[..., Any]:
|
|
199
|
+
"""Decorator to register a function as a tool on this agent.
|
|
200
|
+
|
|
201
|
+
Returns the original function unchanged.
|
|
202
|
+
"""
|
|
203
|
+
self._tools.register(fn)
|
|
204
|
+
return fn
|
|
205
|
+
|
|
206
|
+
@property
|
|
207
|
+
def state(self) -> AgentState:
|
|
208
|
+
"""Current lifecycle state of the agent."""
|
|
209
|
+
return self._state
|
|
210
|
+
|
|
211
|
+
def stop(self) -> None:
|
|
212
|
+
"""Request graceful shutdown after the current iteration."""
|
|
213
|
+
self._stop_requested = True
|
|
214
|
+
|
|
215
|
+
def as_tool(
|
|
216
|
+
self,
|
|
217
|
+
name: str | None = None,
|
|
218
|
+
description: str | None = None,
|
|
219
|
+
custom_output_extractor: Callable[[AgentResult], str] | None = None,
|
|
220
|
+
) -> ToolDefinition:
|
|
221
|
+
"""Wrap this Agent as a callable tool for another Agent.
|
|
222
|
+
|
|
223
|
+
Creates a :class:`ToolDefinition` whose function accepts a ``prompt``
|
|
224
|
+
string, runs this agent, and returns the output text.
|
|
225
|
+
|
|
226
|
+
Args:
|
|
227
|
+
name: Tool name (defaults to ``self.name`` or ``"agent_tool"``).
|
|
228
|
+
description: Tool description (defaults to ``self.description``).
|
|
229
|
+
custom_output_extractor: Optional function to extract a string
|
|
230
|
+
from :class:`AgentResult`. Defaults to ``result.output_text``.
|
|
231
|
+
"""
|
|
232
|
+
tool_name = name or self.name or "agent_tool"
|
|
233
|
+
tool_desc = description or self.description or f"Run agent {tool_name}"
|
|
234
|
+
agent = self
|
|
235
|
+
extractor = custom_output_extractor
|
|
236
|
+
|
|
237
|
+
def _call_agent(prompt: str) -> str:
|
|
238
|
+
"""Run the wrapped agent with the given prompt."""
|
|
239
|
+
result = agent.run(prompt)
|
|
240
|
+
if extractor is not None:
|
|
241
|
+
return extractor(result)
|
|
242
|
+
return result.output_text
|
|
243
|
+
|
|
244
|
+
return ToolDefinition(
|
|
245
|
+
name=tool_name,
|
|
246
|
+
description=tool_desc,
|
|
247
|
+
parameters={
|
|
248
|
+
"type": "object",
|
|
249
|
+
"properties": {
|
|
250
|
+
"prompt": {"type": "string", "description": "The prompt to send to the agent"},
|
|
251
|
+
},
|
|
252
|
+
"required": ["prompt"],
|
|
253
|
+
},
|
|
254
|
+
function=_call_agent,
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
def run(self, prompt: str, *, deps: Any = None) -> AgentResult:
|
|
258
|
+
"""Execute the agent loop to completion.
|
|
259
|
+
|
|
260
|
+
Creates a fresh :class:`Conversation`, sends the prompt,
|
|
261
|
+
handles any tool calls, and optionally parses the final
|
|
262
|
+
response into an ``output_type`` Pydantic model.
|
|
263
|
+
|
|
264
|
+
Args:
|
|
265
|
+
prompt: The user prompt to send.
|
|
266
|
+
deps: Optional dependencies injected into :class:`RunContext`.
|
|
267
|
+
"""
|
|
268
|
+
self._state = AgentState.running
|
|
269
|
+
self._stop_requested = False
|
|
270
|
+
steps: list[AgentStep] = []
|
|
271
|
+
|
|
272
|
+
try:
|
|
273
|
+
result = self._execute(prompt, steps, deps)
|
|
274
|
+
self._state = AgentState.idle
|
|
275
|
+
return result
|
|
276
|
+
except Exception:
|
|
277
|
+
self._state = AgentState.errored
|
|
278
|
+
raise
|
|
279
|
+
|
|
280
|
+
# ------------------------------------------------------------------
|
|
281
|
+
# RunContext helpers
|
|
282
|
+
# ------------------------------------------------------------------
|
|
283
|
+
|
|
284
|
+
def _build_run_context(
|
|
285
|
+
self,
|
|
286
|
+
prompt: str,
|
|
287
|
+
deps: Any,
|
|
288
|
+
session: UsageSession,
|
|
289
|
+
messages: list[dict[str, Any]],
|
|
290
|
+
iteration: int,
|
|
291
|
+
) -> RunContext[Any]:
|
|
292
|
+
"""Create a :class:`RunContext` snapshot for the current run."""
|
|
293
|
+
return RunContext(
|
|
294
|
+
deps=deps,
|
|
295
|
+
model=self._model,
|
|
296
|
+
usage=session.summary(),
|
|
297
|
+
messages=list(messages),
|
|
298
|
+
iteration=iteration,
|
|
299
|
+
prompt=prompt,
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
# ------------------------------------------------------------------
|
|
303
|
+
# Tool wrapping (RunContext injection + ModelRetry + callbacks)
|
|
304
|
+
# ------------------------------------------------------------------
|
|
305
|
+
|
|
306
|
+
def _wrap_tools_with_context(self, ctx: RunContext[Any]) -> ToolRegistry:
|
|
307
|
+
"""Return a new :class:`ToolRegistry` with wrapped tool functions.
|
|
308
|
+
|
|
309
|
+
For each registered tool:
|
|
310
|
+
- If the tool's first param is ``RunContext``, inject *ctx* automatically.
|
|
311
|
+
- Catch :class:`ModelRetry` and convert to an error string.
|
|
312
|
+
- Fire ``agent_callbacks.on_tool_start`` / ``on_tool_end``.
|
|
313
|
+
- Strip the ``RunContext`` parameter from the JSON schema sent to the LLM.
|
|
314
|
+
"""
|
|
315
|
+
if not self._tools:
|
|
316
|
+
return ToolRegistry()
|
|
317
|
+
|
|
318
|
+
new_registry = ToolRegistry()
|
|
319
|
+
|
|
320
|
+
cb = self._agent_callbacks
|
|
321
|
+
|
|
322
|
+
for td in self._tools.definitions:
|
|
323
|
+
wants_ctx = _tool_wants_context(td.function)
|
|
324
|
+
original_fn = td.function
|
|
325
|
+
tool_name = td.name
|
|
326
|
+
|
|
327
|
+
def _make_wrapper(
|
|
328
|
+
_fn: Callable[..., Any],
|
|
329
|
+
_wants: bool,
|
|
330
|
+
_name: str,
|
|
331
|
+
_cb: AgentCallbacks = cb,
|
|
332
|
+
) -> Callable[..., Any]:
|
|
333
|
+
def wrapper(**kwargs: Any) -> Any:
|
|
334
|
+
if _cb.on_tool_start:
|
|
335
|
+
_cb.on_tool_start(_name, kwargs)
|
|
336
|
+
try:
|
|
337
|
+
if _wants:
|
|
338
|
+
result = _fn(ctx, **kwargs)
|
|
339
|
+
else:
|
|
340
|
+
result = _fn(**kwargs)
|
|
341
|
+
except ModelRetry as exc:
|
|
342
|
+
result = f"Error: {exc.message}"
|
|
343
|
+
if _cb.on_tool_end:
|
|
344
|
+
_cb.on_tool_end(_name, result)
|
|
345
|
+
return result
|
|
346
|
+
|
|
347
|
+
return wrapper
|
|
348
|
+
|
|
349
|
+
wrapped = _make_wrapper(original_fn, wants_ctx, tool_name)
|
|
350
|
+
|
|
351
|
+
# Build schema: strip RunContext param if present
|
|
352
|
+
params = dict(td.parameters)
|
|
353
|
+
if wants_ctx:
|
|
354
|
+
ctx_param_name = _get_first_param_name(td.function)
|
|
355
|
+
props = dict(params.get("properties", {}))
|
|
356
|
+
props.pop(ctx_param_name, None)
|
|
357
|
+
params = dict(params)
|
|
358
|
+
params["properties"] = props
|
|
359
|
+
req = list(params.get("required", []))
|
|
360
|
+
if ctx_param_name in req:
|
|
361
|
+
req.remove(ctx_param_name)
|
|
362
|
+
if req:
|
|
363
|
+
params["required"] = req
|
|
364
|
+
elif "required" in params:
|
|
365
|
+
del params["required"]
|
|
366
|
+
|
|
367
|
+
new_td = ToolDefinition(
|
|
368
|
+
name=td.name,
|
|
369
|
+
description=td.description,
|
|
370
|
+
parameters=params,
|
|
371
|
+
function=wrapped,
|
|
372
|
+
)
|
|
373
|
+
new_registry.add(new_td)
|
|
374
|
+
|
|
375
|
+
return new_registry
|
|
376
|
+
|
|
377
|
+
# ------------------------------------------------------------------
|
|
378
|
+
# Guardrails
|
|
379
|
+
# ------------------------------------------------------------------
|
|
380
|
+
|
|
381
|
+
def _run_input_guardrails(self, ctx: RunContext[Any], prompt: str) -> str:
|
|
382
|
+
"""Execute input guardrails in order. Returns the (possibly transformed) prompt.
|
|
383
|
+
|
|
384
|
+
Each guardrail receives ``(ctx, prompt)`` and may:
|
|
385
|
+
- Return a ``str`` to transform the prompt.
|
|
386
|
+
- Return ``None`` to leave it unchanged.
|
|
387
|
+
- Raise :class:`GuardrailError` to reject entirely.
|
|
388
|
+
"""
|
|
389
|
+
for guardrail in self._input_guardrails:
|
|
390
|
+
result = guardrail(ctx, prompt)
|
|
391
|
+
if result is not None:
|
|
392
|
+
prompt = result
|
|
393
|
+
return prompt
|
|
394
|
+
|
|
395
|
+
def _run_output_guardrails(
|
|
396
|
+
self,
|
|
397
|
+
ctx: RunContext[Any],
|
|
398
|
+
result: AgentResult,
|
|
399
|
+
conv: Conversation,
|
|
400
|
+
session: UsageSession,
|
|
401
|
+
steps: list[AgentStep],
|
|
402
|
+
all_tool_calls: list[dict[str, Any]],
|
|
403
|
+
) -> AgentResult:
|
|
404
|
+
"""Execute output guardrails. Returns the (possibly modified) result.
|
|
405
|
+
|
|
406
|
+
Each guardrail receives ``(ctx, result)`` and may:
|
|
407
|
+
- Return ``None`` to pass (no change).
|
|
408
|
+
- Return an :class:`AgentResult` to modify the result.
|
|
409
|
+
- Raise :class:`ModelRetry` to re-prompt the LLM (up to 3 retries).
|
|
410
|
+
"""
|
|
411
|
+
for guardrail in self._output_guardrails:
|
|
412
|
+
for attempt in range(_OUTPUT_GUARDRAIL_MAX_RETRIES):
|
|
413
|
+
try:
|
|
414
|
+
guard_result = guardrail(ctx, result)
|
|
415
|
+
if guard_result is not None:
|
|
416
|
+
result = guard_result
|
|
417
|
+
break # guardrail passed
|
|
418
|
+
except ModelRetry as exc:
|
|
419
|
+
if self._is_over_budget(session):
|
|
420
|
+
logger.debug("Over budget, skipping output guardrail retry")
|
|
421
|
+
break
|
|
422
|
+
if attempt >= _OUTPUT_GUARDRAIL_MAX_RETRIES - 1:
|
|
423
|
+
raise ValueError(
|
|
424
|
+
f"Output guardrail failed after {_OUTPUT_GUARDRAIL_MAX_RETRIES} retries: {exc.message}"
|
|
425
|
+
) from exc
|
|
426
|
+
# Re-prompt the LLM
|
|
427
|
+
retry_text = conv.ask(
|
|
428
|
+
f"Your response did not pass validation. Error: {exc.message}\n\nPlease try again."
|
|
429
|
+
)
|
|
430
|
+
self._extract_steps(conv.messages[-2:], steps, all_tool_calls)
|
|
431
|
+
|
|
432
|
+
# Re-parse if output_type is set
|
|
433
|
+
if self._output_type is not None:
|
|
434
|
+
try:
|
|
435
|
+
cleaned = clean_json_text(retry_text)
|
|
436
|
+
parsed = json.loads(cleaned)
|
|
437
|
+
output = self._output_type.model_validate(parsed)
|
|
438
|
+
except Exception:
|
|
439
|
+
output = retry_text
|
|
440
|
+
else:
|
|
441
|
+
output = retry_text
|
|
442
|
+
|
|
443
|
+
result = AgentResult(
|
|
444
|
+
output=output,
|
|
445
|
+
output_text=retry_text,
|
|
446
|
+
messages=conv.messages,
|
|
447
|
+
usage=conv.usage,
|
|
448
|
+
steps=steps,
|
|
449
|
+
all_tool_calls=all_tool_calls,
|
|
450
|
+
state=AgentState.idle,
|
|
451
|
+
run_usage=session.summary(),
|
|
452
|
+
)
|
|
453
|
+
return result
|
|
454
|
+
|
|
455
|
+
# ------------------------------------------------------------------
|
|
456
|
+
# Budget check
|
|
457
|
+
# ------------------------------------------------------------------
|
|
458
|
+
|
|
459
|
+
def _is_over_budget(self, session: UsageSession) -> bool:
|
|
460
|
+
"""Return True if max_cost is set and the session has exceeded it."""
|
|
461
|
+
if self._max_cost is None:
|
|
462
|
+
return False
|
|
463
|
+
return session.total_cost >= self._max_cost
|
|
464
|
+
|
|
465
|
+
# ------------------------------------------------------------------
|
|
466
|
+
# Internals
|
|
467
|
+
# ------------------------------------------------------------------
|
|
468
|
+
|
|
469
|
+
def _resolve_system_prompt(self, ctx: RunContext[Any] | None = None) -> str | None:
|
|
470
|
+
"""Build the system prompt, appending output schema if needed."""
|
|
471
|
+
parts: list[str] = []
|
|
472
|
+
|
|
473
|
+
if self._system_prompt is not None:
|
|
474
|
+
if isinstance(self._system_prompt, Persona):
|
|
475
|
+
# Render Persona with RunContext variables if available
|
|
476
|
+
render_kwargs: dict[str, Any] = {}
|
|
477
|
+
if ctx is not None:
|
|
478
|
+
render_kwargs["model"] = ctx.model
|
|
479
|
+
render_kwargs["iteration"] = ctx.iteration
|
|
480
|
+
parts.append(self._system_prompt.render(**render_kwargs))
|
|
481
|
+
elif callable(self._system_prompt) and not isinstance(self._system_prompt, str):
|
|
482
|
+
if ctx is not None:
|
|
483
|
+
parts.append(self._system_prompt(ctx))
|
|
484
|
+
else:
|
|
485
|
+
# Fallback: call without context (shouldn't happen in normal flow)
|
|
486
|
+
parts.append(self._system_prompt(None)) # type: ignore[arg-type]
|
|
487
|
+
else:
|
|
488
|
+
parts.append(str(self._system_prompt))
|
|
489
|
+
|
|
490
|
+
if self._output_type is not None:
|
|
491
|
+
schema = self._output_type.model_json_schema()
|
|
492
|
+
schema_str = json.dumps(schema, indent=2)
|
|
493
|
+
parts.append(
|
|
494
|
+
"You MUST respond with a single JSON object (no markdown, "
|
|
495
|
+
"no extra text) that validates against this JSON schema:\n"
|
|
496
|
+
f"{schema_str}\n\n"
|
|
497
|
+
"Use double quotes for keys and strings. "
|
|
498
|
+
"If a value is unknown use null."
|
|
499
|
+
)
|
|
500
|
+
|
|
501
|
+
return "\n\n".join(parts) if parts else None
|
|
502
|
+
|
|
503
|
+
def _build_conversation(
|
|
504
|
+
self,
|
|
505
|
+
system_prompt: str | None = None,
|
|
506
|
+
tools: ToolRegistry | None = None,
|
|
507
|
+
driver_callbacks: DriverCallbacks | None = None,
|
|
508
|
+
) -> Conversation:
|
|
509
|
+
"""Create a fresh Conversation for a single run."""
|
|
510
|
+
effective_tools = tools if tools is not None else (self._tools if self._tools else None)
|
|
511
|
+
|
|
512
|
+
kwargs: dict[str, Any] = {
|
|
513
|
+
"system_prompt": system_prompt if system_prompt is not None else self._resolve_system_prompt(),
|
|
514
|
+
"tools": effective_tools,
|
|
515
|
+
"max_tool_rounds": self._max_iterations,
|
|
516
|
+
}
|
|
517
|
+
if self._options:
|
|
518
|
+
kwargs["options"] = self._options
|
|
519
|
+
if driver_callbacks is not None:
|
|
520
|
+
kwargs["callbacks"] = driver_callbacks
|
|
521
|
+
|
|
522
|
+
if self._driver is not None:
|
|
523
|
+
kwargs["driver"] = self._driver
|
|
524
|
+
else:
|
|
525
|
+
kwargs["model_name"] = self._model
|
|
526
|
+
|
|
527
|
+
return Conversation(**kwargs)
|
|
528
|
+
|
|
529
|
+
def _execute(self, prompt: str, steps: list[AgentStep], deps: Any) -> AgentResult:
|
|
530
|
+
"""Core execution: run conversation, extract steps, parse output."""
|
|
531
|
+
# 1. Create per-run UsageSession and wire into DriverCallbacks
|
|
532
|
+
session = UsageSession()
|
|
533
|
+
driver_callbacks = DriverCallbacks(
|
|
534
|
+
on_response=session.record,
|
|
535
|
+
on_error=session.record_error,
|
|
536
|
+
)
|
|
537
|
+
|
|
538
|
+
# 2. Build initial RunContext
|
|
539
|
+
ctx = self._build_run_context(prompt, deps, session, [], 0)
|
|
540
|
+
|
|
541
|
+
# 3. Run input guardrails
|
|
542
|
+
effective_prompt = self._run_input_guardrails(ctx, prompt)
|
|
543
|
+
|
|
544
|
+
# 4. Resolve system prompt (call it if callable, passing ctx)
|
|
545
|
+
resolved_system_prompt = self._resolve_system_prompt(ctx)
|
|
546
|
+
|
|
547
|
+
# 5. Wrap tools with context
|
|
548
|
+
wrapped_tools = self._wrap_tools_with_context(ctx)
|
|
549
|
+
|
|
550
|
+
# 6. Build Conversation
|
|
551
|
+
conv = self._build_conversation(
|
|
552
|
+
system_prompt=resolved_system_prompt,
|
|
553
|
+
tools=wrapped_tools if wrapped_tools else None,
|
|
554
|
+
driver_callbacks=driver_callbacks,
|
|
555
|
+
)
|
|
556
|
+
|
|
557
|
+
# 7. Fire on_iteration callback
|
|
558
|
+
if self._agent_callbacks.on_iteration:
|
|
559
|
+
self._agent_callbacks.on_iteration(0)
|
|
560
|
+
|
|
561
|
+
# 8. Ask the conversation (handles full tool loop internally)
|
|
562
|
+
t0 = time.perf_counter()
|
|
563
|
+
response_text = conv.ask(effective_prompt)
|
|
564
|
+
elapsed_ms = (time.perf_counter() - t0) * 1000
|
|
565
|
+
|
|
566
|
+
# 9. Extract steps and tool calls from conversation messages
|
|
567
|
+
all_tool_calls: list[dict[str, Any]] = []
|
|
568
|
+
self._extract_steps(conv.messages, steps, all_tool_calls)
|
|
569
|
+
|
|
570
|
+
# Handle output_type parsing
|
|
571
|
+
if self._output_type is not None:
|
|
572
|
+
output, output_text = self._parse_output(conv, response_text, steps, all_tool_calls, elapsed_ms, session)
|
|
573
|
+
else:
|
|
574
|
+
output = response_text
|
|
575
|
+
output_text = response_text
|
|
576
|
+
|
|
577
|
+
# Build result with run_usage
|
|
578
|
+
result = AgentResult(
|
|
579
|
+
output=output,
|
|
580
|
+
output_text=output_text,
|
|
581
|
+
messages=conv.messages,
|
|
582
|
+
usage=conv.usage,
|
|
583
|
+
steps=steps,
|
|
584
|
+
all_tool_calls=all_tool_calls,
|
|
585
|
+
state=AgentState.idle,
|
|
586
|
+
run_usage=session.summary(),
|
|
587
|
+
)
|
|
588
|
+
|
|
589
|
+
# 10. Run output guardrails
|
|
590
|
+
if self._output_guardrails:
|
|
591
|
+
result = self._run_output_guardrails(ctx, result, conv, session, steps, all_tool_calls)
|
|
592
|
+
|
|
593
|
+
# 11. Fire callbacks
|
|
594
|
+
if self._agent_callbacks.on_step:
|
|
595
|
+
for step in steps:
|
|
596
|
+
self._agent_callbacks.on_step(step)
|
|
597
|
+
|
|
598
|
+
if self._agent_callbacks.on_output:
|
|
599
|
+
self._agent_callbacks.on_output(result)
|
|
600
|
+
|
|
601
|
+
return result
|
|
602
|
+
|
|
603
|
+
def _extract_steps(
|
|
604
|
+
self,
|
|
605
|
+
messages: list[dict[str, Any]],
|
|
606
|
+
steps: list[AgentStep],
|
|
607
|
+
all_tool_calls: list[dict[str, Any]],
|
|
608
|
+
) -> None:
|
|
609
|
+
"""Scan conversation messages and populate steps and tool_calls."""
|
|
610
|
+
now = time.time()
|
|
611
|
+
|
|
612
|
+
for msg in messages:
|
|
613
|
+
role = msg.get("role", "")
|
|
614
|
+
|
|
615
|
+
if role == "assistant":
|
|
616
|
+
tc_list = msg.get("tool_calls", [])
|
|
617
|
+
if tc_list:
|
|
618
|
+
# Assistant message with tool calls
|
|
619
|
+
for tc in tc_list:
|
|
620
|
+
fn = tc.get("function", {})
|
|
621
|
+
name = fn.get("name", tc.get("name", ""))
|
|
622
|
+
raw_args = fn.get("arguments", tc.get("arguments", "{}"))
|
|
623
|
+
if isinstance(raw_args, str):
|
|
624
|
+
try:
|
|
625
|
+
args = json.loads(raw_args)
|
|
626
|
+
except json.JSONDecodeError:
|
|
627
|
+
args = {}
|
|
628
|
+
else:
|
|
629
|
+
args = raw_args
|
|
630
|
+
|
|
631
|
+
steps.append(
|
|
632
|
+
AgentStep(
|
|
633
|
+
step_type=StepType.tool_call,
|
|
634
|
+
timestamp=now,
|
|
635
|
+
content=msg.get("content", ""),
|
|
636
|
+
tool_name=name,
|
|
637
|
+
tool_args=args,
|
|
638
|
+
)
|
|
639
|
+
)
|
|
640
|
+
all_tool_calls.append({"name": name, "arguments": args, "id": tc.get("id", "")})
|
|
641
|
+
else:
|
|
642
|
+
# Final assistant message (no tool calls)
|
|
643
|
+
steps.append(
|
|
644
|
+
AgentStep(
|
|
645
|
+
step_type=StepType.output,
|
|
646
|
+
timestamp=now,
|
|
647
|
+
content=msg.get("content", ""),
|
|
648
|
+
)
|
|
649
|
+
)
|
|
650
|
+
|
|
651
|
+
elif role == "tool":
|
|
652
|
+
steps.append(
|
|
653
|
+
AgentStep(
|
|
654
|
+
step_type=StepType.tool_result,
|
|
655
|
+
timestamp=now,
|
|
656
|
+
content=msg.get("content", ""),
|
|
657
|
+
tool_name=msg.get("tool_call_id"),
|
|
658
|
+
)
|
|
659
|
+
)
|
|
660
|
+
|
|
661
|
+
def _parse_output(
|
|
662
|
+
self,
|
|
663
|
+
conv: Conversation,
|
|
664
|
+
response_text: str,
|
|
665
|
+
steps: list[AgentStep],
|
|
666
|
+
all_tool_calls: list[dict[str, Any]],
|
|
667
|
+
elapsed_ms: float,
|
|
668
|
+
session: UsageSession | None = None,
|
|
669
|
+
) -> tuple[Any, str]:
|
|
670
|
+
"""Try to parse ``response_text`` as the output_type, with retries."""
|
|
671
|
+
assert self._output_type is not None
|
|
672
|
+
|
|
673
|
+
last_error: Exception | None = None
|
|
674
|
+
text = response_text
|
|
675
|
+
|
|
676
|
+
for attempt in range(_OUTPUT_PARSE_MAX_RETRIES):
|
|
677
|
+
try:
|
|
678
|
+
cleaned = clean_json_text(text)
|
|
679
|
+
parsed = json.loads(cleaned)
|
|
680
|
+
model_instance = self._output_type.model_validate(parsed)
|
|
681
|
+
return model_instance, text
|
|
682
|
+
except Exception as exc:
|
|
683
|
+
last_error = exc
|
|
684
|
+
if attempt < _OUTPUT_PARSE_MAX_RETRIES - 1:
|
|
685
|
+
# Check budget before retrying
|
|
686
|
+
if session is not None and self._is_over_budget(session):
|
|
687
|
+
logger.debug("Over budget, skipping output parse retry")
|
|
688
|
+
break
|
|
689
|
+
logger.debug("Output parse attempt %d failed: %s", attempt + 1, exc)
|
|
690
|
+
retry_msg = (
|
|
691
|
+
f"Your previous response could not be parsed as valid JSON "
|
|
692
|
+
f"matching the required schema. Error: {exc}\n\n"
|
|
693
|
+
f"Please try again and respond ONLY with valid JSON."
|
|
694
|
+
)
|
|
695
|
+
text = conv.ask(retry_msg)
|
|
696
|
+
|
|
697
|
+
# Record the retry step
|
|
698
|
+
self._extract_steps(conv.messages[-2:], steps, all_tool_calls)
|
|
699
|
+
|
|
700
|
+
raise ValueError(
|
|
701
|
+
f"Failed to parse output as {self._output_type.__name__} "
|
|
702
|
+
f"after {_OUTPUT_PARSE_MAX_RETRIES} attempts: {last_error}"
|
|
703
|
+
)
|
|
704
|
+
|
|
705
|
+
# ------------------------------------------------------------------
|
|
706
|
+
# iter() — step-by-step inspection
|
|
707
|
+
# ------------------------------------------------------------------
|
|
708
|
+
|
|
709
|
+
def iter(self, prompt: str, *, deps: Any = None) -> AgentIterator:
|
|
710
|
+
"""Execute the agent loop and iterate over steps.
|
|
711
|
+
|
|
712
|
+
Returns an :class:`AgentIterator` that yields :class:`AgentStep`
|
|
713
|
+
objects. After iteration completes, the final :class:`AgentResult`
|
|
714
|
+
is available via :attr:`AgentIterator.result`.
|
|
715
|
+
|
|
716
|
+
Note:
|
|
717
|
+
In Phase 3c the conversation's tool loop runs to completion
|
|
718
|
+
before steps are yielded. True mid-loop yielding is deferred.
|
|
719
|
+
"""
|
|
720
|
+
gen = self._execute_iter(prompt, deps)
|
|
721
|
+
return AgentIterator(gen)
|
|
722
|
+
|
|
723
|
+
def _execute_iter(self, prompt: str, deps: Any) -> Generator[AgentStep, None, AgentResult]:
|
|
724
|
+
"""Generator that executes the agent loop and yields each step."""
|
|
725
|
+
self._state = AgentState.running
|
|
726
|
+
self._stop_requested = False
|
|
727
|
+
steps: list[AgentStep] = []
|
|
728
|
+
|
|
729
|
+
try:
|
|
730
|
+
result = self._execute(prompt, steps, deps)
|
|
731
|
+
# Yield each step one at a time
|
|
732
|
+
yield from result.steps
|
|
733
|
+
self._state = AgentState.idle
|
|
734
|
+
return result
|
|
735
|
+
except Exception:
|
|
736
|
+
self._state = AgentState.errored
|
|
737
|
+
raise
|
|
738
|
+
|
|
739
|
+
# ------------------------------------------------------------------
|
|
740
|
+
# run_stream() — streaming output
|
|
741
|
+
# ------------------------------------------------------------------
|
|
742
|
+
|
|
743
|
+
def run_stream(self, prompt: str, *, deps: Any = None) -> StreamedAgentResult:
|
|
744
|
+
"""Execute the agent loop with streaming output.
|
|
745
|
+
|
|
746
|
+
Returns a :class:`StreamedAgentResult` that yields
|
|
747
|
+
:class:`StreamEvent` objects. After iteration completes, the
|
|
748
|
+
final :class:`AgentResult` is available via
|
|
749
|
+
:attr:`StreamedAgentResult.result`.
|
|
750
|
+
|
|
751
|
+
When tools are registered, streaming falls back to non-streaming
|
|
752
|
+
``conv.ask()`` and yields the full response as a single
|
|
753
|
+
``text_delta`` event.
|
|
754
|
+
"""
|
|
755
|
+
gen = self._execute_stream(prompt, deps)
|
|
756
|
+
return StreamedAgentResult(gen)
|
|
757
|
+
|
|
758
|
+
def _execute_stream(self, prompt: str, deps: Any) -> Generator[StreamEvent, None, AgentResult]:
|
|
759
|
+
"""Generator that executes the agent loop and yields stream events."""
|
|
760
|
+
self._state = AgentState.running
|
|
761
|
+
self._stop_requested = False
|
|
762
|
+
steps: list[AgentStep] = []
|
|
763
|
+
|
|
764
|
+
try:
|
|
765
|
+
# 1. Create per-run UsageSession and wire into DriverCallbacks
|
|
766
|
+
session = UsageSession()
|
|
767
|
+
driver_callbacks = DriverCallbacks(
|
|
768
|
+
on_response=session.record,
|
|
769
|
+
on_error=session.record_error,
|
|
770
|
+
)
|
|
771
|
+
|
|
772
|
+
# 2. Build initial RunContext
|
|
773
|
+
ctx = self._build_run_context(prompt, deps, session, [], 0)
|
|
774
|
+
|
|
775
|
+
# 3. Run input guardrails
|
|
776
|
+
effective_prompt = self._run_input_guardrails(ctx, prompt)
|
|
777
|
+
|
|
778
|
+
# 4. Resolve system prompt
|
|
779
|
+
resolved_system_prompt = self._resolve_system_prompt(ctx)
|
|
780
|
+
|
|
781
|
+
# 5. Wrap tools with context
|
|
782
|
+
wrapped_tools = self._wrap_tools_with_context(ctx)
|
|
783
|
+
has_tools = bool(wrapped_tools)
|
|
784
|
+
|
|
785
|
+
# 6. Build Conversation
|
|
786
|
+
conv = self._build_conversation(
|
|
787
|
+
system_prompt=resolved_system_prompt,
|
|
788
|
+
tools=wrapped_tools if wrapped_tools else None,
|
|
789
|
+
driver_callbacks=driver_callbacks,
|
|
790
|
+
)
|
|
791
|
+
|
|
792
|
+
# 7. Fire on_iteration callback
|
|
793
|
+
if self._agent_callbacks.on_iteration:
|
|
794
|
+
self._agent_callbacks.on_iteration(0)
|
|
795
|
+
|
|
796
|
+
if has_tools:
|
|
797
|
+
# Tools registered: fall back to non-streaming conv.ask()
|
|
798
|
+
response_text = conv.ask(effective_prompt)
|
|
799
|
+
|
|
800
|
+
# Yield the full text as a single delta
|
|
801
|
+
yield StreamEvent(
|
|
802
|
+
event_type=StreamEventType.text_delta,
|
|
803
|
+
data=response_text,
|
|
804
|
+
)
|
|
805
|
+
else:
|
|
806
|
+
# No tools: use streaming if available
|
|
807
|
+
response_text = ""
|
|
808
|
+
stream_iter: Iterator[str] = conv.ask_stream(effective_prompt)
|
|
809
|
+
for chunk in stream_iter:
|
|
810
|
+
response_text += chunk
|
|
811
|
+
yield StreamEvent(
|
|
812
|
+
event_type=StreamEventType.text_delta,
|
|
813
|
+
data=chunk,
|
|
814
|
+
)
|
|
815
|
+
|
|
816
|
+
# 8. Extract steps
|
|
817
|
+
all_tool_calls: list[dict[str, Any]] = []
|
|
818
|
+
self._extract_steps(conv.messages, steps, all_tool_calls)
|
|
819
|
+
|
|
820
|
+
# 9. Parse output
|
|
821
|
+
if self._output_type is not None:
|
|
822
|
+
output, output_text = self._parse_output(conv, response_text, steps, all_tool_calls, 0.0, session)
|
|
823
|
+
else:
|
|
824
|
+
output = response_text
|
|
825
|
+
output_text = response_text
|
|
826
|
+
|
|
827
|
+
# 10. Build result
|
|
828
|
+
result = AgentResult(
|
|
829
|
+
output=output,
|
|
830
|
+
output_text=output_text,
|
|
831
|
+
messages=conv.messages,
|
|
832
|
+
usage=conv.usage,
|
|
833
|
+
steps=steps,
|
|
834
|
+
all_tool_calls=all_tool_calls,
|
|
835
|
+
state=AgentState.idle,
|
|
836
|
+
run_usage=session.summary(),
|
|
837
|
+
)
|
|
838
|
+
|
|
839
|
+
# 11. Run output guardrails
|
|
840
|
+
if self._output_guardrails:
|
|
841
|
+
result = self._run_output_guardrails(ctx, result, conv, session, steps, all_tool_calls)
|
|
842
|
+
|
|
843
|
+
# 12. Fire callbacks
|
|
844
|
+
if self._agent_callbacks.on_step:
|
|
845
|
+
for step in steps:
|
|
846
|
+
self._agent_callbacks.on_step(step)
|
|
847
|
+
if self._agent_callbacks.on_output:
|
|
848
|
+
self._agent_callbacks.on_output(result)
|
|
849
|
+
|
|
850
|
+
# 13. Yield final output event
|
|
851
|
+
yield StreamEvent(
|
|
852
|
+
event_type=StreamEventType.output,
|
|
853
|
+
data=result,
|
|
854
|
+
)
|
|
855
|
+
|
|
856
|
+
self._state = AgentState.idle
|
|
857
|
+
return result
|
|
858
|
+
except Exception:
|
|
859
|
+
self._state = AgentState.errored
|
|
860
|
+
raise
|
|
861
|
+
|
|
862
|
+
|
|
863
|
+
# ------------------------------------------------------------------
|
|
864
|
+
# AgentIterator
|
|
865
|
+
# ------------------------------------------------------------------
|
|
866
|
+
|
|
867
|
+
|
|
868
|
+
class AgentIterator:
|
|
869
|
+
"""Wraps the :meth:`Agent.iter` generator, capturing the final result.
|
|
870
|
+
|
|
871
|
+
After iteration completes (the ``for`` loop ends), the
|
|
872
|
+
:attr:`result` property holds the :class:`AgentResult`.
|
|
873
|
+
"""
|
|
874
|
+
|
|
875
|
+
def __init__(self, gen: Generator[AgentStep, None, AgentResult]) -> None:
|
|
876
|
+
self._gen = gen
|
|
877
|
+
self._result: AgentResult | None = None
|
|
878
|
+
|
|
879
|
+
def __iter__(self) -> AgentIterator:
|
|
880
|
+
return self
|
|
881
|
+
|
|
882
|
+
def __next__(self) -> AgentStep:
|
|
883
|
+
try:
|
|
884
|
+
return next(self._gen)
|
|
885
|
+
except StopIteration as e:
|
|
886
|
+
self._result = e.value
|
|
887
|
+
raise
|
|
888
|
+
|
|
889
|
+
@property
|
|
890
|
+
def result(self) -> AgentResult | None:
|
|
891
|
+
"""The final :class:`AgentResult`, available after iteration completes."""
|
|
892
|
+
return self._result
|
|
893
|
+
|
|
894
|
+
|
|
895
|
+
# ------------------------------------------------------------------
|
|
896
|
+
# StreamedAgentResult
|
|
897
|
+
# ------------------------------------------------------------------
|
|
898
|
+
|
|
899
|
+
|
|
900
|
+
class StreamedAgentResult:
|
|
901
|
+
"""Wraps the :meth:`Agent.run_stream` generator, capturing the final result.
|
|
902
|
+
|
|
903
|
+
Yields :class:`StreamEvent` objects during iteration. After iteration
|
|
904
|
+
completes, the :attr:`result` property holds the :class:`AgentResult`.
|
|
905
|
+
"""
|
|
906
|
+
|
|
907
|
+
def __init__(self, gen: Generator[StreamEvent, None, AgentResult]) -> None:
|
|
908
|
+
self._gen = gen
|
|
909
|
+
self._result: AgentResult | None = None
|
|
910
|
+
|
|
911
|
+
def __iter__(self) -> StreamedAgentResult:
|
|
912
|
+
return self
|
|
913
|
+
|
|
914
|
+
def __next__(self) -> StreamEvent:
|
|
915
|
+
try:
|
|
916
|
+
return next(self._gen)
|
|
917
|
+
except StopIteration as e:
|
|
918
|
+
self._result = e.value
|
|
919
|
+
raise
|
|
920
|
+
|
|
921
|
+
@property
|
|
922
|
+
def result(self) -> AgentResult | None:
|
|
923
|
+
"""The final :class:`AgentResult`, available after iteration completes."""
|
|
924
|
+
return self._result
|