snowpark-connect 0.21.0__py3-none-any.whl → 0.22.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of snowpark-connect might be problematic. Click here for more details.
- snowflake/snowpark_connect/config.py +19 -3
- snowflake/snowpark_connect/error/error_utils.py +25 -0
- snowflake/snowpark_connect/expression/map_udf.py +4 -4
- snowflake/snowpark_connect/expression/map_unresolved_function.py +203 -128
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2_grpc.py +4 -0
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2_grpc.py +4 -0
- snowflake/snowpark_connect/relation/map_aggregate.py +102 -18
- snowflake/snowpark_connect/relation/map_column_ops.py +21 -2
- snowflake/snowpark_connect/relation/map_map_partitions.py +3 -1
- snowflake/snowpark_connect/relation/map_sql.py +18 -191
- snowflake/snowpark_connect/relation/map_udtf.py +4 -4
- snowflake/snowpark_connect/relation/read/map_read_json.py +12 -1
- snowflake/snowpark_connect/relation/read/reader_config.py +1 -0
- snowflake/snowpark_connect/relation/write/map_write.py +68 -24
- snowflake/snowpark_connect/server.py +9 -0
- snowflake/snowpark_connect/type_mapping.py +4 -0
- snowflake/snowpark_connect/utils/describe_query_cache.py +2 -9
- snowflake/snowpark_connect/utils/session.py +0 -4
- snowflake/snowpark_connect/utils/telemetry.py +213 -61
- snowflake/snowpark_connect/utils/udxf_import_utils.py +14 -0
- snowflake/snowpark_connect/version.py +1 -1
- snowflake/snowpark_decoder/__init__.py +0 -0
- snowflake/snowpark_decoder/_internal/proto/generated/DataframeProcessorMsg_pb2.py +36 -0
- snowflake/snowpark_decoder/_internal/proto/generated/DataframeProcessorMsg_pb2.pyi +156 -0
- snowflake/snowpark_decoder/dp_session.py +111 -0
- snowflake/snowpark_decoder/spark_decoder.py +76 -0
- {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.22.1.dist-info}/METADATA +2 -2
- {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.22.1.dist-info}/RECORD +40 -29
- {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.22.1.dist-info}/top_level.txt +1 -0
- spark/__init__.py +0 -0
- spark/connect/__init__.py +0 -0
- spark/connect/envelope_pb2.py +31 -0
- spark/connect/envelope_pb2.pyi +46 -0
- snowflake/snowpark_connect/includes/jars/jackson-mapper-asl-1.9.13.jar +0 -0
- {snowpark_connect-0.21.0.data → snowpark_connect-0.22.1.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.21.0.data → snowpark_connect-0.22.1.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.21.0.data → snowpark_connect-0.22.1.data}/scripts/snowpark-submit +0 -0
- {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.22.1.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.22.1.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.22.1.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.22.1.dist-info}/licenses/NOTICE-binary +0 -0
|
@@ -28,6 +28,7 @@ from google.protobuf.message import Message
|
|
|
28
28
|
from pyspark.errors.exceptions.base import (
|
|
29
29
|
AnalysisException,
|
|
30
30
|
ArithmeticException,
|
|
31
|
+
ArrayIndexOutOfBoundsException,
|
|
31
32
|
DateTimeException,
|
|
32
33
|
IllegalArgumentException,
|
|
33
34
|
NumberFormatException,
|
|
@@ -39,6 +40,7 @@ from pyspark.sql.types import _parse_datatype_json_string
|
|
|
39
40
|
import snowflake.snowpark.functions as snowpark_fn
|
|
40
41
|
from snowflake import snowpark
|
|
41
42
|
from snowflake.snowpark import Column, Session
|
|
43
|
+
from snowflake.snowpark._internal.analyzer.expression import Literal
|
|
42
44
|
from snowflake.snowpark._internal.analyzer.unary_expression import Alias
|
|
43
45
|
from snowflake.snowpark.types import (
|
|
44
46
|
ArrayType,
|
|
@@ -139,7 +141,7 @@ from snowflake.snowpark_connect.utils.xxhash64 import (
|
|
|
139
141
|
MAX_UINT64 = 2**64 - 1
|
|
140
142
|
MAX_INT64 = 2**63 - 1
|
|
141
143
|
MIN_INT64 = -(2**63)
|
|
142
|
-
|
|
144
|
+
MAX_ARRAY_SIZE = 2_147_483_647
|
|
143
145
|
|
|
144
146
|
NAN, INFINITY = float("nan"), float("inf")
|
|
145
147
|
|
|
@@ -638,37 +640,22 @@ def map_unresolved_function(
|
|
|
638
640
|
[arg.typ for arg in snowpark_typed_args]
|
|
639
641
|
)
|
|
640
642
|
case "/":
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
snowpark_args
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
and isinstance(snowpark_typed_args[1].typ, _IntegralType)
|
|
658
|
-
or isinstance(snowpark_typed_args[0].typ, _IntegralType)
|
|
659
|
-
and isinstance(snowpark_typed_args[1].typ, DecimalType)
|
|
660
|
-
):
|
|
661
|
-
result_exp, (
|
|
662
|
-
return_type_precision,
|
|
663
|
-
return_type_scale,
|
|
664
|
-
) = _mul_div_precision_helper(snowpark_typed_args, snowpark_args, 1)
|
|
665
|
-
result_type = DecimalType(return_type_precision, return_type_scale)
|
|
666
|
-
else:
|
|
667
|
-
# Perform division directly
|
|
668
|
-
result_exp = _divnull(snowpark_args[0], snowpark_args[1])
|
|
669
|
-
result_type = _find_common_type(
|
|
670
|
-
[arg.typ for arg in snowpark_typed_args]
|
|
671
|
-
)
|
|
643
|
+
match (snowpark_typed_args[0].typ, snowpark_typed_args[1].typ):
|
|
644
|
+
case (DecimalType(), t) | (t, DecimalType()) if isinstance(
|
|
645
|
+
t, DecimalType
|
|
646
|
+
) or isinstance(t, _IntegralType) or isinstance(
|
|
647
|
+
snowpark_typed_args[1].typ, NullType
|
|
648
|
+
):
|
|
649
|
+
result_exp, (
|
|
650
|
+
return_type_precision,
|
|
651
|
+
return_type_scale,
|
|
652
|
+
) = _mul_div_precision_helper(snowpark_typed_args, snowpark_args, 1)
|
|
653
|
+
result_type = DecimalType(return_type_precision, return_type_scale)
|
|
654
|
+
case _:
|
|
655
|
+
result_type = DoubleType()
|
|
656
|
+
dividend = snowpark_args[0].cast(result_type)
|
|
657
|
+
divisor = snowpark_args[1].cast(result_type)
|
|
658
|
+
result_exp = _divnull(dividend, divisor)
|
|
672
659
|
case "~":
|
|
673
660
|
result_exp = TypedColumn(
|
|
674
661
|
snowpark_fn.bitnot(snowpark_args[0]),
|
|
@@ -1201,35 +1188,18 @@ def map_unresolved_function(
|
|
|
1201
1188
|
snowpark_fn.asinh(snowpark_args[0]), lambda: [DoubleType()]
|
|
1202
1189
|
)
|
|
1203
1190
|
case "assert_true":
|
|
1191
|
+
result_type = NullType()
|
|
1192
|
+
raise_error = _raise_error_helper(result_type)
|
|
1204
1193
|
|
|
1205
|
-
@cached_udf(
|
|
1206
|
-
input_types=[BooleanType()],
|
|
1207
|
-
return_type=StringType(),
|
|
1208
|
-
)
|
|
1209
|
-
def _assert_true_single(expr):
|
|
1210
|
-
if not expr:
|
|
1211
|
-
raise ValueError("assertion failed")
|
|
1212
|
-
return None
|
|
1213
|
-
|
|
1214
|
-
@cached_udf(
|
|
1215
|
-
input_types=[BooleanType(), StringType()],
|
|
1216
|
-
return_type=StringType(),
|
|
1217
|
-
)
|
|
1218
|
-
def _assert_true_with_message(expr, message):
|
|
1219
|
-
if not expr:
|
|
1220
|
-
raise ValueError(message)
|
|
1221
|
-
return None
|
|
1222
|
-
|
|
1223
|
-
# Handle different argument counts using match pattern
|
|
1224
1194
|
match snowpark_args:
|
|
1225
1195
|
case [expr]:
|
|
1226
|
-
result_exp =
|
|
1227
|
-
|
|
1228
|
-
)
|
|
1196
|
+
result_exp = snowpark_fn.when(
|
|
1197
|
+
expr, snowpark_fn.lit(None)
|
|
1198
|
+
).otherwise(raise_error(snowpark_fn.lit("assertion failed")))
|
|
1229
1199
|
case [expr, message]:
|
|
1230
|
-
result_exp =
|
|
1231
|
-
|
|
1232
|
-
)
|
|
1200
|
+
result_exp = snowpark_fn.when(
|
|
1201
|
+
expr, snowpark_fn.lit(None)
|
|
1202
|
+
).otherwise(raise_error(snowpark_fn.cast(message, StringType())))
|
|
1233
1203
|
case _:
|
|
1234
1204
|
raise AnalysisException(
|
|
1235
1205
|
f"[WRONG_NUM_ARGS.WITHOUT_SUGGESTION] The `assert_true` requires 1 or 2 parameters but the actual number is {len(snowpark_args)}."
|
|
@@ -2291,31 +2261,32 @@ def map_unresolved_function(
|
|
|
2291
2261
|
)
|
|
2292
2262
|
case "elt":
|
|
2293
2263
|
n = snowpark_args[0]
|
|
2294
|
-
|
|
2295
2264
|
values = snowpark_fn.array_construct(*snowpark_args[1:])
|
|
2296
2265
|
|
|
2297
2266
|
if spark_sql_ansi_enabled:
|
|
2298
|
-
|
|
2299
|
-
|
|
2300
|
-
input_types=[IntegerType()],
|
|
2301
|
-
return_type=StringType(),
|
|
2267
|
+
raise_error = _raise_error_helper(
|
|
2268
|
+
StringType(), error_class=ArrayIndexOutOfBoundsException
|
|
2302
2269
|
)
|
|
2303
|
-
def _raise_out_of_bounds_error(n: int) -> str:
|
|
2304
|
-
raise ValueError(
|
|
2305
|
-
f"ArrayIndexOutOfBoundsException: {n} is not within the input bounds."
|
|
2306
|
-
)
|
|
2307
|
-
|
|
2308
2270
|
values_size = snowpark_fn.lit(len(snowpark_args) - 1)
|
|
2309
2271
|
|
|
2310
2272
|
result_exp = (
|
|
2311
2273
|
snowpark_fn.when(snowpark_fn.is_null(n), snowpark_fn.lit(None))
|
|
2312
2274
|
.when(
|
|
2313
2275
|
(snowpark_fn.lit(1) <= n) & (n <= values_size),
|
|
2314
|
-
snowpark_fn.
|
|
2315
|
-
|
|
2276
|
+
snowpark_fn.cast(
|
|
2277
|
+
snowpark_fn.get(
|
|
2278
|
+
values, snowpark_fn.nvl(n - 1, snowpark_fn.lit(0))
|
|
2279
|
+
),
|
|
2280
|
+
StringType(),
|
|
2316
2281
|
),
|
|
2317
2282
|
)
|
|
2318
|
-
.otherwise(
|
|
2283
|
+
.otherwise(
|
|
2284
|
+
raise_error(
|
|
2285
|
+
snowpark_fn.lit("[INVALID_ARRAY_INDEX] The index "),
|
|
2286
|
+
snowpark_fn.cast(n, StringType()),
|
|
2287
|
+
snowpark_fn.lit(" is out of bounds."),
|
|
2288
|
+
)
|
|
2289
|
+
)
|
|
2319
2290
|
)
|
|
2320
2291
|
else:
|
|
2321
2292
|
result_exp = snowpark_fn.when(
|
|
@@ -3289,7 +3260,7 @@ def map_unresolved_function(
|
|
|
3289
3260
|
).cast(LongType())
|
|
3290
3261
|
result_type = LongType()
|
|
3291
3262
|
case "hll_union_agg":
|
|
3292
|
-
raise_error =
|
|
3263
|
+
raise_error = _raise_error_helper(BinaryType())
|
|
3293
3264
|
args = exp.unresolved_function.arguments
|
|
3294
3265
|
allow_different_lgConfigK = len(args) == 2 and unwrap_literal(args[1])
|
|
3295
3266
|
spark_function_name = f"{function_name}({snowpark_arg_names[0]}, {str(allow_different_lgConfigK).lower()})"
|
|
@@ -3329,7 +3300,7 @@ def map_unresolved_function(
|
|
|
3329
3300
|
SELECT arg1 as x)
|
|
3330
3301
|
""",
|
|
3331
3302
|
)
|
|
3332
|
-
raise_error =
|
|
3303
|
+
raise_error = _raise_error_helper(BinaryType())
|
|
3333
3304
|
args = exp.unresolved_function.arguments
|
|
3334
3305
|
allow_different_lgConfigK = len(args) == 3 and unwrap_literal(args[2])
|
|
3335
3306
|
spark_function_name = f"{function_name}({snowpark_arg_names[0]}, {snowpark_arg_names[1]}, {str(allow_different_lgConfigK).lower()})"
|
|
@@ -4053,7 +4024,7 @@ def map_unresolved_function(
|
|
|
4053
4024
|
snowpark_fn.is_null(snowpark_args[i]),
|
|
4054
4025
|
# udf execution on XP seems to be lazy, so this should only run when there is a null key
|
|
4055
4026
|
# otherwise there should be no udf env setup or execution
|
|
4056
|
-
|
|
4027
|
+
_raise_error_helper(VariantType())(
|
|
4057
4028
|
snowpark_fn.lit(
|
|
4058
4029
|
"[NULL_MAP_KEY] Cannot use null as map key."
|
|
4059
4030
|
)
|
|
@@ -4115,6 +4086,14 @@ def map_unresolved_function(
|
|
|
4115
4086
|
)
|
|
4116
4087
|
result_type = MapType(key_type, value_type)
|
|
4117
4088
|
case "map_contains_key":
|
|
4089
|
+
if isinstance(snowpark_typed_args[0].typ, NullType):
|
|
4090
|
+
raise AnalysisException(
|
|
4091
|
+
f"""[DATATYPE_MISMATCH.MAP_FUNCTION_DIFF_TYPES] Cannot resolve "map_contains_key({snowpark_arg_names[0]}, {snowpark_arg_names[1]})" due to data type mismatch: Input to `map_contains_key` should have been "MAP" followed by a value with same key type, but it's ["VOID", "INT"]."""
|
|
4092
|
+
)
|
|
4093
|
+
if isinstance(snowpark_typed_args[1].typ, NullType):
|
|
4094
|
+
raise AnalysisException(
|
|
4095
|
+
f"""[DATATYPE_MISMATCH.NULL_TYPE] Cannot resolve "map_contains_key({snowpark_arg_names[0]}, {snowpark_arg_names[1]})" due to data type mismatch: Null typed values cannot be used as arguments of `map_contains_key`."""
|
|
4096
|
+
)
|
|
4118
4097
|
args = (
|
|
4119
4098
|
[snowpark_args[1], snowpark_args[0]]
|
|
4120
4099
|
if isinstance(snowpark_typed_args[0].typ, MapType)
|
|
@@ -4244,17 +4223,37 @@ def map_unresolved_function(
|
|
|
4244
4223
|
|
|
4245
4224
|
last_win_dedup = global_config.spark_sql_mapKeyDedupPolicy == "LAST_WIN"
|
|
4246
4225
|
|
|
4247
|
-
|
|
4248
|
-
|
|
4249
|
-
|
|
4250
|
-
snowpark_fn.
|
|
4251
|
-
|
|
4252
|
-
|
|
4253
|
-
|
|
4254
|
-
|
|
4255
|
-
|
|
4256
|
-
|
|
4226
|
+
# Check if any entry has a NULL key
|
|
4227
|
+
has_null_key = (
|
|
4228
|
+
snowpark_fn.function("array_size")(
|
|
4229
|
+
snowpark_fn.function("filter")(
|
|
4230
|
+
snowpark_args[0],
|
|
4231
|
+
snowpark_fn.sql_expr(f"e -> e:{key_field} IS NULL"),
|
|
4232
|
+
)
|
|
4233
|
+
)
|
|
4234
|
+
> 0
|
|
4235
|
+
)
|
|
4236
|
+
|
|
4237
|
+
# Create error UDF for NULL keys (same pattern as map function)
|
|
4238
|
+
null_key_error = _raise_error_helper(VariantType())(
|
|
4239
|
+
snowpark_fn.lit("[NULL_MAP_KEY] Cannot use null as map key.")
|
|
4240
|
+
)
|
|
4241
|
+
|
|
4242
|
+
# Create the reduce operation
|
|
4243
|
+
reduce_result = snowpark_fn.function("reduce")(
|
|
4244
|
+
snowpark_args[0],
|
|
4245
|
+
snowpark_fn.object_construct(),
|
|
4246
|
+
snowpark_fn.sql_expr(
|
|
4247
|
+
# value_field is cast to variant because object_insert doesn't allow structured types,
|
|
4248
|
+
# and structured types are not coercible to variant
|
|
4249
|
+
# TODO: allow structured types in object_insert?
|
|
4250
|
+
f"(acc, e) -> object_insert(acc, e:{key_field}, e:{value_field}::variant, {last_win_dedup})"
|
|
4257
4251
|
),
|
|
4252
|
+
)
|
|
4253
|
+
|
|
4254
|
+
# Use conditional logic: if there are NULL keys, throw error; otherwise proceed with reduce
|
|
4255
|
+
result_exp = snowpark_fn.cast(
|
|
4256
|
+
snowpark_fn.when(has_null_key, null_key_error).otherwise(reduce_result),
|
|
4258
4257
|
MapType(key_type, value_type),
|
|
4259
4258
|
)
|
|
4260
4259
|
result_type = MapType(key_type, value_type)
|
|
@@ -4273,23 +4272,35 @@ def map_unresolved_function(
|
|
|
4273
4272
|
# TODO: implement in Snowflake/Snowpark
|
|
4274
4273
|
# technically this could be done with a lateral join, but it's probably not worth the effort
|
|
4275
4274
|
arg_type = snowpark_typed_args[0].typ
|
|
4276
|
-
if not isinstance(arg_type, MapType):
|
|
4275
|
+
if not isinstance(arg_type, (MapType, NullType)):
|
|
4277
4276
|
raise AnalysisException(
|
|
4278
4277
|
f"map_values requires a MapType argument, got {arg_type}"
|
|
4279
4278
|
)
|
|
4280
4279
|
|
|
4281
4280
|
def _map_values(obj: dict) -> list:
|
|
4282
|
-
|
|
4281
|
+
if obj is None:
|
|
4282
|
+
return None
|
|
4283
|
+
return list(obj.values())
|
|
4283
4284
|
|
|
4284
4285
|
map_values = cached_udf(
|
|
4285
4286
|
_map_values, return_type=ArrayType(), input_types=[StructType()]
|
|
4286
4287
|
)
|
|
4287
4288
|
|
|
4288
|
-
|
|
4289
|
-
|
|
4290
|
-
|
|
4291
|
-
|
|
4292
|
-
|
|
4289
|
+
# Handle NULL input directly at expression level
|
|
4290
|
+
if isinstance(arg_type, NullType):
|
|
4291
|
+
# If input is NULL literal, return NULL
|
|
4292
|
+
result_exp = snowpark_fn.lit(None)
|
|
4293
|
+
result_type = ArrayType(NullType())
|
|
4294
|
+
else:
|
|
4295
|
+
result_exp = snowpark_fn.when(
|
|
4296
|
+
snowpark_args[0].is_null(), snowpark_fn.lit(None)
|
|
4297
|
+
).otherwise(
|
|
4298
|
+
snowpark_fn.cast(
|
|
4299
|
+
map_values(snowpark_fn.cast(snowpark_args[0], StructType())),
|
|
4300
|
+
ArrayType(arg_type.value_type),
|
|
4301
|
+
)
|
|
4302
|
+
)
|
|
4303
|
+
result_type = ArrayType(arg_type.value_type)
|
|
4293
4304
|
case "mask":
|
|
4294
4305
|
|
|
4295
4306
|
number_of_args = len(snowpark_args)
|
|
@@ -5194,7 +5205,7 @@ def map_unresolved_function(
|
|
|
5194
5205
|
result_type = DoubleType()
|
|
5195
5206
|
case "raise_error":
|
|
5196
5207
|
result_type = StringType()
|
|
5197
|
-
raise_error =
|
|
5208
|
+
raise_error = _raise_error_helper(result_type)
|
|
5198
5209
|
result_exp = raise_error(*snowpark_args)
|
|
5199
5210
|
case "rand" | "random":
|
|
5200
5211
|
# Snowpark random() generates a 64 bit signed integer, but pyspark is [0.0, 1.0).
|
|
@@ -5279,7 +5290,7 @@ def map_unresolved_function(
|
|
|
5279
5290
|
snowpark_args[2],
|
|
5280
5291
|
),
|
|
5281
5292
|
),
|
|
5282
|
-
|
|
5293
|
+
_raise_error_helper(StringType())(
|
|
5283
5294
|
snowpark_fn.lit(
|
|
5284
5295
|
"[INVALID_PARAMETER_VALUE.REGEX_GROUP_INDEX] The value of parameter(s) `idx` in `regexp_extract` is invalid."
|
|
5285
5296
|
)
|
|
@@ -5329,7 +5340,7 @@ def map_unresolved_function(
|
|
|
5329
5340
|
idx,
|
|
5330
5341
|
)
|
|
5331
5342
|
),
|
|
5332
|
-
|
|
5343
|
+
_raise_error_helper(ArrayType(StringType()))(
|
|
5333
5344
|
snowpark_fn.lit(
|
|
5334
5345
|
"[INVALID_PARAMETER_VALUE.REGEX_GROUP_INDEX] The value of parameter(s) `idx` in `regexp_extract_all` is invalid."
|
|
5335
5346
|
)
|
|
@@ -6043,7 +6054,7 @@ def map_unresolved_function(
|
|
|
6043
6054
|
result_exp = snowpark_fn.skew(snowpark_fn.lit(None))
|
|
6044
6055
|
result_type = DoubleType()
|
|
6045
6056
|
case "slice":
|
|
6046
|
-
raise_error =
|
|
6057
|
+
raise_error = _raise_error_helper(snowpark_typed_args[0].typ)
|
|
6047
6058
|
spark_index = snowpark_args[1]
|
|
6048
6059
|
arr_size = snowpark_fn.array_size(snowpark_args[0])
|
|
6049
6060
|
slice_len = snowpark_args[2]
|
|
@@ -6113,10 +6124,11 @@ def map_unresolved_function(
|
|
|
6113
6124
|
result_exp = snowpark_fn.lit(0)
|
|
6114
6125
|
result_type = LongType()
|
|
6115
6126
|
case "split":
|
|
6127
|
+
result_type = ArrayType(StringType())
|
|
6116
6128
|
|
|
6117
6129
|
@cached_udf(
|
|
6118
6130
|
input_types=[StringType(), StringType(), IntegerType()],
|
|
6119
|
-
return_type=
|
|
6131
|
+
return_type=result_type,
|
|
6120
6132
|
)
|
|
6121
6133
|
def _split(
|
|
6122
6134
|
input: Optional[str], pattern: Optional[str], limit: Optional[int]
|
|
@@ -6124,34 +6136,80 @@ def map_unresolved_function(
|
|
|
6124
6136
|
if input is None or pattern is None:
|
|
6125
6137
|
return None
|
|
6126
6138
|
|
|
6139
|
+
import re
|
|
6140
|
+
|
|
6141
|
+
try:
|
|
6142
|
+
re.compile(pattern)
|
|
6143
|
+
except re.error:
|
|
6144
|
+
raise ValueError(
|
|
6145
|
+
f"Failed to split string, provided pattern: {pattern} is invalid"
|
|
6146
|
+
)
|
|
6147
|
+
|
|
6127
6148
|
if limit == 1:
|
|
6128
6149
|
return [input]
|
|
6129
6150
|
|
|
6130
|
-
|
|
6151
|
+
if not input:
|
|
6152
|
+
return []
|
|
6131
6153
|
|
|
6132
6154
|
# A default of -1 is passed in PySpark, but RE needs it to be 0 to provide all splits.
|
|
6133
6155
|
# In PySpark, the limit also indicates the max size of the resulting array, but in RE
|
|
6134
6156
|
# the remainder is returned as another element.
|
|
6135
6157
|
maxsplit = limit - 1 if limit > 0 else 0
|
|
6136
6158
|
|
|
6137
|
-
split_result = re.split(pattern, input, maxsplit)
|
|
6138
6159
|
if len(pattern) == 0:
|
|
6139
|
-
|
|
6140
|
-
split_result = split_result[1 : len(split_result) - 1]
|
|
6160
|
+
return list(input) if limit <= 0 else list(input)[:limit]
|
|
6141
6161
|
|
|
6142
|
-
|
|
6162
|
+
match pattern:
|
|
6163
|
+
case "|":
|
|
6164
|
+
split_result = re.split(pattern, input, 0)
|
|
6165
|
+
input_limit = limit + 1 if limit > 0 else len(split_result)
|
|
6166
|
+
return (
|
|
6167
|
+
split_result
|
|
6168
|
+
if input_limit == 0
|
|
6169
|
+
else split_result[1:input_limit]
|
|
6170
|
+
)
|
|
6171
|
+
case "$":
|
|
6172
|
+
return [input, ""] if maxsplit >= 0 else [input]
|
|
6173
|
+
case "^":
|
|
6174
|
+
return [input]
|
|
6175
|
+
case _:
|
|
6176
|
+
return re.split(pattern, input, maxsplit)
|
|
6177
|
+
|
|
6178
|
+
def split_string(str_: Column, pattern: Column, limit: Column):
|
|
6179
|
+
native_split = _split(str_, pattern, limit)
|
|
6180
|
+
# When pattern is a literal and doesn't contain any regex special characters
|
|
6181
|
+
# And when limit is less than or equal to 0
|
|
6182
|
+
# Native Snowflake Split function is used to optimise performance
|
|
6183
|
+
if isinstance(pattern._expression, Literal):
|
|
6184
|
+
pattern_value = pattern._expression.value
|
|
6185
|
+
|
|
6186
|
+
if pattern_value is None:
|
|
6187
|
+
return snowpark_fn.lit(None)
|
|
6188
|
+
|
|
6189
|
+
is_regexp = re.match(
|
|
6190
|
+
".*[\\[\\.\\]\\*\\?\\+\\^\\$\\{\\}\\|\\(\\)\\\\].*",
|
|
6191
|
+
pattern_value,
|
|
6192
|
+
)
|
|
6193
|
+
is_empty = len(pattern_value) == 0
|
|
6194
|
+
|
|
6195
|
+
if not is_empty and not is_regexp:
|
|
6196
|
+
return snowpark_fn.when(
|
|
6197
|
+
limit <= 0,
|
|
6198
|
+
snowpark_fn.split(str_, pattern).cast(result_type),
|
|
6199
|
+
).otherwise(native_split)
|
|
6200
|
+
|
|
6201
|
+
return native_split
|
|
6143
6202
|
|
|
6144
6203
|
match snowpark_args:
|
|
6145
6204
|
case [str_, pattern]:
|
|
6146
6205
|
spark_function_name = (
|
|
6147
6206
|
f"split({snowpark_arg_names[0]}, {snowpark_arg_names[1]}, -1)"
|
|
6148
6207
|
)
|
|
6149
|
-
result_exp =
|
|
6208
|
+
result_exp = split_string(str_, pattern, snowpark_fn.lit(-1))
|
|
6150
6209
|
case [str_, pattern, limit]: # noqa: F841
|
|
6151
|
-
result_exp =
|
|
6210
|
+
result_exp = split_string(str_, pattern, limit)
|
|
6152
6211
|
case _:
|
|
6153
6212
|
raise ValueError(f"Invalid number of arguments to {function_name}")
|
|
6154
|
-
result_type = ArrayType(StringType())
|
|
6155
6213
|
case "split_part":
|
|
6156
6214
|
result_exp = snowpark_fn.call_function("split_part", *snowpark_args)
|
|
6157
6215
|
result_type = StringType()
|
|
@@ -6920,7 +6978,20 @@ def map_unresolved_function(
|
|
|
6920
6978
|
result_type = StringType()
|
|
6921
6979
|
case "trunc":
|
|
6922
6980
|
part = unwrap_literal(exp.unresolved_function.arguments[1])
|
|
6923
|
-
if part is None
|
|
6981
|
+
part = None if part is None else part.lower()
|
|
6982
|
+
|
|
6983
|
+
allowed_parts = {
|
|
6984
|
+
"year",
|
|
6985
|
+
"yyyy",
|
|
6986
|
+
"yy",
|
|
6987
|
+
"month",
|
|
6988
|
+
"mon",
|
|
6989
|
+
"mm",
|
|
6990
|
+
"week",
|
|
6991
|
+
"quarter",
|
|
6992
|
+
}
|
|
6993
|
+
|
|
6994
|
+
if part not in allowed_parts:
|
|
6924
6995
|
result_exp = snowpark_fn.lit(None)
|
|
6925
6996
|
else:
|
|
6926
6997
|
result_exp = _try_to_cast(
|
|
@@ -7311,7 +7382,7 @@ def map_unresolved_function(
|
|
|
7311
7382
|
)
|
|
7312
7383
|
)
|
|
7313
7384
|
)
|
|
7314
|
-
raise_fn =
|
|
7385
|
+
raise_fn = _raise_error_helper(BinaryType(), IllegalArgumentException)
|
|
7315
7386
|
result_exp = (
|
|
7316
7387
|
snowpark_fn.when(unbase_arg.is_null(), snowpark_fn.lit(None))
|
|
7317
7388
|
.when(result_exp.is_null(), raise_fn(snowpark_fn.lit("Invalid input")))
|
|
@@ -7927,16 +7998,8 @@ def _handle_current_timestamp():
|
|
|
7927
7998
|
|
|
7928
7999
|
|
|
7929
8000
|
def _equivalent_decimal(type):
|
|
7930
|
-
|
|
7931
|
-
|
|
7932
|
-
return DecimalType(3, 0)
|
|
7933
|
-
case ShortType():
|
|
7934
|
-
return DecimalType(5, 0)
|
|
7935
|
-
case IntegerType():
|
|
7936
|
-
return DecimalType(10, 0)
|
|
7937
|
-
case LongType():
|
|
7938
|
-
return DecimalType(20, 0)
|
|
7939
|
-
return DecimalType(38, 0)
|
|
8001
|
+
(precision, scale) = _get_type_precision(type)
|
|
8002
|
+
return DecimalType(precision, scale)
|
|
7940
8003
|
|
|
7941
8004
|
|
|
7942
8005
|
def _resolve_decimal_and_numeric(type1: DecimalType, type2: _NumericType) -> DataType:
|
|
@@ -8995,7 +9058,9 @@ def _get_type_precision(typ: DataType) -> tuple[int, int]:
|
|
|
8995
9058
|
case IntegerType():
|
|
8996
9059
|
return 10, 0 # -2147483648 to 2147483647
|
|
8997
9060
|
case LongType():
|
|
8998
|
-
return
|
|
9061
|
+
return 20, 0 # -9223372036854775808 to 9223372036854775807
|
|
9062
|
+
case NullType():
|
|
9063
|
+
return 6, 2 # NULL
|
|
8999
9064
|
case _:
|
|
9000
9065
|
return 38, 0 # Default to maximum precision for other types
|
|
9001
9066
|
|
|
@@ -9210,16 +9275,12 @@ def _try_arithmetic_helper(
|
|
|
9210
9275
|
typed_args[1].typ, DecimalType
|
|
9211
9276
|
):
|
|
9212
9277
|
new_scale = s2
|
|
9213
|
-
new_precision = (
|
|
9214
|
-
p1 + s2 + 1
|
|
9215
|
-
) # Integral precision + decimal scale + 1 for carry
|
|
9278
|
+
new_precision = max(p2, p1 + s2)
|
|
9216
9279
|
elif isinstance(typed_args[0].typ, DecimalType) and isinstance(
|
|
9217
9280
|
typed_args[1].typ, _IntegralType
|
|
9218
9281
|
):
|
|
9219
9282
|
new_scale = s1
|
|
9220
|
-
new_precision = (
|
|
9221
|
-
p2 + s1 + 1
|
|
9222
|
-
) # Integral precision + decimal scale + 1 for carry
|
|
9283
|
+
new_precision = max(p1, p2 + s1)
|
|
9223
9284
|
else:
|
|
9224
9285
|
# Both decimal types
|
|
9225
9286
|
if operation_type == 1 and s1 == s2: # subtraction with matching scales
|
|
@@ -9298,13 +9359,13 @@ def _add_sub_precision_helper(
|
|
|
9298
9359
|
typed_args[1].typ, DecimalType
|
|
9299
9360
|
):
|
|
9300
9361
|
new_scale = s2
|
|
9301
|
-
new_precision = p1 + s2
|
|
9362
|
+
new_precision = max(p2, p1 + s2)
|
|
9302
9363
|
return_type_precision, return_type_scale = new_precision, new_scale
|
|
9303
9364
|
elif isinstance(typed_args[0].typ, DecimalType) and isinstance(
|
|
9304
9365
|
typed_args[1].typ, _IntegralType
|
|
9305
9366
|
):
|
|
9306
9367
|
new_scale = s1
|
|
9307
|
-
new_precision = p2 + s1
|
|
9368
|
+
new_precision = max(p1, p2 + s1)
|
|
9308
9369
|
return_type_precision, return_type_scale = new_precision, new_scale
|
|
9309
9370
|
else:
|
|
9310
9371
|
(
|
|
@@ -9386,11 +9447,25 @@ def _mul_div_precision_helper(
|
|
|
9386
9447
|
)
|
|
9387
9448
|
|
|
9388
9449
|
|
|
9389
|
-
def
|
|
9390
|
-
|
|
9391
|
-
|
|
9450
|
+
def _raise_error_helper(return_type: DataType, error_class=None):
|
|
9451
|
+
error_class = (
|
|
9452
|
+
f":{error_class.__name__}"
|
|
9453
|
+
if error_class and hasattr(error_class, "__name__")
|
|
9454
|
+
else ""
|
|
9455
|
+
)
|
|
9456
|
+
|
|
9457
|
+
def _raise_fn(*msgs: Column) -> Column:
|
|
9458
|
+
return snowpark_fn.cast(
|
|
9459
|
+
snowpark_fn.abs(
|
|
9460
|
+
snowpark_fn.concat(
|
|
9461
|
+
snowpark_fn.lit(f"[snowpark-connect-exception{error_class}]"),
|
|
9462
|
+
*(msg.try_cast(StringType()) for msg in msgs),
|
|
9463
|
+
)
|
|
9464
|
+
).cast(StringType()),
|
|
9465
|
+
return_type,
|
|
9466
|
+
)
|
|
9392
9467
|
|
|
9393
|
-
return
|
|
9468
|
+
return _raise_fn
|
|
9394
9469
|
|
|
9395
9470
|
|
|
9396
9471
|
def _divnull(dividend: Column, divisor: Column) -> Column:
|