altimate-code 0.5.2 → 0.5.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. package/CHANGELOG.md +12 -0
  2. package/bin/altimate +6 -0
  3. package/bin/altimate-code +6 -0
  4. package/dbt-tools/bin/altimate-dbt +2 -0
  5. package/dbt-tools/dist/altimate_python_packages/altimate_packages/altimate/__init__.py +0 -0
  6. package/dbt-tools/dist/altimate_python_packages/altimate_packages/altimate/fetch_schema.py +35 -0
  7. package/dbt-tools/dist/altimate_python_packages/altimate_packages/altimate/utils.py +353 -0
  8. package/dbt-tools/dist/altimate_python_packages/altimate_packages/altimate/validate_sql.py +114 -0
  9. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/__init__.py +178 -0
  10. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/__main__.py +96 -0
  11. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/_typing.py +17 -0
  12. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/__init__.py +3 -0
  13. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/__init__.py +18 -0
  14. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/_typing.py +18 -0
  15. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/column.py +332 -0
  16. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/dataframe.py +866 -0
  17. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/functions.py +1267 -0
  18. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/group.py +59 -0
  19. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/normalize.py +78 -0
  20. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/operations.py +53 -0
  21. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/readwriter.py +108 -0
  22. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/session.py +190 -0
  23. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/transforms.py +9 -0
  24. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/types.py +212 -0
  25. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/util.py +32 -0
  26. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/window.py +134 -0
  27. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/__init__.py +118 -0
  28. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/athena.py +166 -0
  29. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/bigquery.py +1331 -0
  30. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/clickhouse.py +1393 -0
  31. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/databricks.py +131 -0
  32. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/dialect.py +1915 -0
  33. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/doris.py +561 -0
  34. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/drill.py +157 -0
  35. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/druid.py +20 -0
  36. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/duckdb.py +1159 -0
  37. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/dune.py +16 -0
  38. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/hive.py +787 -0
  39. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/materialize.py +94 -0
  40. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/mysql.py +1324 -0
  41. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/oracle.py +378 -0
  42. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/postgres.py +778 -0
  43. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/presto.py +788 -0
  44. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/prql.py +203 -0
  45. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/redshift.py +448 -0
  46. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/risingwave.py +78 -0
  47. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/snowflake.py +1464 -0
  48. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/spark.py +202 -0
  49. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/spark2.py +349 -0
  50. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/sqlite.py +320 -0
  51. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/starrocks.py +343 -0
  52. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/tableau.py +61 -0
  53. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/teradata.py +356 -0
  54. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/trino.py +115 -0
  55. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/tsql.py +1403 -0
  56. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/diff.py +456 -0
  57. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/errors.py +93 -0
  58. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/__init__.py +95 -0
  59. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/context.py +101 -0
  60. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/env.py +246 -0
  61. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/python.py +460 -0
  62. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/table.py +155 -0
  63. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/expressions.py +8870 -0
  64. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/generator.py +4993 -0
  65. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/helper.py +582 -0
  66. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/jsonpath.py +227 -0
  67. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/lineage.py +423 -0
  68. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/__init__.py +11 -0
  69. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/annotate_types.py +589 -0
  70. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/canonicalize.py +222 -0
  71. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/eliminate_ctes.py +43 -0
  72. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/eliminate_joins.py +181 -0
  73. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/eliminate_subqueries.py +189 -0
  74. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/isolate_table_selects.py +50 -0
  75. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/merge_subqueries.py +415 -0
  76. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/normalize.py +200 -0
  77. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/normalize_identifiers.py +64 -0
  78. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/optimize_joins.py +91 -0
  79. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/optimizer.py +94 -0
  80. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/pushdown_predicates.py +222 -0
  81. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/pushdown_projections.py +172 -0
  82. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/qualify.py +104 -0
  83. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/qualify_columns.py +1024 -0
  84. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/qualify_tables.py +155 -0
  85. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/scope.py +904 -0
  86. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/simplify.py +1587 -0
  87. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/unnest_subqueries.py +302 -0
  88. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/parser.py +8501 -0
  89. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/planner.py +463 -0
  90. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/schema.py +588 -0
  91. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/serde.py +68 -0
  92. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/time.py +687 -0
  93. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/tokens.py +1520 -0
  94. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/transforms.py +1020 -0
  95. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/trie.py +81 -0
  96. package/dbt-tools/dist/altimate_python_packages/dbt_core_integration.py +825 -0
  97. package/dbt-tools/dist/altimate_python_packages/dbt_utils.py +157 -0
  98. package/dbt-tools/dist/index.js +23859 -0
  99. package/package.json +13 -13
  100. package/postinstall.mjs +42 -0
  101. package/skills/altimate-setup/SKILL.md +31 -0
@@ -0,0 +1,1024 @@
1
+ from __future__ import annotations
2
+
3
+ import itertools
4
+ import typing as t
5
+
6
+ from sqlglot import alias, exp
7
+ from sqlglot.dialects.dialect import Dialect, DialectType
8
+ from sqlglot.errors import OptimizeError
9
+ from sqlglot.helper import seq_get, SingleValuedMapping
10
+ from sqlglot.optimizer.annotate_types import TypeAnnotator
11
+ from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope
12
+ from sqlglot.optimizer.simplify import simplify_parens
13
+ from sqlglot.schema import Schema, ensure_schema
14
+
15
+ if t.TYPE_CHECKING:
16
+ from sqlglot._typing import E
17
+
18
+
19
+ def qualify_columns(
20
+ expression: exp.Expression,
21
+ schema: t.Dict | Schema,
22
+ expand_alias_refs: bool = True,
23
+ expand_stars: bool = True,
24
+ infer_schema: t.Optional[bool] = None,
25
+ allow_partial_qualification: bool = False,
26
+ dialect: DialectType = None,
27
+ ) -> exp.Expression:
28
+ """
29
+ Rewrite sqlglot AST to have fully qualified columns.
30
+
31
+ Example:
32
+ >>> import sqlglot
33
+ >>> schema = {"tbl": {"col": "INT"}}
34
+ >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
35
+ >>> qualify_columns(expression, schema).sql()
36
+ 'SELECT tbl.col AS col FROM tbl'
37
+
38
+ Args:
39
+ expression: Expression to qualify.
40
+ schema: Database schema.
41
+ expand_alias_refs: Whether to expand references to aliases.
42
+ expand_stars: Whether to expand star queries. This is a necessary step
43
+ for most of the optimizer's rules to work; do not set to False unless you
44
+ know what you're doing!
45
+ infer_schema: Whether to infer the schema if missing.
46
+ allow_partial_qualification: Whether to allow partial qualification.
47
+
48
+ Returns:
49
+ The qualified expression.
50
+
51
+ Notes:
52
+ - Currently only handles a single PIVOT or UNPIVOT operator
53
+ """
54
+ schema = ensure_schema(schema, dialect=dialect)
55
+ annotator = TypeAnnotator(schema)
56
+ infer_schema = schema.empty if infer_schema is None else infer_schema
57
+ dialect = Dialect.get_or_raise(schema.dialect)
58
+ pseudocolumns = dialect.PSEUDOCOLUMNS
59
+ bigquery = dialect == "bigquery"
60
+
61
+ for scope in traverse_scope(expression):
62
+ scope_expression = scope.expression
63
+ is_select = isinstance(scope_expression, exp.Select)
64
+
65
+ if is_select and scope_expression.args.get("connect"):
66
+ # In Snowflake / Oracle queries that have a CONNECT BY clause, one can use the LEVEL
67
+ # pseudocolumn, which doesn't belong to a table, so we change it into an identifier
68
+ scope_expression.transform(
69
+ lambda n: n.this if isinstance(n, exp.Column) and n.name == "LEVEL" else n,
70
+ copy=False,
71
+ )
72
+ scope.clear_cache()
73
+
74
+ resolver = Resolver(scope, schema, infer_schema=infer_schema)
75
+ _pop_table_column_aliases(scope.ctes)
76
+ _pop_table_column_aliases(scope.derived_tables)
77
+ using_column_tables = _expand_using(scope, resolver)
78
+
79
+ if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs:
80
+ _expand_alias_refs(
81
+ scope,
82
+ resolver,
83
+ dialect,
84
+ expand_only_groupby=bigquery,
85
+ )
86
+
87
+ _convert_columns_to_dots(scope, resolver)
88
+ _qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification)
89
+
90
+ if not schema.empty and expand_alias_refs:
91
+ _expand_alias_refs(scope, resolver, dialect)
92
+
93
+ if is_select:
94
+ if expand_stars:
95
+ _expand_stars(
96
+ scope,
97
+ resolver,
98
+ using_column_tables,
99
+ pseudocolumns,
100
+ annotator,
101
+ )
102
+ qualify_outputs(scope)
103
+
104
+ _expand_group_by(scope, dialect)
105
+
106
+ # DISTINCT ON and ORDER BY follow the same rules (tested in DuckDB, Postgres, ClickHouse)
107
+ # https://www.postgresql.org/docs/current/sql-select.html#SQL-DISTINCT
108
+ _expand_order_by_and_distinct_on(scope, resolver)
109
+
110
+ if bigquery:
111
+ annotator.annotate_scope(scope)
112
+
113
+ return expression
114
+
115
+
116
+ def validate_qualify_columns(expression: E) -> E:
117
+ """Raise an `OptimizeError` if any columns aren't qualified"""
118
+ all_unqualified_columns = []
119
+ for scope in traverse_scope(expression):
120
+ if isinstance(scope.expression, exp.Select):
121
+ unqualified_columns = scope.unqualified_columns
122
+
123
+ if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
124
+ column = scope.external_columns[0]
125
+ for_table = f" for table: '{column.table}'" if column.table else ""
126
+ raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
127
+
128
+ if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
129
+ # New columns produced by the UNPIVOT can't be qualified, but there may be columns
130
+ # under the UNPIVOT's IN clause that can and should be qualified. We recompute
131
+ # this list here to ensure those in the former category will be excluded.
132
+ unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
133
+ unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
134
+
135
+ all_unqualified_columns.extend(unqualified_columns)
136
+
137
+ if all_unqualified_columns:
138
+ raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
139
+
140
+ return expression
141
+
142
+
143
+ def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]:
144
+ name_columns = [
145
+ field.this
146
+ for field in unpivot.fields
147
+ if isinstance(field, exp.In) and isinstance(field.this, exp.Column)
148
+ ]
149
+ value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column))
150
+
151
+ return itertools.chain(name_columns, value_columns)
152
+
153
+
154
+ def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None:
155
+ """
156
+ Remove table column aliases.
157
+
158
+ For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2)
159
+ """
160
+ for derived_table in derived_tables:
161
+ if isinstance(derived_table.parent, exp.With) and derived_table.parent.recursive:
162
+ continue
163
+ table_alias = derived_table.args.get("alias")
164
+ if table_alias:
165
+ table_alias.args.pop("columns", None)
166
+
167
+
168
+ def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
169
+ columns = {}
170
+
171
+ def _update_source_columns(source_name: str) -> None:
172
+ for column_name in resolver.get_source_columns(source_name):
173
+ if column_name not in columns:
174
+ columns[column_name] = source_name
175
+
176
+ joins = list(scope.find_all(exp.Join))
177
+ names = {join.alias_or_name for join in joins}
178
+ ordered = [key for key in scope.selected_sources if key not in names]
179
+
180
+ if names and not ordered:
181
+ raise OptimizeError(f"Joins {names} missing source table {scope.expression}")
182
+
183
+ # Mapping of automatically joined column names to an ordered set of source names (dict).
184
+ column_tables: t.Dict[str, t.Dict[str, t.Any]] = {}
185
+
186
+ for source_name in ordered:
187
+ _update_source_columns(source_name)
188
+
189
+ for i, join in enumerate(joins):
190
+ source_table = ordered[-1]
191
+ if source_table:
192
+ _update_source_columns(source_table)
193
+
194
+ join_table = join.alias_or_name
195
+ ordered.append(join_table)
196
+
197
+ using = join.args.get("using")
198
+ if not using:
199
+ continue
200
+
201
+ join_columns = resolver.get_source_columns(join_table)
202
+ conditions = []
203
+ using_identifier_count = len(using)
204
+ is_semi_or_anti_join = join.is_semi_or_anti_join
205
+
206
+ for identifier in using:
207
+ identifier = identifier.name
208
+ table = columns.get(identifier)
209
+
210
+ if not table or identifier not in join_columns:
211
+ if (columns and "*" not in columns) and join_columns:
212
+ raise OptimizeError(f"Cannot automatically join: {identifier}")
213
+
214
+ table = table or source_table
215
+
216
+ if i == 0 or using_identifier_count == 1:
217
+ lhs: exp.Expression = exp.column(identifier, table=table)
218
+ else:
219
+ coalesce_columns = [
220
+ exp.column(identifier, table=t)
221
+ for t in ordered[:-1]
222
+ if identifier in resolver.get_source_columns(t)
223
+ ]
224
+ if len(coalesce_columns) > 1:
225
+ lhs = exp.func("coalesce", *coalesce_columns)
226
+ else:
227
+ lhs = exp.column(identifier, table=table)
228
+
229
+ conditions.append(lhs.eq(exp.column(identifier, table=join_table)))
230
+
231
+ # Set all values in the dict to None, because we only care about the key ordering
232
+ tables = column_tables.setdefault(identifier, {})
233
+
234
+ # Do not update the dict if this was a SEMI/ANTI join in
235
+ # order to avoid generating COALESCE columns for this join pair
236
+ if not is_semi_or_anti_join:
237
+ if table not in tables:
238
+ tables[table] = None
239
+ if join_table not in tables:
240
+ tables[join_table] = None
241
+
242
+ join.args.pop("using")
243
+ join.set("on", exp.and_(*conditions, copy=False))
244
+
245
+ if column_tables:
246
+ for column in scope.columns:
247
+ if not column.table and column.name in column_tables:
248
+ tables = column_tables[column.name]
249
+ coalesce_args = [exp.column(column.name, table=table) for table in tables]
250
+ replacement: exp.Expression = exp.func("coalesce", *coalesce_args)
251
+
252
+ if isinstance(column.parent, exp.Select):
253
+ # Ensure the USING column keeps its name if it's projected
254
+ replacement = alias(replacement, alias=column.name, copy=False)
255
+ elif isinstance(column.parent, exp.Struct):
256
+ # Ensure the USING column keeps its name if it's an anonymous STRUCT field
257
+ replacement = exp.PropertyEQ(
258
+ this=exp.to_identifier(column.name), expression=replacement
259
+ )
260
+
261
+ scope.replace(column, replacement)
262
+
263
+ return column_tables
264
+
265
+
266
+ def _expand_alias_refs(
267
+ scope: Scope, resolver: Resolver, dialect: Dialect, expand_only_groupby: bool = False
268
+ ) -> None:
269
+ """
270
+ Expand references to aliases.
271
+ Example:
272
+ SELECT y.foo AS bar, bar * 2 AS baz FROM y
273
+ => SELECT y.foo AS bar, y.foo * 2 AS baz FROM y
274
+ """
275
+ expression = scope.expression
276
+
277
+ if not isinstance(expression, exp.Select) or dialect == "oracle":
278
+ return
279
+
280
+ alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {}
281
+ projections = {s.alias_or_name for s in expression.selects}
282
+
283
+ def replace_columns(
284
+ node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False
285
+ ) -> None:
286
+ is_group_by = isinstance(node, exp.Group)
287
+ is_having = isinstance(node, exp.Having)
288
+ if not node or (expand_only_groupby and not is_group_by):
289
+ return
290
+
291
+ for column in walk_in_scope(node, prune=lambda node: node.is_star):
292
+ if not isinstance(column, exp.Column):
293
+ continue
294
+
295
+ # BigQuery's GROUP BY allows alias expansion only for standalone names, e.g:
296
+ # SELECT FUNC(col) AS col FROM t GROUP BY col --> Can be expanded
297
+ # SELECT FUNC(col) AS col FROM t GROUP BY FUNC(col) --> Shouldn't be expanded, will result to FUNC(FUNC(col))
298
+ # This not required for the HAVING clause as it can evaluate expressions using both the alias & the table columns
299
+ if expand_only_groupby and is_group_by and column.parent is not node:
300
+ continue
301
+
302
+ skip_replace = False
303
+ table = resolver.get_table(column.name) if resolve_table and not column.table else None
304
+ alias_expr, i = alias_to_expression.get(column.name, (None, 1))
305
+
306
+ if alias_expr:
307
+ skip_replace = bool(
308
+ alias_expr.find(exp.AggFunc)
309
+ and column.find_ancestor(exp.AggFunc)
310
+ and not isinstance(column.find_ancestor(exp.Window, exp.Select), exp.Window)
311
+ )
312
+
313
+ # BigQuery's having clause gets confused if an alias matches a source.
314
+ # SELECT x.a, max(x.b) as x FROM x GROUP BY 1 HAVING x > 1;
315
+ # If HAVING x is expanded to max(x.b), bigquery treats x as the new projection x instead of the table
316
+ if is_having and dialect == "bigquery":
317
+ skip_replace = skip_replace or any(
318
+ node.parts[0].name in projections
319
+ for node in alias_expr.find_all(exp.Column)
320
+ )
321
+
322
+ if table and (not alias_expr or skip_replace):
323
+ column.set("table", table)
324
+ elif not column.table and alias_expr and not skip_replace:
325
+ if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table):
326
+ if literal_index:
327
+ column.replace(exp.Literal.number(i))
328
+ else:
329
+ column = column.replace(exp.paren(alias_expr))
330
+ simplified = simplify_parens(column)
331
+ if simplified is not column:
332
+ column.replace(simplified)
333
+
334
+ for i, projection in enumerate(expression.selects):
335
+ replace_columns(projection)
336
+ if isinstance(projection, exp.Alias):
337
+ alias_to_expression[projection.alias] = (projection.this, i + 1)
338
+
339
+ parent_scope = scope
340
+ while parent_scope.is_union:
341
+ parent_scope = parent_scope.parent
342
+
343
+ # We shouldn't expand aliases if they match the recursive CTE's columns
344
+ if parent_scope.is_cte:
345
+ cte = parent_scope.expression.parent
346
+ if cte.find_ancestor(exp.With).recursive:
347
+ for recursive_cte_column in cte.args["alias"].columns or cte.this.selects:
348
+ alias_to_expression.pop(recursive_cte_column.output_name, None)
349
+
350
+ replace_columns(expression.args.get("where"))
351
+ replace_columns(expression.args.get("group"), literal_index=True)
352
+ replace_columns(expression.args.get("having"), resolve_table=True)
353
+ replace_columns(expression.args.get("qualify"), resolve_table=True)
354
+
355
+ # Snowflake allows alias expansion in the JOIN ... ON clause (and almost everywhere else)
356
+ # https://docs.snowflake.com/en/sql-reference/sql/select#usage-notes
357
+ if dialect == "snowflake":
358
+ for join in expression.args.get("joins") or []:
359
+ replace_columns(join)
360
+
361
+ scope.clear_cache()
362
+
363
+
364
+ def _expand_group_by(scope: Scope, dialect: DialectType) -> None:
365
+ expression = scope.expression
366
+ group = expression.args.get("group")
367
+ if not group:
368
+ return
369
+
370
+ group.set("expressions", _expand_positional_references(scope, group.expressions, dialect))
371
+ expression.set("group", group)
372
+
373
+
374
+ def _expand_order_by_and_distinct_on(scope: Scope, resolver: Resolver) -> None:
375
+ for modifier_key in ("order", "distinct"):
376
+ modifier = scope.expression.args.get(modifier_key)
377
+ if isinstance(modifier, exp.Distinct):
378
+ modifier = modifier.args.get("on")
379
+
380
+ if not isinstance(modifier, exp.Expression):
381
+ continue
382
+
383
+ modifier_expressions = modifier.expressions
384
+ if modifier_key == "order":
385
+ modifier_expressions = [ordered.this for ordered in modifier_expressions]
386
+
387
+ for original, expanded in zip(
388
+ modifier_expressions,
389
+ _expand_positional_references(
390
+ scope, modifier_expressions, resolver.schema.dialect, alias=True
391
+ ),
392
+ ):
393
+ for agg in original.find_all(exp.AggFunc):
394
+ for col in agg.find_all(exp.Column):
395
+ if not col.table:
396
+ col.set("table", resolver.get_table(col.name))
397
+
398
+ original.replace(expanded)
399
+
400
+ if scope.expression.args.get("group"):
401
+ selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects}
402
+
403
+ for expression in modifier_expressions:
404
+ expression.replace(
405
+ exp.to_identifier(_select_by_pos(scope, expression).alias)
406
+ if expression.is_int
407
+ else selects.get(expression, expression)
408
+ )
409
+
410
+
411
+ def _expand_positional_references(
412
+ scope: Scope, expressions: t.Iterable[exp.Expression], dialect: DialectType, alias: bool = False
413
+ ) -> t.List[exp.Expression]:
414
+ new_nodes: t.List[exp.Expression] = []
415
+ ambiguous_projections = None
416
+
417
+ for node in expressions:
418
+ if node.is_int:
419
+ select = _select_by_pos(scope, t.cast(exp.Literal, node))
420
+
421
+ if alias:
422
+ new_nodes.append(exp.column(select.args["alias"].copy()))
423
+ else:
424
+ select = select.this
425
+
426
+ if dialect == "bigquery":
427
+ if ambiguous_projections is None:
428
+ # When a projection name is also a source name and it is referenced in the
429
+ # GROUP BY clause, BQ can't understand what the identifier corresponds to
430
+ ambiguous_projections = {
431
+ s.alias_or_name
432
+ for s in scope.expression.selects
433
+ if s.alias_or_name in scope.selected_sources
434
+ }
435
+
436
+ ambiguous = any(
437
+ column.parts[0].name in ambiguous_projections
438
+ for column in select.find_all(exp.Column)
439
+ )
440
+ else:
441
+ ambiguous = False
442
+
443
+ if (
444
+ isinstance(select, exp.CONSTANTS)
445
+ or select.find(exp.Explode, exp.Unnest)
446
+ or ambiguous
447
+ ):
448
+ new_nodes.append(node)
449
+ else:
450
+ new_nodes.append(select.copy())
451
+ else:
452
+ new_nodes.append(node)
453
+
454
+ return new_nodes
455
+
456
+
457
+ def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias:
458
+ try:
459
+ return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias)
460
+ except IndexError:
461
+ raise OptimizeError(f"Unknown output column: {node.name}")
462
+
463
+
464
+ def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None:
465
+ """
466
+ Converts `Column` instances that represent struct field lookup into chained `Dots`.
467
+
468
+ Struct field lookups look like columns (e.g. "struct"."field"), but they need to be
469
+ qualified separately and represented as Dot(Dot(...(<table>.<column>, field1), field2, ...)).
470
+ """
471
+ converted = False
472
+ for column in itertools.chain(scope.columns, scope.stars):
473
+ if isinstance(column, exp.Dot):
474
+ continue
475
+
476
+ column_table: t.Optional[str | exp.Identifier] = column.table
477
+ if (
478
+ column_table
479
+ and column_table not in scope.sources
480
+ and (
481
+ not scope.parent
482
+ or column_table not in scope.parent.sources
483
+ or not scope.is_correlated_subquery
484
+ )
485
+ ):
486
+ root, *parts = column.parts
487
+
488
+ if root.name in scope.sources:
489
+ # The struct is already qualified, but we still need to change the AST
490
+ column_table = root
491
+ root, *parts = parts
492
+ else:
493
+ column_table = resolver.get_table(root.name)
494
+
495
+ if column_table:
496
+ converted = True
497
+ column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
498
+
499
+ if converted:
500
+ # We want to re-aggregate the converted columns, otherwise they'd be skipped in
501
+ # a `for column in scope.columns` iteration, even though they shouldn't be
502
+ scope.clear_cache()
503
+
504
+
505
+ def _qualify_columns(scope: Scope, resolver: Resolver, allow_partial_qualification: bool) -> None:
506
+ """Disambiguate columns, ensuring each column specifies a source"""
507
+ for column in scope.columns:
508
+ column_table = column.table
509
+ column_name = column.name
510
+
511
+ if column_table and column_table in scope.sources:
512
+ source_columns = resolver.get_source_columns(column_table)
513
+ if (
514
+ not allow_partial_qualification
515
+ and source_columns
516
+ and column_name not in source_columns
517
+ and "*" not in source_columns
518
+ ):
519
+ raise OptimizeError(f"Unknown column: {column_name}")
520
+
521
+ if not column_table:
522
+ if scope.pivots and not column.find_ancestor(exp.Pivot):
523
+ # If the column is under the Pivot expression, we need to qualify it
524
+ # using the name of the pivoted source instead of the pivot's alias
525
+ column.set("table", exp.to_identifier(scope.pivots[0].alias))
526
+ continue
527
+
528
+ # column_table can be a '' because bigquery unnest has no table alias
529
+ column_table = resolver.get_table(column_name)
530
+ if column_table:
531
+ column.set("table", column_table)
532
+
533
+ for pivot in scope.pivots:
534
+ for column in pivot.find_all(exp.Column):
535
+ if not column.table and column.name in resolver.all_columns:
536
+ column_table = resolver.get_table(column.name)
537
+ if column_table:
538
+ column.set("table", column_table)
539
+
540
+
541
+ def _expand_struct_stars(
542
+ expression: exp.Dot,
543
+ ) -> t.List[exp.Alias]:
544
+ """[BigQuery] Expand/Flatten foo.bar.* where bar is a struct column"""
545
+
546
+ dot_column = t.cast(exp.Column, expression.find(exp.Column))
547
+ if not dot_column.is_type(exp.DataType.Type.STRUCT):
548
+ return []
549
+
550
+ # All nested struct values are ColumnDefs, so normalize the first exp.Column in one
551
+ dot_column = dot_column.copy()
552
+ starting_struct = exp.ColumnDef(this=dot_column.this, kind=dot_column.type)
553
+
554
+ # First part is the table name and last part is the star so they can be dropped
555
+ dot_parts = expression.parts[1:-1]
556
+
557
+ # If we're expanding a nested struct eg. t.c.f1.f2.* find the last struct (f2 in this case)
558
+ for part in dot_parts[1:]:
559
+ for field in t.cast(exp.DataType, starting_struct.kind).expressions:
560
+ # Unable to expand star unless all fields are named
561
+ if not isinstance(field.this, exp.Identifier):
562
+ return []
563
+
564
+ if field.name == part.name and field.kind.is_type(exp.DataType.Type.STRUCT):
565
+ starting_struct = field
566
+ break
567
+ else:
568
+ # There is no matching field in the struct
569
+ return []
570
+
571
+ taken_names = set()
572
+ new_selections = []
573
+
574
+ for field in t.cast(exp.DataType, starting_struct.kind).expressions:
575
+ name = field.name
576
+
577
+ # Ambiguous or anonymous fields can't be expanded
578
+ if name in taken_names or not isinstance(field.this, exp.Identifier):
579
+ return []
580
+
581
+ taken_names.add(name)
582
+
583
+ this = field.this.copy()
584
+ root, *parts = [part.copy() for part in itertools.chain(dot_parts, [this])]
585
+ new_column = exp.column(
586
+ t.cast(exp.Identifier, root),
587
+ table=dot_column.args.get("table"),
588
+ fields=t.cast(t.List[exp.Identifier], parts),
589
+ )
590
+ new_selections.append(alias(new_column, this, copy=False))
591
+
592
+ return new_selections
593
+
594
+
595
+ def _expand_stars(
596
+ scope: Scope,
597
+ resolver: Resolver,
598
+ using_column_tables: t.Dict[str, t.Any],
599
+ pseudocolumns: t.Set[str],
600
+ annotator: TypeAnnotator,
601
+ ) -> None:
602
+ """Expand stars to lists of column selections"""
603
+
604
+ new_selections: t.List[exp.Expression] = []
605
+ except_columns: t.Dict[int, t.Set[str]] = {}
606
+ replace_columns: t.Dict[int, t.Dict[str, exp.Alias]] = {}
607
+ rename_columns: t.Dict[int, t.Dict[str, str]] = {}
608
+
609
+ coalesced_columns = set()
610
+ dialect = resolver.schema.dialect
611
+
612
+ pivot_output_columns = None
613
+ pivot_exclude_columns: t.Set[str] = set()
614
+
615
+ pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0))
616
+ if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names:
617
+ if pivot.unpivot:
618
+ pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)]
619
+
620
+ for field in pivot.fields:
621
+ if isinstance(field, exp.In):
622
+ pivot_exclude_columns.update(
623
+ c.output_name for e in field.expressions for c in e.find_all(exp.Column)
624
+ )
625
+
626
+ else:
627
+ pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column))
628
+
629
+ pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])]
630
+ if not pivot_output_columns:
631
+ pivot_output_columns = [c.alias_or_name for c in pivot.expressions]
632
+
633
+ is_bigquery = dialect == "bigquery"
634
+ if is_bigquery and any(isinstance(col, exp.Dot) for col in scope.stars):
635
+ # Found struct expansion, annotate scope ahead of time
636
+ annotator.annotate_scope(scope)
637
+
638
+ for expression in scope.expression.selects:
639
+ tables = []
640
+ if isinstance(expression, exp.Star):
641
+ tables.extend(scope.selected_sources)
642
+ _add_except_columns(expression, tables, except_columns)
643
+ _add_replace_columns(expression, tables, replace_columns)
644
+ _add_rename_columns(expression, tables, rename_columns)
645
+ elif expression.is_star:
646
+ if not isinstance(expression, exp.Dot):
647
+ tables.append(expression.table)
648
+ _add_except_columns(expression.this, tables, except_columns)
649
+ _add_replace_columns(expression.this, tables, replace_columns)
650
+ _add_rename_columns(expression.this, tables, rename_columns)
651
+ elif is_bigquery:
652
+ struct_fields = _expand_struct_stars(expression)
653
+ if struct_fields:
654
+ new_selections.extend(struct_fields)
655
+ continue
656
+
657
+ if not tables:
658
+ new_selections.append(expression)
659
+ continue
660
+
661
+ for table in tables:
662
+ if table not in scope.sources:
663
+ raise OptimizeError(f"Unknown table: {table}")
664
+
665
+ columns = resolver.get_source_columns(table, only_visible=True)
666
+ columns = columns or scope.outer_columns
667
+
668
+ if pseudocolumns:
669
+ columns = [name for name in columns if name.upper() not in pseudocolumns]
670
+
671
+ if not columns or "*" in columns:
672
+ return
673
+
674
+ table_id = id(table)
675
+ columns_to_exclude = except_columns.get(table_id) or set()
676
+ renamed_columns = rename_columns.get(table_id, {})
677
+ replaced_columns = replace_columns.get(table_id, {})
678
+
679
+ if pivot:
680
+ if pivot_output_columns and pivot_exclude_columns:
681
+ pivot_columns = [c for c in columns if c not in pivot_exclude_columns]
682
+ pivot_columns.extend(pivot_output_columns)
683
+ else:
684
+ pivot_columns = pivot.alias_column_names
685
+
686
+ if pivot_columns:
687
+ new_selections.extend(
688
+ alias(exp.column(name, table=pivot.alias), name, copy=False)
689
+ for name in pivot_columns
690
+ if name not in columns_to_exclude
691
+ )
692
+ continue
693
+
694
+ for name in columns:
695
+ if name in columns_to_exclude or name in coalesced_columns:
696
+ continue
697
+ if name in using_column_tables and table in using_column_tables[name]:
698
+ coalesced_columns.add(name)
699
+ tables = using_column_tables[name]
700
+ coalesce_args = [exp.column(name, table=table) for table in tables]
701
+
702
+ new_selections.append(
703
+ alias(exp.func("coalesce", *coalesce_args), alias=name, copy=False)
704
+ )
705
+ else:
706
+ alias_ = renamed_columns.get(name, name)
707
+ selection_expr = replaced_columns.get(name) or exp.column(name, table=table)
708
+ new_selections.append(
709
+ alias(selection_expr, alias_, copy=False)
710
+ if alias_ != name
711
+ else selection_expr
712
+ )
713
+
714
+ # Ensures we don't overwrite the initial selections with an empty list
715
+ if new_selections and isinstance(scope.expression, exp.Select):
716
+ scope.expression.set("expressions", new_selections)
717
+
718
+
719
+ def _add_except_columns(
720
+ expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]]
721
+ ) -> None:
722
+ except_ = expression.args.get("except")
723
+
724
+ if not except_:
725
+ return
726
+
727
+ columns = {e.name for e in except_}
728
+
729
+ for table in tables:
730
+ except_columns[id(table)] = columns
731
+
732
+
733
+ def _add_rename_columns(
734
+ expression: exp.Expression, tables, rename_columns: t.Dict[int, t.Dict[str, str]]
735
+ ) -> None:
736
+ rename = expression.args.get("rename")
737
+
738
+ if not rename:
739
+ return
740
+
741
+ columns = {e.this.name: e.alias for e in rename}
742
+
743
+ for table in tables:
744
+ rename_columns[id(table)] = columns
745
+
746
+
747
+ def _add_replace_columns(
748
+ expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, exp.Alias]]
749
+ ) -> None:
750
+ replace = expression.args.get("replace")
751
+
752
+ if not replace:
753
+ return
754
+
755
+ columns = {e.alias: e for e in replace}
756
+
757
+ for table in tables:
758
+ replace_columns[id(table)] = columns
759
+
760
+
761
+ def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
762
+ """Ensure all output columns are aliased"""
763
+ if isinstance(scope_or_expression, exp.Expression):
764
+ scope = build_scope(scope_or_expression)
765
+ if not isinstance(scope, Scope):
766
+ return
767
+ else:
768
+ scope = scope_or_expression
769
+
770
+ new_selections = []
771
+ for i, (selection, aliased_column) in enumerate(
772
+ itertools.zip_longest(scope.expression.selects, scope.outer_columns)
773
+ ):
774
+ if selection is None or isinstance(selection, exp.QueryTransform):
775
+ break
776
+
777
+ if isinstance(selection, exp.Subquery):
778
+ if not selection.output_name:
779
+ selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
780
+ elif not isinstance(selection, exp.Alias) and not selection.is_star:
781
+ selection = alias(
782
+ selection,
783
+ alias=selection.output_name or f"_col_{i}",
784
+ copy=False,
785
+ )
786
+ if aliased_column:
787
+ selection.set("alias", exp.to_identifier(aliased_column))
788
+
789
+ new_selections.append(selection)
790
+
791
+ if new_selections and isinstance(scope.expression, exp.Select):
792
+ scope.expression.set("expressions", new_selections)
793
+
794
+
795
+ def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
796
+ """Makes sure all identifiers that need to be quoted are quoted."""
797
+ return expression.transform(
798
+ Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
799
+ ) # type: ignore
800
+
801
+
802
+ def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
803
+ """
804
+ Pushes down the CTE alias columns into the projection,
805
+
806
+ This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
807
+
808
+ Example:
809
+ >>> import sqlglot
810
+ >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
811
+ >>> pushdown_cte_alias_columns(expression).sql()
812
+ 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
813
+
814
+ Args:
815
+ expression: Expression to pushdown.
816
+
817
+ Returns:
818
+ The expression with the CTE aliases pushed down into the projection.
819
+ """
820
+ for cte in expression.find_all(exp.CTE):
821
+ if cte.alias_column_names:
822
+ new_expressions = []
823
+ for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
824
+ if isinstance(projection, exp.Alias):
825
+ projection.set("alias", _alias)
826
+ else:
827
+ projection = alias(projection, alias=_alias)
828
+ new_expressions.append(projection)
829
+ cte.this.set("expressions", new_expressions)
830
+
831
+ return expression
832
+
833
+
834
+ class Resolver:
835
+ """
836
+ Helper for resolving columns.
837
+
838
+ This is a class so we can lazily load some things and easily share them across functions.
839
+ """
840
+
841
+ def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
842
+ self.scope = scope
843
+ self.schema = schema
844
+ self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
845
+ self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
846
+ self._all_columns: t.Optional[t.Set[str]] = None
847
+ self._infer_schema = infer_schema
848
+ self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
849
+
850
+ def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
851
+ """
852
+ Get the table for a column name.
853
+
854
+ Args:
855
+ column_name: The column name to find the table for.
856
+ Returns:
857
+ The table name if it can be found/inferred.
858
+ """
859
+ if self._unambiguous_columns is None:
860
+ self._unambiguous_columns = self._get_unambiguous_columns(
861
+ self._get_all_source_columns()
862
+ )
863
+
864
+ table_name = self._unambiguous_columns.get(column_name)
865
+
866
+ if not table_name and self._infer_schema:
867
+ sources_without_schema = tuple(
868
+ source
869
+ for source, columns in self._get_all_source_columns().items()
870
+ if not columns or "*" in columns
871
+ )
872
+ if len(sources_without_schema) == 1:
873
+ table_name = sources_without_schema[0]
874
+
875
+ if table_name not in self.scope.selected_sources:
876
+ return exp.to_identifier(table_name)
877
+
878
+ node, _ = self.scope.selected_sources.get(table_name)
879
+
880
+ if isinstance(node, exp.Query):
881
+ while node and node.alias != table_name:
882
+ node = node.parent
883
+
884
+ node_alias = node.args.get("alias")
885
+ if node_alias:
886
+ return exp.to_identifier(node_alias.this)
887
+
888
+ return exp.to_identifier(table_name)
889
+
890
+ @property
891
+ def all_columns(self) -> t.Set[str]:
892
+ """All available columns of all sources in this scope"""
893
+ if self._all_columns is None:
894
+ self._all_columns = {
895
+ column for columns in self._get_all_source_columns().values() for column in columns
896
+ }
897
+ return self._all_columns
898
+
899
+ def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
900
+ """Resolve the source columns for a given source `name`."""
901
+ cache_key = (name, only_visible)
902
+ if cache_key not in self._get_source_columns_cache:
903
+ if name not in self.scope.sources:
904
+ raise OptimizeError(f"Unknown table: {name}")
905
+
906
+ source = self.scope.sources[name]
907
+
908
+ if isinstance(source, exp.Table):
909
+ columns = self.schema.column_names(source, only_visible)
910
+ elif isinstance(source, Scope) and isinstance(
911
+ source.expression, (exp.Values, exp.Unnest)
912
+ ):
913
+ columns = source.expression.named_selects
914
+
915
+ # in bigquery, unnest structs are automatically scoped as tables, so you can
916
+ # directly select a struct field in a query.
917
+ # this handles the case where the unnest is statically defined.
918
+ if self.schema.dialect == "bigquery":
919
+ if source.expression.is_type(exp.DataType.Type.STRUCT):
920
+ for k in source.expression.type.expressions: # type: ignore
921
+ columns.append(k.name)
922
+ elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation):
923
+ set_op = source.expression
924
+
925
+ # BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME
926
+ on_column_list = set_op.args.get("on")
927
+
928
+ if on_column_list:
929
+ # The resulting columns are the columns in the ON clause:
930
+ # {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...)
931
+ columns = [col.name for col in on_column_list]
932
+ elif set_op.side or set_op.kind:
933
+ side = set_op.side
934
+ kind = set_op.kind
935
+
936
+ left = set_op.left.named_selects
937
+ right = set_op.right.named_selects
938
+
939
+ # We use dict.fromkeys to deduplicate keys and maintain insertion order
940
+ if side == "LEFT":
941
+ columns = left
942
+ elif side == "FULL":
943
+ columns = list(dict.fromkeys(left + right))
944
+ elif kind == "INNER":
945
+ columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys())
946
+ else:
947
+ columns = set_op.named_selects
948
+ else:
949
+ select = seq_get(source.expression.selects, 0)
950
+
951
+ if isinstance(select, exp.QueryTransform):
952
+ # https://spark.apache.org/docs/3.5.1/sql-ref-syntax-qry-select-transform.html
953
+ schema = select.args.get("schema")
954
+ columns = [c.name for c in schema.expressions] if schema else ["key", "value"]
955
+ else:
956
+ columns = source.expression.named_selects
957
+
958
+ node, _ = self.scope.selected_sources.get(name) or (None, None)
959
+ if isinstance(node, Scope):
960
+ column_aliases = node.expression.alias_column_names
961
+ elif isinstance(node, exp.Expression):
962
+ column_aliases = node.alias_column_names
963
+ else:
964
+ column_aliases = []
965
+
966
+ if column_aliases:
967
+ # If the source's columns are aliased, their aliases shadow the corresponding column names.
968
+ # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
969
+ columns = [
970
+ alias or name
971
+ for (name, alias) in itertools.zip_longest(columns, column_aliases)
972
+ ]
973
+
974
+ self._get_source_columns_cache[cache_key] = columns
975
+
976
+ return self._get_source_columns_cache[cache_key]
977
+
978
+ def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
979
+ if self._source_columns is None:
980
+ self._source_columns = {
981
+ source_name: self.get_source_columns(source_name)
982
+ for source_name, source in itertools.chain(
983
+ self.scope.selected_sources.items(), self.scope.lateral_sources.items()
984
+ )
985
+ }
986
+ return self._source_columns
987
+
988
+ def _get_unambiguous_columns(
989
+ self, source_columns: t.Dict[str, t.Sequence[str]]
990
+ ) -> t.Mapping[str, str]:
991
+ """
992
+ Find all the unambiguous columns in sources.
993
+
994
+ Args:
995
+ source_columns: Mapping of names to source columns.
996
+
997
+ Returns:
998
+ Mapping of column name to source name.
999
+ """
1000
+ if not source_columns:
1001
+ return {}
1002
+
1003
+ source_columns_pairs = list(source_columns.items())
1004
+
1005
+ first_table, first_columns = source_columns_pairs[0]
1006
+
1007
+ if len(source_columns_pairs) == 1:
1008
+ # Performance optimization - avoid copying first_columns if there is only one table.
1009
+ return SingleValuedMapping(first_columns, first_table)
1010
+
1011
+ unambiguous_columns = {col: first_table for col in first_columns}
1012
+ all_columns = set(unambiguous_columns)
1013
+
1014
+ for table, columns in source_columns_pairs[1:]:
1015
+ unique = set(columns)
1016
+ ambiguous = all_columns.intersection(unique)
1017
+ all_columns.update(columns)
1018
+
1019
+ for column in ambiguous:
1020
+ unambiguous_columns.pop(column, None)
1021
+ for column in unique.difference(ambiguous):
1022
+ unambiguous_columns[column] = table
1023
+
1024
+ return unambiguous_columns