kailash 0.9.2__py3-none-any.whl → 0.9.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kailash/__init__.py +1 -1
- kailash/analysis/__init__.py +9 -0
- kailash/analysis/conditional_branch_analyzer.py +696 -0
- kailash/nodes/logic/intelligent_merge.py +475 -0
- kailash/nodes/logic/operations.py +41 -8
- kailash/planning/__init__.py +9 -0
- kailash/planning/dynamic_execution_planner.py +776 -0
- kailash/runtime/compatibility_reporter.py +497 -0
- kailash/runtime/hierarchical_switch_executor.py +548 -0
- kailash/runtime/local.py +1787 -26
- kailash/runtime/parallel.py +1 -1
- kailash/runtime/performance_monitor.py +215 -0
- kailash/runtime/validation/import_validator.py +7 -0
- kailash/workflow/cyclic_runner.py +436 -27
- {kailash-0.9.2.dist-info → kailash-0.9.3.dist-info}/METADATA +1 -1
- {kailash-0.9.2.dist-info → kailash-0.9.3.dist-info}/RECORD +20 -12
- {kailash-0.9.2.dist-info → kailash-0.9.3.dist-info}/WHEEL +0 -0
- {kailash-0.9.2.dist-info → kailash-0.9.3.dist-info}/entry_points.txt +0 -0
- {kailash-0.9.2.dist-info → kailash-0.9.3.dist-info}/licenses/LICENSE +0 -0
- {kailash-0.9.2.dist-info → kailash-0.9.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,696 @@
|
|
1
|
+
"""
|
2
|
+
ConditionalBranchAnalyzer for analyzing workflow conditional patterns.
|
3
|
+
|
4
|
+
Analyzes workflow graphs to identify conditional execution patterns and determine
|
5
|
+
which nodes are reachable based on SwitchNode outputs.
|
6
|
+
"""
|
7
|
+
|
8
|
+
import logging
|
9
|
+
from typing import Any, Dict, List, Optional, Set
|
10
|
+
|
11
|
+
import networkx as nx
|
12
|
+
|
13
|
+
from kailash.nodes.logic.operations import MergeNode, SwitchNode
|
14
|
+
from kailash.workflow.graph import Workflow
|
15
|
+
|
16
|
+
logger = logging.getLogger(__name__)
|
17
|
+
|
18
|
+
|
19
|
+
class ConditionalBranchAnalyzer:
|
20
|
+
"""
|
21
|
+
Analyzes workflow graph to identify conditional execution patterns.
|
22
|
+
|
23
|
+
This analyzer examines workflows containing SwitchNode instances and maps
|
24
|
+
out the conditional branches to determine which nodes are reachable based
|
25
|
+
on runtime switch results.
|
26
|
+
"""
|
27
|
+
|
28
|
+
def __init__(self, workflow: Workflow):
|
29
|
+
"""
|
30
|
+
Initialize the ConditionalBranchAnalyzer.
|
31
|
+
|
32
|
+
Args:
|
33
|
+
workflow: The workflow to analyze for conditional patterns
|
34
|
+
"""
|
35
|
+
self.workflow = workflow
|
36
|
+
self._switch_nodes: Optional[List[str]] = None
|
37
|
+
self._branch_map: Optional[Dict[str, Dict[str, Set[str]]]] = None
|
38
|
+
|
39
|
+
def _find_switch_nodes(self) -> List[str]:
|
40
|
+
"""
|
41
|
+
Find all SwitchNode instances in the workflow.
|
42
|
+
|
43
|
+
Returns:
|
44
|
+
List of node IDs that are SwitchNode instances
|
45
|
+
"""
|
46
|
+
if self._switch_nodes is not None:
|
47
|
+
return self._switch_nodes
|
48
|
+
|
49
|
+
switch_nodes = []
|
50
|
+
|
51
|
+
if not hasattr(self.workflow, "graph") or self.workflow.graph is None:
|
52
|
+
logger.warning("Workflow has no graph to analyze")
|
53
|
+
return switch_nodes
|
54
|
+
|
55
|
+
for node_id in self.workflow.graph.nodes():
|
56
|
+
try:
|
57
|
+
node_data = self.workflow.graph.nodes[node_id]
|
58
|
+
# Try both 'node' and 'instance' keys for compatibility
|
59
|
+
node_instance = node_data.get("node") or node_data.get("instance")
|
60
|
+
|
61
|
+
if isinstance(node_instance, SwitchNode):
|
62
|
+
switch_nodes.append(node_id)
|
63
|
+
logger.debug(f"Found SwitchNode: {node_id}")
|
64
|
+
|
65
|
+
except (KeyError, AttributeError) as e:
|
66
|
+
logger.debug(f"Skipping node {node_id} - no valid instance: {e}")
|
67
|
+
continue
|
68
|
+
|
69
|
+
self._switch_nodes = switch_nodes
|
70
|
+
logger.info(f"Found {len(switch_nodes)} SwitchNode instances")
|
71
|
+
return switch_nodes
|
72
|
+
|
73
|
+
def _build_branch_map(self) -> Dict[str, Dict[str, Set[str]]]:
|
74
|
+
"""
|
75
|
+
Build map of SwitchNode -> {output_port -> downstream_nodes}.
|
76
|
+
|
77
|
+
Returns:
|
78
|
+
Dictionary mapping switch node IDs to their output ports and
|
79
|
+
the downstream nodes reachable from each port
|
80
|
+
"""
|
81
|
+
if self._branch_map is not None:
|
82
|
+
return self._branch_map
|
83
|
+
|
84
|
+
branch_map = {}
|
85
|
+
switch_nodes = self._find_switch_nodes()
|
86
|
+
|
87
|
+
for switch_id in switch_nodes:
|
88
|
+
branch_map[switch_id] = {}
|
89
|
+
|
90
|
+
# Find all outgoing edges from this switch
|
91
|
+
if hasattr(self.workflow.graph, "out_edges"):
|
92
|
+
for source, target, edge_data in self.workflow.graph.out_edges(
|
93
|
+
switch_id, data=True
|
94
|
+
):
|
95
|
+
mapping = edge_data.get("mapping", {})
|
96
|
+
|
97
|
+
# Extract output ports from mapping
|
98
|
+
for source_port, target_port in mapping.items():
|
99
|
+
# Output ports typically start with output name (true_output, false_output, case_X)
|
100
|
+
if source_port not in branch_map[switch_id]:
|
101
|
+
branch_map[switch_id][source_port] = set()
|
102
|
+
|
103
|
+
branch_map[switch_id][source_port].add(target)
|
104
|
+
logger.debug(
|
105
|
+
f"Switch {switch_id} port {source_port} -> {target}"
|
106
|
+
)
|
107
|
+
|
108
|
+
# Find downstream nodes recursively
|
109
|
+
downstream = self._find_downstream_nodes(target, switch_nodes)
|
110
|
+
branch_map[switch_id][source_port].update(downstream)
|
111
|
+
|
112
|
+
self._branch_map = branch_map
|
113
|
+
logger.info(f"Built branch map for {len(branch_map)} switches")
|
114
|
+
return branch_map
|
115
|
+
|
116
|
+
def _find_downstream_nodes(
|
117
|
+
self, start_node: str, exclude_switches: List[str]
|
118
|
+
) -> Set[str]:
|
119
|
+
"""
|
120
|
+
Find all downstream nodes from a starting node, excluding switches.
|
121
|
+
|
122
|
+
Args:
|
123
|
+
start_node: Node to start traversal from
|
124
|
+
exclude_switches: Switch nodes to stop traversal at
|
125
|
+
|
126
|
+
Returns:
|
127
|
+
Set of downstream node IDs
|
128
|
+
"""
|
129
|
+
downstream = set()
|
130
|
+
visited = set()
|
131
|
+
stack = [start_node]
|
132
|
+
|
133
|
+
while stack:
|
134
|
+
current = stack.pop()
|
135
|
+
if current in visited:
|
136
|
+
continue
|
137
|
+
|
138
|
+
visited.add(current)
|
139
|
+
|
140
|
+
# Don't traverse past other switches (they create their own branches)
|
141
|
+
if current in exclude_switches and current != start_node:
|
142
|
+
continue
|
143
|
+
|
144
|
+
downstream.add(current)
|
145
|
+
|
146
|
+
# Add successors to stack
|
147
|
+
if hasattr(self.workflow.graph, "successors"):
|
148
|
+
for successor in self.workflow.graph.successors(current):
|
149
|
+
if successor not in visited:
|
150
|
+
stack.append(successor)
|
151
|
+
|
152
|
+
# Remove the start node from downstream set
|
153
|
+
downstream.discard(start_node)
|
154
|
+
return downstream
|
155
|
+
|
156
|
+
def get_reachable_nodes(
|
157
|
+
self, switch_results: Dict[str, Dict[str, Any]]
|
158
|
+
) -> Set[str]:
|
159
|
+
"""
|
160
|
+
Get reachable nodes based on SwitchNode results.
|
161
|
+
|
162
|
+
Args:
|
163
|
+
switch_results: Dictionary mapping switch_id -> {output_port -> result}
|
164
|
+
None results indicate the port was not activated
|
165
|
+
|
166
|
+
Returns:
|
167
|
+
Set of node IDs that are reachable based on switch results
|
168
|
+
"""
|
169
|
+
reachable = set()
|
170
|
+
branch_map = self._build_branch_map()
|
171
|
+
|
172
|
+
# Always include switches that executed
|
173
|
+
reachable.update(switch_results.keys())
|
174
|
+
|
175
|
+
# Track nodes to process for downstream traversal
|
176
|
+
to_process = set()
|
177
|
+
|
178
|
+
# Process each switch result to find directly connected nodes
|
179
|
+
for switch_id, port_results in switch_results.items():
|
180
|
+
if switch_id not in branch_map:
|
181
|
+
logger.warning(f"Switch {switch_id} not found in branch map")
|
182
|
+
continue
|
183
|
+
|
184
|
+
switch_branches = branch_map[switch_id]
|
185
|
+
|
186
|
+
for port, result in port_results.items():
|
187
|
+
if result is not None: # This port was activated
|
188
|
+
if port in switch_branches:
|
189
|
+
direct_nodes = switch_branches[port]
|
190
|
+
reachable.update(direct_nodes)
|
191
|
+
to_process.update(direct_nodes)
|
192
|
+
logger.debug(
|
193
|
+
f"Switch {switch_id} port {port} activated - added {len(direct_nodes)} direct nodes"
|
194
|
+
)
|
195
|
+
|
196
|
+
# Now traverse the graph to find ALL downstream nodes from the activated branches
|
197
|
+
while to_process:
|
198
|
+
current_node = to_process.pop()
|
199
|
+
|
200
|
+
# Get all successors of the current node
|
201
|
+
if hasattr(self.workflow.graph, "successors"):
|
202
|
+
successors = list(self.workflow.graph.successors(current_node))
|
203
|
+
for successor in successors:
|
204
|
+
if successor not in reachable:
|
205
|
+
reachable.add(successor)
|
206
|
+
to_process.add(successor)
|
207
|
+
logger.debug(
|
208
|
+
f"Added downstream node {successor} from {current_node}"
|
209
|
+
)
|
210
|
+
|
211
|
+
logger.info(
|
212
|
+
f"Determined {len(reachable)} reachable nodes from switch results (including downstream)"
|
213
|
+
)
|
214
|
+
return reachable
|
215
|
+
|
216
|
+
def detect_conditional_patterns(self) -> Dict[str, Any]:
|
217
|
+
"""
|
218
|
+
Detect complex conditional patterns in the workflow.
|
219
|
+
|
220
|
+
Returns:
|
221
|
+
Dictionary containing information about detected patterns
|
222
|
+
"""
|
223
|
+
patterns = {}
|
224
|
+
switch_nodes = self._find_switch_nodes()
|
225
|
+
branch_map = self._build_branch_map()
|
226
|
+
|
227
|
+
# Basic statistics
|
228
|
+
patterns["total_switches"] = len(switch_nodes)
|
229
|
+
patterns["has_cycles"] = self._detect_cycles()
|
230
|
+
|
231
|
+
# Pattern classification
|
232
|
+
if len(switch_nodes) == 1:
|
233
|
+
patterns["single_switch"] = switch_nodes
|
234
|
+
elif len(switch_nodes) > 1:
|
235
|
+
patterns["multiple_switches"] = switch_nodes
|
236
|
+
patterns["cascading_switches"] = self._detect_cascading_switches(
|
237
|
+
switch_nodes
|
238
|
+
)
|
239
|
+
|
240
|
+
# Detect merge nodes
|
241
|
+
merge_nodes = self._find_merge_nodes()
|
242
|
+
if merge_nodes:
|
243
|
+
patterns["merge_nodes"] = merge_nodes
|
244
|
+
|
245
|
+
# Complex pattern detection
|
246
|
+
if patterns["has_cycles"] and switch_nodes:
|
247
|
+
patterns["cyclic_conditional"] = True
|
248
|
+
|
249
|
+
if len(merge_nodes) > 1:
|
250
|
+
patterns["complex_merge_patterns"] = True
|
251
|
+
|
252
|
+
# Detect circular switch dependencies
|
253
|
+
if self._detect_circular_switch_dependencies(switch_nodes):
|
254
|
+
patterns["circular_switches"] = True
|
255
|
+
|
256
|
+
# Multi-case switch detection
|
257
|
+
multi_case_switches = self._detect_multi_case_switches(switch_nodes)
|
258
|
+
if multi_case_switches:
|
259
|
+
patterns["multi_case_switches"] = multi_case_switches
|
260
|
+
|
261
|
+
logger.info(f"Detected conditional patterns: {patterns}")
|
262
|
+
return patterns
|
263
|
+
|
264
|
+
def _detect_cycles(self) -> bool:
|
265
|
+
"""Detect if the workflow has cycles."""
|
266
|
+
try:
|
267
|
+
# Prioritize NetworkX cycle detection (detects structural cycles)
|
268
|
+
if hasattr(self.workflow.graph, "nodes"):
|
269
|
+
# Use NetworkX to detect cycles in graph structure
|
270
|
+
has_structural_cycles = not nx.is_directed_acyclic_graph(
|
271
|
+
self.workflow.graph
|
272
|
+
)
|
273
|
+
if has_structural_cycles:
|
274
|
+
return True
|
275
|
+
|
276
|
+
# Also check for explicitly marked cycle connections
|
277
|
+
if hasattr(self.workflow, "has_cycles"):
|
278
|
+
return self.workflow.has_cycles()
|
279
|
+
|
280
|
+
return False
|
281
|
+
except Exception as e:
|
282
|
+
logger.debug(f"Error detecting cycles: {e}")
|
283
|
+
return False
|
284
|
+
|
285
|
+
def _detect_cascading_switches(self, switch_nodes: List[str]) -> List[List[str]]:
|
286
|
+
"""Detect cascading switch patterns (switch -> switch -> ...)."""
|
287
|
+
cascading = []
|
288
|
+
|
289
|
+
for switch_id in switch_nodes:
|
290
|
+
# Check if this switch leads to other switches
|
291
|
+
if switch_id in self._branch_map:
|
292
|
+
for port, downstream in self._branch_map[switch_id].items():
|
293
|
+
switch_chain = [switch_id]
|
294
|
+
|
295
|
+
# Find switches in downstream nodes
|
296
|
+
downstream_switches = [
|
297
|
+
node for node in downstream if node in switch_nodes
|
298
|
+
]
|
299
|
+
if downstream_switches:
|
300
|
+
switch_chain.extend(downstream_switches)
|
301
|
+
if len(switch_chain) > 1:
|
302
|
+
cascading.append(switch_chain)
|
303
|
+
|
304
|
+
return cascading
|
305
|
+
|
306
|
+
def _find_merge_nodes(self) -> List[str]:
|
307
|
+
"""Find all MergeNode instances in the workflow."""
|
308
|
+
merge_nodes = []
|
309
|
+
|
310
|
+
if not hasattr(self.workflow, "graph") or self.workflow.graph is None:
|
311
|
+
return merge_nodes
|
312
|
+
|
313
|
+
for node_id in self.workflow.graph.nodes():
|
314
|
+
try:
|
315
|
+
node_data = self.workflow.graph.nodes[node_id]
|
316
|
+
# Try both 'node' and 'instance' keys for compatibility
|
317
|
+
node_instance = node_data.get("node") or node_data.get("instance")
|
318
|
+
|
319
|
+
if isinstance(node_instance, MergeNode):
|
320
|
+
merge_nodes.append(node_id)
|
321
|
+
|
322
|
+
except (KeyError, AttributeError):
|
323
|
+
continue
|
324
|
+
|
325
|
+
return merge_nodes
|
326
|
+
|
327
|
+
def _detect_circular_switch_dependencies(self, switch_nodes: List[str]) -> bool:
|
328
|
+
"""Detect circular dependencies between switches."""
|
329
|
+
if len(switch_nodes) < 2:
|
330
|
+
return False
|
331
|
+
|
332
|
+
# Build dependency graph between switches
|
333
|
+
switch_deps = nx.DiGraph()
|
334
|
+
switch_deps.add_nodes_from(switch_nodes)
|
335
|
+
|
336
|
+
for switch_id in switch_nodes:
|
337
|
+
if switch_id in self._branch_map:
|
338
|
+
for port, downstream in self._branch_map[switch_id].items():
|
339
|
+
for downstream_switch in downstream:
|
340
|
+
if downstream_switch in switch_nodes:
|
341
|
+
switch_deps.add_edge(switch_id, downstream_switch)
|
342
|
+
|
343
|
+
# Check for cycles in switch dependency graph
|
344
|
+
try:
|
345
|
+
return not nx.is_directed_acyclic_graph(switch_deps)
|
346
|
+
except Exception:
|
347
|
+
return False
|
348
|
+
|
349
|
+
def _detect_multi_case_switches(self, switch_nodes: List[str]) -> List[str]:
|
350
|
+
"""Detect multi-case switches (more than true/false outputs)."""
|
351
|
+
multi_case = []
|
352
|
+
|
353
|
+
for switch_id in switch_nodes:
|
354
|
+
if switch_id in self._branch_map:
|
355
|
+
ports = list(self._branch_map[switch_id].keys())
|
356
|
+
|
357
|
+
# Multi-case switches have more than 2 output ports or case_X patterns
|
358
|
+
case_ports = [p for p in ports if p.startswith("case_")]
|
359
|
+
if len(ports) > 2 or case_ports:
|
360
|
+
multi_case.append(switch_id)
|
361
|
+
|
362
|
+
return multi_case
|
363
|
+
|
364
|
+
def _get_switch_branch_map(self, switch_id: str) -> Dict[str, Set[str]]:
|
365
|
+
"""
|
366
|
+
Get the branch map for a specific switch node.
|
367
|
+
|
368
|
+
Args:
|
369
|
+
switch_id: ID of the switch node
|
370
|
+
|
371
|
+
Returns:
|
372
|
+
Dictionary mapping output ports to downstream nodes
|
373
|
+
"""
|
374
|
+
branch_map = self._build_branch_map()
|
375
|
+
return branch_map.get(switch_id, {})
|
376
|
+
|
377
|
+
def detect_switch_hierarchies(self) -> List[Dict[str, Any]]:
|
378
|
+
"""
|
379
|
+
Detect hierarchical switch patterns.
|
380
|
+
|
381
|
+
Returns:
|
382
|
+
List of hierarchy information dictionaries
|
383
|
+
"""
|
384
|
+
hierarchies = []
|
385
|
+
switch_nodes = self._find_switch_nodes()
|
386
|
+
|
387
|
+
if len(switch_nodes) <= 1:
|
388
|
+
return hierarchies
|
389
|
+
|
390
|
+
# Get hierarchy analysis
|
391
|
+
hierarchy_info = self.analyze_switch_hierarchies(switch_nodes)
|
392
|
+
|
393
|
+
if hierarchy_info["has_hierarchies"]:
|
394
|
+
hierarchies.append(
|
395
|
+
{
|
396
|
+
"layers": hierarchy_info["execution_layers"],
|
397
|
+
"max_depth": hierarchy_info["max_depth"],
|
398
|
+
"dependency_chains": hierarchy_info["dependency_chains"],
|
399
|
+
}
|
400
|
+
)
|
401
|
+
|
402
|
+
return hierarchies
|
403
|
+
|
404
|
+
def invalidate_cache(self):
|
405
|
+
"""Invalidate cached analysis results."""
|
406
|
+
self._switch_nodes = None
|
407
|
+
self._branch_map = None
|
408
|
+
logger.debug("ConditionalBranchAnalyzer cache invalidated")
|
409
|
+
|
410
|
+
# ===== PHASE 4: ADVANCED FEATURES =====
|
411
|
+
|
412
|
+
def analyze_switch_hierarchies(
|
413
|
+
self, switch_nodes: Optional[List[str]] = None
|
414
|
+
) -> Dict[str, Any]:
|
415
|
+
"""
|
416
|
+
Analyze hierarchical relationships between SwitchNodes.
|
417
|
+
|
418
|
+
Args:
|
419
|
+
switch_nodes: Optional list of SwitchNode IDs to analyze
|
420
|
+
|
421
|
+
Returns:
|
422
|
+
Dictionary with hierarchy analysis results
|
423
|
+
"""
|
424
|
+
if switch_nodes is None:
|
425
|
+
switch_nodes = self._find_switch_nodes()
|
426
|
+
|
427
|
+
hierarchy_info = {
|
428
|
+
"has_hierarchies": False,
|
429
|
+
"max_depth": 0,
|
430
|
+
"dependency_chains": [],
|
431
|
+
"independent_switches": [],
|
432
|
+
"execution_layers": [],
|
433
|
+
}
|
434
|
+
|
435
|
+
try:
|
436
|
+
if len(switch_nodes) <= 1:
|
437
|
+
hierarchy_info["independent_switches"] = switch_nodes
|
438
|
+
hierarchy_info["execution_layers"] = (
|
439
|
+
[switch_nodes] if switch_nodes else []
|
440
|
+
)
|
441
|
+
hierarchy_info["max_depth"] = 1 if switch_nodes else 0
|
442
|
+
return hierarchy_info
|
443
|
+
|
444
|
+
# Build dependency graph between switches
|
445
|
+
switch_dependencies = {}
|
446
|
+
for switch_id in switch_nodes:
|
447
|
+
predecessors = list(self.workflow.graph.predecessors(switch_id))
|
448
|
+
switch_predecessors = [
|
449
|
+
pred for pred in predecessors if pred in switch_nodes
|
450
|
+
]
|
451
|
+
switch_dependencies[switch_id] = switch_predecessors
|
452
|
+
|
453
|
+
if switch_predecessors:
|
454
|
+
hierarchy_info["has_hierarchies"] = True
|
455
|
+
|
456
|
+
# Calculate execution layers (topological ordering of switches)
|
457
|
+
if hierarchy_info["has_hierarchies"]:
|
458
|
+
layers = self._create_execution_layers(
|
459
|
+
switch_nodes, switch_dependencies
|
460
|
+
)
|
461
|
+
hierarchy_info["execution_layers"] = layers
|
462
|
+
hierarchy_info["max_depth"] = len(layers)
|
463
|
+
|
464
|
+
# Find dependency chains
|
465
|
+
hierarchy_info["dependency_chains"] = self._find_dependency_chains(
|
466
|
+
switch_dependencies
|
467
|
+
)
|
468
|
+
else:
|
469
|
+
hierarchy_info["independent_switches"] = switch_nodes
|
470
|
+
hierarchy_info["execution_layers"] = [switch_nodes]
|
471
|
+
hierarchy_info["max_depth"] = 1
|
472
|
+
|
473
|
+
except Exception as e:
|
474
|
+
logger.warning(f"Error analyzing switch hierarchies: {e}")
|
475
|
+
|
476
|
+
return hierarchy_info
|
477
|
+
|
478
|
+
def _create_execution_layers(
|
479
|
+
self, switch_nodes: List[str], dependencies: Dict[str, List[str]]
|
480
|
+
) -> List[List[str]]:
|
481
|
+
"""
|
482
|
+
Create execution layers for hierarchical switches.
|
483
|
+
|
484
|
+
Args:
|
485
|
+
switch_nodes: List of all switch node IDs
|
486
|
+
dependencies: Dictionary mapping switch_id -> list of dependent switches
|
487
|
+
|
488
|
+
Returns:
|
489
|
+
List of execution layers, each containing switches that can execute in parallel
|
490
|
+
"""
|
491
|
+
layers = []
|
492
|
+
remaining_switches = set(switch_nodes)
|
493
|
+
processed_switches = set()
|
494
|
+
|
495
|
+
while remaining_switches:
|
496
|
+
# Find switches with no unprocessed dependencies
|
497
|
+
current_layer = []
|
498
|
+
for switch_id in remaining_switches:
|
499
|
+
switch_deps = dependencies.get(switch_id, [])
|
500
|
+
if all(dep in processed_switches for dep in switch_deps):
|
501
|
+
current_layer.append(switch_id)
|
502
|
+
|
503
|
+
if not current_layer:
|
504
|
+
# Circular dependency or error - add remaining switches to avoid infinite loop
|
505
|
+
logger.warning("Circular dependency detected in switch hierarchy")
|
506
|
+
current_layer = list(remaining_switches)
|
507
|
+
|
508
|
+
layers.append(current_layer)
|
509
|
+
remaining_switches -= set(current_layer)
|
510
|
+
processed_switches.update(current_layer)
|
511
|
+
|
512
|
+
return layers
|
513
|
+
|
514
|
+
def _find_dependency_chains(
|
515
|
+
self, dependencies: Dict[str, List[str]]
|
516
|
+
) -> List[List[str]]:
|
517
|
+
"""
|
518
|
+
Find dependency chains in switch hierarchies.
|
519
|
+
|
520
|
+
Args:
|
521
|
+
dependencies: Dictionary mapping switch_id -> list of dependent switches
|
522
|
+
|
523
|
+
Returns:
|
524
|
+
List of dependency chains (each chain is a list of switch IDs)
|
525
|
+
"""
|
526
|
+
chains = []
|
527
|
+
visited = set()
|
528
|
+
|
529
|
+
def build_chain(switch_id: str, current_chain: List[str]):
|
530
|
+
if switch_id in visited or switch_id in current_chain:
|
531
|
+
return # Avoid cycles
|
532
|
+
|
533
|
+
current_chain.append(switch_id)
|
534
|
+
deps = dependencies.get(switch_id, [])
|
535
|
+
|
536
|
+
if not deps:
|
537
|
+
# End of chain
|
538
|
+
if len(current_chain) > 1:
|
539
|
+
chains.append(current_chain.copy())
|
540
|
+
else:
|
541
|
+
for dep in deps:
|
542
|
+
build_chain(dep, current_chain.copy())
|
543
|
+
|
544
|
+
for switch_id in dependencies:
|
545
|
+
if switch_id not in visited:
|
546
|
+
build_chain(switch_id, [])
|
547
|
+
visited.add(switch_id)
|
548
|
+
|
549
|
+
return chains
|
550
|
+
|
551
|
+
def create_hierarchical_execution_plan(
|
552
|
+
self, switch_results: Dict[str, Dict[str, Any]]
|
553
|
+
) -> Dict[str, Any]:
|
554
|
+
"""
|
555
|
+
Create execution plan that handles hierarchical switch dependencies.
|
556
|
+
|
557
|
+
Args:
|
558
|
+
switch_results: Results from SwitchNode execution
|
559
|
+
|
560
|
+
Returns:
|
561
|
+
Dictionary containing hierarchical execution plan
|
562
|
+
"""
|
563
|
+
plan = {
|
564
|
+
"execution_layers": [],
|
565
|
+
"reachable_nodes": set(),
|
566
|
+
"merge_strategies": {},
|
567
|
+
"performance_estimate": 0,
|
568
|
+
}
|
569
|
+
|
570
|
+
try:
|
571
|
+
switch_nodes = self._find_switch_nodes()
|
572
|
+
hierarchy_info = self.analyze_switch_hierarchies(switch_nodes)
|
573
|
+
|
574
|
+
# Process each execution layer
|
575
|
+
for layer_index, layer_switches in enumerate(
|
576
|
+
hierarchy_info["execution_layers"]
|
577
|
+
):
|
578
|
+
layer_plan = {
|
579
|
+
"layer_index": layer_index,
|
580
|
+
"switches": layer_switches,
|
581
|
+
"reachable_from_layer": set(),
|
582
|
+
"blocked_from_layer": set(),
|
583
|
+
}
|
584
|
+
|
585
|
+
# Determine reachable nodes from this layer
|
586
|
+
for switch_id in layer_switches:
|
587
|
+
if switch_id in switch_results:
|
588
|
+
reachable = self._get_reachable_from_switch(
|
589
|
+
switch_id, switch_results[switch_id]
|
590
|
+
)
|
591
|
+
layer_plan["reachable_from_layer"].update(reachable)
|
592
|
+
plan["reachable_nodes"].update(reachable)
|
593
|
+
else:
|
594
|
+
# Switch not executed yet - add to blocked
|
595
|
+
layer_plan["blocked_from_layer"].add(switch_id)
|
596
|
+
|
597
|
+
plan["execution_layers"].append(layer_plan)
|
598
|
+
|
599
|
+
# Analyze merge nodes and their strategies
|
600
|
+
merge_nodes = self._find_merge_nodes()
|
601
|
+
for merge_id in merge_nodes:
|
602
|
+
plan["merge_strategies"][merge_id] = self._determine_merge_strategy(
|
603
|
+
merge_id, plan["reachable_nodes"]
|
604
|
+
)
|
605
|
+
|
606
|
+
# Estimate performance improvement
|
607
|
+
total_nodes = len(self.workflow.graph.nodes)
|
608
|
+
reachable_count = len(plan["reachable_nodes"])
|
609
|
+
plan["performance_estimate"] = (
|
610
|
+
(total_nodes - reachable_count) / total_nodes if total_nodes > 0 else 0
|
611
|
+
)
|
612
|
+
|
613
|
+
except Exception as e:
|
614
|
+
logger.warning(f"Error creating hierarchical execution plan: {e}")
|
615
|
+
|
616
|
+
return plan
|
617
|
+
|
618
|
+
def _get_reachable_from_switch(
|
619
|
+
self, switch_id: str, switch_result: Dict[str, Any]
|
620
|
+
) -> Set[str]:
|
621
|
+
"""
|
622
|
+
Get nodes reachable from a specific switch result.
|
623
|
+
|
624
|
+
Args:
|
625
|
+
switch_id: ID of the switch node
|
626
|
+
switch_result: Result from switch execution
|
627
|
+
|
628
|
+
Returns:
|
629
|
+
Set of node IDs reachable from this switch
|
630
|
+
"""
|
631
|
+
reachable = set()
|
632
|
+
|
633
|
+
try:
|
634
|
+
# Get the branch map for this switch
|
635
|
+
branch_map = self._build_branch_map()
|
636
|
+
switch_branches = branch_map.get(switch_id, {})
|
637
|
+
|
638
|
+
# Check which branches are active
|
639
|
+
for output_key, nodes in switch_branches.items():
|
640
|
+
if (
|
641
|
+
output_key in switch_result
|
642
|
+
and switch_result[output_key] is not None
|
643
|
+
):
|
644
|
+
reachable.update(nodes)
|
645
|
+
|
646
|
+
except Exception as e:
|
647
|
+
logger.warning(
|
648
|
+
f"Error getting reachable nodes from switch {switch_id}: {e}"
|
649
|
+
)
|
650
|
+
|
651
|
+
return reachable
|
652
|
+
|
653
|
+
def _determine_merge_strategy(
|
654
|
+
self, merge_id: str, reachable_nodes: Set[str]
|
655
|
+
) -> Dict[str, Any]:
|
656
|
+
"""
|
657
|
+
Determine merge strategy for a MergeNode based on reachable inputs.
|
658
|
+
|
659
|
+
Args:
|
660
|
+
merge_id: ID of the merge node
|
661
|
+
reachable_nodes: Set of nodes that will be executed
|
662
|
+
|
663
|
+
Returns:
|
664
|
+
Dictionary describing the merge strategy
|
665
|
+
"""
|
666
|
+
strategy = {
|
667
|
+
"merge_id": merge_id,
|
668
|
+
"available_inputs": [],
|
669
|
+
"missing_inputs": [],
|
670
|
+
"strategy_type": "partial",
|
671
|
+
"skip_merge": False,
|
672
|
+
}
|
673
|
+
|
674
|
+
try:
|
675
|
+
# Get predecessors of the merge node
|
676
|
+
predecessors = list(self.workflow.graph.predecessors(merge_id))
|
677
|
+
|
678
|
+
for pred in predecessors:
|
679
|
+
if pred in reachable_nodes:
|
680
|
+
strategy["available_inputs"].append(pred)
|
681
|
+
else:
|
682
|
+
strategy["missing_inputs"].append(pred)
|
683
|
+
|
684
|
+
# Determine strategy based on available inputs
|
685
|
+
if not strategy["available_inputs"]:
|
686
|
+
strategy["strategy_type"] = "skip"
|
687
|
+
strategy["skip_merge"] = True
|
688
|
+
elif not strategy["missing_inputs"]:
|
689
|
+
strategy["strategy_type"] = "full"
|
690
|
+
else:
|
691
|
+
strategy["strategy_type"] = "partial"
|
692
|
+
|
693
|
+
except Exception as e:
|
694
|
+
logger.warning(f"Error determining merge strategy for {merge_id}: {e}")
|
695
|
+
|
696
|
+
return strategy
|