snowpark-connect 0.32.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of snowpark-connect might be problematic. Click here for more details.

Files changed (106) hide show
  1. snowflake/snowpark_connect/column_name_handler.py +91 -40
  2. snowflake/snowpark_connect/column_qualifier.py +0 -4
  3. snowflake/snowpark_connect/config.py +9 -0
  4. snowflake/snowpark_connect/expression/hybrid_column_map.py +5 -4
  5. snowflake/snowpark_connect/expression/literal.py +12 -12
  6. snowflake/snowpark_connect/expression/map_sql_expression.py +18 -4
  7. snowflake/snowpark_connect/expression/map_unresolved_attribute.py +150 -29
  8. snowflake/snowpark_connect/expression/map_unresolved_function.py +93 -55
  9. snowflake/snowpark_connect/relation/map_aggregate.py +156 -257
  10. snowflake/snowpark_connect/relation/map_column_ops.py +19 -0
  11. snowflake/snowpark_connect/relation/map_join.py +454 -252
  12. snowflake/snowpark_connect/relation/map_row_ops.py +136 -54
  13. snowflake/snowpark_connect/relation/map_sql.py +335 -90
  14. snowflake/snowpark_connect/relation/read/map_read.py +9 -1
  15. snowflake/snowpark_connect/relation/read/map_read_csv.py +19 -2
  16. snowflake/snowpark_connect/relation/read/map_read_json.py +90 -2
  17. snowflake/snowpark_connect/relation/read/map_read_parquet.py +3 -0
  18. snowflake/snowpark_connect/relation/read/map_read_text.py +4 -0
  19. snowflake/snowpark_connect/relation/read/reader_config.py +10 -0
  20. snowflake/snowpark_connect/relation/read/utils.py +41 -0
  21. snowflake/snowpark_connect/relation/utils.py +50 -2
  22. snowflake/snowpark_connect/relation/write/map_write.py +251 -292
  23. snowflake/snowpark_connect/resources_initializer.py +25 -13
  24. snowflake/snowpark_connect/server.py +9 -24
  25. snowflake/snowpark_connect/type_mapping.py +2 -0
  26. snowflake/snowpark_connect/typed_column.py +2 -2
  27. snowflake/snowpark_connect/utils/context.py +0 -14
  28. snowflake/snowpark_connect/utils/expression_transformer.py +163 -0
  29. snowflake/snowpark_connect/utils/sequence.py +21 -0
  30. snowflake/snowpark_connect/utils/session.py +4 -1
  31. snowflake/snowpark_connect/utils/udf_helper.py +1 -0
  32. snowflake/snowpark_connect/utils/udtf_helper.py +3 -0
  33. snowflake/snowpark_connect/version.py +1 -1
  34. {snowpark_connect-0.32.0.dist-info → snowpark_connect-1.0.0.dist-info}/METADATA +4 -2
  35. {snowpark_connect-0.32.0.dist-info → snowpark_connect-1.0.0.dist-info}/RECORD +43 -104
  36. snowflake/snowpark_connect/includes/jars/antlr4-runtime-4.9.3.jar +0 -0
  37. snowflake/snowpark_connect/includes/jars/commons-cli-1.5.0.jar +0 -0
  38. snowflake/snowpark_connect/includes/jars/commons-codec-1.16.1.jar +0 -0
  39. snowflake/snowpark_connect/includes/jars/commons-collections-3.2.2.jar +0 -0
  40. snowflake/snowpark_connect/includes/jars/commons-collections4-4.4.jar +0 -0
  41. snowflake/snowpark_connect/includes/jars/commons-compiler-3.1.9.jar +0 -0
  42. snowflake/snowpark_connect/includes/jars/commons-compress-1.26.0.jar +0 -0
  43. snowflake/snowpark_connect/includes/jars/commons-crypto-1.1.0.jar +0 -0
  44. snowflake/snowpark_connect/includes/jars/commons-dbcp-1.4.jar +0 -0
  45. snowflake/snowpark_connect/includes/jars/commons-io-2.16.1.jar +0 -0
  46. snowflake/snowpark_connect/includes/jars/commons-lang-2.6.jar +0 -0
  47. snowflake/snowpark_connect/includes/jars/commons-lang3-3.12.0.jar +0 -0
  48. snowflake/snowpark_connect/includes/jars/commons-logging-1.1.3.jar +0 -0
  49. snowflake/snowpark_connect/includes/jars/commons-math3-3.6.1.jar +0 -0
  50. snowflake/snowpark_connect/includes/jars/commons-pool-1.5.4.jar +0 -0
  51. snowflake/snowpark_connect/includes/jars/commons-text-1.10.0.jar +0 -0
  52. snowflake/snowpark_connect/includes/jars/hadoop-client-api-trimmed-3.3.4.jar +0 -0
  53. snowflake/snowpark_connect/includes/jars/jackson-annotations-2.15.2.jar +0 -0
  54. snowflake/snowpark_connect/includes/jars/jackson-core-2.15.2.jar +0 -0
  55. snowflake/snowpark_connect/includes/jars/jackson-core-asl-1.9.13.jar +0 -0
  56. snowflake/snowpark_connect/includes/jars/jackson-databind-2.15.2.jar +0 -0
  57. snowflake/snowpark_connect/includes/jars/jackson-dataformat-yaml-2.15.2.jar +0 -0
  58. snowflake/snowpark_connect/includes/jars/jackson-datatype-jsr310-2.15.2.jar +0 -0
  59. snowflake/snowpark_connect/includes/jars/jackson-module-scala_2.12-2.15.2.jar +0 -0
  60. snowflake/snowpark_connect/includes/jars/json4s-ast_2.12-3.7.0-M11.jar +0 -0
  61. snowflake/snowpark_connect/includes/jars/json4s-core_2.12-3.7.0-M11.jar +0 -0
  62. snowflake/snowpark_connect/includes/jars/json4s-jackson_2.12-3.7.0-M11.jar +0 -0
  63. snowflake/snowpark_connect/includes/jars/json4s-native_2.12-3.7.0-M11.jar +0 -0
  64. snowflake/snowpark_connect/includes/jars/json4s-scalap_2.12-3.7.0-M11.jar +0 -0
  65. snowflake/snowpark_connect/includes/jars/kryo-shaded-4.0.2.jar +0 -0
  66. snowflake/snowpark_connect/includes/jars/log4j-1.2-api-2.20.0.jar +0 -0
  67. snowflake/snowpark_connect/includes/jars/log4j-api-2.20.0.jar +0 -0
  68. snowflake/snowpark_connect/includes/jars/log4j-core-2.20.0.jar +0 -0
  69. snowflake/snowpark_connect/includes/jars/log4j-slf4j2-impl-2.20.0.jar +0 -0
  70. snowflake/snowpark_connect/includes/jars/paranamer-2.8.3.jar +0 -0
  71. snowflake/snowpark_connect/includes/jars/paranamer-2.8.jar +0 -0
  72. snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.1.0.jar +0 -0
  73. snowflake/snowpark_connect/includes/jars/scala-collection-compat_2.12-2.7.0.jar +0 -0
  74. snowflake/snowpark_connect/includes/jars/scala-library-2.12.18.jar +0 -0
  75. snowflake/snowpark_connect/includes/jars/scala-parser-combinators_2.12-2.3.0.jar +0 -0
  76. snowflake/snowpark_connect/includes/jars/scala-reflect-2.12.18.jar +0 -0
  77. snowflake/snowpark_connect/includes/jars/scala-xml_2.12-2.1.0.jar +0 -0
  78. snowflake/snowpark_connect/includes/jars/slf4j-api-2.0.7.jar +0 -0
  79. snowflake/snowpark_connect/includes/jars/spark-catalyst_2.12-3.5.6.jar +0 -0
  80. snowflake/snowpark_connect/includes/jars/spark-common-utils_2.12-3.5.6.jar +0 -0
  81. snowflake/snowpark_connect/includes/jars/spark-connect-client-jvm_2.12-3.5.6.jar +0 -0
  82. snowflake/snowpark_connect/includes/jars/spark-core_2.12-3.5.6.jar +0 -0
  83. snowflake/snowpark_connect/includes/jars/spark-graphx_2.12-3.5.6.jar +0 -0
  84. snowflake/snowpark_connect/includes/jars/spark-hive-thriftserver_2.12-3.5.6.jar +0 -0
  85. snowflake/snowpark_connect/includes/jars/spark-hive_2.12-3.5.6.jar +0 -0
  86. snowflake/snowpark_connect/includes/jars/spark-kvstore_2.12-3.5.6.jar +0 -0
  87. snowflake/snowpark_connect/includes/jars/spark-launcher_2.12-3.5.6.jar +0 -0
  88. snowflake/snowpark_connect/includes/jars/spark-mesos_2.12-3.5.6.jar +0 -0
  89. snowflake/snowpark_connect/includes/jars/spark-mllib-local_2.12-3.5.6.jar +0 -0
  90. snowflake/snowpark_connect/includes/jars/spark-network-common_2.12-3.5.6.jar +0 -0
  91. snowflake/snowpark_connect/includes/jars/spark-network-shuffle_2.12-3.5.6.jar +0 -0
  92. snowflake/snowpark_connect/includes/jars/spark-repl_2.12-3.5.6.jar +0 -0
  93. snowflake/snowpark_connect/includes/jars/spark-sketch_2.12-3.5.6.jar +0 -0
  94. snowflake/snowpark_connect/includes/jars/spark-sql-api_2.12-3.5.6.jar +0 -0
  95. snowflake/snowpark_connect/includes/jars/spark-sql_2.12-3.5.6.jar +0 -0
  96. snowflake/snowpark_connect/includes/jars/spark-tags_2.12-3.5.6.jar +0 -0
  97. snowflake/snowpark_connect/includes/jars/spark-unsafe_2.12-3.5.6.jar +0 -0
  98. snowflake/snowpark_connect/includes/jars/spark-yarn_2.12-3.5.6.jar +0 -0
  99. {snowpark_connect-0.32.0.data → snowpark_connect-1.0.0.data}/scripts/snowpark-connect +0 -0
  100. {snowpark_connect-0.32.0.data → snowpark_connect-1.0.0.data}/scripts/snowpark-session +0 -0
  101. {snowpark_connect-0.32.0.data → snowpark_connect-1.0.0.data}/scripts/snowpark-submit +0 -0
  102. {snowpark_connect-0.32.0.dist-info → snowpark_connect-1.0.0.dist-info}/WHEEL +0 -0
  103. {snowpark_connect-0.32.0.dist-info → snowpark_connect-1.0.0.dist-info}/licenses/LICENSE-binary +0 -0
  104. {snowpark_connect-0.32.0.dist-info → snowpark_connect-1.0.0.dist-info}/licenses/LICENSE.txt +0 -0
  105. {snowpark_connect-0.32.0.dist-info → snowpark_connect-1.0.0.dist-info}/licenses/NOTICE-binary +0 -0
  106. {snowpark_connect-0.32.0.dist-info → snowpark_connect-1.0.0.dist-info}/top_level.txt +0 -0
@@ -2,9 +2,8 @@
2
2
  # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
3
3
  #
4
4
 
5
- import re
5
+ import copy
6
6
  from dataclasses import dataclass
7
- from typing import Optional
8
7
 
9
8
  import pyspark.sql.connect.proto.relations_pb2 as relation_proto
10
9
 
@@ -25,10 +24,10 @@ from snowflake.snowpark_connect.expression.map_expression import (
25
24
  from snowflake.snowpark_connect.expression.typer import ExpressionTyper
26
25
  from snowflake.snowpark_connect.relation.map_relation import map_relation
27
26
  from snowflake.snowpark_connect.typed_column import TypedColumn
27
+ from snowflake.snowpark_connect.utils import expression_transformer
28
28
  from snowflake.snowpark_connect.utils.context import (
29
29
  get_is_evaluating_sql,
30
30
  set_current_grouping_columns,
31
- temporary_pivot_expression,
32
31
  )
33
32
 
34
33
 
@@ -137,215 +136,118 @@ def map_pivot_aggregate(
137
136
  get_literal_field_and_name(lit)[0] for lit in rel.aggregate.pivot.values
138
137
  ]
139
138
 
140
- used_columns = {pivot_column[1].col._expression.name}
141
- if get_is_evaluating_sql():
142
- # When evaluating SQL spark doesn't trim columns from the result
143
- used_columns = {"*"}
144
- else:
145
- for expression in rel.aggregate.aggregate_expressions:
146
- matched_identifiers = re.findall(
147
- r'unparsed_identifier: "(.*)"', expression.__str__()
139
+ if not pivot_values:
140
+ distinct_col_values = (
141
+ input_df_actual.select(pivot_column[1].col)
142
+ .distinct()
143
+ .sort(snowpark_fn.asc_nulls_first(pivot_column[1].col))
144
+ .collect()
145
+ )
146
+ pivot_values = [row[0] for row in distinct_col_values]
147
+
148
+ agg_expressions = columns.aggregation_expressions(unalias=True)
149
+
150
+ spark_col_names = []
151
+ aggregations = []
152
+ final_pivot_names = []
153
+ grouping_columns_qualifiers = []
154
+
155
+ pivot_col_name = pivot_column[1].col.get_name()
156
+
157
+ agg_columns = set()
158
+ for agg_expression in agg_expressions:
159
+ if hasattr(agg_expression, "_expr1"):
160
+ agg_columns = agg_columns.union(
161
+ agg_expression._expr1.dependent_column_names()
148
162
  )
149
- for identifier in matched_identifiers:
150
- mapped_col = input_container.column_map.spark_to_col.get(
151
- identifier, None
152
- )
153
- if mapped_col:
154
- used_columns.add(mapped_col[0].snowpark_name)
155
163
 
156
- if len(columns.grouping_expressions()) == 0:
157
- # Snowpark doesn't support multiple aggregations in pivot without groupBy
158
- # So we need to perform each aggregation separately and then combine results
159
- if len(columns.aggregation_expressions(unalias=True)) > 1:
160
- agg_expressions = columns.aggregation_expressions(unalias=True)
161
- agg_metadata = columns.aggregation_columns
162
- num_agg_functions = len(agg_expressions)
163
-
164
- spark_names = []
165
- pivot_results = []
166
- for i, agg_expr in enumerate(agg_expressions):
167
- pivot_result = (
168
- input_df_actual.select(*used_columns)
169
- .pivot(pivot_column[1].col, pivot_values if pivot_values else None)
170
- .agg(agg_expr)
171
- )
172
- for col_name in pivot_result.columns:
173
- spark_names.append(
174
- f"{pivot_column_name(col_name)}_{agg_metadata[i].spark_name}"
175
- )
176
- pivot_results.append(pivot_result)
177
-
178
- result = pivot_results[0]
179
- for pivot_result in pivot_results[1:]:
180
- result = result.cross_join(pivot_result)
181
-
182
- pivot_columns_per_agg = len(pivot_results[0].columns)
183
- reordered_spark_names = []
184
- reordered_snowpark_names = []
185
- reordered_types = []
186
- column_selectors = []
187
-
188
- for pivot_idx in range(pivot_columns_per_agg):
189
- for agg_idx in range(num_agg_functions):
190
- current_pos = agg_idx * pivot_columns_per_agg + pivot_idx
191
- if current_pos < len(spark_names):
192
- idx = current_pos + 1 # 1-based indexing for Snowpark
193
- reordered_spark_names.append(spark_names[current_pos])
194
- reordered_snowpark_names.append(f"${idx}")
195
- reordered_types.append(
196
- result.schema.fields[current_pos].datatype
197
- )
198
- column_selectors.append(snowpark_fn.col(f"${idx}"))
199
-
200
- return DataFrameContainer.create_with_column_mapping(
201
- dataframe=result.select(*column_selectors),
202
- spark_column_names=reordered_spark_names,
203
- snowpark_column_names=reordered_snowpark_names,
204
- column_qualifiers=[
205
- {ColumnQualifier.no_qualifier()} for _ in reordered_spark_names
206
- ],
207
- parent_column_name_map=input_container.column_map,
208
- snowpark_column_types=reordered_types,
164
+ grouping_columns = columns.grouping_expressions()
165
+ if grouping_columns:
166
+ for col in grouping_columns:
167
+ snowpark_name = col.get_name()
168
+ spark_col_name = input_container.column_map.get_spark_column_name_from_snowpark_column_name(
169
+ snowpark_name
170
+ )
171
+ qualifiers = input_container.column_map.get_qualifiers_for_spark_column(
172
+ spark_col_name
209
173
  )
174
+ grouping_columns_qualifiers.append(qualifiers)
175
+ spark_col_names.append(spark_col_name)
176
+ elif get_is_evaluating_sql():
177
+ for col in input_container.column_map.get_snowpark_columns():
178
+ if col != pivot_col_name and col not in agg_columns:
179
+ grouping_columns.append(col)
180
+ spark_col_name = input_container.column_map.get_spark_column_name_from_snowpark_column_name(
181
+ col
182
+ )
183
+ qualifiers = input_container.column_map.get_qualifiers_for_spark_column(
184
+ spark_col_name
185
+ )
186
+ grouping_columns_qualifiers.append(qualifiers)
187
+ spark_col_names.append(spark_col_name)
188
+
189
+ for pv_value in pivot_values:
190
+ pv_is_null = False
191
+
192
+ if pv_value in (None, "NULL", "None"):
193
+ pv_value_spark = "null"
194
+ pv_is_null = True
210
195
  else:
211
- result = (
212
- input_df_actual.select(*used_columns)
213
- .pivot(pivot_column[1].col, pivot_values if pivot_values else None)
214
- .agg(*columns.aggregation_expressions(unalias=True))
196
+ pv_value_spark = str(pv_value)
197
+
198
+ for i, agg_expression in enumerate(agg_expressions):
199
+ agg_fun_expr = copy.deepcopy(agg_expression._expr1)
200
+
201
+ condition = (
202
+ snowpark_fn.is_null(pivot_column[1].col)
203
+ if pv_is_null
204
+ else (pivot_column[1].col == snowpark_fn.lit(pv_value))
215
205
  )
216
- else:
217
- result = (
218
- input_df_actual.group_by(*columns.grouping_expressions())
219
- .pivot(pivot_column[1].col, pivot_values if pivot_values else None)
220
- .agg(*columns.aggregation_expressions(unalias=True))
221
- )
222
206
 
223
- agg_name_list = [c.spark_name for c in columns.grouping_columns]
207
+ expression_transformer.inject_condition_to_all_agg_functions(
208
+ agg_fun_expr, condition
209
+ )
224
210
 
225
- # Calculate number of pivot values for proper Spark-compatible indexing
226
- total_pivot_columns = len(result.columns) - len(agg_name_list)
227
- num_pivot_values = (
228
- total_pivot_columns // len(columns.aggregation_columns)
229
- if len(columns.aggregation_columns) > 0
230
- else 1
231
- )
211
+ curr_expression = Column(agg_fun_expr)
232
212
 
233
- def _get_agg_exp_alias_for_col(col_index: int) -> Optional[str]:
234
- if col_index < len(agg_name_list) or len(columns.aggregation_columns) <= 1:
235
- return None
236
- else:
237
- index = (col_index - len(agg_name_list)) // num_pivot_values
238
- return columns.aggregation_columns[index].spark_name
239
-
240
- spark_columns = []
241
- for col in [
242
- pivot_column_name(c, _get_agg_exp_alias_for_col(i))
243
- for i, c in enumerate(result.columns)
244
- ]:
245
- spark_col = (
246
- input_container.column_map.get_spark_column_name_from_snowpark_column_name(
247
- col, allow_non_exists=True
213
+ spark_col_name = (
214
+ f"{pv_value_spark}_{columns.aggregation_columns[i].spark_name}"
215
+ if len(agg_expressions) > 1
216
+ else f"{pv_value_spark}"
248
217
  )
249
- )
250
218
 
251
- if spark_col is not None:
252
- spark_columns.append(spark_col)
253
- else:
254
- # Handle NULL column names to match Spark behavior (lowercase 'null')
255
- if col == "NULL":
256
- spark_columns.append(col.lower())
257
- else:
258
- spark_columns.append(col)
259
-
260
- grouping_cols_count = len(agg_name_list)
261
- pivot_cols = result.columns[grouping_cols_count:]
262
- spark_pivot_cols = spark_columns[grouping_cols_count:]
263
-
264
- num_agg_functions = len(columns.aggregation_columns)
265
- num_pivot_values = len(pivot_cols) // num_agg_functions
266
-
267
- reordered_snowpark_cols = []
268
- reordered_spark_cols = []
269
- column_indices = [] # 1-based indexing
270
-
271
- for i in range(grouping_cols_count):
272
- reordered_snowpark_cols.append(result.columns[i])
273
- reordered_spark_cols.append(spark_columns[i])
274
- column_indices.append(i + 1)
275
-
276
- for pivot_idx in range(num_pivot_values):
277
- for agg_idx in range(num_agg_functions):
278
- current_pos = agg_idx * num_pivot_values + pivot_idx
279
- if current_pos < len(pivot_cols):
280
- reordered_snowpark_cols.append(pivot_cols[current_pos])
281
- reordered_spark_cols.append(spark_pivot_cols[current_pos])
282
- original_index = grouping_cols_count + current_pos
283
- column_indices.append(original_index + 1)
284
-
285
- reordered_result = result.select(
286
- *[snowpark_fn.col(f"${idx}") for idx in column_indices]
219
+ snowpark_col_name = make_column_names_snowpark_compatible(
220
+ [spark_col_name],
221
+ rel.common.plan_id,
222
+ len(grouping_columns) + len(agg_expressions),
223
+ )[0]
224
+
225
+ curr_expression = curr_expression.alias(snowpark_col_name)
226
+
227
+ aggregations.append(curr_expression)
228
+ spark_col_names.append(spark_col_name)
229
+ final_pivot_names.append(snowpark_col_name)
230
+
231
+ result_df = (
232
+ input_df_actual.group_by(*grouping_columns)
233
+ .agg(*aggregations)
234
+ .select(*grouping_columns, *final_pivot_names)
287
235
  )
288
236
 
289
237
  return DataFrameContainer.create_with_column_mapping(
290
- dataframe=reordered_result,
291
- spark_column_names=reordered_spark_cols,
292
- snowpark_column_names=[f"${idx}" for idx in column_indices],
293
- column_qualifiers=(
294
- columns.get_qualifiers()[: len(agg_name_list)]
295
- + [[]] * (len(reordered_spark_cols) - len(agg_name_list))
296
- ),
297
- parent_column_name_map=input_container.column_map,
238
+ dataframe=result_df,
239
+ spark_column_names=spark_col_names,
240
+ snowpark_column_names=result_df.columns,
298
241
  snowpark_column_types=[
299
- result.schema.fields[idx - 1].datatype for idx in column_indices
242
+ result_df.schema.fields[idx].datatype
243
+ for idx, _ in enumerate(result_df.columns)
300
244
  ],
245
+ column_qualifiers=grouping_columns_qualifiers
246
+ + [set() for _ in final_pivot_names],
247
+ parent_column_name_map=input_container.column_map,
301
248
  )
302
249
 
303
250
 
304
- def pivot_column_name(snowpark_cname, opt_alias: Optional[str] = None) -> Optional[str]:
305
- # For values that are used as pivoted columns, the input and output are in the following format (outermost double quotes are part of the input):
306
-
307
- # 1. "'Java'" -> Java
308
- # 2. "'""C++""'" -> "C++"
309
- # 3. "'""""''Scala''""""'" -> ""'Scala'""
310
-
311
- # As we can see:
312
- # 1. the whole content is always nested in a double quote followed by a single quote ("'<content>'").
313
- # 2. the string content is nested in single quotes ('<string_content>')
314
- # 3. double quote is escased by another double quote, this is snowflake behavior
315
- # 4. if there is a single quote followed by a single quote, the first single quote needs to be preserved in the output
316
-
317
- try:
318
- # handling values that are used as pivoted columns
319
- match = re.match(r'^"\'(.*)\'"$', snowpark_cname)
320
- # extract the content between the outermost double quote followed by a single quote "'
321
- content = match.group(1)
322
- # convert the escaped double quote to the actual double quote
323
- content = content.replace('""', '"')
324
- escape_single_quote_placeholder = "__SAS_PLACEHOLDER_ESCAPE_SINGLE_QUOTE__"
325
- # replace two consecutive single quote in the content with a placeholder, the first single quote needs to be preserved
326
- content = re.sub(r"''", escape_single_quote_placeholder, content)
327
- # remove the solo single quote, they are not part of the string content
328
- content = re.sub(r"'", "", content)
329
- # replace the placeholder with the single quote which we want to preserve
330
- result = content.replace(escape_single_quote_placeholder, "'")
331
- return f"{result}_{opt_alias}" if opt_alias else result
332
- except Exception:
333
- # fallback to the original logic, handling aliased column names
334
- double_quote_list = re.findall(r'"(.*?)"', snowpark_cname)
335
- spark_string = ""
336
- for entry in list(filter(None, double_quote_list)):
337
- if "'" in entry:
338
- entry = entry.replace("'", "")
339
- if len(entry) > 0:
340
- spark_string += entry
341
- elif entry.isdigit() or re.compile(r"^\d+?\.\d+?$").match(entry):
342
- # skip quoting digits or decimal numbers as column names.
343
- spark_string += entry
344
- else:
345
- spark_string += '"' + entry + '"'
346
- return snowpark_cname if spark_string == "" else spark_string
347
-
348
-
349
251
  @dataclass(frozen=True)
350
252
  class _ColumnMetadata:
351
253
  expression: snowpark.Column
@@ -416,71 +318,68 @@ def map_aggregate_helper(
416
318
  typer = ExpressionTyper(input_df)
417
319
  schema_inferrable = True
418
320
 
419
- with temporary_pivot_expression(pivot):
420
- for exp in grouping_expressions:
421
- new_name, snowpark_column = map_single_column_expression(
422
- exp, input_container.column_map, typer
423
- )
424
- alias = make_column_names_snowpark_compatible(
425
- [new_name], rel.common.plan_id, len(groupings)
426
- )[0]
427
- groupings.append(
428
- _ColumnMetadata(
429
- snowpark_column.col
430
- if skip_alias
431
- else snowpark_column.col.alias(alias),
432
- new_name,
433
- None if skip_alias else alias,
434
- None if pivot else snowpark_column.typ,
435
- qualifiers=snowpark_column.get_qualifiers(),
436
- )
437
- )
438
-
439
- grouping_cols = [g.spark_name for g in groupings]
440
- set_current_grouping_columns(grouping_cols)
321
+ for exp in grouping_expressions:
322
+ new_name, snowpark_column = map_single_column_expression(
323
+ exp, input_container.column_map, typer
324
+ )
441
325
 
442
- for exp in expressions:
443
- new_name, snowpark_column = map_single_column_expression(
444
- exp, input_container.column_map, typer
326
+ alias = make_column_names_snowpark_compatible(
327
+ [new_name], rel.common.plan_id, len(groupings)
328
+ )[0]
329
+
330
+ groupings.append(
331
+ _ColumnMetadata(
332
+ snowpark_column.col if skip_alias else snowpark_column.col.alias(alias),
333
+ new_name,
334
+ None if skip_alias else alias,
335
+ None if pivot else snowpark_column.typ,
336
+ qualifiers=snowpark_column.get_qualifiers(),
445
337
  )
446
- alias = make_column_names_snowpark_compatible(
447
- [new_name], rel.common.plan_id, len(groupings) + len(aggregations)
448
- )[0]
338
+ )
449
339
 
450
- def type_agg_expr(
451
- agg_exp: TypedColumn, schema_inferrable: bool
452
- ) -> DataType | None:
453
- if pivot or not schema_inferrable:
454
- return None
455
- try:
456
- return agg_exp.typ
457
- except Exception:
458
- # This type used for schema inference optimization purposes.
459
- # typer may not be able to infer the type of some expressions
460
- # in that case we return None, and the optimization will not be applied.
461
- return None
462
-
463
- agg_col_typ = type_agg_expr(snowpark_column, schema_inferrable)
464
- if agg_col_typ is None:
465
- schema_inferrable = False
466
-
467
- aggregations.append(
468
- _ColumnMetadata(
469
- snowpark_column.col
470
- if skip_alias
471
- else snowpark_column.col.alias(alias),
472
- new_name,
473
- None if skip_alias else alias,
474
- agg_col_typ,
475
- qualifiers={ColumnQualifier.no_qualifier()},
476
- )
477
- )
340
+ grouping_cols = [g.spark_name for g in groupings]
341
+ set_current_grouping_columns(grouping_cols)
478
342
 
479
- return (
480
- input_container,
481
- _Columns(
482
- grouping_columns=groupings,
483
- aggregation_columns=aggregations,
484
- can_infer_schema=schema_inferrable,
485
- ),
343
+ for exp in expressions:
344
+ new_name, snowpark_column = map_single_column_expression(
345
+ exp, input_container.column_map, typer
346
+ )
347
+ alias = make_column_names_snowpark_compatible(
348
+ [new_name], rel.common.plan_id, len(groupings) + len(aggregations)
349
+ )[0]
350
+
351
+ def type_agg_expr(
352
+ agg_exp: TypedColumn, schema_inferrable: bool
353
+ ) -> DataType | None:
354
+ if pivot or not schema_inferrable:
355
+ return None
356
+ try:
357
+ return agg_exp.typ
358
+ except Exception:
359
+ # This type used for schema inference optimization purposes.
360
+ # typer may not be able to infer the type of some expressions
361
+ # in that case we return None, and the optimization will not be applied.
362
+ return None
363
+
364
+ agg_col_typ = type_agg_expr(snowpark_column, schema_inferrable)
365
+ if agg_col_typ is None:
366
+ schema_inferrable = False
367
+
368
+ aggregations.append(
369
+ _ColumnMetadata(
370
+ snowpark_column.col if skip_alias else snowpark_column.col.alias(alias),
371
+ new_name,
372
+ None if skip_alias else alias,
373
+ agg_col_typ,
374
+ qualifiers=set(),
375
+ )
486
376
  )
377
+
378
+ return (
379
+ input_container,
380
+ _Columns(
381
+ grouping_columns=groupings,
382
+ aggregation_columns=aggregations,
383
+ can_infer_schema=schema_inferrable,
384
+ ),
385
+ )
@@ -288,6 +288,20 @@ def map_project(
288
288
  alias_types = mapper.types
289
289
  typed_alias = TypedColumn(aliased_col, lambda types=alias_types: types)
290
290
  register_lca_alias(spark_name, typed_alias)
291
+
292
+ # Also register with the original qualified name if this is an alias of a column reference
293
+ # This handles ORDER BY referencing the original name: SELECT o.date AS order_date ... ORDER BY o.date
294
+ if (
295
+ exp.alias.HasField("expr")
296
+ and exp.alias.expr.WhichOneof("expr_type") == "unresolved_attribute"
297
+ ):
298
+ original_name = (
299
+ exp.alias.expr.unresolved_attribute.unparsed_identifier
300
+ )
301
+ if (
302
+ original_name != spark_name
303
+ ): # Don't register twice with the same name
304
+ register_lca_alias(original_name, typed_alias)
291
305
  else:
292
306
  # Multi-column case ('select *', posexplode, explode, inline, etc.)
293
307
  has_multi_column_alias = True
@@ -316,6 +330,11 @@ def map_project(
316
330
  final_snowpark_columns = make_column_names_snowpark_compatible(
317
331
  new_spark_columns, rel.common.plan_id
318
332
  )
333
+ # if there are duplicate snowpark column names, we need to disambiguate them by their index
334
+ if len(new_spark_columns) != len(set(new_spark_columns)):
335
+ result = result.select(
336
+ [f"${i}" for i in range(1, len(new_spark_columns) + 1)]
337
+ )
319
338
  result = result.toDF(*final_snowpark_columns)
320
339
  new_snowpark_columns = final_snowpark_columns
321
340