sqlglot 27.13.2__py3-none-any.whl → 27.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -6,6 +6,7 @@ from sqlglot import exp, generator, jsonpath, parser, tokens, transforms
6
6
  from sqlglot.dialects.dialect import (
7
7
  Dialect,
8
8
  NormalizationStrategy,
9
+ annotate_with_type_lambda,
9
10
  build_timetostr_or_tochar,
10
11
  binary_from_function,
11
12
  build_default_decimal_type,
@@ -32,6 +33,7 @@ from sqlglot.dialects.dialect import (
32
33
  )
33
34
  from sqlglot.generator import unsupported_args
34
35
  from sqlglot.helper import find_new_name, flatten, is_float, is_int, seq_get
36
+ from sqlglot.optimizer.annotate_types import TypeAnnotator
35
37
  from sqlglot.optimizer.scope import build_scope, find_all_in_scope
36
38
  from sqlglot.tokens import TokenType
37
39
 
@@ -376,6 +378,7 @@ def _qualify_unnested_columns(expression: exp.Expression) -> exp.Expression:
376
378
 
377
379
  taken_source_names = set(scope.sources)
378
380
  column_source: t.Dict[str, exp.Identifier] = {}
381
+ unnest_to_identifier: t.Dict[exp.Unnest, exp.Identifier] = {}
379
382
 
380
383
  unnest_identifier: t.Optional[exp.Identifier] = None
381
384
  orig_expression = expression.copy()
@@ -428,6 +431,7 @@ def _qualify_unnested_columns(expression: exp.Expression) -> exp.Expression:
428
431
  if not isinstance(unnest_identifier, exp.Identifier):
429
432
  return orig_expression
430
433
 
434
+ unnest_to_identifier[unnest] = unnest_identifier
431
435
  column_source.update({c.lower(): unnest_identifier for c in unnest_columns})
432
436
 
433
437
  for column in scope.columns:
@@ -441,6 +445,15 @@ def _qualify_unnested_columns(expression: exp.Expression) -> exp.Expression:
441
445
  and len(scope.sources) == 1
442
446
  and column.name.lower() != unnest_identifier.name.lower()
443
447
  ):
448
+ unnest_ancestor = column.find_ancestor(exp.Unnest, exp.Select)
449
+ ancestor_identifier = unnest_to_identifier.get(unnest_ancestor)
450
+ if (
451
+ isinstance(unnest_ancestor, exp.Unnest)
452
+ and ancestor_identifier
453
+ and ancestor_identifier.name.lower() == unnest_identifier.name.lower()
454
+ ):
455
+ continue
456
+
444
457
  table = unnest_identifier
445
458
 
446
459
  column.set("table", table and table.copy())
@@ -482,6 +495,15 @@ def _eliminate_dot_variant_lookup(expression: exp.Expression) -> exp.Expression:
482
495
  return expression
483
496
 
484
497
 
498
+ def _annotate_reverse(self: TypeAnnotator, expression: exp.Reverse) -> exp.Reverse:
499
+ expression = self._annotate_by_args(expression, "this")
500
+ if expression.is_type(exp.DataType.Type.NULL):
501
+ # Snowflake treats REVERSE(NULL) as a VARCHAR
502
+ self._set_type(expression, exp.DataType.Type.VARCHAR)
503
+
504
+ return expression
505
+
506
+
485
507
  class Snowflake(Dialect):
486
508
  # https://docs.snowflake.com/en/sql-reference/identifiers-syntax
487
509
  NORMALIZATION_STRATEGY = NormalizationStrategy.UPPERCASE
@@ -496,13 +518,59 @@ class Snowflake(Dialect):
496
518
  ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN = False
497
519
  TRY_CAST_REQUIRES_STRING = True
498
520
 
521
+ TYPE_TO_EXPRESSIONS = {
522
+ **Dialect.TYPE_TO_EXPRESSIONS,
523
+ exp.DataType.Type.INT: {
524
+ *Dialect.TYPE_TO_EXPRESSIONS[exp.DataType.Type.INT],
525
+ exp.Length,
526
+ },
527
+ exp.DataType.Type.VARCHAR: {
528
+ *Dialect.TYPE_TO_EXPRESSIONS[exp.DataType.Type.VARCHAR],
529
+ exp.MD5,
530
+ exp.AIAgg,
531
+ exp.AISummarizeAgg,
532
+ exp.RegexpExtract,
533
+ exp.RegexpReplace,
534
+ exp.Repeat,
535
+ exp.Replace,
536
+ exp.SHA,
537
+ exp.SHA2,
538
+ exp.Space,
539
+ exp.Uuid,
540
+ },
541
+ exp.DataType.Type.BINARY: {
542
+ *Dialect.TYPE_TO_EXPRESSIONS[exp.DataType.Type.BINARY],
543
+ exp.MD5Digest,
544
+ exp.SHA1Digest,
545
+ exp.SHA2Digest,
546
+ },
547
+ exp.DataType.Type.BIGINT: {
548
+ *Dialect.TYPE_TO_EXPRESSIONS[exp.DataType.Type.BIGINT],
549
+ exp.MD5NumberLower64,
550
+ exp.MD5NumberUpper64,
551
+ },
552
+ exp.DataType.Type.ARRAY: {
553
+ exp.Split,
554
+ },
555
+ }
556
+
499
557
  ANNOTATORS = {
500
558
  **Dialect.ANNOTATORS,
559
+ **{
560
+ expr_type: annotate_with_type_lambda(data_type)
561
+ for data_type, expressions in TYPE_TO_EXPRESSIONS.items()
562
+ for expr_type in expressions
563
+ },
501
564
  **{
502
565
  expr_type: lambda self, e: self._annotate_by_args(e, "this")
503
- for expr_type in (exp.Reverse,)
566
+ for expr_type in (
567
+ exp.Left,
568
+ exp.Right,
569
+ exp.Substring,
570
+ )
504
571
  },
505
572
  exp.ConcatWs: lambda self, e: self._annotate_by_args(e, "expressions"),
573
+ exp.Reverse: _annotate_reverse,
506
574
  }
507
575
 
508
576
  TIME_MAPPING = {
@@ -622,6 +690,10 @@ class Snowflake(Dialect):
622
690
  ),
623
691
  "HEX_DECODE_BINARY": exp.Unhex.from_arg_list,
624
692
  "IFF": exp.If.from_arg_list,
693
+ "MD5_HEX": exp.MD5.from_arg_list,
694
+ "MD5_BINARY": exp.MD5Digest.from_arg_list,
695
+ "MD5_NUMBER_LOWER64": exp.MD5NumberLower64.from_arg_list,
696
+ "MD5_NUMBER_UPPER64": exp.MD5NumberUpper64.from_arg_list,
625
697
  "LAST_DAY": lambda args: exp.LastDay(
626
698
  this=seq_get(args, 0), unit=map_date_part(seq_get(args, 1))
627
699
  ),
@@ -629,12 +701,17 @@ class Snowflake(Dialect):
629
701
  "LENGTH": lambda args: exp.Length(this=seq_get(args, 0), binary=True),
630
702
  "NULLIFZERO": _build_if_from_nullifzero,
631
703
  "OBJECT_CONSTRUCT": _build_object_construct,
704
+ "OCTET_LENGTH": exp.ByteLength.from_arg_list,
632
705
  "REGEXP_EXTRACT_ALL": _build_regexp_extract(exp.RegexpExtractAll),
633
706
  "REGEXP_REPLACE": _build_regexp_replace,
634
707
  "REGEXP_SUBSTR": _build_regexp_extract(exp.RegexpExtract),
635
708
  "REGEXP_SUBSTR_ALL": _build_regexp_extract(exp.RegexpExtractAll),
636
709
  "REPLACE": build_replace_with_optional_replacement,
637
710
  "RLIKE": exp.RegexpLike.from_arg_list,
711
+ "SHA1_BINARY": exp.SHA1Digest.from_arg_list,
712
+ "SHA1_HEX": exp.SHA.from_arg_list,
713
+ "SHA2_BINARY": exp.SHA2Digest.from_arg_list,
714
+ "SHA2_HEX": exp.SHA2.from_arg_list,
638
715
  "SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
639
716
  "TABLE": lambda args: exp.TableFromRows(this=seq_get(args, 0)),
640
717
  "TIMEADD": _build_date_time_add(exp.TimeAdd),
@@ -664,7 +741,8 @@ class Snowflake(Dialect):
664
741
  "TO_TIMESTAMP_LTZ": _build_datetime("TO_TIMESTAMP_LTZ", exp.DataType.Type.TIMESTAMPLTZ),
665
742
  "TO_TIMESTAMP_NTZ": _build_datetime("TO_TIMESTAMP_NTZ", exp.DataType.Type.TIMESTAMP),
666
743
  "TO_TIMESTAMP_TZ": _build_datetime("TO_TIMESTAMP_TZ", exp.DataType.Type.TIMESTAMPTZ),
667
- "TO_VARCHAR": exp.ToChar.from_arg_list,
744
+ "TO_VARCHAR": build_timetostr_or_tochar,
745
+ "TO_JSON": exp.JSONFormat.from_arg_list,
668
746
  "VECTOR_L2_DISTANCE": exp.EuclideanDistance.from_arg_list,
669
747
  "ZEROIFNULL": _build_if_from_zeroifnull,
670
748
  }
@@ -1273,6 +1351,7 @@ class Snowflake(Dialect):
1273
1351
  exp.ParseJSON: lambda self, e: self.func(
1274
1352
  "TRY_PARSE_JSON" if e.args.get("safe") else "PARSE_JSON", e.this
1275
1353
  ),
1354
+ exp.JSONFormat: rename_func("TO_JSON"),
1276
1355
  exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
1277
1356
  exp.PercentileCont: transforms.preprocess(
1278
1357
  [transforms.add_within_group_for_percentiles]
@@ -1297,6 +1376,10 @@ class Snowflake(Dialect):
1297
1376
  ]
1298
1377
  ),
1299
1378
  exp.SHA: rename_func("SHA1"),
1379
+ exp.MD5Digest: rename_func("MD5_BINARY"),
1380
+ exp.MD5NumberLower64: rename_func("MD5_NUMBER_LOWER64"),
1381
+ exp.MD5NumberUpper64: rename_func("MD5_NUMBER_UPPER64"),
1382
+ exp.LowerHex: rename_func("TO_CHAR"),
1300
1383
  exp.SortArray: rename_func("ARRAY_SORT"),
1301
1384
  exp.StarMap: rename_func("OBJECT_CONSTRUCT"),
1302
1385
  exp.StartsWith: rename_func("STARTSWITH"),
@@ -1334,6 +1417,7 @@ class Snowflake(Dialect):
1334
1417
  exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
1335
1418
  exp.WeekOfYear: rename_func("WEEKOFYEAR"),
1336
1419
  exp.Xor: rename_func("BOOLXOR"),
1420
+ exp.ByteLength: rename_func("OCTET_LENGTH"),
1337
1421
  }
1338
1422
 
1339
1423
  SUPPORTED_JSON_PATH_PARTS = {
@@ -1344,9 +1428,10 @@ class Snowflake(Dialect):
1344
1428
 
1345
1429
  TYPE_MAPPING = {
1346
1430
  **generator.Generator.TYPE_MAPPING,
1431
+ exp.DataType.Type.BIGDECIMAL: "DOUBLE",
1347
1432
  exp.DataType.Type.NESTED: "OBJECT",
1348
1433
  exp.DataType.Type.STRUCT: "OBJECT",
1349
- exp.DataType.Type.BIGDECIMAL: "DOUBLE",
1434
+ exp.DataType.Type.TEXT: "VARCHAR",
1350
1435
  }
1351
1436
 
1352
1437
  TOKEN_MAPPING = {
@@ -110,6 +110,7 @@ class SQLite(Dialect):
110
110
  STRING_ALIASES = True
111
111
  ALTER_RENAME_REQUIRES_COLUMN = False
112
112
  JOINS_HAVE_EQUAL_PRECEDENCE = True
113
+ ADD_JOIN_ON_TRUE = True
113
114
 
114
115
  FUNCTIONS = {
115
116
  **parser.Parser.FUNCTIONS,
sqlglot/dialects/tsql.py CHANGED
@@ -650,6 +650,16 @@ class TSQL(Dialect):
650
650
  "NEXT": lambda self: self._parse_next_value_for(),
651
651
  }
652
652
 
653
+ FUNCTION_PARSERS: t.Dict[str, t.Callable] = {
654
+ **parser.Parser.FUNCTION_PARSERS,
655
+ "JSON_ARRAYAGG": lambda self: self.expression(
656
+ exp.JSONArrayAgg,
657
+ this=self._parse_bitwise(),
658
+ order=self._parse_order(),
659
+ null_handling=self._parse_on_handling("NULL", "NULL", "ABSENT"),
660
+ ),
661
+ }
662
+
653
663
  # The DCOLON (::) operator serves as a scope resolution (exp.ScopeResolution) operator in T-SQL
654
664
  COLUMN_OPERATORS = {
655
665
  **parser.Parser.COLUMN_OPERATORS,
sqlglot/expressions.py CHANGED
@@ -134,6 +134,11 @@ class Expression(metaclass=_Expression):
134
134
 
135
135
  return hash((self.__class__, self.hashable_args))
136
136
 
137
+ def __reduce__(self) -> t.Tuple[t.Callable, t.Tuple[t.List[t.Dict[str, t.Any]]]]:
138
+ from sqlglot.serde import dump, load
139
+
140
+ return (load, (dump(self),))
141
+
137
142
  @property
138
143
  def this(self) -> t.Any:
139
144
  """
@@ -259,7 +264,7 @@ class Expression(metaclass=_Expression):
259
264
  return self.type is not None and self.type.is_type(*dtypes)
260
265
 
261
266
  def is_leaf(self) -> bool:
262
- return not any(isinstance(v, (Expression, list)) for v in self.args.values())
267
+ return not any(isinstance(v, (Expression, list)) and v for v in self.args.values())
263
268
 
264
269
  @property
265
270
  def meta(self) -> t.Dict[str, t.Any]:
@@ -1646,6 +1651,12 @@ class Show(Expression):
1646
1651
  "position": False,
1647
1652
  "types": False,
1648
1653
  "privileges": False,
1654
+ "for_table": False,
1655
+ "for_group": False,
1656
+ "for_user": False,
1657
+ "for_role": False,
1658
+ "into_outfile": False,
1659
+ "json": False,
1649
1660
  }
1650
1661
 
1651
1662
 
@@ -2054,7 +2065,7 @@ class ProjectionPolicyColumnConstraint(ColumnConstraintKind):
2054
2065
  # computed column expression
2055
2066
  # https://learn.microsoft.com/en-us/sql/t-sql/statements/create-table-transact-sql?view=sql-server-ver16
2056
2067
  class ComputedColumnConstraint(ColumnConstraintKind):
2057
- arg_types = {"this": True, "persisted": False, "not_null": False}
2068
+ arg_types = {"this": True, "persisted": False, "not_null": False, "data_type": False}
2058
2069
 
2059
2070
 
2060
2071
  class Constraint(Expression):
@@ -2197,7 +2208,7 @@ class Copy(DML):
2197
2208
  arg_types = {
2198
2209
  "this": True,
2199
2210
  "kind": True,
2200
- "files": True,
2211
+ "files": False,
2201
2212
  "credentials": False,
2202
2213
  "format": False,
2203
2214
  "params": False,
@@ -5755,6 +5766,15 @@ class ArrayUniqueAgg(AggFunc):
5755
5766
  pass
5756
5767
 
5757
5768
 
5769
+ class AIAgg(AggFunc):
5770
+ arg_types = {"this": True, "expression": True}
5771
+ _sql_names = ["AI_AGG"]
5772
+
5773
+
5774
+ class AISummarizeAgg(AggFunc):
5775
+ _sql_names = ["AI_SUMMARIZE_AGG"]
5776
+
5777
+
5758
5778
  class ArrayAll(Func):
5759
5779
  arg_types = {"this": True, "expression": True}
5760
5780
 
@@ -6694,11 +6714,26 @@ class JSONBContains(Binary, Func):
6694
6714
  _sql_names = ["JSONB_CONTAINS"]
6695
6715
 
6696
6716
 
6717
+ # https://www.postgresql.org/docs/9.5/functions-json.html
6718
+ class JSONBContainsAnyTopKeys(Binary, Func):
6719
+ pass
6720
+
6721
+
6722
+ # https://www.postgresql.org/docs/9.5/functions-json.html
6723
+ class JSONBContainsAllTopKeys(Binary, Func):
6724
+ pass
6725
+
6726
+
6697
6727
  class JSONBExists(Func):
6698
6728
  arg_types = {"this": True, "path": True}
6699
6729
  _sql_names = ["JSONB_EXISTS"]
6700
6730
 
6701
6731
 
6732
+ # https://www.postgresql.org/docs/9.5/functions-json.html
6733
+ class JSONBDeleteAtPath(Binary, Func):
6734
+ pass
6735
+
6736
+
6702
6737
  class JSONExtract(Binary, Func):
6703
6738
  arg_types = {
6704
6739
  "this": True,
@@ -6925,6 +6960,16 @@ class MD5Digest(Func):
6925
6960
  _sql_names = ["MD5_DIGEST"]
6926
6961
 
6927
6962
 
6963
+ # https://docs.snowflake.com/en/sql-reference/functions/md5_number_lower64
6964
+ class MD5NumberLower64(Func):
6965
+ pass
6966
+
6967
+
6968
+ # https://docs.snowflake.com/en/sql-reference/functions/md5_number_upper64
6969
+ class MD5NumberUpper64(Func):
6970
+ pass
6971
+
6972
+
6928
6973
  class Median(AggFunc):
6929
6974
  pass
6930
6975
 
@@ -6963,6 +7008,11 @@ class Predict(Func):
6963
7008
  arg_types = {"this": True, "expression": True, "params_struct": False}
6964
7009
 
6965
7010
 
7011
+ # https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-translate#mltranslate_function
7012
+ class MLTranslate(Func):
7013
+ arg_types = {"this": True, "expression": True, "params_struct": True}
7014
+
7015
+
6966
7016
  # https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-feature-time
6967
7017
  class FeaturesAtTime(Func):
6968
7018
  arg_types = {"this": True, "time": False, "num_rows": False, "ignore_feature_nulls": False}
@@ -6970,7 +7020,11 @@ class FeaturesAtTime(Func):
6970
7020
 
6971
7021
  # https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-generate-embedding
6972
7022
  class GenerateEmbedding(Func):
6973
- arg_types = {"this": True, "expression": True, "params_struct": False}
7023
+ arg_types = {"this": True, "expression": True, "params_struct": False, "is_text": False}
7024
+
7025
+
7026
+ class MLForecast(Func):
7027
+ arg_types = {"this": True, "expression": False, "params_struct": False}
6974
7028
 
6975
7029
 
6976
7030
  # https://cloud.google.com/bigquery/docs/reference/standard-sql/search_functions#vector_search
@@ -7166,6 +7220,16 @@ class SHA2(Func):
7166
7220
  arg_types = {"this": True, "length": False}
7167
7221
 
7168
7222
 
7223
+ # Represents the variant of the SHA1 function that returns a binary value
7224
+ class SHA1Digest(Func):
7225
+ pass
7226
+
7227
+
7228
+ # Represents the variant of the SHA2 function that returns a binary value
7229
+ class SHA2Digest(Func):
7230
+ arg_types = {"this": True, "length": False}
7231
+
7232
+
7169
7233
  class Sign(Func):
7170
7234
  _sql_names = ["SIGN", "SIGNUM"]
7171
7235
 
sqlglot/generator.py CHANGED
@@ -160,6 +160,9 @@ class Generator(metaclass=_Generator):
160
160
  exp.Intersect: lambda self, e: self.set_operations(e),
161
161
  exp.IntervalSpan: lambda self, e: f"{self.sql(e, 'this')} TO {self.sql(e, 'expression')}",
162
162
  exp.Int64: lambda self, e: self.sql(exp.cast(e.this, exp.DataType.Type.BIGINT)),
163
+ exp.JSONBContainsAnyTopKeys: lambda self, e: self.binary(e, "?|"),
164
+ exp.JSONBContainsAllTopKeys: lambda self, e: self.binary(e, "?&"),
165
+ exp.JSONBDeleteAtPath: lambda self, e: self.binary(e, "#-"),
163
166
  exp.LanguageProperty: lambda self, e: self.naked_property(e),
164
167
  exp.LocationProperty: lambda self, e: self.naked_property(e),
165
168
  exp.LogProperty: lambda _, e: f"{'NO ' if e.args.get('no') else ''}LOG",
@@ -4214,21 +4217,32 @@ class Generator(metaclass=_Generator):
4214
4217
  def opclass_sql(self, expression: exp.Opclass) -> str:
4215
4218
  return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}"
4216
4219
 
4217
- def predict_sql(self, expression: exp.Predict) -> str:
4220
+ def _ml_sql(self, expression: exp.Func, name: str) -> str:
4218
4221
  model = self.sql(expression, "this")
4219
4222
  model = f"MODEL {model}"
4220
- table = self.sql(expression, "expression")
4221
- table = f"TABLE {table}" if not isinstance(expression.expression, exp.Subquery) else table
4222
- parameters = self.sql(expression, "params_struct")
4223
- return self.func("PREDICT", model, table, parameters or None)
4223
+ expr = expression.expression
4224
+ if expr:
4225
+ expr_sql = self.sql(expression, "expression")
4226
+ expr_sql = f"TABLE {expr_sql}" if not isinstance(expr, exp.Subquery) else expr_sql
4227
+ else:
4228
+ expr_sql = None
4229
+
4230
+ parameters = self.sql(expression, "params_struct") or None
4231
+
4232
+ return self.func(name, model, expr_sql, parameters)
4233
+
4234
+ def predict_sql(self, expression: exp.Predict) -> str:
4235
+ return self._ml_sql(expression, "PREDICT")
4224
4236
 
4225
4237
  def generateembedding_sql(self, expression: exp.GenerateEmbedding) -> str:
4226
- model = self.sql(expression, "this")
4227
- model = f"MODEL {model}"
4228
- table = self.sql(expression, "expression")
4229
- table = f"TABLE {table}" if not isinstance(expression.expression, exp.Subquery) else table
4230
- parameters = self.sql(expression, "params_struct")
4231
- return self.func("GENERATE_EMBEDDING", model, table, parameters or None)
4238
+ name = "GENERATE_TEXT_EMBEDDING" if expression.args.get("is_text") else "GENERATE_EMBEDDING"
4239
+ return self._ml_sql(expression, name)
4240
+
4241
+ def mltranslate_sql(self, expression: exp.MLTranslate) -> str:
4242
+ return self._ml_sql(expression, "TRANSLATE")
4243
+
4244
+ def mlforecast_sql(self, expression: exp.MLForecast) -> str:
4245
+ return self._ml_sql(expression, "FORECAST")
4232
4246
 
4233
4247
  def featuresattime_sql(self, expression: exp.FeaturesAtTime) -> str:
4234
4248
  this_sql = self.sql(expression, "this")
@@ -4579,8 +4593,8 @@ class Generator(metaclass=_Generator):
4579
4593
 
4580
4594
  credentials = self.sql(expression, "credentials")
4581
4595
  credentials = self.seg(credentials) if credentials else ""
4582
- kind = self.seg("FROM" if expression.args.get("kind") else "TO")
4583
4596
  files = self.expressions(expression, key="files", flat=True)
4597
+ kind = self.seg("FROM" if expression.args.get("kind") else "TO") if files else ""
4584
4598
 
4585
4599
  sep = ", " if self.dialect.COPY_PARAMS_ARE_CSV else " "
4586
4600
  params = self.expressions(
@@ -4596,7 +4610,7 @@ class Generator(metaclass=_Generator):
4596
4610
  if params:
4597
4611
  if self.COPY_PARAMS_ARE_WRAPPED:
4598
4612
  params = f" WITH ({params})"
4599
- elif not self.pretty:
4613
+ elif not self.pretty and (files or credentials):
4600
4614
  params = f" {params}"
4601
4615
 
4602
4616
  return f"COPY{this}{kind} {files}{credentials}{params}"
@@ -193,6 +193,12 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
193
193
  # Caches the ids of annotated sub-Expressions, to ensure we only visit them once
194
194
  self._visited: t.Set[int] = set()
195
195
 
196
+ # Caches NULL-annotated expressions to set them to UNKNOWN after type inference is completed
197
+ self._null_expressions: t.Dict[int, exp.Expression] = {}
198
+
199
+ # Databricks and Spark ≥v3 actually support NULL (i.e., VOID) as a type
200
+ self._supports_null_type = schema.dialect in ("databricks", "spark")
201
+
196
202
  # Maps an exp.SetOperation's id (e.g. UNION) to its projection types. This is computed if the
197
203
  # exp.SetOperation is the expression of a scope source, as selecting from it multiple times
198
204
  # would reprocess the entire subtree to coerce the types of its operands' projections
@@ -201,13 +207,33 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
201
207
  def _set_type(
202
208
  self, expression: exp.Expression, target_type: t.Optional[exp.DataType | exp.DataType.Type]
203
209
  ) -> None:
210
+ prev_type = expression.type
211
+ expression_id = id(expression)
212
+
204
213
  expression.type = target_type or exp.DataType.Type.UNKNOWN # type: ignore
205
- self._visited.add(id(expression))
214
+ self._visited.add(expression_id)
215
+
216
+ if (
217
+ not self._supports_null_type
218
+ and t.cast(exp.DataType, expression.type).this == exp.DataType.Type.NULL
219
+ ):
220
+ self._null_expressions[expression_id] = expression
221
+ elif prev_type and t.cast(exp.DataType, prev_type).this == exp.DataType.Type.NULL:
222
+ self._null_expressions.pop(expression_id, None)
206
223
 
207
224
  def annotate(self, expression: E) -> E:
208
225
  for scope in traverse_scope(expression):
209
226
  self.annotate_scope(scope)
210
- return self._maybe_annotate(expression) # This takes care of non-traversable expressions
227
+
228
+ # This takes care of non-traversable expressions
229
+ expression = self._maybe_annotate(expression)
230
+
231
+ # Replace NULL type with UNKNOWN, since the former is not an actual type;
232
+ # it is mostly used to aid type coercion, e.g. in query set operations.
233
+ for expr in self._null_expressions.values():
234
+ expr.type = exp.DataType.Type.UNKNOWN
235
+
236
+ return expression
211
237
 
212
238
  def annotate_scope(self, scope: Scope) -> None:
213
239
  selects = {}
@@ -567,14 +593,18 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
567
593
  def _annotate_struct_value(
568
594
  self, expression: exp.Expression
569
595
  ) -> t.Optional[exp.DataType] | exp.ColumnDef:
570
- alias = expression.args.get("alias")
571
- if alias:
596
+ # Case: STRUCT(key AS value)
597
+ if alias := expression.args.get("alias"):
572
598
  return exp.ColumnDef(this=alias.copy(), kind=expression.type)
573
599
 
574
- # Case: key = value or key := value
600
+ # Case: STRUCT(key = value) or STRUCT(key := value)
575
601
  if expression.expression:
576
602
  return exp.ColumnDef(this=expression.this.copy(), kind=expression.expression.type)
577
603
 
604
+ # Case: STRUCT(c)
605
+ if isinstance(expression, exp.Column):
606
+ return exp.ColumnDef(this=expression.this.copy(), kind=expression.type)
607
+
578
608
  return expression.type
579
609
 
580
610
  def _annotate_struct(self, expression: exp.Struct) -> exp.Struct:
@@ -351,11 +351,15 @@ def _expand_alias_refs(
351
351
  alias_to_expression[projection.alias] = (projection.this, i + 1)
352
352
 
353
353
  parent_scope = scope
354
- while parent_scope.is_union:
354
+ on_right_sub_tree = False
355
+ while parent_scope and not parent_scope.is_cte:
356
+ if parent_scope.is_union:
357
+ on_right_sub_tree = parent_scope.parent.expression.right is parent_scope.expression
355
358
  parent_scope = parent_scope.parent
356
359
 
357
360
  # We shouldn't expand aliases if they match the recursive CTE's columns
358
- if parent_scope.is_cte:
361
+ # and we are in the recursive part (right sub tree) of the CTE
362
+ if parent_scope and on_right_sub_tree:
359
363
  cte = parent_scope.expression.parent
360
364
  if cte.find_ancestor(exp.With).recursive:
361
365
  for recursive_cte_column in cte.args["alias"].columns or cte.this.selects: