snowpark-connect 0.23.0__py3-none-any.whl → 0.24.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of snowpark-connect might be problematic. Click here for more details.
- snowflake/snowpark_connect/expression/function_defaults.py +207 -0
- snowflake/snowpark_connect/expression/literal.py +5 -0
- snowflake/snowpark_connect/expression/map_expression.py +10 -1
- snowflake/snowpark_connect/expression/map_extension.py +12 -2
- snowflake/snowpark_connect/expression/map_unresolved_function.py +11 -12
- snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/__init__.py +16 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/accessors.py +1281 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/functions.py +203 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/utils.py +202 -0
- snowflake/snowpark_connect/relation/map_column_ops.py +1 -32
- snowflake/snowpark_connect/relation/map_extension.py +7 -7
- snowflake/snowpark_connect/relation/map_row_ops.py +2 -29
- snowflake/snowpark_connect/relation/read/utils.py +6 -7
- snowflake/snowpark_connect/relation/utils.py +1 -170
- snowflake/snowpark_connect/version.py +1 -1
- {snowpark_connect-0.23.0.dist-info → snowpark_connect-0.24.0.dist-info}/METADATA +1 -1
- {snowpark_connect-0.23.0.dist-info → snowpark_connect-0.24.0.dist-info}/RECORD +25 -20
- {snowpark_connect-0.23.0.data → snowpark_connect-0.24.0.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.23.0.data → snowpark_connect-0.24.0.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.23.0.data → snowpark_connect-0.24.0.data}/scripts/snowpark-submit +0 -0
- {snowpark_connect-0.23.0.dist-info → snowpark_connect-0.24.0.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.23.0.dist-info → snowpark_connect-0.24.0.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-0.23.0.dist-info → snowpark_connect-0.24.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-0.23.0.dist-info → snowpark_connect-0.24.0.dist-info}/licenses/NOTICE-binary +0 -0
- {snowpark_connect-0.23.0.dist-info → snowpark_connect-0.24.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,207 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import pyspark.sql.connect.proto.expressions_pb2 as expressions_pb2
|
|
8
|
+
import pyspark.sql.connect.proto.types_pb2 as types_pb2
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass(frozen=True)
|
|
12
|
+
class DefaultParameter:
|
|
13
|
+
"""Represents a single default parameter for a function."""
|
|
14
|
+
|
|
15
|
+
name: str
|
|
16
|
+
value: Any
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass(frozen=True)
|
|
20
|
+
class FunctionDefaults:
|
|
21
|
+
"""Represents default parameter configuration for a function."""
|
|
22
|
+
|
|
23
|
+
total_args: int
|
|
24
|
+
defaults: list[DefaultParameter]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
# FUNCTION_DEFAULTS dictionary to hold operation name with default values.
|
|
28
|
+
# This is required as non pyspark clients such as scala or sql won't send all the parameters.
|
|
29
|
+
# We use this dict to inject the missing parameters before processing the unresolved function.
|
|
30
|
+
FUNCTION_DEFAULTS: dict[str, FunctionDefaults] = {
|
|
31
|
+
"aes_decrypt": FunctionDefaults(
|
|
32
|
+
total_args=5,
|
|
33
|
+
defaults=[
|
|
34
|
+
DefaultParameter("mode", "GCM"), # Spark SQL default: GCM
|
|
35
|
+
DefaultParameter("padding", "NONE"), # Spark SQL default: NONE for GCM mode
|
|
36
|
+
DefaultParameter("aad", ""), # Spark SQL default: empty string
|
|
37
|
+
],
|
|
38
|
+
),
|
|
39
|
+
"aes_encrypt": FunctionDefaults(
|
|
40
|
+
total_args=6,
|
|
41
|
+
defaults=[
|
|
42
|
+
DefaultParameter("mode", "GCM"), # Spark SQL default: GCM
|
|
43
|
+
DefaultParameter("padding", "NONE"), # Spark SQL default: NONE for GCM mode
|
|
44
|
+
DefaultParameter(
|
|
45
|
+
"iv", ""
|
|
46
|
+
), # Spark SQL default: empty string (random generated if not provided)
|
|
47
|
+
DefaultParameter("aad", ""), # Spark SQL default: empty string
|
|
48
|
+
],
|
|
49
|
+
),
|
|
50
|
+
"approx_percentile": FunctionDefaults(
|
|
51
|
+
total_args=3,
|
|
52
|
+
defaults=[DefaultParameter("accuracy", 10000)],
|
|
53
|
+
),
|
|
54
|
+
"bround": FunctionDefaults(
|
|
55
|
+
total_args=2,
|
|
56
|
+
defaults=[DefaultParameter("scale", 0)],
|
|
57
|
+
),
|
|
58
|
+
"first": FunctionDefaults(
|
|
59
|
+
total_args=2,
|
|
60
|
+
defaults=[DefaultParameter("ignorenulls", False)],
|
|
61
|
+
),
|
|
62
|
+
"lag": FunctionDefaults(
|
|
63
|
+
total_args=2,
|
|
64
|
+
defaults=[
|
|
65
|
+
DefaultParameter("offset", 1),
|
|
66
|
+
],
|
|
67
|
+
),
|
|
68
|
+
"last": FunctionDefaults(
|
|
69
|
+
total_args=2,
|
|
70
|
+
defaults=[DefaultParameter("ignorenulls", False)],
|
|
71
|
+
),
|
|
72
|
+
"lead": FunctionDefaults(
|
|
73
|
+
total_args=3,
|
|
74
|
+
defaults=[DefaultParameter("offset", 1), DefaultParameter("default", None)],
|
|
75
|
+
),
|
|
76
|
+
"locate": FunctionDefaults(
|
|
77
|
+
total_args=3,
|
|
78
|
+
defaults=[DefaultParameter("pos", 1)],
|
|
79
|
+
),
|
|
80
|
+
"months_between": FunctionDefaults(
|
|
81
|
+
total_args=3,
|
|
82
|
+
defaults=[DefaultParameter("roundOff", True)],
|
|
83
|
+
),
|
|
84
|
+
"nth_value": FunctionDefaults(
|
|
85
|
+
total_args=3,
|
|
86
|
+
defaults=[DefaultParameter("ignoreNulls", False)],
|
|
87
|
+
),
|
|
88
|
+
"overlay": FunctionDefaults(
|
|
89
|
+
total_args=4,
|
|
90
|
+
defaults=[DefaultParameter("len", -1)],
|
|
91
|
+
),
|
|
92
|
+
"percentile": FunctionDefaults(
|
|
93
|
+
total_args=3,
|
|
94
|
+
defaults=[DefaultParameter("frequency", 1)],
|
|
95
|
+
),
|
|
96
|
+
"percentile_approx": FunctionDefaults(
|
|
97
|
+
total_args=3,
|
|
98
|
+
defaults=[DefaultParameter("accuracy", 10000)],
|
|
99
|
+
),
|
|
100
|
+
"round": FunctionDefaults(
|
|
101
|
+
total_args=2,
|
|
102
|
+
defaults=[DefaultParameter("scale", 0)],
|
|
103
|
+
),
|
|
104
|
+
"sentences": FunctionDefaults(
|
|
105
|
+
total_args=3,
|
|
106
|
+
defaults=[
|
|
107
|
+
DefaultParameter("language", ""),
|
|
108
|
+
DefaultParameter("country", ""),
|
|
109
|
+
],
|
|
110
|
+
),
|
|
111
|
+
"sort_array": FunctionDefaults(
|
|
112
|
+
total_args=2,
|
|
113
|
+
defaults=[DefaultParameter("asc", True)],
|
|
114
|
+
),
|
|
115
|
+
"split": FunctionDefaults(
|
|
116
|
+
total_args=3,
|
|
117
|
+
defaults=[DefaultParameter("limit", -1)],
|
|
118
|
+
),
|
|
119
|
+
"str_to_map": FunctionDefaults(
|
|
120
|
+
total_args=3,
|
|
121
|
+
defaults=[
|
|
122
|
+
DefaultParameter(
|
|
123
|
+
"pairDelim", ","
|
|
124
|
+
), # Spark SQL default: comma for splitting pairs
|
|
125
|
+
DefaultParameter(
|
|
126
|
+
"keyValueDelim", ":"
|
|
127
|
+
), # Spark SQL default: colon for splitting key/value
|
|
128
|
+
],
|
|
129
|
+
),
|
|
130
|
+
"try_aes_decrypt": FunctionDefaults(
|
|
131
|
+
total_args=5,
|
|
132
|
+
defaults=[
|
|
133
|
+
DefaultParameter("mode", "GCM"), # Spark SQL default: GCM
|
|
134
|
+
DefaultParameter("padding", "NONE"), # Spark SQL default: NONE for GCM mode
|
|
135
|
+
DefaultParameter("aad", ""), # Spark SQL default: empty string
|
|
136
|
+
],
|
|
137
|
+
),
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def _create_literal_expression(value: Any) -> expressions_pb2.Expression:
|
|
142
|
+
"""Create a literal expression for the given value."""
|
|
143
|
+
expr = expressions_pb2.Expression()
|
|
144
|
+
if isinstance(value, bool):
|
|
145
|
+
expr.literal.boolean = value
|
|
146
|
+
elif isinstance(value, int):
|
|
147
|
+
expr.literal.integer = value
|
|
148
|
+
elif isinstance(value, str):
|
|
149
|
+
expr.literal.string = value
|
|
150
|
+
elif isinstance(value, float):
|
|
151
|
+
expr.literal.double = value
|
|
152
|
+
elif value is None:
|
|
153
|
+
null_type = types_pb2.DataType()
|
|
154
|
+
null_type.null.SetInParent()
|
|
155
|
+
expr.literal.null.CopyFrom(null_type)
|
|
156
|
+
else:
|
|
157
|
+
raise ValueError(f"Unsupported literal type: {value}")
|
|
158
|
+
|
|
159
|
+
return expr
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def inject_function_defaults(
|
|
163
|
+
unresolved_function: expressions_pb2.Expression.UnresolvedFunction,
|
|
164
|
+
) -> bool:
|
|
165
|
+
"""
|
|
166
|
+
Inject missing default parameters into an UnresolvedFunction protobuf.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
unresolved_function: The protobuf UnresolvedFunction to modify
|
|
170
|
+
|
|
171
|
+
Returns:
|
|
172
|
+
bool: True if any defaults were injected, False otherwise
|
|
173
|
+
"""
|
|
174
|
+
function_name = unresolved_function.function_name.lower()
|
|
175
|
+
|
|
176
|
+
if function_name not in FUNCTION_DEFAULTS:
|
|
177
|
+
return False
|
|
178
|
+
|
|
179
|
+
func_config = FUNCTION_DEFAULTS[function_name]
|
|
180
|
+
current_arg_count = len(unresolved_function.arguments)
|
|
181
|
+
total_args = func_config.total_args
|
|
182
|
+
defaults = func_config.defaults
|
|
183
|
+
|
|
184
|
+
if not defaults or current_arg_count >= total_args:
|
|
185
|
+
return False
|
|
186
|
+
|
|
187
|
+
# Calculate how many defaults to append
|
|
188
|
+
missing_arg_count = total_args - current_arg_count
|
|
189
|
+
|
|
190
|
+
# Check if any required params are missing.
|
|
191
|
+
if missing_arg_count > len(defaults):
|
|
192
|
+
raise ValueError(
|
|
193
|
+
f"Function '{function_name}' is missing required arguments. "
|
|
194
|
+
f"Expected {total_args} args, got {current_arg_count}, "
|
|
195
|
+
f"but only {len(defaults)} defaults are defined."
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
defaults_to_append = defaults[-missing_arg_count:]
|
|
199
|
+
injected = False
|
|
200
|
+
|
|
201
|
+
# Simply append the needed default values
|
|
202
|
+
for default_param in defaults_to_append:
|
|
203
|
+
default_expr = _create_literal_expression(default_param.value)
|
|
204
|
+
unresolved_function.arguments.append(default_expr)
|
|
205
|
+
injected = True
|
|
206
|
+
|
|
207
|
+
return injected
|
|
@@ -81,6 +81,11 @@ def get_literal_field_and_name(literal: expressions_proto.Expression.Literal):
|
|
|
81
81
|
case "decimal":
|
|
82
82
|
# literal.decimal.precision & scale are ignored, as decimal.Decimal doesn't accept them
|
|
83
83
|
return decimal.Decimal(literal.decimal.value), literal.decimal.value
|
|
84
|
+
case "array":
|
|
85
|
+
array_values, element_names = zip(
|
|
86
|
+
*(get_literal_field_and_name(e) for e in literal.array.elements)
|
|
87
|
+
)
|
|
88
|
+
return array_values, f"ARRAY({', '.join(element_names)})"
|
|
84
89
|
case "null" | None:
|
|
85
90
|
return None, "NULL"
|
|
86
91
|
case other:
|
|
@@ -27,7 +27,10 @@ from snowflake.snowpark_connect.expression.literal import get_literal_field_and_
|
|
|
27
27
|
from snowflake.snowpark_connect.expression.map_cast import map_cast
|
|
28
28
|
from snowflake.snowpark_connect.expression.map_sql_expression import map_sql_expr
|
|
29
29
|
from snowflake.snowpark_connect.expression.typer import ExpressionTyper
|
|
30
|
-
from snowflake.snowpark_connect.type_mapping import
|
|
30
|
+
from snowflake.snowpark_connect.type_mapping import (
|
|
31
|
+
map_simple_types,
|
|
32
|
+
proto_to_snowpark_type,
|
|
33
|
+
)
|
|
31
34
|
from snowflake.snowpark_connect.typed_column import TypedColumn
|
|
32
35
|
from snowflake.snowpark_connect.utils.context import (
|
|
33
36
|
gen_sql_plan_id,
|
|
@@ -166,6 +169,12 @@ def map_expression(
|
|
|
166
169
|
lambda: [map_simple_types(lit_type_str)],
|
|
167
170
|
)
|
|
168
171
|
|
|
172
|
+
if lit_type_str == "array":
|
|
173
|
+
result_exp = snowpark_fn.lit(lit_value)
|
|
174
|
+
element_types = proto_to_snowpark_type(exp.literal.array.element_type)
|
|
175
|
+
array_type = snowpark.types.ArrayType(element_types)
|
|
176
|
+
return [lit_name], TypedColumn(result_exp, lambda: [array_type])
|
|
177
|
+
|
|
169
178
|
# Decimal needs further processing to get the precision and scale properly.
|
|
170
179
|
if lit_type_str == "decimal":
|
|
171
180
|
# Precision and scale are optional in the proto.
|
|
@@ -10,7 +10,10 @@ from snowflake.snowpark.types import BooleanType
|
|
|
10
10
|
from snowflake.snowpark_connect.column_name_handler import ColumnNameMap
|
|
11
11
|
from snowflake.snowpark_connect.expression.typer import ExpressionTyper
|
|
12
12
|
from snowflake.snowpark_connect.typed_column import TypedColumn
|
|
13
|
-
from snowflake.snowpark_connect.utils.context import
|
|
13
|
+
from snowflake.snowpark_connect.utils.context import (
|
|
14
|
+
push_evaluating_sql_scope,
|
|
15
|
+
push_outer_dataframe,
|
|
16
|
+
)
|
|
14
17
|
from snowflake.snowpark_connect.utils.telemetry import (
|
|
15
18
|
SnowparkConnectNotImplementedError,
|
|
16
19
|
)
|
|
@@ -52,12 +55,19 @@ def map_extension(
|
|
|
52
55
|
return [name], typed_col
|
|
53
56
|
|
|
54
57
|
case "subquery_expression":
|
|
58
|
+
from snowflake.snowpark_connect.dataframe_container import (
|
|
59
|
+
DataFrameContainer,
|
|
60
|
+
)
|
|
55
61
|
from snowflake.snowpark_connect.expression.map_expression import (
|
|
56
62
|
map_expression,
|
|
57
63
|
)
|
|
58
64
|
from snowflake.snowpark_connect.relation.map_relation import map_relation
|
|
59
65
|
|
|
60
|
-
|
|
66
|
+
current_outer_df = DataFrameContainer(
|
|
67
|
+
dataframe=typer.df, column_map=column_mapping
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
with push_evaluating_sql_scope(), push_outer_dataframe(current_outer_df):
|
|
61
71
|
df_container = map_relation(extension.subquery_expression.input)
|
|
62
72
|
df = df_container.dataframe
|
|
63
73
|
|
|
@@ -80,6 +80,9 @@ from snowflake.snowpark_connect.constants import (
|
|
|
80
80
|
SPARK_TZ_ABBREVIATIONS_OVERRIDES,
|
|
81
81
|
STRUCTURED_TYPES_ENABLED,
|
|
82
82
|
)
|
|
83
|
+
from snowflake.snowpark_connect.expression.function_defaults import (
|
|
84
|
+
inject_function_defaults,
|
|
85
|
+
)
|
|
83
86
|
from snowflake.snowpark_connect.expression.literal import get_literal_field_and_name
|
|
84
87
|
from snowflake.snowpark_connect.expression.map_cast import (
|
|
85
88
|
CAST_FUNCTIONS,
|
|
@@ -299,6 +302,9 @@ def map_unresolved_function(
|
|
|
299
302
|
function_name = exp.unresolved_function.function_name.lower()
|
|
300
303
|
is_udtf_call = function_name in session._udtfs
|
|
301
304
|
|
|
305
|
+
# Inject default parameters for functions that need them (especially for Scala clients)
|
|
306
|
+
inject_function_defaults(exp.unresolved_function)
|
|
307
|
+
|
|
302
308
|
def _resolve_args_expressions(exp: expressions_proto.Expression):
|
|
303
309
|
def _resolve_fn_arg(exp):
|
|
304
310
|
with resolving_fun_args():
|
|
@@ -3761,7 +3767,7 @@ def map_unresolved_function(
|
|
|
3761
3767
|
snowpark_args[1] <= 0, snowpark_fn.lit("")
|
|
3762
3768
|
).otherwise(snowpark_fn.left(*snowpark_args))
|
|
3763
3769
|
result_type = StringType()
|
|
3764
|
-
case "length" | "char_length" | "character_length":
|
|
3770
|
+
case "length" | "char_length" | "character_length" | "len":
|
|
3765
3771
|
if exp.unresolved_function.arguments[0].HasField("literal"):
|
|
3766
3772
|
# Only update the name if it has the literal field.
|
|
3767
3773
|
# If it doesn't, it means it's binary data.
|
|
@@ -3822,13 +3828,7 @@ def map_unresolved_function(
|
|
|
3822
3828
|
case "locate":
|
|
3823
3829
|
substr = unwrap_literal(exp.unresolved_function.arguments[0])
|
|
3824
3830
|
value = snowpark_args[1]
|
|
3825
|
-
|
|
3826
|
-
start_pos = unwrap_literal(exp.unresolved_function.arguments[2])
|
|
3827
|
-
else:
|
|
3828
|
-
# start_pos is an optional argument and if not provided we should default to 1.
|
|
3829
|
-
# This path will only be reached by spark connect scala clients.
|
|
3830
|
-
start_pos = 1
|
|
3831
|
-
spark_function_name = f"locate({', '.join(snowpark_arg_names)}, 1)"
|
|
3831
|
+
start_pos = unwrap_literal(exp.unresolved_function.arguments[2])
|
|
3832
3832
|
|
|
3833
3833
|
if start_pos > 0:
|
|
3834
3834
|
result_exp = snowpark_fn.locate(substr, value, start_pos)
|
|
@@ -4677,7 +4677,7 @@ def map_unresolved_function(
|
|
|
4677
4677
|
snowpark_args[0],
|
|
4678
4678
|
)
|
|
4679
4679
|
result_type = DateType()
|
|
4680
|
-
case "not":
|
|
4680
|
+
case "not" | "!":
|
|
4681
4681
|
spark_function_name = f"(NOT {snowpark_arg_names[0]})"
|
|
4682
4682
|
result_exp = ~snowpark_args[0]
|
|
4683
4683
|
result_type = BooleanType()
|
|
@@ -5253,9 +5253,8 @@ def map_unresolved_function(
|
|
|
5253
5253
|
# TODO: Seems like more validation of the arguments is appropriate.
|
|
5254
5254
|
args = exp.unresolved_function.arguments
|
|
5255
5255
|
if len(args) > 0:
|
|
5256
|
-
if not (
|
|
5257
|
-
|
|
5258
|
-
or isinstance(snowpark_typed_args[0].typ, NullType)
|
|
5256
|
+
if not isinstance(
|
|
5257
|
+
snowpark_typed_args[0].typ, (IntegerType, LongType, NullType)
|
|
5259
5258
|
):
|
|
5260
5259
|
raise AnalysisException(
|
|
5261
5260
|
f"""[DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE] Cannot resolve "{spark_function_name}" due to data type mismatch: Parameter 1 requires the ("INT" or "BIGINT") type, however {snowpark_arg_names[0]} has the type "{snowpark_typed_args[0].typ}"""
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Licensed to the Apache Software Foundation (ASF) under one or more
|
|
3
|
+
# contributor license agreements. See the NOTICE file distributed with
|
|
4
|
+
# this work for additional information regarding copyright ownership.
|
|
5
|
+
# The ASF licenses this file to You under the Apache License, Version 2.0
|
|
6
|
+
# (the "License"); you may not use this file except in compliance with
|
|
7
|
+
# the License. You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
#
|