agentflowkit 0.5.0__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentflow/__init__.py +95 -0
- agentflow/agent.py +536 -0
- agentflow/cache.py +143 -0
- agentflow/cpp_core/bindings.cpp +36 -0
- agentflow/cpp_core/dag_engine.cpp +105 -0
- agentflow/cpp_core/dag_engine.h +22 -0
- agentflow/distillation.py +160 -0
- agentflow/events.py +254 -0
- agentflow/exceptions.py +45 -0
- agentflow/hitl.py +155 -0
- agentflow/llm.py +248 -0
- agentflow/logging.py +131 -0
- agentflow/memory.py +478 -0
- agentflow/observability.py +91 -0
- agentflow/pipeline.py +729 -0
- agentflow/pricing.py +57 -0
- agentflow/py.typed +0 -0
- agentflow/rate_limiter.py +86 -0
- agentflow/sandbox.py +557 -0
- agentflow/swarm.py +262 -0
- agentflow/swarm_routing.py +321 -0
- agentflow/tools.py +154 -0
- agentflow/triggers.py +127 -0
- agentflow/types.py +68 -0
- agentflowkit-0.5.0.dist-info/METADATA +430 -0
- agentflowkit-0.5.0.dist-info/RECORD +28 -0
- agentflowkit-0.5.0.dist-info/WHEEL +5 -0
- agentflowkit-0.5.0.dist-info/licenses/LICENSE +21 -0
agentflow/__init__.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
"""agentflow - Lightweight multi-agent AI pipeline framework."""
|
|
2
|
+
|
|
3
|
+
__version__ = "0.3.0"
|
|
4
|
+
|
|
5
|
+
from .agent import Agent, BaseAgent
|
|
6
|
+
from .cache import InMemoryCache, RedisCache, ResponseCache
|
|
7
|
+
from .events import EventEmitter
|
|
8
|
+
from .exceptions import (
|
|
9
|
+
AgentError,
|
|
10
|
+
AgentFlowError,
|
|
11
|
+
AgentOutputValidationError,
|
|
12
|
+
AgentTimeoutError,
|
|
13
|
+
LLMError,
|
|
14
|
+
PipelineError,
|
|
15
|
+
ToolError,
|
|
16
|
+
)
|
|
17
|
+
from .hitl import ApprovalPolicy, PauseExecution
|
|
18
|
+
from .llm import LLM
|
|
19
|
+
from .logging import PipelineLogger, get_logger
|
|
20
|
+
from .memory import BaseMemory, InMemoryContext, RedisContext, VectorContext
|
|
21
|
+
from .observability import Hooks, LoggingHooks
|
|
22
|
+
from .pipeline import Pipeline
|
|
23
|
+
from .pricing import estimate_cost, register_price
|
|
24
|
+
from .rate_limiter import RateLimiter
|
|
25
|
+
from .sandbox import (
|
|
26
|
+
DockerSandbox,
|
|
27
|
+
SandboxError,
|
|
28
|
+
SandboxTimeoutError,
|
|
29
|
+
SubprocessSandbox,
|
|
30
|
+
create_sandbox,
|
|
31
|
+
execute_code,
|
|
32
|
+
sandboxed_tool,
|
|
33
|
+
)
|
|
34
|
+
from .swarm import SupervisorAgent
|
|
35
|
+
from .tools import Tool, tool
|
|
36
|
+
from .triggers import BaseTrigger, MQTTTrigger
|
|
37
|
+
from .types import AgentResult, Event, PipelineResult
|
|
38
|
+
|
|
39
|
+
__all__ = [
|
|
40
|
+
# Core
|
|
41
|
+
"Agent",
|
|
42
|
+
"BaseAgent",
|
|
43
|
+
"LLM",
|
|
44
|
+
"Pipeline",
|
|
45
|
+
"SupervisorAgent",
|
|
46
|
+
# Tools
|
|
47
|
+
"Tool",
|
|
48
|
+
"tool",
|
|
49
|
+
# Sandbox
|
|
50
|
+
"DockerSandbox",
|
|
51
|
+
"SubprocessSandbox",
|
|
52
|
+
"SandboxError",
|
|
53
|
+
"SandboxTimeoutError",
|
|
54
|
+
"create_sandbox",
|
|
55
|
+
"execute_code",
|
|
56
|
+
"sandboxed_tool",
|
|
57
|
+
# Cost
|
|
58
|
+
"estimate_cost",
|
|
59
|
+
"register_price",
|
|
60
|
+
# Data models
|
|
61
|
+
"AgentResult",
|
|
62
|
+
"PipelineResult",
|
|
63
|
+
"Event",
|
|
64
|
+
"EventEmitter",
|
|
65
|
+
# Memory
|
|
66
|
+
"BaseMemory",
|
|
67
|
+
"InMemoryContext",
|
|
68
|
+
"RedisContext",
|
|
69
|
+
"VectorContext",
|
|
70
|
+
# Rate limiting
|
|
71
|
+
"RateLimiter",
|
|
72
|
+
# Triggers
|
|
73
|
+
"BaseTrigger",
|
|
74
|
+
"MQTTTrigger",
|
|
75
|
+
# Caching
|
|
76
|
+
"ResponseCache",
|
|
77
|
+
"InMemoryCache",
|
|
78
|
+
"RedisCache",
|
|
79
|
+
# Logging & observability
|
|
80
|
+
"PipelineLogger",
|
|
81
|
+
"get_logger",
|
|
82
|
+
"Hooks",
|
|
83
|
+
"LoggingHooks",
|
|
84
|
+
# Exceptions
|
|
85
|
+
"AgentFlowError",
|
|
86
|
+
"AgentError",
|
|
87
|
+
"AgentTimeoutError",
|
|
88
|
+
"AgentOutputValidationError",
|
|
89
|
+
"PipelineError",
|
|
90
|
+
"LLMError",
|
|
91
|
+
"ToolError",
|
|
92
|
+
# HITL
|
|
93
|
+
"ApprovalPolicy",
|
|
94
|
+
"PauseExecution",
|
|
95
|
+
]
|
agentflow/agent.py
ADDED
|
@@ -0,0 +1,536 @@
|
|
|
1
|
+
"""Agent definition via decorators and base class."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import logging
|
|
7
|
+
import time
|
|
8
|
+
from abc import ABC, abstractmethod
|
|
9
|
+
from collections.abc import Awaitable, Callable
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
from pydantic import BaseModel, ValidationError
|
|
13
|
+
|
|
14
|
+
from .exceptions import AgentError, AgentOutputValidationError, ToolError
|
|
15
|
+
from .hitl import ApprovalPolicy, PauseExecution
|
|
16
|
+
from .llm import LLM
|
|
17
|
+
from .memory import BaseMemory
|
|
18
|
+
from .tools import Tool
|
|
19
|
+
from .types import AgentResult
|
|
20
|
+
|
|
21
|
+
DEFAULT_MAX_TOOL_ITERATIONS = 6
|
|
22
|
+
TOOL_OUTPUT_MAX_CHARS = 5000
|
|
23
|
+
MESSAGES_MAX_LENGTH = 20
|
|
24
|
+
LLM_RETRIES_PER_ITERATION = 1
|
|
25
|
+
|
|
26
|
+
_log = logging.getLogger("agentflow.agent")
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _truncate_output(output: str, max_chars: int = TOOL_OUTPUT_MAX_CHARS) -> str:
|
|
30
|
+
"""Truncate tool output if it exceeds *max_chars*, appending a truncation marker."""
|
|
31
|
+
if len(output) <= max_chars:
|
|
32
|
+
return output
|
|
33
|
+
return output[:max_chars] + f"...[Truncated: {len(output) - max_chars} chars]"
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class BaseAgent(ABC):
|
|
37
|
+
"""Base class for agents that need full control.
|
|
38
|
+
|
|
39
|
+
Subclass this for complex agents with custom logic.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
name: str
|
|
43
|
+
role: str
|
|
44
|
+
|
|
45
|
+
def __init__(self, name: str, role: str):
|
|
46
|
+
self.name = name
|
|
47
|
+
self.role = role
|
|
48
|
+
|
|
49
|
+
@abstractmethod
|
|
50
|
+
async def execute(self, task: str, context: dict[str, str], llm: LLM) -> AgentResult:
|
|
51
|
+
"""Execute the agent's task.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
task: The task/topic string.
|
|
55
|
+
context: Dict mapping agent_name -> output from previous agents.
|
|
56
|
+
llm: The LLM provider to use.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
AgentResult with the agent's output.
|
|
60
|
+
"""
|
|
61
|
+
...
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class _DecoratorAgent:
|
|
65
|
+
"""Agent created via the @Agent decorator."""
|
|
66
|
+
|
|
67
|
+
def __init__(
|
|
68
|
+
self,
|
|
69
|
+
name: str,
|
|
70
|
+
role: str,
|
|
71
|
+
prompt_fn: Callable[..., Awaitable[str]],
|
|
72
|
+
output_schema: type[BaseModel] | None = None,
|
|
73
|
+
tools: list[Tool] | None = None,
|
|
74
|
+
max_tool_iterations: int = DEFAULT_MAX_TOOL_ITERATIONS,
|
|
75
|
+
memory: BaseMemory | None = None,
|
|
76
|
+
):
|
|
77
|
+
self.name = name
|
|
78
|
+
self.role = role
|
|
79
|
+
self._prompt_fn = prompt_fn
|
|
80
|
+
self._output_schema = output_schema
|
|
81
|
+
self._tools = tools or []
|
|
82
|
+
self._max_tool_iterations = max_tool_iterations
|
|
83
|
+
self._memory = memory
|
|
84
|
+
self._session_id = "default"
|
|
85
|
+
self._approval_policy: ApprovalPolicy | None = None
|
|
86
|
+
# B5/B8: Pre-compute tool schemas once at construction time.
|
|
87
|
+
self._openai_tools = [t.openai_schema for t in self._tools]
|
|
88
|
+
|
|
89
|
+
def set_session(self, session_id: str) -> None:
|
|
90
|
+
"""Set the session ID used for memory load/save operations."""
|
|
91
|
+
self._session_id = session_id
|
|
92
|
+
|
|
93
|
+
def set_approval_policy(self, policy: ApprovalPolicy | None) -> None:
|
|
94
|
+
"""Attach an HITL approval policy that intercepts tool calls before execution."""
|
|
95
|
+
self._approval_policy = policy
|
|
96
|
+
|
|
97
|
+
async def execute(self, task: str, context: dict[str, str], llm: LLM) -> AgentResult:
|
|
98
|
+
start = time.perf_counter()
|
|
99
|
+
try:
|
|
100
|
+
user_message = await self._prompt_fn(task, context)
|
|
101
|
+
except Exception as e:
|
|
102
|
+
raise AgentError(self.name, f"Prompt function failed: {e}") from e
|
|
103
|
+
|
|
104
|
+
system_prompt = f"You are a {self.role}. Provide clear, thorough, well-structured responses."
|
|
105
|
+
|
|
106
|
+
# M3: Inject prior session context from memory into the system prompt.
|
|
107
|
+
if self._memory is not None:
|
|
108
|
+
prev = await self._memory.load_context(self._session_id)
|
|
109
|
+
if prev:
|
|
110
|
+
parts = [f"{name}: {output[:300]}" for name, output in prev.items()]
|
|
111
|
+
system_prompt += (
|
|
112
|
+
"\n\n[Memory — previous outputs from this session:\n"
|
|
113
|
+
+ "\n".join(parts)
|
|
114
|
+
+ "\n]"
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
messages: list[dict[str, Any]] = [
|
|
118
|
+
{"role": "system", "content": system_prompt},
|
|
119
|
+
{"role": "user", "content": user_message},
|
|
120
|
+
]
|
|
121
|
+
|
|
122
|
+
metadata: dict[str, Any] = {}
|
|
123
|
+
if self._tools:
|
|
124
|
+
content, tokens_used, cost, model_name, trace = await self._run_tool_loop(messages, llm)
|
|
125
|
+
metadata["model"] = model_name
|
|
126
|
+
metadata["tool_calls"] = trace
|
|
127
|
+
cached = False
|
|
128
|
+
else:
|
|
129
|
+
try:
|
|
130
|
+
response = await llm.generate(messages)
|
|
131
|
+
except Exception as e:
|
|
132
|
+
raise AgentError(self.name, str(e)) from e
|
|
133
|
+
content = response["content"]
|
|
134
|
+
tokens_used = response["tokens"]
|
|
135
|
+
cost = response.get("cost", 0.0)
|
|
136
|
+
model_name = response["model"]
|
|
137
|
+
cached = response.get("cached", False)
|
|
138
|
+
metadata["model"] = model_name
|
|
139
|
+
|
|
140
|
+
# M3: Persist this agent's output back to memory for downstream agents.
|
|
141
|
+
if self._memory is not None and content:
|
|
142
|
+
await self._memory.save_context(self._session_id, self.name, content)
|
|
143
|
+
|
|
144
|
+
if self._output_schema is not None:
|
|
145
|
+
try:
|
|
146
|
+
validated = self._output_schema.model_validate_json(content)
|
|
147
|
+
metadata["validated_output"] = validated.model_dump()
|
|
148
|
+
except ValidationError as e:
|
|
149
|
+
raise AgentOutputValidationError(self.name, str(e)) from e
|
|
150
|
+
|
|
151
|
+
duration = time.perf_counter() - start
|
|
152
|
+
return AgentResult(
|
|
153
|
+
agent=self.name,
|
|
154
|
+
output=content,
|
|
155
|
+
tokens_used=tokens_used,
|
|
156
|
+
cost=round(cost, 6),
|
|
157
|
+
duration=round(duration, 3),
|
|
158
|
+
cached=cached,
|
|
159
|
+
metadata=metadata,
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
async def _run_tool_loop(
|
|
163
|
+
self,
|
|
164
|
+
messages: list[dict[str, Any]],
|
|
165
|
+
llm: LLM,
|
|
166
|
+
start_iteration: int = 0,
|
|
167
|
+
initial_total_tokens: int = 0,
|
|
168
|
+
initial_total_cost: float = 0.0,
|
|
169
|
+
initial_trace: list[dict[str, Any]] | None = None,
|
|
170
|
+
initial_seen_calls: set[tuple[str, str]] | None = None,
|
|
171
|
+
) -> tuple[str, int, float, str, list[dict[str, Any]]]:
|
|
172
|
+
"""Drive a ReAct-style loop: call the LLM, run any requested tools, repeat.
|
|
173
|
+
|
|
174
|
+
Features:
|
|
175
|
+
- Multiple tool calls run concurrently via ``asyncio.gather`` (B1).
|
|
176
|
+
- Duplicate tool calls are detected, skipped, and reported as errors (B3).
|
|
177
|
+
- Tool outputs > 5000 chars are truncated (B4).
|
|
178
|
+
- Message list is trimmed to prevent context overflow (B2).
|
|
179
|
+
- Transient LLM errors are retried once per iteration (B7).
|
|
180
|
+
|
|
181
|
+
When *start_iteration* > 0 the loop resumes from a prior
|
|
182
|
+
:class:`~agentflow.hitl.PauseExecution`, preserving accumulated tokens,
|
|
183
|
+
cost, trace, and seen-calls state.
|
|
184
|
+
|
|
185
|
+
Returns:
|
|
186
|
+
(final_content, total_tokens, total_cost, model_name, tool_call_trace)
|
|
187
|
+
"""
|
|
188
|
+
tool_map = {t.name: t for t in self._tools}
|
|
189
|
+
total_tokens = initial_total_tokens
|
|
190
|
+
total_cost = initial_total_cost
|
|
191
|
+
model_name = llm.model
|
|
192
|
+
trace: list[dict[str, Any]] = initial_trace or []
|
|
193
|
+
seen_calls: set[tuple[str, str]] = initial_seen_calls or set()
|
|
194
|
+
|
|
195
|
+
for iteration in range(start_iteration, self._max_tool_iterations):
|
|
196
|
+
# B7: Per-iteration LLM retry for transient failures.
|
|
197
|
+
for retry in range(LLM_RETRIES_PER_ITERATION + 1):
|
|
198
|
+
try:
|
|
199
|
+
response = await llm.generate(messages, tools=self._openai_tools)
|
|
200
|
+
break
|
|
201
|
+
except Exception as exc:
|
|
202
|
+
if retry == LLM_RETRIES_PER_ITERATION:
|
|
203
|
+
raise AgentError(
|
|
204
|
+
self.name,
|
|
205
|
+
f"LLM call failed after {LLM_RETRIES_PER_ITERATION + 1} iteration-level attempts",
|
|
206
|
+
) from exc
|
|
207
|
+
_log.warning(
|
|
208
|
+
"LLM generation retry in agent %s (iteration %d, attempt %d)",
|
|
209
|
+
self.name, iteration + 1, retry + 1,
|
|
210
|
+
)
|
|
211
|
+
await asyncio.sleep(0.5 * (2 ** retry))
|
|
212
|
+
|
|
213
|
+
total_tokens += response["tokens"]
|
|
214
|
+
total_cost += response.get("cost", 0.0)
|
|
215
|
+
model_name = response["model"]
|
|
216
|
+
tool_calls = response["tool_calls"]
|
|
217
|
+
|
|
218
|
+
if not tool_calls:
|
|
219
|
+
return response["content"], total_tokens, total_cost, model_name, trace
|
|
220
|
+
|
|
221
|
+
messages.append(
|
|
222
|
+
{
|
|
223
|
+
"role": "assistant",
|
|
224
|
+
"content": response["content"] or None,
|
|
225
|
+
"tool_calls": tool_calls,
|
|
226
|
+
}
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
# B3: Separate duplicate calls from new ones; duplicates get an
|
|
230
|
+
# immediate error observation without being re-executed.
|
|
231
|
+
results_map: dict[str, tuple[str, str, str, str]] = {}
|
|
232
|
+
coros: list[tuple[str, asyncio.Task[tuple[str, str, str, str]]]] = []
|
|
233
|
+
call_order: list[str] = []
|
|
234
|
+
|
|
235
|
+
for call in tool_calls:
|
|
236
|
+
fn = call["function"]
|
|
237
|
+
name = fn["name"]
|
|
238
|
+
arguments = fn["arguments"]
|
|
239
|
+
key = (name, arguments)
|
|
240
|
+
|
|
241
|
+
if key in seen_calls:
|
|
242
|
+
dup_msg = (
|
|
243
|
+
f"Error: Duplicate tool call detected. You already called "
|
|
244
|
+
f"'{name}' with these arguments. Try a different approach."
|
|
245
|
+
)
|
|
246
|
+
results_map[call["id"]] = (call["id"], name, arguments, dup_msg)
|
|
247
|
+
_log.warning(
|
|
248
|
+
"Duplicate tool call blocked in agent %s: %s(%s)",
|
|
249
|
+
self.name, name, arguments[:120],
|
|
250
|
+
)
|
|
251
|
+
else:
|
|
252
|
+
# HITL: check approval policy before dispatching.
|
|
253
|
+
if self._approval_policy is not None and self._approval_policy.requires_approval(
|
|
254
|
+
name, arguments
|
|
255
|
+
):
|
|
256
|
+
pending: list[dict[str, Any]] = []
|
|
257
|
+
pause_idx = tool_calls.index(call)
|
|
258
|
+
for rc in tool_calls[pause_idx:]:
|
|
259
|
+
rfn = rc["function"]
|
|
260
|
+
rkey = (rfn["name"], rfn["arguments"])
|
|
261
|
+
if rkey not in seen_calls:
|
|
262
|
+
pending.append(rc)
|
|
263
|
+
raise PauseExecution(
|
|
264
|
+
agent_name=self.name,
|
|
265
|
+
tool_name=name,
|
|
266
|
+
tool_arguments=arguments,
|
|
267
|
+
tool_call_id=call["id"],
|
|
268
|
+
messages=messages,
|
|
269
|
+
total_tokens=total_tokens,
|
|
270
|
+
total_cost=total_cost,
|
|
271
|
+
model_name=model_name,
|
|
272
|
+
trace=trace,
|
|
273
|
+
pending_calls=pending,
|
|
274
|
+
seen_calls=[[n, a] for n, a in seen_calls],
|
|
275
|
+
iterations_used=iteration,
|
|
276
|
+
)
|
|
277
|
+
seen_calls.add(key)
|
|
278
|
+
coros.append(
|
|
279
|
+
(call["id"], asyncio.create_task(self._execute_single_tool(call, tool_map)))
|
|
280
|
+
)
|
|
281
|
+
call_order.append(call["id"])
|
|
282
|
+
|
|
283
|
+
# B1: Await all unique tool calls concurrently.
|
|
284
|
+
if coros:
|
|
285
|
+
for cid, task in coros:
|
|
286
|
+
results_map[cid] = await task
|
|
287
|
+
|
|
288
|
+
# Append results in the original tool_calls order.
|
|
289
|
+
for call_id in call_order:
|
|
290
|
+
_, name, arguments, output = results_map[call_id]
|
|
291
|
+
# B4: Truncate oversized tool results.
|
|
292
|
+
output = _truncate_output(output)
|
|
293
|
+
trace.append({"tool": name, "arguments": arguments, "result": output})
|
|
294
|
+
messages.append(
|
|
295
|
+
{"role": "tool", "tool_call_id": call_id, "content": output}
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
# B2: Sliding window — keep system + user, drop oldest tool pairs
|
|
299
|
+
# when the message list grows too large.
|
|
300
|
+
if len(messages) > MESSAGES_MAX_LENGTH:
|
|
301
|
+
overflow = len(messages) - MESSAGES_MAX_LENGTH
|
|
302
|
+
messages[2:2 + overflow] = []
|
|
303
|
+
|
|
304
|
+
raise AgentError(
|
|
305
|
+
self.name,
|
|
306
|
+
f"exceeded max_tool_iterations={self._max_tool_iterations} without a final answer",
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
async def _execute_single_tool(
|
|
310
|
+
self, call: dict[str, Any], tool_map: dict[str, Any]
|
|
311
|
+
) -> tuple[str, str, str, str]:
|
|
312
|
+
"""Execute a single tool call and return (call_id, name, arguments, output).
|
|
313
|
+
|
|
314
|
+
B6: Execution is logged for observability.
|
|
315
|
+
"""
|
|
316
|
+
fn = call["function"]
|
|
317
|
+
name = fn["name"]
|
|
318
|
+
arguments = fn["arguments"]
|
|
319
|
+
target = tool_map.get(name)
|
|
320
|
+
|
|
321
|
+
if target is None:
|
|
322
|
+
output = f"Error: unknown tool '{name}'"
|
|
323
|
+
_log.error("Unknown tool called in agent %s: %s", self.name, name)
|
|
324
|
+
else:
|
|
325
|
+
_log.info("Tool executing in agent %s: %s(%s)", self.name, name, arguments[:120])
|
|
326
|
+
try:
|
|
327
|
+
output = await target.acall(arguments)
|
|
328
|
+
_log.info(
|
|
329
|
+
"Tool completed in agent %s: %s (output %d chars)",
|
|
330
|
+
self.name, name, len(output),
|
|
331
|
+
)
|
|
332
|
+
except ToolError as e:
|
|
333
|
+
output = f"Error: {e}"
|
|
334
|
+
_log.error("Tool error in agent %s: %s — %s", self.name, name, e)
|
|
335
|
+
except Exception as e:
|
|
336
|
+
output = f"Error: unexpected exception - {str(e)}"
|
|
337
|
+
_log.error(
|
|
338
|
+
"Unexpected exception in agent %s tool %s: %s", self.name, name, e,
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
return call["id"], name, arguments, output
|
|
342
|
+
|
|
343
|
+
async def resume_execution(
|
|
344
|
+
self,
|
|
345
|
+
pause_data: dict[str, Any],
|
|
346
|
+
llm: LLM,
|
|
347
|
+
approved: bool,
|
|
348
|
+
human_feedback: str = "",
|
|
349
|
+
) -> AgentResult:
|
|
350
|
+
"""Resume execution after a :class:`~agentflow.hitl.PauseExecution`.
|
|
351
|
+
|
|
352
|
+
Applies the human decision to the paused tool call, processes any
|
|
353
|
+
remaining tool calls from the same LLM response batch, then re-enters
|
|
354
|
+
the ReAct loop from where it left off.
|
|
355
|
+
|
|
356
|
+
Args:
|
|
357
|
+
pause_data: Serialized state from ``PauseExecution.as_dict()``.
|
|
358
|
+
llm: The LLM provider for continued generation.
|
|
359
|
+
approved: ``True`` to execute the pending tool; ``False`` to inject
|
|
360
|
+
*human_feedback* as an error observation instead.
|
|
361
|
+
human_feedback: Contextual message injected when *approved* is
|
|
362
|
+
``False`` so the agent can self-correct.
|
|
363
|
+
|
|
364
|
+
Returns:
|
|
365
|
+
An ``AgentResult`` with the final agent output.
|
|
366
|
+
"""
|
|
367
|
+
start = time.perf_counter()
|
|
368
|
+
tool_map = {t.name: t for t in self._tools}
|
|
369
|
+
messages: list[dict[str, Any]] = pause_data["messages"]
|
|
370
|
+
pending_calls: list[dict[str, Any]] = pause_data.get("pending_calls", [])
|
|
371
|
+
seen_calls: set[tuple[str, str]] = set(
|
|
372
|
+
tuple(p) for p in pause_data.get("seen_calls", [])
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
if pending_calls:
|
|
376
|
+
paused_call = pending_calls[0]
|
|
377
|
+
if approved:
|
|
378
|
+
call_id, name, args, output = await self._execute_single_tool(
|
|
379
|
+
paused_call, tool_map
|
|
380
|
+
)
|
|
381
|
+
output = _truncate_output(output)
|
|
382
|
+
pause_data["trace"].append(
|
|
383
|
+
{"tool": name, "arguments": args, "result": output}
|
|
384
|
+
)
|
|
385
|
+
messages.append(
|
|
386
|
+
{"role": "tool", "tool_call_id": call_id, "content": output}
|
|
387
|
+
)
|
|
388
|
+
else:
|
|
389
|
+
name = paused_call["function"]["name"]
|
|
390
|
+
args = paused_call["function"]["arguments"]
|
|
391
|
+
messages.append(
|
|
392
|
+
{"role": "tool", "tool_call_id": paused_call["id"], "content": human_feedback}
|
|
393
|
+
)
|
|
394
|
+
pause_data["trace"].append(
|
|
395
|
+
{
|
|
396
|
+
"tool": name,
|
|
397
|
+
"arguments": args,
|
|
398
|
+
"result": f"[HUMAN REJECTED] {human_feedback}",
|
|
399
|
+
}
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
for call in pending_calls[1:]:
|
|
403
|
+
fn = call["function"]
|
|
404
|
+
key = (fn["name"], fn["arguments"])
|
|
405
|
+
if key in seen_calls:
|
|
406
|
+
messages.append(
|
|
407
|
+
{
|
|
408
|
+
"role": "tool",
|
|
409
|
+
"tool_call_id": call["id"],
|
|
410
|
+
"content": (
|
|
411
|
+
f"Error: Duplicate tool call detected. You already called "
|
|
412
|
+
f"'{fn['name']}' with these arguments."
|
|
413
|
+
),
|
|
414
|
+
}
|
|
415
|
+
)
|
|
416
|
+
else:
|
|
417
|
+
seen_calls.add(key)
|
|
418
|
+
cid, tname, targs, output = await self._execute_single_tool(
|
|
419
|
+
call, tool_map
|
|
420
|
+
)
|
|
421
|
+
output = _truncate_output(output)
|
|
422
|
+
pause_data["trace"].append(
|
|
423
|
+
{"tool": tname, "arguments": targs, "result": output}
|
|
424
|
+
)
|
|
425
|
+
messages.append(
|
|
426
|
+
{"role": "tool", "tool_call_id": cid, "content": output}
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
content, tokens_used, cost, model_name, final_trace = await self._run_tool_loop(
|
|
430
|
+
messages,
|
|
431
|
+
llm,
|
|
432
|
+
start_iteration=pause_data["iterations_used"] + 1,
|
|
433
|
+
initial_total_tokens=pause_data["total_tokens"],
|
|
434
|
+
initial_total_cost=pause_data["total_cost"],
|
|
435
|
+
initial_trace=pause_data["trace"],
|
|
436
|
+
initial_seen_calls=seen_calls,
|
|
437
|
+
)
|
|
438
|
+
|
|
439
|
+
total_tokens = pause_data["total_tokens"] + tokens_used
|
|
440
|
+
total_cost = pause_data["total_cost"] + cost
|
|
441
|
+
|
|
442
|
+
if self._memory is not None and content:
|
|
443
|
+
await self._memory.save_context(self._session_id, self.name, content)
|
|
444
|
+
|
|
445
|
+
if self._output_schema is not None:
|
|
446
|
+
try:
|
|
447
|
+
self._output_schema.model_validate_json(content)
|
|
448
|
+
except Exception as e:
|
|
449
|
+
raise AgentOutputValidationError(self.name, str(e)) from e
|
|
450
|
+
|
|
451
|
+
duration = time.perf_counter() - start
|
|
452
|
+
return AgentResult(
|
|
453
|
+
agent=self.name,
|
|
454
|
+
output=content,
|
|
455
|
+
tokens_used=total_tokens,
|
|
456
|
+
cost=round(total_cost, 6),
|
|
457
|
+
duration=round(duration, 3),
|
|
458
|
+
cached=False,
|
|
459
|
+
metadata={
|
|
460
|
+
"model": model_name,
|
|
461
|
+
"tool_calls": final_trace,
|
|
462
|
+
},
|
|
463
|
+
)
|
|
464
|
+
|
|
465
|
+
def __repr__(self) -> str:
|
|
466
|
+
return f"Agent(name={self.name!r}, role={self.role!r})"
|
|
467
|
+
|
|
468
|
+
|
|
469
|
+
class Agent:
|
|
470
|
+
"""Decorator to define an agent from an async function.
|
|
471
|
+
|
|
472
|
+
The decorated function receives (task, context) and returns
|
|
473
|
+
the user message to send to the LLM.
|
|
474
|
+
|
|
475
|
+
Args:
|
|
476
|
+
name: Unique identifier for the agent within a pipeline.
|
|
477
|
+
role: Describes the agent's persona (used as system prompt context).
|
|
478
|
+
output_schema: Optional Pydantic BaseModel subclass. If provided, the
|
|
479
|
+
LLM response must be valid JSON matching the schema, or
|
|
480
|
+
AgentOutputValidationError is raised.
|
|
481
|
+
tools: Optional list of ``Tool`` objects (see :func:`agentflow.tool`).
|
|
482
|
+
When present the agent runs a ReAct loop, letting the model call
|
|
483
|
+
tools and observe results until it produces a final answer.
|
|
484
|
+
max_tool_iterations: Safety cap on tool-calling rounds (default 6).
|
|
485
|
+
|
|
486
|
+
Usage:
|
|
487
|
+
@Agent(name="researcher", role="Research Analyst")
|
|
488
|
+
async def researcher(task: str, context: dict) -> str:
|
|
489
|
+
return f"Research this topic: {task}"
|
|
490
|
+
|
|
491
|
+
# With tools:
|
|
492
|
+
@tool
|
|
493
|
+
async def search(query: str) -> str:
|
|
494
|
+
\"\"\"Search the web.\"\"\"
|
|
495
|
+
...
|
|
496
|
+
|
|
497
|
+
@Agent(name="assistant", role="Assistant", tools=[search])
|
|
498
|
+
async def assistant(task: str, context: dict) -> str:
|
|
499
|
+
return task
|
|
500
|
+
|
|
501
|
+
# With structured output:
|
|
502
|
+
class Summary(BaseModel):
|
|
503
|
+
title: str
|
|
504
|
+
points: list[str]
|
|
505
|
+
|
|
506
|
+
@Agent(name="summarizer", role="Summarizer", output_schema=Summary)
|
|
507
|
+
async def summarizer(task: str, context: dict) -> str:
|
|
508
|
+
return f"Summarize as JSON: {task}"
|
|
509
|
+
"""
|
|
510
|
+
|
|
511
|
+
def __init__(
|
|
512
|
+
self,
|
|
513
|
+
name: str,
|
|
514
|
+
role: str,
|
|
515
|
+
output_schema: type[BaseModel] | None = None,
|
|
516
|
+
tools: list[Tool] | None = None,
|
|
517
|
+
max_tool_iterations: int = DEFAULT_MAX_TOOL_ITERATIONS,
|
|
518
|
+
memory: BaseMemory | None = None,
|
|
519
|
+
):
|
|
520
|
+
self.name = name
|
|
521
|
+
self.role = role
|
|
522
|
+
self._output_schema = output_schema
|
|
523
|
+
self._tools = tools
|
|
524
|
+
self._max_tool_iterations = max_tool_iterations
|
|
525
|
+
self._memory = memory
|
|
526
|
+
|
|
527
|
+
def __call__(self, fn: Callable[..., Awaitable[str]]) -> _DecoratorAgent:
|
|
528
|
+
return _DecoratorAgent(
|
|
529
|
+
self.name,
|
|
530
|
+
self.role,
|
|
531
|
+
fn,
|
|
532
|
+
output_schema=self._output_schema,
|
|
533
|
+
tools=self._tools,
|
|
534
|
+
max_tool_iterations=self._max_tool_iterations,
|
|
535
|
+
memory=self._memory,
|
|
536
|
+
)
|