snowpark-connect 0.25.0__py3-none-any.whl → 0.26.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of snowpark-connect might be problematic. Click here for more details.

Files changed (32) hide show
  1. snowflake/snowpark_connect/config.py +10 -0
  2. snowflake/snowpark_connect/dataframe_container.py +16 -0
  3. snowflake/snowpark_connect/expression/map_udf.py +68 -27
  4. snowflake/snowpark_connect/expression/map_unresolved_function.py +22 -21
  5. snowflake/snowpark_connect/includes/jars/json4s-native_2.12-3.7.0-M11.jar +0 -0
  6. snowflake/snowpark_connect/includes/jars/paranamer-2.8.3.jar +0 -0
  7. snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.1.0.jar +0 -0
  8. snowflake/snowpark_connect/relation/map_map_partitions.py +9 -4
  9. snowflake/snowpark_connect/relation/map_relation.py +12 -1
  10. snowflake/snowpark_connect/relation/map_row_ops.py +8 -1
  11. snowflake/snowpark_connect/relation/map_udtf.py +96 -44
  12. snowflake/snowpark_connect/relation/utils.py +44 -0
  13. snowflake/snowpark_connect/relation/write/map_write.py +113 -22
  14. snowflake/snowpark_connect/resources_initializer.py +18 -5
  15. snowflake/snowpark_connect/server.py +8 -1
  16. snowflake/snowpark_connect/utils/concurrent.py +4 -0
  17. snowflake/snowpark_connect/utils/external_udxf_cache.py +36 -0
  18. snowflake/snowpark_connect/utils/scala_udf_utils.py +250 -242
  19. snowflake/snowpark_connect/utils/session.py +4 -0
  20. snowflake/snowpark_connect/utils/udf_utils.py +7 -17
  21. snowflake/snowpark_connect/utils/udtf_utils.py +3 -16
  22. snowflake/snowpark_connect/version.py +1 -1
  23. {snowpark_connect-0.25.0.dist-info → snowpark_connect-0.26.0.dist-info}/METADATA +1 -1
  24. {snowpark_connect-0.25.0.dist-info → snowpark_connect-0.26.0.dist-info}/RECORD +32 -28
  25. {snowpark_connect-0.25.0.data → snowpark_connect-0.26.0.data}/scripts/snowpark-connect +0 -0
  26. {snowpark_connect-0.25.0.data → snowpark_connect-0.26.0.data}/scripts/snowpark-session +0 -0
  27. {snowpark_connect-0.25.0.data → snowpark_connect-0.26.0.data}/scripts/snowpark-submit +0 -0
  28. {snowpark_connect-0.25.0.dist-info → snowpark_connect-0.26.0.dist-info}/WHEEL +0 -0
  29. {snowpark_connect-0.25.0.dist-info → snowpark_connect-0.26.0.dist-info}/licenses/LICENSE-binary +0 -0
  30. {snowpark_connect-0.25.0.dist-info → snowpark_connect-0.26.0.dist-info}/licenses/LICENSE.txt +0 -0
  31. {snowpark_connect-0.25.0.dist-info → snowpark_connect-0.26.0.dist-info}/licenses/NOTICE-binary +0 -0
  32. {snowpark_connect-0.25.0.dist-info → snowpark_connect-0.26.0.dist-info}/top_level.txt +0 -0
@@ -16,7 +16,7 @@ from snowflake.snowpark._internal.analyzer.analyzer_utils import (
16
16
  unquote_if_quoted,
17
17
  )
18
18
  from snowflake.snowpark.exceptions import SnowparkSQLException
19
- from snowflake.snowpark.functions import col, lit, object_construct
19
+ from snowflake.snowpark.functions import col, lit, object_construct, sql_expr
20
20
  from snowflake.snowpark.types import (
21
21
  ArrayType,
22
22
  DataType,
@@ -40,7 +40,10 @@ from snowflake.snowpark_connect.relation.io_utils import (
40
40
  from snowflake.snowpark_connect.relation.map_relation import map_relation
41
41
  from snowflake.snowpark_connect.relation.read.reader_config import CsvWriterConfig
42
42
  from snowflake.snowpark_connect.relation.stage_locator import get_paths_from_stage
43
- from snowflake.snowpark_connect.relation.utils import random_string
43
+ from snowflake.snowpark_connect.relation.utils import (
44
+ generate_spark_compatible_filename,
45
+ random_string,
46
+ )
44
47
  from snowflake.snowpark_connect.type_mapping import snowpark_to_iceberg_type
45
48
  from snowflake.snowpark_connect.utils.context import get_session_id
46
49
  from snowflake.snowpark_connect.utils.identifiers import (
@@ -133,45 +136,80 @@ def map_write(request: proto_base.ExecutePlanRequest):
133
136
  write_op.source = ""
134
137
 
135
138
  should_write_to_single_file = str_to_bool(write_op.options.get("single", "false"))
136
- if should_write_to_single_file:
137
- # providing default size as 1GB
138
- max_file_size = int(
139
- write_op.options.get("snowflake_max_file_size", "1073741824")
140
- )
139
+
140
+ # Support Snowflake-specific snowflake_max_file_size option. This is NOT a spark option.
141
+ max_file_size = None
142
+ if (
143
+ "snowflake_max_file_size" in write_op.options
144
+ and int(write_op.options["snowflake_max_file_size"]) > 0
145
+ ):
146
+ max_file_size = int(write_op.options["snowflake_max_file_size"])
147
+ elif should_write_to_single_file:
148
+ # providing default size as 1GB for single file write
149
+ max_file_size = 1073741824
141
150
  match write_op.source:
142
151
  case "csv" | "parquet" | "json" | "text":
143
152
  write_path = get_paths_from_stage(
144
153
  [write_op.path],
145
154
  session=session,
146
155
  )[0]
147
- # we need a random prefix to support "append" mode
148
- # otherwise copy into with overwrite=False will fail if the file already exists
156
+ # Generate Spark-compatible filename with proper extension
157
+ extension = write_op.source if write_op.source != "text" else "txt"
158
+
159
+ # Get compression from options for proper filename generation
160
+ compression_option = write_op.options.get("compression", "none")
161
+
162
+ # Generate Spark-compatible filename or prefix
149
163
  if should_write_to_single_file:
150
- extention = write_op.source if write_op.source != "text" else "txt"
151
- temp_file_prefix_on_stage = (
152
- f"{write_path}/{random_string(10, 'sas_file_')}.{extention}"
164
+ # Single file: generate complete filename with extension
165
+ spark_filename = generate_spark_compatible_filename(
166
+ task_id=0,
167
+ attempt_number=0,
168
+ compression=compression_option,
169
+ format_ext=extension,
153
170
  )
171
+ temp_file_prefix_on_stage = f"{write_path}/{spark_filename}"
154
172
  else:
155
- temp_file_prefix_on_stage = (
156
- f"{write_path}/{random_string(10, 'sas_file_')}"
173
+ # Multiple files: generate prefix without extension (Snowflake will add extensions)
174
+ spark_filename_prefix = generate_spark_compatible_filename(
175
+ task_id=0,
176
+ attempt_number=0,
177
+ compression=compression_option,
178
+ format_ext="", # No extension for prefix
157
179
  )
180
+ temp_file_prefix_on_stage = f"{write_path}/{spark_filename_prefix}"
158
181
  overwrite = (
159
182
  write_op.mode
160
183
  == commands_proto.WriteOperation.SaveMode.SAVE_MODE_OVERWRITE
161
184
  )
185
+
186
+ default_compression = "NONE" if write_op.source != "parquet" else "snappy"
187
+ compression = write_op.options.get(
188
+ "compression", default_compression
189
+ ).upper()
162
190
  parameters = {
163
191
  "location": temp_file_prefix_on_stage,
164
192
  "file_format_type": write_op.source
165
193
  if write_op.source != "text"
166
194
  else "csv",
167
195
  "format_type_options": {
168
- "COMPRESSION": "NONE",
196
+ "COMPRESSION": compression,
169
197
  },
170
198
  "overwrite": overwrite,
171
199
  }
172
- if should_write_to_single_file:
173
- parameters["single"] = True
200
+ # By default, download from the same prefix we wrote to.
201
+ download_stage_path = temp_file_prefix_on_stage
202
+
203
+ # Check for partition hint early to determine precedence over single option
204
+ partition_hint = result.partition_hint
205
+
206
+ # Apply max_file_size for both single and multi-file scenarios
207
+ # This helps control when Snowflake splits files into multiple parts
208
+ if max_file_size:
174
209
  parameters["max_file_size"] = max_file_size
210
+ # Only apply single option if no partition hint is present (partition hint takes precedence)
211
+ if should_write_to_single_file and partition_hint is None:
212
+ parameters["single"] = True
175
213
  rewritten_df: snowpark.DataFrame = rewrite_df(input_df, write_op.source)
176
214
  get_param_from_options(parameters, write_op.options, write_op.source)
177
215
  if write_op.partitioning_columns:
@@ -186,10 +224,50 @@ def map_write(request: proto_base.ExecutePlanRequest):
186
224
  )
187
225
  else:
188
226
  parameters["partition_by"] = partitioning_columns[0]
189
- rewritten_df.write.copy_into_location(**parameters)
227
+
228
+ # If a partition hint is present (from DataFrame.repartition(n)), optionally split the
229
+ # write into n COPY INTO calls by assigning a synthetic partition id. Controlled by config.
230
+ # Note: This affects only the number of output files, not computation semantics.
231
+ # Partition hints take precedence over single option (matches Spark behavior) when enabled.
232
+ repartition_for_writes_enabled = (
233
+ global_config.snowflake_repartition_for_writes
234
+ )
235
+ if repartition_for_writes_enabled and partition_hint and partition_hint > 0:
236
+ # Create a stable synthetic file number per row using ROW_NUMBER() over a
237
+ # randomized order, then modulo partition_hint. We rely on sql_expr to avoid
238
+ # adding new helpers.
239
+ file_num_col = "_sas_file_num"
240
+ partitioned_df = rewritten_df.withColumn(
241
+ file_num_col,
242
+ sql_expr(
243
+ f"(ROW_NUMBER() OVER (ORDER BY RANDOM())) % {partition_hint}"
244
+ ),
245
+ )
246
+
247
+ # Execute multiple COPY INTO operations, one per target file.
248
+ # Since we write per-partition with distinct prefixes, download from the base write path.
249
+ download_stage_path = write_path
250
+ for part_idx in range(partition_hint):
251
+ part_params = dict(parameters)
252
+ # Preserve Spark-like filename prefix per partition so downloaded basenames
253
+ # match the expected Spark pattern (with possible Snowflake counters appended).
254
+ per_part_prefix = generate_spark_compatible_filename(
255
+ task_id=part_idx,
256
+ attempt_number=0,
257
+ compression=compression_option,
258
+ format_ext="", # prefix only; Snowflake appends extension/counters
259
+ )
260
+ part_params["location"] = f"{write_path}/{per_part_prefix}"
261
+ (
262
+ partitioned_df.filter(col(file_num_col) == lit(part_idx))
263
+ .drop(file_num_col)
264
+ .write.copy_into_location(**part_params)
265
+ )
266
+ else:
267
+ rewritten_df.write.copy_into_location(**parameters)
190
268
  if not is_cloud_path(write_op.path):
191
269
  store_files_locally(
192
- temp_file_prefix_on_stage,
270
+ download_stage_path,
193
271
  write_op.path,
194
272
  overwrite,
195
273
  session,
@@ -569,7 +647,12 @@ def _validate_schema_and_get_writer(
569
647
  col_name = field.name
570
648
  renamed = col_name
571
649
  matching_field = next(
572
- (f for f in table_schema.fields if f.name.lower() == col_name.lower()),
650
+ (
651
+ f
652
+ for f in table_schema.fields
653
+ if unquote_if_quoted(f.name).lower()
654
+ == unquote_if_quoted(col_name).lower()
655
+ ),
573
656
  None,
574
657
  )
575
658
  if matching_field is not None and matching_field != col_name:
@@ -591,7 +674,10 @@ def _validate_schema_and_get_writer(
591
674
 
592
675
 
593
676
  def _validate_schema_for_append(
594
- table_schema: DataType, data_schema: DataType, snowpark_table_name: str
677
+ table_schema: DataType,
678
+ data_schema: DataType,
679
+ snowpark_table_name: str,
680
+ compare_structs: bool = False,
595
681
  ):
596
682
  match (table_schema, data_schema):
597
683
  case (_, _) if table_schema == data_schema:
@@ -600,7 +686,11 @@ def _validate_schema_for_append(
600
686
  case (StructType() as table_struct, StructType() as data_struct):
601
687
 
602
688
  def _comparable_col_name(col: str) -> str:
603
- return col if global_config.spark_sql_caseSensitive else col.lower()
689
+ name = col if global_config.spark_sql_caseSensitive else col.lower()
690
+ if compare_structs:
691
+ return name
692
+ else:
693
+ return unquote_if_quoted(name)
604
694
 
605
695
  def invalid_struct_schema():
606
696
  raise AnalysisException(
@@ -640,6 +730,7 @@ def _validate_schema_for_append(
640
730
  matching_table_field.datatype,
641
731
  data_field.datatype,
642
732
  snowpark_table_name,
733
+ compare_structs=True,
643
734
  )
644
735
 
645
736
  return
@@ -9,6 +9,7 @@ from snowflake.snowpark_connect.utils.session import get_or_create_snowpark_sess
9
9
  from snowflake.snowpark_connect.utils.snowpark_connect_logging import logger
10
10
 
11
11
  _resources_initialized = threading.Event()
12
+ _initializer_lock = threading.Lock()
12
13
  SPARK_VERSION = "3.5.6"
13
14
  RESOURCE_PATH = "/snowflake/snowpark_connect/resources"
14
15
 
@@ -57,6 +58,9 @@ def initialize_resources() -> None:
57
58
  f"spark-connect-client-jvm_2.12-{SPARK_VERSION}.jar",
58
59
  f"spark-common-utils_2.12-{SPARK_VERSION}.jar",
59
60
  "json4s-ast_2.12-3.7.0-M11.jar",
61
+ "json4s-native_2.12-3.7.0-M11.jar",
62
+ "json4s-core_2.12-3.7.0-M11.jar",
63
+ "paranamer-2.8.3.jar",
60
64
  ]
61
65
 
62
66
  for jar in jar_files:
@@ -94,10 +98,19 @@ def initialize_resources() -> None:
94
98
  logger.info(f"All resources initialized in {time.time() - start_time:.2f}s")
95
99
 
96
100
 
101
+ _resource_initializer = threading.Thread(
102
+ target=initialize_resources, name="ResourceInitializer"
103
+ )
104
+
105
+
97
106
  def initialize_resources_async() -> threading.Thread:
98
107
  """Start resource initialization in background."""
99
- thread = threading.Thread(
100
- target=initialize_resources, name="ResourceInitializer", daemon=True
101
- )
102
- thread.start()
103
- return thread
108
+ with _initializer_lock:
109
+ if not _resource_initializer.is_alive() and _resource_initializer.ident is None:
110
+ _resource_initializer.start()
111
+ return _resource_initializer
112
+
113
+
114
+ def wait_for_resource_initialization() -> None:
115
+ with _initializer_lock:
116
+ _resource_initializer.join()
@@ -88,6 +88,9 @@ from snowflake.snowpark_connect.utils.context import (
88
88
  set_spark_version,
89
89
  )
90
90
  from snowflake.snowpark_connect.utils.env_utils import get_int_from_env
91
+ from snowflake.snowpark_connect.utils.external_udxf_cache import (
92
+ clear_external_udxf_cache,
93
+ )
91
94
  from snowflake.snowpark_connect.utils.interrupt import (
92
95
  interrupt_all_queries,
93
96
  interrupt_queries_with_tag,
@@ -436,7 +439,8 @@ class SnowflakeConnectServicer(proto_base_grpc.SparkConnectServiceServicer):
436
439
  lambda: map_local_relation(relation), # noqa: B023
437
440
  materialize=True,
438
441
  )
439
- except Exception:
442
+ except Exception as e:
443
+ logger.warning("Failed to put df into cache: %s", str(e))
440
444
  # fallback - treat as regular artifact
441
445
  _handle_regular_artifact()
442
446
  else:
@@ -556,6 +560,9 @@ class SnowflakeConnectServicer(proto_base_grpc.SparkConnectServiceServicer):
556
560
  if class_files:
557
561
  write_class_files_to_stage(session, class_files)
558
562
 
563
+ if any(not name.startswith("cache") for name in filenames.keys()):
564
+ clear_external_udxf_cache(session)
565
+
559
566
  return proto_base.AddArtifactsResponse(artifacts=list(response.values()))
560
567
 
561
568
  def ArtifactStatus(self, request, context):
@@ -64,6 +64,10 @@ class SynchronizedDict(Mapping[K, V]):
64
64
  with self._lock.reader():
65
65
  return iter(list(self._dict.items()))
66
66
 
67
+ def clear(self) -> None:
68
+ with self._lock.writer():
69
+ self._dict.clear()
70
+
67
71
 
68
72
  class ReadWriteLock:
69
73
  class _Reader:
@@ -0,0 +1,36 @@
1
+ #
2
+ # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
3
+ #
4
+
5
+ from snowflake.snowpark import Session
6
+ from snowflake.snowpark_connect.utils.concurrent import SynchronizedDict
7
+
8
+
9
+ def init_external_udxf_cache(session: Session) -> None:
10
+ session.external_udfs_cache = SynchronizedDict()
11
+ session.external_udtfs_cache = SynchronizedDict()
12
+
13
+
14
+ def clear_external_udxf_cache(session: Session) -> None:
15
+ session.external_udfs_cache.clear()
16
+ session.external_udtfs_cache.clear()
17
+
18
+
19
+ def get_external_udf_from_cache(hash: str):
20
+ return Session.get_active_session().external_udfs_cache.get(hash)
21
+
22
+
23
+ def cache_external_udf(hash: int, udf):
24
+ Session.get_active_session().external_udfs_cache[hash] = udf
25
+
26
+
27
+ def clear_external_udtf_cache(session: Session) -> None:
28
+ session.external_udtfs_cache.clear()
29
+
30
+
31
+ def get_external_udtf_from_cache(hash: int):
32
+ return Session.get_active_session().external_udtfs_cache.get(hash)
33
+
34
+
35
+ def cache_external_udtf(hash: int, udf):
36
+ Session.get_active_session().external_udtfs_cache[hash] = udf