kailash 0.9.1__py3-none-any.whl → 0.9.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,696 @@
1
+ """
2
+ ConditionalBranchAnalyzer for analyzing workflow conditional patterns.
3
+
4
+ Analyzes workflow graphs to identify conditional execution patterns and determine
5
+ which nodes are reachable based on SwitchNode outputs.
6
+ """
7
+
8
+ import logging
9
+ from typing import Any, Dict, List, Optional, Set
10
+
11
+ import networkx as nx
12
+
13
+ from kailash.nodes.logic.operations import MergeNode, SwitchNode
14
+ from kailash.workflow.graph import Workflow
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class ConditionalBranchAnalyzer:
20
+ """
21
+ Analyzes workflow graph to identify conditional execution patterns.
22
+
23
+ This analyzer examines workflows containing SwitchNode instances and maps
24
+ out the conditional branches to determine which nodes are reachable based
25
+ on runtime switch results.
26
+ """
27
+
28
+ def __init__(self, workflow: Workflow):
29
+ """
30
+ Initialize the ConditionalBranchAnalyzer.
31
+
32
+ Args:
33
+ workflow: The workflow to analyze for conditional patterns
34
+ """
35
+ self.workflow = workflow
36
+ self._switch_nodes: Optional[List[str]] = None
37
+ self._branch_map: Optional[Dict[str, Dict[str, Set[str]]]] = None
38
+
39
+ def _find_switch_nodes(self) -> List[str]:
40
+ """
41
+ Find all SwitchNode instances in the workflow.
42
+
43
+ Returns:
44
+ List of node IDs that are SwitchNode instances
45
+ """
46
+ if self._switch_nodes is not None:
47
+ return self._switch_nodes
48
+
49
+ switch_nodes = []
50
+
51
+ if not hasattr(self.workflow, "graph") or self.workflow.graph is None:
52
+ logger.warning("Workflow has no graph to analyze")
53
+ return switch_nodes
54
+
55
+ for node_id in self.workflow.graph.nodes():
56
+ try:
57
+ node_data = self.workflow.graph.nodes[node_id]
58
+ # Try both 'node' and 'instance' keys for compatibility
59
+ node_instance = node_data.get("node") or node_data.get("instance")
60
+
61
+ if isinstance(node_instance, SwitchNode):
62
+ switch_nodes.append(node_id)
63
+ logger.debug(f"Found SwitchNode: {node_id}")
64
+
65
+ except (KeyError, AttributeError) as e:
66
+ logger.debug(f"Skipping node {node_id} - no valid instance: {e}")
67
+ continue
68
+
69
+ self._switch_nodes = switch_nodes
70
+ logger.info(f"Found {len(switch_nodes)} SwitchNode instances")
71
+ return switch_nodes
72
+
73
+ def _build_branch_map(self) -> Dict[str, Dict[str, Set[str]]]:
74
+ """
75
+ Build map of SwitchNode -> {output_port -> downstream_nodes}.
76
+
77
+ Returns:
78
+ Dictionary mapping switch node IDs to their output ports and
79
+ the downstream nodes reachable from each port
80
+ """
81
+ if self._branch_map is not None:
82
+ return self._branch_map
83
+
84
+ branch_map = {}
85
+ switch_nodes = self._find_switch_nodes()
86
+
87
+ for switch_id in switch_nodes:
88
+ branch_map[switch_id] = {}
89
+
90
+ # Find all outgoing edges from this switch
91
+ if hasattr(self.workflow.graph, "out_edges"):
92
+ for source, target, edge_data in self.workflow.graph.out_edges(
93
+ switch_id, data=True
94
+ ):
95
+ mapping = edge_data.get("mapping", {})
96
+
97
+ # Extract output ports from mapping
98
+ for source_port, target_port in mapping.items():
99
+ # Output ports typically start with output name (true_output, false_output, case_X)
100
+ if source_port not in branch_map[switch_id]:
101
+ branch_map[switch_id][source_port] = set()
102
+
103
+ branch_map[switch_id][source_port].add(target)
104
+ logger.debug(
105
+ f"Switch {switch_id} port {source_port} -> {target}"
106
+ )
107
+
108
+ # Find downstream nodes recursively
109
+ downstream = self._find_downstream_nodes(target, switch_nodes)
110
+ branch_map[switch_id][source_port].update(downstream)
111
+
112
+ self._branch_map = branch_map
113
+ logger.info(f"Built branch map for {len(branch_map)} switches")
114
+ return branch_map
115
+
116
+ def _find_downstream_nodes(
117
+ self, start_node: str, exclude_switches: List[str]
118
+ ) -> Set[str]:
119
+ """
120
+ Find all downstream nodes from a starting node, excluding switches.
121
+
122
+ Args:
123
+ start_node: Node to start traversal from
124
+ exclude_switches: Switch nodes to stop traversal at
125
+
126
+ Returns:
127
+ Set of downstream node IDs
128
+ """
129
+ downstream = set()
130
+ visited = set()
131
+ stack = [start_node]
132
+
133
+ while stack:
134
+ current = stack.pop()
135
+ if current in visited:
136
+ continue
137
+
138
+ visited.add(current)
139
+
140
+ # Don't traverse past other switches (they create their own branches)
141
+ if current in exclude_switches and current != start_node:
142
+ continue
143
+
144
+ downstream.add(current)
145
+
146
+ # Add successors to stack
147
+ if hasattr(self.workflow.graph, "successors"):
148
+ for successor in self.workflow.graph.successors(current):
149
+ if successor not in visited:
150
+ stack.append(successor)
151
+
152
+ # Remove the start node from downstream set
153
+ downstream.discard(start_node)
154
+ return downstream
155
+
156
+ def get_reachable_nodes(
157
+ self, switch_results: Dict[str, Dict[str, Any]]
158
+ ) -> Set[str]:
159
+ """
160
+ Get reachable nodes based on SwitchNode results.
161
+
162
+ Args:
163
+ switch_results: Dictionary mapping switch_id -> {output_port -> result}
164
+ None results indicate the port was not activated
165
+
166
+ Returns:
167
+ Set of node IDs that are reachable based on switch results
168
+ """
169
+ reachable = set()
170
+ branch_map = self._build_branch_map()
171
+
172
+ # Always include switches that executed
173
+ reachable.update(switch_results.keys())
174
+
175
+ # Track nodes to process for downstream traversal
176
+ to_process = set()
177
+
178
+ # Process each switch result to find directly connected nodes
179
+ for switch_id, port_results in switch_results.items():
180
+ if switch_id not in branch_map:
181
+ logger.warning(f"Switch {switch_id} not found in branch map")
182
+ continue
183
+
184
+ switch_branches = branch_map[switch_id]
185
+
186
+ for port, result in port_results.items():
187
+ if result is not None: # This port was activated
188
+ if port in switch_branches:
189
+ direct_nodes = switch_branches[port]
190
+ reachable.update(direct_nodes)
191
+ to_process.update(direct_nodes)
192
+ logger.debug(
193
+ f"Switch {switch_id} port {port} activated - added {len(direct_nodes)} direct nodes"
194
+ )
195
+
196
+ # Now traverse the graph to find ALL downstream nodes from the activated branches
197
+ while to_process:
198
+ current_node = to_process.pop()
199
+
200
+ # Get all successors of the current node
201
+ if hasattr(self.workflow.graph, "successors"):
202
+ successors = list(self.workflow.graph.successors(current_node))
203
+ for successor in successors:
204
+ if successor not in reachable:
205
+ reachable.add(successor)
206
+ to_process.add(successor)
207
+ logger.debug(
208
+ f"Added downstream node {successor} from {current_node}"
209
+ )
210
+
211
+ logger.info(
212
+ f"Determined {len(reachable)} reachable nodes from switch results (including downstream)"
213
+ )
214
+ return reachable
215
+
216
+ def detect_conditional_patterns(self) -> Dict[str, Any]:
217
+ """
218
+ Detect complex conditional patterns in the workflow.
219
+
220
+ Returns:
221
+ Dictionary containing information about detected patterns
222
+ """
223
+ patterns = {}
224
+ switch_nodes = self._find_switch_nodes()
225
+ branch_map = self._build_branch_map()
226
+
227
+ # Basic statistics
228
+ patterns["total_switches"] = len(switch_nodes)
229
+ patterns["has_cycles"] = self._detect_cycles()
230
+
231
+ # Pattern classification
232
+ if len(switch_nodes) == 1:
233
+ patterns["single_switch"] = switch_nodes
234
+ elif len(switch_nodes) > 1:
235
+ patterns["multiple_switches"] = switch_nodes
236
+ patterns["cascading_switches"] = self._detect_cascading_switches(
237
+ switch_nodes
238
+ )
239
+
240
+ # Detect merge nodes
241
+ merge_nodes = self._find_merge_nodes()
242
+ if merge_nodes:
243
+ patterns["merge_nodes"] = merge_nodes
244
+
245
+ # Complex pattern detection
246
+ if patterns["has_cycles"] and switch_nodes:
247
+ patterns["cyclic_conditional"] = True
248
+
249
+ if len(merge_nodes) > 1:
250
+ patterns["complex_merge_patterns"] = True
251
+
252
+ # Detect circular switch dependencies
253
+ if self._detect_circular_switch_dependencies(switch_nodes):
254
+ patterns["circular_switches"] = True
255
+
256
+ # Multi-case switch detection
257
+ multi_case_switches = self._detect_multi_case_switches(switch_nodes)
258
+ if multi_case_switches:
259
+ patterns["multi_case_switches"] = multi_case_switches
260
+
261
+ logger.info(f"Detected conditional patterns: {patterns}")
262
+ return patterns
263
+
264
+ def _detect_cycles(self) -> bool:
265
+ """Detect if the workflow has cycles."""
266
+ try:
267
+ # Prioritize NetworkX cycle detection (detects structural cycles)
268
+ if hasattr(self.workflow.graph, "nodes"):
269
+ # Use NetworkX to detect cycles in graph structure
270
+ has_structural_cycles = not nx.is_directed_acyclic_graph(
271
+ self.workflow.graph
272
+ )
273
+ if has_structural_cycles:
274
+ return True
275
+
276
+ # Also check for explicitly marked cycle connections
277
+ if hasattr(self.workflow, "has_cycles"):
278
+ return self.workflow.has_cycles()
279
+
280
+ return False
281
+ except Exception as e:
282
+ logger.debug(f"Error detecting cycles: {e}")
283
+ return False
284
+
285
+ def _detect_cascading_switches(self, switch_nodes: List[str]) -> List[List[str]]:
286
+ """Detect cascading switch patterns (switch -> switch -> ...)."""
287
+ cascading = []
288
+
289
+ for switch_id in switch_nodes:
290
+ # Check if this switch leads to other switches
291
+ if switch_id in self._branch_map:
292
+ for port, downstream in self._branch_map[switch_id].items():
293
+ switch_chain = [switch_id]
294
+
295
+ # Find switches in downstream nodes
296
+ downstream_switches = [
297
+ node for node in downstream if node in switch_nodes
298
+ ]
299
+ if downstream_switches:
300
+ switch_chain.extend(downstream_switches)
301
+ if len(switch_chain) > 1:
302
+ cascading.append(switch_chain)
303
+
304
+ return cascading
305
+
306
+ def _find_merge_nodes(self) -> List[str]:
307
+ """Find all MergeNode instances in the workflow."""
308
+ merge_nodes = []
309
+
310
+ if not hasattr(self.workflow, "graph") or self.workflow.graph is None:
311
+ return merge_nodes
312
+
313
+ for node_id in self.workflow.graph.nodes():
314
+ try:
315
+ node_data = self.workflow.graph.nodes[node_id]
316
+ # Try both 'node' and 'instance' keys for compatibility
317
+ node_instance = node_data.get("node") or node_data.get("instance")
318
+
319
+ if isinstance(node_instance, MergeNode):
320
+ merge_nodes.append(node_id)
321
+
322
+ except (KeyError, AttributeError):
323
+ continue
324
+
325
+ return merge_nodes
326
+
327
+ def _detect_circular_switch_dependencies(self, switch_nodes: List[str]) -> bool:
328
+ """Detect circular dependencies between switches."""
329
+ if len(switch_nodes) < 2:
330
+ return False
331
+
332
+ # Build dependency graph between switches
333
+ switch_deps = nx.DiGraph()
334
+ switch_deps.add_nodes_from(switch_nodes)
335
+
336
+ for switch_id in switch_nodes:
337
+ if switch_id in self._branch_map:
338
+ for port, downstream in self._branch_map[switch_id].items():
339
+ for downstream_switch in downstream:
340
+ if downstream_switch in switch_nodes:
341
+ switch_deps.add_edge(switch_id, downstream_switch)
342
+
343
+ # Check for cycles in switch dependency graph
344
+ try:
345
+ return not nx.is_directed_acyclic_graph(switch_deps)
346
+ except Exception:
347
+ return False
348
+
349
+ def _detect_multi_case_switches(self, switch_nodes: List[str]) -> List[str]:
350
+ """Detect multi-case switches (more than true/false outputs)."""
351
+ multi_case = []
352
+
353
+ for switch_id in switch_nodes:
354
+ if switch_id in self._branch_map:
355
+ ports = list(self._branch_map[switch_id].keys())
356
+
357
+ # Multi-case switches have more than 2 output ports or case_X patterns
358
+ case_ports = [p for p in ports if p.startswith("case_")]
359
+ if len(ports) > 2 or case_ports:
360
+ multi_case.append(switch_id)
361
+
362
+ return multi_case
363
+
364
+ def _get_switch_branch_map(self, switch_id: str) -> Dict[str, Set[str]]:
365
+ """
366
+ Get the branch map for a specific switch node.
367
+
368
+ Args:
369
+ switch_id: ID of the switch node
370
+
371
+ Returns:
372
+ Dictionary mapping output ports to downstream nodes
373
+ """
374
+ branch_map = self._build_branch_map()
375
+ return branch_map.get(switch_id, {})
376
+
377
+ def detect_switch_hierarchies(self) -> List[Dict[str, Any]]:
378
+ """
379
+ Detect hierarchical switch patterns.
380
+
381
+ Returns:
382
+ List of hierarchy information dictionaries
383
+ """
384
+ hierarchies = []
385
+ switch_nodes = self._find_switch_nodes()
386
+
387
+ if len(switch_nodes) <= 1:
388
+ return hierarchies
389
+
390
+ # Get hierarchy analysis
391
+ hierarchy_info = self.analyze_switch_hierarchies(switch_nodes)
392
+
393
+ if hierarchy_info["has_hierarchies"]:
394
+ hierarchies.append(
395
+ {
396
+ "layers": hierarchy_info["execution_layers"],
397
+ "max_depth": hierarchy_info["max_depth"],
398
+ "dependency_chains": hierarchy_info["dependency_chains"],
399
+ }
400
+ )
401
+
402
+ return hierarchies
403
+
404
+ def invalidate_cache(self):
405
+ """Invalidate cached analysis results."""
406
+ self._switch_nodes = None
407
+ self._branch_map = None
408
+ logger.debug("ConditionalBranchAnalyzer cache invalidated")
409
+
410
+ # ===== PHASE 4: ADVANCED FEATURES =====
411
+
412
+ def analyze_switch_hierarchies(
413
+ self, switch_nodes: Optional[List[str]] = None
414
+ ) -> Dict[str, Any]:
415
+ """
416
+ Analyze hierarchical relationships between SwitchNodes.
417
+
418
+ Args:
419
+ switch_nodes: Optional list of SwitchNode IDs to analyze
420
+
421
+ Returns:
422
+ Dictionary with hierarchy analysis results
423
+ """
424
+ if switch_nodes is None:
425
+ switch_nodes = self._find_switch_nodes()
426
+
427
+ hierarchy_info = {
428
+ "has_hierarchies": False,
429
+ "max_depth": 0,
430
+ "dependency_chains": [],
431
+ "independent_switches": [],
432
+ "execution_layers": [],
433
+ }
434
+
435
+ try:
436
+ if len(switch_nodes) <= 1:
437
+ hierarchy_info["independent_switches"] = switch_nodes
438
+ hierarchy_info["execution_layers"] = (
439
+ [switch_nodes] if switch_nodes else []
440
+ )
441
+ hierarchy_info["max_depth"] = 1 if switch_nodes else 0
442
+ return hierarchy_info
443
+
444
+ # Build dependency graph between switches
445
+ switch_dependencies = {}
446
+ for switch_id in switch_nodes:
447
+ predecessors = list(self.workflow.graph.predecessors(switch_id))
448
+ switch_predecessors = [
449
+ pred for pred in predecessors if pred in switch_nodes
450
+ ]
451
+ switch_dependencies[switch_id] = switch_predecessors
452
+
453
+ if switch_predecessors:
454
+ hierarchy_info["has_hierarchies"] = True
455
+
456
+ # Calculate execution layers (topological ordering of switches)
457
+ if hierarchy_info["has_hierarchies"]:
458
+ layers = self._create_execution_layers(
459
+ switch_nodes, switch_dependencies
460
+ )
461
+ hierarchy_info["execution_layers"] = layers
462
+ hierarchy_info["max_depth"] = len(layers)
463
+
464
+ # Find dependency chains
465
+ hierarchy_info["dependency_chains"] = self._find_dependency_chains(
466
+ switch_dependencies
467
+ )
468
+ else:
469
+ hierarchy_info["independent_switches"] = switch_nodes
470
+ hierarchy_info["execution_layers"] = [switch_nodes]
471
+ hierarchy_info["max_depth"] = 1
472
+
473
+ except Exception as e:
474
+ logger.warning(f"Error analyzing switch hierarchies: {e}")
475
+
476
+ return hierarchy_info
477
+
478
+ def _create_execution_layers(
479
+ self, switch_nodes: List[str], dependencies: Dict[str, List[str]]
480
+ ) -> List[List[str]]:
481
+ """
482
+ Create execution layers for hierarchical switches.
483
+
484
+ Args:
485
+ switch_nodes: List of all switch node IDs
486
+ dependencies: Dictionary mapping switch_id -> list of dependent switches
487
+
488
+ Returns:
489
+ List of execution layers, each containing switches that can execute in parallel
490
+ """
491
+ layers = []
492
+ remaining_switches = set(switch_nodes)
493
+ processed_switches = set()
494
+
495
+ while remaining_switches:
496
+ # Find switches with no unprocessed dependencies
497
+ current_layer = []
498
+ for switch_id in remaining_switches:
499
+ switch_deps = dependencies.get(switch_id, [])
500
+ if all(dep in processed_switches for dep in switch_deps):
501
+ current_layer.append(switch_id)
502
+
503
+ if not current_layer:
504
+ # Circular dependency or error - add remaining switches to avoid infinite loop
505
+ logger.warning("Circular dependency detected in switch hierarchy")
506
+ current_layer = list(remaining_switches)
507
+
508
+ layers.append(current_layer)
509
+ remaining_switches -= set(current_layer)
510
+ processed_switches.update(current_layer)
511
+
512
+ return layers
513
+
514
+ def _find_dependency_chains(
515
+ self, dependencies: Dict[str, List[str]]
516
+ ) -> List[List[str]]:
517
+ """
518
+ Find dependency chains in switch hierarchies.
519
+
520
+ Args:
521
+ dependencies: Dictionary mapping switch_id -> list of dependent switches
522
+
523
+ Returns:
524
+ List of dependency chains (each chain is a list of switch IDs)
525
+ """
526
+ chains = []
527
+ visited = set()
528
+
529
+ def build_chain(switch_id: str, current_chain: List[str]):
530
+ if switch_id in visited or switch_id in current_chain:
531
+ return # Avoid cycles
532
+
533
+ current_chain.append(switch_id)
534
+ deps = dependencies.get(switch_id, [])
535
+
536
+ if not deps:
537
+ # End of chain
538
+ if len(current_chain) > 1:
539
+ chains.append(current_chain.copy())
540
+ else:
541
+ for dep in deps:
542
+ build_chain(dep, current_chain.copy())
543
+
544
+ for switch_id in dependencies:
545
+ if switch_id not in visited:
546
+ build_chain(switch_id, [])
547
+ visited.add(switch_id)
548
+
549
+ return chains
550
+
551
+ def create_hierarchical_execution_plan(
552
+ self, switch_results: Dict[str, Dict[str, Any]]
553
+ ) -> Dict[str, Any]:
554
+ """
555
+ Create execution plan that handles hierarchical switch dependencies.
556
+
557
+ Args:
558
+ switch_results: Results from SwitchNode execution
559
+
560
+ Returns:
561
+ Dictionary containing hierarchical execution plan
562
+ """
563
+ plan = {
564
+ "execution_layers": [],
565
+ "reachable_nodes": set(),
566
+ "merge_strategies": {},
567
+ "performance_estimate": 0,
568
+ }
569
+
570
+ try:
571
+ switch_nodes = self._find_switch_nodes()
572
+ hierarchy_info = self.analyze_switch_hierarchies(switch_nodes)
573
+
574
+ # Process each execution layer
575
+ for layer_index, layer_switches in enumerate(
576
+ hierarchy_info["execution_layers"]
577
+ ):
578
+ layer_plan = {
579
+ "layer_index": layer_index,
580
+ "switches": layer_switches,
581
+ "reachable_from_layer": set(),
582
+ "blocked_from_layer": set(),
583
+ }
584
+
585
+ # Determine reachable nodes from this layer
586
+ for switch_id in layer_switches:
587
+ if switch_id in switch_results:
588
+ reachable = self._get_reachable_from_switch(
589
+ switch_id, switch_results[switch_id]
590
+ )
591
+ layer_plan["reachable_from_layer"].update(reachable)
592
+ plan["reachable_nodes"].update(reachable)
593
+ else:
594
+ # Switch not executed yet - add to blocked
595
+ layer_plan["blocked_from_layer"].add(switch_id)
596
+
597
+ plan["execution_layers"].append(layer_plan)
598
+
599
+ # Analyze merge nodes and their strategies
600
+ merge_nodes = self._find_merge_nodes()
601
+ for merge_id in merge_nodes:
602
+ plan["merge_strategies"][merge_id] = self._determine_merge_strategy(
603
+ merge_id, plan["reachable_nodes"]
604
+ )
605
+
606
+ # Estimate performance improvement
607
+ total_nodes = len(self.workflow.graph.nodes)
608
+ reachable_count = len(plan["reachable_nodes"])
609
+ plan["performance_estimate"] = (
610
+ (total_nodes - reachable_count) / total_nodes if total_nodes > 0 else 0
611
+ )
612
+
613
+ except Exception as e:
614
+ logger.warning(f"Error creating hierarchical execution plan: {e}")
615
+
616
+ return plan
617
+
618
+ def _get_reachable_from_switch(
619
+ self, switch_id: str, switch_result: Dict[str, Any]
620
+ ) -> Set[str]:
621
+ """
622
+ Get nodes reachable from a specific switch result.
623
+
624
+ Args:
625
+ switch_id: ID of the switch node
626
+ switch_result: Result from switch execution
627
+
628
+ Returns:
629
+ Set of node IDs reachable from this switch
630
+ """
631
+ reachable = set()
632
+
633
+ try:
634
+ # Get the branch map for this switch
635
+ branch_map = self._build_branch_map()
636
+ switch_branches = branch_map.get(switch_id, {})
637
+
638
+ # Check which branches are active
639
+ for output_key, nodes in switch_branches.items():
640
+ if (
641
+ output_key in switch_result
642
+ and switch_result[output_key] is not None
643
+ ):
644
+ reachable.update(nodes)
645
+
646
+ except Exception as e:
647
+ logger.warning(
648
+ f"Error getting reachable nodes from switch {switch_id}: {e}"
649
+ )
650
+
651
+ return reachable
652
+
653
+ def _determine_merge_strategy(
654
+ self, merge_id: str, reachable_nodes: Set[str]
655
+ ) -> Dict[str, Any]:
656
+ """
657
+ Determine merge strategy for a MergeNode based on reachable inputs.
658
+
659
+ Args:
660
+ merge_id: ID of the merge node
661
+ reachable_nodes: Set of nodes that will be executed
662
+
663
+ Returns:
664
+ Dictionary describing the merge strategy
665
+ """
666
+ strategy = {
667
+ "merge_id": merge_id,
668
+ "available_inputs": [],
669
+ "missing_inputs": [],
670
+ "strategy_type": "partial",
671
+ "skip_merge": False,
672
+ }
673
+
674
+ try:
675
+ # Get predecessors of the merge node
676
+ predecessors = list(self.workflow.graph.predecessors(merge_id))
677
+
678
+ for pred in predecessors:
679
+ if pred in reachable_nodes:
680
+ strategy["available_inputs"].append(pred)
681
+ else:
682
+ strategy["missing_inputs"].append(pred)
683
+
684
+ # Determine strategy based on available inputs
685
+ if not strategy["available_inputs"]:
686
+ strategy["strategy_type"] = "skip"
687
+ strategy["skip_merge"] = True
688
+ elif not strategy["missing_inputs"]:
689
+ strategy["strategy_type"] = "full"
690
+ else:
691
+ strategy["strategy_type"] = "partial"
692
+
693
+ except Exception as e:
694
+ logger.warning(f"Error determining merge strategy for {merge_id}: {e}")
695
+
696
+ return strategy