sqlglot 27.27.0__py3-none-any.whl → 28.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlglot/__init__.py +1 -0
- sqlglot/__main__.py +6 -4
- sqlglot/_version.py +2 -2
- sqlglot/dialects/bigquery.py +118 -279
- sqlglot/dialects/clickhouse.py +73 -5
- sqlglot/dialects/databricks.py +38 -1
- sqlglot/dialects/dialect.py +354 -275
- sqlglot/dialects/dremio.py +4 -1
- sqlglot/dialects/duckdb.py +754 -25
- sqlglot/dialects/exasol.py +243 -10
- sqlglot/dialects/hive.py +8 -8
- sqlglot/dialects/mysql.py +14 -4
- sqlglot/dialects/oracle.py +29 -0
- sqlglot/dialects/postgres.py +60 -26
- sqlglot/dialects/presto.py +47 -16
- sqlglot/dialects/redshift.py +16 -0
- sqlglot/dialects/risingwave.py +3 -0
- sqlglot/dialects/singlestore.py +12 -3
- sqlglot/dialects/snowflake.py +239 -218
- sqlglot/dialects/spark.py +15 -4
- sqlglot/dialects/spark2.py +11 -48
- sqlglot/dialects/sqlite.py +10 -0
- sqlglot/dialects/starrocks.py +3 -0
- sqlglot/dialects/teradata.py +5 -8
- sqlglot/dialects/trino.py +6 -0
- sqlglot/dialects/tsql.py +61 -22
- sqlglot/diff.py +4 -2
- sqlglot/errors.py +69 -0
- sqlglot/executor/__init__.py +5 -10
- sqlglot/executor/python.py +1 -29
- sqlglot/expressions.py +637 -100
- sqlglot/generator.py +160 -43
- sqlglot/helper.py +2 -44
- sqlglot/lineage.py +10 -4
- sqlglot/optimizer/annotate_types.py +247 -140
- sqlglot/optimizer/canonicalize.py +6 -1
- sqlglot/optimizer/eliminate_joins.py +1 -1
- sqlglot/optimizer/eliminate_subqueries.py +2 -2
- sqlglot/optimizer/merge_subqueries.py +5 -5
- sqlglot/optimizer/normalize.py +20 -13
- sqlglot/optimizer/normalize_identifiers.py +17 -3
- sqlglot/optimizer/optimizer.py +4 -0
- sqlglot/optimizer/pushdown_predicates.py +1 -1
- sqlglot/optimizer/qualify.py +18 -10
- sqlglot/optimizer/qualify_columns.py +122 -275
- sqlglot/optimizer/qualify_tables.py +128 -76
- sqlglot/optimizer/resolver.py +374 -0
- sqlglot/optimizer/scope.py +27 -16
- sqlglot/optimizer/simplify.py +1075 -959
- sqlglot/optimizer/unnest_subqueries.py +12 -2
- sqlglot/parser.py +296 -170
- sqlglot/planner.py +2 -2
- sqlglot/schema.py +15 -4
- sqlglot/tokens.py +42 -7
- sqlglot/transforms.py +77 -22
- sqlglot/typing/__init__.py +316 -0
- sqlglot/typing/bigquery.py +376 -0
- sqlglot/typing/hive.py +12 -0
- sqlglot/typing/presto.py +24 -0
- sqlglot/typing/snowflake.py +505 -0
- sqlglot/typing/spark2.py +58 -0
- sqlglot/typing/tsql.py +9 -0
- {sqlglot-27.27.0.dist-info → sqlglot-28.4.0.dist-info}/METADATA +2 -2
- sqlglot-28.4.0.dist-info/RECORD +92 -0
- sqlglot-27.27.0.dist-info/RECORD +0 -84
- {sqlglot-27.27.0.dist-info → sqlglot-28.4.0.dist-info}/WHEEL +0 -0
- {sqlglot-27.27.0.dist-info → sqlglot-28.4.0.dist-info}/licenses/LICENSE +0 -0
- {sqlglot-27.27.0.dist-info → sqlglot-28.4.0.dist-info}/top_level.txt +0 -0
|
@@ -5,9 +5,10 @@ import typing as t
|
|
|
5
5
|
|
|
6
6
|
from sqlglot import alias, exp
|
|
7
7
|
from sqlglot.dialects.dialect import Dialect, DialectType
|
|
8
|
-
from sqlglot.errors import OptimizeError
|
|
9
|
-
from sqlglot.helper import seq_get
|
|
8
|
+
from sqlglot.errors import OptimizeError, highlight_sql
|
|
9
|
+
from sqlglot.helper import seq_get
|
|
10
10
|
from sqlglot.optimizer.annotate_types import TypeAnnotator
|
|
11
|
+
from sqlglot.optimizer.resolver import Resolver
|
|
11
12
|
from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope
|
|
12
13
|
from sqlglot.optimizer.simplify import simplify_parens
|
|
13
14
|
from sqlglot.schema import Schema, ensure_schema
|
|
@@ -54,22 +55,17 @@ def qualify_columns(
|
|
|
54
55
|
schema = ensure_schema(schema, dialect=dialect)
|
|
55
56
|
annotator = TypeAnnotator(schema)
|
|
56
57
|
infer_schema = schema.empty if infer_schema is None else infer_schema
|
|
57
|
-
dialect =
|
|
58
|
+
dialect = schema.dialect or Dialect()
|
|
58
59
|
pseudocolumns = dialect.PSEUDOCOLUMNS
|
|
59
|
-
bigquery = dialect == "bigquery"
|
|
60
60
|
|
|
61
61
|
for scope in traverse_scope(expression):
|
|
62
|
+
if dialect.PREFER_CTE_ALIAS_COLUMN:
|
|
63
|
+
pushdown_cte_alias_columns(scope)
|
|
64
|
+
|
|
62
65
|
scope_expression = scope.expression
|
|
63
66
|
is_select = isinstance(scope_expression, exp.Select)
|
|
64
67
|
|
|
65
|
-
|
|
66
|
-
# In Snowflake / Oracle queries that have a CONNECT BY clause, one can use the LEVEL
|
|
67
|
-
# pseudocolumn, which doesn't belong to a table, so we change it into an identifier
|
|
68
|
-
scope_expression.transform(
|
|
69
|
-
lambda n: n.this if isinstance(n, exp.Column) and n.name == "LEVEL" else n,
|
|
70
|
-
copy=False,
|
|
71
|
-
)
|
|
72
|
-
scope.clear_cache()
|
|
68
|
+
_separate_pseudocolumns(scope, pseudocolumns)
|
|
73
69
|
|
|
74
70
|
resolver = Resolver(scope, schema, infer_schema=infer_schema)
|
|
75
71
|
_pop_table_column_aliases(scope.ctes)
|
|
@@ -81,11 +77,15 @@ def qualify_columns(
|
|
|
81
77
|
scope,
|
|
82
78
|
resolver,
|
|
83
79
|
dialect,
|
|
84
|
-
expand_only_groupby=
|
|
80
|
+
expand_only_groupby=dialect.EXPAND_ONLY_GROUP_ALIAS_REF,
|
|
85
81
|
)
|
|
86
82
|
|
|
87
83
|
_convert_columns_to_dots(scope, resolver)
|
|
88
|
-
_qualify_columns(
|
|
84
|
+
_qualify_columns(
|
|
85
|
+
scope,
|
|
86
|
+
resolver,
|
|
87
|
+
allow_partial_qualification=allow_partial_qualification,
|
|
88
|
+
)
|
|
89
89
|
|
|
90
90
|
if not schema.empty and expand_alias_refs:
|
|
91
91
|
_expand_alias_refs(scope, resolver, dialect)
|
|
@@ -107,13 +107,13 @@ def qualify_columns(
|
|
|
107
107
|
# https://www.postgresql.org/docs/current/sql-select.html#SQL-DISTINCT
|
|
108
108
|
_expand_order_by_and_distinct_on(scope, resolver)
|
|
109
109
|
|
|
110
|
-
if
|
|
110
|
+
if dialect.ANNOTATE_ALL_SCOPES:
|
|
111
111
|
annotator.annotate_scope(scope)
|
|
112
112
|
|
|
113
113
|
return expression
|
|
114
114
|
|
|
115
115
|
|
|
116
|
-
def validate_qualify_columns(expression: E) -> E:
|
|
116
|
+
def validate_qualify_columns(expression: E, sql: t.Optional[str] = None) -> E:
|
|
117
117
|
"""Raise an `OptimizeError` if any columns aren't qualified"""
|
|
118
118
|
all_unqualified_columns = []
|
|
119
119
|
for scope in traverse_scope(expression):
|
|
@@ -123,7 +123,19 @@ def validate_qualify_columns(expression: E) -> E:
|
|
|
123
123
|
if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
|
|
124
124
|
column = scope.external_columns[0]
|
|
125
125
|
for_table = f" for table: '{column.table}'" if column.table else ""
|
|
126
|
-
|
|
126
|
+
line = column.this.meta.get("line")
|
|
127
|
+
col = column.this.meta.get("col")
|
|
128
|
+
start = column.this.meta.get("start")
|
|
129
|
+
end = column.this.meta.get("end")
|
|
130
|
+
|
|
131
|
+
error_msg = f"Column '{column.name}' could not be resolved{for_table}."
|
|
132
|
+
if line and col:
|
|
133
|
+
error_msg += f" Line: {line}, Col: {col}"
|
|
134
|
+
if sql and start is not None and end is not None:
|
|
135
|
+
formatted_sql = highlight_sql(sql, [(start, end)])[0]
|
|
136
|
+
error_msg += f"\n {formatted_sql}"
|
|
137
|
+
|
|
138
|
+
raise OptimizeError(error_msg)
|
|
127
139
|
|
|
128
140
|
if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
|
|
129
141
|
# New columns produced by the UNPIVOT can't be qualified, but there may be columns
|
|
@@ -135,11 +147,46 @@ def validate_qualify_columns(expression: E) -> E:
|
|
|
135
147
|
all_unqualified_columns.extend(unqualified_columns)
|
|
136
148
|
|
|
137
149
|
if all_unqualified_columns:
|
|
138
|
-
|
|
150
|
+
first_column = all_unqualified_columns[0]
|
|
151
|
+
line = first_column.this.meta.get("line")
|
|
152
|
+
col = first_column.this.meta.get("col")
|
|
153
|
+
start = first_column.this.meta.get("start")
|
|
154
|
+
end = first_column.this.meta.get("end")
|
|
155
|
+
|
|
156
|
+
error_msg = f"Ambiguous column '{first_column.name}'"
|
|
157
|
+
if line and col:
|
|
158
|
+
error_msg += f" (Line: {line}, Col: {col})"
|
|
159
|
+
if sql and start is not None and end is not None:
|
|
160
|
+
formatted_sql = highlight_sql(sql, [(start, end)])[0]
|
|
161
|
+
error_msg += f"\n {formatted_sql}"
|
|
162
|
+
|
|
163
|
+
raise OptimizeError(error_msg)
|
|
139
164
|
|
|
140
165
|
return expression
|
|
141
166
|
|
|
142
167
|
|
|
168
|
+
def _separate_pseudocolumns(scope: Scope, pseudocolumns: t.Set[str]) -> None:
|
|
169
|
+
if not pseudocolumns:
|
|
170
|
+
return
|
|
171
|
+
|
|
172
|
+
has_pseudocolumns = False
|
|
173
|
+
scope_expression = scope.expression
|
|
174
|
+
|
|
175
|
+
for column in scope.columns:
|
|
176
|
+
name = column.name.upper()
|
|
177
|
+
if name not in pseudocolumns:
|
|
178
|
+
continue
|
|
179
|
+
|
|
180
|
+
if name != "LEVEL" or (
|
|
181
|
+
isinstance(scope_expression, exp.Select) and scope_expression.args.get("connect")
|
|
182
|
+
):
|
|
183
|
+
column.replace(exp.Pseudocolumn(**column.args))
|
|
184
|
+
has_pseudocolumns = True
|
|
185
|
+
|
|
186
|
+
if has_pseudocolumns:
|
|
187
|
+
scope.clear_cache()
|
|
188
|
+
|
|
189
|
+
|
|
143
190
|
def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]:
|
|
144
191
|
name_columns = [
|
|
145
192
|
field.this
|
|
@@ -274,16 +321,17 @@ def _expand_alias_refs(
|
|
|
274
321
|
"""
|
|
275
322
|
expression = scope.expression
|
|
276
323
|
|
|
277
|
-
if not isinstance(expression, exp.Select) or dialect
|
|
324
|
+
if not isinstance(expression, exp.Select) or dialect.DISABLES_ALIAS_REF_EXPANSION:
|
|
278
325
|
return
|
|
279
326
|
|
|
280
327
|
alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {}
|
|
281
328
|
projections = {s.alias_or_name for s in expression.selects}
|
|
282
|
-
|
|
329
|
+
replaced = False
|
|
283
330
|
|
|
284
331
|
def replace_columns(
|
|
285
332
|
node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False
|
|
286
333
|
) -> None:
|
|
334
|
+
nonlocal replaced
|
|
287
335
|
is_group_by = isinstance(node, exp.Group)
|
|
288
336
|
is_having = isinstance(node, exp.Having)
|
|
289
337
|
if not node or (expand_only_groupby and not is_group_by):
|
|
@@ -315,12 +363,12 @@ def _expand_alias_refs(
|
|
|
315
363
|
# SELECT x.a, max(x.b) as x FROM x GROUP BY 1 HAVING x > 1;
|
|
316
364
|
# If "HAVING x" is expanded to "HAVING max(x.b)", BQ would blindly replace the "x" reference with the projection MAX(x.b)
|
|
317
365
|
# i.e HAVING MAX(MAX(x.b).b), resulting in the error: "Aggregations of aggregations are not allowed"
|
|
318
|
-
if is_having and
|
|
366
|
+
if is_having and dialect.PROJECTION_ALIASES_SHADOW_SOURCE_NAMES:
|
|
319
367
|
skip_replace = skip_replace or any(
|
|
320
368
|
node.parts[0].name in projections
|
|
321
369
|
for node in alias_expr.find_all(exp.Column)
|
|
322
370
|
)
|
|
323
|
-
elif
|
|
371
|
+
elif dialect.PROJECTION_ALIASES_SHADOW_SOURCE_NAMES and (is_group_by or is_having):
|
|
324
372
|
column_table = table.name if table else column.table
|
|
325
373
|
if column_table in projections:
|
|
326
374
|
# BigQuery's GROUP BY and HAVING clauses get confused if the column name
|
|
@@ -329,6 +377,7 @@ def _expand_alias_refs(
|
|
|
329
377
|
# We should not qualify "id" with "custom_fields" in either clause, since the aggregation shadows the actual table
|
|
330
378
|
# and we'd get the error: "Column custom_fields contains an aggregation function, which is not allowed in GROUP BY clause"
|
|
331
379
|
column.replace(exp.to_identifier(column.name))
|
|
380
|
+
replaced = True
|
|
332
381
|
return
|
|
333
382
|
|
|
334
383
|
if table and (not alias_expr or skip_replace):
|
|
@@ -339,7 +388,9 @@ def _expand_alias_refs(
|
|
|
339
388
|
):
|
|
340
389
|
if literal_index:
|
|
341
390
|
column.replace(exp.Literal.number(i))
|
|
391
|
+
replaced = True
|
|
342
392
|
else:
|
|
393
|
+
replaced = True
|
|
343
394
|
column = column.replace(exp.paren(alias_expr))
|
|
344
395
|
simplified = simplify_parens(column, dialect)
|
|
345
396
|
if simplified is not column:
|
|
@@ -370,16 +421,15 @@ def _expand_alias_refs(
|
|
|
370
421
|
replace_columns(expression.args.get("having"), resolve_table=True)
|
|
371
422
|
replace_columns(expression.args.get("qualify"), resolve_table=True)
|
|
372
423
|
|
|
373
|
-
|
|
374
|
-
# https://docs.snowflake.com/en/sql-reference/sql/select#usage-notes
|
|
375
|
-
if dialect == "snowflake":
|
|
424
|
+
if dialect.SUPPORTS_ALIAS_REFS_IN_JOIN_CONDITIONS:
|
|
376
425
|
for join in expression.args.get("joins") or []:
|
|
377
426
|
replace_columns(join)
|
|
378
427
|
|
|
379
|
-
|
|
428
|
+
if replaced:
|
|
429
|
+
scope.clear_cache()
|
|
380
430
|
|
|
381
431
|
|
|
382
|
-
def _expand_group_by(scope: Scope, dialect:
|
|
432
|
+
def _expand_group_by(scope: Scope, dialect: Dialect) -> None:
|
|
383
433
|
expression = scope.expression
|
|
384
434
|
group = expression.args.get("group")
|
|
385
435
|
if not group:
|
|
@@ -405,7 +455,7 @@ def _expand_order_by_and_distinct_on(scope: Scope, resolver: Resolver) -> None:
|
|
|
405
455
|
for original, expanded in zip(
|
|
406
456
|
modifier_expressions,
|
|
407
457
|
_expand_positional_references(
|
|
408
|
-
scope, modifier_expressions, resolver.
|
|
458
|
+
scope, modifier_expressions, resolver.dialect, alias=True
|
|
409
459
|
),
|
|
410
460
|
):
|
|
411
461
|
for agg in original.find_all(exp.AggFunc):
|
|
@@ -427,7 +477,7 @@ def _expand_order_by_and_distinct_on(scope: Scope, resolver: Resolver) -> None:
|
|
|
427
477
|
|
|
428
478
|
|
|
429
479
|
def _expand_positional_references(
|
|
430
|
-
scope: Scope, expressions: t.Iterable[exp.Expression], dialect:
|
|
480
|
+
scope: Scope, expressions: t.Iterable[exp.Expression], dialect: Dialect, alias: bool = False
|
|
431
481
|
) -> t.List[exp.Expression]:
|
|
432
482
|
new_nodes: t.List[exp.Expression] = []
|
|
433
483
|
ambiguous_projections = None
|
|
@@ -441,7 +491,7 @@ def _expand_positional_references(
|
|
|
441
491
|
else:
|
|
442
492
|
select = select.this
|
|
443
493
|
|
|
444
|
-
if dialect
|
|
494
|
+
if dialect.PROJECTION_ALIASES_SHADOW_SOURCE_NAMES:
|
|
445
495
|
if ambiguous_projections is None:
|
|
446
496
|
# When a projection name is also a source name and it is referenced in the
|
|
447
497
|
# GROUP BY clause, BQ can't understand what the identifier corresponds to
|
|
@@ -482,10 +532,10 @@ def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias:
|
|
|
482
532
|
|
|
483
533
|
def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None:
|
|
484
534
|
"""
|
|
485
|
-
Converts `Column` instances that represent
|
|
535
|
+
Converts `Column` instances that represent STRUCT or JSON field lookup into chained `Dots`.
|
|
486
536
|
|
|
487
|
-
|
|
488
|
-
|
|
537
|
+
These lookups may be parsed as columns (e.g. "col"."field"."field2"), but they need to be
|
|
538
|
+
normalized to `Dot(Dot(...(<table>.<column>, field1), field2, ...))` to be qualified properly.
|
|
489
539
|
"""
|
|
490
540
|
converted = False
|
|
491
541
|
for column in itertools.chain(scope.columns, scope.stars):
|
|
@@ -493,6 +543,7 @@ def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None:
|
|
|
493
543
|
continue
|
|
494
544
|
|
|
495
545
|
column_table: t.Optional[str | exp.Identifier] = column.table
|
|
546
|
+
dot_parts = column.meta.pop("dot_parts", [])
|
|
496
547
|
if (
|
|
497
548
|
column_table
|
|
498
549
|
and column_table not in scope.sources
|
|
@@ -508,12 +559,20 @@ def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None:
|
|
|
508
559
|
# The struct is already qualified, but we still need to change the AST
|
|
509
560
|
column_table = root
|
|
510
561
|
root, *parts = parts
|
|
562
|
+
was_qualified = True
|
|
511
563
|
else:
|
|
512
564
|
column_table = resolver.get_table(root.name)
|
|
565
|
+
was_qualified = False
|
|
513
566
|
|
|
514
567
|
if column_table:
|
|
515
568
|
converted = True
|
|
516
|
-
|
|
569
|
+
new_column = exp.column(root, table=column_table)
|
|
570
|
+
|
|
571
|
+
if dot_parts:
|
|
572
|
+
# Remove the actual column parts from the rest of dot parts
|
|
573
|
+
new_column.meta["dot_parts"] = dot_parts[2 if was_qualified else 1 :]
|
|
574
|
+
|
|
575
|
+
column.replace(exp.Dot.build([new_column, *parts]))
|
|
517
576
|
|
|
518
577
|
if converted:
|
|
519
578
|
# We want to re-aggregate the converted columns, otherwise they'd be skipped in
|
|
@@ -521,7 +580,11 @@ def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None:
|
|
|
521
580
|
scope.clear_cache()
|
|
522
581
|
|
|
523
582
|
|
|
524
|
-
def _qualify_columns(
|
|
583
|
+
def _qualify_columns(
|
|
584
|
+
scope: Scope,
|
|
585
|
+
resolver: Resolver,
|
|
586
|
+
allow_partial_qualification: bool,
|
|
587
|
+
) -> None:
|
|
525
588
|
"""Disambiguate columns, ensuring each column specifies a source"""
|
|
526
589
|
for column in scope.columns:
|
|
527
590
|
column_table = column.table
|
|
@@ -545,15 +608,16 @@ def _qualify_columns(scope: Scope, resolver: Resolver, allow_partial_qualificati
|
|
|
545
608
|
continue
|
|
546
609
|
|
|
547
610
|
# column_table can be a '' because bigquery unnest has no table alias
|
|
548
|
-
column_table = resolver.get_table(
|
|
611
|
+
column_table = resolver.get_table(column)
|
|
612
|
+
|
|
549
613
|
if column_table:
|
|
550
614
|
column.set("table", column_table)
|
|
551
615
|
elif (
|
|
552
|
-
resolver.
|
|
616
|
+
resolver.dialect.TABLES_REFERENCEABLE_AS_COLUMNS
|
|
553
617
|
and len(column.parts) == 1
|
|
554
618
|
and column_name in scope.selected_sources
|
|
555
619
|
):
|
|
556
|
-
# BigQuery
|
|
620
|
+
# BigQuery and Postgres allow tables to be referenced as columns, treating them as structs/records
|
|
557
621
|
scope.replace(column, exp.TableColumn(this=column.this))
|
|
558
622
|
|
|
559
623
|
for pivot in scope.pivots:
|
|
@@ -564,7 +628,7 @@ def _qualify_columns(scope: Scope, resolver: Resolver, allow_partial_qualificati
|
|
|
564
628
|
column.set("table", column_table)
|
|
565
629
|
|
|
566
630
|
|
|
567
|
-
def
|
|
631
|
+
def _expand_struct_stars_no_parens(
|
|
568
632
|
expression: exp.Dot,
|
|
569
633
|
) -> t.List[exp.Alias]:
|
|
570
634
|
"""[BigQuery] Expand/Flatten foo.bar.* where bar is a struct column"""
|
|
@@ -618,7 +682,7 @@ def _expand_struct_stars_bigquery(
|
|
|
618
682
|
return new_selections
|
|
619
683
|
|
|
620
684
|
|
|
621
|
-
def
|
|
685
|
+
def _expand_struct_stars_with_parens(expression: exp.Dot) -> t.List[exp.Alias]:
|
|
622
686
|
"""[RisingWave] Expand/Flatten (<exp>.bar).*, where bar is a struct column"""
|
|
623
687
|
|
|
624
688
|
# it is not (<sub_exp>).* pattern, which means we can't expand
|
|
@@ -695,7 +759,7 @@ def _expand_stars(
|
|
|
695
759
|
rename_columns: t.Dict[int, t.Dict[str, str]] = {}
|
|
696
760
|
|
|
697
761
|
coalesced_columns = set()
|
|
698
|
-
dialect = resolver.
|
|
762
|
+
dialect = resolver.dialect
|
|
699
763
|
|
|
700
764
|
pivot_output_columns = None
|
|
701
765
|
pivot_exclude_columns: t.Set[str] = set()
|
|
@@ -718,10 +782,9 @@ def _expand_stars(
|
|
|
718
782
|
if not pivot_output_columns:
|
|
719
783
|
pivot_output_columns = [c.alias_or_name for c in pivot.expressions]
|
|
720
784
|
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
if (is_bigquery or is_risingwave) and any(isinstance(col, exp.Dot) for col in scope.stars):
|
|
785
|
+
if dialect.SUPPORTS_STRUCT_STAR_EXPANSION and any(
|
|
786
|
+
isinstance(col, exp.Dot) for col in scope.stars
|
|
787
|
+
):
|
|
725
788
|
# Found struct expansion, annotate scope ahead of time
|
|
726
789
|
annotator.annotate_scope(scope)
|
|
727
790
|
|
|
@@ -738,13 +801,16 @@ def _expand_stars(
|
|
|
738
801
|
_add_except_columns(expression.this, tables, except_columns)
|
|
739
802
|
_add_replace_columns(expression.this, tables, replace_columns)
|
|
740
803
|
_add_rename_columns(expression.this, tables, rename_columns)
|
|
741
|
-
elif
|
|
742
|
-
|
|
804
|
+
elif (
|
|
805
|
+
dialect.SUPPORTS_STRUCT_STAR_EXPANSION
|
|
806
|
+
and not dialect.REQUIRES_PARENTHESIZED_STRUCT_ACCESS
|
|
807
|
+
):
|
|
808
|
+
struct_fields = _expand_struct_stars_no_parens(expression)
|
|
743
809
|
if struct_fields:
|
|
744
810
|
new_selections.extend(struct_fields)
|
|
745
811
|
continue
|
|
746
|
-
elif
|
|
747
|
-
struct_fields =
|
|
812
|
+
elif dialect.REQUIRES_PARENTHESIZED_STRUCT_ACCESS:
|
|
813
|
+
struct_fields = _expand_struct_stars_with_parens(expression)
|
|
748
814
|
if struct_fields:
|
|
749
815
|
new_selections.extend(struct_fields)
|
|
750
816
|
continue
|
|
@@ -760,7 +826,7 @@ def _expand_stars(
|
|
|
760
826
|
columns = resolver.get_source_columns(table, only_visible=True)
|
|
761
827
|
columns = columns or scope.outer_columns
|
|
762
828
|
|
|
763
|
-
if pseudocolumns:
|
|
829
|
+
if pseudocolumns and dialect.EXCLUDES_PSEUDOCOLUMNS_FROM_STAR:
|
|
764
830
|
columns = [name for name in columns if name.upper() not in pseudocolumns]
|
|
765
831
|
|
|
766
832
|
if not columns or "*" in columns:
|
|
@@ -814,7 +880,7 @@ def _expand_stars(
|
|
|
814
880
|
def _add_except_columns(
|
|
815
881
|
expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]]
|
|
816
882
|
) -> None:
|
|
817
|
-
except_ = expression.args.get("
|
|
883
|
+
except_ = expression.args.get("except_")
|
|
818
884
|
|
|
819
885
|
if not except_:
|
|
820
886
|
return
|
|
@@ -894,241 +960,22 @@ def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool
|
|
|
894
960
|
) # type: ignore
|
|
895
961
|
|
|
896
962
|
|
|
897
|
-
def pushdown_cte_alias_columns(
|
|
963
|
+
def pushdown_cte_alias_columns(scope: Scope) -> None:
|
|
898
964
|
"""
|
|
899
965
|
Pushes down the CTE alias columns into the projection,
|
|
900
966
|
|
|
901
967
|
This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
|
|
902
968
|
|
|
903
|
-
Example:
|
|
904
|
-
>>> import sqlglot
|
|
905
|
-
>>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
|
|
906
|
-
>>> pushdown_cte_alias_columns(expression).sql()
|
|
907
|
-
'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
|
|
908
|
-
|
|
909
969
|
Args:
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
Returns:
|
|
913
|
-
The expression with the CTE aliases pushed down into the projection.
|
|
970
|
+
scope: Scope to find ctes to pushdown aliases.
|
|
914
971
|
"""
|
|
915
|
-
for cte in
|
|
916
|
-
if cte.alias_column_names:
|
|
972
|
+
for cte in scope.ctes:
|
|
973
|
+
if cte.alias_column_names and isinstance(cte.this, exp.Select):
|
|
917
974
|
new_expressions = []
|
|
918
975
|
for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
|
|
919
976
|
if isinstance(projection, exp.Alias):
|
|
920
|
-
projection.set("alias", _alias)
|
|
977
|
+
projection.set("alias", exp.to_identifier(_alias))
|
|
921
978
|
else:
|
|
922
979
|
projection = alias(projection, alias=_alias)
|
|
923
980
|
new_expressions.append(projection)
|
|
924
981
|
cte.this.set("expressions", new_expressions)
|
|
925
|
-
|
|
926
|
-
return expression
|
|
927
|
-
|
|
928
|
-
|
|
929
|
-
class Resolver:
|
|
930
|
-
"""
|
|
931
|
-
Helper for resolving columns.
|
|
932
|
-
|
|
933
|
-
This is a class so we can lazily load some things and easily share them across functions.
|
|
934
|
-
"""
|
|
935
|
-
|
|
936
|
-
def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
|
|
937
|
-
self.scope = scope
|
|
938
|
-
self.schema = schema
|
|
939
|
-
self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
|
|
940
|
-
self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
|
|
941
|
-
self._all_columns: t.Optional[t.Set[str]] = None
|
|
942
|
-
self._infer_schema = infer_schema
|
|
943
|
-
self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
|
|
944
|
-
|
|
945
|
-
def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
|
|
946
|
-
"""
|
|
947
|
-
Get the table for a column name.
|
|
948
|
-
|
|
949
|
-
Args:
|
|
950
|
-
column_name: The column name to find the table for.
|
|
951
|
-
Returns:
|
|
952
|
-
The table name if it can be found/inferred.
|
|
953
|
-
"""
|
|
954
|
-
if self._unambiguous_columns is None:
|
|
955
|
-
self._unambiguous_columns = self._get_unambiguous_columns(
|
|
956
|
-
self._get_all_source_columns()
|
|
957
|
-
)
|
|
958
|
-
|
|
959
|
-
table_name = self._unambiguous_columns.get(column_name)
|
|
960
|
-
|
|
961
|
-
if not table_name and self._infer_schema:
|
|
962
|
-
sources_without_schema = tuple(
|
|
963
|
-
source
|
|
964
|
-
for source, columns in self._get_all_source_columns().items()
|
|
965
|
-
if not columns or "*" in columns
|
|
966
|
-
)
|
|
967
|
-
if len(sources_without_schema) == 1:
|
|
968
|
-
table_name = sources_without_schema[0]
|
|
969
|
-
|
|
970
|
-
if table_name not in self.scope.selected_sources:
|
|
971
|
-
return exp.to_identifier(table_name)
|
|
972
|
-
|
|
973
|
-
node, _ = self.scope.selected_sources.get(table_name)
|
|
974
|
-
|
|
975
|
-
if isinstance(node, exp.Query):
|
|
976
|
-
while node and node.alias != table_name:
|
|
977
|
-
node = node.parent
|
|
978
|
-
|
|
979
|
-
node_alias = node.args.get("alias")
|
|
980
|
-
if node_alias:
|
|
981
|
-
return exp.to_identifier(node_alias.this)
|
|
982
|
-
|
|
983
|
-
return exp.to_identifier(table_name)
|
|
984
|
-
|
|
985
|
-
@property
|
|
986
|
-
def all_columns(self) -> t.Set[str]:
|
|
987
|
-
"""All available columns of all sources in this scope"""
|
|
988
|
-
if self._all_columns is None:
|
|
989
|
-
self._all_columns = {
|
|
990
|
-
column for columns in self._get_all_source_columns().values() for column in columns
|
|
991
|
-
}
|
|
992
|
-
return self._all_columns
|
|
993
|
-
|
|
994
|
-
def get_source_columns_from_set_op(self, expression: exp.Expression) -> t.List[str]:
|
|
995
|
-
if isinstance(expression, exp.Select):
|
|
996
|
-
return expression.named_selects
|
|
997
|
-
if isinstance(expression, exp.Subquery) and isinstance(expression.this, exp.SetOperation):
|
|
998
|
-
# Different types of SET modifiers can be chained together if they're explicitly grouped by nesting
|
|
999
|
-
return self.get_source_columns_from_set_op(expression.this)
|
|
1000
|
-
if not isinstance(expression, exp.SetOperation):
|
|
1001
|
-
raise OptimizeError(f"Unknown set operation: {expression}")
|
|
1002
|
-
|
|
1003
|
-
set_op = expression
|
|
1004
|
-
|
|
1005
|
-
# BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME
|
|
1006
|
-
on_column_list = set_op.args.get("on")
|
|
1007
|
-
|
|
1008
|
-
if on_column_list:
|
|
1009
|
-
# The resulting columns are the columns in the ON clause:
|
|
1010
|
-
# {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...)
|
|
1011
|
-
columns = [col.name for col in on_column_list]
|
|
1012
|
-
elif set_op.side or set_op.kind:
|
|
1013
|
-
side = set_op.side
|
|
1014
|
-
kind = set_op.kind
|
|
1015
|
-
|
|
1016
|
-
# Visit the children UNIONs (if any) in a post-order traversal
|
|
1017
|
-
left = self.get_source_columns_from_set_op(set_op.left)
|
|
1018
|
-
right = self.get_source_columns_from_set_op(set_op.right)
|
|
1019
|
-
|
|
1020
|
-
# We use dict.fromkeys to deduplicate keys and maintain insertion order
|
|
1021
|
-
if side == "LEFT":
|
|
1022
|
-
columns = left
|
|
1023
|
-
elif side == "FULL":
|
|
1024
|
-
columns = list(dict.fromkeys(left + right))
|
|
1025
|
-
elif kind == "INNER":
|
|
1026
|
-
columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys())
|
|
1027
|
-
else:
|
|
1028
|
-
columns = set_op.named_selects
|
|
1029
|
-
|
|
1030
|
-
return columns
|
|
1031
|
-
|
|
1032
|
-
def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
|
|
1033
|
-
"""Resolve the source columns for a given source `name`."""
|
|
1034
|
-
cache_key = (name, only_visible)
|
|
1035
|
-
if cache_key not in self._get_source_columns_cache:
|
|
1036
|
-
if name not in self.scope.sources:
|
|
1037
|
-
raise OptimizeError(f"Unknown table: {name}")
|
|
1038
|
-
|
|
1039
|
-
source = self.scope.sources[name]
|
|
1040
|
-
|
|
1041
|
-
if isinstance(source, exp.Table):
|
|
1042
|
-
columns = self.schema.column_names(source, only_visible)
|
|
1043
|
-
elif isinstance(source, Scope) and isinstance(
|
|
1044
|
-
source.expression, (exp.Values, exp.Unnest)
|
|
1045
|
-
):
|
|
1046
|
-
columns = source.expression.named_selects
|
|
1047
|
-
|
|
1048
|
-
# in bigquery, unnest structs are automatically scoped as tables, so you can
|
|
1049
|
-
# directly select a struct field in a query.
|
|
1050
|
-
# this handles the case where the unnest is statically defined.
|
|
1051
|
-
if self.schema.dialect == "bigquery":
|
|
1052
|
-
if source.expression.is_type(exp.DataType.Type.STRUCT):
|
|
1053
|
-
for k in source.expression.type.expressions: # type: ignore
|
|
1054
|
-
columns.append(k.name)
|
|
1055
|
-
elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation):
|
|
1056
|
-
columns = self.get_source_columns_from_set_op(source.expression)
|
|
1057
|
-
|
|
1058
|
-
else:
|
|
1059
|
-
select = seq_get(source.expression.selects, 0)
|
|
1060
|
-
|
|
1061
|
-
if isinstance(select, exp.QueryTransform):
|
|
1062
|
-
# https://spark.apache.org/docs/3.5.1/sql-ref-syntax-qry-select-transform.html
|
|
1063
|
-
schema = select.args.get("schema")
|
|
1064
|
-
columns = [c.name for c in schema.expressions] if schema else ["key", "value"]
|
|
1065
|
-
else:
|
|
1066
|
-
columns = source.expression.named_selects
|
|
1067
|
-
|
|
1068
|
-
node, _ = self.scope.selected_sources.get(name) or (None, None)
|
|
1069
|
-
if isinstance(node, Scope):
|
|
1070
|
-
column_aliases = node.expression.alias_column_names
|
|
1071
|
-
elif isinstance(node, exp.Expression):
|
|
1072
|
-
column_aliases = node.alias_column_names
|
|
1073
|
-
else:
|
|
1074
|
-
column_aliases = []
|
|
1075
|
-
|
|
1076
|
-
if column_aliases:
|
|
1077
|
-
# If the source's columns are aliased, their aliases shadow the corresponding column names.
|
|
1078
|
-
# This can be expensive if there are lots of columns, so only do this if column_aliases exist.
|
|
1079
|
-
columns = [
|
|
1080
|
-
alias or name
|
|
1081
|
-
for (name, alias) in itertools.zip_longest(columns, column_aliases)
|
|
1082
|
-
]
|
|
1083
|
-
|
|
1084
|
-
self._get_source_columns_cache[cache_key] = columns
|
|
1085
|
-
|
|
1086
|
-
return self._get_source_columns_cache[cache_key]
|
|
1087
|
-
|
|
1088
|
-
def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
|
|
1089
|
-
if self._source_columns is None:
|
|
1090
|
-
self._source_columns = {
|
|
1091
|
-
source_name: self.get_source_columns(source_name)
|
|
1092
|
-
for source_name, source in itertools.chain(
|
|
1093
|
-
self.scope.selected_sources.items(), self.scope.lateral_sources.items()
|
|
1094
|
-
)
|
|
1095
|
-
}
|
|
1096
|
-
return self._source_columns
|
|
1097
|
-
|
|
1098
|
-
def _get_unambiguous_columns(
|
|
1099
|
-
self, source_columns: t.Dict[str, t.Sequence[str]]
|
|
1100
|
-
) -> t.Mapping[str, str]:
|
|
1101
|
-
"""
|
|
1102
|
-
Find all the unambiguous columns in sources.
|
|
1103
|
-
|
|
1104
|
-
Args:
|
|
1105
|
-
source_columns: Mapping of names to source columns.
|
|
1106
|
-
|
|
1107
|
-
Returns:
|
|
1108
|
-
Mapping of column name to source name.
|
|
1109
|
-
"""
|
|
1110
|
-
if not source_columns:
|
|
1111
|
-
return {}
|
|
1112
|
-
|
|
1113
|
-
source_columns_pairs = list(source_columns.items())
|
|
1114
|
-
|
|
1115
|
-
first_table, first_columns = source_columns_pairs[0]
|
|
1116
|
-
|
|
1117
|
-
if len(source_columns_pairs) == 1:
|
|
1118
|
-
# Performance optimization - avoid copying first_columns if there is only one table.
|
|
1119
|
-
return SingleValuedMapping(first_columns, first_table)
|
|
1120
|
-
|
|
1121
|
-
unambiguous_columns = {col: first_table for col in first_columns}
|
|
1122
|
-
all_columns = set(unambiguous_columns)
|
|
1123
|
-
|
|
1124
|
-
for table, columns in source_columns_pairs[1:]:
|
|
1125
|
-
unique = set(columns)
|
|
1126
|
-
ambiguous = all_columns.intersection(unique)
|
|
1127
|
-
all_columns.update(columns)
|
|
1128
|
-
|
|
1129
|
-
for column in ambiguous:
|
|
1130
|
-
unambiguous_columns.pop(column, None)
|
|
1131
|
-
for column in unique.difference(ambiguous):
|
|
1132
|
-
unambiguous_columns[column] = table
|
|
1133
|
-
|
|
1134
|
-
return unambiguous_columns
|