altimate-code 0.5.2 → 0.5.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. package/CHANGELOG.md +12 -0
  2. package/bin/altimate +6 -0
  3. package/bin/altimate-code +6 -0
  4. package/dbt-tools/bin/altimate-dbt +2 -0
  5. package/dbt-tools/dist/altimate_python_packages/altimate_packages/altimate/__init__.py +0 -0
  6. package/dbt-tools/dist/altimate_python_packages/altimate_packages/altimate/fetch_schema.py +35 -0
  7. package/dbt-tools/dist/altimate_python_packages/altimate_packages/altimate/utils.py +353 -0
  8. package/dbt-tools/dist/altimate_python_packages/altimate_packages/altimate/validate_sql.py +114 -0
  9. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/__init__.py +178 -0
  10. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/__main__.py +96 -0
  11. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/_typing.py +17 -0
  12. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/__init__.py +3 -0
  13. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/__init__.py +18 -0
  14. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/_typing.py +18 -0
  15. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/column.py +332 -0
  16. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/dataframe.py +866 -0
  17. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/functions.py +1267 -0
  18. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/group.py +59 -0
  19. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/normalize.py +78 -0
  20. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/operations.py +53 -0
  21. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/readwriter.py +108 -0
  22. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/session.py +190 -0
  23. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/transforms.py +9 -0
  24. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/types.py +212 -0
  25. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/util.py +32 -0
  26. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/window.py +134 -0
  27. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/__init__.py +118 -0
  28. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/athena.py +166 -0
  29. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/bigquery.py +1331 -0
  30. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/clickhouse.py +1393 -0
  31. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/databricks.py +131 -0
  32. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/dialect.py +1915 -0
  33. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/doris.py +561 -0
  34. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/drill.py +157 -0
  35. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/druid.py +20 -0
  36. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/duckdb.py +1159 -0
  37. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/dune.py +16 -0
  38. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/hive.py +787 -0
  39. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/materialize.py +94 -0
  40. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/mysql.py +1324 -0
  41. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/oracle.py +378 -0
  42. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/postgres.py +778 -0
  43. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/presto.py +788 -0
  44. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/prql.py +203 -0
  45. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/redshift.py +448 -0
  46. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/risingwave.py +78 -0
  47. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/snowflake.py +1464 -0
  48. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/spark.py +202 -0
  49. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/spark2.py +349 -0
  50. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/sqlite.py +320 -0
  51. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/starrocks.py +343 -0
  52. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/tableau.py +61 -0
  53. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/teradata.py +356 -0
  54. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/trino.py +115 -0
  55. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/tsql.py +1403 -0
  56. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/diff.py +456 -0
  57. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/errors.py +93 -0
  58. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/__init__.py +95 -0
  59. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/context.py +101 -0
  60. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/env.py +246 -0
  61. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/python.py +460 -0
  62. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/table.py +155 -0
  63. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/expressions.py +8870 -0
  64. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/generator.py +4993 -0
  65. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/helper.py +582 -0
  66. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/jsonpath.py +227 -0
  67. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/lineage.py +423 -0
  68. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/__init__.py +11 -0
  69. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/annotate_types.py +589 -0
  70. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/canonicalize.py +222 -0
  71. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/eliminate_ctes.py +43 -0
  72. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/eliminate_joins.py +181 -0
  73. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/eliminate_subqueries.py +189 -0
  74. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/isolate_table_selects.py +50 -0
  75. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/merge_subqueries.py +415 -0
  76. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/normalize.py +200 -0
  77. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/normalize_identifiers.py +64 -0
  78. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/optimize_joins.py +91 -0
  79. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/optimizer.py +94 -0
  80. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/pushdown_predicates.py +222 -0
  81. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/pushdown_projections.py +172 -0
  82. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/qualify.py +104 -0
  83. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/qualify_columns.py +1024 -0
  84. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/qualify_tables.py +155 -0
  85. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/scope.py +904 -0
  86. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/simplify.py +1587 -0
  87. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/unnest_subqueries.py +302 -0
  88. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/parser.py +8501 -0
  89. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/planner.py +463 -0
  90. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/schema.py +588 -0
  91. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/serde.py +68 -0
  92. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/time.py +687 -0
  93. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/tokens.py +1520 -0
  94. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/transforms.py +1020 -0
  95. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/trie.py +81 -0
  96. package/dbt-tools/dist/altimate_python_packages/dbt_core_integration.py +825 -0
  97. package/dbt-tools/dist/altimate_python_packages/dbt_utils.py +157 -0
  98. package/dbt-tools/dist/index.js +23859 -0
  99. package/package.json +13 -13
  100. package/postinstall.mjs +42 -0
  101. package/skills/altimate-setup/SKILL.md +31 -0
@@ -0,0 +1,1020 @@
1
+ from __future__ import annotations
2
+
3
+ import typing as t
4
+
5
+ from sqlglot import expressions as exp
6
+ from sqlglot.errors import UnsupportedError
7
+ from sqlglot.helper import find_new_name, name_sequence
8
+
9
+
10
+ if t.TYPE_CHECKING:
11
+ from sqlglot._typing import E
12
+ from sqlglot.generator import Generator
13
+
14
+
15
+ def preprocess(
16
+ transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
17
+ ) -> t.Callable[[Generator, exp.Expression], str]:
18
+ """
19
+ Creates a new transform by chaining a sequence of transformations and converts the resulting
20
+ expression to SQL, using either the "_sql" method corresponding to the resulting expression,
21
+ or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
22
+
23
+ Args:
24
+ transforms: sequence of transform functions. These will be called in order.
25
+
26
+ Returns:
27
+ Function that can be used as a generator transform.
28
+ """
29
+
30
+ def _to_sql(self, expression: exp.Expression) -> str:
31
+ expression_type = type(expression)
32
+
33
+ try:
34
+ expression = transforms[0](expression)
35
+ for transform in transforms[1:]:
36
+ expression = transform(expression)
37
+ except UnsupportedError as unsupported_error:
38
+ self.unsupported(str(unsupported_error))
39
+
40
+ _sql_handler = getattr(self, expression.key + "_sql", None)
41
+ if _sql_handler:
42
+ return _sql_handler(expression)
43
+
44
+ transforms_handler = self.TRANSFORMS.get(type(expression))
45
+ if transforms_handler:
46
+ if expression_type is type(expression):
47
+ if isinstance(expression, exp.Func):
48
+ return self.function_fallback_sql(expression)
49
+
50
+ # Ensures we don't enter an infinite loop. This can happen when the original expression
51
+ # has the same type as the final expression and there's no _sql method available for it,
52
+ # because then it'd re-enter _to_sql.
53
+ raise ValueError(
54
+ f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
55
+ )
56
+
57
+ return transforms_handler(self, expression)
58
+
59
+ raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
60
+
61
+ return _to_sql
62
+
63
+
64
+ def unnest_generate_date_array_using_recursive_cte(expression: exp.Expression) -> exp.Expression:
65
+ if isinstance(expression, exp.Select):
66
+ count = 0
67
+ recursive_ctes = []
68
+
69
+ for unnest in expression.find_all(exp.Unnest):
70
+ if (
71
+ not isinstance(unnest.parent, (exp.From, exp.Join))
72
+ or len(unnest.expressions) != 1
73
+ or not isinstance(unnest.expressions[0], exp.GenerateDateArray)
74
+ ):
75
+ continue
76
+
77
+ generate_date_array = unnest.expressions[0]
78
+ start = generate_date_array.args.get("start")
79
+ end = generate_date_array.args.get("end")
80
+ step = generate_date_array.args.get("step")
81
+
82
+ if not start or not end or not isinstance(step, exp.Interval):
83
+ continue
84
+
85
+ alias = unnest.args.get("alias")
86
+ column_name = alias.columns[0] if isinstance(alias, exp.TableAlias) else "date_value"
87
+
88
+ start = exp.cast(start, "date")
89
+ date_add = exp.func(
90
+ "date_add", column_name, exp.Literal.number(step.name), step.args.get("unit")
91
+ )
92
+ cast_date_add = exp.cast(date_add, "date")
93
+
94
+ cte_name = "_generated_dates" + (f"_{count}" if count else "")
95
+
96
+ base_query = exp.select(start.as_(column_name))
97
+ recursive_query = (
98
+ exp.select(cast_date_add)
99
+ .from_(cte_name)
100
+ .where(cast_date_add <= exp.cast(end, "date"))
101
+ )
102
+ cte_query = base_query.union(recursive_query, distinct=False)
103
+
104
+ generate_dates_query = exp.select(column_name).from_(cte_name)
105
+ unnest.replace(generate_dates_query.subquery(cte_name))
106
+
107
+ recursive_ctes.append(
108
+ exp.alias_(exp.CTE(this=cte_query), cte_name, table=[column_name])
109
+ )
110
+ count += 1
111
+
112
+ if recursive_ctes:
113
+ with_expression = expression.args.get("with") or exp.With()
114
+ with_expression.set("recursive", True)
115
+ with_expression.set("expressions", [*recursive_ctes, *with_expression.expressions])
116
+ expression.set("with", with_expression)
117
+
118
+ return expression
119
+
120
+
121
+ def unnest_generate_series(expression: exp.Expression) -> exp.Expression:
122
+ """Unnests GENERATE_SERIES or SEQUENCE table references."""
123
+ this = expression.this
124
+ if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries):
125
+ unnest = exp.Unnest(expressions=[this])
126
+ if expression.alias:
127
+ return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False)
128
+
129
+ return unnest
130
+
131
+ return expression
132
+
133
+
134
+ def unalias_group(expression: exp.Expression) -> exp.Expression:
135
+ """
136
+ Replace references to select aliases in GROUP BY clauses.
137
+
138
+ Example:
139
+ >>> import sqlglot
140
+ >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
141
+ 'SELECT a AS b FROM x GROUP BY 1'
142
+
143
+ Args:
144
+ expression: the expression that will be transformed.
145
+
146
+ Returns:
147
+ The transformed expression.
148
+ """
149
+ if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
150
+ aliased_selects = {
151
+ e.alias: i
152
+ for i, e in enumerate(expression.parent.expressions, start=1)
153
+ if isinstance(e, exp.Alias)
154
+ }
155
+
156
+ for group_by in expression.expressions:
157
+ if (
158
+ isinstance(group_by, exp.Column)
159
+ and not group_by.table
160
+ and group_by.name in aliased_selects
161
+ ):
162
+ group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
163
+
164
+ return expression
165
+
166
+
167
+ def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
168
+ """
169
+ Convert SELECT DISTINCT ON statements to a subquery with a window function.
170
+
171
+ This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
172
+
173
+ Args:
174
+ expression: the expression that will be transformed.
175
+
176
+ Returns:
177
+ The transformed expression.
178
+ """
179
+ if (
180
+ isinstance(expression, exp.Select)
181
+ and expression.args.get("distinct")
182
+ and isinstance(expression.args["distinct"].args.get("on"), exp.Tuple)
183
+ ):
184
+ row_number_window_alias = find_new_name(expression.named_selects, "_row_number")
185
+
186
+ distinct_cols = expression.args["distinct"].pop().args["on"].expressions
187
+ window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
188
+
189
+ order = expression.args.get("order")
190
+ if order:
191
+ window.set("order", order.pop())
192
+ else:
193
+ window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
194
+
195
+ window = exp.alias_(window, row_number_window_alias)
196
+ expression.select(window, copy=False)
197
+
198
+ # We add aliases to the projections so that we can safely reference them in the outer query
199
+ new_selects = []
200
+ taken_names = {row_number_window_alias}
201
+ for select in expression.selects[:-1]:
202
+ if select.is_star:
203
+ new_selects = [exp.Star()]
204
+ break
205
+
206
+ if not isinstance(select, exp.Alias):
207
+ alias = find_new_name(taken_names, select.output_name or "_col")
208
+ quoted = select.this.args.get("quoted") if isinstance(select, exp.Column) else None
209
+ select = select.replace(exp.alias_(select, alias, quoted=quoted))
210
+
211
+ taken_names.add(select.output_name)
212
+ new_selects.append(select.args["alias"])
213
+
214
+ return (
215
+ exp.select(*new_selects, copy=False)
216
+ .from_(expression.subquery("_t", copy=False), copy=False)
217
+ .where(exp.column(row_number_window_alias).eq(1), copy=False)
218
+ )
219
+
220
+ return expression
221
+
222
+
223
+ def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
224
+ """
225
+ Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
226
+
227
+ The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
228
+ https://docs.snowflake.com/en/sql-reference/constructs/qualify
229
+
230
+ Some dialects don't support window functions in the WHERE clause, so we need to include them as
231
+ projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
232
+ if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
233
+ otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a
234
+ newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the
235
+ corresponding expression to avoid creating invalid column references.
236
+ """
237
+ if isinstance(expression, exp.Select) and expression.args.get("qualify"):
238
+ taken = set(expression.named_selects)
239
+ for select in expression.selects:
240
+ if not select.alias_or_name:
241
+ alias = find_new_name(taken, "_c")
242
+ select.replace(exp.alias_(select, alias))
243
+ taken.add(alias)
244
+
245
+ def _select_alias_or_name(select: exp.Expression) -> str | exp.Column:
246
+ alias_or_name = select.alias_or_name
247
+ identifier = select.args.get("alias") or select.this
248
+ if isinstance(identifier, exp.Identifier):
249
+ return exp.column(alias_or_name, quoted=identifier.args.get("quoted"))
250
+ return alias_or_name
251
+
252
+ outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects)))
253
+ qualify_filters = expression.args["qualify"].pop().this
254
+ expression_by_alias = {
255
+ select.alias: select.this
256
+ for select in expression.selects
257
+ if isinstance(select, exp.Alias)
258
+ }
259
+
260
+ select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
261
+ for select_candidate in list(qualify_filters.find_all(select_candidates)):
262
+ if isinstance(select_candidate, exp.Window):
263
+ if expression_by_alias:
264
+ for column in select_candidate.find_all(exp.Column):
265
+ expr = expression_by_alias.get(column.name)
266
+ if expr:
267
+ column.replace(expr)
268
+
269
+ alias = find_new_name(expression.named_selects, "_w")
270
+ expression.select(exp.alias_(select_candidate, alias), copy=False)
271
+ column = exp.column(alias)
272
+
273
+ if isinstance(select_candidate.parent, exp.Qualify):
274
+ qualify_filters = column
275
+ else:
276
+ select_candidate.replace(column)
277
+ elif select_candidate.name not in expression.named_selects:
278
+ expression.select(select_candidate.copy(), copy=False)
279
+
280
+ return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
281
+ qualify_filters, copy=False
282
+ )
283
+
284
+ return expression
285
+
286
+
287
+ def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
288
+ """
289
+ Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
290
+ other expressions. This transforms removes the precision from parameterized types in expressions.
291
+ """
292
+ for node in expression.find_all(exp.DataType):
293
+ node.set(
294
+ "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
295
+ )
296
+
297
+ return expression
298
+
299
+
300
+ def unqualify_unnest(expression: exp.Expression) -> exp.Expression:
301
+ """Remove references to unnest table aliases, added by the optimizer's qualify_columns step."""
302
+ from sqlglot.optimizer.scope import find_all_in_scope
303
+
304
+ if isinstance(expression, exp.Select):
305
+ unnest_aliases = {
306
+ unnest.alias
307
+ for unnest in find_all_in_scope(expression, exp.Unnest)
308
+ if isinstance(unnest.parent, (exp.From, exp.Join))
309
+ }
310
+ if unnest_aliases:
311
+ for column in expression.find_all(exp.Column):
312
+ leftmost_part = column.parts[0]
313
+ if leftmost_part.arg_key != "this" and leftmost_part.this in unnest_aliases:
314
+ leftmost_part.pop()
315
+
316
+ return expression
317
+
318
+
319
+ def unnest_to_explode(
320
+ expression: exp.Expression,
321
+ unnest_using_arrays_zip: bool = True,
322
+ ) -> exp.Expression:
323
+ """Convert cross join unnest into lateral view explode."""
324
+
325
+ def _unnest_zip_exprs(
326
+ u: exp.Unnest, unnest_exprs: t.List[exp.Expression], has_multi_expr: bool
327
+ ) -> t.List[exp.Expression]:
328
+ if has_multi_expr:
329
+ if not unnest_using_arrays_zip:
330
+ raise UnsupportedError("Cannot transpile UNNEST with multiple input arrays")
331
+
332
+ # Use INLINE(ARRAYS_ZIP(...)) for multiple expressions
333
+ zip_exprs: t.List[exp.Expression] = [
334
+ exp.Anonymous(this="ARRAYS_ZIP", expressions=unnest_exprs)
335
+ ]
336
+ u.set("expressions", zip_exprs)
337
+ return zip_exprs
338
+ return unnest_exprs
339
+
340
+ def _udtf_type(u: exp.Unnest, has_multi_expr: bool) -> t.Type[exp.Func]:
341
+ if u.args.get("offset"):
342
+ return exp.Posexplode
343
+ return exp.Inline if has_multi_expr else exp.Explode
344
+
345
+ if isinstance(expression, exp.Select):
346
+ from_ = expression.args.get("from")
347
+
348
+ if from_ and isinstance(from_.this, exp.Unnest):
349
+ unnest = from_.this
350
+ alias = unnest.args.get("alias")
351
+ exprs = unnest.expressions
352
+ has_multi_expr = len(exprs) > 1
353
+ this, *expressions = _unnest_zip_exprs(unnest, exprs, has_multi_expr)
354
+
355
+ unnest.replace(
356
+ exp.Table(
357
+ this=_udtf_type(unnest, has_multi_expr)(
358
+ this=this,
359
+ expressions=expressions,
360
+ ),
361
+ alias=exp.TableAlias(this=alias.this, columns=alias.columns) if alias else None,
362
+ )
363
+ )
364
+
365
+ joins = expression.args.get("joins") or []
366
+ for join in list(joins):
367
+ join_expr = join.this
368
+
369
+ is_lateral = isinstance(join_expr, exp.Lateral)
370
+
371
+ unnest = join_expr.this if is_lateral else join_expr
372
+
373
+ if isinstance(unnest, exp.Unnest):
374
+ if is_lateral:
375
+ alias = join_expr.args.get("alias")
376
+ else:
377
+ alias = unnest.args.get("alias")
378
+ exprs = unnest.expressions
379
+ # The number of unnest.expressions will be changed by _unnest_zip_exprs, we need to record it here
380
+ has_multi_expr = len(exprs) > 1
381
+ exprs = _unnest_zip_exprs(unnest, exprs, has_multi_expr)
382
+
383
+ joins.remove(join)
384
+
385
+ alias_cols = alias.columns if alias else []
386
+
387
+ # # Handle UNNEST to LATERAL VIEW EXPLODE: Exception is raised when there are 0 or > 2 aliases
388
+ # Spark LATERAL VIEW EXPLODE requires single alias for array/struct and two for Map type column unlike unnest in trino/presto which can take an arbitrary amount.
389
+ # Refs: https://spark.apache.org/docs/latest/sql-ref-syntax-qry-select-lateral-view.html
390
+
391
+ if not has_multi_expr and len(alias_cols) not in (1, 2):
392
+ raise UnsupportedError(
393
+ "CROSS JOIN UNNEST to LATERAL VIEW EXPLODE transformation requires explicit column aliases"
394
+ )
395
+
396
+ for e, column in zip(exprs, alias_cols):
397
+ expression.append(
398
+ "laterals",
399
+ exp.Lateral(
400
+ this=_udtf_type(unnest, has_multi_expr)(this=e),
401
+ view=True,
402
+ alias=exp.TableAlias(
403
+ this=alias.this, # type: ignore
404
+ columns=alias_cols,
405
+ ),
406
+ ),
407
+ )
408
+
409
+ return expression
410
+
411
+
412
+ def explode_projection_to_unnest(
413
+ index_offset: int = 0,
414
+ ) -> t.Callable[[exp.Expression], exp.Expression]:
415
+ """Convert explode/posexplode projections into unnests."""
416
+
417
+ def _explode_projection_to_unnest(expression: exp.Expression) -> exp.Expression:
418
+ if isinstance(expression, exp.Select):
419
+ from sqlglot.optimizer.scope import Scope
420
+
421
+ taken_select_names = set(expression.named_selects)
422
+ taken_source_names = {name for name, _ in Scope(expression).references}
423
+
424
+ def new_name(names: t.Set[str], name: str) -> str:
425
+ name = find_new_name(names, name)
426
+ names.add(name)
427
+ return name
428
+
429
+ arrays: t.List[exp.Condition] = []
430
+ series_alias = new_name(taken_select_names, "pos")
431
+ series = exp.alias_(
432
+ exp.Unnest(
433
+ expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
434
+ ),
435
+ new_name(taken_source_names, "_u"),
436
+ table=[series_alias],
437
+ )
438
+
439
+ # we use list here because expression.selects is mutated inside the loop
440
+ for select in list(expression.selects):
441
+ explode = select.find(exp.Explode)
442
+
443
+ if explode:
444
+ pos_alias = ""
445
+ explode_alias = ""
446
+
447
+ if isinstance(select, exp.Alias):
448
+ explode_alias = select.args["alias"]
449
+ alias = select
450
+ elif isinstance(select, exp.Aliases):
451
+ pos_alias = select.aliases[0]
452
+ explode_alias = select.aliases[1]
453
+ alias = select.replace(exp.alias_(select.this, "", copy=False))
454
+ else:
455
+ alias = select.replace(exp.alias_(select, ""))
456
+ explode = alias.find(exp.Explode)
457
+ assert explode
458
+
459
+ is_posexplode = isinstance(explode, exp.Posexplode)
460
+ explode_arg = explode.this
461
+
462
+ if isinstance(explode, exp.ExplodeOuter):
463
+ bracket = explode_arg[0]
464
+ bracket.set("safe", True)
465
+ bracket.set("offset", True)
466
+ explode_arg = exp.func(
467
+ "IF",
468
+ exp.func(
469
+ "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array())
470
+ ).eq(0),
471
+ exp.array(bracket, copy=False),
472
+ explode_arg,
473
+ )
474
+
475
+ # This ensures that we won't use [POS]EXPLODE's argument as a new selection
476
+ if isinstance(explode_arg, exp.Column):
477
+ taken_select_names.add(explode_arg.output_name)
478
+
479
+ unnest_source_alias = new_name(taken_source_names, "_u")
480
+
481
+ if not explode_alias:
482
+ explode_alias = new_name(taken_select_names, "col")
483
+
484
+ if is_posexplode:
485
+ pos_alias = new_name(taken_select_names, "pos")
486
+
487
+ if not pos_alias:
488
+ pos_alias = new_name(taken_select_names, "pos")
489
+
490
+ alias.set("alias", exp.to_identifier(explode_alias))
491
+
492
+ series_table_alias = series.args["alias"].this
493
+ column = exp.If(
494
+ this=exp.column(series_alias, table=series_table_alias).eq(
495
+ exp.column(pos_alias, table=unnest_source_alias)
496
+ ),
497
+ true=exp.column(explode_alias, table=unnest_source_alias),
498
+ )
499
+
500
+ explode.replace(column)
501
+
502
+ if is_posexplode:
503
+ expressions = expression.expressions
504
+ expressions.insert(
505
+ expressions.index(alias) + 1,
506
+ exp.If(
507
+ this=exp.column(series_alias, table=series_table_alias).eq(
508
+ exp.column(pos_alias, table=unnest_source_alias)
509
+ ),
510
+ true=exp.column(pos_alias, table=unnest_source_alias),
511
+ ).as_(pos_alias),
512
+ )
513
+ expression.set("expressions", expressions)
514
+
515
+ if not arrays:
516
+ if expression.args.get("from"):
517
+ expression.join(series, copy=False, join_type="CROSS")
518
+ else:
519
+ expression.from_(series, copy=False)
520
+
521
+ size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
522
+ arrays.append(size)
523
+
524
+ # trino doesn't support left join unnest with on conditions
525
+ # if it did, this would be much simpler
526
+ expression.join(
527
+ exp.alias_(
528
+ exp.Unnest(
529
+ expressions=[explode_arg.copy()],
530
+ offset=exp.to_identifier(pos_alias),
531
+ ),
532
+ unnest_source_alias,
533
+ table=[explode_alias],
534
+ ),
535
+ join_type="CROSS",
536
+ copy=False,
537
+ )
538
+
539
+ if index_offset != 1:
540
+ size = size - 1
541
+
542
+ expression.where(
543
+ exp.column(series_alias, table=series_table_alias)
544
+ .eq(exp.column(pos_alias, table=unnest_source_alias))
545
+ .or_(
546
+ (exp.column(series_alias, table=series_table_alias) > size).and_(
547
+ exp.column(pos_alias, table=unnest_source_alias).eq(size)
548
+ )
549
+ ),
550
+ copy=False,
551
+ )
552
+
553
+ if arrays:
554
+ end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
555
+
556
+ if index_offset != 1:
557
+ end = end - (1 - index_offset)
558
+ series.expressions[0].set("end", end)
559
+
560
+ return expression
561
+
562
+ return _explode_projection_to_unnest
563
+
564
+
565
+ def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
566
+ """Transforms percentiles by adding a WITHIN GROUP clause to them."""
567
+ if (
568
+ isinstance(expression, exp.PERCENTILES)
569
+ and not isinstance(expression.parent, exp.WithinGroup)
570
+ and expression.expression
571
+ ):
572
+ column = expression.this.pop()
573
+ expression.set("this", expression.expression.pop())
574
+ order = exp.Order(expressions=[exp.Ordered(this=column)])
575
+ expression = exp.WithinGroup(this=expression, expression=order)
576
+
577
+ return expression
578
+
579
+
580
+ def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
581
+ """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
582
+ if (
583
+ isinstance(expression, exp.WithinGroup)
584
+ and isinstance(expression.this, exp.PERCENTILES)
585
+ and isinstance(expression.expression, exp.Order)
586
+ ):
587
+ quantile = expression.this.this
588
+ input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
589
+ return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
590
+
591
+ return expression
592
+
593
+
594
+ def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
595
+ """Uses projection output names in recursive CTE definitions to define the CTEs' columns."""
596
+ if isinstance(expression, exp.With) and expression.recursive:
597
+ next_name = name_sequence("_c_")
598
+
599
+ for cte in expression.expressions:
600
+ if not cte.args["alias"].columns:
601
+ query = cte.this
602
+ if isinstance(query, exp.SetOperation):
603
+ query = query.this
604
+
605
+ cte.args["alias"].set(
606
+ "columns",
607
+ [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
608
+ )
609
+
610
+ return expression
611
+
612
+
613
+ def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
614
+ """Replace 'epoch' in casts by the equivalent date literal."""
615
+ if (
616
+ isinstance(expression, (exp.Cast, exp.TryCast))
617
+ and expression.name.lower() == "epoch"
618
+ and expression.to.this in exp.DataType.TEMPORAL_TYPES
619
+ ):
620
+ expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
621
+
622
+ return expression
623
+
624
+
625
+ def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
626
+ """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead."""
627
+ if isinstance(expression, exp.Select):
628
+ for join in expression.args.get("joins") or []:
629
+ on = join.args.get("on")
630
+ if on and join.kind in ("SEMI", "ANTI"):
631
+ subquery = exp.select("1").from_(join.this).where(on)
632
+ exists = exp.Exists(this=subquery)
633
+ if join.kind == "ANTI":
634
+ exists = exists.not_(copy=False)
635
+
636
+ join.pop()
637
+ expression.where(exists, copy=False)
638
+
639
+ return expression
640
+
641
+
642
+ def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression:
643
+ """
644
+ Converts a query with a FULL OUTER join to a union of identical queries that
645
+ use LEFT/RIGHT OUTER joins instead. This transformation currently only works
646
+ for queries that have a single FULL OUTER join.
647
+ """
648
+ if isinstance(expression, exp.Select):
649
+ full_outer_joins = [
650
+ (index, join)
651
+ for index, join in enumerate(expression.args.get("joins") or [])
652
+ if join.side == "FULL"
653
+ ]
654
+
655
+ if len(full_outer_joins) == 1:
656
+ expression_copy = expression.copy()
657
+ expression.set("limit", None)
658
+ index, full_outer_join = full_outer_joins[0]
659
+
660
+ tables = (expression.args["from"].alias_or_name, full_outer_join.alias_or_name)
661
+ join_conditions = full_outer_join.args.get("on") or exp.and_(
662
+ *[
663
+ exp.column(col, tables[0]).eq(exp.column(col, tables[1]))
664
+ for col in full_outer_join.args.get("using")
665
+ ]
666
+ )
667
+
668
+ full_outer_join.set("side", "left")
669
+ anti_join_clause = exp.select("1").from_(expression.args["from"]).where(join_conditions)
670
+ expression_copy.args["joins"][index].set("side", "right")
671
+ expression_copy = expression_copy.where(exp.Exists(this=anti_join_clause).not_())
672
+ expression_copy.args.pop("with", None) # remove CTEs from RIGHT side
673
+ expression.args.pop("order", None) # remove order by from LEFT side
674
+
675
+ return exp.union(expression, expression_copy, copy=False, distinct=False)
676
+
677
+ return expression
678
+
679
+
680
+ def move_ctes_to_top_level(expression: E) -> E:
681
+ """
682
+ Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be
683
+ defined at the top-level, so for example queries like:
684
+
685
+ SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
686
+
687
+ are invalid in those dialects. This transformation can be used to ensure all CTEs are
688
+ moved to the top level so that the final SQL code is valid from a syntax standpoint.
689
+
690
+ TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
691
+ """
692
+ top_level_with = expression.args.get("with")
693
+ for inner_with in expression.find_all(exp.With):
694
+ if inner_with.parent is expression:
695
+ continue
696
+
697
+ if not top_level_with:
698
+ top_level_with = inner_with.pop()
699
+ expression.set("with", top_level_with)
700
+ else:
701
+ if inner_with.recursive:
702
+ top_level_with.set("recursive", True)
703
+
704
+ parent_cte = inner_with.find_ancestor(exp.CTE)
705
+ inner_with.pop()
706
+
707
+ if parent_cte:
708
+ i = top_level_with.expressions.index(parent_cte)
709
+ top_level_with.expressions[i:i] = inner_with.expressions
710
+ top_level_with.set("expressions", top_level_with.expressions)
711
+ else:
712
+ top_level_with.set(
713
+ "expressions", top_level_with.expressions + inner_with.expressions
714
+ )
715
+
716
+ return expression
717
+
718
+
719
+ def ensure_bools(expression: exp.Expression) -> exp.Expression:
720
+ """Converts numeric values used in conditions into explicit boolean expressions."""
721
+ from sqlglot.optimizer.canonicalize import ensure_bools
722
+
723
+ def _ensure_bool(node: exp.Expression) -> None:
724
+ if (
725
+ node.is_number
726
+ or (
727
+ not isinstance(node, exp.SubqueryPredicate)
728
+ and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
729
+ )
730
+ or (isinstance(node, exp.Column) and not node.type)
731
+ ):
732
+ node.replace(node.neq(0))
733
+
734
+ for node in expression.walk():
735
+ ensure_bools(node, _ensure_bool)
736
+
737
+ return expression
738
+
739
+
740
+ def unqualify_columns(expression: exp.Expression) -> exp.Expression:
741
+ for column in expression.find_all(exp.Column):
742
+ # We only wanna pop off the table, db, catalog args
743
+ for part in column.parts[:-1]:
744
+ part.pop()
745
+
746
+ return expression
747
+
748
+
749
+ def remove_unique_constraints(expression: exp.Expression) -> exp.Expression:
750
+ assert isinstance(expression, exp.Create)
751
+ for constraint in expression.find_all(exp.UniqueColumnConstraint):
752
+ if constraint.parent:
753
+ constraint.parent.pop()
754
+
755
+ return expression
756
+
757
+
758
+ def ctas_with_tmp_tables_to_create_tmp_view(
759
+ expression: exp.Expression,
760
+ tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e,
761
+ ) -> exp.Expression:
762
+ assert isinstance(expression, exp.Create)
763
+ properties = expression.args.get("properties")
764
+ temporary = any(
765
+ isinstance(prop, exp.TemporaryProperty)
766
+ for prop in (properties.expressions if properties else [])
767
+ )
768
+
769
+ # CTAS with temp tables map to CREATE TEMPORARY VIEW
770
+ if expression.kind == "TABLE" and temporary:
771
+ if expression.expression:
772
+ return exp.Create(
773
+ kind="TEMPORARY VIEW",
774
+ this=expression.this,
775
+ expression=expression.expression,
776
+ )
777
+ return tmp_storage_provider(expression)
778
+
779
+ return expression
780
+
781
+
782
+ def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression:
783
+ """
784
+ In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the
785
+ PARTITIONED BY value is an array of column names, they are transformed into a schema.
786
+ The corresponding columns are removed from the create statement.
787
+ """
788
+ assert isinstance(expression, exp.Create)
789
+ has_schema = isinstance(expression.this, exp.Schema)
790
+ is_partitionable = expression.kind in {"TABLE", "VIEW"}
791
+
792
+ if has_schema and is_partitionable:
793
+ prop = expression.find(exp.PartitionedByProperty)
794
+ if prop and prop.this and not isinstance(prop.this, exp.Schema):
795
+ schema = expression.this
796
+ columns = {v.name.upper() for v in prop.this.expressions}
797
+ partitions = [col for col in schema.expressions if col.name.upper() in columns]
798
+ schema.set("expressions", [e for e in schema.expressions if e not in partitions])
799
+ prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
800
+ expression.set("this", schema)
801
+
802
+ return expression
803
+
804
+
805
+ def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression:
806
+ """
807
+ Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
808
+
809
+ Currently, SQLGlot uses the DATASOURCE format for Spark 3.
810
+ """
811
+ assert isinstance(expression, exp.Create)
812
+ prop = expression.find(exp.PartitionedByProperty)
813
+ if (
814
+ prop
815
+ and prop.this
816
+ and isinstance(prop.this, exp.Schema)
817
+ and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions)
818
+ ):
819
+ prop_this = exp.Tuple(
820
+ expressions=[exp.to_identifier(e.this) for e in prop.this.expressions]
821
+ )
822
+ schema = expression.this
823
+ for e in prop.this.expressions:
824
+ schema.append("expressions", e)
825
+ prop.set("this", prop_this)
826
+
827
+ return expression
828
+
829
+
830
+ def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression:
831
+ """Converts struct arguments to aliases, e.g. STRUCT(1 AS y)."""
832
+ if isinstance(expression, exp.Struct):
833
+ expression.set(
834
+ "expressions",
835
+ [
836
+ exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e
837
+ for e in expression.expressions
838
+ ],
839
+ )
840
+
841
+ return expression
842
+
843
+
844
+ def eliminate_join_marks(expression: exp.Expression) -> exp.Expression:
845
+ """https://docs.oracle.com/cd/B19306_01/server.102/b14200/queries006.htm#sthref3178
846
+
847
+ 1. You cannot specify the (+) operator in a query block that also contains FROM clause join syntax.
848
+
849
+ 2. The (+) operator can appear only in the WHERE clause or, in the context of left-correlation (that is, when specifying the TABLE clause) in the FROM clause, and can be applied only to a column of a table or view.
850
+
851
+ The (+) operator does not produce an outer join if you specify one table in the outer query and the other table in an inner query.
852
+
853
+ You cannot use the (+) operator to outer-join a table to itself, although self joins are valid.
854
+
855
+ The (+) operator can be applied only to a column, not to an arbitrary expression. However, an arbitrary expression can contain one or more columns marked with the (+) operator.
856
+
857
+ A WHERE condition containing the (+) operator cannot be combined with another condition using the OR logical operator.
858
+
859
+ A WHERE condition cannot use the IN comparison condition to compare a column marked with the (+) operator with an expression.
860
+
861
+ A WHERE condition cannot compare any column marked with the (+) operator with a subquery.
862
+
863
+ -- example with WHERE
864
+ SELECT d.department_name, sum(e.salary) as total_salary
865
+ FROM departments d, employees e
866
+ WHERE e.department_id(+) = d.department_id
867
+ group by department_name
868
+
869
+ -- example of left correlation in select
870
+ SELECT d.department_name, (
871
+ SELECT SUM(e.salary)
872
+ FROM employees e
873
+ WHERE e.department_id(+) = d.department_id) AS total_salary
874
+ FROM departments d;
875
+
876
+ -- example of left correlation in from
877
+ SELECT d.department_name, t.total_salary
878
+ FROM departments d, (
879
+ SELECT SUM(e.salary) AS total_salary
880
+ FROM employees e
881
+ WHERE e.department_id(+) = d.department_id
882
+ ) t
883
+ """
884
+
885
+ from sqlglot.optimizer.scope import traverse_scope
886
+ from sqlglot.optimizer.normalize import normalize, normalized
887
+ from collections import defaultdict
888
+
889
+ # we go in reverse to check the main query for left correlation
890
+ for scope in reversed(traverse_scope(expression)):
891
+ query = scope.expression
892
+
893
+ where = query.args.get("where")
894
+ joins = query.args.get("joins", [])
895
+
896
+ # knockout: we do not support left correlation (see point 2)
897
+ assert not scope.is_correlated_subquery, "Correlated queries are not supported"
898
+
899
+ # nothing to do - we check it here after knockout above
900
+ if not where or not any(c.args.get("join_mark") for c in where.find_all(exp.Column)):
901
+ continue
902
+
903
+ # make sure we have AND of ORs to have clear join terms
904
+ where = normalize(where.this)
905
+ assert normalized(where), "Cannot normalize JOIN predicates"
906
+
907
+ joins_ons = defaultdict(list) # dict of {name: list of join AND conditions}
908
+ for cond in [where] if not isinstance(where, exp.And) else where.flatten():
909
+ join_cols = [col for col in cond.find_all(exp.Column) if col.args.get("join_mark")]
910
+
911
+ left_join_table = set(col.table for col in join_cols)
912
+ if not left_join_table:
913
+ continue
914
+
915
+ assert not (
916
+ len(left_join_table) > 1
917
+ ), "Cannot combine JOIN predicates from different tables"
918
+
919
+ for col in join_cols:
920
+ col.set("join_mark", False)
921
+
922
+ joins_ons[left_join_table.pop()].append(cond)
923
+
924
+ old_joins = {join.alias_or_name: join for join in joins}
925
+ new_joins = {}
926
+ query_from = query.args["from"]
927
+
928
+ for table, predicates in joins_ons.items():
929
+ join_what = old_joins.get(table, query_from).this.copy()
930
+ new_joins[join_what.alias_or_name] = exp.Join(
931
+ this=join_what, on=exp.and_(*predicates), kind="LEFT"
932
+ )
933
+
934
+ for p in predicates:
935
+ while isinstance(p.parent, exp.Paren):
936
+ p.parent.replace(p)
937
+
938
+ parent = p.parent
939
+ p.pop()
940
+ if isinstance(parent, exp.Binary):
941
+ parent.replace(parent.right if parent.left is None else parent.left)
942
+ elif isinstance(parent, exp.Where):
943
+ parent.pop()
944
+
945
+ if query_from.alias_or_name in new_joins:
946
+ only_old_joins = old_joins.keys() - new_joins.keys()
947
+ assert (
948
+ len(only_old_joins) >= 1
949
+ ), "Cannot determine which table to use in the new FROM clause"
950
+
951
+ new_from_name = list(only_old_joins)[0]
952
+ query.set("from", exp.From(this=old_joins[new_from_name].this))
953
+
954
+ if new_joins:
955
+ for n, j in old_joins.items(): # preserve any other joins
956
+ if n not in new_joins and n != query.args["from"].name:
957
+ if not j.kind:
958
+ j.set("kind", "CROSS")
959
+ new_joins[n] = j
960
+ query.set("joins", list(new_joins.values()))
961
+
962
+ return expression
963
+
964
+
965
+ def any_to_exists(expression: exp.Expression) -> exp.Expression:
966
+ """
967
+ Transform ANY operator to Spark's EXISTS
968
+
969
+ For example,
970
+ - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col)
971
+ - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5)
972
+
973
+ Both ANY and EXISTS accept queries but currently only array expressions are supported for this
974
+ transformation
975
+ """
976
+ if isinstance(expression, exp.Select):
977
+ for any_expr in expression.find_all(exp.Any):
978
+ this = any_expr.this
979
+ if isinstance(this, exp.Query):
980
+ continue
981
+
982
+ binop = any_expr.parent
983
+ if isinstance(binop, exp.Binary):
984
+ lambda_arg = exp.to_identifier("x")
985
+ any_expr.replace(lambda_arg)
986
+ lambda_expr = exp.Lambda(this=binop.copy(), expressions=[lambda_arg])
987
+ binop.replace(exp.Exists(this=this.unnest(), expression=lambda_expr))
988
+
989
+ return expression
990
+
991
+
992
+ def eliminate_window_clause(expression: exp.Expression) -> exp.Expression:
993
+ """Eliminates the `WINDOW` query clause by inling each named window."""
994
+ if isinstance(expression, exp.Select) and expression.args.get("windows"):
995
+ from sqlglot.optimizer.scope import find_all_in_scope
996
+
997
+ windows = expression.args["windows"]
998
+ expression.set("windows", None)
999
+
1000
+ window_expression: t.Dict[str, exp.Expression] = {}
1001
+
1002
+ def _inline_inherited_window(window: exp.Expression) -> None:
1003
+ inherited_window = window_expression.get(window.alias.lower())
1004
+ if not inherited_window:
1005
+ return
1006
+
1007
+ window.set("alias", None)
1008
+ for key in ("partition_by", "order", "spec"):
1009
+ arg = inherited_window.args.get(key)
1010
+ if arg:
1011
+ window.set(key, arg.copy())
1012
+
1013
+ for window in windows:
1014
+ _inline_inherited_window(window)
1015
+ window_expression[window.name.lower()] = window
1016
+
1017
+ for window in find_all_in_scope(expression, exp.Window):
1018
+ _inline_inherited_window(window)
1019
+
1020
+ return expression