snowpark-connect 0.21.0__py3-none-any.whl → 0.23.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of snowpark-connect might be problematic. Click here for more details.
- snowflake/snowpark_connect/config.py +19 -14
- snowflake/snowpark_connect/error/error_utils.py +32 -0
- snowflake/snowpark_connect/error/exceptions.py +4 -0
- snowflake/snowpark_connect/expression/hybrid_column_map.py +192 -0
- snowflake/snowpark_connect/expression/literal.py +9 -12
- snowflake/snowpark_connect/expression/map_cast.py +20 -4
- snowflake/snowpark_connect/expression/map_expression.py +8 -1
- snowflake/snowpark_connect/expression/map_udf.py +4 -4
- snowflake/snowpark_connect/expression/map_unresolved_extract_value.py +32 -5
- snowflake/snowpark_connect/expression/map_unresolved_function.py +269 -134
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +8 -8
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +4 -2
- snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +127 -21
- snowflake/snowpark_connect/relation/map_aggregate.py +154 -18
- snowflake/snowpark_connect/relation/map_column_ops.py +59 -8
- snowflake/snowpark_connect/relation/map_extension.py +58 -24
- snowflake/snowpark_connect/relation/map_local_relation.py +8 -1
- snowflake/snowpark_connect/relation/map_map_partitions.py +3 -1
- snowflake/snowpark_connect/relation/map_row_ops.py +30 -1
- snowflake/snowpark_connect/relation/map_sql.py +40 -196
- snowflake/snowpark_connect/relation/map_udtf.py +4 -4
- snowflake/snowpark_connect/relation/read/map_read.py +2 -1
- snowflake/snowpark_connect/relation/read/map_read_json.py +12 -1
- snowflake/snowpark_connect/relation/read/map_read_parquet.py +8 -1
- snowflake/snowpark_connect/relation/read/reader_config.py +10 -0
- snowflake/snowpark_connect/relation/read/utils.py +7 -6
- snowflake/snowpark_connect/relation/utils.py +170 -1
- snowflake/snowpark_connect/relation/write/map_write.py +306 -87
- snowflake/snowpark_connect/server.py +34 -5
- snowflake/snowpark_connect/type_mapping.py +6 -2
- snowflake/snowpark_connect/utils/describe_query_cache.py +2 -9
- snowflake/snowpark_connect/utils/env_utils.py +55 -0
- snowflake/snowpark_connect/utils/session.py +21 -4
- snowflake/snowpark_connect/utils/telemetry.py +213 -61
- snowflake/snowpark_connect/utils/udxf_import_utils.py +14 -0
- snowflake/snowpark_connect/version.py +1 -1
- snowflake/snowpark_decoder/__init__.py +0 -0
- snowflake/snowpark_decoder/_internal/proto/generated/DataframeProcessorMsg_pb2.py +36 -0
- snowflake/snowpark_decoder/_internal/proto/generated/DataframeProcessorMsg_pb2.pyi +156 -0
- snowflake/snowpark_decoder/dp_session.py +111 -0
- snowflake/snowpark_decoder/spark_decoder.py +76 -0
- {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.23.0.dist-info}/METADATA +2 -2
- {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.23.0.dist-info}/RECORD +55 -44
- {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.23.0.dist-info}/top_level.txt +1 -0
- spark/__init__.py +0 -0
- spark/connect/__init__.py +0 -0
- spark/connect/envelope_pb2.py +31 -0
- spark/connect/envelope_pb2.pyi +46 -0
- snowflake/snowpark_connect/includes/jars/jackson-mapper-asl-1.9.13.jar +0 -0
- {snowpark_connect-0.21.0.data → snowpark_connect-0.23.0.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.21.0.data → snowpark_connect-0.23.0.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.21.0.data → snowpark_connect-0.23.0.data}/scripts/snowpark-submit +0 -0
- {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.23.0.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.23.0.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.23.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.23.0.dist-info}/licenses/NOTICE-binary +0 -0
|
@@ -28,6 +28,7 @@ from google.protobuf.message import Message
|
|
|
28
28
|
from pyspark.errors.exceptions.base import (
|
|
29
29
|
AnalysisException,
|
|
30
30
|
ArithmeticException,
|
|
31
|
+
ArrayIndexOutOfBoundsException,
|
|
31
32
|
DateTimeException,
|
|
32
33
|
IllegalArgumentException,
|
|
33
34
|
NumberFormatException,
|
|
@@ -39,6 +40,7 @@ from pyspark.sql.types import _parse_datatype_json_string
|
|
|
39
40
|
import snowflake.snowpark.functions as snowpark_fn
|
|
40
41
|
from snowflake import snowpark
|
|
41
42
|
from snowflake.snowpark import Column, Session
|
|
43
|
+
from snowflake.snowpark._internal.analyzer.expression import Literal
|
|
42
44
|
from snowflake.snowpark._internal.analyzer.unary_expression import Alias
|
|
43
45
|
from snowflake.snowpark.types import (
|
|
44
46
|
ArrayType,
|
|
@@ -139,7 +141,7 @@ from snowflake.snowpark_connect.utils.xxhash64 import (
|
|
|
139
141
|
MAX_UINT64 = 2**64 - 1
|
|
140
142
|
MAX_INT64 = 2**63 - 1
|
|
141
143
|
MIN_INT64 = -(2**63)
|
|
142
|
-
|
|
144
|
+
MAX_ARRAY_SIZE = 2_147_483_647
|
|
143
145
|
|
|
144
146
|
NAN, INFINITY = float("nan"), float("inf")
|
|
145
147
|
|
|
@@ -638,37 +640,22 @@ def map_unresolved_function(
|
|
|
638
640
|
[arg.typ for arg in snowpark_typed_args]
|
|
639
641
|
)
|
|
640
642
|
case "/":
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
snowpark_args
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
and isinstance(snowpark_typed_args[1].typ, _IntegralType)
|
|
658
|
-
or isinstance(snowpark_typed_args[0].typ, _IntegralType)
|
|
659
|
-
and isinstance(snowpark_typed_args[1].typ, DecimalType)
|
|
660
|
-
):
|
|
661
|
-
result_exp, (
|
|
662
|
-
return_type_precision,
|
|
663
|
-
return_type_scale,
|
|
664
|
-
) = _mul_div_precision_helper(snowpark_typed_args, snowpark_args, 1)
|
|
665
|
-
result_type = DecimalType(return_type_precision, return_type_scale)
|
|
666
|
-
else:
|
|
667
|
-
# Perform division directly
|
|
668
|
-
result_exp = _divnull(snowpark_args[0], snowpark_args[1])
|
|
669
|
-
result_type = _find_common_type(
|
|
670
|
-
[arg.typ for arg in snowpark_typed_args]
|
|
671
|
-
)
|
|
643
|
+
match (snowpark_typed_args[0].typ, snowpark_typed_args[1].typ):
|
|
644
|
+
case (DecimalType(), t) | (t, DecimalType()) if isinstance(
|
|
645
|
+
t, DecimalType
|
|
646
|
+
) or isinstance(t, _IntegralType) or isinstance(
|
|
647
|
+
snowpark_typed_args[1].typ, NullType
|
|
648
|
+
):
|
|
649
|
+
result_exp, (
|
|
650
|
+
return_type_precision,
|
|
651
|
+
return_type_scale,
|
|
652
|
+
) = _mul_div_precision_helper(snowpark_typed_args, snowpark_args, 1)
|
|
653
|
+
result_type = DecimalType(return_type_precision, return_type_scale)
|
|
654
|
+
case _:
|
|
655
|
+
result_type = DoubleType()
|
|
656
|
+
dividend = snowpark_args[0].cast(result_type)
|
|
657
|
+
divisor = snowpark_args[1].cast(result_type)
|
|
658
|
+
result_exp = _divnull(dividend, divisor)
|
|
672
659
|
case "~":
|
|
673
660
|
result_exp = TypedColumn(
|
|
674
661
|
snowpark_fn.bitnot(snowpark_args[0]),
|
|
@@ -1120,7 +1107,7 @@ def map_unresolved_function(
|
|
|
1120
1107
|
result_exp = TypedColumn(
|
|
1121
1108
|
result_exp, lambda: [ArrayType(snowpark_typed_args[0].typ)]
|
|
1122
1109
|
)
|
|
1123
|
-
case "array_size"
|
|
1110
|
+
case "array_size":
|
|
1124
1111
|
array_type = snowpark_typed_args[0].typ
|
|
1125
1112
|
if not isinstance(array_type, ArrayType):
|
|
1126
1113
|
raise AnalysisException(
|
|
@@ -1129,6 +1116,16 @@ def map_unresolved_function(
|
|
|
1129
1116
|
result_exp = TypedColumn(
|
|
1130
1117
|
snowpark_fn.array_size(*snowpark_args), lambda: [LongType()]
|
|
1131
1118
|
)
|
|
1119
|
+
case "cardinality":
|
|
1120
|
+
arg_type = snowpark_typed_args[0].typ
|
|
1121
|
+
if isinstance(arg_type, (ArrayType, MapType)):
|
|
1122
|
+
result_exp = TypedColumn(
|
|
1123
|
+
snowpark_fn.size(*snowpark_args), lambda: [LongType()]
|
|
1124
|
+
)
|
|
1125
|
+
else:
|
|
1126
|
+
raise AnalysisException(
|
|
1127
|
+
f"Expected argument '{snowpark_arg_names[0]}' to have an ArrayType or MapType, but got {arg_type.simpleString()}."
|
|
1128
|
+
)
|
|
1132
1129
|
case "array_sort":
|
|
1133
1130
|
result_exp = TypedColumn(
|
|
1134
1131
|
snowpark_fn.array_sort(*snowpark_args),
|
|
@@ -1201,35 +1198,18 @@ def map_unresolved_function(
|
|
|
1201
1198
|
snowpark_fn.asinh(snowpark_args[0]), lambda: [DoubleType()]
|
|
1202
1199
|
)
|
|
1203
1200
|
case "assert_true":
|
|
1201
|
+
result_type = NullType()
|
|
1202
|
+
raise_error = _raise_error_helper(result_type)
|
|
1204
1203
|
|
|
1205
|
-
@cached_udf(
|
|
1206
|
-
input_types=[BooleanType()],
|
|
1207
|
-
return_type=StringType(),
|
|
1208
|
-
)
|
|
1209
|
-
def _assert_true_single(expr):
|
|
1210
|
-
if not expr:
|
|
1211
|
-
raise ValueError("assertion failed")
|
|
1212
|
-
return None
|
|
1213
|
-
|
|
1214
|
-
@cached_udf(
|
|
1215
|
-
input_types=[BooleanType(), StringType()],
|
|
1216
|
-
return_type=StringType(),
|
|
1217
|
-
)
|
|
1218
|
-
def _assert_true_with_message(expr, message):
|
|
1219
|
-
if not expr:
|
|
1220
|
-
raise ValueError(message)
|
|
1221
|
-
return None
|
|
1222
|
-
|
|
1223
|
-
# Handle different argument counts using match pattern
|
|
1224
1204
|
match snowpark_args:
|
|
1225
1205
|
case [expr]:
|
|
1226
|
-
result_exp =
|
|
1227
|
-
|
|
1228
|
-
)
|
|
1206
|
+
result_exp = snowpark_fn.when(
|
|
1207
|
+
expr, snowpark_fn.lit(None)
|
|
1208
|
+
).otherwise(raise_error(snowpark_fn.lit("assertion failed")))
|
|
1229
1209
|
case [expr, message]:
|
|
1230
|
-
result_exp =
|
|
1231
|
-
|
|
1232
|
-
)
|
|
1210
|
+
result_exp = snowpark_fn.when(
|
|
1211
|
+
expr, snowpark_fn.lit(None)
|
|
1212
|
+
).otherwise(raise_error(snowpark_fn.cast(message, StringType())))
|
|
1233
1213
|
case _:
|
|
1234
1214
|
raise AnalysisException(
|
|
1235
1215
|
f"[WRONG_NUM_ARGS.WITHOUT_SUGGESTION] The `assert_true` requires 1 or 2 parameters but the actual number is {len(snowpark_args)}."
|
|
@@ -1325,10 +1305,35 @@ def map_unresolved_function(
|
|
|
1325
1305
|
)
|
|
1326
1306
|
result_exp = TypedColumn(result_exp, lambda: [LongType()])
|
|
1327
1307
|
case "bit_get" | "getbit":
|
|
1328
|
-
|
|
1329
|
-
|
|
1330
|
-
bit_get_function(*snowpark_args), lambda: [LongType()]
|
|
1308
|
+
snowflake_compat = get_boolean_session_config_param(
|
|
1309
|
+
"enable_snowflake_extension_behavior"
|
|
1331
1310
|
)
|
|
1311
|
+
col, pos = snowpark_args
|
|
1312
|
+
if snowflake_compat:
|
|
1313
|
+
bit_get_function = snowpark_fn.function("GETBIT")(col, pos)
|
|
1314
|
+
else:
|
|
1315
|
+
raise_error = _raise_error_helper(LongType())
|
|
1316
|
+
bit_get_function = snowpark_fn.when(
|
|
1317
|
+
(snowpark_fn.lit(0) <= pos) & (pos <= snowpark_fn.lit(63))
|
|
1318
|
+
| snowpark_fn.is_null(pos),
|
|
1319
|
+
snowpark_fn.function("GETBIT")(col, pos),
|
|
1320
|
+
).otherwise(
|
|
1321
|
+
raise_error(
|
|
1322
|
+
snowpark_fn.concat(
|
|
1323
|
+
snowpark_fn.lit(
|
|
1324
|
+
"Invalid bit position: ",
|
|
1325
|
+
),
|
|
1326
|
+
snowpark_fn.cast(
|
|
1327
|
+
pos,
|
|
1328
|
+
StringType(),
|
|
1329
|
+
),
|
|
1330
|
+
snowpark_fn.lit(
|
|
1331
|
+
" exceeds the bit upper limit",
|
|
1332
|
+
),
|
|
1333
|
+
)
|
|
1334
|
+
)
|
|
1335
|
+
)
|
|
1336
|
+
result_exp = TypedColumn(bit_get_function, lambda: [LongType()])
|
|
1332
1337
|
case "bit_length":
|
|
1333
1338
|
bit_length_function = snowpark_fn.function("bit_length")
|
|
1334
1339
|
result_exp = TypedColumn(
|
|
@@ -2291,31 +2296,32 @@ def map_unresolved_function(
|
|
|
2291
2296
|
)
|
|
2292
2297
|
case "elt":
|
|
2293
2298
|
n = snowpark_args[0]
|
|
2294
|
-
|
|
2295
2299
|
values = snowpark_fn.array_construct(*snowpark_args[1:])
|
|
2296
2300
|
|
|
2297
2301
|
if spark_sql_ansi_enabled:
|
|
2298
|
-
|
|
2299
|
-
|
|
2300
|
-
input_types=[IntegerType()],
|
|
2301
|
-
return_type=StringType(),
|
|
2302
|
+
raise_error = _raise_error_helper(
|
|
2303
|
+
StringType(), error_class=ArrayIndexOutOfBoundsException
|
|
2302
2304
|
)
|
|
2303
|
-
def _raise_out_of_bounds_error(n: int) -> str:
|
|
2304
|
-
raise ValueError(
|
|
2305
|
-
f"ArrayIndexOutOfBoundsException: {n} is not within the input bounds."
|
|
2306
|
-
)
|
|
2307
|
-
|
|
2308
2305
|
values_size = snowpark_fn.lit(len(snowpark_args) - 1)
|
|
2309
2306
|
|
|
2310
2307
|
result_exp = (
|
|
2311
2308
|
snowpark_fn.when(snowpark_fn.is_null(n), snowpark_fn.lit(None))
|
|
2312
2309
|
.when(
|
|
2313
2310
|
(snowpark_fn.lit(1) <= n) & (n <= values_size),
|
|
2314
|
-
snowpark_fn.
|
|
2315
|
-
|
|
2311
|
+
snowpark_fn.cast(
|
|
2312
|
+
snowpark_fn.get(
|
|
2313
|
+
values, snowpark_fn.nvl(n - 1, snowpark_fn.lit(0))
|
|
2314
|
+
),
|
|
2315
|
+
StringType(),
|
|
2316
2316
|
),
|
|
2317
2317
|
)
|
|
2318
|
-
.otherwise(
|
|
2318
|
+
.otherwise(
|
|
2319
|
+
raise_error(
|
|
2320
|
+
snowpark_fn.lit("[INVALID_ARRAY_INDEX] The index "),
|
|
2321
|
+
snowpark_fn.cast(n, StringType()),
|
|
2322
|
+
snowpark_fn.lit(" is out of bounds."),
|
|
2323
|
+
)
|
|
2324
|
+
)
|
|
2319
2325
|
)
|
|
2320
2326
|
else:
|
|
2321
2327
|
result_exp = snowpark_fn.when(
|
|
@@ -3289,7 +3295,7 @@ def map_unresolved_function(
|
|
|
3289
3295
|
).cast(LongType())
|
|
3290
3296
|
result_type = LongType()
|
|
3291
3297
|
case "hll_union_agg":
|
|
3292
|
-
raise_error =
|
|
3298
|
+
raise_error = _raise_error_helper(BinaryType())
|
|
3293
3299
|
args = exp.unresolved_function.arguments
|
|
3294
3300
|
allow_different_lgConfigK = len(args) == 2 and unwrap_literal(args[1])
|
|
3295
3301
|
spark_function_name = f"{function_name}({snowpark_arg_names[0]}, {str(allow_different_lgConfigK).lower()})"
|
|
@@ -3329,7 +3335,7 @@ def map_unresolved_function(
|
|
|
3329
3335
|
SELECT arg1 as x)
|
|
3330
3336
|
""",
|
|
3331
3337
|
)
|
|
3332
|
-
raise_error =
|
|
3338
|
+
raise_error = _raise_error_helper(BinaryType())
|
|
3333
3339
|
args = exp.unresolved_function.arguments
|
|
3334
3340
|
allow_different_lgConfigK = len(args) == 3 and unwrap_literal(args[2])
|
|
3335
3341
|
spark_function_name = f"{function_name}({snowpark_arg_names[0]}, {snowpark_arg_names[1]}, {str(allow_different_lgConfigK).lower()})"
|
|
@@ -3816,7 +3822,13 @@ def map_unresolved_function(
|
|
|
3816
3822
|
case "locate":
|
|
3817
3823
|
substr = unwrap_literal(exp.unresolved_function.arguments[0])
|
|
3818
3824
|
value = snowpark_args[1]
|
|
3819
|
-
|
|
3825
|
+
if len(exp.unresolved_function.arguments) == 3:
|
|
3826
|
+
start_pos = unwrap_literal(exp.unresolved_function.arguments[2])
|
|
3827
|
+
else:
|
|
3828
|
+
# start_pos is an optional argument and if not provided we should default to 1.
|
|
3829
|
+
# This path will only be reached by spark connect scala clients.
|
|
3830
|
+
start_pos = 1
|
|
3831
|
+
spark_function_name = f"locate({', '.join(snowpark_arg_names)}, 1)"
|
|
3820
3832
|
|
|
3821
3833
|
if start_pos > 0:
|
|
3822
3834
|
result_exp = snowpark_fn.locate(substr, value, start_pos)
|
|
@@ -4053,7 +4065,7 @@ def map_unresolved_function(
|
|
|
4053
4065
|
snowpark_fn.is_null(snowpark_args[i]),
|
|
4054
4066
|
# udf execution on XP seems to be lazy, so this should only run when there is a null key
|
|
4055
4067
|
# otherwise there should be no udf env setup or execution
|
|
4056
|
-
|
|
4068
|
+
_raise_error_helper(VariantType())(
|
|
4057
4069
|
snowpark_fn.lit(
|
|
4058
4070
|
"[NULL_MAP_KEY] Cannot use null as map key."
|
|
4059
4071
|
)
|
|
@@ -4115,6 +4127,14 @@ def map_unresolved_function(
|
|
|
4115
4127
|
)
|
|
4116
4128
|
result_type = MapType(key_type, value_type)
|
|
4117
4129
|
case "map_contains_key":
|
|
4130
|
+
if isinstance(snowpark_typed_args[0].typ, NullType):
|
|
4131
|
+
raise AnalysisException(
|
|
4132
|
+
f"""[DATATYPE_MISMATCH.MAP_FUNCTION_DIFF_TYPES] Cannot resolve "map_contains_key({snowpark_arg_names[0]}, {snowpark_arg_names[1]})" due to data type mismatch: Input to `map_contains_key` should have been "MAP" followed by a value with same key type, but it's ["VOID", "INT"]."""
|
|
4133
|
+
)
|
|
4134
|
+
if isinstance(snowpark_typed_args[1].typ, NullType):
|
|
4135
|
+
raise AnalysisException(
|
|
4136
|
+
f"""[DATATYPE_MISMATCH.NULL_TYPE] Cannot resolve "map_contains_key({snowpark_arg_names[0]}, {snowpark_arg_names[1]})" due to data type mismatch: Null typed values cannot be used as arguments of `map_contains_key`."""
|
|
4137
|
+
)
|
|
4118
4138
|
args = (
|
|
4119
4139
|
[snowpark_args[1], snowpark_args[0]]
|
|
4120
4140
|
if isinstance(snowpark_typed_args[0].typ, MapType)
|
|
@@ -4244,17 +4264,37 @@ def map_unresolved_function(
|
|
|
4244
4264
|
|
|
4245
4265
|
last_win_dedup = global_config.spark_sql_mapKeyDedupPolicy == "LAST_WIN"
|
|
4246
4266
|
|
|
4247
|
-
|
|
4248
|
-
|
|
4249
|
-
|
|
4250
|
-
snowpark_fn.
|
|
4251
|
-
|
|
4252
|
-
|
|
4253
|
-
|
|
4254
|
-
|
|
4255
|
-
|
|
4256
|
-
|
|
4267
|
+
# Check if any entry has a NULL key
|
|
4268
|
+
has_null_key = (
|
|
4269
|
+
snowpark_fn.function("array_size")(
|
|
4270
|
+
snowpark_fn.function("filter")(
|
|
4271
|
+
snowpark_args[0],
|
|
4272
|
+
snowpark_fn.sql_expr(f"e -> e:{key_field} IS NULL"),
|
|
4273
|
+
)
|
|
4274
|
+
)
|
|
4275
|
+
> 0
|
|
4276
|
+
)
|
|
4277
|
+
|
|
4278
|
+
# Create error UDF for NULL keys (same pattern as map function)
|
|
4279
|
+
null_key_error = _raise_error_helper(VariantType())(
|
|
4280
|
+
snowpark_fn.lit("[NULL_MAP_KEY] Cannot use null as map key.")
|
|
4281
|
+
)
|
|
4282
|
+
|
|
4283
|
+
# Create the reduce operation
|
|
4284
|
+
reduce_result = snowpark_fn.function("reduce")(
|
|
4285
|
+
snowpark_args[0],
|
|
4286
|
+
snowpark_fn.object_construct(),
|
|
4287
|
+
snowpark_fn.sql_expr(
|
|
4288
|
+
# value_field is cast to variant because object_insert doesn't allow structured types,
|
|
4289
|
+
# and structured types are not coercible to variant
|
|
4290
|
+
# TODO: allow structured types in object_insert?
|
|
4291
|
+
f"(acc, e) -> object_insert(acc, e:{key_field}, e:{value_field}::variant, {last_win_dedup})"
|
|
4257
4292
|
),
|
|
4293
|
+
)
|
|
4294
|
+
|
|
4295
|
+
# Use conditional logic: if there are NULL keys, throw error; otherwise proceed with reduce
|
|
4296
|
+
result_exp = snowpark_fn.cast(
|
|
4297
|
+
snowpark_fn.when(has_null_key, null_key_error).otherwise(reduce_result),
|
|
4258
4298
|
MapType(key_type, value_type),
|
|
4259
4299
|
)
|
|
4260
4300
|
result_type = MapType(key_type, value_type)
|
|
@@ -4273,23 +4313,35 @@ def map_unresolved_function(
|
|
|
4273
4313
|
# TODO: implement in Snowflake/Snowpark
|
|
4274
4314
|
# technically this could be done with a lateral join, but it's probably not worth the effort
|
|
4275
4315
|
arg_type = snowpark_typed_args[0].typ
|
|
4276
|
-
if not isinstance(arg_type, MapType):
|
|
4316
|
+
if not isinstance(arg_type, (MapType, NullType)):
|
|
4277
4317
|
raise AnalysisException(
|
|
4278
4318
|
f"map_values requires a MapType argument, got {arg_type}"
|
|
4279
4319
|
)
|
|
4280
4320
|
|
|
4281
4321
|
def _map_values(obj: dict) -> list:
|
|
4282
|
-
|
|
4322
|
+
if obj is None:
|
|
4323
|
+
return None
|
|
4324
|
+
return list(obj.values())
|
|
4283
4325
|
|
|
4284
4326
|
map_values = cached_udf(
|
|
4285
4327
|
_map_values, return_type=ArrayType(), input_types=[StructType()]
|
|
4286
4328
|
)
|
|
4287
4329
|
|
|
4288
|
-
|
|
4289
|
-
|
|
4290
|
-
|
|
4291
|
-
|
|
4292
|
-
|
|
4330
|
+
# Handle NULL input directly at expression level
|
|
4331
|
+
if isinstance(arg_type, NullType):
|
|
4332
|
+
# If input is NULL literal, return NULL
|
|
4333
|
+
result_exp = snowpark_fn.lit(None)
|
|
4334
|
+
result_type = ArrayType(NullType())
|
|
4335
|
+
else:
|
|
4336
|
+
result_exp = snowpark_fn.when(
|
|
4337
|
+
snowpark_args[0].is_null(), snowpark_fn.lit(None)
|
|
4338
|
+
).otherwise(
|
|
4339
|
+
snowpark_fn.cast(
|
|
4340
|
+
map_values(snowpark_fn.cast(snowpark_args[0], StructType())),
|
|
4341
|
+
ArrayType(arg_type.value_type),
|
|
4342
|
+
)
|
|
4343
|
+
)
|
|
4344
|
+
result_type = ArrayType(arg_type.value_type)
|
|
4293
4345
|
case "mask":
|
|
4294
4346
|
|
|
4295
4347
|
number_of_args = len(snowpark_args)
|
|
@@ -5194,7 +5246,7 @@ def map_unresolved_function(
|
|
|
5194
5246
|
result_type = DoubleType()
|
|
5195
5247
|
case "raise_error":
|
|
5196
5248
|
result_type = StringType()
|
|
5197
|
-
raise_error =
|
|
5249
|
+
raise_error = _raise_error_helper(result_type)
|
|
5198
5250
|
result_exp = raise_error(*snowpark_args)
|
|
5199
5251
|
case "rand" | "random":
|
|
5200
5252
|
# Snowpark random() generates a 64 bit signed integer, but pyspark is [0.0, 1.0).
|
|
@@ -5279,7 +5331,7 @@ def map_unresolved_function(
|
|
|
5279
5331
|
snowpark_args[2],
|
|
5280
5332
|
),
|
|
5281
5333
|
),
|
|
5282
|
-
|
|
5334
|
+
_raise_error_helper(StringType())(
|
|
5283
5335
|
snowpark_fn.lit(
|
|
5284
5336
|
"[INVALID_PARAMETER_VALUE.REGEX_GROUP_INDEX] The value of parameter(s) `idx` in `regexp_extract` is invalid."
|
|
5285
5337
|
)
|
|
@@ -5329,7 +5381,7 @@ def map_unresolved_function(
|
|
|
5329
5381
|
idx,
|
|
5330
5382
|
)
|
|
5331
5383
|
),
|
|
5332
|
-
|
|
5384
|
+
_raise_error_helper(ArrayType(StringType()))(
|
|
5333
5385
|
snowpark_fn.lit(
|
|
5334
5386
|
"[INVALID_PARAMETER_VALUE.REGEX_GROUP_INDEX] The value of parameter(s) `idx` in `regexp_extract_all` is invalid."
|
|
5335
5387
|
)
|
|
@@ -5485,9 +5537,27 @@ def map_unresolved_function(
|
|
|
5485
5537
|
):
|
|
5486
5538
|
result_exp = snowpark_fn.lit(None)
|
|
5487
5539
|
else:
|
|
5540
|
+
right_expr = snowpark_fn.right(*snowpark_args)
|
|
5541
|
+
if isinstance(snowpark_typed_args[0].typ, TimestampType):
|
|
5542
|
+
# Spark format is always displayed as YYY-MM-DD HH:mm:ss.FF6
|
|
5543
|
+
# When microseconds are equal to 0 .FF6 part is removed
|
|
5544
|
+
# When microseconds are equal to 0 at the end, they are removed i.e. .123000 -> .123 when displayed
|
|
5545
|
+
|
|
5546
|
+
formated_timestamp = snowpark_fn.to_varchar(
|
|
5547
|
+
snowpark_args[0], "YYYY-MM-DD HH:MI:SS.FF6"
|
|
5548
|
+
)
|
|
5549
|
+
right_expr = snowpark_fn.right(
|
|
5550
|
+
snowpark_fn.regexp_replace(
|
|
5551
|
+
snowpark_fn.regexp_replace(formated_timestamp, "0+$", ""),
|
|
5552
|
+
"\\.$",
|
|
5553
|
+
"",
|
|
5554
|
+
),
|
|
5555
|
+
snowpark_args[1],
|
|
5556
|
+
)
|
|
5557
|
+
|
|
5488
5558
|
result_exp = snowpark_fn.when(
|
|
5489
5559
|
snowpark_args[1] <= 0, snowpark_fn.lit("")
|
|
5490
|
-
).otherwise(
|
|
5560
|
+
).otherwise(right_expr)
|
|
5491
5561
|
result_type = StringType()
|
|
5492
5562
|
case "rint":
|
|
5493
5563
|
result_exp = snowpark_fn.cast(
|
|
@@ -6043,7 +6113,7 @@ def map_unresolved_function(
|
|
|
6043
6113
|
result_exp = snowpark_fn.skew(snowpark_fn.lit(None))
|
|
6044
6114
|
result_type = DoubleType()
|
|
6045
6115
|
case "slice":
|
|
6046
|
-
raise_error =
|
|
6116
|
+
raise_error = _raise_error_helper(snowpark_typed_args[0].typ)
|
|
6047
6117
|
spark_index = snowpark_args[1]
|
|
6048
6118
|
arr_size = snowpark_fn.array_size(snowpark_args[0])
|
|
6049
6119
|
slice_len = snowpark_args[2]
|
|
@@ -6113,10 +6183,11 @@ def map_unresolved_function(
|
|
|
6113
6183
|
result_exp = snowpark_fn.lit(0)
|
|
6114
6184
|
result_type = LongType()
|
|
6115
6185
|
case "split":
|
|
6186
|
+
result_type = ArrayType(StringType())
|
|
6116
6187
|
|
|
6117
6188
|
@cached_udf(
|
|
6118
6189
|
input_types=[StringType(), StringType(), IntegerType()],
|
|
6119
|
-
return_type=
|
|
6190
|
+
return_type=result_type,
|
|
6120
6191
|
)
|
|
6121
6192
|
def _split(
|
|
6122
6193
|
input: Optional[str], pattern: Optional[str], limit: Optional[int]
|
|
@@ -6124,34 +6195,80 @@ def map_unresolved_function(
|
|
|
6124
6195
|
if input is None or pattern is None:
|
|
6125
6196
|
return None
|
|
6126
6197
|
|
|
6198
|
+
import re
|
|
6199
|
+
|
|
6200
|
+
try:
|
|
6201
|
+
re.compile(pattern)
|
|
6202
|
+
except re.error:
|
|
6203
|
+
raise ValueError(
|
|
6204
|
+
f"Failed to split string, provided pattern: {pattern} is invalid"
|
|
6205
|
+
)
|
|
6206
|
+
|
|
6127
6207
|
if limit == 1:
|
|
6128
6208
|
return [input]
|
|
6129
6209
|
|
|
6130
|
-
|
|
6210
|
+
if not input:
|
|
6211
|
+
return []
|
|
6131
6212
|
|
|
6132
6213
|
# A default of -1 is passed in PySpark, but RE needs it to be 0 to provide all splits.
|
|
6133
6214
|
# In PySpark, the limit also indicates the max size of the resulting array, but in RE
|
|
6134
6215
|
# the remainder is returned as another element.
|
|
6135
6216
|
maxsplit = limit - 1 if limit > 0 else 0
|
|
6136
6217
|
|
|
6137
|
-
split_result = re.split(pattern, input, maxsplit)
|
|
6138
6218
|
if len(pattern) == 0:
|
|
6139
|
-
|
|
6140
|
-
|
|
6219
|
+
return list(input) if limit <= 0 else list(input)[:limit]
|
|
6220
|
+
|
|
6221
|
+
match pattern:
|
|
6222
|
+
case "|":
|
|
6223
|
+
split_result = re.split(pattern, input, 0)
|
|
6224
|
+
input_limit = limit + 1 if limit > 0 else len(split_result)
|
|
6225
|
+
return (
|
|
6226
|
+
split_result
|
|
6227
|
+
if input_limit == 0
|
|
6228
|
+
else split_result[1:input_limit]
|
|
6229
|
+
)
|
|
6230
|
+
case "$":
|
|
6231
|
+
return [input, ""] if maxsplit >= 0 else [input]
|
|
6232
|
+
case "^":
|
|
6233
|
+
return [input]
|
|
6234
|
+
case _:
|
|
6235
|
+
return re.split(pattern, input, maxsplit)
|
|
6141
6236
|
|
|
6142
|
-
|
|
6237
|
+
def split_string(str_: Column, pattern: Column, limit: Column):
|
|
6238
|
+
native_split = _split(str_, pattern, limit)
|
|
6239
|
+
# When pattern is a literal and doesn't contain any regex special characters
|
|
6240
|
+
# And when limit is less than or equal to 0
|
|
6241
|
+
# Native Snowflake Split function is used to optimise performance
|
|
6242
|
+
if isinstance(pattern._expression, Literal):
|
|
6243
|
+
pattern_value = pattern._expression.value
|
|
6244
|
+
|
|
6245
|
+
if pattern_value is None:
|
|
6246
|
+
return snowpark_fn.lit(None)
|
|
6247
|
+
|
|
6248
|
+
is_regexp = re.match(
|
|
6249
|
+
".*[\\[\\.\\]\\*\\?\\+\\^\\$\\{\\}\\|\\(\\)\\\\].*",
|
|
6250
|
+
pattern_value,
|
|
6251
|
+
)
|
|
6252
|
+
is_empty = len(pattern_value) == 0
|
|
6253
|
+
|
|
6254
|
+
if not is_empty and not is_regexp:
|
|
6255
|
+
return snowpark_fn.when(
|
|
6256
|
+
limit <= 0,
|
|
6257
|
+
snowpark_fn.split(str_, pattern).cast(result_type),
|
|
6258
|
+
).otherwise(native_split)
|
|
6259
|
+
|
|
6260
|
+
return native_split
|
|
6143
6261
|
|
|
6144
6262
|
match snowpark_args:
|
|
6145
6263
|
case [str_, pattern]:
|
|
6146
6264
|
spark_function_name = (
|
|
6147
6265
|
f"split({snowpark_arg_names[0]}, {snowpark_arg_names[1]}, -1)"
|
|
6148
6266
|
)
|
|
6149
|
-
result_exp =
|
|
6267
|
+
result_exp = split_string(str_, pattern, snowpark_fn.lit(-1))
|
|
6150
6268
|
case [str_, pattern, limit]: # noqa: F841
|
|
6151
|
-
result_exp =
|
|
6269
|
+
result_exp = split_string(str_, pattern, limit)
|
|
6152
6270
|
case _:
|
|
6153
6271
|
raise ValueError(f"Invalid number of arguments to {function_name}")
|
|
6154
|
-
result_type = ArrayType(StringType())
|
|
6155
6272
|
case "split_part":
|
|
6156
6273
|
result_exp = snowpark_fn.call_function("split_part", *snowpark_args)
|
|
6157
6274
|
result_type = StringType()
|
|
@@ -6671,6 +6788,7 @@ def map_unresolved_function(
|
|
|
6671
6788
|
if value == "" or any(
|
|
6672
6789
|
c in value for c in [",", "\n", "\r", '"', "'"]
|
|
6673
6790
|
):
|
|
6791
|
+
value = value.replace("\\", "\\\\").replace('"', '\\"')
|
|
6674
6792
|
result.append(f'"{value}"')
|
|
6675
6793
|
else:
|
|
6676
6794
|
result.append(value)
|
|
@@ -6920,7 +7038,20 @@ def map_unresolved_function(
|
|
|
6920
7038
|
result_type = StringType()
|
|
6921
7039
|
case "trunc":
|
|
6922
7040
|
part = unwrap_literal(exp.unresolved_function.arguments[1])
|
|
6923
|
-
if part is None
|
|
7041
|
+
part = None if part is None else part.lower()
|
|
7042
|
+
|
|
7043
|
+
allowed_parts = {
|
|
7044
|
+
"year",
|
|
7045
|
+
"yyyy",
|
|
7046
|
+
"yy",
|
|
7047
|
+
"month",
|
|
7048
|
+
"mon",
|
|
7049
|
+
"mm",
|
|
7050
|
+
"week",
|
|
7051
|
+
"quarter",
|
|
7052
|
+
}
|
|
7053
|
+
|
|
7054
|
+
if part not in allowed_parts:
|
|
6924
7055
|
result_exp = snowpark_fn.lit(None)
|
|
6925
7056
|
else:
|
|
6926
7057
|
result_exp = _try_to_cast(
|
|
@@ -7311,7 +7442,7 @@ def map_unresolved_function(
|
|
|
7311
7442
|
)
|
|
7312
7443
|
)
|
|
7313
7444
|
)
|
|
7314
|
-
raise_fn =
|
|
7445
|
+
raise_fn = _raise_error_helper(BinaryType(), IllegalArgumentException)
|
|
7315
7446
|
result_exp = (
|
|
7316
7447
|
snowpark_fn.when(unbase_arg.is_null(), snowpark_fn.lit(None))
|
|
7317
7448
|
.when(result_exp.is_null(), raise_fn(snowpark_fn.lit("Invalid input")))
|
|
@@ -7927,16 +8058,8 @@ def _handle_current_timestamp():
|
|
|
7927
8058
|
|
|
7928
8059
|
|
|
7929
8060
|
def _equivalent_decimal(type):
|
|
7930
|
-
|
|
7931
|
-
|
|
7932
|
-
return DecimalType(3, 0)
|
|
7933
|
-
case ShortType():
|
|
7934
|
-
return DecimalType(5, 0)
|
|
7935
|
-
case IntegerType():
|
|
7936
|
-
return DecimalType(10, 0)
|
|
7937
|
-
case LongType():
|
|
7938
|
-
return DecimalType(20, 0)
|
|
7939
|
-
return DecimalType(38, 0)
|
|
8061
|
+
(precision, scale) = _get_type_precision(type)
|
|
8062
|
+
return DecimalType(precision, scale)
|
|
7940
8063
|
|
|
7941
8064
|
|
|
7942
8065
|
def _resolve_decimal_and_numeric(type1: DecimalType, type2: _NumericType) -> DataType:
|
|
@@ -8995,7 +9118,9 @@ def _get_type_precision(typ: DataType) -> tuple[int, int]:
|
|
|
8995
9118
|
case IntegerType():
|
|
8996
9119
|
return 10, 0 # -2147483648 to 2147483647
|
|
8997
9120
|
case LongType():
|
|
8998
|
-
return
|
|
9121
|
+
return 20, 0 # -9223372036854775808 to 9223372036854775807
|
|
9122
|
+
case NullType():
|
|
9123
|
+
return 6, 2 # NULL
|
|
8999
9124
|
case _:
|
|
9000
9125
|
return 38, 0 # Default to maximum precision for other types
|
|
9001
9126
|
|
|
@@ -9210,16 +9335,12 @@ def _try_arithmetic_helper(
|
|
|
9210
9335
|
typed_args[1].typ, DecimalType
|
|
9211
9336
|
):
|
|
9212
9337
|
new_scale = s2
|
|
9213
|
-
new_precision = (
|
|
9214
|
-
p1 + s2 + 1
|
|
9215
|
-
) # Integral precision + decimal scale + 1 for carry
|
|
9338
|
+
new_precision = max(p2, p1 + s2)
|
|
9216
9339
|
elif isinstance(typed_args[0].typ, DecimalType) and isinstance(
|
|
9217
9340
|
typed_args[1].typ, _IntegralType
|
|
9218
9341
|
):
|
|
9219
9342
|
new_scale = s1
|
|
9220
|
-
new_precision = (
|
|
9221
|
-
p2 + s1 + 1
|
|
9222
|
-
) # Integral precision + decimal scale + 1 for carry
|
|
9343
|
+
new_precision = max(p1, p2 + s1)
|
|
9223
9344
|
else:
|
|
9224
9345
|
# Both decimal types
|
|
9225
9346
|
if operation_type == 1 and s1 == s2: # subtraction with matching scales
|
|
@@ -9298,13 +9419,13 @@ def _add_sub_precision_helper(
|
|
|
9298
9419
|
typed_args[1].typ, DecimalType
|
|
9299
9420
|
):
|
|
9300
9421
|
new_scale = s2
|
|
9301
|
-
new_precision = p1 + s2
|
|
9422
|
+
new_precision = max(p2, p1 + s2)
|
|
9302
9423
|
return_type_precision, return_type_scale = new_precision, new_scale
|
|
9303
9424
|
elif isinstance(typed_args[0].typ, DecimalType) and isinstance(
|
|
9304
9425
|
typed_args[1].typ, _IntegralType
|
|
9305
9426
|
):
|
|
9306
9427
|
new_scale = s1
|
|
9307
|
-
new_precision = p2 + s1
|
|
9428
|
+
new_precision = max(p1, p2 + s1)
|
|
9308
9429
|
return_type_precision, return_type_scale = new_precision, new_scale
|
|
9309
9430
|
else:
|
|
9310
9431
|
(
|
|
@@ -9386,11 +9507,25 @@ def _mul_div_precision_helper(
|
|
|
9386
9507
|
)
|
|
9387
9508
|
|
|
9388
9509
|
|
|
9389
|
-
def
|
|
9390
|
-
|
|
9391
|
-
|
|
9510
|
+
def _raise_error_helper(return_type: DataType, error_class=None):
|
|
9511
|
+
error_class = (
|
|
9512
|
+
f":{error_class.__name__}"
|
|
9513
|
+
if error_class and hasattr(error_class, "__name__")
|
|
9514
|
+
else ""
|
|
9515
|
+
)
|
|
9516
|
+
|
|
9517
|
+
def _raise_fn(*msgs: Column) -> Column:
|
|
9518
|
+
return snowpark_fn.cast(
|
|
9519
|
+
snowpark_fn.abs(
|
|
9520
|
+
snowpark_fn.concat(
|
|
9521
|
+
snowpark_fn.lit(f"[snowpark-connect-exception{error_class}]"),
|
|
9522
|
+
*(msg.try_cast(StringType()) for msg in msgs),
|
|
9523
|
+
)
|
|
9524
|
+
).cast(StringType()),
|
|
9525
|
+
return_type,
|
|
9526
|
+
)
|
|
9392
9527
|
|
|
9393
|
-
return
|
|
9528
|
+
return _raise_fn
|
|
9394
9529
|
|
|
9395
9530
|
|
|
9396
9531
|
def _divnull(dividend: Column, divisor: Column) -> Column:
|