snowpark-connect 0.27.0__py3-none-any.whl → 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (192) hide show
  1. snowflake/snowpark_connect/__init__.py +1 -0
  2. snowflake/snowpark_connect/analyze_plan/map_tree_string.py +8 -4
  3. snowflake/snowpark_connect/client/__init__.py +15 -0
  4. snowflake/snowpark_connect/client/error_utils.py +30 -0
  5. snowflake/snowpark_connect/client/exceptions.py +36 -0
  6. snowflake/snowpark_connect/client/query_results.py +90 -0
  7. snowflake/snowpark_connect/client/server.py +680 -0
  8. snowflake/snowpark_connect/client/utils/__init__.py +10 -0
  9. snowflake/snowpark_connect/client/utils/session.py +85 -0
  10. snowflake/snowpark_connect/column_name_handler.py +404 -243
  11. snowflake/snowpark_connect/column_qualifier.py +43 -0
  12. snowflake/snowpark_connect/config.py +237 -23
  13. snowflake/snowpark_connect/constants.py +2 -0
  14. snowflake/snowpark_connect/dataframe_container.py +102 -8
  15. snowflake/snowpark_connect/date_time_format_mapping.py +71 -13
  16. snowflake/snowpark_connect/error/error_codes.py +50 -0
  17. snowflake/snowpark_connect/error/error_utils.py +172 -23
  18. snowflake/snowpark_connect/error/exceptions.py +13 -4
  19. snowflake/snowpark_connect/execute_plan/map_execution_command.py +15 -160
  20. snowflake/snowpark_connect/execute_plan/map_execution_root.py +26 -20
  21. snowflake/snowpark_connect/execute_plan/utils.py +5 -1
  22. snowflake/snowpark_connect/expression/function_defaults.py +9 -2
  23. snowflake/snowpark_connect/expression/hybrid_column_map.py +53 -5
  24. snowflake/snowpark_connect/expression/literal.py +37 -13
  25. snowflake/snowpark_connect/expression/map_cast.py +123 -5
  26. snowflake/snowpark_connect/expression/map_expression.py +80 -27
  27. snowflake/snowpark_connect/expression/map_extension.py +322 -12
  28. snowflake/snowpark_connect/expression/map_sql_expression.py +316 -81
  29. snowflake/snowpark_connect/expression/map_udf.py +85 -20
  30. snowflake/snowpark_connect/expression/map_unresolved_attribute.py +451 -173
  31. snowflake/snowpark_connect/expression/map_unresolved_function.py +2748 -746
  32. snowflake/snowpark_connect/expression/map_unresolved_star.py +87 -23
  33. snowflake/snowpark_connect/expression/map_update_fields.py +70 -18
  34. snowflake/snowpark_connect/expression/map_window_function.py +18 -3
  35. snowflake/snowpark_connect/includes/jars/{scala-library-2.12.18.jar → sas-scala-udf_2.12-0.2.0.jar} +0 -0
  36. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py +1 -1
  37. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/listener_worker.py +1 -1
  38. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.py +12 -10
  39. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.pyi +14 -2
  40. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +10 -8
  41. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +13 -6
  42. snowflake/snowpark_connect/relation/catalogs/abstract_spark_catalog.py +65 -17
  43. snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +297 -49
  44. snowflake/snowpark_connect/relation/catalogs/utils.py +12 -4
  45. snowflake/snowpark_connect/relation/io_utils.py +110 -10
  46. snowflake/snowpark_connect/relation/map_aggregate.py +196 -255
  47. snowflake/snowpark_connect/relation/map_catalog.py +5 -1
  48. snowflake/snowpark_connect/relation/map_column_ops.py +264 -96
  49. snowflake/snowpark_connect/relation/map_extension.py +263 -29
  50. snowflake/snowpark_connect/relation/map_join.py +683 -442
  51. snowflake/snowpark_connect/relation/map_local_relation.py +28 -1
  52. snowflake/snowpark_connect/relation/map_map_partitions.py +83 -8
  53. snowflake/snowpark_connect/relation/map_relation.py +48 -19
  54. snowflake/snowpark_connect/relation/map_row_ops.py +310 -91
  55. snowflake/snowpark_connect/relation/map_show_string.py +13 -6
  56. snowflake/snowpark_connect/relation/map_sql.py +1233 -222
  57. snowflake/snowpark_connect/relation/map_stats.py +48 -9
  58. snowflake/snowpark_connect/relation/map_subquery_alias.py +11 -2
  59. snowflake/snowpark_connect/relation/map_udtf.py +14 -4
  60. snowflake/snowpark_connect/relation/read/jdbc_read_dbapi.py +53 -14
  61. snowflake/snowpark_connect/relation/read/map_read.py +134 -43
  62. snowflake/snowpark_connect/relation/read/map_read_csv.py +255 -45
  63. snowflake/snowpark_connect/relation/read/map_read_jdbc.py +17 -5
  64. snowflake/snowpark_connect/relation/read/map_read_json.py +320 -85
  65. snowflake/snowpark_connect/relation/read/map_read_parquet.py +142 -27
  66. snowflake/snowpark_connect/relation/read/map_read_partitioned_parquet.py +142 -0
  67. snowflake/snowpark_connect/relation/read/map_read_socket.py +11 -3
  68. snowflake/snowpark_connect/relation/read/map_read_table.py +82 -5
  69. snowflake/snowpark_connect/relation/read/map_read_text.py +18 -3
  70. snowflake/snowpark_connect/relation/read/metadata_utils.py +170 -0
  71. snowflake/snowpark_connect/relation/read/reader_config.py +36 -3
  72. snowflake/snowpark_connect/relation/read/utils.py +50 -5
  73. snowflake/snowpark_connect/relation/stage_locator.py +91 -55
  74. snowflake/snowpark_connect/relation/utils.py +128 -5
  75. snowflake/snowpark_connect/relation/write/jdbc_write_dbapi.py +19 -3
  76. snowflake/snowpark_connect/relation/write/map_write.py +929 -319
  77. snowflake/snowpark_connect/relation/write/map_write_jdbc.py +8 -2
  78. snowflake/snowpark_connect/resources/java_udfs-1.0-SNAPSHOT.jar +0 -0
  79. snowflake/snowpark_connect/resources_initializer.py +110 -48
  80. snowflake/snowpark_connect/server.py +546 -456
  81. snowflake/snowpark_connect/server_common/__init__.py +500 -0
  82. snowflake/snowpark_connect/snowflake_session.py +65 -0
  83. snowflake/snowpark_connect/start_server.py +53 -5
  84. snowflake/snowpark_connect/type_mapping.py +349 -27
  85. snowflake/snowpark_connect/typed_column.py +9 -7
  86. snowflake/snowpark_connect/utils/artifacts.py +9 -8
  87. snowflake/snowpark_connect/utils/cache.py +49 -27
  88. snowflake/snowpark_connect/utils/concurrent.py +36 -1
  89. snowflake/snowpark_connect/utils/context.py +187 -37
  90. snowflake/snowpark_connect/utils/describe_query_cache.py +68 -53
  91. snowflake/snowpark_connect/utils/env_utils.py +5 -1
  92. snowflake/snowpark_connect/utils/expression_transformer.py +172 -0
  93. snowflake/snowpark_connect/utils/identifiers.py +137 -3
  94. snowflake/snowpark_connect/utils/io_utils.py +57 -1
  95. snowflake/snowpark_connect/utils/java_stored_procedure.py +125 -0
  96. snowflake/snowpark_connect/utils/java_udaf_utils.py +303 -0
  97. snowflake/snowpark_connect/utils/java_udtf_utils.py +239 -0
  98. snowflake/snowpark_connect/utils/jvm_udf_utils.py +248 -0
  99. snowflake/snowpark_connect/utils/open_telemetry.py +516 -0
  100. snowflake/snowpark_connect/utils/pandas_udtf_utils.py +8 -4
  101. snowflake/snowpark_connect/utils/patch_spark_line_number.py +181 -0
  102. snowflake/snowpark_connect/utils/profiling.py +25 -8
  103. snowflake/snowpark_connect/utils/scala_udf_utils.py +101 -332
  104. snowflake/snowpark_connect/utils/sequence.py +21 -0
  105. snowflake/snowpark_connect/utils/session.py +64 -28
  106. snowflake/snowpark_connect/utils/snowpark_connect_logging.py +51 -9
  107. snowflake/snowpark_connect/utils/spcs_logger.py +290 -0
  108. snowflake/snowpark_connect/utils/telemetry.py +163 -22
  109. snowflake/snowpark_connect/utils/temporary_view_cache.py +67 -0
  110. snowflake/snowpark_connect/utils/temporary_view_helper.py +334 -0
  111. snowflake/snowpark_connect/utils/udf_cache.py +117 -41
  112. snowflake/snowpark_connect/utils/udf_helper.py +39 -37
  113. snowflake/snowpark_connect/utils/udf_utils.py +133 -14
  114. snowflake/snowpark_connect/utils/udtf_helper.py +8 -1
  115. snowflake/snowpark_connect/utils/udtf_utils.py +46 -31
  116. snowflake/snowpark_connect/utils/upload_java_jar.py +57 -0
  117. snowflake/snowpark_connect/version.py +1 -1
  118. snowflake/snowpark_decoder/dp_session.py +6 -2
  119. snowflake/snowpark_decoder/spark_decoder.py +12 -0
  120. {snowpark_connect-0.27.0.data → snowpark_connect-1.6.0.data}/scripts/snowpark-submit +2 -2
  121. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/METADATA +14 -7
  122. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/RECORD +129 -167
  123. snowflake/snowpark_connect/hidden_column.py +0 -39
  124. snowflake/snowpark_connect/includes/jars/antlr4-runtime-4.9.3.jar +0 -0
  125. snowflake/snowpark_connect/includes/jars/commons-cli-1.5.0.jar +0 -0
  126. snowflake/snowpark_connect/includes/jars/commons-codec-1.16.1.jar +0 -0
  127. snowflake/snowpark_connect/includes/jars/commons-collections-3.2.2.jar +0 -0
  128. snowflake/snowpark_connect/includes/jars/commons-collections4-4.4.jar +0 -0
  129. snowflake/snowpark_connect/includes/jars/commons-compiler-3.1.9.jar +0 -0
  130. snowflake/snowpark_connect/includes/jars/commons-compress-1.26.0.jar +0 -0
  131. snowflake/snowpark_connect/includes/jars/commons-crypto-1.1.0.jar +0 -0
  132. snowflake/snowpark_connect/includes/jars/commons-dbcp-1.4.jar +0 -0
  133. snowflake/snowpark_connect/includes/jars/commons-io-2.16.1.jar +0 -0
  134. snowflake/snowpark_connect/includes/jars/commons-lang-2.6.jar +0 -0
  135. snowflake/snowpark_connect/includes/jars/commons-lang3-3.12.0.jar +0 -0
  136. snowflake/snowpark_connect/includes/jars/commons-logging-1.1.3.jar +0 -0
  137. snowflake/snowpark_connect/includes/jars/commons-math3-3.6.1.jar +0 -0
  138. snowflake/snowpark_connect/includes/jars/commons-pool-1.5.4.jar +0 -0
  139. snowflake/snowpark_connect/includes/jars/commons-text-1.10.0.jar +0 -0
  140. snowflake/snowpark_connect/includes/jars/hadoop-client-api-trimmed-3.3.4.jar +0 -0
  141. snowflake/snowpark_connect/includes/jars/jackson-annotations-2.15.2.jar +0 -0
  142. snowflake/snowpark_connect/includes/jars/jackson-core-2.15.2.jar +0 -0
  143. snowflake/snowpark_connect/includes/jars/jackson-core-asl-1.9.13.jar +0 -0
  144. snowflake/snowpark_connect/includes/jars/jackson-databind-2.15.2.jar +0 -0
  145. snowflake/snowpark_connect/includes/jars/jackson-dataformat-yaml-2.15.2.jar +0 -0
  146. snowflake/snowpark_connect/includes/jars/jackson-datatype-jsr310-2.15.2.jar +0 -0
  147. snowflake/snowpark_connect/includes/jars/jackson-module-scala_2.12-2.15.2.jar +0 -0
  148. snowflake/snowpark_connect/includes/jars/json4s-ast_2.12-3.7.0-M11.jar +0 -0
  149. snowflake/snowpark_connect/includes/jars/json4s-core_2.12-3.7.0-M11.jar +0 -0
  150. snowflake/snowpark_connect/includes/jars/json4s-jackson_2.12-3.7.0-M11.jar +0 -0
  151. snowflake/snowpark_connect/includes/jars/json4s-native_2.12-3.7.0-M11.jar +0 -0
  152. snowflake/snowpark_connect/includes/jars/json4s-scalap_2.12-3.7.0-M11.jar +0 -0
  153. snowflake/snowpark_connect/includes/jars/kryo-shaded-4.0.2.jar +0 -0
  154. snowflake/snowpark_connect/includes/jars/log4j-1.2-api-2.20.0.jar +0 -0
  155. snowflake/snowpark_connect/includes/jars/log4j-api-2.20.0.jar +0 -0
  156. snowflake/snowpark_connect/includes/jars/log4j-core-2.20.0.jar +0 -0
  157. snowflake/snowpark_connect/includes/jars/log4j-slf4j2-impl-2.20.0.jar +0 -0
  158. snowflake/snowpark_connect/includes/jars/paranamer-2.8.3.jar +0 -0
  159. snowflake/snowpark_connect/includes/jars/paranamer-2.8.jar +0 -0
  160. snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.1.0.jar +0 -0
  161. snowflake/snowpark_connect/includes/jars/scala-collection-compat_2.12-2.7.0.jar +0 -0
  162. snowflake/snowpark_connect/includes/jars/scala-parser-combinators_2.12-2.3.0.jar +0 -0
  163. snowflake/snowpark_connect/includes/jars/scala-reflect-2.12.18.jar +0 -0
  164. snowflake/snowpark_connect/includes/jars/scala-xml_2.12-2.1.0.jar +0 -0
  165. snowflake/snowpark_connect/includes/jars/slf4j-api-2.0.7.jar +0 -0
  166. snowflake/snowpark_connect/includes/jars/spark-catalyst_2.12-3.5.6.jar +0 -0
  167. snowflake/snowpark_connect/includes/jars/spark-common-utils_2.12-3.5.6.jar +0 -0
  168. snowflake/snowpark_connect/includes/jars/spark-connect-client-jvm_2.12-3.5.6.jar +0 -0
  169. snowflake/snowpark_connect/includes/jars/spark-core_2.12-3.5.6.jar +0 -0
  170. snowflake/snowpark_connect/includes/jars/spark-graphx_2.12-3.5.6.jar +0 -0
  171. snowflake/snowpark_connect/includes/jars/spark-hive-thriftserver_2.12-3.5.6.jar +0 -0
  172. snowflake/snowpark_connect/includes/jars/spark-hive_2.12-3.5.6.jar +0 -0
  173. snowflake/snowpark_connect/includes/jars/spark-kvstore_2.12-3.5.6.jar +0 -0
  174. snowflake/snowpark_connect/includes/jars/spark-launcher_2.12-3.5.6.jar +0 -0
  175. snowflake/snowpark_connect/includes/jars/spark-mesos_2.12-3.5.6.jar +0 -0
  176. snowflake/snowpark_connect/includes/jars/spark-mllib-local_2.12-3.5.6.jar +0 -0
  177. snowflake/snowpark_connect/includes/jars/spark-network-common_2.12-3.5.6.jar +0 -0
  178. snowflake/snowpark_connect/includes/jars/spark-network-shuffle_2.12-3.5.6.jar +0 -0
  179. snowflake/snowpark_connect/includes/jars/spark-repl_2.12-3.5.6.jar +0 -0
  180. snowflake/snowpark_connect/includes/jars/spark-sketch_2.12-3.5.6.jar +0 -0
  181. snowflake/snowpark_connect/includes/jars/spark-sql-api_2.12-3.5.6.jar +0 -0
  182. snowflake/snowpark_connect/includes/jars/spark-sql_2.12-3.5.6.jar +0 -0
  183. snowflake/snowpark_connect/includes/jars/spark-tags_2.12-3.5.6.jar +0 -0
  184. snowflake/snowpark_connect/includes/jars/spark-unsafe_2.12-3.5.6.jar +0 -0
  185. snowflake/snowpark_connect/includes/jars/spark-yarn_2.12-3.5.6.jar +0 -0
  186. {snowpark_connect-0.27.0.data → snowpark_connect-1.6.0.data}/scripts/snowpark-connect +0 -0
  187. {snowpark_connect-0.27.0.data → snowpark_connect-1.6.0.data}/scripts/snowpark-session +0 -0
  188. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/WHEEL +0 -0
  189. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/licenses/LICENSE-binary +0 -0
  190. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/licenses/LICENSE.txt +0 -0
  191. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/licenses/NOTICE-binary +0 -0
  192. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/top_level.txt +0 -0
@@ -15,14 +15,21 @@ Key components:
15
15
  - Type mapping functions for different type systems
16
16
  - UDF creation and management utilities
17
17
  """
18
- import re
19
18
  from dataclasses import dataclass
20
- from enum import Enum
21
19
  from typing import List, Union
22
20
 
23
21
  import snowflake.snowpark.types as snowpark_type
24
22
  import snowflake.snowpark_connect.includes.python.pyspark.sql.connect.proto.types_pb2 as types_proto
25
- from snowflake.snowpark_connect.resources_initializer import RESOURCE_PATH
23
+ from snowflake.snowpark_connect.error.error_codes import ErrorCodes
24
+ from snowflake.snowpark_connect.error.error_utils import attach_custom_error_code
25
+ from snowflake.snowpark_connect.type_mapping import map_type_to_snowflake_type
26
+ from snowflake.snowpark_connect.utils.jvm_udf_utils import (
27
+ NullHandling,
28
+ Param,
29
+ ReturnType,
30
+ Signature,
31
+ build_jvm_udxf_imports,
32
+ )
26
33
  from snowflake.snowpark_connect.utils.snowpark_connect_logging import logger
27
34
  from snowflake.snowpark_connect.utils.udf_utils import (
28
35
  ProcessCommonInlineUserDefinedFunction,
@@ -59,58 +66,6 @@ class ScalaUdf:
59
66
  self._return_type = return_type
60
67
 
61
68
 
62
- @dataclass(frozen=True)
63
- class Param:
64
- """
65
- Represents a function parameter with name and data type.
66
-
67
- Attributes:
68
- name: Parameter name
69
- data_type: Parameter data type as a string
70
- """
71
-
72
- name: str
73
- data_type: str
74
-
75
-
76
- @dataclass(frozen=True)
77
- class NullHandling(str, Enum):
78
- """
79
- Enumeration for UDF null handling behavior.
80
-
81
- Determines how the UDF behaves when input parameters contain null values.
82
- """
83
-
84
- RETURNS_NULL_ON_NULL_INPUT = "RETURNS NULL ON NULL INPUT"
85
- CALLED_ON_NULL_INPUT = "CALLED ON NULL INPUT"
86
-
87
-
88
- @dataclass(frozen=True)
89
- class ReturnType:
90
- """
91
- Represents the return type of a function.
92
-
93
- Attributes:
94
- data_type: Return data type as a string
95
- """
96
-
97
- data_type: str
98
-
99
-
100
- @dataclass(frozen=True)
101
- class Signature:
102
- """
103
- Represents a function signature with parameters and return type.
104
-
105
- Attributes:
106
- params: List of function parameters
107
- returns: Function return type
108
- """
109
-
110
- params: List[Param]
111
- returns: ReturnType
112
-
113
-
114
69
  @dataclass(frozen=True)
115
70
  class ScalaUDFDef:
116
71
  """
@@ -147,76 +102,45 @@ class ScalaUDFDef:
147
102
  String containing the complete Scala code for the UDF body
148
103
  """
149
104
  # Convert Array to Seq for Scala compatibility in function signatures.
150
- udf_func_input_types = (
151
- ", ".join(p.data_type for p in self.scala_signature.params)
152
- ).replace("Array", "Seq")
105
+ # Replace each "Variant" type with "Any" in the function signature since fromVariant returns Any
106
+ udf_func_input_types = ", ".join(
107
+ "Any" if p.data_type == "Variant" else p.data_type.replace("Array", "Seq")
108
+ for p in self.scala_signature.params
109
+ )
110
+ udf_func_return_type = self.scala_signature.returns.data_type.replace(
111
+ "Array", "Seq"
112
+ )
113
+
153
114
  # Create the Scala arguments and input types string: "arg0: Type0, arg1: Type1, ...".
154
115
  joined_wrapper_arg_and_input_types_str = ", ".join(
155
116
  f"{p.name}: {p.data_type}" for p in self.scala_signature.params
156
117
  )
157
- # This is used in defining the input types for the wrapper function. For Maps to work correctly with Scala UDFs,
158
- # we need to set the Map types to Map[String, String]. These get cast to the respective original types
159
- # when the original UDF function is invoked.
160
- wrapper_arg_and_input_types_str = re.sub(
161
- pattern=r"Map\[\w+,\s\w+\]",
162
- repl="Map[String, String]",
163
- string=joined_wrapper_arg_and_input_types_str,
164
- )
165
- invocation_args = ", ".join(self.scala_invocation_args)
166
-
167
- # Cannot directly return a map from a Scala UDF due to issues with non-String values. Snowflake SQL Scala only
168
- # supports Map[String, String] as input types. Therefore, we convert the map to a JSON string before returning.
169
- # This is processed as a Variant by SQL.
170
- udf_func_return_type = self.scala_signature.returns.data_type
171
- is_map_return = udf_func_return_type.startswith("Map")
172
- wrapper_return_type = "String" if is_map_return else udf_func_return_type
173
-
174
- # Need to call the map to JSON string converter when a map is returned by the user's function.
175
- invoke_udf_func = (
176
- f"write(func({invocation_args}))"
177
- if is_map_return
178
- else f"func({invocation_args})"
179
- )
180
118
 
181
- # The lines of code below are required only when a Map is returned by the UDF. This is needed to serialize the
182
- # map output to a JSON string.
183
- map_return_imports = (
184
- ""
185
- if not is_map_return
186
- else """
187
- import org.json4s._
188
- import org.json4s.native.Serialization._
189
- import org.json4s.native.Serialization
190
- """
191
- )
192
- map_return_formatter = (
193
- ""
194
- if not is_map_return
195
- else """
196
- implicit val formats = Serialization.formats(NoTypeHints)
197
- """
198
- )
119
+ # All Scala UDFs return Variant to ensure consistency and avoid type conversion issues.
120
+ wrapper_return_type = "Variant"
121
+ wrapped_args = [
122
+ f"udfPacket.fromVariant({arg}, {i})" if p.data_type == "Variant" else arg
123
+ for i, (arg, p) in enumerate(
124
+ zip(self.scala_invocation_args, self.scala_signature.params)
125
+ )
126
+ ]
127
+ invocation_args = ", ".join(wrapped_args)
128
+ invoke_udf_func = f"func({invocation_args})"
129
+
130
+ # Always wrap the result in Utils.toVariant() to ensure all Scala UDFs return Variant
131
+ invoke_udf_func = f"Utils.toVariant({invoke_udf_func})"
199
132
 
200
- return f"""import org.apache.spark.sql.connect.common.UdfPacket
201
- {map_return_imports}
202
- import java.io.{{ByteArrayInputStream, ObjectInputStream}}
203
- import java.nio.file.{{Files, Paths}}
133
+ return f"""
134
+ import org.apache.spark.sql.connect.common.UdfPacket
135
+ import com.snowflake.sas.scala.UdfPacketUtils._
136
+ import com.snowflake.sas.scala.Utils
137
+ import com.snowflake.snowpark_java.types.Variant
204
138
 
205
139
  object __RecreatedSparkUdf {{
206
- {map_return_formatter}
207
- private lazy val func: ({udf_func_input_types}) => {udf_func_return_type} = {{
208
- val importDirectory = System.getProperty("com.snowflake.import_directory")
209
- val fPath = importDirectory + "{self.name}.bin"
210
- val bytes = Files.readAllBytes(Paths.get(fPath))
211
- val ois = new ObjectInputStream(new ByteArrayInputStream(bytes))
212
- try {{
213
- ois.readObject().asInstanceOf[UdfPacket].function.asInstanceOf[({udf_func_input_types}) => {udf_func_return_type}]
214
- }} finally {{
215
- ois.close()
216
- }}
217
- }}
140
+ private lazy val udfPacket: UdfPacket = Utils.deserializeUdfPacket("{self.name}.bin")
141
+ private lazy val func: ({udf_func_input_types}) => {udf_func_return_type} = udfPacket.function.asInstanceOf[({udf_func_input_types}) => {udf_func_return_type}]
218
142
 
219
- def __wrapperFunc({wrapper_arg_and_input_types_str}): {wrapper_return_type} = {{
143
+ def __wrapperFunc({joined_wrapper_arg_and_input_types_str}): {wrapper_return_type} = {{
220
144
  {invoke_udf_func}
221
145
  }}
222
146
  }}
@@ -260,70 +184,6 @@ $$
260
184
  $$;"""
261
185
 
262
186
 
263
- def build_scala_udf_imports(session, payload, udf_name, is_map_return) -> List[str]:
264
- """
265
- Build the list of imports needed for the Scala UDF.
266
-
267
- This function:
268
- 1. Saves the UDF payload to a binary file in the session stage
269
- 2. Collects user-uploaded JAR files from the stage
270
- 3. Returns a list of all required JAR files for the UDF
271
-
272
- Args:
273
- session: Snowpark session
274
- payload: Binary payload containing the serialized UDF
275
- udf_name: Name of the UDF (used for the binary file name)
276
- is_map_return: Indicates if the UDF returns a Map (affects imports)
277
-
278
- Returns:
279
- List of JAR file paths to be imported by the UDF
280
- """
281
- # Save pciudf._payload to a bin file:
282
- import io
283
-
284
- payload_as_stream = io.BytesIO(payload)
285
- stage = session.get_session_stage()
286
- stage_resource_path = stage + RESOURCE_PATH
287
- closure_binary_file = stage_resource_path + "/" + udf_name + ".bin"
288
- session.file.put_stream(
289
- payload_as_stream,
290
- closure_binary_file,
291
- overwrite=True,
292
- )
293
-
294
- # Get a list of the jar files uploaded to the stage. We need to import the user's jar for the Scala UDF.
295
- res = session.sql(rf"LIST {stage}/ PATTERN='.*\.jar';").collect()
296
- user_jars = []
297
- for row in res:
298
- if RESOURCE_PATH not in row[0]:
299
- # Remove the stage path since it is not properly formatted.
300
- user_jars.append(row[0][row[0].find("/") :])
301
-
302
- # Jars used when the return type is a Map.
303
- map_jars = (
304
- []
305
- if not is_map_return
306
- else [
307
- f"{stage_resource_path}/json4s-core_2.12-3.7.0-M11.jar",
308
- f"{stage_resource_path}/json4s-native_2.12-3.7.0-M11.jar",
309
- f"{stage_resource_path}/paranamer-2.8.3.jar",
310
- ]
311
- )
312
-
313
- # Format the user jars to be used in the IMPORTS clause of the stored procedure.
314
- return (
315
- [
316
- closure_binary_file,
317
- f"{stage_resource_path}/spark-connect-client-jvm_2.12-3.5.6.jar",
318
- f"{stage_resource_path}/spark-common-utils_2.12-3.5.6.jar",
319
- f"{stage_resource_path}/spark-sql_2.12-3.5.6.jar",
320
- f"{stage_resource_path}/json4s-ast_2.12-3.7.0-M11.jar",
321
- ]
322
- + map_jars
323
- + [f"{stage + jar}" for jar in user_jars]
324
- )
325
-
326
-
327
187
  def create_scala_udf(pciudf: ProcessCommonInlineUserDefinedFunction) -> ScalaUdf:
328
188
  """
329
189
  Create a Scala UDF in Snowflake from a ProcessCommonInlineUserDefinedFunction object.
@@ -343,7 +203,13 @@ def create_scala_udf(pciudf: ProcessCommonInlineUserDefinedFunction) -> ScalaUdf
343
203
  Returns:
344
204
  A ScalaUdf object representing the created or cached Scala UDF.
345
205
  """
346
- from snowflake.snowpark_connect.utils.session import get_or_create_snowpark_session
206
+ from snowflake.snowpark_connect.resources_initializer import (
207
+ ensure_scala_udf_jars_uploaded,
208
+ )
209
+
210
+ # Lazily upload Scala UDF jars on-demand when a Scala UDF is actually created.
211
+ # This is thread-safe and will only upload once even if multiple threads call it.
212
+ ensure_scala_udf_jars_uploaded()
347
213
 
348
214
  function_name = pciudf._function_name
349
215
  # If a function name is not provided, hash the binary file and use the first ten characters as the function name.
@@ -353,11 +219,6 @@ def create_scala_udf(pciudf: ProcessCommonInlineUserDefinedFunction) -> ScalaUdf
353
219
  function_name = hashlib.sha256(pciudf._payload).hexdigest()[:10]
354
220
  udf_name = CREATE_SCALA_UDF_PREFIX + function_name
355
221
 
356
- session = get_or_create_snowpark_session()
357
- if udf_name in session._udfs:
358
- cached_udf = session._udfs[udf_name]
359
- return ScalaUdf(cached_udf.name, cached_udf.input_types, cached_udf.return_type)
360
-
361
222
  # In case the Scala UDF was created with `spark.udf.register`, the Spark Scala input types (from protobuf) are
362
223
  # stored in pciudf.scala_input_types.
363
224
  # We cannot rely solely on the inputTypes field from the Scala UDF or the Snowpark input types, since:
@@ -376,30 +237,40 @@ def create_scala_udf(pciudf: ProcessCommonInlineUserDefinedFunction) -> ScalaUdf
376
237
  param_name = "arg" + str(i)
377
238
  # Create the Scala arguments and input types string: "arg0: Type0, arg1: Type1, ...".
378
239
  scala_input_params.append(
379
- Param(param_name, map_type_to_scala_type(input_type))
240
+ Param(param_name, _map_type_to_scala_type(input_type, is_input=True))
380
241
  )
381
242
  # Create the Snowflake SQL arguments and input types string: "arg0 TYPE0, arg1 TYPE1, ...".
382
- sql_input_params.append(
383
- Param(param_name, map_type_to_snowflake_type(input_type))
243
+ # For arrays and structs, use VARIANT type in SQL signature
244
+ is_snowpark_type = isinstance(input_type, snowpark_type.DataType)
245
+ is_array = (
246
+ is_snowpark_type and isinstance(input_type, snowpark_type.ArrayType)
247
+ ) or (not is_snowpark_type and input_type.WhichOneof("kind") == "array")
248
+ is_map = (
249
+ is_snowpark_type and isinstance(input_type, snowpark_type.MapType)
250
+ ) or (not is_snowpark_type and input_type.WhichOneof("kind") == "map")
251
+ sql_type = (
252
+ "VARIANT"
253
+ if is_array or is_map
254
+ else map_type_to_snowflake_type(input_type)
384
255
  )
256
+ sql_input_params.append(Param(param_name, sql_type))
385
257
  # In the case of Map input types, we need to cast the argument to the correct type in Scala.
386
- # Snowflake SQL Scala can only handle MAP[VARCHAR, VARCHAR] as input types.
387
- scala_invocation_args.append(
388
- cast_scala_map_args_from_given_type(param_name, input_type)
389
- )
258
+ scala_invocation_args.append(param_name)
259
+
260
+ scala_return_type = _map_type_to_scala_type(
261
+ pciudf._original_return_type, is_input=False
262
+ )
263
+ # All Scala UDFs now return VARIANT to ensure consistency and avoid type conversion issues.
264
+ # The actual type conversion is handled after the UDF is called.
265
+ from snowflake.snowpark_connect.utils.session import get_or_create_snowpark_session
390
266
 
391
- scala_return_type = map_type_to_scala_type(pciudf._original_return_type)
392
- # If the SQL return type is a MAP, change this to VARIANT because of issues with Scala UDFs.
393
- sql_return_type = map_type_to_snowflake_type(pciudf._original_return_type)
394
- imports = build_scala_udf_imports(
267
+ session = get_or_create_snowpark_session()
268
+ imports = build_jvm_udxf_imports(
395
269
  session,
396
270
  pciudf._payload,
397
271
  udf_name,
398
- is_map_return=sql_return_type.startswith("MAP"),
399
- )
400
- sql_return_type = (
401
- "VARIANT" if sql_return_type.startswith("MAP") else sql_return_type
402
272
  )
273
+ sql_return_type = "VARIANT"
403
274
 
404
275
  udf_def = ScalaUDFDef(
405
276
  name=udf_name,
@@ -418,21 +289,30 @@ def create_scala_udf(pciudf: ProcessCommonInlineUserDefinedFunction) -> ScalaUdf
418
289
  return ScalaUdf(udf_name, pciudf._input_types, pciudf._return_type)
419
290
 
420
291
 
421
- def map_type_to_scala_type(
422
- t: Union[snowpark_type.DataType, types_proto.DataType]
292
+ def _map_type_to_scala_type(
293
+ t: Union[snowpark_type.DataType, types_proto.DataType], is_input: bool = False
423
294
  ) -> str:
424
- """Maps a Snowpark or Spark protobuf type to a Scala type string."""
295
+ """Maps a Snowpark or Spark protobuf type to a Scala type string.
296
+
297
+ Args:
298
+ t: The type to map
299
+ is_input: If True, maps array types to Variant (for UDF inputs).
300
+ If False, maps array types to Array[ElementType] (for UDF outputs).
301
+ """
425
302
  if not t:
426
303
  return "String"
427
304
  is_snowpark_type = isinstance(t, snowpark_type.DataType)
428
305
  condition = type(t) if is_snowpark_type else t.WhichOneof("kind")
429
306
  match condition:
430
307
  case snowpark_type.ArrayType | "array":
431
- return (
432
- f"Array[{map_type_to_scala_type(t.element_type)}]"
433
- if is_snowpark_type
434
- else f"Array[{map_type_to_scala_type(t.array.element_type)}]"
435
- )
308
+ if is_input:
309
+ return "Variant"
310
+ else:
311
+ return (
312
+ f"Array[{_map_type_to_scala_type(t.element_type, is_input=False)}]"
313
+ if is_snowpark_type
314
+ else f"Array[{_map_type_to_scala_type(t.array.element_type, is_input=False)}]"
315
+ )
436
316
  case snowpark_type.BinaryType | "binary":
437
317
  return "Array[Byte]"
438
318
  case snowpark_type.BooleanType | "boolean":
@@ -453,16 +333,18 @@ def map_type_to_scala_type(
453
333
  return "Int"
454
334
  case snowpark_type.LongType | "long":
455
335
  return "Long"
456
- case snowpark_type.MapType | "map": # can also map to OBJECT in Snowflake
336
+ case snowpark_type.MapType | "map":
337
+ if is_input:
338
+ return "Variant"
457
339
  key_type = (
458
- map_type_to_scala_type(t.key_type)
340
+ _map_type_to_scala_type(t.key_type)
459
341
  if is_snowpark_type
460
- else map_type_to_scala_type(t.map.key_type)
342
+ else _map_type_to_scala_type(t.map.key_type)
461
343
  )
462
344
  value_type = (
463
- map_type_to_scala_type(t.value_type)
345
+ _map_type_to_scala_type(t.value_type)
464
346
  if is_snowpark_type
465
- else map_type_to_scala_type(t.map.value_type)
347
+ else _map_type_to_scala_type(t.map.value_type)
466
348
  )
467
349
  return f"Map[{key_type}, {value_type}]"
468
350
  case snowpark_type.NullType | "null":
@@ -471,126 +353,13 @@ def map_type_to_scala_type(
471
353
  return "Short"
472
354
  case snowpark_type.StringType | "string" | "char" | "varchar":
473
355
  return "String"
356
+ case snowpark_type.StructType | "struct":
357
+ return "Variant"
474
358
  case snowpark_type.TimestampType | "timestamp" | "timestamp_ntz":
475
359
  return "java.sql.Timestamp"
476
360
  case snowpark_type.VariantType:
477
361
  return "Variant"
478
362
  case _:
479
- raise ValueError(f"Unsupported Snowpark type: {t}")
480
-
481
-
482
- def map_type_to_snowflake_type(
483
- t: Union[snowpark_type.DataType, types_proto.DataType]
484
- ) -> str:
485
- """Maps a Snowpark or Spark protobuf type to a Snowflake type string."""
486
- if not t:
487
- return "VARCHAR"
488
- is_snowpark_type = isinstance(t, snowpark_type.DataType)
489
- condition = type(t) if is_snowpark_type else t.WhichOneof("kind")
490
- match condition:
491
- case snowpark_type.ArrayType | "array":
492
- return (
493
- f"ARRAY({map_type_to_snowflake_type(t.element_type)})"
494
- if is_snowpark_type
495
- else f"ARRAY({map_type_to_snowflake_type(t.array.element_type)})"
496
- )
497
- case snowpark_type.BinaryType | "binary":
498
- return "BINARY"
499
- case snowpark_type.BooleanType | "boolean":
500
- return "BOOLEAN"
501
- case snowpark_type.ByteType | "byte":
502
- return "TINYINT"
503
- case snowpark_type.DateType | "date":
504
- return "DATE"
505
- case snowpark_type.DecimalType | "decimal":
506
- return "NUMBER"
507
- case snowpark_type.DoubleType | "double":
508
- return "DOUBLE"
509
- case snowpark_type.FloatType | "float":
510
- return "FLOAT"
511
- case snowpark_type.GeographyType:
512
- return "GEOGRAPHY"
513
- case snowpark_type.IntegerType | "integer":
514
- return "INT"
515
- case snowpark_type.LongType | "long":
516
- return "BIGINT"
517
- case snowpark_type.MapType | "map":
518
- # Maps to OBJECT in Snowflake if key and value types are not specified.
519
- key_type = (
520
- map_type_to_snowflake_type(t.key_type)
521
- if is_snowpark_type
522
- else map_type_to_snowflake_type(t.map.key_type)
523
- )
524
- value_type = (
525
- map_type_to_snowflake_type(t.value_type)
526
- if is_snowpark_type
527
- else map_type_to_snowflake_type(t.map.value_type)
528
- )
529
- return (
530
- f"MAP({key_type}, {value_type})"
531
- if key_type and value_type
532
- else "OBJECT"
533
- )
534
- case snowpark_type.NullType | "null":
535
- return "VARCHAR"
536
- case snowpark_type.ShortType | "short":
537
- return "SMALLINT"
538
- case snowpark_type.StringType | "string" | "char" | "varchar":
539
- return "VARCHAR"
540
- case snowpark_type.TimestampType | "timestamp" | "timestamp_ntz":
541
- return "TIMESTAMP"
542
- case snowpark_type.VariantType:
543
- return "VARIANT"
544
- case _:
545
- raise ValueError(f"Unsupported Snowpark type: {t}")
546
-
547
-
548
- def cast_scala_map_args_from_given_type(
549
- arg_name: str, input_type: Union[snowpark_type.DataType, types_proto.DataType]
550
- ) -> str:
551
- """If the input_type is a Map, cast the argument arg_name to a Map[key_type, value_type] in Scala."""
552
- is_snowpark_type = isinstance(input_type, snowpark_type.DataType)
553
-
554
- def convert_from_string_to_type(
555
- arg_name: str, t: Union[snowpark_type.DataType, types_proto.DataType]
556
- ) -> str:
557
- """Convert the string argument arg_name to the specified type t in Scala."""
558
- condition = type(t) if is_snowpark_type else t.WhichOneof("kind")
559
- match condition:
560
- case snowpark_type.BinaryType | "binary":
561
- return arg_name + ".getBytes()"
562
- case snowpark_type.BooleanType | "boolean":
563
- return arg_name + ".toBoolean"
564
- case snowpark_type.ByteType | "byte":
565
- return arg_name + ".getBytes().head" # TODO: verify if this is correct
566
- case snowpark_type.DateType | "date":
567
- return f"java.sql.Date.valueOf({arg_name})"
568
- case snowpark_type.DecimalType | "decimal":
569
- return f"new BigDecimal({arg_name})"
570
- case snowpark_type.DoubleType | "double":
571
- return arg_name + ".toDouble"
572
- case snowpark_type.FloatType | "float":
573
- return arg_name + ".toFloat"
574
- case snowpark_type.IntegerType | "integer":
575
- return arg_name + ".toInt"
576
- case snowpark_type.LongType | "long":
577
- return arg_name + ".toLong"
578
- case snowpark_type.ShortType | "short":
579
- return arg_name + ".toShort"
580
- case snowpark_type.StringType | "string" | "char" | "varchar":
581
- return arg_name
582
- case snowpark_type.TimestampType | "timestamp" | "timestamp_ntz":
583
- return "java.sql.Timestamp.valueOf({arg_name})"
584
- case _:
585
- raise ValueError(f"Unsupported Snowpark type: {t}")
586
-
587
- if (is_snowpark_type and isinstance(input_type, snowpark_type.MapType)) or (
588
- not is_snowpark_type and input_type.WhichOneof("kind") == "map"
589
- ):
590
- key_type = input_type.key_type if is_snowpark_type else input_type.map.key_type
591
- value_type = (
592
- input_type.value_type if is_snowpark_type else input_type.map.value_type
593
- )
594
- return f"{arg_name}.map {{ case (k, v) => ({convert_from_string_to_type('k', key_type)}, {convert_from_string_to_type('v', value_type)})}}"
595
- else:
596
- return arg_name
363
+ exception = ValueError(f"Unsupported Snowpark type: {t}")
364
+ attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_TYPE)
365
+ raise exception
@@ -0,0 +1,21 @@
1
+ #
2
+ # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
3
+ #
4
+
5
+ import threading
6
+ from collections import defaultdict
7
+
8
+ from snowflake.snowpark_connect.utils.context import get_spark_session_id
9
+
10
+ # per session number sequences to generate unique snowpark columns
11
+ _session_sequences = defaultdict(int)
12
+
13
+ _lock = threading.Lock()
14
+
15
+
16
+ def next_unique_num():
17
+ session_id = get_spark_session_id()
18
+ with _lock:
19
+ next_num = _session_sequences[session_id]
20
+ _session_sequences[session_id] = next_num + 1
21
+ return next_num