pytrilogy 0.0.1.102__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pytrilogy might be problematic. Click here for more details.

Files changed (77) hide show
  1. pytrilogy-0.0.1.102.dist-info/LICENSE.md +19 -0
  2. pytrilogy-0.0.1.102.dist-info/METADATA +277 -0
  3. pytrilogy-0.0.1.102.dist-info/RECORD +77 -0
  4. pytrilogy-0.0.1.102.dist-info/WHEEL +5 -0
  5. pytrilogy-0.0.1.102.dist-info/entry_points.txt +2 -0
  6. pytrilogy-0.0.1.102.dist-info/top_level.txt +1 -0
  7. trilogy/__init__.py +8 -0
  8. trilogy/compiler.py +0 -0
  9. trilogy/constants.py +30 -0
  10. trilogy/core/__init__.py +0 -0
  11. trilogy/core/constants.py +3 -0
  12. trilogy/core/enums.py +270 -0
  13. trilogy/core/env_processor.py +33 -0
  14. trilogy/core/environment_helpers.py +156 -0
  15. trilogy/core/ergonomics.py +187 -0
  16. trilogy/core/exceptions.py +23 -0
  17. trilogy/core/functions.py +320 -0
  18. trilogy/core/graph_models.py +55 -0
  19. trilogy/core/internal.py +37 -0
  20. trilogy/core/models.py +3145 -0
  21. trilogy/core/processing/__init__.py +0 -0
  22. trilogy/core/processing/concept_strategies_v3.py +603 -0
  23. trilogy/core/processing/graph_utils.py +44 -0
  24. trilogy/core/processing/node_generators/__init__.py +25 -0
  25. trilogy/core/processing/node_generators/basic_node.py +71 -0
  26. trilogy/core/processing/node_generators/common.py +239 -0
  27. trilogy/core/processing/node_generators/concept_merge.py +152 -0
  28. trilogy/core/processing/node_generators/filter_node.py +83 -0
  29. trilogy/core/processing/node_generators/group_node.py +92 -0
  30. trilogy/core/processing/node_generators/group_to_node.py +99 -0
  31. trilogy/core/processing/node_generators/merge_node.py +148 -0
  32. trilogy/core/processing/node_generators/multiselect_node.py +189 -0
  33. trilogy/core/processing/node_generators/rowset_node.py +130 -0
  34. trilogy/core/processing/node_generators/select_node.py +328 -0
  35. trilogy/core/processing/node_generators/unnest_node.py +37 -0
  36. trilogy/core/processing/node_generators/window_node.py +85 -0
  37. trilogy/core/processing/nodes/__init__.py +76 -0
  38. trilogy/core/processing/nodes/base_node.py +251 -0
  39. trilogy/core/processing/nodes/filter_node.py +49 -0
  40. trilogy/core/processing/nodes/group_node.py +110 -0
  41. trilogy/core/processing/nodes/merge_node.py +326 -0
  42. trilogy/core/processing/nodes/select_node_v2.py +198 -0
  43. trilogy/core/processing/nodes/unnest_node.py +54 -0
  44. trilogy/core/processing/nodes/window_node.py +34 -0
  45. trilogy/core/processing/utility.py +278 -0
  46. trilogy/core/query_processor.py +331 -0
  47. trilogy/dialect/__init__.py +0 -0
  48. trilogy/dialect/base.py +679 -0
  49. trilogy/dialect/bigquery.py +80 -0
  50. trilogy/dialect/common.py +43 -0
  51. trilogy/dialect/config.py +55 -0
  52. trilogy/dialect/duckdb.py +83 -0
  53. trilogy/dialect/enums.py +95 -0
  54. trilogy/dialect/postgres.py +86 -0
  55. trilogy/dialect/presto.py +82 -0
  56. trilogy/dialect/snowflake.py +82 -0
  57. trilogy/dialect/sql_server.py +89 -0
  58. trilogy/docs/__init__.py +0 -0
  59. trilogy/engine.py +48 -0
  60. trilogy/executor.py +242 -0
  61. trilogy/hooks/__init__.py +0 -0
  62. trilogy/hooks/base_hook.py +37 -0
  63. trilogy/hooks/graph_hook.py +24 -0
  64. trilogy/hooks/query_debugger.py +133 -0
  65. trilogy/metadata/__init__.py +0 -0
  66. trilogy/parser.py +10 -0
  67. trilogy/parsing/__init__.py +0 -0
  68. trilogy/parsing/common.py +176 -0
  69. trilogy/parsing/config.py +5 -0
  70. trilogy/parsing/exceptions.py +2 -0
  71. trilogy/parsing/helpers.py +1 -0
  72. trilogy/parsing/parse_engine.py +1951 -0
  73. trilogy/parsing/render.py +483 -0
  74. trilogy/py.typed +0 -0
  75. trilogy/scripts/__init__.py +0 -0
  76. trilogy/scripts/trilogy.py +127 -0
  77. trilogy/utility.py +31 -0
@@ -0,0 +1,278 @@
1
+ from typing import List, Tuple, Dict, Set
2
+ import networkx as nx
3
+ from trilogy.core.models import (
4
+ Datasource,
5
+ JoinType,
6
+ BaseJoin,
7
+ Concept,
8
+ QueryDatasource,
9
+ LooseConceptList,
10
+ )
11
+
12
+ from trilogy.core.enums import Purpose, Granularity
13
+ from trilogy.core.constants import CONSTANT_DATASET
14
+ from enum import Enum
15
+ from trilogy.utility import unique
16
+ from collections import defaultdict
17
+ from logging import Logger
18
+ from pydantic import BaseModel
19
+
20
+
21
+ class NodeType(Enum):
22
+ CONCEPT = 1
23
+ NODE = 2
24
+
25
+
26
+ class PathInfo(BaseModel):
27
+ paths: Dict[str, List[str]]
28
+ datasource: Datasource
29
+ reduced_concepts: Set[str]
30
+ concept_subgraphs: List[List[Concept]]
31
+
32
+
33
+ def concept_to_relevant_joins(concepts: list[Concept]) -> List[Concept]:
34
+ addresses = LooseConceptList(concepts=concepts)
35
+ sub_props = LooseConceptList(
36
+ concepts=[
37
+ x for x in concepts if x.keys and all([key in addresses for key in x.keys])
38
+ ]
39
+ )
40
+ final = [c for c in concepts if c not in sub_props]
41
+ return unique(final, "address")
42
+
43
+
44
+ def padding(x: int) -> str:
45
+ return "\t" * x
46
+
47
+
48
+ def create_log_lambda(prefix: str, depth: int, logger: Logger):
49
+ pad = padding(depth)
50
+
51
+ def log_lambda(msg: str):
52
+ logger.info(f"{pad} {prefix} {msg}")
53
+
54
+ return log_lambda
55
+
56
+
57
+ def calculate_graph_relevance(
58
+ g: nx.DiGraph, subset_nodes: set[str], concepts: set[Concept]
59
+ ) -> int:
60
+ """Calculate the relevance of each node in a graph
61
+ Relevance is used to prune irrelevant nodes from the graph
62
+ """
63
+ relevance = 0
64
+ for node in g.nodes:
65
+ if node not in subset_nodes:
66
+ continue
67
+ if not g.nodes[node]["type"] == NodeType.CONCEPT:
68
+ continue
69
+ concept = [x for x in concepts if x.address == node].pop()
70
+
71
+ # a single row concept can always be crossjoined
72
+ # therefore a graph with only single row concepts is always relevant
73
+ if concept.granularity == Granularity.SINGLE_ROW:
74
+ continue
75
+ # if it's an aggregate up to an arbitrary grain, it can be joined in later
76
+ # and can be ignored in subgraph
77
+ if concept.purpose == Purpose.METRIC:
78
+ if not concept.grain:
79
+ continue
80
+ if len(concept.grain.components) == 0:
81
+ continue
82
+ if concept.grain and len(concept.grain.components) > 0:
83
+ relevance += 1
84
+ continue
85
+ # Added 2023-10-18 since we seemed to be strangely dropping things
86
+ relevance += 1
87
+
88
+ return relevance
89
+
90
+
91
+ def resolve_join_order(joins: List[BaseJoin]) -> List[BaseJoin]:
92
+ available_aliases: set[str] = set()
93
+ final_joins_pre = [*joins]
94
+ final_joins = []
95
+ while final_joins_pre:
96
+ new_final_joins_pre: List[BaseJoin] = []
97
+ for join in final_joins_pre:
98
+ if not available_aliases:
99
+ final_joins.append(join)
100
+ available_aliases.add(join.left_datasource.identifier)
101
+ available_aliases.add(join.right_datasource.identifier)
102
+ elif join.left_datasource.identifier in available_aliases:
103
+ # we don't need to join twice
104
+ # so whatever join we found first, works
105
+ if join.right_datasource.identifier in available_aliases:
106
+ continue
107
+ final_joins.append(join)
108
+ available_aliases.add(join.left_datasource.identifier)
109
+ available_aliases.add(join.right_datasource.identifier)
110
+ else:
111
+ new_final_joins_pre.append(join)
112
+ if len(new_final_joins_pre) == len(final_joins_pre):
113
+ remaining = [
114
+ join.left_datasource.identifier for join in new_final_joins_pre
115
+ ]
116
+ remaining_right = [
117
+ join.right_datasource.identifier for join in new_final_joins_pre
118
+ ]
119
+ raise SyntaxError(
120
+ f"did not find any new joins, available {available_aliases} remaining is {remaining + remaining_right} "
121
+ )
122
+ final_joins_pre = new_final_joins_pre
123
+ return final_joins
124
+
125
+
126
+ def get_node_joins(
127
+ datasources: List[QueryDatasource],
128
+ grain: List[Concept],
129
+ # concepts:List[Concept],
130
+ ) -> List[BaseJoin]:
131
+ graph = nx.Graph()
132
+ concepts: List[Concept] = []
133
+ for datasource in datasources:
134
+ graph.add_node(datasource.identifier, type=NodeType.NODE)
135
+ for concept in datasource.output_concepts:
136
+ # we don't need to join on a concept if all of the keys exist in the grain
137
+ # if concept.keys and all([x in grain for x in concept.keys]):
138
+ # continue
139
+ concepts.append(concept)
140
+ graph.add_node(concept.address, type=NodeType.CONCEPT)
141
+ graph.add_edge(datasource.identifier, concept.address)
142
+
143
+ # add edges for every constant to every datasource
144
+ for datasource in datasources:
145
+ for concept in datasource.output_concepts:
146
+ if concept.granularity == Granularity.SINGLE_ROW:
147
+ for node in graph.nodes:
148
+ if graph.nodes[node]["type"] == NodeType.NODE:
149
+ graph.add_edge(node, concept.address)
150
+
151
+ joins: defaultdict[str, set] = defaultdict(set)
152
+ identifier_map = {x.identifier: x for x in datasources}
153
+
154
+ node_list = sorted(
155
+ [x for x in graph.nodes if graph.nodes[x]["type"] == NodeType.NODE],
156
+ # sort so that anything with a partial match on the target is later
157
+ key=lambda x: len(
158
+ [x for x in identifier_map[x].partial_concepts if x in grain]
159
+ ),
160
+ )
161
+ for left in node_list:
162
+ # the constant dataset is a special case
163
+ # and can never be on the left of a join
164
+ if left == CONSTANT_DATASET:
165
+ continue
166
+
167
+ for cnode in graph.neighbors(left):
168
+ if graph.nodes[cnode]["type"] == NodeType.CONCEPT:
169
+ for right in graph.neighbors(cnode):
170
+ # skip concepts
171
+ if graph.nodes[right]["type"] == NodeType.CONCEPT:
172
+ continue
173
+ if left == right:
174
+ continue
175
+ identifier = [left, right]
176
+ joins["-".join(identifier)].add(cnode)
177
+
178
+ final_joins_pre: List[BaseJoin] = []
179
+
180
+ for key, join_concepts in joins.items():
181
+ left, right = key.split("-")
182
+ local_concepts: List[Concept] = unique(
183
+ [c for c in concepts if c.address in join_concepts], "address"
184
+ )
185
+ if all([c.granularity == Granularity.SINGLE_ROW for c in local_concepts]):
186
+ # for the constant join, make it a full outer join on 1=1
187
+ join_type = JoinType.FULL
188
+ local_concepts = []
189
+ elif any(
190
+ [
191
+ c.address in [x.address for x in identifier_map[left].partial_concepts]
192
+ for c in local_concepts
193
+ ]
194
+ ):
195
+ join_type = JoinType.FULL
196
+ local_concepts = [
197
+ c for c in local_concepts if c.granularity != Granularity.SINGLE_ROW
198
+ ]
199
+ else:
200
+ join_type = JoinType.LEFT_OUTER
201
+ # remove any constants if other join keys exist
202
+ local_concepts = [
203
+ c for c in local_concepts if c.granularity != Granularity.SINGLE_ROW
204
+ ]
205
+
206
+ # if concept.keys and all([x in grain for x in concept.keys]):
207
+ # continue
208
+ final_joins_pre.append(
209
+ BaseJoin(
210
+ left_datasource=identifier_map[left],
211
+ right_datasource=identifier_map[right],
212
+ join_type=join_type,
213
+ concepts=concept_to_relevant_joins(local_concepts),
214
+ )
215
+ )
216
+ final_joins = resolve_join_order(final_joins_pre)
217
+
218
+ # this is extra validation
219
+ non_single_row_ds = [x for x in datasources if not x.grain.abstract]
220
+ if len(non_single_row_ds) > 1:
221
+ for x in datasources:
222
+ if x.grain.abstract:
223
+ continue
224
+ found = False
225
+ for join in final_joins:
226
+ if (
227
+ join.left_datasource.identifier == x.identifier
228
+ or join.right_datasource.identifier == x.identifier
229
+ ):
230
+ found = True
231
+ if not found:
232
+ raise SyntaxError(
233
+ f"Could not find join for {x.identifier} with output {[c.address for c in x.output_concepts]}, all {[z.identifier for z in datasources]}"
234
+ )
235
+ single_row = [x for x in datasources if x.grain.abstract]
236
+ for x in single_row:
237
+ for join in final_joins:
238
+ found = False
239
+ for join in final_joins:
240
+ if (
241
+ join.left_datasource.identifier == x.identifier
242
+ or join.right_datasource.identifier == x.identifier
243
+ ):
244
+ found = True
245
+ if not found:
246
+ raise SyntaxError(
247
+ f"Could not find join for {x.identifier} with output {[c.address for c in x.output_concepts]}, all {[z.identifier for z in datasources]}"
248
+ )
249
+ return final_joins
250
+
251
+
252
+ def get_disconnected_components(
253
+ concept_map: Dict[str, Set[Concept]]
254
+ ) -> Tuple[int, List]:
255
+ """Find if any of the datasources are not linked"""
256
+ import networkx as nx
257
+
258
+ graph = nx.Graph()
259
+ all_concepts = set()
260
+ for datasource, concepts in concept_map.items():
261
+ graph.add_node(datasource, type=NodeType.NODE)
262
+ for concept in concepts:
263
+ # TODO: determine if this is the right way to handle things
264
+ # if concept.derivation in (PurposeLineage.FILTER, PurposeLineage.WINDOW):
265
+ # if isinstance(concept.lineage, FilterItem):
266
+ # graph.add_node(concept.lineage.content.address, type=NodeType.CONCEPT)
267
+ # graph.add_edge(datasource, concept.lineage.content.address)
268
+ # if isinstance(concept.lineage, WindowItem):
269
+ # graph.add_node(concept.lineage.content.address, type=NodeType.CONCEPT)
270
+ # graph.add_edge(datasource, concept.lineage.content.address)
271
+ graph.add_node(concept.address, type=NodeType.CONCEPT)
272
+ graph.add_edge(datasource, concept.address)
273
+ all_concepts.add(concept)
274
+ sub_graphs = list(nx.connected_components(graph))
275
+ sub_graphs = [
276
+ x for x in sub_graphs if calculate_graph_relevance(graph, x, all_concepts) > 0
277
+ ]
278
+ return len(sub_graphs), sub_graphs
@@ -0,0 +1,331 @@
1
+ from typing import List, Optional, Set, Union, Dict
2
+
3
+ from trilogy.core.env_processor import generate_graph
4
+ from trilogy.core.graph_models import ReferenceGraph
5
+ from trilogy.core.constants import CONSTANT_DATASET
6
+ from trilogy.core.processing.concept_strategies_v3 import source_query_concepts
7
+ from trilogy.constants import CONFIG, DEFAULT_NAMESPACE
8
+ from trilogy.core.models import (
9
+ Environment,
10
+ PersistStatement,
11
+ SelectStatement,
12
+ MultiSelectStatement,
13
+ CTE,
14
+ Join,
15
+ UnnestJoin,
16
+ JoinKey,
17
+ MaterializedDataset,
18
+ ProcessedQuery,
19
+ ProcessedQueryPersist,
20
+ QueryDatasource,
21
+ Datasource,
22
+ BaseJoin,
23
+ InstantiatedUnnestJoin,
24
+ )
25
+
26
+ from trilogy.utility import unique
27
+ from collections import defaultdict
28
+ from trilogy.hooks.base_hook import BaseHook
29
+ from trilogy.constants import logger
30
+ from random import shuffle
31
+ from trilogy.core.ergonomics import CTE_NAMES
32
+ from math import ceil
33
+
34
+ LOGGER_PREFIX = "[QUERY BUILD]"
35
+
36
+
37
+ def base_join_to_join(
38
+ base_join: BaseJoin | UnnestJoin, ctes: List[CTE]
39
+ ) -> Join | InstantiatedUnnestJoin:
40
+ """This function converts joins at the datasource level
41
+ to joins at the CTE level"""
42
+ if isinstance(base_join, UnnestJoin):
43
+ return InstantiatedUnnestJoin(concept=base_join.concept, alias=base_join.alias)
44
+ if base_join.left_datasource.identifier == base_join.right_datasource.identifier:
45
+ raise ValueError(f"Joining on same datasource {base_join}")
46
+ left_ctes = [
47
+ cte
48
+ for cte in ctes
49
+ if (cte.source.full_name == base_join.left_datasource.full_name)
50
+ ]
51
+ if not left_ctes:
52
+ left_ctes = [
53
+ cte
54
+ for cte in ctes
55
+ if (
56
+ cte.source.datasources[0].full_name
57
+ == base_join.left_datasource.full_name
58
+ )
59
+ ]
60
+ left_cte = left_ctes[0]
61
+ right_ctes = [
62
+ cte
63
+ for cte in ctes
64
+ if (cte.source.full_name == base_join.right_datasource.full_name)
65
+ ]
66
+ if not right_ctes:
67
+ right_ctes = [
68
+ cte
69
+ for cte in ctes
70
+ if (
71
+ cte.source.datasources[0].full_name
72
+ == base_join.right_datasource.full_name
73
+ )
74
+ ]
75
+ right_cte = right_ctes[0]
76
+ return Join(
77
+ left_cte=left_cte,
78
+ right_cte=right_cte,
79
+ joinkeys=[JoinKey(concept=concept) for concept in base_join.concepts],
80
+ jointype=base_join.join_type,
81
+ )
82
+
83
+
84
+ def generate_source_map(
85
+ query_datasource: QueryDatasource, all_new_ctes: List[CTE]
86
+ ) -> Dict[str, str | list[str]]:
87
+ source_map: Dict[str, list[str]] = defaultdict(list)
88
+ # now populate anything derived in this level
89
+ for qdk, qdv in query_datasource.source_map.items():
90
+ if (
91
+ qdk not in source_map
92
+ and len(qdv) == 1
93
+ and isinstance(list(qdv)[0], UnnestJoin)
94
+ ):
95
+ source_map[qdk] = []
96
+
97
+ else:
98
+ for cte in all_new_ctes:
99
+ output_address = [
100
+ x.address
101
+ for x in cte.output_columns
102
+ if x.address not in [z.address for z in cte.partial_concepts]
103
+ ]
104
+ if qdk in output_address:
105
+ source_map[qdk].append(cte.name)
106
+ # now do a pass that accepts partials
107
+ # TODO: move this into a second loop by first creationg all sub sourcdes
108
+ # then loop through this
109
+ for cte in all_new_ctes:
110
+ output_address = [x.address for x in cte.output_columns]
111
+ if qdk in output_address:
112
+ if qdk not in source_map:
113
+ source_map[qdk] = [cte.name]
114
+ if qdk not in source_map and not qdv:
115
+ # set source to empty, as it must be derived in this element
116
+ source_map[qdk] = []
117
+ if qdk not in source_map:
118
+ raise ValueError(
119
+ f"Missing {qdk} in {source_map}, source map {query_datasource.source_map.keys()} "
120
+ )
121
+ return {k: "" if not v else v for k, v in source_map.items()}
122
+
123
+
124
+ def datasource_to_query_datasource(datasource: Datasource) -> QueryDatasource:
125
+ sub_select: Dict[str, Set[Union[Datasource, QueryDatasource, UnnestJoin]]] = {
126
+ **{c.address: {datasource} for c in datasource.concepts},
127
+ }
128
+ concepts = [c for c in datasource.concepts]
129
+ concepts = unique(concepts, "address")
130
+ return QueryDatasource(
131
+ output_concepts=concepts,
132
+ input_concepts=concepts,
133
+ source_map=sub_select,
134
+ grain=datasource.grain,
135
+ datasources=[datasource],
136
+ joins=[],
137
+ partial_concepts=[x.concept for x in datasource.columns if not x.is_complete],
138
+ )
139
+
140
+
141
+ def generate_cte_name(full_name: str, name_map: dict[str, str]) -> str:
142
+ if CONFIG.human_identifiers:
143
+ if full_name in name_map:
144
+ return name_map[full_name]
145
+ suffix = ""
146
+ idx = len(name_map)
147
+ if idx >= len(CTE_NAMES):
148
+ int = ceil(idx / len(CTE_NAMES))
149
+ suffix = f"_{int}"
150
+ valid = [x for x in CTE_NAMES if x + suffix not in name_map.values()]
151
+ shuffle(valid)
152
+ lookup = valid[0]
153
+ new_name = f"{lookup}{suffix}"
154
+ name_map[full_name] = new_name
155
+ return new_name
156
+ else:
157
+ return full_name.replace("<", "").replace(">", "").replace(",", "_")
158
+
159
+
160
+ def datasource_to_ctes(
161
+ query_datasource: QueryDatasource, name_map: dict[str, str]
162
+ ) -> List[CTE]:
163
+ output: List[CTE] = []
164
+ parents: list[CTE] = []
165
+ if len(query_datasource.datasources) > 1 or any(
166
+ [isinstance(x, QueryDatasource) for x in query_datasource.datasources]
167
+ ):
168
+ all_new_ctes: List[CTE] = []
169
+ for datasource in query_datasource.datasources:
170
+ if isinstance(datasource, QueryDatasource):
171
+ sub_datasource = datasource
172
+ else:
173
+ sub_datasource = datasource_to_query_datasource(datasource)
174
+
175
+ sub_cte = datasource_to_ctes(sub_datasource, name_map)
176
+ parents += sub_cte
177
+ all_new_ctes += sub_cte
178
+ source_map = generate_source_map(query_datasource, all_new_ctes)
179
+ else:
180
+ # source is the first datasource of the query datasource
181
+ source = query_datasource.datasources[0]
182
+ # this is required to ensure that constant datasets
183
+ # render properly on initial access; since they have
184
+ # no actual source
185
+ if source.full_name == DEFAULT_NAMESPACE + "_" + CONSTANT_DATASET:
186
+ source_map = {k: "" for k in query_datasource.source_map}
187
+ else:
188
+ source_map = {
189
+ k: "" if not v else source.full_name
190
+ for k, v in query_datasource.source_map.items()
191
+ }
192
+ human_id = generate_cte_name(query_datasource.full_name, name_map)
193
+ cte = CTE(
194
+ name=human_id,
195
+ source=query_datasource,
196
+ # output columns are what are selected/grouped by
197
+ output_columns=[
198
+ c.with_grain(query_datasource.grain)
199
+ for c in query_datasource.output_concepts
200
+ ],
201
+ source_map=source_map,
202
+ # related columns include all referenced columns, such as filtering
203
+ joins=[
204
+ x
205
+ for x in [
206
+ base_join_to_join(join, parents) for join in query_datasource.joins
207
+ ]
208
+ if x
209
+ ],
210
+ grain=query_datasource.grain,
211
+ group_to_grain=query_datasource.group_required,
212
+ # we restrict parent_ctes to one level
213
+ # as this set is used as the base for rendering the query
214
+ parent_ctes=parents,
215
+ condition=query_datasource.condition,
216
+ partial_concepts=query_datasource.partial_concepts,
217
+ join_derived_concepts=query_datasource.join_derived_concepts,
218
+ )
219
+ if cte.grain != query_datasource.grain:
220
+ raise ValueError("Grain was corrupted in CTE generation")
221
+ for x in cte.output_columns:
222
+ if x.address not in cte.source_map:
223
+ raise ValueError(
224
+ f"Missing {x.address} in {cte.source_map}, source map {cte.source.source_map.keys()} "
225
+ )
226
+
227
+ output.append(cte)
228
+ return output
229
+
230
+
231
+ def get_query_datasources(
232
+ environment: Environment,
233
+ statement: SelectStatement | MultiSelectStatement,
234
+ graph: Optional[ReferenceGraph] = None,
235
+ hooks: Optional[List[BaseHook]] = None,
236
+ ) -> QueryDatasource:
237
+ graph = graph or generate_graph(environment)
238
+ logger.info(
239
+ f"{LOGGER_PREFIX} getting source datasource for query with output {[str(c) for c in statement.output_components]}"
240
+ )
241
+ if not statement.output_components:
242
+ raise ValueError(f"Statement has no output components {statement}")
243
+ ds = source_query_concepts(
244
+ statement.output_components, environment=environment, g=graph
245
+ )
246
+ if hooks:
247
+ for hook in hooks:
248
+ hook.process_root_strategy_node(ds)
249
+ final_qds = ds.resolve()
250
+ return final_qds
251
+
252
+
253
+ def flatten_ctes(input: CTE) -> list[CTE]:
254
+ output = [input]
255
+ for cte in input.parent_ctes:
256
+ output += flatten_ctes(cte)
257
+ return output
258
+
259
+
260
+ def process_auto(
261
+ environment: Environment,
262
+ statement: PersistStatement | SelectStatement,
263
+ hooks: List[BaseHook] | None = None,
264
+ ):
265
+ if isinstance(statement, PersistStatement):
266
+ return process_persist(environment, statement, hooks)
267
+ elif isinstance(statement, SelectStatement):
268
+ return process_query(environment, statement, hooks)
269
+ raise ValueError(f"Do not know how to process {type(statement)}")
270
+
271
+
272
+ def process_persist(
273
+ environment: Environment,
274
+ statement: PersistStatement,
275
+ hooks: List[BaseHook] | None = None,
276
+ ) -> ProcessedQueryPersist:
277
+ select = process_query(
278
+ environment=environment, statement=statement.select, hooks=hooks
279
+ )
280
+
281
+ # build our object to return
282
+ arg_dict = {k: v for k, v in select.__dict__.items()}
283
+ return ProcessedQueryPersist(
284
+ **arg_dict,
285
+ output_to=MaterializedDataset(address=statement.address),
286
+ datasource=statement.datasource,
287
+ )
288
+
289
+
290
+ def process_query(
291
+ environment: Environment,
292
+ statement: SelectStatement | MultiSelectStatement,
293
+ hooks: List[BaseHook] | None = None,
294
+ ) -> ProcessedQuery:
295
+ hooks = hooks or []
296
+ graph = generate_graph(environment)
297
+ root_datasource = get_query_datasources(
298
+ environment=environment, graph=graph, statement=statement, hooks=hooks
299
+ )
300
+ for hook in hooks:
301
+ hook.process_root_datasource(root_datasource)
302
+ # this should always return 1 - TODO, refactor
303
+ root_cte = datasource_to_ctes(root_datasource, environment.cte_name_map)[0]
304
+ for hook in hooks:
305
+ hook.process_root_cte(root_cte)
306
+ raw_ctes: List[CTE] = list(reversed(flatten_ctes(root_cte)))
307
+ seen = dict()
308
+ # we can have duplicate CTEs at this point
309
+ # so merge them together
310
+ for cte in raw_ctes:
311
+ if cte.name not in seen:
312
+ seen[cte.name] = cte
313
+ else:
314
+ # merge them up
315
+ seen[cte.name] = seen[cte.name] + cte
316
+ for cte in raw_ctes:
317
+ cte.parent_ctes = [seen[x.name] for x in cte.parent_ctes]
318
+ final_ctes: List[CTE] = list(seen.values())
319
+
320
+ return ProcessedQuery(
321
+ order_by=statement.order_by,
322
+ grain=statement.grain,
323
+ limit=statement.limit,
324
+ where_clause=statement.where_clause,
325
+ output_columns=statement.output_components,
326
+ ctes=final_ctes,
327
+ base=root_cte,
328
+ # we no longer do any joins at final level, this should always happen in parent CTEs
329
+ joins=[],
330
+ hidden_columns=[x for x in statement.hidden_components],
331
+ )
File without changes