snowpark-connect 0.27.0__py3-none-any.whl → 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. snowflake/snowpark_connect/__init__.py +1 -0
  2. snowflake/snowpark_connect/analyze_plan/map_tree_string.py +8 -4
  3. snowflake/snowpark_connect/client/__init__.py +15 -0
  4. snowflake/snowpark_connect/client/error_utils.py +30 -0
  5. snowflake/snowpark_connect/client/exceptions.py +36 -0
  6. snowflake/snowpark_connect/client/query_results.py +90 -0
  7. snowflake/snowpark_connect/client/server.py +717 -0
  8. snowflake/snowpark_connect/client/utils/__init__.py +10 -0
  9. snowflake/snowpark_connect/client/utils/session.py +85 -0
  10. snowflake/snowpark_connect/column_name_handler.py +404 -243
  11. snowflake/snowpark_connect/column_qualifier.py +43 -0
  12. snowflake/snowpark_connect/config.py +309 -26
  13. snowflake/snowpark_connect/constants.py +2 -0
  14. snowflake/snowpark_connect/dataframe_container.py +102 -8
  15. snowflake/snowpark_connect/date_time_format_mapping.py +71 -13
  16. snowflake/snowpark_connect/error/error_codes.py +50 -0
  17. snowflake/snowpark_connect/error/error_utils.py +172 -23
  18. snowflake/snowpark_connect/error/exceptions.py +13 -4
  19. snowflake/snowpark_connect/execute_plan/map_execution_command.py +15 -160
  20. snowflake/snowpark_connect/execute_plan/map_execution_root.py +26 -20
  21. snowflake/snowpark_connect/execute_plan/utils.py +5 -1
  22. snowflake/snowpark_connect/expression/error_utils.py +28 -0
  23. snowflake/snowpark_connect/expression/function_defaults.py +9 -2
  24. snowflake/snowpark_connect/expression/hybrid_column_map.py +53 -5
  25. snowflake/snowpark_connect/expression/integral_types_support.py +219 -0
  26. snowflake/snowpark_connect/expression/literal.py +37 -13
  27. snowflake/snowpark_connect/expression/map_cast.py +224 -15
  28. snowflake/snowpark_connect/expression/map_expression.py +80 -27
  29. snowflake/snowpark_connect/expression/map_extension.py +322 -12
  30. snowflake/snowpark_connect/expression/map_sql_expression.py +316 -81
  31. snowflake/snowpark_connect/expression/map_udf.py +86 -20
  32. snowflake/snowpark_connect/expression/map_unresolved_attribute.py +451 -173
  33. snowflake/snowpark_connect/expression/map_unresolved_function.py +2964 -829
  34. snowflake/snowpark_connect/expression/map_unresolved_star.py +87 -23
  35. snowflake/snowpark_connect/expression/map_update_fields.py +70 -18
  36. snowflake/snowpark_connect/expression/map_window_function.py +18 -3
  37. snowflake/snowpark_connect/includes/jars/json4s-ast_2.13-3.7.0-M11.jar +0 -0
  38. snowflake/snowpark_connect/includes/jars/{scala-library-2.12.18.jar → sas-scala-udf_2.12-0.2.0.jar} +0 -0
  39. snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.13-0.2.0.jar +0 -0
  40. snowflake/snowpark_connect/includes/jars/scala-reflect-2.13.16.jar +0 -0
  41. snowflake/snowpark_connect/includes/jars/spark-common-utils_2.13-3.5.6.jar +0 -0
  42. snowflake/snowpark_connect/includes/jars/{spark-connect-client-jvm_2.12-3.5.6.jar → spark-connect-client-jvm_2.13-3.5.6.jar} +0 -0
  43. snowflake/snowpark_connect/includes/jars/{spark-sql_2.12-3.5.6.jar → spark-sql_2.13-3.5.6.jar} +0 -0
  44. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py +1 -1
  45. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/listener_worker.py +1 -1
  46. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.py +12 -10
  47. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.pyi +14 -2
  48. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +10 -8
  49. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +13 -6
  50. snowflake/snowpark_connect/relation/catalogs/abstract_spark_catalog.py +65 -17
  51. snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +297 -49
  52. snowflake/snowpark_connect/relation/catalogs/utils.py +12 -4
  53. snowflake/snowpark_connect/relation/io_utils.py +110 -10
  54. snowflake/snowpark_connect/relation/map_aggregate.py +239 -256
  55. snowflake/snowpark_connect/relation/map_catalog.py +5 -1
  56. snowflake/snowpark_connect/relation/map_column_ops.py +264 -96
  57. snowflake/snowpark_connect/relation/map_extension.py +263 -29
  58. snowflake/snowpark_connect/relation/map_join.py +683 -442
  59. snowflake/snowpark_connect/relation/map_local_relation.py +28 -1
  60. snowflake/snowpark_connect/relation/map_map_partitions.py +83 -8
  61. snowflake/snowpark_connect/relation/map_relation.py +48 -19
  62. snowflake/snowpark_connect/relation/map_row_ops.py +310 -91
  63. snowflake/snowpark_connect/relation/map_show_string.py +13 -6
  64. snowflake/snowpark_connect/relation/map_sql.py +1233 -222
  65. snowflake/snowpark_connect/relation/map_stats.py +48 -9
  66. snowflake/snowpark_connect/relation/map_subquery_alias.py +11 -2
  67. snowflake/snowpark_connect/relation/map_udtf.py +14 -4
  68. snowflake/snowpark_connect/relation/read/jdbc_read_dbapi.py +53 -14
  69. snowflake/snowpark_connect/relation/read/map_read.py +134 -43
  70. snowflake/snowpark_connect/relation/read/map_read_csv.py +326 -47
  71. snowflake/snowpark_connect/relation/read/map_read_jdbc.py +21 -6
  72. snowflake/snowpark_connect/relation/read/map_read_json.py +324 -86
  73. snowflake/snowpark_connect/relation/read/map_read_parquet.py +146 -28
  74. snowflake/snowpark_connect/relation/read/map_read_partitioned_parquet.py +142 -0
  75. snowflake/snowpark_connect/relation/read/map_read_socket.py +15 -3
  76. snowflake/snowpark_connect/relation/read/map_read_table.py +86 -6
  77. snowflake/snowpark_connect/relation/read/map_read_text.py +22 -4
  78. snowflake/snowpark_connect/relation/read/metadata_utils.py +170 -0
  79. snowflake/snowpark_connect/relation/read/reader_config.py +42 -3
  80. snowflake/snowpark_connect/relation/read/utils.py +50 -5
  81. snowflake/snowpark_connect/relation/stage_locator.py +91 -55
  82. snowflake/snowpark_connect/relation/utils.py +128 -5
  83. snowflake/snowpark_connect/relation/write/jdbc_write_dbapi.py +19 -3
  84. snowflake/snowpark_connect/relation/write/map_write.py +929 -319
  85. snowflake/snowpark_connect/relation/write/map_write_jdbc.py +8 -2
  86. snowflake/snowpark_connect/resources/java_udfs-1.0-SNAPSHOT.jar +0 -0
  87. snowflake/snowpark_connect/resources_initializer.py +171 -48
  88. snowflake/snowpark_connect/server.py +528 -473
  89. snowflake/snowpark_connect/server_common/__init__.py +503 -0
  90. snowflake/snowpark_connect/snowflake_session.py +65 -0
  91. snowflake/snowpark_connect/start_server.py +53 -5
  92. snowflake/snowpark_connect/type_mapping.py +349 -27
  93. snowflake/snowpark_connect/type_support.py +130 -0
  94. snowflake/snowpark_connect/typed_column.py +9 -7
  95. snowflake/snowpark_connect/utils/artifacts.py +9 -8
  96. snowflake/snowpark_connect/utils/cache.py +49 -27
  97. snowflake/snowpark_connect/utils/concurrent.py +36 -1
  98. snowflake/snowpark_connect/utils/context.py +195 -37
  99. snowflake/snowpark_connect/utils/describe_query_cache.py +68 -53
  100. snowflake/snowpark_connect/utils/env_utils.py +5 -1
  101. snowflake/snowpark_connect/utils/expression_transformer.py +172 -0
  102. snowflake/snowpark_connect/utils/identifiers.py +137 -3
  103. snowflake/snowpark_connect/utils/io_utils.py +57 -1
  104. snowflake/snowpark_connect/utils/java_stored_procedure.py +151 -0
  105. snowflake/snowpark_connect/utils/java_udaf_utils.py +321 -0
  106. snowflake/snowpark_connect/utils/java_udtf_utils.py +239 -0
  107. snowflake/snowpark_connect/utils/jvm_udf_utils.py +281 -0
  108. snowflake/snowpark_connect/utils/open_telemetry.py +516 -0
  109. snowflake/snowpark_connect/utils/pandas_udtf_utils.py +8 -4
  110. snowflake/snowpark_connect/utils/patch_spark_line_number.py +181 -0
  111. snowflake/snowpark_connect/utils/profiling.py +25 -8
  112. snowflake/snowpark_connect/utils/scala_udf_utils.py +185 -340
  113. snowflake/snowpark_connect/utils/sequence.py +21 -0
  114. snowflake/snowpark_connect/utils/session.py +64 -28
  115. snowflake/snowpark_connect/utils/snowpark_connect_logging.py +51 -9
  116. snowflake/snowpark_connect/utils/spcs_logger.py +290 -0
  117. snowflake/snowpark_connect/utils/telemetry.py +192 -40
  118. snowflake/snowpark_connect/utils/temporary_view_cache.py +67 -0
  119. snowflake/snowpark_connect/utils/temporary_view_helper.py +334 -0
  120. snowflake/snowpark_connect/utils/udf_cache.py +117 -41
  121. snowflake/snowpark_connect/utils/udf_helper.py +39 -37
  122. snowflake/snowpark_connect/utils/udf_utils.py +133 -14
  123. snowflake/snowpark_connect/utils/udtf_helper.py +8 -1
  124. snowflake/snowpark_connect/utils/udtf_utils.py +46 -31
  125. snowflake/snowpark_connect/utils/udxf_import_utils.py +9 -2
  126. snowflake/snowpark_connect/utils/upload_java_jar.py +57 -0
  127. snowflake/snowpark_connect/version.py +1 -1
  128. snowflake/snowpark_decoder/dp_session.py +6 -2
  129. snowflake/snowpark_decoder/spark_decoder.py +12 -0
  130. {snowpark_connect-0.27.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-submit +14 -4
  131. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/METADATA +16 -7
  132. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/RECORD +139 -168
  133. snowflake/snowpark_connect/hidden_column.py +0 -39
  134. snowflake/snowpark_connect/includes/jars/antlr4-runtime-4.9.3.jar +0 -0
  135. snowflake/snowpark_connect/includes/jars/commons-cli-1.5.0.jar +0 -0
  136. snowflake/snowpark_connect/includes/jars/commons-codec-1.16.1.jar +0 -0
  137. snowflake/snowpark_connect/includes/jars/commons-collections-3.2.2.jar +0 -0
  138. snowflake/snowpark_connect/includes/jars/commons-collections4-4.4.jar +0 -0
  139. snowflake/snowpark_connect/includes/jars/commons-compiler-3.1.9.jar +0 -0
  140. snowflake/snowpark_connect/includes/jars/commons-compress-1.26.0.jar +0 -0
  141. snowflake/snowpark_connect/includes/jars/commons-crypto-1.1.0.jar +0 -0
  142. snowflake/snowpark_connect/includes/jars/commons-dbcp-1.4.jar +0 -0
  143. snowflake/snowpark_connect/includes/jars/commons-io-2.16.1.jar +0 -0
  144. snowflake/snowpark_connect/includes/jars/commons-lang-2.6.jar +0 -0
  145. snowflake/snowpark_connect/includes/jars/commons-lang3-3.12.0.jar +0 -0
  146. snowflake/snowpark_connect/includes/jars/commons-logging-1.1.3.jar +0 -0
  147. snowflake/snowpark_connect/includes/jars/commons-math3-3.6.1.jar +0 -0
  148. snowflake/snowpark_connect/includes/jars/commons-pool-1.5.4.jar +0 -0
  149. snowflake/snowpark_connect/includes/jars/commons-text-1.10.0.jar +0 -0
  150. snowflake/snowpark_connect/includes/jars/hadoop-client-api-trimmed-3.3.4.jar +0 -0
  151. snowflake/snowpark_connect/includes/jars/jackson-annotations-2.15.2.jar +0 -0
  152. snowflake/snowpark_connect/includes/jars/jackson-core-2.15.2.jar +0 -0
  153. snowflake/snowpark_connect/includes/jars/jackson-core-asl-1.9.13.jar +0 -0
  154. snowflake/snowpark_connect/includes/jars/jackson-databind-2.15.2.jar +0 -0
  155. snowflake/snowpark_connect/includes/jars/jackson-dataformat-yaml-2.15.2.jar +0 -0
  156. snowflake/snowpark_connect/includes/jars/jackson-datatype-jsr310-2.15.2.jar +0 -0
  157. snowflake/snowpark_connect/includes/jars/jackson-module-scala_2.12-2.15.2.jar +0 -0
  158. snowflake/snowpark_connect/includes/jars/json4s-ast_2.12-3.7.0-M11.jar +0 -0
  159. snowflake/snowpark_connect/includes/jars/json4s-core_2.12-3.7.0-M11.jar +0 -0
  160. snowflake/snowpark_connect/includes/jars/json4s-jackson_2.12-3.7.0-M11.jar +0 -0
  161. snowflake/snowpark_connect/includes/jars/json4s-native_2.12-3.7.0-M11.jar +0 -0
  162. snowflake/snowpark_connect/includes/jars/json4s-scalap_2.12-3.7.0-M11.jar +0 -0
  163. snowflake/snowpark_connect/includes/jars/kryo-shaded-4.0.2.jar +0 -0
  164. snowflake/snowpark_connect/includes/jars/log4j-1.2-api-2.20.0.jar +0 -0
  165. snowflake/snowpark_connect/includes/jars/log4j-api-2.20.0.jar +0 -0
  166. snowflake/snowpark_connect/includes/jars/log4j-core-2.20.0.jar +0 -0
  167. snowflake/snowpark_connect/includes/jars/log4j-slf4j2-impl-2.20.0.jar +0 -0
  168. snowflake/snowpark_connect/includes/jars/paranamer-2.8.3.jar +0 -0
  169. snowflake/snowpark_connect/includes/jars/paranamer-2.8.jar +0 -0
  170. snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.1.0.jar +0 -0
  171. snowflake/snowpark_connect/includes/jars/scala-collection-compat_2.12-2.7.0.jar +0 -0
  172. snowflake/snowpark_connect/includes/jars/scala-parser-combinators_2.12-2.3.0.jar +0 -0
  173. snowflake/snowpark_connect/includes/jars/scala-reflect-2.12.18.jar +0 -0
  174. snowflake/snowpark_connect/includes/jars/scala-xml_2.12-2.1.0.jar +0 -0
  175. snowflake/snowpark_connect/includes/jars/slf4j-api-2.0.7.jar +0 -0
  176. snowflake/snowpark_connect/includes/jars/spark-catalyst_2.12-3.5.6.jar +0 -0
  177. snowflake/snowpark_connect/includes/jars/spark-common-utils_2.12-3.5.6.jar +0 -0
  178. snowflake/snowpark_connect/includes/jars/spark-core_2.12-3.5.6.jar +0 -0
  179. snowflake/snowpark_connect/includes/jars/spark-graphx_2.12-3.5.6.jar +0 -0
  180. snowflake/snowpark_connect/includes/jars/spark-hive-thriftserver_2.12-3.5.6.jar +0 -0
  181. snowflake/snowpark_connect/includes/jars/spark-hive_2.12-3.5.6.jar +0 -0
  182. snowflake/snowpark_connect/includes/jars/spark-kvstore_2.12-3.5.6.jar +0 -0
  183. snowflake/snowpark_connect/includes/jars/spark-launcher_2.12-3.5.6.jar +0 -0
  184. snowflake/snowpark_connect/includes/jars/spark-mesos_2.12-3.5.6.jar +0 -0
  185. snowflake/snowpark_connect/includes/jars/spark-mllib-local_2.12-3.5.6.jar +0 -0
  186. snowflake/snowpark_connect/includes/jars/spark-network-common_2.12-3.5.6.jar +0 -0
  187. snowflake/snowpark_connect/includes/jars/spark-network-shuffle_2.12-3.5.6.jar +0 -0
  188. snowflake/snowpark_connect/includes/jars/spark-repl_2.12-3.5.6.jar +0 -0
  189. snowflake/snowpark_connect/includes/jars/spark-sketch_2.12-3.5.6.jar +0 -0
  190. snowflake/snowpark_connect/includes/jars/spark-sql-api_2.12-3.5.6.jar +0 -0
  191. snowflake/snowpark_connect/includes/jars/spark-tags_2.12-3.5.6.jar +0 -0
  192. snowflake/snowpark_connect/includes/jars/spark-unsafe_2.12-3.5.6.jar +0 -0
  193. snowflake/snowpark_connect/includes/jars/spark-yarn_2.12-3.5.6.jar +0 -0
  194. {snowpark_connect-0.27.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-connect +0 -0
  195. {snowpark_connect-0.27.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-session +0 -0
  196. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/WHEEL +0 -0
  197. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/LICENSE-binary +0 -0
  198. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/LICENSE.txt +0 -0
  199. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/NOTICE-binary +0 -0
  200. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,151 @@
1
+ #
2
+ # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
3
+ #
4
+
5
+ from pyspark.errors import AnalysisException
6
+
7
+ import snowflake.snowpark.types as snowpark_type
8
+ from snowflake.snowpark import Session
9
+ from snowflake.snowpark._internal.type_utils import type_string_to_type_object
10
+ from snowflake.snowpark_connect.client.error_utils import attach_custom_error_code
11
+ from snowflake.snowpark_connect.config import (
12
+ get_scala_version,
13
+ is_java_udf_creator_initialized,
14
+ set_java_udf_creator_initialized_state,
15
+ )
16
+ from snowflake.snowpark_connect.error.error_codes import ErrorCodes
17
+ from snowflake.snowpark_connect.resources_initializer import (
18
+ RESOURCE_PATH,
19
+ SPARK_COMMON_UTILS_JAR_212,
20
+ SPARK_COMMON_UTILS_JAR_213,
21
+ SPARK_CONNECT_CLIENT_JAR_212,
22
+ SPARK_CONNECT_CLIENT_JAR_213,
23
+ SPARK_SQL_JAR_212,
24
+ SPARK_SQL_JAR_213,
25
+ ensure_scala_udf_jars_uploaded,
26
+ )
27
+ from snowflake.snowpark_connect.utils.upload_java_jar import upload_java_udf_jar
28
+
29
+ CREATE_JAVA_UDF_PREFIX = "__SC_JAVA_UDF_"
30
+ PROCEDURE_NAME = "__SC_JAVA_SP_CREATE_JAVA_UDF"
31
+ SP_TEMPLATE = """
32
+ CREATE OR REPLACE TEMPORARY PROCEDURE __SC_JAVA_SP_CREATE_JAVA_UDF(udf_name VARCHAR, udf_class VARCHAR, imports ARRAY(VARCHAR))
33
+ RETURNS VARCHAR
34
+ LANGUAGE JAVA
35
+ RUNTIME_VERSION = 17
36
+ PACKAGES = ('com.snowflake:snowpark___scala_version__:latest')
37
+ __snowflake_udf_imports__
38
+ HANDLER = 'com.snowflake.snowpark_connect.procedures.JavaUDFCreator.process'
39
+ EXECUTE AS CALLER
40
+ ;
41
+ """
42
+
43
+
44
+ class JavaUdf:
45
+ """
46
+ Reference class for Java UDFs, providing similar properties like Python UserDefinedFunction.
47
+
48
+ This class serves as a lightweight reference to a Java UDF that has been created
49
+ in Snowflake, storing the essential metadata needed for function calls.
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ name: str,
55
+ input_types: list[snowpark_type.DataType],
56
+ return_type: snowpark_type.DataType,
57
+ ) -> None:
58
+ """
59
+ Initialize a Java UDF reference.
60
+
61
+ Args:
62
+ name: The name of the UDF in Snowflake
63
+ input_types: List of input parameter types
64
+ return_type: The return type of the UDF
65
+ """
66
+ self.name = name
67
+ self._input_types = input_types
68
+ self._return_type = return_type
69
+
70
+
71
+ def _scala_static_imports_for_sproc(stage_resource_path: str) -> set[str]:
72
+ scala_version = get_scala_version()
73
+ if scala_version == "2.12":
74
+ return {
75
+ f"{stage_resource_path}/{SPARK_CONNECT_CLIENT_JAR_212}",
76
+ f"{stage_resource_path}/{SPARK_COMMON_UTILS_JAR_212}",
77
+ f"{stage_resource_path}/{SPARK_SQL_JAR_212}",
78
+ }
79
+
80
+ if scala_version == "2.13":
81
+ return {
82
+ f"{stage_resource_path}/{SPARK_CONNECT_CLIENT_JAR_213}",
83
+ f"{stage_resource_path}/{SPARK_COMMON_UTILS_JAR_213}",
84
+ f"{stage_resource_path}/{SPARK_SQL_JAR_213}",
85
+ }
86
+
87
+ # invalid Scala version
88
+ exception = ValueError(
89
+ f"Unsupported Scala version: {scala_version}. Snowpark Connect supports Scala 2.12 and 2.13"
90
+ )
91
+ attach_custom_error_code(exception, ErrorCodes.INVALID_CONFIG_VALUE)
92
+ raise exception
93
+
94
+
95
+ def get_quoted_imports(session: Session) -> str:
96
+ stage_resource_path = session.get_session_stage() + RESOURCE_PATH
97
+ spark_imports = _scala_static_imports_for_sproc(stage_resource_path) | {
98
+ f"{stage_resource_path}/java_udfs-1.0-SNAPSHOT.jar",
99
+ }
100
+
101
+ def quote_single(s: str) -> str:
102
+ """Helper function to wrap strings in single quotes for SQL."""
103
+ return "'" + s + "'"
104
+
105
+ from snowflake.snowpark_connect.config import global_config
106
+
107
+ config_imports = global_config.get("snowpark.connect.udf.java.imports", "")
108
+ config_imports = (
109
+ {x.strip() for x in config_imports.strip("[] ").split(",") if x.strip()}
110
+ if config_imports
111
+ else set()
112
+ )
113
+
114
+ return ", ".join(
115
+ quote_single(x) for x in session._artifact_jars | spark_imports | config_imports
116
+ )
117
+
118
+
119
+ def create_snowflake_imports(session: Session) -> str:
120
+ # Make sure that the resource initializer thread is completed before creating Java UDFs since we depend on the jars
121
+ # uploaded by it.
122
+ ensure_scala_udf_jars_uploaded()
123
+
124
+ return f"IMPORTS = ({get_quoted_imports(session)})"
125
+
126
+
127
+ def create_java_udf(session: Session, function_name: str, java_class: str):
128
+ if not is_java_udf_creator_initialized():
129
+ upload_java_udf_jar(session)
130
+ session.sql(
131
+ SP_TEMPLATE.replace(
132
+ "__snowflake_udf_imports__", create_snowflake_imports(session)
133
+ ).replace("__scala_version__", get_scala_version())
134
+ ).collect()
135
+ set_java_udf_creator_initialized_state(True)
136
+ name = CREATE_JAVA_UDF_PREFIX + function_name
137
+ result = session.sql(
138
+ f"CALL {PROCEDURE_NAME}('{name}', '{java_class}', ARRAY_CONSTRUCT({get_quoted_imports(session)})::ARRAY(VARCHAR))"
139
+ ).collect()
140
+ result_value = result[0][0]
141
+ if not result_value:
142
+ raise AnalysisException(f"Can not load class {java_class}")
143
+ types = result_value.split(";")
144
+ input_types = [type_string_to_type_object(t) for t in types[:-1]]
145
+ output_type = types[-1]
146
+
147
+ return JavaUdf(
148
+ name,
149
+ input_types,
150
+ type_string_to_type_object(output_type),
151
+ )
@@ -0,0 +1,321 @@
1
+ #
2
+ # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
3
+ #
4
+
5
+ from dataclasses import dataclass
6
+
7
+ import snowflake.snowpark.types as snowpark_type
8
+ from snowflake.snowpark_connect.type_mapping import map_type_to_snowflake_type
9
+ from snowflake.snowpark_connect.utils.jvm_udf_utils import (
10
+ NullHandling,
11
+ Param,
12
+ ReturnType,
13
+ Signature,
14
+ build_jvm_udxf_imports,
15
+ map_type_to_java_type,
16
+ )
17
+ from snowflake.snowpark_connect.utils.snowpark_connect_logging import logger
18
+ from snowflake.snowpark_connect.utils.udf_utils import (
19
+ ProcessCommonInlineUserDefinedFunction,
20
+ )
21
+
22
+ # Prefix used for internally generated Java UDAF names to avoid conflicts
23
+ CREATE_JAVA_UDAF_PREFIX = "__SC_JAVA_UDAF_"
24
+
25
+
26
+ UDAF_TEMPLATE = """
27
+ import org.apache.spark.sql.connect.common.UdfPacket;
28
+
29
+ import java.io.IOException;
30
+ import java.io.InputStream;
31
+ import java.io.ObjectInputStream;
32
+ import java.io.Serializable;
33
+ import java.nio.file.Files;
34
+ import java.nio.file.Paths;
35
+
36
+ // Import types required for mapping
37
+ import java.util.*;
38
+ import java.util.stream.Collectors;
39
+ import com.snowflake.snowpark_java.types.*;
40
+
41
+ public class JavaUDAF {
42
+ private final static String OPERATION_FILE = "__operation_file__";
43
+ private static scala.Function2<__reduce_type__, __reduce_type__, __reduce_type__> operation = null;
44
+ private static UdfPacket udfPacket = null;
45
+
46
+ private static void loadOperation() throws IOException, ClassNotFoundException {
47
+ if (operation != null) {
48
+ return; // Already loaded
49
+ }
50
+
51
+ udfPacket = com.snowflake.sas.scala.Utils$.MODULE$.deserializeUdfPacket(OPERATION_FILE);
52
+ operation = (scala.Function2<__reduce_type__, __reduce_type__, __reduce_type__>) udfPacket.function();
53
+ }
54
+
55
+ public static class State implements Serializable {
56
+ public __reduce_type__ value = null;
57
+ public boolean initialized = false;
58
+ }
59
+
60
+ public static State initialize() throws IOException, ClassNotFoundException {
61
+ loadOperation();
62
+ return new State();
63
+ }
64
+
65
+ public static State accumulate(State state, __accumulator_type__ accumulator, __value_type__ input) {
66
+ // TODO: Add conversion between value_type we get in input and the value that we are using in the operation
67
+ if (input == null) {
68
+ return state;
69
+ }
70
+
71
+ if (!state.initialized) {
72
+ state.value = __mapped_value__;
73
+ state.initialized = true;
74
+ } else {
75
+ state.value = operation.apply(state.value, __mapped_value__);
76
+ }
77
+ return state;
78
+ }
79
+
80
+ public static State merge(State s1, State s2) {
81
+ if (!s2.initialized) {
82
+ return s1;
83
+ }
84
+ if (!s1.initialized) {
85
+ return s2;
86
+ }
87
+
88
+ s1.value = operation.apply(s1.value, s2.value);
89
+ return s1;
90
+ }
91
+
92
+ public static __return_type__ finish(State state) {
93
+ return state.initialized ? __response_wrapper__ : null;
94
+ }
95
+ }"""
96
+
97
+
98
+ @dataclass(frozen=True)
99
+ class JavaUDAFDef:
100
+ """
101
+ Complete definition for creating a Java UDAF in Snowflake.
102
+
103
+ Contains all the information needed to generate the CREATE FUNCTION SQL statement
104
+ and the Java code body for the UDAF.
105
+
106
+ Attributes:
107
+ name: UDAF name
108
+ signature: SQL signature (for Snowflake function definition)
109
+ java_signature: Java signature (for Java code generation)
110
+ java_invocation_args: List of transformed arguments passed to the Java UDAF invocation, with type casting applied for Map types and other necessary conversions.
111
+ imports: List of JAR files to import
112
+ null_handling: Null handling behavior (defaults to RETURNS_NULL_ON_NULL_INPUT)
113
+ """
114
+
115
+ name: str
116
+ signature: Signature
117
+ java_signature: Signature
118
+ imports: list[str]
119
+ null_handling: NullHandling = NullHandling.RETURNS_NULL_ON_NULL_INPUT
120
+
121
+ # -------------------- DDL Emitter --------------------
122
+
123
+ def _gen_body_java(self) -> str:
124
+ """
125
+ Generate the Java code body for the UDAF.
126
+
127
+ Creates a Java object that loads the serialized function from a binary file
128
+ and provides a run method to execute it.
129
+
130
+ Returns:
131
+ String containing the complete Java code for the UDAF body
132
+ """
133
+ returns_variant = self.signature.returns.data_type.lower() == "variant"
134
+ return_type = (
135
+ "Variant" if returns_variant else self.java_signature.params[0].data_type
136
+ )
137
+ response_wrapper = (
138
+ "com.snowflake.sas.scala.Utils$.MODULE$.toVariant(state.value, udfPacket)"
139
+ if returns_variant
140
+ else "state.value"
141
+ )
142
+
143
+ is_variant_input = self.java_signature.params[0].data_type.lower() == "variant"
144
+ reduce_type = (
145
+ "Object" if is_variant_input else self.java_signature.params[0].data_type
146
+ )
147
+ return (
148
+ UDAF_TEMPLATE.replace("__operation_file__", self.imports[0].split("/")[-1])
149
+ .replace("__accumulator_type__", self.java_signature.params[0].data_type)
150
+ .replace("__value_type__", self.java_signature.params[1].data_type)
151
+ .replace(
152
+ "__mapped_value__",
153
+ "com.snowflake.sas.scala.UdfPacketUtils$.MODULE$.fromVariant(udfPacket, input, 0)"
154
+ if is_variant_input
155
+ else "input",
156
+ )
157
+ .replace("__reduce_type__", reduce_type)
158
+ .replace("__return_type__", return_type)
159
+ .replace("__response_wrapper__", response_wrapper)
160
+ )
161
+
162
+ def to_create_function_sql(self) -> str:
163
+ """
164
+ Generate the complete CREATE FUNCTION SQL statement for the Java UDAF.
165
+
166
+ Creates a Snowflake CREATE OR REPLACE TEMPORARY AGGREGATE FUNCTION statement with
167
+ all necessary clauses including language, runtime version, packages,
168
+ imports, and the Java code body.
169
+
170
+ Returns:
171
+ Complete SQL DDL statement for creating the UDAF
172
+ """
173
+
174
+ args = ", ".join(
175
+ [f"{param.name} {param.data_type}" for param in self.signature.params]
176
+ )
177
+ ret_type = self.signature.returns.data_type
178
+
179
+ def quote_single(s: str) -> str:
180
+ """Helper function to wrap strings in single quotes for SQL."""
181
+ return "'" + s + "'"
182
+
183
+ # Handler and imports
184
+ imports_sql = f"IMPORTS = ({', '.join(quote_single(x) for x in self.imports)})"
185
+
186
+ return f"""
187
+ CREATE OR REPLACE TEMPORARY AGGREGATE FUNCTION {self.name}({args})
188
+ RETURNS {ret_type}
189
+ LANGUAGE JAVA
190
+ {self.null_handling.value}
191
+ RUNTIME_VERSION = 17
192
+ PACKAGES = ('com.snowflake:snowpark:latest')
193
+ {imports_sql}
194
+ HANDLER = 'JavaUDAF'
195
+ AS
196
+ $$
197
+ {self._gen_body_java()}
198
+ $$;"""
199
+
200
+
201
+ class JavaUdaf:
202
+ """
203
+ Reference class for Java UDAFs, providing similar properties like Python UserDefinedFunction.
204
+
205
+ This class serves as a lightweight reference to a Java UDAF that has been created
206
+ in Snowflake, storing the essential metadata needed for function calls.
207
+ """
208
+
209
+ def __init__(
210
+ self,
211
+ name: str,
212
+ input_types: list[snowpark_type.DataType],
213
+ return_type: snowpark_type.DataType,
214
+ ) -> None:
215
+ """
216
+ Initialize a Java UDAF reference.
217
+
218
+ Args:
219
+ name: The name of the UDAF in Snowflake
220
+ input_types: List of input parameter types
221
+ return_type: The return type of the UDAF
222
+ """
223
+ self.name = name
224
+ self._input_types = input_types
225
+ self._return_type = return_type
226
+
227
+
228
+ def create_java_udaf_for_reduce_scala_function(
229
+ pciudf: ProcessCommonInlineUserDefinedFunction,
230
+ ) -> JavaUdaf:
231
+ """
232
+ Create a Java UDAF in Snowflake from a ProcessCommonInlineUserDefinedFunction object.
233
+
234
+ This function handles the complete process of creating a Java UDAF:
235
+ 1. Generates a unique function name if not provided
236
+ 2. Creates the necessary imports list
237
+ 3. Maps types between different systems (Snowpark, Java, Snowflake)
238
+ 4. Generates and executes the CREATE FUNCTION SQL statement
239
+
240
+ Args:
241
+ pciudf: The ProcessCommonInlineUserDefinedFunction object containing UDF details.
242
+
243
+ Returns:
244
+ A JavaUdaf object representing the Java UDAF.
245
+ """
246
+ from snowflake.snowpark_connect.resources_initializer import (
247
+ ensure_scala_udf_jars_uploaded,
248
+ )
249
+
250
+ # Make sure Scala UDF jars are uploaded before creating Java UDAFs since we depend on them.
251
+ ensure_scala_udf_jars_uploaded()
252
+
253
+ from snowflake.snowpark_connect.utils.session import get_or_create_snowpark_session
254
+
255
+ function_name = pciudf._function_name
256
+ # If a function name is not provided, hash the binary file and use the first ten characters as the function name.
257
+ if not function_name:
258
+ import hashlib
259
+
260
+ function_name = hashlib.sha256(pciudf._payload).hexdigest()[:10]
261
+ udf_name = CREATE_JAVA_UDAF_PREFIX + function_name
262
+
263
+ input_types = pciudf._input_types
264
+
265
+ java_input_params: list[Param] = []
266
+ sql_input_params: list[Param] = []
267
+ if input_types: # input_types can be None when no arguments are provided
268
+ for i, input_type in enumerate(input_types):
269
+ param_name = "arg" + str(i)
270
+ if isinstance(
271
+ input_type,
272
+ (
273
+ snowpark_type.ArrayType,
274
+ snowpark_type.MapType,
275
+ snowpark_type.VariantType,
276
+ ),
277
+ ):
278
+ java_type = "Variant"
279
+ snowflake_type = "Variant"
280
+ else:
281
+ java_type = map_type_to_java_type(input_type)
282
+ snowflake_type = map_type_to_snowflake_type(input_type)
283
+ # Create the Java arguments and input types string: "arg0: Type0, arg1: Type1, ...".
284
+ java_input_params.append(Param(param_name, java_type))
285
+ # Create the Snowflake SQL arguments and input types string: "arg0 TYPE0, arg1 TYPE1, ...".
286
+ sql_input_params.append(Param(param_name, snowflake_type))
287
+
288
+ java_return_type = map_type_to_java_type(pciudf._original_return_type)
289
+ # If the SQL return type is a MAP or STRUCT, change this to VARIANT because of issues with Java UDAFs.
290
+ sql_return_type = map_type_to_snowflake_type(pciudf._original_return_type)
291
+ session = get_or_create_snowpark_session()
292
+
293
+ imports = build_jvm_udxf_imports(
294
+ session,
295
+ pciudf._payload,
296
+ udf_name,
297
+ )
298
+ sql_return_type = (
299
+ "VARIANT"
300
+ if (
301
+ sql_return_type.startswith("MAP")
302
+ or sql_return_type.startswith("OBJECT")
303
+ or sql_return_type.startswith("ARRAY")
304
+ )
305
+ else sql_return_type
306
+ )
307
+
308
+ udf_def = JavaUDAFDef(
309
+ name=udf_name,
310
+ signature=Signature(
311
+ params=sql_input_params, returns=ReturnType(sql_return_type)
312
+ ),
313
+ imports=imports,
314
+ java_signature=Signature(
315
+ params=java_input_params, returns=ReturnType(java_return_type)
316
+ ),
317
+ )
318
+ create_udf_sql = udf_def.to_create_function_sql()
319
+ logger.info(f"Creating Java UDAF: {create_udf_sql}")
320
+ session.sql(create_udf_sql).collect()
321
+ return JavaUdaf(udf_name, pciudf._input_types, pciudf._return_type)