snowpark-connect 0.27.0__py3-none-any.whl → 0.28.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of snowpark-connect might be problematic. Click here for more details.
- snowflake/snowpark_connect/column_name_handler.py +3 -93
- snowflake/snowpark_connect/config.py +99 -1
- snowflake/snowpark_connect/dataframe_container.py +0 -6
- snowflake/snowpark_connect/execute_plan/map_execution_command.py +31 -68
- snowflake/snowpark_connect/expression/map_expression.py +22 -7
- snowflake/snowpark_connect/expression/map_sql_expression.py +22 -18
- snowflake/snowpark_connect/expression/map_unresolved_attribute.py +4 -26
- snowflake/snowpark_connect/expression/map_unresolved_function.py +12 -3
- snowflake/snowpark_connect/expression/map_unresolved_star.py +2 -3
- snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.1.0.jar +0 -0
- snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +207 -20
- snowflake/snowpark_connect/relation/map_extension.py +14 -10
- snowflake/snowpark_connect/relation/map_join.py +62 -258
- snowflake/snowpark_connect/relation/map_relation.py +5 -1
- snowflake/snowpark_connect/relation/map_sql.py +464 -68
- snowflake/snowpark_connect/relation/read/map_read_table.py +58 -0
- snowflake/snowpark_connect/relation/write/map_write.py +228 -120
- snowflake/snowpark_connect/resources_initializer.py +20 -5
- snowflake/snowpark_connect/server.py +16 -17
- snowflake/snowpark_connect/utils/concurrent.py +4 -0
- snowflake/snowpark_connect/utils/context.py +21 -0
- snowflake/snowpark_connect/utils/describe_query_cache.py +57 -51
- snowflake/snowpark_connect/utils/identifiers.py +128 -2
- snowflake/snowpark_connect/utils/io_utils.py +21 -1
- snowflake/snowpark_connect/utils/scala_udf_utils.py +34 -43
- snowflake/snowpark_connect/utils/session.py +16 -26
- snowflake/snowpark_connect/utils/telemetry.py +53 -0
- snowflake/snowpark_connect/utils/temporary_view_cache.py +61 -0
- snowflake/snowpark_connect/utils/udf_utils.py +9 -8
- snowflake/snowpark_connect/utils/udtf_utils.py +3 -2
- snowflake/snowpark_connect/version.py +1 -1
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-0.28.1.dist-info}/METADATA +2 -2
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-0.28.1.dist-info}/RECORD +41 -41
- snowflake/snowpark_connect/hidden_column.py +0 -39
- {snowpark_connect-0.27.0.data → snowpark_connect-0.28.1.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.27.0.data → snowpark_connect-0.28.1.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.27.0.data → snowpark_connect-0.28.1.data}/scripts/snowpark-submit +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-0.28.1.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-0.28.1.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-0.28.1.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-0.28.1.dist-info}/licenses/NOTICE-binary +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-0.28.1.dist-info}/top_level.txt +0 -0
|
@@ -12,6 +12,7 @@ _resources_initialized = threading.Event()
|
|
|
12
12
|
_initializer_lock = threading.Lock()
|
|
13
13
|
SPARK_VERSION = "3.5.6"
|
|
14
14
|
RESOURCE_PATH = "/snowflake/snowpark_connect/resources"
|
|
15
|
+
_upload_jars = True # Flag to control whether to upload jars. Required for Scala UDFs.
|
|
15
16
|
|
|
16
17
|
|
|
17
18
|
def initialize_resources() -> None:
|
|
@@ -57,10 +58,8 @@ def initialize_resources() -> None:
|
|
|
57
58
|
f"spark-sql_2.12-{SPARK_VERSION}.jar",
|
|
58
59
|
f"spark-connect-client-jvm_2.12-{SPARK_VERSION}.jar",
|
|
59
60
|
f"spark-common-utils_2.12-{SPARK_VERSION}.jar",
|
|
61
|
+
"sas-scala-udf_2.12-0.1.0.jar",
|
|
60
62
|
"json4s-ast_2.12-3.7.0-M11.jar",
|
|
61
|
-
"json4s-native_2.12-3.7.0-M11.jar",
|
|
62
|
-
"json4s-core_2.12-3.7.0-M11.jar",
|
|
63
|
-
"paranamer-2.8.3.jar",
|
|
64
63
|
]
|
|
65
64
|
|
|
66
65
|
for jar in jar_files:
|
|
@@ -80,9 +79,11 @@ def initialize_resources() -> None:
|
|
|
80
79
|
("Initialize Session Stage", initialize_session_stage), # Takes about 0.3s
|
|
81
80
|
("Initialize Session Catalog", initialize_catalog), # Takes about 1.2s
|
|
82
81
|
("Snowflake Connection Warm Up", warm_up_sf_connection), # Takes about 1s
|
|
83
|
-
("Upload Scala UDF Jars", upload_scala_udf_jars),
|
|
84
82
|
]
|
|
85
83
|
|
|
84
|
+
if _upload_jars:
|
|
85
|
+
resources.append(("Upload Scala UDF Jars", upload_scala_udf_jars))
|
|
86
|
+
|
|
86
87
|
for name, resource_func in resources:
|
|
87
88
|
resource_start = time.time()
|
|
88
89
|
try:
|
|
@@ -113,4 +114,18 @@ def initialize_resources_async() -> threading.Thread:
|
|
|
113
114
|
|
|
114
115
|
def wait_for_resource_initialization() -> None:
|
|
115
116
|
with _initializer_lock:
|
|
116
|
-
_resource_initializer.join()
|
|
117
|
+
_resource_initializer.join(timeout=300) # wait at most 300 seconds
|
|
118
|
+
if _resource_initializer.is_alive():
|
|
119
|
+
logger.error(
|
|
120
|
+
"Resource initialization failed - initializer thread has been running for over 300 seconds."
|
|
121
|
+
)
|
|
122
|
+
raise RuntimeError(
|
|
123
|
+
"Resource initialization failed - initializer thread has been running for over 300 seconds."
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def set_upload_jars(upload: bool) -> None:
|
|
128
|
+
"""Set whether to upload jars required for Scala UDFs. This should be set to False if Scala UDFs
|
|
129
|
+
are not used, to avoid the overhead of uploading jars."""
|
|
130
|
+
global _upload_jars
|
|
131
|
+
_upload_jars = upload
|
|
@@ -725,30 +725,33 @@ def _serve(
|
|
|
725
725
|
# No need to start grpc server in TCM
|
|
726
726
|
return
|
|
727
727
|
|
|
728
|
+
grpc_max_msg_size = get_int_from_env(
|
|
729
|
+
"SNOWFLAKE_GRPC_MAX_MESSAGE_SIZE",
|
|
730
|
+
_SPARK_CONNECT_GRPC_MAX_MESSAGE_SIZE,
|
|
731
|
+
)
|
|
732
|
+
grpc_max_metadata_size = get_int_from_env(
|
|
733
|
+
"SNOWFLAKE_GRPC_MAX_METADATA_SIZE",
|
|
734
|
+
_SPARK_CONNECT_GRPC_MAX_METADATA_SIZE,
|
|
735
|
+
)
|
|
728
736
|
server_options = [
|
|
729
737
|
(
|
|
730
738
|
"grpc.max_receive_message_length",
|
|
731
|
-
|
|
732
|
-
"SNOWFLAKE_GRPC_MAX_MESSAGE_SIZE",
|
|
733
|
-
_SPARK_CONNECT_GRPC_MAX_MESSAGE_SIZE,
|
|
734
|
-
),
|
|
739
|
+
grpc_max_msg_size,
|
|
735
740
|
),
|
|
736
741
|
(
|
|
737
742
|
"grpc.max_metadata_size",
|
|
738
|
-
|
|
739
|
-
"SNOWFLAKE_GRPC_MAX_METADATA_SIZE",
|
|
740
|
-
_SPARK_CONNECT_GRPC_MAX_METADATA_SIZE,
|
|
741
|
-
),
|
|
743
|
+
grpc_max_metadata_size,
|
|
742
744
|
),
|
|
743
745
|
(
|
|
744
746
|
"grpc.absolute_max_metadata_size",
|
|
745
|
-
|
|
746
|
-
"SNOWFLAKE_GRPC_MAX_METADATA_SIZE",
|
|
747
|
-
_SPARK_CONNECT_GRPC_MAX_METADATA_SIZE,
|
|
748
|
-
)
|
|
749
|
-
* 2,
|
|
747
|
+
grpc_max_metadata_size * 2,
|
|
750
748
|
),
|
|
751
749
|
]
|
|
750
|
+
|
|
751
|
+
from pyspark.sql.connect.client import ChannelBuilder
|
|
752
|
+
|
|
753
|
+
ChannelBuilder.MAX_MESSAGE_LENGTH = grpc_max_msg_size
|
|
754
|
+
|
|
752
755
|
server = grpc.server(
|
|
753
756
|
futures.ThreadPoolExecutor(max_workers=10), options=server_options
|
|
754
757
|
)
|
|
@@ -1053,10 +1056,6 @@ def start_session(
|
|
|
1053
1056
|
global _SPARK_CONNECT_GRPC_MAX_MESSAGE_SIZE
|
|
1054
1057
|
_SPARK_CONNECT_GRPC_MAX_MESSAGE_SIZE = max_grpc_message_size
|
|
1055
1058
|
|
|
1056
|
-
from pyspark.sql.connect.client import ChannelBuilder
|
|
1057
|
-
|
|
1058
|
-
ChannelBuilder.MAX_MESSAGE_LENGTH = max_grpc_message_size
|
|
1059
|
-
|
|
1060
1059
|
if os.environ.get("SPARK_ENV_LOADED"):
|
|
1061
1060
|
raise RuntimeError(
|
|
1062
1061
|
"Snowpark Connect cannot be run inside of a Spark environment"
|
|
@@ -52,6 +52,10 @@ class SynchronizedDict(Mapping[K, V]):
|
|
|
52
52
|
with self._lock.writer():
|
|
53
53
|
self._dict[key] = value
|
|
54
54
|
|
|
55
|
+
def __delitem__(self, key: K) -> None:
|
|
56
|
+
with self._lock.writer():
|
|
57
|
+
del self._dict[key]
|
|
58
|
+
|
|
55
59
|
def __contains__(self, key: K) -> bool:
|
|
56
60
|
with self._lock.reader():
|
|
57
61
|
return key in self._dict
|
|
@@ -70,6 +70,26 @@ _lca_alias_map: ContextVar[dict[str, TypedColumn]] = ContextVar(
|
|
|
70
70
|
default={},
|
|
71
71
|
)
|
|
72
72
|
|
|
73
|
+
_view_process_context = ContextVar("_view_process_context", default=[])
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
@contextmanager
|
|
77
|
+
def push_processed_view(name: str):
|
|
78
|
+
_view_process_context.set(_view_process_context.get() + [name])
|
|
79
|
+
yield
|
|
80
|
+
_view_process_context.set(_view_process_context.get()[:-1])
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def get_processed_views() -> list[str]:
|
|
84
|
+
return _view_process_context.get()
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def register_processed_view(name: str) -> None:
|
|
88
|
+
context = _view_process_context.get()
|
|
89
|
+
context.append(name)
|
|
90
|
+
_view_process_context.set(context)
|
|
91
|
+
|
|
92
|
+
|
|
73
93
|
# Context variable to track current grouping columns for grouping_id() function
|
|
74
94
|
_current_grouping_columns: ContextVar[list[str]] = ContextVar(
|
|
75
95
|
"_current_grouping_columns",
|
|
@@ -387,6 +407,7 @@ def clear_context_data() -> None:
|
|
|
387
407
|
_plan_id_map.set({})
|
|
388
408
|
_alias_map.set({})
|
|
389
409
|
|
|
410
|
+
_view_process_context.set([])
|
|
390
411
|
_next_sql_plan_id.set(_STARTING_SQL_PLAN_ID)
|
|
391
412
|
_sql_plan_name_map.set({})
|
|
392
413
|
_map_partitions_stack.set(0)
|
|
@@ -6,20 +6,24 @@ import hashlib
|
|
|
6
6
|
import inspect
|
|
7
7
|
import random
|
|
8
8
|
import re
|
|
9
|
-
import threading
|
|
10
9
|
import time
|
|
11
10
|
from typing import Any
|
|
12
11
|
|
|
13
12
|
from snowflake import snowpark
|
|
14
13
|
from snowflake.connector.cursor import ResultMetadataV2
|
|
15
14
|
from snowflake.snowpark._internal.server_connection import ServerConnection
|
|
15
|
+
from snowflake.snowpark_connect.utils.concurrent import SynchronizedDict
|
|
16
16
|
from snowflake.snowpark_connect.utils.snowpark_connect_logging import logger
|
|
17
17
|
from snowflake.snowpark_connect.utils.telemetry import telemetry
|
|
18
18
|
|
|
19
19
|
DESCRIBE_CACHE_TTL_SECONDS = 15
|
|
20
20
|
USE_DESCRIBE_QUERY_CACHE = True
|
|
21
21
|
|
|
22
|
-
DDL_DETECTION_PATTERN = re.compile(r"
|
|
22
|
+
DDL_DETECTION_PATTERN = re.compile(r"\s*(CREATE|ALTER|DROP)\b", re.IGNORECASE)
|
|
23
|
+
PLAIN_CREATE_PATTERN = re.compile(
|
|
24
|
+
r"\s*CREATE\s+((LOCAL|GLOBAL)\s+)?(TRANSIENT\s+)?TABLE\b", re.IGNORECASE
|
|
25
|
+
)
|
|
26
|
+
|
|
23
27
|
# Pattern for simple constant queries like: SELECT 3 :: INT AS "3-80000030-0" FROM ( SELECT $1 AS "__DUMMY" FROM VALUES (NULL :: STRING))
|
|
24
28
|
# Using exact spacing pattern from generated SQL for deterministic matching
|
|
25
29
|
# Column ID format: {original_name}-{8_digit_hex_plan_id}-{column_index}
|
|
@@ -32,8 +36,7 @@ SIMPLE_CONSTANT_PATTERN = re.compile(
|
|
|
32
36
|
|
|
33
37
|
class DescribeQueryCache:
|
|
34
38
|
def __init__(self) -> None:
|
|
35
|
-
self._cache =
|
|
36
|
-
self._lock = threading.Lock()
|
|
39
|
+
self._cache = SynchronizedDict()
|
|
37
40
|
|
|
38
41
|
@staticmethod
|
|
39
42
|
def _hash_query(sql_query: str) -> str:
|
|
@@ -48,49 +51,49 @@ class DescribeQueryCache:
|
|
|
48
51
|
return sql_query
|
|
49
52
|
|
|
50
53
|
def get(self, sql_query: str) -> list[ResultMetadataV2] | None:
|
|
54
|
+
telemetry.report_describe_query_cache_lookup()
|
|
55
|
+
|
|
51
56
|
cache_key = self._get_cache_key(sql_query)
|
|
52
57
|
key = self._hash_query(cache_key)
|
|
53
58
|
current_time = time.monotonic()
|
|
54
59
|
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
if
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
metadata
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
)
|
|
93
|
-
del self._cache[key]
|
|
60
|
+
if key in self._cache:
|
|
61
|
+
result, timestamp = self._cache[key]
|
|
62
|
+
if current_time < timestamp + DESCRIBE_CACHE_TTL_SECONDS:
|
|
63
|
+
logger.debug(
|
|
64
|
+
f"Returning query result from cache for query: {sql_query[:20]}"
|
|
65
|
+
)
|
|
66
|
+
self._cache[key] = (result, current_time)
|
|
67
|
+
|
|
68
|
+
# If this is a constant query, we need to transform the result metadata
|
|
69
|
+
# to match the actual query's column name
|
|
70
|
+
if cache_key != sql_query: # Only transform if we normalized the key
|
|
71
|
+
match = SIMPLE_CONSTANT_PATTERN.match(sql_query)
|
|
72
|
+
if match:
|
|
73
|
+
number, column_id = match.groups()
|
|
74
|
+
expected_column_name = column_id
|
|
75
|
+
|
|
76
|
+
# Transform the cached result to match this query's column name
|
|
77
|
+
# There should only be one column in these constant queries
|
|
78
|
+
metadata = result[0]
|
|
79
|
+
new_metadata = ResultMetadataV2(
|
|
80
|
+
name=expected_column_name,
|
|
81
|
+
type_code=metadata.type_code,
|
|
82
|
+
display_size=metadata.display_size,
|
|
83
|
+
internal_size=metadata.internal_size,
|
|
84
|
+
precision=metadata.precision,
|
|
85
|
+
scale=metadata.scale,
|
|
86
|
+
is_nullable=metadata.is_nullable,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
telemetry.report_describe_query_cache_hit()
|
|
90
|
+
return [new_metadata]
|
|
91
|
+
|
|
92
|
+
telemetry.report_describe_query_cache_hit()
|
|
93
|
+
return result
|
|
94
|
+
else:
|
|
95
|
+
telemetry.report_describe_query_cache_expired()
|
|
96
|
+
del self._cache[key]
|
|
94
97
|
return None
|
|
95
98
|
|
|
96
99
|
def put(self, sql_query: str, result: list[ResultMetadataV2] | None) -> None:
|
|
@@ -102,12 +105,18 @@ class DescribeQueryCache:
|
|
|
102
105
|
|
|
103
106
|
logger.debug(f"Putting query into cache: {sql_query[:50]}...")
|
|
104
107
|
|
|
105
|
-
|
|
106
|
-
self._cache[key] = (result, time.monotonic())
|
|
108
|
+
self._cache[key] = (result, time.monotonic())
|
|
107
109
|
|
|
108
110
|
def clear(self) -> None:
|
|
109
|
-
|
|
110
|
-
|
|
111
|
+
self._cache.clear()
|
|
112
|
+
|
|
113
|
+
def update_cache_for_query(self, query: str) -> None:
|
|
114
|
+
# Clear cache for DDL operations that modify existing objects (exclude CREATE TABLE)
|
|
115
|
+
if DDL_DETECTION_PATTERN.search(query) and not PLAIN_CREATE_PATTERN.search(
|
|
116
|
+
query
|
|
117
|
+
):
|
|
118
|
+
self.clear()
|
|
119
|
+
telemetry.report_describe_query_cache_clear(query[:100])
|
|
111
120
|
|
|
112
121
|
|
|
113
122
|
def instrument_session_for_describe_cache(session: snowpark.Session):
|
|
@@ -126,10 +135,7 @@ def instrument_session_for_describe_cache(session: snowpark.Session):
|
|
|
126
135
|
if isinstance(cache_instance, DescribeQueryCache):
|
|
127
136
|
cache = cache_instance
|
|
128
137
|
|
|
129
|
-
|
|
130
|
-
if DDL_DETECTION_PATTERN.search(query):
|
|
131
|
-
logger.debug(f"DDL detected, clearing describe query cache: '{query}'")
|
|
132
|
-
cache.clear()
|
|
138
|
+
cache.update_cache_for_query(query)
|
|
133
139
|
|
|
134
140
|
def wrap_execute(wrapped_fn):
|
|
135
141
|
def fn(query: str, **kwargs):
|
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
import re
|
|
5
|
+
from typing import Any, TypeVar
|
|
5
6
|
|
|
6
7
|
from pyspark.errors import AnalysisException
|
|
7
8
|
|
|
@@ -27,12 +28,18 @@ def unquote_spark_identifier_if_quoted(spark_name: str) -> str:
|
|
|
27
28
|
raise AnalysisException(f"Invalid name: {spark_name}")
|
|
28
29
|
|
|
29
30
|
|
|
30
|
-
def spark_to_sf_single_id_with_unquoting(
|
|
31
|
+
def spark_to_sf_single_id_with_unquoting(
|
|
32
|
+
name: str, use_auto_upper_case: bool = False
|
|
33
|
+
) -> str:
|
|
31
34
|
"""
|
|
32
35
|
Transforms a spark name to a valid snowflake name by quoting and potentially uppercasing it.
|
|
33
36
|
Unquotes the spark name if necessary. Will raise an AnalysisException if given name is not valid.
|
|
34
37
|
"""
|
|
35
|
-
return
|
|
38
|
+
return (
|
|
39
|
+
spark_to_sf_single_id(unquote_spark_identifier_if_quoted(name))
|
|
40
|
+
if use_auto_upper_case
|
|
41
|
+
else quote_name_without_upper_casing(unquote_spark_identifier_if_quoted(name))
|
|
42
|
+
)
|
|
36
43
|
|
|
37
44
|
|
|
38
45
|
def spark_to_sf_single_id(name: str, is_column: bool = False) -> str:
|
|
@@ -117,3 +124,122 @@ def split_fully_qualified_spark_name(qualified_name: str | None) -> list[str]:
|
|
|
117
124
|
parts.append("".join(token_chars))
|
|
118
125
|
|
|
119
126
|
return parts
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
# See https://docs.snowflake.com/en/sql-reference/identifiers-syntax for identifier syntax
|
|
130
|
+
UNQUOTED_IDENTIFIER_REGEX = r"([a-zA-Z_])([a-zA-Z0-9_$]{0,254})"
|
|
131
|
+
QUOTED_IDENTIFIER_REGEX = r'"((""|[^"]){0,255})"'
|
|
132
|
+
VALID_IDENTIFIER_REGEX = f"(?:{UNQUOTED_IDENTIFIER_REGEX}|{QUOTED_IDENTIFIER_REGEX})"
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
Self = TypeVar("Self", bound="FQN")
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class FQN:
|
|
139
|
+
"""Represents an object identifier, supporting fully qualified names.
|
|
140
|
+
|
|
141
|
+
The instance supports builder pattern that allows updating the identifier with database and
|
|
142
|
+
schema from different sources.
|
|
143
|
+
|
|
144
|
+
Examples
|
|
145
|
+
________
|
|
146
|
+
>>> fqn = FQN.from_string("my_schema.object").using_connection(conn)
|
|
147
|
+
|
|
148
|
+
>>> fqn = FQN.from_string("my_name").set_database("db").set_schema("foo")
|
|
149
|
+
"""
|
|
150
|
+
|
|
151
|
+
def __init__(
|
|
152
|
+
self,
|
|
153
|
+
database: str | None,
|
|
154
|
+
schema: str | None,
|
|
155
|
+
name: str,
|
|
156
|
+
signature: str | None = None,
|
|
157
|
+
) -> None:
|
|
158
|
+
self._database = database
|
|
159
|
+
self._schema = schema
|
|
160
|
+
self._name = name
|
|
161
|
+
self.signature = signature
|
|
162
|
+
|
|
163
|
+
@property
|
|
164
|
+
def database(self) -> str | None:
|
|
165
|
+
return self._database
|
|
166
|
+
|
|
167
|
+
@property
|
|
168
|
+
def schema(self) -> str | None:
|
|
169
|
+
return self._schema
|
|
170
|
+
|
|
171
|
+
@property
|
|
172
|
+
def name(self) -> str:
|
|
173
|
+
return self._name
|
|
174
|
+
|
|
175
|
+
@property
|
|
176
|
+
def prefix(self) -> str:
|
|
177
|
+
if self.database:
|
|
178
|
+
return f"{self.database}.{self.schema if self.schema else 'PUBLIC'}"
|
|
179
|
+
if self.schema:
|
|
180
|
+
return f"{self.schema}"
|
|
181
|
+
return ""
|
|
182
|
+
|
|
183
|
+
@property
|
|
184
|
+
def identifier(self) -> str:
|
|
185
|
+
if self.prefix:
|
|
186
|
+
return f"{self.prefix}.{self.name}"
|
|
187
|
+
return self.name
|
|
188
|
+
|
|
189
|
+
def __str__(self) -> str:
|
|
190
|
+
return self.identifier
|
|
191
|
+
|
|
192
|
+
def __eq__(self, other: Any) -> bool:
|
|
193
|
+
if not isinstance(other, FQN):
|
|
194
|
+
raise AnalysisException(f"{other} is not a valid FQN")
|
|
195
|
+
return self.identifier == other.identifier
|
|
196
|
+
|
|
197
|
+
@classmethod
|
|
198
|
+
def from_string(cls, identifier: str) -> Self:
|
|
199
|
+
"""Take in an object name in the form [[database.]schema.]name and return a new :class:`FQN` instance.
|
|
200
|
+
|
|
201
|
+
Raises:
|
|
202
|
+
InvalidIdentifierError: If the object identifier does not meet identifier requirements.
|
|
203
|
+
"""
|
|
204
|
+
qualifier_pattern = (
|
|
205
|
+
rf"(?:(?P<first_qualifier>{VALID_IDENTIFIER_REGEX})\.)?"
|
|
206
|
+
rf"(?:(?P<second_qualifier>{VALID_IDENTIFIER_REGEX})\.)?"
|
|
207
|
+
rf"(?P<name>{VALID_IDENTIFIER_REGEX})(?P<signature>\(.*\))?"
|
|
208
|
+
)
|
|
209
|
+
result = re.fullmatch(qualifier_pattern, identifier)
|
|
210
|
+
|
|
211
|
+
if result is None:
|
|
212
|
+
raise AnalysisException(f"{identifier} is not a valid identifier")
|
|
213
|
+
|
|
214
|
+
unqualified_name = result.group("name")
|
|
215
|
+
if result.group("second_qualifier") is not None:
|
|
216
|
+
database = result.group("first_qualifier")
|
|
217
|
+
schema = result.group("second_qualifier")
|
|
218
|
+
else:
|
|
219
|
+
database = None
|
|
220
|
+
schema = result.group("first_qualifier")
|
|
221
|
+
|
|
222
|
+
signature = None
|
|
223
|
+
if result.group("signature"):
|
|
224
|
+
signature = result.group("signature")
|
|
225
|
+
return cls(
|
|
226
|
+
name=unqualified_name, schema=schema, database=database, signature=signature
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
def set_database(self, database: str | None) -> Self:
|
|
230
|
+
if database:
|
|
231
|
+
self._database = database
|
|
232
|
+
return self
|
|
233
|
+
|
|
234
|
+
def set_schema(self, schema: str | None) -> Self:
|
|
235
|
+
if schema:
|
|
236
|
+
self._schema = schema
|
|
237
|
+
return self
|
|
238
|
+
|
|
239
|
+
def set_name(self, name: str) -> Self:
|
|
240
|
+
self._name = name
|
|
241
|
+
return self
|
|
242
|
+
|
|
243
|
+
def to_dict(self) -> dict[str, str | None]:
|
|
244
|
+
"""Return the dictionary representation of the instance."""
|
|
245
|
+
return {"name": self.name, "schema": self.schema, "database": self.database}
|
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
#
|
|
2
2
|
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
|
|
3
3
|
#
|
|
4
|
-
|
|
4
|
+
import contextlib
|
|
5
5
|
import functools
|
|
6
6
|
|
|
7
7
|
from snowflake.snowpark import Session
|
|
8
|
+
from snowflake.snowpark_connect.utils.identifiers import FQN
|
|
8
9
|
|
|
9
10
|
|
|
10
11
|
@functools.cache
|
|
@@ -33,3 +34,22 @@ def file_format(
|
|
|
33
34
|
).collect()
|
|
34
35
|
|
|
35
36
|
return file_format_name
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def get_table_type(
|
|
40
|
+
snowpark_table_name: str,
|
|
41
|
+
snowpark_session: Session,
|
|
42
|
+
) -> str:
|
|
43
|
+
fqn = FQN.from_string(snowpark_table_name)
|
|
44
|
+
with contextlib.suppress(Exception):
|
|
45
|
+
if fqn.database is not None:
|
|
46
|
+
return snowpark_session.catalog.getTable(
|
|
47
|
+
table_name=fqn.name, schema=fqn.schema, database=fqn.database
|
|
48
|
+
).table_type
|
|
49
|
+
elif fqn.schema is not None:
|
|
50
|
+
return snowpark_session.catalog.getTable(
|
|
51
|
+
table_name=fqn.name, schema=fqn.schema
|
|
52
|
+
).table_type
|
|
53
|
+
else:
|
|
54
|
+
return snowpark_session.catalog.getTable(table_name=fqn.name).table_type
|
|
55
|
+
return "TABLE"
|
|
@@ -171,12 +171,19 @@ class ScalaUDFDef:
|
|
|
171
171
|
is_map_return = udf_func_return_type.startswith("Map")
|
|
172
172
|
wrapper_return_type = "String" if is_map_return else udf_func_return_type
|
|
173
173
|
|
|
174
|
+
# For handling Seq type correctly, ensure that the wrapper function always uses Array as its input and
|
|
175
|
+
# return types (when required) and the wrapped function uses Seq.
|
|
176
|
+
udf_func_return_type = udf_func_return_type.replace("Array", "Seq")
|
|
177
|
+
is_seq_return = udf_func_return_type.startswith("Seq")
|
|
178
|
+
|
|
174
179
|
# Need to call the map to JSON string converter when a map is returned by the user's function.
|
|
175
|
-
|
|
176
|
-
f"write(func({invocation_args}))"
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
+
if is_map_return:
|
|
181
|
+
invoke_udf_func = f"write(func({invocation_args}))"
|
|
182
|
+
elif is_seq_return:
|
|
183
|
+
# TODO: SNOW-2339385 Handle Array[T] return types correctly. Currently, only Seq[T] is supported.
|
|
184
|
+
invoke_udf_func = f"func({invocation_args}).toArray"
|
|
185
|
+
else:
|
|
186
|
+
invoke_udf_func = f"func({invocation_args})"
|
|
180
187
|
|
|
181
188
|
# The lines of code below are required only when a Map is returned by the UDF. This is needed to serialize the
|
|
182
189
|
# map output to a JSON string.
|
|
@@ -184,9 +191,9 @@ class ScalaUDFDef:
|
|
|
184
191
|
""
|
|
185
192
|
if not is_map_return
|
|
186
193
|
else """
|
|
187
|
-
import
|
|
188
|
-
import
|
|
189
|
-
import
|
|
194
|
+
import shaded_json4s._
|
|
195
|
+
import shaded_json4s.native.Serialization._
|
|
196
|
+
import shaded_json4s.native.Serialization
|
|
190
197
|
"""
|
|
191
198
|
)
|
|
192
199
|
map_return_formatter = (
|
|
@@ -199,22 +206,12 @@ import org.json4s.native.Serialization
|
|
|
199
206
|
|
|
200
207
|
return f"""import org.apache.spark.sql.connect.common.UdfPacket
|
|
201
208
|
{map_return_imports}
|
|
202
|
-
import
|
|
203
|
-
import java.nio.file.{{Files, Paths}}
|
|
209
|
+
import com.snowflake.sas.scala.Utils
|
|
204
210
|
|
|
205
211
|
object __RecreatedSparkUdf {{
|
|
206
212
|
{map_return_formatter}
|
|
207
|
-
private lazy val func: ({udf_func_input_types}) => {udf_func_return_type} =
|
|
208
|
-
|
|
209
|
-
val fPath = importDirectory + "{self.name}.bin"
|
|
210
|
-
val bytes = Files.readAllBytes(Paths.get(fPath))
|
|
211
|
-
val ois = new ObjectInputStream(new ByteArrayInputStream(bytes))
|
|
212
|
-
try {{
|
|
213
|
-
ois.readObject().asInstanceOf[UdfPacket].function.asInstanceOf[({udf_func_input_types}) => {udf_func_return_type}]
|
|
214
|
-
}} finally {{
|
|
215
|
-
ois.close()
|
|
216
|
-
}}
|
|
217
|
-
}}
|
|
213
|
+
private lazy val func: ({udf_func_input_types}) => {udf_func_return_type} =
|
|
214
|
+
Utils.deserializeFunc("{self.name}.bin").asInstanceOf[({udf_func_input_types}) => {udf_func_return_type}]
|
|
218
215
|
|
|
219
216
|
def __wrapperFunc({wrapper_arg_and_input_types_str}): {wrapper_return_type} = {{
|
|
220
217
|
{invoke_udf_func}
|
|
@@ -299,29 +296,15 @@ def build_scala_udf_imports(session, payload, udf_name, is_map_return) -> List[s
|
|
|
299
296
|
# Remove the stage path since it is not properly formatted.
|
|
300
297
|
user_jars.append(row[0][row[0].find("/") :])
|
|
301
298
|
|
|
302
|
-
# Jars used when the return type is a Map.
|
|
303
|
-
map_jars = (
|
|
304
|
-
[]
|
|
305
|
-
if not is_map_return
|
|
306
|
-
else [
|
|
307
|
-
f"{stage_resource_path}/json4s-core_2.12-3.7.0-M11.jar",
|
|
308
|
-
f"{stage_resource_path}/json4s-native_2.12-3.7.0-M11.jar",
|
|
309
|
-
f"{stage_resource_path}/paranamer-2.8.3.jar",
|
|
310
|
-
]
|
|
311
|
-
)
|
|
312
|
-
|
|
313
299
|
# Format the user jars to be used in the IMPORTS clause of the stored procedure.
|
|
314
|
-
return
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
+ map_jars
|
|
323
|
-
+ [f"{stage + jar}" for jar in user_jars]
|
|
324
|
-
)
|
|
300
|
+
return [
|
|
301
|
+
closure_binary_file,
|
|
302
|
+
f"{stage_resource_path}/spark-connect-client-jvm_2.12-3.5.6.jar",
|
|
303
|
+
f"{stage_resource_path}/spark-common-utils_2.12-3.5.6.jar",
|
|
304
|
+
f"{stage_resource_path}/spark-sql_2.12-3.5.6.jar",
|
|
305
|
+
f"{stage_resource_path}/json4s-ast_2.12-3.7.0-M11.jar",
|
|
306
|
+
f"{stage_resource_path}/sas-scala-udf_2.12-0.1.0.jar",
|
|
307
|
+
] + [f"{stage + jar}" for jar in user_jars]
|
|
325
308
|
|
|
326
309
|
|
|
327
310
|
def create_scala_udf(pciudf: ProcessCommonInlineUserDefinedFunction) -> ScalaUdf:
|
|
@@ -343,6 +326,14 @@ def create_scala_udf(pciudf: ProcessCommonInlineUserDefinedFunction) -> ScalaUdf
|
|
|
343
326
|
Returns:
|
|
344
327
|
A ScalaUdf object representing the created or cached Scala UDF.
|
|
345
328
|
"""
|
|
329
|
+
from snowflake.snowpark_connect.resources_initializer import (
|
|
330
|
+
wait_for_resource_initialization,
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
# Make sure that the resource initializer thread is completed before creating Scala UDFs since we depend on the jars
|
|
334
|
+
# uploaded by it.
|
|
335
|
+
wait_for_resource_initialization()
|
|
336
|
+
|
|
346
337
|
from snowflake.snowpark_connect.utils.session import get_or_create_snowpark_session
|
|
347
338
|
|
|
348
339
|
function_name = pciudf._function_name
|