kailash 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. kailash/__init__.py +31 -0
  2. kailash/__main__.py +11 -0
  3. kailash/cli/__init__.py +5 -0
  4. kailash/cli/commands.py +563 -0
  5. kailash/manifest.py +778 -0
  6. kailash/nodes/__init__.py +23 -0
  7. kailash/nodes/ai/__init__.py +26 -0
  8. kailash/nodes/ai/agents.py +417 -0
  9. kailash/nodes/ai/models.py +488 -0
  10. kailash/nodes/api/__init__.py +52 -0
  11. kailash/nodes/api/auth.py +567 -0
  12. kailash/nodes/api/graphql.py +480 -0
  13. kailash/nodes/api/http.py +598 -0
  14. kailash/nodes/api/rate_limiting.py +572 -0
  15. kailash/nodes/api/rest.py +665 -0
  16. kailash/nodes/base.py +1032 -0
  17. kailash/nodes/base_async.py +128 -0
  18. kailash/nodes/code/__init__.py +32 -0
  19. kailash/nodes/code/python.py +1021 -0
  20. kailash/nodes/data/__init__.py +125 -0
  21. kailash/nodes/data/readers.py +496 -0
  22. kailash/nodes/data/sharepoint_graph.py +623 -0
  23. kailash/nodes/data/sql.py +380 -0
  24. kailash/nodes/data/streaming.py +1168 -0
  25. kailash/nodes/data/vector_db.py +964 -0
  26. kailash/nodes/data/writers.py +529 -0
  27. kailash/nodes/logic/__init__.py +6 -0
  28. kailash/nodes/logic/async_operations.py +702 -0
  29. kailash/nodes/logic/operations.py +551 -0
  30. kailash/nodes/transform/__init__.py +5 -0
  31. kailash/nodes/transform/processors.py +379 -0
  32. kailash/runtime/__init__.py +6 -0
  33. kailash/runtime/async_local.py +356 -0
  34. kailash/runtime/docker.py +697 -0
  35. kailash/runtime/local.py +434 -0
  36. kailash/runtime/parallel.py +557 -0
  37. kailash/runtime/runner.py +110 -0
  38. kailash/runtime/testing.py +347 -0
  39. kailash/sdk_exceptions.py +307 -0
  40. kailash/tracking/__init__.py +7 -0
  41. kailash/tracking/manager.py +885 -0
  42. kailash/tracking/metrics_collector.py +342 -0
  43. kailash/tracking/models.py +535 -0
  44. kailash/tracking/storage/__init__.py +0 -0
  45. kailash/tracking/storage/base.py +113 -0
  46. kailash/tracking/storage/database.py +619 -0
  47. kailash/tracking/storage/filesystem.py +543 -0
  48. kailash/utils/__init__.py +0 -0
  49. kailash/utils/export.py +924 -0
  50. kailash/utils/templates.py +680 -0
  51. kailash/visualization/__init__.py +62 -0
  52. kailash/visualization/api.py +732 -0
  53. kailash/visualization/dashboard.py +951 -0
  54. kailash/visualization/performance.py +808 -0
  55. kailash/visualization/reports.py +1471 -0
  56. kailash/workflow/__init__.py +15 -0
  57. kailash/workflow/builder.py +245 -0
  58. kailash/workflow/graph.py +827 -0
  59. kailash/workflow/mermaid_visualizer.py +628 -0
  60. kailash/workflow/mock_registry.py +63 -0
  61. kailash/workflow/runner.py +302 -0
  62. kailash/workflow/state.py +238 -0
  63. kailash/workflow/visualization.py +588 -0
  64. kailash-0.1.0.dist-info/METADATA +710 -0
  65. kailash-0.1.0.dist-info/RECORD +69 -0
  66. kailash-0.1.0.dist-info/WHEEL +5 -0
  67. kailash-0.1.0.dist-info/entry_points.txt +2 -0
  68. kailash-0.1.0.dist-info/licenses/LICENSE +21 -0
  69. kailash-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,628 @@
1
+ """Mermaid diagram visualization for workflows.
2
+
3
+ This module provides Mermaid diagram generation for workflow visualization,
4
+ offering a text-based format that can be embedded in markdown files and
5
+ rendered in various documentation platforms.
6
+ """
7
+
8
+ from typing import Dict, Optional, Tuple
9
+
10
+ from kailash.workflow.graph import Workflow
11
+
12
+
13
+ class MermaidVisualizer:
14
+ """Generate Mermaid diagrams for workflow visualization.
15
+
16
+ This class provides methods to convert Kailash workflows into Mermaid
17
+ diagram syntax, which can be embedded in markdown files for better
18
+ documentation and visualization.
19
+
20
+ Attributes:
21
+ workflow: The workflow to visualize
22
+ node_styles: Custom styles for different node types
23
+ direction: Graph direction (TB, LR, etc.)
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ workflow: Workflow,
29
+ direction: str = "TB",
30
+ node_styles: Optional[Dict[str, str]] = None,
31
+ ):
32
+ """Initialize the Mermaid visualizer.
33
+
34
+ Args:
35
+ workflow: The workflow to visualize
36
+ direction: Graph direction (TB=top-bottom, LR=left-right, etc.)
37
+ node_styles: Custom node styles mapping node types to Mermaid styles
38
+ """
39
+ self.workflow = workflow
40
+ self.direction = direction
41
+ self.node_styles = node_styles or self._default_node_styles()
42
+
43
+ def _default_node_styles(self) -> Dict[str, str]:
44
+ """Get default node styles for different node types.
45
+
46
+ Returns:
47
+ Dict mapping node type patterns to Mermaid style classes
48
+ """
49
+ return {
50
+ "reader": "fill:#e1f5fe,stroke:#01579b,stroke-width:2px",
51
+ "writer": "fill:#f3e5f5,stroke:#4a148c,stroke-width:2px",
52
+ "transform": "fill:#fff3e0,stroke:#e65100,stroke-width:2px",
53
+ "logic": "fill:#fce4ec,stroke:#880e4f,stroke-width:2px",
54
+ "ai": "fill:#e8f5e9,stroke:#1b5e20,stroke-width:2px",
55
+ "api": "fill:#f3e5f5,stroke:#4527a0,stroke-width:2px",
56
+ "code": "fill:#fffde7,stroke:#f57f17,stroke-width:2px",
57
+ "default": "fill:#f5f5f5,stroke:#424242,stroke-width:2px",
58
+ }
59
+
60
+ def _get_pattern_label(self, node_id: str, node_instance) -> str:
61
+ """Get a pattern-oriented label for a node.
62
+
63
+ Args:
64
+ node_id: The node ID
65
+ node_instance: The node instance
66
+
67
+ Returns:
68
+ Pattern-oriented label for the node
69
+ """
70
+ node_type = node_instance.node_type
71
+
72
+ # Try to get a meaningful name from the node
73
+ node = self.workflow.get_node(node_id)
74
+ if node and hasattr(node, "name") and node.name:
75
+ return node.name
76
+
77
+ # Otherwise use the node type with ID
78
+ clean_type = self._get_node_type_label(node_type)
79
+ # Use line break without parentheses to avoid Mermaid parsing issues
80
+ return f"{clean_type}<br/>{node_id}"
81
+
82
+ def _get_pattern_edge_label(self, source: str, target: str, data: Dict) -> str:
83
+ """Get a pattern-oriented edge label.
84
+
85
+ Args:
86
+ source: Source node ID
87
+ target: Target node ID
88
+ data: Edge data
89
+
90
+ Returns:
91
+ Pattern-oriented edge label
92
+ """
93
+ # Get basic edge label
94
+ basic_label = self._get_edge_label(source, target, data)
95
+
96
+ # Check if this is a validation or error path
97
+ source_node = self.workflow.nodes.get(source)
98
+ target_node = self.workflow.nodes.get(target)
99
+
100
+ if source_node and target_node:
101
+ source_type = source_node.node_type.lower()
102
+ target_type = target_node.node_type.lower()
103
+
104
+ # Check for validation patterns
105
+ if "valid" in source_type or "check" in source_type:
106
+ if "error" in target_type or "fail" in target_type:
107
+ return "Invalid"
108
+ elif basic_label:
109
+ return f"Valid|{basic_label}"
110
+ else:
111
+ return "Valid"
112
+
113
+ # Check for switch/router patterns
114
+ if "switch" in source_type or "router" in source_type:
115
+ if basic_label and "case_" in basic_label:
116
+ case_name = basic_label.replace("case_", "").split("→")[0]
117
+ return case_name.title()
118
+
119
+ return basic_label
120
+
121
+ def _get_pattern_style(self, node_type: str) -> str:
122
+ """Get pattern-oriented styling for a node type.
123
+
124
+ Args:
125
+ node_type: The node type
126
+
127
+ Returns:
128
+ Style string for the node
129
+ """
130
+ node_type_lower = node_type.lower()
131
+
132
+ # Data I/O nodes
133
+ if "reader" in node_type_lower:
134
+ return "fill:#e1f5fe,stroke:#01579b,stroke-width:2px"
135
+ elif "writer" in node_type_lower:
136
+ return "fill:#f3e5f5,stroke:#4a148c,stroke-width:2px"
137
+
138
+ # Validation nodes
139
+ elif any(x in node_type_lower for x in ["valid", "check", "verify"]):
140
+ return "fill:#fff3e0,stroke:#ff6f00,stroke-width:2px"
141
+
142
+ # Error handling nodes
143
+ elif any(x in node_type_lower for x in ["error", "fail", "exception"]):
144
+ return "fill:#ffebee,stroke:#c62828,stroke-width:2px"
145
+
146
+ # Logic nodes
147
+ elif any(x in node_type_lower for x in ["switch", "router", "conditional"]):
148
+ return "fill:#fce4ec,stroke:#880e4f,stroke-width:2px"
149
+ elif "merge" in node_type_lower:
150
+ return "fill:#f3e5f5,stroke:#4a148c,stroke-width:2px"
151
+
152
+ # Processing nodes
153
+ elif any(
154
+ x in node_type_lower
155
+ for x in ["transform", "filter", "process", "aggregate"]
156
+ ):
157
+ return "fill:#fff3e0,stroke:#e65100,stroke-width:2px"
158
+
159
+ # Code execution nodes
160
+ elif "python" in node_type_lower or "code" in node_type_lower:
161
+ return "fill:#fffde7,stroke:#f57f17,stroke-width:2px"
162
+
163
+ # AI/ML nodes
164
+ elif any(x in node_type_lower for x in ["ai", "ml", "model", "embedding"]):
165
+ return "fill:#e8f5e9,stroke:#2e7d32,stroke-width:2px"
166
+
167
+ # API nodes
168
+ elif any(x in node_type_lower for x in ["api", "http", "rest", "graphql"]):
169
+ return "fill:#e8eaf6,stroke:#283593,stroke-width:2px"
170
+
171
+ # Default
172
+ else:
173
+ return "fill:#f5f5f5,stroke:#616161,stroke-width:2px"
174
+
175
+ def _get_node_style(self, node_type: str) -> str:
176
+ """Get the style for a specific node type.
177
+
178
+ Args:
179
+ node_type: The type of the node
180
+
181
+ Returns:
182
+ Mermaid style string for the node
183
+ """
184
+ node_type_lower = node_type.lower()
185
+
186
+ if "reader" in node_type_lower:
187
+ return self.node_styles["reader"]
188
+ elif "writer" in node_type_lower:
189
+ return self.node_styles["writer"]
190
+ elif any(
191
+ x in node_type_lower
192
+ for x in ["transform", "filter", "processor", "aggregator"]
193
+ ):
194
+ return self.node_styles["transform"]
195
+ elif any(
196
+ x in node_type_lower for x in ["switch", "merge", "conditional", "logic"]
197
+ ):
198
+ return self.node_styles["logic"]
199
+ elif any(x in node_type_lower for x in ["ai", "llm", "model", "embedding"]):
200
+ return self.node_styles["ai"]
201
+ elif any(
202
+ x in node_type_lower for x in ["api", "http", "rest", "graphql", "oauth"]
203
+ ):
204
+ return self.node_styles["api"]
205
+ elif "python" in node_type_lower or "code" in node_type_lower:
206
+ return self.node_styles["code"]
207
+ else:
208
+ return self.node_styles["default"]
209
+
210
+ def _sanitize_node_id(self, node_id: str) -> str:
211
+ """Sanitize node ID for Mermaid compatibility.
212
+
213
+ Args:
214
+ node_id: Original node ID
215
+
216
+ Returns:
217
+ Sanitized node ID safe for Mermaid
218
+ """
219
+ # Replace special characters with underscores
220
+ sanitized = node_id.replace("-", "_").replace(" ", "_").replace(".", "_")
221
+ # Ensure it starts with a letter
222
+ if sanitized and sanitized[0].isdigit():
223
+ sanitized = f"node_{sanitized}"
224
+ return sanitized
225
+
226
+ def _get_node_label(self, node_id: str) -> str:
227
+ """Get display label for a node.
228
+
229
+ Args:
230
+ node_id: The node ID
231
+
232
+ Returns:
233
+ Display label for the node
234
+ """
235
+ node = self.workflow.get_node(node_id)
236
+ if node:
237
+ # Use node name if available
238
+ if hasattr(node, "name") and node.name:
239
+ return node.name
240
+ # Fall back to node type
241
+ if hasattr(node, "node_type"):
242
+ return f"{node_id}<br/>({node.node_type})"
243
+
244
+ # Last resort: use node instance from workflow
245
+ node_instance = self.workflow.nodes.get(node_id)
246
+ if node_instance:
247
+ return f"{node_id}<br/>({node_instance.node_type})"
248
+
249
+ return node_id
250
+
251
+ def _get_node_type_label(self, node_type: str) -> str:
252
+ """Get a clean label for a node type.
253
+
254
+ Args:
255
+ node_type: The node type string
256
+
257
+ Returns:
258
+ Clean label for display
259
+ """
260
+ # Remove 'Node' suffix if present
261
+ if node_type.endswith("Node"):
262
+ return node_type[:-4]
263
+ return node_type
264
+
265
+ def _get_node_shape(self, node_type: str) -> Tuple[str, str]:
266
+ """Get the shape brackets for a node type.
267
+
268
+ Args:
269
+ node_type: The type of the node
270
+
271
+ Returns:
272
+ Tuple of (opening bracket, closing bracket)
273
+ """
274
+ node_type_lower = node_type.lower()
275
+
276
+ # Different shapes for different node types
277
+ if "reader" in node_type_lower:
278
+ return "([", "])" # Stadium shape for inputs
279
+ elif "writer" in node_type_lower:
280
+ return "([", "])" # Stadium shape for outputs
281
+ elif any(x in node_type_lower for x in ["switch", "conditional"]):
282
+ return "{", "}" # Rhombus for decisions
283
+ elif any(x in node_type_lower for x in ["merge"]):
284
+ return "((", "))" # Circle for merge
285
+ else:
286
+ return "[", "]" # Rectangle for processing
287
+
288
+ def generate(self) -> str:
289
+ """Generate the Mermaid diagram code.
290
+
291
+ Returns:
292
+ Complete Mermaid diagram as a string
293
+ """
294
+ lines = []
295
+ lines.append(f"flowchart {self.direction}")
296
+ lines.append("")
297
+
298
+ # Identify source and sink nodes
299
+ source_nodes = []
300
+ sink_nodes = []
301
+ intermediate_nodes = []
302
+
303
+ for node_id in self.workflow.graph.nodes():
304
+ in_degree = self.workflow.graph.in_degree(node_id)
305
+ out_degree = self.workflow.graph.out_degree(node_id)
306
+
307
+ if in_degree == 0:
308
+ source_nodes.append(node_id)
309
+ elif out_degree == 0:
310
+ sink_nodes.append(node_id)
311
+ else:
312
+ intermediate_nodes.append(node_id)
313
+
314
+ # Add input data nodes if there are sources
315
+ if source_nodes:
316
+ lines.append(" %% Input Data")
317
+ lines.append(" input_data([Input Data])")
318
+ lines.append("")
319
+
320
+ # Group nodes by type for better organization
321
+ readers = []
322
+ writers = []
323
+ processors = []
324
+ validators = []
325
+ routers = []
326
+ mergers = []
327
+
328
+ # Categorize nodes
329
+ for node_id in self.workflow.graph.nodes():
330
+ node_instance = self.workflow.nodes.get(node_id)
331
+ if node_instance:
332
+ node_type = node_instance.node_type
333
+ node_type_lower = node_type.lower()
334
+
335
+ if "reader" in node_type_lower:
336
+ readers.append((node_id, node_instance))
337
+ elif "writer" in node_type_lower:
338
+ writers.append((node_id, node_instance))
339
+ elif any(
340
+ x in node_type_lower for x in ["switch", "router", "conditional"]
341
+ ):
342
+ routers.append((node_id, node_instance))
343
+ elif "merge" in node_type_lower:
344
+ mergers.append((node_id, node_instance))
345
+ elif any(x in node_type_lower for x in ["valid", "check", "verify"]):
346
+ validators.append((node_id, node_instance))
347
+ else:
348
+ processors.append((node_id, node_instance))
349
+
350
+ # Generate node definitions by category
351
+ if readers:
352
+ lines.append(" %% Data Input nodes")
353
+ for node_id, node_instance in readers:
354
+ sanitized_id = self._sanitize_node_id(node_id)
355
+ label = self._get_pattern_label(node_id, node_instance)
356
+ # Use quotes for labels with special characters
357
+ lines.append(f' {sanitized_id}["{label}"]')
358
+ lines.append("")
359
+
360
+ if validators:
361
+ lines.append(" %% Validation nodes")
362
+ for node_id, node_instance in validators:
363
+ sanitized_id = self._sanitize_node_id(node_id)
364
+ label = self._get_pattern_label(node_id, node_instance)
365
+ # Use quotes for labels with special characters
366
+ lines.append(f' {sanitized_id}{{"{label}"}}')
367
+ lines.append("")
368
+
369
+ if processors:
370
+ lines.append(" %% Processing nodes")
371
+ for node_id, node_instance in processors:
372
+ sanitized_id = self._sanitize_node_id(node_id)
373
+ label = self._get_pattern_label(node_id, node_instance)
374
+ # Use quotes for labels with special characters
375
+ lines.append(f' {sanitized_id}["{label}"]')
376
+ lines.append("")
377
+
378
+ if routers:
379
+ lines.append(" %% Routing/Decision nodes")
380
+ for node_id, node_instance in routers:
381
+ sanitized_id = self._sanitize_node_id(node_id)
382
+ label = self._get_pattern_label(node_id, node_instance)
383
+ # Use quotes for labels with special characters
384
+ lines.append(f' {sanitized_id}{{"{label}"}}')
385
+ lines.append("")
386
+
387
+ if mergers:
388
+ lines.append(" %% Merge nodes")
389
+ for node_id, node_instance in mergers:
390
+ sanitized_id = self._sanitize_node_id(node_id)
391
+ label = self._get_pattern_label(node_id, node_instance)
392
+ # Use quotes for labels with special characters
393
+ lines.append(f' {sanitized_id}(("{label}"))')
394
+ lines.append("")
395
+
396
+ if writers:
397
+ lines.append(" %% Data Output nodes")
398
+ for node_id, node_instance in writers:
399
+ sanitized_id = self._sanitize_node_id(node_id)
400
+ label = self._get_pattern_label(node_id, node_instance)
401
+ # Use quotes for labels with special characters
402
+ lines.append(f' {sanitized_id}["{label}"]')
403
+ lines.append("")
404
+
405
+ # Add output data node if there are sinks
406
+ if sink_nodes:
407
+ lines.append(" %% Output Data")
408
+ lines.append(" output_data([Output Data])")
409
+ lines.append("")
410
+
411
+ # Generate flow section
412
+ lines.append(" %% Flow")
413
+
414
+ # Connect input data to source nodes
415
+ if source_nodes:
416
+ for source in source_nodes:
417
+ sanitized_id = self._sanitize_node_id(source)
418
+ lines.append(f" input_data --> {sanitized_id}")
419
+
420
+ # Add all workflow edges
421
+ for source, target, data in self.workflow.graph.edges(data=True):
422
+ source_id = self._sanitize_node_id(source)
423
+ target_id = self._sanitize_node_id(target)
424
+
425
+ # Determine edge type for better visualization
426
+ edge_label = self._get_pattern_edge_label(source, target, data)
427
+
428
+ if edge_label:
429
+ lines.append(f" {source_id} -->|{edge_label}| {target_id}")
430
+ else:
431
+ lines.append(f" {source_id} --> {target_id}")
432
+
433
+ # Connect sink nodes to output data
434
+ if sink_nodes:
435
+ for sink in sink_nodes:
436
+ sanitized_id = self._sanitize_node_id(sink)
437
+ lines.append(f" {sanitized_id} --> output_data")
438
+
439
+ # Generate styling section
440
+ lines.append("")
441
+ lines.append(" %% Styling")
442
+
443
+ # Style input/output data nodes
444
+ if source_nodes:
445
+ lines.append(
446
+ " style input_data fill:#e3f2fd,stroke:#1565c0,stroke-width:2px,stroke-dasharray: 5 5"
447
+ )
448
+ if sink_nodes:
449
+ lines.append(
450
+ " style output_data fill:#e3f2fd,stroke:#1565c0,stroke-width:2px,stroke-dasharray: 5 5"
451
+ )
452
+
453
+ # Style workflow nodes
454
+ for node_id in self.workflow.graph.nodes():
455
+ sanitized_id = self._sanitize_node_id(node_id)
456
+ node_instance = self.workflow.nodes.get(node_id)
457
+ if node_instance:
458
+ style = self._get_pattern_style(node_instance.node_type)
459
+ lines.append(f" style {sanitized_id} {style}")
460
+
461
+ return "\n".join(lines)
462
+
463
+ def _get_edge_label(self, source: str, target: str, data: Dict) -> str:
464
+ """Get label for an edge.
465
+
466
+ Args:
467
+ source: Source node ID
468
+ target: Target node ID
469
+ data: Edge data dictionary
470
+
471
+ Returns:
472
+ Edge label string
473
+ """
474
+ # Check for direct output/input mapping
475
+ from_output = data.get("from_output")
476
+ to_input = data.get("to_input")
477
+
478
+ if from_output and to_input:
479
+ return f"{from_output}→{to_input}"
480
+
481
+ # Check for mapping dictionary
482
+ mapping = data.get("mapping", {})
483
+ if mapping:
484
+ # For single mapping, show inline
485
+ if len(mapping) == 1:
486
+ src, dst = next(iter(mapping.items()))
487
+ return f"{src}→{dst}"
488
+ # For multiple mappings, show count
489
+ else:
490
+ return f"{len(mapping)} mappings"
491
+
492
+ return ""
493
+
494
+ def generate_markdown(self, title: Optional[str] = None) -> str:
495
+ """Generate a complete markdown section with the Mermaid diagram.
496
+
497
+ Args:
498
+ title: Optional title for the diagram section
499
+
500
+ Returns:
501
+ Complete markdown text with embedded Mermaid diagram
502
+ """
503
+ lines = []
504
+
505
+ # Add title if provided
506
+ if title:
507
+ lines.append(f"## {title}")
508
+ lines.append("")
509
+ else:
510
+ lines.append(f"## Workflow: {self.workflow.name}")
511
+ lines.append("")
512
+
513
+ # Add description if available
514
+ if hasattr(self.workflow, "description") and self.workflow.description:
515
+ lines.append(f"_{self.workflow.description}_")
516
+ lines.append("")
517
+
518
+ # Add the Mermaid diagram
519
+ lines.append("```mermaid")
520
+ lines.append(self.generate())
521
+ lines.append("```")
522
+ lines.append("")
523
+
524
+ # Add node summary
525
+ lines.append("### Nodes")
526
+ lines.append("")
527
+ lines.append("| Node ID | Type | Description |")
528
+ lines.append("|---------|------|-------------|")
529
+
530
+ for node_id in sorted(self.workflow.graph.nodes()):
531
+ node = self.workflow.get_node(node_id)
532
+ node_instance = self.workflow.nodes.get(node_id)
533
+
534
+ if node_instance:
535
+ node_type = node_instance.node_type
536
+ description = ""
537
+
538
+ if node and hasattr(node, "__doc__") and node.__doc__:
539
+ # Get first line of docstring
540
+ description = node.__doc__.strip().split("\n")[0]
541
+
542
+ lines.append(f"| {node_id} | {node_type} | {description} |")
543
+
544
+ lines.append("")
545
+
546
+ # Add edge summary if there are connections
547
+ edges = list(self.workflow.graph.edges(data=True))
548
+ if edges:
549
+ lines.append("### Connections")
550
+ lines.append("")
551
+ lines.append("| From | To | Mapping |")
552
+ lines.append("|------|-----|---------|")
553
+
554
+ for source, target, data in edges:
555
+ edge_label = self._get_edge_label(source, target, data)
556
+ lines.append(f"| {source} | {target} | {edge_label} |")
557
+
558
+ lines.append("")
559
+
560
+ return "\n".join(lines)
561
+
562
+ def save_markdown(self, filepath: str, title: Optional[str] = None) -> None:
563
+ """Save the Mermaid diagram as a markdown file.
564
+
565
+ Args:
566
+ filepath: Path to save the markdown file
567
+ title: Optional title for the diagram
568
+ """
569
+ content = self.generate_markdown(title)
570
+ with open(filepath, "w") as f:
571
+ f.write(content)
572
+
573
+ def save_mermaid(self, filepath: str) -> None:
574
+ """Save just the Mermaid diagram code.
575
+
576
+ Args:
577
+ filepath: Path to save the Mermaid file
578
+ """
579
+ content = self.generate()
580
+ with open(filepath, "w") as f:
581
+ f.write(content)
582
+
583
+
584
+ def add_mermaid_to_workflow():
585
+ """Add Mermaid visualization methods to Workflow class."""
586
+
587
+ def to_mermaid(self, direction: str = "TB") -> str:
588
+ """Generate Mermaid diagram for this workflow.
589
+
590
+ Args:
591
+ direction: Graph direction (TB, LR, etc.)
592
+
593
+ Returns:
594
+ Mermaid diagram as string
595
+ """
596
+ visualizer = MermaidVisualizer(self, direction=direction)
597
+ return visualizer.generate()
598
+
599
+ def to_mermaid_markdown(self, title: Optional[str] = None) -> str:
600
+ """Generate markdown with embedded Mermaid diagram.
601
+
602
+ Args:
603
+ title: Optional title for the diagram
604
+
605
+ Returns:
606
+ Complete markdown text
607
+ """
608
+ visualizer = MermaidVisualizer(self)
609
+ return visualizer.generate_markdown(title)
610
+
611
+ def save_mermaid_markdown(self, filepath: str, title: Optional[str] = None) -> None:
612
+ """Save workflow as markdown with Mermaid diagram.
613
+
614
+ Args:
615
+ filepath: Path to save the markdown file
616
+ title: Optional title for the diagram
617
+ """
618
+ visualizer = MermaidVisualizer(self)
619
+ visualizer.save_markdown(filepath, title)
620
+
621
+ # Add methods to Workflow class
622
+ Workflow.to_mermaid = to_mermaid
623
+ Workflow.to_mermaid_markdown = to_mermaid_markdown
624
+ Workflow.save_mermaid_markdown = save_mermaid_markdown
625
+
626
+
627
+ # Call this when module is imported
628
+ add_mermaid_to_workflow()
@@ -0,0 +1,63 @@
1
+ """Mock node registry for tests."""
2
+
3
+ from typing import Any, Dict, Type
4
+
5
+ from kailash.nodes.base import Node, NodeRegistry
6
+ from kailash.sdk_exceptions import NodeConfigurationError
7
+
8
+
9
+ class MockNode(Node):
10
+ """Mock node for testing."""
11
+
12
+ def __init__(self, node_id: str = None, name: str = None, **kwargs):
13
+ """Initialize mock node."""
14
+ self.node_id = node_id
15
+ self.name = name or node_id
16
+ self.config = kwargs.copy()
17
+
18
+ def process(self, data: Dict[str, Any]) -> Dict[str, Any]:
19
+ """Process data."""
20
+ return {"value": data.get("value", 0) * 2}
21
+
22
+ def execute(self, **kwargs) -> Dict[str, Any]:
23
+ """Execute node with keyword arguments."""
24
+ return self.process(kwargs)
25
+
26
+ def get_parameters(self) -> Dict[str, Any]:
27
+ """Get node parameters."""
28
+ return {}
29
+
30
+
31
+ # Register mock nodes with the real registry for tests
32
+ NODE_TYPES = [
33
+ "MockNode",
34
+ "DataReader",
35
+ "DataWriter",
36
+ "Processor",
37
+ "Merger",
38
+ "DataFilter",
39
+ "AIProcessor",
40
+ "Transformer",
41
+ ]
42
+
43
+ for node_type in NODE_TYPES:
44
+ try:
45
+ NodeRegistry._registry[node_type] = MockNode
46
+ except:
47
+ pass
48
+
49
+
50
+ class MockRegistry:
51
+ """Mock node registry for testing."""
52
+
53
+ _registry: Dict[str, Type[Node]] = {node_type: MockNode for node_type in NODE_TYPES}
54
+
55
+ @classmethod
56
+ def get(cls, node_type: str) -> Type[Node]:
57
+ """Get node class by type name."""
58
+ if node_type not in cls._registry:
59
+ raise NodeConfigurationError(
60
+ f"Node '{node_type}' not found in registry. "
61
+ f"Available nodes: {list(cls._registry.keys())}"
62
+ )
63
+ return cls._registry[node_type]