snowpark-connect 0.25.0__py3-none-any.whl → 0.27.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/snowpark_connect/config.py +10 -3
- snowflake/snowpark_connect/dataframe_container.py +16 -0
- snowflake/snowpark_connect/expression/map_expression.py +15 -0
- snowflake/snowpark_connect/expression/map_udf.py +68 -27
- snowflake/snowpark_connect/expression/map_unresolved_attribute.py +18 -0
- snowflake/snowpark_connect/expression/map_unresolved_function.py +38 -28
- snowflake/snowpark_connect/includes/jars/json4s-native_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/paranamer-2.8.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.1.0.jar +0 -0
- snowflake/snowpark_connect/relation/map_extension.py +9 -7
- snowflake/snowpark_connect/relation/map_map_partitions.py +36 -72
- snowflake/snowpark_connect/relation/map_relation.py +15 -2
- snowflake/snowpark_connect/relation/map_row_ops.py +8 -1
- snowflake/snowpark_connect/relation/map_show_string.py +2 -0
- snowflake/snowpark_connect/relation/map_sql.py +63 -2
- snowflake/snowpark_connect/relation/map_udtf.py +96 -44
- snowflake/snowpark_connect/relation/utils.py +44 -0
- snowflake/snowpark_connect/relation/write/map_write.py +135 -24
- snowflake/snowpark_connect/resources_initializer.py +18 -5
- snowflake/snowpark_connect/server.py +12 -2
- snowflake/snowpark_connect/utils/artifacts.py +4 -5
- snowflake/snowpark_connect/utils/concurrent.py +4 -0
- snowflake/snowpark_connect/utils/context.py +41 -1
- snowflake/snowpark_connect/utils/external_udxf_cache.py +36 -0
- snowflake/snowpark_connect/utils/pandas_udtf_utils.py +86 -2
- snowflake/snowpark_connect/utils/scala_udf_utils.py +250 -242
- snowflake/snowpark_connect/utils/session.py +4 -0
- snowflake/snowpark_connect/utils/udf_utils.py +71 -118
- snowflake/snowpark_connect/utils/udtf_helper.py +17 -7
- snowflake/snowpark_connect/utils/udtf_utils.py +3 -16
- snowflake/snowpark_connect/version.py +2 -3
- {snowpark_connect-0.25.0.dist-info → snowpark_connect-0.27.0.dist-info}/METADATA +2 -2
- {snowpark_connect-0.25.0.dist-info → snowpark_connect-0.27.0.dist-info}/RECORD +41 -37
- {snowpark_connect-0.25.0.data → snowpark_connect-0.27.0.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.25.0.data → snowpark_connect-0.27.0.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.25.0.data → snowpark_connect-0.27.0.data}/scripts/snowpark-submit +0 -0
- {snowpark_connect-0.25.0.dist-info → snowpark_connect-0.27.0.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.25.0.dist-info → snowpark_connect-0.27.0.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-0.25.0.dist-info → snowpark_connect-0.27.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-0.25.0.dist-info → snowpark_connect-0.27.0.dist-info}/licenses/NOTICE-binary +0 -0
- {snowpark_connect-0.25.0.dist-info → snowpark_connect-0.27.0.dist-info}/top_level.txt +0 -0
|
@@ -22,6 +22,9 @@ from snowflake.snowpark.exceptions import SnowparkSQLException
|
|
|
22
22
|
from snowflake.snowpark.types import TimestampTimeZone, TimestampType
|
|
23
23
|
from snowflake.snowpark_connect.utils.concurrent import SynchronizedDict
|
|
24
24
|
from snowflake.snowpark_connect.utils.context import get_session_id
|
|
25
|
+
from snowflake.snowpark_connect.utils.external_udxf_cache import (
|
|
26
|
+
clear_external_udxf_cache,
|
|
27
|
+
)
|
|
25
28
|
from snowflake.snowpark_connect.utils.session import get_or_create_snowpark_session
|
|
26
29
|
from snowflake.snowpark_connect.utils.snowpark_connect_logging import logger
|
|
27
30
|
from snowflake.snowpark_connect.utils.telemetry import (
|
|
@@ -136,6 +139,9 @@ class GlobalConfig:
|
|
|
136
139
|
"spark.sql.parser.quotedRegexColumnNames": "false",
|
|
137
140
|
# custom configs
|
|
138
141
|
"snowpark.connect.version": ".".join(map(str, sas_version)),
|
|
142
|
+
# Control whether repartition(n) on a DataFrame forces splitting into n files during writes
|
|
143
|
+
# This matches spark behavior more closely, but introduces overhead.
|
|
144
|
+
"snowflake.repartition.for.writes": "false",
|
|
139
145
|
}
|
|
140
146
|
|
|
141
147
|
boolean_config_list = [
|
|
@@ -148,6 +154,7 @@ class GlobalConfig:
|
|
|
148
154
|
"spark.sql.legacy.allowHashOnMapType",
|
|
149
155
|
"spark.Catalog.databaseFilterInformationSchema",
|
|
150
156
|
"spark.sql.parser.quotedRegexColumnNames",
|
|
157
|
+
"snowflake.repartition.for.writes",
|
|
151
158
|
]
|
|
152
159
|
|
|
153
160
|
int_config_list = [
|
|
@@ -164,9 +171,6 @@ class GlobalConfig:
|
|
|
164
171
|
"spark.app.name": lambda session, name: setattr(
|
|
165
172
|
session, "query_tag", f"Spark-Connect-App-Name={name}"
|
|
166
173
|
),
|
|
167
|
-
"snowpark.connect.udf.packages": lambda session, packages: session.add_packages(
|
|
168
|
-
*packages.strip("[] ").split(",")
|
|
169
|
-
),
|
|
170
174
|
"snowpark.connect.udf.imports": lambda session, imports: parse_imports(
|
|
171
175
|
session, imports
|
|
172
176
|
),
|
|
@@ -592,6 +596,9 @@ def parse_imports(session: snowpark.Session, imports: str | None) -> None:
|
|
|
592
596
|
if not imports:
|
|
593
597
|
return
|
|
594
598
|
|
|
599
|
+
# UDF needs to be recreated to include new imports
|
|
600
|
+
clear_external_udxf_cache(session)
|
|
601
|
+
|
|
595
602
|
for udf_import in imports.strip("[] ").split(","):
|
|
596
603
|
session.add_import(udf_import)
|
|
597
604
|
|
|
@@ -29,6 +29,7 @@ class DataFrameContainer:
|
|
|
29
29
|
table_name: str | None = None,
|
|
30
30
|
alias: str | None = None,
|
|
31
31
|
cached_schema_getter: Callable[[], StructType] | None = None,
|
|
32
|
+
partition_hint: int | None = None,
|
|
32
33
|
) -> None:
|
|
33
34
|
"""
|
|
34
35
|
Initialize a new DataFrameContainer.
|
|
@@ -39,11 +40,13 @@ class DataFrameContainer:
|
|
|
39
40
|
table_name: Optional table name for the DataFrame
|
|
40
41
|
alias: Optional alias for the DataFrame
|
|
41
42
|
cached_schema_getter: Optional function to get cached schema
|
|
43
|
+
partition_hint: Optional partition count from repartition() operations
|
|
42
44
|
"""
|
|
43
45
|
self._dataframe = dataframe
|
|
44
46
|
self._column_map = self._create_default_column_map(column_map)
|
|
45
47
|
self._table_name = table_name
|
|
46
48
|
self._alias = alias
|
|
49
|
+
self._partition_hint = partition_hint
|
|
47
50
|
|
|
48
51
|
if cached_schema_getter is not None:
|
|
49
52
|
self._apply_cached_schema_getter(cached_schema_getter)
|
|
@@ -62,6 +65,7 @@ class DataFrameContainer:
|
|
|
62
65
|
table_name: str | None = None,
|
|
63
66
|
alias: str | None = None,
|
|
64
67
|
cached_schema_getter: Callable[[], StructType] | None = None,
|
|
68
|
+
partition_hint: int | None = None,
|
|
65
69
|
) -> DataFrameContainer:
|
|
66
70
|
"""
|
|
67
71
|
Create a new container with complete column mapping configuration.
|
|
@@ -78,6 +82,7 @@ class DataFrameContainer:
|
|
|
78
82
|
table_name: Optional table name
|
|
79
83
|
alias: Optional alias
|
|
80
84
|
cached_schema_getter: Optional function to get cached schema
|
|
85
|
+
partition_hint: Optional partition count from repartition() operations
|
|
81
86
|
|
|
82
87
|
Returns:
|
|
83
88
|
A new DataFrameContainer instance
|
|
@@ -123,6 +128,7 @@ class DataFrameContainer:
|
|
|
123
128
|
table_name=table_name,
|
|
124
129
|
alias=alias,
|
|
125
130
|
cached_schema_getter=final_schema_getter,
|
|
131
|
+
partition_hint=partition_hint,
|
|
126
132
|
)
|
|
127
133
|
|
|
128
134
|
@property
|
|
@@ -163,6 +169,16 @@ class DataFrameContainer:
|
|
|
163
169
|
"""Set the alias name."""
|
|
164
170
|
self._alias = value
|
|
165
171
|
|
|
172
|
+
@property
|
|
173
|
+
def partition_hint(self) -> int | None:
|
|
174
|
+
"""Get the partition hint count."""
|
|
175
|
+
return self._partition_hint
|
|
176
|
+
|
|
177
|
+
@partition_hint.setter
|
|
178
|
+
def partition_hint(self, value: int | None) -> None:
|
|
179
|
+
"""Set the partition hint count."""
|
|
180
|
+
self._partition_hint = value
|
|
181
|
+
|
|
166
182
|
def _create_default_column_map(
|
|
167
183
|
self, column_map: ColumnNameMap | None
|
|
168
184
|
) -> ColumnNameMap:
|
|
@@ -6,6 +6,7 @@ import datetime
|
|
|
6
6
|
from collections import defaultdict
|
|
7
7
|
|
|
8
8
|
import pyspark.sql.connect.proto.expressions_pb2 as expressions_proto
|
|
9
|
+
from pyspark.errors.exceptions.connect import AnalysisException
|
|
9
10
|
|
|
10
11
|
import snowflake.snowpark.functions as snowpark_fn
|
|
11
12
|
from snowflake import snowpark
|
|
@@ -34,6 +35,7 @@ from snowflake.snowpark_connect.type_mapping import (
|
|
|
34
35
|
from snowflake.snowpark_connect.typed_column import TypedColumn
|
|
35
36
|
from snowflake.snowpark_connect.utils.context import (
|
|
36
37
|
gen_sql_plan_id,
|
|
38
|
+
get_current_lambda_params,
|
|
37
39
|
is_function_argument_being_resolved,
|
|
38
40
|
is_lambda_being_resolved,
|
|
39
41
|
)
|
|
@@ -271,6 +273,19 @@ def map_expression(
|
|
|
271
273
|
case "unresolved_function":
|
|
272
274
|
return map_func.map_unresolved_function(exp, column_mapping, typer)
|
|
273
275
|
case "unresolved_named_lambda_variable":
|
|
276
|
+
# Validate that this lambda variable is in scope
|
|
277
|
+
var_name = exp.unresolved_named_lambda_variable.name_parts[0]
|
|
278
|
+
current_params = get_current_lambda_params()
|
|
279
|
+
|
|
280
|
+
if current_params and var_name not in current_params:
|
|
281
|
+
raise AnalysisException(
|
|
282
|
+
f"Reference to non-lambda variable '{var_name}' within lambda function. "
|
|
283
|
+
f"Lambda functions can only access their own parameters. "
|
|
284
|
+
f"Available lambda parameters are: {current_params}. "
|
|
285
|
+
f"If '{var_name}' is an outer scope lambda variable from a nested lambda, "
|
|
286
|
+
f"that is an unsupported feature in Snowflake SQL."
|
|
287
|
+
)
|
|
288
|
+
|
|
274
289
|
col = snowpark_fn.Column(
|
|
275
290
|
UnresolvedAttribute(exp.unresolved_named_lambda_variable.name_parts[0])
|
|
276
291
|
)
|
|
@@ -13,6 +13,10 @@ from snowflake.snowpark_connect.config import global_config
|
|
|
13
13
|
from snowflake.snowpark_connect.expression.typer import ExpressionTyper
|
|
14
14
|
from snowflake.snowpark_connect.type_mapping import proto_to_snowpark_type
|
|
15
15
|
from snowflake.snowpark_connect.typed_column import TypedColumn
|
|
16
|
+
from snowflake.snowpark_connect.utils.external_udxf_cache import (
|
|
17
|
+
cache_external_udf,
|
|
18
|
+
get_external_udf_from_cache,
|
|
19
|
+
)
|
|
16
20
|
from snowflake.snowpark_connect.utils.session import get_or_create_snowpark_session
|
|
17
21
|
from snowflake.snowpark_connect.utils.udf_helper import (
|
|
18
22
|
SnowparkUDF,
|
|
@@ -30,6 +34,39 @@ from snowflake.snowpark_connect.utils.udxf_import_utils import (
|
|
|
30
34
|
)
|
|
31
35
|
|
|
32
36
|
|
|
37
|
+
def cache_external_udf_wrapper(from_register_udf: bool):
|
|
38
|
+
def outer_wrapper(wrapper_func):
|
|
39
|
+
def wrapper(
|
|
40
|
+
udf_proto: expressions_proto.CommonInlineUserDefinedFunction,
|
|
41
|
+
) -> SnowparkUDF | None:
|
|
42
|
+
udf_hash = hash(str(udf_proto))
|
|
43
|
+
cached_udf = get_external_udf_from_cache(udf_hash)
|
|
44
|
+
|
|
45
|
+
if cached_udf:
|
|
46
|
+
session = get_or_create_snowpark_session()
|
|
47
|
+
function_type = udf_proto.WhichOneof("function")
|
|
48
|
+
# TODO: Align this with SNOW-2316798 after merge
|
|
49
|
+
match function_type:
|
|
50
|
+
case "scalar_scala_udf":
|
|
51
|
+
session._udfs[cached_udf.name] = cached_udf
|
|
52
|
+
case "python_udf" if from_register_udf:
|
|
53
|
+
session._udfs[udf_proto.function_name.lower()] = cached_udf
|
|
54
|
+
case "python_udf":
|
|
55
|
+
pass
|
|
56
|
+
case _:
|
|
57
|
+
raise ValueError(f"Unsupported UDF type: {function_type}")
|
|
58
|
+
|
|
59
|
+
return cached_udf
|
|
60
|
+
|
|
61
|
+
snowpark_udf = wrapper_func(udf_proto)
|
|
62
|
+
cache_external_udf(udf_hash, snowpark_udf)
|
|
63
|
+
return snowpark_udf
|
|
64
|
+
|
|
65
|
+
return wrapper
|
|
66
|
+
|
|
67
|
+
return outer_wrapper
|
|
68
|
+
|
|
69
|
+
|
|
33
70
|
def process_udf_return_type(
|
|
34
71
|
return_type: types_proto.DataType,
|
|
35
72
|
) -> tuple[snowpark.types.DataType, snowpark.types.DataType]:
|
|
@@ -49,6 +86,7 @@ def process_udf_return_type(
|
|
|
49
86
|
return original_snowpark_type, original_snowpark_type
|
|
50
87
|
|
|
51
88
|
|
|
89
|
+
@cache_external_udf_wrapper(from_register_udf=True)
|
|
52
90
|
def register_udf(
|
|
53
91
|
udf_proto: expressions_proto.CommonInlineUserDefinedFunction,
|
|
54
92
|
) -> SnowparkUDF:
|
|
@@ -84,12 +122,10 @@ def register_udf(
|
|
|
84
122
|
return_type=udf._return_type,
|
|
85
123
|
original_return_type=original_return_type,
|
|
86
124
|
)
|
|
87
|
-
|
|
88
|
-
#
|
|
125
|
+
session._udfs[udf_proto.function_name.lower()] = udf
|
|
126
|
+
# scala udfs can be also accessed using `udf.name`
|
|
89
127
|
if udf_processor._function_type == "scalar_scala_udf":
|
|
90
128
|
session._udfs[udf.name] = udf
|
|
91
|
-
else:
|
|
92
|
-
session._udfs[udf_proto.function_name.lower()] = udf
|
|
93
129
|
return udf
|
|
94
130
|
|
|
95
131
|
|
|
@@ -114,29 +150,34 @@ def map_common_inline_user_defined_udf(
|
|
|
114
150
|
udf_proto.scalar_scala_udf.outputType
|
|
115
151
|
)
|
|
116
152
|
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
153
|
+
@cache_external_udf_wrapper(from_register_udf=False)
|
|
154
|
+
def get_snowpark_udf(
|
|
155
|
+
udf_proto: expressions_proto.CommonInlineUserDefinedFunction,
|
|
156
|
+
) -> SnowparkUDF:
|
|
157
|
+
session = get_or_create_snowpark_session()
|
|
158
|
+
kwargs = {
|
|
159
|
+
"common_inline_user_defined_function": udf_proto,
|
|
160
|
+
"input_types": input_types,
|
|
161
|
+
"called_from": "map_common_inline_user_defined_udf",
|
|
162
|
+
"return_type": processed_return_type,
|
|
163
|
+
"udf_packages": global_config.get("snowpark.connect.udf.packages", ""),
|
|
164
|
+
"udf_imports": get_python_udxf_import_files(session),
|
|
165
|
+
"original_return_type": original_return_type,
|
|
166
|
+
}
|
|
167
|
+
if require_creating_udf_in_sproc(udf_proto):
|
|
168
|
+
snowpark_udf = process_udf_in_sproc(**kwargs)
|
|
169
|
+
else:
|
|
170
|
+
udf_processor = ProcessCommonInlineUserDefinedFunction(**kwargs)
|
|
171
|
+
udf = udf_processor.create_udf()
|
|
172
|
+
snowpark_udf = SnowparkUDF(
|
|
173
|
+
name=udf.name,
|
|
174
|
+
input_types=udf._input_types,
|
|
175
|
+
return_type=udf._return_type,
|
|
176
|
+
original_return_type=original_return_type,
|
|
177
|
+
)
|
|
178
|
+
return snowpark_udf
|
|
179
|
+
|
|
180
|
+
snowpark_udf = get_snowpark_udf(udf_proto)
|
|
140
181
|
udf_call_expr = snowpark_fn.call_udf(snowpark_udf.name, *snowpark_udf_args)
|
|
141
182
|
|
|
142
183
|
# If the original return type was MapType or StructType but we converted it to VariantType,
|
|
@@ -22,6 +22,7 @@ from snowflake.snowpark_connect.utils.context import (
|
|
|
22
22
|
get_is_evaluating_sql,
|
|
23
23
|
get_outer_dataframes,
|
|
24
24
|
get_plan_id_map,
|
|
25
|
+
is_lambda_being_resolved,
|
|
25
26
|
resolve_lca_alias,
|
|
26
27
|
)
|
|
27
28
|
from snowflake.snowpark_connect.utils.identifiers import (
|
|
@@ -356,6 +357,23 @@ def map_unresolved_attribute(
|
|
|
356
357
|
return (unqualified_name, typed_col)
|
|
357
358
|
|
|
358
359
|
if snowpark_name is None:
|
|
360
|
+
# Check if we're inside a lambda and trying to reference an outer column
|
|
361
|
+
# This catches direct column references (not lambda variables)
|
|
362
|
+
if is_lambda_being_resolved() and column_mapping:
|
|
363
|
+
# Check if this column exists in the outer scope (not lambda params)
|
|
364
|
+
outer_col_name = (
|
|
365
|
+
column_mapping.get_snowpark_column_name_from_spark_column_name(
|
|
366
|
+
attr_name, allow_non_exists=True
|
|
367
|
+
)
|
|
368
|
+
)
|
|
369
|
+
if outer_col_name:
|
|
370
|
+
# This is an outer scope column being referenced inside a lambda
|
|
371
|
+
raise AnalysisException(
|
|
372
|
+
f"Reference to non-lambda variable '{attr_name}' within lambda function. "
|
|
373
|
+
f"Lambda functions can only access their own parameters. "
|
|
374
|
+
f"If '{attr_name}' is a table column, it must be passed as an explicit parameter to the enclosing function."
|
|
375
|
+
)
|
|
376
|
+
|
|
359
377
|
if has_plan_id:
|
|
360
378
|
raise AnalysisException(
|
|
361
379
|
f'[RESOLVED_REFERENCE_COLUMN_NOT_FOUND] The column "{attr_name}" does not exist in the target dataframe.'
|
|
@@ -476,11 +476,8 @@ def map_unresolved_function(
|
|
|
476
476
|
return TypedColumn(result, lambda: expected_types)
|
|
477
477
|
|
|
478
478
|
match function_name:
|
|
479
|
-
case func_name if (
|
|
480
|
-
|
|
481
|
-
):
|
|
482
|
-
# TODO: In Spark, UDFs can override built-in functions in SQL,
|
|
483
|
-
# but not in DataFrame ops.
|
|
479
|
+
case func_name if func_name.lower() in session._udfs:
|
|
480
|
+
# In Spark, UDFs can override built-in functions
|
|
484
481
|
udf = session._udfs[func_name.lower()]
|
|
485
482
|
result_exp = snowpark_fn.call_udf(
|
|
486
483
|
udf.name,
|
|
@@ -714,6 +711,9 @@ def map_unresolved_function(
|
|
|
714
711
|
"-",
|
|
715
712
|
)
|
|
716
713
|
match (snowpark_typed_args[0].typ, snowpark_typed_args[1].typ):
|
|
714
|
+
case (DateType(), NullType()) | (NullType(), DateType()):
|
|
715
|
+
result_type = LongType()
|
|
716
|
+
result_exp = snowpark_fn.lit(None).cast(result_type)
|
|
717
717
|
case (NullType(), _) | (_, NullType()):
|
|
718
718
|
result_type = _get_add_sub_result_type(
|
|
719
719
|
snowpark_typed_args[0].typ,
|
|
@@ -727,7 +727,10 @@ def map_unresolved_function(
|
|
|
727
727
|
result_type = LongType()
|
|
728
728
|
result_exp = snowpark_args[0] - snowpark_args[1]
|
|
729
729
|
case (DateType(), StringType()):
|
|
730
|
-
if
|
|
730
|
+
if (
|
|
731
|
+
hasattr(snowpark_typed_args[1].col._expr1, "pretty_name")
|
|
732
|
+
and "INTERVAL" == snowpark_typed_args[1].col._expr1.pretty_name
|
|
733
|
+
):
|
|
731
734
|
result_type = TimestampType()
|
|
732
735
|
result_exp = snowpark_args[0] - snowpark_args[1]
|
|
733
736
|
else:
|
|
@@ -2424,7 +2427,7 @@ def map_unresolved_function(
|
|
|
2424
2427
|
"try_to_date",
|
|
2425
2428
|
snowpark_fn.cast(
|
|
2426
2429
|
truncated_date,
|
|
2427
|
-
TimestampType(
|
|
2430
|
+
TimestampType(),
|
|
2428
2431
|
),
|
|
2429
2432
|
snowpark_args[1],
|
|
2430
2433
|
)
|
|
@@ -6479,6 +6482,18 @@ def map_unresolved_function(
|
|
|
6479
6482
|
if pattern_value is None:
|
|
6480
6483
|
return snowpark_fn.lit(None)
|
|
6481
6484
|
|
|
6485
|
+
# Optimization: treat escaped regex that resolves to a pure literal delimiter
|
|
6486
|
+
# - Single char: "\\."
|
|
6487
|
+
# - Multi char: e.g., "\\.505\\."
|
|
6488
|
+
if re.fullmatch(r"(?:\\.)+", pattern_value):
|
|
6489
|
+
literal_delim = re.sub(r"\\(.)", r"\1", pattern_value)
|
|
6490
|
+
return snowpark_fn.when(
|
|
6491
|
+
limit <= 0,
|
|
6492
|
+
snowpark_fn.split(
|
|
6493
|
+
str_, snowpark_fn.lit(literal_delim)
|
|
6494
|
+
).cast(result_type),
|
|
6495
|
+
).otherwise(native_split)
|
|
6496
|
+
|
|
6482
6497
|
is_regexp = re.match(
|
|
6483
6498
|
".*[\\[\\.\\]\\*\\?\\+\\^\\$\\{\\}\\|\\(\\)\\\\].*",
|
|
6484
6499
|
pattern_value,
|
|
@@ -8285,15 +8300,6 @@ def map_unresolved_function(
|
|
|
8285
8300
|
),
|
|
8286
8301
|
)
|
|
8287
8302
|
result_type = BinaryType()
|
|
8288
|
-
case udf_name if udf_name.lower() in session._udfs:
|
|
8289
|
-
# TODO: In Spark, UDFs can override built-in functions in SQL,
|
|
8290
|
-
# but not in DataFrame ops.
|
|
8291
|
-
udf = session._udfs[udf_name.lower()]
|
|
8292
|
-
result_exp = snowpark_fn.call_udf(
|
|
8293
|
-
udf.name,
|
|
8294
|
-
*(snowpark_fn.cast(arg, VariantType()) for arg in snowpark_args),
|
|
8295
|
-
)
|
|
8296
|
-
result_type = udf.return_type
|
|
8297
8303
|
case udtf_name if udtf_name.lower() in session._udtfs:
|
|
8298
8304
|
udtf, spark_col_names = session._udtfs[udtf_name.lower()]
|
|
8299
8305
|
result_exp = snowpark_fn.call_table_function(
|
|
@@ -8725,7 +8731,7 @@ def _resolve_function_with_lambda(
|
|
|
8725
8731
|
artificial_df = Session.get_active_session().create_dataframe([], schema)
|
|
8726
8732
|
set_schema_getter(artificial_df, lambda: schema)
|
|
8727
8733
|
|
|
8728
|
-
with resolving_lambda_function():
|
|
8734
|
+
with resolving_lambda_function(names):
|
|
8729
8735
|
return map_expression(
|
|
8730
8736
|
(
|
|
8731
8737
|
lambda_exp.lambda_function.function
|
|
@@ -9623,13 +9629,14 @@ def _get_decimal_division_result_exp(
|
|
|
9623
9629
|
snowpark_args: list[Column],
|
|
9624
9630
|
spark_function_name: str,
|
|
9625
9631
|
) -> Column:
|
|
9626
|
-
if
|
|
9627
|
-
|
|
9628
|
-
|
|
9629
|
-
|
|
9630
|
-
|
|
9631
|
-
|
|
9632
|
-
|
|
9632
|
+
if (
|
|
9633
|
+
isinstance(other_type, DecimalType)
|
|
9634
|
+
and overflow_detected
|
|
9635
|
+
and global_config.spark_sql_ansi_enabled
|
|
9636
|
+
):
|
|
9637
|
+
raise ArithmeticException(
|
|
9638
|
+
f'[NUMERIC_VALUE_OUT_OF_RANGE] {spark_function_name} cannot be represented as Decimal({result_type.precision}, {result_type.scale}). If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error, and return NULL instead.'
|
|
9639
|
+
)
|
|
9633
9640
|
else:
|
|
9634
9641
|
dividend = snowpark_args[0].cast(DoubleType())
|
|
9635
9642
|
divisor = snowpark_args[1]
|
|
@@ -9910,7 +9917,10 @@ def _get_spark_function_name(
|
|
|
9910
9917
|
return f"({date_param_name1} {operation_op} {date_param_name2})"
|
|
9911
9918
|
case (StringType(), DateType()):
|
|
9912
9919
|
date_param_name2 = _get_literal_param_name(exp, 1, snowpark_arg_names[1])
|
|
9913
|
-
if
|
|
9920
|
+
if (
|
|
9921
|
+
hasattr(col1.col._expr1, "pretty_name")
|
|
9922
|
+
and "INTERVAL" == col1.col._expr1.pretty_name
|
|
9923
|
+
):
|
|
9914
9924
|
return f"{date_param_name2} {operation_op} {snowpark_arg_names[0]}"
|
|
9915
9925
|
elif global_config.spark_sql_ansi_enabled and function_name == "+":
|
|
9916
9926
|
return f"{operation_func}(cast({date_param_name2} as date), cast({snowpark_arg_names[0]} as double))"
|
|
@@ -9918,9 +9928,9 @@ def _get_spark_function_name(
|
|
|
9918
9928
|
return f"({snowpark_arg_names[0]} {operation_op} {date_param_name2})"
|
|
9919
9929
|
case (DateType(), StringType()):
|
|
9920
9930
|
date_param_name1 = _get_literal_param_name(exp, 0, snowpark_arg_names[0])
|
|
9921
|
-
if (
|
|
9922
|
-
|
|
9923
|
-
|
|
9931
|
+
if global_config.spark_sql_ansi_enabled or (
|
|
9932
|
+
hasattr(col2.col._expr1, "pretty_name")
|
|
9933
|
+
and "INTERVAL" == col2.col._expr1.pretty_name
|
|
9924
9934
|
):
|
|
9925
9935
|
return f"{date_param_name1} {operation_op} {snowpark_arg_names[1]}"
|
|
9926
9936
|
else:
|
|
Binary file
|
|
Binary file
|
|
@@ -23,6 +23,7 @@ from snowflake.snowpark_connect.relation.map_relation import map_relation
|
|
|
23
23
|
from snowflake.snowpark_connect.typed_column import TypedColumn
|
|
24
24
|
from snowflake.snowpark_connect.utils.context import (
|
|
25
25
|
get_sql_aggregate_function_count,
|
|
26
|
+
not_resolving_fun_args,
|
|
26
27
|
push_outer_dataframe,
|
|
27
28
|
set_current_grouping_columns,
|
|
28
29
|
)
|
|
@@ -335,14 +336,15 @@ def map_aggregate(
|
|
|
335
336
|
typer = ExpressionTyper(input_df)
|
|
336
337
|
|
|
337
338
|
def _map_column(exp: expression_proto.Expression) -> tuple[str, TypedColumn]:
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
if len(new_names) != 1:
|
|
342
|
-
raise SnowparkConnectNotImplementedError(
|
|
343
|
-
"Multi-column aggregate expressions are not supported"
|
|
339
|
+
with not_resolving_fun_args():
|
|
340
|
+
new_names, snowpark_column = map_expression(
|
|
341
|
+
exp, input_container.column_map, typer
|
|
344
342
|
)
|
|
345
|
-
|
|
343
|
+
if len(new_names) != 1:
|
|
344
|
+
raise SnowparkConnectNotImplementedError(
|
|
345
|
+
"Multi-column aggregate expressions are not supported"
|
|
346
|
+
)
|
|
347
|
+
return new_names[0], snowpark_column
|
|
346
348
|
|
|
347
349
|
raw_groupings: list[tuple[str, TypedColumn]] = []
|
|
348
350
|
raw_aggregations: list[tuple[str, TypedColumn]] = []
|
|
@@ -8,28 +8,20 @@ from pyspark.sql.connect.proto.expressions_pb2 import CommonInlineUserDefinedFun
|
|
|
8
8
|
import snowflake.snowpark.functions as snowpark_fn
|
|
9
9
|
from snowflake import snowpark
|
|
10
10
|
from snowflake.snowpark.types import StructType
|
|
11
|
-
from snowflake.snowpark_connect.config import global_config
|
|
12
11
|
from snowflake.snowpark_connect.constants import MAP_IN_ARROW_EVAL_TYPE
|
|
13
12
|
from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
|
|
14
13
|
from snowflake.snowpark_connect.relation.map_relation import map_relation
|
|
15
14
|
from snowflake.snowpark_connect.type_mapping import proto_to_snowpark_type
|
|
16
|
-
from snowflake.snowpark_connect.utils.
|
|
17
|
-
from snowflake.snowpark_connect.utils.
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
require_creating_udf_in_sproc,
|
|
21
|
-
udf_check,
|
|
22
|
-
)
|
|
23
|
-
from snowflake.snowpark_connect.utils.udf_utils import (
|
|
24
|
-
ProcessCommonInlineUserDefinedFunction,
|
|
15
|
+
from snowflake.snowpark_connect.utils.context import map_partitions_depth
|
|
16
|
+
from snowflake.snowpark_connect.utils.pandas_udtf_utils import (
|
|
17
|
+
create_pandas_udtf,
|
|
18
|
+
create_pandas_udtf_with_arrow,
|
|
25
19
|
)
|
|
20
|
+
from snowflake.snowpark_connect.utils.udf_helper import udf_check
|
|
26
21
|
from snowflake.snowpark_connect.utils.udtf_helper import (
|
|
27
22
|
create_pandas_udtf_in_sproc,
|
|
28
23
|
require_creating_udtf_in_sproc,
|
|
29
24
|
)
|
|
30
|
-
from snowflake.snowpark_connect.utils.udxf_import_utils import (
|
|
31
|
-
get_python_udxf_import_files,
|
|
32
|
-
)
|
|
33
25
|
|
|
34
26
|
|
|
35
27
|
def map_map_partitions(
|
|
@@ -41,17 +33,10 @@ def map_map_partitions(
|
|
|
41
33
|
This is a simple wrapper around the `mapInPandas` method in Snowpark.
|
|
42
34
|
"""
|
|
43
35
|
input_container = map_relation(rel.map_partitions.input)
|
|
44
|
-
input_df = input_container.dataframe
|
|
45
36
|
udf_proto = rel.map_partitions.func
|
|
46
37
|
udf_check(udf_proto)
|
|
47
38
|
|
|
48
|
-
|
|
49
|
-
eval_type = udf_proto.python_udf.eval_type
|
|
50
|
-
|
|
51
|
-
if eval_type == MAP_IN_ARROW_EVAL_TYPE:
|
|
52
|
-
return _map_in_arrow_with_pandas_udtf(input_container, udf_proto)
|
|
53
|
-
else:
|
|
54
|
-
return _map_partitions_with_udf(input_df, udf_proto)
|
|
39
|
+
return _map_with_pandas_udtf(input_container, udf_proto)
|
|
55
40
|
|
|
56
41
|
|
|
57
42
|
def _call_udtf(
|
|
@@ -70,12 +55,17 @@ def _call_udtf(
|
|
|
70
55
|
|
|
71
56
|
udtf_columns = input_df.columns + [snowpark_fn.col("_DUMMY_PARTITION_KEY")]
|
|
72
57
|
|
|
73
|
-
|
|
74
|
-
snowpark_fn.
|
|
75
|
-
partition_by=[snowpark_fn.col("_DUMMY_PARTITION_KEY")]
|
|
76
|
-
)
|
|
58
|
+
tfc = snowpark_fn.call_table_function(udtf_name, *udtf_columns).over(
|
|
59
|
+
partition_by=[snowpark_fn.col("_DUMMY_PARTITION_KEY")]
|
|
77
60
|
)
|
|
78
61
|
|
|
62
|
+
# Use map_partitions_depth only when mapping non nested map_partitions
|
|
63
|
+
# When mapping chained functions additional column casting is necessary
|
|
64
|
+
if map_partitions_depth() == 1:
|
|
65
|
+
result_df_with_dummy = input_df_with_dummy.join_table_function(tfc)
|
|
66
|
+
else:
|
|
67
|
+
result_df_with_dummy = input_df_with_dummy.select(tfc)
|
|
68
|
+
|
|
79
69
|
output_cols = [field.name for field in return_type.fields]
|
|
80
70
|
|
|
81
71
|
# Only return the output columns.
|
|
@@ -89,7 +79,7 @@ def _call_udtf(
|
|
|
89
79
|
)
|
|
90
80
|
|
|
91
81
|
|
|
92
|
-
def
|
|
82
|
+
def _map_with_pandas_udtf(
|
|
93
83
|
input_df_container: DataFrameContainer,
|
|
94
84
|
udf_proto: CommonInlineUserDefinedFunction,
|
|
95
85
|
) -> snowpark.DataFrame:
|
|
@@ -99,55 +89,29 @@ def _map_in_arrow_with_pandas_udtf(
|
|
|
99
89
|
input_df = input_df_container.dataframe
|
|
100
90
|
input_schema = input_df.schema
|
|
101
91
|
spark_column_names = input_df_container.column_map.get_spark_columns()
|
|
102
|
-
return_type = proto_to_snowpark_type(
|
|
92
|
+
return_type = proto_to_snowpark_type(
|
|
93
|
+
udf_proto.python_udf.output_type
|
|
94
|
+
if udf_proto.WhichOneof("function") == "python_udf"
|
|
95
|
+
else udf_proto.scalar_scala_udf.outputType
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
# Check if this is mapInArrow (eval_type == 207)
|
|
99
|
+
map_in_arrow = (
|
|
100
|
+
udf_proto.WhichOneof("function") == "python_udf"
|
|
101
|
+
and udf_proto.python_udf.eval_type == MAP_IN_ARROW_EVAL_TYPE
|
|
102
|
+
)
|
|
103
103
|
if require_creating_udtf_in_sproc(udf_proto):
|
|
104
104
|
udtf_name = create_pandas_udtf_in_sproc(
|
|
105
105
|
udf_proto, spark_column_names, input_schema, return_type
|
|
106
106
|
)
|
|
107
107
|
else:
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
108
|
+
if map_in_arrow:
|
|
109
|
+
map_udtf = create_pandas_udtf_with_arrow(
|
|
110
|
+
udf_proto, spark_column_names, input_schema, return_type
|
|
111
|
+
)
|
|
112
|
+
else:
|
|
113
|
+
map_udtf = create_pandas_udtf(
|
|
114
|
+
udf_proto, spark_column_names, input_schema, return_type
|
|
115
|
+
)
|
|
116
|
+
udtf_name = map_udtf.name
|
|
112
117
|
return _call_udtf(udtf_name, input_df, return_type)
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
def _map_partitions_with_udf(
|
|
116
|
-
input_df: snowpark.DataFrame, udf_proto
|
|
117
|
-
) -> snowpark.DataFrame:
|
|
118
|
-
"""
|
|
119
|
-
Original UDF-based approach for non-mapInArrow map_partitions cases.
|
|
120
|
-
"""
|
|
121
|
-
input_column_names = input_df.columns
|
|
122
|
-
kwargs = {
|
|
123
|
-
"common_inline_user_defined_function": udf_proto,
|
|
124
|
-
"input_types": [f.datatype for f in input_df.schema.fields],
|
|
125
|
-
"called_from": "map_map_partitions",
|
|
126
|
-
"udf_name": "spark_map_partitions_udf",
|
|
127
|
-
"input_column_names": input_column_names,
|
|
128
|
-
"replace": True,
|
|
129
|
-
"return_type": proto_to_snowpark_type(udf_proto.python_udf.output_type),
|
|
130
|
-
"udf_packages": global_config.get("snowpark.connect.udf.packages", ""),
|
|
131
|
-
"udf_imports": get_python_udxf_import_files(input_df.session),
|
|
132
|
-
}
|
|
133
|
-
|
|
134
|
-
if require_creating_udf_in_sproc(udf_proto):
|
|
135
|
-
snowpark_udf = process_udf_in_sproc(**kwargs)
|
|
136
|
-
else:
|
|
137
|
-
udf_processor = ProcessCommonInlineUserDefinedFunction(**kwargs)
|
|
138
|
-
udf = udf_processor.create_udf()
|
|
139
|
-
snowpark_udf = SnowparkUDF(
|
|
140
|
-
name=udf.name,
|
|
141
|
-
input_types=udf._input_types,
|
|
142
|
-
return_type=udf._return_type,
|
|
143
|
-
original_return_type=None,
|
|
144
|
-
)
|
|
145
|
-
udf_column_name = "UDF_OUTPUT"
|
|
146
|
-
snowpark_columns = [snowpark_fn.col(name) for name in input_df.columns]
|
|
147
|
-
result = input_df.select(snowpark_fn.call_udf(snowpark_udf.name, *snowpark_columns))
|
|
148
|
-
return DataFrameContainer.create_with_column_mapping(
|
|
149
|
-
dataframe=result,
|
|
150
|
-
spark_column_names=[udf_column_name],
|
|
151
|
-
snowpark_column_names=[udf_column_name],
|
|
152
|
-
snowpark_column_types=[snowpark_udf.return_type],
|
|
153
|
-
)
|