snowpark-connect 0.27.0__py3-none-any.whl → 1.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/snowpark_connect/__init__.py +1 -0
- snowflake/snowpark_connect/analyze_plan/map_tree_string.py +8 -4
- snowflake/snowpark_connect/client/__init__.py +15 -0
- snowflake/snowpark_connect/client/error_utils.py +30 -0
- snowflake/snowpark_connect/client/exceptions.py +36 -0
- snowflake/snowpark_connect/client/query_results.py +90 -0
- snowflake/snowpark_connect/client/server.py +680 -0
- snowflake/snowpark_connect/client/utils/__init__.py +10 -0
- snowflake/snowpark_connect/client/utils/session.py +85 -0
- snowflake/snowpark_connect/column_name_handler.py +404 -243
- snowflake/snowpark_connect/column_qualifier.py +43 -0
- snowflake/snowpark_connect/config.py +237 -23
- snowflake/snowpark_connect/constants.py +2 -0
- snowflake/snowpark_connect/dataframe_container.py +102 -8
- snowflake/snowpark_connect/date_time_format_mapping.py +71 -13
- snowflake/snowpark_connect/error/error_codes.py +50 -0
- snowflake/snowpark_connect/error/error_utils.py +172 -23
- snowflake/snowpark_connect/error/exceptions.py +13 -4
- snowflake/snowpark_connect/execute_plan/map_execution_command.py +15 -160
- snowflake/snowpark_connect/execute_plan/map_execution_root.py +26 -20
- snowflake/snowpark_connect/execute_plan/utils.py +5 -1
- snowflake/snowpark_connect/expression/function_defaults.py +9 -2
- snowflake/snowpark_connect/expression/hybrid_column_map.py +53 -5
- snowflake/snowpark_connect/expression/literal.py +37 -13
- snowflake/snowpark_connect/expression/map_cast.py +123 -5
- snowflake/snowpark_connect/expression/map_expression.py +80 -27
- snowflake/snowpark_connect/expression/map_extension.py +322 -12
- snowflake/snowpark_connect/expression/map_sql_expression.py +316 -81
- snowflake/snowpark_connect/expression/map_udf.py +85 -20
- snowflake/snowpark_connect/expression/map_unresolved_attribute.py +451 -173
- snowflake/snowpark_connect/expression/map_unresolved_function.py +2748 -746
- snowflake/snowpark_connect/expression/map_unresolved_star.py +87 -23
- snowflake/snowpark_connect/expression/map_update_fields.py +70 -18
- snowflake/snowpark_connect/expression/map_window_function.py +18 -3
- snowflake/snowpark_connect/includes/jars/{scala-library-2.12.18.jar → sas-scala-udf_2.12-0.2.0.jar} +0 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py +1 -1
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/listener_worker.py +1 -1
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.py +12 -10
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.pyi +14 -2
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +10 -8
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +13 -6
- snowflake/snowpark_connect/relation/catalogs/abstract_spark_catalog.py +65 -17
- snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +297 -49
- snowflake/snowpark_connect/relation/catalogs/utils.py +12 -4
- snowflake/snowpark_connect/relation/io_utils.py +110 -10
- snowflake/snowpark_connect/relation/map_aggregate.py +196 -255
- snowflake/snowpark_connect/relation/map_catalog.py +5 -1
- snowflake/snowpark_connect/relation/map_column_ops.py +264 -96
- snowflake/snowpark_connect/relation/map_extension.py +263 -29
- snowflake/snowpark_connect/relation/map_join.py +683 -442
- snowflake/snowpark_connect/relation/map_local_relation.py +28 -1
- snowflake/snowpark_connect/relation/map_map_partitions.py +83 -8
- snowflake/snowpark_connect/relation/map_relation.py +48 -19
- snowflake/snowpark_connect/relation/map_row_ops.py +310 -91
- snowflake/snowpark_connect/relation/map_show_string.py +13 -6
- snowflake/snowpark_connect/relation/map_sql.py +1233 -222
- snowflake/snowpark_connect/relation/map_stats.py +48 -9
- snowflake/snowpark_connect/relation/map_subquery_alias.py +11 -2
- snowflake/snowpark_connect/relation/map_udtf.py +14 -4
- snowflake/snowpark_connect/relation/read/jdbc_read_dbapi.py +53 -14
- snowflake/snowpark_connect/relation/read/map_read.py +134 -43
- snowflake/snowpark_connect/relation/read/map_read_csv.py +255 -45
- snowflake/snowpark_connect/relation/read/map_read_jdbc.py +17 -5
- snowflake/snowpark_connect/relation/read/map_read_json.py +320 -85
- snowflake/snowpark_connect/relation/read/map_read_parquet.py +142 -27
- snowflake/snowpark_connect/relation/read/map_read_partitioned_parquet.py +142 -0
- snowflake/snowpark_connect/relation/read/map_read_socket.py +11 -3
- snowflake/snowpark_connect/relation/read/map_read_table.py +82 -5
- snowflake/snowpark_connect/relation/read/map_read_text.py +18 -3
- snowflake/snowpark_connect/relation/read/metadata_utils.py +170 -0
- snowflake/snowpark_connect/relation/read/reader_config.py +36 -3
- snowflake/snowpark_connect/relation/read/utils.py +50 -5
- snowflake/snowpark_connect/relation/stage_locator.py +91 -55
- snowflake/snowpark_connect/relation/utils.py +128 -5
- snowflake/snowpark_connect/relation/write/jdbc_write_dbapi.py +19 -3
- snowflake/snowpark_connect/relation/write/map_write.py +929 -319
- snowflake/snowpark_connect/relation/write/map_write_jdbc.py +8 -2
- snowflake/snowpark_connect/resources/java_udfs-1.0-SNAPSHOT.jar +0 -0
- snowflake/snowpark_connect/resources_initializer.py +110 -48
- snowflake/snowpark_connect/server.py +546 -456
- snowflake/snowpark_connect/server_common/__init__.py +500 -0
- snowflake/snowpark_connect/snowflake_session.py +65 -0
- snowflake/snowpark_connect/start_server.py +53 -5
- snowflake/snowpark_connect/type_mapping.py +349 -27
- snowflake/snowpark_connect/typed_column.py +9 -7
- snowflake/snowpark_connect/utils/artifacts.py +9 -8
- snowflake/snowpark_connect/utils/cache.py +49 -27
- snowflake/snowpark_connect/utils/concurrent.py +36 -1
- snowflake/snowpark_connect/utils/context.py +187 -37
- snowflake/snowpark_connect/utils/describe_query_cache.py +68 -53
- snowflake/snowpark_connect/utils/env_utils.py +5 -1
- snowflake/snowpark_connect/utils/expression_transformer.py +172 -0
- snowflake/snowpark_connect/utils/identifiers.py +137 -3
- snowflake/snowpark_connect/utils/io_utils.py +57 -1
- snowflake/snowpark_connect/utils/java_stored_procedure.py +125 -0
- snowflake/snowpark_connect/utils/java_udaf_utils.py +303 -0
- snowflake/snowpark_connect/utils/java_udtf_utils.py +239 -0
- snowflake/snowpark_connect/utils/jvm_udf_utils.py +248 -0
- snowflake/snowpark_connect/utils/open_telemetry.py +516 -0
- snowflake/snowpark_connect/utils/pandas_udtf_utils.py +8 -4
- snowflake/snowpark_connect/utils/patch_spark_line_number.py +181 -0
- snowflake/snowpark_connect/utils/profiling.py +25 -8
- snowflake/snowpark_connect/utils/scala_udf_utils.py +101 -332
- snowflake/snowpark_connect/utils/sequence.py +21 -0
- snowflake/snowpark_connect/utils/session.py +64 -28
- snowflake/snowpark_connect/utils/snowpark_connect_logging.py +51 -9
- snowflake/snowpark_connect/utils/spcs_logger.py +290 -0
- snowflake/snowpark_connect/utils/telemetry.py +163 -22
- snowflake/snowpark_connect/utils/temporary_view_cache.py +67 -0
- snowflake/snowpark_connect/utils/temporary_view_helper.py +334 -0
- snowflake/snowpark_connect/utils/udf_cache.py +117 -41
- snowflake/snowpark_connect/utils/udf_helper.py +39 -37
- snowflake/snowpark_connect/utils/udf_utils.py +133 -14
- snowflake/snowpark_connect/utils/udtf_helper.py +8 -1
- snowflake/snowpark_connect/utils/udtf_utils.py +46 -31
- snowflake/snowpark_connect/utils/upload_java_jar.py +57 -0
- snowflake/snowpark_connect/version.py +1 -1
- snowflake/snowpark_decoder/dp_session.py +6 -2
- snowflake/snowpark_decoder/spark_decoder.py +12 -0
- {snowpark_connect-0.27.0.data → snowpark_connect-1.6.0.data}/scripts/snowpark-submit +2 -2
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/METADATA +14 -7
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/RECORD +129 -167
- snowflake/snowpark_connect/hidden_column.py +0 -39
- snowflake/snowpark_connect/includes/jars/antlr4-runtime-4.9.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-cli-1.5.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-codec-1.16.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-collections-3.2.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-collections4-4.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-compiler-3.1.9.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-compress-1.26.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-crypto-1.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-dbcp-1.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-io-2.16.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-lang-2.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-lang3-3.12.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-logging-1.1.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-math3-3.6.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-pool-1.5.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-text-1.10.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/hadoop-client-api-trimmed-3.3.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-annotations-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-core-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-core-asl-1.9.13.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-databind-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-dataformat-yaml-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-datatype-jsr310-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-module-scala_2.12-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-ast_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-core_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-jackson_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-native_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-scalap_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/kryo-shaded-4.0.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-1.2-api-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-api-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-core-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-slf4j2-impl-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/paranamer-2.8.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/paranamer-2.8.jar +0 -0
- snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-collection-compat_2.12-2.7.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-parser-combinators_2.12-2.3.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-reflect-2.12.18.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-xml_2.12-2.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/slf4j-api-2.0.7.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-catalyst_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-common-utils_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-connect-client-jvm_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-core_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-graphx_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-hive-thriftserver_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-hive_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-kvstore_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-launcher_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-mesos_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-mllib-local_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-network-common_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-network-shuffle_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-repl_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sketch_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sql-api_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sql_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-tags_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-unsafe_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-yarn_2.12-3.5.6.jar +0 -0
- {snowpark_connect-0.27.0.data → snowpark_connect-1.6.0.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.27.0.data → snowpark_connect-1.6.0.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/licenses/NOTICE-binary +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,248 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from enum import Enum
|
|
7
|
+
from typing import List, Union
|
|
8
|
+
|
|
9
|
+
import snowflake.snowpark.types as snowpark_type
|
|
10
|
+
import snowflake.snowpark_connect.includes.python.pyspark.sql.connect.proto.types_pb2 as types_proto
|
|
11
|
+
from snowflake import snowpark
|
|
12
|
+
from snowflake.snowpark_connect.error.error_codes import ErrorCodes
|
|
13
|
+
from snowflake.snowpark_connect.error.error_utils import attach_custom_error_code
|
|
14
|
+
from snowflake.snowpark_connect.resources_initializer import (
|
|
15
|
+
JSON_4S_JAR,
|
|
16
|
+
RESOURCE_PATH,
|
|
17
|
+
SAS_SCALA_UDF_JAR,
|
|
18
|
+
SCALA_REFLECT_JAR,
|
|
19
|
+
SPARK_COMMON_UTILS_JAR,
|
|
20
|
+
SPARK_CONNECT_CLIENT_JAR,
|
|
21
|
+
SPARK_SQL_JAR,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass(frozen=True)
|
|
26
|
+
class Param:
|
|
27
|
+
"""
|
|
28
|
+
Represents a function parameter with name and data type.
|
|
29
|
+
|
|
30
|
+
Attributes:
|
|
31
|
+
name: Parameter name
|
|
32
|
+
data_type: Parameter data type as a string
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
name: str
|
|
36
|
+
data_type: str
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@dataclass(frozen=True)
|
|
40
|
+
class NullHandling(str, Enum):
|
|
41
|
+
"""
|
|
42
|
+
Enumeration for UDF null handling behavior.
|
|
43
|
+
|
|
44
|
+
Determines how the UDF behaves when input parameters contain null values.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
RETURNS_NULL_ON_NULL_INPUT = "RETURNS NULL ON NULL INPUT"
|
|
48
|
+
CALLED_ON_NULL_INPUT = "CALLED ON NULL INPUT"
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@dataclass(frozen=True)
|
|
52
|
+
class ReturnType:
|
|
53
|
+
"""
|
|
54
|
+
Represents the return type of a function.
|
|
55
|
+
|
|
56
|
+
Attributes:
|
|
57
|
+
data_type: Return data type as a string
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
data_type: str
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@dataclass(frozen=True)
|
|
64
|
+
class Signature:
|
|
65
|
+
"""
|
|
66
|
+
Represents a function signature with parameters and return type.
|
|
67
|
+
|
|
68
|
+
Attributes:
|
|
69
|
+
params: List of function parameters
|
|
70
|
+
returns: Function return type
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
params: List[Param]
|
|
74
|
+
returns: ReturnType
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def build_jvm_udxf_imports(
|
|
78
|
+
session: snowpark.Session, payload: bytes, udf_name: str
|
|
79
|
+
) -> List[str]:
|
|
80
|
+
"""
|
|
81
|
+
Build the list of imports needed for the JVM UDxF.
|
|
82
|
+
|
|
83
|
+
This function:
|
|
84
|
+
1. Saves the UDF payload to a binary file in the session stage
|
|
85
|
+
2. Collects user-uploaded JAR files from the stage
|
|
86
|
+
3. Returns a list of all required JAR files for the UDxF
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
session: Snowpark session
|
|
90
|
+
payload: Binary payload containing the serialized Scala UDF
|
|
91
|
+
udf_name: Name of the Scala UDF (used for the binary file name)
|
|
92
|
+
is_map_return: Indicates if the UDxF returns a Map (affects imports)
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
List of JAR file paths to be imported by the UDxF
|
|
96
|
+
"""
|
|
97
|
+
# Save pciudf._payload to a bin file:
|
|
98
|
+
import io
|
|
99
|
+
|
|
100
|
+
payload_as_stream = io.BytesIO(payload)
|
|
101
|
+
stage = session.get_session_stage()
|
|
102
|
+
stage_resource_path = stage + RESOURCE_PATH
|
|
103
|
+
closure_binary_file = stage_resource_path + "/scala/bin/" + udf_name + ".bin"
|
|
104
|
+
session.file.put_stream(
|
|
105
|
+
payload_as_stream,
|
|
106
|
+
closure_binary_file,
|
|
107
|
+
overwrite=True,
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
# Format the user jars to be used in the IMPORTS clause of the stored procedure.
|
|
111
|
+
return [
|
|
112
|
+
closure_binary_file,
|
|
113
|
+
f"{stage_resource_path}/{SPARK_CONNECT_CLIENT_JAR}",
|
|
114
|
+
f"{stage_resource_path}/{SPARK_COMMON_UTILS_JAR}",
|
|
115
|
+
f"{stage_resource_path}/{SPARK_SQL_JAR}",
|
|
116
|
+
f"{stage_resource_path}/{JSON_4S_JAR}",
|
|
117
|
+
f"{stage_resource_path}/{SAS_SCALA_UDF_JAR}",
|
|
118
|
+
f"{stage_resource_path}/{SCALA_REFLECT_JAR}", # Required for deserializing Scala lambdas
|
|
119
|
+
] + list(session._artifact_jars)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def map_type_to_java_type(
|
|
123
|
+
t: Union[snowpark_type.DataType, types_proto.DataType]
|
|
124
|
+
) -> str:
|
|
125
|
+
"""Maps a Snowpark or Spark protobuf type to a Java type string."""
|
|
126
|
+
if not t:
|
|
127
|
+
return "String"
|
|
128
|
+
is_snowpark_type = isinstance(t, snowpark_type.DataType)
|
|
129
|
+
condition = type(t) if is_snowpark_type else t.WhichOneof("kind")
|
|
130
|
+
match condition:
|
|
131
|
+
case snowpark_type.ArrayType | "array":
|
|
132
|
+
return (
|
|
133
|
+
f"{map_type_to_java_type(t.element_type)}[]"
|
|
134
|
+
if is_snowpark_type
|
|
135
|
+
else f"{map_type_to_java_type(t.array.element_type)}[]"
|
|
136
|
+
)
|
|
137
|
+
case snowpark_type.BinaryType | "binary":
|
|
138
|
+
return "byte[]"
|
|
139
|
+
case snowpark_type.BooleanType | "boolean":
|
|
140
|
+
return "Boolean"
|
|
141
|
+
case snowpark_type.ByteType | "byte":
|
|
142
|
+
return "Byte"
|
|
143
|
+
case snowpark_type.DateType | "date":
|
|
144
|
+
return "java.sql.Date"
|
|
145
|
+
case snowpark_type.DecimalType | "decimal":
|
|
146
|
+
return "java.math.BigDecimal"
|
|
147
|
+
case snowpark_type.DoubleType | "double":
|
|
148
|
+
return "Double"
|
|
149
|
+
case snowpark_type.FloatType | "float":
|
|
150
|
+
return "Float"
|
|
151
|
+
case snowpark_type.GeographyType:
|
|
152
|
+
return "Geography"
|
|
153
|
+
case snowpark_type.IntegerType | "integer":
|
|
154
|
+
return "Integer"
|
|
155
|
+
case snowpark_type.LongType | "long":
|
|
156
|
+
return "Long"
|
|
157
|
+
case snowpark_type.MapType | "map": # can also map to OBJECT in Snowflake
|
|
158
|
+
key_type = (
|
|
159
|
+
map_type_to_java_type(t.key_type)
|
|
160
|
+
if is_snowpark_type
|
|
161
|
+
else map_type_to_java_type(t.map.key_type)
|
|
162
|
+
)
|
|
163
|
+
value_type = (
|
|
164
|
+
map_type_to_java_type(t.value_type)
|
|
165
|
+
if is_snowpark_type
|
|
166
|
+
else map_type_to_java_type(t.map.value_type)
|
|
167
|
+
)
|
|
168
|
+
return f"Map<{key_type}, {value_type}>"
|
|
169
|
+
case snowpark_type.NullType | "null":
|
|
170
|
+
return "String" # cannot set the return type to Null in Snowpark Java UDAFs
|
|
171
|
+
case snowpark_type.ShortType | "short":
|
|
172
|
+
return "Short"
|
|
173
|
+
case snowpark_type.StringType | "string" | "char" | "varchar":
|
|
174
|
+
return "String"
|
|
175
|
+
case snowpark_type.StructType | "struct":
|
|
176
|
+
return "Variant"
|
|
177
|
+
case snowpark_type.TimestampType | "timestamp" | "timestamp_ntz":
|
|
178
|
+
return "java.sql.Timestamp"
|
|
179
|
+
case snowpark_type.VariantType:
|
|
180
|
+
return "Variant"
|
|
181
|
+
case _:
|
|
182
|
+
exception = ValueError(f"Unsupported Snowpark type: {t}")
|
|
183
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_TYPE)
|
|
184
|
+
raise exception
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def cast_java_map_args_from_given_type(
|
|
188
|
+
arg_name: str, input_type: Union[snowpark_type.DataType, types_proto.DataType]
|
|
189
|
+
) -> str:
|
|
190
|
+
"""If the input_type is a Map or Struct, cast the argument arg_name to the correct type in Java."""
|
|
191
|
+
is_snowpark_type = isinstance(input_type, snowpark_type.DataType)
|
|
192
|
+
|
|
193
|
+
def convert_from_string_to_type(
|
|
194
|
+
arg_name: str, t: Union[snowpark_type.DataType, types_proto.DataType]
|
|
195
|
+
) -> str:
|
|
196
|
+
"""Convert the string argument arg_name to the specified type t in Java."""
|
|
197
|
+
condition = type(t) if is_snowpark_type else t.WhichOneof("kind")
|
|
198
|
+
match condition:
|
|
199
|
+
case snowpark_type.BinaryType | "binary":
|
|
200
|
+
return arg_name + ".getBytes()"
|
|
201
|
+
case snowpark_type.BooleanType | "boolean":
|
|
202
|
+
return f"Boolean.valueOf({arg_name}"
|
|
203
|
+
case snowpark_type.ByteType | "byte":
|
|
204
|
+
return arg_name + ".getBytes()[0]" # TODO: verify if this is correct
|
|
205
|
+
case snowpark_type.DateType | "date":
|
|
206
|
+
return f"java.sql.Date.valueOf({arg_name})"
|
|
207
|
+
case snowpark_type.DecimalType | "decimal":
|
|
208
|
+
return f"new BigDecimal({arg_name})"
|
|
209
|
+
case snowpark_type.DoubleType | "double":
|
|
210
|
+
return f"Double.valueOf({arg_name}"
|
|
211
|
+
case snowpark_type.FloatType | "float":
|
|
212
|
+
return f"Float.valueOf({arg_name}"
|
|
213
|
+
case snowpark_type.IntegerType | "integer":
|
|
214
|
+
return f"Integer.valueOf({arg_name}"
|
|
215
|
+
case snowpark_type.LongType | "long":
|
|
216
|
+
return f"Long.valueOf({arg_name}"
|
|
217
|
+
case snowpark_type.ShortType | "short":
|
|
218
|
+
return f"Short.valueOf({arg_name}"
|
|
219
|
+
case snowpark_type.StringType | "string" | "char" | "varchar":
|
|
220
|
+
return arg_name
|
|
221
|
+
case snowpark_type.TimestampType | "timestamp" | "timestamp_ntz":
|
|
222
|
+
return f"java.sql.Timestamp.valueOf({arg_name})" # todo add test
|
|
223
|
+
case _:
|
|
224
|
+
exception = ValueError(f"Unsupported Snowpark type: {t}")
|
|
225
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_TYPE)
|
|
226
|
+
raise exception
|
|
227
|
+
|
|
228
|
+
if (is_snowpark_type and isinstance(input_type, snowpark_type.MapType)) or (
|
|
229
|
+
not is_snowpark_type and input_type.WhichOneof("kind") == "map"
|
|
230
|
+
):
|
|
231
|
+
key_type = input_type.key_type if is_snowpark_type else input_type.map.key_type
|
|
232
|
+
value_type = (
|
|
233
|
+
input_type.value_type if is_snowpark_type else input_type.map.value_type
|
|
234
|
+
)
|
|
235
|
+
key_converter = "{" + convert_from_string_to_type("e.getKey()", key_type) + "}"
|
|
236
|
+
value_converter = (
|
|
237
|
+
"{" + convert_from_string_to_type("e.getValue()", value_type) + "}"
|
|
238
|
+
)
|
|
239
|
+
return f"""
|
|
240
|
+
{arg_name}.entrySet()
|
|
241
|
+
.stream()
|
|
242
|
+
.collect(Collectors.toMap(
|
|
243
|
+
e -> {key_converter},
|
|
244
|
+
e -> {value_converter}
|
|
245
|
+
));
|
|
246
|
+
"""
|
|
247
|
+
else:
|
|
248
|
+
return arg_name
|