snowpark-connect 0.27.0__py3-none-any.whl → 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/snowpark_connect/__init__.py +1 -0
- snowflake/snowpark_connect/analyze_plan/map_tree_string.py +8 -4
- snowflake/snowpark_connect/client/__init__.py +15 -0
- snowflake/snowpark_connect/client/error_utils.py +30 -0
- snowflake/snowpark_connect/client/exceptions.py +36 -0
- snowflake/snowpark_connect/client/query_results.py +90 -0
- snowflake/snowpark_connect/client/server.py +717 -0
- snowflake/snowpark_connect/client/utils/__init__.py +10 -0
- snowflake/snowpark_connect/client/utils/session.py +85 -0
- snowflake/snowpark_connect/column_name_handler.py +404 -243
- snowflake/snowpark_connect/column_qualifier.py +43 -0
- snowflake/snowpark_connect/config.py +309 -26
- snowflake/snowpark_connect/constants.py +2 -0
- snowflake/snowpark_connect/dataframe_container.py +102 -8
- snowflake/snowpark_connect/date_time_format_mapping.py +71 -13
- snowflake/snowpark_connect/error/error_codes.py +50 -0
- snowflake/snowpark_connect/error/error_utils.py +172 -23
- snowflake/snowpark_connect/error/exceptions.py +13 -4
- snowflake/snowpark_connect/execute_plan/map_execution_command.py +15 -160
- snowflake/snowpark_connect/execute_plan/map_execution_root.py +26 -20
- snowflake/snowpark_connect/execute_plan/utils.py +5 -1
- snowflake/snowpark_connect/expression/error_utils.py +28 -0
- snowflake/snowpark_connect/expression/function_defaults.py +9 -2
- snowflake/snowpark_connect/expression/hybrid_column_map.py +53 -5
- snowflake/snowpark_connect/expression/integral_types_support.py +219 -0
- snowflake/snowpark_connect/expression/literal.py +37 -13
- snowflake/snowpark_connect/expression/map_cast.py +224 -15
- snowflake/snowpark_connect/expression/map_expression.py +80 -27
- snowflake/snowpark_connect/expression/map_extension.py +322 -12
- snowflake/snowpark_connect/expression/map_sql_expression.py +316 -81
- snowflake/snowpark_connect/expression/map_udf.py +86 -20
- snowflake/snowpark_connect/expression/map_unresolved_attribute.py +451 -173
- snowflake/snowpark_connect/expression/map_unresolved_function.py +2964 -829
- snowflake/snowpark_connect/expression/map_unresolved_star.py +87 -23
- snowflake/snowpark_connect/expression/map_update_fields.py +70 -18
- snowflake/snowpark_connect/expression/map_window_function.py +18 -3
- snowflake/snowpark_connect/includes/jars/json4s-ast_2.13-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/{scala-library-2.12.18.jar → sas-scala-udf_2.12-0.2.0.jar} +0 -0
- snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.13-0.2.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-reflect-2.13.16.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-common-utils_2.13-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/{spark-connect-client-jvm_2.12-3.5.6.jar → spark-connect-client-jvm_2.13-3.5.6.jar} +0 -0
- snowflake/snowpark_connect/includes/jars/{spark-sql_2.12-3.5.6.jar → spark-sql_2.13-3.5.6.jar} +0 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py +1 -1
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/listener_worker.py +1 -1
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.py +12 -10
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.pyi +14 -2
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +10 -8
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +13 -6
- snowflake/snowpark_connect/relation/catalogs/abstract_spark_catalog.py +65 -17
- snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +297 -49
- snowflake/snowpark_connect/relation/catalogs/utils.py +12 -4
- snowflake/snowpark_connect/relation/io_utils.py +110 -10
- snowflake/snowpark_connect/relation/map_aggregate.py +239 -256
- snowflake/snowpark_connect/relation/map_catalog.py +5 -1
- snowflake/snowpark_connect/relation/map_column_ops.py +264 -96
- snowflake/snowpark_connect/relation/map_extension.py +263 -29
- snowflake/snowpark_connect/relation/map_join.py +683 -442
- snowflake/snowpark_connect/relation/map_local_relation.py +28 -1
- snowflake/snowpark_connect/relation/map_map_partitions.py +83 -8
- snowflake/snowpark_connect/relation/map_relation.py +48 -19
- snowflake/snowpark_connect/relation/map_row_ops.py +310 -91
- snowflake/snowpark_connect/relation/map_show_string.py +13 -6
- snowflake/snowpark_connect/relation/map_sql.py +1233 -222
- snowflake/snowpark_connect/relation/map_stats.py +48 -9
- snowflake/snowpark_connect/relation/map_subquery_alias.py +11 -2
- snowflake/snowpark_connect/relation/map_udtf.py +14 -4
- snowflake/snowpark_connect/relation/read/jdbc_read_dbapi.py +53 -14
- snowflake/snowpark_connect/relation/read/map_read.py +134 -43
- snowflake/snowpark_connect/relation/read/map_read_csv.py +326 -47
- snowflake/snowpark_connect/relation/read/map_read_jdbc.py +21 -6
- snowflake/snowpark_connect/relation/read/map_read_json.py +324 -86
- snowflake/snowpark_connect/relation/read/map_read_parquet.py +146 -28
- snowflake/snowpark_connect/relation/read/map_read_partitioned_parquet.py +142 -0
- snowflake/snowpark_connect/relation/read/map_read_socket.py +15 -3
- snowflake/snowpark_connect/relation/read/map_read_table.py +86 -6
- snowflake/snowpark_connect/relation/read/map_read_text.py +22 -4
- snowflake/snowpark_connect/relation/read/metadata_utils.py +170 -0
- snowflake/snowpark_connect/relation/read/reader_config.py +42 -3
- snowflake/snowpark_connect/relation/read/utils.py +50 -5
- snowflake/snowpark_connect/relation/stage_locator.py +91 -55
- snowflake/snowpark_connect/relation/utils.py +128 -5
- snowflake/snowpark_connect/relation/write/jdbc_write_dbapi.py +19 -3
- snowflake/snowpark_connect/relation/write/map_write.py +929 -319
- snowflake/snowpark_connect/relation/write/map_write_jdbc.py +8 -2
- snowflake/snowpark_connect/resources/java_udfs-1.0-SNAPSHOT.jar +0 -0
- snowflake/snowpark_connect/resources_initializer.py +171 -48
- snowflake/snowpark_connect/server.py +528 -473
- snowflake/snowpark_connect/server_common/__init__.py +503 -0
- snowflake/snowpark_connect/snowflake_session.py +65 -0
- snowflake/snowpark_connect/start_server.py +53 -5
- snowflake/snowpark_connect/type_mapping.py +349 -27
- snowflake/snowpark_connect/type_support.py +130 -0
- snowflake/snowpark_connect/typed_column.py +9 -7
- snowflake/snowpark_connect/utils/artifacts.py +9 -8
- snowflake/snowpark_connect/utils/cache.py +49 -27
- snowflake/snowpark_connect/utils/concurrent.py +36 -1
- snowflake/snowpark_connect/utils/context.py +195 -37
- snowflake/snowpark_connect/utils/describe_query_cache.py +68 -53
- snowflake/snowpark_connect/utils/env_utils.py +5 -1
- snowflake/snowpark_connect/utils/expression_transformer.py +172 -0
- snowflake/snowpark_connect/utils/identifiers.py +137 -3
- snowflake/snowpark_connect/utils/io_utils.py +57 -1
- snowflake/snowpark_connect/utils/java_stored_procedure.py +151 -0
- snowflake/snowpark_connect/utils/java_udaf_utils.py +321 -0
- snowflake/snowpark_connect/utils/java_udtf_utils.py +239 -0
- snowflake/snowpark_connect/utils/jvm_udf_utils.py +281 -0
- snowflake/snowpark_connect/utils/open_telemetry.py +516 -0
- snowflake/snowpark_connect/utils/pandas_udtf_utils.py +8 -4
- snowflake/snowpark_connect/utils/patch_spark_line_number.py +181 -0
- snowflake/snowpark_connect/utils/profiling.py +25 -8
- snowflake/snowpark_connect/utils/scala_udf_utils.py +185 -340
- snowflake/snowpark_connect/utils/sequence.py +21 -0
- snowflake/snowpark_connect/utils/session.py +64 -28
- snowflake/snowpark_connect/utils/snowpark_connect_logging.py +51 -9
- snowflake/snowpark_connect/utils/spcs_logger.py +290 -0
- snowflake/snowpark_connect/utils/telemetry.py +192 -40
- snowflake/snowpark_connect/utils/temporary_view_cache.py +67 -0
- snowflake/snowpark_connect/utils/temporary_view_helper.py +334 -0
- snowflake/snowpark_connect/utils/udf_cache.py +117 -41
- snowflake/snowpark_connect/utils/udf_helper.py +39 -37
- snowflake/snowpark_connect/utils/udf_utils.py +133 -14
- snowflake/snowpark_connect/utils/udtf_helper.py +8 -1
- snowflake/snowpark_connect/utils/udtf_utils.py +46 -31
- snowflake/snowpark_connect/utils/udxf_import_utils.py +9 -2
- snowflake/snowpark_connect/utils/upload_java_jar.py +57 -0
- snowflake/snowpark_connect/version.py +1 -1
- snowflake/snowpark_decoder/dp_session.py +6 -2
- snowflake/snowpark_decoder/spark_decoder.py +12 -0
- {snowpark_connect-0.27.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-submit +14 -4
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/METADATA +16 -7
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/RECORD +139 -168
- snowflake/snowpark_connect/hidden_column.py +0 -39
- snowflake/snowpark_connect/includes/jars/antlr4-runtime-4.9.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-cli-1.5.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-codec-1.16.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-collections-3.2.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-collections4-4.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-compiler-3.1.9.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-compress-1.26.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-crypto-1.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-dbcp-1.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-io-2.16.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-lang-2.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-lang3-3.12.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-logging-1.1.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-math3-3.6.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-pool-1.5.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-text-1.10.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/hadoop-client-api-trimmed-3.3.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-annotations-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-core-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-core-asl-1.9.13.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-databind-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-dataformat-yaml-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-datatype-jsr310-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-module-scala_2.12-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-ast_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-core_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-jackson_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-native_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-scalap_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/kryo-shaded-4.0.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-1.2-api-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-api-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-core-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-slf4j2-impl-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/paranamer-2.8.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/paranamer-2.8.jar +0 -0
- snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-collection-compat_2.12-2.7.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-parser-combinators_2.12-2.3.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-reflect-2.12.18.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-xml_2.12-2.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/slf4j-api-2.0.7.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-catalyst_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-common-utils_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-core_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-graphx_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-hive-thriftserver_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-hive_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-kvstore_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-launcher_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-mesos_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-mllib-local_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-network-common_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-network-shuffle_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-repl_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sketch_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sql-api_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-tags_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-unsafe_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-yarn_2.12-3.5.6.jar +0 -0
- {snowpark_connect-0.27.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.27.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/NOTICE-binary +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/top_level.txt +0 -0
|
@@ -7,20 +7,23 @@ import pyspark.sql.connect.proto.types_pb2 as types_proto
|
|
|
7
7
|
|
|
8
8
|
import snowflake.snowpark.functions as snowpark_fn
|
|
9
9
|
from snowflake import snowpark
|
|
10
|
-
from snowflake.snowpark.types import MapType, StructType, VariantType
|
|
10
|
+
from snowflake.snowpark.types import ArrayType, MapType, StructType, VariantType
|
|
11
11
|
from snowflake.snowpark_connect.column_name_handler import ColumnNameMap
|
|
12
12
|
from snowflake.snowpark_connect.config import global_config
|
|
13
|
+
from snowflake.snowpark_connect.error.error_codes import ErrorCodes
|
|
14
|
+
from snowflake.snowpark_connect.error.error_utils import attach_custom_error_code
|
|
13
15
|
from snowflake.snowpark_connect.expression.typer import ExpressionTyper
|
|
14
16
|
from snowflake.snowpark_connect.type_mapping import proto_to_snowpark_type
|
|
15
17
|
from snowflake.snowpark_connect.typed_column import TypedColumn
|
|
18
|
+
from snowflake.snowpark_connect.utils.context import get_grouping_by_scala_udf_key
|
|
16
19
|
from snowflake.snowpark_connect.utils.external_udxf_cache import (
|
|
17
20
|
cache_external_udf,
|
|
18
21
|
get_external_udf_from_cache,
|
|
19
22
|
)
|
|
23
|
+
from snowflake.snowpark_connect.utils.java_stored_procedure import create_java_udf
|
|
20
24
|
from snowflake.snowpark_connect.utils.session import get_or_create_snowpark_session
|
|
21
25
|
from snowflake.snowpark_connect.utils.udf_helper import (
|
|
22
26
|
SnowparkUDF,
|
|
23
|
-
gen_input_types,
|
|
24
27
|
infer_snowpark_arguments,
|
|
25
28
|
process_udf_in_sproc,
|
|
26
29
|
require_creating_udf_in_sproc,
|
|
@@ -53,8 +56,14 @@ def cache_external_udf_wrapper(from_register_udf: bool):
|
|
|
53
56
|
session._udfs[udf_proto.function_name.lower()] = cached_udf
|
|
54
57
|
case "python_udf":
|
|
55
58
|
pass
|
|
59
|
+
case "java_udf":
|
|
60
|
+
session._udfs[udf_proto.function_name.lower()] = cached_udf
|
|
56
61
|
case _:
|
|
57
|
-
|
|
62
|
+
exception = ValueError(f"Unsupported UDF type: {function_type}")
|
|
63
|
+
attach_custom_error_code(
|
|
64
|
+
exception, ErrorCodes.UNSUPPORTED_OPERATION
|
|
65
|
+
)
|
|
66
|
+
raise exception
|
|
58
67
|
|
|
59
68
|
return cached_udf
|
|
60
69
|
|
|
@@ -94,13 +103,43 @@ def register_udf(
|
|
|
94
103
|
match udf_proto.WhichOneof("function"):
|
|
95
104
|
case "python_udf":
|
|
96
105
|
output_type = udf_proto.python_udf.output_type
|
|
106
|
+
processed_return_type, original_return_type = process_udf_return_type(
|
|
107
|
+
output_type
|
|
108
|
+
)
|
|
97
109
|
case "scalar_scala_udf":
|
|
110
|
+
# For Scala UDFs, always use VariantType as the processed type since all Scala UDFs
|
|
111
|
+
# return Variant. The actual type conversion happens after the UDF call.
|
|
98
112
|
output_type = udf_proto.scalar_scala_udf.outputType
|
|
113
|
+
original_return_type = proto_to_snowpark_type(output_type)
|
|
114
|
+
processed_return_type = VariantType()
|
|
115
|
+
case "java_udf":
|
|
116
|
+
has_output_type = udf_proto.java_udf.HasField("output_type")
|
|
117
|
+
session = get_or_create_snowpark_session()
|
|
118
|
+
java_udf = create_java_udf(
|
|
119
|
+
session,
|
|
120
|
+
udf_proto.function_name,
|
|
121
|
+
udf_proto.java_udf.class_name,
|
|
122
|
+
)
|
|
123
|
+
original_return_type = java_udf._return_type
|
|
124
|
+
if has_output_type:
|
|
125
|
+
original_return_type = proto_to_snowpark_type(
|
|
126
|
+
udf_proto.java_udf.output_type
|
|
127
|
+
)
|
|
128
|
+
udf = SnowparkUDF(
|
|
129
|
+
name=java_udf.name,
|
|
130
|
+
input_types=java_udf._input_types,
|
|
131
|
+
return_type=java_udf._return_type,
|
|
132
|
+
original_return_type=original_return_type,
|
|
133
|
+
cast_to_original_return_type=True,
|
|
134
|
+
)
|
|
135
|
+
session._udfs[udf_proto.function_name.lower()] = udf
|
|
136
|
+
return udf
|
|
99
137
|
case _:
|
|
100
|
-
|
|
138
|
+
exception = ValueError(
|
|
101
139
|
f"Unsupported UDF type: {udf_proto.WhichOneof('function')}"
|
|
102
140
|
)
|
|
103
|
-
|
|
141
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
142
|
+
raise exception
|
|
104
143
|
session = get_or_create_snowpark_session()
|
|
105
144
|
kwargs = {
|
|
106
145
|
"common_inline_user_defined_function": udf_proto,
|
|
@@ -116,11 +155,15 @@ def register_udf(
|
|
|
116
155
|
else:
|
|
117
156
|
udf_processor = ProcessCommonInlineUserDefinedFunction(**kwargs)
|
|
118
157
|
udf = udf_processor.create_udf()
|
|
158
|
+
is_scala_udf = udf_proto.WhichOneof("function") == "scalar_scala_udf"
|
|
159
|
+
|
|
119
160
|
udf = SnowparkUDF(
|
|
120
161
|
name=udf.name,
|
|
121
162
|
input_types=udf._input_types,
|
|
122
163
|
return_type=udf._return_type,
|
|
123
164
|
original_return_type=original_return_type,
|
|
165
|
+
cast_to_original_return_type=is_scala_udf
|
|
166
|
+
or udf._return_type == VariantType(),
|
|
124
167
|
)
|
|
125
168
|
session._udfs[udf_proto.function_name.lower()] = udf
|
|
126
169
|
# scala udfs can be also accessed using `udf.name`
|
|
@@ -136,19 +179,22 @@ def map_common_inline_user_defined_udf(
|
|
|
136
179
|
) -> tuple[str, TypedColumn]:
|
|
137
180
|
udf_proto = exp.common_inline_user_defined_function
|
|
138
181
|
udf_check(udf_proto)
|
|
139
|
-
snowpark_udf_arg_names,
|
|
182
|
+
snowpark_udf_arg_names, snowpark_udf_typed_args = infer_snowpark_arguments(
|
|
140
183
|
udf_proto, column_mapping, typer
|
|
141
184
|
)
|
|
142
|
-
input_types =
|
|
185
|
+
input_types = [a.typ for a in snowpark_udf_typed_args]
|
|
143
186
|
match udf_proto.WhichOneof("function"):
|
|
144
187
|
case "python_udf":
|
|
145
188
|
processed_return_type, original_return_type = process_udf_return_type(
|
|
146
189
|
udf_proto.python_udf.output_type
|
|
147
190
|
)
|
|
148
191
|
case "scalar_scala_udf":
|
|
149
|
-
|
|
192
|
+
# For Scala UDFs, always use VariantType as the processed type since all Scala UDFs
|
|
193
|
+
# return Variant. The actual type conversion happens after the UDF call.
|
|
194
|
+
original_return_type = proto_to_snowpark_type(
|
|
150
195
|
udf_proto.scalar_scala_udf.outputType
|
|
151
196
|
)
|
|
197
|
+
processed_return_type = VariantType()
|
|
152
198
|
|
|
153
199
|
@cache_external_udf_wrapper(from_register_udf=False)
|
|
154
200
|
def get_snowpark_udf(
|
|
@@ -178,24 +224,44 @@ def map_common_inline_user_defined_udf(
|
|
|
178
224
|
return snowpark_udf
|
|
179
225
|
|
|
180
226
|
snowpark_udf = get_snowpark_udf(udf_proto)
|
|
181
|
-
|
|
227
|
+
# Determine if we need to cast the result back to the original type
|
|
228
|
+
is_scala_udf = udf_proto.WhichOneof("function") == "scalar_scala_udf"
|
|
182
229
|
|
|
183
|
-
#
|
|
184
|
-
#
|
|
185
|
-
|
|
230
|
+
# For structured types (arrays, structs, maps), use to_variant instead of cast
|
|
231
|
+
# to ensure proper conversion to VARIANT type for Scala UDFS
|
|
232
|
+
converted_args = []
|
|
233
|
+
for tc in snowpark_udf_typed_args:
|
|
234
|
+
if is_scala_udf and isinstance(tc.typ, (ArrayType, StructType, MapType)):
|
|
235
|
+
converted_args.append(snowpark_fn.to_variant(tc.col))
|
|
236
|
+
else:
|
|
237
|
+
converted_args.append(tc.col)
|
|
238
|
+
|
|
239
|
+
udf_call_expr = snowpark_fn.call_udf(snowpark_udf.name, *converted_args)
|
|
240
|
+
|
|
241
|
+
# For Scala UDFs, always cast from Variant to the original type
|
|
242
|
+
# For Python UDFs, only cast if the original type was MapType or StructType
|
|
243
|
+
if is_scala_udf:
|
|
244
|
+
# All Scala UDFs return Variant, so we always need to cast back to the original type
|
|
245
|
+
result_expr = snowpark_fn.cast(udf_call_expr, original_return_type)
|
|
246
|
+
result_type = original_return_type
|
|
247
|
+
|
|
248
|
+
elif isinstance(original_return_type, (MapType, StructType)) and isinstance(
|
|
186
249
|
processed_return_type, VariantType
|
|
187
250
|
):
|
|
188
|
-
# Parse JSON and cast back to original type
|
|
251
|
+
# Parse JSON and cast back to original type for Python UDFs
|
|
189
252
|
result_expr = snowpark_fn.parse_json(udf_call_expr).cast(original_return_type)
|
|
190
253
|
result_type = original_return_type
|
|
191
254
|
else:
|
|
192
255
|
result_expr = udf_call_expr
|
|
193
256
|
result_type = snowpark_udf.return_type
|
|
194
257
|
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
258
|
+
name = f"{udf_proto.function_name}({', '.join(snowpark_udf_arg_names)})"
|
|
259
|
+
if get_grouping_by_scala_udf_key() and not isinstance(
|
|
260
|
+
original_return_type, StructType
|
|
261
|
+
):
|
|
262
|
+
name = (
|
|
263
|
+
"value"
|
|
264
|
+
if global_config.spark_sql_legacy_dataset_nameNonStructGroupingKeyAsValue
|
|
265
|
+
else "key"
|
|
266
|
+
)
|
|
267
|
+
return (name, TypedColumn(result_expr, lambda: [result_type]))
|