kailash 0.9.1__py3-none-any.whl → 0.9.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,548 @@
1
+ """
2
+ Hierarchical Switch Executor for LocalRuntime.
3
+
4
+ This module implements hierarchical switch execution to optimize conditional workflow
5
+ execution by respecting switch dependencies and executing them in layers.
6
+ """
7
+
8
+ import asyncio
9
+ import logging
10
+ from typing import Any, Dict, List, Optional, Set, Tuple
11
+
12
+ from kailash.analysis import ConditionalBranchAnalyzer
13
+ from kailash.tracking import TaskManager
14
+ from kailash.workflow.graph import Workflow
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class HierarchicalSwitchExecutor:
20
+ """
21
+ Executes switches in hierarchical layers to optimize conditional execution.
22
+
23
+ This executor analyzes switch dependencies and executes them in layers where
24
+ switches in the same layer can be executed in parallel, and each layer depends
25
+ on the results of the previous layer.
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ workflow: Workflow,
31
+ debug: bool = False,
32
+ max_parallelism: int = 10,
33
+ layer_timeout: float = None,
34
+ ):
35
+ """
36
+ Initialize the hierarchical switch executor.
37
+
38
+ Args:
39
+ workflow: The workflow containing switches to execute
40
+ debug: Enable debug logging
41
+ max_parallelism: Maximum concurrent switches per layer
42
+ layer_timeout: Timeout in seconds for each layer execution
43
+ """
44
+ self.workflow = workflow
45
+ self.debug = debug
46
+ self.analyzer = ConditionalBranchAnalyzer(workflow)
47
+ self.max_parallelism = max_parallelism
48
+ self.layer_timeout = layer_timeout
49
+ self._execution_metrics = {
50
+ "layer_timings": [],
51
+ "parallelism_achieved": [],
52
+ "errors_by_layer": [],
53
+ }
54
+
55
+ async def execute_switches_hierarchically(
56
+ self,
57
+ parameters: Dict[str, Any],
58
+ task_manager: Optional[TaskManager] = None,
59
+ run_id: str = "",
60
+ workflow_context: Dict[str, Any] = None,
61
+ node_executor=None, # Function to execute individual nodes
62
+ ) -> Tuple[Dict[str, Dict[str, Any]], Dict[str, Any]]:
63
+ """
64
+ Execute switches in hierarchical layers.
65
+
66
+ Args:
67
+ parameters: Node-specific parameters
68
+ task_manager: Task manager for execution tracking
69
+ run_id: Unique run identifier
70
+ workflow_context: Workflow execution context
71
+ node_executor: Function to execute individual nodes
72
+
73
+ Returns:
74
+ Tuple of (all_results, switch_results) where:
75
+ - all_results: All node execution results including dependencies
76
+ - switch_results: Just the switch node results
77
+ """
78
+ if workflow_context is None:
79
+ workflow_context = {}
80
+
81
+ all_results = {}
82
+ switch_results = {}
83
+
84
+ # Find all switch nodes
85
+ switch_node_ids = self.analyzer._find_switch_nodes()
86
+ if not switch_node_ids:
87
+ logger.info("No switch nodes found in workflow")
88
+ return all_results, switch_results
89
+
90
+ # Analyze switch hierarchies
91
+ hierarchy_info = self.analyzer.analyze_switch_hierarchies(switch_node_ids)
92
+ execution_layers = hierarchy_info.get("execution_layers", [])
93
+
94
+ if not execution_layers:
95
+ # Fallback to simple execution if no layers detected
96
+ logger.warning(
97
+ "No execution layers detected, falling back to simple execution"
98
+ )
99
+ execution_layers = [switch_node_ids]
100
+
101
+ logger.info(
102
+ f"Executing switches in {len(execution_layers)} hierarchical layers"
103
+ )
104
+
105
+ # Execute each layer
106
+ for layer_index, layer_switches in enumerate(execution_layers):
107
+ layer_start_time = asyncio.get_event_loop().time()
108
+ layer_errors = []
109
+
110
+ logger.info(
111
+ f"Executing layer {layer_index + 1}/{len(execution_layers)} with {len(layer_switches)} switches"
112
+ )
113
+
114
+ # First, execute dependencies for switches in this layer
115
+ layer_dependencies = self._get_layer_dependencies(
116
+ layer_switches, all_results.keys()
117
+ )
118
+
119
+ if layer_dependencies:
120
+ logger.debug(
121
+ f"Executing {len(layer_dependencies)} dependencies for layer {layer_index + 1}"
122
+ )
123
+ # Execute dependencies sequentially (could be optimized for parallel execution)
124
+ for dep_node_id in layer_dependencies:
125
+ if dep_node_id not in all_results:
126
+ result = await self._execute_node_with_dependencies(
127
+ dep_node_id,
128
+ all_results,
129
+ parameters,
130
+ task_manager,
131
+ run_id,
132
+ workflow_context,
133
+ node_executor,
134
+ )
135
+ if result is not None:
136
+ all_results[dep_node_id] = result
137
+
138
+ # Execute switches in this layer in parallel with concurrency limit
139
+ layer_tasks = []
140
+ for switch_id in layer_switches:
141
+ if switch_id not in all_results:
142
+ task = self._execute_switch_with_context(
143
+ switch_id,
144
+ all_results,
145
+ parameters,
146
+ task_manager,
147
+ run_id,
148
+ workflow_context,
149
+ node_executor,
150
+ )
151
+ layer_tasks.append((switch_id, task))
152
+
153
+ # Wait for all switches in this layer to complete with timeout
154
+ if layer_tasks:
155
+ # Apply concurrency limit by chunking tasks
156
+ task_chunks = [
157
+ layer_tasks[i : i + self.max_parallelism]
158
+ for i in range(0, len(layer_tasks), self.max_parallelism)
159
+ ]
160
+
161
+ for chunk in task_chunks:
162
+ try:
163
+ if self.layer_timeout:
164
+ chunk_results = await asyncio.wait_for(
165
+ asyncio.gather(
166
+ *[task for _, task in chunk], return_exceptions=True
167
+ ),
168
+ timeout=self.layer_timeout,
169
+ )
170
+ else:
171
+ chunk_results = await asyncio.gather(
172
+ *[task for _, task in chunk], return_exceptions=True
173
+ )
174
+
175
+ # Process results
176
+ for (switch_id, _), result in zip(chunk, chunk_results):
177
+ if isinstance(result, Exception):
178
+ logger.error(
179
+ f"Error executing switch {switch_id}: {result}"
180
+ )
181
+ layer_errors.append(
182
+ {"switch": switch_id, "error": str(result)}
183
+ )
184
+ # Store error result
185
+ all_results[switch_id] = {"error": str(result)}
186
+ switch_results[switch_id] = {"error": str(result)}
187
+ else:
188
+ all_results[switch_id] = result
189
+ switch_results[switch_id] = result
190
+
191
+ except asyncio.TimeoutError:
192
+ logger.error(
193
+ f"Layer {layer_index + 1} execution timed out after {self.layer_timeout}s"
194
+ )
195
+ for switch_id, _ in chunk:
196
+ if switch_id not in all_results:
197
+ error_msg = f"Timeout after {self.layer_timeout}s"
198
+ layer_errors.append(
199
+ {"switch": switch_id, "error": error_msg}
200
+ )
201
+ all_results[switch_id] = {"error": error_msg}
202
+ switch_results[switch_id] = {"error": error_msg}
203
+
204
+ # Record layer metrics
205
+ layer_execution_time = asyncio.get_event_loop().time() - layer_start_time
206
+ self._execution_metrics["layer_timings"].append(
207
+ {
208
+ "layer": layer_index + 1,
209
+ "switches": len(layer_switches),
210
+ "execution_time": layer_execution_time,
211
+ "parallelism": min(len(layer_switches), self.max_parallelism),
212
+ }
213
+ )
214
+ self._execution_metrics["parallelism_achieved"].append(
215
+ min(len(layer_switches), self.max_parallelism)
216
+ )
217
+ self._execution_metrics["errors_by_layer"].append(layer_errors)
218
+
219
+ # Log execution summary
220
+ successful_switches = sum(
221
+ 1 for r in switch_results.values() if "error" not in r
222
+ )
223
+ logger.info(
224
+ f"Hierarchical switch execution complete: {successful_switches}/{len(switch_results)} switches executed successfully"
225
+ )
226
+
227
+ return all_results, switch_results
228
+
229
+ def _get_layer_dependencies(
230
+ self, layer_switches: List[str], already_executed: Set[str]
231
+ ) -> List[str]:
232
+ """
233
+ Get all dependencies needed for switches in this layer.
234
+
235
+ Args:
236
+ layer_switches: Switches in the current layer
237
+ already_executed: Set of node IDs that have already been executed
238
+
239
+ Returns:
240
+ List of node IDs that need to be executed before the layer switches
241
+ """
242
+ dependencies = []
243
+ visited = set(already_executed)
244
+
245
+ for switch_id in layer_switches:
246
+ # Get all predecessors of this switch
247
+ predecessors = list(self.workflow.graph.predecessors(switch_id))
248
+
249
+ for pred_id in predecessors:
250
+ if pred_id not in visited:
251
+ # Recursively get dependencies of this predecessor
252
+ self._collect_dependencies(pred_id, dependencies, visited)
253
+
254
+ return dependencies
255
+
256
+ def _collect_dependencies(
257
+ self, node_id: str, dependencies: List[str], visited: Set[str]
258
+ ):
259
+ """
260
+ Recursively collect dependencies for a node.
261
+
262
+ Args:
263
+ node_id: Node to collect dependencies for
264
+ dependencies: List to append dependencies to
265
+ visited: Set of already visited nodes
266
+ """
267
+ if node_id in visited:
268
+ return
269
+
270
+ visited.add(node_id)
271
+
272
+ # Get predecessors
273
+ predecessors = list(self.workflow.graph.predecessors(node_id))
274
+
275
+ # Recursively collect their dependencies first (depth-first)
276
+ for pred_id in predecessors:
277
+ if pred_id not in visited:
278
+ self._collect_dependencies(pred_id, dependencies, visited)
279
+
280
+ # Add this node after its dependencies
281
+ dependencies.append(node_id)
282
+
283
+ async def _execute_node_with_dependencies(
284
+ self,
285
+ node_id: str,
286
+ all_results: Dict[str, Dict[str, Any]],
287
+ parameters: Dict[str, Any],
288
+ task_manager: Optional[TaskManager],
289
+ run_id: str,
290
+ workflow_context: Dict[str, Any],
291
+ node_executor,
292
+ ) -> Optional[Dict[str, Any]]:
293
+ """
294
+ Execute a node after ensuring its dependencies are met.
295
+
296
+ Args:
297
+ node_id: Node to execute
298
+ all_results: Results from previously executed nodes
299
+ parameters: Node-specific parameters
300
+ task_manager: Task manager for execution tracking
301
+ run_id: Unique run identifier
302
+ workflow_context: Workflow execution context
303
+ node_executor: Function to execute the node
304
+
305
+ Returns:
306
+ Execution result or None if execution failed
307
+ """
308
+ try:
309
+ # Get node instance
310
+ node_data = self.workflow.graph.nodes[node_id]
311
+ node_instance = node_data.get("node") or node_data.get("instance")
312
+
313
+ if node_instance is None:
314
+ logger.warning(f"No instance found for node {node_id}")
315
+ return None
316
+
317
+ # Execute using provided executor
318
+ if node_executor:
319
+ result = await node_executor(
320
+ node_id=node_id,
321
+ node_instance=node_instance,
322
+ all_results=all_results,
323
+ parameters=parameters,
324
+ task_manager=task_manager,
325
+ workflow=self.workflow,
326
+ workflow_context=workflow_context,
327
+ )
328
+ return result
329
+ else:
330
+ logger.error(f"No node executor provided for {node_id}")
331
+ return None
332
+
333
+ except Exception as e:
334
+ logger.error(f"Error executing node {node_id}: {e}")
335
+ return {"error": str(e)}
336
+
337
+ async def _execute_switch_with_context(
338
+ self,
339
+ switch_id: str,
340
+ all_results: Dict[str, Dict[str, Any]],
341
+ parameters: Dict[str, Any],
342
+ task_manager: Optional[TaskManager],
343
+ run_id: str,
344
+ workflow_context: Dict[str, Any],
345
+ node_executor,
346
+ ) -> Dict[str, Any]:
347
+ """
348
+ Execute a switch node with proper context from dependencies.
349
+
350
+ Args:
351
+ switch_id: Switch node to execute
352
+ all_results: Results from previously executed nodes
353
+ parameters: Node-specific parameters
354
+ task_manager: Task manager for execution tracking
355
+ run_id: Unique run identifier
356
+ workflow_context: Workflow execution context
357
+ node_executor: Function to execute the node
358
+
359
+ Returns:
360
+ Switch execution result
361
+ """
362
+ logger.debug(
363
+ f"Executing switch {switch_id} with context from {len(all_results)} previous results"
364
+ )
365
+
366
+ # Execute the switch using the standard node execution
367
+ result = await self._execute_node_with_dependencies(
368
+ switch_id,
369
+ all_results,
370
+ parameters,
371
+ task_manager,
372
+ run_id,
373
+ workflow_context,
374
+ node_executor,
375
+ )
376
+
377
+ if result and self.debug:
378
+ # Log switch decision for debugging
379
+ if "true_output" in result and result["true_output"] is not None:
380
+ logger.debug(f"Switch {switch_id} took TRUE branch")
381
+ elif "false_output" in result and result["false_output"] is not None:
382
+ logger.debug(f"Switch {switch_id} took FALSE branch")
383
+ else:
384
+ logger.debug(f"Switch {switch_id} result: {result}")
385
+
386
+ return result or {"error": "Switch execution failed"}
387
+
388
+ def get_execution_summary(
389
+ self, switch_results: Dict[str, Dict[str, Any]]
390
+ ) -> Dict[str, Any]:
391
+ """
392
+ Get a summary of the hierarchical switch execution.
393
+
394
+ Args:
395
+ switch_results: Results from switch execution
396
+
397
+ Returns:
398
+ Summary dictionary with execution statistics
399
+ """
400
+ hierarchy_info = self.analyzer.analyze_switch_hierarchies(
401
+ list(switch_results.keys())
402
+ )
403
+
404
+ summary = {
405
+ "total_switches": len(switch_results),
406
+ "successful_switches": sum(
407
+ 1 for r in switch_results.values() if "error" not in r
408
+ ),
409
+ "failed_switches": sum(1 for r in switch_results.values() if "error" in r),
410
+ "execution_layers": hierarchy_info.get("execution_layers", []),
411
+ "max_depth": hierarchy_info.get("max_depth", 0),
412
+ "has_circular_dependencies": hierarchy_info.get(
413
+ "has_circular_dependencies", False
414
+ ),
415
+ "dependency_chains": hierarchy_info.get("dependency_chains", []),
416
+ }
417
+
418
+ # Analyze branch decisions
419
+ true_branches = 0
420
+ false_branches = 0
421
+ multi_branches = 0
422
+
423
+ for result in switch_results.values():
424
+ if "error" not in result:
425
+ if "true_output" in result and result["true_output"] is not None:
426
+ true_branches += 1
427
+ elif "false_output" in result and result["false_output"] is not None:
428
+ false_branches += 1
429
+ else:
430
+ # Multi-way switch or other pattern
431
+ multi_branches += 1
432
+
433
+ summary["branch_decisions"] = {
434
+ "true_branches": true_branches,
435
+ "false_branches": false_branches,
436
+ "multi_branches": multi_branches,
437
+ }
438
+
439
+ return summary
440
+
441
+ def get_execution_metrics(self) -> Dict[str, Any]:
442
+ """
443
+ Get detailed execution metrics for performance analysis.
444
+
445
+ Returns:
446
+ Dictionary containing execution metrics
447
+ """
448
+ total_time = sum(
449
+ timing["execution_time"]
450
+ for timing in self._execution_metrics["layer_timings"]
451
+ )
452
+ total_errors = sum(
453
+ len(errors) for errors in self._execution_metrics["errors_by_layer"]
454
+ )
455
+
456
+ metrics = {
457
+ "total_execution_time": total_time,
458
+ "layer_count": len(self._execution_metrics["layer_timings"]),
459
+ "layer_timings": self._execution_metrics["layer_timings"],
460
+ "average_layer_time": (
461
+ total_time / len(self._execution_metrics["layer_timings"])
462
+ if self._execution_metrics["layer_timings"]
463
+ else 0
464
+ ),
465
+ "max_parallelism_used": (
466
+ max(self._execution_metrics["parallelism_achieved"])
467
+ if self._execution_metrics["parallelism_achieved"]
468
+ else 0
469
+ ),
470
+ "total_errors": total_errors,
471
+ "errors_by_layer": self._execution_metrics["errors_by_layer"],
472
+ "configuration": {
473
+ "max_parallelism": self.max_parallelism,
474
+ "layer_timeout": self.layer_timeout,
475
+ },
476
+ }
477
+
478
+ return metrics
479
+
480
+ def handle_circular_dependencies(self, switch_nodes: List[str]) -> List[List[str]]:
481
+ """
482
+ Handle circular dependencies by breaking cycles intelligently.
483
+
484
+ Args:
485
+ switch_nodes: List of switch node IDs
486
+
487
+ Returns:
488
+ Execution layers with cycles broken
489
+ """
490
+ # Try to detect and break cycles
491
+ try:
492
+ import networkx as nx
493
+
494
+ # Create directed graph of dependencies
495
+ G = nx.DiGraph()
496
+ for switch_id in switch_nodes:
497
+ G.add_node(switch_id)
498
+ predecessors = list(self.workflow.graph.predecessors(switch_id))
499
+ for pred in predecessors:
500
+ if pred in switch_nodes:
501
+ G.add_edge(pred, switch_id)
502
+
503
+ # Find cycles
504
+ cycles = list(nx.simple_cycles(G))
505
+
506
+ if cycles:
507
+ logger.warning(
508
+ f"Found {len(cycles)} circular dependencies in switch hierarchy"
509
+ )
510
+
511
+ # Break cycles by removing back edges
512
+ for cycle in cycles:
513
+ # Remove the edge that creates the cycle (last -> first)
514
+ if len(cycle) >= 2:
515
+ G.remove_edge(cycle[-1], cycle[0])
516
+ logger.debug(
517
+ f"Breaking cycle by removing edge {cycle[-1]} -> {cycle[0]}"
518
+ )
519
+
520
+ # Now create layers from the acyclic graph
521
+ layers = []
522
+ remaining = set(switch_nodes)
523
+
524
+ while remaining:
525
+ # Find nodes with no incoming edges from remaining nodes
526
+ current_layer = []
527
+ for node in remaining:
528
+ has_deps = any(
529
+ pred in remaining for pred in G.predecessors(node)
530
+ )
531
+ if not has_deps:
532
+ current_layer.append(node)
533
+
534
+ if not current_layer:
535
+ # Shouldn't happen after breaking cycles, but handle it
536
+ current_layer = list(remaining)
537
+
538
+ layers.append(current_layer)
539
+ remaining -= set(current_layer)
540
+
541
+ return layers
542
+
543
+ except ImportError:
544
+ logger.warning("NetworkX not available for cycle detection")
545
+
546
+ # Fallback to analyzer's method
547
+ hierarchy_info = self.analyzer.analyze_switch_hierarchies(switch_nodes)
548
+ return hierarchy_info.get("execution_layers", [switch_nodes])