sqlglot 27.29.0__py3-none-any.whl → 28.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlglot/__main__.py +6 -4
- sqlglot/_version.py +2 -2
- sqlglot/dialects/bigquery.py +116 -295
- sqlglot/dialects/clickhouse.py +67 -2
- sqlglot/dialects/databricks.py +38 -1
- sqlglot/dialects/dialect.py +327 -286
- sqlglot/dialects/dremio.py +4 -1
- sqlglot/dialects/duckdb.py +718 -22
- sqlglot/dialects/exasol.py +243 -10
- sqlglot/dialects/hive.py +8 -8
- sqlglot/dialects/mysql.py +11 -2
- sqlglot/dialects/oracle.py +29 -0
- sqlglot/dialects/postgres.py +46 -24
- sqlglot/dialects/presto.py +47 -16
- sqlglot/dialects/redshift.py +16 -0
- sqlglot/dialects/risingwave.py +3 -0
- sqlglot/dialects/singlestore.py +12 -3
- sqlglot/dialects/snowflake.py +199 -271
- sqlglot/dialects/spark.py +2 -2
- sqlglot/dialects/spark2.py +11 -48
- sqlglot/dialects/sqlite.py +9 -0
- sqlglot/dialects/teradata.py +5 -8
- sqlglot/dialects/trino.py +6 -0
- sqlglot/dialects/tsql.py +61 -25
- sqlglot/diff.py +4 -2
- sqlglot/errors.py +69 -0
- sqlglot/expressions.py +484 -84
- sqlglot/generator.py +143 -41
- sqlglot/helper.py +2 -2
- sqlglot/optimizer/annotate_types.py +247 -140
- sqlglot/optimizer/canonicalize.py +6 -1
- sqlglot/optimizer/eliminate_joins.py +1 -1
- sqlglot/optimizer/eliminate_subqueries.py +2 -2
- sqlglot/optimizer/merge_subqueries.py +5 -5
- sqlglot/optimizer/normalize.py +20 -13
- sqlglot/optimizer/normalize_identifiers.py +17 -3
- sqlglot/optimizer/optimizer.py +4 -0
- sqlglot/optimizer/pushdown_predicates.py +1 -1
- sqlglot/optimizer/qualify.py +14 -6
- sqlglot/optimizer/qualify_columns.py +113 -352
- sqlglot/optimizer/qualify_tables.py +112 -70
- sqlglot/optimizer/resolver.py +374 -0
- sqlglot/optimizer/scope.py +27 -16
- sqlglot/optimizer/simplify.py +1074 -964
- sqlglot/optimizer/unnest_subqueries.py +12 -2
- sqlglot/parser.py +276 -160
- sqlglot/planner.py +2 -2
- sqlglot/schema.py +15 -4
- sqlglot/tokens.py +42 -7
- sqlglot/transforms.py +77 -22
- sqlglot/typing/__init__.py +316 -0
- sqlglot/typing/bigquery.py +376 -0
- sqlglot/typing/hive.py +12 -0
- sqlglot/typing/presto.py +24 -0
- sqlglot/typing/snowflake.py +505 -0
- sqlglot/typing/spark2.py +58 -0
- sqlglot/typing/tsql.py +9 -0
- {sqlglot-27.29.0.dist-info → sqlglot-28.4.0.dist-info}/METADATA +2 -2
- sqlglot-28.4.0.dist-info/RECORD +92 -0
- sqlglot-27.29.0.dist-info/RECORD +0 -84
- {sqlglot-27.29.0.dist-info → sqlglot-28.4.0.dist-info}/WHEEL +0 -0
- {sqlglot-27.29.0.dist-info → sqlglot-28.4.0.dist-info}/licenses/LICENSE +0 -0
- {sqlglot-27.29.0.dist-info → sqlglot-28.4.0.dist-info}/top_level.txt +0 -0
|
@@ -5,9 +5,10 @@ import typing as t
|
|
|
5
5
|
|
|
6
6
|
from sqlglot import alias, exp
|
|
7
7
|
from sqlglot.dialects.dialect import Dialect, DialectType
|
|
8
|
-
from sqlglot.errors import OptimizeError
|
|
9
|
-
from sqlglot.helper import seq_get
|
|
8
|
+
from sqlglot.errors import OptimizeError, highlight_sql
|
|
9
|
+
from sqlglot.helper import seq_get
|
|
10
10
|
from sqlglot.optimizer.annotate_types import TypeAnnotator
|
|
11
|
+
from sqlglot.optimizer.resolver import Resolver
|
|
11
12
|
from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope
|
|
12
13
|
from sqlglot.optimizer.simplify import simplify_parens
|
|
13
14
|
from sqlglot.schema import Schema, ensure_schema
|
|
@@ -54,22 +55,17 @@ def qualify_columns(
|
|
|
54
55
|
schema = ensure_schema(schema, dialect=dialect)
|
|
55
56
|
annotator = TypeAnnotator(schema)
|
|
56
57
|
infer_schema = schema.empty if infer_schema is None else infer_schema
|
|
57
|
-
dialect =
|
|
58
|
+
dialect = schema.dialect or Dialect()
|
|
58
59
|
pseudocolumns = dialect.PSEUDOCOLUMNS
|
|
59
|
-
bigquery = dialect == "bigquery"
|
|
60
60
|
|
|
61
61
|
for scope in traverse_scope(expression):
|
|
62
|
+
if dialect.PREFER_CTE_ALIAS_COLUMN:
|
|
63
|
+
pushdown_cte_alias_columns(scope)
|
|
64
|
+
|
|
62
65
|
scope_expression = scope.expression
|
|
63
66
|
is_select = isinstance(scope_expression, exp.Select)
|
|
64
67
|
|
|
65
|
-
|
|
66
|
-
# In Snowflake / Oracle queries that have a CONNECT BY clause, one can use the LEVEL
|
|
67
|
-
# pseudocolumn, which doesn't belong to a table, so we change it into an identifier
|
|
68
|
-
scope_expression.transform(
|
|
69
|
-
lambda n: n.this if isinstance(n, exp.Column) and n.name == "LEVEL" else n,
|
|
70
|
-
copy=False,
|
|
71
|
-
)
|
|
72
|
-
scope.clear_cache()
|
|
68
|
+
_separate_pseudocolumns(scope, pseudocolumns)
|
|
73
69
|
|
|
74
70
|
resolver = Resolver(scope, schema, infer_schema=infer_schema)
|
|
75
71
|
_pop_table_column_aliases(scope.ctes)
|
|
@@ -81,11 +77,15 @@ def qualify_columns(
|
|
|
81
77
|
scope,
|
|
82
78
|
resolver,
|
|
83
79
|
dialect,
|
|
84
|
-
expand_only_groupby=
|
|
80
|
+
expand_only_groupby=dialect.EXPAND_ONLY_GROUP_ALIAS_REF,
|
|
85
81
|
)
|
|
86
82
|
|
|
87
83
|
_convert_columns_to_dots(scope, resolver)
|
|
88
|
-
_qualify_columns(
|
|
84
|
+
_qualify_columns(
|
|
85
|
+
scope,
|
|
86
|
+
resolver,
|
|
87
|
+
allow_partial_qualification=allow_partial_qualification,
|
|
88
|
+
)
|
|
89
89
|
|
|
90
90
|
if not schema.empty and expand_alias_refs:
|
|
91
91
|
_expand_alias_refs(scope, resolver, dialect)
|
|
@@ -107,13 +107,13 @@ def qualify_columns(
|
|
|
107
107
|
# https://www.postgresql.org/docs/current/sql-select.html#SQL-DISTINCT
|
|
108
108
|
_expand_order_by_and_distinct_on(scope, resolver)
|
|
109
109
|
|
|
110
|
-
if
|
|
110
|
+
if dialect.ANNOTATE_ALL_SCOPES:
|
|
111
111
|
annotator.annotate_scope(scope)
|
|
112
112
|
|
|
113
113
|
return expression
|
|
114
114
|
|
|
115
115
|
|
|
116
|
-
def validate_qualify_columns(expression: E) -> E:
|
|
116
|
+
def validate_qualify_columns(expression: E, sql: t.Optional[str] = None) -> E:
|
|
117
117
|
"""Raise an `OptimizeError` if any columns aren't qualified"""
|
|
118
118
|
all_unqualified_columns = []
|
|
119
119
|
for scope in traverse_scope(expression):
|
|
@@ -123,7 +123,19 @@ def validate_qualify_columns(expression: E) -> E:
|
|
|
123
123
|
if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
|
|
124
124
|
column = scope.external_columns[0]
|
|
125
125
|
for_table = f" for table: '{column.table}'" if column.table else ""
|
|
126
|
-
|
|
126
|
+
line = column.this.meta.get("line")
|
|
127
|
+
col = column.this.meta.get("col")
|
|
128
|
+
start = column.this.meta.get("start")
|
|
129
|
+
end = column.this.meta.get("end")
|
|
130
|
+
|
|
131
|
+
error_msg = f"Column '{column.name}' could not be resolved{for_table}."
|
|
132
|
+
if line and col:
|
|
133
|
+
error_msg += f" Line: {line}, Col: {col}"
|
|
134
|
+
if sql and start is not None and end is not None:
|
|
135
|
+
formatted_sql = highlight_sql(sql, [(start, end)])[0]
|
|
136
|
+
error_msg += f"\n {formatted_sql}"
|
|
137
|
+
|
|
138
|
+
raise OptimizeError(error_msg)
|
|
127
139
|
|
|
128
140
|
if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
|
|
129
141
|
# New columns produced by the UNPIVOT can't be qualified, but there may be columns
|
|
@@ -135,11 +147,46 @@ def validate_qualify_columns(expression: E) -> E:
|
|
|
135
147
|
all_unqualified_columns.extend(unqualified_columns)
|
|
136
148
|
|
|
137
149
|
if all_unqualified_columns:
|
|
138
|
-
|
|
150
|
+
first_column = all_unqualified_columns[0]
|
|
151
|
+
line = first_column.this.meta.get("line")
|
|
152
|
+
col = first_column.this.meta.get("col")
|
|
153
|
+
start = first_column.this.meta.get("start")
|
|
154
|
+
end = first_column.this.meta.get("end")
|
|
155
|
+
|
|
156
|
+
error_msg = f"Ambiguous column '{first_column.name}'"
|
|
157
|
+
if line and col:
|
|
158
|
+
error_msg += f" (Line: {line}, Col: {col})"
|
|
159
|
+
if sql and start is not None and end is not None:
|
|
160
|
+
formatted_sql = highlight_sql(sql, [(start, end)])[0]
|
|
161
|
+
error_msg += f"\n {formatted_sql}"
|
|
162
|
+
|
|
163
|
+
raise OptimizeError(error_msg)
|
|
139
164
|
|
|
140
165
|
return expression
|
|
141
166
|
|
|
142
167
|
|
|
168
|
+
def _separate_pseudocolumns(scope: Scope, pseudocolumns: t.Set[str]) -> None:
|
|
169
|
+
if not pseudocolumns:
|
|
170
|
+
return
|
|
171
|
+
|
|
172
|
+
has_pseudocolumns = False
|
|
173
|
+
scope_expression = scope.expression
|
|
174
|
+
|
|
175
|
+
for column in scope.columns:
|
|
176
|
+
name = column.name.upper()
|
|
177
|
+
if name not in pseudocolumns:
|
|
178
|
+
continue
|
|
179
|
+
|
|
180
|
+
if name != "LEVEL" or (
|
|
181
|
+
isinstance(scope_expression, exp.Select) and scope_expression.args.get("connect")
|
|
182
|
+
):
|
|
183
|
+
column.replace(exp.Pseudocolumn(**column.args))
|
|
184
|
+
has_pseudocolumns = True
|
|
185
|
+
|
|
186
|
+
if has_pseudocolumns:
|
|
187
|
+
scope.clear_cache()
|
|
188
|
+
|
|
189
|
+
|
|
143
190
|
def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]:
|
|
144
191
|
name_columns = [
|
|
145
192
|
field.this
|
|
@@ -274,12 +321,11 @@ def _expand_alias_refs(
|
|
|
274
321
|
"""
|
|
275
322
|
expression = scope.expression
|
|
276
323
|
|
|
277
|
-
if not isinstance(expression, exp.Select) or dialect
|
|
324
|
+
if not isinstance(expression, exp.Select) or dialect.DISABLES_ALIAS_REF_EXPANSION:
|
|
278
325
|
return
|
|
279
326
|
|
|
280
327
|
alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {}
|
|
281
328
|
projections = {s.alias_or_name for s in expression.selects}
|
|
282
|
-
is_bigquery = dialect == "bigquery"
|
|
283
329
|
replaced = False
|
|
284
330
|
|
|
285
331
|
def replace_columns(
|
|
@@ -317,12 +363,12 @@ def _expand_alias_refs(
|
|
|
317
363
|
# SELECT x.a, max(x.b) as x FROM x GROUP BY 1 HAVING x > 1;
|
|
318
364
|
# If "HAVING x" is expanded to "HAVING max(x.b)", BQ would blindly replace the "x" reference with the projection MAX(x.b)
|
|
319
365
|
# i.e HAVING MAX(MAX(x.b).b), resulting in the error: "Aggregations of aggregations are not allowed"
|
|
320
|
-
if is_having and
|
|
366
|
+
if is_having and dialect.PROJECTION_ALIASES_SHADOW_SOURCE_NAMES:
|
|
321
367
|
skip_replace = skip_replace or any(
|
|
322
368
|
node.parts[0].name in projections
|
|
323
369
|
for node in alias_expr.find_all(exp.Column)
|
|
324
370
|
)
|
|
325
|
-
elif
|
|
371
|
+
elif dialect.PROJECTION_ALIASES_SHADOW_SOURCE_NAMES and (is_group_by or is_having):
|
|
326
372
|
column_table = table.name if table else column.table
|
|
327
373
|
if column_table in projections:
|
|
328
374
|
# BigQuery's GROUP BY and HAVING clauses get confused if the column name
|
|
@@ -375,9 +421,7 @@ def _expand_alias_refs(
|
|
|
375
421
|
replace_columns(expression.args.get("having"), resolve_table=True)
|
|
376
422
|
replace_columns(expression.args.get("qualify"), resolve_table=True)
|
|
377
423
|
|
|
378
|
-
|
|
379
|
-
# https://docs.snowflake.com/en/sql-reference/sql/select#usage-notes
|
|
380
|
-
if dialect == "snowflake":
|
|
424
|
+
if dialect.SUPPORTS_ALIAS_REFS_IN_JOIN_CONDITIONS:
|
|
381
425
|
for join in expression.args.get("joins") or []:
|
|
382
426
|
replace_columns(join)
|
|
383
427
|
|
|
@@ -385,7 +429,7 @@ def _expand_alias_refs(
|
|
|
385
429
|
scope.clear_cache()
|
|
386
430
|
|
|
387
431
|
|
|
388
|
-
def _expand_group_by(scope: Scope, dialect:
|
|
432
|
+
def _expand_group_by(scope: Scope, dialect: Dialect) -> None:
|
|
389
433
|
expression = scope.expression
|
|
390
434
|
group = expression.args.get("group")
|
|
391
435
|
if not group:
|
|
@@ -411,7 +455,7 @@ def _expand_order_by_and_distinct_on(scope: Scope, resolver: Resolver) -> None:
|
|
|
411
455
|
for original, expanded in zip(
|
|
412
456
|
modifier_expressions,
|
|
413
457
|
_expand_positional_references(
|
|
414
|
-
scope, modifier_expressions, resolver.
|
|
458
|
+
scope, modifier_expressions, resolver.dialect, alias=True
|
|
415
459
|
),
|
|
416
460
|
):
|
|
417
461
|
for agg in original.find_all(exp.AggFunc):
|
|
@@ -433,7 +477,7 @@ def _expand_order_by_and_distinct_on(scope: Scope, resolver: Resolver) -> None:
|
|
|
433
477
|
|
|
434
478
|
|
|
435
479
|
def _expand_positional_references(
|
|
436
|
-
scope: Scope, expressions: t.Iterable[exp.Expression], dialect:
|
|
480
|
+
scope: Scope, expressions: t.Iterable[exp.Expression], dialect: Dialect, alias: bool = False
|
|
437
481
|
) -> t.List[exp.Expression]:
|
|
438
482
|
new_nodes: t.List[exp.Expression] = []
|
|
439
483
|
ambiguous_projections = None
|
|
@@ -447,7 +491,7 @@ def _expand_positional_references(
|
|
|
447
491
|
else:
|
|
448
492
|
select = select.this
|
|
449
493
|
|
|
450
|
-
if dialect
|
|
494
|
+
if dialect.PROJECTION_ALIASES_SHADOW_SOURCE_NAMES:
|
|
451
495
|
if ambiguous_projections is None:
|
|
452
496
|
# When a projection name is also a source name and it is referenced in the
|
|
453
497
|
# GROUP BY clause, BQ can't understand what the identifier corresponds to
|
|
@@ -488,10 +532,10 @@ def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias:
|
|
|
488
532
|
|
|
489
533
|
def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None:
|
|
490
534
|
"""
|
|
491
|
-
Converts `Column` instances that represent
|
|
535
|
+
Converts `Column` instances that represent STRUCT or JSON field lookup into chained `Dots`.
|
|
492
536
|
|
|
493
|
-
|
|
494
|
-
|
|
537
|
+
These lookups may be parsed as columns (e.g. "col"."field"."field2"), but they need to be
|
|
538
|
+
normalized to `Dot(Dot(...(<table>.<column>, field1), field2, ...))` to be qualified properly.
|
|
495
539
|
"""
|
|
496
540
|
converted = False
|
|
497
541
|
for column in itertools.chain(scope.columns, scope.stars):
|
|
@@ -499,6 +543,7 @@ def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None:
|
|
|
499
543
|
continue
|
|
500
544
|
|
|
501
545
|
column_table: t.Optional[str | exp.Identifier] = column.table
|
|
546
|
+
dot_parts = column.meta.pop("dot_parts", [])
|
|
502
547
|
if (
|
|
503
548
|
column_table
|
|
504
549
|
and column_table not in scope.sources
|
|
@@ -514,12 +559,20 @@ def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None:
|
|
|
514
559
|
# The struct is already qualified, but we still need to change the AST
|
|
515
560
|
column_table = root
|
|
516
561
|
root, *parts = parts
|
|
562
|
+
was_qualified = True
|
|
517
563
|
else:
|
|
518
564
|
column_table = resolver.get_table(root.name)
|
|
565
|
+
was_qualified = False
|
|
519
566
|
|
|
520
567
|
if column_table:
|
|
521
568
|
converted = True
|
|
522
|
-
|
|
569
|
+
new_column = exp.column(root, table=column_table)
|
|
570
|
+
|
|
571
|
+
if dot_parts:
|
|
572
|
+
# Remove the actual column parts from the rest of dot parts
|
|
573
|
+
new_column.meta["dot_parts"] = dot_parts[2 if was_qualified else 1 :]
|
|
574
|
+
|
|
575
|
+
column.replace(exp.Dot.build([new_column, *parts]))
|
|
523
576
|
|
|
524
577
|
if converted:
|
|
525
578
|
# We want to re-aggregate the converted columns, otherwise they'd be skipped in
|
|
@@ -527,7 +580,11 @@ def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None:
|
|
|
527
580
|
scope.clear_cache()
|
|
528
581
|
|
|
529
582
|
|
|
530
|
-
def _qualify_columns(
|
|
583
|
+
def _qualify_columns(
|
|
584
|
+
scope: Scope,
|
|
585
|
+
resolver: Resolver,
|
|
586
|
+
allow_partial_qualification: bool,
|
|
587
|
+
) -> None:
|
|
531
588
|
"""Disambiguate columns, ensuring each column specifies a source"""
|
|
532
589
|
for column in scope.columns:
|
|
533
590
|
column_table = column.table
|
|
@@ -556,11 +613,11 @@ def _qualify_columns(scope: Scope, resolver: Resolver, allow_partial_qualificati
|
|
|
556
613
|
if column_table:
|
|
557
614
|
column.set("table", column_table)
|
|
558
615
|
elif (
|
|
559
|
-
resolver.
|
|
616
|
+
resolver.dialect.TABLES_REFERENCEABLE_AS_COLUMNS
|
|
560
617
|
and len(column.parts) == 1
|
|
561
618
|
and column_name in scope.selected_sources
|
|
562
619
|
):
|
|
563
|
-
# BigQuery
|
|
620
|
+
# BigQuery and Postgres allow tables to be referenced as columns, treating them as structs/records
|
|
564
621
|
scope.replace(column, exp.TableColumn(this=column.this))
|
|
565
622
|
|
|
566
623
|
for pivot in scope.pivots:
|
|
@@ -571,7 +628,7 @@ def _qualify_columns(scope: Scope, resolver: Resolver, allow_partial_qualificati
|
|
|
571
628
|
column.set("table", column_table)
|
|
572
629
|
|
|
573
630
|
|
|
574
|
-
def
|
|
631
|
+
def _expand_struct_stars_no_parens(
|
|
575
632
|
expression: exp.Dot,
|
|
576
633
|
) -> t.List[exp.Alias]:
|
|
577
634
|
"""[BigQuery] Expand/Flatten foo.bar.* where bar is a struct column"""
|
|
@@ -625,7 +682,7 @@ def _expand_struct_stars_bigquery(
|
|
|
625
682
|
return new_selections
|
|
626
683
|
|
|
627
684
|
|
|
628
|
-
def
|
|
685
|
+
def _expand_struct_stars_with_parens(expression: exp.Dot) -> t.List[exp.Alias]:
|
|
629
686
|
"""[RisingWave] Expand/Flatten (<exp>.bar).*, where bar is a struct column"""
|
|
630
687
|
|
|
631
688
|
# it is not (<sub_exp>).* pattern, which means we can't expand
|
|
@@ -702,7 +759,7 @@ def _expand_stars(
|
|
|
702
759
|
rename_columns: t.Dict[int, t.Dict[str, str]] = {}
|
|
703
760
|
|
|
704
761
|
coalesced_columns = set()
|
|
705
|
-
dialect = resolver.
|
|
762
|
+
dialect = resolver.dialect
|
|
706
763
|
|
|
707
764
|
pivot_output_columns = None
|
|
708
765
|
pivot_exclude_columns: t.Set[str] = set()
|
|
@@ -725,10 +782,9 @@ def _expand_stars(
|
|
|
725
782
|
if not pivot_output_columns:
|
|
726
783
|
pivot_output_columns = [c.alias_or_name for c in pivot.expressions]
|
|
727
784
|
|
|
728
|
-
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
if (is_bigquery or is_risingwave) and any(isinstance(col, exp.Dot) for col in scope.stars):
|
|
785
|
+
if dialect.SUPPORTS_STRUCT_STAR_EXPANSION and any(
|
|
786
|
+
isinstance(col, exp.Dot) for col in scope.stars
|
|
787
|
+
):
|
|
732
788
|
# Found struct expansion, annotate scope ahead of time
|
|
733
789
|
annotator.annotate_scope(scope)
|
|
734
790
|
|
|
@@ -745,13 +801,16 @@ def _expand_stars(
|
|
|
745
801
|
_add_except_columns(expression.this, tables, except_columns)
|
|
746
802
|
_add_replace_columns(expression.this, tables, replace_columns)
|
|
747
803
|
_add_rename_columns(expression.this, tables, rename_columns)
|
|
748
|
-
elif
|
|
749
|
-
|
|
804
|
+
elif (
|
|
805
|
+
dialect.SUPPORTS_STRUCT_STAR_EXPANSION
|
|
806
|
+
and not dialect.REQUIRES_PARENTHESIZED_STRUCT_ACCESS
|
|
807
|
+
):
|
|
808
|
+
struct_fields = _expand_struct_stars_no_parens(expression)
|
|
750
809
|
if struct_fields:
|
|
751
810
|
new_selections.extend(struct_fields)
|
|
752
811
|
continue
|
|
753
|
-
elif
|
|
754
|
-
struct_fields =
|
|
812
|
+
elif dialect.REQUIRES_PARENTHESIZED_STRUCT_ACCESS:
|
|
813
|
+
struct_fields = _expand_struct_stars_with_parens(expression)
|
|
755
814
|
if struct_fields:
|
|
756
815
|
new_selections.extend(struct_fields)
|
|
757
816
|
continue
|
|
@@ -767,7 +826,7 @@ def _expand_stars(
|
|
|
767
826
|
columns = resolver.get_source_columns(table, only_visible=True)
|
|
768
827
|
columns = columns or scope.outer_columns
|
|
769
828
|
|
|
770
|
-
if pseudocolumns:
|
|
829
|
+
if pseudocolumns and dialect.EXCLUDES_PSEUDOCOLUMNS_FROM_STAR:
|
|
771
830
|
columns = [name for name in columns if name.upper() not in pseudocolumns]
|
|
772
831
|
|
|
773
832
|
if not columns or "*" in columns:
|
|
@@ -821,7 +880,7 @@ def _expand_stars(
|
|
|
821
880
|
def _add_except_columns(
|
|
822
881
|
expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]]
|
|
823
882
|
) -> None:
|
|
824
|
-
except_ = expression.args.get("
|
|
883
|
+
except_ = expression.args.get("except_")
|
|
825
884
|
|
|
826
885
|
if not except_:
|
|
827
886
|
return
|
|
@@ -901,320 +960,22 @@ def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool
|
|
|
901
960
|
) # type: ignore
|
|
902
961
|
|
|
903
962
|
|
|
904
|
-
def pushdown_cte_alias_columns(
|
|
963
|
+
def pushdown_cte_alias_columns(scope: Scope) -> None:
|
|
905
964
|
"""
|
|
906
965
|
Pushes down the CTE alias columns into the projection,
|
|
907
966
|
|
|
908
967
|
This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
|
|
909
968
|
|
|
910
|
-
Example:
|
|
911
|
-
>>> import sqlglot
|
|
912
|
-
>>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
|
|
913
|
-
>>> pushdown_cte_alias_columns(expression).sql()
|
|
914
|
-
'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
|
|
915
|
-
|
|
916
969
|
Args:
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
Returns:
|
|
920
|
-
The expression with the CTE aliases pushed down into the projection.
|
|
970
|
+
scope: Scope to find ctes to pushdown aliases.
|
|
921
971
|
"""
|
|
922
|
-
for cte in
|
|
923
|
-
if cte.alias_column_names:
|
|
972
|
+
for cte in scope.ctes:
|
|
973
|
+
if cte.alias_column_names and isinstance(cte.this, exp.Select):
|
|
924
974
|
new_expressions = []
|
|
925
975
|
for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
|
|
926
976
|
if isinstance(projection, exp.Alias):
|
|
927
|
-
projection.set("alias", _alias)
|
|
977
|
+
projection.set("alias", exp.to_identifier(_alias))
|
|
928
978
|
else:
|
|
929
979
|
projection = alias(projection, alias=_alias)
|
|
930
980
|
new_expressions.append(projection)
|
|
931
981
|
cte.this.set("expressions", new_expressions)
|
|
932
|
-
|
|
933
|
-
return expression
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
class Resolver:
|
|
937
|
-
"""
|
|
938
|
-
Helper for resolving columns.
|
|
939
|
-
|
|
940
|
-
This is a class so we can lazily load some things and easily share them across functions.
|
|
941
|
-
"""
|
|
942
|
-
|
|
943
|
-
def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
|
|
944
|
-
self.scope = scope
|
|
945
|
-
self.schema = schema
|
|
946
|
-
self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
|
|
947
|
-
self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
|
|
948
|
-
self._all_columns: t.Optional[t.Set[str]] = None
|
|
949
|
-
self._infer_schema = infer_schema
|
|
950
|
-
self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
|
|
951
|
-
|
|
952
|
-
def get_table(self, column: str | exp.Column) -> t.Optional[exp.Identifier]:
|
|
953
|
-
"""
|
|
954
|
-
Get the table for a column name.
|
|
955
|
-
|
|
956
|
-
Args:
|
|
957
|
-
column: The column expression (or column name) to find the table for.
|
|
958
|
-
Returns:
|
|
959
|
-
The table name if it can be found/inferred.
|
|
960
|
-
"""
|
|
961
|
-
column_name = column if isinstance(column, str) else column.name
|
|
962
|
-
|
|
963
|
-
table_name = self._get_table_name_from_sources(column_name)
|
|
964
|
-
|
|
965
|
-
if not table_name and isinstance(column, exp.Column):
|
|
966
|
-
# Fall-back case: If we couldn't find the `table_name` from ALL of the sources,
|
|
967
|
-
# attempt to disambiguate the column based on other characteristics e.g if this column is in a join condition,
|
|
968
|
-
# we may be able to disambiguate based on the source order.
|
|
969
|
-
if join_context := self._get_column_join_context(column):
|
|
970
|
-
# In this case, the return value will be the join that _may_ be able to disambiguate the column
|
|
971
|
-
# and we can use the source columns available at that join to get the table name
|
|
972
|
-
table_name = self._get_table_name_from_sources(
|
|
973
|
-
column_name, self._get_available_source_columns(join_context)
|
|
974
|
-
)
|
|
975
|
-
|
|
976
|
-
if not table_name and self._infer_schema:
|
|
977
|
-
sources_without_schema = tuple(
|
|
978
|
-
source
|
|
979
|
-
for source, columns in self._get_all_source_columns().items()
|
|
980
|
-
if not columns or "*" in columns
|
|
981
|
-
)
|
|
982
|
-
if len(sources_without_schema) == 1:
|
|
983
|
-
table_name = sources_without_schema[0]
|
|
984
|
-
|
|
985
|
-
if table_name not in self.scope.selected_sources:
|
|
986
|
-
return exp.to_identifier(table_name)
|
|
987
|
-
|
|
988
|
-
node, _ = self.scope.selected_sources.get(table_name)
|
|
989
|
-
|
|
990
|
-
if isinstance(node, exp.Query):
|
|
991
|
-
while node and node.alias != table_name:
|
|
992
|
-
node = node.parent
|
|
993
|
-
|
|
994
|
-
node_alias = node.args.get("alias")
|
|
995
|
-
if node_alias:
|
|
996
|
-
return exp.to_identifier(node_alias.this)
|
|
997
|
-
|
|
998
|
-
return exp.to_identifier(table_name)
|
|
999
|
-
|
|
1000
|
-
@property
|
|
1001
|
-
def all_columns(self) -> t.Set[str]:
|
|
1002
|
-
"""All available columns of all sources in this scope"""
|
|
1003
|
-
if self._all_columns is None:
|
|
1004
|
-
self._all_columns = {
|
|
1005
|
-
column for columns in self._get_all_source_columns().values() for column in columns
|
|
1006
|
-
}
|
|
1007
|
-
return self._all_columns
|
|
1008
|
-
|
|
1009
|
-
def get_source_columns_from_set_op(self, expression: exp.Expression) -> t.List[str]:
|
|
1010
|
-
if isinstance(expression, exp.Select):
|
|
1011
|
-
return expression.named_selects
|
|
1012
|
-
if isinstance(expression, exp.Subquery) and isinstance(expression.this, exp.SetOperation):
|
|
1013
|
-
# Different types of SET modifiers can be chained together if they're explicitly grouped by nesting
|
|
1014
|
-
return self.get_source_columns_from_set_op(expression.this)
|
|
1015
|
-
if not isinstance(expression, exp.SetOperation):
|
|
1016
|
-
raise OptimizeError(f"Unknown set operation: {expression}")
|
|
1017
|
-
|
|
1018
|
-
set_op = expression
|
|
1019
|
-
|
|
1020
|
-
# BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME
|
|
1021
|
-
on_column_list = set_op.args.get("on")
|
|
1022
|
-
|
|
1023
|
-
if on_column_list:
|
|
1024
|
-
# The resulting columns are the columns in the ON clause:
|
|
1025
|
-
# {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...)
|
|
1026
|
-
columns = [col.name for col in on_column_list]
|
|
1027
|
-
elif set_op.side or set_op.kind:
|
|
1028
|
-
side = set_op.side
|
|
1029
|
-
kind = set_op.kind
|
|
1030
|
-
|
|
1031
|
-
# Visit the children UNIONs (if any) in a post-order traversal
|
|
1032
|
-
left = self.get_source_columns_from_set_op(set_op.left)
|
|
1033
|
-
right = self.get_source_columns_from_set_op(set_op.right)
|
|
1034
|
-
|
|
1035
|
-
# We use dict.fromkeys to deduplicate keys and maintain insertion order
|
|
1036
|
-
if side == "LEFT":
|
|
1037
|
-
columns = left
|
|
1038
|
-
elif side == "FULL":
|
|
1039
|
-
columns = list(dict.fromkeys(left + right))
|
|
1040
|
-
elif kind == "INNER":
|
|
1041
|
-
columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys())
|
|
1042
|
-
else:
|
|
1043
|
-
columns = set_op.named_selects
|
|
1044
|
-
|
|
1045
|
-
return columns
|
|
1046
|
-
|
|
1047
|
-
def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
|
|
1048
|
-
"""Resolve the source columns for a given source `name`."""
|
|
1049
|
-
cache_key = (name, only_visible)
|
|
1050
|
-
if cache_key not in self._get_source_columns_cache:
|
|
1051
|
-
if name not in self.scope.sources:
|
|
1052
|
-
raise OptimizeError(f"Unknown table: {name}")
|
|
1053
|
-
|
|
1054
|
-
source = self.scope.sources[name]
|
|
1055
|
-
|
|
1056
|
-
if isinstance(source, exp.Table):
|
|
1057
|
-
columns = self.schema.column_names(source, only_visible)
|
|
1058
|
-
elif isinstance(source, Scope) and isinstance(
|
|
1059
|
-
source.expression, (exp.Values, exp.Unnest)
|
|
1060
|
-
):
|
|
1061
|
-
columns = source.expression.named_selects
|
|
1062
|
-
|
|
1063
|
-
# in bigquery, unnest structs are automatically scoped as tables, so you can
|
|
1064
|
-
# directly select a struct field in a query.
|
|
1065
|
-
# this handles the case where the unnest is statically defined.
|
|
1066
|
-
if self.schema.dialect == "bigquery":
|
|
1067
|
-
if source.expression.is_type(exp.DataType.Type.STRUCT):
|
|
1068
|
-
for k in source.expression.type.expressions: # type: ignore
|
|
1069
|
-
columns.append(k.name)
|
|
1070
|
-
elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation):
|
|
1071
|
-
columns = self.get_source_columns_from_set_op(source.expression)
|
|
1072
|
-
|
|
1073
|
-
else:
|
|
1074
|
-
select = seq_get(source.expression.selects, 0)
|
|
1075
|
-
|
|
1076
|
-
if isinstance(select, exp.QueryTransform):
|
|
1077
|
-
# https://spark.apache.org/docs/3.5.1/sql-ref-syntax-qry-select-transform.html
|
|
1078
|
-
schema = select.args.get("schema")
|
|
1079
|
-
columns = [c.name for c in schema.expressions] if schema else ["key", "value"]
|
|
1080
|
-
else:
|
|
1081
|
-
columns = source.expression.named_selects
|
|
1082
|
-
|
|
1083
|
-
node, _ = self.scope.selected_sources.get(name) or (None, None)
|
|
1084
|
-
if isinstance(node, Scope):
|
|
1085
|
-
column_aliases = node.expression.alias_column_names
|
|
1086
|
-
elif isinstance(node, exp.Expression):
|
|
1087
|
-
column_aliases = node.alias_column_names
|
|
1088
|
-
else:
|
|
1089
|
-
column_aliases = []
|
|
1090
|
-
|
|
1091
|
-
if column_aliases:
|
|
1092
|
-
# If the source's columns are aliased, their aliases shadow the corresponding column names.
|
|
1093
|
-
# This can be expensive if there are lots of columns, so only do this if column_aliases exist.
|
|
1094
|
-
columns = [
|
|
1095
|
-
alias or name
|
|
1096
|
-
for (name, alias) in itertools.zip_longest(columns, column_aliases)
|
|
1097
|
-
]
|
|
1098
|
-
|
|
1099
|
-
self._get_source_columns_cache[cache_key] = columns
|
|
1100
|
-
|
|
1101
|
-
return self._get_source_columns_cache[cache_key]
|
|
1102
|
-
|
|
1103
|
-
def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
|
|
1104
|
-
if self._source_columns is None:
|
|
1105
|
-
self._source_columns = {
|
|
1106
|
-
source_name: self.get_source_columns(source_name)
|
|
1107
|
-
for source_name, source in itertools.chain(
|
|
1108
|
-
self.scope.selected_sources.items(), self.scope.lateral_sources.items()
|
|
1109
|
-
)
|
|
1110
|
-
}
|
|
1111
|
-
return self._source_columns
|
|
1112
|
-
|
|
1113
|
-
def _get_table_name_from_sources(
|
|
1114
|
-
self, column_name: str, source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
|
|
1115
|
-
) -> t.Optional[str]:
|
|
1116
|
-
if not source_columns:
|
|
1117
|
-
# If not supplied, get all sources to calculate unambiguous columns
|
|
1118
|
-
if self._unambiguous_columns is None:
|
|
1119
|
-
self._unambiguous_columns = self._get_unambiguous_columns(
|
|
1120
|
-
self._get_all_source_columns()
|
|
1121
|
-
)
|
|
1122
|
-
|
|
1123
|
-
unambiguous_columns = self._unambiguous_columns
|
|
1124
|
-
else:
|
|
1125
|
-
unambiguous_columns = self._get_unambiguous_columns(source_columns)
|
|
1126
|
-
|
|
1127
|
-
return unambiguous_columns.get(column_name)
|
|
1128
|
-
|
|
1129
|
-
def _get_column_join_context(self, column: exp.Column) -> t.Optional[exp.Join]:
|
|
1130
|
-
"""
|
|
1131
|
-
Check if a column participating in a join can be qualified based on the source order.
|
|
1132
|
-
"""
|
|
1133
|
-
args = self.scope.expression.args
|
|
1134
|
-
joins = args.get("joins")
|
|
1135
|
-
|
|
1136
|
-
if not joins or args.get("laterals") or args.get("pivots"):
|
|
1137
|
-
# Feature gap: We currently don't try to disambiguate columns if other sources
|
|
1138
|
-
# (e.g laterals, pivots) exist alongside joins
|
|
1139
|
-
return None
|
|
1140
|
-
|
|
1141
|
-
join_ancestor = column.find_ancestor(exp.Join, exp.Select)
|
|
1142
|
-
|
|
1143
|
-
if (
|
|
1144
|
-
isinstance(join_ancestor, exp.Join)
|
|
1145
|
-
and join_ancestor.alias_or_name in self.scope.selected_sources
|
|
1146
|
-
):
|
|
1147
|
-
# Ensure that the found ancestor is a join that contains an actual source,
|
|
1148
|
-
# e.g in Clickhouse `b` is an array expression in `a ARRAY JOIN b`
|
|
1149
|
-
return join_ancestor
|
|
1150
|
-
|
|
1151
|
-
return None
|
|
1152
|
-
|
|
1153
|
-
def _get_available_source_columns(
|
|
1154
|
-
self, join_ancestor: exp.Join
|
|
1155
|
-
) -> t.Dict[str, t.Sequence[str]]:
|
|
1156
|
-
"""
|
|
1157
|
-
Get the source columns that are available at the point where a column is referenced.
|
|
1158
|
-
|
|
1159
|
-
For columns in JOIN conditions, this only includes tables that have been joined
|
|
1160
|
-
up to that point. Example:
|
|
1161
|
-
|
|
1162
|
-
```
|
|
1163
|
-
SELECT * FROM t_1 INNER JOIN ... INNER JOIN t_n ON t_1.a = c INNER JOIN t_n+1 ON ...
|
|
1164
|
-
``` ^
|
|
1165
|
-
|
|
|
1166
|
-
+----------------------------------+
|
|
1167
|
-
|
|
|
1168
|
-
⌄
|
|
1169
|
-
The unqualified column `c` is not ambiguous if no other sources up until that
|
|
1170
|
-
join i.e t_1, ..., t_n, contain a column named `c`.
|
|
1171
|
-
|
|
1172
|
-
"""
|
|
1173
|
-
args = self.scope.expression.args
|
|
1174
|
-
|
|
1175
|
-
# Collect tables in order: FROM clause tables + joined tables up to current join
|
|
1176
|
-
from_name = args["from"].alias_or_name
|
|
1177
|
-
available_sources = {from_name: self.get_source_columns(from_name)}
|
|
1178
|
-
|
|
1179
|
-
for join in args["joins"][: t.cast(int, join_ancestor.index) + 1]:
|
|
1180
|
-
available_sources[join.alias_or_name] = self.get_source_columns(join.alias_or_name)
|
|
1181
|
-
|
|
1182
|
-
return available_sources
|
|
1183
|
-
|
|
1184
|
-
def _get_unambiguous_columns(
|
|
1185
|
-
self, source_columns: t.Dict[str, t.Sequence[str]]
|
|
1186
|
-
) -> t.Mapping[str, str]:
|
|
1187
|
-
"""
|
|
1188
|
-
Find all the unambiguous columns in sources.
|
|
1189
|
-
|
|
1190
|
-
Args:
|
|
1191
|
-
source_columns: Mapping of names to source columns.
|
|
1192
|
-
|
|
1193
|
-
Returns:
|
|
1194
|
-
Mapping of column name to source name.
|
|
1195
|
-
"""
|
|
1196
|
-
if not source_columns:
|
|
1197
|
-
return {}
|
|
1198
|
-
|
|
1199
|
-
source_columns_pairs = list(source_columns.items())
|
|
1200
|
-
|
|
1201
|
-
first_table, first_columns = source_columns_pairs[0]
|
|
1202
|
-
|
|
1203
|
-
if len(source_columns_pairs) == 1:
|
|
1204
|
-
# Performance optimization - avoid copying first_columns if there is only one table.
|
|
1205
|
-
return SingleValuedMapping(first_columns, first_table)
|
|
1206
|
-
|
|
1207
|
-
unambiguous_columns = {col: first_table for col in first_columns}
|
|
1208
|
-
all_columns = set(unambiguous_columns)
|
|
1209
|
-
|
|
1210
|
-
for table, columns in source_columns_pairs[1:]:
|
|
1211
|
-
unique = set(columns)
|
|
1212
|
-
ambiguous = all_columns.intersection(unique)
|
|
1213
|
-
all_columns.update(columns)
|
|
1214
|
-
|
|
1215
|
-
for column in ambiguous:
|
|
1216
|
-
unambiguous_columns.pop(column, None)
|
|
1217
|
-
for column in unique.difference(ambiguous):
|
|
1218
|
-
unambiguous_columns[column] = table
|
|
1219
|
-
|
|
1220
|
-
return unambiguous_columns
|