sqlglot 27.29.0__py3-none-any.whl → 28.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. sqlglot/__main__.py +6 -4
  2. sqlglot/_version.py +2 -2
  3. sqlglot/dialects/bigquery.py +116 -295
  4. sqlglot/dialects/clickhouse.py +67 -2
  5. sqlglot/dialects/databricks.py +38 -1
  6. sqlglot/dialects/dialect.py +327 -286
  7. sqlglot/dialects/dremio.py +4 -1
  8. sqlglot/dialects/duckdb.py +718 -22
  9. sqlglot/dialects/exasol.py +243 -10
  10. sqlglot/dialects/hive.py +8 -8
  11. sqlglot/dialects/mysql.py +11 -2
  12. sqlglot/dialects/oracle.py +29 -0
  13. sqlglot/dialects/postgres.py +46 -24
  14. sqlglot/dialects/presto.py +47 -16
  15. sqlglot/dialects/redshift.py +16 -0
  16. sqlglot/dialects/risingwave.py +3 -0
  17. sqlglot/dialects/singlestore.py +12 -3
  18. sqlglot/dialects/snowflake.py +199 -271
  19. sqlglot/dialects/spark.py +2 -2
  20. sqlglot/dialects/spark2.py +11 -48
  21. sqlglot/dialects/sqlite.py +9 -0
  22. sqlglot/dialects/teradata.py +5 -8
  23. sqlglot/dialects/trino.py +6 -0
  24. sqlglot/dialects/tsql.py +61 -25
  25. sqlglot/diff.py +4 -2
  26. sqlglot/errors.py +69 -0
  27. sqlglot/expressions.py +484 -84
  28. sqlglot/generator.py +143 -41
  29. sqlglot/helper.py +2 -2
  30. sqlglot/optimizer/annotate_types.py +247 -140
  31. sqlglot/optimizer/canonicalize.py +6 -1
  32. sqlglot/optimizer/eliminate_joins.py +1 -1
  33. sqlglot/optimizer/eliminate_subqueries.py +2 -2
  34. sqlglot/optimizer/merge_subqueries.py +5 -5
  35. sqlglot/optimizer/normalize.py +20 -13
  36. sqlglot/optimizer/normalize_identifiers.py +17 -3
  37. sqlglot/optimizer/optimizer.py +4 -0
  38. sqlglot/optimizer/pushdown_predicates.py +1 -1
  39. sqlglot/optimizer/qualify.py +14 -6
  40. sqlglot/optimizer/qualify_columns.py +113 -352
  41. sqlglot/optimizer/qualify_tables.py +112 -70
  42. sqlglot/optimizer/resolver.py +374 -0
  43. sqlglot/optimizer/scope.py +27 -16
  44. sqlglot/optimizer/simplify.py +1074 -964
  45. sqlglot/optimizer/unnest_subqueries.py +12 -2
  46. sqlglot/parser.py +276 -160
  47. sqlglot/planner.py +2 -2
  48. sqlglot/schema.py +15 -4
  49. sqlglot/tokens.py +42 -7
  50. sqlglot/transforms.py +77 -22
  51. sqlglot/typing/__init__.py +316 -0
  52. sqlglot/typing/bigquery.py +376 -0
  53. sqlglot/typing/hive.py +12 -0
  54. sqlglot/typing/presto.py +24 -0
  55. sqlglot/typing/snowflake.py +505 -0
  56. sqlglot/typing/spark2.py +58 -0
  57. sqlglot/typing/tsql.py +9 -0
  58. {sqlglot-27.29.0.dist-info → sqlglot-28.4.0.dist-info}/METADATA +2 -2
  59. sqlglot-28.4.0.dist-info/RECORD +92 -0
  60. sqlglot-27.29.0.dist-info/RECORD +0 -84
  61. {sqlglot-27.29.0.dist-info → sqlglot-28.4.0.dist-info}/WHEEL +0 -0
  62. {sqlglot-27.29.0.dist-info → sqlglot-28.4.0.dist-info}/licenses/LICENSE +0 -0
  63. {sqlglot-27.29.0.dist-info → sqlglot-28.4.0.dist-info}/top_level.txt +0 -0
@@ -5,9 +5,10 @@ import typing as t
5
5
 
6
6
  from sqlglot import alias, exp
7
7
  from sqlglot.dialects.dialect import Dialect, DialectType
8
- from sqlglot.errors import OptimizeError
9
- from sqlglot.helper import seq_get, SingleValuedMapping
8
+ from sqlglot.errors import OptimizeError, highlight_sql
9
+ from sqlglot.helper import seq_get
10
10
  from sqlglot.optimizer.annotate_types import TypeAnnotator
11
+ from sqlglot.optimizer.resolver import Resolver
11
12
  from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope
12
13
  from sqlglot.optimizer.simplify import simplify_parens
13
14
  from sqlglot.schema import Schema, ensure_schema
@@ -54,22 +55,17 @@ def qualify_columns(
54
55
  schema = ensure_schema(schema, dialect=dialect)
55
56
  annotator = TypeAnnotator(schema)
56
57
  infer_schema = schema.empty if infer_schema is None else infer_schema
57
- dialect = Dialect.get_or_raise(schema.dialect)
58
+ dialect = schema.dialect or Dialect()
58
59
  pseudocolumns = dialect.PSEUDOCOLUMNS
59
- bigquery = dialect == "bigquery"
60
60
 
61
61
  for scope in traverse_scope(expression):
62
+ if dialect.PREFER_CTE_ALIAS_COLUMN:
63
+ pushdown_cte_alias_columns(scope)
64
+
62
65
  scope_expression = scope.expression
63
66
  is_select = isinstance(scope_expression, exp.Select)
64
67
 
65
- if is_select and scope_expression.args.get("connect"):
66
- # In Snowflake / Oracle queries that have a CONNECT BY clause, one can use the LEVEL
67
- # pseudocolumn, which doesn't belong to a table, so we change it into an identifier
68
- scope_expression.transform(
69
- lambda n: n.this if isinstance(n, exp.Column) and n.name == "LEVEL" else n,
70
- copy=False,
71
- )
72
- scope.clear_cache()
68
+ _separate_pseudocolumns(scope, pseudocolumns)
73
69
 
74
70
  resolver = Resolver(scope, schema, infer_schema=infer_schema)
75
71
  _pop_table_column_aliases(scope.ctes)
@@ -81,11 +77,15 @@ def qualify_columns(
81
77
  scope,
82
78
  resolver,
83
79
  dialect,
84
- expand_only_groupby=bigquery,
80
+ expand_only_groupby=dialect.EXPAND_ONLY_GROUP_ALIAS_REF,
85
81
  )
86
82
 
87
83
  _convert_columns_to_dots(scope, resolver)
88
- _qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification)
84
+ _qualify_columns(
85
+ scope,
86
+ resolver,
87
+ allow_partial_qualification=allow_partial_qualification,
88
+ )
89
89
 
90
90
  if not schema.empty and expand_alias_refs:
91
91
  _expand_alias_refs(scope, resolver, dialect)
@@ -107,13 +107,13 @@ def qualify_columns(
107
107
  # https://www.postgresql.org/docs/current/sql-select.html#SQL-DISTINCT
108
108
  _expand_order_by_and_distinct_on(scope, resolver)
109
109
 
110
- if bigquery:
110
+ if dialect.ANNOTATE_ALL_SCOPES:
111
111
  annotator.annotate_scope(scope)
112
112
 
113
113
  return expression
114
114
 
115
115
 
116
- def validate_qualify_columns(expression: E) -> E:
116
+ def validate_qualify_columns(expression: E, sql: t.Optional[str] = None) -> E:
117
117
  """Raise an `OptimizeError` if any columns aren't qualified"""
118
118
  all_unqualified_columns = []
119
119
  for scope in traverse_scope(expression):
@@ -123,7 +123,19 @@ def validate_qualify_columns(expression: E) -> E:
123
123
  if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
124
124
  column = scope.external_columns[0]
125
125
  for_table = f" for table: '{column.table}'" if column.table else ""
126
- raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
126
+ line = column.this.meta.get("line")
127
+ col = column.this.meta.get("col")
128
+ start = column.this.meta.get("start")
129
+ end = column.this.meta.get("end")
130
+
131
+ error_msg = f"Column '{column.name}' could not be resolved{for_table}."
132
+ if line and col:
133
+ error_msg += f" Line: {line}, Col: {col}"
134
+ if sql and start is not None and end is not None:
135
+ formatted_sql = highlight_sql(sql, [(start, end)])[0]
136
+ error_msg += f"\n {formatted_sql}"
137
+
138
+ raise OptimizeError(error_msg)
127
139
 
128
140
  if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
129
141
  # New columns produced by the UNPIVOT can't be qualified, but there may be columns
@@ -135,11 +147,46 @@ def validate_qualify_columns(expression: E) -> E:
135
147
  all_unqualified_columns.extend(unqualified_columns)
136
148
 
137
149
  if all_unqualified_columns:
138
- raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
150
+ first_column = all_unqualified_columns[0]
151
+ line = first_column.this.meta.get("line")
152
+ col = first_column.this.meta.get("col")
153
+ start = first_column.this.meta.get("start")
154
+ end = first_column.this.meta.get("end")
155
+
156
+ error_msg = f"Ambiguous column '{first_column.name}'"
157
+ if line and col:
158
+ error_msg += f" (Line: {line}, Col: {col})"
159
+ if sql and start is not None and end is not None:
160
+ formatted_sql = highlight_sql(sql, [(start, end)])[0]
161
+ error_msg += f"\n {formatted_sql}"
162
+
163
+ raise OptimizeError(error_msg)
139
164
 
140
165
  return expression
141
166
 
142
167
 
168
+ def _separate_pseudocolumns(scope: Scope, pseudocolumns: t.Set[str]) -> None:
169
+ if not pseudocolumns:
170
+ return
171
+
172
+ has_pseudocolumns = False
173
+ scope_expression = scope.expression
174
+
175
+ for column in scope.columns:
176
+ name = column.name.upper()
177
+ if name not in pseudocolumns:
178
+ continue
179
+
180
+ if name != "LEVEL" or (
181
+ isinstance(scope_expression, exp.Select) and scope_expression.args.get("connect")
182
+ ):
183
+ column.replace(exp.Pseudocolumn(**column.args))
184
+ has_pseudocolumns = True
185
+
186
+ if has_pseudocolumns:
187
+ scope.clear_cache()
188
+
189
+
143
190
  def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]:
144
191
  name_columns = [
145
192
  field.this
@@ -274,12 +321,11 @@ def _expand_alias_refs(
274
321
  """
275
322
  expression = scope.expression
276
323
 
277
- if not isinstance(expression, exp.Select) or dialect == "oracle":
324
+ if not isinstance(expression, exp.Select) or dialect.DISABLES_ALIAS_REF_EXPANSION:
278
325
  return
279
326
 
280
327
  alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {}
281
328
  projections = {s.alias_or_name for s in expression.selects}
282
- is_bigquery = dialect == "bigquery"
283
329
  replaced = False
284
330
 
285
331
  def replace_columns(
@@ -317,12 +363,12 @@ def _expand_alias_refs(
317
363
  # SELECT x.a, max(x.b) as x FROM x GROUP BY 1 HAVING x > 1;
318
364
  # If "HAVING x" is expanded to "HAVING max(x.b)", BQ would blindly replace the "x" reference with the projection MAX(x.b)
319
365
  # i.e HAVING MAX(MAX(x.b).b), resulting in the error: "Aggregations of aggregations are not allowed"
320
- if is_having and is_bigquery:
366
+ if is_having and dialect.PROJECTION_ALIASES_SHADOW_SOURCE_NAMES:
321
367
  skip_replace = skip_replace or any(
322
368
  node.parts[0].name in projections
323
369
  for node in alias_expr.find_all(exp.Column)
324
370
  )
325
- elif is_bigquery and (is_group_by or is_having):
371
+ elif dialect.PROJECTION_ALIASES_SHADOW_SOURCE_NAMES and (is_group_by or is_having):
326
372
  column_table = table.name if table else column.table
327
373
  if column_table in projections:
328
374
  # BigQuery's GROUP BY and HAVING clauses get confused if the column name
@@ -375,9 +421,7 @@ def _expand_alias_refs(
375
421
  replace_columns(expression.args.get("having"), resolve_table=True)
376
422
  replace_columns(expression.args.get("qualify"), resolve_table=True)
377
423
 
378
- # Snowflake allows alias expansion in the JOIN ... ON clause (and almost everywhere else)
379
- # https://docs.snowflake.com/en/sql-reference/sql/select#usage-notes
380
- if dialect == "snowflake":
424
+ if dialect.SUPPORTS_ALIAS_REFS_IN_JOIN_CONDITIONS:
381
425
  for join in expression.args.get("joins") or []:
382
426
  replace_columns(join)
383
427
 
@@ -385,7 +429,7 @@ def _expand_alias_refs(
385
429
  scope.clear_cache()
386
430
 
387
431
 
388
- def _expand_group_by(scope: Scope, dialect: DialectType) -> None:
432
+ def _expand_group_by(scope: Scope, dialect: Dialect) -> None:
389
433
  expression = scope.expression
390
434
  group = expression.args.get("group")
391
435
  if not group:
@@ -411,7 +455,7 @@ def _expand_order_by_and_distinct_on(scope: Scope, resolver: Resolver) -> None:
411
455
  for original, expanded in zip(
412
456
  modifier_expressions,
413
457
  _expand_positional_references(
414
- scope, modifier_expressions, resolver.schema.dialect, alias=True
458
+ scope, modifier_expressions, resolver.dialect, alias=True
415
459
  ),
416
460
  ):
417
461
  for agg in original.find_all(exp.AggFunc):
@@ -433,7 +477,7 @@ def _expand_order_by_and_distinct_on(scope: Scope, resolver: Resolver) -> None:
433
477
 
434
478
 
435
479
  def _expand_positional_references(
436
- scope: Scope, expressions: t.Iterable[exp.Expression], dialect: DialectType, alias: bool = False
480
+ scope: Scope, expressions: t.Iterable[exp.Expression], dialect: Dialect, alias: bool = False
437
481
  ) -> t.List[exp.Expression]:
438
482
  new_nodes: t.List[exp.Expression] = []
439
483
  ambiguous_projections = None
@@ -447,7 +491,7 @@ def _expand_positional_references(
447
491
  else:
448
492
  select = select.this
449
493
 
450
- if dialect == "bigquery":
494
+ if dialect.PROJECTION_ALIASES_SHADOW_SOURCE_NAMES:
451
495
  if ambiguous_projections is None:
452
496
  # When a projection name is also a source name and it is referenced in the
453
497
  # GROUP BY clause, BQ can't understand what the identifier corresponds to
@@ -488,10 +532,10 @@ def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias:
488
532
 
489
533
  def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None:
490
534
  """
491
- Converts `Column` instances that represent struct field lookup into chained `Dots`.
535
+ Converts `Column` instances that represent STRUCT or JSON field lookup into chained `Dots`.
492
536
 
493
- Struct field lookups look like columns (e.g. "struct"."field"), but they need to be
494
- qualified separately and represented as Dot(Dot(...(<table>.<column>, field1), field2, ...)).
537
+ These lookups may be parsed as columns (e.g. "col"."field"."field2"), but they need to be
538
+ normalized to `Dot(Dot(...(<table>.<column>, field1), field2, ...))` to be qualified properly.
495
539
  """
496
540
  converted = False
497
541
  for column in itertools.chain(scope.columns, scope.stars):
@@ -499,6 +543,7 @@ def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None:
499
543
  continue
500
544
 
501
545
  column_table: t.Optional[str | exp.Identifier] = column.table
546
+ dot_parts = column.meta.pop("dot_parts", [])
502
547
  if (
503
548
  column_table
504
549
  and column_table not in scope.sources
@@ -514,12 +559,20 @@ def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None:
514
559
  # The struct is already qualified, but we still need to change the AST
515
560
  column_table = root
516
561
  root, *parts = parts
562
+ was_qualified = True
517
563
  else:
518
564
  column_table = resolver.get_table(root.name)
565
+ was_qualified = False
519
566
 
520
567
  if column_table:
521
568
  converted = True
522
- column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
569
+ new_column = exp.column(root, table=column_table)
570
+
571
+ if dot_parts:
572
+ # Remove the actual column parts from the rest of dot parts
573
+ new_column.meta["dot_parts"] = dot_parts[2 if was_qualified else 1 :]
574
+
575
+ column.replace(exp.Dot.build([new_column, *parts]))
523
576
 
524
577
  if converted:
525
578
  # We want to re-aggregate the converted columns, otherwise they'd be skipped in
@@ -527,7 +580,11 @@ def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None:
527
580
  scope.clear_cache()
528
581
 
529
582
 
530
- def _qualify_columns(scope: Scope, resolver: Resolver, allow_partial_qualification: bool) -> None:
583
+ def _qualify_columns(
584
+ scope: Scope,
585
+ resolver: Resolver,
586
+ allow_partial_qualification: bool,
587
+ ) -> None:
531
588
  """Disambiguate columns, ensuring each column specifies a source"""
532
589
  for column in scope.columns:
533
590
  column_table = column.table
@@ -556,11 +613,11 @@ def _qualify_columns(scope: Scope, resolver: Resolver, allow_partial_qualificati
556
613
  if column_table:
557
614
  column.set("table", column_table)
558
615
  elif (
559
- resolver.schema.dialect == "bigquery"
616
+ resolver.dialect.TABLES_REFERENCEABLE_AS_COLUMNS
560
617
  and len(column.parts) == 1
561
618
  and column_name in scope.selected_sources
562
619
  ):
563
- # BigQuery allows tables to be referenced as columns, treating them as structs
620
+ # BigQuery and Postgres allow tables to be referenced as columns, treating them as structs/records
564
621
  scope.replace(column, exp.TableColumn(this=column.this))
565
622
 
566
623
  for pivot in scope.pivots:
@@ -571,7 +628,7 @@ def _qualify_columns(scope: Scope, resolver: Resolver, allow_partial_qualificati
571
628
  column.set("table", column_table)
572
629
 
573
630
 
574
- def _expand_struct_stars_bigquery(
631
+ def _expand_struct_stars_no_parens(
575
632
  expression: exp.Dot,
576
633
  ) -> t.List[exp.Alias]:
577
634
  """[BigQuery] Expand/Flatten foo.bar.* where bar is a struct column"""
@@ -625,7 +682,7 @@ def _expand_struct_stars_bigquery(
625
682
  return new_selections
626
683
 
627
684
 
628
- def _expand_struct_stars_risingwave(expression: exp.Dot) -> t.List[exp.Alias]:
685
+ def _expand_struct_stars_with_parens(expression: exp.Dot) -> t.List[exp.Alias]:
629
686
  """[RisingWave] Expand/Flatten (<exp>.bar).*, where bar is a struct column"""
630
687
 
631
688
  # it is not (<sub_exp>).* pattern, which means we can't expand
@@ -702,7 +759,7 @@ def _expand_stars(
702
759
  rename_columns: t.Dict[int, t.Dict[str, str]] = {}
703
760
 
704
761
  coalesced_columns = set()
705
- dialect = resolver.schema.dialect
762
+ dialect = resolver.dialect
706
763
 
707
764
  pivot_output_columns = None
708
765
  pivot_exclude_columns: t.Set[str] = set()
@@ -725,10 +782,9 @@ def _expand_stars(
725
782
  if not pivot_output_columns:
726
783
  pivot_output_columns = [c.alias_or_name for c in pivot.expressions]
727
784
 
728
- is_bigquery = dialect == "bigquery"
729
- is_risingwave = dialect == "risingwave"
730
-
731
- if (is_bigquery or is_risingwave) and any(isinstance(col, exp.Dot) for col in scope.stars):
785
+ if dialect.SUPPORTS_STRUCT_STAR_EXPANSION and any(
786
+ isinstance(col, exp.Dot) for col in scope.stars
787
+ ):
732
788
  # Found struct expansion, annotate scope ahead of time
733
789
  annotator.annotate_scope(scope)
734
790
 
@@ -745,13 +801,16 @@ def _expand_stars(
745
801
  _add_except_columns(expression.this, tables, except_columns)
746
802
  _add_replace_columns(expression.this, tables, replace_columns)
747
803
  _add_rename_columns(expression.this, tables, rename_columns)
748
- elif is_bigquery:
749
- struct_fields = _expand_struct_stars_bigquery(expression)
804
+ elif (
805
+ dialect.SUPPORTS_STRUCT_STAR_EXPANSION
806
+ and not dialect.REQUIRES_PARENTHESIZED_STRUCT_ACCESS
807
+ ):
808
+ struct_fields = _expand_struct_stars_no_parens(expression)
750
809
  if struct_fields:
751
810
  new_selections.extend(struct_fields)
752
811
  continue
753
- elif is_risingwave:
754
- struct_fields = _expand_struct_stars_risingwave(expression)
812
+ elif dialect.REQUIRES_PARENTHESIZED_STRUCT_ACCESS:
813
+ struct_fields = _expand_struct_stars_with_parens(expression)
755
814
  if struct_fields:
756
815
  new_selections.extend(struct_fields)
757
816
  continue
@@ -767,7 +826,7 @@ def _expand_stars(
767
826
  columns = resolver.get_source_columns(table, only_visible=True)
768
827
  columns = columns or scope.outer_columns
769
828
 
770
- if pseudocolumns:
829
+ if pseudocolumns and dialect.EXCLUDES_PSEUDOCOLUMNS_FROM_STAR:
771
830
  columns = [name for name in columns if name.upper() not in pseudocolumns]
772
831
 
773
832
  if not columns or "*" in columns:
@@ -821,7 +880,7 @@ def _expand_stars(
821
880
  def _add_except_columns(
822
881
  expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]]
823
882
  ) -> None:
824
- except_ = expression.args.get("except")
883
+ except_ = expression.args.get("except_")
825
884
 
826
885
  if not except_:
827
886
  return
@@ -901,320 +960,22 @@ def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool
901
960
  ) # type: ignore
902
961
 
903
962
 
904
- def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
963
+ def pushdown_cte_alias_columns(scope: Scope) -> None:
905
964
  """
906
965
  Pushes down the CTE alias columns into the projection,
907
966
 
908
967
  This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
909
968
 
910
- Example:
911
- >>> import sqlglot
912
- >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
913
- >>> pushdown_cte_alias_columns(expression).sql()
914
- 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
915
-
916
969
  Args:
917
- expression: Expression to pushdown.
918
-
919
- Returns:
920
- The expression with the CTE aliases pushed down into the projection.
970
+ scope: Scope to find ctes to pushdown aliases.
921
971
  """
922
- for cte in expression.find_all(exp.CTE):
923
- if cte.alias_column_names:
972
+ for cte in scope.ctes:
973
+ if cte.alias_column_names and isinstance(cte.this, exp.Select):
924
974
  new_expressions = []
925
975
  for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
926
976
  if isinstance(projection, exp.Alias):
927
- projection.set("alias", _alias)
977
+ projection.set("alias", exp.to_identifier(_alias))
928
978
  else:
929
979
  projection = alias(projection, alias=_alias)
930
980
  new_expressions.append(projection)
931
981
  cte.this.set("expressions", new_expressions)
932
-
933
- return expression
934
-
935
-
936
- class Resolver:
937
- """
938
- Helper for resolving columns.
939
-
940
- This is a class so we can lazily load some things and easily share them across functions.
941
- """
942
-
943
- def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
944
- self.scope = scope
945
- self.schema = schema
946
- self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
947
- self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
948
- self._all_columns: t.Optional[t.Set[str]] = None
949
- self._infer_schema = infer_schema
950
- self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
951
-
952
- def get_table(self, column: str | exp.Column) -> t.Optional[exp.Identifier]:
953
- """
954
- Get the table for a column name.
955
-
956
- Args:
957
- column: The column expression (or column name) to find the table for.
958
- Returns:
959
- The table name if it can be found/inferred.
960
- """
961
- column_name = column if isinstance(column, str) else column.name
962
-
963
- table_name = self._get_table_name_from_sources(column_name)
964
-
965
- if not table_name and isinstance(column, exp.Column):
966
- # Fall-back case: If we couldn't find the `table_name` from ALL of the sources,
967
- # attempt to disambiguate the column based on other characteristics e.g if this column is in a join condition,
968
- # we may be able to disambiguate based on the source order.
969
- if join_context := self._get_column_join_context(column):
970
- # In this case, the return value will be the join that _may_ be able to disambiguate the column
971
- # and we can use the source columns available at that join to get the table name
972
- table_name = self._get_table_name_from_sources(
973
- column_name, self._get_available_source_columns(join_context)
974
- )
975
-
976
- if not table_name and self._infer_schema:
977
- sources_without_schema = tuple(
978
- source
979
- for source, columns in self._get_all_source_columns().items()
980
- if not columns or "*" in columns
981
- )
982
- if len(sources_without_schema) == 1:
983
- table_name = sources_without_schema[0]
984
-
985
- if table_name not in self.scope.selected_sources:
986
- return exp.to_identifier(table_name)
987
-
988
- node, _ = self.scope.selected_sources.get(table_name)
989
-
990
- if isinstance(node, exp.Query):
991
- while node and node.alias != table_name:
992
- node = node.parent
993
-
994
- node_alias = node.args.get("alias")
995
- if node_alias:
996
- return exp.to_identifier(node_alias.this)
997
-
998
- return exp.to_identifier(table_name)
999
-
1000
- @property
1001
- def all_columns(self) -> t.Set[str]:
1002
- """All available columns of all sources in this scope"""
1003
- if self._all_columns is None:
1004
- self._all_columns = {
1005
- column for columns in self._get_all_source_columns().values() for column in columns
1006
- }
1007
- return self._all_columns
1008
-
1009
- def get_source_columns_from_set_op(self, expression: exp.Expression) -> t.List[str]:
1010
- if isinstance(expression, exp.Select):
1011
- return expression.named_selects
1012
- if isinstance(expression, exp.Subquery) and isinstance(expression.this, exp.SetOperation):
1013
- # Different types of SET modifiers can be chained together if they're explicitly grouped by nesting
1014
- return self.get_source_columns_from_set_op(expression.this)
1015
- if not isinstance(expression, exp.SetOperation):
1016
- raise OptimizeError(f"Unknown set operation: {expression}")
1017
-
1018
- set_op = expression
1019
-
1020
- # BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME
1021
- on_column_list = set_op.args.get("on")
1022
-
1023
- if on_column_list:
1024
- # The resulting columns are the columns in the ON clause:
1025
- # {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...)
1026
- columns = [col.name for col in on_column_list]
1027
- elif set_op.side or set_op.kind:
1028
- side = set_op.side
1029
- kind = set_op.kind
1030
-
1031
- # Visit the children UNIONs (if any) in a post-order traversal
1032
- left = self.get_source_columns_from_set_op(set_op.left)
1033
- right = self.get_source_columns_from_set_op(set_op.right)
1034
-
1035
- # We use dict.fromkeys to deduplicate keys and maintain insertion order
1036
- if side == "LEFT":
1037
- columns = left
1038
- elif side == "FULL":
1039
- columns = list(dict.fromkeys(left + right))
1040
- elif kind == "INNER":
1041
- columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys())
1042
- else:
1043
- columns = set_op.named_selects
1044
-
1045
- return columns
1046
-
1047
- def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
1048
- """Resolve the source columns for a given source `name`."""
1049
- cache_key = (name, only_visible)
1050
- if cache_key not in self._get_source_columns_cache:
1051
- if name not in self.scope.sources:
1052
- raise OptimizeError(f"Unknown table: {name}")
1053
-
1054
- source = self.scope.sources[name]
1055
-
1056
- if isinstance(source, exp.Table):
1057
- columns = self.schema.column_names(source, only_visible)
1058
- elif isinstance(source, Scope) and isinstance(
1059
- source.expression, (exp.Values, exp.Unnest)
1060
- ):
1061
- columns = source.expression.named_selects
1062
-
1063
- # in bigquery, unnest structs are automatically scoped as tables, so you can
1064
- # directly select a struct field in a query.
1065
- # this handles the case where the unnest is statically defined.
1066
- if self.schema.dialect == "bigquery":
1067
- if source.expression.is_type(exp.DataType.Type.STRUCT):
1068
- for k in source.expression.type.expressions: # type: ignore
1069
- columns.append(k.name)
1070
- elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation):
1071
- columns = self.get_source_columns_from_set_op(source.expression)
1072
-
1073
- else:
1074
- select = seq_get(source.expression.selects, 0)
1075
-
1076
- if isinstance(select, exp.QueryTransform):
1077
- # https://spark.apache.org/docs/3.5.1/sql-ref-syntax-qry-select-transform.html
1078
- schema = select.args.get("schema")
1079
- columns = [c.name for c in schema.expressions] if schema else ["key", "value"]
1080
- else:
1081
- columns = source.expression.named_selects
1082
-
1083
- node, _ = self.scope.selected_sources.get(name) or (None, None)
1084
- if isinstance(node, Scope):
1085
- column_aliases = node.expression.alias_column_names
1086
- elif isinstance(node, exp.Expression):
1087
- column_aliases = node.alias_column_names
1088
- else:
1089
- column_aliases = []
1090
-
1091
- if column_aliases:
1092
- # If the source's columns are aliased, their aliases shadow the corresponding column names.
1093
- # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
1094
- columns = [
1095
- alias or name
1096
- for (name, alias) in itertools.zip_longest(columns, column_aliases)
1097
- ]
1098
-
1099
- self._get_source_columns_cache[cache_key] = columns
1100
-
1101
- return self._get_source_columns_cache[cache_key]
1102
-
1103
- def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
1104
- if self._source_columns is None:
1105
- self._source_columns = {
1106
- source_name: self.get_source_columns(source_name)
1107
- for source_name, source in itertools.chain(
1108
- self.scope.selected_sources.items(), self.scope.lateral_sources.items()
1109
- )
1110
- }
1111
- return self._source_columns
1112
-
1113
- def _get_table_name_from_sources(
1114
- self, column_name: str, source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
1115
- ) -> t.Optional[str]:
1116
- if not source_columns:
1117
- # If not supplied, get all sources to calculate unambiguous columns
1118
- if self._unambiguous_columns is None:
1119
- self._unambiguous_columns = self._get_unambiguous_columns(
1120
- self._get_all_source_columns()
1121
- )
1122
-
1123
- unambiguous_columns = self._unambiguous_columns
1124
- else:
1125
- unambiguous_columns = self._get_unambiguous_columns(source_columns)
1126
-
1127
- return unambiguous_columns.get(column_name)
1128
-
1129
- def _get_column_join_context(self, column: exp.Column) -> t.Optional[exp.Join]:
1130
- """
1131
- Check if a column participating in a join can be qualified based on the source order.
1132
- """
1133
- args = self.scope.expression.args
1134
- joins = args.get("joins")
1135
-
1136
- if not joins or args.get("laterals") or args.get("pivots"):
1137
- # Feature gap: We currently don't try to disambiguate columns if other sources
1138
- # (e.g laterals, pivots) exist alongside joins
1139
- return None
1140
-
1141
- join_ancestor = column.find_ancestor(exp.Join, exp.Select)
1142
-
1143
- if (
1144
- isinstance(join_ancestor, exp.Join)
1145
- and join_ancestor.alias_or_name in self.scope.selected_sources
1146
- ):
1147
- # Ensure that the found ancestor is a join that contains an actual source,
1148
- # e.g in Clickhouse `b` is an array expression in `a ARRAY JOIN b`
1149
- return join_ancestor
1150
-
1151
- return None
1152
-
1153
- def _get_available_source_columns(
1154
- self, join_ancestor: exp.Join
1155
- ) -> t.Dict[str, t.Sequence[str]]:
1156
- """
1157
- Get the source columns that are available at the point where a column is referenced.
1158
-
1159
- For columns in JOIN conditions, this only includes tables that have been joined
1160
- up to that point. Example:
1161
-
1162
- ```
1163
- SELECT * FROM t_1 INNER JOIN ... INNER JOIN t_n ON t_1.a = c INNER JOIN t_n+1 ON ...
1164
- ``` ^
1165
- |
1166
- +----------------------------------+
1167
- |
1168
-
1169
- The unqualified column `c` is not ambiguous if no other sources up until that
1170
- join i.e t_1, ..., t_n, contain a column named `c`.
1171
-
1172
- """
1173
- args = self.scope.expression.args
1174
-
1175
- # Collect tables in order: FROM clause tables + joined tables up to current join
1176
- from_name = args["from"].alias_or_name
1177
- available_sources = {from_name: self.get_source_columns(from_name)}
1178
-
1179
- for join in args["joins"][: t.cast(int, join_ancestor.index) + 1]:
1180
- available_sources[join.alias_or_name] = self.get_source_columns(join.alias_or_name)
1181
-
1182
- return available_sources
1183
-
1184
- def _get_unambiguous_columns(
1185
- self, source_columns: t.Dict[str, t.Sequence[str]]
1186
- ) -> t.Mapping[str, str]:
1187
- """
1188
- Find all the unambiguous columns in sources.
1189
-
1190
- Args:
1191
- source_columns: Mapping of names to source columns.
1192
-
1193
- Returns:
1194
- Mapping of column name to source name.
1195
- """
1196
- if not source_columns:
1197
- return {}
1198
-
1199
- source_columns_pairs = list(source_columns.items())
1200
-
1201
- first_table, first_columns = source_columns_pairs[0]
1202
-
1203
- if len(source_columns_pairs) == 1:
1204
- # Performance optimization - avoid copying first_columns if there is only one table.
1205
- return SingleValuedMapping(first_columns, first_table)
1206
-
1207
- unambiguous_columns = {col: first_table for col in first_columns}
1208
- all_columns = set(unambiguous_columns)
1209
-
1210
- for table, columns in source_columns_pairs[1:]:
1211
- unique = set(columns)
1212
- ambiguous = all_columns.intersection(unique)
1213
- all_columns.update(columns)
1214
-
1215
- for column in ambiguous:
1216
- unambiguous_columns.pop(column, None)
1217
- for column in unique.difference(ambiguous):
1218
- unambiguous_columns[column] = table
1219
-
1220
- return unambiguous_columns