snowpark-connect 0.32.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of snowpark-connect might be problematic. Click here for more details.

Files changed (106) hide show
  1. snowflake/snowpark_connect/column_name_handler.py +91 -40
  2. snowflake/snowpark_connect/column_qualifier.py +0 -4
  3. snowflake/snowpark_connect/config.py +9 -0
  4. snowflake/snowpark_connect/expression/hybrid_column_map.py +5 -4
  5. snowflake/snowpark_connect/expression/literal.py +12 -12
  6. snowflake/snowpark_connect/expression/map_sql_expression.py +18 -4
  7. snowflake/snowpark_connect/expression/map_unresolved_attribute.py +150 -29
  8. snowflake/snowpark_connect/expression/map_unresolved_function.py +93 -55
  9. snowflake/snowpark_connect/relation/map_aggregate.py +156 -257
  10. snowflake/snowpark_connect/relation/map_column_ops.py +19 -0
  11. snowflake/snowpark_connect/relation/map_join.py +454 -252
  12. snowflake/snowpark_connect/relation/map_row_ops.py +136 -54
  13. snowflake/snowpark_connect/relation/map_sql.py +335 -90
  14. snowflake/snowpark_connect/relation/read/map_read.py +9 -1
  15. snowflake/snowpark_connect/relation/read/map_read_csv.py +19 -2
  16. snowflake/snowpark_connect/relation/read/map_read_json.py +90 -2
  17. snowflake/snowpark_connect/relation/read/map_read_parquet.py +3 -0
  18. snowflake/snowpark_connect/relation/read/map_read_text.py +4 -0
  19. snowflake/snowpark_connect/relation/read/reader_config.py +10 -0
  20. snowflake/snowpark_connect/relation/read/utils.py +41 -0
  21. snowflake/snowpark_connect/relation/utils.py +50 -2
  22. snowflake/snowpark_connect/relation/write/map_write.py +251 -292
  23. snowflake/snowpark_connect/resources_initializer.py +25 -13
  24. snowflake/snowpark_connect/server.py +9 -24
  25. snowflake/snowpark_connect/type_mapping.py +2 -0
  26. snowflake/snowpark_connect/typed_column.py +2 -2
  27. snowflake/snowpark_connect/utils/context.py +0 -14
  28. snowflake/snowpark_connect/utils/expression_transformer.py +163 -0
  29. snowflake/snowpark_connect/utils/sequence.py +21 -0
  30. snowflake/snowpark_connect/utils/session.py +4 -1
  31. snowflake/snowpark_connect/utils/udf_helper.py +1 -0
  32. snowflake/snowpark_connect/utils/udtf_helper.py +3 -0
  33. snowflake/snowpark_connect/version.py +1 -1
  34. {snowpark_connect-0.32.0.dist-info → snowpark_connect-1.0.0.dist-info}/METADATA +4 -2
  35. {snowpark_connect-0.32.0.dist-info → snowpark_connect-1.0.0.dist-info}/RECORD +43 -104
  36. snowflake/snowpark_connect/includes/jars/antlr4-runtime-4.9.3.jar +0 -0
  37. snowflake/snowpark_connect/includes/jars/commons-cli-1.5.0.jar +0 -0
  38. snowflake/snowpark_connect/includes/jars/commons-codec-1.16.1.jar +0 -0
  39. snowflake/snowpark_connect/includes/jars/commons-collections-3.2.2.jar +0 -0
  40. snowflake/snowpark_connect/includes/jars/commons-collections4-4.4.jar +0 -0
  41. snowflake/snowpark_connect/includes/jars/commons-compiler-3.1.9.jar +0 -0
  42. snowflake/snowpark_connect/includes/jars/commons-compress-1.26.0.jar +0 -0
  43. snowflake/snowpark_connect/includes/jars/commons-crypto-1.1.0.jar +0 -0
  44. snowflake/snowpark_connect/includes/jars/commons-dbcp-1.4.jar +0 -0
  45. snowflake/snowpark_connect/includes/jars/commons-io-2.16.1.jar +0 -0
  46. snowflake/snowpark_connect/includes/jars/commons-lang-2.6.jar +0 -0
  47. snowflake/snowpark_connect/includes/jars/commons-lang3-3.12.0.jar +0 -0
  48. snowflake/snowpark_connect/includes/jars/commons-logging-1.1.3.jar +0 -0
  49. snowflake/snowpark_connect/includes/jars/commons-math3-3.6.1.jar +0 -0
  50. snowflake/snowpark_connect/includes/jars/commons-pool-1.5.4.jar +0 -0
  51. snowflake/snowpark_connect/includes/jars/commons-text-1.10.0.jar +0 -0
  52. snowflake/snowpark_connect/includes/jars/hadoop-client-api-trimmed-3.3.4.jar +0 -0
  53. snowflake/snowpark_connect/includes/jars/jackson-annotations-2.15.2.jar +0 -0
  54. snowflake/snowpark_connect/includes/jars/jackson-core-2.15.2.jar +0 -0
  55. snowflake/snowpark_connect/includes/jars/jackson-core-asl-1.9.13.jar +0 -0
  56. snowflake/snowpark_connect/includes/jars/jackson-databind-2.15.2.jar +0 -0
  57. snowflake/snowpark_connect/includes/jars/jackson-dataformat-yaml-2.15.2.jar +0 -0
  58. snowflake/snowpark_connect/includes/jars/jackson-datatype-jsr310-2.15.2.jar +0 -0
  59. snowflake/snowpark_connect/includes/jars/jackson-module-scala_2.12-2.15.2.jar +0 -0
  60. snowflake/snowpark_connect/includes/jars/json4s-ast_2.12-3.7.0-M11.jar +0 -0
  61. snowflake/snowpark_connect/includes/jars/json4s-core_2.12-3.7.0-M11.jar +0 -0
  62. snowflake/snowpark_connect/includes/jars/json4s-jackson_2.12-3.7.0-M11.jar +0 -0
  63. snowflake/snowpark_connect/includes/jars/json4s-native_2.12-3.7.0-M11.jar +0 -0
  64. snowflake/snowpark_connect/includes/jars/json4s-scalap_2.12-3.7.0-M11.jar +0 -0
  65. snowflake/snowpark_connect/includes/jars/kryo-shaded-4.0.2.jar +0 -0
  66. snowflake/snowpark_connect/includes/jars/log4j-1.2-api-2.20.0.jar +0 -0
  67. snowflake/snowpark_connect/includes/jars/log4j-api-2.20.0.jar +0 -0
  68. snowflake/snowpark_connect/includes/jars/log4j-core-2.20.0.jar +0 -0
  69. snowflake/snowpark_connect/includes/jars/log4j-slf4j2-impl-2.20.0.jar +0 -0
  70. snowflake/snowpark_connect/includes/jars/paranamer-2.8.3.jar +0 -0
  71. snowflake/snowpark_connect/includes/jars/paranamer-2.8.jar +0 -0
  72. snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.1.0.jar +0 -0
  73. snowflake/snowpark_connect/includes/jars/scala-collection-compat_2.12-2.7.0.jar +0 -0
  74. snowflake/snowpark_connect/includes/jars/scala-library-2.12.18.jar +0 -0
  75. snowflake/snowpark_connect/includes/jars/scala-parser-combinators_2.12-2.3.0.jar +0 -0
  76. snowflake/snowpark_connect/includes/jars/scala-reflect-2.12.18.jar +0 -0
  77. snowflake/snowpark_connect/includes/jars/scala-xml_2.12-2.1.0.jar +0 -0
  78. snowflake/snowpark_connect/includes/jars/slf4j-api-2.0.7.jar +0 -0
  79. snowflake/snowpark_connect/includes/jars/spark-catalyst_2.12-3.5.6.jar +0 -0
  80. snowflake/snowpark_connect/includes/jars/spark-common-utils_2.12-3.5.6.jar +0 -0
  81. snowflake/snowpark_connect/includes/jars/spark-connect-client-jvm_2.12-3.5.6.jar +0 -0
  82. snowflake/snowpark_connect/includes/jars/spark-core_2.12-3.5.6.jar +0 -0
  83. snowflake/snowpark_connect/includes/jars/spark-graphx_2.12-3.5.6.jar +0 -0
  84. snowflake/snowpark_connect/includes/jars/spark-hive-thriftserver_2.12-3.5.6.jar +0 -0
  85. snowflake/snowpark_connect/includes/jars/spark-hive_2.12-3.5.6.jar +0 -0
  86. snowflake/snowpark_connect/includes/jars/spark-kvstore_2.12-3.5.6.jar +0 -0
  87. snowflake/snowpark_connect/includes/jars/spark-launcher_2.12-3.5.6.jar +0 -0
  88. snowflake/snowpark_connect/includes/jars/spark-mesos_2.12-3.5.6.jar +0 -0
  89. snowflake/snowpark_connect/includes/jars/spark-mllib-local_2.12-3.5.6.jar +0 -0
  90. snowflake/snowpark_connect/includes/jars/spark-network-common_2.12-3.5.6.jar +0 -0
  91. snowflake/snowpark_connect/includes/jars/spark-network-shuffle_2.12-3.5.6.jar +0 -0
  92. snowflake/snowpark_connect/includes/jars/spark-repl_2.12-3.5.6.jar +0 -0
  93. snowflake/snowpark_connect/includes/jars/spark-sketch_2.12-3.5.6.jar +0 -0
  94. snowflake/snowpark_connect/includes/jars/spark-sql-api_2.12-3.5.6.jar +0 -0
  95. snowflake/snowpark_connect/includes/jars/spark-sql_2.12-3.5.6.jar +0 -0
  96. snowflake/snowpark_connect/includes/jars/spark-tags_2.12-3.5.6.jar +0 -0
  97. snowflake/snowpark_connect/includes/jars/spark-unsafe_2.12-3.5.6.jar +0 -0
  98. snowflake/snowpark_connect/includes/jars/spark-yarn_2.12-3.5.6.jar +0 -0
  99. {snowpark_connect-0.32.0.data → snowpark_connect-1.0.0.data}/scripts/snowpark-connect +0 -0
  100. {snowpark_connect-0.32.0.data → snowpark_connect-1.0.0.data}/scripts/snowpark-session +0 -0
  101. {snowpark_connect-0.32.0.data → snowpark_connect-1.0.0.data}/scripts/snowpark-submit +0 -0
  102. {snowpark_connect-0.32.0.dist-info → snowpark_connect-1.0.0.dist-info}/WHEEL +0 -0
  103. {snowpark_connect-0.32.0.dist-info → snowpark_connect-1.0.0.dist-info}/licenses/LICENSE-binary +0 -0
  104. {snowpark_connect-0.32.0.dist-info → snowpark_connect-1.0.0.dist-info}/licenses/LICENSE.txt +0 -0
  105. {snowpark_connect-0.32.0.dist-info → snowpark_connect-1.0.0.dist-info}/licenses/NOTICE-binary +0 -0
  106. {snowpark_connect-0.32.0.dist-info → snowpark_connect-1.0.0.dist-info}/top_level.txt +0 -0
@@ -24,7 +24,6 @@ from typing import List, Optional
24
24
  from urllib.parse import quote, unquote
25
25
 
26
26
  import pyspark.sql.connect.proto.expressions_pb2 as expressions_proto
27
- import pyspark.sql.functions as pyspark_functions
28
27
  from google.protobuf.message import Message
29
28
  from pyspark.errors.exceptions.base import (
30
29
  AnalysisException,
@@ -101,6 +100,7 @@ from snowflake.snowpark_connect.expression.map_unresolved_star import (
101
100
  )
102
101
  from snowflake.snowpark_connect.expression.typer import ExpressionTyper
103
102
  from snowflake.snowpark_connect.relation.catalogs.utils import CURRENT_CATALOG_NAME
103
+ from snowflake.snowpark_connect.relation.utils import is_aggregate_function
104
104
  from snowflake.snowpark_connect.type_mapping import (
105
105
  map_json_schema_to_snowpark,
106
106
  map_pyspark_types_to_snowpark_types,
@@ -119,7 +119,6 @@ from snowflake.snowpark_connect.utils.context import (
119
119
  get_is_evaluating_sql,
120
120
  get_is_in_udtf_context,
121
121
  get_spark_version,
122
- is_in_pivot,
123
122
  is_window_enabled,
124
123
  push_udtf_context,
125
124
  resolving_fun_args,
@@ -400,9 +399,8 @@ def map_unresolved_function(
400
399
  result_type: Optional[DataType | List[DateType]] = None
401
400
  qualifier_parts: List[str] = []
402
401
 
403
- pyspark_func = getattr(pyspark_functions, function_name, None)
404
- if pyspark_func and pyspark_func.__doc__.lstrip().startswith("Aggregate function:"):
405
- # Used by the GROUP BY ALL implementation. Far from ideal, but it seems to work...
402
+ # Check if this is an aggregate function (used by GROUP BY ALL implementation)
403
+ if is_aggregate_function(function_name):
406
404
  add_sql_aggregate_function()
407
405
 
408
406
  def _type_with_typer(col: Column) -> TypedColumn:
@@ -412,10 +410,7 @@ def map_unresolved_function(
412
410
  def _resolve_aggregate_exp(
413
411
  result_exp: Column, default_result_type: DataType
414
412
  ) -> TypedColumn:
415
- if is_in_pivot():
416
- # it's not possible to cast the result in pivot
417
- return _type_with_typer(result_exp)
418
- elif is_window_enabled():
413
+ if is_window_enabled():
419
414
  # defer casting to capture whole window expression
420
415
  return TypedColumnWithDeferredCast(
421
416
  result_exp, lambda: [default_result_type]
@@ -518,7 +513,6 @@ def map_unresolved_function(
518
513
  STRUCTURED_INCOMPATIBLE_AGGREGATES = {"min", "max"}
519
514
  if (
520
515
  aggregate_func.__name__ in STRUCTURED_INCOMPATIBLE_AGGREGATES
521
- and not is_in_pivot()
522
516
  and not is_window_enabled()
523
517
  and isinstance(typed_arg.typ, (ArrayType, MapType, StructType))
524
518
  ):
@@ -526,12 +520,7 @@ def map_unresolved_function(
526
520
  variant_arg = snowpark_fn.to_variant(typed_arg.col)
527
521
  result = aggregate_func(variant_arg)
528
522
 
529
- if is_in_pivot():
530
- # In PIVOT context, casting is not allowed, so return VARIANT type
531
- return _type_with_typer(result)
532
- else:
533
- # Cast back to the original type to maintain the expected schema
534
- return TypedColumn(result.cast(typed_arg.typ), lambda: expected_types)
523
+ return TypedColumn(result.cast(typed_arg.typ), lambda: expected_types)
535
524
  else:
536
525
  # No structured type conversion needed
537
526
  result = aggregate_func(typed_arg.col)
@@ -912,15 +901,28 @@ def map_unresolved_function(
912
901
  ):
913
902
  # String + YearMonthInterval: Spark tries to cast string to double first, throws error if it fails
914
903
  result_type = StringType()
904
+ raise_error = _raise_error_helper(StringType(), AnalysisException)
915
905
  if isinstance(snowpark_typed_args[0].typ, StringType):
916
- result_exp = (
917
- snowpark_fn.cast(snowpark_args[0], "double")
918
- + snowpark_args[1]
919
- )
906
+ # Try to cast string to double, if it fails (returns null), raise exception
907
+ cast_result = snowpark_fn.try_cast(snowpark_args[0], "double")
908
+ result_exp = snowpark_fn.when(
909
+ cast_result.is_null(),
910
+ raise_error(
911
+ snowpark_fn.lit(
912
+ f'The value \'{snowpark_args[0]}\' of the type {snowpark_typed_args[0].typ} cannot be cast to "DOUBLE" because it is malformed. Correct the value as per the syntax, or change its target type. Use `try_cast` to tolerate malformed input and return NULL instead. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error.'
913
+ )
914
+ ),
915
+ ).otherwise(cast_result + snowpark_args[1])
920
916
  else:
921
- result_exp = snowpark_args[0] + snowpark_fn.cast(
922
- snowpark_args[1], "double"
923
- )
917
+ cast_result = snowpark_fn.try_cast(snowpark_args[1], "double")
918
+ result_exp = snowpark_fn.when(
919
+ cast_result.is_null(),
920
+ raise_error(
921
+ snowpark_fn.lit(
922
+ f'The value \'{snowpark_args[0]}\' of the type {snowpark_typed_args[0].typ} cannot be cast to "DOUBLE" because it is malformed. Correct the value as per the syntax, or change its target type. Use `try_cast` to tolerate malformed input and return NULL instead. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error.'
923
+ )
924
+ ),
925
+ ).otherwise(snowpark_args[0] + cast_result)
924
926
  case (StringType(), t) | (t, StringType()) if isinstance(
925
927
  t, DayTimeIntervalType
926
928
  ):
@@ -1016,9 +1018,12 @@ def map_unresolved_function(
1016
1018
  result_exp = snowpark_args[0] - snowpark_args[1]
1017
1019
  result_exp = result_exp.cast(result_type)
1018
1020
  case (DateType(), DateType()):
1019
- # TODO SNOW-2034420: resolve return type (it should be INTERVAL DAY)
1020
- result_type = LongType()
1021
- result_exp = snowpark_args[0] - snowpark_args[1]
1021
+ result_type = DayTimeIntervalType(
1022
+ DayTimeIntervalType.DAY, DayTimeIntervalType.DAY
1023
+ )
1024
+ result_exp = snowpark_fn.interval_day_time_from_parts(
1025
+ snowpark_args[0] - snowpark_args[1]
1026
+ )
1022
1027
  case (DateType(), DayTimeIntervalType()) | (
1023
1028
  DateType(),
1024
1029
  YearMonthIntervalType(),
@@ -1037,14 +1042,22 @@ def map_unresolved_function(
1037
1042
  result_type = TimestampType()
1038
1043
  result_exp = snowpark_args[0] - snowpark_args[1]
1039
1044
  else:
1040
- # TODO SNOW-2034420: resolve return type (it should be INTERVAL DAY)
1041
- result_type = LongType()
1042
1045
  input_type = (
1043
1046
  DateType() if spark_sql_ansi_enabled else DoubleType()
1044
1047
  )
1045
- result_exp = snowpark_args[0] - snowpark_args[1].cast(
1046
- input_type
1047
- )
1048
+ if isinstance(input_type, DateType):
1049
+ result_type = DayTimeIntervalType(
1050
+ DayTimeIntervalType.DAY, DayTimeIntervalType.DAY
1051
+ )
1052
+ result_exp = snowpark_fn.interval_day_time_from_parts(
1053
+ snowpark_args[0] - snowpark_args[1].cast(input_type)
1054
+ )
1055
+ else:
1056
+ # If ANSI is disabled, cast to DoubleType and return long (legacy behavior)
1057
+ result_type = LongType()
1058
+ result_exp = snowpark_args[0] - snowpark_args[1].cast(
1059
+ input_type
1060
+ )
1048
1061
  case (TimestampType(), DayTimeIntervalType()) | (
1049
1062
  TimestampType(),
1050
1063
  YearMonthIntervalType(),
@@ -1063,10 +1076,12 @@ def map_unresolved_function(
1063
1076
  f'[DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE] Cannot resolve "{spark_function_name}" due to data type mismatch: Parameter 2 requires the ("INTERVAL") type for timestamp operations, however "{snowpark_arg_names[1]}" has the type "{snowpark_typed_args[1].typ}".',
1064
1077
  )
1065
1078
  case (StringType(), DateType()):
1066
- # TODO SNOW-2034420: resolve return type (it should be INTERVAL DAY)
1067
- result_type = LongType()
1068
- input_type = DateType()
1069
- result_exp = snowpark_args[0].cast(input_type) - snowpark_args[1]
1079
+ result_type = DayTimeIntervalType(
1080
+ DayTimeIntervalType.DAY, DayTimeIntervalType.DAY
1081
+ )
1082
+ result_exp = snowpark_fn.interval_day_time_from_parts(
1083
+ snowpark_args[0].cast(DateType()) - snowpark_args[1]
1084
+ )
1070
1085
  case (DateType(), (IntegerType() | ShortType() | ByteType())):
1071
1086
  result_type = DateType()
1072
1087
  result_exp = snowpark_args[0] - snowpark_args[1]
@@ -6184,15 +6199,19 @@ def map_unresolved_function(
6184
6199
  or isinstance(arg_type, DayTimeIntervalType)
6185
6200
  else DoubleType()
6186
6201
  )
6187
-
6188
- case "pow":
6189
- result_exp = snowpark_fn.pow(snowpark_args[0], snowpark_args[1])
6190
- result_type = DoubleType()
6191
- case "power":
6192
- spark_function_name = (
6193
- f"POWER({snowpark_arg_names[0]}, {snowpark_arg_names[1]})"
6194
- )
6195
- result_exp = snowpark_fn.pow(snowpark_args[0], snowpark_args[1])
6202
+ case "pow" | "power":
6203
+ spark_function_name = f"{function_name if function_name == 'pow' else function_name.upper()}({snowpark_arg_names[0]}, {snowpark_arg_names[1]})"
6204
+ if not spark_sql_ansi_enabled:
6205
+ snowpark_args = _validate_numeric_args(
6206
+ function_name, snowpark_typed_args, snowpark_args
6207
+ )
6208
+ result_exp = snowpark_fn.when(
6209
+ snowpark_fn.equal_nan(snowpark_fn.cast(snowpark_args[0], FloatType()))
6210
+ | snowpark_fn.equal_nan(
6211
+ snowpark_fn.cast(snowpark_args[1], FloatType())
6212
+ ),
6213
+ NAN,
6214
+ ).otherwise(snowpark_fn.pow(snowpark_args[0], snowpark_args[1]))
6196
6215
  result_type = DoubleType()
6197
6216
  case "product":
6198
6217
  col = snowpark_args[0]
@@ -6939,10 +6958,10 @@ def map_unresolved_function(
6939
6958
  )
6940
6959
  result_type = ArrayType(ArrayType(StringType()))
6941
6960
  case "sequence":
6942
- if snowpark_typed_args[0].typ != snowpark_typed_args[1].typ or (
6943
- not isinstance(snowpark_typed_args[0].typ, _IntegralType)
6944
- or not isinstance(snowpark_typed_args[1].typ, _IntegralType)
6945
- ):
6961
+ both_integral = isinstance(
6962
+ snowpark_typed_args[0].typ, _IntegralType
6963
+ ) and isinstance(snowpark_typed_args[1].typ, _IntegralType)
6964
+ if not both_integral:
6946
6965
  exception = AnalysisException(
6947
6966
  f"""[DATATYPE_MISMATCH.SEQUENCE_WRONG_INPUT_TYPES] Cannot resolve "sequence({snowpark_arg_names[0]}, {snowpark_arg_names[1]})" due to data type mismatch: `sequence` uses the wrong parameter type. The parameter type must conform to:
6948
6967
  1. The start and stop expressions must resolve to the same type.
@@ -9458,15 +9477,21 @@ def map_unresolved_function(
9458
9477
  result_exp = snowpark_fn.year(snowpark_fn.to_date(snowpark_args[0]))
9459
9478
  result_type = LongType()
9460
9479
  case binary_method if binary_method in ("to_binary", "try_to_binary"):
9461
- binary_format = "hex"
9480
+ binary_format = snowpark_fn.lit("hex")
9481
+ arg_str = snowpark_fn.cast(snowpark_args[0], StringType())
9462
9482
  if len(snowpark_args) > 1:
9463
9483
  binary_format = snowpark_args[1]
9464
9484
  result_exp = snowpark_fn.when(
9465
9485
  snowpark_args[0].isNull(), snowpark_fn.lit(None)
9466
9486
  ).otherwise(
9467
9487
  snowpark_fn.function(binary_method)(
9468
- snowpark_fn.cast(snowpark_args[0], StringType()), binary_format
9469
- ),
9488
+ snowpark_fn.when(
9489
+ (snowpark_fn.length(arg_str) % 2 == 1)
9490
+ & (snowpark_fn.lower(binary_format) == snowpark_fn.lit("hex")),
9491
+ snowpark_fn.concat(snowpark_fn.lit("0"), arg_str),
9492
+ ).otherwise(arg_str),
9493
+ binary_format,
9494
+ )
9470
9495
  )
9471
9496
  result_type = BinaryType()
9472
9497
  case udtf_name if udtf_name.lower() in session._udtfs:
@@ -10705,12 +10730,18 @@ def _try_sum_helper(
10705
10730
  return snowpark_fn.lit(None), new_type
10706
10731
  else:
10707
10732
  non_null_rows = snowpark_fn.count(col_name)
10708
- return aggregate_sum / non_null_rows, new_type
10733
+ # Use _divnull to handle case when non_null_rows is 0
10734
+ return _divnull(aggregate_sum, non_null_rows), new_type
10709
10735
  else:
10710
10736
  new_type = DecimalType(
10711
10737
  precision=min(38, arg_type.precision + 10), scale=arg_type.scale
10712
10738
  )
10713
- return aggregate_sum, new_type
10739
+ # Return NULL when there are no non-null values (i.e., all values are NULL); this is handled using case/when to check for non-null values for both SUM and the sum component of AVG calculations.
10740
+ non_null_rows = snowpark_fn.count(col_name)
10741
+ result = snowpark_fn.when(
10742
+ non_null_rows == 0, snowpark_fn.lit(None)
10743
+ ).otherwise(aggregate_sum)
10744
+ return result, new_type
10714
10745
 
10715
10746
  case _:
10716
10747
  # If the input column is floating point (double and float are synonymous in Snowflake per
@@ -10728,9 +10759,16 @@ def _try_sum_helper(
10728
10759
  return snowpark_fn.lit(None), DoubleType()
10729
10760
  else:
10730
10761
  non_null_rows = snowpark_fn.count(col_name)
10731
- return aggregate_sum / non_null_rows, DoubleType()
10762
+ # Use _divnull to handle case when non_null_rows is 0
10763
+ return _divnull(aggregate_sum, non_null_rows), DoubleType()
10732
10764
  else:
10733
- return aggregate_sum, DoubleType()
10765
+ # When all values are NULL, SUM should return NULL (not 0)
10766
+ # Use case/when to return NULL when there are no non-null values (i.e., all values are NULL)
10767
+ non_null_rows = snowpark_fn.count(col_name)
10768
+ result = snowpark_fn.when(
10769
+ non_null_rows == 0, snowpark_fn.lit(None)
10770
+ ).otherwise(aggregate_sum)
10771
+ return result, DoubleType()
10734
10772
 
10735
10773
 
10736
10774
  def _get_type_precision(typ: DataType) -> tuple[int, int]: