kailash 0.9.2__py3-none-any.whl → 0.9.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,776 @@
1
+ """
2
+ DynamicExecutionPlanner for creating optimized execution plans.
3
+
4
+ Creates execution plans based on runtime conditional results, pruning unreachable
5
+ branches to optimize workflow execution performance.
6
+ """
7
+
8
+ import logging
9
+ from collections import defaultdict, deque
10
+ from typing import Any, Dict, List, Optional, Set, Tuple
11
+
12
+ import networkx as nx
13
+
14
+ from kailash.analysis.conditional_branch_analyzer import ConditionalBranchAnalyzer
15
+ from kailash.workflow.graph import Workflow
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class DynamicExecutionPlanner:
21
+ """
22
+ Creates execution plans based on runtime conditional results.
23
+
24
+ The planner analyzes SwitchNode results and creates optimized execution plans
25
+ that skip unreachable branches entirely, improving performance.
26
+ """
27
+
28
+ def __init__(self, workflow: Workflow):
29
+ """
30
+ Initialize the DynamicExecutionPlanner.
31
+
32
+ Args:
33
+ workflow: The workflow to create execution plans for
34
+ """
35
+ self.workflow = workflow
36
+ self.analyzer = ConditionalBranchAnalyzer(workflow)
37
+ self._execution_plan_cache: Dict[str, List[str]] = {}
38
+ self._dependency_cache: Optional[Dict[str, List[str]]] = None
39
+
40
+ def create_execution_plan(
41
+ self, switch_results: Dict[str, Dict[str, Any]]
42
+ ) -> List[str]:
43
+ """
44
+ Create pruned execution plan based on SwitchNode results.
45
+
46
+ Args:
47
+ switch_results: Dictionary mapping switch_id -> {output_port -> result}
48
+ None results indicate the port was not activated
49
+
50
+ Returns:
51
+ List of node IDs in execution order, with unreachable branches pruned
52
+ """
53
+ # Handle None or invalid switch_results
54
+ if switch_results is None:
55
+ return self._get_all_nodes_topological_order()
56
+
57
+ if not switch_results:
58
+ # No switches or no switch results - return all nodes in topological order
59
+ return self._get_all_nodes_topological_order()
60
+
61
+ # Create cache key from switch results
62
+ cache_key = self._create_cache_key(switch_results)
63
+ if cache_key in self._execution_plan_cache:
64
+ logger.debug("Using cached execution plan")
65
+ return self._execution_plan_cache[cache_key]
66
+
67
+ try:
68
+ # Get all nodes in topological order
69
+ all_nodes = self._get_all_nodes_topological_order()
70
+
71
+ # Determine reachable nodes
72
+ reachable_nodes = self.analyzer.get_reachable_nodes(switch_results)
73
+
74
+ # Add nodes that are not dependent on any switches (always reachable)
75
+ always_reachable = self._get_always_reachable_nodes(switch_results.keys())
76
+ reachable_nodes.update(always_reachable)
77
+
78
+ # Create pruned execution plan
79
+ pruned_plan = self._prune_unreachable_branches(all_nodes, reachable_nodes)
80
+
81
+ # Cache the result
82
+ self._execution_plan_cache[cache_key] = pruned_plan
83
+
84
+ logger.info(
85
+ f"Created execution plan: {len(pruned_plan)}/{len(all_nodes)} nodes"
86
+ )
87
+ return pruned_plan
88
+
89
+ except Exception as e:
90
+ logger.error(f"Error creating execution plan: {e}")
91
+ # Fallback to all nodes
92
+ return self._get_all_nodes_topological_order()
93
+
94
+ def _analyze_dependencies(self) -> Dict[str, List[str]]:
95
+ """
96
+ Analyze dependencies for SwitchNode execution ordering.
97
+
98
+ Returns:
99
+ Dictionary mapping node_id -> list of dependency node_ids
100
+ """
101
+ if self._dependency_cache is not None:
102
+ return self._dependency_cache
103
+
104
+ dependencies = defaultdict(list)
105
+
106
+ if not hasattr(self.workflow, "graph") or self.workflow.graph is None:
107
+ return dict(dependencies)
108
+
109
+ # Build dependency map from graph edges
110
+ for source, target, edge_data in self.workflow.graph.edges(data=True):
111
+ dependencies[target].append(source)
112
+
113
+ # Ensure all nodes are in the dependency map
114
+ for node_id in self.workflow.graph.nodes():
115
+ if node_id not in dependencies:
116
+ dependencies[node_id] = []
117
+
118
+ self._dependency_cache = dict(dependencies)
119
+ logger.debug(f"Analyzed dependencies for {len(dependencies)} nodes")
120
+ return self._dependency_cache
121
+
122
+ def _prune_unreachable_branches(
123
+ self, all_nodes: List[str], reachable_nodes: Set[str]
124
+ ) -> List[str]:
125
+ """
126
+ Prune unreachable branches from execution plan.
127
+
128
+ Args:
129
+ all_nodes: All nodes in topological order
130
+ reachable_nodes: Set of nodes that are reachable
131
+
132
+ Returns:
133
+ Pruned list of nodes in execution order
134
+ """
135
+ pruned_plan = []
136
+
137
+ for node_id in all_nodes:
138
+ if node_id in reachable_nodes:
139
+ pruned_plan.append(node_id)
140
+ else:
141
+ logger.debug(f"Pruning unreachable node: {node_id}")
142
+
143
+ return pruned_plan
144
+
145
+ def validate_execution_plan(
146
+ self, execution_plan: List[str]
147
+ ) -> Tuple[bool, List[str]]:
148
+ """
149
+ Validate execution plan for correctness.
150
+
151
+ Args:
152
+ execution_plan: List of node IDs in execution order
153
+
154
+ Returns:
155
+ Tuple of (is_valid, list_of_errors)
156
+ """
157
+ errors = []
158
+
159
+ if not hasattr(self.workflow, "graph") or self.workflow.graph is None:
160
+ errors.append("Workflow has no graph to validate against")
161
+ return False, errors
162
+
163
+ # Check all nodes exist in workflow
164
+ workflow_nodes = set(self.workflow.graph.nodes())
165
+ for node_id in execution_plan:
166
+ if node_id not in workflow_nodes:
167
+ errors.append(f"Node '{node_id}' not found in workflow")
168
+
169
+ # Check dependencies are satisfied
170
+ dependencies = self._analyze_dependencies()
171
+ seen_nodes = set()
172
+
173
+ for node_id in execution_plan:
174
+ # Check if all dependencies have been seen
175
+ for dep in dependencies.get(node_id, []):
176
+ if dep not in seen_nodes:
177
+ if dep in execution_plan:
178
+ # Dependency is in plan but hasn't been seen yet - order issue
179
+ errors.append(
180
+ f"Node '{node_id}' dependency '{dep}' not satisfied"
181
+ )
182
+ else:
183
+ # Dependency is missing from execution plan entirely
184
+ errors.append(
185
+ f"Node '{node_id}' dependency '{dep}' missing from execution plan"
186
+ )
187
+
188
+ seen_nodes.add(node_id)
189
+
190
+ is_valid = len(errors) == 0
191
+ if is_valid:
192
+ logger.debug("Execution plan validation passed")
193
+ else:
194
+ logger.warning(f"Execution plan validation failed: {errors}")
195
+
196
+ return is_valid, errors
197
+
198
+ def _get_all_nodes_topological_order(self) -> List[str]:
199
+ """Get all nodes in topological order."""
200
+ if not hasattr(self.workflow, "graph") or self.workflow.graph is None:
201
+ return []
202
+
203
+ try:
204
+ return list(nx.topological_sort(self.workflow.graph))
205
+ except Exception as e:
206
+ logger.error(f"Error getting topological order: {e}")
207
+ # Fallback to node list
208
+ return list(self.workflow.graph.nodes())
209
+
210
+ def _get_always_reachable_nodes(self, switch_node_ids: Set[str]) -> Set[str]:
211
+ """
212
+ Get nodes that are always reachable (not dependent on any switches).
213
+
214
+ Args:
215
+ switch_node_ids: Set of switch node IDs
216
+
217
+ Returns:
218
+ Set of node IDs that are always reachable
219
+ """
220
+ always_reachable = set()
221
+
222
+ if not hasattr(self.workflow, "graph") or self.workflow.graph is None:
223
+ return always_reachable
224
+
225
+ # Find nodes that don't depend on any switches
226
+ for node_id in self.workflow.graph.nodes():
227
+ if self._is_reachable_without_switches(node_id, switch_node_ids):
228
+ always_reachable.add(node_id)
229
+
230
+ logger.debug(f"Found {len(always_reachable)} always reachable nodes")
231
+ return always_reachable
232
+
233
+ def _is_reachable_without_switches(
234
+ self, node_id: str, switch_node_ids: Set[str]
235
+ ) -> bool:
236
+ """
237
+ Check if a node is reachable without going through any switches.
238
+
239
+ Args:
240
+ node_id: Node to check
241
+ switch_node_ids: Set of switch node IDs
242
+
243
+ Returns:
244
+ True if node is reachable without switches
245
+ """
246
+ if node_id in switch_node_ids:
247
+ return True # Switches themselves are always reachable
248
+
249
+ # BFS backwards to see if we can reach a source without going through switches
250
+ visited = set()
251
+ queue = deque([node_id])
252
+
253
+ while queue:
254
+ current = queue.popleft()
255
+ if current in visited:
256
+ continue
257
+
258
+ visited.add(current)
259
+
260
+ # Get predecessors
261
+ if hasattr(self.workflow.graph, "predecessors"):
262
+ predecessors = list(self.workflow.graph.predecessors(current))
263
+
264
+ if not predecessors:
265
+ # Found a source node - reachable without switches
266
+ return True
267
+
268
+ for pred in predecessors:
269
+ if pred in switch_node_ids:
270
+ # Path goes through a switch - not always reachable
271
+ continue
272
+ queue.append(pred)
273
+
274
+ return False
275
+
276
+ def _create_cache_key(self, switch_results: Dict[str, Dict[str, Any]]) -> str:
277
+ """Create cache key from switch results."""
278
+ # Create a stable string representation of switch results
279
+ key_parts = []
280
+ for switch_id in sorted(switch_results.keys()):
281
+ ports = switch_results[switch_id]
282
+ port_parts = []
283
+
284
+ # Handle None ports (invalid switch results)
285
+ if ports is None:
286
+ port_parts.append("None")
287
+ elif isinstance(ports, dict):
288
+ for port in sorted(ports.keys()):
289
+ result = ports[port]
290
+ # Create simple representation (None vs not-None)
291
+ result_repr = "None" if result is None else "active"
292
+ port_parts.append(f"{port}:{result_repr}")
293
+ else:
294
+ # Invalid port format, represent as string
295
+ port_parts.append(f"invalid:{str(ports)}")
296
+
297
+ key_parts.append(f"{switch_id}({','.join(port_parts)})")
298
+
299
+ return "|".join(key_parts)
300
+
301
+ def create_hierarchical_plan(self, workflow: Workflow) -> List[List[str]]:
302
+ """
303
+ Create execution plan with hierarchical switch dependencies.
304
+
305
+ Execute SwitchNodes in dependency layers:
306
+ - Phase 1: Independent SwitchNodes
307
+ - Phase 2: Dependent SwitchNodes based on Phase 1 results
308
+ - Phase 3: Final conditional branches
309
+
310
+ Args:
311
+ workflow: Workflow to analyze
312
+
313
+ Returns:
314
+ List of execution phases, each containing list of node IDs
315
+ """
316
+ switch_nodes = self.analyzer._find_switch_nodes()
317
+ if not switch_nodes:
318
+ # No switches - single phase with all nodes
319
+ return [self._get_all_nodes_topological_order()]
320
+
321
+ dependencies = self._analyze_dependencies()
322
+
323
+ # Build switch dependency graph
324
+ switch_deps = nx.DiGraph()
325
+ switch_deps.add_nodes_from(switch_nodes)
326
+
327
+ for switch_id in switch_nodes:
328
+ for dep in dependencies.get(switch_id, []):
329
+ if dep in switch_nodes:
330
+ switch_deps.add_edge(dep, switch_id)
331
+
332
+ # Get switch execution layers
333
+ try:
334
+ layers = []
335
+ remaining_switches = set(switch_nodes)
336
+
337
+ while remaining_switches:
338
+ # Find switches with no dependencies in remaining set
339
+ current_layer = []
340
+ for switch_id in remaining_switches:
341
+ deps = [
342
+ d
343
+ for d in switch_deps.predecessors(switch_id)
344
+ if d in remaining_switches
345
+ ]
346
+ if not deps:
347
+ current_layer.append(switch_id)
348
+
349
+ if not current_layer:
350
+ # Circular dependency - add all remaining
351
+ current_layer = list(remaining_switches)
352
+
353
+ layers.append(current_layer)
354
+ remaining_switches -= set(current_layer)
355
+
356
+ return layers
357
+
358
+ except Exception as e:
359
+ logger.error(f"Error creating hierarchical plan: {e}")
360
+ # Fallback to single layer
361
+ return [switch_nodes]
362
+
363
+ def _handle_merge_with_conditional_inputs(
364
+ self,
365
+ merge_node: str,
366
+ workflow: Workflow,
367
+ switch_results: Dict[str, Dict[str, Any]],
368
+ ) -> bool:
369
+ """
370
+ Handle merge node with conditional inputs.
371
+
372
+ Args:
373
+ merge_node: ID of the merge node
374
+ workflow: Workflow containing the merge node
375
+ switch_results: Switch results to determine available inputs
376
+
377
+ Returns:
378
+ True if merge node should be included in execution plan
379
+ """
380
+ if not hasattr(workflow, "graph") or workflow.graph is None:
381
+ return True
382
+
383
+ # Check how many inputs to the merge node are available
384
+ available_inputs = 0
385
+
386
+ for pred in workflow.graph.predecessors(merge_node):
387
+ # Check if predecessor is reachable based on switch results
388
+ reachable_nodes = self.analyzer.get_reachable_nodes(switch_results)
389
+ if pred in reachable_nodes:
390
+ available_inputs += 1
391
+
392
+ # Merge nodes can typically handle partial inputs
393
+ # Include if at least one input is available
394
+ return available_inputs > 0
395
+
396
+ def invalidate_cache(self):
397
+ """Invalidate cached execution plans and dependencies."""
398
+ self._execution_plan_cache.clear()
399
+ self._dependency_cache = None
400
+ # Also invalidate the analyzer's cache when workflow structure changes
401
+ if hasattr(self.analyzer, "invalidate_cache"):
402
+ self.analyzer.invalidate_cache()
403
+ logger.debug("DynamicExecutionPlanner cache invalidated")
404
+
405
+ # ===== PHASE 4: ADVANCED FEATURES =====
406
+
407
+ def create_hierarchical_execution_plan(
408
+ self, switch_results: Dict[str, Dict[str, Any]]
409
+ ) -> Dict[str, Any]:
410
+ """
411
+ Create advanced execution plan with hierarchical switch support and merge strategies.
412
+
413
+ Args:
414
+ switch_results: Results from SwitchNode execution
415
+
416
+ Returns:
417
+ Dictionary containing detailed execution plan with layers and strategies
418
+ """
419
+ plan = {
420
+ "switch_layers": [],
421
+ "execution_plan": [],
422
+ "merge_strategies": {},
423
+ "performance_metrics": {},
424
+ "reachable_nodes": set(),
425
+ "skipped_nodes": set(),
426
+ }
427
+
428
+ try:
429
+ # Use analyzer's hierarchical capabilities
430
+ if hasattr(self.analyzer, "create_hierarchical_execution_plan"):
431
+ hierarchical_plan = self.analyzer.create_hierarchical_execution_plan(
432
+ switch_results
433
+ )
434
+ plan.update(hierarchical_plan)
435
+
436
+ # Create traditional execution plan as fallback
437
+ traditional_plan = self.create_execution_plan(switch_results)
438
+ if not plan["execution_plan"]:
439
+ plan["execution_plan"] = traditional_plan
440
+
441
+ # Calculate performance metrics
442
+ total_nodes = len(self.workflow.graph.nodes)
443
+ executed_nodes = len(plan["execution_plan"])
444
+ plan["performance_metrics"] = {
445
+ "total_nodes": total_nodes,
446
+ "executed_nodes": executed_nodes,
447
+ "skipped_nodes": total_nodes - executed_nodes,
448
+ "performance_improvement": (
449
+ (total_nodes - executed_nodes) / total_nodes
450
+ if total_nodes > 0
451
+ else 0
452
+ ),
453
+ }
454
+
455
+ # Convert sets to lists for JSON serialization
456
+ if isinstance(plan["reachable_nodes"], set):
457
+ plan["reachable_nodes"] = list(plan["reachable_nodes"])
458
+ if isinstance(plan["skipped_nodes"], set):
459
+ plan["skipped_nodes"] = list(plan["skipped_nodes"])
460
+
461
+ except Exception as e:
462
+ logger.warning(f"Error creating hierarchical execution plan: {e}")
463
+ # Fallback to basic execution plan
464
+ plan["execution_plan"] = self.create_execution_plan(switch_results)
465
+
466
+ return plan
467
+
468
+ def handle_merge_nodes_with_conditional_inputs(
469
+ self, execution_plan: List[str], switch_results: Dict[str, Dict[str, Any]]
470
+ ) -> Dict[str, Any]:
471
+ """
472
+ Analyze and handle MergeNodes that receive conditional inputs.
473
+
474
+ Args:
475
+ execution_plan: Current execution plan
476
+ switch_results: Results from switch execution
477
+
478
+ Returns:
479
+ Dictionary with merge handling strategies
480
+ """
481
+ merge_handling = {
482
+ "merge_nodes": [],
483
+ "strategies": {},
484
+ "execution_modifications": [],
485
+ "warnings": [],
486
+ }
487
+
488
+ try:
489
+ # Find merge nodes in the workflow
490
+ if hasattr(self.analyzer, "_find_merge_nodes"):
491
+ merge_nodes = self.analyzer._find_merge_nodes()
492
+ else:
493
+ merge_nodes = self._find_merge_nodes_fallback()
494
+
495
+ reachable_nodes = set(execution_plan)
496
+
497
+ for merge_id in merge_nodes:
498
+ if merge_id in execution_plan:
499
+ strategy = self._create_merge_strategy(
500
+ merge_id, reachable_nodes, switch_results
501
+ )
502
+ merge_handling["strategies"][merge_id] = strategy
503
+ merge_handling["merge_nodes"].append(merge_id)
504
+
505
+ # Add execution modifications if needed
506
+ if strategy["strategy_type"] == "skip":
507
+ merge_handling["execution_modifications"].append(
508
+ {
509
+ "type": "skip_node",
510
+ "node_id": merge_id,
511
+ "reason": "No available inputs",
512
+ }
513
+ )
514
+ elif strategy["strategy_type"] == "partial":
515
+ merge_handling["execution_modifications"].append(
516
+ {
517
+ "type": "partial_merge",
518
+ "node_id": merge_id,
519
+ "available_inputs": strategy["available_inputs"],
520
+ "missing_inputs": strategy["missing_inputs"],
521
+ }
522
+ )
523
+
524
+ except Exception as e:
525
+ logger.warning(f"Error handling merge nodes: {e}")
526
+ merge_handling["warnings"].append(f"Merge node analysis failed: {e}")
527
+
528
+ return merge_handling
529
+
530
+ def _find_merge_nodes_fallback(self) -> List[str]:
531
+ """Fallback method to find merge nodes when analyzer doesn't have the method."""
532
+ merge_nodes = []
533
+
534
+ try:
535
+ from kailash.nodes.logic.operations import MergeNode
536
+
537
+ for node_id, node_data in self.workflow.graph.nodes(data=True):
538
+ node_instance = node_data.get("node") or node_data.get("instance")
539
+ if node_instance and isinstance(node_instance, MergeNode):
540
+ merge_nodes.append(node_id)
541
+
542
+ except Exception as e:
543
+ logger.warning(f"Error in merge node fallback detection: {e}")
544
+
545
+ return merge_nodes
546
+
547
+ def _create_merge_strategy(
548
+ self,
549
+ merge_id: str,
550
+ reachable_nodes: Set[str],
551
+ switch_results: Dict[str, Dict[str, Any]],
552
+ ) -> Dict[str, Any]:
553
+ """
554
+ Create merge strategy for a specific MergeNode.
555
+
556
+ Args:
557
+ merge_id: ID of the merge node
558
+ reachable_nodes: Set of nodes that will be executed
559
+ switch_results: Results from switch execution
560
+
561
+ Returns:
562
+ Dictionary describing merge strategy
563
+ """
564
+ strategy = {
565
+ "merge_id": merge_id,
566
+ "available_inputs": [],
567
+ "missing_inputs": [],
568
+ "strategy_type": "unknown",
569
+ "confidence": 0.0,
570
+ "recommendations": [],
571
+ }
572
+
573
+ try:
574
+ # Get all predecessors of the merge node
575
+ predecessors = list(self.workflow.graph.predecessors(merge_id))
576
+
577
+ for pred in predecessors:
578
+ if pred in reachable_nodes:
579
+ strategy["available_inputs"].append(pred)
580
+ else:
581
+ strategy["missing_inputs"].append(pred)
582
+
583
+ # Determine strategy type
584
+ available_count = len(strategy["available_inputs"])
585
+ total_count = len(predecessors)
586
+
587
+ if available_count == 0:
588
+ strategy["strategy_type"] = "skip"
589
+ strategy["confidence"] = 1.0
590
+ strategy["recommendations"].append(
591
+ "Skip merge node - no inputs available"
592
+ )
593
+ elif available_count == total_count:
594
+ strategy["strategy_type"] = "full"
595
+ strategy["confidence"] = 1.0
596
+ strategy["recommendations"].append("Execute merge with all inputs")
597
+ else:
598
+ strategy["strategy_type"] = "partial"
599
+ strategy["confidence"] = available_count / total_count
600
+ strategy["recommendations"].append(
601
+ f"Execute merge with {available_count}/{total_count} inputs"
602
+ )
603
+ strategy["recommendations"].append(
604
+ "Consider merge node's skip_none parameter"
605
+ )
606
+
607
+ except Exception as e:
608
+ logger.warning(f"Error creating merge strategy for {merge_id}: {e}")
609
+ strategy["strategy_type"] = "error"
610
+ strategy["recommendations"].append(f"Error analyzing merge: {e}")
611
+
612
+ return strategy
613
+
614
+ def optimize_execution_plan(
615
+ self, execution_plan: List[str], switch_results: Dict[str, Dict[str, Any]]
616
+ ) -> Dict[str, Any]:
617
+ """
618
+ Optimize execution plan with advanced performance techniques.
619
+
620
+ Args:
621
+ execution_plan: Basic execution plan
622
+ switch_results: Results from switch execution
623
+
624
+ Returns:
625
+ Dictionary with optimized execution plan and performance data
626
+ """
627
+ optimization_result = {
628
+ "original_plan": execution_plan.copy(),
629
+ "optimized_plan": execution_plan.copy(),
630
+ "optimizations_applied": [],
631
+ "performance_improvement": 0.0,
632
+ "analysis": {},
633
+ }
634
+
635
+ try:
636
+ # Parallel execution opportunities
637
+ parallel_groups = self._identify_parallel_execution_groups(execution_plan)
638
+ if parallel_groups:
639
+ optimization_result["optimizations_applied"].append(
640
+ "parallel_execution_grouping"
641
+ )
642
+ optimization_result["analysis"]["parallel_groups"] = parallel_groups
643
+
644
+ # Merge node optimizations
645
+ merge_handling = self.handle_merge_nodes_with_conditional_inputs(
646
+ execution_plan, switch_results
647
+ )
648
+ if merge_handling["execution_modifications"]:
649
+ optimization_result["optimizations_applied"].append(
650
+ "merge_node_optimization"
651
+ )
652
+ optimization_result["analysis"]["merge_optimizations"] = merge_handling
653
+
654
+ # Apply merge optimizations to plan
655
+ modified_plan = execution_plan.copy()
656
+ for mod in merge_handling["execution_modifications"]:
657
+ if mod["type"] == "skip_node":
658
+ if mod["node_id"] in modified_plan:
659
+ modified_plan.remove(mod["node_id"])
660
+
661
+ optimization_result["optimized_plan"] = modified_plan
662
+
663
+ # Calculate performance improvement
664
+ original_count = len(optimization_result["original_plan"])
665
+ optimized_count = len(optimization_result["optimized_plan"])
666
+ if original_count > 0:
667
+ optimization_result["performance_improvement"] = (
668
+ original_count - optimized_count
669
+ ) / original_count
670
+
671
+ except Exception as e:
672
+ logger.warning(f"Error optimizing execution plan: {e}")
673
+ optimization_result["analysis"]["error"] = str(e)
674
+
675
+ return optimization_result
676
+
677
+ def _identify_parallel_execution_groups(
678
+ self, execution_plan: List[str]
679
+ ) -> List[List[str]]:
680
+ """
681
+ Identify groups of nodes that can be executed in parallel.
682
+
683
+ Args:
684
+ execution_plan: List of nodes in execution order
685
+
686
+ Returns:
687
+ List of parallel execution groups
688
+ """
689
+ parallel_groups = []
690
+
691
+ try:
692
+ # Build dependency graph for nodes in the execution plan
693
+ plan_nodes = set(execution_plan)
694
+ dependencies = {}
695
+
696
+ for node_id in execution_plan:
697
+ predecessors = list(self.workflow.graph.predecessors(node_id))
698
+ # Only consider dependencies within the execution plan
699
+ plan_predecessors = [
700
+ pred for pred in predecessors if pred in plan_nodes
701
+ ]
702
+ dependencies[node_id] = plan_predecessors
703
+
704
+ # Group nodes by their dependency depth
705
+ depth_groups = self._group_by_dependency_depth(execution_plan, dependencies)
706
+
707
+ # Each depth level can potentially be executed in parallel
708
+ for depth, nodes in depth_groups.items():
709
+ if len(nodes) > 1:
710
+ parallel_groups.append(nodes)
711
+
712
+ except Exception as e:
713
+ logger.warning(f"Error identifying parallel execution groups: {e}")
714
+
715
+ return parallel_groups
716
+
717
+ def _group_by_dependency_depth(
718
+ self, execution_plan: List[str], dependencies: Dict[str, List[str]]
719
+ ) -> Dict[int, List[str]]:
720
+ """
721
+ Group nodes by their dependency depth for parallel execution analysis.
722
+
723
+ Args:
724
+ execution_plan: List of nodes in execution order
725
+ dependencies: Dictionary mapping node_id -> list of dependencies
726
+
727
+ Returns:
728
+ Dictionary mapping depth -> list of nodes at that depth
729
+ """
730
+ depth_groups = defaultdict(list)
731
+ node_depths = {}
732
+
733
+ try:
734
+ # Calculate depth for each node
735
+ for node_id in execution_plan:
736
+ depth = self._calculate_node_depth(node_id, dependencies, node_depths)
737
+ depth_groups[depth].append(node_id)
738
+ node_depths[node_id] = depth
739
+
740
+ except Exception as e:
741
+ logger.warning(f"Error grouping by dependency depth: {e}")
742
+
743
+ return dict(depth_groups)
744
+
745
+ def _calculate_node_depth(
746
+ self,
747
+ node_id: str,
748
+ dependencies: Dict[str, List[str]],
749
+ node_depths: Dict[str, int],
750
+ ) -> int:
751
+ """
752
+ Calculate the dependency depth of a node.
753
+
754
+ Args:
755
+ node_id: ID of the node
756
+ dependencies: Dictionary mapping node_id -> list of dependencies
757
+ node_depths: Cache of already calculated depths
758
+
759
+ Returns:
760
+ Dependency depth of the node
761
+ """
762
+ if node_id in node_depths:
763
+ return node_depths[node_id]
764
+
765
+ node_deps = dependencies.get(node_id, [])
766
+ if not node_deps:
767
+ # No dependencies - depth 0
768
+ return 0
769
+
770
+ # Depth is 1 + max depth of dependencies
771
+ max_dep_depth = 0
772
+ for dep in node_deps:
773
+ dep_depth = self._calculate_node_depth(dep, dependencies, node_depths)
774
+ max_dep_depth = max(max_dep_depth, dep_depth)
775
+
776
+ return max_dep_depth + 1