sqlglot 27.27.0__py3-none-any.whl → 28.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlglot/__init__.py +1 -0
- sqlglot/__main__.py +6 -4
- sqlglot/_version.py +2 -2
- sqlglot/dialects/bigquery.py +118 -279
- sqlglot/dialects/clickhouse.py +73 -5
- sqlglot/dialects/databricks.py +38 -1
- sqlglot/dialects/dialect.py +354 -275
- sqlglot/dialects/dremio.py +4 -1
- sqlglot/dialects/duckdb.py +754 -25
- sqlglot/dialects/exasol.py +243 -10
- sqlglot/dialects/hive.py +8 -8
- sqlglot/dialects/mysql.py +14 -4
- sqlglot/dialects/oracle.py +29 -0
- sqlglot/dialects/postgres.py +60 -26
- sqlglot/dialects/presto.py +47 -16
- sqlglot/dialects/redshift.py +16 -0
- sqlglot/dialects/risingwave.py +3 -0
- sqlglot/dialects/singlestore.py +12 -3
- sqlglot/dialects/snowflake.py +239 -218
- sqlglot/dialects/spark.py +15 -4
- sqlglot/dialects/spark2.py +11 -48
- sqlglot/dialects/sqlite.py +10 -0
- sqlglot/dialects/starrocks.py +3 -0
- sqlglot/dialects/teradata.py +5 -8
- sqlglot/dialects/trino.py +6 -0
- sqlglot/dialects/tsql.py +61 -22
- sqlglot/diff.py +4 -2
- sqlglot/errors.py +69 -0
- sqlglot/executor/__init__.py +5 -10
- sqlglot/executor/python.py +1 -29
- sqlglot/expressions.py +637 -100
- sqlglot/generator.py +160 -43
- sqlglot/helper.py +2 -44
- sqlglot/lineage.py +10 -4
- sqlglot/optimizer/annotate_types.py +247 -140
- sqlglot/optimizer/canonicalize.py +6 -1
- sqlglot/optimizer/eliminate_joins.py +1 -1
- sqlglot/optimizer/eliminate_subqueries.py +2 -2
- sqlglot/optimizer/merge_subqueries.py +5 -5
- sqlglot/optimizer/normalize.py +20 -13
- sqlglot/optimizer/normalize_identifiers.py +17 -3
- sqlglot/optimizer/optimizer.py +4 -0
- sqlglot/optimizer/pushdown_predicates.py +1 -1
- sqlglot/optimizer/qualify.py +18 -10
- sqlglot/optimizer/qualify_columns.py +122 -275
- sqlglot/optimizer/qualify_tables.py +128 -76
- sqlglot/optimizer/resolver.py +374 -0
- sqlglot/optimizer/scope.py +27 -16
- sqlglot/optimizer/simplify.py +1075 -959
- sqlglot/optimizer/unnest_subqueries.py +12 -2
- sqlglot/parser.py +296 -170
- sqlglot/planner.py +2 -2
- sqlglot/schema.py +15 -4
- sqlglot/tokens.py +42 -7
- sqlglot/transforms.py +77 -22
- sqlglot/typing/__init__.py +316 -0
- sqlglot/typing/bigquery.py +376 -0
- sqlglot/typing/hive.py +12 -0
- sqlglot/typing/presto.py +24 -0
- sqlglot/typing/snowflake.py +505 -0
- sqlglot/typing/spark2.py +58 -0
- sqlglot/typing/tsql.py +9 -0
- {sqlglot-27.27.0.dist-info → sqlglot-28.4.0.dist-info}/METADATA +2 -2
- sqlglot-28.4.0.dist-info/RECORD +92 -0
- sqlglot-27.27.0.dist-info/RECORD +0 -84
- {sqlglot-27.27.0.dist-info → sqlglot-28.4.0.dist-info}/WHEEL +0 -0
- {sqlglot-27.27.0.dist-info → sqlglot-28.4.0.dist-info}/licenses/LICENSE +0 -0
- {sqlglot-27.27.0.dist-info → sqlglot-28.4.0.dist-info}/top_level.txt +0 -0
|
@@ -5,6 +5,7 @@ import logging
|
|
|
5
5
|
import typing as t
|
|
6
6
|
|
|
7
7
|
from sqlglot import exp
|
|
8
|
+
from sqlglot.dialects.dialect import Dialect
|
|
8
9
|
from sqlglot.helper import (
|
|
9
10
|
ensure_list,
|
|
10
11
|
is_date_unit,
|
|
@@ -14,7 +15,6 @@ from sqlglot.helper import (
|
|
|
14
15
|
)
|
|
15
16
|
from sqlglot.optimizer.scope import Scope, traverse_scope
|
|
16
17
|
from sqlglot.schema import MappingSchema, Schema, ensure_schema
|
|
17
|
-
from sqlglot.dialects.dialect import Dialect
|
|
18
18
|
|
|
19
19
|
if t.TYPE_CHECKING:
|
|
20
20
|
from sqlglot._typing import B, E
|
|
@@ -25,7 +25,8 @@ if t.TYPE_CHECKING:
|
|
|
25
25
|
BinaryCoercionFunc,
|
|
26
26
|
]
|
|
27
27
|
|
|
28
|
-
from sqlglot.dialects.dialect import DialectType
|
|
28
|
+
from sqlglot.dialects.dialect import DialectType
|
|
29
|
+
from sqlglot.typing import ExpressionMetadataType
|
|
29
30
|
|
|
30
31
|
logger = logging.getLogger("sqlglot")
|
|
31
32
|
|
|
@@ -33,9 +34,10 @@ logger = logging.getLogger("sqlglot")
|
|
|
33
34
|
def annotate_types(
|
|
34
35
|
expression: E,
|
|
35
36
|
schema: t.Optional[t.Dict | Schema] = None,
|
|
36
|
-
|
|
37
|
+
expression_metadata: t.Optional[ExpressionMetadataType] = None,
|
|
37
38
|
coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
|
|
38
39
|
dialect: DialectType = None,
|
|
40
|
+
overwrite_types: bool = True,
|
|
39
41
|
) -> E:
|
|
40
42
|
"""
|
|
41
43
|
Infers the types of an expression, annotating its AST accordingly.
|
|
@@ -51,8 +53,9 @@ def annotate_types(
|
|
|
51
53
|
Args:
|
|
52
54
|
expression: Expression to annotate.
|
|
53
55
|
schema: Database schema.
|
|
54
|
-
|
|
56
|
+
expression_metadata: Maps expression type to corresponding annotation function.
|
|
55
57
|
coerces_to: Maps expression type to set of types that it can be coerced into.
|
|
58
|
+
overwrite_types: Re-annotate the existing AST types.
|
|
56
59
|
|
|
57
60
|
Returns:
|
|
58
61
|
The expression annotated with types.
|
|
@@ -60,7 +63,12 @@ def annotate_types(
|
|
|
60
63
|
|
|
61
64
|
schema = ensure_schema(schema, dialect=dialect)
|
|
62
65
|
|
|
63
|
-
return TypeAnnotator(
|
|
66
|
+
return TypeAnnotator(
|
|
67
|
+
schema=schema,
|
|
68
|
+
expression_metadata=expression_metadata,
|
|
69
|
+
coerces_to=coerces_to,
|
|
70
|
+
overwrite_types=overwrite_types,
|
|
71
|
+
).annotate(expression)
|
|
64
72
|
|
|
65
73
|
|
|
66
74
|
def _coerce_date_literal(l: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.DataType.Type:
|
|
@@ -109,8 +117,10 @@ class _TypeAnnotator(type):
|
|
|
109
117
|
exp.DataType.Type.CHAR,
|
|
110
118
|
)
|
|
111
119
|
numeric_precedence = (
|
|
120
|
+
exp.DataType.Type.DECFLOAT,
|
|
112
121
|
exp.DataType.Type.DOUBLE,
|
|
113
122
|
exp.DataType.Type.FLOAT,
|
|
123
|
+
exp.DataType.Type.BIGDECIMAL,
|
|
114
124
|
exp.DataType.Type.DECIMAL,
|
|
115
125
|
exp.DataType.Type.BIGINT,
|
|
116
126
|
exp.DataType.Type.INT,
|
|
@@ -130,14 +140,6 @@ class _TypeAnnotator(type):
|
|
|
130
140
|
for data_type in type_precedence:
|
|
131
141
|
klass.COERCES_TO[data_type] = coerces_to.copy()
|
|
132
142
|
coerces_to |= {data_type}
|
|
133
|
-
|
|
134
|
-
# NULL can be coerced to any type, so e.g. NULL + 1 will have type INT
|
|
135
|
-
klass.COERCES_TO[exp.DataType.Type.NULL] = {
|
|
136
|
-
*text_precedence,
|
|
137
|
-
*numeric_precedence,
|
|
138
|
-
*timelike_precedence,
|
|
139
|
-
}
|
|
140
|
-
|
|
141
143
|
return klass
|
|
142
144
|
|
|
143
145
|
|
|
@@ -182,15 +184,16 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|
|
182
184
|
def __init__(
|
|
183
185
|
self,
|
|
184
186
|
schema: Schema,
|
|
185
|
-
|
|
187
|
+
expression_metadata: t.Optional[ExpressionMetadataType] = None,
|
|
186
188
|
coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
|
|
187
189
|
binary_coercions: t.Optional[BinaryCoercions] = None,
|
|
190
|
+
overwrite_types: bool = True,
|
|
188
191
|
) -> None:
|
|
189
192
|
self.schema = schema
|
|
190
|
-
|
|
191
|
-
self.
|
|
192
|
-
|
|
193
|
-
|
|
193
|
+
dialect = schema.dialect or Dialect()
|
|
194
|
+
self.dialect = dialect
|
|
195
|
+
self.expression_metadata = expression_metadata or dialect.EXPRESSION_METADATA
|
|
196
|
+
self.coerces_to = coerces_to or dialect.COERCES_TO or self.COERCES_TO
|
|
194
197
|
self.binary_coercions = binary_coercions or self.BINARY_COERCIONS
|
|
195
198
|
|
|
196
199
|
# Caches the ids of annotated sub-Expressions, to ensure we only visit them once
|
|
@@ -200,16 +203,24 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|
|
200
203
|
self._null_expressions: t.Dict[int, exp.Expression] = {}
|
|
201
204
|
|
|
202
205
|
# Databricks and Spark ≥v3 actually support NULL (i.e., VOID) as a type
|
|
203
|
-
self._supports_null_type =
|
|
206
|
+
self._supports_null_type = dialect.SUPPORTS_NULL_TYPE
|
|
204
207
|
|
|
205
208
|
# Maps an exp.SetOperation's id (e.g. UNION) to its projection types. This is computed if the
|
|
206
209
|
# exp.SetOperation is the expression of a scope source, as selecting from it multiple times
|
|
207
210
|
# would reprocess the entire subtree to coerce the types of its operands' projections
|
|
208
211
|
self._setop_column_types: t.Dict[int, t.Dict[str, exp.DataType | exp.DataType.Type]] = {}
|
|
209
212
|
|
|
213
|
+
# When set to False, this enables partial annotation by skipping already-annotated nodes
|
|
214
|
+
self._overwrite_types = overwrite_types
|
|
215
|
+
|
|
216
|
+
def clear(self) -> None:
|
|
217
|
+
self._visited.clear()
|
|
218
|
+
self._null_expressions.clear()
|
|
219
|
+
self._setop_column_types.clear()
|
|
220
|
+
|
|
210
221
|
def _set_type(
|
|
211
|
-
self, expression:
|
|
212
|
-
) ->
|
|
222
|
+
self, expression: E, target_type: t.Optional[exp.DataType | exp.DataType.Type]
|
|
223
|
+
) -> E:
|
|
213
224
|
prev_type = expression.type
|
|
214
225
|
expression_id = id(expression)
|
|
215
226
|
|
|
@@ -224,22 +235,42 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|
|
224
235
|
elif prev_type and t.cast(exp.DataType, prev_type).this == exp.DataType.Type.NULL:
|
|
225
236
|
self._null_expressions.pop(expression_id, None)
|
|
226
237
|
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
238
|
+
if (
|
|
239
|
+
isinstance(expression, exp.Column)
|
|
240
|
+
and expression.is_type(exp.DataType.Type.JSON)
|
|
241
|
+
and (dot_parts := expression.meta.get("dot_parts"))
|
|
242
|
+
):
|
|
243
|
+
# JSON dot access is case sensitive across all dialects, so we need to undo the normalization.
|
|
244
|
+
i = iter(dot_parts)
|
|
245
|
+
parent = expression.parent
|
|
246
|
+
while isinstance(parent, exp.Dot):
|
|
247
|
+
parent.expression.set("this", exp.to_identifier(next(i), quoted=True))
|
|
248
|
+
parent = parent.parent
|
|
249
|
+
|
|
250
|
+
expression.meta.pop("dot_parts", None)
|
|
251
|
+
|
|
252
|
+
return expression
|
|
253
|
+
|
|
254
|
+
def annotate(self, expression: E, annotate_scope: bool = True) -> E:
|
|
255
|
+
# This flag is used to avoid costly scope traversals when we only care about annotating
|
|
256
|
+
# non-column expressions (partial type inference), e.g., when simplifying in the optimizer
|
|
257
|
+
if annotate_scope:
|
|
258
|
+
for scope in traverse_scope(expression):
|
|
259
|
+
self.annotate_scope(scope)
|
|
230
260
|
|
|
231
261
|
# This takes care of non-traversable expressions
|
|
232
|
-
|
|
262
|
+
self._annotate_expression(expression)
|
|
233
263
|
|
|
234
|
-
# Replace NULL type with
|
|
264
|
+
# Replace NULL type with the default type of the targeted dialect, since the former is not an actual type;
|
|
235
265
|
# it is mostly used to aid type coercion, e.g. in query set operations.
|
|
236
266
|
for expr in self._null_expressions.values():
|
|
237
|
-
expr.type =
|
|
267
|
+
expr.type = self.dialect.DEFAULT_NULL_TYPE
|
|
238
268
|
|
|
239
269
|
return expression
|
|
240
270
|
|
|
241
271
|
def annotate_scope(self, scope: Scope) -> None:
|
|
242
272
|
selects = {}
|
|
273
|
+
|
|
243
274
|
for name, source in scope.sources.items():
|
|
244
275
|
if not isinstance(source, Scope):
|
|
245
276
|
continue
|
|
@@ -259,66 +290,31 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|
|
259
290
|
if not values:
|
|
260
291
|
continue
|
|
261
292
|
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
293
|
+
alias_column_names = expression.alias_column_names
|
|
294
|
+
|
|
295
|
+
if (
|
|
296
|
+
isinstance(expression, exp.Unnest)
|
|
297
|
+
and not alias_column_names
|
|
298
|
+
and expression.type
|
|
299
|
+
and expression.type.is_type(exp.DataType.Type.STRUCT)
|
|
300
|
+
):
|
|
301
|
+
selects[name] = {
|
|
302
|
+
col_def.name: t.cast(t.Union[exp.DataType, exp.DataType.Type], col_def.kind)
|
|
303
|
+
for col_def in expression.type.expressions
|
|
304
|
+
if isinstance(col_def, exp.ColumnDef) and col_def.kind
|
|
305
|
+
}
|
|
306
|
+
else:
|
|
307
|
+
selects[name] = {
|
|
308
|
+
alias: column.type for alias, column in zip(alias_column_names, values)
|
|
309
|
+
}
|
|
266
310
|
elif isinstance(expression, exp.SetOperation) and len(expression.left.selects) == len(
|
|
267
311
|
expression.right.selects
|
|
268
312
|
):
|
|
269
|
-
selects[name] =
|
|
270
|
-
|
|
271
|
-
if not col_types:
|
|
272
|
-
# Process a chain / sub-tree of set operations
|
|
273
|
-
for set_op in expression.walk(
|
|
274
|
-
prune=lambda n: not isinstance(n, (exp.SetOperation, exp.Subquery))
|
|
275
|
-
):
|
|
276
|
-
if not isinstance(set_op, exp.SetOperation):
|
|
277
|
-
continue
|
|
278
|
-
|
|
279
|
-
if set_op.args.get("by_name"):
|
|
280
|
-
r_type_by_select = {
|
|
281
|
-
s.alias_or_name: s.type for s in set_op.right.selects
|
|
282
|
-
}
|
|
283
|
-
setop_cols = {
|
|
284
|
-
s.alias_or_name: self._maybe_coerce(
|
|
285
|
-
t.cast(exp.DataType, s.type),
|
|
286
|
-
r_type_by_select.get(s.alias_or_name)
|
|
287
|
-
or exp.DataType.Type.UNKNOWN,
|
|
288
|
-
)
|
|
289
|
-
for s in set_op.left.selects
|
|
290
|
-
}
|
|
291
|
-
else:
|
|
292
|
-
setop_cols = {
|
|
293
|
-
ls.alias_or_name: self._maybe_coerce(
|
|
294
|
-
t.cast(exp.DataType, ls.type), t.cast(exp.DataType, rs.type)
|
|
295
|
-
)
|
|
296
|
-
for ls, rs in zip(set_op.left.selects, set_op.right.selects)
|
|
297
|
-
}
|
|
298
|
-
|
|
299
|
-
# Coerce intermediate results with the previously registered types, if they exist
|
|
300
|
-
for col_name, col_type in setop_cols.items():
|
|
301
|
-
col_types[col_name] = self._maybe_coerce(
|
|
302
|
-
col_type, col_types.get(col_name, exp.DataType.Type.NULL)
|
|
303
|
-
)
|
|
313
|
+
selects[name] = self._get_setop_column_types(expression)
|
|
304
314
|
|
|
305
315
|
else:
|
|
306
316
|
selects[name] = {s.alias_or_name: s.type for s in expression.selects}
|
|
307
317
|
|
|
308
|
-
# First annotate the current scope's column references
|
|
309
|
-
for col in scope.columns:
|
|
310
|
-
if not col.table:
|
|
311
|
-
continue
|
|
312
|
-
|
|
313
|
-
source = scope.sources.get(col.table)
|
|
314
|
-
if isinstance(source, exp.Table):
|
|
315
|
-
self._set_type(col, self.schema.get_column_type(source, col))
|
|
316
|
-
elif source:
|
|
317
|
-
if col.table in selects and col.name in selects[col.table]:
|
|
318
|
-
self._set_type(col, selects[col.table][col.name])
|
|
319
|
-
elif isinstance(source.expression, exp.Unnest):
|
|
320
|
-
self._set_type(col, source.expression.type)
|
|
321
|
-
|
|
322
318
|
if isinstance(self.schema, MappingSchema):
|
|
323
319
|
for table_column in scope.table_columns:
|
|
324
320
|
source = scope.sources.get(table_column.name)
|
|
@@ -348,10 +344,10 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|
|
348
344
|
):
|
|
349
345
|
self._set_type(table_column, source.expression.meta["query_type"])
|
|
350
346
|
|
|
351
|
-
#
|
|
352
|
-
self.
|
|
347
|
+
# Iterate through all the expressions of the current scope in post-order, and annotate
|
|
348
|
+
self._annotate_expression(scope.expression, scope, selects)
|
|
353
349
|
|
|
354
|
-
if self.
|
|
350
|
+
if self.dialect.QUERY_RESULTS_ARE_STRUCTS and isinstance(scope.expression, exp.Query):
|
|
355
351
|
struct_type = exp.DataType(
|
|
356
352
|
this=exp.DataType.Type.STRUCT,
|
|
357
353
|
expressions=[
|
|
@@ -374,23 +370,57 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|
|
374
370
|
# annotations, i.e., it shouldn't be interpreted as a STRUCT value.
|
|
375
371
|
scope.expression.meta["query_type"] = struct_type
|
|
376
372
|
|
|
377
|
-
def
|
|
378
|
-
|
|
379
|
-
|
|
373
|
+
def _annotate_expression(
|
|
374
|
+
self,
|
|
375
|
+
expression: exp.Expression,
|
|
376
|
+
scope: t.Optional[Scope] = None,
|
|
377
|
+
selects: t.Optional[t.Dict[str, t.Dict[str, t.Any]]] = None,
|
|
378
|
+
) -> None:
|
|
379
|
+
stack = [(expression, False)]
|
|
380
|
+
selects = selects or {}
|
|
380
381
|
|
|
381
|
-
|
|
382
|
+
while stack:
|
|
383
|
+
expr, children_annotated = stack.pop()
|
|
382
384
|
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
385
|
+
if id(expr) in self._visited or (
|
|
386
|
+
not self._overwrite_types
|
|
387
|
+
and expr.type
|
|
388
|
+
and not expr.is_type(exp.DataType.Type.UNKNOWN)
|
|
389
|
+
):
|
|
390
|
+
continue # We've already inferred the expression's type
|
|
388
391
|
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
+
if not children_annotated:
|
|
393
|
+
stack.append((expr, True))
|
|
394
|
+
for child_expr in expr.iter_expressions():
|
|
395
|
+
stack.append((child_expr, False))
|
|
396
|
+
continue
|
|
392
397
|
|
|
393
|
-
|
|
398
|
+
if scope and isinstance(expr, exp.Column) and expr.table:
|
|
399
|
+
source = scope.sources.get(expr.table)
|
|
400
|
+
if isinstance(source, exp.Table):
|
|
401
|
+
self._set_type(expr, self.schema.get_column_type(source, expr))
|
|
402
|
+
elif source:
|
|
403
|
+
if expr.table in selects and expr.name in selects[expr.table]:
|
|
404
|
+
self._set_type(expr, selects[expr.table][expr.name])
|
|
405
|
+
elif isinstance(source.expression, exp.Unnest):
|
|
406
|
+
self._set_type(expr, source.expression.type)
|
|
407
|
+
else:
|
|
408
|
+
self._set_type(expr, exp.DataType.Type.UNKNOWN)
|
|
409
|
+
else:
|
|
410
|
+
self._set_type(expr, exp.DataType.Type.UNKNOWN)
|
|
411
|
+
|
|
412
|
+
if expr.type and expr.type.args.get("nullable") is False:
|
|
413
|
+
expr.meta["nonnull"] = True
|
|
414
|
+
continue
|
|
415
|
+
|
|
416
|
+
spec = self.expression_metadata.get(expr.__class__)
|
|
417
|
+
|
|
418
|
+
if spec and (annotator := spec.get("annotator")):
|
|
419
|
+
annotator(self, expr)
|
|
420
|
+
elif spec and (returns := spec.get("returns")):
|
|
421
|
+
self._set_type(expr, t.cast(exp.DataType.Type, returns))
|
|
422
|
+
else:
|
|
423
|
+
self._set_type(expr, exp.DataType.Type.UNKNOWN)
|
|
394
424
|
|
|
395
425
|
def _maybe_coerce(
|
|
396
426
|
self,
|
|
@@ -421,14 +451,80 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|
|
421
451
|
if exp.DataType.Type.UNKNOWN in (type1_value, type2_value):
|
|
422
452
|
return exp.DataType.Type.UNKNOWN
|
|
423
453
|
|
|
454
|
+
if type1_value == exp.DataType.Type.NULL:
|
|
455
|
+
return type2_value
|
|
456
|
+
if type2_value == exp.DataType.Type.NULL:
|
|
457
|
+
return type1_value
|
|
458
|
+
|
|
424
459
|
return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value
|
|
425
460
|
|
|
426
|
-
def
|
|
427
|
-
self.
|
|
461
|
+
def _get_setop_column_types(
|
|
462
|
+
self, setop: exp.SetOperation
|
|
463
|
+
) -> t.Dict[str, exp.DataType | exp.DataType.Type]:
|
|
464
|
+
"""
|
|
465
|
+
Computes and returns the coerced column types for a SetOperation.
|
|
466
|
+
|
|
467
|
+
This handles UNION, INTERSECT, EXCEPT, etc., coercing types across
|
|
468
|
+
left and right operands for all projections/columns.
|
|
428
469
|
|
|
470
|
+
Args:
|
|
471
|
+
setop: The SetOperation expression to analyze
|
|
472
|
+
|
|
473
|
+
Returns:
|
|
474
|
+
Dictionary mapping column names to their coerced types
|
|
475
|
+
"""
|
|
476
|
+
setop_id = id(setop)
|
|
477
|
+
if setop_id in self._setop_column_types:
|
|
478
|
+
return self._setop_column_types[setop_id]
|
|
479
|
+
|
|
480
|
+
col_types: t.Dict[str, exp.DataType | exp.DataType.Type] = {}
|
|
481
|
+
|
|
482
|
+
# Validate that left and right have same number of projections
|
|
483
|
+
if not (
|
|
484
|
+
isinstance(setop, exp.SetOperation)
|
|
485
|
+
and setop.left.selects
|
|
486
|
+
and setop.right.selects
|
|
487
|
+
and len(setop.left.selects) == len(setop.right.selects)
|
|
488
|
+
):
|
|
489
|
+
return col_types
|
|
490
|
+
|
|
491
|
+
# Process a chain / sub-tree of set operations
|
|
492
|
+
for set_op in setop.walk(
|
|
493
|
+
prune=lambda n: not isinstance(n, (exp.SetOperation, exp.Subquery))
|
|
494
|
+
):
|
|
495
|
+
if not isinstance(set_op, exp.SetOperation):
|
|
496
|
+
continue
|
|
497
|
+
|
|
498
|
+
if set_op.args.get("by_name"):
|
|
499
|
+
r_type_by_select = {s.alias_or_name: s.type for s in set_op.right.selects}
|
|
500
|
+
setop_cols = {
|
|
501
|
+
s.alias_or_name: self._maybe_coerce(
|
|
502
|
+
t.cast(exp.DataType, s.type),
|
|
503
|
+
r_type_by_select.get(s.alias_or_name) or exp.DataType.Type.UNKNOWN,
|
|
504
|
+
)
|
|
505
|
+
for s in set_op.left.selects
|
|
506
|
+
}
|
|
507
|
+
else:
|
|
508
|
+
setop_cols = {
|
|
509
|
+
ls.alias_or_name: self._maybe_coerce(
|
|
510
|
+
t.cast(exp.DataType, ls.type), t.cast(exp.DataType, rs.type)
|
|
511
|
+
)
|
|
512
|
+
for ls, rs in zip(set_op.left.selects, set_op.right.selects)
|
|
513
|
+
}
|
|
514
|
+
|
|
515
|
+
# Coerce intermediate results with the previously registered types, if they exist
|
|
516
|
+
for col_name, col_type in setop_cols.items():
|
|
517
|
+
col_types[col_name] = self._maybe_coerce(
|
|
518
|
+
col_type, col_types.get(col_name, exp.DataType.Type.NULL)
|
|
519
|
+
)
|
|
520
|
+
|
|
521
|
+
self._setop_column_types[setop_id] = col_types
|
|
522
|
+
return col_types
|
|
523
|
+
|
|
524
|
+
def _annotate_binary(self, expression: B) -> B:
|
|
429
525
|
left, right = expression.left, expression.right
|
|
430
526
|
if not left or not right:
|
|
431
|
-
expression_sql = expression.sql(self.
|
|
527
|
+
expression_sql = expression.sql(self.dialect)
|
|
432
528
|
logger.warning(f"Failed to annotate badly formed binary expression: {expression_sql}")
|
|
433
529
|
self._set_type(expression, None)
|
|
434
530
|
return expression
|
|
@@ -442,16 +538,22 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|
|
442
538
|
else:
|
|
443
539
|
self._set_type(expression, self._maybe_coerce(left_type, right_type))
|
|
444
540
|
|
|
541
|
+
if isinstance(expression, exp.Is) or (
|
|
542
|
+
left.meta.get("nonnull") is True and right.meta.get("nonnull") is True
|
|
543
|
+
):
|
|
544
|
+
expression.meta["nonnull"] = True
|
|
545
|
+
|
|
445
546
|
return expression
|
|
446
547
|
|
|
447
548
|
def _annotate_unary(self, expression: E) -> E:
|
|
448
|
-
self._annotate_args(expression)
|
|
449
|
-
|
|
450
549
|
if isinstance(expression, exp.Not):
|
|
451
550
|
self._set_type(expression, exp.DataType.Type.BOOLEAN)
|
|
452
551
|
else:
|
|
453
552
|
self._set_type(expression, expression.this.type)
|
|
454
553
|
|
|
554
|
+
if expression.this.meta.get("nonnull") is True:
|
|
555
|
+
expression.meta["nonnull"] = True
|
|
556
|
+
|
|
455
557
|
return expression
|
|
456
558
|
|
|
457
559
|
def _annotate_literal(self, expression: exp.Literal) -> exp.Literal:
|
|
@@ -462,13 +564,9 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|
|
462
564
|
else:
|
|
463
565
|
self._set_type(expression, exp.DataType.Type.DOUBLE)
|
|
464
566
|
|
|
465
|
-
|
|
567
|
+
expression.meta["nonnull"] = True
|
|
466
568
|
|
|
467
|
-
|
|
468
|
-
self, expression: E, target_type: exp.DataType | exp.DataType.Type
|
|
469
|
-
) -> E:
|
|
470
|
-
self._set_type(expression, target_type)
|
|
471
|
-
return self._annotate_args(expression)
|
|
569
|
+
return expression
|
|
472
570
|
|
|
473
571
|
@t.no_type_check
|
|
474
572
|
def _annotate_by_args(
|
|
@@ -478,8 +576,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|
|
478
576
|
promote: bool = False,
|
|
479
577
|
array: bool = False,
|
|
480
578
|
) -> E:
|
|
481
|
-
self._annotate_args(expression)
|
|
482
|
-
|
|
483
579
|
expressions: t.List[exp.Expression] = []
|
|
484
580
|
for arg in args:
|
|
485
581
|
arg_expr = expression.args.get(arg)
|
|
@@ -497,7 +593,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|
|
497
593
|
if not expr_type.is_type(exp.DataType.Type.UNKNOWN):
|
|
498
594
|
last_datatype = self._maybe_coerce(last_datatype or expr_type, expr_type)
|
|
499
595
|
|
|
500
|
-
self._set_type(expression, last_datatype
|
|
596
|
+
self._set_type(expression, last_datatype)
|
|
501
597
|
|
|
502
598
|
if promote:
|
|
503
599
|
if expression.type.this in exp.DataType.INTEGER_TYPES:
|
|
@@ -518,8 +614,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|
|
518
614
|
def _annotate_timeunit(
|
|
519
615
|
self, expression: exp.TimeUnit | exp.DateTrunc
|
|
520
616
|
) -> exp.TimeUnit | exp.DateTrunc:
|
|
521
|
-
self._annotate_args(expression)
|
|
522
|
-
|
|
523
617
|
if expression.this.type.this in exp.DataType.TEXT_TYPES:
|
|
524
618
|
datatype = _coerce_date_literal(expression.this, expression.unit)
|
|
525
619
|
elif expression.this.type.this in exp.DataType.TEMPORAL_TYPES:
|
|
@@ -531,8 +625,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|
|
531
625
|
return expression
|
|
532
626
|
|
|
533
627
|
def _annotate_bracket(self, expression: exp.Bracket) -> exp.Bracket:
|
|
534
|
-
self._annotate_args(expression)
|
|
535
|
-
|
|
536
628
|
bracket_arg = expression.expressions[0]
|
|
537
629
|
this = expression.this
|
|
538
630
|
|
|
@@ -550,8 +642,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|
|
550
642
|
return expression
|
|
551
643
|
|
|
552
644
|
def _annotate_div(self, expression: exp.Div) -> exp.Div:
|
|
553
|
-
self._annotate_args(expression)
|
|
554
|
-
|
|
555
645
|
left_type, right_type = expression.left.type.this, expression.right.type.this # type: ignore
|
|
556
646
|
|
|
557
647
|
if (
|
|
@@ -570,7 +660,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|
|
570
660
|
return expression
|
|
571
661
|
|
|
572
662
|
def _annotate_dot(self, expression: exp.Dot) -> exp.Dot:
|
|
573
|
-
self._annotate_args(expression)
|
|
574
663
|
self._set_type(expression, None)
|
|
575
664
|
this_type = expression.this.type
|
|
576
665
|
|
|
@@ -583,12 +672,10 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|
|
583
672
|
return expression
|
|
584
673
|
|
|
585
674
|
def _annotate_explode(self, expression: exp.Explode) -> exp.Explode:
|
|
586
|
-
self._annotate_args(expression)
|
|
587
675
|
self._set_type(expression, seq_get(expression.this.type.expressions, 0))
|
|
588
676
|
return expression
|
|
589
677
|
|
|
590
678
|
def _annotate_unnest(self, expression: exp.Unnest) -> exp.Unnest:
|
|
591
|
-
self._annotate_args(expression)
|
|
592
679
|
child = seq_get(expression.expressions, 0)
|
|
593
680
|
|
|
594
681
|
if child and child.is_type(exp.DataType.Type.ARRAY):
|
|
@@ -599,32 +686,59 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|
|
599
686
|
self._set_type(expression, expr_type)
|
|
600
687
|
return expression
|
|
601
688
|
|
|
689
|
+
def _annotate_subquery(self, expression: exp.Subquery) -> exp.Subquery:
|
|
690
|
+
# For scalar subqueries (subqueries with a single projection), infer the type
|
|
691
|
+
# from that single projection. This allows type propagation in cases like:
|
|
692
|
+
# SELECT (SELECT 1 AS c) AS c
|
|
693
|
+
query = expression.unnest()
|
|
694
|
+
|
|
695
|
+
if isinstance(query, exp.Query):
|
|
696
|
+
selects = query.selects
|
|
697
|
+
if len(selects) == 1:
|
|
698
|
+
self._set_type(expression, selects[0].type)
|
|
699
|
+
return expression
|
|
700
|
+
|
|
701
|
+
self._set_type(expression, exp.DataType.Type.UNKNOWN)
|
|
702
|
+
return expression
|
|
703
|
+
|
|
602
704
|
def _annotate_struct_value(
|
|
603
705
|
self, expression: exp.Expression
|
|
604
706
|
) -> t.Optional[exp.DataType] | exp.ColumnDef:
|
|
605
707
|
# Case: STRUCT(key AS value)
|
|
708
|
+
this: t.Optional[exp.Expression] = None
|
|
709
|
+
kind = expression.type
|
|
710
|
+
|
|
606
711
|
if alias := expression.args.get("alias"):
|
|
607
|
-
|
|
712
|
+
this = alias.copy()
|
|
713
|
+
elif expression.expression:
|
|
714
|
+
# Case: STRUCT(key = value) or STRUCT(key := value)
|
|
715
|
+
this = expression.this.copy()
|
|
716
|
+
kind = expression.expression.type
|
|
717
|
+
elif isinstance(expression, exp.Column):
|
|
718
|
+
# Case: STRUCT(c)
|
|
719
|
+
this = expression.this.copy()
|
|
608
720
|
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
return exp.ColumnDef(this=expression.this.copy(), kind=expression.expression.type)
|
|
721
|
+
if kind and kind.is_type(exp.DataType.Type.UNKNOWN):
|
|
722
|
+
return None
|
|
612
723
|
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
return exp.ColumnDef(this=expression.this.copy(), kind=expression.type)
|
|
724
|
+
if this:
|
|
725
|
+
return exp.ColumnDef(this=this, kind=kind)
|
|
616
726
|
|
|
617
|
-
return
|
|
727
|
+
return kind
|
|
618
728
|
|
|
619
729
|
def _annotate_struct(self, expression: exp.Struct) -> exp.Struct:
|
|
620
|
-
|
|
730
|
+
expressions = []
|
|
731
|
+
for expr in expression.expressions:
|
|
732
|
+
struct_field_type = self._annotate_struct_value(expr)
|
|
733
|
+
if struct_field_type is None:
|
|
734
|
+
self._set_type(expression, None)
|
|
735
|
+
return expression
|
|
736
|
+
|
|
737
|
+
expressions.append(struct_field_type)
|
|
738
|
+
|
|
621
739
|
self._set_type(
|
|
622
740
|
expression,
|
|
623
|
-
exp.DataType(
|
|
624
|
-
this=exp.DataType.Type.STRUCT,
|
|
625
|
-
expressions=[self._annotate_struct_value(expr) for expr in expression.expressions],
|
|
626
|
-
nested=True,
|
|
627
|
-
),
|
|
741
|
+
exp.DataType(this=exp.DataType.Type.STRUCT, expressions=expressions, nested=True),
|
|
628
742
|
)
|
|
629
743
|
return expression
|
|
630
744
|
|
|
@@ -635,8 +749,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|
|
635
749
|
def _annotate_map(self, expression: exp.VarMap) -> exp.VarMap: ...
|
|
636
750
|
|
|
637
751
|
def _annotate_map(self, expression):
|
|
638
|
-
self._annotate_args(expression)
|
|
639
|
-
|
|
640
752
|
keys = expression.args.get("keys")
|
|
641
753
|
values = expression.args.get("values")
|
|
642
754
|
|
|
@@ -653,8 +765,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|
|
653
765
|
return expression
|
|
654
766
|
|
|
655
767
|
def _annotate_to_map(self, expression: exp.ToMap) -> exp.ToMap:
|
|
656
|
-
self._annotate_args(expression)
|
|
657
|
-
|
|
658
768
|
map_type = exp.DataType(this=exp.DataType.Type.MAP)
|
|
659
769
|
arg = expression.this
|
|
660
770
|
if arg.is_type(exp.DataType.Type.STRUCT):
|
|
@@ -669,7 +779,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|
|
669
779
|
return expression
|
|
670
780
|
|
|
671
781
|
def _annotate_extract(self, expression: exp.Extract) -> exp.Extract:
|
|
672
|
-
self._annotate_args(expression)
|
|
673
782
|
part = expression.name
|
|
674
783
|
if part == "TIME":
|
|
675
784
|
self._set_type(expression, exp.DataType.Type.TIME)
|
|
@@ -680,8 +789,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|
|
680
789
|
return expression
|
|
681
790
|
|
|
682
791
|
def _annotate_by_array_element(self, expression: exp.Expression) -> exp.Expression:
|
|
683
|
-
self._annotate_args(expression)
|
|
684
|
-
|
|
685
792
|
array_arg = expression.this
|
|
686
793
|
if array_arg.type.is_type(exp.DataType.Type.ARRAY):
|
|
687
794
|
element_type = seq_get(array_arg.type.expressions, 0) or exp.DataType.Type.UNKNOWN
|
|
@@ -35,7 +35,12 @@ def canonicalize(expression: exp.Expression, dialect: DialectType = None) -> exp
|
|
|
35
35
|
|
|
36
36
|
def add_text_to_concat(node: exp.Expression) -> exp.Expression:
|
|
37
37
|
if isinstance(node, exp.Add) and node.type and node.type.this in exp.DataType.TEXT_TYPES:
|
|
38
|
-
node = exp.Concat(
|
|
38
|
+
node = exp.Concat(
|
|
39
|
+
expressions=[node.left, node.right],
|
|
40
|
+
# All known dialects, i.e. Redshift and T-SQL, that support
|
|
41
|
+
# concatenating strings with the + operator do not coalesce NULLs.
|
|
42
|
+
coalesce=False,
|
|
43
|
+
)
|
|
39
44
|
return node
|
|
40
45
|
|
|
41
46
|
|
|
@@ -110,7 +110,7 @@ def _has_single_output_row(scope):
|
|
|
110
110
|
return isinstance(scope.expression, exp.Select) and (
|
|
111
111
|
all(isinstance(e.unalias(), exp.AggFunc) for e in scope.expression.selects)
|
|
112
112
|
or _is_limit_1(scope)
|
|
113
|
-
or not scope.expression.args.get("
|
|
113
|
+
or not scope.expression.args.get("from_")
|
|
114
114
|
)
|
|
115
115
|
|
|
116
116
|
|