snowpark-connect 0.27.0__py3-none-any.whl → 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/snowpark_connect/__init__.py +1 -0
- snowflake/snowpark_connect/analyze_plan/map_tree_string.py +8 -4
- snowflake/snowpark_connect/client/__init__.py +15 -0
- snowflake/snowpark_connect/client/error_utils.py +30 -0
- snowflake/snowpark_connect/client/exceptions.py +36 -0
- snowflake/snowpark_connect/client/query_results.py +90 -0
- snowflake/snowpark_connect/client/server.py +717 -0
- snowflake/snowpark_connect/client/utils/__init__.py +10 -0
- snowflake/snowpark_connect/client/utils/session.py +85 -0
- snowflake/snowpark_connect/column_name_handler.py +404 -243
- snowflake/snowpark_connect/column_qualifier.py +43 -0
- snowflake/snowpark_connect/config.py +309 -26
- snowflake/snowpark_connect/constants.py +2 -0
- snowflake/snowpark_connect/dataframe_container.py +102 -8
- snowflake/snowpark_connect/date_time_format_mapping.py +71 -13
- snowflake/snowpark_connect/error/error_codes.py +50 -0
- snowflake/snowpark_connect/error/error_utils.py +172 -23
- snowflake/snowpark_connect/error/exceptions.py +13 -4
- snowflake/snowpark_connect/execute_plan/map_execution_command.py +15 -160
- snowflake/snowpark_connect/execute_plan/map_execution_root.py +26 -20
- snowflake/snowpark_connect/execute_plan/utils.py +5 -1
- snowflake/snowpark_connect/expression/error_utils.py +28 -0
- snowflake/snowpark_connect/expression/function_defaults.py +9 -2
- snowflake/snowpark_connect/expression/hybrid_column_map.py +53 -5
- snowflake/snowpark_connect/expression/integral_types_support.py +219 -0
- snowflake/snowpark_connect/expression/literal.py +37 -13
- snowflake/snowpark_connect/expression/map_cast.py +224 -15
- snowflake/snowpark_connect/expression/map_expression.py +80 -27
- snowflake/snowpark_connect/expression/map_extension.py +322 -12
- snowflake/snowpark_connect/expression/map_sql_expression.py +316 -81
- snowflake/snowpark_connect/expression/map_udf.py +86 -20
- snowflake/snowpark_connect/expression/map_unresolved_attribute.py +451 -173
- snowflake/snowpark_connect/expression/map_unresolved_function.py +2964 -829
- snowflake/snowpark_connect/expression/map_unresolved_star.py +87 -23
- snowflake/snowpark_connect/expression/map_update_fields.py +70 -18
- snowflake/snowpark_connect/expression/map_window_function.py +18 -3
- snowflake/snowpark_connect/includes/jars/json4s-ast_2.13-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/{scala-library-2.12.18.jar → sas-scala-udf_2.12-0.2.0.jar} +0 -0
- snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.13-0.2.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-reflect-2.13.16.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-common-utils_2.13-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/{spark-connect-client-jvm_2.12-3.5.6.jar → spark-connect-client-jvm_2.13-3.5.6.jar} +0 -0
- snowflake/snowpark_connect/includes/jars/{spark-sql_2.12-3.5.6.jar → spark-sql_2.13-3.5.6.jar} +0 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py +1 -1
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/listener_worker.py +1 -1
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.py +12 -10
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.pyi +14 -2
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +10 -8
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +13 -6
- snowflake/snowpark_connect/relation/catalogs/abstract_spark_catalog.py +65 -17
- snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +297 -49
- snowflake/snowpark_connect/relation/catalogs/utils.py +12 -4
- snowflake/snowpark_connect/relation/io_utils.py +110 -10
- snowflake/snowpark_connect/relation/map_aggregate.py +239 -256
- snowflake/snowpark_connect/relation/map_catalog.py +5 -1
- snowflake/snowpark_connect/relation/map_column_ops.py +264 -96
- snowflake/snowpark_connect/relation/map_extension.py +263 -29
- snowflake/snowpark_connect/relation/map_join.py +683 -442
- snowflake/snowpark_connect/relation/map_local_relation.py +28 -1
- snowflake/snowpark_connect/relation/map_map_partitions.py +83 -8
- snowflake/snowpark_connect/relation/map_relation.py +48 -19
- snowflake/snowpark_connect/relation/map_row_ops.py +310 -91
- snowflake/snowpark_connect/relation/map_show_string.py +13 -6
- snowflake/snowpark_connect/relation/map_sql.py +1233 -222
- snowflake/snowpark_connect/relation/map_stats.py +48 -9
- snowflake/snowpark_connect/relation/map_subquery_alias.py +11 -2
- snowflake/snowpark_connect/relation/map_udtf.py +14 -4
- snowflake/snowpark_connect/relation/read/jdbc_read_dbapi.py +53 -14
- snowflake/snowpark_connect/relation/read/map_read.py +134 -43
- snowflake/snowpark_connect/relation/read/map_read_csv.py +326 -47
- snowflake/snowpark_connect/relation/read/map_read_jdbc.py +21 -6
- snowflake/snowpark_connect/relation/read/map_read_json.py +324 -86
- snowflake/snowpark_connect/relation/read/map_read_parquet.py +146 -28
- snowflake/snowpark_connect/relation/read/map_read_partitioned_parquet.py +142 -0
- snowflake/snowpark_connect/relation/read/map_read_socket.py +15 -3
- snowflake/snowpark_connect/relation/read/map_read_table.py +86 -6
- snowflake/snowpark_connect/relation/read/map_read_text.py +22 -4
- snowflake/snowpark_connect/relation/read/metadata_utils.py +170 -0
- snowflake/snowpark_connect/relation/read/reader_config.py +42 -3
- snowflake/snowpark_connect/relation/read/utils.py +50 -5
- snowflake/snowpark_connect/relation/stage_locator.py +91 -55
- snowflake/snowpark_connect/relation/utils.py +128 -5
- snowflake/snowpark_connect/relation/write/jdbc_write_dbapi.py +19 -3
- snowflake/snowpark_connect/relation/write/map_write.py +929 -319
- snowflake/snowpark_connect/relation/write/map_write_jdbc.py +8 -2
- snowflake/snowpark_connect/resources/java_udfs-1.0-SNAPSHOT.jar +0 -0
- snowflake/snowpark_connect/resources_initializer.py +171 -48
- snowflake/snowpark_connect/server.py +528 -473
- snowflake/snowpark_connect/server_common/__init__.py +503 -0
- snowflake/snowpark_connect/snowflake_session.py +65 -0
- snowflake/snowpark_connect/start_server.py +53 -5
- snowflake/snowpark_connect/type_mapping.py +349 -27
- snowflake/snowpark_connect/type_support.py +130 -0
- snowflake/snowpark_connect/typed_column.py +9 -7
- snowflake/snowpark_connect/utils/artifacts.py +9 -8
- snowflake/snowpark_connect/utils/cache.py +49 -27
- snowflake/snowpark_connect/utils/concurrent.py +36 -1
- snowflake/snowpark_connect/utils/context.py +195 -37
- snowflake/snowpark_connect/utils/describe_query_cache.py +68 -53
- snowflake/snowpark_connect/utils/env_utils.py +5 -1
- snowflake/snowpark_connect/utils/expression_transformer.py +172 -0
- snowflake/snowpark_connect/utils/identifiers.py +137 -3
- snowflake/snowpark_connect/utils/io_utils.py +57 -1
- snowflake/snowpark_connect/utils/java_stored_procedure.py +151 -0
- snowflake/snowpark_connect/utils/java_udaf_utils.py +321 -0
- snowflake/snowpark_connect/utils/java_udtf_utils.py +239 -0
- snowflake/snowpark_connect/utils/jvm_udf_utils.py +281 -0
- snowflake/snowpark_connect/utils/open_telemetry.py +516 -0
- snowflake/snowpark_connect/utils/pandas_udtf_utils.py +8 -4
- snowflake/snowpark_connect/utils/patch_spark_line_number.py +181 -0
- snowflake/snowpark_connect/utils/profiling.py +25 -8
- snowflake/snowpark_connect/utils/scala_udf_utils.py +185 -340
- snowflake/snowpark_connect/utils/sequence.py +21 -0
- snowflake/snowpark_connect/utils/session.py +64 -28
- snowflake/snowpark_connect/utils/snowpark_connect_logging.py +51 -9
- snowflake/snowpark_connect/utils/spcs_logger.py +290 -0
- snowflake/snowpark_connect/utils/telemetry.py +192 -40
- snowflake/snowpark_connect/utils/temporary_view_cache.py +67 -0
- snowflake/snowpark_connect/utils/temporary_view_helper.py +334 -0
- snowflake/snowpark_connect/utils/udf_cache.py +117 -41
- snowflake/snowpark_connect/utils/udf_helper.py +39 -37
- snowflake/snowpark_connect/utils/udf_utils.py +133 -14
- snowflake/snowpark_connect/utils/udtf_helper.py +8 -1
- snowflake/snowpark_connect/utils/udtf_utils.py +46 -31
- snowflake/snowpark_connect/utils/udxf_import_utils.py +9 -2
- snowflake/snowpark_connect/utils/upload_java_jar.py +57 -0
- snowflake/snowpark_connect/version.py +1 -1
- snowflake/snowpark_decoder/dp_session.py +6 -2
- snowflake/snowpark_decoder/spark_decoder.py +12 -0
- {snowpark_connect-0.27.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-submit +14 -4
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/METADATA +16 -7
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/RECORD +139 -168
- snowflake/snowpark_connect/hidden_column.py +0 -39
- snowflake/snowpark_connect/includes/jars/antlr4-runtime-4.9.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-cli-1.5.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-codec-1.16.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-collections-3.2.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-collections4-4.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-compiler-3.1.9.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-compress-1.26.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-crypto-1.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-dbcp-1.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-io-2.16.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-lang-2.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-lang3-3.12.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-logging-1.1.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-math3-3.6.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-pool-1.5.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-text-1.10.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/hadoop-client-api-trimmed-3.3.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-annotations-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-core-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-core-asl-1.9.13.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-databind-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-dataformat-yaml-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-datatype-jsr310-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-module-scala_2.12-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-ast_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-core_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-jackson_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-native_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-scalap_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/kryo-shaded-4.0.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-1.2-api-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-api-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-core-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-slf4j2-impl-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/paranamer-2.8.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/paranamer-2.8.jar +0 -0
- snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-collection-compat_2.12-2.7.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-parser-combinators_2.12-2.3.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-reflect-2.12.18.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-xml_2.12-2.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/slf4j-api-2.0.7.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-catalyst_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-common-utils_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-core_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-graphx_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-hive-thriftserver_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-hive_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-kvstore_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-launcher_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-mesos_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-mllib-local_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-network-common_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-network-shuffle_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-repl_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sketch_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sql-api_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-tags_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-unsafe_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-yarn_2.12-3.5.6.jar +0 -0
- {snowpark_connect-0.27.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.27.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/NOTICE-binary +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/top_level.txt +0 -0
|
@@ -8,6 +8,7 @@ import typing
|
|
|
8
8
|
from contextlib import suppress
|
|
9
9
|
from datetime import datetime
|
|
10
10
|
from functools import cache
|
|
11
|
+
from typing import Union
|
|
11
12
|
|
|
12
13
|
import jpype
|
|
13
14
|
import pyarrow as pa
|
|
@@ -18,6 +19,7 @@ from pyspark.errors.exceptions.base import AnalysisException
|
|
|
18
19
|
from pyspark.sql.connect.proto import expressions_pb2
|
|
19
20
|
|
|
20
21
|
from snowflake import snowpark
|
|
22
|
+
from snowflake.snowpark import types as snowpark_type
|
|
21
23
|
from snowflake.snowpark._internal.utils import quote_name
|
|
22
24
|
from snowflake.snowpark.types import TimestampTimeZone, TimestampType
|
|
23
25
|
from snowflake.snowpark_connect.column_name_handler import ColumnNameMap
|
|
@@ -29,8 +31,17 @@ from snowflake.snowpark_connect.constants import (
|
|
|
29
31
|
from snowflake.snowpark_connect.date_time_format_mapping import (
|
|
30
32
|
convert_spark_format_to_snowflake,
|
|
31
33
|
)
|
|
34
|
+
from snowflake.snowpark_connect.error.error_codes import ErrorCodes
|
|
35
|
+
from snowflake.snowpark_connect.error.error_utils import attach_custom_error_code
|
|
32
36
|
from snowflake.snowpark_connect.expression.literal import get_literal_field_and_name
|
|
33
|
-
from snowflake.snowpark_connect.
|
|
37
|
+
from snowflake.snowpark_connect.expression.map_sql_expression import (
|
|
38
|
+
_INTERVAL_DAYTIME_PATTERN_RE,
|
|
39
|
+
_INTERVAL_YEARMONTH_PATTERN_RE,
|
|
40
|
+
)
|
|
41
|
+
from snowflake.snowpark_connect.utils.context import (
|
|
42
|
+
get_is_evaluating_sql,
|
|
43
|
+
get_jpype_jclass_lock,
|
|
44
|
+
)
|
|
34
45
|
from snowflake.snowpark_connect.utils.snowpark_connect_logging import logger
|
|
35
46
|
from snowflake.snowpark_connect.utils.telemetry import (
|
|
36
47
|
SnowparkConnectNotImplementedError,
|
|
@@ -57,12 +68,14 @@ SNOWPARK_TYPE_NAME_TO_PYSPARK_TYPE_NAME = {
|
|
|
57
68
|
|
|
58
69
|
@cache
|
|
59
70
|
def _get_struct_type_class():
|
|
60
|
-
|
|
71
|
+
with get_jpype_jclass_lock():
|
|
72
|
+
return jpype.JClass("org.apache.spark.sql.types.StructType")
|
|
61
73
|
|
|
62
74
|
|
|
63
75
|
@cache
|
|
64
76
|
def get_python_sql_utils_class():
|
|
65
|
-
|
|
77
|
+
with get_jpype_jclass_lock():
|
|
78
|
+
return jpype.JClass("org.apache.spark.sql.api.python.PythonSQLUtils")
|
|
66
79
|
|
|
67
80
|
|
|
68
81
|
def _parse_ddl_with_spark_scala(ddl_string: str) -> pyspark.sql.types.DataType:
|
|
@@ -93,7 +106,7 @@ def _parse_ddl_with_spark_scala(ddl_string: str) -> pyspark.sql.types.DataType:
|
|
|
93
106
|
|
|
94
107
|
def snowpark_to_proto_type(
|
|
95
108
|
data_type: snowpark.types.DataType,
|
|
96
|
-
column_name_map: ColumnNameMap,
|
|
109
|
+
column_name_map: ColumnNameMap | None = None,
|
|
97
110
|
df: snowpark.DataFrame = None, # remove this param after SNOW-1857090
|
|
98
111
|
depth: int = 0,
|
|
99
112
|
) -> dict[str, types_proto.DataType]:
|
|
@@ -193,7 +206,7 @@ def snowpark_to_proto_type(
|
|
|
193
206
|
# For attributes inside struct type (depth > 0), they don't get renamed as normal dataframe column names. Thus no need to do the conversion from snowpark column name to spark column name.
|
|
194
207
|
spark_name = (
|
|
195
208
|
column_name_map.get_spark_column_name(index)
|
|
196
|
-
if depth == 0
|
|
209
|
+
if depth == 0 and column_name_map
|
|
197
210
|
else field.name
|
|
198
211
|
)
|
|
199
212
|
|
|
@@ -274,10 +287,24 @@ def snowpark_to_proto_type(
|
|
|
274
287
|
case snowpark.types.VariantType:
|
|
275
288
|
# For now we are returning a string type for variant types.
|
|
276
289
|
return {"string": types_proto.DataType.String()}
|
|
290
|
+
case snowpark.types.YearMonthIntervalType:
|
|
291
|
+
return {
|
|
292
|
+
"year_month_interval": types_proto.DataType.YearMonthInterval(
|
|
293
|
+
start_field=data_type.start_field, end_field=data_type.end_field
|
|
294
|
+
)
|
|
295
|
+
}
|
|
296
|
+
case snowpark.types.DayTimeIntervalType:
|
|
297
|
+
return {
|
|
298
|
+
"day_time_interval": types_proto.DataType.DayTimeInterval(
|
|
299
|
+
start_field=data_type.start_field, end_field=data_type.end_field
|
|
300
|
+
)
|
|
301
|
+
}
|
|
277
302
|
case _:
|
|
278
|
-
|
|
303
|
+
exception = SnowparkConnectNotImplementedError(
|
|
279
304
|
f"Unsupported snowpark data type: {data_type}"
|
|
280
305
|
)
|
|
306
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
307
|
+
raise exception
|
|
281
308
|
|
|
282
309
|
|
|
283
310
|
def cast_to_match_snowpark_type(
|
|
@@ -317,7 +344,9 @@ def cast_to_match_snowpark_type(
|
|
|
317
344
|
with suppress(TypeError):
|
|
318
345
|
date = datetime.strptime(content, format)
|
|
319
346
|
return date
|
|
320
|
-
|
|
347
|
+
exception = ValueError(f"Date casting error for {str(content)}")
|
|
348
|
+
attach_custom_error_code(exception, ErrorCodes.INVALID_CAST)
|
|
349
|
+
raise exception
|
|
321
350
|
case snowpark.types.ShortType:
|
|
322
351
|
return int(content)
|
|
323
352
|
case snowpark.types.StringType:
|
|
@@ -328,10 +357,32 @@ def cast_to_match_snowpark_type(
|
|
|
328
357
|
return str(content)
|
|
329
358
|
case snowpark.types.TimestampType:
|
|
330
359
|
return str(content)
|
|
360
|
+
case snowpark.types.YearMonthIntervalType:
|
|
361
|
+
if isinstance(content, (int, float)):
|
|
362
|
+
total_months = int(content)
|
|
363
|
+
years = total_months // 12
|
|
364
|
+
months = total_months % 12
|
|
365
|
+
return f"INTERVAL '{years}-{months}' YEAR TO MONTH"
|
|
366
|
+
elif isinstance(content, str) and content.startswith(("+", "-")):
|
|
367
|
+
# Handle Snowflake's native interval format (e.g., "+11-08" or "-2-3")
|
|
368
|
+
# Convert to Spark's format: "INTERVAL 'Y-M' YEAR TO MONTH"
|
|
369
|
+
sign = content[0]
|
|
370
|
+
interval_part = content[1:] # Remove sign
|
|
371
|
+
if sign == "-":
|
|
372
|
+
return f"INTERVAL '-{interval_part}' YEAR TO MONTH"
|
|
373
|
+
else:
|
|
374
|
+
return f"INTERVAL '{interval_part}' YEAR TO MONTH"
|
|
375
|
+
return str(content)
|
|
376
|
+
case snowpark.types.DayTimeIntervalType:
|
|
377
|
+
return str(content)
|
|
378
|
+
case snowpark.types.MapType:
|
|
379
|
+
return content
|
|
331
380
|
case _:
|
|
332
|
-
|
|
381
|
+
exception = SnowparkConnectNotImplementedError(
|
|
333
382
|
f"Unsupported snowpark data type in casting: {data_type}"
|
|
334
383
|
)
|
|
384
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
385
|
+
raise exception
|
|
335
386
|
|
|
336
387
|
|
|
337
388
|
def snowpark_to_iceberg_type(data_type: snowpark.types.DataType) -> str:
|
|
@@ -364,9 +415,11 @@ def snowpark_to_iceberg_type(data_type: snowpark.types.DataType) -> str:
|
|
|
364
415
|
case snowpark.types.TimestampType:
|
|
365
416
|
return "timestamp"
|
|
366
417
|
case _:
|
|
367
|
-
|
|
418
|
+
exception = SnowparkConnectNotImplementedError(
|
|
368
419
|
f"Unsupported snowpark data type for iceber: {data_type}"
|
|
369
420
|
)
|
|
421
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
422
|
+
raise exception
|
|
370
423
|
|
|
371
424
|
|
|
372
425
|
def proto_to_snowpark_type(
|
|
@@ -411,6 +464,18 @@ def proto_to_snowpark_type(
|
|
|
411
464
|
# For UDT types, return the underlying SQL type
|
|
412
465
|
logger.debug("Returning underlying sql type for udt")
|
|
413
466
|
return proto_to_snowpark_type(data_type.udt.sql_type)
|
|
467
|
+
case "year_month_interval":
|
|
468
|
+
# Preserve start_field and end_field from protobuf
|
|
469
|
+
return snowpark.types.YearMonthIntervalType(
|
|
470
|
+
start_field=data_type.year_month_interval.start_field,
|
|
471
|
+
end_field=data_type.year_month_interval.end_field,
|
|
472
|
+
)
|
|
473
|
+
case "day_time_interval":
|
|
474
|
+
# Preserve start_field and end_field from protobuf
|
|
475
|
+
return snowpark.types.DayTimeIntervalType(
|
|
476
|
+
start_field=data_type.day_time_interval.start_field,
|
|
477
|
+
end_field=data_type.day_time_interval.end_field,
|
|
478
|
+
)
|
|
414
479
|
case _:
|
|
415
480
|
return map_simple_types(data_type.WhichOneof("kind"))
|
|
416
481
|
|
|
@@ -441,9 +506,11 @@ def map_snowpark_types_to_pyarrow_types(
|
|
|
441
506
|
)
|
|
442
507
|
)
|
|
443
508
|
else:
|
|
444
|
-
|
|
509
|
+
exception = AnalysisException(
|
|
445
510
|
f"Unsupported arrow type {pa_type} for snowpark ArrayType."
|
|
446
511
|
)
|
|
512
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_TYPE)
|
|
513
|
+
raise exception
|
|
447
514
|
case snowpark.types.BinaryType:
|
|
448
515
|
return pa.binary()
|
|
449
516
|
case snowpark.types.BooleanType:
|
|
@@ -484,9 +551,11 @@ def map_snowpark_types_to_pyarrow_types(
|
|
|
484
551
|
),
|
|
485
552
|
)
|
|
486
553
|
else:
|
|
487
|
-
|
|
554
|
+
exception = AnalysisException(
|
|
488
555
|
f"Unsupported arrow type {pa_type} for snowpark MapType."
|
|
489
556
|
)
|
|
557
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_TYPE)
|
|
558
|
+
raise exception
|
|
490
559
|
case snowpark.types.NullType:
|
|
491
560
|
return pa.string()
|
|
492
561
|
case snowpark.types.ShortType:
|
|
@@ -503,30 +572,44 @@ def map_snowpark_types_to_pyarrow_types(
|
|
|
503
572
|
pa.field(
|
|
504
573
|
field.name if not rename_struct_columns else str(i),
|
|
505
574
|
map_snowpark_types_to_pyarrow_types(
|
|
506
|
-
field.datatype,
|
|
575
|
+
field.datatype,
|
|
576
|
+
pa_type[i].type,
|
|
507
577
|
),
|
|
508
|
-
nullable=
|
|
578
|
+
nullable=True,
|
|
509
579
|
)
|
|
510
580
|
for i, field in enumerate(snowpark_type.fields)
|
|
511
581
|
]
|
|
512
582
|
)
|
|
513
583
|
else:
|
|
514
|
-
|
|
584
|
+
exception = AnalysisException(
|
|
515
585
|
f"Unsupported arrow type {pa_type} for snowpark StructType."
|
|
516
586
|
)
|
|
587
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_TYPE)
|
|
588
|
+
raise exception
|
|
517
589
|
case snowpark.types.TimestampType:
|
|
518
|
-
|
|
519
|
-
|
|
590
|
+
# Check if pa_type has unit attribute (it should be a timestamp type)
|
|
591
|
+
unit = pa_type.unit if hasattr(pa_type, "unit") else "us"
|
|
592
|
+
tz = pa_type.tz if hasattr(pa_type, "tz") else None
|
|
593
|
+
|
|
594
|
+
# Spark truncates nanosecond precision to microseconds
|
|
520
595
|
if unit == "ns":
|
|
521
|
-
# Spark truncates nanosecond precision to microseconds
|
|
522
596
|
unit = "us"
|
|
597
|
+
|
|
523
598
|
return pa.timestamp(unit, tz=tz)
|
|
524
599
|
case snowpark.types.VariantType:
|
|
525
600
|
return pa.string()
|
|
601
|
+
case snowpark.types.YearMonthIntervalType:
|
|
602
|
+
# Return string type so formatted intervals are preserved in display
|
|
603
|
+
return pa.string()
|
|
604
|
+
case snowpark.types.DayTimeIntervalType:
|
|
605
|
+
# Return string type so formatted intervals are preserved in display
|
|
606
|
+
return pa.string()
|
|
526
607
|
case _:
|
|
527
|
-
|
|
608
|
+
exception = SnowparkConnectNotImplementedError(
|
|
528
609
|
f"Unsupported snowpark data type: {snowpark_type}"
|
|
529
610
|
)
|
|
611
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
612
|
+
raise exception
|
|
530
613
|
|
|
531
614
|
|
|
532
615
|
def map_pyarrow_to_snowpark_types(pa_type: pa.DataType) -> snowpark.types.DataType:
|
|
@@ -595,10 +678,15 @@ def map_pyarrow_to_snowpark_types(pa_type: pa.DataType) -> snowpark.types.DataTy
|
|
|
595
678
|
return snowpark.types.TimestampType()
|
|
596
679
|
elif pa.types.is_null(pa_type):
|
|
597
680
|
return snowpark.types.NullType()
|
|
681
|
+
elif pa.types.is_duration(pa_type):
|
|
682
|
+
# Map PyArrow duration[us] to DayTimeIntervalType
|
|
683
|
+
return snowpark.types.DayTimeIntervalType()
|
|
598
684
|
else:
|
|
599
|
-
|
|
685
|
+
exception = SnowparkConnectNotImplementedError(
|
|
600
686
|
f"Unsupported PyArrow data type: {pa_type}"
|
|
601
687
|
)
|
|
688
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
689
|
+
raise exception
|
|
602
690
|
|
|
603
691
|
|
|
604
692
|
def map_pyspark_types_to_snowpark_types(
|
|
@@ -676,9 +764,19 @@ def map_pyspark_types_to_snowpark_types(
|
|
|
676
764
|
return snowpark.types.TimestampType()
|
|
677
765
|
if isinstance(type_to_map, pyspark.sql.types.TimestampNTZType):
|
|
678
766
|
return snowpark.types.TimestampType(timezone=TimestampTimeZone.NTZ)
|
|
679
|
-
|
|
767
|
+
if isinstance(type_to_map, pyspark.sql.types.YearMonthIntervalType):
|
|
768
|
+
return snowpark.types.YearMonthIntervalType(
|
|
769
|
+
type_to_map.startField, type_to_map.endField
|
|
770
|
+
)
|
|
771
|
+
if isinstance(type_to_map, pyspark.sql.types.DayTimeIntervalType):
|
|
772
|
+
return snowpark.types.DayTimeIntervalType(
|
|
773
|
+
type_to_map.startField, type_to_map.endField
|
|
774
|
+
)
|
|
775
|
+
exception = SnowparkConnectNotImplementedError(
|
|
680
776
|
f"Unsupported spark data type: {type_to_map}"
|
|
681
777
|
)
|
|
778
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
779
|
+
raise exception
|
|
682
780
|
|
|
683
781
|
|
|
684
782
|
def map_snowpark_to_pyspark_types(
|
|
@@ -743,7 +841,81 @@ def map_snowpark_to_pyspark_types(
|
|
|
743
841
|
if type_to_map.tz == snowpark.types.TimestampTimeZone.NTZ:
|
|
744
842
|
return pyspark.sql.types.TimestampNTZType()
|
|
745
843
|
return pyspark.sql.types.TimestampType()
|
|
746
|
-
|
|
844
|
+
if isinstance(type_to_map, snowpark.types.YearMonthIntervalType):
|
|
845
|
+
return pyspark.sql.types.YearMonthIntervalType(
|
|
846
|
+
type_to_map.start_field, type_to_map.end_field
|
|
847
|
+
)
|
|
848
|
+
if isinstance(type_to_map, snowpark.types.DayTimeIntervalType):
|
|
849
|
+
return pyspark.sql.types.DayTimeIntervalType(
|
|
850
|
+
type_to_map.start_field, type_to_map.end_field
|
|
851
|
+
)
|
|
852
|
+
exception = SnowparkConnectNotImplementedError(
|
|
853
|
+
f"Unsupported data type: {type_to_map}"
|
|
854
|
+
)
|
|
855
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
856
|
+
raise exception
|
|
857
|
+
|
|
858
|
+
|
|
859
|
+
def map_pyspark_types_to_pyarrow_types(
|
|
860
|
+
pyspark_type: pyspark.sql.types.DataType,
|
|
861
|
+
) -> pa.DataType:
|
|
862
|
+
"""
|
|
863
|
+
Map a PySpark data type to a PyArrow data type.
|
|
864
|
+
|
|
865
|
+
This function converts PySpark types to PyArrow types for generating
|
|
866
|
+
Parquet metadata files with correct schema structure.
|
|
867
|
+
|
|
868
|
+
Args:
|
|
869
|
+
pyspark_type: PySpark data type to convert
|
|
870
|
+
|
|
871
|
+
Returns:
|
|
872
|
+
Corresponding PyArrow data type
|
|
873
|
+
"""
|
|
874
|
+
if isinstance(pyspark_type, pyspark.sql.types.StringType):
|
|
875
|
+
return pa.string()
|
|
876
|
+
elif isinstance(pyspark_type, pyspark.sql.types.LongType):
|
|
877
|
+
return pa.int64()
|
|
878
|
+
elif isinstance(pyspark_type, pyspark.sql.types.IntegerType):
|
|
879
|
+
return pa.int32()
|
|
880
|
+
elif isinstance(pyspark_type, pyspark.sql.types.ShortType):
|
|
881
|
+
return pa.int16()
|
|
882
|
+
elif isinstance(pyspark_type, pyspark.sql.types.ByteType):
|
|
883
|
+
return pa.int8()
|
|
884
|
+
elif isinstance(pyspark_type, pyspark.sql.types.DoubleType):
|
|
885
|
+
return pa.float64()
|
|
886
|
+
elif isinstance(pyspark_type, pyspark.sql.types.FloatType):
|
|
887
|
+
return pa.float32()
|
|
888
|
+
elif isinstance(pyspark_type, pyspark.sql.types.BooleanType):
|
|
889
|
+
return pa.bool_()
|
|
890
|
+
elif isinstance(pyspark_type, pyspark.sql.types.DateType):
|
|
891
|
+
return pa.date32()
|
|
892
|
+
elif isinstance(pyspark_type, pyspark.sql.types.TimestampType):
|
|
893
|
+
return pa.timestamp("us")
|
|
894
|
+
elif isinstance(pyspark_type, pyspark.sql.types.TimestampNTZType):
|
|
895
|
+
return pa.timestamp("us")
|
|
896
|
+
elif isinstance(pyspark_type, pyspark.sql.types.BinaryType):
|
|
897
|
+
return pa.binary()
|
|
898
|
+
elif isinstance(pyspark_type, pyspark.sql.types.DecimalType):
|
|
899
|
+
return pa.decimal128(pyspark_type.precision, pyspark_type.scale)
|
|
900
|
+
elif isinstance(pyspark_type, pyspark.sql.types.ArrayType):
|
|
901
|
+
element_type = map_pyspark_types_to_pyarrow_types(pyspark_type.elementType)
|
|
902
|
+
return pa.list_(element_type)
|
|
903
|
+
elif isinstance(pyspark_type, pyspark.sql.types.MapType):
|
|
904
|
+
key_type = map_pyspark_types_to_pyarrow_types(pyspark_type.keyType)
|
|
905
|
+
value_type = map_pyspark_types_to_pyarrow_types(pyspark_type.valueType)
|
|
906
|
+
return pa.map_(key_type, value_type)
|
|
907
|
+
elif isinstance(pyspark_type, pyspark.sql.types.StructType):
|
|
908
|
+
fields = [
|
|
909
|
+
pa.field(
|
|
910
|
+
f.name,
|
|
911
|
+
map_pyspark_types_to_pyarrow_types(f.dataType),
|
|
912
|
+
nullable=f.nullable,
|
|
913
|
+
)
|
|
914
|
+
for f in pyspark_type.fields
|
|
915
|
+
]
|
|
916
|
+
return pa.struct(fields)
|
|
917
|
+
else:
|
|
918
|
+
return pa.string() # Default fallback
|
|
747
919
|
|
|
748
920
|
|
|
749
921
|
def map_simple_types(simple_type: str) -> snowpark.types.DataType:
|
|
@@ -785,18 +957,51 @@ def map_simple_types(simple_type: str) -> snowpark.types.DataType:
|
|
|
785
957
|
return snowpark.types.TimestampType(snowpark.types.TimestampTimeZone.NTZ)
|
|
786
958
|
case "timestamp_ltz":
|
|
787
959
|
return snowpark.types.TimestampType(snowpark.types.TimestampTimeZone.LTZ)
|
|
960
|
+
case "year_month_interval":
|
|
961
|
+
return snowpark.types.YearMonthIntervalType()
|
|
788
962
|
case "day_time_interval":
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
return snowpark.types.
|
|
963
|
+
return snowpark.types.DayTimeIntervalType()
|
|
964
|
+
case type_name if _INTERVAL_YEARMONTH_PATTERN_RE.match(type_name):
|
|
965
|
+
return snowpark.types.YearMonthIntervalType()
|
|
966
|
+
case type_name if _INTERVAL_DAYTIME_PATTERN_RE.match(type_name):
|
|
967
|
+
return snowpark.types.DayTimeIntervalType()
|
|
968
|
+
# Year-Month interval cases
|
|
969
|
+
case "interval year":
|
|
970
|
+
return snowpark.types.YearMonthIntervalType(0)
|
|
971
|
+
case "interval month":
|
|
972
|
+
return snowpark.types.YearMonthIntervalType(1)
|
|
973
|
+
case "interval year to month":
|
|
974
|
+
return snowpark.types.YearMonthIntervalType(0, 1)
|
|
975
|
+
case "interval day":
|
|
976
|
+
return snowpark.types.DayTimeIntervalType(0)
|
|
977
|
+
case "interval hour":
|
|
978
|
+
return snowpark.types.DayTimeIntervalType(1)
|
|
979
|
+
case "interval minute":
|
|
980
|
+
return snowpark.types.DayTimeIntervalType(2)
|
|
981
|
+
case "interval second":
|
|
982
|
+
return snowpark.types.DayTimeIntervalType(3)
|
|
983
|
+
case "interval day to hour":
|
|
984
|
+
return snowpark.types.DayTimeIntervalType(0, 1)
|
|
985
|
+
case "interval day to minute":
|
|
986
|
+
return snowpark.types.DayTimeIntervalType(0, 2)
|
|
987
|
+
case "interval day to second":
|
|
988
|
+
return snowpark.types.DayTimeIntervalType(0, 3)
|
|
989
|
+
case "interval hour to minute":
|
|
990
|
+
return snowpark.types.DayTimeIntervalType(1, 2)
|
|
991
|
+
case "interval hour to second":
|
|
992
|
+
return snowpark.types.DayTimeIntervalType(1, 3)
|
|
993
|
+
case "interval minute to second":
|
|
994
|
+
return snowpark.types.DayTimeIntervalType(2, 3)
|
|
792
995
|
case _:
|
|
793
996
|
if simple_type.startswith("decimal"):
|
|
794
997
|
precision = int(simple_type.split("(")[1].split(",")[0])
|
|
795
998
|
scale = int(simple_type.split(",")[1].split(")")[0])
|
|
796
999
|
return snowpark.types.DecimalType(precision, scale)
|
|
797
|
-
|
|
1000
|
+
exception = SnowparkConnectNotImplementedError(
|
|
798
1001
|
f"Unsupported simple type: {simple_type}"
|
|
799
1002
|
)
|
|
1003
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
1004
|
+
raise exception
|
|
800
1005
|
|
|
801
1006
|
|
|
802
1007
|
def map_json_schema_to_snowpark(
|
|
@@ -937,9 +1142,11 @@ def map_spark_timestamp_format_expression(
|
|
|
937
1142
|
lit_value, _ = get_literal_field_and_name(arguments.literal)
|
|
938
1143
|
return convert_spark_format_to_snowflake(lit_value, timestamp_input_type)
|
|
939
1144
|
case other:
|
|
940
|
-
|
|
1145
|
+
exception = SnowparkConnectNotImplementedError(
|
|
941
1146
|
f"Unsupported expression type {other} in timestamp format argument"
|
|
942
1147
|
)
|
|
1148
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
1149
|
+
raise exception
|
|
943
1150
|
|
|
944
1151
|
|
|
945
1152
|
def map_spark_number_format_expression(
|
|
@@ -958,9 +1165,11 @@ def map_spark_number_format_expression(
|
|
|
958
1165
|
case "literal":
|
|
959
1166
|
lit_value, _ = get_literal_field_and_name(arguments.literal)
|
|
960
1167
|
case other:
|
|
961
|
-
|
|
1168
|
+
exception = SnowparkConnectNotImplementedError(
|
|
962
1169
|
f"Unsupported expression type {other} in number format argument"
|
|
963
1170
|
)
|
|
1171
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
1172
|
+
raise exception
|
|
964
1173
|
|
|
965
1174
|
return _map_spark_to_snowflake_number_format(lit_value)
|
|
966
1175
|
|
|
@@ -988,3 +1197,116 @@ def _map_spark_to_snowflake_number_format(spark_format: str) -> str:
|
|
|
988
1197
|
return ret
|
|
989
1198
|
|
|
990
1199
|
return NUMBER_FORMAT_RE.sub(_replace, spark_format)
|
|
1200
|
+
|
|
1201
|
+
|
|
1202
|
+
def map_type_to_snowflake_type(
|
|
1203
|
+
t: Union[snowpark_type.DataType, types_proto.DataType]
|
|
1204
|
+
) -> str:
|
|
1205
|
+
"""Maps a Snowpark or Spark protobuf type to a Snowflake type string."""
|
|
1206
|
+
if not t:
|
|
1207
|
+
return "VARCHAR"
|
|
1208
|
+
is_snowpark_type = isinstance(t, snowpark_type.DataType)
|
|
1209
|
+
condition = type(t) if is_snowpark_type else t.WhichOneof("kind")
|
|
1210
|
+
match condition:
|
|
1211
|
+
case snowpark_type.ArrayType | "array":
|
|
1212
|
+
return (
|
|
1213
|
+
f"ARRAY({map_type_to_snowflake_type(t.element_type)})"
|
|
1214
|
+
if is_snowpark_type
|
|
1215
|
+
else f"ARRAY({map_type_to_snowflake_type(t.array.element_type)})"
|
|
1216
|
+
)
|
|
1217
|
+
case snowpark_type.BinaryType | "binary":
|
|
1218
|
+
return "BINARY"
|
|
1219
|
+
case snowpark_type.BooleanType | "boolean":
|
|
1220
|
+
return "BOOLEAN"
|
|
1221
|
+
case snowpark_type.ByteType | "byte":
|
|
1222
|
+
return "TINYINT"
|
|
1223
|
+
case snowpark_type.DateType | "date":
|
|
1224
|
+
return "DATE"
|
|
1225
|
+
case snowpark_type.DecimalType | "decimal":
|
|
1226
|
+
return "NUMBER"
|
|
1227
|
+
case snowpark_type.DoubleType | "double":
|
|
1228
|
+
return "DOUBLE"
|
|
1229
|
+
case snowpark_type.FloatType | "float":
|
|
1230
|
+
return "FLOAT"
|
|
1231
|
+
case snowpark_type.GeographyType:
|
|
1232
|
+
return "GEOGRAPHY"
|
|
1233
|
+
case snowpark_type.IntegerType | "integer":
|
|
1234
|
+
return "INT"
|
|
1235
|
+
case snowpark_type.LongType | "long":
|
|
1236
|
+
return "BIGINT"
|
|
1237
|
+
case snowpark_type.MapType | "map":
|
|
1238
|
+
# Maps to OBJECT in Snowflake if key and value types are not specified.
|
|
1239
|
+
key_type = (
|
|
1240
|
+
map_type_to_snowflake_type(t.key_type)
|
|
1241
|
+
if is_snowpark_type
|
|
1242
|
+
else map_type_to_snowflake_type(t.map.key_type)
|
|
1243
|
+
)
|
|
1244
|
+
value_type = (
|
|
1245
|
+
map_type_to_snowflake_type(t.value_type)
|
|
1246
|
+
if is_snowpark_type
|
|
1247
|
+
else map_type_to_snowflake_type(t.map.value_type)
|
|
1248
|
+
)
|
|
1249
|
+
return (
|
|
1250
|
+
f"MAP({key_type}, {value_type})"
|
|
1251
|
+
if key_type and value_type
|
|
1252
|
+
else "OBJECT"
|
|
1253
|
+
)
|
|
1254
|
+
case snowpark_type.NullType | "null":
|
|
1255
|
+
return "VARCHAR"
|
|
1256
|
+
case snowpark_type.ShortType | "short":
|
|
1257
|
+
return "SMALLINT"
|
|
1258
|
+
case snowpark_type.StringType | "string" | "char" | "varchar":
|
|
1259
|
+
return "VARCHAR"
|
|
1260
|
+
case snowpark_type.TimestampType | "timestamp" | "timestamp_ntz":
|
|
1261
|
+
return "TIMESTAMP"
|
|
1262
|
+
case snowpark_type.StructType | "struct":
|
|
1263
|
+
return "VARIANT"
|
|
1264
|
+
case snowpark_type.VariantType:
|
|
1265
|
+
return "VARIANT"
|
|
1266
|
+
case _:
|
|
1267
|
+
exception = ValueError(f"Unsupported Snowpark type: {t}")
|
|
1268
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_TYPE)
|
|
1269
|
+
raise exception
|
|
1270
|
+
|
|
1271
|
+
|
|
1272
|
+
def merge_different_types(
|
|
1273
|
+
type1: snowpark_type.DataType,
|
|
1274
|
+
type2: snowpark_type.DataType,
|
|
1275
|
+
) -> snowpark_type.DataType:
|
|
1276
|
+
"""
|
|
1277
|
+
Merge two different Snowpark data types.
|
|
1278
|
+
"""
|
|
1279
|
+
# If one type is NullType, return the other
|
|
1280
|
+
if isinstance(type1, snowpark_type.NullType) or type1 is None:
|
|
1281
|
+
return type2
|
|
1282
|
+
if isinstance(type2, snowpark_type.NullType) or type2 is None:
|
|
1283
|
+
return type1
|
|
1284
|
+
|
|
1285
|
+
if type1 == type2:
|
|
1286
|
+
return type2
|
|
1287
|
+
# Define type hierarchy - from narrowest to widest scope
|
|
1288
|
+
# Each set contains types that can be merged to a common type
|
|
1289
|
+
numeric_type_hierarchy = [
|
|
1290
|
+
# Numeric hierarchy: byte -> short -> int -> long -> decimal -> float -> double
|
|
1291
|
+
snowpark_type.ByteType,
|
|
1292
|
+
snowpark_type.ShortType,
|
|
1293
|
+
snowpark_type.IntegerType,
|
|
1294
|
+
snowpark_type.LongType,
|
|
1295
|
+
snowpark_type.DecimalType,
|
|
1296
|
+
snowpark_type.FloatType,
|
|
1297
|
+
snowpark_type.DoubleType,
|
|
1298
|
+
]
|
|
1299
|
+
|
|
1300
|
+
type1_index = next(
|
|
1301
|
+
(i for i, t in enumerate(numeric_type_hierarchy) if isinstance(type1, t)), -1
|
|
1302
|
+
)
|
|
1303
|
+
type2_index = next(
|
|
1304
|
+
(i for i, t in enumerate(numeric_type_hierarchy) if isinstance(type2, t)), -1
|
|
1305
|
+
)
|
|
1306
|
+
|
|
1307
|
+
if type1_index >= 0 and type2_index >= 0:
|
|
1308
|
+
broader_index = max(type1_index, type2_index)
|
|
1309
|
+
return numeric_type_hierarchy[broader_index]()
|
|
1310
|
+
|
|
1311
|
+
# No common type found, default to StringType
|
|
1312
|
+
return snowpark_type.StringType()
|
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
|
|
5
|
+
import threading
|
|
6
|
+
|
|
7
|
+
from snowflake import snowpark
|
|
8
|
+
from snowflake.snowpark.types import (
|
|
9
|
+
ArrayType,
|
|
10
|
+
ByteType,
|
|
11
|
+
DataType,
|
|
12
|
+
DecimalType,
|
|
13
|
+
IntegerType,
|
|
14
|
+
LongType,
|
|
15
|
+
MapType,
|
|
16
|
+
ShortType,
|
|
17
|
+
StructField,
|
|
18
|
+
StructType,
|
|
19
|
+
_IntegralType,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
_integral_types_conversion_enabled: bool = False
|
|
23
|
+
_client_mode_lock = threading.Lock()
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def set_integral_types_conversion(enabled: bool) -> None:
|
|
27
|
+
global _integral_types_conversion_enabled
|
|
28
|
+
|
|
29
|
+
with _client_mode_lock:
|
|
30
|
+
if _integral_types_conversion_enabled == enabled:
|
|
31
|
+
return
|
|
32
|
+
|
|
33
|
+
_integral_types_conversion_enabled = enabled
|
|
34
|
+
|
|
35
|
+
if enabled:
|
|
36
|
+
snowpark.context._integral_type_default_precision = {
|
|
37
|
+
LongType: 19,
|
|
38
|
+
IntegerType: 10,
|
|
39
|
+
ShortType: 5,
|
|
40
|
+
ByteType: 3,
|
|
41
|
+
}
|
|
42
|
+
else:
|
|
43
|
+
snowpark.context._integral_type_default_precision = {}
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def set_integral_types_for_client_default(is_python_client: bool) -> None:
|
|
47
|
+
"""
|
|
48
|
+
Set integral types based on client type when config is 'client_default'.
|
|
49
|
+
"""
|
|
50
|
+
from snowflake.snowpark_connect.config import global_config
|
|
51
|
+
|
|
52
|
+
config_key = "snowpark.connect.integralTypesEmulation"
|
|
53
|
+
if global_config.get(config_key) != "client_default":
|
|
54
|
+
return
|
|
55
|
+
|
|
56
|
+
# if client mode matches, no action needed (no lock overhead)
|
|
57
|
+
if _integral_types_conversion_enabled == (not is_python_client):
|
|
58
|
+
return
|
|
59
|
+
|
|
60
|
+
set_integral_types_conversion(not is_python_client)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def emulate_integral_types(t: DataType) -> DataType:
|
|
64
|
+
"""
|
|
65
|
+
Map LongType based on precision attribute to appropriate integral types.
|
|
66
|
+
|
|
67
|
+
Mappings:
|
|
68
|
+
- _IntegralType with precision=19 -> LongType
|
|
69
|
+
- _IntegralType with precision=10 -> IntegerType
|
|
70
|
+
- _IntegralType with precision=5 -> ShortType
|
|
71
|
+
- _IntegralType with precision=3 -> ByteType
|
|
72
|
+
- _IntegralType with other precision -> DecimalType(precision, 0)
|
|
73
|
+
|
|
74
|
+
This conversion is controlled by the 'snowpark.connect.integralTypesEmulation' config.
|
|
75
|
+
When disabled, the function returns the input type unchanged.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
t: The DataType to transform
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
The transformed DataType with integral type conversions applied based on precision.
|
|
82
|
+
"""
|
|
83
|
+
global _integral_types_conversion_enabled
|
|
84
|
+
|
|
85
|
+
with _client_mode_lock:
|
|
86
|
+
enabled = _integral_types_conversion_enabled
|
|
87
|
+
if not enabled:
|
|
88
|
+
return t
|
|
89
|
+
if isinstance(t, _IntegralType):
|
|
90
|
+
precision = getattr(t, "_precision", None)
|
|
91
|
+
|
|
92
|
+
if precision is None:
|
|
93
|
+
return t
|
|
94
|
+
elif precision == 19:
|
|
95
|
+
return LongType()
|
|
96
|
+
elif precision == 10:
|
|
97
|
+
return IntegerType()
|
|
98
|
+
elif precision == 5:
|
|
99
|
+
return ShortType()
|
|
100
|
+
elif precision == 3:
|
|
101
|
+
return ByteType()
|
|
102
|
+
else:
|
|
103
|
+
return DecimalType(precision, 0)
|
|
104
|
+
|
|
105
|
+
elif isinstance(t, StructType):
|
|
106
|
+
new_fields = [
|
|
107
|
+
StructField(
|
|
108
|
+
field.name,
|
|
109
|
+
emulate_integral_types(field.datatype),
|
|
110
|
+
field.nullable,
|
|
111
|
+
_is_column=field._is_column,
|
|
112
|
+
)
|
|
113
|
+
for field in t.fields
|
|
114
|
+
]
|
|
115
|
+
return StructType(new_fields)
|
|
116
|
+
|
|
117
|
+
elif isinstance(t, ArrayType):
|
|
118
|
+
return ArrayType(
|
|
119
|
+
emulate_integral_types(t.element_type),
|
|
120
|
+
t.contains_null,
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
elif isinstance(t, MapType):
|
|
124
|
+
return MapType(
|
|
125
|
+
emulate_integral_types(t.key_type),
|
|
126
|
+
emulate_integral_types(t.value_type),
|
|
127
|
+
t.value_contains_null,
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
return t
|