sqlglot 28.4.1__py3-none-any.whl → 28.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. sqlglot/_version.py +2 -2
  2. sqlglot/dialects/bigquery.py +20 -23
  3. sqlglot/dialects/clickhouse.py +2 -0
  4. sqlglot/dialects/dialect.py +355 -18
  5. sqlglot/dialects/doris.py +38 -90
  6. sqlglot/dialects/druid.py +1 -0
  7. sqlglot/dialects/duckdb.py +1739 -163
  8. sqlglot/dialects/exasol.py +17 -1
  9. sqlglot/dialects/hive.py +27 -2
  10. sqlglot/dialects/mysql.py +103 -11
  11. sqlglot/dialects/oracle.py +38 -1
  12. sqlglot/dialects/postgres.py +142 -33
  13. sqlglot/dialects/presto.py +6 -2
  14. sqlglot/dialects/redshift.py +7 -1
  15. sqlglot/dialects/singlestore.py +13 -3
  16. sqlglot/dialects/snowflake.py +271 -21
  17. sqlglot/dialects/spark.py +25 -0
  18. sqlglot/dialects/spark2.py +4 -3
  19. sqlglot/dialects/starrocks.py +152 -17
  20. sqlglot/dialects/trino.py +1 -0
  21. sqlglot/dialects/tsql.py +5 -0
  22. sqlglot/diff.py +1 -1
  23. sqlglot/expressions.py +239 -47
  24. sqlglot/generator.py +173 -44
  25. sqlglot/optimizer/annotate_types.py +129 -60
  26. sqlglot/optimizer/merge_subqueries.py +13 -2
  27. sqlglot/optimizer/qualify_columns.py +7 -0
  28. sqlglot/optimizer/resolver.py +19 -0
  29. sqlglot/optimizer/scope.py +12 -0
  30. sqlglot/optimizer/unnest_subqueries.py +7 -0
  31. sqlglot/parser.py +251 -58
  32. sqlglot/schema.py +186 -14
  33. sqlglot/tokens.py +36 -6
  34. sqlglot/transforms.py +6 -5
  35. sqlglot/typing/__init__.py +29 -10
  36. sqlglot/typing/bigquery.py +5 -10
  37. sqlglot/typing/duckdb.py +39 -0
  38. sqlglot/typing/hive.py +50 -1
  39. sqlglot/typing/mysql.py +32 -0
  40. sqlglot/typing/presto.py +0 -1
  41. sqlglot/typing/snowflake.py +80 -17
  42. sqlglot/typing/spark.py +29 -0
  43. sqlglot/typing/spark2.py +9 -1
  44. sqlglot/typing/tsql.py +21 -0
  45. {sqlglot-28.4.1.dist-info → sqlglot-28.8.0.dist-info}/METADATA +47 -2
  46. sqlglot-28.8.0.dist-info/RECORD +95 -0
  47. {sqlglot-28.4.1.dist-info → sqlglot-28.8.0.dist-info}/WHEEL +1 -1
  48. sqlglot-28.4.1.dist-info/RECORD +0 -92
  49. {sqlglot-28.4.1.dist-info → sqlglot-28.8.0.dist-info}/licenses/LICENSE +0 -0
  50. {sqlglot-28.4.1.dist-info → sqlglot-28.8.0.dist-info}/top_level.txt +0 -0
sqlglot/schema.py CHANGED
@@ -111,6 +111,25 @@ class Schema(abc.ABC):
111
111
  name = column if isinstance(column, str) else column.name
112
112
  return name in self.column_names(table, dialect=dialect, normalize=normalize)
113
113
 
114
+ def get_udf_type(
115
+ self,
116
+ udf: exp.Anonymous | str,
117
+ dialect: DialectType = None,
118
+ normalize: t.Optional[bool] = None,
119
+ ) -> exp.DataType:
120
+ """
121
+ Get the return type of a UDF.
122
+
123
+ Args:
124
+ udf: the UDF expression or string.
125
+ dialect: the SQL dialect for parsing string arguments.
126
+ normalize: whether to normalize identifiers.
127
+
128
+ Returns:
129
+ The return type as a DataType, or UNKNOWN if not found.
130
+ """
131
+ return exp.DataType.build("unknown")
132
+
114
133
  @property
115
134
  @abc.abstractmethod
116
135
  def supported_table_args(self) -> t.Tuple[str, ...]:
@@ -128,11 +147,18 @@ class AbstractMappingSchema:
128
147
  def __init__(
129
148
  self,
130
149
  mapping: t.Optional[t.Dict] = None,
150
+ udf_mapping: t.Optional[t.Dict] = None,
131
151
  ) -> None:
132
152
  self.mapping = mapping or {}
133
153
  self.mapping_trie = new_trie(
134
154
  tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth())
135
155
  )
156
+
157
+ self.udf_mapping = udf_mapping or {}
158
+ self.udf_trie = new_trie(
159
+ tuple(reversed(t)) for t in flatten_schema(self.udf_mapping, depth=self.udf_depth())
160
+ )
161
+
136
162
  self._supported_table_args: t.Tuple[str, ...] = tuple()
137
163
 
138
164
  @property
@@ -142,6 +168,9 @@ class AbstractMappingSchema:
142
168
  def depth(self) -> int:
143
169
  return dict_depth(self.mapping)
144
170
 
171
+ def udf_depth(self) -> int:
172
+ return dict_depth(self.udf_mapping)
173
+
145
174
  @property
146
175
  def supported_table_args(self) -> t.Tuple[str, ...]:
147
176
  if not self._supported_table_args and self.mapping:
@@ -157,7 +186,39 @@ class AbstractMappingSchema:
157
186
  return self._supported_table_args
158
187
 
159
188
  def table_parts(self, table: exp.Table) -> t.List[str]:
160
- return [part.name for part in reversed(table.parts)]
189
+ return [p.name for p in reversed(table.parts)]
190
+
191
+ def udf_parts(self, udf: exp.Anonymous) -> t.List[str]:
192
+ # a.b.c(...) is represented as Dot(Dot(a, b), Anonymous(c, ...))
193
+ parent = udf.parent
194
+ parts = [p.name for p in parent.flatten()] if isinstance(parent, exp.Dot) else [udf.name]
195
+ return list(reversed(parts))[0 : self.udf_depth()]
196
+
197
+ def _find_in_trie(
198
+ self,
199
+ parts: t.List[str],
200
+ trie: t.Dict,
201
+ raise_on_missing: bool,
202
+ ) -> t.Optional[t.List[str]]:
203
+ value, trie = in_trie(trie, parts)
204
+
205
+ if value == TrieResult.FAILED:
206
+ return None
207
+
208
+ if value == TrieResult.PREFIX:
209
+ possibilities = flatten_schema(trie)
210
+
211
+ if len(possibilities) == 1:
212
+ parts.extend(possibilities[0])
213
+ else:
214
+ if raise_on_missing:
215
+ joined_parts = ".".join(parts)
216
+ message = ", ".join(".".join(p) for p in possibilities)
217
+ raise SchemaError(f"Ambiguous mapping for {joined_parts}: {message}.")
218
+
219
+ return None
220
+
221
+ return parts
161
222
 
162
223
  def find(
163
224
  self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False
@@ -174,23 +235,35 @@ class AbstractMappingSchema:
174
235
  The schema of the target table.
175
236
  """
176
237
  parts = self.table_parts(table)[0 : len(self.supported_table_args)]
177
- value, trie = in_trie(self.mapping_trie, parts)
238
+ resolved_parts = self._find_in_trie(parts, self.mapping_trie, raise_on_missing)
178
239
 
179
- if value == TrieResult.FAILED:
240
+ if resolved_parts is None:
180
241
  return None
181
242
 
182
- if value == TrieResult.PREFIX:
183
- possibilities = flatten_schema(trie)
243
+ return self.nested_get(resolved_parts, raise_on_missing=raise_on_missing)
184
244
 
185
- if len(possibilities) == 1:
186
- parts.extend(possibilities[0])
187
- else:
188
- message = ", ".join(".".join(parts) for parts in possibilities)
189
- if raise_on_missing:
190
- raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
191
- return None
245
+ def find_udf(self, udf: exp.Anonymous, raise_on_missing: bool = False) -> t.Optional[t.Any]:
246
+ """
247
+ Returns the return type of a given UDF.
248
+
249
+ Args:
250
+ udf: the target UDF expression.
251
+ raise_on_missing: whether to raise if the UDF is not found.
252
+
253
+ Returns:
254
+ The return type of the UDF, or None if not found.
255
+ """
256
+ parts = self.udf_parts(udf)
257
+ resolved_parts = self._find_in_trie(parts, self.udf_trie, raise_on_missing)
192
258
 
193
- return self.nested_get(parts, raise_on_missing=raise_on_missing)
259
+ if resolved_parts is None:
260
+ return None
261
+
262
+ return nested_get(
263
+ self.udf_mapping,
264
+ *zip(resolved_parts, reversed(resolved_parts)),
265
+ raise_on_missing=raise_on_missing,
266
+ )
194
267
 
195
268
  def nested_get(
196
269
  self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
@@ -227,6 +300,7 @@ class MappingSchema(AbstractMappingSchema, Schema):
227
300
  visible: t.Optional[t.Dict] = None,
228
301
  dialect: DialectType = None,
229
302
  normalize: bool = True,
303
+ udf_mapping: t.Optional[t.Dict] = None,
230
304
  ) -> None:
231
305
  self.visible = {} if visible is None else visible
232
306
  self.normalize = normalize
@@ -234,8 +308,12 @@ class MappingSchema(AbstractMappingSchema, Schema):
234
308
  self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
235
309
  self._depth = 0
236
310
  schema = {} if schema is None else schema
311
+ udf_mapping = {} if udf_mapping is None else udf_mapping
237
312
 
238
- super().__init__(self._normalize(schema) if self.normalize else schema)
313
+ super().__init__(
314
+ self._normalize(schema) if self.normalize else schema,
315
+ self._normalize_udfs(udf_mapping) if self.normalize else udf_mapping,
316
+ )
239
317
 
240
318
  @property
241
319
  def dialect(self) -> Dialect:
@@ -249,6 +327,7 @@ class MappingSchema(AbstractMappingSchema, Schema):
249
327
  visible=mapping_schema.visible,
250
328
  dialect=mapping_schema.dialect,
251
329
  normalize=mapping_schema.normalize,
330
+ udf_mapping=mapping_schema.udf_mapping,
252
331
  )
253
332
 
254
333
  def find(
@@ -272,6 +351,7 @@ class MappingSchema(AbstractMappingSchema, Schema):
272
351
  "visible": self.visible.copy(),
273
352
  "dialect": self.dialect,
274
353
  "normalize": self.normalize,
354
+ "udf_mapping": self.udf_mapping.copy(),
275
355
  **kwargs,
276
356
  }
277
357
  )
@@ -360,6 +440,42 @@ class MappingSchema(AbstractMappingSchema, Schema):
360
440
 
361
441
  return exp.DataType.build("unknown")
362
442
 
443
+ def get_udf_type(
444
+ self,
445
+ udf: exp.Anonymous | str,
446
+ dialect: DialectType = None,
447
+ normalize: t.Optional[bool] = None,
448
+ ) -> exp.DataType:
449
+ """
450
+ Get the return type of a UDF.
451
+
452
+ Args:
453
+ udf: the UDF expression or string (e.g., "db.my_func()").
454
+ dialect: the SQL dialect for parsing string arguments.
455
+ normalize: whether to normalize identifiers.
456
+
457
+ Returns:
458
+ The return type as a DataType, or UNKNOWN if not found.
459
+ """
460
+ parts = self._normalize_udf(udf, dialect=dialect, normalize=normalize)
461
+ resolved_parts = self._find_in_trie(parts, self.udf_trie, raise_on_missing=False)
462
+
463
+ if resolved_parts is None:
464
+ return exp.DataType.build("unknown")
465
+
466
+ udf_type = nested_get(
467
+ self.udf_mapping,
468
+ *zip(resolved_parts, reversed(resolved_parts)),
469
+ raise_on_missing=False,
470
+ )
471
+
472
+ if isinstance(udf_type, exp.DataType):
473
+ return udf_type
474
+ elif isinstance(udf_type, str):
475
+ return self._to_data_type(udf_type, dialect=dialect)
476
+
477
+ return exp.DataType.build("unknown")
478
+
363
479
  def has_column(
364
480
  self,
365
481
  table: exp.Table | str,
@@ -414,6 +530,61 @@ class MappingSchema(AbstractMappingSchema, Schema):
414
530
 
415
531
  return normalized_mapping
416
532
 
533
+ def _normalize_udfs(self, udfs: t.Dict) -> t.Dict:
534
+ """
535
+ Normalizes all identifiers in the UDF mapping.
536
+
537
+ Args:
538
+ udfs: the UDF mapping to normalize.
539
+
540
+ Returns:
541
+ The normalized UDF mapping.
542
+ """
543
+ normalized_mapping: t.Dict = {}
544
+
545
+ for keys in flatten_schema(udfs, depth=dict_depth(udfs)):
546
+ udf_type = nested_get(udfs, *zip(keys, keys))
547
+ normalized_keys = [self._normalize_name(key, is_table=True) for key in keys]
548
+ nested_set(normalized_mapping, normalized_keys, udf_type)
549
+
550
+ return normalized_mapping
551
+
552
+ def _normalize_udf(
553
+ self,
554
+ udf: exp.Anonymous | str,
555
+ dialect: DialectType = None,
556
+ normalize: t.Optional[bool] = None,
557
+ ) -> t.List[str]:
558
+ """
559
+ Extract and normalize UDF parts for lookup.
560
+
561
+ Args:
562
+ udf: the UDF expression or qualified string (e.g., "db.my_func()").
563
+ dialect: the SQL dialect for parsing.
564
+ normalize: whether to normalize identifiers.
565
+
566
+ Returns:
567
+ A list of normalized UDF parts (reversed for trie lookup).
568
+ """
569
+ dialect = dialect or self.dialect
570
+ normalize = self.normalize if normalize is None else normalize
571
+
572
+ if isinstance(udf, str):
573
+ parsed: exp.Expression = exp.maybe_parse(udf, dialect=dialect)
574
+
575
+ if isinstance(parsed, exp.Anonymous):
576
+ udf = parsed
577
+ elif isinstance(parsed, exp.Dot) and isinstance(parsed.expression, exp.Anonymous):
578
+ udf = parsed.expression
579
+ else:
580
+ raise SchemaError(f"Unable to parse UDF from: {udf!r}")
581
+ parts = self.udf_parts(udf)
582
+
583
+ if normalize:
584
+ parts = [self._normalize_name(part, dialect=dialect, is_table=True) for part in parts]
585
+
586
+ return parts
587
+
417
588
  def _normalize_table(
418
589
  self,
419
590
  table: exp.Table | str,
@@ -471,6 +642,7 @@ class MappingSchema(AbstractMappingSchema, Schema):
471
642
 
472
643
  try:
473
644
  expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt)
645
+ expression.transform(dialect.normalize_identifier, copy=False)
474
646
  self._type_mapping_cache[schema_type] = expression
475
647
  except AttributeError:
476
648
  in_dialect = f" in dialect {dialect}" if dialect else ""
sqlglot/tokens.py CHANGED
@@ -68,7 +68,7 @@ class TokenType(AutoName):
68
68
  DPIPE_SLASH = auto()
69
69
  CARET = auto()
70
70
  CARET_AT = auto()
71
- TILDA = auto()
71
+ TILDE = auto()
72
72
  ARROW = auto()
73
73
  DARROW = auto()
74
74
  FARROW = auto()
@@ -87,6 +87,7 @@ class TokenType(AutoName):
87
87
  DAMP = auto()
88
88
  AMP_LT = auto()
89
89
  AMP_GT = auto()
90
+ ADJACENT = auto()
90
91
  XOR = auto()
91
92
  DSTAR = auto()
92
93
  QMARK_AMP = auto()
@@ -207,6 +208,7 @@ class TokenType(AutoName):
207
208
  LINESTRING = auto()
208
209
  LOCALTIME = auto()
209
210
  LOCALTIMESTAMP = auto()
211
+ SYSTIMESTAMP = auto()
210
212
  MULTILINESTRING = auto()
211
213
  POLYGON = auto()
212
214
  MULTIPOLYGON = auto()
@@ -370,6 +372,8 @@ class TokenType(AutoName):
370
372
  ORDER_SIBLINGS_BY = auto()
371
373
  ORDERED = auto()
372
374
  ORDINALITY = auto()
375
+ OUT = auto()
376
+ INOUT = auto()
373
377
  OUTER = auto()
374
378
  OVER = auto()
375
379
  OVERLAPS = auto()
@@ -436,6 +440,7 @@ class TokenType(AutoName):
436
440
  USE = auto()
437
441
  USING = auto()
438
442
  VALUES = auto()
443
+ VARIADIC = auto()
439
444
  VIEW = auto()
440
445
  SEMANTIC_VIEW = auto()
441
446
  VOLATILE = auto()
@@ -552,7 +557,11 @@ class _Tokenizer(type):
552
557
  **_quotes_to_format(TokenType.UNICODE_STRING, klass.UNICODE_STRINGS),
553
558
  }
554
559
 
560
+ if "BYTE_STRING_ESCAPES" not in klass.__dict__:
561
+ klass.BYTE_STRING_ESCAPES = klass.STRING_ESCAPES.copy()
562
+
555
563
  klass._STRING_ESCAPES = set(klass.STRING_ESCAPES)
564
+ klass._BYTE_STRING_ESCAPES = set(klass.BYTE_STRING_ESCAPES)
556
565
  klass._ESCAPE_FOLLOW_CHARS = set(klass.ESCAPE_FOLLOW_CHARS)
557
566
  klass._IDENTIFIER_ESCAPES = set(klass.IDENTIFIER_ESCAPES)
558
567
  klass._COMMENTS = {
@@ -585,6 +594,7 @@ class _Tokenizer(type):
585
594
  identifiers=klass._IDENTIFIERS,
586
595
  identifier_escapes=klass._IDENTIFIER_ESCAPES,
587
596
  string_escapes=klass._STRING_ESCAPES,
597
+ byte_string_escapes=klass._BYTE_STRING_ESCAPES,
588
598
  quotes=klass._QUOTES,
589
599
  format_strings={
590
600
  k: (v1, _TOKEN_TYPE_TO_INDEX[v2])
@@ -609,6 +619,7 @@ class _Tokenizer(type):
609
619
  )
610
620
  token_types = RsTokenTypeSettings(
611
621
  bit_string=_TOKEN_TYPE_TO_INDEX[TokenType.BIT_STRING],
622
+ byte_string=_TOKEN_TYPE_TO_INDEX[TokenType.BYTE_STRING],
612
623
  break_=_TOKEN_TYPE_TO_INDEX[TokenType.BREAK],
613
624
  dcolon=_TOKEN_TYPE_TO_INDEX[TokenType.DCOLON],
614
625
  heredoc_string=_TOKEN_TYPE_TO_INDEX[TokenType.HEREDOC_STRING],
@@ -655,7 +666,7 @@ class Tokenizer(metaclass=_Tokenizer):
655
666
  "/": TokenType.SLASH,
656
667
  "\\": TokenType.BACKSLASH,
657
668
  "*": TokenType.STAR,
658
- "~": TokenType.TILDA,
669
+ "~": TokenType.TILDE,
659
670
  "?": TokenType.PLACEHOLDER,
660
671
  "@": TokenType.PARAMETER,
661
672
  "#": TokenType.HASH,
@@ -674,6 +685,7 @@ class Tokenizer(metaclass=_Tokenizer):
674
685
  IDENTIFIERS: t.List[str | t.Tuple[str, str]] = ['"']
675
686
  QUOTES: t.List[t.Tuple[str, str] | str] = ["'"]
676
687
  STRING_ESCAPES = ["'"]
688
+ BYTE_STRING_ESCAPES: t.List[str] = []
677
689
  VAR_SINGLE_TOKENS: t.Set[str] = set()
678
690
  ESCAPE_FOLLOW_CHARS: t.List[str] = []
679
691
 
@@ -704,6 +716,7 @@ class Tokenizer(metaclass=_Tokenizer):
704
716
  _IDENTIFIER_ESCAPES: t.Set[str] = set()
705
717
  _QUOTES: t.Dict[str, str] = {}
706
718
  _STRING_ESCAPES: t.Set[str] = set()
719
+ _BYTE_STRING_ESCAPES: t.Set[str] = set()
707
720
  _KEYWORD_TRIE: t.Dict = {}
708
721
  _RS_TOKENIZER: t.Optional[t.Any] = None
709
722
  _ESCAPE_FOLLOW_CHARS: t.Set[str] = set()
@@ -714,6 +727,8 @@ class Tokenizer(metaclass=_Tokenizer):
714
727
  **{f"{{{{{postfix}": TokenType.BLOCK_START for postfix in ("+", "-")},
715
728
  **{f"{prefix}}}}}": TokenType.BLOCK_END for prefix in ("+", "-")},
716
729
  HINT_START: TokenType.HINT,
730
+ "&<": TokenType.AMP_LT,
731
+ "&>": TokenType.AMP_GT,
717
732
  "==": TokenType.EQ,
718
733
  "::": TokenType.DCOLON,
719
734
  "?::": TokenType.QDCOLON,
@@ -737,6 +752,7 @@ class Tokenizer(metaclass=_Tokenizer):
737
752
  "~~": TokenType.LIKE,
738
753
  "~~*": TokenType.ILIKE,
739
754
  "~*": TokenType.IRLIKE,
755
+ "-|-": TokenType.ADJACENT,
740
756
  "ALL": TokenType.ALL,
741
757
  "AND": TokenType.AND,
742
758
  "ANTI": TokenType.ANTI,
@@ -837,6 +853,7 @@ class Tokenizer(metaclass=_Tokenizer):
837
853
  "XOR": TokenType.XOR,
838
854
  "ORDER BY": TokenType.ORDER_BY,
839
855
  "ORDINALITY": TokenType.ORDINALITY,
856
+ "OUT": TokenType.OUT,
840
857
  "OUTER": TokenType.OUTER,
841
858
  "OVER": TokenType.OVER,
842
859
  "OVERLAPS": TokenType.OVERLAPS,
@@ -850,6 +867,7 @@ class Tokenizer(metaclass=_Tokenizer):
850
867
  "PRAGMA": TokenType.PRAGMA,
851
868
  "PRIMARY KEY": TokenType.PRIMARY_KEY,
852
869
  "PROCEDURE": TokenType.PROCEDURE,
870
+ "OPERATOR": TokenType.OPERATOR,
853
871
  "QUALIFY": TokenType.QUALIFY,
854
872
  "RANGE": TokenType.RANGE,
855
873
  "RECURSIVE": TokenType.RECURSIVE,
@@ -1363,8 +1381,12 @@ class Tokenizer(metaclass=_Tokenizer):
1363
1381
  decimal = True
1364
1382
  self._advance()
1365
1383
  elif self._peek in ("-", "+") and scientific == 1:
1366
- scientific += 1
1367
- self._advance()
1384
+ # Only consume +/- if followed by a digit
1385
+ if self._current + 1 < self.size and self.sql[self._current + 1].isdigit():
1386
+ scientific += 1
1387
+ self._advance()
1388
+ else:
1389
+ return self._add(TokenType.NUMBER)
1368
1390
  elif self._peek.upper() == "E" and not scientific:
1369
1391
  scientific += 1
1370
1392
  self._advance()
@@ -1464,7 +1486,15 @@ class Tokenizer(metaclass=_Tokenizer):
1464
1486
  return False
1465
1487
 
1466
1488
  self._advance(len(start))
1467
- text = self._extract_string(end, raw_string=token_type == TokenType.RAW_STRING)
1489
+ text = self._extract_string(
1490
+ end,
1491
+ escapes=(
1492
+ self._BYTE_STRING_ESCAPES
1493
+ if token_type == TokenType.BYTE_STRING
1494
+ else self._STRING_ESCAPES
1495
+ ),
1496
+ raw_string=token_type == TokenType.RAW_STRING,
1497
+ )
1468
1498
 
1469
1499
  if base and text:
1470
1500
  try:
@@ -1514,7 +1544,7 @@ class Tokenizer(metaclass=_Tokenizer):
1514
1544
  not raw_string
1515
1545
  and self.dialect.UNESCAPED_SEQUENCES
1516
1546
  and self._peek
1517
- and self._char in self.STRING_ESCAPES
1547
+ and self._char in escapes
1518
1548
  ):
1519
1549
  unescaped_sequence = self.dialect.UNESCAPED_SEQUENCES.get(self._char + self._peek)
1520
1550
  if unescaped_sequence:
sqlglot/transforms.py CHANGED
@@ -1042,12 +1042,13 @@ def inherit_struct_field_names(expression: exp.Expression) -> exp.Expression:
1042
1042
  new_expressions = []
1043
1043
  for i, expr in enumerate(struct.expressions):
1044
1044
  if not isinstance(expr, exp.PropertyEQ):
1045
- # Create PropertyEQ: field_name := value
1046
- new_expressions.append(
1047
- exp.PropertyEQ(
1048
- this=exp.Identifier(this=field_names[i].copy()), expression=expr
1049
- )
1045
+ # Create PropertyEQ: field_name := value, preserving the type from the inner expression
1046
+ property_eq = exp.PropertyEQ(
1047
+ this=exp.Identifier(this=field_names[i].copy()),
1048
+ expression=expr,
1050
1049
  )
1050
+ property_eq.type = expr.type
1051
+ new_expressions.append(property_eq)
1051
1052
  else:
1052
1053
  new_expressions.append(expr)
1053
1054
 
@@ -30,7 +30,6 @@ EXPRESSION_METADATA: ExpressionMetadataType = {
30
30
  exp.ArraySize,
31
31
  exp.CountIf,
32
32
  exp.Int64,
33
- exp.Length,
34
33
  exp.UnixDate,
35
34
  exp.UnixSeconds,
36
35
  exp.UnixMicros,
@@ -47,11 +46,16 @@ EXPRESSION_METADATA: ExpressionMetadataType = {
47
46
  **{
48
47
  expr_type: {"returns": exp.DataType.Type.BOOLEAN}
49
48
  for expr_type in {
49
+ exp.All,
50
+ exp.Any,
50
51
  exp.Between,
51
52
  exp.Boolean,
52
53
  exp.Contains,
53
54
  exp.EndsWith,
55
+ exp.Exists,
54
56
  exp.In,
57
+ exp.IsInf,
58
+ exp.IsNan,
55
59
  exp.LogicalAnd,
56
60
  exp.LogicalOr,
57
61
  exp.RegexpLike,
@@ -86,7 +90,9 @@ EXPRESSION_METADATA: ExpressionMetadataType = {
86
90
  for expr_type in {
87
91
  exp.ApproxQuantile,
88
92
  exp.Avg,
93
+ exp.Cbrt,
89
94
  exp.Exp,
95
+ exp.Kurtosis,
90
96
  exp.Ln,
91
97
  exp.Log,
92
98
  exp.Pi,
@@ -109,16 +115,20 @@ EXPRESSION_METADATA: ExpressionMetadataType = {
109
115
  expr_type: {"returns": exp.DataType.Type.INT}
110
116
  for expr_type in {
111
117
  exp.Ascii,
118
+ exp.BitLength,
112
119
  exp.Ceil,
113
120
  exp.DatetimeDiff,
121
+ exp.Getbit,
114
122
  exp.TimestampDiff,
115
123
  exp.TimeDiff,
116
124
  exp.Unicode,
117
125
  exp.DateToDi,
118
126
  exp.Levenshtein,
127
+ exp.Length,
119
128
  exp.Sign,
120
129
  exp.StrPosition,
121
130
  exp.TsOrDiToDi,
131
+ exp.Quarter,
122
132
  }
123
133
  },
124
134
  **{
@@ -141,6 +151,7 @@ EXPRESSION_METADATA: ExpressionMetadataType = {
141
151
  expr_type: {"returns": exp.DataType.Type.TIME}
142
152
  for expr_type in {
143
153
  exp.CurrentTime,
154
+ exp.Localtime,
144
155
  exp.Time,
145
156
  exp.TimeAdd,
146
157
  exp.TimeSub,
@@ -169,7 +180,6 @@ EXPRESSION_METADATA: ExpressionMetadataType = {
169
180
  exp.DayOfWeekIso,
170
181
  exp.DayOfYear,
171
182
  exp.Month,
172
- exp.Quarter,
173
183
  exp.Week,
174
184
  exp.WeekOfYear,
175
185
  exp.Year,
@@ -184,11 +194,14 @@ EXPRESSION_METADATA: ExpressionMetadataType = {
184
194
  exp.Concat,
185
195
  exp.ConcatWs,
186
196
  exp.Chr,
197
+ exp.Dayname,
187
198
  exp.DateToDateStr,
188
199
  exp.DPipe,
189
200
  exp.GroupConcat,
190
201
  exp.Initcap,
191
202
  exp.Lower,
203
+ exp.SHA,
204
+ exp.SHA2,
192
205
  exp.Substring,
193
206
  exp.String,
194
207
  exp.TimeToStr,
@@ -200,6 +213,8 @@ EXPRESSION_METADATA: ExpressionMetadataType = {
200
213
  exp.UnixToStr,
201
214
  exp.UnixToTimeStr,
202
215
  exp.Upper,
216
+ exp.RawString,
217
+ exp.Space,
203
218
  }
204
219
  },
205
220
  **{
@@ -237,13 +252,7 @@ EXPRESSION_METADATA: ExpressionMetadataType = {
237
252
  exp.ArrayLast,
238
253
  }
239
254
  },
240
- **{
241
- expr_type: {"returns": exp.DataType.Type.UNKNOWN}
242
- for expr_type in {
243
- exp.Anonymous,
244
- exp.Slice,
245
- }
246
- },
255
+ exp.Anonymous: {"annotator": lambda self, e: self._set_type(e, self.schema.get_udf_type(e))},
247
256
  **{
248
257
  expr_type: {"annotator": lambda self, e: self._annotate_timeunit(e)}
249
258
  for expr_type in {
@@ -269,7 +278,11 @@ EXPRESSION_METADATA: ExpressionMetadataType = {
269
278
  exp.Array: {"annotator": lambda self, e: self._annotate_by_args(e, "expressions", array=True)},
270
279
  exp.ArrayAgg: {"annotator": lambda self, e: self._annotate_by_args(e, "this", array=True)},
271
280
  exp.Bracket: {"annotator": lambda self, e: self._annotate_bracket(e)},
272
- exp.Case: {"annotator": lambda self, e: self._annotate_by_args(e, "default", "ifs")},
281
+ exp.Case: {
282
+ "annotator": lambda self, e: self._annotate_by_args(
283
+ e, *[if_expr.args["true"] for if_expr in e.args["ifs"]], "default"
284
+ )
285
+ },
273
286
  exp.Count: {
274
287
  "annotator": lambda self, e: self._set_type(
275
288
  e, exp.DataType.Type.BIGINT if e.args.get("big_int") else exp.DataType.Type.INT
@@ -286,6 +299,12 @@ EXPRESSION_METADATA: ExpressionMetadataType = {
286
299
  exp.Dot: {"annotator": lambda self, e: self._annotate_dot(e)},
287
300
  exp.Explode: {"annotator": lambda self, e: self._annotate_explode(e)},
288
301
  exp.Extract: {"annotator": lambda self, e: self._annotate_extract(e)},
302
+ exp.HexString: {
303
+ "annotator": lambda self, e: self._set_type(
304
+ e,
305
+ exp.DataType.Type.BIGINT if e.args.get("is_integer") else exp.DataType.Type.BINARY,
306
+ )
307
+ },
289
308
  exp.GenerateSeries: {
290
309
  "annotator": lambda self, e: self._annotate_by_args(e, "start", "end", "step", array=True)
291
310
  },
@@ -163,9 +163,9 @@ EXPRESSION_METADATA = {
163
163
  **{
164
164
  expr_type: {"annotator": lambda self, e: self._annotate_by_args(e, "this")}
165
165
  for expr_type in {
166
- exp.Abs,
167
166
  exp.ArgMax,
168
167
  exp.ArgMin,
168
+ exp.DateAdd,
169
169
  exp.DateTrunc,
170
170
  exp.DatetimeTrunc,
171
171
  exp.FirstValue,
@@ -175,6 +175,7 @@ EXPRESSION_METADATA = {
175
175
  exp.Lead,
176
176
  exp.Left,
177
177
  exp.Lower,
178
+ exp.NetFunc,
178
179
  exp.NthValue,
179
180
  exp.Pad,
180
181
  exp.PercentileDisc,
@@ -185,6 +186,7 @@ EXPRESSION_METADATA = {
185
186
  exp.RespectNulls,
186
187
  exp.Reverse,
187
188
  exp.Right,
189
+ exp.SafeFunc,
188
190
  exp.SafeNegate,
189
191
  exp.Sign,
190
192
  exp.Substring,
@@ -197,7 +199,6 @@ EXPRESSION_METADATA = {
197
199
  **{
198
200
  expr_type: {"returns": exp.DataType.Type.BIGINT}
199
201
  for expr_type in {
200
- exp.Ascii,
201
202
  exp.BitwiseAndAgg,
202
203
  exp.BitwiseCount,
203
204
  exp.BitwiseOrAgg,
@@ -213,7 +214,6 @@ EXPRESSION_METADATA = {
213
214
  exp.RangeBucket,
214
215
  exp.RegexpInstr,
215
216
  exp.RowNumber,
216
- exp.Unicode,
217
217
  }
218
218
  },
219
219
  **{
@@ -232,8 +232,6 @@ EXPRESSION_METADATA = {
232
232
  **{
233
233
  expr_type: {"returns": exp.DataType.Type.BOOLEAN}
234
234
  for expr_type in {
235
- exp.IsInf,
236
- exp.IsNan,
237
235
  exp.JSONBool,
238
236
  exp.LaxBool,
239
237
  }
@@ -255,7 +253,6 @@ EXPRESSION_METADATA = {
255
253
  exp.Atan,
256
254
  exp.Atan2,
257
255
  exp.Atanh,
258
- exp.Cbrt,
259
256
  exp.Corr,
260
257
  exp.CosineDistance,
261
258
  exp.Cot,
@@ -302,13 +299,14 @@ EXPRESSION_METADATA = {
302
299
  for expr_type in {
303
300
  exp.CodePointsToString,
304
301
  exp.Format,
302
+ exp.Host,
305
303
  exp.JSONExtractScalar,
306
304
  exp.JSONType,
307
305
  exp.LaxString,
308
306
  exp.LowerHex,
309
307
  exp.MD5,
310
- exp.NetHost,
311
308
  exp.Normalize,
309
+ exp.RegDomain,
312
310
  exp.SafeConvertBytesToString,
313
311
  exp.Soundex,
314
312
  exp.Uuid,
@@ -339,9 +337,6 @@ EXPRESSION_METADATA = {
339
337
  exp.ApproxTopK: {"annotator": lambda self, e: _annotate_by_args_approx_top(self, e)},
340
338
  exp.ApproxTopSum: {"annotator": lambda self, e: _annotate_by_args_approx_top(self, e)},
341
339
  exp.Array: {"annotator": _annotate_array},
342
- exp.ArrayConcat: {
343
- "annotator": lambda self, e: self._annotate_by_args(e, "this", "expressions")
344
- },
345
340
  exp.Concat: {"annotator": _annotate_concat},
346
341
  exp.DateFromUnixDate: {"returns": exp.DataType.Type.DATE},
347
342
  exp.GenerateTimestampArray: {