agentrun-sdk 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of agentrun-sdk might be problematic. Click here for more details.
- agentrun_operation_sdk/cli/__init__.py +1 -0
- agentrun_operation_sdk/cli/cli.py +19 -0
- agentrun_operation_sdk/cli/common.py +21 -0
- agentrun_operation_sdk/cli/runtime/__init__.py +1 -0
- agentrun_operation_sdk/cli/runtime/commands.py +203 -0
- agentrun_operation_sdk/client/client.py +75 -0
- agentrun_operation_sdk/operations/runtime/__init__.py +8 -0
- agentrun_operation_sdk/operations/runtime/configure.py +101 -0
- agentrun_operation_sdk/operations/runtime/launch.py +82 -0
- agentrun_operation_sdk/operations/runtime/models.py +31 -0
- agentrun_operation_sdk/services/runtime.py +152 -0
- agentrun_operation_sdk/utils/logging_config.py +72 -0
- agentrun_operation_sdk/utils/runtime/config.py +94 -0
- agentrun_operation_sdk/utils/runtime/container.py +280 -0
- agentrun_operation_sdk/utils/runtime/entrypoint.py +203 -0
- agentrun_operation_sdk/utils/runtime/schema.py +56 -0
- agentrun_sdk/__init__.py +7 -0
- agentrun_sdk/agent/__init__.py +25 -0
- agentrun_sdk/agent/agent.py +696 -0
- agentrun_sdk/agent/agent_result.py +46 -0
- agentrun_sdk/agent/conversation_manager/__init__.py +26 -0
- agentrun_sdk/agent/conversation_manager/conversation_manager.py +88 -0
- agentrun_sdk/agent/conversation_manager/null_conversation_manager.py +46 -0
- agentrun_sdk/agent/conversation_manager/sliding_window_conversation_manager.py +179 -0
- agentrun_sdk/agent/conversation_manager/summarizing_conversation_manager.py +252 -0
- agentrun_sdk/agent/state.py +97 -0
- agentrun_sdk/event_loop/__init__.py +9 -0
- agentrun_sdk/event_loop/event_loop.py +499 -0
- agentrun_sdk/event_loop/streaming.py +319 -0
- agentrun_sdk/experimental/__init__.py +4 -0
- agentrun_sdk/experimental/hooks/__init__.py +15 -0
- agentrun_sdk/experimental/hooks/events.py +123 -0
- agentrun_sdk/handlers/__init__.py +10 -0
- agentrun_sdk/handlers/callback_handler.py +70 -0
- agentrun_sdk/hooks/__init__.py +49 -0
- agentrun_sdk/hooks/events.py +80 -0
- agentrun_sdk/hooks/registry.py +247 -0
- agentrun_sdk/models/__init__.py +10 -0
- agentrun_sdk/models/anthropic.py +432 -0
- agentrun_sdk/models/bedrock.py +649 -0
- agentrun_sdk/models/litellm.py +225 -0
- agentrun_sdk/models/llamaapi.py +438 -0
- agentrun_sdk/models/mistral.py +539 -0
- agentrun_sdk/models/model.py +95 -0
- agentrun_sdk/models/ollama.py +357 -0
- agentrun_sdk/models/openai.py +436 -0
- agentrun_sdk/models/sagemaker.py +598 -0
- agentrun_sdk/models/writer.py +449 -0
- agentrun_sdk/multiagent/__init__.py +22 -0
- agentrun_sdk/multiagent/a2a/__init__.py +15 -0
- agentrun_sdk/multiagent/a2a/executor.py +148 -0
- agentrun_sdk/multiagent/a2a/server.py +252 -0
- agentrun_sdk/multiagent/base.py +92 -0
- agentrun_sdk/multiagent/graph.py +555 -0
- agentrun_sdk/multiagent/swarm.py +656 -0
- agentrun_sdk/py.typed +1 -0
- agentrun_sdk/session/__init__.py +18 -0
- agentrun_sdk/session/file_session_manager.py +216 -0
- agentrun_sdk/session/repository_session_manager.py +152 -0
- agentrun_sdk/session/s3_session_manager.py +272 -0
- agentrun_sdk/session/session_manager.py +73 -0
- agentrun_sdk/session/session_repository.py +51 -0
- agentrun_sdk/telemetry/__init__.py +21 -0
- agentrun_sdk/telemetry/config.py +194 -0
- agentrun_sdk/telemetry/metrics.py +476 -0
- agentrun_sdk/telemetry/metrics_constants.py +15 -0
- agentrun_sdk/telemetry/tracer.py +563 -0
- agentrun_sdk/tools/__init__.py +17 -0
- agentrun_sdk/tools/decorator.py +569 -0
- agentrun_sdk/tools/executor.py +137 -0
- agentrun_sdk/tools/loader.py +152 -0
- agentrun_sdk/tools/mcp/__init__.py +13 -0
- agentrun_sdk/tools/mcp/mcp_agent_tool.py +99 -0
- agentrun_sdk/tools/mcp/mcp_client.py +423 -0
- agentrun_sdk/tools/mcp/mcp_instrumentation.py +322 -0
- agentrun_sdk/tools/mcp/mcp_types.py +63 -0
- agentrun_sdk/tools/registry.py +607 -0
- agentrun_sdk/tools/structured_output.py +421 -0
- agentrun_sdk/tools/tools.py +217 -0
- agentrun_sdk/tools/watcher.py +136 -0
- agentrun_sdk/types/__init__.py +5 -0
- agentrun_sdk/types/collections.py +23 -0
- agentrun_sdk/types/content.py +188 -0
- agentrun_sdk/types/event_loop.py +48 -0
- agentrun_sdk/types/exceptions.py +81 -0
- agentrun_sdk/types/guardrails.py +254 -0
- agentrun_sdk/types/media.py +89 -0
- agentrun_sdk/types/session.py +152 -0
- agentrun_sdk/types/streaming.py +201 -0
- agentrun_sdk/types/tools.py +258 -0
- agentrun_sdk/types/traces.py +5 -0
- agentrun_sdk-0.1.2.dist-info/METADATA +51 -0
- agentrun_sdk-0.1.2.dist-info/RECORD +115 -0
- agentrun_sdk-0.1.2.dist-info/WHEEL +5 -0
- agentrun_sdk-0.1.2.dist-info/entry_points.txt +2 -0
- agentrun_sdk-0.1.2.dist-info/top_level.txt +3 -0
- agentrun_wrapper/__init__.py +11 -0
- agentrun_wrapper/_utils/__init__.py +6 -0
- agentrun_wrapper/_utils/endpoints.py +16 -0
- agentrun_wrapper/identity/__init__.py +5 -0
- agentrun_wrapper/identity/auth.py +211 -0
- agentrun_wrapper/memory/__init__.py +6 -0
- agentrun_wrapper/memory/client.py +1697 -0
- agentrun_wrapper/memory/constants.py +103 -0
- agentrun_wrapper/memory/controlplane.py +626 -0
- agentrun_wrapper/py.typed +1 -0
- agentrun_wrapper/runtime/__init__.py +13 -0
- agentrun_wrapper/runtime/app.py +473 -0
- agentrun_wrapper/runtime/context.py +34 -0
- agentrun_wrapper/runtime/models.py +25 -0
- agentrun_wrapper/services/__init__.py +1 -0
- agentrun_wrapper/services/identity.py +192 -0
- agentrun_wrapper/tools/__init__.py +6 -0
- agentrun_wrapper/tools/browser_client.py +325 -0
- agentrun_wrapper/tools/code_interpreter_client.py +186 -0
|
@@ -0,0 +1,555 @@
|
|
|
1
|
+
"""Directed Acyclic Graph (DAG) Multi-Agent Pattern Implementation.
|
|
2
|
+
|
|
3
|
+
This module provides a deterministic DAG-based agent orchestration system where
|
|
4
|
+
agents or MultiAgentBase instances (like Swarm or Graph) are nodes in a graph,
|
|
5
|
+
executed according to edge dependencies, with output from one node passed as input
|
|
6
|
+
to connected nodes.
|
|
7
|
+
|
|
8
|
+
Key Features:
|
|
9
|
+
- Agents and MultiAgentBase instances (Swarm, Graph, etc.) as graph nodes
|
|
10
|
+
- Deterministic execution order based on DAG structure
|
|
11
|
+
- Output propagation along edges
|
|
12
|
+
- Topological sort for execution ordering
|
|
13
|
+
- Clear dependency management
|
|
14
|
+
- Supports nested graphs (Graph as a node in another Graph)
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import asyncio
|
|
18
|
+
import logging
|
|
19
|
+
import time
|
|
20
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
21
|
+
from dataclasses import dataclass, field
|
|
22
|
+
from typing import Any, Callable, Tuple
|
|
23
|
+
|
|
24
|
+
from opentelemetry import trace as trace_api
|
|
25
|
+
|
|
26
|
+
from ..agent import Agent
|
|
27
|
+
from ..telemetry import get_tracer
|
|
28
|
+
from ..types.content import ContentBlock
|
|
29
|
+
from ..types.event_loop import Metrics, Usage
|
|
30
|
+
from .base import MultiAgentBase, MultiAgentResult, NodeResult, Status
|
|
31
|
+
|
|
32
|
+
logger = logging.getLogger(__name__)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclass
|
|
36
|
+
class GraphState:
|
|
37
|
+
"""Graph execution state.
|
|
38
|
+
|
|
39
|
+
Attributes:
|
|
40
|
+
status: Current execution status of the graph.
|
|
41
|
+
completed_nodes: Set of nodes that have completed execution.
|
|
42
|
+
failed_nodes: Set of nodes that failed during execution.
|
|
43
|
+
execution_order: List of nodes in the order they were executed.
|
|
44
|
+
task: The original input prompt/query provided to the graph execution.
|
|
45
|
+
This represents the actual work to be performed by the graph as a whole.
|
|
46
|
+
Entry point nodes receive this task as their input if they have no dependencies.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
# Task (with default empty string)
|
|
50
|
+
task: str | list[ContentBlock] = ""
|
|
51
|
+
|
|
52
|
+
# Execution state
|
|
53
|
+
status: Status = Status.PENDING
|
|
54
|
+
completed_nodes: set["GraphNode"] = field(default_factory=set)
|
|
55
|
+
failed_nodes: set["GraphNode"] = field(default_factory=set)
|
|
56
|
+
execution_order: list["GraphNode"] = field(default_factory=list)
|
|
57
|
+
|
|
58
|
+
# Results
|
|
59
|
+
results: dict[str, NodeResult] = field(default_factory=dict)
|
|
60
|
+
|
|
61
|
+
# Accumulated metrics
|
|
62
|
+
accumulated_usage: Usage = field(default_factory=lambda: Usage(inputTokens=0, outputTokens=0, totalTokens=0))
|
|
63
|
+
accumulated_metrics: Metrics = field(default_factory=lambda: Metrics(latencyMs=0))
|
|
64
|
+
execution_count: int = 0
|
|
65
|
+
execution_time: int = 0
|
|
66
|
+
|
|
67
|
+
# Graph structure info
|
|
68
|
+
total_nodes: int = 0
|
|
69
|
+
edges: list[Tuple["GraphNode", "GraphNode"]] = field(default_factory=list)
|
|
70
|
+
entry_points: list["GraphNode"] = field(default_factory=list)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@dataclass
|
|
74
|
+
class GraphResult(MultiAgentResult):
|
|
75
|
+
"""Result from graph execution - extends MultiAgentResult with graph-specific details."""
|
|
76
|
+
|
|
77
|
+
total_nodes: int = 0
|
|
78
|
+
completed_nodes: int = 0
|
|
79
|
+
failed_nodes: int = 0
|
|
80
|
+
execution_order: list["GraphNode"] = field(default_factory=list)
|
|
81
|
+
edges: list[Tuple["GraphNode", "GraphNode"]] = field(default_factory=list)
|
|
82
|
+
entry_points: list["GraphNode"] = field(default_factory=list)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
@dataclass
|
|
86
|
+
class GraphEdge:
|
|
87
|
+
"""Represents an edge in the graph with an optional condition."""
|
|
88
|
+
|
|
89
|
+
from_node: "GraphNode"
|
|
90
|
+
to_node: "GraphNode"
|
|
91
|
+
condition: Callable[[GraphState], bool] | None = None
|
|
92
|
+
|
|
93
|
+
def __hash__(self) -> int:
|
|
94
|
+
"""Return hash for GraphEdge based on from_node and to_node."""
|
|
95
|
+
return hash((self.from_node.node_id, self.to_node.node_id))
|
|
96
|
+
|
|
97
|
+
def should_traverse(self, state: GraphState) -> bool:
|
|
98
|
+
"""Check if this edge should be traversed based on condition."""
|
|
99
|
+
if self.condition is None:
|
|
100
|
+
return True
|
|
101
|
+
return self.condition(state)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
@dataclass
|
|
105
|
+
class GraphNode:
|
|
106
|
+
"""Represents a node in the graph.
|
|
107
|
+
|
|
108
|
+
The execution_status tracks the node's lifecycle within graph orchestration:
|
|
109
|
+
- PENDING: Node hasn't started executing yet
|
|
110
|
+
- EXECUTING: Node is currently running
|
|
111
|
+
- COMPLETED/FAILED: Node finished executing (regardless of result quality)
|
|
112
|
+
"""
|
|
113
|
+
|
|
114
|
+
node_id: str
|
|
115
|
+
executor: Agent | MultiAgentBase
|
|
116
|
+
dependencies: set["GraphNode"] = field(default_factory=set)
|
|
117
|
+
execution_status: Status = Status.PENDING
|
|
118
|
+
result: NodeResult | None = None
|
|
119
|
+
execution_time: int = 0
|
|
120
|
+
|
|
121
|
+
def __hash__(self) -> int:
|
|
122
|
+
"""Return hash for GraphNode based on node_id."""
|
|
123
|
+
return hash(self.node_id)
|
|
124
|
+
|
|
125
|
+
def __eq__(self, other: Any) -> bool:
|
|
126
|
+
"""Return equality for GraphNode based on node_id."""
|
|
127
|
+
if not isinstance(other, GraphNode):
|
|
128
|
+
return False
|
|
129
|
+
return self.node_id == other.node_id
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def _validate_node_executor(
|
|
133
|
+
executor: Agent | MultiAgentBase, existing_nodes: dict[str, GraphNode] | None = None
|
|
134
|
+
) -> None:
|
|
135
|
+
"""Validate a node executor for graph compatibility.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
executor: The executor to validate
|
|
139
|
+
existing_nodes: Optional dict of existing nodes to check for duplicates
|
|
140
|
+
"""
|
|
141
|
+
# Check for duplicate node instances
|
|
142
|
+
if existing_nodes:
|
|
143
|
+
seen_instances = {id(node.executor) for node in existing_nodes.values()}
|
|
144
|
+
if id(executor) in seen_instances:
|
|
145
|
+
raise ValueError("Duplicate node instance detected. Each node must have a unique object instance.")
|
|
146
|
+
|
|
147
|
+
# Validate Agent-specific constraints
|
|
148
|
+
if isinstance(executor, Agent):
|
|
149
|
+
# Check for session persistence
|
|
150
|
+
if executor._session_manager is not None:
|
|
151
|
+
raise ValueError("Session persistence is not supported for Graph agents yet.")
|
|
152
|
+
|
|
153
|
+
# Check for callbacks
|
|
154
|
+
if executor.hooks.has_callbacks():
|
|
155
|
+
raise ValueError("Agent callbacks are not supported for Graph agents yet.")
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
class GraphBuilder:
|
|
159
|
+
"""Builder pattern for constructing graphs."""
|
|
160
|
+
|
|
161
|
+
def __init__(self) -> None:
|
|
162
|
+
"""Initialize GraphBuilder with empty collections."""
|
|
163
|
+
self.nodes: dict[str, GraphNode] = {}
|
|
164
|
+
self.edges: set[GraphEdge] = set()
|
|
165
|
+
self.entry_points: set[GraphNode] = set()
|
|
166
|
+
|
|
167
|
+
def add_node(self, executor: Agent | MultiAgentBase, node_id: str | None = None) -> GraphNode:
|
|
168
|
+
"""Add an Agent or MultiAgentBase instance as a node to the graph."""
|
|
169
|
+
_validate_node_executor(executor, self.nodes)
|
|
170
|
+
|
|
171
|
+
# Auto-generate node_id if not provided
|
|
172
|
+
if node_id is None:
|
|
173
|
+
node_id = getattr(executor, "id", None) or getattr(executor, "name", None) or f"node_{len(self.nodes)}"
|
|
174
|
+
|
|
175
|
+
if node_id in self.nodes:
|
|
176
|
+
raise ValueError(f"Node '{node_id}' already exists")
|
|
177
|
+
|
|
178
|
+
node = GraphNode(node_id=node_id, executor=executor)
|
|
179
|
+
self.nodes[node_id] = node
|
|
180
|
+
return node
|
|
181
|
+
|
|
182
|
+
def add_edge(
|
|
183
|
+
self,
|
|
184
|
+
from_node: str | GraphNode,
|
|
185
|
+
to_node: str | GraphNode,
|
|
186
|
+
condition: Callable[[GraphState], bool] | None = None,
|
|
187
|
+
) -> GraphEdge:
|
|
188
|
+
"""Add an edge between two nodes with optional condition function that receives full GraphState."""
|
|
189
|
+
|
|
190
|
+
def resolve_node(node: str | GraphNode, node_type: str) -> GraphNode:
|
|
191
|
+
if isinstance(node, str):
|
|
192
|
+
if node not in self.nodes:
|
|
193
|
+
raise ValueError(f"{node_type} node '{node}' not found")
|
|
194
|
+
return self.nodes[node]
|
|
195
|
+
else:
|
|
196
|
+
if node not in self.nodes.values():
|
|
197
|
+
raise ValueError(f"{node_type} node object has not been added to the graph, use graph.add_node")
|
|
198
|
+
return node
|
|
199
|
+
|
|
200
|
+
from_node_obj = resolve_node(from_node, "Source")
|
|
201
|
+
to_node_obj = resolve_node(to_node, "Target")
|
|
202
|
+
|
|
203
|
+
# Add edge and update dependencies
|
|
204
|
+
edge = GraphEdge(from_node=from_node_obj, to_node=to_node_obj, condition=condition)
|
|
205
|
+
self.edges.add(edge)
|
|
206
|
+
to_node_obj.dependencies.add(from_node_obj)
|
|
207
|
+
return edge
|
|
208
|
+
|
|
209
|
+
def set_entry_point(self, node_id: str) -> "GraphBuilder":
|
|
210
|
+
"""Set a node as an entry point for graph execution."""
|
|
211
|
+
if node_id not in self.nodes:
|
|
212
|
+
raise ValueError(f"Node '{node_id}' not found")
|
|
213
|
+
self.entry_points.add(self.nodes[node_id])
|
|
214
|
+
return self
|
|
215
|
+
|
|
216
|
+
def build(self) -> "Graph":
|
|
217
|
+
"""Build and validate the graph."""
|
|
218
|
+
if not self.nodes:
|
|
219
|
+
raise ValueError("Graph must contain at least one node")
|
|
220
|
+
|
|
221
|
+
# Auto-detect entry points if none specified
|
|
222
|
+
if not self.entry_points:
|
|
223
|
+
self.entry_points = {node for node_id, node in self.nodes.items() if not node.dependencies}
|
|
224
|
+
logger.debug(
|
|
225
|
+
"entry_points=<%s> | auto-detected entrypoints", ", ".join(node.node_id for node in self.entry_points)
|
|
226
|
+
)
|
|
227
|
+
if not self.entry_points:
|
|
228
|
+
raise ValueError("No entry points found - all nodes have dependencies")
|
|
229
|
+
|
|
230
|
+
# Validate entry points and check for cycles
|
|
231
|
+
self._validate_graph()
|
|
232
|
+
|
|
233
|
+
return Graph(nodes=self.nodes.copy(), edges=self.edges.copy(), entry_points=self.entry_points.copy())
|
|
234
|
+
|
|
235
|
+
def _validate_graph(self) -> None:
|
|
236
|
+
"""Validate graph structure and detect cycles."""
|
|
237
|
+
# Validate entry points exist
|
|
238
|
+
entry_point_ids = {node.node_id for node in self.entry_points}
|
|
239
|
+
invalid_entries = entry_point_ids - set(self.nodes.keys())
|
|
240
|
+
if invalid_entries:
|
|
241
|
+
raise ValueError(f"Entry points not found in nodes: {invalid_entries}")
|
|
242
|
+
|
|
243
|
+
# Check for cycles using DFS with color coding
|
|
244
|
+
WHITE, GRAY, BLACK = 0, 1, 2
|
|
245
|
+
colors = {node_id: WHITE for node_id in self.nodes}
|
|
246
|
+
|
|
247
|
+
def has_cycle_from(node_id: str) -> bool:
|
|
248
|
+
if colors[node_id] == GRAY:
|
|
249
|
+
return True # Back edge found - cycle detected
|
|
250
|
+
if colors[node_id] == BLACK:
|
|
251
|
+
return False
|
|
252
|
+
|
|
253
|
+
colors[node_id] = GRAY
|
|
254
|
+
# Check all outgoing edges for cycles
|
|
255
|
+
for edge in self.edges:
|
|
256
|
+
if edge.from_node.node_id == node_id and has_cycle_from(edge.to_node.node_id):
|
|
257
|
+
return True
|
|
258
|
+
colors[node_id] = BLACK
|
|
259
|
+
return False
|
|
260
|
+
|
|
261
|
+
# Check for cycles from each unvisited node
|
|
262
|
+
if any(colors[node_id] == WHITE and has_cycle_from(node_id) for node_id in self.nodes):
|
|
263
|
+
raise ValueError("Graph contains cycles - must be a directed acyclic graph")
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
class Graph(MultiAgentBase):
|
|
267
|
+
"""Directed Acyclic Graph multi-agent orchestration."""
|
|
268
|
+
|
|
269
|
+
def __init__(self, nodes: dict[str, GraphNode], edges: set[GraphEdge], entry_points: set[GraphNode]) -> None:
|
|
270
|
+
"""Initialize Graph."""
|
|
271
|
+
super().__init__()
|
|
272
|
+
|
|
273
|
+
# Validate nodes for duplicate instances
|
|
274
|
+
self._validate_graph(nodes)
|
|
275
|
+
|
|
276
|
+
self.nodes = nodes
|
|
277
|
+
self.edges = edges
|
|
278
|
+
self.entry_points = entry_points
|
|
279
|
+
self.state = GraphState()
|
|
280
|
+
self.tracer = get_tracer()
|
|
281
|
+
|
|
282
|
+
def __call__(self, task: str | list[ContentBlock], **kwargs: Any) -> GraphResult:
|
|
283
|
+
"""Invoke the graph synchronously."""
|
|
284
|
+
|
|
285
|
+
def execute() -> GraphResult:
|
|
286
|
+
return asyncio.run(self.invoke_async(task))
|
|
287
|
+
|
|
288
|
+
with ThreadPoolExecutor() as executor:
|
|
289
|
+
future = executor.submit(execute)
|
|
290
|
+
return future.result()
|
|
291
|
+
|
|
292
|
+
async def invoke_async(self, task: str | list[ContentBlock], **kwargs: Any) -> GraphResult:
|
|
293
|
+
"""Invoke the graph asynchronously."""
|
|
294
|
+
logger.debug("task=<%s> | starting graph execution", task)
|
|
295
|
+
|
|
296
|
+
# Initialize state
|
|
297
|
+
self.state = GraphState(
|
|
298
|
+
status=Status.EXECUTING,
|
|
299
|
+
task=task,
|
|
300
|
+
total_nodes=len(self.nodes),
|
|
301
|
+
edges=[(edge.from_node, edge.to_node) for edge in self.edges],
|
|
302
|
+
entry_points=list(self.entry_points),
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
start_time = time.time()
|
|
306
|
+
span = self.tracer.start_multiagent_span(task, "graph")
|
|
307
|
+
with trace_api.use_span(span, end_on_exit=True):
|
|
308
|
+
try:
|
|
309
|
+
await self._execute_graph()
|
|
310
|
+
self.state.status = Status.COMPLETED
|
|
311
|
+
logger.debug("status=<%s> | graph execution completed", self.state.status)
|
|
312
|
+
|
|
313
|
+
except Exception:
|
|
314
|
+
logger.exception("graph execution failed")
|
|
315
|
+
self.state.status = Status.FAILED
|
|
316
|
+
raise
|
|
317
|
+
finally:
|
|
318
|
+
self.state.execution_time = round((time.time() - start_time) * 1000)
|
|
319
|
+
return self._build_result()
|
|
320
|
+
|
|
321
|
+
def _validate_graph(self, nodes: dict[str, GraphNode]) -> None:
|
|
322
|
+
"""Validate graph nodes for duplicate instances."""
|
|
323
|
+
# Check for duplicate node instances
|
|
324
|
+
seen_instances = set()
|
|
325
|
+
for node in nodes.values():
|
|
326
|
+
if id(node.executor) in seen_instances:
|
|
327
|
+
raise ValueError("Duplicate node instance detected. Each node must have a unique object instance.")
|
|
328
|
+
seen_instances.add(id(node.executor))
|
|
329
|
+
|
|
330
|
+
# Validate Agent-specific constraints for each node
|
|
331
|
+
_validate_node_executor(node.executor)
|
|
332
|
+
|
|
333
|
+
async def _execute_graph(self) -> None:
|
|
334
|
+
"""Unified execution flow with conditional routing."""
|
|
335
|
+
ready_nodes = list(self.entry_points)
|
|
336
|
+
|
|
337
|
+
while ready_nodes:
|
|
338
|
+
current_batch = ready_nodes.copy()
|
|
339
|
+
ready_nodes.clear()
|
|
340
|
+
|
|
341
|
+
# Execute current batch of ready nodes concurrently
|
|
342
|
+
tasks = [
|
|
343
|
+
asyncio.create_task(self._execute_node(node))
|
|
344
|
+
for node in current_batch
|
|
345
|
+
if node not in self.state.completed_nodes
|
|
346
|
+
]
|
|
347
|
+
|
|
348
|
+
for task in tasks:
|
|
349
|
+
await task
|
|
350
|
+
|
|
351
|
+
# Find newly ready nodes after batch execution
|
|
352
|
+
ready_nodes.extend(self._find_newly_ready_nodes())
|
|
353
|
+
|
|
354
|
+
def _find_newly_ready_nodes(self) -> list["GraphNode"]:
|
|
355
|
+
"""Find nodes that became ready after the last execution."""
|
|
356
|
+
newly_ready = []
|
|
357
|
+
for _node_id, node in self.nodes.items():
|
|
358
|
+
if (
|
|
359
|
+
node not in self.state.completed_nodes
|
|
360
|
+
and node not in self.state.failed_nodes
|
|
361
|
+
and self._is_node_ready_with_conditions(node)
|
|
362
|
+
):
|
|
363
|
+
newly_ready.append(node)
|
|
364
|
+
return newly_ready
|
|
365
|
+
|
|
366
|
+
def _is_node_ready_with_conditions(self, node: GraphNode) -> bool:
|
|
367
|
+
"""Check if a node is ready considering conditional edges."""
|
|
368
|
+
# Get incoming edges to this node
|
|
369
|
+
incoming_edges = [edge for edge in self.edges if edge.to_node == node]
|
|
370
|
+
|
|
371
|
+
if not incoming_edges:
|
|
372
|
+
return node in self.entry_points
|
|
373
|
+
|
|
374
|
+
# Check if at least one incoming edge condition is satisfied
|
|
375
|
+
for edge in incoming_edges:
|
|
376
|
+
if edge.from_node in self.state.completed_nodes:
|
|
377
|
+
if edge.should_traverse(self.state):
|
|
378
|
+
logger.debug(
|
|
379
|
+
"from=<%s>, to=<%s> | edge ready via satisfied condition", edge.from_node.node_id, node.node_id
|
|
380
|
+
)
|
|
381
|
+
return True
|
|
382
|
+
else:
|
|
383
|
+
logger.debug(
|
|
384
|
+
"from=<%s>, to=<%s> | edge condition not satisfied", edge.from_node.node_id, node.node_id
|
|
385
|
+
)
|
|
386
|
+
return False
|
|
387
|
+
|
|
388
|
+
async def _execute_node(self, node: GraphNode) -> None:
|
|
389
|
+
"""Execute a single node with error handling."""
|
|
390
|
+
node.execution_status = Status.EXECUTING
|
|
391
|
+
logger.debug("node_id=<%s> | executing node", node.node_id)
|
|
392
|
+
|
|
393
|
+
start_time = time.time()
|
|
394
|
+
try:
|
|
395
|
+
# Build node input from satisfied dependencies
|
|
396
|
+
node_input = self._build_node_input(node)
|
|
397
|
+
|
|
398
|
+
# Execute based on node type and create unified NodeResult
|
|
399
|
+
if isinstance(node.executor, MultiAgentBase):
|
|
400
|
+
multi_agent_result = await node.executor.invoke_async(node_input)
|
|
401
|
+
|
|
402
|
+
# Create NodeResult with MultiAgentResult directly
|
|
403
|
+
node_result = NodeResult(
|
|
404
|
+
result=multi_agent_result, # type is MultiAgentResult
|
|
405
|
+
execution_time=multi_agent_result.execution_time,
|
|
406
|
+
status=Status.COMPLETED,
|
|
407
|
+
accumulated_usage=multi_agent_result.accumulated_usage,
|
|
408
|
+
accumulated_metrics=multi_agent_result.accumulated_metrics,
|
|
409
|
+
execution_count=multi_agent_result.execution_count,
|
|
410
|
+
)
|
|
411
|
+
|
|
412
|
+
elif isinstance(node.executor, Agent):
|
|
413
|
+
agent_response = await node.executor.invoke_async(node_input)
|
|
414
|
+
|
|
415
|
+
# Extract metrics from agent response
|
|
416
|
+
usage = Usage(inputTokens=0, outputTokens=0, totalTokens=0)
|
|
417
|
+
metrics = Metrics(latencyMs=0)
|
|
418
|
+
if hasattr(agent_response, "metrics") and agent_response.metrics:
|
|
419
|
+
if hasattr(agent_response.metrics, "accumulated_usage"):
|
|
420
|
+
usage = agent_response.metrics.accumulated_usage
|
|
421
|
+
if hasattr(agent_response.metrics, "accumulated_metrics"):
|
|
422
|
+
metrics = agent_response.metrics.accumulated_metrics
|
|
423
|
+
|
|
424
|
+
node_result = NodeResult(
|
|
425
|
+
result=agent_response, # type is AgentResult
|
|
426
|
+
execution_time=round((time.time() - start_time) * 1000),
|
|
427
|
+
status=Status.COMPLETED,
|
|
428
|
+
accumulated_usage=usage,
|
|
429
|
+
accumulated_metrics=metrics,
|
|
430
|
+
execution_count=1,
|
|
431
|
+
)
|
|
432
|
+
else:
|
|
433
|
+
raise ValueError(f"Node '{node.node_id}' of type '{type(node.executor)}' is not supported")
|
|
434
|
+
|
|
435
|
+
# Mark as completed
|
|
436
|
+
node.execution_status = Status.COMPLETED
|
|
437
|
+
node.result = node_result
|
|
438
|
+
node.execution_time = node_result.execution_time
|
|
439
|
+
self.state.completed_nodes.add(node)
|
|
440
|
+
self.state.results[node.node_id] = node_result
|
|
441
|
+
self.state.execution_order.append(node)
|
|
442
|
+
|
|
443
|
+
# Accumulate metrics
|
|
444
|
+
self._accumulate_metrics(node_result)
|
|
445
|
+
|
|
446
|
+
logger.debug(
|
|
447
|
+
"node_id=<%s>, execution_time=<%dms> | node completed successfully", node.node_id, node.execution_time
|
|
448
|
+
)
|
|
449
|
+
|
|
450
|
+
except Exception as e:
|
|
451
|
+
logger.error("node_id=<%s>, error=<%s> | node failed", node.node_id, e)
|
|
452
|
+
execution_time = round((time.time() - start_time) * 1000)
|
|
453
|
+
|
|
454
|
+
# Create a NodeResult for the failed node
|
|
455
|
+
node_result = NodeResult(
|
|
456
|
+
result=e, # Store exception as result
|
|
457
|
+
execution_time=execution_time,
|
|
458
|
+
status=Status.FAILED,
|
|
459
|
+
accumulated_usage=Usage(inputTokens=0, outputTokens=0, totalTokens=0),
|
|
460
|
+
accumulated_metrics=Metrics(latencyMs=execution_time),
|
|
461
|
+
execution_count=1,
|
|
462
|
+
)
|
|
463
|
+
|
|
464
|
+
node.execution_status = Status.FAILED
|
|
465
|
+
node.result = node_result
|
|
466
|
+
node.execution_time = execution_time
|
|
467
|
+
self.state.failed_nodes.add(node)
|
|
468
|
+
self.state.results[node.node_id] = node_result # Store in results for consistency
|
|
469
|
+
|
|
470
|
+
raise
|
|
471
|
+
|
|
472
|
+
def _accumulate_metrics(self, node_result: NodeResult) -> None:
|
|
473
|
+
"""Accumulate metrics from a node result."""
|
|
474
|
+
self.state.accumulated_usage["inputTokens"] += node_result.accumulated_usage.get("inputTokens", 0)
|
|
475
|
+
self.state.accumulated_usage["outputTokens"] += node_result.accumulated_usage.get("outputTokens", 0)
|
|
476
|
+
self.state.accumulated_usage["totalTokens"] += node_result.accumulated_usage.get("totalTokens", 0)
|
|
477
|
+
self.state.accumulated_metrics["latencyMs"] += node_result.accumulated_metrics.get("latencyMs", 0)
|
|
478
|
+
self.state.execution_count += node_result.execution_count
|
|
479
|
+
|
|
480
|
+
def _build_node_input(self, node: GraphNode) -> list[ContentBlock]:
|
|
481
|
+
"""Build input text for a node based on dependency outputs.
|
|
482
|
+
|
|
483
|
+
Example formatted output:
|
|
484
|
+
```
|
|
485
|
+
Original Task: Analyze the quarterly sales data and create a summary report
|
|
486
|
+
|
|
487
|
+
Inputs from previous nodes:
|
|
488
|
+
|
|
489
|
+
From data_processor:
|
|
490
|
+
- Agent: Sales data processed successfully. Found 1,247 transactions totaling $89,432.
|
|
491
|
+
- Agent: Key trends: 15% increase in Q3, top product category is Electronics.
|
|
492
|
+
|
|
493
|
+
From validator:
|
|
494
|
+
- Agent: Data validation complete. All records verified, no anomalies detected.
|
|
495
|
+
```
|
|
496
|
+
"""
|
|
497
|
+
# Get satisfied dependencies
|
|
498
|
+
dependency_results = {}
|
|
499
|
+
for edge in self.edges:
|
|
500
|
+
if (
|
|
501
|
+
edge.to_node == node
|
|
502
|
+
and edge.from_node in self.state.completed_nodes
|
|
503
|
+
and edge.from_node.node_id in self.state.results
|
|
504
|
+
):
|
|
505
|
+
if edge.should_traverse(self.state):
|
|
506
|
+
dependency_results[edge.from_node.node_id] = self.state.results[edge.from_node.node_id]
|
|
507
|
+
|
|
508
|
+
if not dependency_results:
|
|
509
|
+
# No dependencies - return task as ContentBlocks
|
|
510
|
+
if isinstance(self.state.task, str):
|
|
511
|
+
return [ContentBlock(text=self.state.task)]
|
|
512
|
+
else:
|
|
513
|
+
return self.state.task
|
|
514
|
+
|
|
515
|
+
# Combine task with dependency outputs
|
|
516
|
+
node_input = []
|
|
517
|
+
|
|
518
|
+
# Add original task
|
|
519
|
+
if isinstance(self.state.task, str):
|
|
520
|
+
node_input.append(ContentBlock(text=f"Original Task: {self.state.task}"))
|
|
521
|
+
else:
|
|
522
|
+
# Add task content blocks with a prefix
|
|
523
|
+
node_input.append(ContentBlock(text="Original Task:"))
|
|
524
|
+
node_input.extend(self.state.task)
|
|
525
|
+
|
|
526
|
+
# Add dependency outputs
|
|
527
|
+
node_input.append(ContentBlock(text="\nInputs from previous nodes:"))
|
|
528
|
+
|
|
529
|
+
for dep_id, node_result in dependency_results.items():
|
|
530
|
+
node_input.append(ContentBlock(text=f"\nFrom {dep_id}:"))
|
|
531
|
+
# Get all agent results from this node (flattened if nested)
|
|
532
|
+
agent_results = node_result.get_agent_results()
|
|
533
|
+
for result in agent_results:
|
|
534
|
+
agent_name = getattr(result, "agent_name", "Agent")
|
|
535
|
+
result_text = str(result)
|
|
536
|
+
node_input.append(ContentBlock(text=f" - {agent_name}: {result_text}"))
|
|
537
|
+
|
|
538
|
+
return node_input
|
|
539
|
+
|
|
540
|
+
def _build_result(self) -> GraphResult:
|
|
541
|
+
"""Build graph result from current state."""
|
|
542
|
+
return GraphResult(
|
|
543
|
+
status=self.state.status,
|
|
544
|
+
results=self.state.results,
|
|
545
|
+
accumulated_usage=self.state.accumulated_usage,
|
|
546
|
+
accumulated_metrics=self.state.accumulated_metrics,
|
|
547
|
+
execution_count=self.state.execution_count,
|
|
548
|
+
execution_time=self.state.execution_time,
|
|
549
|
+
total_nodes=self.state.total_nodes,
|
|
550
|
+
completed_nodes=len(self.state.completed_nodes),
|
|
551
|
+
failed_nodes=len(self.state.failed_nodes),
|
|
552
|
+
execution_order=self.state.execution_order,
|
|
553
|
+
edges=self.state.edges,
|
|
554
|
+
entry_points=self.state.entry_points,
|
|
555
|
+
)
|