snowpark-connect 0.32.0__py3-none-any.whl → 0.33.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of snowpark-connect might be problematic. Click here for more details.
- snowflake/snowpark_connect/column_name_handler.py +92 -27
- snowflake/snowpark_connect/column_qualifier.py +0 -4
- snowflake/snowpark_connect/expression/hybrid_column_map.py +5 -4
- snowflake/snowpark_connect/expression/map_sql_expression.py +12 -4
- snowflake/snowpark_connect/expression/map_unresolved_attribute.py +58 -21
- snowflake/snowpark_connect/expression/map_unresolved_function.py +62 -27
- snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/__init__.py +16 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/accessors.py +1281 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/functions.py +203 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/utils.py +202 -0
- snowflake/snowpark_connect/relation/map_aggregate.py +2 -4
- snowflake/snowpark_connect/relation/map_column_ops.py +5 -0
- snowflake/snowpark_connect/relation/map_join.py +218 -146
- snowflake/snowpark_connect/relation/map_row_ops.py +136 -54
- snowflake/snowpark_connect/relation/map_sql.py +102 -16
- snowflake/snowpark_connect/relation/read/map_read_json.py +87 -2
- snowflake/snowpark_connect/relation/utils.py +46 -0
- snowflake/snowpark_connect/relation/write/map_write.py +186 -275
- snowflake/snowpark_connect/resources_initializer.py +25 -13
- snowflake/snowpark_connect/server.py +9 -24
- snowflake/snowpark_connect/type_mapping.py +2 -0
- snowflake/snowpark_connect/typed_column.py +2 -2
- snowflake/snowpark_connect/utils/sequence.py +21 -0
- snowflake/snowpark_connect/utils/session.py +8 -1
- snowflake/snowpark_connect/version.py +1 -1
- {snowpark_connect-0.32.0.dist-info → snowpark_connect-0.33.0.dist-info}/METADATA +3 -1
- {snowpark_connect-0.32.0.dist-info → snowpark_connect-0.33.0.dist-info}/RECORD +35 -93
- snowflake/snowpark_connect/includes/jars/antlr4-runtime-4.9.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-cli-1.5.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-codec-1.16.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-collections-3.2.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-collections4-4.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-compiler-3.1.9.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-compress-1.26.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-crypto-1.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-dbcp-1.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-io-2.16.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-lang-2.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-lang3-3.12.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-logging-1.1.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-math3-3.6.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-pool-1.5.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-text-1.10.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/hadoop-client-api-trimmed-3.3.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-annotations-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-core-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-core-asl-1.9.13.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-databind-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-dataformat-yaml-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-datatype-jsr310-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-module-scala_2.12-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-ast_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-core_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-jackson_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-native_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-scalap_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/kryo-shaded-4.0.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-1.2-api-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-api-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-core-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-slf4j2-impl-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/paranamer-2.8.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/paranamer-2.8.jar +0 -0
- snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-collection-compat_2.12-2.7.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-library-2.12.18.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-parser-combinators_2.12-2.3.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-reflect-2.12.18.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-xml_2.12-2.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/slf4j-api-2.0.7.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-catalyst_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-common-utils_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-connect-client-jvm_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-core_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-graphx_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-hive-thriftserver_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-hive_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-kvstore_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-launcher_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-mesos_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-mllib-local_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-network-common_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-network-shuffle_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-repl_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sketch_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sql-api_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sql_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-tags_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-unsafe_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-yarn_2.12-3.5.6.jar +0 -0
- {snowpark_connect-0.32.0.data → snowpark_connect-0.33.0.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.32.0.data → snowpark_connect-0.33.0.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.32.0.data → snowpark_connect-0.33.0.data}/scripts/snowpark-submit +0 -0
- {snowpark_connect-0.32.0.dist-info → snowpark_connect-0.33.0.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.32.0.dist-info → snowpark_connect-0.33.0.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-0.32.0.dist-info → snowpark_connect-0.33.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-0.32.0.dist-info → snowpark_connect-0.33.0.dist-info}/licenses/NOTICE-binary +0 -0
- {snowpark_connect-0.32.0.dist-info → snowpark_connect-0.33.0.dist-info}/top_level.txt +0 -0
|
@@ -1,15 +1,21 @@
|
|
|
1
1
|
#
|
|
2
2
|
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
|
|
3
3
|
#
|
|
4
|
-
|
|
4
|
+
import dataclasses
|
|
5
|
+
from enum import Enum
|
|
5
6
|
from functools import reduce
|
|
7
|
+
from typing import Optional
|
|
6
8
|
|
|
7
9
|
import pyspark.sql.connect.proto.relations_pb2 as relation_proto
|
|
8
10
|
from pyspark.errors import AnalysisException
|
|
9
11
|
|
|
10
12
|
import snowflake.snowpark.functions as snowpark_fn
|
|
11
13
|
from snowflake import snowpark
|
|
12
|
-
from snowflake.
|
|
14
|
+
from snowflake.snowpark.types import StructField, StructType
|
|
15
|
+
from snowflake.snowpark_connect.column_name_handler import (
|
|
16
|
+
JoinColumnNameMap,
|
|
17
|
+
make_unique_snowpark_name,
|
|
18
|
+
)
|
|
13
19
|
from snowflake.snowpark_connect.column_qualifier import ColumnQualifier
|
|
14
20
|
from snowflake.snowpark_connect.config import global_config
|
|
15
21
|
from snowflake.snowpark_connect.constants import COLUMN_METADATA_COLLISION_KEY
|
|
@@ -43,6 +49,25 @@ from snowflake.snowpark_connect.utils.telemetry import (
|
|
|
43
49
|
USING_COLUMN_NOT_FOUND_ERROR = "[UNRESOLVED_USING_COLUMN_FOR_JOIN] USING column `{0}` not found on the {1} side of the join. The {1}-side columns: {2}"
|
|
44
50
|
|
|
45
51
|
|
|
52
|
+
class ConditionType(Enum):
|
|
53
|
+
USING_COLUMNS = 1
|
|
54
|
+
JOIN_CONDITION = 2
|
|
55
|
+
NO_CONDITION = 3
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@dataclasses.dataclass
|
|
59
|
+
class JoinInfo:
|
|
60
|
+
join_type: str
|
|
61
|
+
condition_type: ConditionType
|
|
62
|
+
join_columns: Optional[list[str]]
|
|
63
|
+
|
|
64
|
+
def has_join_condition(self) -> bool:
|
|
65
|
+
return self.condition_type == ConditionType.JOIN_CONDITION
|
|
66
|
+
|
|
67
|
+
def is_using_columns(self):
|
|
68
|
+
return self.condition_type == ConditionType.USING_COLUMNS
|
|
69
|
+
|
|
70
|
+
|
|
46
71
|
def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
|
|
47
72
|
left_container: DataFrameContainer = map_relation(rel.join.left)
|
|
48
73
|
right_container: DataFrameContainer = map_relation(rel.join.right)
|
|
@@ -54,48 +79,11 @@ def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
|
|
|
54
79
|
|
|
55
80
|
left_input: snowpark.DataFrame = left_container.dataframe
|
|
56
81
|
right_input: snowpark.DataFrame = right_container.dataframe
|
|
57
|
-
is_natural_join = rel.join.join_type >= NATURAL_JOIN_TYPE_BASE
|
|
58
|
-
using_columns = rel.join.using_columns
|
|
59
|
-
if is_natural_join:
|
|
60
|
-
rel.join.join_type -= NATURAL_JOIN_TYPE_BASE
|
|
61
|
-
left_spark_columns = left_container.column_map.get_spark_columns()
|
|
62
|
-
right_spark_columns = right_container.column_map.get_spark_columns()
|
|
63
|
-
common_spark_columns = [
|
|
64
|
-
x for x in left_spark_columns if x in right_spark_columns
|
|
65
|
-
]
|
|
66
|
-
using_columns = common_spark_columns
|
|
67
|
-
|
|
68
|
-
match rel.join.join_type:
|
|
69
|
-
case relation_proto.Join.JOIN_TYPE_UNSPECIFIED:
|
|
70
|
-
# TODO: Understand what UNSPECIFIED Join type is
|
|
71
|
-
exception = SnowparkConnectNotImplementedError("Unspecified Join Type")
|
|
72
|
-
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
73
|
-
raise exception
|
|
74
|
-
case relation_proto.Join.JOIN_TYPE_INNER:
|
|
75
|
-
join_type = "inner"
|
|
76
|
-
case relation_proto.Join.JOIN_TYPE_FULL_OUTER:
|
|
77
|
-
join_type = "full_outer"
|
|
78
|
-
case relation_proto.Join.JOIN_TYPE_LEFT_OUTER:
|
|
79
|
-
join_type = "left"
|
|
80
|
-
case relation_proto.Join.JOIN_TYPE_RIGHT_OUTER:
|
|
81
|
-
join_type = "right"
|
|
82
|
-
case relation_proto.Join.JOIN_TYPE_LEFT_ANTI:
|
|
83
|
-
join_type = "leftanti"
|
|
84
|
-
case relation_proto.Join.JOIN_TYPE_LEFT_SEMI:
|
|
85
|
-
join_type = "leftsemi"
|
|
86
|
-
case relation_proto.Join.JOIN_TYPE_CROSS:
|
|
87
|
-
join_type = "cross"
|
|
88
|
-
case other:
|
|
89
|
-
exception = SnowparkConnectNotImplementedError(f"Other Join Type: {other}")
|
|
90
|
-
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
91
|
-
raise exception
|
|
92
|
-
|
|
93
|
-
# This handles case sensitivity for using_columns
|
|
94
|
-
case_corrected_right_columns: list[str] = []
|
|
95
82
|
|
|
96
|
-
|
|
97
|
-
|
|
83
|
+
join_info = _get_join_info(rel, left_container, right_container)
|
|
84
|
+
join_type = join_info.join_type
|
|
98
85
|
|
|
86
|
+
if join_info.has_join_condition():
|
|
99
87
|
left_columns = list(left_container.column_map.spark_to_col.keys())
|
|
100
88
|
right_columns = list(right_container.column_map.spark_to_col.keys())
|
|
101
89
|
|
|
@@ -122,72 +110,42 @@ def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
|
|
|
122
110
|
result: snowpark.DataFrame = left_input.join(
|
|
123
111
|
right=right_input,
|
|
124
112
|
on=join_expression.col,
|
|
125
|
-
how=join_type,
|
|
113
|
+
how="inner" if join_info.join_type == "cross" else join_info.join_type,
|
|
126
114
|
lsuffix="_left",
|
|
127
115
|
rsuffix="_right",
|
|
128
116
|
)
|
|
129
|
-
elif
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
),
|
|
147
|
-
"left",
|
|
148
|
-
left_container.column_map.get_spark_columns(),
|
|
117
|
+
elif join_info.is_using_columns():
|
|
118
|
+
# TODO: disambiguate snowpark columns for all join condition types
|
|
119
|
+
# disambiguation temporarily done only for using_columns/natural joins to reduce changes
|
|
120
|
+
left_container, right_container = _disambiguate_snowpark_columns(
|
|
121
|
+
left_container, right_container
|
|
122
|
+
)
|
|
123
|
+
left_input = left_container.dataframe
|
|
124
|
+
right_input = right_container.dataframe
|
|
125
|
+
|
|
126
|
+
join_columns = join_info.join_columns
|
|
127
|
+
|
|
128
|
+
def _validate_using_column(
|
|
129
|
+
column: str, container: DataFrameContainer, side: str
|
|
130
|
+
) -> None:
|
|
131
|
+
if (
|
|
132
|
+
container.column_map.get_snowpark_column_name_from_spark_column_name(
|
|
133
|
+
column, allow_non_exists=True, return_first=True
|
|
149
134
|
)
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
)
|
|
157
|
-
is None
|
|
158
|
-
for c in using_columns
|
|
159
|
-
):
|
|
160
|
-
exception = AnalysisException(
|
|
161
|
-
USING_COLUMN_NOT_FOUND_ERROR.format(
|
|
162
|
-
next(
|
|
163
|
-
c
|
|
164
|
-
for c in using_columns
|
|
165
|
-
if right_container.column_map.get_snowpark_column_name_from_spark_column_name(
|
|
166
|
-
c, allow_non_exists=True, return_first=True
|
|
167
|
-
)
|
|
168
|
-
is None
|
|
169
|
-
),
|
|
170
|
-
"right",
|
|
171
|
-
right_container.column_map.get_spark_columns(),
|
|
135
|
+
is None
|
|
136
|
+
):
|
|
137
|
+
exception = AnalysisException(
|
|
138
|
+
USING_COLUMN_NOT_FOUND_ERROR.format(
|
|
139
|
+
column, side, container.column_map.get_spark_columns()
|
|
140
|
+
)
|
|
172
141
|
)
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
142
|
+
attach_custom_error_code(exception, ErrorCodes.COLUMN_NOT_FOUND)
|
|
143
|
+
raise exception
|
|
144
|
+
|
|
145
|
+
for col in join_columns:
|
|
146
|
+
_validate_using_column(col, left_container, "left")
|
|
147
|
+
_validate_using_column(col, right_container, "right")
|
|
176
148
|
|
|
177
|
-
# Round trip the using columns through the column map to get the correct names
|
|
178
|
-
# in order to support case sensitivity.
|
|
179
|
-
# TODO: case_corrected_left_columns / case_corrected_right_columns may no longer be required as Snowpark dataframe preserves the column casing now.
|
|
180
|
-
case_corrected_left_columns = left_container.column_map.get_spark_column_names_from_snowpark_column_names(
|
|
181
|
-
left_container.column_map.get_snowpark_column_names_from_spark_column_names(
|
|
182
|
-
list(using_columns), return_first=True
|
|
183
|
-
)
|
|
184
|
-
)
|
|
185
|
-
case_corrected_right_columns = right_container.column_map.get_spark_column_names_from_snowpark_column_names(
|
|
186
|
-
right_container.column_map.get_snowpark_column_names_from_spark_column_names(
|
|
187
|
-
list(using_columns), return_first=True
|
|
188
|
-
)
|
|
189
|
-
)
|
|
190
|
-
using_columns = zip(case_corrected_left_columns, case_corrected_right_columns)
|
|
191
149
|
# We cannot assume that Snowpark will have the same names for left and right columns,
|
|
192
150
|
# so we convert ["a", "b"] into (left["a"] == right["a"] & left["b"] == right["b"]),
|
|
193
151
|
# then drop right["a"] and right["b"].
|
|
@@ -195,16 +153,16 @@ def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
|
|
|
195
153
|
(
|
|
196
154
|
left_input[
|
|
197
155
|
left_container.column_map.get_snowpark_column_name_from_spark_column_name(
|
|
198
|
-
|
|
156
|
+
spark_name, return_first=True
|
|
199
157
|
)
|
|
200
158
|
],
|
|
201
159
|
right_input[
|
|
202
160
|
right_container.column_map.get_snowpark_column_name_from_spark_column_name(
|
|
203
|
-
|
|
161
|
+
spark_name, return_first=True
|
|
204
162
|
)
|
|
205
163
|
],
|
|
206
164
|
)
|
|
207
|
-
for
|
|
165
|
+
for spark_name in join_columns
|
|
208
166
|
]
|
|
209
167
|
joined_df = left_input.join(
|
|
210
168
|
right=right_input,
|
|
@@ -240,10 +198,19 @@ def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
|
|
|
240
198
|
exception = SparkException.implicit_cartesian_product("inner")
|
|
241
199
|
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
242
200
|
raise exception
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
201
|
+
# For outer joins without a condition, we need to use a TRUE condition
|
|
202
|
+
# to match Spark's behavior.
|
|
203
|
+
if join_type in ["left", "right", "full_outer"]:
|
|
204
|
+
result: snowpark.DataFrame = left_input.join(
|
|
205
|
+
right=right_input,
|
|
206
|
+
on=snowpark_fn.lit(True),
|
|
207
|
+
how=join_type,
|
|
208
|
+
)
|
|
209
|
+
else:
|
|
210
|
+
result: snowpark.DataFrame = left_input.join(
|
|
211
|
+
right=right_input,
|
|
212
|
+
how=join_type,
|
|
213
|
+
)
|
|
247
214
|
|
|
248
215
|
if join_type in ["leftanti", "leftsemi"]:
|
|
249
216
|
# Join types that only return columns from the left side:
|
|
@@ -253,39 +220,26 @@ def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
|
|
|
253
220
|
spark_cols_after_join: list[str] = left_container.column_map.get_spark_columns()
|
|
254
221
|
qualifiers = left_container.column_map.get_qualifiers()
|
|
255
222
|
else:
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
spark_col
|
|
261
|
-
for i, spark_col in enumerate(
|
|
262
|
-
right_container.column_map.get_spark_columns()
|
|
223
|
+
if not join_info.is_using_columns():
|
|
224
|
+
spark_cols_after_join: list[str] = (
|
|
225
|
+
left_container.column_map.get_spark_columns()
|
|
226
|
+
+ right_container.column_map.get_spark_columns()
|
|
263
227
|
)
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
:i
|
|
268
|
-
] # this is to make sure we only remove the column once
|
|
269
|
-
]
|
|
270
|
-
|
|
271
|
-
qualifiers: list[set[ColumnQualifier]] = list(
|
|
272
|
-
left_container.column_map.get_qualifiers()
|
|
273
|
-
) + [
|
|
274
|
-
{right_container.column_map.get_qualifier_for_spark_column(spark_col)}
|
|
275
|
-
for i, spark_col in enumerate(
|
|
276
|
-
right_container.column_map.get_spark_columns()
|
|
228
|
+
qualifiers: list[set[ColumnQualifier]] = (
|
|
229
|
+
left_container.column_map.get_qualifiers()
|
|
230
|
+
+ right_container.column_map.get_qualifiers()
|
|
277
231
|
)
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
column_metadata.update(left_container.column_map.column_metadata)
|
|
232
|
+
else:
|
|
233
|
+
# get columns after join
|
|
234
|
+
joined_columns = left_container.column_map.get_columns_after_join(
|
|
235
|
+
right_container.column_map, join_info.join_columns
|
|
236
|
+
)
|
|
237
|
+
spark_cols_after_join: list[str] = [c.spark_name for c in joined_columns]
|
|
238
|
+
qualifiers: list[set[ColumnQualifier]] = [
|
|
239
|
+
c.qualifiers for c in joined_columns
|
|
240
|
+
]
|
|
288
241
|
|
|
242
|
+
column_metadata = dict(left_container.column_map.column_metadata or {})
|
|
289
243
|
if right_container.column_map.column_metadata:
|
|
290
244
|
for key, value in right_container.column_map.column_metadata.items():
|
|
291
245
|
if key not in column_metadata:
|
|
@@ -318,7 +272,7 @@ def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
|
|
|
318
272
|
# After a USING join, references to the right dataframe's columns should resolve
|
|
319
273
|
# to the result dataframe that contains the merged columns
|
|
320
274
|
if (
|
|
321
|
-
|
|
275
|
+
join_info.is_using_columns()
|
|
322
276
|
and rel.join.right.HasField("common")
|
|
323
277
|
and rel.join.right.common.HasField("plan_id")
|
|
324
278
|
):
|
|
@@ -328,7 +282,7 @@ def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
|
|
|
328
282
|
# For FULL OUTER joins, we also need to map the left dataframe's plan_id
|
|
329
283
|
# since both columns are replaced with a coalesced column
|
|
330
284
|
if (
|
|
331
|
-
|
|
285
|
+
join_info.is_using_columns()
|
|
332
286
|
and join_type == "full_outer"
|
|
333
287
|
and rel.join.left.HasField("common")
|
|
334
288
|
and rel.join.left.common.HasField("plan_id")
|
|
@@ -336,12 +290,12 @@ def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
|
|
|
336
290
|
left_plan_id = rel.join.left.common.plan_id
|
|
337
291
|
set_plan_id_map(left_plan_id, result_container)
|
|
338
292
|
|
|
339
|
-
if
|
|
293
|
+
if join_info.is_using_columns():
|
|
340
294
|
# When join 'using_columns', the 'join columns' should go first in result DF.
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
295
|
+
# we're only shifting left side columns, since we dropped the right-side ones
|
|
296
|
+
idxs_to_shift = left_container.column_map.get_column_indexes(
|
|
297
|
+
join_info.join_columns
|
|
298
|
+
)
|
|
345
299
|
|
|
346
300
|
def reorder(lst: list) -> list:
|
|
347
301
|
to_move = [lst[i] for i in idxs_to_shift]
|
|
@@ -370,3 +324,121 @@ def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
|
|
|
370
324
|
)
|
|
371
325
|
|
|
372
326
|
return result_container
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
def _get_join_info(
|
|
330
|
+
rel: relation_proto.Relation, left: DataFrameContainer, right: DataFrameContainer
|
|
331
|
+
) -> JoinInfo:
|
|
332
|
+
"""
|
|
333
|
+
Gathers basic information about the join, and performs basic assertions
|
|
334
|
+
"""
|
|
335
|
+
|
|
336
|
+
is_natural_join = rel.join.join_type >= NATURAL_JOIN_TYPE_BASE
|
|
337
|
+
join_columns = rel.join.using_columns
|
|
338
|
+
if is_natural_join:
|
|
339
|
+
rel.join.join_type -= NATURAL_JOIN_TYPE_BASE
|
|
340
|
+
left_spark_columns = left.column_map.get_spark_columns()
|
|
341
|
+
right_spark_columns = right.column_map.get_spark_columns()
|
|
342
|
+
common_spark_columns = [
|
|
343
|
+
x for x in left_spark_columns if x in right_spark_columns
|
|
344
|
+
]
|
|
345
|
+
join_columns = common_spark_columns
|
|
346
|
+
|
|
347
|
+
match rel.join.join_type:
|
|
348
|
+
case relation_proto.Join.JOIN_TYPE_UNSPECIFIED:
|
|
349
|
+
# TODO: Understand what UNSPECIFIED Join type is
|
|
350
|
+
exception = SnowparkConnectNotImplementedError("Unspecified Join Type")
|
|
351
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
352
|
+
raise exception
|
|
353
|
+
case relation_proto.Join.JOIN_TYPE_INNER:
|
|
354
|
+
join_type = "inner"
|
|
355
|
+
case relation_proto.Join.JOIN_TYPE_FULL_OUTER:
|
|
356
|
+
join_type = "full_outer"
|
|
357
|
+
case relation_proto.Join.JOIN_TYPE_LEFT_OUTER:
|
|
358
|
+
join_type = "left"
|
|
359
|
+
case relation_proto.Join.JOIN_TYPE_RIGHT_OUTER:
|
|
360
|
+
join_type = "right"
|
|
361
|
+
case relation_proto.Join.JOIN_TYPE_LEFT_ANTI:
|
|
362
|
+
join_type = "leftanti"
|
|
363
|
+
case relation_proto.Join.JOIN_TYPE_LEFT_SEMI:
|
|
364
|
+
join_type = "leftsemi"
|
|
365
|
+
case relation_proto.Join.JOIN_TYPE_CROSS:
|
|
366
|
+
join_type = "cross"
|
|
367
|
+
case other:
|
|
368
|
+
exception = SnowparkConnectNotImplementedError(f"Other Join Type: {other}")
|
|
369
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
370
|
+
raise exception
|
|
371
|
+
|
|
372
|
+
has_join_condition = rel.join.HasField("join_condition")
|
|
373
|
+
is_using_columns = bool(join_columns)
|
|
374
|
+
|
|
375
|
+
if has_join_condition:
|
|
376
|
+
assert not is_using_columns
|
|
377
|
+
|
|
378
|
+
condition_type = ConditionType.NO_CONDITION
|
|
379
|
+
if has_join_condition:
|
|
380
|
+
condition_type = ConditionType.JOIN_CONDITION
|
|
381
|
+
elif is_using_columns:
|
|
382
|
+
condition_type = ConditionType.USING_COLUMNS
|
|
383
|
+
|
|
384
|
+
return JoinInfo(join_type, condition_type, join_columns)
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
def _disambiguate_snowpark_columns(
|
|
388
|
+
left: DataFrameContainer, right: DataFrameContainer
|
|
389
|
+
) -> tuple[DataFrameContainer, DataFrameContainer]:
|
|
390
|
+
conflicting_snowpark_columns = left.column_map.get_conflicting_snowpark_columns(
|
|
391
|
+
right.column_map
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
if not conflicting_snowpark_columns:
|
|
395
|
+
return left, right
|
|
396
|
+
|
|
397
|
+
# rename and create new containers
|
|
398
|
+
return _disambiguate_container(
|
|
399
|
+
left, conflicting_snowpark_columns
|
|
400
|
+
), _disambiguate_container(right, conflicting_snowpark_columns)
|
|
401
|
+
|
|
402
|
+
|
|
403
|
+
def _disambiguate_container(
|
|
404
|
+
container: DataFrameContainer, conflicting_snowpark_columns: set[str]
|
|
405
|
+
) -> DataFrameContainer:
|
|
406
|
+
column_map = container.column_map
|
|
407
|
+
disambiguated_columns = []
|
|
408
|
+
disambiguated_snowpark_names = []
|
|
409
|
+
for c in column_map.columns:
|
|
410
|
+
if c.snowpark_name in conflicting_snowpark_columns:
|
|
411
|
+
# alias snowpark column with a new unique name
|
|
412
|
+
new_name = make_unique_snowpark_name(c.spark_name)
|
|
413
|
+
disambiguated_snowpark_names.append(new_name)
|
|
414
|
+
disambiguated_columns.append(
|
|
415
|
+
snowpark_fn.col(c.snowpark_name).alias(new_name)
|
|
416
|
+
)
|
|
417
|
+
else:
|
|
418
|
+
disambiguated_snowpark_names.append(c.snowpark_name)
|
|
419
|
+
disambiguated_columns.append(snowpark_fn.col(c.snowpark_name))
|
|
420
|
+
|
|
421
|
+
disambiguated_df = container.dataframe.select(*disambiguated_columns)
|
|
422
|
+
|
|
423
|
+
def _get_new_schema():
|
|
424
|
+
old_schema = container.dataframe.schema
|
|
425
|
+
if not old_schema.fields:
|
|
426
|
+
return StructType([])
|
|
427
|
+
|
|
428
|
+
new_fields = []
|
|
429
|
+
for i, name in enumerate(disambiguated_snowpark_names):
|
|
430
|
+
f = old_schema.fields[i]
|
|
431
|
+
new_fields.append(
|
|
432
|
+
StructField(name, f.datatype, nullable=f.nullable, _is_column=True)
|
|
433
|
+
)
|
|
434
|
+
return StructType(new_fields)
|
|
435
|
+
|
|
436
|
+
return DataFrameContainer.create_with_column_mapping(
|
|
437
|
+
dataframe=disambiguated_df,
|
|
438
|
+
spark_column_names=column_map.get_spark_columns(),
|
|
439
|
+
snowpark_column_names=disambiguated_snowpark_names,
|
|
440
|
+
column_metadata=column_map.column_metadata,
|
|
441
|
+
column_qualifiers=column_map.get_qualifiers(),
|
|
442
|
+
table_name=container.table_name,
|
|
443
|
+
cached_schema_getter=_get_new_schema,
|
|
444
|
+
)
|
|
@@ -45,6 +45,61 @@ from snowflake.snowpark_connect.utils.telemetry import (
|
|
|
45
45
|
)
|
|
46
46
|
|
|
47
47
|
|
|
48
|
+
def cast_columns(
|
|
49
|
+
df_container: DataFrameContainer,
|
|
50
|
+
df_dtypes: list[snowpark.types.DataType],
|
|
51
|
+
target_dtypes: list[snowpark.types.DataType],
|
|
52
|
+
column_map: ColumnNameMap,
|
|
53
|
+
):
|
|
54
|
+
df: snowpark.DataFrame = df_container.dataframe
|
|
55
|
+
if df_dtypes == target_dtypes:
|
|
56
|
+
return df_container
|
|
57
|
+
# Use cached schema if available to avoid triggering extra queries
|
|
58
|
+
if (
|
|
59
|
+
hasattr(df_container, "cached_schema_getter")
|
|
60
|
+
and df_container.cached_schema_getter is not None
|
|
61
|
+
):
|
|
62
|
+
df_schema = df_container.cached_schema_getter()
|
|
63
|
+
else:
|
|
64
|
+
df_schema = df.schema # Get current schema
|
|
65
|
+
new_columns = []
|
|
66
|
+
|
|
67
|
+
for i, field in enumerate(df_schema.fields):
|
|
68
|
+
col_name = field.name
|
|
69
|
+
current_type = field.datatype
|
|
70
|
+
target_type = target_dtypes[i]
|
|
71
|
+
|
|
72
|
+
if current_type != target_type:
|
|
73
|
+
new_columns.append(df[col_name].cast(target_type).alias(col_name))
|
|
74
|
+
else:
|
|
75
|
+
new_columns.append(df[col_name])
|
|
76
|
+
|
|
77
|
+
new_df = df.select(new_columns)
|
|
78
|
+
return DataFrameContainer.create_with_column_mapping(
|
|
79
|
+
dataframe=new_df,
|
|
80
|
+
spark_column_names=column_map.get_spark_columns(),
|
|
81
|
+
snowpark_column_names=column_map.get_snowpark_columns(),
|
|
82
|
+
snowpark_column_types=target_dtypes,
|
|
83
|
+
column_metadata=column_map.column_metadata,
|
|
84
|
+
parent_column_name_map=column_map,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def get_schema_from_result(
|
|
89
|
+
result: DataFrameContainer,
|
|
90
|
+
) -> StructType:
|
|
91
|
+
"""
|
|
92
|
+
Get schema from a DataFrameContainer, using cached schema if available to avoid extra queries.
|
|
93
|
+
"""
|
|
94
|
+
if (
|
|
95
|
+
hasattr(result, "cached_schema_getter")
|
|
96
|
+
and result.cached_schema_getter is not None
|
|
97
|
+
):
|
|
98
|
+
return result.cached_schema_getter()
|
|
99
|
+
else:
|
|
100
|
+
return result.dataframe.schema
|
|
101
|
+
|
|
102
|
+
|
|
48
103
|
def map_deduplicate(
|
|
49
104
|
rel: relation_proto.Relation,
|
|
50
105
|
) -> DataFrameContainer:
|
|
@@ -205,21 +260,8 @@ def map_union(
|
|
|
205
260
|
|
|
206
261
|
# workaround for unstructured type vs structured type
|
|
207
262
|
# Use cached schema if available to avoid triggering extra queries
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
and left_result.cached_schema_getter is not None
|
|
211
|
-
):
|
|
212
|
-
left_schema = left_result.cached_schema_getter()
|
|
213
|
-
else:
|
|
214
|
-
left_schema = left_df.schema
|
|
215
|
-
|
|
216
|
-
if (
|
|
217
|
-
hasattr(right_result, "cached_schema_getter")
|
|
218
|
-
and right_result.cached_schema_getter is not None
|
|
219
|
-
):
|
|
220
|
-
right_schema = right_result.cached_schema_getter()
|
|
221
|
-
else:
|
|
222
|
-
right_schema = right_df.schema
|
|
263
|
+
left_schema = get_schema_from_result(left_result)
|
|
264
|
+
right_schema = get_schema_from_result(right_result)
|
|
223
265
|
|
|
224
266
|
left_dtypes = [field.datatype for field in left_schema.fields]
|
|
225
267
|
right_dtypes = [field.datatype for field in right_schema.fields]
|
|
@@ -257,6 +299,29 @@ def map_union(
|
|
|
257
299
|
# Union of any type with null type is of the other type
|
|
258
300
|
target_left_dtypes.append(other_t)
|
|
259
301
|
target_right_dtypes.append(other_t)
|
|
302
|
+
case (snowpark.types.DecimalType(), snowpark.types.DecimalType()):
|
|
303
|
+
# Widen decimal types to accommodate both sides
|
|
304
|
+
# Calculate the maximum scale and maximum integer digits
|
|
305
|
+
left_integer_digits = left_type.precision - left_type.scale
|
|
306
|
+
right_integer_digits = right_type.precision - right_type.scale
|
|
307
|
+
|
|
308
|
+
# The common type needs to accommodate:
|
|
309
|
+
# - The maximum number of digits after the decimal point (scale)
|
|
310
|
+
# - The maximum number of digits before the decimal point (integer digits)
|
|
311
|
+
common_scale = max(left_type.scale, right_type.scale)
|
|
312
|
+
common_integer_digits = max(
|
|
313
|
+
left_integer_digits, right_integer_digits
|
|
314
|
+
)
|
|
315
|
+
common_precision = min(38, common_scale + common_integer_digits)
|
|
316
|
+
|
|
317
|
+
# Ensure scale doesn't exceed precision
|
|
318
|
+
common_scale = min(common_scale, common_precision)
|
|
319
|
+
|
|
320
|
+
common_type = snowpark.types.DecimalType(
|
|
321
|
+
common_precision, common_scale
|
|
322
|
+
)
|
|
323
|
+
target_left_dtypes.append(common_type)
|
|
324
|
+
target_right_dtypes.append(common_type)
|
|
260
325
|
case (snowpark.types.BooleanType(), _) | (
|
|
261
326
|
_,
|
|
262
327
|
snowpark.types.BooleanType(),
|
|
@@ -272,49 +337,24 @@ def map_union(
|
|
|
272
337
|
raise exception
|
|
273
338
|
target_left_dtypes.append(left_type)
|
|
274
339
|
target_right_dtypes.append(right_type)
|
|
340
|
+
case (
|
|
341
|
+
snowpark.types.TimestampType()
|
|
342
|
+
| snowpark.types.DateType()
|
|
343
|
+
| snowpark.types._NumericType(),
|
|
344
|
+
snowpark.types.StringType(),
|
|
345
|
+
) | (
|
|
346
|
+
snowpark.types.StringType(),
|
|
347
|
+
snowpark.types.TimestampType()
|
|
348
|
+
| snowpark.types.DateType()
|
|
349
|
+
| snowpark.types._NumericType(),
|
|
350
|
+
) if not spark_sql_ansi_enabled:
|
|
351
|
+
common_type = snowpark.types.StringType()
|
|
352
|
+
target_left_dtypes.append(common_type)
|
|
353
|
+
target_right_dtypes.append(common_type)
|
|
275
354
|
case _:
|
|
276
355
|
target_left_dtypes.append(left_type)
|
|
277
356
|
target_right_dtypes.append(right_type)
|
|
278
357
|
|
|
279
|
-
def cast_columns(
|
|
280
|
-
df_container: DataFrameContainer,
|
|
281
|
-
df_dtypes: list[snowpark.types.DataType],
|
|
282
|
-
target_dtypes: list[snowpark.types.DataType],
|
|
283
|
-
column_map: ColumnNameMap,
|
|
284
|
-
):
|
|
285
|
-
df: snowpark.DataFrame = df_container.dataframe
|
|
286
|
-
if df_dtypes == target_dtypes:
|
|
287
|
-
return df_container
|
|
288
|
-
# Use cached schema if available to avoid triggering extra queries
|
|
289
|
-
if (
|
|
290
|
-
hasattr(df_container, "cached_schema_getter")
|
|
291
|
-
and df_container.cached_schema_getter is not None
|
|
292
|
-
):
|
|
293
|
-
df_schema = df_container.cached_schema_getter()
|
|
294
|
-
else:
|
|
295
|
-
df_schema = df.schema # Get current schema
|
|
296
|
-
new_columns = []
|
|
297
|
-
|
|
298
|
-
for i, field in enumerate(df_schema.fields):
|
|
299
|
-
col_name = field.name
|
|
300
|
-
current_type = field.datatype
|
|
301
|
-
target_type = target_dtypes[i]
|
|
302
|
-
|
|
303
|
-
if current_type != target_type:
|
|
304
|
-
new_columns.append(df[col_name].cast(target_type).alias(col_name))
|
|
305
|
-
else:
|
|
306
|
-
new_columns.append(df[col_name])
|
|
307
|
-
|
|
308
|
-
new_df = df.select(new_columns)
|
|
309
|
-
return DataFrameContainer.create_with_column_mapping(
|
|
310
|
-
dataframe=new_df,
|
|
311
|
-
spark_column_names=column_map.get_spark_columns(),
|
|
312
|
-
snowpark_column_names=column_map.get_snowpark_columns(),
|
|
313
|
-
snowpark_column_types=target_dtypes,
|
|
314
|
-
column_metadata=column_map.column_metadata,
|
|
315
|
-
parent_column_name_map=column_map,
|
|
316
|
-
)
|
|
317
|
-
|
|
318
358
|
left_result = cast_columns(
|
|
319
359
|
left_result,
|
|
320
360
|
left_dtypes,
|
|
@@ -527,6 +567,48 @@ def map_except(
|
|
|
527
567
|
left_df = left_result.dataframe
|
|
528
568
|
right_df = right_result.dataframe
|
|
529
569
|
|
|
570
|
+
# workaround for unstructured type vs structured type
|
|
571
|
+
# Use cached schema if available to avoid triggering extra queries
|
|
572
|
+
left_schema = get_schema_from_result(left_result)
|
|
573
|
+
right_schema = get_schema_from_result(right_result)
|
|
574
|
+
|
|
575
|
+
left_dtypes = [field.datatype for field in left_schema.fields]
|
|
576
|
+
right_dtypes = [field.datatype for field in right_schema.fields]
|
|
577
|
+
|
|
578
|
+
if left_dtypes != right_dtypes and not rel.set_op.by_name:
|
|
579
|
+
if len(left_dtypes) != len(right_dtypes):
|
|
580
|
+
exception = AnalysisException("UNION: the number of columns must match")
|
|
581
|
+
attach_custom_error_code(exception, ErrorCodes.INVALID_OPERATION)
|
|
582
|
+
raise exception
|
|
583
|
+
target_left_dtypes, target_right_dtypes = [], []
|
|
584
|
+
for left_type, right_type in zip(left_dtypes, right_dtypes):
|
|
585
|
+
match (left_type, right_type):
|
|
586
|
+
case (snowpark.types._NumericType(), snowpark.types.StringType()) | (
|
|
587
|
+
snowpark.types.StringType(),
|
|
588
|
+
snowpark.types._NumericType(),
|
|
589
|
+
):
|
|
590
|
+
common_type = snowpark.types.StringType()
|
|
591
|
+
target_left_dtypes.append(common_type)
|
|
592
|
+
target_right_dtypes.append(common_type)
|
|
593
|
+
case _:
|
|
594
|
+
target_left_dtypes.append(left_type)
|
|
595
|
+
target_right_dtypes.append(right_type)
|
|
596
|
+
|
|
597
|
+
left_result = cast_columns(
|
|
598
|
+
left_result,
|
|
599
|
+
left_dtypes,
|
|
600
|
+
target_left_dtypes,
|
|
601
|
+
left_result.column_map,
|
|
602
|
+
)
|
|
603
|
+
right_result = cast_columns(
|
|
604
|
+
right_result,
|
|
605
|
+
right_dtypes,
|
|
606
|
+
target_right_dtypes,
|
|
607
|
+
right_result.column_map,
|
|
608
|
+
)
|
|
609
|
+
left_df = left_result.dataframe
|
|
610
|
+
right_df = right_result.dataframe
|
|
611
|
+
|
|
530
612
|
if rel.set_op.is_all:
|
|
531
613
|
# Snowflake except removes all duplicated rows. In order to handle the case,
|
|
532
614
|
# we add a partition row number column to the df to make duplicated rows unique to
|