snowpark-connect 0.27.0__py3-none-any.whl → 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/snowpark_connect/__init__.py +1 -0
- snowflake/snowpark_connect/analyze_plan/map_tree_string.py +8 -4
- snowflake/snowpark_connect/client/__init__.py +15 -0
- snowflake/snowpark_connect/client/error_utils.py +30 -0
- snowflake/snowpark_connect/client/exceptions.py +36 -0
- snowflake/snowpark_connect/client/query_results.py +90 -0
- snowflake/snowpark_connect/client/server.py +717 -0
- snowflake/snowpark_connect/client/utils/__init__.py +10 -0
- snowflake/snowpark_connect/client/utils/session.py +85 -0
- snowflake/snowpark_connect/column_name_handler.py +404 -243
- snowflake/snowpark_connect/column_qualifier.py +43 -0
- snowflake/snowpark_connect/config.py +309 -26
- snowflake/snowpark_connect/constants.py +2 -0
- snowflake/snowpark_connect/dataframe_container.py +102 -8
- snowflake/snowpark_connect/date_time_format_mapping.py +71 -13
- snowflake/snowpark_connect/error/error_codes.py +50 -0
- snowflake/snowpark_connect/error/error_utils.py +172 -23
- snowflake/snowpark_connect/error/exceptions.py +13 -4
- snowflake/snowpark_connect/execute_plan/map_execution_command.py +15 -160
- snowflake/snowpark_connect/execute_plan/map_execution_root.py +26 -20
- snowflake/snowpark_connect/execute_plan/utils.py +5 -1
- snowflake/snowpark_connect/expression/error_utils.py +28 -0
- snowflake/snowpark_connect/expression/function_defaults.py +9 -2
- snowflake/snowpark_connect/expression/hybrid_column_map.py +53 -5
- snowflake/snowpark_connect/expression/integral_types_support.py +219 -0
- snowflake/snowpark_connect/expression/literal.py +37 -13
- snowflake/snowpark_connect/expression/map_cast.py +224 -15
- snowflake/snowpark_connect/expression/map_expression.py +80 -27
- snowflake/snowpark_connect/expression/map_extension.py +322 -12
- snowflake/snowpark_connect/expression/map_sql_expression.py +316 -81
- snowflake/snowpark_connect/expression/map_udf.py +86 -20
- snowflake/snowpark_connect/expression/map_unresolved_attribute.py +451 -173
- snowflake/snowpark_connect/expression/map_unresolved_function.py +2964 -829
- snowflake/snowpark_connect/expression/map_unresolved_star.py +87 -23
- snowflake/snowpark_connect/expression/map_update_fields.py +70 -18
- snowflake/snowpark_connect/expression/map_window_function.py +18 -3
- snowflake/snowpark_connect/includes/jars/json4s-ast_2.13-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/{scala-library-2.12.18.jar → sas-scala-udf_2.12-0.2.0.jar} +0 -0
- snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.13-0.2.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-reflect-2.13.16.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-common-utils_2.13-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/{spark-connect-client-jvm_2.12-3.5.6.jar → spark-connect-client-jvm_2.13-3.5.6.jar} +0 -0
- snowflake/snowpark_connect/includes/jars/{spark-sql_2.12-3.5.6.jar → spark-sql_2.13-3.5.6.jar} +0 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py +1 -1
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/listener_worker.py +1 -1
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.py +12 -10
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.pyi +14 -2
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +10 -8
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +13 -6
- snowflake/snowpark_connect/relation/catalogs/abstract_spark_catalog.py +65 -17
- snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +297 -49
- snowflake/snowpark_connect/relation/catalogs/utils.py +12 -4
- snowflake/snowpark_connect/relation/io_utils.py +110 -10
- snowflake/snowpark_connect/relation/map_aggregate.py +239 -256
- snowflake/snowpark_connect/relation/map_catalog.py +5 -1
- snowflake/snowpark_connect/relation/map_column_ops.py +264 -96
- snowflake/snowpark_connect/relation/map_extension.py +263 -29
- snowflake/snowpark_connect/relation/map_join.py +683 -442
- snowflake/snowpark_connect/relation/map_local_relation.py +28 -1
- snowflake/snowpark_connect/relation/map_map_partitions.py +83 -8
- snowflake/snowpark_connect/relation/map_relation.py +48 -19
- snowflake/snowpark_connect/relation/map_row_ops.py +310 -91
- snowflake/snowpark_connect/relation/map_show_string.py +13 -6
- snowflake/snowpark_connect/relation/map_sql.py +1233 -222
- snowflake/snowpark_connect/relation/map_stats.py +48 -9
- snowflake/snowpark_connect/relation/map_subquery_alias.py +11 -2
- snowflake/snowpark_connect/relation/map_udtf.py +14 -4
- snowflake/snowpark_connect/relation/read/jdbc_read_dbapi.py +53 -14
- snowflake/snowpark_connect/relation/read/map_read.py +134 -43
- snowflake/snowpark_connect/relation/read/map_read_csv.py +326 -47
- snowflake/snowpark_connect/relation/read/map_read_jdbc.py +21 -6
- snowflake/snowpark_connect/relation/read/map_read_json.py +324 -86
- snowflake/snowpark_connect/relation/read/map_read_parquet.py +146 -28
- snowflake/snowpark_connect/relation/read/map_read_partitioned_parquet.py +142 -0
- snowflake/snowpark_connect/relation/read/map_read_socket.py +15 -3
- snowflake/snowpark_connect/relation/read/map_read_table.py +86 -6
- snowflake/snowpark_connect/relation/read/map_read_text.py +22 -4
- snowflake/snowpark_connect/relation/read/metadata_utils.py +170 -0
- snowflake/snowpark_connect/relation/read/reader_config.py +42 -3
- snowflake/snowpark_connect/relation/read/utils.py +50 -5
- snowflake/snowpark_connect/relation/stage_locator.py +91 -55
- snowflake/snowpark_connect/relation/utils.py +128 -5
- snowflake/snowpark_connect/relation/write/jdbc_write_dbapi.py +19 -3
- snowflake/snowpark_connect/relation/write/map_write.py +929 -319
- snowflake/snowpark_connect/relation/write/map_write_jdbc.py +8 -2
- snowflake/snowpark_connect/resources/java_udfs-1.0-SNAPSHOT.jar +0 -0
- snowflake/snowpark_connect/resources_initializer.py +171 -48
- snowflake/snowpark_connect/server.py +528 -473
- snowflake/snowpark_connect/server_common/__init__.py +503 -0
- snowflake/snowpark_connect/snowflake_session.py +65 -0
- snowflake/snowpark_connect/start_server.py +53 -5
- snowflake/snowpark_connect/type_mapping.py +349 -27
- snowflake/snowpark_connect/type_support.py +130 -0
- snowflake/snowpark_connect/typed_column.py +9 -7
- snowflake/snowpark_connect/utils/artifacts.py +9 -8
- snowflake/snowpark_connect/utils/cache.py +49 -27
- snowflake/snowpark_connect/utils/concurrent.py +36 -1
- snowflake/snowpark_connect/utils/context.py +195 -37
- snowflake/snowpark_connect/utils/describe_query_cache.py +68 -53
- snowflake/snowpark_connect/utils/env_utils.py +5 -1
- snowflake/snowpark_connect/utils/expression_transformer.py +172 -0
- snowflake/snowpark_connect/utils/identifiers.py +137 -3
- snowflake/snowpark_connect/utils/io_utils.py +57 -1
- snowflake/snowpark_connect/utils/java_stored_procedure.py +151 -0
- snowflake/snowpark_connect/utils/java_udaf_utils.py +321 -0
- snowflake/snowpark_connect/utils/java_udtf_utils.py +239 -0
- snowflake/snowpark_connect/utils/jvm_udf_utils.py +281 -0
- snowflake/snowpark_connect/utils/open_telemetry.py +516 -0
- snowflake/snowpark_connect/utils/pandas_udtf_utils.py +8 -4
- snowflake/snowpark_connect/utils/patch_spark_line_number.py +181 -0
- snowflake/snowpark_connect/utils/profiling.py +25 -8
- snowflake/snowpark_connect/utils/scala_udf_utils.py +185 -340
- snowflake/snowpark_connect/utils/sequence.py +21 -0
- snowflake/snowpark_connect/utils/session.py +64 -28
- snowflake/snowpark_connect/utils/snowpark_connect_logging.py +51 -9
- snowflake/snowpark_connect/utils/spcs_logger.py +290 -0
- snowflake/snowpark_connect/utils/telemetry.py +192 -40
- snowflake/snowpark_connect/utils/temporary_view_cache.py +67 -0
- snowflake/snowpark_connect/utils/temporary_view_helper.py +334 -0
- snowflake/snowpark_connect/utils/udf_cache.py +117 -41
- snowflake/snowpark_connect/utils/udf_helper.py +39 -37
- snowflake/snowpark_connect/utils/udf_utils.py +133 -14
- snowflake/snowpark_connect/utils/udtf_helper.py +8 -1
- snowflake/snowpark_connect/utils/udtf_utils.py +46 -31
- snowflake/snowpark_connect/utils/udxf_import_utils.py +9 -2
- snowflake/snowpark_connect/utils/upload_java_jar.py +57 -0
- snowflake/snowpark_connect/version.py +1 -1
- snowflake/snowpark_decoder/dp_session.py +6 -2
- snowflake/snowpark_decoder/spark_decoder.py +12 -0
- {snowpark_connect-0.27.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-submit +14 -4
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/METADATA +16 -7
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/RECORD +139 -168
- snowflake/snowpark_connect/hidden_column.py +0 -39
- snowflake/snowpark_connect/includes/jars/antlr4-runtime-4.9.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-cli-1.5.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-codec-1.16.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-collections-3.2.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-collections4-4.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-compiler-3.1.9.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-compress-1.26.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-crypto-1.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-dbcp-1.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-io-2.16.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-lang-2.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-lang3-3.12.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-logging-1.1.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-math3-3.6.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-pool-1.5.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-text-1.10.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/hadoop-client-api-trimmed-3.3.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-annotations-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-core-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-core-asl-1.9.13.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-databind-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-dataformat-yaml-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-datatype-jsr310-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-module-scala_2.12-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-ast_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-core_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-jackson_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-native_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-scalap_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/kryo-shaded-4.0.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-1.2-api-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-api-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-core-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-slf4j2-impl-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/paranamer-2.8.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/paranamer-2.8.jar +0 -0
- snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-collection-compat_2.12-2.7.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-parser-combinators_2.12-2.3.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-reflect-2.12.18.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-xml_2.12-2.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/slf4j-api-2.0.7.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-catalyst_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-common-utils_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-core_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-graphx_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-hive-thriftserver_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-hive_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-kvstore_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-launcher_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-mesos_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-mllib-local_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-network-common_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-network-shuffle_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-repl_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sketch_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sql-api_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-tags_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-unsafe_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-yarn_2.12-3.5.6.jar +0 -0
- {snowpark_connect-0.27.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.27.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/NOTICE-binary +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/top_level.txt +0 -0
|
@@ -2,16 +2,21 @@
|
|
|
2
2
|
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
|
|
5
|
+
import math
|
|
5
6
|
import re
|
|
7
|
+
import typing
|
|
6
8
|
from collections.abc import MutableMapping, MutableSequence
|
|
7
|
-
from contextlib import contextmanager
|
|
9
|
+
from contextlib import contextmanager, suppress
|
|
8
10
|
from contextvars import ContextVar
|
|
11
|
+
from decimal import Decimal
|
|
9
12
|
from functools import reduce
|
|
13
|
+
from typing import Tuple
|
|
10
14
|
|
|
11
15
|
import jpype
|
|
12
16
|
import pandas
|
|
13
17
|
import pyspark.sql.connect.proto.expressions_pb2 as expressions_proto
|
|
14
18
|
import pyspark.sql.connect.proto.relations_pb2 as relation_proto
|
|
19
|
+
import pyspark.sql.connect.proto.types_pb2 as types_proto
|
|
15
20
|
import sqlglot
|
|
16
21
|
from google.protobuf.any_pb2 import Any
|
|
17
22
|
from pyspark.errors.exceptions.base import (
|
|
@@ -24,20 +29,28 @@ import snowflake.snowpark.functions as snowpark_fn
|
|
|
24
29
|
import snowflake.snowpark_connect.proto.snowflake_expression_ext_pb2 as snowflake_exp_proto
|
|
25
30
|
import snowflake.snowpark_connect.proto.snowflake_relation_ext_pb2 as snowflake_proto
|
|
26
31
|
from snowflake import snowpark
|
|
32
|
+
from snowflake.snowpark import Session, types as snowpark_types
|
|
27
33
|
from snowflake.snowpark._internal.analyzer.analyzer_utils import (
|
|
28
34
|
quote_name_without_upper_casing,
|
|
29
35
|
unquote_if_quoted,
|
|
30
36
|
)
|
|
31
37
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
|
32
38
|
from snowflake.snowpark._internal.utils import is_sql_select_statement, quote_name
|
|
39
|
+
from snowflake.snowpark.functions import when_matched, when_not_matched
|
|
33
40
|
from snowflake.snowpark_connect.config import (
|
|
41
|
+
auto_uppercase_column_identifiers,
|
|
34
42
|
auto_uppercase_non_column_identifiers,
|
|
43
|
+
check_table_supports_operation,
|
|
35
44
|
get_boolean_session_config_param,
|
|
36
45
|
global_config,
|
|
46
|
+
record_table_metadata,
|
|
37
47
|
set_config_param,
|
|
48
|
+
should_create_temporary_view_in_snowflake,
|
|
38
49
|
unset_config_param,
|
|
39
50
|
)
|
|
40
51
|
from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
|
|
52
|
+
from snowflake.snowpark_connect.error.error_codes import ErrorCodes
|
|
53
|
+
from snowflake.snowpark_connect.error.error_utils import attach_custom_error_code
|
|
41
54
|
from snowflake.snowpark_connect.expression.map_expression import (
|
|
42
55
|
ColumnNameMap,
|
|
43
56
|
map_single_column_expression,
|
|
@@ -51,14 +64,28 @@ from snowflake.snowpark_connect.relation.map_relation import (
|
|
|
51
64
|
NATURAL_JOIN_TYPE_BASE,
|
|
52
65
|
map_relation,
|
|
53
66
|
)
|
|
54
|
-
|
|
67
|
+
|
|
68
|
+
# Import from utils for consistency
|
|
69
|
+
from snowflake.snowpark_connect.relation.utils import is_aggregate_function
|
|
70
|
+
from snowflake.snowpark_connect.snowflake_session import (
|
|
71
|
+
SQL_PASS_THROUGH_MARKER,
|
|
72
|
+
calculate_checksum,
|
|
73
|
+
)
|
|
74
|
+
from snowflake.snowpark_connect.type_mapping import (
|
|
75
|
+
map_snowpark_to_pyspark_types,
|
|
76
|
+
snowpark_to_proto_type,
|
|
77
|
+
)
|
|
55
78
|
from snowflake.snowpark_connect.utils.context import (
|
|
56
79
|
_accessing_temp_object,
|
|
57
80
|
gen_sql_plan_id,
|
|
58
|
-
|
|
81
|
+
get_is_processing_aliased_relation,
|
|
82
|
+
get_spark_session_id,
|
|
59
83
|
get_sql_plan,
|
|
60
84
|
push_evaluating_sql_scope,
|
|
85
|
+
push_processed_view,
|
|
86
|
+
push_processing_aliased_relation_scope,
|
|
61
87
|
push_sql_scope,
|
|
88
|
+
set_plan_id_map,
|
|
62
89
|
set_sql_args,
|
|
63
90
|
set_sql_plan_name,
|
|
64
91
|
)
|
|
@@ -68,6 +95,7 @@ from snowflake.snowpark_connect.utils.telemetry import (
|
|
|
68
95
|
telemetry,
|
|
69
96
|
)
|
|
70
97
|
|
|
98
|
+
from .. import column_name_handler
|
|
71
99
|
from ..expression.map_sql_expression import (
|
|
72
100
|
_window_specs,
|
|
73
101
|
as_java_list,
|
|
@@ -75,7 +103,18 @@ from ..expression.map_sql_expression import (
|
|
|
75
103
|
map_logical_plan_expression,
|
|
76
104
|
sql_parser,
|
|
77
105
|
)
|
|
78
|
-
from ..
|
|
106
|
+
from ..typed_column import TypedColumn
|
|
107
|
+
from ..utils.identifiers import (
|
|
108
|
+
spark_to_sf_single_id,
|
|
109
|
+
spark_to_sf_single_id_with_unquoting,
|
|
110
|
+
)
|
|
111
|
+
from ..utils.temporary_view_helper import (
|
|
112
|
+
create_snowflake_temporary_view,
|
|
113
|
+
get_temp_view,
|
|
114
|
+
store_temporary_view_as_dataframe,
|
|
115
|
+
unregister_temp_view,
|
|
116
|
+
)
|
|
117
|
+
from .catalogs import SNOWFLAKE_CATALOG
|
|
79
118
|
|
|
80
119
|
_ctes = ContextVar[dict[str, relation_proto.Relation]]("_ctes", default={})
|
|
81
120
|
_cte_definitions = ContextVar[dict[str, any]]("_cte_definitions", default={})
|
|
@@ -84,6 +123,65 @@ _having_condition = ContextVar[expressions_proto.Expression | None](
|
|
|
84
123
|
)
|
|
85
124
|
|
|
86
125
|
|
|
126
|
+
def _map_value_to_literal_proto(
|
|
127
|
+
value: typing.Any, typ: snowpark_types.DataType
|
|
128
|
+
) -> expressions_proto.Expression.Literal:
|
|
129
|
+
if isinstance(typ, snowpark_types.NullType):
|
|
130
|
+
return expressions_proto.Expression.Literal(null=value)
|
|
131
|
+
if isinstance(typ, snowpark_types.BinaryType):
|
|
132
|
+
return expressions_proto.Expression.Literal(binary=value)
|
|
133
|
+
if isinstance(typ, snowpark_types.BooleanType):
|
|
134
|
+
return expressions_proto.Expression.Literal(boolean=value)
|
|
135
|
+
if isinstance(typ, snowpark_types.ByteType):
|
|
136
|
+
return expressions_proto.Expression.Literal(byte=value)
|
|
137
|
+
if isinstance(typ, snowpark_types.ShortType):
|
|
138
|
+
return expressions_proto.Expression.Literal(short=value)
|
|
139
|
+
if isinstance(typ, snowpark_types.IntegerType):
|
|
140
|
+
return expressions_proto.Expression.Literal(integer=value)
|
|
141
|
+
if isinstance(typ, snowpark_types.LongType):
|
|
142
|
+
return expressions_proto.Expression.Literal(long=value)
|
|
143
|
+
if isinstance(typ, snowpark_types.FloatType):
|
|
144
|
+
return expressions_proto.Expression.Literal(float=value)
|
|
145
|
+
if isinstance(typ, snowpark_types.DoubleType):
|
|
146
|
+
return expressions_proto.Expression.Literal(double=value)
|
|
147
|
+
if isinstance(typ, snowpark_types.DecimalType):
|
|
148
|
+
return expressions_proto.Expression.Literal(
|
|
149
|
+
decimal=expressions_proto.Expression.Literal.Decimal(
|
|
150
|
+
value=value,
|
|
151
|
+
precision=typ.precision,
|
|
152
|
+
scale=typ.scale,
|
|
153
|
+
)
|
|
154
|
+
)
|
|
155
|
+
if isinstance(typ, snowpark_types.ArrayType):
|
|
156
|
+
element_type_proto = types_proto.DataType(
|
|
157
|
+
**snowpark_to_proto_type(typ.element_type)
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
return expressions_proto.Expression.Literal(
|
|
161
|
+
array=expressions_proto.Expression.Literal.Array(
|
|
162
|
+
element_type=element_type_proto,
|
|
163
|
+
elements=[
|
|
164
|
+
_map_value_to_literal_proto(el, typ.element_type) for el in value
|
|
165
|
+
],
|
|
166
|
+
)
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
if isinstance(typ, snowpark_types.StructType):
|
|
170
|
+
struct_type_proto = types_proto.DataType(**snowpark_to_proto_type(typ))
|
|
171
|
+
|
|
172
|
+
return expressions_proto.Expression.Literal(
|
|
173
|
+
struct=expressions_proto.Expression.Literal.Struct(
|
|
174
|
+
struct_type=struct_type_proto,
|
|
175
|
+
elements=[
|
|
176
|
+
_map_value_to_literal_proto(v, typ.fields[i].datatype)
|
|
177
|
+
for i, v in enumerate(value.values())
|
|
178
|
+
],
|
|
179
|
+
)
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
return expressions_proto.Expression.Literal(string=str(value))
|
|
183
|
+
|
|
184
|
+
|
|
87
185
|
def _is_sql_select_statement_helper(sql_string: str) -> bool:
|
|
88
186
|
"""
|
|
89
187
|
Determine if a SQL string is a SELECT or CTE query statement, even when it starts with comments or whitespace.
|
|
@@ -130,6 +228,48 @@ def _push_cte_scope():
|
|
|
130
228
|
_cte_definitions.reset(def_token)
|
|
131
229
|
|
|
132
230
|
|
|
231
|
+
def _process_cte_relations(cte_relations):
|
|
232
|
+
"""
|
|
233
|
+
Process CTE relations and register them in the current CTE scope.
|
|
234
|
+
|
|
235
|
+
This function extracts CTE definitions from CTE relations,
|
|
236
|
+
maps them to protobuf representations, and stores them for later reference.
|
|
237
|
+
|
|
238
|
+
Args:
|
|
239
|
+
cte_relations: Java list of CTE relations (tuples of name and SubqueryAlias)
|
|
240
|
+
"""
|
|
241
|
+
for cte in as_java_list(cte_relations):
|
|
242
|
+
name = str(cte._1())
|
|
243
|
+
# Store the original CTE definition for re-evaluation
|
|
244
|
+
_cte_definitions.get()[name] = cte._2()
|
|
245
|
+
# Process CTE definition with a unique plan_id to ensure proper column naming
|
|
246
|
+
# Clear HAVING condition before processing each CTE to prevent leakage between CTEs
|
|
247
|
+
saved_having = _having_condition.get()
|
|
248
|
+
_having_condition.set(None)
|
|
249
|
+
try:
|
|
250
|
+
cte_plan_id = gen_sql_plan_id()
|
|
251
|
+
cte_proto = map_logical_plan_relation(cte._2(), cte_plan_id)
|
|
252
|
+
_ctes.get()[name] = cte_proto
|
|
253
|
+
finally:
|
|
254
|
+
_having_condition.set(saved_having)
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
@contextmanager
|
|
258
|
+
def _with_cte_scope(cte_relations):
|
|
259
|
+
"""
|
|
260
|
+
Context manager that creates a CTE scope and processes CTE relations.
|
|
261
|
+
|
|
262
|
+
This combines _push_cte_scope() and _process_cte_relations() to handle
|
|
263
|
+
the common pattern of processing CTEs within a new scope.
|
|
264
|
+
|
|
265
|
+
Args:
|
|
266
|
+
cte_relations: Java list of CTE relations (tuples of name and SubqueryAlias)
|
|
267
|
+
"""
|
|
268
|
+
with (_push_cte_scope()):
|
|
269
|
+
_process_cte_relations(cte_relations)
|
|
270
|
+
yield
|
|
271
|
+
|
|
272
|
+
|
|
133
273
|
@contextmanager
|
|
134
274
|
def _push_window_specs_scope():
|
|
135
275
|
"""
|
|
@@ -203,6 +343,9 @@ def _rename_columns(
|
|
|
203
343
|
def _create_table_as_select(logical_plan, mode: str) -> None:
|
|
204
344
|
# TODO: for as select create tables we'd map multi layer identifier here
|
|
205
345
|
name = get_relation_identifier_name(logical_plan.name())
|
|
346
|
+
full_table_identifier = get_relation_identifier_name(
|
|
347
|
+
logical_plan.name(), is_multi_part=True
|
|
348
|
+
)
|
|
206
349
|
comment = logical_plan.tableSpec().comment()
|
|
207
350
|
|
|
208
351
|
container = execute_logical_plan(logical_plan.query())
|
|
@@ -223,9 +366,158 @@ def _create_table_as_select(logical_plan, mode: str) -> None:
|
|
|
223
366
|
mode=mode,
|
|
224
367
|
)
|
|
225
368
|
|
|
369
|
+
# Record table metadata for CREATE TABLE AS SELECT
|
|
370
|
+
# These are typically considered v2 tables and support RENAME COLUMN
|
|
371
|
+
record_table_metadata(
|
|
372
|
+
table_identifier=full_table_identifier,
|
|
373
|
+
table_type="v2",
|
|
374
|
+
data_source="default",
|
|
375
|
+
supports_column_rename=True,
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
|
|
379
|
+
def _insert_into_table(logical_plan, session: Session) -> None:
|
|
380
|
+
df_container = execute_logical_plan(logical_plan.query())
|
|
381
|
+
df = df_container.dataframe
|
|
382
|
+
queries = df.queries["queries"]
|
|
383
|
+
if len(queries) != 1:
|
|
384
|
+
exception = SnowparkConnectNotImplementedError(
|
|
385
|
+
f"Unexpected number of queries: {len(queries)}"
|
|
386
|
+
)
|
|
387
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
388
|
+
raise exception
|
|
389
|
+
|
|
390
|
+
name = get_relation_identifier_name(logical_plan.table(), True)
|
|
391
|
+
|
|
392
|
+
user_columns = [
|
|
393
|
+
spark_to_sf_single_id(str(col), is_column=True)
|
|
394
|
+
for col in as_java_list(logical_plan.userSpecifiedCols())
|
|
395
|
+
]
|
|
396
|
+
overwrite_str = "OVERWRITE" if logical_plan.overwrite() else ""
|
|
397
|
+
cols_str = "(" + ", ".join(user_columns) + ")" if user_columns else ""
|
|
398
|
+
|
|
399
|
+
# Extract partition spec if any
|
|
400
|
+
partition_spec = logical_plan.partitionSpec()
|
|
401
|
+
partition_map = as_java_map(partition_spec)
|
|
402
|
+
|
|
403
|
+
partition_columns = {}
|
|
404
|
+
for entry in partition_map.entrySet():
|
|
405
|
+
col_name = str(entry.getKey())
|
|
406
|
+
value_option = entry.getValue()
|
|
407
|
+
if value_option.isDefined():
|
|
408
|
+
partition_columns[col_name] = value_option.get()
|
|
409
|
+
|
|
410
|
+
target_table = session.table(name)
|
|
411
|
+
target_schema = target_table.schema
|
|
412
|
+
|
|
413
|
+
# Add partition columns to the dataframe
|
|
414
|
+
if partition_columns:
|
|
415
|
+
"""
|
|
416
|
+
Spark sends them in the partition spec and the values won't be present in the values array.
|
|
417
|
+
As snowflake does not support static partitions in INSERT INTO statements,
|
|
418
|
+
we need to add the partition columns to the dataframe as literal columns.
|
|
419
|
+
|
|
420
|
+
ex: INSERT INTO TABLE test_table PARTITION (ds='2021-01-01', hr=10) VALUES ('k1', 100), ('k2', 200), ('k3', 300)
|
|
421
|
+
|
|
422
|
+
Spark sends: VALUES ('k1', 100), ('k2', 200), ('k3', 300) with partition spec (ds='2021-01-01', hr=10)
|
|
423
|
+
Snowflake expects: VALUES ('k1', 100, '2021-01-01', 10), ('k2', 200, '2021-01-01', 10), ('k3', 300, '2021-01-01', 10)
|
|
424
|
+
|
|
425
|
+
We need to add the partition columns to the dataframe as literal columns.
|
|
426
|
+
|
|
427
|
+
ex: df = df.withColumn('ds', snowpark_fn.lit('2021-01-01'))
|
|
428
|
+
df = df.withColumn('hr', snowpark_fn.lit(10))
|
|
429
|
+
|
|
430
|
+
Then the final query will be:
|
|
431
|
+
INSERT INTO TABLE test_table VALUES ('k1', 100, '2021-01-01', 10), ('k2', 200, '2021-01-01', 10), ('k3', 300, '2021-01-01', 10)
|
|
432
|
+
"""
|
|
433
|
+
for partition_col, partition_value in partition_columns.items():
|
|
434
|
+
|
|
435
|
+
def _comparable_col_name(col: str) -> str:
|
|
436
|
+
name = col.upper() if auto_uppercase_column_identifiers() else col
|
|
437
|
+
return unquote_if_quoted(name)
|
|
438
|
+
|
|
439
|
+
comparable_target_schema = [
|
|
440
|
+
_comparable_col_name(col.name) for col in target_schema.fields
|
|
441
|
+
]
|
|
442
|
+
|
|
443
|
+
if _comparable_col_name(partition_col) not in comparable_target_schema:
|
|
444
|
+
exception = AnalysisException(
|
|
445
|
+
f"{partition_col} is not a valid partition column in table {name}."
|
|
446
|
+
)
|
|
447
|
+
attach_custom_error_code(exception, ErrorCodes.INVALID_INPUT)
|
|
448
|
+
raise exception
|
|
449
|
+
df = df.withColumn(partition_col, snowpark_fn.lit(partition_value))
|
|
450
|
+
|
|
451
|
+
expected_number_of_columns = (
|
|
452
|
+
len(user_columns) if user_columns else len(target_schema.fields)
|
|
453
|
+
)
|
|
454
|
+
if expected_number_of_columns != len(df.schema.fields):
|
|
455
|
+
reason = (
|
|
456
|
+
"too many data columns"
|
|
457
|
+
if len(df.schema.fields) > expected_number_of_columns
|
|
458
|
+
else "not enough data columns"
|
|
459
|
+
)
|
|
460
|
+
exception = AnalysisException(
|
|
461
|
+
f'[INSERT_COLUMN_ARITY_MISMATCH.{reason.replace(" ", "_").upper()}] Cannot write to {name}, the reason is {reason}:\n'
|
|
462
|
+
f'Table columns: {", ".join(target_schema.names)}.\n'
|
|
463
|
+
f'Data columns: {", ".join(df.schema.names)}.'
|
|
464
|
+
)
|
|
465
|
+
attach_custom_error_code(exception, ErrorCodes.INVALID_INPUT)
|
|
466
|
+
raise exception
|
|
467
|
+
|
|
468
|
+
try:
|
|
469
|
+
# Modify df with type conversions and struct field name mapping
|
|
470
|
+
modified_columns = []
|
|
471
|
+
for source_field, target_field in zip(df.schema.fields, target_schema.fields):
|
|
472
|
+
col_name = source_field.name
|
|
473
|
+
|
|
474
|
+
# Handle different type conversions
|
|
475
|
+
if isinstance(
|
|
476
|
+
target_field.datatype, snowpark.types.DecimalType
|
|
477
|
+
) and isinstance(
|
|
478
|
+
source_field.datatype,
|
|
479
|
+
(snowpark.types.FloatType, snowpark.types.DoubleType),
|
|
480
|
+
):
|
|
481
|
+
# Add CASE WHEN to convert NaN to NULL for DECIMAL targets
|
|
482
|
+
# Only apply this to floating-point source columns
|
|
483
|
+
modified_col = (
|
|
484
|
+
snowpark_fn.when(
|
|
485
|
+
snowpark_fn.equal_nan(snowpark_fn.col(col_name)),
|
|
486
|
+
snowpark_fn.lit(None),
|
|
487
|
+
)
|
|
488
|
+
.otherwise(snowpark_fn.col(col_name))
|
|
489
|
+
.alias(col_name)
|
|
490
|
+
)
|
|
491
|
+
modified_columns.append(modified_col)
|
|
492
|
+
elif (
|
|
493
|
+
isinstance(target_field.datatype, snowpark.types.StructType)
|
|
494
|
+
and source_field.datatype != target_field.datatype
|
|
495
|
+
):
|
|
496
|
+
# Cast struct with field name mapping (e.g., col1,col2 -> i1,i2)
|
|
497
|
+
# This fixes INSERT INTO table with struct literals like (2, 3)
|
|
498
|
+
modified_col = (
|
|
499
|
+
snowpark_fn.col(col_name)
|
|
500
|
+
.cast(target_field.datatype, rename_fields=True)
|
|
501
|
+
.alias(col_name)
|
|
502
|
+
)
|
|
503
|
+
modified_columns.append(modified_col)
|
|
504
|
+
else:
|
|
505
|
+
modified_columns.append(snowpark_fn.col(col_name))
|
|
506
|
+
|
|
507
|
+
df = df.select(modified_columns)
|
|
508
|
+
except Exception:
|
|
509
|
+
pass
|
|
510
|
+
|
|
511
|
+
queries = df.queries["queries"]
|
|
512
|
+
final_query = queries[0]
|
|
513
|
+
session.sql(
|
|
514
|
+
f"INSERT {overwrite_str} INTO {name} {cols_str} {final_query}",
|
|
515
|
+
).collect()
|
|
516
|
+
|
|
226
517
|
|
|
227
518
|
def _spark_field_to_sql(field: jpype.JObject, is_column: bool) -> str:
|
|
228
|
-
# Column names will be uppercased according to "snowpark.connect.sql.identifiers.auto-uppercase"
|
|
519
|
+
# Column names will be uppercased according to "snowpark.connect.sql.identifiers.auto-uppercase"
|
|
520
|
+
# if present, or to "spark.sql.caseSensitive".
|
|
229
521
|
# and struct fields will be left as is. This should allow users to use the same names
|
|
230
522
|
# in spark and Snowflake in most cases.
|
|
231
523
|
if is_column:
|
|
@@ -300,6 +592,69 @@ def _remove_column_data_type(node):
|
|
|
300
592
|
return node
|
|
301
593
|
|
|
302
594
|
|
|
595
|
+
def _get_condition_from_action(action, column_mapping, typer):
|
|
596
|
+
condition = None
|
|
597
|
+
if action.condition().isDefined():
|
|
598
|
+
(_, condition_typed_col,) = map_single_column_expression(
|
|
599
|
+
map_logical_plan_expression(action.condition().get()),
|
|
600
|
+
column_mapping,
|
|
601
|
+
typer,
|
|
602
|
+
)
|
|
603
|
+
condition = condition_typed_col.col
|
|
604
|
+
return condition
|
|
605
|
+
|
|
606
|
+
|
|
607
|
+
def _get_assignments_from_action(
|
|
608
|
+
action,
|
|
609
|
+
column_mapping_source,
|
|
610
|
+
column_mapping_target,
|
|
611
|
+
typer_source,
|
|
612
|
+
typer_target,
|
|
613
|
+
):
|
|
614
|
+
assignments = dict()
|
|
615
|
+
if (
|
|
616
|
+
action.getClass().getSimpleName() == "InsertAction"
|
|
617
|
+
or action.getClass().getSimpleName() == "UpdateAction"
|
|
618
|
+
):
|
|
619
|
+
incoming_assignments = as_java_list(action.assignments())
|
|
620
|
+
for assignment in incoming_assignments:
|
|
621
|
+
(_, key_typ_col) = map_single_column_expression(
|
|
622
|
+
map_logical_plan_expression(assignment.key()),
|
|
623
|
+
column_mapping=column_mapping_target,
|
|
624
|
+
typer=typer_target,
|
|
625
|
+
)
|
|
626
|
+
key_name = typer_target.df.select(key_typ_col.col).columns[0]
|
|
627
|
+
|
|
628
|
+
(_, val_typ_col) = map_single_column_expression(
|
|
629
|
+
map_logical_plan_expression(assignment.value()),
|
|
630
|
+
column_mapping=column_mapping_source,
|
|
631
|
+
typer=typer_source,
|
|
632
|
+
)
|
|
633
|
+
|
|
634
|
+
assignments[key_name] = val_typ_col.col
|
|
635
|
+
elif (
|
|
636
|
+
action.getClass().getSimpleName() == "InsertStarAction"
|
|
637
|
+
or action.getClass().getSimpleName() == "UpdateStarAction"
|
|
638
|
+
):
|
|
639
|
+
if len(column_mapping_source.columns) != len(column_mapping_target.columns):
|
|
640
|
+
exception = ValueError(
|
|
641
|
+
"source and target must have the same number of columns for InsertStarAction or UpdateStarAction"
|
|
642
|
+
)
|
|
643
|
+
attach_custom_error_code(exception, ErrorCodes.INVALID_OPERATION)
|
|
644
|
+
raise exception
|
|
645
|
+
for i, col in enumerate(column_mapping_target.columns):
|
|
646
|
+
if assignments.get(col.snowpark_name) is not None:
|
|
647
|
+
exception = SnowparkConnectNotImplementedError(
|
|
648
|
+
"UpdateStarAction or InsertStarAction is not supported with duplicate columns."
|
|
649
|
+
)
|
|
650
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
651
|
+
raise exception
|
|
652
|
+
assignments[col.snowpark_name] = snowpark_fn.col(
|
|
653
|
+
column_mapping_source.columns[i].snowpark_name
|
|
654
|
+
)
|
|
655
|
+
return assignments
|
|
656
|
+
|
|
657
|
+
|
|
303
658
|
def map_sql_to_pandas_df(
|
|
304
659
|
sql_string: str,
|
|
305
660
|
named_args: MutableMapping[str, expressions_proto.Expression.Literal],
|
|
@@ -311,7 +666,7 @@ def map_sql_to_pandas_df(
|
|
|
311
666
|
returns a tuple of None for SELECT queries to enable lazy evaluation
|
|
312
667
|
"""
|
|
313
668
|
|
|
314
|
-
snowpark_connect_sql_passthrough =
|
|
669
|
+
snowpark_connect_sql_passthrough, sql_string = is_valid_passthrough_sql(sql_string)
|
|
315
670
|
|
|
316
671
|
if not snowpark_connect_sql_passthrough:
|
|
317
672
|
logical_plan = sql_parser().parsePlan(sql_string)
|
|
@@ -327,6 +682,7 @@ def map_sql_to_pandas_df(
|
|
|
327
682
|
) == "UnresolvedHint":
|
|
328
683
|
logical_plan = logical_plan.child()
|
|
329
684
|
|
|
685
|
+
# TODO: Add support for temporary views for SQL cases such as ShowViews, ShowColumns ect. (Currently the cases are not compatible with Spark, returning raw Snowflake rows)
|
|
330
686
|
match class_name:
|
|
331
687
|
case "AddColumns":
|
|
332
688
|
# Handle ALTER TABLE ... ADD COLUMNS (col_name data_type) -> ADD COLUMN col_name data_type
|
|
@@ -397,9 +753,11 @@ def map_sql_to_pandas_df(
|
|
|
397
753
|
snowflake_sql = f"ALTER TABLE {table_name} ALTER COLUMN {column_name} {alter_clause}"
|
|
398
754
|
session.sql(snowflake_sql).collect()
|
|
399
755
|
else:
|
|
400
|
-
|
|
756
|
+
exception = ValueError(
|
|
401
757
|
f"No alter operations found in AlterColumn logical plan for table {table_name}, column {column_name}"
|
|
402
758
|
)
|
|
759
|
+
attach_custom_error_code(exception, ErrorCodes.INVALID_SQL_SYNTAX)
|
|
760
|
+
raise exception
|
|
403
761
|
case "CreateNamespace":
|
|
404
762
|
name = get_relation_identifier_name(logical_plan.name(), True)
|
|
405
763
|
previous_name = session.connection.schema
|
|
@@ -421,6 +779,9 @@ def map_sql_to_pandas_df(
|
|
|
421
779
|
)
|
|
422
780
|
|
|
423
781
|
name = get_relation_identifier_name(logical_plan.name())
|
|
782
|
+
full_table_identifier = get_relation_identifier_name(
|
|
783
|
+
logical_plan.name(), is_multi_part=True
|
|
784
|
+
)
|
|
424
785
|
columns = ", ".join(
|
|
425
786
|
_spark_field_to_sql(f, True)
|
|
426
787
|
for f in logical_plan.tableSchema().fields()
|
|
@@ -431,10 +792,48 @@ def map_sql_to_pandas_df(
|
|
|
431
792
|
if comment_opt.isDefined()
|
|
432
793
|
else ""
|
|
433
794
|
)
|
|
795
|
+
|
|
796
|
+
# Extract data source for metadata tracking
|
|
797
|
+
data_source = "default"
|
|
798
|
+
|
|
799
|
+
with suppress(Exception):
|
|
800
|
+
# Get data source from tableSpec.provider() (for USING clause)
|
|
801
|
+
if hasattr(logical_plan, "tableSpec"):
|
|
802
|
+
table_spec = logical_plan.tableSpec()
|
|
803
|
+
if hasattr(table_spec, "provider"):
|
|
804
|
+
provider_opt = table_spec.provider()
|
|
805
|
+
if provider_opt.isDefined():
|
|
806
|
+
data_source = str(provider_opt.get()).lower()
|
|
807
|
+
else:
|
|
808
|
+
# Fall back to checking properties for FORMAT
|
|
809
|
+
table_properties = table_spec.properties()
|
|
810
|
+
if not table_properties.isEmpty():
|
|
811
|
+
for prop in table_properties.get():
|
|
812
|
+
if str(prop.key()) == "FORMAT":
|
|
813
|
+
data_source = str(prop.value()).lower()
|
|
814
|
+
break
|
|
815
|
+
|
|
434
816
|
# NOTE: We are intentionally ignoring any FORMAT=... parameters here.
|
|
435
817
|
session.sql(
|
|
436
818
|
f"CREATE {replace_table} TABLE {if_not_exists}{name} ({columns}) {comment}"
|
|
437
819
|
).collect()
|
|
820
|
+
|
|
821
|
+
# Record table metadata for Spark compatibility
|
|
822
|
+
# Tables created with explicit schema are considered v1 tables
|
|
823
|
+
# v1 tables with certain data sources don't support RENAME COLUMN in OSS Spark
|
|
824
|
+
supports_rename = data_source not in (
|
|
825
|
+
"parquet",
|
|
826
|
+
"csv",
|
|
827
|
+
"json",
|
|
828
|
+
"orc",
|
|
829
|
+
"avro",
|
|
830
|
+
)
|
|
831
|
+
record_table_metadata(
|
|
832
|
+
table_identifier=full_table_identifier,
|
|
833
|
+
table_type="v1",
|
|
834
|
+
data_source=data_source,
|
|
835
|
+
supports_column_rename=supports_rename,
|
|
836
|
+
)
|
|
438
837
|
case "CreateTableAsSelect":
|
|
439
838
|
mode = "ignore" if logical_plan.ignoreIfExists() else "errorifexists"
|
|
440
839
|
_create_table_as_select(logical_plan, mode=mode)
|
|
@@ -446,20 +845,62 @@ def map_sql_to_pandas_df(
|
|
|
446
845
|
f"CREATE TABLE {if_not_exists}{name} LIKE {source}"
|
|
447
846
|
).collect()
|
|
448
847
|
case "CreateTempViewUsing":
|
|
848
|
+
parsed_sql = sqlglot.parse_one(sql_string, dialect="spark")
|
|
849
|
+
|
|
850
|
+
spark_view_name = next(parsed_sql.find_all(sqlglot.exp.Table)).name
|
|
851
|
+
|
|
852
|
+
# extract ONLY top-level column definitions (not nested struct fields)
|
|
853
|
+
column_defs = []
|
|
854
|
+
schema_node = next(parsed_sql.find_all(sqlglot.exp.Schema), None)
|
|
855
|
+
if schema_node:
|
|
856
|
+
for expr in schema_node.expressions:
|
|
857
|
+
if isinstance(expr, sqlglot.exp.ColumnDef):
|
|
858
|
+
column_defs.append(expr)
|
|
859
|
+
|
|
860
|
+
num_columns = len(column_defs)
|
|
861
|
+
if num_columns > 0:
|
|
862
|
+
null_list_parts = []
|
|
863
|
+
for col_def in column_defs:
|
|
864
|
+
col_name = spark_to_sf_single_id(col_def.name, is_column=True)
|
|
865
|
+
col_type = col_def.kind
|
|
866
|
+
if col_type:
|
|
867
|
+
null_list_parts.append(
|
|
868
|
+
f"CAST(NULL AS {col_type.sql(dialect='snowflake')}) AS {col_name}"
|
|
869
|
+
)
|
|
870
|
+
else:
|
|
871
|
+
null_list_parts.append(f"NULL AS {col_name}")
|
|
872
|
+
null_list = ", ".join(null_list_parts)
|
|
873
|
+
else:
|
|
874
|
+
null_list = "*"
|
|
875
|
+
|
|
449
876
|
empty_select = (
|
|
450
|
-
" AS SELECT
|
|
877
|
+
f" AS SELECT {null_list} WHERE 1 = 0"
|
|
451
878
|
if logical_plan.options().isEmpty()
|
|
452
879
|
and logical_plan.children().isEmpty()
|
|
453
880
|
else ""
|
|
454
881
|
)
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
.transform(_normalize_identifiers)
|
|
882
|
+
|
|
883
|
+
transformed_sql = (
|
|
884
|
+
parsed_sql.transform(_normalize_identifiers)
|
|
458
885
|
.transform(_remove_column_data_type)
|
|
459
886
|
.transform(_remove_file_format_property)
|
|
460
887
|
)
|
|
461
|
-
snowflake_sql =
|
|
888
|
+
snowflake_sql = transformed_sql.sql(dialect="snowflake")
|
|
462
889
|
session.sql(f"{snowflake_sql}{empty_select}").collect()
|
|
890
|
+
snowflake_view_name = spark_to_sf_single_id_with_unquoting(
|
|
891
|
+
spark_view_name
|
|
892
|
+
)
|
|
893
|
+
temp_view = get_temp_view(snowflake_view_name)
|
|
894
|
+
if temp_view is not None and not logical_plan.replace():
|
|
895
|
+
exception = AnalysisException(
|
|
896
|
+
f"[TEMP_TABLE_OR_VIEW_ALREADY_EXISTS] Cannot create the temporary view `{spark_view_name}` because it already exists."
|
|
897
|
+
)
|
|
898
|
+
attach_custom_error_code(exception, ErrorCodes.INVALID_OPERATION)
|
|
899
|
+
raise exception
|
|
900
|
+
else:
|
|
901
|
+
unregister_temp_view(
|
|
902
|
+
spark_to_sf_single_id_with_unquoting(spark_view_name)
|
|
903
|
+
)
|
|
463
904
|
case "CreateView":
|
|
464
905
|
current_schema = session.connection.schema
|
|
465
906
|
if (
|
|
@@ -475,11 +916,13 @@ def map_sql_to_pandas_df(
|
|
|
475
916
|
df_container = execute_logical_plan(logical_plan.query())
|
|
476
917
|
df = df_container.dataframe
|
|
477
918
|
if _accessing_temp_object.get():
|
|
478
|
-
|
|
919
|
+
exception = AnalysisException(
|
|
479
920
|
f"[INVALID_TEMP_OBJ_REFERENCE] Cannot create the persistent object `{CURRENT_CATALOG_NAME}`.`{current_schema}`.`{object_name}` "
|
|
480
921
|
"of the type VIEW because it references to a temporary object of the type VIEW. Please "
|
|
481
922
|
f"make the temporary object persistent, or make the persistent object `{CURRENT_CATALOG_NAME}`.`{current_schema}`.`{object_name}` temporary."
|
|
482
923
|
)
|
|
924
|
+
attach_custom_error_code(exception, ErrorCodes.INVALID_OPERATION)
|
|
925
|
+
raise exception
|
|
483
926
|
|
|
484
927
|
name = get_relation_identifier_name(logical_plan.child())
|
|
485
928
|
comment = logical_plan.comment()
|
|
@@ -496,58 +939,143 @@ def map_sql_to_pandas_df(
|
|
|
496
939
|
else None,
|
|
497
940
|
)
|
|
498
941
|
case "CreateViewCommand":
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
942
|
+
with push_processed_view(logical_plan.name().identifier()):
|
|
943
|
+
df_container = execute_logical_plan(logical_plan.plan())
|
|
944
|
+
df = df_container.dataframe
|
|
945
|
+
user_specified_spark_column_names = [
|
|
946
|
+
str(col._1())
|
|
947
|
+
for col in as_java_list(logical_plan.userSpecifiedColumns())
|
|
948
|
+
]
|
|
949
|
+
df_container = DataFrameContainer.create_with_column_mapping(
|
|
950
|
+
dataframe=df,
|
|
951
|
+
spark_column_names=user_specified_spark_column_names
|
|
952
|
+
if user_specified_spark_column_names
|
|
953
|
+
else df_container.column_map.get_spark_columns(),
|
|
954
|
+
snowpark_column_names=df_container.column_map.get_snowpark_columns(),
|
|
955
|
+
parent_column_name_map=df_container.column_map,
|
|
507
956
|
)
|
|
508
|
-
)
|
|
509
957
|
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
),
|
|
517
|
-
):
|
|
518
|
-
name = f"{global_config.spark_sql_globalTempDatabase}.{name}"
|
|
519
|
-
comment = logical_plan.comment()
|
|
520
|
-
maybe_comment = (
|
|
521
|
-
_escape_sql_comment(str(comment.get()))
|
|
522
|
-
if comment.isDefined()
|
|
523
|
-
else None
|
|
524
|
-
)
|
|
958
|
+
is_global = isinstance(
|
|
959
|
+
logical_plan.viewType(),
|
|
960
|
+
jpype.JClass(
|
|
961
|
+
"org.apache.spark.sql.catalyst.analysis.GlobalTempView$"
|
|
962
|
+
),
|
|
963
|
+
)
|
|
525
964
|
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
965
|
+
def get_cached_view_name() -> str:
|
|
966
|
+
if is_global:
|
|
967
|
+
view_name = [
|
|
968
|
+
global_config.spark_sql_globalTempDatabase,
|
|
969
|
+
logical_plan.name().quotedString(),
|
|
970
|
+
]
|
|
971
|
+
else:
|
|
972
|
+
view_name = [logical_plan.name().quotedString()]
|
|
973
|
+
view_name = [
|
|
974
|
+
spark_to_sf_single_id_with_unquoting(part)
|
|
975
|
+
for part in view_name
|
|
976
|
+
]
|
|
977
|
+
return ".".join(view_name)
|
|
978
|
+
|
|
979
|
+
def get_snowflake_view_name() -> list[str]:
|
|
980
|
+
snowpark_view_name = str(logical_plan.name().identifier())
|
|
981
|
+
snowpark_view_name = spark_to_sf_single_id(snowpark_view_name)
|
|
982
|
+
return (
|
|
983
|
+
[
|
|
984
|
+
global_config.spark_sql_globalTempDatabase,
|
|
985
|
+
snowpark_view_name,
|
|
986
|
+
]
|
|
987
|
+
if is_global
|
|
988
|
+
else [snowpark_view_name]
|
|
989
|
+
)
|
|
529
990
|
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
991
|
+
snowflake_view_name = get_snowflake_view_name()
|
|
992
|
+
cached_view_name = get_cached_view_name()
|
|
993
|
+
|
|
994
|
+
tmp_views = _get_current_temp_objects()
|
|
995
|
+
tmp_views.add(
|
|
996
|
+
(
|
|
997
|
+
CURRENT_CATALOG_NAME,
|
|
998
|
+
session.connection.schema,
|
|
999
|
+
str(logical_plan.name().identifier()),
|
|
1000
|
+
)
|
|
539
1001
|
)
|
|
1002
|
+
|
|
1003
|
+
def _create_snowflake_temporary_view():
|
|
1004
|
+
comment = logical_plan.comment()
|
|
1005
|
+
maybe_comment = (
|
|
1006
|
+
_escape_sql_comment(str(comment.get()))
|
|
1007
|
+
if comment.isDefined()
|
|
1008
|
+
else None
|
|
1009
|
+
)
|
|
1010
|
+
|
|
1011
|
+
renamed_df = _rename_columns(
|
|
1012
|
+
df,
|
|
1013
|
+
logical_plan.userSpecifiedColumns(),
|
|
1014
|
+
df_container.column_map,
|
|
1015
|
+
)
|
|
1016
|
+
|
|
1017
|
+
create_snowflake_temporary_view(
|
|
1018
|
+
renamed_df,
|
|
1019
|
+
snowflake_view_name,
|
|
1020
|
+
cached_view_name,
|
|
1021
|
+
logical_plan.replace(),
|
|
1022
|
+
maybe_comment,
|
|
1023
|
+
)
|
|
1024
|
+
|
|
1025
|
+
if should_create_temporary_view_in_snowflake():
|
|
1026
|
+
_create_snowflake_temporary_view()
|
|
1027
|
+
else:
|
|
1028
|
+
user_specified_spark_column_names = [
|
|
1029
|
+
str(col._1())
|
|
1030
|
+
for col in as_java_list(logical_plan.userSpecifiedColumns())
|
|
1031
|
+
]
|
|
1032
|
+
spark_column_names = (
|
|
1033
|
+
user_specified_spark_column_names
|
|
1034
|
+
if user_specified_spark_column_names
|
|
1035
|
+
else df_container.column_map.get_spark_columns()
|
|
1036
|
+
)
|
|
1037
|
+
store_temporary_view_as_dataframe(
|
|
1038
|
+
df,
|
|
1039
|
+
df_container.column_map,
|
|
1040
|
+
spark_column_names,
|
|
1041
|
+
df_container.column_map.get_snowpark_columns(),
|
|
1042
|
+
cached_view_name,
|
|
1043
|
+
snowflake_view_name,
|
|
1044
|
+
logical_plan.replace(),
|
|
1045
|
+
)
|
|
540
1046
|
case "DescribeColumn":
|
|
541
|
-
name =
|
|
1047
|
+
name = get_relation_identifier_name_without_uppercasing(
|
|
1048
|
+
logical_plan.column()
|
|
1049
|
+
)
|
|
1050
|
+
stored_temp_view = get_temp_view(name)
|
|
1051
|
+
if stored_temp_view:
|
|
1052
|
+
return (
|
|
1053
|
+
SNOWFLAKE_CATALOG._list_columns_from_dataframe_container(
|
|
1054
|
+
stored_temp_view
|
|
1055
|
+
),
|
|
1056
|
+
"",
|
|
1057
|
+
)
|
|
542
1058
|
# todo double check if this is correct
|
|
1059
|
+
name = get_relation_identifier_name(logical_plan.column())
|
|
543
1060
|
rows = session.sql(f"DESCRIBE TABLE {name}").collect()
|
|
544
1061
|
case "DescribeNamespace":
|
|
545
1062
|
name = get_relation_identifier_name(logical_plan.namespace(), True)
|
|
546
|
-
name = change_default_to_public(name)
|
|
547
1063
|
rows = session.sql(f"DESCRIBE SCHEMA {name}").collect()
|
|
548
1064
|
if not rows:
|
|
549
1065
|
rows = None
|
|
550
1066
|
case "DescribeRelation":
|
|
1067
|
+
name = get_relation_identifier_name_without_uppercasing(
|
|
1068
|
+
logical_plan.relation(), True
|
|
1069
|
+
)
|
|
1070
|
+
stored_temp_view = get_temp_view(name)
|
|
1071
|
+
if stored_temp_view:
|
|
1072
|
+
return (
|
|
1073
|
+
SNOWFLAKE_CATALOG._list_columns_from_dataframe_container(
|
|
1074
|
+
stored_temp_view
|
|
1075
|
+
),
|
|
1076
|
+
"",
|
|
1077
|
+
)
|
|
1078
|
+
|
|
551
1079
|
name = get_relation_identifier_name(logical_plan.relation(), True)
|
|
552
1080
|
rows = session.sql(f"DESCRIBE TABLE {name}").collect()
|
|
553
1081
|
if not rows:
|
|
@@ -598,9 +1126,11 @@ def map_sql_to_pandas_df(
|
|
|
598
1126
|
del session._udtfs[func_name]
|
|
599
1127
|
else:
|
|
600
1128
|
if not logical_plan.ifExists():
|
|
601
|
-
|
|
1129
|
+
exception = ValueError(
|
|
602
1130
|
f"Function {func_name} not found among registered UDFs or UDTFs."
|
|
603
1131
|
)
|
|
1132
|
+
attach_custom_error_code(exception, ErrorCodes.INVALID_INPUT)
|
|
1133
|
+
raise exception
|
|
604
1134
|
if snowpark_name != "":
|
|
605
1135
|
argument_string = f"({', '.join(convert_sp_to_sf_type(arg) for arg in input_types)})"
|
|
606
1136
|
session.sql(
|
|
@@ -615,9 +1145,13 @@ def map_sql_to_pandas_df(
|
|
|
615
1145
|
if_exists = "IF EXISTS " if logical_plan.ifExists() else ""
|
|
616
1146
|
session.sql(f"DROP TABLE {if_exists}{name}").collect()
|
|
617
1147
|
case "DropView":
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
1148
|
+
temporary_view_name = get_relation_identifier_name_without_uppercasing(
|
|
1149
|
+
logical_plan.child()
|
|
1150
|
+
)
|
|
1151
|
+
if not unregister_temp_view(temporary_view_name):
|
|
1152
|
+
name = get_relation_identifier_name(logical_plan.child())
|
|
1153
|
+
if_exists = "IF EXISTS " if logical_plan.ifExists() else ""
|
|
1154
|
+
session.sql(f"DROP VIEW {if_exists}{name}").collect()
|
|
621
1155
|
case "ExplainCommand":
|
|
622
1156
|
inner_plan = logical_plan.logicalPlan()
|
|
623
1157
|
logical_plan_name = inner_plan.nodeName()
|
|
@@ -669,84 +1203,189 @@ def map_sql_to_pandas_df(
|
|
|
669
1203
|
rows = session.sql(final_sql).collect()
|
|
670
1204
|
else:
|
|
671
1205
|
# TODO: Support other logical plans
|
|
672
|
-
|
|
1206
|
+
exception = SnowparkConnectNotImplementedError(
|
|
673
1207
|
f"{logical_plan_name} is not supported yet with EXPLAIN."
|
|
674
1208
|
)
|
|
1209
|
+
attach_custom_error_code(
|
|
1210
|
+
exception, ErrorCodes.UNSUPPORTED_OPERATION
|
|
1211
|
+
)
|
|
1212
|
+
raise exception
|
|
675
1213
|
case "InsertIntoStatement":
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
|
|
1214
|
+
_insert_into_table(logical_plan, session)
|
|
1215
|
+
case "MergeIntoTable":
|
|
1216
|
+
source_df_container = map_relation(
|
|
1217
|
+
map_logical_plan_relation(logical_plan.sourceTable())
|
|
1218
|
+
)
|
|
1219
|
+
source_df = source_df_container.dataframe
|
|
1220
|
+
plan_id = gen_sql_plan_id()
|
|
1221
|
+
target_df_container = map_relation(
|
|
1222
|
+
map_logical_plan_relation(logical_plan.targetTable(), plan_id)
|
|
1223
|
+
)
|
|
1224
|
+
target_df = target_df_container.dataframe
|
|
1225
|
+
|
|
1226
|
+
if (
|
|
1227
|
+
logical_plan.targetTable().getClass().getSimpleName()
|
|
1228
|
+
== "UnresolvedRelation"
|
|
1229
|
+
):
|
|
1230
|
+
target_table_name = _spark_to_snowflake(
|
|
1231
|
+
logical_plan.targetTable().multipartIdentifier()
|
|
1232
|
+
)
|
|
1233
|
+
else:
|
|
1234
|
+
target_table_name = _spark_to_snowflake(
|
|
1235
|
+
logical_plan.targetTable().child().multipartIdentifier()
|
|
682
1236
|
)
|
|
683
1237
|
|
|
684
|
-
|
|
1238
|
+
target_table = session.table(target_table_name)
|
|
1239
|
+
target_table_columns = target_table.columns
|
|
1240
|
+
target_df_spark_names = []
|
|
1241
|
+
for target_table_col, target_df_col in zip(
|
|
1242
|
+
target_table_columns, target_df_container.column_map.columns
|
|
1243
|
+
):
|
|
1244
|
+
target_df = target_df.with_column_renamed(
|
|
1245
|
+
target_df_col.snowpark_name,
|
|
1246
|
+
target_table_col,
|
|
1247
|
+
)
|
|
1248
|
+
target_df_spark_names.append(target_df_col.spark_name)
|
|
1249
|
+
target_df_container = DataFrameContainer.create_with_column_mapping(
|
|
1250
|
+
dataframe=target_df,
|
|
1251
|
+
spark_column_names=target_df_spark_names,
|
|
1252
|
+
snowpark_column_names=target_table_columns,
|
|
1253
|
+
)
|
|
685
1254
|
|
|
686
|
-
|
|
687
|
-
spark_to_sf_single_id(str(col), is_column=True)
|
|
688
|
-
for col in as_java_list(logical_plan.userSpecifiedCols())
|
|
689
|
-
]
|
|
690
|
-
overwrite_str = "OVERWRITE" if logical_plan.overwrite() else ""
|
|
691
|
-
cols_str = "(" + ", ".join(user_columns) + ")" if user_columns else ""
|
|
1255
|
+
set_plan_id_map(plan_id, target_df_container)
|
|
692
1256
|
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
1257
|
+
joined_df_before_condition: snowpark.DataFrame = source_df.join(
|
|
1258
|
+
target_df
|
|
1259
|
+
)
|
|
696
1260
|
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
col_name = source_field.name
|
|
703
|
-
if isinstance(
|
|
704
|
-
target_field.datatype, snowpark.types.DecimalType
|
|
705
|
-
) and isinstance(
|
|
706
|
-
source_field.datatype,
|
|
707
|
-
(snowpark.types.FloatType, snowpark.types.DoubleType),
|
|
708
|
-
):
|
|
709
|
-
# Add CASE WHEN to convert NaN to NULL for DECIMAL targets
|
|
710
|
-
# Only apply this to floating-point source columns
|
|
711
|
-
modified_col = (
|
|
712
|
-
snowpark_fn.when(
|
|
713
|
-
snowpark_fn.equal_nan(snowpark_fn.col(col_name)),
|
|
714
|
-
snowpark_fn.lit(None),
|
|
715
|
-
)
|
|
716
|
-
.otherwise(snowpark_fn.col(col_name))
|
|
717
|
-
.alias(col_name)
|
|
718
|
-
)
|
|
719
|
-
modified_columns.append(modified_col)
|
|
720
|
-
else:
|
|
721
|
-
modified_columns.append(snowpark_fn.col(col_name))
|
|
1261
|
+
column_mapping_for_conditions = column_name_handler.JoinColumnNameMap(
|
|
1262
|
+
source_df_container.column_map,
|
|
1263
|
+
target_df_container.column_map,
|
|
1264
|
+
)
|
|
1265
|
+
typer_for_expressions = ExpressionTyper(joined_df_before_condition)
|
|
722
1266
|
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
final_query = queries[0]
|
|
728
|
-
session.sql(
|
|
729
|
-
f"INSERT {overwrite_str} INTO {name} {cols_str} {final_query}",
|
|
730
|
-
).collect()
|
|
731
|
-
case "MergeIntoTable":
|
|
732
|
-
raise UnsupportedOperationException(
|
|
733
|
-
"[UNSUPPORTED_SQL_EXTENSION] The MERGE INTO command failed.\n"
|
|
734
|
-
+ "Reason: This command is a platform-specific SQL extension and is not part of the standard Apache Spark specification that this interface uses."
|
|
1267
|
+
(_, merge_condition_typed_col,) = map_single_column_expression(
|
|
1268
|
+
map_logical_plan_expression(logical_plan.mergeCondition()),
|
|
1269
|
+
column_mapping=column_mapping_for_conditions,
|
|
1270
|
+
typer=typer_for_expressions,
|
|
735
1271
|
)
|
|
1272
|
+
|
|
1273
|
+
clauses = []
|
|
1274
|
+
|
|
1275
|
+
for matched_action in as_java_list(logical_plan.matchedActions()):
|
|
1276
|
+
condition = _get_condition_from_action(
|
|
1277
|
+
matched_action,
|
|
1278
|
+
column_mapping_for_conditions,
|
|
1279
|
+
typer_for_expressions,
|
|
1280
|
+
)
|
|
1281
|
+
if matched_action.getClass().getSimpleName() == "DeleteAction":
|
|
1282
|
+
clauses.append(when_matched(condition).delete())
|
|
1283
|
+
elif (
|
|
1284
|
+
matched_action.getClass().getSimpleName() == "UpdateAction"
|
|
1285
|
+
or matched_action.getClass().getSimpleName()
|
|
1286
|
+
== "UpdateStarAction"
|
|
1287
|
+
):
|
|
1288
|
+
assignments = _get_assignments_from_action(
|
|
1289
|
+
matched_action,
|
|
1290
|
+
source_df_container.column_map,
|
|
1291
|
+
target_df_container.column_map,
|
|
1292
|
+
ExpressionTyper(source_df),
|
|
1293
|
+
ExpressionTyper(target_df),
|
|
1294
|
+
)
|
|
1295
|
+
clauses.append(when_matched(condition).update(assignments))
|
|
1296
|
+
|
|
1297
|
+
for not_matched_action in as_java_list(
|
|
1298
|
+
logical_plan.notMatchedActions()
|
|
1299
|
+
):
|
|
1300
|
+
condition = _get_condition_from_action(
|
|
1301
|
+
not_matched_action,
|
|
1302
|
+
column_mapping_for_conditions,
|
|
1303
|
+
typer_for_expressions,
|
|
1304
|
+
)
|
|
1305
|
+
if (
|
|
1306
|
+
not_matched_action.getClass().getSimpleName() == "InsertAction"
|
|
1307
|
+
or not_matched_action.getClass().getSimpleName()
|
|
1308
|
+
== "InsertStarAction"
|
|
1309
|
+
):
|
|
1310
|
+
assignments = _get_assignments_from_action(
|
|
1311
|
+
not_matched_action,
|
|
1312
|
+
source_df_container.column_map,
|
|
1313
|
+
target_df_container.column_map,
|
|
1314
|
+
ExpressionTyper(source_df),
|
|
1315
|
+
ExpressionTyper(target_df),
|
|
1316
|
+
)
|
|
1317
|
+
clauses.append(when_not_matched(condition).insert(assignments))
|
|
1318
|
+
|
|
1319
|
+
if not as_java_list(logical_plan.notMatchedBySourceActions()).isEmpty():
|
|
1320
|
+
exception = SnowparkConnectNotImplementedError(
|
|
1321
|
+
"Snowflake does not support 'not matched by source' actions in MERGE statements."
|
|
1322
|
+
)
|
|
1323
|
+
attach_custom_error_code(
|
|
1324
|
+
exception, ErrorCodes.UNSUPPORTED_OPERATION
|
|
1325
|
+
)
|
|
1326
|
+
raise exception
|
|
1327
|
+
|
|
1328
|
+
target_table.merge(source_df, merge_condition_typed_col.col, clauses)
|
|
736
1329
|
case "DeleteFromTable":
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
+ "Reason: This command is a platform-specific SQL extension and is not part of the standard Apache Spark specification that this interface uses."
|
|
1330
|
+
df_container = map_relation(
|
|
1331
|
+
map_logical_plan_relation(logical_plan.table())
|
|
740
1332
|
)
|
|
1333
|
+
name = get_relation_identifier_name(logical_plan.table(), True)
|
|
1334
|
+
table = session.table(name)
|
|
1335
|
+
table_columns = table.columns
|
|
1336
|
+
df = df_container.dataframe
|
|
1337
|
+
spark_names = []
|
|
1338
|
+
for table_col, df_col in zip(
|
|
1339
|
+
table_columns, df_container.column_map.columns
|
|
1340
|
+
):
|
|
1341
|
+
df = df.with_column_renamed(
|
|
1342
|
+
df_col.snowpark_name,
|
|
1343
|
+
table_col,
|
|
1344
|
+
)
|
|
1345
|
+
spark_names.append(df_col.spark_name)
|
|
1346
|
+
df_container = DataFrameContainer.create_with_column_mapping(
|
|
1347
|
+
dataframe=df,
|
|
1348
|
+
spark_column_names=spark_names,
|
|
1349
|
+
snowpark_column_names=table_columns,
|
|
1350
|
+
)
|
|
1351
|
+
df = df_container.dataframe
|
|
1352
|
+
(
|
|
1353
|
+
condition_column_name,
|
|
1354
|
+
condition_typed_col,
|
|
1355
|
+
) = map_single_column_expression(
|
|
1356
|
+
map_logical_plan_expression(logical_plan.condition()),
|
|
1357
|
+
df_container.column_map,
|
|
1358
|
+
ExpressionTyper(df),
|
|
1359
|
+
)
|
|
1360
|
+
table.delete(condition_typed_col.col)
|
|
741
1361
|
case "UpdateTable":
|
|
742
1362
|
# Databricks/Delta-specific extension not supported by SAS.
|
|
743
1363
|
# Provide an actionable, clear error.
|
|
744
|
-
|
|
1364
|
+
exception = UnsupportedOperationException(
|
|
745
1365
|
"[UNSUPPORTED_SQL_EXTENSION] The UPDATE TABLE command failed.\n"
|
|
746
1366
|
+ "Reason: This command is a platform-specific SQL extension and is not part of the standard Apache Spark specification that this interface uses."
|
|
747
1367
|
)
|
|
1368
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
1369
|
+
raise exception
|
|
748
1370
|
case "RenameColumn":
|
|
749
|
-
|
|
1371
|
+
full_table_identifier = get_relation_identifier_name(
|
|
1372
|
+
logical_plan.table(), True
|
|
1373
|
+
)
|
|
1374
|
+
|
|
1375
|
+
# Check Spark compatibility for RENAME COLUMN operation
|
|
1376
|
+
if not check_table_supports_operation(
|
|
1377
|
+
full_table_identifier, "rename_column"
|
|
1378
|
+
):
|
|
1379
|
+
exception = AnalysisException(
|
|
1380
|
+
f"ALTER TABLE RENAME COLUMN is not supported for table '{full_table_identifier}'. "
|
|
1381
|
+
f"This table was created as a v1 table with a data source that doesn't support column renaming. "
|
|
1382
|
+
f"To enable this operation, set 'snowpark.connect.enable_snowflake_extension_behavior' to 'true'."
|
|
1383
|
+
)
|
|
1384
|
+
attach_custom_error_code(
|
|
1385
|
+
exception, ErrorCodes.UNSUPPORTED_OPERATION
|
|
1386
|
+
)
|
|
1387
|
+
raise exception
|
|
1388
|
+
|
|
750
1389
|
column_obj = logical_plan.column()
|
|
751
1390
|
old_column_name = ".".join(
|
|
752
1391
|
spark_to_sf_single_id(str(part), is_column=True)
|
|
@@ -756,7 +1395,7 @@ def map_sql_to_pandas_df(
|
|
|
756
1395
|
case_insensitive_name = next(
|
|
757
1396
|
(
|
|
758
1397
|
f.name
|
|
759
|
-
for f in session.table(
|
|
1398
|
+
for f in session.table(full_table_identifier).schema.fields
|
|
760
1399
|
if f.name.lower() == old_column_name.lower()
|
|
761
1400
|
),
|
|
762
1401
|
None,
|
|
@@ -768,7 +1407,7 @@ def map_sql_to_pandas_df(
|
|
|
768
1407
|
)
|
|
769
1408
|
|
|
770
1409
|
# Pass through to Snowflake
|
|
771
|
-
snowflake_sql = f"ALTER TABLE {
|
|
1410
|
+
snowflake_sql = f"ALTER TABLE {full_table_identifier} RENAME COLUMN {old_column_name} TO {new_column_name}"
|
|
772
1411
|
session.sql(snowflake_sql).collect()
|
|
773
1412
|
case "RenameTable":
|
|
774
1413
|
name = get_relation_identifier_name(logical_plan.child(), True)
|
|
@@ -786,30 +1425,31 @@ def map_sql_to_pandas_df(
|
|
|
786
1425
|
f"ALTER ICEBERG TABLE {name} RENAME TO {new_name}"
|
|
787
1426
|
).collect()
|
|
788
1427
|
else:
|
|
1428
|
+
attach_custom_error_code(e, ErrorCodes.INTERNAL_ERROR)
|
|
789
1429
|
raise e
|
|
790
1430
|
case "ReplaceTableAsSelect":
|
|
791
1431
|
_create_table_as_select(logical_plan, mode="overwrite")
|
|
792
1432
|
case "ResetCommand":
|
|
793
1433
|
key = logical_plan.config().get()
|
|
794
|
-
unset_config_param(
|
|
1434
|
+
unset_config_param(get_spark_session_id(), key, session)
|
|
795
1435
|
case "SetCatalogAndNamespace":
|
|
796
1436
|
# TODO: add catalog setting here
|
|
797
1437
|
name = get_relation_identifier_name(logical_plan.child(), True)
|
|
798
|
-
name = change_default_to_public(name)
|
|
799
1438
|
session.sql(f"USE SCHEMA {name}").collect()
|
|
800
1439
|
case "SetCommand":
|
|
801
1440
|
kv_result_tuple = logical_plan.kv().get()
|
|
802
1441
|
key = kv_result_tuple._1()
|
|
803
1442
|
val = kv_result_tuple._2().get()
|
|
804
|
-
set_config_param(
|
|
1443
|
+
set_config_param(get_spark_session_id(), key, val, session)
|
|
805
1444
|
case "SetNamespaceCommand":
|
|
806
1445
|
name = _spark_to_snowflake(logical_plan.namespace())
|
|
807
|
-
name = change_default_to_public(name)
|
|
808
1446
|
session.sql(f"USE SCHEMA {name}").collect()
|
|
809
1447
|
case "SetNamespaceLocation" | "SetNamespaceProperties":
|
|
810
|
-
|
|
1448
|
+
exception = SnowparkConnectNotImplementedError(
|
|
811
1449
|
"Altering databases is not currently supported."
|
|
812
1450
|
)
|
|
1451
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
1452
|
+
raise exception
|
|
813
1453
|
case "ShowCreateTable":
|
|
814
1454
|
# Handle SHOW CREATE TABLE command
|
|
815
1455
|
# Spark: SHOW CREATE TABLE table_name
|
|
@@ -831,16 +1471,24 @@ def map_sql_to_pandas_df(
|
|
|
831
1471
|
case "ShowNamespaces":
|
|
832
1472
|
name = get_relation_identifier_name(logical_plan.namespace(), True)
|
|
833
1473
|
if name:
|
|
834
|
-
|
|
1474
|
+
exception = SnowparkConnectNotImplementedError(
|
|
835
1475
|
"'IN' clause is not supported while listing databases"
|
|
836
1476
|
)
|
|
1477
|
+
attach_custom_error_code(
|
|
1478
|
+
exception, ErrorCodes.UNSUPPORTED_OPERATION
|
|
1479
|
+
)
|
|
1480
|
+
raise exception
|
|
837
1481
|
if logical_plan.pattern().isDefined():
|
|
838
1482
|
# Snowflake SQL requires a "%" pattern.
|
|
839
1483
|
# Snowpark catalog requires a regex and does client-side filtering.
|
|
840
1484
|
# Spark, however, uses a regex-like pattern that treats '*' and '|' differently.
|
|
841
|
-
|
|
1485
|
+
exception = SnowparkConnectNotImplementedError(
|
|
842
1486
|
"'LIKE' clause is not supported while listing databases"
|
|
843
1487
|
)
|
|
1488
|
+
attach_custom_error_code(
|
|
1489
|
+
exception, ErrorCodes.UNSUPPORTED_OPERATION
|
|
1490
|
+
)
|
|
1491
|
+
raise exception
|
|
844
1492
|
rows = session.sql("SHOW SCHEMAS").collect()
|
|
845
1493
|
if not rows:
|
|
846
1494
|
rows = None
|
|
@@ -913,6 +1561,18 @@ def map_sql_to_pandas_df(
|
|
|
913
1561
|
if pattern and rows:
|
|
914
1562
|
rows = _filter_tables_by_pattern(rows, pattern)
|
|
915
1563
|
case "ShowColumns":
|
|
1564
|
+
name = get_relation_identifier_name_without_uppercasing(
|
|
1565
|
+
logical_plan.child(), True
|
|
1566
|
+
)
|
|
1567
|
+
stored_temp_view = get_temp_view(name)
|
|
1568
|
+
if stored_temp_view:
|
|
1569
|
+
return (
|
|
1570
|
+
SNOWFLAKE_CATALOG._list_columns_from_dataframe_container(
|
|
1571
|
+
stored_temp_view
|
|
1572
|
+
),
|
|
1573
|
+
"",
|
|
1574
|
+
)
|
|
1575
|
+
|
|
916
1576
|
# Handle Spark SQL: SHOW COLUMNS IN table_name FROM database_name
|
|
917
1577
|
# Convert to Snowflake SQL: SHOW COLUMNS IN TABLE database_name.table_name
|
|
918
1578
|
|
|
@@ -941,9 +1601,13 @@ def map_sql_to_pandas_df(
|
|
|
941
1601
|
spark_to_sf_single_id(str(db_and_table_name[0])).casefold()
|
|
942
1602
|
!= db_name.casefold()
|
|
943
1603
|
):
|
|
944
|
-
|
|
1604
|
+
exception = AnalysisException(
|
|
945
1605
|
f"database name is not matching:{db_name} and {db_and_table_name[0]}"
|
|
946
1606
|
)
|
|
1607
|
+
attach_custom_error_code(
|
|
1608
|
+
exception, ErrorCodes.INVALID_OPERATION
|
|
1609
|
+
)
|
|
1610
|
+
raise exception
|
|
947
1611
|
|
|
948
1612
|
# Just table name
|
|
949
1613
|
snowflake_cmd = f"SHOW COLUMNS IN TABLE {table_name}"
|
|
@@ -981,6 +1645,51 @@ def map_sql_to_pandas_df(
|
|
|
981
1645
|
return pandas.DataFrame({"": [""]}), ""
|
|
982
1646
|
|
|
983
1647
|
rows = session.sql(snowflake_sql).collect()
|
|
1648
|
+
case "RefreshTable":
|
|
1649
|
+
table_name_unquoted = ".".join(
|
|
1650
|
+
str(part)
|
|
1651
|
+
for part in as_java_list(logical_plan.child().multipartIdentifier())
|
|
1652
|
+
)
|
|
1653
|
+
SNOWFLAKE_CATALOG.refreshTable(table_name_unquoted)
|
|
1654
|
+
|
|
1655
|
+
return pandas.DataFrame({"": [""]}), ""
|
|
1656
|
+
case "RepairTable":
|
|
1657
|
+
# No-Op: Snowflake doesn't have explicit partitions to repair.
|
|
1658
|
+
table_relation = logical_plan.child()
|
|
1659
|
+
db_and_table_name = as_java_list(table_relation.multipartIdentifier())
|
|
1660
|
+
multi_part_len = len(db_and_table_name)
|
|
1661
|
+
|
|
1662
|
+
if multi_part_len == 1:
|
|
1663
|
+
table_name = db_and_table_name[0]
|
|
1664
|
+
db_name = None
|
|
1665
|
+
full_table_name = table_name
|
|
1666
|
+
else:
|
|
1667
|
+
db_name = db_and_table_name[0]
|
|
1668
|
+
table_name = db_and_table_name[1]
|
|
1669
|
+
full_table_name = db_name + "." + table_name
|
|
1670
|
+
|
|
1671
|
+
df = SNOWFLAKE_CATALOG.tableExists(table_name, db_name)
|
|
1672
|
+
|
|
1673
|
+
table_exist = df.iloc[0, 0]
|
|
1674
|
+
|
|
1675
|
+
if not table_exist:
|
|
1676
|
+
exception = AnalysisException(
|
|
1677
|
+
f"[TABLE_OR_VIEW_NOT_FOUND] Table not found `{full_table_name}`."
|
|
1678
|
+
)
|
|
1679
|
+
attach_custom_error_code(exception, ErrorCodes.INVALID_OPERATION)
|
|
1680
|
+
raise exception
|
|
1681
|
+
|
|
1682
|
+
return pandas.DataFrame({"": [""]}), ""
|
|
1683
|
+
case "UnresolvedWith":
|
|
1684
|
+
child = logical_plan.child()
|
|
1685
|
+
child_class = str(child.getClass().getSimpleName())
|
|
1686
|
+
match child_class:
|
|
1687
|
+
case "InsertIntoStatement":
|
|
1688
|
+
with _with_cte_scope(logical_plan.cteRelations()):
|
|
1689
|
+
_insert_into_table(child, get_or_create_snowpark_session())
|
|
1690
|
+
case _:
|
|
1691
|
+
execute_logical_plan(logical_plan)
|
|
1692
|
+
return None, None
|
|
984
1693
|
case _:
|
|
985
1694
|
execute_logical_plan(logical_plan)
|
|
986
1695
|
return None, None
|
|
@@ -1001,6 +1710,27 @@ def get_sql_passthrough() -> bool:
|
|
|
1001
1710
|
return get_boolean_session_config_param("snowpark.connect.sql.passthrough")
|
|
1002
1711
|
|
|
1003
1712
|
|
|
1713
|
+
def is_valid_passthrough_sql(sql_stmt: str) -> Tuple[bool, str]:
|
|
1714
|
+
"""
|
|
1715
|
+
Checks if :param sql_stmt: should be executed as SQL pass-through. SQL pass-through can be detected in 1 of 2 ways:
|
|
1716
|
+
1) Either Spark config parameter "snowpark.connect.sql.passthrough" is set (legacy mode, to be deprecated)
|
|
1717
|
+
2) If :param sql_stmt: is created through SnowflakeSession and has correct marker + checksum
|
|
1718
|
+
"""
|
|
1719
|
+
if get_sql_passthrough():
|
|
1720
|
+
# legacy style pass-through, sql_stmt should be a whole, valid SQL statement
|
|
1721
|
+
return True, sql_stmt
|
|
1722
|
+
|
|
1723
|
+
# check for new style, SnowflakeSession based SQL pass-through
|
|
1724
|
+
sql_parts = sql_stmt.split(" ", 2)
|
|
1725
|
+
if len(sql_parts) == 3:
|
|
1726
|
+
marker, checksum, sql = sql_parts
|
|
1727
|
+
if marker == SQL_PASS_THROUGH_MARKER and checksum == calculate_checksum(sql):
|
|
1728
|
+
return True, sql
|
|
1729
|
+
|
|
1730
|
+
# Not a SQL pass-through
|
|
1731
|
+
return False, sql_stmt
|
|
1732
|
+
|
|
1733
|
+
|
|
1004
1734
|
def change_default_to_public(name: str) -> str:
|
|
1005
1735
|
"""
|
|
1006
1736
|
Change the namespace to PUBLIC when given name is DEFAULT
|
|
@@ -1015,6 +1745,76 @@ def change_default_to_public(name: str) -> str:
|
|
|
1015
1745
|
return name
|
|
1016
1746
|
|
|
1017
1747
|
|
|
1748
|
+
def _preprocess_identifier_calls(sql_query: str) -> str:
|
|
1749
|
+
"""
|
|
1750
|
+
Pre-process SQL query to resolve IDENTIFIER() calls before Spark parsing.
|
|
1751
|
+
|
|
1752
|
+
Transforms: IDENTIFIER('abs')(c2) -> abs(c2)
|
|
1753
|
+
Transforms: IDENTIFIER('COAL' || 'ESCE')(NULL, 1) -> COALESCE(NULL, 1)
|
|
1754
|
+
|
|
1755
|
+
This preserves all function arguments in their original positions, eliminating
|
|
1756
|
+
the need to reconstruct them at the expression level.
|
|
1757
|
+
"""
|
|
1758
|
+
import re
|
|
1759
|
+
|
|
1760
|
+
# Pattern to match IDENTIFIER(...) followed by optional function call arguments
|
|
1761
|
+
# This captures both the identifier expression and any trailing arguments
|
|
1762
|
+
# Note: We need to be careful about whitespace preservation
|
|
1763
|
+
identifier_pattern = r"IDENTIFIER\s*\(\s*([^)]+)\s*\)(\s*)(\([^)]*\))?"
|
|
1764
|
+
|
|
1765
|
+
def resolve_identifier_match(match):
|
|
1766
|
+
identifier_expr_str = match.group(1).strip()
|
|
1767
|
+
whitespace = match.group(2) if match.group(2) else ""
|
|
1768
|
+
function_args = match.group(3) if match.group(3) else ""
|
|
1769
|
+
|
|
1770
|
+
try:
|
|
1771
|
+
# Handle string concatenation FIRST: IDENTIFIER('COAL' || 'ESCE')
|
|
1772
|
+
# (Must check this before simple strings since it also starts/ends with quotes)
|
|
1773
|
+
if "||" in identifier_expr_str:
|
|
1774
|
+
# Parse basic string concatenation with proper quote handling
|
|
1775
|
+
parts = []
|
|
1776
|
+
split_parts = identifier_expr_str.split("||")
|
|
1777
|
+
for part in split_parts:
|
|
1778
|
+
part = part.strip()
|
|
1779
|
+
if part.startswith("'") and part.endswith("'"):
|
|
1780
|
+
unquoted = part[1:-1] # Remove quotes from each part
|
|
1781
|
+
parts.append(unquoted)
|
|
1782
|
+
else:
|
|
1783
|
+
# Non-string parts - return original for safety
|
|
1784
|
+
return match.group(0)
|
|
1785
|
+
resolved_name = "".join(parts) # Concatenate the unquoted parts
|
|
1786
|
+
|
|
1787
|
+
# Handle simple string literals: IDENTIFIER('abs')
|
|
1788
|
+
elif identifier_expr_str.startswith("'") and identifier_expr_str.endswith(
|
|
1789
|
+
"'"
|
|
1790
|
+
):
|
|
1791
|
+
resolved_name = identifier_expr_str[1:-1] # Remove quotes
|
|
1792
|
+
|
|
1793
|
+
else:
|
|
1794
|
+
# Complex expressions not supported yet - return original
|
|
1795
|
+
return match.group(0)
|
|
1796
|
+
|
|
1797
|
+
# Return resolved function call with preserved arguments and whitespace
|
|
1798
|
+
if function_args:
|
|
1799
|
+
# Function call case: IDENTIFIER('abs')(c1) -> abs(c1)
|
|
1800
|
+
result = f"{resolved_name}{function_args}"
|
|
1801
|
+
else:
|
|
1802
|
+
# Column reference case: IDENTIFIER('c1') FROM -> c1 FROM (preserve whitespace)
|
|
1803
|
+
result = f"{resolved_name}{whitespace}"
|
|
1804
|
+
return result
|
|
1805
|
+
|
|
1806
|
+
except Exception:
|
|
1807
|
+
# Return original to avoid breaking the query
|
|
1808
|
+
return match.group(0)
|
|
1809
|
+
|
|
1810
|
+
# Apply the transformation
|
|
1811
|
+
processed_query = re.sub(
|
|
1812
|
+
identifier_pattern, resolve_identifier_match, sql_query, flags=re.IGNORECASE
|
|
1813
|
+
)
|
|
1814
|
+
|
|
1815
|
+
return processed_query
|
|
1816
|
+
|
|
1817
|
+
|
|
1018
1818
|
def map_sql(
|
|
1019
1819
|
rel: relation_proto.Relation,
|
|
1020
1820
|
) -> DataFrameContainer:
|
|
@@ -1026,10 +1826,15 @@ def map_sql(
|
|
|
1026
1826
|
In passthough mode as True, SAS calls session.sql() and not calling Spark Parser.
|
|
1027
1827
|
This is to mitigate any issue not covered by spark logical plan to protobuf conversion.
|
|
1028
1828
|
"""
|
|
1029
|
-
snowpark_connect_sql_passthrough =
|
|
1829
|
+
snowpark_connect_sql_passthrough, sql_stmt = is_valid_passthrough_sql(rel.sql.query)
|
|
1030
1830
|
|
|
1031
1831
|
if not snowpark_connect_sql_passthrough:
|
|
1032
|
-
|
|
1832
|
+
# Changed from parseQuery to parsePlan as Spark parseQuery() call generating wrong logical plan for
|
|
1833
|
+
# query like this: SELECT cast('3.4' as decimal(38, 18)) UNION SELECT 'foo'
|
|
1834
|
+
# As such other place in this file we use parsePlan.
|
|
1835
|
+
# Main difference between parsePlan() and parseQuery() is, parsePlan() can be called for any SQL statement, while
|
|
1836
|
+
# parseQuery() can only be called for query statements.
|
|
1837
|
+
logical_plan = sql_parser().parsePlan(sql_stmt)
|
|
1033
1838
|
|
|
1034
1839
|
parsed_pos_args = parse_pos_args(logical_plan, rel.sql.pos_args)
|
|
1035
1840
|
set_sql_args(rel.sql.args, parsed_pos_args)
|
|
@@ -1037,7 +1842,7 @@ def map_sql(
|
|
|
1037
1842
|
return execute_logical_plan(logical_plan)
|
|
1038
1843
|
else:
|
|
1039
1844
|
session = snowpark.Session.get_active_session()
|
|
1040
|
-
sql_df = session.sql(
|
|
1845
|
+
sql_df = session.sql(sql_stmt)
|
|
1041
1846
|
columns = sql_df.columns
|
|
1042
1847
|
return DataFrameContainer.create_with_column_mapping(
|
|
1043
1848
|
dataframe=sql_df,
|
|
@@ -1112,7 +1917,19 @@ def map_logical_plan_relation(
|
|
|
1112
1917
|
attr_parts = as_java_list(expr.nameParts())
|
|
1113
1918
|
if len(attr_parts) == 1:
|
|
1114
1919
|
attr_name = str(attr_parts[0])
|
|
1115
|
-
|
|
1920
|
+
if attr_name in alias_map:
|
|
1921
|
+
# Check if the alias references an aggregate function
|
|
1922
|
+
# If so, don't substitute because you can't GROUP BY an aggregate
|
|
1923
|
+
aliased_expr = alias_map[attr_name]
|
|
1924
|
+
aliased_expr_class = str(
|
|
1925
|
+
aliased_expr.getClass().getSimpleName()
|
|
1926
|
+
)
|
|
1927
|
+
if aliased_expr_class == "UnresolvedFunction":
|
|
1928
|
+
func_name = str(aliased_expr.nameParts().head())
|
|
1929
|
+
if is_aggregate_function(func_name):
|
|
1930
|
+
return expr
|
|
1931
|
+
return aliased_expr
|
|
1932
|
+
return expr
|
|
1116
1933
|
|
|
1117
1934
|
return expr
|
|
1118
1935
|
|
|
@@ -1129,9 +1946,13 @@ def map_logical_plan_relation(
|
|
|
1129
1946
|
group_type = snowflake_proto.Aggregate.GROUP_TYPE_CUBE
|
|
1130
1947
|
case "GroupingSets":
|
|
1131
1948
|
if not exp.userGivenGroupByExprs().isEmpty():
|
|
1132
|
-
|
|
1949
|
+
exception = SnowparkConnectNotImplementedError(
|
|
1133
1950
|
"User-defined group by expressions are not supported"
|
|
1134
1951
|
)
|
|
1952
|
+
attach_custom_error_code(
|
|
1953
|
+
exception, ErrorCodes.UNSUPPORTED_OPERATION
|
|
1954
|
+
)
|
|
1955
|
+
raise exception
|
|
1135
1956
|
group_type = (
|
|
1136
1957
|
snowflake_proto.Aggregate.GROUP_TYPE_GROUPING_SETS
|
|
1137
1958
|
)
|
|
@@ -1147,9 +1968,13 @@ def map_logical_plan_relation(
|
|
|
1147
1968
|
|
|
1148
1969
|
if group_type != snowflake_proto.Aggregate.GROUP_TYPE_GROUPBY:
|
|
1149
1970
|
if len(group_expression_list) != 1:
|
|
1150
|
-
|
|
1971
|
+
exception = SnowparkConnectNotImplementedError(
|
|
1151
1972
|
"Multiple grouping expressions are not supported"
|
|
1152
1973
|
)
|
|
1974
|
+
attach_custom_error_code(
|
|
1975
|
+
exception, ErrorCodes.UNSUPPORTED_OPERATION
|
|
1976
|
+
)
|
|
1977
|
+
raise exception
|
|
1153
1978
|
if group_type == snowflake_proto.Aggregate.GROUP_TYPE_GROUPING_SETS:
|
|
1154
1979
|
group_expression_list = [] # TODO: exp.userGivenGroupByExprs()?
|
|
1155
1980
|
else:
|
|
@@ -1281,38 +2106,89 @@ def map_logical_plan_relation(
|
|
|
1281
2106
|
case "Pivot":
|
|
1282
2107
|
pivot_column = map_logical_plan_expression(rel.pivotColumn())
|
|
1283
2108
|
session = snowpark.Session.get_active_session()
|
|
1284
|
-
m = ColumnNameMap([], []
|
|
2109
|
+
m = ColumnNameMap([], [])
|
|
1285
2110
|
|
|
1286
|
-
|
|
1287
|
-
|
|
1288
|
-
|
|
2111
|
+
pivot_columns = (
|
|
2112
|
+
[
|
|
2113
|
+
col
|
|
2114
|
+
for col in pivot_column.unresolved_function.arguments
|
|
2115
|
+
if col.HasField("unresolved_attribute")
|
|
2116
|
+
]
|
|
2117
|
+
if pivot_column.HasField("unresolved_function")
|
|
2118
|
+
else [pivot_column]
|
|
2119
|
+
)
|
|
1289
2120
|
|
|
1290
|
-
|
|
2121
|
+
typer = ExpressionTyper.dummy_typer(session)
|
|
1291
2122
|
|
|
1292
|
-
|
|
1293
|
-
|
|
1294
|
-
|
|
2123
|
+
expression_protos: list[expressions_proto.Expression] = []
|
|
2124
|
+
expressions: list[TypedColumn] = []
|
|
2125
|
+
aliases: list[str] = []
|
|
2126
|
+
|
|
2127
|
+
for pivot_value in as_java_list(rel.pivotValues()):
|
|
2128
|
+
expr_proto = map_logical_plan_expression(pivot_value)
|
|
2129
|
+
alias, expr = map_single_column_expression(expr_proto, m, typer)
|
|
2130
|
+
|
|
2131
|
+
expression_protos.append(expr_proto)
|
|
2132
|
+
expressions.append(expr)
|
|
2133
|
+
aliases.append(alias)
|
|
2134
|
+
|
|
2135
|
+
resolved_pivot_values_row = (
|
|
2136
|
+
session.range(1)
|
|
2137
|
+
.select(*[expr.col for expr in expressions])
|
|
2138
|
+
.collect()[0]
|
|
2139
|
+
)
|
|
2140
|
+
resolved_pivot_values = [value for value in resolved_pivot_values_row]
|
|
2141
|
+
|
|
2142
|
+
pivot_values = []
|
|
2143
|
+
for expr_proto, expr, alias, value in zip(
|
|
2144
|
+
expression_protos, expressions, aliases, resolved_pivot_values
|
|
2145
|
+
):
|
|
2146
|
+
literals_proto = (
|
|
2147
|
+
[
|
|
2148
|
+
_map_value_to_literal_proto(v, expr.typ.fields[i].datatype)
|
|
2149
|
+
for i, v in enumerate(value)
|
|
2150
|
+
]
|
|
2151
|
+
if isinstance(expr.typ, snowpark.types.StructType)
|
|
2152
|
+
else [_map_value_to_literal_proto(value, expr.typ)]
|
|
1295
2153
|
)
|
|
1296
|
-
|
|
1297
|
-
|
|
1298
|
-
|
|
2154
|
+
|
|
2155
|
+
if len(pivot_columns) != len(literals_proto):
|
|
2156
|
+
raise AnalysisException(
|
|
2157
|
+
f"[PIVOT_VALUE_DATA_TYPE_MISMATCH] Number of pivot columns ({len(pivot_columns)}) does not match number of values ({len(literals_proto)})"
|
|
2158
|
+
)
|
|
2159
|
+
|
|
2160
|
+
current_pivot_value_proto = (
|
|
2161
|
+
snowflake_proto.Aggregate.Pivot.PivotValue(
|
|
2162
|
+
values=literals_proto, alias=alias
|
|
2163
|
+
)
|
|
2164
|
+
if expr_proto.HasField("alias")
|
|
2165
|
+
else snowflake_proto.Aggregate.Pivot.PivotValue(
|
|
2166
|
+
values=literals_proto
|
|
2167
|
+
)
|
|
1299
2168
|
)
|
|
1300
2169
|
|
|
2170
|
+
pivot_values.append(current_pivot_value_proto)
|
|
2171
|
+
|
|
1301
2172
|
aggregate_expressions = [
|
|
1302
2173
|
map_logical_plan_expression(e) for e in as_java_list(rel.aggregates())
|
|
1303
2174
|
]
|
|
1304
2175
|
|
|
1305
|
-
|
|
1306
|
-
|
|
1307
|
-
|
|
1308
|
-
|
|
1309
|
-
|
|
1310
|
-
|
|
1311
|
-
|
|
1312
|
-
|
|
2176
|
+
any_proto = Any()
|
|
2177
|
+
any_proto.Pack(
|
|
2178
|
+
snowflake_proto.Extension(
|
|
2179
|
+
aggregate=snowflake_proto.Aggregate(
|
|
2180
|
+
input=map_logical_plan_relation(rel.child()),
|
|
2181
|
+
group_type=relation_proto.Aggregate.GroupType.GROUP_TYPE_PIVOT,
|
|
2182
|
+
aggregate_expressions=aggregate_expressions,
|
|
2183
|
+
having_condition=_having_condition.get(),
|
|
2184
|
+
pivot=snowflake_proto.Aggregate.Pivot(
|
|
2185
|
+
pivot_columns=pivot_columns,
|
|
2186
|
+
pivot_values=pivot_values,
|
|
2187
|
+
),
|
|
2188
|
+
)
|
|
1313
2189
|
)
|
|
1314
2190
|
)
|
|
1315
|
-
|
|
2191
|
+
proto = relation_proto.Relation(extension=any_proto)
|
|
1316
2192
|
case "PlanWithUnresolvedIdentifier":
|
|
1317
2193
|
expr_proto = map_logical_plan_expression(rel.identifierExpr())
|
|
1318
2194
|
session = snowpark.Session.get_active_session()
|
|
@@ -1343,23 +2219,119 @@ def map_logical_plan_relation(
|
|
|
1343
2219
|
)
|
|
1344
2220
|
)
|
|
1345
2221
|
case "Sort":
|
|
2222
|
+
# Process the input first
|
|
2223
|
+
input_proto = map_logical_plan_relation(rel.child())
|
|
2224
|
+
|
|
2225
|
+
# Check if child is a Project - if so, build an alias map for ORDER BY resolution
|
|
2226
|
+
# This handles: SELECT o.date AS order_date ... ORDER BY o.date
|
|
2227
|
+
child_class = str(rel.child().getClass().getSimpleName())
|
|
2228
|
+
alias_map = {}
|
|
2229
|
+
|
|
2230
|
+
if child_class == "Project":
|
|
2231
|
+
# Extract aliases from SELECT clause
|
|
2232
|
+
for proj_expr in list(as_java_list(rel.child().projectList())):
|
|
2233
|
+
if str(proj_expr.getClass().getSimpleName()) == "Alias":
|
|
2234
|
+
alias_name = str(proj_expr.name())
|
|
2235
|
+
child_expr = proj_expr.child()
|
|
2236
|
+
|
|
2237
|
+
# Store mapping from original expression to alias name
|
|
2238
|
+
# Use string representation for matching
|
|
2239
|
+
expr_str = str(child_expr)
|
|
2240
|
+
alias_map[expr_str] = alias_name
|
|
2241
|
+
|
|
2242
|
+
# Also handle UnresolvedAttribute specifically to get the qualified name
|
|
2243
|
+
if (
|
|
2244
|
+
str(child_expr.getClass().getSimpleName())
|
|
2245
|
+
== "UnresolvedAttribute"
|
|
2246
|
+
):
|
|
2247
|
+
# Get the qualified name like "o.date"
|
|
2248
|
+
name_parts = list(as_java_list(child_expr.nameParts()))
|
|
2249
|
+
qualified_name = ".".join(str(part) for part in name_parts)
|
|
2250
|
+
if qualified_name not in alias_map:
|
|
2251
|
+
alias_map[qualified_name] = alias_name
|
|
2252
|
+
|
|
2253
|
+
# Process ORDER BY expressions, substituting aliases where needed
|
|
2254
|
+
order_list = []
|
|
2255
|
+
for order_expr in as_java_list(rel.order()):
|
|
2256
|
+
# Get the child expression from the SortOrder
|
|
2257
|
+
child_expr = order_expr.child()
|
|
2258
|
+
expr_class = str(child_expr.getClass().getSimpleName())
|
|
2259
|
+
|
|
2260
|
+
# Check if this expression matches any aliased expression
|
|
2261
|
+
expr_str = str(child_expr)
|
|
2262
|
+
substituted = False
|
|
2263
|
+
|
|
2264
|
+
if expr_str in alias_map:
|
|
2265
|
+
# Found a match - substitute with alias reference
|
|
2266
|
+
alias_name = alias_map[expr_str]
|
|
2267
|
+
# Create new UnresolvedAttribute for the alias
|
|
2268
|
+
UnresolvedAttribute = jpype.JClass(
|
|
2269
|
+
"org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute"
|
|
2270
|
+
)
|
|
2271
|
+
new_attr = UnresolvedAttribute.quoted(alias_name)
|
|
2272
|
+
|
|
2273
|
+
# Create new SortOrder with substituted expression
|
|
2274
|
+
SortOrder = jpype.JClass(
|
|
2275
|
+
"org.apache.spark.sql.catalyst.expressions.SortOrder"
|
|
2276
|
+
)
|
|
2277
|
+
new_order = SortOrder(
|
|
2278
|
+
new_attr,
|
|
2279
|
+
order_expr.direction(),
|
|
2280
|
+
order_expr.nullOrdering(),
|
|
2281
|
+
order_expr.sameOrderExpressions(),
|
|
2282
|
+
)
|
|
2283
|
+
order_list.append(map_logical_plan_expression(new_order).sort_order)
|
|
2284
|
+
substituted = True
|
|
2285
|
+
elif expr_class == "UnresolvedAttribute":
|
|
2286
|
+
# Try matching on qualified name
|
|
2287
|
+
name_parts = list(as_java_list(child_expr.nameParts()))
|
|
2288
|
+
qualified_name = ".".join(str(part) for part in name_parts)
|
|
2289
|
+
if qualified_name in alias_map:
|
|
2290
|
+
alias_name = alias_map[qualified_name]
|
|
2291
|
+
UnresolvedAttribute = jpype.JClass(
|
|
2292
|
+
"org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute"
|
|
2293
|
+
)
|
|
2294
|
+
new_attr = UnresolvedAttribute.quoted(alias_name)
|
|
2295
|
+
|
|
2296
|
+
SortOrder = jpype.JClass(
|
|
2297
|
+
"org.apache.spark.sql.catalyst.expressions.SortOrder"
|
|
2298
|
+
)
|
|
2299
|
+
new_order = SortOrder(
|
|
2300
|
+
new_attr,
|
|
2301
|
+
order_expr.direction(),
|
|
2302
|
+
order_expr.nullOrdering(),
|
|
2303
|
+
order_expr.sameOrderExpressions(),
|
|
2304
|
+
)
|
|
2305
|
+
order_list.append(
|
|
2306
|
+
map_logical_plan_expression(new_order).sort_order
|
|
2307
|
+
)
|
|
2308
|
+
substituted = True
|
|
2309
|
+
|
|
2310
|
+
if not substituted:
|
|
2311
|
+
# No substitution needed - use original
|
|
2312
|
+
order_list.append(
|
|
2313
|
+
map_logical_plan_expression(order_expr).sort_order
|
|
2314
|
+
)
|
|
2315
|
+
|
|
1346
2316
|
proto = relation_proto.Relation(
|
|
1347
2317
|
sort=relation_proto.Sort(
|
|
1348
|
-
input=
|
|
1349
|
-
order=
|
|
1350
|
-
map_logical_plan_expression(e).sort_order
|
|
1351
|
-
for e in as_java_list(rel.order())
|
|
1352
|
-
],
|
|
2318
|
+
input=input_proto,
|
|
2319
|
+
order=order_list,
|
|
1353
2320
|
)
|
|
1354
2321
|
)
|
|
1355
2322
|
case "SubqueryAlias":
|
|
1356
2323
|
alias = str(rel.alias())
|
|
1357
|
-
|
|
1358
|
-
|
|
1359
|
-
|
|
1360
|
-
alias=alias,
|
|
1361
|
-
)
|
|
2324
|
+
# If the child is an UnresolvedRelation, we want to preserve the original plan id and save only aliased one
|
|
2325
|
+
process_aliased_relation = (
|
|
2326
|
+
str(rel.child().getClass().getSimpleName()) == "UnresolvedRelation"
|
|
1362
2327
|
)
|
|
2328
|
+
with push_processing_aliased_relation_scope(process_aliased_relation):
|
|
2329
|
+
proto = relation_proto.Relation(
|
|
2330
|
+
subquery_alias=relation_proto.SubqueryAlias(
|
|
2331
|
+
input=map_logical_plan_relation(rel.child()),
|
|
2332
|
+
alias=alias,
|
|
2333
|
+
)
|
|
2334
|
+
)
|
|
1363
2335
|
set_sql_plan_name(alias, plan_id)
|
|
1364
2336
|
case "Union":
|
|
1365
2337
|
children = as_java_list(rel.children())
|
|
@@ -1381,12 +2353,14 @@ def map_logical_plan_relation(
|
|
|
1381
2353
|
|
|
1382
2354
|
# Check for multi-column UNPIVOT which Snowflake doesn't support
|
|
1383
2355
|
if len(value_column_names) > 1:
|
|
1384
|
-
|
|
2356
|
+
exception = UnsupportedOperationException(
|
|
1385
2357
|
f"Multi-column UNPIVOT is not supported. Snowflake SQL does not support unpivoting "
|
|
1386
2358
|
f"multiple value columns ({', '.join(value_column_names)}) in a single operation. "
|
|
1387
2359
|
f"Workaround: Use separate UNPIVOT operations for each value column and join the results, "
|
|
1388
2360
|
f"or restructure your query to unpivot columns individually."
|
|
1389
2361
|
)
|
|
2362
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
2363
|
+
raise exception
|
|
1390
2364
|
|
|
1391
2365
|
values = []
|
|
1392
2366
|
values_groups = as_java_list(rel.values().get())
|
|
@@ -1394,11 +2368,13 @@ def map_logical_plan_relation(
|
|
|
1394
2368
|
# Check if we have multi-column groups in the IN clause
|
|
1395
2369
|
if values_groups and len(as_java_list(values_groups[0])) > 1:
|
|
1396
2370
|
group_sizes = [len(as_java_list(group)) for group in values_groups]
|
|
1397
|
-
|
|
2371
|
+
exception = UnsupportedOperationException(
|
|
1398
2372
|
f"Multi-column UNPIVOT is not supported. Snowflake SQL does not support unpivoting "
|
|
1399
2373
|
f"multiple columns together in groups. Found groups with {max(group_sizes)} columns. "
|
|
1400
2374
|
f"Workaround: Unpivot each column separately and then join/union the results as needed."
|
|
1401
2375
|
)
|
|
2376
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
2377
|
+
raise exception
|
|
1402
2378
|
|
|
1403
2379
|
for e1 in values_groups:
|
|
1404
2380
|
for e in as_java_list(e1):
|
|
@@ -1444,9 +2420,11 @@ def map_logical_plan_relation(
|
|
|
1444
2420
|
# Store the having condition in context and process the child aggregate
|
|
1445
2421
|
child_relation = rel.child()
|
|
1446
2422
|
if str(child_relation.getClass().getSimpleName()) != "Aggregate":
|
|
1447
|
-
|
|
2423
|
+
exception = SnowparkConnectNotImplementedError(
|
|
1448
2424
|
"UnresolvedHaving can only be applied to Aggregate relations"
|
|
1449
2425
|
)
|
|
2426
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
2427
|
+
raise exception
|
|
1450
2428
|
|
|
1451
2429
|
# Store having condition in a context variable for the Aggregate case to pick up
|
|
1452
2430
|
having_condition = map_logical_plan_expression(rel.havingCondition())
|
|
@@ -1509,7 +2487,8 @@ def map_logical_plan_relation(
|
|
|
1509
2487
|
)
|
|
1510
2488
|
case "UnresolvedRelation":
|
|
1511
2489
|
name = str(rel.name())
|
|
1512
|
-
|
|
2490
|
+
if not get_is_processing_aliased_relation():
|
|
2491
|
+
set_sql_plan_name(name, plan_id)
|
|
1513
2492
|
|
|
1514
2493
|
cte_proto = _ctes.get().get(name)
|
|
1515
2494
|
if cte_proto is not None:
|
|
@@ -1530,10 +2509,16 @@ def map_logical_plan_relation(
|
|
|
1530
2509
|
)
|
|
1531
2510
|
|
|
1532
2511
|
# Re-evaluate the CTE definition with a fresh plan_id
|
|
1533
|
-
|
|
1534
|
-
|
|
1535
|
-
|
|
1536
|
-
|
|
2512
|
+
# Clear HAVING condition to prevent leakage from outer CTEs
|
|
2513
|
+
saved_having = _having_condition.get()
|
|
2514
|
+
_having_condition.set(None)
|
|
2515
|
+
try:
|
|
2516
|
+
fresh_plan_id = gen_sql_plan_id()
|
|
2517
|
+
fresh_cte_proto = map_logical_plan_relation(
|
|
2518
|
+
cte_definition, fresh_plan_id
|
|
2519
|
+
)
|
|
2520
|
+
finally:
|
|
2521
|
+
_having_condition.set(saved_having)
|
|
1537
2522
|
|
|
1538
2523
|
# Use SubqueryColumnAliases to ensure consistent column names across CTE references
|
|
1539
2524
|
# This is crucial for CTEs that reference other CTEs
|
|
@@ -1612,14 +2597,35 @@ def map_logical_plan_relation(
|
|
|
1612
2597
|
.collect()[0]
|
|
1613
2598
|
)
|
|
1614
2599
|
|
|
2600
|
+
def _parse_value(argument, place):
|
|
2601
|
+
if isinstance(argument, (Decimal, float)):
|
|
2602
|
+
return int(argument)
|
|
2603
|
+
elif isinstance(argument, str):
|
|
2604
|
+
try:
|
|
2605
|
+
value = float(argument)
|
|
2606
|
+
if value < 0:
|
|
2607
|
+
return math.ceil(value)
|
|
2608
|
+
return math.floor(float(argument))
|
|
2609
|
+
except ValueError:
|
|
2610
|
+
raise AnalysisException(
|
|
2611
|
+
f'[UNEXPECTED_INPUT_TYPE] Parameter {place} of function `range` requires the "BIGINT" type, however "{argument}" has the type "STRING"'
|
|
2612
|
+
)
|
|
2613
|
+
return argument
|
|
2614
|
+
|
|
1615
2615
|
start, step = 0, 1
|
|
1616
2616
|
match args:
|
|
1617
2617
|
case [_]:
|
|
1618
2618
|
[end] = args
|
|
2619
|
+
end = _parse_value(end, 1)
|
|
1619
2620
|
case [_, _]:
|
|
1620
2621
|
[start, end] = args
|
|
2622
|
+
start = _parse_value(start, 1)
|
|
2623
|
+
end = _parse_value(end, 2)
|
|
1621
2624
|
case [_, _, _]:
|
|
1622
2625
|
[start, end, step] = args
|
|
2626
|
+
start = _parse_value(start, 1)
|
|
2627
|
+
end = _parse_value(end, 2)
|
|
2628
|
+
step = _parse_value(step, 3)
|
|
1623
2629
|
|
|
1624
2630
|
proto = relation_proto.Relation(
|
|
1625
2631
|
range=relation_proto.Range(
|
|
@@ -1688,16 +2694,7 @@ def map_logical_plan_relation(
|
|
|
1688
2694
|
),
|
|
1689
2695
|
)
|
|
1690
2696
|
case "UnresolvedWith":
|
|
1691
|
-
with
|
|
1692
|
-
for cte in as_java_list(rel.cteRelations()):
|
|
1693
|
-
name = str(cte._1())
|
|
1694
|
-
# Store the original CTE definition for re-evaluation
|
|
1695
|
-
_cte_definitions.get()[name] = cte._2()
|
|
1696
|
-
# Process CTE definition with a unique plan_id to ensure proper column naming
|
|
1697
|
-
cte_plan_id = gen_sql_plan_id()
|
|
1698
|
-
cte_proto = map_logical_plan_relation(cte._2(), cte_plan_id)
|
|
1699
|
-
_ctes.get()[name] = cte_proto
|
|
1700
|
-
|
|
2697
|
+
with _with_cte_scope(rel.cteRelations()):
|
|
1701
2698
|
proto = map_logical_plan_relation(rel.child())
|
|
1702
2699
|
case "LateralJoin":
|
|
1703
2700
|
left = map_logical_plan_relation(rel.left())
|
|
@@ -1719,41 +2716,16 @@ def map_logical_plan_relation(
|
|
|
1719
2716
|
_window_specs.get()[key] = window_spec
|
|
1720
2717
|
proto = map_logical_plan_relation(rel.child())
|
|
1721
2718
|
case "Generate":
|
|
1722
|
-
# Generate creates a nested Project relation (see lines 1785-1790) without
|
|
1723
|
-
# setting its plan_id field. When this Project is later processed by map_project
|
|
1724
|
-
# (map_column_ops.py), it uses rel.common.plan_id which defaults to 0 for unset
|
|
1725
|
-
# protobuf fields. This means all columns from the Generate operation (both exploded
|
|
1726
|
-
# columns and passthrough columns) will have plan_id=0 in their names.
|
|
1727
|
-
#
|
|
1728
|
-
# If Generate's child is a SubqueryAlias whose inner relation was processed
|
|
1729
|
-
# with a non-zero plan_id, there will be a mismatch between:
|
|
1730
|
-
# - The columns referenced in the Project (expecting plan_id from SubqueryAlias's child)
|
|
1731
|
-
# - The actual column names created by Generate's Project (using plan_id=0)
|
|
1732
|
-
|
|
1733
|
-
# Therefore, when Generate has a SubqueryAlias child, we explicitly process the inner
|
|
1734
|
-
# relation with plan_id=0 to match what Generate's Project will use. This only applies when
|
|
1735
|
-
# the immediate child of Generate is a SubqueryAlias and preserves existing registrations (like CTEs),
|
|
1736
|
-
# so it won't affect other patterns.
|
|
1737
|
-
|
|
1738
2719
|
child_class = str(rel.child().getClass().getSimpleName())
|
|
1739
2720
|
|
|
1740
2721
|
if child_class == "SubqueryAlias":
|
|
1741
2722
|
alias = str(rel.child().alias())
|
|
1742
2723
|
|
|
1743
|
-
# Check if this alias was already registered during initial SQL parsing
|
|
1744
2724
|
existing_plan_id = get_sql_plan(alias)
|
|
1745
2725
|
|
|
1746
|
-
if existing_plan_id is not None:
|
|
1747
|
-
# Use the existing plan_id to maintain consistency with prior registration
|
|
1748
|
-
used_plan_id = existing_plan_id
|
|
1749
|
-
else:
|
|
1750
|
-
# Use plan_id=0 to match what the nested Project will use (protobuf default)
|
|
1751
|
-
used_plan_id = 0
|
|
1752
|
-
set_sql_plan_name(alias, used_plan_id)
|
|
1753
|
-
|
|
1754
2726
|
# Process the inner child with the determined plan_id
|
|
1755
2727
|
inner_child = map_logical_plan_relation(
|
|
1756
|
-
rel.child().child(), plan_id=
|
|
2728
|
+
rel.child().child(), plan_id=existing_plan_id
|
|
1757
2729
|
)
|
|
1758
2730
|
input_relation = relation_proto.Relation(
|
|
1759
2731
|
subquery_alias=relation_proto.SubqueryAlias(
|
|
@@ -1771,19 +2743,19 @@ def map_logical_plan_relation(
|
|
|
1771
2743
|
function_name = rel.generator().name().toString()
|
|
1772
2744
|
func_arguments = [
|
|
1773
2745
|
map_logical_plan_expression(e)
|
|
1774
|
-
for e in as_java_list(rel.generator().children())
|
|
2746
|
+
for e in list(as_java_list(rel.generator().children()))
|
|
1775
2747
|
]
|
|
1776
2748
|
unresolved_fun_proto = expressions_proto.Expression.UnresolvedFunction(
|
|
1777
2749
|
function_name=function_name, arguments=func_arguments
|
|
1778
2750
|
)
|
|
1779
2751
|
|
|
1780
|
-
aliased_proto =
|
|
2752
|
+
aliased_proto = expressions_proto.Expression(
|
|
2753
|
+
unresolved_function=unresolved_fun_proto,
|
|
2754
|
+
)
|
|
1781
2755
|
if generator_output_list.size() > 0:
|
|
1782
2756
|
aliased_proto = expressions_proto.Expression(
|
|
1783
2757
|
alias=expressions_proto.Expression.Alias(
|
|
1784
|
-
expr=
|
|
1785
|
-
unresolved_function=unresolved_fun_proto,
|
|
1786
|
-
),
|
|
2758
|
+
expr=aliased_proto,
|
|
1787
2759
|
name=[attribute.name() for attribute in generator_output_list],
|
|
1788
2760
|
)
|
|
1789
2761
|
)
|
|
@@ -1837,28 +2809,67 @@ def map_logical_plan_relation(
|
|
|
1837
2809
|
)
|
|
1838
2810
|
proto = generator_dataframe_proto
|
|
1839
2811
|
case other:
|
|
1840
|
-
|
|
2812
|
+
exception = SnowparkConnectNotImplementedError(
|
|
2813
|
+
f"Unimplemented relation: {other}"
|
|
2814
|
+
)
|
|
2815
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
2816
|
+
raise exception
|
|
1841
2817
|
|
|
1842
2818
|
proto.common.plan_id = plan_id
|
|
1843
2819
|
|
|
1844
2820
|
return proto
|
|
1845
2821
|
|
|
1846
2822
|
|
|
2823
|
+
def _get_relation_identifier(name_obj) -> str:
|
|
2824
|
+
# IDENTIFIER(<table_name>), or IDENTIFIER(<method name>)
|
|
2825
|
+
expr_proto = map_logical_plan_expression(name_obj.identifierExpr())
|
|
2826
|
+
session = snowpark.Session.get_active_session()
|
|
2827
|
+
m = ColumnNameMap([], [], None)
|
|
2828
|
+
expr = map_single_column_expression(
|
|
2829
|
+
expr_proto, m, ExpressionTyper.dummy_typer(session)
|
|
2830
|
+
)
|
|
2831
|
+
return spark_to_sf_single_id(session.range(1).select(expr[1].col).collect()[0][0])
|
|
2832
|
+
|
|
2833
|
+
|
|
2834
|
+
def _create_temp_view_name(parts) -> str:
|
|
2835
|
+
return ".".join(
|
|
2836
|
+
quote_name_without_upper_casing(str(part)) for part in as_java_list(parts)
|
|
2837
|
+
)
|
|
2838
|
+
|
|
2839
|
+
|
|
2840
|
+
def get_relation_identifier_name_without_uppercasing(
|
|
2841
|
+
name_obj, is_multi_part: bool = False
|
|
2842
|
+
) -> str:
|
|
2843
|
+
if name_obj.getClass().getSimpleName() in (
|
|
2844
|
+
"PlanWithUnresolvedIdentifier",
|
|
2845
|
+
"ExpressionWithUnresolvedIdentifier",
|
|
2846
|
+
):
|
|
2847
|
+
return _get_relation_identifier(name_obj)
|
|
2848
|
+
elif is_multi_part:
|
|
2849
|
+
try:
|
|
2850
|
+
# Try multipartIdentifier first for full catalog.database.table
|
|
2851
|
+
return _create_temp_view_name(name_obj.multipartIdentifier())
|
|
2852
|
+
except AttributeError:
|
|
2853
|
+
# Fallback to nameParts if multipartIdentifier not available
|
|
2854
|
+
return _create_temp_view_name(name_obj.nameParts())
|
|
2855
|
+
else:
|
|
2856
|
+
return _create_temp_view_name(name_obj.nameParts())
|
|
2857
|
+
|
|
2858
|
+
|
|
1847
2859
|
def get_relation_identifier_name(name_obj, is_multi_part: bool = False) -> str:
|
|
1848
|
-
if name_obj.getClass().getSimpleName()
|
|
1849
|
-
|
|
1850
|
-
|
|
1851
|
-
|
|
1852
|
-
|
|
1853
|
-
expr = map_single_column_expression(
|
|
1854
|
-
expr_proto, m, ExpressionTyper.dummy_typer(session)
|
|
1855
|
-
)
|
|
1856
|
-
name = spark_to_sf_single_id(
|
|
1857
|
-
session.range(1).select(expr[1].col).collect()[0][0]
|
|
1858
|
-
)
|
|
2860
|
+
if name_obj.getClass().getSimpleName() in (
|
|
2861
|
+
"PlanWithUnresolvedIdentifier",
|
|
2862
|
+
"ExpressionWithUnresolvedIdentifier",
|
|
2863
|
+
):
|
|
2864
|
+
return _get_relation_identifier(name_obj)
|
|
1859
2865
|
else:
|
|
1860
2866
|
if is_multi_part:
|
|
1861
|
-
|
|
2867
|
+
try:
|
|
2868
|
+
# Try multipartIdentifier first for full catalog.database.table
|
|
2869
|
+
name = _spark_to_snowflake(name_obj.multipartIdentifier())
|
|
2870
|
+
except AttributeError:
|
|
2871
|
+
# Fallback to nameParts if multipartIdentifier not available
|
|
2872
|
+
name = _spark_to_snowflake(name_obj.nameParts())
|
|
1862
2873
|
else:
|
|
1863
2874
|
name = _spark_to_snowflake(name_obj.nameParts())
|
|
1864
2875
|
|