prompture 0.0.29.dev8__py3-none-any.whl → 0.0.38.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prompture/__init__.py +264 -23
- prompture/_version.py +34 -0
- prompture/agent.py +924 -0
- prompture/agent_types.py +156 -0
- prompture/aio/__init__.py +74 -0
- prompture/async_agent.py +880 -0
- prompture/async_conversation.py +789 -0
- prompture/async_core.py +803 -0
- prompture/async_driver.py +193 -0
- prompture/async_groups.py +551 -0
- prompture/cache.py +469 -0
- prompture/callbacks.py +55 -0
- prompture/cli.py +63 -4
- prompture/conversation.py +826 -0
- prompture/core.py +894 -263
- prompture/cost_mixin.py +51 -0
- prompture/discovery.py +187 -0
- prompture/driver.py +206 -5
- prompture/drivers/__init__.py +175 -67
- prompture/drivers/airllm_driver.py +109 -0
- prompture/drivers/async_airllm_driver.py +26 -0
- prompture/drivers/async_azure_driver.py +123 -0
- prompture/drivers/async_claude_driver.py +113 -0
- prompture/drivers/async_google_driver.py +316 -0
- prompture/drivers/async_grok_driver.py +97 -0
- prompture/drivers/async_groq_driver.py +90 -0
- prompture/drivers/async_hugging_driver.py +61 -0
- prompture/drivers/async_lmstudio_driver.py +148 -0
- prompture/drivers/async_local_http_driver.py +44 -0
- prompture/drivers/async_ollama_driver.py +135 -0
- prompture/drivers/async_openai_driver.py +102 -0
- prompture/drivers/async_openrouter_driver.py +102 -0
- prompture/drivers/async_registry.py +133 -0
- prompture/drivers/azure_driver.py +42 -9
- prompture/drivers/claude_driver.py +257 -34
- prompture/drivers/google_driver.py +295 -42
- prompture/drivers/grok_driver.py +35 -32
- prompture/drivers/groq_driver.py +33 -26
- prompture/drivers/hugging_driver.py +6 -6
- prompture/drivers/lmstudio_driver.py +97 -19
- prompture/drivers/local_http_driver.py +6 -6
- prompture/drivers/ollama_driver.py +168 -23
- prompture/drivers/openai_driver.py +184 -9
- prompture/drivers/openrouter_driver.py +37 -25
- prompture/drivers/registry.py +306 -0
- prompture/drivers/vision_helpers.py +153 -0
- prompture/field_definitions.py +106 -96
- prompture/group_types.py +147 -0
- prompture/groups.py +530 -0
- prompture/image.py +180 -0
- prompture/logging.py +80 -0
- prompture/model_rates.py +217 -0
- prompture/persistence.py +254 -0
- prompture/persona.py +482 -0
- prompture/runner.py +49 -47
- prompture/scaffold/__init__.py +1 -0
- prompture/scaffold/generator.py +84 -0
- prompture/scaffold/templates/Dockerfile.j2 +12 -0
- prompture/scaffold/templates/README.md.j2 +41 -0
- prompture/scaffold/templates/config.py.j2 +21 -0
- prompture/scaffold/templates/env.example.j2 +8 -0
- prompture/scaffold/templates/main.py.j2 +86 -0
- prompture/scaffold/templates/models.py.j2 +40 -0
- prompture/scaffold/templates/requirements.txt.j2 +5 -0
- prompture/serialization.py +218 -0
- prompture/server.py +183 -0
- prompture/session.py +117 -0
- prompture/settings.py +19 -1
- prompture/tools.py +219 -267
- prompture/tools_schema.py +254 -0
- prompture/validator.py +3 -3
- prompture-0.0.38.dev2.dist-info/METADATA +369 -0
- prompture-0.0.38.dev2.dist-info/RECORD +77 -0
- {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/WHEEL +1 -1
- prompture-0.0.29.dev8.dist-info/METADATA +0 -368
- prompture-0.0.29.dev8.dist-info/RECORD +0 -27
- {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/entry_points.txt +0 -0
- {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/licenses/LICENSE +0 -0
- {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,551 @@
|
|
|
1
|
+
"""Async multi-agent group coordination.
|
|
2
|
+
|
|
3
|
+
Provides :class:`ParallelGroup`, :class:`AsyncSequentialGroup`,
|
|
4
|
+
:class:`AsyncLoopGroup`, and :class:`AsyncRouterAgent` for composing
|
|
5
|
+
multiple async agents into deterministic workflows.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import asyncio
|
|
11
|
+
import logging
|
|
12
|
+
import time
|
|
13
|
+
from collections.abc import Callable
|
|
14
|
+
from typing import Any
|
|
15
|
+
|
|
16
|
+
from .agent_types import AgentResult, AgentState
|
|
17
|
+
from .group_types import (
|
|
18
|
+
AgentError,
|
|
19
|
+
ErrorPolicy,
|
|
20
|
+
GroupCallbacks,
|
|
21
|
+
GroupResult,
|
|
22
|
+
GroupStep,
|
|
23
|
+
_aggregate_usage,
|
|
24
|
+
)
|
|
25
|
+
from .groups import _agent_name, _inject_state, _normalise_agents
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger("prompture.async_groups")
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
# ------------------------------------------------------------------
|
|
31
|
+
# ParallelGroup
|
|
32
|
+
# ------------------------------------------------------------------
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class ParallelGroup:
|
|
36
|
+
"""Execute agents concurrently and collect results.
|
|
37
|
+
|
|
38
|
+
Agents read from a frozen snapshot of the shared state taken at
|
|
39
|
+
the start of the run. Output key writes are applied after all
|
|
40
|
+
agents complete, in agent index order.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
agents: List of async agents or ``(agent, prompt_template)`` tuples.
|
|
44
|
+
state: Initial shared state dict.
|
|
45
|
+
error_policy: How to handle agent failures.
|
|
46
|
+
timeout_ms: Per-agent timeout in milliseconds.
|
|
47
|
+
callbacks: Observability hooks.
|
|
48
|
+
max_total_cost: Budget cap in USD.
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
def __init__(
|
|
52
|
+
self,
|
|
53
|
+
agents: list[Any],
|
|
54
|
+
*,
|
|
55
|
+
state: dict[str, Any] | None = None,
|
|
56
|
+
error_policy: ErrorPolicy = ErrorPolicy.fail_fast,
|
|
57
|
+
timeout_ms: float | None = None,
|
|
58
|
+
callbacks: GroupCallbacks | None = None,
|
|
59
|
+
max_total_cost: float | None = None,
|
|
60
|
+
) -> None:
|
|
61
|
+
self._agents = _normalise_agents(agents)
|
|
62
|
+
self._state: dict[str, Any] = dict(state) if state else {}
|
|
63
|
+
self._error_policy = error_policy
|
|
64
|
+
self._timeout_ms = timeout_ms
|
|
65
|
+
self._callbacks = callbacks or GroupCallbacks()
|
|
66
|
+
self._max_total_cost = max_total_cost
|
|
67
|
+
self._stop_requested = False
|
|
68
|
+
|
|
69
|
+
def stop(self) -> None:
|
|
70
|
+
"""Request graceful shutdown."""
|
|
71
|
+
self._stop_requested = True
|
|
72
|
+
|
|
73
|
+
async def run_async(self, prompt: str = "") -> GroupResult:
|
|
74
|
+
"""Execute all agents concurrently."""
|
|
75
|
+
self._stop_requested = False
|
|
76
|
+
t0 = time.perf_counter()
|
|
77
|
+
|
|
78
|
+
# Frozen state snapshot for all agents
|
|
79
|
+
frozen_state = dict(self._state)
|
|
80
|
+
|
|
81
|
+
async def _run_one(
|
|
82
|
+
idx: int, agent: Any, custom_prompt: str | None
|
|
83
|
+
) -> tuple[int, str, AgentResult | None, AgentError | None, GroupStep]:
|
|
84
|
+
name = _agent_name(agent, idx)
|
|
85
|
+
|
|
86
|
+
if custom_prompt is not None:
|
|
87
|
+
effective = _inject_state(custom_prompt, frozen_state)
|
|
88
|
+
elif prompt:
|
|
89
|
+
effective = _inject_state(prompt, frozen_state)
|
|
90
|
+
else:
|
|
91
|
+
effective = ""
|
|
92
|
+
|
|
93
|
+
if self._callbacks.on_agent_start:
|
|
94
|
+
self._callbacks.on_agent_start(name, effective)
|
|
95
|
+
|
|
96
|
+
step_t0 = time.perf_counter()
|
|
97
|
+
try:
|
|
98
|
+
coro = agent.run(effective)
|
|
99
|
+
if self._timeout_ms is not None:
|
|
100
|
+
result = await asyncio.wait_for(coro, timeout=self._timeout_ms / 1000)
|
|
101
|
+
else:
|
|
102
|
+
result = await coro
|
|
103
|
+
|
|
104
|
+
duration_ms = (time.perf_counter() - step_t0) * 1000
|
|
105
|
+
step = GroupStep(
|
|
106
|
+
agent_name=name,
|
|
107
|
+
step_type="agent_run",
|
|
108
|
+
timestamp=step_t0,
|
|
109
|
+
duration_ms=duration_ms,
|
|
110
|
+
usage_delta=getattr(result, "run_usage", {}),
|
|
111
|
+
)
|
|
112
|
+
if self._callbacks.on_agent_complete:
|
|
113
|
+
self._callbacks.on_agent_complete(name, result)
|
|
114
|
+
return idx, name, result, None, step
|
|
115
|
+
|
|
116
|
+
except Exception as exc:
|
|
117
|
+
duration_ms = (time.perf_counter() - step_t0) * 1000
|
|
118
|
+
err = AgentError(
|
|
119
|
+
agent_name=name,
|
|
120
|
+
error=exc,
|
|
121
|
+
output_key=getattr(agent, "output_key", None),
|
|
122
|
+
)
|
|
123
|
+
step = GroupStep(
|
|
124
|
+
agent_name=name,
|
|
125
|
+
step_type="agent_error",
|
|
126
|
+
timestamp=step_t0,
|
|
127
|
+
duration_ms=duration_ms,
|
|
128
|
+
error=str(exc),
|
|
129
|
+
)
|
|
130
|
+
if self._callbacks.on_agent_error:
|
|
131
|
+
self._callbacks.on_agent_error(name, exc)
|
|
132
|
+
return idx, name, None, err, step
|
|
133
|
+
|
|
134
|
+
# Launch all agents concurrently
|
|
135
|
+
tasks = [_run_one(idx, agent, custom_prompt) for idx, (agent, custom_prompt) in enumerate(self._agents)]
|
|
136
|
+
completed = await asyncio.gather(*tasks, return_exceptions=False)
|
|
137
|
+
|
|
138
|
+
# Sort by original index to maintain deterministic ordering
|
|
139
|
+
completed_sorted = sorted(completed, key=lambda x: x[0])
|
|
140
|
+
|
|
141
|
+
agent_results: dict[str, Any] = {}
|
|
142
|
+
errors: list[AgentError] = []
|
|
143
|
+
timeline: list[GroupStep] = []
|
|
144
|
+
usage_summaries: list[dict[str, Any]] = []
|
|
145
|
+
|
|
146
|
+
for idx, name, result, err, step in completed_sorted:
|
|
147
|
+
timeline.append(step)
|
|
148
|
+
if err is not None:
|
|
149
|
+
errors.append(err)
|
|
150
|
+
elif result is not None:
|
|
151
|
+
agent_results[name] = result
|
|
152
|
+
usage_summaries.append(getattr(result, "run_usage", {}))
|
|
153
|
+
|
|
154
|
+
# Apply output_key writes in order
|
|
155
|
+
agent_obj = self._agents[idx][0]
|
|
156
|
+
output_key = getattr(agent_obj, "output_key", None)
|
|
157
|
+
if output_key:
|
|
158
|
+
self._state[output_key] = result.output_text
|
|
159
|
+
if self._callbacks.on_state_update:
|
|
160
|
+
self._callbacks.on_state_update(output_key, result.output_text)
|
|
161
|
+
|
|
162
|
+
elapsed_ms = (time.perf_counter() - t0) * 1000
|
|
163
|
+
return GroupResult(
|
|
164
|
+
agent_results=agent_results,
|
|
165
|
+
aggregate_usage=_aggregate_usage(*usage_summaries),
|
|
166
|
+
shared_state=dict(self._state),
|
|
167
|
+
elapsed_ms=elapsed_ms,
|
|
168
|
+
timeline=timeline,
|
|
169
|
+
errors=errors,
|
|
170
|
+
success=len(errors) == 0,
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
def run(self, prompt: str = "") -> GroupResult:
|
|
174
|
+
"""Sync wrapper around :meth:`run_async`."""
|
|
175
|
+
return asyncio.run(self.run_async(prompt))
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
# ------------------------------------------------------------------
|
|
179
|
+
# AsyncSequentialGroup
|
|
180
|
+
# ------------------------------------------------------------------
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
class AsyncSequentialGroup:
|
|
184
|
+
"""Async version of :class:`~prompture.groups.SequentialGroup`.
|
|
185
|
+
|
|
186
|
+
Args:
|
|
187
|
+
agents: List of async agents or ``(agent, prompt_template)`` tuples.
|
|
188
|
+
state: Initial shared state dict.
|
|
189
|
+
error_policy: How to handle agent failures.
|
|
190
|
+
max_total_turns: Limit on total agent runs.
|
|
191
|
+
callbacks: Observability hooks.
|
|
192
|
+
max_total_cost: Budget cap in USD.
|
|
193
|
+
"""
|
|
194
|
+
|
|
195
|
+
def __init__(
|
|
196
|
+
self,
|
|
197
|
+
agents: list[Any],
|
|
198
|
+
*,
|
|
199
|
+
state: dict[str, Any] | None = None,
|
|
200
|
+
error_policy: ErrorPolicy = ErrorPolicy.fail_fast,
|
|
201
|
+
max_total_turns: int | None = None,
|
|
202
|
+
callbacks: GroupCallbacks | None = None,
|
|
203
|
+
max_total_cost: float | None = None,
|
|
204
|
+
) -> None:
|
|
205
|
+
self._agents = _normalise_agents(agents)
|
|
206
|
+
self._state: dict[str, Any] = dict(state) if state else {}
|
|
207
|
+
self._error_policy = error_policy
|
|
208
|
+
self._max_total_turns = max_total_turns
|
|
209
|
+
self._callbacks = callbacks or GroupCallbacks()
|
|
210
|
+
self._max_total_cost = max_total_cost
|
|
211
|
+
self._stop_requested = False
|
|
212
|
+
|
|
213
|
+
def stop(self) -> None:
|
|
214
|
+
self._stop_requested = True
|
|
215
|
+
|
|
216
|
+
async def run(self, prompt: str = "") -> GroupResult:
|
|
217
|
+
"""Execute all agents in sequence (async)."""
|
|
218
|
+
self._stop_requested = False
|
|
219
|
+
t0 = time.perf_counter()
|
|
220
|
+
timeline: list[GroupStep] = []
|
|
221
|
+
agent_results: dict[str, Any] = {}
|
|
222
|
+
errors: list[AgentError] = []
|
|
223
|
+
usage_summaries: list[dict[str, Any]] = []
|
|
224
|
+
turns = 0
|
|
225
|
+
|
|
226
|
+
for idx, (agent, custom_prompt) in enumerate(self._agents):
|
|
227
|
+
if self._stop_requested:
|
|
228
|
+
break
|
|
229
|
+
|
|
230
|
+
name = _agent_name(agent, idx)
|
|
231
|
+
|
|
232
|
+
if custom_prompt is not None:
|
|
233
|
+
effective = _inject_state(custom_prompt, self._state)
|
|
234
|
+
elif prompt:
|
|
235
|
+
effective = _inject_state(prompt, self._state)
|
|
236
|
+
else:
|
|
237
|
+
effective = ""
|
|
238
|
+
|
|
239
|
+
if self._max_total_cost is not None:
|
|
240
|
+
total_so_far = sum(s.get("total_cost", 0.0) for s in usage_summaries)
|
|
241
|
+
if total_so_far >= self._max_total_cost:
|
|
242
|
+
break
|
|
243
|
+
|
|
244
|
+
if self._max_total_turns is not None and turns >= self._max_total_turns:
|
|
245
|
+
break
|
|
246
|
+
|
|
247
|
+
if self._callbacks.on_agent_start:
|
|
248
|
+
self._callbacks.on_agent_start(name, effective)
|
|
249
|
+
|
|
250
|
+
step_t0 = time.perf_counter()
|
|
251
|
+
try:
|
|
252
|
+
result = await agent.run(effective)
|
|
253
|
+
duration_ms = (time.perf_counter() - step_t0) * 1000
|
|
254
|
+
turns += 1
|
|
255
|
+
|
|
256
|
+
agent_results[name] = result
|
|
257
|
+
usage = getattr(result, "run_usage", {})
|
|
258
|
+
usage_summaries.append(usage)
|
|
259
|
+
|
|
260
|
+
output_key = getattr(agent, "output_key", None)
|
|
261
|
+
if output_key:
|
|
262
|
+
self._state[output_key] = result.output_text
|
|
263
|
+
if self._callbacks.on_state_update:
|
|
264
|
+
self._callbacks.on_state_update(output_key, result.output_text)
|
|
265
|
+
|
|
266
|
+
timeline.append(
|
|
267
|
+
GroupStep(
|
|
268
|
+
agent_name=name,
|
|
269
|
+
step_type="agent_run",
|
|
270
|
+
timestamp=step_t0,
|
|
271
|
+
duration_ms=duration_ms,
|
|
272
|
+
usage_delta=usage,
|
|
273
|
+
)
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
if self._callbacks.on_agent_complete:
|
|
277
|
+
self._callbacks.on_agent_complete(name, result)
|
|
278
|
+
|
|
279
|
+
except Exception as exc:
|
|
280
|
+
duration_ms = (time.perf_counter() - step_t0) * 1000
|
|
281
|
+
turns += 1
|
|
282
|
+
err = AgentError(
|
|
283
|
+
agent_name=name,
|
|
284
|
+
error=exc,
|
|
285
|
+
output_key=getattr(agent, "output_key", None),
|
|
286
|
+
)
|
|
287
|
+
errors.append(err)
|
|
288
|
+
timeline.append(
|
|
289
|
+
GroupStep(
|
|
290
|
+
agent_name=name,
|
|
291
|
+
step_type="agent_error",
|
|
292
|
+
timestamp=step_t0,
|
|
293
|
+
duration_ms=duration_ms,
|
|
294
|
+
error=str(exc),
|
|
295
|
+
)
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
if self._callbacks.on_agent_error:
|
|
299
|
+
self._callbacks.on_agent_error(name, exc)
|
|
300
|
+
|
|
301
|
+
if self._error_policy == ErrorPolicy.fail_fast:
|
|
302
|
+
break
|
|
303
|
+
|
|
304
|
+
elapsed_ms = (time.perf_counter() - t0) * 1000
|
|
305
|
+
return GroupResult(
|
|
306
|
+
agent_results=agent_results,
|
|
307
|
+
aggregate_usage=_aggregate_usage(*usage_summaries),
|
|
308
|
+
shared_state=dict(self._state),
|
|
309
|
+
elapsed_ms=elapsed_ms,
|
|
310
|
+
timeline=timeline,
|
|
311
|
+
errors=errors,
|
|
312
|
+
success=len(errors) == 0,
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
# ------------------------------------------------------------------
|
|
317
|
+
# AsyncLoopGroup
|
|
318
|
+
# ------------------------------------------------------------------
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
class AsyncLoopGroup:
|
|
322
|
+
"""Async version of :class:`~prompture.groups.LoopGroup`.
|
|
323
|
+
|
|
324
|
+
Args:
|
|
325
|
+
agents: List of async agents or ``(agent, prompt_template)`` tuples.
|
|
326
|
+
exit_condition: Callable ``(state, iteration) -> bool``.
|
|
327
|
+
max_iterations: Hard cap on loop iterations.
|
|
328
|
+
state: Initial shared state dict.
|
|
329
|
+
error_policy: How to handle agent failures.
|
|
330
|
+
callbacks: Observability hooks.
|
|
331
|
+
"""
|
|
332
|
+
|
|
333
|
+
def __init__(
|
|
334
|
+
self,
|
|
335
|
+
agents: list[Any],
|
|
336
|
+
*,
|
|
337
|
+
exit_condition: Callable[[dict[str, Any], int], bool],
|
|
338
|
+
max_iterations: int = 10,
|
|
339
|
+
state: dict[str, Any] | None = None,
|
|
340
|
+
error_policy: ErrorPolicy = ErrorPolicy.fail_fast,
|
|
341
|
+
callbacks: GroupCallbacks | None = None,
|
|
342
|
+
) -> None:
|
|
343
|
+
self._agents = _normalise_agents(agents)
|
|
344
|
+
self._exit_condition = exit_condition
|
|
345
|
+
self._max_iterations = max_iterations
|
|
346
|
+
self._state: dict[str, Any] = dict(state) if state else {}
|
|
347
|
+
self._error_policy = error_policy
|
|
348
|
+
self._callbacks = callbacks or GroupCallbacks()
|
|
349
|
+
self._stop_requested = False
|
|
350
|
+
|
|
351
|
+
def stop(self) -> None:
|
|
352
|
+
self._stop_requested = True
|
|
353
|
+
|
|
354
|
+
async def run(self, prompt: str = "") -> GroupResult:
|
|
355
|
+
"""Execute the loop (async)."""
|
|
356
|
+
self._stop_requested = False
|
|
357
|
+
t0 = time.perf_counter()
|
|
358
|
+
timeline: list[GroupStep] = []
|
|
359
|
+
agent_results: dict[str, Any] = {}
|
|
360
|
+
errors: list[AgentError] = []
|
|
361
|
+
usage_summaries: list[dict[str, Any]] = []
|
|
362
|
+
|
|
363
|
+
for iteration in range(self._max_iterations):
|
|
364
|
+
if self._stop_requested:
|
|
365
|
+
break
|
|
366
|
+
if self._exit_condition(self._state, iteration):
|
|
367
|
+
break
|
|
368
|
+
|
|
369
|
+
for idx, (agent, custom_prompt) in enumerate(self._agents):
|
|
370
|
+
if self._stop_requested:
|
|
371
|
+
break
|
|
372
|
+
|
|
373
|
+
name = _agent_name(agent, idx)
|
|
374
|
+
result_key = f"{name}_iter{iteration}"
|
|
375
|
+
|
|
376
|
+
if custom_prompt is not None:
|
|
377
|
+
effective = _inject_state(custom_prompt, self._state)
|
|
378
|
+
elif prompt:
|
|
379
|
+
effective = _inject_state(prompt, self._state)
|
|
380
|
+
else:
|
|
381
|
+
effective = ""
|
|
382
|
+
|
|
383
|
+
if self._callbacks.on_agent_start:
|
|
384
|
+
self._callbacks.on_agent_start(name, effective)
|
|
385
|
+
|
|
386
|
+
step_t0 = time.perf_counter()
|
|
387
|
+
try:
|
|
388
|
+
result = await agent.run(effective)
|
|
389
|
+
duration_ms = (time.perf_counter() - step_t0) * 1000
|
|
390
|
+
|
|
391
|
+
agent_results[result_key] = result
|
|
392
|
+
usage = getattr(result, "run_usage", {})
|
|
393
|
+
usage_summaries.append(usage)
|
|
394
|
+
|
|
395
|
+
output_key = getattr(agent, "output_key", None)
|
|
396
|
+
if output_key:
|
|
397
|
+
self._state[output_key] = result.output_text
|
|
398
|
+
if self._callbacks.on_state_update:
|
|
399
|
+
self._callbacks.on_state_update(output_key, result.output_text)
|
|
400
|
+
|
|
401
|
+
timeline.append(
|
|
402
|
+
GroupStep(
|
|
403
|
+
agent_name=name,
|
|
404
|
+
step_type="agent_run",
|
|
405
|
+
timestamp=step_t0,
|
|
406
|
+
duration_ms=duration_ms,
|
|
407
|
+
usage_delta=usage,
|
|
408
|
+
)
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
if self._callbacks.on_agent_complete:
|
|
412
|
+
self._callbacks.on_agent_complete(name, result)
|
|
413
|
+
|
|
414
|
+
except Exception as exc:
|
|
415
|
+
duration_ms = (time.perf_counter() - step_t0) * 1000
|
|
416
|
+
err = AgentError(
|
|
417
|
+
agent_name=name,
|
|
418
|
+
error=exc,
|
|
419
|
+
output_key=getattr(agent, "output_key", None),
|
|
420
|
+
)
|
|
421
|
+
errors.append(err)
|
|
422
|
+
timeline.append(
|
|
423
|
+
GroupStep(
|
|
424
|
+
agent_name=name,
|
|
425
|
+
step_type="agent_error",
|
|
426
|
+
timestamp=step_t0,
|
|
427
|
+
duration_ms=duration_ms,
|
|
428
|
+
error=str(exc),
|
|
429
|
+
)
|
|
430
|
+
)
|
|
431
|
+
|
|
432
|
+
if self._callbacks.on_agent_error:
|
|
433
|
+
self._callbacks.on_agent_error(name, exc)
|
|
434
|
+
|
|
435
|
+
if self._error_policy == ErrorPolicy.fail_fast:
|
|
436
|
+
break
|
|
437
|
+
|
|
438
|
+
if errors and self._error_policy == ErrorPolicy.fail_fast:
|
|
439
|
+
break
|
|
440
|
+
|
|
441
|
+
elapsed_ms = (time.perf_counter() - t0) * 1000
|
|
442
|
+
return GroupResult(
|
|
443
|
+
agent_results=agent_results,
|
|
444
|
+
aggregate_usage=_aggregate_usage(*usage_summaries),
|
|
445
|
+
shared_state=dict(self._state),
|
|
446
|
+
elapsed_ms=elapsed_ms,
|
|
447
|
+
timeline=timeline,
|
|
448
|
+
errors=errors,
|
|
449
|
+
success=len(errors) == 0,
|
|
450
|
+
)
|
|
451
|
+
|
|
452
|
+
|
|
453
|
+
# ------------------------------------------------------------------
|
|
454
|
+
# AsyncRouterAgent
|
|
455
|
+
# ------------------------------------------------------------------
|
|
456
|
+
|
|
457
|
+
_DEFAULT_ROUTING_PROMPT = """Given these specialists:
|
|
458
|
+
{agent_list}
|
|
459
|
+
|
|
460
|
+
Which should handle this? Reply with ONLY the name.
|
|
461
|
+
|
|
462
|
+
Request: {prompt}"""
|
|
463
|
+
|
|
464
|
+
|
|
465
|
+
class AsyncRouterAgent:
|
|
466
|
+
"""Async LLM-driven router that delegates to the best-matching agent.
|
|
467
|
+
|
|
468
|
+
Args:
|
|
469
|
+
model: Model string for the routing LLM call.
|
|
470
|
+
agents: List of async agents to route between.
|
|
471
|
+
routing_prompt: Custom prompt template.
|
|
472
|
+
fallback: Agent to use when routing fails.
|
|
473
|
+
driver: Pre-built async driver instance.
|
|
474
|
+
"""
|
|
475
|
+
|
|
476
|
+
def __init__(
|
|
477
|
+
self,
|
|
478
|
+
model: str = "",
|
|
479
|
+
*,
|
|
480
|
+
agents: list[Any],
|
|
481
|
+
routing_prompt: str | None = None,
|
|
482
|
+
fallback: Any | None = None,
|
|
483
|
+
driver: Any | None = None,
|
|
484
|
+
name: str = "",
|
|
485
|
+
description: str = "",
|
|
486
|
+
output_key: str | None = None,
|
|
487
|
+
) -> None:
|
|
488
|
+
self._model = model
|
|
489
|
+
self._driver = driver
|
|
490
|
+
self._agents = {_agent_name(a, i): a for i, a in enumerate(agents)}
|
|
491
|
+
self._routing_prompt = routing_prompt or _DEFAULT_ROUTING_PROMPT
|
|
492
|
+
self._fallback = fallback
|
|
493
|
+
self.name = name
|
|
494
|
+
self.description = description
|
|
495
|
+
self.output_key = output_key
|
|
496
|
+
|
|
497
|
+
async def run(self, prompt: str, *, deps: Any = None) -> AgentResult:
|
|
498
|
+
"""Route the prompt to the best agent (async)."""
|
|
499
|
+
from .async_conversation import AsyncConversation
|
|
500
|
+
|
|
501
|
+
agent_lines = []
|
|
502
|
+
for name, agent in self._agents.items():
|
|
503
|
+
desc = getattr(agent, "description", "") or ""
|
|
504
|
+
agent_lines.append(f"- {name}: {desc}" if desc else f"- {name}")
|
|
505
|
+
agent_list = "\n".join(agent_lines)
|
|
506
|
+
|
|
507
|
+
routing_text = self._routing_prompt.replace("{agent_list}", agent_list).replace("{prompt}", prompt)
|
|
508
|
+
|
|
509
|
+
kwargs: dict[str, Any] = {}
|
|
510
|
+
if self._driver is not None:
|
|
511
|
+
kwargs["driver"] = self._driver
|
|
512
|
+
else:
|
|
513
|
+
kwargs["model_name"] = self._model
|
|
514
|
+
|
|
515
|
+
conv = AsyncConversation(**kwargs)
|
|
516
|
+
route_response = await conv.ask(routing_text)
|
|
517
|
+
|
|
518
|
+
selected = self._fuzzy_match(route_response.strip())
|
|
519
|
+
|
|
520
|
+
if selected is not None:
|
|
521
|
+
return await selected.run(prompt, deps=deps) if deps is not None else await selected.run(prompt)
|
|
522
|
+
elif self._fallback is not None:
|
|
523
|
+
return await self._fallback.run(prompt, deps=deps) if deps is not None else await self._fallback.run(prompt)
|
|
524
|
+
else:
|
|
525
|
+
return AgentResult(
|
|
526
|
+
output=route_response,
|
|
527
|
+
output_text=route_response,
|
|
528
|
+
messages=conv.messages,
|
|
529
|
+
usage=conv.usage,
|
|
530
|
+
state=AgentState.idle,
|
|
531
|
+
)
|
|
532
|
+
|
|
533
|
+
def _fuzzy_match(self, response: str) -> Any | None:
|
|
534
|
+
"""Find the best matching agent name in the LLM response."""
|
|
535
|
+
response_lower = response.lower().strip()
|
|
536
|
+
|
|
537
|
+
for name, agent in self._agents.items():
|
|
538
|
+
if name.lower() == response_lower:
|
|
539
|
+
return agent
|
|
540
|
+
|
|
541
|
+
for name, agent in self._agents.items():
|
|
542
|
+
if name.lower() in response_lower:
|
|
543
|
+
return agent
|
|
544
|
+
|
|
545
|
+
response_words = set(response_lower.split())
|
|
546
|
+
for name, agent in self._agents.items():
|
|
547
|
+
name_words = set(name.lower().replace("_", " ").split())
|
|
548
|
+
if name_words & response_words:
|
|
549
|
+
return agent
|
|
550
|
+
|
|
551
|
+
return None
|