snowpark-connect 0.31.0__py3-none-any.whl → 0.33.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of snowpark-connect might be problematic. Click here for more details.

Files changed (111) hide show
  1. snowflake/snowpark_connect/__init__.py +1 -0
  2. snowflake/snowpark_connect/column_name_handler.py +143 -105
  3. snowflake/snowpark_connect/column_qualifier.py +43 -0
  4. snowflake/snowpark_connect/dataframe_container.py +3 -2
  5. snowflake/snowpark_connect/execute_plan/map_execution_command.py +4 -2
  6. snowflake/snowpark_connect/expression/hybrid_column_map.py +5 -4
  7. snowflake/snowpark_connect/expression/map_expression.py +5 -4
  8. snowflake/snowpark_connect/expression/map_extension.py +12 -6
  9. snowflake/snowpark_connect/expression/map_sql_expression.py +50 -7
  10. snowflake/snowpark_connect/expression/map_unresolved_attribute.py +62 -25
  11. snowflake/snowpark_connect/expression/map_unresolved_function.py +924 -127
  12. snowflake/snowpark_connect/expression/map_unresolved_star.py +9 -7
  13. snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/__init__.py +16 -0
  14. snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/accessors.py +1281 -0
  15. snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/functions.py +203 -0
  16. snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/utils.py +202 -0
  17. snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +4 -1
  18. snowflake/snowpark_connect/relation/map_aggregate.py +6 -5
  19. snowflake/snowpark_connect/relation/map_column_ops.py +9 -3
  20. snowflake/snowpark_connect/relation/map_extension.py +10 -9
  21. snowflake/snowpark_connect/relation/map_join.py +219 -144
  22. snowflake/snowpark_connect/relation/map_row_ops.py +136 -54
  23. snowflake/snowpark_connect/relation/map_sql.py +134 -16
  24. snowflake/snowpark_connect/relation/map_subquery_alias.py +4 -1
  25. snowflake/snowpark_connect/relation/read/map_read_json.py +87 -2
  26. snowflake/snowpark_connect/relation/read/map_read_table.py +6 -3
  27. snowflake/snowpark_connect/relation/utils.py +46 -0
  28. snowflake/snowpark_connect/relation/write/map_write.py +215 -289
  29. snowflake/snowpark_connect/resources_initializer.py +25 -13
  30. snowflake/snowpark_connect/server.py +10 -26
  31. snowflake/snowpark_connect/type_mapping.py +38 -3
  32. snowflake/snowpark_connect/typed_column.py +8 -6
  33. snowflake/snowpark_connect/utils/sequence.py +21 -0
  34. snowflake/snowpark_connect/utils/session.py +27 -4
  35. snowflake/snowpark_connect/version.py +1 -1
  36. snowflake/snowpark_decoder/dp_session.py +1 -1
  37. {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.33.0.dist-info}/METADATA +7 -2
  38. {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.33.0.dist-info}/RECORD +46 -105
  39. snowflake/snowpark_connect/includes/jars/antlr4-runtime-4.9.3.jar +0 -0
  40. snowflake/snowpark_connect/includes/jars/commons-cli-1.5.0.jar +0 -0
  41. snowflake/snowpark_connect/includes/jars/commons-codec-1.16.1.jar +0 -0
  42. snowflake/snowpark_connect/includes/jars/commons-collections-3.2.2.jar +0 -0
  43. snowflake/snowpark_connect/includes/jars/commons-collections4-4.4.jar +0 -0
  44. snowflake/snowpark_connect/includes/jars/commons-compiler-3.1.9.jar +0 -0
  45. snowflake/snowpark_connect/includes/jars/commons-compress-1.26.0.jar +0 -0
  46. snowflake/snowpark_connect/includes/jars/commons-crypto-1.1.0.jar +0 -0
  47. snowflake/snowpark_connect/includes/jars/commons-dbcp-1.4.jar +0 -0
  48. snowflake/snowpark_connect/includes/jars/commons-io-2.16.1.jar +0 -0
  49. snowflake/snowpark_connect/includes/jars/commons-lang-2.6.jar +0 -0
  50. snowflake/snowpark_connect/includes/jars/commons-lang3-3.12.0.jar +0 -0
  51. snowflake/snowpark_connect/includes/jars/commons-logging-1.1.3.jar +0 -0
  52. snowflake/snowpark_connect/includes/jars/commons-math3-3.6.1.jar +0 -0
  53. snowflake/snowpark_connect/includes/jars/commons-pool-1.5.4.jar +0 -0
  54. snowflake/snowpark_connect/includes/jars/commons-text-1.10.0.jar +0 -0
  55. snowflake/snowpark_connect/includes/jars/hadoop-client-api-trimmed-3.3.4.jar +0 -0
  56. snowflake/snowpark_connect/includes/jars/jackson-annotations-2.15.2.jar +0 -0
  57. snowflake/snowpark_connect/includes/jars/jackson-core-2.15.2.jar +0 -0
  58. snowflake/snowpark_connect/includes/jars/jackson-core-asl-1.9.13.jar +0 -0
  59. snowflake/snowpark_connect/includes/jars/jackson-databind-2.15.2.jar +0 -0
  60. snowflake/snowpark_connect/includes/jars/jackson-dataformat-yaml-2.15.2.jar +0 -0
  61. snowflake/snowpark_connect/includes/jars/jackson-datatype-jsr310-2.15.2.jar +0 -0
  62. snowflake/snowpark_connect/includes/jars/jackson-module-scala_2.12-2.15.2.jar +0 -0
  63. snowflake/snowpark_connect/includes/jars/json4s-ast_2.12-3.7.0-M11.jar +0 -0
  64. snowflake/snowpark_connect/includes/jars/json4s-core_2.12-3.7.0-M11.jar +0 -0
  65. snowflake/snowpark_connect/includes/jars/json4s-jackson_2.12-3.7.0-M11.jar +0 -0
  66. snowflake/snowpark_connect/includes/jars/json4s-native_2.12-3.7.0-M11.jar +0 -0
  67. snowflake/snowpark_connect/includes/jars/json4s-scalap_2.12-3.7.0-M11.jar +0 -0
  68. snowflake/snowpark_connect/includes/jars/kryo-shaded-4.0.2.jar +0 -0
  69. snowflake/snowpark_connect/includes/jars/log4j-1.2-api-2.20.0.jar +0 -0
  70. snowflake/snowpark_connect/includes/jars/log4j-api-2.20.0.jar +0 -0
  71. snowflake/snowpark_connect/includes/jars/log4j-core-2.20.0.jar +0 -0
  72. snowflake/snowpark_connect/includes/jars/log4j-slf4j2-impl-2.20.0.jar +0 -0
  73. snowflake/snowpark_connect/includes/jars/paranamer-2.8.3.jar +0 -0
  74. snowflake/snowpark_connect/includes/jars/paranamer-2.8.jar +0 -0
  75. snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.1.0.jar +0 -0
  76. snowflake/snowpark_connect/includes/jars/scala-collection-compat_2.12-2.7.0.jar +0 -0
  77. snowflake/snowpark_connect/includes/jars/scala-library-2.12.18.jar +0 -0
  78. snowflake/snowpark_connect/includes/jars/scala-parser-combinators_2.12-2.3.0.jar +0 -0
  79. snowflake/snowpark_connect/includes/jars/scala-reflect-2.12.18.jar +0 -0
  80. snowflake/snowpark_connect/includes/jars/scala-xml_2.12-2.1.0.jar +0 -0
  81. snowflake/snowpark_connect/includes/jars/slf4j-api-2.0.7.jar +0 -0
  82. snowflake/snowpark_connect/includes/jars/spark-catalyst_2.12-3.5.6.jar +0 -0
  83. snowflake/snowpark_connect/includes/jars/spark-common-utils_2.12-3.5.6.jar +0 -0
  84. snowflake/snowpark_connect/includes/jars/spark-connect-client-jvm_2.12-3.5.6.jar +0 -0
  85. snowflake/snowpark_connect/includes/jars/spark-core_2.12-3.5.6.jar +0 -0
  86. snowflake/snowpark_connect/includes/jars/spark-graphx_2.12-3.5.6.jar +0 -0
  87. snowflake/snowpark_connect/includes/jars/spark-hive-thriftserver_2.12-3.5.6.jar +0 -0
  88. snowflake/snowpark_connect/includes/jars/spark-hive_2.12-3.5.6.jar +0 -0
  89. snowflake/snowpark_connect/includes/jars/spark-kvstore_2.12-3.5.6.jar +0 -0
  90. snowflake/snowpark_connect/includes/jars/spark-launcher_2.12-3.5.6.jar +0 -0
  91. snowflake/snowpark_connect/includes/jars/spark-mesos_2.12-3.5.6.jar +0 -0
  92. snowflake/snowpark_connect/includes/jars/spark-mllib-local_2.12-3.5.6.jar +0 -0
  93. snowflake/snowpark_connect/includes/jars/spark-network-common_2.12-3.5.6.jar +0 -0
  94. snowflake/snowpark_connect/includes/jars/spark-network-shuffle_2.12-3.5.6.jar +0 -0
  95. snowflake/snowpark_connect/includes/jars/spark-repl_2.12-3.5.6.jar +0 -0
  96. snowflake/snowpark_connect/includes/jars/spark-sketch_2.12-3.5.6.jar +0 -0
  97. snowflake/snowpark_connect/includes/jars/spark-sql-api_2.12-3.5.6.jar +0 -0
  98. snowflake/snowpark_connect/includes/jars/spark-sql_2.12-3.5.6.jar +0 -0
  99. snowflake/snowpark_connect/includes/jars/spark-tags_2.12-3.5.6.jar +0 -0
  100. snowflake/snowpark_connect/includes/jars/spark-unsafe_2.12-3.5.6.jar +0 -0
  101. snowflake/snowpark_connect/includes/jars/spark-yarn_2.12-3.5.6.jar +0 -0
  102. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2_grpc.py +0 -4
  103. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2_grpc.py +0 -4
  104. {snowpark_connect-0.31.0.data → snowpark_connect-0.33.0.data}/scripts/snowpark-connect +0 -0
  105. {snowpark_connect-0.31.0.data → snowpark_connect-0.33.0.data}/scripts/snowpark-session +0 -0
  106. {snowpark_connect-0.31.0.data → snowpark_connect-0.33.0.data}/scripts/snowpark-submit +0 -0
  107. {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.33.0.dist-info}/WHEEL +0 -0
  108. {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.33.0.dist-info}/licenses/LICENSE-binary +0 -0
  109. {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.33.0.dist-info}/licenses/LICENSE.txt +0 -0
  110. {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.33.0.dist-info}/licenses/NOTICE-binary +0 -0
  111. {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.33.0.dist-info}/top_level.txt +0 -0
@@ -20,11 +20,10 @@ from contextlib import suppress
20
20
  from decimal import ROUND_HALF_EVEN, ROUND_HALF_UP, Context, Decimal
21
21
  from functools import partial, reduce
22
22
  from pathlib import Path
23
- from typing import List, Optional, Union
23
+ from typing import List, Optional
24
24
  from urllib.parse import quote, unquote
25
25
 
26
26
  import pyspark.sql.connect.proto.expressions_pb2 as expressions_proto
27
- import pyspark.sql.functions as pyspark_functions
28
27
  from google.protobuf.message import Message
29
28
  from pyspark.errors.exceptions.base import (
30
29
  AnalysisException,
@@ -66,6 +65,7 @@ from snowflake.snowpark.types import (
66
65
  TimestampType,
67
66
  VariantType,
68
67
  YearMonthIntervalType,
68
+ _AnsiIntervalType,
69
69
  _FractionalType,
70
70
  _IntegralType,
71
71
  _NumericType,
@@ -74,6 +74,7 @@ from snowflake.snowpark_connect.column_name_handler import (
74
74
  ColumnNameMap,
75
75
  set_schema_getter,
76
76
  )
77
+ from snowflake.snowpark_connect.column_qualifier import ColumnQualifier
77
78
  from snowflake.snowpark_connect.config import (
78
79
  get_boolean_session_config_param,
79
80
  get_timestamp_type,
@@ -99,6 +100,7 @@ from snowflake.snowpark_connect.expression.map_unresolved_star import (
99
100
  )
100
101
  from snowflake.snowpark_connect.expression.typer import ExpressionTyper
101
102
  from snowflake.snowpark_connect.relation.catalogs.utils import CURRENT_CATALOG_NAME
103
+ from snowflake.snowpark_connect.relation.utils import is_aggregate_function
102
104
  from snowflake.snowpark_connect.type_mapping import (
103
105
  map_json_schema_to_snowpark,
104
106
  map_pyspark_types_to_snowpark_types,
@@ -148,7 +150,11 @@ from snowflake.snowpark_connect.utils.xxhash64 import (
148
150
  MAX_UINT64 = 2**64 - 1
149
151
  MAX_INT64 = 2**63 - 1
150
152
  MIN_INT64 = -(2**63)
151
- MAX_ARRAY_SIZE = 2_147_483_647
153
+ MAX_32BIT_SIGNED_INT = 2_147_483_647
154
+
155
+ # Interval arithmetic precision limits
156
+ MAX_DAY_TIME_DAYS = 106751991 # Maximum days for day-time intervals
157
+ MAX_10_DIGIT_LIMIT = 1000000000 # 10-digit limit (1 billion) for interval operands
152
158
 
153
159
  NAN, INFINITY = float("nan"), float("inf")
154
160
 
@@ -272,6 +278,40 @@ def _coerce_for_comparison(
272
278
  return left_col, right_col
273
279
 
274
280
 
281
+ def _preprocess_not_equals_expression(exp: expressions_proto.Expression) -> str:
282
+ """
283
+ Transform NOT(col1 = col2) expressions to col1 != col2 for Snowflake compatibility.
284
+
285
+ Snowflake has issues with NOT (col1 = col2) in subqueries, so we rewrite
286
+ not(==(a, b)) to a != b by modifying the protobuf expression early.
287
+
288
+ Returns:
289
+ The (potentially modified) function name as a lowercase string.
290
+ """
291
+ function_name = exp.unresolved_function.function_name.lower()
292
+
293
+ # Snowflake has issues with NOT (col1 = col2) in subqueries.
294
+ # Transform not(==(a, b)) to a!=b by modifying the protobuf early.
295
+ if (
296
+ function_name in ("not", "!")
297
+ and len(exp.unresolved_function.arguments) == 1
298
+ and exp.unresolved_function.arguments[0].WhichOneof("expr_type")
299
+ == "unresolved_function"
300
+ and exp.unresolved_function.arguments[0].unresolved_function.function_name
301
+ == "=="
302
+ ):
303
+ inner_eq_func = exp.unresolved_function.arguments[0].unresolved_function
304
+ inner_args = list(inner_eq_func.arguments)
305
+
306
+ exp.unresolved_function.function_name = "!="
307
+ exp.unresolved_function.ClearField("arguments")
308
+ exp.unresolved_function.arguments.extend(inner_args)
309
+
310
+ function_name = "!="
311
+
312
+ return function_name
313
+
314
+
275
315
  def map_unresolved_function(
276
316
  exp: expressions_proto.Expression,
277
317
  column_mapping: ColumnNameMap,
@@ -300,6 +340,9 @@ def map_unresolved_function(
300
340
  # Inject default parameters for functions that need them (especially for Scala clients)
301
341
  inject_function_defaults(exp.unresolved_function)
302
342
 
343
+ # Transform NOT(col = col) to col != col for Snowflake compatibility
344
+ function_name = _preprocess_not_equals_expression(exp)
345
+
303
346
  def _resolve_args_expressions(exp: expressions_proto.Expression):
304
347
  def _resolve_fn_arg(exp):
305
348
  with resolving_fun_args():
@@ -355,11 +398,10 @@ def map_unresolved_function(
355
398
  function_name = exp.unresolved_function.function_name.lower()
356
399
  telemetry.report_function_usage(function_name)
357
400
  result_type: Optional[DataType | List[DateType]] = None
358
- qualifiers: List[str] = []
401
+ qualifier_parts: List[str] = []
359
402
 
360
- pyspark_func = getattr(pyspark_functions, function_name, None)
361
- if pyspark_func and pyspark_func.__doc__.lstrip().startswith("Aggregate function:"):
362
- # Used by the GROUP BY ALL implementation. Far from ideal, but it seems to work...
403
+ # Check if this is an aggregate function (used by GROUP BY ALL implementation)
404
+ if is_aggregate_function(function_name):
363
405
  add_sql_aggregate_function()
364
406
 
365
407
  def _type_with_typer(col: Column) -> TypedColumn:
@@ -513,9 +555,17 @@ def map_unresolved_function(
513
555
  )
514
556
  result_type = [f.datatype for f in udtf.output_schema]
515
557
  case "!=":
516
- result_exp = TypedColumn(
517
- snowpark_args[0] != snowpark_args[1], lambda: [BooleanType()]
558
+ _check_interval_string_comparison(
559
+ "!=", snowpark_typed_args, snowpark_arg_names
560
+ )
561
+ # Make the function name same as spark connect. a != b translate's to not(a=b)
562
+ spark_function_name = (
563
+ f"(NOT ({snowpark_arg_names[0]} = {snowpark_arg_names[1]}))"
564
+ )
565
+ left, right = _coerce_for_comparison(
566
+ snowpark_typed_args[0], snowpark_typed_args[1]
518
567
  )
568
+ result_exp = TypedColumn(left != right, lambda: [BooleanType()])
519
569
  case "%" | "mod":
520
570
  if spark_sql_ansi_enabled:
521
571
  result_exp = snowpark_args[0] % snowpark_args[1]
@@ -616,12 +666,87 @@ def map_unresolved_function(
616
666
  result_exp = snowpark_args[0] * snowpark_args[1].try_cast(
617
667
  result_type
618
668
  )
619
- case (_NumericType() as t, NullType()) | (
620
- NullType(),
621
- _NumericType() as t,
669
+ case (StringType(), t) | (t, StringType()) if isinstance(
670
+ t, _AnsiIntervalType
671
+ ):
672
+ if isinstance(snowpark_typed_args[0].typ, StringType):
673
+ result_type = type(
674
+ t
675
+ )() # YearMonthIntervalType() or DayTimeIntervalType()
676
+ result_exp = snowpark_args[1] * snowpark_args[0].try_cast(
677
+ LongType()
678
+ )
679
+ spark_function_name = (
680
+ f"({snowpark_arg_names[1]} * {snowpark_arg_names[0]})"
681
+ )
682
+ else:
683
+ result_type = type(
684
+ t
685
+ )() # YearMonthIntervalType() or DayTimeIntervalType()
686
+ result_exp = snowpark_args[0] * snowpark_args[1].try_cast(
687
+ LongType()
688
+ )
689
+ spark_function_name = (
690
+ f"({snowpark_arg_names[0]} * {snowpark_arg_names[1]})"
691
+ )
692
+ case (
693
+ (_NumericType() as t, NullType())
694
+ | (NullType(), _NumericType() as t)
622
695
  ):
623
696
  result_type = t
624
697
  result_exp = snowpark_fn.lit(None)
698
+ case (NullType(), t) | (t, NullType()) if isinstance(
699
+ t, _AnsiIntervalType
700
+ ):
701
+ result_type = (
702
+ YearMonthIntervalType()
703
+ if isinstance(t, YearMonthIntervalType)
704
+ else DayTimeIntervalType()
705
+ )
706
+ result_exp = snowpark_fn.lit(None)
707
+ if isinstance(snowpark_typed_args[0].typ, NullType):
708
+ spark_function_name = (
709
+ f"({snowpark_arg_names[1]} * {snowpark_arg_names[0]})"
710
+ )
711
+ else:
712
+ spark_function_name = (
713
+ f"({snowpark_arg_names[0]} * {snowpark_arg_names[1]})"
714
+ )
715
+ case (DecimalType(), t) | (t, DecimalType()) if isinstance(
716
+ t, _AnsiIntervalType
717
+ ):
718
+ result_type = (
719
+ YearMonthIntervalType()
720
+ if isinstance(t, YearMonthIntervalType)
721
+ else DayTimeIntervalType()
722
+ )
723
+ if isinstance(snowpark_typed_args[0].typ, DecimalType):
724
+ result_exp = snowpark_args[1] * snowpark_args[0]
725
+ spark_function_name = (
726
+ f"({snowpark_arg_names[1]} * {snowpark_arg_names[0]})"
727
+ )
728
+ else:
729
+ result_exp = snowpark_args[0] * snowpark_args[1]
730
+ spark_function_name = (
731
+ f"({snowpark_arg_names[0]} * {snowpark_arg_names[1]})"
732
+ )
733
+ case (t, _NumericType()) if isinstance(t, _AnsiIntervalType):
734
+ result_type = (
735
+ YearMonthIntervalType()
736
+ if isinstance(t, YearMonthIntervalType)
737
+ else DayTimeIntervalType()
738
+ )
739
+ result_exp = snowpark_args[0] * snowpark_args[1]
740
+ case (_NumericType(), t) if isinstance(t, _AnsiIntervalType):
741
+ result_type = (
742
+ YearMonthIntervalType()
743
+ if isinstance(t, YearMonthIntervalType)
744
+ else DayTimeIntervalType()
745
+ )
746
+ result_exp = snowpark_args[1] * snowpark_args[0]
747
+ spark_function_name = (
748
+ f"({snowpark_arg_names[1]} * {snowpark_arg_names[0]})"
749
+ )
625
750
  case (_NumericType(), _NumericType()):
626
751
  result_type = _find_common_type(
627
752
  [arg.typ for arg in snowpark_typed_args]
@@ -662,7 +787,14 @@ def map_unresolved_function(
662
787
  result_type = DateType()
663
788
  result_exp = snowpark_args[0] + snowpark_args[1]
664
789
  elif isinstance(t, (DayTimeIntervalType, YearMonthIntervalType)):
665
- result_type = TimestampType()
790
+ result_type = (
791
+ TimestampType()
792
+ if isinstance(
793
+ snowpark_typed_args[t_param_index].typ,
794
+ DayTimeIntervalType,
795
+ )
796
+ else DateType()
797
+ )
666
798
  result_exp = (
667
799
  snowpark_args[date_param_index]
668
800
  + snowpark_args[t_param_index]
@@ -685,6 +817,35 @@ def map_unresolved_function(
685
817
  )
686
818
  attach_custom_error_code(exception, ErrorCodes.TYPE_MISMATCH)
687
819
  raise exception
820
+ case (TimestampType(), t) | (t, TimestampType()):
821
+ timestamp_param_index = (
822
+ 0
823
+ if isinstance(snowpark_typed_args[0].typ, TimestampType)
824
+ else 1
825
+ )
826
+ t_param_index = 1 - timestamp_param_index
827
+ if isinstance(t, (DayTimeIntervalType, YearMonthIntervalType)):
828
+ result_type = TimestampType()
829
+ result_exp = (
830
+ snowpark_args[timestamp_param_index]
831
+ + snowpark_args[t_param_index]
832
+ )
833
+ elif (
834
+ hasattr(
835
+ snowpark_typed_args[t_param_index].col._expr1, "pretty_name"
836
+ )
837
+ and "INTERVAL"
838
+ == snowpark_typed_args[t_param_index].col._expr1.pretty_name
839
+ ):
840
+ result_type = TimestampType()
841
+ result_exp = (
842
+ snowpark_args[timestamp_param_index]
843
+ + snowpark_args[t_param_index]
844
+ )
845
+ else:
846
+ raise AnalysisException(
847
+ f'[DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE] Cannot resolve "{spark_function_name}" due to data type mismatch: Parameter 2 requires the ("INTERVAL") type for timestamp operations, however "{snowpark_arg_names[t_param_index]}" has the type "{t}".',
848
+ )
688
849
  case (StringType(), StringType()):
689
850
  if spark_sql_ansi_enabled:
690
851
  exception = AnalysisException(
@@ -736,6 +897,99 @@ def map_unresolved_function(
736
897
  )
737
898
  attach_custom_error_code(exception, ErrorCodes.TYPE_MISMATCH)
738
899
  raise exception
900
+ case (t1, t2) | (t2, t1) if isinstance(
901
+ t1, _AnsiIntervalType
902
+ ) and isinstance(t2, _AnsiIntervalType) and type(t1) == type(t2):
903
+ # Both operands are the same interval type
904
+ result_type = type(t1)(
905
+ min(t1.start_field, t2.start_field),
906
+ max(t1.end_field, t2.end_field),
907
+ )
908
+ result_exp = snowpark_args[0] + snowpark_args[1]
909
+ case (StringType(), t) | (t, StringType()) if isinstance(
910
+ t, YearMonthIntervalType
911
+ ):
912
+ # String + YearMonthInterval: Spark tries to cast string to double first, throws error if it fails
913
+ result_type = StringType()
914
+ raise_error = _raise_error_helper(StringType(), AnalysisException)
915
+ if isinstance(snowpark_typed_args[0].typ, StringType):
916
+ # Try to cast string to double, if it fails (returns null), raise exception
917
+ cast_result = snowpark_fn.try_cast(snowpark_args[0], "double")
918
+ result_exp = snowpark_fn.when(
919
+ cast_result.is_null(),
920
+ raise_error(
921
+ snowpark_fn.lit(
922
+ f'The value \'{snowpark_args[0]}\' of the type {snowpark_typed_args[0].typ} cannot be cast to "DOUBLE" because it is malformed. Correct the value as per the syntax, or change its target type. Use `try_cast` to tolerate malformed input and return NULL instead. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error.'
923
+ )
924
+ ),
925
+ ).otherwise(cast_result + snowpark_args[1])
926
+ else:
927
+ cast_result = snowpark_fn.try_cast(snowpark_args[1], "double")
928
+ result_exp = snowpark_fn.when(
929
+ cast_result.is_null(),
930
+ raise_error(
931
+ snowpark_fn.lit(
932
+ f'The value \'{snowpark_args[0]}\' of the type {snowpark_typed_args[0].typ} cannot be cast to "DOUBLE" because it is malformed. Correct the value as per the syntax, or change its target type. Use `try_cast` to tolerate malformed input and return NULL instead. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error.'
933
+ )
934
+ ),
935
+ ).otherwise(snowpark_args[0] + cast_result)
936
+ case (StringType(), t) | (t, StringType()) if isinstance(
937
+ t, DayTimeIntervalType
938
+ ):
939
+ # String + DayTimeInterval: try to parse string as timestamp, return NULL if it fails
940
+ # For time-only strings (like '10:00:00'), prepend current date to make it a full timestamp
941
+ result_type = StringType()
942
+ if isinstance(snowpark_typed_args[0].typ, StringType):
943
+ # Check if string looks like time-only (HH:MM:SS or HH:MM pattern)
944
+ # If so, prepend current date; otherwise use as-is
945
+ time_only_pattern = snowpark_fn.function("regexp_like")(
946
+ snowpark_args[0], r"^\d{1,2}:\d{2}(:\d{2})?$"
947
+ )
948
+ timestamp_expr = snowpark_fn.when(
949
+ time_only_pattern,
950
+ snowpark_fn.function("try_to_timestamp_ntz")(
951
+ snowpark_fn.function("concat")(
952
+ snowpark_fn.function("to_char")(
953
+ snowpark_fn.function("current_date")(),
954
+ "YYYY-MM-DD",
955
+ ),
956
+ snowpark_fn.lit(" "),
957
+ snowpark_args[0],
958
+ )
959
+ ),
960
+ ).otherwise(
961
+ snowpark_fn.function("try_to_timestamp_ntz")(
962
+ snowpark_args[0]
963
+ )
964
+ )
965
+ result_exp = timestamp_expr + snowpark_args[1]
966
+ else:
967
+ # interval + string case
968
+ time_only_pattern = snowpark_fn.function("regexp_like")(
969
+ snowpark_args[1], r"^\d{1,2}:\d{2}(:\d{2})?$"
970
+ )
971
+ timestamp_expr = snowpark_fn.when(
972
+ time_only_pattern,
973
+ snowpark_fn.function("try_to_timestamp_ntz")(
974
+ snowpark_fn.function("concat")(
975
+ snowpark_fn.function("to_char")(
976
+ snowpark_fn.function("current_date")(),
977
+ "'YYYY-MM-DD'",
978
+ ),
979
+ snowpark_fn.lit(" "),
980
+ snowpark_args[1],
981
+ )
982
+ ),
983
+ ).otherwise(
984
+ snowpark_fn.function("try_to_timestamp_ntz")(
985
+ snowpark_args[1]
986
+ )
987
+ )
988
+ result_exp = snowpark_args[0] + timestamp_expr
989
+ spark_function_name = (
990
+ f"{snowpark_arg_names[0]} + {snowpark_arg_names[1]}"
991
+ )
992
+
739
993
  case _:
740
994
  result_type, overflow_possible = _get_add_sub_result_type(
741
995
  snowpark_typed_args[0].typ,
@@ -781,7 +1035,11 @@ def map_unresolved_function(
781
1035
  DateType(),
782
1036
  YearMonthIntervalType(),
783
1037
  ):
784
- result_type = TimestampType()
1038
+ result_type = (
1039
+ TimestampType()
1040
+ if isinstance(snowpark_typed_args[1].typ, DayTimeIntervalType)
1041
+ else DateType()
1042
+ )
785
1043
  result_exp = snowpark_args[0] - snowpark_args[1]
786
1044
  case (DateType(), StringType()):
787
1045
  if (
@@ -799,6 +1057,23 @@ def map_unresolved_function(
799
1057
  result_exp = snowpark_args[0] - snowpark_args[1].cast(
800
1058
  input_type
801
1059
  )
1060
+ case (TimestampType(), DayTimeIntervalType()) | (
1061
+ TimestampType(),
1062
+ YearMonthIntervalType(),
1063
+ ):
1064
+ result_type = TimestampType()
1065
+ result_exp = snowpark_args[0] - snowpark_args[1]
1066
+ case (TimestampType(), StringType()):
1067
+ if (
1068
+ hasattr(snowpark_typed_args[1].col._expr1, "pretty_name")
1069
+ and "INTERVAL" == snowpark_typed_args[1].col._expr1.pretty_name
1070
+ ):
1071
+ result_type = TimestampType()
1072
+ result_exp = snowpark_args[0] - snowpark_args[1]
1073
+ else:
1074
+ raise AnalysisException(
1075
+ f'[DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE] Cannot resolve "{spark_function_name}" due to data type mismatch: Parameter 2 requires the ("INTERVAL") type for timestamp operations, however "{snowpark_arg_names[1]}" has the type "{snowpark_typed_args[1].typ}".',
1076
+ )
802
1077
  case (StringType(), DateType()):
803
1078
  # TODO SNOW-2034420: resolve return type (it should be INTERVAL DAY)
804
1079
  result_type = LongType()
@@ -870,6 +1145,16 @@ def map_unresolved_function(
870
1145
  )
871
1146
  attach_custom_error_code(exception, ErrorCodes.TYPE_MISMATCH)
872
1147
  raise exception
1148
+ case (StringType(), t) if isinstance(t, _AnsiIntervalType):
1149
+ # String - Interval: try to parse string as timestamp, return NULL if it fails
1150
+ result_type = StringType()
1151
+ result_exp = (
1152
+ snowpark_fn.function("try_to_timestamp")(snowpark_args[0])
1153
+ - snowpark_args[1]
1154
+ )
1155
+ spark_function_name = (
1156
+ f"{snowpark_arg_names[0]} - {snowpark_arg_names[1]}"
1157
+ )
873
1158
  case _:
874
1159
  result_type, overflow_possible = _get_add_sub_result_type(
875
1160
  snowpark_typed_args[0].typ,
@@ -968,9 +1253,57 @@ def map_unresolved_function(
968
1253
  result_exp = _divnull(
969
1254
  snowpark_args[0], snowpark_args[1].try_cast(result_type)
970
1255
  )
1256
+ case (t, StringType()) if isinstance(t, _AnsiIntervalType):
1257
+ result_type = (
1258
+ YearMonthIntervalType()
1259
+ if isinstance(t, YearMonthIntervalType)
1260
+ else DayTimeIntervalType()
1261
+ )
1262
+ result_exp = snowpark_args[0] / snowpark_args[1].try_cast(
1263
+ LongType()
1264
+ )
1265
+ spark_function_name = (
1266
+ f"({snowpark_arg_names[0]} / {snowpark_arg_names[1]})"
1267
+ )
971
1268
  case (_NumericType(), NullType()) | (NullType(), _NumericType()):
972
1269
  result_type = DoubleType()
973
1270
  result_exp = snowpark_fn.lit(None)
1271
+ case (t, NullType()) if isinstance(t, _AnsiIntervalType):
1272
+ # Only allow interval / null, not null / interval
1273
+ result_type = (
1274
+ YearMonthIntervalType()
1275
+ if isinstance(t, YearMonthIntervalType)
1276
+ else DayTimeIntervalType()
1277
+ )
1278
+ result_exp = snowpark_fn.lit(None)
1279
+ spark_function_name = (
1280
+ f"({snowpark_arg_names[0]} / {snowpark_arg_names[1]})"
1281
+ )
1282
+ case (DecimalType(), t) | (t, DecimalType()) if isinstance(
1283
+ t, _AnsiIntervalType
1284
+ ):
1285
+ result_type = (
1286
+ YearMonthIntervalType()
1287
+ if isinstance(t, YearMonthIntervalType)
1288
+ else DayTimeIntervalType()
1289
+ )
1290
+ if isinstance(snowpark_typed_args[0].typ, DecimalType):
1291
+ result_exp = snowpark_args[1] / snowpark_args[0]
1292
+ spark_function_name = (
1293
+ f"({snowpark_arg_names[1]} / {snowpark_arg_names[0]})"
1294
+ )
1295
+ else:
1296
+ result_exp = snowpark_args[0] / snowpark_args[1]
1297
+ spark_function_name = (
1298
+ f"({snowpark_arg_names[0]} / {snowpark_arg_names[1]})"
1299
+ )
1300
+ case (t, _NumericType()) if isinstance(t, _AnsiIntervalType):
1301
+ result_type = (
1302
+ YearMonthIntervalType()
1303
+ if isinstance(t, YearMonthIntervalType)
1304
+ else DayTimeIntervalType()
1305
+ )
1306
+ result_exp = snowpark_args[0] / snowpark_args[1]
974
1307
  case (_NumericType(), _NumericType()):
975
1308
  result_type = DoubleType()
976
1309
  result_exp = _divnull(
@@ -2027,11 +2360,6 @@ def map_unresolved_function(
2027
2360
  result_exp = snowpark_fn.coalesce(
2028
2361
  *[col.cast(result_type) for col in snowpark_args]
2029
2362
  )
2030
- case "col":
2031
- # TODO: assign type
2032
- result_exp = snowpark_fn.col(*snowpark_args)
2033
- result_exp = _type_with_typer(result_exp)
2034
- qualifiers = snowpark_args[0].get_qualifiers()
2035
2363
  case "collect_list" | "array_agg":
2036
2364
  # TODO: SNOW-1967177 - Support structured types in array_agg
2037
2365
  result_exp = snowpark_fn.array_agg(
@@ -2049,11 +2377,6 @@ def map_unresolved_function(
2049
2377
  result_exp = _resolve_aggregate_exp(
2050
2378
  result_exp, ArrayType(snowpark_typed_args[0].typ)
2051
2379
  )
2052
- case "column":
2053
- # TODO: assign type
2054
- result_exp = snowpark_fn.column(*snowpark_args)
2055
- result_exp = _type_with_typer(result_exp)
2056
- qualifiers = snowpark_args[0].get_qualifiers()
2057
2380
  case "concat":
2058
2381
  if len(snowpark_args) == 0:
2059
2382
  result_exp = TypedColumn(snowpark_fn.lit(""), lambda: [StringType()])
@@ -2232,7 +2555,7 @@ def map_unresolved_function(
2232
2555
  snowpark_fn.col("*", _is_qualified_name=True)
2233
2556
  )
2234
2557
  else:
2235
- result_exp = snowpark_fn.count(*snowpark_args)
2558
+ result_exp = snowpark_fn.call_function("COUNT", *snowpark_args)
2236
2559
  result_exp = TypedColumn(result_exp, lambda: [LongType()])
2237
2560
  case "count_if":
2238
2561
  result_exp = snowpark_fn.call_function("COUNT_IF", snowpark_args[0])
@@ -2670,16 +2993,6 @@ def map_unresolved_function(
2670
2993
  )
2671
2994
  result_type = LongType()
2672
2995
  case "date_part" | "datepart" | "extract":
2673
- # Check for interval types and throw NotImplementedError
2674
- if isinstance(
2675
- snowpark_typed_args[1].typ, (YearMonthIntervalType, DayTimeIntervalType)
2676
- ):
2677
- exception = NotImplementedError(
2678
- f"{function_name} with interval types is not supported"
2679
- )
2680
- attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
2681
- raise exception
2682
-
2683
2996
  field_lit: str | None = unwrap_literal(exp.unresolved_function.arguments[0])
2684
2997
 
2685
2998
  if field_lit is None:
@@ -2724,16 +3037,51 @@ def map_unresolved_function(
2724
3037
  case "div":
2725
3038
  # Only called from SQL, either as `a div b` or `div(a, b)`
2726
3039
  # Convert it into `(a - a % b) / b`.
2727
- result_exp = snowpark_fn.cast(
2728
- (snowpark_args[0] - snowpark_args[0] % snowpark_args[1])
2729
- / snowpark_args[1],
2730
- LongType(),
2731
- )
2732
- if not spark_sql_ansi_enabled:
2733
- result_exp = snowpark_fn.when(
2734
- snowpark_args[1] == 0, snowpark_fn.lit(None)
2735
- ).otherwise(result_exp)
2736
- result_type = LongType()
3040
+ if isinstance(snowpark_typed_args[0].typ, YearMonthIntervalType):
3041
+ if isinstance(snowpark_typed_args[1].typ, YearMonthIntervalType):
3042
+ dividend_total = _calculate_total_months(snowpark_args[0])
3043
+ divisor_total = _calculate_total_months(snowpark_args[1])
3044
+
3045
+ # Handle division by zero interval
3046
+ if not spark_sql_ansi_enabled:
3047
+ result_exp = snowpark_fn.when(
3048
+ divisor_total == 0, snowpark_fn.lit(None)
3049
+ ).otherwise(snowpark_fn.trunc(dividend_total / divisor_total))
3050
+ else:
3051
+ result_exp = snowpark_fn.trunc(dividend_total / divisor_total)
3052
+ result_type = LongType()
3053
+ else:
3054
+ raise AnalysisException(
3055
+ f"""[DATATYPE_MISMATCH.BINARY_OP_DIFF_TYPES] Cannot resolve "({snowpark_arg_names[0]} div {snowpark_arg_names[1]})" due to data type mismatch: the left and right operands of the binary operator have incompatible types ({snowpark_typed_args[0].typ} and {snowpark_typed_args[1].typ}).;"""
3056
+ )
3057
+ elif isinstance(snowpark_typed_args[0].typ, DayTimeIntervalType):
3058
+ if isinstance(snowpark_typed_args[1].typ, DayTimeIntervalType):
3059
+ dividend_total = _calculate_total_seconds(snowpark_args[0])
3060
+ divisor_total = _calculate_total_seconds(snowpark_args[1])
3061
+
3062
+ # Handle division by zero interval
3063
+ if not spark_sql_ansi_enabled:
3064
+ result_exp = snowpark_fn.when(
3065
+ divisor_total == 0, snowpark_fn.lit(None)
3066
+ ).otherwise(snowpark_fn.trunc(dividend_total / divisor_total))
3067
+ else:
3068
+ result_exp = snowpark_fn.trunc(dividend_total / divisor_total)
3069
+ result_type = LongType()
3070
+ else:
3071
+ raise AnalysisException(
3072
+ f"""[DATATYPE_MISMATCH.BINARY_OP_DIFF_TYPES] Cannot resolve "({snowpark_arg_names[0]} div {snowpark_arg_names[1]})" due to data type mismatch: the left and right operands of the binary operator have incompatible types ({snowpark_typed_args[0].typ} and {snowpark_typed_args[1].typ}).;"""
3073
+ )
3074
+ else:
3075
+ result_exp = snowpark_fn.cast(
3076
+ (snowpark_args[0] - snowpark_args[0] % snowpark_args[1])
3077
+ / snowpark_args[1],
3078
+ LongType(),
3079
+ )
3080
+ if not spark_sql_ansi_enabled:
3081
+ result_exp = snowpark_fn.when(
3082
+ snowpark_args[1] == 0, snowpark_fn.lit(None)
3083
+ ).otherwise(result_exp)
3084
+ result_type = LongType()
2737
3085
  case "e":
2738
3086
  spark_function_name = "E()"
2739
3087
  result_exp = snowpark_fn.lit(math.e)
@@ -3565,8 +3913,6 @@ def map_unresolved_function(
3565
3913
  result should be either way good enough.
3566
3914
  """
3567
3915
 
3568
- from datetime import date, datetime, time, timedelta
3569
-
3570
3916
  def __init__(self) -> None:
3571
3917
 
3572
3918
  # init the RNG for breaking ties in histogram merging. A fixed seed is specified here
@@ -3710,7 +4056,8 @@ def map_unresolved_function(
3710
4056
  # just increment 'bin'. This is not done now because we don't want to make any
3711
4057
  # assumptions about the range of numeric data being analyzed.
3712
4058
  if bin < self.n_used_bins and self.bins[bin][0] == v:
3713
- self.bins[bin][1] += 1
4059
+ bin_x, bin_y = self.bins[bin]
4060
+ self.bins[bin] = (bin_x, bin_y + 1)
3714
4061
  else:
3715
4062
  self.bins.insert(bin + 1, (v, 1.0))
3716
4063
  self.n_used_bins += 1
@@ -4504,6 +4851,17 @@ def map_unresolved_function(
4504
4851
  date_str_exp = snowpark_fn.concat(y, dash, m, dash, d)
4505
4852
  result_exp = snowpark_fn.builtin(snowpark_function)(date_str_exp)
4506
4853
  result_type = DateType()
4854
+ case "make_dt_interval":
4855
+ # Pad argument names for display purposes
4856
+ padded_arg_names = snowpark_arg_names.copy()
4857
+ while len(padded_arg_names) < 3: # days, hours, minutes are integers
4858
+ padded_arg_names.append("0")
4859
+ if len(padded_arg_names) < 4: # seconds can be decimal
4860
+ padded_arg_names.append("0.000000")
4861
+
4862
+ spark_function_name = f"make_dt_interval({', '.join(padded_arg_names)})"
4863
+ result_exp = snowpark_fn.interval_day_time_from_parts(*snowpark_args)
4864
+ result_type = DayTimeIntervalType()
4507
4865
  case "make_timestamp" | "make_timestamp_ltz" | "make_timestamp_ntz":
4508
4866
  y, m, d, h, mins = map(lambda col: col.cast(LongType()), snowpark_args[:5])
4509
4867
  y_abs = snowpark_fn.abs(y)
@@ -4557,6 +4915,15 @@ def map_unresolved_function(
4557
4915
  result_exp = snowpark_fn.when(
4558
4916
  snowpark_fn.is_null(parsed_str_exp), snowpark_fn.lit(None)
4559
4917
  ).otherwise(make_timestamp_res)
4918
+ case "make_ym_interval":
4919
+ # Pad argument names for display purposes
4920
+ padded_arg_names = snowpark_arg_names.copy()
4921
+ while len(padded_arg_names) < 2: # years, months
4922
+ padded_arg_names.append("0")
4923
+
4924
+ spark_function_name = f"make_ym_interval({', '.join(padded_arg_names)})"
4925
+ result_exp = snowpark_fn.interval_year_month_from_parts(*snowpark_args)
4926
+ result_type = YearMonthIntervalType()
4560
4927
  case "map":
4561
4928
  allow_duplicate_keys = (
4562
4929
  global_config.spark_sql_mapKeyDedupPolicy == "LAST_WIN"
@@ -5211,7 +5578,11 @@ def map_unresolved_function(
5211
5578
  spark_function_name = f"(- {snowpark_arg_names[0]})"
5212
5579
  else:
5213
5580
  spark_function_name = f"negative({snowpark_arg_names[0]})"
5214
- if isinstance(arg_type, _NumericType):
5581
+ if (
5582
+ isinstance(arg_type, _NumericType)
5583
+ or isinstance(arg_type, YearMonthIntervalType)
5584
+ or isinstance(arg_type, DayTimeIntervalType)
5585
+ ):
5215
5586
  # Instead of using snowpark_fn.negate which can generate invalid SQL for nested minus operations,
5216
5587
  # use a direct multiplication by -1 which generates cleaner SQL
5217
5588
  result_exp = snowpark_args[0] * snowpark_fn.lit(-1)
@@ -5236,6 +5607,8 @@ def map_unresolved_function(
5236
5607
  result_type = (
5237
5608
  snowpark_typed_args[0].types
5238
5609
  if isinstance(arg_type, _NumericType)
5610
+ or isinstance(arg_type, YearMonthIntervalType)
5611
+ or isinstance(arg_type, DayTimeIntervalType)
5239
5612
  else DoubleType()
5240
5613
  )
5241
5614
  case "next_day":
@@ -5616,9 +5989,33 @@ def map_unresolved_function(
5616
5989
  case "percentile_cont" | "percentiledisc":
5617
5990
  if function_name == "percentiledisc":
5618
5991
  function_name = "percentile_disc"
5992
+ order_by_col = snowpark_args[0]
5993
+ args = exp.unresolved_function.arguments
5994
+ if len(args) != 3:
5995
+ exception = AssertionError(
5996
+ f"{function_name} expected 3 args but got {len(args)}"
5997
+ )
5998
+ attach_custom_error_code(exception, ErrorCodes.INTERNAL_ERROR)
5999
+ raise exception
6000
+ # literal value 0.0 - 1.0
6001
+ percentage_arg = args[1]
6002
+ sort_direction = args[2].sort_order.direction
6003
+ direction_str = "" # defaultValue
6004
+ if (
6005
+ sort_direction
6006
+ == expressions_proto.Expression.SortOrder.SORT_DIRECTION_DESCENDING
6007
+ ):
6008
+ direction_str = "DESC"
6009
+
6010
+ # Apply sort direction to the order_by column
6011
+ if direction_str == "DESC":
6012
+ order_by_col_with_direction = order_by_col.desc()
6013
+ else:
6014
+ order_by_col_with_direction = order_by_col.asc()
6015
+
5619
6016
  result_exp = snowpark_fn.function(function_name)(
5620
- _check_percentile_percentage(exp.unresolved_function.arguments[1])
5621
- ).within_group(snowpark_args[0])
6017
+ _check_percentile_percentage(percentage_arg)
6018
+ ).within_group(order_by_col_with_direction)
5622
6019
  result_exp = (
5623
6020
  TypedColumn(
5624
6021
  snowpark_fn.cast(result_exp, FloatType()), lambda: [DoubleType()]
@@ -5627,7 +6024,8 @@ def map_unresolved_function(
5627
6024
  else TypedColumnWithDeferredCast(result_exp, lambda: [DoubleType()])
5628
6025
  )
5629
6026
 
5630
- spark_function_name = f"{function_name}({unwrap_literal(exp.unresolved_function.arguments[1])}) WITHIN GROUP (ORDER BY {snowpark_arg_names[0]})"
6027
+ direction_part = f" {direction_str}" if direction_str else ""
6028
+ spark_function_name = f"{function_name}({unwrap_literal(percentage_arg)}) WITHIN GROUP (ORDER BY {snowpark_arg_names[0]}{direction_part})"
5631
6029
  case "pi":
5632
6030
  spark_function_name = "PI()"
5633
6031
  result_exp = snowpark_fn.lit(math.pi)
@@ -5767,7 +6165,11 @@ def map_unresolved_function(
5767
6165
  case "positive":
5768
6166
  arg_type = snowpark_typed_args[0].typ
5769
6167
  spark_function_name = f"(+ {snowpark_arg_names[0]})"
5770
- if isinstance(arg_type, _NumericType):
6168
+ if (
6169
+ isinstance(arg_type, _NumericType)
6170
+ or isinstance(arg_type, YearMonthIntervalType)
6171
+ or isinstance(arg_type, DayTimeIntervalType)
6172
+ ):
5771
6173
  result_exp = snowpark_args[0]
5772
6174
  elif isinstance(arg_type, StringType):
5773
6175
  if spark_sql_ansi_enabled:
@@ -5790,17 +6192,23 @@ def map_unresolved_function(
5790
6192
  result_type = (
5791
6193
  snowpark_typed_args[0].types
5792
6194
  if isinstance(arg_type, _NumericType)
6195
+ or isinstance(arg_type, YearMonthIntervalType)
6196
+ or isinstance(arg_type, DayTimeIntervalType)
5793
6197
  else DoubleType()
5794
6198
  )
5795
-
5796
- case "pow":
5797
- result_exp = snowpark_fn.pow(snowpark_args[0], snowpark_args[1])
5798
- result_type = DoubleType()
5799
- case "power":
5800
- spark_function_name = (
5801
- f"POWER({snowpark_arg_names[0]}, {snowpark_arg_names[1]})"
5802
- )
5803
- result_exp = snowpark_fn.pow(snowpark_args[0], snowpark_args[1])
6199
+ case "pow" | "power":
6200
+ spark_function_name = f"{function_name if function_name == 'pow' else function_name.upper()}({snowpark_arg_names[0]}, {snowpark_arg_names[1]})"
6201
+ if not spark_sql_ansi_enabled:
6202
+ snowpark_args = _validate_numeric_args(
6203
+ function_name, snowpark_typed_args, snowpark_args
6204
+ )
6205
+ result_exp = snowpark_fn.when(
6206
+ snowpark_fn.equal_nan(snowpark_fn.cast(snowpark_args[0], FloatType()))
6207
+ | snowpark_fn.equal_nan(
6208
+ snowpark_fn.cast(snowpark_args[1], FloatType())
6209
+ ),
6210
+ NAN,
6211
+ ).otherwise(snowpark_fn.pow(snowpark_args[0], snowpark_args[1]))
5804
6212
  result_type = DoubleType()
5805
6213
  case "product":
5806
6214
  col = snowpark_args[0]
@@ -6660,11 +7068,43 @@ def map_unresolved_function(
6660
7068
  fn_name = "sign"
6661
7069
 
6662
7070
  spark_function_name = f"{fn_name}({snowpark_arg_names[0]})"
6663
- result_exp = snowpark_fn.when(
6664
- snowpark_args[0] == NAN, snowpark_fn.lit(NAN)
6665
- ).otherwise(
6666
- snowpark_fn.cast(snowpark_fn.sign(snowpark_args[0]), DoubleType())
6667
- )
7071
+
7072
+ if isinstance(snowpark_typed_args[0].typ, YearMonthIntervalType):
7073
+ # Use SQL expression for zero year-month interval comparison
7074
+ result_exp = (
7075
+ snowpark_fn.when(
7076
+ snowpark_args[0]
7077
+ > snowpark_fn.sql_expr("INTERVAL '0-0' YEAR TO MONTH"),
7078
+ snowpark_fn.lit(1.0),
7079
+ )
7080
+ .when(
7081
+ snowpark_args[0]
7082
+ < snowpark_fn.sql_expr("INTERVAL '0-0' YEAR TO MONTH"),
7083
+ snowpark_fn.lit(-1.0),
7084
+ )
7085
+ .otherwise(snowpark_fn.lit(0.0))
7086
+ )
7087
+ elif isinstance(snowpark_typed_args[0].typ, DayTimeIntervalType):
7088
+ # Use SQL expression for zero day-time interval comparison
7089
+ result_exp = (
7090
+ snowpark_fn.when(
7091
+ snowpark_args[0]
7092
+ > snowpark_fn.sql_expr("INTERVAL '0 0:0:0' DAY TO SECOND"),
7093
+ snowpark_fn.lit(1.0),
7094
+ )
7095
+ .when(
7096
+ snowpark_args[0]
7097
+ < snowpark_fn.sql_expr("INTERVAL '0 0:0:0' DAY TO SECOND"),
7098
+ snowpark_fn.lit(-1.0),
7099
+ )
7100
+ .otherwise(snowpark_fn.lit(0.0))
7101
+ )
7102
+ else:
7103
+ result_exp = snowpark_fn.when(
7104
+ snowpark_args[0] == NAN, snowpark_fn.lit(NAN)
7105
+ ).otherwise(
7106
+ snowpark_fn.cast(snowpark_fn.sign(snowpark_args[0]), DoubleType())
7107
+ )
6668
7108
  result_type = DoubleType()
6669
7109
  case "sin":
6670
7110
  spark_function_name = f"SIN({snowpark_arg_names[0]})"
@@ -6909,7 +7349,16 @@ def map_unresolved_function(
6909
7349
  )
6910
7350
  raise exception
6911
7351
  case "split_part":
6912
- result_exp = snowpark_fn.call_function("split_part", *snowpark_args)
7352
+ # Check for index 0 and throw error to match PySpark behavior
7353
+ raise_error = _raise_error_helper(StringType(), SparkRuntimeException)
7354
+ result_exp = snowpark_fn.when(
7355
+ snowpark_args[2] == 0,
7356
+ raise_error(
7357
+ snowpark_fn.lit(
7358
+ "[INVALID_INDEX_OF_ZERO] The index 0 is invalid. An index shall be either < 0 or > 0 (the first element has index 1)."
7359
+ )
7360
+ ),
7361
+ ).otherwise(snowpark_fn.call_function("split_part", *snowpark_args))
6913
7362
  result_type = StringType()
6914
7363
  case "sqrt":
6915
7364
  spark_function_name = f"SQRT({snowpark_arg_names[0]})"
@@ -7939,18 +8388,123 @@ def map_unresolved_function(
7939
8388
  )
7940
8389
  result_type = DateType()
7941
8390
  case "try_add":
7942
- # Check for interval types and throw NotImplementedError
7943
- for arg in snowpark_typed_args:
7944
- if isinstance(arg.typ, (YearMonthIntervalType, DayTimeIntervalType)):
7945
- exception = NotImplementedError(
7946
- "try_add with interval types is not supported"
8391
+ # Handle interval arithmetic with overflow detection
8392
+ match (snowpark_typed_args[0].typ, snowpark_typed_args[1].typ):
8393
+ case (DateType(), t) | (t, DateType()) if isinstance(
8394
+ t, YearMonthIntervalType
8395
+ ):
8396
+ result_type = DateType()
8397
+ result_exp = snowpark_args[0] + snowpark_args[1]
8398
+ case (DateType(), t) | (t, DateType()) if isinstance(
8399
+ t, DayTimeIntervalType
8400
+ ):
8401
+ result_type = TimestampType()
8402
+ result_exp = snowpark_args[0] + snowpark_args[1]
8403
+ case (TimestampType(), t) | (t, TimestampType()) if isinstance(
8404
+ t, (DayTimeIntervalType, YearMonthIntervalType)
8405
+ ):
8406
+ result_type = (
8407
+ snowpark_typed_args[0].typ
8408
+ if isinstance(snowpark_typed_args[0].typ, TimestampType)
8409
+ else snowpark_typed_args[1].typ
7947
8410
  )
7948
- attach_custom_error_code(
7949
- exception, ErrorCodes.UNSUPPORTED_OPERATION
8411
+ result_exp = snowpark_args[0] + snowpark_args[1]
8412
+ case (t1, t2) if (
8413
+ isinstance(t1, YearMonthIntervalType)
8414
+ and isinstance(t2, (_NumericType, StringType))
8415
+ ) or (
8416
+ isinstance(t2, YearMonthIntervalType)
8417
+ and isinstance(t1, (_NumericType, StringType))
8418
+ ):
8419
+ # YearMonthInterval + numeric/string or numeric/string + YearMonthInterval should throw error
8420
+ exception = AnalysisException(
8421
+ f'[DATATYPE_MISMATCH.BINARY_OP_DIFF_TYPES] Cannot resolve "try_add({snowpark_arg_names[0]}, {snowpark_arg_names[1]})" due to data type mismatch: the left and right operands of the binary operator have incompatible types ("{snowpark_typed_args[0].typ}" and "{snowpark_typed_args[1].typ}").'
7950
8422
  )
8423
+ attach_custom_error_code(exception, ErrorCodes.TYPE_MISMATCH)
7951
8424
  raise exception
7952
- result_exp = _try_arithmetic_helper(snowpark_typed_args, snowpark_args, 0)
7953
- result_exp = _type_with_typer(result_exp)
8425
+ case (t1, t2) if isinstance(t1, YearMonthIntervalType) and isinstance(
8426
+ t2, YearMonthIntervalType
8427
+ ):
8428
+ result_type = YearMonthIntervalType(
8429
+ min(t1.start_field, t2.start_field),
8430
+ max(t1.end_field, t2.end_field),
8431
+ )
8432
+
8433
+ # For year-month intervals, throw ArithmeticException if operands reach 10+ digits OR result exceeds 9 digits
8434
+ total1 = _calculate_total_months(snowpark_args[0])
8435
+ total2 = _calculate_total_months(snowpark_args[1])
8436
+ ten_digit_limit = snowpark_fn.lit(MAX_10_DIGIT_LIMIT)
8437
+
8438
+ precision_violation = (
8439
+ # Check if either operand already reaches 10 digits (parsing limit)
8440
+ (snowpark_fn.abs(total1) >= ten_digit_limit)
8441
+ | (snowpark_fn.abs(total2) >= ten_digit_limit)
8442
+ | (
8443
+ (total1 > 0)
8444
+ & (total2 > 0)
8445
+ & (total1 >= ten_digit_limit - total2)
8446
+ )
8447
+ | (
8448
+ (total1 < 0)
8449
+ & (total2 < 0)
8450
+ & (total1 <= -ten_digit_limit - total2)
8451
+ )
8452
+ )
8453
+
8454
+ raise_error = _raise_error_helper(result_type, ArithmeticException)
8455
+ result_exp = snowpark_fn.when(
8456
+ precision_violation,
8457
+ raise_error(
8458
+ snowpark_fn.lit(
8459
+ "Year-Month Interval result exceeds Snowflake interval precision limit"
8460
+ )
8461
+ ),
8462
+ ).otherwise(snowpark_args[0] + snowpark_args[1])
8463
+ case (t1, t2) if isinstance(t1, DayTimeIntervalType) and isinstance(
8464
+ t2, DayTimeIntervalType
8465
+ ):
8466
+ result_type = DayTimeIntervalType(
8467
+ min(t1.start_field, t2.start_field),
8468
+ max(t1.end_field, t2.end_field),
8469
+ )
8470
+ # Check for Snowflake's day limit (106751991 days is the cutoff)
8471
+ days1 = snowpark_fn.date_part("day", snowpark_args[0])
8472
+ days2 = snowpark_fn.date_part("day", snowpark_args[1])
8473
+ max_days = snowpark_fn.lit(
8474
+ MAX_DAY_TIME_DAYS
8475
+ ) # Snowflake's actual limit
8476
+ min_days = snowpark_fn.lit(-MAX_DAY_TIME_DAYS)
8477
+
8478
+ # Check if either operand exceeds the day limit - throw error like Spark does
8479
+ operand_limit_violation = (snowpark_fn.abs(days1) > max_days) | (
8480
+ snowpark_fn.abs(days2) > max_days
8481
+ )
8482
+
8483
+ # Check if result would exceed day limit (but operands are valid) - return NULL
8484
+ result_overflow = (
8485
+ # Check if result would exceed day limit (positive overflow)
8486
+ ((days1 > 0) & (days2 > 0) & (days1 > max_days - days2))
8487
+ | ((days1 < 0) & (days2 < 0) & (days1 < min_days - days2))
8488
+ )
8489
+
8490
+ raise_error = _raise_error_helper(result_type, ArithmeticException)
8491
+ result_exp = (
8492
+ snowpark_fn.when(
8493
+ operand_limit_violation,
8494
+ raise_error(
8495
+ snowpark_fn.lit(
8496
+ "Day-Time Interval operand exceeds Snowflake interval precision limit"
8497
+ )
8498
+ ),
8499
+ )
8500
+ .when(result_overflow, snowpark_fn.lit(None))
8501
+ .otherwise(snowpark_args[0] + snowpark_args[1])
8502
+ )
8503
+ case _:
8504
+ result_exp = _try_arithmetic_helper(
8505
+ snowpark_typed_args, snowpark_args, 0
8506
+ )
8507
+ result_exp = _type_with_typer(result_exp)
7954
8508
  case "try_aes_decrypt":
7955
8509
  result_exp = _aes_helper(
7956
8510
  "TRY_DECRYPT",
@@ -8002,17 +8556,49 @@ def map_unresolved_function(
8002
8556
  DoubleType(), cleaned, calculating_avg=True
8003
8557
  )
8004
8558
  case "try_divide":
8005
- # Check for interval types and throw NotImplementedError
8006
- for arg in snowpark_typed_args:
8007
- if isinstance(arg.typ, (YearMonthIntervalType, DayTimeIntervalType)):
8008
- exception = NotImplementedError(
8009
- "try_divide with interval types is not supported"
8010
- )
8011
- attach_custom_error_code(
8012
- exception, ErrorCodes.UNSUPPORTED_OPERATION
8013
- )
8014
- raise exception
8559
+ # Handle interval division with overflow detection
8015
8560
  match (snowpark_typed_args[0].typ, snowpark_typed_args[1].typ):
8561
+ case (t1, t2) if isinstance(t1, _AnsiIntervalType) and isinstance(
8562
+ t2, (_NumericType, StringType)
8563
+ ):
8564
+ # Interval / numeric/string
8565
+ result_type = t1
8566
+ interval_arg = snowpark_args[0]
8567
+ divisor = (
8568
+ snowpark_args[1]
8569
+ if isinstance(t2, _NumericType)
8570
+ else snowpark_fn.cast(snowpark_args[1], "double")
8571
+ )
8572
+
8573
+ # Check for division by zero first
8574
+ zero_check = divisor == 0
8575
+
8576
+ if isinstance(result_type, YearMonthIntervalType):
8577
+ # For year-month intervals, check if result exceeds 32-bit signed integer limit
8578
+ result_type = YearMonthIntervalType()
8579
+ total_months = _calculate_total_months(interval_arg)
8580
+ max_months = snowpark_fn.lit(MAX_32BIT_SIGNED_INT)
8581
+ overflow_check = (
8582
+ snowpark_fn.abs(total_months / divisor) > max_months
8583
+ )
8584
+ result_exp = (
8585
+ snowpark_fn.when(zero_check, snowpark_fn.lit(None))
8586
+ .when(overflow_check, snowpark_fn.lit(None))
8587
+ .otherwise(interval_arg / divisor)
8588
+ )
8589
+ else: # DayTimeIntervalType
8590
+ # For day-time intervals, check if result exceeds day limit
8591
+ result_type = DayTimeIntervalType()
8592
+ total_days = _calculate_total_days(interval_arg)
8593
+ max_days = snowpark_fn.lit(MAX_DAY_TIME_DAYS)
8594
+ overflow_check = (
8595
+ snowpark_fn.abs(total_days / divisor) > max_days
8596
+ )
8597
+ result_exp = (
8598
+ snowpark_fn.when(zero_check, snowpark_fn.lit(None))
8599
+ .when(overflow_check, snowpark_fn.lit(None))
8600
+ .otherwise(interval_arg / divisor)
8601
+ )
8016
8602
  case (NullType(), t) | (t, NullType()):
8017
8603
  result_exp = snowpark_fn.lit(None)
8018
8604
  result_type = FloatType()
@@ -8124,17 +8710,76 @@ def map_unresolved_function(
8124
8710
  attach_custom_error_code(exception, ErrorCodes.TYPE_MISMATCH)
8125
8711
  raise exception
8126
8712
  case "try_multiply":
8127
- # Check for interval types and throw NotImplementedError
8128
- for arg in snowpark_typed_args:
8129
- if isinstance(arg.typ, (YearMonthIntervalType, DayTimeIntervalType)):
8130
- exception = NotImplementedError(
8131
- "try_multiply with interval types is not supported"
8132
- )
8133
- attach_custom_error_code(
8134
- exception, ErrorCodes.UNSUPPORTED_OPERATION
8135
- )
8136
- raise exception
8137
8713
  match (snowpark_typed_args[0].typ, snowpark_typed_args[1].typ):
8714
+ case (t1, t2) if isinstance(t1, _AnsiIntervalType) and isinstance(
8715
+ t2, (_NumericType, StringType)
8716
+ ):
8717
+ # Interval * numeric/string
8718
+ result_type = t1
8719
+ interval_arg = snowpark_args[0]
8720
+ multiplier = (
8721
+ snowpark_args[1]
8722
+ if isinstance(t2, _NumericType)
8723
+ else snowpark_fn.cast(snowpark_args[1], "double")
8724
+ )
8725
+
8726
+ if isinstance(result_type, YearMonthIntervalType):
8727
+ # For year-month intervals, check if result exceeds 32-bit signed integer limit
8728
+ result_type = YearMonthIntervalType()
8729
+ total_months = _calculate_total_months(interval_arg)
8730
+ max_months = snowpark_fn.lit(MAX_32BIT_SIGNED_INT)
8731
+ overflow_check = (
8732
+ snowpark_fn.abs(total_months * multiplier) > max_months
8733
+ )
8734
+ result_exp = snowpark_fn.when(
8735
+ overflow_check, snowpark_fn.lit(None)
8736
+ ).otherwise(interval_arg * multiplier)
8737
+ else: # DayTimeIntervalType
8738
+ # For day-time intervals, check if result exceeds day limit
8739
+ result_type = DayTimeIntervalType()
8740
+ total_days = _calculate_total_days(interval_arg)
8741
+ max_days = snowpark_fn.lit(MAX_DAY_TIME_DAYS)
8742
+ overflow_check = (
8743
+ snowpark_fn.abs(total_days * multiplier) > max_days
8744
+ )
8745
+ result_exp = snowpark_fn.when(
8746
+ overflow_check, snowpark_fn.lit(None)
8747
+ ).otherwise(interval_arg * multiplier)
8748
+
8749
+ case (t1, t2) if isinstance(t2, _AnsiIntervalType) and isinstance(
8750
+ t1, (_NumericType, StringType)
8751
+ ):
8752
+ # numeric/string * Interval
8753
+ result_type = t2
8754
+ interval_arg = snowpark_args[1]
8755
+ multiplier = (
8756
+ snowpark_args[0]
8757
+ if isinstance(t1, _NumericType)
8758
+ else snowpark_fn.cast(snowpark_args[0], "double")
8759
+ )
8760
+
8761
+ if isinstance(result_type, YearMonthIntervalType):
8762
+ # For year-month intervals, check if result exceeds 32-bit signed integer limit
8763
+ result_type = YearMonthIntervalType()
8764
+ total_months = _calculate_total_months(interval_arg)
8765
+ max_months = snowpark_fn.lit(MAX_32BIT_SIGNED_INT)
8766
+ overflow_check = (
8767
+ snowpark_fn.abs(total_months * multiplier) > max_months
8768
+ )
8769
+ result_exp = snowpark_fn.when(
8770
+ overflow_check, snowpark_fn.lit(None)
8771
+ ).otherwise(interval_arg * multiplier)
8772
+ else: # DayTimeIntervalType
8773
+ # For day-time intervals, check if result exceeds day limit
8774
+ result_type = DayTimeIntervalType()
8775
+ total_days = _calculate_total_days(interval_arg)
8776
+ max_days = snowpark_fn.lit(MAX_DAY_TIME_DAYS)
8777
+ overflow_check = (
8778
+ snowpark_fn.abs(total_days * multiplier) > max_days
8779
+ )
8780
+ result_exp = snowpark_fn.when(
8781
+ overflow_check, snowpark_fn.lit(None)
8782
+ ).otherwise(interval_arg * multiplier)
8138
8783
  case (NullType(), t) | (t, NullType()):
8139
8784
  result_exp = snowpark_fn.lit(None)
8140
8785
  match t:
@@ -8234,18 +8879,112 @@ def map_unresolved_function(
8234
8879
  snowpark_typed_args[0].typ, snowpark_args[0]
8235
8880
  )
8236
8881
  case "try_subtract":
8237
- # Check for interval types and throw NotImplementedError
8238
- for arg in snowpark_typed_args:
8239
- if isinstance(arg.typ, (YearMonthIntervalType, DayTimeIntervalType)):
8240
- exception = NotImplementedError(
8241
- "try_subtract with interval types is not supported"
8242
- )
8243
- attach_custom_error_code(
8244
- exception, ErrorCodes.UNSUPPORTED_OPERATION
8882
+ # Handle interval arithmetic with overflow detection
8883
+ match (snowpark_typed_args[0].typ, snowpark_typed_args[1].typ):
8884
+ case (DateType(), t) if isinstance(t, YearMonthIntervalType):
8885
+ result_type = DateType()
8886
+ result_exp = snowpark_args[0] - snowpark_args[1]
8887
+ case (DateType(), t) if isinstance(t, DayTimeIntervalType):
8888
+ result_type = TimestampType()
8889
+ result_exp = snowpark_args[0] - snowpark_args[1]
8890
+ case (TimestampType(), t) if isinstance(
8891
+ t, (DayTimeIntervalType, YearMonthIntervalType)
8892
+ ):
8893
+ result_type = snowpark_typed_args[0].typ
8894
+ result_exp = snowpark_args[0] - snowpark_args[1]
8895
+ case (t1, t2) if (
8896
+ isinstance(t1, YearMonthIntervalType)
8897
+ and isinstance(t2, (_NumericType, StringType))
8898
+ ) or (
8899
+ isinstance(t2, YearMonthIntervalType)
8900
+ and isinstance(t1, (_NumericType, StringType))
8901
+ ):
8902
+ # YearMonthInterval - numeric/string or numeric/string - YearMonthInterval should throw error
8903
+ exception = AnalysisException(
8904
+ f'[DATATYPE_MISMATCH.BINARY_OP_DIFF_TYPES] Cannot resolve "try_subtract({snowpark_arg_names[0]}, {snowpark_arg_names[1]})" due to data type mismatch: the left and right operands of the binary operator have incompatible types ("{snowpark_typed_args[0].typ}" and "{snowpark_typed_args[1].typ}").'
8245
8905
  )
8906
+ attach_custom_error_code(exception, ErrorCodes.TYPE_MISMATCH)
8246
8907
  raise exception
8247
- result_exp = _try_arithmetic_helper(snowpark_typed_args, snowpark_args, 1)
8248
- result_exp = _type_with_typer(result_exp)
8908
+ case (t1, t2) if isinstance(t1, YearMonthIntervalType) and isinstance(
8909
+ t2, YearMonthIntervalType
8910
+ ):
8911
+ result_type = YearMonthIntervalType(
8912
+ min(t1.start_field, t2.start_field),
8913
+ max(t1.end_field, t2.end_field),
8914
+ )
8915
+ # Check for Snowflake's precision limits: 10+ digits for operands, 9+ digits for results
8916
+ total1 = _calculate_total_months(snowpark_args[0])
8917
+ total2 = _calculate_total_months(snowpark_args[1])
8918
+ ten_digit_limit = snowpark_fn.lit(MAX_10_DIGIT_LIMIT)
8919
+
8920
+ precision_violation = (
8921
+ # Check if either operand already reaches 10 digits (parsing limit)
8922
+ (snowpark_fn.abs(total1) >= ten_digit_limit)
8923
+ | (snowpark_fn.abs(total2) >= ten_digit_limit)
8924
+ | (
8925
+ (total1 > 0)
8926
+ & (total2 < 0)
8927
+ & (total1 >= ten_digit_limit + total2)
8928
+ )
8929
+ | (
8930
+ (total1 < 0)
8931
+ & (total2 > 0)
8932
+ & (total1 <= -ten_digit_limit + total2)
8933
+ )
8934
+ )
8935
+
8936
+ raise_error = _raise_error_helper(result_type, ArithmeticException)
8937
+ result_exp = snowpark_fn.when(
8938
+ precision_violation,
8939
+ raise_error(
8940
+ snowpark_fn.lit(
8941
+ "Year-Month Interval result exceeds Snowflake interval precision limit"
8942
+ )
8943
+ ),
8944
+ ).otherwise(snowpark_args[0] - snowpark_args[1])
8945
+ case (t1, t2) if isinstance(t1, DayTimeIntervalType) and isinstance(
8946
+ t2, DayTimeIntervalType
8947
+ ):
8948
+ result_type = DayTimeIntervalType(
8949
+ min(t1.start_field, t2.start_field),
8950
+ max(t1.end_field, t2.end_field),
8951
+ )
8952
+ # Check for Snowflake's day limit (106751991 days is the cutoff)
8953
+ days1 = snowpark_fn.date_part("day", snowpark_args[0])
8954
+ days2 = snowpark_fn.date_part("day", snowpark_args[1])
8955
+ max_days = snowpark_fn.lit(
8956
+ MAX_DAY_TIME_DAYS
8957
+ ) # Snowflake's actual limit
8958
+ min_days = snowpark_fn.lit(-MAX_DAY_TIME_DAYS)
8959
+
8960
+ # Check if either operand exceeds the day limit - throw error like Spark does
8961
+ operand_limit_violation = (snowpark_fn.abs(days1) > max_days) | (
8962
+ snowpark_fn.abs(days2) > max_days
8963
+ )
8964
+
8965
+ # Check if result would exceed day limit (but operands are valid) - return NULL
8966
+ result_overflow = (
8967
+ (days1 > 0) & (days2 < 0) & (days1 > max_days + days2)
8968
+ ) | ((days1 < 0) & (days2 > 0) & (days1 < min_days + days2))
8969
+
8970
+ raise_error = _raise_error_helper(result_type, ArithmeticException)
8971
+ result_exp = (
8972
+ snowpark_fn.when(
8973
+ operand_limit_violation,
8974
+ raise_error(
8975
+ snowpark_fn.lit(
8976
+ "Day-Time Interval operand exceeds day limit"
8977
+ )
8978
+ ),
8979
+ )
8980
+ .when(result_overflow, snowpark_fn.lit(None))
8981
+ .otherwise(snowpark_args[0] - snowpark_args[1])
8982
+ )
8983
+ case _:
8984
+ result_exp = _try_arithmetic_helper(
8985
+ snowpark_typed_args, snowpark_args, 1
8986
+ )
8987
+ result_exp = _type_with_typer(result_exp)
8249
8988
  case "try_to_number":
8250
8989
  try_to_number = snowpark_fn.function("try_to_number")
8251
8990
  precision, scale = resolve_to_number_precision_and_scale(exp)
@@ -8735,15 +9474,21 @@ def map_unresolved_function(
8735
9474
  result_exp = snowpark_fn.year(snowpark_fn.to_date(snowpark_args[0]))
8736
9475
  result_type = LongType()
8737
9476
  case binary_method if binary_method in ("to_binary", "try_to_binary"):
8738
- binary_format = "hex"
9477
+ binary_format = snowpark_fn.lit("hex")
9478
+ arg_str = snowpark_fn.cast(snowpark_args[0], StringType())
8739
9479
  if len(snowpark_args) > 1:
8740
9480
  binary_format = snowpark_args[1]
8741
9481
  result_exp = snowpark_fn.when(
8742
9482
  snowpark_args[0].isNull(), snowpark_fn.lit(None)
8743
9483
  ).otherwise(
8744
9484
  snowpark_fn.function(binary_method)(
8745
- snowpark_fn.cast(snowpark_args[0], StringType()), binary_format
8746
- ),
9485
+ snowpark_fn.when(
9486
+ (snowpark_fn.length(arg_str) % 2 == 1)
9487
+ & (snowpark_fn.lower(binary_format) == snowpark_fn.lit("hex")),
9488
+ snowpark_fn.concat(snowpark_fn.lit("0"), arg_str),
9489
+ ).otherwise(arg_str),
9490
+ binary_format,
9491
+ )
8747
9492
  )
8748
9493
  result_type = BinaryType()
8749
9494
  case udtf_name if udtf_name.lower() in session._udtfs:
@@ -8828,7 +9573,7 @@ def map_unresolved_function(
8828
9573
  spark_col_names if len(spark_col_names) > 0 else [spark_function_name]
8829
9574
  )
8830
9575
  typed_col = _to_typed_column(result_exp, result_type, function_name)
8831
- typed_col.set_qualifiers(qualifiers)
9576
+ typed_col.set_qualifiers({ColumnQualifier(tuple(qualifier_parts))})
8832
9577
  return spark_col_names, typed_col
8833
9578
 
8834
9579
 
@@ -9025,6 +9770,20 @@ def _find_common_type(
9025
9770
  key_type = _common(type1.key_type, type2.key_type)
9026
9771
  value_type = _common(type1.value_type, type2.value_type)
9027
9772
  return MapType(key_type, value_type)
9773
+ case (_, _) if isinstance(type1, YearMonthIntervalType) and isinstance(
9774
+ type2, YearMonthIntervalType
9775
+ ):
9776
+ return YearMonthIntervalType(
9777
+ min(type1.start_field, type2.start_field),
9778
+ max(type1.end_field, type2.end_field),
9779
+ )
9780
+ case (_, _) if isinstance(type1, DayTimeIntervalType) and isinstance(
9781
+ type2, DayTimeIntervalType
9782
+ ):
9783
+ return DayTimeIntervalType(
9784
+ min(type1.start_field, type2.start_field),
9785
+ max(type1.end_field, type2.end_field),
9786
+ )
9028
9787
  case _:
9029
9788
  exception = AnalysisException(exception_base_message)
9030
9789
  attach_custom_error_code(exception, ErrorCodes.TYPE_MISMATCH)
@@ -9968,12 +10727,18 @@ def _try_sum_helper(
9968
10727
  return snowpark_fn.lit(None), new_type
9969
10728
  else:
9970
10729
  non_null_rows = snowpark_fn.count(col_name)
9971
- return aggregate_sum / non_null_rows, new_type
10730
+ # Use _divnull to handle case when non_null_rows is 0
10731
+ return _divnull(aggregate_sum, non_null_rows), new_type
9972
10732
  else:
9973
10733
  new_type = DecimalType(
9974
10734
  precision=min(38, arg_type.precision + 10), scale=arg_type.scale
9975
10735
  )
9976
- return aggregate_sum, new_type
10736
+ # Return NULL when there are no non-null values (i.e., all values are NULL); this is handled using case/when to check for non-null values for both SUM and the sum component of AVG calculations.
10737
+ non_null_rows = snowpark_fn.count(col_name)
10738
+ result = snowpark_fn.when(
10739
+ non_null_rows == 0, snowpark_fn.lit(None)
10740
+ ).otherwise(aggregate_sum)
10741
+ return result, new_type
9977
10742
 
9978
10743
  case _:
9979
10744
  # If the input column is floating point (double and float are synonymous in Snowflake per
@@ -9991,9 +10756,16 @@ def _try_sum_helper(
9991
10756
  return snowpark_fn.lit(None), DoubleType()
9992
10757
  else:
9993
10758
  non_null_rows = snowpark_fn.count(col_name)
9994
- return aggregate_sum / non_null_rows, DoubleType()
10759
+ # Use _divnull to handle case when non_null_rows is 0
10760
+ return _divnull(aggregate_sum, non_null_rows), DoubleType()
9995
10761
  else:
9996
- return aggregate_sum, DoubleType()
10762
+ # When all values are NULL, SUM should return NULL (not 0)
10763
+ # Use case/when to return NULL when there are no non-null values (i.e., all values are NULL)
10764
+ non_null_rows = snowpark_fn.count(col_name)
10765
+ result = snowpark_fn.when(
10766
+ non_null_rows == 0, snowpark_fn.lit(None)
10767
+ ).otherwise(aggregate_sum)
10768
+ return result, DoubleType()
9997
10769
 
9998
10770
 
9999
10771
  def _get_type_precision(typ: DataType) -> tuple[int, int]:
@@ -10384,9 +11156,7 @@ def _get_add_sub_result_type(
10384
11156
  return result_type, overflow_possible
10385
11157
 
10386
11158
 
10387
- def _get_interval_type_name(
10388
- interval_type: Union[YearMonthIntervalType, DayTimeIntervalType]
10389
- ) -> str:
11159
+ def _get_interval_type_name(interval_type: _AnsiIntervalType) -> str:
10390
11160
  """Get the formatted interval type name for error messages."""
10391
11161
  if isinstance(interval_type, YearMonthIntervalType):
10392
11162
  if interval_type.start_field == 0 and interval_type.end_field == 0:
@@ -10413,21 +11183,15 @@ def _check_interval_string_comparison(
10413
11183
  ) -> None:
10414
11184
  """Check for invalid interval-string comparisons and raise AnalysisException if found."""
10415
11185
  if (
10416
- isinstance(
10417
- snowpark_typed_args[0].typ, (YearMonthIntervalType, DayTimeIntervalType)
10418
- )
11186
+ isinstance(snowpark_typed_args[0].typ, _AnsiIntervalType)
10419
11187
  and isinstance(snowpark_typed_args[1].typ, StringType)
10420
11188
  or isinstance(snowpark_typed_args[0].typ, StringType)
10421
- and isinstance(
10422
- snowpark_typed_args[1].typ, (YearMonthIntervalType, DayTimeIntervalType)
10423
- )
11189
+ and isinstance(snowpark_typed_args[1].typ, _AnsiIntervalType)
10424
11190
  ):
10425
11191
  # Format interval type name for error message
10426
11192
  interval_type = (
10427
11193
  snowpark_typed_args[0].typ
10428
- if isinstance(
10429
- snowpark_typed_args[0].typ, (YearMonthIntervalType, DayTimeIntervalType)
10430
- )
11194
+ if isinstance(snowpark_typed_args[0].typ, _AnsiIntervalType)
10431
11195
  else snowpark_typed_args[1].typ
10432
11196
  )
10433
11197
  interval_name = _get_interval_type_name(interval_type)
@@ -10494,12 +11258,18 @@ def _get_spark_function_name(
10494
11258
  case (DateType(), DayTimeIntervalType()) | (
10495
11259
  DateType(),
10496
11260
  YearMonthIntervalType(),
11261
+ ) | (TimestampType(), DayTimeIntervalType()) | (
11262
+ TimestampType(),
11263
+ YearMonthIntervalType(),
10497
11264
  ):
10498
11265
  date_param_name1 = _get_literal_param_name(exp, 0, snowpark_arg_names[0])
10499
11266
  return f"{date_param_name1} {operation_op} {snowpark_arg_names[1]}"
10500
11267
  case (DayTimeIntervalType(), DateType()) | (
10501
11268
  YearMonthIntervalType(),
10502
11269
  DateType(),
11270
+ ) | (DayTimeIntervalType(), TimestampType()) | (
11271
+ YearMonthIntervalType(),
11272
+ TimestampType(),
10503
11273
  ):
10504
11274
  date_param_name2 = _get_literal_param_name(exp, 1, snowpark_arg_names[1])
10505
11275
  if function_name == "+":
@@ -10887,3 +11657,30 @@ def _map_from_spark_tz(value: Column) -> Column:
10887
11657
  .when(value == "VST", snowpark_fn.lit("Asia/Ho_Chi_Minh"))
10888
11658
  .otherwise(value) # Return original timezone if no mapping found
10889
11659
  )
11660
+
11661
+
11662
+ def _calculate_total_months(interval_arg):
11663
+ """Calculate total months from a year-month interval."""
11664
+ years = snowpark_fn.date_part("year", interval_arg)
11665
+ months = snowpark_fn.date_part("month", interval_arg)
11666
+ return years * 12 + months
11667
+
11668
+
11669
+ def _calculate_total_days(interval_arg):
11670
+ """Calculate total days from a day-time interval."""
11671
+ days = snowpark_fn.date_part("day", interval_arg)
11672
+ hours = snowpark_fn.date_part("hour", interval_arg)
11673
+ minutes = snowpark_fn.date_part("minute", interval_arg)
11674
+ seconds = snowpark_fn.date_part("second", interval_arg)
11675
+ # Convert hours, minutes, seconds to fractional days
11676
+ fractional_days = (hours * 3600 + minutes * 60 + seconds) / 86400
11677
+ return days + fractional_days
11678
+
11679
+
11680
+ def _calculate_total_seconds(interval_arg):
11681
+ """Calculate total seconds from a day-time interval."""
11682
+ days = snowpark_fn.date_part("day", interval_arg)
11683
+ hours = snowpark_fn.date_part("hour", interval_arg)
11684
+ minutes = snowpark_fn.date_part("minute", interval_arg)
11685
+ seconds = snowpark_fn.date_part("second", interval_arg)
11686
+ return days * 86400 + hours * 3600 + minutes * 60 + seconds