snowpark-connect 0.25.0__py3-none-any.whl → 0.27.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/snowpark_connect/config.py +10 -3
- snowflake/snowpark_connect/dataframe_container.py +16 -0
- snowflake/snowpark_connect/expression/map_expression.py +15 -0
- snowflake/snowpark_connect/expression/map_udf.py +68 -27
- snowflake/snowpark_connect/expression/map_unresolved_attribute.py +18 -0
- snowflake/snowpark_connect/expression/map_unresolved_function.py +38 -28
- snowflake/snowpark_connect/includes/jars/json4s-native_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/paranamer-2.8.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.1.0.jar +0 -0
- snowflake/snowpark_connect/relation/map_extension.py +9 -7
- snowflake/snowpark_connect/relation/map_map_partitions.py +36 -72
- snowflake/snowpark_connect/relation/map_relation.py +15 -2
- snowflake/snowpark_connect/relation/map_row_ops.py +8 -1
- snowflake/snowpark_connect/relation/map_show_string.py +2 -0
- snowflake/snowpark_connect/relation/map_sql.py +63 -2
- snowflake/snowpark_connect/relation/map_udtf.py +96 -44
- snowflake/snowpark_connect/relation/utils.py +44 -0
- snowflake/snowpark_connect/relation/write/map_write.py +135 -24
- snowflake/snowpark_connect/resources_initializer.py +18 -5
- snowflake/snowpark_connect/server.py +12 -2
- snowflake/snowpark_connect/utils/artifacts.py +4 -5
- snowflake/snowpark_connect/utils/concurrent.py +4 -0
- snowflake/snowpark_connect/utils/context.py +41 -1
- snowflake/snowpark_connect/utils/external_udxf_cache.py +36 -0
- snowflake/snowpark_connect/utils/pandas_udtf_utils.py +86 -2
- snowflake/snowpark_connect/utils/scala_udf_utils.py +250 -242
- snowflake/snowpark_connect/utils/session.py +4 -0
- snowflake/snowpark_connect/utils/udf_utils.py +71 -118
- snowflake/snowpark_connect/utils/udtf_helper.py +17 -7
- snowflake/snowpark_connect/utils/udtf_utils.py +3 -16
- snowflake/snowpark_connect/version.py +2 -3
- {snowpark_connect-0.25.0.dist-info → snowpark_connect-0.27.0.dist-info}/METADATA +2 -2
- {snowpark_connect-0.25.0.dist-info → snowpark_connect-0.27.0.dist-info}/RECORD +41 -37
- {snowpark_connect-0.25.0.data → snowpark_connect-0.27.0.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.25.0.data → snowpark_connect-0.27.0.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.25.0.data → snowpark_connect-0.27.0.data}/scripts/snowpark-submit +0 -0
- {snowpark_connect-0.25.0.dist-info → snowpark_connect-0.27.0.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.25.0.dist-info → snowpark_connect-0.27.0.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-0.25.0.dist-info → snowpark_connect-0.27.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-0.25.0.dist-info → snowpark_connect-0.27.0.dist-info}/licenses/NOTICE-binary +0 -0
- {snowpark_connect-0.25.0.dist-info → snowpark_connect-0.27.0.dist-info}/top_level.txt +0 -0
|
@@ -16,7 +16,7 @@ from snowflake.snowpark._internal.analyzer.analyzer_utils import (
|
|
|
16
16
|
unquote_if_quoted,
|
|
17
17
|
)
|
|
18
18
|
from snowflake.snowpark.exceptions import SnowparkSQLException
|
|
19
|
-
from snowflake.snowpark.functions import col, lit, object_construct
|
|
19
|
+
from snowflake.snowpark.functions import col, lit, object_construct, sql_expr
|
|
20
20
|
from snowflake.snowpark.types import (
|
|
21
21
|
ArrayType,
|
|
22
22
|
DataType,
|
|
@@ -40,7 +40,10 @@ from snowflake.snowpark_connect.relation.io_utils import (
|
|
|
40
40
|
from snowflake.snowpark_connect.relation.map_relation import map_relation
|
|
41
41
|
from snowflake.snowpark_connect.relation.read.reader_config import CsvWriterConfig
|
|
42
42
|
from snowflake.snowpark_connect.relation.stage_locator import get_paths_from_stage
|
|
43
|
-
from snowflake.snowpark_connect.relation.utils import
|
|
43
|
+
from snowflake.snowpark_connect.relation.utils import (
|
|
44
|
+
generate_spark_compatible_filename,
|
|
45
|
+
random_string,
|
|
46
|
+
)
|
|
44
47
|
from snowflake.snowpark_connect.type_mapping import snowpark_to_iceberg_type
|
|
45
48
|
from snowflake.snowpark_connect.utils.context import get_session_id
|
|
46
49
|
from snowflake.snowpark_connect.utils.identifiers import (
|
|
@@ -48,6 +51,7 @@ from snowflake.snowpark_connect.utils.identifiers import (
|
|
|
48
51
|
split_fully_qualified_spark_name,
|
|
49
52
|
)
|
|
50
53
|
from snowflake.snowpark_connect.utils.session import get_or_create_snowpark_session
|
|
54
|
+
from snowflake.snowpark_connect.utils.snowpark_connect_logging import logger
|
|
51
55
|
from snowflake.snowpark_connect.utils.telemetry import (
|
|
52
56
|
SnowparkConnectNotImplementedError,
|
|
53
57
|
telemetry,
|
|
@@ -133,45 +137,99 @@ def map_write(request: proto_base.ExecutePlanRequest):
|
|
|
133
137
|
write_op.source = ""
|
|
134
138
|
|
|
135
139
|
should_write_to_single_file = str_to_bool(write_op.options.get("single", "false"))
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
140
|
+
|
|
141
|
+
# Support Snowflake-specific snowflake_max_file_size option. This is NOT a spark option.
|
|
142
|
+
max_file_size = None
|
|
143
|
+
if (
|
|
144
|
+
"snowflake_max_file_size" in write_op.options
|
|
145
|
+
and int(write_op.options["snowflake_max_file_size"]) > 0
|
|
146
|
+
):
|
|
147
|
+
max_file_size = int(write_op.options["snowflake_max_file_size"])
|
|
148
|
+
elif should_write_to_single_file:
|
|
149
|
+
# providing default size as 1GB for single file write
|
|
150
|
+
max_file_size = 1073741824
|
|
141
151
|
match write_op.source:
|
|
142
152
|
case "csv" | "parquet" | "json" | "text":
|
|
143
153
|
write_path = get_paths_from_stage(
|
|
144
154
|
[write_op.path],
|
|
145
155
|
session=session,
|
|
146
156
|
)[0]
|
|
157
|
+
# Generate Spark-compatible filename with proper extension
|
|
158
|
+
extension = write_op.source if write_op.source != "text" else "txt"
|
|
159
|
+
|
|
160
|
+
# Get compression from options for proper filename generation
|
|
161
|
+
compression_option = write_op.options.get("compression", "none")
|
|
162
|
+
|
|
163
|
+
# Generate Spark-compatible filename or prefix
|
|
147
164
|
# we need a random prefix to support "append" mode
|
|
148
165
|
# otherwise copy into with overwrite=False will fail if the file already exists
|
|
149
|
-
if should_write_to_single_file:
|
|
150
|
-
extention = write_op.source if write_op.source != "text" else "txt"
|
|
151
|
-
temp_file_prefix_on_stage = (
|
|
152
|
-
f"{write_path}/{random_string(10, 'sas_file_')}.{extention}"
|
|
153
|
-
)
|
|
154
|
-
else:
|
|
155
|
-
temp_file_prefix_on_stage = (
|
|
156
|
-
f"{write_path}/{random_string(10, 'sas_file_')}"
|
|
157
|
-
)
|
|
158
166
|
overwrite = (
|
|
159
167
|
write_op.mode
|
|
160
168
|
== commands_proto.WriteOperation.SaveMode.SAVE_MODE_OVERWRITE
|
|
161
169
|
)
|
|
170
|
+
|
|
171
|
+
if overwrite:
|
|
172
|
+
try:
|
|
173
|
+
path_after_stage = (
|
|
174
|
+
write_path.split("/", 1)[1] if "/" in write_path else ""
|
|
175
|
+
)
|
|
176
|
+
if not path_after_stage or path_after_stage == "/":
|
|
177
|
+
logger.warning(
|
|
178
|
+
f"Skipping REMOVE for root path {write_path} - too broad scope"
|
|
179
|
+
)
|
|
180
|
+
else:
|
|
181
|
+
remove_command = f"REMOVE {write_path}/"
|
|
182
|
+
session.sql(remove_command).collect()
|
|
183
|
+
logger.info(f"Successfully cleared directory: {write_path}")
|
|
184
|
+
except Exception as e:
|
|
185
|
+
logger.warning(f"Could not clear directory {write_path}: {e}")
|
|
186
|
+
|
|
187
|
+
if should_write_to_single_file:
|
|
188
|
+
# Single file: generate complete filename with extension
|
|
189
|
+
spark_filename = generate_spark_compatible_filename(
|
|
190
|
+
task_id=0,
|
|
191
|
+
attempt_number=0,
|
|
192
|
+
compression=compression_option,
|
|
193
|
+
format_ext=extension,
|
|
194
|
+
)
|
|
195
|
+
temp_file_prefix_on_stage = f"{write_path}/{spark_filename}"
|
|
196
|
+
else:
|
|
197
|
+
# Multiple files: generate prefix without extension (Snowflake will add extensions)
|
|
198
|
+
spark_filename_prefix = generate_spark_compatible_filename(
|
|
199
|
+
task_id=0,
|
|
200
|
+
attempt_number=0,
|
|
201
|
+
compression=compression_option,
|
|
202
|
+
format_ext="", # No extension for prefix
|
|
203
|
+
)
|
|
204
|
+
temp_file_prefix_on_stage = f"{write_path}/{spark_filename_prefix}"
|
|
205
|
+
|
|
206
|
+
default_compression = "NONE" if write_op.source != "parquet" else "snappy"
|
|
207
|
+
compression = write_op.options.get(
|
|
208
|
+
"compression", default_compression
|
|
209
|
+
).upper()
|
|
162
210
|
parameters = {
|
|
163
211
|
"location": temp_file_prefix_on_stage,
|
|
164
212
|
"file_format_type": write_op.source
|
|
165
213
|
if write_op.source != "text"
|
|
166
214
|
else "csv",
|
|
167
215
|
"format_type_options": {
|
|
168
|
-
"COMPRESSION":
|
|
216
|
+
"COMPRESSION": compression,
|
|
169
217
|
},
|
|
170
218
|
"overwrite": overwrite,
|
|
171
219
|
}
|
|
172
|
-
|
|
173
|
-
|
|
220
|
+
# By default, download from the same prefix we wrote to.
|
|
221
|
+
download_stage_path = temp_file_prefix_on_stage
|
|
222
|
+
|
|
223
|
+
# Check for partition hint early to determine precedence over single option
|
|
224
|
+
partition_hint = result.partition_hint
|
|
225
|
+
|
|
226
|
+
# Apply max_file_size for both single and multi-file scenarios
|
|
227
|
+
# This helps control when Snowflake splits files into multiple parts
|
|
228
|
+
if max_file_size:
|
|
174
229
|
parameters["max_file_size"] = max_file_size
|
|
230
|
+
# Only apply single option if no partition hint is present (partition hint takes precedence)
|
|
231
|
+
if should_write_to_single_file and partition_hint is None:
|
|
232
|
+
parameters["single"] = True
|
|
175
233
|
rewritten_df: snowpark.DataFrame = rewrite_df(input_df, write_op.source)
|
|
176
234
|
get_param_from_options(parameters, write_op.options, write_op.source)
|
|
177
235
|
if write_op.partitioning_columns:
|
|
@@ -186,10 +244,50 @@ def map_write(request: proto_base.ExecutePlanRequest):
|
|
|
186
244
|
)
|
|
187
245
|
else:
|
|
188
246
|
parameters["partition_by"] = partitioning_columns[0]
|
|
189
|
-
|
|
247
|
+
|
|
248
|
+
# If a partition hint is present (from DataFrame.repartition(n)), optionally split the
|
|
249
|
+
# write into n COPY INTO calls by assigning a synthetic partition id. Controlled by config.
|
|
250
|
+
# Note: This affects only the number of output files, not computation semantics.
|
|
251
|
+
# Partition hints take precedence over single option (matches Spark behavior) when enabled.
|
|
252
|
+
repartition_for_writes_enabled = (
|
|
253
|
+
global_config.snowflake_repartition_for_writes
|
|
254
|
+
)
|
|
255
|
+
if repartition_for_writes_enabled and partition_hint and partition_hint > 0:
|
|
256
|
+
# Create a stable synthetic file number per row using ROW_NUMBER() over a
|
|
257
|
+
# randomized order, then modulo partition_hint. We rely on sql_expr to avoid
|
|
258
|
+
# adding new helpers.
|
|
259
|
+
file_num_col = "_sas_file_num"
|
|
260
|
+
partitioned_df = rewritten_df.withColumn(
|
|
261
|
+
file_num_col,
|
|
262
|
+
sql_expr(
|
|
263
|
+
f"(ROW_NUMBER() OVER (ORDER BY RANDOM())) % {partition_hint}"
|
|
264
|
+
),
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
# Execute multiple COPY INTO operations, one per target file.
|
|
268
|
+
# Since we write per-partition with distinct prefixes, download from the base write path.
|
|
269
|
+
download_stage_path = write_path
|
|
270
|
+
for part_idx in range(partition_hint):
|
|
271
|
+
part_params = dict(parameters)
|
|
272
|
+
# Preserve Spark-like filename prefix per partition so downloaded basenames
|
|
273
|
+
# match the expected Spark pattern (with possible Snowflake counters appended).
|
|
274
|
+
per_part_prefix = generate_spark_compatible_filename(
|
|
275
|
+
task_id=part_idx,
|
|
276
|
+
attempt_number=0,
|
|
277
|
+
compression=compression_option,
|
|
278
|
+
format_ext="", # prefix only; Snowflake appends extension/counters
|
|
279
|
+
)
|
|
280
|
+
part_params["location"] = f"{write_path}/{per_part_prefix}"
|
|
281
|
+
(
|
|
282
|
+
partitioned_df.filter(col(file_num_col) == lit(part_idx))
|
|
283
|
+
.drop(file_num_col)
|
|
284
|
+
.write.copy_into_location(**part_params)
|
|
285
|
+
)
|
|
286
|
+
else:
|
|
287
|
+
rewritten_df.write.copy_into_location(**parameters)
|
|
190
288
|
if not is_cloud_path(write_op.path):
|
|
191
289
|
store_files_locally(
|
|
192
|
-
|
|
290
|
+
download_stage_path,
|
|
193
291
|
write_op.path,
|
|
194
292
|
overwrite,
|
|
195
293
|
session,
|
|
@@ -569,7 +667,12 @@ def _validate_schema_and_get_writer(
|
|
|
569
667
|
col_name = field.name
|
|
570
668
|
renamed = col_name
|
|
571
669
|
matching_field = next(
|
|
572
|
-
(
|
|
670
|
+
(
|
|
671
|
+
f
|
|
672
|
+
for f in table_schema.fields
|
|
673
|
+
if unquote_if_quoted(f.name).lower()
|
|
674
|
+
== unquote_if_quoted(col_name).lower()
|
|
675
|
+
),
|
|
573
676
|
None,
|
|
574
677
|
)
|
|
575
678
|
if matching_field is not None and matching_field != col_name:
|
|
@@ -591,7 +694,10 @@ def _validate_schema_and_get_writer(
|
|
|
591
694
|
|
|
592
695
|
|
|
593
696
|
def _validate_schema_for_append(
|
|
594
|
-
table_schema: DataType,
|
|
697
|
+
table_schema: DataType,
|
|
698
|
+
data_schema: DataType,
|
|
699
|
+
snowpark_table_name: str,
|
|
700
|
+
compare_structs: bool = False,
|
|
595
701
|
):
|
|
596
702
|
match (table_schema, data_schema):
|
|
597
703
|
case (_, _) if table_schema == data_schema:
|
|
@@ -600,7 +706,11 @@ def _validate_schema_for_append(
|
|
|
600
706
|
case (StructType() as table_struct, StructType() as data_struct):
|
|
601
707
|
|
|
602
708
|
def _comparable_col_name(col: str) -> str:
|
|
603
|
-
|
|
709
|
+
name = col if global_config.spark_sql_caseSensitive else col.lower()
|
|
710
|
+
if compare_structs:
|
|
711
|
+
return name
|
|
712
|
+
else:
|
|
713
|
+
return unquote_if_quoted(name)
|
|
604
714
|
|
|
605
715
|
def invalid_struct_schema():
|
|
606
716
|
raise AnalysisException(
|
|
@@ -640,6 +750,7 @@ def _validate_schema_for_append(
|
|
|
640
750
|
matching_table_field.datatype,
|
|
641
751
|
data_field.datatype,
|
|
642
752
|
snowpark_table_name,
|
|
753
|
+
compare_structs=True,
|
|
643
754
|
)
|
|
644
755
|
|
|
645
756
|
return
|
|
@@ -9,6 +9,7 @@ from snowflake.snowpark_connect.utils.session import get_or_create_snowpark_sess
|
|
|
9
9
|
from snowflake.snowpark_connect.utils.snowpark_connect_logging import logger
|
|
10
10
|
|
|
11
11
|
_resources_initialized = threading.Event()
|
|
12
|
+
_initializer_lock = threading.Lock()
|
|
12
13
|
SPARK_VERSION = "3.5.6"
|
|
13
14
|
RESOURCE_PATH = "/snowflake/snowpark_connect/resources"
|
|
14
15
|
|
|
@@ -57,6 +58,9 @@ def initialize_resources() -> None:
|
|
|
57
58
|
f"spark-connect-client-jvm_2.12-{SPARK_VERSION}.jar",
|
|
58
59
|
f"spark-common-utils_2.12-{SPARK_VERSION}.jar",
|
|
59
60
|
"json4s-ast_2.12-3.7.0-M11.jar",
|
|
61
|
+
"json4s-native_2.12-3.7.0-M11.jar",
|
|
62
|
+
"json4s-core_2.12-3.7.0-M11.jar",
|
|
63
|
+
"paranamer-2.8.3.jar",
|
|
60
64
|
]
|
|
61
65
|
|
|
62
66
|
for jar in jar_files:
|
|
@@ -94,10 +98,19 @@ def initialize_resources() -> None:
|
|
|
94
98
|
logger.info(f"All resources initialized in {time.time() - start_time:.2f}s")
|
|
95
99
|
|
|
96
100
|
|
|
101
|
+
_resource_initializer = threading.Thread(
|
|
102
|
+
target=initialize_resources, name="ResourceInitializer"
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
|
|
97
106
|
def initialize_resources_async() -> threading.Thread:
|
|
98
107
|
"""Start resource initialization in background."""
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
108
|
+
with _initializer_lock:
|
|
109
|
+
if not _resource_initializer.is_alive() and _resource_initializer.ident is None:
|
|
110
|
+
_resource_initializer.start()
|
|
111
|
+
return _resource_initializer
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def wait_for_resource_initialization() -> None:
|
|
115
|
+
with _initializer_lock:
|
|
116
|
+
_resource_initializer.join()
|
|
@@ -88,6 +88,9 @@ from snowflake.snowpark_connect.utils.context import (
|
|
|
88
88
|
set_spark_version,
|
|
89
89
|
)
|
|
90
90
|
from snowflake.snowpark_connect.utils.env_utils import get_int_from_env
|
|
91
|
+
from snowflake.snowpark_connect.utils.external_udxf_cache import (
|
|
92
|
+
clear_external_udxf_cache,
|
|
93
|
+
)
|
|
91
94
|
from snowflake.snowpark_connect.utils.interrupt import (
|
|
92
95
|
interrupt_all_queries,
|
|
93
96
|
interrupt_queries_with_tag,
|
|
@@ -436,7 +439,8 @@ class SnowflakeConnectServicer(proto_base_grpc.SparkConnectServiceServicer):
|
|
|
436
439
|
lambda: map_local_relation(relation), # noqa: B023
|
|
437
440
|
materialize=True,
|
|
438
441
|
)
|
|
439
|
-
except Exception:
|
|
442
|
+
except Exception as e:
|
|
443
|
+
logger.warning("Failed to put df into cache: %s", str(e))
|
|
440
444
|
# fallback - treat as regular artifact
|
|
441
445
|
_handle_regular_artifact()
|
|
442
446
|
else:
|
|
@@ -527,7 +531,10 @@ class SnowflakeConnectServicer(proto_base_grpc.SparkConnectServiceServicer):
|
|
|
527
531
|
if name.endswith(".class"):
|
|
528
532
|
# name is <dir>/<package>/<class_name>
|
|
529
533
|
# we don't need the dir name, but require the package, so only remove dir
|
|
530
|
-
|
|
534
|
+
if os.name != "nt":
|
|
535
|
+
class_files[name.split("/", 1)[-1]] = filepath
|
|
536
|
+
else:
|
|
537
|
+
class_files[name.split("\\", 1)[-1]] = filepath
|
|
531
538
|
continue
|
|
532
539
|
session.file.put(
|
|
533
540
|
filepath,
|
|
@@ -556,6 +563,9 @@ class SnowflakeConnectServicer(proto_base_grpc.SparkConnectServiceServicer):
|
|
|
556
563
|
if class_files:
|
|
557
564
|
write_class_files_to_stage(session, class_files)
|
|
558
565
|
|
|
566
|
+
if any(not name.startswith("cache") for name in filenames.keys()):
|
|
567
|
+
clear_external_udxf_cache(session)
|
|
568
|
+
|
|
559
569
|
return proto_base.AddArtifactsResponse(artifacts=list(response.values()))
|
|
560
570
|
|
|
561
571
|
def ArtifactStatus(self, request, context):
|
|
@@ -39,7 +39,7 @@ def write_temporary_artifact(
|
|
|
39
39
|
if os.name != "nt":
|
|
40
40
|
filepath = f"/tmp/sas-{session.session_id}/{name}"
|
|
41
41
|
else:
|
|
42
|
-
filepath = f"{tempfile.gettempdir()}
|
|
42
|
+
filepath = f"{tempfile.gettempdir()}\\sas-{session.session_id}\\{name}"
|
|
43
43
|
# The name comes to us as a path (e.g. cache/<name>), so we need to create
|
|
44
44
|
# the parent directory if it doesn't exist to avoid errors during writing.
|
|
45
45
|
pathlib.Path(filepath).parent.mkdir(parents=True, exist_ok=True)
|
|
@@ -55,11 +55,10 @@ def write_class_files_to_stage(
|
|
|
55
55
|
) -> None:
|
|
56
56
|
if os.name != "nt":
|
|
57
57
|
filepath = f"/tmp/sas-{session.session_id}"
|
|
58
|
+
jar_name = f'{filepath}/{hashlib.sha256(str(files).encode("utf-8")).hexdigest()[:10]}.jar'
|
|
58
59
|
else:
|
|
59
|
-
filepath = f"{tempfile.gettempdir()}
|
|
60
|
-
|
|
61
|
-
f'{filepath}/{hashlib.sha256(str(files).encode("utf-8")).hexdigest()[:10]}.jar'
|
|
62
|
-
)
|
|
60
|
+
filepath = f"{tempfile.gettempdir()}\\sas-{session.session_id}"
|
|
61
|
+
jar_name = f'{filepath}\\{hashlib.sha256(str(files).encode("utf-8")).hexdigest()[:10]}.jar'
|
|
63
62
|
with zipfile.ZipFile(jar_name, "w", zipfile.ZIP_DEFLATED) as jar:
|
|
64
63
|
for name, path in files.items():
|
|
65
64
|
jar.write(path, name)
|
|
@@ -30,6 +30,9 @@ _sql_aggregate_function_count = ContextVar[int](
|
|
|
30
30
|
"_contains_aggregate_function", default=0
|
|
31
31
|
)
|
|
32
32
|
|
|
33
|
+
# Context for parsing map_partitions
|
|
34
|
+
_map_partitions_stack = ContextVar[int]("_map_partitions_stack", default=0)
|
|
35
|
+
|
|
33
36
|
# We have to generate our own plan IDs that are different from Spark's.
|
|
34
37
|
# Spark plan IDs start at 0, so pick a "big enough" number to avoid overlaps.
|
|
35
38
|
_STARTING_SQL_PLAN_ID = 0x80000000
|
|
@@ -49,6 +52,7 @@ _spark_client_type_regex = re.compile(r"spark/(?P<spark_version>\d+\.\d+\.\d+)")
|
|
|
49
52
|
_current_operation = ContextVar[str]("_current_operation", default="default")
|
|
50
53
|
_resolving_fun_args = ContextVar[bool]("_resolving_fun_args", default=False)
|
|
51
54
|
_resolving_lambda_fun = ContextVar[bool]("_resolving_lambdas", default=False)
|
|
55
|
+
_current_lambda_params = ContextVar[list[str]]("_current_lambda_params", default=[])
|
|
52
56
|
|
|
53
57
|
_is_window_enabled = ContextVar[bool]("_is_window_enabled", default=False)
|
|
54
58
|
_is_in_pivot = ContextVar[bool]("_is_in_pivot", default=False)
|
|
@@ -206,6 +210,16 @@ def push_evaluating_join_condition(join_type, left_keys, right_keys):
|
|
|
206
210
|
_is_evaluating_join_condition.set(prev)
|
|
207
211
|
|
|
208
212
|
|
|
213
|
+
@contextmanager
|
|
214
|
+
def push_map_partitions():
|
|
215
|
+
_map_partitions_stack.set(_map_partitions_stack.get() + 1)
|
|
216
|
+
yield
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def map_partitions_depth() -> int:
|
|
220
|
+
return _map_partitions_stack.get()
|
|
221
|
+
|
|
222
|
+
|
|
209
223
|
@contextmanager
|
|
210
224
|
def push_sql_scope():
|
|
211
225
|
"""
|
|
@@ -238,16 +252,21 @@ def push_operation_scope(operation: str):
|
|
|
238
252
|
|
|
239
253
|
|
|
240
254
|
@contextmanager
|
|
241
|
-
def resolving_lambda_function():
|
|
255
|
+
def resolving_lambda_function(param_names: list[str] = None):
|
|
242
256
|
"""
|
|
243
257
|
Context manager that sets a flag indicating lambda function is being resolved.
|
|
258
|
+
Also tracks the lambda parameter names for validation.
|
|
244
259
|
"""
|
|
245
260
|
prev = _resolving_lambda_fun.get()
|
|
261
|
+
prev_params = _current_lambda_params.get()
|
|
246
262
|
try:
|
|
247
263
|
_resolving_lambda_fun.set(True)
|
|
264
|
+
if param_names is not None:
|
|
265
|
+
_current_lambda_params.set(param_names)
|
|
248
266
|
yield
|
|
249
267
|
finally:
|
|
250
268
|
_resolving_lambda_fun.set(prev)
|
|
269
|
+
_current_lambda_params.set(prev_params)
|
|
251
270
|
|
|
252
271
|
|
|
253
272
|
def is_lambda_being_resolved() -> bool:
|
|
@@ -257,6 +276,13 @@ def is_lambda_being_resolved() -> bool:
|
|
|
257
276
|
return _resolving_lambda_fun.get()
|
|
258
277
|
|
|
259
278
|
|
|
279
|
+
def get_current_lambda_params() -> list[str]:
|
|
280
|
+
"""
|
|
281
|
+
Returns the current lambda parameter names.
|
|
282
|
+
"""
|
|
283
|
+
return _current_lambda_params.get()
|
|
284
|
+
|
|
285
|
+
|
|
260
286
|
@contextmanager
|
|
261
287
|
def resolving_fun_args():
|
|
262
288
|
"""
|
|
@@ -270,6 +296,19 @@ def resolving_fun_args():
|
|
|
270
296
|
_resolving_fun_args.set(prev)
|
|
271
297
|
|
|
272
298
|
|
|
299
|
+
@contextmanager
|
|
300
|
+
def not_resolving_fun_args():
|
|
301
|
+
"""
|
|
302
|
+
Context manager that sets a flag indicating function arguments are *not* being resolved.
|
|
303
|
+
"""
|
|
304
|
+
prev = _resolving_fun_args.get()
|
|
305
|
+
try:
|
|
306
|
+
_resolving_fun_args.set(False)
|
|
307
|
+
yield
|
|
308
|
+
finally:
|
|
309
|
+
_resolving_fun_args.set(prev)
|
|
310
|
+
|
|
311
|
+
|
|
273
312
|
def is_function_argument_being_resolved() -> bool:
|
|
274
313
|
"""
|
|
275
314
|
Returns True if function arguments are being resolved.
|
|
@@ -350,6 +389,7 @@ def clear_context_data() -> None:
|
|
|
350
389
|
|
|
351
390
|
_next_sql_plan_id.set(_STARTING_SQL_PLAN_ID)
|
|
352
391
|
_sql_plan_name_map.set({})
|
|
392
|
+
_map_partitions_stack.set(0)
|
|
353
393
|
_sql_aggregate_function_count.set(0)
|
|
354
394
|
_sql_named_args.set({})
|
|
355
395
|
_sql_pos_args.set({})
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
|
|
5
|
+
from snowflake.snowpark import Session
|
|
6
|
+
from snowflake.snowpark_connect.utils.concurrent import SynchronizedDict
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def init_external_udxf_cache(session: Session) -> None:
|
|
10
|
+
session.external_udfs_cache = SynchronizedDict()
|
|
11
|
+
session.external_udtfs_cache = SynchronizedDict()
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def clear_external_udxf_cache(session: Session) -> None:
|
|
15
|
+
session.external_udfs_cache.clear()
|
|
16
|
+
session.external_udtfs_cache.clear()
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def get_external_udf_from_cache(hash: str):
|
|
20
|
+
return Session.get_active_session().external_udfs_cache.get(hash)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def cache_external_udf(hash: int, udf):
|
|
24
|
+
Session.get_active_session().external_udfs_cache[hash] = udf
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def clear_external_udtf_cache(session: Session) -> None:
|
|
28
|
+
session.external_udtfs_cache.clear()
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def get_external_udtf_from_cache(hash: int):
|
|
32
|
+
return Session.get_active_session().external_udtfs_cache.get(hash)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def cache_external_udtf(hash: int, udf):
|
|
36
|
+
Session.get_active_session().external_udtfs_cache[hash] = udf
|
|
@@ -87,9 +87,93 @@ def get_map_in_arrow_udtf(
|
|
|
87
87
|
def create_pandas_udtf(
|
|
88
88
|
udtf_proto: CommonInlineUserDefinedFunction,
|
|
89
89
|
spark_column_names: list[str],
|
|
90
|
-
input_schema: StructType
|
|
91
|
-
return_schema: StructType
|
|
90
|
+
input_schema: StructType,
|
|
91
|
+
return_schema: StructType,
|
|
92
|
+
):
|
|
93
|
+
user_function, _ = cloudpickle.loads(udtf_proto.python_udf.command)
|
|
94
|
+
output_column_names = [field.name for field in return_schema.fields]
|
|
95
|
+
output_column_original_names = [
|
|
96
|
+
field.original_column_identifier for field in return_schema.fields
|
|
97
|
+
]
|
|
98
|
+
|
|
99
|
+
class MapPandasUDTF:
|
|
100
|
+
def __init__(self) -> None:
|
|
101
|
+
self.user_function = user_function
|
|
102
|
+
self.output_column_names = output_column_names
|
|
103
|
+
self.spark_column_names = spark_column_names
|
|
104
|
+
self.output_column_original_names = output_column_original_names
|
|
105
|
+
|
|
106
|
+
def end_partition(self, df: pd.DataFrame):
|
|
107
|
+
if df.empty:
|
|
108
|
+
empty_df = pd.DataFrame(columns=self.output_column_names)
|
|
109
|
+
yield empty_df
|
|
110
|
+
return
|
|
111
|
+
|
|
112
|
+
df_without_dummy = df.drop(
|
|
113
|
+
columns=["_DUMMY_PARTITION_KEY"], errors="ignore"
|
|
114
|
+
)
|
|
115
|
+
df_without_dummy.columns = self.spark_column_names
|
|
116
|
+
result_iterator = self.user_function(
|
|
117
|
+
[pd.DataFrame([row]) for _, row in df_without_dummy.iterrows()]
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
if not isinstance(result_iterator, Iterator) and not hasattr(
|
|
121
|
+
result_iterator, "__iter__"
|
|
122
|
+
):
|
|
123
|
+
raise RuntimeError(
|
|
124
|
+
f"snowpark_connect::UDF_RETURN_TYPE Return type of the user-defined function should be "
|
|
125
|
+
f"iterator of pandas.DataFrame, but is {type(result_iterator).__name__}"
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
output_df = pd.concat(result_iterator)
|
|
129
|
+
generated_output_column_names = list(output_df.columns)
|
|
130
|
+
|
|
131
|
+
missing_columns = []
|
|
132
|
+
for original_column in self.output_column_original_names:
|
|
133
|
+
if original_column not in generated_output_column_names:
|
|
134
|
+
missing_columns.append(original_column)
|
|
135
|
+
|
|
136
|
+
if missing_columns:
|
|
137
|
+
unexpected_columns = [
|
|
138
|
+
column
|
|
139
|
+
for column in generated_output_column_names
|
|
140
|
+
if column not in self.output_column_original_names
|
|
141
|
+
]
|
|
142
|
+
raise RuntimeError(
|
|
143
|
+
f"[RESULT_COLUMNS_MISMATCH_FOR_PANDAS_UDF] Column names of the returned pandas.DataFrame do not match specified schema. Missing: {', '.join(sorted(missing_columns))}. Unexpected: {', '.join(sorted(unexpected_columns))}"
|
|
144
|
+
"."
|
|
145
|
+
)
|
|
146
|
+
reordered_df = output_df[self.output_column_original_names]
|
|
147
|
+
reordered_df.columns = self.output_column_names
|
|
148
|
+
yield reordered_df
|
|
149
|
+
|
|
150
|
+
return snowpark_fn.pandas_udtf(
|
|
151
|
+
MapPandasUDTF,
|
|
152
|
+
output_schema=PandasDataFrameType(
|
|
153
|
+
[field.datatype for field in return_schema.fields],
|
|
154
|
+
[field.name for field in return_schema.fields],
|
|
155
|
+
),
|
|
156
|
+
input_types=[
|
|
157
|
+
PandasDataFrameType(
|
|
158
|
+
[field.datatype for field in input_schema.fields] + [IntegerType()]
|
|
159
|
+
)
|
|
160
|
+
],
|
|
161
|
+
input_names=[field.name for field in input_schema.fields]
|
|
162
|
+
+ ["_DUMMY_PARTITION_KEY"],
|
|
163
|
+
name="map_pandas_udtf",
|
|
164
|
+
replace=True,
|
|
165
|
+
packages=["pandas"],
|
|
166
|
+
is_permanent=False,
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def create_pandas_udtf_with_arrow(
|
|
171
|
+
udtf_proto: CommonInlineUserDefinedFunction,
|
|
172
|
+
spark_column_names: list[str],
|
|
173
|
+
input_schema: StructType,
|
|
174
|
+
return_schema: StructType,
|
|
92
175
|
) -> str | snowpark.udtf.UserDefinedTableFunction:
|
|
176
|
+
|
|
93
177
|
user_function, _ = cloudpickle.loads(udtf_proto.python_udf.command)
|
|
94
178
|
output_column_names = [field.name for field in return_schema.fields]
|
|
95
179
|
|