sqlglot 27.27.0__py3-none-any.whl → 28.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. sqlglot/__init__.py +1 -0
  2. sqlglot/__main__.py +6 -4
  3. sqlglot/_version.py +2 -2
  4. sqlglot/dialects/bigquery.py +118 -279
  5. sqlglot/dialects/clickhouse.py +73 -5
  6. sqlglot/dialects/databricks.py +38 -1
  7. sqlglot/dialects/dialect.py +354 -275
  8. sqlglot/dialects/dremio.py +4 -1
  9. sqlglot/dialects/duckdb.py +754 -25
  10. sqlglot/dialects/exasol.py +243 -10
  11. sqlglot/dialects/hive.py +8 -8
  12. sqlglot/dialects/mysql.py +14 -4
  13. sqlglot/dialects/oracle.py +29 -0
  14. sqlglot/dialects/postgres.py +60 -26
  15. sqlglot/dialects/presto.py +47 -16
  16. sqlglot/dialects/redshift.py +16 -0
  17. sqlglot/dialects/risingwave.py +3 -0
  18. sqlglot/dialects/singlestore.py +12 -3
  19. sqlglot/dialects/snowflake.py +239 -218
  20. sqlglot/dialects/spark.py +15 -4
  21. sqlglot/dialects/spark2.py +11 -48
  22. sqlglot/dialects/sqlite.py +10 -0
  23. sqlglot/dialects/starrocks.py +3 -0
  24. sqlglot/dialects/teradata.py +5 -8
  25. sqlglot/dialects/trino.py +6 -0
  26. sqlglot/dialects/tsql.py +61 -22
  27. sqlglot/diff.py +4 -2
  28. sqlglot/errors.py +69 -0
  29. sqlglot/executor/__init__.py +5 -10
  30. sqlglot/executor/python.py +1 -29
  31. sqlglot/expressions.py +637 -100
  32. sqlglot/generator.py +160 -43
  33. sqlglot/helper.py +2 -44
  34. sqlglot/lineage.py +10 -4
  35. sqlglot/optimizer/annotate_types.py +247 -140
  36. sqlglot/optimizer/canonicalize.py +6 -1
  37. sqlglot/optimizer/eliminate_joins.py +1 -1
  38. sqlglot/optimizer/eliminate_subqueries.py +2 -2
  39. sqlglot/optimizer/merge_subqueries.py +5 -5
  40. sqlglot/optimizer/normalize.py +20 -13
  41. sqlglot/optimizer/normalize_identifiers.py +17 -3
  42. sqlglot/optimizer/optimizer.py +4 -0
  43. sqlglot/optimizer/pushdown_predicates.py +1 -1
  44. sqlglot/optimizer/qualify.py +18 -10
  45. sqlglot/optimizer/qualify_columns.py +122 -275
  46. sqlglot/optimizer/qualify_tables.py +128 -76
  47. sqlglot/optimizer/resolver.py +374 -0
  48. sqlglot/optimizer/scope.py +27 -16
  49. sqlglot/optimizer/simplify.py +1075 -959
  50. sqlglot/optimizer/unnest_subqueries.py +12 -2
  51. sqlglot/parser.py +296 -170
  52. sqlglot/planner.py +2 -2
  53. sqlglot/schema.py +15 -4
  54. sqlglot/tokens.py +42 -7
  55. sqlglot/transforms.py +77 -22
  56. sqlglot/typing/__init__.py +316 -0
  57. sqlglot/typing/bigquery.py +376 -0
  58. sqlglot/typing/hive.py +12 -0
  59. sqlglot/typing/presto.py +24 -0
  60. sqlglot/typing/snowflake.py +505 -0
  61. sqlglot/typing/spark2.py +58 -0
  62. sqlglot/typing/tsql.py +9 -0
  63. {sqlglot-27.27.0.dist-info → sqlglot-28.4.0.dist-info}/METADATA +2 -2
  64. sqlglot-28.4.0.dist-info/RECORD +92 -0
  65. sqlglot-27.27.0.dist-info/RECORD +0 -84
  66. {sqlglot-27.27.0.dist-info → sqlglot-28.4.0.dist-info}/WHEEL +0 -0
  67. {sqlglot-27.27.0.dist-info → sqlglot-28.4.0.dist-info}/licenses/LICENSE +0 -0
  68. {sqlglot-27.27.0.dist-info → sqlglot-28.4.0.dist-info}/top_level.txt +0 -0
@@ -6,12 +6,11 @@ from sqlglot import exp, generator, jsonpath, parser, tokens, transforms
6
6
  from sqlglot.dialects.dialect import (
7
7
  Dialect,
8
8
  NormalizationStrategy,
9
- annotate_with_type_lambda,
10
9
  build_timetostr_or_tochar,
10
+ build_like,
11
11
  binary_from_function,
12
12
  build_default_decimal_type,
13
13
  build_replace_with_optional_replacement,
14
- build_timestamp_from_parts,
15
14
  date_delta_sql,
16
15
  date_trunc_to_time,
17
16
  datestrtodate_sql,
@@ -23,6 +22,7 @@ from sqlglot.dialects.dialect import (
23
22
  rename_func,
24
23
  timestamptrunc_sql,
25
24
  timestrtotime_sql,
25
+ unit_to_str,
26
26
  var_map_sql,
27
27
  map_date_part,
28
28
  no_timestamp_sql,
@@ -33,9 +33,9 @@ from sqlglot.dialects.dialect import (
33
33
  )
34
34
  from sqlglot.generator import unsupported_args
35
35
  from sqlglot.helper import find_new_name, flatten, is_float, is_int, seq_get
36
- from sqlglot.optimizer.annotate_types import TypeAnnotator
37
36
  from sqlglot.optimizer.scope import build_scope, find_all_in_scope
38
37
  from sqlglot.tokens import TokenType
38
+ from sqlglot.typing.snowflake import EXPRESSION_METADATA
39
39
 
40
40
  if t.TYPE_CHECKING:
41
41
  from sqlglot._typing import E, B
@@ -53,6 +53,21 @@ def _build_strtok(args: t.List) -> exp.SplitPart:
53
53
  return exp.SplitPart.from_arg_list(args)
54
54
 
55
55
 
56
+ def _build_approx_top_k(args: t.List) -> exp.ApproxTopK:
57
+ """
58
+ Normalizes APPROX_TOP_K arguments to match Snowflake semantics.
59
+
60
+ Snowflake APPROX_TOP_K signature: APPROX_TOP_K(column [, k] [, counters])
61
+ - k defaults to 1 if omitted (per Snowflake documentation)
62
+ - counters is optional precision parameter
63
+ """
64
+ # Add default k=1 if only column is provided
65
+ if len(args) == 1:
66
+ args.append(exp.Literal.number(1))
67
+
68
+ return exp.ApproxTopK.from_arg_list(args)
69
+
70
+
56
71
  def _build_datetime(
57
72
  name: str, kind: exp.DataType.Type, safe: bool = False
58
73
  ) -> t.Callable[[t.List], exp.Func]:
@@ -128,6 +143,11 @@ def _build_date_time_add(expr_type: t.Type[E]) -> t.Callable[[t.List], E]:
128
143
  def _build_bitwise(expr_type: t.Type[B], name: str) -> t.Callable[[t.List], B | exp.Anonymous]:
129
144
  def _builder(args: t.List) -> B | exp.Anonymous:
130
145
  if len(args) == 3:
146
+ # Special handling for bitwise operations with padside argument
147
+ if expr_type in (exp.BitwiseAnd, exp.BitwiseOr, exp.BitwiseXor):
148
+ return expr_type(
149
+ this=seq_get(args, 0), expression=seq_get(args, 1), padside=seq_get(args, 2)
150
+ )
131
151
  return exp.Anonymous(this=name, expressions=args)
132
152
 
133
153
  return binary_from_function(expr_type)(args)
@@ -344,8 +364,8 @@ def _transform_generate_date_array(expression: exp.Expression) -> exp.Expression
344
364
  return expression
345
365
 
346
366
 
347
- def _build_regexp_extract(expr_type: t.Type[E]) -> t.Callable[[t.List], E]:
348
- def _builder(args: t.List) -> E:
367
+ def _build_regexp_extract(expr_type: t.Type[E]) -> t.Callable[[t.List, Snowflake], E]:
368
+ def _builder(args: t.List, dialect: Snowflake) -> E:
349
369
  return expr_type(
350
370
  this=seq_get(args, 0),
351
371
  expression=seq_get(args, 1),
@@ -353,20 +373,16 @@ def _build_regexp_extract(expr_type: t.Type[E]) -> t.Callable[[t.List], E]:
353
373
  occurrence=seq_get(args, 3),
354
374
  parameters=seq_get(args, 4),
355
375
  group=seq_get(args, 5) or exp.Literal.number(0),
376
+ **(
377
+ {"null_if_pos_overflow": dialect.REGEXP_EXTRACT_POSITION_OVERFLOW_RETURNS_NULL}
378
+ if expr_type is exp.RegexpExtract
379
+ else {}
380
+ ),
356
381
  )
357
382
 
358
383
  return _builder
359
384
 
360
385
 
361
- def _build_like(expr_type: t.Type[E]) -> t.Callable[[t.List], E | exp.Escape]:
362
- def _builder(args: t.List) -> E | exp.Escape:
363
- like_expr = expr_type(this=args[0], expression=args[1])
364
- escape = seq_get(args, 2)
365
- return exp.Escape(this=like_expr, expression=escape) if escape else like_expr
366
-
367
- return _builder
368
-
369
-
370
386
  def _regexpextract_sql(self, expression: exp.RegexpExtract | exp.RegexpExtractAll) -> str:
371
387
  # Other dialects don't support all of the following parameters, so we need to
372
388
  # generate default values as necessary to ensure the transpilation is correct
@@ -538,15 +554,57 @@ def _eliminate_dot_variant_lookup(expression: exp.Expression) -> exp.Expression:
538
554
  return expression
539
555
 
540
556
 
541
- def _annotate_reverse(self: TypeAnnotator, expression: exp.Reverse) -> exp.Reverse:
542
- expression = self._annotate_by_args(expression, "this")
543
- if expression.is_type(exp.DataType.Type.NULL):
544
- # Snowflake treats REVERSE(NULL) as a VARCHAR
545
- self._set_type(expression, exp.DataType.Type.VARCHAR)
557
+ def _build_timestamp_from_parts(args: t.List) -> exp.Func:
558
+ """Build TimestampFromParts with support for both syntaxes:
559
+ 1. TIMESTAMP_FROM_PARTS(year, month, day, hour, minute, second [, nanosecond] [, time_zone])
560
+ 2. TIMESTAMP_FROM_PARTS(date_expr, time_expr) - Snowflake specific
561
+ """
562
+ if len(args) == 2:
563
+ return exp.TimestampFromParts(this=seq_get(args, 0), expression=seq_get(args, 1))
564
+
565
+ return exp.TimestampFromParts.from_arg_list(args)
566
+
567
+
568
+ def _build_round(args: t.List) -> exp.Round:
569
+ """
570
+ Build Round expression, unwrapping Snowflake's named parameters.
571
+
572
+ Maps EXPR => this, SCALE => decimals, ROUNDING_MODE => truncate.
546
573
 
574
+ Note: Snowflake does not support mixing named and positional arguments.
575
+ Arguments are either all named or all positional.
576
+ """
577
+ kwarg_map = {"EXPR": "this", "SCALE": "decimals", "ROUNDING_MODE": "truncate"}
578
+ round_args = {}
579
+ positional_keys = ["this", "decimals", "truncate"]
580
+ positional_idx = 0
581
+
582
+ for arg in args:
583
+ if isinstance(arg, exp.Kwarg):
584
+ key = arg.this.name.upper()
585
+ round_key = kwarg_map.get(key)
586
+ if round_key:
587
+ round_args[round_key] = arg.expression
588
+ else:
589
+ if positional_idx < len(positional_keys):
590
+ round_args[positional_keys[positional_idx]] = arg
591
+ positional_idx += 1
592
+
593
+ expression = exp.Round(**round_args)
594
+ expression.set("casts_non_integer_decimals", True)
547
595
  return expression
548
596
 
549
597
 
598
+ def _build_try_to_number(args: t.List[exp.Expression]) -> exp.Expression:
599
+ return exp.ToNumber(
600
+ this=seq_get(args, 0),
601
+ format=seq_get(args, 1),
602
+ precision=seq_get(args, 2),
603
+ scale=seq_get(args, 3),
604
+ safe=True,
605
+ )
606
+
607
+
550
608
  class Snowflake(Dialect):
551
609
  # https://docs.snowflake.com/en/sql-reference/identifiers-syntax
552
610
  NORMALIZATION_STRATEGY = NormalizationStrategy.UPPERCASE
@@ -560,137 +618,12 @@ class Snowflake(Dialect):
560
618
  ARRAY_AGG_INCLUDES_NULLS = None
561
619
  ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN = False
562
620
  TRY_CAST_REQUIRES_STRING = True
621
+ SUPPORTS_ALIAS_REFS_IN_JOIN_CONDITIONS = True
563
622
 
564
- TYPE_TO_EXPRESSIONS = {
565
- **Dialect.TYPE_TO_EXPRESSIONS,
566
- exp.DataType.Type.DOUBLE: {
567
- *Dialect.TYPE_TO_EXPRESSIONS[exp.DataType.Type.DOUBLE],
568
- exp.Cos,
569
- exp.Cosh,
570
- exp.Cot,
571
- exp.Degrees,
572
- exp.Exp,
573
- exp.Sin,
574
- exp.Sinh,
575
- exp.Tan,
576
- exp.Tanh,
577
- exp.Asin,
578
- exp.Asinh,
579
- exp.Atan,
580
- exp.Atan2,
581
- exp.Atanh,
582
- exp.Cbrt,
583
- },
584
- exp.DataType.Type.INT: {
585
- *Dialect.TYPE_TO_EXPRESSIONS[exp.DataType.Type.INT],
586
- exp.Ascii,
587
- exp.ByteLength,
588
- exp.Length,
589
- exp.RtrimmedLength,
590
- exp.BitLength,
591
- exp.Levenshtein,
592
- exp.JarowinklerSimilarity,
593
- exp.StrPosition,
594
- exp.Unicode,
595
- },
596
- exp.DataType.Type.VARCHAR: {
597
- *Dialect.TYPE_TO_EXPRESSIONS[exp.DataType.Type.VARCHAR],
598
- exp.Base64DecodeString,
599
- exp.TryBase64DecodeString,
600
- exp.Base64Encode,
601
- exp.DecompressString,
602
- exp.MD5,
603
- exp.AIAgg,
604
- exp.AIClassify,
605
- exp.AISummarizeAgg,
606
- exp.Chr,
607
- exp.Collate,
608
- exp.Collation,
609
- exp.HexDecodeString,
610
- exp.TryHexDecodeString,
611
- exp.HexEncode,
612
- exp.Initcap,
613
- exp.RegexpExtract,
614
- exp.RegexpReplace,
615
- exp.Repeat,
616
- exp.Replace,
617
- exp.SHA,
618
- exp.SHA2,
619
- exp.Soundex,
620
- exp.SoundexP123,
621
- exp.Space,
622
- exp.SplitPart,
623
- exp.Translate,
624
- exp.Uuid,
625
- },
626
- exp.DataType.Type.BINARY: {
627
- *Dialect.TYPE_TO_EXPRESSIONS[exp.DataType.Type.BINARY],
628
- exp.Base64DecodeBinary,
629
- exp.TryBase64DecodeBinary,
630
- exp.TryHexDecodeBinary,
631
- exp.Compress,
632
- exp.DecompressBinary,
633
- exp.MD5Digest,
634
- exp.SHA1Digest,
635
- exp.SHA2Digest,
636
- exp.Unhex,
637
- },
638
- exp.DataType.Type.BIGINT: {
639
- *Dialect.TYPE_TO_EXPRESSIONS[exp.DataType.Type.BIGINT],
640
- exp.Factorial,
641
- exp.MD5NumberLower64,
642
- exp.MD5NumberUpper64,
643
- },
644
- exp.DataType.Type.ARRAY: {
645
- exp.Split,
646
- exp.RegexpExtractAll,
647
- exp.StringToArray,
648
- },
649
- exp.DataType.Type.OBJECT: {
650
- exp.ParseUrl,
651
- exp.ParseIp,
652
- },
653
- exp.DataType.Type.DECIMAL: {
654
- exp.RegexpCount,
655
- },
656
- exp.DataType.Type.BOOLEAN: {
657
- *Dialect.TYPE_TO_EXPRESSIONS[exp.DataType.Type.BOOLEAN],
658
- exp.Search,
659
- },
660
- }
623
+ EXPRESSION_METADATA = EXPRESSION_METADATA.copy()
661
624
 
662
- ANNOTATORS = {
663
- **Dialect.ANNOTATORS,
664
- **{
665
- expr_type: annotate_with_type_lambda(data_type)
666
- for data_type, expressions in TYPE_TO_EXPRESSIONS.items()
667
- for expr_type in expressions
668
- },
669
- **{
670
- expr_type: lambda self, e: self._annotate_by_args(e, "this")
671
- for expr_type in (
672
- exp.Floor,
673
- exp.Left,
674
- exp.Pad,
675
- exp.Right,
676
- exp.Stuff,
677
- exp.Substring,
678
- exp.Round,
679
- exp.Ceil,
680
- )
681
- },
682
- **{
683
- expr_type: lambda self, e: self._annotate_with_type(
684
- e, exp.DataType.build("NUMBER", dialect="snowflake")
685
- )
686
- for expr_type in (
687
- exp.RegexpCount,
688
- exp.RegexpInstr,
689
- )
690
- },
691
- exp.ConcatWs: lambda self, e: self._annotate_by_args(e, "expressions"),
692
- exp.Reverse: _annotate_reverse,
693
- }
625
+ # https://docs.snowflake.com/en/en/sql-reference/functions/initcap
626
+ INITCAP_DEFAULT_DELIMITER_CHARS = ' \t\n\r\f\v!?@"^#$&~_,.:;+\\-*%/|\\[\\](){}<>'
694
627
 
695
628
  TIME_MAPPING = {
696
629
  "YYYY": "%Y",
@@ -724,17 +657,16 @@ class Snowflake(Dialect):
724
657
  "ISOWEEK": "WEEKISO",
725
658
  }
726
659
 
727
- def quote_identifier(self, expression: E, identify: bool = True) -> E:
660
+ PSEUDOCOLUMNS = {"LEVEL"}
661
+
662
+ def can_quote(self, identifier: exp.Identifier, identify: str | bool = "safe") -> bool:
728
663
  # This disables quoting DUAL in SELECT ... FROM DUAL, because Snowflake treats an
729
664
  # unquoted DUAL keyword in a special way and does not map it to a user-defined table
730
- if (
731
- isinstance(expression, exp.Identifier)
732
- and isinstance(expression.parent, exp.Table)
733
- and expression.name.lower() == "dual"
734
- ):
735
- return expression # type: ignore
736
-
737
- return super().quote_identifier(expression, identify=identify)
665
+ return super().can_quote(identifier, identify) and not (
666
+ isinstance(identifier.parent, exp.Table)
667
+ and not identifier.quoted
668
+ and identifier.name.lower() == "dual"
669
+ )
738
670
 
739
671
  class JSONPathTokenizer(jsonpath.JSONPathTokenizer):
740
672
  SINGLE_TOKENS = jsonpath.JSONPathTokenizer.SINGLE_TOKENS.copy()
@@ -760,6 +692,7 @@ class Snowflake(Dialect):
760
692
  FUNCTIONS = {
761
693
  **parser.Parser.FUNCTIONS,
762
694
  "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
695
+ "APPROX_TOP_K": _build_approx_top_k,
763
696
  "ARRAY_CONSTRUCT": lambda args: exp.Array(expressions=args),
764
697
  "ARRAY_CONTAINS": lambda args: exp.ArrayContains(
765
698
  this=seq_get(args, 1), expression=seq_get(args, 0), ensure_variant=False
@@ -771,6 +704,10 @@ class Snowflake(Dialect):
771
704
  step=seq_get(args, 2),
772
705
  ),
773
706
  "ARRAY_SORT": exp.SortArray.from_arg_list,
707
+ "BITAND": _build_bitwise(exp.BitwiseAnd, "BITAND"),
708
+ "BIT_AND": _build_bitwise(exp.BitwiseAnd, "BITAND"),
709
+ "BITNOT": lambda args: exp.BitwiseNot(this=seq_get(args, 0)),
710
+ "BIT_NOT": lambda args: exp.BitwiseNot(this=seq_get(args, 0)),
774
711
  "BITXOR": _build_bitwise(exp.BitwiseXor, "BITXOR"),
775
712
  "BIT_XOR": _build_bitwise(exp.BitwiseXor, "BITXOR"),
776
713
  "BITOR": _build_bitwise(exp.BitwiseOr, "BITOR"),
@@ -791,6 +728,7 @@ class Snowflake(Dialect):
791
728
  "BITXOR_AGG": exp.BitwiseXorAgg.from_arg_list,
792
729
  "BIT_XOR_AGG": exp.BitwiseXorAgg.from_arg_list,
793
730
  "BIT_XORAGG": exp.BitwiseXorAgg.from_arg_list,
731
+ "BITMAP_OR_AGG": exp.BitmapOrAgg.from_arg_list,
794
732
  "BOOLXOR": _build_bitwise(exp.Xor, "BOOLXOR"),
795
733
  "DATE": _build_datetime("DATE", exp.DataType.Type.DATE),
796
734
  "DATE_TRUNC": _date_trunc_to_time,
@@ -804,6 +742,7 @@ class Snowflake(Dialect):
804
742
  ),
805
743
  "FLATTEN": exp.Explode.from_arg_list,
806
744
  "GET": exp.GetExtract.from_arg_list,
745
+ "GETDATE": exp.CurrentTimestamp.from_arg_list,
807
746
  "GET_PATH": lambda args, dialect: exp.JSONExtract(
808
747
  this=seq_get(args, 0),
809
748
  expression=dialect.to_json_path(seq_get(args, 1)),
@@ -832,6 +771,7 @@ class Snowflake(Dialect):
832
771
  "REGEXP_SUBSTR_ALL": _build_regexp_extract(exp.RegexpExtractAll),
833
772
  "REPLACE": build_replace_with_optional_replacement,
834
773
  "RLIKE": exp.RegexpLike.from_arg_list,
774
+ "ROUND": _build_round,
835
775
  "SHA1_BINARY": exp.SHA1Digest.from_arg_list,
836
776
  "SHA1_HEX": exp.SHA.from_arg_list,
837
777
  "SHA2_BINARY": exp.SHA2Digest.from_arg_list,
@@ -843,23 +783,39 @@ class Snowflake(Dialect):
843
783
  "TIMEDIFF": _build_datediff,
844
784
  "TIMESTAMPADD": _build_date_time_add(exp.DateAdd),
845
785
  "TIMESTAMPDIFF": _build_datediff,
846
- "TIMESTAMPFROMPARTS": build_timestamp_from_parts,
847
- "TIMESTAMP_FROM_PARTS": build_timestamp_from_parts,
848
- "TIMESTAMPNTZFROMPARTS": build_timestamp_from_parts,
849
- "TIMESTAMP_NTZ_FROM_PARTS": build_timestamp_from_parts,
786
+ "TIMESTAMPFROMPARTS": _build_timestamp_from_parts,
787
+ "TIMESTAMP_FROM_PARTS": _build_timestamp_from_parts,
788
+ "TIMESTAMPNTZFROMPARTS": _build_timestamp_from_parts,
789
+ "TIMESTAMP_NTZ_FROM_PARTS": _build_timestamp_from_parts,
850
790
  "TRY_PARSE_JSON": lambda args: exp.ParseJSON(this=seq_get(args, 0), safe=True),
791
+ "TRY_TO_BINARY": lambda args: exp.ToBinary(
792
+ this=seq_get(args, 0), format=seq_get(args, 1), safe=True
793
+ ),
794
+ "TRY_TO_BOOLEAN": lambda args: exp.ToBoolean(this=seq_get(args, 0), safe=True),
851
795
  "TRY_TO_DATE": _build_datetime("TRY_TO_DATE", exp.DataType.Type.DATE, safe=True),
796
+ **dict.fromkeys(
797
+ ("TRY_TO_DECIMAL", "TRY_TO_NUMBER", "TRY_TO_NUMERIC"), _build_try_to_number
798
+ ),
799
+ "TRY_TO_DOUBLE": lambda args: exp.ToDouble(
800
+ this=seq_get(args, 0), format=seq_get(args, 1), safe=True
801
+ ),
802
+ "TRY_TO_FILE": lambda args: exp.ToFile(
803
+ this=seq_get(args, 0), path=seq_get(args, 1), safe=True
804
+ ),
852
805
  "TRY_TO_TIME": _build_datetime("TRY_TO_TIME", exp.DataType.Type.TIME, safe=True),
853
806
  "TRY_TO_TIMESTAMP": _build_datetime(
854
807
  "TRY_TO_TIMESTAMP", exp.DataType.Type.TIMESTAMP, safe=True
855
808
  ),
856
809
  "TO_CHAR": build_timetostr_or_tochar,
857
810
  "TO_DATE": _build_datetime("TO_DATE", exp.DataType.Type.DATE),
858
- "TO_NUMBER": lambda args: exp.ToNumber(
859
- this=seq_get(args, 0),
860
- format=seq_get(args, 1),
861
- precision=seq_get(args, 2),
862
- scale=seq_get(args, 3),
811
+ **dict.fromkeys(
812
+ ("TO_DECIMAL", "TO_NUMBER", "TO_NUMERIC"),
813
+ lambda args: exp.ToNumber(
814
+ this=seq_get(args, 0),
815
+ format=seq_get(args, 1),
816
+ precision=seq_get(args, 2),
817
+ scale=seq_get(args, 3),
818
+ ),
863
819
  ),
864
820
  "TO_TIME": _build_datetime("TO_TIME", exp.DataType.Type.TIME),
865
821
  "TO_TIMESTAMP": _build_datetime("TO_TIMESTAMP", exp.DataType.Type.TIMESTAMP),
@@ -868,11 +824,18 @@ class Snowflake(Dialect):
868
824
  "TO_TIMESTAMP_TZ": _build_datetime("TO_TIMESTAMP_TZ", exp.DataType.Type.TIMESTAMPTZ),
869
825
  "TO_VARCHAR": build_timetostr_or_tochar,
870
826
  "TO_JSON": exp.JSONFormat.from_arg_list,
827
+ "VECTOR_COSINE_SIMILARITY": exp.CosineDistance.from_arg_list,
828
+ "VECTOR_INNER_PRODUCT": exp.DotProduct.from_arg_list,
829
+ "VECTOR_L1_DISTANCE": exp.ManhattanDistance.from_arg_list,
871
830
  "VECTOR_L2_DISTANCE": exp.EuclideanDistance.from_arg_list,
872
831
  "ZEROIFNULL": _build_if_from_zeroifnull,
873
- "LIKE": _build_like(exp.Like),
874
- "ILIKE": _build_like(exp.ILike),
832
+ "LIKE": build_like(exp.Like),
833
+ "ILIKE": build_like(exp.ILike),
875
834
  "SEARCH": _build_search,
835
+ "SKEW": exp.Skewness.from_arg_list,
836
+ "SYSTIMESTAMP": exp.CurrentTimestamp.from_arg_list,
837
+ "WEEKISO": exp.WeekOfYear.from_arg_list,
838
+ "WEEKOFYEAR": exp.Week.from_arg_list,
876
839
  }
877
840
  FUNCTIONS.pop("PREDICT")
878
841
 
@@ -1078,30 +1041,11 @@ class Snowflake(Dialect):
1078
1041
  if not this:
1079
1042
  return None
1080
1043
 
1081
- self._match(TokenType.COMMA)
1082
- expression = self._parse_bitwise()
1083
- this = map_date_part(this)
1084
- name = this.name.upper()
1085
-
1086
- if name.startswith("EPOCH"):
1087
- if name == "EPOCH_MILLISECOND":
1088
- scale = 10**3
1089
- elif name == "EPOCH_MICROSECOND":
1090
- scale = 10**6
1091
- elif name == "EPOCH_NANOSECOND":
1092
- scale = 10**9
1093
- else:
1094
- scale = None
1095
-
1096
- ts = self.expression(exp.Cast, this=expression, to=exp.DataType.build("TIMESTAMP"))
1097
- to_unix: exp.Expression = self.expression(exp.TimeToUnix, this=ts)
1098
-
1099
- if scale:
1100
- to_unix = exp.Mul(this=to_unix, expression=exp.Literal.number(scale))
1101
-
1102
- return to_unix
1103
-
1104
- return self.expression(exp.Extract, this=this, expression=expression)
1044
+ # Handle both syntaxes: DATE_PART(part, expr) and DATE_PART(part FROM expr)
1045
+ expression = (
1046
+ self._match_set((TokenType.FROM, TokenType.COMMA)) and self._parse_bitwise()
1047
+ )
1048
+ return self.expression(exp.Extract, this=map_date_part(this), expression=expression)
1105
1049
 
1106
1050
  def _parse_bracket_key_value(self, is_map: bool = False) -> t.Optional[exp.Expression]:
1107
1051
  if is_map:
@@ -1237,19 +1181,17 @@ class Snowflake(Dialect):
1237
1181
 
1238
1182
  return self.expression(
1239
1183
  exp.Show,
1240
- **{
1241
- "terse": terse,
1242
- "this": this,
1243
- "history": history,
1244
- "like": like,
1245
- "scope": scope,
1246
- "scope_kind": scope_kind,
1247
- "starts_with": self._match_text_seq("STARTS", "WITH") and self._parse_string(),
1248
- "limit": self._parse_limit(),
1249
- "from": self._parse_string() if self._match(TokenType.FROM) else None,
1250
- "privileges": self._match_text_seq("WITH", "PRIVILEGES")
1251
- and self._parse_csv(lambda: self._parse_var(any_token=True, upper=True)),
1252
- },
1184
+ terse=terse,
1185
+ this=this,
1186
+ history=history,
1187
+ like=like,
1188
+ scope=scope,
1189
+ scope_kind=scope_kind,
1190
+ starts_with=self._match_text_seq("STARTS", "WITH") and self._parse_string(),
1191
+ limit=self._parse_limit(),
1192
+ from_=self._parse_string() if self._match(TokenType.FROM) else None,
1193
+ privileges=self._match_text_seq("WITH", "PRIVILEGES")
1194
+ and self._parse_csv(lambda: self._parse_var(any_token=True, upper=True)),
1253
1195
  )
1254
1196
 
1255
1197
  def _parse_put(self) -> exp.Put | exp.Command:
@@ -1353,15 +1295,26 @@ class Snowflake(Dialect):
1353
1295
  kwargs: t.Dict[str, t.Any] = {"this": self._parse_table_parts()}
1354
1296
 
1355
1297
  while self._curr and not self._match(TokenType.R_PAREN, advance=False):
1356
- if self._match_text_seq("DIMENSIONS"):
1357
- kwargs["dimensions"] = self._parse_csv(self._parse_disjunction)
1358
- if self._match_text_seq("METRICS"):
1359
- kwargs["metrics"] = self._parse_csv(self._parse_disjunction)
1360
- if self._match_text_seq("WHERE"):
1298
+ if self._match_texts(("DIMENSIONS", "METRICS", "FACTS")):
1299
+ keyword = self._prev.text.lower()
1300
+ kwargs[keyword] = self._parse_csv(self._parse_disjunction)
1301
+ elif self._match_text_seq("WHERE"):
1361
1302
  kwargs["where"] = self._parse_expression()
1303
+ else:
1304
+ self.raise_error("Expecting ) or encountered unexpected keyword")
1305
+ break
1362
1306
 
1363
1307
  return self.expression(exp.SemanticView, **kwargs)
1364
1308
 
1309
+ def _parse_set(self, unset: bool = False, tag: bool = False) -> exp.Set | exp.Command:
1310
+ set = super()._parse_set(unset=unset, tag=tag)
1311
+
1312
+ if isinstance(set, exp.Set):
1313
+ for expr in set.expressions:
1314
+ if isinstance(expr, exp.SetItem):
1315
+ expr.set("kind", "VARIABLE")
1316
+ return set
1317
+
1365
1318
  class Tokenizer(tokens.Tokenizer):
1366
1319
  STRING_ESCAPES = ["\\", "'"]
1367
1320
  HEX_STRINGS = [("x'", "'"), ("X'", "'")]
@@ -1393,6 +1346,9 @@ class Snowflake(Dialect):
1393
1346
  "TIMESTAMP_TZ": TokenType.TIMESTAMPTZ,
1394
1347
  "TOP": TokenType.TOP,
1395
1348
  "WAREHOUSE": TokenType.WAREHOUSE,
1349
+ # https://docs.snowflake.com/en/sql-reference/data-types-numeric#float
1350
+ # FLOAT is a synonym for DOUBLE in Snowflake
1351
+ "FLOAT": TokenType.DOUBLE,
1396
1352
  }
1397
1353
  KEYWORDS.pop("/*+")
1398
1354
 
@@ -1437,6 +1393,7 @@ class Snowflake(Dialect):
1437
1393
  exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
1438
1394
  exp.ArgMax: rename_func("MAX_BY"),
1439
1395
  exp.ArgMin: rename_func("MIN_BY"),
1396
+ exp.Array: transforms.preprocess([transforms.inherit_struct_field_names]),
1440
1397
  exp.ArrayConcat: lambda self, e: self.arrayconcat_sql(e, name="ARRAY_CAT"),
1441
1398
  exp.ArrayContains: lambda self, e: self.func(
1442
1399
  "ARRAY_CONTAINS",
@@ -1468,10 +1425,12 @@ class Snowflake(Dialect):
1468
1425
  exp.DayOfWeek: rename_func("DAYOFWEEK"),
1469
1426
  exp.DayOfWeekIso: rename_func("DAYOFWEEKISO"),
1470
1427
  exp.DayOfYear: rename_func("DAYOFYEAR"),
1428
+ exp.DotProduct: rename_func("VECTOR_INNER_PRODUCT"),
1471
1429
  exp.Explode: rename_func("FLATTEN"),
1472
1430
  exp.Extract: lambda self, e: self.func(
1473
1431
  "DATE_PART", map_date_part(e.this, self.dialect), e.expression
1474
1432
  ),
1433
+ exp.CosineDistance: rename_func("VECTOR_COSINE_SIMILARITY"),
1475
1434
  exp.EuclideanDistance: rename_func("VECTOR_L2_DISTANCE"),
1476
1435
  exp.FileFormatProperty: lambda self,
1477
1436
  e: f"FILE_FORMAT=({self.expressions(e, 'expressions', sep=' ')})",
@@ -1498,11 +1457,31 @@ class Snowflake(Dialect):
1498
1457
  exp.LogicalAnd: rename_func("BOOLAND_AGG"),
1499
1458
  exp.LogicalOr: rename_func("BOOLOR_AGG"),
1500
1459
  exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
1460
+ exp.ManhattanDistance: rename_func("VECTOR_L1_DISTANCE"),
1501
1461
  exp.MakeInterval: no_make_interval_sql,
1502
1462
  exp.Max: max_or_greatest,
1503
1463
  exp.Min: min_or_least,
1504
1464
  exp.ParseJSON: lambda self, e: self.func(
1505
- "TRY_PARSE_JSON" if e.args.get("safe") else "PARSE_JSON", e.this
1465
+ f"{'TRY_' if e.args.get('safe') else ''}PARSE_JSON", e.this
1466
+ ),
1467
+ exp.ToBinary: lambda self, e: self.func(
1468
+ f"{'TRY_' if e.args.get('safe') else ''}TO_BINARY", e.this, e.args.get("format")
1469
+ ),
1470
+ exp.ToBoolean: lambda self, e: self.func(
1471
+ f"{'TRY_' if e.args.get('safe') else ''}TO_BOOLEAN", e.this
1472
+ ),
1473
+ exp.ToDouble: lambda self, e: self.func(
1474
+ f"{'TRY_' if e.args.get('safe') else ''}TO_DOUBLE", e.this, e.args.get("format")
1475
+ ),
1476
+ exp.ToFile: lambda self, e: self.func(
1477
+ f"{'TRY_' if e.args.get('safe') else ''}TO_FILE", e.this, e.args.get("path")
1478
+ ),
1479
+ exp.ToNumber: lambda self, e: self.func(
1480
+ f"{'TRY_' if e.args.get('safe') else ''}TO_NUMBER",
1481
+ e.this,
1482
+ e.args.get("format"),
1483
+ e.args.get("precision"),
1484
+ e.args.get("scale"),
1506
1485
  ),
1507
1486
  exp.JSONFormat: rename_func("TO_JSON"),
1508
1487
  exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
@@ -1529,11 +1508,13 @@ class Snowflake(Dialect):
1529
1508
  ]
1530
1509
  ),
1531
1510
  exp.SHA: rename_func("SHA1"),
1511
+ exp.SHA1Digest: rename_func("SHA1_BINARY"),
1532
1512
  exp.MD5Digest: rename_func("MD5_BINARY"),
1533
1513
  exp.MD5NumberLower64: rename_func("MD5_NUMBER_LOWER64"),
1534
1514
  exp.MD5NumberUpper64: rename_func("MD5_NUMBER_UPPER64"),
1535
1515
  exp.LowerHex: rename_func("TO_CHAR"),
1536
1516
  exp.SortArray: rename_func("ARRAY_SORT"),
1517
+ exp.Skewness: rename_func("SKEW"),
1537
1518
  exp.StarMap: rename_func("OBJECT_CONSTRUCT"),
1538
1519
  exp.StartsWith: rename_func("STARTSWITH"),
1539
1520
  exp.EndsWith: rename_func("ENDSWITH"),
@@ -1545,6 +1526,13 @@ class Snowflake(Dialect):
1545
1526
  exp.Stuff: rename_func("INSERT"),
1546
1527
  exp.StPoint: rename_func("ST_MAKEPOINT"),
1547
1528
  exp.TimeAdd: date_delta_sql("TIMEADD"),
1529
+ exp.TimeSlice: lambda self, e: self.func(
1530
+ "TIME_SLICE",
1531
+ e.this,
1532
+ e.expression,
1533
+ unit_to_str(e),
1534
+ e.args.get("kind"),
1535
+ ),
1548
1536
  exp.Timestamp: no_timestamp_sql,
1549
1537
  exp.TimestampAdd: date_delta_sql("TIMESTAMPADD"),
1550
1538
  exp.TimestampDiff: lambda self, e: self.func(
@@ -1555,22 +1543,31 @@ class Snowflake(Dialect):
1555
1543
  exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
1556
1544
  exp.ToArray: rename_func("TO_ARRAY"),
1557
1545
  exp.ToChar: lambda self, e: self.function_fallback_sql(e),
1558
- exp.ToDouble: rename_func("TO_DOUBLE"),
1559
1546
  exp.TsOrDsAdd: date_delta_sql("DATEADD", cast=True),
1560
1547
  exp.TsOrDsDiff: date_delta_sql("DATEDIFF"),
1561
1548
  exp.TsOrDsToDate: lambda self, e: self.func(
1562
- "TRY_TO_DATE" if e.args.get("safe") else "TO_DATE", e.this, self.format_time(e)
1549
+ f"{'TRY_' if e.args.get('safe') else ''}TO_DATE", e.this, self.format_time(e)
1563
1550
  ),
1564
1551
  exp.TsOrDsToTime: lambda self, e: self.func(
1565
- "TRY_TO_TIME" if e.args.get("safe") else "TO_TIME", e.this, self.format_time(e)
1552
+ f"{'TRY_' if e.args.get('safe') else ''}TO_TIME", e.this, self.format_time(e)
1566
1553
  ),
1567
1554
  exp.Unhex: rename_func("HEX_DECODE_BINARY"),
1568
1555
  exp.UnixToTime: rename_func("TO_TIMESTAMP"),
1569
1556
  exp.Uuid: rename_func("UUID_STRING"),
1570
1557
  exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
1571
- exp.WeekOfYear: rename_func("WEEKOFYEAR"),
1558
+ exp.Booland: rename_func("BOOLAND"),
1559
+ exp.Boolor: rename_func("BOOLOR"),
1560
+ exp.WeekOfYear: rename_func("WEEKISO"),
1561
+ exp.YearOfWeek: rename_func("YEAROFWEEK"),
1562
+ exp.YearOfWeekIso: rename_func("YEAROFWEEKISO"),
1572
1563
  exp.Xor: rename_func("BOOLXOR"),
1573
1564
  exp.ByteLength: rename_func("OCTET_LENGTH"),
1565
+ exp.ArrayConcatAgg: lambda self, e: self.func(
1566
+ "ARRAY_FLATTEN", exp.ArrayAgg(this=e.this)
1567
+ ),
1568
+ exp.SHA2Digest: lambda self, e: self.func(
1569
+ "SHA2_BINARY", e.this, e.args.get("length") or exp.Literal.number(256)
1570
+ ),
1574
1571
  }
1575
1572
 
1576
1573
  SUPPORTED_JSON_PATH_PARTS = {
@@ -1619,6 +1616,15 @@ class Snowflake(Dialect):
1619
1616
  return super().values_sql(expression, values_as_table=values_as_table)
1620
1617
 
1621
1618
  def datatype_sql(self, expression: exp.DataType) -> str:
1619
+ # Check if this is a FLOAT type nested inside a VECTOR type
1620
+ # VECTOR only accepts FLOAT (not DOUBLE), INT, and STRING as element types
1621
+ # https://docs.snowflake.com/en/sql-reference/data-types-vector
1622
+ if expression.is_type(exp.DataType.Type.DOUBLE):
1623
+ parent = expression.parent
1624
+ if isinstance(parent, exp.DataType) and parent.is_type(exp.DataType.Type.VECTOR):
1625
+ # Preserve FLOAT for VECTOR types instead of mapping to synonym DOUBLE
1626
+ return "FLOAT"
1627
+
1622
1628
  expressions = expression.expressions
1623
1629
  if expressions and expression.is_type(*exp.DataType.STRUCT_TYPES):
1624
1630
  for field_type in expressions:
@@ -1748,7 +1754,7 @@ class Snowflake(Dialect):
1748
1754
 
1749
1755
  limit = self.sql(expression, "limit")
1750
1756
 
1751
- from_ = self.sql(expression, "from")
1757
+ from_ = self.sql(expression, "from_")
1752
1758
  if from_:
1753
1759
  from_ = f" FROM {from_}"
1754
1760
 
@@ -1824,9 +1830,10 @@ class Snowflake(Dialect):
1824
1830
  return f"SET{exprs}{file_format}{copy_options}{tag}"
1825
1831
 
1826
1832
  def strtotime_sql(self, expression: exp.StrToTime):
1827
- safe_prefix = "TRY_" if expression.args.get("safe") else ""
1828
1833
  return self.func(
1829
- f"{safe_prefix}TO_TIMESTAMP", expression.this, self.format_time(expression)
1834
+ f"{'TRY_' if expression.args.get('safe') else ''}TO_TIMESTAMP",
1835
+ expression.this,
1836
+ self.format_time(expression),
1830
1837
  )
1831
1838
 
1832
1839
  def timestampsub_sql(self, expression: exp.TimestampSub):
@@ -1984,3 +1991,17 @@ class Snowflake(Dialect):
1984
1991
  expression.set("part_index", exp.Literal.number(1))
1985
1992
 
1986
1993
  return rename_func("SPLIT_PART")(self, expression)
1994
+
1995
+ def uniform_sql(self, expression: exp.Uniform) -> str:
1996
+ gen = expression.args.get("gen")
1997
+ seed = expression.args.get("seed")
1998
+
1999
+ # From Databricks UNIFORM(min, max, seed) -> Wrap gen in RANDOM(seed)
2000
+ if seed:
2001
+ gen = exp.Rand(this=seed)
2002
+
2003
+ # No gen argument (from Databricks 2-arg UNIFORM(min, max)) -> Add RANDOM()
2004
+ if not gen:
2005
+ gen = exp.Rand()
2006
+
2007
+ return self.func("UNIFORM", expression.this, expression.expression, gen)