snowpark-connect 0.31.0__py3-none-any.whl → 0.33.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of snowpark-connect might be problematic. Click here for more details.
- snowflake/snowpark_connect/__init__.py +1 -0
- snowflake/snowpark_connect/column_name_handler.py +143 -105
- snowflake/snowpark_connect/column_qualifier.py +43 -0
- snowflake/snowpark_connect/dataframe_container.py +3 -2
- snowflake/snowpark_connect/execute_plan/map_execution_command.py +4 -2
- snowflake/snowpark_connect/expression/hybrid_column_map.py +5 -4
- snowflake/snowpark_connect/expression/map_expression.py +5 -4
- snowflake/snowpark_connect/expression/map_extension.py +12 -6
- snowflake/snowpark_connect/expression/map_sql_expression.py +50 -7
- snowflake/snowpark_connect/expression/map_unresolved_attribute.py +62 -25
- snowflake/snowpark_connect/expression/map_unresolved_function.py +924 -127
- snowflake/snowpark_connect/expression/map_unresolved_star.py +9 -7
- snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/__init__.py +16 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/accessors.py +1281 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/functions.py +203 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/utils.py +202 -0
- snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +4 -1
- snowflake/snowpark_connect/relation/map_aggregate.py +6 -5
- snowflake/snowpark_connect/relation/map_column_ops.py +9 -3
- snowflake/snowpark_connect/relation/map_extension.py +10 -9
- snowflake/snowpark_connect/relation/map_join.py +219 -144
- snowflake/snowpark_connect/relation/map_row_ops.py +136 -54
- snowflake/snowpark_connect/relation/map_sql.py +134 -16
- snowflake/snowpark_connect/relation/map_subquery_alias.py +4 -1
- snowflake/snowpark_connect/relation/read/map_read_json.py +87 -2
- snowflake/snowpark_connect/relation/read/map_read_table.py +6 -3
- snowflake/snowpark_connect/relation/utils.py +46 -0
- snowflake/snowpark_connect/relation/write/map_write.py +215 -289
- snowflake/snowpark_connect/resources_initializer.py +25 -13
- snowflake/snowpark_connect/server.py +10 -26
- snowflake/snowpark_connect/type_mapping.py +38 -3
- snowflake/snowpark_connect/typed_column.py +8 -6
- snowflake/snowpark_connect/utils/sequence.py +21 -0
- snowflake/snowpark_connect/utils/session.py +27 -4
- snowflake/snowpark_connect/version.py +1 -1
- snowflake/snowpark_decoder/dp_session.py +1 -1
- {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.33.0.dist-info}/METADATA +7 -2
- {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.33.0.dist-info}/RECORD +46 -105
- snowflake/snowpark_connect/includes/jars/antlr4-runtime-4.9.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-cli-1.5.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-codec-1.16.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-collections-3.2.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-collections4-4.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-compiler-3.1.9.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-compress-1.26.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-crypto-1.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-dbcp-1.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-io-2.16.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-lang-2.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-lang3-3.12.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-logging-1.1.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-math3-3.6.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-pool-1.5.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-text-1.10.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/hadoop-client-api-trimmed-3.3.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-annotations-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-core-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-core-asl-1.9.13.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-databind-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-dataformat-yaml-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-datatype-jsr310-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-module-scala_2.12-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-ast_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-core_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-jackson_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-native_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-scalap_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/kryo-shaded-4.0.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-1.2-api-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-api-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-core-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-slf4j2-impl-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/paranamer-2.8.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/paranamer-2.8.jar +0 -0
- snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-collection-compat_2.12-2.7.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-library-2.12.18.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-parser-combinators_2.12-2.3.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-reflect-2.12.18.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-xml_2.12-2.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/slf4j-api-2.0.7.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-catalyst_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-common-utils_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-connect-client-jvm_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-core_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-graphx_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-hive-thriftserver_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-hive_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-kvstore_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-launcher_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-mesos_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-mllib-local_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-network-common_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-network-shuffle_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-repl_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sketch_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sql-api_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sql_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-tags_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-unsafe_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-yarn_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2_grpc.py +0 -4
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2_grpc.py +0 -4
- {snowpark_connect-0.31.0.data → snowpark_connect-0.33.0.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.31.0.data → snowpark_connect-0.33.0.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.31.0.data → snowpark_connect-0.33.0.data}/scripts/snowpark-submit +0 -0
- {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.33.0.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.33.0.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.33.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.33.0.dist-info}/licenses/NOTICE-binary +0 -0
- {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.33.0.dist-info}/top_level.txt +0 -0
|
@@ -1,15 +1,22 @@
|
|
|
1
1
|
#
|
|
2
2
|
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
|
|
3
3
|
#
|
|
4
|
-
|
|
4
|
+
import dataclasses
|
|
5
|
+
from enum import Enum
|
|
5
6
|
from functools import reduce
|
|
7
|
+
from typing import Optional
|
|
6
8
|
|
|
7
9
|
import pyspark.sql.connect.proto.relations_pb2 as relation_proto
|
|
8
10
|
from pyspark.errors import AnalysisException
|
|
9
11
|
|
|
10
12
|
import snowflake.snowpark.functions as snowpark_fn
|
|
11
13
|
from snowflake import snowpark
|
|
12
|
-
from snowflake.
|
|
14
|
+
from snowflake.snowpark.types import StructField, StructType
|
|
15
|
+
from snowflake.snowpark_connect.column_name_handler import (
|
|
16
|
+
JoinColumnNameMap,
|
|
17
|
+
make_unique_snowpark_name,
|
|
18
|
+
)
|
|
19
|
+
from snowflake.snowpark_connect.column_qualifier import ColumnQualifier
|
|
13
20
|
from snowflake.snowpark_connect.config import global_config
|
|
14
21
|
from snowflake.snowpark_connect.constants import COLUMN_METADATA_COLLISION_KEY
|
|
15
22
|
from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
|
|
@@ -42,6 +49,25 @@ from snowflake.snowpark_connect.utils.telemetry import (
|
|
|
42
49
|
USING_COLUMN_NOT_FOUND_ERROR = "[UNRESOLVED_USING_COLUMN_FOR_JOIN] USING column `{0}` not found on the {1} side of the join. The {1}-side columns: {2}"
|
|
43
50
|
|
|
44
51
|
|
|
52
|
+
class ConditionType(Enum):
|
|
53
|
+
USING_COLUMNS = 1
|
|
54
|
+
JOIN_CONDITION = 2
|
|
55
|
+
NO_CONDITION = 3
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@dataclasses.dataclass
|
|
59
|
+
class JoinInfo:
|
|
60
|
+
join_type: str
|
|
61
|
+
condition_type: ConditionType
|
|
62
|
+
join_columns: Optional[list[str]]
|
|
63
|
+
|
|
64
|
+
def has_join_condition(self) -> bool:
|
|
65
|
+
return self.condition_type == ConditionType.JOIN_CONDITION
|
|
66
|
+
|
|
67
|
+
def is_using_columns(self):
|
|
68
|
+
return self.condition_type == ConditionType.USING_COLUMNS
|
|
69
|
+
|
|
70
|
+
|
|
45
71
|
def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
|
|
46
72
|
left_container: DataFrameContainer = map_relation(rel.join.left)
|
|
47
73
|
right_container: DataFrameContainer = map_relation(rel.join.right)
|
|
@@ -53,48 +79,11 @@ def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
|
|
|
53
79
|
|
|
54
80
|
left_input: snowpark.DataFrame = left_container.dataframe
|
|
55
81
|
right_input: snowpark.DataFrame = right_container.dataframe
|
|
56
|
-
is_natural_join = rel.join.join_type >= NATURAL_JOIN_TYPE_BASE
|
|
57
|
-
using_columns = rel.join.using_columns
|
|
58
|
-
if is_natural_join:
|
|
59
|
-
rel.join.join_type -= NATURAL_JOIN_TYPE_BASE
|
|
60
|
-
left_spark_columns = left_container.column_map.get_spark_columns()
|
|
61
|
-
right_spark_columns = right_container.column_map.get_spark_columns()
|
|
62
|
-
common_spark_columns = [
|
|
63
|
-
x for x in left_spark_columns if x in right_spark_columns
|
|
64
|
-
]
|
|
65
|
-
using_columns = common_spark_columns
|
|
66
|
-
|
|
67
|
-
match rel.join.join_type:
|
|
68
|
-
case relation_proto.Join.JOIN_TYPE_UNSPECIFIED:
|
|
69
|
-
# TODO: Understand what UNSPECIFIED Join type is
|
|
70
|
-
exception = SnowparkConnectNotImplementedError("Unspecified Join Type")
|
|
71
|
-
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
72
|
-
raise exception
|
|
73
|
-
case relation_proto.Join.JOIN_TYPE_INNER:
|
|
74
|
-
join_type = "inner"
|
|
75
|
-
case relation_proto.Join.JOIN_TYPE_FULL_OUTER:
|
|
76
|
-
join_type = "full_outer"
|
|
77
|
-
case relation_proto.Join.JOIN_TYPE_LEFT_OUTER:
|
|
78
|
-
join_type = "left"
|
|
79
|
-
case relation_proto.Join.JOIN_TYPE_RIGHT_OUTER:
|
|
80
|
-
join_type = "right"
|
|
81
|
-
case relation_proto.Join.JOIN_TYPE_LEFT_ANTI:
|
|
82
|
-
join_type = "leftanti"
|
|
83
|
-
case relation_proto.Join.JOIN_TYPE_LEFT_SEMI:
|
|
84
|
-
join_type = "leftsemi"
|
|
85
|
-
case relation_proto.Join.JOIN_TYPE_CROSS:
|
|
86
|
-
join_type = "cross"
|
|
87
|
-
case other:
|
|
88
|
-
exception = SnowparkConnectNotImplementedError(f"Other Join Type: {other}")
|
|
89
|
-
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
90
|
-
raise exception
|
|
91
|
-
|
|
92
|
-
# This handles case sensitivity for using_columns
|
|
93
|
-
case_corrected_right_columns: list[str] = []
|
|
94
82
|
|
|
95
|
-
|
|
96
|
-
|
|
83
|
+
join_info = _get_join_info(rel, left_container, right_container)
|
|
84
|
+
join_type = join_info.join_type
|
|
97
85
|
|
|
86
|
+
if join_info.has_join_condition():
|
|
98
87
|
left_columns = list(left_container.column_map.spark_to_col.keys())
|
|
99
88
|
right_columns = list(right_container.column_map.spark_to_col.keys())
|
|
100
89
|
|
|
@@ -121,72 +110,42 @@ def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
|
|
|
121
110
|
result: snowpark.DataFrame = left_input.join(
|
|
122
111
|
right=right_input,
|
|
123
112
|
on=join_expression.col,
|
|
124
|
-
how=join_type,
|
|
113
|
+
how="inner" if join_info.join_type == "cross" else join_info.join_type,
|
|
125
114
|
lsuffix="_left",
|
|
126
115
|
rsuffix="_right",
|
|
127
116
|
)
|
|
128
|
-
elif
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
),
|
|
146
|
-
"left",
|
|
147
|
-
left_container.column_map.get_spark_columns(),
|
|
117
|
+
elif join_info.is_using_columns():
|
|
118
|
+
# TODO: disambiguate snowpark columns for all join condition types
|
|
119
|
+
# disambiguation temporarily done only for using_columns/natural joins to reduce changes
|
|
120
|
+
left_container, right_container = _disambiguate_snowpark_columns(
|
|
121
|
+
left_container, right_container
|
|
122
|
+
)
|
|
123
|
+
left_input = left_container.dataframe
|
|
124
|
+
right_input = right_container.dataframe
|
|
125
|
+
|
|
126
|
+
join_columns = join_info.join_columns
|
|
127
|
+
|
|
128
|
+
def _validate_using_column(
|
|
129
|
+
column: str, container: DataFrameContainer, side: str
|
|
130
|
+
) -> None:
|
|
131
|
+
if (
|
|
132
|
+
container.column_map.get_snowpark_column_name_from_spark_column_name(
|
|
133
|
+
column, allow_non_exists=True, return_first=True
|
|
148
134
|
)
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
)
|
|
156
|
-
is None
|
|
157
|
-
for c in using_columns
|
|
158
|
-
):
|
|
159
|
-
exception = AnalysisException(
|
|
160
|
-
USING_COLUMN_NOT_FOUND_ERROR.format(
|
|
161
|
-
next(
|
|
162
|
-
c
|
|
163
|
-
for c in using_columns
|
|
164
|
-
if right_container.column_map.get_snowpark_column_name_from_spark_column_name(
|
|
165
|
-
c, allow_non_exists=True, return_first=True
|
|
166
|
-
)
|
|
167
|
-
is None
|
|
168
|
-
),
|
|
169
|
-
"right",
|
|
170
|
-
right_container.column_map.get_spark_columns(),
|
|
135
|
+
is None
|
|
136
|
+
):
|
|
137
|
+
exception = AnalysisException(
|
|
138
|
+
USING_COLUMN_NOT_FOUND_ERROR.format(
|
|
139
|
+
column, side, container.column_map.get_spark_columns()
|
|
140
|
+
)
|
|
171
141
|
)
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
142
|
+
attach_custom_error_code(exception, ErrorCodes.COLUMN_NOT_FOUND)
|
|
143
|
+
raise exception
|
|
144
|
+
|
|
145
|
+
for col in join_columns:
|
|
146
|
+
_validate_using_column(col, left_container, "left")
|
|
147
|
+
_validate_using_column(col, right_container, "right")
|
|
175
148
|
|
|
176
|
-
# Round trip the using columns through the column map to get the correct names
|
|
177
|
-
# in order to support case sensitivity.
|
|
178
|
-
# TODO: case_corrected_left_columns / case_corrected_right_columns may no longer be required as Snowpark dataframe preserves the column casing now.
|
|
179
|
-
case_corrected_left_columns = left_container.column_map.get_spark_column_names_from_snowpark_column_names(
|
|
180
|
-
left_container.column_map.get_snowpark_column_names_from_spark_column_names(
|
|
181
|
-
list(using_columns), return_first=True
|
|
182
|
-
)
|
|
183
|
-
)
|
|
184
|
-
case_corrected_right_columns = right_container.column_map.get_spark_column_names_from_snowpark_column_names(
|
|
185
|
-
right_container.column_map.get_snowpark_column_names_from_spark_column_names(
|
|
186
|
-
list(using_columns), return_first=True
|
|
187
|
-
)
|
|
188
|
-
)
|
|
189
|
-
using_columns = zip(case_corrected_left_columns, case_corrected_right_columns)
|
|
190
149
|
# We cannot assume that Snowpark will have the same names for left and right columns,
|
|
191
150
|
# so we convert ["a", "b"] into (left["a"] == right["a"] & left["b"] == right["b"]),
|
|
192
151
|
# then drop right["a"] and right["b"].
|
|
@@ -194,16 +153,16 @@ def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
|
|
|
194
153
|
(
|
|
195
154
|
left_input[
|
|
196
155
|
left_container.column_map.get_snowpark_column_name_from_spark_column_name(
|
|
197
|
-
|
|
156
|
+
spark_name, return_first=True
|
|
198
157
|
)
|
|
199
158
|
],
|
|
200
159
|
right_input[
|
|
201
160
|
right_container.column_map.get_snowpark_column_name_from_spark_column_name(
|
|
202
|
-
|
|
161
|
+
spark_name, return_first=True
|
|
203
162
|
)
|
|
204
163
|
],
|
|
205
164
|
)
|
|
206
|
-
for
|
|
165
|
+
for spark_name in join_columns
|
|
207
166
|
]
|
|
208
167
|
joined_df = left_input.join(
|
|
209
168
|
right=right_input,
|
|
@@ -239,10 +198,19 @@ def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
|
|
|
239
198
|
exception = SparkException.implicit_cartesian_product("inner")
|
|
240
199
|
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
241
200
|
raise exception
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
201
|
+
# For outer joins without a condition, we need to use a TRUE condition
|
|
202
|
+
# to match Spark's behavior.
|
|
203
|
+
if join_type in ["left", "right", "full_outer"]:
|
|
204
|
+
result: snowpark.DataFrame = left_input.join(
|
|
205
|
+
right=right_input,
|
|
206
|
+
on=snowpark_fn.lit(True),
|
|
207
|
+
how=join_type,
|
|
208
|
+
)
|
|
209
|
+
else:
|
|
210
|
+
result: snowpark.DataFrame = left_input.join(
|
|
211
|
+
right=right_input,
|
|
212
|
+
how=join_type,
|
|
213
|
+
)
|
|
246
214
|
|
|
247
215
|
if join_type in ["leftanti", "leftsemi"]:
|
|
248
216
|
# Join types that only return columns from the left side:
|
|
@@ -252,37 +220,26 @@ def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
|
|
|
252
220
|
spark_cols_after_join: list[str] = left_container.column_map.get_spark_columns()
|
|
253
221
|
qualifiers = left_container.column_map.get_qualifiers()
|
|
254
222
|
else:
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
spark_col
|
|
260
|
-
for i, spark_col in enumerate(
|
|
261
|
-
right_container.column_map.get_spark_columns()
|
|
223
|
+
if not join_info.is_using_columns():
|
|
224
|
+
spark_cols_after_join: list[str] = (
|
|
225
|
+
left_container.column_map.get_spark_columns()
|
|
226
|
+
+ right_container.column_map.get_spark_columns()
|
|
262
227
|
)
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
:i
|
|
267
|
-
] # this is to make sure we only remove the column once
|
|
268
|
-
]
|
|
269
|
-
|
|
270
|
-
qualifiers = list(left_container.column_map.get_qualifiers()) + [
|
|
271
|
-
right_container.column_map.get_qualifier_for_spark_column(spark_col)
|
|
272
|
-
for i, spark_col in enumerate(
|
|
273
|
-
right_container.column_map.get_spark_columns()
|
|
228
|
+
qualifiers: list[set[ColumnQualifier]] = (
|
|
229
|
+
left_container.column_map.get_qualifiers()
|
|
230
|
+
+ right_container.column_map.get_qualifiers()
|
|
274
231
|
)
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
column_metadata.update(left_container.column_map.column_metadata)
|
|
232
|
+
else:
|
|
233
|
+
# get columns after join
|
|
234
|
+
joined_columns = left_container.column_map.get_columns_after_join(
|
|
235
|
+
right_container.column_map, join_info.join_columns
|
|
236
|
+
)
|
|
237
|
+
spark_cols_after_join: list[str] = [c.spark_name for c in joined_columns]
|
|
238
|
+
qualifiers: list[set[ColumnQualifier]] = [
|
|
239
|
+
c.qualifiers for c in joined_columns
|
|
240
|
+
]
|
|
285
241
|
|
|
242
|
+
column_metadata = dict(left_container.column_map.column_metadata or {})
|
|
286
243
|
if right_container.column_map.column_metadata:
|
|
287
244
|
for key, value in right_container.column_map.column_metadata.items():
|
|
288
245
|
if key not in column_metadata:
|
|
@@ -315,7 +272,7 @@ def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
|
|
|
315
272
|
# After a USING join, references to the right dataframe's columns should resolve
|
|
316
273
|
# to the result dataframe that contains the merged columns
|
|
317
274
|
if (
|
|
318
|
-
|
|
275
|
+
join_info.is_using_columns()
|
|
319
276
|
and rel.join.right.HasField("common")
|
|
320
277
|
and rel.join.right.common.HasField("plan_id")
|
|
321
278
|
):
|
|
@@ -325,7 +282,7 @@ def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
|
|
|
325
282
|
# For FULL OUTER joins, we also need to map the left dataframe's plan_id
|
|
326
283
|
# since both columns are replaced with a coalesced column
|
|
327
284
|
if (
|
|
328
|
-
|
|
285
|
+
join_info.is_using_columns()
|
|
329
286
|
and join_type == "full_outer"
|
|
330
287
|
and rel.join.left.HasField("common")
|
|
331
288
|
and rel.join.left.common.HasField("plan_id")
|
|
@@ -333,12 +290,12 @@ def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
|
|
|
333
290
|
left_plan_id = rel.join.left.common.plan_id
|
|
334
291
|
set_plan_id_map(left_plan_id, result_container)
|
|
335
292
|
|
|
336
|
-
if
|
|
293
|
+
if join_info.is_using_columns():
|
|
337
294
|
# When join 'using_columns', the 'join columns' should go first in result DF.
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
295
|
+
# we're only shifting left side columns, since we dropped the right-side ones
|
|
296
|
+
idxs_to_shift = left_container.column_map.get_column_indexes(
|
|
297
|
+
join_info.join_columns
|
|
298
|
+
)
|
|
342
299
|
|
|
343
300
|
def reorder(lst: list) -> list:
|
|
344
301
|
to_move = [lst[i] for i in idxs_to_shift]
|
|
@@ -367,3 +324,121 @@ def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
|
|
|
367
324
|
)
|
|
368
325
|
|
|
369
326
|
return result_container
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
def _get_join_info(
|
|
330
|
+
rel: relation_proto.Relation, left: DataFrameContainer, right: DataFrameContainer
|
|
331
|
+
) -> JoinInfo:
|
|
332
|
+
"""
|
|
333
|
+
Gathers basic information about the join, and performs basic assertions
|
|
334
|
+
"""
|
|
335
|
+
|
|
336
|
+
is_natural_join = rel.join.join_type >= NATURAL_JOIN_TYPE_BASE
|
|
337
|
+
join_columns = rel.join.using_columns
|
|
338
|
+
if is_natural_join:
|
|
339
|
+
rel.join.join_type -= NATURAL_JOIN_TYPE_BASE
|
|
340
|
+
left_spark_columns = left.column_map.get_spark_columns()
|
|
341
|
+
right_spark_columns = right.column_map.get_spark_columns()
|
|
342
|
+
common_spark_columns = [
|
|
343
|
+
x for x in left_spark_columns if x in right_spark_columns
|
|
344
|
+
]
|
|
345
|
+
join_columns = common_spark_columns
|
|
346
|
+
|
|
347
|
+
match rel.join.join_type:
|
|
348
|
+
case relation_proto.Join.JOIN_TYPE_UNSPECIFIED:
|
|
349
|
+
# TODO: Understand what UNSPECIFIED Join type is
|
|
350
|
+
exception = SnowparkConnectNotImplementedError("Unspecified Join Type")
|
|
351
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
352
|
+
raise exception
|
|
353
|
+
case relation_proto.Join.JOIN_TYPE_INNER:
|
|
354
|
+
join_type = "inner"
|
|
355
|
+
case relation_proto.Join.JOIN_TYPE_FULL_OUTER:
|
|
356
|
+
join_type = "full_outer"
|
|
357
|
+
case relation_proto.Join.JOIN_TYPE_LEFT_OUTER:
|
|
358
|
+
join_type = "left"
|
|
359
|
+
case relation_proto.Join.JOIN_TYPE_RIGHT_OUTER:
|
|
360
|
+
join_type = "right"
|
|
361
|
+
case relation_proto.Join.JOIN_TYPE_LEFT_ANTI:
|
|
362
|
+
join_type = "leftanti"
|
|
363
|
+
case relation_proto.Join.JOIN_TYPE_LEFT_SEMI:
|
|
364
|
+
join_type = "leftsemi"
|
|
365
|
+
case relation_proto.Join.JOIN_TYPE_CROSS:
|
|
366
|
+
join_type = "cross"
|
|
367
|
+
case other:
|
|
368
|
+
exception = SnowparkConnectNotImplementedError(f"Other Join Type: {other}")
|
|
369
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
370
|
+
raise exception
|
|
371
|
+
|
|
372
|
+
has_join_condition = rel.join.HasField("join_condition")
|
|
373
|
+
is_using_columns = bool(join_columns)
|
|
374
|
+
|
|
375
|
+
if has_join_condition:
|
|
376
|
+
assert not is_using_columns
|
|
377
|
+
|
|
378
|
+
condition_type = ConditionType.NO_CONDITION
|
|
379
|
+
if has_join_condition:
|
|
380
|
+
condition_type = ConditionType.JOIN_CONDITION
|
|
381
|
+
elif is_using_columns:
|
|
382
|
+
condition_type = ConditionType.USING_COLUMNS
|
|
383
|
+
|
|
384
|
+
return JoinInfo(join_type, condition_type, join_columns)
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
def _disambiguate_snowpark_columns(
|
|
388
|
+
left: DataFrameContainer, right: DataFrameContainer
|
|
389
|
+
) -> tuple[DataFrameContainer, DataFrameContainer]:
|
|
390
|
+
conflicting_snowpark_columns = left.column_map.get_conflicting_snowpark_columns(
|
|
391
|
+
right.column_map
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
if not conflicting_snowpark_columns:
|
|
395
|
+
return left, right
|
|
396
|
+
|
|
397
|
+
# rename and create new containers
|
|
398
|
+
return _disambiguate_container(
|
|
399
|
+
left, conflicting_snowpark_columns
|
|
400
|
+
), _disambiguate_container(right, conflicting_snowpark_columns)
|
|
401
|
+
|
|
402
|
+
|
|
403
|
+
def _disambiguate_container(
|
|
404
|
+
container: DataFrameContainer, conflicting_snowpark_columns: set[str]
|
|
405
|
+
) -> DataFrameContainer:
|
|
406
|
+
column_map = container.column_map
|
|
407
|
+
disambiguated_columns = []
|
|
408
|
+
disambiguated_snowpark_names = []
|
|
409
|
+
for c in column_map.columns:
|
|
410
|
+
if c.snowpark_name in conflicting_snowpark_columns:
|
|
411
|
+
# alias snowpark column with a new unique name
|
|
412
|
+
new_name = make_unique_snowpark_name(c.spark_name)
|
|
413
|
+
disambiguated_snowpark_names.append(new_name)
|
|
414
|
+
disambiguated_columns.append(
|
|
415
|
+
snowpark_fn.col(c.snowpark_name).alias(new_name)
|
|
416
|
+
)
|
|
417
|
+
else:
|
|
418
|
+
disambiguated_snowpark_names.append(c.snowpark_name)
|
|
419
|
+
disambiguated_columns.append(snowpark_fn.col(c.snowpark_name))
|
|
420
|
+
|
|
421
|
+
disambiguated_df = container.dataframe.select(*disambiguated_columns)
|
|
422
|
+
|
|
423
|
+
def _get_new_schema():
|
|
424
|
+
old_schema = container.dataframe.schema
|
|
425
|
+
if not old_schema.fields:
|
|
426
|
+
return StructType([])
|
|
427
|
+
|
|
428
|
+
new_fields = []
|
|
429
|
+
for i, name in enumerate(disambiguated_snowpark_names):
|
|
430
|
+
f = old_schema.fields[i]
|
|
431
|
+
new_fields.append(
|
|
432
|
+
StructField(name, f.datatype, nullable=f.nullable, _is_column=True)
|
|
433
|
+
)
|
|
434
|
+
return StructType(new_fields)
|
|
435
|
+
|
|
436
|
+
return DataFrameContainer.create_with_column_mapping(
|
|
437
|
+
dataframe=disambiguated_df,
|
|
438
|
+
spark_column_names=column_map.get_spark_columns(),
|
|
439
|
+
snowpark_column_names=disambiguated_snowpark_names,
|
|
440
|
+
column_metadata=column_map.column_metadata,
|
|
441
|
+
column_qualifiers=column_map.get_qualifiers(),
|
|
442
|
+
table_name=container.table_name,
|
|
443
|
+
cached_schema_getter=_get_new_schema,
|
|
444
|
+
)
|
|
@@ -45,6 +45,61 @@ from snowflake.snowpark_connect.utils.telemetry import (
|
|
|
45
45
|
)
|
|
46
46
|
|
|
47
47
|
|
|
48
|
+
def cast_columns(
|
|
49
|
+
df_container: DataFrameContainer,
|
|
50
|
+
df_dtypes: list[snowpark.types.DataType],
|
|
51
|
+
target_dtypes: list[snowpark.types.DataType],
|
|
52
|
+
column_map: ColumnNameMap,
|
|
53
|
+
):
|
|
54
|
+
df: snowpark.DataFrame = df_container.dataframe
|
|
55
|
+
if df_dtypes == target_dtypes:
|
|
56
|
+
return df_container
|
|
57
|
+
# Use cached schema if available to avoid triggering extra queries
|
|
58
|
+
if (
|
|
59
|
+
hasattr(df_container, "cached_schema_getter")
|
|
60
|
+
and df_container.cached_schema_getter is not None
|
|
61
|
+
):
|
|
62
|
+
df_schema = df_container.cached_schema_getter()
|
|
63
|
+
else:
|
|
64
|
+
df_schema = df.schema # Get current schema
|
|
65
|
+
new_columns = []
|
|
66
|
+
|
|
67
|
+
for i, field in enumerate(df_schema.fields):
|
|
68
|
+
col_name = field.name
|
|
69
|
+
current_type = field.datatype
|
|
70
|
+
target_type = target_dtypes[i]
|
|
71
|
+
|
|
72
|
+
if current_type != target_type:
|
|
73
|
+
new_columns.append(df[col_name].cast(target_type).alias(col_name))
|
|
74
|
+
else:
|
|
75
|
+
new_columns.append(df[col_name])
|
|
76
|
+
|
|
77
|
+
new_df = df.select(new_columns)
|
|
78
|
+
return DataFrameContainer.create_with_column_mapping(
|
|
79
|
+
dataframe=new_df,
|
|
80
|
+
spark_column_names=column_map.get_spark_columns(),
|
|
81
|
+
snowpark_column_names=column_map.get_snowpark_columns(),
|
|
82
|
+
snowpark_column_types=target_dtypes,
|
|
83
|
+
column_metadata=column_map.column_metadata,
|
|
84
|
+
parent_column_name_map=column_map,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def get_schema_from_result(
|
|
89
|
+
result: DataFrameContainer,
|
|
90
|
+
) -> StructType:
|
|
91
|
+
"""
|
|
92
|
+
Get schema from a DataFrameContainer, using cached schema if available to avoid extra queries.
|
|
93
|
+
"""
|
|
94
|
+
if (
|
|
95
|
+
hasattr(result, "cached_schema_getter")
|
|
96
|
+
and result.cached_schema_getter is not None
|
|
97
|
+
):
|
|
98
|
+
return result.cached_schema_getter()
|
|
99
|
+
else:
|
|
100
|
+
return result.dataframe.schema
|
|
101
|
+
|
|
102
|
+
|
|
48
103
|
def map_deduplicate(
|
|
49
104
|
rel: relation_proto.Relation,
|
|
50
105
|
) -> DataFrameContainer:
|
|
@@ -205,21 +260,8 @@ def map_union(
|
|
|
205
260
|
|
|
206
261
|
# workaround for unstructured type vs structured type
|
|
207
262
|
# Use cached schema if available to avoid triggering extra queries
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
and left_result.cached_schema_getter is not None
|
|
211
|
-
):
|
|
212
|
-
left_schema = left_result.cached_schema_getter()
|
|
213
|
-
else:
|
|
214
|
-
left_schema = left_df.schema
|
|
215
|
-
|
|
216
|
-
if (
|
|
217
|
-
hasattr(right_result, "cached_schema_getter")
|
|
218
|
-
and right_result.cached_schema_getter is not None
|
|
219
|
-
):
|
|
220
|
-
right_schema = right_result.cached_schema_getter()
|
|
221
|
-
else:
|
|
222
|
-
right_schema = right_df.schema
|
|
263
|
+
left_schema = get_schema_from_result(left_result)
|
|
264
|
+
right_schema = get_schema_from_result(right_result)
|
|
223
265
|
|
|
224
266
|
left_dtypes = [field.datatype for field in left_schema.fields]
|
|
225
267
|
right_dtypes = [field.datatype for field in right_schema.fields]
|
|
@@ -257,6 +299,29 @@ def map_union(
|
|
|
257
299
|
# Union of any type with null type is of the other type
|
|
258
300
|
target_left_dtypes.append(other_t)
|
|
259
301
|
target_right_dtypes.append(other_t)
|
|
302
|
+
case (snowpark.types.DecimalType(), snowpark.types.DecimalType()):
|
|
303
|
+
# Widen decimal types to accommodate both sides
|
|
304
|
+
# Calculate the maximum scale and maximum integer digits
|
|
305
|
+
left_integer_digits = left_type.precision - left_type.scale
|
|
306
|
+
right_integer_digits = right_type.precision - right_type.scale
|
|
307
|
+
|
|
308
|
+
# The common type needs to accommodate:
|
|
309
|
+
# - The maximum number of digits after the decimal point (scale)
|
|
310
|
+
# - The maximum number of digits before the decimal point (integer digits)
|
|
311
|
+
common_scale = max(left_type.scale, right_type.scale)
|
|
312
|
+
common_integer_digits = max(
|
|
313
|
+
left_integer_digits, right_integer_digits
|
|
314
|
+
)
|
|
315
|
+
common_precision = min(38, common_scale + common_integer_digits)
|
|
316
|
+
|
|
317
|
+
# Ensure scale doesn't exceed precision
|
|
318
|
+
common_scale = min(common_scale, common_precision)
|
|
319
|
+
|
|
320
|
+
common_type = snowpark.types.DecimalType(
|
|
321
|
+
common_precision, common_scale
|
|
322
|
+
)
|
|
323
|
+
target_left_dtypes.append(common_type)
|
|
324
|
+
target_right_dtypes.append(common_type)
|
|
260
325
|
case (snowpark.types.BooleanType(), _) | (
|
|
261
326
|
_,
|
|
262
327
|
snowpark.types.BooleanType(),
|
|
@@ -272,49 +337,24 @@ def map_union(
|
|
|
272
337
|
raise exception
|
|
273
338
|
target_left_dtypes.append(left_type)
|
|
274
339
|
target_right_dtypes.append(right_type)
|
|
340
|
+
case (
|
|
341
|
+
snowpark.types.TimestampType()
|
|
342
|
+
| snowpark.types.DateType()
|
|
343
|
+
| snowpark.types._NumericType(),
|
|
344
|
+
snowpark.types.StringType(),
|
|
345
|
+
) | (
|
|
346
|
+
snowpark.types.StringType(),
|
|
347
|
+
snowpark.types.TimestampType()
|
|
348
|
+
| snowpark.types.DateType()
|
|
349
|
+
| snowpark.types._NumericType(),
|
|
350
|
+
) if not spark_sql_ansi_enabled:
|
|
351
|
+
common_type = snowpark.types.StringType()
|
|
352
|
+
target_left_dtypes.append(common_type)
|
|
353
|
+
target_right_dtypes.append(common_type)
|
|
275
354
|
case _:
|
|
276
355
|
target_left_dtypes.append(left_type)
|
|
277
356
|
target_right_dtypes.append(right_type)
|
|
278
357
|
|
|
279
|
-
def cast_columns(
|
|
280
|
-
df_container: DataFrameContainer,
|
|
281
|
-
df_dtypes: list[snowpark.types.DataType],
|
|
282
|
-
target_dtypes: list[snowpark.types.DataType],
|
|
283
|
-
column_map: ColumnNameMap,
|
|
284
|
-
):
|
|
285
|
-
df: snowpark.DataFrame = df_container.dataframe
|
|
286
|
-
if df_dtypes == target_dtypes:
|
|
287
|
-
return df_container
|
|
288
|
-
# Use cached schema if available to avoid triggering extra queries
|
|
289
|
-
if (
|
|
290
|
-
hasattr(df_container, "cached_schema_getter")
|
|
291
|
-
and df_container.cached_schema_getter is not None
|
|
292
|
-
):
|
|
293
|
-
df_schema = df_container.cached_schema_getter()
|
|
294
|
-
else:
|
|
295
|
-
df_schema = df.schema # Get current schema
|
|
296
|
-
new_columns = []
|
|
297
|
-
|
|
298
|
-
for i, field in enumerate(df_schema.fields):
|
|
299
|
-
col_name = field.name
|
|
300
|
-
current_type = field.datatype
|
|
301
|
-
target_type = target_dtypes[i]
|
|
302
|
-
|
|
303
|
-
if current_type != target_type:
|
|
304
|
-
new_columns.append(df[col_name].cast(target_type).alias(col_name))
|
|
305
|
-
else:
|
|
306
|
-
new_columns.append(df[col_name])
|
|
307
|
-
|
|
308
|
-
new_df = df.select(new_columns)
|
|
309
|
-
return DataFrameContainer.create_with_column_mapping(
|
|
310
|
-
dataframe=new_df,
|
|
311
|
-
spark_column_names=column_map.get_spark_columns(),
|
|
312
|
-
snowpark_column_names=column_map.get_snowpark_columns(),
|
|
313
|
-
snowpark_column_types=target_dtypes,
|
|
314
|
-
column_metadata=column_map.column_metadata,
|
|
315
|
-
parent_column_name_map=column_map,
|
|
316
|
-
)
|
|
317
|
-
|
|
318
358
|
left_result = cast_columns(
|
|
319
359
|
left_result,
|
|
320
360
|
left_dtypes,
|
|
@@ -527,6 +567,48 @@ def map_except(
|
|
|
527
567
|
left_df = left_result.dataframe
|
|
528
568
|
right_df = right_result.dataframe
|
|
529
569
|
|
|
570
|
+
# workaround for unstructured type vs structured type
|
|
571
|
+
# Use cached schema if available to avoid triggering extra queries
|
|
572
|
+
left_schema = get_schema_from_result(left_result)
|
|
573
|
+
right_schema = get_schema_from_result(right_result)
|
|
574
|
+
|
|
575
|
+
left_dtypes = [field.datatype for field in left_schema.fields]
|
|
576
|
+
right_dtypes = [field.datatype for field in right_schema.fields]
|
|
577
|
+
|
|
578
|
+
if left_dtypes != right_dtypes and not rel.set_op.by_name:
|
|
579
|
+
if len(left_dtypes) != len(right_dtypes):
|
|
580
|
+
exception = AnalysisException("UNION: the number of columns must match")
|
|
581
|
+
attach_custom_error_code(exception, ErrorCodes.INVALID_OPERATION)
|
|
582
|
+
raise exception
|
|
583
|
+
target_left_dtypes, target_right_dtypes = [], []
|
|
584
|
+
for left_type, right_type in zip(left_dtypes, right_dtypes):
|
|
585
|
+
match (left_type, right_type):
|
|
586
|
+
case (snowpark.types._NumericType(), snowpark.types.StringType()) | (
|
|
587
|
+
snowpark.types.StringType(),
|
|
588
|
+
snowpark.types._NumericType(),
|
|
589
|
+
):
|
|
590
|
+
common_type = snowpark.types.StringType()
|
|
591
|
+
target_left_dtypes.append(common_type)
|
|
592
|
+
target_right_dtypes.append(common_type)
|
|
593
|
+
case _:
|
|
594
|
+
target_left_dtypes.append(left_type)
|
|
595
|
+
target_right_dtypes.append(right_type)
|
|
596
|
+
|
|
597
|
+
left_result = cast_columns(
|
|
598
|
+
left_result,
|
|
599
|
+
left_dtypes,
|
|
600
|
+
target_left_dtypes,
|
|
601
|
+
left_result.column_map,
|
|
602
|
+
)
|
|
603
|
+
right_result = cast_columns(
|
|
604
|
+
right_result,
|
|
605
|
+
right_dtypes,
|
|
606
|
+
target_right_dtypes,
|
|
607
|
+
right_result.column_map,
|
|
608
|
+
)
|
|
609
|
+
left_df = left_result.dataframe
|
|
610
|
+
right_df = right_result.dataframe
|
|
611
|
+
|
|
530
612
|
if rel.set_op.is_all:
|
|
531
613
|
# Snowflake except removes all duplicated rows. In order to handle the case,
|
|
532
614
|
# we add a partition row number column to the df to make duplicated rows unique to
|