snowpark-connect 0.28.1__py3-none-any.whl → 0.30.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of snowpark-connect might be problematic. Click here for more details.

Files changed (47) hide show
  1. snowflake/snowpark_connect/analyze_plan/map_tree_string.py +8 -4
  2. snowflake/snowpark_connect/client.py +65 -0
  3. snowflake/snowpark_connect/column_name_handler.py +6 -0
  4. snowflake/snowpark_connect/config.py +33 -5
  5. snowflake/snowpark_connect/execute_plan/map_execution_root.py +21 -19
  6. snowflake/snowpark_connect/expression/map_extension.py +277 -1
  7. snowflake/snowpark_connect/expression/map_sql_expression.py +107 -2
  8. snowflake/snowpark_connect/expression/map_unresolved_function.py +425 -269
  9. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.py +12 -10
  10. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.pyi +14 -2
  11. snowflake/snowpark_connect/relation/io_utils.py +21 -1
  12. snowflake/snowpark_connect/relation/map_column_ops.py +9 -4
  13. snowflake/snowpark_connect/relation/map_extension.py +21 -4
  14. snowflake/snowpark_connect/relation/map_join.py +8 -0
  15. snowflake/snowpark_connect/relation/map_map_partitions.py +7 -8
  16. snowflake/snowpark_connect/relation/map_relation.py +1 -3
  17. snowflake/snowpark_connect/relation/map_row_ops.py +116 -15
  18. snowflake/snowpark_connect/relation/map_show_string.py +14 -6
  19. snowflake/snowpark_connect/relation/map_sql.py +39 -5
  20. snowflake/snowpark_connect/relation/map_stats.py +1 -1
  21. snowflake/snowpark_connect/relation/read/map_read.py +22 -3
  22. snowflake/snowpark_connect/relation/read/map_read_csv.py +119 -29
  23. snowflake/snowpark_connect/relation/read/map_read_json.py +57 -36
  24. snowflake/snowpark_connect/relation/read/map_read_parquet.py +7 -1
  25. snowflake/snowpark_connect/relation/read/map_read_text.py +6 -1
  26. snowflake/snowpark_connect/relation/read/metadata_utils.py +159 -0
  27. snowflake/snowpark_connect/relation/stage_locator.py +85 -53
  28. snowflake/snowpark_connect/relation/write/map_write.py +67 -4
  29. snowflake/snowpark_connect/server.py +29 -16
  30. snowflake/snowpark_connect/type_mapping.py +75 -3
  31. snowflake/snowpark_connect/utils/context.py +0 -14
  32. snowflake/snowpark_connect/utils/describe_query_cache.py +6 -3
  33. snowflake/snowpark_connect/utils/io_utils.py +36 -0
  34. snowflake/snowpark_connect/utils/session.py +4 -0
  35. snowflake/snowpark_connect/utils/telemetry.py +30 -5
  36. snowflake/snowpark_connect/utils/udf_cache.py +37 -7
  37. snowflake/snowpark_connect/version.py +1 -1
  38. {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.30.0.dist-info}/METADATA +3 -2
  39. {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.30.0.dist-info}/RECORD +47 -45
  40. {snowpark_connect-0.28.1.data → snowpark_connect-0.30.0.data}/scripts/snowpark-connect +0 -0
  41. {snowpark_connect-0.28.1.data → snowpark_connect-0.30.0.data}/scripts/snowpark-session +0 -0
  42. {snowpark_connect-0.28.1.data → snowpark_connect-0.30.0.data}/scripts/snowpark-submit +0 -0
  43. {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.30.0.dist-info}/WHEEL +0 -0
  44. {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.30.0.dist-info}/licenses/LICENSE-binary +0 -0
  45. {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.30.0.dist-info}/licenses/LICENSE.txt +0 -0
  46. {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.30.0.dist-info}/licenses/NOTICE-binary +0 -0
  47. {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.30.0.dist-info}/top_level.txt +0 -0
@@ -15,11 +15,12 @@ import tempfile
15
15
  import time
16
16
  import uuid
17
17
  from collections import defaultdict
18
+ from collections.abc import Callable
18
19
  from contextlib import suppress
19
20
  from decimal import ROUND_HALF_EVEN, ROUND_HALF_UP, Context, Decimal
20
21
  from functools import partial, reduce
21
22
  from pathlib import Path
22
- from typing import List, Optional
23
+ from typing import List, Optional, Union
23
24
  from urllib.parse import quote, unquote
24
25
 
25
26
  import pyspark.sql.connect.proto.expressions_pb2 as expressions_proto
@@ -49,6 +50,7 @@ from snowflake.snowpark.types import (
49
50
  ByteType,
50
51
  DataType,
51
52
  DateType,
53
+ DayTimeIntervalType,
52
54
  DecimalType,
53
55
  DoubleType,
54
56
  FloatType,
@@ -63,6 +65,7 @@ from snowflake.snowpark.types import (
63
65
  TimestampTimeZone,
64
66
  TimestampType,
65
67
  VariantType,
68
+ YearMonthIntervalType,
66
69
  _FractionalType,
67
70
  _IntegralType,
68
71
  _NumericType,
@@ -199,7 +202,7 @@ def _validate_numeric_args(
199
202
  case StringType():
200
203
  # Cast strings to doubles following Spark
201
204
  # https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala#L204
202
- modified_args[i] = _try_cast_helper(snowpark_args[i], DoubleType())
205
+ modified_args[i] = snowpark_fn.try_cast(snowpark_args[i], DoubleType())
203
206
  case _:
204
207
  raise TypeError(
205
208
  f"Data type mismatch: {function_name} requires numeric types, but got {typed_args[0].typ} and {typed_args[1].typ}."
@@ -519,7 +522,7 @@ def map_unresolved_function(
519
522
  DecimalType() as t,
520
523
  ):
521
524
  p1, s1 = _get_type_precision(t)
522
- result_type = _get_decimal_multiplication_result_type(
525
+ result_type, _ = _get_decimal_multiplication_result_type(
523
526
  p1, s1, p1, s1
524
527
  )
525
528
  result_exp = snowpark_fn.lit(None)
@@ -528,11 +531,17 @@ def map_unresolved_function(
528
531
  ):
529
532
  p1, s1 = _get_type_precision(snowpark_typed_args[0].typ)
530
533
  p2, s2 = _get_type_precision(snowpark_typed_args[1].typ)
531
- result_type = _get_decimal_multiplication_result_type(
532
- p1, s1, p2, s2
533
- )
534
- result_exp = _get_decimal_multiplication_result_exp(
535
- result_type, t, snowpark_args
534
+ (
535
+ result_type,
536
+ overflow_possible,
537
+ ) = _get_decimal_multiplication_result_type(p1, s1, p2, s2)
538
+ result_exp = _arithmetic_operation(
539
+ snowpark_typed_args[0],
540
+ snowpark_typed_args[1],
541
+ lambda x, y: x * y,
542
+ overflow_possible,
543
+ global_config.spark_sql_ansi_enabled,
544
+ result_type,
536
545
  )
537
546
  case (NullType(), NullType()):
538
547
  result_type = DoubleType()
@@ -617,7 +626,7 @@ def map_unresolved_function(
617
626
  )
618
627
  match (snowpark_typed_args[0].typ, snowpark_typed_args[1].typ):
619
628
  case (NullType(), _) | (_, NullType()):
620
- result_type = _get_add_sub_result_type(
629
+ result_type, _ = _get_add_sub_result_type(
621
630
  snowpark_typed_args[0].typ,
622
631
  snowpark_typed_args[1].typ,
623
632
  spark_function_name,
@@ -632,8 +641,17 @@ def map_unresolved_function(
632
641
  if isinstance(t, (IntegerType, ShortType, ByteType)):
633
642
  result_type = DateType()
634
643
  result_exp = snowpark_args[0] + snowpark_args[1]
644
+ elif isinstance(t, (DayTimeIntervalType, YearMonthIntervalType)):
645
+ result_type = TimestampType()
646
+ result_exp = (
647
+ snowpark_args[date_param_index]
648
+ + snowpark_args[t_param_index]
649
+ )
635
650
  elif (
636
- "INTERVAL"
651
+ hasattr(
652
+ snowpark_typed_args[t_param_index].col._expr1, "pretty_name"
653
+ )
654
+ and "INTERVAL"
637
655
  == snowpark_typed_args[t_param_index].col._expr1.pretty_name
638
656
  ):
639
657
  result_type = TimestampType()
@@ -693,14 +711,21 @@ def map_unresolved_function(
693
711
  f'[DATATYPE_MISMATCH.BINARY_OP_DIFF_TYPES] Cannot resolve "{spark_function_name}" due to data type mismatch: the left and right operands of the binary operator have incompatible types ("{snowpark_typed_args[0].typ}" and "{snowpark_typed_args[1].typ}").'
694
712
  )
695
713
  case _:
696
- result_type = _get_add_sub_result_type(
714
+ result_type, overflow_possible = _get_add_sub_result_type(
697
715
  snowpark_typed_args[0].typ,
698
716
  snowpark_typed_args[1].typ,
699
717
  spark_function_name,
700
718
  )
701
- result_exp = snowpark_args[0] + snowpark_args[1]
702
- if isinstance(result_type, DecimalType):
703
- result_exp = _cast_helper(result_exp, result_type)
719
+
720
+ result_exp = _arithmetic_operation(
721
+ snowpark_typed_args[0],
722
+ snowpark_typed_args[1],
723
+ lambda x, y: x + y,
724
+ overflow_possible,
725
+ global_config.spark_sql_ansi_enabled,
726
+ result_type,
727
+ )
728
+
704
729
  case "-":
705
730
  spark_function_name = _get_spark_function_name(
706
731
  snowpark_typed_args[0],
@@ -715,7 +740,7 @@ def map_unresolved_function(
715
740
  result_type = LongType()
716
741
  result_exp = snowpark_fn.lit(None).cast(result_type)
717
742
  case (NullType(), _) | (_, NullType()):
718
- result_type = _get_add_sub_result_type(
743
+ result_type, _ = _get_add_sub_result_type(
719
744
  snowpark_typed_args[0].typ,
720
745
  snowpark_typed_args[1].typ,
721
746
  spark_function_name,
@@ -726,6 +751,12 @@ def map_unresolved_function(
726
751
  # TODO SNOW-2034420: resolve return type (it should be INTERVAL DAY)
727
752
  result_type = LongType()
728
753
  result_exp = snowpark_args[0] - snowpark_args[1]
754
+ case (DateType(), DayTimeIntervalType()) | (
755
+ DateType(),
756
+ YearMonthIntervalType(),
757
+ ):
758
+ result_type = TimestampType()
759
+ result_exp = snowpark_args[0] - snowpark_args[1]
729
760
  case (DateType(), StringType()):
730
761
  if (
731
762
  hasattr(snowpark_typed_args[1].col._expr1, "pretty_name")
@@ -806,14 +837,20 @@ def map_unresolved_function(
806
837
  f'[DATATYPE_MISMATCH.BINARY_OP_DIFF_TYPES] Cannot resolve "{spark_function_name}" due to data type mismatch: the left and right operands of the binary operator have incompatible types ("{snowpark_typed_args[0].typ}" and "{snowpark_typed_args[1].typ}").'
807
838
  )
808
839
  case _:
809
- result_type = _get_add_sub_result_type(
840
+ result_type, overflow_possible = _get_add_sub_result_type(
810
841
  snowpark_typed_args[0].typ,
811
842
  snowpark_typed_args[1].typ,
812
843
  spark_function_name,
813
844
  )
814
- result_exp = snowpark_args[0] - snowpark_args[1]
815
- if isinstance(result_type, DecimalType):
816
- result_exp = _cast_helper(result_exp, result_type)
845
+ result_exp = _arithmetic_operation(
846
+ snowpark_typed_args[0],
847
+ snowpark_typed_args[1],
848
+ lambda x, y: x - y,
849
+ overflow_possible,
850
+ global_config.spark_sql_ansi_enabled,
851
+ result_type,
852
+ )
853
+
817
854
  case "/":
818
855
  match (snowpark_typed_args[0].typ, snowpark_typed_args[1].typ):
819
856
  case (DecimalType() as t1, NullType()):
@@ -825,15 +862,17 @@ def map_unresolved_function(
825
862
  ):
826
863
  p1, s1 = _get_type_precision(snowpark_typed_args[0].typ)
827
864
  p2, s2 = _get_type_precision(snowpark_typed_args[1].typ)
828
- result_type, overflow_detected = _get_decimal_division_result_type(
865
+ result_type, overflow_possible = _get_decimal_division_result_type(
829
866
  p1, s1, p2, s2
830
867
  )
831
- result_exp = _get_decimal_division_result_exp(
868
+
869
+ result_exp = _arithmetic_operation(
870
+ snowpark_typed_args[0],
871
+ snowpark_typed_args[1],
872
+ lambda x, y: _divnull(x, y),
873
+ overflow_possible,
874
+ global_config.spark_sql_ansi_enabled,
832
875
  result_type,
833
- t,
834
- overflow_detected,
835
- snowpark_args,
836
- spark_function_name,
837
876
  )
838
877
  case (NullType(), NullType()):
839
878
  result_type = DoubleType()
@@ -922,6 +961,10 @@ def map_unresolved_function(
922
961
  raise AnalysisException(
923
962
  f'[DATATYPE_MISMATCH.BINARY_OP_DIFF_TYPES] Cannot resolve "{snowpark_arg_names[0]} < {snowpark_arg_names[1]}" due to data type mismatch: the left and right operands of the binary operator have incompatible types ("{snowpark_typed_args[0].typ}" and "{snowpark_typed_args[1].typ}").;'
924
963
  )
964
+ # Check for interval-string comparisons
965
+ _check_interval_string_comparison(
966
+ "<", snowpark_typed_args, snowpark_arg_names
967
+ )
925
968
  left, right = _coerce_for_comparison(
926
969
  snowpark_typed_args[0], snowpark_typed_args[1]
927
970
  )
@@ -936,6 +979,10 @@ def map_unresolved_function(
936
979
  raise AnalysisException(
937
980
  f'[DATATYPE_MISMATCH.BINARY_OP_DIFF_TYPES] Cannot resolve "{snowpark_arg_names[0]} <= {snowpark_arg_names[1]}" due to data type mismatch: the left and right operands of the binary operator have incompatible types ("{snowpark_typed_args[0].typ}" and "{snowpark_typed_args[1].typ}").;'
938
981
  )
982
+ # Check for interval-string comparisons
983
+ _check_interval_string_comparison(
984
+ "<=", snowpark_typed_args, snowpark_arg_names
985
+ )
939
986
  left, right = _coerce_for_comparison(
940
987
  snowpark_typed_args[0], snowpark_typed_args[1]
941
988
  )
@@ -954,6 +1001,10 @@ def map_unresolved_function(
954
1001
  )
955
1002
  result_exp = TypedColumn(left.eqNullSafe(right), lambda: [BooleanType()])
956
1003
  case "==" | "=":
1004
+ # Check for interval-string comparisons
1005
+ _check_interval_string_comparison(
1006
+ "=", snowpark_typed_args, snowpark_arg_names
1007
+ )
957
1008
  spark_function_name = f"({snowpark_arg_names[0]} = {snowpark_arg_names[1]})"
958
1009
  left, right = _coerce_for_comparison(
959
1010
  snowpark_typed_args[0], snowpark_typed_args[1]
@@ -969,6 +1020,10 @@ def map_unresolved_function(
969
1020
  raise AnalysisException(
970
1021
  f'[DATATYPE_MISMATCH.BINARY_OP_DIFF_TYPES] Cannot resolve "{snowpark_arg_names[0]} > {snowpark_arg_names[1]}" due to data type mismatch: the left and right operands of the binary operator have incompatible types ("{snowpark_typed_args[0].typ}" and "{snowpark_typed_args[1].typ}").;'
971
1022
  )
1023
+ # Check for interval-string comparisons
1024
+ _check_interval_string_comparison(
1025
+ ">", snowpark_typed_args, snowpark_arg_names
1026
+ )
972
1027
  left, right = _coerce_for_comparison(
973
1028
  snowpark_typed_args[0], snowpark_typed_args[1]
974
1029
  )
@@ -983,6 +1038,10 @@ def map_unresolved_function(
983
1038
  raise AnalysisException(
984
1039
  f'[DATATYPE_MISMATCH.BINARY_OP_DIFF_TYPES] Cannot resolve "{snowpark_arg_names[0]} >= {snowpark_arg_names[1]}" due to data type mismatch: the left and right operands of the binary operator have incompatible types ("{snowpark_typed_args[0].typ}" and "{snowpark_typed_args[1].typ}").;'
985
1040
  )
1041
+ # Check for interval-string comparisons
1042
+ _check_interval_string_comparison(
1043
+ ">=", snowpark_typed_args, snowpark_arg_names
1044
+ )
986
1045
  left, right = _coerce_for_comparison(
987
1046
  snowpark_typed_args[0], snowpark_typed_args[1]
988
1047
  )
@@ -1113,9 +1172,7 @@ def map_unresolved_function(
1113
1172
  # SNOW-1955784: Support accuracy parameter
1114
1173
  # Use percentile_disc to return actual values from dataset (matches PySpark behavior)
1115
1174
 
1116
- def _pyspark_approx_percentile(
1117
- column: Column, percentage: float, original_type: DataType
1118
- ) -> Column:
1175
+ def _pyspark_approx_percentile(column: Column, percentage: float) -> Column:
1119
1176
  """
1120
1177
  PySpark-compatible percentile that returns actual values from dataset.
1121
1178
  - PySpark's approx_percentile returns the "smallest value in the ordered col values
@@ -1132,7 +1189,7 @@ def map_unresolved_function(
1132
1189
  result = snowpark_fn.function("percentile_disc")(
1133
1190
  snowpark_fn.lit(percentage)
1134
1191
  ).within_group(column)
1135
- return snowpark_fn.cast(result, original_type)
1192
+ return result
1136
1193
 
1137
1194
  column_type = snowpark_typed_args[0].typ
1138
1195
 
@@ -1143,26 +1200,18 @@ def map_unresolved_function(
1143
1200
  assert array_func.function_name == "array", array_func
1144
1201
 
1145
1202
  percentile_results = [
1146
- _pyspark_approx_percentile(
1147
- snowpark_args[0], unwrap_literal(arg), column_type
1148
- )
1203
+ _pyspark_approx_percentile(snowpark_args[0], unwrap_literal(arg))
1149
1204
  for arg in array_func.arguments
1150
1205
  ]
1151
1206
 
1152
1207
  result_type = ArrayType(element_type=column_type, contains_null=False)
1153
- result_exp = snowpark_fn.cast(
1154
- snowpark_fn.array_construct(*percentile_results),
1155
- result_type,
1156
- )
1208
+ result_exp = snowpark_fn.array_construct(*percentile_results)
1209
+ result_exp = _resolve_aggregate_exp(result_exp, result_type)
1157
1210
  else:
1158
1211
  # Handle single percentile
1159
1212
  percentage = unwrap_literal(exp.unresolved_function.arguments[1])
1160
- result_exp = TypedColumn(
1161
- _pyspark_approx_percentile(
1162
- snowpark_args[0], percentage, column_type
1163
- ),
1164
- lambda: [column_type],
1165
- )
1213
+ result_exp = _pyspark_approx_percentile(snowpark_args[0], percentage)
1214
+ result_exp = _resolve_aggregate_exp(result_exp, column_type)
1166
1215
  case "array":
1167
1216
  if len(snowpark_args) == 0:
1168
1217
  result_exp = snowpark_fn.cast(
@@ -1336,27 +1385,55 @@ def map_unresolved_function(
1336
1385
  result_exp = snowpark_fn.cast(result_exp, array_type)
1337
1386
  result_exp = TypedColumn(result_exp, lambda: snowpark_typed_args[0].types)
1338
1387
  case "array_repeat":
1388
+ elem, count = snowpark_args[0], snowpark_args[1]
1389
+ elem_type = snowpark_typed_args[0].typ
1390
+ result_type = ArrayType(elem_type)
1339
1391
 
1340
- @cached_udf(
1341
- input_types=[VariantType(), LongType()],
1342
- return_type=ArrayType(),
1343
- )
1344
- def _array_repeat(elem, n):
1345
- if n is None:
1346
- return None
1347
- if n < 0:
1348
- return []
1349
- return [elem] * n
1392
+ fallback_to_udf = True
1350
1393
 
1351
- result_exp = snowpark_fn.cast(
1352
- _array_repeat(
1353
- snowpark_fn.cast(snowpark_args[0], VariantType()), snowpark_args[1]
1354
- ),
1355
- ArrayType(snowpark_typed_args[0].typ),
1356
- )
1357
- result_exp = TypedColumn(
1358
- result_exp, lambda: [ArrayType(snowpark_typed_args[0].typ)]
1359
- )
1394
+ if isinstance(count._expression, Literal):
1395
+ count_value = count._expression.value
1396
+ fallback_to_udf = False
1397
+
1398
+ if count_value is None:
1399
+ result_exp = snowpark_fn.lit(None).cast(result_type)
1400
+ elif count_value <= 0:
1401
+ result_exp = snowpark_fn.array_construct().cast(result_type)
1402
+ elif count_value <= 16:
1403
+ # count_value is small enough to initialize the array directly in memory
1404
+ elem_variant = snowpark_fn.cast(elem, VariantType())
1405
+ result_exp = snowpark_fn.array_construct(
1406
+ *([elem_variant] * count_value)
1407
+ ).cast(result_type)
1408
+ else:
1409
+ fallback_to_udf = True
1410
+
1411
+ if fallback_to_udf:
1412
+
1413
+ @cached_udf(
1414
+ input_types=[VariantType(), LongType()],
1415
+ return_type=ArrayType(),
1416
+ )
1417
+ def _array_repeat(elem, n):
1418
+ if n is None:
1419
+ return None
1420
+ if n < 0:
1421
+ return []
1422
+ return [elem] * n
1423
+
1424
+ elem_variant = snowpark_fn.cast(elem, VariantType())
1425
+
1426
+ result_exp = (
1427
+ snowpark_fn.when(
1428
+ count.is_null(), snowpark_fn.lit(None).cast(result_type)
1429
+ )
1430
+ .when(count <= 0, snowpark_fn.array_construct().cast(result_type))
1431
+ .otherwise(
1432
+ snowpark_fn.cast(
1433
+ _array_repeat(elem_variant, count), result_type
1434
+ )
1435
+ )
1436
+ )
1360
1437
  case "array_size":
1361
1438
  array_type = snowpark_typed_args[0].typ
1362
1439
  if not isinstance(array_type, ArrayType):
@@ -1556,7 +1633,7 @@ def map_unresolved_function(
1556
1633
  result_exp = TypedColumn(result_exp, lambda: [LongType()])
1557
1634
  case "bit_get" | "getbit":
1558
1635
  snowflake_compat = get_boolean_session_config_param(
1559
- "enable_snowflake_extension_behavior"
1636
+ "snowpark.connect.enable_snowflake_extension_behavior"
1560
1637
  )
1561
1638
  col, pos = snowpark_args
1562
1639
  if snowflake_compat:
@@ -1863,14 +1940,11 @@ def map_unresolved_function(
1863
1940
  qualifiers = snowpark_args[0].get_qualifiers()
1864
1941
  case "collect_list" | "array_agg":
1865
1942
  # TODO: SNOW-1967177 - Support structured types in array_agg
1866
- result_exp = snowpark_fn.cast(
1867
- snowpark_fn.array_agg(
1868
- snowpark_typed_args[0].column(to_semi_structure=True)
1869
- ),
1870
- ArrayType(snowpark_typed_args[0].typ),
1943
+ result_exp = snowpark_fn.array_agg(
1944
+ snowpark_typed_args[0].column(to_semi_structure=True)
1871
1945
  )
1872
- result_exp = TypedColumn(
1873
- result_exp, lambda: [ArrayType(snowpark_typed_args[0].typ)]
1946
+ result_exp = _resolve_aggregate_exp(
1947
+ result_exp, ArrayType(snowpark_typed_args[0].typ)
1874
1948
  )
1875
1949
  spark_function_name = f"collect_list({snowpark_arg_names[0]})"
1876
1950
  case "collect_set":
@@ -2357,15 +2431,30 @@ def map_unresolved_function(
2357
2431
  # If format is NULL, return NULL for all rows
2358
2432
  result_exp = snowpark_fn.lit(None)
2359
2433
  else:
2434
+ format_lit = snowpark_fn.lit(
2435
+ map_spark_timestamp_format_expression(
2436
+ exp.unresolved_function.arguments[1],
2437
+ snowpark_typed_args[0].typ,
2438
+ )
2439
+ )
2360
2440
  result_exp = snowpark_fn.date_format(
2361
2441
  snowpark_args[0],
2362
- snowpark_fn.lit(
2363
- map_spark_timestamp_format_expression(
2364
- exp.unresolved_function.arguments[1],
2365
- snowpark_typed_args[0].typ,
2366
- )
2367
- ),
2442
+ format_lit,
2368
2443
  )
2444
+
2445
+ if format_literal == "EEEE":
2446
+ # TODO: SNOW-2356874, for weekday, Snowflake only supports abbreviated name, e.g. "Fri". Patch spark "EEEE" until
2447
+ # snowflake supports full weekday name.
2448
+ result_exp = (
2449
+ snowpark_fn.when(result_exp == "Mon", "Monday")
2450
+ .when(result_exp == "Tue", "Tuesday")
2451
+ .when(result_exp == "Wed", "Wednesday")
2452
+ .when(result_exp == "Thu", "Thursday")
2453
+ .when(result_exp == "Fri", "Friday")
2454
+ .when(result_exp == "Sat", "Saturday")
2455
+ .when(result_exp == "Sun", "Sunday")
2456
+ .otherwise(result_exp)
2457
+ )
2369
2458
  result_exp = TypedColumn(result_exp, lambda: [StringType()])
2370
2459
  case "date_from_unix_date":
2371
2460
  result_exp = snowpark_fn.date_add(
@@ -2464,6 +2553,14 @@ def map_unresolved_function(
2464
2553
  )
2465
2554
  result_type = LongType()
2466
2555
  case "date_part" | "datepart" | "extract":
2556
+ # Check for interval types and throw NotImplementedError
2557
+ if isinstance(
2558
+ snowpark_typed_args[1].typ, (YearMonthIntervalType, DayTimeIntervalType)
2559
+ ):
2560
+ raise NotImplementedError(
2561
+ f"{function_name} with interval types is not supported"
2562
+ )
2563
+
2467
2564
  field_lit: str | None = unwrap_literal(exp.unresolved_function.arguments[0])
2468
2565
 
2469
2566
  if field_lit is None:
@@ -3239,7 +3336,7 @@ def map_unresolved_function(
3239
3336
  # TODO: See the spark-compatibility-issues.md explanation, this is quite different from Spark.
3240
3337
  # MapType columns as input should raise an exception as they are not hashable.
3241
3338
  snowflake_compat = get_boolean_session_config_param(
3242
- "enable_snowflake_extension_behavior"
3339
+ "snowpark.connect.enable_snowflake_extension_behavior"
3243
3340
  )
3244
3341
  # Snowflake's hash function does allow MAP types, but Spark does not. Therefore, if we have the expansion flag enabled
3245
3342
  # we want to let it pass through and hash MAP types.
@@ -3746,10 +3843,21 @@ def map_unresolved_function(
3746
3843
  snowpark_fn.lit(is_outer),
3747
3844
  )
3748
3845
  case "input_file_name":
3749
- # TODO: Use METADATA$FILENAME to get an actual file name, but this requires a refactor, because this is not possible straightaway here.
3750
- raise SnowparkConnectNotImplementedError(
3751
- "input_file_name is not yet supported."
3846
+ # Return the filename metadata column for file-based DataFrames
3847
+ # If METADATA$FILENAME doesn't exist (e.g., for DataFrames created from local data),
3848
+ # return empty string to match Spark's behavior
3849
+ from snowflake.snowpark_connect.relation.read.metadata_utils import (
3850
+ METADATA_FILENAME_COLUMN,
3752
3851
  )
3852
+
3853
+ available_columns = column_mapping.get_snowpark_columns()
3854
+ if METADATA_FILENAME_COLUMN in available_columns:
3855
+ result_exp = snowpark_fn.col(METADATA_FILENAME_COLUMN)
3856
+ else:
3857
+ # Return empty when METADATA$FILENAME column doesn't exist, matching Spark behavior
3858
+ result_exp = snowpark_fn.lit("").cast(StringType())
3859
+ result_type = StringType()
3860
+ spark_function_name = "input_file_name()"
3753
3861
  case "instr":
3754
3862
  result_exp = snowpark_fn.charindex(snowpark_args[1], snowpark_args[0])
3755
3863
  result_type = LongType()
@@ -4709,7 +4817,7 @@ def map_unresolved_function(
4709
4817
  )
4710
4818
  case "md5":
4711
4819
  snowflake_compat = get_boolean_session_config_param(
4712
- "enable_snowflake_extension_behavior"
4820
+ "snowpark.connect.enable_snowflake_extension_behavior"
4713
4821
  )
4714
4822
 
4715
4823
  # MD5 in Spark only accepts BinaryType or types that can be implicitly cast to it (StringType)
@@ -5283,9 +5391,14 @@ def map_unresolved_function(
5283
5391
  result_exp = snowpark_fn.function(function_name)(
5284
5392
  _check_percentile_percentage(exp.unresolved_function.arguments[1])
5285
5393
  ).within_group(snowpark_args[0])
5286
- result_exp = TypedColumn(
5287
- snowpark_fn.cast(result_exp, FloatType()), lambda: [DoubleType()]
5394
+ result_exp = (
5395
+ TypedColumn(
5396
+ snowpark_fn.cast(result_exp, FloatType()), lambda: [DoubleType()]
5397
+ )
5398
+ if not is_window_enabled()
5399
+ else TypedColumnWithDeferredCast(result_exp, lambda: [DoubleType()])
5288
5400
  )
5401
+
5289
5402
  spark_function_name = f"{function_name}({unwrap_literal(exp.unresolved_function.arguments[1])}) WITHIN GROUP (ORDER BY {snowpark_arg_names[0]})"
5290
5403
  case "pi":
5291
5404
  spark_function_name = "PI()"
@@ -7504,6 +7617,12 @@ def map_unresolved_function(
7504
7617
  )
7505
7618
  result_type = DateType()
7506
7619
  case "try_add":
7620
+ # Check for interval types and throw NotImplementedError
7621
+ for arg in snowpark_typed_args:
7622
+ if isinstance(arg.typ, (YearMonthIntervalType, DayTimeIntervalType)):
7623
+ raise NotImplementedError(
7624
+ "try_add with interval types is not supported"
7625
+ )
7507
7626
  result_exp = _try_arithmetic_helper(snowpark_typed_args, snowpark_args, 0)
7508
7627
  result_exp = _type_with_typer(result_exp)
7509
7628
  case "try_aes_decrypt":
@@ -7557,6 +7676,12 @@ def map_unresolved_function(
7557
7676
  DoubleType(), cleaned, calculating_avg=True
7558
7677
  )
7559
7678
  case "try_divide":
7679
+ # Check for interval types and throw NotImplementedError
7680
+ for arg in snowpark_typed_args:
7681
+ if isinstance(arg.typ, (YearMonthIntervalType, DayTimeIntervalType)):
7682
+ raise NotImplementedError(
7683
+ "try_divide with interval types is not supported"
7684
+ )
7560
7685
  match (snowpark_typed_args[0].typ, snowpark_typed_args[1].typ):
7561
7686
  case (NullType(), t) | (t, NullType()):
7562
7687
  result_exp = snowpark_fn.lit(None)
@@ -7580,63 +7705,20 @@ def map_unresolved_function(
7580
7705
  )
7581
7706
  | (DecimalType(), DecimalType())
7582
7707
  ):
7583
- # compute new precision and scale using correct decimal division rules
7584
- if isinstance(
7585
- snowpark_typed_args[0].typ, DecimalType
7586
- ) and isinstance(snowpark_typed_args[1].typ, DecimalType):
7587
- s1, p1 = (
7588
- snowpark_typed_args[0].typ.scale,
7589
- snowpark_typed_args[0].typ.precision,
7590
- )
7591
- s2, p2 = (
7592
- snowpark_typed_args[1].typ.scale,
7593
- snowpark_typed_args[1].typ.precision,
7594
- )
7595
- # The scale and precision formula that Spark follows for DecimalType
7596
- # arithmetic operations can be found in the following Spark source
7597
- # code file:
7598
- # https://github.com/apache/spark/blob/a584cc48ef63fefb2e035349c8684250f8b936c4/docs/sql-ref-ansi-compliance.md
7599
- new_scale = max(6, s1 + p2 + 1)
7600
- new_precision = p1 - s1 + s2 + new_scale
7601
-
7602
- elif isinstance(snowpark_typed_args[0].typ, DecimalType):
7603
- s1, p1 = (
7604
- snowpark_typed_args[0].typ.scale,
7605
- snowpark_typed_args[0].typ.precision,
7606
- )
7607
- # INT is treated as Decimal(10, 0)
7608
- new_scale = max(6, s1 + 11)
7609
- new_precision = p1 - s1 + new_scale
7610
-
7611
- else: # right is DecimalType
7612
- s2, p2 = (
7613
- snowpark_typed_args[1].typ.scale,
7614
- snowpark_typed_args[1].typ.precision,
7615
- )
7616
- # INT is treated as Decimal(10, 0)
7617
- new_scale = max(6, 11 + p2)
7618
- new_precision = (
7619
- 10 - 0 + s2 + new_scale
7620
- ) # INT has precision 10, scale 0
7621
-
7622
- # apply precision cap
7623
- if new_precision > 38:
7624
- new_scale -= new_precision - 38
7625
- new_precision = 38
7626
- new_scale = max(new_scale, 6)
7627
-
7628
- left_double = snowpark_fn.cast(snowpark_args[0], DoubleType())
7629
- right_double = snowpark_fn.cast(snowpark_args[1], DoubleType())
7630
-
7631
- quotient = snowpark_fn.when(
7632
- snowpark_args[1] == 0, snowpark_fn.lit(None)
7633
- ).otherwise(left_double / right_double)
7634
- quotient = snowpark_fn.cast(quotient, StringType())
7708
+ p1, s1 = _get_type_precision(snowpark_typed_args[0].typ)
7709
+ p2, s2 = _get_type_precision(snowpark_typed_args[1].typ)
7710
+ result_type, overflow_possible = _get_decimal_division_result_type(
7711
+ p1, s1, p2, s2
7712
+ )
7635
7713
 
7636
- result_exp = _try_cast_helper(
7637
- quotient, DecimalType(new_precision, new_scale)
7714
+ result_exp = _arithmetic_operation(
7715
+ snowpark_typed_args[0],
7716
+ snowpark_typed_args[1],
7717
+ lambda x, y: _divnull(x, y),
7718
+ overflow_possible,
7719
+ False,
7720
+ result_type,
7638
7721
  )
7639
- result_type = DecimalType(new_precision, new_scale)
7640
7722
  case (_NumericType(), _NumericType()):
7641
7723
  result_exp = snowpark_fn.when(
7642
7724
  snowpark_args[1] == 0, snowpark_fn.lit(None)
@@ -7708,6 +7790,12 @@ def map_unresolved_function(
7708
7790
  f"Expected either (ArrayType, IntegralType) or (MapType, StringType), got {snowpark_typed_args[0].typ}, {snowpark_typed_args[1].typ}."
7709
7791
  )
7710
7792
  case "try_multiply":
7793
+ # Check for interval types and throw NotImplementedError
7794
+ for arg in snowpark_typed_args:
7795
+ if isinstance(arg.typ, (YearMonthIntervalType, DayTimeIntervalType)):
7796
+ raise NotImplementedError(
7797
+ "try_multiply with interval types is not supported"
7798
+ )
7711
7799
  match (snowpark_typed_args[0].typ, snowpark_typed_args[1].typ):
7712
7800
  case (NullType(), t) | (t, NullType()):
7713
7801
  result_exp = snowpark_fn.lit(None)
@@ -7749,42 +7837,21 @@ def map_unresolved_function(
7749
7837
  )
7750
7838
  | (DecimalType(), DecimalType())
7751
7839
  ):
7752
- # figure out what precision to use as the overflow amount
7753
- if isinstance(
7754
- snowpark_typed_args[0].typ, DecimalType
7755
- ) and isinstance(snowpark_typed_args[1].typ, DecimalType):
7756
- new_precision = (
7757
- snowpark_typed_args[0].typ.precision
7758
- + snowpark_typed_args[1].typ.precision
7759
- + 1
7760
- )
7761
- new_scale = (
7762
- snowpark_typed_args[0].typ.scale
7763
- + snowpark_typed_args[1].typ.scale
7764
- )
7765
- elif isinstance(snowpark_typed_args[0].typ, DecimalType):
7766
- new_precision = snowpark_typed_args[0].typ.precision + 11
7767
- new_scale = snowpark_typed_args[0].typ.scale
7768
- else:
7769
- new_precision = snowpark_typed_args[1].typ.precision + 11
7770
- new_scale = snowpark_typed_args[1].typ.scale
7771
-
7772
- # truncating down appropriately
7773
- if new_precision > 38:
7774
- new_precision = 38
7775
- if new_scale > new_precision:
7776
- new_scale = new_precision
7777
-
7778
- left_double = snowpark_fn.cast(snowpark_args[0], DoubleType())
7779
- right_double = snowpark_fn.cast(snowpark_args[1], DoubleType())
7780
-
7781
- product = left_double * right_double
7782
-
7783
- product = snowpark_fn.cast(product, StringType())
7784
- result_exp = _try_cast_helper(
7785
- product, DecimalType(new_precision, new_scale)
7840
+ p1, s1 = _get_type_precision(snowpark_typed_args[0].typ)
7841
+ p2, s2 = _get_type_precision(snowpark_typed_args[1].typ)
7842
+ (
7843
+ result_type,
7844
+ overflow_possible,
7845
+ ) = _get_decimal_multiplication_result_type(p1, s1, p2, s2)
7846
+
7847
+ result_exp = _arithmetic_operation(
7848
+ snowpark_typed_args[0],
7849
+ snowpark_typed_args[1],
7850
+ lambda x, y: x * y,
7851
+ overflow_possible,
7852
+ False,
7853
+ result_type,
7786
7854
  )
7787
- result_type = DecimalType(new_precision, new_scale)
7788
7855
  case (_NumericType(), _NumericType()):
7789
7856
  result_exp = snowpark_args[0] * snowpark_args[1]
7790
7857
  result_exp = _type_with_typer(result_exp)
@@ -7827,6 +7894,12 @@ def map_unresolved_function(
7827
7894
  snowpark_typed_args[0].typ, snowpark_args[0]
7828
7895
  )
7829
7896
  case "try_subtract":
7897
+ # Check for interval types and throw NotImplementedError
7898
+ for arg in snowpark_typed_args:
7899
+ if isinstance(arg.typ, (YearMonthIntervalType, DayTimeIntervalType)):
7900
+ raise NotImplementedError(
7901
+ "try_subtract with interval types is not supported"
7902
+ )
7830
7903
  result_exp = _try_arithmetic_helper(snowpark_typed_args, snowpark_args, 1)
7831
7904
  result_exp = _type_with_typer(result_exp)
7832
7905
  case "try_to_number":
@@ -8391,20 +8464,10 @@ def map_unresolved_function(
8391
8464
  return spark_col_names, typed_col
8392
8465
 
8393
8466
 
8394
- def _cast_helper(column: Column, to: DataType) -> Column:
8395
- if global_config.spark_sql_ansi_enabled:
8396
- column_mediator = (
8397
- snowpark_fn.cast(column, StringType())
8398
- if isinstance(to, DecimalType)
8399
- else column
8400
- )
8401
- return snowpark_fn.cast(column_mediator, to)
8402
- else:
8403
- return _try_cast_helper(column, to)
8404
-
8405
-
8406
8467
  def _try_cast_helper(column: Column, to: DataType) -> Column:
8407
8468
  """
8469
+ DEPRECATED because of performance issues
8470
+
8408
8471
  Attempts to cast a given column to a specified data type using the same behaviour as Spark.
8409
8472
 
8410
8473
  Args:
@@ -9600,71 +9663,109 @@ def _decimal_add_sub_result_type_helper(p1, s1, p2, s2):
9600
9663
  return result_precision, min_scale, return_type_precision, return_type_scale
9601
9664
 
9602
9665
 
9603
- def _get_decimal_multiplication_result_exp(
9604
- result_type: DecimalType | DataType,
9605
- other_type: DataType,
9606
- snowpark_args: list[Column],
9607
- ) -> Column:
9608
- if global_config.spark_sql_ansi_enabled:
9609
- result_exp = snowpark_args[0] * snowpark_args[1]
9610
- else:
9611
- if isinstance(other_type, _IntegralType):
9612
- result_exp = snowpark_args[0].cast(result_type) * snowpark_args[1].cast(
9613
- result_type
9614
- )
9615
- else:
9616
- result_exp = snowpark_args[0].cast(DoubleType()) * snowpark_args[1].cast(
9617
- DoubleType()
9618
- )
9619
- result_exp = _try_cast_helper(result_exp, result_type)
9620
- return result_exp
9621
-
9622
-
9623
- def _get_decimal_multiplication_result_type(p1, s1, p2, s2) -> DecimalType:
9666
+ def _get_decimal_multiplication_result_type(p1, s1, p2, s2) -> tuple[DecimalType, bool]:
9624
9667
  result_precision = p1 + p2 + 1
9625
9668
  result_scale = s1 + s2
9669
+ overflow_possible = False
9626
9670
  if result_precision > 38:
9671
+ overflow_possible = True
9627
9672
  if result_scale > 6:
9628
9673
  overflow = result_precision - 38
9629
9674
  result_scale = max(6, result_scale - overflow)
9630
9675
  result_precision = 38
9631
- return DecimalType(result_precision, result_scale)
9676
+ return DecimalType(result_precision, result_scale), overflow_possible
9632
9677
 
9633
9678
 
9634
- def _get_decimal_division_result_exp(
9635
- result_type: DecimalType | DataType,
9636
- other_type: DataType,
9637
- overflow_detected: bool,
9638
- snowpark_args: list[Column],
9639
- spark_function_name: str,
9679
+ def _arithmetic_operation(
9680
+ arg1: TypedColumn,
9681
+ arg2: TypedColumn,
9682
+ op: Callable[[Column, Column], Column],
9683
+ overflow_possible: bool,
9684
+ should_raise_on_overflow: bool,
9685
+ target_type: DecimalType,
9640
9686
  ) -> Column:
9641
- if (
9642
- isinstance(other_type, DecimalType)
9643
- and overflow_detected
9644
- and global_config.spark_sql_ansi_enabled
9645
- ):
9646
- raise ArithmeticException(
9647
- f'[NUMERIC_VALUE_OUT_OF_RANGE] {spark_function_name} cannot be represented as Decimal({result_type.precision}, {result_type.scale}). If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error, and return NULL instead.'
9687
+ def _cast_arg(tc: TypedColumn) -> Column:
9688
+ _, s = _get_type_precision(tc.typ)
9689
+ typ = (
9690
+ DoubleType()
9691
+ if s > 0
9692
+ or (
9693
+ isinstance(tc.typ, _FractionalType)
9694
+ and not isinstance(tc.typ, DecimalType)
9695
+ )
9696
+ else LongType()
9697
+ )
9698
+ return tc.col.cast(typ)
9699
+
9700
+ op_for_overflow_check = op(arg1.col.cast(DoubleType()), arg2.col.cast(DoubleType()))
9701
+ safe_op = op(_cast_arg(arg1), _cast_arg(arg2))
9702
+
9703
+ if overflow_possible:
9704
+ return _cast_arithmetic_operation_result(
9705
+ op_for_overflow_check, safe_op, target_type, should_raise_on_overflow
9648
9706
  )
9649
9707
  else:
9650
- dividend = snowpark_args[0].cast(DoubleType())
9651
- divisor = snowpark_args[1]
9652
- result_exp = _divnull(dividend, divisor)
9653
- result_exp = _cast_helper(result_exp, result_type)
9654
- return result_exp
9708
+ return op(arg1.col, arg2.col).cast(target_type)
9709
+
9710
+
9711
+ def _cast_arithmetic_operation_result(
9712
+ overflow_check_expr: Column,
9713
+ result_expr: Column,
9714
+ target_type: DecimalType,
9715
+ should_raise_on_overflow: bool,
9716
+ ) -> Column:
9717
+ """
9718
+ Casts an arithmetic operation result to the target decimal type with overflow detection.
9719
+ This function uses a dual-expression approach for robust overflow handling:
9720
+ Args:
9721
+ overflow_check_expr: Arithmetic expression using DoubleType operands for overflow detection.
9722
+ This expression is used ONLY for boundary checking against the target
9723
+ decimal's min/max values. DoubleType preserves the magnitude of large
9724
+ intermediate results that might overflow in decimal arithmetic.
9725
+ result_expr: Arithmetic expression using safer operand types (LongType for integers,
9726
+ DoubleType for fractionals) for the actual result computation.
9727
+ target_type: Target DecimalType to cast the result to.
9728
+ should_raise_on_overflow: If True raises ArithmeticException on overflow, if False, returns NULL on overflow.
9729
+ """
9730
+
9731
+ def create_overflow_handler(min_val, max_val, type_name: str):
9732
+ if should_raise_on_overflow:
9733
+ raise_error = _raise_error_helper(target_type, ArithmeticException)
9734
+ return snowpark_fn.when(
9735
+ (overflow_check_expr < snowpark_fn.lit(min_val))
9736
+ | (overflow_check_expr > snowpark_fn.lit(max_val)),
9737
+ raise_error(
9738
+ snowpark_fn.lit(
9739
+ f'[NUMERIC_VALUE_OUT_OF_RANGE] Value cannot be represented as {type_name}. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error, and return NULL instead.'
9740
+ )
9741
+ ),
9742
+ ).otherwise(result_expr.cast(target_type))
9743
+ else:
9744
+ return snowpark_fn.when(
9745
+ (overflow_check_expr < snowpark_fn.lit(min_val))
9746
+ | (overflow_check_expr > snowpark_fn.lit(max_val)),
9747
+ snowpark_fn.lit(None),
9748
+ ).otherwise(result_expr.cast(target_type))
9749
+
9750
+ precision = target_type.precision
9751
+ scale = target_type.scale
9752
+
9753
+ max_val = (10**precision - 1) / (10**scale)
9754
+ min_val = -max_val
9755
+
9756
+ return create_overflow_handler(min_val, max_val, f"DECIMAL({precision},{scale})")
9655
9757
 
9656
9758
 
9657
9759
  def _get_decimal_division_result_type(p1, s1, p2, s2) -> tuple[DecimalType, bool]:
9658
- overflow_detected = False
9760
+ overflow_possible = False
9659
9761
  result_scale = max(6, s1 + p2 + 1)
9660
9762
  result_precision = p1 - s1 + s2 + result_scale
9661
9763
  if result_precision > 38:
9662
- if result_precision > 40:
9663
- overflow_detected = True
9764
+ overflow_possible = True
9664
9765
  overflow = result_precision - 38
9665
9766
  result_scale = max(6, result_scale - overflow)
9666
9767
  result_precision = 38
9667
- return DecimalType(result_precision, result_scale), overflow_detected
9768
+ return DecimalType(result_precision, result_scale), overflow_possible
9668
9769
 
9669
9770
 
9670
9771
  def _try_arithmetic_helper(
@@ -9778,46 +9879,20 @@ def _try_arithmetic_helper(
9778
9879
  DecimalType(),
9779
9880
  DecimalType(),
9780
9881
  ):
9781
- p1, s1 = _get_type_precision(typed_args[0].typ)
9782
- p2, s2 = _get_type_precision(typed_args[1].typ)
9783
-
9784
- if isinstance(typed_args[0].typ, _IntegralType) and isinstance(
9785
- typed_args[1].typ, DecimalType
9786
- ):
9787
- new_scale = s2
9788
- new_precision = max(p2, p1 + s2)
9789
- elif isinstance(typed_args[0].typ, DecimalType) and isinstance(
9790
- typed_args[1].typ, _IntegralType
9791
- ):
9792
- new_scale = s1
9793
- new_precision = max(p1, p2 + s1)
9794
- else:
9795
- # Both decimal types
9796
- if operation_type == 1 and s1 == s2: # subtraction with matching scales
9797
- new_scale = s1
9798
- max_integral_digits = max(p1 - s1, p2 - s2)
9799
- new_precision = max_integral_digits + new_scale
9800
- else:
9801
- new_scale = max(s1, s2)
9802
- max_integral_digits = max(p1 - s1, p2 - s2)
9803
- new_precision = max_integral_digits + new_scale + 1
9804
-
9805
- # Overflow check
9806
- if new_precision > 38:
9807
- if global_config.spark_sql_ansi_enabled:
9808
- raise ArithmeticException(
9809
- f'[NUMERIC_VALUE_OUT_OF_RANGE] Precision {new_precision} exceeds maximum allowed precision of 38. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error, and return NULL instead.'
9810
- )
9811
- return snowpark_fn.lit(None)
9812
-
9813
- left_operand, right_operand = snowpark_args[0], snowpark_args[1]
9814
-
9815
- result = (
9816
- left_operand + right_operand
9817
- if operation_type == 0
9818
- else left_operand - right_operand
9882
+ result_type, overflow_possible = _get_add_sub_result_type(
9883
+ typed_args[0].typ,
9884
+ typed_args[1].typ,
9885
+ "try_add" if operation_type == 0 else "try_subtract",
9886
+ )
9887
+
9888
+ return _arithmetic_operation(
9889
+ typed_args[0],
9890
+ typed_args[1],
9891
+ lambda x, y: x + y if operation_type == 0 else x - y,
9892
+ overflow_possible,
9893
+ False,
9894
+ result_type,
9819
9895
  )
9820
- return snowpark_fn.cast(result, DecimalType(new_precision, new_scale))
9821
9896
 
9822
9897
  # If either of the inputs is floating point, we can just let it go through to Snowflake, where overflow
9823
9898
  # matches Spark and goes to inf.
@@ -9863,7 +9938,8 @@ def _get_add_sub_result_type(
9863
9938
  type1: DataType,
9864
9939
  type2: DataType,
9865
9940
  spark_function_name: str,
9866
- ) -> DataType:
9941
+ ) -> tuple[DataType, bool]:
9942
+ overflow_possible = False
9867
9943
  result_type = _find_common_type([type1, type2])
9868
9944
  match result_type:
9869
9945
  case DecimalType():
@@ -9872,6 +9948,7 @@ def _get_add_sub_result_type(
9872
9948
  result_scale = max(s1, s2)
9873
9949
  result_precision = max(p1 - s1, p2 - s2) + result_scale + 1
9874
9950
  if result_precision > 38:
9951
+ overflow_possible = True
9875
9952
  if result_scale > 6:
9876
9953
  overflow = result_precision - 38
9877
9954
  result_scale = max(6, result_scale - overflow)
@@ -9900,7 +9977,71 @@ def _get_add_sub_result_type(
9900
9977
  raise AnalysisException(
9901
9978
  f'[DATATYPE_MISMATCH.BINARY_OP_WRONG_TYPE] Cannot resolve "{spark_function_name}" due to data type mismatch: the binary operator requires the input type ("NUMERIC" or "INTERVAL DAY TO SECOND" or "INTERVAL YEAR TO MONTH" or "INTERVAL"), not "BOOLEAN".',
9902
9979
  )
9903
- return result_type
9980
+ return result_type, overflow_possible
9981
+
9982
+
9983
+ def _get_interval_type_name(
9984
+ interval_type: Union[YearMonthIntervalType, DayTimeIntervalType]
9985
+ ) -> str:
9986
+ """Get the formatted interval type name for error messages."""
9987
+ if isinstance(interval_type, YearMonthIntervalType):
9988
+ if interval_type.start_field == 0 and interval_type.end_field == 0:
9989
+ return "INTERVAL YEAR"
9990
+ elif interval_type.start_field == 1 and interval_type.end_field == 1:
9991
+ return "INTERVAL MONTH"
9992
+ else:
9993
+ return "INTERVAL YEAR TO MONTH"
9994
+ else: # DayTimeIntervalType
9995
+ if interval_type.start_field == 0 and interval_type.end_field == 0:
9996
+ return "INTERVAL DAY"
9997
+ elif interval_type.start_field == 1 and interval_type.end_field == 1:
9998
+ return "INTERVAL HOUR"
9999
+ elif interval_type.start_field == 2 and interval_type.end_field == 2:
10000
+ return "INTERVAL MINUTE"
10001
+ elif interval_type.start_field == 3 and interval_type.end_field == 3:
10002
+ return "INTERVAL SECOND"
10003
+ else:
10004
+ return "INTERVAL DAY TO SECOND"
10005
+
10006
+
10007
+ def _check_interval_string_comparison(
10008
+ operator: str, snowpark_typed_args: List[TypedColumn], snowpark_arg_names: List[str]
10009
+ ) -> None:
10010
+ """Check for invalid interval-string comparisons and raise AnalysisException if found."""
10011
+ if (
10012
+ isinstance(
10013
+ snowpark_typed_args[0].typ, (YearMonthIntervalType, DayTimeIntervalType)
10014
+ )
10015
+ and isinstance(snowpark_typed_args[1].typ, StringType)
10016
+ or isinstance(snowpark_typed_args[0].typ, StringType)
10017
+ and isinstance(
10018
+ snowpark_typed_args[1].typ, (YearMonthIntervalType, DayTimeIntervalType)
10019
+ )
10020
+ ):
10021
+ # Format interval type name for error message
10022
+ interval_type = (
10023
+ snowpark_typed_args[0].typ
10024
+ if isinstance(
10025
+ snowpark_typed_args[0].typ, (YearMonthIntervalType, DayTimeIntervalType)
10026
+ )
10027
+ else snowpark_typed_args[1].typ
10028
+ )
10029
+ interval_name = _get_interval_type_name(interval_type)
10030
+
10031
+ left_type = (
10032
+ "STRING"
10033
+ if isinstance(snowpark_typed_args[0].typ, StringType)
10034
+ else interval_name
10035
+ )
10036
+ right_type = (
10037
+ "STRING"
10038
+ if isinstance(snowpark_typed_args[1].typ, StringType)
10039
+ else interval_name
10040
+ )
10041
+
10042
+ raise AnalysisException(
10043
+ f'[DATATYPE_MISMATCH.BINARY_OP_DIFF_TYPES] Cannot resolve "({snowpark_arg_names[0]} {operator} {snowpark_arg_names[1]})" due to data type mismatch: the left and right operands of the binary operator have incompatible types ("{left_type}" and "{right_type}").;'
10044
+ )
9904
10045
 
9905
10046
 
9906
10047
  def _get_spark_function_name(
@@ -9944,6 +10085,21 @@ def _get_spark_function_name(
9944
10085
  return f"{date_param_name1} {operation_op} {snowpark_arg_names[1]}"
9945
10086
  else:
9946
10087
  return f"{operation_func}(cast({date_param_name1} as date), cast({snowpark_arg_names[1]} as double))"
10088
+ case (DateType(), DayTimeIntervalType()) | (
10089
+ DateType(),
10090
+ YearMonthIntervalType(),
10091
+ ):
10092
+ date_param_name1 = _get_literal_param_name(exp, 0, snowpark_arg_names[0])
10093
+ return f"{date_param_name1} {operation_op} {snowpark_arg_names[1]}"
10094
+ case (DayTimeIntervalType(), DateType()) | (
10095
+ YearMonthIntervalType(),
10096
+ DateType(),
10097
+ ):
10098
+ date_param_name2 = _get_literal_param_name(exp, 1, snowpark_arg_names[1])
10099
+ if function_name == "+":
10100
+ return f"{date_param_name2} {operation_op} {snowpark_arg_names[0]}"
10101
+ else:
10102
+ return default_spark_function_name
9947
10103
  case (DateType() as dt, _) | (_, DateType() as dt):
9948
10104
  date_param_index = 0 if dt == col1.typ else 1
9949
10105
  date_param_name = _get_literal_param_name(