snowpark-connect 0.25.0__py3-none-any.whl → 0.26.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of snowpark-connect might be problematic. Click here for more details.
- snowflake/snowpark_connect/config.py +10 -0
- snowflake/snowpark_connect/dataframe_container.py +16 -0
- snowflake/snowpark_connect/expression/map_udf.py +68 -27
- snowflake/snowpark_connect/expression/map_unresolved_function.py +22 -21
- snowflake/snowpark_connect/includes/jars/json4s-native_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/paranamer-2.8.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.1.0.jar +0 -0
- snowflake/snowpark_connect/relation/map_map_partitions.py +9 -4
- snowflake/snowpark_connect/relation/map_relation.py +12 -1
- snowflake/snowpark_connect/relation/map_row_ops.py +8 -1
- snowflake/snowpark_connect/relation/map_udtf.py +96 -44
- snowflake/snowpark_connect/relation/utils.py +44 -0
- snowflake/snowpark_connect/relation/write/map_write.py +113 -22
- snowflake/snowpark_connect/resources_initializer.py +18 -5
- snowflake/snowpark_connect/server.py +8 -1
- snowflake/snowpark_connect/utils/concurrent.py +4 -0
- snowflake/snowpark_connect/utils/external_udxf_cache.py +36 -0
- snowflake/snowpark_connect/utils/scala_udf_utils.py +250 -242
- snowflake/snowpark_connect/utils/session.py +4 -0
- snowflake/snowpark_connect/utils/udf_utils.py +7 -17
- snowflake/snowpark_connect/utils/udtf_utils.py +3 -16
- snowflake/snowpark_connect/version.py +1 -1
- {snowpark_connect-0.25.0.dist-info → snowpark_connect-0.26.0.dist-info}/METADATA +1 -1
- {snowpark_connect-0.25.0.dist-info → snowpark_connect-0.26.0.dist-info}/RECORD +32 -28
- {snowpark_connect-0.25.0.data → snowpark_connect-0.26.0.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.25.0.data → snowpark_connect-0.26.0.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.25.0.data → snowpark_connect-0.26.0.data}/scripts/snowpark-submit +0 -0
- {snowpark_connect-0.25.0.dist-info → snowpark_connect-0.26.0.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.25.0.dist-info → snowpark_connect-0.26.0.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-0.25.0.dist-info → snowpark_connect-0.26.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-0.25.0.dist-info → snowpark_connect-0.26.0.dist-info}/licenses/NOTICE-binary +0 -0
- {snowpark_connect-0.25.0.dist-info → snowpark_connect-0.26.0.dist-info}/top_level.txt +0 -0
|
@@ -22,6 +22,9 @@ from snowflake.snowpark.exceptions import SnowparkSQLException
|
|
|
22
22
|
from snowflake.snowpark.types import TimestampTimeZone, TimestampType
|
|
23
23
|
from snowflake.snowpark_connect.utils.concurrent import SynchronizedDict
|
|
24
24
|
from snowflake.snowpark_connect.utils.context import get_session_id
|
|
25
|
+
from snowflake.snowpark_connect.utils.external_udxf_cache import (
|
|
26
|
+
clear_external_udxf_cache,
|
|
27
|
+
)
|
|
25
28
|
from snowflake.snowpark_connect.utils.session import get_or_create_snowpark_session
|
|
26
29
|
from snowflake.snowpark_connect.utils.snowpark_connect_logging import logger
|
|
27
30
|
from snowflake.snowpark_connect.utils.telemetry import (
|
|
@@ -136,6 +139,9 @@ class GlobalConfig:
|
|
|
136
139
|
"spark.sql.parser.quotedRegexColumnNames": "false",
|
|
137
140
|
# custom configs
|
|
138
141
|
"snowpark.connect.version": ".".join(map(str, sas_version)),
|
|
142
|
+
# Control whether repartition(n) on a DataFrame forces splitting into n files during writes
|
|
143
|
+
# This matches spark behavior more closely, but introduces overhead.
|
|
144
|
+
"snowflake.repartition.for.writes": "false",
|
|
139
145
|
}
|
|
140
146
|
|
|
141
147
|
boolean_config_list = [
|
|
@@ -148,6 +154,7 @@ class GlobalConfig:
|
|
|
148
154
|
"spark.sql.legacy.allowHashOnMapType",
|
|
149
155
|
"spark.Catalog.databaseFilterInformationSchema",
|
|
150
156
|
"spark.sql.parser.quotedRegexColumnNames",
|
|
157
|
+
"snowflake.repartition.for.writes",
|
|
151
158
|
]
|
|
152
159
|
|
|
153
160
|
int_config_list = [
|
|
@@ -592,6 +599,9 @@ def parse_imports(session: snowpark.Session, imports: str | None) -> None:
|
|
|
592
599
|
if not imports:
|
|
593
600
|
return
|
|
594
601
|
|
|
602
|
+
# UDF needs to be recreated to include new imports
|
|
603
|
+
clear_external_udxf_cache(session)
|
|
604
|
+
|
|
595
605
|
for udf_import in imports.strip("[] ").split(","):
|
|
596
606
|
session.add_import(udf_import)
|
|
597
607
|
|
|
@@ -29,6 +29,7 @@ class DataFrameContainer:
|
|
|
29
29
|
table_name: str | None = None,
|
|
30
30
|
alias: str | None = None,
|
|
31
31
|
cached_schema_getter: Callable[[], StructType] | None = None,
|
|
32
|
+
partition_hint: int | None = None,
|
|
32
33
|
) -> None:
|
|
33
34
|
"""
|
|
34
35
|
Initialize a new DataFrameContainer.
|
|
@@ -39,11 +40,13 @@ class DataFrameContainer:
|
|
|
39
40
|
table_name: Optional table name for the DataFrame
|
|
40
41
|
alias: Optional alias for the DataFrame
|
|
41
42
|
cached_schema_getter: Optional function to get cached schema
|
|
43
|
+
partition_hint: Optional partition count from repartition() operations
|
|
42
44
|
"""
|
|
43
45
|
self._dataframe = dataframe
|
|
44
46
|
self._column_map = self._create_default_column_map(column_map)
|
|
45
47
|
self._table_name = table_name
|
|
46
48
|
self._alias = alias
|
|
49
|
+
self._partition_hint = partition_hint
|
|
47
50
|
|
|
48
51
|
if cached_schema_getter is not None:
|
|
49
52
|
self._apply_cached_schema_getter(cached_schema_getter)
|
|
@@ -62,6 +65,7 @@ class DataFrameContainer:
|
|
|
62
65
|
table_name: str | None = None,
|
|
63
66
|
alias: str | None = None,
|
|
64
67
|
cached_schema_getter: Callable[[], StructType] | None = None,
|
|
68
|
+
partition_hint: int | None = None,
|
|
65
69
|
) -> DataFrameContainer:
|
|
66
70
|
"""
|
|
67
71
|
Create a new container with complete column mapping configuration.
|
|
@@ -78,6 +82,7 @@ class DataFrameContainer:
|
|
|
78
82
|
table_name: Optional table name
|
|
79
83
|
alias: Optional alias
|
|
80
84
|
cached_schema_getter: Optional function to get cached schema
|
|
85
|
+
partition_hint: Optional partition count from repartition() operations
|
|
81
86
|
|
|
82
87
|
Returns:
|
|
83
88
|
A new DataFrameContainer instance
|
|
@@ -123,6 +128,7 @@ class DataFrameContainer:
|
|
|
123
128
|
table_name=table_name,
|
|
124
129
|
alias=alias,
|
|
125
130
|
cached_schema_getter=final_schema_getter,
|
|
131
|
+
partition_hint=partition_hint,
|
|
126
132
|
)
|
|
127
133
|
|
|
128
134
|
@property
|
|
@@ -163,6 +169,16 @@ class DataFrameContainer:
|
|
|
163
169
|
"""Set the alias name."""
|
|
164
170
|
self._alias = value
|
|
165
171
|
|
|
172
|
+
@property
|
|
173
|
+
def partition_hint(self) -> int | None:
|
|
174
|
+
"""Get the partition hint count."""
|
|
175
|
+
return self._partition_hint
|
|
176
|
+
|
|
177
|
+
@partition_hint.setter
|
|
178
|
+
def partition_hint(self, value: int | None) -> None:
|
|
179
|
+
"""Set the partition hint count."""
|
|
180
|
+
self._partition_hint = value
|
|
181
|
+
|
|
166
182
|
def _create_default_column_map(
|
|
167
183
|
self, column_map: ColumnNameMap | None
|
|
168
184
|
) -> ColumnNameMap:
|
|
@@ -13,6 +13,10 @@ from snowflake.snowpark_connect.config import global_config
|
|
|
13
13
|
from snowflake.snowpark_connect.expression.typer import ExpressionTyper
|
|
14
14
|
from snowflake.snowpark_connect.type_mapping import proto_to_snowpark_type
|
|
15
15
|
from snowflake.snowpark_connect.typed_column import TypedColumn
|
|
16
|
+
from snowflake.snowpark_connect.utils.external_udxf_cache import (
|
|
17
|
+
cache_external_udf,
|
|
18
|
+
get_external_udf_from_cache,
|
|
19
|
+
)
|
|
16
20
|
from snowflake.snowpark_connect.utils.session import get_or_create_snowpark_session
|
|
17
21
|
from snowflake.snowpark_connect.utils.udf_helper import (
|
|
18
22
|
SnowparkUDF,
|
|
@@ -30,6 +34,39 @@ from snowflake.snowpark_connect.utils.udxf_import_utils import (
|
|
|
30
34
|
)
|
|
31
35
|
|
|
32
36
|
|
|
37
|
+
def cache_external_udf_wrapper(from_register_udf: bool):
|
|
38
|
+
def outer_wrapper(wrapper_func):
|
|
39
|
+
def wrapper(
|
|
40
|
+
udf_proto: expressions_proto.CommonInlineUserDefinedFunction,
|
|
41
|
+
) -> SnowparkUDF | None:
|
|
42
|
+
udf_hash = hash(str(udf_proto))
|
|
43
|
+
cached_udf = get_external_udf_from_cache(udf_hash)
|
|
44
|
+
|
|
45
|
+
if cached_udf:
|
|
46
|
+
session = get_or_create_snowpark_session()
|
|
47
|
+
function_type = udf_proto.WhichOneof("function")
|
|
48
|
+
# TODO: Align this with SNOW-2316798 after merge
|
|
49
|
+
match function_type:
|
|
50
|
+
case "scalar_scala_udf":
|
|
51
|
+
session._udfs[cached_udf.name] = cached_udf
|
|
52
|
+
case "python_udf" if from_register_udf:
|
|
53
|
+
session._udfs[udf_proto.function_name.lower()] = cached_udf
|
|
54
|
+
case "python_udf":
|
|
55
|
+
pass
|
|
56
|
+
case _:
|
|
57
|
+
raise ValueError(f"Unsupported UDF type: {function_type}")
|
|
58
|
+
|
|
59
|
+
return cached_udf
|
|
60
|
+
|
|
61
|
+
snowpark_udf = wrapper_func(udf_proto)
|
|
62
|
+
cache_external_udf(udf_hash, snowpark_udf)
|
|
63
|
+
return snowpark_udf
|
|
64
|
+
|
|
65
|
+
return wrapper
|
|
66
|
+
|
|
67
|
+
return outer_wrapper
|
|
68
|
+
|
|
69
|
+
|
|
33
70
|
def process_udf_return_type(
|
|
34
71
|
return_type: types_proto.DataType,
|
|
35
72
|
) -> tuple[snowpark.types.DataType, snowpark.types.DataType]:
|
|
@@ -49,6 +86,7 @@ def process_udf_return_type(
|
|
|
49
86
|
return original_snowpark_type, original_snowpark_type
|
|
50
87
|
|
|
51
88
|
|
|
89
|
+
@cache_external_udf_wrapper(from_register_udf=True)
|
|
52
90
|
def register_udf(
|
|
53
91
|
udf_proto: expressions_proto.CommonInlineUserDefinedFunction,
|
|
54
92
|
) -> SnowparkUDF:
|
|
@@ -84,12 +122,10 @@ def register_udf(
|
|
|
84
122
|
return_type=udf._return_type,
|
|
85
123
|
original_return_type=original_return_type,
|
|
86
124
|
)
|
|
87
|
-
|
|
88
|
-
#
|
|
125
|
+
session._udfs[udf_proto.function_name.lower()] = udf
|
|
126
|
+
# scala udfs can be also accessed using `udf.name`
|
|
89
127
|
if udf_processor._function_type == "scalar_scala_udf":
|
|
90
128
|
session._udfs[udf.name] = udf
|
|
91
|
-
else:
|
|
92
|
-
session._udfs[udf_proto.function_name.lower()] = udf
|
|
93
129
|
return udf
|
|
94
130
|
|
|
95
131
|
|
|
@@ -114,29 +150,34 @@ def map_common_inline_user_defined_udf(
|
|
|
114
150
|
udf_proto.scalar_scala_udf.outputType
|
|
115
151
|
)
|
|
116
152
|
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
153
|
+
@cache_external_udf_wrapper(from_register_udf=False)
|
|
154
|
+
def get_snowpark_udf(
|
|
155
|
+
udf_proto: expressions_proto.CommonInlineUserDefinedFunction,
|
|
156
|
+
) -> SnowparkUDF:
|
|
157
|
+
session = get_or_create_snowpark_session()
|
|
158
|
+
kwargs = {
|
|
159
|
+
"common_inline_user_defined_function": udf_proto,
|
|
160
|
+
"input_types": input_types,
|
|
161
|
+
"called_from": "map_common_inline_user_defined_udf",
|
|
162
|
+
"return_type": processed_return_type,
|
|
163
|
+
"udf_packages": global_config.get("snowpark.connect.udf.packages", ""),
|
|
164
|
+
"udf_imports": get_python_udxf_import_files(session),
|
|
165
|
+
"original_return_type": original_return_type,
|
|
166
|
+
}
|
|
167
|
+
if require_creating_udf_in_sproc(udf_proto):
|
|
168
|
+
snowpark_udf = process_udf_in_sproc(**kwargs)
|
|
169
|
+
else:
|
|
170
|
+
udf_processor = ProcessCommonInlineUserDefinedFunction(**kwargs)
|
|
171
|
+
udf = udf_processor.create_udf()
|
|
172
|
+
snowpark_udf = SnowparkUDF(
|
|
173
|
+
name=udf.name,
|
|
174
|
+
input_types=udf._input_types,
|
|
175
|
+
return_type=udf._return_type,
|
|
176
|
+
original_return_type=original_return_type,
|
|
177
|
+
)
|
|
178
|
+
return snowpark_udf
|
|
179
|
+
|
|
180
|
+
snowpark_udf = get_snowpark_udf(udf_proto)
|
|
140
181
|
udf_call_expr = snowpark_fn.call_udf(snowpark_udf.name, *snowpark_udf_args)
|
|
141
182
|
|
|
142
183
|
# If the original return type was MapType or StructType but we converted it to VariantType,
|
|
@@ -476,11 +476,8 @@ def map_unresolved_function(
|
|
|
476
476
|
return TypedColumn(result, lambda: expected_types)
|
|
477
477
|
|
|
478
478
|
match function_name:
|
|
479
|
-
case func_name if (
|
|
480
|
-
|
|
481
|
-
):
|
|
482
|
-
# TODO: In Spark, UDFs can override built-in functions in SQL,
|
|
483
|
-
# but not in DataFrame ops.
|
|
479
|
+
case func_name if func_name.lower() in session._udfs:
|
|
480
|
+
# In Spark, UDFs can override built-in functions
|
|
484
481
|
udf = session._udfs[func_name.lower()]
|
|
485
482
|
result_exp = snowpark_fn.call_udf(
|
|
486
483
|
udf.name,
|
|
@@ -6479,6 +6476,18 @@ def map_unresolved_function(
|
|
|
6479
6476
|
if pattern_value is None:
|
|
6480
6477
|
return snowpark_fn.lit(None)
|
|
6481
6478
|
|
|
6479
|
+
# Optimization: treat escaped regex that resolves to a pure literal delimiter
|
|
6480
|
+
# - Single char: "\\."
|
|
6481
|
+
# - Multi char: e.g., "\\.505\\."
|
|
6482
|
+
if re.fullmatch(r"(?:\\.)+", pattern_value):
|
|
6483
|
+
literal_delim = re.sub(r"\\(.)", r"\1", pattern_value)
|
|
6484
|
+
return snowpark_fn.when(
|
|
6485
|
+
limit <= 0,
|
|
6486
|
+
snowpark_fn.split(
|
|
6487
|
+
str_, snowpark_fn.lit(literal_delim)
|
|
6488
|
+
).cast(result_type),
|
|
6489
|
+
).otherwise(native_split)
|
|
6490
|
+
|
|
6482
6491
|
is_regexp = re.match(
|
|
6483
6492
|
".*[\\[\\.\\]\\*\\?\\+\\^\\$\\{\\}\\|\\(\\)\\\\].*",
|
|
6484
6493
|
pattern_value,
|
|
@@ -8285,15 +8294,6 @@ def map_unresolved_function(
|
|
|
8285
8294
|
),
|
|
8286
8295
|
)
|
|
8287
8296
|
result_type = BinaryType()
|
|
8288
|
-
case udf_name if udf_name.lower() in session._udfs:
|
|
8289
|
-
# TODO: In Spark, UDFs can override built-in functions in SQL,
|
|
8290
|
-
# but not in DataFrame ops.
|
|
8291
|
-
udf = session._udfs[udf_name.lower()]
|
|
8292
|
-
result_exp = snowpark_fn.call_udf(
|
|
8293
|
-
udf.name,
|
|
8294
|
-
*(snowpark_fn.cast(arg, VariantType()) for arg in snowpark_args),
|
|
8295
|
-
)
|
|
8296
|
-
result_type = udf.return_type
|
|
8297
8297
|
case udtf_name if udtf_name.lower() in session._udtfs:
|
|
8298
8298
|
udtf, spark_col_names = session._udtfs[udtf_name.lower()]
|
|
8299
8299
|
result_exp = snowpark_fn.call_table_function(
|
|
@@ -9623,13 +9623,14 @@ def _get_decimal_division_result_exp(
|
|
|
9623
9623
|
snowpark_args: list[Column],
|
|
9624
9624
|
spark_function_name: str,
|
|
9625
9625
|
) -> Column:
|
|
9626
|
-
if
|
|
9627
|
-
|
|
9628
|
-
|
|
9629
|
-
|
|
9630
|
-
|
|
9631
|
-
|
|
9632
|
-
|
|
9626
|
+
if (
|
|
9627
|
+
isinstance(other_type, DecimalType)
|
|
9628
|
+
and overflow_detected
|
|
9629
|
+
and global_config.spark_sql_ansi_enabled
|
|
9630
|
+
):
|
|
9631
|
+
raise ArithmeticException(
|
|
9632
|
+
f'[NUMERIC_VALUE_OUT_OF_RANGE] {spark_function_name} cannot be represented as Decimal({result_type.precision}, {result_type.scale}). If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error, and return NULL instead.'
|
|
9633
|
+
)
|
|
9633
9634
|
else:
|
|
9634
9635
|
dividend = snowpark_args[0].cast(DoubleType())
|
|
9635
9636
|
divisor = snowpark_args[1]
|
|
Binary file
|
|
Binary file
|
|
@@ -46,9 +46,10 @@ def map_map_partitions(
|
|
|
46
46
|
udf_check(udf_proto)
|
|
47
47
|
|
|
48
48
|
# Check if this is mapInArrow (eval_type == 207)
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
49
|
+
if (
|
|
50
|
+
udf_proto.WhichOneof("function") == "python_udf"
|
|
51
|
+
and udf_proto.python_udf.eval_type == MAP_IN_ARROW_EVAL_TYPE
|
|
52
|
+
):
|
|
52
53
|
return _map_in_arrow_with_pandas_udtf(input_container, udf_proto)
|
|
53
54
|
else:
|
|
54
55
|
return _map_partitions_with_udf(input_df, udf_proto)
|
|
@@ -126,7 +127,11 @@ def _map_partitions_with_udf(
|
|
|
126
127
|
"udf_name": "spark_map_partitions_udf",
|
|
127
128
|
"input_column_names": input_column_names,
|
|
128
129
|
"replace": True,
|
|
129
|
-
"return_type": proto_to_snowpark_type(
|
|
130
|
+
"return_type": proto_to_snowpark_type(
|
|
131
|
+
udf_proto.python_udf.output_type
|
|
132
|
+
if udf_proto.WhichOneof("function") == "python_udf"
|
|
133
|
+
else udf_proto.scalar_scala_udf.outputType
|
|
134
|
+
),
|
|
130
135
|
"udf_packages": global_config.get("snowpark.connect.udf.packages", ""),
|
|
131
136
|
"udf_imports": get_python_udxf_import_files(input_df.session),
|
|
132
137
|
}
|
|
@@ -90,6 +90,7 @@ def map_relation(
|
|
|
90
90
|
table_name=copy.deepcopy(cached_container.table_name),
|
|
91
91
|
alias=cached_container.alias,
|
|
92
92
|
cached_schema_getter=lambda: cached_df.schema,
|
|
93
|
+
partition_hint=cached_container.partition_hint,
|
|
93
94
|
)
|
|
94
95
|
# If we don't make a copy of the df._output, the expression IDs for attributes in Snowpark DataFrames will differ from those stored in the cache,
|
|
95
96
|
# leading to errors during query execution.
|
|
@@ -189,13 +190,23 @@ def map_relation(
|
|
|
189
190
|
case "read":
|
|
190
191
|
result = read.map_read(rel)
|
|
191
192
|
case "repartition":
|
|
192
|
-
#
|
|
193
|
+
# Preserve partition hint for file output control
|
|
194
|
+
# This handles both repartition(n) with shuffle=True and coalesce(n) with shuffle=False
|
|
193
195
|
result = map_relation(rel.repartition.input)
|
|
196
|
+
if rel.repartition.num_partitions > 0:
|
|
197
|
+
result.partition_hint = rel.repartition.num_partitions
|
|
194
198
|
case "repartition_by_expression":
|
|
195
199
|
# This is a no-op operation in SAS as Snowpark doesn't have the concept of partitions.
|
|
196
200
|
# All the data in the dataframe will be treated as a single partition, and this will not
|
|
197
201
|
# have any side effects.
|
|
198
202
|
result = map_relation(rel.repartition_by_expression.input)
|
|
203
|
+
# Only preserve partition hint if num_partitions is explicitly specified and > 0
|
|
204
|
+
# Column-based repartitioning without count should clear any existing partition hints
|
|
205
|
+
if rel.repartition_by_expression.num_partitions > 0:
|
|
206
|
+
result.partition_hint = rel.repartition_by_expression.num_partitions
|
|
207
|
+
else:
|
|
208
|
+
# Column-based repartitioning clears partition hint (resets to default behavior)
|
|
209
|
+
result.partition_hint = None
|
|
199
210
|
case "replace":
|
|
200
211
|
result = map_row_ops.map_replace(rel)
|
|
201
212
|
case "sample":
|
|
@@ -553,7 +553,14 @@ def map_filter(
|
|
|
553
553
|
rel.filter.condition, input_container.column_map, typer
|
|
554
554
|
)
|
|
555
555
|
|
|
556
|
-
|
|
556
|
+
if rel.filter.input.WhichOneof("rel_type") == "subquery_alias":
|
|
557
|
+
# map_subquery_alias does not actually wrap the DataFrame in an alias or subquery.
|
|
558
|
+
# Apparently, there are cases (e.g., TpcdsQ53) where this is required, without it, we get
|
|
559
|
+
# SQL compilation error.
|
|
560
|
+
# To mitigate it, we are doing .select("*"), .alias() introduces additional describe queries
|
|
561
|
+
result = input_df.select("*").filter(condition.col)
|
|
562
|
+
else:
|
|
563
|
+
result = input_df.filter(condition.col)
|
|
557
564
|
|
|
558
565
|
return DataFrameContainer(
|
|
559
566
|
result,
|
|
@@ -31,6 +31,10 @@ from snowflake.snowpark_connect.type_mapping import (
|
|
|
31
31
|
proto_to_snowpark_type,
|
|
32
32
|
)
|
|
33
33
|
from snowflake.snowpark_connect.utils.context import push_udtf_context
|
|
34
|
+
from snowflake.snowpark_connect.utils.external_udxf_cache import (
|
|
35
|
+
cache_external_udtf,
|
|
36
|
+
get_external_udtf_from_cache,
|
|
37
|
+
)
|
|
34
38
|
from snowflake.snowpark_connect.utils.session import get_or_create_snowpark_session
|
|
35
39
|
from snowflake.snowpark_connect.utils.udtf_helper import (
|
|
36
40
|
SnowparkUDTF,
|
|
@@ -44,6 +48,34 @@ from snowflake.snowpark_connect.utils.udxf_import_utils import (
|
|
|
44
48
|
)
|
|
45
49
|
|
|
46
50
|
|
|
51
|
+
def cache_external_udtf_wrapper(from_register_udtf: bool):
|
|
52
|
+
def outer_wrapper(wrapper_func):
|
|
53
|
+
def wrapper(
|
|
54
|
+
udtf_proto: relation_proto.CommonInlineUserDefinedTableFunction,
|
|
55
|
+
spark_column_names,
|
|
56
|
+
) -> SnowparkUDTF | None:
|
|
57
|
+
udf_hash = hash(str(udtf_proto))
|
|
58
|
+
cached_udtf = get_external_udtf_from_cache(udf_hash)
|
|
59
|
+
|
|
60
|
+
if cached_udtf:
|
|
61
|
+
if from_register_udtf:
|
|
62
|
+
session = get_or_create_snowpark_session()
|
|
63
|
+
session._udtfs[udtf_proto.function_name.lower()] = (
|
|
64
|
+
cached_udtf,
|
|
65
|
+
spark_column_names,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
return cached_udtf
|
|
69
|
+
|
|
70
|
+
snowpark_udf = wrapper_func(udtf_proto, spark_column_names)
|
|
71
|
+
cache_external_udtf(udf_hash, snowpark_udf)
|
|
72
|
+
return snowpark_udf
|
|
73
|
+
|
|
74
|
+
return wrapper
|
|
75
|
+
|
|
76
|
+
return outer_wrapper
|
|
77
|
+
|
|
78
|
+
|
|
47
79
|
def build_expected_types_from_parsed(
|
|
48
80
|
parsed_return: types_proto.DataType,
|
|
49
81
|
) -> List[Tuple[str, Any]]:
|
|
@@ -165,26 +197,37 @@ def register_udtf(
|
|
|
165
197
|
) = process_return_type(python_udft.return_type)
|
|
166
198
|
function_name = udtf_proto.function_name
|
|
167
199
|
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
200
|
+
@cache_external_udtf_wrapper(from_register_udtf=True)
|
|
201
|
+
def _register_udtf(
|
|
202
|
+
udtf_proto: relation_proto.CommonInlineUserDefinedTableFunction,
|
|
203
|
+
spark_column_names,
|
|
204
|
+
):
|
|
205
|
+
kwargs = {
|
|
206
|
+
"session": session,
|
|
207
|
+
"udtf_proto": udtf_proto,
|
|
208
|
+
"expected_types": expected_types,
|
|
209
|
+
"output_schema": output_schema,
|
|
210
|
+
"packages": global_config.get("snowpark.connect.udf.packages", ""),
|
|
211
|
+
"imports": get_python_udxf_import_files(session),
|
|
212
|
+
"called_from": "register_udtf",
|
|
213
|
+
"is_arrow_enabled": is_arrow_enabled_in_udtf(),
|
|
214
|
+
"is_spark_compatible_udtf_mode_enabled": is_spark_compatible_udtf_mode_enabled(),
|
|
215
|
+
}
|
|
216
|
+
|
|
217
|
+
if require_creating_udtf_in_sproc(udtf_proto):
|
|
218
|
+
snowpark_udtf = create_udtf_in_sproc(**kwargs)
|
|
219
|
+
else:
|
|
220
|
+
udtf = create_udtf(**kwargs)
|
|
221
|
+
snowpark_udtf = SnowparkUDTF(
|
|
222
|
+
name=udtf.name,
|
|
223
|
+
input_types=udtf._input_types,
|
|
224
|
+
output_schema=output_schema,
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
return snowpark_udtf
|
|
187
228
|
|
|
229
|
+
snowpark_udtf = _register_udtf(udtf_proto, spark_column_names)
|
|
230
|
+
# We have to update cached _udtfs here, because function could have been cached in map_common_inline_user_defined_table_function
|
|
188
231
|
session._udtfs[function_name.lower()] = (snowpark_udtf, spark_column_names)
|
|
189
232
|
return snowpark_udtf
|
|
190
233
|
|
|
@@ -213,32 +256,41 @@ def map_common_inline_user_defined_table_function(
|
|
|
213
256
|
spark_column_names,
|
|
214
257
|
) = process_return_type(python_udft.return_type)
|
|
215
258
|
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
259
|
+
@cache_external_udtf_wrapper(from_register_udtf=False)
|
|
260
|
+
def _get_udtf(
|
|
261
|
+
udtf_proto: relation_proto.CommonInlineUserDefinedTableFunction,
|
|
262
|
+
spark_column_names,
|
|
263
|
+
):
|
|
264
|
+
kwargs = {
|
|
265
|
+
"session": session,
|
|
266
|
+
"udtf_proto": udtf_proto,
|
|
267
|
+
"expected_types": expected_types,
|
|
268
|
+
"output_schema": output_schema,
|
|
269
|
+
"packages": global_config.get("snowpark.connect.udf.packages", ""),
|
|
270
|
+
"imports": get_python_udxf_import_files(session),
|
|
271
|
+
"called_from": "map_common_inline_user_defined_table_function",
|
|
272
|
+
"is_arrow_enabled": is_arrow_enabled_in_udtf(),
|
|
273
|
+
"is_spark_compatible_udtf_mode_enabled": is_spark_compatible_udtf_mode_enabled(),
|
|
274
|
+
}
|
|
275
|
+
|
|
276
|
+
if require_creating_udtf_in_sproc(udtf_proto):
|
|
277
|
+
snowpark_udtf_or_error = create_udtf_in_sproc(**kwargs)
|
|
278
|
+
if isinstance(snowpark_udtf_or_error, str):
|
|
279
|
+
raise PythonException(snowpark_udtf_or_error)
|
|
280
|
+
snowpark_udtf = snowpark_udtf_or_error
|
|
281
|
+
else:
|
|
282
|
+
udtf_or_error = create_udtf(**kwargs)
|
|
283
|
+
if isinstance(udtf_or_error, str):
|
|
284
|
+
raise PythonException(udtf_or_error)
|
|
285
|
+
udtf = udtf_or_error
|
|
286
|
+
snowpark_udtf = SnowparkUDTF(
|
|
287
|
+
name=udtf.name,
|
|
288
|
+
input_types=udtf._input_types,
|
|
289
|
+
output_schema=output_schema,
|
|
290
|
+
)
|
|
291
|
+
return snowpark_udtf
|
|
241
292
|
|
|
293
|
+
snowpark_udtf = _get_udtf(rel, spark_column_names)
|
|
242
294
|
column_map = ColumnNameMap([], [])
|
|
243
295
|
snowpark_udtf_args = []
|
|
244
296
|
|
|
@@ -6,6 +6,7 @@ import random
|
|
|
6
6
|
import re
|
|
7
7
|
import string
|
|
8
8
|
import time
|
|
9
|
+
import uuid
|
|
9
10
|
from typing import Sequence
|
|
10
11
|
|
|
11
12
|
import pyspark.sql.connect.proto.relations_pb2 as relation_proto
|
|
@@ -153,6 +154,49 @@ def random_string(
|
|
|
153
154
|
return "".join([prefix, random_part, suffix])
|
|
154
155
|
|
|
155
156
|
|
|
157
|
+
def generate_spark_compatible_filename(
|
|
158
|
+
task_id: int = 0,
|
|
159
|
+
attempt_number: int = 0,
|
|
160
|
+
compression: str = None,
|
|
161
|
+
format_ext: str = "parquet",
|
|
162
|
+
) -> str:
|
|
163
|
+
"""Generate a Spark-compatible filename following the convention:
|
|
164
|
+
part-<task-id>-<uuid>-c<attempt-number>.<compression>.<format>
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
task_id: Task ID (usually 0 for single partition)
|
|
168
|
+
attempt_number: Attempt number (usually 0)
|
|
169
|
+
compression: Compression type (e.g., 'snappy', 'gzip', 'none')
|
|
170
|
+
format_ext: File format extension (e.g., 'parquet', 'csv', 'json')
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
A filename string following Spark's naming convention
|
|
174
|
+
"""
|
|
175
|
+
# Generate a UUID for uniqueness
|
|
176
|
+
file_uuid = str(uuid.uuid4())
|
|
177
|
+
|
|
178
|
+
# Format task ID with leading zeros (5 digits)
|
|
179
|
+
formatted_task_id = f"{task_id:05d}"
|
|
180
|
+
|
|
181
|
+
# Format attempt number with leading zeros (3 digits)
|
|
182
|
+
formatted_attempt = f"{attempt_number:03d}"
|
|
183
|
+
|
|
184
|
+
# Build the base filename
|
|
185
|
+
base_name = f"part-{formatted_task_id}-{file_uuid}-c{formatted_attempt}"
|
|
186
|
+
|
|
187
|
+
# Add compression if specified and not 'none'
|
|
188
|
+
if compression and compression.lower() not in ("none", "uncompressed"):
|
|
189
|
+
compression_part = f".{compression.lower()}"
|
|
190
|
+
else:
|
|
191
|
+
compression_part = ""
|
|
192
|
+
|
|
193
|
+
# Add format extension if specified
|
|
194
|
+
if format_ext:
|
|
195
|
+
return f"{base_name}{compression_part}.{format_ext}"
|
|
196
|
+
else:
|
|
197
|
+
return f"{base_name}{compression_part}"
|
|
198
|
+
|
|
199
|
+
|
|
156
200
|
def _normalize_query_for_semantic_hash(query_str: str) -> str:
|
|
157
201
|
"""
|
|
158
202
|
Normalize a query string for semantic comparison by extracting original names from
|