sqlglot 27.29.0__py3-none-any.whl → 28.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. sqlglot/__main__.py +6 -4
  2. sqlglot/_version.py +2 -2
  3. sqlglot/dialects/bigquery.py +116 -295
  4. sqlglot/dialects/clickhouse.py +67 -2
  5. sqlglot/dialects/databricks.py +38 -1
  6. sqlglot/dialects/dialect.py +327 -286
  7. sqlglot/dialects/dremio.py +4 -1
  8. sqlglot/dialects/duckdb.py +718 -22
  9. sqlglot/dialects/exasol.py +243 -10
  10. sqlglot/dialects/hive.py +8 -8
  11. sqlglot/dialects/mysql.py +11 -2
  12. sqlglot/dialects/oracle.py +29 -0
  13. sqlglot/dialects/postgres.py +46 -24
  14. sqlglot/dialects/presto.py +47 -16
  15. sqlglot/dialects/redshift.py +16 -0
  16. sqlglot/dialects/risingwave.py +3 -0
  17. sqlglot/dialects/singlestore.py +12 -3
  18. sqlglot/dialects/snowflake.py +199 -271
  19. sqlglot/dialects/spark.py +2 -2
  20. sqlglot/dialects/spark2.py +11 -48
  21. sqlglot/dialects/sqlite.py +9 -0
  22. sqlglot/dialects/teradata.py +5 -8
  23. sqlglot/dialects/trino.py +6 -0
  24. sqlglot/dialects/tsql.py +61 -25
  25. sqlglot/diff.py +4 -2
  26. sqlglot/errors.py +69 -0
  27. sqlglot/expressions.py +484 -84
  28. sqlglot/generator.py +143 -41
  29. sqlglot/helper.py +2 -2
  30. sqlglot/optimizer/annotate_types.py +247 -140
  31. sqlglot/optimizer/canonicalize.py +6 -1
  32. sqlglot/optimizer/eliminate_joins.py +1 -1
  33. sqlglot/optimizer/eliminate_subqueries.py +2 -2
  34. sqlglot/optimizer/merge_subqueries.py +5 -5
  35. sqlglot/optimizer/normalize.py +20 -13
  36. sqlglot/optimizer/normalize_identifiers.py +17 -3
  37. sqlglot/optimizer/optimizer.py +4 -0
  38. sqlglot/optimizer/pushdown_predicates.py +1 -1
  39. sqlglot/optimizer/qualify.py +14 -6
  40. sqlglot/optimizer/qualify_columns.py +113 -352
  41. sqlglot/optimizer/qualify_tables.py +112 -70
  42. sqlglot/optimizer/resolver.py +374 -0
  43. sqlglot/optimizer/scope.py +27 -16
  44. sqlglot/optimizer/simplify.py +1074 -964
  45. sqlglot/optimizer/unnest_subqueries.py +12 -2
  46. sqlglot/parser.py +276 -160
  47. sqlglot/planner.py +2 -2
  48. sqlglot/schema.py +15 -4
  49. sqlglot/tokens.py +42 -7
  50. sqlglot/transforms.py +77 -22
  51. sqlglot/typing/__init__.py +316 -0
  52. sqlglot/typing/bigquery.py +376 -0
  53. sqlglot/typing/hive.py +12 -0
  54. sqlglot/typing/presto.py +24 -0
  55. sqlglot/typing/snowflake.py +505 -0
  56. sqlglot/typing/spark2.py +58 -0
  57. sqlglot/typing/tsql.py +9 -0
  58. {sqlglot-27.29.0.dist-info → sqlglot-28.4.0.dist-info}/METADATA +2 -2
  59. sqlglot-28.4.0.dist-info/RECORD +92 -0
  60. sqlglot-27.29.0.dist-info/RECORD +0 -84
  61. {sqlglot-27.29.0.dist-info → sqlglot-28.4.0.dist-info}/WHEEL +0 -0
  62. {sqlglot-27.29.0.dist-info → sqlglot-28.4.0.dist-info}/licenses/LICENSE +0 -0
  63. {sqlglot-27.29.0.dist-info → sqlglot-28.4.0.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,10 @@
1
1
  from __future__ import annotations
2
2
 
3
- import itertools
4
3
  import typing as t
5
4
 
6
- from sqlglot import alias, exp
5
+ from sqlglot import exp
7
6
  from sqlglot.dialects.dialect import Dialect, DialectType
8
- from sqlglot.helper import name_sequence
7
+ from sqlglot.helper import name_sequence, seq_get, ensure_list
9
8
  from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
10
9
  from sqlglot.optimizer.scope import Scope, traverse_scope
11
10
 
@@ -17,8 +16,9 @@ def qualify_tables(
17
16
  expression: E,
18
17
  db: t.Optional[str | exp.Identifier] = None,
19
18
  catalog: t.Optional[str | exp.Identifier] = None,
20
- on_qualify: t.Optional[t.Callable[[exp.Expression], None]] = None,
19
+ on_qualify: t.Optional[t.Callable[[exp.Table], None]] = None,
21
20
  dialect: DialectType = None,
21
+ canonicalize_table_aliases: bool = False,
22
22
  ) -> E:
23
23
  """
24
24
  Rewrite sqlglot AST to have fully qualified tables. Join constructs such as
@@ -40,16 +40,14 @@ def qualify_tables(
40
40
  catalog: Catalog name
41
41
  on_qualify: Callback after a table has been qualified.
42
42
  dialect: The dialect to parse catalog and schema into.
43
+ canonicalize_table_aliases: Whether to use canonical aliases (_0, _1, ...) for all sources
44
+ instead of preserving table names. Defaults to False.
43
45
 
44
46
  Returns:
45
47
  The qualified expression.
46
48
  """
47
49
  dialect = Dialect.get_or_raise(dialect)
48
-
49
- alias_sequence = name_sequence("_q_")
50
-
51
- def next_alias_name() -> str:
52
- return normalize_identifiers(alias_sequence(), dialect=dialect).name
50
+ next_alias_name = name_sequence("_")
53
51
 
54
52
  if db := db or None:
55
53
  db = exp.parse_identifier(db, dialect=dialect)
@@ -68,95 +66,132 @@ def qualify_tables(
68
66
  table.set("catalog", catalog.copy())
69
67
 
70
68
  if (db or catalog) and not isinstance(expression, exp.Query):
71
- with_ = expression.args.get("with") or exp.With()
69
+ with_ = expression.args.get("with_") or exp.With()
72
70
  cte_names = {cte.alias_or_name for cte in with_.expressions}
73
71
 
74
72
  for node in expression.walk(prune=lambda n: isinstance(n, exp.Query)):
75
73
  if isinstance(node, exp.Table) and node.name not in cte_names:
76
74
  _qualify(node)
77
75
 
76
+ def _set_alias(
77
+ expression: exp.Expression,
78
+ canonical_aliases: t.Dict[str, str],
79
+ target_alias: t.Optional[str] = None,
80
+ scope: t.Optional[Scope] = None,
81
+ normalize: bool = False,
82
+ columns: t.Optional[t.List[t.Union[str, exp.Identifier]]] = None,
83
+ ) -> None:
84
+ alias = expression.args.get("alias") or exp.TableAlias()
85
+
86
+ if canonicalize_table_aliases:
87
+ new_alias_name = next_alias_name()
88
+ canonical_aliases[alias.name or target_alias or ""] = new_alias_name
89
+ elif not alias.name:
90
+ new_alias_name = target_alias or next_alias_name()
91
+ if normalize and target_alias:
92
+ new_alias_name = normalize_identifiers(new_alias_name, dialect=dialect).name
93
+ else:
94
+ return
95
+
96
+ alias.set("this", exp.to_identifier(new_alias_name))
97
+
98
+ if columns:
99
+ alias.set("columns", [exp.to_identifier(c) for c in columns])
100
+
101
+ expression.set("alias", alias)
102
+
103
+ if scope:
104
+ scope.rename_source(None, new_alias_name)
105
+
78
106
  for scope in traverse_scope(expression):
79
- for derived_table in itertools.chain(scope.ctes, scope.derived_tables):
80
- if isinstance(derived_table, exp.Subquery):
81
- unnested = derived_table.unnest()
82
- if isinstance(unnested, exp.Table):
83
- joins = unnested.args.get("joins")
84
- unnested.set("joins", None)
85
- derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False))
86
- derived_table.this.set("joins", joins)
87
-
88
- if not derived_table.args.get("alias"):
89
- alias_ = next_alias_name()
90
- derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_)))
91
- scope.rename_source(None, alias_)
92
-
93
- pivots = derived_table.args.get("pivots")
94
- if pivots and not pivots[0].alias:
95
- pivots[0].set("alias", exp.TableAlias(this=exp.to_identifier(next_alias_name())))
107
+ local_columns = scope.local_columns
108
+ canonical_aliases: t.Dict[str, str] = {}
109
+
110
+ for query in scope.subqueries:
111
+ subquery = query.parent
112
+ if isinstance(subquery, exp.Subquery):
113
+ subquery.unwrap().replace(subquery)
114
+
115
+ for derived_table in scope.derived_tables:
116
+ unnested = derived_table.unnest()
117
+ if isinstance(unnested, exp.Table):
118
+ joins = unnested.args.get("joins")
119
+ unnested.set("joins", None)
120
+ derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False))
121
+ derived_table.this.set("joins", joins)
122
+
123
+ _set_alias(derived_table, canonical_aliases, scope=scope)
124
+ if pivot := seq_get(derived_table.args.get("pivots") or [], 0):
125
+ _set_alias(pivot, canonical_aliases)
96
126
 
97
127
  table_aliases = {}
98
128
 
99
129
  for name, source in scope.sources.items():
100
130
  if isinstance(source, exp.Table):
101
- pivots = source.args.get("pivots")
102
- if not source.alias:
103
- # Don't add the pivot's alias to the pivoted table, use the table's name instead
104
- if pivots and pivots[0].alias == name:
105
- name = source.name
106
-
107
- # Mutates the source by attaching an alias to it
108
- normalized_alias = normalize_identifiers(
109
- name or source.name or alias_sequence(), dialect=dialect
110
- )
111
- alias(source, normalized_alias, copy=False, table=True)
112
-
113
- table_aliases[".".join(p.name for p in source.parts)] = exp.to_identifier(
114
- source.alias
131
+ # When the name is empty, it means that we have a non-table source, e.g. a pivoted cte
132
+ is_real_table_source = bool(name)
133
+
134
+ if pivot := seq_get(source.args.get("pivots") or [], 0):
135
+ name = source.name
136
+
137
+ table_this = source.this
138
+ table_alias = source.args.get("alias")
139
+ function_columns: t.List[t.Union[str, exp.Identifier]] = []
140
+ if isinstance(table_this, exp.Func):
141
+ if not table_alias:
142
+ function_columns = ensure_list(
143
+ dialect.DEFAULT_FUNCTIONS_COLUMN_NAMES.get(type(table_this))
144
+ )
145
+ elif columns := table_alias.columns:
146
+ function_columns = columns
147
+ elif type(table_this) in dialect.DEFAULT_FUNCTIONS_COLUMN_NAMES:
148
+ function_columns = ensure_list(source.alias_or_name)
149
+ source.set("alias", None)
150
+ name = None
151
+
152
+ _set_alias(
153
+ source,
154
+ canonical_aliases,
155
+ target_alias=name or source.name or None,
156
+ normalize=True,
157
+ columns=function_columns,
115
158
  )
116
159
 
117
- if pivots:
118
- pivot = pivots[0]
119
- if not pivot.alias:
120
- pivot_alias = normalize_identifiers(
121
- source.alias if pivot.unpivot else alias_sequence(),
122
- dialect=dialect,
123
- )
124
- pivot.set("alias", exp.TableAlias(this=exp.to_identifier(pivot_alias)))
160
+ source_fqn = ".".join(p.name for p in source.parts)
161
+ table_aliases[source_fqn] = source.args["alias"].this.copy()
162
+
163
+ if pivot:
164
+ target_alias = source.alias if pivot.unpivot else None
165
+ _set_alias(pivot, canonical_aliases, target_alias=target_alias, normalize=True)
125
166
 
126
167
  # This case corresponds to a pivoted CTE, we don't want to qualify that
127
168
  if isinstance(scope.sources.get(source.alias_or_name), Scope):
128
169
  continue
129
170
 
130
- _qualify(source)
171
+ if is_real_table_source:
172
+ _qualify(source)
131
173
 
132
- if on_qualify:
133
- on_qualify(source)
174
+ if on_qualify:
175
+ on_qualify(source)
134
176
  elif isinstance(source, Scope) and source.is_udtf:
135
- udtf = source.expression
136
- table_alias = udtf.args.get("alias") or exp.TableAlias(
137
- this=exp.to_identifier(next_alias_name())
138
- )
139
- udtf.set("alias", table_alias)
177
+ _set_alias(udtf := source.expression, canonical_aliases)
178
+
179
+ table_alias = udtf.args["alias"]
140
180
 
141
- if not table_alias.name:
142
- table_alias.set("this", exp.to_identifier(next_alias_name()))
143
181
  if isinstance(udtf, exp.Values) and not table_alias.columns:
144
182
  column_aliases = [
145
183
  normalize_identifiers(i, dialect=dialect)
146
184
  for i in dialect.generate_values_aliases(udtf)
147
185
  ]
148
186
  table_alias.set("columns", column_aliases)
149
- else:
150
- for node in scope.walk():
151
- if (
152
- isinstance(node, exp.Table)
153
- and not node.alias
154
- and isinstance(node.parent, (exp.From, exp.Join))
155
- ):
156
- # Mutates the table by attaching an alias to it
157
- alias(node, node.name, copy=False, table=True)
158
-
159
- for column in scope.columns:
187
+
188
+ for table in scope.tables:
189
+ if not table.alias and isinstance(table.parent, (exp.From, exp.Join)):
190
+ _set_alias(table, canonical_aliases, target_alias=table.name)
191
+
192
+ for column in local_columns:
193
+ table = column.table
194
+
160
195
  if column.db:
161
196
  table_alias = table_aliases.get(".".join(p.name for p in column.parts[0:-1]))
162
197
 
@@ -165,5 +200,12 @@ def qualify_tables(
165
200
  column.set(p, None)
166
201
 
167
202
  column.set("table", table_alias.copy())
203
+ elif (
204
+ canonical_aliases
205
+ and table
206
+ and (canonical_table := canonical_aliases.get(table, "")) != column.table
207
+ ):
208
+ # Amend existing aliases, e.g. t.c -> _0.c if t is aliased to _0
209
+ column.set("table", exp.to_identifier(canonical_table))
168
210
 
169
211
  return expression
@@ -0,0 +1,374 @@
1
+ from __future__ import annotations
2
+
3
+ import itertools
4
+ import typing as t
5
+
6
+ from sqlglot import exp
7
+ from sqlglot.dialects.dialect import Dialect
8
+ from sqlglot.errors import OptimizeError
9
+ from sqlglot.helper import seq_get, SingleValuedMapping
10
+ from sqlglot.optimizer.scope import Scope
11
+
12
+ if t.TYPE_CHECKING:
13
+ from sqlglot.schema import Schema
14
+
15
+
16
+ class Resolver:
17
+ """
18
+ Helper for resolving columns.
19
+
20
+ This is a class so we can lazily load some things and easily share them across functions.
21
+ """
22
+
23
+ def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
24
+ self.scope = scope
25
+ self.schema = schema
26
+ self.dialect = schema.dialect or Dialect()
27
+ self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
28
+ self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
29
+ self._all_columns: t.Optional[t.Set[str]] = None
30
+ self._infer_schema = infer_schema
31
+ self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
32
+
33
+ def get_table(self, column: str | exp.Column) -> t.Optional[exp.Identifier]:
34
+ """
35
+ Get the table for a column name.
36
+
37
+ Args:
38
+ column: The column expression (or column name) to find the table for.
39
+ Returns:
40
+ The table name if it can be found/inferred.
41
+ """
42
+ column_name = column if isinstance(column, str) else column.name
43
+
44
+ table_name = self._get_table_name_from_sources(column_name)
45
+
46
+ if not table_name and isinstance(column, exp.Column):
47
+ # Fall-back case: If we couldn't find the `table_name` from ALL of the sources,
48
+ # attempt to disambiguate the column based on other characteristics e.g if this column is in a join condition,
49
+ # we may be able to disambiguate based on the source order.
50
+ if join_context := self._get_column_join_context(column):
51
+ # In this case, the return value will be the join that _may_ be able to disambiguate the column
52
+ # and we can use the source columns available at that join to get the table name
53
+ # catch OptimizeError if column is still ambiguous and try to resolve with schema inference below
54
+ try:
55
+ table_name = self._get_table_name_from_sources(
56
+ column_name, self._get_available_source_columns(join_context)
57
+ )
58
+ except OptimizeError:
59
+ pass
60
+
61
+ if not table_name and self._infer_schema:
62
+ sources_without_schema = tuple(
63
+ source
64
+ for source, columns in self._get_all_source_columns().items()
65
+ if not columns or "*" in columns
66
+ )
67
+ if len(sources_without_schema) == 1:
68
+ table_name = sources_without_schema[0]
69
+
70
+ if table_name not in self.scope.selected_sources:
71
+ return exp.to_identifier(table_name)
72
+
73
+ node, _ = self.scope.selected_sources.get(table_name)
74
+
75
+ if isinstance(node, exp.Query):
76
+ while node and node.alias != table_name:
77
+ node = node.parent
78
+
79
+ node_alias = node.args.get("alias")
80
+ if node_alias:
81
+ return exp.to_identifier(node_alias.this)
82
+
83
+ return exp.to_identifier(table_name)
84
+
85
+ @property
86
+ def all_columns(self) -> t.Set[str]:
87
+ """All available columns of all sources in this scope"""
88
+ if self._all_columns is None:
89
+ self._all_columns = {
90
+ column for columns in self._get_all_source_columns().values() for column in columns
91
+ }
92
+ return self._all_columns
93
+
94
+ def get_source_columns_from_set_op(self, expression: exp.Expression) -> t.List[str]:
95
+ if isinstance(expression, exp.Select):
96
+ return expression.named_selects
97
+ if isinstance(expression, exp.Subquery) and isinstance(expression.this, exp.SetOperation):
98
+ # Different types of SET modifiers can be chained together if they're explicitly grouped by nesting
99
+ return self.get_source_columns_from_set_op(expression.this)
100
+ if not isinstance(expression, exp.SetOperation):
101
+ raise OptimizeError(f"Unknown set operation: {expression}")
102
+
103
+ set_op = expression
104
+
105
+ # BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME
106
+ on_column_list = set_op.args.get("on")
107
+
108
+ if on_column_list:
109
+ # The resulting columns are the columns in the ON clause:
110
+ # {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...)
111
+ columns = [col.name for col in on_column_list]
112
+ elif set_op.side or set_op.kind:
113
+ side = set_op.side
114
+ kind = set_op.kind
115
+
116
+ # Visit the children UNIONs (if any) in a post-order traversal
117
+ left = self.get_source_columns_from_set_op(set_op.left)
118
+ right = self.get_source_columns_from_set_op(set_op.right)
119
+
120
+ # We use dict.fromkeys to deduplicate keys and maintain insertion order
121
+ if side == "LEFT":
122
+ columns = left
123
+ elif side == "FULL":
124
+ columns = list(dict.fromkeys(left + right))
125
+ elif kind == "INNER":
126
+ columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys())
127
+ else:
128
+ columns = set_op.named_selects
129
+
130
+ return columns
131
+
132
+ def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
133
+ """Resolve the source columns for a given source `name`."""
134
+ cache_key = (name, only_visible)
135
+ if cache_key not in self._get_source_columns_cache:
136
+ if name not in self.scope.sources:
137
+ raise OptimizeError(f"Unknown table: {name}")
138
+
139
+ source = self.scope.sources[name]
140
+
141
+ if isinstance(source, exp.Table):
142
+ columns = self.schema.column_names(source, only_visible)
143
+ elif isinstance(source, Scope) and isinstance(
144
+ source.expression, (exp.Values, exp.Unnest)
145
+ ):
146
+ columns = source.expression.named_selects
147
+
148
+ # in bigquery, unnest structs are automatically scoped as tables, so you can
149
+ # directly select a struct field in a query.
150
+ # this handles the case where the unnest is statically defined.
151
+ if self.dialect.UNNEST_COLUMN_ONLY and isinstance(source.expression, exp.Unnest):
152
+ unnest = source.expression
153
+
154
+ # if type is not annotated yet, try to get it from the schema
155
+ if not unnest.type or unnest.type.is_type(exp.DataType.Type.UNKNOWN):
156
+ unnest_expr = seq_get(unnest.expressions, 0)
157
+ if isinstance(unnest_expr, exp.Column) and self.scope.parent:
158
+ col_type = self._get_unnest_column_type(unnest_expr)
159
+ # extract element type if it's an ARRAY
160
+ if col_type and col_type.is_type(exp.DataType.Type.ARRAY):
161
+ element_types = col_type.expressions
162
+ if element_types:
163
+ unnest.type = element_types[0].copy()
164
+ else:
165
+ if col_type:
166
+ unnest.type = col_type.copy()
167
+ # check if the result type is a STRUCT - extract struct field names
168
+ if unnest.is_type(exp.DataType.Type.STRUCT):
169
+ for k in unnest.type.expressions: # type: ignore
170
+ columns.append(k.name)
171
+ elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation):
172
+ columns = self.get_source_columns_from_set_op(source.expression)
173
+
174
+ else:
175
+ select = seq_get(source.expression.selects, 0)
176
+
177
+ if isinstance(select, exp.QueryTransform):
178
+ # https://spark.apache.org/docs/3.5.1/sql-ref-syntax-qry-select-transform.html
179
+ schema = select.args.get("schema")
180
+ columns = [c.name for c in schema.expressions] if schema else ["key", "value"]
181
+ else:
182
+ columns = source.expression.named_selects
183
+
184
+ node, _ = self.scope.selected_sources.get(name) or (None, None)
185
+ if isinstance(node, Scope):
186
+ column_aliases = node.expression.alias_column_names
187
+ elif isinstance(node, exp.Expression):
188
+ column_aliases = node.alias_column_names
189
+ else:
190
+ column_aliases = []
191
+
192
+ if column_aliases:
193
+ # If the source's columns are aliased, their aliases shadow the corresponding column names.
194
+ # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
195
+ columns = [
196
+ alias or name
197
+ for (name, alias) in itertools.zip_longest(columns, column_aliases)
198
+ ]
199
+
200
+ self._get_source_columns_cache[cache_key] = columns
201
+
202
+ return self._get_source_columns_cache[cache_key]
203
+
204
+ def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
205
+ if self._source_columns is None:
206
+ self._source_columns = {
207
+ source_name: self.get_source_columns(source_name)
208
+ for source_name, source in itertools.chain(
209
+ self.scope.selected_sources.items(), self.scope.lateral_sources.items()
210
+ )
211
+ }
212
+ return self._source_columns
213
+
214
+ def _get_table_name_from_sources(
215
+ self, column_name: str, source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
216
+ ) -> t.Optional[str]:
217
+ if not source_columns:
218
+ # If not supplied, get all sources to calculate unambiguous columns
219
+ if self._unambiguous_columns is None:
220
+ self._unambiguous_columns = self._get_unambiguous_columns(
221
+ self._get_all_source_columns()
222
+ )
223
+
224
+ unambiguous_columns = self._unambiguous_columns
225
+ else:
226
+ unambiguous_columns = self._get_unambiguous_columns(source_columns)
227
+
228
+ return unambiguous_columns.get(column_name)
229
+
230
+ def _get_column_join_context(self, column: exp.Column) -> t.Optional[exp.Join]:
231
+ """
232
+ Check if a column participating in a join can be qualified based on the source order.
233
+ """
234
+ args = self.scope.expression.args
235
+ joins = args.get("joins")
236
+
237
+ if not joins or args.get("laterals") or args.get("pivots"):
238
+ # Feature gap: We currently don't try to disambiguate columns if other sources
239
+ # (e.g laterals, pivots) exist alongside joins
240
+ return None
241
+
242
+ join_ancestor = column.find_ancestor(exp.Join, exp.Select)
243
+
244
+ if (
245
+ isinstance(join_ancestor, exp.Join)
246
+ and join_ancestor.alias_or_name in self.scope.selected_sources
247
+ ):
248
+ # Ensure that the found ancestor is a join that contains an actual source,
249
+ # e.g in Clickhouse `b` is an array expression in `a ARRAY JOIN b`
250
+ return join_ancestor
251
+
252
+ return None
253
+
254
+ def _get_available_source_columns(
255
+ self, join_ancestor: exp.Join
256
+ ) -> t.Dict[str, t.Sequence[str]]:
257
+ """
258
+ Get the source columns that are available at the point where a column is referenced.
259
+
260
+ For columns in JOIN conditions, this only includes tables that have been joined
261
+ up to that point. Example:
262
+
263
+ ```
264
+ SELECT * FROM t_1 INNER JOIN ... INNER JOIN t_n ON t_1.a = c INNER JOIN t_n+1 ON ...
265
+ ``` ^
266
+ |
267
+ +----------------------------------+
268
+ |
269
+
270
+ The unqualified column `c` is not ambiguous if no other sources up until that
271
+ join i.e t_1, ..., t_n, contain a column named `c`.
272
+
273
+ """
274
+ args = self.scope.expression.args
275
+
276
+ # Collect tables in order: FROM clause tables + joined tables up to current join
277
+ from_name = args["from_"].alias_or_name
278
+ available_sources = {from_name: self.get_source_columns(from_name)}
279
+
280
+ for join in args["joins"][: t.cast(int, join_ancestor.index) + 1]:
281
+ available_sources[join.alias_or_name] = self.get_source_columns(join.alias_or_name)
282
+
283
+ return available_sources
284
+
285
+ def _get_unambiguous_columns(
286
+ self, source_columns: t.Dict[str, t.Sequence[str]]
287
+ ) -> t.Mapping[str, str]:
288
+ """
289
+ Find all the unambiguous columns in sources.
290
+
291
+ Args:
292
+ source_columns: Mapping of names to source columns.
293
+
294
+ Returns:
295
+ Mapping of column name to source name.
296
+ """
297
+ if not source_columns:
298
+ return {}
299
+
300
+ source_columns_pairs = list(source_columns.items())
301
+
302
+ first_table, first_columns = source_columns_pairs[0]
303
+
304
+ if len(source_columns_pairs) == 1:
305
+ # Performance optimization - avoid copying first_columns if there is only one table.
306
+ return SingleValuedMapping(first_columns, first_table)
307
+
308
+ unambiguous_columns = {col: first_table for col in first_columns}
309
+ all_columns = set(unambiguous_columns)
310
+
311
+ for table, columns in source_columns_pairs[1:]:
312
+ unique = set(columns)
313
+ ambiguous = all_columns.intersection(unique)
314
+ all_columns.update(columns)
315
+
316
+ for column in ambiguous:
317
+ unambiguous_columns.pop(column, None)
318
+ for column in unique.difference(ambiguous):
319
+ unambiguous_columns[column] = table
320
+
321
+ return unambiguous_columns
322
+
323
+ def _get_unnest_column_type(self, column: exp.Column) -> t.Optional[exp.DataType]:
324
+ """
325
+ Get the type of a column being unnested, tracing through CTEs/subqueries to find the base table.
326
+
327
+ Args:
328
+ column: The column expression being unnested.
329
+
330
+ Returns:
331
+ The DataType of the column, or None if not found.
332
+ """
333
+ scope = self.scope.parent
334
+
335
+ # if column is qualified, use that table, otherwise disambiguate using the resolver
336
+ if column.table:
337
+ table_name = column.table
338
+ else:
339
+ # use the parent scope's resolver to disambiguate the column
340
+ parent_resolver = Resolver(scope, self.schema, self._infer_schema)
341
+ table_identifier = parent_resolver.get_table(column)
342
+ if not table_identifier:
343
+ return None
344
+ table_name = table_identifier.name
345
+
346
+ source = scope.sources.get(table_name)
347
+ return self._get_column_type_from_scope(source, column) if source else None
348
+
349
+ def _get_column_type_from_scope(
350
+ self, source: t.Union[Scope, exp.Table], column: exp.Column
351
+ ) -> t.Optional[exp.DataType]:
352
+ """
353
+ Get a column's type by tracing through scopes/tables to find the base table.
354
+
355
+ Args:
356
+ source: The source to search - can be a Scope (to iterate its sources) or a Table.
357
+ column: The column to find the type for.
358
+
359
+ Returns:
360
+ The DataType of the column, or None if not found.
361
+ """
362
+ if isinstance(source, exp.Table):
363
+ # base table - get the column type from schema
364
+ col_type: t.Optional[exp.DataType] = self.schema.get_column_type(source, column)
365
+ if col_type and not col_type.is_type(exp.DataType.Type.UNKNOWN):
366
+ return col_type
367
+ elif isinstance(source, Scope):
368
+ # iterate over all sources in the scope
369
+ for source_name, nested_source in source.sources.items():
370
+ col_type = self._get_column_type_from_scope(nested_source, column)
371
+ if col_type and not col_type.is_type(exp.DataType.Type.UNKNOWN):
372
+ return col_type
373
+
374
+ return None