snowpark-connect 0.27.0__py3-none-any.whl → 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/snowpark_connect/__init__.py +1 -0
- snowflake/snowpark_connect/analyze_plan/map_tree_string.py +8 -4
- snowflake/snowpark_connect/client/__init__.py +15 -0
- snowflake/snowpark_connect/client/error_utils.py +30 -0
- snowflake/snowpark_connect/client/exceptions.py +36 -0
- snowflake/snowpark_connect/client/query_results.py +90 -0
- snowflake/snowpark_connect/client/server.py +717 -0
- snowflake/snowpark_connect/client/utils/__init__.py +10 -0
- snowflake/snowpark_connect/client/utils/session.py +85 -0
- snowflake/snowpark_connect/column_name_handler.py +404 -243
- snowflake/snowpark_connect/column_qualifier.py +43 -0
- snowflake/snowpark_connect/config.py +309 -26
- snowflake/snowpark_connect/constants.py +2 -0
- snowflake/snowpark_connect/dataframe_container.py +102 -8
- snowflake/snowpark_connect/date_time_format_mapping.py +71 -13
- snowflake/snowpark_connect/error/error_codes.py +50 -0
- snowflake/snowpark_connect/error/error_utils.py +172 -23
- snowflake/snowpark_connect/error/exceptions.py +13 -4
- snowflake/snowpark_connect/execute_plan/map_execution_command.py +15 -160
- snowflake/snowpark_connect/execute_plan/map_execution_root.py +26 -20
- snowflake/snowpark_connect/execute_plan/utils.py +5 -1
- snowflake/snowpark_connect/expression/error_utils.py +28 -0
- snowflake/snowpark_connect/expression/function_defaults.py +9 -2
- snowflake/snowpark_connect/expression/hybrid_column_map.py +53 -5
- snowflake/snowpark_connect/expression/integral_types_support.py +219 -0
- snowflake/snowpark_connect/expression/literal.py +37 -13
- snowflake/snowpark_connect/expression/map_cast.py +224 -15
- snowflake/snowpark_connect/expression/map_expression.py +80 -27
- snowflake/snowpark_connect/expression/map_extension.py +322 -12
- snowflake/snowpark_connect/expression/map_sql_expression.py +316 -81
- snowflake/snowpark_connect/expression/map_udf.py +86 -20
- snowflake/snowpark_connect/expression/map_unresolved_attribute.py +451 -173
- snowflake/snowpark_connect/expression/map_unresolved_function.py +2964 -829
- snowflake/snowpark_connect/expression/map_unresolved_star.py +87 -23
- snowflake/snowpark_connect/expression/map_update_fields.py +70 -18
- snowflake/snowpark_connect/expression/map_window_function.py +18 -3
- snowflake/snowpark_connect/includes/jars/json4s-ast_2.13-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/{scala-library-2.12.18.jar → sas-scala-udf_2.12-0.2.0.jar} +0 -0
- snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.13-0.2.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-reflect-2.13.16.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-common-utils_2.13-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/{spark-connect-client-jvm_2.12-3.5.6.jar → spark-connect-client-jvm_2.13-3.5.6.jar} +0 -0
- snowflake/snowpark_connect/includes/jars/{spark-sql_2.12-3.5.6.jar → spark-sql_2.13-3.5.6.jar} +0 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py +1 -1
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/listener_worker.py +1 -1
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.py +12 -10
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.pyi +14 -2
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +10 -8
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +13 -6
- snowflake/snowpark_connect/relation/catalogs/abstract_spark_catalog.py +65 -17
- snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +297 -49
- snowflake/snowpark_connect/relation/catalogs/utils.py +12 -4
- snowflake/snowpark_connect/relation/io_utils.py +110 -10
- snowflake/snowpark_connect/relation/map_aggregate.py +239 -256
- snowflake/snowpark_connect/relation/map_catalog.py +5 -1
- snowflake/snowpark_connect/relation/map_column_ops.py +264 -96
- snowflake/snowpark_connect/relation/map_extension.py +263 -29
- snowflake/snowpark_connect/relation/map_join.py +683 -442
- snowflake/snowpark_connect/relation/map_local_relation.py +28 -1
- snowflake/snowpark_connect/relation/map_map_partitions.py +83 -8
- snowflake/snowpark_connect/relation/map_relation.py +48 -19
- snowflake/snowpark_connect/relation/map_row_ops.py +310 -91
- snowflake/snowpark_connect/relation/map_show_string.py +13 -6
- snowflake/snowpark_connect/relation/map_sql.py +1233 -222
- snowflake/snowpark_connect/relation/map_stats.py +48 -9
- snowflake/snowpark_connect/relation/map_subquery_alias.py +11 -2
- snowflake/snowpark_connect/relation/map_udtf.py +14 -4
- snowflake/snowpark_connect/relation/read/jdbc_read_dbapi.py +53 -14
- snowflake/snowpark_connect/relation/read/map_read.py +134 -43
- snowflake/snowpark_connect/relation/read/map_read_csv.py +326 -47
- snowflake/snowpark_connect/relation/read/map_read_jdbc.py +21 -6
- snowflake/snowpark_connect/relation/read/map_read_json.py +324 -86
- snowflake/snowpark_connect/relation/read/map_read_parquet.py +146 -28
- snowflake/snowpark_connect/relation/read/map_read_partitioned_parquet.py +142 -0
- snowflake/snowpark_connect/relation/read/map_read_socket.py +15 -3
- snowflake/snowpark_connect/relation/read/map_read_table.py +86 -6
- snowflake/snowpark_connect/relation/read/map_read_text.py +22 -4
- snowflake/snowpark_connect/relation/read/metadata_utils.py +170 -0
- snowflake/snowpark_connect/relation/read/reader_config.py +42 -3
- snowflake/snowpark_connect/relation/read/utils.py +50 -5
- snowflake/snowpark_connect/relation/stage_locator.py +91 -55
- snowflake/snowpark_connect/relation/utils.py +128 -5
- snowflake/snowpark_connect/relation/write/jdbc_write_dbapi.py +19 -3
- snowflake/snowpark_connect/relation/write/map_write.py +929 -319
- snowflake/snowpark_connect/relation/write/map_write_jdbc.py +8 -2
- snowflake/snowpark_connect/resources/java_udfs-1.0-SNAPSHOT.jar +0 -0
- snowflake/snowpark_connect/resources_initializer.py +171 -48
- snowflake/snowpark_connect/server.py +528 -473
- snowflake/snowpark_connect/server_common/__init__.py +503 -0
- snowflake/snowpark_connect/snowflake_session.py +65 -0
- snowflake/snowpark_connect/start_server.py +53 -5
- snowflake/snowpark_connect/type_mapping.py +349 -27
- snowflake/snowpark_connect/type_support.py +130 -0
- snowflake/snowpark_connect/typed_column.py +9 -7
- snowflake/snowpark_connect/utils/artifacts.py +9 -8
- snowflake/snowpark_connect/utils/cache.py +49 -27
- snowflake/snowpark_connect/utils/concurrent.py +36 -1
- snowflake/snowpark_connect/utils/context.py +195 -37
- snowflake/snowpark_connect/utils/describe_query_cache.py +68 -53
- snowflake/snowpark_connect/utils/env_utils.py +5 -1
- snowflake/snowpark_connect/utils/expression_transformer.py +172 -0
- snowflake/snowpark_connect/utils/identifiers.py +137 -3
- snowflake/snowpark_connect/utils/io_utils.py +57 -1
- snowflake/snowpark_connect/utils/java_stored_procedure.py +151 -0
- snowflake/snowpark_connect/utils/java_udaf_utils.py +321 -0
- snowflake/snowpark_connect/utils/java_udtf_utils.py +239 -0
- snowflake/snowpark_connect/utils/jvm_udf_utils.py +281 -0
- snowflake/snowpark_connect/utils/open_telemetry.py +516 -0
- snowflake/snowpark_connect/utils/pandas_udtf_utils.py +8 -4
- snowflake/snowpark_connect/utils/patch_spark_line_number.py +181 -0
- snowflake/snowpark_connect/utils/profiling.py +25 -8
- snowflake/snowpark_connect/utils/scala_udf_utils.py +185 -340
- snowflake/snowpark_connect/utils/sequence.py +21 -0
- snowflake/snowpark_connect/utils/session.py +64 -28
- snowflake/snowpark_connect/utils/snowpark_connect_logging.py +51 -9
- snowflake/snowpark_connect/utils/spcs_logger.py +290 -0
- snowflake/snowpark_connect/utils/telemetry.py +192 -40
- snowflake/snowpark_connect/utils/temporary_view_cache.py +67 -0
- snowflake/snowpark_connect/utils/temporary_view_helper.py +334 -0
- snowflake/snowpark_connect/utils/udf_cache.py +117 -41
- snowflake/snowpark_connect/utils/udf_helper.py +39 -37
- snowflake/snowpark_connect/utils/udf_utils.py +133 -14
- snowflake/snowpark_connect/utils/udtf_helper.py +8 -1
- snowflake/snowpark_connect/utils/udtf_utils.py +46 -31
- snowflake/snowpark_connect/utils/udxf_import_utils.py +9 -2
- snowflake/snowpark_connect/utils/upload_java_jar.py +57 -0
- snowflake/snowpark_connect/version.py +1 -1
- snowflake/snowpark_decoder/dp_session.py +6 -2
- snowflake/snowpark_decoder/spark_decoder.py +12 -0
- {snowpark_connect-0.27.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-submit +14 -4
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/METADATA +16 -7
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/RECORD +139 -168
- snowflake/snowpark_connect/hidden_column.py +0 -39
- snowflake/snowpark_connect/includes/jars/antlr4-runtime-4.9.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-cli-1.5.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-codec-1.16.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-collections-3.2.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-collections4-4.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-compiler-3.1.9.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-compress-1.26.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-crypto-1.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-dbcp-1.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-io-2.16.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-lang-2.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-lang3-3.12.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-logging-1.1.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-math3-3.6.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-pool-1.5.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-text-1.10.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/hadoop-client-api-trimmed-3.3.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-annotations-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-core-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-core-asl-1.9.13.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-databind-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-dataformat-yaml-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-datatype-jsr310-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-module-scala_2.12-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-ast_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-core_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-jackson_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-native_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-scalap_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/kryo-shaded-4.0.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-1.2-api-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-api-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-core-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-slf4j2-impl-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/paranamer-2.8.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/paranamer-2.8.jar +0 -0
- snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-collection-compat_2.12-2.7.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-parser-combinators_2.12-2.3.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-reflect-2.12.18.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-xml_2.12-2.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/slf4j-api-2.0.7.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-catalyst_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-common-utils_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-core_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-graphx_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-hive-thriftserver_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-hive_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-kvstore_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-launcher_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-mesos_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-mllib-local_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-network-common_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-network-shuffle_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-repl_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sketch_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sql-api_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-tags_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-unsafe_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-yarn_2.12-3.5.6.jar +0 -0
- {snowpark_connect-0.27.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.27.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/NOTICE-binary +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/top_level.txt +0 -0
|
@@ -14,14 +14,20 @@ from pyspark.errors.exceptions.base import AnalysisException
|
|
|
14
14
|
import snowflake.snowpark.functions as snowpark_fn
|
|
15
15
|
import snowflake.snowpark_connect.tcm as tcm
|
|
16
16
|
import snowflake.snowpark_connect.utils.udf_utils as udf_utils
|
|
17
|
-
from snowflake.snowpark import
|
|
17
|
+
from snowflake.snowpark import Session
|
|
18
18
|
from snowflake.snowpark.types import DataType, _parse_datatype_json_value
|
|
19
19
|
from snowflake.snowpark_connect.column_name_handler import ColumnNameMap
|
|
20
20
|
from snowflake.snowpark_connect.config import global_config
|
|
21
|
+
from snowflake.snowpark_connect.error.error_codes import ErrorCodes
|
|
22
|
+
from snowflake.snowpark_connect.error.error_utils import attach_custom_error_code
|
|
21
23
|
from snowflake.snowpark_connect.expression.map_expression import (
|
|
22
24
|
map_single_column_expression,
|
|
23
25
|
)
|
|
26
|
+
from snowflake.snowpark_connect.expression.map_unresolved_star import (
|
|
27
|
+
map_unresolved_star_as_single_column,
|
|
28
|
+
)
|
|
24
29
|
from snowflake.snowpark_connect.expression.typer import ExpressionTyper
|
|
30
|
+
from snowflake.snowpark_connect.typed_column import TypedColumn
|
|
25
31
|
from snowflake.snowpark_connect.utils.context import (
|
|
26
32
|
get_is_aggregate_function,
|
|
27
33
|
get_is_evaluating_join_condition,
|
|
@@ -38,6 +44,7 @@ class SnowparkUDF(NamedTuple):
|
|
|
38
44
|
return_type: DataType
|
|
39
45
|
input_types: list[DataType]
|
|
40
46
|
original_return_type: DataType | None
|
|
47
|
+
cast_to_original_return_type: bool = False
|
|
41
48
|
|
|
42
49
|
|
|
43
50
|
def require_creating_udf_in_sproc(
|
|
@@ -184,6 +191,7 @@ def parse_return_type(return_type_json_str) -> Optional[DataType]:
|
|
|
184
191
|
|
|
185
192
|
|
|
186
193
|
def create(session, called_from, return_type_json_str, input_types_json_str, input_column_names_json_str, udf_name, replace, udf_packages, udf_imports, b64_str, original_return_type):
|
|
194
|
+
session._use_scoped_temp_objects = False
|
|
187
195
|
import snowflake.snowpark.context as context
|
|
188
196
|
context._use_structured_type_semantics = True
|
|
189
197
|
context._is_snowpark_connect_compatible_mode = True
|
|
@@ -227,25 +235,15 @@ def _check_supported_udf(
|
|
|
227
235
|
case "python_udf":
|
|
228
236
|
pass
|
|
229
237
|
case "java_udf":
|
|
230
|
-
|
|
231
|
-
get_or_create_snowpark_session,
|
|
232
|
-
)
|
|
233
|
-
|
|
234
|
-
session = get_or_create_snowpark_session()
|
|
235
|
-
if udf_proto.java_udf.class_name not in session._cached_java_udfs:
|
|
236
|
-
raise AnalysisException(
|
|
237
|
-
f"Can not load class {udf_proto.java_udf.class_name}"
|
|
238
|
-
)
|
|
239
|
-
else:
|
|
240
|
-
raise ValueError(
|
|
241
|
-
"Function type java_udf not supported for common inline user-defined function"
|
|
242
|
-
)
|
|
238
|
+
pass
|
|
243
239
|
case "scalar_scala_udf":
|
|
244
240
|
pass
|
|
245
241
|
case _ as function_type:
|
|
246
|
-
|
|
242
|
+
exception = ValueError(
|
|
247
243
|
f"Function type {function_type} not supported for common inline user-defined function"
|
|
248
244
|
)
|
|
245
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
246
|
+
raise exception
|
|
249
247
|
|
|
250
248
|
|
|
251
249
|
def _aggregate_function_check(
|
|
@@ -253,9 +251,11 @@ def _aggregate_function_check(
|
|
|
253
251
|
):
|
|
254
252
|
name, is_aggregate_function = get_is_aggregate_function()
|
|
255
253
|
if not udf_proto.deterministic and name != "default" and is_aggregate_function:
|
|
256
|
-
|
|
254
|
+
exception = AnalysisException(
|
|
257
255
|
f"[AGGREGATE_FUNCTION_WITH_NONDETERMINISTIC_EXPRESSION] Non-deterministic expression {name}({udf_proto.function_name}) should not appear in the arguments of an aggregate function."
|
|
258
256
|
)
|
|
257
|
+
attach_custom_error_code(exception, ErrorCodes.INVALID_OPERATION)
|
|
258
|
+
raise exception
|
|
259
259
|
|
|
260
260
|
|
|
261
261
|
def _join_checks(snowpark_udf_arg_names: list[str]):
|
|
@@ -282,49 +282,51 @@ def _join_checks(snowpark_udf_arg_names: list[str]):
|
|
|
282
282
|
and is_left_evaluable
|
|
283
283
|
and is_right_evaluable
|
|
284
284
|
):
|
|
285
|
-
|
|
285
|
+
exception = AnalysisException(
|
|
286
286
|
f"Detected implicit cartesian product for {is_evaluating_join_condition[0]} join between logical plans. \n"
|
|
287
287
|
f"Join condition is missing or trivial. \n"
|
|
288
288
|
f"Either: use the CROSS JOIN syntax to allow cartesian products between those relations, or; "
|
|
289
289
|
f"enable implicit cartesian products by setting the configuration variable spark.sql.crossJoin.enabled=True."
|
|
290
290
|
)
|
|
291
|
+
attach_custom_error_code(exception, ErrorCodes.INVALID_OPERATION)
|
|
292
|
+
raise exception
|
|
291
293
|
if (
|
|
292
294
|
is_evaluating_join_condition[0] != "INNER"
|
|
293
295
|
and is_evaluating_join_condition[1]
|
|
294
296
|
and is_left_evaluable
|
|
295
297
|
and is_right_evaluable
|
|
296
298
|
):
|
|
297
|
-
|
|
299
|
+
exception = AnalysisException(
|
|
298
300
|
f"[UNSUPPORTED_FEATURE.PYTHON_UDF_IN_ON_CLAUSE] The feature is not supported: "
|
|
299
301
|
f"Python UDF in the ON clause of a {is_evaluating_join_condition[0]} JOIN. "
|
|
300
302
|
f"In case of an INNNER JOIN consider rewriting to a CROSS JOIN with a WHERE clause."
|
|
301
303
|
)
|
|
304
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
305
|
+
raise exception
|
|
302
306
|
|
|
303
307
|
|
|
304
308
|
def infer_snowpark_arguments(
|
|
305
309
|
udf_proto: expressions_proto.CommonInlineUserDefinedFunction,
|
|
306
310
|
column_mapping: ColumnNameMap,
|
|
307
311
|
typer: ExpressionTyper,
|
|
308
|
-
) -> tuple[list[str], list[
|
|
309
|
-
snowpark_udf_args: list[
|
|
312
|
+
) -> tuple[list[str], list[TypedColumn]]:
|
|
313
|
+
snowpark_udf_args: list[TypedColumn] = []
|
|
310
314
|
snowpark_udf_arg_names: list[str] = []
|
|
311
315
|
for arg_exp in udf_proto.arguments:
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
316
|
+
# Handle unresolved_star expressions specially
|
|
317
|
+
if arg_exp.HasField("unresolved_star"):
|
|
318
|
+
# Use map_unresolved_star_as_struct to expand star into a single combined column
|
|
319
|
+
spark_name, typed_column = map_unresolved_star_as_single_column(
|
|
320
|
+
arg_exp, column_mapping, typer
|
|
321
|
+
)
|
|
322
|
+
snowpark_udf_args.append(typed_column)
|
|
323
|
+
snowpark_udf_arg_names.append(spark_name)
|
|
324
|
+
else:
|
|
325
|
+
(
|
|
326
|
+
snowpark_udf_arg_name,
|
|
327
|
+
snowpark_udf_arg,
|
|
328
|
+
) = map_single_column_expression(arg_exp, column_mapping, typer)
|
|
329
|
+
snowpark_udf_args.append(snowpark_udf_arg)
|
|
330
|
+
snowpark_udf_arg_names.append(snowpark_udf_arg_name)
|
|
319
331
|
_join_checks(snowpark_udf_arg_names)
|
|
320
332
|
return snowpark_udf_arg_names, snowpark_udf_args
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
def gen_input_types(
|
|
324
|
-
snowpark_udf_args: list[Column],
|
|
325
|
-
typer: ExpressionTyper,
|
|
326
|
-
):
|
|
327
|
-
input_types = []
|
|
328
|
-
for udf_arg in snowpark_udf_args:
|
|
329
|
-
input_types.extend(typer.type(udf_arg))
|
|
330
|
-
return input_types
|
|
@@ -103,7 +103,7 @@ class ProcessCommonInlineUserDefinedFunction:
|
|
|
103
103
|
)
|
|
104
104
|
case _:
|
|
105
105
|
raise ValueError(
|
|
106
|
-
f"Function type {self._function_type} not supported for common inline user-defined function"
|
|
106
|
+
f"[snowpark_connect::unsupported_operation] Function type {self._function_type} not supported for common inline user-defined function"
|
|
107
107
|
)
|
|
108
108
|
|
|
109
109
|
@property
|
|
@@ -112,7 +112,7 @@ class ProcessCommonInlineUserDefinedFunction:
|
|
|
112
112
|
return self._snowpark_udf_args
|
|
113
113
|
else:
|
|
114
114
|
raise ValueError(
|
|
115
|
-
"Column mapping is not provided, cannot get snowpark udf args"
|
|
115
|
+
"[snowpark_connect::internal_error] Column mapping is not provided, cannot get snowpark udf args"
|
|
116
116
|
)
|
|
117
117
|
|
|
118
118
|
@property
|
|
@@ -121,7 +121,7 @@ class ProcessCommonInlineUserDefinedFunction:
|
|
|
121
121
|
return self._snowpark_udf_arg_names
|
|
122
122
|
else:
|
|
123
123
|
raise ValueError(
|
|
124
|
-
"Column mapping is not provided, cannot get snowpark udf arg names"
|
|
124
|
+
"[snowpark_connect::internal_error] Column mapping is not provided, cannot get snowpark udf arg names"
|
|
125
125
|
)
|
|
126
126
|
|
|
127
127
|
def _create_python_udf(self):
|
|
@@ -148,7 +148,12 @@ class ProcessCommonInlineUserDefinedFunction:
|
|
|
148
148
|
|
|
149
149
|
# Change directory to the one containing the UDF imported files
|
|
150
150
|
import_path = sys._xoptions["snowflake_import_directory"]
|
|
151
|
-
|
|
151
|
+
if os.name == "nt":
|
|
152
|
+
import tempfile
|
|
153
|
+
|
|
154
|
+
tmp_path = os.path.join(tempfile.gettempdir(), f"sas-{os.getpid()}")
|
|
155
|
+
else:
|
|
156
|
+
tmp_path = f"/tmp/sas-{os.getpid()}"
|
|
152
157
|
os.makedirs(tmp_path, exist_ok=True)
|
|
153
158
|
os.chdir(tmp_path)
|
|
154
159
|
shutil.copytree(import_path, tmp_path, dirs_exist_ok=True)
|
|
@@ -176,14 +181,6 @@ class ProcessCommonInlineUserDefinedFunction:
|
|
|
176
181
|
tar_ref.extractall(archive[: -len(".archive")])
|
|
177
182
|
os.remove(archive)
|
|
178
183
|
|
|
179
|
-
def callable_func(*args, **kwargs):
|
|
180
|
-
import_staged_files()
|
|
181
|
-
return original_callable(*args, **kwargs)
|
|
182
|
-
|
|
183
|
-
callable_func.__signature__ = inspect.signature(original_callable)
|
|
184
|
-
if hasattr(original_callable, "__annotations__"):
|
|
185
|
-
callable_func.__annotations__ = original_callable.__annotations__
|
|
186
|
-
|
|
187
184
|
if self._udf_packages:
|
|
188
185
|
packages = [p.strip() for p in self._udf_packages.strip("[]").split(",")]
|
|
189
186
|
else:
|
|
@@ -193,13 +190,109 @@ class ProcessCommonInlineUserDefinedFunction:
|
|
|
193
190
|
else:
|
|
194
191
|
imports = []
|
|
195
192
|
|
|
193
|
+
def callable_func(*args, **kwargs):
|
|
194
|
+
if imports:
|
|
195
|
+
import_staged_files()
|
|
196
|
+
return original_callable(*args, **kwargs)
|
|
197
|
+
|
|
198
|
+
callable_func.__signature__ = inspect.signature(original_callable)
|
|
199
|
+
if hasattr(original_callable, "__annotations__"):
|
|
200
|
+
callable_func.__annotations__ = original_callable.__annotations__
|
|
201
|
+
|
|
196
202
|
update_none_input_types()
|
|
197
203
|
|
|
204
|
+
struct_positions = [
|
|
205
|
+
i
|
|
206
|
+
for i, t in enumerate(self._input_types or [])
|
|
207
|
+
if isinstance(t, StructType)
|
|
208
|
+
]
|
|
209
|
+
|
|
210
|
+
if struct_positions:
|
|
211
|
+
|
|
212
|
+
class StructRowProxy:
|
|
213
|
+
"""Row-like object supporting positional and named access for PySpark compatibility."""
|
|
214
|
+
|
|
215
|
+
def __init__(self, fields, values) -> None:
|
|
216
|
+
self._fields = fields
|
|
217
|
+
self._values = values
|
|
218
|
+
self._field_to_index = {field: i for i, field in enumerate(fields)}
|
|
219
|
+
|
|
220
|
+
def __getitem__(self, key):
|
|
221
|
+
if isinstance(key, int):
|
|
222
|
+
return self._values[key]
|
|
223
|
+
elif isinstance(key, str):
|
|
224
|
+
if key in self._field_to_index:
|
|
225
|
+
return self._values[self._field_to_index[key]]
|
|
226
|
+
raise KeyError(f"Field '{key}' not found in struct")
|
|
227
|
+
else:
|
|
228
|
+
raise TypeError(f"Invalid key type: {type(key)}")
|
|
229
|
+
|
|
230
|
+
def __getattr__(self, name):
|
|
231
|
+
if name.startswith("_"):
|
|
232
|
+
raise AttributeError(f"Attribute '{name}' not found")
|
|
233
|
+
if name in self._field_to_index:
|
|
234
|
+
return self._values[self._field_to_index[name]]
|
|
235
|
+
raise AttributeError(f"Attribute '{name}' not found")
|
|
236
|
+
|
|
237
|
+
def __len__(self):
|
|
238
|
+
return len(self._values)
|
|
239
|
+
|
|
240
|
+
def __iter__(self):
|
|
241
|
+
return iter(self._values)
|
|
242
|
+
|
|
243
|
+
def __repr__(self):
|
|
244
|
+
field_values = [
|
|
245
|
+
f"{field}={repr(value)}"
|
|
246
|
+
for field, value in zip(self._fields, self._values)
|
|
247
|
+
]
|
|
248
|
+
return f"Row({', '.join(field_values)})"
|
|
249
|
+
|
|
250
|
+
def asDict(self):
|
|
251
|
+
"""Convert to dict (like PySpark Row.asDict())."""
|
|
252
|
+
return dict(zip(self._fields, self._values))
|
|
253
|
+
|
|
254
|
+
def convert_to_row(arg):
|
|
255
|
+
"""Convert dict to StructRowProxy. Only called for struct positions."""
|
|
256
|
+
if isinstance(arg, dict) and arg:
|
|
257
|
+
fields = list(arg.keys())
|
|
258
|
+
values = [arg[k] for k in fields]
|
|
259
|
+
return StructRowProxy(fields, values)
|
|
260
|
+
return arg
|
|
261
|
+
|
|
262
|
+
def convert_from_row(result):
|
|
263
|
+
"""Convert StructRowProxy back to dict for serialization."""
|
|
264
|
+
if isinstance(result, StructRowProxy):
|
|
265
|
+
return result.asDict()
|
|
266
|
+
return result
|
|
267
|
+
|
|
268
|
+
def struct_input_wrapper(*args, **kwargs):
|
|
269
|
+
if struct_positions:
|
|
270
|
+
processed_args = []
|
|
271
|
+
for i, arg in enumerate(args):
|
|
272
|
+
if i in struct_positions:
|
|
273
|
+
processed_args.append(convert_to_row(arg))
|
|
274
|
+
else:
|
|
275
|
+
processed_args.append(arg)
|
|
276
|
+
|
|
277
|
+
processed_kwargs = {k: convert_to_row(v) for k, v in kwargs.items()}
|
|
278
|
+
result = callable_func(*tuple(processed_args), **processed_kwargs)
|
|
279
|
+
# Convert any StructRowProxy in return value back to dict for serialization
|
|
280
|
+
return convert_from_row(result)
|
|
281
|
+
return callable_func(*args, **kwargs)
|
|
282
|
+
|
|
198
283
|
needs_struct_conversion = isinstance(self._original_return_type, StructType)
|
|
199
284
|
|
|
285
|
+
# Use callable_func directly when there are no struct inputs to avoid closure issues.
|
|
286
|
+
# struct_input_wrapper captures convert_to_row in its closure, but convert_to_row is only
|
|
287
|
+
# defined when struct_positions is truthy. Cloudpickle serializes all closure variables,
|
|
288
|
+
# so using struct_input_wrapper without struct positions would fail during serialization.
|
|
289
|
+
updated_callable_func = (
|
|
290
|
+
struct_input_wrapper if struct_positions else callable_func
|
|
291
|
+
)
|
|
292
|
+
|
|
200
293
|
if not needs_struct_conversion:
|
|
201
294
|
return snowpark_fn.udf(
|
|
202
|
-
create_null_safe_wrapper(
|
|
295
|
+
create_null_safe_wrapper(updated_callable_func),
|
|
203
296
|
return_type=self._return_type,
|
|
204
297
|
input_types=self._input_types,
|
|
205
298
|
name=self._udf_name,
|
|
@@ -225,7 +318,21 @@ class ProcessCommonInlineUserDefinedFunction:
|
|
|
225
318
|
field_names = [field.name for field in self._original_return_type.fields]
|
|
226
319
|
|
|
227
320
|
def struct_wrapper(*args):
|
|
321
|
+
if struct_positions:
|
|
322
|
+
processed_args = []
|
|
323
|
+
for i, arg in enumerate(args):
|
|
324
|
+
if i in struct_positions:
|
|
325
|
+
processed_args.append(convert_to_row(arg))
|
|
326
|
+
else:
|
|
327
|
+
processed_args.append(arg)
|
|
328
|
+
args = tuple(processed_args)
|
|
329
|
+
|
|
228
330
|
result = callable_func(*args)
|
|
331
|
+
|
|
332
|
+
# Convert StructRowProxy back to dict for serialization
|
|
333
|
+
if struct_positions:
|
|
334
|
+
result = convert_from_row(result)
|
|
335
|
+
|
|
229
336
|
if isinstance(result, (tuple, list)):
|
|
230
337
|
# Convert tuple/list to dict using struct field names
|
|
231
338
|
if len(result) == len(field_names):
|
|
@@ -283,6 +390,18 @@ class ProcessCommonInlineUserDefinedFunction:
|
|
|
283
390
|
case "python_udf":
|
|
284
391
|
return self._create_python_udf()
|
|
285
392
|
case "scalar_scala_udf":
|
|
393
|
+
from snowflake.snowpark_connect.utils.context import (
|
|
394
|
+
get_is_aggregate_function,
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
name, is_aggregate_function = get_is_aggregate_function()
|
|
398
|
+
if is_aggregate_function and name.lower() == "reduce":
|
|
399
|
+
# Handling of Scala Reduce function requires usage of Java UDAF
|
|
400
|
+
from snowflake.snowpark_connect.utils.java_udaf_utils import (
|
|
401
|
+
create_java_udaf_for_reduce_scala_function,
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
return create_java_udaf_for_reduce_scala_function(self)
|
|
286
405
|
from snowflake.snowpark_connect.utils.scala_udf_utils import (
|
|
287
406
|
create_scala_udf,
|
|
288
407
|
)
|
|
@@ -290,5 +409,5 @@ class ProcessCommonInlineUserDefinedFunction:
|
|
|
290
409
|
return create_scala_udf(self)
|
|
291
410
|
case _:
|
|
292
411
|
raise ValueError(
|
|
293
|
-
f"Function type {self._function_type} not supported for common inline user-defined function"
|
|
412
|
+
f"[snowpark_connect::unsupported_operation] Function type {self._function_type} not supported for common inline user-defined function"
|
|
294
413
|
)
|
|
@@ -16,6 +16,8 @@ import snowflake.snowpark_connect.tcm as tcm
|
|
|
16
16
|
from snowflake import snowpark
|
|
17
17
|
from snowflake.snowpark._internal.analyzer.analyzer_utils import unquote_if_quoted
|
|
18
18
|
from snowflake.snowpark.types import DataType, StructType, _parse_datatype_json_value
|
|
19
|
+
from snowflake.snowpark_connect.error.error_codes import ErrorCodes
|
|
20
|
+
from snowflake.snowpark_connect.error.error_utils import attach_custom_error_code
|
|
19
21
|
from snowflake.snowpark_connect.type_mapping import proto_to_snowpark_type
|
|
20
22
|
from snowflake.snowpark_connect.utils import pandas_udtf_utils, udtf_utils
|
|
21
23
|
from snowflake.snowpark_connect.utils.session import get_or_create_snowpark_session
|
|
@@ -37,7 +39,9 @@ def udtf_check(
|
|
|
37
39
|
udtf_proto: relation_proto.CommonInlineUserDefinedTableFunction,
|
|
38
40
|
) -> None:
|
|
39
41
|
if udtf_proto.WhichOneof("function") != "python_udtf":
|
|
40
|
-
|
|
42
|
+
exception = ValueError(f"Not python udtf {udtf_proto.function}")
|
|
43
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
44
|
+
raise exception
|
|
41
45
|
|
|
42
46
|
|
|
43
47
|
def require_creating_udtf_in_sproc(
|
|
@@ -149,6 +153,7 @@ def parse_types(types_json_str) -> Optional[list[DataType]]:
|
|
|
149
153
|
return json.loads(types_json_str)
|
|
150
154
|
|
|
151
155
|
def create(session, b64_str, expected_types_json_str, output_schema_json_str, packages, imports, is_arrow_enabled, is_spark_compatible_udtf_mode_enabled, called_from):
|
|
156
|
+
session._use_scoped_temp_objects = False
|
|
152
157
|
import snowflake.snowpark.context as context
|
|
153
158
|
context._use_structured_type_semantics = True
|
|
154
159
|
context._is_snowpark_connect_compatible_mode = True
|
|
@@ -253,6 +258,7 @@ from snowflake.snowpark.types import _parse_datatype_json_value
|
|
|
253
258
|
{inline_udtf_utils_py_code}
|
|
254
259
|
|
|
255
260
|
def create(session, b64_str, spark_column_names_json_str, input_schema_json_str, return_schema_json_str):
|
|
261
|
+
session._use_scoped_temp_objects = False
|
|
256
262
|
import snowflake.snowpark.context as context
|
|
257
263
|
context._use_structured_type_semantics = True
|
|
258
264
|
context._is_snowpark_connect_compatible_mode = True
|
|
@@ -326,6 +332,7 @@ from snowflake.snowpark.types import _parse_datatype_json_value
|
|
|
326
332
|
from pyspark.serializers import CloudPickleSerializer
|
|
327
333
|
|
|
328
334
|
def create(session, func_info_json):
|
|
335
|
+
session._use_scoped_temp_objects = False
|
|
329
336
|
import snowflake.snowpark.context as context
|
|
330
337
|
context._use_structured_type_semantics = True
|
|
331
338
|
context._is_snowpark_connect_compatible_mode = True
|
|
@@ -32,14 +32,16 @@ def create_udtf(
|
|
|
32
32
|
udtf = udtf_proto.python_udtf
|
|
33
33
|
callable_func = CloudPickleSerializer().loads(udtf.command)
|
|
34
34
|
|
|
35
|
-
|
|
35
|
+
original_func = callable_func.eval
|
|
36
|
+
func_signature = inspect.signature(original_func)
|
|
36
37
|
# Set all input types to VariantType regardless of type hints so that we can pass all arguments as VariantType.
|
|
37
38
|
# Otherwise, we will run into issues with type mismatches. This only applies for UDTF registration.
|
|
38
39
|
# We subtract one here since UDTF functions are class methods and always have "self" as the first parameter.
|
|
39
40
|
input_types = [VariantType()] * (len(func_signature.parameters) - 1)
|
|
40
41
|
|
|
41
|
-
|
|
42
|
-
|
|
42
|
+
if imports:
|
|
43
|
+
# Wrapp callable to allow reading imported files
|
|
44
|
+
callable_func = artifacts_reader_wrapper(callable_func)
|
|
43
45
|
|
|
44
46
|
if is_arrow_enabled:
|
|
45
47
|
callable_func = spark_compatible_udtf_wrapper_with_arrow(
|
|
@@ -48,7 +50,7 @@ def create_udtf(
|
|
|
48
50
|
elif is_spark_compatible_udtf_mode_enabled:
|
|
49
51
|
callable_func = spark_compatible_udtf_wrapper(callable_func, expected_types)
|
|
50
52
|
else:
|
|
51
|
-
callable_func.process =
|
|
53
|
+
callable_func.process = original_func
|
|
52
54
|
if hasattr(callable_func, "terminate"):
|
|
53
55
|
callable_func.end_partition = callable_func.terminate
|
|
54
56
|
|
|
@@ -107,7 +109,9 @@ def create_udtf(
|
|
|
107
109
|
imports=imports,
|
|
108
110
|
)
|
|
109
111
|
case _:
|
|
110
|
-
raise NotImplementedError(
|
|
112
|
+
raise NotImplementedError(
|
|
113
|
+
f"[snowpark_connect::unsupported_operation] {called_from}"
|
|
114
|
+
)
|
|
111
115
|
|
|
112
116
|
|
|
113
117
|
def artifacts_reader_wrapper(user_udtf_cls: type) -> type:
|
|
@@ -127,7 +131,12 @@ def artifacts_reader_wrapper(user_udtf_cls: type) -> type:
|
|
|
127
131
|
|
|
128
132
|
# Change directory to the one containing the UDF imported files
|
|
129
133
|
import_path = sys._xoptions["snowflake_import_directory"]
|
|
130
|
-
|
|
134
|
+
if os.name == "nt":
|
|
135
|
+
import tempfile
|
|
136
|
+
|
|
137
|
+
tmp_path = os.path.join(tempfile.gettempdir(), f"sas-{os.getpid()}")
|
|
138
|
+
else:
|
|
139
|
+
tmp_path = f"/tmp/sas-{os.getpid()}"
|
|
131
140
|
os.makedirs(tmp_path, exist_ok=True)
|
|
132
141
|
os.chdir(tmp_path)
|
|
133
142
|
shutil.copytree(import_path, tmp_path, dirs_exist_ok=True)
|
|
@@ -195,17 +204,19 @@ def _create_convert_table_argument_to_row():
|
|
|
195
204
|
# Named access: row["col1"], row["col2"]
|
|
196
205
|
if key in self._field_to_index:
|
|
197
206
|
return self._values[self._field_to_index[key]]
|
|
198
|
-
raise KeyError(key)
|
|
207
|
+
raise KeyError(f"[snowpark_connect::invalid_operation] {key}")
|
|
199
208
|
else:
|
|
200
|
-
raise TypeError(
|
|
209
|
+
raise TypeError(
|
|
210
|
+
f"[snowpark_connect::type_mismatch] Invalid key type: {type(key)}"
|
|
211
|
+
)
|
|
201
212
|
|
|
202
213
|
def __getattr__(self, name):
|
|
203
214
|
# Attribute access: row.col1, row.col2
|
|
204
215
|
if name.startswith("_"):
|
|
205
|
-
raise AttributeError(name)
|
|
216
|
+
raise AttributeError(f"[snowpark_connect::invalid_operation] {name}")
|
|
206
217
|
if name in self._field_to_index:
|
|
207
218
|
return self._values[self._field_to_index[name]]
|
|
208
|
-
raise AttributeError(name)
|
|
219
|
+
raise AttributeError(f"[snowpark_connect::invalid_operation] {name}")
|
|
209
220
|
|
|
210
221
|
def __len__(self):
|
|
211
222
|
return len(self._values)
|
|
@@ -279,7 +290,9 @@ def spark_compatible_udtf_wrapper(
|
|
|
279
290
|
return val
|
|
280
291
|
if isinstance(val, datetime.datetime):
|
|
281
292
|
return val.date()
|
|
282
|
-
raise AttributeError(
|
|
293
|
+
raise AttributeError(
|
|
294
|
+
f"[snowpark_connect::invalid_input] Invalid date value {val}"
|
|
295
|
+
)
|
|
283
296
|
|
|
284
297
|
def _coerce_to_binary(val: object, target_type_name: str = "byte") -> bytes | None:
|
|
285
298
|
if target_type_name == "binary":
|
|
@@ -343,7 +356,9 @@ def spark_compatible_udtf_wrapper(
|
|
|
343
356
|
def _coerce_to_timestamp(val: object) -> datetime.datetime | None:
|
|
344
357
|
if isinstance(val, datetime.datetime):
|
|
345
358
|
return val
|
|
346
|
-
raise AttributeError(
|
|
359
|
+
raise AttributeError(
|
|
360
|
+
f"[snowpark_connect::invalid_input] Invalid time stamp value {val}"
|
|
361
|
+
)
|
|
347
362
|
|
|
348
363
|
SCALAR_COERCERS = {
|
|
349
364
|
"bool": _coerce_to_bool,
|
|
@@ -447,7 +462,7 @@ def spark_compatible_udtf_wrapper(
|
|
|
447
462
|
|
|
448
463
|
if not isinstance(raw_row_tuple, (tuple, list)):
|
|
449
464
|
raise TypeError(
|
|
450
|
-
f"[UDTF_INVALID_OUTPUT_ROW_TYPE] return value should be an iterable object containing tuples, but got {type(raw_row_tuple)}"
|
|
465
|
+
f"[snowpark_connect::type_mismatch] [UDTF_INVALID_OUTPUT_ROW_TYPE] return value should be an iterable object containing tuples, but got {type(raw_row_tuple)}"
|
|
451
466
|
)
|
|
452
467
|
|
|
453
468
|
if len(raw_row_tuple) != len(expected_types):
|
|
@@ -467,7 +482,7 @@ def spark_compatible_udtf_wrapper(
|
|
|
467
482
|
and val is not None
|
|
468
483
|
):
|
|
469
484
|
raise RuntimeError(
|
|
470
|
-
f"[UNEXPECTED_TUPLE_WITH_STRUCT] Expected a struct for column at position {i}, but got a primitive value of type {type(val)}"
|
|
485
|
+
f"[snowpark_connect::type_mismatch] [UNEXPECTED_TUPLE_WITH_STRUCT] Expected a struct for column at position {i}, but got a primitive value of type {type(val)}"
|
|
471
486
|
)
|
|
472
487
|
|
|
473
488
|
coerced_row_list = [None] * len(expected_types)
|
|
@@ -533,7 +548,7 @@ def spark_compatible_udtf_wrapper_with_arrow(
|
|
|
533
548
|
return pa.map_(key_type, value_type)
|
|
534
549
|
case _, _:
|
|
535
550
|
raise TypeError(
|
|
536
|
-
f"[UDTF_ARROW_TYPE_CAST_ERROR] Unsupported Python scalar type for Arrow conversion: {target_py_type}"
|
|
551
|
+
f"[snowpark_connect::unsupported_type] [UDTF_ARROW_TYPE_CAST_ERROR] Unsupported Python scalar type for Arrow conversion: {target_py_type}"
|
|
537
552
|
)
|
|
538
553
|
elif kind == "array":
|
|
539
554
|
element_type_info = type_marker
|
|
@@ -543,7 +558,7 @@ def spark_compatible_udtf_wrapper_with_arrow(
|
|
|
543
558
|
struct_fields_info = type_marker
|
|
544
559
|
if not isinstance(struct_fields_info, dict):
|
|
545
560
|
raise TypeError(
|
|
546
|
-
f"[UDTF_ARROW_TYPE_CAST_ERROR] Invalid struct definition for Arrow: expected dict, got {type(struct_fields_info)}"
|
|
561
|
+
f"[snowpark_connect::type_mismatch] [UDTF_ARROW_TYPE_CAST_ERROR] Invalid struct definition for Arrow: expected dict, got {type(struct_fields_info)}"
|
|
547
562
|
)
|
|
548
563
|
fields = []
|
|
549
564
|
for field_name, field_type_info in struct_fields_info.items():
|
|
@@ -552,7 +567,7 @@ def spark_compatible_udtf_wrapper_with_arrow(
|
|
|
552
567
|
return pa.struct(fields)
|
|
553
568
|
else:
|
|
554
569
|
raise TypeError(
|
|
555
|
-
f"[UDTF_ARROW_TYPE_CAST_ERROR] Unsupported data kind for Arrow conversion: {kind}"
|
|
570
|
+
f"[snowpark_connect::unsupported_type] [UDTF_ARROW_TYPE_CAST_ERROR] Unsupported data kind for Arrow conversion: {kind}"
|
|
556
571
|
)
|
|
557
572
|
|
|
558
573
|
def _convert_to_arrow_value(
|
|
@@ -576,7 +591,7 @@ def spark_compatible_udtf_wrapper_with_arrow(
|
|
|
576
591
|
]
|
|
577
592
|
if not isinstance(obj, (list, tuple)):
|
|
578
593
|
raise TypeError(
|
|
579
|
-
f"[UDTF_ARROW_TYPE_CAST_ERROR] Expected list or tuple for Arrow array type, got {type(obj).__name__}"
|
|
594
|
+
f"[snowpark_connect::type_mismatch] [UDTF_ARROW_TYPE_CAST_ERROR] Expected list or tuple for Arrow array type, got {type(obj).__name__}"
|
|
580
595
|
)
|
|
581
596
|
element_type = arrow_type.value_type
|
|
582
597
|
return [_convert_to_arrow_value(e, element_type, "array") for e in obj]
|
|
@@ -584,7 +599,7 @@ def spark_compatible_udtf_wrapper_with_arrow(
|
|
|
584
599
|
if pa.types.is_map(arrow_type):
|
|
585
600
|
if not isinstance(obj, dict):
|
|
586
601
|
raise TypeError(
|
|
587
|
-
f"[UDTF_ARROW_TYPE_CAST_ERROR] Expected dict for Arrow map type, got {type(obj).__name__}"
|
|
602
|
+
f"[snowpark_connect::type_mismatch] [UDTF_ARROW_TYPE_CAST_ERROR] Expected dict for Arrow map type, got {type(obj).__name__}"
|
|
588
603
|
)
|
|
589
604
|
key_type = arrow_type.key_type
|
|
590
605
|
value_type = arrow_type.item_type
|
|
@@ -610,7 +625,7 @@ def spark_compatible_udtf_wrapper_with_arrow(
|
|
|
610
625
|
else:
|
|
611
626
|
# If the UDTF yields a list/tuple (or anything not a dict) for a struct column, it's an error.
|
|
612
627
|
raise TypeError(
|
|
613
|
-
f"[UDTF_ARROW_TYPE_CAST_ERROR] Expected a dictionary for Arrow struct type column, but got {type(obj).__name__}"
|
|
628
|
+
f"[snowpark_connect::type_mismatch] [UDTF_ARROW_TYPE_CAST_ERROR] Expected a dictionary for Arrow struct type column, but got {type(obj).__name__}"
|
|
614
629
|
)
|
|
615
630
|
|
|
616
631
|
# Check if a scalar type is expected and if obj is a collection; if so, error out.
|
|
@@ -622,7 +637,7 @@ def spark_compatible_udtf_wrapper_with_arrow(
|
|
|
622
637
|
):
|
|
623
638
|
if isinstance(obj, (list, tuple, dict)):
|
|
624
639
|
raise TypeError(
|
|
625
|
-
f"[UDTF_ARROW_TYPE_CAST_ERROR] Cannot convert Python collection type {type(obj).__name__} to scalar Arrow type {arrow_type}"
|
|
640
|
+
f"[snowpark_connect::type_mismatch] [UDTF_ARROW_TYPE_CAST_ERROR] Cannot convert Python collection type {type(obj).__name__} to scalar Arrow type {arrow_type}"
|
|
626
641
|
)
|
|
627
642
|
|
|
628
643
|
if pa.types.is_boolean(arrow_type):
|
|
@@ -638,7 +653,7 @@ def spark_compatible_udtf_wrapper_with_arrow(
|
|
|
638
653
|
elif obj == 1:
|
|
639
654
|
return True
|
|
640
655
|
raise TypeError(
|
|
641
|
-
f"[UDTF_ARROW_TYPE_CAST_ERROR] Cannot convert {obj} to Arrow boolean"
|
|
656
|
+
f"[snowpark_connect::type_mismatch] [UDTF_ARROW_TYPE_CAST_ERROR] Cannot convert {obj} to Arrow boolean"
|
|
642
657
|
)
|
|
643
658
|
if isinstance(obj, str):
|
|
644
659
|
v_str = obj.strip().lower()
|
|
@@ -647,7 +662,7 @@ def spark_compatible_udtf_wrapper_with_arrow(
|
|
|
647
662
|
if v_str == "false":
|
|
648
663
|
return False
|
|
649
664
|
raise TypeError(
|
|
650
|
-
f"[UDTF_ARROW_TYPE_CAST_ERROR] Cannot convert {type(obj).__name__} to Arrow boolean"
|
|
665
|
+
f"[snowpark_connect::type_mismatch] [UDTF_ARROW_TYPE_CAST_ERROR] Cannot convert {type(obj).__name__} to Arrow boolean"
|
|
651
666
|
)
|
|
652
667
|
|
|
653
668
|
if pa.types.is_integer(arrow_type):
|
|
@@ -663,7 +678,7 @@ def spark_compatible_udtf_wrapper_with_arrow(
|
|
|
663
678
|
except ValueError:
|
|
664
679
|
pass
|
|
665
680
|
raise TypeError(
|
|
666
|
-
f"[UDTF_ARROW_TYPE_CAST_ERROR] Cannot convert {type(obj).__name__} to Arrow integer"
|
|
681
|
+
f"[snowpark_connect::type_mismatch] [UDTF_ARROW_TYPE_CAST_ERROR] Cannot convert {type(obj).__name__} to Arrow integer"
|
|
667
682
|
)
|
|
668
683
|
|
|
669
684
|
if pa.types.is_floating(arrow_type):
|
|
@@ -675,7 +690,7 @@ def spark_compatible_udtf_wrapper_with_arrow(
|
|
|
675
690
|
except ValueError:
|
|
676
691
|
pass
|
|
677
692
|
raise TypeError(
|
|
678
|
-
f"[UDTF_ARROW_TYPE_CAST_ERROR] Cannot convert {type(obj).__name__} to Arrow float"
|
|
693
|
+
f"[snowpark_connect::type_mismatch] [UDTF_ARROW_TYPE_CAST_ERROR] Cannot convert {type(obj).__name__} to Arrow float"
|
|
679
694
|
)
|
|
680
695
|
|
|
681
696
|
if pa.types.is_string(arrow_type):
|
|
@@ -687,7 +702,7 @@ def spark_compatible_udtf_wrapper_with_arrow(
|
|
|
687
702
|
if isinstance(obj, str):
|
|
688
703
|
return obj
|
|
689
704
|
raise TypeError(
|
|
690
|
-
f"[UDTF_ARROW_TYPE_CAST_ERROR] Cannot convert {type(obj).__name__} to Arrow string"
|
|
705
|
+
f"[snowpark_connect::type_mismatch] [UDTF_ARROW_TYPE_CAST_ERROR] Cannot convert {type(obj).__name__} to Arrow string"
|
|
691
706
|
)
|
|
692
707
|
|
|
693
708
|
if pa.types.is_binary(arrow_type) or pa.types.is_fixed_size_binary(arrow_type):
|
|
@@ -698,21 +713,21 @@ def spark_compatible_udtf_wrapper_with_arrow(
|
|
|
698
713
|
if isinstance(obj, int):
|
|
699
714
|
return bytearray([obj])
|
|
700
715
|
raise TypeError(
|
|
701
|
-
f"[UDTF_ARROW_TYPE_CAST_ERROR] Cannot convert {type(obj).__name__} to Arrow binary"
|
|
716
|
+
f"[snowpark_connect::type_mismatch] [UDTF_ARROW_TYPE_CAST_ERROR] Cannot convert {type(obj).__name__} to Arrow binary"
|
|
702
717
|
)
|
|
703
718
|
|
|
704
719
|
if pa.types.is_date(arrow_type):
|
|
705
720
|
if isinstance(obj, datetime.date):
|
|
706
721
|
return obj
|
|
707
722
|
raise TypeError(
|
|
708
|
-
f"[UDTF_ARROW_TYPE_CAST_ERROR] Cannot convert {type(obj).__name__} to Arrow date. Expected datetime.date."
|
|
723
|
+
f"[snowpark_connect::type_mismatch] [UDTF_ARROW_TYPE_CAST_ERROR] Cannot convert {type(obj).__name__} to Arrow date. Expected datetime.date."
|
|
709
724
|
)
|
|
710
725
|
|
|
711
726
|
if pa.types.is_timestamp(arrow_type):
|
|
712
727
|
if isinstance(obj, datetime.datetime):
|
|
713
728
|
return obj
|
|
714
729
|
raise TypeError(
|
|
715
|
-
f"[UDTF_ARROW_TYPE_CAST_ERROR] Cannot convert {type(obj).__name__} to Arrow timestamp. Expected datetime.datetime."
|
|
730
|
+
f"[snowpark_connect::type_mismatch] [UDTF_ARROW_TYPE_CAST_ERROR] Cannot convert {type(obj).__name__} to Arrow timestamp. Expected datetime.datetime."
|
|
716
731
|
)
|
|
717
732
|
|
|
718
733
|
if pa.types.is_decimal(arrow_type):
|
|
@@ -727,11 +742,11 @@ def spark_compatible_udtf_wrapper_with_arrow(
|
|
|
727
742
|
pass
|
|
728
743
|
|
|
729
744
|
raise TypeError(
|
|
730
|
-
f"[UDTF_ARROW_TYPE_CAST_ERROR] Cannot convert {type(obj).__name__} to Arrow decimal. Expected decimal.Decimal or compatible int/str."
|
|
745
|
+
f"[snowpark_connect::type_mismatch] [UDTF_ARROW_TYPE_CAST_ERROR] Cannot convert {type(obj).__name__} to Arrow decimal. Expected decimal.Decimal or compatible int/str."
|
|
731
746
|
)
|
|
732
747
|
|
|
733
748
|
raise TypeError(
|
|
734
|
-
f"[UDTF_ARROW_TYPE_CAST_ERROR] Unsupported type conversion for {type(obj).__name__} to Arrow type {arrow_type}"
|
|
749
|
+
f"[snowpark_connect::unsupported_operation] [UDTF_ARROW_TYPE_CAST_ERROR] Unsupported type conversion for {type(obj).__name__} to Arrow type {arrow_type}"
|
|
735
750
|
)
|
|
736
751
|
|
|
737
752
|
class WrappedUDTF:
|