kailash 0.9.2__py3-none-any.whl → 0.9.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kailash/__init__.py +1 -1
- kailash/analysis/__init__.py +9 -0
- kailash/analysis/conditional_branch_analyzer.py +696 -0
- kailash/nodes/logic/intelligent_merge.py +475 -0
- kailash/nodes/logic/operations.py +41 -8
- kailash/planning/__init__.py +9 -0
- kailash/planning/dynamic_execution_planner.py +776 -0
- kailash/runtime/compatibility_reporter.py +497 -0
- kailash/runtime/hierarchical_switch_executor.py +548 -0
- kailash/runtime/local.py +1904 -27
- kailash/runtime/parallel.py +1 -1
- kailash/runtime/performance_monitor.py +215 -0
- kailash/runtime/validation/import_validator.py +7 -0
- kailash/workflow/cyclic_runner.py +436 -27
- {kailash-0.9.2.dist-info → kailash-0.9.4.dist-info}/METADATA +22 -12
- {kailash-0.9.2.dist-info → kailash-0.9.4.dist-info}/RECORD +20 -12
- {kailash-0.9.2.dist-info → kailash-0.9.4.dist-info}/WHEEL +0 -0
- {kailash-0.9.2.dist-info → kailash-0.9.4.dist-info}/entry_points.txt +0 -0
- {kailash-0.9.2.dist-info → kailash-0.9.4.dist-info}/licenses/LICENSE +0 -0
- {kailash-0.9.2.dist-info → kailash-0.9.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,776 @@
|
|
1
|
+
"""
|
2
|
+
DynamicExecutionPlanner for creating optimized execution plans.
|
3
|
+
|
4
|
+
Creates execution plans based on runtime conditional results, pruning unreachable
|
5
|
+
branches to optimize workflow execution performance.
|
6
|
+
"""
|
7
|
+
|
8
|
+
import logging
|
9
|
+
from collections import defaultdict, deque
|
10
|
+
from typing import Any, Dict, List, Optional, Set, Tuple
|
11
|
+
|
12
|
+
import networkx as nx
|
13
|
+
|
14
|
+
from kailash.analysis.conditional_branch_analyzer import ConditionalBranchAnalyzer
|
15
|
+
from kailash.workflow.graph import Workflow
|
16
|
+
|
17
|
+
logger = logging.getLogger(__name__)
|
18
|
+
|
19
|
+
|
20
|
+
class DynamicExecutionPlanner:
|
21
|
+
"""
|
22
|
+
Creates execution plans based on runtime conditional results.
|
23
|
+
|
24
|
+
The planner analyzes SwitchNode results and creates optimized execution plans
|
25
|
+
that skip unreachable branches entirely, improving performance.
|
26
|
+
"""
|
27
|
+
|
28
|
+
def __init__(self, workflow: Workflow):
|
29
|
+
"""
|
30
|
+
Initialize the DynamicExecutionPlanner.
|
31
|
+
|
32
|
+
Args:
|
33
|
+
workflow: The workflow to create execution plans for
|
34
|
+
"""
|
35
|
+
self.workflow = workflow
|
36
|
+
self.analyzer = ConditionalBranchAnalyzer(workflow)
|
37
|
+
self._execution_plan_cache: Dict[str, List[str]] = {}
|
38
|
+
self._dependency_cache: Optional[Dict[str, List[str]]] = None
|
39
|
+
|
40
|
+
def create_execution_plan(
|
41
|
+
self, switch_results: Dict[str, Dict[str, Any]]
|
42
|
+
) -> List[str]:
|
43
|
+
"""
|
44
|
+
Create pruned execution plan based on SwitchNode results.
|
45
|
+
|
46
|
+
Args:
|
47
|
+
switch_results: Dictionary mapping switch_id -> {output_port -> result}
|
48
|
+
None results indicate the port was not activated
|
49
|
+
|
50
|
+
Returns:
|
51
|
+
List of node IDs in execution order, with unreachable branches pruned
|
52
|
+
"""
|
53
|
+
# Handle None or invalid switch_results
|
54
|
+
if switch_results is None:
|
55
|
+
return self._get_all_nodes_topological_order()
|
56
|
+
|
57
|
+
if not switch_results:
|
58
|
+
# No switches or no switch results - return all nodes in topological order
|
59
|
+
return self._get_all_nodes_topological_order()
|
60
|
+
|
61
|
+
# Create cache key from switch results
|
62
|
+
cache_key = self._create_cache_key(switch_results)
|
63
|
+
if cache_key in self._execution_plan_cache:
|
64
|
+
logger.debug("Using cached execution plan")
|
65
|
+
return self._execution_plan_cache[cache_key]
|
66
|
+
|
67
|
+
try:
|
68
|
+
# Get all nodes in topological order
|
69
|
+
all_nodes = self._get_all_nodes_topological_order()
|
70
|
+
|
71
|
+
# Determine reachable nodes
|
72
|
+
reachable_nodes = self.analyzer.get_reachable_nodes(switch_results)
|
73
|
+
|
74
|
+
# Add nodes that are not dependent on any switches (always reachable)
|
75
|
+
always_reachable = self._get_always_reachable_nodes(switch_results.keys())
|
76
|
+
reachable_nodes.update(always_reachable)
|
77
|
+
|
78
|
+
# Create pruned execution plan
|
79
|
+
pruned_plan = self._prune_unreachable_branches(all_nodes, reachable_nodes)
|
80
|
+
|
81
|
+
# Cache the result
|
82
|
+
self._execution_plan_cache[cache_key] = pruned_plan
|
83
|
+
|
84
|
+
logger.info(
|
85
|
+
f"Created execution plan: {len(pruned_plan)}/{len(all_nodes)} nodes"
|
86
|
+
)
|
87
|
+
return pruned_plan
|
88
|
+
|
89
|
+
except Exception as e:
|
90
|
+
logger.error(f"Error creating execution plan: {e}")
|
91
|
+
# Fallback to all nodes
|
92
|
+
return self._get_all_nodes_topological_order()
|
93
|
+
|
94
|
+
def _analyze_dependencies(self) -> Dict[str, List[str]]:
|
95
|
+
"""
|
96
|
+
Analyze dependencies for SwitchNode execution ordering.
|
97
|
+
|
98
|
+
Returns:
|
99
|
+
Dictionary mapping node_id -> list of dependency node_ids
|
100
|
+
"""
|
101
|
+
if self._dependency_cache is not None:
|
102
|
+
return self._dependency_cache
|
103
|
+
|
104
|
+
dependencies = defaultdict(list)
|
105
|
+
|
106
|
+
if not hasattr(self.workflow, "graph") or self.workflow.graph is None:
|
107
|
+
return dict(dependencies)
|
108
|
+
|
109
|
+
# Build dependency map from graph edges
|
110
|
+
for source, target, edge_data in self.workflow.graph.edges(data=True):
|
111
|
+
dependencies[target].append(source)
|
112
|
+
|
113
|
+
# Ensure all nodes are in the dependency map
|
114
|
+
for node_id in self.workflow.graph.nodes():
|
115
|
+
if node_id not in dependencies:
|
116
|
+
dependencies[node_id] = []
|
117
|
+
|
118
|
+
self._dependency_cache = dict(dependencies)
|
119
|
+
logger.debug(f"Analyzed dependencies for {len(dependencies)} nodes")
|
120
|
+
return self._dependency_cache
|
121
|
+
|
122
|
+
def _prune_unreachable_branches(
|
123
|
+
self, all_nodes: List[str], reachable_nodes: Set[str]
|
124
|
+
) -> List[str]:
|
125
|
+
"""
|
126
|
+
Prune unreachable branches from execution plan.
|
127
|
+
|
128
|
+
Args:
|
129
|
+
all_nodes: All nodes in topological order
|
130
|
+
reachable_nodes: Set of nodes that are reachable
|
131
|
+
|
132
|
+
Returns:
|
133
|
+
Pruned list of nodes in execution order
|
134
|
+
"""
|
135
|
+
pruned_plan = []
|
136
|
+
|
137
|
+
for node_id in all_nodes:
|
138
|
+
if node_id in reachable_nodes:
|
139
|
+
pruned_plan.append(node_id)
|
140
|
+
else:
|
141
|
+
logger.debug(f"Pruning unreachable node: {node_id}")
|
142
|
+
|
143
|
+
return pruned_plan
|
144
|
+
|
145
|
+
def validate_execution_plan(
|
146
|
+
self, execution_plan: List[str]
|
147
|
+
) -> Tuple[bool, List[str]]:
|
148
|
+
"""
|
149
|
+
Validate execution plan for correctness.
|
150
|
+
|
151
|
+
Args:
|
152
|
+
execution_plan: List of node IDs in execution order
|
153
|
+
|
154
|
+
Returns:
|
155
|
+
Tuple of (is_valid, list_of_errors)
|
156
|
+
"""
|
157
|
+
errors = []
|
158
|
+
|
159
|
+
if not hasattr(self.workflow, "graph") or self.workflow.graph is None:
|
160
|
+
errors.append("Workflow has no graph to validate against")
|
161
|
+
return False, errors
|
162
|
+
|
163
|
+
# Check all nodes exist in workflow
|
164
|
+
workflow_nodes = set(self.workflow.graph.nodes())
|
165
|
+
for node_id in execution_plan:
|
166
|
+
if node_id not in workflow_nodes:
|
167
|
+
errors.append(f"Node '{node_id}' not found in workflow")
|
168
|
+
|
169
|
+
# Check dependencies are satisfied
|
170
|
+
dependencies = self._analyze_dependencies()
|
171
|
+
seen_nodes = set()
|
172
|
+
|
173
|
+
for node_id in execution_plan:
|
174
|
+
# Check if all dependencies have been seen
|
175
|
+
for dep in dependencies.get(node_id, []):
|
176
|
+
if dep not in seen_nodes:
|
177
|
+
if dep in execution_plan:
|
178
|
+
# Dependency is in plan but hasn't been seen yet - order issue
|
179
|
+
errors.append(
|
180
|
+
f"Node '{node_id}' dependency '{dep}' not satisfied"
|
181
|
+
)
|
182
|
+
else:
|
183
|
+
# Dependency is missing from execution plan entirely
|
184
|
+
errors.append(
|
185
|
+
f"Node '{node_id}' dependency '{dep}' missing from execution plan"
|
186
|
+
)
|
187
|
+
|
188
|
+
seen_nodes.add(node_id)
|
189
|
+
|
190
|
+
is_valid = len(errors) == 0
|
191
|
+
if is_valid:
|
192
|
+
logger.debug("Execution plan validation passed")
|
193
|
+
else:
|
194
|
+
logger.warning(f"Execution plan validation failed: {errors}")
|
195
|
+
|
196
|
+
return is_valid, errors
|
197
|
+
|
198
|
+
def _get_all_nodes_topological_order(self) -> List[str]:
|
199
|
+
"""Get all nodes in topological order."""
|
200
|
+
if not hasattr(self.workflow, "graph") or self.workflow.graph is None:
|
201
|
+
return []
|
202
|
+
|
203
|
+
try:
|
204
|
+
return list(nx.topological_sort(self.workflow.graph))
|
205
|
+
except Exception as e:
|
206
|
+
logger.error(f"Error getting topological order: {e}")
|
207
|
+
# Fallback to node list
|
208
|
+
return list(self.workflow.graph.nodes())
|
209
|
+
|
210
|
+
def _get_always_reachable_nodes(self, switch_node_ids: Set[str]) -> Set[str]:
|
211
|
+
"""
|
212
|
+
Get nodes that are always reachable (not dependent on any switches).
|
213
|
+
|
214
|
+
Args:
|
215
|
+
switch_node_ids: Set of switch node IDs
|
216
|
+
|
217
|
+
Returns:
|
218
|
+
Set of node IDs that are always reachable
|
219
|
+
"""
|
220
|
+
always_reachable = set()
|
221
|
+
|
222
|
+
if not hasattr(self.workflow, "graph") or self.workflow.graph is None:
|
223
|
+
return always_reachable
|
224
|
+
|
225
|
+
# Find nodes that don't depend on any switches
|
226
|
+
for node_id in self.workflow.graph.nodes():
|
227
|
+
if self._is_reachable_without_switches(node_id, switch_node_ids):
|
228
|
+
always_reachable.add(node_id)
|
229
|
+
|
230
|
+
logger.debug(f"Found {len(always_reachable)} always reachable nodes")
|
231
|
+
return always_reachable
|
232
|
+
|
233
|
+
def _is_reachable_without_switches(
|
234
|
+
self, node_id: str, switch_node_ids: Set[str]
|
235
|
+
) -> bool:
|
236
|
+
"""
|
237
|
+
Check if a node is reachable without going through any switches.
|
238
|
+
|
239
|
+
Args:
|
240
|
+
node_id: Node to check
|
241
|
+
switch_node_ids: Set of switch node IDs
|
242
|
+
|
243
|
+
Returns:
|
244
|
+
True if node is reachable without switches
|
245
|
+
"""
|
246
|
+
if node_id in switch_node_ids:
|
247
|
+
return True # Switches themselves are always reachable
|
248
|
+
|
249
|
+
# BFS backwards to see if we can reach a source without going through switches
|
250
|
+
visited = set()
|
251
|
+
queue = deque([node_id])
|
252
|
+
|
253
|
+
while queue:
|
254
|
+
current = queue.popleft()
|
255
|
+
if current in visited:
|
256
|
+
continue
|
257
|
+
|
258
|
+
visited.add(current)
|
259
|
+
|
260
|
+
# Get predecessors
|
261
|
+
if hasattr(self.workflow.graph, "predecessors"):
|
262
|
+
predecessors = list(self.workflow.graph.predecessors(current))
|
263
|
+
|
264
|
+
if not predecessors:
|
265
|
+
# Found a source node - reachable without switches
|
266
|
+
return True
|
267
|
+
|
268
|
+
for pred in predecessors:
|
269
|
+
if pred in switch_node_ids:
|
270
|
+
# Path goes through a switch - not always reachable
|
271
|
+
continue
|
272
|
+
queue.append(pred)
|
273
|
+
|
274
|
+
return False
|
275
|
+
|
276
|
+
def _create_cache_key(self, switch_results: Dict[str, Dict[str, Any]]) -> str:
|
277
|
+
"""Create cache key from switch results."""
|
278
|
+
# Create a stable string representation of switch results
|
279
|
+
key_parts = []
|
280
|
+
for switch_id in sorted(switch_results.keys()):
|
281
|
+
ports = switch_results[switch_id]
|
282
|
+
port_parts = []
|
283
|
+
|
284
|
+
# Handle None ports (invalid switch results)
|
285
|
+
if ports is None:
|
286
|
+
port_parts.append("None")
|
287
|
+
elif isinstance(ports, dict):
|
288
|
+
for port in sorted(ports.keys()):
|
289
|
+
result = ports[port]
|
290
|
+
# Create simple representation (None vs not-None)
|
291
|
+
result_repr = "None" if result is None else "active"
|
292
|
+
port_parts.append(f"{port}:{result_repr}")
|
293
|
+
else:
|
294
|
+
# Invalid port format, represent as string
|
295
|
+
port_parts.append(f"invalid:{str(ports)}")
|
296
|
+
|
297
|
+
key_parts.append(f"{switch_id}({','.join(port_parts)})")
|
298
|
+
|
299
|
+
return "|".join(key_parts)
|
300
|
+
|
301
|
+
def create_hierarchical_plan(self, workflow: Workflow) -> List[List[str]]:
|
302
|
+
"""
|
303
|
+
Create execution plan with hierarchical switch dependencies.
|
304
|
+
|
305
|
+
Execute SwitchNodes in dependency layers:
|
306
|
+
- Phase 1: Independent SwitchNodes
|
307
|
+
- Phase 2: Dependent SwitchNodes based on Phase 1 results
|
308
|
+
- Phase 3: Final conditional branches
|
309
|
+
|
310
|
+
Args:
|
311
|
+
workflow: Workflow to analyze
|
312
|
+
|
313
|
+
Returns:
|
314
|
+
List of execution phases, each containing list of node IDs
|
315
|
+
"""
|
316
|
+
switch_nodes = self.analyzer._find_switch_nodes()
|
317
|
+
if not switch_nodes:
|
318
|
+
# No switches - single phase with all nodes
|
319
|
+
return [self._get_all_nodes_topological_order()]
|
320
|
+
|
321
|
+
dependencies = self._analyze_dependencies()
|
322
|
+
|
323
|
+
# Build switch dependency graph
|
324
|
+
switch_deps = nx.DiGraph()
|
325
|
+
switch_deps.add_nodes_from(switch_nodes)
|
326
|
+
|
327
|
+
for switch_id in switch_nodes:
|
328
|
+
for dep in dependencies.get(switch_id, []):
|
329
|
+
if dep in switch_nodes:
|
330
|
+
switch_deps.add_edge(dep, switch_id)
|
331
|
+
|
332
|
+
# Get switch execution layers
|
333
|
+
try:
|
334
|
+
layers = []
|
335
|
+
remaining_switches = set(switch_nodes)
|
336
|
+
|
337
|
+
while remaining_switches:
|
338
|
+
# Find switches with no dependencies in remaining set
|
339
|
+
current_layer = []
|
340
|
+
for switch_id in remaining_switches:
|
341
|
+
deps = [
|
342
|
+
d
|
343
|
+
for d in switch_deps.predecessors(switch_id)
|
344
|
+
if d in remaining_switches
|
345
|
+
]
|
346
|
+
if not deps:
|
347
|
+
current_layer.append(switch_id)
|
348
|
+
|
349
|
+
if not current_layer:
|
350
|
+
# Circular dependency - add all remaining
|
351
|
+
current_layer = list(remaining_switches)
|
352
|
+
|
353
|
+
layers.append(current_layer)
|
354
|
+
remaining_switches -= set(current_layer)
|
355
|
+
|
356
|
+
return layers
|
357
|
+
|
358
|
+
except Exception as e:
|
359
|
+
logger.error(f"Error creating hierarchical plan: {e}")
|
360
|
+
# Fallback to single layer
|
361
|
+
return [switch_nodes]
|
362
|
+
|
363
|
+
def _handle_merge_with_conditional_inputs(
|
364
|
+
self,
|
365
|
+
merge_node: str,
|
366
|
+
workflow: Workflow,
|
367
|
+
switch_results: Dict[str, Dict[str, Any]],
|
368
|
+
) -> bool:
|
369
|
+
"""
|
370
|
+
Handle merge node with conditional inputs.
|
371
|
+
|
372
|
+
Args:
|
373
|
+
merge_node: ID of the merge node
|
374
|
+
workflow: Workflow containing the merge node
|
375
|
+
switch_results: Switch results to determine available inputs
|
376
|
+
|
377
|
+
Returns:
|
378
|
+
True if merge node should be included in execution plan
|
379
|
+
"""
|
380
|
+
if not hasattr(workflow, "graph") or workflow.graph is None:
|
381
|
+
return True
|
382
|
+
|
383
|
+
# Check how many inputs to the merge node are available
|
384
|
+
available_inputs = 0
|
385
|
+
|
386
|
+
for pred in workflow.graph.predecessors(merge_node):
|
387
|
+
# Check if predecessor is reachable based on switch results
|
388
|
+
reachable_nodes = self.analyzer.get_reachable_nodes(switch_results)
|
389
|
+
if pred in reachable_nodes:
|
390
|
+
available_inputs += 1
|
391
|
+
|
392
|
+
# Merge nodes can typically handle partial inputs
|
393
|
+
# Include if at least one input is available
|
394
|
+
return available_inputs > 0
|
395
|
+
|
396
|
+
def invalidate_cache(self):
|
397
|
+
"""Invalidate cached execution plans and dependencies."""
|
398
|
+
self._execution_plan_cache.clear()
|
399
|
+
self._dependency_cache = None
|
400
|
+
# Also invalidate the analyzer's cache when workflow structure changes
|
401
|
+
if hasattr(self.analyzer, "invalidate_cache"):
|
402
|
+
self.analyzer.invalidate_cache()
|
403
|
+
logger.debug("DynamicExecutionPlanner cache invalidated")
|
404
|
+
|
405
|
+
# ===== PHASE 4: ADVANCED FEATURES =====
|
406
|
+
|
407
|
+
def create_hierarchical_execution_plan(
|
408
|
+
self, switch_results: Dict[str, Dict[str, Any]]
|
409
|
+
) -> Dict[str, Any]:
|
410
|
+
"""
|
411
|
+
Create advanced execution plan with hierarchical switch support and merge strategies.
|
412
|
+
|
413
|
+
Args:
|
414
|
+
switch_results: Results from SwitchNode execution
|
415
|
+
|
416
|
+
Returns:
|
417
|
+
Dictionary containing detailed execution plan with layers and strategies
|
418
|
+
"""
|
419
|
+
plan = {
|
420
|
+
"switch_layers": [],
|
421
|
+
"execution_plan": [],
|
422
|
+
"merge_strategies": {},
|
423
|
+
"performance_metrics": {},
|
424
|
+
"reachable_nodes": set(),
|
425
|
+
"skipped_nodes": set(),
|
426
|
+
}
|
427
|
+
|
428
|
+
try:
|
429
|
+
# Use analyzer's hierarchical capabilities
|
430
|
+
if hasattr(self.analyzer, "create_hierarchical_execution_plan"):
|
431
|
+
hierarchical_plan = self.analyzer.create_hierarchical_execution_plan(
|
432
|
+
switch_results
|
433
|
+
)
|
434
|
+
plan.update(hierarchical_plan)
|
435
|
+
|
436
|
+
# Create traditional execution plan as fallback
|
437
|
+
traditional_plan = self.create_execution_plan(switch_results)
|
438
|
+
if not plan["execution_plan"]:
|
439
|
+
plan["execution_plan"] = traditional_plan
|
440
|
+
|
441
|
+
# Calculate performance metrics
|
442
|
+
total_nodes = len(self.workflow.graph.nodes)
|
443
|
+
executed_nodes = len(plan["execution_plan"])
|
444
|
+
plan["performance_metrics"] = {
|
445
|
+
"total_nodes": total_nodes,
|
446
|
+
"executed_nodes": executed_nodes,
|
447
|
+
"skipped_nodes": total_nodes - executed_nodes,
|
448
|
+
"performance_improvement": (
|
449
|
+
(total_nodes - executed_nodes) / total_nodes
|
450
|
+
if total_nodes > 0
|
451
|
+
else 0
|
452
|
+
),
|
453
|
+
}
|
454
|
+
|
455
|
+
# Convert sets to lists for JSON serialization
|
456
|
+
if isinstance(plan["reachable_nodes"], set):
|
457
|
+
plan["reachable_nodes"] = list(plan["reachable_nodes"])
|
458
|
+
if isinstance(plan["skipped_nodes"], set):
|
459
|
+
plan["skipped_nodes"] = list(plan["skipped_nodes"])
|
460
|
+
|
461
|
+
except Exception as e:
|
462
|
+
logger.warning(f"Error creating hierarchical execution plan: {e}")
|
463
|
+
# Fallback to basic execution plan
|
464
|
+
plan["execution_plan"] = self.create_execution_plan(switch_results)
|
465
|
+
|
466
|
+
return plan
|
467
|
+
|
468
|
+
def handle_merge_nodes_with_conditional_inputs(
|
469
|
+
self, execution_plan: List[str], switch_results: Dict[str, Dict[str, Any]]
|
470
|
+
) -> Dict[str, Any]:
|
471
|
+
"""
|
472
|
+
Analyze and handle MergeNodes that receive conditional inputs.
|
473
|
+
|
474
|
+
Args:
|
475
|
+
execution_plan: Current execution plan
|
476
|
+
switch_results: Results from switch execution
|
477
|
+
|
478
|
+
Returns:
|
479
|
+
Dictionary with merge handling strategies
|
480
|
+
"""
|
481
|
+
merge_handling = {
|
482
|
+
"merge_nodes": [],
|
483
|
+
"strategies": {},
|
484
|
+
"execution_modifications": [],
|
485
|
+
"warnings": [],
|
486
|
+
}
|
487
|
+
|
488
|
+
try:
|
489
|
+
# Find merge nodes in the workflow
|
490
|
+
if hasattr(self.analyzer, "_find_merge_nodes"):
|
491
|
+
merge_nodes = self.analyzer._find_merge_nodes()
|
492
|
+
else:
|
493
|
+
merge_nodes = self._find_merge_nodes_fallback()
|
494
|
+
|
495
|
+
reachable_nodes = set(execution_plan)
|
496
|
+
|
497
|
+
for merge_id in merge_nodes:
|
498
|
+
if merge_id in execution_plan:
|
499
|
+
strategy = self._create_merge_strategy(
|
500
|
+
merge_id, reachable_nodes, switch_results
|
501
|
+
)
|
502
|
+
merge_handling["strategies"][merge_id] = strategy
|
503
|
+
merge_handling["merge_nodes"].append(merge_id)
|
504
|
+
|
505
|
+
# Add execution modifications if needed
|
506
|
+
if strategy["strategy_type"] == "skip":
|
507
|
+
merge_handling["execution_modifications"].append(
|
508
|
+
{
|
509
|
+
"type": "skip_node",
|
510
|
+
"node_id": merge_id,
|
511
|
+
"reason": "No available inputs",
|
512
|
+
}
|
513
|
+
)
|
514
|
+
elif strategy["strategy_type"] == "partial":
|
515
|
+
merge_handling["execution_modifications"].append(
|
516
|
+
{
|
517
|
+
"type": "partial_merge",
|
518
|
+
"node_id": merge_id,
|
519
|
+
"available_inputs": strategy["available_inputs"],
|
520
|
+
"missing_inputs": strategy["missing_inputs"],
|
521
|
+
}
|
522
|
+
)
|
523
|
+
|
524
|
+
except Exception as e:
|
525
|
+
logger.warning(f"Error handling merge nodes: {e}")
|
526
|
+
merge_handling["warnings"].append(f"Merge node analysis failed: {e}")
|
527
|
+
|
528
|
+
return merge_handling
|
529
|
+
|
530
|
+
def _find_merge_nodes_fallback(self) -> List[str]:
|
531
|
+
"""Fallback method to find merge nodes when analyzer doesn't have the method."""
|
532
|
+
merge_nodes = []
|
533
|
+
|
534
|
+
try:
|
535
|
+
from kailash.nodes.logic.operations import MergeNode
|
536
|
+
|
537
|
+
for node_id, node_data in self.workflow.graph.nodes(data=True):
|
538
|
+
node_instance = node_data.get("node") or node_data.get("instance")
|
539
|
+
if node_instance and isinstance(node_instance, MergeNode):
|
540
|
+
merge_nodes.append(node_id)
|
541
|
+
|
542
|
+
except Exception as e:
|
543
|
+
logger.warning(f"Error in merge node fallback detection: {e}")
|
544
|
+
|
545
|
+
return merge_nodes
|
546
|
+
|
547
|
+
def _create_merge_strategy(
|
548
|
+
self,
|
549
|
+
merge_id: str,
|
550
|
+
reachable_nodes: Set[str],
|
551
|
+
switch_results: Dict[str, Dict[str, Any]],
|
552
|
+
) -> Dict[str, Any]:
|
553
|
+
"""
|
554
|
+
Create merge strategy for a specific MergeNode.
|
555
|
+
|
556
|
+
Args:
|
557
|
+
merge_id: ID of the merge node
|
558
|
+
reachable_nodes: Set of nodes that will be executed
|
559
|
+
switch_results: Results from switch execution
|
560
|
+
|
561
|
+
Returns:
|
562
|
+
Dictionary describing merge strategy
|
563
|
+
"""
|
564
|
+
strategy = {
|
565
|
+
"merge_id": merge_id,
|
566
|
+
"available_inputs": [],
|
567
|
+
"missing_inputs": [],
|
568
|
+
"strategy_type": "unknown",
|
569
|
+
"confidence": 0.0,
|
570
|
+
"recommendations": [],
|
571
|
+
}
|
572
|
+
|
573
|
+
try:
|
574
|
+
# Get all predecessors of the merge node
|
575
|
+
predecessors = list(self.workflow.graph.predecessors(merge_id))
|
576
|
+
|
577
|
+
for pred in predecessors:
|
578
|
+
if pred in reachable_nodes:
|
579
|
+
strategy["available_inputs"].append(pred)
|
580
|
+
else:
|
581
|
+
strategy["missing_inputs"].append(pred)
|
582
|
+
|
583
|
+
# Determine strategy type
|
584
|
+
available_count = len(strategy["available_inputs"])
|
585
|
+
total_count = len(predecessors)
|
586
|
+
|
587
|
+
if available_count == 0:
|
588
|
+
strategy["strategy_type"] = "skip"
|
589
|
+
strategy["confidence"] = 1.0
|
590
|
+
strategy["recommendations"].append(
|
591
|
+
"Skip merge node - no inputs available"
|
592
|
+
)
|
593
|
+
elif available_count == total_count:
|
594
|
+
strategy["strategy_type"] = "full"
|
595
|
+
strategy["confidence"] = 1.0
|
596
|
+
strategy["recommendations"].append("Execute merge with all inputs")
|
597
|
+
else:
|
598
|
+
strategy["strategy_type"] = "partial"
|
599
|
+
strategy["confidence"] = available_count / total_count
|
600
|
+
strategy["recommendations"].append(
|
601
|
+
f"Execute merge with {available_count}/{total_count} inputs"
|
602
|
+
)
|
603
|
+
strategy["recommendations"].append(
|
604
|
+
"Consider merge node's skip_none parameter"
|
605
|
+
)
|
606
|
+
|
607
|
+
except Exception as e:
|
608
|
+
logger.warning(f"Error creating merge strategy for {merge_id}: {e}")
|
609
|
+
strategy["strategy_type"] = "error"
|
610
|
+
strategy["recommendations"].append(f"Error analyzing merge: {e}")
|
611
|
+
|
612
|
+
return strategy
|
613
|
+
|
614
|
+
def optimize_execution_plan(
|
615
|
+
self, execution_plan: List[str], switch_results: Dict[str, Dict[str, Any]]
|
616
|
+
) -> Dict[str, Any]:
|
617
|
+
"""
|
618
|
+
Optimize execution plan with advanced performance techniques.
|
619
|
+
|
620
|
+
Args:
|
621
|
+
execution_plan: Basic execution plan
|
622
|
+
switch_results: Results from switch execution
|
623
|
+
|
624
|
+
Returns:
|
625
|
+
Dictionary with optimized execution plan and performance data
|
626
|
+
"""
|
627
|
+
optimization_result = {
|
628
|
+
"original_plan": execution_plan.copy(),
|
629
|
+
"optimized_plan": execution_plan.copy(),
|
630
|
+
"optimizations_applied": [],
|
631
|
+
"performance_improvement": 0.0,
|
632
|
+
"analysis": {},
|
633
|
+
}
|
634
|
+
|
635
|
+
try:
|
636
|
+
# Parallel execution opportunities
|
637
|
+
parallel_groups = self._identify_parallel_execution_groups(execution_plan)
|
638
|
+
if parallel_groups:
|
639
|
+
optimization_result["optimizations_applied"].append(
|
640
|
+
"parallel_execution_grouping"
|
641
|
+
)
|
642
|
+
optimization_result["analysis"]["parallel_groups"] = parallel_groups
|
643
|
+
|
644
|
+
# Merge node optimizations
|
645
|
+
merge_handling = self.handle_merge_nodes_with_conditional_inputs(
|
646
|
+
execution_plan, switch_results
|
647
|
+
)
|
648
|
+
if merge_handling["execution_modifications"]:
|
649
|
+
optimization_result["optimizations_applied"].append(
|
650
|
+
"merge_node_optimization"
|
651
|
+
)
|
652
|
+
optimization_result["analysis"]["merge_optimizations"] = merge_handling
|
653
|
+
|
654
|
+
# Apply merge optimizations to plan
|
655
|
+
modified_plan = execution_plan.copy()
|
656
|
+
for mod in merge_handling["execution_modifications"]:
|
657
|
+
if mod["type"] == "skip_node":
|
658
|
+
if mod["node_id"] in modified_plan:
|
659
|
+
modified_plan.remove(mod["node_id"])
|
660
|
+
|
661
|
+
optimization_result["optimized_plan"] = modified_plan
|
662
|
+
|
663
|
+
# Calculate performance improvement
|
664
|
+
original_count = len(optimization_result["original_plan"])
|
665
|
+
optimized_count = len(optimization_result["optimized_plan"])
|
666
|
+
if original_count > 0:
|
667
|
+
optimization_result["performance_improvement"] = (
|
668
|
+
original_count - optimized_count
|
669
|
+
) / original_count
|
670
|
+
|
671
|
+
except Exception as e:
|
672
|
+
logger.warning(f"Error optimizing execution plan: {e}")
|
673
|
+
optimization_result["analysis"]["error"] = str(e)
|
674
|
+
|
675
|
+
return optimization_result
|
676
|
+
|
677
|
+
def _identify_parallel_execution_groups(
|
678
|
+
self, execution_plan: List[str]
|
679
|
+
) -> List[List[str]]:
|
680
|
+
"""
|
681
|
+
Identify groups of nodes that can be executed in parallel.
|
682
|
+
|
683
|
+
Args:
|
684
|
+
execution_plan: List of nodes in execution order
|
685
|
+
|
686
|
+
Returns:
|
687
|
+
List of parallel execution groups
|
688
|
+
"""
|
689
|
+
parallel_groups = []
|
690
|
+
|
691
|
+
try:
|
692
|
+
# Build dependency graph for nodes in the execution plan
|
693
|
+
plan_nodes = set(execution_plan)
|
694
|
+
dependencies = {}
|
695
|
+
|
696
|
+
for node_id in execution_plan:
|
697
|
+
predecessors = list(self.workflow.graph.predecessors(node_id))
|
698
|
+
# Only consider dependencies within the execution plan
|
699
|
+
plan_predecessors = [
|
700
|
+
pred for pred in predecessors if pred in plan_nodes
|
701
|
+
]
|
702
|
+
dependencies[node_id] = plan_predecessors
|
703
|
+
|
704
|
+
# Group nodes by their dependency depth
|
705
|
+
depth_groups = self._group_by_dependency_depth(execution_plan, dependencies)
|
706
|
+
|
707
|
+
# Each depth level can potentially be executed in parallel
|
708
|
+
for depth, nodes in depth_groups.items():
|
709
|
+
if len(nodes) > 1:
|
710
|
+
parallel_groups.append(nodes)
|
711
|
+
|
712
|
+
except Exception as e:
|
713
|
+
logger.warning(f"Error identifying parallel execution groups: {e}")
|
714
|
+
|
715
|
+
return parallel_groups
|
716
|
+
|
717
|
+
def _group_by_dependency_depth(
|
718
|
+
self, execution_plan: List[str], dependencies: Dict[str, List[str]]
|
719
|
+
) -> Dict[int, List[str]]:
|
720
|
+
"""
|
721
|
+
Group nodes by their dependency depth for parallel execution analysis.
|
722
|
+
|
723
|
+
Args:
|
724
|
+
execution_plan: List of nodes in execution order
|
725
|
+
dependencies: Dictionary mapping node_id -> list of dependencies
|
726
|
+
|
727
|
+
Returns:
|
728
|
+
Dictionary mapping depth -> list of nodes at that depth
|
729
|
+
"""
|
730
|
+
depth_groups = defaultdict(list)
|
731
|
+
node_depths = {}
|
732
|
+
|
733
|
+
try:
|
734
|
+
# Calculate depth for each node
|
735
|
+
for node_id in execution_plan:
|
736
|
+
depth = self._calculate_node_depth(node_id, dependencies, node_depths)
|
737
|
+
depth_groups[depth].append(node_id)
|
738
|
+
node_depths[node_id] = depth
|
739
|
+
|
740
|
+
except Exception as e:
|
741
|
+
logger.warning(f"Error grouping by dependency depth: {e}")
|
742
|
+
|
743
|
+
return dict(depth_groups)
|
744
|
+
|
745
|
+
def _calculate_node_depth(
|
746
|
+
self,
|
747
|
+
node_id: str,
|
748
|
+
dependencies: Dict[str, List[str]],
|
749
|
+
node_depths: Dict[str, int],
|
750
|
+
) -> int:
|
751
|
+
"""
|
752
|
+
Calculate the dependency depth of a node.
|
753
|
+
|
754
|
+
Args:
|
755
|
+
node_id: ID of the node
|
756
|
+
dependencies: Dictionary mapping node_id -> list of dependencies
|
757
|
+
node_depths: Cache of already calculated depths
|
758
|
+
|
759
|
+
Returns:
|
760
|
+
Dependency depth of the node
|
761
|
+
"""
|
762
|
+
if node_id in node_depths:
|
763
|
+
return node_depths[node_id]
|
764
|
+
|
765
|
+
node_deps = dependencies.get(node_id, [])
|
766
|
+
if not node_deps:
|
767
|
+
# No dependencies - depth 0
|
768
|
+
return 0
|
769
|
+
|
770
|
+
# Depth is 1 + max depth of dependencies
|
771
|
+
max_dep_depth = 0
|
772
|
+
for dep in node_deps:
|
773
|
+
dep_depth = self._calculate_node_depth(dep, dependencies, node_depths)
|
774
|
+
max_dep_depth = max(max_dep_depth, dep_depth)
|
775
|
+
|
776
|
+
return max_dep_depth + 1
|