kailash 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. kailash/__init__.py +31 -0
  2. kailash/__main__.py +11 -0
  3. kailash/cli/__init__.py +5 -0
  4. kailash/cli/commands.py +563 -0
  5. kailash/manifest.py +778 -0
  6. kailash/nodes/__init__.py +23 -0
  7. kailash/nodes/ai/__init__.py +26 -0
  8. kailash/nodes/ai/agents.py +417 -0
  9. kailash/nodes/ai/models.py +488 -0
  10. kailash/nodes/api/__init__.py +52 -0
  11. kailash/nodes/api/auth.py +567 -0
  12. kailash/nodes/api/graphql.py +480 -0
  13. kailash/nodes/api/http.py +598 -0
  14. kailash/nodes/api/rate_limiting.py +572 -0
  15. kailash/nodes/api/rest.py +665 -0
  16. kailash/nodes/base.py +1032 -0
  17. kailash/nodes/base_async.py +128 -0
  18. kailash/nodes/code/__init__.py +32 -0
  19. kailash/nodes/code/python.py +1021 -0
  20. kailash/nodes/data/__init__.py +125 -0
  21. kailash/nodes/data/readers.py +496 -0
  22. kailash/nodes/data/sharepoint_graph.py +623 -0
  23. kailash/nodes/data/sql.py +380 -0
  24. kailash/nodes/data/streaming.py +1168 -0
  25. kailash/nodes/data/vector_db.py +964 -0
  26. kailash/nodes/data/writers.py +529 -0
  27. kailash/nodes/logic/__init__.py +6 -0
  28. kailash/nodes/logic/async_operations.py +702 -0
  29. kailash/nodes/logic/operations.py +551 -0
  30. kailash/nodes/transform/__init__.py +5 -0
  31. kailash/nodes/transform/processors.py +379 -0
  32. kailash/runtime/__init__.py +6 -0
  33. kailash/runtime/async_local.py +356 -0
  34. kailash/runtime/docker.py +697 -0
  35. kailash/runtime/local.py +434 -0
  36. kailash/runtime/parallel.py +557 -0
  37. kailash/runtime/runner.py +110 -0
  38. kailash/runtime/testing.py +347 -0
  39. kailash/sdk_exceptions.py +307 -0
  40. kailash/tracking/__init__.py +7 -0
  41. kailash/tracking/manager.py +885 -0
  42. kailash/tracking/metrics_collector.py +342 -0
  43. kailash/tracking/models.py +535 -0
  44. kailash/tracking/storage/__init__.py +0 -0
  45. kailash/tracking/storage/base.py +113 -0
  46. kailash/tracking/storage/database.py +619 -0
  47. kailash/tracking/storage/filesystem.py +543 -0
  48. kailash/utils/__init__.py +0 -0
  49. kailash/utils/export.py +924 -0
  50. kailash/utils/templates.py +680 -0
  51. kailash/visualization/__init__.py +62 -0
  52. kailash/visualization/api.py +732 -0
  53. kailash/visualization/dashboard.py +951 -0
  54. kailash/visualization/performance.py +808 -0
  55. kailash/visualization/reports.py +1471 -0
  56. kailash/workflow/__init__.py +15 -0
  57. kailash/workflow/builder.py +245 -0
  58. kailash/workflow/graph.py +827 -0
  59. kailash/workflow/mermaid_visualizer.py +628 -0
  60. kailash/workflow/mock_registry.py +63 -0
  61. kailash/workflow/runner.py +302 -0
  62. kailash/workflow/state.py +238 -0
  63. kailash/workflow/visualization.py +588 -0
  64. kailash-0.1.0.dist-info/METADATA +710 -0
  65. kailash-0.1.0.dist-info/RECORD +69 -0
  66. kailash-0.1.0.dist-info/WHEEL +5 -0
  67. kailash-0.1.0.dist-info/entry_points.txt +2 -0
  68. kailash-0.1.0.dist-info/licenses/LICENSE +21 -0
  69. kailash-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,827 @@
1
+ """Workflow DAG implementation for the Kailash SDK."""
2
+
3
+ import json
4
+ import logging
5
+ import uuid
6
+ from datetime import datetime, timezone
7
+ from typing import Any, Dict, List, Optional, Tuple
8
+
9
+ import networkx as nx
10
+ import yaml
11
+ from pydantic import BaseModel, Field, ValidationError
12
+
13
+ from kailash.nodes import Node
14
+
15
+ try:
16
+ # For normal runtime, use the actual registry
17
+ from kailash.nodes import NodeRegistry
18
+ except ImportError:
19
+ # For tests, use the mock registry
20
+ from kailash.workflow.mock_registry import MockRegistry as NodeRegistry
21
+
22
+ from kailash.sdk_exceptions import (
23
+ ConnectionError,
24
+ ExportException,
25
+ NodeConfigurationError,
26
+ WorkflowExecutionError,
27
+ WorkflowValidationError,
28
+ )
29
+ from kailash.tracking import TaskManager, TaskStatus
30
+ from kailash.workflow.state import WorkflowStateWrapper
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ class NodeInstance(BaseModel):
36
+ """Instance of a node in a workflow."""
37
+
38
+ node_id: str = Field(..., description="Unique identifier for this instance")
39
+ node_type: str = Field(..., description="Type of node")
40
+ config: Dict[str, Any] = Field(
41
+ default_factory=dict, description="Node configuration"
42
+ )
43
+ position: Tuple[float, float] = Field(default=(0, 0), description="Visual position")
44
+
45
+
46
+ class Connection(BaseModel):
47
+ """Connection between two nodes in a workflow."""
48
+
49
+ source_node: str = Field(..., description="Source node ID")
50
+ source_output: str = Field(..., description="Output field from source")
51
+ target_node: str = Field(..., description="Target node ID")
52
+ target_input: str = Field(..., description="Input field on target")
53
+
54
+
55
+ class Workflow:
56
+ """Represents a workflow DAG of nodes."""
57
+
58
+ def __init__(
59
+ self,
60
+ workflow_id: str,
61
+ name: str,
62
+ description: str = "",
63
+ version: str = "1.0.0",
64
+ author: str = "",
65
+ metadata: Optional[Dict[str, Any]] = None,
66
+ ):
67
+ """Initialize a workflow.
68
+
69
+ Args:
70
+ workflow_id: Unique workflow identifier
71
+ name: Workflow name
72
+ description: Workflow description
73
+ version: Workflow version
74
+ author: Workflow author
75
+ metadata: Additional metadata
76
+
77
+ Raises:
78
+ WorkflowValidationError: If workflow initialization fails
79
+ """
80
+ self.workflow_id = workflow_id
81
+ self.name = name
82
+ self.description = description
83
+ self.version = version
84
+ self.author = author
85
+ self.metadata = metadata or {}
86
+
87
+ # Add standard metadata
88
+ if "author" not in self.metadata and author:
89
+ self.metadata["author"] = author
90
+ if "version" not in self.metadata and version:
91
+ self.metadata["version"] = version
92
+ if "created_at" not in self.metadata:
93
+ self.metadata["created_at"] = datetime.now(timezone.utc).isoformat()
94
+
95
+ # Create directed graph for the workflow
96
+ self.graph = nx.DiGraph()
97
+
98
+ # Storage for node instances and node metadata
99
+ self._node_instances = {} # Maps node_id to Node instances
100
+ self.nodes = {} # Maps node_id to NodeInstance metadata objects
101
+ self.connections = [] # List of Connection objects
102
+
103
+ logger.info(f"Created workflow '{name}' (ID: {workflow_id})")
104
+
105
+ def add_node(self, node_id: str, node_or_type: Any, **config) -> None:
106
+ """Add a node to the workflow.
107
+
108
+ Args:
109
+ node_id: Unique identifier for this node instance
110
+ node_or_type: Either a Node instance, Node class, or node type name
111
+ **config: Configuration for the node
112
+
113
+ Raises:
114
+ WorkflowValidationError: If node is invalid
115
+ NodeConfigurationError: If node configuration fails
116
+ """
117
+ if node_id in self.nodes:
118
+ raise WorkflowValidationError(
119
+ f"Node '{node_id}' already exists in workflow. "
120
+ f"Existing nodes: {list(self.nodes.keys())}"
121
+ )
122
+
123
+ try:
124
+ # Handle different input types
125
+ if isinstance(node_or_type, str):
126
+ # Node type name provided
127
+ node_class = NodeRegistry.get(node_or_type)
128
+ node_instance = node_class(id=node_id, **config)
129
+ node_type = node_or_type
130
+ elif isinstance(node_or_type, type) and issubclass(node_or_type, Node):
131
+ # Node class provided
132
+ node_instance = node_or_type(id=node_id, **config)
133
+ node_type = node_or_type.__name__
134
+ elif isinstance(node_or_type, Node):
135
+ # Node instance provided
136
+ node_instance = node_or_type
137
+ node_instance.id = node_id
138
+ node_type = node_instance.__class__.__name__
139
+ # Update config - handle nested config case
140
+ if "config" in node_instance.config and isinstance(
141
+ node_instance.config["config"], dict
142
+ ):
143
+ # If config is nested, extract it
144
+ actual_config = node_instance.config["config"]
145
+ node_instance.config.update(actual_config)
146
+ # Remove the nested config key
147
+ del node_instance.config["config"]
148
+ # Now update with provided config
149
+ node_instance.config.update(config)
150
+ node_instance._validate_config()
151
+ else:
152
+ raise WorkflowValidationError(
153
+ f"Invalid node type: {type(node_or_type)}. "
154
+ "Expected: str (node type name), Node class, or Node instance"
155
+ )
156
+ except NodeConfigurationError:
157
+ # Re-raise configuration errors with additional context
158
+ raise
159
+ except Exception as e:
160
+ raise NodeConfigurationError(
161
+ f"Failed to create node '{node_id}' of type '{node_or_type}': {e}"
162
+ ) from e
163
+
164
+ # Store node instance and metadata
165
+ try:
166
+ node_instance_data = NodeInstance(
167
+ node_id=node_id,
168
+ node_type=node_type,
169
+ config=config,
170
+ position=(len(self.nodes) * 150, 100),
171
+ )
172
+ self.nodes[node_id] = node_instance_data
173
+ except ValidationError as e:
174
+ raise WorkflowValidationError(f"Invalid node instance data: {e}") from e
175
+
176
+ self._node_instances[node_id] = node_instance
177
+
178
+ # Add to graph
179
+ self.graph.add_node(node_id, node=node_instance, type=node_type, config=config)
180
+ logger.info(f"Added node '{node_id}' of type '{node_type}'")
181
+
182
+ def _add_node_internal(
183
+ self, node_id: str, node_type: str, config: Optional[Dict[str, Any]] = None
184
+ ) -> None:
185
+ """Add a node to the workflow (internal method).
186
+
187
+ Args:
188
+ node_id: Node identifier
189
+ node_type: Node type name
190
+ config: Node configuration
191
+ """
192
+ # This method is used by WorkflowBuilder and from_dict
193
+ config = config or {}
194
+ self.add_node(node_id=node_id, node_or_type=node_type, **config)
195
+
196
+ def connect(
197
+ self,
198
+ source_node: str,
199
+ target_node: str,
200
+ mapping: Optional[Dict[str, str]] = None,
201
+ ) -> None:
202
+ """Connect two nodes in the workflow.
203
+
204
+ Args:
205
+ source_node: Source node ID
206
+ target_node: Target node ID
207
+ mapping: Dict mapping source outputs to target inputs
208
+
209
+ Raises:
210
+ ConnectionError: If connection is invalid
211
+ WorkflowValidationError: If nodes don't exist
212
+ """
213
+ if source_node not in self.nodes:
214
+ available_nodes = ", ".join(self.nodes.keys())
215
+ raise WorkflowValidationError(
216
+ f"Source node '{source_node}' not found in workflow. "
217
+ f"Available nodes: {available_nodes}"
218
+ )
219
+ if target_node not in self.nodes:
220
+ available_nodes = ", ".join(self.nodes.keys())
221
+ raise WorkflowValidationError(
222
+ f"Target node '{target_node}' not found in workflow. "
223
+ f"Available nodes: {available_nodes}"
224
+ )
225
+
226
+ # Self-connection check
227
+ if source_node == target_node:
228
+ raise ConnectionError(f"Cannot connect node '{source_node}' to itself")
229
+
230
+ # Default mapping if not provided
231
+ if mapping is None:
232
+ mapping = {"output": "input"}
233
+
234
+ # Check for existing connections
235
+ existing_connections = [
236
+ c
237
+ for c in self.connections
238
+ if c.source_node == source_node and c.target_node == target_node
239
+ ]
240
+ if existing_connections:
241
+ raise ConnectionError(
242
+ f"Connection already exists between '{source_node}' and '{target_node}'. "
243
+ f"Existing mappings: {[c.model_dump() for c in existing_connections]}"
244
+ )
245
+
246
+ # Create connections
247
+ for source_output, target_input in mapping.items():
248
+ try:
249
+ connection = Connection(
250
+ source_node=source_node,
251
+ source_output=source_output,
252
+ target_node=target_node,
253
+ target_input=target_input,
254
+ )
255
+ except ValidationError as e:
256
+ raise ConnectionError(f"Invalid connection data: {e}") from e
257
+
258
+ self.connections.append(connection)
259
+
260
+ # Add edge to graph
261
+ self.graph.add_edge(
262
+ source_node,
263
+ target_node,
264
+ from_output=source_output,
265
+ to_input=target_input,
266
+ mapping={
267
+ source_output: target_input
268
+ }, # Keep for backward compatibility
269
+ )
270
+
271
+ logger.info(
272
+ f"Connected '{source_node}' to '{target_node}' with mapping: {mapping}"
273
+ )
274
+
275
+ def _add_edge_internal(
276
+ self, from_node: str, from_output: str, to_node: str, to_input: str
277
+ ) -> None:
278
+ """Add an edge between nodes (internal method).
279
+
280
+ Args:
281
+ from_node: Source node ID
282
+ from_output: Output field from source
283
+ to_node: Target node ID
284
+ to_input: Input field on target
285
+ """
286
+ # This method is used by WorkflowBuilder and from_dict
287
+ self.connect(
288
+ source_node=from_node, target_node=to_node, mapping={from_output: to_input}
289
+ )
290
+
291
+ def get_node(self, node_id: str) -> Optional[Node]:
292
+ """Get node instance by ID.
293
+
294
+ Args:
295
+ node_id: Node identifier
296
+
297
+ Returns:
298
+ Node instance or None if not found
299
+ """
300
+ if node_id not in self.graph.nodes:
301
+ return None
302
+
303
+ # First try to get from graph (for test compatibility)
304
+ graph_node = self.graph.nodes[node_id].get("node")
305
+ if graph_node:
306
+ return graph_node
307
+
308
+ # Fallback to _node_instances
309
+ return self._node_instances.get(node_id)
310
+
311
+ def get_execution_order(self) -> List[str]:
312
+ """Get topological execution order for nodes.
313
+
314
+ Returns:
315
+ List of node IDs in execution order
316
+
317
+ Raises:
318
+ WorkflowValidationError: If workflow contains cycles
319
+ """
320
+ try:
321
+ return list(nx.topological_sort(self.graph))
322
+ except nx.NetworkXUnfeasible:
323
+ cycles = list(nx.simple_cycles(self.graph))
324
+ raise WorkflowValidationError(
325
+ f"Workflow contains cycles: {cycles}. "
326
+ "Remove circular dependencies to create a valid workflow."
327
+ )
328
+
329
+ def validate(self) -> None:
330
+ """Validate the workflow structure.
331
+
332
+ Raises:
333
+ WorkflowValidationError: If workflow is invalid
334
+ """
335
+ # Check for cycles
336
+ try:
337
+ self.get_execution_order()
338
+ except WorkflowValidationError:
339
+ raise
340
+
341
+ # Check all nodes have required inputs
342
+ for node_id, node_instance in self._node_instances.items():
343
+ try:
344
+ params = node_instance.get_parameters()
345
+ except Exception as e:
346
+ raise WorkflowValidationError(
347
+ f"Failed to get parameters for node '{node_id}': {e}"
348
+ ) from e
349
+
350
+ # Get inputs from connections
351
+ incoming_edges = self.graph.in_edges(node_id, data=True)
352
+ connected_inputs = set()
353
+
354
+ for _, _, data in incoming_edges:
355
+ to_input = data.get("to_input")
356
+ if to_input:
357
+ connected_inputs.add(to_input)
358
+ # For backward compatibility
359
+ mapping = data.get("mapping", {})
360
+ connected_inputs.update(mapping.values())
361
+
362
+ # Check required parameters
363
+ missing_inputs = []
364
+ for param_name, param_def in params.items():
365
+ if param_def.required and param_name not in connected_inputs:
366
+ # Check if it's provided in config
367
+ # Handle nested config case (for PythonCodeNode and similar)
368
+ found_in_config = param_name in node_instance.config
369
+ if not found_in_config and "config" in node_instance.config:
370
+ # Check nested config
371
+ found_in_config = param_name in node_instance.config["config"]
372
+
373
+ if not found_in_config:
374
+ if param_def.default is None:
375
+ missing_inputs.append(param_name)
376
+
377
+ if missing_inputs:
378
+ raise WorkflowValidationError(
379
+ f"Node '{node_id}' missing required inputs: {missing_inputs}. "
380
+ f"Provide these inputs via connections or node configuration"
381
+ )
382
+
383
+ logger.info(f"Workflow '{self.name}' validated successfully")
384
+
385
+ def run(
386
+ self, task_manager: Optional[TaskManager] = None, **overrides
387
+ ) -> Tuple[Dict[str, Any], Optional[str]]:
388
+ """Execute the workflow.
389
+
390
+ Args:
391
+ task_manager: Optional task manager for tracking
392
+ **overrides: Parameter overrides
393
+
394
+ Returns:
395
+ Tuple of (results dict, run_id)
396
+
397
+ Raises:
398
+ WorkflowExecutionError: If workflow execution fails
399
+ WorkflowValidationError: If workflow is invalid
400
+ """
401
+ # For backward compatibility with original graph.py's run method
402
+ return self.execute(inputs=overrides, task_manager=task_manager), None
403
+
404
+ def execute(
405
+ self,
406
+ inputs: Optional[Dict[str, Any]] = None,
407
+ task_manager: Optional[TaskManager] = None,
408
+ ) -> Dict[str, Any]:
409
+ """Execute the workflow.
410
+
411
+ Args:
412
+ inputs: Input data for the workflow (can include node overrides)
413
+ task_manager: Optional task manager for tracking
414
+
415
+ Returns:
416
+ Execution results by node
417
+
418
+ Raises:
419
+ WorkflowExecutionError: If execution fails
420
+ """
421
+ try:
422
+ self.validate()
423
+ except Exception as e:
424
+ raise WorkflowValidationError(f"Workflow validation failed: {e}") from e
425
+
426
+ # Initialize task tracking
427
+ run_id = None
428
+ if task_manager:
429
+ try:
430
+ run_id = task_manager.create_run(
431
+ workflow_name=self.name, metadata={"inputs": inputs}
432
+ )
433
+ except Exception as e:
434
+ logger.warning(f"Failed to create task run: {e}")
435
+ # Continue without task tracking
436
+
437
+ # Get execution order
438
+ try:
439
+ execution_order = self.get_execution_order()
440
+ except Exception as e:
441
+ raise WorkflowExecutionError(
442
+ f"Failed to determine execution order: {e}"
443
+ ) from e
444
+
445
+ # Execute nodes in order
446
+ results = {}
447
+ inputs = inputs or {}
448
+ failed_nodes = []
449
+
450
+ for node_id in execution_order:
451
+ node_instance = self._node_instances[node_id]
452
+
453
+ # Start task tracking
454
+ task = None
455
+ if task_manager and run_id:
456
+ try:
457
+ task = task_manager.create_task(
458
+ run_id=run_id,
459
+ node_id=node_id,
460
+ node_type=node_instance.__class__.__name__,
461
+ )
462
+ task.update_status(TaskStatus.RUNNING)
463
+ except Exception as e:
464
+ logger.warning(f"Failed to create task for node '{node_id}': {e}")
465
+
466
+ try:
467
+ # Gather inputs from previous nodes
468
+ node_inputs = {}
469
+
470
+ # Add config values
471
+ node_inputs.update(node_instance.config)
472
+
473
+ # Get inputs from connected nodes
474
+ for edge in self.graph.in_edges(node_id, data=True):
475
+ source_node_id = edge[0]
476
+ edge_data = self.graph[source_node_id][node_id]
477
+
478
+ # Try both connection formats for backward compatibility
479
+ from_output = edge_data.get("from_output")
480
+ to_input = edge_data.get("to_input")
481
+ mapping = edge_data.get("mapping", {})
482
+
483
+ source_results = results.get(source_node_id, {})
484
+
485
+ # Add connections using from_output/to_input format
486
+ if from_output and to_input and from_output in source_results:
487
+ node_inputs[to_input] = source_results[from_output]
488
+
489
+ # Also add connections using mapping format for backward compatibility
490
+ for source_key, target_key in mapping.items():
491
+ if source_key in source_results:
492
+ node_inputs[target_key] = source_results[source_key]
493
+
494
+ # Apply overrides
495
+ node_overrides = inputs.get(node_id, {})
496
+ node_inputs.update(node_overrides)
497
+
498
+ # Execute node
499
+ logger.info(
500
+ f"Executing node '{node_id}' with inputs: {list(node_inputs.keys())}"
501
+ )
502
+
503
+ # Support both process() and execute() methods
504
+ if hasattr(node_instance, "process") and callable(
505
+ node_instance.process
506
+ ):
507
+ node_results = node_instance.process(node_inputs)
508
+ else:
509
+ node_results = node_instance.execute(**node_inputs)
510
+
511
+ results[node_id] = node_results
512
+
513
+ if task:
514
+ task.update_status(TaskStatus.COMPLETED, result=node_results)
515
+
516
+ logger.info(f"Node '{node_id}' completed successfully")
517
+
518
+ except Exception as e:
519
+ failed_nodes.append(node_id)
520
+ if task:
521
+ task.update_status(TaskStatus.FAILED, error=str(e))
522
+
523
+ # Include previous failures in error message
524
+ error_msg = f"Node '{node_id}' failed: {e}"
525
+ if len(failed_nodes) > 1:
526
+ error_msg += f" (Previously failed nodes: {failed_nodes[:-1]})"
527
+
528
+ raise WorkflowExecutionError(error_msg) from e
529
+
530
+ logger.info(
531
+ f"Workflow '{self.name}' completed successfully. "
532
+ f"Executed {len(execution_order)} nodes"
533
+ )
534
+ return results
535
+
536
+ def export_to_kailash(
537
+ self, output_path: str, format: str = "yaml", **config
538
+ ) -> None:
539
+ """Export workflow to Kailash-compatible format.
540
+
541
+ Args:
542
+ output_path: Path to write file
543
+ format: Export format (yaml, json, manifest)
544
+ **config: Additional export configuration
545
+
546
+ Raises:
547
+ ExportException: If export fails
548
+ """
549
+ try:
550
+ from kailash.utils.export import export_workflow
551
+
552
+ export_workflow(self, format=format, output_path=output_path, **config)
553
+ except ImportError as e:
554
+ raise ExportException(f"Failed to import export utilities: {e}") from e
555
+ except Exception as e:
556
+ raise ExportException(
557
+ f"Failed to export workflow to '{output_path}': {e}"
558
+ ) from e
559
+
560
+ def to_dict(self) -> Dict[str, Any]:
561
+ """Convert workflow to dictionary.
562
+
563
+ Returns:
564
+ Dictionary representation
565
+ """
566
+ # Build nodes dictionary
567
+ nodes_dict = {}
568
+ for node_id, node_data in self.nodes.items():
569
+ nodes_dict[node_id] = node_data.model_dump()
570
+
571
+ # Build connections list
572
+ connections_list = [conn.model_dump() for conn in self.connections]
573
+
574
+ # Build workflow dictionary
575
+ return {
576
+ "workflow_id": self.workflow_id,
577
+ "name": self.name,
578
+ "description": self.description,
579
+ "version": self.version,
580
+ "author": self.author,
581
+ "metadata": self.metadata,
582
+ "nodes": nodes_dict,
583
+ "connections": connections_list,
584
+ }
585
+
586
+ def to_json(self) -> str:
587
+ """Convert workflow to JSON string.
588
+
589
+ Returns:
590
+ JSON representation
591
+ """
592
+ return json.dumps(self.to_dict(), indent=2)
593
+
594
+ def to_yaml(self) -> str:
595
+ """Convert workflow to YAML string.
596
+
597
+ Returns:
598
+ YAML representation
599
+ """
600
+ return yaml.dump(self.to_dict(), default_flow_style=False)
601
+
602
+ def save(self, path: str, format: str = "json") -> None:
603
+ """Save workflow to file.
604
+
605
+ Args:
606
+ path: Output file path
607
+ format: Output format (json or yaml)
608
+
609
+ Raises:
610
+ ValueError: If format is invalid
611
+ """
612
+ if format == "json":
613
+ with open(path, "w") as f:
614
+ f.write(self.to_json())
615
+ elif format == "yaml":
616
+ with open(path, "w") as f:
617
+ f.write(self.to_yaml())
618
+ else:
619
+ raise ValueError(f"Unsupported format: {format}")
620
+
621
+ @classmethod
622
+ def from_dict(cls, data: Dict[str, Any]) -> "Workflow":
623
+ """Create workflow from dictionary.
624
+
625
+ Args:
626
+ data: Dictionary representation
627
+
628
+ Returns:
629
+ Workflow instance
630
+
631
+ Raises:
632
+ WorkflowValidationError: If data is invalid
633
+ """
634
+ try:
635
+ # Extract basic data
636
+ workflow_id = data.get("workflow_id", str(uuid.uuid4()))
637
+ name = data.get("name", "Unnamed Workflow")
638
+ description = data.get("description", "")
639
+ version = data.get("version", "1.0.0")
640
+ author = data.get("author", "")
641
+ metadata = data.get("metadata", {})
642
+
643
+ # Create workflow
644
+ workflow = cls(
645
+ workflow_id=workflow_id,
646
+ name=name,
647
+ description=description,
648
+ version=version,
649
+ author=author,
650
+ metadata=metadata,
651
+ )
652
+
653
+ # Add nodes
654
+ nodes_data = data.get("nodes", {})
655
+ for node_id, node_data in nodes_data.items():
656
+ # Handle both formats of node data
657
+ if isinstance(node_data, dict):
658
+ # Get node type
659
+ node_type = node_data.get("node_type") or node_data.get("type")
660
+ if not node_type:
661
+ raise WorkflowValidationError(
662
+ f"Node type not specified for node '{node_id}'"
663
+ )
664
+
665
+ # Get node config
666
+ config = node_data.get("config", {})
667
+
668
+ # Add the node
669
+ workflow._add_node_internal(node_id, node_type, config)
670
+ else:
671
+ raise WorkflowValidationError(
672
+ f"Invalid node data format for node '{node_id}': {type(node_data)}"
673
+ )
674
+
675
+ # Add connections
676
+ connections = data.get("connections", [])
677
+ for conn_data in connections:
678
+ # Handle both connection formats
679
+ if "source_node" in conn_data and "target_node" in conn_data:
680
+ # Original format
681
+ source_node = conn_data.get("source_node")
682
+ source_output = conn_data.get("source_output")
683
+ target_node = conn_data.get("target_node")
684
+ target_input = conn_data.get("target_input")
685
+ workflow._add_edge_internal(
686
+ source_node, source_output, target_node, target_input
687
+ )
688
+ elif "from_node" in conn_data and "to_node" in conn_data:
689
+ # Updated format
690
+ from_node = conn_data.get("from_node")
691
+ from_output = conn_data.get("from_output", "output")
692
+ to_node = conn_data.get("to_node")
693
+ to_input = conn_data.get("to_input", "input")
694
+ workflow._add_edge_internal(
695
+ from_node, from_output, to_node, to_input
696
+ )
697
+ else:
698
+ raise WorkflowValidationError(
699
+ f"Invalid connection data: {conn_data}"
700
+ )
701
+
702
+ return workflow
703
+
704
+ except Exception as e:
705
+ if isinstance(e, WorkflowValidationError):
706
+ raise
707
+ raise WorkflowValidationError(
708
+ f"Failed to create workflow from dict: {e}"
709
+ ) from e
710
+
711
+ def __repr__(self) -> str:
712
+ """Get string representation."""
713
+ return f"Workflow(id='{self.workflow_id}', name='{self.name}', nodes={len(self.graph.nodes)}, connections={len(self.graph.edges)})"
714
+
715
+ def __str__(self) -> str:
716
+ """Get readable string."""
717
+ return f"Workflow '{self.name}' (ID: {self.workflow_id}) with {len(self.graph.nodes)} nodes and {len(self.graph.edges)} connections"
718
+
719
+ def create_state_wrapper(self, state_model: BaseModel) -> WorkflowStateWrapper:
720
+ """Create a state manager wrapper for a workflow.
721
+
722
+ This wrapper provides convenient methods for updating state immutably,
723
+ making it easier to manage state in workflow nodes.
724
+
725
+ Args:
726
+ state_model: The Pydantic model state object to wrap
727
+
728
+ Returns:
729
+ A WorkflowStateWrapper instance
730
+
731
+ Raises:
732
+ TypeError: If state_model is not a Pydantic BaseModel
733
+ """
734
+ if not isinstance(state_model, BaseModel):
735
+ raise TypeError(f"Expected BaseModel, got {type(state_model)}")
736
+
737
+ return WorkflowStateWrapper(state_model)
738
+
739
+ def execute_with_state(
740
+ self,
741
+ state_model: BaseModel,
742
+ wrap_state: bool = True,
743
+ task_manager: Optional[TaskManager] = None,
744
+ **overrides,
745
+ ) -> Tuple[BaseModel, Dict[str, Any]]:
746
+ """Execute the workflow with state management.
747
+
748
+ This method provides a simplified interface for executing workflows
749
+ with automatic state management, making it easier to manage state
750
+ transitions.
751
+
752
+ Args:
753
+ state_model: The initial state for workflow execution
754
+ wrap_state: Whether to wrap state in WorkflowStateWrapper
755
+ task_manager: Optional task manager for tracking
756
+ **overrides: Additional parameter overrides
757
+
758
+ Returns:
759
+ Tuple of (final state, all results)
760
+
761
+ Raises:
762
+ WorkflowExecutionError: If execution fails
763
+ WorkflowValidationError: If workflow is invalid
764
+ """
765
+ # Validate input
766
+ if not isinstance(state_model, BaseModel):
767
+ raise TypeError(f"Expected BaseModel, got {type(state_model)}")
768
+
769
+ # Prepare inputs
770
+ inputs = {}
771
+
772
+ # Wrap the state if needed
773
+ if wrap_state:
774
+ state_wrapper = self.create_state_wrapper(state_model)
775
+ # Find entry nodes (nodes with no incoming edges) and provide state_wrapper to them
776
+ for node_id in self.nodes:
777
+ if self.graph.in_degree(node_id) == 0: # Entry node
778
+ inputs[node_id] = {"state_wrapper": state_wrapper}
779
+ else:
780
+ # Find entry nodes and provide unwrapped state to them
781
+ for node_id in self.nodes:
782
+ if self.graph.in_degree(node_id) == 0: # Entry node
783
+ inputs[node_id] = {"state": state_model}
784
+
785
+ # Add any additional overrides
786
+ for key, value in overrides.items():
787
+ if key in self.nodes:
788
+ inputs.setdefault(key, {}).update(value)
789
+
790
+ # Execute the workflow
791
+ results = self.execute(inputs=inputs, task_manager=task_manager)
792
+
793
+ # Find the final state
794
+ # First try to find state_wrapper in the last node's outputs
795
+ execution_order = self.get_execution_order()
796
+ if execution_order:
797
+ last_node_id = execution_order[-1]
798
+ last_node_results = results.get(last_node_id, {})
799
+
800
+ if wrap_state:
801
+ final_state_wrapper = last_node_results.get("state_wrapper")
802
+ if final_state_wrapper and isinstance(
803
+ final_state_wrapper, WorkflowStateWrapper
804
+ ):
805
+ return final_state_wrapper.get_state(), results
806
+
807
+ # Try to find another key with a WorkflowStateWrapper
808
+ for key, value in last_node_results.items():
809
+ if isinstance(value, WorkflowStateWrapper):
810
+ return value.get_state(), results
811
+ else:
812
+ final_state = last_node_results.get("state")
813
+ if final_state and isinstance(final_state, BaseModel):
814
+ return final_state, results
815
+
816
+ # Try to find another key with a BaseModel
817
+ for key, value in last_node_results.items():
818
+ if isinstance(value, BaseModel) and type(value) == type(
819
+ state_model
820
+ ):
821
+ return value, results
822
+
823
+ # Fallback to original state
824
+ logger.warning(
825
+ "Failed to find final state in workflow results, returning original state"
826
+ )
827
+ return state_model, results