sqlglot 27.29.0__py3-none-any.whl → 28.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. sqlglot/__main__.py +6 -4
  2. sqlglot/_version.py +2 -2
  3. sqlglot/dialects/bigquery.py +116 -295
  4. sqlglot/dialects/clickhouse.py +67 -2
  5. sqlglot/dialects/databricks.py +38 -1
  6. sqlglot/dialects/dialect.py +327 -286
  7. sqlglot/dialects/dremio.py +4 -1
  8. sqlglot/dialects/duckdb.py +718 -22
  9. sqlglot/dialects/exasol.py +243 -10
  10. sqlglot/dialects/hive.py +8 -8
  11. sqlglot/dialects/mysql.py +11 -2
  12. sqlglot/dialects/oracle.py +29 -0
  13. sqlglot/dialects/postgres.py +46 -24
  14. sqlglot/dialects/presto.py +47 -16
  15. sqlglot/dialects/redshift.py +16 -0
  16. sqlglot/dialects/risingwave.py +3 -0
  17. sqlglot/dialects/singlestore.py +12 -3
  18. sqlglot/dialects/snowflake.py +199 -271
  19. sqlglot/dialects/spark.py +2 -2
  20. sqlglot/dialects/spark2.py +11 -48
  21. sqlglot/dialects/sqlite.py +9 -0
  22. sqlglot/dialects/teradata.py +5 -8
  23. sqlglot/dialects/trino.py +6 -0
  24. sqlglot/dialects/tsql.py +61 -25
  25. sqlglot/diff.py +4 -2
  26. sqlglot/errors.py +69 -0
  27. sqlglot/expressions.py +484 -84
  28. sqlglot/generator.py +143 -41
  29. sqlglot/helper.py +2 -2
  30. sqlglot/optimizer/annotate_types.py +247 -140
  31. sqlglot/optimizer/canonicalize.py +6 -1
  32. sqlglot/optimizer/eliminate_joins.py +1 -1
  33. sqlglot/optimizer/eliminate_subqueries.py +2 -2
  34. sqlglot/optimizer/merge_subqueries.py +5 -5
  35. sqlglot/optimizer/normalize.py +20 -13
  36. sqlglot/optimizer/normalize_identifiers.py +17 -3
  37. sqlglot/optimizer/optimizer.py +4 -0
  38. sqlglot/optimizer/pushdown_predicates.py +1 -1
  39. sqlglot/optimizer/qualify.py +14 -6
  40. sqlglot/optimizer/qualify_columns.py +113 -352
  41. sqlglot/optimizer/qualify_tables.py +112 -70
  42. sqlglot/optimizer/resolver.py +374 -0
  43. sqlglot/optimizer/scope.py +27 -16
  44. sqlglot/optimizer/simplify.py +1074 -964
  45. sqlglot/optimizer/unnest_subqueries.py +12 -2
  46. sqlglot/parser.py +276 -160
  47. sqlglot/planner.py +2 -2
  48. sqlglot/schema.py +15 -4
  49. sqlglot/tokens.py +42 -7
  50. sqlglot/transforms.py +77 -22
  51. sqlglot/typing/__init__.py +316 -0
  52. sqlglot/typing/bigquery.py +376 -0
  53. sqlglot/typing/hive.py +12 -0
  54. sqlglot/typing/presto.py +24 -0
  55. sqlglot/typing/snowflake.py +505 -0
  56. sqlglot/typing/spark2.py +58 -0
  57. sqlglot/typing/tsql.py +9 -0
  58. {sqlglot-27.29.0.dist-info → sqlglot-28.4.0.dist-info}/METADATA +2 -2
  59. sqlglot-28.4.0.dist-info/RECORD +92 -0
  60. sqlglot-27.29.0.dist-info/RECORD +0 -84
  61. {sqlglot-27.29.0.dist-info → sqlglot-28.4.0.dist-info}/WHEEL +0 -0
  62. {sqlglot-27.29.0.dist-info → sqlglot-28.4.0.dist-info}/licenses/LICENSE +0 -0
  63. {sqlglot-27.29.0.dist-info → sqlglot-28.4.0.dist-info}/top_level.txt +0 -0
@@ -5,6 +5,7 @@ import logging
5
5
  import typing as t
6
6
 
7
7
  from sqlglot import exp
8
+ from sqlglot.dialects.dialect import Dialect
8
9
  from sqlglot.helper import (
9
10
  ensure_list,
10
11
  is_date_unit,
@@ -14,7 +15,6 @@ from sqlglot.helper import (
14
15
  )
15
16
  from sqlglot.optimizer.scope import Scope, traverse_scope
16
17
  from sqlglot.schema import MappingSchema, Schema, ensure_schema
17
- from sqlglot.dialects.dialect import Dialect
18
18
 
19
19
  if t.TYPE_CHECKING:
20
20
  from sqlglot._typing import B, E
@@ -25,7 +25,8 @@ if t.TYPE_CHECKING:
25
25
  BinaryCoercionFunc,
26
26
  ]
27
27
 
28
- from sqlglot.dialects.dialect import DialectType, AnnotatorsType
28
+ from sqlglot.dialects.dialect import DialectType
29
+ from sqlglot.typing import ExpressionMetadataType
29
30
 
30
31
  logger = logging.getLogger("sqlglot")
31
32
 
@@ -33,9 +34,10 @@ logger = logging.getLogger("sqlglot")
33
34
  def annotate_types(
34
35
  expression: E,
35
36
  schema: t.Optional[t.Dict | Schema] = None,
36
- annotators: t.Optional[AnnotatorsType] = None,
37
+ expression_metadata: t.Optional[ExpressionMetadataType] = None,
37
38
  coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
38
39
  dialect: DialectType = None,
40
+ overwrite_types: bool = True,
39
41
  ) -> E:
40
42
  """
41
43
  Infers the types of an expression, annotating its AST accordingly.
@@ -51,8 +53,9 @@ def annotate_types(
51
53
  Args:
52
54
  expression: Expression to annotate.
53
55
  schema: Database schema.
54
- annotators: Maps expression type to corresponding annotation function.
56
+ expression_metadata: Maps expression type to corresponding annotation function.
55
57
  coerces_to: Maps expression type to set of types that it can be coerced into.
58
+ overwrite_types: Re-annotate the existing AST types.
56
59
 
57
60
  Returns:
58
61
  The expression annotated with types.
@@ -60,7 +63,12 @@ def annotate_types(
60
63
 
61
64
  schema = ensure_schema(schema, dialect=dialect)
62
65
 
63
- return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)
66
+ return TypeAnnotator(
67
+ schema=schema,
68
+ expression_metadata=expression_metadata,
69
+ coerces_to=coerces_to,
70
+ overwrite_types=overwrite_types,
71
+ ).annotate(expression)
64
72
 
65
73
 
66
74
  def _coerce_date_literal(l: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.DataType.Type:
@@ -109,8 +117,10 @@ class _TypeAnnotator(type):
109
117
  exp.DataType.Type.CHAR,
110
118
  )
111
119
  numeric_precedence = (
120
+ exp.DataType.Type.DECFLOAT,
112
121
  exp.DataType.Type.DOUBLE,
113
122
  exp.DataType.Type.FLOAT,
123
+ exp.DataType.Type.BIGDECIMAL,
114
124
  exp.DataType.Type.DECIMAL,
115
125
  exp.DataType.Type.BIGINT,
116
126
  exp.DataType.Type.INT,
@@ -130,14 +140,6 @@ class _TypeAnnotator(type):
130
140
  for data_type in type_precedence:
131
141
  klass.COERCES_TO[data_type] = coerces_to.copy()
132
142
  coerces_to |= {data_type}
133
-
134
- # NULL can be coerced to any type, so e.g. NULL + 1 will have type INT
135
- klass.COERCES_TO[exp.DataType.Type.NULL] = {
136
- *text_precedence,
137
- *numeric_precedence,
138
- *timelike_precedence,
139
- }
140
-
141
143
  return klass
142
144
 
143
145
 
@@ -182,15 +184,16 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
182
184
  def __init__(
183
185
  self,
184
186
  schema: Schema,
185
- annotators: t.Optional[AnnotatorsType] = None,
187
+ expression_metadata: t.Optional[ExpressionMetadataType] = None,
186
188
  coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
187
189
  binary_coercions: t.Optional[BinaryCoercions] = None,
190
+ overwrite_types: bool = True,
188
191
  ) -> None:
189
192
  self.schema = schema
190
- self.annotators = annotators or Dialect.get_or_raise(schema.dialect).ANNOTATORS
191
- self.coerces_to = (
192
- coerces_to or Dialect.get_or_raise(schema.dialect).COERCES_TO or self.COERCES_TO
193
- )
193
+ dialect = schema.dialect or Dialect()
194
+ self.dialect = dialect
195
+ self.expression_metadata = expression_metadata or dialect.EXPRESSION_METADATA
196
+ self.coerces_to = coerces_to or dialect.COERCES_TO or self.COERCES_TO
194
197
  self.binary_coercions = binary_coercions or self.BINARY_COERCIONS
195
198
 
196
199
  # Caches the ids of annotated sub-Expressions, to ensure we only visit them once
@@ -200,16 +203,24 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
200
203
  self._null_expressions: t.Dict[int, exp.Expression] = {}
201
204
 
202
205
  # Databricks and Spark ≥v3 actually support NULL (i.e., VOID) as a type
203
- self._supports_null_type = schema.dialect in ("databricks", "spark")
206
+ self._supports_null_type = dialect.SUPPORTS_NULL_TYPE
204
207
 
205
208
  # Maps an exp.SetOperation's id (e.g. UNION) to its projection types. This is computed if the
206
209
  # exp.SetOperation is the expression of a scope source, as selecting from it multiple times
207
210
  # would reprocess the entire subtree to coerce the types of its operands' projections
208
211
  self._setop_column_types: t.Dict[int, t.Dict[str, exp.DataType | exp.DataType.Type]] = {}
209
212
 
213
+ # When set to False, this enables partial annotation by skipping already-annotated nodes
214
+ self._overwrite_types = overwrite_types
215
+
216
+ def clear(self) -> None:
217
+ self._visited.clear()
218
+ self._null_expressions.clear()
219
+ self._setop_column_types.clear()
220
+
210
221
  def _set_type(
211
- self, expression: exp.Expression, target_type: t.Optional[exp.DataType | exp.DataType.Type]
212
- ) -> None:
222
+ self, expression: E, target_type: t.Optional[exp.DataType | exp.DataType.Type]
223
+ ) -> E:
213
224
  prev_type = expression.type
214
225
  expression_id = id(expression)
215
226
 
@@ -224,22 +235,42 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
224
235
  elif prev_type and t.cast(exp.DataType, prev_type).this == exp.DataType.Type.NULL:
225
236
  self._null_expressions.pop(expression_id, None)
226
237
 
227
- def annotate(self, expression: E) -> E:
228
- for scope in traverse_scope(expression):
229
- self.annotate_scope(scope)
238
+ if (
239
+ isinstance(expression, exp.Column)
240
+ and expression.is_type(exp.DataType.Type.JSON)
241
+ and (dot_parts := expression.meta.get("dot_parts"))
242
+ ):
243
+ # JSON dot access is case sensitive across all dialects, so we need to undo the normalization.
244
+ i = iter(dot_parts)
245
+ parent = expression.parent
246
+ while isinstance(parent, exp.Dot):
247
+ parent.expression.set("this", exp.to_identifier(next(i), quoted=True))
248
+ parent = parent.parent
249
+
250
+ expression.meta.pop("dot_parts", None)
251
+
252
+ return expression
253
+
254
+ def annotate(self, expression: E, annotate_scope: bool = True) -> E:
255
+ # This flag is used to avoid costly scope traversals when we only care about annotating
256
+ # non-column expressions (partial type inference), e.g., when simplifying in the optimizer
257
+ if annotate_scope:
258
+ for scope in traverse_scope(expression):
259
+ self.annotate_scope(scope)
230
260
 
231
261
  # This takes care of non-traversable expressions
232
- expression = self._maybe_annotate(expression)
262
+ self._annotate_expression(expression)
233
263
 
234
- # Replace NULL type with UNKNOWN, since the former is not an actual type;
264
+ # Replace NULL type with the default type of the targeted dialect, since the former is not an actual type;
235
265
  # it is mostly used to aid type coercion, e.g. in query set operations.
236
266
  for expr in self._null_expressions.values():
237
- expr.type = exp.DataType.Type.UNKNOWN
267
+ expr.type = self.dialect.DEFAULT_NULL_TYPE
238
268
 
239
269
  return expression
240
270
 
241
271
  def annotate_scope(self, scope: Scope) -> None:
242
272
  selects = {}
273
+
243
274
  for name, source in scope.sources.items():
244
275
  if not isinstance(source, Scope):
245
276
  continue
@@ -259,66 +290,31 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
259
290
  if not values:
260
291
  continue
261
292
 
262
- selects[name] = {
263
- alias: column.type
264
- for alias, column in zip(expression.alias_column_names, values)
265
- }
293
+ alias_column_names = expression.alias_column_names
294
+
295
+ if (
296
+ isinstance(expression, exp.Unnest)
297
+ and not alias_column_names
298
+ and expression.type
299
+ and expression.type.is_type(exp.DataType.Type.STRUCT)
300
+ ):
301
+ selects[name] = {
302
+ col_def.name: t.cast(t.Union[exp.DataType, exp.DataType.Type], col_def.kind)
303
+ for col_def in expression.type.expressions
304
+ if isinstance(col_def, exp.ColumnDef) and col_def.kind
305
+ }
306
+ else:
307
+ selects[name] = {
308
+ alias: column.type for alias, column in zip(alias_column_names, values)
309
+ }
266
310
  elif isinstance(expression, exp.SetOperation) and len(expression.left.selects) == len(
267
311
  expression.right.selects
268
312
  ):
269
- selects[name] = col_types = self._setop_column_types.setdefault(id(expression), {})
270
-
271
- if not col_types:
272
- # Process a chain / sub-tree of set operations
273
- for set_op in expression.walk(
274
- prune=lambda n: not isinstance(n, (exp.SetOperation, exp.Subquery))
275
- ):
276
- if not isinstance(set_op, exp.SetOperation):
277
- continue
278
-
279
- if set_op.args.get("by_name"):
280
- r_type_by_select = {
281
- s.alias_or_name: s.type for s in set_op.right.selects
282
- }
283
- setop_cols = {
284
- s.alias_or_name: self._maybe_coerce(
285
- t.cast(exp.DataType, s.type),
286
- r_type_by_select.get(s.alias_or_name)
287
- or exp.DataType.Type.UNKNOWN,
288
- )
289
- for s in set_op.left.selects
290
- }
291
- else:
292
- setop_cols = {
293
- ls.alias_or_name: self._maybe_coerce(
294
- t.cast(exp.DataType, ls.type), t.cast(exp.DataType, rs.type)
295
- )
296
- for ls, rs in zip(set_op.left.selects, set_op.right.selects)
297
- }
298
-
299
- # Coerce intermediate results with the previously registered types, if they exist
300
- for col_name, col_type in setop_cols.items():
301
- col_types[col_name] = self._maybe_coerce(
302
- col_type, col_types.get(col_name, exp.DataType.Type.NULL)
303
- )
313
+ selects[name] = self._get_setop_column_types(expression)
304
314
 
305
315
  else:
306
316
  selects[name] = {s.alias_or_name: s.type for s in expression.selects}
307
317
 
308
- # First annotate the current scope's column references
309
- for col in scope.columns:
310
- if not col.table:
311
- continue
312
-
313
- source = scope.sources.get(col.table)
314
- if isinstance(source, exp.Table):
315
- self._set_type(col, self.schema.get_column_type(source, col))
316
- elif source:
317
- if col.table in selects and col.name in selects[col.table]:
318
- self._set_type(col, selects[col.table][col.name])
319
- elif isinstance(source.expression, exp.Unnest):
320
- self._set_type(col, source.expression.type)
321
-
322
318
  if isinstance(self.schema, MappingSchema):
323
319
  for table_column in scope.table_columns:
324
320
  source = scope.sources.get(table_column.name)
@@ -348,10 +344,10 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
348
344
  ):
349
345
  self._set_type(table_column, source.expression.meta["query_type"])
350
346
 
351
- # Then (possibly) annotate the remaining expressions in the scope
352
- self._maybe_annotate(scope.expression)
347
+ # Iterate through all the expressions of the current scope in post-order, and annotate
348
+ self._annotate_expression(scope.expression, scope, selects)
353
349
 
354
- if self.schema.dialect == "bigquery" and isinstance(scope.expression, exp.Query):
350
+ if self.dialect.QUERY_RESULTS_ARE_STRUCTS and isinstance(scope.expression, exp.Query):
355
351
  struct_type = exp.DataType(
356
352
  this=exp.DataType.Type.STRUCT,
357
353
  expressions=[
@@ -374,23 +370,57 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
374
370
  # annotations, i.e., it shouldn't be interpreted as a STRUCT value.
375
371
  scope.expression.meta["query_type"] = struct_type
376
372
 
377
- def _maybe_annotate(self, expression: E) -> E:
378
- if id(expression) in self._visited:
379
- return expression # We've already inferred the expression's type
373
+ def _annotate_expression(
374
+ self,
375
+ expression: exp.Expression,
376
+ scope: t.Optional[Scope] = None,
377
+ selects: t.Optional[t.Dict[str, t.Dict[str, t.Any]]] = None,
378
+ ) -> None:
379
+ stack = [(expression, False)]
380
+ selects = selects or {}
380
381
 
381
- annotator = self.annotators.get(expression.__class__)
382
+ while stack:
383
+ expr, children_annotated = stack.pop()
382
384
 
383
- return (
384
- annotator(self, expression)
385
- if annotator
386
- else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN)
387
- )
385
+ if id(expr) in self._visited or (
386
+ not self._overwrite_types
387
+ and expr.type
388
+ and not expr.is_type(exp.DataType.Type.UNKNOWN)
389
+ ):
390
+ continue # We've already inferred the expression's type
388
391
 
389
- def _annotate_args(self, expression: E) -> E:
390
- for value in expression.iter_expressions():
391
- self._maybe_annotate(value)
392
+ if not children_annotated:
393
+ stack.append((expr, True))
394
+ for child_expr in expr.iter_expressions():
395
+ stack.append((child_expr, False))
396
+ continue
392
397
 
393
- return expression
398
+ if scope and isinstance(expr, exp.Column) and expr.table:
399
+ source = scope.sources.get(expr.table)
400
+ if isinstance(source, exp.Table):
401
+ self._set_type(expr, self.schema.get_column_type(source, expr))
402
+ elif source:
403
+ if expr.table in selects and expr.name in selects[expr.table]:
404
+ self._set_type(expr, selects[expr.table][expr.name])
405
+ elif isinstance(source.expression, exp.Unnest):
406
+ self._set_type(expr, source.expression.type)
407
+ else:
408
+ self._set_type(expr, exp.DataType.Type.UNKNOWN)
409
+ else:
410
+ self._set_type(expr, exp.DataType.Type.UNKNOWN)
411
+
412
+ if expr.type and expr.type.args.get("nullable") is False:
413
+ expr.meta["nonnull"] = True
414
+ continue
415
+
416
+ spec = self.expression_metadata.get(expr.__class__)
417
+
418
+ if spec and (annotator := spec.get("annotator")):
419
+ annotator(self, expr)
420
+ elif spec and (returns := spec.get("returns")):
421
+ self._set_type(expr, t.cast(exp.DataType.Type, returns))
422
+ else:
423
+ self._set_type(expr, exp.DataType.Type.UNKNOWN)
394
424
 
395
425
  def _maybe_coerce(
396
426
  self,
@@ -421,14 +451,80 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
421
451
  if exp.DataType.Type.UNKNOWN in (type1_value, type2_value):
422
452
  return exp.DataType.Type.UNKNOWN
423
453
 
454
+ if type1_value == exp.DataType.Type.NULL:
455
+ return type2_value
456
+ if type2_value == exp.DataType.Type.NULL:
457
+ return type1_value
458
+
424
459
  return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value
425
460
 
426
- def _annotate_binary(self, expression: B) -> B:
427
- self._annotate_args(expression)
461
+ def _get_setop_column_types(
462
+ self, setop: exp.SetOperation
463
+ ) -> t.Dict[str, exp.DataType | exp.DataType.Type]:
464
+ """
465
+ Computes and returns the coerced column types for a SetOperation.
466
+
467
+ This handles UNION, INTERSECT, EXCEPT, etc., coercing types across
468
+ left and right operands for all projections/columns.
428
469
 
470
+ Args:
471
+ setop: The SetOperation expression to analyze
472
+
473
+ Returns:
474
+ Dictionary mapping column names to their coerced types
475
+ """
476
+ setop_id = id(setop)
477
+ if setop_id in self._setop_column_types:
478
+ return self._setop_column_types[setop_id]
479
+
480
+ col_types: t.Dict[str, exp.DataType | exp.DataType.Type] = {}
481
+
482
+ # Validate that left and right have same number of projections
483
+ if not (
484
+ isinstance(setop, exp.SetOperation)
485
+ and setop.left.selects
486
+ and setop.right.selects
487
+ and len(setop.left.selects) == len(setop.right.selects)
488
+ ):
489
+ return col_types
490
+
491
+ # Process a chain / sub-tree of set operations
492
+ for set_op in setop.walk(
493
+ prune=lambda n: not isinstance(n, (exp.SetOperation, exp.Subquery))
494
+ ):
495
+ if not isinstance(set_op, exp.SetOperation):
496
+ continue
497
+
498
+ if set_op.args.get("by_name"):
499
+ r_type_by_select = {s.alias_or_name: s.type for s in set_op.right.selects}
500
+ setop_cols = {
501
+ s.alias_or_name: self._maybe_coerce(
502
+ t.cast(exp.DataType, s.type),
503
+ r_type_by_select.get(s.alias_or_name) or exp.DataType.Type.UNKNOWN,
504
+ )
505
+ for s in set_op.left.selects
506
+ }
507
+ else:
508
+ setop_cols = {
509
+ ls.alias_or_name: self._maybe_coerce(
510
+ t.cast(exp.DataType, ls.type), t.cast(exp.DataType, rs.type)
511
+ )
512
+ for ls, rs in zip(set_op.left.selects, set_op.right.selects)
513
+ }
514
+
515
+ # Coerce intermediate results with the previously registered types, if they exist
516
+ for col_name, col_type in setop_cols.items():
517
+ col_types[col_name] = self._maybe_coerce(
518
+ col_type, col_types.get(col_name, exp.DataType.Type.NULL)
519
+ )
520
+
521
+ self._setop_column_types[setop_id] = col_types
522
+ return col_types
523
+
524
+ def _annotate_binary(self, expression: B) -> B:
429
525
  left, right = expression.left, expression.right
430
526
  if not left or not right:
431
- expression_sql = expression.sql(self.schema.dialect)
527
+ expression_sql = expression.sql(self.dialect)
432
528
  logger.warning(f"Failed to annotate badly formed binary expression: {expression_sql}")
433
529
  self._set_type(expression, None)
434
530
  return expression
@@ -442,16 +538,22 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
442
538
  else:
443
539
  self._set_type(expression, self._maybe_coerce(left_type, right_type))
444
540
 
541
+ if isinstance(expression, exp.Is) or (
542
+ left.meta.get("nonnull") is True and right.meta.get("nonnull") is True
543
+ ):
544
+ expression.meta["nonnull"] = True
545
+
445
546
  return expression
446
547
 
447
548
  def _annotate_unary(self, expression: E) -> E:
448
- self._annotate_args(expression)
449
-
450
549
  if isinstance(expression, exp.Not):
451
550
  self._set_type(expression, exp.DataType.Type.BOOLEAN)
452
551
  else:
453
552
  self._set_type(expression, expression.this.type)
454
553
 
554
+ if expression.this.meta.get("nonnull") is True:
555
+ expression.meta["nonnull"] = True
556
+
455
557
  return expression
456
558
 
457
559
  def _annotate_literal(self, expression: exp.Literal) -> exp.Literal:
@@ -462,13 +564,9 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
462
564
  else:
463
565
  self._set_type(expression, exp.DataType.Type.DOUBLE)
464
566
 
465
- return expression
567
+ expression.meta["nonnull"] = True
466
568
 
467
- def _annotate_with_type(
468
- self, expression: E, target_type: exp.DataType | exp.DataType.Type
469
- ) -> E:
470
- self._set_type(expression, target_type)
471
- return self._annotate_args(expression)
569
+ return expression
472
570
 
473
571
  @t.no_type_check
474
572
  def _annotate_by_args(
@@ -478,8 +576,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
478
576
  promote: bool = False,
479
577
  array: bool = False,
480
578
  ) -> E:
481
- self._annotate_args(expression)
482
-
483
579
  expressions: t.List[exp.Expression] = []
484
580
  for arg in args:
485
581
  arg_expr = expression.args.get(arg)
@@ -497,7 +593,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
497
593
  if not expr_type.is_type(exp.DataType.Type.UNKNOWN):
498
594
  last_datatype = self._maybe_coerce(last_datatype or expr_type, expr_type)
499
595
 
500
- self._set_type(expression, last_datatype or exp.DataType.Type.UNKNOWN)
596
+ self._set_type(expression, last_datatype)
501
597
 
502
598
  if promote:
503
599
  if expression.type.this in exp.DataType.INTEGER_TYPES:
@@ -518,8 +614,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
518
614
  def _annotate_timeunit(
519
615
  self, expression: exp.TimeUnit | exp.DateTrunc
520
616
  ) -> exp.TimeUnit | exp.DateTrunc:
521
- self._annotate_args(expression)
522
-
523
617
  if expression.this.type.this in exp.DataType.TEXT_TYPES:
524
618
  datatype = _coerce_date_literal(expression.this, expression.unit)
525
619
  elif expression.this.type.this in exp.DataType.TEMPORAL_TYPES:
@@ -531,8 +625,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
531
625
  return expression
532
626
 
533
627
  def _annotate_bracket(self, expression: exp.Bracket) -> exp.Bracket:
534
- self._annotate_args(expression)
535
-
536
628
  bracket_arg = expression.expressions[0]
537
629
  this = expression.this
538
630
 
@@ -550,8 +642,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
550
642
  return expression
551
643
 
552
644
  def _annotate_div(self, expression: exp.Div) -> exp.Div:
553
- self._annotate_args(expression)
554
-
555
645
  left_type, right_type = expression.left.type.this, expression.right.type.this # type: ignore
556
646
 
557
647
  if (
@@ -570,7 +660,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
570
660
  return expression
571
661
 
572
662
  def _annotate_dot(self, expression: exp.Dot) -> exp.Dot:
573
- self._annotate_args(expression)
574
663
  self._set_type(expression, None)
575
664
  this_type = expression.this.type
576
665
 
@@ -583,12 +672,10 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
583
672
  return expression
584
673
 
585
674
  def _annotate_explode(self, expression: exp.Explode) -> exp.Explode:
586
- self._annotate_args(expression)
587
675
  self._set_type(expression, seq_get(expression.this.type.expressions, 0))
588
676
  return expression
589
677
 
590
678
  def _annotate_unnest(self, expression: exp.Unnest) -> exp.Unnest:
591
- self._annotate_args(expression)
592
679
  child = seq_get(expression.expressions, 0)
593
680
 
594
681
  if child and child.is_type(exp.DataType.Type.ARRAY):
@@ -599,32 +686,59 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
599
686
  self._set_type(expression, expr_type)
600
687
  return expression
601
688
 
689
+ def _annotate_subquery(self, expression: exp.Subquery) -> exp.Subquery:
690
+ # For scalar subqueries (subqueries with a single projection), infer the type
691
+ # from that single projection. This allows type propagation in cases like:
692
+ # SELECT (SELECT 1 AS c) AS c
693
+ query = expression.unnest()
694
+
695
+ if isinstance(query, exp.Query):
696
+ selects = query.selects
697
+ if len(selects) == 1:
698
+ self._set_type(expression, selects[0].type)
699
+ return expression
700
+
701
+ self._set_type(expression, exp.DataType.Type.UNKNOWN)
702
+ return expression
703
+
602
704
  def _annotate_struct_value(
603
705
  self, expression: exp.Expression
604
706
  ) -> t.Optional[exp.DataType] | exp.ColumnDef:
605
707
  # Case: STRUCT(key AS value)
708
+ this: t.Optional[exp.Expression] = None
709
+ kind = expression.type
710
+
606
711
  if alias := expression.args.get("alias"):
607
- return exp.ColumnDef(this=alias.copy(), kind=expression.type)
712
+ this = alias.copy()
713
+ elif expression.expression:
714
+ # Case: STRUCT(key = value) or STRUCT(key := value)
715
+ this = expression.this.copy()
716
+ kind = expression.expression.type
717
+ elif isinstance(expression, exp.Column):
718
+ # Case: STRUCT(c)
719
+ this = expression.this.copy()
608
720
 
609
- # Case: STRUCT(key = value) or STRUCT(key := value)
610
- if expression.expression:
611
- return exp.ColumnDef(this=expression.this.copy(), kind=expression.expression.type)
721
+ if kind and kind.is_type(exp.DataType.Type.UNKNOWN):
722
+ return None
612
723
 
613
- # Case: STRUCT(c)
614
- if isinstance(expression, exp.Column):
615
- return exp.ColumnDef(this=expression.this.copy(), kind=expression.type)
724
+ if this:
725
+ return exp.ColumnDef(this=this, kind=kind)
616
726
 
617
- return expression.type
727
+ return kind
618
728
 
619
729
  def _annotate_struct(self, expression: exp.Struct) -> exp.Struct:
620
- self._annotate_args(expression)
730
+ expressions = []
731
+ for expr in expression.expressions:
732
+ struct_field_type = self._annotate_struct_value(expr)
733
+ if struct_field_type is None:
734
+ self._set_type(expression, None)
735
+ return expression
736
+
737
+ expressions.append(struct_field_type)
738
+
621
739
  self._set_type(
622
740
  expression,
623
- exp.DataType(
624
- this=exp.DataType.Type.STRUCT,
625
- expressions=[self._annotate_struct_value(expr) for expr in expression.expressions],
626
- nested=True,
627
- ),
741
+ exp.DataType(this=exp.DataType.Type.STRUCT, expressions=expressions, nested=True),
628
742
  )
629
743
  return expression
630
744
 
@@ -635,8 +749,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
635
749
  def _annotate_map(self, expression: exp.VarMap) -> exp.VarMap: ...
636
750
 
637
751
  def _annotate_map(self, expression):
638
- self._annotate_args(expression)
639
-
640
752
  keys = expression.args.get("keys")
641
753
  values = expression.args.get("values")
642
754
 
@@ -653,8 +765,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
653
765
  return expression
654
766
 
655
767
  def _annotate_to_map(self, expression: exp.ToMap) -> exp.ToMap:
656
- self._annotate_args(expression)
657
-
658
768
  map_type = exp.DataType(this=exp.DataType.Type.MAP)
659
769
  arg = expression.this
660
770
  if arg.is_type(exp.DataType.Type.STRUCT):
@@ -669,7 +779,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
669
779
  return expression
670
780
 
671
781
  def _annotate_extract(self, expression: exp.Extract) -> exp.Extract:
672
- self._annotate_args(expression)
673
782
  part = expression.name
674
783
  if part == "TIME":
675
784
  self._set_type(expression, exp.DataType.Type.TIME)
@@ -680,8 +789,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
680
789
  return expression
681
790
 
682
791
  def _annotate_by_array_element(self, expression: exp.Expression) -> exp.Expression:
683
- self._annotate_args(expression)
684
-
685
792
  array_arg = expression.this
686
793
  if array_arg.type.is_type(exp.DataType.Type.ARRAY):
687
794
  element_type = seq_get(array_arg.type.expressions, 0) or exp.DataType.Type.UNKNOWN
@@ -35,7 +35,12 @@ def canonicalize(expression: exp.Expression, dialect: DialectType = None) -> exp
35
35
 
36
36
  def add_text_to_concat(node: exp.Expression) -> exp.Expression:
37
37
  if isinstance(node, exp.Add) and node.type and node.type.this in exp.DataType.TEXT_TYPES:
38
- node = exp.Concat(expressions=[node.left, node.right])
38
+ node = exp.Concat(
39
+ expressions=[node.left, node.right],
40
+ # All known dialects, i.e. Redshift and T-SQL, that support
41
+ # concatenating strings with the + operator do not coalesce NULLs.
42
+ coalesce=False,
43
+ )
39
44
  return node
40
45
 
41
46
 
@@ -110,7 +110,7 @@ def _has_single_output_row(scope):
110
110
  return isinstance(scope.expression, exp.Select) and (
111
111
  all(isinstance(e.unalias(), exp.AggFunc) for e in scope.expression.selects)
112
112
  or _is_limit_1(scope)
113
- or not scope.expression.args.get("from")
113
+ or not scope.expression.args.get("from_")
114
114
  )
115
115
 
116
116