snowpark-connect 0.20.2__py3-none-any.whl → 0.22.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of snowpark-connect might be problematic. Click here for more details.
- snowflake/snowpark_connect/analyze_plan/map_tree_string.py +3 -2
- snowflake/snowpark_connect/column_name_handler.py +6 -65
- snowflake/snowpark_connect/config.py +47 -17
- snowflake/snowpark_connect/dataframe_container.py +242 -0
- snowflake/snowpark_connect/error/error_utils.py +25 -0
- snowflake/snowpark_connect/execute_plan/map_execution_command.py +13 -23
- snowflake/snowpark_connect/execute_plan/map_execution_root.py +9 -5
- snowflake/snowpark_connect/expression/map_extension.py +2 -1
- snowflake/snowpark_connect/expression/map_udf.py +4 -4
- snowflake/snowpark_connect/expression/map_unresolved_attribute.py +8 -7
- snowflake/snowpark_connect/expression/map_unresolved_function.py +481 -170
- snowflake/snowpark_connect/expression/map_unresolved_star.py +8 -8
- snowflake/snowpark_connect/expression/map_update_fields.py +1 -1
- snowflake/snowpark_connect/expression/typer.py +6 -6
- snowflake/snowpark_connect/proto/control_pb2.py +17 -16
- snowflake/snowpark_connect/proto/control_pb2.pyi +17 -17
- snowflake/snowpark_connect/proto/control_pb2_grpc.py +12 -63
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.py +15 -14
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.pyi +19 -14
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2_grpc.py +4 -0
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +27 -26
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +74 -68
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2_grpc.py +4 -0
- snowflake/snowpark_connect/relation/catalogs/abstract_spark_catalog.py +5 -5
- snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +25 -17
- snowflake/snowpark_connect/relation/map_aggregate.py +170 -61
- snowflake/snowpark_connect/relation/map_catalog.py +2 -2
- snowflake/snowpark_connect/relation/map_column_ops.py +227 -145
- snowflake/snowpark_connect/relation/map_crosstab.py +25 -6
- snowflake/snowpark_connect/relation/map_extension.py +81 -56
- snowflake/snowpark_connect/relation/map_join.py +72 -63
- snowflake/snowpark_connect/relation/map_local_relation.py +35 -20
- snowflake/snowpark_connect/relation/map_map_partitions.py +24 -17
- snowflake/snowpark_connect/relation/map_relation.py +22 -16
- snowflake/snowpark_connect/relation/map_row_ops.py +232 -146
- snowflake/snowpark_connect/relation/map_sample_by.py +15 -8
- snowflake/snowpark_connect/relation/map_show_string.py +42 -5
- snowflake/snowpark_connect/relation/map_sql.py +141 -237
- snowflake/snowpark_connect/relation/map_stats.py +88 -39
- snowflake/snowpark_connect/relation/map_subquery_alias.py +13 -14
- snowflake/snowpark_connect/relation/map_udtf.py +10 -13
- snowflake/snowpark_connect/relation/read/map_read.py +8 -3
- snowflake/snowpark_connect/relation/read/map_read_csv.py +7 -7
- snowflake/snowpark_connect/relation/read/map_read_jdbc.py +7 -7
- snowflake/snowpark_connect/relation/read/map_read_json.py +19 -8
- snowflake/snowpark_connect/relation/read/map_read_parquet.py +7 -7
- snowflake/snowpark_connect/relation/read/map_read_socket.py +7 -3
- snowflake/snowpark_connect/relation/read/map_read_table.py +25 -16
- snowflake/snowpark_connect/relation/read/map_read_text.py +7 -7
- snowflake/snowpark_connect/relation/read/reader_config.py +1 -0
- snowflake/snowpark_connect/relation/utils.py +11 -5
- snowflake/snowpark_connect/relation/write/jdbc_write_dbapi.py +15 -12
- snowflake/snowpark_connect/relation/write/map_write.py +259 -56
- snowflake/snowpark_connect/relation/write/map_write_jdbc.py +3 -2
- snowflake/snowpark_connect/server.py +43 -4
- snowflake/snowpark_connect/type_mapping.py +6 -23
- snowflake/snowpark_connect/utils/cache.py +27 -22
- snowflake/snowpark_connect/utils/context.py +33 -17
- snowflake/snowpark_connect/utils/describe_query_cache.py +2 -9
- snowflake/snowpark_connect/utils/{attribute_handling.py → identifiers.py} +47 -0
- snowflake/snowpark_connect/utils/session.py +41 -38
- snowflake/snowpark_connect/utils/telemetry.py +214 -63
- snowflake/snowpark_connect/utils/udxf_import_utils.py +14 -0
- snowflake/snowpark_connect/version.py +1 -1
- snowflake/snowpark_decoder/__init__.py +0 -0
- snowflake/snowpark_decoder/_internal/proto/generated/DataframeProcessorMsg_pb2.py +36 -0
- snowflake/snowpark_decoder/_internal/proto/generated/DataframeProcessorMsg_pb2.pyi +156 -0
- snowflake/snowpark_decoder/dp_session.py +111 -0
- snowflake/snowpark_decoder/spark_decoder.py +76 -0
- {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.22.1.dist-info}/METADATA +6 -4
- {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.22.1.dist-info}/RECORD +83 -69
- snowpark_connect-0.22.1.dist-info/licenses/LICENSE-binary +568 -0
- snowpark_connect-0.22.1.dist-info/licenses/NOTICE-binary +1533 -0
- {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.22.1.dist-info}/top_level.txt +1 -0
- spark/__init__.py +0 -0
- spark/connect/__init__.py +0 -0
- spark/connect/envelope_pb2.py +31 -0
- spark/connect/envelope_pb2.pyi +46 -0
- snowflake/snowpark_connect/includes/jars/jackson-mapper-asl-1.9.13.jar +0 -0
- {snowpark_connect-0.20.2.data → snowpark_connect-0.22.1.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.20.2.data → snowpark_connect-0.22.1.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.20.2.data → snowpark_connect-0.22.1.data}/scripts/snowpark-submit +0 -0
- {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.22.1.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.22.1.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -6,26 +6,41 @@ import pyspark.sql.connect.proto.relations_pb2 as relation_proto
|
|
|
6
6
|
|
|
7
7
|
import snowflake.snowpark.functions as fn
|
|
8
8
|
from snowflake import snowpark
|
|
9
|
-
from snowflake.snowpark_connect.
|
|
9
|
+
from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
|
|
10
10
|
from snowflake.snowpark_connect.relation.map_relation import map_relation
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
def map_crosstab(
|
|
14
14
|
rel: relation_proto.Relation,
|
|
15
|
-
) ->
|
|
15
|
+
) -> DataFrameContainer:
|
|
16
16
|
"""
|
|
17
17
|
Perform a crosstab on the input DataFrame.
|
|
18
18
|
"""
|
|
19
|
-
|
|
20
|
-
|
|
19
|
+
input_container = map_relation(rel.crosstab.input)
|
|
20
|
+
input_df = input_container.dataframe
|
|
21
|
+
|
|
22
|
+
col1 = input_container.column_map.get_snowpark_column_name_from_spark_column_name(
|
|
21
23
|
rel.crosstab.col1
|
|
22
24
|
)
|
|
23
|
-
col2 =
|
|
25
|
+
col2 = input_container.column_map.get_snowpark_column_name_from_spark_column_name(
|
|
24
26
|
rel.crosstab.col2
|
|
25
27
|
)
|
|
26
28
|
input_df = input_df.select(
|
|
27
29
|
fn.col(col1).cast("string").alias(col1), fn.col(col2).cast("string").alias(col2)
|
|
28
30
|
)
|
|
31
|
+
|
|
32
|
+
# Handle empty DataFrame case
|
|
33
|
+
if input_df.count() == 0:
|
|
34
|
+
# For empty DataFrame, return a DataFrame with just the first column name
|
|
35
|
+
result = input_df.select(
|
|
36
|
+
fn.lit(f"{rel.crosstab.col1}_{rel.crosstab.col2}").alias("c0")
|
|
37
|
+
)
|
|
38
|
+
return DataFrameContainer.create_with_column_mapping(
|
|
39
|
+
dataframe=result,
|
|
40
|
+
spark_column_names=[f"{rel.crosstab.col1}_{rel.crosstab.col2}"],
|
|
41
|
+
snowpark_column_names=["c0"],
|
|
42
|
+
)
|
|
43
|
+
|
|
29
44
|
result: snowpark.DataFrame = input_df.crosstab(col1, col2)
|
|
30
45
|
new_columns = [f"{rel.crosstab.col1}_{rel.crosstab.col2}"] + [
|
|
31
46
|
(
|
|
@@ -45,4 +60,8 @@ def map_crosstab(
|
|
|
45
60
|
result = result.rename(
|
|
46
61
|
dict(zip(result.columns, [f"c{i}" for i in range(len(result.columns))]))
|
|
47
62
|
)
|
|
48
|
-
return
|
|
63
|
+
return DataFrameContainer.create_with_column_mapping(
|
|
64
|
+
dataframe=result,
|
|
65
|
+
spark_column_names=new_columns,
|
|
66
|
+
snowpark_column_names=result.columns,
|
|
67
|
+
)
|
|
@@ -14,26 +14,29 @@ from snowflake import snowpark
|
|
|
14
14
|
from snowflake.snowpark_connect.column_name_handler import (
|
|
15
15
|
ColumnNameMap,
|
|
16
16
|
make_column_names_snowpark_compatible,
|
|
17
|
-
with_column_map,
|
|
18
17
|
)
|
|
19
18
|
from snowflake.snowpark_connect.config import get_boolean_session_config_param
|
|
19
|
+
from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
|
|
20
20
|
from snowflake.snowpark_connect.expression.map_expression import map_expression
|
|
21
21
|
from snowflake.snowpark_connect.expression.typer import ExpressionTyper
|
|
22
22
|
from snowflake.snowpark_connect.relation.map_relation import map_relation
|
|
23
23
|
from snowflake.snowpark_connect.typed_column import TypedColumn
|
|
24
|
-
from snowflake.snowpark_connect.utils.attribute_handling import (
|
|
25
|
-
split_fully_qualified_spark_name,
|
|
26
|
-
)
|
|
27
24
|
from snowflake.snowpark_connect.utils.context import (
|
|
28
25
|
get_sql_aggregate_function_count,
|
|
29
26
|
push_outer_dataframe,
|
|
27
|
+
set_current_grouping_columns,
|
|
28
|
+
)
|
|
29
|
+
from snowflake.snowpark_connect.utils.identifiers import (
|
|
30
|
+
split_fully_qualified_spark_name,
|
|
30
31
|
)
|
|
31
32
|
from snowflake.snowpark_connect.utils.telemetry import (
|
|
32
33
|
SnowparkConnectNotImplementedError,
|
|
33
34
|
)
|
|
34
35
|
|
|
35
36
|
|
|
36
|
-
def map_extension(
|
|
37
|
+
def map_extension(
|
|
38
|
+
rel: relation_proto.Relation,
|
|
39
|
+
) -> DataFrameContainer:
|
|
37
40
|
"""
|
|
38
41
|
The Extension relation type contains any extensions we use for adding new
|
|
39
42
|
functionality to Spark Connect.
|
|
@@ -46,7 +49,8 @@ def map_extension(rel: relation_proto.Relation) -> snowpark.DataFrame:
|
|
|
46
49
|
match extension.WhichOneof("op"):
|
|
47
50
|
case "rdd_map":
|
|
48
51
|
rdd_map = extension.rdd_map
|
|
49
|
-
|
|
52
|
+
result = map_relation(rdd_map.input)
|
|
53
|
+
input_df = result.dataframe
|
|
50
54
|
|
|
51
55
|
column_name = "_RDD_"
|
|
52
56
|
if len(input_df.columns) > 1:
|
|
@@ -67,32 +71,39 @@ def map_extension(rel: relation_proto.Relation) -> snowpark.DataFrame:
|
|
|
67
71
|
replace=True,
|
|
68
72
|
)
|
|
69
73
|
result = input_df.select(func(column_name).as_(column_name))
|
|
70
|
-
return
|
|
74
|
+
return DataFrameContainer.create_with_column_mapping(
|
|
75
|
+
dataframe=result,
|
|
76
|
+
spark_column_names=[column_name],
|
|
77
|
+
snowpark_column_names=[column_name],
|
|
78
|
+
snowpark_column_types=[return_type],
|
|
79
|
+
)
|
|
71
80
|
case "subquery_column_aliases":
|
|
72
81
|
subquery_aliases = extension.subquery_column_aliases
|
|
73
82
|
rel.extension.Unpack(subquery_aliases)
|
|
74
|
-
|
|
75
|
-
|
|
83
|
+
result = map_relation(subquery_aliases.input)
|
|
84
|
+
input_df = result.dataframe
|
|
85
|
+
snowpark_col_names = result.column_map.get_snowpark_columns()
|
|
76
86
|
if len(subquery_aliases.aliases) != len(snowpark_col_names):
|
|
77
87
|
raise AnalysisException(
|
|
78
88
|
"Number of column aliases does not match number of columns. "
|
|
79
89
|
f"Number of column aliases: {len(subquery_aliases.aliases)}; "
|
|
80
90
|
f"number of columns: {len(snowpark_col_names)}."
|
|
81
91
|
)
|
|
82
|
-
return
|
|
83
|
-
input_df,
|
|
84
|
-
subquery_aliases.aliases,
|
|
85
|
-
snowpark_col_names,
|
|
86
|
-
column_qualifiers=
|
|
92
|
+
return DataFrameContainer.create_with_column_mapping(
|
|
93
|
+
dataframe=input_df,
|
|
94
|
+
spark_column_names=subquery_aliases.aliases,
|
|
95
|
+
snowpark_column_names=snowpark_col_names,
|
|
96
|
+
column_qualifiers=result.column_map.get_qualifiers(),
|
|
87
97
|
)
|
|
88
98
|
case "lateral_join":
|
|
89
99
|
lateral_join = extension.lateral_join
|
|
90
|
-
|
|
100
|
+
left_result = map_relation(lateral_join.left)
|
|
101
|
+
left_df = left_result.dataframe
|
|
91
102
|
|
|
92
103
|
udtf_info = get_udtf_project(lateral_join.right)
|
|
93
104
|
if udtf_info:
|
|
94
105
|
return handle_lateral_join_with_udtf(
|
|
95
|
-
|
|
106
|
+
left_result, lateral_join.right, udtf_info
|
|
96
107
|
)
|
|
97
108
|
|
|
98
109
|
left_queries = left_df.queries["queries"]
|
|
@@ -101,8 +112,9 @@ def map_extension(rel: relation_proto.Relation) -> snowpark.DataFrame:
|
|
|
101
112
|
f"Unexpected number of queries: {len(left_queries)}"
|
|
102
113
|
)
|
|
103
114
|
left_query = left_queries[0]
|
|
104
|
-
with push_outer_dataframe(
|
|
105
|
-
|
|
115
|
+
with push_outer_dataframe(left_result):
|
|
116
|
+
right_result = map_relation(lateral_join.right)
|
|
117
|
+
right_df = right_result.dataframe
|
|
106
118
|
right_queries = right_df.queries["queries"]
|
|
107
119
|
if len(right_queries) != 1:
|
|
108
120
|
raise SnowparkConnectNotImplementedError(
|
|
@@ -112,14 +124,14 @@ def map_extension(rel: relation_proto.Relation) -> snowpark.DataFrame:
|
|
|
112
124
|
input_df_sql = f"WITH __left AS ({left_query}) SELECT * FROM __left INNER JOIN LATERAL ({right_query})"
|
|
113
125
|
session = snowpark.Session.get_active_session()
|
|
114
126
|
input_df = session.sql(input_df_sql)
|
|
115
|
-
return
|
|
116
|
-
input_df,
|
|
117
|
-
|
|
118
|
-
+
|
|
119
|
-
|
|
120
|
-
+
|
|
121
|
-
column_qualifiers=
|
|
122
|
-
+
|
|
127
|
+
return DataFrameContainer.create_with_column_mapping(
|
|
128
|
+
dataframe=input_df,
|
|
129
|
+
spark_column_names=left_result.column_map.get_spark_columns()
|
|
130
|
+
+ right_result.column_map.get_spark_columns(),
|
|
131
|
+
snowpark_column_names=left_result.column_map.get_snowpark_columns()
|
|
132
|
+
+ right_result.column_map.get_snowpark_columns(),
|
|
133
|
+
column_qualifiers=left_result.column_map.get_qualifiers()
|
|
134
|
+
+ right_result.column_map.get_qualifiers(),
|
|
123
135
|
)
|
|
124
136
|
|
|
125
137
|
case "udtf_with_table_arguments":
|
|
@@ -165,13 +177,13 @@ def handle_udtf_with_table_arguments(
|
|
|
165
177
|
raise ValueError(f"UDTF '{udtf_info.function_name}' not found.")
|
|
166
178
|
_udtf_obj, udtf_spark_output_names = session._udtfs[udtf_name_lower]
|
|
167
179
|
|
|
168
|
-
|
|
180
|
+
table_containers = []
|
|
169
181
|
for table_arg_info in udtf_info.table_arguments:
|
|
170
|
-
|
|
171
|
-
|
|
182
|
+
result = map_relation(table_arg_info.table_argument)
|
|
183
|
+
table_containers.append((result, table_arg_info.table_argument_idx))
|
|
172
184
|
|
|
173
|
-
if len(
|
|
174
|
-
base_df =
|
|
185
|
+
if len(table_containers) == 1:
|
|
186
|
+
base_df = table_containers[0][0].dataframe
|
|
175
187
|
else:
|
|
176
188
|
if not get_boolean_session_config_param(
|
|
177
189
|
"spark.sql.tvf.allowMultipleTableArguments.enabled"
|
|
@@ -181,11 +193,11 @@ def handle_udtf_with_table_arguments(
|
|
|
181
193
|
"Please set `spark.sql.tvf.allowMultipleTableArguments.enabled` to `true`"
|
|
182
194
|
)
|
|
183
195
|
|
|
184
|
-
base_df =
|
|
196
|
+
base_df = table_containers[0][0].dataframe
|
|
185
197
|
first_table_col_count = len(base_df.columns)
|
|
186
198
|
|
|
187
|
-
for
|
|
188
|
-
base_df = base_df.cross_join(
|
|
199
|
+
for table_container, _ in table_containers[1:]:
|
|
200
|
+
base_df = base_df.cross_join(table_container.dataframe)
|
|
189
201
|
|
|
190
202
|
# Ensure deterministic ordering to match Spark's Cartesian product behavior
|
|
191
203
|
# For two tables A and B, Spark produces: for each B row, iterate through A rows
|
|
@@ -206,9 +218,9 @@ def handle_udtf_with_table_arguments(
|
|
|
206
218
|
scalar_args.append(typed_column.col)
|
|
207
219
|
|
|
208
220
|
table_arg_variants = []
|
|
209
|
-
for
|
|
210
|
-
table_columns =
|
|
211
|
-
spark_columns =
|
|
221
|
+
for table_container, table_arg_idx in table_containers:
|
|
222
|
+
table_columns = table_container.column_map.get_snowpark_columns()
|
|
223
|
+
spark_columns = table_container.column_map.get_spark_columns()
|
|
212
224
|
|
|
213
225
|
# Create a structure that supports both positional and named access
|
|
214
226
|
# Format: {"__fields__": ["col1", "col2"], "__values__": [val1, val2]}
|
|
@@ -247,15 +259,15 @@ def handle_udtf_with_table_arguments(
|
|
|
247
259
|
|
|
248
260
|
final_df = result_df.select(*udtf_output_columns)
|
|
249
261
|
|
|
250
|
-
return
|
|
251
|
-
final_df,
|
|
252
|
-
udtf_spark_output_names,
|
|
253
|
-
udtf_output_columns,
|
|
262
|
+
return DataFrameContainer.create_with_column_mapping(
|
|
263
|
+
dataframe=final_df,
|
|
264
|
+
spark_column_names=udtf_spark_output_names,
|
|
265
|
+
snowpark_column_names=udtf_output_columns,
|
|
254
266
|
)
|
|
255
267
|
|
|
256
268
|
|
|
257
269
|
def handle_lateral_join_with_udtf(
|
|
258
|
-
|
|
270
|
+
left_result: DataFrameContainer,
|
|
259
271
|
udtf_relation: relation_proto.Relation,
|
|
260
272
|
udtf_info: tuple[snowpark.udtf.UserDefinedTableFunction, list],
|
|
261
273
|
) -> snowpark.DataFrame:
|
|
@@ -269,7 +281,8 @@ def handle_lateral_join_with_udtf(
|
|
|
269
281
|
_udtf_obj, udtf_spark_output_names = udtf_info
|
|
270
282
|
|
|
271
283
|
typer = ExpressionTyper.dummy_typer(session)
|
|
272
|
-
left_column_map =
|
|
284
|
+
left_column_map = left_result.column_map
|
|
285
|
+
left_df = left_result.dataframe
|
|
273
286
|
table_func = snowpark_fn.table_function(_udtf_obj.name)
|
|
274
287
|
udtf_args = [
|
|
275
288
|
map_expression(arg_proto, left_column_map, typer)[1].col
|
|
@@ -278,11 +291,12 @@ def handle_lateral_join_with_udtf(
|
|
|
278
291
|
udtf_args_variant = [snowpark_fn.to_variant(arg) for arg in udtf_args]
|
|
279
292
|
result_df = left_df.join_table_function(table_func(*udtf_args_variant))
|
|
280
293
|
|
|
281
|
-
return
|
|
282
|
-
result_df,
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
294
|
+
return DataFrameContainer.create_with_column_mapping(
|
|
295
|
+
dataframe=result_df,
|
|
296
|
+
spark_column_names=left_result.column_map.get_spark_columns()
|
|
297
|
+
+ udtf_spark_output_names,
|
|
298
|
+
snowpark_column_names=result_df.columns,
|
|
299
|
+
column_qualifiers=left_result.column_map.get_qualifiers()
|
|
286
300
|
+ [[]] * len(udtf_spark_output_names),
|
|
287
301
|
)
|
|
288
302
|
|
|
@@ -290,7 +304,8 @@ def handle_lateral_join_with_udtf(
|
|
|
290
304
|
def map_aggregate(
|
|
291
305
|
aggregate: snowflake_proto.Aggregate, plan_id: int
|
|
292
306
|
) -> snowpark.DataFrame:
|
|
293
|
-
|
|
307
|
+
input_container = map_relation(aggregate.input)
|
|
308
|
+
input_df: snowpark.DataFrame = input_container.dataframe
|
|
294
309
|
|
|
295
310
|
# Detect the "GROUP BY ALL" case:
|
|
296
311
|
# - it's a plain GROUP BY (not ROLLUP, CUBE, etc.)
|
|
@@ -307,7 +322,7 @@ def map_aggregate(
|
|
|
307
322
|
if (
|
|
308
323
|
len(parsed_col_name) == 1
|
|
309
324
|
and parsed_col_name[0].lower() == "all"
|
|
310
|
-
and
|
|
325
|
+
and input_container.column_map.get_snowpark_column_name_from_spark_column_name(
|
|
311
326
|
parsed_col_name[0], allow_non_exists=True
|
|
312
327
|
)
|
|
313
328
|
is None
|
|
@@ -320,7 +335,9 @@ def map_aggregate(
|
|
|
320
335
|
typer = ExpressionTyper(input_df)
|
|
321
336
|
|
|
322
337
|
def _map_column(exp: expression_proto.Expression) -> tuple[str, TypedColumn]:
|
|
323
|
-
new_names, snowpark_column = map_expression(
|
|
338
|
+
new_names, snowpark_column = map_expression(
|
|
339
|
+
exp, input_container.column_map, typer
|
|
340
|
+
)
|
|
324
341
|
if len(new_names) != 1:
|
|
325
342
|
raise SnowparkConnectNotImplementedError(
|
|
326
343
|
"Multi-column aggregate expressions are not supported"
|
|
@@ -345,6 +362,10 @@ def map_aggregate(
|
|
|
345
362
|
if not is_group_by_all:
|
|
346
363
|
raw_groupings = [_map_column(exp) for exp in aggregate.grouping_expressions]
|
|
347
364
|
|
|
365
|
+
# Set the current grouping columns in context for grouping_id() function
|
|
366
|
+
grouping_spark_columns = [spark_name for spark_name, _ in raw_groupings]
|
|
367
|
+
set_current_grouping_columns(grouping_spark_columns)
|
|
368
|
+
|
|
348
369
|
# Now create column name lists and assign aliases.
|
|
349
370
|
# In case of GROUP BY ALL, even though groupings are a subset of aggregations,
|
|
350
371
|
# they will have their own aliases so we can drop them later.
|
|
@@ -378,7 +399,7 @@ def map_aggregate(
|
|
|
378
399
|
# TODO: What do we do about groupings?
|
|
379
400
|
sets = (
|
|
380
401
|
[
|
|
381
|
-
map_expression(exp,
|
|
402
|
+
map_expression(exp, input_container.column_map, typer)[1].col
|
|
382
403
|
for exp in grouping_sets.grouping_set
|
|
383
404
|
]
|
|
384
405
|
for grouping_sets in aggregate.grouping_sets
|
|
@@ -397,16 +418,20 @@ def map_aggregate(
|
|
|
397
418
|
result = result.select(result.columns[-len(spark_columns) :])
|
|
398
419
|
|
|
399
420
|
# Build a parent column map that includes groupings.
|
|
400
|
-
|
|
401
|
-
result,
|
|
421
|
+
result_container = DataFrameContainer.create_with_column_mapping(
|
|
422
|
+
dataframe=result,
|
|
423
|
+
spark_column_names=spark_columns,
|
|
424
|
+
snowpark_column_names=snowpark_columns,
|
|
425
|
+
snowpark_column_types=snowpark_column_types,
|
|
402
426
|
)
|
|
403
427
|
|
|
404
428
|
# Drop the groupings.
|
|
405
429
|
grouping_count = len(groupings)
|
|
406
|
-
|
|
430
|
+
|
|
431
|
+
return DataFrameContainer.create_with_column_mapping(
|
|
407
432
|
result.drop(snowpark_columns[:grouping_count]),
|
|
408
433
|
spark_columns[grouping_count:],
|
|
409
434
|
snowpark_columns[grouping_count:],
|
|
410
435
|
snowpark_column_types[grouping_count:],
|
|
411
|
-
parent_column_name_map=
|
|
436
|
+
parent_column_name_map=result_container.column_map,
|
|
412
437
|
)
|
|
@@ -8,14 +8,10 @@ import pyspark.sql.connect.proto.relations_pb2 as relation_proto
|
|
|
8
8
|
|
|
9
9
|
import snowflake.snowpark.functions as snowpark_fn
|
|
10
10
|
from snowflake import snowpark
|
|
11
|
-
from snowflake.snowpark_connect.column_name_handler import
|
|
12
|
-
ColumnNameMap,
|
|
13
|
-
JoinColumnNameMap,
|
|
14
|
-
set_schema_getter,
|
|
15
|
-
with_column_map,
|
|
16
|
-
)
|
|
11
|
+
from snowflake.snowpark_connect.column_name_handler import JoinColumnNameMap
|
|
17
12
|
from snowflake.snowpark_connect.config import global_config
|
|
18
13
|
from snowflake.snowpark_connect.constants import COLUMN_METADATA_COLLISION_KEY
|
|
14
|
+
from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
|
|
19
15
|
from snowflake.snowpark_connect.error.error_utils import SparkException
|
|
20
16
|
from snowflake.snowpark_connect.expression.map_expression import (
|
|
21
17
|
map_single_column_expression,
|
|
@@ -38,15 +34,18 @@ from snowflake.snowpark_connect.utils.telemetry import (
|
|
|
38
34
|
USING_COLUMN_NOT_FOUND_ERROR = "[UNRESOLVED_USING_COLUMN_FOR_JOIN] USING column `{0}` not found on the {1} side of the join. The {1}-side columns: {2}"
|
|
39
35
|
|
|
40
36
|
|
|
41
|
-
def map_join(rel: relation_proto.Relation) ->
|
|
42
|
-
|
|
43
|
-
|
|
37
|
+
def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
|
|
38
|
+
left_container: DataFrameContainer = map_relation(rel.join.left)
|
|
39
|
+
right_container: DataFrameContainer = map_relation(rel.join.right)
|
|
40
|
+
|
|
41
|
+
left_input: snowpark.DataFrame = left_container.dataframe
|
|
42
|
+
right_input: snowpark.DataFrame = right_container.dataframe
|
|
44
43
|
is_natural_join = rel.join.join_type >= NATURAL_JOIN_TYPE_BASE
|
|
45
44
|
using_columns = rel.join.using_columns
|
|
46
45
|
if is_natural_join:
|
|
47
46
|
rel.join.join_type -= NATURAL_JOIN_TYPE_BASE
|
|
48
|
-
left_spark_columns =
|
|
49
|
-
right_spark_columns =
|
|
47
|
+
left_spark_columns = left_container.column_map.get_spark_columns()
|
|
48
|
+
right_spark_columns = right_container.column_map.get_spark_columns()
|
|
50
49
|
common_spark_columns = [
|
|
51
50
|
x for x in left_spark_columns if x in right_spark_columns
|
|
52
51
|
]
|
|
@@ -79,8 +78,8 @@ def map_join(rel: relation_proto.Relation) -> snowpark.DataFrame:
|
|
|
79
78
|
if rel.join.HasField("join_condition"):
|
|
80
79
|
assert not using_columns
|
|
81
80
|
|
|
82
|
-
left_columns = list(
|
|
83
|
-
right_columns = list(
|
|
81
|
+
left_columns = list(left_container.column_map.spark_to_col.keys())
|
|
82
|
+
right_columns = list(right_container.column_map.spark_to_col.keys())
|
|
84
83
|
|
|
85
84
|
# All PySpark join types are in the format of JOIN_TYPE_XXX.
|
|
86
85
|
# We remove the first 10 characters (JOIN_TYPE_) and replace all underscores with spaces to match the exception.
|
|
@@ -90,15 +89,15 @@ def map_join(rel: relation_proto.Relation) -> snowpark.DataFrame:
|
|
|
90
89
|
with push_sql_scope(), push_evaluating_join_condition(
|
|
91
90
|
pyspark_join_type, left_columns, right_columns
|
|
92
91
|
):
|
|
93
|
-
if
|
|
94
|
-
set_sql_plan_name(
|
|
95
|
-
if
|
|
96
|
-
set_sql_plan_name(
|
|
92
|
+
if left_container.alias is not None:
|
|
93
|
+
set_sql_plan_name(left_container.alias, rel.join.left.common.plan_id)
|
|
94
|
+
if right_container.alias is not None:
|
|
95
|
+
set_sql_plan_name(right_container.alias, rel.join.right.common.plan_id)
|
|
97
96
|
_, join_expression = map_single_column_expression(
|
|
98
97
|
rel.join.join_condition,
|
|
99
98
|
column_mapping=JoinColumnNameMap(
|
|
100
|
-
|
|
101
|
-
|
|
99
|
+
left_container.column_map,
|
|
100
|
+
right_container.column_map,
|
|
102
101
|
),
|
|
103
102
|
typer=JoinExpressionTyper(left_input, right_input),
|
|
104
103
|
)
|
|
@@ -111,7 +110,7 @@ def map_join(rel: relation_proto.Relation) -> snowpark.DataFrame:
|
|
|
111
110
|
)
|
|
112
111
|
elif using_columns:
|
|
113
112
|
if any(
|
|
114
|
-
|
|
113
|
+
left_container.column_map.get_snowpark_column_name_from_spark_column_name(
|
|
115
114
|
c, allow_non_exists=True, return_first=True
|
|
116
115
|
)
|
|
117
116
|
is None
|
|
@@ -124,17 +123,17 @@ def map_join(rel: relation_proto.Relation) -> snowpark.DataFrame:
|
|
|
124
123
|
next(
|
|
125
124
|
c
|
|
126
125
|
for c in using_columns
|
|
127
|
-
if
|
|
126
|
+
if left_container.column_map.get_snowpark_column_name_from_spark_column_name(
|
|
128
127
|
c, allow_non_exists=True, return_first=True
|
|
129
128
|
)
|
|
130
129
|
is None
|
|
131
130
|
),
|
|
132
131
|
"left",
|
|
133
|
-
|
|
132
|
+
left_container.column_map.get_spark_columns(),
|
|
134
133
|
)
|
|
135
134
|
)
|
|
136
135
|
if any(
|
|
137
|
-
|
|
136
|
+
right_container.column_map.get_snowpark_column_name_from_spark_column_name(
|
|
138
137
|
c, allow_non_exists=True, return_first=True
|
|
139
138
|
)
|
|
140
139
|
is None
|
|
@@ -147,26 +146,26 @@ def map_join(rel: relation_proto.Relation) -> snowpark.DataFrame:
|
|
|
147
146
|
next(
|
|
148
147
|
c
|
|
149
148
|
for c in using_columns
|
|
150
|
-
if
|
|
149
|
+
if right_container.column_map.get_snowpark_column_name_from_spark_column_name(
|
|
151
150
|
c, allow_non_exists=True, return_first=True
|
|
152
151
|
)
|
|
153
152
|
is None
|
|
154
153
|
),
|
|
155
154
|
"right",
|
|
156
|
-
|
|
155
|
+
right_container.column_map.get_spark_columns(),
|
|
157
156
|
)
|
|
158
157
|
)
|
|
159
158
|
|
|
160
159
|
# Round trip the using columns through the column map to get the correct names
|
|
161
160
|
# in order to support case sensitivity.
|
|
162
161
|
# TODO: case_corrected_left_columns / case_corrected_right_columns may no longer be required as Snowpark dataframe preserves the column casing now.
|
|
163
|
-
case_corrected_left_columns =
|
|
164
|
-
|
|
162
|
+
case_corrected_left_columns = left_container.column_map.get_spark_column_names_from_snowpark_column_names(
|
|
163
|
+
left_container.column_map.get_snowpark_column_names_from_spark_column_names(
|
|
165
164
|
list(using_columns), return_first=True
|
|
166
165
|
)
|
|
167
166
|
)
|
|
168
|
-
case_corrected_right_columns =
|
|
169
|
-
|
|
167
|
+
case_corrected_right_columns = right_container.column_map.get_spark_column_names_from_snowpark_column_names(
|
|
168
|
+
right_container.column_map.get_snowpark_column_names_from_spark_column_names(
|
|
170
169
|
list(using_columns), return_first=True
|
|
171
170
|
)
|
|
172
171
|
)
|
|
@@ -177,12 +176,12 @@ def map_join(rel: relation_proto.Relation) -> snowpark.DataFrame:
|
|
|
177
176
|
snowpark_using_columns = [
|
|
178
177
|
(
|
|
179
178
|
left_input[
|
|
180
|
-
|
|
179
|
+
left_container.column_map.get_snowpark_column_name_from_spark_column_name(
|
|
181
180
|
lft, return_first=True
|
|
182
181
|
)
|
|
183
182
|
],
|
|
184
183
|
right_input[
|
|
185
|
-
|
|
184
|
+
right_container.column_map.get_snowpark_column_name_from_spark_column_name(
|
|
186
185
|
r, return_first=True
|
|
187
186
|
)
|
|
188
187
|
],
|
|
@@ -231,45 +230,49 @@ def map_join(rel: relation_proto.Relation) -> snowpark.DataFrame:
|
|
|
231
230
|
# - LEFT SEMI JOIN: Returns left rows that have matches in right table (no right columns)
|
|
232
231
|
# - LEFT ANTI JOIN: Returns left rows that have NO matches in right table (no right columns)
|
|
233
232
|
# Both preserve only the columns from the left DataFrame without adding any columns from the right.
|
|
234
|
-
spark_cols_after_join: list[str] =
|
|
235
|
-
qualifiers =
|
|
233
|
+
spark_cols_after_join: list[str] = left_container.column_map.get_spark_columns()
|
|
234
|
+
qualifiers = left_container.column_map.get_qualifiers()
|
|
236
235
|
else:
|
|
237
236
|
# Add Spark columns and plan_ids from left DF
|
|
238
237
|
spark_cols_after_join: list[str] = list(
|
|
239
|
-
|
|
238
|
+
left_container.column_map.get_spark_columns()
|
|
240
239
|
) + [
|
|
241
240
|
spark_col
|
|
242
|
-
for i, spark_col in enumerate(
|
|
241
|
+
for i, spark_col in enumerate(
|
|
242
|
+
right_container.column_map.get_spark_columns()
|
|
243
|
+
)
|
|
243
244
|
if spark_col not in case_corrected_right_columns
|
|
244
245
|
or spark_col
|
|
245
|
-
in
|
|
246
|
+
in right_container.column_map.get_spark_columns()[
|
|
246
247
|
:i
|
|
247
248
|
] # this is to make sure we only remove the column once
|
|
248
249
|
]
|
|
249
250
|
|
|
250
|
-
qualifiers = list(
|
|
251
|
-
|
|
252
|
-
for i, spark_col in enumerate(
|
|
251
|
+
qualifiers = list(left_container.column_map.get_qualifiers()) + [
|
|
252
|
+
right_container.column_map.get_qualifier_for_spark_column(spark_col)
|
|
253
|
+
for i, spark_col in enumerate(
|
|
254
|
+
right_container.column_map.get_spark_columns()
|
|
255
|
+
)
|
|
253
256
|
if spark_col not in case_corrected_right_columns
|
|
254
257
|
or spark_col
|
|
255
|
-
in
|
|
258
|
+
in right_container.column_map.get_spark_columns()[
|
|
256
259
|
:i
|
|
257
260
|
] # this is to make sure we only remove the column once]
|
|
258
261
|
]
|
|
259
262
|
|
|
260
263
|
column_metadata = {}
|
|
261
|
-
if
|
|
262
|
-
column_metadata.update(
|
|
264
|
+
if left_container.column_map.column_metadata:
|
|
265
|
+
column_metadata.update(left_container.column_map.column_metadata)
|
|
263
266
|
|
|
264
|
-
if
|
|
265
|
-
for key, value in
|
|
267
|
+
if right_container.column_map.column_metadata:
|
|
268
|
+
for key, value in right_container.column_map.column_metadata.items():
|
|
266
269
|
if key not in column_metadata:
|
|
267
270
|
column_metadata[key] = value
|
|
268
271
|
else:
|
|
269
272
|
# In case of collision, use snowpark's column's expr_id as prefix.
|
|
270
273
|
# this is a temporary solution until SNOW-1926440 is resolved.
|
|
271
274
|
try:
|
|
272
|
-
snowpark_name =
|
|
275
|
+
snowpark_name = right_container.column_map.get_snowpark_column_name_from_spark_column_name(
|
|
273
276
|
key
|
|
274
277
|
)
|
|
275
278
|
expr_id = right_input[snowpark_name]._expression.expr_id
|
|
@@ -281,10 +284,10 @@ def map_join(rel: relation_proto.Relation) -> snowpark.DataFrame:
|
|
|
281
284
|
# ignore any errors that happens while fetching the metadata
|
|
282
285
|
pass
|
|
283
286
|
|
|
284
|
-
|
|
285
|
-
result,
|
|
286
|
-
spark_cols_after_join,
|
|
287
|
-
result.columns,
|
|
287
|
+
result_container = DataFrameContainer.create_with_column_mapping(
|
|
288
|
+
dataframe=result,
|
|
289
|
+
spark_column_names=spark_cols_after_join,
|
|
290
|
+
snowpark_column_names=result.columns,
|
|
288
291
|
column_metadata=column_metadata,
|
|
289
292
|
column_qualifiers=qualifiers,
|
|
290
293
|
)
|
|
@@ -298,7 +301,7 @@ def map_join(rel: relation_proto.Relation) -> snowpark.DataFrame:
|
|
|
298
301
|
and rel.join.right.common.HasField("plan_id")
|
|
299
302
|
):
|
|
300
303
|
right_plan_id = rel.join.right.common.plan_id
|
|
301
|
-
set_plan_id_map(right_plan_id,
|
|
304
|
+
set_plan_id_map(right_plan_id, result_container)
|
|
302
305
|
|
|
303
306
|
# For FULL OUTER joins, we also need to map the left dataframe's plan_id
|
|
304
307
|
# since both columns are replaced with a coalesced column
|
|
@@ -309,7 +312,7 @@ def map_join(rel: relation_proto.Relation) -> snowpark.DataFrame:
|
|
|
309
312
|
and rel.join.left.common.HasField("plan_id")
|
|
310
313
|
):
|
|
311
314
|
left_plan_id = rel.join.left.common.plan_id
|
|
312
|
-
set_plan_id_map(left_plan_id,
|
|
315
|
+
set_plan_id_map(left_plan_id, result_container)
|
|
313
316
|
|
|
314
317
|
if rel.join.using_columns:
|
|
315
318
|
# When join 'using_columns', the 'join columns' should go first in result DF.
|
|
@@ -323,19 +326,25 @@ def map_join(rel: relation_proto.Relation) -> snowpark.DataFrame:
|
|
|
323
326
|
remaining = [el for i, el in enumerate(lst) if i not in idxs_to_shift]
|
|
324
327
|
return to_move + remaining
|
|
325
328
|
|
|
326
|
-
|
|
327
|
-
|
|
329
|
+
# Create reordered DataFrame
|
|
330
|
+
reordered_df = result_container.dataframe.select(
|
|
331
|
+
[snowpark_fn.col(c) for c in reorder(result_container.dataframe.columns)]
|
|
328
332
|
)
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
333
|
+
|
|
334
|
+
# Create new container with reordered metadata
|
|
335
|
+
original_df = result_container.dataframe
|
|
336
|
+
return DataFrameContainer.create_with_column_mapping(
|
|
337
|
+
dataframe=reordered_df,
|
|
338
|
+
spark_column_names=reorder(result_container.column_map.get_spark_columns()),
|
|
339
|
+
snowpark_column_names=reorder(
|
|
340
|
+
result_container.column_map.get_snowpark_columns()
|
|
341
|
+
),
|
|
332
342
|
column_metadata=column_metadata,
|
|
333
343
|
column_qualifiers=reorder(qualifiers),
|
|
344
|
+
table_name=result_container.table_name,
|
|
345
|
+
cached_schema_getter=lambda: snowpark.types.StructType(
|
|
346
|
+
reorder(original_df.schema.fields)
|
|
347
|
+
),
|
|
334
348
|
)
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
reordered_df,
|
|
338
|
-
lambda: snowpark.types.StructType(reorder(result_df.schema.fields)),
|
|
339
|
-
)
|
|
340
|
-
return reordered_df
|
|
341
|
-
return result_df
|
|
349
|
+
|
|
350
|
+
return result_container
|