snowpark-connect 1.6.0__py3-none-any.whl → 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/snowpark_connect/client/server.py +37 -0
- snowflake/snowpark_connect/config.py +72 -3
- snowflake/snowpark_connect/expression/error_utils.py +28 -0
- snowflake/snowpark_connect/expression/integral_types_support.py +219 -0
- snowflake/snowpark_connect/expression/map_cast.py +108 -17
- snowflake/snowpark_connect/expression/map_udf.py +1 -0
- snowflake/snowpark_connect/expression/map_unresolved_function.py +229 -96
- snowflake/snowpark_connect/includes/jars/json4s-ast_2.13-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.2.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.13-0.2.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-reflect-2.13.16.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-common-utils_2.13-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-connect-client-jvm_2.13-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sql_2.13-3.5.6.jar +0 -0
- snowflake/snowpark_connect/relation/map_aggregate.py +43 -1
- snowflake/snowpark_connect/relation/read/map_read_csv.py +73 -4
- snowflake/snowpark_connect/relation/read/map_read_jdbc.py +4 -1
- snowflake/snowpark_connect/relation/read/map_read_json.py +4 -1
- snowflake/snowpark_connect/relation/read/map_read_parquet.py +4 -1
- snowflake/snowpark_connect/relation/read/map_read_socket.py +4 -0
- snowflake/snowpark_connect/relation/read/map_read_table.py +4 -1
- snowflake/snowpark_connect/relation/read/map_read_text.py +4 -1
- snowflake/snowpark_connect/relation/read/reader_config.py +6 -0
- snowflake/snowpark_connect/resources_initializer.py +90 -29
- snowflake/snowpark_connect/server.py +6 -41
- snowflake/snowpark_connect/server_common/__init__.py +4 -1
- snowflake/snowpark_connect/type_support.py +130 -0
- snowflake/snowpark_connect/utils/context.py +8 -0
- snowflake/snowpark_connect/utils/java_stored_procedure.py +53 -27
- snowflake/snowpark_connect/utils/java_udaf_utils.py +46 -28
- snowflake/snowpark_connect/utils/java_udtf_utils.py +1 -1
- snowflake/snowpark_connect/utils/jvm_udf_utils.py +48 -15
- snowflake/snowpark_connect/utils/scala_udf_utils.py +98 -22
- snowflake/snowpark_connect/utils/telemetry.py +33 -22
- snowflake/snowpark_connect/utils/udxf_import_utils.py +9 -2
- snowflake/snowpark_connect/version.py +1 -1
- {snowpark_connect-1.6.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-submit +12 -2
- {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/METADATA +4 -2
- {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/RECORD +46 -37
- {snowpark_connect-1.6.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-1.6.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/WHEEL +0 -0
- {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/NOTICE-binary +0 -0
- {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/top_level.txt +0 -0
|
@@ -89,6 +89,12 @@ from snowflake.snowpark_connect.error.error_utils import attach_custom_error_cod
|
|
|
89
89
|
from snowflake.snowpark_connect.expression.function_defaults import (
|
|
90
90
|
inject_function_defaults,
|
|
91
91
|
)
|
|
92
|
+
from snowflake.snowpark_connect.expression.integral_types_support import (
|
|
93
|
+
apply_abs_overflow_with_ansi_check,
|
|
94
|
+
apply_arithmetic_overflow_with_ansi_check,
|
|
95
|
+
apply_unary_overflow_with_ansi_check,
|
|
96
|
+
get_integral_type_bounds,
|
|
97
|
+
)
|
|
92
98
|
from snowflake.snowpark_connect.expression.literal import get_literal_field_and_name
|
|
93
99
|
from snowflake.snowpark_connect.expression.map_cast import (
|
|
94
100
|
CAST_FUNCTIONS,
|
|
@@ -607,6 +613,7 @@ def map_unresolved_function(
|
|
|
607
613
|
overflow_possible,
|
|
608
614
|
global_config.spark_sql_ansi_enabled,
|
|
609
615
|
result_type,
|
|
616
|
+
"multiply",
|
|
610
617
|
)
|
|
611
618
|
case (NullType(), NullType()):
|
|
612
619
|
result_type = DoubleType()
|
|
@@ -750,9 +757,17 @@ def map_unresolved_function(
|
|
|
750
757
|
result_type = _find_common_type(
|
|
751
758
|
[arg.typ for arg in snowpark_typed_args]
|
|
752
759
|
)
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
760
|
+
if isinstance(result_type, _IntegralType):
|
|
761
|
+
raw_result = snowpark_args[0].cast(result_type) * snowpark_args[
|
|
762
|
+
1
|
|
763
|
+
].cast(result_type)
|
|
764
|
+
result_exp = apply_arithmetic_overflow_with_ansi_check(
|
|
765
|
+
raw_result, result_type, spark_sql_ansi_enabled, "multiply"
|
|
766
|
+
)
|
|
767
|
+
else:
|
|
768
|
+
result_exp = snowpark_args[0].cast(result_type) * snowpark_args[
|
|
769
|
+
1
|
|
770
|
+
].cast(result_type)
|
|
756
771
|
case _:
|
|
757
772
|
exception = AnalysisException(
|
|
758
773
|
f'[DATATYPE_MISMATCH.BINARY_OP_DIFF_TYPES] Cannot resolve "{spark_function_name}" due to data type mismatch: the left and right operands of the binary operator have incompatible types ("{snowpark_typed_args[0].typ}" and "{snowpark_typed_args[1].typ}").'
|
|
@@ -1012,6 +1027,7 @@ def map_unresolved_function(
|
|
|
1012
1027
|
overflow_possible,
|
|
1013
1028
|
global_config.spark_sql_ansi_enabled,
|
|
1014
1029
|
result_type,
|
|
1030
|
+
"add",
|
|
1015
1031
|
)
|
|
1016
1032
|
|
|
1017
1033
|
case "-":
|
|
@@ -1194,6 +1210,7 @@ def map_unresolved_function(
|
|
|
1194
1210
|
overflow_possible,
|
|
1195
1211
|
global_config.spark_sql_ansi_enabled,
|
|
1196
1212
|
result_type,
|
|
1213
|
+
"subtract",
|
|
1197
1214
|
)
|
|
1198
1215
|
|
|
1199
1216
|
case "/":
|
|
@@ -1218,6 +1235,7 @@ def map_unresolved_function(
|
|
|
1218
1235
|
overflow_possible,
|
|
1219
1236
|
global_config.spark_sql_ansi_enabled,
|
|
1220
1237
|
result_type,
|
|
1238
|
+
"divide",
|
|
1221
1239
|
)
|
|
1222
1240
|
case (NullType(), NullType()):
|
|
1223
1241
|
result_type = DoubleType()
|
|
@@ -1471,6 +1489,11 @@ def map_unresolved_function(
|
|
|
1471
1489
|
snowpark_fn.cast(snowpark_args[0], DoubleType())
|
|
1472
1490
|
)
|
|
1473
1491
|
result_type = DoubleType()
|
|
1492
|
+
elif isinstance(input_type, _IntegralType):
|
|
1493
|
+
result_exp = apply_abs_overflow_with_ansi_check(
|
|
1494
|
+
snowpark_args[0], input_type, spark_sql_ansi_enabled
|
|
1495
|
+
)
|
|
1496
|
+
result_type = input_type
|
|
1474
1497
|
else:
|
|
1475
1498
|
result_exp = snowpark_fn.abs(snowpark_args[0])
|
|
1476
1499
|
result_type = input_type
|
|
@@ -3764,7 +3787,12 @@ def map_unresolved_function(
|
|
|
3764
3787
|
snowpark_args[0].is_null(), snowpark_fn.lit(None)
|
|
3765
3788
|
).otherwise(snowpark_fn.cast(csv_result, result_type))
|
|
3766
3789
|
case "from_json":
|
|
3767
|
-
# TODO: support options.
|
|
3790
|
+
# TODO: support options parameter.
|
|
3791
|
+
# The options map (e.g., map('timestampFormat', 'dd/MM/yyyy')) is validated
|
|
3792
|
+
# but not currently used. To implement:
|
|
3793
|
+
# 1. Extract options from snowpark_args[2]
|
|
3794
|
+
# 2. Pass format options to JSON parsing/coercion logic
|
|
3795
|
+
# 3. Apply custom formats when casting timestamp/date fields
|
|
3768
3796
|
if len(snowpark_args) > 2:
|
|
3769
3797
|
if not isinstance(snowpark_typed_args[2].typ, MapType):
|
|
3770
3798
|
exception = AnalysisException(
|
|
@@ -3792,6 +3820,29 @@ def map_unresolved_function(
|
|
|
3792
3820
|
logger.debug("Failed to parse datatype json string: %s", e)
|
|
3793
3821
|
result_type = map_type_string_to_snowpark_type(lit_schema)
|
|
3794
3822
|
|
|
3823
|
+
# Validate that all MapTypes in the schema have StringType keys.
|
|
3824
|
+
# JSON specification only supports string keys, so from_json cannot parse
|
|
3825
|
+
# into MapType with non-string keys (e.g., IntegerType, LongType).
|
|
3826
|
+
# Spark enforces this and raises INVALID_JSON_MAP_KEY_TYPE error.
|
|
3827
|
+
def _validate_map_key_types(data_type: DataType) -> None:
|
|
3828
|
+
"""Recursively validate that all MapType instances have StringType keys."""
|
|
3829
|
+
if isinstance(data_type, MapType):
|
|
3830
|
+
if not isinstance(data_type.key_type, StringType):
|
|
3831
|
+
exception = AnalysisException(
|
|
3832
|
+
f"[INVALID_JSON_MAP_KEY_TYPE] Input schema {lit_schema} can only contain STRING as a key type for a MAP."
|
|
3833
|
+
)
|
|
3834
|
+
attach_custom_error_code(exception, ErrorCodes.TYPE_MISMATCH)
|
|
3835
|
+
raise exception
|
|
3836
|
+
# Check the value type recursively
|
|
3837
|
+
_validate_map_key_types(data_type.value_type)
|
|
3838
|
+
elif isinstance(data_type, ArrayType):
|
|
3839
|
+
_validate_map_key_types(data_type.element_type)
|
|
3840
|
+
elif isinstance(data_type, StructType):
|
|
3841
|
+
for field in data_type.fields:
|
|
3842
|
+
_validate_map_key_types(field.datatype)
|
|
3843
|
+
|
|
3844
|
+
_validate_map_key_types(result_type)
|
|
3845
|
+
|
|
3795
3846
|
# if the result is a map, the column is named "entries"
|
|
3796
3847
|
if isinstance(result_type, MapType):
|
|
3797
3848
|
spark_function_name = "entries"
|
|
@@ -3846,11 +3897,14 @@ def map_unresolved_function(
|
|
|
3846
3897
|
# let's optimistically assume that any simple type can be coerced to the expected type automatically
|
|
3847
3898
|
return snowpark_fn.lit(True)
|
|
3848
3899
|
|
|
3849
|
-
#
|
|
3850
|
-
#
|
|
3851
|
-
#
|
|
3852
|
-
#
|
|
3853
|
-
#
|
|
3900
|
+
# Snowflake limitation: Casting semi-structured data to structured types fails
|
|
3901
|
+
# if the source doesn't have the exact "shape" (e.g., missing struct fields).
|
|
3902
|
+
# This function constructs an expression that coerces the parsed JSON to match
|
|
3903
|
+
# the expected type structure, filling in NULLs for missing fields.
|
|
3904
|
+
#
|
|
3905
|
+
# For complex types (StructType, ArrayType, MapType), this recursively ensures
|
|
3906
|
+
# nested structures match the schema. For MapType with complex values, it uses
|
|
3907
|
+
# a pure SQL REDUCE approach to avoid UDF-in-lambda errors.
|
|
3854
3908
|
def _coerce_to_type(
|
|
3855
3909
|
exp: Column, t: DataType, top_level: bool = True
|
|
3856
3910
|
) -> Column:
|
|
@@ -3873,13 +3927,15 @@ def map_unresolved_function(
|
|
|
3873
3927
|
snowpark_fn.as_object(exp).is_null(), snowpark_fn.lit(None)
|
|
3874
3928
|
).otherwise(snowpark_fn.object_construct_keep_null(*key_values))
|
|
3875
3929
|
elif isinstance(t, ArrayType):
|
|
3930
|
+
# Handle array wrapping behavior for top-level structs
|
|
3876
3931
|
if top_level and isinstance(t.element_type, StructType):
|
|
3877
|
-
# Spark can
|
|
3932
|
+
# Spark can wrap a single value in an array if the element type is a struct
|
|
3878
3933
|
arr_exp = snowpark_fn.to_array(exp)
|
|
3879
3934
|
else:
|
|
3880
|
-
#
|
|
3935
|
+
# For other types, return null for non-array values
|
|
3881
3936
|
arr_exp = snowpark_fn.as_array(exp)
|
|
3882
|
-
|
|
3937
|
+
|
|
3938
|
+
# Get coercion SQL for the array element type using placeholder column "x"
|
|
3883
3939
|
analyzer = Session.get_active_session()._analyzer
|
|
3884
3940
|
fn_sql = analyzer.analyze(
|
|
3885
3941
|
_coerce_to_type(
|
|
@@ -3888,7 +3944,7 @@ def map_unresolved_function(
|
|
|
3888
3944
|
defaultdict(),
|
|
3889
3945
|
)
|
|
3890
3946
|
|
|
3891
|
-
#
|
|
3947
|
+
# Apply TRANSFORM to coerce each element, or return null if types don't match
|
|
3892
3948
|
return snowpark_fn.when(
|
|
3893
3949
|
_element_type_matches(arr_exp, t.element_type),
|
|
3894
3950
|
snowpark_fn.call_function(
|
|
@@ -3896,9 +3952,58 @@ def map_unresolved_function(
|
|
|
3896
3952
|
),
|
|
3897
3953
|
).otherwise(snowpark_fn.lit(None))
|
|
3898
3954
|
elif isinstance(t, MapType):
|
|
3899
|
-
|
|
3955
|
+
obj_exp = snowpark_fn.as_object(exp)
|
|
3956
|
+
|
|
3957
|
+
# If value type is simple (no nested complex types), no coercion needed
|
|
3958
|
+
if not isinstance(t.value_type, (StructType, ArrayType, MapType)):
|
|
3959
|
+
return obj_exp
|
|
3960
|
+
|
|
3961
|
+
# For maps with complex value types, we need to coerce each value.
|
|
3962
|
+
# Strategy: Use pure SQL REDUCE with stateful accumulator to avoid:
|
|
3963
|
+
# 1. UDF-in-lambda errors (which break nested maps)
|
|
3964
|
+
# 2. Column scoping issues (outer columns aren't accessible in lambdas)
|
|
3965
|
+
#
|
|
3966
|
+
# The state is a 2-element array: [result_map, original_map]
|
|
3967
|
+
# This allows the lambda to access the original map's values while building the result.
|
|
3968
|
+
|
|
3969
|
+
analyzer = Session.get_active_session()._analyzer
|
|
3970
|
+
|
|
3971
|
+
# Get the coercion SQL for the value type using a placeholder column
|
|
3972
|
+
fn_sql = analyzer.analyze(
|
|
3973
|
+
_coerce_to_type(
|
|
3974
|
+
snowpark_fn.col("v"), t.value_type, False
|
|
3975
|
+
)._expression,
|
|
3976
|
+
defaultdict(),
|
|
3977
|
+
)
|
|
3978
|
+
|
|
3979
|
+
# Replace placeholder "V" with reference to original map via state array
|
|
3980
|
+
# In lambda: state[1] = original_map, k = current_key
|
|
3981
|
+
fn_sql_with_value = fn_sql.replace(
|
|
3982
|
+
'"V"', "strip_null_value(GET(state[1], k))"
|
|
3983
|
+
)
|
|
3984
|
+
|
|
3985
|
+
# Build REDUCE lambda: (state, k) -> [updated_result, original_map]
|
|
3986
|
+
lambda_expr = (
|
|
3987
|
+
f"(state, k) -> ARRAY_CONSTRUCT("
|
|
3988
|
+
f"object_insert(state[0], k, ({fn_sql_with_value})::variant, true), "
|
|
3989
|
+
f"state[1])"
|
|
3990
|
+
)
|
|
3991
|
+
|
|
3992
|
+
# Execute REDUCE with initial state = [{}, original_map]
|
|
3993
|
+
reduce_result = snowpark_fn.call_function(
|
|
3994
|
+
"reduce",
|
|
3995
|
+
snowpark_fn.call_function("object_keys", obj_exp),
|
|
3996
|
+
snowpark_fn.array_construct(
|
|
3997
|
+
snowpark_fn.object_construct(), # state[0]: empty result map
|
|
3998
|
+
obj_exp, # state[1]: original map
|
|
3999
|
+
),
|
|
4000
|
+
snowpark_fn.sql_expr(lambda_expr),
|
|
4001
|
+
)
|
|
4002
|
+
|
|
4003
|
+
# Extract the result map (state[0]) from the final state
|
|
4004
|
+
return snowpark_fn.get(reduce_result, snowpark_fn.lit(0))
|
|
3900
4005
|
else:
|
|
3901
|
-
return exp
|
|
4006
|
+
return snowpark_fn.try_cast(snowpark_fn.to_varchar(exp), t)
|
|
3902
4007
|
|
|
3903
4008
|
# Apply the coercion to handle invalid JSON (creates struct with NULL fields)
|
|
3904
4009
|
coerced_exp = _coerce_to_type(result_exp, result_type)
|
|
@@ -5930,7 +6035,11 @@ def map_unresolved_function(
|
|
|
5930
6035
|
spark_function_name = f"(- {snowpark_arg_names[0]})"
|
|
5931
6036
|
else:
|
|
5932
6037
|
spark_function_name = f"negative({snowpark_arg_names[0]})"
|
|
5933
|
-
if (
|
|
6038
|
+
if isinstance(arg_type, _IntegralType):
|
|
6039
|
+
result_exp = apply_unary_overflow_with_ansi_check(
|
|
6040
|
+
snowpark_args[0], arg_type, spark_sql_ansi_enabled, "negative"
|
|
6041
|
+
)
|
|
6042
|
+
elif (
|
|
5934
6043
|
isinstance(arg_type, _NumericType)
|
|
5935
6044
|
or isinstance(arg_type, YearMonthIntervalType)
|
|
5936
6045
|
or isinstance(arg_type, DayTimeIntervalType)
|
|
@@ -8101,10 +8210,20 @@ def map_unresolved_function(
|
|
|
8101
8210
|
else:
|
|
8102
8211
|
result_type = DoubleType()
|
|
8103
8212
|
|
|
8104
|
-
|
|
8105
|
-
sum_fn(arg)
|
|
8106
|
-
|
|
8107
|
-
|
|
8213
|
+
if isinstance(input_type, _IntegralType) and not is_window_enabled():
|
|
8214
|
+
raw_sum = sum_fn(arg)
|
|
8215
|
+
wrapped_sum = apply_arithmetic_overflow_with_ansi_check(
|
|
8216
|
+
raw_sum, result_type, spark_sql_ansi_enabled, "add"
|
|
8217
|
+
)
|
|
8218
|
+
result_exp = _resolve_aggregate_exp(
|
|
8219
|
+
wrapped_sum,
|
|
8220
|
+
result_type,
|
|
8221
|
+
)
|
|
8222
|
+
else:
|
|
8223
|
+
result_exp = _resolve_aggregate_exp(
|
|
8224
|
+
sum_fn(arg),
|
|
8225
|
+
result_type,
|
|
8226
|
+
)
|
|
8108
8227
|
case "tan":
|
|
8109
8228
|
spark_function_name = f"TAN({snowpark_arg_names[0]})"
|
|
8110
8229
|
result_exp = snowpark_fn.tan(snowpark_args[0])
|
|
@@ -8977,10 +9096,15 @@ def map_unresolved_function(
|
|
|
8977
9096
|
.otherwise(snowpark_args[0] + snowpark_args[1])
|
|
8978
9097
|
)
|
|
8979
9098
|
case _:
|
|
8980
|
-
result_exp = _try_arithmetic_helper(
|
|
9099
|
+
result_exp, result_type = _try_arithmetic_helper(
|
|
8981
9100
|
snowpark_typed_args, snowpark_args, 0
|
|
8982
9101
|
)
|
|
8983
|
-
|
|
9102
|
+
if result_type is not None:
|
|
9103
|
+
result_exp = TypedColumn(
|
|
9104
|
+
result_exp, lambda rt=result_type: [rt]
|
|
9105
|
+
)
|
|
9106
|
+
else:
|
|
9107
|
+
result_exp = _type_with_typer(result_exp)
|
|
8984
9108
|
case "try_aes_decrypt":
|
|
8985
9109
|
result_exp = _aes_helper(
|
|
8986
9110
|
"TRY_DECRYPT",
|
|
@@ -9110,6 +9234,7 @@ def map_unresolved_function(
|
|
|
9110
9234
|
overflow_possible,
|
|
9111
9235
|
False,
|
|
9112
9236
|
result_type,
|
|
9237
|
+
"divide",
|
|
9113
9238
|
)
|
|
9114
9239
|
case (_NumericType(), _NumericType()):
|
|
9115
9240
|
result_exp = snowpark_fn.when(
|
|
@@ -9263,31 +9388,29 @@ def map_unresolved_function(
|
|
|
9263
9388
|
result_type = FloatType()
|
|
9264
9389
|
case _:
|
|
9265
9390
|
result_type = t
|
|
9266
|
-
case (_IntegralType(), _IntegralType()):
|
|
9267
|
-
|
|
9268
|
-
|
|
9391
|
+
case (_IntegralType() as t1, _IntegralType() as t2):
|
|
9392
|
+
result_type = _find_common_type([t1, t2])
|
|
9393
|
+
min_val, max_val = get_integral_type_bounds(result_type)
|
|
9269
9394
|
|
|
9270
|
-
|
|
9271
|
-
(
|
|
9272
|
-
|
|
9273
|
-
|
|
9274
|
-
).otherwise(min_long)
|
|
9395
|
+
same_sign = ((snowpark_args[0] > 0) & (snowpark_args[1] > 0)) | (
|
|
9396
|
+
(snowpark_args[0] < 0) & (snowpark_args[1] < 0)
|
|
9397
|
+
)
|
|
9398
|
+
bound = snowpark_fn.when(same_sign, max_val).otherwise(-min_val - 1)
|
|
9275
9399
|
|
|
9276
9400
|
result_exp = (
|
|
9277
|
-
# Multiplication by 0 must be handled separately, since division by 0 will throw an error.
|
|
9278
9401
|
snowpark_fn.when(
|
|
9279
9402
|
(snowpark_args[0] == 0) | (snowpark_args[1] == 0),
|
|
9280
|
-
snowpark_fn.lit(0),
|
|
9403
|
+
snowpark_fn.lit(0).cast(result_type),
|
|
9281
9404
|
)
|
|
9282
|
-
# We check for overflow by seeing if max or min divided by the right argument is greater than the
|
|
9283
|
-
# left argument.
|
|
9284
9405
|
.when(
|
|
9285
9406
|
snowpark_fn.abs(snowpark_args[0])
|
|
9286
|
-
> (
|
|
9407
|
+
> (bound / snowpark_fn.abs(snowpark_args[1])),
|
|
9287
9408
|
snowpark_fn.lit(None),
|
|
9288
|
-
)
|
|
9409
|
+
)
|
|
9410
|
+
.otherwise(
|
|
9411
|
+
(snowpark_args[0] * snowpark_args[1]).cast(result_type)
|
|
9412
|
+
)
|
|
9289
9413
|
)
|
|
9290
|
-
result_exp = _type_with_typer(result_exp)
|
|
9291
9414
|
case (
|
|
9292
9415
|
(DecimalType(), _IntegralType())
|
|
9293
9416
|
| (
|
|
@@ -9310,6 +9433,7 @@ def map_unresolved_function(
|
|
|
9310
9433
|
overflow_possible,
|
|
9311
9434
|
False,
|
|
9312
9435
|
result_type,
|
|
9436
|
+
"multiply",
|
|
9313
9437
|
)
|
|
9314
9438
|
case (_NumericType(), _NumericType()):
|
|
9315
9439
|
result_exp = snowpark_args[0] * snowpark_args[1]
|
|
@@ -9457,10 +9581,15 @@ def map_unresolved_function(
|
|
|
9457
9581
|
.otherwise(snowpark_args[0] - snowpark_args[1])
|
|
9458
9582
|
)
|
|
9459
9583
|
case _:
|
|
9460
|
-
result_exp = _try_arithmetic_helper(
|
|
9584
|
+
result_exp, result_type = _try_arithmetic_helper(
|
|
9461
9585
|
snowpark_typed_args, snowpark_args, 1
|
|
9462
9586
|
)
|
|
9463
|
-
|
|
9587
|
+
if result_type is not None:
|
|
9588
|
+
result_exp = TypedColumn(
|
|
9589
|
+
result_exp, lambda rt=result_type: [rt]
|
|
9590
|
+
)
|
|
9591
|
+
else:
|
|
9592
|
+
result_exp = _type_with_typer(result_exp)
|
|
9464
9593
|
case "try_to_number":
|
|
9465
9594
|
try_to_number = snowpark_fn.function("try_to_number")
|
|
9466
9595
|
precision, scale = resolve_to_number_precision_and_scale(exp)
|
|
@@ -11178,7 +11307,7 @@ def _try_sum_helper(
|
|
|
11178
11307
|
input_types=[arg_type],
|
|
11179
11308
|
)
|
|
11180
11309
|
# call the udaf
|
|
11181
|
-
return _try_sum_int_udaf(col_name),
|
|
11310
|
+
return _try_sum_int_udaf(col_name), LongType()
|
|
11182
11311
|
|
|
11183
11312
|
# NOTE: We will never call this function with an IntegerType column and calculating_avg=True. Therefore,
|
|
11184
11313
|
# we don't need to handle the case where calculating_avg=True here. The caller of this function will handle it.
|
|
@@ -11402,8 +11531,15 @@ def _arithmetic_operation(
|
|
|
11402
11531
|
op: Callable[[Column, Column], Column],
|
|
11403
11532
|
overflow_possible: bool,
|
|
11404
11533
|
should_raise_on_overflow: bool,
|
|
11405
|
-
target_type:
|
|
11534
|
+
target_type: DataType,
|
|
11535
|
+
operation_name: str,
|
|
11406
11536
|
) -> Column:
|
|
11537
|
+
if isinstance(target_type, _IntegralType):
|
|
11538
|
+
raw_result = op(arg1.col, arg2.col)
|
|
11539
|
+
return apply_arithmetic_overflow_with_ansi_check(
|
|
11540
|
+
raw_result, target_type, should_raise_on_overflow, operation_name
|
|
11541
|
+
)
|
|
11542
|
+
|
|
11407
11543
|
def _cast_arg(tc: TypedColumn) -> Column:
|
|
11408
11544
|
_, s = _get_type_precision(tc.typ)
|
|
11409
11545
|
typ = (
|
|
@@ -11490,12 +11626,12 @@ def _get_decimal_division_result_type(p1, s1, p2, s2) -> tuple[DecimalType, bool
|
|
|
11490
11626
|
|
|
11491
11627
|
def _try_arithmetic_helper(
|
|
11492
11628
|
typed_args: List[TypedColumn], snowpark_args: List[Column], operation_type: int
|
|
11493
|
-
) -> Column:
|
|
11629
|
+
) -> tuple[Column, DataType | None]:
|
|
11494
11630
|
# Constructs a Snowpark Column expression for a "try-style" arithmetic operation
|
|
11495
11631
|
# (addition or subtraction, determined by `operation_type`) between two input columns.
|
|
11496
11632
|
#
|
|
11497
11633
|
# Key behavioral characteristics:
|
|
11498
|
-
# 1. For **Integral inputs**: Explicitly checks for
|
|
11634
|
+
# 1. For **Integral inputs**: Explicitly checks for overflow/underflow at the result type boundaries.
|
|
11499
11635
|
# - BEHAVIOR: Returns a NULL literal if the operation would exceed these limits;
|
|
11500
11636
|
# otherwise, returns the result of the standard Snowpark `+` or `-`.
|
|
11501
11637
|
#
|
|
@@ -11508,10 +11644,11 @@ def _try_arithmetic_helper(
|
|
|
11508
11644
|
# Arithmetic operations involving **Boolean types** will raise an `AnalysisException`.
|
|
11509
11645
|
# All other unhandled incompatible type combinations result in a NULL literal.
|
|
11510
11646
|
# The function returns the resulting Snowpark Column expression.
|
|
11511
|
-
|
|
11512
11647
|
match (typed_args[0].typ, typed_args[1].typ):
|
|
11513
|
-
case (_IntegralType(), _IntegralType()):
|
|
11514
|
-
|
|
11648
|
+
case (_IntegralType() as t1, _IntegralType() as t2):
|
|
11649
|
+
result_type = _find_common_type([t1, t2])
|
|
11650
|
+
min_val, max_val = get_integral_type_bounds(result_type)
|
|
11651
|
+
|
|
11515
11652
|
if operation_type == 0: # Addition
|
|
11516
11653
|
result_exp = (
|
|
11517
11654
|
snowpark_fn.when(
|
|
@@ -11519,20 +11656,20 @@ def _try_arithmetic_helper(
|
|
|
11519
11656
|
& (snowpark_args[1] > 0)
|
|
11520
11657
|
& (
|
|
11521
11658
|
snowpark_args[0]
|
|
11522
|
-
> snowpark_fn.lit(
|
|
11659
|
+
> snowpark_fn.lit(max_val) - snowpark_args[1]
|
|
11523
11660
|
),
|
|
11524
|
-
snowpark_fn.lit(None),
|
|
11661
|
+
snowpark_fn.lit(None),
|
|
11525
11662
|
)
|
|
11526
11663
|
.when(
|
|
11527
11664
|
(snowpark_args[0] < 0)
|
|
11528
11665
|
& (snowpark_args[1] < 0)
|
|
11529
11666
|
& (
|
|
11530
11667
|
snowpark_args[0]
|
|
11531
|
-
< snowpark_fn.lit(
|
|
11668
|
+
< snowpark_fn.lit(min_val) - snowpark_args[1]
|
|
11532
11669
|
),
|
|
11533
|
-
snowpark_fn.lit(None),
|
|
11670
|
+
snowpark_fn.lit(None),
|
|
11534
11671
|
)
|
|
11535
|
-
.otherwise(snowpark_args[0] + snowpark_args[1])
|
|
11672
|
+
.otherwise((snowpark_args[0] + snowpark_args[1]).cast(result_type))
|
|
11536
11673
|
)
|
|
11537
11674
|
else: # Subtraction
|
|
11538
11675
|
result_exp = (
|
|
@@ -11541,22 +11678,22 @@ def _try_arithmetic_helper(
|
|
|
11541
11678
|
& (snowpark_args[1] < 0)
|
|
11542
11679
|
& (
|
|
11543
11680
|
snowpark_args[0]
|
|
11544
|
-
> snowpark_fn.lit(
|
|
11681
|
+
> snowpark_fn.lit(max_val) + snowpark_args[1]
|
|
11545
11682
|
),
|
|
11546
|
-
snowpark_fn.lit(None),
|
|
11683
|
+
snowpark_fn.lit(None),
|
|
11547
11684
|
)
|
|
11548
11685
|
.when(
|
|
11549
11686
|
(snowpark_args[0] < 0)
|
|
11550
11687
|
& (snowpark_args[1] > 0)
|
|
11551
11688
|
& (
|
|
11552
11689
|
snowpark_args[0]
|
|
11553
|
-
< snowpark_fn.lit(
|
|
11690
|
+
< snowpark_fn.lit(min_val) + snowpark_args[1]
|
|
11554
11691
|
),
|
|
11555
|
-
snowpark_fn.lit(None),
|
|
11692
|
+
snowpark_fn.lit(None),
|
|
11556
11693
|
)
|
|
11557
|
-
.otherwise(snowpark_args[0] - snowpark_args[1])
|
|
11694
|
+
.otherwise((snowpark_args[0] - snowpark_args[1]).cast(result_type))
|
|
11558
11695
|
)
|
|
11559
|
-
return result_exp
|
|
11696
|
+
return result_exp, result_type
|
|
11560
11697
|
case (DateType(), _) | (_, DateType()):
|
|
11561
11698
|
arg1, arg2 = typed_args[0].typ, typed_args[1].typ
|
|
11562
11699
|
# Valid input parameter types for try_add - DateType and _NumericType, _NumericType and DateType.
|
|
@@ -11577,22 +11714,28 @@ def _try_arithmetic_helper(
|
|
|
11577
11714
|
if isinstance(arg1, _IntegralType)
|
|
11578
11715
|
else snowpark_args
|
|
11579
11716
|
)
|
|
11580
|
-
return
|
|
11581
|
-
|
|
11582
|
-
|
|
11583
|
-
|
|
11717
|
+
return (
|
|
11718
|
+
_try_to_cast(
|
|
11719
|
+
"try_to_date",
|
|
11720
|
+
snowpark_fn.cast(snowpark_fn.date_add(*args), DateType()),
|
|
11721
|
+
args[0],
|
|
11722
|
+
),
|
|
11723
|
+
None,
|
|
11584
11724
|
)
|
|
11585
11725
|
else:
|
|
11586
11726
|
if isinstance(arg1, DateType) and isinstance(arg2, _IntegralType):
|
|
11587
|
-
return
|
|
11588
|
-
|
|
11589
|
-
|
|
11590
|
-
snowpark_fn.
|
|
11727
|
+
return (
|
|
11728
|
+
_try_to_cast(
|
|
11729
|
+
"try_to_date",
|
|
11730
|
+
snowpark_fn.to_date(
|
|
11731
|
+
snowpark_fn.date_sub(snowpark_args[0], snowpark_args[1])
|
|
11732
|
+
),
|
|
11733
|
+
snowpark_args[0],
|
|
11591
11734
|
),
|
|
11592
|
-
|
|
11735
|
+
None,
|
|
11593
11736
|
)
|
|
11594
11737
|
elif isinstance(arg1, DateType) and isinstance(arg2, DateType):
|
|
11595
|
-
return snowpark_fn.daydiff(snowpark_args[0], snowpark_args[1])
|
|
11738
|
+
return snowpark_fn.daydiff(snowpark_args[0], snowpark_args[1]), None
|
|
11596
11739
|
else:
|
|
11597
11740
|
exception = AnalysisException(
|
|
11598
11741
|
'[DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE] Cannot resolve "date_sub(dt, sub)" due to data type mismatch: Parameter 1 requires the "DATE" type and parameter 2 requires the ("INT" or "SMALLINT" or "TINYINT") type'
|
|
@@ -11609,12 +11752,16 @@ def _try_arithmetic_helper(
|
|
|
11609
11752
|
"try_add" if operation_type == 0 else "try_subtract",
|
|
11610
11753
|
)
|
|
11611
11754
|
|
|
11612
|
-
return
|
|
11613
|
-
|
|
11614
|
-
|
|
11615
|
-
|
|
11616
|
-
|
|
11617
|
-
|
|
11755
|
+
return (
|
|
11756
|
+
_arithmetic_operation(
|
|
11757
|
+
typed_args[0],
|
|
11758
|
+
typed_args[1],
|
|
11759
|
+
lambda x, y: x + y if operation_type == 0 else x - y,
|
|
11760
|
+
overflow_possible,
|
|
11761
|
+
False,
|
|
11762
|
+
result_type,
|
|
11763
|
+
"add" if operation_type == 0 else "subtract",
|
|
11764
|
+
),
|
|
11618
11765
|
result_type,
|
|
11619
11766
|
)
|
|
11620
11767
|
|
|
@@ -11622,11 +11769,12 @@ def _try_arithmetic_helper(
|
|
|
11622
11769
|
# matches Spark and goes to inf.
|
|
11623
11770
|
# Note that we already handle the int,int case above, hence it is okay to use the broader _numeric
|
|
11624
11771
|
# below.
|
|
11625
|
-
case (_NumericType(), _NumericType()):
|
|
11772
|
+
case (_NumericType() as t1, _NumericType() as t2):
|
|
11773
|
+
result_type = _find_common_type([t1, t2])
|
|
11626
11774
|
if operation_type == 0:
|
|
11627
|
-
return snowpark_args[0] + snowpark_args[1]
|
|
11775
|
+
return snowpark_args[0] + snowpark_args[1], result_type
|
|
11628
11776
|
else:
|
|
11629
|
-
return snowpark_args[0] - snowpark_args[1]
|
|
11777
|
+
return snowpark_args[0] - snowpark_args[1], result_type
|
|
11630
11778
|
# String cases - try to convert to numeric
|
|
11631
11779
|
case (
|
|
11632
11780
|
(StringType(), _NumericType())
|
|
@@ -11642,12 +11790,12 @@ def _try_arithmetic_helper(
|
|
|
11642
11790
|
updated_args = _validate_numeric_args(
|
|
11643
11791
|
"try_add", typed_args, snowpark_args
|
|
11644
11792
|
)
|
|
11645
|
-
return updated_args[0] + updated_args[1]
|
|
11793
|
+
return updated_args[0] + updated_args[1], None
|
|
11646
11794
|
else:
|
|
11647
11795
|
updated_args = _validate_numeric_args(
|
|
11648
11796
|
"try_subtract", typed_args, snowpark_args
|
|
11649
11797
|
)
|
|
11650
|
-
return updated_args[0] - updated_args[1]
|
|
11798
|
+
return updated_args[0] - updated_args[1], None
|
|
11651
11799
|
|
|
11652
11800
|
case (BooleanType(), _) | (_, BooleanType()):
|
|
11653
11801
|
exception = AnalysisException(
|
|
@@ -11657,7 +11805,7 @@ def _try_arithmetic_helper(
|
|
|
11657
11805
|
raise exception
|
|
11658
11806
|
case _:
|
|
11659
11807
|
# Return NULL for incompatible types
|
|
11660
|
-
return snowpark_fn.lit(None)
|
|
11808
|
+
return snowpark_fn.lit(None), None
|
|
11661
11809
|
|
|
11662
11810
|
|
|
11663
11811
|
def _get_add_sub_result_type(
|
|
@@ -11853,24 +12001,9 @@ def _get_literal_param_name(exp, arg_index: int, default_param_name: str):
|
|
|
11853
12001
|
|
|
11854
12002
|
|
|
11855
12003
|
def _raise_error_helper(return_type: DataType, error_class=None):
|
|
11856
|
-
|
|
11857
|
-
f":{error_class.__name__}"
|
|
11858
|
-
if error_class and hasattr(error_class, "__name__")
|
|
11859
|
-
else ""
|
|
11860
|
-
)
|
|
11861
|
-
|
|
11862
|
-
def _raise_fn(*msgs: Column) -> Column:
|
|
11863
|
-
return snowpark_fn.cast(
|
|
11864
|
-
snowpark_fn.abs(
|
|
11865
|
-
snowpark_fn.concat(
|
|
11866
|
-
snowpark_fn.lit(f"[snowpark-connect-exception{error_class}]"),
|
|
11867
|
-
*(msg.try_cast(StringType()) for msg in msgs),
|
|
11868
|
-
)
|
|
11869
|
-
).cast(StringType()),
|
|
11870
|
-
return_type,
|
|
11871
|
-
)
|
|
12004
|
+
from snowflake.snowpark_connect.expression.error_utils import raise_error_helper
|
|
11872
12005
|
|
|
11873
|
-
return
|
|
12006
|
+
return raise_error_helper(return_type, error_class)
|
|
11874
12007
|
|
|
11875
12008
|
|
|
11876
12009
|
def _divnull(dividend: Column, divisor: Column) -> Column:
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|