snowpark-connect 0.20.2__py3-none-any.whl → 0.22.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of snowpark-connect might be problematic. Click here for more details.
- snowflake/snowpark_connect/analyze_plan/map_tree_string.py +3 -2
- snowflake/snowpark_connect/column_name_handler.py +6 -65
- snowflake/snowpark_connect/config.py +47 -17
- snowflake/snowpark_connect/dataframe_container.py +242 -0
- snowflake/snowpark_connect/error/error_utils.py +25 -0
- snowflake/snowpark_connect/execute_plan/map_execution_command.py +13 -23
- snowflake/snowpark_connect/execute_plan/map_execution_root.py +9 -5
- snowflake/snowpark_connect/expression/map_extension.py +2 -1
- snowflake/snowpark_connect/expression/map_udf.py +4 -4
- snowflake/snowpark_connect/expression/map_unresolved_attribute.py +8 -7
- snowflake/snowpark_connect/expression/map_unresolved_function.py +481 -170
- snowflake/snowpark_connect/expression/map_unresolved_star.py +8 -8
- snowflake/snowpark_connect/expression/map_update_fields.py +1 -1
- snowflake/snowpark_connect/expression/typer.py +6 -6
- snowflake/snowpark_connect/proto/control_pb2.py +17 -16
- snowflake/snowpark_connect/proto/control_pb2.pyi +17 -17
- snowflake/snowpark_connect/proto/control_pb2_grpc.py +12 -63
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.py +15 -14
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.pyi +19 -14
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2_grpc.py +4 -0
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +27 -26
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +74 -68
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2_grpc.py +4 -0
- snowflake/snowpark_connect/relation/catalogs/abstract_spark_catalog.py +5 -5
- snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +25 -17
- snowflake/snowpark_connect/relation/map_aggregate.py +170 -61
- snowflake/snowpark_connect/relation/map_catalog.py +2 -2
- snowflake/snowpark_connect/relation/map_column_ops.py +227 -145
- snowflake/snowpark_connect/relation/map_crosstab.py +25 -6
- snowflake/snowpark_connect/relation/map_extension.py +81 -56
- snowflake/snowpark_connect/relation/map_join.py +72 -63
- snowflake/snowpark_connect/relation/map_local_relation.py +35 -20
- snowflake/snowpark_connect/relation/map_map_partitions.py +24 -17
- snowflake/snowpark_connect/relation/map_relation.py +22 -16
- snowflake/snowpark_connect/relation/map_row_ops.py +232 -146
- snowflake/snowpark_connect/relation/map_sample_by.py +15 -8
- snowflake/snowpark_connect/relation/map_show_string.py +42 -5
- snowflake/snowpark_connect/relation/map_sql.py +141 -237
- snowflake/snowpark_connect/relation/map_stats.py +88 -39
- snowflake/snowpark_connect/relation/map_subquery_alias.py +13 -14
- snowflake/snowpark_connect/relation/map_udtf.py +10 -13
- snowflake/snowpark_connect/relation/read/map_read.py +8 -3
- snowflake/snowpark_connect/relation/read/map_read_csv.py +7 -7
- snowflake/snowpark_connect/relation/read/map_read_jdbc.py +7 -7
- snowflake/snowpark_connect/relation/read/map_read_json.py +19 -8
- snowflake/snowpark_connect/relation/read/map_read_parquet.py +7 -7
- snowflake/snowpark_connect/relation/read/map_read_socket.py +7 -3
- snowflake/snowpark_connect/relation/read/map_read_table.py +25 -16
- snowflake/snowpark_connect/relation/read/map_read_text.py +7 -7
- snowflake/snowpark_connect/relation/read/reader_config.py +1 -0
- snowflake/snowpark_connect/relation/utils.py +11 -5
- snowflake/snowpark_connect/relation/write/jdbc_write_dbapi.py +15 -12
- snowflake/snowpark_connect/relation/write/map_write.py +259 -56
- snowflake/snowpark_connect/relation/write/map_write_jdbc.py +3 -2
- snowflake/snowpark_connect/server.py +43 -4
- snowflake/snowpark_connect/type_mapping.py +6 -23
- snowflake/snowpark_connect/utils/cache.py +27 -22
- snowflake/snowpark_connect/utils/context.py +33 -17
- snowflake/snowpark_connect/utils/describe_query_cache.py +2 -9
- snowflake/snowpark_connect/utils/{attribute_handling.py → identifiers.py} +47 -0
- snowflake/snowpark_connect/utils/session.py +41 -38
- snowflake/snowpark_connect/utils/telemetry.py +214 -63
- snowflake/snowpark_connect/utils/udxf_import_utils.py +14 -0
- snowflake/snowpark_connect/version.py +1 -1
- snowflake/snowpark_decoder/__init__.py +0 -0
- snowflake/snowpark_decoder/_internal/proto/generated/DataframeProcessorMsg_pb2.py +36 -0
- snowflake/snowpark_decoder/_internal/proto/generated/DataframeProcessorMsg_pb2.pyi +156 -0
- snowflake/snowpark_decoder/dp_session.py +111 -0
- snowflake/snowpark_decoder/spark_decoder.py +76 -0
- {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.22.1.dist-info}/METADATA +6 -4
- {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.22.1.dist-info}/RECORD +83 -69
- snowpark_connect-0.22.1.dist-info/licenses/LICENSE-binary +568 -0
- snowpark_connect-0.22.1.dist-info/licenses/NOTICE-binary +1533 -0
- {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.22.1.dist-info}/top_level.txt +1 -0
- spark/__init__.py +0 -0
- spark/connect/__init__.py +0 -0
- spark/connect/envelope_pb2.py +31 -0
- spark/connect/envelope_pb2.pyi +46 -0
- snowflake/snowpark_connect/includes/jars/jackson-mapper-asl-1.9.13.jar +0 -0
- {snowpark_connect-0.20.2.data → snowpark_connect-0.22.1.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.20.2.data → snowpark_connect-0.22.1.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.20.2.data → snowpark_connect-0.22.1.data}/scripts/snowpark-submit +0 -0
- {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.22.1.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.22.1.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -4,15 +4,19 @@
|
|
|
4
4
|
|
|
5
5
|
import re
|
|
6
6
|
from dataclasses import dataclass
|
|
7
|
+
from typing import Optional
|
|
7
8
|
|
|
8
9
|
import pyspark.sql.connect.proto.relations_pb2 as relation_proto
|
|
9
10
|
|
|
11
|
+
import snowflake.snowpark.functions as snowpark_fn
|
|
10
12
|
from snowflake import snowpark
|
|
13
|
+
from snowflake.snowpark import Column
|
|
14
|
+
from snowflake.snowpark._internal.analyzer.unary_expression import Alias
|
|
11
15
|
from snowflake.snowpark.types import DataType
|
|
12
16
|
from snowflake.snowpark_connect.column_name_handler import (
|
|
13
17
|
make_column_names_snowpark_compatible,
|
|
14
|
-
with_column_map,
|
|
15
18
|
)
|
|
19
|
+
from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
|
|
16
20
|
from snowflake.snowpark_connect.expression.literal import get_literal_field_and_name
|
|
17
21
|
from snowflake.snowpark_connect.expression.map_expression import (
|
|
18
22
|
map_single_column_expression,
|
|
@@ -20,111 +24,171 @@ from snowflake.snowpark_connect.expression.map_expression import (
|
|
|
20
24
|
from snowflake.snowpark_connect.expression.typer import ExpressionTyper
|
|
21
25
|
from snowflake.snowpark_connect.relation.map_relation import map_relation
|
|
22
26
|
from snowflake.snowpark_connect.typed_column import TypedColumn
|
|
23
|
-
from snowflake.snowpark_connect.utils.context import
|
|
27
|
+
from snowflake.snowpark_connect.utils.context import (
|
|
28
|
+
get_is_evaluating_sql,
|
|
29
|
+
set_current_grouping_columns,
|
|
30
|
+
temporary_pivot_expression,
|
|
31
|
+
)
|
|
24
32
|
|
|
25
33
|
|
|
26
|
-
def map_group_by_aggregate(
|
|
34
|
+
def map_group_by_aggregate(
|
|
35
|
+
rel: relation_proto.Relation,
|
|
36
|
+
) -> DataFrameContainer:
|
|
27
37
|
"""
|
|
28
38
|
Groups the DataFrame using the specified columns.
|
|
29
39
|
|
|
30
40
|
Aggregations come in as expressions, which are mapped to `snowpark.Column`
|
|
31
41
|
objects.
|
|
32
42
|
"""
|
|
33
|
-
|
|
43
|
+
input_df_container, columns = map_aggregate_helper(rel)
|
|
44
|
+
input_df_actual = input_df_container.dataframe
|
|
45
|
+
|
|
34
46
|
if len(columns.grouping_expressions()) == 0:
|
|
35
|
-
result =
|
|
47
|
+
result = input_df_actual.agg(*columns.aggregation_expressions())
|
|
36
48
|
else:
|
|
37
|
-
result =
|
|
49
|
+
result = input_df_actual.group_by(*columns.grouping_expressions()).agg(
|
|
38
50
|
*columns.aggregation_expressions()
|
|
39
51
|
)
|
|
40
|
-
return
|
|
41
|
-
result,
|
|
42
|
-
columns.spark_names(),
|
|
43
|
-
columns.snowpark_names(),
|
|
44
|
-
columns.data_types(),
|
|
52
|
+
return DataFrameContainer.create_with_column_mapping(
|
|
53
|
+
dataframe=result,
|
|
54
|
+
spark_column_names=columns.spark_names(),
|
|
55
|
+
snowpark_column_names=columns.snowpark_names(),
|
|
56
|
+
snowpark_column_types=columns.data_types(),
|
|
45
57
|
column_qualifiers=columns.get_qualifiers(),
|
|
46
|
-
parent_column_name_map=
|
|
58
|
+
parent_column_name_map=input_df_container.column_map,
|
|
47
59
|
)
|
|
48
60
|
|
|
49
61
|
|
|
50
|
-
def map_rollup_aggregate(
|
|
62
|
+
def map_rollup_aggregate(
|
|
63
|
+
rel: relation_proto.Relation,
|
|
64
|
+
) -> DataFrameContainer:
|
|
51
65
|
"""
|
|
52
66
|
Create a multidimensional rollup for the current DataFrame using the specified columns.
|
|
53
67
|
|
|
54
68
|
Aggregations come in as expressions, which are mapped to `snowpark.Column`
|
|
55
69
|
objects.
|
|
56
70
|
"""
|
|
57
|
-
|
|
71
|
+
input_container, columns = map_aggregate_helper(rel)
|
|
72
|
+
input_df_actual = input_container.dataframe
|
|
73
|
+
|
|
58
74
|
if len(columns.grouping_expressions()) == 0:
|
|
59
|
-
result =
|
|
75
|
+
result = input_df_actual.agg(*columns.aggregation_expressions())
|
|
60
76
|
else:
|
|
61
|
-
result =
|
|
77
|
+
result = input_df_actual.rollup(*columns.grouping_expressions()).agg(
|
|
62
78
|
*columns.aggregation_expressions()
|
|
63
79
|
)
|
|
64
|
-
return
|
|
65
|
-
result,
|
|
66
|
-
columns.spark_names(),
|
|
67
|
-
columns.snowpark_names(),
|
|
68
|
-
columns.data_types(),
|
|
80
|
+
return DataFrameContainer.create_with_column_mapping(
|
|
81
|
+
dataframe=result,
|
|
82
|
+
spark_column_names=columns.spark_names(),
|
|
83
|
+
snowpark_column_names=columns.snowpark_names(),
|
|
84
|
+
snowpark_column_types=columns.data_types(),
|
|
69
85
|
column_qualifiers=columns.get_qualifiers(),
|
|
70
|
-
parent_column_name_map=
|
|
86
|
+
parent_column_name_map=input_container.column_map,
|
|
71
87
|
)
|
|
72
88
|
|
|
73
89
|
|
|
74
|
-
def map_cube_aggregate(
|
|
90
|
+
def map_cube_aggregate(
|
|
91
|
+
rel: relation_proto.Relation,
|
|
92
|
+
) -> DataFrameContainer:
|
|
75
93
|
"""
|
|
76
94
|
Create a multidimensional cube for the current DataFrame using the specified columns.
|
|
77
95
|
|
|
78
96
|
Aggregations come in as expressions, which are mapped to `snowpark.Column`
|
|
79
97
|
objects.
|
|
80
98
|
"""
|
|
81
|
-
|
|
99
|
+
input_container, columns = map_aggregate_helper(rel)
|
|
100
|
+
input_df_actual = input_container.dataframe
|
|
101
|
+
|
|
82
102
|
if len(columns.grouping_expressions()) == 0:
|
|
83
|
-
result =
|
|
103
|
+
result = input_df_actual.agg(*columns.aggregation_expressions())
|
|
84
104
|
else:
|
|
85
|
-
result =
|
|
105
|
+
result = input_df_actual.cube(*columns.grouping_expressions()).agg(
|
|
86
106
|
*columns.aggregation_expressions()
|
|
87
107
|
)
|
|
88
|
-
return
|
|
89
|
-
result,
|
|
90
|
-
columns.spark_names(),
|
|
91
|
-
columns.snowpark_names(),
|
|
92
|
-
columns.data_types(),
|
|
108
|
+
return DataFrameContainer.create_with_column_mapping(
|
|
109
|
+
dataframe=result,
|
|
110
|
+
spark_column_names=columns.spark_names(),
|
|
111
|
+
snowpark_column_names=columns.snowpark_names(),
|
|
112
|
+
snowpark_column_types=columns.data_types(),
|
|
93
113
|
column_qualifiers=columns.get_qualifiers(),
|
|
94
|
-
parent_column_name_map=
|
|
114
|
+
parent_column_name_map=input_container.column_map,
|
|
95
115
|
)
|
|
96
116
|
|
|
97
117
|
|
|
98
|
-
def map_pivot_aggregate(
|
|
118
|
+
def map_pivot_aggregate(
|
|
119
|
+
rel: relation_proto.Relation,
|
|
120
|
+
) -> DataFrameContainer:
|
|
99
121
|
"""
|
|
100
122
|
Pivots a column of the current DataFrame and performs the specified aggregation.
|
|
101
123
|
|
|
102
124
|
There are 2 versions of the pivot function: one that requires the caller to specify the list of the distinct values
|
|
103
125
|
to pivot on and one that does not.
|
|
104
126
|
"""
|
|
105
|
-
|
|
127
|
+
input_container, columns = map_aggregate_helper(rel, pivot=True, skip_alias=True)
|
|
128
|
+
input_df_actual = input_container.dataframe
|
|
129
|
+
|
|
106
130
|
pivot_column = map_single_column_expression(
|
|
107
|
-
rel.aggregate.pivot.col,
|
|
131
|
+
rel.aggregate.pivot.col,
|
|
132
|
+
input_container.column_map,
|
|
133
|
+
ExpressionTyper(input_df_actual),
|
|
108
134
|
)
|
|
109
135
|
pivot_values = [
|
|
110
136
|
get_literal_field_and_name(lit)[0] for lit in rel.aggregate.pivot.values
|
|
111
137
|
]
|
|
112
138
|
|
|
139
|
+
used_columns = {pivot_column[1].col._expression.name}
|
|
140
|
+
if get_is_evaluating_sql():
|
|
141
|
+
# When evaluating SQL spark doesn't trim columns from the result
|
|
142
|
+
used_columns = {"*"}
|
|
143
|
+
else:
|
|
144
|
+
for expression in rel.aggregate.aggregate_expressions:
|
|
145
|
+
matched_identifiers = re.findall(
|
|
146
|
+
r'unparsed_identifier: "(.*)"', expression.__str__()
|
|
147
|
+
)
|
|
148
|
+
for identifier in matched_identifiers:
|
|
149
|
+
mapped_col = input_container.column_map.spark_to_col.get(
|
|
150
|
+
identifier, None
|
|
151
|
+
)
|
|
152
|
+
if mapped_col:
|
|
153
|
+
used_columns.add(mapped_col[0].snowpark_name)
|
|
154
|
+
|
|
113
155
|
if len(columns.grouping_expressions()) == 0:
|
|
114
|
-
result =
|
|
115
|
-
|
|
116
|
-
|
|
156
|
+
result = (
|
|
157
|
+
input_df_actual.select(*used_columns)
|
|
158
|
+
.pivot(pivot_column[1].col, pivot_values if pivot_values else None)
|
|
159
|
+
.agg(*columns.aggregation_expressions(unalias=True))
|
|
160
|
+
)
|
|
117
161
|
else:
|
|
118
162
|
result = (
|
|
119
|
-
|
|
163
|
+
input_df_actual.group_by(*columns.grouping_expressions())
|
|
120
164
|
.pivot(pivot_column[1].col, pivot_values if pivot_values else None)
|
|
121
|
-
.agg(*columns.aggregation_expressions())
|
|
165
|
+
.agg(*columns.aggregation_expressions(unalias=True))
|
|
122
166
|
)
|
|
123
167
|
|
|
168
|
+
agg_name_list = [c.spark_name for c in columns.grouping_columns]
|
|
169
|
+
|
|
170
|
+
# Calculate number of pivot values for proper Spark-compatible indexing
|
|
171
|
+
total_pivot_columns = len(result.columns) - len(agg_name_list)
|
|
172
|
+
num_pivot_values = (
|
|
173
|
+
total_pivot_columns // len(columns.aggregation_columns)
|
|
174
|
+
if len(columns.aggregation_columns) > 0
|
|
175
|
+
else 1
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
def _get_agg_exp_alias_for_col(col_index: int) -> Optional[str]:
|
|
179
|
+
if col_index < len(agg_name_list) or len(columns.aggregation_columns) <= 1:
|
|
180
|
+
return None
|
|
181
|
+
else:
|
|
182
|
+
index = (col_index - len(agg_name_list)) // num_pivot_values
|
|
183
|
+
return columns.aggregation_columns[index].spark_name
|
|
184
|
+
|
|
124
185
|
spark_columns = []
|
|
125
|
-
for col in [
|
|
186
|
+
for col in [
|
|
187
|
+
pivot_column_name(c, _get_agg_exp_alias_for_col(i))
|
|
188
|
+
for i, c in enumerate(result.columns)
|
|
189
|
+
]:
|
|
126
190
|
spark_col = (
|
|
127
|
-
|
|
191
|
+
input_container.column_map.get_spark_column_name_from_snowpark_column_name(
|
|
128
192
|
col, allow_non_exists=True
|
|
129
193
|
)
|
|
130
194
|
)
|
|
@@ -132,22 +196,57 @@ def map_pivot_aggregate(rel: relation_proto.Relation) -> snowpark.DataFrame:
|
|
|
132
196
|
if spark_col is not None:
|
|
133
197
|
spark_columns.append(spark_col)
|
|
134
198
|
else:
|
|
135
|
-
|
|
199
|
+
# Handle NULL column names to match Spark behavior (lowercase 'null')
|
|
200
|
+
if col == "NULL":
|
|
201
|
+
spark_columns.append(col.lower())
|
|
202
|
+
else:
|
|
203
|
+
spark_columns.append(col)
|
|
204
|
+
|
|
205
|
+
grouping_cols_count = len(agg_name_list)
|
|
206
|
+
pivot_cols = result.columns[grouping_cols_count:]
|
|
207
|
+
spark_pivot_cols = spark_columns[grouping_cols_count:]
|
|
208
|
+
|
|
209
|
+
num_agg_functions = len(columns.aggregation_columns)
|
|
210
|
+
num_pivot_values = len(pivot_cols) // num_agg_functions
|
|
211
|
+
|
|
212
|
+
reordered_snowpark_cols = []
|
|
213
|
+
reordered_spark_cols = []
|
|
214
|
+
column_indices = [] # 1-based indexing
|
|
215
|
+
|
|
216
|
+
for i in range(grouping_cols_count):
|
|
217
|
+
reordered_snowpark_cols.append(result.columns[i])
|
|
218
|
+
reordered_spark_cols.append(spark_columns[i])
|
|
219
|
+
column_indices.append(i + 1)
|
|
220
|
+
|
|
221
|
+
for pivot_idx in range(num_pivot_values):
|
|
222
|
+
for agg_idx in range(num_agg_functions):
|
|
223
|
+
current_pos = agg_idx * num_pivot_values + pivot_idx
|
|
224
|
+
if current_pos < len(pivot_cols):
|
|
225
|
+
reordered_snowpark_cols.append(pivot_cols[current_pos])
|
|
226
|
+
reordered_spark_cols.append(spark_pivot_cols[current_pos])
|
|
227
|
+
original_index = grouping_cols_count + current_pos
|
|
228
|
+
column_indices.append(original_index + 1)
|
|
229
|
+
|
|
230
|
+
reordered_result = result.select(
|
|
231
|
+
*[snowpark_fn.col(f"${idx}") for idx in column_indices]
|
|
232
|
+
)
|
|
136
233
|
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
result.columns,
|
|
234
|
+
return DataFrameContainer.create_with_column_mapping(
|
|
235
|
+
dataframe=reordered_result,
|
|
236
|
+
spark_column_names=reordered_spark_cols,
|
|
237
|
+
snowpark_column_names=[f"${idx}" for idx in column_indices],
|
|
142
238
|
column_qualifiers=(
|
|
143
239
|
columns.get_qualifiers()[: len(agg_name_list)]
|
|
144
|
-
+ [[]] * (len(
|
|
240
|
+
+ [[]] * (len(reordered_spark_cols) - len(agg_name_list))
|
|
145
241
|
),
|
|
146
|
-
parent_column_name_map=
|
|
242
|
+
parent_column_name_map=input_container.column_map,
|
|
243
|
+
snowpark_column_types=[
|
|
244
|
+
result.schema.fields[idx - 1].datatype for idx in column_indices
|
|
245
|
+
],
|
|
147
246
|
)
|
|
148
247
|
|
|
149
248
|
|
|
150
|
-
def
|
|
249
|
+
def pivot_column_name(snowpark_cname, opt_alias: Optional[str] = None) -> Optional[str]:
|
|
151
250
|
# For values that are used as pivoted columns, the input and output are in the following format (outermost double quotes are part of the input):
|
|
152
251
|
|
|
153
252
|
# 1. "'Java'" -> Java
|
|
@@ -162,7 +261,7 @@ def string_parser(s):
|
|
|
162
261
|
|
|
163
262
|
try:
|
|
164
263
|
# handling values that are used as pivoted columns
|
|
165
|
-
match = re.match(r'^"\'(.*)\'"$',
|
|
264
|
+
match = re.match(r'^"\'(.*)\'"$', snowpark_cname)
|
|
166
265
|
# extract the content between the outermost double quote followed by a single quote "'
|
|
167
266
|
content = match.group(1)
|
|
168
267
|
# convert the escaped double quote to the actual double quote
|
|
@@ -174,10 +273,10 @@ def string_parser(s):
|
|
|
174
273
|
content = re.sub(r"'", "", content)
|
|
175
274
|
# replace the placeholder with the single quote which we want to preserve
|
|
176
275
|
result = content.replace(escape_single_quote_placeholder, "'")
|
|
177
|
-
return result
|
|
276
|
+
return f"{result}_{opt_alias}" if opt_alias else result
|
|
178
277
|
except Exception:
|
|
179
278
|
# fallback to the original logic, handling aliased column names
|
|
180
|
-
double_quote_list = re.findall(r'"(.*?)"',
|
|
279
|
+
double_quote_list = re.findall(r'"(.*?)"', snowpark_cname)
|
|
181
280
|
spark_string = ""
|
|
182
281
|
for entry in list(filter(None, double_quote_list)):
|
|
183
282
|
if "'" in entry:
|
|
@@ -189,7 +288,7 @@ def string_parser(s):
|
|
|
189
288
|
spark_string += entry
|
|
190
289
|
else:
|
|
191
290
|
spark_string += '"' + entry + '"'
|
|
192
|
-
return
|
|
291
|
+
return snowpark_cname if spark_string == "" else spark_string
|
|
193
292
|
|
|
194
293
|
|
|
195
294
|
@dataclass(frozen=True)
|
|
@@ -210,8 +309,14 @@ class _Columns:
|
|
|
210
309
|
def grouping_expressions(self) -> list[snowpark.Column]:
|
|
211
310
|
return [col.expression for col in self.grouping_columns]
|
|
212
311
|
|
|
213
|
-
def aggregation_expressions(self) -> list[snowpark.Column]:
|
|
214
|
-
|
|
312
|
+
def aggregation_expressions(self, unalias: bool = False) -> list[snowpark.Column]:
|
|
313
|
+
def _unalias(col: snowpark.Column) -> snowpark.Column:
|
|
314
|
+
if unalias and hasattr(col, "_expr1") and isinstance(col._expr1, Alias):
|
|
315
|
+
return _unalias(Column(col._expr1.child))
|
|
316
|
+
else:
|
|
317
|
+
return col
|
|
318
|
+
|
|
319
|
+
return [_unalias(col.expression) for col in self.aggregation_columns]
|
|
215
320
|
|
|
216
321
|
def expressions(self) -> list[snowpark.Column]:
|
|
217
322
|
return self.grouping_expressions() + self.aggregation_expressions()
|
|
@@ -246,7 +351,8 @@ class _Columns:
|
|
|
246
351
|
def map_aggregate_helper(
|
|
247
352
|
rel: relation_proto.Relation, pivot: bool = False, skip_alias: bool = False
|
|
248
353
|
):
|
|
249
|
-
|
|
354
|
+
input_container = map_relation(rel.aggregate.input)
|
|
355
|
+
input_df = input_container.dataframe
|
|
250
356
|
grouping_expressions = rel.aggregate.grouping_expressions
|
|
251
357
|
expressions = rel.aggregate.aggregate_expressions
|
|
252
358
|
groupings: list[_ColumnMetadata] = []
|
|
@@ -258,7 +364,7 @@ def map_aggregate_helper(
|
|
|
258
364
|
with temporary_pivot_expression(pivot):
|
|
259
365
|
for exp in grouping_expressions:
|
|
260
366
|
new_name, snowpark_column = map_single_column_expression(
|
|
261
|
-
exp,
|
|
367
|
+
exp, input_container.column_map, typer
|
|
262
368
|
)
|
|
263
369
|
alias = make_column_names_snowpark_compatible(
|
|
264
370
|
[new_name], rel.common.plan_id, len(groupings)
|
|
@@ -275,9 +381,12 @@ def map_aggregate_helper(
|
|
|
275
381
|
)
|
|
276
382
|
)
|
|
277
383
|
|
|
384
|
+
grouping_cols = [g.spark_name for g in groupings]
|
|
385
|
+
set_current_grouping_columns(grouping_cols)
|
|
386
|
+
|
|
278
387
|
for exp in expressions:
|
|
279
388
|
new_name, snowpark_column = map_single_column_expression(
|
|
280
|
-
exp,
|
|
389
|
+
exp, input_container.column_map, typer
|
|
281
390
|
)
|
|
282
391
|
alias = make_column_names_snowpark_compatible(
|
|
283
392
|
[new_name], rel.common.plan_id, len(groupings) + len(aggregations)
|
|
@@ -313,7 +422,7 @@ def map_aggregate_helper(
|
|
|
313
422
|
)
|
|
314
423
|
|
|
315
424
|
return (
|
|
316
|
-
|
|
425
|
+
input_container,
|
|
317
426
|
_Columns(
|
|
318
427
|
grouping_columns=groupings,
|
|
319
428
|
aggregation_columns=aggregations,
|
|
@@ -7,7 +7,7 @@ import re
|
|
|
7
7
|
import pandas
|
|
8
8
|
import pyspark.sql.connect.proto.catalog_pb2 as catalog_proto
|
|
9
9
|
|
|
10
|
-
from snowflake import
|
|
10
|
+
from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
|
|
11
11
|
from snowflake.snowpark_connect.relation.catalogs import CATALOGS
|
|
12
12
|
from snowflake.snowpark_connect.relation.catalogs.utils import (
|
|
13
13
|
CURRENT_CATALOG_NAME,
|
|
@@ -22,7 +22,7 @@ from snowflake.snowpark_connect.utils.telemetry import (
|
|
|
22
22
|
|
|
23
23
|
def map_catalog(
|
|
24
24
|
rel: catalog_proto.Catalog,
|
|
25
|
-
) ->
|
|
25
|
+
) -> DataFrameContainer | pandas.DataFrame:
|
|
26
26
|
match rel.WhichOneof("cat_type"):
|
|
27
27
|
# Database related APIs
|
|
28
28
|
case "current_database":
|