snowpark-connect 0.27.0__py3-none-any.whl → 1.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/snowpark_connect/__init__.py +1 -0
- snowflake/snowpark_connect/analyze_plan/map_tree_string.py +8 -4
- snowflake/snowpark_connect/client/__init__.py +15 -0
- snowflake/snowpark_connect/client/error_utils.py +30 -0
- snowflake/snowpark_connect/client/exceptions.py +36 -0
- snowflake/snowpark_connect/client/query_results.py +90 -0
- snowflake/snowpark_connect/client/server.py +680 -0
- snowflake/snowpark_connect/client/utils/__init__.py +10 -0
- snowflake/snowpark_connect/client/utils/session.py +85 -0
- snowflake/snowpark_connect/column_name_handler.py +404 -243
- snowflake/snowpark_connect/column_qualifier.py +43 -0
- snowflake/snowpark_connect/config.py +237 -23
- snowflake/snowpark_connect/constants.py +2 -0
- snowflake/snowpark_connect/dataframe_container.py +102 -8
- snowflake/snowpark_connect/date_time_format_mapping.py +71 -13
- snowflake/snowpark_connect/error/error_codes.py +50 -0
- snowflake/snowpark_connect/error/error_utils.py +172 -23
- snowflake/snowpark_connect/error/exceptions.py +13 -4
- snowflake/snowpark_connect/execute_plan/map_execution_command.py +15 -160
- snowflake/snowpark_connect/execute_plan/map_execution_root.py +26 -20
- snowflake/snowpark_connect/execute_plan/utils.py +5 -1
- snowflake/snowpark_connect/expression/function_defaults.py +9 -2
- snowflake/snowpark_connect/expression/hybrid_column_map.py +53 -5
- snowflake/snowpark_connect/expression/literal.py +37 -13
- snowflake/snowpark_connect/expression/map_cast.py +123 -5
- snowflake/snowpark_connect/expression/map_expression.py +80 -27
- snowflake/snowpark_connect/expression/map_extension.py +322 -12
- snowflake/snowpark_connect/expression/map_sql_expression.py +316 -81
- snowflake/snowpark_connect/expression/map_udf.py +85 -20
- snowflake/snowpark_connect/expression/map_unresolved_attribute.py +451 -173
- snowflake/snowpark_connect/expression/map_unresolved_function.py +2748 -746
- snowflake/snowpark_connect/expression/map_unresolved_star.py +87 -23
- snowflake/snowpark_connect/expression/map_update_fields.py +70 -18
- snowflake/snowpark_connect/expression/map_window_function.py +18 -3
- snowflake/snowpark_connect/includes/jars/{scala-library-2.12.18.jar → sas-scala-udf_2.12-0.2.0.jar} +0 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py +1 -1
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/listener_worker.py +1 -1
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.py +12 -10
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.pyi +14 -2
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +10 -8
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +13 -6
- snowflake/snowpark_connect/relation/catalogs/abstract_spark_catalog.py +65 -17
- snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +297 -49
- snowflake/snowpark_connect/relation/catalogs/utils.py +12 -4
- snowflake/snowpark_connect/relation/io_utils.py +110 -10
- snowflake/snowpark_connect/relation/map_aggregate.py +196 -255
- snowflake/snowpark_connect/relation/map_catalog.py +5 -1
- snowflake/snowpark_connect/relation/map_column_ops.py +264 -96
- snowflake/snowpark_connect/relation/map_extension.py +263 -29
- snowflake/snowpark_connect/relation/map_join.py +683 -442
- snowflake/snowpark_connect/relation/map_local_relation.py +28 -1
- snowflake/snowpark_connect/relation/map_map_partitions.py +83 -8
- snowflake/snowpark_connect/relation/map_relation.py +48 -19
- snowflake/snowpark_connect/relation/map_row_ops.py +310 -91
- snowflake/snowpark_connect/relation/map_show_string.py +13 -6
- snowflake/snowpark_connect/relation/map_sql.py +1233 -222
- snowflake/snowpark_connect/relation/map_stats.py +48 -9
- snowflake/snowpark_connect/relation/map_subquery_alias.py +11 -2
- snowflake/snowpark_connect/relation/map_udtf.py +14 -4
- snowflake/snowpark_connect/relation/read/jdbc_read_dbapi.py +53 -14
- snowflake/snowpark_connect/relation/read/map_read.py +134 -43
- snowflake/snowpark_connect/relation/read/map_read_csv.py +255 -45
- snowflake/snowpark_connect/relation/read/map_read_jdbc.py +17 -5
- snowflake/snowpark_connect/relation/read/map_read_json.py +320 -85
- snowflake/snowpark_connect/relation/read/map_read_parquet.py +142 -27
- snowflake/snowpark_connect/relation/read/map_read_partitioned_parquet.py +142 -0
- snowflake/snowpark_connect/relation/read/map_read_socket.py +11 -3
- snowflake/snowpark_connect/relation/read/map_read_table.py +82 -5
- snowflake/snowpark_connect/relation/read/map_read_text.py +18 -3
- snowflake/snowpark_connect/relation/read/metadata_utils.py +170 -0
- snowflake/snowpark_connect/relation/read/reader_config.py +36 -3
- snowflake/snowpark_connect/relation/read/utils.py +50 -5
- snowflake/snowpark_connect/relation/stage_locator.py +91 -55
- snowflake/snowpark_connect/relation/utils.py +128 -5
- snowflake/snowpark_connect/relation/write/jdbc_write_dbapi.py +19 -3
- snowflake/snowpark_connect/relation/write/map_write.py +929 -319
- snowflake/snowpark_connect/relation/write/map_write_jdbc.py +8 -2
- snowflake/snowpark_connect/resources/java_udfs-1.0-SNAPSHOT.jar +0 -0
- snowflake/snowpark_connect/resources_initializer.py +110 -48
- snowflake/snowpark_connect/server.py +546 -456
- snowflake/snowpark_connect/server_common/__init__.py +500 -0
- snowflake/snowpark_connect/snowflake_session.py +65 -0
- snowflake/snowpark_connect/start_server.py +53 -5
- snowflake/snowpark_connect/type_mapping.py +349 -27
- snowflake/snowpark_connect/typed_column.py +9 -7
- snowflake/snowpark_connect/utils/artifacts.py +9 -8
- snowflake/snowpark_connect/utils/cache.py +49 -27
- snowflake/snowpark_connect/utils/concurrent.py +36 -1
- snowflake/snowpark_connect/utils/context.py +187 -37
- snowflake/snowpark_connect/utils/describe_query_cache.py +68 -53
- snowflake/snowpark_connect/utils/env_utils.py +5 -1
- snowflake/snowpark_connect/utils/expression_transformer.py +172 -0
- snowflake/snowpark_connect/utils/identifiers.py +137 -3
- snowflake/snowpark_connect/utils/io_utils.py +57 -1
- snowflake/snowpark_connect/utils/java_stored_procedure.py +125 -0
- snowflake/snowpark_connect/utils/java_udaf_utils.py +303 -0
- snowflake/snowpark_connect/utils/java_udtf_utils.py +239 -0
- snowflake/snowpark_connect/utils/jvm_udf_utils.py +248 -0
- snowflake/snowpark_connect/utils/open_telemetry.py +516 -0
- snowflake/snowpark_connect/utils/pandas_udtf_utils.py +8 -4
- snowflake/snowpark_connect/utils/patch_spark_line_number.py +181 -0
- snowflake/snowpark_connect/utils/profiling.py +25 -8
- snowflake/snowpark_connect/utils/scala_udf_utils.py +101 -332
- snowflake/snowpark_connect/utils/sequence.py +21 -0
- snowflake/snowpark_connect/utils/session.py +64 -28
- snowflake/snowpark_connect/utils/snowpark_connect_logging.py +51 -9
- snowflake/snowpark_connect/utils/spcs_logger.py +290 -0
- snowflake/snowpark_connect/utils/telemetry.py +163 -22
- snowflake/snowpark_connect/utils/temporary_view_cache.py +67 -0
- snowflake/snowpark_connect/utils/temporary_view_helper.py +334 -0
- snowflake/snowpark_connect/utils/udf_cache.py +117 -41
- snowflake/snowpark_connect/utils/udf_helper.py +39 -37
- snowflake/snowpark_connect/utils/udf_utils.py +133 -14
- snowflake/snowpark_connect/utils/udtf_helper.py +8 -1
- snowflake/snowpark_connect/utils/udtf_utils.py +46 -31
- snowflake/snowpark_connect/utils/upload_java_jar.py +57 -0
- snowflake/snowpark_connect/version.py +1 -1
- snowflake/snowpark_decoder/dp_session.py +6 -2
- snowflake/snowpark_decoder/spark_decoder.py +12 -0
- {snowpark_connect-0.27.0.data → snowpark_connect-1.6.0.data}/scripts/snowpark-submit +2 -2
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/METADATA +14 -7
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/RECORD +129 -167
- snowflake/snowpark_connect/hidden_column.py +0 -39
- snowflake/snowpark_connect/includes/jars/antlr4-runtime-4.9.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-cli-1.5.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-codec-1.16.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-collections-3.2.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-collections4-4.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-compiler-3.1.9.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-compress-1.26.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-crypto-1.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-dbcp-1.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-io-2.16.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-lang-2.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-lang3-3.12.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-logging-1.1.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-math3-3.6.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-pool-1.5.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-text-1.10.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/hadoop-client-api-trimmed-3.3.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-annotations-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-core-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-core-asl-1.9.13.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-databind-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-dataformat-yaml-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-datatype-jsr310-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-module-scala_2.12-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-ast_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-core_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-jackson_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-native_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-scalap_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/kryo-shaded-4.0.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-1.2-api-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-api-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-core-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-slf4j2-impl-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/paranamer-2.8.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/paranamer-2.8.jar +0 -0
- snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-collection-compat_2.12-2.7.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-parser-combinators_2.12-2.3.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-reflect-2.12.18.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-xml_2.12-2.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/slf4j-api-2.0.7.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-catalyst_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-common-utils_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-connect-client-jvm_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-core_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-graphx_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-hive-thriftserver_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-hive_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-kvstore_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-launcher_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-mesos_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-mllib-local_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-network-common_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-network-shuffle_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-repl_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sketch_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sql-api_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sql_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-tags_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-unsafe_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-yarn_2.12-3.5.6.jar +0 -0
- {snowpark_connect-0.27.0.data → snowpark_connect-1.6.0.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.27.0.data → snowpark_connect-1.6.0.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/licenses/NOTICE-binary +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,303 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
|
|
7
|
+
import snowflake.snowpark.types as snowpark_type
|
|
8
|
+
from snowflake.snowpark_connect.type_mapping import map_type_to_snowflake_type
|
|
9
|
+
from snowflake.snowpark_connect.utils.jvm_udf_utils import (
|
|
10
|
+
NullHandling,
|
|
11
|
+
Param,
|
|
12
|
+
ReturnType,
|
|
13
|
+
Signature,
|
|
14
|
+
build_jvm_udxf_imports,
|
|
15
|
+
cast_java_map_args_from_given_type,
|
|
16
|
+
map_type_to_java_type,
|
|
17
|
+
)
|
|
18
|
+
from snowflake.snowpark_connect.utils.snowpark_connect_logging import logger
|
|
19
|
+
from snowflake.snowpark_connect.utils.udf_utils import (
|
|
20
|
+
ProcessCommonInlineUserDefinedFunction,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
# Prefix used for internally generated Java UDAF names to avoid conflicts
|
|
24
|
+
CREATE_JAVA_UDAF_PREFIX = "__SC_JAVA_UDAF_"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
UDAF_TEMPLATE = """
|
|
28
|
+
import org.apache.spark.sql.connect.common.UdfPacket;
|
|
29
|
+
|
|
30
|
+
import java.io.IOException;
|
|
31
|
+
import java.io.InputStream;
|
|
32
|
+
import java.io.ObjectInputStream;
|
|
33
|
+
import java.io.Serializable;
|
|
34
|
+
import java.nio.file.Files;
|
|
35
|
+
import java.nio.file.Paths;
|
|
36
|
+
|
|
37
|
+
// Import types required for mapping
|
|
38
|
+
import java.util.*;
|
|
39
|
+
import java.util.stream.Collectors;
|
|
40
|
+
import com.snowflake.snowpark_java.types.*;
|
|
41
|
+
|
|
42
|
+
public class JavaUDAF {
|
|
43
|
+
private final static String OPERATION_FILE = "__operation_file__";
|
|
44
|
+
private static scala.Function2<__accumulator_type__, __value_type__, __value_type__> operation = null;
|
|
45
|
+
|
|
46
|
+
private static void loadOperation() throws IOException, ClassNotFoundException {
|
|
47
|
+
if (operation != null) {
|
|
48
|
+
return; // Already loaded
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
final UdfPacket udfPacket = com.snowflake.sas.scala.Utils$.MODULE$.deserializeUdfPacket(OPERATION_FILE);
|
|
52
|
+
operation = (scala.Function2<__accumulator_type__, __value_type__, __value_type__>) udfPacket.function();
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
public static class State implements Serializable {
|
|
56
|
+
public __accumulator_type__ value = null;
|
|
57
|
+
public boolean initialized = false;
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
public static State initialize() throws IOException, ClassNotFoundException {
|
|
61
|
+
loadOperation();
|
|
62
|
+
return new State();
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
public static State accumulate(State state, __accumulator_type__ accumulator, __value_type__ input) {
|
|
66
|
+
// TODO: Add conversion between value_type we get in input and the value that we are using in the operation
|
|
67
|
+
if (input == null) {
|
|
68
|
+
return state;
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
if (!state.initialized) {
|
|
72
|
+
state.value = input;
|
|
73
|
+
state.initialized = true;
|
|
74
|
+
} else {
|
|
75
|
+
state.value = operation.apply(state.value, input);
|
|
76
|
+
}
|
|
77
|
+
return state;
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
public static State merge(State s1, State s2) {
|
|
81
|
+
if (!s2.initialized) {
|
|
82
|
+
return s1;
|
|
83
|
+
}
|
|
84
|
+
if (!s1.initialized) {
|
|
85
|
+
return s2;
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
s1.value = operation.apply(s1.value, s2.value);
|
|
89
|
+
return s1;
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
public static __return_type__ finish(State state) {
|
|
93
|
+
return state.initialized ? __response_wrapper__ : null;
|
|
94
|
+
}
|
|
95
|
+
}"""
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
@dataclass(frozen=True)
|
|
99
|
+
class JavaUDAFDef:
|
|
100
|
+
"""
|
|
101
|
+
Complete definition for creating a Java UDAF in Snowflake.
|
|
102
|
+
|
|
103
|
+
Contains all the information needed to generate the CREATE FUNCTION SQL statement
|
|
104
|
+
and the Java code body for the UDAF.
|
|
105
|
+
|
|
106
|
+
Attributes:
|
|
107
|
+
name: UDAF name
|
|
108
|
+
signature: SQL signature (for Snowflake function definition)
|
|
109
|
+
java_signature: Java signature (for Java code generation)
|
|
110
|
+
java_invocation_args: List of transformed arguments passed to the Java UDAF invocation, with type casting applied for Map types and other necessary conversions.
|
|
111
|
+
imports: List of JAR files to import
|
|
112
|
+
null_handling: Null handling behavior (defaults to RETURNS_NULL_ON_NULL_INPUT)
|
|
113
|
+
"""
|
|
114
|
+
|
|
115
|
+
name: str
|
|
116
|
+
signature: Signature
|
|
117
|
+
java_signature: Signature
|
|
118
|
+
java_invocation_args: list[str]
|
|
119
|
+
imports: list[str]
|
|
120
|
+
null_handling: NullHandling = NullHandling.RETURNS_NULL_ON_NULL_INPUT
|
|
121
|
+
|
|
122
|
+
# -------------------- DDL Emitter --------------------
|
|
123
|
+
|
|
124
|
+
def _gen_body_java(self) -> str:
|
|
125
|
+
"""
|
|
126
|
+
Generate the Java code body for the UDAF.
|
|
127
|
+
|
|
128
|
+
Creates a Java object that loads the serialized function from a binary file
|
|
129
|
+
and provides a run method to execute it.
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
String containing the complete Java code for the UDAF body
|
|
133
|
+
"""
|
|
134
|
+
returns_variant = self.signature.returns.data_type == "VARIANT"
|
|
135
|
+
return_type = (
|
|
136
|
+
"Variant" if returns_variant else self.java_signature.params[0].data_type
|
|
137
|
+
)
|
|
138
|
+
response_wrapper = (
|
|
139
|
+
"new Variant(state.value)" if returns_variant else "state.value"
|
|
140
|
+
)
|
|
141
|
+
return (
|
|
142
|
+
UDAF_TEMPLATE.replace("__operation_file__", self.imports[0].split("/")[-1])
|
|
143
|
+
.replace("__accumulator_type__", self.java_signature.params[0].data_type)
|
|
144
|
+
.replace("__value_type__", self.java_signature.params[1].data_type)
|
|
145
|
+
.replace("__return_type__", return_type)
|
|
146
|
+
.replace("__response_wrapper__", response_wrapper)
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
def to_create_function_sql(self) -> str:
|
|
150
|
+
"""
|
|
151
|
+
Generate the complete CREATE FUNCTION SQL statement for the Java UDAF.
|
|
152
|
+
|
|
153
|
+
Creates a Snowflake CREATE OR REPLACE TEMPORARY AGGREGATE FUNCTION statement with
|
|
154
|
+
all necessary clauses including language, runtime version, packages,
|
|
155
|
+
imports, and the Java code body.
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
Complete SQL DDL statement for creating the UDAF
|
|
159
|
+
"""
|
|
160
|
+
|
|
161
|
+
args = ", ".join(
|
|
162
|
+
[f"{param.name} {param.data_type}" for param in self.signature.params]
|
|
163
|
+
)
|
|
164
|
+
ret_type = self.signature.returns.data_type
|
|
165
|
+
|
|
166
|
+
def quote_single(s: str) -> str:
|
|
167
|
+
"""Helper function to wrap strings in single quotes for SQL."""
|
|
168
|
+
return "'" + s + "'"
|
|
169
|
+
|
|
170
|
+
# Handler and imports
|
|
171
|
+
imports_sql = f"IMPORTS = ({', '.join(quote_single(x) for x in self.imports)})"
|
|
172
|
+
|
|
173
|
+
return f"""
|
|
174
|
+
CREATE OR REPLACE TEMPORARY AGGREGATE FUNCTION {self.name}({args})
|
|
175
|
+
RETURNS {ret_type}
|
|
176
|
+
LANGUAGE JAVA
|
|
177
|
+
{self.null_handling.value}
|
|
178
|
+
RUNTIME_VERSION = 17
|
|
179
|
+
PACKAGES = ('com.snowflake:snowpark:latest')
|
|
180
|
+
{imports_sql}
|
|
181
|
+
HANDLER = 'JavaUDAF'
|
|
182
|
+
AS
|
|
183
|
+
$$
|
|
184
|
+
{self._gen_body_java()}
|
|
185
|
+
$$;"""
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
class JavaUdaf:
|
|
189
|
+
"""
|
|
190
|
+
Reference class for Java UDAFs, providing similar properties like Python UserDefinedFunction.
|
|
191
|
+
|
|
192
|
+
This class serves as a lightweight reference to a Java UDAF that has been created
|
|
193
|
+
in Snowflake, storing the essential metadata needed for function calls.
|
|
194
|
+
"""
|
|
195
|
+
|
|
196
|
+
def __init__(
|
|
197
|
+
self,
|
|
198
|
+
name: str,
|
|
199
|
+
input_types: list[snowpark_type.DataType],
|
|
200
|
+
return_type: snowpark_type.DataType,
|
|
201
|
+
) -> None:
|
|
202
|
+
"""
|
|
203
|
+
Initialize a Java UDAF reference.
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
name: The name of the UDAF in Snowflake
|
|
207
|
+
input_types: List of input parameter types
|
|
208
|
+
return_type: The return type of the UDAF
|
|
209
|
+
"""
|
|
210
|
+
self.name = name
|
|
211
|
+
self._input_types = input_types
|
|
212
|
+
self._return_type = return_type
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def create_java_udaf_for_reduce_scala_function(
|
|
216
|
+
pciudf: ProcessCommonInlineUserDefinedFunction,
|
|
217
|
+
) -> JavaUdaf:
|
|
218
|
+
"""
|
|
219
|
+
Create a Java UDAF in Snowflake from a ProcessCommonInlineUserDefinedFunction object.
|
|
220
|
+
|
|
221
|
+
This function handles the complete process of creating a Java UDAF:
|
|
222
|
+
1. Generates a unique function name if not provided
|
|
223
|
+
2. Creates the necessary imports list
|
|
224
|
+
3. Maps types between different systems (Snowpark, Java, Snowflake)
|
|
225
|
+
4. Generates and executes the CREATE FUNCTION SQL statement
|
|
226
|
+
|
|
227
|
+
Args:
|
|
228
|
+
pciudf: The ProcessCommonInlineUserDefinedFunction object containing UDF details.
|
|
229
|
+
|
|
230
|
+
Returns:
|
|
231
|
+
A JavaUdaf object representing the Java UDAF.
|
|
232
|
+
"""
|
|
233
|
+
from snowflake.snowpark_connect.resources_initializer import (
|
|
234
|
+
wait_for_resource_initialization,
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
# Make sure that the resource initializer thread is completed before creating Java UDFs since we depend on the jars
|
|
238
|
+
# uploaded by it.
|
|
239
|
+
wait_for_resource_initialization()
|
|
240
|
+
|
|
241
|
+
from snowflake.snowpark_connect.utils.session import get_or_create_snowpark_session
|
|
242
|
+
|
|
243
|
+
function_name = pciudf._function_name
|
|
244
|
+
# If a function name is not provided, hash the binary file and use the first ten characters as the function name.
|
|
245
|
+
if not function_name:
|
|
246
|
+
import hashlib
|
|
247
|
+
|
|
248
|
+
function_name = hashlib.sha256(pciudf._payload).hexdigest()[:10]
|
|
249
|
+
udf_name = CREATE_JAVA_UDAF_PREFIX + function_name
|
|
250
|
+
|
|
251
|
+
input_types = pciudf._input_types
|
|
252
|
+
|
|
253
|
+
java_input_params: list[Param] = []
|
|
254
|
+
sql_input_params: list[Param] = []
|
|
255
|
+
java_invocation_args: list[str] = [] # arguments passed into the udf function
|
|
256
|
+
if input_types: # input_types can be None when no arguments are provided
|
|
257
|
+
for i, input_type in enumerate(input_types):
|
|
258
|
+
param_name = "arg" + str(i)
|
|
259
|
+
# Create the Java arguments and input types string: "arg0: Type0, arg1: Type1, ...".
|
|
260
|
+
java_input_params.append(
|
|
261
|
+
Param(param_name, map_type_to_java_type(input_type))
|
|
262
|
+
)
|
|
263
|
+
# Create the Snowflake SQL arguments and input types string: "arg0 TYPE0, arg1 TYPE1, ...".
|
|
264
|
+
sql_input_params.append(
|
|
265
|
+
Param(param_name, map_type_to_snowflake_type(input_type))
|
|
266
|
+
)
|
|
267
|
+
# In the case of Map input types, we need to cast the argument to the correct type in Java.
|
|
268
|
+
# Snowflake SQL Java can only handle MAP[VARCHAR, VARCHAR] as input types.
|
|
269
|
+
java_invocation_args.append(
|
|
270
|
+
cast_java_map_args_from_given_type(param_name, input_type)
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
java_return_type = map_type_to_java_type(pciudf._original_return_type)
|
|
274
|
+
# If the SQL return type is a MAP or STRUCT, change this to VARIANT because of issues with Java UDAFs.
|
|
275
|
+
sql_return_type = map_type_to_snowflake_type(pciudf._original_return_type)
|
|
276
|
+
session = get_or_create_snowpark_session()
|
|
277
|
+
|
|
278
|
+
imports = build_jvm_udxf_imports(
|
|
279
|
+
session,
|
|
280
|
+
pciudf._payload,
|
|
281
|
+
udf_name,
|
|
282
|
+
)
|
|
283
|
+
sql_return_type = (
|
|
284
|
+
"VARIANT"
|
|
285
|
+
if (sql_return_type.startswith("MAP") or sql_return_type.startswith("OBJECT"))
|
|
286
|
+
else sql_return_type
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
udf_def = JavaUDAFDef(
|
|
290
|
+
name=udf_name,
|
|
291
|
+
signature=Signature(
|
|
292
|
+
params=sql_input_params, returns=ReturnType(sql_return_type)
|
|
293
|
+
),
|
|
294
|
+
imports=imports,
|
|
295
|
+
java_signature=Signature(
|
|
296
|
+
params=java_input_params, returns=ReturnType(java_return_type)
|
|
297
|
+
),
|
|
298
|
+
java_invocation_args=java_invocation_args,
|
|
299
|
+
)
|
|
300
|
+
create_udf_sql = udf_def.to_create_function_sql()
|
|
301
|
+
logger.info(f"Creating Java UDAF: {create_udf_sql}")
|
|
302
|
+
session.sql(create_udf_sql).collect()
|
|
303
|
+
return JavaUdaf(udf_name, pciudf._input_types, pciudf._return_type)
|
|
@@ -0,0 +1,239 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
|
|
5
|
+
import hashlib
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
|
|
8
|
+
from pyspark.sql.connect.proto.expressions_pb2 import CommonInlineUserDefinedFunction
|
|
9
|
+
|
|
10
|
+
from snowflake.snowpark.types import ArrayType, MapType, VariantType
|
|
11
|
+
from snowflake.snowpark_connect.resources_initializer import (
|
|
12
|
+
ensure_scala_udf_jars_uploaded,
|
|
13
|
+
)
|
|
14
|
+
from snowflake.snowpark_connect.type_mapping import (
|
|
15
|
+
map_type_to_snowflake_type,
|
|
16
|
+
proto_to_snowpark_type,
|
|
17
|
+
)
|
|
18
|
+
from snowflake.snowpark_connect.utils.jvm_udf_utils import (
|
|
19
|
+
NullHandling,
|
|
20
|
+
Param,
|
|
21
|
+
ReturnType,
|
|
22
|
+
Signature,
|
|
23
|
+
build_jvm_udxf_imports,
|
|
24
|
+
map_type_to_java_type,
|
|
25
|
+
)
|
|
26
|
+
from snowflake.snowpark_connect.utils.session import get_or_create_snowpark_session
|
|
27
|
+
|
|
28
|
+
JAVA_UDTF_PREFIX = "__SC_JAVA_UDTF_"
|
|
29
|
+
|
|
30
|
+
SCALA_INPUT_VARIANT = """
|
|
31
|
+
Object mappedInput = com.snowflake.sas.scala.UdfPacketUtils$.MODULE$.fromVariant(udfPacket, input, 0);
|
|
32
|
+
|
|
33
|
+
java.util.Iterator<Object> javaInput = Arrays.asList(mappedInput).iterator();
|
|
34
|
+
scala.collection.Iterator<Object> scalaInput = new scala.collection.AbstractIterator<Object>() {
|
|
35
|
+
public boolean hasNext() { return javaInput.hasNext(); }
|
|
36
|
+
public Object next() { return javaInput.next(); }
|
|
37
|
+
};
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
SCALA_INPUT_SIMPLE_TYPE = """
|
|
41
|
+
java.util.Iterator<__iterator_type__> javaInput = Arrays.asList(input).iterator();
|
|
42
|
+
scala.collection.Iterator<__iterator_type__> scalaInput = new scala.collection.AbstractIterator<__iterator_type__>() {
|
|
43
|
+
public boolean hasNext() { return javaInput.hasNext(); }
|
|
44
|
+
public __iterator_type__ next() { return javaInput.next(); }
|
|
45
|
+
};
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
UDTF_TEMPLATE = """
|
|
49
|
+
import org.apache.spark.sql.connect.common.UdfPacket;
|
|
50
|
+
|
|
51
|
+
import java.io.IOException;
|
|
52
|
+
import java.io.InputStream;
|
|
53
|
+
import java.io.ObjectInputStream;
|
|
54
|
+
import java.io.Serializable;
|
|
55
|
+
import java.nio.file.Files;
|
|
56
|
+
import java.nio.file.Paths;
|
|
57
|
+
|
|
58
|
+
import java.util.*;
|
|
59
|
+
import java.lang.*;
|
|
60
|
+
import java.util.stream.Collectors;
|
|
61
|
+
import com.snowflake.snowpark_java.types.*;
|
|
62
|
+
import java.util.stream.Stream;
|
|
63
|
+
import java.util.stream.StreamSupport;
|
|
64
|
+
|
|
65
|
+
public class OutputRow {
|
|
66
|
+
public Variant __java_udtf_prefix__C1;
|
|
67
|
+
public OutputRow(Variant __java_udtf_prefix__C1) {
|
|
68
|
+
this.__java_udtf_prefix__C1 = __java_udtf_prefix__C1;
|
|
69
|
+
}
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
public class JavaUdtfHandler {
|
|
73
|
+
private final static String OPERATION_FILE = "__operation_file__";
|
|
74
|
+
private static scala.Function1<scala.collection.Iterator<__iterator_type__>, scala.collection.Iterator<Object>> operation = null;
|
|
75
|
+
private static UdfPacket udfPacket = null;
|
|
76
|
+
|
|
77
|
+
public static Class getOutputClass() { return OutputRow.class; }
|
|
78
|
+
|
|
79
|
+
private static void loadOperation() throws IOException, ClassNotFoundException {
|
|
80
|
+
if (operation != null) {
|
|
81
|
+
return; // Already loaded
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
udfPacket = com.snowflake.sas.scala.Utils$.MODULE$.deserializeUdfPacket(OPERATION_FILE);
|
|
85
|
+
operation = (scala.Function1<scala.collection.Iterator<__iterator_type__>, scala.collection.Iterator<Object>>) udfPacket.function();
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
public Stream<OutputRow> process(__input_type__ input) throws IOException, ClassNotFoundException {
|
|
89
|
+
loadOperation();
|
|
90
|
+
|
|
91
|
+
__scala_input__
|
|
92
|
+
|
|
93
|
+
scala.collection.Iterator<Object> scalaResult = operation.apply(scalaInput);
|
|
94
|
+
|
|
95
|
+
java.util.Iterator<Variant> javaResult = new java.util.Iterator<Variant>() {
|
|
96
|
+
public boolean hasNext() { return scalaResult.hasNext(); }
|
|
97
|
+
public Variant next() {
|
|
98
|
+
return com.snowflake.sas.scala.Utils$.MODULE$.toVariant(scalaResult.next());
|
|
99
|
+
}
|
|
100
|
+
};
|
|
101
|
+
|
|
102
|
+
return StreamSupport.stream(Spliterators.spliteratorUnknownSize(javaResult, Spliterator.ORDERED), false)
|
|
103
|
+
.map(i -> new OutputRow(i));
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
public Stream<OutputRow> endPartition() {
|
|
107
|
+
return Stream.empty();
|
|
108
|
+
}
|
|
109
|
+
}
|
|
110
|
+
"""
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
@dataclass(frozen=True)
|
|
114
|
+
class JavaUDTFDef:
|
|
115
|
+
"""
|
|
116
|
+
Complete definition for creating a Java UDTF in Snowflake.
|
|
117
|
+
|
|
118
|
+
Contains all the information needed to generate the CREATE FUNCTION SQL statement
|
|
119
|
+
and the Java code body for the UDTF.
|
|
120
|
+
|
|
121
|
+
Attributes:
|
|
122
|
+
name: UDTF name
|
|
123
|
+
signature: SQL signature (for Snowflake function definition)
|
|
124
|
+
java_signature: Java signature (for Java code generation)
|
|
125
|
+
imports: List of JAR files to import
|
|
126
|
+
null_handling: Null handling behavior (defaults to RETURNS_NULL_ON_NULL_INPUT)
|
|
127
|
+
"""
|
|
128
|
+
|
|
129
|
+
name: str
|
|
130
|
+
signature: Signature
|
|
131
|
+
java_signature: Signature
|
|
132
|
+
imports: list[str]
|
|
133
|
+
null_handling: NullHandling = NullHandling.RETURNS_NULL_ON_NULL_INPUT
|
|
134
|
+
|
|
135
|
+
def _gen_body_java(self) -> str:
|
|
136
|
+
returns_variant = self.signature.returns.data_type == "VARIANT"
|
|
137
|
+
return_type = (
|
|
138
|
+
"Variant" if returns_variant else self.java_signature.returns.data_type
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
is_variant_input = self.java_signature.params[0].data_type.lower() == "variant"
|
|
142
|
+
|
|
143
|
+
scala_input_template = (
|
|
144
|
+
SCALA_INPUT_VARIANT if is_variant_input else SCALA_INPUT_SIMPLE_TYPE
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
iterator_type = (
|
|
148
|
+
"Object" if is_variant_input else self.java_signature.params[0].data_type
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
return (
|
|
152
|
+
UDTF_TEMPLATE.replace("__operation_file__", self.imports[0].split("/")[-1])
|
|
153
|
+
.replace("__scala_input__", scala_input_template)
|
|
154
|
+
.replace("__iterator_type__", iterator_type)
|
|
155
|
+
.replace("__input_type__", self.java_signature.params[0].data_type)
|
|
156
|
+
.replace("__return_type__", return_type)
|
|
157
|
+
.replace("__java_udtf_prefix__", JAVA_UDTF_PREFIX)
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
def to_create_function_sql(self) -> str:
|
|
161
|
+
args = ", ".join(
|
|
162
|
+
[f"{param.name} {param.data_type}" for param in self.signature.params]
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
def quote_single(s: str) -> str:
|
|
166
|
+
"""Helper function to wrap strings in single quotes for SQL."""
|
|
167
|
+
return "'" + s + "'"
|
|
168
|
+
|
|
169
|
+
# Handler and imports
|
|
170
|
+
imports_sql = f"IMPORTS = ({', '.join(quote_single(x) for x in self.imports)})"
|
|
171
|
+
|
|
172
|
+
return f"""
|
|
173
|
+
create or replace function {self.name}({args})
|
|
174
|
+
returns table ({JAVA_UDTF_PREFIX}C1 VARIANT)
|
|
175
|
+
language java
|
|
176
|
+
runtime_version = 17
|
|
177
|
+
PACKAGES = ('com.snowflake:snowpark:latest')
|
|
178
|
+
{imports_sql}
|
|
179
|
+
handler='JavaUdtfHandler'
|
|
180
|
+
as
|
|
181
|
+
$$
|
|
182
|
+
{self._gen_body_java()}
|
|
183
|
+
$$;"""
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def create_java_udtf_for_scala_flatmap_handling(
|
|
187
|
+
udf_proto: CommonInlineUserDefinedFunction,
|
|
188
|
+
) -> str:
|
|
189
|
+
ensure_scala_udf_jars_uploaded()
|
|
190
|
+
|
|
191
|
+
return_type = proto_to_snowpark_type(udf_proto.scalar_scala_udf.outputType)
|
|
192
|
+
|
|
193
|
+
session = get_or_create_snowpark_session()
|
|
194
|
+
|
|
195
|
+
return_type_java = map_type_to_java_type(return_type)
|
|
196
|
+
sql_return_type = map_type_to_snowflake_type(return_type)
|
|
197
|
+
|
|
198
|
+
java_input_params: list[Param] = []
|
|
199
|
+
sql_input_params: list[Param] = []
|
|
200
|
+
for i, input_type_proto in enumerate(udf_proto.scalar_scala_udf.inputTypes):
|
|
201
|
+
input_type = proto_to_snowpark_type(input_type_proto)
|
|
202
|
+
|
|
203
|
+
param_name = "arg" + str(i)
|
|
204
|
+
|
|
205
|
+
if isinstance(input_type, (ArrayType, MapType, VariantType)):
|
|
206
|
+
java_type = "Variant"
|
|
207
|
+
snowflake_type = "Variant"
|
|
208
|
+
else:
|
|
209
|
+
java_type = map_type_to_java_type(input_type)
|
|
210
|
+
snowflake_type = map_type_to_snowflake_type(input_type)
|
|
211
|
+
|
|
212
|
+
java_input_params.append(Param(param_name, java_type))
|
|
213
|
+
sql_input_params.append(Param(param_name, snowflake_type))
|
|
214
|
+
|
|
215
|
+
udtf_name = (
|
|
216
|
+
JAVA_UDTF_PREFIX + hashlib.md5(udf_proto.scalar_scala_udf.payload).hexdigest()
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
imports = build_jvm_udxf_imports(
|
|
220
|
+
session,
|
|
221
|
+
udf_proto.scalar_scala_udf.payload,
|
|
222
|
+
udtf_name,
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
udtf = JavaUDTFDef(
|
|
226
|
+
name=udtf_name,
|
|
227
|
+
signature=Signature(
|
|
228
|
+
params=sql_input_params, returns=ReturnType(sql_return_type)
|
|
229
|
+
),
|
|
230
|
+
imports=imports,
|
|
231
|
+
java_signature=Signature(
|
|
232
|
+
params=java_input_params, returns=ReturnType(return_type_java)
|
|
233
|
+
),
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
sql = udtf.to_create_function_sql()
|
|
237
|
+
session.sql(sql).collect()
|
|
238
|
+
|
|
239
|
+
return udtf_name
|