agentflowkit 0.5.0__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
agentflow/__init__.py ADDED
@@ -0,0 +1,95 @@
1
+ """agentflow - Lightweight multi-agent AI pipeline framework."""
2
+
3
+ __version__ = "0.3.0"
4
+
5
+ from .agent import Agent, BaseAgent
6
+ from .cache import InMemoryCache, RedisCache, ResponseCache
7
+ from .events import EventEmitter
8
+ from .exceptions import (
9
+ AgentError,
10
+ AgentFlowError,
11
+ AgentOutputValidationError,
12
+ AgentTimeoutError,
13
+ LLMError,
14
+ PipelineError,
15
+ ToolError,
16
+ )
17
+ from .hitl import ApprovalPolicy, PauseExecution
18
+ from .llm import LLM
19
+ from .logging import PipelineLogger, get_logger
20
+ from .memory import BaseMemory, InMemoryContext, RedisContext, VectorContext
21
+ from .observability import Hooks, LoggingHooks
22
+ from .pipeline import Pipeline
23
+ from .pricing import estimate_cost, register_price
24
+ from .rate_limiter import RateLimiter
25
+ from .sandbox import (
26
+ DockerSandbox,
27
+ SandboxError,
28
+ SandboxTimeoutError,
29
+ SubprocessSandbox,
30
+ create_sandbox,
31
+ execute_code,
32
+ sandboxed_tool,
33
+ )
34
+ from .swarm import SupervisorAgent
35
+ from .tools import Tool, tool
36
+ from .triggers import BaseTrigger, MQTTTrigger
37
+ from .types import AgentResult, Event, PipelineResult
38
+
39
+ __all__ = [
40
+ # Core
41
+ "Agent",
42
+ "BaseAgent",
43
+ "LLM",
44
+ "Pipeline",
45
+ "SupervisorAgent",
46
+ # Tools
47
+ "Tool",
48
+ "tool",
49
+ # Sandbox
50
+ "DockerSandbox",
51
+ "SubprocessSandbox",
52
+ "SandboxError",
53
+ "SandboxTimeoutError",
54
+ "create_sandbox",
55
+ "execute_code",
56
+ "sandboxed_tool",
57
+ # Cost
58
+ "estimate_cost",
59
+ "register_price",
60
+ # Data models
61
+ "AgentResult",
62
+ "PipelineResult",
63
+ "Event",
64
+ "EventEmitter",
65
+ # Memory
66
+ "BaseMemory",
67
+ "InMemoryContext",
68
+ "RedisContext",
69
+ "VectorContext",
70
+ # Rate limiting
71
+ "RateLimiter",
72
+ # Triggers
73
+ "BaseTrigger",
74
+ "MQTTTrigger",
75
+ # Caching
76
+ "ResponseCache",
77
+ "InMemoryCache",
78
+ "RedisCache",
79
+ # Logging & observability
80
+ "PipelineLogger",
81
+ "get_logger",
82
+ "Hooks",
83
+ "LoggingHooks",
84
+ # Exceptions
85
+ "AgentFlowError",
86
+ "AgentError",
87
+ "AgentTimeoutError",
88
+ "AgentOutputValidationError",
89
+ "PipelineError",
90
+ "LLMError",
91
+ "ToolError",
92
+ # HITL
93
+ "ApprovalPolicy",
94
+ "PauseExecution",
95
+ ]
agentflow/agent.py ADDED
@@ -0,0 +1,536 @@
1
+ """Agent definition via decorators and base class."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import logging
7
+ import time
8
+ from abc import ABC, abstractmethod
9
+ from collections.abc import Awaitable, Callable
10
+ from typing import Any
11
+
12
+ from pydantic import BaseModel, ValidationError
13
+
14
+ from .exceptions import AgentError, AgentOutputValidationError, ToolError
15
+ from .hitl import ApprovalPolicy, PauseExecution
16
+ from .llm import LLM
17
+ from .memory import BaseMemory
18
+ from .tools import Tool
19
+ from .types import AgentResult
20
+
21
+ DEFAULT_MAX_TOOL_ITERATIONS = 6
22
+ TOOL_OUTPUT_MAX_CHARS = 5000
23
+ MESSAGES_MAX_LENGTH = 20
24
+ LLM_RETRIES_PER_ITERATION = 1
25
+
26
+ _log = logging.getLogger("agentflow.agent")
27
+
28
+
29
+ def _truncate_output(output: str, max_chars: int = TOOL_OUTPUT_MAX_CHARS) -> str:
30
+ """Truncate tool output if it exceeds *max_chars*, appending a truncation marker."""
31
+ if len(output) <= max_chars:
32
+ return output
33
+ return output[:max_chars] + f"...[Truncated: {len(output) - max_chars} chars]"
34
+
35
+
36
+ class BaseAgent(ABC):
37
+ """Base class for agents that need full control.
38
+
39
+ Subclass this for complex agents with custom logic.
40
+ """
41
+
42
+ name: str
43
+ role: str
44
+
45
+ def __init__(self, name: str, role: str):
46
+ self.name = name
47
+ self.role = role
48
+
49
+ @abstractmethod
50
+ async def execute(self, task: str, context: dict[str, str], llm: LLM) -> AgentResult:
51
+ """Execute the agent's task.
52
+
53
+ Args:
54
+ task: The task/topic string.
55
+ context: Dict mapping agent_name -> output from previous agents.
56
+ llm: The LLM provider to use.
57
+
58
+ Returns:
59
+ AgentResult with the agent's output.
60
+ """
61
+ ...
62
+
63
+
64
+ class _DecoratorAgent:
65
+ """Agent created via the @Agent decorator."""
66
+
67
+ def __init__(
68
+ self,
69
+ name: str,
70
+ role: str,
71
+ prompt_fn: Callable[..., Awaitable[str]],
72
+ output_schema: type[BaseModel] | None = None,
73
+ tools: list[Tool] | None = None,
74
+ max_tool_iterations: int = DEFAULT_MAX_TOOL_ITERATIONS,
75
+ memory: BaseMemory | None = None,
76
+ ):
77
+ self.name = name
78
+ self.role = role
79
+ self._prompt_fn = prompt_fn
80
+ self._output_schema = output_schema
81
+ self._tools = tools or []
82
+ self._max_tool_iterations = max_tool_iterations
83
+ self._memory = memory
84
+ self._session_id = "default"
85
+ self._approval_policy: ApprovalPolicy | None = None
86
+ # B5/B8: Pre-compute tool schemas once at construction time.
87
+ self._openai_tools = [t.openai_schema for t in self._tools]
88
+
89
+ def set_session(self, session_id: str) -> None:
90
+ """Set the session ID used for memory load/save operations."""
91
+ self._session_id = session_id
92
+
93
+ def set_approval_policy(self, policy: ApprovalPolicy | None) -> None:
94
+ """Attach an HITL approval policy that intercepts tool calls before execution."""
95
+ self._approval_policy = policy
96
+
97
+ async def execute(self, task: str, context: dict[str, str], llm: LLM) -> AgentResult:
98
+ start = time.perf_counter()
99
+ try:
100
+ user_message = await self._prompt_fn(task, context)
101
+ except Exception as e:
102
+ raise AgentError(self.name, f"Prompt function failed: {e}") from e
103
+
104
+ system_prompt = f"You are a {self.role}. Provide clear, thorough, well-structured responses."
105
+
106
+ # M3: Inject prior session context from memory into the system prompt.
107
+ if self._memory is not None:
108
+ prev = await self._memory.load_context(self._session_id)
109
+ if prev:
110
+ parts = [f"{name}: {output[:300]}" for name, output in prev.items()]
111
+ system_prompt += (
112
+ "\n\n[Memory — previous outputs from this session:\n"
113
+ + "\n".join(parts)
114
+ + "\n]"
115
+ )
116
+
117
+ messages: list[dict[str, Any]] = [
118
+ {"role": "system", "content": system_prompt},
119
+ {"role": "user", "content": user_message},
120
+ ]
121
+
122
+ metadata: dict[str, Any] = {}
123
+ if self._tools:
124
+ content, tokens_used, cost, model_name, trace = await self._run_tool_loop(messages, llm)
125
+ metadata["model"] = model_name
126
+ metadata["tool_calls"] = trace
127
+ cached = False
128
+ else:
129
+ try:
130
+ response = await llm.generate(messages)
131
+ except Exception as e:
132
+ raise AgentError(self.name, str(e)) from e
133
+ content = response["content"]
134
+ tokens_used = response["tokens"]
135
+ cost = response.get("cost", 0.0)
136
+ model_name = response["model"]
137
+ cached = response.get("cached", False)
138
+ metadata["model"] = model_name
139
+
140
+ # M3: Persist this agent's output back to memory for downstream agents.
141
+ if self._memory is not None and content:
142
+ await self._memory.save_context(self._session_id, self.name, content)
143
+
144
+ if self._output_schema is not None:
145
+ try:
146
+ validated = self._output_schema.model_validate_json(content)
147
+ metadata["validated_output"] = validated.model_dump()
148
+ except ValidationError as e:
149
+ raise AgentOutputValidationError(self.name, str(e)) from e
150
+
151
+ duration = time.perf_counter() - start
152
+ return AgentResult(
153
+ agent=self.name,
154
+ output=content,
155
+ tokens_used=tokens_used,
156
+ cost=round(cost, 6),
157
+ duration=round(duration, 3),
158
+ cached=cached,
159
+ metadata=metadata,
160
+ )
161
+
162
+ async def _run_tool_loop(
163
+ self,
164
+ messages: list[dict[str, Any]],
165
+ llm: LLM,
166
+ start_iteration: int = 0,
167
+ initial_total_tokens: int = 0,
168
+ initial_total_cost: float = 0.0,
169
+ initial_trace: list[dict[str, Any]] | None = None,
170
+ initial_seen_calls: set[tuple[str, str]] | None = None,
171
+ ) -> tuple[str, int, float, str, list[dict[str, Any]]]:
172
+ """Drive a ReAct-style loop: call the LLM, run any requested tools, repeat.
173
+
174
+ Features:
175
+ - Multiple tool calls run concurrently via ``asyncio.gather`` (B1).
176
+ - Duplicate tool calls are detected, skipped, and reported as errors (B3).
177
+ - Tool outputs > 5000 chars are truncated (B4).
178
+ - Message list is trimmed to prevent context overflow (B2).
179
+ - Transient LLM errors are retried once per iteration (B7).
180
+
181
+ When *start_iteration* > 0 the loop resumes from a prior
182
+ :class:`~agentflow.hitl.PauseExecution`, preserving accumulated tokens,
183
+ cost, trace, and seen-calls state.
184
+
185
+ Returns:
186
+ (final_content, total_tokens, total_cost, model_name, tool_call_trace)
187
+ """
188
+ tool_map = {t.name: t for t in self._tools}
189
+ total_tokens = initial_total_tokens
190
+ total_cost = initial_total_cost
191
+ model_name = llm.model
192
+ trace: list[dict[str, Any]] = initial_trace or []
193
+ seen_calls: set[tuple[str, str]] = initial_seen_calls or set()
194
+
195
+ for iteration in range(start_iteration, self._max_tool_iterations):
196
+ # B7: Per-iteration LLM retry for transient failures.
197
+ for retry in range(LLM_RETRIES_PER_ITERATION + 1):
198
+ try:
199
+ response = await llm.generate(messages, tools=self._openai_tools)
200
+ break
201
+ except Exception as exc:
202
+ if retry == LLM_RETRIES_PER_ITERATION:
203
+ raise AgentError(
204
+ self.name,
205
+ f"LLM call failed after {LLM_RETRIES_PER_ITERATION + 1} iteration-level attempts",
206
+ ) from exc
207
+ _log.warning(
208
+ "LLM generation retry in agent %s (iteration %d, attempt %d)",
209
+ self.name, iteration + 1, retry + 1,
210
+ )
211
+ await asyncio.sleep(0.5 * (2 ** retry))
212
+
213
+ total_tokens += response["tokens"]
214
+ total_cost += response.get("cost", 0.0)
215
+ model_name = response["model"]
216
+ tool_calls = response["tool_calls"]
217
+
218
+ if not tool_calls:
219
+ return response["content"], total_tokens, total_cost, model_name, trace
220
+
221
+ messages.append(
222
+ {
223
+ "role": "assistant",
224
+ "content": response["content"] or None,
225
+ "tool_calls": tool_calls,
226
+ }
227
+ )
228
+
229
+ # B3: Separate duplicate calls from new ones; duplicates get an
230
+ # immediate error observation without being re-executed.
231
+ results_map: dict[str, tuple[str, str, str, str]] = {}
232
+ coros: list[tuple[str, asyncio.Task[tuple[str, str, str, str]]]] = []
233
+ call_order: list[str] = []
234
+
235
+ for call in tool_calls:
236
+ fn = call["function"]
237
+ name = fn["name"]
238
+ arguments = fn["arguments"]
239
+ key = (name, arguments)
240
+
241
+ if key in seen_calls:
242
+ dup_msg = (
243
+ f"Error: Duplicate tool call detected. You already called "
244
+ f"'{name}' with these arguments. Try a different approach."
245
+ )
246
+ results_map[call["id"]] = (call["id"], name, arguments, dup_msg)
247
+ _log.warning(
248
+ "Duplicate tool call blocked in agent %s: %s(%s)",
249
+ self.name, name, arguments[:120],
250
+ )
251
+ else:
252
+ # HITL: check approval policy before dispatching.
253
+ if self._approval_policy is not None and self._approval_policy.requires_approval(
254
+ name, arguments
255
+ ):
256
+ pending: list[dict[str, Any]] = []
257
+ pause_idx = tool_calls.index(call)
258
+ for rc in tool_calls[pause_idx:]:
259
+ rfn = rc["function"]
260
+ rkey = (rfn["name"], rfn["arguments"])
261
+ if rkey not in seen_calls:
262
+ pending.append(rc)
263
+ raise PauseExecution(
264
+ agent_name=self.name,
265
+ tool_name=name,
266
+ tool_arguments=arguments,
267
+ tool_call_id=call["id"],
268
+ messages=messages,
269
+ total_tokens=total_tokens,
270
+ total_cost=total_cost,
271
+ model_name=model_name,
272
+ trace=trace,
273
+ pending_calls=pending,
274
+ seen_calls=[[n, a] for n, a in seen_calls],
275
+ iterations_used=iteration,
276
+ )
277
+ seen_calls.add(key)
278
+ coros.append(
279
+ (call["id"], asyncio.create_task(self._execute_single_tool(call, tool_map)))
280
+ )
281
+ call_order.append(call["id"])
282
+
283
+ # B1: Await all unique tool calls concurrently.
284
+ if coros:
285
+ for cid, task in coros:
286
+ results_map[cid] = await task
287
+
288
+ # Append results in the original tool_calls order.
289
+ for call_id in call_order:
290
+ _, name, arguments, output = results_map[call_id]
291
+ # B4: Truncate oversized tool results.
292
+ output = _truncate_output(output)
293
+ trace.append({"tool": name, "arguments": arguments, "result": output})
294
+ messages.append(
295
+ {"role": "tool", "tool_call_id": call_id, "content": output}
296
+ )
297
+
298
+ # B2: Sliding window — keep system + user, drop oldest tool pairs
299
+ # when the message list grows too large.
300
+ if len(messages) > MESSAGES_MAX_LENGTH:
301
+ overflow = len(messages) - MESSAGES_MAX_LENGTH
302
+ messages[2:2 + overflow] = []
303
+
304
+ raise AgentError(
305
+ self.name,
306
+ f"exceeded max_tool_iterations={self._max_tool_iterations} without a final answer",
307
+ )
308
+
309
+ async def _execute_single_tool(
310
+ self, call: dict[str, Any], tool_map: dict[str, Any]
311
+ ) -> tuple[str, str, str, str]:
312
+ """Execute a single tool call and return (call_id, name, arguments, output).
313
+
314
+ B6: Execution is logged for observability.
315
+ """
316
+ fn = call["function"]
317
+ name = fn["name"]
318
+ arguments = fn["arguments"]
319
+ target = tool_map.get(name)
320
+
321
+ if target is None:
322
+ output = f"Error: unknown tool '{name}'"
323
+ _log.error("Unknown tool called in agent %s: %s", self.name, name)
324
+ else:
325
+ _log.info("Tool executing in agent %s: %s(%s)", self.name, name, arguments[:120])
326
+ try:
327
+ output = await target.acall(arguments)
328
+ _log.info(
329
+ "Tool completed in agent %s: %s (output %d chars)",
330
+ self.name, name, len(output),
331
+ )
332
+ except ToolError as e:
333
+ output = f"Error: {e}"
334
+ _log.error("Tool error in agent %s: %s — %s", self.name, name, e)
335
+ except Exception as e:
336
+ output = f"Error: unexpected exception - {str(e)}"
337
+ _log.error(
338
+ "Unexpected exception in agent %s tool %s: %s", self.name, name, e,
339
+ )
340
+
341
+ return call["id"], name, arguments, output
342
+
343
+ async def resume_execution(
344
+ self,
345
+ pause_data: dict[str, Any],
346
+ llm: LLM,
347
+ approved: bool,
348
+ human_feedback: str = "",
349
+ ) -> AgentResult:
350
+ """Resume execution after a :class:`~agentflow.hitl.PauseExecution`.
351
+
352
+ Applies the human decision to the paused tool call, processes any
353
+ remaining tool calls from the same LLM response batch, then re-enters
354
+ the ReAct loop from where it left off.
355
+
356
+ Args:
357
+ pause_data: Serialized state from ``PauseExecution.as_dict()``.
358
+ llm: The LLM provider for continued generation.
359
+ approved: ``True`` to execute the pending tool; ``False`` to inject
360
+ *human_feedback* as an error observation instead.
361
+ human_feedback: Contextual message injected when *approved* is
362
+ ``False`` so the agent can self-correct.
363
+
364
+ Returns:
365
+ An ``AgentResult`` with the final agent output.
366
+ """
367
+ start = time.perf_counter()
368
+ tool_map = {t.name: t for t in self._tools}
369
+ messages: list[dict[str, Any]] = pause_data["messages"]
370
+ pending_calls: list[dict[str, Any]] = pause_data.get("pending_calls", [])
371
+ seen_calls: set[tuple[str, str]] = set(
372
+ tuple(p) for p in pause_data.get("seen_calls", [])
373
+ )
374
+
375
+ if pending_calls:
376
+ paused_call = pending_calls[0]
377
+ if approved:
378
+ call_id, name, args, output = await self._execute_single_tool(
379
+ paused_call, tool_map
380
+ )
381
+ output = _truncate_output(output)
382
+ pause_data["trace"].append(
383
+ {"tool": name, "arguments": args, "result": output}
384
+ )
385
+ messages.append(
386
+ {"role": "tool", "tool_call_id": call_id, "content": output}
387
+ )
388
+ else:
389
+ name = paused_call["function"]["name"]
390
+ args = paused_call["function"]["arguments"]
391
+ messages.append(
392
+ {"role": "tool", "tool_call_id": paused_call["id"], "content": human_feedback}
393
+ )
394
+ pause_data["trace"].append(
395
+ {
396
+ "tool": name,
397
+ "arguments": args,
398
+ "result": f"[HUMAN REJECTED] {human_feedback}",
399
+ }
400
+ )
401
+
402
+ for call in pending_calls[1:]:
403
+ fn = call["function"]
404
+ key = (fn["name"], fn["arguments"])
405
+ if key in seen_calls:
406
+ messages.append(
407
+ {
408
+ "role": "tool",
409
+ "tool_call_id": call["id"],
410
+ "content": (
411
+ f"Error: Duplicate tool call detected. You already called "
412
+ f"'{fn['name']}' with these arguments."
413
+ ),
414
+ }
415
+ )
416
+ else:
417
+ seen_calls.add(key)
418
+ cid, tname, targs, output = await self._execute_single_tool(
419
+ call, tool_map
420
+ )
421
+ output = _truncate_output(output)
422
+ pause_data["trace"].append(
423
+ {"tool": tname, "arguments": targs, "result": output}
424
+ )
425
+ messages.append(
426
+ {"role": "tool", "tool_call_id": cid, "content": output}
427
+ )
428
+
429
+ content, tokens_used, cost, model_name, final_trace = await self._run_tool_loop(
430
+ messages,
431
+ llm,
432
+ start_iteration=pause_data["iterations_used"] + 1,
433
+ initial_total_tokens=pause_data["total_tokens"],
434
+ initial_total_cost=pause_data["total_cost"],
435
+ initial_trace=pause_data["trace"],
436
+ initial_seen_calls=seen_calls,
437
+ )
438
+
439
+ total_tokens = pause_data["total_tokens"] + tokens_used
440
+ total_cost = pause_data["total_cost"] + cost
441
+
442
+ if self._memory is not None and content:
443
+ await self._memory.save_context(self._session_id, self.name, content)
444
+
445
+ if self._output_schema is not None:
446
+ try:
447
+ self._output_schema.model_validate_json(content)
448
+ except Exception as e:
449
+ raise AgentOutputValidationError(self.name, str(e)) from e
450
+
451
+ duration = time.perf_counter() - start
452
+ return AgentResult(
453
+ agent=self.name,
454
+ output=content,
455
+ tokens_used=total_tokens,
456
+ cost=round(total_cost, 6),
457
+ duration=round(duration, 3),
458
+ cached=False,
459
+ metadata={
460
+ "model": model_name,
461
+ "tool_calls": final_trace,
462
+ },
463
+ )
464
+
465
+ def __repr__(self) -> str:
466
+ return f"Agent(name={self.name!r}, role={self.role!r})"
467
+
468
+
469
+ class Agent:
470
+ """Decorator to define an agent from an async function.
471
+
472
+ The decorated function receives (task, context) and returns
473
+ the user message to send to the LLM.
474
+
475
+ Args:
476
+ name: Unique identifier for the agent within a pipeline.
477
+ role: Describes the agent's persona (used as system prompt context).
478
+ output_schema: Optional Pydantic BaseModel subclass. If provided, the
479
+ LLM response must be valid JSON matching the schema, or
480
+ AgentOutputValidationError is raised.
481
+ tools: Optional list of ``Tool`` objects (see :func:`agentflow.tool`).
482
+ When present the agent runs a ReAct loop, letting the model call
483
+ tools and observe results until it produces a final answer.
484
+ max_tool_iterations: Safety cap on tool-calling rounds (default 6).
485
+
486
+ Usage:
487
+ @Agent(name="researcher", role="Research Analyst")
488
+ async def researcher(task: str, context: dict) -> str:
489
+ return f"Research this topic: {task}"
490
+
491
+ # With tools:
492
+ @tool
493
+ async def search(query: str) -> str:
494
+ \"\"\"Search the web.\"\"\"
495
+ ...
496
+
497
+ @Agent(name="assistant", role="Assistant", tools=[search])
498
+ async def assistant(task: str, context: dict) -> str:
499
+ return task
500
+
501
+ # With structured output:
502
+ class Summary(BaseModel):
503
+ title: str
504
+ points: list[str]
505
+
506
+ @Agent(name="summarizer", role="Summarizer", output_schema=Summary)
507
+ async def summarizer(task: str, context: dict) -> str:
508
+ return f"Summarize as JSON: {task}"
509
+ """
510
+
511
+ def __init__(
512
+ self,
513
+ name: str,
514
+ role: str,
515
+ output_schema: type[BaseModel] | None = None,
516
+ tools: list[Tool] | None = None,
517
+ max_tool_iterations: int = DEFAULT_MAX_TOOL_ITERATIONS,
518
+ memory: BaseMemory | None = None,
519
+ ):
520
+ self.name = name
521
+ self.role = role
522
+ self._output_schema = output_schema
523
+ self._tools = tools
524
+ self._max_tool_iterations = max_tool_iterations
525
+ self._memory = memory
526
+
527
+ def __call__(self, fn: Callable[..., Awaitable[str]]) -> _DecoratorAgent:
528
+ return _DecoratorAgent(
529
+ self.name,
530
+ self.role,
531
+ fn,
532
+ output_schema=self._output_schema,
533
+ tools=self._tools,
534
+ max_tool_iterations=self._max_tool_iterations,
535
+ memory=self._memory,
536
+ )