snowpark-connect 0.27.0__py3-none-any.whl → 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/snowpark_connect/__init__.py +1 -0
- snowflake/snowpark_connect/analyze_plan/map_tree_string.py +8 -4
- snowflake/snowpark_connect/client/__init__.py +15 -0
- snowflake/snowpark_connect/client/error_utils.py +30 -0
- snowflake/snowpark_connect/client/exceptions.py +36 -0
- snowflake/snowpark_connect/client/query_results.py +90 -0
- snowflake/snowpark_connect/client/server.py +717 -0
- snowflake/snowpark_connect/client/utils/__init__.py +10 -0
- snowflake/snowpark_connect/client/utils/session.py +85 -0
- snowflake/snowpark_connect/column_name_handler.py +404 -243
- snowflake/snowpark_connect/column_qualifier.py +43 -0
- snowflake/snowpark_connect/config.py +309 -26
- snowflake/snowpark_connect/constants.py +2 -0
- snowflake/snowpark_connect/dataframe_container.py +102 -8
- snowflake/snowpark_connect/date_time_format_mapping.py +71 -13
- snowflake/snowpark_connect/error/error_codes.py +50 -0
- snowflake/snowpark_connect/error/error_utils.py +172 -23
- snowflake/snowpark_connect/error/exceptions.py +13 -4
- snowflake/snowpark_connect/execute_plan/map_execution_command.py +15 -160
- snowflake/snowpark_connect/execute_plan/map_execution_root.py +26 -20
- snowflake/snowpark_connect/execute_plan/utils.py +5 -1
- snowflake/snowpark_connect/expression/error_utils.py +28 -0
- snowflake/snowpark_connect/expression/function_defaults.py +9 -2
- snowflake/snowpark_connect/expression/hybrid_column_map.py +53 -5
- snowflake/snowpark_connect/expression/integral_types_support.py +219 -0
- snowflake/snowpark_connect/expression/literal.py +37 -13
- snowflake/snowpark_connect/expression/map_cast.py +224 -15
- snowflake/snowpark_connect/expression/map_expression.py +80 -27
- snowflake/snowpark_connect/expression/map_extension.py +322 -12
- snowflake/snowpark_connect/expression/map_sql_expression.py +316 -81
- snowflake/snowpark_connect/expression/map_udf.py +86 -20
- snowflake/snowpark_connect/expression/map_unresolved_attribute.py +451 -173
- snowflake/snowpark_connect/expression/map_unresolved_function.py +2964 -829
- snowflake/snowpark_connect/expression/map_unresolved_star.py +87 -23
- snowflake/snowpark_connect/expression/map_update_fields.py +70 -18
- snowflake/snowpark_connect/expression/map_window_function.py +18 -3
- snowflake/snowpark_connect/includes/jars/json4s-ast_2.13-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/{scala-library-2.12.18.jar → sas-scala-udf_2.12-0.2.0.jar} +0 -0
- snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.13-0.2.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-reflect-2.13.16.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-common-utils_2.13-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/{spark-connect-client-jvm_2.12-3.5.6.jar → spark-connect-client-jvm_2.13-3.5.6.jar} +0 -0
- snowflake/snowpark_connect/includes/jars/{spark-sql_2.12-3.5.6.jar → spark-sql_2.13-3.5.6.jar} +0 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py +1 -1
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/listener_worker.py +1 -1
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.py +12 -10
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.pyi +14 -2
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +10 -8
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +13 -6
- snowflake/snowpark_connect/relation/catalogs/abstract_spark_catalog.py +65 -17
- snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +297 -49
- snowflake/snowpark_connect/relation/catalogs/utils.py +12 -4
- snowflake/snowpark_connect/relation/io_utils.py +110 -10
- snowflake/snowpark_connect/relation/map_aggregate.py +239 -256
- snowflake/snowpark_connect/relation/map_catalog.py +5 -1
- snowflake/snowpark_connect/relation/map_column_ops.py +264 -96
- snowflake/snowpark_connect/relation/map_extension.py +263 -29
- snowflake/snowpark_connect/relation/map_join.py +683 -442
- snowflake/snowpark_connect/relation/map_local_relation.py +28 -1
- snowflake/snowpark_connect/relation/map_map_partitions.py +83 -8
- snowflake/snowpark_connect/relation/map_relation.py +48 -19
- snowflake/snowpark_connect/relation/map_row_ops.py +310 -91
- snowflake/snowpark_connect/relation/map_show_string.py +13 -6
- snowflake/snowpark_connect/relation/map_sql.py +1233 -222
- snowflake/snowpark_connect/relation/map_stats.py +48 -9
- snowflake/snowpark_connect/relation/map_subquery_alias.py +11 -2
- snowflake/snowpark_connect/relation/map_udtf.py +14 -4
- snowflake/snowpark_connect/relation/read/jdbc_read_dbapi.py +53 -14
- snowflake/snowpark_connect/relation/read/map_read.py +134 -43
- snowflake/snowpark_connect/relation/read/map_read_csv.py +326 -47
- snowflake/snowpark_connect/relation/read/map_read_jdbc.py +21 -6
- snowflake/snowpark_connect/relation/read/map_read_json.py +324 -86
- snowflake/snowpark_connect/relation/read/map_read_parquet.py +146 -28
- snowflake/snowpark_connect/relation/read/map_read_partitioned_parquet.py +142 -0
- snowflake/snowpark_connect/relation/read/map_read_socket.py +15 -3
- snowflake/snowpark_connect/relation/read/map_read_table.py +86 -6
- snowflake/snowpark_connect/relation/read/map_read_text.py +22 -4
- snowflake/snowpark_connect/relation/read/metadata_utils.py +170 -0
- snowflake/snowpark_connect/relation/read/reader_config.py +42 -3
- snowflake/snowpark_connect/relation/read/utils.py +50 -5
- snowflake/snowpark_connect/relation/stage_locator.py +91 -55
- snowflake/snowpark_connect/relation/utils.py +128 -5
- snowflake/snowpark_connect/relation/write/jdbc_write_dbapi.py +19 -3
- snowflake/snowpark_connect/relation/write/map_write.py +929 -319
- snowflake/snowpark_connect/relation/write/map_write_jdbc.py +8 -2
- snowflake/snowpark_connect/resources/java_udfs-1.0-SNAPSHOT.jar +0 -0
- snowflake/snowpark_connect/resources_initializer.py +171 -48
- snowflake/snowpark_connect/server.py +528 -473
- snowflake/snowpark_connect/server_common/__init__.py +503 -0
- snowflake/snowpark_connect/snowflake_session.py +65 -0
- snowflake/snowpark_connect/start_server.py +53 -5
- snowflake/snowpark_connect/type_mapping.py +349 -27
- snowflake/snowpark_connect/type_support.py +130 -0
- snowflake/snowpark_connect/typed_column.py +9 -7
- snowflake/snowpark_connect/utils/artifacts.py +9 -8
- snowflake/snowpark_connect/utils/cache.py +49 -27
- snowflake/snowpark_connect/utils/concurrent.py +36 -1
- snowflake/snowpark_connect/utils/context.py +195 -37
- snowflake/snowpark_connect/utils/describe_query_cache.py +68 -53
- snowflake/snowpark_connect/utils/env_utils.py +5 -1
- snowflake/snowpark_connect/utils/expression_transformer.py +172 -0
- snowflake/snowpark_connect/utils/identifiers.py +137 -3
- snowflake/snowpark_connect/utils/io_utils.py +57 -1
- snowflake/snowpark_connect/utils/java_stored_procedure.py +151 -0
- snowflake/snowpark_connect/utils/java_udaf_utils.py +321 -0
- snowflake/snowpark_connect/utils/java_udtf_utils.py +239 -0
- snowflake/snowpark_connect/utils/jvm_udf_utils.py +281 -0
- snowflake/snowpark_connect/utils/open_telemetry.py +516 -0
- snowflake/snowpark_connect/utils/pandas_udtf_utils.py +8 -4
- snowflake/snowpark_connect/utils/patch_spark_line_number.py +181 -0
- snowflake/snowpark_connect/utils/profiling.py +25 -8
- snowflake/snowpark_connect/utils/scala_udf_utils.py +185 -340
- snowflake/snowpark_connect/utils/sequence.py +21 -0
- snowflake/snowpark_connect/utils/session.py +64 -28
- snowflake/snowpark_connect/utils/snowpark_connect_logging.py +51 -9
- snowflake/snowpark_connect/utils/spcs_logger.py +290 -0
- snowflake/snowpark_connect/utils/telemetry.py +192 -40
- snowflake/snowpark_connect/utils/temporary_view_cache.py +67 -0
- snowflake/snowpark_connect/utils/temporary_view_helper.py +334 -0
- snowflake/snowpark_connect/utils/udf_cache.py +117 -41
- snowflake/snowpark_connect/utils/udf_helper.py +39 -37
- snowflake/snowpark_connect/utils/udf_utils.py +133 -14
- snowflake/snowpark_connect/utils/udtf_helper.py +8 -1
- snowflake/snowpark_connect/utils/udtf_utils.py +46 -31
- snowflake/snowpark_connect/utils/udxf_import_utils.py +9 -2
- snowflake/snowpark_connect/utils/upload_java_jar.py +57 -0
- snowflake/snowpark_connect/version.py +1 -1
- snowflake/snowpark_decoder/dp_session.py +6 -2
- snowflake/snowpark_decoder/spark_decoder.py +12 -0
- {snowpark_connect-0.27.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-submit +14 -4
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/METADATA +16 -7
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/RECORD +139 -168
- snowflake/snowpark_connect/hidden_column.py +0 -39
- snowflake/snowpark_connect/includes/jars/antlr4-runtime-4.9.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-cli-1.5.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-codec-1.16.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-collections-3.2.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-collections4-4.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-compiler-3.1.9.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-compress-1.26.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-crypto-1.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-dbcp-1.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-io-2.16.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-lang-2.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-lang3-3.12.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-logging-1.1.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-math3-3.6.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-pool-1.5.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-text-1.10.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/hadoop-client-api-trimmed-3.3.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-annotations-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-core-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-core-asl-1.9.13.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-databind-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-dataformat-yaml-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-datatype-jsr310-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-module-scala_2.12-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-ast_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-core_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-jackson_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-native_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-scalap_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/kryo-shaded-4.0.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-1.2-api-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-api-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-core-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-slf4j2-impl-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/paranamer-2.8.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/paranamer-2.8.jar +0 -0
- snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-collection-compat_2.12-2.7.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-parser-combinators_2.12-2.3.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-reflect-2.12.18.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-xml_2.12-2.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/slf4j-api-2.0.7.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-catalyst_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-common-utils_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-core_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-graphx_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-hive-thriftserver_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-hive_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-kvstore_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-launcher_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-mesos_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-mllib-local_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-network-common_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-network-shuffle_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-repl_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sketch_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sql-api_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-tags_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-unsafe_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-yarn_2.12-3.5.6.jar +0 -0
- {snowpark_connect-0.27.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.27.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/NOTICE-binary +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
|
|
5
|
+
from pyspark.errors import AnalysisException
|
|
6
|
+
|
|
7
|
+
import snowflake.snowpark.types as snowpark_type
|
|
8
|
+
from snowflake.snowpark import Session
|
|
9
|
+
from snowflake.snowpark._internal.type_utils import type_string_to_type_object
|
|
10
|
+
from snowflake.snowpark_connect.client.error_utils import attach_custom_error_code
|
|
11
|
+
from snowflake.snowpark_connect.config import (
|
|
12
|
+
get_scala_version,
|
|
13
|
+
is_java_udf_creator_initialized,
|
|
14
|
+
set_java_udf_creator_initialized_state,
|
|
15
|
+
)
|
|
16
|
+
from snowflake.snowpark_connect.error.error_codes import ErrorCodes
|
|
17
|
+
from snowflake.snowpark_connect.resources_initializer import (
|
|
18
|
+
RESOURCE_PATH,
|
|
19
|
+
SPARK_COMMON_UTILS_JAR_212,
|
|
20
|
+
SPARK_COMMON_UTILS_JAR_213,
|
|
21
|
+
SPARK_CONNECT_CLIENT_JAR_212,
|
|
22
|
+
SPARK_CONNECT_CLIENT_JAR_213,
|
|
23
|
+
SPARK_SQL_JAR_212,
|
|
24
|
+
SPARK_SQL_JAR_213,
|
|
25
|
+
ensure_scala_udf_jars_uploaded,
|
|
26
|
+
)
|
|
27
|
+
from snowflake.snowpark_connect.utils.upload_java_jar import upload_java_udf_jar
|
|
28
|
+
|
|
29
|
+
CREATE_JAVA_UDF_PREFIX = "__SC_JAVA_UDF_"
|
|
30
|
+
PROCEDURE_NAME = "__SC_JAVA_SP_CREATE_JAVA_UDF"
|
|
31
|
+
SP_TEMPLATE = """
|
|
32
|
+
CREATE OR REPLACE TEMPORARY PROCEDURE __SC_JAVA_SP_CREATE_JAVA_UDF(udf_name VARCHAR, udf_class VARCHAR, imports ARRAY(VARCHAR))
|
|
33
|
+
RETURNS VARCHAR
|
|
34
|
+
LANGUAGE JAVA
|
|
35
|
+
RUNTIME_VERSION = 17
|
|
36
|
+
PACKAGES = ('com.snowflake:snowpark___scala_version__:latest')
|
|
37
|
+
__snowflake_udf_imports__
|
|
38
|
+
HANDLER = 'com.snowflake.snowpark_connect.procedures.JavaUDFCreator.process'
|
|
39
|
+
EXECUTE AS CALLER
|
|
40
|
+
;
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class JavaUdf:
|
|
45
|
+
"""
|
|
46
|
+
Reference class for Java UDFs, providing similar properties like Python UserDefinedFunction.
|
|
47
|
+
|
|
48
|
+
This class serves as a lightweight reference to a Java UDF that has been created
|
|
49
|
+
in Snowflake, storing the essential metadata needed for function calls.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
def __init__(
|
|
53
|
+
self,
|
|
54
|
+
name: str,
|
|
55
|
+
input_types: list[snowpark_type.DataType],
|
|
56
|
+
return_type: snowpark_type.DataType,
|
|
57
|
+
) -> None:
|
|
58
|
+
"""
|
|
59
|
+
Initialize a Java UDF reference.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
name: The name of the UDF in Snowflake
|
|
63
|
+
input_types: List of input parameter types
|
|
64
|
+
return_type: The return type of the UDF
|
|
65
|
+
"""
|
|
66
|
+
self.name = name
|
|
67
|
+
self._input_types = input_types
|
|
68
|
+
self._return_type = return_type
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _scala_static_imports_for_sproc(stage_resource_path: str) -> set[str]:
|
|
72
|
+
scala_version = get_scala_version()
|
|
73
|
+
if scala_version == "2.12":
|
|
74
|
+
return {
|
|
75
|
+
f"{stage_resource_path}/{SPARK_CONNECT_CLIENT_JAR_212}",
|
|
76
|
+
f"{stage_resource_path}/{SPARK_COMMON_UTILS_JAR_212}",
|
|
77
|
+
f"{stage_resource_path}/{SPARK_SQL_JAR_212}",
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
if scala_version == "2.13":
|
|
81
|
+
return {
|
|
82
|
+
f"{stage_resource_path}/{SPARK_CONNECT_CLIENT_JAR_213}",
|
|
83
|
+
f"{stage_resource_path}/{SPARK_COMMON_UTILS_JAR_213}",
|
|
84
|
+
f"{stage_resource_path}/{SPARK_SQL_JAR_213}",
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
# invalid Scala version
|
|
88
|
+
exception = ValueError(
|
|
89
|
+
f"Unsupported Scala version: {scala_version}. Snowpark Connect supports Scala 2.12 and 2.13"
|
|
90
|
+
)
|
|
91
|
+
attach_custom_error_code(exception, ErrorCodes.INVALID_CONFIG_VALUE)
|
|
92
|
+
raise exception
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def get_quoted_imports(session: Session) -> str:
|
|
96
|
+
stage_resource_path = session.get_session_stage() + RESOURCE_PATH
|
|
97
|
+
spark_imports = _scala_static_imports_for_sproc(stage_resource_path) | {
|
|
98
|
+
f"{stage_resource_path}/java_udfs-1.0-SNAPSHOT.jar",
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
def quote_single(s: str) -> str:
|
|
102
|
+
"""Helper function to wrap strings in single quotes for SQL."""
|
|
103
|
+
return "'" + s + "'"
|
|
104
|
+
|
|
105
|
+
from snowflake.snowpark_connect.config import global_config
|
|
106
|
+
|
|
107
|
+
config_imports = global_config.get("snowpark.connect.udf.java.imports", "")
|
|
108
|
+
config_imports = (
|
|
109
|
+
{x.strip() for x in config_imports.strip("[] ").split(",") if x.strip()}
|
|
110
|
+
if config_imports
|
|
111
|
+
else set()
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
return ", ".join(
|
|
115
|
+
quote_single(x) for x in session._artifact_jars | spark_imports | config_imports
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def create_snowflake_imports(session: Session) -> str:
|
|
120
|
+
# Make sure that the resource initializer thread is completed before creating Java UDFs since we depend on the jars
|
|
121
|
+
# uploaded by it.
|
|
122
|
+
ensure_scala_udf_jars_uploaded()
|
|
123
|
+
|
|
124
|
+
return f"IMPORTS = ({get_quoted_imports(session)})"
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def create_java_udf(session: Session, function_name: str, java_class: str):
|
|
128
|
+
if not is_java_udf_creator_initialized():
|
|
129
|
+
upload_java_udf_jar(session)
|
|
130
|
+
session.sql(
|
|
131
|
+
SP_TEMPLATE.replace(
|
|
132
|
+
"__snowflake_udf_imports__", create_snowflake_imports(session)
|
|
133
|
+
).replace("__scala_version__", get_scala_version())
|
|
134
|
+
).collect()
|
|
135
|
+
set_java_udf_creator_initialized_state(True)
|
|
136
|
+
name = CREATE_JAVA_UDF_PREFIX + function_name
|
|
137
|
+
result = session.sql(
|
|
138
|
+
f"CALL {PROCEDURE_NAME}('{name}', '{java_class}', ARRAY_CONSTRUCT({get_quoted_imports(session)})::ARRAY(VARCHAR))"
|
|
139
|
+
).collect()
|
|
140
|
+
result_value = result[0][0]
|
|
141
|
+
if not result_value:
|
|
142
|
+
raise AnalysisException(f"Can not load class {java_class}")
|
|
143
|
+
types = result_value.split(";")
|
|
144
|
+
input_types = [type_string_to_type_object(t) for t in types[:-1]]
|
|
145
|
+
output_type = types[-1]
|
|
146
|
+
|
|
147
|
+
return JavaUdf(
|
|
148
|
+
name,
|
|
149
|
+
input_types,
|
|
150
|
+
type_string_to_type_object(output_type),
|
|
151
|
+
)
|
|
@@ -0,0 +1,321 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
|
|
7
|
+
import snowflake.snowpark.types as snowpark_type
|
|
8
|
+
from snowflake.snowpark_connect.type_mapping import map_type_to_snowflake_type
|
|
9
|
+
from snowflake.snowpark_connect.utils.jvm_udf_utils import (
|
|
10
|
+
NullHandling,
|
|
11
|
+
Param,
|
|
12
|
+
ReturnType,
|
|
13
|
+
Signature,
|
|
14
|
+
build_jvm_udxf_imports,
|
|
15
|
+
map_type_to_java_type,
|
|
16
|
+
)
|
|
17
|
+
from snowflake.snowpark_connect.utils.snowpark_connect_logging import logger
|
|
18
|
+
from snowflake.snowpark_connect.utils.udf_utils import (
|
|
19
|
+
ProcessCommonInlineUserDefinedFunction,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
# Prefix used for internally generated Java UDAF names to avoid conflicts
|
|
23
|
+
CREATE_JAVA_UDAF_PREFIX = "__SC_JAVA_UDAF_"
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
UDAF_TEMPLATE = """
|
|
27
|
+
import org.apache.spark.sql.connect.common.UdfPacket;
|
|
28
|
+
|
|
29
|
+
import java.io.IOException;
|
|
30
|
+
import java.io.InputStream;
|
|
31
|
+
import java.io.ObjectInputStream;
|
|
32
|
+
import java.io.Serializable;
|
|
33
|
+
import java.nio.file.Files;
|
|
34
|
+
import java.nio.file.Paths;
|
|
35
|
+
|
|
36
|
+
// Import types required for mapping
|
|
37
|
+
import java.util.*;
|
|
38
|
+
import java.util.stream.Collectors;
|
|
39
|
+
import com.snowflake.snowpark_java.types.*;
|
|
40
|
+
|
|
41
|
+
public class JavaUDAF {
|
|
42
|
+
private final static String OPERATION_FILE = "__operation_file__";
|
|
43
|
+
private static scala.Function2<__reduce_type__, __reduce_type__, __reduce_type__> operation = null;
|
|
44
|
+
private static UdfPacket udfPacket = null;
|
|
45
|
+
|
|
46
|
+
private static void loadOperation() throws IOException, ClassNotFoundException {
|
|
47
|
+
if (operation != null) {
|
|
48
|
+
return; // Already loaded
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
udfPacket = com.snowflake.sas.scala.Utils$.MODULE$.deserializeUdfPacket(OPERATION_FILE);
|
|
52
|
+
operation = (scala.Function2<__reduce_type__, __reduce_type__, __reduce_type__>) udfPacket.function();
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
public static class State implements Serializable {
|
|
56
|
+
public __reduce_type__ value = null;
|
|
57
|
+
public boolean initialized = false;
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
public static State initialize() throws IOException, ClassNotFoundException {
|
|
61
|
+
loadOperation();
|
|
62
|
+
return new State();
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
public static State accumulate(State state, __accumulator_type__ accumulator, __value_type__ input) {
|
|
66
|
+
// TODO: Add conversion between value_type we get in input and the value that we are using in the operation
|
|
67
|
+
if (input == null) {
|
|
68
|
+
return state;
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
if (!state.initialized) {
|
|
72
|
+
state.value = __mapped_value__;
|
|
73
|
+
state.initialized = true;
|
|
74
|
+
} else {
|
|
75
|
+
state.value = operation.apply(state.value, __mapped_value__);
|
|
76
|
+
}
|
|
77
|
+
return state;
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
public static State merge(State s1, State s2) {
|
|
81
|
+
if (!s2.initialized) {
|
|
82
|
+
return s1;
|
|
83
|
+
}
|
|
84
|
+
if (!s1.initialized) {
|
|
85
|
+
return s2;
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
s1.value = operation.apply(s1.value, s2.value);
|
|
89
|
+
return s1;
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
public static __return_type__ finish(State state) {
|
|
93
|
+
return state.initialized ? __response_wrapper__ : null;
|
|
94
|
+
}
|
|
95
|
+
}"""
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
@dataclass(frozen=True)
|
|
99
|
+
class JavaUDAFDef:
|
|
100
|
+
"""
|
|
101
|
+
Complete definition for creating a Java UDAF in Snowflake.
|
|
102
|
+
|
|
103
|
+
Contains all the information needed to generate the CREATE FUNCTION SQL statement
|
|
104
|
+
and the Java code body for the UDAF.
|
|
105
|
+
|
|
106
|
+
Attributes:
|
|
107
|
+
name: UDAF name
|
|
108
|
+
signature: SQL signature (for Snowflake function definition)
|
|
109
|
+
java_signature: Java signature (for Java code generation)
|
|
110
|
+
java_invocation_args: List of transformed arguments passed to the Java UDAF invocation, with type casting applied for Map types and other necessary conversions.
|
|
111
|
+
imports: List of JAR files to import
|
|
112
|
+
null_handling: Null handling behavior (defaults to RETURNS_NULL_ON_NULL_INPUT)
|
|
113
|
+
"""
|
|
114
|
+
|
|
115
|
+
name: str
|
|
116
|
+
signature: Signature
|
|
117
|
+
java_signature: Signature
|
|
118
|
+
imports: list[str]
|
|
119
|
+
null_handling: NullHandling = NullHandling.RETURNS_NULL_ON_NULL_INPUT
|
|
120
|
+
|
|
121
|
+
# -------------------- DDL Emitter --------------------
|
|
122
|
+
|
|
123
|
+
def _gen_body_java(self) -> str:
|
|
124
|
+
"""
|
|
125
|
+
Generate the Java code body for the UDAF.
|
|
126
|
+
|
|
127
|
+
Creates a Java object that loads the serialized function from a binary file
|
|
128
|
+
and provides a run method to execute it.
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
String containing the complete Java code for the UDAF body
|
|
132
|
+
"""
|
|
133
|
+
returns_variant = self.signature.returns.data_type.lower() == "variant"
|
|
134
|
+
return_type = (
|
|
135
|
+
"Variant" if returns_variant else self.java_signature.params[0].data_type
|
|
136
|
+
)
|
|
137
|
+
response_wrapper = (
|
|
138
|
+
"com.snowflake.sas.scala.Utils$.MODULE$.toVariant(state.value, udfPacket)"
|
|
139
|
+
if returns_variant
|
|
140
|
+
else "state.value"
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
is_variant_input = self.java_signature.params[0].data_type.lower() == "variant"
|
|
144
|
+
reduce_type = (
|
|
145
|
+
"Object" if is_variant_input else self.java_signature.params[0].data_type
|
|
146
|
+
)
|
|
147
|
+
return (
|
|
148
|
+
UDAF_TEMPLATE.replace("__operation_file__", self.imports[0].split("/")[-1])
|
|
149
|
+
.replace("__accumulator_type__", self.java_signature.params[0].data_type)
|
|
150
|
+
.replace("__value_type__", self.java_signature.params[1].data_type)
|
|
151
|
+
.replace(
|
|
152
|
+
"__mapped_value__",
|
|
153
|
+
"com.snowflake.sas.scala.UdfPacketUtils$.MODULE$.fromVariant(udfPacket, input, 0)"
|
|
154
|
+
if is_variant_input
|
|
155
|
+
else "input",
|
|
156
|
+
)
|
|
157
|
+
.replace("__reduce_type__", reduce_type)
|
|
158
|
+
.replace("__return_type__", return_type)
|
|
159
|
+
.replace("__response_wrapper__", response_wrapper)
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
def to_create_function_sql(self) -> str:
|
|
163
|
+
"""
|
|
164
|
+
Generate the complete CREATE FUNCTION SQL statement for the Java UDAF.
|
|
165
|
+
|
|
166
|
+
Creates a Snowflake CREATE OR REPLACE TEMPORARY AGGREGATE FUNCTION statement with
|
|
167
|
+
all necessary clauses including language, runtime version, packages,
|
|
168
|
+
imports, and the Java code body.
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
Complete SQL DDL statement for creating the UDAF
|
|
172
|
+
"""
|
|
173
|
+
|
|
174
|
+
args = ", ".join(
|
|
175
|
+
[f"{param.name} {param.data_type}" for param in self.signature.params]
|
|
176
|
+
)
|
|
177
|
+
ret_type = self.signature.returns.data_type
|
|
178
|
+
|
|
179
|
+
def quote_single(s: str) -> str:
|
|
180
|
+
"""Helper function to wrap strings in single quotes for SQL."""
|
|
181
|
+
return "'" + s + "'"
|
|
182
|
+
|
|
183
|
+
# Handler and imports
|
|
184
|
+
imports_sql = f"IMPORTS = ({', '.join(quote_single(x) for x in self.imports)})"
|
|
185
|
+
|
|
186
|
+
return f"""
|
|
187
|
+
CREATE OR REPLACE TEMPORARY AGGREGATE FUNCTION {self.name}({args})
|
|
188
|
+
RETURNS {ret_type}
|
|
189
|
+
LANGUAGE JAVA
|
|
190
|
+
{self.null_handling.value}
|
|
191
|
+
RUNTIME_VERSION = 17
|
|
192
|
+
PACKAGES = ('com.snowflake:snowpark:latest')
|
|
193
|
+
{imports_sql}
|
|
194
|
+
HANDLER = 'JavaUDAF'
|
|
195
|
+
AS
|
|
196
|
+
$$
|
|
197
|
+
{self._gen_body_java()}
|
|
198
|
+
$$;"""
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
class JavaUdaf:
|
|
202
|
+
"""
|
|
203
|
+
Reference class for Java UDAFs, providing similar properties like Python UserDefinedFunction.
|
|
204
|
+
|
|
205
|
+
This class serves as a lightweight reference to a Java UDAF that has been created
|
|
206
|
+
in Snowflake, storing the essential metadata needed for function calls.
|
|
207
|
+
"""
|
|
208
|
+
|
|
209
|
+
def __init__(
|
|
210
|
+
self,
|
|
211
|
+
name: str,
|
|
212
|
+
input_types: list[snowpark_type.DataType],
|
|
213
|
+
return_type: snowpark_type.DataType,
|
|
214
|
+
) -> None:
|
|
215
|
+
"""
|
|
216
|
+
Initialize a Java UDAF reference.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
name: The name of the UDAF in Snowflake
|
|
220
|
+
input_types: List of input parameter types
|
|
221
|
+
return_type: The return type of the UDAF
|
|
222
|
+
"""
|
|
223
|
+
self.name = name
|
|
224
|
+
self._input_types = input_types
|
|
225
|
+
self._return_type = return_type
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
def create_java_udaf_for_reduce_scala_function(
|
|
229
|
+
pciudf: ProcessCommonInlineUserDefinedFunction,
|
|
230
|
+
) -> JavaUdaf:
|
|
231
|
+
"""
|
|
232
|
+
Create a Java UDAF in Snowflake from a ProcessCommonInlineUserDefinedFunction object.
|
|
233
|
+
|
|
234
|
+
This function handles the complete process of creating a Java UDAF:
|
|
235
|
+
1. Generates a unique function name if not provided
|
|
236
|
+
2. Creates the necessary imports list
|
|
237
|
+
3. Maps types between different systems (Snowpark, Java, Snowflake)
|
|
238
|
+
4. Generates and executes the CREATE FUNCTION SQL statement
|
|
239
|
+
|
|
240
|
+
Args:
|
|
241
|
+
pciudf: The ProcessCommonInlineUserDefinedFunction object containing UDF details.
|
|
242
|
+
|
|
243
|
+
Returns:
|
|
244
|
+
A JavaUdaf object representing the Java UDAF.
|
|
245
|
+
"""
|
|
246
|
+
from snowflake.snowpark_connect.resources_initializer import (
|
|
247
|
+
ensure_scala_udf_jars_uploaded,
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
# Make sure Scala UDF jars are uploaded before creating Java UDAFs since we depend on them.
|
|
251
|
+
ensure_scala_udf_jars_uploaded()
|
|
252
|
+
|
|
253
|
+
from snowflake.snowpark_connect.utils.session import get_or_create_snowpark_session
|
|
254
|
+
|
|
255
|
+
function_name = pciudf._function_name
|
|
256
|
+
# If a function name is not provided, hash the binary file and use the first ten characters as the function name.
|
|
257
|
+
if not function_name:
|
|
258
|
+
import hashlib
|
|
259
|
+
|
|
260
|
+
function_name = hashlib.sha256(pciudf._payload).hexdigest()[:10]
|
|
261
|
+
udf_name = CREATE_JAVA_UDAF_PREFIX + function_name
|
|
262
|
+
|
|
263
|
+
input_types = pciudf._input_types
|
|
264
|
+
|
|
265
|
+
java_input_params: list[Param] = []
|
|
266
|
+
sql_input_params: list[Param] = []
|
|
267
|
+
if input_types: # input_types can be None when no arguments are provided
|
|
268
|
+
for i, input_type in enumerate(input_types):
|
|
269
|
+
param_name = "arg" + str(i)
|
|
270
|
+
if isinstance(
|
|
271
|
+
input_type,
|
|
272
|
+
(
|
|
273
|
+
snowpark_type.ArrayType,
|
|
274
|
+
snowpark_type.MapType,
|
|
275
|
+
snowpark_type.VariantType,
|
|
276
|
+
),
|
|
277
|
+
):
|
|
278
|
+
java_type = "Variant"
|
|
279
|
+
snowflake_type = "Variant"
|
|
280
|
+
else:
|
|
281
|
+
java_type = map_type_to_java_type(input_type)
|
|
282
|
+
snowflake_type = map_type_to_snowflake_type(input_type)
|
|
283
|
+
# Create the Java arguments and input types string: "arg0: Type0, arg1: Type1, ...".
|
|
284
|
+
java_input_params.append(Param(param_name, java_type))
|
|
285
|
+
# Create the Snowflake SQL arguments and input types string: "arg0 TYPE0, arg1 TYPE1, ...".
|
|
286
|
+
sql_input_params.append(Param(param_name, snowflake_type))
|
|
287
|
+
|
|
288
|
+
java_return_type = map_type_to_java_type(pciudf._original_return_type)
|
|
289
|
+
# If the SQL return type is a MAP or STRUCT, change this to VARIANT because of issues with Java UDAFs.
|
|
290
|
+
sql_return_type = map_type_to_snowflake_type(pciudf._original_return_type)
|
|
291
|
+
session = get_or_create_snowpark_session()
|
|
292
|
+
|
|
293
|
+
imports = build_jvm_udxf_imports(
|
|
294
|
+
session,
|
|
295
|
+
pciudf._payload,
|
|
296
|
+
udf_name,
|
|
297
|
+
)
|
|
298
|
+
sql_return_type = (
|
|
299
|
+
"VARIANT"
|
|
300
|
+
if (
|
|
301
|
+
sql_return_type.startswith("MAP")
|
|
302
|
+
or sql_return_type.startswith("OBJECT")
|
|
303
|
+
or sql_return_type.startswith("ARRAY")
|
|
304
|
+
)
|
|
305
|
+
else sql_return_type
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
udf_def = JavaUDAFDef(
|
|
309
|
+
name=udf_name,
|
|
310
|
+
signature=Signature(
|
|
311
|
+
params=sql_input_params, returns=ReturnType(sql_return_type)
|
|
312
|
+
),
|
|
313
|
+
imports=imports,
|
|
314
|
+
java_signature=Signature(
|
|
315
|
+
params=java_input_params, returns=ReturnType(java_return_type)
|
|
316
|
+
),
|
|
317
|
+
)
|
|
318
|
+
create_udf_sql = udf_def.to_create_function_sql()
|
|
319
|
+
logger.info(f"Creating Java UDAF: {create_udf_sql}")
|
|
320
|
+
session.sql(create_udf_sql).collect()
|
|
321
|
+
return JavaUdaf(udf_name, pciudf._input_types, pciudf._return_type)
|