genxai-framework 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli/commands/__init__.py +3 -1
- cli/commands/connector.py +309 -0
- cli/commands/workflow.py +80 -0
- cli/main.py +3 -1
- genxai/__init__.py +33 -0
- genxai/agents/__init__.py +8 -0
- genxai/agents/presets.py +53 -0
- genxai/connectors/__init__.py +10 -0
- genxai/connectors/config_store.py +106 -0
- genxai/connectors/github.py +117 -0
- genxai/connectors/google_workspace.py +124 -0
- genxai/connectors/jira.py +108 -0
- genxai/connectors/notion.py +97 -0
- genxai/connectors/slack.py +121 -0
- genxai/core/agent/config_io.py +32 -1
- genxai/core/agent/runtime.py +41 -4
- genxai/core/graph/__init__.py +3 -0
- genxai/core/graph/engine.py +218 -11
- genxai/core/graph/executor.py +103 -10
- genxai/core/graph/nodes.py +28 -0
- genxai/core/graph/workflow_io.py +199 -0
- genxai/flows/__init__.py +33 -0
- genxai/flows/auction.py +66 -0
- genxai/flows/base.py +134 -0
- genxai/flows/conditional.py +45 -0
- genxai/flows/coordinator_worker.py +62 -0
- genxai/flows/critic_review.py +62 -0
- genxai/flows/ensemble_voting.py +49 -0
- genxai/flows/loop.py +42 -0
- genxai/flows/map_reduce.py +61 -0
- genxai/flows/p2p.py +146 -0
- genxai/flows/parallel.py +27 -0
- genxai/flows/round_robin.py +24 -0
- genxai/flows/router.py +45 -0
- genxai/flows/selector.py +63 -0
- genxai/flows/subworkflow.py +35 -0
- genxai/llm/factory.py +17 -10
- genxai/llm/providers/anthropic.py +116 -1
- genxai/tools/builtin/__init__.py +3 -0
- genxai/tools/builtin/communication/human_input.py +32 -0
- genxai/tools/custom/test-2.py +19 -0
- genxai/tools/custom/test_tool_ui.py +9 -0
- genxai/tools/persistence/service.py +3 -3
- genxai/utils/tokens.py +6 -0
- {genxai_framework-0.1.0.dist-info → genxai_framework-0.1.1.dist-info}/METADATA +63 -12
- {genxai_framework-0.1.0.dist-info → genxai_framework-0.1.1.dist-info}/RECORD +50 -21
- {genxai_framework-0.1.0.dist-info → genxai_framework-0.1.1.dist-info}/WHEEL +0 -0
- {genxai_framework-0.1.0.dist-info → genxai_framework-0.1.1.dist-info}/entry_points.txt +0 -0
- {genxai_framework-0.1.0.dist-info → genxai_framework-0.1.1.dist-info}/licenses/LICENSE +0 -0
- {genxai_framework-0.1.0.dist-info → genxai_framework-0.1.1.dist-info}/top_level.txt +0 -0
genxai/core/agent/runtime.py
CHANGED
|
@@ -5,6 +5,7 @@ import asyncio
|
|
|
5
5
|
import time
|
|
6
6
|
import logging
|
|
7
7
|
import json
|
|
8
|
+
import copy
|
|
8
9
|
|
|
9
10
|
from genxai.core.agent.base import Agent
|
|
10
11
|
from genxai.llm.base import LLMProvider
|
|
@@ -16,6 +17,7 @@ from genxai.observability.tracing import span, add_event, record_exception
|
|
|
16
17
|
from genxai.security.rbac import get_current_user, Permission
|
|
17
18
|
from genxai.security.policy_engine import get_policy_engine
|
|
18
19
|
from genxai.security.audit import get_audit_log, AuditEvent
|
|
20
|
+
from genxai.core.memory.shared import SharedMemoryBus
|
|
19
21
|
|
|
20
22
|
logger = logging.getLogger(__name__)
|
|
21
23
|
|
|
@@ -34,19 +36,25 @@ class AgentRuntime:
|
|
|
34
36
|
agent: Agent,
|
|
35
37
|
llm_provider: Optional[LLMProvider] = None,
|
|
36
38
|
api_key: Optional[str] = None,
|
|
39
|
+
openai_api_key: Optional[str] = None,
|
|
40
|
+
anthropic_api_key: Optional[str] = None,
|
|
37
41
|
enable_memory: bool = True,
|
|
42
|
+
shared_memory: Optional[SharedMemoryBus] = None,
|
|
38
43
|
) -> None:
|
|
39
44
|
"""Initialize agent runtime.
|
|
40
45
|
|
|
41
46
|
Args:
|
|
42
47
|
agent: Agent to execute
|
|
43
48
|
llm_provider: LLM provider instance (optional, will be created if not provided)
|
|
44
|
-
api_key: API key for LLM provider (optional,
|
|
49
|
+
api_key: API key for LLM provider (optional, deprecated - use openai_api_key or anthropic_api_key)
|
|
50
|
+
openai_api_key: OpenAI API key (for GPT models)
|
|
51
|
+
anthropic_api_key: Anthropic API key (for Claude models)
|
|
45
52
|
enable_memory: Whether to initialize memory system
|
|
46
53
|
"""
|
|
47
54
|
self.agent = agent
|
|
48
55
|
self._tools: Dict[str, Any] = {}
|
|
49
56
|
self._memory: Optional[Any] = None
|
|
57
|
+
self._shared_memory = shared_memory
|
|
50
58
|
|
|
51
59
|
# Initialize LLM provider
|
|
52
60
|
if llm_provider:
|
|
@@ -54,9 +62,25 @@ class AgentRuntime:
|
|
|
54
62
|
else:
|
|
55
63
|
# Create provider from agent config
|
|
56
64
|
try:
|
|
65
|
+
# Determine which API key to use based on model
|
|
66
|
+
model = agent.config.llm_model.lower()
|
|
67
|
+
selected_api_key = api_key # Fallback to deprecated api_key parameter
|
|
68
|
+
|
|
69
|
+
if model.startswith('claude'):
|
|
70
|
+
# Claude models use Anthropic API key
|
|
71
|
+
selected_api_key = anthropic_api_key or api_key
|
|
72
|
+
logger.info(f"Using Anthropic API key for Claude model: {agent.config.llm_model}")
|
|
73
|
+
elif model.startswith('gpt'):
|
|
74
|
+
# GPT models use OpenAI API key
|
|
75
|
+
selected_api_key = openai_api_key or api_key
|
|
76
|
+
logger.info(f"Using OpenAI API key for GPT model: {agent.config.llm_model}")
|
|
77
|
+
else:
|
|
78
|
+
# For other models, try to infer or use openai_api_key as default
|
|
79
|
+
selected_api_key = openai_api_key or anthropic_api_key or api_key
|
|
80
|
+
|
|
57
81
|
self._llm_provider = LLMProviderFactory.create_provider(
|
|
58
82
|
model=agent.config.llm_model,
|
|
59
|
-
api_key=
|
|
83
|
+
api_key=selected_api_key,
|
|
60
84
|
temperature=agent.config.llm_temperature,
|
|
61
85
|
max_tokens=agent.config.llm_max_tokens,
|
|
62
86
|
)
|
|
@@ -177,7 +201,13 @@ class AgentRuntime:
|
|
|
177
201
|
memory_context = await self.get_memory_context(limit=5)
|
|
178
202
|
|
|
179
203
|
# Build prompt (without memory context, as it's handled in _get_llm_response)
|
|
180
|
-
|
|
204
|
+
prompt_context = dict(context)
|
|
205
|
+
if self._shared_memory is not None:
|
|
206
|
+
prompt_context["shared_memory"] = {
|
|
207
|
+
key: self._shared_memory.get(key)
|
|
208
|
+
for key in self._shared_memory.list_keys()
|
|
209
|
+
}
|
|
210
|
+
prompt = self._build_prompt(task, prompt_context, "")
|
|
181
211
|
|
|
182
212
|
# Get LLM response with retry logic and memory context
|
|
183
213
|
if self.agent.config.tools and self._tools and self._provider_supports_tools():
|
|
@@ -193,13 +223,20 @@ class AgentRuntime:
|
|
|
193
223
|
await self._update_memory(task, response)
|
|
194
224
|
|
|
195
225
|
# Build result
|
|
226
|
+
safe_context: Dict[str, Any]
|
|
227
|
+
try:
|
|
228
|
+
safe_context = copy.deepcopy(context)
|
|
229
|
+
except Exception:
|
|
230
|
+
safe_context = dict(context)
|
|
231
|
+
safe_context.pop("llm_provider", None)
|
|
232
|
+
safe_context.pop("shared_memory", None)
|
|
196
233
|
result = {
|
|
197
234
|
"agent_id": self.agent.id,
|
|
198
235
|
"task": task,
|
|
199
236
|
"status": "completed",
|
|
200
237
|
"output": response,
|
|
201
|
-
"context": context,
|
|
202
238
|
"tokens_used": self.agent._total_tokens,
|
|
239
|
+
"context": safe_context,
|
|
203
240
|
}
|
|
204
241
|
|
|
205
242
|
# Store episode in episodic memory
|
genxai/core/graph/__init__.py
CHANGED
|
@@ -14,6 +14,7 @@ from genxai.core.graph.checkpoints import (
|
|
|
14
14
|
)
|
|
15
15
|
from genxai.core.graph.trigger_runner import TriggerWorkflowRunner
|
|
16
16
|
from genxai.core.execution import WorkerQueueEngine
|
|
17
|
+
from genxai.core.graph.workflow_io import load_workflow_yaml, register_workflow_agents
|
|
17
18
|
|
|
18
19
|
__all__ = [
|
|
19
20
|
"Node",
|
|
@@ -27,4 +28,6 @@ __all__ = [
|
|
|
27
28
|
"WorkflowCheckpointManager",
|
|
28
29
|
"TriggerWorkflowRunner",
|
|
29
30
|
"WorkerQueueEngine",
|
|
31
|
+
"load_workflow_yaml",
|
|
32
|
+
"register_workflow_agents",
|
|
30
33
|
]
|
genxai/core/graph/engine.py
CHANGED
|
@@ -1,18 +1,19 @@
|
|
|
1
1
|
"""Graph execution engine for orchestrating agent workflows."""
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
|
-
from typing import Any, Dict, List, Optional, Set
|
|
4
|
+
from typing import Any, Callable, Dict, List, Optional, Set
|
|
5
5
|
from collections import defaultdict, deque
|
|
6
6
|
import logging
|
|
7
7
|
import time
|
|
8
8
|
import copy
|
|
9
9
|
from pathlib import Path
|
|
10
10
|
|
|
11
|
-
from genxai.core.graph.nodes import Node, NodeStatus, NodeType
|
|
11
|
+
from genxai.core.graph.nodes import Node, NodeConfig, NodeStatus, NodeType
|
|
12
12
|
from genxai.core.agent.registry import AgentRegistry
|
|
13
13
|
from genxai.core.agent.runtime import AgentRuntime
|
|
14
14
|
from genxai.tools.registry import ToolRegistry
|
|
15
15
|
from genxai.core.graph.edges import Edge
|
|
16
|
+
from genxai.core.memory.shared import SharedMemoryBus
|
|
16
17
|
from genxai.core.graph.checkpoints import (
|
|
17
18
|
WorkflowCheckpoint,
|
|
18
19
|
WorkflowCheckpointManager,
|
|
@@ -47,6 +48,7 @@ class Graph:
|
|
|
47
48
|
self.edges: List[Edge] = []
|
|
48
49
|
self._adjacency_list: Dict[str, List[Edge]] = defaultdict(list)
|
|
49
50
|
self._reverse_adjacency: Dict[str, List[str]] = defaultdict(list)
|
|
51
|
+
self.shared_memory: Optional[SharedMemoryBus] = None
|
|
50
52
|
|
|
51
53
|
def add_node(self, node: Node) -> None:
|
|
52
54
|
"""Add a node to the graph.
|
|
@@ -82,6 +84,10 @@ class Graph:
|
|
|
82
84
|
self._reverse_adjacency[edge.target].append(edge.source)
|
|
83
85
|
logger.debug(f"Added edge: {edge.source} -> {edge.target}")
|
|
84
86
|
|
|
87
|
+
def set_shared_memory(self, shared_memory: Optional[SharedMemoryBus]) -> None:
|
|
88
|
+
"""Attach a shared memory bus to the graph for agent execution."""
|
|
89
|
+
self.shared_memory = shared_memory
|
|
90
|
+
|
|
85
91
|
def get_node(self, node_id: str) -> Optional[Node]:
|
|
86
92
|
"""Get a node by ID.
|
|
87
93
|
|
|
@@ -212,6 +218,7 @@ class Graph:
|
|
|
212
218
|
state: Optional[Dict[str, Any]] = None,
|
|
213
219
|
resume_from: Optional[WorkflowCheckpoint] = None,
|
|
214
220
|
llm_provider: Any = None,
|
|
221
|
+
event_callback: Optional[Callable[[Dict[str, Any]], Any]] = None,
|
|
215
222
|
) -> Dict[str, Any]:
|
|
216
223
|
"""Execute the graph workflow.
|
|
217
224
|
|
|
@@ -244,6 +251,7 @@ class Graph:
|
|
|
244
251
|
state = {}
|
|
245
252
|
state["input"] = input_data
|
|
246
253
|
state["iterations"] = 0
|
|
254
|
+
state.setdefault("node_events", [])
|
|
247
255
|
|
|
248
256
|
if resume_from:
|
|
249
257
|
for node_id, status in resume_from.node_statuses.items():
|
|
@@ -274,7 +282,7 @@ class Graph:
|
|
|
274
282
|
try:
|
|
275
283
|
with span("genxai.workflow.execute", {"workflow_id": self.name}):
|
|
276
284
|
for entry_point in entry_points:
|
|
277
|
-
await self._execute_node(entry_point, state, max_iterations)
|
|
285
|
+
await self._execute_node(entry_point, state, max_iterations, event_callback)
|
|
278
286
|
except Exception as exc:
|
|
279
287
|
status = "error"
|
|
280
288
|
record_exception(exc)
|
|
@@ -288,6 +296,7 @@ class Graph:
|
|
|
288
296
|
)
|
|
289
297
|
|
|
290
298
|
logger.info(f"Graph execution completed: {self.name}")
|
|
299
|
+
state["node_events"] = state.get("node_events", [])
|
|
291
300
|
return state
|
|
292
301
|
|
|
293
302
|
def create_checkpoint(self, name: str, state: Dict[str, Any]) -> WorkflowCheckpoint:
|
|
@@ -307,7 +316,11 @@ class Graph:
|
|
|
307
316
|
return manager.load(name)
|
|
308
317
|
|
|
309
318
|
async def _execute_node(
|
|
310
|
-
self,
|
|
319
|
+
self,
|
|
320
|
+
node_id: str,
|
|
321
|
+
state: Dict[str, Any],
|
|
322
|
+
max_iterations: int,
|
|
323
|
+
event_callback: Optional[Callable[[Dict[str, Any]], Any]] = None,
|
|
311
324
|
) -> None:
|
|
312
325
|
"""Execute a single node and its descendants.
|
|
313
326
|
|
|
@@ -334,6 +347,16 @@ class Graph:
|
|
|
334
347
|
node.status = NodeStatus.RUNNING
|
|
335
348
|
logger.debug(f"Executing node: {node_id}")
|
|
336
349
|
node_start = time.time()
|
|
350
|
+
running_event = {
|
|
351
|
+
"node_id": node_id,
|
|
352
|
+
"status": NodeStatus.RUNNING.value,
|
|
353
|
+
"timestamp": time.time(),
|
|
354
|
+
}
|
|
355
|
+
state.setdefault("node_events", []).append(running_event)
|
|
356
|
+
if event_callback:
|
|
357
|
+
callback_result = event_callback(running_event)
|
|
358
|
+
if asyncio.iscoroutine(callback_result):
|
|
359
|
+
await callback_result
|
|
337
360
|
|
|
338
361
|
try:
|
|
339
362
|
# Execute node (placeholder - will be implemented with actual executors)
|
|
@@ -341,16 +364,35 @@ class Graph:
|
|
|
341
364
|
"genxai.workflow.node",
|
|
342
365
|
{"workflow_id": self.name, "node_id": node_id, "node_type": node.type.value},
|
|
343
366
|
):
|
|
344
|
-
result = await self._execute_node_logic(node, state)
|
|
367
|
+
result = await self._execute_node_logic(node, state, max_iterations)
|
|
345
368
|
node.result = result
|
|
346
369
|
node.status = NodeStatus.COMPLETED
|
|
347
370
|
logger.debug(f"Node completed: {node_id}")
|
|
348
371
|
|
|
372
|
+
node_duration_ms = int((time.time() - node_start) * 1000)
|
|
373
|
+
|
|
349
374
|
record_workflow_node_execution(
|
|
350
375
|
workflow_id=self.name,
|
|
351
376
|
node_id=node_id,
|
|
352
377
|
status="success",
|
|
353
378
|
)
|
|
379
|
+
completed_event = {
|
|
380
|
+
"node_id": node_id,
|
|
381
|
+
"status": NodeStatus.COMPLETED.value,
|
|
382
|
+
"timestamp": time.time(),
|
|
383
|
+
"duration_ms": node_duration_ms,
|
|
384
|
+
}
|
|
385
|
+
state.setdefault("node_events", []).append(completed_event)
|
|
386
|
+
if event_callback:
|
|
387
|
+
callback_result = event_callback(completed_event)
|
|
388
|
+
if asyncio.iscoroutine(callback_result):
|
|
389
|
+
await callback_result
|
|
390
|
+
|
|
391
|
+
state.setdefault("node_results", {})[node_id] = {
|
|
392
|
+
"output": result,
|
|
393
|
+
"status": NodeStatus.COMPLETED.value,
|
|
394
|
+
"duration_ms": node_duration_ms,
|
|
395
|
+
}
|
|
354
396
|
|
|
355
397
|
# Update state with result
|
|
356
398
|
state[node_id] = result
|
|
@@ -367,27 +409,48 @@ class Graph:
|
|
|
367
409
|
tasks = []
|
|
368
410
|
for edge in parallel_edges:
|
|
369
411
|
if edge.evaluate_condition(state):
|
|
370
|
-
tasks.append(self._execute_node(edge.target, state, max_iterations))
|
|
412
|
+
tasks.append(self._execute_node(edge.target, state, max_iterations, event_callback))
|
|
371
413
|
if tasks:
|
|
372
|
-
await
|
|
414
|
+
await self._gather_with_config(tasks, state)
|
|
373
415
|
|
|
374
416
|
# Execute sequential edges in order
|
|
375
417
|
for edge in sorted(sequential_edges, key=lambda e: e.priority):
|
|
376
418
|
if edge.evaluate_condition(state):
|
|
377
|
-
await self._execute_node(edge.target, state, max_iterations)
|
|
419
|
+
await self._execute_node(edge.target, state, max_iterations, event_callback)
|
|
378
420
|
|
|
379
421
|
except Exception as e:
|
|
380
422
|
node.status = NodeStatus.FAILED
|
|
381
423
|
node.error = str(e)
|
|
382
424
|
logger.error(f"Node execution failed: {node_id} - {e}")
|
|
425
|
+
node_duration_ms = int((time.time() - node_start) * 1000)
|
|
383
426
|
record_workflow_node_execution(
|
|
384
427
|
workflow_id=self.name,
|
|
385
428
|
node_id=node_id,
|
|
386
429
|
status="error",
|
|
387
430
|
)
|
|
431
|
+
failed_event = {
|
|
432
|
+
"node_id": node_id,
|
|
433
|
+
"status": NodeStatus.FAILED.value,
|
|
434
|
+
"timestamp": time.time(),
|
|
435
|
+
"error": str(e),
|
|
436
|
+
"duration_ms": node_duration_ms,
|
|
437
|
+
}
|
|
438
|
+
state.setdefault("node_events", []).append(failed_event)
|
|
439
|
+
if event_callback:
|
|
440
|
+
callback_result = event_callback(failed_event)
|
|
441
|
+
if asyncio.iscoroutine(callback_result):
|
|
442
|
+
await callback_result
|
|
443
|
+
state.setdefault("node_results", {})[node_id] = {
|
|
444
|
+
"output": None,
|
|
445
|
+
"status": NodeStatus.FAILED.value,
|
|
446
|
+
"duration_ms": node_duration_ms,
|
|
447
|
+
"error": str(e),
|
|
448
|
+
}
|
|
388
449
|
raise GraphExecutionError(f"Node {node_id} failed: {e}") from e
|
|
389
450
|
|
|
390
|
-
async def _execute_node_logic(
|
|
451
|
+
async def _execute_node_logic(
|
|
452
|
+
self, node: Node, state: Dict[str, Any], max_iterations: int
|
|
453
|
+
) -> Any:
|
|
391
454
|
"""Execute the actual logic of a node.
|
|
392
455
|
|
|
393
456
|
Args:
|
|
@@ -409,6 +472,12 @@ class Graph:
|
|
|
409
472
|
if node.type == NodeType.TOOL:
|
|
410
473
|
return await self._execute_tool_node(node, state)
|
|
411
474
|
|
|
475
|
+
if node.type == NodeType.SUBGRAPH:
|
|
476
|
+
return await self._execute_subgraph_node(node, state, max_iterations)
|
|
477
|
+
|
|
478
|
+
if node.type == NodeType.LOOP:
|
|
479
|
+
return await self._execute_loop_node(node, state, max_iterations)
|
|
480
|
+
|
|
412
481
|
# Default fallback for unsupported nodes
|
|
413
482
|
return {"node_id": node.id, "type": node.type.value}
|
|
414
483
|
|
|
@@ -435,7 +504,12 @@ class Graph:
|
|
|
435
504
|
task = node.config.data.get("task") or state.get("task") or "Process input"
|
|
436
505
|
|
|
437
506
|
llm_provider = state.get("llm_provider")
|
|
438
|
-
runtime = AgentRuntime(
|
|
507
|
+
runtime = AgentRuntime(
|
|
508
|
+
agent=agent,
|
|
509
|
+
llm_provider=llm_provider,
|
|
510
|
+
enable_memory=True,
|
|
511
|
+
shared_memory=self.shared_memory,
|
|
512
|
+
)
|
|
439
513
|
if agent.config.tools:
|
|
440
514
|
tools: Dict[str, Any] = {}
|
|
441
515
|
for tool_name in agent.config.tools:
|
|
@@ -444,7 +518,73 @@ class Graph:
|
|
|
444
518
|
tools[tool_name] = tool
|
|
445
519
|
runtime.set_tools(tools)
|
|
446
520
|
|
|
447
|
-
|
|
521
|
+
context = dict(state)
|
|
522
|
+
if self.shared_memory is not None:
|
|
523
|
+
context["shared_memory"] = self.shared_memory
|
|
524
|
+
return await self._execute_with_config(runtime, task=task, context=context, state=state)
|
|
525
|
+
|
|
526
|
+
def _get_execution_config(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
|
527
|
+
config = state.get("execution_config") or {}
|
|
528
|
+
return {
|
|
529
|
+
"timeout_seconds": config.get("timeout_seconds", 120.0),
|
|
530
|
+
"retry_count": config.get("retry_count", 3),
|
|
531
|
+
"backoff_base": config.get("backoff_base", 1.0),
|
|
532
|
+
"backoff_multiplier": config.get("backoff_multiplier", 2.0),
|
|
533
|
+
"cancel_on_failure": config.get("cancel_on_failure", True),
|
|
534
|
+
}
|
|
535
|
+
|
|
536
|
+
async def _execute_with_config(
|
|
537
|
+
self,
|
|
538
|
+
runtime: AgentRuntime,
|
|
539
|
+
task: str,
|
|
540
|
+
context: Dict[str, Any],
|
|
541
|
+
state: Dict[str, Any],
|
|
542
|
+
) -> Any:
|
|
543
|
+
config = self._get_execution_config(state)
|
|
544
|
+
delay = config["backoff_base"]
|
|
545
|
+
for attempt in range(config["retry_count"] + 1):
|
|
546
|
+
try:
|
|
547
|
+
coro = runtime.execute(task=task, context=context)
|
|
548
|
+
timeout = config["timeout_seconds"]
|
|
549
|
+
if timeout:
|
|
550
|
+
return await asyncio.wait_for(coro, timeout=timeout)
|
|
551
|
+
return await coro
|
|
552
|
+
except asyncio.CancelledError:
|
|
553
|
+
raise
|
|
554
|
+
except Exception:
|
|
555
|
+
if attempt >= config["retry_count"]:
|
|
556
|
+
raise
|
|
557
|
+
await asyncio.sleep(delay)
|
|
558
|
+
delay *= config["backoff_multiplier"]
|
|
559
|
+
|
|
560
|
+
async def _gather_with_config(self, coros: List[Any], state: Dict[str, Any]) -> List[Any]:
|
|
561
|
+
config = self._get_execution_config(state)
|
|
562
|
+
tasks = [asyncio.create_task(coro) for coro in coros]
|
|
563
|
+
if not tasks:
|
|
564
|
+
return []
|
|
565
|
+
if not config["cancel_on_failure"]:
|
|
566
|
+
return await asyncio.gather(*tasks, return_exceptions=True)
|
|
567
|
+
|
|
568
|
+
results: List[Any] = [None] * len(tasks)
|
|
569
|
+
index_map = {task: idx for idx, task in enumerate(tasks)}
|
|
570
|
+
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION)
|
|
571
|
+
|
|
572
|
+
for task in done:
|
|
573
|
+
idx = index_map[task]
|
|
574
|
+
exc = task.exception()
|
|
575
|
+
if exc:
|
|
576
|
+
for pending_task in pending:
|
|
577
|
+
pending_task.cancel()
|
|
578
|
+
await asyncio.gather(*pending, return_exceptions=True)
|
|
579
|
+
raise exc
|
|
580
|
+
results[idx] = task.result()
|
|
581
|
+
|
|
582
|
+
if pending:
|
|
583
|
+
pending_results = await asyncio.gather(*pending, return_exceptions=True)
|
|
584
|
+
for task, result in zip(pending, pending_results):
|
|
585
|
+
results[index_map[task]] = result
|
|
586
|
+
|
|
587
|
+
return results
|
|
448
588
|
|
|
449
589
|
async def _execute_tool_node(self, node: Node, state: Dict[str, Any]) -> Any:
|
|
450
590
|
"""Execute a ToolNode using ToolRegistry.
|
|
@@ -477,6 +617,73 @@ class Graph:
|
|
|
477
617
|
result = await tool.execute(**tool_params)
|
|
478
618
|
return result.model_dump() if hasattr(result, "model_dump") else result
|
|
479
619
|
|
|
620
|
+
async def _execute_subgraph_node(
|
|
621
|
+
self, node: Node, state: Dict[str, Any], max_iterations: int
|
|
622
|
+
) -> Any:
|
|
623
|
+
"""Execute a nested workflow defined in the state metadata."""
|
|
624
|
+
workflow_id = node.config.data.get("workflow_id")
|
|
625
|
+
if not workflow_id:
|
|
626
|
+
raise GraphExecutionError(
|
|
627
|
+
f"Subgraph node '{node.id}' missing workflow_id in config.data"
|
|
628
|
+
)
|
|
629
|
+
|
|
630
|
+
subgraphs = state.get("subgraphs", {})
|
|
631
|
+
workflow_def = subgraphs.get(workflow_id)
|
|
632
|
+
if not workflow_def and "subgraphs" in state:
|
|
633
|
+
workflow_def = state["subgraphs"].get(workflow_id)
|
|
634
|
+
if not workflow_def and "metadata" in state:
|
|
635
|
+
workflow_def = state.get("metadata", {}).get("subgraphs", {}).get(workflow_id)
|
|
636
|
+
if not workflow_def:
|
|
637
|
+
raise GraphExecutionError(
|
|
638
|
+
f"Subgraph workflow '{workflow_id}' not found in state.subgraphs"
|
|
639
|
+
)
|
|
640
|
+
|
|
641
|
+
subgraph = Graph(name=f"subgraph:{workflow_id}")
|
|
642
|
+
for node_def in workflow_def.get("nodes", []):
|
|
643
|
+
node_type = node_def.get("type")
|
|
644
|
+
node_id = node_def.get("id")
|
|
645
|
+
if node_type == "input":
|
|
646
|
+
subgraph.add_node(Node(id=node_id, type=NodeType.INPUT, config=NodeConfig(type=NodeType.INPUT)))
|
|
647
|
+
elif node_type == "output":
|
|
648
|
+
subgraph.add_node(Node(id=node_id, type=NodeType.OUTPUT, config=NodeConfig(type=NodeType.OUTPUT)))
|
|
649
|
+
elif node_type == "agent":
|
|
650
|
+
subgraph.add_node(Node(id=node_id, type=NodeType.AGENT, config=NodeConfig(type=NodeType.AGENT, data=node_def.get("config", {}))))
|
|
651
|
+
elif node_type == "tool":
|
|
652
|
+
subgraph.add_node(Node(id=node_id, type=NodeType.TOOL, config=NodeConfig(type=NodeType.TOOL, data=node_def.get("config", {}))))
|
|
653
|
+
else:
|
|
654
|
+
subgraph.add_node(Node(id=node_id, type=NodeType.CONDITION, config=NodeConfig(type=NodeType.CONDITION, data=node_def.get("config", {}))))
|
|
655
|
+
|
|
656
|
+
for edge_def in workflow_def.get("edges", []):
|
|
657
|
+
subgraph.add_edge(Edge(source=edge_def["source"], target=edge_def["target"], condition=edge_def.get("condition")))
|
|
658
|
+
|
|
659
|
+
result_state = await subgraph.run(
|
|
660
|
+
input_data=state.get("input"),
|
|
661
|
+
max_iterations=max_iterations,
|
|
662
|
+
state={"parent_state": state},
|
|
663
|
+
)
|
|
664
|
+
return {"workflow_id": workflow_id, "state": result_state}
|
|
665
|
+
|
|
666
|
+
async def _execute_loop_node(
|
|
667
|
+
self, node: Node, state: Dict[str, Any], max_iterations: int
|
|
668
|
+
) -> Any:
|
|
669
|
+
"""Execute a loop node by iterating until condition is met."""
|
|
670
|
+
condition_key = node.config.data.get("condition")
|
|
671
|
+
loop_limit = int(node.config.data.get("max_iterations", 5))
|
|
672
|
+
loop_iterations = 0
|
|
673
|
+
results = []
|
|
674
|
+
|
|
675
|
+
while loop_iterations < loop_limit:
|
|
676
|
+
loop_iterations += 1
|
|
677
|
+
state_key = f"loop_{node.id}_iteration"
|
|
678
|
+
state[state_key] = loop_iterations
|
|
679
|
+
results.append({"iteration": loop_iterations})
|
|
680
|
+
if condition_key and state.get(condition_key):
|
|
681
|
+
break
|
|
682
|
+
if state.get("iterations", 0) >= max_iterations:
|
|
683
|
+
break
|
|
684
|
+
|
|
685
|
+
return {"iterations": loop_iterations, "results": results}
|
|
686
|
+
|
|
480
687
|
def to_dict(self) -> Dict[str, Any]:
|
|
481
688
|
"""Convert graph to dictionary representation.
|
|
482
689
|
|