snowpark-connect 0.33.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of snowpark-connect might be problematic. Click here for more details.

Files changed (39) hide show
  1. snowflake/snowpark_connect/column_name_handler.py +42 -56
  2. snowflake/snowpark_connect/config.py +9 -0
  3. snowflake/snowpark_connect/expression/literal.py +12 -12
  4. snowflake/snowpark_connect/expression/map_sql_expression.py +6 -0
  5. snowflake/snowpark_connect/expression/map_unresolved_attribute.py +147 -63
  6. snowflake/snowpark_connect/expression/map_unresolved_function.py +31 -28
  7. snowflake/snowpark_connect/relation/map_aggregate.py +156 -255
  8. snowflake/snowpark_connect/relation/map_column_ops.py +14 -0
  9. snowflake/snowpark_connect/relation/map_join.py +364 -234
  10. snowflake/snowpark_connect/relation/map_sql.py +309 -150
  11. snowflake/snowpark_connect/relation/read/map_read.py +9 -1
  12. snowflake/snowpark_connect/relation/read/map_read_csv.py +19 -2
  13. snowflake/snowpark_connect/relation/read/map_read_json.py +3 -0
  14. snowflake/snowpark_connect/relation/read/map_read_parquet.py +3 -0
  15. snowflake/snowpark_connect/relation/read/map_read_text.py +4 -0
  16. snowflake/snowpark_connect/relation/read/reader_config.py +10 -0
  17. snowflake/snowpark_connect/relation/read/utils.py +41 -0
  18. snowflake/snowpark_connect/relation/utils.py +4 -2
  19. snowflake/snowpark_connect/relation/write/map_write.py +65 -17
  20. snowflake/snowpark_connect/utils/context.py +0 -14
  21. snowflake/snowpark_connect/utils/expression_transformer.py +163 -0
  22. snowflake/snowpark_connect/utils/session.py +0 -4
  23. snowflake/snowpark_connect/utils/udf_helper.py +1 -0
  24. snowflake/snowpark_connect/utils/udtf_helper.py +3 -0
  25. snowflake/snowpark_connect/version.py +1 -1
  26. {snowpark_connect-0.33.0.dist-info → snowpark_connect-1.0.0.dist-info}/METADATA +2 -2
  27. {snowpark_connect-0.33.0.dist-info → snowpark_connect-1.0.0.dist-info}/RECORD +35 -38
  28. snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/__init__.py +0 -16
  29. snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/accessors.py +0 -1281
  30. snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/functions.py +0 -203
  31. snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/utils.py +0 -202
  32. {snowpark_connect-0.33.0.data → snowpark_connect-1.0.0.data}/scripts/snowpark-connect +0 -0
  33. {snowpark_connect-0.33.0.data → snowpark_connect-1.0.0.data}/scripts/snowpark-session +0 -0
  34. {snowpark_connect-0.33.0.data → snowpark_connect-1.0.0.data}/scripts/snowpark-submit +0 -0
  35. {snowpark_connect-0.33.0.dist-info → snowpark_connect-1.0.0.dist-info}/WHEEL +0 -0
  36. {snowpark_connect-0.33.0.dist-info → snowpark_connect-1.0.0.dist-info}/licenses/LICENSE-binary +0 -0
  37. {snowpark_connect-0.33.0.dist-info → snowpark_connect-1.0.0.dist-info}/licenses/LICENSE.txt +0 -0
  38. {snowpark_connect-0.33.0.dist-info → snowpark_connect-1.0.0.dist-info}/licenses/NOTICE-binary +0 -0
  39. {snowpark_connect-0.33.0.dist-info → snowpark_connect-1.0.0.dist-info}/top_level.txt +0 -0
@@ -2,9 +2,8 @@
2
2
  # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
3
3
  #
4
4
 
5
- import re
5
+ import copy
6
6
  from dataclasses import dataclass
7
- from typing import Optional
8
7
 
9
8
  import pyspark.sql.connect.proto.relations_pb2 as relation_proto
10
9
 
@@ -25,10 +24,10 @@ from snowflake.snowpark_connect.expression.map_expression import (
25
24
  from snowflake.snowpark_connect.expression.typer import ExpressionTyper
26
25
  from snowflake.snowpark_connect.relation.map_relation import map_relation
27
26
  from snowflake.snowpark_connect.typed_column import TypedColumn
27
+ from snowflake.snowpark_connect.utils import expression_transformer
28
28
  from snowflake.snowpark_connect.utils.context import (
29
29
  get_is_evaluating_sql,
30
30
  set_current_grouping_columns,
31
- temporary_pivot_expression,
32
31
  )
33
32
 
34
33
 
@@ -137,213 +136,118 @@ def map_pivot_aggregate(
137
136
  get_literal_field_and_name(lit)[0] for lit in rel.aggregate.pivot.values
138
137
  ]
139
138
 
140
- used_columns = {pivot_column[1].col._expression.name}
141
- if get_is_evaluating_sql():
142
- # When evaluating SQL spark doesn't trim columns from the result
143
- used_columns = {"*"}
144
- else:
145
- for expression in rel.aggregate.aggregate_expressions:
146
- matched_identifiers = re.findall(
147
- r'unparsed_identifier: "(.*)"', expression.__str__()
139
+ if not pivot_values:
140
+ distinct_col_values = (
141
+ input_df_actual.select(pivot_column[1].col)
142
+ .distinct()
143
+ .sort(snowpark_fn.asc_nulls_first(pivot_column[1].col))
144
+ .collect()
145
+ )
146
+ pivot_values = [row[0] for row in distinct_col_values]
147
+
148
+ agg_expressions = columns.aggregation_expressions(unalias=True)
149
+
150
+ spark_col_names = []
151
+ aggregations = []
152
+ final_pivot_names = []
153
+ grouping_columns_qualifiers = []
154
+
155
+ pivot_col_name = pivot_column[1].col.get_name()
156
+
157
+ agg_columns = set()
158
+ for agg_expression in agg_expressions:
159
+ if hasattr(agg_expression, "_expr1"):
160
+ agg_columns = agg_columns.union(
161
+ agg_expression._expr1.dependent_column_names()
148
162
  )
149
- for identifier in matched_identifiers:
150
- mapped_col = input_container.column_map.spark_to_col.get(
151
- identifier, None
152
- )
153
- if mapped_col:
154
- used_columns.add(mapped_col[0].snowpark_name)
155
163
 
156
- if len(columns.grouping_expressions()) == 0:
157
- # Snowpark doesn't support multiple aggregations in pivot without groupBy
158
- # So we need to perform each aggregation separately and then combine results
159
- if len(columns.aggregation_expressions(unalias=True)) > 1:
160
- agg_expressions = columns.aggregation_expressions(unalias=True)
161
- agg_metadata = columns.aggregation_columns
162
- num_agg_functions = len(agg_expressions)
163
-
164
- spark_names = []
165
- pivot_results = []
166
- for i, agg_expr in enumerate(agg_expressions):
167
- pivot_result = (
168
- input_df_actual.select(*used_columns)
169
- .pivot(pivot_column[1].col, pivot_values if pivot_values else None)
170
- .agg(agg_expr)
171
- )
172
- for col_name in pivot_result.columns:
173
- spark_names.append(
174
- f"{pivot_column_name(col_name)}_{agg_metadata[i].spark_name}"
175
- )
176
- pivot_results.append(pivot_result)
177
-
178
- result = pivot_results[0]
179
- for pivot_result in pivot_results[1:]:
180
- result = result.cross_join(pivot_result)
181
-
182
- pivot_columns_per_agg = len(pivot_results[0].columns)
183
- reordered_spark_names = []
184
- reordered_snowpark_names = []
185
- reordered_types = []
186
- column_selectors = []
187
-
188
- for pivot_idx in range(pivot_columns_per_agg):
189
- for agg_idx in range(num_agg_functions):
190
- current_pos = agg_idx * pivot_columns_per_agg + pivot_idx
191
- if current_pos < len(spark_names):
192
- idx = current_pos + 1 # 1-based indexing for Snowpark
193
- reordered_spark_names.append(spark_names[current_pos])
194
- reordered_snowpark_names.append(f"${idx}")
195
- reordered_types.append(
196
- result.schema.fields[current_pos].datatype
197
- )
198
- column_selectors.append(snowpark_fn.col(f"${idx}"))
199
-
200
- return DataFrameContainer.create_with_column_mapping(
201
- dataframe=result.select(*column_selectors),
202
- spark_column_names=reordered_spark_names,
203
- snowpark_column_names=reordered_snowpark_names,
204
- column_qualifiers=[set() for _ in reordered_spark_names],
205
- parent_column_name_map=input_container.column_map,
206
- snowpark_column_types=reordered_types,
164
+ grouping_columns = columns.grouping_expressions()
165
+ if grouping_columns:
166
+ for col in grouping_columns:
167
+ snowpark_name = col.get_name()
168
+ spark_col_name = input_container.column_map.get_spark_column_name_from_snowpark_column_name(
169
+ snowpark_name
170
+ )
171
+ qualifiers = input_container.column_map.get_qualifiers_for_spark_column(
172
+ spark_col_name
207
173
  )
174
+ grouping_columns_qualifiers.append(qualifiers)
175
+ spark_col_names.append(spark_col_name)
176
+ elif get_is_evaluating_sql():
177
+ for col in input_container.column_map.get_snowpark_columns():
178
+ if col != pivot_col_name and col not in agg_columns:
179
+ grouping_columns.append(col)
180
+ spark_col_name = input_container.column_map.get_spark_column_name_from_snowpark_column_name(
181
+ col
182
+ )
183
+ qualifiers = input_container.column_map.get_qualifiers_for_spark_column(
184
+ spark_col_name
185
+ )
186
+ grouping_columns_qualifiers.append(qualifiers)
187
+ spark_col_names.append(spark_col_name)
188
+
189
+ for pv_value in pivot_values:
190
+ pv_is_null = False
191
+
192
+ if pv_value in (None, "NULL", "None"):
193
+ pv_value_spark = "null"
194
+ pv_is_null = True
208
195
  else:
209
- result = (
210
- input_df_actual.select(*used_columns)
211
- .pivot(pivot_column[1].col, pivot_values if pivot_values else None)
212
- .agg(*columns.aggregation_expressions(unalias=True))
196
+ pv_value_spark = str(pv_value)
197
+
198
+ for i, agg_expression in enumerate(agg_expressions):
199
+ agg_fun_expr = copy.deepcopy(agg_expression._expr1)
200
+
201
+ condition = (
202
+ snowpark_fn.is_null(pivot_column[1].col)
203
+ if pv_is_null
204
+ else (pivot_column[1].col == snowpark_fn.lit(pv_value))
213
205
  )
214
- else:
215
- result = (
216
- input_df_actual.group_by(*columns.grouping_expressions())
217
- .pivot(pivot_column[1].col, pivot_values if pivot_values else None)
218
- .agg(*columns.aggregation_expressions(unalias=True))
219
- )
220
206
 
221
- agg_name_list = [c.spark_name for c in columns.grouping_columns]
207
+ expression_transformer.inject_condition_to_all_agg_functions(
208
+ agg_fun_expr, condition
209
+ )
222
210
 
223
- # Calculate number of pivot values for proper Spark-compatible indexing
224
- total_pivot_columns = len(result.columns) - len(agg_name_list)
225
- num_pivot_values = (
226
- total_pivot_columns // len(columns.aggregation_columns)
227
- if len(columns.aggregation_columns) > 0
228
- else 1
229
- )
211
+ curr_expression = Column(agg_fun_expr)
230
212
 
231
- def _get_agg_exp_alias_for_col(col_index: int) -> Optional[str]:
232
- if col_index < len(agg_name_list) or len(columns.aggregation_columns) <= 1:
233
- return None
234
- else:
235
- index = (col_index - len(agg_name_list)) // num_pivot_values
236
- return columns.aggregation_columns[index].spark_name
237
-
238
- spark_columns = []
239
- for col in [
240
- pivot_column_name(c, _get_agg_exp_alias_for_col(i))
241
- for i, c in enumerate(result.columns)
242
- ]:
243
- spark_col = (
244
- input_container.column_map.get_spark_column_name_from_snowpark_column_name(
245
- col, allow_non_exists=True
213
+ spark_col_name = (
214
+ f"{pv_value_spark}_{columns.aggregation_columns[i].spark_name}"
215
+ if len(agg_expressions) > 1
216
+ else f"{pv_value_spark}"
246
217
  )
247
- )
248
218
 
249
- if spark_col is not None:
250
- spark_columns.append(spark_col)
251
- else:
252
- # Handle NULL column names to match Spark behavior (lowercase 'null')
253
- if col == "NULL":
254
- spark_columns.append(col.lower())
255
- else:
256
- spark_columns.append(col)
257
-
258
- grouping_cols_count = len(agg_name_list)
259
- pivot_cols = result.columns[grouping_cols_count:]
260
- spark_pivot_cols = spark_columns[grouping_cols_count:]
261
-
262
- num_agg_functions = len(columns.aggregation_columns)
263
- num_pivot_values = len(pivot_cols) // num_agg_functions
264
-
265
- reordered_snowpark_cols = []
266
- reordered_spark_cols = []
267
- column_indices = [] # 1-based indexing
268
-
269
- for i in range(grouping_cols_count):
270
- reordered_snowpark_cols.append(result.columns[i])
271
- reordered_spark_cols.append(spark_columns[i])
272
- column_indices.append(i + 1)
273
-
274
- for pivot_idx in range(num_pivot_values):
275
- for agg_idx in range(num_agg_functions):
276
- current_pos = agg_idx * num_pivot_values + pivot_idx
277
- if current_pos < len(pivot_cols):
278
- reordered_snowpark_cols.append(pivot_cols[current_pos])
279
- reordered_spark_cols.append(spark_pivot_cols[current_pos])
280
- original_index = grouping_cols_count + current_pos
281
- column_indices.append(original_index + 1)
282
-
283
- reordered_result = result.select(
284
- *[snowpark_fn.col(f"${idx}") for idx in column_indices]
219
+ snowpark_col_name = make_column_names_snowpark_compatible(
220
+ [spark_col_name],
221
+ rel.common.plan_id,
222
+ len(grouping_columns) + len(agg_expressions),
223
+ )[0]
224
+
225
+ curr_expression = curr_expression.alias(snowpark_col_name)
226
+
227
+ aggregations.append(curr_expression)
228
+ spark_col_names.append(spark_col_name)
229
+ final_pivot_names.append(snowpark_col_name)
230
+
231
+ result_df = (
232
+ input_df_actual.group_by(*grouping_columns)
233
+ .agg(*aggregations)
234
+ .select(*grouping_columns, *final_pivot_names)
285
235
  )
286
236
 
287
237
  return DataFrameContainer.create_with_column_mapping(
288
- dataframe=reordered_result,
289
- spark_column_names=reordered_spark_cols,
290
- snowpark_column_names=[f"${idx}" for idx in column_indices],
291
- column_qualifiers=(
292
- columns.get_qualifiers()[: len(agg_name_list)]
293
- + [[]] * (len(reordered_spark_cols) - len(agg_name_list))
294
- ),
295
- parent_column_name_map=input_container.column_map,
238
+ dataframe=result_df,
239
+ spark_column_names=spark_col_names,
240
+ snowpark_column_names=result_df.columns,
296
241
  snowpark_column_types=[
297
- result.schema.fields[idx - 1].datatype for idx in column_indices
242
+ result_df.schema.fields[idx].datatype
243
+ for idx, _ in enumerate(result_df.columns)
298
244
  ],
245
+ column_qualifiers=grouping_columns_qualifiers
246
+ + [set() for _ in final_pivot_names],
247
+ parent_column_name_map=input_container.column_map,
299
248
  )
300
249
 
301
250
 
302
- def pivot_column_name(snowpark_cname, opt_alias: Optional[str] = None) -> Optional[str]:
303
- # For values that are used as pivoted columns, the input and output are in the following format (outermost double quotes are part of the input):
304
-
305
- # 1. "'Java'" -> Java
306
- # 2. "'""C++""'" -> "C++"
307
- # 3. "'""""''Scala''""""'" -> ""'Scala'""
308
-
309
- # As we can see:
310
- # 1. the whole content is always nested in a double quote followed by a single quote ("'<content>'").
311
- # 2. the string content is nested in single quotes ('<string_content>')
312
- # 3. double quote is escased by another double quote, this is snowflake behavior
313
- # 4. if there is a single quote followed by a single quote, the first single quote needs to be preserved in the output
314
-
315
- try:
316
- # handling values that are used as pivoted columns
317
- match = re.match(r'^"\'(.*)\'"$', snowpark_cname)
318
- # extract the content between the outermost double quote followed by a single quote "'
319
- content = match.group(1)
320
- # convert the escaped double quote to the actual double quote
321
- content = content.replace('""', '"')
322
- escape_single_quote_placeholder = "__SAS_PLACEHOLDER_ESCAPE_SINGLE_QUOTE__"
323
- # replace two consecutive single quote in the content with a placeholder, the first single quote needs to be preserved
324
- content = re.sub(r"''", escape_single_quote_placeholder, content)
325
- # remove the solo single quote, they are not part of the string content
326
- content = re.sub(r"'", "", content)
327
- # replace the placeholder with the single quote which we want to preserve
328
- result = content.replace(escape_single_quote_placeholder, "'")
329
- return f"{result}_{opt_alias}" if opt_alias else result
330
- except Exception:
331
- # fallback to the original logic, handling aliased column names
332
- double_quote_list = re.findall(r'"(.*?)"', snowpark_cname)
333
- spark_string = ""
334
- for entry in list(filter(None, double_quote_list)):
335
- if "'" in entry:
336
- entry = entry.replace("'", "")
337
- if len(entry) > 0:
338
- spark_string += entry
339
- elif entry.isdigit() or re.compile(r"^\d+?\.\d+?$").match(entry):
340
- # skip quoting digits or decimal numbers as column names.
341
- spark_string += entry
342
- else:
343
- spark_string += '"' + entry + '"'
344
- return snowpark_cname if spark_string == "" else spark_string
345
-
346
-
347
251
  @dataclass(frozen=True)
348
252
  class _ColumnMetadata:
349
253
  expression: snowpark.Column
@@ -414,71 +318,68 @@ def map_aggregate_helper(
414
318
  typer = ExpressionTyper(input_df)
415
319
  schema_inferrable = True
416
320
 
417
- with temporary_pivot_expression(pivot):
418
- for exp in grouping_expressions:
419
- new_name, snowpark_column = map_single_column_expression(
420
- exp, input_container.column_map, typer
421
- )
422
- alias = make_column_names_snowpark_compatible(
423
- [new_name], rel.common.plan_id, len(groupings)
424
- )[0]
425
- groupings.append(
426
- _ColumnMetadata(
427
- snowpark_column.col
428
- if skip_alias
429
- else snowpark_column.col.alias(alias),
430
- new_name,
431
- None if skip_alias else alias,
432
- None if pivot else snowpark_column.typ,
433
- qualifiers=snowpark_column.get_qualifiers(),
434
- )
435
- )
436
-
437
- grouping_cols = [g.spark_name for g in groupings]
438
- set_current_grouping_columns(grouping_cols)
321
+ for exp in grouping_expressions:
322
+ new_name, snowpark_column = map_single_column_expression(
323
+ exp, input_container.column_map, typer
324
+ )
439
325
 
440
- for exp in expressions:
441
- new_name, snowpark_column = map_single_column_expression(
442
- exp, input_container.column_map, typer
326
+ alias = make_column_names_snowpark_compatible(
327
+ [new_name], rel.common.plan_id, len(groupings)
328
+ )[0]
329
+
330
+ groupings.append(
331
+ _ColumnMetadata(
332
+ snowpark_column.col if skip_alias else snowpark_column.col.alias(alias),
333
+ new_name,
334
+ None if skip_alias else alias,
335
+ None if pivot else snowpark_column.typ,
336
+ qualifiers=snowpark_column.get_qualifiers(),
443
337
  )
444
- alias = make_column_names_snowpark_compatible(
445
- [new_name], rel.common.plan_id, len(groupings) + len(aggregations)
446
- )[0]
338
+ )
447
339
 
448
- def type_agg_expr(
449
- agg_exp: TypedColumn, schema_inferrable: bool
450
- ) -> DataType | None:
451
- if pivot or not schema_inferrable:
452
- return None
453
- try:
454
- return agg_exp.typ
455
- except Exception:
456
- # This type used for schema inference optimization purposes.
457
- # typer may not be able to infer the type of some expressions
458
- # in that case we return None, and the optimization will not be applied.
459
- return None
460
-
461
- agg_col_typ = type_agg_expr(snowpark_column, schema_inferrable)
462
- if agg_col_typ is None:
463
- schema_inferrable = False
464
-
465
- aggregations.append(
466
- _ColumnMetadata(
467
- snowpark_column.col
468
- if skip_alias
469
- else snowpark_column.col.alias(alias),
470
- new_name,
471
- None if skip_alias else alias,
472
- agg_col_typ,
473
- qualifiers=set(),
474
- )
475
- )
340
+ grouping_cols = [g.spark_name for g in groupings]
341
+ set_current_grouping_columns(grouping_cols)
476
342
 
477
- return (
478
- input_container,
479
- _Columns(
480
- grouping_columns=groupings,
481
- aggregation_columns=aggregations,
482
- can_infer_schema=schema_inferrable,
483
- ),
343
+ for exp in expressions:
344
+ new_name, snowpark_column = map_single_column_expression(
345
+ exp, input_container.column_map, typer
346
+ )
347
+ alias = make_column_names_snowpark_compatible(
348
+ [new_name], rel.common.plan_id, len(groupings) + len(aggregations)
349
+ )[0]
350
+
351
+ def type_agg_expr(
352
+ agg_exp: TypedColumn, schema_inferrable: bool
353
+ ) -> DataType | None:
354
+ if pivot or not schema_inferrable:
355
+ return None
356
+ try:
357
+ return agg_exp.typ
358
+ except Exception:
359
+ # This type used for schema inference optimization purposes.
360
+ # typer may not be able to infer the type of some expressions
361
+ # in that case we return None, and the optimization will not be applied.
362
+ return None
363
+
364
+ agg_col_typ = type_agg_expr(snowpark_column, schema_inferrable)
365
+ if agg_col_typ is None:
366
+ schema_inferrable = False
367
+
368
+ aggregations.append(
369
+ _ColumnMetadata(
370
+ snowpark_column.col if skip_alias else snowpark_column.col.alias(alias),
371
+ new_name,
372
+ None if skip_alias else alias,
373
+ agg_col_typ,
374
+ qualifiers=set(),
375
+ )
484
376
  )
377
+
378
+ return (
379
+ input_container,
380
+ _Columns(
381
+ grouping_columns=groupings,
382
+ aggregation_columns=aggregations,
383
+ can_infer_schema=schema_inferrable,
384
+ ),
385
+ )
@@ -288,6 +288,20 @@ def map_project(
288
288
  alias_types = mapper.types
289
289
  typed_alias = TypedColumn(aliased_col, lambda types=alias_types: types)
290
290
  register_lca_alias(spark_name, typed_alias)
291
+
292
+ # Also register with the original qualified name if this is an alias of a column reference
293
+ # This handles ORDER BY referencing the original name: SELECT o.date AS order_date ... ORDER BY o.date
294
+ if (
295
+ exp.alias.HasField("expr")
296
+ and exp.alias.expr.WhichOneof("expr_type") == "unresolved_attribute"
297
+ ):
298
+ original_name = (
299
+ exp.alias.expr.unresolved_attribute.unparsed_identifier
300
+ )
301
+ if (
302
+ original_name != spark_name
303
+ ): # Don't register twice with the same name
304
+ register_lca_alias(original_name, typed_alias)
291
305
  else:
292
306
  # Multi-column case ('select *', posexplode, explode, inline, etc.)
293
307
  has_multi_column_alias = True