prompture 0.0.35__py3-none-any.whl → 0.0.38.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prompture/__init__.py +120 -2
- prompture/_version.py +2 -2
- prompture/agent.py +924 -0
- prompture/agent_types.py +156 -0
- prompture/async_agent.py +880 -0
- prompture/async_conversation.py +199 -17
- prompture/async_driver.py +24 -0
- prompture/async_groups.py +551 -0
- prompture/conversation.py +213 -18
- prompture/core.py +30 -12
- prompture/discovery.py +24 -1
- prompture/driver.py +38 -0
- prompture/drivers/__init__.py +5 -1
- prompture/drivers/async_azure_driver.py +7 -1
- prompture/drivers/async_claude_driver.py +7 -1
- prompture/drivers/async_google_driver.py +212 -28
- prompture/drivers/async_grok_driver.py +7 -1
- prompture/drivers/async_groq_driver.py +7 -1
- prompture/drivers/async_lmstudio_driver.py +74 -5
- prompture/drivers/async_ollama_driver.py +13 -3
- prompture/drivers/async_openai_driver.py +7 -1
- prompture/drivers/async_openrouter_driver.py +7 -1
- prompture/drivers/async_registry.py +5 -1
- prompture/drivers/azure_driver.py +7 -1
- prompture/drivers/claude_driver.py +7 -1
- prompture/drivers/google_driver.py +217 -33
- prompture/drivers/grok_driver.py +7 -1
- prompture/drivers/groq_driver.py +7 -1
- prompture/drivers/lmstudio_driver.py +73 -8
- prompture/drivers/ollama_driver.py +16 -5
- prompture/drivers/openai_driver.py +7 -1
- prompture/drivers/openrouter_driver.py +7 -1
- prompture/drivers/vision_helpers.py +153 -0
- prompture/group_types.py +147 -0
- prompture/groups.py +530 -0
- prompture/image.py +180 -0
- prompture/persistence.py +254 -0
- prompture/persona.py +482 -0
- prompture/serialization.py +218 -0
- prompture/settings.py +1 -0
- prompture-0.0.38.dev2.dist-info/METADATA +369 -0
- prompture-0.0.38.dev2.dist-info/RECORD +77 -0
- prompture-0.0.35.dist-info/METADATA +0 -464
- prompture-0.0.35.dist-info/RECORD +0 -66
- {prompture-0.0.35.dist-info → prompture-0.0.38.dev2.dist-info}/WHEEL +0 -0
- {prompture-0.0.35.dist-info → prompture-0.0.38.dev2.dist-info}/entry_points.txt +0 -0
- {prompture-0.0.35.dist-info → prompture-0.0.38.dev2.dist-info}/licenses/LICENSE +0 -0
- {prompture-0.0.35.dist-info → prompture-0.0.38.dev2.dist-info}/top_level.txt +0 -0
prompture/async_agent.py
ADDED
|
@@ -0,0 +1,880 @@
|
|
|
1
|
+
"""Async Agent framework for Prompture.
|
|
2
|
+
|
|
3
|
+
Provides :class:`AsyncAgent`, the async counterpart of :class:`~prompture.agent.Agent`.
|
|
4
|
+
All methods are ``async`` and use :class:`~prompture.async_conversation.AsyncConversation`.
|
|
5
|
+
|
|
6
|
+
Example::
|
|
7
|
+
|
|
8
|
+
from prompture import AsyncAgent
|
|
9
|
+
|
|
10
|
+
agent = AsyncAgent("openai/gpt-4o", system_prompt="You are helpful.")
|
|
11
|
+
result = await agent.run("What is 2 + 2?")
|
|
12
|
+
print(result.output)
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import asyncio
|
|
18
|
+
import inspect
|
|
19
|
+
import json
|
|
20
|
+
import logging
|
|
21
|
+
import time
|
|
22
|
+
import typing
|
|
23
|
+
from collections.abc import AsyncGenerator, Callable
|
|
24
|
+
from typing import Any, Generic
|
|
25
|
+
|
|
26
|
+
from pydantic import BaseModel
|
|
27
|
+
|
|
28
|
+
from .agent_types import (
|
|
29
|
+
AgentCallbacks,
|
|
30
|
+
AgentResult,
|
|
31
|
+
AgentState,
|
|
32
|
+
AgentStep,
|
|
33
|
+
DepsType,
|
|
34
|
+
ModelRetry,
|
|
35
|
+
RunContext,
|
|
36
|
+
StepType,
|
|
37
|
+
StreamEvent,
|
|
38
|
+
StreamEventType,
|
|
39
|
+
)
|
|
40
|
+
from .callbacks import DriverCallbacks
|
|
41
|
+
from .persona import Persona
|
|
42
|
+
from .session import UsageSession
|
|
43
|
+
from .tools import clean_json_text
|
|
44
|
+
from .tools_schema import ToolDefinition, ToolRegistry
|
|
45
|
+
|
|
46
|
+
logger = logging.getLogger("prompture.async_agent")
|
|
47
|
+
|
|
48
|
+
_OUTPUT_PARSE_MAX_RETRIES = 3
|
|
49
|
+
_OUTPUT_GUARDRAIL_MAX_RETRIES = 3
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
# ------------------------------------------------------------------
|
|
53
|
+
# Helpers
|
|
54
|
+
# ------------------------------------------------------------------
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _is_async_callable(fn: Callable[..., Any]) -> bool:
|
|
58
|
+
"""Check if *fn* is an async callable (coroutine function or has async ``__call__``)."""
|
|
59
|
+
if asyncio.iscoroutinefunction(fn):
|
|
60
|
+
return True
|
|
61
|
+
# Check if the object has an async __call__ method (callable class)
|
|
62
|
+
dunder_call = type(fn).__call__ if callable(fn) else None
|
|
63
|
+
return dunder_call is not None and asyncio.iscoroutinefunction(dunder_call)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _tool_wants_context(fn: Callable[..., Any]) -> bool:
|
|
67
|
+
"""Check whether *fn*'s first parameter is annotated as :class:`RunContext`."""
|
|
68
|
+
sig = inspect.signature(fn)
|
|
69
|
+
params = list(sig.parameters.keys())
|
|
70
|
+
if not params:
|
|
71
|
+
return False
|
|
72
|
+
|
|
73
|
+
first_param = params[0]
|
|
74
|
+
if first_param == "self":
|
|
75
|
+
if len(params) < 2:
|
|
76
|
+
return False
|
|
77
|
+
first_param = params[1]
|
|
78
|
+
|
|
79
|
+
annotation = None
|
|
80
|
+
try:
|
|
81
|
+
hints = typing.get_type_hints(fn, include_extras=True)
|
|
82
|
+
annotation = hints.get(first_param)
|
|
83
|
+
except Exception:
|
|
84
|
+
# get_type_hints can fail with local/forward references; fall back to raw annotation
|
|
85
|
+
pass
|
|
86
|
+
|
|
87
|
+
if annotation is None:
|
|
88
|
+
raw = sig.parameters[first_param].annotation
|
|
89
|
+
if raw is inspect.Parameter.empty:
|
|
90
|
+
return False
|
|
91
|
+
annotation = raw
|
|
92
|
+
|
|
93
|
+
if isinstance(annotation, str):
|
|
94
|
+
return annotation == "RunContext" or annotation.startswith("RunContext[")
|
|
95
|
+
|
|
96
|
+
if annotation is RunContext:
|
|
97
|
+
return True
|
|
98
|
+
|
|
99
|
+
origin = getattr(annotation, "__origin__", None)
|
|
100
|
+
return origin is RunContext
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def _get_first_param_name(fn: Callable[..., Any]) -> str:
|
|
104
|
+
"""Return the name of the first non-self parameter of *fn*."""
|
|
105
|
+
sig = inspect.signature(fn)
|
|
106
|
+
for name, _param in sig.parameters.items():
|
|
107
|
+
if name != "self":
|
|
108
|
+
return name
|
|
109
|
+
return ""
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
# ------------------------------------------------------------------
|
|
113
|
+
# AsyncAgent
|
|
114
|
+
# ------------------------------------------------------------------
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class AsyncAgent(Generic[DepsType]):
|
|
118
|
+
"""Async agent that executes a ReAct loop with tool support.
|
|
119
|
+
|
|
120
|
+
Mirrors :class:`~prompture.agent.Agent` but uses
|
|
121
|
+
:class:`~prompture.async_conversation.AsyncConversation` and
|
|
122
|
+
``async`` methods throughout.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
model: Model string in ``"provider/model"`` format.
|
|
126
|
+
driver: Pre-built async driver instance.
|
|
127
|
+
tools: Initial tools as a list of callables or a :class:`ToolRegistry`.
|
|
128
|
+
system_prompt: System prompt prepended to every run. May also be a
|
|
129
|
+
callable ``(RunContext) -> str`` for dynamic prompts.
|
|
130
|
+
output_type: Optional Pydantic model class for structured output.
|
|
131
|
+
max_iterations: Maximum tool-use rounds per run.
|
|
132
|
+
max_cost: Soft budget in USD.
|
|
133
|
+
options: Extra driver options forwarded to every call.
|
|
134
|
+
deps_type: Type hint for dependencies.
|
|
135
|
+
agent_callbacks: Agent-level observability callbacks.
|
|
136
|
+
input_guardrails: Functions called before the prompt is sent.
|
|
137
|
+
output_guardrails: Functions called after output is parsed.
|
|
138
|
+
"""
|
|
139
|
+
|
|
140
|
+
def __init__(
|
|
141
|
+
self,
|
|
142
|
+
model: str = "",
|
|
143
|
+
*,
|
|
144
|
+
driver: Any | None = None,
|
|
145
|
+
tools: list[Callable[..., Any]] | ToolRegistry | None = None,
|
|
146
|
+
system_prompt: str | Persona | Callable[..., str] | None = None,
|
|
147
|
+
output_type: type[BaseModel] | None = None,
|
|
148
|
+
max_iterations: int = 10,
|
|
149
|
+
max_cost: float | None = None,
|
|
150
|
+
options: dict[str, Any] | None = None,
|
|
151
|
+
deps_type: type | None = None,
|
|
152
|
+
agent_callbacks: AgentCallbacks | None = None,
|
|
153
|
+
input_guardrails: list[Callable[..., Any]] | None = None,
|
|
154
|
+
output_guardrails: list[Callable[..., Any]] | None = None,
|
|
155
|
+
name: str = "",
|
|
156
|
+
description: str = "",
|
|
157
|
+
output_key: str | None = None,
|
|
158
|
+
) -> None:
|
|
159
|
+
if not model and driver is None:
|
|
160
|
+
raise ValueError("Either model or driver must be provided")
|
|
161
|
+
|
|
162
|
+
self._model = model
|
|
163
|
+
self._driver = driver
|
|
164
|
+
self._system_prompt = system_prompt
|
|
165
|
+
self._output_type = output_type
|
|
166
|
+
self._max_iterations = max_iterations
|
|
167
|
+
self._max_cost = max_cost
|
|
168
|
+
self._options = dict(options) if options else {}
|
|
169
|
+
self._deps_type = deps_type
|
|
170
|
+
self._agent_callbacks = agent_callbacks or AgentCallbacks()
|
|
171
|
+
self._input_guardrails = list(input_guardrails) if input_guardrails else []
|
|
172
|
+
self._output_guardrails = list(output_guardrails) if output_guardrails else []
|
|
173
|
+
self.name = name
|
|
174
|
+
self.description = description
|
|
175
|
+
self.output_key = output_key
|
|
176
|
+
|
|
177
|
+
# Build internal tool registry
|
|
178
|
+
self._tools = ToolRegistry()
|
|
179
|
+
if isinstance(tools, ToolRegistry):
|
|
180
|
+
self._tools = tools
|
|
181
|
+
elif tools is not None:
|
|
182
|
+
for fn in tools:
|
|
183
|
+
self._tools.register(fn)
|
|
184
|
+
|
|
185
|
+
self._state = AgentState.idle
|
|
186
|
+
self._stop_requested = False
|
|
187
|
+
|
|
188
|
+
# ------------------------------------------------------------------
|
|
189
|
+
# Public API
|
|
190
|
+
# ------------------------------------------------------------------
|
|
191
|
+
|
|
192
|
+
def tool(self, fn: Callable[..., Any]) -> Callable[..., Any]:
|
|
193
|
+
"""Decorator to register a function as a tool on this agent."""
|
|
194
|
+
self._tools.register(fn)
|
|
195
|
+
return fn
|
|
196
|
+
|
|
197
|
+
@property
|
|
198
|
+
def state(self) -> AgentState:
|
|
199
|
+
"""Current lifecycle state of the agent."""
|
|
200
|
+
return self._state
|
|
201
|
+
|
|
202
|
+
def stop(self) -> None:
|
|
203
|
+
"""Request graceful shutdown after the current iteration."""
|
|
204
|
+
self._stop_requested = True
|
|
205
|
+
|
|
206
|
+
def as_tool(
|
|
207
|
+
self,
|
|
208
|
+
name: str | None = None,
|
|
209
|
+
description: str | None = None,
|
|
210
|
+
custom_output_extractor: Callable[[AgentResult], str] | None = None,
|
|
211
|
+
) -> ToolDefinition:
|
|
212
|
+
"""Wrap this AsyncAgent as a callable tool for another Agent.
|
|
213
|
+
|
|
214
|
+
Creates a :class:`ToolDefinition` whose function accepts a ``prompt``
|
|
215
|
+
string, runs this agent (bridging async to sync), and returns the
|
|
216
|
+
output text.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
name: Tool name (defaults to ``self.name`` or ``"agent_tool"``).
|
|
220
|
+
description: Tool description (defaults to ``self.description``).
|
|
221
|
+
custom_output_extractor: Optional function to extract a string
|
|
222
|
+
from :class:`AgentResult`.
|
|
223
|
+
"""
|
|
224
|
+
tool_name = name or self.name or "agent_tool"
|
|
225
|
+
tool_desc = description or self.description or f"Run agent {tool_name}"
|
|
226
|
+
agent = self
|
|
227
|
+
extractor = custom_output_extractor
|
|
228
|
+
|
|
229
|
+
def _call_agent(prompt: str) -> str:
|
|
230
|
+
"""Run the wrapped async agent with the given prompt."""
|
|
231
|
+
try:
|
|
232
|
+
loop = asyncio.get_running_loop()
|
|
233
|
+
except RuntimeError:
|
|
234
|
+
loop = None
|
|
235
|
+
|
|
236
|
+
if loop is not None and loop.is_running():
|
|
237
|
+
import concurrent.futures
|
|
238
|
+
|
|
239
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
|
240
|
+
result = pool.submit(asyncio.run, agent.run(prompt)).result()
|
|
241
|
+
else:
|
|
242
|
+
result = asyncio.run(agent.run(prompt))
|
|
243
|
+
|
|
244
|
+
if extractor is not None:
|
|
245
|
+
return extractor(result)
|
|
246
|
+
return result.output_text
|
|
247
|
+
|
|
248
|
+
return ToolDefinition(
|
|
249
|
+
name=tool_name,
|
|
250
|
+
description=tool_desc,
|
|
251
|
+
parameters={
|
|
252
|
+
"type": "object",
|
|
253
|
+
"properties": {
|
|
254
|
+
"prompt": {"type": "string", "description": "The prompt to send to the agent"},
|
|
255
|
+
},
|
|
256
|
+
"required": ["prompt"],
|
|
257
|
+
},
|
|
258
|
+
function=_call_agent,
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
async def run(self, prompt: str, *, deps: Any = None) -> AgentResult:
|
|
262
|
+
"""Execute the agent loop to completion (async).
|
|
263
|
+
|
|
264
|
+
Creates a fresh conversation, sends the prompt, handles tool calls,
|
|
265
|
+
and optionally parses the final response into ``output_type``.
|
|
266
|
+
"""
|
|
267
|
+
self._state = AgentState.running
|
|
268
|
+
self._stop_requested = False
|
|
269
|
+
steps: list[AgentStep] = []
|
|
270
|
+
|
|
271
|
+
try:
|
|
272
|
+
result = await self._execute(prompt, steps, deps)
|
|
273
|
+
self._state = AgentState.idle
|
|
274
|
+
return result
|
|
275
|
+
except Exception:
|
|
276
|
+
self._state = AgentState.errored
|
|
277
|
+
raise
|
|
278
|
+
|
|
279
|
+
async def iter(self, prompt: str, *, deps: Any = None) -> AsyncAgentIterator:
|
|
280
|
+
"""Execute the agent loop and iterate over steps asynchronously.
|
|
281
|
+
|
|
282
|
+
Returns an :class:`AsyncAgentIterator` yielding :class:`AgentStep` objects.
|
|
283
|
+
After iteration, :attr:`AsyncAgentIterator.result` holds the final result.
|
|
284
|
+
"""
|
|
285
|
+
gen = self._execute_iter(prompt, deps)
|
|
286
|
+
return AsyncAgentIterator(gen)
|
|
287
|
+
|
|
288
|
+
async def run_stream(self, prompt: str, *, deps: Any = None) -> AsyncStreamedAgentResult:
|
|
289
|
+
"""Execute the agent loop with streaming output (async).
|
|
290
|
+
|
|
291
|
+
Returns an :class:`AsyncStreamedAgentResult` yielding :class:`StreamEvent` objects.
|
|
292
|
+
"""
|
|
293
|
+
gen = self._execute_stream(prompt, deps)
|
|
294
|
+
return AsyncStreamedAgentResult(gen)
|
|
295
|
+
|
|
296
|
+
# ------------------------------------------------------------------
|
|
297
|
+
# RunContext helpers
|
|
298
|
+
# ------------------------------------------------------------------
|
|
299
|
+
|
|
300
|
+
def _build_run_context(
|
|
301
|
+
self,
|
|
302
|
+
prompt: str,
|
|
303
|
+
deps: Any,
|
|
304
|
+
session: UsageSession,
|
|
305
|
+
messages: list[dict[str, Any]],
|
|
306
|
+
iteration: int,
|
|
307
|
+
) -> RunContext[Any]:
|
|
308
|
+
return RunContext(
|
|
309
|
+
deps=deps,
|
|
310
|
+
model=self._model,
|
|
311
|
+
usage=session.summary(),
|
|
312
|
+
messages=list(messages),
|
|
313
|
+
iteration=iteration,
|
|
314
|
+
prompt=prompt,
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
# ------------------------------------------------------------------
|
|
318
|
+
# Tool wrapping (RunContext injection + ModelRetry + callbacks)
|
|
319
|
+
# ------------------------------------------------------------------
|
|
320
|
+
|
|
321
|
+
def _wrap_tools_with_context(self, ctx: RunContext[Any]) -> ToolRegistry:
|
|
322
|
+
"""Return a new :class:`ToolRegistry` with wrapped tool functions.
|
|
323
|
+
|
|
324
|
+
All wrappers are **sync** so they work with ``ToolRegistry.execute()``.
|
|
325
|
+
For async tool functions, the wrapper uses
|
|
326
|
+
``asyncio.get_event_loop().run_until_complete()`` as a fallback.
|
|
327
|
+
"""
|
|
328
|
+
if not self._tools:
|
|
329
|
+
return ToolRegistry()
|
|
330
|
+
|
|
331
|
+
new_registry = ToolRegistry()
|
|
332
|
+
cb = self._agent_callbacks
|
|
333
|
+
|
|
334
|
+
for td in self._tools.definitions:
|
|
335
|
+
wants_ctx = _tool_wants_context(td.function)
|
|
336
|
+
original_fn = td.function
|
|
337
|
+
tool_name = td.name
|
|
338
|
+
is_async = _is_async_callable(original_fn)
|
|
339
|
+
|
|
340
|
+
def _make_wrapper(
|
|
341
|
+
_fn: Callable[..., Any],
|
|
342
|
+
_wants: bool,
|
|
343
|
+
_name: str,
|
|
344
|
+
_is_async: bool,
|
|
345
|
+
_cb: AgentCallbacks = cb,
|
|
346
|
+
) -> Callable[..., Any]:
|
|
347
|
+
def wrapper(**kwargs: Any) -> Any:
|
|
348
|
+
if _cb.on_tool_start:
|
|
349
|
+
_cb.on_tool_start(_name, kwargs)
|
|
350
|
+
try:
|
|
351
|
+
if _wants:
|
|
352
|
+
call_args = (ctx,)
|
|
353
|
+
else:
|
|
354
|
+
call_args = ()
|
|
355
|
+
|
|
356
|
+
if _is_async:
|
|
357
|
+
coro = _fn(*call_args, **kwargs)
|
|
358
|
+
# Try to get running loop; if none, use asyncio.run()
|
|
359
|
+
try:
|
|
360
|
+
loop = asyncio.get_running_loop()
|
|
361
|
+
except RuntimeError:
|
|
362
|
+
loop = None
|
|
363
|
+
if loop is not None and loop.is_running():
|
|
364
|
+
# We're inside an async context — create a new thread
|
|
365
|
+
import concurrent.futures
|
|
366
|
+
|
|
367
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
|
368
|
+
result = pool.submit(asyncio.run, coro).result()
|
|
369
|
+
else:
|
|
370
|
+
result = asyncio.run(coro)
|
|
371
|
+
else:
|
|
372
|
+
result = _fn(*call_args, **kwargs)
|
|
373
|
+
except ModelRetry as exc:
|
|
374
|
+
result = f"Error: {exc.message}"
|
|
375
|
+
if _cb.on_tool_end:
|
|
376
|
+
_cb.on_tool_end(_name, result)
|
|
377
|
+
return result
|
|
378
|
+
|
|
379
|
+
return wrapper
|
|
380
|
+
|
|
381
|
+
wrapped = _make_wrapper(original_fn, wants_ctx, tool_name, is_async)
|
|
382
|
+
|
|
383
|
+
# Build schema: strip RunContext param if present
|
|
384
|
+
params = dict(td.parameters)
|
|
385
|
+
if wants_ctx:
|
|
386
|
+
ctx_param_name = _get_first_param_name(td.function)
|
|
387
|
+
props = dict(params.get("properties", {}))
|
|
388
|
+
props.pop(ctx_param_name, None)
|
|
389
|
+
params = dict(params)
|
|
390
|
+
params["properties"] = props
|
|
391
|
+
req = list(params.get("required", []))
|
|
392
|
+
if ctx_param_name in req:
|
|
393
|
+
req.remove(ctx_param_name)
|
|
394
|
+
if req:
|
|
395
|
+
params["required"] = req
|
|
396
|
+
elif "required" in params:
|
|
397
|
+
del params["required"]
|
|
398
|
+
|
|
399
|
+
new_td = ToolDefinition(
|
|
400
|
+
name=td.name,
|
|
401
|
+
description=td.description,
|
|
402
|
+
parameters=params,
|
|
403
|
+
function=wrapped,
|
|
404
|
+
)
|
|
405
|
+
new_registry.add(new_td)
|
|
406
|
+
|
|
407
|
+
return new_registry
|
|
408
|
+
|
|
409
|
+
# ------------------------------------------------------------------
|
|
410
|
+
# Guardrails
|
|
411
|
+
# ------------------------------------------------------------------
|
|
412
|
+
|
|
413
|
+
def _run_input_guardrails(self, ctx: RunContext[Any], prompt: str) -> str:
|
|
414
|
+
for guardrail in self._input_guardrails:
|
|
415
|
+
result = guardrail(ctx, prompt)
|
|
416
|
+
if result is not None:
|
|
417
|
+
prompt = result
|
|
418
|
+
return prompt
|
|
419
|
+
|
|
420
|
+
async def _run_output_guardrails(
|
|
421
|
+
self,
|
|
422
|
+
ctx: RunContext[Any],
|
|
423
|
+
result: AgentResult,
|
|
424
|
+
conv: Any,
|
|
425
|
+
session: UsageSession,
|
|
426
|
+
steps: list[AgentStep],
|
|
427
|
+
all_tool_calls: list[dict[str, Any]],
|
|
428
|
+
) -> AgentResult:
|
|
429
|
+
for guardrail in self._output_guardrails:
|
|
430
|
+
for attempt in range(_OUTPUT_GUARDRAIL_MAX_RETRIES):
|
|
431
|
+
try:
|
|
432
|
+
guard_result = guardrail(ctx, result)
|
|
433
|
+
if guard_result is not None:
|
|
434
|
+
result = guard_result
|
|
435
|
+
break
|
|
436
|
+
except ModelRetry as exc:
|
|
437
|
+
if self._is_over_budget(session):
|
|
438
|
+
break
|
|
439
|
+
if attempt >= _OUTPUT_GUARDRAIL_MAX_RETRIES - 1:
|
|
440
|
+
raise ValueError(
|
|
441
|
+
f"Output guardrail failed after {_OUTPUT_GUARDRAIL_MAX_RETRIES} retries: {exc.message}"
|
|
442
|
+
) from exc
|
|
443
|
+
retry_text = await conv.ask(
|
|
444
|
+
f"Your response did not pass validation. Error: {exc.message}\n\nPlease try again."
|
|
445
|
+
)
|
|
446
|
+
self._extract_steps(conv.messages[-2:], steps, all_tool_calls)
|
|
447
|
+
|
|
448
|
+
if self._output_type is not None:
|
|
449
|
+
try:
|
|
450
|
+
cleaned = clean_json_text(retry_text)
|
|
451
|
+
parsed = json.loads(cleaned)
|
|
452
|
+
output = self._output_type.model_validate(parsed)
|
|
453
|
+
except Exception:
|
|
454
|
+
output = retry_text
|
|
455
|
+
else:
|
|
456
|
+
output = retry_text
|
|
457
|
+
|
|
458
|
+
result = AgentResult(
|
|
459
|
+
output=output,
|
|
460
|
+
output_text=retry_text,
|
|
461
|
+
messages=conv.messages,
|
|
462
|
+
usage=conv.usage,
|
|
463
|
+
steps=steps,
|
|
464
|
+
all_tool_calls=all_tool_calls,
|
|
465
|
+
state=AgentState.idle,
|
|
466
|
+
run_usage=session.summary(),
|
|
467
|
+
)
|
|
468
|
+
return result
|
|
469
|
+
|
|
470
|
+
# ------------------------------------------------------------------
|
|
471
|
+
# Budget check
|
|
472
|
+
# ------------------------------------------------------------------
|
|
473
|
+
|
|
474
|
+
def _is_over_budget(self, session: UsageSession) -> bool:
|
|
475
|
+
if self._max_cost is None:
|
|
476
|
+
return False
|
|
477
|
+
return session.total_cost >= self._max_cost
|
|
478
|
+
|
|
479
|
+
# ------------------------------------------------------------------
|
|
480
|
+
# Internals
|
|
481
|
+
# ------------------------------------------------------------------
|
|
482
|
+
|
|
483
|
+
def _resolve_system_prompt(self, ctx: RunContext[Any] | None = None) -> str | None:
|
|
484
|
+
parts: list[str] = []
|
|
485
|
+
|
|
486
|
+
if self._system_prompt is not None:
|
|
487
|
+
if isinstance(self._system_prompt, Persona):
|
|
488
|
+
render_kwargs: dict[str, Any] = {}
|
|
489
|
+
if ctx is not None:
|
|
490
|
+
render_kwargs["model"] = ctx.model
|
|
491
|
+
render_kwargs["iteration"] = ctx.iteration
|
|
492
|
+
parts.append(self._system_prompt.render(**render_kwargs))
|
|
493
|
+
elif callable(self._system_prompt) and not isinstance(self._system_prompt, str):
|
|
494
|
+
if ctx is not None:
|
|
495
|
+
parts.append(self._system_prompt(ctx))
|
|
496
|
+
else:
|
|
497
|
+
parts.append(self._system_prompt(None)) # type: ignore[arg-type]
|
|
498
|
+
else:
|
|
499
|
+
parts.append(str(self._system_prompt))
|
|
500
|
+
|
|
501
|
+
if self._output_type is not None:
|
|
502
|
+
schema = self._output_type.model_json_schema()
|
|
503
|
+
schema_str = json.dumps(schema, indent=2)
|
|
504
|
+
parts.append(
|
|
505
|
+
"You MUST respond with a single JSON object (no markdown, "
|
|
506
|
+
"no extra text) that validates against this JSON schema:\n"
|
|
507
|
+
f"{schema_str}\n\n"
|
|
508
|
+
"Use double quotes for keys and strings. "
|
|
509
|
+
"If a value is unknown use null."
|
|
510
|
+
)
|
|
511
|
+
|
|
512
|
+
return "\n\n".join(parts) if parts else None
|
|
513
|
+
|
|
514
|
+
def _build_conversation(
|
|
515
|
+
self,
|
|
516
|
+
system_prompt: str | None = None,
|
|
517
|
+
tools: ToolRegistry | None = None,
|
|
518
|
+
driver_callbacks: DriverCallbacks | None = None,
|
|
519
|
+
) -> Any:
|
|
520
|
+
"""Create a fresh AsyncConversation for a single run."""
|
|
521
|
+
from .async_conversation import AsyncConversation
|
|
522
|
+
|
|
523
|
+
effective_tools = tools if tools is not None else (self._tools if self._tools else None)
|
|
524
|
+
|
|
525
|
+
kwargs: dict[str, Any] = {
|
|
526
|
+
"system_prompt": system_prompt if system_prompt is not None else self._resolve_system_prompt(),
|
|
527
|
+
"tools": effective_tools,
|
|
528
|
+
"max_tool_rounds": self._max_iterations,
|
|
529
|
+
}
|
|
530
|
+
if self._options:
|
|
531
|
+
kwargs["options"] = self._options
|
|
532
|
+
if driver_callbacks is not None:
|
|
533
|
+
kwargs["callbacks"] = driver_callbacks
|
|
534
|
+
|
|
535
|
+
if self._driver is not None:
|
|
536
|
+
kwargs["driver"] = self._driver
|
|
537
|
+
else:
|
|
538
|
+
kwargs["model_name"] = self._model
|
|
539
|
+
|
|
540
|
+
return AsyncConversation(**kwargs)
|
|
541
|
+
|
|
542
|
+
async def _execute(self, prompt: str, steps: list[AgentStep], deps: Any) -> AgentResult:
|
|
543
|
+
"""Core async execution: run conversation, extract steps, parse output."""
|
|
544
|
+
# 1. Create per-run UsageSession
|
|
545
|
+
session = UsageSession()
|
|
546
|
+
driver_callbacks = DriverCallbacks(
|
|
547
|
+
on_response=session.record,
|
|
548
|
+
on_error=session.record_error,
|
|
549
|
+
)
|
|
550
|
+
|
|
551
|
+
# 2. Build initial RunContext
|
|
552
|
+
ctx = self._build_run_context(prompt, deps, session, [], 0)
|
|
553
|
+
|
|
554
|
+
# 3. Run input guardrails
|
|
555
|
+
effective_prompt = self._run_input_guardrails(ctx, prompt)
|
|
556
|
+
|
|
557
|
+
# 4. Resolve system prompt
|
|
558
|
+
resolved_system_prompt = self._resolve_system_prompt(ctx)
|
|
559
|
+
|
|
560
|
+
# 5. Wrap tools with context
|
|
561
|
+
wrapped_tools = self._wrap_tools_with_context(ctx)
|
|
562
|
+
|
|
563
|
+
# 6. Build AsyncConversation
|
|
564
|
+
conv = self._build_conversation(
|
|
565
|
+
system_prompt=resolved_system_prompt,
|
|
566
|
+
tools=wrapped_tools if wrapped_tools else None,
|
|
567
|
+
driver_callbacks=driver_callbacks,
|
|
568
|
+
)
|
|
569
|
+
|
|
570
|
+
# 7. Fire on_iteration callback
|
|
571
|
+
if self._agent_callbacks.on_iteration:
|
|
572
|
+
self._agent_callbacks.on_iteration(0)
|
|
573
|
+
|
|
574
|
+
# 8. Ask the conversation (handles full tool loop internally)
|
|
575
|
+
t0 = time.perf_counter()
|
|
576
|
+
response_text = await conv.ask(effective_prompt)
|
|
577
|
+
elapsed_ms = (time.perf_counter() - t0) * 1000
|
|
578
|
+
|
|
579
|
+
# 9. Extract steps and tool calls
|
|
580
|
+
all_tool_calls: list[dict[str, Any]] = []
|
|
581
|
+
self._extract_steps(conv.messages, steps, all_tool_calls)
|
|
582
|
+
|
|
583
|
+
# Handle output_type parsing
|
|
584
|
+
if self._output_type is not None:
|
|
585
|
+
output, output_text = await self._parse_output(
|
|
586
|
+
conv, response_text, steps, all_tool_calls, elapsed_ms, session
|
|
587
|
+
)
|
|
588
|
+
else:
|
|
589
|
+
output = response_text
|
|
590
|
+
output_text = response_text
|
|
591
|
+
|
|
592
|
+
result = AgentResult(
|
|
593
|
+
output=output,
|
|
594
|
+
output_text=output_text,
|
|
595
|
+
messages=conv.messages,
|
|
596
|
+
usage=conv.usage,
|
|
597
|
+
steps=steps,
|
|
598
|
+
all_tool_calls=all_tool_calls,
|
|
599
|
+
state=AgentState.idle,
|
|
600
|
+
run_usage=session.summary(),
|
|
601
|
+
)
|
|
602
|
+
|
|
603
|
+
# 10. Run output guardrails
|
|
604
|
+
if self._output_guardrails:
|
|
605
|
+
result = await self._run_output_guardrails(ctx, result, conv, session, steps, all_tool_calls)
|
|
606
|
+
|
|
607
|
+
# 11. Fire callbacks
|
|
608
|
+
if self._agent_callbacks.on_step:
|
|
609
|
+
for step in steps:
|
|
610
|
+
self._agent_callbacks.on_step(step)
|
|
611
|
+
if self._agent_callbacks.on_output:
|
|
612
|
+
self._agent_callbacks.on_output(result)
|
|
613
|
+
|
|
614
|
+
return result
|
|
615
|
+
|
|
616
|
+
def _extract_steps(
|
|
617
|
+
self,
|
|
618
|
+
messages: list[dict[str, Any]],
|
|
619
|
+
steps: list[AgentStep],
|
|
620
|
+
all_tool_calls: list[dict[str, Any]],
|
|
621
|
+
) -> None:
|
|
622
|
+
"""Scan conversation messages and populate steps and tool_calls."""
|
|
623
|
+
now = time.time()
|
|
624
|
+
|
|
625
|
+
for msg in messages:
|
|
626
|
+
role = msg.get("role", "")
|
|
627
|
+
|
|
628
|
+
if role == "assistant":
|
|
629
|
+
tc_list = msg.get("tool_calls", [])
|
|
630
|
+
if tc_list:
|
|
631
|
+
for tc in tc_list:
|
|
632
|
+
fn = tc.get("function", {})
|
|
633
|
+
name = fn.get("name", tc.get("name", ""))
|
|
634
|
+
raw_args = fn.get("arguments", tc.get("arguments", "{}"))
|
|
635
|
+
if isinstance(raw_args, str):
|
|
636
|
+
try:
|
|
637
|
+
args = json.loads(raw_args)
|
|
638
|
+
except json.JSONDecodeError:
|
|
639
|
+
args = {}
|
|
640
|
+
else:
|
|
641
|
+
args = raw_args
|
|
642
|
+
|
|
643
|
+
steps.append(
|
|
644
|
+
AgentStep(
|
|
645
|
+
step_type=StepType.tool_call,
|
|
646
|
+
timestamp=now,
|
|
647
|
+
content=msg.get("content", ""),
|
|
648
|
+
tool_name=name,
|
|
649
|
+
tool_args=args,
|
|
650
|
+
)
|
|
651
|
+
)
|
|
652
|
+
all_tool_calls.append({"name": name, "arguments": args, "id": tc.get("id", "")})
|
|
653
|
+
else:
|
|
654
|
+
steps.append(
|
|
655
|
+
AgentStep(
|
|
656
|
+
step_type=StepType.output,
|
|
657
|
+
timestamp=now,
|
|
658
|
+
content=msg.get("content", ""),
|
|
659
|
+
)
|
|
660
|
+
)
|
|
661
|
+
|
|
662
|
+
elif role == "tool":
|
|
663
|
+
steps.append(
|
|
664
|
+
AgentStep(
|
|
665
|
+
step_type=StepType.tool_result,
|
|
666
|
+
timestamp=now,
|
|
667
|
+
content=msg.get("content", ""),
|
|
668
|
+
tool_name=msg.get("tool_call_id"),
|
|
669
|
+
)
|
|
670
|
+
)
|
|
671
|
+
|
|
672
|
+
async def _parse_output(
|
|
673
|
+
self,
|
|
674
|
+
conv: Any,
|
|
675
|
+
response_text: str,
|
|
676
|
+
steps: list[AgentStep],
|
|
677
|
+
all_tool_calls: list[dict[str, Any]],
|
|
678
|
+
elapsed_ms: float,
|
|
679
|
+
session: UsageSession | None = None,
|
|
680
|
+
) -> tuple[Any, str]:
|
|
681
|
+
"""Try to parse ``response_text`` as the output_type, with retries (async)."""
|
|
682
|
+
assert self._output_type is not None
|
|
683
|
+
|
|
684
|
+
last_error: Exception | None = None
|
|
685
|
+
text = response_text
|
|
686
|
+
|
|
687
|
+
for attempt in range(_OUTPUT_PARSE_MAX_RETRIES):
|
|
688
|
+
try:
|
|
689
|
+
cleaned = clean_json_text(text)
|
|
690
|
+
parsed = json.loads(cleaned)
|
|
691
|
+
model_instance = self._output_type.model_validate(parsed)
|
|
692
|
+
return model_instance, text
|
|
693
|
+
except Exception as exc:
|
|
694
|
+
last_error = exc
|
|
695
|
+
if attempt < _OUTPUT_PARSE_MAX_RETRIES - 1:
|
|
696
|
+
if session is not None and self._is_over_budget(session):
|
|
697
|
+
break
|
|
698
|
+
retry_msg = (
|
|
699
|
+
f"Your previous response could not be parsed as valid JSON "
|
|
700
|
+
f"matching the required schema. Error: {exc}\n\n"
|
|
701
|
+
f"Please try again and respond ONLY with valid JSON."
|
|
702
|
+
)
|
|
703
|
+
text = await conv.ask(retry_msg)
|
|
704
|
+
self._extract_steps(conv.messages[-2:], steps, all_tool_calls)
|
|
705
|
+
|
|
706
|
+
raise ValueError(
|
|
707
|
+
f"Failed to parse output as {self._output_type.__name__} "
|
|
708
|
+
f"after {_OUTPUT_PARSE_MAX_RETRIES} attempts: {last_error}"
|
|
709
|
+
)
|
|
710
|
+
|
|
711
|
+
# ------------------------------------------------------------------
|
|
712
|
+
# iter() — async step-by-step
|
|
713
|
+
# ------------------------------------------------------------------
|
|
714
|
+
|
|
715
|
+
async def _execute_iter(self, prompt: str, deps: Any) -> AsyncGenerator[AgentStep, None]:
|
|
716
|
+
"""Async generator that executes the agent loop and yields each step."""
|
|
717
|
+
self._state = AgentState.running
|
|
718
|
+
self._stop_requested = False
|
|
719
|
+
steps: list[AgentStep] = []
|
|
720
|
+
|
|
721
|
+
try:
|
|
722
|
+
result = await self._execute(prompt, steps, deps)
|
|
723
|
+
for step in result.steps:
|
|
724
|
+
yield step
|
|
725
|
+
self._state = AgentState.idle
|
|
726
|
+
# Store result on the generator for retrieval
|
|
727
|
+
self._last_iter_result = result
|
|
728
|
+
except Exception:
|
|
729
|
+
self._state = AgentState.errored
|
|
730
|
+
raise
|
|
731
|
+
|
|
732
|
+
# ------------------------------------------------------------------
|
|
733
|
+
# run_stream() — async streaming
|
|
734
|
+
# ------------------------------------------------------------------
|
|
735
|
+
|
|
736
|
+
async def _execute_stream(self, prompt: str, deps: Any) -> AsyncGenerator[StreamEvent, None]:
|
|
737
|
+
"""Async generator that executes the agent loop and yields stream events."""
|
|
738
|
+
self._state = AgentState.running
|
|
739
|
+
self._stop_requested = False
|
|
740
|
+
steps: list[AgentStep] = []
|
|
741
|
+
|
|
742
|
+
try:
|
|
743
|
+
# 1. Setup
|
|
744
|
+
session = UsageSession()
|
|
745
|
+
driver_callbacks = DriverCallbacks(
|
|
746
|
+
on_response=session.record,
|
|
747
|
+
on_error=session.record_error,
|
|
748
|
+
)
|
|
749
|
+
ctx = self._build_run_context(prompt, deps, session, [], 0)
|
|
750
|
+
effective_prompt = self._run_input_guardrails(ctx, prompt)
|
|
751
|
+
resolved_system_prompt = self._resolve_system_prompt(ctx)
|
|
752
|
+
wrapped_tools = self._wrap_tools_with_context(ctx)
|
|
753
|
+
has_tools = bool(wrapped_tools)
|
|
754
|
+
|
|
755
|
+
conv = self._build_conversation(
|
|
756
|
+
system_prompt=resolved_system_prompt,
|
|
757
|
+
tools=wrapped_tools if wrapped_tools else None,
|
|
758
|
+
driver_callbacks=driver_callbacks,
|
|
759
|
+
)
|
|
760
|
+
|
|
761
|
+
if self._agent_callbacks.on_iteration:
|
|
762
|
+
self._agent_callbacks.on_iteration(0)
|
|
763
|
+
|
|
764
|
+
if has_tools:
|
|
765
|
+
response_text = await conv.ask(effective_prompt)
|
|
766
|
+
yield StreamEvent(event_type=StreamEventType.text_delta, data=response_text)
|
|
767
|
+
else:
|
|
768
|
+
response_text = ""
|
|
769
|
+
async for chunk in conv.ask_stream(effective_prompt):
|
|
770
|
+
response_text += chunk
|
|
771
|
+
yield StreamEvent(event_type=StreamEventType.text_delta, data=chunk)
|
|
772
|
+
|
|
773
|
+
# Extract steps
|
|
774
|
+
all_tool_calls: list[dict[str, Any]] = []
|
|
775
|
+
self._extract_steps(conv.messages, steps, all_tool_calls)
|
|
776
|
+
|
|
777
|
+
# Parse output
|
|
778
|
+
if self._output_type is not None:
|
|
779
|
+
output, output_text = await self._parse_output(conv, response_text, steps, all_tool_calls, 0.0, session)
|
|
780
|
+
else:
|
|
781
|
+
output = response_text
|
|
782
|
+
output_text = response_text
|
|
783
|
+
|
|
784
|
+
result = AgentResult(
|
|
785
|
+
output=output,
|
|
786
|
+
output_text=output_text,
|
|
787
|
+
messages=conv.messages,
|
|
788
|
+
usage=conv.usage,
|
|
789
|
+
steps=steps,
|
|
790
|
+
all_tool_calls=all_tool_calls,
|
|
791
|
+
state=AgentState.idle,
|
|
792
|
+
run_usage=session.summary(),
|
|
793
|
+
)
|
|
794
|
+
|
|
795
|
+
if self._output_guardrails:
|
|
796
|
+
result = await self._run_output_guardrails(ctx, result, conv, session, steps, all_tool_calls)
|
|
797
|
+
|
|
798
|
+
if self._agent_callbacks.on_step:
|
|
799
|
+
for step in steps:
|
|
800
|
+
self._agent_callbacks.on_step(step)
|
|
801
|
+
if self._agent_callbacks.on_output:
|
|
802
|
+
self._agent_callbacks.on_output(result)
|
|
803
|
+
|
|
804
|
+
yield StreamEvent(event_type=StreamEventType.output, data=result)
|
|
805
|
+
|
|
806
|
+
self._state = AgentState.idle
|
|
807
|
+
self._last_stream_result = result
|
|
808
|
+
except Exception:
|
|
809
|
+
self._state = AgentState.errored
|
|
810
|
+
raise
|
|
811
|
+
|
|
812
|
+
|
|
813
|
+
# ------------------------------------------------------------------
|
|
814
|
+
# AsyncAgentIterator
|
|
815
|
+
# ------------------------------------------------------------------
|
|
816
|
+
|
|
817
|
+
|
|
818
|
+
class AsyncAgentIterator:
|
|
819
|
+
"""Wraps the :meth:`AsyncAgent.iter` async generator, capturing the final result.
|
|
820
|
+
|
|
821
|
+
After async iteration completes, :attr:`result` holds the :class:`AgentResult`.
|
|
822
|
+
"""
|
|
823
|
+
|
|
824
|
+
def __init__(self, gen: AsyncGenerator[AgentStep, None]) -> None:
|
|
825
|
+
self._gen = gen
|
|
826
|
+
self._result: AgentResult | None = None
|
|
827
|
+
self._agent: AsyncAgent[Any] | None = None
|
|
828
|
+
|
|
829
|
+
def __aiter__(self) -> AsyncAgentIterator:
|
|
830
|
+
return self
|
|
831
|
+
|
|
832
|
+
async def __anext__(self) -> AgentStep:
|
|
833
|
+
try:
|
|
834
|
+
return await self._gen.__anext__()
|
|
835
|
+
except StopAsyncIteration:
|
|
836
|
+
# Try to capture the result from the agent
|
|
837
|
+
agent = self._gen.ag_frame and self._gen.ag_frame.f_locals.get("self")
|
|
838
|
+
if agent and hasattr(agent, "_last_iter_result"):
|
|
839
|
+
self._result = agent._last_iter_result
|
|
840
|
+
raise
|
|
841
|
+
|
|
842
|
+
@property
|
|
843
|
+
def result(self) -> AgentResult | None:
|
|
844
|
+
"""The final :class:`AgentResult`, available after iteration completes."""
|
|
845
|
+
return self._result
|
|
846
|
+
|
|
847
|
+
|
|
848
|
+
# ------------------------------------------------------------------
|
|
849
|
+
# AsyncStreamedAgentResult
|
|
850
|
+
# ------------------------------------------------------------------
|
|
851
|
+
|
|
852
|
+
|
|
853
|
+
class AsyncStreamedAgentResult:
|
|
854
|
+
"""Wraps the :meth:`AsyncAgent.run_stream` async generator.
|
|
855
|
+
|
|
856
|
+
Yields :class:`StreamEvent` objects. After iteration completes,
|
|
857
|
+
:attr:`result` holds the :class:`AgentResult`.
|
|
858
|
+
"""
|
|
859
|
+
|
|
860
|
+
def __init__(self, gen: AsyncGenerator[StreamEvent, None]) -> None:
|
|
861
|
+
self._gen = gen
|
|
862
|
+
self._result: AgentResult | None = None
|
|
863
|
+
|
|
864
|
+
def __aiter__(self) -> AsyncStreamedAgentResult:
|
|
865
|
+
return self
|
|
866
|
+
|
|
867
|
+
async def __anext__(self) -> StreamEvent:
|
|
868
|
+
try:
|
|
869
|
+
event = await self._gen.__anext__()
|
|
870
|
+
# Capture result from the output event
|
|
871
|
+
if event.event_type == StreamEventType.output and isinstance(event.data, AgentResult):
|
|
872
|
+
self._result = event.data
|
|
873
|
+
return event
|
|
874
|
+
except StopAsyncIteration:
|
|
875
|
+
raise
|
|
876
|
+
|
|
877
|
+
@property
|
|
878
|
+
def result(self) -> AgentResult | None:
|
|
879
|
+
"""The final :class:`AgentResult`, available after iteration completes."""
|
|
880
|
+
return self._result
|