altimate-code 0.5.2 → 0.5.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +27 -0
- package/bin/altimate +6 -0
- package/bin/altimate-code +6 -0
- package/dbt-tools/bin/altimate-dbt +2 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/altimate/__init__.py +0 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/altimate/fetch_schema.py +35 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/altimate/utils.py +353 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/altimate/validate_sql.py +114 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/__init__.py +178 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/__main__.py +96 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/_typing.py +17 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/__init__.py +3 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/__init__.py +18 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/_typing.py +18 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/column.py +332 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/dataframe.py +866 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/functions.py +1267 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/group.py +59 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/normalize.py +78 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/operations.py +53 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/readwriter.py +108 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/session.py +190 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/transforms.py +9 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/types.py +212 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/util.py +32 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/window.py +134 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/__init__.py +118 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/athena.py +166 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/bigquery.py +1331 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/clickhouse.py +1393 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/databricks.py +131 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/dialect.py +1915 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/doris.py +561 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/drill.py +157 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/druid.py +20 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/duckdb.py +1159 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/dune.py +16 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/hive.py +787 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/materialize.py +94 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/mysql.py +1324 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/oracle.py +378 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/postgres.py +778 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/presto.py +788 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/prql.py +203 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/redshift.py +448 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/risingwave.py +78 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/snowflake.py +1464 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/spark.py +202 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/spark2.py +349 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/sqlite.py +320 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/starrocks.py +343 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/tableau.py +61 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/teradata.py +356 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/trino.py +115 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/tsql.py +1403 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/diff.py +456 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/errors.py +93 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/__init__.py +95 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/context.py +101 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/env.py +246 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/python.py +460 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/table.py +155 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/expressions.py +8870 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/generator.py +4993 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/helper.py +582 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/jsonpath.py +227 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/lineage.py +423 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/__init__.py +11 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/annotate_types.py +589 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/canonicalize.py +222 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/eliminate_ctes.py +43 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/eliminate_joins.py +181 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/eliminate_subqueries.py +189 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/isolate_table_selects.py +50 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/merge_subqueries.py +415 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/normalize.py +200 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/normalize_identifiers.py +64 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/optimize_joins.py +91 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/optimizer.py +94 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/pushdown_predicates.py +222 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/pushdown_projections.py +172 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/qualify.py +104 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/qualify_columns.py +1024 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/qualify_tables.py +155 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/scope.py +904 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/simplify.py +1587 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/unnest_subqueries.py +302 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/parser.py +8501 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/planner.py +463 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/schema.py +588 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/serde.py +68 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/time.py +687 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/tokens.py +1520 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/transforms.py +1020 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/trie.py +81 -0
- package/dbt-tools/dist/altimate_python_packages/dbt_core_integration.py +825 -0
- package/dbt-tools/dist/altimate_python_packages/dbt_utils.py +157 -0
- package/dbt-tools/dist/index.js +23859 -0
- package/package.json +14 -18
- package/postinstall.mjs +42 -0
- package/skills/altimate-setup/SKILL.md +31 -0
|
@@ -0,0 +1,1024 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import itertools
|
|
4
|
+
import typing as t
|
|
5
|
+
|
|
6
|
+
from sqlglot import alias, exp
|
|
7
|
+
from sqlglot.dialects.dialect import Dialect, DialectType
|
|
8
|
+
from sqlglot.errors import OptimizeError
|
|
9
|
+
from sqlglot.helper import seq_get, SingleValuedMapping
|
|
10
|
+
from sqlglot.optimizer.annotate_types import TypeAnnotator
|
|
11
|
+
from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope
|
|
12
|
+
from sqlglot.optimizer.simplify import simplify_parens
|
|
13
|
+
from sqlglot.schema import Schema, ensure_schema
|
|
14
|
+
|
|
15
|
+
if t.TYPE_CHECKING:
|
|
16
|
+
from sqlglot._typing import E
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def qualify_columns(
|
|
20
|
+
expression: exp.Expression,
|
|
21
|
+
schema: t.Dict | Schema,
|
|
22
|
+
expand_alias_refs: bool = True,
|
|
23
|
+
expand_stars: bool = True,
|
|
24
|
+
infer_schema: t.Optional[bool] = None,
|
|
25
|
+
allow_partial_qualification: bool = False,
|
|
26
|
+
dialect: DialectType = None,
|
|
27
|
+
) -> exp.Expression:
|
|
28
|
+
"""
|
|
29
|
+
Rewrite sqlglot AST to have fully qualified columns.
|
|
30
|
+
|
|
31
|
+
Example:
|
|
32
|
+
>>> import sqlglot
|
|
33
|
+
>>> schema = {"tbl": {"col": "INT"}}
|
|
34
|
+
>>> expression = sqlglot.parse_one("SELECT col FROM tbl")
|
|
35
|
+
>>> qualify_columns(expression, schema).sql()
|
|
36
|
+
'SELECT tbl.col AS col FROM tbl'
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
expression: Expression to qualify.
|
|
40
|
+
schema: Database schema.
|
|
41
|
+
expand_alias_refs: Whether to expand references to aliases.
|
|
42
|
+
expand_stars: Whether to expand star queries. This is a necessary step
|
|
43
|
+
for most of the optimizer's rules to work; do not set to False unless you
|
|
44
|
+
know what you're doing!
|
|
45
|
+
infer_schema: Whether to infer the schema if missing.
|
|
46
|
+
allow_partial_qualification: Whether to allow partial qualification.
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
The qualified expression.
|
|
50
|
+
|
|
51
|
+
Notes:
|
|
52
|
+
- Currently only handles a single PIVOT or UNPIVOT operator
|
|
53
|
+
"""
|
|
54
|
+
schema = ensure_schema(schema, dialect=dialect)
|
|
55
|
+
annotator = TypeAnnotator(schema)
|
|
56
|
+
infer_schema = schema.empty if infer_schema is None else infer_schema
|
|
57
|
+
dialect = Dialect.get_or_raise(schema.dialect)
|
|
58
|
+
pseudocolumns = dialect.PSEUDOCOLUMNS
|
|
59
|
+
bigquery = dialect == "bigquery"
|
|
60
|
+
|
|
61
|
+
for scope in traverse_scope(expression):
|
|
62
|
+
scope_expression = scope.expression
|
|
63
|
+
is_select = isinstance(scope_expression, exp.Select)
|
|
64
|
+
|
|
65
|
+
if is_select and scope_expression.args.get("connect"):
|
|
66
|
+
# In Snowflake / Oracle queries that have a CONNECT BY clause, one can use the LEVEL
|
|
67
|
+
# pseudocolumn, which doesn't belong to a table, so we change it into an identifier
|
|
68
|
+
scope_expression.transform(
|
|
69
|
+
lambda n: n.this if isinstance(n, exp.Column) and n.name == "LEVEL" else n,
|
|
70
|
+
copy=False,
|
|
71
|
+
)
|
|
72
|
+
scope.clear_cache()
|
|
73
|
+
|
|
74
|
+
resolver = Resolver(scope, schema, infer_schema=infer_schema)
|
|
75
|
+
_pop_table_column_aliases(scope.ctes)
|
|
76
|
+
_pop_table_column_aliases(scope.derived_tables)
|
|
77
|
+
using_column_tables = _expand_using(scope, resolver)
|
|
78
|
+
|
|
79
|
+
if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs:
|
|
80
|
+
_expand_alias_refs(
|
|
81
|
+
scope,
|
|
82
|
+
resolver,
|
|
83
|
+
dialect,
|
|
84
|
+
expand_only_groupby=bigquery,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
_convert_columns_to_dots(scope, resolver)
|
|
88
|
+
_qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification)
|
|
89
|
+
|
|
90
|
+
if not schema.empty and expand_alias_refs:
|
|
91
|
+
_expand_alias_refs(scope, resolver, dialect)
|
|
92
|
+
|
|
93
|
+
if is_select:
|
|
94
|
+
if expand_stars:
|
|
95
|
+
_expand_stars(
|
|
96
|
+
scope,
|
|
97
|
+
resolver,
|
|
98
|
+
using_column_tables,
|
|
99
|
+
pseudocolumns,
|
|
100
|
+
annotator,
|
|
101
|
+
)
|
|
102
|
+
qualify_outputs(scope)
|
|
103
|
+
|
|
104
|
+
_expand_group_by(scope, dialect)
|
|
105
|
+
|
|
106
|
+
# DISTINCT ON and ORDER BY follow the same rules (tested in DuckDB, Postgres, ClickHouse)
|
|
107
|
+
# https://www.postgresql.org/docs/current/sql-select.html#SQL-DISTINCT
|
|
108
|
+
_expand_order_by_and_distinct_on(scope, resolver)
|
|
109
|
+
|
|
110
|
+
if bigquery:
|
|
111
|
+
annotator.annotate_scope(scope)
|
|
112
|
+
|
|
113
|
+
return expression
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def validate_qualify_columns(expression: E) -> E:
|
|
117
|
+
"""Raise an `OptimizeError` if any columns aren't qualified"""
|
|
118
|
+
all_unqualified_columns = []
|
|
119
|
+
for scope in traverse_scope(expression):
|
|
120
|
+
if isinstance(scope.expression, exp.Select):
|
|
121
|
+
unqualified_columns = scope.unqualified_columns
|
|
122
|
+
|
|
123
|
+
if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
|
|
124
|
+
column = scope.external_columns[0]
|
|
125
|
+
for_table = f" for table: '{column.table}'" if column.table else ""
|
|
126
|
+
raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
|
|
127
|
+
|
|
128
|
+
if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
|
|
129
|
+
# New columns produced by the UNPIVOT can't be qualified, but there may be columns
|
|
130
|
+
# under the UNPIVOT's IN clause that can and should be qualified. We recompute
|
|
131
|
+
# this list here to ensure those in the former category will be excluded.
|
|
132
|
+
unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
|
|
133
|
+
unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
|
|
134
|
+
|
|
135
|
+
all_unqualified_columns.extend(unqualified_columns)
|
|
136
|
+
|
|
137
|
+
if all_unqualified_columns:
|
|
138
|
+
raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
|
|
139
|
+
|
|
140
|
+
return expression
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]:
|
|
144
|
+
name_columns = [
|
|
145
|
+
field.this
|
|
146
|
+
for field in unpivot.fields
|
|
147
|
+
if isinstance(field, exp.In) and isinstance(field.this, exp.Column)
|
|
148
|
+
]
|
|
149
|
+
value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column))
|
|
150
|
+
|
|
151
|
+
return itertools.chain(name_columns, value_columns)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None:
|
|
155
|
+
"""
|
|
156
|
+
Remove table column aliases.
|
|
157
|
+
|
|
158
|
+
For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2)
|
|
159
|
+
"""
|
|
160
|
+
for derived_table in derived_tables:
|
|
161
|
+
if isinstance(derived_table.parent, exp.With) and derived_table.parent.recursive:
|
|
162
|
+
continue
|
|
163
|
+
table_alias = derived_table.args.get("alias")
|
|
164
|
+
if table_alias:
|
|
165
|
+
table_alias.args.pop("columns", None)
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
|
|
169
|
+
columns = {}
|
|
170
|
+
|
|
171
|
+
def _update_source_columns(source_name: str) -> None:
|
|
172
|
+
for column_name in resolver.get_source_columns(source_name):
|
|
173
|
+
if column_name not in columns:
|
|
174
|
+
columns[column_name] = source_name
|
|
175
|
+
|
|
176
|
+
joins = list(scope.find_all(exp.Join))
|
|
177
|
+
names = {join.alias_or_name for join in joins}
|
|
178
|
+
ordered = [key for key in scope.selected_sources if key not in names]
|
|
179
|
+
|
|
180
|
+
if names and not ordered:
|
|
181
|
+
raise OptimizeError(f"Joins {names} missing source table {scope.expression}")
|
|
182
|
+
|
|
183
|
+
# Mapping of automatically joined column names to an ordered set of source names (dict).
|
|
184
|
+
column_tables: t.Dict[str, t.Dict[str, t.Any]] = {}
|
|
185
|
+
|
|
186
|
+
for source_name in ordered:
|
|
187
|
+
_update_source_columns(source_name)
|
|
188
|
+
|
|
189
|
+
for i, join in enumerate(joins):
|
|
190
|
+
source_table = ordered[-1]
|
|
191
|
+
if source_table:
|
|
192
|
+
_update_source_columns(source_table)
|
|
193
|
+
|
|
194
|
+
join_table = join.alias_or_name
|
|
195
|
+
ordered.append(join_table)
|
|
196
|
+
|
|
197
|
+
using = join.args.get("using")
|
|
198
|
+
if not using:
|
|
199
|
+
continue
|
|
200
|
+
|
|
201
|
+
join_columns = resolver.get_source_columns(join_table)
|
|
202
|
+
conditions = []
|
|
203
|
+
using_identifier_count = len(using)
|
|
204
|
+
is_semi_or_anti_join = join.is_semi_or_anti_join
|
|
205
|
+
|
|
206
|
+
for identifier in using:
|
|
207
|
+
identifier = identifier.name
|
|
208
|
+
table = columns.get(identifier)
|
|
209
|
+
|
|
210
|
+
if not table or identifier not in join_columns:
|
|
211
|
+
if (columns and "*" not in columns) and join_columns:
|
|
212
|
+
raise OptimizeError(f"Cannot automatically join: {identifier}")
|
|
213
|
+
|
|
214
|
+
table = table or source_table
|
|
215
|
+
|
|
216
|
+
if i == 0 or using_identifier_count == 1:
|
|
217
|
+
lhs: exp.Expression = exp.column(identifier, table=table)
|
|
218
|
+
else:
|
|
219
|
+
coalesce_columns = [
|
|
220
|
+
exp.column(identifier, table=t)
|
|
221
|
+
for t in ordered[:-1]
|
|
222
|
+
if identifier in resolver.get_source_columns(t)
|
|
223
|
+
]
|
|
224
|
+
if len(coalesce_columns) > 1:
|
|
225
|
+
lhs = exp.func("coalesce", *coalesce_columns)
|
|
226
|
+
else:
|
|
227
|
+
lhs = exp.column(identifier, table=table)
|
|
228
|
+
|
|
229
|
+
conditions.append(lhs.eq(exp.column(identifier, table=join_table)))
|
|
230
|
+
|
|
231
|
+
# Set all values in the dict to None, because we only care about the key ordering
|
|
232
|
+
tables = column_tables.setdefault(identifier, {})
|
|
233
|
+
|
|
234
|
+
# Do not update the dict if this was a SEMI/ANTI join in
|
|
235
|
+
# order to avoid generating COALESCE columns for this join pair
|
|
236
|
+
if not is_semi_or_anti_join:
|
|
237
|
+
if table not in tables:
|
|
238
|
+
tables[table] = None
|
|
239
|
+
if join_table not in tables:
|
|
240
|
+
tables[join_table] = None
|
|
241
|
+
|
|
242
|
+
join.args.pop("using")
|
|
243
|
+
join.set("on", exp.and_(*conditions, copy=False))
|
|
244
|
+
|
|
245
|
+
if column_tables:
|
|
246
|
+
for column in scope.columns:
|
|
247
|
+
if not column.table and column.name in column_tables:
|
|
248
|
+
tables = column_tables[column.name]
|
|
249
|
+
coalesce_args = [exp.column(column.name, table=table) for table in tables]
|
|
250
|
+
replacement: exp.Expression = exp.func("coalesce", *coalesce_args)
|
|
251
|
+
|
|
252
|
+
if isinstance(column.parent, exp.Select):
|
|
253
|
+
# Ensure the USING column keeps its name if it's projected
|
|
254
|
+
replacement = alias(replacement, alias=column.name, copy=False)
|
|
255
|
+
elif isinstance(column.parent, exp.Struct):
|
|
256
|
+
# Ensure the USING column keeps its name if it's an anonymous STRUCT field
|
|
257
|
+
replacement = exp.PropertyEQ(
|
|
258
|
+
this=exp.to_identifier(column.name), expression=replacement
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
scope.replace(column, replacement)
|
|
262
|
+
|
|
263
|
+
return column_tables
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
def _expand_alias_refs(
|
|
267
|
+
scope: Scope, resolver: Resolver, dialect: Dialect, expand_only_groupby: bool = False
|
|
268
|
+
) -> None:
|
|
269
|
+
"""
|
|
270
|
+
Expand references to aliases.
|
|
271
|
+
Example:
|
|
272
|
+
SELECT y.foo AS bar, bar * 2 AS baz FROM y
|
|
273
|
+
=> SELECT y.foo AS bar, y.foo * 2 AS baz FROM y
|
|
274
|
+
"""
|
|
275
|
+
expression = scope.expression
|
|
276
|
+
|
|
277
|
+
if not isinstance(expression, exp.Select) or dialect == "oracle":
|
|
278
|
+
return
|
|
279
|
+
|
|
280
|
+
alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {}
|
|
281
|
+
projections = {s.alias_or_name for s in expression.selects}
|
|
282
|
+
|
|
283
|
+
def replace_columns(
|
|
284
|
+
node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False
|
|
285
|
+
) -> None:
|
|
286
|
+
is_group_by = isinstance(node, exp.Group)
|
|
287
|
+
is_having = isinstance(node, exp.Having)
|
|
288
|
+
if not node or (expand_only_groupby and not is_group_by):
|
|
289
|
+
return
|
|
290
|
+
|
|
291
|
+
for column in walk_in_scope(node, prune=lambda node: node.is_star):
|
|
292
|
+
if not isinstance(column, exp.Column):
|
|
293
|
+
continue
|
|
294
|
+
|
|
295
|
+
# BigQuery's GROUP BY allows alias expansion only for standalone names, e.g:
|
|
296
|
+
# SELECT FUNC(col) AS col FROM t GROUP BY col --> Can be expanded
|
|
297
|
+
# SELECT FUNC(col) AS col FROM t GROUP BY FUNC(col) --> Shouldn't be expanded, will result to FUNC(FUNC(col))
|
|
298
|
+
# This not required for the HAVING clause as it can evaluate expressions using both the alias & the table columns
|
|
299
|
+
if expand_only_groupby and is_group_by and column.parent is not node:
|
|
300
|
+
continue
|
|
301
|
+
|
|
302
|
+
skip_replace = False
|
|
303
|
+
table = resolver.get_table(column.name) if resolve_table and not column.table else None
|
|
304
|
+
alias_expr, i = alias_to_expression.get(column.name, (None, 1))
|
|
305
|
+
|
|
306
|
+
if alias_expr:
|
|
307
|
+
skip_replace = bool(
|
|
308
|
+
alias_expr.find(exp.AggFunc)
|
|
309
|
+
and column.find_ancestor(exp.AggFunc)
|
|
310
|
+
and not isinstance(column.find_ancestor(exp.Window, exp.Select), exp.Window)
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
# BigQuery's having clause gets confused if an alias matches a source.
|
|
314
|
+
# SELECT x.a, max(x.b) as x FROM x GROUP BY 1 HAVING x > 1;
|
|
315
|
+
# If HAVING x is expanded to max(x.b), bigquery treats x as the new projection x instead of the table
|
|
316
|
+
if is_having and dialect == "bigquery":
|
|
317
|
+
skip_replace = skip_replace or any(
|
|
318
|
+
node.parts[0].name in projections
|
|
319
|
+
for node in alias_expr.find_all(exp.Column)
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
if table and (not alias_expr or skip_replace):
|
|
323
|
+
column.set("table", table)
|
|
324
|
+
elif not column.table and alias_expr and not skip_replace:
|
|
325
|
+
if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table):
|
|
326
|
+
if literal_index:
|
|
327
|
+
column.replace(exp.Literal.number(i))
|
|
328
|
+
else:
|
|
329
|
+
column = column.replace(exp.paren(alias_expr))
|
|
330
|
+
simplified = simplify_parens(column)
|
|
331
|
+
if simplified is not column:
|
|
332
|
+
column.replace(simplified)
|
|
333
|
+
|
|
334
|
+
for i, projection in enumerate(expression.selects):
|
|
335
|
+
replace_columns(projection)
|
|
336
|
+
if isinstance(projection, exp.Alias):
|
|
337
|
+
alias_to_expression[projection.alias] = (projection.this, i + 1)
|
|
338
|
+
|
|
339
|
+
parent_scope = scope
|
|
340
|
+
while parent_scope.is_union:
|
|
341
|
+
parent_scope = parent_scope.parent
|
|
342
|
+
|
|
343
|
+
# We shouldn't expand aliases if they match the recursive CTE's columns
|
|
344
|
+
if parent_scope.is_cte:
|
|
345
|
+
cte = parent_scope.expression.parent
|
|
346
|
+
if cte.find_ancestor(exp.With).recursive:
|
|
347
|
+
for recursive_cte_column in cte.args["alias"].columns or cte.this.selects:
|
|
348
|
+
alias_to_expression.pop(recursive_cte_column.output_name, None)
|
|
349
|
+
|
|
350
|
+
replace_columns(expression.args.get("where"))
|
|
351
|
+
replace_columns(expression.args.get("group"), literal_index=True)
|
|
352
|
+
replace_columns(expression.args.get("having"), resolve_table=True)
|
|
353
|
+
replace_columns(expression.args.get("qualify"), resolve_table=True)
|
|
354
|
+
|
|
355
|
+
# Snowflake allows alias expansion in the JOIN ... ON clause (and almost everywhere else)
|
|
356
|
+
# https://docs.snowflake.com/en/sql-reference/sql/select#usage-notes
|
|
357
|
+
if dialect == "snowflake":
|
|
358
|
+
for join in expression.args.get("joins") or []:
|
|
359
|
+
replace_columns(join)
|
|
360
|
+
|
|
361
|
+
scope.clear_cache()
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
def _expand_group_by(scope: Scope, dialect: DialectType) -> None:
|
|
365
|
+
expression = scope.expression
|
|
366
|
+
group = expression.args.get("group")
|
|
367
|
+
if not group:
|
|
368
|
+
return
|
|
369
|
+
|
|
370
|
+
group.set("expressions", _expand_positional_references(scope, group.expressions, dialect))
|
|
371
|
+
expression.set("group", group)
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
def _expand_order_by_and_distinct_on(scope: Scope, resolver: Resolver) -> None:
|
|
375
|
+
for modifier_key in ("order", "distinct"):
|
|
376
|
+
modifier = scope.expression.args.get(modifier_key)
|
|
377
|
+
if isinstance(modifier, exp.Distinct):
|
|
378
|
+
modifier = modifier.args.get("on")
|
|
379
|
+
|
|
380
|
+
if not isinstance(modifier, exp.Expression):
|
|
381
|
+
continue
|
|
382
|
+
|
|
383
|
+
modifier_expressions = modifier.expressions
|
|
384
|
+
if modifier_key == "order":
|
|
385
|
+
modifier_expressions = [ordered.this for ordered in modifier_expressions]
|
|
386
|
+
|
|
387
|
+
for original, expanded in zip(
|
|
388
|
+
modifier_expressions,
|
|
389
|
+
_expand_positional_references(
|
|
390
|
+
scope, modifier_expressions, resolver.schema.dialect, alias=True
|
|
391
|
+
),
|
|
392
|
+
):
|
|
393
|
+
for agg in original.find_all(exp.AggFunc):
|
|
394
|
+
for col in agg.find_all(exp.Column):
|
|
395
|
+
if not col.table:
|
|
396
|
+
col.set("table", resolver.get_table(col.name))
|
|
397
|
+
|
|
398
|
+
original.replace(expanded)
|
|
399
|
+
|
|
400
|
+
if scope.expression.args.get("group"):
|
|
401
|
+
selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects}
|
|
402
|
+
|
|
403
|
+
for expression in modifier_expressions:
|
|
404
|
+
expression.replace(
|
|
405
|
+
exp.to_identifier(_select_by_pos(scope, expression).alias)
|
|
406
|
+
if expression.is_int
|
|
407
|
+
else selects.get(expression, expression)
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
def _expand_positional_references(
|
|
412
|
+
scope: Scope, expressions: t.Iterable[exp.Expression], dialect: DialectType, alias: bool = False
|
|
413
|
+
) -> t.List[exp.Expression]:
|
|
414
|
+
new_nodes: t.List[exp.Expression] = []
|
|
415
|
+
ambiguous_projections = None
|
|
416
|
+
|
|
417
|
+
for node in expressions:
|
|
418
|
+
if node.is_int:
|
|
419
|
+
select = _select_by_pos(scope, t.cast(exp.Literal, node))
|
|
420
|
+
|
|
421
|
+
if alias:
|
|
422
|
+
new_nodes.append(exp.column(select.args["alias"].copy()))
|
|
423
|
+
else:
|
|
424
|
+
select = select.this
|
|
425
|
+
|
|
426
|
+
if dialect == "bigquery":
|
|
427
|
+
if ambiguous_projections is None:
|
|
428
|
+
# When a projection name is also a source name and it is referenced in the
|
|
429
|
+
# GROUP BY clause, BQ can't understand what the identifier corresponds to
|
|
430
|
+
ambiguous_projections = {
|
|
431
|
+
s.alias_or_name
|
|
432
|
+
for s in scope.expression.selects
|
|
433
|
+
if s.alias_or_name in scope.selected_sources
|
|
434
|
+
}
|
|
435
|
+
|
|
436
|
+
ambiguous = any(
|
|
437
|
+
column.parts[0].name in ambiguous_projections
|
|
438
|
+
for column in select.find_all(exp.Column)
|
|
439
|
+
)
|
|
440
|
+
else:
|
|
441
|
+
ambiguous = False
|
|
442
|
+
|
|
443
|
+
if (
|
|
444
|
+
isinstance(select, exp.CONSTANTS)
|
|
445
|
+
or select.find(exp.Explode, exp.Unnest)
|
|
446
|
+
or ambiguous
|
|
447
|
+
):
|
|
448
|
+
new_nodes.append(node)
|
|
449
|
+
else:
|
|
450
|
+
new_nodes.append(select.copy())
|
|
451
|
+
else:
|
|
452
|
+
new_nodes.append(node)
|
|
453
|
+
|
|
454
|
+
return new_nodes
|
|
455
|
+
|
|
456
|
+
|
|
457
|
+
def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias:
|
|
458
|
+
try:
|
|
459
|
+
return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias)
|
|
460
|
+
except IndexError:
|
|
461
|
+
raise OptimizeError(f"Unknown output column: {node.name}")
|
|
462
|
+
|
|
463
|
+
|
|
464
|
+
def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None:
|
|
465
|
+
"""
|
|
466
|
+
Converts `Column` instances that represent struct field lookup into chained `Dots`.
|
|
467
|
+
|
|
468
|
+
Struct field lookups look like columns (e.g. "struct"."field"), but they need to be
|
|
469
|
+
qualified separately and represented as Dot(Dot(...(<table>.<column>, field1), field2, ...)).
|
|
470
|
+
"""
|
|
471
|
+
converted = False
|
|
472
|
+
for column in itertools.chain(scope.columns, scope.stars):
|
|
473
|
+
if isinstance(column, exp.Dot):
|
|
474
|
+
continue
|
|
475
|
+
|
|
476
|
+
column_table: t.Optional[str | exp.Identifier] = column.table
|
|
477
|
+
if (
|
|
478
|
+
column_table
|
|
479
|
+
and column_table not in scope.sources
|
|
480
|
+
and (
|
|
481
|
+
not scope.parent
|
|
482
|
+
or column_table not in scope.parent.sources
|
|
483
|
+
or not scope.is_correlated_subquery
|
|
484
|
+
)
|
|
485
|
+
):
|
|
486
|
+
root, *parts = column.parts
|
|
487
|
+
|
|
488
|
+
if root.name in scope.sources:
|
|
489
|
+
# The struct is already qualified, but we still need to change the AST
|
|
490
|
+
column_table = root
|
|
491
|
+
root, *parts = parts
|
|
492
|
+
else:
|
|
493
|
+
column_table = resolver.get_table(root.name)
|
|
494
|
+
|
|
495
|
+
if column_table:
|
|
496
|
+
converted = True
|
|
497
|
+
column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
|
|
498
|
+
|
|
499
|
+
if converted:
|
|
500
|
+
# We want to re-aggregate the converted columns, otherwise they'd be skipped in
|
|
501
|
+
# a `for column in scope.columns` iteration, even though they shouldn't be
|
|
502
|
+
scope.clear_cache()
|
|
503
|
+
|
|
504
|
+
|
|
505
|
+
def _qualify_columns(scope: Scope, resolver: Resolver, allow_partial_qualification: bool) -> None:
|
|
506
|
+
"""Disambiguate columns, ensuring each column specifies a source"""
|
|
507
|
+
for column in scope.columns:
|
|
508
|
+
column_table = column.table
|
|
509
|
+
column_name = column.name
|
|
510
|
+
|
|
511
|
+
if column_table and column_table in scope.sources:
|
|
512
|
+
source_columns = resolver.get_source_columns(column_table)
|
|
513
|
+
if (
|
|
514
|
+
not allow_partial_qualification
|
|
515
|
+
and source_columns
|
|
516
|
+
and column_name not in source_columns
|
|
517
|
+
and "*" not in source_columns
|
|
518
|
+
):
|
|
519
|
+
raise OptimizeError(f"Unknown column: {column_name}")
|
|
520
|
+
|
|
521
|
+
if not column_table:
|
|
522
|
+
if scope.pivots and not column.find_ancestor(exp.Pivot):
|
|
523
|
+
# If the column is under the Pivot expression, we need to qualify it
|
|
524
|
+
# using the name of the pivoted source instead of the pivot's alias
|
|
525
|
+
column.set("table", exp.to_identifier(scope.pivots[0].alias))
|
|
526
|
+
continue
|
|
527
|
+
|
|
528
|
+
# column_table can be a '' because bigquery unnest has no table alias
|
|
529
|
+
column_table = resolver.get_table(column_name)
|
|
530
|
+
if column_table:
|
|
531
|
+
column.set("table", column_table)
|
|
532
|
+
|
|
533
|
+
for pivot in scope.pivots:
|
|
534
|
+
for column in pivot.find_all(exp.Column):
|
|
535
|
+
if not column.table and column.name in resolver.all_columns:
|
|
536
|
+
column_table = resolver.get_table(column.name)
|
|
537
|
+
if column_table:
|
|
538
|
+
column.set("table", column_table)
|
|
539
|
+
|
|
540
|
+
|
|
541
|
+
def _expand_struct_stars(
|
|
542
|
+
expression: exp.Dot,
|
|
543
|
+
) -> t.List[exp.Alias]:
|
|
544
|
+
"""[BigQuery] Expand/Flatten foo.bar.* where bar is a struct column"""
|
|
545
|
+
|
|
546
|
+
dot_column = t.cast(exp.Column, expression.find(exp.Column))
|
|
547
|
+
if not dot_column.is_type(exp.DataType.Type.STRUCT):
|
|
548
|
+
return []
|
|
549
|
+
|
|
550
|
+
# All nested struct values are ColumnDefs, so normalize the first exp.Column in one
|
|
551
|
+
dot_column = dot_column.copy()
|
|
552
|
+
starting_struct = exp.ColumnDef(this=dot_column.this, kind=dot_column.type)
|
|
553
|
+
|
|
554
|
+
# First part is the table name and last part is the star so they can be dropped
|
|
555
|
+
dot_parts = expression.parts[1:-1]
|
|
556
|
+
|
|
557
|
+
# If we're expanding a nested struct eg. t.c.f1.f2.* find the last struct (f2 in this case)
|
|
558
|
+
for part in dot_parts[1:]:
|
|
559
|
+
for field in t.cast(exp.DataType, starting_struct.kind).expressions:
|
|
560
|
+
# Unable to expand star unless all fields are named
|
|
561
|
+
if not isinstance(field.this, exp.Identifier):
|
|
562
|
+
return []
|
|
563
|
+
|
|
564
|
+
if field.name == part.name and field.kind.is_type(exp.DataType.Type.STRUCT):
|
|
565
|
+
starting_struct = field
|
|
566
|
+
break
|
|
567
|
+
else:
|
|
568
|
+
# There is no matching field in the struct
|
|
569
|
+
return []
|
|
570
|
+
|
|
571
|
+
taken_names = set()
|
|
572
|
+
new_selections = []
|
|
573
|
+
|
|
574
|
+
for field in t.cast(exp.DataType, starting_struct.kind).expressions:
|
|
575
|
+
name = field.name
|
|
576
|
+
|
|
577
|
+
# Ambiguous or anonymous fields can't be expanded
|
|
578
|
+
if name in taken_names or not isinstance(field.this, exp.Identifier):
|
|
579
|
+
return []
|
|
580
|
+
|
|
581
|
+
taken_names.add(name)
|
|
582
|
+
|
|
583
|
+
this = field.this.copy()
|
|
584
|
+
root, *parts = [part.copy() for part in itertools.chain(dot_parts, [this])]
|
|
585
|
+
new_column = exp.column(
|
|
586
|
+
t.cast(exp.Identifier, root),
|
|
587
|
+
table=dot_column.args.get("table"),
|
|
588
|
+
fields=t.cast(t.List[exp.Identifier], parts),
|
|
589
|
+
)
|
|
590
|
+
new_selections.append(alias(new_column, this, copy=False))
|
|
591
|
+
|
|
592
|
+
return new_selections
|
|
593
|
+
|
|
594
|
+
|
|
595
|
+
def _expand_stars(
|
|
596
|
+
scope: Scope,
|
|
597
|
+
resolver: Resolver,
|
|
598
|
+
using_column_tables: t.Dict[str, t.Any],
|
|
599
|
+
pseudocolumns: t.Set[str],
|
|
600
|
+
annotator: TypeAnnotator,
|
|
601
|
+
) -> None:
|
|
602
|
+
"""Expand stars to lists of column selections"""
|
|
603
|
+
|
|
604
|
+
new_selections: t.List[exp.Expression] = []
|
|
605
|
+
except_columns: t.Dict[int, t.Set[str]] = {}
|
|
606
|
+
replace_columns: t.Dict[int, t.Dict[str, exp.Alias]] = {}
|
|
607
|
+
rename_columns: t.Dict[int, t.Dict[str, str]] = {}
|
|
608
|
+
|
|
609
|
+
coalesced_columns = set()
|
|
610
|
+
dialect = resolver.schema.dialect
|
|
611
|
+
|
|
612
|
+
pivot_output_columns = None
|
|
613
|
+
pivot_exclude_columns: t.Set[str] = set()
|
|
614
|
+
|
|
615
|
+
pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0))
|
|
616
|
+
if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names:
|
|
617
|
+
if pivot.unpivot:
|
|
618
|
+
pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)]
|
|
619
|
+
|
|
620
|
+
for field in pivot.fields:
|
|
621
|
+
if isinstance(field, exp.In):
|
|
622
|
+
pivot_exclude_columns.update(
|
|
623
|
+
c.output_name for e in field.expressions for c in e.find_all(exp.Column)
|
|
624
|
+
)
|
|
625
|
+
|
|
626
|
+
else:
|
|
627
|
+
pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column))
|
|
628
|
+
|
|
629
|
+
pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])]
|
|
630
|
+
if not pivot_output_columns:
|
|
631
|
+
pivot_output_columns = [c.alias_or_name for c in pivot.expressions]
|
|
632
|
+
|
|
633
|
+
is_bigquery = dialect == "bigquery"
|
|
634
|
+
if is_bigquery and any(isinstance(col, exp.Dot) for col in scope.stars):
|
|
635
|
+
# Found struct expansion, annotate scope ahead of time
|
|
636
|
+
annotator.annotate_scope(scope)
|
|
637
|
+
|
|
638
|
+
for expression in scope.expression.selects:
|
|
639
|
+
tables = []
|
|
640
|
+
if isinstance(expression, exp.Star):
|
|
641
|
+
tables.extend(scope.selected_sources)
|
|
642
|
+
_add_except_columns(expression, tables, except_columns)
|
|
643
|
+
_add_replace_columns(expression, tables, replace_columns)
|
|
644
|
+
_add_rename_columns(expression, tables, rename_columns)
|
|
645
|
+
elif expression.is_star:
|
|
646
|
+
if not isinstance(expression, exp.Dot):
|
|
647
|
+
tables.append(expression.table)
|
|
648
|
+
_add_except_columns(expression.this, tables, except_columns)
|
|
649
|
+
_add_replace_columns(expression.this, tables, replace_columns)
|
|
650
|
+
_add_rename_columns(expression.this, tables, rename_columns)
|
|
651
|
+
elif is_bigquery:
|
|
652
|
+
struct_fields = _expand_struct_stars(expression)
|
|
653
|
+
if struct_fields:
|
|
654
|
+
new_selections.extend(struct_fields)
|
|
655
|
+
continue
|
|
656
|
+
|
|
657
|
+
if not tables:
|
|
658
|
+
new_selections.append(expression)
|
|
659
|
+
continue
|
|
660
|
+
|
|
661
|
+
for table in tables:
|
|
662
|
+
if table not in scope.sources:
|
|
663
|
+
raise OptimizeError(f"Unknown table: {table}")
|
|
664
|
+
|
|
665
|
+
columns = resolver.get_source_columns(table, only_visible=True)
|
|
666
|
+
columns = columns or scope.outer_columns
|
|
667
|
+
|
|
668
|
+
if pseudocolumns:
|
|
669
|
+
columns = [name for name in columns if name.upper() not in pseudocolumns]
|
|
670
|
+
|
|
671
|
+
if not columns or "*" in columns:
|
|
672
|
+
return
|
|
673
|
+
|
|
674
|
+
table_id = id(table)
|
|
675
|
+
columns_to_exclude = except_columns.get(table_id) or set()
|
|
676
|
+
renamed_columns = rename_columns.get(table_id, {})
|
|
677
|
+
replaced_columns = replace_columns.get(table_id, {})
|
|
678
|
+
|
|
679
|
+
if pivot:
|
|
680
|
+
if pivot_output_columns and pivot_exclude_columns:
|
|
681
|
+
pivot_columns = [c for c in columns if c not in pivot_exclude_columns]
|
|
682
|
+
pivot_columns.extend(pivot_output_columns)
|
|
683
|
+
else:
|
|
684
|
+
pivot_columns = pivot.alias_column_names
|
|
685
|
+
|
|
686
|
+
if pivot_columns:
|
|
687
|
+
new_selections.extend(
|
|
688
|
+
alias(exp.column(name, table=pivot.alias), name, copy=False)
|
|
689
|
+
for name in pivot_columns
|
|
690
|
+
if name not in columns_to_exclude
|
|
691
|
+
)
|
|
692
|
+
continue
|
|
693
|
+
|
|
694
|
+
for name in columns:
|
|
695
|
+
if name in columns_to_exclude or name in coalesced_columns:
|
|
696
|
+
continue
|
|
697
|
+
if name in using_column_tables and table in using_column_tables[name]:
|
|
698
|
+
coalesced_columns.add(name)
|
|
699
|
+
tables = using_column_tables[name]
|
|
700
|
+
coalesce_args = [exp.column(name, table=table) for table in tables]
|
|
701
|
+
|
|
702
|
+
new_selections.append(
|
|
703
|
+
alias(exp.func("coalesce", *coalesce_args), alias=name, copy=False)
|
|
704
|
+
)
|
|
705
|
+
else:
|
|
706
|
+
alias_ = renamed_columns.get(name, name)
|
|
707
|
+
selection_expr = replaced_columns.get(name) or exp.column(name, table=table)
|
|
708
|
+
new_selections.append(
|
|
709
|
+
alias(selection_expr, alias_, copy=False)
|
|
710
|
+
if alias_ != name
|
|
711
|
+
else selection_expr
|
|
712
|
+
)
|
|
713
|
+
|
|
714
|
+
# Ensures we don't overwrite the initial selections with an empty list
|
|
715
|
+
if new_selections and isinstance(scope.expression, exp.Select):
|
|
716
|
+
scope.expression.set("expressions", new_selections)
|
|
717
|
+
|
|
718
|
+
|
|
719
|
+
def _add_except_columns(
|
|
720
|
+
expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]]
|
|
721
|
+
) -> None:
|
|
722
|
+
except_ = expression.args.get("except")
|
|
723
|
+
|
|
724
|
+
if not except_:
|
|
725
|
+
return
|
|
726
|
+
|
|
727
|
+
columns = {e.name for e in except_}
|
|
728
|
+
|
|
729
|
+
for table in tables:
|
|
730
|
+
except_columns[id(table)] = columns
|
|
731
|
+
|
|
732
|
+
|
|
733
|
+
def _add_rename_columns(
|
|
734
|
+
expression: exp.Expression, tables, rename_columns: t.Dict[int, t.Dict[str, str]]
|
|
735
|
+
) -> None:
|
|
736
|
+
rename = expression.args.get("rename")
|
|
737
|
+
|
|
738
|
+
if not rename:
|
|
739
|
+
return
|
|
740
|
+
|
|
741
|
+
columns = {e.this.name: e.alias for e in rename}
|
|
742
|
+
|
|
743
|
+
for table in tables:
|
|
744
|
+
rename_columns[id(table)] = columns
|
|
745
|
+
|
|
746
|
+
|
|
747
|
+
def _add_replace_columns(
|
|
748
|
+
expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, exp.Alias]]
|
|
749
|
+
) -> None:
|
|
750
|
+
replace = expression.args.get("replace")
|
|
751
|
+
|
|
752
|
+
if not replace:
|
|
753
|
+
return
|
|
754
|
+
|
|
755
|
+
columns = {e.alias: e for e in replace}
|
|
756
|
+
|
|
757
|
+
for table in tables:
|
|
758
|
+
replace_columns[id(table)] = columns
|
|
759
|
+
|
|
760
|
+
|
|
761
|
+
def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
|
|
762
|
+
"""Ensure all output columns are aliased"""
|
|
763
|
+
if isinstance(scope_or_expression, exp.Expression):
|
|
764
|
+
scope = build_scope(scope_or_expression)
|
|
765
|
+
if not isinstance(scope, Scope):
|
|
766
|
+
return
|
|
767
|
+
else:
|
|
768
|
+
scope = scope_or_expression
|
|
769
|
+
|
|
770
|
+
new_selections = []
|
|
771
|
+
for i, (selection, aliased_column) in enumerate(
|
|
772
|
+
itertools.zip_longest(scope.expression.selects, scope.outer_columns)
|
|
773
|
+
):
|
|
774
|
+
if selection is None or isinstance(selection, exp.QueryTransform):
|
|
775
|
+
break
|
|
776
|
+
|
|
777
|
+
if isinstance(selection, exp.Subquery):
|
|
778
|
+
if not selection.output_name:
|
|
779
|
+
selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
|
|
780
|
+
elif not isinstance(selection, exp.Alias) and not selection.is_star:
|
|
781
|
+
selection = alias(
|
|
782
|
+
selection,
|
|
783
|
+
alias=selection.output_name or f"_col_{i}",
|
|
784
|
+
copy=False,
|
|
785
|
+
)
|
|
786
|
+
if aliased_column:
|
|
787
|
+
selection.set("alias", exp.to_identifier(aliased_column))
|
|
788
|
+
|
|
789
|
+
new_selections.append(selection)
|
|
790
|
+
|
|
791
|
+
if new_selections and isinstance(scope.expression, exp.Select):
|
|
792
|
+
scope.expression.set("expressions", new_selections)
|
|
793
|
+
|
|
794
|
+
|
|
795
|
+
def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
|
|
796
|
+
"""Makes sure all identifiers that need to be quoted are quoted."""
|
|
797
|
+
return expression.transform(
|
|
798
|
+
Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
|
|
799
|
+
) # type: ignore
|
|
800
|
+
|
|
801
|
+
|
|
802
|
+
def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
|
|
803
|
+
"""
|
|
804
|
+
Pushes down the CTE alias columns into the projection,
|
|
805
|
+
|
|
806
|
+
This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
|
|
807
|
+
|
|
808
|
+
Example:
|
|
809
|
+
>>> import sqlglot
|
|
810
|
+
>>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
|
|
811
|
+
>>> pushdown_cte_alias_columns(expression).sql()
|
|
812
|
+
'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
|
|
813
|
+
|
|
814
|
+
Args:
|
|
815
|
+
expression: Expression to pushdown.
|
|
816
|
+
|
|
817
|
+
Returns:
|
|
818
|
+
The expression with the CTE aliases pushed down into the projection.
|
|
819
|
+
"""
|
|
820
|
+
for cte in expression.find_all(exp.CTE):
|
|
821
|
+
if cte.alias_column_names:
|
|
822
|
+
new_expressions = []
|
|
823
|
+
for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
|
|
824
|
+
if isinstance(projection, exp.Alias):
|
|
825
|
+
projection.set("alias", _alias)
|
|
826
|
+
else:
|
|
827
|
+
projection = alias(projection, alias=_alias)
|
|
828
|
+
new_expressions.append(projection)
|
|
829
|
+
cte.this.set("expressions", new_expressions)
|
|
830
|
+
|
|
831
|
+
return expression
|
|
832
|
+
|
|
833
|
+
|
|
834
|
+
class Resolver:
|
|
835
|
+
"""
|
|
836
|
+
Helper for resolving columns.
|
|
837
|
+
|
|
838
|
+
This is a class so we can lazily load some things and easily share them across functions.
|
|
839
|
+
"""
|
|
840
|
+
|
|
841
|
+
def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
|
|
842
|
+
self.scope = scope
|
|
843
|
+
self.schema = schema
|
|
844
|
+
self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
|
|
845
|
+
self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
|
|
846
|
+
self._all_columns: t.Optional[t.Set[str]] = None
|
|
847
|
+
self._infer_schema = infer_schema
|
|
848
|
+
self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
|
|
849
|
+
|
|
850
|
+
def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
|
|
851
|
+
"""
|
|
852
|
+
Get the table for a column name.
|
|
853
|
+
|
|
854
|
+
Args:
|
|
855
|
+
column_name: The column name to find the table for.
|
|
856
|
+
Returns:
|
|
857
|
+
The table name if it can be found/inferred.
|
|
858
|
+
"""
|
|
859
|
+
if self._unambiguous_columns is None:
|
|
860
|
+
self._unambiguous_columns = self._get_unambiguous_columns(
|
|
861
|
+
self._get_all_source_columns()
|
|
862
|
+
)
|
|
863
|
+
|
|
864
|
+
table_name = self._unambiguous_columns.get(column_name)
|
|
865
|
+
|
|
866
|
+
if not table_name and self._infer_schema:
|
|
867
|
+
sources_without_schema = tuple(
|
|
868
|
+
source
|
|
869
|
+
for source, columns in self._get_all_source_columns().items()
|
|
870
|
+
if not columns or "*" in columns
|
|
871
|
+
)
|
|
872
|
+
if len(sources_without_schema) == 1:
|
|
873
|
+
table_name = sources_without_schema[0]
|
|
874
|
+
|
|
875
|
+
if table_name not in self.scope.selected_sources:
|
|
876
|
+
return exp.to_identifier(table_name)
|
|
877
|
+
|
|
878
|
+
node, _ = self.scope.selected_sources.get(table_name)
|
|
879
|
+
|
|
880
|
+
if isinstance(node, exp.Query):
|
|
881
|
+
while node and node.alias != table_name:
|
|
882
|
+
node = node.parent
|
|
883
|
+
|
|
884
|
+
node_alias = node.args.get("alias")
|
|
885
|
+
if node_alias:
|
|
886
|
+
return exp.to_identifier(node_alias.this)
|
|
887
|
+
|
|
888
|
+
return exp.to_identifier(table_name)
|
|
889
|
+
|
|
890
|
+
@property
|
|
891
|
+
def all_columns(self) -> t.Set[str]:
|
|
892
|
+
"""All available columns of all sources in this scope"""
|
|
893
|
+
if self._all_columns is None:
|
|
894
|
+
self._all_columns = {
|
|
895
|
+
column for columns in self._get_all_source_columns().values() for column in columns
|
|
896
|
+
}
|
|
897
|
+
return self._all_columns
|
|
898
|
+
|
|
899
|
+
def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
|
|
900
|
+
"""Resolve the source columns for a given source `name`."""
|
|
901
|
+
cache_key = (name, only_visible)
|
|
902
|
+
if cache_key not in self._get_source_columns_cache:
|
|
903
|
+
if name not in self.scope.sources:
|
|
904
|
+
raise OptimizeError(f"Unknown table: {name}")
|
|
905
|
+
|
|
906
|
+
source = self.scope.sources[name]
|
|
907
|
+
|
|
908
|
+
if isinstance(source, exp.Table):
|
|
909
|
+
columns = self.schema.column_names(source, only_visible)
|
|
910
|
+
elif isinstance(source, Scope) and isinstance(
|
|
911
|
+
source.expression, (exp.Values, exp.Unnest)
|
|
912
|
+
):
|
|
913
|
+
columns = source.expression.named_selects
|
|
914
|
+
|
|
915
|
+
# in bigquery, unnest structs are automatically scoped as tables, so you can
|
|
916
|
+
# directly select a struct field in a query.
|
|
917
|
+
# this handles the case where the unnest is statically defined.
|
|
918
|
+
if self.schema.dialect == "bigquery":
|
|
919
|
+
if source.expression.is_type(exp.DataType.Type.STRUCT):
|
|
920
|
+
for k in source.expression.type.expressions: # type: ignore
|
|
921
|
+
columns.append(k.name)
|
|
922
|
+
elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation):
|
|
923
|
+
set_op = source.expression
|
|
924
|
+
|
|
925
|
+
# BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME
|
|
926
|
+
on_column_list = set_op.args.get("on")
|
|
927
|
+
|
|
928
|
+
if on_column_list:
|
|
929
|
+
# The resulting columns are the columns in the ON clause:
|
|
930
|
+
# {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...)
|
|
931
|
+
columns = [col.name for col in on_column_list]
|
|
932
|
+
elif set_op.side or set_op.kind:
|
|
933
|
+
side = set_op.side
|
|
934
|
+
kind = set_op.kind
|
|
935
|
+
|
|
936
|
+
left = set_op.left.named_selects
|
|
937
|
+
right = set_op.right.named_selects
|
|
938
|
+
|
|
939
|
+
# We use dict.fromkeys to deduplicate keys and maintain insertion order
|
|
940
|
+
if side == "LEFT":
|
|
941
|
+
columns = left
|
|
942
|
+
elif side == "FULL":
|
|
943
|
+
columns = list(dict.fromkeys(left + right))
|
|
944
|
+
elif kind == "INNER":
|
|
945
|
+
columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys())
|
|
946
|
+
else:
|
|
947
|
+
columns = set_op.named_selects
|
|
948
|
+
else:
|
|
949
|
+
select = seq_get(source.expression.selects, 0)
|
|
950
|
+
|
|
951
|
+
if isinstance(select, exp.QueryTransform):
|
|
952
|
+
# https://spark.apache.org/docs/3.5.1/sql-ref-syntax-qry-select-transform.html
|
|
953
|
+
schema = select.args.get("schema")
|
|
954
|
+
columns = [c.name for c in schema.expressions] if schema else ["key", "value"]
|
|
955
|
+
else:
|
|
956
|
+
columns = source.expression.named_selects
|
|
957
|
+
|
|
958
|
+
node, _ = self.scope.selected_sources.get(name) or (None, None)
|
|
959
|
+
if isinstance(node, Scope):
|
|
960
|
+
column_aliases = node.expression.alias_column_names
|
|
961
|
+
elif isinstance(node, exp.Expression):
|
|
962
|
+
column_aliases = node.alias_column_names
|
|
963
|
+
else:
|
|
964
|
+
column_aliases = []
|
|
965
|
+
|
|
966
|
+
if column_aliases:
|
|
967
|
+
# If the source's columns are aliased, their aliases shadow the corresponding column names.
|
|
968
|
+
# This can be expensive if there are lots of columns, so only do this if column_aliases exist.
|
|
969
|
+
columns = [
|
|
970
|
+
alias or name
|
|
971
|
+
for (name, alias) in itertools.zip_longest(columns, column_aliases)
|
|
972
|
+
]
|
|
973
|
+
|
|
974
|
+
self._get_source_columns_cache[cache_key] = columns
|
|
975
|
+
|
|
976
|
+
return self._get_source_columns_cache[cache_key]
|
|
977
|
+
|
|
978
|
+
def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
|
|
979
|
+
if self._source_columns is None:
|
|
980
|
+
self._source_columns = {
|
|
981
|
+
source_name: self.get_source_columns(source_name)
|
|
982
|
+
for source_name, source in itertools.chain(
|
|
983
|
+
self.scope.selected_sources.items(), self.scope.lateral_sources.items()
|
|
984
|
+
)
|
|
985
|
+
}
|
|
986
|
+
return self._source_columns
|
|
987
|
+
|
|
988
|
+
def _get_unambiguous_columns(
|
|
989
|
+
self, source_columns: t.Dict[str, t.Sequence[str]]
|
|
990
|
+
) -> t.Mapping[str, str]:
|
|
991
|
+
"""
|
|
992
|
+
Find all the unambiguous columns in sources.
|
|
993
|
+
|
|
994
|
+
Args:
|
|
995
|
+
source_columns: Mapping of names to source columns.
|
|
996
|
+
|
|
997
|
+
Returns:
|
|
998
|
+
Mapping of column name to source name.
|
|
999
|
+
"""
|
|
1000
|
+
if not source_columns:
|
|
1001
|
+
return {}
|
|
1002
|
+
|
|
1003
|
+
source_columns_pairs = list(source_columns.items())
|
|
1004
|
+
|
|
1005
|
+
first_table, first_columns = source_columns_pairs[0]
|
|
1006
|
+
|
|
1007
|
+
if len(source_columns_pairs) == 1:
|
|
1008
|
+
# Performance optimization - avoid copying first_columns if there is only one table.
|
|
1009
|
+
return SingleValuedMapping(first_columns, first_table)
|
|
1010
|
+
|
|
1011
|
+
unambiguous_columns = {col: first_table for col in first_columns}
|
|
1012
|
+
all_columns = set(unambiguous_columns)
|
|
1013
|
+
|
|
1014
|
+
for table, columns in source_columns_pairs[1:]:
|
|
1015
|
+
unique = set(columns)
|
|
1016
|
+
ambiguous = all_columns.intersection(unique)
|
|
1017
|
+
all_columns.update(columns)
|
|
1018
|
+
|
|
1019
|
+
for column in ambiguous:
|
|
1020
|
+
unambiguous_columns.pop(column, None)
|
|
1021
|
+
for column in unique.difference(ambiguous):
|
|
1022
|
+
unambiguous_columns[column] = table
|
|
1023
|
+
|
|
1024
|
+
return unambiguous_columns
|