genxai-framework 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (156) hide show
  1. cli/__init__.py +3 -0
  2. cli/commands/__init__.py +6 -0
  3. cli/commands/approval.py +85 -0
  4. cli/commands/audit.py +127 -0
  5. cli/commands/metrics.py +25 -0
  6. cli/commands/tool.py +389 -0
  7. cli/main.py +32 -0
  8. genxai/__init__.py +81 -0
  9. genxai/api/__init__.py +5 -0
  10. genxai/api/app.py +21 -0
  11. genxai/config/__init__.py +5 -0
  12. genxai/config/settings.py +37 -0
  13. genxai/connectors/__init__.py +19 -0
  14. genxai/connectors/base.py +122 -0
  15. genxai/connectors/kafka.py +92 -0
  16. genxai/connectors/postgres_cdc.py +95 -0
  17. genxai/connectors/registry.py +44 -0
  18. genxai/connectors/sqs.py +94 -0
  19. genxai/connectors/webhook.py +73 -0
  20. genxai/core/__init__.py +37 -0
  21. genxai/core/agent/__init__.py +32 -0
  22. genxai/core/agent/base.py +206 -0
  23. genxai/core/agent/config_io.py +59 -0
  24. genxai/core/agent/registry.py +98 -0
  25. genxai/core/agent/runtime.py +970 -0
  26. genxai/core/communication/__init__.py +6 -0
  27. genxai/core/communication/collaboration.py +44 -0
  28. genxai/core/communication/message_bus.py +192 -0
  29. genxai/core/communication/protocols.py +35 -0
  30. genxai/core/execution/__init__.py +22 -0
  31. genxai/core/execution/metadata.py +181 -0
  32. genxai/core/execution/queue.py +201 -0
  33. genxai/core/graph/__init__.py +30 -0
  34. genxai/core/graph/checkpoints.py +77 -0
  35. genxai/core/graph/edges.py +131 -0
  36. genxai/core/graph/engine.py +813 -0
  37. genxai/core/graph/executor.py +516 -0
  38. genxai/core/graph/nodes.py +161 -0
  39. genxai/core/graph/trigger_runner.py +40 -0
  40. genxai/core/memory/__init__.py +19 -0
  41. genxai/core/memory/base.py +72 -0
  42. genxai/core/memory/embedding.py +327 -0
  43. genxai/core/memory/episodic.py +448 -0
  44. genxai/core/memory/long_term.py +467 -0
  45. genxai/core/memory/manager.py +543 -0
  46. genxai/core/memory/persistence.py +297 -0
  47. genxai/core/memory/procedural.py +461 -0
  48. genxai/core/memory/semantic.py +526 -0
  49. genxai/core/memory/shared.py +62 -0
  50. genxai/core/memory/short_term.py +303 -0
  51. genxai/core/memory/vector_store.py +508 -0
  52. genxai/core/memory/working.py +211 -0
  53. genxai/core/state/__init__.py +6 -0
  54. genxai/core/state/manager.py +293 -0
  55. genxai/core/state/schema.py +115 -0
  56. genxai/llm/__init__.py +14 -0
  57. genxai/llm/base.py +150 -0
  58. genxai/llm/factory.py +329 -0
  59. genxai/llm/providers/__init__.py +1 -0
  60. genxai/llm/providers/anthropic.py +249 -0
  61. genxai/llm/providers/cohere.py +274 -0
  62. genxai/llm/providers/google.py +334 -0
  63. genxai/llm/providers/ollama.py +147 -0
  64. genxai/llm/providers/openai.py +257 -0
  65. genxai/llm/routing.py +83 -0
  66. genxai/observability/__init__.py +6 -0
  67. genxai/observability/logging.py +327 -0
  68. genxai/observability/metrics.py +494 -0
  69. genxai/observability/tracing.py +372 -0
  70. genxai/performance/__init__.py +39 -0
  71. genxai/performance/cache.py +256 -0
  72. genxai/performance/pooling.py +289 -0
  73. genxai/security/audit.py +304 -0
  74. genxai/security/auth.py +315 -0
  75. genxai/security/cost_control.py +528 -0
  76. genxai/security/default_policies.py +44 -0
  77. genxai/security/jwt.py +142 -0
  78. genxai/security/oauth.py +226 -0
  79. genxai/security/pii.py +366 -0
  80. genxai/security/policy_engine.py +82 -0
  81. genxai/security/rate_limit.py +341 -0
  82. genxai/security/rbac.py +247 -0
  83. genxai/security/validation.py +218 -0
  84. genxai/tools/__init__.py +21 -0
  85. genxai/tools/base.py +383 -0
  86. genxai/tools/builtin/__init__.py +131 -0
  87. genxai/tools/builtin/communication/__init__.py +15 -0
  88. genxai/tools/builtin/communication/email_sender.py +159 -0
  89. genxai/tools/builtin/communication/notification_manager.py +167 -0
  90. genxai/tools/builtin/communication/slack_notifier.py +118 -0
  91. genxai/tools/builtin/communication/sms_sender.py +118 -0
  92. genxai/tools/builtin/communication/webhook_caller.py +136 -0
  93. genxai/tools/builtin/computation/__init__.py +15 -0
  94. genxai/tools/builtin/computation/calculator.py +101 -0
  95. genxai/tools/builtin/computation/code_executor.py +183 -0
  96. genxai/tools/builtin/computation/data_validator.py +259 -0
  97. genxai/tools/builtin/computation/hash_generator.py +129 -0
  98. genxai/tools/builtin/computation/regex_matcher.py +201 -0
  99. genxai/tools/builtin/data/__init__.py +15 -0
  100. genxai/tools/builtin/data/csv_processor.py +213 -0
  101. genxai/tools/builtin/data/data_transformer.py +299 -0
  102. genxai/tools/builtin/data/json_processor.py +233 -0
  103. genxai/tools/builtin/data/text_analyzer.py +288 -0
  104. genxai/tools/builtin/data/xml_processor.py +175 -0
  105. genxai/tools/builtin/database/__init__.py +15 -0
  106. genxai/tools/builtin/database/database_inspector.py +157 -0
  107. genxai/tools/builtin/database/mongodb_query.py +196 -0
  108. genxai/tools/builtin/database/redis_cache.py +167 -0
  109. genxai/tools/builtin/database/sql_query.py +145 -0
  110. genxai/tools/builtin/database/vector_search.py +163 -0
  111. genxai/tools/builtin/file/__init__.py +17 -0
  112. genxai/tools/builtin/file/directory_scanner.py +214 -0
  113. genxai/tools/builtin/file/file_compressor.py +237 -0
  114. genxai/tools/builtin/file/file_reader.py +102 -0
  115. genxai/tools/builtin/file/file_writer.py +122 -0
  116. genxai/tools/builtin/file/image_processor.py +186 -0
  117. genxai/tools/builtin/file/pdf_parser.py +144 -0
  118. genxai/tools/builtin/test/__init__.py +15 -0
  119. genxai/tools/builtin/test/async_simulator.py +62 -0
  120. genxai/tools/builtin/test/data_transformer.py +99 -0
  121. genxai/tools/builtin/test/error_generator.py +82 -0
  122. genxai/tools/builtin/test/simple_math.py +94 -0
  123. genxai/tools/builtin/test/string_processor.py +72 -0
  124. genxai/tools/builtin/web/__init__.py +15 -0
  125. genxai/tools/builtin/web/api_caller.py +161 -0
  126. genxai/tools/builtin/web/html_parser.py +330 -0
  127. genxai/tools/builtin/web/http_client.py +187 -0
  128. genxai/tools/builtin/web/url_validator.py +162 -0
  129. genxai/tools/builtin/web/web_scraper.py +170 -0
  130. genxai/tools/custom/my_test_tool_2.py +9 -0
  131. genxai/tools/dynamic.py +105 -0
  132. genxai/tools/mcp_server.py +167 -0
  133. genxai/tools/persistence/__init__.py +6 -0
  134. genxai/tools/persistence/models.py +55 -0
  135. genxai/tools/persistence/service.py +322 -0
  136. genxai/tools/registry.py +227 -0
  137. genxai/tools/security/__init__.py +11 -0
  138. genxai/tools/security/limits.py +214 -0
  139. genxai/tools/security/policy.py +20 -0
  140. genxai/tools/security/sandbox.py +248 -0
  141. genxai/tools/templates.py +435 -0
  142. genxai/triggers/__init__.py +19 -0
  143. genxai/triggers/base.py +104 -0
  144. genxai/triggers/file_watcher.py +75 -0
  145. genxai/triggers/queue.py +68 -0
  146. genxai/triggers/registry.py +82 -0
  147. genxai/triggers/schedule.py +66 -0
  148. genxai/triggers/webhook.py +68 -0
  149. genxai/utils/__init__.py +1 -0
  150. genxai/utils/tokens.py +295 -0
  151. genxai_framework-0.1.0.dist-info/METADATA +495 -0
  152. genxai_framework-0.1.0.dist-info/RECORD +156 -0
  153. genxai_framework-0.1.0.dist-info/WHEEL +5 -0
  154. genxai_framework-0.1.0.dist-info/entry_points.txt +2 -0
  155. genxai_framework-0.1.0.dist-info/licenses/LICENSE +21 -0
  156. genxai_framework-0.1.0.dist-info/top_level.txt +2 -0
@@ -0,0 +1,813 @@
1
+ """Graph execution engine for orchestrating agent workflows."""
2
+
3
+ import asyncio
4
+ from typing import Any, Dict, List, Optional, Set
5
+ from collections import defaultdict, deque
6
+ import logging
7
+ import time
8
+ import copy
9
+ from pathlib import Path
10
+
11
+ from genxai.core.graph.nodes import Node, NodeStatus, NodeType
12
+ from genxai.core.agent.registry import AgentRegistry
13
+ from genxai.core.agent.runtime import AgentRuntime
14
+ from genxai.tools.registry import ToolRegistry
15
+ from genxai.core.graph.edges import Edge
16
+ from genxai.core.graph.checkpoints import (
17
+ WorkflowCheckpoint,
18
+ WorkflowCheckpointManager,
19
+ create_checkpoint,
20
+ )
21
+ from genxai.observability.metrics import (
22
+ record_workflow_execution,
23
+ record_workflow_node_execution,
24
+ )
25
+ from genxai.observability.tracing import span, record_exception
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ class GraphExecutionError(Exception):
31
+ """Exception raised during graph execution."""
32
+
33
+ pass
34
+
35
+
36
+ class Graph:
37
+ """Main graph class for orchestrating agent workflows."""
38
+
39
+ def __init__(self, name: str = "workflow") -> None:
40
+ """Initialize the graph.
41
+
42
+ Args:
43
+ name: Name of the workflow graph
44
+ """
45
+ self.name = name
46
+ self.nodes: Dict[str, Node] = {}
47
+ self.edges: List[Edge] = []
48
+ self._adjacency_list: Dict[str, List[Edge]] = defaultdict(list)
49
+ self._reverse_adjacency: Dict[str, List[str]] = defaultdict(list)
50
+
51
+ def add_node(self, node: Node) -> None:
52
+ """Add a node to the graph.
53
+
54
+ Args:
55
+ node: Node to add
56
+
57
+ Raises:
58
+ ValueError: If node with same ID already exists
59
+ """
60
+ if node.id in self.nodes:
61
+ raise ValueError(f"Node with id '{node.id}' already exists")
62
+
63
+ self.nodes[node.id] = node
64
+ logger.debug(f"Added node: {node.id} (type: {node.type})")
65
+
66
+ def add_edge(self, edge: Edge) -> None:
67
+ """Add an edge to the graph.
68
+
69
+ Args:
70
+ edge: Edge to add
71
+
72
+ Raises:
73
+ ValueError: If source or target node doesn't exist
74
+ """
75
+ if edge.source not in self.nodes:
76
+ raise ValueError(f"Source node '{edge.source}' not found")
77
+ if edge.target not in self.nodes:
78
+ raise ValueError(f"Target node '{edge.target}' not found")
79
+
80
+ self.edges.append(edge)
81
+ self._adjacency_list[edge.source].append(edge)
82
+ self._reverse_adjacency[edge.target].append(edge.source)
83
+ logger.debug(f"Added edge: {edge.source} -> {edge.target}")
84
+
85
+ def get_node(self, node_id: str) -> Optional[Node]:
86
+ """Get a node by ID.
87
+
88
+ Args:
89
+ node_id: ID of the node
90
+
91
+ Returns:
92
+ Node if found, None otherwise
93
+ """
94
+ return self.nodes.get(node_id)
95
+
96
+ def get_outgoing_edges(self, node_id: str) -> List[Edge]:
97
+ """Get all outgoing edges from a node.
98
+
99
+ Args:
100
+ node_id: ID of the node
101
+
102
+ Returns:
103
+ List of outgoing edges
104
+ """
105
+ return self._adjacency_list.get(node_id, [])
106
+
107
+ def get_incoming_nodes(self, node_id: str) -> List[str]:
108
+ """Get all nodes with edges pointing to this node.
109
+
110
+ Args:
111
+ node_id: ID of the node
112
+
113
+ Returns:
114
+ List of incoming node IDs
115
+ """
116
+ return self._reverse_adjacency.get(node_id, [])
117
+
118
+ def validate(self) -> bool:
119
+ """Validate the graph structure.
120
+
121
+ Returns:
122
+ True if graph is valid
123
+
124
+ Raises:
125
+ GraphExecutionError: If graph is invalid
126
+ """
127
+ # Check for at least one node
128
+ if not self.nodes:
129
+ raise GraphExecutionError("Graph must have at least one node")
130
+
131
+ # Check for cycles (optional - we allow cycles)
132
+ # Check for disconnected components
133
+ visited = self._dfs_visit(next(iter(self.nodes.keys())))
134
+ if len(visited) != len(self.nodes):
135
+ logger.warning("Graph has disconnected components")
136
+
137
+ # Check that all edges reference valid nodes
138
+ for edge in self.edges:
139
+ if edge.source not in self.nodes or edge.target not in self.nodes:
140
+ raise GraphExecutionError(
141
+ f"Edge references non-existent node: {edge.source} -> {edge.target}"
142
+ )
143
+
144
+ logger.info(f"Graph '{self.name}' validated successfully")
145
+ return True
146
+
147
+ def _dfs_visit(self, start_node: str) -> Set[str]:
148
+ """Perform DFS traversal from start node.
149
+
150
+ Args:
151
+ start_node: Starting node ID
152
+
153
+ Returns:
154
+ Set of visited node IDs
155
+ """
156
+ visited: Set[str] = set()
157
+ stack = [start_node]
158
+
159
+ while stack:
160
+ node_id = stack.pop()
161
+ if node_id in visited:
162
+ continue
163
+
164
+ visited.add(node_id)
165
+
166
+ # Add neighbors (both outgoing and incoming for undirected check)
167
+ for edge in self.get_outgoing_edges(node_id):
168
+ if edge.target not in visited:
169
+ stack.append(edge.target)
170
+
171
+ for incoming in self.get_incoming_nodes(node_id):
172
+ if incoming not in visited:
173
+ stack.append(incoming)
174
+
175
+ return visited
176
+
177
+ def topological_sort(self) -> List[str]:
178
+ """Perform topological sort on the graph.
179
+
180
+ Returns:
181
+ List of node IDs in topological order
182
+
183
+ Raises:
184
+ GraphExecutionError: If graph has cycles
185
+ """
186
+ in_degree = {node_id: 0 for node_id in self.nodes}
187
+
188
+ for edge in self.edges:
189
+ in_degree[edge.target] += 1
190
+
191
+ queue: deque[str] = deque([node_id for node_id, degree in in_degree.items() if degree == 0])
192
+ result: List[str] = []
193
+
194
+ while queue:
195
+ node_id = queue.popleft()
196
+ result.append(node_id)
197
+
198
+ for edge in self.get_outgoing_edges(node_id):
199
+ in_degree[edge.target] -= 1
200
+ if in_degree[edge.target] == 0:
201
+ queue.append(edge.target)
202
+
203
+ if len(result) != len(self.nodes):
204
+ raise GraphExecutionError("Graph contains cycles - cannot perform topological sort")
205
+
206
+ return result
207
+
208
+ async def run(
209
+ self,
210
+ input_data: Any,
211
+ max_iterations: int = 100,
212
+ state: Optional[Dict[str, Any]] = None,
213
+ resume_from: Optional[WorkflowCheckpoint] = None,
214
+ llm_provider: Any = None,
215
+ ) -> Dict[str, Any]:
216
+ """Execute the graph workflow.
217
+
218
+ Args:
219
+ input_data: Input data for the workflow
220
+ max_iterations: Maximum number of iterations (for cycle detection)
221
+ state: Initial state dictionary
222
+
223
+ Returns:
224
+ Final state after execution
225
+
226
+ Raises:
227
+ GraphExecutionError: If execution fails
228
+ """
229
+ if not self.nodes:
230
+ raise GraphExecutionError("Cannot run empty graph")
231
+
232
+ self.validate()
233
+
234
+ start_time = time.time()
235
+ status = "success"
236
+
237
+ # Initialize state
238
+ if resume_from:
239
+ state = resume_from.state.copy()
240
+ state["input"] = input_data
241
+ state.setdefault("iterations", 0)
242
+ else:
243
+ if state is None:
244
+ state = {}
245
+ state["input"] = input_data
246
+ state["iterations"] = 0
247
+
248
+ if resume_from:
249
+ for node_id, status in resume_from.node_statuses.items():
250
+ if node_id in self.nodes:
251
+ self.nodes[node_id].status = NodeStatus(status)
252
+
253
+ # Find entry points (nodes with no incoming edges)
254
+ entry_points = [
255
+ node_id for node_id in self.nodes if not self.get_incoming_nodes(node_id)
256
+ ]
257
+
258
+ if not entry_points:
259
+ # If no clear entry point, look for INPUT node
260
+ entry_points = [
261
+ node_id for node_id, node in self.nodes.items() if node.type == NodeType.INPUT
262
+ ]
263
+
264
+ if not entry_points:
265
+ raise GraphExecutionError("No entry point found in graph")
266
+
267
+ logger.info(f"Starting graph execution: {self.name}")
268
+ logger.debug(f"Entry points: {entry_points}")
269
+
270
+ if llm_provider is not None:
271
+ state["llm_provider"] = llm_provider
272
+
273
+ # Execute from entry points
274
+ try:
275
+ with span("genxai.workflow.execute", {"workflow_id": self.name}):
276
+ for entry_point in entry_points:
277
+ await self._execute_node(entry_point, state, max_iterations)
278
+ except Exception as exc:
279
+ status = "error"
280
+ record_exception(exc)
281
+ raise
282
+ finally:
283
+ duration = time.time() - start_time
284
+ record_workflow_execution(
285
+ workflow_id=self.name,
286
+ duration=duration,
287
+ status=status,
288
+ )
289
+
290
+ logger.info(f"Graph execution completed: {self.name}")
291
+ return state
292
+
293
+ def create_checkpoint(self, name: str, state: Dict[str, Any]) -> WorkflowCheckpoint:
294
+ """Create a checkpoint from current workflow state."""
295
+ node_statuses = {node_id: node.status for node_id, node in self.nodes.items()}
296
+ return create_checkpoint(name=name, workflow=self.name, state=state, node_statuses=node_statuses)
297
+
298
+ def save_checkpoint(self, name: str, state: Dict[str, Any], path: Path) -> Path:
299
+ """Persist a checkpoint to disk."""
300
+ manager = WorkflowCheckpointManager(path)
301
+ checkpoint = self.create_checkpoint(name=name, state=state)
302
+ return manager.save(checkpoint)
303
+
304
+ def load_checkpoint(self, name: str, path: Path) -> WorkflowCheckpoint:
305
+ """Load a checkpoint from disk."""
306
+ manager = WorkflowCheckpointManager(path)
307
+ return manager.load(name)
308
+
309
+ async def _execute_node(
310
+ self, node_id: str, state: Dict[str, Any], max_iterations: int
311
+ ) -> None:
312
+ """Execute a single node and its descendants.
313
+
314
+ Args:
315
+ node_id: ID of the node to execute
316
+ state: Current state
317
+ max_iterations: Maximum iterations allowed
318
+
319
+ Raises:
320
+ GraphExecutionError: If execution fails or max iterations exceeded
321
+ """
322
+ if state.get("iterations", 0) >= max_iterations:
323
+ raise GraphExecutionError(f"Maximum iterations ({max_iterations}) exceeded")
324
+
325
+ state["iterations"] = state.get("iterations", 0) + 1
326
+
327
+ node = self.nodes[node_id]
328
+
329
+ # Skip if already completed
330
+ if node.status == NodeStatus.COMPLETED:
331
+ return
332
+
333
+ # Mark as running
334
+ node.status = NodeStatus.RUNNING
335
+ logger.debug(f"Executing node: {node_id}")
336
+ node_start = time.time()
337
+
338
+ try:
339
+ # Execute node (placeholder - will be implemented with actual executors)
340
+ with span(
341
+ "genxai.workflow.node",
342
+ {"workflow_id": self.name, "node_id": node_id, "node_type": node.type.value},
343
+ ):
344
+ result = await self._execute_node_logic(node, state)
345
+ node.result = result
346
+ node.status = NodeStatus.COMPLETED
347
+ logger.debug(f"Node completed: {node_id}")
348
+
349
+ record_workflow_node_execution(
350
+ workflow_id=self.name,
351
+ node_id=node_id,
352
+ status="success",
353
+ )
354
+
355
+ # Update state with result
356
+ state[node_id] = result
357
+
358
+ # Get outgoing edges and evaluate conditions
359
+ outgoing_edges = self.get_outgoing_edges(node_id)
360
+
361
+ # Separate parallel and sequential edges
362
+ parallel_edges = [e for e in outgoing_edges if e.metadata.get("parallel", False)]
363
+ sequential_edges = [e for e in outgoing_edges if not e.metadata.get("parallel", False)]
364
+
365
+ # Execute parallel edges concurrently
366
+ if parallel_edges:
367
+ tasks = []
368
+ for edge in parallel_edges:
369
+ if edge.evaluate_condition(state):
370
+ tasks.append(self._execute_node(edge.target, state, max_iterations))
371
+ if tasks:
372
+ await asyncio.gather(*tasks)
373
+
374
+ # Execute sequential edges in order
375
+ for edge in sorted(sequential_edges, key=lambda e: e.priority):
376
+ if edge.evaluate_condition(state):
377
+ await self._execute_node(edge.target, state, max_iterations)
378
+
379
+ except Exception as e:
380
+ node.status = NodeStatus.FAILED
381
+ node.error = str(e)
382
+ logger.error(f"Node execution failed: {node_id} - {e}")
383
+ record_workflow_node_execution(
384
+ workflow_id=self.name,
385
+ node_id=node_id,
386
+ status="error",
387
+ )
388
+ raise GraphExecutionError(f"Node {node_id} failed: {e}") from e
389
+
390
+ async def _execute_node_logic(self, node: Node, state: Dict[str, Any]) -> Any:
391
+ """Execute the actual logic of a node.
392
+
393
+ Args:
394
+ node: Node to execute
395
+ state: Current state
396
+
397
+ Returns:
398
+ Result of node execution
399
+ """
400
+ if node.type == NodeType.INPUT:
401
+ return copy.deepcopy(state.get("input"))
402
+
403
+ if node.type == NodeType.OUTPUT:
404
+ return copy.deepcopy(state)
405
+
406
+ if node.type == NodeType.AGENT:
407
+ return await self._execute_agent_node(node, state)
408
+
409
+ if node.type == NodeType.TOOL:
410
+ return await self._execute_tool_node(node, state)
411
+
412
+ # Default fallback for unsupported nodes
413
+ return {"node_id": node.id, "type": node.type.value}
414
+
415
+ async def _execute_agent_node(self, node: Node, state: Dict[str, Any]) -> Dict[str, Any]:
416
+ """Execute an AgentNode using AgentRuntime.
417
+
418
+ Args:
419
+ node: Agent node to execute
420
+ state: Current workflow state
421
+
422
+ Returns:
423
+ Agent execution result
424
+ """
425
+ agent_id = node.config.data.get("agent_id")
426
+ if not agent_id:
427
+ raise GraphExecutionError(
428
+ f"Agent node '{node.id}' missing agent_id in config.data"
429
+ )
430
+
431
+ agent = AgentRegistry.get(agent_id)
432
+ if agent is None:
433
+ raise GraphExecutionError(f"Agent '{agent_id}' not found in registry")
434
+
435
+ task = node.config.data.get("task") or state.get("task") or "Process input"
436
+
437
+ llm_provider = state.get("llm_provider")
438
+ runtime = AgentRuntime(agent=agent, llm_provider=llm_provider, enable_memory=True)
439
+ if agent.config.tools:
440
+ tools: Dict[str, Any] = {}
441
+ for tool_name in agent.config.tools:
442
+ tool = ToolRegistry.get(tool_name)
443
+ if tool:
444
+ tools[tool_name] = tool
445
+ runtime.set_tools(tools)
446
+
447
+ return await runtime.execute(task, context=state)
448
+
449
+ async def _execute_tool_node(self, node: Node, state: Dict[str, Any]) -> Any:
450
+ """Execute a ToolNode using ToolRegistry.
451
+
452
+ Args:
453
+ node: Tool node to execute
454
+ state: Current workflow state
455
+
456
+ Returns:
457
+ Tool execution result
458
+ """
459
+ tool_name = node.config.data.get("tool_name")
460
+ if not tool_name:
461
+ raise GraphExecutionError(
462
+ f"Tool node '{node.id}' missing tool_name in config.data"
463
+ )
464
+
465
+ tool = ToolRegistry.get(tool_name)
466
+ if tool is None:
467
+ raise GraphExecutionError(f"Tool '{tool_name}' not found in registry")
468
+
469
+ tool_params = node.config.data.get("tool_params", {})
470
+ if tool_params is None:
471
+ tool_params = {}
472
+ if not isinstance(tool_params, dict):
473
+ raise GraphExecutionError(
474
+ f"Tool node '{node.id}' tool_params must be a dict"
475
+ )
476
+
477
+ result = await tool.execute(**tool_params)
478
+ return result.model_dump() if hasattr(result, "model_dump") else result
479
+
480
+ def to_dict(self) -> Dict[str, Any]:
481
+ """Convert graph to dictionary representation.
482
+
483
+ Returns:
484
+ Dictionary representation of the graph
485
+ """
486
+ return {
487
+ "name": self.name,
488
+ "nodes": [
489
+ {
490
+ "id": node.id,
491
+ "type": node.type.value,
492
+ "config": node.config.model_dump(),
493
+ "status": node.status.value,
494
+ }
495
+ for node in self.nodes.values()
496
+ ],
497
+ "edges": [
498
+ {
499
+ "source": edge.source,
500
+ "target": edge.target,
501
+ "metadata": edge.metadata,
502
+ "priority": edge.priority,
503
+ }
504
+ for edge in self.edges
505
+ ],
506
+ }
507
+
508
+ def __repr__(self) -> str:
509
+ """String representation of the graph."""
510
+ return f"Graph(name={self.name}, nodes={len(self.nodes)}, edges={len(self.edges)})"
511
+
512
+ def draw_ascii(self) -> str:
513
+ """Generate ASCII art representation of the graph.
514
+
515
+ Returns:
516
+ String containing ASCII art visualization of the graph
517
+ """
518
+ if not self.nodes:
519
+ return "Empty graph"
520
+
521
+ lines = []
522
+ lines.append(f"Graph: {self.name}")
523
+ lines.append("=" * 60)
524
+ lines.append("")
525
+
526
+ # Find entry points
527
+ entry_points = [
528
+ node_id for node_id in self.nodes if not self.get_incoming_nodes(node_id)
529
+ ]
530
+
531
+ if not entry_points:
532
+ entry_points = [
533
+ node_id
534
+ for node_id, node in self.nodes.items()
535
+ if node.type == NodeType.INPUT
536
+ ]
537
+
538
+ if not entry_points and self.nodes:
539
+ entry_points = [next(iter(self.nodes.keys()))]
540
+
541
+ # Build tree structure
542
+ visited = set()
543
+ for entry in entry_points:
544
+ self._draw_node_tree(entry, lines, visited, prefix="", is_last=True)
545
+
546
+ lines.append("")
547
+ lines.append("=" * 60)
548
+ lines.append(f"Total Nodes: {len(self.nodes)} | Total Edges: {len(self.edges)}")
549
+
550
+ return "\n".join(lines)
551
+
552
+ def _draw_node_tree(
553
+ self, node_id: str, lines: List[str], visited: Set[str], prefix: str, is_last: bool
554
+ ) -> None:
555
+ """Recursively draw node tree structure.
556
+
557
+ Args:
558
+ node_id: Current node ID
559
+ lines: List to append output lines to
560
+ visited: Set of visited node IDs
561
+ prefix: Current line prefix for indentation
562
+ is_last: Whether this is the last child
563
+ """
564
+ if node_id not in self.nodes:
565
+ return
566
+
567
+ node = self.nodes[node_id]
568
+
569
+ # Draw current node
570
+ connector = "└── " if is_last else "├── "
571
+ status_symbol = {
572
+ NodeStatus.PENDING: "○",
573
+ NodeStatus.RUNNING: "◐",
574
+ NodeStatus.COMPLETED: "●",
575
+ NodeStatus.FAILED: "✗",
576
+ NodeStatus.SKIPPED: "⊘",
577
+ }.get(node.status, "?")
578
+
579
+ node_line = f"{prefix}{connector}{status_symbol} {node.id} [{node.type.value}]"
580
+ lines.append(node_line)
581
+
582
+ # Avoid infinite loops in cyclic graphs
583
+ if node_id in visited:
584
+ extension = " " if is_last else "│ "
585
+ lines.append(f"{prefix}{extension}↻ (cycle detected)")
586
+ return
587
+
588
+ visited.add(node_id)
589
+
590
+ # Get outgoing edges
591
+ outgoing = self.get_outgoing_edges(node_id)
592
+ if not outgoing:
593
+ return
594
+
595
+ # Group edges by type
596
+ parallel_edges = [e for e in outgoing if e.metadata.get("parallel", False)]
597
+ sequential_edges = [e for e in outgoing if not e.metadata.get("parallel", False)]
598
+
599
+ # Draw parallel edges
600
+ if parallel_edges:
601
+ extension = " " if is_last else "│ "
602
+ lines.append(f"{prefix}{extension}║")
603
+ lines.append(f"{prefix}{extension}╠══ [PARALLEL]")
604
+
605
+ for i, edge in enumerate(parallel_edges):
606
+ is_last_parallel = i == len(parallel_edges) - 1 and not sequential_edges
607
+ new_prefix = prefix + (" " if is_last else "│ ")
608
+ condition_marker = " (conditional)" if edge.condition else ""
609
+ lines.append(f"{new_prefix}║")
610
+ self._draw_node_tree(
611
+ edge.target, lines, visited.copy(), new_prefix, is_last_parallel
612
+ )
613
+
614
+ # Draw sequential edges
615
+ for i, edge in enumerate(sequential_edges):
616
+ is_last_edge = i == len(sequential_edges) - 1
617
+ new_prefix = prefix + (" " if is_last else "│ ")
618
+ condition_marker = " (?)" if edge.condition else ""
619
+
620
+ if edge.condition:
621
+ lines.append(f"{new_prefix}│")
622
+ lines.append(f"{new_prefix}├── [IF condition]")
623
+
624
+ self._draw_node_tree(edge.target, lines, visited.copy(), new_prefix, is_last_edge)
625
+
626
+ def to_mermaid(self) -> str:
627
+ """Generate Mermaid diagram syntax for the graph.
628
+
629
+ Returns:
630
+ String containing Mermaid flowchart syntax
631
+ """
632
+ if not self.nodes:
633
+ return "graph TD\n empty[Empty Graph]"
634
+
635
+ lines = ["graph TD"]
636
+
637
+ # Define nodes with appropriate shapes
638
+ for node_id, node in self.nodes.items():
639
+ label = f"{node_id}\\n[{node.type.value}]"
640
+
641
+ # Choose shape based on node type
642
+ if node.type == NodeType.INPUT:
643
+ shape = f' {node_id}(["{label}"])'
644
+ elif node.type == NodeType.OUTPUT:
645
+ shape = f' {node_id}(["{label}"])'
646
+ elif node.type == NodeType.CONDITION:
647
+ shape = f' {node_id}{{{{{label}}}}}'
648
+ elif node.type == NodeType.AGENT:
649
+ shape = f' {node_id}["{label}"]'
650
+ elif node.type == NodeType.TOOL:
651
+ shape = f' {node_id}["{label}"]'
652
+ else:
653
+ shape = f' {node_id}["{label}"]'
654
+
655
+ lines.append(shape)
656
+
657
+ lines.append("")
658
+
659
+ # Define edges
660
+ for edge in self.edges:
661
+ if edge.condition:
662
+ lines.append(f" {edge.source} -->|conditional| {edge.target}")
663
+ elif edge.metadata.get("parallel", False):
664
+ lines.append(f" {edge.source} -.parallel.-> {edge.target}")
665
+ else:
666
+ lines.append(f" {edge.source} --> {edge.target}")
667
+
668
+ return "\n".join(lines)
669
+
670
+ def to_dot(self) -> str:
671
+ """Generate GraphViz DOT format for the graph.
672
+
673
+ Returns:
674
+ String containing DOT format graph definition
675
+ """
676
+ if not self.nodes:
677
+ return "digraph empty { }"
678
+
679
+ lines = [f'digraph "{self.name}" {{']
680
+ lines.append(" rankdir=TB;")
681
+ lines.append(" node [fontname=Arial, fontsize=10];")
682
+ lines.append(" edge [fontname=Arial, fontsize=9];")
683
+ lines.append("")
684
+
685
+ # Define node styles by type
686
+ node_styles = {
687
+ NodeType.INPUT: 'shape=ellipse, style=filled, fillcolor=lightblue',
688
+ NodeType.OUTPUT: 'shape=ellipse, style=filled, fillcolor=lightgreen',
689
+ NodeType.CONDITION: 'shape=diamond, style=filled, fillcolor=lightyellow',
690
+ NodeType.AGENT: 'shape=box, style="rounded,filled", fillcolor=lightcoral',
691
+ NodeType.TOOL: 'shape=box, style=filled, fillcolor=lightgray',
692
+ NodeType.HUMAN: 'shape=box, style=filled, fillcolor=lightpink',
693
+ NodeType.SUBGRAPH: 'shape=box3d, style=filled, fillcolor=lavender',
694
+ }
695
+
696
+ # Define nodes
697
+ for node_id, node in self.nodes.items():
698
+ style = node_styles.get(node.type, 'shape=box')
699
+ label = f"{node_id}\\n[{node.type.value}]"
700
+
701
+ # Add status indicator
702
+ if node.status != NodeStatus.PENDING:
703
+ label += f"\\n({node.status.value})"
704
+
705
+ lines.append(f' {node_id} [label="{label}", {style}];')
706
+
707
+ lines.append("")
708
+
709
+ # Define edges
710
+ for edge in self.edges:
711
+ attrs = []
712
+
713
+ if edge.condition:
714
+ attrs.append('label="conditional"')
715
+ attrs.append('style=dashed')
716
+
717
+ if edge.metadata.get("parallel", False):
718
+ attrs.append('label="parallel"')
719
+ attrs.append('color=blue')
720
+
721
+ if edge.priority != 0:
722
+ attrs.append(f'weight={edge.priority}')
723
+
724
+ attr_str = ", ".join(attrs) if attrs else ""
725
+ if attr_str:
726
+ lines.append(f" {edge.source} -> {edge.target} [{attr_str}];")
727
+ else:
728
+ lines.append(f" {edge.source} -> {edge.target};")
729
+
730
+ lines.append("}")
731
+
732
+ return "\n".join(lines)
733
+
734
+ def print_structure(self) -> None:
735
+ """Print a simple text summary of the graph structure."""
736
+ print(f"\nGraph: {self.name}")
737
+ print("=" * 60)
738
+ print(f"Nodes: {len(self.nodes)}")
739
+ print(f"Edges: {len(self.edges)}")
740
+ print()
741
+
742
+ if self.nodes:
743
+ print("Node List:")
744
+ print("-" * 60)
745
+ for node_id, node in self.nodes.items():
746
+ status = node.status.value
747
+ print(f" • {node_id:20} [{node.type.value:10}] ({status})")
748
+ print()
749
+
750
+ if self.edges:
751
+ print("Edge List:")
752
+ print("-" * 60)
753
+ for edge in self.edges:
754
+ condition = "conditional" if edge.condition else "unconditional"
755
+ parallel = " [PARALLEL]" if edge.metadata.get("parallel", False) else ""
756
+ print(f" • {edge.source:15} → {edge.target:15} ({condition}){parallel}")
757
+ print()
758
+
759
+ # Find entry and exit points
760
+ entry_points = [
761
+ node_id for node_id in self.nodes if not self.get_incoming_nodes(node_id)
762
+ ]
763
+ exit_points = [
764
+ node_id for node_id in self.nodes if not self.get_outgoing_edges(node_id)
765
+ ]
766
+
767
+ if entry_points:
768
+ print(f"Entry Points: {', '.join(entry_points)}")
769
+ if exit_points:
770
+ print(f"Exit Points: {', '.join(exit_points)}")
771
+
772
+ print("=" * 60)
773
+ print()
774
+
775
+
776
+ class WorkflowEngine(Graph):
777
+ """Public, user-facing workflow engine.
778
+
779
+ This is a thin compatibility wrapper around :class:`~genxai.core.graph.engine.Graph`
780
+ to match the API expected by integration tests and external users.
781
+ """
782
+
783
+ def __init__(self, name: str = "workflow") -> None:
784
+ super().__init__(name=name)
785
+
786
+ async def execute(self, start_node: str, llm_provider: Any = None, **kwargs: Any) -> Dict[str, Any]:
787
+ """Execute a workflow starting from a given node.
788
+
789
+ Notes:
790
+ - WorkflowEngine uses the core Graph execution pipeline, which now
791
+ executes AgentNode + ToolNode via AgentRuntime/ToolRegistry.
792
+ - Integration tests pass `llm_provider`, but Graph does not need it.
793
+ It's accepted here for compatibility.
794
+ """
795
+ # Initialize state with start node as the only entry point.
796
+ state: Dict[str, Any] = kwargs.pop("state", {}) if "state" in kwargs else {}
797
+ input_data = kwargs.pop("input_data", None)
798
+ if input_data is not None:
799
+ state["input"] = input_data
800
+
801
+ # Ensure max_iterations propagates.
802
+ max_iterations = kwargs.pop("max_iterations", 100)
803
+
804
+ if llm_provider is not None:
805
+ state["llm_provider"] = llm_provider
806
+
807
+ # Execute from specified start node.
808
+ await self._execute_node(start_node, state, max_iterations)
809
+ return {
810
+ "status": "completed",
811
+ "node_results": {k: v for k, v in state.items() if k not in {"iterations"}},
812
+ "state": state,
813
+ }