sqlglot 27.27.0__py3-none-any.whl → 28.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlglot/__init__.py +1 -0
- sqlglot/__main__.py +6 -4
- sqlglot/_version.py +2 -2
- sqlglot/dialects/bigquery.py +118 -279
- sqlglot/dialects/clickhouse.py +73 -5
- sqlglot/dialects/databricks.py +38 -1
- sqlglot/dialects/dialect.py +354 -275
- sqlglot/dialects/dremio.py +4 -1
- sqlglot/dialects/duckdb.py +754 -25
- sqlglot/dialects/exasol.py +243 -10
- sqlglot/dialects/hive.py +8 -8
- sqlglot/dialects/mysql.py +14 -4
- sqlglot/dialects/oracle.py +29 -0
- sqlglot/dialects/postgres.py +60 -26
- sqlglot/dialects/presto.py +47 -16
- sqlglot/dialects/redshift.py +16 -0
- sqlglot/dialects/risingwave.py +3 -0
- sqlglot/dialects/singlestore.py +12 -3
- sqlglot/dialects/snowflake.py +239 -218
- sqlglot/dialects/spark.py +15 -4
- sqlglot/dialects/spark2.py +11 -48
- sqlglot/dialects/sqlite.py +10 -0
- sqlglot/dialects/starrocks.py +3 -0
- sqlglot/dialects/teradata.py +5 -8
- sqlglot/dialects/trino.py +6 -0
- sqlglot/dialects/tsql.py +61 -22
- sqlglot/diff.py +4 -2
- sqlglot/errors.py +69 -0
- sqlglot/executor/__init__.py +5 -10
- sqlglot/executor/python.py +1 -29
- sqlglot/expressions.py +637 -100
- sqlglot/generator.py +160 -43
- sqlglot/helper.py +2 -44
- sqlglot/lineage.py +10 -4
- sqlglot/optimizer/annotate_types.py +247 -140
- sqlglot/optimizer/canonicalize.py +6 -1
- sqlglot/optimizer/eliminate_joins.py +1 -1
- sqlglot/optimizer/eliminate_subqueries.py +2 -2
- sqlglot/optimizer/merge_subqueries.py +5 -5
- sqlglot/optimizer/normalize.py +20 -13
- sqlglot/optimizer/normalize_identifiers.py +17 -3
- sqlglot/optimizer/optimizer.py +4 -0
- sqlglot/optimizer/pushdown_predicates.py +1 -1
- sqlglot/optimizer/qualify.py +18 -10
- sqlglot/optimizer/qualify_columns.py +122 -275
- sqlglot/optimizer/qualify_tables.py +128 -76
- sqlglot/optimizer/resolver.py +374 -0
- sqlglot/optimizer/scope.py +27 -16
- sqlglot/optimizer/simplify.py +1075 -959
- sqlglot/optimizer/unnest_subqueries.py +12 -2
- sqlglot/parser.py +296 -170
- sqlglot/planner.py +2 -2
- sqlglot/schema.py +15 -4
- sqlglot/tokens.py +42 -7
- sqlglot/transforms.py +77 -22
- sqlglot/typing/__init__.py +316 -0
- sqlglot/typing/bigquery.py +376 -0
- sqlglot/typing/hive.py +12 -0
- sqlglot/typing/presto.py +24 -0
- sqlglot/typing/snowflake.py +505 -0
- sqlglot/typing/spark2.py +58 -0
- sqlglot/typing/tsql.py +9 -0
- {sqlglot-27.27.0.dist-info → sqlglot-28.4.0.dist-info}/METADATA +2 -2
- sqlglot-28.4.0.dist-info/RECORD +92 -0
- sqlglot-27.27.0.dist-info/RECORD +0 -84
- {sqlglot-27.27.0.dist-info → sqlglot-28.4.0.dist-info}/WHEEL +0 -0
- {sqlglot-27.27.0.dist-info → sqlglot-28.4.0.dist-info}/licenses/LICENSE +0 -0
- {sqlglot-27.27.0.dist-info → sqlglot-28.4.0.dist-info}/top_level.txt +0 -0
|
@@ -1,14 +1,12 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
import itertools
|
|
4
3
|
import typing as t
|
|
5
4
|
|
|
6
|
-
from sqlglot import
|
|
7
|
-
from sqlglot.dialects.dialect import DialectType
|
|
8
|
-
from sqlglot.helper import
|
|
5
|
+
from sqlglot import exp
|
|
6
|
+
from sqlglot.dialects.dialect import Dialect, DialectType
|
|
7
|
+
from sqlglot.helper import name_sequence, seq_get, ensure_list
|
|
8
|
+
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
|
|
9
9
|
from sqlglot.optimizer.scope import Scope, traverse_scope
|
|
10
|
-
from sqlglot.schema import Schema
|
|
11
|
-
from sqlglot.dialects.dialect import Dialect
|
|
12
10
|
|
|
13
11
|
if t.TYPE_CHECKING:
|
|
14
12
|
from sqlglot._typing import E
|
|
@@ -18,9 +16,9 @@ def qualify_tables(
|
|
|
18
16
|
expression: E,
|
|
19
17
|
db: t.Optional[str | exp.Identifier] = None,
|
|
20
18
|
catalog: t.Optional[str | exp.Identifier] = None,
|
|
21
|
-
|
|
22
|
-
infer_csv_schemas: bool = False,
|
|
19
|
+
on_qualify: t.Optional[t.Callable[[exp.Table], None]] = None,
|
|
23
20
|
dialect: DialectType = None,
|
|
21
|
+
canonicalize_table_aliases: bool = False,
|
|
24
22
|
) -> E:
|
|
25
23
|
"""
|
|
26
24
|
Rewrite sqlglot AST to have fully qualified tables. Join constructs such as
|
|
@@ -40,17 +38,25 @@ def qualify_tables(
|
|
|
40
38
|
expression: Expression to qualify
|
|
41
39
|
db: Database name
|
|
42
40
|
catalog: Catalog name
|
|
43
|
-
|
|
44
|
-
infer_csv_schemas: Whether to scan READ_CSV calls in order to infer the CSVs' schemas.
|
|
41
|
+
on_qualify: Callback after a table has been qualified.
|
|
45
42
|
dialect: The dialect to parse catalog and schema into.
|
|
43
|
+
canonicalize_table_aliases: Whether to use canonical aliases (_0, _1, ...) for all sources
|
|
44
|
+
instead of preserving table names. Defaults to False.
|
|
46
45
|
|
|
47
46
|
Returns:
|
|
48
47
|
The qualified expression.
|
|
49
48
|
"""
|
|
50
|
-
next_alias_name = name_sequence("_q_")
|
|
51
|
-
db = exp.parse_identifier(db, dialect=dialect) if db else None
|
|
52
|
-
catalog = exp.parse_identifier(catalog, dialect=dialect) if catalog else None
|
|
53
49
|
dialect = Dialect.get_or_raise(dialect)
|
|
50
|
+
next_alias_name = name_sequence("_")
|
|
51
|
+
|
|
52
|
+
if db := db or None:
|
|
53
|
+
db = exp.parse_identifier(db, dialect=dialect)
|
|
54
|
+
db.meta["is_table"] = True
|
|
55
|
+
db = normalize_identifiers(db, dialect=dialect)
|
|
56
|
+
if catalog := catalog or None:
|
|
57
|
+
catalog = exp.parse_identifier(catalog, dialect=dialect)
|
|
58
|
+
catalog.meta["is_table"] = True
|
|
59
|
+
catalog = normalize_identifiers(catalog, dialect=dialect)
|
|
54
60
|
|
|
55
61
|
def _qualify(table: exp.Table) -> None:
|
|
56
62
|
if isinstance(table.this, exp.Identifier):
|
|
@@ -60,93 +66,132 @@ def qualify_tables(
|
|
|
60
66
|
table.set("catalog", catalog.copy())
|
|
61
67
|
|
|
62
68
|
if (db or catalog) and not isinstance(expression, exp.Query):
|
|
63
|
-
with_ = expression.args.get("
|
|
69
|
+
with_ = expression.args.get("with_") or exp.With()
|
|
64
70
|
cte_names = {cte.alias_or_name for cte in with_.expressions}
|
|
65
71
|
|
|
66
72
|
for node in expression.walk(prune=lambda n: isinstance(n, exp.Query)):
|
|
67
73
|
if isinstance(node, exp.Table) and node.name not in cte_names:
|
|
68
74
|
_qualify(node)
|
|
69
75
|
|
|
76
|
+
def _set_alias(
|
|
77
|
+
expression: exp.Expression,
|
|
78
|
+
canonical_aliases: t.Dict[str, str],
|
|
79
|
+
target_alias: t.Optional[str] = None,
|
|
80
|
+
scope: t.Optional[Scope] = None,
|
|
81
|
+
normalize: bool = False,
|
|
82
|
+
columns: t.Optional[t.List[t.Union[str, exp.Identifier]]] = None,
|
|
83
|
+
) -> None:
|
|
84
|
+
alias = expression.args.get("alias") or exp.TableAlias()
|
|
85
|
+
|
|
86
|
+
if canonicalize_table_aliases:
|
|
87
|
+
new_alias_name = next_alias_name()
|
|
88
|
+
canonical_aliases[alias.name or target_alias or ""] = new_alias_name
|
|
89
|
+
elif not alias.name:
|
|
90
|
+
new_alias_name = target_alias or next_alias_name()
|
|
91
|
+
if normalize and target_alias:
|
|
92
|
+
new_alias_name = normalize_identifiers(new_alias_name, dialect=dialect).name
|
|
93
|
+
else:
|
|
94
|
+
return
|
|
95
|
+
|
|
96
|
+
alias.set("this", exp.to_identifier(new_alias_name))
|
|
97
|
+
|
|
98
|
+
if columns:
|
|
99
|
+
alias.set("columns", [exp.to_identifier(c) for c in columns])
|
|
100
|
+
|
|
101
|
+
expression.set("alias", alias)
|
|
102
|
+
|
|
103
|
+
if scope:
|
|
104
|
+
scope.rename_source(None, new_alias_name)
|
|
105
|
+
|
|
70
106
|
for scope in traverse_scope(expression):
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
107
|
+
local_columns = scope.local_columns
|
|
108
|
+
canonical_aliases: t.Dict[str, str] = {}
|
|
109
|
+
|
|
110
|
+
for query in scope.subqueries:
|
|
111
|
+
subquery = query.parent
|
|
112
|
+
if isinstance(subquery, exp.Subquery):
|
|
113
|
+
subquery.unwrap().replace(subquery)
|
|
114
|
+
|
|
115
|
+
for derived_table in scope.derived_tables:
|
|
116
|
+
unnested = derived_table.unnest()
|
|
117
|
+
if isinstance(unnested, exp.Table):
|
|
118
|
+
joins = unnested.args.get("joins")
|
|
119
|
+
unnested.set("joins", None)
|
|
120
|
+
derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False))
|
|
121
|
+
derived_table.this.set("joins", joins)
|
|
122
|
+
|
|
123
|
+
_set_alias(derived_table, canonical_aliases, scope=scope)
|
|
124
|
+
if pivot := seq_get(derived_table.args.get("pivots") or [], 0):
|
|
125
|
+
_set_alias(pivot, canonical_aliases)
|
|
88
126
|
|
|
89
127
|
table_aliases = {}
|
|
90
128
|
|
|
91
129
|
for name, source in scope.sources.items():
|
|
92
130
|
if isinstance(source, exp.Table):
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
131
|
+
# When the name is empty, it means that we have a non-table source, e.g. a pivoted cte
|
|
132
|
+
is_real_table_source = bool(name)
|
|
133
|
+
|
|
134
|
+
if pivot := seq_get(source.args.get("pivots") or [], 0):
|
|
135
|
+
name = source.name
|
|
136
|
+
|
|
137
|
+
table_this = source.this
|
|
138
|
+
table_alias = source.args.get("alias")
|
|
139
|
+
function_columns: t.List[t.Union[str, exp.Identifier]] = []
|
|
140
|
+
if isinstance(table_this, exp.Func):
|
|
141
|
+
if not table_alias:
|
|
142
|
+
function_columns = ensure_list(
|
|
143
|
+
dialect.DEFAULT_FUNCTIONS_COLUMN_NAMES.get(type(table_this))
|
|
144
|
+
)
|
|
145
|
+
elif columns := table_alias.columns:
|
|
146
|
+
function_columns = columns
|
|
147
|
+
elif type(table_this) in dialect.DEFAULT_FUNCTIONS_COLUMN_NAMES:
|
|
148
|
+
function_columns = ensure_list(source.alias_or_name)
|
|
149
|
+
source.set("alias", None)
|
|
150
|
+
name = None
|
|
151
|
+
|
|
152
|
+
_set_alias(
|
|
153
|
+
source,
|
|
154
|
+
canonical_aliases,
|
|
155
|
+
target_alias=name or source.name or None,
|
|
156
|
+
normalize=True,
|
|
157
|
+
columns=function_columns,
|
|
104
158
|
)
|
|
105
159
|
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
160
|
+
source_fqn = ".".join(p.name for p in source.parts)
|
|
161
|
+
table_aliases[source_fqn] = source.args["alias"].this.copy()
|
|
162
|
+
|
|
163
|
+
if pivot:
|
|
164
|
+
target_alias = source.alias if pivot.unpivot else None
|
|
165
|
+
_set_alias(pivot, canonical_aliases, target_alias=target_alias, normalize=True)
|
|
111
166
|
|
|
112
167
|
# This case corresponds to a pivoted CTE, we don't want to qualify that
|
|
113
168
|
if isinstance(scope.sources.get(source.alias_or_name), Scope):
|
|
114
169
|
continue
|
|
115
170
|
|
|
116
|
-
|
|
171
|
+
if is_real_table_source:
|
|
172
|
+
_qualify(source)
|
|
117
173
|
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
header = next(reader)
|
|
121
|
-
columns = next(reader)
|
|
122
|
-
schema.add_table(
|
|
123
|
-
source,
|
|
124
|
-
{k: type(v).__name__ for k, v in zip(header, columns)},
|
|
125
|
-
match_depth=False,
|
|
126
|
-
)
|
|
174
|
+
if on_qualify:
|
|
175
|
+
on_qualify(source)
|
|
127
176
|
elif isinstance(source, Scope) and source.is_udtf:
|
|
128
|
-
udtf
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
)
|
|
132
|
-
udtf.set("alias", table_alias)
|
|
177
|
+
_set_alias(udtf := source.expression, canonical_aliases)
|
|
178
|
+
|
|
179
|
+
table_alias = udtf.args["alias"]
|
|
133
180
|
|
|
134
|
-
if not table_alias.name:
|
|
135
|
-
table_alias.set("this", exp.to_identifier(next_alias_name()))
|
|
136
181
|
if isinstance(udtf, exp.Values) and not table_alias.columns:
|
|
137
|
-
column_aliases =
|
|
182
|
+
column_aliases = [
|
|
183
|
+
normalize_identifiers(i, dialect=dialect)
|
|
184
|
+
for i in dialect.generate_values_aliases(udtf)
|
|
185
|
+
]
|
|
138
186
|
table_alias.set("columns", column_aliases)
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
alias(node, node.name, copy=False, table=True)
|
|
148
|
-
|
|
149
|
-
for column in scope.columns:
|
|
187
|
+
|
|
188
|
+
for table in scope.tables:
|
|
189
|
+
if not table.alias and isinstance(table.parent, (exp.From, exp.Join)):
|
|
190
|
+
_set_alias(table, canonical_aliases, target_alias=table.name)
|
|
191
|
+
|
|
192
|
+
for column in local_columns:
|
|
193
|
+
table = column.table
|
|
194
|
+
|
|
150
195
|
if column.db:
|
|
151
196
|
table_alias = table_aliases.get(".".join(p.name for p in column.parts[0:-1]))
|
|
152
197
|
|
|
@@ -155,5 +200,12 @@ def qualify_tables(
|
|
|
155
200
|
column.set(p, None)
|
|
156
201
|
|
|
157
202
|
column.set("table", table_alias.copy())
|
|
203
|
+
elif (
|
|
204
|
+
canonical_aliases
|
|
205
|
+
and table
|
|
206
|
+
and (canonical_table := canonical_aliases.get(table, "")) != column.table
|
|
207
|
+
):
|
|
208
|
+
# Amend existing aliases, e.g. t.c -> _0.c if t is aliased to _0
|
|
209
|
+
column.set("table", exp.to_identifier(canonical_table))
|
|
158
210
|
|
|
159
211
|
return expression
|
|
@@ -0,0 +1,374 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import itertools
|
|
4
|
+
import typing as t
|
|
5
|
+
|
|
6
|
+
from sqlglot import exp
|
|
7
|
+
from sqlglot.dialects.dialect import Dialect
|
|
8
|
+
from sqlglot.errors import OptimizeError
|
|
9
|
+
from sqlglot.helper import seq_get, SingleValuedMapping
|
|
10
|
+
from sqlglot.optimizer.scope import Scope
|
|
11
|
+
|
|
12
|
+
if t.TYPE_CHECKING:
|
|
13
|
+
from sqlglot.schema import Schema
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class Resolver:
|
|
17
|
+
"""
|
|
18
|
+
Helper for resolving columns.
|
|
19
|
+
|
|
20
|
+
This is a class so we can lazily load some things and easily share them across functions.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
|
|
24
|
+
self.scope = scope
|
|
25
|
+
self.schema = schema
|
|
26
|
+
self.dialect = schema.dialect or Dialect()
|
|
27
|
+
self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
|
|
28
|
+
self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
|
|
29
|
+
self._all_columns: t.Optional[t.Set[str]] = None
|
|
30
|
+
self._infer_schema = infer_schema
|
|
31
|
+
self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
|
|
32
|
+
|
|
33
|
+
def get_table(self, column: str | exp.Column) -> t.Optional[exp.Identifier]:
|
|
34
|
+
"""
|
|
35
|
+
Get the table for a column name.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
column: The column expression (or column name) to find the table for.
|
|
39
|
+
Returns:
|
|
40
|
+
The table name if it can be found/inferred.
|
|
41
|
+
"""
|
|
42
|
+
column_name = column if isinstance(column, str) else column.name
|
|
43
|
+
|
|
44
|
+
table_name = self._get_table_name_from_sources(column_name)
|
|
45
|
+
|
|
46
|
+
if not table_name and isinstance(column, exp.Column):
|
|
47
|
+
# Fall-back case: If we couldn't find the `table_name` from ALL of the sources,
|
|
48
|
+
# attempt to disambiguate the column based on other characteristics e.g if this column is in a join condition,
|
|
49
|
+
# we may be able to disambiguate based on the source order.
|
|
50
|
+
if join_context := self._get_column_join_context(column):
|
|
51
|
+
# In this case, the return value will be the join that _may_ be able to disambiguate the column
|
|
52
|
+
# and we can use the source columns available at that join to get the table name
|
|
53
|
+
# catch OptimizeError if column is still ambiguous and try to resolve with schema inference below
|
|
54
|
+
try:
|
|
55
|
+
table_name = self._get_table_name_from_sources(
|
|
56
|
+
column_name, self._get_available_source_columns(join_context)
|
|
57
|
+
)
|
|
58
|
+
except OptimizeError:
|
|
59
|
+
pass
|
|
60
|
+
|
|
61
|
+
if not table_name and self._infer_schema:
|
|
62
|
+
sources_without_schema = tuple(
|
|
63
|
+
source
|
|
64
|
+
for source, columns in self._get_all_source_columns().items()
|
|
65
|
+
if not columns or "*" in columns
|
|
66
|
+
)
|
|
67
|
+
if len(sources_without_schema) == 1:
|
|
68
|
+
table_name = sources_without_schema[0]
|
|
69
|
+
|
|
70
|
+
if table_name not in self.scope.selected_sources:
|
|
71
|
+
return exp.to_identifier(table_name)
|
|
72
|
+
|
|
73
|
+
node, _ = self.scope.selected_sources.get(table_name)
|
|
74
|
+
|
|
75
|
+
if isinstance(node, exp.Query):
|
|
76
|
+
while node and node.alias != table_name:
|
|
77
|
+
node = node.parent
|
|
78
|
+
|
|
79
|
+
node_alias = node.args.get("alias")
|
|
80
|
+
if node_alias:
|
|
81
|
+
return exp.to_identifier(node_alias.this)
|
|
82
|
+
|
|
83
|
+
return exp.to_identifier(table_name)
|
|
84
|
+
|
|
85
|
+
@property
|
|
86
|
+
def all_columns(self) -> t.Set[str]:
|
|
87
|
+
"""All available columns of all sources in this scope"""
|
|
88
|
+
if self._all_columns is None:
|
|
89
|
+
self._all_columns = {
|
|
90
|
+
column for columns in self._get_all_source_columns().values() for column in columns
|
|
91
|
+
}
|
|
92
|
+
return self._all_columns
|
|
93
|
+
|
|
94
|
+
def get_source_columns_from_set_op(self, expression: exp.Expression) -> t.List[str]:
|
|
95
|
+
if isinstance(expression, exp.Select):
|
|
96
|
+
return expression.named_selects
|
|
97
|
+
if isinstance(expression, exp.Subquery) and isinstance(expression.this, exp.SetOperation):
|
|
98
|
+
# Different types of SET modifiers can be chained together if they're explicitly grouped by nesting
|
|
99
|
+
return self.get_source_columns_from_set_op(expression.this)
|
|
100
|
+
if not isinstance(expression, exp.SetOperation):
|
|
101
|
+
raise OptimizeError(f"Unknown set operation: {expression}")
|
|
102
|
+
|
|
103
|
+
set_op = expression
|
|
104
|
+
|
|
105
|
+
# BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME
|
|
106
|
+
on_column_list = set_op.args.get("on")
|
|
107
|
+
|
|
108
|
+
if on_column_list:
|
|
109
|
+
# The resulting columns are the columns in the ON clause:
|
|
110
|
+
# {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...)
|
|
111
|
+
columns = [col.name for col in on_column_list]
|
|
112
|
+
elif set_op.side or set_op.kind:
|
|
113
|
+
side = set_op.side
|
|
114
|
+
kind = set_op.kind
|
|
115
|
+
|
|
116
|
+
# Visit the children UNIONs (if any) in a post-order traversal
|
|
117
|
+
left = self.get_source_columns_from_set_op(set_op.left)
|
|
118
|
+
right = self.get_source_columns_from_set_op(set_op.right)
|
|
119
|
+
|
|
120
|
+
# We use dict.fromkeys to deduplicate keys and maintain insertion order
|
|
121
|
+
if side == "LEFT":
|
|
122
|
+
columns = left
|
|
123
|
+
elif side == "FULL":
|
|
124
|
+
columns = list(dict.fromkeys(left + right))
|
|
125
|
+
elif kind == "INNER":
|
|
126
|
+
columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys())
|
|
127
|
+
else:
|
|
128
|
+
columns = set_op.named_selects
|
|
129
|
+
|
|
130
|
+
return columns
|
|
131
|
+
|
|
132
|
+
def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
|
|
133
|
+
"""Resolve the source columns for a given source `name`."""
|
|
134
|
+
cache_key = (name, only_visible)
|
|
135
|
+
if cache_key not in self._get_source_columns_cache:
|
|
136
|
+
if name not in self.scope.sources:
|
|
137
|
+
raise OptimizeError(f"Unknown table: {name}")
|
|
138
|
+
|
|
139
|
+
source = self.scope.sources[name]
|
|
140
|
+
|
|
141
|
+
if isinstance(source, exp.Table):
|
|
142
|
+
columns = self.schema.column_names(source, only_visible)
|
|
143
|
+
elif isinstance(source, Scope) and isinstance(
|
|
144
|
+
source.expression, (exp.Values, exp.Unnest)
|
|
145
|
+
):
|
|
146
|
+
columns = source.expression.named_selects
|
|
147
|
+
|
|
148
|
+
# in bigquery, unnest structs are automatically scoped as tables, so you can
|
|
149
|
+
# directly select a struct field in a query.
|
|
150
|
+
# this handles the case where the unnest is statically defined.
|
|
151
|
+
if self.dialect.UNNEST_COLUMN_ONLY and isinstance(source.expression, exp.Unnest):
|
|
152
|
+
unnest = source.expression
|
|
153
|
+
|
|
154
|
+
# if type is not annotated yet, try to get it from the schema
|
|
155
|
+
if not unnest.type or unnest.type.is_type(exp.DataType.Type.UNKNOWN):
|
|
156
|
+
unnest_expr = seq_get(unnest.expressions, 0)
|
|
157
|
+
if isinstance(unnest_expr, exp.Column) and self.scope.parent:
|
|
158
|
+
col_type = self._get_unnest_column_type(unnest_expr)
|
|
159
|
+
# extract element type if it's an ARRAY
|
|
160
|
+
if col_type and col_type.is_type(exp.DataType.Type.ARRAY):
|
|
161
|
+
element_types = col_type.expressions
|
|
162
|
+
if element_types:
|
|
163
|
+
unnest.type = element_types[0].copy()
|
|
164
|
+
else:
|
|
165
|
+
if col_type:
|
|
166
|
+
unnest.type = col_type.copy()
|
|
167
|
+
# check if the result type is a STRUCT - extract struct field names
|
|
168
|
+
if unnest.is_type(exp.DataType.Type.STRUCT):
|
|
169
|
+
for k in unnest.type.expressions: # type: ignore
|
|
170
|
+
columns.append(k.name)
|
|
171
|
+
elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation):
|
|
172
|
+
columns = self.get_source_columns_from_set_op(source.expression)
|
|
173
|
+
|
|
174
|
+
else:
|
|
175
|
+
select = seq_get(source.expression.selects, 0)
|
|
176
|
+
|
|
177
|
+
if isinstance(select, exp.QueryTransform):
|
|
178
|
+
# https://spark.apache.org/docs/3.5.1/sql-ref-syntax-qry-select-transform.html
|
|
179
|
+
schema = select.args.get("schema")
|
|
180
|
+
columns = [c.name for c in schema.expressions] if schema else ["key", "value"]
|
|
181
|
+
else:
|
|
182
|
+
columns = source.expression.named_selects
|
|
183
|
+
|
|
184
|
+
node, _ = self.scope.selected_sources.get(name) or (None, None)
|
|
185
|
+
if isinstance(node, Scope):
|
|
186
|
+
column_aliases = node.expression.alias_column_names
|
|
187
|
+
elif isinstance(node, exp.Expression):
|
|
188
|
+
column_aliases = node.alias_column_names
|
|
189
|
+
else:
|
|
190
|
+
column_aliases = []
|
|
191
|
+
|
|
192
|
+
if column_aliases:
|
|
193
|
+
# If the source's columns are aliased, their aliases shadow the corresponding column names.
|
|
194
|
+
# This can be expensive if there are lots of columns, so only do this if column_aliases exist.
|
|
195
|
+
columns = [
|
|
196
|
+
alias or name
|
|
197
|
+
for (name, alias) in itertools.zip_longest(columns, column_aliases)
|
|
198
|
+
]
|
|
199
|
+
|
|
200
|
+
self._get_source_columns_cache[cache_key] = columns
|
|
201
|
+
|
|
202
|
+
return self._get_source_columns_cache[cache_key]
|
|
203
|
+
|
|
204
|
+
def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
|
|
205
|
+
if self._source_columns is None:
|
|
206
|
+
self._source_columns = {
|
|
207
|
+
source_name: self.get_source_columns(source_name)
|
|
208
|
+
for source_name, source in itertools.chain(
|
|
209
|
+
self.scope.selected_sources.items(), self.scope.lateral_sources.items()
|
|
210
|
+
)
|
|
211
|
+
}
|
|
212
|
+
return self._source_columns
|
|
213
|
+
|
|
214
|
+
def _get_table_name_from_sources(
|
|
215
|
+
self, column_name: str, source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
|
|
216
|
+
) -> t.Optional[str]:
|
|
217
|
+
if not source_columns:
|
|
218
|
+
# If not supplied, get all sources to calculate unambiguous columns
|
|
219
|
+
if self._unambiguous_columns is None:
|
|
220
|
+
self._unambiguous_columns = self._get_unambiguous_columns(
|
|
221
|
+
self._get_all_source_columns()
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
unambiguous_columns = self._unambiguous_columns
|
|
225
|
+
else:
|
|
226
|
+
unambiguous_columns = self._get_unambiguous_columns(source_columns)
|
|
227
|
+
|
|
228
|
+
return unambiguous_columns.get(column_name)
|
|
229
|
+
|
|
230
|
+
def _get_column_join_context(self, column: exp.Column) -> t.Optional[exp.Join]:
|
|
231
|
+
"""
|
|
232
|
+
Check if a column participating in a join can be qualified based on the source order.
|
|
233
|
+
"""
|
|
234
|
+
args = self.scope.expression.args
|
|
235
|
+
joins = args.get("joins")
|
|
236
|
+
|
|
237
|
+
if not joins or args.get("laterals") or args.get("pivots"):
|
|
238
|
+
# Feature gap: We currently don't try to disambiguate columns if other sources
|
|
239
|
+
# (e.g laterals, pivots) exist alongside joins
|
|
240
|
+
return None
|
|
241
|
+
|
|
242
|
+
join_ancestor = column.find_ancestor(exp.Join, exp.Select)
|
|
243
|
+
|
|
244
|
+
if (
|
|
245
|
+
isinstance(join_ancestor, exp.Join)
|
|
246
|
+
and join_ancestor.alias_or_name in self.scope.selected_sources
|
|
247
|
+
):
|
|
248
|
+
# Ensure that the found ancestor is a join that contains an actual source,
|
|
249
|
+
# e.g in Clickhouse `b` is an array expression in `a ARRAY JOIN b`
|
|
250
|
+
return join_ancestor
|
|
251
|
+
|
|
252
|
+
return None
|
|
253
|
+
|
|
254
|
+
def _get_available_source_columns(
|
|
255
|
+
self, join_ancestor: exp.Join
|
|
256
|
+
) -> t.Dict[str, t.Sequence[str]]:
|
|
257
|
+
"""
|
|
258
|
+
Get the source columns that are available at the point where a column is referenced.
|
|
259
|
+
|
|
260
|
+
For columns in JOIN conditions, this only includes tables that have been joined
|
|
261
|
+
up to that point. Example:
|
|
262
|
+
|
|
263
|
+
```
|
|
264
|
+
SELECT * FROM t_1 INNER JOIN ... INNER JOIN t_n ON t_1.a = c INNER JOIN t_n+1 ON ...
|
|
265
|
+
``` ^
|
|
266
|
+
|
|
|
267
|
+
+----------------------------------+
|
|
268
|
+
|
|
|
269
|
+
⌄
|
|
270
|
+
The unqualified column `c` is not ambiguous if no other sources up until that
|
|
271
|
+
join i.e t_1, ..., t_n, contain a column named `c`.
|
|
272
|
+
|
|
273
|
+
"""
|
|
274
|
+
args = self.scope.expression.args
|
|
275
|
+
|
|
276
|
+
# Collect tables in order: FROM clause tables + joined tables up to current join
|
|
277
|
+
from_name = args["from_"].alias_or_name
|
|
278
|
+
available_sources = {from_name: self.get_source_columns(from_name)}
|
|
279
|
+
|
|
280
|
+
for join in args["joins"][: t.cast(int, join_ancestor.index) + 1]:
|
|
281
|
+
available_sources[join.alias_or_name] = self.get_source_columns(join.alias_or_name)
|
|
282
|
+
|
|
283
|
+
return available_sources
|
|
284
|
+
|
|
285
|
+
def _get_unambiguous_columns(
|
|
286
|
+
self, source_columns: t.Dict[str, t.Sequence[str]]
|
|
287
|
+
) -> t.Mapping[str, str]:
|
|
288
|
+
"""
|
|
289
|
+
Find all the unambiguous columns in sources.
|
|
290
|
+
|
|
291
|
+
Args:
|
|
292
|
+
source_columns: Mapping of names to source columns.
|
|
293
|
+
|
|
294
|
+
Returns:
|
|
295
|
+
Mapping of column name to source name.
|
|
296
|
+
"""
|
|
297
|
+
if not source_columns:
|
|
298
|
+
return {}
|
|
299
|
+
|
|
300
|
+
source_columns_pairs = list(source_columns.items())
|
|
301
|
+
|
|
302
|
+
first_table, first_columns = source_columns_pairs[0]
|
|
303
|
+
|
|
304
|
+
if len(source_columns_pairs) == 1:
|
|
305
|
+
# Performance optimization - avoid copying first_columns if there is only one table.
|
|
306
|
+
return SingleValuedMapping(first_columns, first_table)
|
|
307
|
+
|
|
308
|
+
unambiguous_columns = {col: first_table for col in first_columns}
|
|
309
|
+
all_columns = set(unambiguous_columns)
|
|
310
|
+
|
|
311
|
+
for table, columns in source_columns_pairs[1:]:
|
|
312
|
+
unique = set(columns)
|
|
313
|
+
ambiguous = all_columns.intersection(unique)
|
|
314
|
+
all_columns.update(columns)
|
|
315
|
+
|
|
316
|
+
for column in ambiguous:
|
|
317
|
+
unambiguous_columns.pop(column, None)
|
|
318
|
+
for column in unique.difference(ambiguous):
|
|
319
|
+
unambiguous_columns[column] = table
|
|
320
|
+
|
|
321
|
+
return unambiguous_columns
|
|
322
|
+
|
|
323
|
+
def _get_unnest_column_type(self, column: exp.Column) -> t.Optional[exp.DataType]:
|
|
324
|
+
"""
|
|
325
|
+
Get the type of a column being unnested, tracing through CTEs/subqueries to find the base table.
|
|
326
|
+
|
|
327
|
+
Args:
|
|
328
|
+
column: The column expression being unnested.
|
|
329
|
+
|
|
330
|
+
Returns:
|
|
331
|
+
The DataType of the column, or None if not found.
|
|
332
|
+
"""
|
|
333
|
+
scope = self.scope.parent
|
|
334
|
+
|
|
335
|
+
# if column is qualified, use that table, otherwise disambiguate using the resolver
|
|
336
|
+
if column.table:
|
|
337
|
+
table_name = column.table
|
|
338
|
+
else:
|
|
339
|
+
# use the parent scope's resolver to disambiguate the column
|
|
340
|
+
parent_resolver = Resolver(scope, self.schema, self._infer_schema)
|
|
341
|
+
table_identifier = parent_resolver.get_table(column)
|
|
342
|
+
if not table_identifier:
|
|
343
|
+
return None
|
|
344
|
+
table_name = table_identifier.name
|
|
345
|
+
|
|
346
|
+
source = scope.sources.get(table_name)
|
|
347
|
+
return self._get_column_type_from_scope(source, column) if source else None
|
|
348
|
+
|
|
349
|
+
def _get_column_type_from_scope(
|
|
350
|
+
self, source: t.Union[Scope, exp.Table], column: exp.Column
|
|
351
|
+
) -> t.Optional[exp.DataType]:
|
|
352
|
+
"""
|
|
353
|
+
Get a column's type by tracing through scopes/tables to find the base table.
|
|
354
|
+
|
|
355
|
+
Args:
|
|
356
|
+
source: The source to search - can be a Scope (to iterate its sources) or a Table.
|
|
357
|
+
column: The column to find the type for.
|
|
358
|
+
|
|
359
|
+
Returns:
|
|
360
|
+
The DataType of the column, or None if not found.
|
|
361
|
+
"""
|
|
362
|
+
if isinstance(source, exp.Table):
|
|
363
|
+
# base table - get the column type from schema
|
|
364
|
+
col_type: t.Optional[exp.DataType] = self.schema.get_column_type(source, column)
|
|
365
|
+
if col_type and not col_type.is_type(exp.DataType.Type.UNKNOWN):
|
|
366
|
+
return col_type
|
|
367
|
+
elif isinstance(source, Scope):
|
|
368
|
+
# iterate over all sources in the scope
|
|
369
|
+
for source_name, nested_source in source.sources.items():
|
|
370
|
+
col_type = self._get_column_type_from_scope(nested_source, column)
|
|
371
|
+
if col_type and not col_type.is_type(exp.DataType.Type.UNKNOWN):
|
|
372
|
+
return col_type
|
|
373
|
+
|
|
374
|
+
return None
|