snowpark-connect 1.6.0__py3-none-any.whl → 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. snowflake/snowpark_connect/client/server.py +37 -0
  2. snowflake/snowpark_connect/config.py +72 -3
  3. snowflake/snowpark_connect/expression/error_utils.py +28 -0
  4. snowflake/snowpark_connect/expression/integral_types_support.py +219 -0
  5. snowflake/snowpark_connect/expression/map_cast.py +108 -17
  6. snowflake/snowpark_connect/expression/map_udf.py +1 -0
  7. snowflake/snowpark_connect/expression/map_unresolved_function.py +229 -96
  8. snowflake/snowpark_connect/includes/jars/json4s-ast_2.13-3.7.0-M11.jar +0 -0
  9. snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.2.0.jar +0 -0
  10. snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.13-0.2.0.jar +0 -0
  11. snowflake/snowpark_connect/includes/jars/scala-reflect-2.13.16.jar +0 -0
  12. snowflake/snowpark_connect/includes/jars/spark-common-utils_2.13-3.5.6.jar +0 -0
  13. snowflake/snowpark_connect/includes/jars/spark-connect-client-jvm_2.13-3.5.6.jar +0 -0
  14. snowflake/snowpark_connect/includes/jars/spark-sql_2.13-3.5.6.jar +0 -0
  15. snowflake/snowpark_connect/relation/map_aggregate.py +43 -1
  16. snowflake/snowpark_connect/relation/read/map_read_csv.py +73 -4
  17. snowflake/snowpark_connect/relation/read/map_read_jdbc.py +4 -1
  18. snowflake/snowpark_connect/relation/read/map_read_json.py +4 -1
  19. snowflake/snowpark_connect/relation/read/map_read_parquet.py +4 -1
  20. snowflake/snowpark_connect/relation/read/map_read_socket.py +4 -0
  21. snowflake/snowpark_connect/relation/read/map_read_table.py +4 -1
  22. snowflake/snowpark_connect/relation/read/map_read_text.py +4 -1
  23. snowflake/snowpark_connect/relation/read/reader_config.py +6 -0
  24. snowflake/snowpark_connect/resources_initializer.py +90 -29
  25. snowflake/snowpark_connect/server.py +6 -41
  26. snowflake/snowpark_connect/server_common/__init__.py +4 -1
  27. snowflake/snowpark_connect/type_support.py +130 -0
  28. snowflake/snowpark_connect/utils/context.py +8 -0
  29. snowflake/snowpark_connect/utils/java_stored_procedure.py +53 -27
  30. snowflake/snowpark_connect/utils/java_udaf_utils.py +46 -28
  31. snowflake/snowpark_connect/utils/java_udtf_utils.py +1 -1
  32. snowflake/snowpark_connect/utils/jvm_udf_utils.py +48 -15
  33. snowflake/snowpark_connect/utils/scala_udf_utils.py +98 -22
  34. snowflake/snowpark_connect/utils/telemetry.py +33 -22
  35. snowflake/snowpark_connect/utils/udxf_import_utils.py +9 -2
  36. snowflake/snowpark_connect/version.py +1 -1
  37. {snowpark_connect-1.6.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-submit +12 -2
  38. {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/METADATA +4 -2
  39. {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/RECORD +46 -37
  40. {snowpark_connect-1.6.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-connect +0 -0
  41. {snowpark_connect-1.6.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-session +0 -0
  42. {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/WHEEL +0 -0
  43. {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/LICENSE-binary +0 -0
  44. {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/LICENSE.txt +0 -0
  45. {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/NOTICE-binary +0 -0
  46. {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/top_level.txt +0 -0
@@ -89,6 +89,12 @@ from snowflake.snowpark_connect.error.error_utils import attach_custom_error_cod
89
89
  from snowflake.snowpark_connect.expression.function_defaults import (
90
90
  inject_function_defaults,
91
91
  )
92
+ from snowflake.snowpark_connect.expression.integral_types_support import (
93
+ apply_abs_overflow_with_ansi_check,
94
+ apply_arithmetic_overflow_with_ansi_check,
95
+ apply_unary_overflow_with_ansi_check,
96
+ get_integral_type_bounds,
97
+ )
92
98
  from snowflake.snowpark_connect.expression.literal import get_literal_field_and_name
93
99
  from snowflake.snowpark_connect.expression.map_cast import (
94
100
  CAST_FUNCTIONS,
@@ -607,6 +613,7 @@ def map_unresolved_function(
607
613
  overflow_possible,
608
614
  global_config.spark_sql_ansi_enabled,
609
615
  result_type,
616
+ "multiply",
610
617
  )
611
618
  case (NullType(), NullType()):
612
619
  result_type = DoubleType()
@@ -750,9 +757,17 @@ def map_unresolved_function(
750
757
  result_type = _find_common_type(
751
758
  [arg.typ for arg in snowpark_typed_args]
752
759
  )
753
- result_exp = snowpark_args[0].cast(result_type) * snowpark_args[
754
- 1
755
- ].cast(result_type)
760
+ if isinstance(result_type, _IntegralType):
761
+ raw_result = snowpark_args[0].cast(result_type) * snowpark_args[
762
+ 1
763
+ ].cast(result_type)
764
+ result_exp = apply_arithmetic_overflow_with_ansi_check(
765
+ raw_result, result_type, spark_sql_ansi_enabled, "multiply"
766
+ )
767
+ else:
768
+ result_exp = snowpark_args[0].cast(result_type) * snowpark_args[
769
+ 1
770
+ ].cast(result_type)
756
771
  case _:
757
772
  exception = AnalysisException(
758
773
  f'[DATATYPE_MISMATCH.BINARY_OP_DIFF_TYPES] Cannot resolve "{spark_function_name}" due to data type mismatch: the left and right operands of the binary operator have incompatible types ("{snowpark_typed_args[0].typ}" and "{snowpark_typed_args[1].typ}").'
@@ -1012,6 +1027,7 @@ def map_unresolved_function(
1012
1027
  overflow_possible,
1013
1028
  global_config.spark_sql_ansi_enabled,
1014
1029
  result_type,
1030
+ "add",
1015
1031
  )
1016
1032
 
1017
1033
  case "-":
@@ -1194,6 +1210,7 @@ def map_unresolved_function(
1194
1210
  overflow_possible,
1195
1211
  global_config.spark_sql_ansi_enabled,
1196
1212
  result_type,
1213
+ "subtract",
1197
1214
  )
1198
1215
 
1199
1216
  case "/":
@@ -1218,6 +1235,7 @@ def map_unresolved_function(
1218
1235
  overflow_possible,
1219
1236
  global_config.spark_sql_ansi_enabled,
1220
1237
  result_type,
1238
+ "divide",
1221
1239
  )
1222
1240
  case (NullType(), NullType()):
1223
1241
  result_type = DoubleType()
@@ -1471,6 +1489,11 @@ def map_unresolved_function(
1471
1489
  snowpark_fn.cast(snowpark_args[0], DoubleType())
1472
1490
  )
1473
1491
  result_type = DoubleType()
1492
+ elif isinstance(input_type, _IntegralType):
1493
+ result_exp = apply_abs_overflow_with_ansi_check(
1494
+ snowpark_args[0], input_type, spark_sql_ansi_enabled
1495
+ )
1496
+ result_type = input_type
1474
1497
  else:
1475
1498
  result_exp = snowpark_fn.abs(snowpark_args[0])
1476
1499
  result_type = input_type
@@ -3764,7 +3787,12 @@ def map_unresolved_function(
3764
3787
  snowpark_args[0].is_null(), snowpark_fn.lit(None)
3765
3788
  ).otherwise(snowpark_fn.cast(csv_result, result_type))
3766
3789
  case "from_json":
3767
- # TODO: support options.
3790
+ # TODO: support options parameter.
3791
+ # The options map (e.g., map('timestampFormat', 'dd/MM/yyyy')) is validated
3792
+ # but not currently used. To implement:
3793
+ # 1. Extract options from snowpark_args[2]
3794
+ # 2. Pass format options to JSON parsing/coercion logic
3795
+ # 3. Apply custom formats when casting timestamp/date fields
3768
3796
  if len(snowpark_args) > 2:
3769
3797
  if not isinstance(snowpark_typed_args[2].typ, MapType):
3770
3798
  exception = AnalysisException(
@@ -3792,6 +3820,29 @@ def map_unresolved_function(
3792
3820
  logger.debug("Failed to parse datatype json string: %s", e)
3793
3821
  result_type = map_type_string_to_snowpark_type(lit_schema)
3794
3822
 
3823
+ # Validate that all MapTypes in the schema have StringType keys.
3824
+ # JSON specification only supports string keys, so from_json cannot parse
3825
+ # into MapType with non-string keys (e.g., IntegerType, LongType).
3826
+ # Spark enforces this and raises INVALID_JSON_MAP_KEY_TYPE error.
3827
+ def _validate_map_key_types(data_type: DataType) -> None:
3828
+ """Recursively validate that all MapType instances have StringType keys."""
3829
+ if isinstance(data_type, MapType):
3830
+ if not isinstance(data_type.key_type, StringType):
3831
+ exception = AnalysisException(
3832
+ f"[INVALID_JSON_MAP_KEY_TYPE] Input schema {lit_schema} can only contain STRING as a key type for a MAP."
3833
+ )
3834
+ attach_custom_error_code(exception, ErrorCodes.TYPE_MISMATCH)
3835
+ raise exception
3836
+ # Check the value type recursively
3837
+ _validate_map_key_types(data_type.value_type)
3838
+ elif isinstance(data_type, ArrayType):
3839
+ _validate_map_key_types(data_type.element_type)
3840
+ elif isinstance(data_type, StructType):
3841
+ for field in data_type.fields:
3842
+ _validate_map_key_types(field.datatype)
3843
+
3844
+ _validate_map_key_types(result_type)
3845
+
3795
3846
  # if the result is a map, the column is named "entries"
3796
3847
  if isinstance(result_type, MapType):
3797
3848
  spark_function_name = "entries"
@@ -3846,11 +3897,14 @@ def map_unresolved_function(
3846
3897
  # let's optimistically assume that any simple type can be coerced to the expected type automatically
3847
3898
  return snowpark_fn.lit(True)
3848
3899
 
3849
- # There is a known limitation in snowflake while casting semi structured data to structured data
3850
- # that if some keys are missing in value the cast would fail
3851
- # we need to make sure it has the same "shape" as the result_type.
3852
- # This function will construct an expression
3853
- # that will convert the parsed json to the expected type.
3900
+ # Snowflake limitation: Casting semi-structured data to structured types fails
3901
+ # if the source doesn't have the exact "shape" (e.g., missing struct fields).
3902
+ # This function constructs an expression that coerces the parsed JSON to match
3903
+ # the expected type structure, filling in NULLs for missing fields.
3904
+ #
3905
+ # For complex types (StructType, ArrayType, MapType), this recursively ensures
3906
+ # nested structures match the schema. For MapType with complex values, it uses
3907
+ # a pure SQL REDUCE approach to avoid UDF-in-lambda errors.
3854
3908
  def _coerce_to_type(
3855
3909
  exp: Column, t: DataType, top_level: bool = True
3856
3910
  ) -> Column:
@@ -3873,13 +3927,15 @@ def map_unresolved_function(
3873
3927
  snowpark_fn.as_object(exp).is_null(), snowpark_fn.lit(None)
3874
3928
  ).otherwise(snowpark_fn.object_construct_keep_null(*key_values))
3875
3929
  elif isinstance(t, ArrayType):
3930
+ # Handle array wrapping behavior for top-level structs
3876
3931
  if top_level and isinstance(t.element_type, StructType):
3877
- # Spark can still wrap a single value in an array if the internal type is a struct
3932
+ # Spark can wrap a single value in an array if the element type is a struct
3878
3933
  arr_exp = snowpark_fn.to_array(exp)
3879
3934
  else:
3880
- # if it's not a struct, we can return null for any non-array values
3935
+ # For other types, return null for non-array values
3881
3936
  arr_exp = snowpark_fn.as_array(exp)
3882
- # adjust the parsed json to match the expected type so that we can cast it later
3937
+
3938
+ # Get coercion SQL for the array element type using placeholder column "x"
3883
3939
  analyzer = Session.get_active_session()._analyzer
3884
3940
  fn_sql = analyzer.analyze(
3885
3941
  _coerce_to_type(
@@ -3888,7 +3944,7 @@ def map_unresolved_function(
3888
3944
  defaultdict(),
3889
3945
  )
3890
3946
 
3891
- # if there's even a single incorrect element, return null
3947
+ # Apply TRANSFORM to coerce each element, or return null if types don't match
3892
3948
  return snowpark_fn.when(
3893
3949
  _element_type_matches(arr_exp, t.element_type),
3894
3950
  snowpark_fn.call_function(
@@ -3896,9 +3952,58 @@ def map_unresolved_function(
3896
3952
  ),
3897
3953
  ).otherwise(snowpark_fn.lit(None))
3898
3954
  elif isinstance(t, MapType):
3899
- return snowpark_fn.as_object(exp)
3955
+ obj_exp = snowpark_fn.as_object(exp)
3956
+
3957
+ # If value type is simple (no nested complex types), no coercion needed
3958
+ if not isinstance(t.value_type, (StructType, ArrayType, MapType)):
3959
+ return obj_exp
3960
+
3961
+ # For maps with complex value types, we need to coerce each value.
3962
+ # Strategy: Use pure SQL REDUCE with stateful accumulator to avoid:
3963
+ # 1. UDF-in-lambda errors (which break nested maps)
3964
+ # 2. Column scoping issues (outer columns aren't accessible in lambdas)
3965
+ #
3966
+ # The state is a 2-element array: [result_map, original_map]
3967
+ # This allows the lambda to access the original map's values while building the result.
3968
+
3969
+ analyzer = Session.get_active_session()._analyzer
3970
+
3971
+ # Get the coercion SQL for the value type using a placeholder column
3972
+ fn_sql = analyzer.analyze(
3973
+ _coerce_to_type(
3974
+ snowpark_fn.col("v"), t.value_type, False
3975
+ )._expression,
3976
+ defaultdict(),
3977
+ )
3978
+
3979
+ # Replace placeholder "V" with reference to original map via state array
3980
+ # In lambda: state[1] = original_map, k = current_key
3981
+ fn_sql_with_value = fn_sql.replace(
3982
+ '"V"', "strip_null_value(GET(state[1], k))"
3983
+ )
3984
+
3985
+ # Build REDUCE lambda: (state, k) -> [updated_result, original_map]
3986
+ lambda_expr = (
3987
+ f"(state, k) -> ARRAY_CONSTRUCT("
3988
+ f"object_insert(state[0], k, ({fn_sql_with_value})::variant, true), "
3989
+ f"state[1])"
3990
+ )
3991
+
3992
+ # Execute REDUCE with initial state = [{}, original_map]
3993
+ reduce_result = snowpark_fn.call_function(
3994
+ "reduce",
3995
+ snowpark_fn.call_function("object_keys", obj_exp),
3996
+ snowpark_fn.array_construct(
3997
+ snowpark_fn.object_construct(), # state[0]: empty result map
3998
+ obj_exp, # state[1]: original map
3999
+ ),
4000
+ snowpark_fn.sql_expr(lambda_expr),
4001
+ )
4002
+
4003
+ # Extract the result map (state[0]) from the final state
4004
+ return snowpark_fn.get(reduce_result, snowpark_fn.lit(0))
3900
4005
  else:
3901
- return exp
4006
+ return snowpark_fn.try_cast(snowpark_fn.to_varchar(exp), t)
3902
4007
 
3903
4008
  # Apply the coercion to handle invalid JSON (creates struct with NULL fields)
3904
4009
  coerced_exp = _coerce_to_type(result_exp, result_type)
@@ -5930,7 +6035,11 @@ def map_unresolved_function(
5930
6035
  spark_function_name = f"(- {snowpark_arg_names[0]})"
5931
6036
  else:
5932
6037
  spark_function_name = f"negative({snowpark_arg_names[0]})"
5933
- if (
6038
+ if isinstance(arg_type, _IntegralType):
6039
+ result_exp = apply_unary_overflow_with_ansi_check(
6040
+ snowpark_args[0], arg_type, spark_sql_ansi_enabled, "negative"
6041
+ )
6042
+ elif (
5934
6043
  isinstance(arg_type, _NumericType)
5935
6044
  or isinstance(arg_type, YearMonthIntervalType)
5936
6045
  or isinstance(arg_type, DayTimeIntervalType)
@@ -8101,10 +8210,20 @@ def map_unresolved_function(
8101
8210
  else:
8102
8211
  result_type = DoubleType()
8103
8212
 
8104
- result_exp = _resolve_aggregate_exp(
8105
- sum_fn(arg),
8106
- result_type,
8107
- )
8213
+ if isinstance(input_type, _IntegralType) and not is_window_enabled():
8214
+ raw_sum = sum_fn(arg)
8215
+ wrapped_sum = apply_arithmetic_overflow_with_ansi_check(
8216
+ raw_sum, result_type, spark_sql_ansi_enabled, "add"
8217
+ )
8218
+ result_exp = _resolve_aggregate_exp(
8219
+ wrapped_sum,
8220
+ result_type,
8221
+ )
8222
+ else:
8223
+ result_exp = _resolve_aggregate_exp(
8224
+ sum_fn(arg),
8225
+ result_type,
8226
+ )
8108
8227
  case "tan":
8109
8228
  spark_function_name = f"TAN({snowpark_arg_names[0]})"
8110
8229
  result_exp = snowpark_fn.tan(snowpark_args[0])
@@ -8977,10 +9096,15 @@ def map_unresolved_function(
8977
9096
  .otherwise(snowpark_args[0] + snowpark_args[1])
8978
9097
  )
8979
9098
  case _:
8980
- result_exp = _try_arithmetic_helper(
9099
+ result_exp, result_type = _try_arithmetic_helper(
8981
9100
  snowpark_typed_args, snowpark_args, 0
8982
9101
  )
8983
- result_exp = _type_with_typer(result_exp)
9102
+ if result_type is not None:
9103
+ result_exp = TypedColumn(
9104
+ result_exp, lambda rt=result_type: [rt]
9105
+ )
9106
+ else:
9107
+ result_exp = _type_with_typer(result_exp)
8984
9108
  case "try_aes_decrypt":
8985
9109
  result_exp = _aes_helper(
8986
9110
  "TRY_DECRYPT",
@@ -9110,6 +9234,7 @@ def map_unresolved_function(
9110
9234
  overflow_possible,
9111
9235
  False,
9112
9236
  result_type,
9237
+ "divide",
9113
9238
  )
9114
9239
  case (_NumericType(), _NumericType()):
9115
9240
  result_exp = snowpark_fn.when(
@@ -9263,31 +9388,29 @@ def map_unresolved_function(
9263
9388
  result_type = FloatType()
9264
9389
  case _:
9265
9390
  result_type = t
9266
- case (_IntegralType(), _IntegralType()):
9267
- min_long = sys.maxsize + 1
9268
- max_long = sys.maxsize
9391
+ case (_IntegralType() as t1, _IntegralType() as t2):
9392
+ result_type = _find_common_type([t1, t2])
9393
+ min_val, max_val = get_integral_type_bounds(result_type)
9269
9394
 
9270
- max_value = snowpark_fn.when(
9271
- ((snowpark_args[0] > 0) & (snowpark_args[1] > 0))
9272
- | ((snowpark_args[0] < 0) & (snowpark_args[1] < 0)),
9273
- max_long,
9274
- ).otherwise(min_long)
9395
+ same_sign = ((snowpark_args[0] > 0) & (snowpark_args[1] > 0)) | (
9396
+ (snowpark_args[0] < 0) & (snowpark_args[1] < 0)
9397
+ )
9398
+ bound = snowpark_fn.when(same_sign, max_val).otherwise(-min_val - 1)
9275
9399
 
9276
9400
  result_exp = (
9277
- # Multiplication by 0 must be handled separately, since division by 0 will throw an error.
9278
9401
  snowpark_fn.when(
9279
9402
  (snowpark_args[0] == 0) | (snowpark_args[1] == 0),
9280
- snowpark_fn.lit(0),
9403
+ snowpark_fn.lit(0).cast(result_type),
9281
9404
  )
9282
- # We check for overflow by seeing if max or min divided by the right argument is greater than the
9283
- # left argument.
9284
9405
  .when(
9285
9406
  snowpark_fn.abs(snowpark_args[0])
9286
- > (max_value / snowpark_fn.abs(snowpark_args[1])),
9407
+ > (bound / snowpark_fn.abs(snowpark_args[1])),
9287
9408
  snowpark_fn.lit(None),
9288
- ).otherwise(snowpark_args[0] * snowpark_args[1])
9409
+ )
9410
+ .otherwise(
9411
+ (snowpark_args[0] * snowpark_args[1]).cast(result_type)
9412
+ )
9289
9413
  )
9290
- result_exp = _type_with_typer(result_exp)
9291
9414
  case (
9292
9415
  (DecimalType(), _IntegralType())
9293
9416
  | (
@@ -9310,6 +9433,7 @@ def map_unresolved_function(
9310
9433
  overflow_possible,
9311
9434
  False,
9312
9435
  result_type,
9436
+ "multiply",
9313
9437
  )
9314
9438
  case (_NumericType(), _NumericType()):
9315
9439
  result_exp = snowpark_args[0] * snowpark_args[1]
@@ -9457,10 +9581,15 @@ def map_unresolved_function(
9457
9581
  .otherwise(snowpark_args[0] - snowpark_args[1])
9458
9582
  )
9459
9583
  case _:
9460
- result_exp = _try_arithmetic_helper(
9584
+ result_exp, result_type = _try_arithmetic_helper(
9461
9585
  snowpark_typed_args, snowpark_args, 1
9462
9586
  )
9463
- result_exp = _type_with_typer(result_exp)
9587
+ if result_type is not None:
9588
+ result_exp = TypedColumn(
9589
+ result_exp, lambda rt=result_type: [rt]
9590
+ )
9591
+ else:
9592
+ result_exp = _type_with_typer(result_exp)
9464
9593
  case "try_to_number":
9465
9594
  try_to_number = snowpark_fn.function("try_to_number")
9466
9595
  precision, scale = resolve_to_number_precision_and_scale(exp)
@@ -11178,7 +11307,7 @@ def _try_sum_helper(
11178
11307
  input_types=[arg_type],
11179
11308
  )
11180
11309
  # call the udaf
11181
- return _try_sum_int_udaf(col_name), arg_type
11310
+ return _try_sum_int_udaf(col_name), LongType()
11182
11311
 
11183
11312
  # NOTE: We will never call this function with an IntegerType column and calculating_avg=True. Therefore,
11184
11313
  # we don't need to handle the case where calculating_avg=True here. The caller of this function will handle it.
@@ -11402,8 +11531,15 @@ def _arithmetic_operation(
11402
11531
  op: Callable[[Column, Column], Column],
11403
11532
  overflow_possible: bool,
11404
11533
  should_raise_on_overflow: bool,
11405
- target_type: DecimalType,
11534
+ target_type: DataType,
11535
+ operation_name: str,
11406
11536
  ) -> Column:
11537
+ if isinstance(target_type, _IntegralType):
11538
+ raw_result = op(arg1.col, arg2.col)
11539
+ return apply_arithmetic_overflow_with_ansi_check(
11540
+ raw_result, target_type, should_raise_on_overflow, operation_name
11541
+ )
11542
+
11407
11543
  def _cast_arg(tc: TypedColumn) -> Column:
11408
11544
  _, s = _get_type_precision(tc.typ)
11409
11545
  typ = (
@@ -11490,12 +11626,12 @@ def _get_decimal_division_result_type(p1, s1, p2, s2) -> tuple[DecimalType, bool
11490
11626
 
11491
11627
  def _try_arithmetic_helper(
11492
11628
  typed_args: List[TypedColumn], snowpark_args: List[Column], operation_type: int
11493
- ) -> Column:
11629
+ ) -> tuple[Column, DataType | None]:
11494
11630
  # Constructs a Snowpark Column expression for a "try-style" arithmetic operation
11495
11631
  # (addition or subtraction, determined by `operation_type`) between two input columns.
11496
11632
  #
11497
11633
  # Key behavioral characteristics:
11498
- # 1. For **Integral inputs**: Explicitly checks for 64-bit overflow/underflow.
11634
+ # 1. For **Integral inputs**: Explicitly checks for overflow/underflow at the result type boundaries.
11499
11635
  # - BEHAVIOR: Returns a NULL literal if the operation would exceed these limits;
11500
11636
  # otherwise, returns the result of the standard Snowpark `+` or `-`.
11501
11637
  #
@@ -11508,10 +11644,11 @@ def _try_arithmetic_helper(
11508
11644
  # Arithmetic operations involving **Boolean types** will raise an `AnalysisException`.
11509
11645
  # All other unhandled incompatible type combinations result in a NULL literal.
11510
11646
  # The function returns the resulting Snowpark Column expression.
11511
-
11512
11647
  match (typed_args[0].typ, typed_args[1].typ):
11513
- case (_IntegralType(), _IntegralType()):
11514
- # For integer addition, overflow errors by default in Snowflake. We need it to return null.
11648
+ case (_IntegralType() as t1, _IntegralType() as t2):
11649
+ result_type = _find_common_type([t1, t2])
11650
+ min_val, max_val = get_integral_type_bounds(result_type)
11651
+
11515
11652
  if operation_type == 0: # Addition
11516
11653
  result_exp = (
11517
11654
  snowpark_fn.when(
@@ -11519,20 +11656,20 @@ def _try_arithmetic_helper(
11519
11656
  & (snowpark_args[1] > 0)
11520
11657
  & (
11521
11658
  snowpark_args[0]
11522
- > snowpark_fn.lit(MAX_INT64) - snowpark_args[1]
11659
+ > snowpark_fn.lit(max_val) - snowpark_args[1]
11523
11660
  ),
11524
- snowpark_fn.lit(None), # Overflow
11661
+ snowpark_fn.lit(None),
11525
11662
  )
11526
11663
  .when(
11527
11664
  (snowpark_args[0] < 0)
11528
11665
  & (snowpark_args[1] < 0)
11529
11666
  & (
11530
11667
  snowpark_args[0]
11531
- < snowpark_fn.lit(MIN_INT64) - snowpark_args[1]
11668
+ < snowpark_fn.lit(min_val) - snowpark_args[1]
11532
11669
  ),
11533
- snowpark_fn.lit(None), # Underflow
11670
+ snowpark_fn.lit(None),
11534
11671
  )
11535
- .otherwise(snowpark_args[0] + snowpark_args[1])
11672
+ .otherwise((snowpark_args[0] + snowpark_args[1]).cast(result_type))
11536
11673
  )
11537
11674
  else: # Subtraction
11538
11675
  result_exp = (
@@ -11541,22 +11678,22 @@ def _try_arithmetic_helper(
11541
11678
  & (snowpark_args[1] < 0)
11542
11679
  & (
11543
11680
  snowpark_args[0]
11544
- > snowpark_fn.lit(MAX_INT64) + snowpark_args[1]
11681
+ > snowpark_fn.lit(max_val) + snowpark_args[1]
11545
11682
  ),
11546
- snowpark_fn.lit(None), # Overflow
11683
+ snowpark_fn.lit(None),
11547
11684
  )
11548
11685
  .when(
11549
11686
  (snowpark_args[0] < 0)
11550
11687
  & (snowpark_args[1] > 0)
11551
11688
  & (
11552
11689
  snowpark_args[0]
11553
- < snowpark_fn.lit(MIN_INT64) + snowpark_args[1]
11690
+ < snowpark_fn.lit(min_val) + snowpark_args[1]
11554
11691
  ),
11555
- snowpark_fn.lit(None), # Underflow
11692
+ snowpark_fn.lit(None),
11556
11693
  )
11557
- .otherwise(snowpark_args[0] - snowpark_args[1])
11694
+ .otherwise((snowpark_args[0] - snowpark_args[1]).cast(result_type))
11558
11695
  )
11559
- return result_exp
11696
+ return result_exp, result_type
11560
11697
  case (DateType(), _) | (_, DateType()):
11561
11698
  arg1, arg2 = typed_args[0].typ, typed_args[1].typ
11562
11699
  # Valid input parameter types for try_add - DateType and _NumericType, _NumericType and DateType.
@@ -11577,22 +11714,28 @@ def _try_arithmetic_helper(
11577
11714
  if isinstance(arg1, _IntegralType)
11578
11715
  else snowpark_args
11579
11716
  )
11580
- return _try_to_cast(
11581
- "try_to_date",
11582
- snowpark_fn.cast(snowpark_fn.date_add(*args), DateType()),
11583
- args[0],
11717
+ return (
11718
+ _try_to_cast(
11719
+ "try_to_date",
11720
+ snowpark_fn.cast(snowpark_fn.date_add(*args), DateType()),
11721
+ args[0],
11722
+ ),
11723
+ None,
11584
11724
  )
11585
11725
  else:
11586
11726
  if isinstance(arg1, DateType) and isinstance(arg2, _IntegralType):
11587
- return _try_to_cast(
11588
- "try_to_date",
11589
- snowpark_fn.to_date(
11590
- snowpark_fn.date_sub(snowpark_args[0], snowpark_args[1])
11727
+ return (
11728
+ _try_to_cast(
11729
+ "try_to_date",
11730
+ snowpark_fn.to_date(
11731
+ snowpark_fn.date_sub(snowpark_args[0], snowpark_args[1])
11732
+ ),
11733
+ snowpark_args[0],
11591
11734
  ),
11592
- snowpark_args[0],
11735
+ None,
11593
11736
  )
11594
11737
  elif isinstance(arg1, DateType) and isinstance(arg2, DateType):
11595
- return snowpark_fn.daydiff(snowpark_args[0], snowpark_args[1])
11738
+ return snowpark_fn.daydiff(snowpark_args[0], snowpark_args[1]), None
11596
11739
  else:
11597
11740
  exception = AnalysisException(
11598
11741
  '[DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE] Cannot resolve "date_sub(dt, sub)" due to data type mismatch: Parameter 1 requires the "DATE" type and parameter 2 requires the ("INT" or "SMALLINT" or "TINYINT") type'
@@ -11609,12 +11752,16 @@ def _try_arithmetic_helper(
11609
11752
  "try_add" if operation_type == 0 else "try_subtract",
11610
11753
  )
11611
11754
 
11612
- return _arithmetic_operation(
11613
- typed_args[0],
11614
- typed_args[1],
11615
- lambda x, y: x + y if operation_type == 0 else x - y,
11616
- overflow_possible,
11617
- False,
11755
+ return (
11756
+ _arithmetic_operation(
11757
+ typed_args[0],
11758
+ typed_args[1],
11759
+ lambda x, y: x + y if operation_type == 0 else x - y,
11760
+ overflow_possible,
11761
+ False,
11762
+ result_type,
11763
+ "add" if operation_type == 0 else "subtract",
11764
+ ),
11618
11765
  result_type,
11619
11766
  )
11620
11767
 
@@ -11622,11 +11769,12 @@ def _try_arithmetic_helper(
11622
11769
  # matches Spark and goes to inf.
11623
11770
  # Note that we already handle the int,int case above, hence it is okay to use the broader _numeric
11624
11771
  # below.
11625
- case (_NumericType(), _NumericType()):
11772
+ case (_NumericType() as t1, _NumericType() as t2):
11773
+ result_type = _find_common_type([t1, t2])
11626
11774
  if operation_type == 0:
11627
- return snowpark_args[0] + snowpark_args[1]
11775
+ return snowpark_args[0] + snowpark_args[1], result_type
11628
11776
  else:
11629
- return snowpark_args[0] - snowpark_args[1]
11777
+ return snowpark_args[0] - snowpark_args[1], result_type
11630
11778
  # String cases - try to convert to numeric
11631
11779
  case (
11632
11780
  (StringType(), _NumericType())
@@ -11642,12 +11790,12 @@ def _try_arithmetic_helper(
11642
11790
  updated_args = _validate_numeric_args(
11643
11791
  "try_add", typed_args, snowpark_args
11644
11792
  )
11645
- return updated_args[0] + updated_args[1]
11793
+ return updated_args[0] + updated_args[1], None
11646
11794
  else:
11647
11795
  updated_args = _validate_numeric_args(
11648
11796
  "try_subtract", typed_args, snowpark_args
11649
11797
  )
11650
- return updated_args[0] - updated_args[1]
11798
+ return updated_args[0] - updated_args[1], None
11651
11799
 
11652
11800
  case (BooleanType(), _) | (_, BooleanType()):
11653
11801
  exception = AnalysisException(
@@ -11657,7 +11805,7 @@ def _try_arithmetic_helper(
11657
11805
  raise exception
11658
11806
  case _:
11659
11807
  # Return NULL for incompatible types
11660
- return snowpark_fn.lit(None)
11808
+ return snowpark_fn.lit(None), None
11661
11809
 
11662
11810
 
11663
11811
  def _get_add_sub_result_type(
@@ -11853,24 +12001,9 @@ def _get_literal_param_name(exp, arg_index: int, default_param_name: str):
11853
12001
 
11854
12002
 
11855
12003
  def _raise_error_helper(return_type: DataType, error_class=None):
11856
- error_class = (
11857
- f":{error_class.__name__}"
11858
- if error_class and hasattr(error_class, "__name__")
11859
- else ""
11860
- )
11861
-
11862
- def _raise_fn(*msgs: Column) -> Column:
11863
- return snowpark_fn.cast(
11864
- snowpark_fn.abs(
11865
- snowpark_fn.concat(
11866
- snowpark_fn.lit(f"[snowpark-connect-exception{error_class}]"),
11867
- *(msg.try_cast(StringType()) for msg in msgs),
11868
- )
11869
- ).cast(StringType()),
11870
- return_type,
11871
- )
12004
+ from snowflake.snowpark_connect.expression.error_utils import raise_error_helper
11872
12005
 
11873
- return _raise_fn
12006
+ return raise_error_helper(return_type, error_class)
11874
12007
 
11875
12008
 
11876
12009
  def _divnull(dividend: Column, divisor: Column) -> Column: