snowpark-connect 0.20.2__py3-none-any.whl → 0.22.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of snowpark-connect might be problematic. Click here for more details.

Files changed (84) hide show
  1. snowflake/snowpark_connect/analyze_plan/map_tree_string.py +3 -2
  2. snowflake/snowpark_connect/column_name_handler.py +6 -65
  3. snowflake/snowpark_connect/config.py +47 -17
  4. snowflake/snowpark_connect/dataframe_container.py +242 -0
  5. snowflake/snowpark_connect/error/error_utils.py +25 -0
  6. snowflake/snowpark_connect/execute_plan/map_execution_command.py +13 -23
  7. snowflake/snowpark_connect/execute_plan/map_execution_root.py +9 -5
  8. snowflake/snowpark_connect/expression/map_extension.py +2 -1
  9. snowflake/snowpark_connect/expression/map_udf.py +4 -4
  10. snowflake/snowpark_connect/expression/map_unresolved_attribute.py +8 -7
  11. snowflake/snowpark_connect/expression/map_unresolved_function.py +481 -170
  12. snowflake/snowpark_connect/expression/map_unresolved_star.py +8 -8
  13. snowflake/snowpark_connect/expression/map_update_fields.py +1 -1
  14. snowflake/snowpark_connect/expression/typer.py +6 -6
  15. snowflake/snowpark_connect/proto/control_pb2.py +17 -16
  16. snowflake/snowpark_connect/proto/control_pb2.pyi +17 -17
  17. snowflake/snowpark_connect/proto/control_pb2_grpc.py +12 -63
  18. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.py +15 -14
  19. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.pyi +19 -14
  20. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2_grpc.py +4 -0
  21. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +27 -26
  22. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +74 -68
  23. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2_grpc.py +4 -0
  24. snowflake/snowpark_connect/relation/catalogs/abstract_spark_catalog.py +5 -5
  25. snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +25 -17
  26. snowflake/snowpark_connect/relation/map_aggregate.py +170 -61
  27. snowflake/snowpark_connect/relation/map_catalog.py +2 -2
  28. snowflake/snowpark_connect/relation/map_column_ops.py +227 -145
  29. snowflake/snowpark_connect/relation/map_crosstab.py +25 -6
  30. snowflake/snowpark_connect/relation/map_extension.py +81 -56
  31. snowflake/snowpark_connect/relation/map_join.py +72 -63
  32. snowflake/snowpark_connect/relation/map_local_relation.py +35 -20
  33. snowflake/snowpark_connect/relation/map_map_partitions.py +24 -17
  34. snowflake/snowpark_connect/relation/map_relation.py +22 -16
  35. snowflake/snowpark_connect/relation/map_row_ops.py +232 -146
  36. snowflake/snowpark_connect/relation/map_sample_by.py +15 -8
  37. snowflake/snowpark_connect/relation/map_show_string.py +42 -5
  38. snowflake/snowpark_connect/relation/map_sql.py +141 -237
  39. snowflake/snowpark_connect/relation/map_stats.py +88 -39
  40. snowflake/snowpark_connect/relation/map_subquery_alias.py +13 -14
  41. snowflake/snowpark_connect/relation/map_udtf.py +10 -13
  42. snowflake/snowpark_connect/relation/read/map_read.py +8 -3
  43. snowflake/snowpark_connect/relation/read/map_read_csv.py +7 -7
  44. snowflake/snowpark_connect/relation/read/map_read_jdbc.py +7 -7
  45. snowflake/snowpark_connect/relation/read/map_read_json.py +19 -8
  46. snowflake/snowpark_connect/relation/read/map_read_parquet.py +7 -7
  47. snowflake/snowpark_connect/relation/read/map_read_socket.py +7 -3
  48. snowflake/snowpark_connect/relation/read/map_read_table.py +25 -16
  49. snowflake/snowpark_connect/relation/read/map_read_text.py +7 -7
  50. snowflake/snowpark_connect/relation/read/reader_config.py +1 -0
  51. snowflake/snowpark_connect/relation/utils.py +11 -5
  52. snowflake/snowpark_connect/relation/write/jdbc_write_dbapi.py +15 -12
  53. snowflake/snowpark_connect/relation/write/map_write.py +259 -56
  54. snowflake/snowpark_connect/relation/write/map_write_jdbc.py +3 -2
  55. snowflake/snowpark_connect/server.py +43 -4
  56. snowflake/snowpark_connect/type_mapping.py +6 -23
  57. snowflake/snowpark_connect/utils/cache.py +27 -22
  58. snowflake/snowpark_connect/utils/context.py +33 -17
  59. snowflake/snowpark_connect/utils/describe_query_cache.py +2 -9
  60. snowflake/snowpark_connect/utils/{attribute_handling.py → identifiers.py} +47 -0
  61. snowflake/snowpark_connect/utils/session.py +41 -38
  62. snowflake/snowpark_connect/utils/telemetry.py +214 -63
  63. snowflake/snowpark_connect/utils/udxf_import_utils.py +14 -0
  64. snowflake/snowpark_connect/version.py +1 -1
  65. snowflake/snowpark_decoder/__init__.py +0 -0
  66. snowflake/snowpark_decoder/_internal/proto/generated/DataframeProcessorMsg_pb2.py +36 -0
  67. snowflake/snowpark_decoder/_internal/proto/generated/DataframeProcessorMsg_pb2.pyi +156 -0
  68. snowflake/snowpark_decoder/dp_session.py +111 -0
  69. snowflake/snowpark_decoder/spark_decoder.py +76 -0
  70. {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.22.1.dist-info}/METADATA +6 -4
  71. {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.22.1.dist-info}/RECORD +83 -69
  72. snowpark_connect-0.22.1.dist-info/licenses/LICENSE-binary +568 -0
  73. snowpark_connect-0.22.1.dist-info/licenses/NOTICE-binary +1533 -0
  74. {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.22.1.dist-info}/top_level.txt +1 -0
  75. spark/__init__.py +0 -0
  76. spark/connect/__init__.py +0 -0
  77. spark/connect/envelope_pb2.py +31 -0
  78. spark/connect/envelope_pb2.pyi +46 -0
  79. snowflake/snowpark_connect/includes/jars/jackson-mapper-asl-1.9.13.jar +0 -0
  80. {snowpark_connect-0.20.2.data → snowpark_connect-0.22.1.data}/scripts/snowpark-connect +0 -0
  81. {snowpark_connect-0.20.2.data → snowpark_connect-0.22.1.data}/scripts/snowpark-session +0 -0
  82. {snowpark_connect-0.20.2.data → snowpark_connect-0.22.1.data}/scripts/snowpark-submit +0 -0
  83. {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.22.1.dist-info}/WHEEL +0 -0
  84. {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.22.1.dist-info}/licenses/LICENSE.txt +0 -0
@@ -28,6 +28,7 @@ from google.protobuf.message import Message
28
28
  from pyspark.errors.exceptions.base import (
29
29
  AnalysisException,
30
30
  ArithmeticException,
31
+ ArrayIndexOutOfBoundsException,
31
32
  DateTimeException,
32
33
  IllegalArgumentException,
33
34
  NumberFormatException,
@@ -39,6 +40,7 @@ from pyspark.sql.types import _parse_datatype_json_string
39
40
  import snowflake.snowpark.functions as snowpark_fn
40
41
  from snowflake import snowpark
41
42
  from snowflake.snowpark import Column, Session
43
+ from snowflake.snowpark._internal.analyzer.expression import Literal
42
44
  from snowflake.snowpark._internal.analyzer.unary_expression import Alias
43
45
  from snowflake.snowpark.types import (
44
46
  ArrayType,
@@ -69,7 +71,10 @@ from snowflake.snowpark_connect.column_name_handler import (
69
71
  ColumnNameMap,
70
72
  set_schema_getter,
71
73
  )
72
- from snowflake.snowpark_connect.config import global_config
74
+ from snowflake.snowpark_connect.config import (
75
+ get_boolean_session_config_param,
76
+ global_config,
77
+ )
73
78
  from snowflake.snowpark_connect.constants import (
74
79
  DUPLICATE_KEY_FOUND_ERROR_TEMPLATE,
75
80
  SPARK_TZ_ABBREVIATIONS_OVERRIDES,
@@ -100,6 +105,7 @@ from snowflake.snowpark_connect.typed_column import (
100
105
  )
101
106
  from snowflake.snowpark_connect.utils.context import (
102
107
  add_sql_aggregate_function,
108
+ get_current_grouping_columns,
103
109
  get_is_aggregate_function,
104
110
  get_is_evaluating_sql,
105
111
  get_is_in_udtf_context,
@@ -135,7 +141,7 @@ from snowflake.snowpark_connect.utils.xxhash64 import (
135
141
  MAX_UINT64 = 2**64 - 1
136
142
  MAX_INT64 = 2**63 - 1
137
143
  MIN_INT64 = -(2**63)
138
-
144
+ MAX_ARRAY_SIZE = 2_147_483_647
139
145
 
140
146
  NAN, INFINITY = float("nan"), float("inf")
141
147
 
@@ -341,6 +347,9 @@ def map_unresolved_function(
341
347
  )
342
348
  spark_col_names = []
343
349
  spark_sql_ansi_enabled = global_config.spark_sql_ansi_enabled
350
+ spark_sql_legacy_allow_hash_on_map_type = (
351
+ global_config.spark_sql_legacy_allowHashOnMapType
352
+ )
344
353
 
345
354
  function_name = exp.unresolved_function.function_name.lower()
346
355
  telemetry.report_function_usage(function_name)
@@ -631,37 +640,22 @@ def map_unresolved_function(
631
640
  [arg.typ for arg in snowpark_typed_args]
632
641
  )
633
642
  case "/":
634
- if isinstance(
635
- snowpark_typed_args[0].typ, (IntegerType, LongType, ShortType)
636
- ) and isinstance(
637
- snowpark_typed_args[1].typ, (IntegerType, LongType, ShortType)
638
- ):
639
- # Check if both arguments are integer types. Snowpark performs integer division, and precision is lost.
640
- # Cast to double and perform division
641
- result_exp = _divnull(
642
- snowpark_args[0].cast(DoubleType()),
643
- snowpark_args[1].cast(DoubleType()),
644
- )
645
- result_type = DoubleType()
646
- elif (
647
- isinstance(snowpark_typed_args[0].typ, DecimalType)
648
- and isinstance(snowpark_typed_args[1].typ, DecimalType)
649
- or isinstance(snowpark_typed_args[0].typ, DecimalType)
650
- and isinstance(snowpark_typed_args[1].typ, _IntegralType)
651
- or isinstance(snowpark_typed_args[0].typ, _IntegralType)
652
- and isinstance(snowpark_typed_args[1].typ, DecimalType)
653
- ):
654
- result_exp, (
655
- return_type_precision,
656
- return_type_scale,
657
- ) = _mul_div_precision_helper(snowpark_typed_args, snowpark_args, 1)
658
- result_type = DecimalType(return_type_precision, return_type_scale)
659
- else:
660
- # Perform division directly
661
- result_exp = _divnull(snowpark_args[0], snowpark_args[1])
662
- result_type = _find_common_type(
663
- [arg.typ for arg in snowpark_typed_args]
664
- )
643
+ match (snowpark_typed_args[0].typ, snowpark_typed_args[1].typ):
644
+ case (DecimalType(), t) | (t, DecimalType()) if isinstance(
645
+ t, DecimalType
646
+ ) or isinstance(t, _IntegralType) or isinstance(
647
+ snowpark_typed_args[1].typ, NullType
648
+ ):
649
+ result_exp, (
650
+ return_type_precision,
651
+ return_type_scale,
652
+ ) = _mul_div_precision_helper(snowpark_typed_args, snowpark_args, 1)
653
+ result_type = DecimalType(return_type_precision, return_type_scale)
654
+ case _:
655
+ result_type = DoubleType()
656
+ dividend = snowpark_args[0].cast(result_type)
657
+ divisor = snowpark_args[1].cast(result_type)
658
+ result_exp = _divnull(dividend, divisor)
665
659
  case "~":
666
660
  result_exp = TypedColumn(
667
661
  snowpark_fn.bitnot(snowpark_args[0]),
@@ -867,14 +861,30 @@ def map_unresolved_function(
867
861
  )
868
862
  case "approx_percentile" | "percentile_approx":
869
863
  # SNOW-1955784: Support accuracy parameter
864
+ # Use percentile_disc to return actual values from dataset (matches PySpark behavior)
870
865
 
871
- # Even though the Spark function accepts a Column for percentage, it will fail unless it's a literal.
872
- # Therefore, we can do error checking right here.
873
- def _check_percentage(exp: expressions_proto.Expression) -> Column:
874
- perc = unwrap_literal(exp)
875
- if not 0.0 <= perc <= 1.0:
866
+ def _pyspark_approx_percentile(
867
+ column: Column, percentage: float, original_type: DataType
868
+ ) -> Column:
869
+ """
870
+ PySpark-compatible percentile that returns actual values from dataset.
871
+ - PySpark's approx_percentile returns the "smallest value in the ordered col values
872
+ such that no more than percentage of col values is less than or equal to that value"
873
+ - This means it MUST return an actual value from the original dataset
874
+ - Snowflake's approx_percentile() may interpolate between values, breaking compatibility
875
+ - percentile_disc() returns discrete values (actual dataset values), matching PySpark
876
+ """
877
+ # Even though the Spark function accepts a Column for percentage, it will fail unless it's a literal.
878
+ # Therefore, we can do error checking right here.
879
+ if not 0.0 <= percentage <= 1.0:
876
880
  raise AnalysisException("percentage must be between [0.0, 1.0]")
877
- return snowpark_fn.lit(perc)
881
+
882
+ result = snowpark_fn.function("percentile_disc")(
883
+ snowpark_fn.lit(percentage)
884
+ ).within_group(column)
885
+ return snowpark_fn.cast(result, original_type)
886
+
887
+ column_type = snowpark_typed_args[0].typ
878
888
 
879
889
  if isinstance(snowpark_typed_args[1].typ, ArrayType):
880
890
  # Snowpark doesn't accept a list of percentile values.
@@ -882,26 +892,26 @@ def map_unresolved_function(
882
892
  array_func = exp.unresolved_function.arguments[1].unresolved_function
883
893
  assert array_func.function_name == "array", array_func
884
894
 
885
- result_exp = snowpark_fn.array_construct(
886
- *[
887
- snowpark_fn.approx_percentile(
888
- snowpark_args[0], _check_percentage(arg)
889
- )
890
- for arg in array_func.arguments
891
- ]
892
- )
895
+ percentile_results = [
896
+ _pyspark_approx_percentile(
897
+ snowpark_args[0], unwrap_literal(arg), column_type
898
+ )
899
+ for arg in array_func.arguments
900
+ ]
901
+
902
+ result_type = ArrayType(element_type=column_type, contains_null=False)
893
903
  result_exp = snowpark_fn.cast(
894
- result_exp,
895
- ArrayType(element_type=DoubleType(), contains_null=False),
904
+ snowpark_fn.array_construct(*percentile_results),
905
+ result_type,
896
906
  )
897
- result_type = ArrayType(element_type=DoubleType(), contains_null=False)
898
907
  else:
908
+ # Handle single percentile
909
+ percentage = unwrap_literal(exp.unresolved_function.arguments[1])
899
910
  result_exp = TypedColumn(
900
- snowpark_fn.approx_percentile(
901
- snowpark_args[0],
902
- _check_percentage(exp.unresolved_function.arguments[1]),
911
+ _pyspark_approx_percentile(
912
+ snowpark_args[0], percentage, column_type
903
913
  ),
904
- lambda: [DoubleType()],
914
+ lambda: [column_type],
905
915
  )
906
916
  case "array":
907
917
  if len(snowpark_args) == 0:
@@ -1178,35 +1188,18 @@ def map_unresolved_function(
1178
1188
  snowpark_fn.asinh(snowpark_args[0]), lambda: [DoubleType()]
1179
1189
  )
1180
1190
  case "assert_true":
1191
+ result_type = NullType()
1192
+ raise_error = _raise_error_helper(result_type)
1181
1193
 
1182
- @cached_udf(
1183
- input_types=[BooleanType()],
1184
- return_type=StringType(),
1185
- )
1186
- def _assert_true_single(expr):
1187
- if not expr:
1188
- raise ValueError("assertion failed")
1189
- return None
1190
-
1191
- @cached_udf(
1192
- input_types=[BooleanType(), StringType()],
1193
- return_type=StringType(),
1194
- )
1195
- def _assert_true_with_message(expr, message):
1196
- if not expr:
1197
- raise ValueError(message)
1198
- return None
1199
-
1200
- # Handle different argument counts using match pattern
1201
1194
  match snowpark_args:
1202
1195
  case [expr]:
1203
- result_exp = TypedColumn(
1204
- _assert_true_single(expr), lambda: [StringType()]
1205
- )
1196
+ result_exp = snowpark_fn.when(
1197
+ expr, snowpark_fn.lit(None)
1198
+ ).otherwise(raise_error(snowpark_fn.lit("assertion failed")))
1206
1199
  case [expr, message]:
1207
- result_exp = TypedColumn(
1208
- _assert_true_with_message(expr, message), lambda: [StringType()]
1209
- )
1200
+ result_exp = snowpark_fn.when(
1201
+ expr, snowpark_fn.lit(None)
1202
+ ).otherwise(raise_error(snowpark_fn.cast(message, StringType())))
1210
1203
  case _:
1211
1204
  raise AnalysisException(
1212
1205
  f"[WRONG_NUM_ARGS.WITHOUT_SUGGESTION] The `assert_true` requires 1 or 2 parameters but the actual number is {len(snowpark_args)}."
@@ -2073,14 +2066,22 @@ def map_unresolved_function(
2073
2066
  assert (
2074
2067
  len(exp.unresolved_function.arguments) == 2
2075
2068
  ), "date_format takes 2 arguments"
2076
- result_exp = snowpark_fn.date_format(
2077
- snowpark_args[0],
2078
- snowpark_fn.lit(
2079
- map_spark_timestamp_format_expression(
2080
- exp.unresolved_function.arguments[1], snowpark_typed_args[0].typ
2081
- )
2082
- ),
2083
- )
2069
+
2070
+ # Check if format parameter is NULL
2071
+ format_literal = unwrap_literal(exp.unresolved_function.arguments[1])
2072
+ if format_literal is None:
2073
+ # If format is NULL, return NULL for all rows
2074
+ result_exp = snowpark_fn.lit(None)
2075
+ else:
2076
+ result_exp = snowpark_fn.date_format(
2077
+ snowpark_args[0],
2078
+ snowpark_fn.lit(
2079
+ map_spark_timestamp_format_expression(
2080
+ exp.unresolved_function.arguments[1],
2081
+ snowpark_typed_args[0].typ,
2082
+ )
2083
+ ),
2084
+ )
2084
2085
  result_exp = TypedColumn(result_exp, lambda: [StringType()])
2085
2086
  case "date_from_unix_date":
2086
2087
  result_exp = snowpark_fn.date_add(
@@ -2260,31 +2261,32 @@ def map_unresolved_function(
2260
2261
  )
2261
2262
  case "elt":
2262
2263
  n = snowpark_args[0]
2263
-
2264
2264
  values = snowpark_fn.array_construct(*snowpark_args[1:])
2265
2265
 
2266
2266
  if spark_sql_ansi_enabled:
2267
-
2268
- @cached_udf(
2269
- input_types=[IntegerType()],
2270
- return_type=StringType(),
2267
+ raise_error = _raise_error_helper(
2268
+ StringType(), error_class=ArrayIndexOutOfBoundsException
2271
2269
  )
2272
- def _raise_out_of_bounds_error(n: int) -> str:
2273
- raise ValueError(
2274
- f"ArrayIndexOutOfBoundsException: {n} is not within the input bounds."
2275
- )
2276
-
2277
2270
  values_size = snowpark_fn.lit(len(snowpark_args) - 1)
2278
2271
 
2279
2272
  result_exp = (
2280
2273
  snowpark_fn.when(snowpark_fn.is_null(n), snowpark_fn.lit(None))
2281
2274
  .when(
2282
2275
  (snowpark_fn.lit(1) <= n) & (n <= values_size),
2283
- snowpark_fn.get(
2284
- values, snowpark_fn.nvl(n - 1, snowpark_fn.lit(0))
2276
+ snowpark_fn.cast(
2277
+ snowpark_fn.get(
2278
+ values, snowpark_fn.nvl(n - 1, snowpark_fn.lit(0))
2279
+ ),
2280
+ StringType(),
2285
2281
  ),
2286
2282
  )
2287
- .otherwise(_raise_out_of_bounds_error(n))
2283
+ .otherwise(
2284
+ raise_error(
2285
+ snowpark_fn.lit("[INVALID_ARRAY_INDEX] The index "),
2286
+ snowpark_fn.cast(n, StringType()),
2287
+ snowpark_fn.lit(" is out of bounds."),
2288
+ )
2289
+ )
2288
2290
  )
2289
2291
  else:
2290
2292
  result_exp = snowpark_fn.when(
@@ -2535,6 +2537,19 @@ def map_unresolved_function(
2535
2537
  input_types=[StringType(), StringType(), StructType()],
2536
2538
  )
2537
2539
  def _from_csv(csv_data: str, schema: str, options: Optional[dict]):
2540
+ if csv_data is None:
2541
+ return None
2542
+
2543
+ if csv_data == "":
2544
+ # Return dict with None values for empty string
2545
+ schemas = schema.split(",")
2546
+ results = {}
2547
+ for sc in schemas:
2548
+ parts = [i for i in sc.split(" ") if len(i) != 0]
2549
+ assert len(parts) == 2, f"{sc} is not a valid schema"
2550
+ results[parts[0]] = None
2551
+ return results
2552
+
2538
2553
  max_chars_per_column = -1
2539
2554
  sep = ","
2540
2555
 
@@ -2617,7 +2632,9 @@ def map_unresolved_function(
2617
2632
  case _:
2618
2633
  raise ValueError("Unrecognized from_csv parameters")
2619
2634
 
2620
- result_exp = snowpark_fn.cast(csv_result, ddl_schema)
2635
+ result_exp = snowpark_fn.when(
2636
+ snowpark_args[0].is_null(), snowpark_fn.lit(None)
2637
+ ).otherwise(snowpark_fn.cast(csv_result, ddl_schema))
2621
2638
  result_type = ddl_schema
2622
2639
  case "from_json":
2623
2640
  # TODO: support options.
@@ -2651,6 +2668,9 @@ def map_unresolved_function(
2651
2668
  # try to parse first, since spark returns null for invalid json
2652
2669
  result_exp = snowpark_fn.call_function("try_parse_json", snowpark_args[0])
2653
2670
 
2671
+ # Check if the original input is NULL - if so, return NULL for the entire result
2672
+ original_input_is_null = snowpark_args[0].is_null()
2673
+
2654
2674
  # helper function to make sure we have the expected array element type
2655
2675
  def _element_type_matches(
2656
2676
  array_exp: Column, element_type: DataType
@@ -2749,9 +2769,13 @@ def map_unresolved_function(
2749
2769
  else:
2750
2770
  return exp
2751
2771
 
2752
- result_exp = snowpark_fn.cast(
2753
- _coerce_to_type(result_exp, result_type), result_type
2754
- )
2772
+ # Apply the coercion to handle invalid JSON (creates struct with NULL fields)
2773
+ coerced_exp = _coerce_to_type(result_exp, result_type)
2774
+
2775
+ # If the original input was NULL, return NULL instead of a struct
2776
+ result_exp = snowpark_fn.when(
2777
+ original_input_is_null, snowpark_fn.lit(None)
2778
+ ).otherwise(snowpark_fn.cast(coerced_exp, result_type))
2755
2779
  case "from_unixtime":
2756
2780
 
2757
2781
  def raise_analysis_exception(
@@ -2896,10 +2920,53 @@ def map_unresolved_function(
2896
2920
  )
2897
2921
  case "grouping" | "grouping_id":
2898
2922
  # grouping_id is not an alias for grouping in PySpark, but Snowflake's implementation handles both
2899
- result_exp = snowpark_fn.grouping(*snowpark_args)
2923
+ current_grouping_cols = get_current_grouping_columns()
2924
+ if function_name == "grouping_id":
2925
+ if not snowpark_args:
2926
+ # grouping_id() with empty args means use all grouping columns
2927
+ spark_function_name = "grouping_id()"
2928
+ snowpark_args = [
2929
+ column_mapping.get_snowpark_column_name_from_spark_column_name(
2930
+ spark_col
2931
+ )
2932
+ for spark_col in current_grouping_cols
2933
+ ]
2934
+ else:
2935
+ # Verify that grouping arguments match current grouping columns
2936
+ spark_col_args = [
2937
+ column_mapping.get_spark_column_name_from_snowpark_column_name(
2938
+ sp_col.getName()
2939
+ )
2940
+ for sp_col in snowpark_args
2941
+ ]
2942
+ if current_grouping_cols != spark_col_args:
2943
+ raise AnalysisException(
2944
+ f"[GROUPING_ID_COLUMN_MISMATCH] Columns of grouping_id: {spark_col_args} doesnt match "
2945
+ f"Grouping columns: {current_grouping_cols}"
2946
+ )
2947
+ if function_name == "grouping_id":
2948
+ result_exp = snowpark_fn.grouping_id(*snowpark_args)
2949
+ else:
2950
+ result_exp = snowpark_fn.grouping(*snowpark_args)
2900
2951
  result_type = LongType()
2901
2952
  case "hash":
2902
2953
  # TODO: See the spark-compatibility-issues.md explanation, this is quite different from Spark.
2954
+ # MapType columns as input should raise an exception as they are not hashable.
2955
+ snowflake_compat = get_boolean_session_config_param(
2956
+ "enable_snowflake_extension_behavior"
2957
+ )
2958
+ # Snowflake's hash function does allow MAP types, but Spark does not. Therefore, if we have the expansion flag enabled
2959
+ # we want to let it pass through and hash MAP types.
2960
+ # Also allow if the legacy config spark.sql.legacy.allowHashOnMapType is set to true
2961
+ if not snowflake_compat and not spark_sql_legacy_allow_hash_on_map_type:
2962
+ for arg in snowpark_typed_args:
2963
+ if any(isinstance(t, MapType) for t in arg.types):
2964
+ raise AnalysisException(
2965
+ '[DATATYPE_MISMATCH.HASH_MAP_TYPE] Cannot resolve "hash(value)" due to data type mismatch: '
2966
+ 'Input to the function `hash` cannot contain elements of the "MAP" type. '
2967
+ 'In Spark, same maps may have different hashcode, thus hash expressions are prohibited on "MAP" elements. '
2968
+ 'To restore previous behavior set "spark.sql.legacy.allowHashOnMapType" to "true".'
2969
+ )
2903
2970
  result_exp = snowpark_fn.hash(*snowpark_args)
2904
2971
  result_type = LongType()
2905
2972
  case "hex":
@@ -2934,6 +3001,14 @@ def map_unresolved_function(
2934
3001
  result_type = StringType()
2935
3002
  case "histogram_numeric":
2936
3003
  aggregate_input_typ = snowpark_typed_args[0].typ
3004
+
3005
+ if isinstance(aggregate_input_typ, DecimalType):
3006
+ # mimic bug from Spark 3.5.3.
3007
+ # In 3.5.5 it's fixed and this exception shouldn't be thrown
3008
+ raise ValueError(
3009
+ "class org.apache.spark.sql.types.Decimal cannot be cast to class java.lang.Number (org.apache.spark.sql.types.Decimal is in unnamed module of loader 'app'; java.lang.Number is in module java.base of loader 'bootstrap')"
3010
+ )
3011
+
2937
3012
  histogram_return_type = ArrayType(
2938
3013
  StructType(
2939
3014
  [
@@ -3154,6 +3229,18 @@ def map_unresolved_function(
3154
3229
  )
3155
3230
  result_type = histogram_return_type
3156
3231
  case "hll_sketch_agg":
3232
+ # check if input type is correct
3233
+ if type(snowpark_typed_args[0].typ) not in [
3234
+ IntegerType,
3235
+ LongType,
3236
+ StringType,
3237
+ BinaryType,
3238
+ ]:
3239
+ type_str = snowpark_typed_args[0].typ.simpleString().upper()
3240
+ raise AnalysisException(
3241
+ f'[DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE] Cannot resolve "{spark_function_name}" due to data type mismatch: Parameter 1 requires the ("INT" or "BIGINT" or "STRING" or "BINARY") type, however "{snowpark_arg_names[0]}" has the type "{type_str}".'
3242
+ )
3243
+
3157
3244
  match snowpark_args:
3158
3245
  case [sketch]:
3159
3246
  spark_function_name = (
@@ -3173,7 +3260,7 @@ def map_unresolved_function(
3173
3260
  ).cast(LongType())
3174
3261
  result_type = LongType()
3175
3262
  case "hll_union_agg":
3176
- raise_error = _raise_error_udf_helper(BinaryType())
3263
+ raise_error = _raise_error_helper(BinaryType())
3177
3264
  args = exp.unresolved_function.arguments
3178
3265
  allow_different_lgConfigK = len(args) == 2 and unwrap_literal(args[1])
3179
3266
  spark_function_name = f"{function_name}({snowpark_arg_names[0]}, {str(allow_different_lgConfigK).lower()})"
@@ -3213,7 +3300,7 @@ def map_unresolved_function(
3213
3300
  SELECT arg1 as x)
3214
3301
  """,
3215
3302
  )
3216
- raise_error = _raise_error_udf_helper(BinaryType())
3303
+ raise_error = _raise_error_helper(BinaryType())
3217
3304
  args = exp.unresolved_function.arguments
3218
3305
  allow_different_lgConfigK = len(args) == 3 and unwrap_literal(args[2])
3219
3306
  spark_function_name = f"{function_name}({snowpark_arg_names[0]}, {snowpark_arg_names[1]}, {str(allow_different_lgConfigK).lower()})"
@@ -3796,12 +3883,47 @@ def map_unresolved_function(
3796
3883
  )
3797
3884
 
3798
3885
  result_type = StringType()
3799
- case "ltrim":
3886
+ case "ltrim" | "rtrim":
3887
+ function_name_argument = (
3888
+ "TRAILING" if function_name == "rtrim" else "LEADING"
3889
+ )
3800
3890
  if len(snowpark_args) == 2:
3801
3891
  # Only possible using SQL
3802
- spark_function_name = f"TRIM(LEADING {snowpark_arg_names[1]} FROM {snowpark_arg_names[0]})"
3892
+ spark_function_name = f"TRIM({function_name_argument} {snowpark_arg_names[1]} FROM {snowpark_arg_names[0]})"
3803
3893
  result_exp = snowpark_fn.ltrim(*snowpark_args)
3804
3894
  result_type = StringType()
3895
+ if isinstance(snowpark_typed_args[0].typ, BinaryType):
3896
+ argument_name = snowpark_arg_names[0]
3897
+ if exp.unresolved_function.arguments[0].HasField("literal"):
3898
+ argument_name = f"""X'{exp.unresolved_function.arguments[0].literal.binary.hex()}'"""
3899
+ if len(snowpark_args) == 1:
3900
+ spark_function_name = f"{function_name}({argument_name})"
3901
+ trim_value = snowpark_fn.lit(b"\x20")
3902
+ if len(snowpark_args) == 2:
3903
+ # Only possible using SQL
3904
+ trim_arg = snowpark_arg_names[1]
3905
+ if isinstance(
3906
+ snowpark_typed_args[1].typ, BinaryType
3907
+ ) and exp.unresolved_function.arguments[1].HasField("literal"):
3908
+ trim_arg = f"""X'{exp.unresolved_function.arguments[1].literal.binary.hex()}'"""
3909
+ trim_value = snowpark_args[1]
3910
+ else:
3911
+ trim_value = snowpark_fn.lit(None)
3912
+ function_name_argument = (
3913
+ "TRAILING" if function_name == "rtrim" else "LEADING"
3914
+ )
3915
+ spark_function_name = f"TRIM({function_name_argument} {trim_arg} FROM {argument_name})"
3916
+ result_exp = _trim_helper(
3917
+ snowpark_args[0], trim_value, snowpark_fn.lit(function_name)
3918
+ )
3919
+ result_type = BinaryType()
3920
+ else:
3921
+ if function_name == "ltrim":
3922
+ result_exp = snowpark_fn.ltrim(*snowpark_args)
3923
+ result_type = StringType()
3924
+ elif function_name == "rtrim":
3925
+ result_exp = snowpark_fn.rtrim(*snowpark_args)
3926
+ result_type = StringType()
3805
3927
  case "make_date":
3806
3928
  y = snowpark_args[0].cast(LongType())
3807
3929
  m = snowpark_args[1].cast(LongType())
@@ -3902,7 +4024,7 @@ def map_unresolved_function(
3902
4024
  snowpark_fn.is_null(snowpark_args[i]),
3903
4025
  # udf execution on XP seems to be lazy, so this should only run when there is a null key
3904
4026
  # otherwise there should be no udf env setup or execution
3905
- _raise_error_udf_helper(VariantType())(
4027
+ _raise_error_helper(VariantType())(
3906
4028
  snowpark_fn.lit(
3907
4029
  "[NULL_MAP_KEY] Cannot use null as map key."
3908
4030
  )
@@ -3964,6 +4086,14 @@ def map_unresolved_function(
3964
4086
  )
3965
4087
  result_type = MapType(key_type, value_type)
3966
4088
  case "map_contains_key":
4089
+ if isinstance(snowpark_typed_args[0].typ, NullType):
4090
+ raise AnalysisException(
4091
+ f"""[DATATYPE_MISMATCH.MAP_FUNCTION_DIFF_TYPES] Cannot resolve "map_contains_key({snowpark_arg_names[0]}, {snowpark_arg_names[1]})" due to data type mismatch: Input to `map_contains_key` should have been "MAP" followed by a value with same key type, but it's ["VOID", "INT"]."""
4092
+ )
4093
+ if isinstance(snowpark_typed_args[1].typ, NullType):
4094
+ raise AnalysisException(
4095
+ f"""[DATATYPE_MISMATCH.NULL_TYPE] Cannot resolve "map_contains_key({snowpark_arg_names[0]}, {snowpark_arg_names[1]})" due to data type mismatch: Null typed values cannot be used as arguments of `map_contains_key`."""
4096
+ )
3967
4097
  args = (
3968
4098
  [snowpark_args[1], snowpark_args[0]]
3969
4099
  if isinstance(snowpark_typed_args[0].typ, MapType)
@@ -4093,17 +4223,37 @@ def map_unresolved_function(
4093
4223
 
4094
4224
  last_win_dedup = global_config.spark_sql_mapKeyDedupPolicy == "LAST_WIN"
4095
4225
 
4096
- result_exp = snowpark_fn.cast(
4097
- snowpark_fn.function("reduce")(
4098
- snowpark_args[0],
4099
- snowpark_fn.object_construct(),
4100
- snowpark_fn.sql_expr(
4101
- # value_field is cast to variant because object_insert doesn't allow structured types,
4102
- # and structured types are not coercible to variant
4103
- # TODO: allow structured types in object_insert?
4104
- f"(acc, e) -> object_insert(acc, e:{key_field}, e:{value_field}::variant, {last_win_dedup})"
4105
- ),
4226
+ # Check if any entry has a NULL key
4227
+ has_null_key = (
4228
+ snowpark_fn.function("array_size")(
4229
+ snowpark_fn.function("filter")(
4230
+ snowpark_args[0],
4231
+ snowpark_fn.sql_expr(f"e -> e:{key_field} IS NULL"),
4232
+ )
4233
+ )
4234
+ > 0
4235
+ )
4236
+
4237
+ # Create error UDF for NULL keys (same pattern as map function)
4238
+ null_key_error = _raise_error_helper(VariantType())(
4239
+ snowpark_fn.lit("[NULL_MAP_KEY] Cannot use null as map key.")
4240
+ )
4241
+
4242
+ # Create the reduce operation
4243
+ reduce_result = snowpark_fn.function("reduce")(
4244
+ snowpark_args[0],
4245
+ snowpark_fn.object_construct(),
4246
+ snowpark_fn.sql_expr(
4247
+ # value_field is cast to variant because object_insert doesn't allow structured types,
4248
+ # and structured types are not coercible to variant
4249
+ # TODO: allow structured types in object_insert?
4250
+ f"(acc, e) -> object_insert(acc, e:{key_field}, e:{value_field}::variant, {last_win_dedup})"
4106
4251
  ),
4252
+ )
4253
+
4254
+ # Use conditional logic: if there are NULL keys, throw error; otherwise proceed with reduce
4255
+ result_exp = snowpark_fn.cast(
4256
+ snowpark_fn.when(has_null_key, null_key_error).otherwise(reduce_result),
4107
4257
  MapType(key_type, value_type),
4108
4258
  )
4109
4259
  result_type = MapType(key_type, value_type)
@@ -4122,23 +4272,35 @@ def map_unresolved_function(
4122
4272
  # TODO: implement in Snowflake/Snowpark
4123
4273
  # technically this could be done with a lateral join, but it's probably not worth the effort
4124
4274
  arg_type = snowpark_typed_args[0].typ
4125
- if not isinstance(arg_type, MapType):
4275
+ if not isinstance(arg_type, (MapType, NullType)):
4126
4276
  raise AnalysisException(
4127
4277
  f"map_values requires a MapType argument, got {arg_type}"
4128
4278
  )
4129
4279
 
4130
4280
  def _map_values(obj: dict) -> list:
4131
- return list(obj.values()) if obj else None
4281
+ if obj is None:
4282
+ return None
4283
+ return list(obj.values())
4132
4284
 
4133
4285
  map_values = cached_udf(
4134
4286
  _map_values, return_type=ArrayType(), input_types=[StructType()]
4135
4287
  )
4136
4288
 
4137
- result_exp = snowpark_fn.cast(
4138
- map_values(snowpark_fn.cast(snowpark_args[0], StructType())),
4139
- ArrayType(arg_type.value_type),
4140
- )
4141
- result_type = ArrayType(arg_type.value_type)
4289
+ # Handle NULL input directly at expression level
4290
+ if isinstance(arg_type, NullType):
4291
+ # If input is NULL literal, return NULL
4292
+ result_exp = snowpark_fn.lit(None)
4293
+ result_type = ArrayType(NullType())
4294
+ else:
4295
+ result_exp = snowpark_fn.when(
4296
+ snowpark_args[0].is_null(), snowpark_fn.lit(None)
4297
+ ).otherwise(
4298
+ snowpark_fn.cast(
4299
+ map_values(snowpark_fn.cast(snowpark_args[0], StructType())),
4300
+ ArrayType(arg_type.value_type),
4301
+ )
4302
+ )
4303
+ result_type = ArrayType(arg_type.value_type)
4142
4304
  case "mask":
4143
4305
 
4144
4306
  number_of_args = len(snowpark_args)
@@ -4258,6 +4420,17 @@ def map_unresolved_function(
4258
4420
  lambda: snowpark_typed_args[0].types,
4259
4421
  )
4260
4422
  case "md5":
4423
+ snowflake_compat = get_boolean_session_config_param(
4424
+ "enable_snowflake_extension_behavior"
4425
+ )
4426
+
4427
+ # MD5 in Spark only accepts BinaryType or types that can be implicitly cast to it (StringType)
4428
+ if not snowflake_compat:
4429
+ if not isinstance(snowpark_typed_args[0].typ, (BinaryType, StringType)):
4430
+ raise AnalysisException(
4431
+ f'[DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE] Cannot resolve "md5({snowpark_arg_names[0]})" due to data type mismatch: '
4432
+ f'Parameter 1 requires the "BINARY" type, however "{snowpark_arg_names[0]}" has the type "{snowpark_typed_args[0].typ}".'
4433
+ )
4261
4434
  result_exp = snowpark_fn.md5(snowpark_args[0])
4262
4435
  result_type = StringType(32)
4263
4436
  case "median":
@@ -5032,7 +5205,7 @@ def map_unresolved_function(
5032
5205
  result_type = DoubleType()
5033
5206
  case "raise_error":
5034
5207
  result_type = StringType()
5035
- raise_error = _raise_error_udf_helper(result_type)
5208
+ raise_error = _raise_error_helper(result_type)
5036
5209
  result_exp = raise_error(*snowpark_args)
5037
5210
  case "rand" | "random":
5038
5211
  # Snowpark random() generates a 64 bit signed integer, but pyspark is [0.0, 1.0).
@@ -5117,7 +5290,7 @@ def map_unresolved_function(
5117
5290
  snowpark_args[2],
5118
5291
  ),
5119
5292
  ),
5120
- _raise_error_udf_helper(StringType())(
5293
+ _raise_error_helper(StringType())(
5121
5294
  snowpark_fn.lit(
5122
5295
  "[INVALID_PARAMETER_VALUE.REGEX_GROUP_INDEX] The value of parameter(s) `idx` in `regexp_extract` is invalid."
5123
5296
  )
@@ -5167,7 +5340,7 @@ def map_unresolved_function(
5167
5340
  idx,
5168
5341
  )
5169
5342
  ),
5170
- _raise_error_udf_helper(ArrayType(StringType()))(
5343
+ _raise_error_helper(ArrayType(StringType()))(
5171
5344
  snowpark_fn.lit(
5172
5345
  "[INVALID_PARAMETER_VALUE.REGEX_GROUP_INDEX] The value of parameter(s) `idx` in `regexp_extract_all` is invalid."
5173
5346
  )
@@ -5466,13 +5639,28 @@ def map_unresolved_function(
5466
5639
  case "row_number":
5467
5640
  result_exp = snowpark_fn.row_number()
5468
5641
  result_exp = TypedColumn(result_exp, lambda: [LongType()])
5469
- case "rtrim":
5470
- if len(snowpark_args) == 2:
5471
- # Only possible using SQL
5472
- spark_function_name = f"TRIM(TRAILING {snowpark_arg_names[1]} FROM {snowpark_arg_names[0]})"
5473
- result_exp = snowpark_fn.rtrim(*snowpark_args)
5474
- result_type = StringType()
5475
5642
  case "schema_of_csv":
5643
+ # Validate that the input is a foldable STRING expression
5644
+ if (
5645
+ exp.unresolved_function.arguments[0].WhichOneof("expr_type")
5646
+ != "literal"
5647
+ ):
5648
+ raise AnalysisException(
5649
+ "[DATATYPE_MISMATCH.NON_FOLDABLE_INPUT] Cannot resolve "
5650
+ f'"schema_of_csv({snowpark_arg_names[0]})" due to data type mismatch: '
5651
+ 'the input csv should be a foldable "STRING" expression; however, '
5652
+ f'got "{snowpark_arg_names[0]}".'
5653
+ )
5654
+
5655
+ if isinstance(snowpark_typed_args[0].typ, StringType):
5656
+ if exp.unresolved_function.arguments[0].literal.string == "":
5657
+ raise AnalysisException(
5658
+ "[DATATYPE_MISMATCH.NON_FOLDABLE_INPUT] Cannot resolve "
5659
+ f'"schema_of_csv({snowpark_arg_names[0]})" due to data type mismatch: '
5660
+ 'the input csv should be a foldable "STRING" expression; however, '
5661
+ f'got "{snowpark_arg_names[0]}".'
5662
+ )
5663
+
5476
5664
  snowpark_args = [
5477
5665
  typed_arg.column(to_semi_structure=True)
5478
5666
  for typed_arg in snowpark_typed_args
@@ -5689,6 +5877,16 @@ def map_unresolved_function(
5689
5877
  )
5690
5878
  result_type = ArrayType(ArrayType(StringType()))
5691
5879
  case "sequence":
5880
+ if snowpark_typed_args[0].typ != snowpark_typed_args[1].typ or (
5881
+ not isinstance(snowpark_typed_args[0].typ, _IntegralType)
5882
+ or not isinstance(snowpark_typed_args[1].typ, _IntegralType)
5883
+ ):
5884
+ raise AnalysisException(
5885
+ f"""[DATATYPE_MISMATCH.SEQUENCE_WRONG_INPUT_TYPES] Cannot resolve "sequence({snowpark_arg_names[0]}, {snowpark_arg_names[1]})" due to data type mismatch: `sequence` uses the wrong parameter type. The parameter type must conform to:
5886
+ 1. The start and stop expressions must resolve to the same type.
5887
+ 2. Otherwise, if start and stop expressions resolve to the "INTEGRAL" type, then the step expression must resolve to the same type.
5888
+ """
5889
+ )
5692
5890
  result_exp = snowpark_fn.cast(
5693
5891
  snowpark_fn.sequence(*snowpark_args),
5694
5892
  ArrayType(LongType(), contains_null=False),
@@ -5856,7 +6054,7 @@ def map_unresolved_function(
5856
6054
  result_exp = snowpark_fn.skew(snowpark_fn.lit(None))
5857
6055
  result_type = DoubleType()
5858
6056
  case "slice":
5859
- raise_error = _raise_error_udf_helper(snowpark_typed_args[0].typ)
6057
+ raise_error = _raise_error_helper(snowpark_typed_args[0].typ)
5860
6058
  spark_index = snowpark_args[1]
5861
6059
  arr_size = snowpark_fn.array_size(snowpark_args[0])
5862
6060
  slice_len = snowpark_args[2]
@@ -5926,10 +6124,11 @@ def map_unresolved_function(
5926
6124
  result_exp = snowpark_fn.lit(0)
5927
6125
  result_type = LongType()
5928
6126
  case "split":
6127
+ result_type = ArrayType(StringType())
5929
6128
 
5930
6129
  @cached_udf(
5931
6130
  input_types=[StringType(), StringType(), IntegerType()],
5932
- return_type=ArrayType(StringType()),
6131
+ return_type=result_type,
5933
6132
  )
5934
6133
  def _split(
5935
6134
  input: Optional[str], pattern: Optional[str], limit: Optional[int]
@@ -5937,34 +6136,80 @@ def map_unresolved_function(
5937
6136
  if input is None or pattern is None:
5938
6137
  return None
5939
6138
 
6139
+ import re
6140
+
6141
+ try:
6142
+ re.compile(pattern)
6143
+ except re.error:
6144
+ raise ValueError(
6145
+ f"Failed to split string, provided pattern: {pattern} is invalid"
6146
+ )
6147
+
5940
6148
  if limit == 1:
5941
6149
  return [input]
5942
6150
 
5943
- import re
6151
+ if not input:
6152
+ return []
5944
6153
 
5945
6154
  # A default of -1 is passed in PySpark, but RE needs it to be 0 to provide all splits.
5946
6155
  # In PySpark, the limit also indicates the max size of the resulting array, but in RE
5947
6156
  # the remainder is returned as another element.
5948
6157
  maxsplit = limit - 1 if limit > 0 else 0
5949
6158
 
5950
- split_result = re.split(pattern, input, maxsplit)
5951
6159
  if len(pattern) == 0:
5952
- # RE.split provides a first and last empty element that is not there in PySpark.
5953
- split_result = split_result[1 : len(split_result) - 1]
6160
+ return list(input) if limit <= 0 else list(input)[:limit]
6161
+
6162
+ match pattern:
6163
+ case "|":
6164
+ split_result = re.split(pattern, input, 0)
6165
+ input_limit = limit + 1 if limit > 0 else len(split_result)
6166
+ return (
6167
+ split_result
6168
+ if input_limit == 0
6169
+ else split_result[1:input_limit]
6170
+ )
6171
+ case "$":
6172
+ return [input, ""] if maxsplit >= 0 else [input]
6173
+ case "^":
6174
+ return [input]
6175
+ case _:
6176
+ return re.split(pattern, input, maxsplit)
6177
+
6178
+ def split_string(str_: Column, pattern: Column, limit: Column):
6179
+ native_split = _split(str_, pattern, limit)
6180
+ # When pattern is a literal and doesn't contain any regex special characters
6181
+ # And when limit is less than or equal to 0
6182
+ # Native Snowflake Split function is used to optimise performance
6183
+ if isinstance(pattern._expression, Literal):
6184
+ pattern_value = pattern._expression.value
6185
+
6186
+ if pattern_value is None:
6187
+ return snowpark_fn.lit(None)
6188
+
6189
+ is_regexp = re.match(
6190
+ ".*[\\[\\.\\]\\*\\?\\+\\^\\$\\{\\}\\|\\(\\)\\\\].*",
6191
+ pattern_value,
6192
+ )
6193
+ is_empty = len(pattern_value) == 0
6194
+
6195
+ if not is_empty and not is_regexp:
6196
+ return snowpark_fn.when(
6197
+ limit <= 0,
6198
+ snowpark_fn.split(str_, pattern).cast(result_type),
6199
+ ).otherwise(native_split)
5954
6200
 
5955
- return split_result
6201
+ return native_split
5956
6202
 
5957
6203
  match snowpark_args:
5958
6204
  case [str_, pattern]:
5959
6205
  spark_function_name = (
5960
6206
  f"split({snowpark_arg_names[0]}, {snowpark_arg_names[1]}, -1)"
5961
6207
  )
5962
- result_exp = _split(str_, pattern, snowpark_fn.lit(0))
6208
+ result_exp = split_string(str_, pattern, snowpark_fn.lit(-1))
5963
6209
  case [str_, pattern, limit]: # noqa: F841
5964
- result_exp = _split(str_, pattern, limit)
6210
+ result_exp = split_string(str_, pattern, limit)
5965
6211
  case _:
5966
6212
  raise ValueError(f"Invalid number of arguments to {function_name}")
5967
- result_type = ArrayType(StringType())
5968
6213
  case "split_part":
5969
6214
  result_exp = snowpark_fn.call_function("split_part", *snowpark_args)
5970
6215
  result_type = StringType()
@@ -6274,6 +6519,10 @@ def map_unresolved_function(
6274
6519
  )
6275
6520
  result_type = TimestampType(snowpark.types.TimestampTimeZone.NTZ)
6276
6521
  case "timestamp_millis":
6522
+ if not isinstance(snowpark_typed_args[0].typ, _IntegralType):
6523
+ raise AnalysisException(
6524
+ f'[DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE] Cannot resolve "timestamp_millis({snowpark_arg_names[0]}" due to data type mismatch: Parameter 1 requires the "INTEGRAL" type, however "{snowpark_arg_names[0]}" has the type "{snowpark_typed_args[0].typ}".'
6525
+ )
6277
6526
  result_exp = snowpark_fn.cast(
6278
6527
  snowpark_fn.to_timestamp(snowpark_args[0] * 1_000, 6),
6279
6528
  TimestampType(snowpark.types.TimestampTimeZone.NTZ),
@@ -6283,6 +6532,10 @@ def map_unresolved_function(
6283
6532
  # Spark allows seconds to be fractional. Snowflake does not allow that
6284
6533
  # even though the documentation explicitly says that it does.
6285
6534
  # As a workaround, use integer milliseconds instead of fractional seconds.
6535
+ if not isinstance(snowpark_typed_args[0].typ, _NumericType):
6536
+ raise AnalysisException(
6537
+ f"""AnalysisException: [DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE] Cannot resolve "{function_name}({snowpark_arg_names[0]})" due to data type mismatch: Parameter 1 requires the "NUMERIC" type, however "{snowpark_arg_names[0]}" has the type "{snowpark_typed_args[0].typ}".;"""
6538
+ )
6286
6539
  result_exp = snowpark_fn.cast(
6287
6540
  snowpark_fn.to_timestamp(
6288
6541
  snowpark_fn.cast(snowpark_args[0] * 1_000_000, LongType()), 6
@@ -6725,7 +6978,20 @@ def map_unresolved_function(
6725
6978
  result_type = StringType()
6726
6979
  case "trunc":
6727
6980
  part = unwrap_literal(exp.unresolved_function.arguments[1])
6728
- if part is None:
6981
+ part = None if part is None else part.lower()
6982
+
6983
+ allowed_parts = {
6984
+ "year",
6985
+ "yyyy",
6986
+ "yy",
6987
+ "month",
6988
+ "mon",
6989
+ "mm",
6990
+ "week",
6991
+ "quarter",
6992
+ }
6993
+
6994
+ if part not in allowed_parts:
6729
6995
  result_exp = snowpark_fn.lit(None)
6730
6996
  else:
6731
6997
  result_exp = _try_to_cast(
@@ -7116,6 +7382,12 @@ def map_unresolved_function(
7116
7382
  )
7117
7383
  )
7118
7384
  )
7385
+ raise_fn = _raise_error_helper(BinaryType(), IllegalArgumentException)
7386
+ result_exp = (
7387
+ snowpark_fn.when(unbase_arg.is_null(), snowpark_fn.lit(None))
7388
+ .when(result_exp.is_null(), raise_fn(snowpark_fn.lit("Invalid input")))
7389
+ .otherwise(result_exp)
7390
+ )
7119
7391
  result_type = BinaryType()
7120
7392
  case "unhex":
7121
7393
  # Non string columns, convert them to string type. This mimics pyspark behavior.
@@ -7316,6 +7588,15 @@ def map_unresolved_function(
7316
7588
  )
7317
7589
  result_type = LongType()
7318
7590
  case "when" | "if":
7591
+ # Validate that the condition is a boolean expression
7592
+ if len(snowpark_typed_args) > 0:
7593
+ condition_type = snowpark_typed_args[0].typ
7594
+ if not isinstance(condition_type, BooleanType):
7595
+ raise AnalysisException(
7596
+ f"[DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE] Cannot resolve CASE WHEN condition due to data type mismatch: "
7597
+ f"Parameter 1 requires the 'BOOLEAN' type, however got '{condition_type}'"
7598
+ )
7599
+
7319
7600
  name_components = ["CASE"]
7320
7601
  name_components.append("WHEN")
7321
7602
  name_components.append(snowpark_arg_names[0])
@@ -7334,6 +7615,13 @@ def map_unresolved_function(
7334
7615
  name_components.append(snowpark_arg_names[i])
7335
7616
  name_components.append("THEN")
7336
7617
  name_components.append(snowpark_arg_names[i + 1])
7618
+ # Validate each WHEN condition
7619
+ condition_type = snowpark_typed_args[i].typ
7620
+ if not isinstance(condition_type, BooleanType):
7621
+ raise AnalysisException(
7622
+ f"[DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE] Cannot resolve CASE WHEN condition due to data type mismatch: "
7623
+ f"Parameter {i + 1} requires the 'BOOLEAN' type, however got '{condition_type}'"
7624
+ )
7337
7625
  result_exp = result_exp.when(snowpark_args[i], snowpark_args[i + 1])
7338
7626
  result_type_indexes.append(i + 1)
7339
7627
  name_components.append("END")
@@ -7710,16 +7998,8 @@ def _handle_current_timestamp():
7710
7998
 
7711
7999
 
7712
8000
  def _equivalent_decimal(type):
7713
- match (type):
7714
- case ByteType():
7715
- return DecimalType(3, 0)
7716
- case ShortType():
7717
- return DecimalType(5, 0)
7718
- case IntegerType():
7719
- return DecimalType(10, 0)
7720
- case LongType():
7721
- return DecimalType(20, 0)
7722
- return DecimalType(38, 0)
8001
+ (precision, scale) = _get_type_precision(type)
8002
+ return DecimalType(precision, scale)
7723
8003
 
7724
8004
 
7725
8005
  def _resolve_decimal_and_numeric(type1: DecimalType, type2: _NumericType) -> DataType:
@@ -8778,7 +9058,9 @@ def _get_type_precision(typ: DataType) -> tuple[int, int]:
8778
9058
  case IntegerType():
8779
9059
  return 10, 0 # -2147483648 to 2147483647
8780
9060
  case LongType():
8781
- return 19, 0 # -9223372036854775808 to 9223372036854775807
9061
+ return 20, 0 # -9223372036854775808 to 9223372036854775807
9062
+ case NullType():
9063
+ return 6, 2 # NULL
8782
9064
  case _:
8783
9065
  return 38, 0 # Default to maximum precision for other types
8784
9066
 
@@ -8993,16 +9275,12 @@ def _try_arithmetic_helper(
8993
9275
  typed_args[1].typ, DecimalType
8994
9276
  ):
8995
9277
  new_scale = s2
8996
- new_precision = (
8997
- p1 + s2 + 1
8998
- ) # Integral precision + decimal scale + 1 for carry
9278
+ new_precision = max(p2, p1 + s2)
8999
9279
  elif isinstance(typed_args[0].typ, DecimalType) and isinstance(
9000
9280
  typed_args[1].typ, _IntegralType
9001
9281
  ):
9002
9282
  new_scale = s1
9003
- new_precision = (
9004
- p2 + s1 + 1
9005
- ) # Integral precision + decimal scale + 1 for carry
9283
+ new_precision = max(p1, p2 + s1)
9006
9284
  else:
9007
9285
  # Both decimal types
9008
9286
  if operation_type == 1 and s1 == s2: # subtraction with matching scales
@@ -9081,13 +9359,13 @@ def _add_sub_precision_helper(
9081
9359
  typed_args[1].typ, DecimalType
9082
9360
  ):
9083
9361
  new_scale = s2
9084
- new_precision = p1 + s2 + 1 # Integral precision + decimal scale + 1 for carry
9362
+ new_precision = max(p2, p1 + s2)
9085
9363
  return_type_precision, return_type_scale = new_precision, new_scale
9086
9364
  elif isinstance(typed_args[0].typ, DecimalType) and isinstance(
9087
9365
  typed_args[1].typ, _IntegralType
9088
9366
  ):
9089
9367
  new_scale = s1
9090
- new_precision = p2 + s1 + 1 # Integral precision + decimal scale + 1 for carry
9368
+ new_precision = max(p1, p2 + s1)
9091
9369
  return_type_precision, return_type_scale = new_precision, new_scale
9092
9370
  else:
9093
9371
  (
@@ -9169,11 +9447,25 @@ def _mul_div_precision_helper(
9169
9447
  )
9170
9448
 
9171
9449
 
9172
- def _raise_error_udf_helper(return_type: DataType):
9173
- def _raise_error(message=None):
9174
- raise ValueError(message)
9450
+ def _raise_error_helper(return_type: DataType, error_class=None):
9451
+ error_class = (
9452
+ f":{error_class.__name__}"
9453
+ if error_class and hasattr(error_class, "__name__")
9454
+ else ""
9455
+ )
9456
+
9457
+ def _raise_fn(*msgs: Column) -> Column:
9458
+ return snowpark_fn.cast(
9459
+ snowpark_fn.abs(
9460
+ snowpark_fn.concat(
9461
+ snowpark_fn.lit(f"[snowpark-connect-exception{error_class}]"),
9462
+ *(msg.try_cast(StringType()) for msg in msgs),
9463
+ )
9464
+ ).cast(StringType()),
9465
+ return_type,
9466
+ )
9175
9467
 
9176
- return cached_udf(_raise_error, return_type=return_type, input_types=[StringType()])
9468
+ return _raise_fn
9177
9469
 
9178
9470
 
9179
9471
  def _divnull(dividend: Column, divisor: Column) -> Column:
@@ -9448,3 +9740,22 @@ def _validate_number_format_string(format_str: str) -> None:
9448
9740
  raise AnalysisException(
9449
9741
  f"[INVALID_FORMAT.WRONG_NUM_DIGIT] The format is invalid: '{format_str}'. The format string requires at least one number digit."
9450
9742
  )
9743
+
9744
+
9745
+ def _trim_helper(value: Column, trim_value: Column, trim_type: Column) -> Column:
9746
+ @cached_udf(
9747
+ return_type=BinaryType(),
9748
+ input_types=[BinaryType(), BinaryType(), StringType()],
9749
+ )
9750
+ def _binary_trim_udf(value: bytes, trim_value: bytes, trim_type: str) -> bytes:
9751
+ if value is None or trim_value is None:
9752
+ return value
9753
+ if trim_type in ("rtrim", "btrim", "trim"):
9754
+ while value.endswith(trim_value):
9755
+ value = value[: -len(trim_value)]
9756
+ if trim_type in ("ltrim", "btrim", "trim"):
9757
+ while value.startswith(trim_value):
9758
+ value = value[len(trim_value) :]
9759
+ return value
9760
+
9761
+ return _binary_trim_udf(value, trim_value, trim_type)