altimate-code 0.5.2 → 0.5.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. package/CHANGELOG.md +12 -0
  2. package/bin/altimate +6 -0
  3. package/bin/altimate-code +6 -0
  4. package/dbt-tools/bin/altimate-dbt +2 -0
  5. package/dbt-tools/dist/altimate_python_packages/altimate_packages/altimate/__init__.py +0 -0
  6. package/dbt-tools/dist/altimate_python_packages/altimate_packages/altimate/fetch_schema.py +35 -0
  7. package/dbt-tools/dist/altimate_python_packages/altimate_packages/altimate/utils.py +353 -0
  8. package/dbt-tools/dist/altimate_python_packages/altimate_packages/altimate/validate_sql.py +114 -0
  9. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/__init__.py +178 -0
  10. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/__main__.py +96 -0
  11. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/_typing.py +17 -0
  12. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/__init__.py +3 -0
  13. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/__init__.py +18 -0
  14. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/_typing.py +18 -0
  15. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/column.py +332 -0
  16. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/dataframe.py +866 -0
  17. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/functions.py +1267 -0
  18. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/group.py +59 -0
  19. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/normalize.py +78 -0
  20. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/operations.py +53 -0
  21. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/readwriter.py +108 -0
  22. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/session.py +190 -0
  23. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/transforms.py +9 -0
  24. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/types.py +212 -0
  25. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/util.py +32 -0
  26. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/window.py +134 -0
  27. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/__init__.py +118 -0
  28. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/athena.py +166 -0
  29. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/bigquery.py +1331 -0
  30. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/clickhouse.py +1393 -0
  31. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/databricks.py +131 -0
  32. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/dialect.py +1915 -0
  33. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/doris.py +561 -0
  34. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/drill.py +157 -0
  35. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/druid.py +20 -0
  36. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/duckdb.py +1159 -0
  37. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/dune.py +16 -0
  38. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/hive.py +787 -0
  39. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/materialize.py +94 -0
  40. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/mysql.py +1324 -0
  41. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/oracle.py +378 -0
  42. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/postgres.py +778 -0
  43. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/presto.py +788 -0
  44. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/prql.py +203 -0
  45. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/redshift.py +448 -0
  46. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/risingwave.py +78 -0
  47. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/snowflake.py +1464 -0
  48. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/spark.py +202 -0
  49. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/spark2.py +349 -0
  50. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/sqlite.py +320 -0
  51. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/starrocks.py +343 -0
  52. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/tableau.py +61 -0
  53. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/teradata.py +356 -0
  54. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/trino.py +115 -0
  55. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/tsql.py +1403 -0
  56. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/diff.py +456 -0
  57. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/errors.py +93 -0
  58. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/__init__.py +95 -0
  59. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/context.py +101 -0
  60. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/env.py +246 -0
  61. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/python.py +460 -0
  62. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/table.py +155 -0
  63. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/expressions.py +8870 -0
  64. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/generator.py +4993 -0
  65. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/helper.py +582 -0
  66. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/jsonpath.py +227 -0
  67. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/lineage.py +423 -0
  68. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/__init__.py +11 -0
  69. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/annotate_types.py +589 -0
  70. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/canonicalize.py +222 -0
  71. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/eliminate_ctes.py +43 -0
  72. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/eliminate_joins.py +181 -0
  73. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/eliminate_subqueries.py +189 -0
  74. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/isolate_table_selects.py +50 -0
  75. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/merge_subqueries.py +415 -0
  76. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/normalize.py +200 -0
  77. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/normalize_identifiers.py +64 -0
  78. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/optimize_joins.py +91 -0
  79. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/optimizer.py +94 -0
  80. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/pushdown_predicates.py +222 -0
  81. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/pushdown_projections.py +172 -0
  82. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/qualify.py +104 -0
  83. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/qualify_columns.py +1024 -0
  84. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/qualify_tables.py +155 -0
  85. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/scope.py +904 -0
  86. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/simplify.py +1587 -0
  87. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/unnest_subqueries.py +302 -0
  88. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/parser.py +8501 -0
  89. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/planner.py +463 -0
  90. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/schema.py +588 -0
  91. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/serde.py +68 -0
  92. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/time.py +687 -0
  93. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/tokens.py +1520 -0
  94. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/transforms.py +1020 -0
  95. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/trie.py +81 -0
  96. package/dbt-tools/dist/altimate_python_packages/dbt_core_integration.py +825 -0
  97. package/dbt-tools/dist/altimate_python_packages/dbt_utils.py +157 -0
  98. package/dbt-tools/dist/index.js +23859 -0
  99. package/package.json +13 -13
  100. package/postinstall.mjs +42 -0
  101. package/skills/altimate-setup/SKILL.md +31 -0
@@ -0,0 +1,222 @@
1
+ from __future__ import annotations
2
+
3
+ import itertools
4
+ import typing as t
5
+
6
+ from sqlglot import exp
7
+ from sqlglot.dialects.dialect import Dialect, DialectType
8
+ from sqlglot.helper import is_date_unit, is_iso_date, is_iso_datetime
9
+ from sqlglot.optimizer.annotate_types import TypeAnnotator
10
+
11
+
12
+ def canonicalize(expression: exp.Expression, dialect: DialectType = None) -> exp.Expression:
13
+ """Converts a sql expression into a standard form.
14
+
15
+ This method relies on annotate_types because many of the
16
+ conversions rely on type inference.
17
+
18
+ Args:
19
+ expression: The expression to canonicalize.
20
+ """
21
+
22
+ dialect = Dialect.get_or_raise(dialect)
23
+
24
+ def _canonicalize(expression: exp.Expression) -> exp.Expression:
25
+ expression = add_text_to_concat(expression)
26
+ expression = replace_date_funcs(expression, dialect=dialect)
27
+ expression = coerce_type(expression, dialect.PROMOTE_TO_INFERRED_DATETIME_TYPE)
28
+ expression = remove_redundant_casts(expression)
29
+ expression = ensure_bools(expression, _replace_int_predicate)
30
+ expression = remove_ascending_order(expression)
31
+ return expression
32
+
33
+ return exp.replace_tree(expression, _canonicalize)
34
+
35
+
36
+ def add_text_to_concat(node: exp.Expression) -> exp.Expression:
37
+ if isinstance(node, exp.Add) and node.type and node.type.this in exp.DataType.TEXT_TYPES:
38
+ node = exp.Concat(expressions=[node.left, node.right])
39
+ return node
40
+
41
+
42
+ def replace_date_funcs(node: exp.Expression, dialect: DialectType) -> exp.Expression:
43
+ if (
44
+ isinstance(node, (exp.Date, exp.TsOrDsToDate))
45
+ and not node.expressions
46
+ and not node.args.get("zone")
47
+ and node.this.is_string
48
+ and is_iso_date(node.this.name)
49
+ ):
50
+ return exp.cast(node.this, to=exp.DataType.Type.DATE)
51
+ if isinstance(node, exp.Timestamp) and not node.args.get("zone"):
52
+ if not node.type:
53
+ from sqlglot.optimizer.annotate_types import annotate_types
54
+
55
+ node = annotate_types(node, dialect=dialect)
56
+ return exp.cast(node.this, to=node.type or exp.DataType.Type.TIMESTAMP)
57
+
58
+ return node
59
+
60
+
61
+ COERCIBLE_DATE_OPS = (
62
+ exp.Add,
63
+ exp.Sub,
64
+ exp.EQ,
65
+ exp.NEQ,
66
+ exp.GT,
67
+ exp.GTE,
68
+ exp.LT,
69
+ exp.LTE,
70
+ exp.NullSafeEQ,
71
+ exp.NullSafeNEQ,
72
+ )
73
+
74
+
75
+ def coerce_type(node: exp.Expression, promote_to_inferred_datetime_type: bool) -> exp.Expression:
76
+ if isinstance(node, COERCIBLE_DATE_OPS):
77
+ _coerce_date(node.left, node.right, promote_to_inferred_datetime_type)
78
+ elif isinstance(node, exp.Between):
79
+ _coerce_date(node.this, node.args["low"], promote_to_inferred_datetime_type)
80
+ elif isinstance(node, exp.Extract) and not node.expression.type.is_type(
81
+ *exp.DataType.TEMPORAL_TYPES
82
+ ):
83
+ _replace_cast(node.expression, exp.DataType.Type.DATETIME)
84
+ elif isinstance(node, (exp.DateAdd, exp.DateSub, exp.DateTrunc)):
85
+ _coerce_timeunit_arg(node.this, node.unit)
86
+ elif isinstance(node, exp.DateDiff):
87
+ _coerce_datediff_args(node)
88
+
89
+ return node
90
+
91
+
92
+ def remove_redundant_casts(expression: exp.Expression) -> exp.Expression:
93
+ if (
94
+ isinstance(expression, exp.Cast)
95
+ and expression.this.type
96
+ and expression.to == expression.this.type
97
+ ):
98
+ return expression.this
99
+
100
+ if (
101
+ isinstance(expression, (exp.Date, exp.TsOrDsToDate))
102
+ and expression.this.type
103
+ and expression.this.type.this == exp.DataType.Type.DATE
104
+ and not expression.this.type.expressions
105
+ ):
106
+ return expression.this
107
+
108
+ return expression
109
+
110
+
111
+ def ensure_bools(
112
+ expression: exp.Expression, replace_func: t.Callable[[exp.Expression], None]
113
+ ) -> exp.Expression:
114
+ if isinstance(expression, exp.Connector):
115
+ replace_func(expression.left)
116
+ replace_func(expression.right)
117
+ elif isinstance(expression, exp.Not):
118
+ replace_func(expression.this)
119
+ # We can't replace num in CASE x WHEN num ..., because it's not the full predicate
120
+ elif isinstance(expression, exp.If) and not (
121
+ isinstance(expression.parent, exp.Case) and expression.parent.this
122
+ ):
123
+ replace_func(expression.this)
124
+ elif isinstance(expression, (exp.Where, exp.Having)):
125
+ replace_func(expression.this)
126
+
127
+ return expression
128
+
129
+
130
+ def remove_ascending_order(expression: exp.Expression) -> exp.Expression:
131
+ if isinstance(expression, exp.Ordered) and expression.args.get("desc") is False:
132
+ # Convert ORDER BY a ASC to ORDER BY a
133
+ expression.set("desc", None)
134
+
135
+ return expression
136
+
137
+
138
+ def _coerce_date(
139
+ a: exp.Expression,
140
+ b: exp.Expression,
141
+ promote_to_inferred_datetime_type: bool,
142
+ ) -> None:
143
+ for a, b in itertools.permutations([a, b]):
144
+ if isinstance(b, exp.Interval):
145
+ a = _coerce_timeunit_arg(a, b.unit)
146
+
147
+ a_type = a.type
148
+ if (
149
+ not a_type
150
+ or a_type.this not in exp.DataType.TEMPORAL_TYPES
151
+ or not b.type
152
+ or b.type.this not in exp.DataType.TEXT_TYPES
153
+ ):
154
+ continue
155
+
156
+ if promote_to_inferred_datetime_type:
157
+ if b.is_string:
158
+ date_text = b.name
159
+ if is_iso_date(date_text):
160
+ b_type = exp.DataType.Type.DATE
161
+ elif is_iso_datetime(date_text):
162
+ b_type = exp.DataType.Type.DATETIME
163
+ else:
164
+ b_type = a_type.this
165
+ else:
166
+ # If b is not a datetime string, we conservatively promote it to a DATETIME,
167
+ # in order to ensure there are no surprising truncations due to downcasting
168
+ b_type = exp.DataType.Type.DATETIME
169
+
170
+ target_type = (
171
+ b_type if b_type in TypeAnnotator.COERCES_TO.get(a_type.this, {}) else a_type
172
+ )
173
+ else:
174
+ target_type = a_type
175
+
176
+ if target_type != a_type:
177
+ _replace_cast(a, target_type)
178
+
179
+ _replace_cast(b, target_type)
180
+
181
+
182
+ def _coerce_timeunit_arg(arg: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.Expression:
183
+ if not arg.type:
184
+ return arg
185
+
186
+ if arg.type.this in exp.DataType.TEXT_TYPES:
187
+ date_text = arg.name
188
+ is_iso_date_ = is_iso_date(date_text)
189
+
190
+ if is_iso_date_ and is_date_unit(unit):
191
+ return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATE))
192
+
193
+ # An ISO date is also an ISO datetime, but not vice versa
194
+ if is_iso_date_ or is_iso_datetime(date_text):
195
+ return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATETIME))
196
+
197
+ elif arg.type.this == exp.DataType.Type.DATE and not is_date_unit(unit):
198
+ return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATETIME))
199
+
200
+ return arg
201
+
202
+
203
+ def _coerce_datediff_args(node: exp.DateDiff) -> None:
204
+ for e in (node.this, node.expression):
205
+ if e.type.this not in exp.DataType.TEMPORAL_TYPES:
206
+ e.replace(exp.cast(e.copy(), to=exp.DataType.Type.DATETIME))
207
+
208
+
209
+ def _replace_cast(node: exp.Expression, to: exp.DATA_TYPE) -> None:
210
+ node.replace(exp.cast(node.copy(), to=to))
211
+
212
+
213
+ # this was originally designed for presto, there is a similar transform for tsql
214
+ # this is different in that it only operates on int types, this is because
215
+ # presto has a boolean type whereas tsql doesn't (people use bits)
216
+ # with y as (select true as x) select x = 0 FROM y -- illegal presto query
217
+ def _replace_int_predicate(expression: exp.Expression) -> None:
218
+ if isinstance(expression, exp.Coalesce):
219
+ for child in expression.iter_expressions():
220
+ _replace_int_predicate(child)
221
+ elif expression.type and expression.type.this in exp.DataType.INTEGER_TYPES:
222
+ expression.replace(expression.neq(0))
@@ -0,0 +1,43 @@
1
+ from sqlglot.optimizer.scope import Scope, build_scope
2
+
3
+
4
+ def eliminate_ctes(expression):
5
+ """
6
+ Remove unused CTEs from an expression.
7
+
8
+ Example:
9
+ >>> import sqlglot
10
+ >>> sql = "WITH y AS (SELECT a FROM x) SELECT a FROM z"
11
+ >>> expression = sqlglot.parse_one(sql)
12
+ >>> eliminate_ctes(expression).sql()
13
+ 'SELECT a FROM z'
14
+
15
+ Args:
16
+ expression (sqlglot.Expression): expression to optimize
17
+ Returns:
18
+ sqlglot.Expression: optimized expression
19
+ """
20
+ root = build_scope(expression)
21
+
22
+ if root:
23
+ ref_count = root.ref_count()
24
+
25
+ # Traverse the scope tree in reverse so we can remove chains of unused CTEs
26
+ for scope in reversed(list(root.traverse())):
27
+ if scope.is_cte:
28
+ count = ref_count[id(scope)]
29
+ if count <= 0:
30
+ cte_node = scope.expression.parent
31
+ with_node = cte_node.parent
32
+ cte_node.pop()
33
+
34
+ # Pop the entire WITH clause if this is the last CTE
35
+ if with_node and len(with_node.expressions) <= 0:
36
+ with_node.pop()
37
+
38
+ # Decrement the ref count for all sources this CTE selects from
39
+ for _, source in scope.selected_sources.values():
40
+ if isinstance(source, Scope):
41
+ ref_count[id(source)] -= 1
42
+
43
+ return expression
@@ -0,0 +1,181 @@
1
+ from sqlglot import expressions as exp
2
+ from sqlglot.optimizer.normalize import normalized
3
+ from sqlglot.optimizer.scope import Scope, traverse_scope
4
+
5
+
6
+ def eliminate_joins(expression):
7
+ """
8
+ Remove unused joins from an expression.
9
+
10
+ This only removes joins when we know that the join condition doesn't produce duplicate rows.
11
+
12
+ Example:
13
+ >>> import sqlglot
14
+ >>> sql = "SELECT x.a FROM x LEFT JOIN (SELECT DISTINCT y.b FROM y) AS y ON x.b = y.b"
15
+ >>> expression = sqlglot.parse_one(sql)
16
+ >>> eliminate_joins(expression).sql()
17
+ 'SELECT x.a FROM x'
18
+
19
+ Args:
20
+ expression (sqlglot.Expression): expression to optimize
21
+ Returns:
22
+ sqlglot.Expression: optimized expression
23
+ """
24
+ for scope in traverse_scope(expression):
25
+ # If any columns in this scope aren't qualified, it's hard to determine if a join isn't used.
26
+ # It's probably possible to infer this from the outputs of derived tables.
27
+ # But for now, let's just skip this rule.
28
+ if scope.unqualified_columns:
29
+ continue
30
+
31
+ joins = scope.expression.args.get("joins", [])
32
+
33
+ # Reverse the joins so we can remove chains of unused joins
34
+ for join in reversed(joins):
35
+ alias = join.alias_or_name
36
+ if _should_eliminate_join(scope, join, alias):
37
+ join.pop()
38
+ scope.remove_source(alias)
39
+ return expression
40
+
41
+
42
+ def _should_eliminate_join(scope, join, alias):
43
+ inner_source = scope.sources.get(alias)
44
+ return (
45
+ isinstance(inner_source, Scope)
46
+ and not _join_is_used(scope, join, alias)
47
+ and (
48
+ (join.side == "LEFT" and _is_joined_on_all_unique_outputs(inner_source, join))
49
+ or (not join.args.get("on") and _has_single_output_row(inner_source))
50
+ )
51
+ )
52
+
53
+
54
+ def _join_is_used(scope, join, alias):
55
+ # We need to find all columns that reference this join.
56
+ # But columns in the ON clause shouldn't count.
57
+ on = join.args.get("on")
58
+ if on:
59
+ on_clause_columns = {id(column) for column in on.find_all(exp.Column)}
60
+ else:
61
+ on_clause_columns = set()
62
+ return any(
63
+ column for column in scope.source_columns(alias) if id(column) not in on_clause_columns
64
+ )
65
+
66
+
67
+ def _is_joined_on_all_unique_outputs(scope, join):
68
+ unique_outputs = _unique_outputs(scope)
69
+ if not unique_outputs:
70
+ return False
71
+
72
+ _, join_keys, _ = join_condition(join)
73
+ remaining_unique_outputs = unique_outputs - {c.name for c in join_keys}
74
+ return not remaining_unique_outputs
75
+
76
+
77
+ def _unique_outputs(scope):
78
+ """Determine output columns of `scope` that must have a unique combination per row"""
79
+ if scope.expression.args.get("distinct"):
80
+ return set(scope.expression.named_selects)
81
+
82
+ group = scope.expression.args.get("group")
83
+ if group:
84
+ grouped_expressions = set(group.expressions)
85
+ grouped_outputs = set()
86
+
87
+ unique_outputs = set()
88
+ for select in scope.expression.selects:
89
+ output = select.unalias()
90
+ if output in grouped_expressions:
91
+ grouped_outputs.add(output)
92
+ unique_outputs.add(select.alias_or_name)
93
+
94
+ # All the grouped expressions must be in the output
95
+ if not grouped_expressions.difference(grouped_outputs):
96
+ return unique_outputs
97
+ else:
98
+ return set()
99
+
100
+ if _has_single_output_row(scope):
101
+ return set(scope.expression.named_selects)
102
+
103
+ return set()
104
+
105
+
106
+ def _has_single_output_row(scope):
107
+ return isinstance(scope.expression, exp.Select) and (
108
+ all(isinstance(e.unalias(), exp.AggFunc) for e in scope.expression.selects)
109
+ or _is_limit_1(scope)
110
+ or not scope.expression.args.get("from")
111
+ )
112
+
113
+
114
+ def _is_limit_1(scope):
115
+ limit = scope.expression.args.get("limit")
116
+ return limit and limit.expression.this == "1"
117
+
118
+
119
+ def join_condition(join):
120
+ """
121
+ Extract the join condition from a join expression.
122
+
123
+ Args:
124
+ join (exp.Join)
125
+ Returns:
126
+ tuple[list[str], list[str], exp.Expression]:
127
+ Tuple of (source key, join key, remaining predicate)
128
+ """
129
+ name = join.alias_or_name
130
+ on = (join.args.get("on") or exp.true()).copy()
131
+ source_key = []
132
+ join_key = []
133
+
134
+ def extract_condition(condition):
135
+ left, right = condition.unnest_operands()
136
+ left_tables = exp.column_table_names(left)
137
+ right_tables = exp.column_table_names(right)
138
+
139
+ if name in left_tables and name not in right_tables:
140
+ join_key.append(left)
141
+ source_key.append(right)
142
+ condition.replace(exp.true())
143
+ elif name in right_tables and name not in left_tables:
144
+ join_key.append(right)
145
+ source_key.append(left)
146
+ condition.replace(exp.true())
147
+
148
+ # find the join keys
149
+ # SELECT
150
+ # FROM x
151
+ # JOIN y
152
+ # ON x.a = y.b AND y.b > 1
153
+ #
154
+ # should pull y.b as the join key and x.a as the source key
155
+ if normalized(on):
156
+ on = on if isinstance(on, exp.And) else exp.and_(on, exp.true(), copy=False)
157
+
158
+ for condition in on.flatten():
159
+ if isinstance(condition, exp.EQ):
160
+ extract_condition(condition)
161
+ elif normalized(on, dnf=True):
162
+ conditions = None
163
+
164
+ for condition in on.flatten():
165
+ parts = [part for part in condition.flatten() if isinstance(part, exp.EQ)]
166
+ if conditions is None:
167
+ conditions = parts
168
+ else:
169
+ temp = []
170
+ for p in parts:
171
+ cs = [c for c in conditions if p == c]
172
+
173
+ if cs:
174
+ temp.append(p)
175
+ temp.extend(cs)
176
+ conditions = temp
177
+
178
+ for condition in conditions:
179
+ extract_condition(condition)
180
+
181
+ return source_key, join_key, on
@@ -0,0 +1,189 @@
1
+ from __future__ import annotations
2
+
3
+ import itertools
4
+ import typing as t
5
+
6
+ from sqlglot import expressions as exp
7
+ from sqlglot.helper import find_new_name
8
+ from sqlglot.optimizer.scope import Scope, build_scope
9
+
10
+ if t.TYPE_CHECKING:
11
+ ExistingCTEsMapping = t.Dict[exp.Expression, str]
12
+ TakenNameMapping = t.Dict[str, t.Union[Scope, exp.Expression]]
13
+
14
+
15
+ def eliminate_subqueries(expression: exp.Expression) -> exp.Expression:
16
+ """
17
+ Rewrite derived tables as CTES, deduplicating if possible.
18
+
19
+ Example:
20
+ >>> import sqlglot
21
+ >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y")
22
+ >>> eliminate_subqueries(expression).sql()
23
+ 'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y'
24
+
25
+ This also deduplicates common subqueries:
26
+ >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y CROSS JOIN (SELECT * FROM x) AS z")
27
+ >>> eliminate_subqueries(expression).sql()
28
+ 'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y CROSS JOIN y AS z'
29
+
30
+ Args:
31
+ expression (sqlglot.Expression): expression
32
+ Returns:
33
+ sqlglot.Expression: expression
34
+ """
35
+ if isinstance(expression, exp.Subquery):
36
+ # It's possible to have subqueries at the root, e.g. (SELECT * FROM x) LIMIT 1
37
+ eliminate_subqueries(expression.this)
38
+ return expression
39
+
40
+ root = build_scope(expression)
41
+
42
+ if not root:
43
+ return expression
44
+
45
+ # Map of alias->Scope|Table
46
+ # These are all aliases that are already used in the expression.
47
+ # We don't want to create new CTEs that conflict with these names.
48
+ taken: TakenNameMapping = {}
49
+
50
+ # All CTE aliases in the root scope are taken
51
+ for scope in root.cte_scopes:
52
+ taken[scope.expression.parent.alias] = scope
53
+
54
+ # All table names are taken
55
+ for scope in root.traverse():
56
+ taken.update(
57
+ {
58
+ source.name: source
59
+ for _, source in scope.sources.items()
60
+ if isinstance(source, exp.Table)
61
+ }
62
+ )
63
+
64
+ # Map of Expression->alias
65
+ # Existing CTES in the root expression. We'll use this for deduplication.
66
+ existing_ctes: ExistingCTEsMapping = {}
67
+
68
+ with_ = root.expression.args.get("with")
69
+ recursive = False
70
+ if with_:
71
+ recursive = with_.args.get("recursive")
72
+ for cte in with_.expressions:
73
+ existing_ctes[cte.this] = cte.alias
74
+ new_ctes = []
75
+
76
+ # We're adding more CTEs, but we want to maintain the DAG order.
77
+ # Derived tables within an existing CTE need to come before the existing CTE.
78
+ for cte_scope in root.cte_scopes:
79
+ # Append all the new CTEs from this existing CTE
80
+ for scope in cte_scope.traverse():
81
+ if scope is cte_scope:
82
+ # Don't try to eliminate this CTE itself
83
+ continue
84
+ new_cte = _eliminate(scope, existing_ctes, taken)
85
+ if new_cte:
86
+ new_ctes.append(new_cte)
87
+
88
+ # Append the existing CTE itself
89
+ new_ctes.append(cte_scope.expression.parent)
90
+
91
+ # Now append the rest
92
+ for scope in itertools.chain(root.union_scopes, root.subquery_scopes, root.table_scopes):
93
+ for child_scope in scope.traverse():
94
+ new_cte = _eliminate(child_scope, existing_ctes, taken)
95
+ if new_cte:
96
+ new_ctes.append(new_cte)
97
+
98
+ if new_ctes:
99
+ query = expression.expression if isinstance(expression, exp.DDL) else expression
100
+ query.set("with", exp.With(expressions=new_ctes, recursive=recursive))
101
+
102
+ return expression
103
+
104
+
105
+ def _eliminate(
106
+ scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping
107
+ ) -> t.Optional[exp.Expression]:
108
+ if scope.is_derived_table:
109
+ return _eliminate_derived_table(scope, existing_ctes, taken)
110
+
111
+ if scope.is_cte:
112
+ return _eliminate_cte(scope, existing_ctes, taken)
113
+
114
+ return None
115
+
116
+
117
+ def _eliminate_derived_table(
118
+ scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping
119
+ ) -> t.Optional[exp.Expression]:
120
+ # This makes sure that we don't:
121
+ # - drop the "pivot" arg from a pivoted subquery
122
+ # - eliminate a lateral correlated subquery
123
+ if scope.parent.pivots or isinstance(scope.parent.expression, exp.Lateral):
124
+ return None
125
+
126
+ # Get rid of redundant exp.Subquery expressions, i.e. those that are just used as wrappers
127
+ to_replace = scope.expression.parent.unwrap()
128
+ name, cte = _new_cte(scope, existing_ctes, taken)
129
+ table = exp.alias_(exp.table_(name), alias=to_replace.alias or name)
130
+ table.set("joins", to_replace.args.get("joins"))
131
+
132
+ to_replace.replace(table)
133
+
134
+ return cte
135
+
136
+
137
+ def _eliminate_cte(
138
+ scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping
139
+ ) -> t.Optional[exp.Expression]:
140
+ parent = scope.expression.parent
141
+ name, cte = _new_cte(scope, existing_ctes, taken)
142
+
143
+ with_ = parent.parent
144
+ parent.pop()
145
+ if not with_.expressions:
146
+ with_.pop()
147
+
148
+ # Rename references to this CTE
149
+ for child_scope in scope.parent.traverse():
150
+ for table, source in child_scope.selected_sources.values():
151
+ if source is scope:
152
+ new_table = exp.alias_(exp.table_(name), alias=table.alias_or_name, copy=False)
153
+ table.replace(new_table)
154
+
155
+ return cte
156
+
157
+
158
+ def _new_cte(
159
+ scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping
160
+ ) -> t.Tuple[str, t.Optional[exp.Expression]]:
161
+ """
162
+ Returns:
163
+ tuple of (name, cte)
164
+ where `name` is a new name for this CTE in the root scope and `cte` is a new CTE instance.
165
+ If this CTE duplicates an existing CTE, `cte` will be None.
166
+ """
167
+ duplicate_cte_alias = existing_ctes.get(scope.expression)
168
+ parent = scope.expression.parent
169
+ name = parent.alias
170
+
171
+ if not name:
172
+ name = find_new_name(taken=taken, base="cte")
173
+
174
+ if duplicate_cte_alias:
175
+ name = duplicate_cte_alias
176
+ elif taken.get(name):
177
+ name = find_new_name(taken=taken, base=name)
178
+
179
+ taken[name] = scope
180
+
181
+ if not duplicate_cte_alias:
182
+ existing_ctes[scope.expression] = name
183
+ cte = exp.CTE(
184
+ this=scope.expression,
185
+ alias=exp.TableAlias(this=exp.to_identifier(name)),
186
+ )
187
+ else:
188
+ cte = None
189
+ return name, cte
@@ -0,0 +1,50 @@
1
+ from __future__ import annotations
2
+
3
+ import typing as t
4
+
5
+ from sqlglot import alias, exp
6
+ from sqlglot.errors import OptimizeError
7
+ from sqlglot.optimizer.scope import traverse_scope
8
+ from sqlglot.schema import ensure_schema
9
+
10
+ if t.TYPE_CHECKING:
11
+ from sqlglot._typing import E
12
+ from sqlglot.schema import Schema
13
+ from sqlglot.dialects.dialect import DialectType
14
+
15
+
16
+ def isolate_table_selects(
17
+ expression: E,
18
+ schema: t.Optional[t.Dict | Schema] = None,
19
+ dialect: DialectType = None,
20
+ ) -> E:
21
+ schema = ensure_schema(schema, dialect=dialect)
22
+
23
+ for scope in traverse_scope(expression):
24
+ if len(scope.selected_sources) == 1:
25
+ continue
26
+
27
+ for _, source in scope.selected_sources.values():
28
+ assert source.parent
29
+
30
+ if (
31
+ not isinstance(source, exp.Table)
32
+ or not schema.column_names(source)
33
+ or isinstance(source.parent, exp.Subquery)
34
+ or isinstance(source.parent.parent, exp.Table)
35
+ ):
36
+ continue
37
+
38
+ if not source.alias:
39
+ raise OptimizeError("Tables require an alias. Run qualify_tables optimization.")
40
+
41
+ source.replace(
42
+ exp.select("*")
43
+ .from_(
44
+ alias(source, source.alias_or_name, table=True),
45
+ copy=False,
46
+ )
47
+ .subquery(source.alias, copy=False)
48
+ )
49
+
50
+ return expression