snowpark-connect 0.27.0__py3-none-any.whl → 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/snowpark_connect/__init__.py +1 -0
- snowflake/snowpark_connect/analyze_plan/map_tree_string.py +8 -4
- snowflake/snowpark_connect/client/__init__.py +15 -0
- snowflake/snowpark_connect/client/error_utils.py +30 -0
- snowflake/snowpark_connect/client/exceptions.py +36 -0
- snowflake/snowpark_connect/client/query_results.py +90 -0
- snowflake/snowpark_connect/client/server.py +717 -0
- snowflake/snowpark_connect/client/utils/__init__.py +10 -0
- snowflake/snowpark_connect/client/utils/session.py +85 -0
- snowflake/snowpark_connect/column_name_handler.py +404 -243
- snowflake/snowpark_connect/column_qualifier.py +43 -0
- snowflake/snowpark_connect/config.py +309 -26
- snowflake/snowpark_connect/constants.py +2 -0
- snowflake/snowpark_connect/dataframe_container.py +102 -8
- snowflake/snowpark_connect/date_time_format_mapping.py +71 -13
- snowflake/snowpark_connect/error/error_codes.py +50 -0
- snowflake/snowpark_connect/error/error_utils.py +172 -23
- snowflake/snowpark_connect/error/exceptions.py +13 -4
- snowflake/snowpark_connect/execute_plan/map_execution_command.py +15 -160
- snowflake/snowpark_connect/execute_plan/map_execution_root.py +26 -20
- snowflake/snowpark_connect/execute_plan/utils.py +5 -1
- snowflake/snowpark_connect/expression/error_utils.py +28 -0
- snowflake/snowpark_connect/expression/function_defaults.py +9 -2
- snowflake/snowpark_connect/expression/hybrid_column_map.py +53 -5
- snowflake/snowpark_connect/expression/integral_types_support.py +219 -0
- snowflake/snowpark_connect/expression/literal.py +37 -13
- snowflake/snowpark_connect/expression/map_cast.py +224 -15
- snowflake/snowpark_connect/expression/map_expression.py +80 -27
- snowflake/snowpark_connect/expression/map_extension.py +322 -12
- snowflake/snowpark_connect/expression/map_sql_expression.py +316 -81
- snowflake/snowpark_connect/expression/map_udf.py +86 -20
- snowflake/snowpark_connect/expression/map_unresolved_attribute.py +451 -173
- snowflake/snowpark_connect/expression/map_unresolved_function.py +2964 -829
- snowflake/snowpark_connect/expression/map_unresolved_star.py +87 -23
- snowflake/snowpark_connect/expression/map_update_fields.py +70 -18
- snowflake/snowpark_connect/expression/map_window_function.py +18 -3
- snowflake/snowpark_connect/includes/jars/json4s-ast_2.13-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/{scala-library-2.12.18.jar → sas-scala-udf_2.12-0.2.0.jar} +0 -0
- snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.13-0.2.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-reflect-2.13.16.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-common-utils_2.13-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/{spark-connect-client-jvm_2.12-3.5.6.jar → spark-connect-client-jvm_2.13-3.5.6.jar} +0 -0
- snowflake/snowpark_connect/includes/jars/{spark-sql_2.12-3.5.6.jar → spark-sql_2.13-3.5.6.jar} +0 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py +1 -1
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/listener_worker.py +1 -1
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.py +12 -10
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.pyi +14 -2
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +10 -8
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +13 -6
- snowflake/snowpark_connect/relation/catalogs/abstract_spark_catalog.py +65 -17
- snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +297 -49
- snowflake/snowpark_connect/relation/catalogs/utils.py +12 -4
- snowflake/snowpark_connect/relation/io_utils.py +110 -10
- snowflake/snowpark_connect/relation/map_aggregate.py +239 -256
- snowflake/snowpark_connect/relation/map_catalog.py +5 -1
- snowflake/snowpark_connect/relation/map_column_ops.py +264 -96
- snowflake/snowpark_connect/relation/map_extension.py +263 -29
- snowflake/snowpark_connect/relation/map_join.py +683 -442
- snowflake/snowpark_connect/relation/map_local_relation.py +28 -1
- snowflake/snowpark_connect/relation/map_map_partitions.py +83 -8
- snowflake/snowpark_connect/relation/map_relation.py +48 -19
- snowflake/snowpark_connect/relation/map_row_ops.py +310 -91
- snowflake/snowpark_connect/relation/map_show_string.py +13 -6
- snowflake/snowpark_connect/relation/map_sql.py +1233 -222
- snowflake/snowpark_connect/relation/map_stats.py +48 -9
- snowflake/snowpark_connect/relation/map_subquery_alias.py +11 -2
- snowflake/snowpark_connect/relation/map_udtf.py +14 -4
- snowflake/snowpark_connect/relation/read/jdbc_read_dbapi.py +53 -14
- snowflake/snowpark_connect/relation/read/map_read.py +134 -43
- snowflake/snowpark_connect/relation/read/map_read_csv.py +326 -47
- snowflake/snowpark_connect/relation/read/map_read_jdbc.py +21 -6
- snowflake/snowpark_connect/relation/read/map_read_json.py +324 -86
- snowflake/snowpark_connect/relation/read/map_read_parquet.py +146 -28
- snowflake/snowpark_connect/relation/read/map_read_partitioned_parquet.py +142 -0
- snowflake/snowpark_connect/relation/read/map_read_socket.py +15 -3
- snowflake/snowpark_connect/relation/read/map_read_table.py +86 -6
- snowflake/snowpark_connect/relation/read/map_read_text.py +22 -4
- snowflake/snowpark_connect/relation/read/metadata_utils.py +170 -0
- snowflake/snowpark_connect/relation/read/reader_config.py +42 -3
- snowflake/snowpark_connect/relation/read/utils.py +50 -5
- snowflake/snowpark_connect/relation/stage_locator.py +91 -55
- snowflake/snowpark_connect/relation/utils.py +128 -5
- snowflake/snowpark_connect/relation/write/jdbc_write_dbapi.py +19 -3
- snowflake/snowpark_connect/relation/write/map_write.py +929 -319
- snowflake/snowpark_connect/relation/write/map_write_jdbc.py +8 -2
- snowflake/snowpark_connect/resources/java_udfs-1.0-SNAPSHOT.jar +0 -0
- snowflake/snowpark_connect/resources_initializer.py +171 -48
- snowflake/snowpark_connect/server.py +528 -473
- snowflake/snowpark_connect/server_common/__init__.py +503 -0
- snowflake/snowpark_connect/snowflake_session.py +65 -0
- snowflake/snowpark_connect/start_server.py +53 -5
- snowflake/snowpark_connect/type_mapping.py +349 -27
- snowflake/snowpark_connect/type_support.py +130 -0
- snowflake/snowpark_connect/typed_column.py +9 -7
- snowflake/snowpark_connect/utils/artifacts.py +9 -8
- snowflake/snowpark_connect/utils/cache.py +49 -27
- snowflake/snowpark_connect/utils/concurrent.py +36 -1
- snowflake/snowpark_connect/utils/context.py +195 -37
- snowflake/snowpark_connect/utils/describe_query_cache.py +68 -53
- snowflake/snowpark_connect/utils/env_utils.py +5 -1
- snowflake/snowpark_connect/utils/expression_transformer.py +172 -0
- snowflake/snowpark_connect/utils/identifiers.py +137 -3
- snowflake/snowpark_connect/utils/io_utils.py +57 -1
- snowflake/snowpark_connect/utils/java_stored_procedure.py +151 -0
- snowflake/snowpark_connect/utils/java_udaf_utils.py +321 -0
- snowflake/snowpark_connect/utils/java_udtf_utils.py +239 -0
- snowflake/snowpark_connect/utils/jvm_udf_utils.py +281 -0
- snowflake/snowpark_connect/utils/open_telemetry.py +516 -0
- snowflake/snowpark_connect/utils/pandas_udtf_utils.py +8 -4
- snowflake/snowpark_connect/utils/patch_spark_line_number.py +181 -0
- snowflake/snowpark_connect/utils/profiling.py +25 -8
- snowflake/snowpark_connect/utils/scala_udf_utils.py +185 -340
- snowflake/snowpark_connect/utils/sequence.py +21 -0
- snowflake/snowpark_connect/utils/session.py +64 -28
- snowflake/snowpark_connect/utils/snowpark_connect_logging.py +51 -9
- snowflake/snowpark_connect/utils/spcs_logger.py +290 -0
- snowflake/snowpark_connect/utils/telemetry.py +192 -40
- snowflake/snowpark_connect/utils/temporary_view_cache.py +67 -0
- snowflake/snowpark_connect/utils/temporary_view_helper.py +334 -0
- snowflake/snowpark_connect/utils/udf_cache.py +117 -41
- snowflake/snowpark_connect/utils/udf_helper.py +39 -37
- snowflake/snowpark_connect/utils/udf_utils.py +133 -14
- snowflake/snowpark_connect/utils/udtf_helper.py +8 -1
- snowflake/snowpark_connect/utils/udtf_utils.py +46 -31
- snowflake/snowpark_connect/utils/udxf_import_utils.py +9 -2
- snowflake/snowpark_connect/utils/upload_java_jar.py +57 -0
- snowflake/snowpark_connect/version.py +1 -1
- snowflake/snowpark_decoder/dp_session.py +6 -2
- snowflake/snowpark_decoder/spark_decoder.py +12 -0
- {snowpark_connect-0.27.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-submit +14 -4
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/METADATA +16 -7
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/RECORD +139 -168
- snowflake/snowpark_connect/hidden_column.py +0 -39
- snowflake/snowpark_connect/includes/jars/antlr4-runtime-4.9.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-cli-1.5.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-codec-1.16.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-collections-3.2.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-collections4-4.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-compiler-3.1.9.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-compress-1.26.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-crypto-1.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-dbcp-1.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-io-2.16.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-lang-2.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-lang3-3.12.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-logging-1.1.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-math3-3.6.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-pool-1.5.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-text-1.10.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/hadoop-client-api-trimmed-3.3.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-annotations-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-core-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-core-asl-1.9.13.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-databind-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-dataformat-yaml-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-datatype-jsr310-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-module-scala_2.12-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-ast_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-core_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-jackson_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-native_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-scalap_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/kryo-shaded-4.0.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-1.2-api-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-api-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-core-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-slf4j2-impl-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/paranamer-2.8.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/paranamer-2.8.jar +0 -0
- snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-collection-compat_2.12-2.7.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-parser-combinators_2.12-2.3.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-reflect-2.12.18.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-xml_2.12-2.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/slf4j-api-2.0.7.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-catalyst_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-common-utils_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-core_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-graphx_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-hive-thriftserver_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-hive_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-kvstore_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-launcher_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-mesos_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-mllib-local_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-network-common_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-network-shuffle_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-repl_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sketch_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sql-api_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-tags_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-unsafe_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-yarn_2.12-3.5.6.jar +0 -0
- {snowpark_connect-0.27.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.27.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/NOTICE-binary +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,239 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
|
|
5
|
+
import hashlib
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
|
|
8
|
+
from pyspark.sql.connect.proto.expressions_pb2 import CommonInlineUserDefinedFunction
|
|
9
|
+
|
|
10
|
+
from snowflake.snowpark.types import ArrayType, MapType, VariantType
|
|
11
|
+
from snowflake.snowpark_connect.resources_initializer import (
|
|
12
|
+
ensure_scala_udf_jars_uploaded,
|
|
13
|
+
)
|
|
14
|
+
from snowflake.snowpark_connect.type_mapping import (
|
|
15
|
+
map_type_to_snowflake_type,
|
|
16
|
+
proto_to_snowpark_type,
|
|
17
|
+
)
|
|
18
|
+
from snowflake.snowpark_connect.utils.jvm_udf_utils import (
|
|
19
|
+
NullHandling,
|
|
20
|
+
Param,
|
|
21
|
+
ReturnType,
|
|
22
|
+
Signature,
|
|
23
|
+
build_jvm_udxf_imports,
|
|
24
|
+
map_type_to_java_type,
|
|
25
|
+
)
|
|
26
|
+
from snowflake.snowpark_connect.utils.session import get_or_create_snowpark_session
|
|
27
|
+
|
|
28
|
+
JAVA_UDTF_PREFIX = "__SC_JAVA_UDTF_"
|
|
29
|
+
|
|
30
|
+
SCALA_INPUT_VARIANT = """
|
|
31
|
+
Object mappedInput = com.snowflake.sas.scala.UdfPacketUtils$.MODULE$.fromVariant(udfPacket, input, 0);
|
|
32
|
+
|
|
33
|
+
java.util.Iterator<Object> javaInput = Arrays.asList(mappedInput).iterator();
|
|
34
|
+
scala.collection.Iterator<Object> scalaInput = new scala.collection.AbstractIterator<Object>() {
|
|
35
|
+
public boolean hasNext() { return javaInput.hasNext(); }
|
|
36
|
+
public Object next() { return javaInput.next(); }
|
|
37
|
+
};
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
SCALA_INPUT_SIMPLE_TYPE = """
|
|
41
|
+
java.util.Iterator<__iterator_type__> javaInput = Arrays.asList(input).iterator();
|
|
42
|
+
scala.collection.Iterator<__iterator_type__> scalaInput = new scala.collection.AbstractIterator<__iterator_type__>() {
|
|
43
|
+
public boolean hasNext() { return javaInput.hasNext(); }
|
|
44
|
+
public __iterator_type__ next() { return javaInput.next(); }
|
|
45
|
+
};
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
UDTF_TEMPLATE = """
|
|
49
|
+
import org.apache.spark.sql.connect.common.UdfPacket;
|
|
50
|
+
|
|
51
|
+
import java.io.IOException;
|
|
52
|
+
import java.io.InputStream;
|
|
53
|
+
import java.io.ObjectInputStream;
|
|
54
|
+
import java.io.Serializable;
|
|
55
|
+
import java.nio.file.Files;
|
|
56
|
+
import java.nio.file.Paths;
|
|
57
|
+
|
|
58
|
+
import java.util.*;
|
|
59
|
+
import java.lang.*;
|
|
60
|
+
import java.util.stream.Collectors;
|
|
61
|
+
import com.snowflake.snowpark_java.types.*;
|
|
62
|
+
import java.util.stream.Stream;
|
|
63
|
+
import java.util.stream.StreamSupport;
|
|
64
|
+
|
|
65
|
+
public class OutputRow {
|
|
66
|
+
public Variant __java_udtf_prefix__C1;
|
|
67
|
+
public OutputRow(Variant __java_udtf_prefix__C1) {
|
|
68
|
+
this.__java_udtf_prefix__C1 = __java_udtf_prefix__C1;
|
|
69
|
+
}
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
public class JavaUdtfHandler {
|
|
73
|
+
private final static String OPERATION_FILE = "__operation_file__";
|
|
74
|
+
private static scala.Function1<scala.collection.Iterator<__iterator_type__>, scala.collection.Iterator<Object>> operation = null;
|
|
75
|
+
private static UdfPacket udfPacket = null;
|
|
76
|
+
|
|
77
|
+
public static Class getOutputClass() { return OutputRow.class; }
|
|
78
|
+
|
|
79
|
+
private static void loadOperation() throws IOException, ClassNotFoundException {
|
|
80
|
+
if (operation != null) {
|
|
81
|
+
return; // Already loaded
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
udfPacket = com.snowflake.sas.scala.Utils$.MODULE$.deserializeUdfPacket(OPERATION_FILE);
|
|
85
|
+
operation = (scala.Function1<scala.collection.Iterator<__iterator_type__>, scala.collection.Iterator<Object>>) udfPacket.function();
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
public Stream<OutputRow> process(__input_type__ input) throws IOException, ClassNotFoundException {
|
|
89
|
+
loadOperation();
|
|
90
|
+
|
|
91
|
+
__scala_input__
|
|
92
|
+
|
|
93
|
+
scala.collection.Iterator<Object> scalaResult = operation.apply(scalaInput);
|
|
94
|
+
|
|
95
|
+
java.util.Iterator<Variant> javaResult = new java.util.Iterator<Variant>() {
|
|
96
|
+
public boolean hasNext() { return scalaResult.hasNext(); }
|
|
97
|
+
public Variant next() {
|
|
98
|
+
return com.snowflake.sas.scala.Utils$.MODULE$.toVariant(scalaResult.next(), udfPacket);
|
|
99
|
+
}
|
|
100
|
+
};
|
|
101
|
+
|
|
102
|
+
return StreamSupport.stream(Spliterators.spliteratorUnknownSize(javaResult, Spliterator.ORDERED), false)
|
|
103
|
+
.map(i -> new OutputRow(i));
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
public Stream<OutputRow> endPartition() {
|
|
107
|
+
return Stream.empty();
|
|
108
|
+
}
|
|
109
|
+
}
|
|
110
|
+
"""
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
@dataclass(frozen=True)
|
|
114
|
+
class JavaUDTFDef:
|
|
115
|
+
"""
|
|
116
|
+
Complete definition for creating a Java UDTF in Snowflake.
|
|
117
|
+
|
|
118
|
+
Contains all the information needed to generate the CREATE FUNCTION SQL statement
|
|
119
|
+
and the Java code body for the UDTF.
|
|
120
|
+
|
|
121
|
+
Attributes:
|
|
122
|
+
name: UDTF name
|
|
123
|
+
signature: SQL signature (for Snowflake function definition)
|
|
124
|
+
java_signature: Java signature (for Java code generation)
|
|
125
|
+
imports: List of JAR files to import
|
|
126
|
+
null_handling: Null handling behavior (defaults to RETURNS_NULL_ON_NULL_INPUT)
|
|
127
|
+
"""
|
|
128
|
+
|
|
129
|
+
name: str
|
|
130
|
+
signature: Signature
|
|
131
|
+
java_signature: Signature
|
|
132
|
+
imports: list[str]
|
|
133
|
+
null_handling: NullHandling = NullHandling.RETURNS_NULL_ON_NULL_INPUT
|
|
134
|
+
|
|
135
|
+
def _gen_body_java(self) -> str:
|
|
136
|
+
returns_variant = self.signature.returns.data_type == "VARIANT"
|
|
137
|
+
return_type = (
|
|
138
|
+
"Variant" if returns_variant else self.java_signature.returns.data_type
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
is_variant_input = self.java_signature.params[0].data_type.lower() == "variant"
|
|
142
|
+
|
|
143
|
+
scala_input_template = (
|
|
144
|
+
SCALA_INPUT_VARIANT if is_variant_input else SCALA_INPUT_SIMPLE_TYPE
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
iterator_type = (
|
|
148
|
+
"Object" if is_variant_input else self.java_signature.params[0].data_type
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
return (
|
|
152
|
+
UDTF_TEMPLATE.replace("__operation_file__", self.imports[0].split("/")[-1])
|
|
153
|
+
.replace("__scala_input__", scala_input_template)
|
|
154
|
+
.replace("__iterator_type__", iterator_type)
|
|
155
|
+
.replace("__input_type__", self.java_signature.params[0].data_type)
|
|
156
|
+
.replace("__return_type__", return_type)
|
|
157
|
+
.replace("__java_udtf_prefix__", JAVA_UDTF_PREFIX)
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
def to_create_function_sql(self) -> str:
|
|
161
|
+
args = ", ".join(
|
|
162
|
+
[f"{param.name} {param.data_type}" for param in self.signature.params]
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
def quote_single(s: str) -> str:
|
|
166
|
+
"""Helper function to wrap strings in single quotes for SQL."""
|
|
167
|
+
return "'" + s + "'"
|
|
168
|
+
|
|
169
|
+
# Handler and imports
|
|
170
|
+
imports_sql = f"IMPORTS = ({', '.join(quote_single(x) for x in self.imports)})"
|
|
171
|
+
|
|
172
|
+
return f"""
|
|
173
|
+
create or replace function {self.name}({args})
|
|
174
|
+
returns table ({JAVA_UDTF_PREFIX}C1 VARIANT)
|
|
175
|
+
language java
|
|
176
|
+
runtime_version = 17
|
|
177
|
+
PACKAGES = ('com.snowflake:snowpark:latest')
|
|
178
|
+
{imports_sql}
|
|
179
|
+
handler='JavaUdtfHandler'
|
|
180
|
+
as
|
|
181
|
+
$$
|
|
182
|
+
{self._gen_body_java()}
|
|
183
|
+
$$;"""
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def create_java_udtf_for_scala_flatmap_handling(
|
|
187
|
+
udf_proto: CommonInlineUserDefinedFunction,
|
|
188
|
+
) -> str:
|
|
189
|
+
ensure_scala_udf_jars_uploaded()
|
|
190
|
+
|
|
191
|
+
return_type = proto_to_snowpark_type(udf_proto.scalar_scala_udf.outputType)
|
|
192
|
+
|
|
193
|
+
session = get_or_create_snowpark_session()
|
|
194
|
+
|
|
195
|
+
return_type_java = map_type_to_java_type(return_type)
|
|
196
|
+
sql_return_type = map_type_to_snowflake_type(return_type)
|
|
197
|
+
|
|
198
|
+
java_input_params: list[Param] = []
|
|
199
|
+
sql_input_params: list[Param] = []
|
|
200
|
+
for i, input_type_proto in enumerate(udf_proto.scalar_scala_udf.inputTypes):
|
|
201
|
+
input_type = proto_to_snowpark_type(input_type_proto)
|
|
202
|
+
|
|
203
|
+
param_name = "arg" + str(i)
|
|
204
|
+
|
|
205
|
+
if isinstance(input_type, (ArrayType, MapType, VariantType)):
|
|
206
|
+
java_type = "Variant"
|
|
207
|
+
snowflake_type = "Variant"
|
|
208
|
+
else:
|
|
209
|
+
java_type = map_type_to_java_type(input_type)
|
|
210
|
+
snowflake_type = map_type_to_snowflake_type(input_type)
|
|
211
|
+
|
|
212
|
+
java_input_params.append(Param(param_name, java_type))
|
|
213
|
+
sql_input_params.append(Param(param_name, snowflake_type))
|
|
214
|
+
|
|
215
|
+
udtf_name = (
|
|
216
|
+
JAVA_UDTF_PREFIX + hashlib.md5(udf_proto.scalar_scala_udf.payload).hexdigest()
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
imports = build_jvm_udxf_imports(
|
|
220
|
+
session,
|
|
221
|
+
udf_proto.scalar_scala_udf.payload,
|
|
222
|
+
udtf_name,
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
udtf = JavaUDTFDef(
|
|
226
|
+
name=udtf_name,
|
|
227
|
+
signature=Signature(
|
|
228
|
+
params=sql_input_params, returns=ReturnType(sql_return_type)
|
|
229
|
+
),
|
|
230
|
+
imports=imports,
|
|
231
|
+
java_signature=Signature(
|
|
232
|
+
params=java_input_params, returns=ReturnType(return_type_java)
|
|
233
|
+
),
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
sql = udtf.to_create_function_sql()
|
|
237
|
+
session.sql(sql).collect()
|
|
238
|
+
|
|
239
|
+
return udtf_name
|
|
@@ -0,0 +1,281 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from enum import Enum
|
|
7
|
+
from typing import List, Union
|
|
8
|
+
|
|
9
|
+
import snowflake.snowpark.types as snowpark_type
|
|
10
|
+
import snowflake.snowpark_connect.includes.python.pyspark.sql.connect.proto.types_pb2 as types_proto
|
|
11
|
+
from snowflake import snowpark
|
|
12
|
+
from snowflake.snowpark_connect.config import get_scala_version
|
|
13
|
+
from snowflake.snowpark_connect.error.error_codes import ErrorCodes
|
|
14
|
+
from snowflake.snowpark_connect.error.error_utils import attach_custom_error_code
|
|
15
|
+
from snowflake.snowpark_connect.resources_initializer import (
|
|
16
|
+
JSON_4S_JAR_212,
|
|
17
|
+
JSON_4S_JAR_213,
|
|
18
|
+
RESOURCE_PATH,
|
|
19
|
+
SAS_SCALA_UDF_JAR_212,
|
|
20
|
+
SAS_SCALA_UDF_JAR_213,
|
|
21
|
+
SCALA_REFLECT_JAR_212,
|
|
22
|
+
SCALA_REFLECT_JAR_213,
|
|
23
|
+
SPARK_COMMON_UTILS_JAR_212,
|
|
24
|
+
SPARK_COMMON_UTILS_JAR_213,
|
|
25
|
+
SPARK_CONNECT_CLIENT_JAR_212,
|
|
26
|
+
SPARK_CONNECT_CLIENT_JAR_213,
|
|
27
|
+
SPARK_SQL_JAR_212,
|
|
28
|
+
SPARK_SQL_JAR_213,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclass(frozen=True)
|
|
33
|
+
class Param:
|
|
34
|
+
"""
|
|
35
|
+
Represents a function parameter with name and data type.
|
|
36
|
+
|
|
37
|
+
Attributes:
|
|
38
|
+
name: Parameter name
|
|
39
|
+
data_type: Parameter data type as a string
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
name: str
|
|
43
|
+
data_type: str
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@dataclass(frozen=True)
|
|
47
|
+
class NullHandling(str, Enum):
|
|
48
|
+
"""
|
|
49
|
+
Enumeration for UDF null handling behavior.
|
|
50
|
+
|
|
51
|
+
Determines how the UDF behaves when input parameters contain null values.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
RETURNS_NULL_ON_NULL_INPUT = "RETURNS NULL ON NULL INPUT"
|
|
55
|
+
CALLED_ON_NULL_INPUT = "CALLED ON NULL INPUT"
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@dataclass(frozen=True)
|
|
59
|
+
class ReturnType:
|
|
60
|
+
"""
|
|
61
|
+
Represents the return type of a function.
|
|
62
|
+
|
|
63
|
+
Attributes:
|
|
64
|
+
data_type: Return data type as a string
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
data_type: str
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
@dataclass(frozen=True)
|
|
71
|
+
class Signature:
|
|
72
|
+
"""
|
|
73
|
+
Represents a function signature with parameters and return type.
|
|
74
|
+
|
|
75
|
+
Attributes:
|
|
76
|
+
params: List of function parameters
|
|
77
|
+
returns: Function return type
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
params: List[Param]
|
|
81
|
+
returns: ReturnType
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def build_jvm_udxf_imports(
|
|
85
|
+
session: snowpark.Session, payload: bytes, udf_name: str
|
|
86
|
+
) -> List[str]:
|
|
87
|
+
"""
|
|
88
|
+
Build the list of imports needed for the JVM UDxF.
|
|
89
|
+
|
|
90
|
+
This function:
|
|
91
|
+
1. Saves the UDF payload to a binary file in the session stage
|
|
92
|
+
2. Collects user-uploaded JAR files from the stage
|
|
93
|
+
3. Returns a list of all required JAR files for the UDxF
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
session: Snowpark session
|
|
97
|
+
payload: Binary payload containing the serialized Scala UDF
|
|
98
|
+
udf_name: Name of the Scala UDF (used for the binary file name)
|
|
99
|
+
is_map_return: Indicates if the UDxF returns a Map (affects imports)
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
List of JAR file paths to be imported by the UDxF
|
|
103
|
+
"""
|
|
104
|
+
# Save pciudf._payload to a bin file:
|
|
105
|
+
import io
|
|
106
|
+
|
|
107
|
+
payload_as_stream = io.BytesIO(payload)
|
|
108
|
+
stage = session.get_session_stage()
|
|
109
|
+
stage_resource_path = stage + RESOURCE_PATH
|
|
110
|
+
closure_binary_file = stage_resource_path + "/scala/bin/" + udf_name + ".bin"
|
|
111
|
+
session.file.put_stream(
|
|
112
|
+
payload_as_stream,
|
|
113
|
+
closure_binary_file,
|
|
114
|
+
overwrite=True,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
# Format the user jars to be used in the IMPORTS clause of the stored procedure.
|
|
118
|
+
return (
|
|
119
|
+
[closure_binary_file]
|
|
120
|
+
+ _scala_static_imports_for_udf(stage_resource_path)
|
|
121
|
+
+ list(session._artifact_jars)
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def _scala_static_imports_for_udf(stage_resource_path: str) -> list[str]:
|
|
126
|
+
scala_version = get_scala_version()
|
|
127
|
+
if scala_version == "2.12":
|
|
128
|
+
return [
|
|
129
|
+
f"{stage_resource_path}/{SPARK_CONNECT_CLIENT_JAR_212}",
|
|
130
|
+
f"{stage_resource_path}/{SPARK_COMMON_UTILS_JAR_212}",
|
|
131
|
+
f"{stage_resource_path}/{SPARK_SQL_JAR_212}",
|
|
132
|
+
f"{stage_resource_path}/{JSON_4S_JAR_212}",
|
|
133
|
+
f"{stage_resource_path}/{SAS_SCALA_UDF_JAR_212}",
|
|
134
|
+
f"{stage_resource_path}/{SCALA_REFLECT_JAR_212}", # Required for deserializing Scala lambdas
|
|
135
|
+
]
|
|
136
|
+
|
|
137
|
+
if scala_version == "2.13":
|
|
138
|
+
return [
|
|
139
|
+
f"{stage_resource_path}/{SPARK_CONNECT_CLIENT_JAR_213}",
|
|
140
|
+
f"{stage_resource_path}/{SPARK_COMMON_UTILS_JAR_213}",
|
|
141
|
+
f"{stage_resource_path}/{SPARK_SQL_JAR_213}",
|
|
142
|
+
f"{stage_resource_path}/{JSON_4S_JAR_213}",
|
|
143
|
+
f"{stage_resource_path}/{SAS_SCALA_UDF_JAR_213}",
|
|
144
|
+
f"{stage_resource_path}/{SCALA_REFLECT_JAR_213}", # Required for deserializing Scala lambdas
|
|
145
|
+
]
|
|
146
|
+
|
|
147
|
+
# invalid Scala version
|
|
148
|
+
exception = ValueError(
|
|
149
|
+
f"Unsupported Scala version: {scala_version}. Snowpark Connect supports Scala 2.12 and 2.13"
|
|
150
|
+
)
|
|
151
|
+
attach_custom_error_code(exception, ErrorCodes.INVALID_CONFIG_VALUE)
|
|
152
|
+
raise exception
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def map_type_to_java_type(
|
|
156
|
+
t: Union[snowpark_type.DataType, types_proto.DataType]
|
|
157
|
+
) -> str:
|
|
158
|
+
"""Maps a Snowpark or Spark protobuf type to a Java type string."""
|
|
159
|
+
if not t:
|
|
160
|
+
return "String"
|
|
161
|
+
is_snowpark_type = isinstance(t, snowpark_type.DataType)
|
|
162
|
+
condition = type(t) if is_snowpark_type else t.WhichOneof("kind")
|
|
163
|
+
match condition:
|
|
164
|
+
case snowpark_type.ArrayType | "array":
|
|
165
|
+
return (
|
|
166
|
+
f"{map_type_to_java_type(t.element_type)}[]"
|
|
167
|
+
if is_snowpark_type
|
|
168
|
+
else f"{map_type_to_java_type(t.array.element_type)}[]"
|
|
169
|
+
)
|
|
170
|
+
case snowpark_type.BinaryType | "binary":
|
|
171
|
+
return "byte[]"
|
|
172
|
+
case snowpark_type.BooleanType | "boolean":
|
|
173
|
+
return "Boolean"
|
|
174
|
+
case snowpark_type.ByteType | "byte":
|
|
175
|
+
return "Byte"
|
|
176
|
+
case snowpark_type.DateType | "date":
|
|
177
|
+
return "java.sql.Date"
|
|
178
|
+
case snowpark_type.DecimalType | "decimal":
|
|
179
|
+
return "java.math.BigDecimal"
|
|
180
|
+
case snowpark_type.DoubleType | "double":
|
|
181
|
+
return "Double"
|
|
182
|
+
case snowpark_type.FloatType | "float":
|
|
183
|
+
return "Float"
|
|
184
|
+
case snowpark_type.GeographyType:
|
|
185
|
+
return "Geography"
|
|
186
|
+
case snowpark_type.IntegerType | "integer":
|
|
187
|
+
return "Integer"
|
|
188
|
+
case snowpark_type.LongType | "long":
|
|
189
|
+
return "Long"
|
|
190
|
+
case snowpark_type.MapType | "map": # can also map to OBJECT in Snowflake
|
|
191
|
+
key_type = (
|
|
192
|
+
map_type_to_java_type(t.key_type)
|
|
193
|
+
if is_snowpark_type
|
|
194
|
+
else map_type_to_java_type(t.map.key_type)
|
|
195
|
+
)
|
|
196
|
+
value_type = (
|
|
197
|
+
map_type_to_java_type(t.value_type)
|
|
198
|
+
if is_snowpark_type
|
|
199
|
+
else map_type_to_java_type(t.map.value_type)
|
|
200
|
+
)
|
|
201
|
+
return f"Map<{key_type}, {value_type}>"
|
|
202
|
+
case snowpark_type.NullType | "null":
|
|
203
|
+
return "String" # cannot set the return type to Null in Snowpark Java UDAFs
|
|
204
|
+
case snowpark_type.ShortType | "short":
|
|
205
|
+
return "Short"
|
|
206
|
+
case snowpark_type.StringType | "string" | "char" | "varchar":
|
|
207
|
+
return "String"
|
|
208
|
+
case snowpark_type.StructType | "struct":
|
|
209
|
+
return "Variant"
|
|
210
|
+
case snowpark_type.TimestampType | "timestamp" | "timestamp_ntz":
|
|
211
|
+
return "java.sql.Timestamp"
|
|
212
|
+
case snowpark_type.VariantType:
|
|
213
|
+
return "Variant"
|
|
214
|
+
case _:
|
|
215
|
+
exception = ValueError(f"Unsupported Snowpark type: {t}")
|
|
216
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_TYPE)
|
|
217
|
+
raise exception
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def cast_java_map_args_from_given_type(
|
|
221
|
+
arg_name: str, input_type: Union[snowpark_type.DataType, types_proto.DataType]
|
|
222
|
+
) -> str:
|
|
223
|
+
"""If the input_type is a Map or Struct, cast the argument arg_name to the correct type in Java."""
|
|
224
|
+
is_snowpark_type = isinstance(input_type, snowpark_type.DataType)
|
|
225
|
+
|
|
226
|
+
def convert_from_string_to_type(
|
|
227
|
+
arg_name: str, t: Union[snowpark_type.DataType, types_proto.DataType]
|
|
228
|
+
) -> str:
|
|
229
|
+
"""Convert the string argument arg_name to the specified type t in Java."""
|
|
230
|
+
condition = type(t) if is_snowpark_type else t.WhichOneof("kind")
|
|
231
|
+
match condition:
|
|
232
|
+
case snowpark_type.BinaryType | "binary":
|
|
233
|
+
return arg_name + ".getBytes()"
|
|
234
|
+
case snowpark_type.BooleanType | "boolean":
|
|
235
|
+
return f"Boolean.valueOf({arg_name}"
|
|
236
|
+
case snowpark_type.ByteType | "byte":
|
|
237
|
+
return arg_name + ".getBytes()[0]" # TODO: verify if this is correct
|
|
238
|
+
case snowpark_type.DateType | "date":
|
|
239
|
+
return f"java.sql.Date.valueOf({arg_name})"
|
|
240
|
+
case snowpark_type.DecimalType | "decimal":
|
|
241
|
+
return f"new BigDecimal({arg_name})"
|
|
242
|
+
case snowpark_type.DoubleType | "double":
|
|
243
|
+
return f"Double.valueOf({arg_name}"
|
|
244
|
+
case snowpark_type.FloatType | "float":
|
|
245
|
+
return f"Float.valueOf({arg_name}"
|
|
246
|
+
case snowpark_type.IntegerType | "integer":
|
|
247
|
+
return f"Integer.valueOf({arg_name}"
|
|
248
|
+
case snowpark_type.LongType | "long":
|
|
249
|
+
return f"Long.valueOf({arg_name}"
|
|
250
|
+
case snowpark_type.ShortType | "short":
|
|
251
|
+
return f"Short.valueOf({arg_name}"
|
|
252
|
+
case snowpark_type.StringType | "string" | "char" | "varchar":
|
|
253
|
+
return arg_name
|
|
254
|
+
case snowpark_type.TimestampType | "timestamp" | "timestamp_ntz":
|
|
255
|
+
return f"java.sql.Timestamp.valueOf({arg_name})" # todo add test
|
|
256
|
+
case _:
|
|
257
|
+
exception = ValueError(f"Unsupported Snowpark type: {t}")
|
|
258
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_TYPE)
|
|
259
|
+
raise exception
|
|
260
|
+
|
|
261
|
+
if (is_snowpark_type and isinstance(input_type, snowpark_type.MapType)) or (
|
|
262
|
+
not is_snowpark_type and input_type.WhichOneof("kind") == "map"
|
|
263
|
+
):
|
|
264
|
+
key_type = input_type.key_type if is_snowpark_type else input_type.map.key_type
|
|
265
|
+
value_type = (
|
|
266
|
+
input_type.value_type if is_snowpark_type else input_type.map.value_type
|
|
267
|
+
)
|
|
268
|
+
key_converter = "{" + convert_from_string_to_type("e.getKey()", key_type) + "}"
|
|
269
|
+
value_converter = (
|
|
270
|
+
"{" + convert_from_string_to_type("e.getValue()", value_type) + "}"
|
|
271
|
+
)
|
|
272
|
+
return f"""
|
|
273
|
+
{arg_name}.entrySet()
|
|
274
|
+
.stream()
|
|
275
|
+
.collect(Collectors.toMap(
|
|
276
|
+
e -> {key_converter},
|
|
277
|
+
e -> {value_converter}
|
|
278
|
+
));
|
|
279
|
+
"""
|
|
280
|
+
else:
|
|
281
|
+
return arg_name
|