snowpark-connect 0.22.1__py3-none-any.whl → 0.24.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of snowpark-connect might be problematic. Click here for more details.

Files changed (46) hide show
  1. snowflake/snowpark_connect/config.py +0 -11
  2. snowflake/snowpark_connect/error/error_utils.py +7 -0
  3. snowflake/snowpark_connect/error/exceptions.py +4 -0
  4. snowflake/snowpark_connect/expression/function_defaults.py +207 -0
  5. snowflake/snowpark_connect/expression/hybrid_column_map.py +192 -0
  6. snowflake/snowpark_connect/expression/literal.py +14 -12
  7. snowflake/snowpark_connect/expression/map_cast.py +20 -4
  8. snowflake/snowpark_connect/expression/map_expression.py +18 -2
  9. snowflake/snowpark_connect/expression/map_extension.py +12 -2
  10. snowflake/snowpark_connect/expression/map_unresolved_extract_value.py +32 -5
  11. snowflake/snowpark_connect/expression/map_unresolved_function.py +69 -10
  12. snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/__init__.py +16 -0
  13. snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/accessors.py +1281 -0
  14. snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/functions.py +203 -0
  15. snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/utils.py +202 -0
  16. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +8 -8
  17. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +4 -2
  18. snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +127 -21
  19. snowflake/snowpark_connect/relation/map_aggregate.py +57 -5
  20. snowflake/snowpark_connect/relation/map_column_ops.py +6 -5
  21. snowflake/snowpark_connect/relation/map_extension.py +65 -31
  22. snowflake/snowpark_connect/relation/map_local_relation.py +8 -1
  23. snowflake/snowpark_connect/relation/map_row_ops.py +2 -0
  24. snowflake/snowpark_connect/relation/map_sql.py +22 -5
  25. snowflake/snowpark_connect/relation/read/map_read.py +2 -1
  26. snowflake/snowpark_connect/relation/read/map_read_parquet.py +8 -1
  27. snowflake/snowpark_connect/relation/read/reader_config.py +9 -0
  28. snowflake/snowpark_connect/relation/write/map_write.py +243 -68
  29. snowflake/snowpark_connect/server.py +25 -5
  30. snowflake/snowpark_connect/type_mapping.py +2 -2
  31. snowflake/snowpark_connect/utils/env_utils.py +55 -0
  32. snowflake/snowpark_connect/utils/session.py +21 -0
  33. snowflake/snowpark_connect/version.py +1 -1
  34. snowflake/snowpark_decoder/spark_decoder.py +1 -1
  35. {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.24.0.dist-info}/METADATA +2 -2
  36. {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.24.0.dist-info}/RECORD +44 -39
  37. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2_grpc.py +0 -4
  38. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2_grpc.py +0 -4
  39. {snowpark_connect-0.22.1.data → snowpark_connect-0.24.0.data}/scripts/snowpark-connect +0 -0
  40. {snowpark_connect-0.22.1.data → snowpark_connect-0.24.0.data}/scripts/snowpark-session +0 -0
  41. {snowpark_connect-0.22.1.data → snowpark_connect-0.24.0.data}/scripts/snowpark-submit +0 -0
  42. {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.24.0.dist-info}/WHEEL +0 -0
  43. {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.24.0.dist-info}/licenses/LICENSE-binary +0 -0
  44. {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.24.0.dist-info}/licenses/LICENSE.txt +0 -0
  45. {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.24.0.dist-info}/licenses/NOTICE-binary +0 -0
  46. {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.24.0.dist-info}/top_level.txt +0 -0
@@ -153,11 +153,63 @@ def map_pivot_aggregate(
153
153
  used_columns.add(mapped_col[0].snowpark_name)
154
154
 
155
155
  if len(columns.grouping_expressions()) == 0:
156
- result = (
157
- input_df_actual.select(*used_columns)
158
- .pivot(pivot_column[1].col, pivot_values if pivot_values else None)
159
- .agg(*columns.aggregation_expressions(unalias=True))
160
- )
156
+ # Snowpark doesn't support multiple aggregations in pivot without groupBy
157
+ # So we need to perform each aggregation separately and then combine results
158
+ if len(columns.aggregation_expressions(unalias=True)) > 1:
159
+ agg_expressions = columns.aggregation_expressions(unalias=True)
160
+ agg_metadata = columns.aggregation_columns
161
+ num_agg_functions = len(agg_expressions)
162
+
163
+ spark_names = []
164
+ pivot_results = []
165
+ for i, agg_expr in enumerate(agg_expressions):
166
+ pivot_result = (
167
+ input_df_actual.select(*used_columns)
168
+ .pivot(pivot_column[1].col, pivot_values if pivot_values else None)
169
+ .agg(agg_expr)
170
+ )
171
+ for col_name in pivot_result.columns:
172
+ spark_names.append(
173
+ f"{pivot_column_name(col_name)}_{agg_metadata[i].spark_name}"
174
+ )
175
+ pivot_results.append(pivot_result)
176
+
177
+ result = pivot_results[0]
178
+ for pivot_result in pivot_results[1:]:
179
+ result = result.cross_join(pivot_result)
180
+
181
+ pivot_columns_per_agg = len(pivot_results[0].columns)
182
+ reordered_spark_names = []
183
+ reordered_snowpark_names = []
184
+ reordered_types = []
185
+ column_selectors = []
186
+
187
+ for pivot_idx in range(pivot_columns_per_agg):
188
+ for agg_idx in range(num_agg_functions):
189
+ current_pos = agg_idx * pivot_columns_per_agg + pivot_idx
190
+ if current_pos < len(spark_names):
191
+ idx = current_pos + 1 # 1-based indexing for Snowpark
192
+ reordered_spark_names.append(spark_names[current_pos])
193
+ reordered_snowpark_names.append(f"${idx}")
194
+ reordered_types.append(
195
+ result.schema.fields[current_pos].datatype
196
+ )
197
+ column_selectors.append(snowpark_fn.col(f"${idx}"))
198
+
199
+ return DataFrameContainer.create_with_column_mapping(
200
+ dataframe=result.select(*column_selectors),
201
+ spark_column_names=reordered_spark_names,
202
+ snowpark_column_names=reordered_snowpark_names,
203
+ column_qualifiers=[[]] * len(reordered_spark_names),
204
+ parent_column_name_map=input_container.column_map,
205
+ snowpark_column_types=reordered_types,
206
+ )
207
+ else:
208
+ result = (
209
+ input_df_actual.select(*used_columns)
210
+ .pivot(pivot_column[1].col, pivot_values if pivot_values else None)
211
+ .agg(*columns.aggregation_expressions(unalias=True))
212
+ )
161
213
  else:
162
214
  result = (
163
215
  input_df_actual.group_by(*columns.grouping_expressions())
@@ -266,6 +266,7 @@ def map_project(
266
266
 
267
267
  aliased_col = mapper.col.alias(snowpark_column)
268
268
  select_list.append(aliased_col)
269
+
269
270
  new_snowpark_columns.append(snowpark_column)
270
271
  new_spark_columns.append(spark_name)
271
272
  column_types.extend(mapper.types)
@@ -422,6 +423,7 @@ def map_sort(
422
423
  # TODO: sort.isglobal.
423
424
  if not order_specified:
424
425
  ascending = None
426
+
425
427
  result = input_df.sort(cols, ascending=ascending)
426
428
 
427
429
  return DataFrameContainer(
@@ -1075,14 +1077,12 @@ def map_group_map(
1075
1077
  snowpark_grouping_expressions: list[snowpark.Column] = []
1076
1078
  typer = ExpressionTyper(input_df)
1077
1079
  group_name_list: list[str] = []
1078
- qualifiers = []
1079
1080
  for exp in grouping_expressions:
1080
1081
  new_name, snowpark_column = map_single_column_expression(
1081
1082
  exp, input_container.column_map, typer
1082
1083
  )
1083
1084
  snowpark_grouping_expressions.append(snowpark_column.col)
1084
1085
  group_name_list.append(new_name)
1085
- qualifiers.append(snowpark_column.get_qualifiers())
1086
1086
  if rel.group_map.func.python_udf is None:
1087
1087
  raise ValueError("group_map relation without python udf is not supported")
1088
1088
 
@@ -1124,13 +1124,14 @@ def map_group_map(
1124
1124
  result = input_df.group_by(*snowpark_grouping_expressions).apply_in_pandas(
1125
1125
  callable_func, output_type
1126
1126
  )
1127
-
1128
- qualifiers.extend([[]] * (len(result.columns) - len(group_name_list)))
1127
+ # The UDTF `apply_in_pandas` generates a new table whose output schema
1128
+ # can be entirely different from that of the input Snowpark DataFrame.
1129
+ # As a result, the output DataFrame should not use qualifiers based on the input group by columns.
1129
1130
  return DataFrameContainer.create_with_column_mapping(
1130
1131
  dataframe=result,
1131
1132
  spark_column_names=[field.name for field in output_type],
1132
1133
  snowpark_column_names=result.columns,
1133
- column_qualifiers=qualifiers,
1134
+ column_qualifiers=None,
1134
1135
  parent_column_name_map=input_container.column_map,
1135
1136
  )
1136
1137
 
@@ -347,6 +347,13 @@ def map_aggregate(
347
347
  raw_groupings: list[tuple[str, TypedColumn]] = []
348
348
  raw_aggregations: list[tuple[str, TypedColumn]] = []
349
349
 
350
+ if not is_group_by_all:
351
+ raw_groupings = [_map_column(exp) for exp in aggregate.grouping_expressions]
352
+
353
+ # Set the current grouping columns in context for grouping_id() function
354
+ grouping_spark_columns = [spark_name for spark_name, _ in raw_groupings]
355
+ set_current_grouping_columns(grouping_spark_columns)
356
+
350
357
  agg_count = get_sql_aggregate_function_count()
351
358
  for exp in aggregate.aggregate_expressions:
352
359
  col = _map_column(exp)
@@ -359,13 +366,6 @@ def map_aggregate(
359
366
  else:
360
367
  agg_count = new_agg_count
361
368
 
362
- if not is_group_by_all:
363
- raw_groupings = [_map_column(exp) for exp in aggregate.grouping_expressions]
364
-
365
- # Set the current grouping columns in context for grouping_id() function
366
- grouping_spark_columns = [spark_name for spark_name, _ in raw_groupings]
367
- set_current_grouping_columns(grouping_spark_columns)
368
-
369
369
  # Now create column name lists and assign aliases.
370
370
  # In case of GROUP BY ALL, even though groupings are a subset of aggregations,
371
371
  # they will have their own aliases so we can drop them later.
@@ -374,23 +374,31 @@ def map_aggregate(
374
374
  snowpark_columns: list[str] = []
375
375
  snowpark_column_types: list[snowpark_types.DataType] = []
376
376
 
377
- def _add_column(spark_name: str, snowpark_column: TypedColumn) -> snowpark.Column:
378
- alias = make_column_names_snowpark_compatible(
379
- [spark_name], plan_id, len(spark_columns)
380
- )[0]
377
+ # Use grouping columns directly without aliases
378
+ groupings = [col.col for _, col in raw_groupings]
379
+
380
+ # Create aliases only for aggregation columns
381
+ aggregations = []
382
+ for i, (spark_name, snowpark_column) in enumerate(raw_aggregations):
383
+ alias = make_column_names_snowpark_compatible([spark_name], plan_id, i)[0]
381
384
 
382
385
  spark_columns.append(spark_name)
383
386
  snowpark_columns.append(alias)
384
387
  snowpark_column_types.append(snowpark_column.typ)
385
388
 
386
- return snowpark_column.col.alias(alias)
387
-
388
- groupings = [_add_column(name, col) for name, col in raw_groupings]
389
- aggregations = [_add_column(name, col) for name, col in raw_aggregations]
389
+ aggregations.append(snowpark_column.col.alias(alias))
390
390
 
391
391
  match aggregate.group_type:
392
392
  case snowflake_proto.Aggregate.GROUP_TYPE_GROUPBY:
393
- result = input_df.group_by(groupings)
393
+ if groupings:
394
+ # Normal GROUP BY with explicit grouping columns
395
+ result = input_df.group_by(groupings)
396
+ else:
397
+ # No explicit GROUP BY - this is an aggregate over the entire table
398
+ # Use a dummy constant that will be excluded from the final result
399
+ result = input_df.with_column(
400
+ "__dummy_group__", snowpark_fn.lit(1)
401
+ ).group_by("__dummy_group__")
394
402
  case snowflake_proto.Aggregate.GROUP_TYPE_ROLLUP:
395
403
  result = input_df.rollup(groupings)
396
404
  case snowflake_proto.Aggregate.GROUP_TYPE_CUBE:
@@ -410,28 +418,54 @@ def map_aggregate(
410
418
  f"Unsupported GROUP BY type: {other}"
411
419
  )
412
420
 
413
- result = result.agg(*aggregations)
421
+ result = result.agg(*aggregations, exclude_grouping_columns=True)
422
+
423
+ # If we added a dummy grouping column, make sure it's excluded
424
+ if not groupings and "__dummy_group__" in result.columns:
425
+ result = result.drop("__dummy_group__")
426
+
427
+ # Apply HAVING condition if present
428
+ if aggregate.HasField("having_condition"):
429
+ from snowflake.snowpark_connect.expression.hybrid_column_map import (
430
+ create_hybrid_column_map_for_having,
431
+ )
432
+
433
+ # Create aggregated DataFrame column map
434
+ aggregated_column_map = DataFrameContainer.create_with_column_mapping(
435
+ dataframe=result,
436
+ spark_column_names=spark_columns,
437
+ snowpark_column_names=snowpark_columns,
438
+ snowpark_column_types=snowpark_column_types,
439
+ ).column_map
440
+
441
+ # Create hybrid column map that can resolve both input and aggregate contexts
442
+ hybrid_map = create_hybrid_column_map_for_having(
443
+ input_df=input_df,
444
+ input_column_map=input_container.column_map,
445
+ aggregated_df=result,
446
+ aggregated_column_map=aggregated_column_map,
447
+ aggregate_expressions=list(aggregate.aggregate_expressions),
448
+ grouping_expressions=list(aggregate.grouping_expressions),
449
+ spark_columns=spark_columns,
450
+ raw_aggregations=raw_aggregations,
451
+ )
452
+
453
+ # Map the HAVING condition using hybrid resolution
454
+ _, having_column = hybrid_map.resolve_expression(aggregate.having_condition)
455
+
456
+ # Apply the HAVING filter
457
+ result = result.filter(having_column.col)
414
458
 
415
459
  if aggregate.group_type == snowflake_proto.Aggregate.GROUP_TYPE_GROUPING_SETS:
416
460
  # Immediately drop extra columns. Unlike other GROUP BY operations,
417
461
  # grouping sets don't allow ORDER BY with columns that aren't in the aggregate list.
418
- result = result.select(result.columns[-len(spark_columns) :])
462
+ result = result.select(result.columns[-len(aggregations) :])
419
463
 
420
- # Build a parent column map that includes groupings.
421
- result_container = DataFrameContainer.create_with_column_mapping(
464
+ # Return only aggregation columns in the column map
465
+ return DataFrameContainer.create_with_column_mapping(
422
466
  dataframe=result,
423
467
  spark_column_names=spark_columns,
424
468
  snowpark_column_names=snowpark_columns,
425
469
  snowpark_column_types=snowpark_column_types,
426
- )
427
-
428
- # Drop the groupings.
429
- grouping_count = len(groupings)
430
-
431
- return DataFrameContainer.create_with_column_mapping(
432
- result.drop(snowpark_columns[:grouping_count]),
433
- spark_columns[grouping_count:],
434
- snowpark_columns[grouping_count:],
435
- snowpark_column_types[grouping_count:],
436
- parent_column_name_map=result_container.column_map,
470
+ parent_column_name_map=input_df._column_map,
437
471
  )
@@ -4,6 +4,7 @@
4
4
 
5
5
  import json
6
6
  import re
7
+ from json import JSONDecodeError
7
8
 
8
9
  import numpy as np
9
10
  import pyarrow as pa
@@ -19,6 +20,7 @@ from snowflake.snowpark_connect.column_name_handler import (
19
20
  )
20
21
  from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
21
22
  from snowflake.snowpark_connect.type_mapping import (
23
+ get_python_sql_utils_class,
22
24
  map_json_schema_to_snowpark,
23
25
  map_pyarrow_to_snowpark_types,
24
26
  map_simple_types,
@@ -34,7 +36,12 @@ def parse_local_relation_schema_string(rel: relation_proto.Relation):
34
36
  # schema_str can be a dict, or just a type string, e.g. INTEGER.
35
37
  schema_str = rel.local_relation.schema
36
38
  assert schema_str
37
- schema_dict = json.loads(schema_str)
39
+ try:
40
+ schema_dict = json.loads(schema_str)
41
+ except JSONDecodeError:
42
+ # Legacy scala clients sends unparsed struct type strings like "struct<id:bigint,a:int,b:double>"
43
+ spark_datatype = get_python_sql_utils_class().parseDataType(schema_str)
44
+ schema_dict = json.loads(spark_datatype.json())
38
45
 
39
46
  column_metadata = {}
40
47
  if isinstance(schema_dict, dict):
@@ -2,6 +2,7 @@
2
2
  # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
3
3
  #
4
4
 
5
+
5
6
  import pyspark.sql.connect.proto.expressions_pb2 as expressions_proto
6
7
  import pyspark.sql.connect.proto.relations_pb2 as relation_proto
7
8
  from pyspark.errors.exceptions.base import AnalysisException, IllegalArgumentException
@@ -551,6 +552,7 @@ def map_filter(
551
552
  _, condition = map_single_column_expression(
552
553
  rel.filter.condition, input_container.column_map, typer
553
554
  )
555
+
554
556
  result = input_df.filter(condition.col)
555
557
 
556
558
  return DataFrameContainer(
@@ -77,6 +77,9 @@ from ..expression.map_sql_expression import (
77
77
  from ..utils.identifiers import spark_to_sf_single_id
78
78
 
79
79
  _ctes = ContextVar[dict[str, relation_proto.Relation]]("_ctes", default={})
80
+ _having_condition = ContextVar[expressions_proto.Expression | None](
81
+ "_having_condition", default=None
82
+ )
80
83
 
81
84
 
82
85
  def _is_sql_select_statement_helper(sql_string: str) -> bool:
@@ -1146,6 +1149,7 @@ def map_logical_plan_relation(
1146
1149
  grouping_expressions=grouping_expressions,
1147
1150
  aggregate_expressions=aggregate_expressions,
1148
1151
  grouping_sets=grouping_sets,
1152
+ having_condition=_having_condition.get(),
1149
1153
  )
1150
1154
  )
1151
1155
  )
@@ -1389,12 +1393,25 @@ def map_logical_plan_relation(
1389
1393
  )
1390
1394
  )
1391
1395
  case "UnresolvedHaving":
1392
- proto = relation_proto.Relation(
1393
- filter=relation_proto.Filter(
1394
- input=map_logical_plan_relation(rel.child()),
1395
- condition=map_logical_plan_expression(rel.havingCondition()),
1396
+ # Store the having condition in context and process the child aggregate
1397
+ child_relation = rel.child()
1398
+ if str(child_relation.getClass().getSimpleName()) != "Aggregate":
1399
+ raise SnowparkConnectNotImplementedError(
1400
+ "UnresolvedHaving can only be applied to Aggregate relations"
1396
1401
  )
1397
- )
1402
+
1403
+ # Store having condition in a context variable for the Aggregate case to pick up
1404
+ having_condition = map_logical_plan_expression(rel.havingCondition())
1405
+
1406
+ # Store in thread-local context (similar to how _ctes works)
1407
+ token = _having_condition.set(having_condition)
1408
+
1409
+ try:
1410
+ # Recursively call map_logical_plan_relation on the child Aggregate
1411
+ # The Aggregate case will pick up the having condition from context
1412
+ proto = map_logical_plan_relation(child_relation, plan_id)
1413
+ finally:
1414
+ _having_condition.reset(token)
1398
1415
  case "UnresolvedHint":
1399
1416
  proto = relation_proto.Relation(
1400
1417
  hint=relation_proto.Hint(
@@ -95,7 +95,8 @@ def map_read(
95
95
  if len(rel.read.data_source.paths) > 0:
96
96
  # Normalize paths to ensure consistent behavior
97
97
  clean_source_paths = [
98
- str(Path(path)) for path in rel.read.data_source.paths
98
+ path.rstrip("/") if is_cloud_path(path) else str(Path(path))
99
+ for path in rel.read.data_source.paths
99
100
  ]
100
101
 
101
102
  result = _read_file(
@@ -54,10 +54,17 @@ def map_read_parquet(
54
54
  if len(paths) == 1:
55
55
  df = _read_parquet_with_partitions(session, reader, paths[0])
56
56
  else:
57
+ is_merge_schema = options.config.get("mergeschema")
57
58
  df = _read_parquet_with_partitions(session, reader, paths[0])
59
+ schema_cols = df.columns
58
60
  for p in paths[1:]:
59
61
  reader._user_schema = None
60
- df = df.union_all(_read_parquet_with_partitions(session, reader, p))
62
+ df = df.union_all_by_name(
63
+ _read_parquet_with_partitions(session, reader, p),
64
+ allow_missing_columns=True,
65
+ )
66
+ if not is_merge_schema:
67
+ df = df.select(*schema_cols)
61
68
 
62
69
  renamed_df, snowpark_column_names = rename_columns_as_snowflake_standard(
63
70
  df, rel.common.plan_id
@@ -398,3 +398,12 @@ class ParquetReaderConfig(ReaderWriterConfig):
398
398
  ),
399
399
  options,
400
400
  )
401
+
402
+ def convert_to_snowpark_args(self) -> dict[str, Any]:
403
+ snowpark_args = super().convert_to_snowpark_args()
404
+
405
+ # Should be determined by spark.sql.parquet.binaryAsString, but currently Snowpark Connect only supports
406
+ # the default value (false). TODO: Add support for spark.sql.parquet.binaryAsString equal to "true".
407
+ snowpark_args["BINARY_AS_TEXT"] = False
408
+
409
+ return snowpark_args