snowpark-connect 0.27.0__py3-none-any.whl → 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. snowflake/snowpark_connect/__init__.py +1 -0
  2. snowflake/snowpark_connect/analyze_plan/map_tree_string.py +8 -4
  3. snowflake/snowpark_connect/client/__init__.py +15 -0
  4. snowflake/snowpark_connect/client/error_utils.py +30 -0
  5. snowflake/snowpark_connect/client/exceptions.py +36 -0
  6. snowflake/snowpark_connect/client/query_results.py +90 -0
  7. snowflake/snowpark_connect/client/server.py +717 -0
  8. snowflake/snowpark_connect/client/utils/__init__.py +10 -0
  9. snowflake/snowpark_connect/client/utils/session.py +85 -0
  10. snowflake/snowpark_connect/column_name_handler.py +404 -243
  11. snowflake/snowpark_connect/column_qualifier.py +43 -0
  12. snowflake/snowpark_connect/config.py +309 -26
  13. snowflake/snowpark_connect/constants.py +2 -0
  14. snowflake/snowpark_connect/dataframe_container.py +102 -8
  15. snowflake/snowpark_connect/date_time_format_mapping.py +71 -13
  16. snowflake/snowpark_connect/error/error_codes.py +50 -0
  17. snowflake/snowpark_connect/error/error_utils.py +172 -23
  18. snowflake/snowpark_connect/error/exceptions.py +13 -4
  19. snowflake/snowpark_connect/execute_plan/map_execution_command.py +15 -160
  20. snowflake/snowpark_connect/execute_plan/map_execution_root.py +26 -20
  21. snowflake/snowpark_connect/execute_plan/utils.py +5 -1
  22. snowflake/snowpark_connect/expression/error_utils.py +28 -0
  23. snowflake/snowpark_connect/expression/function_defaults.py +9 -2
  24. snowflake/snowpark_connect/expression/hybrid_column_map.py +53 -5
  25. snowflake/snowpark_connect/expression/integral_types_support.py +219 -0
  26. snowflake/snowpark_connect/expression/literal.py +37 -13
  27. snowflake/snowpark_connect/expression/map_cast.py +224 -15
  28. snowflake/snowpark_connect/expression/map_expression.py +80 -27
  29. snowflake/snowpark_connect/expression/map_extension.py +322 -12
  30. snowflake/snowpark_connect/expression/map_sql_expression.py +316 -81
  31. snowflake/snowpark_connect/expression/map_udf.py +86 -20
  32. snowflake/snowpark_connect/expression/map_unresolved_attribute.py +451 -173
  33. snowflake/snowpark_connect/expression/map_unresolved_function.py +2964 -829
  34. snowflake/snowpark_connect/expression/map_unresolved_star.py +87 -23
  35. snowflake/snowpark_connect/expression/map_update_fields.py +70 -18
  36. snowflake/snowpark_connect/expression/map_window_function.py +18 -3
  37. snowflake/snowpark_connect/includes/jars/json4s-ast_2.13-3.7.0-M11.jar +0 -0
  38. snowflake/snowpark_connect/includes/jars/{scala-library-2.12.18.jar → sas-scala-udf_2.12-0.2.0.jar} +0 -0
  39. snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.13-0.2.0.jar +0 -0
  40. snowflake/snowpark_connect/includes/jars/scala-reflect-2.13.16.jar +0 -0
  41. snowflake/snowpark_connect/includes/jars/spark-common-utils_2.13-3.5.6.jar +0 -0
  42. snowflake/snowpark_connect/includes/jars/{spark-connect-client-jvm_2.12-3.5.6.jar → spark-connect-client-jvm_2.13-3.5.6.jar} +0 -0
  43. snowflake/snowpark_connect/includes/jars/{spark-sql_2.12-3.5.6.jar → spark-sql_2.13-3.5.6.jar} +0 -0
  44. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py +1 -1
  45. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/listener_worker.py +1 -1
  46. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.py +12 -10
  47. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.pyi +14 -2
  48. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +10 -8
  49. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +13 -6
  50. snowflake/snowpark_connect/relation/catalogs/abstract_spark_catalog.py +65 -17
  51. snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +297 -49
  52. snowflake/snowpark_connect/relation/catalogs/utils.py +12 -4
  53. snowflake/snowpark_connect/relation/io_utils.py +110 -10
  54. snowflake/snowpark_connect/relation/map_aggregate.py +239 -256
  55. snowflake/snowpark_connect/relation/map_catalog.py +5 -1
  56. snowflake/snowpark_connect/relation/map_column_ops.py +264 -96
  57. snowflake/snowpark_connect/relation/map_extension.py +263 -29
  58. snowflake/snowpark_connect/relation/map_join.py +683 -442
  59. snowflake/snowpark_connect/relation/map_local_relation.py +28 -1
  60. snowflake/snowpark_connect/relation/map_map_partitions.py +83 -8
  61. snowflake/snowpark_connect/relation/map_relation.py +48 -19
  62. snowflake/snowpark_connect/relation/map_row_ops.py +310 -91
  63. snowflake/snowpark_connect/relation/map_show_string.py +13 -6
  64. snowflake/snowpark_connect/relation/map_sql.py +1233 -222
  65. snowflake/snowpark_connect/relation/map_stats.py +48 -9
  66. snowflake/snowpark_connect/relation/map_subquery_alias.py +11 -2
  67. snowflake/snowpark_connect/relation/map_udtf.py +14 -4
  68. snowflake/snowpark_connect/relation/read/jdbc_read_dbapi.py +53 -14
  69. snowflake/snowpark_connect/relation/read/map_read.py +134 -43
  70. snowflake/snowpark_connect/relation/read/map_read_csv.py +326 -47
  71. snowflake/snowpark_connect/relation/read/map_read_jdbc.py +21 -6
  72. snowflake/snowpark_connect/relation/read/map_read_json.py +324 -86
  73. snowflake/snowpark_connect/relation/read/map_read_parquet.py +146 -28
  74. snowflake/snowpark_connect/relation/read/map_read_partitioned_parquet.py +142 -0
  75. snowflake/snowpark_connect/relation/read/map_read_socket.py +15 -3
  76. snowflake/snowpark_connect/relation/read/map_read_table.py +86 -6
  77. snowflake/snowpark_connect/relation/read/map_read_text.py +22 -4
  78. snowflake/snowpark_connect/relation/read/metadata_utils.py +170 -0
  79. snowflake/snowpark_connect/relation/read/reader_config.py +42 -3
  80. snowflake/snowpark_connect/relation/read/utils.py +50 -5
  81. snowflake/snowpark_connect/relation/stage_locator.py +91 -55
  82. snowflake/snowpark_connect/relation/utils.py +128 -5
  83. snowflake/snowpark_connect/relation/write/jdbc_write_dbapi.py +19 -3
  84. snowflake/snowpark_connect/relation/write/map_write.py +929 -319
  85. snowflake/snowpark_connect/relation/write/map_write_jdbc.py +8 -2
  86. snowflake/snowpark_connect/resources/java_udfs-1.0-SNAPSHOT.jar +0 -0
  87. snowflake/snowpark_connect/resources_initializer.py +171 -48
  88. snowflake/snowpark_connect/server.py +528 -473
  89. snowflake/snowpark_connect/server_common/__init__.py +503 -0
  90. snowflake/snowpark_connect/snowflake_session.py +65 -0
  91. snowflake/snowpark_connect/start_server.py +53 -5
  92. snowflake/snowpark_connect/type_mapping.py +349 -27
  93. snowflake/snowpark_connect/type_support.py +130 -0
  94. snowflake/snowpark_connect/typed_column.py +9 -7
  95. snowflake/snowpark_connect/utils/artifacts.py +9 -8
  96. snowflake/snowpark_connect/utils/cache.py +49 -27
  97. snowflake/snowpark_connect/utils/concurrent.py +36 -1
  98. snowflake/snowpark_connect/utils/context.py +195 -37
  99. snowflake/snowpark_connect/utils/describe_query_cache.py +68 -53
  100. snowflake/snowpark_connect/utils/env_utils.py +5 -1
  101. snowflake/snowpark_connect/utils/expression_transformer.py +172 -0
  102. snowflake/snowpark_connect/utils/identifiers.py +137 -3
  103. snowflake/snowpark_connect/utils/io_utils.py +57 -1
  104. snowflake/snowpark_connect/utils/java_stored_procedure.py +151 -0
  105. snowflake/snowpark_connect/utils/java_udaf_utils.py +321 -0
  106. snowflake/snowpark_connect/utils/java_udtf_utils.py +239 -0
  107. snowflake/snowpark_connect/utils/jvm_udf_utils.py +281 -0
  108. snowflake/snowpark_connect/utils/open_telemetry.py +516 -0
  109. snowflake/snowpark_connect/utils/pandas_udtf_utils.py +8 -4
  110. snowflake/snowpark_connect/utils/patch_spark_line_number.py +181 -0
  111. snowflake/snowpark_connect/utils/profiling.py +25 -8
  112. snowflake/snowpark_connect/utils/scala_udf_utils.py +185 -340
  113. snowflake/snowpark_connect/utils/sequence.py +21 -0
  114. snowflake/snowpark_connect/utils/session.py +64 -28
  115. snowflake/snowpark_connect/utils/snowpark_connect_logging.py +51 -9
  116. snowflake/snowpark_connect/utils/spcs_logger.py +290 -0
  117. snowflake/snowpark_connect/utils/telemetry.py +192 -40
  118. snowflake/snowpark_connect/utils/temporary_view_cache.py +67 -0
  119. snowflake/snowpark_connect/utils/temporary_view_helper.py +334 -0
  120. snowflake/snowpark_connect/utils/udf_cache.py +117 -41
  121. snowflake/snowpark_connect/utils/udf_helper.py +39 -37
  122. snowflake/snowpark_connect/utils/udf_utils.py +133 -14
  123. snowflake/snowpark_connect/utils/udtf_helper.py +8 -1
  124. snowflake/snowpark_connect/utils/udtf_utils.py +46 -31
  125. snowflake/snowpark_connect/utils/udxf_import_utils.py +9 -2
  126. snowflake/snowpark_connect/utils/upload_java_jar.py +57 -0
  127. snowflake/snowpark_connect/version.py +1 -1
  128. snowflake/snowpark_decoder/dp_session.py +6 -2
  129. snowflake/snowpark_decoder/spark_decoder.py +12 -0
  130. {snowpark_connect-0.27.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-submit +14 -4
  131. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/METADATA +16 -7
  132. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/RECORD +139 -168
  133. snowflake/snowpark_connect/hidden_column.py +0 -39
  134. snowflake/snowpark_connect/includes/jars/antlr4-runtime-4.9.3.jar +0 -0
  135. snowflake/snowpark_connect/includes/jars/commons-cli-1.5.0.jar +0 -0
  136. snowflake/snowpark_connect/includes/jars/commons-codec-1.16.1.jar +0 -0
  137. snowflake/snowpark_connect/includes/jars/commons-collections-3.2.2.jar +0 -0
  138. snowflake/snowpark_connect/includes/jars/commons-collections4-4.4.jar +0 -0
  139. snowflake/snowpark_connect/includes/jars/commons-compiler-3.1.9.jar +0 -0
  140. snowflake/snowpark_connect/includes/jars/commons-compress-1.26.0.jar +0 -0
  141. snowflake/snowpark_connect/includes/jars/commons-crypto-1.1.0.jar +0 -0
  142. snowflake/snowpark_connect/includes/jars/commons-dbcp-1.4.jar +0 -0
  143. snowflake/snowpark_connect/includes/jars/commons-io-2.16.1.jar +0 -0
  144. snowflake/snowpark_connect/includes/jars/commons-lang-2.6.jar +0 -0
  145. snowflake/snowpark_connect/includes/jars/commons-lang3-3.12.0.jar +0 -0
  146. snowflake/snowpark_connect/includes/jars/commons-logging-1.1.3.jar +0 -0
  147. snowflake/snowpark_connect/includes/jars/commons-math3-3.6.1.jar +0 -0
  148. snowflake/snowpark_connect/includes/jars/commons-pool-1.5.4.jar +0 -0
  149. snowflake/snowpark_connect/includes/jars/commons-text-1.10.0.jar +0 -0
  150. snowflake/snowpark_connect/includes/jars/hadoop-client-api-trimmed-3.3.4.jar +0 -0
  151. snowflake/snowpark_connect/includes/jars/jackson-annotations-2.15.2.jar +0 -0
  152. snowflake/snowpark_connect/includes/jars/jackson-core-2.15.2.jar +0 -0
  153. snowflake/snowpark_connect/includes/jars/jackson-core-asl-1.9.13.jar +0 -0
  154. snowflake/snowpark_connect/includes/jars/jackson-databind-2.15.2.jar +0 -0
  155. snowflake/snowpark_connect/includes/jars/jackson-dataformat-yaml-2.15.2.jar +0 -0
  156. snowflake/snowpark_connect/includes/jars/jackson-datatype-jsr310-2.15.2.jar +0 -0
  157. snowflake/snowpark_connect/includes/jars/jackson-module-scala_2.12-2.15.2.jar +0 -0
  158. snowflake/snowpark_connect/includes/jars/json4s-ast_2.12-3.7.0-M11.jar +0 -0
  159. snowflake/snowpark_connect/includes/jars/json4s-core_2.12-3.7.0-M11.jar +0 -0
  160. snowflake/snowpark_connect/includes/jars/json4s-jackson_2.12-3.7.0-M11.jar +0 -0
  161. snowflake/snowpark_connect/includes/jars/json4s-native_2.12-3.7.0-M11.jar +0 -0
  162. snowflake/snowpark_connect/includes/jars/json4s-scalap_2.12-3.7.0-M11.jar +0 -0
  163. snowflake/snowpark_connect/includes/jars/kryo-shaded-4.0.2.jar +0 -0
  164. snowflake/snowpark_connect/includes/jars/log4j-1.2-api-2.20.0.jar +0 -0
  165. snowflake/snowpark_connect/includes/jars/log4j-api-2.20.0.jar +0 -0
  166. snowflake/snowpark_connect/includes/jars/log4j-core-2.20.0.jar +0 -0
  167. snowflake/snowpark_connect/includes/jars/log4j-slf4j2-impl-2.20.0.jar +0 -0
  168. snowflake/snowpark_connect/includes/jars/paranamer-2.8.3.jar +0 -0
  169. snowflake/snowpark_connect/includes/jars/paranamer-2.8.jar +0 -0
  170. snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.1.0.jar +0 -0
  171. snowflake/snowpark_connect/includes/jars/scala-collection-compat_2.12-2.7.0.jar +0 -0
  172. snowflake/snowpark_connect/includes/jars/scala-parser-combinators_2.12-2.3.0.jar +0 -0
  173. snowflake/snowpark_connect/includes/jars/scala-reflect-2.12.18.jar +0 -0
  174. snowflake/snowpark_connect/includes/jars/scala-xml_2.12-2.1.0.jar +0 -0
  175. snowflake/snowpark_connect/includes/jars/slf4j-api-2.0.7.jar +0 -0
  176. snowflake/snowpark_connect/includes/jars/spark-catalyst_2.12-3.5.6.jar +0 -0
  177. snowflake/snowpark_connect/includes/jars/spark-common-utils_2.12-3.5.6.jar +0 -0
  178. snowflake/snowpark_connect/includes/jars/spark-core_2.12-3.5.6.jar +0 -0
  179. snowflake/snowpark_connect/includes/jars/spark-graphx_2.12-3.5.6.jar +0 -0
  180. snowflake/snowpark_connect/includes/jars/spark-hive-thriftserver_2.12-3.5.6.jar +0 -0
  181. snowflake/snowpark_connect/includes/jars/spark-hive_2.12-3.5.6.jar +0 -0
  182. snowflake/snowpark_connect/includes/jars/spark-kvstore_2.12-3.5.6.jar +0 -0
  183. snowflake/snowpark_connect/includes/jars/spark-launcher_2.12-3.5.6.jar +0 -0
  184. snowflake/snowpark_connect/includes/jars/spark-mesos_2.12-3.5.6.jar +0 -0
  185. snowflake/snowpark_connect/includes/jars/spark-mllib-local_2.12-3.5.6.jar +0 -0
  186. snowflake/snowpark_connect/includes/jars/spark-network-common_2.12-3.5.6.jar +0 -0
  187. snowflake/snowpark_connect/includes/jars/spark-network-shuffle_2.12-3.5.6.jar +0 -0
  188. snowflake/snowpark_connect/includes/jars/spark-repl_2.12-3.5.6.jar +0 -0
  189. snowflake/snowpark_connect/includes/jars/spark-sketch_2.12-3.5.6.jar +0 -0
  190. snowflake/snowpark_connect/includes/jars/spark-sql-api_2.12-3.5.6.jar +0 -0
  191. snowflake/snowpark_connect/includes/jars/spark-tags_2.12-3.5.6.jar +0 -0
  192. snowflake/snowpark_connect/includes/jars/spark-unsafe_2.12-3.5.6.jar +0 -0
  193. snowflake/snowpark_connect/includes/jars/spark-yarn_2.12-3.5.6.jar +0 -0
  194. {snowpark_connect-0.27.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-connect +0 -0
  195. {snowpark_connect-0.27.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-session +0 -0
  196. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/WHEEL +0 -0
  197. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/LICENSE-binary +0 -0
  198. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/LICENSE.txt +0 -0
  199. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/NOTICE-binary +0 -0
  200. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,239 @@
1
+ #
2
+ # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
3
+ #
4
+
5
+ import hashlib
6
+ from dataclasses import dataclass
7
+
8
+ from pyspark.sql.connect.proto.expressions_pb2 import CommonInlineUserDefinedFunction
9
+
10
+ from snowflake.snowpark.types import ArrayType, MapType, VariantType
11
+ from snowflake.snowpark_connect.resources_initializer import (
12
+ ensure_scala_udf_jars_uploaded,
13
+ )
14
+ from snowflake.snowpark_connect.type_mapping import (
15
+ map_type_to_snowflake_type,
16
+ proto_to_snowpark_type,
17
+ )
18
+ from snowflake.snowpark_connect.utils.jvm_udf_utils import (
19
+ NullHandling,
20
+ Param,
21
+ ReturnType,
22
+ Signature,
23
+ build_jvm_udxf_imports,
24
+ map_type_to_java_type,
25
+ )
26
+ from snowflake.snowpark_connect.utils.session import get_or_create_snowpark_session
27
+
28
+ JAVA_UDTF_PREFIX = "__SC_JAVA_UDTF_"
29
+
30
+ SCALA_INPUT_VARIANT = """
31
+ Object mappedInput = com.snowflake.sas.scala.UdfPacketUtils$.MODULE$.fromVariant(udfPacket, input, 0);
32
+
33
+ java.util.Iterator<Object> javaInput = Arrays.asList(mappedInput).iterator();
34
+ scala.collection.Iterator<Object> scalaInput = new scala.collection.AbstractIterator<Object>() {
35
+ public boolean hasNext() { return javaInput.hasNext(); }
36
+ public Object next() { return javaInput.next(); }
37
+ };
38
+ """
39
+
40
+ SCALA_INPUT_SIMPLE_TYPE = """
41
+ java.util.Iterator<__iterator_type__> javaInput = Arrays.asList(input).iterator();
42
+ scala.collection.Iterator<__iterator_type__> scalaInput = new scala.collection.AbstractIterator<__iterator_type__>() {
43
+ public boolean hasNext() { return javaInput.hasNext(); }
44
+ public __iterator_type__ next() { return javaInput.next(); }
45
+ };
46
+ """
47
+
48
+ UDTF_TEMPLATE = """
49
+ import org.apache.spark.sql.connect.common.UdfPacket;
50
+
51
+ import java.io.IOException;
52
+ import java.io.InputStream;
53
+ import java.io.ObjectInputStream;
54
+ import java.io.Serializable;
55
+ import java.nio.file.Files;
56
+ import java.nio.file.Paths;
57
+
58
+ import java.util.*;
59
+ import java.lang.*;
60
+ import java.util.stream.Collectors;
61
+ import com.snowflake.snowpark_java.types.*;
62
+ import java.util.stream.Stream;
63
+ import java.util.stream.StreamSupport;
64
+
65
+ public class OutputRow {
66
+ public Variant __java_udtf_prefix__C1;
67
+ public OutputRow(Variant __java_udtf_prefix__C1) {
68
+ this.__java_udtf_prefix__C1 = __java_udtf_prefix__C1;
69
+ }
70
+ }
71
+
72
+ public class JavaUdtfHandler {
73
+ private final static String OPERATION_FILE = "__operation_file__";
74
+ private static scala.Function1<scala.collection.Iterator<__iterator_type__>, scala.collection.Iterator<Object>> operation = null;
75
+ private static UdfPacket udfPacket = null;
76
+
77
+ public static Class getOutputClass() { return OutputRow.class; }
78
+
79
+ private static void loadOperation() throws IOException, ClassNotFoundException {
80
+ if (operation != null) {
81
+ return; // Already loaded
82
+ }
83
+
84
+ udfPacket = com.snowflake.sas.scala.Utils$.MODULE$.deserializeUdfPacket(OPERATION_FILE);
85
+ operation = (scala.Function1<scala.collection.Iterator<__iterator_type__>, scala.collection.Iterator<Object>>) udfPacket.function();
86
+ }
87
+
88
+ public Stream<OutputRow> process(__input_type__ input) throws IOException, ClassNotFoundException {
89
+ loadOperation();
90
+
91
+ __scala_input__
92
+
93
+ scala.collection.Iterator<Object> scalaResult = operation.apply(scalaInput);
94
+
95
+ java.util.Iterator<Variant> javaResult = new java.util.Iterator<Variant>() {
96
+ public boolean hasNext() { return scalaResult.hasNext(); }
97
+ public Variant next() {
98
+ return com.snowflake.sas.scala.Utils$.MODULE$.toVariant(scalaResult.next(), udfPacket);
99
+ }
100
+ };
101
+
102
+ return StreamSupport.stream(Spliterators.spliteratorUnknownSize(javaResult, Spliterator.ORDERED), false)
103
+ .map(i -> new OutputRow(i));
104
+ }
105
+
106
+ public Stream<OutputRow> endPartition() {
107
+ return Stream.empty();
108
+ }
109
+ }
110
+ """
111
+
112
+
113
+ @dataclass(frozen=True)
114
+ class JavaUDTFDef:
115
+ """
116
+ Complete definition for creating a Java UDTF in Snowflake.
117
+
118
+ Contains all the information needed to generate the CREATE FUNCTION SQL statement
119
+ and the Java code body for the UDTF.
120
+
121
+ Attributes:
122
+ name: UDTF name
123
+ signature: SQL signature (for Snowflake function definition)
124
+ java_signature: Java signature (for Java code generation)
125
+ imports: List of JAR files to import
126
+ null_handling: Null handling behavior (defaults to RETURNS_NULL_ON_NULL_INPUT)
127
+ """
128
+
129
+ name: str
130
+ signature: Signature
131
+ java_signature: Signature
132
+ imports: list[str]
133
+ null_handling: NullHandling = NullHandling.RETURNS_NULL_ON_NULL_INPUT
134
+
135
+ def _gen_body_java(self) -> str:
136
+ returns_variant = self.signature.returns.data_type == "VARIANT"
137
+ return_type = (
138
+ "Variant" if returns_variant else self.java_signature.returns.data_type
139
+ )
140
+
141
+ is_variant_input = self.java_signature.params[0].data_type.lower() == "variant"
142
+
143
+ scala_input_template = (
144
+ SCALA_INPUT_VARIANT if is_variant_input else SCALA_INPUT_SIMPLE_TYPE
145
+ )
146
+
147
+ iterator_type = (
148
+ "Object" if is_variant_input else self.java_signature.params[0].data_type
149
+ )
150
+
151
+ return (
152
+ UDTF_TEMPLATE.replace("__operation_file__", self.imports[0].split("/")[-1])
153
+ .replace("__scala_input__", scala_input_template)
154
+ .replace("__iterator_type__", iterator_type)
155
+ .replace("__input_type__", self.java_signature.params[0].data_type)
156
+ .replace("__return_type__", return_type)
157
+ .replace("__java_udtf_prefix__", JAVA_UDTF_PREFIX)
158
+ )
159
+
160
+ def to_create_function_sql(self) -> str:
161
+ args = ", ".join(
162
+ [f"{param.name} {param.data_type}" for param in self.signature.params]
163
+ )
164
+
165
+ def quote_single(s: str) -> str:
166
+ """Helper function to wrap strings in single quotes for SQL."""
167
+ return "'" + s + "'"
168
+
169
+ # Handler and imports
170
+ imports_sql = f"IMPORTS = ({', '.join(quote_single(x) for x in self.imports)})"
171
+
172
+ return f"""
173
+ create or replace function {self.name}({args})
174
+ returns table ({JAVA_UDTF_PREFIX}C1 VARIANT)
175
+ language java
176
+ runtime_version = 17
177
+ PACKAGES = ('com.snowflake:snowpark:latest')
178
+ {imports_sql}
179
+ handler='JavaUdtfHandler'
180
+ as
181
+ $$
182
+ {self._gen_body_java()}
183
+ $$;"""
184
+
185
+
186
+ def create_java_udtf_for_scala_flatmap_handling(
187
+ udf_proto: CommonInlineUserDefinedFunction,
188
+ ) -> str:
189
+ ensure_scala_udf_jars_uploaded()
190
+
191
+ return_type = proto_to_snowpark_type(udf_proto.scalar_scala_udf.outputType)
192
+
193
+ session = get_or_create_snowpark_session()
194
+
195
+ return_type_java = map_type_to_java_type(return_type)
196
+ sql_return_type = map_type_to_snowflake_type(return_type)
197
+
198
+ java_input_params: list[Param] = []
199
+ sql_input_params: list[Param] = []
200
+ for i, input_type_proto in enumerate(udf_proto.scalar_scala_udf.inputTypes):
201
+ input_type = proto_to_snowpark_type(input_type_proto)
202
+
203
+ param_name = "arg" + str(i)
204
+
205
+ if isinstance(input_type, (ArrayType, MapType, VariantType)):
206
+ java_type = "Variant"
207
+ snowflake_type = "Variant"
208
+ else:
209
+ java_type = map_type_to_java_type(input_type)
210
+ snowflake_type = map_type_to_snowflake_type(input_type)
211
+
212
+ java_input_params.append(Param(param_name, java_type))
213
+ sql_input_params.append(Param(param_name, snowflake_type))
214
+
215
+ udtf_name = (
216
+ JAVA_UDTF_PREFIX + hashlib.md5(udf_proto.scalar_scala_udf.payload).hexdigest()
217
+ )
218
+
219
+ imports = build_jvm_udxf_imports(
220
+ session,
221
+ udf_proto.scalar_scala_udf.payload,
222
+ udtf_name,
223
+ )
224
+
225
+ udtf = JavaUDTFDef(
226
+ name=udtf_name,
227
+ signature=Signature(
228
+ params=sql_input_params, returns=ReturnType(sql_return_type)
229
+ ),
230
+ imports=imports,
231
+ java_signature=Signature(
232
+ params=java_input_params, returns=ReturnType(return_type_java)
233
+ ),
234
+ )
235
+
236
+ sql = udtf.to_create_function_sql()
237
+ session.sql(sql).collect()
238
+
239
+ return udtf_name
@@ -0,0 +1,281 @@
1
+ #
2
+ # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
3
+ #
4
+
5
+ from dataclasses import dataclass
6
+ from enum import Enum
7
+ from typing import List, Union
8
+
9
+ import snowflake.snowpark.types as snowpark_type
10
+ import snowflake.snowpark_connect.includes.python.pyspark.sql.connect.proto.types_pb2 as types_proto
11
+ from snowflake import snowpark
12
+ from snowflake.snowpark_connect.config import get_scala_version
13
+ from snowflake.snowpark_connect.error.error_codes import ErrorCodes
14
+ from snowflake.snowpark_connect.error.error_utils import attach_custom_error_code
15
+ from snowflake.snowpark_connect.resources_initializer import (
16
+ JSON_4S_JAR_212,
17
+ JSON_4S_JAR_213,
18
+ RESOURCE_PATH,
19
+ SAS_SCALA_UDF_JAR_212,
20
+ SAS_SCALA_UDF_JAR_213,
21
+ SCALA_REFLECT_JAR_212,
22
+ SCALA_REFLECT_JAR_213,
23
+ SPARK_COMMON_UTILS_JAR_212,
24
+ SPARK_COMMON_UTILS_JAR_213,
25
+ SPARK_CONNECT_CLIENT_JAR_212,
26
+ SPARK_CONNECT_CLIENT_JAR_213,
27
+ SPARK_SQL_JAR_212,
28
+ SPARK_SQL_JAR_213,
29
+ )
30
+
31
+
32
+ @dataclass(frozen=True)
33
+ class Param:
34
+ """
35
+ Represents a function parameter with name and data type.
36
+
37
+ Attributes:
38
+ name: Parameter name
39
+ data_type: Parameter data type as a string
40
+ """
41
+
42
+ name: str
43
+ data_type: str
44
+
45
+
46
+ @dataclass(frozen=True)
47
+ class NullHandling(str, Enum):
48
+ """
49
+ Enumeration for UDF null handling behavior.
50
+
51
+ Determines how the UDF behaves when input parameters contain null values.
52
+ """
53
+
54
+ RETURNS_NULL_ON_NULL_INPUT = "RETURNS NULL ON NULL INPUT"
55
+ CALLED_ON_NULL_INPUT = "CALLED ON NULL INPUT"
56
+
57
+
58
+ @dataclass(frozen=True)
59
+ class ReturnType:
60
+ """
61
+ Represents the return type of a function.
62
+
63
+ Attributes:
64
+ data_type: Return data type as a string
65
+ """
66
+
67
+ data_type: str
68
+
69
+
70
+ @dataclass(frozen=True)
71
+ class Signature:
72
+ """
73
+ Represents a function signature with parameters and return type.
74
+
75
+ Attributes:
76
+ params: List of function parameters
77
+ returns: Function return type
78
+ """
79
+
80
+ params: List[Param]
81
+ returns: ReturnType
82
+
83
+
84
+ def build_jvm_udxf_imports(
85
+ session: snowpark.Session, payload: bytes, udf_name: str
86
+ ) -> List[str]:
87
+ """
88
+ Build the list of imports needed for the JVM UDxF.
89
+
90
+ This function:
91
+ 1. Saves the UDF payload to a binary file in the session stage
92
+ 2. Collects user-uploaded JAR files from the stage
93
+ 3. Returns a list of all required JAR files for the UDxF
94
+
95
+ Args:
96
+ session: Snowpark session
97
+ payload: Binary payload containing the serialized Scala UDF
98
+ udf_name: Name of the Scala UDF (used for the binary file name)
99
+ is_map_return: Indicates if the UDxF returns a Map (affects imports)
100
+
101
+ Returns:
102
+ List of JAR file paths to be imported by the UDxF
103
+ """
104
+ # Save pciudf._payload to a bin file:
105
+ import io
106
+
107
+ payload_as_stream = io.BytesIO(payload)
108
+ stage = session.get_session_stage()
109
+ stage_resource_path = stage + RESOURCE_PATH
110
+ closure_binary_file = stage_resource_path + "/scala/bin/" + udf_name + ".bin"
111
+ session.file.put_stream(
112
+ payload_as_stream,
113
+ closure_binary_file,
114
+ overwrite=True,
115
+ )
116
+
117
+ # Format the user jars to be used in the IMPORTS clause of the stored procedure.
118
+ return (
119
+ [closure_binary_file]
120
+ + _scala_static_imports_for_udf(stage_resource_path)
121
+ + list(session._artifact_jars)
122
+ )
123
+
124
+
125
+ def _scala_static_imports_for_udf(stage_resource_path: str) -> list[str]:
126
+ scala_version = get_scala_version()
127
+ if scala_version == "2.12":
128
+ return [
129
+ f"{stage_resource_path}/{SPARK_CONNECT_CLIENT_JAR_212}",
130
+ f"{stage_resource_path}/{SPARK_COMMON_UTILS_JAR_212}",
131
+ f"{stage_resource_path}/{SPARK_SQL_JAR_212}",
132
+ f"{stage_resource_path}/{JSON_4S_JAR_212}",
133
+ f"{stage_resource_path}/{SAS_SCALA_UDF_JAR_212}",
134
+ f"{stage_resource_path}/{SCALA_REFLECT_JAR_212}", # Required for deserializing Scala lambdas
135
+ ]
136
+
137
+ if scala_version == "2.13":
138
+ return [
139
+ f"{stage_resource_path}/{SPARK_CONNECT_CLIENT_JAR_213}",
140
+ f"{stage_resource_path}/{SPARK_COMMON_UTILS_JAR_213}",
141
+ f"{stage_resource_path}/{SPARK_SQL_JAR_213}",
142
+ f"{stage_resource_path}/{JSON_4S_JAR_213}",
143
+ f"{stage_resource_path}/{SAS_SCALA_UDF_JAR_213}",
144
+ f"{stage_resource_path}/{SCALA_REFLECT_JAR_213}", # Required for deserializing Scala lambdas
145
+ ]
146
+
147
+ # invalid Scala version
148
+ exception = ValueError(
149
+ f"Unsupported Scala version: {scala_version}. Snowpark Connect supports Scala 2.12 and 2.13"
150
+ )
151
+ attach_custom_error_code(exception, ErrorCodes.INVALID_CONFIG_VALUE)
152
+ raise exception
153
+
154
+
155
+ def map_type_to_java_type(
156
+ t: Union[snowpark_type.DataType, types_proto.DataType]
157
+ ) -> str:
158
+ """Maps a Snowpark or Spark protobuf type to a Java type string."""
159
+ if not t:
160
+ return "String"
161
+ is_snowpark_type = isinstance(t, snowpark_type.DataType)
162
+ condition = type(t) if is_snowpark_type else t.WhichOneof("kind")
163
+ match condition:
164
+ case snowpark_type.ArrayType | "array":
165
+ return (
166
+ f"{map_type_to_java_type(t.element_type)}[]"
167
+ if is_snowpark_type
168
+ else f"{map_type_to_java_type(t.array.element_type)}[]"
169
+ )
170
+ case snowpark_type.BinaryType | "binary":
171
+ return "byte[]"
172
+ case snowpark_type.BooleanType | "boolean":
173
+ return "Boolean"
174
+ case snowpark_type.ByteType | "byte":
175
+ return "Byte"
176
+ case snowpark_type.DateType | "date":
177
+ return "java.sql.Date"
178
+ case snowpark_type.DecimalType | "decimal":
179
+ return "java.math.BigDecimal"
180
+ case snowpark_type.DoubleType | "double":
181
+ return "Double"
182
+ case snowpark_type.FloatType | "float":
183
+ return "Float"
184
+ case snowpark_type.GeographyType:
185
+ return "Geography"
186
+ case snowpark_type.IntegerType | "integer":
187
+ return "Integer"
188
+ case snowpark_type.LongType | "long":
189
+ return "Long"
190
+ case snowpark_type.MapType | "map": # can also map to OBJECT in Snowflake
191
+ key_type = (
192
+ map_type_to_java_type(t.key_type)
193
+ if is_snowpark_type
194
+ else map_type_to_java_type(t.map.key_type)
195
+ )
196
+ value_type = (
197
+ map_type_to_java_type(t.value_type)
198
+ if is_snowpark_type
199
+ else map_type_to_java_type(t.map.value_type)
200
+ )
201
+ return f"Map<{key_type}, {value_type}>"
202
+ case snowpark_type.NullType | "null":
203
+ return "String" # cannot set the return type to Null in Snowpark Java UDAFs
204
+ case snowpark_type.ShortType | "short":
205
+ return "Short"
206
+ case snowpark_type.StringType | "string" | "char" | "varchar":
207
+ return "String"
208
+ case snowpark_type.StructType | "struct":
209
+ return "Variant"
210
+ case snowpark_type.TimestampType | "timestamp" | "timestamp_ntz":
211
+ return "java.sql.Timestamp"
212
+ case snowpark_type.VariantType:
213
+ return "Variant"
214
+ case _:
215
+ exception = ValueError(f"Unsupported Snowpark type: {t}")
216
+ attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_TYPE)
217
+ raise exception
218
+
219
+
220
+ def cast_java_map_args_from_given_type(
221
+ arg_name: str, input_type: Union[snowpark_type.DataType, types_proto.DataType]
222
+ ) -> str:
223
+ """If the input_type is a Map or Struct, cast the argument arg_name to the correct type in Java."""
224
+ is_snowpark_type = isinstance(input_type, snowpark_type.DataType)
225
+
226
+ def convert_from_string_to_type(
227
+ arg_name: str, t: Union[snowpark_type.DataType, types_proto.DataType]
228
+ ) -> str:
229
+ """Convert the string argument arg_name to the specified type t in Java."""
230
+ condition = type(t) if is_snowpark_type else t.WhichOneof("kind")
231
+ match condition:
232
+ case snowpark_type.BinaryType | "binary":
233
+ return arg_name + ".getBytes()"
234
+ case snowpark_type.BooleanType | "boolean":
235
+ return f"Boolean.valueOf({arg_name}"
236
+ case snowpark_type.ByteType | "byte":
237
+ return arg_name + ".getBytes()[0]" # TODO: verify if this is correct
238
+ case snowpark_type.DateType | "date":
239
+ return f"java.sql.Date.valueOf({arg_name})"
240
+ case snowpark_type.DecimalType | "decimal":
241
+ return f"new BigDecimal({arg_name})"
242
+ case snowpark_type.DoubleType | "double":
243
+ return f"Double.valueOf({arg_name}"
244
+ case snowpark_type.FloatType | "float":
245
+ return f"Float.valueOf({arg_name}"
246
+ case snowpark_type.IntegerType | "integer":
247
+ return f"Integer.valueOf({arg_name}"
248
+ case snowpark_type.LongType | "long":
249
+ return f"Long.valueOf({arg_name}"
250
+ case snowpark_type.ShortType | "short":
251
+ return f"Short.valueOf({arg_name}"
252
+ case snowpark_type.StringType | "string" | "char" | "varchar":
253
+ return arg_name
254
+ case snowpark_type.TimestampType | "timestamp" | "timestamp_ntz":
255
+ return f"java.sql.Timestamp.valueOf({arg_name})" # todo add test
256
+ case _:
257
+ exception = ValueError(f"Unsupported Snowpark type: {t}")
258
+ attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_TYPE)
259
+ raise exception
260
+
261
+ if (is_snowpark_type and isinstance(input_type, snowpark_type.MapType)) or (
262
+ not is_snowpark_type and input_type.WhichOneof("kind") == "map"
263
+ ):
264
+ key_type = input_type.key_type if is_snowpark_type else input_type.map.key_type
265
+ value_type = (
266
+ input_type.value_type if is_snowpark_type else input_type.map.value_type
267
+ )
268
+ key_converter = "{" + convert_from_string_to_type("e.getKey()", key_type) + "}"
269
+ value_converter = (
270
+ "{" + convert_from_string_to_type("e.getValue()", value_type) + "}"
271
+ )
272
+ return f"""
273
+ {arg_name}.entrySet()
274
+ .stream()
275
+ .collect(Collectors.toMap(
276
+ e -> {key_converter},
277
+ e -> {value_converter}
278
+ ));
279
+ """
280
+ else:
281
+ return arg_name