altimate-code 0.5.2 → 0.5.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +12 -0
- package/bin/altimate +6 -0
- package/bin/altimate-code +6 -0
- package/dbt-tools/bin/altimate-dbt +2 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/altimate/__init__.py +0 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/altimate/fetch_schema.py +35 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/altimate/utils.py +353 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/altimate/validate_sql.py +114 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/__init__.py +178 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/__main__.py +96 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/_typing.py +17 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/__init__.py +3 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/__init__.py +18 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/_typing.py +18 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/column.py +332 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/dataframe.py +866 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/functions.py +1267 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/group.py +59 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/normalize.py +78 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/operations.py +53 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/readwriter.py +108 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/session.py +190 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/transforms.py +9 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/types.py +212 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/util.py +32 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/window.py +134 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/__init__.py +118 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/athena.py +166 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/bigquery.py +1331 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/clickhouse.py +1393 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/databricks.py +131 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/dialect.py +1915 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/doris.py +561 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/drill.py +157 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/druid.py +20 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/duckdb.py +1159 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/dune.py +16 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/hive.py +787 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/materialize.py +94 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/mysql.py +1324 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/oracle.py +378 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/postgres.py +778 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/presto.py +788 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/prql.py +203 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/redshift.py +448 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/risingwave.py +78 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/snowflake.py +1464 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/spark.py +202 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/spark2.py +349 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/sqlite.py +320 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/starrocks.py +343 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/tableau.py +61 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/teradata.py +356 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/trino.py +115 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/tsql.py +1403 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/diff.py +456 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/errors.py +93 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/__init__.py +95 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/context.py +101 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/env.py +246 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/python.py +460 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/table.py +155 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/expressions.py +8870 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/generator.py +4993 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/helper.py +582 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/jsonpath.py +227 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/lineage.py +423 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/__init__.py +11 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/annotate_types.py +589 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/canonicalize.py +222 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/eliminate_ctes.py +43 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/eliminate_joins.py +181 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/eliminate_subqueries.py +189 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/isolate_table_selects.py +50 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/merge_subqueries.py +415 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/normalize.py +200 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/normalize_identifiers.py +64 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/optimize_joins.py +91 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/optimizer.py +94 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/pushdown_predicates.py +222 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/pushdown_projections.py +172 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/qualify.py +104 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/qualify_columns.py +1024 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/qualify_tables.py +155 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/scope.py +904 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/simplify.py +1587 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/unnest_subqueries.py +302 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/parser.py +8501 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/planner.py +463 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/schema.py +588 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/serde.py +68 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/time.py +687 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/tokens.py +1520 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/transforms.py +1020 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/trie.py +81 -0
- package/dbt-tools/dist/altimate_python_packages/dbt_core_integration.py +825 -0
- package/dbt-tools/dist/altimate_python_packages/dbt_utils.py +157 -0
- package/dbt-tools/dist/index.js +23859 -0
- package/package.json +13 -13
- package/postinstall.mjs +42 -0
- package/skills/altimate-setup/SKILL.md +31 -0
package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/canonicalize.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import itertools
|
|
4
|
+
import typing as t
|
|
5
|
+
|
|
6
|
+
from sqlglot import exp
|
|
7
|
+
from sqlglot.dialects.dialect import Dialect, DialectType
|
|
8
|
+
from sqlglot.helper import is_date_unit, is_iso_date, is_iso_datetime
|
|
9
|
+
from sqlglot.optimizer.annotate_types import TypeAnnotator
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def canonicalize(expression: exp.Expression, dialect: DialectType = None) -> exp.Expression:
|
|
13
|
+
"""Converts a sql expression into a standard form.
|
|
14
|
+
|
|
15
|
+
This method relies on annotate_types because many of the
|
|
16
|
+
conversions rely on type inference.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
expression: The expression to canonicalize.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
dialect = Dialect.get_or_raise(dialect)
|
|
23
|
+
|
|
24
|
+
def _canonicalize(expression: exp.Expression) -> exp.Expression:
|
|
25
|
+
expression = add_text_to_concat(expression)
|
|
26
|
+
expression = replace_date_funcs(expression, dialect=dialect)
|
|
27
|
+
expression = coerce_type(expression, dialect.PROMOTE_TO_INFERRED_DATETIME_TYPE)
|
|
28
|
+
expression = remove_redundant_casts(expression)
|
|
29
|
+
expression = ensure_bools(expression, _replace_int_predicate)
|
|
30
|
+
expression = remove_ascending_order(expression)
|
|
31
|
+
return expression
|
|
32
|
+
|
|
33
|
+
return exp.replace_tree(expression, _canonicalize)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def add_text_to_concat(node: exp.Expression) -> exp.Expression:
|
|
37
|
+
if isinstance(node, exp.Add) and node.type and node.type.this in exp.DataType.TEXT_TYPES:
|
|
38
|
+
node = exp.Concat(expressions=[node.left, node.right])
|
|
39
|
+
return node
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def replace_date_funcs(node: exp.Expression, dialect: DialectType) -> exp.Expression:
|
|
43
|
+
if (
|
|
44
|
+
isinstance(node, (exp.Date, exp.TsOrDsToDate))
|
|
45
|
+
and not node.expressions
|
|
46
|
+
and not node.args.get("zone")
|
|
47
|
+
and node.this.is_string
|
|
48
|
+
and is_iso_date(node.this.name)
|
|
49
|
+
):
|
|
50
|
+
return exp.cast(node.this, to=exp.DataType.Type.DATE)
|
|
51
|
+
if isinstance(node, exp.Timestamp) and not node.args.get("zone"):
|
|
52
|
+
if not node.type:
|
|
53
|
+
from sqlglot.optimizer.annotate_types import annotate_types
|
|
54
|
+
|
|
55
|
+
node = annotate_types(node, dialect=dialect)
|
|
56
|
+
return exp.cast(node.this, to=node.type or exp.DataType.Type.TIMESTAMP)
|
|
57
|
+
|
|
58
|
+
return node
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
COERCIBLE_DATE_OPS = (
|
|
62
|
+
exp.Add,
|
|
63
|
+
exp.Sub,
|
|
64
|
+
exp.EQ,
|
|
65
|
+
exp.NEQ,
|
|
66
|
+
exp.GT,
|
|
67
|
+
exp.GTE,
|
|
68
|
+
exp.LT,
|
|
69
|
+
exp.LTE,
|
|
70
|
+
exp.NullSafeEQ,
|
|
71
|
+
exp.NullSafeNEQ,
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def coerce_type(node: exp.Expression, promote_to_inferred_datetime_type: bool) -> exp.Expression:
|
|
76
|
+
if isinstance(node, COERCIBLE_DATE_OPS):
|
|
77
|
+
_coerce_date(node.left, node.right, promote_to_inferred_datetime_type)
|
|
78
|
+
elif isinstance(node, exp.Between):
|
|
79
|
+
_coerce_date(node.this, node.args["low"], promote_to_inferred_datetime_type)
|
|
80
|
+
elif isinstance(node, exp.Extract) and not node.expression.type.is_type(
|
|
81
|
+
*exp.DataType.TEMPORAL_TYPES
|
|
82
|
+
):
|
|
83
|
+
_replace_cast(node.expression, exp.DataType.Type.DATETIME)
|
|
84
|
+
elif isinstance(node, (exp.DateAdd, exp.DateSub, exp.DateTrunc)):
|
|
85
|
+
_coerce_timeunit_arg(node.this, node.unit)
|
|
86
|
+
elif isinstance(node, exp.DateDiff):
|
|
87
|
+
_coerce_datediff_args(node)
|
|
88
|
+
|
|
89
|
+
return node
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def remove_redundant_casts(expression: exp.Expression) -> exp.Expression:
|
|
93
|
+
if (
|
|
94
|
+
isinstance(expression, exp.Cast)
|
|
95
|
+
and expression.this.type
|
|
96
|
+
and expression.to == expression.this.type
|
|
97
|
+
):
|
|
98
|
+
return expression.this
|
|
99
|
+
|
|
100
|
+
if (
|
|
101
|
+
isinstance(expression, (exp.Date, exp.TsOrDsToDate))
|
|
102
|
+
and expression.this.type
|
|
103
|
+
and expression.this.type.this == exp.DataType.Type.DATE
|
|
104
|
+
and not expression.this.type.expressions
|
|
105
|
+
):
|
|
106
|
+
return expression.this
|
|
107
|
+
|
|
108
|
+
return expression
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def ensure_bools(
|
|
112
|
+
expression: exp.Expression, replace_func: t.Callable[[exp.Expression], None]
|
|
113
|
+
) -> exp.Expression:
|
|
114
|
+
if isinstance(expression, exp.Connector):
|
|
115
|
+
replace_func(expression.left)
|
|
116
|
+
replace_func(expression.right)
|
|
117
|
+
elif isinstance(expression, exp.Not):
|
|
118
|
+
replace_func(expression.this)
|
|
119
|
+
# We can't replace num in CASE x WHEN num ..., because it's not the full predicate
|
|
120
|
+
elif isinstance(expression, exp.If) and not (
|
|
121
|
+
isinstance(expression.parent, exp.Case) and expression.parent.this
|
|
122
|
+
):
|
|
123
|
+
replace_func(expression.this)
|
|
124
|
+
elif isinstance(expression, (exp.Where, exp.Having)):
|
|
125
|
+
replace_func(expression.this)
|
|
126
|
+
|
|
127
|
+
return expression
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def remove_ascending_order(expression: exp.Expression) -> exp.Expression:
|
|
131
|
+
if isinstance(expression, exp.Ordered) and expression.args.get("desc") is False:
|
|
132
|
+
# Convert ORDER BY a ASC to ORDER BY a
|
|
133
|
+
expression.set("desc", None)
|
|
134
|
+
|
|
135
|
+
return expression
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def _coerce_date(
|
|
139
|
+
a: exp.Expression,
|
|
140
|
+
b: exp.Expression,
|
|
141
|
+
promote_to_inferred_datetime_type: bool,
|
|
142
|
+
) -> None:
|
|
143
|
+
for a, b in itertools.permutations([a, b]):
|
|
144
|
+
if isinstance(b, exp.Interval):
|
|
145
|
+
a = _coerce_timeunit_arg(a, b.unit)
|
|
146
|
+
|
|
147
|
+
a_type = a.type
|
|
148
|
+
if (
|
|
149
|
+
not a_type
|
|
150
|
+
or a_type.this not in exp.DataType.TEMPORAL_TYPES
|
|
151
|
+
or not b.type
|
|
152
|
+
or b.type.this not in exp.DataType.TEXT_TYPES
|
|
153
|
+
):
|
|
154
|
+
continue
|
|
155
|
+
|
|
156
|
+
if promote_to_inferred_datetime_type:
|
|
157
|
+
if b.is_string:
|
|
158
|
+
date_text = b.name
|
|
159
|
+
if is_iso_date(date_text):
|
|
160
|
+
b_type = exp.DataType.Type.DATE
|
|
161
|
+
elif is_iso_datetime(date_text):
|
|
162
|
+
b_type = exp.DataType.Type.DATETIME
|
|
163
|
+
else:
|
|
164
|
+
b_type = a_type.this
|
|
165
|
+
else:
|
|
166
|
+
# If b is not a datetime string, we conservatively promote it to a DATETIME,
|
|
167
|
+
# in order to ensure there are no surprising truncations due to downcasting
|
|
168
|
+
b_type = exp.DataType.Type.DATETIME
|
|
169
|
+
|
|
170
|
+
target_type = (
|
|
171
|
+
b_type if b_type in TypeAnnotator.COERCES_TO.get(a_type.this, {}) else a_type
|
|
172
|
+
)
|
|
173
|
+
else:
|
|
174
|
+
target_type = a_type
|
|
175
|
+
|
|
176
|
+
if target_type != a_type:
|
|
177
|
+
_replace_cast(a, target_type)
|
|
178
|
+
|
|
179
|
+
_replace_cast(b, target_type)
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def _coerce_timeunit_arg(arg: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.Expression:
|
|
183
|
+
if not arg.type:
|
|
184
|
+
return arg
|
|
185
|
+
|
|
186
|
+
if arg.type.this in exp.DataType.TEXT_TYPES:
|
|
187
|
+
date_text = arg.name
|
|
188
|
+
is_iso_date_ = is_iso_date(date_text)
|
|
189
|
+
|
|
190
|
+
if is_iso_date_ and is_date_unit(unit):
|
|
191
|
+
return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATE))
|
|
192
|
+
|
|
193
|
+
# An ISO date is also an ISO datetime, but not vice versa
|
|
194
|
+
if is_iso_date_ or is_iso_datetime(date_text):
|
|
195
|
+
return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATETIME))
|
|
196
|
+
|
|
197
|
+
elif arg.type.this == exp.DataType.Type.DATE and not is_date_unit(unit):
|
|
198
|
+
return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATETIME))
|
|
199
|
+
|
|
200
|
+
return arg
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def _coerce_datediff_args(node: exp.DateDiff) -> None:
|
|
204
|
+
for e in (node.this, node.expression):
|
|
205
|
+
if e.type.this not in exp.DataType.TEMPORAL_TYPES:
|
|
206
|
+
e.replace(exp.cast(e.copy(), to=exp.DataType.Type.DATETIME))
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def _replace_cast(node: exp.Expression, to: exp.DATA_TYPE) -> None:
|
|
210
|
+
node.replace(exp.cast(node.copy(), to=to))
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
# this was originally designed for presto, there is a similar transform for tsql
|
|
214
|
+
# this is different in that it only operates on int types, this is because
|
|
215
|
+
# presto has a boolean type whereas tsql doesn't (people use bits)
|
|
216
|
+
# with y as (select true as x) select x = 0 FROM y -- illegal presto query
|
|
217
|
+
def _replace_int_predicate(expression: exp.Expression) -> None:
|
|
218
|
+
if isinstance(expression, exp.Coalesce):
|
|
219
|
+
for child in expression.iter_expressions():
|
|
220
|
+
_replace_int_predicate(child)
|
|
221
|
+
elif expression.type and expression.type.this in exp.DataType.INTEGER_TYPES:
|
|
222
|
+
expression.replace(expression.neq(0))
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
from sqlglot.optimizer.scope import Scope, build_scope
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def eliminate_ctes(expression):
|
|
5
|
+
"""
|
|
6
|
+
Remove unused CTEs from an expression.
|
|
7
|
+
|
|
8
|
+
Example:
|
|
9
|
+
>>> import sqlglot
|
|
10
|
+
>>> sql = "WITH y AS (SELECT a FROM x) SELECT a FROM z"
|
|
11
|
+
>>> expression = sqlglot.parse_one(sql)
|
|
12
|
+
>>> eliminate_ctes(expression).sql()
|
|
13
|
+
'SELECT a FROM z'
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
expression (sqlglot.Expression): expression to optimize
|
|
17
|
+
Returns:
|
|
18
|
+
sqlglot.Expression: optimized expression
|
|
19
|
+
"""
|
|
20
|
+
root = build_scope(expression)
|
|
21
|
+
|
|
22
|
+
if root:
|
|
23
|
+
ref_count = root.ref_count()
|
|
24
|
+
|
|
25
|
+
# Traverse the scope tree in reverse so we can remove chains of unused CTEs
|
|
26
|
+
for scope in reversed(list(root.traverse())):
|
|
27
|
+
if scope.is_cte:
|
|
28
|
+
count = ref_count[id(scope)]
|
|
29
|
+
if count <= 0:
|
|
30
|
+
cte_node = scope.expression.parent
|
|
31
|
+
with_node = cte_node.parent
|
|
32
|
+
cte_node.pop()
|
|
33
|
+
|
|
34
|
+
# Pop the entire WITH clause if this is the last CTE
|
|
35
|
+
if with_node and len(with_node.expressions) <= 0:
|
|
36
|
+
with_node.pop()
|
|
37
|
+
|
|
38
|
+
# Decrement the ref count for all sources this CTE selects from
|
|
39
|
+
for _, source in scope.selected_sources.values():
|
|
40
|
+
if isinstance(source, Scope):
|
|
41
|
+
ref_count[id(source)] -= 1
|
|
42
|
+
|
|
43
|
+
return expression
|
|
@@ -0,0 +1,181 @@
|
|
|
1
|
+
from sqlglot import expressions as exp
|
|
2
|
+
from sqlglot.optimizer.normalize import normalized
|
|
3
|
+
from sqlglot.optimizer.scope import Scope, traverse_scope
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def eliminate_joins(expression):
|
|
7
|
+
"""
|
|
8
|
+
Remove unused joins from an expression.
|
|
9
|
+
|
|
10
|
+
This only removes joins when we know that the join condition doesn't produce duplicate rows.
|
|
11
|
+
|
|
12
|
+
Example:
|
|
13
|
+
>>> import sqlglot
|
|
14
|
+
>>> sql = "SELECT x.a FROM x LEFT JOIN (SELECT DISTINCT y.b FROM y) AS y ON x.b = y.b"
|
|
15
|
+
>>> expression = sqlglot.parse_one(sql)
|
|
16
|
+
>>> eliminate_joins(expression).sql()
|
|
17
|
+
'SELECT x.a FROM x'
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
expression (sqlglot.Expression): expression to optimize
|
|
21
|
+
Returns:
|
|
22
|
+
sqlglot.Expression: optimized expression
|
|
23
|
+
"""
|
|
24
|
+
for scope in traverse_scope(expression):
|
|
25
|
+
# If any columns in this scope aren't qualified, it's hard to determine if a join isn't used.
|
|
26
|
+
# It's probably possible to infer this from the outputs of derived tables.
|
|
27
|
+
# But for now, let's just skip this rule.
|
|
28
|
+
if scope.unqualified_columns:
|
|
29
|
+
continue
|
|
30
|
+
|
|
31
|
+
joins = scope.expression.args.get("joins", [])
|
|
32
|
+
|
|
33
|
+
# Reverse the joins so we can remove chains of unused joins
|
|
34
|
+
for join in reversed(joins):
|
|
35
|
+
alias = join.alias_or_name
|
|
36
|
+
if _should_eliminate_join(scope, join, alias):
|
|
37
|
+
join.pop()
|
|
38
|
+
scope.remove_source(alias)
|
|
39
|
+
return expression
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _should_eliminate_join(scope, join, alias):
|
|
43
|
+
inner_source = scope.sources.get(alias)
|
|
44
|
+
return (
|
|
45
|
+
isinstance(inner_source, Scope)
|
|
46
|
+
and not _join_is_used(scope, join, alias)
|
|
47
|
+
and (
|
|
48
|
+
(join.side == "LEFT" and _is_joined_on_all_unique_outputs(inner_source, join))
|
|
49
|
+
or (not join.args.get("on") and _has_single_output_row(inner_source))
|
|
50
|
+
)
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _join_is_used(scope, join, alias):
|
|
55
|
+
# We need to find all columns that reference this join.
|
|
56
|
+
# But columns in the ON clause shouldn't count.
|
|
57
|
+
on = join.args.get("on")
|
|
58
|
+
if on:
|
|
59
|
+
on_clause_columns = {id(column) for column in on.find_all(exp.Column)}
|
|
60
|
+
else:
|
|
61
|
+
on_clause_columns = set()
|
|
62
|
+
return any(
|
|
63
|
+
column for column in scope.source_columns(alias) if id(column) not in on_clause_columns
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _is_joined_on_all_unique_outputs(scope, join):
|
|
68
|
+
unique_outputs = _unique_outputs(scope)
|
|
69
|
+
if not unique_outputs:
|
|
70
|
+
return False
|
|
71
|
+
|
|
72
|
+
_, join_keys, _ = join_condition(join)
|
|
73
|
+
remaining_unique_outputs = unique_outputs - {c.name for c in join_keys}
|
|
74
|
+
return not remaining_unique_outputs
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def _unique_outputs(scope):
|
|
78
|
+
"""Determine output columns of `scope` that must have a unique combination per row"""
|
|
79
|
+
if scope.expression.args.get("distinct"):
|
|
80
|
+
return set(scope.expression.named_selects)
|
|
81
|
+
|
|
82
|
+
group = scope.expression.args.get("group")
|
|
83
|
+
if group:
|
|
84
|
+
grouped_expressions = set(group.expressions)
|
|
85
|
+
grouped_outputs = set()
|
|
86
|
+
|
|
87
|
+
unique_outputs = set()
|
|
88
|
+
for select in scope.expression.selects:
|
|
89
|
+
output = select.unalias()
|
|
90
|
+
if output in grouped_expressions:
|
|
91
|
+
grouped_outputs.add(output)
|
|
92
|
+
unique_outputs.add(select.alias_or_name)
|
|
93
|
+
|
|
94
|
+
# All the grouped expressions must be in the output
|
|
95
|
+
if not grouped_expressions.difference(grouped_outputs):
|
|
96
|
+
return unique_outputs
|
|
97
|
+
else:
|
|
98
|
+
return set()
|
|
99
|
+
|
|
100
|
+
if _has_single_output_row(scope):
|
|
101
|
+
return set(scope.expression.named_selects)
|
|
102
|
+
|
|
103
|
+
return set()
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def _has_single_output_row(scope):
|
|
107
|
+
return isinstance(scope.expression, exp.Select) and (
|
|
108
|
+
all(isinstance(e.unalias(), exp.AggFunc) for e in scope.expression.selects)
|
|
109
|
+
or _is_limit_1(scope)
|
|
110
|
+
or not scope.expression.args.get("from")
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def _is_limit_1(scope):
|
|
115
|
+
limit = scope.expression.args.get("limit")
|
|
116
|
+
return limit and limit.expression.this == "1"
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def join_condition(join):
|
|
120
|
+
"""
|
|
121
|
+
Extract the join condition from a join expression.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
join (exp.Join)
|
|
125
|
+
Returns:
|
|
126
|
+
tuple[list[str], list[str], exp.Expression]:
|
|
127
|
+
Tuple of (source key, join key, remaining predicate)
|
|
128
|
+
"""
|
|
129
|
+
name = join.alias_or_name
|
|
130
|
+
on = (join.args.get("on") or exp.true()).copy()
|
|
131
|
+
source_key = []
|
|
132
|
+
join_key = []
|
|
133
|
+
|
|
134
|
+
def extract_condition(condition):
|
|
135
|
+
left, right = condition.unnest_operands()
|
|
136
|
+
left_tables = exp.column_table_names(left)
|
|
137
|
+
right_tables = exp.column_table_names(right)
|
|
138
|
+
|
|
139
|
+
if name in left_tables and name not in right_tables:
|
|
140
|
+
join_key.append(left)
|
|
141
|
+
source_key.append(right)
|
|
142
|
+
condition.replace(exp.true())
|
|
143
|
+
elif name in right_tables and name not in left_tables:
|
|
144
|
+
join_key.append(right)
|
|
145
|
+
source_key.append(left)
|
|
146
|
+
condition.replace(exp.true())
|
|
147
|
+
|
|
148
|
+
# find the join keys
|
|
149
|
+
# SELECT
|
|
150
|
+
# FROM x
|
|
151
|
+
# JOIN y
|
|
152
|
+
# ON x.a = y.b AND y.b > 1
|
|
153
|
+
#
|
|
154
|
+
# should pull y.b as the join key and x.a as the source key
|
|
155
|
+
if normalized(on):
|
|
156
|
+
on = on if isinstance(on, exp.And) else exp.and_(on, exp.true(), copy=False)
|
|
157
|
+
|
|
158
|
+
for condition in on.flatten():
|
|
159
|
+
if isinstance(condition, exp.EQ):
|
|
160
|
+
extract_condition(condition)
|
|
161
|
+
elif normalized(on, dnf=True):
|
|
162
|
+
conditions = None
|
|
163
|
+
|
|
164
|
+
for condition in on.flatten():
|
|
165
|
+
parts = [part for part in condition.flatten() if isinstance(part, exp.EQ)]
|
|
166
|
+
if conditions is None:
|
|
167
|
+
conditions = parts
|
|
168
|
+
else:
|
|
169
|
+
temp = []
|
|
170
|
+
for p in parts:
|
|
171
|
+
cs = [c for c in conditions if p == c]
|
|
172
|
+
|
|
173
|
+
if cs:
|
|
174
|
+
temp.append(p)
|
|
175
|
+
temp.extend(cs)
|
|
176
|
+
conditions = temp
|
|
177
|
+
|
|
178
|
+
for condition in conditions:
|
|
179
|
+
extract_condition(condition)
|
|
180
|
+
|
|
181
|
+
return source_key, join_key, on
|
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import itertools
|
|
4
|
+
import typing as t
|
|
5
|
+
|
|
6
|
+
from sqlglot import expressions as exp
|
|
7
|
+
from sqlglot.helper import find_new_name
|
|
8
|
+
from sqlglot.optimizer.scope import Scope, build_scope
|
|
9
|
+
|
|
10
|
+
if t.TYPE_CHECKING:
|
|
11
|
+
ExistingCTEsMapping = t.Dict[exp.Expression, str]
|
|
12
|
+
TakenNameMapping = t.Dict[str, t.Union[Scope, exp.Expression]]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def eliminate_subqueries(expression: exp.Expression) -> exp.Expression:
|
|
16
|
+
"""
|
|
17
|
+
Rewrite derived tables as CTES, deduplicating if possible.
|
|
18
|
+
|
|
19
|
+
Example:
|
|
20
|
+
>>> import sqlglot
|
|
21
|
+
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y")
|
|
22
|
+
>>> eliminate_subqueries(expression).sql()
|
|
23
|
+
'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y'
|
|
24
|
+
|
|
25
|
+
This also deduplicates common subqueries:
|
|
26
|
+
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y CROSS JOIN (SELECT * FROM x) AS z")
|
|
27
|
+
>>> eliminate_subqueries(expression).sql()
|
|
28
|
+
'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y CROSS JOIN y AS z'
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
expression (sqlglot.Expression): expression
|
|
32
|
+
Returns:
|
|
33
|
+
sqlglot.Expression: expression
|
|
34
|
+
"""
|
|
35
|
+
if isinstance(expression, exp.Subquery):
|
|
36
|
+
# It's possible to have subqueries at the root, e.g. (SELECT * FROM x) LIMIT 1
|
|
37
|
+
eliminate_subqueries(expression.this)
|
|
38
|
+
return expression
|
|
39
|
+
|
|
40
|
+
root = build_scope(expression)
|
|
41
|
+
|
|
42
|
+
if not root:
|
|
43
|
+
return expression
|
|
44
|
+
|
|
45
|
+
# Map of alias->Scope|Table
|
|
46
|
+
# These are all aliases that are already used in the expression.
|
|
47
|
+
# We don't want to create new CTEs that conflict with these names.
|
|
48
|
+
taken: TakenNameMapping = {}
|
|
49
|
+
|
|
50
|
+
# All CTE aliases in the root scope are taken
|
|
51
|
+
for scope in root.cte_scopes:
|
|
52
|
+
taken[scope.expression.parent.alias] = scope
|
|
53
|
+
|
|
54
|
+
# All table names are taken
|
|
55
|
+
for scope in root.traverse():
|
|
56
|
+
taken.update(
|
|
57
|
+
{
|
|
58
|
+
source.name: source
|
|
59
|
+
for _, source in scope.sources.items()
|
|
60
|
+
if isinstance(source, exp.Table)
|
|
61
|
+
}
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
# Map of Expression->alias
|
|
65
|
+
# Existing CTES in the root expression. We'll use this for deduplication.
|
|
66
|
+
existing_ctes: ExistingCTEsMapping = {}
|
|
67
|
+
|
|
68
|
+
with_ = root.expression.args.get("with")
|
|
69
|
+
recursive = False
|
|
70
|
+
if with_:
|
|
71
|
+
recursive = with_.args.get("recursive")
|
|
72
|
+
for cte in with_.expressions:
|
|
73
|
+
existing_ctes[cte.this] = cte.alias
|
|
74
|
+
new_ctes = []
|
|
75
|
+
|
|
76
|
+
# We're adding more CTEs, but we want to maintain the DAG order.
|
|
77
|
+
# Derived tables within an existing CTE need to come before the existing CTE.
|
|
78
|
+
for cte_scope in root.cte_scopes:
|
|
79
|
+
# Append all the new CTEs from this existing CTE
|
|
80
|
+
for scope in cte_scope.traverse():
|
|
81
|
+
if scope is cte_scope:
|
|
82
|
+
# Don't try to eliminate this CTE itself
|
|
83
|
+
continue
|
|
84
|
+
new_cte = _eliminate(scope, existing_ctes, taken)
|
|
85
|
+
if new_cte:
|
|
86
|
+
new_ctes.append(new_cte)
|
|
87
|
+
|
|
88
|
+
# Append the existing CTE itself
|
|
89
|
+
new_ctes.append(cte_scope.expression.parent)
|
|
90
|
+
|
|
91
|
+
# Now append the rest
|
|
92
|
+
for scope in itertools.chain(root.union_scopes, root.subquery_scopes, root.table_scopes):
|
|
93
|
+
for child_scope in scope.traverse():
|
|
94
|
+
new_cte = _eliminate(child_scope, existing_ctes, taken)
|
|
95
|
+
if new_cte:
|
|
96
|
+
new_ctes.append(new_cte)
|
|
97
|
+
|
|
98
|
+
if new_ctes:
|
|
99
|
+
query = expression.expression if isinstance(expression, exp.DDL) else expression
|
|
100
|
+
query.set("with", exp.With(expressions=new_ctes, recursive=recursive))
|
|
101
|
+
|
|
102
|
+
return expression
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def _eliminate(
|
|
106
|
+
scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping
|
|
107
|
+
) -> t.Optional[exp.Expression]:
|
|
108
|
+
if scope.is_derived_table:
|
|
109
|
+
return _eliminate_derived_table(scope, existing_ctes, taken)
|
|
110
|
+
|
|
111
|
+
if scope.is_cte:
|
|
112
|
+
return _eliminate_cte(scope, existing_ctes, taken)
|
|
113
|
+
|
|
114
|
+
return None
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def _eliminate_derived_table(
|
|
118
|
+
scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping
|
|
119
|
+
) -> t.Optional[exp.Expression]:
|
|
120
|
+
# This makes sure that we don't:
|
|
121
|
+
# - drop the "pivot" arg from a pivoted subquery
|
|
122
|
+
# - eliminate a lateral correlated subquery
|
|
123
|
+
if scope.parent.pivots or isinstance(scope.parent.expression, exp.Lateral):
|
|
124
|
+
return None
|
|
125
|
+
|
|
126
|
+
# Get rid of redundant exp.Subquery expressions, i.e. those that are just used as wrappers
|
|
127
|
+
to_replace = scope.expression.parent.unwrap()
|
|
128
|
+
name, cte = _new_cte(scope, existing_ctes, taken)
|
|
129
|
+
table = exp.alias_(exp.table_(name), alias=to_replace.alias or name)
|
|
130
|
+
table.set("joins", to_replace.args.get("joins"))
|
|
131
|
+
|
|
132
|
+
to_replace.replace(table)
|
|
133
|
+
|
|
134
|
+
return cte
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def _eliminate_cte(
|
|
138
|
+
scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping
|
|
139
|
+
) -> t.Optional[exp.Expression]:
|
|
140
|
+
parent = scope.expression.parent
|
|
141
|
+
name, cte = _new_cte(scope, existing_ctes, taken)
|
|
142
|
+
|
|
143
|
+
with_ = parent.parent
|
|
144
|
+
parent.pop()
|
|
145
|
+
if not with_.expressions:
|
|
146
|
+
with_.pop()
|
|
147
|
+
|
|
148
|
+
# Rename references to this CTE
|
|
149
|
+
for child_scope in scope.parent.traverse():
|
|
150
|
+
for table, source in child_scope.selected_sources.values():
|
|
151
|
+
if source is scope:
|
|
152
|
+
new_table = exp.alias_(exp.table_(name), alias=table.alias_or_name, copy=False)
|
|
153
|
+
table.replace(new_table)
|
|
154
|
+
|
|
155
|
+
return cte
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def _new_cte(
|
|
159
|
+
scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping
|
|
160
|
+
) -> t.Tuple[str, t.Optional[exp.Expression]]:
|
|
161
|
+
"""
|
|
162
|
+
Returns:
|
|
163
|
+
tuple of (name, cte)
|
|
164
|
+
where `name` is a new name for this CTE in the root scope and `cte` is a new CTE instance.
|
|
165
|
+
If this CTE duplicates an existing CTE, `cte` will be None.
|
|
166
|
+
"""
|
|
167
|
+
duplicate_cte_alias = existing_ctes.get(scope.expression)
|
|
168
|
+
parent = scope.expression.parent
|
|
169
|
+
name = parent.alias
|
|
170
|
+
|
|
171
|
+
if not name:
|
|
172
|
+
name = find_new_name(taken=taken, base="cte")
|
|
173
|
+
|
|
174
|
+
if duplicate_cte_alias:
|
|
175
|
+
name = duplicate_cte_alias
|
|
176
|
+
elif taken.get(name):
|
|
177
|
+
name = find_new_name(taken=taken, base=name)
|
|
178
|
+
|
|
179
|
+
taken[name] = scope
|
|
180
|
+
|
|
181
|
+
if not duplicate_cte_alias:
|
|
182
|
+
existing_ctes[scope.expression] = name
|
|
183
|
+
cte = exp.CTE(
|
|
184
|
+
this=scope.expression,
|
|
185
|
+
alias=exp.TableAlias(this=exp.to_identifier(name)),
|
|
186
|
+
)
|
|
187
|
+
else:
|
|
188
|
+
cte = None
|
|
189
|
+
return name, cte
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import typing as t
|
|
4
|
+
|
|
5
|
+
from sqlglot import alias, exp
|
|
6
|
+
from sqlglot.errors import OptimizeError
|
|
7
|
+
from sqlglot.optimizer.scope import traverse_scope
|
|
8
|
+
from sqlglot.schema import ensure_schema
|
|
9
|
+
|
|
10
|
+
if t.TYPE_CHECKING:
|
|
11
|
+
from sqlglot._typing import E
|
|
12
|
+
from sqlglot.schema import Schema
|
|
13
|
+
from sqlglot.dialects.dialect import DialectType
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def isolate_table_selects(
|
|
17
|
+
expression: E,
|
|
18
|
+
schema: t.Optional[t.Dict | Schema] = None,
|
|
19
|
+
dialect: DialectType = None,
|
|
20
|
+
) -> E:
|
|
21
|
+
schema = ensure_schema(schema, dialect=dialect)
|
|
22
|
+
|
|
23
|
+
for scope in traverse_scope(expression):
|
|
24
|
+
if len(scope.selected_sources) == 1:
|
|
25
|
+
continue
|
|
26
|
+
|
|
27
|
+
for _, source in scope.selected_sources.values():
|
|
28
|
+
assert source.parent
|
|
29
|
+
|
|
30
|
+
if (
|
|
31
|
+
not isinstance(source, exp.Table)
|
|
32
|
+
or not schema.column_names(source)
|
|
33
|
+
or isinstance(source.parent, exp.Subquery)
|
|
34
|
+
or isinstance(source.parent.parent, exp.Table)
|
|
35
|
+
):
|
|
36
|
+
continue
|
|
37
|
+
|
|
38
|
+
if not source.alias:
|
|
39
|
+
raise OptimizeError("Tables require an alias. Run qualify_tables optimization.")
|
|
40
|
+
|
|
41
|
+
source.replace(
|
|
42
|
+
exp.select("*")
|
|
43
|
+
.from_(
|
|
44
|
+
alias(source, source.alias_or_name, table=True),
|
|
45
|
+
copy=False,
|
|
46
|
+
)
|
|
47
|
+
.subquery(source.alias, copy=False)
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
return expression
|