snowpark-connect 0.31.0__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of snowpark-connect might be problematic. Click here for more details.

Files changed (38) hide show
  1. snowflake/snowpark_connect/__init__.py +1 -0
  2. snowflake/snowpark_connect/column_name_handler.py +73 -100
  3. snowflake/snowpark_connect/column_qualifier.py +47 -0
  4. snowflake/snowpark_connect/dataframe_container.py +3 -2
  5. snowflake/snowpark_connect/execute_plan/map_execution_command.py +4 -2
  6. snowflake/snowpark_connect/expression/map_expression.py +5 -4
  7. snowflake/snowpark_connect/expression/map_extension.py +12 -6
  8. snowflake/snowpark_connect/expression/map_sql_expression.py +38 -3
  9. snowflake/snowpark_connect/expression/map_unresolved_attribute.py +5 -5
  10. snowflake/snowpark_connect/expression/map_unresolved_function.py +869 -107
  11. snowflake/snowpark_connect/expression/map_unresolved_star.py +9 -7
  12. snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +4 -1
  13. snowflake/snowpark_connect/relation/map_aggregate.py +8 -5
  14. snowflake/snowpark_connect/relation/map_column_ops.py +4 -3
  15. snowflake/snowpark_connect/relation/map_extension.py +10 -9
  16. snowflake/snowpark_connect/relation/map_join.py +5 -2
  17. snowflake/snowpark_connect/relation/map_sql.py +33 -1
  18. snowflake/snowpark_connect/relation/map_subquery_alias.py +4 -1
  19. snowflake/snowpark_connect/relation/read/map_read_table.py +6 -3
  20. snowflake/snowpark_connect/relation/write/map_write.py +29 -14
  21. snowflake/snowpark_connect/server.py +1 -2
  22. snowflake/snowpark_connect/type_mapping.py +36 -3
  23. snowflake/snowpark_connect/typed_column.py +8 -6
  24. snowflake/snowpark_connect/utils/session.py +19 -3
  25. snowflake/snowpark_connect/version.py +1 -1
  26. snowflake/snowpark_decoder/dp_session.py +1 -1
  27. {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.32.0.dist-info}/METADATA +5 -2
  28. {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.32.0.dist-info}/RECORD +36 -37
  29. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2_grpc.py +0 -4
  30. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2_grpc.py +0 -4
  31. {snowpark_connect-0.31.0.data → snowpark_connect-0.32.0.data}/scripts/snowpark-connect +0 -0
  32. {snowpark_connect-0.31.0.data → snowpark_connect-0.32.0.data}/scripts/snowpark-session +0 -0
  33. {snowpark_connect-0.31.0.data → snowpark_connect-0.32.0.data}/scripts/snowpark-submit +0 -0
  34. {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.32.0.dist-info}/WHEEL +0 -0
  35. {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.32.0.dist-info}/licenses/LICENSE-binary +0 -0
  36. {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.32.0.dist-info}/licenses/LICENSE.txt +0 -0
  37. {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.32.0.dist-info}/licenses/NOTICE-binary +0 -0
  38. {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.32.0.dist-info}/top_level.txt +0 -0
@@ -20,7 +20,7 @@ from contextlib import suppress
20
20
  from decimal import ROUND_HALF_EVEN, ROUND_HALF_UP, Context, Decimal
21
21
  from functools import partial, reduce
22
22
  from pathlib import Path
23
- from typing import List, Optional, Union
23
+ from typing import List, Optional
24
24
  from urllib.parse import quote, unquote
25
25
 
26
26
  import pyspark.sql.connect.proto.expressions_pb2 as expressions_proto
@@ -66,6 +66,7 @@ from snowflake.snowpark.types import (
66
66
  TimestampType,
67
67
  VariantType,
68
68
  YearMonthIntervalType,
69
+ _AnsiIntervalType,
69
70
  _FractionalType,
70
71
  _IntegralType,
71
72
  _NumericType,
@@ -74,6 +75,7 @@ from snowflake.snowpark_connect.column_name_handler import (
74
75
  ColumnNameMap,
75
76
  set_schema_getter,
76
77
  )
78
+ from snowflake.snowpark_connect.column_qualifier import ColumnQualifier
77
79
  from snowflake.snowpark_connect.config import (
78
80
  get_boolean_session_config_param,
79
81
  get_timestamp_type,
@@ -148,7 +150,11 @@ from snowflake.snowpark_connect.utils.xxhash64 import (
148
150
  MAX_UINT64 = 2**64 - 1
149
151
  MAX_INT64 = 2**63 - 1
150
152
  MIN_INT64 = -(2**63)
151
- MAX_ARRAY_SIZE = 2_147_483_647
153
+ MAX_32BIT_SIGNED_INT = 2_147_483_647
154
+
155
+ # Interval arithmetic precision limits
156
+ MAX_DAY_TIME_DAYS = 106751991 # Maximum days for day-time intervals
157
+ MAX_10_DIGIT_LIMIT = 1000000000 # 10-digit limit (1 billion) for interval operands
152
158
 
153
159
  NAN, INFINITY = float("nan"), float("inf")
154
160
 
@@ -272,6 +278,40 @@ def _coerce_for_comparison(
272
278
  return left_col, right_col
273
279
 
274
280
 
281
+ def _preprocess_not_equals_expression(exp: expressions_proto.Expression) -> str:
282
+ """
283
+ Transform NOT(col1 = col2) expressions to col1 != col2 for Snowflake compatibility.
284
+
285
+ Snowflake has issues with NOT (col1 = col2) in subqueries, so we rewrite
286
+ not(==(a, b)) to a != b by modifying the protobuf expression early.
287
+
288
+ Returns:
289
+ The (potentially modified) function name as a lowercase string.
290
+ """
291
+ function_name = exp.unresolved_function.function_name.lower()
292
+
293
+ # Snowflake has issues with NOT (col1 = col2) in subqueries.
294
+ # Transform not(==(a, b)) to a!=b by modifying the protobuf early.
295
+ if (
296
+ function_name in ("not", "!")
297
+ and len(exp.unresolved_function.arguments) == 1
298
+ and exp.unresolved_function.arguments[0].WhichOneof("expr_type")
299
+ == "unresolved_function"
300
+ and exp.unresolved_function.arguments[0].unresolved_function.function_name
301
+ == "=="
302
+ ):
303
+ inner_eq_func = exp.unresolved_function.arguments[0].unresolved_function
304
+ inner_args = list(inner_eq_func.arguments)
305
+
306
+ exp.unresolved_function.function_name = "!="
307
+ exp.unresolved_function.ClearField("arguments")
308
+ exp.unresolved_function.arguments.extend(inner_args)
309
+
310
+ function_name = "!="
311
+
312
+ return function_name
313
+
314
+
275
315
  def map_unresolved_function(
276
316
  exp: expressions_proto.Expression,
277
317
  column_mapping: ColumnNameMap,
@@ -300,6 +340,9 @@ def map_unresolved_function(
300
340
  # Inject default parameters for functions that need them (especially for Scala clients)
301
341
  inject_function_defaults(exp.unresolved_function)
302
342
 
343
+ # Transform NOT(col = col) to col != col for Snowflake compatibility
344
+ function_name = _preprocess_not_equals_expression(exp)
345
+
303
346
  def _resolve_args_expressions(exp: expressions_proto.Expression):
304
347
  def _resolve_fn_arg(exp):
305
348
  with resolving_fun_args():
@@ -355,7 +398,7 @@ def map_unresolved_function(
355
398
  function_name = exp.unresolved_function.function_name.lower()
356
399
  telemetry.report_function_usage(function_name)
357
400
  result_type: Optional[DataType | List[DateType]] = None
358
- qualifiers: List[str] = []
401
+ qualifier_parts: List[str] = []
359
402
 
360
403
  pyspark_func = getattr(pyspark_functions, function_name, None)
361
404
  if pyspark_func and pyspark_func.__doc__.lstrip().startswith("Aggregate function:"):
@@ -513,9 +556,17 @@ def map_unresolved_function(
513
556
  )
514
557
  result_type = [f.datatype for f in udtf.output_schema]
515
558
  case "!=":
516
- result_exp = TypedColumn(
517
- snowpark_args[0] != snowpark_args[1], lambda: [BooleanType()]
559
+ _check_interval_string_comparison(
560
+ "!=", snowpark_typed_args, snowpark_arg_names
561
+ )
562
+ # Make the function name same as spark connect. a != b translate's to not(a=b)
563
+ spark_function_name = (
564
+ f"(NOT ({snowpark_arg_names[0]} = {snowpark_arg_names[1]}))"
565
+ )
566
+ left, right = _coerce_for_comparison(
567
+ snowpark_typed_args[0], snowpark_typed_args[1]
518
568
  )
569
+ result_exp = TypedColumn(left != right, lambda: [BooleanType()])
519
570
  case "%" | "mod":
520
571
  if spark_sql_ansi_enabled:
521
572
  result_exp = snowpark_args[0] % snowpark_args[1]
@@ -616,12 +667,87 @@ def map_unresolved_function(
616
667
  result_exp = snowpark_args[0] * snowpark_args[1].try_cast(
617
668
  result_type
618
669
  )
619
- case (_NumericType() as t, NullType()) | (
620
- NullType(),
621
- _NumericType() as t,
670
+ case (StringType(), t) | (t, StringType()) if isinstance(
671
+ t, _AnsiIntervalType
672
+ ):
673
+ if isinstance(snowpark_typed_args[0].typ, StringType):
674
+ result_type = type(
675
+ t
676
+ )() # YearMonthIntervalType() or DayTimeIntervalType()
677
+ result_exp = snowpark_args[1] * snowpark_args[0].try_cast(
678
+ LongType()
679
+ )
680
+ spark_function_name = (
681
+ f"({snowpark_arg_names[1]} * {snowpark_arg_names[0]})"
682
+ )
683
+ else:
684
+ result_type = type(
685
+ t
686
+ )() # YearMonthIntervalType() or DayTimeIntervalType()
687
+ result_exp = snowpark_args[0] * snowpark_args[1].try_cast(
688
+ LongType()
689
+ )
690
+ spark_function_name = (
691
+ f"({snowpark_arg_names[0]} * {snowpark_arg_names[1]})"
692
+ )
693
+ case (
694
+ (_NumericType() as t, NullType())
695
+ | (NullType(), _NumericType() as t)
622
696
  ):
623
697
  result_type = t
624
698
  result_exp = snowpark_fn.lit(None)
699
+ case (NullType(), t) | (t, NullType()) if isinstance(
700
+ t, _AnsiIntervalType
701
+ ):
702
+ result_type = (
703
+ YearMonthIntervalType()
704
+ if isinstance(t, YearMonthIntervalType)
705
+ else DayTimeIntervalType()
706
+ )
707
+ result_exp = snowpark_fn.lit(None)
708
+ if isinstance(snowpark_typed_args[0].typ, NullType):
709
+ spark_function_name = (
710
+ f"({snowpark_arg_names[1]} * {snowpark_arg_names[0]})"
711
+ )
712
+ else:
713
+ spark_function_name = (
714
+ f"({snowpark_arg_names[0]} * {snowpark_arg_names[1]})"
715
+ )
716
+ case (DecimalType(), t) | (t, DecimalType()) if isinstance(
717
+ t, _AnsiIntervalType
718
+ ):
719
+ result_type = (
720
+ YearMonthIntervalType()
721
+ if isinstance(t, YearMonthIntervalType)
722
+ else DayTimeIntervalType()
723
+ )
724
+ if isinstance(snowpark_typed_args[0].typ, DecimalType):
725
+ result_exp = snowpark_args[1] * snowpark_args[0]
726
+ spark_function_name = (
727
+ f"({snowpark_arg_names[1]} * {snowpark_arg_names[0]})"
728
+ )
729
+ else:
730
+ result_exp = snowpark_args[0] * snowpark_args[1]
731
+ spark_function_name = (
732
+ f"({snowpark_arg_names[0]} * {snowpark_arg_names[1]})"
733
+ )
734
+ case (t, _NumericType()) if isinstance(t, _AnsiIntervalType):
735
+ result_type = (
736
+ YearMonthIntervalType()
737
+ if isinstance(t, YearMonthIntervalType)
738
+ else DayTimeIntervalType()
739
+ )
740
+ result_exp = snowpark_args[0] * snowpark_args[1]
741
+ case (_NumericType(), t) if isinstance(t, _AnsiIntervalType):
742
+ result_type = (
743
+ YearMonthIntervalType()
744
+ if isinstance(t, YearMonthIntervalType)
745
+ else DayTimeIntervalType()
746
+ )
747
+ result_exp = snowpark_args[1] * snowpark_args[0]
748
+ spark_function_name = (
749
+ f"({snowpark_arg_names[1]} * {snowpark_arg_names[0]})"
750
+ )
625
751
  case (_NumericType(), _NumericType()):
626
752
  result_type = _find_common_type(
627
753
  [arg.typ for arg in snowpark_typed_args]
@@ -662,7 +788,14 @@ def map_unresolved_function(
662
788
  result_type = DateType()
663
789
  result_exp = snowpark_args[0] + snowpark_args[1]
664
790
  elif isinstance(t, (DayTimeIntervalType, YearMonthIntervalType)):
665
- result_type = TimestampType()
791
+ result_type = (
792
+ TimestampType()
793
+ if isinstance(
794
+ snowpark_typed_args[t_param_index].typ,
795
+ DayTimeIntervalType,
796
+ )
797
+ else DateType()
798
+ )
666
799
  result_exp = (
667
800
  snowpark_args[date_param_index]
668
801
  + snowpark_args[t_param_index]
@@ -685,6 +818,35 @@ def map_unresolved_function(
685
818
  )
686
819
  attach_custom_error_code(exception, ErrorCodes.TYPE_MISMATCH)
687
820
  raise exception
821
+ case (TimestampType(), t) | (t, TimestampType()):
822
+ timestamp_param_index = (
823
+ 0
824
+ if isinstance(snowpark_typed_args[0].typ, TimestampType)
825
+ else 1
826
+ )
827
+ t_param_index = 1 - timestamp_param_index
828
+ if isinstance(t, (DayTimeIntervalType, YearMonthIntervalType)):
829
+ result_type = TimestampType()
830
+ result_exp = (
831
+ snowpark_args[timestamp_param_index]
832
+ + snowpark_args[t_param_index]
833
+ )
834
+ elif (
835
+ hasattr(
836
+ snowpark_typed_args[t_param_index].col._expr1, "pretty_name"
837
+ )
838
+ and "INTERVAL"
839
+ == snowpark_typed_args[t_param_index].col._expr1.pretty_name
840
+ ):
841
+ result_type = TimestampType()
842
+ result_exp = (
843
+ snowpark_args[timestamp_param_index]
844
+ + snowpark_args[t_param_index]
845
+ )
846
+ else:
847
+ raise AnalysisException(
848
+ f'[DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE] Cannot resolve "{spark_function_name}" due to data type mismatch: Parameter 2 requires the ("INTERVAL") type for timestamp operations, however "{snowpark_arg_names[t_param_index]}" has the type "{t}".',
849
+ )
688
850
  case (StringType(), StringType()):
689
851
  if spark_sql_ansi_enabled:
690
852
  exception = AnalysisException(
@@ -736,6 +898,86 @@ def map_unresolved_function(
736
898
  )
737
899
  attach_custom_error_code(exception, ErrorCodes.TYPE_MISMATCH)
738
900
  raise exception
901
+ case (t1, t2) | (t2, t1) if isinstance(
902
+ t1, _AnsiIntervalType
903
+ ) and isinstance(t2, _AnsiIntervalType) and type(t1) == type(t2):
904
+ # Both operands are the same interval type
905
+ result_type = type(t1)(
906
+ min(t1.start_field, t2.start_field),
907
+ max(t1.end_field, t2.end_field),
908
+ )
909
+ result_exp = snowpark_args[0] + snowpark_args[1]
910
+ case (StringType(), t) | (t, StringType()) if isinstance(
911
+ t, YearMonthIntervalType
912
+ ):
913
+ # String + YearMonthInterval: Spark tries to cast string to double first, throws error if it fails
914
+ result_type = StringType()
915
+ if isinstance(snowpark_typed_args[0].typ, StringType):
916
+ result_exp = (
917
+ snowpark_fn.cast(snowpark_args[0], "double")
918
+ + snowpark_args[1]
919
+ )
920
+ else:
921
+ result_exp = snowpark_args[0] + snowpark_fn.cast(
922
+ snowpark_args[1], "double"
923
+ )
924
+ case (StringType(), t) | (t, StringType()) if isinstance(
925
+ t, DayTimeIntervalType
926
+ ):
927
+ # String + DayTimeInterval: try to parse string as timestamp, return NULL if it fails
928
+ # For time-only strings (like '10:00:00'), prepend current date to make it a full timestamp
929
+ result_type = StringType()
930
+ if isinstance(snowpark_typed_args[0].typ, StringType):
931
+ # Check if string looks like time-only (HH:MM:SS or HH:MM pattern)
932
+ # If so, prepend current date; otherwise use as-is
933
+ time_only_pattern = snowpark_fn.function("regexp_like")(
934
+ snowpark_args[0], r"^\d{1,2}:\d{2}(:\d{2})?$"
935
+ )
936
+ timestamp_expr = snowpark_fn.when(
937
+ time_only_pattern,
938
+ snowpark_fn.function("try_to_timestamp_ntz")(
939
+ snowpark_fn.function("concat")(
940
+ snowpark_fn.function("to_char")(
941
+ snowpark_fn.function("current_date")(),
942
+ "YYYY-MM-DD",
943
+ ),
944
+ snowpark_fn.lit(" "),
945
+ snowpark_args[0],
946
+ )
947
+ ),
948
+ ).otherwise(
949
+ snowpark_fn.function("try_to_timestamp_ntz")(
950
+ snowpark_args[0]
951
+ )
952
+ )
953
+ result_exp = timestamp_expr + snowpark_args[1]
954
+ else:
955
+ # interval + string case
956
+ time_only_pattern = snowpark_fn.function("regexp_like")(
957
+ snowpark_args[1], r"^\d{1,2}:\d{2}(:\d{2})?$"
958
+ )
959
+ timestamp_expr = snowpark_fn.when(
960
+ time_only_pattern,
961
+ snowpark_fn.function("try_to_timestamp_ntz")(
962
+ snowpark_fn.function("concat")(
963
+ snowpark_fn.function("to_char")(
964
+ snowpark_fn.function("current_date")(),
965
+ "'YYYY-MM-DD'",
966
+ ),
967
+ snowpark_fn.lit(" "),
968
+ snowpark_args[1],
969
+ )
970
+ ),
971
+ ).otherwise(
972
+ snowpark_fn.function("try_to_timestamp_ntz")(
973
+ snowpark_args[1]
974
+ )
975
+ )
976
+ result_exp = snowpark_args[0] + timestamp_expr
977
+ spark_function_name = (
978
+ f"{snowpark_arg_names[0]} + {snowpark_arg_names[1]}"
979
+ )
980
+
739
981
  case _:
740
982
  result_type, overflow_possible = _get_add_sub_result_type(
741
983
  snowpark_typed_args[0].typ,
@@ -781,7 +1023,11 @@ def map_unresolved_function(
781
1023
  DateType(),
782
1024
  YearMonthIntervalType(),
783
1025
  ):
784
- result_type = TimestampType()
1026
+ result_type = (
1027
+ TimestampType()
1028
+ if isinstance(snowpark_typed_args[1].typ, DayTimeIntervalType)
1029
+ else DateType()
1030
+ )
785
1031
  result_exp = snowpark_args[0] - snowpark_args[1]
786
1032
  case (DateType(), StringType()):
787
1033
  if (
@@ -799,6 +1045,23 @@ def map_unresolved_function(
799
1045
  result_exp = snowpark_args[0] - snowpark_args[1].cast(
800
1046
  input_type
801
1047
  )
1048
+ case (TimestampType(), DayTimeIntervalType()) | (
1049
+ TimestampType(),
1050
+ YearMonthIntervalType(),
1051
+ ):
1052
+ result_type = TimestampType()
1053
+ result_exp = snowpark_args[0] - snowpark_args[1]
1054
+ case (TimestampType(), StringType()):
1055
+ if (
1056
+ hasattr(snowpark_typed_args[1].col._expr1, "pretty_name")
1057
+ and "INTERVAL" == snowpark_typed_args[1].col._expr1.pretty_name
1058
+ ):
1059
+ result_type = TimestampType()
1060
+ result_exp = snowpark_args[0] - snowpark_args[1]
1061
+ else:
1062
+ raise AnalysisException(
1063
+ f'[DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE] Cannot resolve "{spark_function_name}" due to data type mismatch: Parameter 2 requires the ("INTERVAL") type for timestamp operations, however "{snowpark_arg_names[1]}" has the type "{snowpark_typed_args[1].typ}".',
1064
+ )
802
1065
  case (StringType(), DateType()):
803
1066
  # TODO SNOW-2034420: resolve return type (it should be INTERVAL DAY)
804
1067
  result_type = LongType()
@@ -870,6 +1133,16 @@ def map_unresolved_function(
870
1133
  )
871
1134
  attach_custom_error_code(exception, ErrorCodes.TYPE_MISMATCH)
872
1135
  raise exception
1136
+ case (StringType(), t) if isinstance(t, _AnsiIntervalType):
1137
+ # String - Interval: try to parse string as timestamp, return NULL if it fails
1138
+ result_type = StringType()
1139
+ result_exp = (
1140
+ snowpark_fn.function("try_to_timestamp")(snowpark_args[0])
1141
+ - snowpark_args[1]
1142
+ )
1143
+ spark_function_name = (
1144
+ f"{snowpark_arg_names[0]} - {snowpark_arg_names[1]}"
1145
+ )
873
1146
  case _:
874
1147
  result_type, overflow_possible = _get_add_sub_result_type(
875
1148
  snowpark_typed_args[0].typ,
@@ -968,9 +1241,57 @@ def map_unresolved_function(
968
1241
  result_exp = _divnull(
969
1242
  snowpark_args[0], snowpark_args[1].try_cast(result_type)
970
1243
  )
1244
+ case (t, StringType()) if isinstance(t, _AnsiIntervalType):
1245
+ result_type = (
1246
+ YearMonthIntervalType()
1247
+ if isinstance(t, YearMonthIntervalType)
1248
+ else DayTimeIntervalType()
1249
+ )
1250
+ result_exp = snowpark_args[0] / snowpark_args[1].try_cast(
1251
+ LongType()
1252
+ )
1253
+ spark_function_name = (
1254
+ f"({snowpark_arg_names[0]} / {snowpark_arg_names[1]})"
1255
+ )
971
1256
  case (_NumericType(), NullType()) | (NullType(), _NumericType()):
972
1257
  result_type = DoubleType()
973
1258
  result_exp = snowpark_fn.lit(None)
1259
+ case (t, NullType()) if isinstance(t, _AnsiIntervalType):
1260
+ # Only allow interval / null, not null / interval
1261
+ result_type = (
1262
+ YearMonthIntervalType()
1263
+ if isinstance(t, YearMonthIntervalType)
1264
+ else DayTimeIntervalType()
1265
+ )
1266
+ result_exp = snowpark_fn.lit(None)
1267
+ spark_function_name = (
1268
+ f"({snowpark_arg_names[0]} / {snowpark_arg_names[1]})"
1269
+ )
1270
+ case (DecimalType(), t) | (t, DecimalType()) if isinstance(
1271
+ t, _AnsiIntervalType
1272
+ ):
1273
+ result_type = (
1274
+ YearMonthIntervalType()
1275
+ if isinstance(t, YearMonthIntervalType)
1276
+ else DayTimeIntervalType()
1277
+ )
1278
+ if isinstance(snowpark_typed_args[0].typ, DecimalType):
1279
+ result_exp = snowpark_args[1] / snowpark_args[0]
1280
+ spark_function_name = (
1281
+ f"({snowpark_arg_names[1]} / {snowpark_arg_names[0]})"
1282
+ )
1283
+ else:
1284
+ result_exp = snowpark_args[0] / snowpark_args[1]
1285
+ spark_function_name = (
1286
+ f"({snowpark_arg_names[0]} / {snowpark_arg_names[1]})"
1287
+ )
1288
+ case (t, _NumericType()) if isinstance(t, _AnsiIntervalType):
1289
+ result_type = (
1290
+ YearMonthIntervalType()
1291
+ if isinstance(t, YearMonthIntervalType)
1292
+ else DayTimeIntervalType()
1293
+ )
1294
+ result_exp = snowpark_args[0] / snowpark_args[1]
974
1295
  case (_NumericType(), _NumericType()):
975
1296
  result_type = DoubleType()
976
1297
  result_exp = _divnull(
@@ -2027,11 +2348,6 @@ def map_unresolved_function(
2027
2348
  result_exp = snowpark_fn.coalesce(
2028
2349
  *[col.cast(result_type) for col in snowpark_args]
2029
2350
  )
2030
- case "col":
2031
- # TODO: assign type
2032
- result_exp = snowpark_fn.col(*snowpark_args)
2033
- result_exp = _type_with_typer(result_exp)
2034
- qualifiers = snowpark_args[0].get_qualifiers()
2035
2351
  case "collect_list" | "array_agg":
2036
2352
  # TODO: SNOW-1967177 - Support structured types in array_agg
2037
2353
  result_exp = snowpark_fn.array_agg(
@@ -2049,11 +2365,6 @@ def map_unresolved_function(
2049
2365
  result_exp = _resolve_aggregate_exp(
2050
2366
  result_exp, ArrayType(snowpark_typed_args[0].typ)
2051
2367
  )
2052
- case "column":
2053
- # TODO: assign type
2054
- result_exp = snowpark_fn.column(*snowpark_args)
2055
- result_exp = _type_with_typer(result_exp)
2056
- qualifiers = snowpark_args[0].get_qualifiers()
2057
2368
  case "concat":
2058
2369
  if len(snowpark_args) == 0:
2059
2370
  result_exp = TypedColumn(snowpark_fn.lit(""), lambda: [StringType()])
@@ -2232,7 +2543,7 @@ def map_unresolved_function(
2232
2543
  snowpark_fn.col("*", _is_qualified_name=True)
2233
2544
  )
2234
2545
  else:
2235
- result_exp = snowpark_fn.count(*snowpark_args)
2546
+ result_exp = snowpark_fn.call_function("COUNT", *snowpark_args)
2236
2547
  result_exp = TypedColumn(result_exp, lambda: [LongType()])
2237
2548
  case "count_if":
2238
2549
  result_exp = snowpark_fn.call_function("COUNT_IF", snowpark_args[0])
@@ -2670,16 +2981,6 @@ def map_unresolved_function(
2670
2981
  )
2671
2982
  result_type = LongType()
2672
2983
  case "date_part" | "datepart" | "extract":
2673
- # Check for interval types and throw NotImplementedError
2674
- if isinstance(
2675
- snowpark_typed_args[1].typ, (YearMonthIntervalType, DayTimeIntervalType)
2676
- ):
2677
- exception = NotImplementedError(
2678
- f"{function_name} with interval types is not supported"
2679
- )
2680
- attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
2681
- raise exception
2682
-
2683
2984
  field_lit: str | None = unwrap_literal(exp.unresolved_function.arguments[0])
2684
2985
 
2685
2986
  if field_lit is None:
@@ -2724,16 +3025,51 @@ def map_unresolved_function(
2724
3025
  case "div":
2725
3026
  # Only called from SQL, either as `a div b` or `div(a, b)`
2726
3027
  # Convert it into `(a - a % b) / b`.
2727
- result_exp = snowpark_fn.cast(
2728
- (snowpark_args[0] - snowpark_args[0] % snowpark_args[1])
2729
- / snowpark_args[1],
2730
- LongType(),
2731
- )
2732
- if not spark_sql_ansi_enabled:
2733
- result_exp = snowpark_fn.when(
2734
- snowpark_args[1] == 0, snowpark_fn.lit(None)
2735
- ).otherwise(result_exp)
2736
- result_type = LongType()
3028
+ if isinstance(snowpark_typed_args[0].typ, YearMonthIntervalType):
3029
+ if isinstance(snowpark_typed_args[1].typ, YearMonthIntervalType):
3030
+ dividend_total = _calculate_total_months(snowpark_args[0])
3031
+ divisor_total = _calculate_total_months(snowpark_args[1])
3032
+
3033
+ # Handle division by zero interval
3034
+ if not spark_sql_ansi_enabled:
3035
+ result_exp = snowpark_fn.when(
3036
+ divisor_total == 0, snowpark_fn.lit(None)
3037
+ ).otherwise(snowpark_fn.trunc(dividend_total / divisor_total))
3038
+ else:
3039
+ result_exp = snowpark_fn.trunc(dividend_total / divisor_total)
3040
+ result_type = LongType()
3041
+ else:
3042
+ raise AnalysisException(
3043
+ f"""[DATATYPE_MISMATCH.BINARY_OP_DIFF_TYPES] Cannot resolve "({snowpark_arg_names[0]} div {snowpark_arg_names[1]})" due to data type mismatch: the left and right operands of the binary operator have incompatible types ({snowpark_typed_args[0].typ} and {snowpark_typed_args[1].typ}).;"""
3044
+ )
3045
+ elif isinstance(snowpark_typed_args[0].typ, DayTimeIntervalType):
3046
+ if isinstance(snowpark_typed_args[1].typ, DayTimeIntervalType):
3047
+ dividend_total = _calculate_total_seconds(snowpark_args[0])
3048
+ divisor_total = _calculate_total_seconds(snowpark_args[1])
3049
+
3050
+ # Handle division by zero interval
3051
+ if not spark_sql_ansi_enabled:
3052
+ result_exp = snowpark_fn.when(
3053
+ divisor_total == 0, snowpark_fn.lit(None)
3054
+ ).otherwise(snowpark_fn.trunc(dividend_total / divisor_total))
3055
+ else:
3056
+ result_exp = snowpark_fn.trunc(dividend_total / divisor_total)
3057
+ result_type = LongType()
3058
+ else:
3059
+ raise AnalysisException(
3060
+ f"""[DATATYPE_MISMATCH.BINARY_OP_DIFF_TYPES] Cannot resolve "({snowpark_arg_names[0]} div {snowpark_arg_names[1]})" due to data type mismatch: the left and right operands of the binary operator have incompatible types ({snowpark_typed_args[0].typ} and {snowpark_typed_args[1].typ}).;"""
3061
+ )
3062
+ else:
3063
+ result_exp = snowpark_fn.cast(
3064
+ (snowpark_args[0] - snowpark_args[0] % snowpark_args[1])
3065
+ / snowpark_args[1],
3066
+ LongType(),
3067
+ )
3068
+ if not spark_sql_ansi_enabled:
3069
+ result_exp = snowpark_fn.when(
3070
+ snowpark_args[1] == 0, snowpark_fn.lit(None)
3071
+ ).otherwise(result_exp)
3072
+ result_type = LongType()
2737
3073
  case "e":
2738
3074
  spark_function_name = "E()"
2739
3075
  result_exp = snowpark_fn.lit(math.e)
@@ -3565,8 +3901,6 @@ def map_unresolved_function(
3565
3901
  result should be either way good enough.
3566
3902
  """
3567
3903
 
3568
- from datetime import date, datetime, time, timedelta
3569
-
3570
3904
  def __init__(self) -> None:
3571
3905
 
3572
3906
  # init the RNG for breaking ties in histogram merging. A fixed seed is specified here
@@ -3710,7 +4044,8 @@ def map_unresolved_function(
3710
4044
  # just increment 'bin'. This is not done now because we don't want to make any
3711
4045
  # assumptions about the range of numeric data being analyzed.
3712
4046
  if bin < self.n_used_bins and self.bins[bin][0] == v:
3713
- self.bins[bin][1] += 1
4047
+ bin_x, bin_y = self.bins[bin]
4048
+ self.bins[bin] = (bin_x, bin_y + 1)
3714
4049
  else:
3715
4050
  self.bins.insert(bin + 1, (v, 1.0))
3716
4051
  self.n_used_bins += 1
@@ -4504,6 +4839,17 @@ def map_unresolved_function(
4504
4839
  date_str_exp = snowpark_fn.concat(y, dash, m, dash, d)
4505
4840
  result_exp = snowpark_fn.builtin(snowpark_function)(date_str_exp)
4506
4841
  result_type = DateType()
4842
+ case "make_dt_interval":
4843
+ # Pad argument names for display purposes
4844
+ padded_arg_names = snowpark_arg_names.copy()
4845
+ while len(padded_arg_names) < 3: # days, hours, minutes are integers
4846
+ padded_arg_names.append("0")
4847
+ if len(padded_arg_names) < 4: # seconds can be decimal
4848
+ padded_arg_names.append("0.000000")
4849
+
4850
+ spark_function_name = f"make_dt_interval({', '.join(padded_arg_names)})"
4851
+ result_exp = snowpark_fn.interval_day_time_from_parts(*snowpark_args)
4852
+ result_type = DayTimeIntervalType()
4507
4853
  case "make_timestamp" | "make_timestamp_ltz" | "make_timestamp_ntz":
4508
4854
  y, m, d, h, mins = map(lambda col: col.cast(LongType()), snowpark_args[:5])
4509
4855
  y_abs = snowpark_fn.abs(y)
@@ -4557,6 +4903,15 @@ def map_unresolved_function(
4557
4903
  result_exp = snowpark_fn.when(
4558
4904
  snowpark_fn.is_null(parsed_str_exp), snowpark_fn.lit(None)
4559
4905
  ).otherwise(make_timestamp_res)
4906
+ case "make_ym_interval":
4907
+ # Pad argument names for display purposes
4908
+ padded_arg_names = snowpark_arg_names.copy()
4909
+ while len(padded_arg_names) < 2: # years, months
4910
+ padded_arg_names.append("0")
4911
+
4912
+ spark_function_name = f"make_ym_interval({', '.join(padded_arg_names)})"
4913
+ result_exp = snowpark_fn.interval_year_month_from_parts(*snowpark_args)
4914
+ result_type = YearMonthIntervalType()
4560
4915
  case "map":
4561
4916
  allow_duplicate_keys = (
4562
4917
  global_config.spark_sql_mapKeyDedupPolicy == "LAST_WIN"
@@ -5211,7 +5566,11 @@ def map_unresolved_function(
5211
5566
  spark_function_name = f"(- {snowpark_arg_names[0]})"
5212
5567
  else:
5213
5568
  spark_function_name = f"negative({snowpark_arg_names[0]})"
5214
- if isinstance(arg_type, _NumericType):
5569
+ if (
5570
+ isinstance(arg_type, _NumericType)
5571
+ or isinstance(arg_type, YearMonthIntervalType)
5572
+ or isinstance(arg_type, DayTimeIntervalType)
5573
+ ):
5215
5574
  # Instead of using snowpark_fn.negate which can generate invalid SQL for nested minus operations,
5216
5575
  # use a direct multiplication by -1 which generates cleaner SQL
5217
5576
  result_exp = snowpark_args[0] * snowpark_fn.lit(-1)
@@ -5236,6 +5595,8 @@ def map_unresolved_function(
5236
5595
  result_type = (
5237
5596
  snowpark_typed_args[0].types
5238
5597
  if isinstance(arg_type, _NumericType)
5598
+ or isinstance(arg_type, YearMonthIntervalType)
5599
+ or isinstance(arg_type, DayTimeIntervalType)
5239
5600
  else DoubleType()
5240
5601
  )
5241
5602
  case "next_day":
@@ -5616,9 +5977,33 @@ def map_unresolved_function(
5616
5977
  case "percentile_cont" | "percentiledisc":
5617
5978
  if function_name == "percentiledisc":
5618
5979
  function_name = "percentile_disc"
5980
+ order_by_col = snowpark_args[0]
5981
+ args = exp.unresolved_function.arguments
5982
+ if len(args) != 3:
5983
+ exception = AssertionError(
5984
+ f"{function_name} expected 3 args but got {len(args)}"
5985
+ )
5986
+ attach_custom_error_code(exception, ErrorCodes.INTERNAL_ERROR)
5987
+ raise exception
5988
+ # literal value 0.0 - 1.0
5989
+ percentage_arg = args[1]
5990
+ sort_direction = args[2].sort_order.direction
5991
+ direction_str = "" # defaultValue
5992
+ if (
5993
+ sort_direction
5994
+ == expressions_proto.Expression.SortOrder.SORT_DIRECTION_DESCENDING
5995
+ ):
5996
+ direction_str = "DESC"
5997
+
5998
+ # Apply sort direction to the order_by column
5999
+ if direction_str == "DESC":
6000
+ order_by_col_with_direction = order_by_col.desc()
6001
+ else:
6002
+ order_by_col_with_direction = order_by_col.asc()
6003
+
5619
6004
  result_exp = snowpark_fn.function(function_name)(
5620
- _check_percentile_percentage(exp.unresolved_function.arguments[1])
5621
- ).within_group(snowpark_args[0])
6005
+ _check_percentile_percentage(percentage_arg)
6006
+ ).within_group(order_by_col_with_direction)
5622
6007
  result_exp = (
5623
6008
  TypedColumn(
5624
6009
  snowpark_fn.cast(result_exp, FloatType()), lambda: [DoubleType()]
@@ -5627,7 +6012,8 @@ def map_unresolved_function(
5627
6012
  else TypedColumnWithDeferredCast(result_exp, lambda: [DoubleType()])
5628
6013
  )
5629
6014
 
5630
- spark_function_name = f"{function_name}({unwrap_literal(exp.unresolved_function.arguments[1])}) WITHIN GROUP (ORDER BY {snowpark_arg_names[0]})"
6015
+ direction_part = f" {direction_str}" if direction_str else ""
6016
+ spark_function_name = f"{function_name}({unwrap_literal(percentage_arg)}) WITHIN GROUP (ORDER BY {snowpark_arg_names[0]}{direction_part})"
5631
6017
  case "pi":
5632
6018
  spark_function_name = "PI()"
5633
6019
  result_exp = snowpark_fn.lit(math.pi)
@@ -5767,7 +6153,11 @@ def map_unresolved_function(
5767
6153
  case "positive":
5768
6154
  arg_type = snowpark_typed_args[0].typ
5769
6155
  spark_function_name = f"(+ {snowpark_arg_names[0]})"
5770
- if isinstance(arg_type, _NumericType):
6156
+ if (
6157
+ isinstance(arg_type, _NumericType)
6158
+ or isinstance(arg_type, YearMonthIntervalType)
6159
+ or isinstance(arg_type, DayTimeIntervalType)
6160
+ ):
5771
6161
  result_exp = snowpark_args[0]
5772
6162
  elif isinstance(arg_type, StringType):
5773
6163
  if spark_sql_ansi_enabled:
@@ -5790,6 +6180,8 @@ def map_unresolved_function(
5790
6180
  result_type = (
5791
6181
  snowpark_typed_args[0].types
5792
6182
  if isinstance(arg_type, _NumericType)
6183
+ or isinstance(arg_type, YearMonthIntervalType)
6184
+ or isinstance(arg_type, DayTimeIntervalType)
5793
6185
  else DoubleType()
5794
6186
  )
5795
6187
 
@@ -6660,11 +7052,43 @@ def map_unresolved_function(
6660
7052
  fn_name = "sign"
6661
7053
 
6662
7054
  spark_function_name = f"{fn_name}({snowpark_arg_names[0]})"
6663
- result_exp = snowpark_fn.when(
6664
- snowpark_args[0] == NAN, snowpark_fn.lit(NAN)
6665
- ).otherwise(
6666
- snowpark_fn.cast(snowpark_fn.sign(snowpark_args[0]), DoubleType())
6667
- )
7055
+
7056
+ if isinstance(snowpark_typed_args[0].typ, YearMonthIntervalType):
7057
+ # Use SQL expression for zero year-month interval comparison
7058
+ result_exp = (
7059
+ snowpark_fn.when(
7060
+ snowpark_args[0]
7061
+ > snowpark_fn.sql_expr("INTERVAL '0-0' YEAR TO MONTH"),
7062
+ snowpark_fn.lit(1.0),
7063
+ )
7064
+ .when(
7065
+ snowpark_args[0]
7066
+ < snowpark_fn.sql_expr("INTERVAL '0-0' YEAR TO MONTH"),
7067
+ snowpark_fn.lit(-1.0),
7068
+ )
7069
+ .otherwise(snowpark_fn.lit(0.0))
7070
+ )
7071
+ elif isinstance(snowpark_typed_args[0].typ, DayTimeIntervalType):
7072
+ # Use SQL expression for zero day-time interval comparison
7073
+ result_exp = (
7074
+ snowpark_fn.when(
7075
+ snowpark_args[0]
7076
+ > snowpark_fn.sql_expr("INTERVAL '0 0:0:0' DAY TO SECOND"),
7077
+ snowpark_fn.lit(1.0),
7078
+ )
7079
+ .when(
7080
+ snowpark_args[0]
7081
+ < snowpark_fn.sql_expr("INTERVAL '0 0:0:0' DAY TO SECOND"),
7082
+ snowpark_fn.lit(-1.0),
7083
+ )
7084
+ .otherwise(snowpark_fn.lit(0.0))
7085
+ )
7086
+ else:
7087
+ result_exp = snowpark_fn.when(
7088
+ snowpark_args[0] == NAN, snowpark_fn.lit(NAN)
7089
+ ).otherwise(
7090
+ snowpark_fn.cast(snowpark_fn.sign(snowpark_args[0]), DoubleType())
7091
+ )
6668
7092
  result_type = DoubleType()
6669
7093
  case "sin":
6670
7094
  spark_function_name = f"SIN({snowpark_arg_names[0]})"
@@ -6909,7 +7333,16 @@ def map_unresolved_function(
6909
7333
  )
6910
7334
  raise exception
6911
7335
  case "split_part":
6912
- result_exp = snowpark_fn.call_function("split_part", *snowpark_args)
7336
+ # Check for index 0 and throw error to match PySpark behavior
7337
+ raise_error = _raise_error_helper(StringType(), SparkRuntimeException)
7338
+ result_exp = snowpark_fn.when(
7339
+ snowpark_args[2] == 0,
7340
+ raise_error(
7341
+ snowpark_fn.lit(
7342
+ "[INVALID_INDEX_OF_ZERO] The index 0 is invalid. An index shall be either < 0 or > 0 (the first element has index 1)."
7343
+ )
7344
+ ),
7345
+ ).otherwise(snowpark_fn.call_function("split_part", *snowpark_args))
6913
7346
  result_type = StringType()
6914
7347
  case "sqrt":
6915
7348
  spark_function_name = f"SQRT({snowpark_arg_names[0]})"
@@ -7939,18 +8372,123 @@ def map_unresolved_function(
7939
8372
  )
7940
8373
  result_type = DateType()
7941
8374
  case "try_add":
7942
- # Check for interval types and throw NotImplementedError
7943
- for arg in snowpark_typed_args:
7944
- if isinstance(arg.typ, (YearMonthIntervalType, DayTimeIntervalType)):
7945
- exception = NotImplementedError(
7946
- "try_add with interval types is not supported"
8375
+ # Handle interval arithmetic with overflow detection
8376
+ match (snowpark_typed_args[0].typ, snowpark_typed_args[1].typ):
8377
+ case (DateType(), t) | (t, DateType()) if isinstance(
8378
+ t, YearMonthIntervalType
8379
+ ):
8380
+ result_type = DateType()
8381
+ result_exp = snowpark_args[0] + snowpark_args[1]
8382
+ case (DateType(), t) | (t, DateType()) if isinstance(
8383
+ t, DayTimeIntervalType
8384
+ ):
8385
+ result_type = TimestampType()
8386
+ result_exp = snowpark_args[0] + snowpark_args[1]
8387
+ case (TimestampType(), t) | (t, TimestampType()) if isinstance(
8388
+ t, (DayTimeIntervalType, YearMonthIntervalType)
8389
+ ):
8390
+ result_type = (
8391
+ snowpark_typed_args[0].typ
8392
+ if isinstance(snowpark_typed_args[0].typ, TimestampType)
8393
+ else snowpark_typed_args[1].typ
7947
8394
  )
7948
- attach_custom_error_code(
7949
- exception, ErrorCodes.UNSUPPORTED_OPERATION
8395
+ result_exp = snowpark_args[0] + snowpark_args[1]
8396
+ case (t1, t2) if (
8397
+ isinstance(t1, YearMonthIntervalType)
8398
+ and isinstance(t2, (_NumericType, StringType))
8399
+ ) or (
8400
+ isinstance(t2, YearMonthIntervalType)
8401
+ and isinstance(t1, (_NumericType, StringType))
8402
+ ):
8403
+ # YearMonthInterval + numeric/string or numeric/string + YearMonthInterval should throw error
8404
+ exception = AnalysisException(
8405
+ f'[DATATYPE_MISMATCH.BINARY_OP_DIFF_TYPES] Cannot resolve "try_add({snowpark_arg_names[0]}, {snowpark_arg_names[1]})" due to data type mismatch: the left and right operands of the binary operator have incompatible types ("{snowpark_typed_args[0].typ}" and "{snowpark_typed_args[1].typ}").'
7950
8406
  )
8407
+ attach_custom_error_code(exception, ErrorCodes.TYPE_MISMATCH)
7951
8408
  raise exception
7952
- result_exp = _try_arithmetic_helper(snowpark_typed_args, snowpark_args, 0)
7953
- result_exp = _type_with_typer(result_exp)
8409
+ case (t1, t2) if isinstance(t1, YearMonthIntervalType) and isinstance(
8410
+ t2, YearMonthIntervalType
8411
+ ):
8412
+ result_type = YearMonthIntervalType(
8413
+ min(t1.start_field, t2.start_field),
8414
+ max(t1.end_field, t2.end_field),
8415
+ )
8416
+
8417
+ # For year-month intervals, throw ArithmeticException if operands reach 10+ digits OR result exceeds 9 digits
8418
+ total1 = _calculate_total_months(snowpark_args[0])
8419
+ total2 = _calculate_total_months(snowpark_args[1])
8420
+ ten_digit_limit = snowpark_fn.lit(MAX_10_DIGIT_LIMIT)
8421
+
8422
+ precision_violation = (
8423
+ # Check if either operand already reaches 10 digits (parsing limit)
8424
+ (snowpark_fn.abs(total1) >= ten_digit_limit)
8425
+ | (snowpark_fn.abs(total2) >= ten_digit_limit)
8426
+ | (
8427
+ (total1 > 0)
8428
+ & (total2 > 0)
8429
+ & (total1 >= ten_digit_limit - total2)
8430
+ )
8431
+ | (
8432
+ (total1 < 0)
8433
+ & (total2 < 0)
8434
+ & (total1 <= -ten_digit_limit - total2)
8435
+ )
8436
+ )
8437
+
8438
+ raise_error = _raise_error_helper(result_type, ArithmeticException)
8439
+ result_exp = snowpark_fn.when(
8440
+ precision_violation,
8441
+ raise_error(
8442
+ snowpark_fn.lit(
8443
+ "Year-Month Interval result exceeds Snowflake interval precision limit"
8444
+ )
8445
+ ),
8446
+ ).otherwise(snowpark_args[0] + snowpark_args[1])
8447
+ case (t1, t2) if isinstance(t1, DayTimeIntervalType) and isinstance(
8448
+ t2, DayTimeIntervalType
8449
+ ):
8450
+ result_type = DayTimeIntervalType(
8451
+ min(t1.start_field, t2.start_field),
8452
+ max(t1.end_field, t2.end_field),
8453
+ )
8454
+ # Check for Snowflake's day limit (106751991 days is the cutoff)
8455
+ days1 = snowpark_fn.date_part("day", snowpark_args[0])
8456
+ days2 = snowpark_fn.date_part("day", snowpark_args[1])
8457
+ max_days = snowpark_fn.lit(
8458
+ MAX_DAY_TIME_DAYS
8459
+ ) # Snowflake's actual limit
8460
+ min_days = snowpark_fn.lit(-MAX_DAY_TIME_DAYS)
8461
+
8462
+ # Check if either operand exceeds the day limit - throw error like Spark does
8463
+ operand_limit_violation = (snowpark_fn.abs(days1) > max_days) | (
8464
+ snowpark_fn.abs(days2) > max_days
8465
+ )
8466
+
8467
+ # Check if result would exceed day limit (but operands are valid) - return NULL
8468
+ result_overflow = (
8469
+ # Check if result would exceed day limit (positive overflow)
8470
+ ((days1 > 0) & (days2 > 0) & (days1 > max_days - days2))
8471
+ | ((days1 < 0) & (days2 < 0) & (days1 < min_days - days2))
8472
+ )
8473
+
8474
+ raise_error = _raise_error_helper(result_type, ArithmeticException)
8475
+ result_exp = (
8476
+ snowpark_fn.when(
8477
+ operand_limit_violation,
8478
+ raise_error(
8479
+ snowpark_fn.lit(
8480
+ "Day-Time Interval operand exceeds Snowflake interval precision limit"
8481
+ )
8482
+ ),
8483
+ )
8484
+ .when(result_overflow, snowpark_fn.lit(None))
8485
+ .otherwise(snowpark_args[0] + snowpark_args[1])
8486
+ )
8487
+ case _:
8488
+ result_exp = _try_arithmetic_helper(
8489
+ snowpark_typed_args, snowpark_args, 0
8490
+ )
8491
+ result_exp = _type_with_typer(result_exp)
7954
8492
  case "try_aes_decrypt":
7955
8493
  result_exp = _aes_helper(
7956
8494
  "TRY_DECRYPT",
@@ -8002,17 +8540,49 @@ def map_unresolved_function(
8002
8540
  DoubleType(), cleaned, calculating_avg=True
8003
8541
  )
8004
8542
  case "try_divide":
8005
- # Check for interval types and throw NotImplementedError
8006
- for arg in snowpark_typed_args:
8007
- if isinstance(arg.typ, (YearMonthIntervalType, DayTimeIntervalType)):
8008
- exception = NotImplementedError(
8009
- "try_divide with interval types is not supported"
8010
- )
8011
- attach_custom_error_code(
8012
- exception, ErrorCodes.UNSUPPORTED_OPERATION
8013
- )
8014
- raise exception
8543
+ # Handle interval division with overflow detection
8015
8544
  match (snowpark_typed_args[0].typ, snowpark_typed_args[1].typ):
8545
+ case (t1, t2) if isinstance(t1, _AnsiIntervalType) and isinstance(
8546
+ t2, (_NumericType, StringType)
8547
+ ):
8548
+ # Interval / numeric/string
8549
+ result_type = t1
8550
+ interval_arg = snowpark_args[0]
8551
+ divisor = (
8552
+ snowpark_args[1]
8553
+ if isinstance(t2, _NumericType)
8554
+ else snowpark_fn.cast(snowpark_args[1], "double")
8555
+ )
8556
+
8557
+ # Check for division by zero first
8558
+ zero_check = divisor == 0
8559
+
8560
+ if isinstance(result_type, YearMonthIntervalType):
8561
+ # For year-month intervals, check if result exceeds 32-bit signed integer limit
8562
+ result_type = YearMonthIntervalType()
8563
+ total_months = _calculate_total_months(interval_arg)
8564
+ max_months = snowpark_fn.lit(MAX_32BIT_SIGNED_INT)
8565
+ overflow_check = (
8566
+ snowpark_fn.abs(total_months / divisor) > max_months
8567
+ )
8568
+ result_exp = (
8569
+ snowpark_fn.when(zero_check, snowpark_fn.lit(None))
8570
+ .when(overflow_check, snowpark_fn.lit(None))
8571
+ .otherwise(interval_arg / divisor)
8572
+ )
8573
+ else: # DayTimeIntervalType
8574
+ # For day-time intervals, check if result exceeds day limit
8575
+ result_type = DayTimeIntervalType()
8576
+ total_days = _calculate_total_days(interval_arg)
8577
+ max_days = snowpark_fn.lit(MAX_DAY_TIME_DAYS)
8578
+ overflow_check = (
8579
+ snowpark_fn.abs(total_days / divisor) > max_days
8580
+ )
8581
+ result_exp = (
8582
+ snowpark_fn.when(zero_check, snowpark_fn.lit(None))
8583
+ .when(overflow_check, snowpark_fn.lit(None))
8584
+ .otherwise(interval_arg / divisor)
8585
+ )
8016
8586
  case (NullType(), t) | (t, NullType()):
8017
8587
  result_exp = snowpark_fn.lit(None)
8018
8588
  result_type = FloatType()
@@ -8124,17 +8694,76 @@ def map_unresolved_function(
8124
8694
  attach_custom_error_code(exception, ErrorCodes.TYPE_MISMATCH)
8125
8695
  raise exception
8126
8696
  case "try_multiply":
8127
- # Check for interval types and throw NotImplementedError
8128
- for arg in snowpark_typed_args:
8129
- if isinstance(arg.typ, (YearMonthIntervalType, DayTimeIntervalType)):
8130
- exception = NotImplementedError(
8131
- "try_multiply with interval types is not supported"
8132
- )
8133
- attach_custom_error_code(
8134
- exception, ErrorCodes.UNSUPPORTED_OPERATION
8135
- )
8136
- raise exception
8137
8697
  match (snowpark_typed_args[0].typ, snowpark_typed_args[1].typ):
8698
+ case (t1, t2) if isinstance(t1, _AnsiIntervalType) and isinstance(
8699
+ t2, (_NumericType, StringType)
8700
+ ):
8701
+ # Interval * numeric/string
8702
+ result_type = t1
8703
+ interval_arg = snowpark_args[0]
8704
+ multiplier = (
8705
+ snowpark_args[1]
8706
+ if isinstance(t2, _NumericType)
8707
+ else snowpark_fn.cast(snowpark_args[1], "double")
8708
+ )
8709
+
8710
+ if isinstance(result_type, YearMonthIntervalType):
8711
+ # For year-month intervals, check if result exceeds 32-bit signed integer limit
8712
+ result_type = YearMonthIntervalType()
8713
+ total_months = _calculate_total_months(interval_arg)
8714
+ max_months = snowpark_fn.lit(MAX_32BIT_SIGNED_INT)
8715
+ overflow_check = (
8716
+ snowpark_fn.abs(total_months * multiplier) > max_months
8717
+ )
8718
+ result_exp = snowpark_fn.when(
8719
+ overflow_check, snowpark_fn.lit(None)
8720
+ ).otherwise(interval_arg * multiplier)
8721
+ else: # DayTimeIntervalType
8722
+ # For day-time intervals, check if result exceeds day limit
8723
+ result_type = DayTimeIntervalType()
8724
+ total_days = _calculate_total_days(interval_arg)
8725
+ max_days = snowpark_fn.lit(MAX_DAY_TIME_DAYS)
8726
+ overflow_check = (
8727
+ snowpark_fn.abs(total_days * multiplier) > max_days
8728
+ )
8729
+ result_exp = snowpark_fn.when(
8730
+ overflow_check, snowpark_fn.lit(None)
8731
+ ).otherwise(interval_arg * multiplier)
8732
+
8733
+ case (t1, t2) if isinstance(t2, _AnsiIntervalType) and isinstance(
8734
+ t1, (_NumericType, StringType)
8735
+ ):
8736
+ # numeric/string * Interval
8737
+ result_type = t2
8738
+ interval_arg = snowpark_args[1]
8739
+ multiplier = (
8740
+ snowpark_args[0]
8741
+ if isinstance(t1, _NumericType)
8742
+ else snowpark_fn.cast(snowpark_args[0], "double")
8743
+ )
8744
+
8745
+ if isinstance(result_type, YearMonthIntervalType):
8746
+ # For year-month intervals, check if result exceeds 32-bit signed integer limit
8747
+ result_type = YearMonthIntervalType()
8748
+ total_months = _calculate_total_months(interval_arg)
8749
+ max_months = snowpark_fn.lit(MAX_32BIT_SIGNED_INT)
8750
+ overflow_check = (
8751
+ snowpark_fn.abs(total_months * multiplier) > max_months
8752
+ )
8753
+ result_exp = snowpark_fn.when(
8754
+ overflow_check, snowpark_fn.lit(None)
8755
+ ).otherwise(interval_arg * multiplier)
8756
+ else: # DayTimeIntervalType
8757
+ # For day-time intervals, check if result exceeds day limit
8758
+ result_type = DayTimeIntervalType()
8759
+ total_days = _calculate_total_days(interval_arg)
8760
+ max_days = snowpark_fn.lit(MAX_DAY_TIME_DAYS)
8761
+ overflow_check = (
8762
+ snowpark_fn.abs(total_days * multiplier) > max_days
8763
+ )
8764
+ result_exp = snowpark_fn.when(
8765
+ overflow_check, snowpark_fn.lit(None)
8766
+ ).otherwise(interval_arg * multiplier)
8138
8767
  case (NullType(), t) | (t, NullType()):
8139
8768
  result_exp = snowpark_fn.lit(None)
8140
8769
  match t:
@@ -8234,18 +8863,112 @@ def map_unresolved_function(
8234
8863
  snowpark_typed_args[0].typ, snowpark_args[0]
8235
8864
  )
8236
8865
  case "try_subtract":
8237
- # Check for interval types and throw NotImplementedError
8238
- for arg in snowpark_typed_args:
8239
- if isinstance(arg.typ, (YearMonthIntervalType, DayTimeIntervalType)):
8240
- exception = NotImplementedError(
8241
- "try_subtract with interval types is not supported"
8242
- )
8243
- attach_custom_error_code(
8244
- exception, ErrorCodes.UNSUPPORTED_OPERATION
8866
+ # Handle interval arithmetic with overflow detection
8867
+ match (snowpark_typed_args[0].typ, snowpark_typed_args[1].typ):
8868
+ case (DateType(), t) if isinstance(t, YearMonthIntervalType):
8869
+ result_type = DateType()
8870
+ result_exp = snowpark_args[0] - snowpark_args[1]
8871
+ case (DateType(), t) if isinstance(t, DayTimeIntervalType):
8872
+ result_type = TimestampType()
8873
+ result_exp = snowpark_args[0] - snowpark_args[1]
8874
+ case (TimestampType(), t) if isinstance(
8875
+ t, (DayTimeIntervalType, YearMonthIntervalType)
8876
+ ):
8877
+ result_type = snowpark_typed_args[0].typ
8878
+ result_exp = snowpark_args[0] - snowpark_args[1]
8879
+ case (t1, t2) if (
8880
+ isinstance(t1, YearMonthIntervalType)
8881
+ and isinstance(t2, (_NumericType, StringType))
8882
+ ) or (
8883
+ isinstance(t2, YearMonthIntervalType)
8884
+ and isinstance(t1, (_NumericType, StringType))
8885
+ ):
8886
+ # YearMonthInterval - numeric/string or numeric/string - YearMonthInterval should throw error
8887
+ exception = AnalysisException(
8888
+ f'[DATATYPE_MISMATCH.BINARY_OP_DIFF_TYPES] Cannot resolve "try_subtract({snowpark_arg_names[0]}, {snowpark_arg_names[1]})" due to data type mismatch: the left and right operands of the binary operator have incompatible types ("{snowpark_typed_args[0].typ}" and "{snowpark_typed_args[1].typ}").'
8245
8889
  )
8890
+ attach_custom_error_code(exception, ErrorCodes.TYPE_MISMATCH)
8246
8891
  raise exception
8247
- result_exp = _try_arithmetic_helper(snowpark_typed_args, snowpark_args, 1)
8248
- result_exp = _type_with_typer(result_exp)
8892
+ case (t1, t2) if isinstance(t1, YearMonthIntervalType) and isinstance(
8893
+ t2, YearMonthIntervalType
8894
+ ):
8895
+ result_type = YearMonthIntervalType(
8896
+ min(t1.start_field, t2.start_field),
8897
+ max(t1.end_field, t2.end_field),
8898
+ )
8899
+ # Check for Snowflake's precision limits: 10+ digits for operands, 9+ digits for results
8900
+ total1 = _calculate_total_months(snowpark_args[0])
8901
+ total2 = _calculate_total_months(snowpark_args[1])
8902
+ ten_digit_limit = snowpark_fn.lit(MAX_10_DIGIT_LIMIT)
8903
+
8904
+ precision_violation = (
8905
+ # Check if either operand already reaches 10 digits (parsing limit)
8906
+ (snowpark_fn.abs(total1) >= ten_digit_limit)
8907
+ | (snowpark_fn.abs(total2) >= ten_digit_limit)
8908
+ | (
8909
+ (total1 > 0)
8910
+ & (total2 < 0)
8911
+ & (total1 >= ten_digit_limit + total2)
8912
+ )
8913
+ | (
8914
+ (total1 < 0)
8915
+ & (total2 > 0)
8916
+ & (total1 <= -ten_digit_limit + total2)
8917
+ )
8918
+ )
8919
+
8920
+ raise_error = _raise_error_helper(result_type, ArithmeticException)
8921
+ result_exp = snowpark_fn.when(
8922
+ precision_violation,
8923
+ raise_error(
8924
+ snowpark_fn.lit(
8925
+ "Year-Month Interval result exceeds Snowflake interval precision limit"
8926
+ )
8927
+ ),
8928
+ ).otherwise(snowpark_args[0] - snowpark_args[1])
8929
+ case (t1, t2) if isinstance(t1, DayTimeIntervalType) and isinstance(
8930
+ t2, DayTimeIntervalType
8931
+ ):
8932
+ result_type = DayTimeIntervalType(
8933
+ min(t1.start_field, t2.start_field),
8934
+ max(t1.end_field, t2.end_field),
8935
+ )
8936
+ # Check for Snowflake's day limit (106751991 days is the cutoff)
8937
+ days1 = snowpark_fn.date_part("day", snowpark_args[0])
8938
+ days2 = snowpark_fn.date_part("day", snowpark_args[1])
8939
+ max_days = snowpark_fn.lit(
8940
+ MAX_DAY_TIME_DAYS
8941
+ ) # Snowflake's actual limit
8942
+ min_days = snowpark_fn.lit(-MAX_DAY_TIME_DAYS)
8943
+
8944
+ # Check if either operand exceeds the day limit - throw error like Spark does
8945
+ operand_limit_violation = (snowpark_fn.abs(days1) > max_days) | (
8946
+ snowpark_fn.abs(days2) > max_days
8947
+ )
8948
+
8949
+ # Check if result would exceed day limit (but operands are valid) - return NULL
8950
+ result_overflow = (
8951
+ (days1 > 0) & (days2 < 0) & (days1 > max_days + days2)
8952
+ ) | ((days1 < 0) & (days2 > 0) & (days1 < min_days + days2))
8953
+
8954
+ raise_error = _raise_error_helper(result_type, ArithmeticException)
8955
+ result_exp = (
8956
+ snowpark_fn.when(
8957
+ operand_limit_violation,
8958
+ raise_error(
8959
+ snowpark_fn.lit(
8960
+ "Day-Time Interval operand exceeds day limit"
8961
+ )
8962
+ ),
8963
+ )
8964
+ .when(result_overflow, snowpark_fn.lit(None))
8965
+ .otherwise(snowpark_args[0] - snowpark_args[1])
8966
+ )
8967
+ case _:
8968
+ result_exp = _try_arithmetic_helper(
8969
+ snowpark_typed_args, snowpark_args, 1
8970
+ )
8971
+ result_exp = _type_with_typer(result_exp)
8249
8972
  case "try_to_number":
8250
8973
  try_to_number = snowpark_fn.function("try_to_number")
8251
8974
  precision, scale = resolve_to_number_precision_and_scale(exp)
@@ -8828,7 +9551,7 @@ def map_unresolved_function(
8828
9551
  spark_col_names if len(spark_col_names) > 0 else [spark_function_name]
8829
9552
  )
8830
9553
  typed_col = _to_typed_column(result_exp, result_type, function_name)
8831
- typed_col.set_qualifiers(qualifiers)
9554
+ typed_col.set_qualifiers({ColumnQualifier(tuple(qualifier_parts))})
8832
9555
  return spark_col_names, typed_col
8833
9556
 
8834
9557
 
@@ -9025,6 +9748,20 @@ def _find_common_type(
9025
9748
  key_type = _common(type1.key_type, type2.key_type)
9026
9749
  value_type = _common(type1.value_type, type2.value_type)
9027
9750
  return MapType(key_type, value_type)
9751
+ case (_, _) if isinstance(type1, YearMonthIntervalType) and isinstance(
9752
+ type2, YearMonthIntervalType
9753
+ ):
9754
+ return YearMonthIntervalType(
9755
+ min(type1.start_field, type2.start_field),
9756
+ max(type1.end_field, type2.end_field),
9757
+ )
9758
+ case (_, _) if isinstance(type1, DayTimeIntervalType) and isinstance(
9759
+ type2, DayTimeIntervalType
9760
+ ):
9761
+ return DayTimeIntervalType(
9762
+ min(type1.start_field, type2.start_field),
9763
+ max(type1.end_field, type2.end_field),
9764
+ )
9028
9765
  case _:
9029
9766
  exception = AnalysisException(exception_base_message)
9030
9767
  attach_custom_error_code(exception, ErrorCodes.TYPE_MISMATCH)
@@ -10384,9 +11121,7 @@ def _get_add_sub_result_type(
10384
11121
  return result_type, overflow_possible
10385
11122
 
10386
11123
 
10387
- def _get_interval_type_name(
10388
- interval_type: Union[YearMonthIntervalType, DayTimeIntervalType]
10389
- ) -> str:
11124
+ def _get_interval_type_name(interval_type: _AnsiIntervalType) -> str:
10390
11125
  """Get the formatted interval type name for error messages."""
10391
11126
  if isinstance(interval_type, YearMonthIntervalType):
10392
11127
  if interval_type.start_field == 0 and interval_type.end_field == 0:
@@ -10413,21 +11148,15 @@ def _check_interval_string_comparison(
10413
11148
  ) -> None:
10414
11149
  """Check for invalid interval-string comparisons and raise AnalysisException if found."""
10415
11150
  if (
10416
- isinstance(
10417
- snowpark_typed_args[0].typ, (YearMonthIntervalType, DayTimeIntervalType)
10418
- )
11151
+ isinstance(snowpark_typed_args[0].typ, _AnsiIntervalType)
10419
11152
  and isinstance(snowpark_typed_args[1].typ, StringType)
10420
11153
  or isinstance(snowpark_typed_args[0].typ, StringType)
10421
- and isinstance(
10422
- snowpark_typed_args[1].typ, (YearMonthIntervalType, DayTimeIntervalType)
10423
- )
11154
+ and isinstance(snowpark_typed_args[1].typ, _AnsiIntervalType)
10424
11155
  ):
10425
11156
  # Format interval type name for error message
10426
11157
  interval_type = (
10427
11158
  snowpark_typed_args[0].typ
10428
- if isinstance(
10429
- snowpark_typed_args[0].typ, (YearMonthIntervalType, DayTimeIntervalType)
10430
- )
11159
+ if isinstance(snowpark_typed_args[0].typ, _AnsiIntervalType)
10431
11160
  else snowpark_typed_args[1].typ
10432
11161
  )
10433
11162
  interval_name = _get_interval_type_name(interval_type)
@@ -10494,12 +11223,18 @@ def _get_spark_function_name(
10494
11223
  case (DateType(), DayTimeIntervalType()) | (
10495
11224
  DateType(),
10496
11225
  YearMonthIntervalType(),
11226
+ ) | (TimestampType(), DayTimeIntervalType()) | (
11227
+ TimestampType(),
11228
+ YearMonthIntervalType(),
10497
11229
  ):
10498
11230
  date_param_name1 = _get_literal_param_name(exp, 0, snowpark_arg_names[0])
10499
11231
  return f"{date_param_name1} {operation_op} {snowpark_arg_names[1]}"
10500
11232
  case (DayTimeIntervalType(), DateType()) | (
10501
11233
  YearMonthIntervalType(),
10502
11234
  DateType(),
11235
+ ) | (DayTimeIntervalType(), TimestampType()) | (
11236
+ YearMonthIntervalType(),
11237
+ TimestampType(),
10503
11238
  ):
10504
11239
  date_param_name2 = _get_literal_param_name(exp, 1, snowpark_arg_names[1])
10505
11240
  if function_name == "+":
@@ -10887,3 +11622,30 @@ def _map_from_spark_tz(value: Column) -> Column:
10887
11622
  .when(value == "VST", snowpark_fn.lit("Asia/Ho_Chi_Minh"))
10888
11623
  .otherwise(value) # Return original timezone if no mapping found
10889
11624
  )
11625
+
11626
+
11627
+ def _calculate_total_months(interval_arg):
11628
+ """Calculate total months from a year-month interval."""
11629
+ years = snowpark_fn.date_part("year", interval_arg)
11630
+ months = snowpark_fn.date_part("month", interval_arg)
11631
+ return years * 12 + months
11632
+
11633
+
11634
+ def _calculate_total_days(interval_arg):
11635
+ """Calculate total days from a day-time interval."""
11636
+ days = snowpark_fn.date_part("day", interval_arg)
11637
+ hours = snowpark_fn.date_part("hour", interval_arg)
11638
+ minutes = snowpark_fn.date_part("minute", interval_arg)
11639
+ seconds = snowpark_fn.date_part("second", interval_arg)
11640
+ # Convert hours, minutes, seconds to fractional days
11641
+ fractional_days = (hours * 3600 + minutes * 60 + seconds) / 86400
11642
+ return days + fractional_days
11643
+
11644
+
11645
+ def _calculate_total_seconds(interval_arg):
11646
+ """Calculate total seconds from a day-time interval."""
11647
+ days = snowpark_fn.date_part("day", interval_arg)
11648
+ hours = snowpark_fn.date_part("hour", interval_arg)
11649
+ minutes = snowpark_fn.date_part("minute", interval_arg)
11650
+ seconds = snowpark_fn.date_part("second", interval_arg)
11651
+ return days * 86400 + hours * 3600 + minutes * 60 + seconds