snowpark-connect 0.21.0__py3-none-any.whl → 0.23.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of snowpark-connect might be problematic. Click here for more details.

Files changed (56) hide show
  1. snowflake/snowpark_connect/config.py +19 -14
  2. snowflake/snowpark_connect/error/error_utils.py +32 -0
  3. snowflake/snowpark_connect/error/exceptions.py +4 -0
  4. snowflake/snowpark_connect/expression/hybrid_column_map.py +192 -0
  5. snowflake/snowpark_connect/expression/literal.py +9 -12
  6. snowflake/snowpark_connect/expression/map_cast.py +20 -4
  7. snowflake/snowpark_connect/expression/map_expression.py +8 -1
  8. snowflake/snowpark_connect/expression/map_udf.py +4 -4
  9. snowflake/snowpark_connect/expression/map_unresolved_extract_value.py +32 -5
  10. snowflake/snowpark_connect/expression/map_unresolved_function.py +269 -134
  11. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +8 -8
  12. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +4 -2
  13. snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +127 -21
  14. snowflake/snowpark_connect/relation/map_aggregate.py +154 -18
  15. snowflake/snowpark_connect/relation/map_column_ops.py +59 -8
  16. snowflake/snowpark_connect/relation/map_extension.py +58 -24
  17. snowflake/snowpark_connect/relation/map_local_relation.py +8 -1
  18. snowflake/snowpark_connect/relation/map_map_partitions.py +3 -1
  19. snowflake/snowpark_connect/relation/map_row_ops.py +30 -1
  20. snowflake/snowpark_connect/relation/map_sql.py +40 -196
  21. snowflake/snowpark_connect/relation/map_udtf.py +4 -4
  22. snowflake/snowpark_connect/relation/read/map_read.py +2 -1
  23. snowflake/snowpark_connect/relation/read/map_read_json.py +12 -1
  24. snowflake/snowpark_connect/relation/read/map_read_parquet.py +8 -1
  25. snowflake/snowpark_connect/relation/read/reader_config.py +10 -0
  26. snowflake/snowpark_connect/relation/read/utils.py +7 -6
  27. snowflake/snowpark_connect/relation/utils.py +170 -1
  28. snowflake/snowpark_connect/relation/write/map_write.py +306 -87
  29. snowflake/snowpark_connect/server.py +34 -5
  30. snowflake/snowpark_connect/type_mapping.py +6 -2
  31. snowflake/snowpark_connect/utils/describe_query_cache.py +2 -9
  32. snowflake/snowpark_connect/utils/env_utils.py +55 -0
  33. snowflake/snowpark_connect/utils/session.py +21 -4
  34. snowflake/snowpark_connect/utils/telemetry.py +213 -61
  35. snowflake/snowpark_connect/utils/udxf_import_utils.py +14 -0
  36. snowflake/snowpark_connect/version.py +1 -1
  37. snowflake/snowpark_decoder/__init__.py +0 -0
  38. snowflake/snowpark_decoder/_internal/proto/generated/DataframeProcessorMsg_pb2.py +36 -0
  39. snowflake/snowpark_decoder/_internal/proto/generated/DataframeProcessorMsg_pb2.pyi +156 -0
  40. snowflake/snowpark_decoder/dp_session.py +111 -0
  41. snowflake/snowpark_decoder/spark_decoder.py +76 -0
  42. {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.23.0.dist-info}/METADATA +2 -2
  43. {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.23.0.dist-info}/RECORD +55 -44
  44. {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.23.0.dist-info}/top_level.txt +1 -0
  45. spark/__init__.py +0 -0
  46. spark/connect/__init__.py +0 -0
  47. spark/connect/envelope_pb2.py +31 -0
  48. spark/connect/envelope_pb2.pyi +46 -0
  49. snowflake/snowpark_connect/includes/jars/jackson-mapper-asl-1.9.13.jar +0 -0
  50. {snowpark_connect-0.21.0.data → snowpark_connect-0.23.0.data}/scripts/snowpark-connect +0 -0
  51. {snowpark_connect-0.21.0.data → snowpark_connect-0.23.0.data}/scripts/snowpark-session +0 -0
  52. {snowpark_connect-0.21.0.data → snowpark_connect-0.23.0.data}/scripts/snowpark-submit +0 -0
  53. {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.23.0.dist-info}/WHEEL +0 -0
  54. {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.23.0.dist-info}/licenses/LICENSE-binary +0 -0
  55. {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.23.0.dist-info}/licenses/LICENSE.txt +0 -0
  56. {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.23.0.dist-info}/licenses/NOTICE-binary +0 -0
@@ -28,6 +28,7 @@ from google.protobuf.message import Message
28
28
  from pyspark.errors.exceptions.base import (
29
29
  AnalysisException,
30
30
  ArithmeticException,
31
+ ArrayIndexOutOfBoundsException,
31
32
  DateTimeException,
32
33
  IllegalArgumentException,
33
34
  NumberFormatException,
@@ -39,6 +40,7 @@ from pyspark.sql.types import _parse_datatype_json_string
39
40
  import snowflake.snowpark.functions as snowpark_fn
40
41
  from snowflake import snowpark
41
42
  from snowflake.snowpark import Column, Session
43
+ from snowflake.snowpark._internal.analyzer.expression import Literal
42
44
  from snowflake.snowpark._internal.analyzer.unary_expression import Alias
43
45
  from snowflake.snowpark.types import (
44
46
  ArrayType,
@@ -139,7 +141,7 @@ from snowflake.snowpark_connect.utils.xxhash64 import (
139
141
  MAX_UINT64 = 2**64 - 1
140
142
  MAX_INT64 = 2**63 - 1
141
143
  MIN_INT64 = -(2**63)
142
-
144
+ MAX_ARRAY_SIZE = 2_147_483_647
143
145
 
144
146
  NAN, INFINITY = float("nan"), float("inf")
145
147
 
@@ -638,37 +640,22 @@ def map_unresolved_function(
638
640
  [arg.typ for arg in snowpark_typed_args]
639
641
  )
640
642
  case "/":
641
- if isinstance(
642
- snowpark_typed_args[0].typ, (IntegerType, LongType, ShortType)
643
- ) and isinstance(
644
- snowpark_typed_args[1].typ, (IntegerType, LongType, ShortType)
645
- ):
646
- # Check if both arguments are integer types. Snowpark performs integer division, and precision is lost.
647
- # Cast to double and perform division
648
- result_exp = _divnull(
649
- snowpark_args[0].cast(DoubleType()),
650
- snowpark_args[1].cast(DoubleType()),
651
- )
652
- result_type = DoubleType()
653
- elif (
654
- isinstance(snowpark_typed_args[0].typ, DecimalType)
655
- and isinstance(snowpark_typed_args[1].typ, DecimalType)
656
- or isinstance(snowpark_typed_args[0].typ, DecimalType)
657
- and isinstance(snowpark_typed_args[1].typ, _IntegralType)
658
- or isinstance(snowpark_typed_args[0].typ, _IntegralType)
659
- and isinstance(snowpark_typed_args[1].typ, DecimalType)
660
- ):
661
- result_exp, (
662
- return_type_precision,
663
- return_type_scale,
664
- ) = _mul_div_precision_helper(snowpark_typed_args, snowpark_args, 1)
665
- result_type = DecimalType(return_type_precision, return_type_scale)
666
- else:
667
- # Perform division directly
668
- result_exp = _divnull(snowpark_args[0], snowpark_args[1])
669
- result_type = _find_common_type(
670
- [arg.typ for arg in snowpark_typed_args]
671
- )
643
+ match (snowpark_typed_args[0].typ, snowpark_typed_args[1].typ):
644
+ case (DecimalType(), t) | (t, DecimalType()) if isinstance(
645
+ t, DecimalType
646
+ ) or isinstance(t, _IntegralType) or isinstance(
647
+ snowpark_typed_args[1].typ, NullType
648
+ ):
649
+ result_exp, (
650
+ return_type_precision,
651
+ return_type_scale,
652
+ ) = _mul_div_precision_helper(snowpark_typed_args, snowpark_args, 1)
653
+ result_type = DecimalType(return_type_precision, return_type_scale)
654
+ case _:
655
+ result_type = DoubleType()
656
+ dividend = snowpark_args[0].cast(result_type)
657
+ divisor = snowpark_args[1].cast(result_type)
658
+ result_exp = _divnull(dividend, divisor)
672
659
  case "~":
673
660
  result_exp = TypedColumn(
674
661
  snowpark_fn.bitnot(snowpark_args[0]),
@@ -1120,7 +1107,7 @@ def map_unresolved_function(
1120
1107
  result_exp = TypedColumn(
1121
1108
  result_exp, lambda: [ArrayType(snowpark_typed_args[0].typ)]
1122
1109
  )
1123
- case "array_size" | "cardinality":
1110
+ case "array_size":
1124
1111
  array_type = snowpark_typed_args[0].typ
1125
1112
  if not isinstance(array_type, ArrayType):
1126
1113
  raise AnalysisException(
@@ -1129,6 +1116,16 @@ def map_unresolved_function(
1129
1116
  result_exp = TypedColumn(
1130
1117
  snowpark_fn.array_size(*snowpark_args), lambda: [LongType()]
1131
1118
  )
1119
+ case "cardinality":
1120
+ arg_type = snowpark_typed_args[0].typ
1121
+ if isinstance(arg_type, (ArrayType, MapType)):
1122
+ result_exp = TypedColumn(
1123
+ snowpark_fn.size(*snowpark_args), lambda: [LongType()]
1124
+ )
1125
+ else:
1126
+ raise AnalysisException(
1127
+ f"Expected argument '{snowpark_arg_names[0]}' to have an ArrayType or MapType, but got {arg_type.simpleString()}."
1128
+ )
1132
1129
  case "array_sort":
1133
1130
  result_exp = TypedColumn(
1134
1131
  snowpark_fn.array_sort(*snowpark_args),
@@ -1201,35 +1198,18 @@ def map_unresolved_function(
1201
1198
  snowpark_fn.asinh(snowpark_args[0]), lambda: [DoubleType()]
1202
1199
  )
1203
1200
  case "assert_true":
1201
+ result_type = NullType()
1202
+ raise_error = _raise_error_helper(result_type)
1204
1203
 
1205
- @cached_udf(
1206
- input_types=[BooleanType()],
1207
- return_type=StringType(),
1208
- )
1209
- def _assert_true_single(expr):
1210
- if not expr:
1211
- raise ValueError("assertion failed")
1212
- return None
1213
-
1214
- @cached_udf(
1215
- input_types=[BooleanType(), StringType()],
1216
- return_type=StringType(),
1217
- )
1218
- def _assert_true_with_message(expr, message):
1219
- if not expr:
1220
- raise ValueError(message)
1221
- return None
1222
-
1223
- # Handle different argument counts using match pattern
1224
1204
  match snowpark_args:
1225
1205
  case [expr]:
1226
- result_exp = TypedColumn(
1227
- _assert_true_single(expr), lambda: [StringType()]
1228
- )
1206
+ result_exp = snowpark_fn.when(
1207
+ expr, snowpark_fn.lit(None)
1208
+ ).otherwise(raise_error(snowpark_fn.lit("assertion failed")))
1229
1209
  case [expr, message]:
1230
- result_exp = TypedColumn(
1231
- _assert_true_with_message(expr, message), lambda: [StringType()]
1232
- )
1210
+ result_exp = snowpark_fn.when(
1211
+ expr, snowpark_fn.lit(None)
1212
+ ).otherwise(raise_error(snowpark_fn.cast(message, StringType())))
1233
1213
  case _:
1234
1214
  raise AnalysisException(
1235
1215
  f"[WRONG_NUM_ARGS.WITHOUT_SUGGESTION] The `assert_true` requires 1 or 2 parameters but the actual number is {len(snowpark_args)}."
@@ -1325,10 +1305,35 @@ def map_unresolved_function(
1325
1305
  )
1326
1306
  result_exp = TypedColumn(result_exp, lambda: [LongType()])
1327
1307
  case "bit_get" | "getbit":
1328
- bit_get_function = snowpark_fn.function("GETBIT")
1329
- result_exp = TypedColumn(
1330
- bit_get_function(*snowpark_args), lambda: [LongType()]
1308
+ snowflake_compat = get_boolean_session_config_param(
1309
+ "enable_snowflake_extension_behavior"
1331
1310
  )
1311
+ col, pos = snowpark_args
1312
+ if snowflake_compat:
1313
+ bit_get_function = snowpark_fn.function("GETBIT")(col, pos)
1314
+ else:
1315
+ raise_error = _raise_error_helper(LongType())
1316
+ bit_get_function = snowpark_fn.when(
1317
+ (snowpark_fn.lit(0) <= pos) & (pos <= snowpark_fn.lit(63))
1318
+ | snowpark_fn.is_null(pos),
1319
+ snowpark_fn.function("GETBIT")(col, pos),
1320
+ ).otherwise(
1321
+ raise_error(
1322
+ snowpark_fn.concat(
1323
+ snowpark_fn.lit(
1324
+ "Invalid bit position: ",
1325
+ ),
1326
+ snowpark_fn.cast(
1327
+ pos,
1328
+ StringType(),
1329
+ ),
1330
+ snowpark_fn.lit(
1331
+ " exceeds the bit upper limit",
1332
+ ),
1333
+ )
1334
+ )
1335
+ )
1336
+ result_exp = TypedColumn(bit_get_function, lambda: [LongType()])
1332
1337
  case "bit_length":
1333
1338
  bit_length_function = snowpark_fn.function("bit_length")
1334
1339
  result_exp = TypedColumn(
@@ -2291,31 +2296,32 @@ def map_unresolved_function(
2291
2296
  )
2292
2297
  case "elt":
2293
2298
  n = snowpark_args[0]
2294
-
2295
2299
  values = snowpark_fn.array_construct(*snowpark_args[1:])
2296
2300
 
2297
2301
  if spark_sql_ansi_enabled:
2298
-
2299
- @cached_udf(
2300
- input_types=[IntegerType()],
2301
- return_type=StringType(),
2302
+ raise_error = _raise_error_helper(
2303
+ StringType(), error_class=ArrayIndexOutOfBoundsException
2302
2304
  )
2303
- def _raise_out_of_bounds_error(n: int) -> str:
2304
- raise ValueError(
2305
- f"ArrayIndexOutOfBoundsException: {n} is not within the input bounds."
2306
- )
2307
-
2308
2305
  values_size = snowpark_fn.lit(len(snowpark_args) - 1)
2309
2306
 
2310
2307
  result_exp = (
2311
2308
  snowpark_fn.when(snowpark_fn.is_null(n), snowpark_fn.lit(None))
2312
2309
  .when(
2313
2310
  (snowpark_fn.lit(1) <= n) & (n <= values_size),
2314
- snowpark_fn.get(
2315
- values, snowpark_fn.nvl(n - 1, snowpark_fn.lit(0))
2311
+ snowpark_fn.cast(
2312
+ snowpark_fn.get(
2313
+ values, snowpark_fn.nvl(n - 1, snowpark_fn.lit(0))
2314
+ ),
2315
+ StringType(),
2316
2316
  ),
2317
2317
  )
2318
- .otherwise(_raise_out_of_bounds_error(n))
2318
+ .otherwise(
2319
+ raise_error(
2320
+ snowpark_fn.lit("[INVALID_ARRAY_INDEX] The index "),
2321
+ snowpark_fn.cast(n, StringType()),
2322
+ snowpark_fn.lit(" is out of bounds."),
2323
+ )
2324
+ )
2319
2325
  )
2320
2326
  else:
2321
2327
  result_exp = snowpark_fn.when(
@@ -3289,7 +3295,7 @@ def map_unresolved_function(
3289
3295
  ).cast(LongType())
3290
3296
  result_type = LongType()
3291
3297
  case "hll_union_agg":
3292
- raise_error = _raise_error_udf_helper(BinaryType())
3298
+ raise_error = _raise_error_helper(BinaryType())
3293
3299
  args = exp.unresolved_function.arguments
3294
3300
  allow_different_lgConfigK = len(args) == 2 and unwrap_literal(args[1])
3295
3301
  spark_function_name = f"{function_name}({snowpark_arg_names[0]}, {str(allow_different_lgConfigK).lower()})"
@@ -3329,7 +3335,7 @@ def map_unresolved_function(
3329
3335
  SELECT arg1 as x)
3330
3336
  """,
3331
3337
  )
3332
- raise_error = _raise_error_udf_helper(BinaryType())
3338
+ raise_error = _raise_error_helper(BinaryType())
3333
3339
  args = exp.unresolved_function.arguments
3334
3340
  allow_different_lgConfigK = len(args) == 3 and unwrap_literal(args[2])
3335
3341
  spark_function_name = f"{function_name}({snowpark_arg_names[0]}, {snowpark_arg_names[1]}, {str(allow_different_lgConfigK).lower()})"
@@ -3816,7 +3822,13 @@ def map_unresolved_function(
3816
3822
  case "locate":
3817
3823
  substr = unwrap_literal(exp.unresolved_function.arguments[0])
3818
3824
  value = snowpark_args[1]
3819
- start_pos = unwrap_literal(exp.unresolved_function.arguments[2])
3825
+ if len(exp.unresolved_function.arguments) == 3:
3826
+ start_pos = unwrap_literal(exp.unresolved_function.arguments[2])
3827
+ else:
3828
+ # start_pos is an optional argument and if not provided we should default to 1.
3829
+ # This path will only be reached by spark connect scala clients.
3830
+ start_pos = 1
3831
+ spark_function_name = f"locate({', '.join(snowpark_arg_names)}, 1)"
3820
3832
 
3821
3833
  if start_pos > 0:
3822
3834
  result_exp = snowpark_fn.locate(substr, value, start_pos)
@@ -4053,7 +4065,7 @@ def map_unresolved_function(
4053
4065
  snowpark_fn.is_null(snowpark_args[i]),
4054
4066
  # udf execution on XP seems to be lazy, so this should only run when there is a null key
4055
4067
  # otherwise there should be no udf env setup or execution
4056
- _raise_error_udf_helper(VariantType())(
4068
+ _raise_error_helper(VariantType())(
4057
4069
  snowpark_fn.lit(
4058
4070
  "[NULL_MAP_KEY] Cannot use null as map key."
4059
4071
  )
@@ -4115,6 +4127,14 @@ def map_unresolved_function(
4115
4127
  )
4116
4128
  result_type = MapType(key_type, value_type)
4117
4129
  case "map_contains_key":
4130
+ if isinstance(snowpark_typed_args[0].typ, NullType):
4131
+ raise AnalysisException(
4132
+ f"""[DATATYPE_MISMATCH.MAP_FUNCTION_DIFF_TYPES] Cannot resolve "map_contains_key({snowpark_arg_names[0]}, {snowpark_arg_names[1]})" due to data type mismatch: Input to `map_contains_key` should have been "MAP" followed by a value with same key type, but it's ["VOID", "INT"]."""
4133
+ )
4134
+ if isinstance(snowpark_typed_args[1].typ, NullType):
4135
+ raise AnalysisException(
4136
+ f"""[DATATYPE_MISMATCH.NULL_TYPE] Cannot resolve "map_contains_key({snowpark_arg_names[0]}, {snowpark_arg_names[1]})" due to data type mismatch: Null typed values cannot be used as arguments of `map_contains_key`."""
4137
+ )
4118
4138
  args = (
4119
4139
  [snowpark_args[1], snowpark_args[0]]
4120
4140
  if isinstance(snowpark_typed_args[0].typ, MapType)
@@ -4244,17 +4264,37 @@ def map_unresolved_function(
4244
4264
 
4245
4265
  last_win_dedup = global_config.spark_sql_mapKeyDedupPolicy == "LAST_WIN"
4246
4266
 
4247
- result_exp = snowpark_fn.cast(
4248
- snowpark_fn.function("reduce")(
4249
- snowpark_args[0],
4250
- snowpark_fn.object_construct(),
4251
- snowpark_fn.sql_expr(
4252
- # value_field is cast to variant because object_insert doesn't allow structured types,
4253
- # and structured types are not coercible to variant
4254
- # TODO: allow structured types in object_insert?
4255
- f"(acc, e) -> object_insert(acc, e:{key_field}, e:{value_field}::variant, {last_win_dedup})"
4256
- ),
4267
+ # Check if any entry has a NULL key
4268
+ has_null_key = (
4269
+ snowpark_fn.function("array_size")(
4270
+ snowpark_fn.function("filter")(
4271
+ snowpark_args[0],
4272
+ snowpark_fn.sql_expr(f"e -> e:{key_field} IS NULL"),
4273
+ )
4274
+ )
4275
+ > 0
4276
+ )
4277
+
4278
+ # Create error UDF for NULL keys (same pattern as map function)
4279
+ null_key_error = _raise_error_helper(VariantType())(
4280
+ snowpark_fn.lit("[NULL_MAP_KEY] Cannot use null as map key.")
4281
+ )
4282
+
4283
+ # Create the reduce operation
4284
+ reduce_result = snowpark_fn.function("reduce")(
4285
+ snowpark_args[0],
4286
+ snowpark_fn.object_construct(),
4287
+ snowpark_fn.sql_expr(
4288
+ # value_field is cast to variant because object_insert doesn't allow structured types,
4289
+ # and structured types are not coercible to variant
4290
+ # TODO: allow structured types in object_insert?
4291
+ f"(acc, e) -> object_insert(acc, e:{key_field}, e:{value_field}::variant, {last_win_dedup})"
4257
4292
  ),
4293
+ )
4294
+
4295
+ # Use conditional logic: if there are NULL keys, throw error; otherwise proceed with reduce
4296
+ result_exp = snowpark_fn.cast(
4297
+ snowpark_fn.when(has_null_key, null_key_error).otherwise(reduce_result),
4258
4298
  MapType(key_type, value_type),
4259
4299
  )
4260
4300
  result_type = MapType(key_type, value_type)
@@ -4273,23 +4313,35 @@ def map_unresolved_function(
4273
4313
  # TODO: implement in Snowflake/Snowpark
4274
4314
  # technically this could be done with a lateral join, but it's probably not worth the effort
4275
4315
  arg_type = snowpark_typed_args[0].typ
4276
- if not isinstance(arg_type, MapType):
4316
+ if not isinstance(arg_type, (MapType, NullType)):
4277
4317
  raise AnalysisException(
4278
4318
  f"map_values requires a MapType argument, got {arg_type}"
4279
4319
  )
4280
4320
 
4281
4321
  def _map_values(obj: dict) -> list:
4282
- return list(obj.values()) if obj else None
4322
+ if obj is None:
4323
+ return None
4324
+ return list(obj.values())
4283
4325
 
4284
4326
  map_values = cached_udf(
4285
4327
  _map_values, return_type=ArrayType(), input_types=[StructType()]
4286
4328
  )
4287
4329
 
4288
- result_exp = snowpark_fn.cast(
4289
- map_values(snowpark_fn.cast(snowpark_args[0], StructType())),
4290
- ArrayType(arg_type.value_type),
4291
- )
4292
- result_type = ArrayType(arg_type.value_type)
4330
+ # Handle NULL input directly at expression level
4331
+ if isinstance(arg_type, NullType):
4332
+ # If input is NULL literal, return NULL
4333
+ result_exp = snowpark_fn.lit(None)
4334
+ result_type = ArrayType(NullType())
4335
+ else:
4336
+ result_exp = snowpark_fn.when(
4337
+ snowpark_args[0].is_null(), snowpark_fn.lit(None)
4338
+ ).otherwise(
4339
+ snowpark_fn.cast(
4340
+ map_values(snowpark_fn.cast(snowpark_args[0], StructType())),
4341
+ ArrayType(arg_type.value_type),
4342
+ )
4343
+ )
4344
+ result_type = ArrayType(arg_type.value_type)
4293
4345
  case "mask":
4294
4346
 
4295
4347
  number_of_args = len(snowpark_args)
@@ -5194,7 +5246,7 @@ def map_unresolved_function(
5194
5246
  result_type = DoubleType()
5195
5247
  case "raise_error":
5196
5248
  result_type = StringType()
5197
- raise_error = _raise_error_udf_helper(result_type)
5249
+ raise_error = _raise_error_helper(result_type)
5198
5250
  result_exp = raise_error(*snowpark_args)
5199
5251
  case "rand" | "random":
5200
5252
  # Snowpark random() generates a 64 bit signed integer, but pyspark is [0.0, 1.0).
@@ -5279,7 +5331,7 @@ def map_unresolved_function(
5279
5331
  snowpark_args[2],
5280
5332
  ),
5281
5333
  ),
5282
- _raise_error_udf_helper(StringType())(
5334
+ _raise_error_helper(StringType())(
5283
5335
  snowpark_fn.lit(
5284
5336
  "[INVALID_PARAMETER_VALUE.REGEX_GROUP_INDEX] The value of parameter(s) `idx` in `regexp_extract` is invalid."
5285
5337
  )
@@ -5329,7 +5381,7 @@ def map_unresolved_function(
5329
5381
  idx,
5330
5382
  )
5331
5383
  ),
5332
- _raise_error_udf_helper(ArrayType(StringType()))(
5384
+ _raise_error_helper(ArrayType(StringType()))(
5333
5385
  snowpark_fn.lit(
5334
5386
  "[INVALID_PARAMETER_VALUE.REGEX_GROUP_INDEX] The value of parameter(s) `idx` in `regexp_extract_all` is invalid."
5335
5387
  )
@@ -5485,9 +5537,27 @@ def map_unresolved_function(
5485
5537
  ):
5486
5538
  result_exp = snowpark_fn.lit(None)
5487
5539
  else:
5540
+ right_expr = snowpark_fn.right(*snowpark_args)
5541
+ if isinstance(snowpark_typed_args[0].typ, TimestampType):
5542
+ # Spark format is always displayed as YYY-MM-DD HH:mm:ss.FF6
5543
+ # When microseconds are equal to 0 .FF6 part is removed
5544
+ # When microseconds are equal to 0 at the end, they are removed i.e. .123000 -> .123 when displayed
5545
+
5546
+ formated_timestamp = snowpark_fn.to_varchar(
5547
+ snowpark_args[0], "YYYY-MM-DD HH:MI:SS.FF6"
5548
+ )
5549
+ right_expr = snowpark_fn.right(
5550
+ snowpark_fn.regexp_replace(
5551
+ snowpark_fn.regexp_replace(formated_timestamp, "0+$", ""),
5552
+ "\\.$",
5553
+ "",
5554
+ ),
5555
+ snowpark_args[1],
5556
+ )
5557
+
5488
5558
  result_exp = snowpark_fn.when(
5489
5559
  snowpark_args[1] <= 0, snowpark_fn.lit("")
5490
- ).otherwise(snowpark_fn.right(*snowpark_args))
5560
+ ).otherwise(right_expr)
5491
5561
  result_type = StringType()
5492
5562
  case "rint":
5493
5563
  result_exp = snowpark_fn.cast(
@@ -6043,7 +6113,7 @@ def map_unresolved_function(
6043
6113
  result_exp = snowpark_fn.skew(snowpark_fn.lit(None))
6044
6114
  result_type = DoubleType()
6045
6115
  case "slice":
6046
- raise_error = _raise_error_udf_helper(snowpark_typed_args[0].typ)
6116
+ raise_error = _raise_error_helper(snowpark_typed_args[0].typ)
6047
6117
  spark_index = snowpark_args[1]
6048
6118
  arr_size = snowpark_fn.array_size(snowpark_args[0])
6049
6119
  slice_len = snowpark_args[2]
@@ -6113,10 +6183,11 @@ def map_unresolved_function(
6113
6183
  result_exp = snowpark_fn.lit(0)
6114
6184
  result_type = LongType()
6115
6185
  case "split":
6186
+ result_type = ArrayType(StringType())
6116
6187
 
6117
6188
  @cached_udf(
6118
6189
  input_types=[StringType(), StringType(), IntegerType()],
6119
- return_type=ArrayType(StringType()),
6190
+ return_type=result_type,
6120
6191
  )
6121
6192
  def _split(
6122
6193
  input: Optional[str], pattern: Optional[str], limit: Optional[int]
@@ -6124,34 +6195,80 @@ def map_unresolved_function(
6124
6195
  if input is None or pattern is None:
6125
6196
  return None
6126
6197
 
6198
+ import re
6199
+
6200
+ try:
6201
+ re.compile(pattern)
6202
+ except re.error:
6203
+ raise ValueError(
6204
+ f"Failed to split string, provided pattern: {pattern} is invalid"
6205
+ )
6206
+
6127
6207
  if limit == 1:
6128
6208
  return [input]
6129
6209
 
6130
- import re
6210
+ if not input:
6211
+ return []
6131
6212
 
6132
6213
  # A default of -1 is passed in PySpark, but RE needs it to be 0 to provide all splits.
6133
6214
  # In PySpark, the limit also indicates the max size of the resulting array, but in RE
6134
6215
  # the remainder is returned as another element.
6135
6216
  maxsplit = limit - 1 if limit > 0 else 0
6136
6217
 
6137
- split_result = re.split(pattern, input, maxsplit)
6138
6218
  if len(pattern) == 0:
6139
- # RE.split provides a first and last empty element that is not there in PySpark.
6140
- split_result = split_result[1 : len(split_result) - 1]
6219
+ return list(input) if limit <= 0 else list(input)[:limit]
6220
+
6221
+ match pattern:
6222
+ case "|":
6223
+ split_result = re.split(pattern, input, 0)
6224
+ input_limit = limit + 1 if limit > 0 else len(split_result)
6225
+ return (
6226
+ split_result
6227
+ if input_limit == 0
6228
+ else split_result[1:input_limit]
6229
+ )
6230
+ case "$":
6231
+ return [input, ""] if maxsplit >= 0 else [input]
6232
+ case "^":
6233
+ return [input]
6234
+ case _:
6235
+ return re.split(pattern, input, maxsplit)
6141
6236
 
6142
- return split_result
6237
+ def split_string(str_: Column, pattern: Column, limit: Column):
6238
+ native_split = _split(str_, pattern, limit)
6239
+ # When pattern is a literal and doesn't contain any regex special characters
6240
+ # And when limit is less than or equal to 0
6241
+ # Native Snowflake Split function is used to optimise performance
6242
+ if isinstance(pattern._expression, Literal):
6243
+ pattern_value = pattern._expression.value
6244
+
6245
+ if pattern_value is None:
6246
+ return snowpark_fn.lit(None)
6247
+
6248
+ is_regexp = re.match(
6249
+ ".*[\\[\\.\\]\\*\\?\\+\\^\\$\\{\\}\\|\\(\\)\\\\].*",
6250
+ pattern_value,
6251
+ )
6252
+ is_empty = len(pattern_value) == 0
6253
+
6254
+ if not is_empty and not is_regexp:
6255
+ return snowpark_fn.when(
6256
+ limit <= 0,
6257
+ snowpark_fn.split(str_, pattern).cast(result_type),
6258
+ ).otherwise(native_split)
6259
+
6260
+ return native_split
6143
6261
 
6144
6262
  match snowpark_args:
6145
6263
  case [str_, pattern]:
6146
6264
  spark_function_name = (
6147
6265
  f"split({snowpark_arg_names[0]}, {snowpark_arg_names[1]}, -1)"
6148
6266
  )
6149
- result_exp = _split(str_, pattern, snowpark_fn.lit(0))
6267
+ result_exp = split_string(str_, pattern, snowpark_fn.lit(-1))
6150
6268
  case [str_, pattern, limit]: # noqa: F841
6151
- result_exp = _split(str_, pattern, limit)
6269
+ result_exp = split_string(str_, pattern, limit)
6152
6270
  case _:
6153
6271
  raise ValueError(f"Invalid number of arguments to {function_name}")
6154
- result_type = ArrayType(StringType())
6155
6272
  case "split_part":
6156
6273
  result_exp = snowpark_fn.call_function("split_part", *snowpark_args)
6157
6274
  result_type = StringType()
@@ -6671,6 +6788,7 @@ def map_unresolved_function(
6671
6788
  if value == "" or any(
6672
6789
  c in value for c in [",", "\n", "\r", '"', "'"]
6673
6790
  ):
6791
+ value = value.replace("\\", "\\\\").replace('"', '\\"')
6674
6792
  result.append(f'"{value}"')
6675
6793
  else:
6676
6794
  result.append(value)
@@ -6920,7 +7038,20 @@ def map_unresolved_function(
6920
7038
  result_type = StringType()
6921
7039
  case "trunc":
6922
7040
  part = unwrap_literal(exp.unresolved_function.arguments[1])
6923
- if part is None:
7041
+ part = None if part is None else part.lower()
7042
+
7043
+ allowed_parts = {
7044
+ "year",
7045
+ "yyyy",
7046
+ "yy",
7047
+ "month",
7048
+ "mon",
7049
+ "mm",
7050
+ "week",
7051
+ "quarter",
7052
+ }
7053
+
7054
+ if part not in allowed_parts:
6924
7055
  result_exp = snowpark_fn.lit(None)
6925
7056
  else:
6926
7057
  result_exp = _try_to_cast(
@@ -7311,7 +7442,7 @@ def map_unresolved_function(
7311
7442
  )
7312
7443
  )
7313
7444
  )
7314
- raise_fn = _raise_error_udf_helper(BinaryType())
7445
+ raise_fn = _raise_error_helper(BinaryType(), IllegalArgumentException)
7315
7446
  result_exp = (
7316
7447
  snowpark_fn.when(unbase_arg.is_null(), snowpark_fn.lit(None))
7317
7448
  .when(result_exp.is_null(), raise_fn(snowpark_fn.lit("Invalid input")))
@@ -7927,16 +8058,8 @@ def _handle_current_timestamp():
7927
8058
 
7928
8059
 
7929
8060
  def _equivalent_decimal(type):
7930
- match (type):
7931
- case ByteType():
7932
- return DecimalType(3, 0)
7933
- case ShortType():
7934
- return DecimalType(5, 0)
7935
- case IntegerType():
7936
- return DecimalType(10, 0)
7937
- case LongType():
7938
- return DecimalType(20, 0)
7939
- return DecimalType(38, 0)
8061
+ (precision, scale) = _get_type_precision(type)
8062
+ return DecimalType(precision, scale)
7940
8063
 
7941
8064
 
7942
8065
  def _resolve_decimal_and_numeric(type1: DecimalType, type2: _NumericType) -> DataType:
@@ -8995,7 +9118,9 @@ def _get_type_precision(typ: DataType) -> tuple[int, int]:
8995
9118
  case IntegerType():
8996
9119
  return 10, 0 # -2147483648 to 2147483647
8997
9120
  case LongType():
8998
- return 19, 0 # -9223372036854775808 to 9223372036854775807
9121
+ return 20, 0 # -9223372036854775808 to 9223372036854775807
9122
+ case NullType():
9123
+ return 6, 2 # NULL
8999
9124
  case _:
9000
9125
  return 38, 0 # Default to maximum precision for other types
9001
9126
 
@@ -9210,16 +9335,12 @@ def _try_arithmetic_helper(
9210
9335
  typed_args[1].typ, DecimalType
9211
9336
  ):
9212
9337
  new_scale = s2
9213
- new_precision = (
9214
- p1 + s2 + 1
9215
- ) # Integral precision + decimal scale + 1 for carry
9338
+ new_precision = max(p2, p1 + s2)
9216
9339
  elif isinstance(typed_args[0].typ, DecimalType) and isinstance(
9217
9340
  typed_args[1].typ, _IntegralType
9218
9341
  ):
9219
9342
  new_scale = s1
9220
- new_precision = (
9221
- p2 + s1 + 1
9222
- ) # Integral precision + decimal scale + 1 for carry
9343
+ new_precision = max(p1, p2 + s1)
9223
9344
  else:
9224
9345
  # Both decimal types
9225
9346
  if operation_type == 1 and s1 == s2: # subtraction with matching scales
@@ -9298,13 +9419,13 @@ def _add_sub_precision_helper(
9298
9419
  typed_args[1].typ, DecimalType
9299
9420
  ):
9300
9421
  new_scale = s2
9301
- new_precision = p1 + s2 + 1 # Integral precision + decimal scale + 1 for carry
9422
+ new_precision = max(p2, p1 + s2)
9302
9423
  return_type_precision, return_type_scale = new_precision, new_scale
9303
9424
  elif isinstance(typed_args[0].typ, DecimalType) and isinstance(
9304
9425
  typed_args[1].typ, _IntegralType
9305
9426
  ):
9306
9427
  new_scale = s1
9307
- new_precision = p2 + s1 + 1 # Integral precision + decimal scale + 1 for carry
9428
+ new_precision = max(p1, p2 + s1)
9308
9429
  return_type_precision, return_type_scale = new_precision, new_scale
9309
9430
  else:
9310
9431
  (
@@ -9386,11 +9507,25 @@ def _mul_div_precision_helper(
9386
9507
  )
9387
9508
 
9388
9509
 
9389
- def _raise_error_udf_helper(return_type: DataType):
9390
- def _raise_error(message=None):
9391
- raise ValueError(message)
9510
+ def _raise_error_helper(return_type: DataType, error_class=None):
9511
+ error_class = (
9512
+ f":{error_class.__name__}"
9513
+ if error_class and hasattr(error_class, "__name__")
9514
+ else ""
9515
+ )
9516
+
9517
+ def _raise_fn(*msgs: Column) -> Column:
9518
+ return snowpark_fn.cast(
9519
+ snowpark_fn.abs(
9520
+ snowpark_fn.concat(
9521
+ snowpark_fn.lit(f"[snowpark-connect-exception{error_class}]"),
9522
+ *(msg.try_cast(StringType()) for msg in msgs),
9523
+ )
9524
+ ).cast(StringType()),
9525
+ return_type,
9526
+ )
9392
9527
 
9393
- return cached_udf(_raise_error, return_type=return_type, input_types=[StringType()])
9528
+ return _raise_fn
9394
9529
 
9395
9530
 
9396
9531
  def _divnull(dividend: Column, divisor: Column) -> Column: