altimate-code 0.5.2 → 0.5.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. package/CHANGELOG.md +12 -0
  2. package/bin/altimate +6 -0
  3. package/bin/altimate-code +6 -0
  4. package/dbt-tools/bin/altimate-dbt +2 -0
  5. package/dbt-tools/dist/altimate_python_packages/altimate_packages/altimate/__init__.py +0 -0
  6. package/dbt-tools/dist/altimate_python_packages/altimate_packages/altimate/fetch_schema.py +35 -0
  7. package/dbt-tools/dist/altimate_python_packages/altimate_packages/altimate/utils.py +353 -0
  8. package/dbt-tools/dist/altimate_python_packages/altimate_packages/altimate/validate_sql.py +114 -0
  9. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/__init__.py +178 -0
  10. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/__main__.py +96 -0
  11. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/_typing.py +17 -0
  12. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/__init__.py +3 -0
  13. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/__init__.py +18 -0
  14. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/_typing.py +18 -0
  15. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/column.py +332 -0
  16. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/dataframe.py +866 -0
  17. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/functions.py +1267 -0
  18. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/group.py +59 -0
  19. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/normalize.py +78 -0
  20. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/operations.py +53 -0
  21. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/readwriter.py +108 -0
  22. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/session.py +190 -0
  23. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/transforms.py +9 -0
  24. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/types.py +212 -0
  25. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/util.py +32 -0
  26. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/window.py +134 -0
  27. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/__init__.py +118 -0
  28. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/athena.py +166 -0
  29. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/bigquery.py +1331 -0
  30. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/clickhouse.py +1393 -0
  31. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/databricks.py +131 -0
  32. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/dialect.py +1915 -0
  33. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/doris.py +561 -0
  34. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/drill.py +157 -0
  35. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/druid.py +20 -0
  36. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/duckdb.py +1159 -0
  37. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/dune.py +16 -0
  38. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/hive.py +787 -0
  39. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/materialize.py +94 -0
  40. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/mysql.py +1324 -0
  41. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/oracle.py +378 -0
  42. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/postgres.py +778 -0
  43. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/presto.py +788 -0
  44. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/prql.py +203 -0
  45. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/redshift.py +448 -0
  46. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/risingwave.py +78 -0
  47. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/snowflake.py +1464 -0
  48. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/spark.py +202 -0
  49. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/spark2.py +349 -0
  50. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/sqlite.py +320 -0
  51. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/starrocks.py +343 -0
  52. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/tableau.py +61 -0
  53. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/teradata.py +356 -0
  54. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/trino.py +115 -0
  55. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/tsql.py +1403 -0
  56. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/diff.py +456 -0
  57. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/errors.py +93 -0
  58. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/__init__.py +95 -0
  59. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/context.py +101 -0
  60. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/env.py +246 -0
  61. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/python.py +460 -0
  62. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/table.py +155 -0
  63. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/expressions.py +8870 -0
  64. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/generator.py +4993 -0
  65. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/helper.py +582 -0
  66. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/jsonpath.py +227 -0
  67. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/lineage.py +423 -0
  68. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/__init__.py +11 -0
  69. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/annotate_types.py +589 -0
  70. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/canonicalize.py +222 -0
  71. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/eliminate_ctes.py +43 -0
  72. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/eliminate_joins.py +181 -0
  73. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/eliminate_subqueries.py +189 -0
  74. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/isolate_table_selects.py +50 -0
  75. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/merge_subqueries.py +415 -0
  76. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/normalize.py +200 -0
  77. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/normalize_identifiers.py +64 -0
  78. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/optimize_joins.py +91 -0
  79. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/optimizer.py +94 -0
  80. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/pushdown_predicates.py +222 -0
  81. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/pushdown_projections.py +172 -0
  82. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/qualify.py +104 -0
  83. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/qualify_columns.py +1024 -0
  84. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/qualify_tables.py +155 -0
  85. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/scope.py +904 -0
  86. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/simplify.py +1587 -0
  87. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/unnest_subqueries.py +302 -0
  88. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/parser.py +8501 -0
  89. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/planner.py +463 -0
  90. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/schema.py +588 -0
  91. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/serde.py +68 -0
  92. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/time.py +687 -0
  93. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/tokens.py +1520 -0
  94. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/transforms.py +1020 -0
  95. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/trie.py +81 -0
  96. package/dbt-tools/dist/altimate_python_packages/dbt_core_integration.py +825 -0
  97. package/dbt-tools/dist/altimate_python_packages/dbt_utils.py +157 -0
  98. package/dbt-tools/dist/index.js +23859 -0
  99. package/package.json +13 -13
  100. package/postinstall.mjs +42 -0
  101. package/skills/altimate-setup/SKILL.md +31 -0
@@ -0,0 +1,415 @@
1
+ from __future__ import annotations
2
+
3
+ import typing as t
4
+
5
+ from collections import defaultdict
6
+
7
+ from sqlglot import expressions as exp
8
+ from sqlglot.helper import find_new_name, seq_get
9
+ from sqlglot.optimizer.scope import Scope, traverse_scope
10
+
11
+ if t.TYPE_CHECKING:
12
+ from sqlglot._typing import E
13
+
14
+ FromOrJoin = t.Union[exp.From, exp.Join]
15
+
16
+
17
+ def merge_subqueries(expression: E, leave_tables_isolated: bool = False) -> E:
18
+ """
19
+ Rewrite sqlglot AST to merge derived tables into the outer query.
20
+
21
+ This also merges CTEs if they are selected from only once.
22
+
23
+ Example:
24
+ >>> import sqlglot
25
+ >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y")
26
+ >>> merge_subqueries(expression).sql()
27
+ 'SELECT x.a FROM x CROSS JOIN y'
28
+
29
+ If `leave_tables_isolated` is True, this will not merge inner queries into outer
30
+ queries if it would result in multiple table selects in a single query:
31
+ >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y")
32
+ >>> merge_subqueries(expression, leave_tables_isolated=True).sql()
33
+ 'SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y'
34
+
35
+ Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html
36
+
37
+ Args:
38
+ expression (sqlglot.Expression): expression to optimize
39
+ leave_tables_isolated (bool):
40
+ Returns:
41
+ sqlglot.Expression: optimized expression
42
+ """
43
+ expression = merge_ctes(expression, leave_tables_isolated)
44
+ expression = merge_derived_tables(expression, leave_tables_isolated)
45
+ return expression
46
+
47
+
48
+ # If a derived table has these Select args, it can't be merged
49
+ UNMERGABLE_ARGS = set(exp.Select.arg_types) - {
50
+ "expressions",
51
+ "from",
52
+ "joins",
53
+ "where",
54
+ "order",
55
+ "hint",
56
+ }
57
+
58
+
59
+ # Projections in the outer query that are instances of these types can be replaced
60
+ # without getting wrapped in parentheses, because the precedence won't be altered.
61
+ SAFE_TO_REPLACE_UNWRAPPED = (
62
+ exp.Column,
63
+ exp.EQ,
64
+ exp.Func,
65
+ exp.NEQ,
66
+ exp.Paren,
67
+ )
68
+
69
+
70
+ def merge_ctes(expression: E, leave_tables_isolated: bool = False) -> E:
71
+ scopes = traverse_scope(expression)
72
+
73
+ # All places where we select from CTEs.
74
+ # We key on the CTE scope so we can detect CTES that are selected from multiple times.
75
+ cte_selections = defaultdict(list)
76
+ for outer_scope in scopes:
77
+ for table, inner_scope in outer_scope.selected_sources.values():
78
+ if isinstance(inner_scope, Scope) and inner_scope.is_cte:
79
+ cte_selections[id(inner_scope)].append(
80
+ (
81
+ outer_scope,
82
+ inner_scope,
83
+ table,
84
+ )
85
+ )
86
+
87
+ singular_cte_selections = [v[0] for k, v in cte_selections.items() if len(v) == 1]
88
+ for outer_scope, inner_scope, table in singular_cte_selections:
89
+ from_or_join = table.find_ancestor(exp.From, exp.Join)
90
+ if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
91
+ alias = table.alias_or_name
92
+ _rename_inner_sources(outer_scope, inner_scope, alias)
93
+ _merge_from(outer_scope, inner_scope, table, alias)
94
+ _merge_expressions(outer_scope, inner_scope, alias)
95
+ _merge_order(outer_scope, inner_scope)
96
+ _merge_joins(outer_scope, inner_scope, from_or_join)
97
+ _merge_where(outer_scope, inner_scope, from_or_join)
98
+ _merge_hints(outer_scope, inner_scope)
99
+ _pop_cte(inner_scope)
100
+ outer_scope.clear_cache()
101
+ return expression
102
+
103
+
104
+ def merge_derived_tables(expression: E, leave_tables_isolated: bool = False) -> E:
105
+ for outer_scope in traverse_scope(expression):
106
+ for subquery in outer_scope.derived_tables:
107
+ from_or_join = subquery.find_ancestor(exp.From, exp.Join)
108
+ alias = subquery.alias_or_name
109
+ inner_scope = outer_scope.sources[alias]
110
+ if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
111
+ _rename_inner_sources(outer_scope, inner_scope, alias)
112
+ _merge_from(outer_scope, inner_scope, subquery, alias)
113
+ _merge_expressions(outer_scope, inner_scope, alias)
114
+ _merge_order(outer_scope, inner_scope)
115
+ _merge_joins(outer_scope, inner_scope, from_or_join)
116
+ _merge_where(outer_scope, inner_scope, from_or_join)
117
+ _merge_hints(outer_scope, inner_scope)
118
+ outer_scope.clear_cache()
119
+
120
+ return expression
121
+
122
+
123
+ def _mergeable(
124
+ outer_scope: Scope, inner_scope: Scope, leave_tables_isolated: bool, from_or_join: FromOrJoin
125
+ ) -> bool:
126
+ """
127
+ Return True if `inner_select` can be merged into outer query.
128
+ """
129
+ inner_select = inner_scope.expression.unnest()
130
+
131
+ def _is_a_window_expression_in_unmergable_operation():
132
+ window_aliases = {s.alias_or_name for s in inner_select.selects if s.find(exp.Window)}
133
+ inner_select_name = from_or_join.alias_or_name
134
+ unmergable_window_columns = [
135
+ column
136
+ for column in outer_scope.columns
137
+ if column.find_ancestor(
138
+ exp.Where, exp.Group, exp.Order, exp.Join, exp.Having, exp.AggFunc
139
+ )
140
+ ]
141
+ window_expressions_in_unmergable = [
142
+ column
143
+ for column in unmergable_window_columns
144
+ if column.table == inner_select_name and column.name in window_aliases
145
+ ]
146
+ return any(window_expressions_in_unmergable)
147
+
148
+ def _outer_select_joins_on_inner_select_join():
149
+ """
150
+ All columns from the inner select in the ON clause must be from the first FROM table.
151
+
152
+ That is, this can be merged:
153
+ SELECT * FROM x JOIN (SELECT y.a AS a FROM y JOIN z) AS q ON x.a = q.a
154
+ ^^^ ^
155
+ But this can't:
156
+ SELECT * FROM x JOIN (SELECT z.a AS a FROM y JOIN z) AS q ON x.a = q.a
157
+ ^^^ ^
158
+ """
159
+ if not isinstance(from_or_join, exp.Join):
160
+ return False
161
+
162
+ alias = from_or_join.alias_or_name
163
+
164
+ on = from_or_join.args.get("on")
165
+ if not on:
166
+ return False
167
+ selections = [c.name for c in on.find_all(exp.Column) if c.table == alias]
168
+ inner_from = inner_scope.expression.args.get("from")
169
+ if not inner_from:
170
+ return False
171
+ inner_from_table = inner_from.alias_or_name
172
+ inner_projections = {s.alias_or_name: s for s in inner_scope.expression.selects}
173
+ return any(
174
+ col.table != inner_from_table
175
+ for selection in selections
176
+ for col in inner_projections[selection].find_all(exp.Column)
177
+ )
178
+
179
+ def _is_recursive():
180
+ # Recursive CTEs look like this:
181
+ # WITH RECURSIVE cte AS (
182
+ # SELECT * FROM x <-- inner scope
183
+ # UNION ALL
184
+ # SELECT * FROM cte <-- outer scope
185
+ # )
186
+ cte = inner_scope.expression.parent
187
+ node = outer_scope.expression.parent
188
+
189
+ while node:
190
+ if node is cte:
191
+ return True
192
+ node = node.parent
193
+ return False
194
+
195
+ return (
196
+ isinstance(outer_scope.expression, exp.Select)
197
+ and not outer_scope.expression.is_star
198
+ and isinstance(inner_select, exp.Select)
199
+ and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS)
200
+ and inner_select.args.get("from") is not None
201
+ and not outer_scope.pivots
202
+ and not any(e.find(exp.AggFunc, exp.Select, exp.Explode) for e in inner_select.expressions)
203
+ and not (leave_tables_isolated and len(outer_scope.selected_sources) > 1)
204
+ and not (
205
+ isinstance(from_or_join, exp.Join)
206
+ and inner_select.args.get("where")
207
+ and from_or_join.side in ("FULL", "LEFT", "RIGHT")
208
+ )
209
+ and not (
210
+ isinstance(from_or_join, exp.From)
211
+ and inner_select.args.get("where")
212
+ and any(
213
+ j.side in ("FULL", "RIGHT") for j in outer_scope.expression.args.get("joins", [])
214
+ )
215
+ )
216
+ and not _outer_select_joins_on_inner_select_join()
217
+ and not _is_a_window_expression_in_unmergable_operation()
218
+ and not _is_recursive()
219
+ and not (inner_select.args.get("order") and outer_scope.is_union)
220
+ and not isinstance(seq_get(inner_select.expressions, 0), exp.QueryTransform)
221
+ )
222
+
223
+
224
+ def _rename_inner_sources(outer_scope: Scope, inner_scope: Scope, alias: str) -> None:
225
+ """
226
+ Renames any sources in the inner query that conflict with names in the outer query.
227
+ """
228
+ inner_taken = set(inner_scope.selected_sources)
229
+ outer_taken = set(outer_scope.selected_sources)
230
+ conflicts = outer_taken.intersection(inner_taken)
231
+ conflicts -= {alias}
232
+
233
+ taken = outer_taken.union(inner_taken)
234
+
235
+ for conflict in conflicts:
236
+ new_name = find_new_name(taken, conflict)
237
+
238
+ source, _ = inner_scope.selected_sources[conflict]
239
+ new_alias = exp.to_identifier(new_name)
240
+
241
+ if isinstance(source, exp.Table) and source.alias:
242
+ source.set("alias", new_alias)
243
+ elif isinstance(source, exp.Table):
244
+ source.replace(exp.alias_(source, new_alias))
245
+ elif isinstance(source.parent, exp.Subquery):
246
+ source.parent.set("alias", exp.TableAlias(this=new_alias))
247
+
248
+ for column in inner_scope.source_columns(conflict):
249
+ column.set("table", exp.to_identifier(new_name))
250
+
251
+ inner_scope.rename_source(conflict, new_name)
252
+
253
+
254
+ def _merge_from(
255
+ outer_scope: Scope,
256
+ inner_scope: Scope,
257
+ node_to_replace: t.Union[exp.Subquery, exp.Table],
258
+ alias: str,
259
+ ) -> None:
260
+ """
261
+ Merge FROM clause of inner query into outer query.
262
+ """
263
+ new_subquery = inner_scope.expression.args["from"].this
264
+ new_subquery.set("joins", node_to_replace.args.get("joins"))
265
+ node_to_replace.replace(new_subquery)
266
+ for join_hint in outer_scope.join_hints:
267
+ tables = join_hint.find_all(exp.Table)
268
+ for table in tables:
269
+ if table.alias_or_name == node_to_replace.alias_or_name:
270
+ table.set("this", exp.to_identifier(new_subquery.alias_or_name))
271
+ outer_scope.remove_source(alias)
272
+ outer_scope.add_source(
273
+ new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name]
274
+ )
275
+
276
+
277
+ def _merge_joins(outer_scope: Scope, inner_scope: Scope, from_or_join: FromOrJoin) -> None:
278
+ """
279
+ Merge JOIN clauses of inner query into outer query.
280
+ """
281
+
282
+ new_joins = []
283
+
284
+ joins = inner_scope.expression.args.get("joins") or []
285
+ for join in joins:
286
+ new_joins.append(join)
287
+ outer_scope.add_source(join.alias_or_name, inner_scope.sources[join.alias_or_name])
288
+
289
+ if new_joins:
290
+ outer_joins = outer_scope.expression.args.get("joins", [])
291
+
292
+ # Maintain the join order
293
+ if isinstance(from_or_join, exp.From):
294
+ position = 0
295
+ else:
296
+ position = outer_joins.index(from_or_join) + 1
297
+ outer_joins[position:position] = new_joins
298
+
299
+ outer_scope.expression.set("joins", outer_joins)
300
+
301
+
302
+ def _merge_expressions(outer_scope: Scope, inner_scope: Scope, alias: str) -> None:
303
+ """
304
+ Merge projections of inner query into outer query.
305
+
306
+ Args:
307
+ outer_scope (sqlglot.optimizer.scope.Scope)
308
+ inner_scope (sqlglot.optimizer.scope.Scope)
309
+ alias (str)
310
+ """
311
+ # Collect all columns that reference the alias of the inner query
312
+ outer_columns = defaultdict(list)
313
+ for column in outer_scope.columns:
314
+ if column.table == alias:
315
+ outer_columns[column.name].append(column)
316
+
317
+ # Replace columns with the projection expression in the inner query
318
+ for expression in inner_scope.expression.expressions:
319
+ projection_name = expression.alias_or_name
320
+ if not projection_name:
321
+ continue
322
+ columns_to_replace = outer_columns.get(projection_name, [])
323
+
324
+ expression = expression.unalias()
325
+ must_wrap_expression = not isinstance(expression, SAFE_TO_REPLACE_UNWRAPPED)
326
+
327
+ for column in columns_to_replace:
328
+ # Ensures we don't alter the intended operator precedence if there's additional
329
+ # context surrounding the outer expression (i.e. it's not a simple projection).
330
+ if isinstance(column.parent, (exp.Unary, exp.Binary)) and must_wrap_expression:
331
+ expression = exp.paren(expression, copy=False)
332
+
333
+ column.replace(expression.copy())
334
+
335
+
336
+ def _merge_where(outer_scope: Scope, inner_scope: Scope, from_or_join: FromOrJoin) -> None:
337
+ """
338
+ Merge WHERE clause of inner query into outer query.
339
+
340
+ Args:
341
+ outer_scope (sqlglot.optimizer.scope.Scope)
342
+ inner_scope (sqlglot.optimizer.scope.Scope)
343
+ from_or_join (exp.From|exp.Join)
344
+ """
345
+ where = inner_scope.expression.args.get("where")
346
+ if not where or not where.this:
347
+ return
348
+
349
+ expression = outer_scope.expression
350
+
351
+ if isinstance(from_or_join, exp.Join):
352
+ # Merge predicates from an outer join to the ON clause
353
+ # if it only has columns that are already joined
354
+ from_ = expression.args.get("from")
355
+ sources = {from_.alias_or_name} if from_ else set()
356
+
357
+ for join in expression.args["joins"]:
358
+ source = join.alias_or_name
359
+ sources.add(source)
360
+ if source == from_or_join.alias_or_name:
361
+ break
362
+
363
+ if exp.column_table_names(where.this) <= sources:
364
+ from_or_join.on(where.this, copy=False)
365
+ from_or_join.set("on", from_or_join.args.get("on"))
366
+ return
367
+
368
+ expression.where(where.this, copy=False)
369
+
370
+
371
+ def _merge_order(outer_scope: Scope, inner_scope: Scope) -> None:
372
+ """
373
+ Merge ORDER clause of inner query into outer query.
374
+
375
+ Args:
376
+ outer_scope (sqlglot.optimizer.scope.Scope)
377
+ inner_scope (sqlglot.optimizer.scope.Scope)
378
+ """
379
+ if (
380
+ any(
381
+ outer_scope.expression.args.get(arg) for arg in ["group", "distinct", "having", "order"]
382
+ )
383
+ or len(outer_scope.selected_sources) != 1
384
+ or any(expression.find(exp.AggFunc) for expression in outer_scope.expression.expressions)
385
+ ):
386
+ return
387
+
388
+ outer_scope.expression.set("order", inner_scope.expression.args.get("order"))
389
+
390
+
391
+ def _merge_hints(outer_scope: Scope, inner_scope: Scope) -> None:
392
+ inner_scope_hint = inner_scope.expression.args.get("hint")
393
+ if not inner_scope_hint:
394
+ return
395
+ outer_scope_hint = outer_scope.expression.args.get("hint")
396
+ if outer_scope_hint:
397
+ for hint_expression in inner_scope_hint.expressions:
398
+ outer_scope_hint.append("expressions", hint_expression)
399
+ else:
400
+ outer_scope.expression.set("hint", inner_scope_hint)
401
+
402
+
403
+ def _pop_cte(inner_scope: Scope) -> None:
404
+ """
405
+ Remove CTE from the AST.
406
+
407
+ Args:
408
+ inner_scope (sqlglot.optimizer.scope.Scope)
409
+ """
410
+ cte = inner_scope.expression.parent
411
+ with_ = cte.parent
412
+ if len(with_.expressions) == 1:
413
+ with_.pop()
414
+ else:
415
+ cte.pop()
@@ -0,0 +1,200 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+
5
+ from sqlglot import exp
6
+ from sqlglot.errors import OptimizeError
7
+ from sqlglot.helper import while_changing
8
+ from sqlglot.optimizer.scope import find_all_in_scope
9
+ from sqlglot.optimizer.simplify import flatten, rewrite_between, uniq_sort
10
+
11
+ logger = logging.getLogger("sqlglot")
12
+
13
+
14
+ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = 128):
15
+ """
16
+ Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form.
17
+
18
+ Example:
19
+ >>> import sqlglot
20
+ >>> expression = sqlglot.parse_one("(x AND y) OR z")
21
+ >>> normalize(expression, dnf=False).sql()
22
+ '(x OR z) AND (y OR z)'
23
+
24
+ Args:
25
+ expression: expression to normalize
26
+ dnf: rewrite in disjunctive normal form instead.
27
+ max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion
28
+ Returns:
29
+ sqlglot.Expression: normalized expression
30
+ """
31
+ for node in tuple(expression.walk(prune=lambda e: isinstance(e, exp.Connector))):
32
+ if isinstance(node, exp.Connector):
33
+ if normalized(node, dnf=dnf):
34
+ continue
35
+ root = node is expression
36
+ original = node.copy()
37
+
38
+ node.transform(rewrite_between, copy=False)
39
+ distance = normalization_distance(node, dnf=dnf, max_=max_distance)
40
+
41
+ if distance > max_distance:
42
+ logger.info(
43
+ f"Skipping normalization because distance {distance} exceeds max {max_distance}"
44
+ )
45
+ return expression
46
+
47
+ try:
48
+ node = node.replace(
49
+ while_changing(node, lambda e: distributive_law(e, dnf, max_distance))
50
+ )
51
+ except OptimizeError as e:
52
+ logger.info(e)
53
+ node.replace(original)
54
+ if root:
55
+ return original
56
+ return expression
57
+
58
+ if root:
59
+ expression = node
60
+
61
+ return expression
62
+
63
+
64
+ def normalized(expression: exp.Expression, dnf: bool = False) -> bool:
65
+ """
66
+ Checks whether a given expression is in a normal form of interest.
67
+
68
+ Example:
69
+ >>> from sqlglot import parse_one
70
+ >>> normalized(parse_one("(a AND b) OR c OR (d AND e)"), dnf=True)
71
+ True
72
+ >>> normalized(parse_one("(a OR b) AND c")) # Checks CNF by default
73
+ True
74
+ >>> normalized(parse_one("a AND (b OR c)"), dnf=True)
75
+ False
76
+
77
+ Args:
78
+ expression: The expression to check if it's normalized.
79
+ dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF).
80
+ Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
81
+ """
82
+ ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And)
83
+ return not any(
84
+ connector.find_ancestor(ancestor) for connector in find_all_in_scope(expression, root)
85
+ )
86
+
87
+
88
+ def normalization_distance(
89
+ expression: exp.Expression, dnf: bool = False, max_: float = float("inf")
90
+ ) -> int:
91
+ """
92
+ The difference in the number of predicates between a given expression and its normalized form.
93
+
94
+ This is used as an estimate of the cost of the conversion which is exponential in complexity.
95
+
96
+ Example:
97
+ >>> import sqlglot
98
+ >>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)")
99
+ >>> normalization_distance(expression)
100
+ 4
101
+
102
+ Args:
103
+ expression: The expression to compute the normalization distance for.
104
+ dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF).
105
+ Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
106
+ max_: stop early if count exceeds this.
107
+
108
+ Returns:
109
+ The normalization distance.
110
+ """
111
+ total = -(sum(1 for _ in expression.find_all(exp.Connector)) + 1)
112
+
113
+ for length in _predicate_lengths(expression, dnf, max_):
114
+ total += length
115
+ if total > max_:
116
+ return total
117
+
118
+ return total
119
+
120
+
121
+ def _predicate_lengths(expression, dnf, max_=float("inf"), depth=0):
122
+ """
123
+ Returns a list of predicate lengths when expanded to normalized form.
124
+
125
+ (A AND B) OR C -> [2, 2] because len(A OR C), len(B OR C).
126
+ """
127
+ if depth > max_:
128
+ yield depth
129
+ return
130
+
131
+ expression = expression.unnest()
132
+
133
+ if not isinstance(expression, exp.Connector):
134
+ yield 1
135
+ return
136
+
137
+ depth += 1
138
+ left, right = expression.args.values()
139
+
140
+ if isinstance(expression, exp.And if dnf else exp.Or):
141
+ for a in _predicate_lengths(left, dnf, max_, depth):
142
+ for b in _predicate_lengths(right, dnf, max_, depth):
143
+ yield a + b
144
+ else:
145
+ yield from _predicate_lengths(left, dnf, max_, depth)
146
+ yield from _predicate_lengths(right, dnf, max_, depth)
147
+
148
+
149
+ def distributive_law(expression, dnf, max_distance):
150
+ """
151
+ x OR (y AND z) -> (x OR y) AND (x OR z)
152
+ (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)
153
+ """
154
+ if normalized(expression, dnf=dnf):
155
+ return expression
156
+
157
+ distance = normalization_distance(expression, dnf=dnf, max_=max_distance)
158
+
159
+ if distance > max_distance:
160
+ raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}")
161
+
162
+ exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance))
163
+ to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or)
164
+
165
+ if isinstance(expression, from_exp):
166
+ a, b = expression.unnest_operands()
167
+
168
+ from_func = exp.and_ if from_exp == exp.And else exp.or_
169
+ to_func = exp.and_ if to_exp == exp.And else exp.or_
170
+
171
+ if isinstance(a, to_exp) and isinstance(b, to_exp):
172
+ if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))):
173
+ return _distribute(a, b, from_func, to_func)
174
+ return _distribute(b, a, from_func, to_func)
175
+ if isinstance(a, to_exp):
176
+ return _distribute(b, a, from_func, to_func)
177
+ if isinstance(b, to_exp):
178
+ return _distribute(a, b, from_func, to_func)
179
+
180
+ return expression
181
+
182
+
183
+ def _distribute(a, b, from_func, to_func):
184
+ if isinstance(a, exp.Connector):
185
+ exp.replace_children(
186
+ a,
187
+ lambda c: to_func(
188
+ uniq_sort(flatten(from_func(c, b.left))),
189
+ uniq_sort(flatten(from_func(c, b.right))),
190
+ copy=False,
191
+ ),
192
+ )
193
+ else:
194
+ a = to_func(
195
+ uniq_sort(flatten(from_func(a, b.left))),
196
+ uniq_sort(flatten(from_func(a, b.right))),
197
+ copy=False,
198
+ )
199
+
200
+ return a
@@ -0,0 +1,64 @@
1
+ from __future__ import annotations
2
+
3
+ import typing as t
4
+
5
+ from sqlglot import exp
6
+ from sqlglot.dialects.dialect import Dialect, DialectType
7
+
8
+ if t.TYPE_CHECKING:
9
+ from sqlglot._typing import E
10
+
11
+
12
+ @t.overload
13
+ def normalize_identifiers(expression: E, dialect: DialectType = None) -> E: ...
14
+
15
+
16
+ @t.overload
17
+ def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Identifier: ...
18
+
19
+
20
+ def normalize_identifiers(expression, dialect=None):
21
+ """
22
+ Normalize identifiers by converting them to either lower or upper case,
23
+ ensuring the semantics are preserved in each case (e.g. by respecting
24
+ case-sensitivity).
25
+
26
+ This transformation reflects how identifiers would be resolved by the engine corresponding
27
+ to each SQL dialect, and plays a very important role in the standardization of the AST.
28
+
29
+ It's possible to make this a no-op by adding a special comment next to the
30
+ identifier of interest:
31
+
32
+ SELECT a /* sqlglot.meta case_sensitive */ FROM table
33
+
34
+ In this example, the identifier `a` will not be normalized.
35
+
36
+ Note:
37
+ Some dialects (e.g. DuckDB) treat all identifiers as case-insensitive even
38
+ when they're quoted, so in these cases all identifiers are normalized.
39
+
40
+ Example:
41
+ >>> import sqlglot
42
+ >>> expression = sqlglot.parse_one('SELECT Bar.A AS A FROM "Foo".Bar')
43
+ >>> normalize_identifiers(expression).sql()
44
+ 'SELECT bar.a AS a FROM "Foo".bar'
45
+ >>> normalize_identifiers("foo", dialect="snowflake").sql(dialect="snowflake")
46
+ 'FOO'
47
+
48
+ Args:
49
+ expression: The expression to transform.
50
+ dialect: The dialect to use in order to decide how to normalize identifiers.
51
+
52
+ Returns:
53
+ The transformed expression.
54
+ """
55
+ dialect = Dialect.get_or_raise(dialect)
56
+
57
+ if isinstance(expression, str):
58
+ expression = exp.parse_identifier(expression, dialect=dialect)
59
+
60
+ for node in expression.walk(prune=lambda n: n.meta.get("case_sensitive")):
61
+ if not node.meta.get("case_sensitive"):
62
+ dialect.normalize_identifier(node)
63
+
64
+ return expression