snowpark-connect 0.32.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of snowpark-connect might be problematic. Click here for more details.
- snowflake/snowpark_connect/column_name_handler.py +91 -40
- snowflake/snowpark_connect/column_qualifier.py +0 -4
- snowflake/snowpark_connect/config.py +9 -0
- snowflake/snowpark_connect/expression/hybrid_column_map.py +5 -4
- snowflake/snowpark_connect/expression/literal.py +12 -12
- snowflake/snowpark_connect/expression/map_sql_expression.py +18 -4
- snowflake/snowpark_connect/expression/map_unresolved_attribute.py +150 -29
- snowflake/snowpark_connect/expression/map_unresolved_function.py +93 -55
- snowflake/snowpark_connect/relation/map_aggregate.py +156 -257
- snowflake/snowpark_connect/relation/map_column_ops.py +19 -0
- snowflake/snowpark_connect/relation/map_join.py +454 -252
- snowflake/snowpark_connect/relation/map_row_ops.py +136 -54
- snowflake/snowpark_connect/relation/map_sql.py +335 -90
- snowflake/snowpark_connect/relation/read/map_read.py +9 -1
- snowflake/snowpark_connect/relation/read/map_read_csv.py +19 -2
- snowflake/snowpark_connect/relation/read/map_read_json.py +90 -2
- snowflake/snowpark_connect/relation/read/map_read_parquet.py +3 -0
- snowflake/snowpark_connect/relation/read/map_read_text.py +4 -0
- snowflake/snowpark_connect/relation/read/reader_config.py +10 -0
- snowflake/snowpark_connect/relation/read/utils.py +41 -0
- snowflake/snowpark_connect/relation/utils.py +50 -2
- snowflake/snowpark_connect/relation/write/map_write.py +251 -292
- snowflake/snowpark_connect/resources_initializer.py +25 -13
- snowflake/snowpark_connect/server.py +9 -24
- snowflake/snowpark_connect/type_mapping.py +2 -0
- snowflake/snowpark_connect/typed_column.py +2 -2
- snowflake/snowpark_connect/utils/context.py +0 -14
- snowflake/snowpark_connect/utils/expression_transformer.py +163 -0
- snowflake/snowpark_connect/utils/sequence.py +21 -0
- snowflake/snowpark_connect/utils/session.py +4 -1
- snowflake/snowpark_connect/utils/udf_helper.py +1 -0
- snowflake/snowpark_connect/utils/udtf_helper.py +3 -0
- snowflake/snowpark_connect/version.py +1 -1
- {snowpark_connect-0.32.0.dist-info → snowpark_connect-1.0.0.dist-info}/METADATA +4 -2
- {snowpark_connect-0.32.0.dist-info → snowpark_connect-1.0.0.dist-info}/RECORD +43 -104
- snowflake/snowpark_connect/includes/jars/antlr4-runtime-4.9.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-cli-1.5.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-codec-1.16.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-collections-3.2.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-collections4-4.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-compiler-3.1.9.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-compress-1.26.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-crypto-1.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-dbcp-1.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-io-2.16.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-lang-2.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-lang3-3.12.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-logging-1.1.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-math3-3.6.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-pool-1.5.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-text-1.10.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/hadoop-client-api-trimmed-3.3.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-annotations-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-core-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-core-asl-1.9.13.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-databind-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-dataformat-yaml-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-datatype-jsr310-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-module-scala_2.12-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-ast_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-core_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-jackson_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-native_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-scalap_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/kryo-shaded-4.0.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-1.2-api-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-api-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-core-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-slf4j2-impl-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/paranamer-2.8.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/paranamer-2.8.jar +0 -0
- snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-collection-compat_2.12-2.7.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-library-2.12.18.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-parser-combinators_2.12-2.3.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-reflect-2.12.18.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-xml_2.12-2.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/slf4j-api-2.0.7.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-catalyst_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-common-utils_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-connect-client-jvm_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-core_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-graphx_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-hive-thriftserver_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-hive_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-kvstore_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-launcher_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-mesos_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-mllib-local_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-network-common_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-network-shuffle_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-repl_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sketch_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sql-api_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sql_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-tags_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-unsafe_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-yarn_2.12-3.5.6.jar +0 -0
- {snowpark_connect-0.32.0.data → snowpark_connect-1.0.0.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.32.0.data → snowpark_connect-1.0.0.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.32.0.data → snowpark_connect-1.0.0.data}/scripts/snowpark-submit +0 -0
- {snowpark_connect-0.32.0.dist-info → snowpark_connect-1.0.0.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.32.0.dist-info → snowpark_connect-1.0.0.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-0.32.0.dist-info → snowpark_connect-1.0.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-0.32.0.dist-info → snowpark_connect-1.0.0.dist-info}/licenses/NOTICE-binary +0 -0
- {snowpark_connect-0.32.0.dist-info → snowpark_connect-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -2,9 +2,8 @@
|
|
|
2
2
|
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
|
|
5
|
-
import
|
|
5
|
+
import copy
|
|
6
6
|
from dataclasses import dataclass
|
|
7
|
-
from typing import Optional
|
|
8
7
|
|
|
9
8
|
import pyspark.sql.connect.proto.relations_pb2 as relation_proto
|
|
10
9
|
|
|
@@ -25,10 +24,10 @@ from snowflake.snowpark_connect.expression.map_expression import (
|
|
|
25
24
|
from snowflake.snowpark_connect.expression.typer import ExpressionTyper
|
|
26
25
|
from snowflake.snowpark_connect.relation.map_relation import map_relation
|
|
27
26
|
from snowflake.snowpark_connect.typed_column import TypedColumn
|
|
27
|
+
from snowflake.snowpark_connect.utils import expression_transformer
|
|
28
28
|
from snowflake.snowpark_connect.utils.context import (
|
|
29
29
|
get_is_evaluating_sql,
|
|
30
30
|
set_current_grouping_columns,
|
|
31
|
-
temporary_pivot_expression,
|
|
32
31
|
)
|
|
33
32
|
|
|
34
33
|
|
|
@@ -137,215 +136,118 @@ def map_pivot_aggregate(
|
|
|
137
136
|
get_literal_field_and_name(lit)[0] for lit in rel.aggregate.pivot.values
|
|
138
137
|
]
|
|
139
138
|
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
139
|
+
if not pivot_values:
|
|
140
|
+
distinct_col_values = (
|
|
141
|
+
input_df_actual.select(pivot_column[1].col)
|
|
142
|
+
.distinct()
|
|
143
|
+
.sort(snowpark_fn.asc_nulls_first(pivot_column[1].col))
|
|
144
|
+
.collect()
|
|
145
|
+
)
|
|
146
|
+
pivot_values = [row[0] for row in distinct_col_values]
|
|
147
|
+
|
|
148
|
+
agg_expressions = columns.aggregation_expressions(unalias=True)
|
|
149
|
+
|
|
150
|
+
spark_col_names = []
|
|
151
|
+
aggregations = []
|
|
152
|
+
final_pivot_names = []
|
|
153
|
+
grouping_columns_qualifiers = []
|
|
154
|
+
|
|
155
|
+
pivot_col_name = pivot_column[1].col.get_name()
|
|
156
|
+
|
|
157
|
+
agg_columns = set()
|
|
158
|
+
for agg_expression in agg_expressions:
|
|
159
|
+
if hasattr(agg_expression, "_expr1"):
|
|
160
|
+
agg_columns = agg_columns.union(
|
|
161
|
+
agg_expression._expr1.dependent_column_names()
|
|
148
162
|
)
|
|
149
|
-
for identifier in matched_identifiers:
|
|
150
|
-
mapped_col = input_container.column_map.spark_to_col.get(
|
|
151
|
-
identifier, None
|
|
152
|
-
)
|
|
153
|
-
if mapped_col:
|
|
154
|
-
used_columns.add(mapped_col[0].snowpark_name)
|
|
155
163
|
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
pivot_results = []
|
|
166
|
-
for i, agg_expr in enumerate(agg_expressions):
|
|
167
|
-
pivot_result = (
|
|
168
|
-
input_df_actual.select(*used_columns)
|
|
169
|
-
.pivot(pivot_column[1].col, pivot_values if pivot_values else None)
|
|
170
|
-
.agg(agg_expr)
|
|
171
|
-
)
|
|
172
|
-
for col_name in pivot_result.columns:
|
|
173
|
-
spark_names.append(
|
|
174
|
-
f"{pivot_column_name(col_name)}_{agg_metadata[i].spark_name}"
|
|
175
|
-
)
|
|
176
|
-
pivot_results.append(pivot_result)
|
|
177
|
-
|
|
178
|
-
result = pivot_results[0]
|
|
179
|
-
for pivot_result in pivot_results[1:]:
|
|
180
|
-
result = result.cross_join(pivot_result)
|
|
181
|
-
|
|
182
|
-
pivot_columns_per_agg = len(pivot_results[0].columns)
|
|
183
|
-
reordered_spark_names = []
|
|
184
|
-
reordered_snowpark_names = []
|
|
185
|
-
reordered_types = []
|
|
186
|
-
column_selectors = []
|
|
187
|
-
|
|
188
|
-
for pivot_idx in range(pivot_columns_per_agg):
|
|
189
|
-
for agg_idx in range(num_agg_functions):
|
|
190
|
-
current_pos = agg_idx * pivot_columns_per_agg + pivot_idx
|
|
191
|
-
if current_pos < len(spark_names):
|
|
192
|
-
idx = current_pos + 1 # 1-based indexing for Snowpark
|
|
193
|
-
reordered_spark_names.append(spark_names[current_pos])
|
|
194
|
-
reordered_snowpark_names.append(f"${idx}")
|
|
195
|
-
reordered_types.append(
|
|
196
|
-
result.schema.fields[current_pos].datatype
|
|
197
|
-
)
|
|
198
|
-
column_selectors.append(snowpark_fn.col(f"${idx}"))
|
|
199
|
-
|
|
200
|
-
return DataFrameContainer.create_with_column_mapping(
|
|
201
|
-
dataframe=result.select(*column_selectors),
|
|
202
|
-
spark_column_names=reordered_spark_names,
|
|
203
|
-
snowpark_column_names=reordered_snowpark_names,
|
|
204
|
-
column_qualifiers=[
|
|
205
|
-
{ColumnQualifier.no_qualifier()} for _ in reordered_spark_names
|
|
206
|
-
],
|
|
207
|
-
parent_column_name_map=input_container.column_map,
|
|
208
|
-
snowpark_column_types=reordered_types,
|
|
164
|
+
grouping_columns = columns.grouping_expressions()
|
|
165
|
+
if grouping_columns:
|
|
166
|
+
for col in grouping_columns:
|
|
167
|
+
snowpark_name = col.get_name()
|
|
168
|
+
spark_col_name = input_container.column_map.get_spark_column_name_from_snowpark_column_name(
|
|
169
|
+
snowpark_name
|
|
170
|
+
)
|
|
171
|
+
qualifiers = input_container.column_map.get_qualifiers_for_spark_column(
|
|
172
|
+
spark_col_name
|
|
209
173
|
)
|
|
174
|
+
grouping_columns_qualifiers.append(qualifiers)
|
|
175
|
+
spark_col_names.append(spark_col_name)
|
|
176
|
+
elif get_is_evaluating_sql():
|
|
177
|
+
for col in input_container.column_map.get_snowpark_columns():
|
|
178
|
+
if col != pivot_col_name and col not in agg_columns:
|
|
179
|
+
grouping_columns.append(col)
|
|
180
|
+
spark_col_name = input_container.column_map.get_spark_column_name_from_snowpark_column_name(
|
|
181
|
+
col
|
|
182
|
+
)
|
|
183
|
+
qualifiers = input_container.column_map.get_qualifiers_for_spark_column(
|
|
184
|
+
spark_col_name
|
|
185
|
+
)
|
|
186
|
+
grouping_columns_qualifiers.append(qualifiers)
|
|
187
|
+
spark_col_names.append(spark_col_name)
|
|
188
|
+
|
|
189
|
+
for pv_value in pivot_values:
|
|
190
|
+
pv_is_null = False
|
|
191
|
+
|
|
192
|
+
if pv_value in (None, "NULL", "None"):
|
|
193
|
+
pv_value_spark = "null"
|
|
194
|
+
pv_is_null = True
|
|
210
195
|
else:
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
196
|
+
pv_value_spark = str(pv_value)
|
|
197
|
+
|
|
198
|
+
for i, agg_expression in enumerate(agg_expressions):
|
|
199
|
+
agg_fun_expr = copy.deepcopy(agg_expression._expr1)
|
|
200
|
+
|
|
201
|
+
condition = (
|
|
202
|
+
snowpark_fn.is_null(pivot_column[1].col)
|
|
203
|
+
if pv_is_null
|
|
204
|
+
else (pivot_column[1].col == snowpark_fn.lit(pv_value))
|
|
215
205
|
)
|
|
216
|
-
else:
|
|
217
|
-
result = (
|
|
218
|
-
input_df_actual.group_by(*columns.grouping_expressions())
|
|
219
|
-
.pivot(pivot_column[1].col, pivot_values if pivot_values else None)
|
|
220
|
-
.agg(*columns.aggregation_expressions(unalias=True))
|
|
221
|
-
)
|
|
222
206
|
|
|
223
|
-
|
|
207
|
+
expression_transformer.inject_condition_to_all_agg_functions(
|
|
208
|
+
agg_fun_expr, condition
|
|
209
|
+
)
|
|
224
210
|
|
|
225
|
-
|
|
226
|
-
total_pivot_columns = len(result.columns) - len(agg_name_list)
|
|
227
|
-
num_pivot_values = (
|
|
228
|
-
total_pivot_columns // len(columns.aggregation_columns)
|
|
229
|
-
if len(columns.aggregation_columns) > 0
|
|
230
|
-
else 1
|
|
231
|
-
)
|
|
211
|
+
curr_expression = Column(agg_fun_expr)
|
|
232
212
|
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
index = (col_index - len(agg_name_list)) // num_pivot_values
|
|
238
|
-
return columns.aggregation_columns[index].spark_name
|
|
239
|
-
|
|
240
|
-
spark_columns = []
|
|
241
|
-
for col in [
|
|
242
|
-
pivot_column_name(c, _get_agg_exp_alias_for_col(i))
|
|
243
|
-
for i, c in enumerate(result.columns)
|
|
244
|
-
]:
|
|
245
|
-
spark_col = (
|
|
246
|
-
input_container.column_map.get_spark_column_name_from_snowpark_column_name(
|
|
247
|
-
col, allow_non_exists=True
|
|
213
|
+
spark_col_name = (
|
|
214
|
+
f"{pv_value_spark}_{columns.aggregation_columns[i].spark_name}"
|
|
215
|
+
if len(agg_expressions) > 1
|
|
216
|
+
else f"{pv_value_spark}"
|
|
248
217
|
)
|
|
249
|
-
)
|
|
250
218
|
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
reordered_snowpark_cols = []
|
|
268
|
-
reordered_spark_cols = []
|
|
269
|
-
column_indices = [] # 1-based indexing
|
|
270
|
-
|
|
271
|
-
for i in range(grouping_cols_count):
|
|
272
|
-
reordered_snowpark_cols.append(result.columns[i])
|
|
273
|
-
reordered_spark_cols.append(spark_columns[i])
|
|
274
|
-
column_indices.append(i + 1)
|
|
275
|
-
|
|
276
|
-
for pivot_idx in range(num_pivot_values):
|
|
277
|
-
for agg_idx in range(num_agg_functions):
|
|
278
|
-
current_pos = agg_idx * num_pivot_values + pivot_idx
|
|
279
|
-
if current_pos < len(pivot_cols):
|
|
280
|
-
reordered_snowpark_cols.append(pivot_cols[current_pos])
|
|
281
|
-
reordered_spark_cols.append(spark_pivot_cols[current_pos])
|
|
282
|
-
original_index = grouping_cols_count + current_pos
|
|
283
|
-
column_indices.append(original_index + 1)
|
|
284
|
-
|
|
285
|
-
reordered_result = result.select(
|
|
286
|
-
*[snowpark_fn.col(f"${idx}") for idx in column_indices]
|
|
219
|
+
snowpark_col_name = make_column_names_snowpark_compatible(
|
|
220
|
+
[spark_col_name],
|
|
221
|
+
rel.common.plan_id,
|
|
222
|
+
len(grouping_columns) + len(agg_expressions),
|
|
223
|
+
)[0]
|
|
224
|
+
|
|
225
|
+
curr_expression = curr_expression.alias(snowpark_col_name)
|
|
226
|
+
|
|
227
|
+
aggregations.append(curr_expression)
|
|
228
|
+
spark_col_names.append(spark_col_name)
|
|
229
|
+
final_pivot_names.append(snowpark_col_name)
|
|
230
|
+
|
|
231
|
+
result_df = (
|
|
232
|
+
input_df_actual.group_by(*grouping_columns)
|
|
233
|
+
.agg(*aggregations)
|
|
234
|
+
.select(*grouping_columns, *final_pivot_names)
|
|
287
235
|
)
|
|
288
236
|
|
|
289
237
|
return DataFrameContainer.create_with_column_mapping(
|
|
290
|
-
dataframe=
|
|
291
|
-
spark_column_names=
|
|
292
|
-
snowpark_column_names=
|
|
293
|
-
column_qualifiers=(
|
|
294
|
-
columns.get_qualifiers()[: len(agg_name_list)]
|
|
295
|
-
+ [[]] * (len(reordered_spark_cols) - len(agg_name_list))
|
|
296
|
-
),
|
|
297
|
-
parent_column_name_map=input_container.column_map,
|
|
238
|
+
dataframe=result_df,
|
|
239
|
+
spark_column_names=spark_col_names,
|
|
240
|
+
snowpark_column_names=result_df.columns,
|
|
298
241
|
snowpark_column_types=[
|
|
299
|
-
|
|
242
|
+
result_df.schema.fields[idx].datatype
|
|
243
|
+
for idx, _ in enumerate(result_df.columns)
|
|
300
244
|
],
|
|
245
|
+
column_qualifiers=grouping_columns_qualifiers
|
|
246
|
+
+ [set() for _ in final_pivot_names],
|
|
247
|
+
parent_column_name_map=input_container.column_map,
|
|
301
248
|
)
|
|
302
249
|
|
|
303
250
|
|
|
304
|
-
def pivot_column_name(snowpark_cname, opt_alias: Optional[str] = None) -> Optional[str]:
|
|
305
|
-
# For values that are used as pivoted columns, the input and output are in the following format (outermost double quotes are part of the input):
|
|
306
|
-
|
|
307
|
-
# 1. "'Java'" -> Java
|
|
308
|
-
# 2. "'""C++""'" -> "C++"
|
|
309
|
-
# 3. "'""""''Scala''""""'" -> ""'Scala'""
|
|
310
|
-
|
|
311
|
-
# As we can see:
|
|
312
|
-
# 1. the whole content is always nested in a double quote followed by a single quote ("'<content>'").
|
|
313
|
-
# 2. the string content is nested in single quotes ('<string_content>')
|
|
314
|
-
# 3. double quote is escased by another double quote, this is snowflake behavior
|
|
315
|
-
# 4. if there is a single quote followed by a single quote, the first single quote needs to be preserved in the output
|
|
316
|
-
|
|
317
|
-
try:
|
|
318
|
-
# handling values that are used as pivoted columns
|
|
319
|
-
match = re.match(r'^"\'(.*)\'"$', snowpark_cname)
|
|
320
|
-
# extract the content between the outermost double quote followed by a single quote "'
|
|
321
|
-
content = match.group(1)
|
|
322
|
-
# convert the escaped double quote to the actual double quote
|
|
323
|
-
content = content.replace('""', '"')
|
|
324
|
-
escape_single_quote_placeholder = "__SAS_PLACEHOLDER_ESCAPE_SINGLE_QUOTE__"
|
|
325
|
-
# replace two consecutive single quote in the content with a placeholder, the first single quote needs to be preserved
|
|
326
|
-
content = re.sub(r"''", escape_single_quote_placeholder, content)
|
|
327
|
-
# remove the solo single quote, they are not part of the string content
|
|
328
|
-
content = re.sub(r"'", "", content)
|
|
329
|
-
# replace the placeholder with the single quote which we want to preserve
|
|
330
|
-
result = content.replace(escape_single_quote_placeholder, "'")
|
|
331
|
-
return f"{result}_{opt_alias}" if opt_alias else result
|
|
332
|
-
except Exception:
|
|
333
|
-
# fallback to the original logic, handling aliased column names
|
|
334
|
-
double_quote_list = re.findall(r'"(.*?)"', snowpark_cname)
|
|
335
|
-
spark_string = ""
|
|
336
|
-
for entry in list(filter(None, double_quote_list)):
|
|
337
|
-
if "'" in entry:
|
|
338
|
-
entry = entry.replace("'", "")
|
|
339
|
-
if len(entry) > 0:
|
|
340
|
-
spark_string += entry
|
|
341
|
-
elif entry.isdigit() or re.compile(r"^\d+?\.\d+?$").match(entry):
|
|
342
|
-
# skip quoting digits or decimal numbers as column names.
|
|
343
|
-
spark_string += entry
|
|
344
|
-
else:
|
|
345
|
-
spark_string += '"' + entry + '"'
|
|
346
|
-
return snowpark_cname if spark_string == "" else spark_string
|
|
347
|
-
|
|
348
|
-
|
|
349
251
|
@dataclass(frozen=True)
|
|
350
252
|
class _ColumnMetadata:
|
|
351
253
|
expression: snowpark.Column
|
|
@@ -416,71 +318,68 @@ def map_aggregate_helper(
|
|
|
416
318
|
typer = ExpressionTyper(input_df)
|
|
417
319
|
schema_inferrable = True
|
|
418
320
|
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
)
|
|
424
|
-
alias = make_column_names_snowpark_compatible(
|
|
425
|
-
[new_name], rel.common.plan_id, len(groupings)
|
|
426
|
-
)[0]
|
|
427
|
-
groupings.append(
|
|
428
|
-
_ColumnMetadata(
|
|
429
|
-
snowpark_column.col
|
|
430
|
-
if skip_alias
|
|
431
|
-
else snowpark_column.col.alias(alias),
|
|
432
|
-
new_name,
|
|
433
|
-
None if skip_alias else alias,
|
|
434
|
-
None if pivot else snowpark_column.typ,
|
|
435
|
-
qualifiers=snowpark_column.get_qualifiers(),
|
|
436
|
-
)
|
|
437
|
-
)
|
|
438
|
-
|
|
439
|
-
grouping_cols = [g.spark_name for g in groupings]
|
|
440
|
-
set_current_grouping_columns(grouping_cols)
|
|
321
|
+
for exp in grouping_expressions:
|
|
322
|
+
new_name, snowpark_column = map_single_column_expression(
|
|
323
|
+
exp, input_container.column_map, typer
|
|
324
|
+
)
|
|
441
325
|
|
|
442
|
-
|
|
443
|
-
new_name,
|
|
444
|
-
|
|
326
|
+
alias = make_column_names_snowpark_compatible(
|
|
327
|
+
[new_name], rel.common.plan_id, len(groupings)
|
|
328
|
+
)[0]
|
|
329
|
+
|
|
330
|
+
groupings.append(
|
|
331
|
+
_ColumnMetadata(
|
|
332
|
+
snowpark_column.col if skip_alias else snowpark_column.col.alias(alias),
|
|
333
|
+
new_name,
|
|
334
|
+
None if skip_alias else alias,
|
|
335
|
+
None if pivot else snowpark_column.typ,
|
|
336
|
+
qualifiers=snowpark_column.get_qualifiers(),
|
|
445
337
|
)
|
|
446
|
-
|
|
447
|
-
[new_name], rel.common.plan_id, len(groupings) + len(aggregations)
|
|
448
|
-
)[0]
|
|
338
|
+
)
|
|
449
339
|
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
) -> DataType | None:
|
|
453
|
-
if pivot or not schema_inferrable:
|
|
454
|
-
return None
|
|
455
|
-
try:
|
|
456
|
-
return agg_exp.typ
|
|
457
|
-
except Exception:
|
|
458
|
-
# This type used for schema inference optimization purposes.
|
|
459
|
-
# typer may not be able to infer the type of some expressions
|
|
460
|
-
# in that case we return None, and the optimization will not be applied.
|
|
461
|
-
return None
|
|
462
|
-
|
|
463
|
-
agg_col_typ = type_agg_expr(snowpark_column, schema_inferrable)
|
|
464
|
-
if agg_col_typ is None:
|
|
465
|
-
schema_inferrable = False
|
|
466
|
-
|
|
467
|
-
aggregations.append(
|
|
468
|
-
_ColumnMetadata(
|
|
469
|
-
snowpark_column.col
|
|
470
|
-
if skip_alias
|
|
471
|
-
else snowpark_column.col.alias(alias),
|
|
472
|
-
new_name,
|
|
473
|
-
None if skip_alias else alias,
|
|
474
|
-
agg_col_typ,
|
|
475
|
-
qualifiers={ColumnQualifier.no_qualifier()},
|
|
476
|
-
)
|
|
477
|
-
)
|
|
340
|
+
grouping_cols = [g.spark_name for g in groupings]
|
|
341
|
+
set_current_grouping_columns(grouping_cols)
|
|
478
342
|
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
343
|
+
for exp in expressions:
|
|
344
|
+
new_name, snowpark_column = map_single_column_expression(
|
|
345
|
+
exp, input_container.column_map, typer
|
|
346
|
+
)
|
|
347
|
+
alias = make_column_names_snowpark_compatible(
|
|
348
|
+
[new_name], rel.common.plan_id, len(groupings) + len(aggregations)
|
|
349
|
+
)[0]
|
|
350
|
+
|
|
351
|
+
def type_agg_expr(
|
|
352
|
+
agg_exp: TypedColumn, schema_inferrable: bool
|
|
353
|
+
) -> DataType | None:
|
|
354
|
+
if pivot or not schema_inferrable:
|
|
355
|
+
return None
|
|
356
|
+
try:
|
|
357
|
+
return agg_exp.typ
|
|
358
|
+
except Exception:
|
|
359
|
+
# This type used for schema inference optimization purposes.
|
|
360
|
+
# typer may not be able to infer the type of some expressions
|
|
361
|
+
# in that case we return None, and the optimization will not be applied.
|
|
362
|
+
return None
|
|
363
|
+
|
|
364
|
+
agg_col_typ = type_agg_expr(snowpark_column, schema_inferrable)
|
|
365
|
+
if agg_col_typ is None:
|
|
366
|
+
schema_inferrable = False
|
|
367
|
+
|
|
368
|
+
aggregations.append(
|
|
369
|
+
_ColumnMetadata(
|
|
370
|
+
snowpark_column.col if skip_alias else snowpark_column.col.alias(alias),
|
|
371
|
+
new_name,
|
|
372
|
+
None if skip_alias else alias,
|
|
373
|
+
agg_col_typ,
|
|
374
|
+
qualifiers=set(),
|
|
375
|
+
)
|
|
486
376
|
)
|
|
377
|
+
|
|
378
|
+
return (
|
|
379
|
+
input_container,
|
|
380
|
+
_Columns(
|
|
381
|
+
grouping_columns=groupings,
|
|
382
|
+
aggregation_columns=aggregations,
|
|
383
|
+
can_infer_schema=schema_inferrable,
|
|
384
|
+
),
|
|
385
|
+
)
|
|
@@ -288,6 +288,20 @@ def map_project(
|
|
|
288
288
|
alias_types = mapper.types
|
|
289
289
|
typed_alias = TypedColumn(aliased_col, lambda types=alias_types: types)
|
|
290
290
|
register_lca_alias(spark_name, typed_alias)
|
|
291
|
+
|
|
292
|
+
# Also register with the original qualified name if this is an alias of a column reference
|
|
293
|
+
# This handles ORDER BY referencing the original name: SELECT o.date AS order_date ... ORDER BY o.date
|
|
294
|
+
if (
|
|
295
|
+
exp.alias.HasField("expr")
|
|
296
|
+
and exp.alias.expr.WhichOneof("expr_type") == "unresolved_attribute"
|
|
297
|
+
):
|
|
298
|
+
original_name = (
|
|
299
|
+
exp.alias.expr.unresolved_attribute.unparsed_identifier
|
|
300
|
+
)
|
|
301
|
+
if (
|
|
302
|
+
original_name != spark_name
|
|
303
|
+
): # Don't register twice with the same name
|
|
304
|
+
register_lca_alias(original_name, typed_alias)
|
|
291
305
|
else:
|
|
292
306
|
# Multi-column case ('select *', posexplode, explode, inline, etc.)
|
|
293
307
|
has_multi_column_alias = True
|
|
@@ -316,6 +330,11 @@ def map_project(
|
|
|
316
330
|
final_snowpark_columns = make_column_names_snowpark_compatible(
|
|
317
331
|
new_spark_columns, rel.common.plan_id
|
|
318
332
|
)
|
|
333
|
+
# if there are duplicate snowpark column names, we need to disambiguate them by their index
|
|
334
|
+
if len(new_spark_columns) != len(set(new_spark_columns)):
|
|
335
|
+
result = result.select(
|
|
336
|
+
[f"${i}" for i in range(1, len(new_spark_columns) + 1)]
|
|
337
|
+
)
|
|
319
338
|
result = result.toDF(*final_snowpark_columns)
|
|
320
339
|
new_snowpark_columns = final_snowpark_columns
|
|
321
340
|
|