kailash 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kailash/__init__.py +31 -0
- kailash/__main__.py +11 -0
- kailash/cli/__init__.py +5 -0
- kailash/cli/commands.py +563 -0
- kailash/manifest.py +778 -0
- kailash/nodes/__init__.py +23 -0
- kailash/nodes/ai/__init__.py +26 -0
- kailash/nodes/ai/agents.py +417 -0
- kailash/nodes/ai/models.py +488 -0
- kailash/nodes/api/__init__.py +52 -0
- kailash/nodes/api/auth.py +567 -0
- kailash/nodes/api/graphql.py +480 -0
- kailash/nodes/api/http.py +598 -0
- kailash/nodes/api/rate_limiting.py +572 -0
- kailash/nodes/api/rest.py +665 -0
- kailash/nodes/base.py +1032 -0
- kailash/nodes/base_async.py +128 -0
- kailash/nodes/code/__init__.py +32 -0
- kailash/nodes/code/python.py +1021 -0
- kailash/nodes/data/__init__.py +125 -0
- kailash/nodes/data/readers.py +496 -0
- kailash/nodes/data/sharepoint_graph.py +623 -0
- kailash/nodes/data/sql.py +380 -0
- kailash/nodes/data/streaming.py +1168 -0
- kailash/nodes/data/vector_db.py +964 -0
- kailash/nodes/data/writers.py +529 -0
- kailash/nodes/logic/__init__.py +6 -0
- kailash/nodes/logic/async_operations.py +702 -0
- kailash/nodes/logic/operations.py +551 -0
- kailash/nodes/transform/__init__.py +5 -0
- kailash/nodes/transform/processors.py +379 -0
- kailash/runtime/__init__.py +6 -0
- kailash/runtime/async_local.py +356 -0
- kailash/runtime/docker.py +697 -0
- kailash/runtime/local.py +434 -0
- kailash/runtime/parallel.py +557 -0
- kailash/runtime/runner.py +110 -0
- kailash/runtime/testing.py +347 -0
- kailash/sdk_exceptions.py +307 -0
- kailash/tracking/__init__.py +7 -0
- kailash/tracking/manager.py +885 -0
- kailash/tracking/metrics_collector.py +342 -0
- kailash/tracking/models.py +535 -0
- kailash/tracking/storage/__init__.py +0 -0
- kailash/tracking/storage/base.py +113 -0
- kailash/tracking/storage/database.py +619 -0
- kailash/tracking/storage/filesystem.py +543 -0
- kailash/utils/__init__.py +0 -0
- kailash/utils/export.py +924 -0
- kailash/utils/templates.py +680 -0
- kailash/visualization/__init__.py +62 -0
- kailash/visualization/api.py +732 -0
- kailash/visualization/dashboard.py +951 -0
- kailash/visualization/performance.py +808 -0
- kailash/visualization/reports.py +1471 -0
- kailash/workflow/__init__.py +15 -0
- kailash/workflow/builder.py +245 -0
- kailash/workflow/graph.py +827 -0
- kailash/workflow/mermaid_visualizer.py +628 -0
- kailash/workflow/mock_registry.py +63 -0
- kailash/workflow/runner.py +302 -0
- kailash/workflow/state.py +238 -0
- kailash/workflow/visualization.py +588 -0
- kailash-0.1.0.dist-info/METADATA +710 -0
- kailash-0.1.0.dist-info/RECORD +69 -0
- kailash-0.1.0.dist-info/WHEEL +5 -0
- kailash-0.1.0.dist-info/entry_points.txt +2 -0
- kailash-0.1.0.dist-info/licenses/LICENSE +21 -0
- kailash-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,827 @@
|
|
1
|
+
"""Workflow DAG implementation for the Kailash SDK."""
|
2
|
+
|
3
|
+
import json
|
4
|
+
import logging
|
5
|
+
import uuid
|
6
|
+
from datetime import datetime, timezone
|
7
|
+
from typing import Any, Dict, List, Optional, Tuple
|
8
|
+
|
9
|
+
import networkx as nx
|
10
|
+
import yaml
|
11
|
+
from pydantic import BaseModel, Field, ValidationError
|
12
|
+
|
13
|
+
from kailash.nodes import Node
|
14
|
+
|
15
|
+
try:
|
16
|
+
# For normal runtime, use the actual registry
|
17
|
+
from kailash.nodes import NodeRegistry
|
18
|
+
except ImportError:
|
19
|
+
# For tests, use the mock registry
|
20
|
+
from kailash.workflow.mock_registry import MockRegistry as NodeRegistry
|
21
|
+
|
22
|
+
from kailash.sdk_exceptions import (
|
23
|
+
ConnectionError,
|
24
|
+
ExportException,
|
25
|
+
NodeConfigurationError,
|
26
|
+
WorkflowExecutionError,
|
27
|
+
WorkflowValidationError,
|
28
|
+
)
|
29
|
+
from kailash.tracking import TaskManager, TaskStatus
|
30
|
+
from kailash.workflow.state import WorkflowStateWrapper
|
31
|
+
|
32
|
+
logger = logging.getLogger(__name__)
|
33
|
+
|
34
|
+
|
35
|
+
class NodeInstance(BaseModel):
|
36
|
+
"""Instance of a node in a workflow."""
|
37
|
+
|
38
|
+
node_id: str = Field(..., description="Unique identifier for this instance")
|
39
|
+
node_type: str = Field(..., description="Type of node")
|
40
|
+
config: Dict[str, Any] = Field(
|
41
|
+
default_factory=dict, description="Node configuration"
|
42
|
+
)
|
43
|
+
position: Tuple[float, float] = Field(default=(0, 0), description="Visual position")
|
44
|
+
|
45
|
+
|
46
|
+
class Connection(BaseModel):
|
47
|
+
"""Connection between two nodes in a workflow."""
|
48
|
+
|
49
|
+
source_node: str = Field(..., description="Source node ID")
|
50
|
+
source_output: str = Field(..., description="Output field from source")
|
51
|
+
target_node: str = Field(..., description="Target node ID")
|
52
|
+
target_input: str = Field(..., description="Input field on target")
|
53
|
+
|
54
|
+
|
55
|
+
class Workflow:
|
56
|
+
"""Represents a workflow DAG of nodes."""
|
57
|
+
|
58
|
+
def __init__(
|
59
|
+
self,
|
60
|
+
workflow_id: str,
|
61
|
+
name: str,
|
62
|
+
description: str = "",
|
63
|
+
version: str = "1.0.0",
|
64
|
+
author: str = "",
|
65
|
+
metadata: Optional[Dict[str, Any]] = None,
|
66
|
+
):
|
67
|
+
"""Initialize a workflow.
|
68
|
+
|
69
|
+
Args:
|
70
|
+
workflow_id: Unique workflow identifier
|
71
|
+
name: Workflow name
|
72
|
+
description: Workflow description
|
73
|
+
version: Workflow version
|
74
|
+
author: Workflow author
|
75
|
+
metadata: Additional metadata
|
76
|
+
|
77
|
+
Raises:
|
78
|
+
WorkflowValidationError: If workflow initialization fails
|
79
|
+
"""
|
80
|
+
self.workflow_id = workflow_id
|
81
|
+
self.name = name
|
82
|
+
self.description = description
|
83
|
+
self.version = version
|
84
|
+
self.author = author
|
85
|
+
self.metadata = metadata or {}
|
86
|
+
|
87
|
+
# Add standard metadata
|
88
|
+
if "author" not in self.metadata and author:
|
89
|
+
self.metadata["author"] = author
|
90
|
+
if "version" not in self.metadata and version:
|
91
|
+
self.metadata["version"] = version
|
92
|
+
if "created_at" not in self.metadata:
|
93
|
+
self.metadata["created_at"] = datetime.now(timezone.utc).isoformat()
|
94
|
+
|
95
|
+
# Create directed graph for the workflow
|
96
|
+
self.graph = nx.DiGraph()
|
97
|
+
|
98
|
+
# Storage for node instances and node metadata
|
99
|
+
self._node_instances = {} # Maps node_id to Node instances
|
100
|
+
self.nodes = {} # Maps node_id to NodeInstance metadata objects
|
101
|
+
self.connections = [] # List of Connection objects
|
102
|
+
|
103
|
+
logger.info(f"Created workflow '{name}' (ID: {workflow_id})")
|
104
|
+
|
105
|
+
def add_node(self, node_id: str, node_or_type: Any, **config) -> None:
|
106
|
+
"""Add a node to the workflow.
|
107
|
+
|
108
|
+
Args:
|
109
|
+
node_id: Unique identifier for this node instance
|
110
|
+
node_or_type: Either a Node instance, Node class, or node type name
|
111
|
+
**config: Configuration for the node
|
112
|
+
|
113
|
+
Raises:
|
114
|
+
WorkflowValidationError: If node is invalid
|
115
|
+
NodeConfigurationError: If node configuration fails
|
116
|
+
"""
|
117
|
+
if node_id in self.nodes:
|
118
|
+
raise WorkflowValidationError(
|
119
|
+
f"Node '{node_id}' already exists in workflow. "
|
120
|
+
f"Existing nodes: {list(self.nodes.keys())}"
|
121
|
+
)
|
122
|
+
|
123
|
+
try:
|
124
|
+
# Handle different input types
|
125
|
+
if isinstance(node_or_type, str):
|
126
|
+
# Node type name provided
|
127
|
+
node_class = NodeRegistry.get(node_or_type)
|
128
|
+
node_instance = node_class(id=node_id, **config)
|
129
|
+
node_type = node_or_type
|
130
|
+
elif isinstance(node_or_type, type) and issubclass(node_or_type, Node):
|
131
|
+
# Node class provided
|
132
|
+
node_instance = node_or_type(id=node_id, **config)
|
133
|
+
node_type = node_or_type.__name__
|
134
|
+
elif isinstance(node_or_type, Node):
|
135
|
+
# Node instance provided
|
136
|
+
node_instance = node_or_type
|
137
|
+
node_instance.id = node_id
|
138
|
+
node_type = node_instance.__class__.__name__
|
139
|
+
# Update config - handle nested config case
|
140
|
+
if "config" in node_instance.config and isinstance(
|
141
|
+
node_instance.config["config"], dict
|
142
|
+
):
|
143
|
+
# If config is nested, extract it
|
144
|
+
actual_config = node_instance.config["config"]
|
145
|
+
node_instance.config.update(actual_config)
|
146
|
+
# Remove the nested config key
|
147
|
+
del node_instance.config["config"]
|
148
|
+
# Now update with provided config
|
149
|
+
node_instance.config.update(config)
|
150
|
+
node_instance._validate_config()
|
151
|
+
else:
|
152
|
+
raise WorkflowValidationError(
|
153
|
+
f"Invalid node type: {type(node_or_type)}. "
|
154
|
+
"Expected: str (node type name), Node class, or Node instance"
|
155
|
+
)
|
156
|
+
except NodeConfigurationError:
|
157
|
+
# Re-raise configuration errors with additional context
|
158
|
+
raise
|
159
|
+
except Exception as e:
|
160
|
+
raise NodeConfigurationError(
|
161
|
+
f"Failed to create node '{node_id}' of type '{node_or_type}': {e}"
|
162
|
+
) from e
|
163
|
+
|
164
|
+
# Store node instance and metadata
|
165
|
+
try:
|
166
|
+
node_instance_data = NodeInstance(
|
167
|
+
node_id=node_id,
|
168
|
+
node_type=node_type,
|
169
|
+
config=config,
|
170
|
+
position=(len(self.nodes) * 150, 100),
|
171
|
+
)
|
172
|
+
self.nodes[node_id] = node_instance_data
|
173
|
+
except ValidationError as e:
|
174
|
+
raise WorkflowValidationError(f"Invalid node instance data: {e}") from e
|
175
|
+
|
176
|
+
self._node_instances[node_id] = node_instance
|
177
|
+
|
178
|
+
# Add to graph
|
179
|
+
self.graph.add_node(node_id, node=node_instance, type=node_type, config=config)
|
180
|
+
logger.info(f"Added node '{node_id}' of type '{node_type}'")
|
181
|
+
|
182
|
+
def _add_node_internal(
|
183
|
+
self, node_id: str, node_type: str, config: Optional[Dict[str, Any]] = None
|
184
|
+
) -> None:
|
185
|
+
"""Add a node to the workflow (internal method).
|
186
|
+
|
187
|
+
Args:
|
188
|
+
node_id: Node identifier
|
189
|
+
node_type: Node type name
|
190
|
+
config: Node configuration
|
191
|
+
"""
|
192
|
+
# This method is used by WorkflowBuilder and from_dict
|
193
|
+
config = config or {}
|
194
|
+
self.add_node(node_id=node_id, node_or_type=node_type, **config)
|
195
|
+
|
196
|
+
def connect(
|
197
|
+
self,
|
198
|
+
source_node: str,
|
199
|
+
target_node: str,
|
200
|
+
mapping: Optional[Dict[str, str]] = None,
|
201
|
+
) -> None:
|
202
|
+
"""Connect two nodes in the workflow.
|
203
|
+
|
204
|
+
Args:
|
205
|
+
source_node: Source node ID
|
206
|
+
target_node: Target node ID
|
207
|
+
mapping: Dict mapping source outputs to target inputs
|
208
|
+
|
209
|
+
Raises:
|
210
|
+
ConnectionError: If connection is invalid
|
211
|
+
WorkflowValidationError: If nodes don't exist
|
212
|
+
"""
|
213
|
+
if source_node not in self.nodes:
|
214
|
+
available_nodes = ", ".join(self.nodes.keys())
|
215
|
+
raise WorkflowValidationError(
|
216
|
+
f"Source node '{source_node}' not found in workflow. "
|
217
|
+
f"Available nodes: {available_nodes}"
|
218
|
+
)
|
219
|
+
if target_node not in self.nodes:
|
220
|
+
available_nodes = ", ".join(self.nodes.keys())
|
221
|
+
raise WorkflowValidationError(
|
222
|
+
f"Target node '{target_node}' not found in workflow. "
|
223
|
+
f"Available nodes: {available_nodes}"
|
224
|
+
)
|
225
|
+
|
226
|
+
# Self-connection check
|
227
|
+
if source_node == target_node:
|
228
|
+
raise ConnectionError(f"Cannot connect node '{source_node}' to itself")
|
229
|
+
|
230
|
+
# Default mapping if not provided
|
231
|
+
if mapping is None:
|
232
|
+
mapping = {"output": "input"}
|
233
|
+
|
234
|
+
# Check for existing connections
|
235
|
+
existing_connections = [
|
236
|
+
c
|
237
|
+
for c in self.connections
|
238
|
+
if c.source_node == source_node and c.target_node == target_node
|
239
|
+
]
|
240
|
+
if existing_connections:
|
241
|
+
raise ConnectionError(
|
242
|
+
f"Connection already exists between '{source_node}' and '{target_node}'. "
|
243
|
+
f"Existing mappings: {[c.model_dump() for c in existing_connections]}"
|
244
|
+
)
|
245
|
+
|
246
|
+
# Create connections
|
247
|
+
for source_output, target_input in mapping.items():
|
248
|
+
try:
|
249
|
+
connection = Connection(
|
250
|
+
source_node=source_node,
|
251
|
+
source_output=source_output,
|
252
|
+
target_node=target_node,
|
253
|
+
target_input=target_input,
|
254
|
+
)
|
255
|
+
except ValidationError as e:
|
256
|
+
raise ConnectionError(f"Invalid connection data: {e}") from e
|
257
|
+
|
258
|
+
self.connections.append(connection)
|
259
|
+
|
260
|
+
# Add edge to graph
|
261
|
+
self.graph.add_edge(
|
262
|
+
source_node,
|
263
|
+
target_node,
|
264
|
+
from_output=source_output,
|
265
|
+
to_input=target_input,
|
266
|
+
mapping={
|
267
|
+
source_output: target_input
|
268
|
+
}, # Keep for backward compatibility
|
269
|
+
)
|
270
|
+
|
271
|
+
logger.info(
|
272
|
+
f"Connected '{source_node}' to '{target_node}' with mapping: {mapping}"
|
273
|
+
)
|
274
|
+
|
275
|
+
def _add_edge_internal(
|
276
|
+
self, from_node: str, from_output: str, to_node: str, to_input: str
|
277
|
+
) -> None:
|
278
|
+
"""Add an edge between nodes (internal method).
|
279
|
+
|
280
|
+
Args:
|
281
|
+
from_node: Source node ID
|
282
|
+
from_output: Output field from source
|
283
|
+
to_node: Target node ID
|
284
|
+
to_input: Input field on target
|
285
|
+
"""
|
286
|
+
# This method is used by WorkflowBuilder and from_dict
|
287
|
+
self.connect(
|
288
|
+
source_node=from_node, target_node=to_node, mapping={from_output: to_input}
|
289
|
+
)
|
290
|
+
|
291
|
+
def get_node(self, node_id: str) -> Optional[Node]:
|
292
|
+
"""Get node instance by ID.
|
293
|
+
|
294
|
+
Args:
|
295
|
+
node_id: Node identifier
|
296
|
+
|
297
|
+
Returns:
|
298
|
+
Node instance or None if not found
|
299
|
+
"""
|
300
|
+
if node_id not in self.graph.nodes:
|
301
|
+
return None
|
302
|
+
|
303
|
+
# First try to get from graph (for test compatibility)
|
304
|
+
graph_node = self.graph.nodes[node_id].get("node")
|
305
|
+
if graph_node:
|
306
|
+
return graph_node
|
307
|
+
|
308
|
+
# Fallback to _node_instances
|
309
|
+
return self._node_instances.get(node_id)
|
310
|
+
|
311
|
+
def get_execution_order(self) -> List[str]:
|
312
|
+
"""Get topological execution order for nodes.
|
313
|
+
|
314
|
+
Returns:
|
315
|
+
List of node IDs in execution order
|
316
|
+
|
317
|
+
Raises:
|
318
|
+
WorkflowValidationError: If workflow contains cycles
|
319
|
+
"""
|
320
|
+
try:
|
321
|
+
return list(nx.topological_sort(self.graph))
|
322
|
+
except nx.NetworkXUnfeasible:
|
323
|
+
cycles = list(nx.simple_cycles(self.graph))
|
324
|
+
raise WorkflowValidationError(
|
325
|
+
f"Workflow contains cycles: {cycles}. "
|
326
|
+
"Remove circular dependencies to create a valid workflow."
|
327
|
+
)
|
328
|
+
|
329
|
+
def validate(self) -> None:
|
330
|
+
"""Validate the workflow structure.
|
331
|
+
|
332
|
+
Raises:
|
333
|
+
WorkflowValidationError: If workflow is invalid
|
334
|
+
"""
|
335
|
+
# Check for cycles
|
336
|
+
try:
|
337
|
+
self.get_execution_order()
|
338
|
+
except WorkflowValidationError:
|
339
|
+
raise
|
340
|
+
|
341
|
+
# Check all nodes have required inputs
|
342
|
+
for node_id, node_instance in self._node_instances.items():
|
343
|
+
try:
|
344
|
+
params = node_instance.get_parameters()
|
345
|
+
except Exception as e:
|
346
|
+
raise WorkflowValidationError(
|
347
|
+
f"Failed to get parameters for node '{node_id}': {e}"
|
348
|
+
) from e
|
349
|
+
|
350
|
+
# Get inputs from connections
|
351
|
+
incoming_edges = self.graph.in_edges(node_id, data=True)
|
352
|
+
connected_inputs = set()
|
353
|
+
|
354
|
+
for _, _, data in incoming_edges:
|
355
|
+
to_input = data.get("to_input")
|
356
|
+
if to_input:
|
357
|
+
connected_inputs.add(to_input)
|
358
|
+
# For backward compatibility
|
359
|
+
mapping = data.get("mapping", {})
|
360
|
+
connected_inputs.update(mapping.values())
|
361
|
+
|
362
|
+
# Check required parameters
|
363
|
+
missing_inputs = []
|
364
|
+
for param_name, param_def in params.items():
|
365
|
+
if param_def.required and param_name not in connected_inputs:
|
366
|
+
# Check if it's provided in config
|
367
|
+
# Handle nested config case (for PythonCodeNode and similar)
|
368
|
+
found_in_config = param_name in node_instance.config
|
369
|
+
if not found_in_config and "config" in node_instance.config:
|
370
|
+
# Check nested config
|
371
|
+
found_in_config = param_name in node_instance.config["config"]
|
372
|
+
|
373
|
+
if not found_in_config:
|
374
|
+
if param_def.default is None:
|
375
|
+
missing_inputs.append(param_name)
|
376
|
+
|
377
|
+
if missing_inputs:
|
378
|
+
raise WorkflowValidationError(
|
379
|
+
f"Node '{node_id}' missing required inputs: {missing_inputs}. "
|
380
|
+
f"Provide these inputs via connections or node configuration"
|
381
|
+
)
|
382
|
+
|
383
|
+
logger.info(f"Workflow '{self.name}' validated successfully")
|
384
|
+
|
385
|
+
def run(
|
386
|
+
self, task_manager: Optional[TaskManager] = None, **overrides
|
387
|
+
) -> Tuple[Dict[str, Any], Optional[str]]:
|
388
|
+
"""Execute the workflow.
|
389
|
+
|
390
|
+
Args:
|
391
|
+
task_manager: Optional task manager for tracking
|
392
|
+
**overrides: Parameter overrides
|
393
|
+
|
394
|
+
Returns:
|
395
|
+
Tuple of (results dict, run_id)
|
396
|
+
|
397
|
+
Raises:
|
398
|
+
WorkflowExecutionError: If workflow execution fails
|
399
|
+
WorkflowValidationError: If workflow is invalid
|
400
|
+
"""
|
401
|
+
# For backward compatibility with original graph.py's run method
|
402
|
+
return self.execute(inputs=overrides, task_manager=task_manager), None
|
403
|
+
|
404
|
+
def execute(
|
405
|
+
self,
|
406
|
+
inputs: Optional[Dict[str, Any]] = None,
|
407
|
+
task_manager: Optional[TaskManager] = None,
|
408
|
+
) -> Dict[str, Any]:
|
409
|
+
"""Execute the workflow.
|
410
|
+
|
411
|
+
Args:
|
412
|
+
inputs: Input data for the workflow (can include node overrides)
|
413
|
+
task_manager: Optional task manager for tracking
|
414
|
+
|
415
|
+
Returns:
|
416
|
+
Execution results by node
|
417
|
+
|
418
|
+
Raises:
|
419
|
+
WorkflowExecutionError: If execution fails
|
420
|
+
"""
|
421
|
+
try:
|
422
|
+
self.validate()
|
423
|
+
except Exception as e:
|
424
|
+
raise WorkflowValidationError(f"Workflow validation failed: {e}") from e
|
425
|
+
|
426
|
+
# Initialize task tracking
|
427
|
+
run_id = None
|
428
|
+
if task_manager:
|
429
|
+
try:
|
430
|
+
run_id = task_manager.create_run(
|
431
|
+
workflow_name=self.name, metadata={"inputs": inputs}
|
432
|
+
)
|
433
|
+
except Exception as e:
|
434
|
+
logger.warning(f"Failed to create task run: {e}")
|
435
|
+
# Continue without task tracking
|
436
|
+
|
437
|
+
# Get execution order
|
438
|
+
try:
|
439
|
+
execution_order = self.get_execution_order()
|
440
|
+
except Exception as e:
|
441
|
+
raise WorkflowExecutionError(
|
442
|
+
f"Failed to determine execution order: {e}"
|
443
|
+
) from e
|
444
|
+
|
445
|
+
# Execute nodes in order
|
446
|
+
results = {}
|
447
|
+
inputs = inputs or {}
|
448
|
+
failed_nodes = []
|
449
|
+
|
450
|
+
for node_id in execution_order:
|
451
|
+
node_instance = self._node_instances[node_id]
|
452
|
+
|
453
|
+
# Start task tracking
|
454
|
+
task = None
|
455
|
+
if task_manager and run_id:
|
456
|
+
try:
|
457
|
+
task = task_manager.create_task(
|
458
|
+
run_id=run_id,
|
459
|
+
node_id=node_id,
|
460
|
+
node_type=node_instance.__class__.__name__,
|
461
|
+
)
|
462
|
+
task.update_status(TaskStatus.RUNNING)
|
463
|
+
except Exception as e:
|
464
|
+
logger.warning(f"Failed to create task for node '{node_id}': {e}")
|
465
|
+
|
466
|
+
try:
|
467
|
+
# Gather inputs from previous nodes
|
468
|
+
node_inputs = {}
|
469
|
+
|
470
|
+
# Add config values
|
471
|
+
node_inputs.update(node_instance.config)
|
472
|
+
|
473
|
+
# Get inputs from connected nodes
|
474
|
+
for edge in self.graph.in_edges(node_id, data=True):
|
475
|
+
source_node_id = edge[0]
|
476
|
+
edge_data = self.graph[source_node_id][node_id]
|
477
|
+
|
478
|
+
# Try both connection formats for backward compatibility
|
479
|
+
from_output = edge_data.get("from_output")
|
480
|
+
to_input = edge_data.get("to_input")
|
481
|
+
mapping = edge_data.get("mapping", {})
|
482
|
+
|
483
|
+
source_results = results.get(source_node_id, {})
|
484
|
+
|
485
|
+
# Add connections using from_output/to_input format
|
486
|
+
if from_output and to_input and from_output in source_results:
|
487
|
+
node_inputs[to_input] = source_results[from_output]
|
488
|
+
|
489
|
+
# Also add connections using mapping format for backward compatibility
|
490
|
+
for source_key, target_key in mapping.items():
|
491
|
+
if source_key in source_results:
|
492
|
+
node_inputs[target_key] = source_results[source_key]
|
493
|
+
|
494
|
+
# Apply overrides
|
495
|
+
node_overrides = inputs.get(node_id, {})
|
496
|
+
node_inputs.update(node_overrides)
|
497
|
+
|
498
|
+
# Execute node
|
499
|
+
logger.info(
|
500
|
+
f"Executing node '{node_id}' with inputs: {list(node_inputs.keys())}"
|
501
|
+
)
|
502
|
+
|
503
|
+
# Support both process() and execute() methods
|
504
|
+
if hasattr(node_instance, "process") and callable(
|
505
|
+
node_instance.process
|
506
|
+
):
|
507
|
+
node_results = node_instance.process(node_inputs)
|
508
|
+
else:
|
509
|
+
node_results = node_instance.execute(**node_inputs)
|
510
|
+
|
511
|
+
results[node_id] = node_results
|
512
|
+
|
513
|
+
if task:
|
514
|
+
task.update_status(TaskStatus.COMPLETED, result=node_results)
|
515
|
+
|
516
|
+
logger.info(f"Node '{node_id}' completed successfully")
|
517
|
+
|
518
|
+
except Exception as e:
|
519
|
+
failed_nodes.append(node_id)
|
520
|
+
if task:
|
521
|
+
task.update_status(TaskStatus.FAILED, error=str(e))
|
522
|
+
|
523
|
+
# Include previous failures in error message
|
524
|
+
error_msg = f"Node '{node_id}' failed: {e}"
|
525
|
+
if len(failed_nodes) > 1:
|
526
|
+
error_msg += f" (Previously failed nodes: {failed_nodes[:-1]})"
|
527
|
+
|
528
|
+
raise WorkflowExecutionError(error_msg) from e
|
529
|
+
|
530
|
+
logger.info(
|
531
|
+
f"Workflow '{self.name}' completed successfully. "
|
532
|
+
f"Executed {len(execution_order)} nodes"
|
533
|
+
)
|
534
|
+
return results
|
535
|
+
|
536
|
+
def export_to_kailash(
|
537
|
+
self, output_path: str, format: str = "yaml", **config
|
538
|
+
) -> None:
|
539
|
+
"""Export workflow to Kailash-compatible format.
|
540
|
+
|
541
|
+
Args:
|
542
|
+
output_path: Path to write file
|
543
|
+
format: Export format (yaml, json, manifest)
|
544
|
+
**config: Additional export configuration
|
545
|
+
|
546
|
+
Raises:
|
547
|
+
ExportException: If export fails
|
548
|
+
"""
|
549
|
+
try:
|
550
|
+
from kailash.utils.export import export_workflow
|
551
|
+
|
552
|
+
export_workflow(self, format=format, output_path=output_path, **config)
|
553
|
+
except ImportError as e:
|
554
|
+
raise ExportException(f"Failed to import export utilities: {e}") from e
|
555
|
+
except Exception as e:
|
556
|
+
raise ExportException(
|
557
|
+
f"Failed to export workflow to '{output_path}': {e}"
|
558
|
+
) from e
|
559
|
+
|
560
|
+
def to_dict(self) -> Dict[str, Any]:
|
561
|
+
"""Convert workflow to dictionary.
|
562
|
+
|
563
|
+
Returns:
|
564
|
+
Dictionary representation
|
565
|
+
"""
|
566
|
+
# Build nodes dictionary
|
567
|
+
nodes_dict = {}
|
568
|
+
for node_id, node_data in self.nodes.items():
|
569
|
+
nodes_dict[node_id] = node_data.model_dump()
|
570
|
+
|
571
|
+
# Build connections list
|
572
|
+
connections_list = [conn.model_dump() for conn in self.connections]
|
573
|
+
|
574
|
+
# Build workflow dictionary
|
575
|
+
return {
|
576
|
+
"workflow_id": self.workflow_id,
|
577
|
+
"name": self.name,
|
578
|
+
"description": self.description,
|
579
|
+
"version": self.version,
|
580
|
+
"author": self.author,
|
581
|
+
"metadata": self.metadata,
|
582
|
+
"nodes": nodes_dict,
|
583
|
+
"connections": connections_list,
|
584
|
+
}
|
585
|
+
|
586
|
+
def to_json(self) -> str:
|
587
|
+
"""Convert workflow to JSON string.
|
588
|
+
|
589
|
+
Returns:
|
590
|
+
JSON representation
|
591
|
+
"""
|
592
|
+
return json.dumps(self.to_dict(), indent=2)
|
593
|
+
|
594
|
+
def to_yaml(self) -> str:
|
595
|
+
"""Convert workflow to YAML string.
|
596
|
+
|
597
|
+
Returns:
|
598
|
+
YAML representation
|
599
|
+
"""
|
600
|
+
return yaml.dump(self.to_dict(), default_flow_style=False)
|
601
|
+
|
602
|
+
def save(self, path: str, format: str = "json") -> None:
|
603
|
+
"""Save workflow to file.
|
604
|
+
|
605
|
+
Args:
|
606
|
+
path: Output file path
|
607
|
+
format: Output format (json or yaml)
|
608
|
+
|
609
|
+
Raises:
|
610
|
+
ValueError: If format is invalid
|
611
|
+
"""
|
612
|
+
if format == "json":
|
613
|
+
with open(path, "w") as f:
|
614
|
+
f.write(self.to_json())
|
615
|
+
elif format == "yaml":
|
616
|
+
with open(path, "w") as f:
|
617
|
+
f.write(self.to_yaml())
|
618
|
+
else:
|
619
|
+
raise ValueError(f"Unsupported format: {format}")
|
620
|
+
|
621
|
+
@classmethod
|
622
|
+
def from_dict(cls, data: Dict[str, Any]) -> "Workflow":
|
623
|
+
"""Create workflow from dictionary.
|
624
|
+
|
625
|
+
Args:
|
626
|
+
data: Dictionary representation
|
627
|
+
|
628
|
+
Returns:
|
629
|
+
Workflow instance
|
630
|
+
|
631
|
+
Raises:
|
632
|
+
WorkflowValidationError: If data is invalid
|
633
|
+
"""
|
634
|
+
try:
|
635
|
+
# Extract basic data
|
636
|
+
workflow_id = data.get("workflow_id", str(uuid.uuid4()))
|
637
|
+
name = data.get("name", "Unnamed Workflow")
|
638
|
+
description = data.get("description", "")
|
639
|
+
version = data.get("version", "1.0.0")
|
640
|
+
author = data.get("author", "")
|
641
|
+
metadata = data.get("metadata", {})
|
642
|
+
|
643
|
+
# Create workflow
|
644
|
+
workflow = cls(
|
645
|
+
workflow_id=workflow_id,
|
646
|
+
name=name,
|
647
|
+
description=description,
|
648
|
+
version=version,
|
649
|
+
author=author,
|
650
|
+
metadata=metadata,
|
651
|
+
)
|
652
|
+
|
653
|
+
# Add nodes
|
654
|
+
nodes_data = data.get("nodes", {})
|
655
|
+
for node_id, node_data in nodes_data.items():
|
656
|
+
# Handle both formats of node data
|
657
|
+
if isinstance(node_data, dict):
|
658
|
+
# Get node type
|
659
|
+
node_type = node_data.get("node_type") or node_data.get("type")
|
660
|
+
if not node_type:
|
661
|
+
raise WorkflowValidationError(
|
662
|
+
f"Node type not specified for node '{node_id}'"
|
663
|
+
)
|
664
|
+
|
665
|
+
# Get node config
|
666
|
+
config = node_data.get("config", {})
|
667
|
+
|
668
|
+
# Add the node
|
669
|
+
workflow._add_node_internal(node_id, node_type, config)
|
670
|
+
else:
|
671
|
+
raise WorkflowValidationError(
|
672
|
+
f"Invalid node data format for node '{node_id}': {type(node_data)}"
|
673
|
+
)
|
674
|
+
|
675
|
+
# Add connections
|
676
|
+
connections = data.get("connections", [])
|
677
|
+
for conn_data in connections:
|
678
|
+
# Handle both connection formats
|
679
|
+
if "source_node" in conn_data and "target_node" in conn_data:
|
680
|
+
# Original format
|
681
|
+
source_node = conn_data.get("source_node")
|
682
|
+
source_output = conn_data.get("source_output")
|
683
|
+
target_node = conn_data.get("target_node")
|
684
|
+
target_input = conn_data.get("target_input")
|
685
|
+
workflow._add_edge_internal(
|
686
|
+
source_node, source_output, target_node, target_input
|
687
|
+
)
|
688
|
+
elif "from_node" in conn_data and "to_node" in conn_data:
|
689
|
+
# Updated format
|
690
|
+
from_node = conn_data.get("from_node")
|
691
|
+
from_output = conn_data.get("from_output", "output")
|
692
|
+
to_node = conn_data.get("to_node")
|
693
|
+
to_input = conn_data.get("to_input", "input")
|
694
|
+
workflow._add_edge_internal(
|
695
|
+
from_node, from_output, to_node, to_input
|
696
|
+
)
|
697
|
+
else:
|
698
|
+
raise WorkflowValidationError(
|
699
|
+
f"Invalid connection data: {conn_data}"
|
700
|
+
)
|
701
|
+
|
702
|
+
return workflow
|
703
|
+
|
704
|
+
except Exception as e:
|
705
|
+
if isinstance(e, WorkflowValidationError):
|
706
|
+
raise
|
707
|
+
raise WorkflowValidationError(
|
708
|
+
f"Failed to create workflow from dict: {e}"
|
709
|
+
) from e
|
710
|
+
|
711
|
+
def __repr__(self) -> str:
|
712
|
+
"""Get string representation."""
|
713
|
+
return f"Workflow(id='{self.workflow_id}', name='{self.name}', nodes={len(self.graph.nodes)}, connections={len(self.graph.edges)})"
|
714
|
+
|
715
|
+
def __str__(self) -> str:
|
716
|
+
"""Get readable string."""
|
717
|
+
return f"Workflow '{self.name}' (ID: {self.workflow_id}) with {len(self.graph.nodes)} nodes and {len(self.graph.edges)} connections"
|
718
|
+
|
719
|
+
def create_state_wrapper(self, state_model: BaseModel) -> WorkflowStateWrapper:
|
720
|
+
"""Create a state manager wrapper for a workflow.
|
721
|
+
|
722
|
+
This wrapper provides convenient methods for updating state immutably,
|
723
|
+
making it easier to manage state in workflow nodes.
|
724
|
+
|
725
|
+
Args:
|
726
|
+
state_model: The Pydantic model state object to wrap
|
727
|
+
|
728
|
+
Returns:
|
729
|
+
A WorkflowStateWrapper instance
|
730
|
+
|
731
|
+
Raises:
|
732
|
+
TypeError: If state_model is not a Pydantic BaseModel
|
733
|
+
"""
|
734
|
+
if not isinstance(state_model, BaseModel):
|
735
|
+
raise TypeError(f"Expected BaseModel, got {type(state_model)}")
|
736
|
+
|
737
|
+
return WorkflowStateWrapper(state_model)
|
738
|
+
|
739
|
+
def execute_with_state(
|
740
|
+
self,
|
741
|
+
state_model: BaseModel,
|
742
|
+
wrap_state: bool = True,
|
743
|
+
task_manager: Optional[TaskManager] = None,
|
744
|
+
**overrides,
|
745
|
+
) -> Tuple[BaseModel, Dict[str, Any]]:
|
746
|
+
"""Execute the workflow with state management.
|
747
|
+
|
748
|
+
This method provides a simplified interface for executing workflows
|
749
|
+
with automatic state management, making it easier to manage state
|
750
|
+
transitions.
|
751
|
+
|
752
|
+
Args:
|
753
|
+
state_model: The initial state for workflow execution
|
754
|
+
wrap_state: Whether to wrap state in WorkflowStateWrapper
|
755
|
+
task_manager: Optional task manager for tracking
|
756
|
+
**overrides: Additional parameter overrides
|
757
|
+
|
758
|
+
Returns:
|
759
|
+
Tuple of (final state, all results)
|
760
|
+
|
761
|
+
Raises:
|
762
|
+
WorkflowExecutionError: If execution fails
|
763
|
+
WorkflowValidationError: If workflow is invalid
|
764
|
+
"""
|
765
|
+
# Validate input
|
766
|
+
if not isinstance(state_model, BaseModel):
|
767
|
+
raise TypeError(f"Expected BaseModel, got {type(state_model)}")
|
768
|
+
|
769
|
+
# Prepare inputs
|
770
|
+
inputs = {}
|
771
|
+
|
772
|
+
# Wrap the state if needed
|
773
|
+
if wrap_state:
|
774
|
+
state_wrapper = self.create_state_wrapper(state_model)
|
775
|
+
# Find entry nodes (nodes with no incoming edges) and provide state_wrapper to them
|
776
|
+
for node_id in self.nodes:
|
777
|
+
if self.graph.in_degree(node_id) == 0: # Entry node
|
778
|
+
inputs[node_id] = {"state_wrapper": state_wrapper}
|
779
|
+
else:
|
780
|
+
# Find entry nodes and provide unwrapped state to them
|
781
|
+
for node_id in self.nodes:
|
782
|
+
if self.graph.in_degree(node_id) == 0: # Entry node
|
783
|
+
inputs[node_id] = {"state": state_model}
|
784
|
+
|
785
|
+
# Add any additional overrides
|
786
|
+
for key, value in overrides.items():
|
787
|
+
if key in self.nodes:
|
788
|
+
inputs.setdefault(key, {}).update(value)
|
789
|
+
|
790
|
+
# Execute the workflow
|
791
|
+
results = self.execute(inputs=inputs, task_manager=task_manager)
|
792
|
+
|
793
|
+
# Find the final state
|
794
|
+
# First try to find state_wrapper in the last node's outputs
|
795
|
+
execution_order = self.get_execution_order()
|
796
|
+
if execution_order:
|
797
|
+
last_node_id = execution_order[-1]
|
798
|
+
last_node_results = results.get(last_node_id, {})
|
799
|
+
|
800
|
+
if wrap_state:
|
801
|
+
final_state_wrapper = last_node_results.get("state_wrapper")
|
802
|
+
if final_state_wrapper and isinstance(
|
803
|
+
final_state_wrapper, WorkflowStateWrapper
|
804
|
+
):
|
805
|
+
return final_state_wrapper.get_state(), results
|
806
|
+
|
807
|
+
# Try to find another key with a WorkflowStateWrapper
|
808
|
+
for key, value in last_node_results.items():
|
809
|
+
if isinstance(value, WorkflowStateWrapper):
|
810
|
+
return value.get_state(), results
|
811
|
+
else:
|
812
|
+
final_state = last_node_results.get("state")
|
813
|
+
if final_state and isinstance(final_state, BaseModel):
|
814
|
+
return final_state, results
|
815
|
+
|
816
|
+
# Try to find another key with a BaseModel
|
817
|
+
for key, value in last_node_results.items():
|
818
|
+
if isinstance(value, BaseModel) and type(value) == type(
|
819
|
+
state_model
|
820
|
+
):
|
821
|
+
return value, results
|
822
|
+
|
823
|
+
# Fallback to original state
|
824
|
+
logger.warning(
|
825
|
+
"Failed to find final state in workflow results, returning original state"
|
826
|
+
)
|
827
|
+
return state_model, results
|