snowpark-connect 0.27.0__py3-none-any.whl → 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (192) hide show
  1. snowflake/snowpark_connect/__init__.py +1 -0
  2. snowflake/snowpark_connect/analyze_plan/map_tree_string.py +8 -4
  3. snowflake/snowpark_connect/client/__init__.py +15 -0
  4. snowflake/snowpark_connect/client/error_utils.py +30 -0
  5. snowflake/snowpark_connect/client/exceptions.py +36 -0
  6. snowflake/snowpark_connect/client/query_results.py +90 -0
  7. snowflake/snowpark_connect/client/server.py +680 -0
  8. snowflake/snowpark_connect/client/utils/__init__.py +10 -0
  9. snowflake/snowpark_connect/client/utils/session.py +85 -0
  10. snowflake/snowpark_connect/column_name_handler.py +404 -243
  11. snowflake/snowpark_connect/column_qualifier.py +43 -0
  12. snowflake/snowpark_connect/config.py +237 -23
  13. snowflake/snowpark_connect/constants.py +2 -0
  14. snowflake/snowpark_connect/dataframe_container.py +102 -8
  15. snowflake/snowpark_connect/date_time_format_mapping.py +71 -13
  16. snowflake/snowpark_connect/error/error_codes.py +50 -0
  17. snowflake/snowpark_connect/error/error_utils.py +172 -23
  18. snowflake/snowpark_connect/error/exceptions.py +13 -4
  19. snowflake/snowpark_connect/execute_plan/map_execution_command.py +15 -160
  20. snowflake/snowpark_connect/execute_plan/map_execution_root.py +26 -20
  21. snowflake/snowpark_connect/execute_plan/utils.py +5 -1
  22. snowflake/snowpark_connect/expression/function_defaults.py +9 -2
  23. snowflake/snowpark_connect/expression/hybrid_column_map.py +53 -5
  24. snowflake/snowpark_connect/expression/literal.py +37 -13
  25. snowflake/snowpark_connect/expression/map_cast.py +123 -5
  26. snowflake/snowpark_connect/expression/map_expression.py +80 -27
  27. snowflake/snowpark_connect/expression/map_extension.py +322 -12
  28. snowflake/snowpark_connect/expression/map_sql_expression.py +316 -81
  29. snowflake/snowpark_connect/expression/map_udf.py +85 -20
  30. snowflake/snowpark_connect/expression/map_unresolved_attribute.py +451 -173
  31. snowflake/snowpark_connect/expression/map_unresolved_function.py +2748 -746
  32. snowflake/snowpark_connect/expression/map_unresolved_star.py +87 -23
  33. snowflake/snowpark_connect/expression/map_update_fields.py +70 -18
  34. snowflake/snowpark_connect/expression/map_window_function.py +18 -3
  35. snowflake/snowpark_connect/includes/jars/{scala-library-2.12.18.jar → sas-scala-udf_2.12-0.2.0.jar} +0 -0
  36. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py +1 -1
  37. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/listener_worker.py +1 -1
  38. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.py +12 -10
  39. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.pyi +14 -2
  40. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +10 -8
  41. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +13 -6
  42. snowflake/snowpark_connect/relation/catalogs/abstract_spark_catalog.py +65 -17
  43. snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +297 -49
  44. snowflake/snowpark_connect/relation/catalogs/utils.py +12 -4
  45. snowflake/snowpark_connect/relation/io_utils.py +110 -10
  46. snowflake/snowpark_connect/relation/map_aggregate.py +196 -255
  47. snowflake/snowpark_connect/relation/map_catalog.py +5 -1
  48. snowflake/snowpark_connect/relation/map_column_ops.py +264 -96
  49. snowflake/snowpark_connect/relation/map_extension.py +263 -29
  50. snowflake/snowpark_connect/relation/map_join.py +683 -442
  51. snowflake/snowpark_connect/relation/map_local_relation.py +28 -1
  52. snowflake/snowpark_connect/relation/map_map_partitions.py +83 -8
  53. snowflake/snowpark_connect/relation/map_relation.py +48 -19
  54. snowflake/snowpark_connect/relation/map_row_ops.py +310 -91
  55. snowflake/snowpark_connect/relation/map_show_string.py +13 -6
  56. snowflake/snowpark_connect/relation/map_sql.py +1233 -222
  57. snowflake/snowpark_connect/relation/map_stats.py +48 -9
  58. snowflake/snowpark_connect/relation/map_subquery_alias.py +11 -2
  59. snowflake/snowpark_connect/relation/map_udtf.py +14 -4
  60. snowflake/snowpark_connect/relation/read/jdbc_read_dbapi.py +53 -14
  61. snowflake/snowpark_connect/relation/read/map_read.py +134 -43
  62. snowflake/snowpark_connect/relation/read/map_read_csv.py +255 -45
  63. snowflake/snowpark_connect/relation/read/map_read_jdbc.py +17 -5
  64. snowflake/snowpark_connect/relation/read/map_read_json.py +320 -85
  65. snowflake/snowpark_connect/relation/read/map_read_parquet.py +142 -27
  66. snowflake/snowpark_connect/relation/read/map_read_partitioned_parquet.py +142 -0
  67. snowflake/snowpark_connect/relation/read/map_read_socket.py +11 -3
  68. snowflake/snowpark_connect/relation/read/map_read_table.py +82 -5
  69. snowflake/snowpark_connect/relation/read/map_read_text.py +18 -3
  70. snowflake/snowpark_connect/relation/read/metadata_utils.py +170 -0
  71. snowflake/snowpark_connect/relation/read/reader_config.py +36 -3
  72. snowflake/snowpark_connect/relation/read/utils.py +50 -5
  73. snowflake/snowpark_connect/relation/stage_locator.py +91 -55
  74. snowflake/snowpark_connect/relation/utils.py +128 -5
  75. snowflake/snowpark_connect/relation/write/jdbc_write_dbapi.py +19 -3
  76. snowflake/snowpark_connect/relation/write/map_write.py +929 -319
  77. snowflake/snowpark_connect/relation/write/map_write_jdbc.py +8 -2
  78. snowflake/snowpark_connect/resources/java_udfs-1.0-SNAPSHOT.jar +0 -0
  79. snowflake/snowpark_connect/resources_initializer.py +110 -48
  80. snowflake/snowpark_connect/server.py +546 -456
  81. snowflake/snowpark_connect/server_common/__init__.py +500 -0
  82. snowflake/snowpark_connect/snowflake_session.py +65 -0
  83. snowflake/snowpark_connect/start_server.py +53 -5
  84. snowflake/snowpark_connect/type_mapping.py +349 -27
  85. snowflake/snowpark_connect/typed_column.py +9 -7
  86. snowflake/snowpark_connect/utils/artifacts.py +9 -8
  87. snowflake/snowpark_connect/utils/cache.py +49 -27
  88. snowflake/snowpark_connect/utils/concurrent.py +36 -1
  89. snowflake/snowpark_connect/utils/context.py +187 -37
  90. snowflake/snowpark_connect/utils/describe_query_cache.py +68 -53
  91. snowflake/snowpark_connect/utils/env_utils.py +5 -1
  92. snowflake/snowpark_connect/utils/expression_transformer.py +172 -0
  93. snowflake/snowpark_connect/utils/identifiers.py +137 -3
  94. snowflake/snowpark_connect/utils/io_utils.py +57 -1
  95. snowflake/snowpark_connect/utils/java_stored_procedure.py +125 -0
  96. snowflake/snowpark_connect/utils/java_udaf_utils.py +303 -0
  97. snowflake/snowpark_connect/utils/java_udtf_utils.py +239 -0
  98. snowflake/snowpark_connect/utils/jvm_udf_utils.py +248 -0
  99. snowflake/snowpark_connect/utils/open_telemetry.py +516 -0
  100. snowflake/snowpark_connect/utils/pandas_udtf_utils.py +8 -4
  101. snowflake/snowpark_connect/utils/patch_spark_line_number.py +181 -0
  102. snowflake/snowpark_connect/utils/profiling.py +25 -8
  103. snowflake/snowpark_connect/utils/scala_udf_utils.py +101 -332
  104. snowflake/snowpark_connect/utils/sequence.py +21 -0
  105. snowflake/snowpark_connect/utils/session.py +64 -28
  106. snowflake/snowpark_connect/utils/snowpark_connect_logging.py +51 -9
  107. snowflake/snowpark_connect/utils/spcs_logger.py +290 -0
  108. snowflake/snowpark_connect/utils/telemetry.py +163 -22
  109. snowflake/snowpark_connect/utils/temporary_view_cache.py +67 -0
  110. snowflake/snowpark_connect/utils/temporary_view_helper.py +334 -0
  111. snowflake/snowpark_connect/utils/udf_cache.py +117 -41
  112. snowflake/snowpark_connect/utils/udf_helper.py +39 -37
  113. snowflake/snowpark_connect/utils/udf_utils.py +133 -14
  114. snowflake/snowpark_connect/utils/udtf_helper.py +8 -1
  115. snowflake/snowpark_connect/utils/udtf_utils.py +46 -31
  116. snowflake/snowpark_connect/utils/upload_java_jar.py +57 -0
  117. snowflake/snowpark_connect/version.py +1 -1
  118. snowflake/snowpark_decoder/dp_session.py +6 -2
  119. snowflake/snowpark_decoder/spark_decoder.py +12 -0
  120. {snowpark_connect-0.27.0.data → snowpark_connect-1.6.0.data}/scripts/snowpark-submit +2 -2
  121. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/METADATA +14 -7
  122. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/RECORD +129 -167
  123. snowflake/snowpark_connect/hidden_column.py +0 -39
  124. snowflake/snowpark_connect/includes/jars/antlr4-runtime-4.9.3.jar +0 -0
  125. snowflake/snowpark_connect/includes/jars/commons-cli-1.5.0.jar +0 -0
  126. snowflake/snowpark_connect/includes/jars/commons-codec-1.16.1.jar +0 -0
  127. snowflake/snowpark_connect/includes/jars/commons-collections-3.2.2.jar +0 -0
  128. snowflake/snowpark_connect/includes/jars/commons-collections4-4.4.jar +0 -0
  129. snowflake/snowpark_connect/includes/jars/commons-compiler-3.1.9.jar +0 -0
  130. snowflake/snowpark_connect/includes/jars/commons-compress-1.26.0.jar +0 -0
  131. snowflake/snowpark_connect/includes/jars/commons-crypto-1.1.0.jar +0 -0
  132. snowflake/snowpark_connect/includes/jars/commons-dbcp-1.4.jar +0 -0
  133. snowflake/snowpark_connect/includes/jars/commons-io-2.16.1.jar +0 -0
  134. snowflake/snowpark_connect/includes/jars/commons-lang-2.6.jar +0 -0
  135. snowflake/snowpark_connect/includes/jars/commons-lang3-3.12.0.jar +0 -0
  136. snowflake/snowpark_connect/includes/jars/commons-logging-1.1.3.jar +0 -0
  137. snowflake/snowpark_connect/includes/jars/commons-math3-3.6.1.jar +0 -0
  138. snowflake/snowpark_connect/includes/jars/commons-pool-1.5.4.jar +0 -0
  139. snowflake/snowpark_connect/includes/jars/commons-text-1.10.0.jar +0 -0
  140. snowflake/snowpark_connect/includes/jars/hadoop-client-api-trimmed-3.3.4.jar +0 -0
  141. snowflake/snowpark_connect/includes/jars/jackson-annotations-2.15.2.jar +0 -0
  142. snowflake/snowpark_connect/includes/jars/jackson-core-2.15.2.jar +0 -0
  143. snowflake/snowpark_connect/includes/jars/jackson-core-asl-1.9.13.jar +0 -0
  144. snowflake/snowpark_connect/includes/jars/jackson-databind-2.15.2.jar +0 -0
  145. snowflake/snowpark_connect/includes/jars/jackson-dataformat-yaml-2.15.2.jar +0 -0
  146. snowflake/snowpark_connect/includes/jars/jackson-datatype-jsr310-2.15.2.jar +0 -0
  147. snowflake/snowpark_connect/includes/jars/jackson-module-scala_2.12-2.15.2.jar +0 -0
  148. snowflake/snowpark_connect/includes/jars/json4s-ast_2.12-3.7.0-M11.jar +0 -0
  149. snowflake/snowpark_connect/includes/jars/json4s-core_2.12-3.7.0-M11.jar +0 -0
  150. snowflake/snowpark_connect/includes/jars/json4s-jackson_2.12-3.7.0-M11.jar +0 -0
  151. snowflake/snowpark_connect/includes/jars/json4s-native_2.12-3.7.0-M11.jar +0 -0
  152. snowflake/snowpark_connect/includes/jars/json4s-scalap_2.12-3.7.0-M11.jar +0 -0
  153. snowflake/snowpark_connect/includes/jars/kryo-shaded-4.0.2.jar +0 -0
  154. snowflake/snowpark_connect/includes/jars/log4j-1.2-api-2.20.0.jar +0 -0
  155. snowflake/snowpark_connect/includes/jars/log4j-api-2.20.0.jar +0 -0
  156. snowflake/snowpark_connect/includes/jars/log4j-core-2.20.0.jar +0 -0
  157. snowflake/snowpark_connect/includes/jars/log4j-slf4j2-impl-2.20.0.jar +0 -0
  158. snowflake/snowpark_connect/includes/jars/paranamer-2.8.3.jar +0 -0
  159. snowflake/snowpark_connect/includes/jars/paranamer-2.8.jar +0 -0
  160. snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.1.0.jar +0 -0
  161. snowflake/snowpark_connect/includes/jars/scala-collection-compat_2.12-2.7.0.jar +0 -0
  162. snowflake/snowpark_connect/includes/jars/scala-parser-combinators_2.12-2.3.0.jar +0 -0
  163. snowflake/snowpark_connect/includes/jars/scala-reflect-2.12.18.jar +0 -0
  164. snowflake/snowpark_connect/includes/jars/scala-xml_2.12-2.1.0.jar +0 -0
  165. snowflake/snowpark_connect/includes/jars/slf4j-api-2.0.7.jar +0 -0
  166. snowflake/snowpark_connect/includes/jars/spark-catalyst_2.12-3.5.6.jar +0 -0
  167. snowflake/snowpark_connect/includes/jars/spark-common-utils_2.12-3.5.6.jar +0 -0
  168. snowflake/snowpark_connect/includes/jars/spark-connect-client-jvm_2.12-3.5.6.jar +0 -0
  169. snowflake/snowpark_connect/includes/jars/spark-core_2.12-3.5.6.jar +0 -0
  170. snowflake/snowpark_connect/includes/jars/spark-graphx_2.12-3.5.6.jar +0 -0
  171. snowflake/snowpark_connect/includes/jars/spark-hive-thriftserver_2.12-3.5.6.jar +0 -0
  172. snowflake/snowpark_connect/includes/jars/spark-hive_2.12-3.5.6.jar +0 -0
  173. snowflake/snowpark_connect/includes/jars/spark-kvstore_2.12-3.5.6.jar +0 -0
  174. snowflake/snowpark_connect/includes/jars/spark-launcher_2.12-3.5.6.jar +0 -0
  175. snowflake/snowpark_connect/includes/jars/spark-mesos_2.12-3.5.6.jar +0 -0
  176. snowflake/snowpark_connect/includes/jars/spark-mllib-local_2.12-3.5.6.jar +0 -0
  177. snowflake/snowpark_connect/includes/jars/spark-network-common_2.12-3.5.6.jar +0 -0
  178. snowflake/snowpark_connect/includes/jars/spark-network-shuffle_2.12-3.5.6.jar +0 -0
  179. snowflake/snowpark_connect/includes/jars/spark-repl_2.12-3.5.6.jar +0 -0
  180. snowflake/snowpark_connect/includes/jars/spark-sketch_2.12-3.5.6.jar +0 -0
  181. snowflake/snowpark_connect/includes/jars/spark-sql-api_2.12-3.5.6.jar +0 -0
  182. snowflake/snowpark_connect/includes/jars/spark-sql_2.12-3.5.6.jar +0 -0
  183. snowflake/snowpark_connect/includes/jars/spark-tags_2.12-3.5.6.jar +0 -0
  184. snowflake/snowpark_connect/includes/jars/spark-unsafe_2.12-3.5.6.jar +0 -0
  185. snowflake/snowpark_connect/includes/jars/spark-yarn_2.12-3.5.6.jar +0 -0
  186. {snowpark_connect-0.27.0.data → snowpark_connect-1.6.0.data}/scripts/snowpark-connect +0 -0
  187. {snowpark_connect-0.27.0.data → snowpark_connect-1.6.0.data}/scripts/snowpark-session +0 -0
  188. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/WHEEL +0 -0
  189. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/licenses/LICENSE-binary +0 -0
  190. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/licenses/LICENSE.txt +0 -0
  191. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/licenses/NOTICE-binary +0 -0
  192. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,303 @@
1
+ #
2
+ # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
3
+ #
4
+
5
+ from dataclasses import dataclass
6
+
7
+ import snowflake.snowpark.types as snowpark_type
8
+ from snowflake.snowpark_connect.type_mapping import map_type_to_snowflake_type
9
+ from snowflake.snowpark_connect.utils.jvm_udf_utils import (
10
+ NullHandling,
11
+ Param,
12
+ ReturnType,
13
+ Signature,
14
+ build_jvm_udxf_imports,
15
+ cast_java_map_args_from_given_type,
16
+ map_type_to_java_type,
17
+ )
18
+ from snowflake.snowpark_connect.utils.snowpark_connect_logging import logger
19
+ from snowflake.snowpark_connect.utils.udf_utils import (
20
+ ProcessCommonInlineUserDefinedFunction,
21
+ )
22
+
23
+ # Prefix used for internally generated Java UDAF names to avoid conflicts
24
+ CREATE_JAVA_UDAF_PREFIX = "__SC_JAVA_UDAF_"
25
+
26
+
27
+ UDAF_TEMPLATE = """
28
+ import org.apache.spark.sql.connect.common.UdfPacket;
29
+
30
+ import java.io.IOException;
31
+ import java.io.InputStream;
32
+ import java.io.ObjectInputStream;
33
+ import java.io.Serializable;
34
+ import java.nio.file.Files;
35
+ import java.nio.file.Paths;
36
+
37
+ // Import types required for mapping
38
+ import java.util.*;
39
+ import java.util.stream.Collectors;
40
+ import com.snowflake.snowpark_java.types.*;
41
+
42
+ public class JavaUDAF {
43
+ private final static String OPERATION_FILE = "__operation_file__";
44
+ private static scala.Function2<__accumulator_type__, __value_type__, __value_type__> operation = null;
45
+
46
+ private static void loadOperation() throws IOException, ClassNotFoundException {
47
+ if (operation != null) {
48
+ return; // Already loaded
49
+ }
50
+
51
+ final UdfPacket udfPacket = com.snowflake.sas.scala.Utils$.MODULE$.deserializeUdfPacket(OPERATION_FILE);
52
+ operation = (scala.Function2<__accumulator_type__, __value_type__, __value_type__>) udfPacket.function();
53
+ }
54
+
55
+ public static class State implements Serializable {
56
+ public __accumulator_type__ value = null;
57
+ public boolean initialized = false;
58
+ }
59
+
60
+ public static State initialize() throws IOException, ClassNotFoundException {
61
+ loadOperation();
62
+ return new State();
63
+ }
64
+
65
+ public static State accumulate(State state, __accumulator_type__ accumulator, __value_type__ input) {
66
+ // TODO: Add conversion between value_type we get in input and the value that we are using in the operation
67
+ if (input == null) {
68
+ return state;
69
+ }
70
+
71
+ if (!state.initialized) {
72
+ state.value = input;
73
+ state.initialized = true;
74
+ } else {
75
+ state.value = operation.apply(state.value, input);
76
+ }
77
+ return state;
78
+ }
79
+
80
+ public static State merge(State s1, State s2) {
81
+ if (!s2.initialized) {
82
+ return s1;
83
+ }
84
+ if (!s1.initialized) {
85
+ return s2;
86
+ }
87
+
88
+ s1.value = operation.apply(s1.value, s2.value);
89
+ return s1;
90
+ }
91
+
92
+ public static __return_type__ finish(State state) {
93
+ return state.initialized ? __response_wrapper__ : null;
94
+ }
95
+ }"""
96
+
97
+
98
+ @dataclass(frozen=True)
99
+ class JavaUDAFDef:
100
+ """
101
+ Complete definition for creating a Java UDAF in Snowflake.
102
+
103
+ Contains all the information needed to generate the CREATE FUNCTION SQL statement
104
+ and the Java code body for the UDAF.
105
+
106
+ Attributes:
107
+ name: UDAF name
108
+ signature: SQL signature (for Snowflake function definition)
109
+ java_signature: Java signature (for Java code generation)
110
+ java_invocation_args: List of transformed arguments passed to the Java UDAF invocation, with type casting applied for Map types and other necessary conversions.
111
+ imports: List of JAR files to import
112
+ null_handling: Null handling behavior (defaults to RETURNS_NULL_ON_NULL_INPUT)
113
+ """
114
+
115
+ name: str
116
+ signature: Signature
117
+ java_signature: Signature
118
+ java_invocation_args: list[str]
119
+ imports: list[str]
120
+ null_handling: NullHandling = NullHandling.RETURNS_NULL_ON_NULL_INPUT
121
+
122
+ # -------------------- DDL Emitter --------------------
123
+
124
+ def _gen_body_java(self) -> str:
125
+ """
126
+ Generate the Java code body for the UDAF.
127
+
128
+ Creates a Java object that loads the serialized function from a binary file
129
+ and provides a run method to execute it.
130
+
131
+ Returns:
132
+ String containing the complete Java code for the UDAF body
133
+ """
134
+ returns_variant = self.signature.returns.data_type == "VARIANT"
135
+ return_type = (
136
+ "Variant" if returns_variant else self.java_signature.params[0].data_type
137
+ )
138
+ response_wrapper = (
139
+ "new Variant(state.value)" if returns_variant else "state.value"
140
+ )
141
+ return (
142
+ UDAF_TEMPLATE.replace("__operation_file__", self.imports[0].split("/")[-1])
143
+ .replace("__accumulator_type__", self.java_signature.params[0].data_type)
144
+ .replace("__value_type__", self.java_signature.params[1].data_type)
145
+ .replace("__return_type__", return_type)
146
+ .replace("__response_wrapper__", response_wrapper)
147
+ )
148
+
149
+ def to_create_function_sql(self) -> str:
150
+ """
151
+ Generate the complete CREATE FUNCTION SQL statement for the Java UDAF.
152
+
153
+ Creates a Snowflake CREATE OR REPLACE TEMPORARY AGGREGATE FUNCTION statement with
154
+ all necessary clauses including language, runtime version, packages,
155
+ imports, and the Java code body.
156
+
157
+ Returns:
158
+ Complete SQL DDL statement for creating the UDAF
159
+ """
160
+
161
+ args = ", ".join(
162
+ [f"{param.name} {param.data_type}" for param in self.signature.params]
163
+ )
164
+ ret_type = self.signature.returns.data_type
165
+
166
+ def quote_single(s: str) -> str:
167
+ """Helper function to wrap strings in single quotes for SQL."""
168
+ return "'" + s + "'"
169
+
170
+ # Handler and imports
171
+ imports_sql = f"IMPORTS = ({', '.join(quote_single(x) for x in self.imports)})"
172
+
173
+ return f"""
174
+ CREATE OR REPLACE TEMPORARY AGGREGATE FUNCTION {self.name}({args})
175
+ RETURNS {ret_type}
176
+ LANGUAGE JAVA
177
+ {self.null_handling.value}
178
+ RUNTIME_VERSION = 17
179
+ PACKAGES = ('com.snowflake:snowpark:latest')
180
+ {imports_sql}
181
+ HANDLER = 'JavaUDAF'
182
+ AS
183
+ $$
184
+ {self._gen_body_java()}
185
+ $$;"""
186
+
187
+
188
+ class JavaUdaf:
189
+ """
190
+ Reference class for Java UDAFs, providing similar properties like Python UserDefinedFunction.
191
+
192
+ This class serves as a lightweight reference to a Java UDAF that has been created
193
+ in Snowflake, storing the essential metadata needed for function calls.
194
+ """
195
+
196
+ def __init__(
197
+ self,
198
+ name: str,
199
+ input_types: list[snowpark_type.DataType],
200
+ return_type: snowpark_type.DataType,
201
+ ) -> None:
202
+ """
203
+ Initialize a Java UDAF reference.
204
+
205
+ Args:
206
+ name: The name of the UDAF in Snowflake
207
+ input_types: List of input parameter types
208
+ return_type: The return type of the UDAF
209
+ """
210
+ self.name = name
211
+ self._input_types = input_types
212
+ self._return_type = return_type
213
+
214
+
215
+ def create_java_udaf_for_reduce_scala_function(
216
+ pciudf: ProcessCommonInlineUserDefinedFunction,
217
+ ) -> JavaUdaf:
218
+ """
219
+ Create a Java UDAF in Snowflake from a ProcessCommonInlineUserDefinedFunction object.
220
+
221
+ This function handles the complete process of creating a Java UDAF:
222
+ 1. Generates a unique function name if not provided
223
+ 2. Creates the necessary imports list
224
+ 3. Maps types between different systems (Snowpark, Java, Snowflake)
225
+ 4. Generates and executes the CREATE FUNCTION SQL statement
226
+
227
+ Args:
228
+ pciudf: The ProcessCommonInlineUserDefinedFunction object containing UDF details.
229
+
230
+ Returns:
231
+ A JavaUdaf object representing the Java UDAF.
232
+ """
233
+ from snowflake.snowpark_connect.resources_initializer import (
234
+ wait_for_resource_initialization,
235
+ )
236
+
237
+ # Make sure that the resource initializer thread is completed before creating Java UDFs since we depend on the jars
238
+ # uploaded by it.
239
+ wait_for_resource_initialization()
240
+
241
+ from snowflake.snowpark_connect.utils.session import get_or_create_snowpark_session
242
+
243
+ function_name = pciudf._function_name
244
+ # If a function name is not provided, hash the binary file and use the first ten characters as the function name.
245
+ if not function_name:
246
+ import hashlib
247
+
248
+ function_name = hashlib.sha256(pciudf._payload).hexdigest()[:10]
249
+ udf_name = CREATE_JAVA_UDAF_PREFIX + function_name
250
+
251
+ input_types = pciudf._input_types
252
+
253
+ java_input_params: list[Param] = []
254
+ sql_input_params: list[Param] = []
255
+ java_invocation_args: list[str] = [] # arguments passed into the udf function
256
+ if input_types: # input_types can be None when no arguments are provided
257
+ for i, input_type in enumerate(input_types):
258
+ param_name = "arg" + str(i)
259
+ # Create the Java arguments and input types string: "arg0: Type0, arg1: Type1, ...".
260
+ java_input_params.append(
261
+ Param(param_name, map_type_to_java_type(input_type))
262
+ )
263
+ # Create the Snowflake SQL arguments and input types string: "arg0 TYPE0, arg1 TYPE1, ...".
264
+ sql_input_params.append(
265
+ Param(param_name, map_type_to_snowflake_type(input_type))
266
+ )
267
+ # In the case of Map input types, we need to cast the argument to the correct type in Java.
268
+ # Snowflake SQL Java can only handle MAP[VARCHAR, VARCHAR] as input types.
269
+ java_invocation_args.append(
270
+ cast_java_map_args_from_given_type(param_name, input_type)
271
+ )
272
+
273
+ java_return_type = map_type_to_java_type(pciudf._original_return_type)
274
+ # If the SQL return type is a MAP or STRUCT, change this to VARIANT because of issues with Java UDAFs.
275
+ sql_return_type = map_type_to_snowflake_type(pciudf._original_return_type)
276
+ session = get_or_create_snowpark_session()
277
+
278
+ imports = build_jvm_udxf_imports(
279
+ session,
280
+ pciudf._payload,
281
+ udf_name,
282
+ )
283
+ sql_return_type = (
284
+ "VARIANT"
285
+ if (sql_return_type.startswith("MAP") or sql_return_type.startswith("OBJECT"))
286
+ else sql_return_type
287
+ )
288
+
289
+ udf_def = JavaUDAFDef(
290
+ name=udf_name,
291
+ signature=Signature(
292
+ params=sql_input_params, returns=ReturnType(sql_return_type)
293
+ ),
294
+ imports=imports,
295
+ java_signature=Signature(
296
+ params=java_input_params, returns=ReturnType(java_return_type)
297
+ ),
298
+ java_invocation_args=java_invocation_args,
299
+ )
300
+ create_udf_sql = udf_def.to_create_function_sql()
301
+ logger.info(f"Creating Java UDAF: {create_udf_sql}")
302
+ session.sql(create_udf_sql).collect()
303
+ return JavaUdaf(udf_name, pciudf._input_types, pciudf._return_type)
@@ -0,0 +1,239 @@
1
+ #
2
+ # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
3
+ #
4
+
5
+ import hashlib
6
+ from dataclasses import dataclass
7
+
8
+ from pyspark.sql.connect.proto.expressions_pb2 import CommonInlineUserDefinedFunction
9
+
10
+ from snowflake.snowpark.types import ArrayType, MapType, VariantType
11
+ from snowflake.snowpark_connect.resources_initializer import (
12
+ ensure_scala_udf_jars_uploaded,
13
+ )
14
+ from snowflake.snowpark_connect.type_mapping import (
15
+ map_type_to_snowflake_type,
16
+ proto_to_snowpark_type,
17
+ )
18
+ from snowflake.snowpark_connect.utils.jvm_udf_utils import (
19
+ NullHandling,
20
+ Param,
21
+ ReturnType,
22
+ Signature,
23
+ build_jvm_udxf_imports,
24
+ map_type_to_java_type,
25
+ )
26
+ from snowflake.snowpark_connect.utils.session import get_or_create_snowpark_session
27
+
28
+ JAVA_UDTF_PREFIX = "__SC_JAVA_UDTF_"
29
+
30
+ SCALA_INPUT_VARIANT = """
31
+ Object mappedInput = com.snowflake.sas.scala.UdfPacketUtils$.MODULE$.fromVariant(udfPacket, input, 0);
32
+
33
+ java.util.Iterator<Object> javaInput = Arrays.asList(mappedInput).iterator();
34
+ scala.collection.Iterator<Object> scalaInput = new scala.collection.AbstractIterator<Object>() {
35
+ public boolean hasNext() { return javaInput.hasNext(); }
36
+ public Object next() { return javaInput.next(); }
37
+ };
38
+ """
39
+
40
+ SCALA_INPUT_SIMPLE_TYPE = """
41
+ java.util.Iterator<__iterator_type__> javaInput = Arrays.asList(input).iterator();
42
+ scala.collection.Iterator<__iterator_type__> scalaInput = new scala.collection.AbstractIterator<__iterator_type__>() {
43
+ public boolean hasNext() { return javaInput.hasNext(); }
44
+ public __iterator_type__ next() { return javaInput.next(); }
45
+ };
46
+ """
47
+
48
+ UDTF_TEMPLATE = """
49
+ import org.apache.spark.sql.connect.common.UdfPacket;
50
+
51
+ import java.io.IOException;
52
+ import java.io.InputStream;
53
+ import java.io.ObjectInputStream;
54
+ import java.io.Serializable;
55
+ import java.nio.file.Files;
56
+ import java.nio.file.Paths;
57
+
58
+ import java.util.*;
59
+ import java.lang.*;
60
+ import java.util.stream.Collectors;
61
+ import com.snowflake.snowpark_java.types.*;
62
+ import java.util.stream.Stream;
63
+ import java.util.stream.StreamSupport;
64
+
65
+ public class OutputRow {
66
+ public Variant __java_udtf_prefix__C1;
67
+ public OutputRow(Variant __java_udtf_prefix__C1) {
68
+ this.__java_udtf_prefix__C1 = __java_udtf_prefix__C1;
69
+ }
70
+ }
71
+
72
+ public class JavaUdtfHandler {
73
+ private final static String OPERATION_FILE = "__operation_file__";
74
+ private static scala.Function1<scala.collection.Iterator<__iterator_type__>, scala.collection.Iterator<Object>> operation = null;
75
+ private static UdfPacket udfPacket = null;
76
+
77
+ public static Class getOutputClass() { return OutputRow.class; }
78
+
79
+ private static void loadOperation() throws IOException, ClassNotFoundException {
80
+ if (operation != null) {
81
+ return; // Already loaded
82
+ }
83
+
84
+ udfPacket = com.snowflake.sas.scala.Utils$.MODULE$.deserializeUdfPacket(OPERATION_FILE);
85
+ operation = (scala.Function1<scala.collection.Iterator<__iterator_type__>, scala.collection.Iterator<Object>>) udfPacket.function();
86
+ }
87
+
88
+ public Stream<OutputRow> process(__input_type__ input) throws IOException, ClassNotFoundException {
89
+ loadOperation();
90
+
91
+ __scala_input__
92
+
93
+ scala.collection.Iterator<Object> scalaResult = operation.apply(scalaInput);
94
+
95
+ java.util.Iterator<Variant> javaResult = new java.util.Iterator<Variant>() {
96
+ public boolean hasNext() { return scalaResult.hasNext(); }
97
+ public Variant next() {
98
+ return com.snowflake.sas.scala.Utils$.MODULE$.toVariant(scalaResult.next());
99
+ }
100
+ };
101
+
102
+ return StreamSupport.stream(Spliterators.spliteratorUnknownSize(javaResult, Spliterator.ORDERED), false)
103
+ .map(i -> new OutputRow(i));
104
+ }
105
+
106
+ public Stream<OutputRow> endPartition() {
107
+ return Stream.empty();
108
+ }
109
+ }
110
+ """
111
+
112
+
113
+ @dataclass(frozen=True)
114
+ class JavaUDTFDef:
115
+ """
116
+ Complete definition for creating a Java UDTF in Snowflake.
117
+
118
+ Contains all the information needed to generate the CREATE FUNCTION SQL statement
119
+ and the Java code body for the UDTF.
120
+
121
+ Attributes:
122
+ name: UDTF name
123
+ signature: SQL signature (for Snowflake function definition)
124
+ java_signature: Java signature (for Java code generation)
125
+ imports: List of JAR files to import
126
+ null_handling: Null handling behavior (defaults to RETURNS_NULL_ON_NULL_INPUT)
127
+ """
128
+
129
+ name: str
130
+ signature: Signature
131
+ java_signature: Signature
132
+ imports: list[str]
133
+ null_handling: NullHandling = NullHandling.RETURNS_NULL_ON_NULL_INPUT
134
+
135
+ def _gen_body_java(self) -> str:
136
+ returns_variant = self.signature.returns.data_type == "VARIANT"
137
+ return_type = (
138
+ "Variant" if returns_variant else self.java_signature.returns.data_type
139
+ )
140
+
141
+ is_variant_input = self.java_signature.params[0].data_type.lower() == "variant"
142
+
143
+ scala_input_template = (
144
+ SCALA_INPUT_VARIANT if is_variant_input else SCALA_INPUT_SIMPLE_TYPE
145
+ )
146
+
147
+ iterator_type = (
148
+ "Object" if is_variant_input else self.java_signature.params[0].data_type
149
+ )
150
+
151
+ return (
152
+ UDTF_TEMPLATE.replace("__operation_file__", self.imports[0].split("/")[-1])
153
+ .replace("__scala_input__", scala_input_template)
154
+ .replace("__iterator_type__", iterator_type)
155
+ .replace("__input_type__", self.java_signature.params[0].data_type)
156
+ .replace("__return_type__", return_type)
157
+ .replace("__java_udtf_prefix__", JAVA_UDTF_PREFIX)
158
+ )
159
+
160
+ def to_create_function_sql(self) -> str:
161
+ args = ", ".join(
162
+ [f"{param.name} {param.data_type}" for param in self.signature.params]
163
+ )
164
+
165
+ def quote_single(s: str) -> str:
166
+ """Helper function to wrap strings in single quotes for SQL."""
167
+ return "'" + s + "'"
168
+
169
+ # Handler and imports
170
+ imports_sql = f"IMPORTS = ({', '.join(quote_single(x) for x in self.imports)})"
171
+
172
+ return f"""
173
+ create or replace function {self.name}({args})
174
+ returns table ({JAVA_UDTF_PREFIX}C1 VARIANT)
175
+ language java
176
+ runtime_version = 17
177
+ PACKAGES = ('com.snowflake:snowpark:latest')
178
+ {imports_sql}
179
+ handler='JavaUdtfHandler'
180
+ as
181
+ $$
182
+ {self._gen_body_java()}
183
+ $$;"""
184
+
185
+
186
+ def create_java_udtf_for_scala_flatmap_handling(
187
+ udf_proto: CommonInlineUserDefinedFunction,
188
+ ) -> str:
189
+ ensure_scala_udf_jars_uploaded()
190
+
191
+ return_type = proto_to_snowpark_type(udf_proto.scalar_scala_udf.outputType)
192
+
193
+ session = get_or_create_snowpark_session()
194
+
195
+ return_type_java = map_type_to_java_type(return_type)
196
+ sql_return_type = map_type_to_snowflake_type(return_type)
197
+
198
+ java_input_params: list[Param] = []
199
+ sql_input_params: list[Param] = []
200
+ for i, input_type_proto in enumerate(udf_proto.scalar_scala_udf.inputTypes):
201
+ input_type = proto_to_snowpark_type(input_type_proto)
202
+
203
+ param_name = "arg" + str(i)
204
+
205
+ if isinstance(input_type, (ArrayType, MapType, VariantType)):
206
+ java_type = "Variant"
207
+ snowflake_type = "Variant"
208
+ else:
209
+ java_type = map_type_to_java_type(input_type)
210
+ snowflake_type = map_type_to_snowflake_type(input_type)
211
+
212
+ java_input_params.append(Param(param_name, java_type))
213
+ sql_input_params.append(Param(param_name, snowflake_type))
214
+
215
+ udtf_name = (
216
+ JAVA_UDTF_PREFIX + hashlib.md5(udf_proto.scalar_scala_udf.payload).hexdigest()
217
+ )
218
+
219
+ imports = build_jvm_udxf_imports(
220
+ session,
221
+ udf_proto.scalar_scala_udf.payload,
222
+ udtf_name,
223
+ )
224
+
225
+ udtf = JavaUDTFDef(
226
+ name=udtf_name,
227
+ signature=Signature(
228
+ params=sql_input_params, returns=ReturnType(sql_return_type)
229
+ ),
230
+ imports=imports,
231
+ java_signature=Signature(
232
+ params=java_input_params, returns=ReturnType(return_type_java)
233
+ ),
234
+ )
235
+
236
+ sql = udtf.to_create_function_sql()
237
+ session.sql(sql).collect()
238
+
239
+ return udtf_name