snowpark-connect 0.27.0__py3-none-any.whl → 0.28.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of snowpark-connect might be problematic. Click here for more details.

Files changed (42) hide show
  1. snowflake/snowpark_connect/column_name_handler.py +3 -93
  2. snowflake/snowpark_connect/config.py +99 -1
  3. snowflake/snowpark_connect/dataframe_container.py +0 -6
  4. snowflake/snowpark_connect/execute_plan/map_execution_command.py +31 -68
  5. snowflake/snowpark_connect/expression/map_expression.py +22 -7
  6. snowflake/snowpark_connect/expression/map_sql_expression.py +22 -18
  7. snowflake/snowpark_connect/expression/map_unresolved_attribute.py +4 -26
  8. snowflake/snowpark_connect/expression/map_unresolved_function.py +12 -3
  9. snowflake/snowpark_connect/expression/map_unresolved_star.py +2 -3
  10. snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.1.0.jar +0 -0
  11. snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +207 -20
  12. snowflake/snowpark_connect/relation/map_extension.py +14 -10
  13. snowflake/snowpark_connect/relation/map_join.py +62 -258
  14. snowflake/snowpark_connect/relation/map_relation.py +5 -1
  15. snowflake/snowpark_connect/relation/map_sql.py +464 -68
  16. snowflake/snowpark_connect/relation/read/map_read_table.py +58 -0
  17. snowflake/snowpark_connect/relation/write/map_write.py +228 -120
  18. snowflake/snowpark_connect/resources_initializer.py +20 -5
  19. snowflake/snowpark_connect/server.py +16 -17
  20. snowflake/snowpark_connect/utils/concurrent.py +4 -0
  21. snowflake/snowpark_connect/utils/context.py +21 -0
  22. snowflake/snowpark_connect/utils/describe_query_cache.py +57 -51
  23. snowflake/snowpark_connect/utils/identifiers.py +128 -2
  24. snowflake/snowpark_connect/utils/io_utils.py +21 -1
  25. snowflake/snowpark_connect/utils/scala_udf_utils.py +34 -43
  26. snowflake/snowpark_connect/utils/session.py +16 -26
  27. snowflake/snowpark_connect/utils/telemetry.py +53 -0
  28. snowflake/snowpark_connect/utils/temporary_view_cache.py +61 -0
  29. snowflake/snowpark_connect/utils/udf_utils.py +9 -8
  30. snowflake/snowpark_connect/utils/udtf_utils.py +3 -2
  31. snowflake/snowpark_connect/version.py +1 -1
  32. {snowpark_connect-0.27.0.dist-info → snowpark_connect-0.28.1.dist-info}/METADATA +2 -2
  33. {snowpark_connect-0.27.0.dist-info → snowpark_connect-0.28.1.dist-info}/RECORD +41 -41
  34. snowflake/snowpark_connect/hidden_column.py +0 -39
  35. {snowpark_connect-0.27.0.data → snowpark_connect-0.28.1.data}/scripts/snowpark-connect +0 -0
  36. {snowpark_connect-0.27.0.data → snowpark_connect-0.28.1.data}/scripts/snowpark-session +0 -0
  37. {snowpark_connect-0.27.0.data → snowpark_connect-0.28.1.data}/scripts/snowpark-submit +0 -0
  38. {snowpark_connect-0.27.0.dist-info → snowpark_connect-0.28.1.dist-info}/WHEEL +0 -0
  39. {snowpark_connect-0.27.0.dist-info → snowpark_connect-0.28.1.dist-info}/licenses/LICENSE-binary +0 -0
  40. {snowpark_connect-0.27.0.dist-info → snowpark_connect-0.28.1.dist-info}/licenses/LICENSE.txt +0 -0
  41. {snowpark_connect-0.27.0.dist-info → snowpark_connect-0.28.1.dist-info}/licenses/NOTICE-binary +0 -0
  42. {snowpark_connect-0.27.0.dist-info → snowpark_connect-0.28.1.dist-info}/top_level.txt +0 -0
@@ -12,6 +12,7 @@ _resources_initialized = threading.Event()
12
12
  _initializer_lock = threading.Lock()
13
13
  SPARK_VERSION = "3.5.6"
14
14
  RESOURCE_PATH = "/snowflake/snowpark_connect/resources"
15
+ _upload_jars = True # Flag to control whether to upload jars. Required for Scala UDFs.
15
16
 
16
17
 
17
18
  def initialize_resources() -> None:
@@ -57,10 +58,8 @@ def initialize_resources() -> None:
57
58
  f"spark-sql_2.12-{SPARK_VERSION}.jar",
58
59
  f"spark-connect-client-jvm_2.12-{SPARK_VERSION}.jar",
59
60
  f"spark-common-utils_2.12-{SPARK_VERSION}.jar",
61
+ "sas-scala-udf_2.12-0.1.0.jar",
60
62
  "json4s-ast_2.12-3.7.0-M11.jar",
61
- "json4s-native_2.12-3.7.0-M11.jar",
62
- "json4s-core_2.12-3.7.0-M11.jar",
63
- "paranamer-2.8.3.jar",
64
63
  ]
65
64
 
66
65
  for jar in jar_files:
@@ -80,9 +79,11 @@ def initialize_resources() -> None:
80
79
  ("Initialize Session Stage", initialize_session_stage), # Takes about 0.3s
81
80
  ("Initialize Session Catalog", initialize_catalog), # Takes about 1.2s
82
81
  ("Snowflake Connection Warm Up", warm_up_sf_connection), # Takes about 1s
83
- ("Upload Scala UDF Jars", upload_scala_udf_jars),
84
82
  ]
85
83
 
84
+ if _upload_jars:
85
+ resources.append(("Upload Scala UDF Jars", upload_scala_udf_jars))
86
+
86
87
  for name, resource_func in resources:
87
88
  resource_start = time.time()
88
89
  try:
@@ -113,4 +114,18 @@ def initialize_resources_async() -> threading.Thread:
113
114
 
114
115
  def wait_for_resource_initialization() -> None:
115
116
  with _initializer_lock:
116
- _resource_initializer.join()
117
+ _resource_initializer.join(timeout=300) # wait at most 300 seconds
118
+ if _resource_initializer.is_alive():
119
+ logger.error(
120
+ "Resource initialization failed - initializer thread has been running for over 300 seconds."
121
+ )
122
+ raise RuntimeError(
123
+ "Resource initialization failed - initializer thread has been running for over 300 seconds."
124
+ )
125
+
126
+
127
+ def set_upload_jars(upload: bool) -> None:
128
+ """Set whether to upload jars required for Scala UDFs. This should be set to False if Scala UDFs
129
+ are not used, to avoid the overhead of uploading jars."""
130
+ global _upload_jars
131
+ _upload_jars = upload
@@ -725,30 +725,33 @@ def _serve(
725
725
  # No need to start grpc server in TCM
726
726
  return
727
727
 
728
+ grpc_max_msg_size = get_int_from_env(
729
+ "SNOWFLAKE_GRPC_MAX_MESSAGE_SIZE",
730
+ _SPARK_CONNECT_GRPC_MAX_MESSAGE_SIZE,
731
+ )
732
+ grpc_max_metadata_size = get_int_from_env(
733
+ "SNOWFLAKE_GRPC_MAX_METADATA_SIZE",
734
+ _SPARK_CONNECT_GRPC_MAX_METADATA_SIZE,
735
+ )
728
736
  server_options = [
729
737
  (
730
738
  "grpc.max_receive_message_length",
731
- get_int_from_env(
732
- "SNOWFLAKE_GRPC_MAX_MESSAGE_SIZE",
733
- _SPARK_CONNECT_GRPC_MAX_MESSAGE_SIZE,
734
- ),
739
+ grpc_max_msg_size,
735
740
  ),
736
741
  (
737
742
  "grpc.max_metadata_size",
738
- get_int_from_env(
739
- "SNOWFLAKE_GRPC_MAX_METADATA_SIZE",
740
- _SPARK_CONNECT_GRPC_MAX_METADATA_SIZE,
741
- ),
743
+ grpc_max_metadata_size,
742
744
  ),
743
745
  (
744
746
  "grpc.absolute_max_metadata_size",
745
- get_int_from_env(
746
- "SNOWFLAKE_GRPC_MAX_METADATA_SIZE",
747
- _SPARK_CONNECT_GRPC_MAX_METADATA_SIZE,
748
- )
749
- * 2,
747
+ grpc_max_metadata_size * 2,
750
748
  ),
751
749
  ]
750
+
751
+ from pyspark.sql.connect.client import ChannelBuilder
752
+
753
+ ChannelBuilder.MAX_MESSAGE_LENGTH = grpc_max_msg_size
754
+
752
755
  server = grpc.server(
753
756
  futures.ThreadPoolExecutor(max_workers=10), options=server_options
754
757
  )
@@ -1053,10 +1056,6 @@ def start_session(
1053
1056
  global _SPARK_CONNECT_GRPC_MAX_MESSAGE_SIZE
1054
1057
  _SPARK_CONNECT_GRPC_MAX_MESSAGE_SIZE = max_grpc_message_size
1055
1058
 
1056
- from pyspark.sql.connect.client import ChannelBuilder
1057
-
1058
- ChannelBuilder.MAX_MESSAGE_LENGTH = max_grpc_message_size
1059
-
1060
1059
  if os.environ.get("SPARK_ENV_LOADED"):
1061
1060
  raise RuntimeError(
1062
1061
  "Snowpark Connect cannot be run inside of a Spark environment"
@@ -52,6 +52,10 @@ class SynchronizedDict(Mapping[K, V]):
52
52
  with self._lock.writer():
53
53
  self._dict[key] = value
54
54
 
55
+ def __delitem__(self, key: K) -> None:
56
+ with self._lock.writer():
57
+ del self._dict[key]
58
+
55
59
  def __contains__(self, key: K) -> bool:
56
60
  with self._lock.reader():
57
61
  return key in self._dict
@@ -70,6 +70,26 @@ _lca_alias_map: ContextVar[dict[str, TypedColumn]] = ContextVar(
70
70
  default={},
71
71
  )
72
72
 
73
+ _view_process_context = ContextVar("_view_process_context", default=[])
74
+
75
+
76
+ @contextmanager
77
+ def push_processed_view(name: str):
78
+ _view_process_context.set(_view_process_context.get() + [name])
79
+ yield
80
+ _view_process_context.set(_view_process_context.get()[:-1])
81
+
82
+
83
+ def get_processed_views() -> list[str]:
84
+ return _view_process_context.get()
85
+
86
+
87
+ def register_processed_view(name: str) -> None:
88
+ context = _view_process_context.get()
89
+ context.append(name)
90
+ _view_process_context.set(context)
91
+
92
+
73
93
  # Context variable to track current grouping columns for grouping_id() function
74
94
  _current_grouping_columns: ContextVar[list[str]] = ContextVar(
75
95
  "_current_grouping_columns",
@@ -387,6 +407,7 @@ def clear_context_data() -> None:
387
407
  _plan_id_map.set({})
388
408
  _alias_map.set({})
389
409
 
410
+ _view_process_context.set([])
390
411
  _next_sql_plan_id.set(_STARTING_SQL_PLAN_ID)
391
412
  _sql_plan_name_map.set({})
392
413
  _map_partitions_stack.set(0)
@@ -6,20 +6,24 @@ import hashlib
6
6
  import inspect
7
7
  import random
8
8
  import re
9
- import threading
10
9
  import time
11
10
  from typing import Any
12
11
 
13
12
  from snowflake import snowpark
14
13
  from snowflake.connector.cursor import ResultMetadataV2
15
14
  from snowflake.snowpark._internal.server_connection import ServerConnection
15
+ from snowflake.snowpark_connect.utils.concurrent import SynchronizedDict
16
16
  from snowflake.snowpark_connect.utils.snowpark_connect_logging import logger
17
17
  from snowflake.snowpark_connect.utils.telemetry import telemetry
18
18
 
19
19
  DESCRIBE_CACHE_TTL_SECONDS = 15
20
20
  USE_DESCRIBE_QUERY_CACHE = True
21
21
 
22
- DDL_DETECTION_PATTERN = re.compile(r"^\s*(CREATE|ALTER|DROP|RENAME)\b", re.IGNORECASE)
22
+ DDL_DETECTION_PATTERN = re.compile(r"\s*(CREATE|ALTER|DROP)\b", re.IGNORECASE)
23
+ PLAIN_CREATE_PATTERN = re.compile(
24
+ r"\s*CREATE\s+((LOCAL|GLOBAL)\s+)?(TRANSIENT\s+)?TABLE\b", re.IGNORECASE
25
+ )
26
+
23
27
  # Pattern for simple constant queries like: SELECT 3 :: INT AS "3-80000030-0" FROM ( SELECT $1 AS "__DUMMY" FROM VALUES (NULL :: STRING))
24
28
  # Using exact spacing pattern from generated SQL for deterministic matching
25
29
  # Column ID format: {original_name}-{8_digit_hex_plan_id}-{column_index}
@@ -32,8 +36,7 @@ SIMPLE_CONSTANT_PATTERN = re.compile(
32
36
 
33
37
  class DescribeQueryCache:
34
38
  def __init__(self) -> None:
35
- self._cache = {}
36
- self._lock = threading.Lock()
39
+ self._cache = SynchronizedDict()
37
40
 
38
41
  @staticmethod
39
42
  def _hash_query(sql_query: str) -> str:
@@ -48,49 +51,49 @@ class DescribeQueryCache:
48
51
  return sql_query
49
52
 
50
53
  def get(self, sql_query: str) -> list[ResultMetadataV2] | None:
54
+ telemetry.report_describe_query_cache_lookup()
55
+
51
56
  cache_key = self._get_cache_key(sql_query)
52
57
  key = self._hash_query(cache_key)
53
58
  current_time = time.monotonic()
54
59
 
55
- # TODO: maybe too much locking, we could use read-write lock also. Or a thread safe dictionary.
56
- with self._lock:
57
- if key in self._cache:
58
- result, timestamp = self._cache[key]
59
- if current_time < timestamp + DESCRIBE_CACHE_TTL_SECONDS:
60
- logger.debug(
61
- f"Returning query result from cache for query: {sql_query[:20]}"
62
- )
63
-
64
- # If this is a constant query, we need to transform the result metadata
65
- # to match the actual query's column name
66
- if (
67
- cache_key != sql_query
68
- ): # Only transform if we normalized the key
69
- match = SIMPLE_CONSTANT_PATTERN.match(sql_query)
70
- if match:
71
- number, column_id = match.groups()
72
- expected_column_name = column_id
73
-
74
- # Transform the cached result to match this query's column name
75
- # There should only be one column in these constant queries
76
- metadata = result[0]
77
- new_metadata = ResultMetadataV2(
78
- name=expected_column_name,
79
- type_code=metadata.type_code,
80
- display_size=metadata.display_size,
81
- internal_size=metadata.internal_size,
82
- precision=metadata.precision,
83
- scale=metadata.scale,
84
- is_nullable=metadata.is_nullable,
85
- )
86
- return [new_metadata]
87
-
88
- return result
89
- else:
90
- logger.debug(
91
- f"Had a cached entry, but it expired for query: {sql_query[:20]}"
92
- )
93
- del self._cache[key]
60
+ if key in self._cache:
61
+ result, timestamp = self._cache[key]
62
+ if current_time < timestamp + DESCRIBE_CACHE_TTL_SECONDS:
63
+ logger.debug(
64
+ f"Returning query result from cache for query: {sql_query[:20]}"
65
+ )
66
+ self._cache[key] = (result, current_time)
67
+
68
+ # If this is a constant query, we need to transform the result metadata
69
+ # to match the actual query's column name
70
+ if cache_key != sql_query: # Only transform if we normalized the key
71
+ match = SIMPLE_CONSTANT_PATTERN.match(sql_query)
72
+ if match:
73
+ number, column_id = match.groups()
74
+ expected_column_name = column_id
75
+
76
+ # Transform the cached result to match this query's column name
77
+ # There should only be one column in these constant queries
78
+ metadata = result[0]
79
+ new_metadata = ResultMetadataV2(
80
+ name=expected_column_name,
81
+ type_code=metadata.type_code,
82
+ display_size=metadata.display_size,
83
+ internal_size=metadata.internal_size,
84
+ precision=metadata.precision,
85
+ scale=metadata.scale,
86
+ is_nullable=metadata.is_nullable,
87
+ )
88
+
89
+ telemetry.report_describe_query_cache_hit()
90
+ return [new_metadata]
91
+
92
+ telemetry.report_describe_query_cache_hit()
93
+ return result
94
+ else:
95
+ telemetry.report_describe_query_cache_expired()
96
+ del self._cache[key]
94
97
  return None
95
98
 
96
99
  def put(self, sql_query: str, result: list[ResultMetadataV2] | None) -> None:
@@ -102,12 +105,18 @@ class DescribeQueryCache:
102
105
 
103
106
  logger.debug(f"Putting query into cache: {sql_query[:50]}...")
104
107
 
105
- with self._lock:
106
- self._cache[key] = (result, time.monotonic())
108
+ self._cache[key] = (result, time.monotonic())
107
109
 
108
110
  def clear(self) -> None:
109
- with self._lock:
110
- self._cache.clear()
111
+ self._cache.clear()
112
+
113
+ def update_cache_for_query(self, query: str) -> None:
114
+ # Clear cache for DDL operations that modify existing objects (exclude CREATE TABLE)
115
+ if DDL_DETECTION_PATTERN.search(query) and not PLAIN_CREATE_PATTERN.search(
116
+ query
117
+ ):
118
+ self.clear()
119
+ telemetry.report_describe_query_cache_clear(query[:100])
111
120
 
112
121
 
113
122
  def instrument_session_for_describe_cache(session: snowpark.Session):
@@ -126,10 +135,7 @@ def instrument_session_for_describe_cache(session: snowpark.Session):
126
135
  if isinstance(cache_instance, DescribeQueryCache):
127
136
  cache = cache_instance
128
137
 
129
- # TODO: This is very broad right now. We should be able to reduce the scope of clearing.
130
- if DDL_DETECTION_PATTERN.search(query):
131
- logger.debug(f"DDL detected, clearing describe query cache: '{query}'")
132
- cache.clear()
138
+ cache.update_cache_for_query(query)
133
139
 
134
140
  def wrap_execute(wrapped_fn):
135
141
  def fn(query: str, **kwargs):
@@ -2,6 +2,7 @@
2
2
  # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
3
3
  #
4
4
  import re
5
+ from typing import Any, TypeVar
5
6
 
6
7
  from pyspark.errors import AnalysisException
7
8
 
@@ -27,12 +28,18 @@ def unquote_spark_identifier_if_quoted(spark_name: str) -> str:
27
28
  raise AnalysisException(f"Invalid name: {spark_name}")
28
29
 
29
30
 
30
- def spark_to_sf_single_id_with_unquoting(name: str) -> str:
31
+ def spark_to_sf_single_id_with_unquoting(
32
+ name: str, use_auto_upper_case: bool = False
33
+ ) -> str:
31
34
  """
32
35
  Transforms a spark name to a valid snowflake name by quoting and potentially uppercasing it.
33
36
  Unquotes the spark name if necessary. Will raise an AnalysisException if given name is not valid.
34
37
  """
35
- return spark_to_sf_single_id(unquote_spark_identifier_if_quoted(name))
38
+ return (
39
+ spark_to_sf_single_id(unquote_spark_identifier_if_quoted(name))
40
+ if use_auto_upper_case
41
+ else quote_name_without_upper_casing(unquote_spark_identifier_if_quoted(name))
42
+ )
36
43
 
37
44
 
38
45
  def spark_to_sf_single_id(name: str, is_column: bool = False) -> str:
@@ -117,3 +124,122 @@ def split_fully_qualified_spark_name(qualified_name: str | None) -> list[str]:
117
124
  parts.append("".join(token_chars))
118
125
 
119
126
  return parts
127
+
128
+
129
+ # See https://docs.snowflake.com/en/sql-reference/identifiers-syntax for identifier syntax
130
+ UNQUOTED_IDENTIFIER_REGEX = r"([a-zA-Z_])([a-zA-Z0-9_$]{0,254})"
131
+ QUOTED_IDENTIFIER_REGEX = r'"((""|[^"]){0,255})"'
132
+ VALID_IDENTIFIER_REGEX = f"(?:{UNQUOTED_IDENTIFIER_REGEX}|{QUOTED_IDENTIFIER_REGEX})"
133
+
134
+
135
+ Self = TypeVar("Self", bound="FQN")
136
+
137
+
138
+ class FQN:
139
+ """Represents an object identifier, supporting fully qualified names.
140
+
141
+ The instance supports builder pattern that allows updating the identifier with database and
142
+ schema from different sources.
143
+
144
+ Examples
145
+ ________
146
+ >>> fqn = FQN.from_string("my_schema.object").using_connection(conn)
147
+
148
+ >>> fqn = FQN.from_string("my_name").set_database("db").set_schema("foo")
149
+ """
150
+
151
+ def __init__(
152
+ self,
153
+ database: str | None,
154
+ schema: str | None,
155
+ name: str,
156
+ signature: str | None = None,
157
+ ) -> None:
158
+ self._database = database
159
+ self._schema = schema
160
+ self._name = name
161
+ self.signature = signature
162
+
163
+ @property
164
+ def database(self) -> str | None:
165
+ return self._database
166
+
167
+ @property
168
+ def schema(self) -> str | None:
169
+ return self._schema
170
+
171
+ @property
172
+ def name(self) -> str:
173
+ return self._name
174
+
175
+ @property
176
+ def prefix(self) -> str:
177
+ if self.database:
178
+ return f"{self.database}.{self.schema if self.schema else 'PUBLIC'}"
179
+ if self.schema:
180
+ return f"{self.schema}"
181
+ return ""
182
+
183
+ @property
184
+ def identifier(self) -> str:
185
+ if self.prefix:
186
+ return f"{self.prefix}.{self.name}"
187
+ return self.name
188
+
189
+ def __str__(self) -> str:
190
+ return self.identifier
191
+
192
+ def __eq__(self, other: Any) -> bool:
193
+ if not isinstance(other, FQN):
194
+ raise AnalysisException(f"{other} is not a valid FQN")
195
+ return self.identifier == other.identifier
196
+
197
+ @classmethod
198
+ def from_string(cls, identifier: str) -> Self:
199
+ """Take in an object name in the form [[database.]schema.]name and return a new :class:`FQN` instance.
200
+
201
+ Raises:
202
+ InvalidIdentifierError: If the object identifier does not meet identifier requirements.
203
+ """
204
+ qualifier_pattern = (
205
+ rf"(?:(?P<first_qualifier>{VALID_IDENTIFIER_REGEX})\.)?"
206
+ rf"(?:(?P<second_qualifier>{VALID_IDENTIFIER_REGEX})\.)?"
207
+ rf"(?P<name>{VALID_IDENTIFIER_REGEX})(?P<signature>\(.*\))?"
208
+ )
209
+ result = re.fullmatch(qualifier_pattern, identifier)
210
+
211
+ if result is None:
212
+ raise AnalysisException(f"{identifier} is not a valid identifier")
213
+
214
+ unqualified_name = result.group("name")
215
+ if result.group("second_qualifier") is not None:
216
+ database = result.group("first_qualifier")
217
+ schema = result.group("second_qualifier")
218
+ else:
219
+ database = None
220
+ schema = result.group("first_qualifier")
221
+
222
+ signature = None
223
+ if result.group("signature"):
224
+ signature = result.group("signature")
225
+ return cls(
226
+ name=unqualified_name, schema=schema, database=database, signature=signature
227
+ )
228
+
229
+ def set_database(self, database: str | None) -> Self:
230
+ if database:
231
+ self._database = database
232
+ return self
233
+
234
+ def set_schema(self, schema: str | None) -> Self:
235
+ if schema:
236
+ self._schema = schema
237
+ return self
238
+
239
+ def set_name(self, name: str) -> Self:
240
+ self._name = name
241
+ return self
242
+
243
+ def to_dict(self) -> dict[str, str | None]:
244
+ """Return the dictionary representation of the instance."""
245
+ return {"name": self.name, "schema": self.schema, "database": self.database}
@@ -1,10 +1,11 @@
1
1
  #
2
2
  # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
3
3
  #
4
-
4
+ import contextlib
5
5
  import functools
6
6
 
7
7
  from snowflake.snowpark import Session
8
+ from snowflake.snowpark_connect.utils.identifiers import FQN
8
9
 
9
10
 
10
11
  @functools.cache
@@ -33,3 +34,22 @@ def file_format(
33
34
  ).collect()
34
35
 
35
36
  return file_format_name
37
+
38
+
39
+ def get_table_type(
40
+ snowpark_table_name: str,
41
+ snowpark_session: Session,
42
+ ) -> str:
43
+ fqn = FQN.from_string(snowpark_table_name)
44
+ with contextlib.suppress(Exception):
45
+ if fqn.database is not None:
46
+ return snowpark_session.catalog.getTable(
47
+ table_name=fqn.name, schema=fqn.schema, database=fqn.database
48
+ ).table_type
49
+ elif fqn.schema is not None:
50
+ return snowpark_session.catalog.getTable(
51
+ table_name=fqn.name, schema=fqn.schema
52
+ ).table_type
53
+ else:
54
+ return snowpark_session.catalog.getTable(table_name=fqn.name).table_type
55
+ return "TABLE"
@@ -171,12 +171,19 @@ class ScalaUDFDef:
171
171
  is_map_return = udf_func_return_type.startswith("Map")
172
172
  wrapper_return_type = "String" if is_map_return else udf_func_return_type
173
173
 
174
+ # For handling Seq type correctly, ensure that the wrapper function always uses Array as its input and
175
+ # return types (when required) and the wrapped function uses Seq.
176
+ udf_func_return_type = udf_func_return_type.replace("Array", "Seq")
177
+ is_seq_return = udf_func_return_type.startswith("Seq")
178
+
174
179
  # Need to call the map to JSON string converter when a map is returned by the user's function.
175
- invoke_udf_func = (
176
- f"write(func({invocation_args}))"
177
- if is_map_return
178
- else f"func({invocation_args})"
179
- )
180
+ if is_map_return:
181
+ invoke_udf_func = f"write(func({invocation_args}))"
182
+ elif is_seq_return:
183
+ # TODO: SNOW-2339385 Handle Array[T] return types correctly. Currently, only Seq[T] is supported.
184
+ invoke_udf_func = f"func({invocation_args}).toArray"
185
+ else:
186
+ invoke_udf_func = f"func({invocation_args})"
180
187
 
181
188
  # The lines of code below are required only when a Map is returned by the UDF. This is needed to serialize the
182
189
  # map output to a JSON string.
@@ -184,9 +191,9 @@ class ScalaUDFDef:
184
191
  ""
185
192
  if not is_map_return
186
193
  else """
187
- import org.json4s._
188
- import org.json4s.native.Serialization._
189
- import org.json4s.native.Serialization
194
+ import shaded_json4s._
195
+ import shaded_json4s.native.Serialization._
196
+ import shaded_json4s.native.Serialization
190
197
  """
191
198
  )
192
199
  map_return_formatter = (
@@ -199,22 +206,12 @@ import org.json4s.native.Serialization
199
206
 
200
207
  return f"""import org.apache.spark.sql.connect.common.UdfPacket
201
208
  {map_return_imports}
202
- import java.io.{{ByteArrayInputStream, ObjectInputStream}}
203
- import java.nio.file.{{Files, Paths}}
209
+ import com.snowflake.sas.scala.Utils
204
210
 
205
211
  object __RecreatedSparkUdf {{
206
212
  {map_return_formatter}
207
- private lazy val func: ({udf_func_input_types}) => {udf_func_return_type} = {{
208
- val importDirectory = System.getProperty("com.snowflake.import_directory")
209
- val fPath = importDirectory + "{self.name}.bin"
210
- val bytes = Files.readAllBytes(Paths.get(fPath))
211
- val ois = new ObjectInputStream(new ByteArrayInputStream(bytes))
212
- try {{
213
- ois.readObject().asInstanceOf[UdfPacket].function.asInstanceOf[({udf_func_input_types}) => {udf_func_return_type}]
214
- }} finally {{
215
- ois.close()
216
- }}
217
- }}
213
+ private lazy val func: ({udf_func_input_types}) => {udf_func_return_type} =
214
+ Utils.deserializeFunc("{self.name}.bin").asInstanceOf[({udf_func_input_types}) => {udf_func_return_type}]
218
215
 
219
216
  def __wrapperFunc({wrapper_arg_and_input_types_str}): {wrapper_return_type} = {{
220
217
  {invoke_udf_func}
@@ -299,29 +296,15 @@ def build_scala_udf_imports(session, payload, udf_name, is_map_return) -> List[s
299
296
  # Remove the stage path since it is not properly formatted.
300
297
  user_jars.append(row[0][row[0].find("/") :])
301
298
 
302
- # Jars used when the return type is a Map.
303
- map_jars = (
304
- []
305
- if not is_map_return
306
- else [
307
- f"{stage_resource_path}/json4s-core_2.12-3.7.0-M11.jar",
308
- f"{stage_resource_path}/json4s-native_2.12-3.7.0-M11.jar",
309
- f"{stage_resource_path}/paranamer-2.8.3.jar",
310
- ]
311
- )
312
-
313
299
  # Format the user jars to be used in the IMPORTS clause of the stored procedure.
314
- return (
315
- [
316
- closure_binary_file,
317
- f"{stage_resource_path}/spark-connect-client-jvm_2.12-3.5.6.jar",
318
- f"{stage_resource_path}/spark-common-utils_2.12-3.5.6.jar",
319
- f"{stage_resource_path}/spark-sql_2.12-3.5.6.jar",
320
- f"{stage_resource_path}/json4s-ast_2.12-3.7.0-M11.jar",
321
- ]
322
- + map_jars
323
- + [f"{stage + jar}" for jar in user_jars]
324
- )
300
+ return [
301
+ closure_binary_file,
302
+ f"{stage_resource_path}/spark-connect-client-jvm_2.12-3.5.6.jar",
303
+ f"{stage_resource_path}/spark-common-utils_2.12-3.5.6.jar",
304
+ f"{stage_resource_path}/spark-sql_2.12-3.5.6.jar",
305
+ f"{stage_resource_path}/json4s-ast_2.12-3.7.0-M11.jar",
306
+ f"{stage_resource_path}/sas-scala-udf_2.12-0.1.0.jar",
307
+ ] + [f"{stage + jar}" for jar in user_jars]
325
308
 
326
309
 
327
310
  def create_scala_udf(pciudf: ProcessCommonInlineUserDefinedFunction) -> ScalaUdf:
@@ -343,6 +326,14 @@ def create_scala_udf(pciudf: ProcessCommonInlineUserDefinedFunction) -> ScalaUdf
343
326
  Returns:
344
327
  A ScalaUdf object representing the created or cached Scala UDF.
345
328
  """
329
+ from snowflake.snowpark_connect.resources_initializer import (
330
+ wait_for_resource_initialization,
331
+ )
332
+
333
+ # Make sure that the resource initializer thread is completed before creating Scala UDFs since we depend on the jars
334
+ # uploaded by it.
335
+ wait_for_resource_initialization()
336
+
346
337
  from snowflake.snowpark_connect.utils.session import get_or_create_snowpark_session
347
338
 
348
339
  function_name = pciudf._function_name