pytrilogy 0.0.3.91__py3-none-any.whl → 0.0.3.93__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pytrilogy might be problematic. Click here for more details.

@@ -0,0 +1,792 @@
1
+ from functools import lru_cache
2
+ from typing import TYPE_CHECKING, Dict, List, Optional, Set
3
+
4
+ import networkx as nx
5
+
6
+ from trilogy.constants import logger
7
+ from trilogy.core.enums import Derivation
8
+ from trilogy.core.graph_models import (
9
+ concept_to_node,
10
+ get_graph_exact_match,
11
+ prune_sources_for_conditions,
12
+ )
13
+ from trilogy.core.models.build import (
14
+ BuildConcept,
15
+ BuildDatasource,
16
+ BuildGrain,
17
+ BuildWhereClause,
18
+ LooseBuildConceptList,
19
+ )
20
+ from trilogy.core.models.build_environment import BuildEnvironment
21
+ from trilogy.core.processing.node_generators.select_helpers.datasource_injection import (
22
+ get_union_sources,
23
+ )
24
+ from trilogy.core.processing.nodes import (
25
+ ConstantNode,
26
+ GroupNode,
27
+ MergeNode,
28
+ SelectNode,
29
+ StrategyNode,
30
+ )
31
+ from trilogy.core.processing.utility import padding
32
+
33
+ if TYPE_CHECKING:
34
+ from trilogy.core.processing.nodes.union_node import UnionNode
35
+
36
+ LOGGER_PREFIX = "[GEN_ROOT_MERGE_NODE]"
37
+
38
+
39
+ # Cache for expensive string operations
40
+ @lru_cache(maxsize=1024)
41
+ def extract_address(node: str) -> str:
42
+ """Cached version of address extraction from node string."""
43
+ return node.split("~")[1].split("@")[0]
44
+
45
+
46
+ @lru_cache(maxsize=1024)
47
+ def get_node_type(node: str) -> str:
48
+ """Extract node type prefix efficiently."""
49
+ if "~" in node:
50
+ return node.split("~")[0] + "~"
51
+ return ""
52
+
53
+
54
+ class GraphAttributeCache:
55
+ """Cache for expensive NetworkX attribute operations."""
56
+
57
+ def __init__(self, graph: nx.DiGraph):
58
+ self.graph = graph
59
+ self._datasources: dict | None = None
60
+ self._concepts: dict | None = None
61
+ self._ds_nodes: set[str] | None = None
62
+ self._concept_nodes: set[str] | None = None
63
+
64
+ @property
65
+ def datasources(self) -> Dict:
66
+ if self._datasources is None:
67
+ self._datasources = nx.get_node_attributes(self.graph, "datasource") or {}
68
+ return self._datasources
69
+
70
+ @property
71
+ def concepts(self) -> Dict:
72
+ if self._concepts is None:
73
+ self._concepts = nx.get_node_attributes(self.graph, "concept") or {}
74
+ return self._concepts
75
+
76
+ @property
77
+ def ds_nodes(self) -> Set[str]:
78
+ if self._ds_nodes is None:
79
+ self._ds_nodes = {n for n in self.graph.nodes if n.startswith("ds~")}
80
+ return self._ds_nodes
81
+
82
+ @property
83
+ def concept_nodes(self) -> Set[str]:
84
+ if self._concept_nodes is None:
85
+ self._concept_nodes = {
86
+ n for n in self.graph.nodes if n.startswith("c~")
87
+ } or set()
88
+ return self._concept_nodes
89
+
90
+
91
+ def get_graph_partial_nodes(
92
+ g: nx.DiGraph,
93
+ conditions: BuildWhereClause | None,
94
+ cache: GraphAttributeCache | None = None,
95
+ ) -> dict[str, list[str]]:
96
+ """Optimized version with caching and early returns."""
97
+ if cache is None:
98
+ cache = GraphAttributeCache(g)
99
+
100
+ datasources = cache.datasources
101
+ partial: dict[str, list[str]] = {}
102
+
103
+ for node in cache.ds_nodes: # Only iterate over datasource nodes
104
+ if node not in datasources:
105
+ continue
106
+
107
+ ds = datasources[node]
108
+ if not isinstance(ds, list):
109
+ # Early return for non-partial nodes
110
+ if ds.non_partial_for and conditions == ds.non_partial_for:
111
+ partial[node] = []
112
+ continue
113
+ partial[node] = [concept_to_node(c) for c in ds.partial_concepts]
114
+ else:
115
+ # Union sources have no partial
116
+ partial[node] = []
117
+
118
+ return partial
119
+
120
+
121
+ def get_graph_grains(
122
+ g: nx.DiGraph, cache: GraphAttributeCache | None = None
123
+ ) -> dict[str, set[str]]:
124
+ """Optimized version using set.update() instead of reduce with union."""
125
+ if cache is None:
126
+ cache = GraphAttributeCache(g)
127
+
128
+ datasources = cache.datasources
129
+ grain_length: dict[str, set[str]] = {}
130
+
131
+ for node in cache.ds_nodes: # Only iterate over datasource nodes
132
+ if node not in datasources:
133
+ continue
134
+
135
+ lookup = datasources[node]
136
+ if not isinstance(lookup, list):
137
+ lookup = [lookup]
138
+
139
+ # Optimized set building - avoid reduce and intermediate sets
140
+ components: set[str] = set()
141
+ for item in lookup:
142
+ components.update(item.grain.components)
143
+ grain_length[node] = components
144
+
145
+ return grain_length
146
+
147
+
148
+ def subgraph_is_complete(
149
+ nodes: list[str], targets: set[str], mapping: dict[str, str], g: nx.DiGraph
150
+ ) -> bool:
151
+ """Optimized with early returns and reduced iterations."""
152
+ # Early return check for target presence
153
+ mapped = set()
154
+ for n in nodes:
155
+ mapped.add(mapping.get(n, n))
156
+ # Early return if we've found all targets
157
+ if len(mapped) >= len(targets) and targets.issubset(mapped):
158
+ break
159
+
160
+ if not targets.issubset(mapped):
161
+ return False
162
+
163
+ # Check datasource edges more efficiently
164
+ has_ds_edge = {k: False for k in targets}
165
+
166
+ # Early return optimization - stop checking once all targets have edges
167
+ found_count = 0
168
+ for n in nodes:
169
+ if not n.startswith("c~"):
170
+ continue
171
+
172
+ concept_key = mapping.get(n, n)
173
+ if concept_key in has_ds_edge and not has_ds_edge[concept_key]:
174
+ # Check for datasource neighbor
175
+ for neighbor in nx.neighbors(g, n):
176
+ if neighbor.startswith("ds~"):
177
+ has_ds_edge[concept_key] = True
178
+ found_count += 1
179
+ break
180
+
181
+ # Early return if all targets have datasource edges
182
+ if found_count == len(targets):
183
+ return True
184
+
185
+ return all(has_ds_edge.values())
186
+
187
+
188
+ def create_pruned_concept_graph(
189
+ g: nx.DiGraph,
190
+ all_concepts: List[BuildConcept],
191
+ datasources: list[BuildDatasource],
192
+ accept_partial: bool = False,
193
+ conditions: BuildWhereClause | None = None,
194
+ depth: int = 0,
195
+ ) -> nx.DiGraph:
196
+ """Optimized version with caching and batch operations."""
197
+ orig_g = g
198
+ orig_cache = GraphAttributeCache(orig_g)
199
+
200
+ g = g.copy()
201
+ union_options = get_union_sources(datasources, all_concepts)
202
+
203
+ # Batch edge additions for union sources
204
+ edges_to_add = []
205
+ for ds_list in union_options:
206
+ node_address = "ds~" + "-".join([x.name for x in ds_list])
207
+ logger.info(
208
+ f"{padding(depth)}{LOGGER_PREFIX} injecting potentially relevant union datasource {node_address}"
209
+ )
210
+ common: set[BuildConcept] = set.intersection(
211
+ *[set(x.output_concepts) for x in ds_list]
212
+ )
213
+ g.add_node(node_address, datasource=ds_list)
214
+
215
+ # Collect edges for batch addition
216
+ for c in common:
217
+ c_node = concept_to_node(c)
218
+ edges_to_add.extend([(node_address, c_node), (c_node, node_address)])
219
+
220
+ # Batch add all edges at once
221
+ if edges_to_add:
222
+ g.add_edges_from(edges_to_add)
223
+
224
+ prune_sources_for_conditions(g, accept_partial, conditions)
225
+
226
+ # Create cache for the modified graph
227
+ g_cache = GraphAttributeCache(g)
228
+
229
+ target_addresses = set([c.address for c in all_concepts])
230
+ concepts = orig_cache.concepts
231
+ datasource_map = orig_cache.datasources
232
+
233
+ # Optimized filtering with early termination
234
+ relevant_concepts_pre = {}
235
+ for n in g_cache.concept_nodes: # Only iterate over concept nodes
236
+ if n in concepts:
237
+ concept = concepts[n]
238
+ if concept.address in target_addresses:
239
+ relevant_concepts_pre[n] = concept.address
240
+
241
+ relevant_concepts: list[str] = list(relevant_concepts_pre.keys())
242
+ relevent_datasets: list[str] = []
243
+
244
+ if not accept_partial:
245
+ partial = get_graph_partial_nodes(g, conditions, g_cache)
246
+ edges_to_remove = []
247
+
248
+ # Collect edges to remove
249
+ for edge in g.edges:
250
+ if edge[0] in datasource_map and edge[0] in partial:
251
+ if edge[1] in partial[edge[0]]:
252
+ edges_to_remove.append(edge)
253
+ if edge[1] in datasource_map and edge[1] in partial:
254
+ if edge[0] in partial[edge[1]]:
255
+ edges_to_remove.append(edge)
256
+
257
+ # Batch remove edges
258
+ if edges_to_remove:
259
+ g.remove_edges_from(edges_to_remove)
260
+
261
+ # Find relevant datasets more efficiently
262
+ relevant_concepts_set = set(relevant_concepts)
263
+ for n in g_cache.ds_nodes: # Only iterate over datasource nodes
264
+ # Check if any relevant concepts are neighbors
265
+ if any(
266
+ neighbor in relevant_concepts_set for neighbor in nx.all_neighbors(g, n)
267
+ ):
268
+ relevent_datasets.append(n)
269
+
270
+ # Handle additional join concepts
271
+ roots: dict[str, set[str]] = {}
272
+ for n in orig_cache.concept_nodes: # Only iterate over concept nodes
273
+ if n not in relevant_concepts:
274
+ root = n.split("@")[0]
275
+ neighbors = roots.get(root, set())
276
+ for neighbor in nx.all_neighbors(orig_g, n):
277
+ if neighbor in relevent_datasets:
278
+ neighbors.add(neighbor)
279
+ if len(neighbors) > 1:
280
+ relevant_concepts.append(n)
281
+ roots[root] = neighbors
282
+
283
+ # Remove irrelevant nodes
284
+ nodes_to_keep = set(relevent_datasets + relevant_concepts)
285
+ nodes_to_remove = [n for n in g.nodes() if n not in nodes_to_keep]
286
+ if nodes_to_remove:
287
+ g.remove_nodes_from(nodes_to_remove)
288
+
289
+ # Check subgraphs
290
+ subgraphs = list(nx.connected_components(g.to_undirected()))
291
+ subgraphs = [
292
+ s
293
+ for s in subgraphs
294
+ if subgraph_is_complete(list(s), target_addresses, relevant_concepts_pre, g)
295
+ ]
296
+
297
+ if not subgraphs:
298
+ logger.info(
299
+ f"{padding(depth)}{LOGGER_PREFIX} cannot resolve root graph - no subgraphs after node prune"
300
+ )
301
+ return None
302
+
303
+ if subgraphs and len(subgraphs) != 1:
304
+ logger.info(
305
+ f"{padding(depth)}{LOGGER_PREFIX} cannot resolve root graph - subgraphs are split - have {len(subgraphs)} from {subgraphs}"
306
+ )
307
+ return None
308
+
309
+ # Add back relevant edges - batch operation
310
+ relevant = set(relevant_concepts + relevent_datasets)
311
+ edges_to_add = []
312
+ for edge in orig_g.edges():
313
+ if edge[0] in relevant and edge[1] in relevant and not g.has_edge(*edge):
314
+ edges_to_add.append(edge)
315
+ if edges_to_add:
316
+ g.add_edges_from(edges_to_add)
317
+
318
+ # Early return check
319
+ if not any(n.startswith("ds~") for n in g.nodes):
320
+ logger.info(
321
+ f"{padding(depth)}{LOGGER_PREFIX} cannot resolve root graph - No datasource nodes found"
322
+ )
323
+ return None
324
+
325
+ return g
326
+
327
+
328
+ def resolve_subgraphs(
329
+ g: nx.DiGraph,
330
+ relevant: list[BuildConcept],
331
+ accept_partial: bool,
332
+ conditions: BuildWhereClause | None,
333
+ depth: int = 0,
334
+ ) -> dict[str, list[str]]:
335
+ """Optimized version with caching and reduced iterations."""
336
+ cache = GraphAttributeCache(g)
337
+ datasources = list(cache.ds_nodes)
338
+
339
+ # Build subgraphs more efficiently
340
+ subgraphs: dict[str, list[str]] = {}
341
+ for ds in datasources:
342
+ # Use set to avoid duplicates from the start
343
+ subgraphs[ds] = list(set(nx.all_neighbors(g, ds)))
344
+
345
+ partial_map = get_graph_partial_nodes(g, conditions, cache)
346
+ exact_map = get_graph_exact_match(g, accept_partial, conditions)
347
+ grain_length = get_graph_grains(g, cache)
348
+ concepts = cache.concepts
349
+
350
+ # Pre-compute concept addresses for all datasources
351
+ non_partial_map = {}
352
+ concept_map = {}
353
+ for ds in datasources:
354
+ ds_concepts = subgraphs[ds]
355
+ partial_concepts = set(partial_map.get(ds, []))
356
+
357
+ non_partial_map[ds] = [
358
+ concepts[c].address
359
+ for c in ds_concepts
360
+ if c in concepts and c not in partial_concepts
361
+ ]
362
+ concept_map[ds] = [concepts[c].address for c in ds_concepts if c in concepts]
363
+
364
+ pruned_subgraphs = {}
365
+
366
+ def score_node(input: str) -> tuple:
367
+ """Optimized scoring function."""
368
+ logger.debug(f"{padding(depth)}{LOGGER_PREFIX} scoring node {input}")
369
+ grain = grain_length[input]
370
+ concept_addresses = concept_map[input]
371
+
372
+ # Calculate score components
373
+ grain_score = len(grain) - sum(1 for x in concept_addresses if x in grain)
374
+ exact_match_score = 0 if input in exact_map else 0.5
375
+ concept_count = len(subgraphs[input])
376
+
377
+ score = (grain_score, exact_match_score, concept_count, input)
378
+ logger.debug(f"{padding(depth)}{LOGGER_PREFIX} node {input} has score {score}")
379
+ return score
380
+
381
+ # Optimize subset detection with early termination
382
+ for key in subgraphs:
383
+ value = non_partial_map[key]
384
+ all_concepts = concept_map[key]
385
+ is_subset = False
386
+ matches = set()
387
+
388
+ # Early termination optimization
389
+ value_set = set(value)
390
+ all_concepts_set = set(all_concepts)
391
+
392
+ for other_key in concept_map:
393
+ if key == other_key:
394
+ continue
395
+
396
+ other_value = non_partial_map[other_key]
397
+ other_all_concepts = concept_map[other_key]
398
+
399
+ # Quick check before detailed comparison
400
+ if len(value) > len(other_value) or len(all_concepts) > len(
401
+ other_all_concepts
402
+ ):
403
+ continue
404
+
405
+ other_value_set = set(other_value)
406
+ other_all_concepts_set = set(other_all_concepts)
407
+
408
+ if value_set.issubset(other_value_set) and all_concepts_set.issubset(
409
+ other_all_concepts_set
410
+ ):
411
+ if len(value) < len(other_value):
412
+ is_subset = True
413
+ logger.debug(
414
+ f"{padding(depth)}{LOGGER_PREFIX} Dropping subgraph {key} with {value} as it is a subset of {other_key} with {other_value}"
415
+ )
416
+ break # Early termination
417
+ elif len(value) == len(other_value) and len(all_concepts) == len(
418
+ other_all_concepts
419
+ ):
420
+ matches.add(other_key)
421
+ matches.add(key)
422
+
423
+ if matches and not is_subset:
424
+ min_node = min(matches, key=score_node)
425
+ logger.debug(
426
+ f"{padding(depth)}{LOGGER_PREFIX} minimum source score is {min_node}"
427
+ )
428
+ is_subset = key != min_node
429
+
430
+ if not is_subset:
431
+ pruned_subgraphs[key] = subgraphs[key]
432
+
433
+ # Final node pruning - optimized
434
+ final_nodes: set[str] = set()
435
+ for v in pruned_subgraphs.values():
436
+ final_nodes.update(v)
437
+
438
+ relevant_concepts_pre = {
439
+ n: concepts[n].address
440
+ for n in cache.concept_nodes
441
+ if n in concepts and concepts[n].address in relevant
442
+ }
443
+
444
+ # Count node occurrences once
445
+ node_counts = {}
446
+ for node in final_nodes:
447
+ if node.startswith("c~") and node not in relevant_concepts_pre:
448
+ node_counts[node] = sum(
449
+ 1 for sub_nodes in pruned_subgraphs.values() if node in sub_nodes
450
+ )
451
+
452
+ # Filter nodes based on counts
453
+ nodes_to_remove = {node for node, count in node_counts.items() if count <= 1}
454
+
455
+ if nodes_to_remove:
456
+ for key in pruned_subgraphs:
457
+ pruned_subgraphs[key] = [
458
+ n for n in pruned_subgraphs[key] if n not in nodes_to_remove
459
+ ]
460
+ logger.debug(
461
+ f"{padding(depth)}{LOGGER_PREFIX} Pruning nodes {nodes_to_remove} as irrelevant after subgraph resolution"
462
+ )
463
+
464
+ return pruned_subgraphs
465
+
466
+
467
+ def create_datasource_node(
468
+ datasource: BuildDatasource,
469
+ all_concepts: List[BuildConcept],
470
+ accept_partial: bool,
471
+ environment: BuildEnvironment,
472
+ depth: int,
473
+ conditions: BuildWhereClause | None = None,
474
+ ) -> tuple[StrategyNode, bool]:
475
+
476
+ target_grain = BuildGrain.from_concepts(all_concepts, environment=environment)
477
+ force_group = False
478
+ if not datasource.grain.issubset(target_grain):
479
+ logger.info(
480
+ f"{padding(depth)}{LOGGER_PREFIX}_DS_NODE Select node must be wrapped in group, {datasource.grain} not subset of target grain {target_grain}"
481
+ )
482
+ force_group = True
483
+ else:
484
+ logger.info(
485
+ f"{padding(depth)}{LOGGER_PREFIX}_DS_NODE Select node grain {datasource.grain} is subset of target grain {target_grain}, no group required"
486
+ )
487
+ if not datasource.grain.components:
488
+ force_group = True
489
+
490
+ # Optimized concept filtering using sets
491
+ all_concept_addresses = {c.address for c in all_concepts}
492
+
493
+ partial_concepts = [
494
+ c.concept
495
+ for c in datasource.columns
496
+ if not c.is_complete and c.concept.address in all_concept_addresses
497
+ ]
498
+
499
+ partial_lcl = LooseBuildConceptList(concepts=partial_concepts)
500
+
501
+ nullable_concepts = [
502
+ c.concept
503
+ for c in datasource.columns
504
+ if c.is_nullable and c.concept.address in all_concept_addresses
505
+ ]
506
+
507
+ nullable_lcl = LooseBuildConceptList(concepts=nullable_concepts)
508
+ partial_is_full = conditions and (conditions == datasource.non_partial_for)
509
+
510
+ datasource_conditions = datasource.where.conditional if datasource.where else None
511
+ rval = SelectNode(
512
+ input_concepts=[c.concept for c in datasource.columns],
513
+ output_concepts=sorted(all_concepts, key=lambda x: x.address),
514
+ environment=environment,
515
+ parents=[],
516
+ depth=depth,
517
+ partial_concepts=(
518
+ [] if partial_is_full else [c for c in all_concepts if c in partial_lcl]
519
+ ),
520
+ nullable_concepts=[c for c in all_concepts if c in nullable_lcl],
521
+ accept_partial=accept_partial,
522
+ datasource=datasource,
523
+ grain=datasource.grain,
524
+ conditions=datasource_conditions,
525
+ preexisting_conditions=(
526
+ conditions.conditional if partial_is_full and conditions else None
527
+ ),
528
+ )
529
+ return (
530
+ rval,
531
+ force_group,
532
+ )
533
+
534
+
535
+ def create_union_datasource(
536
+ datasource: list[BuildDatasource],
537
+ all_concepts: List[BuildConcept],
538
+ accept_partial: bool,
539
+ environment: BuildEnvironment,
540
+ depth: int,
541
+ conditions: BuildWhereClause | None = None,
542
+ ) -> tuple["UnionNode", bool]:
543
+ from trilogy.core.processing.nodes.union_node import UnionNode
544
+
545
+ logger.info(
546
+ f"{padding(depth)}{LOGGER_PREFIX} generating union node parents with condition {conditions}"
547
+ )
548
+ force_group = False
549
+ parents = []
550
+ for x in datasource:
551
+ subnode, fg = create_datasource_node(
552
+ x,
553
+ all_concepts,
554
+ accept_partial,
555
+ environment,
556
+ depth + 1,
557
+ conditions=conditions,
558
+ )
559
+ parents.append(subnode)
560
+ force_group = force_group or fg
561
+ logger.info(f"{padding(depth)}{LOGGER_PREFIX} returning union node")
562
+ return (
563
+ UnionNode(
564
+ output_concepts=all_concepts,
565
+ input_concepts=all_concepts,
566
+ environment=environment,
567
+ parents=parents,
568
+ depth=depth,
569
+ partial_concepts=[],
570
+ ),
571
+ force_group,
572
+ )
573
+
574
+
575
+ def create_select_node(
576
+ ds_name: str,
577
+ subgraph: list[str],
578
+ accept_partial: bool,
579
+ g,
580
+ environment: BuildEnvironment,
581
+ depth: int,
582
+ conditions: BuildWhereClause | None = None,
583
+ ) -> StrategyNode:
584
+
585
+ # Use cached extraction
586
+ all_concepts = [
587
+ environment.concepts[extract_address(c)] for c in subgraph if c.startswith("c~")
588
+ ]
589
+
590
+ # Early return for all constants
591
+ if all(c.derivation == Derivation.CONSTANT for c in all_concepts):
592
+ logger.info(
593
+ f"{padding(depth)}{LOGGER_PREFIX} All concepts {[x.address for x in all_concepts]} are constants, returning constant node"
594
+ )
595
+ return ConstantNode(
596
+ output_concepts=all_concepts,
597
+ input_concepts=[],
598
+ environment=environment,
599
+ parents=[],
600
+ depth=depth,
601
+ partial_concepts=[],
602
+ force_group=False,
603
+ preexisting_conditions=conditions.conditional if conditions else None,
604
+ )
605
+
606
+ cache = GraphAttributeCache(g)
607
+ datasource = cache.datasources[ds_name]
608
+
609
+ if isinstance(datasource, BuildDatasource):
610
+ bcandidate, force_group = create_datasource_node(
611
+ datasource,
612
+ all_concepts,
613
+ accept_partial,
614
+ environment,
615
+ depth,
616
+ conditions=conditions,
617
+ )
618
+
619
+ elif isinstance(datasource, list):
620
+ bcandidate, force_group = create_union_datasource(
621
+ datasource,
622
+ all_concepts,
623
+ accept_partial,
624
+ environment,
625
+ depth,
626
+ conditions=conditions,
627
+ )
628
+ else:
629
+ raise ValueError(f"Unknown datasource type {datasource}")
630
+
631
+ # we need to nest the group node one further
632
+ if force_group is True:
633
+ logger.info(
634
+ f"{padding(depth)}{LOGGER_PREFIX} source requires group before consumption."
635
+ )
636
+ candidate: StrategyNode = GroupNode(
637
+ output_concepts=all_concepts,
638
+ input_concepts=all_concepts,
639
+ environment=environment,
640
+ parents=[bcandidate],
641
+ depth=depth + 1,
642
+ partial_concepts=bcandidate.partial_concepts,
643
+ nullable_concepts=bcandidate.nullable_concepts,
644
+ preexisting_conditions=bcandidate.preexisting_conditions,
645
+ force_group=force_group,
646
+ )
647
+ else:
648
+ candidate = bcandidate
649
+
650
+ return candidate
651
+
652
+
653
+ def gen_select_merge_node(
654
+ all_concepts: List[BuildConcept],
655
+ g: nx.DiGraph,
656
+ environment: BuildEnvironment,
657
+ depth: int,
658
+ accept_partial: bool = False,
659
+ conditions: BuildWhereClause | None = None,
660
+ ) -> Optional[StrategyNode]:
661
+ # Early separation of constants and non-constants
662
+ non_constant = []
663
+ constants = []
664
+ for c in all_concepts:
665
+ if c.derivation == Derivation.CONSTANT:
666
+ constants.append(c)
667
+ else:
668
+ non_constant.append(c)
669
+
670
+ # Early return for all constants
671
+ if not non_constant and constants:
672
+ logger.info(
673
+ f"{padding(depth)}{LOGGER_PREFIX} only constant inputs to discovery ({constants}), returning constant node directly"
674
+ )
675
+ for x in constants:
676
+ logger.info(
677
+ f"{padding(depth)}{LOGGER_PREFIX} {x} {x.lineage} {x.derivation}"
678
+ )
679
+ if conditions:
680
+ if not all(
681
+ x.derivation == Derivation.CONSTANT for x in conditions.row_arguments
682
+ ):
683
+ logger.info(
684
+ f"{padding(depth)}{LOGGER_PREFIX} conditions being passed in to constant node {conditions}, but not all concepts are constants."
685
+ )
686
+ return None
687
+ else:
688
+ constants += conditions.row_arguments
689
+
690
+ return ConstantNode(
691
+ output_concepts=constants,
692
+ input_concepts=[],
693
+ environment=environment,
694
+ parents=[],
695
+ depth=depth,
696
+ partial_concepts=[],
697
+ force_group=False,
698
+ conditions=conditions.conditional if conditions else None,
699
+ )
700
+
701
+ attempts = [False]
702
+ if accept_partial:
703
+ attempts.append(True)
704
+
705
+ logger.info(
706
+ f"{padding(depth)}{LOGGER_PREFIX} searching for root source graph for concepts {[c.address for c in all_concepts]} and conditions {conditions}"
707
+ )
708
+
709
+ pruned_concept_graph = None
710
+ for attempt in attempts:
711
+ pruned_concept_graph = create_pruned_concept_graph(
712
+ g,
713
+ non_constant,
714
+ accept_partial=attempt,
715
+ conditions=conditions,
716
+ datasources=list(environment.datasources.values()),
717
+ depth=depth,
718
+ )
719
+ if pruned_concept_graph:
720
+ logger.info(
721
+ f"{padding(depth)}{LOGGER_PREFIX} found covering graph w/ partial flag {attempt}"
722
+ )
723
+ break
724
+
725
+ if not pruned_concept_graph:
726
+ logger.info(f"{padding(depth)}{LOGGER_PREFIX} no covering graph found.")
727
+ return None
728
+
729
+ sub_nodes = resolve_subgraphs(
730
+ pruned_concept_graph,
731
+ relevant=non_constant,
732
+ accept_partial=accept_partial,
733
+ conditions=conditions,
734
+ depth=depth,
735
+ )
736
+
737
+ logger.info(f"{padding(depth)}{LOGGER_PREFIX} fetching subgraphs {sub_nodes}")
738
+
739
+ parents = [
740
+ create_select_node(
741
+ k,
742
+ subgraph,
743
+ g=pruned_concept_graph,
744
+ accept_partial=accept_partial,
745
+ environment=environment,
746
+ depth=depth,
747
+ conditions=conditions,
748
+ )
749
+ for k, subgraph in sub_nodes.items()
750
+ ]
751
+
752
+ if not parents:
753
+ return None
754
+
755
+ if constants:
756
+ parents.append(
757
+ ConstantNode(
758
+ output_concepts=constants,
759
+ input_concepts=[],
760
+ environment=environment,
761
+ parents=[],
762
+ depth=depth,
763
+ partial_concepts=[],
764
+ force_group=False,
765
+ preexisting_conditions=conditions.conditional if conditions else None,
766
+ )
767
+ )
768
+
769
+ if len(parents) == 1:
770
+ return parents[0]
771
+
772
+ logger.info(
773
+ f"{padding(depth)}{LOGGER_PREFIX} Multiple parent DS nodes resolved - {[type(x) for x in parents]}, wrapping in merge"
774
+ )
775
+
776
+ preexisting_conditions = None
777
+ if conditions and all(
778
+ x.preexisting_conditions and x.preexisting_conditions == conditions.conditional
779
+ for x in parents
780
+ ):
781
+ preexisting_conditions = conditions.conditional
782
+
783
+ base = MergeNode(
784
+ output_concepts=all_concepts,
785
+ input_concepts=non_constant,
786
+ environment=environment,
787
+ depth=depth,
788
+ parents=parents,
789
+ preexisting_conditions=preexisting_conditions,
790
+ )
791
+
792
+ return base