snowpark-connect 0.27.0__py3-none-any.whl → 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. snowflake/snowpark_connect/__init__.py +1 -0
  2. snowflake/snowpark_connect/analyze_plan/map_tree_string.py +8 -4
  3. snowflake/snowpark_connect/client/__init__.py +15 -0
  4. snowflake/snowpark_connect/client/error_utils.py +30 -0
  5. snowflake/snowpark_connect/client/exceptions.py +36 -0
  6. snowflake/snowpark_connect/client/query_results.py +90 -0
  7. snowflake/snowpark_connect/client/server.py +717 -0
  8. snowflake/snowpark_connect/client/utils/__init__.py +10 -0
  9. snowflake/snowpark_connect/client/utils/session.py +85 -0
  10. snowflake/snowpark_connect/column_name_handler.py +404 -243
  11. snowflake/snowpark_connect/column_qualifier.py +43 -0
  12. snowflake/snowpark_connect/config.py +309 -26
  13. snowflake/snowpark_connect/constants.py +2 -0
  14. snowflake/snowpark_connect/dataframe_container.py +102 -8
  15. snowflake/snowpark_connect/date_time_format_mapping.py +71 -13
  16. snowflake/snowpark_connect/error/error_codes.py +50 -0
  17. snowflake/snowpark_connect/error/error_utils.py +172 -23
  18. snowflake/snowpark_connect/error/exceptions.py +13 -4
  19. snowflake/snowpark_connect/execute_plan/map_execution_command.py +15 -160
  20. snowflake/snowpark_connect/execute_plan/map_execution_root.py +26 -20
  21. snowflake/snowpark_connect/execute_plan/utils.py +5 -1
  22. snowflake/snowpark_connect/expression/error_utils.py +28 -0
  23. snowflake/snowpark_connect/expression/function_defaults.py +9 -2
  24. snowflake/snowpark_connect/expression/hybrid_column_map.py +53 -5
  25. snowflake/snowpark_connect/expression/integral_types_support.py +219 -0
  26. snowflake/snowpark_connect/expression/literal.py +37 -13
  27. snowflake/snowpark_connect/expression/map_cast.py +224 -15
  28. snowflake/snowpark_connect/expression/map_expression.py +80 -27
  29. snowflake/snowpark_connect/expression/map_extension.py +322 -12
  30. snowflake/snowpark_connect/expression/map_sql_expression.py +316 -81
  31. snowflake/snowpark_connect/expression/map_udf.py +86 -20
  32. snowflake/snowpark_connect/expression/map_unresolved_attribute.py +451 -173
  33. snowflake/snowpark_connect/expression/map_unresolved_function.py +2964 -829
  34. snowflake/snowpark_connect/expression/map_unresolved_star.py +87 -23
  35. snowflake/snowpark_connect/expression/map_update_fields.py +70 -18
  36. snowflake/snowpark_connect/expression/map_window_function.py +18 -3
  37. snowflake/snowpark_connect/includes/jars/json4s-ast_2.13-3.7.0-M11.jar +0 -0
  38. snowflake/snowpark_connect/includes/jars/{scala-library-2.12.18.jar → sas-scala-udf_2.12-0.2.0.jar} +0 -0
  39. snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.13-0.2.0.jar +0 -0
  40. snowflake/snowpark_connect/includes/jars/scala-reflect-2.13.16.jar +0 -0
  41. snowflake/snowpark_connect/includes/jars/spark-common-utils_2.13-3.5.6.jar +0 -0
  42. snowflake/snowpark_connect/includes/jars/{spark-connect-client-jvm_2.12-3.5.6.jar → spark-connect-client-jvm_2.13-3.5.6.jar} +0 -0
  43. snowflake/snowpark_connect/includes/jars/{spark-sql_2.12-3.5.6.jar → spark-sql_2.13-3.5.6.jar} +0 -0
  44. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py +1 -1
  45. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/listener_worker.py +1 -1
  46. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.py +12 -10
  47. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.pyi +14 -2
  48. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +10 -8
  49. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +13 -6
  50. snowflake/snowpark_connect/relation/catalogs/abstract_spark_catalog.py +65 -17
  51. snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +297 -49
  52. snowflake/snowpark_connect/relation/catalogs/utils.py +12 -4
  53. snowflake/snowpark_connect/relation/io_utils.py +110 -10
  54. snowflake/snowpark_connect/relation/map_aggregate.py +239 -256
  55. snowflake/snowpark_connect/relation/map_catalog.py +5 -1
  56. snowflake/snowpark_connect/relation/map_column_ops.py +264 -96
  57. snowflake/snowpark_connect/relation/map_extension.py +263 -29
  58. snowflake/snowpark_connect/relation/map_join.py +683 -442
  59. snowflake/snowpark_connect/relation/map_local_relation.py +28 -1
  60. snowflake/snowpark_connect/relation/map_map_partitions.py +83 -8
  61. snowflake/snowpark_connect/relation/map_relation.py +48 -19
  62. snowflake/snowpark_connect/relation/map_row_ops.py +310 -91
  63. snowflake/snowpark_connect/relation/map_show_string.py +13 -6
  64. snowflake/snowpark_connect/relation/map_sql.py +1233 -222
  65. snowflake/snowpark_connect/relation/map_stats.py +48 -9
  66. snowflake/snowpark_connect/relation/map_subquery_alias.py +11 -2
  67. snowflake/snowpark_connect/relation/map_udtf.py +14 -4
  68. snowflake/snowpark_connect/relation/read/jdbc_read_dbapi.py +53 -14
  69. snowflake/snowpark_connect/relation/read/map_read.py +134 -43
  70. snowflake/snowpark_connect/relation/read/map_read_csv.py +326 -47
  71. snowflake/snowpark_connect/relation/read/map_read_jdbc.py +21 -6
  72. snowflake/snowpark_connect/relation/read/map_read_json.py +324 -86
  73. snowflake/snowpark_connect/relation/read/map_read_parquet.py +146 -28
  74. snowflake/snowpark_connect/relation/read/map_read_partitioned_parquet.py +142 -0
  75. snowflake/snowpark_connect/relation/read/map_read_socket.py +15 -3
  76. snowflake/snowpark_connect/relation/read/map_read_table.py +86 -6
  77. snowflake/snowpark_connect/relation/read/map_read_text.py +22 -4
  78. snowflake/snowpark_connect/relation/read/metadata_utils.py +170 -0
  79. snowflake/snowpark_connect/relation/read/reader_config.py +42 -3
  80. snowflake/snowpark_connect/relation/read/utils.py +50 -5
  81. snowflake/snowpark_connect/relation/stage_locator.py +91 -55
  82. snowflake/snowpark_connect/relation/utils.py +128 -5
  83. snowflake/snowpark_connect/relation/write/jdbc_write_dbapi.py +19 -3
  84. snowflake/snowpark_connect/relation/write/map_write.py +929 -319
  85. snowflake/snowpark_connect/relation/write/map_write_jdbc.py +8 -2
  86. snowflake/snowpark_connect/resources/java_udfs-1.0-SNAPSHOT.jar +0 -0
  87. snowflake/snowpark_connect/resources_initializer.py +171 -48
  88. snowflake/snowpark_connect/server.py +528 -473
  89. snowflake/snowpark_connect/server_common/__init__.py +503 -0
  90. snowflake/snowpark_connect/snowflake_session.py +65 -0
  91. snowflake/snowpark_connect/start_server.py +53 -5
  92. snowflake/snowpark_connect/type_mapping.py +349 -27
  93. snowflake/snowpark_connect/type_support.py +130 -0
  94. snowflake/snowpark_connect/typed_column.py +9 -7
  95. snowflake/snowpark_connect/utils/artifacts.py +9 -8
  96. snowflake/snowpark_connect/utils/cache.py +49 -27
  97. snowflake/snowpark_connect/utils/concurrent.py +36 -1
  98. snowflake/snowpark_connect/utils/context.py +195 -37
  99. snowflake/snowpark_connect/utils/describe_query_cache.py +68 -53
  100. snowflake/snowpark_connect/utils/env_utils.py +5 -1
  101. snowflake/snowpark_connect/utils/expression_transformer.py +172 -0
  102. snowflake/snowpark_connect/utils/identifiers.py +137 -3
  103. snowflake/snowpark_connect/utils/io_utils.py +57 -1
  104. snowflake/snowpark_connect/utils/java_stored_procedure.py +151 -0
  105. snowflake/snowpark_connect/utils/java_udaf_utils.py +321 -0
  106. snowflake/snowpark_connect/utils/java_udtf_utils.py +239 -0
  107. snowflake/snowpark_connect/utils/jvm_udf_utils.py +281 -0
  108. snowflake/snowpark_connect/utils/open_telemetry.py +516 -0
  109. snowflake/snowpark_connect/utils/pandas_udtf_utils.py +8 -4
  110. snowflake/snowpark_connect/utils/patch_spark_line_number.py +181 -0
  111. snowflake/snowpark_connect/utils/profiling.py +25 -8
  112. snowflake/snowpark_connect/utils/scala_udf_utils.py +185 -340
  113. snowflake/snowpark_connect/utils/sequence.py +21 -0
  114. snowflake/snowpark_connect/utils/session.py +64 -28
  115. snowflake/snowpark_connect/utils/snowpark_connect_logging.py +51 -9
  116. snowflake/snowpark_connect/utils/spcs_logger.py +290 -0
  117. snowflake/snowpark_connect/utils/telemetry.py +192 -40
  118. snowflake/snowpark_connect/utils/temporary_view_cache.py +67 -0
  119. snowflake/snowpark_connect/utils/temporary_view_helper.py +334 -0
  120. snowflake/snowpark_connect/utils/udf_cache.py +117 -41
  121. snowflake/snowpark_connect/utils/udf_helper.py +39 -37
  122. snowflake/snowpark_connect/utils/udf_utils.py +133 -14
  123. snowflake/snowpark_connect/utils/udtf_helper.py +8 -1
  124. snowflake/snowpark_connect/utils/udtf_utils.py +46 -31
  125. snowflake/snowpark_connect/utils/udxf_import_utils.py +9 -2
  126. snowflake/snowpark_connect/utils/upload_java_jar.py +57 -0
  127. snowflake/snowpark_connect/version.py +1 -1
  128. snowflake/snowpark_decoder/dp_session.py +6 -2
  129. snowflake/snowpark_decoder/spark_decoder.py +12 -0
  130. {snowpark_connect-0.27.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-submit +14 -4
  131. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/METADATA +16 -7
  132. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/RECORD +139 -168
  133. snowflake/snowpark_connect/hidden_column.py +0 -39
  134. snowflake/snowpark_connect/includes/jars/antlr4-runtime-4.9.3.jar +0 -0
  135. snowflake/snowpark_connect/includes/jars/commons-cli-1.5.0.jar +0 -0
  136. snowflake/snowpark_connect/includes/jars/commons-codec-1.16.1.jar +0 -0
  137. snowflake/snowpark_connect/includes/jars/commons-collections-3.2.2.jar +0 -0
  138. snowflake/snowpark_connect/includes/jars/commons-collections4-4.4.jar +0 -0
  139. snowflake/snowpark_connect/includes/jars/commons-compiler-3.1.9.jar +0 -0
  140. snowflake/snowpark_connect/includes/jars/commons-compress-1.26.0.jar +0 -0
  141. snowflake/snowpark_connect/includes/jars/commons-crypto-1.1.0.jar +0 -0
  142. snowflake/snowpark_connect/includes/jars/commons-dbcp-1.4.jar +0 -0
  143. snowflake/snowpark_connect/includes/jars/commons-io-2.16.1.jar +0 -0
  144. snowflake/snowpark_connect/includes/jars/commons-lang-2.6.jar +0 -0
  145. snowflake/snowpark_connect/includes/jars/commons-lang3-3.12.0.jar +0 -0
  146. snowflake/snowpark_connect/includes/jars/commons-logging-1.1.3.jar +0 -0
  147. snowflake/snowpark_connect/includes/jars/commons-math3-3.6.1.jar +0 -0
  148. snowflake/snowpark_connect/includes/jars/commons-pool-1.5.4.jar +0 -0
  149. snowflake/snowpark_connect/includes/jars/commons-text-1.10.0.jar +0 -0
  150. snowflake/snowpark_connect/includes/jars/hadoop-client-api-trimmed-3.3.4.jar +0 -0
  151. snowflake/snowpark_connect/includes/jars/jackson-annotations-2.15.2.jar +0 -0
  152. snowflake/snowpark_connect/includes/jars/jackson-core-2.15.2.jar +0 -0
  153. snowflake/snowpark_connect/includes/jars/jackson-core-asl-1.9.13.jar +0 -0
  154. snowflake/snowpark_connect/includes/jars/jackson-databind-2.15.2.jar +0 -0
  155. snowflake/snowpark_connect/includes/jars/jackson-dataformat-yaml-2.15.2.jar +0 -0
  156. snowflake/snowpark_connect/includes/jars/jackson-datatype-jsr310-2.15.2.jar +0 -0
  157. snowflake/snowpark_connect/includes/jars/jackson-module-scala_2.12-2.15.2.jar +0 -0
  158. snowflake/snowpark_connect/includes/jars/json4s-ast_2.12-3.7.0-M11.jar +0 -0
  159. snowflake/snowpark_connect/includes/jars/json4s-core_2.12-3.7.0-M11.jar +0 -0
  160. snowflake/snowpark_connect/includes/jars/json4s-jackson_2.12-3.7.0-M11.jar +0 -0
  161. snowflake/snowpark_connect/includes/jars/json4s-native_2.12-3.7.0-M11.jar +0 -0
  162. snowflake/snowpark_connect/includes/jars/json4s-scalap_2.12-3.7.0-M11.jar +0 -0
  163. snowflake/snowpark_connect/includes/jars/kryo-shaded-4.0.2.jar +0 -0
  164. snowflake/snowpark_connect/includes/jars/log4j-1.2-api-2.20.0.jar +0 -0
  165. snowflake/snowpark_connect/includes/jars/log4j-api-2.20.0.jar +0 -0
  166. snowflake/snowpark_connect/includes/jars/log4j-core-2.20.0.jar +0 -0
  167. snowflake/snowpark_connect/includes/jars/log4j-slf4j2-impl-2.20.0.jar +0 -0
  168. snowflake/snowpark_connect/includes/jars/paranamer-2.8.3.jar +0 -0
  169. snowflake/snowpark_connect/includes/jars/paranamer-2.8.jar +0 -0
  170. snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.1.0.jar +0 -0
  171. snowflake/snowpark_connect/includes/jars/scala-collection-compat_2.12-2.7.0.jar +0 -0
  172. snowflake/snowpark_connect/includes/jars/scala-parser-combinators_2.12-2.3.0.jar +0 -0
  173. snowflake/snowpark_connect/includes/jars/scala-reflect-2.12.18.jar +0 -0
  174. snowflake/snowpark_connect/includes/jars/scala-xml_2.12-2.1.0.jar +0 -0
  175. snowflake/snowpark_connect/includes/jars/slf4j-api-2.0.7.jar +0 -0
  176. snowflake/snowpark_connect/includes/jars/spark-catalyst_2.12-3.5.6.jar +0 -0
  177. snowflake/snowpark_connect/includes/jars/spark-common-utils_2.12-3.5.6.jar +0 -0
  178. snowflake/snowpark_connect/includes/jars/spark-core_2.12-3.5.6.jar +0 -0
  179. snowflake/snowpark_connect/includes/jars/spark-graphx_2.12-3.5.6.jar +0 -0
  180. snowflake/snowpark_connect/includes/jars/spark-hive-thriftserver_2.12-3.5.6.jar +0 -0
  181. snowflake/snowpark_connect/includes/jars/spark-hive_2.12-3.5.6.jar +0 -0
  182. snowflake/snowpark_connect/includes/jars/spark-kvstore_2.12-3.5.6.jar +0 -0
  183. snowflake/snowpark_connect/includes/jars/spark-launcher_2.12-3.5.6.jar +0 -0
  184. snowflake/snowpark_connect/includes/jars/spark-mesos_2.12-3.5.6.jar +0 -0
  185. snowflake/snowpark_connect/includes/jars/spark-mllib-local_2.12-3.5.6.jar +0 -0
  186. snowflake/snowpark_connect/includes/jars/spark-network-common_2.12-3.5.6.jar +0 -0
  187. snowflake/snowpark_connect/includes/jars/spark-network-shuffle_2.12-3.5.6.jar +0 -0
  188. snowflake/snowpark_connect/includes/jars/spark-repl_2.12-3.5.6.jar +0 -0
  189. snowflake/snowpark_connect/includes/jars/spark-sketch_2.12-3.5.6.jar +0 -0
  190. snowflake/snowpark_connect/includes/jars/spark-sql-api_2.12-3.5.6.jar +0 -0
  191. snowflake/snowpark_connect/includes/jars/spark-tags_2.12-3.5.6.jar +0 -0
  192. snowflake/snowpark_connect/includes/jars/spark-unsafe_2.12-3.5.6.jar +0 -0
  193. snowflake/snowpark_connect/includes/jars/spark-yarn_2.12-3.5.6.jar +0 -0
  194. {snowpark_connect-0.27.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-connect +0 -0
  195. {snowpark_connect-0.27.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-session +0 -0
  196. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/WHEEL +0 -0
  197. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/LICENSE-binary +0 -0
  198. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/LICENSE.txt +0 -0
  199. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/NOTICE-binary +0 -0
  200. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/top_level.txt +0 -0
@@ -7,20 +7,23 @@ import pyspark.sql.connect.proto.types_pb2 as types_proto
7
7
 
8
8
  import snowflake.snowpark.functions as snowpark_fn
9
9
  from snowflake import snowpark
10
- from snowflake.snowpark.types import MapType, StructType, VariantType
10
+ from snowflake.snowpark.types import ArrayType, MapType, StructType, VariantType
11
11
  from snowflake.snowpark_connect.column_name_handler import ColumnNameMap
12
12
  from snowflake.snowpark_connect.config import global_config
13
+ from snowflake.snowpark_connect.error.error_codes import ErrorCodes
14
+ from snowflake.snowpark_connect.error.error_utils import attach_custom_error_code
13
15
  from snowflake.snowpark_connect.expression.typer import ExpressionTyper
14
16
  from snowflake.snowpark_connect.type_mapping import proto_to_snowpark_type
15
17
  from snowflake.snowpark_connect.typed_column import TypedColumn
18
+ from snowflake.snowpark_connect.utils.context import get_grouping_by_scala_udf_key
16
19
  from snowflake.snowpark_connect.utils.external_udxf_cache import (
17
20
  cache_external_udf,
18
21
  get_external_udf_from_cache,
19
22
  )
23
+ from snowflake.snowpark_connect.utils.java_stored_procedure import create_java_udf
20
24
  from snowflake.snowpark_connect.utils.session import get_or_create_snowpark_session
21
25
  from snowflake.snowpark_connect.utils.udf_helper import (
22
26
  SnowparkUDF,
23
- gen_input_types,
24
27
  infer_snowpark_arguments,
25
28
  process_udf_in_sproc,
26
29
  require_creating_udf_in_sproc,
@@ -53,8 +56,14 @@ def cache_external_udf_wrapper(from_register_udf: bool):
53
56
  session._udfs[udf_proto.function_name.lower()] = cached_udf
54
57
  case "python_udf":
55
58
  pass
59
+ case "java_udf":
60
+ session._udfs[udf_proto.function_name.lower()] = cached_udf
56
61
  case _:
57
- raise ValueError(f"Unsupported UDF type: {function_type}")
62
+ exception = ValueError(f"Unsupported UDF type: {function_type}")
63
+ attach_custom_error_code(
64
+ exception, ErrorCodes.UNSUPPORTED_OPERATION
65
+ )
66
+ raise exception
58
67
 
59
68
  return cached_udf
60
69
 
@@ -94,13 +103,43 @@ def register_udf(
94
103
  match udf_proto.WhichOneof("function"):
95
104
  case "python_udf":
96
105
  output_type = udf_proto.python_udf.output_type
106
+ processed_return_type, original_return_type = process_udf_return_type(
107
+ output_type
108
+ )
97
109
  case "scalar_scala_udf":
110
+ # For Scala UDFs, always use VariantType as the processed type since all Scala UDFs
111
+ # return Variant. The actual type conversion happens after the UDF call.
98
112
  output_type = udf_proto.scalar_scala_udf.outputType
113
+ original_return_type = proto_to_snowpark_type(output_type)
114
+ processed_return_type = VariantType()
115
+ case "java_udf":
116
+ has_output_type = udf_proto.java_udf.HasField("output_type")
117
+ session = get_or_create_snowpark_session()
118
+ java_udf = create_java_udf(
119
+ session,
120
+ udf_proto.function_name,
121
+ udf_proto.java_udf.class_name,
122
+ )
123
+ original_return_type = java_udf._return_type
124
+ if has_output_type:
125
+ original_return_type = proto_to_snowpark_type(
126
+ udf_proto.java_udf.output_type
127
+ )
128
+ udf = SnowparkUDF(
129
+ name=java_udf.name,
130
+ input_types=java_udf._input_types,
131
+ return_type=java_udf._return_type,
132
+ original_return_type=original_return_type,
133
+ cast_to_original_return_type=True,
134
+ )
135
+ session._udfs[udf_proto.function_name.lower()] = udf
136
+ return udf
99
137
  case _:
100
- raise ValueError(
138
+ exception = ValueError(
101
139
  f"Unsupported UDF type: {udf_proto.WhichOneof('function')}"
102
140
  )
103
- processed_return_type, original_return_type = process_udf_return_type(output_type)
141
+ attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
142
+ raise exception
104
143
  session = get_or_create_snowpark_session()
105
144
  kwargs = {
106
145
  "common_inline_user_defined_function": udf_proto,
@@ -116,11 +155,15 @@ def register_udf(
116
155
  else:
117
156
  udf_processor = ProcessCommonInlineUserDefinedFunction(**kwargs)
118
157
  udf = udf_processor.create_udf()
158
+ is_scala_udf = udf_proto.WhichOneof("function") == "scalar_scala_udf"
159
+
119
160
  udf = SnowparkUDF(
120
161
  name=udf.name,
121
162
  input_types=udf._input_types,
122
163
  return_type=udf._return_type,
123
164
  original_return_type=original_return_type,
165
+ cast_to_original_return_type=is_scala_udf
166
+ or udf._return_type == VariantType(),
124
167
  )
125
168
  session._udfs[udf_proto.function_name.lower()] = udf
126
169
  # scala udfs can be also accessed using `udf.name`
@@ -136,19 +179,22 @@ def map_common_inline_user_defined_udf(
136
179
  ) -> tuple[str, TypedColumn]:
137
180
  udf_proto = exp.common_inline_user_defined_function
138
181
  udf_check(udf_proto)
139
- snowpark_udf_arg_names, snowpark_udf_args = infer_snowpark_arguments(
182
+ snowpark_udf_arg_names, snowpark_udf_typed_args = infer_snowpark_arguments(
140
183
  udf_proto, column_mapping, typer
141
184
  )
142
- input_types = gen_input_types(snowpark_udf_args, typer)
185
+ input_types = [a.typ for a in snowpark_udf_typed_args]
143
186
  match udf_proto.WhichOneof("function"):
144
187
  case "python_udf":
145
188
  processed_return_type, original_return_type = process_udf_return_type(
146
189
  udf_proto.python_udf.output_type
147
190
  )
148
191
  case "scalar_scala_udf":
149
- processed_return_type, original_return_type = process_udf_return_type(
192
+ # For Scala UDFs, always use VariantType as the processed type since all Scala UDFs
193
+ # return Variant. The actual type conversion happens after the UDF call.
194
+ original_return_type = proto_to_snowpark_type(
150
195
  udf_proto.scalar_scala_udf.outputType
151
196
  )
197
+ processed_return_type = VariantType()
152
198
 
153
199
  @cache_external_udf_wrapper(from_register_udf=False)
154
200
  def get_snowpark_udf(
@@ -178,24 +224,44 @@ def map_common_inline_user_defined_udf(
178
224
  return snowpark_udf
179
225
 
180
226
  snowpark_udf = get_snowpark_udf(udf_proto)
181
- udf_call_expr = snowpark_fn.call_udf(snowpark_udf.name, *snowpark_udf_args)
227
+ # Determine if we need to cast the result back to the original type
228
+ is_scala_udf = udf_proto.WhichOneof("function") == "scalar_scala_udf"
182
229
 
183
- # If the original return type was MapType or StructType but we converted it to VariantType,
184
- # we need to parse the JSON result back to the original type
185
- if isinstance(original_return_type, (MapType, StructType)) and isinstance(
230
+ # For structured types (arrays, structs, maps), use to_variant instead of cast
231
+ # to ensure proper conversion to VARIANT type for Scala UDFS
232
+ converted_args = []
233
+ for tc in snowpark_udf_typed_args:
234
+ if is_scala_udf and isinstance(tc.typ, (ArrayType, StructType, MapType)):
235
+ converted_args.append(snowpark_fn.to_variant(tc.col))
236
+ else:
237
+ converted_args.append(tc.col)
238
+
239
+ udf_call_expr = snowpark_fn.call_udf(snowpark_udf.name, *converted_args)
240
+
241
+ # For Scala UDFs, always cast from Variant to the original type
242
+ # For Python UDFs, only cast if the original type was MapType or StructType
243
+ if is_scala_udf:
244
+ # All Scala UDFs return Variant, so we always need to cast back to the original type
245
+ result_expr = snowpark_fn.cast(udf_call_expr, original_return_type)
246
+ result_type = original_return_type
247
+
248
+ elif isinstance(original_return_type, (MapType, StructType)) and isinstance(
186
249
  processed_return_type, VariantType
187
250
  ):
188
- # Parse JSON and cast back to original type
251
+ # Parse JSON and cast back to original type for Python UDFs
189
252
  result_expr = snowpark_fn.parse_json(udf_call_expr).cast(original_return_type)
190
253
  result_type = original_return_type
191
254
  else:
192
255
  result_expr = udf_call_expr
193
256
  result_type = snowpark_udf.return_type
194
257
 
195
- return (
196
- f"{udf_proto.function_name}({', '.join(snowpark_udf_arg_names)})",
197
- TypedColumn(
198
- result_expr,
199
- lambda: [result_type],
200
- ),
201
- )
258
+ name = f"{udf_proto.function_name}({', '.join(snowpark_udf_arg_names)})"
259
+ if get_grouping_by_scala_udf_key() and not isinstance(
260
+ original_return_type, StructType
261
+ ):
262
+ name = (
263
+ "value"
264
+ if global_config.spark_sql_legacy_dataset_nameNonStructGroupingKeyAsValue
265
+ else "key"
266
+ )
267
+ return (name, TypedColumn(result_expr, lambda: [result_type]))