snowpark-connect 0.33.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of snowpark-connect might be problematic. Click here for more details.
- snowflake/snowpark_connect/column_name_handler.py +42 -56
- snowflake/snowpark_connect/config.py +9 -0
- snowflake/snowpark_connect/expression/literal.py +12 -12
- snowflake/snowpark_connect/expression/map_sql_expression.py +6 -0
- snowflake/snowpark_connect/expression/map_unresolved_attribute.py +147 -63
- snowflake/snowpark_connect/expression/map_unresolved_function.py +31 -28
- snowflake/snowpark_connect/relation/map_aggregate.py +156 -255
- snowflake/snowpark_connect/relation/map_column_ops.py +14 -0
- snowflake/snowpark_connect/relation/map_join.py +364 -234
- snowflake/snowpark_connect/relation/map_sql.py +309 -150
- snowflake/snowpark_connect/relation/read/map_read.py +9 -1
- snowflake/snowpark_connect/relation/read/map_read_csv.py +19 -2
- snowflake/snowpark_connect/relation/read/map_read_json.py +3 -0
- snowflake/snowpark_connect/relation/read/map_read_parquet.py +3 -0
- snowflake/snowpark_connect/relation/read/map_read_text.py +4 -0
- snowflake/snowpark_connect/relation/read/reader_config.py +10 -0
- snowflake/snowpark_connect/relation/read/utils.py +41 -0
- snowflake/snowpark_connect/relation/utils.py +4 -2
- snowflake/snowpark_connect/relation/write/map_write.py +65 -17
- snowflake/snowpark_connect/utils/context.py +0 -14
- snowflake/snowpark_connect/utils/expression_transformer.py +163 -0
- snowflake/snowpark_connect/utils/session.py +0 -4
- snowflake/snowpark_connect/utils/udf_helper.py +1 -0
- snowflake/snowpark_connect/utils/udtf_helper.py +3 -0
- snowflake/snowpark_connect/version.py +1 -1
- {snowpark_connect-0.33.0.dist-info → snowpark_connect-1.0.0.dist-info}/METADATA +2 -2
- {snowpark_connect-0.33.0.dist-info → snowpark_connect-1.0.0.dist-info}/RECORD +35 -38
- snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/__init__.py +0 -16
- snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/accessors.py +0 -1281
- snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/functions.py +0 -203
- snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/utils.py +0 -202
- {snowpark_connect-0.33.0.data → snowpark_connect-1.0.0.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.33.0.data → snowpark_connect-1.0.0.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.33.0.data → snowpark_connect-1.0.0.data}/scripts/snowpark-submit +0 -0
- {snowpark_connect-0.33.0.dist-info → snowpark_connect-1.0.0.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.33.0.dist-info → snowpark_connect-1.0.0.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-0.33.0.dist-info → snowpark_connect-1.0.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-0.33.0.dist-info → snowpark_connect-1.0.0.dist-info}/licenses/NOTICE-binary +0 -0
- {snowpark_connect-0.33.0.dist-info → snowpark_connect-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -2,9 +2,8 @@
|
|
|
2
2
|
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
|
|
5
|
-
import
|
|
5
|
+
import copy
|
|
6
6
|
from dataclasses import dataclass
|
|
7
|
-
from typing import Optional
|
|
8
7
|
|
|
9
8
|
import pyspark.sql.connect.proto.relations_pb2 as relation_proto
|
|
10
9
|
|
|
@@ -25,10 +24,10 @@ from snowflake.snowpark_connect.expression.map_expression import (
|
|
|
25
24
|
from snowflake.snowpark_connect.expression.typer import ExpressionTyper
|
|
26
25
|
from snowflake.snowpark_connect.relation.map_relation import map_relation
|
|
27
26
|
from snowflake.snowpark_connect.typed_column import TypedColumn
|
|
27
|
+
from snowflake.snowpark_connect.utils import expression_transformer
|
|
28
28
|
from snowflake.snowpark_connect.utils.context import (
|
|
29
29
|
get_is_evaluating_sql,
|
|
30
30
|
set_current_grouping_columns,
|
|
31
|
-
temporary_pivot_expression,
|
|
32
31
|
)
|
|
33
32
|
|
|
34
33
|
|
|
@@ -137,213 +136,118 @@ def map_pivot_aggregate(
|
|
|
137
136
|
get_literal_field_and_name(lit)[0] for lit in rel.aggregate.pivot.values
|
|
138
137
|
]
|
|
139
138
|
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
139
|
+
if not pivot_values:
|
|
140
|
+
distinct_col_values = (
|
|
141
|
+
input_df_actual.select(pivot_column[1].col)
|
|
142
|
+
.distinct()
|
|
143
|
+
.sort(snowpark_fn.asc_nulls_first(pivot_column[1].col))
|
|
144
|
+
.collect()
|
|
145
|
+
)
|
|
146
|
+
pivot_values = [row[0] for row in distinct_col_values]
|
|
147
|
+
|
|
148
|
+
agg_expressions = columns.aggregation_expressions(unalias=True)
|
|
149
|
+
|
|
150
|
+
spark_col_names = []
|
|
151
|
+
aggregations = []
|
|
152
|
+
final_pivot_names = []
|
|
153
|
+
grouping_columns_qualifiers = []
|
|
154
|
+
|
|
155
|
+
pivot_col_name = pivot_column[1].col.get_name()
|
|
156
|
+
|
|
157
|
+
agg_columns = set()
|
|
158
|
+
for agg_expression in agg_expressions:
|
|
159
|
+
if hasattr(agg_expression, "_expr1"):
|
|
160
|
+
agg_columns = agg_columns.union(
|
|
161
|
+
agg_expression._expr1.dependent_column_names()
|
|
148
162
|
)
|
|
149
|
-
for identifier in matched_identifiers:
|
|
150
|
-
mapped_col = input_container.column_map.spark_to_col.get(
|
|
151
|
-
identifier, None
|
|
152
|
-
)
|
|
153
|
-
if mapped_col:
|
|
154
|
-
used_columns.add(mapped_col[0].snowpark_name)
|
|
155
163
|
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
pivot_results = []
|
|
166
|
-
for i, agg_expr in enumerate(agg_expressions):
|
|
167
|
-
pivot_result = (
|
|
168
|
-
input_df_actual.select(*used_columns)
|
|
169
|
-
.pivot(pivot_column[1].col, pivot_values if pivot_values else None)
|
|
170
|
-
.agg(agg_expr)
|
|
171
|
-
)
|
|
172
|
-
for col_name in pivot_result.columns:
|
|
173
|
-
spark_names.append(
|
|
174
|
-
f"{pivot_column_name(col_name)}_{agg_metadata[i].spark_name}"
|
|
175
|
-
)
|
|
176
|
-
pivot_results.append(pivot_result)
|
|
177
|
-
|
|
178
|
-
result = pivot_results[0]
|
|
179
|
-
for pivot_result in pivot_results[1:]:
|
|
180
|
-
result = result.cross_join(pivot_result)
|
|
181
|
-
|
|
182
|
-
pivot_columns_per_agg = len(pivot_results[0].columns)
|
|
183
|
-
reordered_spark_names = []
|
|
184
|
-
reordered_snowpark_names = []
|
|
185
|
-
reordered_types = []
|
|
186
|
-
column_selectors = []
|
|
187
|
-
|
|
188
|
-
for pivot_idx in range(pivot_columns_per_agg):
|
|
189
|
-
for agg_idx in range(num_agg_functions):
|
|
190
|
-
current_pos = agg_idx * pivot_columns_per_agg + pivot_idx
|
|
191
|
-
if current_pos < len(spark_names):
|
|
192
|
-
idx = current_pos + 1 # 1-based indexing for Snowpark
|
|
193
|
-
reordered_spark_names.append(spark_names[current_pos])
|
|
194
|
-
reordered_snowpark_names.append(f"${idx}")
|
|
195
|
-
reordered_types.append(
|
|
196
|
-
result.schema.fields[current_pos].datatype
|
|
197
|
-
)
|
|
198
|
-
column_selectors.append(snowpark_fn.col(f"${idx}"))
|
|
199
|
-
|
|
200
|
-
return DataFrameContainer.create_with_column_mapping(
|
|
201
|
-
dataframe=result.select(*column_selectors),
|
|
202
|
-
spark_column_names=reordered_spark_names,
|
|
203
|
-
snowpark_column_names=reordered_snowpark_names,
|
|
204
|
-
column_qualifiers=[set() for _ in reordered_spark_names],
|
|
205
|
-
parent_column_name_map=input_container.column_map,
|
|
206
|
-
snowpark_column_types=reordered_types,
|
|
164
|
+
grouping_columns = columns.grouping_expressions()
|
|
165
|
+
if grouping_columns:
|
|
166
|
+
for col in grouping_columns:
|
|
167
|
+
snowpark_name = col.get_name()
|
|
168
|
+
spark_col_name = input_container.column_map.get_spark_column_name_from_snowpark_column_name(
|
|
169
|
+
snowpark_name
|
|
170
|
+
)
|
|
171
|
+
qualifiers = input_container.column_map.get_qualifiers_for_spark_column(
|
|
172
|
+
spark_col_name
|
|
207
173
|
)
|
|
174
|
+
grouping_columns_qualifiers.append(qualifiers)
|
|
175
|
+
spark_col_names.append(spark_col_name)
|
|
176
|
+
elif get_is_evaluating_sql():
|
|
177
|
+
for col in input_container.column_map.get_snowpark_columns():
|
|
178
|
+
if col != pivot_col_name and col not in agg_columns:
|
|
179
|
+
grouping_columns.append(col)
|
|
180
|
+
spark_col_name = input_container.column_map.get_spark_column_name_from_snowpark_column_name(
|
|
181
|
+
col
|
|
182
|
+
)
|
|
183
|
+
qualifiers = input_container.column_map.get_qualifiers_for_spark_column(
|
|
184
|
+
spark_col_name
|
|
185
|
+
)
|
|
186
|
+
grouping_columns_qualifiers.append(qualifiers)
|
|
187
|
+
spark_col_names.append(spark_col_name)
|
|
188
|
+
|
|
189
|
+
for pv_value in pivot_values:
|
|
190
|
+
pv_is_null = False
|
|
191
|
+
|
|
192
|
+
if pv_value in (None, "NULL", "None"):
|
|
193
|
+
pv_value_spark = "null"
|
|
194
|
+
pv_is_null = True
|
|
208
195
|
else:
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
196
|
+
pv_value_spark = str(pv_value)
|
|
197
|
+
|
|
198
|
+
for i, agg_expression in enumerate(agg_expressions):
|
|
199
|
+
agg_fun_expr = copy.deepcopy(agg_expression._expr1)
|
|
200
|
+
|
|
201
|
+
condition = (
|
|
202
|
+
snowpark_fn.is_null(pivot_column[1].col)
|
|
203
|
+
if pv_is_null
|
|
204
|
+
else (pivot_column[1].col == snowpark_fn.lit(pv_value))
|
|
213
205
|
)
|
|
214
|
-
else:
|
|
215
|
-
result = (
|
|
216
|
-
input_df_actual.group_by(*columns.grouping_expressions())
|
|
217
|
-
.pivot(pivot_column[1].col, pivot_values if pivot_values else None)
|
|
218
|
-
.agg(*columns.aggregation_expressions(unalias=True))
|
|
219
|
-
)
|
|
220
206
|
|
|
221
|
-
|
|
207
|
+
expression_transformer.inject_condition_to_all_agg_functions(
|
|
208
|
+
agg_fun_expr, condition
|
|
209
|
+
)
|
|
222
210
|
|
|
223
|
-
|
|
224
|
-
total_pivot_columns = len(result.columns) - len(agg_name_list)
|
|
225
|
-
num_pivot_values = (
|
|
226
|
-
total_pivot_columns // len(columns.aggregation_columns)
|
|
227
|
-
if len(columns.aggregation_columns) > 0
|
|
228
|
-
else 1
|
|
229
|
-
)
|
|
211
|
+
curr_expression = Column(agg_fun_expr)
|
|
230
212
|
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
index = (col_index - len(agg_name_list)) // num_pivot_values
|
|
236
|
-
return columns.aggregation_columns[index].spark_name
|
|
237
|
-
|
|
238
|
-
spark_columns = []
|
|
239
|
-
for col in [
|
|
240
|
-
pivot_column_name(c, _get_agg_exp_alias_for_col(i))
|
|
241
|
-
for i, c in enumerate(result.columns)
|
|
242
|
-
]:
|
|
243
|
-
spark_col = (
|
|
244
|
-
input_container.column_map.get_spark_column_name_from_snowpark_column_name(
|
|
245
|
-
col, allow_non_exists=True
|
|
213
|
+
spark_col_name = (
|
|
214
|
+
f"{pv_value_spark}_{columns.aggregation_columns[i].spark_name}"
|
|
215
|
+
if len(agg_expressions) > 1
|
|
216
|
+
else f"{pv_value_spark}"
|
|
246
217
|
)
|
|
247
|
-
)
|
|
248
218
|
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
reordered_snowpark_cols = []
|
|
266
|
-
reordered_spark_cols = []
|
|
267
|
-
column_indices = [] # 1-based indexing
|
|
268
|
-
|
|
269
|
-
for i in range(grouping_cols_count):
|
|
270
|
-
reordered_snowpark_cols.append(result.columns[i])
|
|
271
|
-
reordered_spark_cols.append(spark_columns[i])
|
|
272
|
-
column_indices.append(i + 1)
|
|
273
|
-
|
|
274
|
-
for pivot_idx in range(num_pivot_values):
|
|
275
|
-
for agg_idx in range(num_agg_functions):
|
|
276
|
-
current_pos = agg_idx * num_pivot_values + pivot_idx
|
|
277
|
-
if current_pos < len(pivot_cols):
|
|
278
|
-
reordered_snowpark_cols.append(pivot_cols[current_pos])
|
|
279
|
-
reordered_spark_cols.append(spark_pivot_cols[current_pos])
|
|
280
|
-
original_index = grouping_cols_count + current_pos
|
|
281
|
-
column_indices.append(original_index + 1)
|
|
282
|
-
|
|
283
|
-
reordered_result = result.select(
|
|
284
|
-
*[snowpark_fn.col(f"${idx}") for idx in column_indices]
|
|
219
|
+
snowpark_col_name = make_column_names_snowpark_compatible(
|
|
220
|
+
[spark_col_name],
|
|
221
|
+
rel.common.plan_id,
|
|
222
|
+
len(grouping_columns) + len(agg_expressions),
|
|
223
|
+
)[0]
|
|
224
|
+
|
|
225
|
+
curr_expression = curr_expression.alias(snowpark_col_name)
|
|
226
|
+
|
|
227
|
+
aggregations.append(curr_expression)
|
|
228
|
+
spark_col_names.append(spark_col_name)
|
|
229
|
+
final_pivot_names.append(snowpark_col_name)
|
|
230
|
+
|
|
231
|
+
result_df = (
|
|
232
|
+
input_df_actual.group_by(*grouping_columns)
|
|
233
|
+
.agg(*aggregations)
|
|
234
|
+
.select(*grouping_columns, *final_pivot_names)
|
|
285
235
|
)
|
|
286
236
|
|
|
287
237
|
return DataFrameContainer.create_with_column_mapping(
|
|
288
|
-
dataframe=
|
|
289
|
-
spark_column_names=
|
|
290
|
-
snowpark_column_names=
|
|
291
|
-
column_qualifiers=(
|
|
292
|
-
columns.get_qualifiers()[: len(agg_name_list)]
|
|
293
|
-
+ [[]] * (len(reordered_spark_cols) - len(agg_name_list))
|
|
294
|
-
),
|
|
295
|
-
parent_column_name_map=input_container.column_map,
|
|
238
|
+
dataframe=result_df,
|
|
239
|
+
spark_column_names=spark_col_names,
|
|
240
|
+
snowpark_column_names=result_df.columns,
|
|
296
241
|
snowpark_column_types=[
|
|
297
|
-
|
|
242
|
+
result_df.schema.fields[idx].datatype
|
|
243
|
+
for idx, _ in enumerate(result_df.columns)
|
|
298
244
|
],
|
|
245
|
+
column_qualifiers=grouping_columns_qualifiers
|
|
246
|
+
+ [set() for _ in final_pivot_names],
|
|
247
|
+
parent_column_name_map=input_container.column_map,
|
|
299
248
|
)
|
|
300
249
|
|
|
301
250
|
|
|
302
|
-
def pivot_column_name(snowpark_cname, opt_alias: Optional[str] = None) -> Optional[str]:
|
|
303
|
-
# For values that are used as pivoted columns, the input and output are in the following format (outermost double quotes are part of the input):
|
|
304
|
-
|
|
305
|
-
# 1. "'Java'" -> Java
|
|
306
|
-
# 2. "'""C++""'" -> "C++"
|
|
307
|
-
# 3. "'""""''Scala''""""'" -> ""'Scala'""
|
|
308
|
-
|
|
309
|
-
# As we can see:
|
|
310
|
-
# 1. the whole content is always nested in a double quote followed by a single quote ("'<content>'").
|
|
311
|
-
# 2. the string content is nested in single quotes ('<string_content>')
|
|
312
|
-
# 3. double quote is escased by another double quote, this is snowflake behavior
|
|
313
|
-
# 4. if there is a single quote followed by a single quote, the first single quote needs to be preserved in the output
|
|
314
|
-
|
|
315
|
-
try:
|
|
316
|
-
# handling values that are used as pivoted columns
|
|
317
|
-
match = re.match(r'^"\'(.*)\'"$', snowpark_cname)
|
|
318
|
-
# extract the content between the outermost double quote followed by a single quote "'
|
|
319
|
-
content = match.group(1)
|
|
320
|
-
# convert the escaped double quote to the actual double quote
|
|
321
|
-
content = content.replace('""', '"')
|
|
322
|
-
escape_single_quote_placeholder = "__SAS_PLACEHOLDER_ESCAPE_SINGLE_QUOTE__"
|
|
323
|
-
# replace two consecutive single quote in the content with a placeholder, the first single quote needs to be preserved
|
|
324
|
-
content = re.sub(r"''", escape_single_quote_placeholder, content)
|
|
325
|
-
# remove the solo single quote, they are not part of the string content
|
|
326
|
-
content = re.sub(r"'", "", content)
|
|
327
|
-
# replace the placeholder with the single quote which we want to preserve
|
|
328
|
-
result = content.replace(escape_single_quote_placeholder, "'")
|
|
329
|
-
return f"{result}_{opt_alias}" if opt_alias else result
|
|
330
|
-
except Exception:
|
|
331
|
-
# fallback to the original logic, handling aliased column names
|
|
332
|
-
double_quote_list = re.findall(r'"(.*?)"', snowpark_cname)
|
|
333
|
-
spark_string = ""
|
|
334
|
-
for entry in list(filter(None, double_quote_list)):
|
|
335
|
-
if "'" in entry:
|
|
336
|
-
entry = entry.replace("'", "")
|
|
337
|
-
if len(entry) > 0:
|
|
338
|
-
spark_string += entry
|
|
339
|
-
elif entry.isdigit() or re.compile(r"^\d+?\.\d+?$").match(entry):
|
|
340
|
-
# skip quoting digits or decimal numbers as column names.
|
|
341
|
-
spark_string += entry
|
|
342
|
-
else:
|
|
343
|
-
spark_string += '"' + entry + '"'
|
|
344
|
-
return snowpark_cname if spark_string == "" else spark_string
|
|
345
|
-
|
|
346
|
-
|
|
347
251
|
@dataclass(frozen=True)
|
|
348
252
|
class _ColumnMetadata:
|
|
349
253
|
expression: snowpark.Column
|
|
@@ -414,71 +318,68 @@ def map_aggregate_helper(
|
|
|
414
318
|
typer = ExpressionTyper(input_df)
|
|
415
319
|
schema_inferrable = True
|
|
416
320
|
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
)
|
|
422
|
-
alias = make_column_names_snowpark_compatible(
|
|
423
|
-
[new_name], rel.common.plan_id, len(groupings)
|
|
424
|
-
)[0]
|
|
425
|
-
groupings.append(
|
|
426
|
-
_ColumnMetadata(
|
|
427
|
-
snowpark_column.col
|
|
428
|
-
if skip_alias
|
|
429
|
-
else snowpark_column.col.alias(alias),
|
|
430
|
-
new_name,
|
|
431
|
-
None if skip_alias else alias,
|
|
432
|
-
None if pivot else snowpark_column.typ,
|
|
433
|
-
qualifiers=snowpark_column.get_qualifiers(),
|
|
434
|
-
)
|
|
435
|
-
)
|
|
436
|
-
|
|
437
|
-
grouping_cols = [g.spark_name for g in groupings]
|
|
438
|
-
set_current_grouping_columns(grouping_cols)
|
|
321
|
+
for exp in grouping_expressions:
|
|
322
|
+
new_name, snowpark_column = map_single_column_expression(
|
|
323
|
+
exp, input_container.column_map, typer
|
|
324
|
+
)
|
|
439
325
|
|
|
440
|
-
|
|
441
|
-
new_name,
|
|
442
|
-
|
|
326
|
+
alias = make_column_names_snowpark_compatible(
|
|
327
|
+
[new_name], rel.common.plan_id, len(groupings)
|
|
328
|
+
)[0]
|
|
329
|
+
|
|
330
|
+
groupings.append(
|
|
331
|
+
_ColumnMetadata(
|
|
332
|
+
snowpark_column.col if skip_alias else snowpark_column.col.alias(alias),
|
|
333
|
+
new_name,
|
|
334
|
+
None if skip_alias else alias,
|
|
335
|
+
None if pivot else snowpark_column.typ,
|
|
336
|
+
qualifiers=snowpark_column.get_qualifiers(),
|
|
443
337
|
)
|
|
444
|
-
|
|
445
|
-
[new_name], rel.common.plan_id, len(groupings) + len(aggregations)
|
|
446
|
-
)[0]
|
|
338
|
+
)
|
|
447
339
|
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
) -> DataType | None:
|
|
451
|
-
if pivot or not schema_inferrable:
|
|
452
|
-
return None
|
|
453
|
-
try:
|
|
454
|
-
return agg_exp.typ
|
|
455
|
-
except Exception:
|
|
456
|
-
# This type used for schema inference optimization purposes.
|
|
457
|
-
# typer may not be able to infer the type of some expressions
|
|
458
|
-
# in that case we return None, and the optimization will not be applied.
|
|
459
|
-
return None
|
|
460
|
-
|
|
461
|
-
agg_col_typ = type_agg_expr(snowpark_column, schema_inferrable)
|
|
462
|
-
if agg_col_typ is None:
|
|
463
|
-
schema_inferrable = False
|
|
464
|
-
|
|
465
|
-
aggregations.append(
|
|
466
|
-
_ColumnMetadata(
|
|
467
|
-
snowpark_column.col
|
|
468
|
-
if skip_alias
|
|
469
|
-
else snowpark_column.col.alias(alias),
|
|
470
|
-
new_name,
|
|
471
|
-
None if skip_alias else alias,
|
|
472
|
-
agg_col_typ,
|
|
473
|
-
qualifiers=set(),
|
|
474
|
-
)
|
|
475
|
-
)
|
|
340
|
+
grouping_cols = [g.spark_name for g in groupings]
|
|
341
|
+
set_current_grouping_columns(grouping_cols)
|
|
476
342
|
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
343
|
+
for exp in expressions:
|
|
344
|
+
new_name, snowpark_column = map_single_column_expression(
|
|
345
|
+
exp, input_container.column_map, typer
|
|
346
|
+
)
|
|
347
|
+
alias = make_column_names_snowpark_compatible(
|
|
348
|
+
[new_name], rel.common.plan_id, len(groupings) + len(aggregations)
|
|
349
|
+
)[0]
|
|
350
|
+
|
|
351
|
+
def type_agg_expr(
|
|
352
|
+
agg_exp: TypedColumn, schema_inferrable: bool
|
|
353
|
+
) -> DataType | None:
|
|
354
|
+
if pivot or not schema_inferrable:
|
|
355
|
+
return None
|
|
356
|
+
try:
|
|
357
|
+
return agg_exp.typ
|
|
358
|
+
except Exception:
|
|
359
|
+
# This type used for schema inference optimization purposes.
|
|
360
|
+
# typer may not be able to infer the type of some expressions
|
|
361
|
+
# in that case we return None, and the optimization will not be applied.
|
|
362
|
+
return None
|
|
363
|
+
|
|
364
|
+
agg_col_typ = type_agg_expr(snowpark_column, schema_inferrable)
|
|
365
|
+
if agg_col_typ is None:
|
|
366
|
+
schema_inferrable = False
|
|
367
|
+
|
|
368
|
+
aggregations.append(
|
|
369
|
+
_ColumnMetadata(
|
|
370
|
+
snowpark_column.col if skip_alias else snowpark_column.col.alias(alias),
|
|
371
|
+
new_name,
|
|
372
|
+
None if skip_alias else alias,
|
|
373
|
+
agg_col_typ,
|
|
374
|
+
qualifiers=set(),
|
|
375
|
+
)
|
|
484
376
|
)
|
|
377
|
+
|
|
378
|
+
return (
|
|
379
|
+
input_container,
|
|
380
|
+
_Columns(
|
|
381
|
+
grouping_columns=groupings,
|
|
382
|
+
aggregation_columns=aggregations,
|
|
383
|
+
can_infer_schema=schema_inferrable,
|
|
384
|
+
),
|
|
385
|
+
)
|
|
@@ -288,6 +288,20 @@ def map_project(
|
|
|
288
288
|
alias_types = mapper.types
|
|
289
289
|
typed_alias = TypedColumn(aliased_col, lambda types=alias_types: types)
|
|
290
290
|
register_lca_alias(spark_name, typed_alias)
|
|
291
|
+
|
|
292
|
+
# Also register with the original qualified name if this is an alias of a column reference
|
|
293
|
+
# This handles ORDER BY referencing the original name: SELECT o.date AS order_date ... ORDER BY o.date
|
|
294
|
+
if (
|
|
295
|
+
exp.alias.HasField("expr")
|
|
296
|
+
and exp.alias.expr.WhichOneof("expr_type") == "unresolved_attribute"
|
|
297
|
+
):
|
|
298
|
+
original_name = (
|
|
299
|
+
exp.alias.expr.unresolved_attribute.unparsed_identifier
|
|
300
|
+
)
|
|
301
|
+
if (
|
|
302
|
+
original_name != spark_name
|
|
303
|
+
): # Don't register twice with the same name
|
|
304
|
+
register_lca_alias(original_name, typed_alias)
|
|
291
305
|
else:
|
|
292
306
|
# Multi-column case ('select *', posexplode, explode, inline, etc.)
|
|
293
307
|
has_multi_column_alias = True
|