sqlglot 28.4.1__py3-none-any.whl → 28.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. sqlglot/_version.py +2 -2
  2. sqlglot/dialects/bigquery.py +20 -23
  3. sqlglot/dialects/clickhouse.py +2 -0
  4. sqlglot/dialects/dialect.py +355 -18
  5. sqlglot/dialects/doris.py +38 -90
  6. sqlglot/dialects/druid.py +1 -0
  7. sqlglot/dialects/duckdb.py +1739 -163
  8. sqlglot/dialects/exasol.py +17 -1
  9. sqlglot/dialects/hive.py +27 -2
  10. sqlglot/dialects/mysql.py +103 -11
  11. sqlglot/dialects/oracle.py +38 -1
  12. sqlglot/dialects/postgres.py +142 -33
  13. sqlglot/dialects/presto.py +6 -2
  14. sqlglot/dialects/redshift.py +7 -1
  15. sqlglot/dialects/singlestore.py +13 -3
  16. sqlglot/dialects/snowflake.py +271 -21
  17. sqlglot/dialects/spark.py +25 -0
  18. sqlglot/dialects/spark2.py +4 -3
  19. sqlglot/dialects/starrocks.py +152 -17
  20. sqlglot/dialects/trino.py +1 -0
  21. sqlglot/dialects/tsql.py +5 -0
  22. sqlglot/diff.py +1 -1
  23. sqlglot/expressions.py +239 -47
  24. sqlglot/generator.py +173 -44
  25. sqlglot/optimizer/annotate_types.py +129 -60
  26. sqlglot/optimizer/merge_subqueries.py +13 -2
  27. sqlglot/optimizer/qualify_columns.py +7 -0
  28. sqlglot/optimizer/resolver.py +19 -0
  29. sqlglot/optimizer/scope.py +12 -0
  30. sqlglot/optimizer/unnest_subqueries.py +7 -0
  31. sqlglot/parser.py +251 -58
  32. sqlglot/schema.py +186 -14
  33. sqlglot/tokens.py +36 -6
  34. sqlglot/transforms.py +6 -5
  35. sqlglot/typing/__init__.py +29 -10
  36. sqlglot/typing/bigquery.py +5 -10
  37. sqlglot/typing/duckdb.py +39 -0
  38. sqlglot/typing/hive.py +50 -1
  39. sqlglot/typing/mysql.py +32 -0
  40. sqlglot/typing/presto.py +0 -1
  41. sqlglot/typing/snowflake.py +80 -17
  42. sqlglot/typing/spark.py +29 -0
  43. sqlglot/typing/spark2.py +9 -1
  44. sqlglot/typing/tsql.py +21 -0
  45. {sqlglot-28.4.1.dist-info → sqlglot-28.8.0.dist-info}/METADATA +47 -2
  46. sqlglot-28.8.0.dist-info/RECORD +95 -0
  47. {sqlglot-28.4.1.dist-info → sqlglot-28.8.0.dist-info}/WHEEL +1 -1
  48. sqlglot-28.4.1.dist-info/RECORD +0 -92
  49. {sqlglot-28.4.1.dist-info → sqlglot-28.8.0.dist-info}/licenses/LICENSE +0 -0
  50. {sqlglot-28.4.1.dist-info → sqlglot-28.8.0.dist-info}/top_level.txt +0 -0
@@ -6,6 +6,8 @@ from sqlglot import exp, generator, jsonpath, parser, tokens, transforms
6
6
  from sqlglot.dialects.dialect import (
7
7
  Dialect,
8
8
  NormalizationStrategy,
9
+ array_append_sql,
10
+ array_concat_sql,
9
11
  build_timetostr_or_tochar,
10
12
  build_like,
11
13
  binary_from_function,
@@ -32,7 +34,7 @@ from sqlglot.dialects.dialect import (
32
34
  groupconcat_sql,
33
35
  )
34
36
  from sqlglot.generator import unsupported_args
35
- from sqlglot.helper import find_new_name, flatten, is_float, is_int, seq_get
37
+ from sqlglot.helper import find_new_name, flatten, is_date_unit, is_int, seq_get
36
38
  from sqlglot.optimizer.scope import build_scope, find_all_in_scope
37
39
  from sqlglot.tokens import TokenType
38
40
  from sqlglot.typing.snowflake import EXPRESSION_METADATA
@@ -41,6 +43,15 @@ if t.TYPE_CHECKING:
41
43
  from sqlglot._typing import E, B
42
44
 
43
45
 
46
+ # Timestamp types used in _build_datetime
47
+ TIMESTAMP_TYPES = {
48
+ exp.DataType.Type.TIMESTAMP: "TO_TIMESTAMP",
49
+ exp.DataType.Type.TIMESTAMPLTZ: "TO_TIMESTAMP_LTZ",
50
+ exp.DataType.Type.TIMESTAMPNTZ: "TO_TIMESTAMP_NTZ",
51
+ exp.DataType.Type.TIMESTAMPTZ: "TO_TIMESTAMP_TZ",
52
+ }
53
+
54
+
44
55
  def _build_strtok(args: t.List) -> exp.SplitPart:
45
56
  # Add default delimiter (space) if missing - per Snowflake docs
46
57
  if len(args) == 1:
@@ -68,6 +79,15 @@ def _build_approx_top_k(args: t.List) -> exp.ApproxTopK:
68
79
  return exp.ApproxTopK.from_arg_list(args)
69
80
 
70
81
 
82
+ def _build_date_from_parts(args: t.List) -> exp.DateFromParts:
83
+ return exp.DateFromParts(
84
+ year=seq_get(args, 0),
85
+ month=seq_get(args, 1),
86
+ day=seq_get(args, 2),
87
+ allow_overflow=True,
88
+ )
89
+
90
+
71
91
  def _build_datetime(
72
92
  name: str, kind: exp.DataType.Type, safe: bool = False
73
93
  ) -> t.Callable[[t.List], exp.Func]:
@@ -78,7 +98,7 @@ def _build_datetime(
78
98
  int_value = value is not None and is_int(value.name)
79
99
  int_scale_or_fmt = scale_or_fmt is not None and scale_or_fmt.is_int
80
100
 
81
- if isinstance(value, exp.Literal) or (value and scale_or_fmt):
101
+ if isinstance(value, (exp.Literal, exp.Neg)) or (value and scale_or_fmt):
82
102
  # Converts calls like `TO_TIME('01:02:03')` into casts
83
103
  if len(args) == 1 and value.is_string and not int_value:
84
104
  return (
@@ -89,17 +109,27 @@ def _build_datetime(
89
109
 
90
110
  # Handles `TO_TIMESTAMP(str, fmt)` and `TO_TIMESTAMP(num, scale)` as special
91
111
  # cases so we can transpile them, since they're relatively common
92
- if kind == exp.DataType.Type.TIMESTAMP:
93
- if not safe and (int_value or int_scale_or_fmt):
112
+ if kind in TIMESTAMP_TYPES:
113
+ if not safe and (int_scale_or_fmt or (int_value and scale_or_fmt is None)):
94
114
  # TRY_TO_TIMESTAMP('integer') is not parsed into exp.UnixToTime as
95
- # it's not easily transpilable
96
- return exp.UnixToTime(this=value, scale=scale_or_fmt)
97
- if not int_scale_or_fmt and not is_float(value.name):
98
- expr = build_formatted_time(exp.StrToTime, "snowflake")(args)
99
- expr.set("safe", safe)
100
- return expr
101
-
102
- if kind in (exp.DataType.Type.DATE, exp.DataType.Type.TIME) and not int_value:
115
+ # it's not easily transpilable. Also, numeric-looking strings with
116
+ # format strings (e.g., TO_TIMESTAMP('20240115', 'YYYYMMDD')) should
117
+ # use StrToTime, not UnixToTime.
118
+ unix_expr = exp.UnixToTime(this=value, scale=scale_or_fmt)
119
+ unix_expr.set("target_type", exp.DataType.build(kind, dialect="snowflake"))
120
+ return unix_expr
121
+ if scale_or_fmt and not int_scale_or_fmt:
122
+ # Format string provided (e.g., 'YYYY-MM-DD'), use StrToTime
123
+ strtotime_expr = build_formatted_time(exp.StrToTime, "snowflake")(args)
124
+ strtotime_expr.set("safe", safe)
125
+ strtotime_expr.set("target_type", exp.DataType.build(kind, dialect="snowflake"))
126
+ return strtotime_expr
127
+
128
+ # Handle DATE/TIME with format strings - allow int_value if a format string is provided
129
+ has_format_string = scale_or_fmt and not int_scale_or_fmt
130
+ if kind in (exp.DataType.Type.DATE, exp.DataType.Type.TIME) and (
131
+ not int_value or has_format_string
132
+ ):
103
133
  klass = exp.TsOrDsToDate if kind == exp.DataType.Type.DATE else exp.TsOrDsToTime
104
134
  formatted_exp = build_formatted_time(klass, "snowflake")(args)
105
135
  formatted_exp.set("safe", safe)
@@ -125,7 +155,10 @@ def _build_object_construct(args: t.List) -> t.Union[exp.StarMap, exp.Struct]:
125
155
 
126
156
  def _build_datediff(args: t.List) -> exp.DateDiff:
127
157
  return exp.DateDiff(
128
- this=seq_get(args, 2), expression=seq_get(args, 1), unit=map_date_part(seq_get(args, 0))
158
+ this=seq_get(args, 2),
159
+ expression=seq_get(args, 1),
160
+ unit=map_date_part(seq_get(args, 0)),
161
+ date_part_boundary=True,
129
162
  )
130
163
 
131
164
 
@@ -150,7 +183,13 @@ def _build_bitwise(expr_type: t.Type[B], name: str) -> t.Callable[[t.List], B |
150
183
  )
151
184
  return exp.Anonymous(this=name, expressions=args)
152
185
 
153
- return binary_from_function(expr_type)(args)
186
+ result = binary_from_function(expr_type)(args)
187
+
188
+ # Snowflake specifies INT128 for bitwise shifts
189
+ if expr_type in (exp.BitwiseLeftShift, exp.BitwiseRightShift):
190
+ result.set("requires_int128", True)
191
+
192
+ return result
154
193
 
155
194
  return _builder
156
195
 
@@ -232,7 +271,13 @@ def _show_parser(*args: t.Any, **kwargs: t.Any) -> t.Callable[[Snowflake.Parser]
232
271
 
233
272
  def _date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
234
273
  trunc = date_trunc_to_time(args)
235
- trunc.set("unit", map_date_part(trunc.args["unit"]))
274
+ unit = map_date_part(trunc.args["unit"])
275
+ trunc.set("unit", unit)
276
+ is_time_input = trunc.this.is_type(exp.DataType.Type.TIME, exp.DataType.Type.TIMETZ)
277
+ if (isinstance(trunc, exp.TimestampTrunc) and is_date_unit(unit) or is_time_input) or (
278
+ isinstance(trunc, exp.DateTrunc) and not is_date_unit(unit)
279
+ ):
280
+ trunc.set("input_type_preserved", True)
236
281
  return trunc
237
282
 
238
283
 
@@ -595,6 +640,25 @@ def _build_round(args: t.List) -> exp.Round:
595
640
  return expression
596
641
 
597
642
 
643
+ def _build_generator(args: t.List) -> exp.Generator:
644
+ """
645
+ Build Generator expression, unwrapping Snowflake's named parameters.
646
+
647
+ Maps ROWCOUNT => rowcount, TIMELIMIT => time_limit.
648
+ """
649
+ kwarg_map = {"ROWCOUNT": "rowcount", "TIMELIMIT": "time_limit"}
650
+ gen_args = {}
651
+
652
+ for arg in args:
653
+ if isinstance(arg, exp.Kwarg):
654
+ key = arg.this.name.upper()
655
+ gen_key = kwarg_map.get(key)
656
+ if gen_key:
657
+ gen_args[gen_key] = arg.expression
658
+
659
+ return exp.Generator(**gen_args)
660
+
661
+
598
662
  def _build_try_to_number(args: t.List[exp.Expression]) -> exp.Expression:
599
663
  return exp.ToNumber(
600
664
  this=seq_get(args, 0),
@@ -616,15 +680,21 @@ class Snowflake(Dialect):
616
680
  TABLESAMPLE_SIZE_IS_PERCENT = True
617
681
  COPY_PARAMS_ARE_CSV = False
618
682
  ARRAY_AGG_INCLUDES_NULLS = None
683
+ ARRAY_FUNCS_PROPAGATES_NULLS = True
619
684
  ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN = False
620
685
  TRY_CAST_REQUIRES_STRING = True
621
686
  SUPPORTS_ALIAS_REFS_IN_JOIN_CONDITIONS = True
687
+ LEAST_GREATEST_IGNORES_NULLS = False
622
688
 
623
689
  EXPRESSION_METADATA = EXPRESSION_METADATA.copy()
624
690
 
625
691
  # https://docs.snowflake.com/en/en/sql-reference/functions/initcap
626
692
  INITCAP_DEFAULT_DELIMITER_CHARS = ' \t\n\r\f\v!?@"^#$&~_,.:;+\\-*%/|\\[\\](){}<>'
627
693
 
694
+ INVERSE_TIME_MAPPING = {
695
+ "T": "T", # in TIME_MAPPING we map '"T"' with the double quotes to 'T', and we want to prevent 'T' from being mapped back to '"T"' so that 'AUTO' doesn't become 'AU"T"O'
696
+ }
697
+
628
698
  TIME_MAPPING = {
629
699
  "YYYY": "%Y",
630
700
  "yyyy": "%Y",
@@ -648,13 +718,55 @@ class Snowflake(Dialect):
648
718
  "mi": "%M",
649
719
  "SS": "%S",
650
720
  "ss": "%S",
721
+ "FF": "%f_nine", # %f_ internal representation with precision specified
722
+ "ff": "%f_nine",
723
+ "FF0": "%f_zero",
724
+ "ff0": "%f_zero",
725
+ "FF1": "%f_one",
726
+ "ff1": "%f_one",
727
+ "FF2": "%f_two",
728
+ "ff2": "%f_two",
729
+ "FF3": "%f_three",
730
+ "ff3": "%f_three",
731
+ "FF4": "%f_four",
732
+ "ff4": "%f_four",
733
+ "FF5": "%f_five",
734
+ "ff5": "%f_five",
651
735
  "FF6": "%f",
652
736
  "ff6": "%f",
737
+ "FF7": "%f_seven",
738
+ "ff7": "%f_seven",
739
+ "FF8": "%f_eight",
740
+ "ff8": "%f_eight",
741
+ "FF9": "%f_nine",
742
+ "ff9": "%f_nine",
743
+ "TZHTZM": "%z",
744
+ "tzhtzm": "%z",
745
+ "TZH:TZM": "%:z", # internal representation for ±HH:MM
746
+ "tzh:tzm": "%:z",
747
+ "TZH": "%-z", # internal representation ±HH
748
+ "tzh": "%-z",
749
+ '"T"': "T", # remove the optional double quotes around the separator between the date and time
750
+ # Seems like Snowflake treats AM/PM in the format string as equivalent,
751
+ # only the time (stamp) value's AM/PM affects the output
752
+ "AM": "%p",
753
+ "am": "%p",
754
+ "PM": "%p",
755
+ "pm": "%p",
653
756
  }
654
757
 
655
758
  DATE_PART_MAPPING = {
656
759
  **Dialect.DATE_PART_MAPPING,
657
760
  "ISOWEEK": "WEEKISO",
761
+ # The base Dialect maps EPOCH_SECOND -> EPOCH, but we need to preserve
762
+ # EPOCH_SECOND as a distinct value for two reasons:
763
+ # 1. Type annotation: EPOCH_SECOND returns BIGINT, while EPOCH returns DOUBLE
764
+ # 2. Transpilation: DuckDB's EPOCH() returns float, so we cast EPOCH_SECOND
765
+ # to BIGINT to match Snowflake's integer behavior
766
+ # Without this override, EXTRACT(EPOCH_SECOND FROM ts) would be normalized
767
+ # to EXTRACT(EPOCH FROM ts) and lose the integer semantics.
768
+ "EPOCH_SECOND": "EPOCH_SECOND",
769
+ "EPOCH_SECONDS": "EPOCH_SECOND",
658
770
  }
659
771
 
660
772
  PSEUDOCOLUMNS = {"LEVEL"}
@@ -689,9 +801,20 @@ class Snowflake(Dialect):
689
801
 
690
802
  COLON_PLACEHOLDER_TOKENS = ID_VAR_TOKENS | {TokenType.NUMBER}
691
803
 
804
+ NO_PAREN_FUNCTIONS = {
805
+ **parser.Parser.NO_PAREN_FUNCTIONS,
806
+ TokenType.CURRENT_TIME: exp.Localtime,
807
+ }
808
+
692
809
  FUNCTIONS = {
693
810
  **parser.Parser.FUNCTIONS,
811
+ "ADD_MONTHS": lambda args: exp.AddMonths(
812
+ this=seq_get(args, 0),
813
+ expression=seq_get(args, 1),
814
+ preserve_end_of_month=True,
815
+ ),
694
816
  "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
817
+ "CURRENT_TIME": lambda args: exp.Localtime(this=seq_get(args, 0)),
695
818
  "APPROX_TOP_K": _build_approx_top_k,
696
819
  "ARRAY_CONSTRUCT": lambda args: exp.Array(expressions=args),
697
820
  "ARRAY_CONTAINS": lambda args: exp.ArrayContains(
@@ -704,6 +827,7 @@ class Snowflake(Dialect):
704
827
  step=seq_get(args, 2),
705
828
  ),
706
829
  "ARRAY_SORT": exp.SortArray.from_arg_list,
830
+ "ARRAY_FLATTEN": exp.Flatten.from_arg_list,
707
831
  "BITAND": _build_bitwise(exp.BitwiseAnd, "BITAND"),
708
832
  "BIT_AND": _build_bitwise(exp.BitwiseAnd, "BITAND"),
709
833
  "BITNOT": lambda args: exp.BitwiseNot(this=seq_get(args, 0)),
@@ -729,11 +853,28 @@ class Snowflake(Dialect):
729
853
  "BIT_XOR_AGG": exp.BitwiseXorAgg.from_arg_list,
730
854
  "BIT_XORAGG": exp.BitwiseXorAgg.from_arg_list,
731
855
  "BITMAP_OR_AGG": exp.BitmapOrAgg.from_arg_list,
732
- "BOOLXOR": _build_bitwise(exp.Xor, "BOOLXOR"),
856
+ "BOOLAND": lambda args: exp.Booland(
857
+ this=seq_get(args, 0), expression=seq_get(args, 1), round_input=True
858
+ ),
859
+ "BOOLOR": lambda args: exp.Boolor(
860
+ this=seq_get(args, 0), expression=seq_get(args, 1), round_input=True
861
+ ),
862
+ "BOOLNOT": lambda args: exp.Boolnot(this=seq_get(args, 0), round_input=True),
863
+ "BOOLXOR": lambda args: exp.Xor(
864
+ this=seq_get(args, 0), expression=seq_get(args, 1), round_input=True
865
+ ),
866
+ "CORR": lambda args: exp.Corr(
867
+ this=seq_get(args, 0),
868
+ expression=seq_get(args, 1),
869
+ null_on_zero_variance=True,
870
+ ),
733
871
  "DATE": _build_datetime("DATE", exp.DataType.Type.DATE),
872
+ "DATEFROMPARTS": _build_date_from_parts,
873
+ "DATE_FROM_PARTS": _build_date_from_parts,
734
874
  "DATE_TRUNC": _date_trunc_to_time,
735
875
  "DATEADD": _build_date_time_add(exp.DateAdd),
736
876
  "DATEDIFF": _build_datediff,
877
+ "DAYNAME": lambda args: exp.Dayname(this=seq_get(args, 0), abbreviated=True),
737
878
  "DAYOFWEEKISO": exp.DayOfWeekIso.from_arg_list,
738
879
  "DIV0": _build_if_from_div0,
739
880
  "DIV0NULL": _build_if_from_div0null,
@@ -741,6 +882,7 @@ class Snowflake(Dialect):
741
882
  this=seq_get(args, 0), expression=seq_get(args, 1), max_dist=seq_get(args, 2)
742
883
  ),
743
884
  "FLATTEN": exp.Explode.from_arg_list,
885
+ "GENERATOR": _build_generator,
744
886
  "GET": exp.GetExtract.from_arg_list,
745
887
  "GETDATE": exp.CurrentTimestamp.from_arg_list,
746
888
  "GET_PATH": lambda args, dialect: exp.JSONExtract(
@@ -748,19 +890,28 @@ class Snowflake(Dialect):
748
890
  expression=dialect.to_json_path(seq_get(args, 1)),
749
891
  requires_json=True,
750
892
  ),
893
+ "GREATEST_IGNORE_NULLS": lambda args: exp.Greatest(
894
+ this=seq_get(args, 0), expressions=args[1:], ignore_nulls=True
895
+ ),
896
+ "LEAST_IGNORE_NULLS": lambda args: exp.Least(
897
+ this=seq_get(args, 0), expressions=args[1:], ignore_nulls=True
898
+ ),
751
899
  "HEX_DECODE_BINARY": exp.Unhex.from_arg_list,
752
900
  "IFF": exp.If.from_arg_list,
753
901
  "MD5_HEX": exp.MD5.from_arg_list,
754
902
  "MD5_BINARY": exp.MD5Digest.from_arg_list,
755
903
  "MD5_NUMBER_LOWER64": exp.MD5NumberLower64.from_arg_list,
756
904
  "MD5_NUMBER_UPPER64": exp.MD5NumberUpper64.from_arg_list,
905
+ "MONTHNAME": lambda args: exp.Monthname(this=seq_get(args, 0), abbreviated=True),
757
906
  "LAST_DAY": lambda args: exp.LastDay(
758
907
  this=seq_get(args, 0), unit=map_date_part(seq_get(args, 1))
759
908
  ),
760
909
  "LEN": lambda args: exp.Length(this=seq_get(args, 0), binary=True),
761
910
  "LENGTH": lambda args: exp.Length(this=seq_get(args, 0), binary=True),
911
+ "LOCALTIMESTAMP": exp.CurrentTimestamp.from_arg_list,
762
912
  "NULLIFZERO": _build_if_from_nullifzero,
763
913
  "OBJECT_CONSTRUCT": _build_object_construct,
914
+ "OBJECT_KEYS": exp.JSONKeys.from_arg_list,
764
915
  "OCTET_LENGTH": exp.ByteLength.from_arg_list,
765
916
  "PARSE_URL": lambda args: exp.ParseUrl(
766
917
  this=seq_get(args, 0), permissive=seq_get(args, 1)
@@ -777,16 +928,41 @@ class Snowflake(Dialect):
777
928
  "SHA2_BINARY": exp.SHA2Digest.from_arg_list,
778
929
  "SHA2_HEX": exp.SHA2.from_arg_list,
779
930
  "SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
931
+ "STDDEV_SAMP": exp.Stddev.from_arg_list,
780
932
  "STRTOK": _build_strtok,
933
+ "SYSDATE": lambda args: exp.CurrentTimestamp(this=seq_get(args, 0), sysdate=True),
781
934
  "TABLE": lambda args: exp.TableFromRows(this=seq_get(args, 0)),
782
935
  "TIMEADD": _build_date_time_add(exp.TimeAdd),
783
936
  "TIMEDIFF": _build_datediff,
937
+ "TIME_FROM_PARTS": lambda args: exp.TimeFromParts(
938
+ hour=seq_get(args, 0),
939
+ min=seq_get(args, 1),
940
+ sec=seq_get(args, 2),
941
+ nano=seq_get(args, 3),
942
+ overflow=True,
943
+ ),
784
944
  "TIMESTAMPADD": _build_date_time_add(exp.DateAdd),
785
945
  "TIMESTAMPDIFF": _build_datediff,
786
946
  "TIMESTAMPFROMPARTS": _build_timestamp_from_parts,
787
947
  "TIMESTAMP_FROM_PARTS": _build_timestamp_from_parts,
788
948
  "TIMESTAMPNTZFROMPARTS": _build_timestamp_from_parts,
789
949
  "TIMESTAMP_NTZ_FROM_PARTS": _build_timestamp_from_parts,
950
+ "TRY_DECRYPT": lambda args: exp.Decrypt(
951
+ this=seq_get(args, 0),
952
+ passphrase=seq_get(args, 1),
953
+ aad=seq_get(args, 2),
954
+ encryption_method=seq_get(args, 3),
955
+ safe=True,
956
+ ),
957
+ "TRY_DECRYPT_RAW": lambda args: exp.DecryptRaw(
958
+ this=seq_get(args, 0),
959
+ key=seq_get(args, 1),
960
+ iv=seq_get(args, 2),
961
+ aad=seq_get(args, 3),
962
+ encryption_method=seq_get(args, 4),
963
+ aead=seq_get(args, 5),
964
+ safe=True,
965
+ ),
790
966
  "TRY_PARSE_JSON": lambda args: exp.ParseJSON(this=seq_get(args, 0), safe=True),
791
967
  "TRY_TO_BINARY": lambda args: exp.ToBinary(
792
968
  this=seq_get(args, 0), format=seq_get(args, 1), safe=True
@@ -806,6 +982,15 @@ class Snowflake(Dialect):
806
982
  "TRY_TO_TIMESTAMP": _build_datetime(
807
983
  "TRY_TO_TIMESTAMP", exp.DataType.Type.TIMESTAMP, safe=True
808
984
  ),
985
+ "TRY_TO_TIMESTAMP_LTZ": _build_datetime(
986
+ "TRY_TO_TIMESTAMP_LTZ", exp.DataType.Type.TIMESTAMPLTZ, safe=True
987
+ ),
988
+ "TRY_TO_TIMESTAMP_NTZ": _build_datetime(
989
+ "TRY_TO_TIMESTAMP_NTZ", exp.DataType.Type.TIMESTAMPNTZ, safe=True
990
+ ),
991
+ "TRY_TO_TIMESTAMP_TZ": _build_datetime(
992
+ "TRY_TO_TIMESTAMP_TZ", exp.DataType.Type.TIMESTAMPTZ, safe=True
993
+ ),
809
994
  "TO_CHAR": build_timetostr_or_tochar,
810
995
  "TO_DATE": _build_datetime("TO_DATE", exp.DataType.Type.DATE),
811
996
  **dict.fromkeys(
@@ -820,7 +1005,7 @@ class Snowflake(Dialect):
820
1005
  "TO_TIME": _build_datetime("TO_TIME", exp.DataType.Type.TIME),
821
1006
  "TO_TIMESTAMP": _build_datetime("TO_TIMESTAMP", exp.DataType.Type.TIMESTAMP),
822
1007
  "TO_TIMESTAMP_LTZ": _build_datetime("TO_TIMESTAMP_LTZ", exp.DataType.Type.TIMESTAMPLTZ),
823
- "TO_TIMESTAMP_NTZ": _build_datetime("TO_TIMESTAMP_NTZ", exp.DataType.Type.TIMESTAMP),
1008
+ "TO_TIMESTAMP_NTZ": _build_datetime("TO_TIMESTAMP_NTZ", exp.DataType.Type.TIMESTAMPNTZ),
824
1009
  "TO_TIMESTAMP_TZ": _build_datetime("TO_TIMESTAMP_TZ", exp.DataType.Type.TIMESTAMPTZ),
825
1010
  "TO_VARCHAR": build_timetostr_or_tochar,
826
1011
  "TO_JSON": exp.JSONFormat.from_arg_list,
@@ -1045,7 +1230,9 @@ class Snowflake(Dialect):
1045
1230
  expression = (
1046
1231
  self._match_set((TokenType.FROM, TokenType.COMMA)) and self._parse_bitwise()
1047
1232
  )
1048
- return self.expression(exp.Extract, this=map_date_part(this), expression=expression)
1233
+ return self.expression(
1234
+ exp.Extract, this=map_date_part(this, self.dialect), expression=expression
1235
+ )
1049
1236
 
1050
1237
  def _parse_bracket_key_value(self, is_map: bool = False) -> t.Optional[exp.Expression]:
1051
1238
  if is_map:
@@ -1387,6 +1574,7 @@ class Snowflake(Dialect):
1387
1574
  ARRAY_SIZE_NAME = "ARRAY_SIZE"
1388
1575
  SUPPORTS_DECODE_CASE = True
1389
1576
  IS_BOOL_ALLOWED = False
1577
+ DIRECTED_JOINS = True
1390
1578
 
1391
1579
  TRANSFORMS = {
1392
1580
  **generator.Generator.TRANSFORMS,
@@ -1394,7 +1582,9 @@ class Snowflake(Dialect):
1394
1582
  exp.ArgMax: rename_func("MAX_BY"),
1395
1583
  exp.ArgMin: rename_func("MIN_BY"),
1396
1584
  exp.Array: transforms.preprocess([transforms.inherit_struct_field_names]),
1397
- exp.ArrayConcat: lambda self, e: self.arrayconcat_sql(e, name="ARRAY_CAT"),
1585
+ exp.ArrayConcat: array_concat_sql("ARRAY_CAT"),
1586
+ exp.ArrayAppend: array_append_sql("ARRAY_APPEND"),
1587
+ exp.ArrayPrepend: array_append_sql("ARRAY_PREPEND"),
1398
1588
  exp.ArrayContains: lambda self, e: self.func(
1399
1589
  "ARRAY_CONTAINS",
1400
1590
  e.expression
@@ -1416,11 +1606,36 @@ class Snowflake(Dialect):
1416
1606
  exp.BitwiseLeftShift: rename_func("BITSHIFTLEFT"),
1417
1607
  exp.BitwiseRightShift: rename_func("BITSHIFTRIGHT"),
1418
1608
  exp.Create: transforms.preprocess([_flatten_structured_types_unless_iceberg]),
1609
+ exp.CurrentTimestamp: lambda self, e: self.func("SYSDATE")
1610
+ if e.args.get("sysdate")
1611
+ else self.function_fallback_sql(e),
1612
+ exp.Localtime: lambda self, e: self.func("CURRENT_TIME", e.this)
1613
+ if e.this
1614
+ else "CURRENT_TIME",
1615
+ exp.Localtimestamp: lambda self, e: self.func("CURRENT_TIMESTAMP", e.this)
1616
+ if e.this
1617
+ else "CURRENT_TIMESTAMP",
1419
1618
  exp.DateAdd: date_delta_sql("DATEADD"),
1420
1619
  exp.DateDiff: date_delta_sql("DATEDIFF"),
1421
1620
  exp.DatetimeAdd: date_delta_sql("TIMESTAMPADD"),
1422
1621
  exp.DatetimeDiff: timestampdiff_sql,
1423
1622
  exp.DateStrToDate: datestrtodate_sql,
1623
+ exp.Decrypt: lambda self, e: self.func(
1624
+ f"{'TRY_' if e.args.get('safe') else ''}DECRYPT",
1625
+ e.this,
1626
+ e.args.get("passphrase"),
1627
+ e.args.get("aad"),
1628
+ e.args.get("encryption_method"),
1629
+ ),
1630
+ exp.DecryptRaw: lambda self, e: self.func(
1631
+ f"{'TRY_' if e.args.get('safe') else ''}DECRYPT_RAW",
1632
+ e.this,
1633
+ e.args.get("key"),
1634
+ e.args.get("iv"),
1635
+ e.args.get("aad"),
1636
+ e.args.get("encryption_method"),
1637
+ e.args.get("aead"),
1638
+ ),
1424
1639
  exp.DayOfMonth: rename_func("DAYOFMONTH"),
1425
1640
  exp.DayOfWeek: rename_func("DAYOFWEEK"),
1426
1641
  exp.DayOfWeekIso: rename_func("DAYOFWEEKISO"),
@@ -1447,6 +1662,7 @@ class Snowflake(Dialect):
1447
1662
  exp.JSONExtractScalar: lambda self, e: self.func(
1448
1663
  "JSON_EXTRACT_PATH_TEXT", e.this, e.expression
1449
1664
  ),
1665
+ exp.JSONKeys: rename_func("OBJECT_KEYS"),
1450
1666
  exp.JSONObject: lambda self, e: self.func("OBJECT_CONSTRUCT_KEEP_NULL", *e.expressions),
1451
1667
  exp.JSONPathRoot: lambda *_: "",
1452
1668
  exp.JSONValueArray: _json_extract_value_array_sql,
@@ -1552,7 +1768,7 @@ class Snowflake(Dialect):
1552
1768
  f"{'TRY_' if e.args.get('safe') else ''}TO_TIME", e.this, self.format_time(e)
1553
1769
  ),
1554
1770
  exp.Unhex: rename_func("HEX_DECODE_BINARY"),
1555
- exp.UnixToTime: rename_func("TO_TIMESTAMP"),
1771
+ exp.UnixToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, e.args.get("scale")),
1556
1772
  exp.Uuid: rename_func("UUID_STRING"),
1557
1773
  exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
1558
1774
  exp.Booland: rename_func("BOOLAND"),
@@ -1562,6 +1778,7 @@ class Snowflake(Dialect):
1562
1778
  exp.YearOfWeekIso: rename_func("YEAROFWEEKISO"),
1563
1779
  exp.Xor: rename_func("BOOLXOR"),
1564
1780
  exp.ByteLength: rename_func("OCTET_LENGTH"),
1781
+ exp.Flatten: rename_func("ARRAY_FLATTEN"),
1565
1782
  exp.ArrayConcatAgg: lambda self, e: self.func(
1566
1783
  "ARRAY_FLATTEN", exp.ArrayAgg(this=e.this)
1567
1784
  ),
@@ -1693,6 +1910,26 @@ class Snowflake(Dialect):
1693
1910
 
1694
1911
  return super().log_sql(expression)
1695
1912
 
1913
+ def greatest_sql(self, expression: exp.Greatest) -> str:
1914
+ name = "GREATEST_IGNORE_NULLS" if expression.args.get("ignore_nulls") else "GREATEST"
1915
+ return self.func(name, expression.this, *expression.expressions)
1916
+
1917
+ def least_sql(self, expression: exp.Least) -> str:
1918
+ name = "LEAST_IGNORE_NULLS" if expression.args.get("ignore_nulls") else "LEAST"
1919
+ return self.func(name, expression.this, *expression.expressions)
1920
+
1921
+ def generator_sql(self, expression: exp.Generator) -> str:
1922
+ args = []
1923
+ rowcount = expression.args.get("rowcount")
1924
+ time_limit = expression.args.get("time_limit")
1925
+
1926
+ if rowcount:
1927
+ args.append(exp.Kwarg(this=exp.var("ROWCOUNT"), expression=rowcount))
1928
+ if time_limit:
1929
+ args.append(exp.Kwarg(this=exp.var("TIMELIMIT"), expression=time_limit))
1930
+
1931
+ return self.func("GENERATOR", *args)
1932
+
1696
1933
  def unnest_sql(self, expression: exp.Unnest) -> str:
1697
1934
  unnest_alias = expression.args.get("alias")
1698
1935
  offset = expression.args.get("offset")
@@ -1830,8 +2067,21 @@ class Snowflake(Dialect):
1830
2067
  return f"SET{exprs}{file_format}{copy_options}{tag}"
1831
2068
 
1832
2069
  def strtotime_sql(self, expression: exp.StrToTime):
2070
+ # target_type is stored as a DataType instance
2071
+ target_type = expression.args.get("target_type")
2072
+
2073
+ # Get the type enum from DataType instance or from type annotation
2074
+ if isinstance(target_type, exp.DataType):
2075
+ type_enum = target_type.this
2076
+ elif expression.type:
2077
+ type_enum = expression.type.this
2078
+ else:
2079
+ type_enum = exp.DataType.Type.TIMESTAMP
2080
+
2081
+ func_name = TIMESTAMP_TYPES.get(type_enum, "TO_TIMESTAMP")
2082
+
1833
2083
  return self.func(
1834
- f"{'TRY_' if expression.args.get('safe') else ''}TO_TIMESTAMP",
2084
+ f"{'TRY_' if expression.args.get('safe') else ''}{func_name}",
1835
2085
  expression.this,
1836
2086
  self.format_time(expression),
1837
2087
  )
sqlglot/dialects/spark.py CHANGED
@@ -4,6 +4,7 @@ import typing as t
4
4
 
5
5
  from sqlglot import exp
6
6
  from sqlglot.dialects.dialect import (
7
+ array_append_sql,
7
8
  rename_func,
8
9
  build_like,
9
10
  unit_to_var,
@@ -14,6 +15,7 @@ from sqlglot.dialects.dialect import (
14
15
  )
15
16
  from sqlglot.dialects.hive import _build_with_ignore_nulls
16
17
  from sqlglot.dialects.spark2 import Spark2, temporary_storage_provider, _build_as_cast
18
+ from sqlglot.typing.spark import EXPRESSION_METADATA
17
19
  from sqlglot.helper import ensure_list, seq_get
18
20
  from sqlglot.tokens import TokenType
19
21
  from sqlglot.transforms import (
@@ -112,6 +114,8 @@ def _groupconcat_sql(self: Spark.Generator, expression: exp.GroupConcat) -> str:
112
114
  class Spark(Spark2):
113
115
  SUPPORTS_ORDER_BY_ALL = True
114
116
  SUPPORTS_NULL_TYPE = True
117
+ ARRAY_FUNCS_PROPAGATES_NULLS = True
118
+ EXPRESSION_METADATA = EXPRESSION_METADATA.copy()
115
119
 
116
120
  class Tokenizer(Spark2.Tokenizer):
117
121
  STRING_ESCAPES_ALLOWED_IN_RAW_STRINGS = False
@@ -126,6 +130,12 @@ class Spark(Spark2):
126
130
  FUNCTIONS = {
127
131
  **Spark2.Parser.FUNCTIONS,
128
132
  "ANY_VALUE": _build_with_ignore_nulls(exp.AnyValue),
133
+ "ARRAY_INSERT": lambda args: exp.ArrayInsert(
134
+ this=seq_get(args, 0),
135
+ position=seq_get(args, 1),
136
+ expression=seq_get(args, 2),
137
+ offset=1,
138
+ ),
129
139
  "BIT_AND": exp.BitwiseAndAgg.from_arg_list,
130
140
  "BIT_OR": exp.BitwiseOrAgg.from_arg_list,
131
141
  "BIT_XOR": exp.BitwiseXorAgg.from_arg_list,
@@ -139,6 +149,7 @@ class Spark(Spark2):
139
149
  "TRY_SUBTRACT": exp.SafeSubtract.from_arg_list,
140
150
  "DATEDIFF": _build_datediff,
141
151
  "DATE_DIFF": _build_datediff,
152
+ "JSON_OBJECT_KEYS": exp.JSONKeys.from_arg_list,
142
153
  "LISTAGG": exp.GroupConcat.from_arg_list,
143
154
  "TIMESTAMP_LTZ": _build_as_cast("TIMESTAMP_LTZ"),
144
155
  "TIMESTAMP_NTZ": _build_as_cast("TIMESTAMP_NTZ"),
@@ -162,6 +173,11 @@ class Spark(Spark2):
162
173
  self._match(TokenType.R_BRACE)
163
174
  return self.expression(exp.Placeholder, this=this, widget=True)
164
175
 
176
+ FUNCTION_PARSERS = {
177
+ **Spark2.Parser.FUNCTION_PARSERS,
178
+ "SUBSTR": lambda self: self._parse_substring(),
179
+ }
180
+
165
181
  def _parse_generated_as_identity(
166
182
  self,
167
183
  ) -> (
@@ -174,6 +190,12 @@ class Spark(Spark2):
174
190
  return self.expression(exp.ComputedColumnConstraint, this=this.expression)
175
191
  return this
176
192
 
193
+ def _parse_pivot_aggregation(self) -> t.Optional[exp.Expression]:
194
+ # Spark 3+ and Databricks support non aggregate functions in PIVOT too, e.g
195
+ # PIVOT (..., 'foo' AS bar FOR col_to_pivot IN (...))
196
+ aggregate_expr = self._parse_function() or self._parse_disjunction()
197
+ return self._parse_alias(aggregate_expr)
198
+
177
199
  class Generator(Spark2.Generator):
178
200
  SUPPORTS_TO_NUMBER = True
179
201
  PAD_FILL_PATTERN_IS_REQUIRED = False
@@ -196,6 +218,8 @@ class Spark(Spark2):
196
218
  exp.ArrayConstructCompact: lambda self, e: self.func(
197
219
  "ARRAY_COMPACT", self.func("ARRAY", *e.expressions)
198
220
  ),
221
+ exp.ArrayAppend: array_append_sql("ARRAY_APPEND"),
222
+ exp.ArrayPrepend: array_append_sql("ARRAY_PREPEND"),
199
223
  exp.BitwiseAndAgg: rename_func("BIT_AND"),
200
224
  exp.BitwiseOrAgg: rename_func("BIT_OR"),
201
225
  exp.BitwiseXorAgg: rename_func("BIT_XOR"),
@@ -214,6 +238,7 @@ class Spark(Spark2):
214
238
  exp.DatetimeSub: date_delta_to_binary_interval_op(cast=False),
215
239
  exp.GroupConcat: _groupconcat_sql,
216
240
  exp.EndsWith: rename_func("ENDSWITH"),
241
+ exp.JSONKeys: rename_func("JSON_OBJECT_KEYS"),
217
242
  exp.PartitionedByProperty: lambda self,
218
243
  e: f"PARTITIONED BY {self.wrap(self.expressions(sqls=[_normalize_partition(e) for e in e.this.expressions], skip_first=True))}",
219
244
  exp.SafeAdd: rename_func("TRY_ADD"),
@@ -9,11 +9,11 @@ from sqlglot.dialects.dialect import (
9
9
  is_parse_json,
10
10
  pivot_column_names,
11
11
  rename_func,
12
- trim_sql,
13
12
  unit_to_str,
14
13
  )
15
14
  from sqlglot.dialects.hive import Hive
16
15
  from sqlglot.helper import seq_get
16
+ from sqlglot.parser import build_trim
17
17
  from sqlglot.tokens import TokenType
18
18
  from sqlglot.transforms import (
19
19
  preprocess,
@@ -139,7 +139,6 @@ class Spark2(Hive):
139
139
  FUNCTIONS = {
140
140
  **Hive.Parser.FUNCTIONS,
141
141
  "AGGREGATE": exp.Reduce.from_arg_list,
142
- "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
143
142
  "BOOLEAN": _build_as_cast("boolean"),
144
143
  "DATE": _build_as_cast("date"),
145
144
  "DATE_TRUNC": lambda args: exp.TimestampTrunc(
@@ -159,9 +158,11 @@ class Spark2(Hive):
159
158
  ),
160
159
  zone=seq_get(args, 1),
161
160
  ),
161
+ "LTRIM": lambda args: build_trim(args, reverse_args=True),
162
162
  "INT": _build_as_cast("int"),
163
163
  "MAP_FROM_ARRAYS": exp.Map.from_arg_list,
164
164
  "RLIKE": exp.RegexpLike.from_arg_list,
165
+ "RTRIM": lambda args: build_trim(args, is_left=False, reverse_args=True),
165
166
  "SHIFTLEFT": binary_from_function(exp.BitwiseLeftShift),
166
167
  "SHIFTRIGHT": binary_from_function(exp.BitwiseRightShift),
167
168
  "STRING": _build_as_cast("string"),
@@ -187,6 +188,7 @@ class Spark2(Hive):
187
188
 
188
189
  FUNCTION_PARSERS = {
189
190
  **Hive.Parser.FUNCTION_PARSERS,
191
+ "APPROX_PERCENTILE": lambda self: self._parse_quantile_function(exp.ApproxQuantile),
190
192
  "BROADCAST": lambda self: self._parse_join_hint("BROADCAST"),
191
193
  "BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"),
192
194
  "MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"),
@@ -288,7 +290,6 @@ class Spark2(Hive):
288
290
  exp.StrToDate: _str_to_date,
289
291
  exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)),
290
292
  exp.TimestampTrunc: lambda self, e: self.func("DATE_TRUNC", unit_to_str(e), e.this),
291
- exp.Trim: trim_sql,
292
293
  exp.UnixToTime: _unix_to_time_sql,
293
294
  exp.VariancePop: rename_func("VAR_POP"),
294
295
  exp.WeekOfYear: rename_func("WEEKOFYEAR"),