snowpark-connect 0.20.2__py3-none-any.whl → 0.22.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of snowpark-connect might be problematic. Click here for more details.
- snowflake/snowpark_connect/analyze_plan/map_tree_string.py +3 -2
- snowflake/snowpark_connect/column_name_handler.py +6 -65
- snowflake/snowpark_connect/config.py +47 -17
- snowflake/snowpark_connect/dataframe_container.py +242 -0
- snowflake/snowpark_connect/error/error_utils.py +25 -0
- snowflake/snowpark_connect/execute_plan/map_execution_command.py +13 -23
- snowflake/snowpark_connect/execute_plan/map_execution_root.py +9 -5
- snowflake/snowpark_connect/expression/map_extension.py +2 -1
- snowflake/snowpark_connect/expression/map_udf.py +4 -4
- snowflake/snowpark_connect/expression/map_unresolved_attribute.py +8 -7
- snowflake/snowpark_connect/expression/map_unresolved_function.py +481 -170
- snowflake/snowpark_connect/expression/map_unresolved_star.py +8 -8
- snowflake/snowpark_connect/expression/map_update_fields.py +1 -1
- snowflake/snowpark_connect/expression/typer.py +6 -6
- snowflake/snowpark_connect/proto/control_pb2.py +17 -16
- snowflake/snowpark_connect/proto/control_pb2.pyi +17 -17
- snowflake/snowpark_connect/proto/control_pb2_grpc.py +12 -63
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.py +15 -14
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.pyi +19 -14
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2_grpc.py +4 -0
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +27 -26
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +74 -68
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2_grpc.py +4 -0
- snowflake/snowpark_connect/relation/catalogs/abstract_spark_catalog.py +5 -5
- snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +25 -17
- snowflake/snowpark_connect/relation/map_aggregate.py +170 -61
- snowflake/snowpark_connect/relation/map_catalog.py +2 -2
- snowflake/snowpark_connect/relation/map_column_ops.py +227 -145
- snowflake/snowpark_connect/relation/map_crosstab.py +25 -6
- snowflake/snowpark_connect/relation/map_extension.py +81 -56
- snowflake/snowpark_connect/relation/map_join.py +72 -63
- snowflake/snowpark_connect/relation/map_local_relation.py +35 -20
- snowflake/snowpark_connect/relation/map_map_partitions.py +24 -17
- snowflake/snowpark_connect/relation/map_relation.py +22 -16
- snowflake/snowpark_connect/relation/map_row_ops.py +232 -146
- snowflake/snowpark_connect/relation/map_sample_by.py +15 -8
- snowflake/snowpark_connect/relation/map_show_string.py +42 -5
- snowflake/snowpark_connect/relation/map_sql.py +141 -237
- snowflake/snowpark_connect/relation/map_stats.py +88 -39
- snowflake/snowpark_connect/relation/map_subquery_alias.py +13 -14
- snowflake/snowpark_connect/relation/map_udtf.py +10 -13
- snowflake/snowpark_connect/relation/read/map_read.py +8 -3
- snowflake/snowpark_connect/relation/read/map_read_csv.py +7 -7
- snowflake/snowpark_connect/relation/read/map_read_jdbc.py +7 -7
- snowflake/snowpark_connect/relation/read/map_read_json.py +19 -8
- snowflake/snowpark_connect/relation/read/map_read_parquet.py +7 -7
- snowflake/snowpark_connect/relation/read/map_read_socket.py +7 -3
- snowflake/snowpark_connect/relation/read/map_read_table.py +25 -16
- snowflake/snowpark_connect/relation/read/map_read_text.py +7 -7
- snowflake/snowpark_connect/relation/read/reader_config.py +1 -0
- snowflake/snowpark_connect/relation/utils.py +11 -5
- snowflake/snowpark_connect/relation/write/jdbc_write_dbapi.py +15 -12
- snowflake/snowpark_connect/relation/write/map_write.py +259 -56
- snowflake/snowpark_connect/relation/write/map_write_jdbc.py +3 -2
- snowflake/snowpark_connect/server.py +43 -4
- snowflake/snowpark_connect/type_mapping.py +6 -23
- snowflake/snowpark_connect/utils/cache.py +27 -22
- snowflake/snowpark_connect/utils/context.py +33 -17
- snowflake/snowpark_connect/utils/describe_query_cache.py +2 -9
- snowflake/snowpark_connect/utils/{attribute_handling.py → identifiers.py} +47 -0
- snowflake/snowpark_connect/utils/session.py +41 -38
- snowflake/snowpark_connect/utils/telemetry.py +214 -63
- snowflake/snowpark_connect/utils/udxf_import_utils.py +14 -0
- snowflake/snowpark_connect/version.py +1 -1
- snowflake/snowpark_decoder/__init__.py +0 -0
- snowflake/snowpark_decoder/_internal/proto/generated/DataframeProcessorMsg_pb2.py +36 -0
- snowflake/snowpark_decoder/_internal/proto/generated/DataframeProcessorMsg_pb2.pyi +156 -0
- snowflake/snowpark_decoder/dp_session.py +111 -0
- snowflake/snowpark_decoder/spark_decoder.py +76 -0
- {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.22.1.dist-info}/METADATA +6 -4
- {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.22.1.dist-info}/RECORD +83 -69
- snowpark_connect-0.22.1.dist-info/licenses/LICENSE-binary +568 -0
- snowpark_connect-0.22.1.dist-info/licenses/NOTICE-binary +1533 -0
- {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.22.1.dist-info}/top_level.txt +1 -0
- spark/__init__.py +0 -0
- spark/connect/__init__.py +0 -0
- spark/connect/envelope_pb2.py +31 -0
- spark/connect/envelope_pb2.pyi +46 -0
- snowflake/snowpark_connect/includes/jars/jackson-mapper-asl-1.9.13.jar +0 -0
- {snowpark_connect-0.20.2.data → snowpark_connect-0.22.1.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.20.2.data → snowpark_connect-0.22.1.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.20.2.data → snowpark_connect-0.22.1.data}/scripts/snowpark-submit +0 -0
- {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.22.1.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.22.1.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -28,6 +28,7 @@ from google.protobuf.message import Message
|
|
|
28
28
|
from pyspark.errors.exceptions.base import (
|
|
29
29
|
AnalysisException,
|
|
30
30
|
ArithmeticException,
|
|
31
|
+
ArrayIndexOutOfBoundsException,
|
|
31
32
|
DateTimeException,
|
|
32
33
|
IllegalArgumentException,
|
|
33
34
|
NumberFormatException,
|
|
@@ -39,6 +40,7 @@ from pyspark.sql.types import _parse_datatype_json_string
|
|
|
39
40
|
import snowflake.snowpark.functions as snowpark_fn
|
|
40
41
|
from snowflake import snowpark
|
|
41
42
|
from snowflake.snowpark import Column, Session
|
|
43
|
+
from snowflake.snowpark._internal.analyzer.expression import Literal
|
|
42
44
|
from snowflake.snowpark._internal.analyzer.unary_expression import Alias
|
|
43
45
|
from snowflake.snowpark.types import (
|
|
44
46
|
ArrayType,
|
|
@@ -69,7 +71,10 @@ from snowflake.snowpark_connect.column_name_handler import (
|
|
|
69
71
|
ColumnNameMap,
|
|
70
72
|
set_schema_getter,
|
|
71
73
|
)
|
|
72
|
-
from snowflake.snowpark_connect.config import
|
|
74
|
+
from snowflake.snowpark_connect.config import (
|
|
75
|
+
get_boolean_session_config_param,
|
|
76
|
+
global_config,
|
|
77
|
+
)
|
|
73
78
|
from snowflake.snowpark_connect.constants import (
|
|
74
79
|
DUPLICATE_KEY_FOUND_ERROR_TEMPLATE,
|
|
75
80
|
SPARK_TZ_ABBREVIATIONS_OVERRIDES,
|
|
@@ -100,6 +105,7 @@ from snowflake.snowpark_connect.typed_column import (
|
|
|
100
105
|
)
|
|
101
106
|
from snowflake.snowpark_connect.utils.context import (
|
|
102
107
|
add_sql_aggregate_function,
|
|
108
|
+
get_current_grouping_columns,
|
|
103
109
|
get_is_aggregate_function,
|
|
104
110
|
get_is_evaluating_sql,
|
|
105
111
|
get_is_in_udtf_context,
|
|
@@ -135,7 +141,7 @@ from snowflake.snowpark_connect.utils.xxhash64 import (
|
|
|
135
141
|
MAX_UINT64 = 2**64 - 1
|
|
136
142
|
MAX_INT64 = 2**63 - 1
|
|
137
143
|
MIN_INT64 = -(2**63)
|
|
138
|
-
|
|
144
|
+
MAX_ARRAY_SIZE = 2_147_483_647
|
|
139
145
|
|
|
140
146
|
NAN, INFINITY = float("nan"), float("inf")
|
|
141
147
|
|
|
@@ -341,6 +347,9 @@ def map_unresolved_function(
|
|
|
341
347
|
)
|
|
342
348
|
spark_col_names = []
|
|
343
349
|
spark_sql_ansi_enabled = global_config.spark_sql_ansi_enabled
|
|
350
|
+
spark_sql_legacy_allow_hash_on_map_type = (
|
|
351
|
+
global_config.spark_sql_legacy_allowHashOnMapType
|
|
352
|
+
)
|
|
344
353
|
|
|
345
354
|
function_name = exp.unresolved_function.function_name.lower()
|
|
346
355
|
telemetry.report_function_usage(function_name)
|
|
@@ -631,37 +640,22 @@ def map_unresolved_function(
|
|
|
631
640
|
[arg.typ for arg in snowpark_typed_args]
|
|
632
641
|
)
|
|
633
642
|
case "/":
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
snowpark_args
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
and isinstance(snowpark_typed_args[1].typ, _IntegralType)
|
|
651
|
-
or isinstance(snowpark_typed_args[0].typ, _IntegralType)
|
|
652
|
-
and isinstance(snowpark_typed_args[1].typ, DecimalType)
|
|
653
|
-
):
|
|
654
|
-
result_exp, (
|
|
655
|
-
return_type_precision,
|
|
656
|
-
return_type_scale,
|
|
657
|
-
) = _mul_div_precision_helper(snowpark_typed_args, snowpark_args, 1)
|
|
658
|
-
result_type = DecimalType(return_type_precision, return_type_scale)
|
|
659
|
-
else:
|
|
660
|
-
# Perform division directly
|
|
661
|
-
result_exp = _divnull(snowpark_args[0], snowpark_args[1])
|
|
662
|
-
result_type = _find_common_type(
|
|
663
|
-
[arg.typ for arg in snowpark_typed_args]
|
|
664
|
-
)
|
|
643
|
+
match (snowpark_typed_args[0].typ, snowpark_typed_args[1].typ):
|
|
644
|
+
case (DecimalType(), t) | (t, DecimalType()) if isinstance(
|
|
645
|
+
t, DecimalType
|
|
646
|
+
) or isinstance(t, _IntegralType) or isinstance(
|
|
647
|
+
snowpark_typed_args[1].typ, NullType
|
|
648
|
+
):
|
|
649
|
+
result_exp, (
|
|
650
|
+
return_type_precision,
|
|
651
|
+
return_type_scale,
|
|
652
|
+
) = _mul_div_precision_helper(snowpark_typed_args, snowpark_args, 1)
|
|
653
|
+
result_type = DecimalType(return_type_precision, return_type_scale)
|
|
654
|
+
case _:
|
|
655
|
+
result_type = DoubleType()
|
|
656
|
+
dividend = snowpark_args[0].cast(result_type)
|
|
657
|
+
divisor = snowpark_args[1].cast(result_type)
|
|
658
|
+
result_exp = _divnull(dividend, divisor)
|
|
665
659
|
case "~":
|
|
666
660
|
result_exp = TypedColumn(
|
|
667
661
|
snowpark_fn.bitnot(snowpark_args[0]),
|
|
@@ -867,14 +861,30 @@ def map_unresolved_function(
|
|
|
867
861
|
)
|
|
868
862
|
case "approx_percentile" | "percentile_approx":
|
|
869
863
|
# SNOW-1955784: Support accuracy parameter
|
|
864
|
+
# Use percentile_disc to return actual values from dataset (matches PySpark behavior)
|
|
870
865
|
|
|
871
|
-
|
|
872
|
-
|
|
873
|
-
|
|
874
|
-
|
|
875
|
-
|
|
866
|
+
def _pyspark_approx_percentile(
|
|
867
|
+
column: Column, percentage: float, original_type: DataType
|
|
868
|
+
) -> Column:
|
|
869
|
+
"""
|
|
870
|
+
PySpark-compatible percentile that returns actual values from dataset.
|
|
871
|
+
- PySpark's approx_percentile returns the "smallest value in the ordered col values
|
|
872
|
+
such that no more than percentage of col values is less than or equal to that value"
|
|
873
|
+
- This means it MUST return an actual value from the original dataset
|
|
874
|
+
- Snowflake's approx_percentile() may interpolate between values, breaking compatibility
|
|
875
|
+
- percentile_disc() returns discrete values (actual dataset values), matching PySpark
|
|
876
|
+
"""
|
|
877
|
+
# Even though the Spark function accepts a Column for percentage, it will fail unless it's a literal.
|
|
878
|
+
# Therefore, we can do error checking right here.
|
|
879
|
+
if not 0.0 <= percentage <= 1.0:
|
|
876
880
|
raise AnalysisException("percentage must be between [0.0, 1.0]")
|
|
877
|
-
|
|
881
|
+
|
|
882
|
+
result = snowpark_fn.function("percentile_disc")(
|
|
883
|
+
snowpark_fn.lit(percentage)
|
|
884
|
+
).within_group(column)
|
|
885
|
+
return snowpark_fn.cast(result, original_type)
|
|
886
|
+
|
|
887
|
+
column_type = snowpark_typed_args[0].typ
|
|
878
888
|
|
|
879
889
|
if isinstance(snowpark_typed_args[1].typ, ArrayType):
|
|
880
890
|
# Snowpark doesn't accept a list of percentile values.
|
|
@@ -882,26 +892,26 @@ def map_unresolved_function(
|
|
|
882
892
|
array_func = exp.unresolved_function.arguments[1].unresolved_function
|
|
883
893
|
assert array_func.function_name == "array", array_func
|
|
884
894
|
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
)
|
|
895
|
+
percentile_results = [
|
|
896
|
+
_pyspark_approx_percentile(
|
|
897
|
+
snowpark_args[0], unwrap_literal(arg), column_type
|
|
898
|
+
)
|
|
899
|
+
for arg in array_func.arguments
|
|
900
|
+
]
|
|
901
|
+
|
|
902
|
+
result_type = ArrayType(element_type=column_type, contains_null=False)
|
|
893
903
|
result_exp = snowpark_fn.cast(
|
|
894
|
-
|
|
895
|
-
|
|
904
|
+
snowpark_fn.array_construct(*percentile_results),
|
|
905
|
+
result_type,
|
|
896
906
|
)
|
|
897
|
-
result_type = ArrayType(element_type=DoubleType(), contains_null=False)
|
|
898
907
|
else:
|
|
908
|
+
# Handle single percentile
|
|
909
|
+
percentage = unwrap_literal(exp.unresolved_function.arguments[1])
|
|
899
910
|
result_exp = TypedColumn(
|
|
900
|
-
|
|
901
|
-
snowpark_args[0],
|
|
902
|
-
_check_percentage(exp.unresolved_function.arguments[1]),
|
|
911
|
+
_pyspark_approx_percentile(
|
|
912
|
+
snowpark_args[0], percentage, column_type
|
|
903
913
|
),
|
|
904
|
-
lambda: [
|
|
914
|
+
lambda: [column_type],
|
|
905
915
|
)
|
|
906
916
|
case "array":
|
|
907
917
|
if len(snowpark_args) == 0:
|
|
@@ -1178,35 +1188,18 @@ def map_unresolved_function(
|
|
|
1178
1188
|
snowpark_fn.asinh(snowpark_args[0]), lambda: [DoubleType()]
|
|
1179
1189
|
)
|
|
1180
1190
|
case "assert_true":
|
|
1191
|
+
result_type = NullType()
|
|
1192
|
+
raise_error = _raise_error_helper(result_type)
|
|
1181
1193
|
|
|
1182
|
-
@cached_udf(
|
|
1183
|
-
input_types=[BooleanType()],
|
|
1184
|
-
return_type=StringType(),
|
|
1185
|
-
)
|
|
1186
|
-
def _assert_true_single(expr):
|
|
1187
|
-
if not expr:
|
|
1188
|
-
raise ValueError("assertion failed")
|
|
1189
|
-
return None
|
|
1190
|
-
|
|
1191
|
-
@cached_udf(
|
|
1192
|
-
input_types=[BooleanType(), StringType()],
|
|
1193
|
-
return_type=StringType(),
|
|
1194
|
-
)
|
|
1195
|
-
def _assert_true_with_message(expr, message):
|
|
1196
|
-
if not expr:
|
|
1197
|
-
raise ValueError(message)
|
|
1198
|
-
return None
|
|
1199
|
-
|
|
1200
|
-
# Handle different argument counts using match pattern
|
|
1201
1194
|
match snowpark_args:
|
|
1202
1195
|
case [expr]:
|
|
1203
|
-
result_exp =
|
|
1204
|
-
|
|
1205
|
-
)
|
|
1196
|
+
result_exp = snowpark_fn.when(
|
|
1197
|
+
expr, snowpark_fn.lit(None)
|
|
1198
|
+
).otherwise(raise_error(snowpark_fn.lit("assertion failed")))
|
|
1206
1199
|
case [expr, message]:
|
|
1207
|
-
result_exp =
|
|
1208
|
-
|
|
1209
|
-
)
|
|
1200
|
+
result_exp = snowpark_fn.when(
|
|
1201
|
+
expr, snowpark_fn.lit(None)
|
|
1202
|
+
).otherwise(raise_error(snowpark_fn.cast(message, StringType())))
|
|
1210
1203
|
case _:
|
|
1211
1204
|
raise AnalysisException(
|
|
1212
1205
|
f"[WRONG_NUM_ARGS.WITHOUT_SUGGESTION] The `assert_true` requires 1 or 2 parameters but the actual number is {len(snowpark_args)}."
|
|
@@ -2073,14 +2066,22 @@ def map_unresolved_function(
|
|
|
2073
2066
|
assert (
|
|
2074
2067
|
len(exp.unresolved_function.arguments) == 2
|
|
2075
2068
|
), "date_format takes 2 arguments"
|
|
2076
|
-
|
|
2077
|
-
|
|
2078
|
-
|
|
2079
|
-
|
|
2080
|
-
|
|
2081
|
-
|
|
2082
|
-
|
|
2083
|
-
|
|
2069
|
+
|
|
2070
|
+
# Check if format parameter is NULL
|
|
2071
|
+
format_literal = unwrap_literal(exp.unresolved_function.arguments[1])
|
|
2072
|
+
if format_literal is None:
|
|
2073
|
+
# If format is NULL, return NULL for all rows
|
|
2074
|
+
result_exp = snowpark_fn.lit(None)
|
|
2075
|
+
else:
|
|
2076
|
+
result_exp = snowpark_fn.date_format(
|
|
2077
|
+
snowpark_args[0],
|
|
2078
|
+
snowpark_fn.lit(
|
|
2079
|
+
map_spark_timestamp_format_expression(
|
|
2080
|
+
exp.unresolved_function.arguments[1],
|
|
2081
|
+
snowpark_typed_args[0].typ,
|
|
2082
|
+
)
|
|
2083
|
+
),
|
|
2084
|
+
)
|
|
2084
2085
|
result_exp = TypedColumn(result_exp, lambda: [StringType()])
|
|
2085
2086
|
case "date_from_unix_date":
|
|
2086
2087
|
result_exp = snowpark_fn.date_add(
|
|
@@ -2260,31 +2261,32 @@ def map_unresolved_function(
|
|
|
2260
2261
|
)
|
|
2261
2262
|
case "elt":
|
|
2262
2263
|
n = snowpark_args[0]
|
|
2263
|
-
|
|
2264
2264
|
values = snowpark_fn.array_construct(*snowpark_args[1:])
|
|
2265
2265
|
|
|
2266
2266
|
if spark_sql_ansi_enabled:
|
|
2267
|
-
|
|
2268
|
-
|
|
2269
|
-
input_types=[IntegerType()],
|
|
2270
|
-
return_type=StringType(),
|
|
2267
|
+
raise_error = _raise_error_helper(
|
|
2268
|
+
StringType(), error_class=ArrayIndexOutOfBoundsException
|
|
2271
2269
|
)
|
|
2272
|
-
def _raise_out_of_bounds_error(n: int) -> str:
|
|
2273
|
-
raise ValueError(
|
|
2274
|
-
f"ArrayIndexOutOfBoundsException: {n} is not within the input bounds."
|
|
2275
|
-
)
|
|
2276
|
-
|
|
2277
2270
|
values_size = snowpark_fn.lit(len(snowpark_args) - 1)
|
|
2278
2271
|
|
|
2279
2272
|
result_exp = (
|
|
2280
2273
|
snowpark_fn.when(snowpark_fn.is_null(n), snowpark_fn.lit(None))
|
|
2281
2274
|
.when(
|
|
2282
2275
|
(snowpark_fn.lit(1) <= n) & (n <= values_size),
|
|
2283
|
-
snowpark_fn.
|
|
2284
|
-
|
|
2276
|
+
snowpark_fn.cast(
|
|
2277
|
+
snowpark_fn.get(
|
|
2278
|
+
values, snowpark_fn.nvl(n - 1, snowpark_fn.lit(0))
|
|
2279
|
+
),
|
|
2280
|
+
StringType(),
|
|
2285
2281
|
),
|
|
2286
2282
|
)
|
|
2287
|
-
.otherwise(
|
|
2283
|
+
.otherwise(
|
|
2284
|
+
raise_error(
|
|
2285
|
+
snowpark_fn.lit("[INVALID_ARRAY_INDEX] The index "),
|
|
2286
|
+
snowpark_fn.cast(n, StringType()),
|
|
2287
|
+
snowpark_fn.lit(" is out of bounds."),
|
|
2288
|
+
)
|
|
2289
|
+
)
|
|
2288
2290
|
)
|
|
2289
2291
|
else:
|
|
2290
2292
|
result_exp = snowpark_fn.when(
|
|
@@ -2535,6 +2537,19 @@ def map_unresolved_function(
|
|
|
2535
2537
|
input_types=[StringType(), StringType(), StructType()],
|
|
2536
2538
|
)
|
|
2537
2539
|
def _from_csv(csv_data: str, schema: str, options: Optional[dict]):
|
|
2540
|
+
if csv_data is None:
|
|
2541
|
+
return None
|
|
2542
|
+
|
|
2543
|
+
if csv_data == "":
|
|
2544
|
+
# Return dict with None values for empty string
|
|
2545
|
+
schemas = schema.split(",")
|
|
2546
|
+
results = {}
|
|
2547
|
+
for sc in schemas:
|
|
2548
|
+
parts = [i for i in sc.split(" ") if len(i) != 0]
|
|
2549
|
+
assert len(parts) == 2, f"{sc} is not a valid schema"
|
|
2550
|
+
results[parts[0]] = None
|
|
2551
|
+
return results
|
|
2552
|
+
|
|
2538
2553
|
max_chars_per_column = -1
|
|
2539
2554
|
sep = ","
|
|
2540
2555
|
|
|
@@ -2617,7 +2632,9 @@ def map_unresolved_function(
|
|
|
2617
2632
|
case _:
|
|
2618
2633
|
raise ValueError("Unrecognized from_csv parameters")
|
|
2619
2634
|
|
|
2620
|
-
result_exp = snowpark_fn.
|
|
2635
|
+
result_exp = snowpark_fn.when(
|
|
2636
|
+
snowpark_args[0].is_null(), snowpark_fn.lit(None)
|
|
2637
|
+
).otherwise(snowpark_fn.cast(csv_result, ddl_schema))
|
|
2621
2638
|
result_type = ddl_schema
|
|
2622
2639
|
case "from_json":
|
|
2623
2640
|
# TODO: support options.
|
|
@@ -2651,6 +2668,9 @@ def map_unresolved_function(
|
|
|
2651
2668
|
# try to parse first, since spark returns null for invalid json
|
|
2652
2669
|
result_exp = snowpark_fn.call_function("try_parse_json", snowpark_args[0])
|
|
2653
2670
|
|
|
2671
|
+
# Check if the original input is NULL - if so, return NULL for the entire result
|
|
2672
|
+
original_input_is_null = snowpark_args[0].is_null()
|
|
2673
|
+
|
|
2654
2674
|
# helper function to make sure we have the expected array element type
|
|
2655
2675
|
def _element_type_matches(
|
|
2656
2676
|
array_exp: Column, element_type: DataType
|
|
@@ -2749,9 +2769,13 @@ def map_unresolved_function(
|
|
|
2749
2769
|
else:
|
|
2750
2770
|
return exp
|
|
2751
2771
|
|
|
2752
|
-
|
|
2753
|
-
|
|
2754
|
-
|
|
2772
|
+
# Apply the coercion to handle invalid JSON (creates struct with NULL fields)
|
|
2773
|
+
coerced_exp = _coerce_to_type(result_exp, result_type)
|
|
2774
|
+
|
|
2775
|
+
# If the original input was NULL, return NULL instead of a struct
|
|
2776
|
+
result_exp = snowpark_fn.when(
|
|
2777
|
+
original_input_is_null, snowpark_fn.lit(None)
|
|
2778
|
+
).otherwise(snowpark_fn.cast(coerced_exp, result_type))
|
|
2755
2779
|
case "from_unixtime":
|
|
2756
2780
|
|
|
2757
2781
|
def raise_analysis_exception(
|
|
@@ -2896,10 +2920,53 @@ def map_unresolved_function(
|
|
|
2896
2920
|
)
|
|
2897
2921
|
case "grouping" | "grouping_id":
|
|
2898
2922
|
# grouping_id is not an alias for grouping in PySpark, but Snowflake's implementation handles both
|
|
2899
|
-
|
|
2923
|
+
current_grouping_cols = get_current_grouping_columns()
|
|
2924
|
+
if function_name == "grouping_id":
|
|
2925
|
+
if not snowpark_args:
|
|
2926
|
+
# grouping_id() with empty args means use all grouping columns
|
|
2927
|
+
spark_function_name = "grouping_id()"
|
|
2928
|
+
snowpark_args = [
|
|
2929
|
+
column_mapping.get_snowpark_column_name_from_spark_column_name(
|
|
2930
|
+
spark_col
|
|
2931
|
+
)
|
|
2932
|
+
for spark_col in current_grouping_cols
|
|
2933
|
+
]
|
|
2934
|
+
else:
|
|
2935
|
+
# Verify that grouping arguments match current grouping columns
|
|
2936
|
+
spark_col_args = [
|
|
2937
|
+
column_mapping.get_spark_column_name_from_snowpark_column_name(
|
|
2938
|
+
sp_col.getName()
|
|
2939
|
+
)
|
|
2940
|
+
for sp_col in snowpark_args
|
|
2941
|
+
]
|
|
2942
|
+
if current_grouping_cols != spark_col_args:
|
|
2943
|
+
raise AnalysisException(
|
|
2944
|
+
f"[GROUPING_ID_COLUMN_MISMATCH] Columns of grouping_id: {spark_col_args} doesnt match "
|
|
2945
|
+
f"Grouping columns: {current_grouping_cols}"
|
|
2946
|
+
)
|
|
2947
|
+
if function_name == "grouping_id":
|
|
2948
|
+
result_exp = snowpark_fn.grouping_id(*snowpark_args)
|
|
2949
|
+
else:
|
|
2950
|
+
result_exp = snowpark_fn.grouping(*snowpark_args)
|
|
2900
2951
|
result_type = LongType()
|
|
2901
2952
|
case "hash":
|
|
2902
2953
|
# TODO: See the spark-compatibility-issues.md explanation, this is quite different from Spark.
|
|
2954
|
+
# MapType columns as input should raise an exception as they are not hashable.
|
|
2955
|
+
snowflake_compat = get_boolean_session_config_param(
|
|
2956
|
+
"enable_snowflake_extension_behavior"
|
|
2957
|
+
)
|
|
2958
|
+
# Snowflake's hash function does allow MAP types, but Spark does not. Therefore, if we have the expansion flag enabled
|
|
2959
|
+
# we want to let it pass through and hash MAP types.
|
|
2960
|
+
# Also allow if the legacy config spark.sql.legacy.allowHashOnMapType is set to true
|
|
2961
|
+
if not snowflake_compat and not spark_sql_legacy_allow_hash_on_map_type:
|
|
2962
|
+
for arg in snowpark_typed_args:
|
|
2963
|
+
if any(isinstance(t, MapType) for t in arg.types):
|
|
2964
|
+
raise AnalysisException(
|
|
2965
|
+
'[DATATYPE_MISMATCH.HASH_MAP_TYPE] Cannot resolve "hash(value)" due to data type mismatch: '
|
|
2966
|
+
'Input to the function `hash` cannot contain elements of the "MAP" type. '
|
|
2967
|
+
'In Spark, same maps may have different hashcode, thus hash expressions are prohibited on "MAP" elements. '
|
|
2968
|
+
'To restore previous behavior set "spark.sql.legacy.allowHashOnMapType" to "true".'
|
|
2969
|
+
)
|
|
2903
2970
|
result_exp = snowpark_fn.hash(*snowpark_args)
|
|
2904
2971
|
result_type = LongType()
|
|
2905
2972
|
case "hex":
|
|
@@ -2934,6 +3001,14 @@ def map_unresolved_function(
|
|
|
2934
3001
|
result_type = StringType()
|
|
2935
3002
|
case "histogram_numeric":
|
|
2936
3003
|
aggregate_input_typ = snowpark_typed_args[0].typ
|
|
3004
|
+
|
|
3005
|
+
if isinstance(aggregate_input_typ, DecimalType):
|
|
3006
|
+
# mimic bug from Spark 3.5.3.
|
|
3007
|
+
# In 3.5.5 it's fixed and this exception shouldn't be thrown
|
|
3008
|
+
raise ValueError(
|
|
3009
|
+
"class org.apache.spark.sql.types.Decimal cannot be cast to class java.lang.Number (org.apache.spark.sql.types.Decimal is in unnamed module of loader 'app'; java.lang.Number is in module java.base of loader 'bootstrap')"
|
|
3010
|
+
)
|
|
3011
|
+
|
|
2937
3012
|
histogram_return_type = ArrayType(
|
|
2938
3013
|
StructType(
|
|
2939
3014
|
[
|
|
@@ -3154,6 +3229,18 @@ def map_unresolved_function(
|
|
|
3154
3229
|
)
|
|
3155
3230
|
result_type = histogram_return_type
|
|
3156
3231
|
case "hll_sketch_agg":
|
|
3232
|
+
# check if input type is correct
|
|
3233
|
+
if type(snowpark_typed_args[0].typ) not in [
|
|
3234
|
+
IntegerType,
|
|
3235
|
+
LongType,
|
|
3236
|
+
StringType,
|
|
3237
|
+
BinaryType,
|
|
3238
|
+
]:
|
|
3239
|
+
type_str = snowpark_typed_args[0].typ.simpleString().upper()
|
|
3240
|
+
raise AnalysisException(
|
|
3241
|
+
f'[DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE] Cannot resolve "{spark_function_name}" due to data type mismatch: Parameter 1 requires the ("INT" or "BIGINT" or "STRING" or "BINARY") type, however "{snowpark_arg_names[0]}" has the type "{type_str}".'
|
|
3242
|
+
)
|
|
3243
|
+
|
|
3157
3244
|
match snowpark_args:
|
|
3158
3245
|
case [sketch]:
|
|
3159
3246
|
spark_function_name = (
|
|
@@ -3173,7 +3260,7 @@ def map_unresolved_function(
|
|
|
3173
3260
|
).cast(LongType())
|
|
3174
3261
|
result_type = LongType()
|
|
3175
3262
|
case "hll_union_agg":
|
|
3176
|
-
raise_error =
|
|
3263
|
+
raise_error = _raise_error_helper(BinaryType())
|
|
3177
3264
|
args = exp.unresolved_function.arguments
|
|
3178
3265
|
allow_different_lgConfigK = len(args) == 2 and unwrap_literal(args[1])
|
|
3179
3266
|
spark_function_name = f"{function_name}({snowpark_arg_names[0]}, {str(allow_different_lgConfigK).lower()})"
|
|
@@ -3213,7 +3300,7 @@ def map_unresolved_function(
|
|
|
3213
3300
|
SELECT arg1 as x)
|
|
3214
3301
|
""",
|
|
3215
3302
|
)
|
|
3216
|
-
raise_error =
|
|
3303
|
+
raise_error = _raise_error_helper(BinaryType())
|
|
3217
3304
|
args = exp.unresolved_function.arguments
|
|
3218
3305
|
allow_different_lgConfigK = len(args) == 3 and unwrap_literal(args[2])
|
|
3219
3306
|
spark_function_name = f"{function_name}({snowpark_arg_names[0]}, {snowpark_arg_names[1]}, {str(allow_different_lgConfigK).lower()})"
|
|
@@ -3796,12 +3883,47 @@ def map_unresolved_function(
|
|
|
3796
3883
|
)
|
|
3797
3884
|
|
|
3798
3885
|
result_type = StringType()
|
|
3799
|
-
case "ltrim":
|
|
3886
|
+
case "ltrim" | "rtrim":
|
|
3887
|
+
function_name_argument = (
|
|
3888
|
+
"TRAILING" if function_name == "rtrim" else "LEADING"
|
|
3889
|
+
)
|
|
3800
3890
|
if len(snowpark_args) == 2:
|
|
3801
3891
|
# Only possible using SQL
|
|
3802
|
-
spark_function_name = f"TRIM(
|
|
3892
|
+
spark_function_name = f"TRIM({function_name_argument} {snowpark_arg_names[1]} FROM {snowpark_arg_names[0]})"
|
|
3803
3893
|
result_exp = snowpark_fn.ltrim(*snowpark_args)
|
|
3804
3894
|
result_type = StringType()
|
|
3895
|
+
if isinstance(snowpark_typed_args[0].typ, BinaryType):
|
|
3896
|
+
argument_name = snowpark_arg_names[0]
|
|
3897
|
+
if exp.unresolved_function.arguments[0].HasField("literal"):
|
|
3898
|
+
argument_name = f"""X'{exp.unresolved_function.arguments[0].literal.binary.hex()}'"""
|
|
3899
|
+
if len(snowpark_args) == 1:
|
|
3900
|
+
spark_function_name = f"{function_name}({argument_name})"
|
|
3901
|
+
trim_value = snowpark_fn.lit(b"\x20")
|
|
3902
|
+
if len(snowpark_args) == 2:
|
|
3903
|
+
# Only possible using SQL
|
|
3904
|
+
trim_arg = snowpark_arg_names[1]
|
|
3905
|
+
if isinstance(
|
|
3906
|
+
snowpark_typed_args[1].typ, BinaryType
|
|
3907
|
+
) and exp.unresolved_function.arguments[1].HasField("literal"):
|
|
3908
|
+
trim_arg = f"""X'{exp.unresolved_function.arguments[1].literal.binary.hex()}'"""
|
|
3909
|
+
trim_value = snowpark_args[1]
|
|
3910
|
+
else:
|
|
3911
|
+
trim_value = snowpark_fn.lit(None)
|
|
3912
|
+
function_name_argument = (
|
|
3913
|
+
"TRAILING" if function_name == "rtrim" else "LEADING"
|
|
3914
|
+
)
|
|
3915
|
+
spark_function_name = f"TRIM({function_name_argument} {trim_arg} FROM {argument_name})"
|
|
3916
|
+
result_exp = _trim_helper(
|
|
3917
|
+
snowpark_args[0], trim_value, snowpark_fn.lit(function_name)
|
|
3918
|
+
)
|
|
3919
|
+
result_type = BinaryType()
|
|
3920
|
+
else:
|
|
3921
|
+
if function_name == "ltrim":
|
|
3922
|
+
result_exp = snowpark_fn.ltrim(*snowpark_args)
|
|
3923
|
+
result_type = StringType()
|
|
3924
|
+
elif function_name == "rtrim":
|
|
3925
|
+
result_exp = snowpark_fn.rtrim(*snowpark_args)
|
|
3926
|
+
result_type = StringType()
|
|
3805
3927
|
case "make_date":
|
|
3806
3928
|
y = snowpark_args[0].cast(LongType())
|
|
3807
3929
|
m = snowpark_args[1].cast(LongType())
|
|
@@ -3902,7 +4024,7 @@ def map_unresolved_function(
|
|
|
3902
4024
|
snowpark_fn.is_null(snowpark_args[i]),
|
|
3903
4025
|
# udf execution on XP seems to be lazy, so this should only run when there is a null key
|
|
3904
4026
|
# otherwise there should be no udf env setup or execution
|
|
3905
|
-
|
|
4027
|
+
_raise_error_helper(VariantType())(
|
|
3906
4028
|
snowpark_fn.lit(
|
|
3907
4029
|
"[NULL_MAP_KEY] Cannot use null as map key."
|
|
3908
4030
|
)
|
|
@@ -3964,6 +4086,14 @@ def map_unresolved_function(
|
|
|
3964
4086
|
)
|
|
3965
4087
|
result_type = MapType(key_type, value_type)
|
|
3966
4088
|
case "map_contains_key":
|
|
4089
|
+
if isinstance(snowpark_typed_args[0].typ, NullType):
|
|
4090
|
+
raise AnalysisException(
|
|
4091
|
+
f"""[DATATYPE_MISMATCH.MAP_FUNCTION_DIFF_TYPES] Cannot resolve "map_contains_key({snowpark_arg_names[0]}, {snowpark_arg_names[1]})" due to data type mismatch: Input to `map_contains_key` should have been "MAP" followed by a value with same key type, but it's ["VOID", "INT"]."""
|
|
4092
|
+
)
|
|
4093
|
+
if isinstance(snowpark_typed_args[1].typ, NullType):
|
|
4094
|
+
raise AnalysisException(
|
|
4095
|
+
f"""[DATATYPE_MISMATCH.NULL_TYPE] Cannot resolve "map_contains_key({snowpark_arg_names[0]}, {snowpark_arg_names[1]})" due to data type mismatch: Null typed values cannot be used as arguments of `map_contains_key`."""
|
|
4096
|
+
)
|
|
3967
4097
|
args = (
|
|
3968
4098
|
[snowpark_args[1], snowpark_args[0]]
|
|
3969
4099
|
if isinstance(snowpark_typed_args[0].typ, MapType)
|
|
@@ -4093,17 +4223,37 @@ def map_unresolved_function(
|
|
|
4093
4223
|
|
|
4094
4224
|
last_win_dedup = global_config.spark_sql_mapKeyDedupPolicy == "LAST_WIN"
|
|
4095
4225
|
|
|
4096
|
-
|
|
4097
|
-
|
|
4098
|
-
|
|
4099
|
-
snowpark_fn.
|
|
4100
|
-
|
|
4101
|
-
|
|
4102
|
-
|
|
4103
|
-
|
|
4104
|
-
|
|
4105
|
-
|
|
4226
|
+
# Check if any entry has a NULL key
|
|
4227
|
+
has_null_key = (
|
|
4228
|
+
snowpark_fn.function("array_size")(
|
|
4229
|
+
snowpark_fn.function("filter")(
|
|
4230
|
+
snowpark_args[0],
|
|
4231
|
+
snowpark_fn.sql_expr(f"e -> e:{key_field} IS NULL"),
|
|
4232
|
+
)
|
|
4233
|
+
)
|
|
4234
|
+
> 0
|
|
4235
|
+
)
|
|
4236
|
+
|
|
4237
|
+
# Create error UDF for NULL keys (same pattern as map function)
|
|
4238
|
+
null_key_error = _raise_error_helper(VariantType())(
|
|
4239
|
+
snowpark_fn.lit("[NULL_MAP_KEY] Cannot use null as map key.")
|
|
4240
|
+
)
|
|
4241
|
+
|
|
4242
|
+
# Create the reduce operation
|
|
4243
|
+
reduce_result = snowpark_fn.function("reduce")(
|
|
4244
|
+
snowpark_args[0],
|
|
4245
|
+
snowpark_fn.object_construct(),
|
|
4246
|
+
snowpark_fn.sql_expr(
|
|
4247
|
+
# value_field is cast to variant because object_insert doesn't allow structured types,
|
|
4248
|
+
# and structured types are not coercible to variant
|
|
4249
|
+
# TODO: allow structured types in object_insert?
|
|
4250
|
+
f"(acc, e) -> object_insert(acc, e:{key_field}, e:{value_field}::variant, {last_win_dedup})"
|
|
4106
4251
|
),
|
|
4252
|
+
)
|
|
4253
|
+
|
|
4254
|
+
# Use conditional logic: if there are NULL keys, throw error; otherwise proceed with reduce
|
|
4255
|
+
result_exp = snowpark_fn.cast(
|
|
4256
|
+
snowpark_fn.when(has_null_key, null_key_error).otherwise(reduce_result),
|
|
4107
4257
|
MapType(key_type, value_type),
|
|
4108
4258
|
)
|
|
4109
4259
|
result_type = MapType(key_type, value_type)
|
|
@@ -4122,23 +4272,35 @@ def map_unresolved_function(
|
|
|
4122
4272
|
# TODO: implement in Snowflake/Snowpark
|
|
4123
4273
|
# technically this could be done with a lateral join, but it's probably not worth the effort
|
|
4124
4274
|
arg_type = snowpark_typed_args[0].typ
|
|
4125
|
-
if not isinstance(arg_type, MapType):
|
|
4275
|
+
if not isinstance(arg_type, (MapType, NullType)):
|
|
4126
4276
|
raise AnalysisException(
|
|
4127
4277
|
f"map_values requires a MapType argument, got {arg_type}"
|
|
4128
4278
|
)
|
|
4129
4279
|
|
|
4130
4280
|
def _map_values(obj: dict) -> list:
|
|
4131
|
-
|
|
4281
|
+
if obj is None:
|
|
4282
|
+
return None
|
|
4283
|
+
return list(obj.values())
|
|
4132
4284
|
|
|
4133
4285
|
map_values = cached_udf(
|
|
4134
4286
|
_map_values, return_type=ArrayType(), input_types=[StructType()]
|
|
4135
4287
|
)
|
|
4136
4288
|
|
|
4137
|
-
|
|
4138
|
-
|
|
4139
|
-
|
|
4140
|
-
|
|
4141
|
-
|
|
4289
|
+
# Handle NULL input directly at expression level
|
|
4290
|
+
if isinstance(arg_type, NullType):
|
|
4291
|
+
# If input is NULL literal, return NULL
|
|
4292
|
+
result_exp = snowpark_fn.lit(None)
|
|
4293
|
+
result_type = ArrayType(NullType())
|
|
4294
|
+
else:
|
|
4295
|
+
result_exp = snowpark_fn.when(
|
|
4296
|
+
snowpark_args[0].is_null(), snowpark_fn.lit(None)
|
|
4297
|
+
).otherwise(
|
|
4298
|
+
snowpark_fn.cast(
|
|
4299
|
+
map_values(snowpark_fn.cast(snowpark_args[0], StructType())),
|
|
4300
|
+
ArrayType(arg_type.value_type),
|
|
4301
|
+
)
|
|
4302
|
+
)
|
|
4303
|
+
result_type = ArrayType(arg_type.value_type)
|
|
4142
4304
|
case "mask":
|
|
4143
4305
|
|
|
4144
4306
|
number_of_args = len(snowpark_args)
|
|
@@ -4258,6 +4420,17 @@ def map_unresolved_function(
|
|
|
4258
4420
|
lambda: snowpark_typed_args[0].types,
|
|
4259
4421
|
)
|
|
4260
4422
|
case "md5":
|
|
4423
|
+
snowflake_compat = get_boolean_session_config_param(
|
|
4424
|
+
"enable_snowflake_extension_behavior"
|
|
4425
|
+
)
|
|
4426
|
+
|
|
4427
|
+
# MD5 in Spark only accepts BinaryType or types that can be implicitly cast to it (StringType)
|
|
4428
|
+
if not snowflake_compat:
|
|
4429
|
+
if not isinstance(snowpark_typed_args[0].typ, (BinaryType, StringType)):
|
|
4430
|
+
raise AnalysisException(
|
|
4431
|
+
f'[DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE] Cannot resolve "md5({snowpark_arg_names[0]})" due to data type mismatch: '
|
|
4432
|
+
f'Parameter 1 requires the "BINARY" type, however "{snowpark_arg_names[0]}" has the type "{snowpark_typed_args[0].typ}".'
|
|
4433
|
+
)
|
|
4261
4434
|
result_exp = snowpark_fn.md5(snowpark_args[0])
|
|
4262
4435
|
result_type = StringType(32)
|
|
4263
4436
|
case "median":
|
|
@@ -5032,7 +5205,7 @@ def map_unresolved_function(
|
|
|
5032
5205
|
result_type = DoubleType()
|
|
5033
5206
|
case "raise_error":
|
|
5034
5207
|
result_type = StringType()
|
|
5035
|
-
raise_error =
|
|
5208
|
+
raise_error = _raise_error_helper(result_type)
|
|
5036
5209
|
result_exp = raise_error(*snowpark_args)
|
|
5037
5210
|
case "rand" | "random":
|
|
5038
5211
|
# Snowpark random() generates a 64 bit signed integer, but pyspark is [0.0, 1.0).
|
|
@@ -5117,7 +5290,7 @@ def map_unresolved_function(
|
|
|
5117
5290
|
snowpark_args[2],
|
|
5118
5291
|
),
|
|
5119
5292
|
),
|
|
5120
|
-
|
|
5293
|
+
_raise_error_helper(StringType())(
|
|
5121
5294
|
snowpark_fn.lit(
|
|
5122
5295
|
"[INVALID_PARAMETER_VALUE.REGEX_GROUP_INDEX] The value of parameter(s) `idx` in `regexp_extract` is invalid."
|
|
5123
5296
|
)
|
|
@@ -5167,7 +5340,7 @@ def map_unresolved_function(
|
|
|
5167
5340
|
idx,
|
|
5168
5341
|
)
|
|
5169
5342
|
),
|
|
5170
|
-
|
|
5343
|
+
_raise_error_helper(ArrayType(StringType()))(
|
|
5171
5344
|
snowpark_fn.lit(
|
|
5172
5345
|
"[INVALID_PARAMETER_VALUE.REGEX_GROUP_INDEX] The value of parameter(s) `idx` in `regexp_extract_all` is invalid."
|
|
5173
5346
|
)
|
|
@@ -5466,13 +5639,28 @@ def map_unresolved_function(
|
|
|
5466
5639
|
case "row_number":
|
|
5467
5640
|
result_exp = snowpark_fn.row_number()
|
|
5468
5641
|
result_exp = TypedColumn(result_exp, lambda: [LongType()])
|
|
5469
|
-
case "rtrim":
|
|
5470
|
-
if len(snowpark_args) == 2:
|
|
5471
|
-
# Only possible using SQL
|
|
5472
|
-
spark_function_name = f"TRIM(TRAILING {snowpark_arg_names[1]} FROM {snowpark_arg_names[0]})"
|
|
5473
|
-
result_exp = snowpark_fn.rtrim(*snowpark_args)
|
|
5474
|
-
result_type = StringType()
|
|
5475
5642
|
case "schema_of_csv":
|
|
5643
|
+
# Validate that the input is a foldable STRING expression
|
|
5644
|
+
if (
|
|
5645
|
+
exp.unresolved_function.arguments[0].WhichOneof("expr_type")
|
|
5646
|
+
!= "literal"
|
|
5647
|
+
):
|
|
5648
|
+
raise AnalysisException(
|
|
5649
|
+
"[DATATYPE_MISMATCH.NON_FOLDABLE_INPUT] Cannot resolve "
|
|
5650
|
+
f'"schema_of_csv({snowpark_arg_names[0]})" due to data type mismatch: '
|
|
5651
|
+
'the input csv should be a foldable "STRING" expression; however, '
|
|
5652
|
+
f'got "{snowpark_arg_names[0]}".'
|
|
5653
|
+
)
|
|
5654
|
+
|
|
5655
|
+
if isinstance(snowpark_typed_args[0].typ, StringType):
|
|
5656
|
+
if exp.unresolved_function.arguments[0].literal.string == "":
|
|
5657
|
+
raise AnalysisException(
|
|
5658
|
+
"[DATATYPE_MISMATCH.NON_FOLDABLE_INPUT] Cannot resolve "
|
|
5659
|
+
f'"schema_of_csv({snowpark_arg_names[0]})" due to data type mismatch: '
|
|
5660
|
+
'the input csv should be a foldable "STRING" expression; however, '
|
|
5661
|
+
f'got "{snowpark_arg_names[0]}".'
|
|
5662
|
+
)
|
|
5663
|
+
|
|
5476
5664
|
snowpark_args = [
|
|
5477
5665
|
typed_arg.column(to_semi_structure=True)
|
|
5478
5666
|
for typed_arg in snowpark_typed_args
|
|
@@ -5689,6 +5877,16 @@ def map_unresolved_function(
|
|
|
5689
5877
|
)
|
|
5690
5878
|
result_type = ArrayType(ArrayType(StringType()))
|
|
5691
5879
|
case "sequence":
|
|
5880
|
+
if snowpark_typed_args[0].typ != snowpark_typed_args[1].typ or (
|
|
5881
|
+
not isinstance(snowpark_typed_args[0].typ, _IntegralType)
|
|
5882
|
+
or not isinstance(snowpark_typed_args[1].typ, _IntegralType)
|
|
5883
|
+
):
|
|
5884
|
+
raise AnalysisException(
|
|
5885
|
+
f"""[DATATYPE_MISMATCH.SEQUENCE_WRONG_INPUT_TYPES] Cannot resolve "sequence({snowpark_arg_names[0]}, {snowpark_arg_names[1]})" due to data type mismatch: `sequence` uses the wrong parameter type. The parameter type must conform to:
|
|
5886
|
+
1. The start and stop expressions must resolve to the same type.
|
|
5887
|
+
2. Otherwise, if start and stop expressions resolve to the "INTEGRAL" type, then the step expression must resolve to the same type.
|
|
5888
|
+
"""
|
|
5889
|
+
)
|
|
5692
5890
|
result_exp = snowpark_fn.cast(
|
|
5693
5891
|
snowpark_fn.sequence(*snowpark_args),
|
|
5694
5892
|
ArrayType(LongType(), contains_null=False),
|
|
@@ -5856,7 +6054,7 @@ def map_unresolved_function(
|
|
|
5856
6054
|
result_exp = snowpark_fn.skew(snowpark_fn.lit(None))
|
|
5857
6055
|
result_type = DoubleType()
|
|
5858
6056
|
case "slice":
|
|
5859
|
-
raise_error =
|
|
6057
|
+
raise_error = _raise_error_helper(snowpark_typed_args[0].typ)
|
|
5860
6058
|
spark_index = snowpark_args[1]
|
|
5861
6059
|
arr_size = snowpark_fn.array_size(snowpark_args[0])
|
|
5862
6060
|
slice_len = snowpark_args[2]
|
|
@@ -5926,10 +6124,11 @@ def map_unresolved_function(
|
|
|
5926
6124
|
result_exp = snowpark_fn.lit(0)
|
|
5927
6125
|
result_type = LongType()
|
|
5928
6126
|
case "split":
|
|
6127
|
+
result_type = ArrayType(StringType())
|
|
5929
6128
|
|
|
5930
6129
|
@cached_udf(
|
|
5931
6130
|
input_types=[StringType(), StringType(), IntegerType()],
|
|
5932
|
-
return_type=
|
|
6131
|
+
return_type=result_type,
|
|
5933
6132
|
)
|
|
5934
6133
|
def _split(
|
|
5935
6134
|
input: Optional[str], pattern: Optional[str], limit: Optional[int]
|
|
@@ -5937,34 +6136,80 @@ def map_unresolved_function(
|
|
|
5937
6136
|
if input is None or pattern is None:
|
|
5938
6137
|
return None
|
|
5939
6138
|
|
|
6139
|
+
import re
|
|
6140
|
+
|
|
6141
|
+
try:
|
|
6142
|
+
re.compile(pattern)
|
|
6143
|
+
except re.error:
|
|
6144
|
+
raise ValueError(
|
|
6145
|
+
f"Failed to split string, provided pattern: {pattern} is invalid"
|
|
6146
|
+
)
|
|
6147
|
+
|
|
5940
6148
|
if limit == 1:
|
|
5941
6149
|
return [input]
|
|
5942
6150
|
|
|
5943
|
-
|
|
6151
|
+
if not input:
|
|
6152
|
+
return []
|
|
5944
6153
|
|
|
5945
6154
|
# A default of -1 is passed in PySpark, but RE needs it to be 0 to provide all splits.
|
|
5946
6155
|
# In PySpark, the limit also indicates the max size of the resulting array, but in RE
|
|
5947
6156
|
# the remainder is returned as another element.
|
|
5948
6157
|
maxsplit = limit - 1 if limit > 0 else 0
|
|
5949
6158
|
|
|
5950
|
-
split_result = re.split(pattern, input, maxsplit)
|
|
5951
6159
|
if len(pattern) == 0:
|
|
5952
|
-
|
|
5953
|
-
|
|
6160
|
+
return list(input) if limit <= 0 else list(input)[:limit]
|
|
6161
|
+
|
|
6162
|
+
match pattern:
|
|
6163
|
+
case "|":
|
|
6164
|
+
split_result = re.split(pattern, input, 0)
|
|
6165
|
+
input_limit = limit + 1 if limit > 0 else len(split_result)
|
|
6166
|
+
return (
|
|
6167
|
+
split_result
|
|
6168
|
+
if input_limit == 0
|
|
6169
|
+
else split_result[1:input_limit]
|
|
6170
|
+
)
|
|
6171
|
+
case "$":
|
|
6172
|
+
return [input, ""] if maxsplit >= 0 else [input]
|
|
6173
|
+
case "^":
|
|
6174
|
+
return [input]
|
|
6175
|
+
case _:
|
|
6176
|
+
return re.split(pattern, input, maxsplit)
|
|
6177
|
+
|
|
6178
|
+
def split_string(str_: Column, pattern: Column, limit: Column):
|
|
6179
|
+
native_split = _split(str_, pattern, limit)
|
|
6180
|
+
# When pattern is a literal and doesn't contain any regex special characters
|
|
6181
|
+
# And when limit is less than or equal to 0
|
|
6182
|
+
# Native Snowflake Split function is used to optimise performance
|
|
6183
|
+
if isinstance(pattern._expression, Literal):
|
|
6184
|
+
pattern_value = pattern._expression.value
|
|
6185
|
+
|
|
6186
|
+
if pattern_value is None:
|
|
6187
|
+
return snowpark_fn.lit(None)
|
|
6188
|
+
|
|
6189
|
+
is_regexp = re.match(
|
|
6190
|
+
".*[\\[\\.\\]\\*\\?\\+\\^\\$\\{\\}\\|\\(\\)\\\\].*",
|
|
6191
|
+
pattern_value,
|
|
6192
|
+
)
|
|
6193
|
+
is_empty = len(pattern_value) == 0
|
|
6194
|
+
|
|
6195
|
+
if not is_empty and not is_regexp:
|
|
6196
|
+
return snowpark_fn.when(
|
|
6197
|
+
limit <= 0,
|
|
6198
|
+
snowpark_fn.split(str_, pattern).cast(result_type),
|
|
6199
|
+
).otherwise(native_split)
|
|
5954
6200
|
|
|
5955
|
-
return
|
|
6201
|
+
return native_split
|
|
5956
6202
|
|
|
5957
6203
|
match snowpark_args:
|
|
5958
6204
|
case [str_, pattern]:
|
|
5959
6205
|
spark_function_name = (
|
|
5960
6206
|
f"split({snowpark_arg_names[0]}, {snowpark_arg_names[1]}, -1)"
|
|
5961
6207
|
)
|
|
5962
|
-
result_exp =
|
|
6208
|
+
result_exp = split_string(str_, pattern, snowpark_fn.lit(-1))
|
|
5963
6209
|
case [str_, pattern, limit]: # noqa: F841
|
|
5964
|
-
result_exp =
|
|
6210
|
+
result_exp = split_string(str_, pattern, limit)
|
|
5965
6211
|
case _:
|
|
5966
6212
|
raise ValueError(f"Invalid number of arguments to {function_name}")
|
|
5967
|
-
result_type = ArrayType(StringType())
|
|
5968
6213
|
case "split_part":
|
|
5969
6214
|
result_exp = snowpark_fn.call_function("split_part", *snowpark_args)
|
|
5970
6215
|
result_type = StringType()
|
|
@@ -6274,6 +6519,10 @@ def map_unresolved_function(
|
|
|
6274
6519
|
)
|
|
6275
6520
|
result_type = TimestampType(snowpark.types.TimestampTimeZone.NTZ)
|
|
6276
6521
|
case "timestamp_millis":
|
|
6522
|
+
if not isinstance(snowpark_typed_args[0].typ, _IntegralType):
|
|
6523
|
+
raise AnalysisException(
|
|
6524
|
+
f'[DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE] Cannot resolve "timestamp_millis({snowpark_arg_names[0]}" due to data type mismatch: Parameter 1 requires the "INTEGRAL" type, however "{snowpark_arg_names[0]}" has the type "{snowpark_typed_args[0].typ}".'
|
|
6525
|
+
)
|
|
6277
6526
|
result_exp = snowpark_fn.cast(
|
|
6278
6527
|
snowpark_fn.to_timestamp(snowpark_args[0] * 1_000, 6),
|
|
6279
6528
|
TimestampType(snowpark.types.TimestampTimeZone.NTZ),
|
|
@@ -6283,6 +6532,10 @@ def map_unresolved_function(
|
|
|
6283
6532
|
# Spark allows seconds to be fractional. Snowflake does not allow that
|
|
6284
6533
|
# even though the documentation explicitly says that it does.
|
|
6285
6534
|
# As a workaround, use integer milliseconds instead of fractional seconds.
|
|
6535
|
+
if not isinstance(snowpark_typed_args[0].typ, _NumericType):
|
|
6536
|
+
raise AnalysisException(
|
|
6537
|
+
f"""AnalysisException: [DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE] Cannot resolve "{function_name}({snowpark_arg_names[0]})" due to data type mismatch: Parameter 1 requires the "NUMERIC" type, however "{snowpark_arg_names[0]}" has the type "{snowpark_typed_args[0].typ}".;"""
|
|
6538
|
+
)
|
|
6286
6539
|
result_exp = snowpark_fn.cast(
|
|
6287
6540
|
snowpark_fn.to_timestamp(
|
|
6288
6541
|
snowpark_fn.cast(snowpark_args[0] * 1_000_000, LongType()), 6
|
|
@@ -6725,7 +6978,20 @@ def map_unresolved_function(
|
|
|
6725
6978
|
result_type = StringType()
|
|
6726
6979
|
case "trunc":
|
|
6727
6980
|
part = unwrap_literal(exp.unresolved_function.arguments[1])
|
|
6728
|
-
if part is None
|
|
6981
|
+
part = None if part is None else part.lower()
|
|
6982
|
+
|
|
6983
|
+
allowed_parts = {
|
|
6984
|
+
"year",
|
|
6985
|
+
"yyyy",
|
|
6986
|
+
"yy",
|
|
6987
|
+
"month",
|
|
6988
|
+
"mon",
|
|
6989
|
+
"mm",
|
|
6990
|
+
"week",
|
|
6991
|
+
"quarter",
|
|
6992
|
+
}
|
|
6993
|
+
|
|
6994
|
+
if part not in allowed_parts:
|
|
6729
6995
|
result_exp = snowpark_fn.lit(None)
|
|
6730
6996
|
else:
|
|
6731
6997
|
result_exp = _try_to_cast(
|
|
@@ -7116,6 +7382,12 @@ def map_unresolved_function(
|
|
|
7116
7382
|
)
|
|
7117
7383
|
)
|
|
7118
7384
|
)
|
|
7385
|
+
raise_fn = _raise_error_helper(BinaryType(), IllegalArgumentException)
|
|
7386
|
+
result_exp = (
|
|
7387
|
+
snowpark_fn.when(unbase_arg.is_null(), snowpark_fn.lit(None))
|
|
7388
|
+
.when(result_exp.is_null(), raise_fn(snowpark_fn.lit("Invalid input")))
|
|
7389
|
+
.otherwise(result_exp)
|
|
7390
|
+
)
|
|
7119
7391
|
result_type = BinaryType()
|
|
7120
7392
|
case "unhex":
|
|
7121
7393
|
# Non string columns, convert them to string type. This mimics pyspark behavior.
|
|
@@ -7316,6 +7588,15 @@ def map_unresolved_function(
|
|
|
7316
7588
|
)
|
|
7317
7589
|
result_type = LongType()
|
|
7318
7590
|
case "when" | "if":
|
|
7591
|
+
# Validate that the condition is a boolean expression
|
|
7592
|
+
if len(snowpark_typed_args) > 0:
|
|
7593
|
+
condition_type = snowpark_typed_args[0].typ
|
|
7594
|
+
if not isinstance(condition_type, BooleanType):
|
|
7595
|
+
raise AnalysisException(
|
|
7596
|
+
f"[DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE] Cannot resolve CASE WHEN condition due to data type mismatch: "
|
|
7597
|
+
f"Parameter 1 requires the 'BOOLEAN' type, however got '{condition_type}'"
|
|
7598
|
+
)
|
|
7599
|
+
|
|
7319
7600
|
name_components = ["CASE"]
|
|
7320
7601
|
name_components.append("WHEN")
|
|
7321
7602
|
name_components.append(snowpark_arg_names[0])
|
|
@@ -7334,6 +7615,13 @@ def map_unresolved_function(
|
|
|
7334
7615
|
name_components.append(snowpark_arg_names[i])
|
|
7335
7616
|
name_components.append("THEN")
|
|
7336
7617
|
name_components.append(snowpark_arg_names[i + 1])
|
|
7618
|
+
# Validate each WHEN condition
|
|
7619
|
+
condition_type = snowpark_typed_args[i].typ
|
|
7620
|
+
if not isinstance(condition_type, BooleanType):
|
|
7621
|
+
raise AnalysisException(
|
|
7622
|
+
f"[DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE] Cannot resolve CASE WHEN condition due to data type mismatch: "
|
|
7623
|
+
f"Parameter {i + 1} requires the 'BOOLEAN' type, however got '{condition_type}'"
|
|
7624
|
+
)
|
|
7337
7625
|
result_exp = result_exp.when(snowpark_args[i], snowpark_args[i + 1])
|
|
7338
7626
|
result_type_indexes.append(i + 1)
|
|
7339
7627
|
name_components.append("END")
|
|
@@ -7710,16 +7998,8 @@ def _handle_current_timestamp():
|
|
|
7710
7998
|
|
|
7711
7999
|
|
|
7712
8000
|
def _equivalent_decimal(type):
|
|
7713
|
-
|
|
7714
|
-
|
|
7715
|
-
return DecimalType(3, 0)
|
|
7716
|
-
case ShortType():
|
|
7717
|
-
return DecimalType(5, 0)
|
|
7718
|
-
case IntegerType():
|
|
7719
|
-
return DecimalType(10, 0)
|
|
7720
|
-
case LongType():
|
|
7721
|
-
return DecimalType(20, 0)
|
|
7722
|
-
return DecimalType(38, 0)
|
|
8001
|
+
(precision, scale) = _get_type_precision(type)
|
|
8002
|
+
return DecimalType(precision, scale)
|
|
7723
8003
|
|
|
7724
8004
|
|
|
7725
8005
|
def _resolve_decimal_and_numeric(type1: DecimalType, type2: _NumericType) -> DataType:
|
|
@@ -8778,7 +9058,9 @@ def _get_type_precision(typ: DataType) -> tuple[int, int]:
|
|
|
8778
9058
|
case IntegerType():
|
|
8779
9059
|
return 10, 0 # -2147483648 to 2147483647
|
|
8780
9060
|
case LongType():
|
|
8781
|
-
return
|
|
9061
|
+
return 20, 0 # -9223372036854775808 to 9223372036854775807
|
|
9062
|
+
case NullType():
|
|
9063
|
+
return 6, 2 # NULL
|
|
8782
9064
|
case _:
|
|
8783
9065
|
return 38, 0 # Default to maximum precision for other types
|
|
8784
9066
|
|
|
@@ -8993,16 +9275,12 @@ def _try_arithmetic_helper(
|
|
|
8993
9275
|
typed_args[1].typ, DecimalType
|
|
8994
9276
|
):
|
|
8995
9277
|
new_scale = s2
|
|
8996
|
-
new_precision = (
|
|
8997
|
-
p1 + s2 + 1
|
|
8998
|
-
) # Integral precision + decimal scale + 1 for carry
|
|
9278
|
+
new_precision = max(p2, p1 + s2)
|
|
8999
9279
|
elif isinstance(typed_args[0].typ, DecimalType) and isinstance(
|
|
9000
9280
|
typed_args[1].typ, _IntegralType
|
|
9001
9281
|
):
|
|
9002
9282
|
new_scale = s1
|
|
9003
|
-
new_precision = (
|
|
9004
|
-
p2 + s1 + 1
|
|
9005
|
-
) # Integral precision + decimal scale + 1 for carry
|
|
9283
|
+
new_precision = max(p1, p2 + s1)
|
|
9006
9284
|
else:
|
|
9007
9285
|
# Both decimal types
|
|
9008
9286
|
if operation_type == 1 and s1 == s2: # subtraction with matching scales
|
|
@@ -9081,13 +9359,13 @@ def _add_sub_precision_helper(
|
|
|
9081
9359
|
typed_args[1].typ, DecimalType
|
|
9082
9360
|
):
|
|
9083
9361
|
new_scale = s2
|
|
9084
|
-
new_precision = p1 + s2
|
|
9362
|
+
new_precision = max(p2, p1 + s2)
|
|
9085
9363
|
return_type_precision, return_type_scale = new_precision, new_scale
|
|
9086
9364
|
elif isinstance(typed_args[0].typ, DecimalType) and isinstance(
|
|
9087
9365
|
typed_args[1].typ, _IntegralType
|
|
9088
9366
|
):
|
|
9089
9367
|
new_scale = s1
|
|
9090
|
-
new_precision = p2 + s1
|
|
9368
|
+
new_precision = max(p1, p2 + s1)
|
|
9091
9369
|
return_type_precision, return_type_scale = new_precision, new_scale
|
|
9092
9370
|
else:
|
|
9093
9371
|
(
|
|
@@ -9169,11 +9447,25 @@ def _mul_div_precision_helper(
|
|
|
9169
9447
|
)
|
|
9170
9448
|
|
|
9171
9449
|
|
|
9172
|
-
def
|
|
9173
|
-
|
|
9174
|
-
|
|
9450
|
+
def _raise_error_helper(return_type: DataType, error_class=None):
|
|
9451
|
+
error_class = (
|
|
9452
|
+
f":{error_class.__name__}"
|
|
9453
|
+
if error_class and hasattr(error_class, "__name__")
|
|
9454
|
+
else ""
|
|
9455
|
+
)
|
|
9456
|
+
|
|
9457
|
+
def _raise_fn(*msgs: Column) -> Column:
|
|
9458
|
+
return snowpark_fn.cast(
|
|
9459
|
+
snowpark_fn.abs(
|
|
9460
|
+
snowpark_fn.concat(
|
|
9461
|
+
snowpark_fn.lit(f"[snowpark-connect-exception{error_class}]"),
|
|
9462
|
+
*(msg.try_cast(StringType()) for msg in msgs),
|
|
9463
|
+
)
|
|
9464
|
+
).cast(StringType()),
|
|
9465
|
+
return_type,
|
|
9466
|
+
)
|
|
9175
9467
|
|
|
9176
|
-
return
|
|
9468
|
+
return _raise_fn
|
|
9177
9469
|
|
|
9178
9470
|
|
|
9179
9471
|
def _divnull(dividend: Column, divisor: Column) -> Column:
|
|
@@ -9448,3 +9740,22 @@ def _validate_number_format_string(format_str: str) -> None:
|
|
|
9448
9740
|
raise AnalysisException(
|
|
9449
9741
|
f"[INVALID_FORMAT.WRONG_NUM_DIGIT] The format is invalid: '{format_str}'. The format string requires at least one number digit."
|
|
9450
9742
|
)
|
|
9743
|
+
|
|
9744
|
+
|
|
9745
|
+
def _trim_helper(value: Column, trim_value: Column, trim_type: Column) -> Column:
|
|
9746
|
+
@cached_udf(
|
|
9747
|
+
return_type=BinaryType(),
|
|
9748
|
+
input_types=[BinaryType(), BinaryType(), StringType()],
|
|
9749
|
+
)
|
|
9750
|
+
def _binary_trim_udf(value: bytes, trim_value: bytes, trim_type: str) -> bytes:
|
|
9751
|
+
if value is None or trim_value is None:
|
|
9752
|
+
return value
|
|
9753
|
+
if trim_type in ("rtrim", "btrim", "trim"):
|
|
9754
|
+
while value.endswith(trim_value):
|
|
9755
|
+
value = value[: -len(trim_value)]
|
|
9756
|
+
if trim_type in ("ltrim", "btrim", "trim"):
|
|
9757
|
+
while value.startswith(trim_value):
|
|
9758
|
+
value = value[len(trim_value) :]
|
|
9759
|
+
return value
|
|
9760
|
+
|
|
9761
|
+
return _binary_trim_udf(value, trim_value, trim_type)
|