pytrilogy 0.0.2.25__py3-none-any.whl → 0.0.2.26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pytrilogy might be problematic. Click here for more details.

@@ -31,27 +31,183 @@ from trilogy.core.models import (
31
31
  TupleWrapper,
32
32
  )
33
33
 
34
- from trilogy.core.enums import Purpose, Granularity, BooleanOperator, Modifier
35
- from trilogy.core.constants import CONSTANT_DATASET
34
+ from trilogy.core.enums import Purpose, Granularity, BooleanOperator
36
35
  from enum import Enum
37
36
  from trilogy.utility import unique
38
- from collections import defaultdict
37
+
39
38
  from logging import Logger
40
- from pydantic import BaseModel
39
+
41
40
 
42
41
  from trilogy.core.enums import FunctionClass
43
42
 
44
43
 
44
+ from dataclasses import dataclass
45
+
46
+
45
47
  class NodeType(Enum):
46
48
  CONCEPT = 1
47
49
  NODE = 2
48
50
 
49
51
 
50
- class PathInfo(BaseModel):
51
- paths: Dict[str, List[str]]
52
- datasource: Datasource
53
- reduced_concepts: Set[str]
54
- concept_subgraphs: List[List[Concept]]
52
+ @dataclass
53
+ class JoinOrderOutput:
54
+ right: str
55
+ type: JoinType
56
+ keys: dict[str, set[str]]
57
+ left: str | None = None
58
+
59
+ @property
60
+ def lefts(self):
61
+ return set(self.keys.keys())
62
+
63
+
64
+ def resolve_join_order_v2(
65
+ g: nx.Graph, partials: dict[str, list[str]]
66
+ ) -> list[JoinOrderOutput]:
67
+ datasources = [x for x in g.nodes if x.startswith("ds~")]
68
+ concepts = [x for x in g.nodes if x.startswith("c~")]
69
+ # from trilogy.hooks.graph_hook import GraphHook
70
+
71
+ # GraphHook().query_graph_built(g)
72
+
73
+ output: list[JoinOrderOutput] = []
74
+ pivot_map = {
75
+ concept: [x for x in g.neighbors(concept) if x in datasources]
76
+ for concept in concepts
77
+ }
78
+ pivots = list(
79
+ sorted(
80
+ [x for x in pivot_map if len(pivot_map[x]) > 1],
81
+ key=lambda x: len(pivot_map[x]),
82
+ )
83
+ )
84
+ solo = [x for x in pivot_map if len(pivot_map[x]) == 1]
85
+ eligible_left = set()
86
+
87
+ while pivots:
88
+ next_pivots = [
89
+ x for x in pivots if any(y in eligible_left for y in pivot_map[x])
90
+ ]
91
+ if next_pivots:
92
+ root = next_pivots[0]
93
+ pivots = [x for x in pivots if x != root]
94
+ else:
95
+ root = pivots.pop()
96
+
97
+ # sort so less partials is last and eligible lefts are
98
+ def score_key(x: str) -> int:
99
+ base = 1
100
+ # if it's left, higher weight
101
+ if x in eligible_left:
102
+ base += 3
103
+ # if it has the concept as a partial, lower weight
104
+ if root in partials.get(x, []):
105
+ base -= 1
106
+ return base
107
+
108
+ # get remainig un-joined datasets
109
+ to_join = sorted(
110
+ [x for x in pivot_map[root] if x not in eligible_left], key=score_key
111
+ )
112
+ while to_join:
113
+ # need to sort this to ensure we join on the best match
114
+ base = sorted(
115
+ [x for x in pivot_map[root] if x in eligible_left], key=score_key
116
+ )
117
+ if not base:
118
+ new = to_join.pop()
119
+ eligible_left.add(new)
120
+ base = [new]
121
+ right = to_join.pop()
122
+ # we already joined it
123
+ # this could happen if the same pivot is shared with multiple Dses
124
+ if right in eligible_left:
125
+ continue
126
+ joinkeys: dict[str, set[str]] = {}
127
+ # sorting puts the best candidate last for pop
128
+ # so iterate over the reversed list
129
+ join_types = set()
130
+ for left_candidate in reversed(base):
131
+ common = nx.common_neighbors(g, left_candidate, right)
132
+
133
+ if not common:
134
+ continue
135
+ exists = False
136
+ for _, v in joinkeys.items():
137
+ if v == common:
138
+ exists = True
139
+ if exists:
140
+ continue
141
+ left_is_partial = any(
142
+ key in partials.get(left_candidate, []) for key in common
143
+ )
144
+ right_is_partial = any(key in partials.get(right, []) for key in common)
145
+ # we don't care if left is nullable for join type (just keys), but if we did
146
+ # ex: left_is_nullable = any(key in partials.get(left_candidate, [])
147
+ right_is_nullable = any(
148
+ key in partials.get(right, []) for key in common
149
+ )
150
+ if left_is_partial:
151
+ join_type = JoinType.FULL
152
+ elif right_is_partial or right_is_nullable:
153
+ join_type = JoinType.LEFT_OUTER
154
+ # we can't inner join if the left was an outer join
155
+ else:
156
+ join_type = JoinType.INNER
157
+ join_types.add(join_type)
158
+ joinkeys[left_candidate] = common
159
+
160
+ final_join_type = JoinType.INNER
161
+ if any([x == JoinType.LEFT_OUTER for x in join_types]):
162
+ final_join_type = JoinType.LEFT_OUTER
163
+ elif any([x == JoinType.FULL for x in join_types]):
164
+ final_join_type = JoinType.FULL
165
+ output.append(
166
+ JoinOrderOutput(
167
+ # left=left_candidate,
168
+ right=right,
169
+ type=final_join_type,
170
+ keys=joinkeys,
171
+ )
172
+ )
173
+ eligible_left.add(right)
174
+
175
+ for concept in solo:
176
+ for ds in pivot_map[concept]:
177
+ # if we already have it, skip it
178
+
179
+ if ds in eligible_left:
180
+ continue
181
+ # if we haven't had ANY left datasources yet
182
+ # this needs to become it
183
+ if not eligible_left:
184
+ eligible_left.add(ds)
185
+ continue
186
+ # otherwise do a full out join
187
+ output.append(
188
+ JoinOrderOutput(
189
+ # pick random one to be left
190
+ left=list(eligible_left)[0],
191
+ right=ds,
192
+ type=JoinType.FULL,
193
+ keys={},
194
+ )
195
+ )
196
+ eligible_left.add(ds)
197
+ # only once we have all joins
198
+ # do we know if some inners need to be left outers
199
+ for review_join in output:
200
+ if review_join.type in (JoinType.LEFT_OUTER, JoinType.FULL):
201
+ continue
202
+ if any(
203
+ [
204
+ join.right in review_join.lefts
205
+ for join in output
206
+ if join.type in (JoinType.LEFT_OUTER, JoinType.FULL)
207
+ ]
208
+ ):
209
+ review_join.type = JoinType.LEFT_OUTER
210
+ return output
55
211
 
56
212
 
57
213
  def concept_to_relevant_joins(concepts: list[Concept]) -> List[Concept]:
@@ -112,273 +268,94 @@ def calculate_graph_relevance(
112
268
  return relevance
113
269
 
114
270
 
115
- def resolve_join_order(joins: List[BaseJoin]) -> List[BaseJoin]:
116
- available_aliases: set[str] = set()
117
- final_joins_pre = [*joins]
118
- final_joins = []
119
- partial = set()
120
- while final_joins_pre:
121
- new_final_joins_pre: List[BaseJoin] = []
122
- for join in final_joins_pre:
123
- if join.join_type != JoinType.INNER:
124
- partial.add(join.right_datasource.identifier)
125
- # an inner join after a left outer implicitly makes that outer an inner
126
- # so fix that
127
- if (
128
- join.left_datasource.identifier in partial
129
- and join.join_type == JoinType.INNER
130
- ):
131
- join.join_type = JoinType.LEFT_OUTER
132
- if not available_aliases:
133
- final_joins.append(join)
134
- available_aliases.add(join.left_datasource.identifier)
135
- available_aliases.add(join.right_datasource.identifier)
136
- elif join.left_datasource.identifier in available_aliases:
137
- # we don't need to join twice
138
- # so whatever join we found first, works
139
- if join.right_datasource.identifier in available_aliases:
140
- continue
141
- final_joins.append(join)
142
- available_aliases.add(join.left_datasource.identifier)
143
- available_aliases.add(join.right_datasource.identifier)
144
- else:
145
- new_final_joins_pre.append(join)
146
- if len(new_final_joins_pre) == len(final_joins_pre):
147
- remaining = [
148
- join.left_datasource.identifier for join in new_final_joins_pre
149
- ]
150
- remaining_right = [
151
- join.right_datasource.identifier for join in new_final_joins_pre
152
- ]
153
- raise SyntaxError(
154
- f"did not find any new joins, available {available_aliases} remaining is {remaining + remaining_right} "
155
- )
156
- final_joins_pre = new_final_joins_pre
157
- return final_joins
158
-
159
-
160
271
  def add_node_join_concept(
161
272
  graph: nx.DiGraph,
162
273
  concept: Concept,
163
- datasource: Datasource | QueryDatasource,
164
- concepts: List[Concept],
274
+ concept_map: dict[str, Concept],
275
+ ds_node: str,
165
276
  environment: Environment,
166
277
  ):
167
-
168
- concepts.append(concept)
169
-
170
- graph.add_node(concept.address, type=NodeType.CONCEPT)
171
- graph.add_edge(datasource.identifier, concept.address)
278
+ name = f"c~{concept.address}"
279
+ graph.add_node(name, type=NodeType.CONCEPT)
280
+ graph.add_edge(ds_node, name)
281
+ concept_map[name] = concept
172
282
  for v_address in concept.pseudonyms:
173
283
  v = environment.alias_origin_lookup.get(
174
284
  v_address, environment.concepts[v_address]
175
285
  )
176
- if v in concepts:
286
+ if f"c~{v.address}" in graph.nodes:
177
287
  continue
178
288
  if v != concept.address:
179
- add_node_join_concept(graph, v, datasource, concepts, environment)
289
+ add_node_join_concept(
290
+ graph=graph,
291
+ concept=v,
292
+ concept_map=concept_map,
293
+ ds_node=ds_node,
294
+ environment=environment,
295
+ )
296
+
297
+
298
+ def resolve_instantiated_concept(
299
+ concept: Concept, datasource: QueryDatasource
300
+ ) -> Concept:
301
+ if concept.address in datasource.output_concepts:
302
+ return concept
303
+ for k in concept.pseudonyms:
304
+ if k in datasource.output_concepts:
305
+ return [x for x in datasource.output_concepts if x.address == k].pop()
306
+ raise SyntaxError(
307
+ f"Could not find {concept.address} in {datasource.identifier} output {[c.address for c in datasource.output_concepts]}"
308
+ )
180
309
 
181
310
 
182
311
  def get_node_joins(
183
312
  datasources: List[QueryDatasource],
184
- grain: List[Concept],
185
313
  environment: Environment,
186
314
  # concepts:List[Concept],
187
- ) -> List[BaseJoin]:
188
- graph = nx.Graph()
189
- concepts: List[Concept] = []
190
- for datasource in datasources:
191
- graph.add_node(datasource.identifier, type=NodeType.NODE)
192
- for concept in datasource.output_concepts:
193
- add_node_join_concept(graph, concept, datasource, concepts, environment)
315
+ ):
194
316
 
195
- # add edges for every constant to every datasource
317
+ graph = nx.Graph()
318
+ partials: dict[str, list[str]] = {}
319
+ ds_node_map: dict[str, QueryDatasource] = {}
320
+ concept_map: dict[str, Concept] = {}
196
321
  for datasource in datasources:
322
+ ds_node = f"ds~{datasource.identifier}"
323
+ ds_node_map[ds_node] = datasource
324
+ graph.add_node(ds_node, type=NodeType.NODE)
325
+ partials[ds_node] = [f"c~{c.address}" for c in datasource.partial_concepts]
197
326
  for concept in datasource.output_concepts:
198
- if concept.granularity == Granularity.SINGLE_ROW:
199
- for node in graph.nodes:
200
- if graph.nodes[node]["type"] == NodeType.NODE:
201
- graph.add_edge(node, concept.address)
202
- joins: defaultdict[str, set] = defaultdict(set)
203
- identifier_map: dict[str, Datasource | QueryDatasource] = {
204
- x.identifier: x for x in datasources
205
- }
206
- grain_pseudonyms: set[str] = set()
207
- for g in grain:
208
- env_lookup = environment.concepts[g.address]
209
- # if we're looking up a pseudonym, we would have gotten the remapped value
210
- # so double check we got what we were looking for
211
- if env_lookup.address == g.address:
212
- grain_pseudonyms.update(env_lookup.pseudonyms)
213
-
214
- node_list = sorted(
215
- [x for x in graph.nodes if graph.nodes[x]["type"] == NodeType.NODE],
216
- # sort so that anything with a partial match on the target is later
217
- key=lambda x: len(
218
- [
219
- partial
220
- for partial in identifier_map[x].partial_concepts
221
- if partial in grain
222
- ]
223
- + [
224
- output
225
- for output in identifier_map[x].output_concepts
226
- if output.address in grain_pseudonyms
227
- ]
228
- ),
229
- )
230
-
231
- for left in node_list:
232
- # the constant dataset is a special case
233
- # and can never be on the left of a join
234
- if left == CONSTANT_DATASET:
235
- continue
236
-
237
- for cnode in graph.neighbors(left):
238
- if graph.nodes[cnode]["type"] == NodeType.CONCEPT:
239
- for right in graph.neighbors(cnode):
240
- # skip concepts
241
- if graph.nodes[right]["type"] == NodeType.CONCEPT:
242
- continue
243
- if left == right:
244
- continue
245
- identifier = [left, right]
246
- joins["-".join(identifier)].add(cnode)
247
-
248
- final_joins_pre: List[BaseJoin] = []
249
-
250
- for key, join_concepts in joins.items():
251
- left, right = key.split("-")
252
- local_concepts: List[Concept] = unique(
253
- [c for c in concepts if c.address in join_concepts], "address"
254
- )
255
- if all([c.granularity == Granularity.SINGLE_ROW for c in local_concepts]):
256
- # for the constant join, make it a full outer join on 1=1
257
- join_type = JoinType.FULL
258
- local_concepts = []
259
- elif any(
260
- [
261
- c.address in [x.address for x in identifier_map[left].partial_concepts]
262
- for c in local_concepts
263
- ]
264
- ):
265
- join_type = JoinType.FULL
266
- local_concepts = [
267
- c for c in local_concepts if c.granularity != Granularity.SINGLE_ROW
268
- ]
269
- elif any(
270
- [
271
- c.address in [x.address for x in identifier_map[right].partial_concepts]
272
- for c in local_concepts
273
- ]
274
- ) or any(
275
- [
276
- c.address in [x.address for x in identifier_map[left].nullable_concepts]
277
- for c in local_concepts
278
- ]
279
- ):
280
- join_type = JoinType.LEFT_OUTER
281
- local_concepts = [
282
- c for c in local_concepts if c.granularity != Granularity.SINGLE_ROW
283
- ]
284
- else:
285
- join_type = JoinType.INNER
286
- # remove any constants if other join keys exist
287
- local_concepts = [
288
- c for c in local_concepts if c.granularity != Granularity.SINGLE_ROW
289
- ]
290
-
291
- relevant = concept_to_relevant_joins(local_concepts)
292
- left_datasource = identifier_map[left]
293
- right_datasource = identifier_map[right]
294
- join_tuples: list[ConceptPair] = []
295
- for joinc in relevant:
296
- left_arg = joinc
297
- right_arg = joinc
298
- if joinc.address not in [
299
- c.address for c in left_datasource.output_concepts
300
- ]:
301
- try:
302
- left_arg = [
303
- x
304
- for x in left_datasource.output_concepts
305
- if x.address in joinc.pseudonyms
306
- or joinc.address in x.pseudonyms
307
- ].pop()
308
- except IndexError:
309
- raise SyntaxError(
310
- f"Could not find {joinc.address} in {left_datasource.identifier} output {[c.address for c in left_datasource.output_concepts]}"
311
- )
312
- if joinc.address not in [
313
- c.address for c in right_datasource.output_concepts
314
- ]:
315
- try:
316
- right_arg = [
317
- x
318
- for x in right_datasource.output_concepts
319
- if x.address in joinc.pseudonyms
320
- or joinc.address in x.pseudonyms
321
- ].pop()
322
- except IndexError:
323
- raise SyntaxError(
324
- f"Could not find {joinc.address} in {right_datasource.identifier} output {[c.address for c in right_datasource.output_concepts]}"
325
- )
326
- narg = (left_arg, right_arg)
327
- if narg not in join_tuples:
328
- modifiers = set()
329
- if left_arg.address in [
330
- x.address for x in left_datasource.nullable_concepts
331
- ] and right_arg.address in [
332
- x.address for x in right_datasource.nullable_concepts
333
- ]:
334
- modifiers.add(Modifier.NULLABLE)
335
- join_tuples.append(
336
- ConceptPair(
337
- left=left_arg, right=right_arg, modifiers=list(modifiers)
338
- )
339
- )
340
-
341
- # deduplication
342
- all_right = []
343
- for tuple in join_tuples:
344
- all_right.append(tuple.right.address)
345
- right_grain = identifier_map[right].grain
346
- # if the join includes all the right grain components
347
- # we only need to join on those, not everything
348
- if all([x.address in all_right for x in right_grain.components]):
349
- join_tuples = [
350
- x for x in join_tuples if x.right.address in right_grain.components
351
- ]
352
-
353
- final_joins_pre.append(
354
- BaseJoin(
355
- left_datasource=identifier_map[left],
356
- right_datasource=identifier_map[right],
357
- join_type=join_type,
358
- concepts=[],
359
- concept_pairs=join_tuples,
327
+ add_node_join_concept(
328
+ graph=graph,
329
+ concept=concept,
330
+ concept_map=concept_map,
331
+ ds_node=ds_node,
332
+ environment=environment,
360
333
  )
361
- )
362
- final_joins = resolve_join_order(final_joins_pre)
363
334
 
364
- # this is extra validation
365
- non_single_row_ds = [x for x in datasources if not x.grain.abstract]
366
- if len(non_single_row_ds) > 1:
367
- for x in datasources:
368
- if x.grain.abstract:
369
- continue
370
- found = False
371
- for join in final_joins:
372
- if (
373
- join.left_datasource.identifier == x.identifier
374
- or join.right_datasource.identifier == x.identifier
375
- ):
376
- found = True
377
- if not found:
378
- raise SyntaxError(
379
- f"Could not find join for {x.identifier} with output {[c.address for c in x.output_concepts]}, all {[z.identifier for z in datasources]}"
335
+ joins = resolve_join_order_v2(graph, partials=partials)
336
+ return [
337
+ BaseJoin(
338
+ left_datasource=ds_node_map[j.left] if j.left else None,
339
+ right_datasource=ds_node_map[j.right],
340
+ join_type=j.type,
341
+ # preserve empty field for maps
342
+ concepts=[] if not j.keys else None,
343
+ concept_pairs=[
344
+ ConceptPair(
345
+ left=resolve_instantiated_concept(
346
+ concept_map[concept], ds_node_map[k]
347
+ ),
348
+ right=resolve_instantiated_concept(
349
+ concept_map[concept], ds_node_map[j.right]
350
+ ),
351
+ existing_datasource=ds_node_map[k],
380
352
  )
381
- return final_joins
353
+ for k, v in j.keys.items()
354
+ for concept in v
355
+ ],
356
+ )
357
+ for j in joins
358
+ ]
382
359
 
383
360
 
384
361
  def get_disconnected_components(
@@ -523,11 +500,13 @@ def find_nullable_concepts(
523
500
  ]:
524
501
  is_on_nullable_condition = True
525
502
  break
503
+ left_check = (
504
+ join.left_datasource.identifier
505
+ if join.left_datasource is not None
506
+ else pair.existing_datasource.identifier
507
+ )
526
508
  if pair.left.address in [
527
- y.address
528
- for y in datasource_map[
529
- join.left_datasource.identifier
530
- ].nullable_concepts
509
+ y.address for y in datasource_map[left_check].nullable_concepts
531
510
  ]:
532
511
  is_on_nullable_condition = True
533
512
  break
@@ -17,7 +17,6 @@ from trilogy.core.models import (
17
17
  CTE,
18
18
  Join,
19
19
  UnnestJoin,
20
- JoinKey,
21
20
  MaterializedDataset,
22
21
  ProcessedQuery,
23
22
  ProcessedQueryPersist,
@@ -28,6 +27,7 @@ from trilogy.core.models import (
28
27
  Conditional,
29
28
  ProcessedCopyStatement,
30
29
  CopyStatement,
30
+ CTEConceptPair,
31
31
  )
32
32
 
33
33
  from trilogy.utility import unique
@@ -52,44 +52,53 @@ def base_join_to_join(
52
52
  concept_to_unnest=base_join.parent.concept_arguments[0],
53
53
  alias=base_join.alias,
54
54
  )
55
- if base_join.left_datasource.identifier == base_join.right_datasource.identifier:
56
- raise ValueError(f"Joining on same datasource {base_join}")
57
- left_ctes = [
58
- cte
59
- for cte in ctes
60
- if (cte.source.full_name == base_join.left_datasource.full_name)
61
- ]
62
- if not left_ctes:
63
- left_ctes = [
64
- cte
65
- for cte in ctes
66
- if (
67
- cte.source.datasources[0].full_name
68
- == base_join.left_datasource.full_name
55
+
56
+ def get_datasource_cte(datasource: Datasource | QueryDatasource) -> CTE:
57
+ for cte in ctes:
58
+ if cte.source.full_name == datasource.full_name:
59
+ return cte
60
+ for cte in ctes:
61
+ if cte.source.datasources[0].full_name == datasource.full_name:
62
+ return cte
63
+ raise ValueError(f"Could not find CTE for datasource {datasource.full_name}")
64
+
65
+ if base_join.left_datasource is not None:
66
+ left_cte = get_datasource_cte(base_join.left_datasource)
67
+ else:
68
+ # multiple left ctes
69
+ left_cte = None
70
+ right_cte = get_datasource_cte(base_join.right_datasource)
71
+ if base_join.concept_pairs:
72
+ final_pairs = [
73
+ CTEConceptPair(
74
+ left=pair.left,
75
+ right=pair.right,
76
+ existing_datasource=pair.existing_datasource,
77
+ modifiers=pair.modifiers,
78
+ cte=get_datasource_cte(pair.existing_datasource),
69
79
  )
80
+ for pair in base_join.concept_pairs
70
81
  ]
71
- left_cte = left_ctes[0]
72
- right_ctes = [
73
- cte
74
- for cte in ctes
75
- if (cte.source.full_name == base_join.right_datasource.full_name)
76
- ]
77
- if not right_ctes:
78
- right_ctes = [
79
- cte
80
- for cte in ctes
81
- if (
82
- cte.source.datasources[0].full_name
83
- == base_join.right_datasource.full_name
82
+ elif base_join.concepts and base_join.left_datasource:
83
+ final_pairs = [
84
+ CTEConceptPair(
85
+ left=concept,
86
+ right=concept,
87
+ existing_datasource=base_join.left_datasource,
88
+ modifiers=[],
89
+ cte=get_datasource_cte(
90
+ base_join.left_datasource,
91
+ ),
84
92
  )
93
+ for concept in base_join.concepts
85
94
  ]
86
- right_cte = right_ctes[0]
95
+ else:
96
+ final_pairs = []
87
97
  return Join(
88
98
  left_cte=left_cte,
89
99
  right_cte=right_cte,
90
- joinkeys=[JoinKey(concept=concept) for concept in base_join.concepts],
91
100
  jointype=base_join.join_type,
92
- joinkey_pairs=base_join.concept_pairs if base_join.concept_pairs else None,
101
+ joinkey_pairs=final_pairs,
93
102
  )
94
103
 
95
104
 
@@ -195,7 +204,6 @@ def resolve_cte_base_name_and_alias_v2(
195
204
  source_map: Dict[str, list[str]],
196
205
  raw_joins: List[Join | InstantiatedUnnestJoin],
197
206
  ) -> Tuple[str | None, str | None]:
198
- joins: List[Join] = [join for join in raw_joins if isinstance(join, Join)]
199
207
  if (
200
208
  isinstance(source.datasources[0], Datasource)
201
209
  and not source.datasources[0].name == CONSTANT_DATASET
@@ -203,8 +211,12 @@ def resolve_cte_base_name_and_alias_v2(
203
211
  ds = source.datasources[0]
204
212
  return ds.safe_location, ds.identifier
205
213
 
214
+ joins: List[Join] = [join for join in raw_joins if isinstance(join, Join)]
206
215
  if joins and len(joins) > 0:
207
- candidates = [x.left_cte.name for x in joins]
216
+ candidates = [x.left_cte.name for x in joins if x.left_cte]
217
+ for join in joins:
218
+ if join.joinkey_pairs:
219
+ candidates += [x.cte.name for x in join.joinkey_pairs if x.cte]
208
220
  disallowed = [x.right_cte.name for x in joins]
209
221
  try:
210
222
  cte = [y for y in candidates if y not in disallowed][0]
@@ -213,7 +225,6 @@ def resolve_cte_base_name_and_alias_v2(
213
225
  raise SyntaxError(
214
226
  f"Invalid join configuration {candidates} {disallowed} for {name}",
215
227
  )
216
-
217
228
  counts: dict[str, int] = defaultdict(lambda: 0)
218
229
  output_addresses = [x.address for x in source.output_concepts]
219
230
  input_address = [x.address for x in source.input_concepts]
@@ -269,11 +280,8 @@ def datasource_to_ctes(
269
280
 
270
281
  human_id = generate_cte_name(query_datasource.full_name, name_map)
271
282
 
272
- final_joins = [
273
- x
274
- for x in [base_join_to_join(join, parents) for join in query_datasource.joins]
275
- if x
276
- ]
283
+ final_joins = [base_join_to_join(join, parents) for join in query_datasource.joins]
284
+
277
285
  base_name, base_alias = resolve_cte_base_name_and_alias_v2(
278
286
  human_id, query_datasource, source_map, final_joins
279
287
  )