snowpark-connect 0.25.0__py3-none-any.whl → 0.27.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. snowflake/snowpark_connect/config.py +10 -3
  2. snowflake/snowpark_connect/dataframe_container.py +16 -0
  3. snowflake/snowpark_connect/expression/map_expression.py +15 -0
  4. snowflake/snowpark_connect/expression/map_udf.py +68 -27
  5. snowflake/snowpark_connect/expression/map_unresolved_attribute.py +18 -0
  6. snowflake/snowpark_connect/expression/map_unresolved_function.py +38 -28
  7. snowflake/snowpark_connect/includes/jars/json4s-native_2.12-3.7.0-M11.jar +0 -0
  8. snowflake/snowpark_connect/includes/jars/paranamer-2.8.3.jar +0 -0
  9. snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.1.0.jar +0 -0
  10. snowflake/snowpark_connect/relation/map_extension.py +9 -7
  11. snowflake/snowpark_connect/relation/map_map_partitions.py +36 -72
  12. snowflake/snowpark_connect/relation/map_relation.py +15 -2
  13. snowflake/snowpark_connect/relation/map_row_ops.py +8 -1
  14. snowflake/snowpark_connect/relation/map_show_string.py +2 -0
  15. snowflake/snowpark_connect/relation/map_sql.py +63 -2
  16. snowflake/snowpark_connect/relation/map_udtf.py +96 -44
  17. snowflake/snowpark_connect/relation/utils.py +44 -0
  18. snowflake/snowpark_connect/relation/write/map_write.py +135 -24
  19. snowflake/snowpark_connect/resources_initializer.py +18 -5
  20. snowflake/snowpark_connect/server.py +12 -2
  21. snowflake/snowpark_connect/utils/artifacts.py +4 -5
  22. snowflake/snowpark_connect/utils/concurrent.py +4 -0
  23. snowflake/snowpark_connect/utils/context.py +41 -1
  24. snowflake/snowpark_connect/utils/external_udxf_cache.py +36 -0
  25. snowflake/snowpark_connect/utils/pandas_udtf_utils.py +86 -2
  26. snowflake/snowpark_connect/utils/scala_udf_utils.py +250 -242
  27. snowflake/snowpark_connect/utils/session.py +4 -0
  28. snowflake/snowpark_connect/utils/udf_utils.py +71 -118
  29. snowflake/snowpark_connect/utils/udtf_helper.py +17 -7
  30. snowflake/snowpark_connect/utils/udtf_utils.py +3 -16
  31. snowflake/snowpark_connect/version.py +2 -3
  32. {snowpark_connect-0.25.0.dist-info → snowpark_connect-0.27.0.dist-info}/METADATA +2 -2
  33. {snowpark_connect-0.25.0.dist-info → snowpark_connect-0.27.0.dist-info}/RECORD +41 -37
  34. {snowpark_connect-0.25.0.data → snowpark_connect-0.27.0.data}/scripts/snowpark-connect +0 -0
  35. {snowpark_connect-0.25.0.data → snowpark_connect-0.27.0.data}/scripts/snowpark-session +0 -0
  36. {snowpark_connect-0.25.0.data → snowpark_connect-0.27.0.data}/scripts/snowpark-submit +0 -0
  37. {snowpark_connect-0.25.0.dist-info → snowpark_connect-0.27.0.dist-info}/WHEEL +0 -0
  38. {snowpark_connect-0.25.0.dist-info → snowpark_connect-0.27.0.dist-info}/licenses/LICENSE-binary +0 -0
  39. {snowpark_connect-0.25.0.dist-info → snowpark_connect-0.27.0.dist-info}/licenses/LICENSE.txt +0 -0
  40. {snowpark_connect-0.25.0.dist-info → snowpark_connect-0.27.0.dist-info}/licenses/NOTICE-binary +0 -0
  41. {snowpark_connect-0.25.0.dist-info → snowpark_connect-0.27.0.dist-info}/top_level.txt +0 -0
@@ -16,7 +16,7 @@ from snowflake.snowpark._internal.analyzer.analyzer_utils import (
16
16
  unquote_if_quoted,
17
17
  )
18
18
  from snowflake.snowpark.exceptions import SnowparkSQLException
19
- from snowflake.snowpark.functions import col, lit, object_construct
19
+ from snowflake.snowpark.functions import col, lit, object_construct, sql_expr
20
20
  from snowflake.snowpark.types import (
21
21
  ArrayType,
22
22
  DataType,
@@ -40,7 +40,10 @@ from snowflake.snowpark_connect.relation.io_utils import (
40
40
  from snowflake.snowpark_connect.relation.map_relation import map_relation
41
41
  from snowflake.snowpark_connect.relation.read.reader_config import CsvWriterConfig
42
42
  from snowflake.snowpark_connect.relation.stage_locator import get_paths_from_stage
43
- from snowflake.snowpark_connect.relation.utils import random_string
43
+ from snowflake.snowpark_connect.relation.utils import (
44
+ generate_spark_compatible_filename,
45
+ random_string,
46
+ )
44
47
  from snowflake.snowpark_connect.type_mapping import snowpark_to_iceberg_type
45
48
  from snowflake.snowpark_connect.utils.context import get_session_id
46
49
  from snowflake.snowpark_connect.utils.identifiers import (
@@ -48,6 +51,7 @@ from snowflake.snowpark_connect.utils.identifiers import (
48
51
  split_fully_qualified_spark_name,
49
52
  )
50
53
  from snowflake.snowpark_connect.utils.session import get_or_create_snowpark_session
54
+ from snowflake.snowpark_connect.utils.snowpark_connect_logging import logger
51
55
  from snowflake.snowpark_connect.utils.telemetry import (
52
56
  SnowparkConnectNotImplementedError,
53
57
  telemetry,
@@ -133,45 +137,99 @@ def map_write(request: proto_base.ExecutePlanRequest):
133
137
  write_op.source = ""
134
138
 
135
139
  should_write_to_single_file = str_to_bool(write_op.options.get("single", "false"))
136
- if should_write_to_single_file:
137
- # providing default size as 1GB
138
- max_file_size = int(
139
- write_op.options.get("snowflake_max_file_size", "1073741824")
140
- )
140
+
141
+ # Support Snowflake-specific snowflake_max_file_size option. This is NOT a spark option.
142
+ max_file_size = None
143
+ if (
144
+ "snowflake_max_file_size" in write_op.options
145
+ and int(write_op.options["snowflake_max_file_size"]) > 0
146
+ ):
147
+ max_file_size = int(write_op.options["snowflake_max_file_size"])
148
+ elif should_write_to_single_file:
149
+ # providing default size as 1GB for single file write
150
+ max_file_size = 1073741824
141
151
  match write_op.source:
142
152
  case "csv" | "parquet" | "json" | "text":
143
153
  write_path = get_paths_from_stage(
144
154
  [write_op.path],
145
155
  session=session,
146
156
  )[0]
157
+ # Generate Spark-compatible filename with proper extension
158
+ extension = write_op.source if write_op.source != "text" else "txt"
159
+
160
+ # Get compression from options for proper filename generation
161
+ compression_option = write_op.options.get("compression", "none")
162
+
163
+ # Generate Spark-compatible filename or prefix
147
164
  # we need a random prefix to support "append" mode
148
165
  # otherwise copy into with overwrite=False will fail if the file already exists
149
- if should_write_to_single_file:
150
- extention = write_op.source if write_op.source != "text" else "txt"
151
- temp_file_prefix_on_stage = (
152
- f"{write_path}/{random_string(10, 'sas_file_')}.{extention}"
153
- )
154
- else:
155
- temp_file_prefix_on_stage = (
156
- f"{write_path}/{random_string(10, 'sas_file_')}"
157
- )
158
166
  overwrite = (
159
167
  write_op.mode
160
168
  == commands_proto.WriteOperation.SaveMode.SAVE_MODE_OVERWRITE
161
169
  )
170
+
171
+ if overwrite:
172
+ try:
173
+ path_after_stage = (
174
+ write_path.split("/", 1)[1] if "/" in write_path else ""
175
+ )
176
+ if not path_after_stage or path_after_stage == "/":
177
+ logger.warning(
178
+ f"Skipping REMOVE for root path {write_path} - too broad scope"
179
+ )
180
+ else:
181
+ remove_command = f"REMOVE {write_path}/"
182
+ session.sql(remove_command).collect()
183
+ logger.info(f"Successfully cleared directory: {write_path}")
184
+ except Exception as e:
185
+ logger.warning(f"Could not clear directory {write_path}: {e}")
186
+
187
+ if should_write_to_single_file:
188
+ # Single file: generate complete filename with extension
189
+ spark_filename = generate_spark_compatible_filename(
190
+ task_id=0,
191
+ attempt_number=0,
192
+ compression=compression_option,
193
+ format_ext=extension,
194
+ )
195
+ temp_file_prefix_on_stage = f"{write_path}/{spark_filename}"
196
+ else:
197
+ # Multiple files: generate prefix without extension (Snowflake will add extensions)
198
+ spark_filename_prefix = generate_spark_compatible_filename(
199
+ task_id=0,
200
+ attempt_number=0,
201
+ compression=compression_option,
202
+ format_ext="", # No extension for prefix
203
+ )
204
+ temp_file_prefix_on_stage = f"{write_path}/{spark_filename_prefix}"
205
+
206
+ default_compression = "NONE" if write_op.source != "parquet" else "snappy"
207
+ compression = write_op.options.get(
208
+ "compression", default_compression
209
+ ).upper()
162
210
  parameters = {
163
211
  "location": temp_file_prefix_on_stage,
164
212
  "file_format_type": write_op.source
165
213
  if write_op.source != "text"
166
214
  else "csv",
167
215
  "format_type_options": {
168
- "COMPRESSION": "NONE",
216
+ "COMPRESSION": compression,
169
217
  },
170
218
  "overwrite": overwrite,
171
219
  }
172
- if should_write_to_single_file:
173
- parameters["single"] = True
220
+ # By default, download from the same prefix we wrote to.
221
+ download_stage_path = temp_file_prefix_on_stage
222
+
223
+ # Check for partition hint early to determine precedence over single option
224
+ partition_hint = result.partition_hint
225
+
226
+ # Apply max_file_size for both single and multi-file scenarios
227
+ # This helps control when Snowflake splits files into multiple parts
228
+ if max_file_size:
174
229
  parameters["max_file_size"] = max_file_size
230
+ # Only apply single option if no partition hint is present (partition hint takes precedence)
231
+ if should_write_to_single_file and partition_hint is None:
232
+ parameters["single"] = True
175
233
  rewritten_df: snowpark.DataFrame = rewrite_df(input_df, write_op.source)
176
234
  get_param_from_options(parameters, write_op.options, write_op.source)
177
235
  if write_op.partitioning_columns:
@@ -186,10 +244,50 @@ def map_write(request: proto_base.ExecutePlanRequest):
186
244
  )
187
245
  else:
188
246
  parameters["partition_by"] = partitioning_columns[0]
189
- rewritten_df.write.copy_into_location(**parameters)
247
+
248
+ # If a partition hint is present (from DataFrame.repartition(n)), optionally split the
249
+ # write into n COPY INTO calls by assigning a synthetic partition id. Controlled by config.
250
+ # Note: This affects only the number of output files, not computation semantics.
251
+ # Partition hints take precedence over single option (matches Spark behavior) when enabled.
252
+ repartition_for_writes_enabled = (
253
+ global_config.snowflake_repartition_for_writes
254
+ )
255
+ if repartition_for_writes_enabled and partition_hint and partition_hint > 0:
256
+ # Create a stable synthetic file number per row using ROW_NUMBER() over a
257
+ # randomized order, then modulo partition_hint. We rely on sql_expr to avoid
258
+ # adding new helpers.
259
+ file_num_col = "_sas_file_num"
260
+ partitioned_df = rewritten_df.withColumn(
261
+ file_num_col,
262
+ sql_expr(
263
+ f"(ROW_NUMBER() OVER (ORDER BY RANDOM())) % {partition_hint}"
264
+ ),
265
+ )
266
+
267
+ # Execute multiple COPY INTO operations, one per target file.
268
+ # Since we write per-partition with distinct prefixes, download from the base write path.
269
+ download_stage_path = write_path
270
+ for part_idx in range(partition_hint):
271
+ part_params = dict(parameters)
272
+ # Preserve Spark-like filename prefix per partition so downloaded basenames
273
+ # match the expected Spark pattern (with possible Snowflake counters appended).
274
+ per_part_prefix = generate_spark_compatible_filename(
275
+ task_id=part_idx,
276
+ attempt_number=0,
277
+ compression=compression_option,
278
+ format_ext="", # prefix only; Snowflake appends extension/counters
279
+ )
280
+ part_params["location"] = f"{write_path}/{per_part_prefix}"
281
+ (
282
+ partitioned_df.filter(col(file_num_col) == lit(part_idx))
283
+ .drop(file_num_col)
284
+ .write.copy_into_location(**part_params)
285
+ )
286
+ else:
287
+ rewritten_df.write.copy_into_location(**parameters)
190
288
  if not is_cloud_path(write_op.path):
191
289
  store_files_locally(
192
- temp_file_prefix_on_stage,
290
+ download_stage_path,
193
291
  write_op.path,
194
292
  overwrite,
195
293
  session,
@@ -569,7 +667,12 @@ def _validate_schema_and_get_writer(
569
667
  col_name = field.name
570
668
  renamed = col_name
571
669
  matching_field = next(
572
- (f for f in table_schema.fields if f.name.lower() == col_name.lower()),
670
+ (
671
+ f
672
+ for f in table_schema.fields
673
+ if unquote_if_quoted(f.name).lower()
674
+ == unquote_if_quoted(col_name).lower()
675
+ ),
573
676
  None,
574
677
  )
575
678
  if matching_field is not None and matching_field != col_name:
@@ -591,7 +694,10 @@ def _validate_schema_and_get_writer(
591
694
 
592
695
 
593
696
  def _validate_schema_for_append(
594
- table_schema: DataType, data_schema: DataType, snowpark_table_name: str
697
+ table_schema: DataType,
698
+ data_schema: DataType,
699
+ snowpark_table_name: str,
700
+ compare_structs: bool = False,
595
701
  ):
596
702
  match (table_schema, data_schema):
597
703
  case (_, _) if table_schema == data_schema:
@@ -600,7 +706,11 @@ def _validate_schema_for_append(
600
706
  case (StructType() as table_struct, StructType() as data_struct):
601
707
 
602
708
  def _comparable_col_name(col: str) -> str:
603
- return col if global_config.spark_sql_caseSensitive else col.lower()
709
+ name = col if global_config.spark_sql_caseSensitive else col.lower()
710
+ if compare_structs:
711
+ return name
712
+ else:
713
+ return unquote_if_quoted(name)
604
714
 
605
715
  def invalid_struct_schema():
606
716
  raise AnalysisException(
@@ -640,6 +750,7 @@ def _validate_schema_for_append(
640
750
  matching_table_field.datatype,
641
751
  data_field.datatype,
642
752
  snowpark_table_name,
753
+ compare_structs=True,
643
754
  )
644
755
 
645
756
  return
@@ -9,6 +9,7 @@ from snowflake.snowpark_connect.utils.session import get_or_create_snowpark_sess
9
9
  from snowflake.snowpark_connect.utils.snowpark_connect_logging import logger
10
10
 
11
11
  _resources_initialized = threading.Event()
12
+ _initializer_lock = threading.Lock()
12
13
  SPARK_VERSION = "3.5.6"
13
14
  RESOURCE_PATH = "/snowflake/snowpark_connect/resources"
14
15
 
@@ -57,6 +58,9 @@ def initialize_resources() -> None:
57
58
  f"spark-connect-client-jvm_2.12-{SPARK_VERSION}.jar",
58
59
  f"spark-common-utils_2.12-{SPARK_VERSION}.jar",
59
60
  "json4s-ast_2.12-3.7.0-M11.jar",
61
+ "json4s-native_2.12-3.7.0-M11.jar",
62
+ "json4s-core_2.12-3.7.0-M11.jar",
63
+ "paranamer-2.8.3.jar",
60
64
  ]
61
65
 
62
66
  for jar in jar_files:
@@ -94,10 +98,19 @@ def initialize_resources() -> None:
94
98
  logger.info(f"All resources initialized in {time.time() - start_time:.2f}s")
95
99
 
96
100
 
101
+ _resource_initializer = threading.Thread(
102
+ target=initialize_resources, name="ResourceInitializer"
103
+ )
104
+
105
+
97
106
  def initialize_resources_async() -> threading.Thread:
98
107
  """Start resource initialization in background."""
99
- thread = threading.Thread(
100
- target=initialize_resources, name="ResourceInitializer", daemon=True
101
- )
102
- thread.start()
103
- return thread
108
+ with _initializer_lock:
109
+ if not _resource_initializer.is_alive() and _resource_initializer.ident is None:
110
+ _resource_initializer.start()
111
+ return _resource_initializer
112
+
113
+
114
+ def wait_for_resource_initialization() -> None:
115
+ with _initializer_lock:
116
+ _resource_initializer.join()
@@ -88,6 +88,9 @@ from snowflake.snowpark_connect.utils.context import (
88
88
  set_spark_version,
89
89
  )
90
90
  from snowflake.snowpark_connect.utils.env_utils import get_int_from_env
91
+ from snowflake.snowpark_connect.utils.external_udxf_cache import (
92
+ clear_external_udxf_cache,
93
+ )
91
94
  from snowflake.snowpark_connect.utils.interrupt import (
92
95
  interrupt_all_queries,
93
96
  interrupt_queries_with_tag,
@@ -436,7 +439,8 @@ class SnowflakeConnectServicer(proto_base_grpc.SparkConnectServiceServicer):
436
439
  lambda: map_local_relation(relation), # noqa: B023
437
440
  materialize=True,
438
441
  )
439
- except Exception:
442
+ except Exception as e:
443
+ logger.warning("Failed to put df into cache: %s", str(e))
440
444
  # fallback - treat as regular artifact
441
445
  _handle_regular_artifact()
442
446
  else:
@@ -527,7 +531,10 @@ class SnowflakeConnectServicer(proto_base_grpc.SparkConnectServiceServicer):
527
531
  if name.endswith(".class"):
528
532
  # name is <dir>/<package>/<class_name>
529
533
  # we don't need the dir name, but require the package, so only remove dir
530
- class_files[name.split("/", 1)[-1]] = filepath
534
+ if os.name != "nt":
535
+ class_files[name.split("/", 1)[-1]] = filepath
536
+ else:
537
+ class_files[name.split("\\", 1)[-1]] = filepath
531
538
  continue
532
539
  session.file.put(
533
540
  filepath,
@@ -556,6 +563,9 @@ class SnowflakeConnectServicer(proto_base_grpc.SparkConnectServiceServicer):
556
563
  if class_files:
557
564
  write_class_files_to_stage(session, class_files)
558
565
 
566
+ if any(not name.startswith("cache") for name in filenames.keys()):
567
+ clear_external_udxf_cache(session)
568
+
559
569
  return proto_base.AddArtifactsResponse(artifacts=list(response.values()))
560
570
 
561
571
  def ArtifactStatus(self, request, context):
@@ -39,7 +39,7 @@ def write_temporary_artifact(
39
39
  if os.name != "nt":
40
40
  filepath = f"/tmp/sas-{session.session_id}/{name}"
41
41
  else:
42
- filepath = f"{tempfile.gettempdir()}/sas-{session.session_id}/{name}"
42
+ filepath = f"{tempfile.gettempdir()}\\sas-{session.session_id}\\{name}"
43
43
  # The name comes to us as a path (e.g. cache/<name>), so we need to create
44
44
  # the parent directory if it doesn't exist to avoid errors during writing.
45
45
  pathlib.Path(filepath).parent.mkdir(parents=True, exist_ok=True)
@@ -55,11 +55,10 @@ def write_class_files_to_stage(
55
55
  ) -> None:
56
56
  if os.name != "nt":
57
57
  filepath = f"/tmp/sas-{session.session_id}"
58
+ jar_name = f'{filepath}/{hashlib.sha256(str(files).encode("utf-8")).hexdigest()[:10]}.jar'
58
59
  else:
59
- filepath = f"{tempfile.gettempdir()}/sas-{session.session_id}"
60
- jar_name = (
61
- f'{filepath}/{hashlib.sha256(str(files).encode("utf-8")).hexdigest()[:10]}.jar'
62
- )
60
+ filepath = f"{tempfile.gettempdir()}\\sas-{session.session_id}"
61
+ jar_name = f'{filepath}\\{hashlib.sha256(str(files).encode("utf-8")).hexdigest()[:10]}.jar'
63
62
  with zipfile.ZipFile(jar_name, "w", zipfile.ZIP_DEFLATED) as jar:
64
63
  for name, path in files.items():
65
64
  jar.write(path, name)
@@ -64,6 +64,10 @@ class SynchronizedDict(Mapping[K, V]):
64
64
  with self._lock.reader():
65
65
  return iter(list(self._dict.items()))
66
66
 
67
+ def clear(self) -> None:
68
+ with self._lock.writer():
69
+ self._dict.clear()
70
+
67
71
 
68
72
  class ReadWriteLock:
69
73
  class _Reader:
@@ -30,6 +30,9 @@ _sql_aggregate_function_count = ContextVar[int](
30
30
  "_contains_aggregate_function", default=0
31
31
  )
32
32
 
33
+ # Context for parsing map_partitions
34
+ _map_partitions_stack = ContextVar[int]("_map_partitions_stack", default=0)
35
+
33
36
  # We have to generate our own plan IDs that are different from Spark's.
34
37
  # Spark plan IDs start at 0, so pick a "big enough" number to avoid overlaps.
35
38
  _STARTING_SQL_PLAN_ID = 0x80000000
@@ -49,6 +52,7 @@ _spark_client_type_regex = re.compile(r"spark/(?P<spark_version>\d+\.\d+\.\d+)")
49
52
  _current_operation = ContextVar[str]("_current_operation", default="default")
50
53
  _resolving_fun_args = ContextVar[bool]("_resolving_fun_args", default=False)
51
54
  _resolving_lambda_fun = ContextVar[bool]("_resolving_lambdas", default=False)
55
+ _current_lambda_params = ContextVar[list[str]]("_current_lambda_params", default=[])
52
56
 
53
57
  _is_window_enabled = ContextVar[bool]("_is_window_enabled", default=False)
54
58
  _is_in_pivot = ContextVar[bool]("_is_in_pivot", default=False)
@@ -206,6 +210,16 @@ def push_evaluating_join_condition(join_type, left_keys, right_keys):
206
210
  _is_evaluating_join_condition.set(prev)
207
211
 
208
212
 
213
+ @contextmanager
214
+ def push_map_partitions():
215
+ _map_partitions_stack.set(_map_partitions_stack.get() + 1)
216
+ yield
217
+
218
+
219
+ def map_partitions_depth() -> int:
220
+ return _map_partitions_stack.get()
221
+
222
+
209
223
  @contextmanager
210
224
  def push_sql_scope():
211
225
  """
@@ -238,16 +252,21 @@ def push_operation_scope(operation: str):
238
252
 
239
253
 
240
254
  @contextmanager
241
- def resolving_lambda_function():
255
+ def resolving_lambda_function(param_names: list[str] = None):
242
256
  """
243
257
  Context manager that sets a flag indicating lambda function is being resolved.
258
+ Also tracks the lambda parameter names for validation.
244
259
  """
245
260
  prev = _resolving_lambda_fun.get()
261
+ prev_params = _current_lambda_params.get()
246
262
  try:
247
263
  _resolving_lambda_fun.set(True)
264
+ if param_names is not None:
265
+ _current_lambda_params.set(param_names)
248
266
  yield
249
267
  finally:
250
268
  _resolving_lambda_fun.set(prev)
269
+ _current_lambda_params.set(prev_params)
251
270
 
252
271
 
253
272
  def is_lambda_being_resolved() -> bool:
@@ -257,6 +276,13 @@ def is_lambda_being_resolved() -> bool:
257
276
  return _resolving_lambda_fun.get()
258
277
 
259
278
 
279
+ def get_current_lambda_params() -> list[str]:
280
+ """
281
+ Returns the current lambda parameter names.
282
+ """
283
+ return _current_lambda_params.get()
284
+
285
+
260
286
  @contextmanager
261
287
  def resolving_fun_args():
262
288
  """
@@ -270,6 +296,19 @@ def resolving_fun_args():
270
296
  _resolving_fun_args.set(prev)
271
297
 
272
298
 
299
+ @contextmanager
300
+ def not_resolving_fun_args():
301
+ """
302
+ Context manager that sets a flag indicating function arguments are *not* being resolved.
303
+ """
304
+ prev = _resolving_fun_args.get()
305
+ try:
306
+ _resolving_fun_args.set(False)
307
+ yield
308
+ finally:
309
+ _resolving_fun_args.set(prev)
310
+
311
+
273
312
  def is_function_argument_being_resolved() -> bool:
274
313
  """
275
314
  Returns True if function arguments are being resolved.
@@ -350,6 +389,7 @@ def clear_context_data() -> None:
350
389
 
351
390
  _next_sql_plan_id.set(_STARTING_SQL_PLAN_ID)
352
391
  _sql_plan_name_map.set({})
392
+ _map_partitions_stack.set(0)
353
393
  _sql_aggregate_function_count.set(0)
354
394
  _sql_named_args.set({})
355
395
  _sql_pos_args.set({})
@@ -0,0 +1,36 @@
1
+ #
2
+ # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
3
+ #
4
+
5
+ from snowflake.snowpark import Session
6
+ from snowflake.snowpark_connect.utils.concurrent import SynchronizedDict
7
+
8
+
9
+ def init_external_udxf_cache(session: Session) -> None:
10
+ session.external_udfs_cache = SynchronizedDict()
11
+ session.external_udtfs_cache = SynchronizedDict()
12
+
13
+
14
+ def clear_external_udxf_cache(session: Session) -> None:
15
+ session.external_udfs_cache.clear()
16
+ session.external_udtfs_cache.clear()
17
+
18
+
19
+ def get_external_udf_from_cache(hash: str):
20
+ return Session.get_active_session().external_udfs_cache.get(hash)
21
+
22
+
23
+ def cache_external_udf(hash: int, udf):
24
+ Session.get_active_session().external_udfs_cache[hash] = udf
25
+
26
+
27
+ def clear_external_udtf_cache(session: Session) -> None:
28
+ session.external_udtfs_cache.clear()
29
+
30
+
31
+ def get_external_udtf_from_cache(hash: int):
32
+ return Session.get_active_session().external_udtfs_cache.get(hash)
33
+
34
+
35
+ def cache_external_udtf(hash: int, udf):
36
+ Session.get_active_session().external_udtfs_cache[hash] = udf
@@ -87,9 +87,93 @@ def get_map_in_arrow_udtf(
87
87
  def create_pandas_udtf(
88
88
  udtf_proto: CommonInlineUserDefinedFunction,
89
89
  spark_column_names: list[str],
90
- input_schema: StructType | None = None,
91
- return_schema: StructType | None = None,
90
+ input_schema: StructType,
91
+ return_schema: StructType,
92
+ ):
93
+ user_function, _ = cloudpickle.loads(udtf_proto.python_udf.command)
94
+ output_column_names = [field.name for field in return_schema.fields]
95
+ output_column_original_names = [
96
+ field.original_column_identifier for field in return_schema.fields
97
+ ]
98
+
99
+ class MapPandasUDTF:
100
+ def __init__(self) -> None:
101
+ self.user_function = user_function
102
+ self.output_column_names = output_column_names
103
+ self.spark_column_names = spark_column_names
104
+ self.output_column_original_names = output_column_original_names
105
+
106
+ def end_partition(self, df: pd.DataFrame):
107
+ if df.empty:
108
+ empty_df = pd.DataFrame(columns=self.output_column_names)
109
+ yield empty_df
110
+ return
111
+
112
+ df_without_dummy = df.drop(
113
+ columns=["_DUMMY_PARTITION_KEY"], errors="ignore"
114
+ )
115
+ df_without_dummy.columns = self.spark_column_names
116
+ result_iterator = self.user_function(
117
+ [pd.DataFrame([row]) for _, row in df_without_dummy.iterrows()]
118
+ )
119
+
120
+ if not isinstance(result_iterator, Iterator) and not hasattr(
121
+ result_iterator, "__iter__"
122
+ ):
123
+ raise RuntimeError(
124
+ f"snowpark_connect::UDF_RETURN_TYPE Return type of the user-defined function should be "
125
+ f"iterator of pandas.DataFrame, but is {type(result_iterator).__name__}"
126
+ )
127
+
128
+ output_df = pd.concat(result_iterator)
129
+ generated_output_column_names = list(output_df.columns)
130
+
131
+ missing_columns = []
132
+ for original_column in self.output_column_original_names:
133
+ if original_column not in generated_output_column_names:
134
+ missing_columns.append(original_column)
135
+
136
+ if missing_columns:
137
+ unexpected_columns = [
138
+ column
139
+ for column in generated_output_column_names
140
+ if column not in self.output_column_original_names
141
+ ]
142
+ raise RuntimeError(
143
+ f"[RESULT_COLUMNS_MISMATCH_FOR_PANDAS_UDF] Column names of the returned pandas.DataFrame do not match specified schema. Missing: {', '.join(sorted(missing_columns))}. Unexpected: {', '.join(sorted(unexpected_columns))}"
144
+ "."
145
+ )
146
+ reordered_df = output_df[self.output_column_original_names]
147
+ reordered_df.columns = self.output_column_names
148
+ yield reordered_df
149
+
150
+ return snowpark_fn.pandas_udtf(
151
+ MapPandasUDTF,
152
+ output_schema=PandasDataFrameType(
153
+ [field.datatype for field in return_schema.fields],
154
+ [field.name for field in return_schema.fields],
155
+ ),
156
+ input_types=[
157
+ PandasDataFrameType(
158
+ [field.datatype for field in input_schema.fields] + [IntegerType()]
159
+ )
160
+ ],
161
+ input_names=[field.name for field in input_schema.fields]
162
+ + ["_DUMMY_PARTITION_KEY"],
163
+ name="map_pandas_udtf",
164
+ replace=True,
165
+ packages=["pandas"],
166
+ is_permanent=False,
167
+ )
168
+
169
+
170
+ def create_pandas_udtf_with_arrow(
171
+ udtf_proto: CommonInlineUserDefinedFunction,
172
+ spark_column_names: list[str],
173
+ input_schema: StructType,
174
+ return_schema: StructType,
92
175
  ) -> str | snowpark.udtf.UserDefinedTableFunction:
176
+
93
177
  user_function, _ = cloudpickle.loads(udtf_proto.python_udf.command)
94
178
  output_column_names = [field.name for field in return_schema.fields]
95
179