snowpark-connect 0.21.0__py3-none-any.whl → 0.22.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of snowpark-connect might be problematic. Click here for more details.

Files changed (41) hide show
  1. snowflake/snowpark_connect/config.py +19 -3
  2. snowflake/snowpark_connect/error/error_utils.py +25 -0
  3. snowflake/snowpark_connect/expression/map_udf.py +4 -4
  4. snowflake/snowpark_connect/expression/map_unresolved_function.py +203 -128
  5. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2_grpc.py +4 -0
  6. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2_grpc.py +4 -0
  7. snowflake/snowpark_connect/relation/map_aggregate.py +102 -18
  8. snowflake/snowpark_connect/relation/map_column_ops.py +21 -2
  9. snowflake/snowpark_connect/relation/map_map_partitions.py +3 -1
  10. snowflake/snowpark_connect/relation/map_sql.py +18 -191
  11. snowflake/snowpark_connect/relation/map_udtf.py +4 -4
  12. snowflake/snowpark_connect/relation/read/map_read_json.py +12 -1
  13. snowflake/snowpark_connect/relation/read/reader_config.py +1 -0
  14. snowflake/snowpark_connect/relation/write/map_write.py +68 -24
  15. snowflake/snowpark_connect/server.py +9 -0
  16. snowflake/snowpark_connect/type_mapping.py +4 -0
  17. snowflake/snowpark_connect/utils/describe_query_cache.py +2 -9
  18. snowflake/snowpark_connect/utils/session.py +0 -4
  19. snowflake/snowpark_connect/utils/telemetry.py +213 -61
  20. snowflake/snowpark_connect/utils/udxf_import_utils.py +14 -0
  21. snowflake/snowpark_connect/version.py +1 -1
  22. snowflake/snowpark_decoder/__init__.py +0 -0
  23. snowflake/snowpark_decoder/_internal/proto/generated/DataframeProcessorMsg_pb2.py +36 -0
  24. snowflake/snowpark_decoder/_internal/proto/generated/DataframeProcessorMsg_pb2.pyi +156 -0
  25. snowflake/snowpark_decoder/dp_session.py +111 -0
  26. snowflake/snowpark_decoder/spark_decoder.py +76 -0
  27. {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.22.1.dist-info}/METADATA +2 -2
  28. {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.22.1.dist-info}/RECORD +40 -29
  29. {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.22.1.dist-info}/top_level.txt +1 -0
  30. spark/__init__.py +0 -0
  31. spark/connect/__init__.py +0 -0
  32. spark/connect/envelope_pb2.py +31 -0
  33. spark/connect/envelope_pb2.pyi +46 -0
  34. snowflake/snowpark_connect/includes/jars/jackson-mapper-asl-1.9.13.jar +0 -0
  35. {snowpark_connect-0.21.0.data → snowpark_connect-0.22.1.data}/scripts/snowpark-connect +0 -0
  36. {snowpark_connect-0.21.0.data → snowpark_connect-0.22.1.data}/scripts/snowpark-session +0 -0
  37. {snowpark_connect-0.21.0.data → snowpark_connect-0.22.1.data}/scripts/snowpark-submit +0 -0
  38. {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.22.1.dist-info}/WHEEL +0 -0
  39. {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.22.1.dist-info}/licenses/LICENSE-binary +0 -0
  40. {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.22.1.dist-info}/licenses/LICENSE.txt +0 -0
  41. {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.22.1.dist-info}/licenses/NOTICE-binary +0 -0
@@ -28,6 +28,7 @@ from google.protobuf.message import Message
28
28
  from pyspark.errors.exceptions.base import (
29
29
  AnalysisException,
30
30
  ArithmeticException,
31
+ ArrayIndexOutOfBoundsException,
31
32
  DateTimeException,
32
33
  IllegalArgumentException,
33
34
  NumberFormatException,
@@ -39,6 +40,7 @@ from pyspark.sql.types import _parse_datatype_json_string
39
40
  import snowflake.snowpark.functions as snowpark_fn
40
41
  from snowflake import snowpark
41
42
  from snowflake.snowpark import Column, Session
43
+ from snowflake.snowpark._internal.analyzer.expression import Literal
42
44
  from snowflake.snowpark._internal.analyzer.unary_expression import Alias
43
45
  from snowflake.snowpark.types import (
44
46
  ArrayType,
@@ -139,7 +141,7 @@ from snowflake.snowpark_connect.utils.xxhash64 import (
139
141
  MAX_UINT64 = 2**64 - 1
140
142
  MAX_INT64 = 2**63 - 1
141
143
  MIN_INT64 = -(2**63)
142
-
144
+ MAX_ARRAY_SIZE = 2_147_483_647
143
145
 
144
146
  NAN, INFINITY = float("nan"), float("inf")
145
147
 
@@ -638,37 +640,22 @@ def map_unresolved_function(
638
640
  [arg.typ for arg in snowpark_typed_args]
639
641
  )
640
642
  case "/":
641
- if isinstance(
642
- snowpark_typed_args[0].typ, (IntegerType, LongType, ShortType)
643
- ) and isinstance(
644
- snowpark_typed_args[1].typ, (IntegerType, LongType, ShortType)
645
- ):
646
- # Check if both arguments are integer types. Snowpark performs integer division, and precision is lost.
647
- # Cast to double and perform division
648
- result_exp = _divnull(
649
- snowpark_args[0].cast(DoubleType()),
650
- snowpark_args[1].cast(DoubleType()),
651
- )
652
- result_type = DoubleType()
653
- elif (
654
- isinstance(snowpark_typed_args[0].typ, DecimalType)
655
- and isinstance(snowpark_typed_args[1].typ, DecimalType)
656
- or isinstance(snowpark_typed_args[0].typ, DecimalType)
657
- and isinstance(snowpark_typed_args[1].typ, _IntegralType)
658
- or isinstance(snowpark_typed_args[0].typ, _IntegralType)
659
- and isinstance(snowpark_typed_args[1].typ, DecimalType)
660
- ):
661
- result_exp, (
662
- return_type_precision,
663
- return_type_scale,
664
- ) = _mul_div_precision_helper(snowpark_typed_args, snowpark_args, 1)
665
- result_type = DecimalType(return_type_precision, return_type_scale)
666
- else:
667
- # Perform division directly
668
- result_exp = _divnull(snowpark_args[0], snowpark_args[1])
669
- result_type = _find_common_type(
670
- [arg.typ for arg in snowpark_typed_args]
671
- )
643
+ match (snowpark_typed_args[0].typ, snowpark_typed_args[1].typ):
644
+ case (DecimalType(), t) | (t, DecimalType()) if isinstance(
645
+ t, DecimalType
646
+ ) or isinstance(t, _IntegralType) or isinstance(
647
+ snowpark_typed_args[1].typ, NullType
648
+ ):
649
+ result_exp, (
650
+ return_type_precision,
651
+ return_type_scale,
652
+ ) = _mul_div_precision_helper(snowpark_typed_args, snowpark_args, 1)
653
+ result_type = DecimalType(return_type_precision, return_type_scale)
654
+ case _:
655
+ result_type = DoubleType()
656
+ dividend = snowpark_args[0].cast(result_type)
657
+ divisor = snowpark_args[1].cast(result_type)
658
+ result_exp = _divnull(dividend, divisor)
672
659
  case "~":
673
660
  result_exp = TypedColumn(
674
661
  snowpark_fn.bitnot(snowpark_args[0]),
@@ -1201,35 +1188,18 @@ def map_unresolved_function(
1201
1188
  snowpark_fn.asinh(snowpark_args[0]), lambda: [DoubleType()]
1202
1189
  )
1203
1190
  case "assert_true":
1191
+ result_type = NullType()
1192
+ raise_error = _raise_error_helper(result_type)
1204
1193
 
1205
- @cached_udf(
1206
- input_types=[BooleanType()],
1207
- return_type=StringType(),
1208
- )
1209
- def _assert_true_single(expr):
1210
- if not expr:
1211
- raise ValueError("assertion failed")
1212
- return None
1213
-
1214
- @cached_udf(
1215
- input_types=[BooleanType(), StringType()],
1216
- return_type=StringType(),
1217
- )
1218
- def _assert_true_with_message(expr, message):
1219
- if not expr:
1220
- raise ValueError(message)
1221
- return None
1222
-
1223
- # Handle different argument counts using match pattern
1224
1194
  match snowpark_args:
1225
1195
  case [expr]:
1226
- result_exp = TypedColumn(
1227
- _assert_true_single(expr), lambda: [StringType()]
1228
- )
1196
+ result_exp = snowpark_fn.when(
1197
+ expr, snowpark_fn.lit(None)
1198
+ ).otherwise(raise_error(snowpark_fn.lit("assertion failed")))
1229
1199
  case [expr, message]:
1230
- result_exp = TypedColumn(
1231
- _assert_true_with_message(expr, message), lambda: [StringType()]
1232
- )
1200
+ result_exp = snowpark_fn.when(
1201
+ expr, snowpark_fn.lit(None)
1202
+ ).otherwise(raise_error(snowpark_fn.cast(message, StringType())))
1233
1203
  case _:
1234
1204
  raise AnalysisException(
1235
1205
  f"[WRONG_NUM_ARGS.WITHOUT_SUGGESTION] The `assert_true` requires 1 or 2 parameters but the actual number is {len(snowpark_args)}."
@@ -2291,31 +2261,32 @@ def map_unresolved_function(
2291
2261
  )
2292
2262
  case "elt":
2293
2263
  n = snowpark_args[0]
2294
-
2295
2264
  values = snowpark_fn.array_construct(*snowpark_args[1:])
2296
2265
 
2297
2266
  if spark_sql_ansi_enabled:
2298
-
2299
- @cached_udf(
2300
- input_types=[IntegerType()],
2301
- return_type=StringType(),
2267
+ raise_error = _raise_error_helper(
2268
+ StringType(), error_class=ArrayIndexOutOfBoundsException
2302
2269
  )
2303
- def _raise_out_of_bounds_error(n: int) -> str:
2304
- raise ValueError(
2305
- f"ArrayIndexOutOfBoundsException: {n} is not within the input bounds."
2306
- )
2307
-
2308
2270
  values_size = snowpark_fn.lit(len(snowpark_args) - 1)
2309
2271
 
2310
2272
  result_exp = (
2311
2273
  snowpark_fn.when(snowpark_fn.is_null(n), snowpark_fn.lit(None))
2312
2274
  .when(
2313
2275
  (snowpark_fn.lit(1) <= n) & (n <= values_size),
2314
- snowpark_fn.get(
2315
- values, snowpark_fn.nvl(n - 1, snowpark_fn.lit(0))
2276
+ snowpark_fn.cast(
2277
+ snowpark_fn.get(
2278
+ values, snowpark_fn.nvl(n - 1, snowpark_fn.lit(0))
2279
+ ),
2280
+ StringType(),
2316
2281
  ),
2317
2282
  )
2318
- .otherwise(_raise_out_of_bounds_error(n))
2283
+ .otherwise(
2284
+ raise_error(
2285
+ snowpark_fn.lit("[INVALID_ARRAY_INDEX] The index "),
2286
+ snowpark_fn.cast(n, StringType()),
2287
+ snowpark_fn.lit(" is out of bounds."),
2288
+ )
2289
+ )
2319
2290
  )
2320
2291
  else:
2321
2292
  result_exp = snowpark_fn.when(
@@ -3289,7 +3260,7 @@ def map_unresolved_function(
3289
3260
  ).cast(LongType())
3290
3261
  result_type = LongType()
3291
3262
  case "hll_union_agg":
3292
- raise_error = _raise_error_udf_helper(BinaryType())
3263
+ raise_error = _raise_error_helper(BinaryType())
3293
3264
  args = exp.unresolved_function.arguments
3294
3265
  allow_different_lgConfigK = len(args) == 2 and unwrap_literal(args[1])
3295
3266
  spark_function_name = f"{function_name}({snowpark_arg_names[0]}, {str(allow_different_lgConfigK).lower()})"
@@ -3329,7 +3300,7 @@ def map_unresolved_function(
3329
3300
  SELECT arg1 as x)
3330
3301
  """,
3331
3302
  )
3332
- raise_error = _raise_error_udf_helper(BinaryType())
3303
+ raise_error = _raise_error_helper(BinaryType())
3333
3304
  args = exp.unresolved_function.arguments
3334
3305
  allow_different_lgConfigK = len(args) == 3 and unwrap_literal(args[2])
3335
3306
  spark_function_name = f"{function_name}({snowpark_arg_names[0]}, {snowpark_arg_names[1]}, {str(allow_different_lgConfigK).lower()})"
@@ -4053,7 +4024,7 @@ def map_unresolved_function(
4053
4024
  snowpark_fn.is_null(snowpark_args[i]),
4054
4025
  # udf execution on XP seems to be lazy, so this should only run when there is a null key
4055
4026
  # otherwise there should be no udf env setup or execution
4056
- _raise_error_udf_helper(VariantType())(
4027
+ _raise_error_helper(VariantType())(
4057
4028
  snowpark_fn.lit(
4058
4029
  "[NULL_MAP_KEY] Cannot use null as map key."
4059
4030
  )
@@ -4115,6 +4086,14 @@ def map_unresolved_function(
4115
4086
  )
4116
4087
  result_type = MapType(key_type, value_type)
4117
4088
  case "map_contains_key":
4089
+ if isinstance(snowpark_typed_args[0].typ, NullType):
4090
+ raise AnalysisException(
4091
+ f"""[DATATYPE_MISMATCH.MAP_FUNCTION_DIFF_TYPES] Cannot resolve "map_contains_key({snowpark_arg_names[0]}, {snowpark_arg_names[1]})" due to data type mismatch: Input to `map_contains_key` should have been "MAP" followed by a value with same key type, but it's ["VOID", "INT"]."""
4092
+ )
4093
+ if isinstance(snowpark_typed_args[1].typ, NullType):
4094
+ raise AnalysisException(
4095
+ f"""[DATATYPE_MISMATCH.NULL_TYPE] Cannot resolve "map_contains_key({snowpark_arg_names[0]}, {snowpark_arg_names[1]})" due to data type mismatch: Null typed values cannot be used as arguments of `map_contains_key`."""
4096
+ )
4118
4097
  args = (
4119
4098
  [snowpark_args[1], snowpark_args[0]]
4120
4099
  if isinstance(snowpark_typed_args[0].typ, MapType)
@@ -4244,17 +4223,37 @@ def map_unresolved_function(
4244
4223
 
4245
4224
  last_win_dedup = global_config.spark_sql_mapKeyDedupPolicy == "LAST_WIN"
4246
4225
 
4247
- result_exp = snowpark_fn.cast(
4248
- snowpark_fn.function("reduce")(
4249
- snowpark_args[0],
4250
- snowpark_fn.object_construct(),
4251
- snowpark_fn.sql_expr(
4252
- # value_field is cast to variant because object_insert doesn't allow structured types,
4253
- # and structured types are not coercible to variant
4254
- # TODO: allow structured types in object_insert?
4255
- f"(acc, e) -> object_insert(acc, e:{key_field}, e:{value_field}::variant, {last_win_dedup})"
4256
- ),
4226
+ # Check if any entry has a NULL key
4227
+ has_null_key = (
4228
+ snowpark_fn.function("array_size")(
4229
+ snowpark_fn.function("filter")(
4230
+ snowpark_args[0],
4231
+ snowpark_fn.sql_expr(f"e -> e:{key_field} IS NULL"),
4232
+ )
4233
+ )
4234
+ > 0
4235
+ )
4236
+
4237
+ # Create error UDF for NULL keys (same pattern as map function)
4238
+ null_key_error = _raise_error_helper(VariantType())(
4239
+ snowpark_fn.lit("[NULL_MAP_KEY] Cannot use null as map key.")
4240
+ )
4241
+
4242
+ # Create the reduce operation
4243
+ reduce_result = snowpark_fn.function("reduce")(
4244
+ snowpark_args[0],
4245
+ snowpark_fn.object_construct(),
4246
+ snowpark_fn.sql_expr(
4247
+ # value_field is cast to variant because object_insert doesn't allow structured types,
4248
+ # and structured types are not coercible to variant
4249
+ # TODO: allow structured types in object_insert?
4250
+ f"(acc, e) -> object_insert(acc, e:{key_field}, e:{value_field}::variant, {last_win_dedup})"
4257
4251
  ),
4252
+ )
4253
+
4254
+ # Use conditional logic: if there are NULL keys, throw error; otherwise proceed with reduce
4255
+ result_exp = snowpark_fn.cast(
4256
+ snowpark_fn.when(has_null_key, null_key_error).otherwise(reduce_result),
4258
4257
  MapType(key_type, value_type),
4259
4258
  )
4260
4259
  result_type = MapType(key_type, value_type)
@@ -4273,23 +4272,35 @@ def map_unresolved_function(
4273
4272
  # TODO: implement in Snowflake/Snowpark
4274
4273
  # technically this could be done with a lateral join, but it's probably not worth the effort
4275
4274
  arg_type = snowpark_typed_args[0].typ
4276
- if not isinstance(arg_type, MapType):
4275
+ if not isinstance(arg_type, (MapType, NullType)):
4277
4276
  raise AnalysisException(
4278
4277
  f"map_values requires a MapType argument, got {arg_type}"
4279
4278
  )
4280
4279
 
4281
4280
  def _map_values(obj: dict) -> list:
4282
- return list(obj.values()) if obj else None
4281
+ if obj is None:
4282
+ return None
4283
+ return list(obj.values())
4283
4284
 
4284
4285
  map_values = cached_udf(
4285
4286
  _map_values, return_type=ArrayType(), input_types=[StructType()]
4286
4287
  )
4287
4288
 
4288
- result_exp = snowpark_fn.cast(
4289
- map_values(snowpark_fn.cast(snowpark_args[0], StructType())),
4290
- ArrayType(arg_type.value_type),
4291
- )
4292
- result_type = ArrayType(arg_type.value_type)
4289
+ # Handle NULL input directly at expression level
4290
+ if isinstance(arg_type, NullType):
4291
+ # If input is NULL literal, return NULL
4292
+ result_exp = snowpark_fn.lit(None)
4293
+ result_type = ArrayType(NullType())
4294
+ else:
4295
+ result_exp = snowpark_fn.when(
4296
+ snowpark_args[0].is_null(), snowpark_fn.lit(None)
4297
+ ).otherwise(
4298
+ snowpark_fn.cast(
4299
+ map_values(snowpark_fn.cast(snowpark_args[0], StructType())),
4300
+ ArrayType(arg_type.value_type),
4301
+ )
4302
+ )
4303
+ result_type = ArrayType(arg_type.value_type)
4293
4304
  case "mask":
4294
4305
 
4295
4306
  number_of_args = len(snowpark_args)
@@ -5194,7 +5205,7 @@ def map_unresolved_function(
5194
5205
  result_type = DoubleType()
5195
5206
  case "raise_error":
5196
5207
  result_type = StringType()
5197
- raise_error = _raise_error_udf_helper(result_type)
5208
+ raise_error = _raise_error_helper(result_type)
5198
5209
  result_exp = raise_error(*snowpark_args)
5199
5210
  case "rand" | "random":
5200
5211
  # Snowpark random() generates a 64 bit signed integer, but pyspark is [0.0, 1.0).
@@ -5279,7 +5290,7 @@ def map_unresolved_function(
5279
5290
  snowpark_args[2],
5280
5291
  ),
5281
5292
  ),
5282
- _raise_error_udf_helper(StringType())(
5293
+ _raise_error_helper(StringType())(
5283
5294
  snowpark_fn.lit(
5284
5295
  "[INVALID_PARAMETER_VALUE.REGEX_GROUP_INDEX] The value of parameter(s) `idx` in `regexp_extract` is invalid."
5285
5296
  )
@@ -5329,7 +5340,7 @@ def map_unresolved_function(
5329
5340
  idx,
5330
5341
  )
5331
5342
  ),
5332
- _raise_error_udf_helper(ArrayType(StringType()))(
5343
+ _raise_error_helper(ArrayType(StringType()))(
5333
5344
  snowpark_fn.lit(
5334
5345
  "[INVALID_PARAMETER_VALUE.REGEX_GROUP_INDEX] The value of parameter(s) `idx` in `regexp_extract_all` is invalid."
5335
5346
  )
@@ -6043,7 +6054,7 @@ def map_unresolved_function(
6043
6054
  result_exp = snowpark_fn.skew(snowpark_fn.lit(None))
6044
6055
  result_type = DoubleType()
6045
6056
  case "slice":
6046
- raise_error = _raise_error_udf_helper(snowpark_typed_args[0].typ)
6057
+ raise_error = _raise_error_helper(snowpark_typed_args[0].typ)
6047
6058
  spark_index = snowpark_args[1]
6048
6059
  arr_size = snowpark_fn.array_size(snowpark_args[0])
6049
6060
  slice_len = snowpark_args[2]
@@ -6113,10 +6124,11 @@ def map_unresolved_function(
6113
6124
  result_exp = snowpark_fn.lit(0)
6114
6125
  result_type = LongType()
6115
6126
  case "split":
6127
+ result_type = ArrayType(StringType())
6116
6128
 
6117
6129
  @cached_udf(
6118
6130
  input_types=[StringType(), StringType(), IntegerType()],
6119
- return_type=ArrayType(StringType()),
6131
+ return_type=result_type,
6120
6132
  )
6121
6133
  def _split(
6122
6134
  input: Optional[str], pattern: Optional[str], limit: Optional[int]
@@ -6124,34 +6136,80 @@ def map_unresolved_function(
6124
6136
  if input is None or pattern is None:
6125
6137
  return None
6126
6138
 
6139
+ import re
6140
+
6141
+ try:
6142
+ re.compile(pattern)
6143
+ except re.error:
6144
+ raise ValueError(
6145
+ f"Failed to split string, provided pattern: {pattern} is invalid"
6146
+ )
6147
+
6127
6148
  if limit == 1:
6128
6149
  return [input]
6129
6150
 
6130
- import re
6151
+ if not input:
6152
+ return []
6131
6153
 
6132
6154
  # A default of -1 is passed in PySpark, but RE needs it to be 0 to provide all splits.
6133
6155
  # In PySpark, the limit also indicates the max size of the resulting array, but in RE
6134
6156
  # the remainder is returned as another element.
6135
6157
  maxsplit = limit - 1 if limit > 0 else 0
6136
6158
 
6137
- split_result = re.split(pattern, input, maxsplit)
6138
6159
  if len(pattern) == 0:
6139
- # RE.split provides a first and last empty element that is not there in PySpark.
6140
- split_result = split_result[1 : len(split_result) - 1]
6160
+ return list(input) if limit <= 0 else list(input)[:limit]
6141
6161
 
6142
- return split_result
6162
+ match pattern:
6163
+ case "|":
6164
+ split_result = re.split(pattern, input, 0)
6165
+ input_limit = limit + 1 if limit > 0 else len(split_result)
6166
+ return (
6167
+ split_result
6168
+ if input_limit == 0
6169
+ else split_result[1:input_limit]
6170
+ )
6171
+ case "$":
6172
+ return [input, ""] if maxsplit >= 0 else [input]
6173
+ case "^":
6174
+ return [input]
6175
+ case _:
6176
+ return re.split(pattern, input, maxsplit)
6177
+
6178
+ def split_string(str_: Column, pattern: Column, limit: Column):
6179
+ native_split = _split(str_, pattern, limit)
6180
+ # When pattern is a literal and doesn't contain any regex special characters
6181
+ # And when limit is less than or equal to 0
6182
+ # Native Snowflake Split function is used to optimise performance
6183
+ if isinstance(pattern._expression, Literal):
6184
+ pattern_value = pattern._expression.value
6185
+
6186
+ if pattern_value is None:
6187
+ return snowpark_fn.lit(None)
6188
+
6189
+ is_regexp = re.match(
6190
+ ".*[\\[\\.\\]\\*\\?\\+\\^\\$\\{\\}\\|\\(\\)\\\\].*",
6191
+ pattern_value,
6192
+ )
6193
+ is_empty = len(pattern_value) == 0
6194
+
6195
+ if not is_empty and not is_regexp:
6196
+ return snowpark_fn.when(
6197
+ limit <= 0,
6198
+ snowpark_fn.split(str_, pattern).cast(result_type),
6199
+ ).otherwise(native_split)
6200
+
6201
+ return native_split
6143
6202
 
6144
6203
  match snowpark_args:
6145
6204
  case [str_, pattern]:
6146
6205
  spark_function_name = (
6147
6206
  f"split({snowpark_arg_names[0]}, {snowpark_arg_names[1]}, -1)"
6148
6207
  )
6149
- result_exp = _split(str_, pattern, snowpark_fn.lit(0))
6208
+ result_exp = split_string(str_, pattern, snowpark_fn.lit(-1))
6150
6209
  case [str_, pattern, limit]: # noqa: F841
6151
- result_exp = _split(str_, pattern, limit)
6210
+ result_exp = split_string(str_, pattern, limit)
6152
6211
  case _:
6153
6212
  raise ValueError(f"Invalid number of arguments to {function_name}")
6154
- result_type = ArrayType(StringType())
6155
6213
  case "split_part":
6156
6214
  result_exp = snowpark_fn.call_function("split_part", *snowpark_args)
6157
6215
  result_type = StringType()
@@ -6920,7 +6978,20 @@ def map_unresolved_function(
6920
6978
  result_type = StringType()
6921
6979
  case "trunc":
6922
6980
  part = unwrap_literal(exp.unresolved_function.arguments[1])
6923
- if part is None:
6981
+ part = None if part is None else part.lower()
6982
+
6983
+ allowed_parts = {
6984
+ "year",
6985
+ "yyyy",
6986
+ "yy",
6987
+ "month",
6988
+ "mon",
6989
+ "mm",
6990
+ "week",
6991
+ "quarter",
6992
+ }
6993
+
6994
+ if part not in allowed_parts:
6924
6995
  result_exp = snowpark_fn.lit(None)
6925
6996
  else:
6926
6997
  result_exp = _try_to_cast(
@@ -7311,7 +7382,7 @@ def map_unresolved_function(
7311
7382
  )
7312
7383
  )
7313
7384
  )
7314
- raise_fn = _raise_error_udf_helper(BinaryType())
7385
+ raise_fn = _raise_error_helper(BinaryType(), IllegalArgumentException)
7315
7386
  result_exp = (
7316
7387
  snowpark_fn.when(unbase_arg.is_null(), snowpark_fn.lit(None))
7317
7388
  .when(result_exp.is_null(), raise_fn(snowpark_fn.lit("Invalid input")))
@@ -7927,16 +7998,8 @@ def _handle_current_timestamp():
7927
7998
 
7928
7999
 
7929
8000
  def _equivalent_decimal(type):
7930
- match (type):
7931
- case ByteType():
7932
- return DecimalType(3, 0)
7933
- case ShortType():
7934
- return DecimalType(5, 0)
7935
- case IntegerType():
7936
- return DecimalType(10, 0)
7937
- case LongType():
7938
- return DecimalType(20, 0)
7939
- return DecimalType(38, 0)
8001
+ (precision, scale) = _get_type_precision(type)
8002
+ return DecimalType(precision, scale)
7940
8003
 
7941
8004
 
7942
8005
  def _resolve_decimal_and_numeric(type1: DecimalType, type2: _NumericType) -> DataType:
@@ -8995,7 +9058,9 @@ def _get_type_precision(typ: DataType) -> tuple[int, int]:
8995
9058
  case IntegerType():
8996
9059
  return 10, 0 # -2147483648 to 2147483647
8997
9060
  case LongType():
8998
- return 19, 0 # -9223372036854775808 to 9223372036854775807
9061
+ return 20, 0 # -9223372036854775808 to 9223372036854775807
9062
+ case NullType():
9063
+ return 6, 2 # NULL
8999
9064
  case _:
9000
9065
  return 38, 0 # Default to maximum precision for other types
9001
9066
 
@@ -9210,16 +9275,12 @@ def _try_arithmetic_helper(
9210
9275
  typed_args[1].typ, DecimalType
9211
9276
  ):
9212
9277
  new_scale = s2
9213
- new_precision = (
9214
- p1 + s2 + 1
9215
- ) # Integral precision + decimal scale + 1 for carry
9278
+ new_precision = max(p2, p1 + s2)
9216
9279
  elif isinstance(typed_args[0].typ, DecimalType) and isinstance(
9217
9280
  typed_args[1].typ, _IntegralType
9218
9281
  ):
9219
9282
  new_scale = s1
9220
- new_precision = (
9221
- p2 + s1 + 1
9222
- ) # Integral precision + decimal scale + 1 for carry
9283
+ new_precision = max(p1, p2 + s1)
9223
9284
  else:
9224
9285
  # Both decimal types
9225
9286
  if operation_type == 1 and s1 == s2: # subtraction with matching scales
@@ -9298,13 +9359,13 @@ def _add_sub_precision_helper(
9298
9359
  typed_args[1].typ, DecimalType
9299
9360
  ):
9300
9361
  new_scale = s2
9301
- new_precision = p1 + s2 + 1 # Integral precision + decimal scale + 1 for carry
9362
+ new_precision = max(p2, p1 + s2)
9302
9363
  return_type_precision, return_type_scale = new_precision, new_scale
9303
9364
  elif isinstance(typed_args[0].typ, DecimalType) and isinstance(
9304
9365
  typed_args[1].typ, _IntegralType
9305
9366
  ):
9306
9367
  new_scale = s1
9307
- new_precision = p2 + s1 + 1 # Integral precision + decimal scale + 1 for carry
9368
+ new_precision = max(p1, p2 + s1)
9308
9369
  return_type_precision, return_type_scale = new_precision, new_scale
9309
9370
  else:
9310
9371
  (
@@ -9386,11 +9447,25 @@ def _mul_div_precision_helper(
9386
9447
  )
9387
9448
 
9388
9449
 
9389
- def _raise_error_udf_helper(return_type: DataType):
9390
- def _raise_error(message=None):
9391
- raise ValueError(message)
9450
+ def _raise_error_helper(return_type: DataType, error_class=None):
9451
+ error_class = (
9452
+ f":{error_class.__name__}"
9453
+ if error_class and hasattr(error_class, "__name__")
9454
+ else ""
9455
+ )
9456
+
9457
+ def _raise_fn(*msgs: Column) -> Column:
9458
+ return snowpark_fn.cast(
9459
+ snowpark_fn.abs(
9460
+ snowpark_fn.concat(
9461
+ snowpark_fn.lit(f"[snowpark-connect-exception{error_class}]"),
9462
+ *(msg.try_cast(StringType()) for msg in msgs),
9463
+ )
9464
+ ).cast(StringType()),
9465
+ return_type,
9466
+ )
9392
9467
 
9393
- return cached_udf(_raise_error, return_type=return_type, input_types=[StringType()])
9468
+ return _raise_fn
9394
9469
 
9395
9470
 
9396
9471
  def _divnull(dividend: Column, divisor: Column) -> Column:
@@ -0,0 +1,4 @@
1
+ # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
2
+ """Client and server classes corresponding to protobuf-defined services."""
3
+ import grpc
4
+
@@ -0,0 +1,4 @@
1
+ # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
2
+ """Client and server classes corresponding to protobuf-defined services."""
3
+ import grpc
4
+