snowpark-connect 0.27.0__py3-none-any.whl → 1.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/snowpark_connect/__init__.py +1 -0
- snowflake/snowpark_connect/analyze_plan/map_tree_string.py +8 -4
- snowflake/snowpark_connect/client/__init__.py +15 -0
- snowflake/snowpark_connect/client/error_utils.py +30 -0
- snowflake/snowpark_connect/client/exceptions.py +36 -0
- snowflake/snowpark_connect/client/query_results.py +90 -0
- snowflake/snowpark_connect/client/server.py +680 -0
- snowflake/snowpark_connect/client/utils/__init__.py +10 -0
- snowflake/snowpark_connect/client/utils/session.py +85 -0
- snowflake/snowpark_connect/column_name_handler.py +404 -243
- snowflake/snowpark_connect/column_qualifier.py +43 -0
- snowflake/snowpark_connect/config.py +237 -23
- snowflake/snowpark_connect/constants.py +2 -0
- snowflake/snowpark_connect/dataframe_container.py +102 -8
- snowflake/snowpark_connect/date_time_format_mapping.py +71 -13
- snowflake/snowpark_connect/error/error_codes.py +50 -0
- snowflake/snowpark_connect/error/error_utils.py +172 -23
- snowflake/snowpark_connect/error/exceptions.py +13 -4
- snowflake/snowpark_connect/execute_plan/map_execution_command.py +15 -160
- snowflake/snowpark_connect/execute_plan/map_execution_root.py +26 -20
- snowflake/snowpark_connect/execute_plan/utils.py +5 -1
- snowflake/snowpark_connect/expression/function_defaults.py +9 -2
- snowflake/snowpark_connect/expression/hybrid_column_map.py +53 -5
- snowflake/snowpark_connect/expression/literal.py +37 -13
- snowflake/snowpark_connect/expression/map_cast.py +123 -5
- snowflake/snowpark_connect/expression/map_expression.py +80 -27
- snowflake/snowpark_connect/expression/map_extension.py +322 -12
- snowflake/snowpark_connect/expression/map_sql_expression.py +316 -81
- snowflake/snowpark_connect/expression/map_udf.py +85 -20
- snowflake/snowpark_connect/expression/map_unresolved_attribute.py +451 -173
- snowflake/snowpark_connect/expression/map_unresolved_function.py +2748 -746
- snowflake/snowpark_connect/expression/map_unresolved_star.py +87 -23
- snowflake/snowpark_connect/expression/map_update_fields.py +70 -18
- snowflake/snowpark_connect/expression/map_window_function.py +18 -3
- snowflake/snowpark_connect/includes/jars/{scala-library-2.12.18.jar → sas-scala-udf_2.12-0.2.0.jar} +0 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py +1 -1
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/listener_worker.py +1 -1
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.py +12 -10
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.pyi +14 -2
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +10 -8
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +13 -6
- snowflake/snowpark_connect/relation/catalogs/abstract_spark_catalog.py +65 -17
- snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +297 -49
- snowflake/snowpark_connect/relation/catalogs/utils.py +12 -4
- snowflake/snowpark_connect/relation/io_utils.py +110 -10
- snowflake/snowpark_connect/relation/map_aggregate.py +196 -255
- snowflake/snowpark_connect/relation/map_catalog.py +5 -1
- snowflake/snowpark_connect/relation/map_column_ops.py +264 -96
- snowflake/snowpark_connect/relation/map_extension.py +263 -29
- snowflake/snowpark_connect/relation/map_join.py +683 -442
- snowflake/snowpark_connect/relation/map_local_relation.py +28 -1
- snowflake/snowpark_connect/relation/map_map_partitions.py +83 -8
- snowflake/snowpark_connect/relation/map_relation.py +48 -19
- snowflake/snowpark_connect/relation/map_row_ops.py +310 -91
- snowflake/snowpark_connect/relation/map_show_string.py +13 -6
- snowflake/snowpark_connect/relation/map_sql.py +1233 -222
- snowflake/snowpark_connect/relation/map_stats.py +48 -9
- snowflake/snowpark_connect/relation/map_subquery_alias.py +11 -2
- snowflake/snowpark_connect/relation/map_udtf.py +14 -4
- snowflake/snowpark_connect/relation/read/jdbc_read_dbapi.py +53 -14
- snowflake/snowpark_connect/relation/read/map_read.py +134 -43
- snowflake/snowpark_connect/relation/read/map_read_csv.py +255 -45
- snowflake/snowpark_connect/relation/read/map_read_jdbc.py +17 -5
- snowflake/snowpark_connect/relation/read/map_read_json.py +320 -85
- snowflake/snowpark_connect/relation/read/map_read_parquet.py +142 -27
- snowflake/snowpark_connect/relation/read/map_read_partitioned_parquet.py +142 -0
- snowflake/snowpark_connect/relation/read/map_read_socket.py +11 -3
- snowflake/snowpark_connect/relation/read/map_read_table.py +82 -5
- snowflake/snowpark_connect/relation/read/map_read_text.py +18 -3
- snowflake/snowpark_connect/relation/read/metadata_utils.py +170 -0
- snowflake/snowpark_connect/relation/read/reader_config.py +36 -3
- snowflake/snowpark_connect/relation/read/utils.py +50 -5
- snowflake/snowpark_connect/relation/stage_locator.py +91 -55
- snowflake/snowpark_connect/relation/utils.py +128 -5
- snowflake/snowpark_connect/relation/write/jdbc_write_dbapi.py +19 -3
- snowflake/snowpark_connect/relation/write/map_write.py +929 -319
- snowflake/snowpark_connect/relation/write/map_write_jdbc.py +8 -2
- snowflake/snowpark_connect/resources/java_udfs-1.0-SNAPSHOT.jar +0 -0
- snowflake/snowpark_connect/resources_initializer.py +110 -48
- snowflake/snowpark_connect/server.py +546 -456
- snowflake/snowpark_connect/server_common/__init__.py +500 -0
- snowflake/snowpark_connect/snowflake_session.py +65 -0
- snowflake/snowpark_connect/start_server.py +53 -5
- snowflake/snowpark_connect/type_mapping.py +349 -27
- snowflake/snowpark_connect/typed_column.py +9 -7
- snowflake/snowpark_connect/utils/artifacts.py +9 -8
- snowflake/snowpark_connect/utils/cache.py +49 -27
- snowflake/snowpark_connect/utils/concurrent.py +36 -1
- snowflake/snowpark_connect/utils/context.py +187 -37
- snowflake/snowpark_connect/utils/describe_query_cache.py +68 -53
- snowflake/snowpark_connect/utils/env_utils.py +5 -1
- snowflake/snowpark_connect/utils/expression_transformer.py +172 -0
- snowflake/snowpark_connect/utils/identifiers.py +137 -3
- snowflake/snowpark_connect/utils/io_utils.py +57 -1
- snowflake/snowpark_connect/utils/java_stored_procedure.py +125 -0
- snowflake/snowpark_connect/utils/java_udaf_utils.py +303 -0
- snowflake/snowpark_connect/utils/java_udtf_utils.py +239 -0
- snowflake/snowpark_connect/utils/jvm_udf_utils.py +248 -0
- snowflake/snowpark_connect/utils/open_telemetry.py +516 -0
- snowflake/snowpark_connect/utils/pandas_udtf_utils.py +8 -4
- snowflake/snowpark_connect/utils/patch_spark_line_number.py +181 -0
- snowflake/snowpark_connect/utils/profiling.py +25 -8
- snowflake/snowpark_connect/utils/scala_udf_utils.py +101 -332
- snowflake/snowpark_connect/utils/sequence.py +21 -0
- snowflake/snowpark_connect/utils/session.py +64 -28
- snowflake/snowpark_connect/utils/snowpark_connect_logging.py +51 -9
- snowflake/snowpark_connect/utils/spcs_logger.py +290 -0
- snowflake/snowpark_connect/utils/telemetry.py +163 -22
- snowflake/snowpark_connect/utils/temporary_view_cache.py +67 -0
- snowflake/snowpark_connect/utils/temporary_view_helper.py +334 -0
- snowflake/snowpark_connect/utils/udf_cache.py +117 -41
- snowflake/snowpark_connect/utils/udf_helper.py +39 -37
- snowflake/snowpark_connect/utils/udf_utils.py +133 -14
- snowflake/snowpark_connect/utils/udtf_helper.py +8 -1
- snowflake/snowpark_connect/utils/udtf_utils.py +46 -31
- snowflake/snowpark_connect/utils/upload_java_jar.py +57 -0
- snowflake/snowpark_connect/version.py +1 -1
- snowflake/snowpark_decoder/dp_session.py +6 -2
- snowflake/snowpark_decoder/spark_decoder.py +12 -0
- {snowpark_connect-0.27.0.data → snowpark_connect-1.6.0.data}/scripts/snowpark-submit +2 -2
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/METADATA +14 -7
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/RECORD +129 -167
- snowflake/snowpark_connect/hidden_column.py +0 -39
- snowflake/snowpark_connect/includes/jars/antlr4-runtime-4.9.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-cli-1.5.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-codec-1.16.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-collections-3.2.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-collections4-4.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-compiler-3.1.9.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-compress-1.26.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-crypto-1.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-dbcp-1.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-io-2.16.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-lang-2.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-lang3-3.12.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-logging-1.1.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-math3-3.6.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-pool-1.5.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-text-1.10.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/hadoop-client-api-trimmed-3.3.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-annotations-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-core-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-core-asl-1.9.13.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-databind-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-dataformat-yaml-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-datatype-jsr310-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-module-scala_2.12-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-ast_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-core_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-jackson_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-native_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-scalap_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/kryo-shaded-4.0.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-1.2-api-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-api-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-core-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-slf4j2-impl-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/paranamer-2.8.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/paranamer-2.8.jar +0 -0
- snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-collection-compat_2.12-2.7.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-parser-combinators_2.12-2.3.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-reflect-2.12.18.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-xml_2.12-2.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/slf4j-api-2.0.7.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-catalyst_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-common-utils_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-connect-client-jvm_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-core_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-graphx_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-hive-thriftserver_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-hive_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-kvstore_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-launcher_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-mesos_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-mllib-local_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-network-common_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-network-shuffle_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-repl_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sketch_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sql-api_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sql_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-tags_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-unsafe_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-yarn_2.12-3.5.6.jar +0 -0
- {snowpark_connect-0.27.0.data → snowpark_connect-1.6.0.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.27.0.data → snowpark_connect-1.6.0.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/licenses/NOTICE-binary +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/top_level.txt +0 -0
|
@@ -3,13 +3,9 @@
|
|
|
3
3
|
#
|
|
4
4
|
|
|
5
5
|
import functools
|
|
6
|
-
import importlib.resources
|
|
7
|
-
import tempfile
|
|
8
6
|
import threading
|
|
9
7
|
import typing
|
|
10
|
-
import zipfile
|
|
11
8
|
from collections.abc import Callable
|
|
12
|
-
from pathlib import Path
|
|
13
9
|
from types import ModuleType
|
|
14
10
|
from typing import List, Optional, Tuple, Union
|
|
15
11
|
|
|
@@ -19,13 +15,15 @@ from snowflake.snowpark.functions import call_udf, udaf, udf, udtf
|
|
|
19
15
|
from snowflake.snowpark.types import DataType, StructType
|
|
20
16
|
from snowflake.snowpark_connect import tcm
|
|
21
17
|
from snowflake.snowpark_connect.utils.telemetry import telemetry
|
|
18
|
+
from snowflake.snowpark_connect.utils.upload_java_jar import (
|
|
19
|
+
JAVA_UDFS_JAR_NAME,
|
|
20
|
+
upload_java_udf_jar,
|
|
21
|
+
)
|
|
22
22
|
|
|
23
23
|
_lock = threading.RLock()
|
|
24
24
|
|
|
25
25
|
_BUILTIN_UDF_PREFIX = "__SC_BUILTIN_"
|
|
26
26
|
|
|
27
|
-
_JAVA_UDFS_JAR_NAME = "java_udfs-1.0-SNAPSHOT.jar"
|
|
28
|
-
|
|
29
27
|
|
|
30
28
|
def init_builtin_udf_cache(session: Session) -> None:
|
|
31
29
|
with _lock:
|
|
@@ -34,6 +32,7 @@ def init_builtin_udf_cache(session: Session) -> None:
|
|
|
34
32
|
session._cached_udtfs = {}
|
|
35
33
|
session._cached_java_udfs = {}
|
|
36
34
|
session._cached_sql_udfs = {}
|
|
35
|
+
session._cached_sprocs = {}
|
|
37
36
|
|
|
38
37
|
|
|
39
38
|
def _hash_types(types: list) -> str:
|
|
@@ -98,7 +97,11 @@ def cached_udaf(
|
|
|
98
97
|
# Register the function outside the lock to avoid contention
|
|
99
98
|
wrapped_func = udaf(
|
|
100
99
|
udaf_type,
|
|
101
|
-
name=
|
|
100
|
+
name=[
|
|
101
|
+
Session.get_active_session().get_current_database(),
|
|
102
|
+
Session.get_active_session().get_current_schema(),
|
|
103
|
+
name,
|
|
104
|
+
],
|
|
102
105
|
return_type=return_type,
|
|
103
106
|
input_types=input_types,
|
|
104
107
|
imports=imports,
|
|
@@ -114,7 +117,7 @@ def cached_udaf(
|
|
|
114
117
|
|
|
115
118
|
if class_type is None:
|
|
116
119
|
raise ValueError(
|
|
117
|
-
"Type must be provided for cached_udaf. UDAF contains multiple functions hence it has to be represented by a type. Functions are not supported."
|
|
120
|
+
"[snowpark_connect::internal_error] Type must be provided for cached_udaf. UDAF contains multiple functions hence it has to be represented by a type. Functions are not supported."
|
|
118
121
|
)
|
|
119
122
|
else:
|
|
120
123
|
# return udaf
|
|
@@ -155,7 +158,11 @@ def cached_udf(
|
|
|
155
158
|
# but this will not cause any issues.
|
|
156
159
|
wrapped_func = udf(
|
|
157
160
|
_null_safe_wrapper,
|
|
158
|
-
name=
|
|
161
|
+
name=[
|
|
162
|
+
Session.get_active_session().get_current_database(),
|
|
163
|
+
Session.get_active_session().get_current_schema(),
|
|
164
|
+
name,
|
|
165
|
+
],
|
|
159
166
|
return_type=return_type,
|
|
160
167
|
input_types=input_types,
|
|
161
168
|
imports=imports,
|
|
@@ -205,7 +212,11 @@ def cached_udtf(
|
|
|
205
212
|
# Register the function outside the lock to avoid contention
|
|
206
213
|
wrapped_func = udtf(
|
|
207
214
|
func,
|
|
208
|
-
name=
|
|
215
|
+
name=[
|
|
216
|
+
Session.get_active_session().get_current_database(),
|
|
217
|
+
Session.get_active_session().get_current_schema(),
|
|
218
|
+
name,
|
|
219
|
+
],
|
|
209
220
|
output_schema=output_schema,
|
|
210
221
|
input_types=input_types,
|
|
211
222
|
imports=imports,
|
|
@@ -306,11 +317,20 @@ def register_cached_sql_udf(
|
|
|
306
317
|
)
|
|
307
318
|
|
|
308
319
|
with _lock:
|
|
309
|
-
|
|
320
|
+
function_identifier = ".".join(
|
|
321
|
+
[
|
|
322
|
+
Session.get_active_session().get_current_database(),
|
|
323
|
+
Session.get_active_session().get_current_schema(),
|
|
324
|
+
function_name,
|
|
325
|
+
]
|
|
326
|
+
)
|
|
327
|
+
cache[function_name] = function_identifier
|
|
328
|
+
else:
|
|
329
|
+
function_identifier = cache[function_name]
|
|
310
330
|
|
|
311
331
|
return functools.partial(
|
|
312
332
|
call_udf,
|
|
313
|
-
|
|
333
|
+
function_identifier,
|
|
314
334
|
)
|
|
315
335
|
|
|
316
336
|
|
|
@@ -343,32 +363,7 @@ def register_cached_java_udf(
|
|
|
343
363
|
|
|
344
364
|
if len(cache) == 0:
|
|
345
365
|
# This is the first Java UDF being registered, so we need to upload the JAR with UDF definitions first
|
|
346
|
-
|
|
347
|
-
try:
|
|
348
|
-
jar_path = importlib.resources.files(
|
|
349
|
-
"snowflake.snowpark_connect.resources"
|
|
350
|
-
).joinpath(_JAVA_UDFS_JAR_NAME)
|
|
351
|
-
except NotADirectoryError:
|
|
352
|
-
# importlib.resource doesn't work in Stage Package method
|
|
353
|
-
zip_path = Path(__file__).parent.parent.parent.parent
|
|
354
|
-
jar_path_in_zip = (
|
|
355
|
-
f"snowflake/snowpark_connect/resources/{_JAVA_UDFS_JAR_NAME}"
|
|
356
|
-
)
|
|
357
|
-
temp_dir = tempfile.gettempdir()
|
|
358
|
-
|
|
359
|
-
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
|
360
|
-
if jar_path_in_zip not in zip_ref.namelist():
|
|
361
|
-
raise FileNotFoundError(f"{jar_path_in_zip} not found")
|
|
362
|
-
zip_ref.extract(jar_path_in_zip, temp_dir)
|
|
363
|
-
|
|
364
|
-
jar_path = f"{temp_dir}/{jar_path_in_zip}"
|
|
365
|
-
|
|
366
|
-
upload_result = session.file.put(str(jar_path), stage, overwrite=True)
|
|
367
|
-
|
|
368
|
-
if upload_result[0].status != "UPLOADED":
|
|
369
|
-
raise RuntimeError(
|
|
370
|
-
f"Failed to upload JAR with UDF definitions to stage: {upload_result[0].message}"
|
|
371
|
-
)
|
|
366
|
+
upload_java_udf_jar(session)
|
|
372
367
|
|
|
373
368
|
udf_is_cached = function_name in cache
|
|
374
369
|
|
|
@@ -378,15 +373,96 @@ def register_cached_java_udf(
|
|
|
378
373
|
function_name,
|
|
379
374
|
input_types,
|
|
380
375
|
return_type,
|
|
381
|
-
[f"{stage}/{
|
|
376
|
+
[f"{stage}/snowflake/snowpark_connect/resources/{JAVA_UDFS_JAR_NAME}"],
|
|
382
377
|
java_handler,
|
|
383
378
|
packages,
|
|
384
379
|
)
|
|
385
380
|
|
|
386
381
|
with _lock:
|
|
387
|
-
|
|
382
|
+
function_identifier = ".".join(
|
|
383
|
+
[
|
|
384
|
+
Session.get_active_session().get_current_database(),
|
|
385
|
+
Session.get_active_session().get_current_schema(),
|
|
386
|
+
function_name,
|
|
387
|
+
]
|
|
388
|
+
)
|
|
389
|
+
cache[function_name] = function_identifier
|
|
390
|
+
else:
|
|
391
|
+
function_identifier = cache[function_name]
|
|
388
392
|
|
|
389
393
|
return functools.partial(
|
|
390
394
|
call_udf,
|
|
391
|
-
|
|
395
|
+
function_identifier,
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
def register_cached_sproc(
|
|
400
|
+
sproc_body: str,
|
|
401
|
+
handler_name: str,
|
|
402
|
+
input_arg_types: list[str],
|
|
403
|
+
return_type: str = "STRING",
|
|
404
|
+
runtime_version: str = "3.11",
|
|
405
|
+
packages: list[str] | None = None,
|
|
406
|
+
) -> str:
|
|
407
|
+
"""
|
|
408
|
+
Register a cached stored procedure that persists across schema/database changes.
|
|
409
|
+
|
|
410
|
+
Args:
|
|
411
|
+
sproc_body: The Python code for the stored procedure
|
|
412
|
+
handler_name: Name of the handler function in the sproc_body
|
|
413
|
+
input_arg_types: List of SQL types for input arguments (e.g. ['STRING', 'STRING'])
|
|
414
|
+
return_type: SQL return type (default: 'STRING')
|
|
415
|
+
runtime_version: Python runtime version (default: '3.11')
|
|
416
|
+
packages: List of Python packages to include
|
|
417
|
+
|
|
418
|
+
Returns:
|
|
419
|
+
Fully qualified stored procedure name for calling
|
|
420
|
+
"""
|
|
421
|
+
if packages is None:
|
|
422
|
+
packages = ["snowflake-snowpark-python"]
|
|
423
|
+
|
|
424
|
+
# Create a unique hash based on the procedure content and signature
|
|
425
|
+
content_hash = _hash_types(
|
|
426
|
+
[sproc_body, handler_name, return_type, runtime_version]
|
|
427
|
+
+ input_arg_types
|
|
428
|
+
+ packages
|
|
392
429
|
)
|
|
430
|
+
|
|
431
|
+
# Generate unique procedure name with hash
|
|
432
|
+
sproc_name = f"{_BUILTIN_UDF_PREFIX}SPROC_{content_hash}"
|
|
433
|
+
|
|
434
|
+
with _lock:
|
|
435
|
+
session = Session.get_active_session()
|
|
436
|
+
cache = session._cached_sprocs
|
|
437
|
+
|
|
438
|
+
# Create fully qualified name with current database and schema
|
|
439
|
+
fully_qualified_name = ".".join(
|
|
440
|
+
[session.get_current_database(), session.get_current_schema(), sproc_name]
|
|
441
|
+
)
|
|
442
|
+
|
|
443
|
+
if sproc_name in cache:
|
|
444
|
+
return cache[sproc_name]
|
|
445
|
+
|
|
446
|
+
args_str = ",".join(
|
|
447
|
+
f"arg{idx} {type_}" for idx, type_ in enumerate(input_arg_types)
|
|
448
|
+
)
|
|
449
|
+
packages_str = ",".join(f"'{pkg}'" for pkg in packages)
|
|
450
|
+
|
|
451
|
+
session.sql(
|
|
452
|
+
f"""
|
|
453
|
+
CREATE OR REPLACE TEMPORARY PROCEDURE {sproc_name}({args_str})
|
|
454
|
+
RETURNS {return_type}
|
|
455
|
+
LANGUAGE PYTHON
|
|
456
|
+
RUNTIME_VERSION = '{runtime_version}'
|
|
457
|
+
PACKAGES = ({packages_str})
|
|
458
|
+
HANDLER = '{handler_name}'
|
|
459
|
+
AS $$
|
|
460
|
+
{sproc_body}
|
|
461
|
+
$$
|
|
462
|
+
"""
|
|
463
|
+
).collect()
|
|
464
|
+
|
|
465
|
+
with _lock:
|
|
466
|
+
cache[sproc_name] = fully_qualified_name
|
|
467
|
+
|
|
468
|
+
return fully_qualified_name
|
|
@@ -14,14 +14,20 @@ from pyspark.errors.exceptions.base import AnalysisException
|
|
|
14
14
|
import snowflake.snowpark.functions as snowpark_fn
|
|
15
15
|
import snowflake.snowpark_connect.tcm as tcm
|
|
16
16
|
import snowflake.snowpark_connect.utils.udf_utils as udf_utils
|
|
17
|
-
from snowflake.snowpark import
|
|
17
|
+
from snowflake.snowpark import Session
|
|
18
18
|
from snowflake.snowpark.types import DataType, _parse_datatype_json_value
|
|
19
19
|
from snowflake.snowpark_connect.column_name_handler import ColumnNameMap
|
|
20
20
|
from snowflake.snowpark_connect.config import global_config
|
|
21
|
+
from snowflake.snowpark_connect.error.error_codes import ErrorCodes
|
|
22
|
+
from snowflake.snowpark_connect.error.error_utils import attach_custom_error_code
|
|
21
23
|
from snowflake.snowpark_connect.expression.map_expression import (
|
|
22
24
|
map_single_column_expression,
|
|
23
25
|
)
|
|
26
|
+
from snowflake.snowpark_connect.expression.map_unresolved_star import (
|
|
27
|
+
map_unresolved_star_as_single_column,
|
|
28
|
+
)
|
|
24
29
|
from snowflake.snowpark_connect.expression.typer import ExpressionTyper
|
|
30
|
+
from snowflake.snowpark_connect.typed_column import TypedColumn
|
|
25
31
|
from snowflake.snowpark_connect.utils.context import (
|
|
26
32
|
get_is_aggregate_function,
|
|
27
33
|
get_is_evaluating_join_condition,
|
|
@@ -38,6 +44,7 @@ class SnowparkUDF(NamedTuple):
|
|
|
38
44
|
return_type: DataType
|
|
39
45
|
input_types: list[DataType]
|
|
40
46
|
original_return_type: DataType | None
|
|
47
|
+
cast_to_original_return_type: bool = False
|
|
41
48
|
|
|
42
49
|
|
|
43
50
|
def require_creating_udf_in_sproc(
|
|
@@ -184,6 +191,7 @@ def parse_return_type(return_type_json_str) -> Optional[DataType]:
|
|
|
184
191
|
|
|
185
192
|
|
|
186
193
|
def create(session, called_from, return_type_json_str, input_types_json_str, input_column_names_json_str, udf_name, replace, udf_packages, udf_imports, b64_str, original_return_type):
|
|
194
|
+
session._use_scoped_temp_objects = False
|
|
187
195
|
import snowflake.snowpark.context as context
|
|
188
196
|
context._use_structured_type_semantics = True
|
|
189
197
|
context._is_snowpark_connect_compatible_mode = True
|
|
@@ -227,25 +235,15 @@ def _check_supported_udf(
|
|
|
227
235
|
case "python_udf":
|
|
228
236
|
pass
|
|
229
237
|
case "java_udf":
|
|
230
|
-
|
|
231
|
-
get_or_create_snowpark_session,
|
|
232
|
-
)
|
|
233
|
-
|
|
234
|
-
session = get_or_create_snowpark_session()
|
|
235
|
-
if udf_proto.java_udf.class_name not in session._cached_java_udfs:
|
|
236
|
-
raise AnalysisException(
|
|
237
|
-
f"Can not load class {udf_proto.java_udf.class_name}"
|
|
238
|
-
)
|
|
239
|
-
else:
|
|
240
|
-
raise ValueError(
|
|
241
|
-
"Function type java_udf not supported for common inline user-defined function"
|
|
242
|
-
)
|
|
238
|
+
pass
|
|
243
239
|
case "scalar_scala_udf":
|
|
244
240
|
pass
|
|
245
241
|
case _ as function_type:
|
|
246
|
-
|
|
242
|
+
exception = ValueError(
|
|
247
243
|
f"Function type {function_type} not supported for common inline user-defined function"
|
|
248
244
|
)
|
|
245
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
246
|
+
raise exception
|
|
249
247
|
|
|
250
248
|
|
|
251
249
|
def _aggregate_function_check(
|
|
@@ -253,9 +251,11 @@ def _aggregate_function_check(
|
|
|
253
251
|
):
|
|
254
252
|
name, is_aggregate_function = get_is_aggregate_function()
|
|
255
253
|
if not udf_proto.deterministic and name != "default" and is_aggregate_function:
|
|
256
|
-
|
|
254
|
+
exception = AnalysisException(
|
|
257
255
|
f"[AGGREGATE_FUNCTION_WITH_NONDETERMINISTIC_EXPRESSION] Non-deterministic expression {name}({udf_proto.function_name}) should not appear in the arguments of an aggregate function."
|
|
258
256
|
)
|
|
257
|
+
attach_custom_error_code(exception, ErrorCodes.INVALID_OPERATION)
|
|
258
|
+
raise exception
|
|
259
259
|
|
|
260
260
|
|
|
261
261
|
def _join_checks(snowpark_udf_arg_names: list[str]):
|
|
@@ -282,49 +282,51 @@ def _join_checks(snowpark_udf_arg_names: list[str]):
|
|
|
282
282
|
and is_left_evaluable
|
|
283
283
|
and is_right_evaluable
|
|
284
284
|
):
|
|
285
|
-
|
|
285
|
+
exception = AnalysisException(
|
|
286
286
|
f"Detected implicit cartesian product for {is_evaluating_join_condition[0]} join between logical plans. \n"
|
|
287
287
|
f"Join condition is missing or trivial. \n"
|
|
288
288
|
f"Either: use the CROSS JOIN syntax to allow cartesian products between those relations, or; "
|
|
289
289
|
f"enable implicit cartesian products by setting the configuration variable spark.sql.crossJoin.enabled=True."
|
|
290
290
|
)
|
|
291
|
+
attach_custom_error_code(exception, ErrorCodes.INVALID_OPERATION)
|
|
292
|
+
raise exception
|
|
291
293
|
if (
|
|
292
294
|
is_evaluating_join_condition[0] != "INNER"
|
|
293
295
|
and is_evaluating_join_condition[1]
|
|
294
296
|
and is_left_evaluable
|
|
295
297
|
and is_right_evaluable
|
|
296
298
|
):
|
|
297
|
-
|
|
299
|
+
exception = AnalysisException(
|
|
298
300
|
f"[UNSUPPORTED_FEATURE.PYTHON_UDF_IN_ON_CLAUSE] The feature is not supported: "
|
|
299
301
|
f"Python UDF in the ON clause of a {is_evaluating_join_condition[0]} JOIN. "
|
|
300
302
|
f"In case of an INNNER JOIN consider rewriting to a CROSS JOIN with a WHERE clause."
|
|
301
303
|
)
|
|
304
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
305
|
+
raise exception
|
|
302
306
|
|
|
303
307
|
|
|
304
308
|
def infer_snowpark_arguments(
|
|
305
309
|
udf_proto: expressions_proto.CommonInlineUserDefinedFunction,
|
|
306
310
|
column_mapping: ColumnNameMap,
|
|
307
311
|
typer: ExpressionTyper,
|
|
308
|
-
) -> tuple[list[str], list[
|
|
309
|
-
snowpark_udf_args: list[
|
|
312
|
+
) -> tuple[list[str], list[TypedColumn]]:
|
|
313
|
+
snowpark_udf_args: list[TypedColumn] = []
|
|
310
314
|
snowpark_udf_arg_names: list[str] = []
|
|
311
315
|
for arg_exp in udf_proto.arguments:
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
316
|
+
# Handle unresolved_star expressions specially
|
|
317
|
+
if arg_exp.HasField("unresolved_star"):
|
|
318
|
+
# Use map_unresolved_star_as_struct to expand star into a single combined column
|
|
319
|
+
spark_name, typed_column = map_unresolved_star_as_single_column(
|
|
320
|
+
arg_exp, column_mapping, typer
|
|
321
|
+
)
|
|
322
|
+
snowpark_udf_args.append(typed_column)
|
|
323
|
+
snowpark_udf_arg_names.append(spark_name)
|
|
324
|
+
else:
|
|
325
|
+
(
|
|
326
|
+
snowpark_udf_arg_name,
|
|
327
|
+
snowpark_udf_arg,
|
|
328
|
+
) = map_single_column_expression(arg_exp, column_mapping, typer)
|
|
329
|
+
snowpark_udf_args.append(snowpark_udf_arg)
|
|
330
|
+
snowpark_udf_arg_names.append(snowpark_udf_arg_name)
|
|
319
331
|
_join_checks(snowpark_udf_arg_names)
|
|
320
332
|
return snowpark_udf_arg_names, snowpark_udf_args
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
def gen_input_types(
|
|
324
|
-
snowpark_udf_args: list[Column],
|
|
325
|
-
typer: ExpressionTyper,
|
|
326
|
-
):
|
|
327
|
-
input_types = []
|
|
328
|
-
for udf_arg in snowpark_udf_args:
|
|
329
|
-
input_types.extend(typer.type(udf_arg))
|
|
330
|
-
return input_types
|
|
@@ -103,7 +103,7 @@ class ProcessCommonInlineUserDefinedFunction:
|
|
|
103
103
|
)
|
|
104
104
|
case _:
|
|
105
105
|
raise ValueError(
|
|
106
|
-
f"Function type {self._function_type} not supported for common inline user-defined function"
|
|
106
|
+
f"[snowpark_connect::unsupported_operation] Function type {self._function_type} not supported for common inline user-defined function"
|
|
107
107
|
)
|
|
108
108
|
|
|
109
109
|
@property
|
|
@@ -112,7 +112,7 @@ class ProcessCommonInlineUserDefinedFunction:
|
|
|
112
112
|
return self._snowpark_udf_args
|
|
113
113
|
else:
|
|
114
114
|
raise ValueError(
|
|
115
|
-
"Column mapping is not provided, cannot get snowpark udf args"
|
|
115
|
+
"[snowpark_connect::internal_error] Column mapping is not provided, cannot get snowpark udf args"
|
|
116
116
|
)
|
|
117
117
|
|
|
118
118
|
@property
|
|
@@ -121,7 +121,7 @@ class ProcessCommonInlineUserDefinedFunction:
|
|
|
121
121
|
return self._snowpark_udf_arg_names
|
|
122
122
|
else:
|
|
123
123
|
raise ValueError(
|
|
124
|
-
"Column mapping is not provided, cannot get snowpark udf arg names"
|
|
124
|
+
"[snowpark_connect::internal_error] Column mapping is not provided, cannot get snowpark udf arg names"
|
|
125
125
|
)
|
|
126
126
|
|
|
127
127
|
def _create_python_udf(self):
|
|
@@ -148,7 +148,12 @@ class ProcessCommonInlineUserDefinedFunction:
|
|
|
148
148
|
|
|
149
149
|
# Change directory to the one containing the UDF imported files
|
|
150
150
|
import_path = sys._xoptions["snowflake_import_directory"]
|
|
151
|
-
|
|
151
|
+
if os.name == "nt":
|
|
152
|
+
import tempfile
|
|
153
|
+
|
|
154
|
+
tmp_path = os.path.join(tempfile.gettempdir(), f"sas-{os.getpid()}")
|
|
155
|
+
else:
|
|
156
|
+
tmp_path = f"/tmp/sas-{os.getpid()}"
|
|
152
157
|
os.makedirs(tmp_path, exist_ok=True)
|
|
153
158
|
os.chdir(tmp_path)
|
|
154
159
|
shutil.copytree(import_path, tmp_path, dirs_exist_ok=True)
|
|
@@ -176,14 +181,6 @@ class ProcessCommonInlineUserDefinedFunction:
|
|
|
176
181
|
tar_ref.extractall(archive[: -len(".archive")])
|
|
177
182
|
os.remove(archive)
|
|
178
183
|
|
|
179
|
-
def callable_func(*args, **kwargs):
|
|
180
|
-
import_staged_files()
|
|
181
|
-
return original_callable(*args, **kwargs)
|
|
182
|
-
|
|
183
|
-
callable_func.__signature__ = inspect.signature(original_callable)
|
|
184
|
-
if hasattr(original_callable, "__annotations__"):
|
|
185
|
-
callable_func.__annotations__ = original_callable.__annotations__
|
|
186
|
-
|
|
187
184
|
if self._udf_packages:
|
|
188
185
|
packages = [p.strip() for p in self._udf_packages.strip("[]").split(",")]
|
|
189
186
|
else:
|
|
@@ -193,13 +190,109 @@ class ProcessCommonInlineUserDefinedFunction:
|
|
|
193
190
|
else:
|
|
194
191
|
imports = []
|
|
195
192
|
|
|
193
|
+
def callable_func(*args, **kwargs):
|
|
194
|
+
if imports:
|
|
195
|
+
import_staged_files()
|
|
196
|
+
return original_callable(*args, **kwargs)
|
|
197
|
+
|
|
198
|
+
callable_func.__signature__ = inspect.signature(original_callable)
|
|
199
|
+
if hasattr(original_callable, "__annotations__"):
|
|
200
|
+
callable_func.__annotations__ = original_callable.__annotations__
|
|
201
|
+
|
|
196
202
|
update_none_input_types()
|
|
197
203
|
|
|
204
|
+
struct_positions = [
|
|
205
|
+
i
|
|
206
|
+
for i, t in enumerate(self._input_types or [])
|
|
207
|
+
if isinstance(t, StructType)
|
|
208
|
+
]
|
|
209
|
+
|
|
210
|
+
if struct_positions:
|
|
211
|
+
|
|
212
|
+
class StructRowProxy:
|
|
213
|
+
"""Row-like object supporting positional and named access for PySpark compatibility."""
|
|
214
|
+
|
|
215
|
+
def __init__(self, fields, values) -> None:
|
|
216
|
+
self._fields = fields
|
|
217
|
+
self._values = values
|
|
218
|
+
self._field_to_index = {field: i for i, field in enumerate(fields)}
|
|
219
|
+
|
|
220
|
+
def __getitem__(self, key):
|
|
221
|
+
if isinstance(key, int):
|
|
222
|
+
return self._values[key]
|
|
223
|
+
elif isinstance(key, str):
|
|
224
|
+
if key in self._field_to_index:
|
|
225
|
+
return self._values[self._field_to_index[key]]
|
|
226
|
+
raise KeyError(f"Field '{key}' not found in struct")
|
|
227
|
+
else:
|
|
228
|
+
raise TypeError(f"Invalid key type: {type(key)}")
|
|
229
|
+
|
|
230
|
+
def __getattr__(self, name):
|
|
231
|
+
if name.startswith("_"):
|
|
232
|
+
raise AttributeError(f"Attribute '{name}' not found")
|
|
233
|
+
if name in self._field_to_index:
|
|
234
|
+
return self._values[self._field_to_index[name]]
|
|
235
|
+
raise AttributeError(f"Attribute '{name}' not found")
|
|
236
|
+
|
|
237
|
+
def __len__(self):
|
|
238
|
+
return len(self._values)
|
|
239
|
+
|
|
240
|
+
def __iter__(self):
|
|
241
|
+
return iter(self._values)
|
|
242
|
+
|
|
243
|
+
def __repr__(self):
|
|
244
|
+
field_values = [
|
|
245
|
+
f"{field}={repr(value)}"
|
|
246
|
+
for field, value in zip(self._fields, self._values)
|
|
247
|
+
]
|
|
248
|
+
return f"Row({', '.join(field_values)})"
|
|
249
|
+
|
|
250
|
+
def asDict(self):
|
|
251
|
+
"""Convert to dict (like PySpark Row.asDict())."""
|
|
252
|
+
return dict(zip(self._fields, self._values))
|
|
253
|
+
|
|
254
|
+
def convert_to_row(arg):
|
|
255
|
+
"""Convert dict to StructRowProxy. Only called for struct positions."""
|
|
256
|
+
if isinstance(arg, dict) and arg:
|
|
257
|
+
fields = list(arg.keys())
|
|
258
|
+
values = [arg[k] for k in fields]
|
|
259
|
+
return StructRowProxy(fields, values)
|
|
260
|
+
return arg
|
|
261
|
+
|
|
262
|
+
def convert_from_row(result):
|
|
263
|
+
"""Convert StructRowProxy back to dict for serialization."""
|
|
264
|
+
if isinstance(result, StructRowProxy):
|
|
265
|
+
return result.asDict()
|
|
266
|
+
return result
|
|
267
|
+
|
|
268
|
+
def struct_input_wrapper(*args, **kwargs):
|
|
269
|
+
if struct_positions:
|
|
270
|
+
processed_args = []
|
|
271
|
+
for i, arg in enumerate(args):
|
|
272
|
+
if i in struct_positions:
|
|
273
|
+
processed_args.append(convert_to_row(arg))
|
|
274
|
+
else:
|
|
275
|
+
processed_args.append(arg)
|
|
276
|
+
|
|
277
|
+
processed_kwargs = {k: convert_to_row(v) for k, v in kwargs.items()}
|
|
278
|
+
result = callable_func(*tuple(processed_args), **processed_kwargs)
|
|
279
|
+
# Convert any StructRowProxy in return value back to dict for serialization
|
|
280
|
+
return convert_from_row(result)
|
|
281
|
+
return callable_func(*args, **kwargs)
|
|
282
|
+
|
|
198
283
|
needs_struct_conversion = isinstance(self._original_return_type, StructType)
|
|
199
284
|
|
|
285
|
+
# Use callable_func directly when there are no struct inputs to avoid closure issues.
|
|
286
|
+
# struct_input_wrapper captures convert_to_row in its closure, but convert_to_row is only
|
|
287
|
+
# defined when struct_positions is truthy. Cloudpickle serializes all closure variables,
|
|
288
|
+
# so using struct_input_wrapper without struct positions would fail during serialization.
|
|
289
|
+
updated_callable_func = (
|
|
290
|
+
struct_input_wrapper if struct_positions else callable_func
|
|
291
|
+
)
|
|
292
|
+
|
|
200
293
|
if not needs_struct_conversion:
|
|
201
294
|
return snowpark_fn.udf(
|
|
202
|
-
create_null_safe_wrapper(
|
|
295
|
+
create_null_safe_wrapper(updated_callable_func),
|
|
203
296
|
return_type=self._return_type,
|
|
204
297
|
input_types=self._input_types,
|
|
205
298
|
name=self._udf_name,
|
|
@@ -225,7 +318,21 @@ class ProcessCommonInlineUserDefinedFunction:
|
|
|
225
318
|
field_names = [field.name for field in self._original_return_type.fields]
|
|
226
319
|
|
|
227
320
|
def struct_wrapper(*args):
|
|
321
|
+
if struct_positions:
|
|
322
|
+
processed_args = []
|
|
323
|
+
for i, arg in enumerate(args):
|
|
324
|
+
if i in struct_positions:
|
|
325
|
+
processed_args.append(convert_to_row(arg))
|
|
326
|
+
else:
|
|
327
|
+
processed_args.append(arg)
|
|
328
|
+
args = tuple(processed_args)
|
|
329
|
+
|
|
228
330
|
result = callable_func(*args)
|
|
331
|
+
|
|
332
|
+
# Convert StructRowProxy back to dict for serialization
|
|
333
|
+
if struct_positions:
|
|
334
|
+
result = convert_from_row(result)
|
|
335
|
+
|
|
229
336
|
if isinstance(result, (tuple, list)):
|
|
230
337
|
# Convert tuple/list to dict using struct field names
|
|
231
338
|
if len(result) == len(field_names):
|
|
@@ -283,6 +390,18 @@ class ProcessCommonInlineUserDefinedFunction:
|
|
|
283
390
|
case "python_udf":
|
|
284
391
|
return self._create_python_udf()
|
|
285
392
|
case "scalar_scala_udf":
|
|
393
|
+
from snowflake.snowpark_connect.utils.context import (
|
|
394
|
+
get_is_aggregate_function,
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
name, is_aggregate_function = get_is_aggregate_function()
|
|
398
|
+
if is_aggregate_function and name.lower() == "reduce":
|
|
399
|
+
# Handling of Scala Reduce function requires usage of Java UDAF
|
|
400
|
+
from snowflake.snowpark_connect.utils.java_udaf_utils import (
|
|
401
|
+
create_java_udaf_for_reduce_scala_function,
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
return create_java_udaf_for_reduce_scala_function(self)
|
|
286
405
|
from snowflake.snowpark_connect.utils.scala_udf_utils import (
|
|
287
406
|
create_scala_udf,
|
|
288
407
|
)
|
|
@@ -290,5 +409,5 @@ class ProcessCommonInlineUserDefinedFunction:
|
|
|
290
409
|
return create_scala_udf(self)
|
|
291
410
|
case _:
|
|
292
411
|
raise ValueError(
|
|
293
|
-
f"Function type {self._function_type} not supported for common inline user-defined function"
|
|
412
|
+
f"[snowpark_connect::unsupported_operation] Function type {self._function_type} not supported for common inline user-defined function"
|
|
294
413
|
)
|
|
@@ -16,6 +16,8 @@ import snowflake.snowpark_connect.tcm as tcm
|
|
|
16
16
|
from snowflake import snowpark
|
|
17
17
|
from snowflake.snowpark._internal.analyzer.analyzer_utils import unquote_if_quoted
|
|
18
18
|
from snowflake.snowpark.types import DataType, StructType, _parse_datatype_json_value
|
|
19
|
+
from snowflake.snowpark_connect.error.error_codes import ErrorCodes
|
|
20
|
+
from snowflake.snowpark_connect.error.error_utils import attach_custom_error_code
|
|
19
21
|
from snowflake.snowpark_connect.type_mapping import proto_to_snowpark_type
|
|
20
22
|
from snowflake.snowpark_connect.utils import pandas_udtf_utils, udtf_utils
|
|
21
23
|
from snowflake.snowpark_connect.utils.session import get_or_create_snowpark_session
|
|
@@ -37,7 +39,9 @@ def udtf_check(
|
|
|
37
39
|
udtf_proto: relation_proto.CommonInlineUserDefinedTableFunction,
|
|
38
40
|
) -> None:
|
|
39
41
|
if udtf_proto.WhichOneof("function") != "python_udtf":
|
|
40
|
-
|
|
42
|
+
exception = ValueError(f"Not python udtf {udtf_proto.function}")
|
|
43
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
44
|
+
raise exception
|
|
41
45
|
|
|
42
46
|
|
|
43
47
|
def require_creating_udtf_in_sproc(
|
|
@@ -149,6 +153,7 @@ def parse_types(types_json_str) -> Optional[list[DataType]]:
|
|
|
149
153
|
return json.loads(types_json_str)
|
|
150
154
|
|
|
151
155
|
def create(session, b64_str, expected_types_json_str, output_schema_json_str, packages, imports, is_arrow_enabled, is_spark_compatible_udtf_mode_enabled, called_from):
|
|
156
|
+
session._use_scoped_temp_objects = False
|
|
152
157
|
import snowflake.snowpark.context as context
|
|
153
158
|
context._use_structured_type_semantics = True
|
|
154
159
|
context._is_snowpark_connect_compatible_mode = True
|
|
@@ -253,6 +258,7 @@ from snowflake.snowpark.types import _parse_datatype_json_value
|
|
|
253
258
|
{inline_udtf_utils_py_code}
|
|
254
259
|
|
|
255
260
|
def create(session, b64_str, spark_column_names_json_str, input_schema_json_str, return_schema_json_str):
|
|
261
|
+
session._use_scoped_temp_objects = False
|
|
256
262
|
import snowflake.snowpark.context as context
|
|
257
263
|
context._use_structured_type_semantics = True
|
|
258
264
|
context._is_snowpark_connect_compatible_mode = True
|
|
@@ -326,6 +332,7 @@ from snowflake.snowpark.types import _parse_datatype_json_value
|
|
|
326
332
|
from pyspark.serializers import CloudPickleSerializer
|
|
327
333
|
|
|
328
334
|
def create(session, func_info_json):
|
|
335
|
+
session._use_scoped_temp_objects = False
|
|
329
336
|
import snowflake.snowpark.context as context
|
|
330
337
|
context._use_structured_type_semantics = True
|
|
331
338
|
context._is_snowpark_connect_compatible_mode = True
|