sqlglot 27.27.0__py3-none-any.whl → 28.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. sqlglot/__init__.py +1 -0
  2. sqlglot/__main__.py +6 -4
  3. sqlglot/_version.py +2 -2
  4. sqlglot/dialects/bigquery.py +118 -279
  5. sqlglot/dialects/clickhouse.py +73 -5
  6. sqlglot/dialects/databricks.py +38 -1
  7. sqlglot/dialects/dialect.py +354 -275
  8. sqlglot/dialects/dremio.py +4 -1
  9. sqlglot/dialects/duckdb.py +754 -25
  10. sqlglot/dialects/exasol.py +243 -10
  11. sqlglot/dialects/hive.py +8 -8
  12. sqlglot/dialects/mysql.py +14 -4
  13. sqlglot/dialects/oracle.py +29 -0
  14. sqlglot/dialects/postgres.py +60 -26
  15. sqlglot/dialects/presto.py +47 -16
  16. sqlglot/dialects/redshift.py +16 -0
  17. sqlglot/dialects/risingwave.py +3 -0
  18. sqlglot/dialects/singlestore.py +12 -3
  19. sqlglot/dialects/snowflake.py +239 -218
  20. sqlglot/dialects/spark.py +15 -4
  21. sqlglot/dialects/spark2.py +11 -48
  22. sqlglot/dialects/sqlite.py +10 -0
  23. sqlglot/dialects/starrocks.py +3 -0
  24. sqlglot/dialects/teradata.py +5 -8
  25. sqlglot/dialects/trino.py +6 -0
  26. sqlglot/dialects/tsql.py +61 -22
  27. sqlglot/diff.py +4 -2
  28. sqlglot/errors.py +69 -0
  29. sqlglot/executor/__init__.py +5 -10
  30. sqlglot/executor/python.py +1 -29
  31. sqlglot/expressions.py +637 -100
  32. sqlglot/generator.py +160 -43
  33. sqlglot/helper.py +2 -44
  34. sqlglot/lineage.py +10 -4
  35. sqlglot/optimizer/annotate_types.py +247 -140
  36. sqlglot/optimizer/canonicalize.py +6 -1
  37. sqlglot/optimizer/eliminate_joins.py +1 -1
  38. sqlglot/optimizer/eliminate_subqueries.py +2 -2
  39. sqlglot/optimizer/merge_subqueries.py +5 -5
  40. sqlglot/optimizer/normalize.py +20 -13
  41. sqlglot/optimizer/normalize_identifiers.py +17 -3
  42. sqlglot/optimizer/optimizer.py +4 -0
  43. sqlglot/optimizer/pushdown_predicates.py +1 -1
  44. sqlglot/optimizer/qualify.py +18 -10
  45. sqlglot/optimizer/qualify_columns.py +122 -275
  46. sqlglot/optimizer/qualify_tables.py +128 -76
  47. sqlglot/optimizer/resolver.py +374 -0
  48. sqlglot/optimizer/scope.py +27 -16
  49. sqlglot/optimizer/simplify.py +1075 -959
  50. sqlglot/optimizer/unnest_subqueries.py +12 -2
  51. sqlglot/parser.py +296 -170
  52. sqlglot/planner.py +2 -2
  53. sqlglot/schema.py +15 -4
  54. sqlglot/tokens.py +42 -7
  55. sqlglot/transforms.py +77 -22
  56. sqlglot/typing/__init__.py +316 -0
  57. sqlglot/typing/bigquery.py +376 -0
  58. sqlglot/typing/hive.py +12 -0
  59. sqlglot/typing/presto.py +24 -0
  60. sqlglot/typing/snowflake.py +505 -0
  61. sqlglot/typing/spark2.py +58 -0
  62. sqlglot/typing/tsql.py +9 -0
  63. {sqlglot-27.27.0.dist-info → sqlglot-28.4.0.dist-info}/METADATA +2 -2
  64. sqlglot-28.4.0.dist-info/RECORD +92 -0
  65. sqlglot-27.27.0.dist-info/RECORD +0 -84
  66. {sqlglot-27.27.0.dist-info → sqlglot-28.4.0.dist-info}/WHEEL +0 -0
  67. {sqlglot-27.27.0.dist-info → sqlglot-28.4.0.dist-info}/licenses/LICENSE +0 -0
  68. {sqlglot-27.27.0.dist-info → sqlglot-28.4.0.dist-info}/top_level.txt +0 -0
@@ -5,9 +5,10 @@ import typing as t
5
5
 
6
6
  from sqlglot import alias, exp
7
7
  from sqlglot.dialects.dialect import Dialect, DialectType
8
- from sqlglot.errors import OptimizeError
9
- from sqlglot.helper import seq_get, SingleValuedMapping
8
+ from sqlglot.errors import OptimizeError, highlight_sql
9
+ from sqlglot.helper import seq_get
10
10
  from sqlglot.optimizer.annotate_types import TypeAnnotator
11
+ from sqlglot.optimizer.resolver import Resolver
11
12
  from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope
12
13
  from sqlglot.optimizer.simplify import simplify_parens
13
14
  from sqlglot.schema import Schema, ensure_schema
@@ -54,22 +55,17 @@ def qualify_columns(
54
55
  schema = ensure_schema(schema, dialect=dialect)
55
56
  annotator = TypeAnnotator(schema)
56
57
  infer_schema = schema.empty if infer_schema is None else infer_schema
57
- dialect = Dialect.get_or_raise(schema.dialect)
58
+ dialect = schema.dialect or Dialect()
58
59
  pseudocolumns = dialect.PSEUDOCOLUMNS
59
- bigquery = dialect == "bigquery"
60
60
 
61
61
  for scope in traverse_scope(expression):
62
+ if dialect.PREFER_CTE_ALIAS_COLUMN:
63
+ pushdown_cte_alias_columns(scope)
64
+
62
65
  scope_expression = scope.expression
63
66
  is_select = isinstance(scope_expression, exp.Select)
64
67
 
65
- if is_select and scope_expression.args.get("connect"):
66
- # In Snowflake / Oracle queries that have a CONNECT BY clause, one can use the LEVEL
67
- # pseudocolumn, which doesn't belong to a table, so we change it into an identifier
68
- scope_expression.transform(
69
- lambda n: n.this if isinstance(n, exp.Column) and n.name == "LEVEL" else n,
70
- copy=False,
71
- )
72
- scope.clear_cache()
68
+ _separate_pseudocolumns(scope, pseudocolumns)
73
69
 
74
70
  resolver = Resolver(scope, schema, infer_schema=infer_schema)
75
71
  _pop_table_column_aliases(scope.ctes)
@@ -81,11 +77,15 @@ def qualify_columns(
81
77
  scope,
82
78
  resolver,
83
79
  dialect,
84
- expand_only_groupby=bigquery,
80
+ expand_only_groupby=dialect.EXPAND_ONLY_GROUP_ALIAS_REF,
85
81
  )
86
82
 
87
83
  _convert_columns_to_dots(scope, resolver)
88
- _qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification)
84
+ _qualify_columns(
85
+ scope,
86
+ resolver,
87
+ allow_partial_qualification=allow_partial_qualification,
88
+ )
89
89
 
90
90
  if not schema.empty and expand_alias_refs:
91
91
  _expand_alias_refs(scope, resolver, dialect)
@@ -107,13 +107,13 @@ def qualify_columns(
107
107
  # https://www.postgresql.org/docs/current/sql-select.html#SQL-DISTINCT
108
108
  _expand_order_by_and_distinct_on(scope, resolver)
109
109
 
110
- if bigquery:
110
+ if dialect.ANNOTATE_ALL_SCOPES:
111
111
  annotator.annotate_scope(scope)
112
112
 
113
113
  return expression
114
114
 
115
115
 
116
- def validate_qualify_columns(expression: E) -> E:
116
+ def validate_qualify_columns(expression: E, sql: t.Optional[str] = None) -> E:
117
117
  """Raise an `OptimizeError` if any columns aren't qualified"""
118
118
  all_unqualified_columns = []
119
119
  for scope in traverse_scope(expression):
@@ -123,7 +123,19 @@ def validate_qualify_columns(expression: E) -> E:
123
123
  if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
124
124
  column = scope.external_columns[0]
125
125
  for_table = f" for table: '{column.table}'" if column.table else ""
126
- raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
126
+ line = column.this.meta.get("line")
127
+ col = column.this.meta.get("col")
128
+ start = column.this.meta.get("start")
129
+ end = column.this.meta.get("end")
130
+
131
+ error_msg = f"Column '{column.name}' could not be resolved{for_table}."
132
+ if line and col:
133
+ error_msg += f" Line: {line}, Col: {col}"
134
+ if sql and start is not None and end is not None:
135
+ formatted_sql = highlight_sql(sql, [(start, end)])[0]
136
+ error_msg += f"\n {formatted_sql}"
137
+
138
+ raise OptimizeError(error_msg)
127
139
 
128
140
  if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
129
141
  # New columns produced by the UNPIVOT can't be qualified, but there may be columns
@@ -135,11 +147,46 @@ def validate_qualify_columns(expression: E) -> E:
135
147
  all_unqualified_columns.extend(unqualified_columns)
136
148
 
137
149
  if all_unqualified_columns:
138
- raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
150
+ first_column = all_unqualified_columns[0]
151
+ line = first_column.this.meta.get("line")
152
+ col = first_column.this.meta.get("col")
153
+ start = first_column.this.meta.get("start")
154
+ end = first_column.this.meta.get("end")
155
+
156
+ error_msg = f"Ambiguous column '{first_column.name}'"
157
+ if line and col:
158
+ error_msg += f" (Line: {line}, Col: {col})"
159
+ if sql and start is not None and end is not None:
160
+ formatted_sql = highlight_sql(sql, [(start, end)])[0]
161
+ error_msg += f"\n {formatted_sql}"
162
+
163
+ raise OptimizeError(error_msg)
139
164
 
140
165
  return expression
141
166
 
142
167
 
168
+ def _separate_pseudocolumns(scope: Scope, pseudocolumns: t.Set[str]) -> None:
169
+ if not pseudocolumns:
170
+ return
171
+
172
+ has_pseudocolumns = False
173
+ scope_expression = scope.expression
174
+
175
+ for column in scope.columns:
176
+ name = column.name.upper()
177
+ if name not in pseudocolumns:
178
+ continue
179
+
180
+ if name != "LEVEL" or (
181
+ isinstance(scope_expression, exp.Select) and scope_expression.args.get("connect")
182
+ ):
183
+ column.replace(exp.Pseudocolumn(**column.args))
184
+ has_pseudocolumns = True
185
+
186
+ if has_pseudocolumns:
187
+ scope.clear_cache()
188
+
189
+
143
190
  def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]:
144
191
  name_columns = [
145
192
  field.this
@@ -274,16 +321,17 @@ def _expand_alias_refs(
274
321
  """
275
322
  expression = scope.expression
276
323
 
277
- if not isinstance(expression, exp.Select) or dialect == "oracle":
324
+ if not isinstance(expression, exp.Select) or dialect.DISABLES_ALIAS_REF_EXPANSION:
278
325
  return
279
326
 
280
327
  alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {}
281
328
  projections = {s.alias_or_name for s in expression.selects}
282
- is_bigquery = dialect == "bigquery"
329
+ replaced = False
283
330
 
284
331
  def replace_columns(
285
332
  node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False
286
333
  ) -> None:
334
+ nonlocal replaced
287
335
  is_group_by = isinstance(node, exp.Group)
288
336
  is_having = isinstance(node, exp.Having)
289
337
  if not node or (expand_only_groupby and not is_group_by):
@@ -315,12 +363,12 @@ def _expand_alias_refs(
315
363
  # SELECT x.a, max(x.b) as x FROM x GROUP BY 1 HAVING x > 1;
316
364
  # If "HAVING x" is expanded to "HAVING max(x.b)", BQ would blindly replace the "x" reference with the projection MAX(x.b)
317
365
  # i.e HAVING MAX(MAX(x.b).b), resulting in the error: "Aggregations of aggregations are not allowed"
318
- if is_having and is_bigquery:
366
+ if is_having and dialect.PROJECTION_ALIASES_SHADOW_SOURCE_NAMES:
319
367
  skip_replace = skip_replace or any(
320
368
  node.parts[0].name in projections
321
369
  for node in alias_expr.find_all(exp.Column)
322
370
  )
323
- elif is_bigquery and (is_group_by or is_having):
371
+ elif dialect.PROJECTION_ALIASES_SHADOW_SOURCE_NAMES and (is_group_by or is_having):
324
372
  column_table = table.name if table else column.table
325
373
  if column_table in projections:
326
374
  # BigQuery's GROUP BY and HAVING clauses get confused if the column name
@@ -329,6 +377,7 @@ def _expand_alias_refs(
329
377
  # We should not qualify "id" with "custom_fields" in either clause, since the aggregation shadows the actual table
330
378
  # and we'd get the error: "Column custom_fields contains an aggregation function, which is not allowed in GROUP BY clause"
331
379
  column.replace(exp.to_identifier(column.name))
380
+ replaced = True
332
381
  return
333
382
 
334
383
  if table and (not alias_expr or skip_replace):
@@ -339,7 +388,9 @@ def _expand_alias_refs(
339
388
  ):
340
389
  if literal_index:
341
390
  column.replace(exp.Literal.number(i))
391
+ replaced = True
342
392
  else:
393
+ replaced = True
343
394
  column = column.replace(exp.paren(alias_expr))
344
395
  simplified = simplify_parens(column, dialect)
345
396
  if simplified is not column:
@@ -370,16 +421,15 @@ def _expand_alias_refs(
370
421
  replace_columns(expression.args.get("having"), resolve_table=True)
371
422
  replace_columns(expression.args.get("qualify"), resolve_table=True)
372
423
 
373
- # Snowflake allows alias expansion in the JOIN ... ON clause (and almost everywhere else)
374
- # https://docs.snowflake.com/en/sql-reference/sql/select#usage-notes
375
- if dialect == "snowflake":
424
+ if dialect.SUPPORTS_ALIAS_REFS_IN_JOIN_CONDITIONS:
376
425
  for join in expression.args.get("joins") or []:
377
426
  replace_columns(join)
378
427
 
379
- scope.clear_cache()
428
+ if replaced:
429
+ scope.clear_cache()
380
430
 
381
431
 
382
- def _expand_group_by(scope: Scope, dialect: DialectType) -> None:
432
+ def _expand_group_by(scope: Scope, dialect: Dialect) -> None:
383
433
  expression = scope.expression
384
434
  group = expression.args.get("group")
385
435
  if not group:
@@ -405,7 +455,7 @@ def _expand_order_by_and_distinct_on(scope: Scope, resolver: Resolver) -> None:
405
455
  for original, expanded in zip(
406
456
  modifier_expressions,
407
457
  _expand_positional_references(
408
- scope, modifier_expressions, resolver.schema.dialect, alias=True
458
+ scope, modifier_expressions, resolver.dialect, alias=True
409
459
  ),
410
460
  ):
411
461
  for agg in original.find_all(exp.AggFunc):
@@ -427,7 +477,7 @@ def _expand_order_by_and_distinct_on(scope: Scope, resolver: Resolver) -> None:
427
477
 
428
478
 
429
479
  def _expand_positional_references(
430
- scope: Scope, expressions: t.Iterable[exp.Expression], dialect: DialectType, alias: bool = False
480
+ scope: Scope, expressions: t.Iterable[exp.Expression], dialect: Dialect, alias: bool = False
431
481
  ) -> t.List[exp.Expression]:
432
482
  new_nodes: t.List[exp.Expression] = []
433
483
  ambiguous_projections = None
@@ -441,7 +491,7 @@ def _expand_positional_references(
441
491
  else:
442
492
  select = select.this
443
493
 
444
- if dialect == "bigquery":
494
+ if dialect.PROJECTION_ALIASES_SHADOW_SOURCE_NAMES:
445
495
  if ambiguous_projections is None:
446
496
  # When a projection name is also a source name and it is referenced in the
447
497
  # GROUP BY clause, BQ can't understand what the identifier corresponds to
@@ -482,10 +532,10 @@ def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias:
482
532
 
483
533
  def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None:
484
534
  """
485
- Converts `Column` instances that represent struct field lookup into chained `Dots`.
535
+ Converts `Column` instances that represent STRUCT or JSON field lookup into chained `Dots`.
486
536
 
487
- Struct field lookups look like columns (e.g. "struct"."field"), but they need to be
488
- qualified separately and represented as Dot(Dot(...(<table>.<column>, field1), field2, ...)).
537
+ These lookups may be parsed as columns (e.g. "col"."field"."field2"), but they need to be
538
+ normalized to `Dot(Dot(...(<table>.<column>, field1), field2, ...))` to be qualified properly.
489
539
  """
490
540
  converted = False
491
541
  for column in itertools.chain(scope.columns, scope.stars):
@@ -493,6 +543,7 @@ def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None:
493
543
  continue
494
544
 
495
545
  column_table: t.Optional[str | exp.Identifier] = column.table
546
+ dot_parts = column.meta.pop("dot_parts", [])
496
547
  if (
497
548
  column_table
498
549
  and column_table not in scope.sources
@@ -508,12 +559,20 @@ def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None:
508
559
  # The struct is already qualified, but we still need to change the AST
509
560
  column_table = root
510
561
  root, *parts = parts
562
+ was_qualified = True
511
563
  else:
512
564
  column_table = resolver.get_table(root.name)
565
+ was_qualified = False
513
566
 
514
567
  if column_table:
515
568
  converted = True
516
- column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
569
+ new_column = exp.column(root, table=column_table)
570
+
571
+ if dot_parts:
572
+ # Remove the actual column parts from the rest of dot parts
573
+ new_column.meta["dot_parts"] = dot_parts[2 if was_qualified else 1 :]
574
+
575
+ column.replace(exp.Dot.build([new_column, *parts]))
517
576
 
518
577
  if converted:
519
578
  # We want to re-aggregate the converted columns, otherwise they'd be skipped in
@@ -521,7 +580,11 @@ def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None:
521
580
  scope.clear_cache()
522
581
 
523
582
 
524
- def _qualify_columns(scope: Scope, resolver: Resolver, allow_partial_qualification: bool) -> None:
583
+ def _qualify_columns(
584
+ scope: Scope,
585
+ resolver: Resolver,
586
+ allow_partial_qualification: bool,
587
+ ) -> None:
525
588
  """Disambiguate columns, ensuring each column specifies a source"""
526
589
  for column in scope.columns:
527
590
  column_table = column.table
@@ -545,15 +608,16 @@ def _qualify_columns(scope: Scope, resolver: Resolver, allow_partial_qualificati
545
608
  continue
546
609
 
547
610
  # column_table can be a '' because bigquery unnest has no table alias
548
- column_table = resolver.get_table(column_name)
611
+ column_table = resolver.get_table(column)
612
+
549
613
  if column_table:
550
614
  column.set("table", column_table)
551
615
  elif (
552
- resolver.schema.dialect == "bigquery"
616
+ resolver.dialect.TABLES_REFERENCEABLE_AS_COLUMNS
553
617
  and len(column.parts) == 1
554
618
  and column_name in scope.selected_sources
555
619
  ):
556
- # BigQuery allows tables to be referenced as columns, treating them as structs
620
+ # BigQuery and Postgres allow tables to be referenced as columns, treating them as structs/records
557
621
  scope.replace(column, exp.TableColumn(this=column.this))
558
622
 
559
623
  for pivot in scope.pivots:
@@ -564,7 +628,7 @@ def _qualify_columns(scope: Scope, resolver: Resolver, allow_partial_qualificati
564
628
  column.set("table", column_table)
565
629
 
566
630
 
567
- def _expand_struct_stars_bigquery(
631
+ def _expand_struct_stars_no_parens(
568
632
  expression: exp.Dot,
569
633
  ) -> t.List[exp.Alias]:
570
634
  """[BigQuery] Expand/Flatten foo.bar.* where bar is a struct column"""
@@ -618,7 +682,7 @@ def _expand_struct_stars_bigquery(
618
682
  return new_selections
619
683
 
620
684
 
621
- def _expand_struct_stars_risingwave(expression: exp.Dot) -> t.List[exp.Alias]:
685
+ def _expand_struct_stars_with_parens(expression: exp.Dot) -> t.List[exp.Alias]:
622
686
  """[RisingWave] Expand/Flatten (<exp>.bar).*, where bar is a struct column"""
623
687
 
624
688
  # it is not (<sub_exp>).* pattern, which means we can't expand
@@ -695,7 +759,7 @@ def _expand_stars(
695
759
  rename_columns: t.Dict[int, t.Dict[str, str]] = {}
696
760
 
697
761
  coalesced_columns = set()
698
- dialect = resolver.schema.dialect
762
+ dialect = resolver.dialect
699
763
 
700
764
  pivot_output_columns = None
701
765
  pivot_exclude_columns: t.Set[str] = set()
@@ -718,10 +782,9 @@ def _expand_stars(
718
782
  if not pivot_output_columns:
719
783
  pivot_output_columns = [c.alias_or_name for c in pivot.expressions]
720
784
 
721
- is_bigquery = dialect == "bigquery"
722
- is_risingwave = dialect == "risingwave"
723
-
724
- if (is_bigquery or is_risingwave) and any(isinstance(col, exp.Dot) for col in scope.stars):
785
+ if dialect.SUPPORTS_STRUCT_STAR_EXPANSION and any(
786
+ isinstance(col, exp.Dot) for col in scope.stars
787
+ ):
725
788
  # Found struct expansion, annotate scope ahead of time
726
789
  annotator.annotate_scope(scope)
727
790
 
@@ -738,13 +801,16 @@ def _expand_stars(
738
801
  _add_except_columns(expression.this, tables, except_columns)
739
802
  _add_replace_columns(expression.this, tables, replace_columns)
740
803
  _add_rename_columns(expression.this, tables, rename_columns)
741
- elif is_bigquery:
742
- struct_fields = _expand_struct_stars_bigquery(expression)
804
+ elif (
805
+ dialect.SUPPORTS_STRUCT_STAR_EXPANSION
806
+ and not dialect.REQUIRES_PARENTHESIZED_STRUCT_ACCESS
807
+ ):
808
+ struct_fields = _expand_struct_stars_no_parens(expression)
743
809
  if struct_fields:
744
810
  new_selections.extend(struct_fields)
745
811
  continue
746
- elif is_risingwave:
747
- struct_fields = _expand_struct_stars_risingwave(expression)
812
+ elif dialect.REQUIRES_PARENTHESIZED_STRUCT_ACCESS:
813
+ struct_fields = _expand_struct_stars_with_parens(expression)
748
814
  if struct_fields:
749
815
  new_selections.extend(struct_fields)
750
816
  continue
@@ -760,7 +826,7 @@ def _expand_stars(
760
826
  columns = resolver.get_source_columns(table, only_visible=True)
761
827
  columns = columns or scope.outer_columns
762
828
 
763
- if pseudocolumns:
829
+ if pseudocolumns and dialect.EXCLUDES_PSEUDOCOLUMNS_FROM_STAR:
764
830
  columns = [name for name in columns if name.upper() not in pseudocolumns]
765
831
 
766
832
  if not columns or "*" in columns:
@@ -814,7 +880,7 @@ def _expand_stars(
814
880
  def _add_except_columns(
815
881
  expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]]
816
882
  ) -> None:
817
- except_ = expression.args.get("except")
883
+ except_ = expression.args.get("except_")
818
884
 
819
885
  if not except_:
820
886
  return
@@ -894,241 +960,22 @@ def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool
894
960
  ) # type: ignore
895
961
 
896
962
 
897
- def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
963
+ def pushdown_cte_alias_columns(scope: Scope) -> None:
898
964
  """
899
965
  Pushes down the CTE alias columns into the projection,
900
966
 
901
967
  This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
902
968
 
903
- Example:
904
- >>> import sqlglot
905
- >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
906
- >>> pushdown_cte_alias_columns(expression).sql()
907
- 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
908
-
909
969
  Args:
910
- expression: Expression to pushdown.
911
-
912
- Returns:
913
- The expression with the CTE aliases pushed down into the projection.
970
+ scope: Scope to find ctes to pushdown aliases.
914
971
  """
915
- for cte in expression.find_all(exp.CTE):
916
- if cte.alias_column_names:
972
+ for cte in scope.ctes:
973
+ if cte.alias_column_names and isinstance(cte.this, exp.Select):
917
974
  new_expressions = []
918
975
  for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
919
976
  if isinstance(projection, exp.Alias):
920
- projection.set("alias", _alias)
977
+ projection.set("alias", exp.to_identifier(_alias))
921
978
  else:
922
979
  projection = alias(projection, alias=_alias)
923
980
  new_expressions.append(projection)
924
981
  cte.this.set("expressions", new_expressions)
925
-
926
- return expression
927
-
928
-
929
- class Resolver:
930
- """
931
- Helper for resolving columns.
932
-
933
- This is a class so we can lazily load some things and easily share them across functions.
934
- """
935
-
936
- def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
937
- self.scope = scope
938
- self.schema = schema
939
- self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
940
- self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
941
- self._all_columns: t.Optional[t.Set[str]] = None
942
- self._infer_schema = infer_schema
943
- self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
944
-
945
- def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
946
- """
947
- Get the table for a column name.
948
-
949
- Args:
950
- column_name: The column name to find the table for.
951
- Returns:
952
- The table name if it can be found/inferred.
953
- """
954
- if self._unambiguous_columns is None:
955
- self._unambiguous_columns = self._get_unambiguous_columns(
956
- self._get_all_source_columns()
957
- )
958
-
959
- table_name = self._unambiguous_columns.get(column_name)
960
-
961
- if not table_name and self._infer_schema:
962
- sources_without_schema = tuple(
963
- source
964
- for source, columns in self._get_all_source_columns().items()
965
- if not columns or "*" in columns
966
- )
967
- if len(sources_without_schema) == 1:
968
- table_name = sources_without_schema[0]
969
-
970
- if table_name not in self.scope.selected_sources:
971
- return exp.to_identifier(table_name)
972
-
973
- node, _ = self.scope.selected_sources.get(table_name)
974
-
975
- if isinstance(node, exp.Query):
976
- while node and node.alias != table_name:
977
- node = node.parent
978
-
979
- node_alias = node.args.get("alias")
980
- if node_alias:
981
- return exp.to_identifier(node_alias.this)
982
-
983
- return exp.to_identifier(table_name)
984
-
985
- @property
986
- def all_columns(self) -> t.Set[str]:
987
- """All available columns of all sources in this scope"""
988
- if self._all_columns is None:
989
- self._all_columns = {
990
- column for columns in self._get_all_source_columns().values() for column in columns
991
- }
992
- return self._all_columns
993
-
994
- def get_source_columns_from_set_op(self, expression: exp.Expression) -> t.List[str]:
995
- if isinstance(expression, exp.Select):
996
- return expression.named_selects
997
- if isinstance(expression, exp.Subquery) and isinstance(expression.this, exp.SetOperation):
998
- # Different types of SET modifiers can be chained together if they're explicitly grouped by nesting
999
- return self.get_source_columns_from_set_op(expression.this)
1000
- if not isinstance(expression, exp.SetOperation):
1001
- raise OptimizeError(f"Unknown set operation: {expression}")
1002
-
1003
- set_op = expression
1004
-
1005
- # BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME
1006
- on_column_list = set_op.args.get("on")
1007
-
1008
- if on_column_list:
1009
- # The resulting columns are the columns in the ON clause:
1010
- # {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...)
1011
- columns = [col.name for col in on_column_list]
1012
- elif set_op.side or set_op.kind:
1013
- side = set_op.side
1014
- kind = set_op.kind
1015
-
1016
- # Visit the children UNIONs (if any) in a post-order traversal
1017
- left = self.get_source_columns_from_set_op(set_op.left)
1018
- right = self.get_source_columns_from_set_op(set_op.right)
1019
-
1020
- # We use dict.fromkeys to deduplicate keys and maintain insertion order
1021
- if side == "LEFT":
1022
- columns = left
1023
- elif side == "FULL":
1024
- columns = list(dict.fromkeys(left + right))
1025
- elif kind == "INNER":
1026
- columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys())
1027
- else:
1028
- columns = set_op.named_selects
1029
-
1030
- return columns
1031
-
1032
- def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
1033
- """Resolve the source columns for a given source `name`."""
1034
- cache_key = (name, only_visible)
1035
- if cache_key not in self._get_source_columns_cache:
1036
- if name not in self.scope.sources:
1037
- raise OptimizeError(f"Unknown table: {name}")
1038
-
1039
- source = self.scope.sources[name]
1040
-
1041
- if isinstance(source, exp.Table):
1042
- columns = self.schema.column_names(source, only_visible)
1043
- elif isinstance(source, Scope) and isinstance(
1044
- source.expression, (exp.Values, exp.Unnest)
1045
- ):
1046
- columns = source.expression.named_selects
1047
-
1048
- # in bigquery, unnest structs are automatically scoped as tables, so you can
1049
- # directly select a struct field in a query.
1050
- # this handles the case where the unnest is statically defined.
1051
- if self.schema.dialect == "bigquery":
1052
- if source.expression.is_type(exp.DataType.Type.STRUCT):
1053
- for k in source.expression.type.expressions: # type: ignore
1054
- columns.append(k.name)
1055
- elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation):
1056
- columns = self.get_source_columns_from_set_op(source.expression)
1057
-
1058
- else:
1059
- select = seq_get(source.expression.selects, 0)
1060
-
1061
- if isinstance(select, exp.QueryTransform):
1062
- # https://spark.apache.org/docs/3.5.1/sql-ref-syntax-qry-select-transform.html
1063
- schema = select.args.get("schema")
1064
- columns = [c.name for c in schema.expressions] if schema else ["key", "value"]
1065
- else:
1066
- columns = source.expression.named_selects
1067
-
1068
- node, _ = self.scope.selected_sources.get(name) or (None, None)
1069
- if isinstance(node, Scope):
1070
- column_aliases = node.expression.alias_column_names
1071
- elif isinstance(node, exp.Expression):
1072
- column_aliases = node.alias_column_names
1073
- else:
1074
- column_aliases = []
1075
-
1076
- if column_aliases:
1077
- # If the source's columns are aliased, their aliases shadow the corresponding column names.
1078
- # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
1079
- columns = [
1080
- alias or name
1081
- for (name, alias) in itertools.zip_longest(columns, column_aliases)
1082
- ]
1083
-
1084
- self._get_source_columns_cache[cache_key] = columns
1085
-
1086
- return self._get_source_columns_cache[cache_key]
1087
-
1088
- def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
1089
- if self._source_columns is None:
1090
- self._source_columns = {
1091
- source_name: self.get_source_columns(source_name)
1092
- for source_name, source in itertools.chain(
1093
- self.scope.selected_sources.items(), self.scope.lateral_sources.items()
1094
- )
1095
- }
1096
- return self._source_columns
1097
-
1098
- def _get_unambiguous_columns(
1099
- self, source_columns: t.Dict[str, t.Sequence[str]]
1100
- ) -> t.Mapping[str, str]:
1101
- """
1102
- Find all the unambiguous columns in sources.
1103
-
1104
- Args:
1105
- source_columns: Mapping of names to source columns.
1106
-
1107
- Returns:
1108
- Mapping of column name to source name.
1109
- """
1110
- if not source_columns:
1111
- return {}
1112
-
1113
- source_columns_pairs = list(source_columns.items())
1114
-
1115
- first_table, first_columns = source_columns_pairs[0]
1116
-
1117
- if len(source_columns_pairs) == 1:
1118
- # Performance optimization - avoid copying first_columns if there is only one table.
1119
- return SingleValuedMapping(first_columns, first_table)
1120
-
1121
- unambiguous_columns = {col: first_table for col in first_columns}
1122
- all_columns = set(unambiguous_columns)
1123
-
1124
- for table, columns in source_columns_pairs[1:]:
1125
- unique = set(columns)
1126
- ambiguous = all_columns.intersection(unique)
1127
- all_columns.update(columns)
1128
-
1129
- for column in ambiguous:
1130
- unambiguous_columns.pop(column, None)
1131
- for column in unique.difference(ambiguous):
1132
- unambiguous_columns[column] = table
1133
-
1134
- return unambiguous_columns