snowpark-connect 0.26.0__py3-none-any.whl → 0.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of snowpark-connect might be problematic. Click here for more details.

Files changed (42) hide show
  1. snowflake/snowpark_connect/column_name_handler.py +3 -93
  2. snowflake/snowpark_connect/config.py +99 -4
  3. snowflake/snowpark_connect/dataframe_container.py +0 -6
  4. snowflake/snowpark_connect/expression/map_expression.py +31 -1
  5. snowflake/snowpark_connect/expression/map_sql_expression.py +22 -18
  6. snowflake/snowpark_connect/expression/map_unresolved_attribute.py +22 -26
  7. snowflake/snowpark_connect/expression/map_unresolved_function.py +28 -10
  8. snowflake/snowpark_connect/expression/map_unresolved_star.py +2 -3
  9. snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.1.0.jar +0 -0
  10. snowflake/snowpark_connect/relation/map_extension.py +7 -1
  11. snowflake/snowpark_connect/relation/map_join.py +62 -258
  12. snowflake/snowpark_connect/relation/map_map_partitions.py +36 -77
  13. snowflake/snowpark_connect/relation/map_relation.py +8 -2
  14. snowflake/snowpark_connect/relation/map_show_string.py +2 -0
  15. snowflake/snowpark_connect/relation/map_sql.py +413 -15
  16. snowflake/snowpark_connect/relation/write/map_write.py +195 -114
  17. snowflake/snowpark_connect/resources_initializer.py +20 -5
  18. snowflake/snowpark_connect/server.py +20 -18
  19. snowflake/snowpark_connect/utils/artifacts.py +4 -5
  20. snowflake/snowpark_connect/utils/concurrent.py +4 -0
  21. snowflake/snowpark_connect/utils/context.py +41 -1
  22. snowflake/snowpark_connect/utils/describe_query_cache.py +57 -51
  23. snowflake/snowpark_connect/utils/identifiers.py +120 -0
  24. snowflake/snowpark_connect/utils/io_utils.py +21 -1
  25. snowflake/snowpark_connect/utils/pandas_udtf_utils.py +86 -2
  26. snowflake/snowpark_connect/utils/scala_udf_utils.py +34 -43
  27. snowflake/snowpark_connect/utils/session.py +16 -26
  28. snowflake/snowpark_connect/utils/telemetry.py +53 -0
  29. snowflake/snowpark_connect/utils/udf_utils.py +66 -103
  30. snowflake/snowpark_connect/utils/udtf_helper.py +17 -7
  31. snowflake/snowpark_connect/version.py +2 -3
  32. {snowpark_connect-0.26.0.dist-info → snowpark_connect-0.28.0.dist-info}/METADATA +2 -2
  33. {snowpark_connect-0.26.0.dist-info → snowpark_connect-0.28.0.dist-info}/RECORD +41 -42
  34. snowflake/snowpark_connect/hidden_column.py +0 -39
  35. {snowpark_connect-0.26.0.data → snowpark_connect-0.28.0.data}/scripts/snowpark-connect +0 -0
  36. {snowpark_connect-0.26.0.data → snowpark_connect-0.28.0.data}/scripts/snowpark-session +0 -0
  37. {snowpark_connect-0.26.0.data → snowpark_connect-0.28.0.data}/scripts/snowpark-submit +0 -0
  38. {snowpark_connect-0.26.0.dist-info → snowpark_connect-0.28.0.dist-info}/WHEEL +0 -0
  39. {snowpark_connect-0.26.0.dist-info → snowpark_connect-0.28.0.dist-info}/licenses/LICENSE-binary +0 -0
  40. {snowpark_connect-0.26.0.dist-info → snowpark_connect-0.28.0.dist-info}/licenses/LICENSE.txt +0 -0
  41. {snowpark_connect-0.26.0.dist-info → snowpark_connect-0.28.0.dist-info}/licenses/NOTICE-binary +0 -0
  42. {snowpark_connect-0.26.0.dist-info → snowpark_connect-0.28.0.dist-info}/top_level.txt +0 -0
@@ -30,6 +30,9 @@ _sql_aggregate_function_count = ContextVar[int](
30
30
  "_contains_aggregate_function", default=0
31
31
  )
32
32
 
33
+ # Context for parsing map_partitions
34
+ _map_partitions_stack = ContextVar[int]("_map_partitions_stack", default=0)
35
+
33
36
  # We have to generate our own plan IDs that are different from Spark's.
34
37
  # Spark plan IDs start at 0, so pick a "big enough" number to avoid overlaps.
35
38
  _STARTING_SQL_PLAN_ID = 0x80000000
@@ -49,6 +52,7 @@ _spark_client_type_regex = re.compile(r"spark/(?P<spark_version>\d+\.\d+\.\d+)")
49
52
  _current_operation = ContextVar[str]("_current_operation", default="default")
50
53
  _resolving_fun_args = ContextVar[bool]("_resolving_fun_args", default=False)
51
54
  _resolving_lambda_fun = ContextVar[bool]("_resolving_lambdas", default=False)
55
+ _current_lambda_params = ContextVar[list[str]]("_current_lambda_params", default=[])
52
56
 
53
57
  _is_window_enabled = ContextVar[bool]("_is_window_enabled", default=False)
54
58
  _is_in_pivot = ContextVar[bool]("_is_in_pivot", default=False)
@@ -206,6 +210,16 @@ def push_evaluating_join_condition(join_type, left_keys, right_keys):
206
210
  _is_evaluating_join_condition.set(prev)
207
211
 
208
212
 
213
+ @contextmanager
214
+ def push_map_partitions():
215
+ _map_partitions_stack.set(_map_partitions_stack.get() + 1)
216
+ yield
217
+
218
+
219
+ def map_partitions_depth() -> int:
220
+ return _map_partitions_stack.get()
221
+
222
+
209
223
  @contextmanager
210
224
  def push_sql_scope():
211
225
  """
@@ -238,16 +252,21 @@ def push_operation_scope(operation: str):
238
252
 
239
253
 
240
254
  @contextmanager
241
- def resolving_lambda_function():
255
+ def resolving_lambda_function(param_names: list[str] = None):
242
256
  """
243
257
  Context manager that sets a flag indicating lambda function is being resolved.
258
+ Also tracks the lambda parameter names for validation.
244
259
  """
245
260
  prev = _resolving_lambda_fun.get()
261
+ prev_params = _current_lambda_params.get()
246
262
  try:
247
263
  _resolving_lambda_fun.set(True)
264
+ if param_names is not None:
265
+ _current_lambda_params.set(param_names)
248
266
  yield
249
267
  finally:
250
268
  _resolving_lambda_fun.set(prev)
269
+ _current_lambda_params.set(prev_params)
251
270
 
252
271
 
253
272
  def is_lambda_being_resolved() -> bool:
@@ -257,6 +276,13 @@ def is_lambda_being_resolved() -> bool:
257
276
  return _resolving_lambda_fun.get()
258
277
 
259
278
 
279
+ def get_current_lambda_params() -> list[str]:
280
+ """
281
+ Returns the current lambda parameter names.
282
+ """
283
+ return _current_lambda_params.get()
284
+
285
+
260
286
  @contextmanager
261
287
  def resolving_fun_args():
262
288
  """
@@ -270,6 +296,19 @@ def resolving_fun_args():
270
296
  _resolving_fun_args.set(prev)
271
297
 
272
298
 
299
+ @contextmanager
300
+ def not_resolving_fun_args():
301
+ """
302
+ Context manager that sets a flag indicating function arguments are *not* being resolved.
303
+ """
304
+ prev = _resolving_fun_args.get()
305
+ try:
306
+ _resolving_fun_args.set(False)
307
+ yield
308
+ finally:
309
+ _resolving_fun_args.set(prev)
310
+
311
+
273
312
  def is_function_argument_being_resolved() -> bool:
274
313
  """
275
314
  Returns True if function arguments are being resolved.
@@ -350,6 +389,7 @@ def clear_context_data() -> None:
350
389
 
351
390
  _next_sql_plan_id.set(_STARTING_SQL_PLAN_ID)
352
391
  _sql_plan_name_map.set({})
392
+ _map_partitions_stack.set(0)
353
393
  _sql_aggregate_function_count.set(0)
354
394
  _sql_named_args.set({})
355
395
  _sql_pos_args.set({})
@@ -6,20 +6,24 @@ import hashlib
6
6
  import inspect
7
7
  import random
8
8
  import re
9
- import threading
10
9
  import time
11
10
  from typing import Any
12
11
 
13
12
  from snowflake import snowpark
14
13
  from snowflake.connector.cursor import ResultMetadataV2
15
14
  from snowflake.snowpark._internal.server_connection import ServerConnection
15
+ from snowflake.snowpark_connect.utils.concurrent import SynchronizedDict
16
16
  from snowflake.snowpark_connect.utils.snowpark_connect_logging import logger
17
17
  from snowflake.snowpark_connect.utils.telemetry import telemetry
18
18
 
19
19
  DESCRIBE_CACHE_TTL_SECONDS = 15
20
20
  USE_DESCRIBE_QUERY_CACHE = True
21
21
 
22
- DDL_DETECTION_PATTERN = re.compile(r"^\s*(CREATE|ALTER|DROP|RENAME)\b", re.IGNORECASE)
22
+ DDL_DETECTION_PATTERN = re.compile(r"\s*(CREATE|ALTER|DROP)\b", re.IGNORECASE)
23
+ PLAIN_CREATE_PATTERN = re.compile(
24
+ r"\s*CREATE\s+((LOCAL|GLOBAL)\s+)?(TRANSIENT\s+)?TABLE\b", re.IGNORECASE
25
+ )
26
+
23
27
  # Pattern for simple constant queries like: SELECT 3 :: INT AS "3-80000030-0" FROM ( SELECT $1 AS "__DUMMY" FROM VALUES (NULL :: STRING))
24
28
  # Using exact spacing pattern from generated SQL for deterministic matching
25
29
  # Column ID format: {original_name}-{8_digit_hex_plan_id}-{column_index}
@@ -32,8 +36,7 @@ SIMPLE_CONSTANT_PATTERN = re.compile(
32
36
 
33
37
  class DescribeQueryCache:
34
38
  def __init__(self) -> None:
35
- self._cache = {}
36
- self._lock = threading.Lock()
39
+ self._cache = SynchronizedDict()
37
40
 
38
41
  @staticmethod
39
42
  def _hash_query(sql_query: str) -> str:
@@ -48,49 +51,49 @@ class DescribeQueryCache:
48
51
  return sql_query
49
52
 
50
53
  def get(self, sql_query: str) -> list[ResultMetadataV2] | None:
54
+ telemetry.report_describe_query_cache_lookup()
55
+
51
56
  cache_key = self._get_cache_key(sql_query)
52
57
  key = self._hash_query(cache_key)
53
58
  current_time = time.monotonic()
54
59
 
55
- # TODO: maybe too much locking, we could use read-write lock also. Or a thread safe dictionary.
56
- with self._lock:
57
- if key in self._cache:
58
- result, timestamp = self._cache[key]
59
- if current_time < timestamp + DESCRIBE_CACHE_TTL_SECONDS:
60
- logger.debug(
61
- f"Returning query result from cache for query: {sql_query[:20]}"
62
- )
63
-
64
- # If this is a constant query, we need to transform the result metadata
65
- # to match the actual query's column name
66
- if (
67
- cache_key != sql_query
68
- ): # Only transform if we normalized the key
69
- match = SIMPLE_CONSTANT_PATTERN.match(sql_query)
70
- if match:
71
- number, column_id = match.groups()
72
- expected_column_name = column_id
73
-
74
- # Transform the cached result to match this query's column name
75
- # There should only be one column in these constant queries
76
- metadata = result[0]
77
- new_metadata = ResultMetadataV2(
78
- name=expected_column_name,
79
- type_code=metadata.type_code,
80
- display_size=metadata.display_size,
81
- internal_size=metadata.internal_size,
82
- precision=metadata.precision,
83
- scale=metadata.scale,
84
- is_nullable=metadata.is_nullable,
85
- )
86
- return [new_metadata]
87
-
88
- return result
89
- else:
90
- logger.debug(
91
- f"Had a cached entry, but it expired for query: {sql_query[:20]}"
92
- )
93
- del self._cache[key]
60
+ if key in self._cache:
61
+ result, timestamp = self._cache[key]
62
+ if current_time < timestamp + DESCRIBE_CACHE_TTL_SECONDS:
63
+ logger.debug(
64
+ f"Returning query result from cache for query: {sql_query[:20]}"
65
+ )
66
+ self._cache[key] = (result, current_time)
67
+
68
+ # If this is a constant query, we need to transform the result metadata
69
+ # to match the actual query's column name
70
+ if cache_key != sql_query: # Only transform if we normalized the key
71
+ match = SIMPLE_CONSTANT_PATTERN.match(sql_query)
72
+ if match:
73
+ number, column_id = match.groups()
74
+ expected_column_name = column_id
75
+
76
+ # Transform the cached result to match this query's column name
77
+ # There should only be one column in these constant queries
78
+ metadata = result[0]
79
+ new_metadata = ResultMetadataV2(
80
+ name=expected_column_name,
81
+ type_code=metadata.type_code,
82
+ display_size=metadata.display_size,
83
+ internal_size=metadata.internal_size,
84
+ precision=metadata.precision,
85
+ scale=metadata.scale,
86
+ is_nullable=metadata.is_nullable,
87
+ )
88
+
89
+ telemetry.report_describe_query_cache_hit()
90
+ return [new_metadata]
91
+
92
+ telemetry.report_describe_query_cache_hit()
93
+ return result
94
+ else:
95
+ telemetry.report_describe_query_cache_expired()
96
+ del self._cache[key]
94
97
  return None
95
98
 
96
99
  def put(self, sql_query: str, result: list[ResultMetadataV2] | None) -> None:
@@ -102,12 +105,18 @@ class DescribeQueryCache:
102
105
 
103
106
  logger.debug(f"Putting query into cache: {sql_query[:50]}...")
104
107
 
105
- with self._lock:
106
- self._cache[key] = (result, time.monotonic())
108
+ self._cache[key] = (result, time.monotonic())
107
109
 
108
110
  def clear(self) -> None:
109
- with self._lock:
110
- self._cache.clear()
111
+ self._cache.clear()
112
+
113
+ def update_cache_for_query(self, query: str) -> None:
114
+ # Clear cache for DDL operations that modify existing objects (exclude CREATE TABLE)
115
+ if DDL_DETECTION_PATTERN.search(query) and not PLAIN_CREATE_PATTERN.search(
116
+ query
117
+ ):
118
+ self.clear()
119
+ telemetry.report_describe_query_cache_clear(query[:100])
111
120
 
112
121
 
113
122
  def instrument_session_for_describe_cache(session: snowpark.Session):
@@ -126,10 +135,7 @@ def instrument_session_for_describe_cache(session: snowpark.Session):
126
135
  if isinstance(cache_instance, DescribeQueryCache):
127
136
  cache = cache_instance
128
137
 
129
- # TODO: This is very broad right now. We should be able to reduce the scope of clearing.
130
- if DDL_DETECTION_PATTERN.search(query):
131
- logger.debug(f"DDL detected, clearing describe query cache: '{query}'")
132
- cache.clear()
138
+ cache.update_cache_for_query(query)
133
139
 
134
140
  def wrap_execute(wrapped_fn):
135
141
  def fn(query: str, **kwargs):
@@ -2,6 +2,7 @@
2
2
  # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
3
3
  #
4
4
  import re
5
+ from typing import Any, TypeVar
5
6
 
6
7
  from pyspark.errors import AnalysisException
7
8
 
@@ -117,3 +118,122 @@ def split_fully_qualified_spark_name(qualified_name: str | None) -> list[str]:
117
118
  parts.append("".join(token_chars))
118
119
 
119
120
  return parts
121
+
122
+
123
+ # See https://docs.snowflake.com/en/sql-reference/identifiers-syntax for identifier syntax
124
+ UNQUOTED_IDENTIFIER_REGEX = r"([a-zA-Z_])([a-zA-Z0-9_$]{0,254})"
125
+ QUOTED_IDENTIFIER_REGEX = r'"((""|[^"]){0,255})"'
126
+ VALID_IDENTIFIER_REGEX = f"(?:{UNQUOTED_IDENTIFIER_REGEX}|{QUOTED_IDENTIFIER_REGEX})"
127
+
128
+
129
+ Self = TypeVar("Self", bound="FQN")
130
+
131
+
132
+ class FQN:
133
+ """Represents an object identifier, supporting fully qualified names.
134
+
135
+ The instance supports builder pattern that allows updating the identifier with database and
136
+ schema from different sources.
137
+
138
+ Examples
139
+ ________
140
+ >>> fqn = FQN.from_string("my_schema.object").using_connection(conn)
141
+
142
+ >>> fqn = FQN.from_string("my_name").set_database("db").set_schema("foo")
143
+ """
144
+
145
+ def __init__(
146
+ self,
147
+ database: str | None,
148
+ schema: str | None,
149
+ name: str,
150
+ signature: str | None = None,
151
+ ) -> None:
152
+ self._database = database
153
+ self._schema = schema
154
+ self._name = name
155
+ self.signature = signature
156
+
157
+ @property
158
+ def database(self) -> str | None:
159
+ return self._database
160
+
161
+ @property
162
+ def schema(self) -> str | None:
163
+ return self._schema
164
+
165
+ @property
166
+ def name(self) -> str:
167
+ return self._name
168
+
169
+ @property
170
+ def prefix(self) -> str:
171
+ if self.database:
172
+ return f"{self.database}.{self.schema if self.schema else 'PUBLIC'}"
173
+ if self.schema:
174
+ return f"{self.schema}"
175
+ return ""
176
+
177
+ @property
178
+ def identifier(self) -> str:
179
+ if self.prefix:
180
+ return f"{self.prefix}.{self.name}"
181
+ return self.name
182
+
183
+ def __str__(self) -> str:
184
+ return self.identifier
185
+
186
+ def __eq__(self, other: Any) -> bool:
187
+ if not isinstance(other, FQN):
188
+ raise AnalysisException(f"{other} is not a valid FQN")
189
+ return self.identifier == other.identifier
190
+
191
+ @classmethod
192
+ def from_string(cls, identifier: str) -> Self:
193
+ """Take in an object name in the form [[database.]schema.]name and return a new :class:`FQN` instance.
194
+
195
+ Raises:
196
+ InvalidIdentifierError: If the object identifier does not meet identifier requirements.
197
+ """
198
+ qualifier_pattern = (
199
+ rf"(?:(?P<first_qualifier>{VALID_IDENTIFIER_REGEX})\.)?"
200
+ rf"(?:(?P<second_qualifier>{VALID_IDENTIFIER_REGEX})\.)?"
201
+ rf"(?P<name>{VALID_IDENTIFIER_REGEX})(?P<signature>\(.*\))?"
202
+ )
203
+ result = re.fullmatch(qualifier_pattern, identifier)
204
+
205
+ if result is None:
206
+ raise AnalysisException(f"{identifier} is not a valid identifier")
207
+
208
+ unqualified_name = result.group("name")
209
+ if result.group("second_qualifier") is not None:
210
+ database = result.group("first_qualifier")
211
+ schema = result.group("second_qualifier")
212
+ else:
213
+ database = None
214
+ schema = result.group("first_qualifier")
215
+
216
+ signature = None
217
+ if result.group("signature"):
218
+ signature = result.group("signature")
219
+ return cls(
220
+ name=unqualified_name, schema=schema, database=database, signature=signature
221
+ )
222
+
223
+ def set_database(self, database: str | None) -> Self:
224
+ if database:
225
+ self._database = database
226
+ return self
227
+
228
+ def set_schema(self, schema: str | None) -> Self:
229
+ if schema:
230
+ self._schema = schema
231
+ return self
232
+
233
+ def set_name(self, name: str) -> Self:
234
+ self._name = name
235
+ return self
236
+
237
+ def to_dict(self) -> dict[str, str | None]:
238
+ """Return the dictionary representation of the instance."""
239
+ return {"name": self.name, "schema": self.schema, "database": self.database}
@@ -1,10 +1,11 @@
1
1
  #
2
2
  # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
3
3
  #
4
-
4
+ import contextlib
5
5
  import functools
6
6
 
7
7
  from snowflake.snowpark import Session
8
+ from snowflake.snowpark_connect.utils.identifiers import FQN
8
9
 
9
10
 
10
11
  @functools.cache
@@ -33,3 +34,22 @@ def file_format(
33
34
  ).collect()
34
35
 
35
36
  return file_format_name
37
+
38
+
39
+ def get_table_type(
40
+ snowpark_table_name: str,
41
+ snowpark_session: Session,
42
+ ) -> str:
43
+ fqn = FQN.from_string(snowpark_table_name)
44
+ with contextlib.suppress(Exception):
45
+ if fqn.database is not None:
46
+ return snowpark_session.catalog.getTable(
47
+ table_name=fqn.name, schema=fqn.schema, database=fqn.database
48
+ ).table_type
49
+ elif fqn.schema is not None:
50
+ return snowpark_session.catalog.getTable(
51
+ table_name=fqn.name, schema=fqn.schema
52
+ ).table_type
53
+ else:
54
+ return snowpark_session.catalog.getTable(table_name=fqn.name).table_type
55
+ return "TABLE"
@@ -87,9 +87,93 @@ def get_map_in_arrow_udtf(
87
87
  def create_pandas_udtf(
88
88
  udtf_proto: CommonInlineUserDefinedFunction,
89
89
  spark_column_names: list[str],
90
- input_schema: StructType | None = None,
91
- return_schema: StructType | None = None,
90
+ input_schema: StructType,
91
+ return_schema: StructType,
92
+ ):
93
+ user_function, _ = cloudpickle.loads(udtf_proto.python_udf.command)
94
+ output_column_names = [field.name for field in return_schema.fields]
95
+ output_column_original_names = [
96
+ field.original_column_identifier for field in return_schema.fields
97
+ ]
98
+
99
+ class MapPandasUDTF:
100
+ def __init__(self) -> None:
101
+ self.user_function = user_function
102
+ self.output_column_names = output_column_names
103
+ self.spark_column_names = spark_column_names
104
+ self.output_column_original_names = output_column_original_names
105
+
106
+ def end_partition(self, df: pd.DataFrame):
107
+ if df.empty:
108
+ empty_df = pd.DataFrame(columns=self.output_column_names)
109
+ yield empty_df
110
+ return
111
+
112
+ df_without_dummy = df.drop(
113
+ columns=["_DUMMY_PARTITION_KEY"], errors="ignore"
114
+ )
115
+ df_without_dummy.columns = self.spark_column_names
116
+ result_iterator = self.user_function(
117
+ [pd.DataFrame([row]) for _, row in df_without_dummy.iterrows()]
118
+ )
119
+
120
+ if not isinstance(result_iterator, Iterator) and not hasattr(
121
+ result_iterator, "__iter__"
122
+ ):
123
+ raise RuntimeError(
124
+ f"snowpark_connect::UDF_RETURN_TYPE Return type of the user-defined function should be "
125
+ f"iterator of pandas.DataFrame, but is {type(result_iterator).__name__}"
126
+ )
127
+
128
+ output_df = pd.concat(result_iterator)
129
+ generated_output_column_names = list(output_df.columns)
130
+
131
+ missing_columns = []
132
+ for original_column in self.output_column_original_names:
133
+ if original_column not in generated_output_column_names:
134
+ missing_columns.append(original_column)
135
+
136
+ if missing_columns:
137
+ unexpected_columns = [
138
+ column
139
+ for column in generated_output_column_names
140
+ if column not in self.output_column_original_names
141
+ ]
142
+ raise RuntimeError(
143
+ f"[RESULT_COLUMNS_MISMATCH_FOR_PANDAS_UDF] Column names of the returned pandas.DataFrame do not match specified schema. Missing: {', '.join(sorted(missing_columns))}. Unexpected: {', '.join(sorted(unexpected_columns))}"
144
+ "."
145
+ )
146
+ reordered_df = output_df[self.output_column_original_names]
147
+ reordered_df.columns = self.output_column_names
148
+ yield reordered_df
149
+
150
+ return snowpark_fn.pandas_udtf(
151
+ MapPandasUDTF,
152
+ output_schema=PandasDataFrameType(
153
+ [field.datatype for field in return_schema.fields],
154
+ [field.name for field in return_schema.fields],
155
+ ),
156
+ input_types=[
157
+ PandasDataFrameType(
158
+ [field.datatype for field in input_schema.fields] + [IntegerType()]
159
+ )
160
+ ],
161
+ input_names=[field.name for field in input_schema.fields]
162
+ + ["_DUMMY_PARTITION_KEY"],
163
+ name="map_pandas_udtf",
164
+ replace=True,
165
+ packages=["pandas"],
166
+ is_permanent=False,
167
+ )
168
+
169
+
170
+ def create_pandas_udtf_with_arrow(
171
+ udtf_proto: CommonInlineUserDefinedFunction,
172
+ spark_column_names: list[str],
173
+ input_schema: StructType,
174
+ return_schema: StructType,
92
175
  ) -> str | snowpark.udtf.UserDefinedTableFunction:
176
+
93
177
  user_function, _ = cloudpickle.loads(udtf_proto.python_udf.command)
94
178
  output_column_names = [field.name for field in return_schema.fields]
95
179
 
@@ -171,12 +171,19 @@ class ScalaUDFDef:
171
171
  is_map_return = udf_func_return_type.startswith("Map")
172
172
  wrapper_return_type = "String" if is_map_return else udf_func_return_type
173
173
 
174
+ # For handling Seq type correctly, ensure that the wrapper function always uses Array as its input and
175
+ # return types (when required) and the wrapped function uses Seq.
176
+ udf_func_return_type = udf_func_return_type.replace("Array", "Seq")
177
+ is_seq_return = udf_func_return_type.startswith("Seq")
178
+
174
179
  # Need to call the map to JSON string converter when a map is returned by the user's function.
175
- invoke_udf_func = (
176
- f"write(func({invocation_args}))"
177
- if is_map_return
178
- else f"func({invocation_args})"
179
- )
180
+ if is_map_return:
181
+ invoke_udf_func = f"write(func({invocation_args}))"
182
+ elif is_seq_return:
183
+ # TODO: SNOW-2339385 Handle Array[T] return types correctly. Currently, only Seq[T] is supported.
184
+ invoke_udf_func = f"func({invocation_args}).toArray"
185
+ else:
186
+ invoke_udf_func = f"func({invocation_args})"
180
187
 
181
188
  # The lines of code below are required only when a Map is returned by the UDF. This is needed to serialize the
182
189
  # map output to a JSON string.
@@ -184,9 +191,9 @@ class ScalaUDFDef:
184
191
  ""
185
192
  if not is_map_return
186
193
  else """
187
- import org.json4s._
188
- import org.json4s.native.Serialization._
189
- import org.json4s.native.Serialization
194
+ import shaded_json4s._
195
+ import shaded_json4s.native.Serialization._
196
+ import shaded_json4s.native.Serialization
190
197
  """
191
198
  )
192
199
  map_return_formatter = (
@@ -199,22 +206,12 @@ import org.json4s.native.Serialization
199
206
 
200
207
  return f"""import org.apache.spark.sql.connect.common.UdfPacket
201
208
  {map_return_imports}
202
- import java.io.{{ByteArrayInputStream, ObjectInputStream}}
203
- import java.nio.file.{{Files, Paths}}
209
+ import com.snowflake.sas.scala.Utils
204
210
 
205
211
  object __RecreatedSparkUdf {{
206
212
  {map_return_formatter}
207
- private lazy val func: ({udf_func_input_types}) => {udf_func_return_type} = {{
208
- val importDirectory = System.getProperty("com.snowflake.import_directory")
209
- val fPath = importDirectory + "{self.name}.bin"
210
- val bytes = Files.readAllBytes(Paths.get(fPath))
211
- val ois = new ObjectInputStream(new ByteArrayInputStream(bytes))
212
- try {{
213
- ois.readObject().asInstanceOf[UdfPacket].function.asInstanceOf[({udf_func_input_types}) => {udf_func_return_type}]
214
- }} finally {{
215
- ois.close()
216
- }}
217
- }}
213
+ private lazy val func: ({udf_func_input_types}) => {udf_func_return_type} =
214
+ Utils.deserializeFunc("{self.name}.bin").asInstanceOf[({udf_func_input_types}) => {udf_func_return_type}]
218
215
 
219
216
  def __wrapperFunc({wrapper_arg_and_input_types_str}): {wrapper_return_type} = {{
220
217
  {invoke_udf_func}
@@ -299,29 +296,15 @@ def build_scala_udf_imports(session, payload, udf_name, is_map_return) -> List[s
299
296
  # Remove the stage path since it is not properly formatted.
300
297
  user_jars.append(row[0][row[0].find("/") :])
301
298
 
302
- # Jars used when the return type is a Map.
303
- map_jars = (
304
- []
305
- if not is_map_return
306
- else [
307
- f"{stage_resource_path}/json4s-core_2.12-3.7.0-M11.jar",
308
- f"{stage_resource_path}/json4s-native_2.12-3.7.0-M11.jar",
309
- f"{stage_resource_path}/paranamer-2.8.3.jar",
310
- ]
311
- )
312
-
313
299
  # Format the user jars to be used in the IMPORTS clause of the stored procedure.
314
- return (
315
- [
316
- closure_binary_file,
317
- f"{stage_resource_path}/spark-connect-client-jvm_2.12-3.5.6.jar",
318
- f"{stage_resource_path}/spark-common-utils_2.12-3.5.6.jar",
319
- f"{stage_resource_path}/spark-sql_2.12-3.5.6.jar",
320
- f"{stage_resource_path}/json4s-ast_2.12-3.7.0-M11.jar",
321
- ]
322
- + map_jars
323
- + [f"{stage + jar}" for jar in user_jars]
324
- )
300
+ return [
301
+ closure_binary_file,
302
+ f"{stage_resource_path}/spark-connect-client-jvm_2.12-3.5.6.jar",
303
+ f"{stage_resource_path}/spark-common-utils_2.12-3.5.6.jar",
304
+ f"{stage_resource_path}/spark-sql_2.12-3.5.6.jar",
305
+ f"{stage_resource_path}/json4s-ast_2.12-3.7.0-M11.jar",
306
+ f"{stage_resource_path}/sas-scala-udf_2.12-0.1.0.jar",
307
+ ] + [f"{stage + jar}" for jar in user_jars]
325
308
 
326
309
 
327
310
  def create_scala_udf(pciudf: ProcessCommonInlineUserDefinedFunction) -> ScalaUdf:
@@ -343,6 +326,14 @@ def create_scala_udf(pciudf: ProcessCommonInlineUserDefinedFunction) -> ScalaUdf
343
326
  Returns:
344
327
  A ScalaUdf object representing the created or cached Scala UDF.
345
328
  """
329
+ from snowflake.snowpark_connect.resources_initializer import (
330
+ wait_for_resource_initialization,
331
+ )
332
+
333
+ # Make sure that the resource initializer thread is completed before creating Scala UDFs since we depend on the jars
334
+ # uploaded by it.
335
+ wait_for_resource_initialization()
336
+
346
337
  from snowflake.snowpark_connect.utils.session import get_or_create_snowpark_session
347
338
 
348
339
  function_name = pciudf._function_name