pytrilogy 0.0.2.23__py3-none-any.whl → 0.0.2.26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pytrilogy might be problematic. Click here for more details.

@@ -31,27 +31,183 @@ from trilogy.core.models import (
31
31
  TupleWrapper,
32
32
  )
33
33
 
34
- from trilogy.core.enums import Purpose, Granularity, BooleanOperator, Modifier
35
- from trilogy.core.constants import CONSTANT_DATASET
34
+ from trilogy.core.enums import Purpose, Granularity, BooleanOperator
36
35
  from enum import Enum
37
36
  from trilogy.utility import unique
38
- from collections import defaultdict
37
+
39
38
  from logging import Logger
40
- from pydantic import BaseModel
39
+
41
40
 
42
41
  from trilogy.core.enums import FunctionClass
43
42
 
44
43
 
44
+ from dataclasses import dataclass
45
+
46
+
45
47
  class NodeType(Enum):
46
48
  CONCEPT = 1
47
49
  NODE = 2
48
50
 
49
51
 
50
- class PathInfo(BaseModel):
51
- paths: Dict[str, List[str]]
52
- datasource: Datasource
53
- reduced_concepts: Set[str]
54
- concept_subgraphs: List[List[Concept]]
52
+ @dataclass
53
+ class JoinOrderOutput:
54
+ right: str
55
+ type: JoinType
56
+ keys: dict[str, set[str]]
57
+ left: str | None = None
58
+
59
+ @property
60
+ def lefts(self):
61
+ return set(self.keys.keys())
62
+
63
+
64
+ def resolve_join_order_v2(
65
+ g: nx.Graph, partials: dict[str, list[str]]
66
+ ) -> list[JoinOrderOutput]:
67
+ datasources = [x for x in g.nodes if x.startswith("ds~")]
68
+ concepts = [x for x in g.nodes if x.startswith("c~")]
69
+ # from trilogy.hooks.graph_hook import GraphHook
70
+
71
+ # GraphHook().query_graph_built(g)
72
+
73
+ output: list[JoinOrderOutput] = []
74
+ pivot_map = {
75
+ concept: [x for x in g.neighbors(concept) if x in datasources]
76
+ for concept in concepts
77
+ }
78
+ pivots = list(
79
+ sorted(
80
+ [x for x in pivot_map if len(pivot_map[x]) > 1],
81
+ key=lambda x: len(pivot_map[x]),
82
+ )
83
+ )
84
+ solo = [x for x in pivot_map if len(pivot_map[x]) == 1]
85
+ eligible_left = set()
86
+
87
+ while pivots:
88
+ next_pivots = [
89
+ x for x in pivots if any(y in eligible_left for y in pivot_map[x])
90
+ ]
91
+ if next_pivots:
92
+ root = next_pivots[0]
93
+ pivots = [x for x in pivots if x != root]
94
+ else:
95
+ root = pivots.pop()
96
+
97
+ # sort so less partials is last and eligible lefts are
98
+ def score_key(x: str) -> int:
99
+ base = 1
100
+ # if it's left, higher weight
101
+ if x in eligible_left:
102
+ base += 3
103
+ # if it has the concept as a partial, lower weight
104
+ if root in partials.get(x, []):
105
+ base -= 1
106
+ return base
107
+
108
+ # get remainig un-joined datasets
109
+ to_join = sorted(
110
+ [x for x in pivot_map[root] if x not in eligible_left], key=score_key
111
+ )
112
+ while to_join:
113
+ # need to sort this to ensure we join on the best match
114
+ base = sorted(
115
+ [x for x in pivot_map[root] if x in eligible_left], key=score_key
116
+ )
117
+ if not base:
118
+ new = to_join.pop()
119
+ eligible_left.add(new)
120
+ base = [new]
121
+ right = to_join.pop()
122
+ # we already joined it
123
+ # this could happen if the same pivot is shared with multiple Dses
124
+ if right in eligible_left:
125
+ continue
126
+ joinkeys: dict[str, set[str]] = {}
127
+ # sorting puts the best candidate last for pop
128
+ # so iterate over the reversed list
129
+ join_types = set()
130
+ for left_candidate in reversed(base):
131
+ common = nx.common_neighbors(g, left_candidate, right)
132
+
133
+ if not common:
134
+ continue
135
+ exists = False
136
+ for _, v in joinkeys.items():
137
+ if v == common:
138
+ exists = True
139
+ if exists:
140
+ continue
141
+ left_is_partial = any(
142
+ key in partials.get(left_candidate, []) for key in common
143
+ )
144
+ right_is_partial = any(key in partials.get(right, []) for key in common)
145
+ # we don't care if left is nullable for join type (just keys), but if we did
146
+ # ex: left_is_nullable = any(key in partials.get(left_candidate, [])
147
+ right_is_nullable = any(
148
+ key in partials.get(right, []) for key in common
149
+ )
150
+ if left_is_partial:
151
+ join_type = JoinType.FULL
152
+ elif right_is_partial or right_is_nullable:
153
+ join_type = JoinType.LEFT_OUTER
154
+ # we can't inner join if the left was an outer join
155
+ else:
156
+ join_type = JoinType.INNER
157
+ join_types.add(join_type)
158
+ joinkeys[left_candidate] = common
159
+
160
+ final_join_type = JoinType.INNER
161
+ if any([x == JoinType.LEFT_OUTER for x in join_types]):
162
+ final_join_type = JoinType.LEFT_OUTER
163
+ elif any([x == JoinType.FULL for x in join_types]):
164
+ final_join_type = JoinType.FULL
165
+ output.append(
166
+ JoinOrderOutput(
167
+ # left=left_candidate,
168
+ right=right,
169
+ type=final_join_type,
170
+ keys=joinkeys,
171
+ )
172
+ )
173
+ eligible_left.add(right)
174
+
175
+ for concept in solo:
176
+ for ds in pivot_map[concept]:
177
+ # if we already have it, skip it
178
+
179
+ if ds in eligible_left:
180
+ continue
181
+ # if we haven't had ANY left datasources yet
182
+ # this needs to become it
183
+ if not eligible_left:
184
+ eligible_left.add(ds)
185
+ continue
186
+ # otherwise do a full out join
187
+ output.append(
188
+ JoinOrderOutput(
189
+ # pick random one to be left
190
+ left=list(eligible_left)[0],
191
+ right=ds,
192
+ type=JoinType.FULL,
193
+ keys={},
194
+ )
195
+ )
196
+ eligible_left.add(ds)
197
+ # only once we have all joins
198
+ # do we know if some inners need to be left outers
199
+ for review_join in output:
200
+ if review_join.type in (JoinType.LEFT_OUTER, JoinType.FULL):
201
+ continue
202
+ if any(
203
+ [
204
+ join.right in review_join.lefts
205
+ for join in output
206
+ if join.type in (JoinType.LEFT_OUTER, JoinType.FULL)
207
+ ]
208
+ ):
209
+ review_join.type = JoinType.LEFT_OUTER
210
+ return output
55
211
 
56
212
 
57
213
  def concept_to_relevant_joins(concepts: list[Concept]) -> List[Concept]:
@@ -112,270 +268,94 @@ def calculate_graph_relevance(
112
268
  return relevance
113
269
 
114
270
 
115
- def resolve_join_order(joins: List[BaseJoin]) -> List[BaseJoin]:
116
- available_aliases: set[str] = set()
117
- final_joins_pre = [*joins]
118
- final_joins = []
119
- partial = set()
120
- while final_joins_pre:
121
- new_final_joins_pre: List[BaseJoin] = []
122
- for join in final_joins_pre:
123
- if join.join_type != JoinType.INNER:
124
- partial.add(join.right_datasource.identifier)
125
- # an inner join after a left outer implicitly makes that outer an inner
126
- # so fix that
127
- if (
128
- join.left_datasource.identifier in partial
129
- and join.join_type == JoinType.INNER
130
- ):
131
- join.join_type = JoinType.LEFT_OUTER
132
- if not available_aliases:
133
- final_joins.append(join)
134
- available_aliases.add(join.left_datasource.identifier)
135
- available_aliases.add(join.right_datasource.identifier)
136
- elif join.left_datasource.identifier in available_aliases:
137
- # we don't need to join twice
138
- # so whatever join we found first, works
139
- if join.right_datasource.identifier in available_aliases:
140
- continue
141
- final_joins.append(join)
142
- available_aliases.add(join.left_datasource.identifier)
143
- available_aliases.add(join.right_datasource.identifier)
144
- else:
145
- new_final_joins_pre.append(join)
146
- if len(new_final_joins_pre) == len(final_joins_pre):
147
- remaining = [
148
- join.left_datasource.identifier for join in new_final_joins_pre
149
- ]
150
- remaining_right = [
151
- join.right_datasource.identifier for join in new_final_joins_pre
152
- ]
153
- raise SyntaxError(
154
- f"did not find any new joins, available {available_aliases} remaining is {remaining + remaining_right} "
155
- )
156
- final_joins_pre = new_final_joins_pre
157
- return final_joins
158
-
159
-
160
271
  def add_node_join_concept(
161
272
  graph: nx.DiGraph,
162
273
  concept: Concept,
163
- datasource: Datasource | QueryDatasource,
164
- concepts: List[Concept],
274
+ concept_map: dict[str, Concept],
275
+ ds_node: str,
276
+ environment: Environment,
165
277
  ):
278
+ name = f"c~{concept.address}"
279
+ graph.add_node(name, type=NodeType.CONCEPT)
280
+ graph.add_edge(ds_node, name)
281
+ concept_map[name] = concept
282
+ for v_address in concept.pseudonyms:
283
+ v = environment.alias_origin_lookup.get(
284
+ v_address, environment.concepts[v_address]
285
+ )
286
+ if f"c~{v.address}" in graph.nodes:
287
+ continue
288
+ if v != concept.address:
289
+ add_node_join_concept(
290
+ graph=graph,
291
+ concept=v,
292
+ concept_map=concept_map,
293
+ ds_node=ds_node,
294
+ environment=environment,
295
+ )
166
296
 
167
- concepts.append(concept)
168
297
 
169
- graph.add_node(concept.address, type=NodeType.CONCEPT)
170
- graph.add_edge(datasource.identifier, concept.address)
171
- for _, v in concept.pseudonyms.items():
172
- if v in concepts:
173
- continue
174
- if v.address != concept.address:
175
- add_node_join_concept(graph, v, datasource, concepts)
298
+ def resolve_instantiated_concept(
299
+ concept: Concept, datasource: QueryDatasource
300
+ ) -> Concept:
301
+ if concept.address in datasource.output_concepts:
302
+ return concept
303
+ for k in concept.pseudonyms:
304
+ if k in datasource.output_concepts:
305
+ return [x for x in datasource.output_concepts if x.address == k].pop()
306
+ raise SyntaxError(
307
+ f"Could not find {concept.address} in {datasource.identifier} output {[c.address for c in datasource.output_concepts]}"
308
+ )
176
309
 
177
310
 
178
311
  def get_node_joins(
179
312
  datasources: List[QueryDatasource],
180
- grain: List[Concept],
181
313
  environment: Environment,
182
314
  # concepts:List[Concept],
183
- ) -> List[BaseJoin]:
184
- graph = nx.Graph()
185
- concepts: List[Concept] = []
186
- for datasource in datasources:
187
- graph.add_node(datasource.identifier, type=NodeType.NODE)
188
- for concept in datasource.output_concepts:
189
- add_node_join_concept(graph, concept, datasource, concepts)
315
+ ):
190
316
 
191
- # add edges for every constant to every datasource
317
+ graph = nx.Graph()
318
+ partials: dict[str, list[str]] = {}
319
+ ds_node_map: dict[str, QueryDatasource] = {}
320
+ concept_map: dict[str, Concept] = {}
192
321
  for datasource in datasources:
322
+ ds_node = f"ds~{datasource.identifier}"
323
+ ds_node_map[ds_node] = datasource
324
+ graph.add_node(ds_node, type=NodeType.NODE)
325
+ partials[ds_node] = [f"c~{c.address}" for c in datasource.partial_concepts]
193
326
  for concept in datasource.output_concepts:
194
- if concept.granularity == Granularity.SINGLE_ROW:
195
- for node in graph.nodes:
196
- if graph.nodes[node]["type"] == NodeType.NODE:
197
- graph.add_edge(node, concept.address)
198
-
199
- joins: defaultdict[str, set] = defaultdict(set)
200
- identifier_map: dict[str, Datasource | QueryDatasource] = {
201
- x.identifier: x for x in datasources
202
- }
203
- grain_pseudonyms: set[str] = set()
204
- for g in grain:
205
- env_lookup = environment.concepts[g.address]
206
- # if we're looking up a pseudonym, we would have gotten the remapped value
207
- # so double check we got what we were looking for
208
- if env_lookup.address == g.address:
209
- grain_pseudonyms.update(env_lookup.pseudonyms.keys())
210
-
211
- node_list = sorted(
212
- [x for x in graph.nodes if graph.nodes[x]["type"] == NodeType.NODE],
213
- # sort so that anything with a partial match on the target is later
214
- key=lambda x: len(
215
- [
216
- partial
217
- for partial in identifier_map[x].partial_concepts
218
- if partial in grain
219
- ]
220
- + [
221
- output
222
- for output in identifier_map[x].output_concepts
223
- if output.address in grain_pseudonyms
224
- ]
225
- ),
226
- )
227
-
228
- for left in node_list:
229
- # the constant dataset is a special case
230
- # and can never be on the left of a join
231
- if left == CONSTANT_DATASET:
232
- continue
233
-
234
- for cnode in graph.neighbors(left):
235
- if graph.nodes[cnode]["type"] == NodeType.CONCEPT:
236
- for right in graph.neighbors(cnode):
237
- # skip concepts
238
- if graph.nodes[right]["type"] == NodeType.CONCEPT:
239
- continue
240
- if left == right:
241
- continue
242
- identifier = [left, right]
243
- joins["-".join(identifier)].add(cnode)
244
-
245
- final_joins_pre: List[BaseJoin] = []
246
-
247
- for key, join_concepts in joins.items():
248
- left, right = key.split("-")
249
- local_concepts: List[Concept] = unique(
250
- [c for c in concepts if c.address in join_concepts], "address"
251
- )
252
- if all([c.granularity == Granularity.SINGLE_ROW for c in local_concepts]):
253
- # for the constant join, make it a full outer join on 1=1
254
- join_type = JoinType.FULL
255
- local_concepts = []
256
- elif any(
257
- [
258
- c.address in [x.address for x in identifier_map[left].partial_concepts]
259
- for c in local_concepts
260
- ]
261
- ):
262
- join_type = JoinType.FULL
263
- local_concepts = [
264
- c for c in local_concepts if c.granularity != Granularity.SINGLE_ROW
265
- ]
266
- elif any(
267
- [
268
- c.address in [x.address for x in identifier_map[right].partial_concepts]
269
- for c in local_concepts
270
- ]
271
- ) or any(
272
- [
273
- c.address in [x.address for x in identifier_map[left].nullable_concepts]
274
- for c in local_concepts
275
- ]
276
- ):
277
- join_type = JoinType.LEFT_OUTER
278
- local_concepts = [
279
- c for c in local_concepts if c.granularity != Granularity.SINGLE_ROW
280
- ]
281
- else:
282
- join_type = JoinType.INNER
283
- # remove any constants if other join keys exist
284
- local_concepts = [
285
- c for c in local_concepts if c.granularity != Granularity.SINGLE_ROW
286
- ]
287
-
288
- relevant = concept_to_relevant_joins(local_concepts)
289
- left_datasource = identifier_map[left]
290
- right_datasource = identifier_map[right]
291
- join_tuples: list[ConceptPair] = []
292
- for joinc in relevant:
293
- left_arg = joinc
294
- right_arg = joinc
295
- if joinc.address not in [
296
- c.address for c in left_datasource.output_concepts
297
- ]:
298
- try:
299
- left_arg = [
300
- x
301
- for x in left_datasource.output_concepts
302
- if x.address in joinc.pseudonyms
303
- or joinc.address in x.pseudonyms
304
- ].pop()
305
- except IndexError:
306
- raise SyntaxError(
307
- f"Could not find {joinc.address} in {left_datasource.identifier} output {[c.address for c in left_datasource.output_concepts]}"
308
- )
309
- if joinc.address not in [
310
- c.address for c in right_datasource.output_concepts
311
- ]:
312
- try:
313
- right_arg = [
314
- x
315
- for x in right_datasource.output_concepts
316
- if x.address in joinc.pseudonyms
317
- or joinc.address in x.pseudonyms
318
- ].pop()
319
- except IndexError:
320
- raise SyntaxError(
321
- f"Could not find {joinc.address} in {right_datasource.identifier} output {[c.address for c in right_datasource.output_concepts]}"
322
- )
323
- narg = (left_arg, right_arg)
324
- if narg not in join_tuples:
325
- modifiers = set()
326
- if left_arg.address in [
327
- x.address for x in left_datasource.nullable_concepts
328
- ] and right_arg.address in [
329
- x.address for x in right_datasource.nullable_concepts
330
- ]:
331
- modifiers.add(Modifier.NULLABLE)
332
- join_tuples.append(
333
- ConceptPair(
334
- left=left_arg, right=right_arg, modifiers=list(modifiers)
335
- )
336
- )
337
-
338
- # deduplication
339
- all_right = []
340
- for tuple in join_tuples:
341
- all_right.append(tuple.right.address)
342
- right_grain = identifier_map[right].grain
343
- # if the join includes all the right grain components
344
- # we only need to join on those, not everything
345
- if all([x.address in all_right for x in right_grain.components]):
346
- join_tuples = [
347
- x for x in join_tuples if x.right.address in right_grain.components
348
- ]
349
-
350
- final_joins_pre.append(
351
- BaseJoin(
352
- left_datasource=identifier_map[left],
353
- right_datasource=identifier_map[right],
354
- join_type=join_type,
355
- concepts=[],
356
- concept_pairs=join_tuples,
327
+ add_node_join_concept(
328
+ graph=graph,
329
+ concept=concept,
330
+ concept_map=concept_map,
331
+ ds_node=ds_node,
332
+ environment=environment,
357
333
  )
358
- )
359
- final_joins = resolve_join_order(final_joins_pre)
360
334
 
361
- # this is extra validation
362
- non_single_row_ds = [x for x in datasources if not x.grain.abstract]
363
- if len(non_single_row_ds) > 1:
364
- for x in datasources:
365
- if x.grain.abstract:
366
- continue
367
- found = False
368
- for join in final_joins:
369
- if (
370
- join.left_datasource.identifier == x.identifier
371
- or join.right_datasource.identifier == x.identifier
372
- ):
373
- found = True
374
- if not found:
375
- raise SyntaxError(
376
- f"Could not find join for {x.identifier} with output {[c.address for c in x.output_concepts]}, all {[z.identifier for z in datasources]}"
335
+ joins = resolve_join_order_v2(graph, partials=partials)
336
+ return [
337
+ BaseJoin(
338
+ left_datasource=ds_node_map[j.left] if j.left else None,
339
+ right_datasource=ds_node_map[j.right],
340
+ join_type=j.type,
341
+ # preserve empty field for maps
342
+ concepts=[] if not j.keys else None,
343
+ concept_pairs=[
344
+ ConceptPair(
345
+ left=resolve_instantiated_concept(
346
+ concept_map[concept], ds_node_map[k]
347
+ ),
348
+ right=resolve_instantiated_concept(
349
+ concept_map[concept], ds_node_map[j.right]
350
+ ),
351
+ existing_datasource=ds_node_map[k],
377
352
  )
378
- return final_joins
353
+ for k, v in j.keys.items()
354
+ for concept in v
355
+ ],
356
+ )
357
+ for j in joins
358
+ ]
379
359
 
380
360
 
381
361
  def get_disconnected_components(
@@ -520,11 +500,13 @@ def find_nullable_concepts(
520
500
  ]:
521
501
  is_on_nullable_condition = True
522
502
  break
503
+ left_check = (
504
+ join.left_datasource.identifier
505
+ if join.left_datasource is not None
506
+ else pair.existing_datasource.identifier
507
+ )
523
508
  if pair.left.address in [
524
- y.address
525
- for y in datasource_map[
526
- join.left_datasource.identifier
527
- ].nullable_concepts
509
+ y.address for y in datasource_map[left_check].nullable_concepts
528
510
  ]:
529
511
  is_on_nullable_condition = True
530
512
  break
@@ -17,7 +17,6 @@ from trilogy.core.models import (
17
17
  CTE,
18
18
  Join,
19
19
  UnnestJoin,
20
- JoinKey,
21
20
  MaterializedDataset,
22
21
  ProcessedQuery,
23
22
  ProcessedQueryPersist,
@@ -28,6 +27,7 @@ from trilogy.core.models import (
28
27
  Conditional,
29
28
  ProcessedCopyStatement,
30
29
  CopyStatement,
30
+ CTEConceptPair,
31
31
  )
32
32
 
33
33
  from trilogy.utility import unique
@@ -52,44 +52,53 @@ def base_join_to_join(
52
52
  concept_to_unnest=base_join.parent.concept_arguments[0],
53
53
  alias=base_join.alias,
54
54
  )
55
- if base_join.left_datasource.identifier == base_join.right_datasource.identifier:
56
- raise ValueError(f"Joining on same datasource {base_join}")
57
- left_ctes = [
58
- cte
59
- for cte in ctes
60
- if (cte.source.full_name == base_join.left_datasource.full_name)
61
- ]
62
- if not left_ctes:
63
- left_ctes = [
64
- cte
65
- for cte in ctes
66
- if (
67
- cte.source.datasources[0].full_name
68
- == base_join.left_datasource.full_name
55
+
56
+ def get_datasource_cte(datasource: Datasource | QueryDatasource) -> CTE:
57
+ for cte in ctes:
58
+ if cte.source.full_name == datasource.full_name:
59
+ return cte
60
+ for cte in ctes:
61
+ if cte.source.datasources[0].full_name == datasource.full_name:
62
+ return cte
63
+ raise ValueError(f"Could not find CTE for datasource {datasource.full_name}")
64
+
65
+ if base_join.left_datasource is not None:
66
+ left_cte = get_datasource_cte(base_join.left_datasource)
67
+ else:
68
+ # multiple left ctes
69
+ left_cte = None
70
+ right_cte = get_datasource_cte(base_join.right_datasource)
71
+ if base_join.concept_pairs:
72
+ final_pairs = [
73
+ CTEConceptPair(
74
+ left=pair.left,
75
+ right=pair.right,
76
+ existing_datasource=pair.existing_datasource,
77
+ modifiers=pair.modifiers,
78
+ cte=get_datasource_cte(pair.existing_datasource),
69
79
  )
80
+ for pair in base_join.concept_pairs
70
81
  ]
71
- left_cte = left_ctes[0]
72
- right_ctes = [
73
- cte
74
- for cte in ctes
75
- if (cte.source.full_name == base_join.right_datasource.full_name)
76
- ]
77
- if not right_ctes:
78
- right_ctes = [
79
- cte
80
- for cte in ctes
81
- if (
82
- cte.source.datasources[0].full_name
83
- == base_join.right_datasource.full_name
82
+ elif base_join.concepts and base_join.left_datasource:
83
+ final_pairs = [
84
+ CTEConceptPair(
85
+ left=concept,
86
+ right=concept,
87
+ existing_datasource=base_join.left_datasource,
88
+ modifiers=[],
89
+ cte=get_datasource_cte(
90
+ base_join.left_datasource,
91
+ ),
84
92
  )
93
+ for concept in base_join.concepts
85
94
  ]
86
- right_cte = right_ctes[0]
95
+ else:
96
+ final_pairs = []
87
97
  return Join(
88
98
  left_cte=left_cte,
89
99
  right_cte=right_cte,
90
- joinkeys=[JoinKey(concept=concept) for concept in base_join.concepts],
91
100
  jointype=base_join.join_type,
92
- joinkey_pairs=base_join.concept_pairs if base_join.concept_pairs else None,
101
+ joinkey_pairs=final_pairs,
93
102
  )
94
103
 
95
104
 
@@ -195,7 +204,6 @@ def resolve_cte_base_name_and_alias_v2(
195
204
  source_map: Dict[str, list[str]],
196
205
  raw_joins: List[Join | InstantiatedUnnestJoin],
197
206
  ) -> Tuple[str | None, str | None]:
198
- joins: List[Join] = [join for join in raw_joins if isinstance(join, Join)]
199
207
  if (
200
208
  isinstance(source.datasources[0], Datasource)
201
209
  and not source.datasources[0].name == CONSTANT_DATASET
@@ -203,8 +211,12 @@ def resolve_cte_base_name_and_alias_v2(
203
211
  ds = source.datasources[0]
204
212
  return ds.safe_location, ds.identifier
205
213
 
214
+ joins: List[Join] = [join for join in raw_joins if isinstance(join, Join)]
206
215
  if joins and len(joins) > 0:
207
- candidates = [x.left_cte.name for x in joins]
216
+ candidates = [x.left_cte.name for x in joins if x.left_cte]
217
+ for join in joins:
218
+ if join.joinkey_pairs:
219
+ candidates += [x.cte.name for x in join.joinkey_pairs if x.cte]
208
220
  disallowed = [x.right_cte.name for x in joins]
209
221
  try:
210
222
  cte = [y for y in candidates if y not in disallowed][0]
@@ -213,7 +225,6 @@ def resolve_cte_base_name_and_alias_v2(
213
225
  raise SyntaxError(
214
226
  f"Invalid join configuration {candidates} {disallowed} for {name}",
215
227
  )
216
-
217
228
  counts: dict[str, int] = defaultdict(lambda: 0)
218
229
  output_addresses = [x.address for x in source.output_concepts]
219
230
  input_address = [x.address for x in source.input_concepts]
@@ -269,11 +280,8 @@ def datasource_to_ctes(
269
280
 
270
281
  human_id = generate_cte_name(query_datasource.full_name, name_map)
271
282
 
272
- final_joins = [
273
- x
274
- for x in [base_join_to_join(join, parents) for join in query_datasource.joins]
275
- if x
276
- ]
283
+ final_joins = [base_join_to_join(join, parents) for join in query_datasource.joins]
284
+
277
285
  base_name, base_alias = resolve_cte_base_name_and_alias_v2(
278
286
  human_id, query_datasource, source_map, final_joins
279
287
  )