snowpark-connect 0.27.0__py3-none-any.whl → 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/snowpark_connect/__init__.py +1 -0
- snowflake/snowpark_connect/analyze_plan/map_tree_string.py +8 -4
- snowflake/snowpark_connect/client/__init__.py +15 -0
- snowflake/snowpark_connect/client/error_utils.py +30 -0
- snowflake/snowpark_connect/client/exceptions.py +36 -0
- snowflake/snowpark_connect/client/query_results.py +90 -0
- snowflake/snowpark_connect/client/server.py +717 -0
- snowflake/snowpark_connect/client/utils/__init__.py +10 -0
- snowflake/snowpark_connect/client/utils/session.py +85 -0
- snowflake/snowpark_connect/column_name_handler.py +404 -243
- snowflake/snowpark_connect/column_qualifier.py +43 -0
- snowflake/snowpark_connect/config.py +309 -26
- snowflake/snowpark_connect/constants.py +2 -0
- snowflake/snowpark_connect/dataframe_container.py +102 -8
- snowflake/snowpark_connect/date_time_format_mapping.py +71 -13
- snowflake/snowpark_connect/error/error_codes.py +50 -0
- snowflake/snowpark_connect/error/error_utils.py +172 -23
- snowflake/snowpark_connect/error/exceptions.py +13 -4
- snowflake/snowpark_connect/execute_plan/map_execution_command.py +15 -160
- snowflake/snowpark_connect/execute_plan/map_execution_root.py +26 -20
- snowflake/snowpark_connect/execute_plan/utils.py +5 -1
- snowflake/snowpark_connect/expression/error_utils.py +28 -0
- snowflake/snowpark_connect/expression/function_defaults.py +9 -2
- snowflake/snowpark_connect/expression/hybrid_column_map.py +53 -5
- snowflake/snowpark_connect/expression/integral_types_support.py +219 -0
- snowflake/snowpark_connect/expression/literal.py +37 -13
- snowflake/snowpark_connect/expression/map_cast.py +224 -15
- snowflake/snowpark_connect/expression/map_expression.py +80 -27
- snowflake/snowpark_connect/expression/map_extension.py +322 -12
- snowflake/snowpark_connect/expression/map_sql_expression.py +316 -81
- snowflake/snowpark_connect/expression/map_udf.py +86 -20
- snowflake/snowpark_connect/expression/map_unresolved_attribute.py +451 -173
- snowflake/snowpark_connect/expression/map_unresolved_function.py +2964 -829
- snowflake/snowpark_connect/expression/map_unresolved_star.py +87 -23
- snowflake/snowpark_connect/expression/map_update_fields.py +70 -18
- snowflake/snowpark_connect/expression/map_window_function.py +18 -3
- snowflake/snowpark_connect/includes/jars/json4s-ast_2.13-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/{scala-library-2.12.18.jar → sas-scala-udf_2.12-0.2.0.jar} +0 -0
- snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.13-0.2.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-reflect-2.13.16.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-common-utils_2.13-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/{spark-connect-client-jvm_2.12-3.5.6.jar → spark-connect-client-jvm_2.13-3.5.6.jar} +0 -0
- snowflake/snowpark_connect/includes/jars/{spark-sql_2.12-3.5.6.jar → spark-sql_2.13-3.5.6.jar} +0 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py +1 -1
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/listener_worker.py +1 -1
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.py +12 -10
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.pyi +14 -2
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +10 -8
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +13 -6
- snowflake/snowpark_connect/relation/catalogs/abstract_spark_catalog.py +65 -17
- snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +297 -49
- snowflake/snowpark_connect/relation/catalogs/utils.py +12 -4
- snowflake/snowpark_connect/relation/io_utils.py +110 -10
- snowflake/snowpark_connect/relation/map_aggregate.py +239 -256
- snowflake/snowpark_connect/relation/map_catalog.py +5 -1
- snowflake/snowpark_connect/relation/map_column_ops.py +264 -96
- snowflake/snowpark_connect/relation/map_extension.py +263 -29
- snowflake/snowpark_connect/relation/map_join.py +683 -442
- snowflake/snowpark_connect/relation/map_local_relation.py +28 -1
- snowflake/snowpark_connect/relation/map_map_partitions.py +83 -8
- snowflake/snowpark_connect/relation/map_relation.py +48 -19
- snowflake/snowpark_connect/relation/map_row_ops.py +310 -91
- snowflake/snowpark_connect/relation/map_show_string.py +13 -6
- snowflake/snowpark_connect/relation/map_sql.py +1233 -222
- snowflake/snowpark_connect/relation/map_stats.py +48 -9
- snowflake/snowpark_connect/relation/map_subquery_alias.py +11 -2
- snowflake/snowpark_connect/relation/map_udtf.py +14 -4
- snowflake/snowpark_connect/relation/read/jdbc_read_dbapi.py +53 -14
- snowflake/snowpark_connect/relation/read/map_read.py +134 -43
- snowflake/snowpark_connect/relation/read/map_read_csv.py +326 -47
- snowflake/snowpark_connect/relation/read/map_read_jdbc.py +21 -6
- snowflake/snowpark_connect/relation/read/map_read_json.py +324 -86
- snowflake/snowpark_connect/relation/read/map_read_parquet.py +146 -28
- snowflake/snowpark_connect/relation/read/map_read_partitioned_parquet.py +142 -0
- snowflake/snowpark_connect/relation/read/map_read_socket.py +15 -3
- snowflake/snowpark_connect/relation/read/map_read_table.py +86 -6
- snowflake/snowpark_connect/relation/read/map_read_text.py +22 -4
- snowflake/snowpark_connect/relation/read/metadata_utils.py +170 -0
- snowflake/snowpark_connect/relation/read/reader_config.py +42 -3
- snowflake/snowpark_connect/relation/read/utils.py +50 -5
- snowflake/snowpark_connect/relation/stage_locator.py +91 -55
- snowflake/snowpark_connect/relation/utils.py +128 -5
- snowflake/snowpark_connect/relation/write/jdbc_write_dbapi.py +19 -3
- snowflake/snowpark_connect/relation/write/map_write.py +929 -319
- snowflake/snowpark_connect/relation/write/map_write_jdbc.py +8 -2
- snowflake/snowpark_connect/resources/java_udfs-1.0-SNAPSHOT.jar +0 -0
- snowflake/snowpark_connect/resources_initializer.py +171 -48
- snowflake/snowpark_connect/server.py +528 -473
- snowflake/snowpark_connect/server_common/__init__.py +503 -0
- snowflake/snowpark_connect/snowflake_session.py +65 -0
- snowflake/snowpark_connect/start_server.py +53 -5
- snowflake/snowpark_connect/type_mapping.py +349 -27
- snowflake/snowpark_connect/type_support.py +130 -0
- snowflake/snowpark_connect/typed_column.py +9 -7
- snowflake/snowpark_connect/utils/artifacts.py +9 -8
- snowflake/snowpark_connect/utils/cache.py +49 -27
- snowflake/snowpark_connect/utils/concurrent.py +36 -1
- snowflake/snowpark_connect/utils/context.py +195 -37
- snowflake/snowpark_connect/utils/describe_query_cache.py +68 -53
- snowflake/snowpark_connect/utils/env_utils.py +5 -1
- snowflake/snowpark_connect/utils/expression_transformer.py +172 -0
- snowflake/snowpark_connect/utils/identifiers.py +137 -3
- snowflake/snowpark_connect/utils/io_utils.py +57 -1
- snowflake/snowpark_connect/utils/java_stored_procedure.py +151 -0
- snowflake/snowpark_connect/utils/java_udaf_utils.py +321 -0
- snowflake/snowpark_connect/utils/java_udtf_utils.py +239 -0
- snowflake/snowpark_connect/utils/jvm_udf_utils.py +281 -0
- snowflake/snowpark_connect/utils/open_telemetry.py +516 -0
- snowflake/snowpark_connect/utils/pandas_udtf_utils.py +8 -4
- snowflake/snowpark_connect/utils/patch_spark_line_number.py +181 -0
- snowflake/snowpark_connect/utils/profiling.py +25 -8
- snowflake/snowpark_connect/utils/scala_udf_utils.py +185 -340
- snowflake/snowpark_connect/utils/sequence.py +21 -0
- snowflake/snowpark_connect/utils/session.py +64 -28
- snowflake/snowpark_connect/utils/snowpark_connect_logging.py +51 -9
- snowflake/snowpark_connect/utils/spcs_logger.py +290 -0
- snowflake/snowpark_connect/utils/telemetry.py +192 -40
- snowflake/snowpark_connect/utils/temporary_view_cache.py +67 -0
- snowflake/snowpark_connect/utils/temporary_view_helper.py +334 -0
- snowflake/snowpark_connect/utils/udf_cache.py +117 -41
- snowflake/snowpark_connect/utils/udf_helper.py +39 -37
- snowflake/snowpark_connect/utils/udf_utils.py +133 -14
- snowflake/snowpark_connect/utils/udtf_helper.py +8 -1
- snowflake/snowpark_connect/utils/udtf_utils.py +46 -31
- snowflake/snowpark_connect/utils/udxf_import_utils.py +9 -2
- snowflake/snowpark_connect/utils/upload_java_jar.py +57 -0
- snowflake/snowpark_connect/version.py +1 -1
- snowflake/snowpark_decoder/dp_session.py +6 -2
- snowflake/snowpark_decoder/spark_decoder.py +12 -0
- {snowpark_connect-0.27.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-submit +14 -4
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/METADATA +16 -7
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/RECORD +139 -168
- snowflake/snowpark_connect/hidden_column.py +0 -39
- snowflake/snowpark_connect/includes/jars/antlr4-runtime-4.9.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-cli-1.5.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-codec-1.16.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-collections-3.2.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-collections4-4.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-compiler-3.1.9.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-compress-1.26.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-crypto-1.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-dbcp-1.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-io-2.16.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-lang-2.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-lang3-3.12.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-logging-1.1.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-math3-3.6.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-pool-1.5.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-text-1.10.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/hadoop-client-api-trimmed-3.3.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-annotations-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-core-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-core-asl-1.9.13.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-databind-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-dataformat-yaml-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-datatype-jsr310-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-module-scala_2.12-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-ast_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-core_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-jackson_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-native_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-scalap_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/kryo-shaded-4.0.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-1.2-api-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-api-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-core-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-slf4j2-impl-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/paranamer-2.8.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/paranamer-2.8.jar +0 -0
- snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-collection-compat_2.12-2.7.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-parser-combinators_2.12-2.3.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-reflect-2.12.18.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-xml_2.12-2.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/slf4j-api-2.0.7.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-catalyst_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-common-utils_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-core_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-graphx_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-hive-thriftserver_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-hive_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-kvstore_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-launcher_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-mesos_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-mllib-local_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-network-common_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-network-shuffle_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-repl_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sketch_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sql-api_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-tags_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-unsafe_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-yarn_2.12-3.5.6.jar +0 -0
- {snowpark_connect-0.27.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.27.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/NOTICE-binary +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/top_level.txt +0 -0
|
@@ -15,14 +15,23 @@ Key components:
|
|
|
15
15
|
- Type mapping functions for different type systems
|
|
16
16
|
- UDF creation and management utilities
|
|
17
17
|
"""
|
|
18
|
-
import re
|
|
19
18
|
from dataclasses import dataclass
|
|
20
|
-
from enum import Enum
|
|
21
19
|
from typing import List, Union
|
|
22
20
|
|
|
23
21
|
import snowflake.snowpark.types as snowpark_type
|
|
24
22
|
import snowflake.snowpark_connect.includes.python.pyspark.sql.connect.proto.types_pb2 as types_proto
|
|
25
|
-
from snowflake.snowpark_connect.
|
|
23
|
+
from snowflake.snowpark_connect.config import get_scala_version
|
|
24
|
+
from snowflake.snowpark_connect.error.error_codes import ErrorCodes
|
|
25
|
+
from snowflake.snowpark_connect.error.error_utils import attach_custom_error_code
|
|
26
|
+
from snowflake.snowpark_connect.type_mapping import map_type_to_snowflake_type
|
|
27
|
+
from snowflake.snowpark_connect.utils.jvm_udf_utils import (
|
|
28
|
+
NullHandling,
|
|
29
|
+
Param,
|
|
30
|
+
ReturnType,
|
|
31
|
+
Signature,
|
|
32
|
+
build_jvm_udxf_imports,
|
|
33
|
+
)
|
|
34
|
+
from snowflake.snowpark_connect.utils.session import get_or_create_snowpark_session
|
|
26
35
|
from snowflake.snowpark_connect.utils.snowpark_connect_logging import logger
|
|
27
36
|
from snowflake.snowpark_connect.utils.udf_utils import (
|
|
28
37
|
ProcessCommonInlineUserDefinedFunction,
|
|
@@ -59,58 +68,6 @@ class ScalaUdf:
|
|
|
59
68
|
self._return_type = return_type
|
|
60
69
|
|
|
61
70
|
|
|
62
|
-
@dataclass(frozen=True)
|
|
63
|
-
class Param:
|
|
64
|
-
"""
|
|
65
|
-
Represents a function parameter with name and data type.
|
|
66
|
-
|
|
67
|
-
Attributes:
|
|
68
|
-
name: Parameter name
|
|
69
|
-
data_type: Parameter data type as a string
|
|
70
|
-
"""
|
|
71
|
-
|
|
72
|
-
name: str
|
|
73
|
-
data_type: str
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
@dataclass(frozen=True)
|
|
77
|
-
class NullHandling(str, Enum):
|
|
78
|
-
"""
|
|
79
|
-
Enumeration for UDF null handling behavior.
|
|
80
|
-
|
|
81
|
-
Determines how the UDF behaves when input parameters contain null values.
|
|
82
|
-
"""
|
|
83
|
-
|
|
84
|
-
RETURNS_NULL_ON_NULL_INPUT = "RETURNS NULL ON NULL INPUT"
|
|
85
|
-
CALLED_ON_NULL_INPUT = "CALLED ON NULL INPUT"
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
@dataclass(frozen=True)
|
|
89
|
-
class ReturnType:
|
|
90
|
-
"""
|
|
91
|
-
Represents the return type of a function.
|
|
92
|
-
|
|
93
|
-
Attributes:
|
|
94
|
-
data_type: Return data type as a string
|
|
95
|
-
"""
|
|
96
|
-
|
|
97
|
-
data_type: str
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
@dataclass(frozen=True)
|
|
101
|
-
class Signature:
|
|
102
|
-
"""
|
|
103
|
-
Represents a function signature with parameters and return type.
|
|
104
|
-
|
|
105
|
-
Attributes:
|
|
106
|
-
params: List of function parameters
|
|
107
|
-
returns: Function return type
|
|
108
|
-
"""
|
|
109
|
-
|
|
110
|
-
params: List[Param]
|
|
111
|
-
returns: ReturnType
|
|
112
|
-
|
|
113
|
-
|
|
114
71
|
@dataclass(frozen=True)
|
|
115
72
|
class ScalaUDFDef:
|
|
116
73
|
"""
|
|
@@ -147,76 +104,58 @@ class ScalaUDFDef:
|
|
|
147
104
|
String containing the complete Scala code for the UDF body
|
|
148
105
|
"""
|
|
149
106
|
# Convert Array to Seq for Scala compatibility in function signatures.
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
107
|
+
# Replace each "Variant" type with "Any" in the function signature since fromVariant returns Any
|
|
108
|
+
udf_func_input_types = ", ".join(
|
|
109
|
+
"Any"
|
|
110
|
+
if p.data_type == "Variant"
|
|
111
|
+
else p.data_type # .replace("Array", "Seq")
|
|
112
|
+
for p in self.scala_signature.params
|
|
156
113
|
)
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
# when the original UDF function is invoked.
|
|
160
|
-
wrapper_arg_and_input_types_str = re.sub(
|
|
161
|
-
pattern=r"Map\[\w+,\s\w+\]",
|
|
162
|
-
repl="Map[String, String]",
|
|
163
|
-
string=joined_wrapper_arg_and_input_types_str,
|
|
164
|
-
)
|
|
165
|
-
invocation_args = ", ".join(self.scala_invocation_args)
|
|
166
|
-
|
|
167
|
-
# Cannot directly return a map from a Scala UDF due to issues with non-String values. Snowflake SQL Scala only
|
|
168
|
-
# supports Map[String, String] as input types. Therefore, we convert the map to a JSON string before returning.
|
|
169
|
-
# This is processed as a Variant by SQL.
|
|
170
|
-
udf_func_return_type = self.scala_signature.returns.data_type
|
|
171
|
-
is_map_return = udf_func_return_type.startswith("Map")
|
|
172
|
-
wrapper_return_type = "String" if is_map_return else udf_func_return_type
|
|
173
|
-
|
|
174
|
-
# Need to call the map to JSON string converter when a map is returned by the user's function.
|
|
175
|
-
invoke_udf_func = (
|
|
176
|
-
f"write(func({invocation_args}))"
|
|
177
|
-
if is_map_return
|
|
178
|
-
else f"func({invocation_args})"
|
|
114
|
+
udf_func_return_type = self.scala_signature.returns.data_type.replace(
|
|
115
|
+
"Array", "Seq"
|
|
179
116
|
)
|
|
180
117
|
|
|
181
|
-
#
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
import org.json4s._
|
|
188
|
-
import org.json4s.native.Serialization._
|
|
189
|
-
import org.json4s.native.Serialization
|
|
190
|
-
"""
|
|
191
|
-
)
|
|
192
|
-
map_return_formatter = (
|
|
193
|
-
""
|
|
194
|
-
if not is_map_return
|
|
195
|
-
else """
|
|
196
|
-
implicit val formats = Serialization.formats(NoTypeHints)
|
|
197
|
-
"""
|
|
118
|
+
# Create the Scala arguments and input types string: "arg0: Type0, arg1: Type1, ...".
|
|
119
|
+
joined_wrapper_arg_and_input_types_str = ", ".join(
|
|
120
|
+
f"{scala_type.name}: { scala_type.data_type if snowflake_type.data_type != 'VARIANT' else 'Variant'}"
|
|
121
|
+
for (scala_type, snowflake_type) in zip(
|
|
122
|
+
self.scala_signature.params, self.signature.params
|
|
123
|
+
)
|
|
198
124
|
)
|
|
199
125
|
|
|
200
|
-
return
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
126
|
+
# All Scala UDFs return Variant to ensure consistency and avoid type conversion issues.
|
|
127
|
+
wrapper_return_type = "Variant"
|
|
128
|
+
wrapped_args = [
|
|
129
|
+
f"UdfPacketUtils.fromVariant{f'[{scala_param.data_type}]' if scala_param.data_type != 'Variant' else '' }({arg if scala_param.data_type != 'Variant' else f'udfPacket, {arg}, {i}'})"
|
|
130
|
+
if param.data_type == "VARIANT"
|
|
131
|
+
else arg
|
|
132
|
+
for i, ((arg, param), scala_param) in enumerate(
|
|
133
|
+
zip(
|
|
134
|
+
zip(self.scala_invocation_args, self.signature.params),
|
|
135
|
+
self.scala_signature.params,
|
|
136
|
+
)
|
|
137
|
+
)
|
|
138
|
+
]
|
|
139
|
+
invocation_args = ", ".join(wrapped_args)
|
|
140
|
+
invoke_udf_func = f"func({invocation_args})"
|
|
141
|
+
|
|
142
|
+
# Always wrap the result in Utils.toVariant() to ensure all Scala UDFs return Variant
|
|
143
|
+
invoke_udf_func = f"Utils.toVariant({invoke_udf_func}, udfPacket)"
|
|
144
|
+
|
|
145
|
+
return f"""
|
|
146
|
+
import org.apache.spark.sql.connect.common.UdfPacket
|
|
147
|
+
import com.snowflake.sas.scala.UdfPacketUtils._
|
|
148
|
+
import com.snowflake.sas.scala.UdfPacketUtils
|
|
149
|
+
import com.snowflake.sas.scala.Utils
|
|
150
|
+
import com.snowflake.snowpark_java.types.Variant
|
|
204
151
|
|
|
205
152
|
object __RecreatedSparkUdf {{
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
val bytes = Files.readAllBytes(Paths.get(fPath))
|
|
211
|
-
val ois = new ObjectInputStream(new ByteArrayInputStream(bytes))
|
|
212
|
-
try {{
|
|
213
|
-
ois.readObject().asInstanceOf[UdfPacket].function.asInstanceOf[({udf_func_input_types}) => {udf_func_return_type}]
|
|
214
|
-
}} finally {{
|
|
215
|
-
ois.close()
|
|
216
|
-
}}
|
|
217
|
-
}}
|
|
153
|
+
import com.snowflake.sas.scala.FromVariantConverter._
|
|
154
|
+
|
|
155
|
+
private lazy val udfPacket: UdfPacket = Utils.deserializeUdfPacket("{self.name}.bin")
|
|
156
|
+
private lazy val func: ({udf_func_input_types}) => {udf_func_return_type} = udfPacket.function.asInstanceOf[({udf_func_input_types}) => {udf_func_return_type}]
|
|
218
157
|
|
|
219
|
-
def __wrapperFunc({
|
|
158
|
+
def __wrapperFunc({joined_wrapper_arg_and_input_types_str}): {wrapper_return_type} = {{
|
|
220
159
|
{invoke_udf_func}
|
|
221
160
|
}}
|
|
222
161
|
}}
|
|
@@ -245,13 +184,15 @@ object __RecreatedSparkUdf {{
|
|
|
245
184
|
# Handler and imports
|
|
246
185
|
imports_sql = f"IMPORTS = ({', '.join(quote_single(x) for x in self.imports)})"
|
|
247
186
|
|
|
187
|
+
scala_version = get_scala_version()
|
|
188
|
+
|
|
248
189
|
return f"""
|
|
249
190
|
CREATE OR REPLACE TEMPORARY FUNCTION {self.name}({args})
|
|
250
191
|
RETURNS {ret_type}
|
|
251
192
|
LANGUAGE SCALA
|
|
252
193
|
{self.null_handling.value}
|
|
253
|
-
RUNTIME_VERSION =
|
|
254
|
-
PACKAGES = ('com.snowflake:
|
|
194
|
+
RUNTIME_VERSION = {scala_version}
|
|
195
|
+
PACKAGES = ('com.snowflake:snowpark_{scala_version}:latest')
|
|
255
196
|
{imports_sql}
|
|
256
197
|
HANDLER = '__RecreatedSparkUdf.__wrapperFunc'
|
|
257
198
|
AS
|
|
@@ -260,70 +201,6 @@ $$
|
|
|
260
201
|
$$;"""
|
|
261
202
|
|
|
262
203
|
|
|
263
|
-
def build_scala_udf_imports(session, payload, udf_name, is_map_return) -> List[str]:
|
|
264
|
-
"""
|
|
265
|
-
Build the list of imports needed for the Scala UDF.
|
|
266
|
-
|
|
267
|
-
This function:
|
|
268
|
-
1. Saves the UDF payload to a binary file in the session stage
|
|
269
|
-
2. Collects user-uploaded JAR files from the stage
|
|
270
|
-
3. Returns a list of all required JAR files for the UDF
|
|
271
|
-
|
|
272
|
-
Args:
|
|
273
|
-
session: Snowpark session
|
|
274
|
-
payload: Binary payload containing the serialized UDF
|
|
275
|
-
udf_name: Name of the UDF (used for the binary file name)
|
|
276
|
-
is_map_return: Indicates if the UDF returns a Map (affects imports)
|
|
277
|
-
|
|
278
|
-
Returns:
|
|
279
|
-
List of JAR file paths to be imported by the UDF
|
|
280
|
-
"""
|
|
281
|
-
# Save pciudf._payload to a bin file:
|
|
282
|
-
import io
|
|
283
|
-
|
|
284
|
-
payload_as_stream = io.BytesIO(payload)
|
|
285
|
-
stage = session.get_session_stage()
|
|
286
|
-
stage_resource_path = stage + RESOURCE_PATH
|
|
287
|
-
closure_binary_file = stage_resource_path + "/" + udf_name + ".bin"
|
|
288
|
-
session.file.put_stream(
|
|
289
|
-
payload_as_stream,
|
|
290
|
-
closure_binary_file,
|
|
291
|
-
overwrite=True,
|
|
292
|
-
)
|
|
293
|
-
|
|
294
|
-
# Get a list of the jar files uploaded to the stage. We need to import the user's jar for the Scala UDF.
|
|
295
|
-
res = session.sql(rf"LIST {stage}/ PATTERN='.*\.jar';").collect()
|
|
296
|
-
user_jars = []
|
|
297
|
-
for row in res:
|
|
298
|
-
if RESOURCE_PATH not in row[0]:
|
|
299
|
-
# Remove the stage path since it is not properly formatted.
|
|
300
|
-
user_jars.append(row[0][row[0].find("/") :])
|
|
301
|
-
|
|
302
|
-
# Jars used when the return type is a Map.
|
|
303
|
-
map_jars = (
|
|
304
|
-
[]
|
|
305
|
-
if not is_map_return
|
|
306
|
-
else [
|
|
307
|
-
f"{stage_resource_path}/json4s-core_2.12-3.7.0-M11.jar",
|
|
308
|
-
f"{stage_resource_path}/json4s-native_2.12-3.7.0-M11.jar",
|
|
309
|
-
f"{stage_resource_path}/paranamer-2.8.3.jar",
|
|
310
|
-
]
|
|
311
|
-
)
|
|
312
|
-
|
|
313
|
-
# Format the user jars to be used in the IMPORTS clause of the stored procedure.
|
|
314
|
-
return (
|
|
315
|
-
[
|
|
316
|
-
closure_binary_file,
|
|
317
|
-
f"{stage_resource_path}/spark-connect-client-jvm_2.12-3.5.6.jar",
|
|
318
|
-
f"{stage_resource_path}/spark-common-utils_2.12-3.5.6.jar",
|
|
319
|
-
f"{stage_resource_path}/spark-sql_2.12-3.5.6.jar",
|
|
320
|
-
f"{stage_resource_path}/json4s-ast_2.12-3.7.0-M11.jar",
|
|
321
|
-
]
|
|
322
|
-
+ map_jars
|
|
323
|
-
+ [f"{stage + jar}" for jar in user_jars]
|
|
324
|
-
)
|
|
325
|
-
|
|
326
|
-
|
|
327
204
|
def create_scala_udf(pciudf: ProcessCommonInlineUserDefinedFunction) -> ScalaUdf:
|
|
328
205
|
"""
|
|
329
206
|
Create a Scala UDF in Snowflake from a ProcessCommonInlineUserDefinedFunction object.
|
|
@@ -343,7 +220,13 @@ def create_scala_udf(pciudf: ProcessCommonInlineUserDefinedFunction) -> ScalaUdf
|
|
|
343
220
|
Returns:
|
|
344
221
|
A ScalaUdf object representing the created or cached Scala UDF.
|
|
345
222
|
"""
|
|
346
|
-
from snowflake.snowpark_connect.
|
|
223
|
+
from snowflake.snowpark_connect.resources_initializer import (
|
|
224
|
+
ensure_scala_udf_jars_uploaded,
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
# Lazily upload Scala UDF jars on-demand when a Scala UDF is actually created.
|
|
228
|
+
# This is thread-safe and will only upload once even if multiple threads call it.
|
|
229
|
+
ensure_scala_udf_jars_uploaded()
|
|
347
230
|
|
|
348
231
|
function_name = pciudf._function_name
|
|
349
232
|
# If a function name is not provided, hash the binary file and use the first ten characters as the function name.
|
|
@@ -353,11 +236,6 @@ def create_scala_udf(pciudf: ProcessCommonInlineUserDefinedFunction) -> ScalaUdf
|
|
|
353
236
|
function_name = hashlib.sha256(pciudf._payload).hexdigest()[:10]
|
|
354
237
|
udf_name = CREATE_SCALA_UDF_PREFIX + function_name
|
|
355
238
|
|
|
356
|
-
session = get_or_create_snowpark_session()
|
|
357
|
-
if udf_name in session._udfs:
|
|
358
|
-
cached_udf = session._udfs[udf_name]
|
|
359
|
-
return ScalaUdf(cached_udf.name, cached_udf.input_types, cached_udf.return_type)
|
|
360
|
-
|
|
361
239
|
# In case the Scala UDF was created with `spark.udf.register`, the Spark Scala input types (from protobuf) are
|
|
362
240
|
# stored in pciudf.scala_input_types.
|
|
363
241
|
# We cannot rely solely on the inputTypes field from the Scala UDF or the Snowpark input types, since:
|
|
@@ -368,38 +246,59 @@ def create_scala_udf(pciudf: ProcessCommonInlineUserDefinedFunction) -> ScalaUdf
|
|
|
368
246
|
pciudf._scala_input_types if pciudf._scala_input_types else pciudf._input_types
|
|
369
247
|
)
|
|
370
248
|
|
|
249
|
+
scala_return_type = _map_type_to_scala_type(
|
|
250
|
+
pciudf._original_return_type, is_input=False
|
|
251
|
+
)
|
|
371
252
|
scala_input_params: List[Param] = []
|
|
372
253
|
sql_input_params: List[Param] = []
|
|
373
254
|
scala_invocation_args: List[str] = [] # arguments passed into the udf function
|
|
374
|
-
|
|
255
|
+
|
|
256
|
+
session = get_or_create_snowpark_session()
|
|
257
|
+
imports = build_jvm_udxf_imports(
|
|
258
|
+
session,
|
|
259
|
+
pciudf._payload,
|
|
260
|
+
udf_name,
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
# If input_types is empty (length 0), it doesn't necessarily mean there are no arguments.
|
|
264
|
+
# We need to inspect the UdfPacket to determine the actual number of arguments.
|
|
265
|
+
if (
|
|
266
|
+
input_types is None or len(input_types) == 0
|
|
267
|
+
) and pciudf._called_from == "register_udf":
|
|
268
|
+
args_scala = _get_input_arg_types_if_udfpacket_input_types_empty(
|
|
269
|
+
session, imports, udf_name
|
|
270
|
+
)
|
|
271
|
+
for i, arg in enumerate(args_scala):
|
|
272
|
+
param_name = "arg" + str(i)
|
|
273
|
+
scala_input_params.append(Param(param_name, arg))
|
|
274
|
+
sql_input_params.append(Param(param_name, "VARIANT"))
|
|
275
|
+
scala_invocation_args.append(param_name)
|
|
276
|
+
elif input_types:
|
|
375
277
|
for i, input_type in enumerate(input_types):
|
|
376
278
|
param_name = "arg" + str(i)
|
|
377
279
|
# Create the Scala arguments and input types string: "arg0: Type0, arg1: Type1, ...".
|
|
378
280
|
scala_input_params.append(
|
|
379
|
-
Param(param_name,
|
|
281
|
+
Param(param_name, _map_type_to_scala_type(input_type, is_input=True))
|
|
380
282
|
)
|
|
381
283
|
# Create the Snowflake SQL arguments and input types string: "arg0 TYPE0, arg1 TYPE1, ...".
|
|
382
|
-
|
|
383
|
-
|
|
284
|
+
# For arrays and structs, use VARIANT type in SQL signature
|
|
285
|
+
is_snowpark_type = isinstance(input_type, snowpark_type.DataType)
|
|
286
|
+
is_array = (
|
|
287
|
+
is_snowpark_type and isinstance(input_type, snowpark_type.ArrayType)
|
|
288
|
+
) or (not is_snowpark_type and input_type.WhichOneof("kind") == "array")
|
|
289
|
+
is_map = (
|
|
290
|
+
is_snowpark_type and isinstance(input_type, snowpark_type.MapType)
|
|
291
|
+
) or (not is_snowpark_type and input_type.WhichOneof("kind") == "map")
|
|
292
|
+
sql_type = (
|
|
293
|
+
"VARIANT"
|
|
294
|
+
if is_array or is_map
|
|
295
|
+
else map_type_to_snowflake_type(input_type)
|
|
384
296
|
)
|
|
297
|
+
sql_input_params.append(Param(param_name, sql_type))
|
|
385
298
|
# In the case of Map input types, we need to cast the argument to the correct type in Scala.
|
|
386
|
-
|
|
387
|
-
scala_invocation_args.append(
|
|
388
|
-
cast_scala_map_args_from_given_type(param_name, input_type)
|
|
389
|
-
)
|
|
299
|
+
scala_invocation_args.append(param_name)
|
|
390
300
|
|
|
391
|
-
|
|
392
|
-
# If the SQL return type is a MAP, change this to VARIANT because of issues with Scala UDFs.
|
|
393
|
-
sql_return_type = map_type_to_snowflake_type(pciudf._original_return_type)
|
|
394
|
-
imports = build_scala_udf_imports(
|
|
395
|
-
session,
|
|
396
|
-
pciudf._payload,
|
|
397
|
-
udf_name,
|
|
398
|
-
is_map_return=sql_return_type.startswith("MAP"),
|
|
399
|
-
)
|
|
400
|
-
sql_return_type = (
|
|
401
|
-
"VARIANT" if sql_return_type.startswith("MAP") else sql_return_type
|
|
402
|
-
)
|
|
301
|
+
sql_return_type = "VARIANT"
|
|
403
302
|
|
|
404
303
|
udf_def = ScalaUDFDef(
|
|
405
304
|
name=udf_name,
|
|
@@ -418,21 +317,78 @@ def create_scala_udf(pciudf: ProcessCommonInlineUserDefinedFunction) -> ScalaUdf
|
|
|
418
317
|
return ScalaUdf(udf_name, pciudf._input_types, pciudf._return_type)
|
|
419
318
|
|
|
420
319
|
|
|
421
|
-
def
|
|
422
|
-
|
|
320
|
+
def _ensure_input_types_udf_created(session, imports: List[str], udf_name: str) -> str:
|
|
321
|
+
"""
|
|
322
|
+
Create a UDF for getting input types with a unique name based on the UDF name.
|
|
323
|
+
|
|
324
|
+
This UDF uses reflection to inspect a serialized UdfPacket
|
|
325
|
+
and determine the actual input parameter types.
|
|
326
|
+
|
|
327
|
+
Returns:
|
|
328
|
+
The name of the created UDF.
|
|
329
|
+
"""
|
|
330
|
+
|
|
331
|
+
def quote_single(s: str) -> str:
|
|
332
|
+
return "'" + s + "'"
|
|
333
|
+
|
|
334
|
+
scala_version = get_scala_version()
|
|
335
|
+
udf_helper_name = f"__SC_INPUT_ARGS_UDF_{udf_name}"
|
|
336
|
+
imports_sql = f"IMPORTS = ({', '.join(quote_single(x) for x in imports)})"
|
|
337
|
+
create_udf_sql = f"""
|
|
338
|
+
CREATE OR REPLACE TEMPORARY FUNCTION {udf_helper_name}(udf_bin_file VARCHAR)
|
|
339
|
+
RETURNS STRING
|
|
340
|
+
LANGUAGE SCALA
|
|
341
|
+
PACKAGES = ('com.snowflake:snowpark_{scala_version}:latest')
|
|
342
|
+
RUNTIME_VERSION = {scala_version}
|
|
343
|
+
{imports_sql}
|
|
344
|
+
HANDLER = 'com.snowflake.sas.scala.handlers.InputTypesUdf.getInputArgTypesWithReflection';"""
|
|
345
|
+
logger.info(f"Creating UDF for input type inspection: {create_udf_sql}")
|
|
346
|
+
session.sql(create_udf_sql).collect()
|
|
347
|
+
return udf_helper_name
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
def _get_input_arg_types_if_udfpacket_input_types_empty(
|
|
351
|
+
session, imports: List[str], udf_name: str
|
|
352
|
+
) -> list[str]:
|
|
353
|
+
"""
|
|
354
|
+
Get the number of input arguments from a UdfPacket by calling a Scala UDF.
|
|
355
|
+
|
|
356
|
+
This is used when the input_types list is empty (length 0), which doesn't necessarily
|
|
357
|
+
mean there are no arguments. The UDF uses reflection to inspect the
|
|
358
|
+
serialized function and determine the actual parameters.
|
|
359
|
+
"""
|
|
360
|
+
udf_helper_name = _ensure_input_types_udf_created(session, imports, udf_name)
|
|
361
|
+
result = session.sql(f"SELECT {udf_helper_name}('{udf_name}.bin')").collect()
|
|
362
|
+
args = str(result[0][0])
|
|
363
|
+
num_args = len(args.split(", "))
|
|
364
|
+
logger.info(f"UDF has {num_args} input arguments")
|
|
365
|
+
return [arg for arg in args.split(", ") if arg]
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
def _map_type_to_scala_type(
|
|
369
|
+
t: Union[snowpark_type.DataType, types_proto.DataType], is_input: bool = False
|
|
423
370
|
) -> str:
|
|
424
|
-
"""Maps a Snowpark or Spark protobuf type to a Scala type string.
|
|
371
|
+
"""Maps a Snowpark or Spark protobuf type to a Scala type string.
|
|
372
|
+
|
|
373
|
+
Args:
|
|
374
|
+
t: The type to map
|
|
375
|
+
is_input: If True, maps array types to Variant (for UDF inputs).
|
|
376
|
+
If False, maps array types to Array[ElementType] (for UDF outputs).
|
|
377
|
+
"""
|
|
425
378
|
if not t:
|
|
426
379
|
return "String"
|
|
427
380
|
is_snowpark_type = isinstance(t, snowpark_type.DataType)
|
|
428
381
|
condition = type(t) if is_snowpark_type else t.WhichOneof("kind")
|
|
429
382
|
match condition:
|
|
430
383
|
case snowpark_type.ArrayType | "array":
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
384
|
+
if is_input:
|
|
385
|
+
return "Variant"
|
|
386
|
+
else:
|
|
387
|
+
return (
|
|
388
|
+
f"Array[{_map_type_to_scala_type(t.element_type, is_input=False)}]"
|
|
389
|
+
if is_snowpark_type
|
|
390
|
+
else f"Array[{_map_type_to_scala_type(t.array.element_type, is_input=False)}]"
|
|
391
|
+
)
|
|
436
392
|
case snowpark_type.BinaryType | "binary":
|
|
437
393
|
return "Array[Byte]"
|
|
438
394
|
case snowpark_type.BooleanType | "boolean":
|
|
@@ -453,16 +409,18 @@ def map_type_to_scala_type(
|
|
|
453
409
|
return "Int"
|
|
454
410
|
case snowpark_type.LongType | "long":
|
|
455
411
|
return "Long"
|
|
456
|
-
case snowpark_type.MapType | "map":
|
|
412
|
+
case snowpark_type.MapType | "map":
|
|
413
|
+
if is_input:
|
|
414
|
+
return "Variant"
|
|
457
415
|
key_type = (
|
|
458
|
-
|
|
416
|
+
_map_type_to_scala_type(t.key_type)
|
|
459
417
|
if is_snowpark_type
|
|
460
|
-
else
|
|
418
|
+
else _map_type_to_scala_type(t.map.key_type)
|
|
461
419
|
)
|
|
462
420
|
value_type = (
|
|
463
|
-
|
|
421
|
+
_map_type_to_scala_type(t.value_type)
|
|
464
422
|
if is_snowpark_type
|
|
465
|
-
else
|
|
423
|
+
else _map_type_to_scala_type(t.map.value_type)
|
|
466
424
|
)
|
|
467
425
|
return f"Map[{key_type}, {value_type}]"
|
|
468
426
|
case snowpark_type.NullType | "null":
|
|
@@ -471,126 +429,13 @@ def map_type_to_scala_type(
|
|
|
471
429
|
return "Short"
|
|
472
430
|
case snowpark_type.StringType | "string" | "char" | "varchar":
|
|
473
431
|
return "String"
|
|
432
|
+
case snowpark_type.StructType | "struct":
|
|
433
|
+
return "Variant"
|
|
474
434
|
case snowpark_type.TimestampType | "timestamp" | "timestamp_ntz":
|
|
475
435
|
return "java.sql.Timestamp"
|
|
476
436
|
case snowpark_type.VariantType:
|
|
477
437
|
return "Variant"
|
|
478
438
|
case _:
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
def map_type_to_snowflake_type(
|
|
483
|
-
t: Union[snowpark_type.DataType, types_proto.DataType]
|
|
484
|
-
) -> str:
|
|
485
|
-
"""Maps a Snowpark or Spark protobuf type to a Snowflake type string."""
|
|
486
|
-
if not t:
|
|
487
|
-
return "VARCHAR"
|
|
488
|
-
is_snowpark_type = isinstance(t, snowpark_type.DataType)
|
|
489
|
-
condition = type(t) if is_snowpark_type else t.WhichOneof("kind")
|
|
490
|
-
match condition:
|
|
491
|
-
case snowpark_type.ArrayType | "array":
|
|
492
|
-
return (
|
|
493
|
-
f"ARRAY({map_type_to_snowflake_type(t.element_type)})"
|
|
494
|
-
if is_snowpark_type
|
|
495
|
-
else f"ARRAY({map_type_to_snowflake_type(t.array.element_type)})"
|
|
496
|
-
)
|
|
497
|
-
case snowpark_type.BinaryType | "binary":
|
|
498
|
-
return "BINARY"
|
|
499
|
-
case snowpark_type.BooleanType | "boolean":
|
|
500
|
-
return "BOOLEAN"
|
|
501
|
-
case snowpark_type.ByteType | "byte":
|
|
502
|
-
return "TINYINT"
|
|
503
|
-
case snowpark_type.DateType | "date":
|
|
504
|
-
return "DATE"
|
|
505
|
-
case snowpark_type.DecimalType | "decimal":
|
|
506
|
-
return "NUMBER"
|
|
507
|
-
case snowpark_type.DoubleType | "double":
|
|
508
|
-
return "DOUBLE"
|
|
509
|
-
case snowpark_type.FloatType | "float":
|
|
510
|
-
return "FLOAT"
|
|
511
|
-
case snowpark_type.GeographyType:
|
|
512
|
-
return "GEOGRAPHY"
|
|
513
|
-
case snowpark_type.IntegerType | "integer":
|
|
514
|
-
return "INT"
|
|
515
|
-
case snowpark_type.LongType | "long":
|
|
516
|
-
return "BIGINT"
|
|
517
|
-
case snowpark_type.MapType | "map":
|
|
518
|
-
# Maps to OBJECT in Snowflake if key and value types are not specified.
|
|
519
|
-
key_type = (
|
|
520
|
-
map_type_to_snowflake_type(t.key_type)
|
|
521
|
-
if is_snowpark_type
|
|
522
|
-
else map_type_to_snowflake_type(t.map.key_type)
|
|
523
|
-
)
|
|
524
|
-
value_type = (
|
|
525
|
-
map_type_to_snowflake_type(t.value_type)
|
|
526
|
-
if is_snowpark_type
|
|
527
|
-
else map_type_to_snowflake_type(t.map.value_type)
|
|
528
|
-
)
|
|
529
|
-
return (
|
|
530
|
-
f"MAP({key_type}, {value_type})"
|
|
531
|
-
if key_type and value_type
|
|
532
|
-
else "OBJECT"
|
|
533
|
-
)
|
|
534
|
-
case snowpark_type.NullType | "null":
|
|
535
|
-
return "VARCHAR"
|
|
536
|
-
case snowpark_type.ShortType | "short":
|
|
537
|
-
return "SMALLINT"
|
|
538
|
-
case snowpark_type.StringType | "string" | "char" | "varchar":
|
|
539
|
-
return "VARCHAR"
|
|
540
|
-
case snowpark_type.TimestampType | "timestamp" | "timestamp_ntz":
|
|
541
|
-
return "TIMESTAMP"
|
|
542
|
-
case snowpark_type.VariantType:
|
|
543
|
-
return "VARIANT"
|
|
544
|
-
case _:
|
|
545
|
-
raise ValueError(f"Unsupported Snowpark type: {t}")
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
def cast_scala_map_args_from_given_type(
|
|
549
|
-
arg_name: str, input_type: Union[snowpark_type.DataType, types_proto.DataType]
|
|
550
|
-
) -> str:
|
|
551
|
-
"""If the input_type is a Map, cast the argument arg_name to a Map[key_type, value_type] in Scala."""
|
|
552
|
-
is_snowpark_type = isinstance(input_type, snowpark_type.DataType)
|
|
553
|
-
|
|
554
|
-
def convert_from_string_to_type(
|
|
555
|
-
arg_name: str, t: Union[snowpark_type.DataType, types_proto.DataType]
|
|
556
|
-
) -> str:
|
|
557
|
-
"""Convert the string argument arg_name to the specified type t in Scala."""
|
|
558
|
-
condition = type(t) if is_snowpark_type else t.WhichOneof("kind")
|
|
559
|
-
match condition:
|
|
560
|
-
case snowpark_type.BinaryType | "binary":
|
|
561
|
-
return arg_name + ".getBytes()"
|
|
562
|
-
case snowpark_type.BooleanType | "boolean":
|
|
563
|
-
return arg_name + ".toBoolean"
|
|
564
|
-
case snowpark_type.ByteType | "byte":
|
|
565
|
-
return arg_name + ".getBytes().head" # TODO: verify if this is correct
|
|
566
|
-
case snowpark_type.DateType | "date":
|
|
567
|
-
return f"java.sql.Date.valueOf({arg_name})"
|
|
568
|
-
case snowpark_type.DecimalType | "decimal":
|
|
569
|
-
return f"new BigDecimal({arg_name})"
|
|
570
|
-
case snowpark_type.DoubleType | "double":
|
|
571
|
-
return arg_name + ".toDouble"
|
|
572
|
-
case snowpark_type.FloatType | "float":
|
|
573
|
-
return arg_name + ".toFloat"
|
|
574
|
-
case snowpark_type.IntegerType | "integer":
|
|
575
|
-
return arg_name + ".toInt"
|
|
576
|
-
case snowpark_type.LongType | "long":
|
|
577
|
-
return arg_name + ".toLong"
|
|
578
|
-
case snowpark_type.ShortType | "short":
|
|
579
|
-
return arg_name + ".toShort"
|
|
580
|
-
case snowpark_type.StringType | "string" | "char" | "varchar":
|
|
581
|
-
return arg_name
|
|
582
|
-
case snowpark_type.TimestampType | "timestamp" | "timestamp_ntz":
|
|
583
|
-
return "java.sql.Timestamp.valueOf({arg_name})"
|
|
584
|
-
case _:
|
|
585
|
-
raise ValueError(f"Unsupported Snowpark type: {t}")
|
|
586
|
-
|
|
587
|
-
if (is_snowpark_type and isinstance(input_type, snowpark_type.MapType)) or (
|
|
588
|
-
not is_snowpark_type and input_type.WhichOneof("kind") == "map"
|
|
589
|
-
):
|
|
590
|
-
key_type = input_type.key_type if is_snowpark_type else input_type.map.key_type
|
|
591
|
-
value_type = (
|
|
592
|
-
input_type.value_type if is_snowpark_type else input_type.map.value_type
|
|
593
|
-
)
|
|
594
|
-
return f"{arg_name}.map {{ case (k, v) => ({convert_from_string_to_type('k', key_type)}, {convert_from_string_to_type('v', value_type)})}}"
|
|
595
|
-
else:
|
|
596
|
-
return arg_name
|
|
439
|
+
exception = ValueError(f"Unsupported Snowpark type: {t}")
|
|
440
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_TYPE)
|
|
441
|
+
raise exception
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
|
|
5
|
+
import threading
|
|
6
|
+
from collections import defaultdict
|
|
7
|
+
|
|
8
|
+
from snowflake.snowpark_connect.utils.context import get_spark_session_id
|
|
9
|
+
|
|
10
|
+
# per session number sequences to generate unique snowpark columns
|
|
11
|
+
_session_sequences = defaultdict(int)
|
|
12
|
+
|
|
13
|
+
_lock = threading.Lock()
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def next_unique_num():
|
|
17
|
+
session_id = get_spark_session_id()
|
|
18
|
+
with _lock:
|
|
19
|
+
next_num = _session_sequences[session_id]
|
|
20
|
+
_session_sequences[session_id] = next_num + 1
|
|
21
|
+
return next_num
|