snowpark-connect 0.22.1__py3-none-any.whl → 0.23.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of snowpark-connect might be problematic. Click here for more details.

Files changed (42) hide show
  1. snowflake/snowpark_connect/config.py +0 -11
  2. snowflake/snowpark_connect/error/error_utils.py +7 -0
  3. snowflake/snowpark_connect/error/exceptions.py +4 -0
  4. snowflake/snowpark_connect/expression/hybrid_column_map.py +192 -0
  5. snowflake/snowpark_connect/expression/literal.py +9 -12
  6. snowflake/snowpark_connect/expression/map_cast.py +20 -4
  7. snowflake/snowpark_connect/expression/map_expression.py +8 -1
  8. snowflake/snowpark_connect/expression/map_unresolved_extract_value.py +32 -5
  9. snowflake/snowpark_connect/expression/map_unresolved_function.py +66 -6
  10. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +8 -8
  11. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +4 -2
  12. snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +127 -21
  13. snowflake/snowpark_connect/relation/map_aggregate.py +57 -5
  14. snowflake/snowpark_connect/relation/map_column_ops.py +38 -6
  15. snowflake/snowpark_connect/relation/map_extension.py +58 -24
  16. snowflake/snowpark_connect/relation/map_local_relation.py +8 -1
  17. snowflake/snowpark_connect/relation/map_row_ops.py +30 -1
  18. snowflake/snowpark_connect/relation/map_sql.py +22 -5
  19. snowflake/snowpark_connect/relation/read/map_read.py +2 -1
  20. snowflake/snowpark_connect/relation/read/map_read_parquet.py +8 -1
  21. snowflake/snowpark_connect/relation/read/reader_config.py +9 -0
  22. snowflake/snowpark_connect/relation/read/utils.py +7 -6
  23. snowflake/snowpark_connect/relation/utils.py +170 -1
  24. snowflake/snowpark_connect/relation/write/map_write.py +243 -68
  25. snowflake/snowpark_connect/server.py +25 -5
  26. snowflake/snowpark_connect/type_mapping.py +2 -2
  27. snowflake/snowpark_connect/utils/env_utils.py +55 -0
  28. snowflake/snowpark_connect/utils/session.py +21 -0
  29. snowflake/snowpark_connect/version.py +1 -1
  30. snowflake/snowpark_decoder/spark_decoder.py +1 -1
  31. {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.23.0.dist-info}/METADATA +2 -2
  32. {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.23.0.dist-info}/RECORD +40 -40
  33. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2_grpc.py +0 -4
  34. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2_grpc.py +0 -4
  35. {snowpark_connect-0.22.1.data → snowpark_connect-0.23.0.data}/scripts/snowpark-connect +0 -0
  36. {snowpark_connect-0.22.1.data → snowpark_connect-0.23.0.data}/scripts/snowpark-session +0 -0
  37. {snowpark_connect-0.22.1.data → snowpark_connect-0.23.0.data}/scripts/snowpark-submit +0 -0
  38. {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.23.0.dist-info}/WHEEL +0 -0
  39. {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.23.0.dist-info}/licenses/LICENSE-binary +0 -0
  40. {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.23.0.dist-info}/licenses/LICENSE.txt +0 -0
  41. {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.23.0.dist-info}/licenses/NOTICE-binary +0 -0
  42. {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.23.0.dist-info}/top_level.txt +0 -0
@@ -8,7 +8,10 @@ import typing
8
8
  import pandas
9
9
  import pyspark.sql.connect.proto.common_pb2 as common_proto
10
10
  import pyspark.sql.connect.proto.types_pb2 as types_proto
11
- from snowflake.core.exceptions import NotFoundError
11
+ from pyspark.sql.connect.client.core import Retrying
12
+ from snowflake.core.exceptions import APIError, NotFoundError
13
+ from snowflake.core.schema import Schema
14
+ from snowflake.core.table import Table, TableColumn
12
15
 
13
16
  from snowflake.snowpark import functions
14
17
  from snowflake.snowpark._internal.analyzer.analyzer_utils import (
@@ -22,6 +25,7 @@ from snowflake.snowpark_connect.config import (
22
25
  global_config,
23
26
  )
24
27
  from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
28
+ from snowflake.snowpark_connect.error.exceptions import MaxRetryExceeded
25
29
  from snowflake.snowpark_connect.relation.catalogs.abstract_spark_catalog import (
26
30
  AbstractSparkCatalog,
27
31
  _get_current_snowflake_schema,
@@ -39,6 +43,37 @@ from snowflake.snowpark_connect.utils.telemetry import (
39
43
  from snowflake.snowpark_connect.utils.udf_cache import cached_udf
40
44
 
41
45
 
46
+ def _is_retryable_api_error(e: Exception) -> bool:
47
+ """
48
+ Determine if an APIError should be retried.
49
+
50
+ Only retry on server errors, rate limiting, and transient network issues.
51
+ Don't retry on client errors like authentication, authorization, or validation failures.
52
+ """
53
+ if not isinstance(e, APIError):
54
+ return False
55
+
56
+ # Check if the error has a status_code attribute
57
+ if hasattr(e, "status_code"):
58
+ # Retry on server errors (5xx), rate limiting (429), and some client errors (400)
59
+ # 400 can be transient in some cases (like the original error trace shows)
60
+ return e.status_code in [400, 429, 500, 502, 503, 504]
61
+
62
+ # For APIErrors without explicit status codes, check the message
63
+ error_msg = str(e).lower()
64
+ retryable_patterns = [
65
+ "timeout",
66
+ "connection",
67
+ "network",
68
+ "unavailable",
69
+ "temporary",
70
+ "rate limit",
71
+ "throttle",
72
+ ]
73
+
74
+ return any(pattern in error_msg for pattern in retryable_patterns)
75
+
76
+
42
77
  def _normalize_identifier(identifier: str | None) -> str | None:
43
78
  if identifier is None:
44
79
  return None
@@ -73,10 +108,25 @@ class SnowflakeCatalog(AbstractSparkCatalog):
73
108
  )
74
109
  sp_catalog = get_or_create_snowpark_session().catalog
75
110
 
76
- dbs = sp_catalog.list_schemas(
77
- database=sf_quote(sf_database),
78
- pattern=_normalize_identifier(sf_schema),
79
- )
111
+ dbs: list[Schema] | None = None
112
+ for attempt in Retrying(
113
+ max_retries=5,
114
+ initial_backoff=100, # 100ms
115
+ max_backoff=5000, # 5 s
116
+ backoff_multiplier=2.0,
117
+ jitter=100,
118
+ min_jitter_threshold=200,
119
+ can_retry=_is_retryable_api_error,
120
+ ):
121
+ with attempt:
122
+ dbs = sp_catalog.list_schemas(
123
+ database=sf_quote(sf_database),
124
+ pattern=_normalize_identifier(sf_schema),
125
+ )
126
+ if dbs is None:
127
+ raise MaxRetryExceeded(
128
+ f"Failed to fetch databases {f'with pattern {pattern} ' if pattern is not None else ''}after all retry attempts"
129
+ )
80
130
  names: list[str] = list()
81
131
  catalogs: list[str] = list()
82
132
  descriptions: list[str | None] = list()
@@ -112,9 +162,24 @@ class SnowflakeCatalog(AbstractSparkCatalog):
112
162
  )
113
163
  sp_catalog = get_or_create_snowpark_session().catalog
114
164
 
115
- db = sp_catalog.get_schema(
116
- schema=sf_quote(sf_schema), database=sf_quote(sf_database)
117
- )
165
+ db: Schema | None = None
166
+ for attempt in Retrying(
167
+ max_retries=5,
168
+ initial_backoff=100, # 100ms
169
+ max_backoff=5000, # 5 s
170
+ backoff_multiplier=2.0,
171
+ jitter=100,
172
+ min_jitter_threshold=200,
173
+ can_retry=_is_retryable_api_error,
174
+ ):
175
+ with attempt:
176
+ db = sp_catalog.get_schema(
177
+ schema=sf_quote(sf_schema), database=sf_quote(sf_database)
178
+ )
179
+ if db is None:
180
+ raise MaxRetryExceeded(
181
+ f"Failed to fetch database {spark_dbName} after all retry attempts"
182
+ )
118
183
 
119
184
  name = unquote_if_quoted(db.name)
120
185
  return pandas.DataFrame(
@@ -241,11 +306,27 @@ class SnowflakeCatalog(AbstractSparkCatalog):
241
306
  "Calling into another catalog is not currently supported"
242
307
  )
243
308
 
244
- table = sp_catalog.get_table(
245
- database=sf_quote(sf_database),
246
- schema=sf_quote(sf_schema),
247
- table_name=sf_quote(table_name),
248
- )
309
+ table: Table | None = None
310
+ for attempt in Retrying(
311
+ max_retries=5,
312
+ initial_backoff=100, # 100ms
313
+ max_backoff=5000, # 5 s
314
+ backoff_multiplier=2.0,
315
+ jitter=100,
316
+ min_jitter_threshold=200,
317
+ can_retry=_is_retryable_api_error,
318
+ ):
319
+ with attempt:
320
+ table = sp_catalog.get_table(
321
+ database=sf_quote(sf_database),
322
+ schema=sf_quote(sf_schema),
323
+ table_name=sf_quote(table_name),
324
+ )
325
+
326
+ if table is None:
327
+ raise MaxRetryExceeded(
328
+ f"Failed to fetch table {spark_tableName} after all retry attempts"
329
+ )
249
330
 
250
331
  return pandas.DataFrame(
251
332
  {
@@ -286,6 +367,7 @@ class SnowflakeCatalog(AbstractSparkCatalog):
286
367
  ) -> pandas.DataFrame:
287
368
  """List all columns in a table/view, optionally database name filter can be provided."""
288
369
  sp_catalog = get_or_create_snowpark_session().catalog
370
+ columns: list[TableColumn] | None = None
289
371
  if spark_dbName is None:
290
372
  catalog, sf_database, sf_schema, sf_table = _process_multi_layer_identifier(
291
373
  spark_tableName
@@ -294,15 +376,39 @@ class SnowflakeCatalog(AbstractSparkCatalog):
294
376
  raise SnowparkConnectNotImplementedError(
295
377
  "Calling into another catalog is not currently supported"
296
378
  )
297
- columns = sp_catalog.list_columns(
298
- database=sf_quote(sf_database),
299
- schema=sf_quote(sf_schema),
300
- table_name=sf_quote(sf_table),
301
- )
379
+ for attempt in Retrying(
380
+ max_retries=5,
381
+ initial_backoff=100, # 100ms
382
+ max_backoff=5000, # 5 s
383
+ backoff_multiplier=2.0,
384
+ jitter=100,
385
+ min_jitter_threshold=200,
386
+ can_retry=_is_retryable_api_error,
387
+ ):
388
+ with attempt:
389
+ columns = sp_catalog.list_columns(
390
+ database=sf_quote(sf_database),
391
+ schema=sf_quote(sf_schema),
392
+ table_name=sf_quote(sf_table),
393
+ )
302
394
  else:
303
- columns = sp_catalog.list_columns(
304
- schema=sf_quote(spark_dbName),
305
- table_name=sf_quote(spark_tableName),
395
+ for attempt in Retrying(
396
+ max_retries=5,
397
+ initial_backoff=100, # 100ms
398
+ max_backoff=5000, # 5 s
399
+ backoff_multiplier=2.0,
400
+ jitter=100,
401
+ min_jitter_threshold=200,
402
+ can_retry=_is_retryable_api_error,
403
+ ):
404
+ with attempt:
405
+ columns = sp_catalog.list_columns(
406
+ schema=sf_quote(spark_dbName),
407
+ table_name=sf_quote(spark_tableName),
408
+ )
409
+ if columns is None:
410
+ raise MaxRetryExceeded(
411
+ f"Failed to fetch columns of {spark_tableName} after all retry attempts"
306
412
  )
307
413
  names: list[str] = list()
308
414
  descriptions: list[str | None] = list()
@@ -153,11 +153,63 @@ def map_pivot_aggregate(
153
153
  used_columns.add(mapped_col[0].snowpark_name)
154
154
 
155
155
  if len(columns.grouping_expressions()) == 0:
156
- result = (
157
- input_df_actual.select(*used_columns)
158
- .pivot(pivot_column[1].col, pivot_values if pivot_values else None)
159
- .agg(*columns.aggregation_expressions(unalias=True))
160
- )
156
+ # Snowpark doesn't support multiple aggregations in pivot without groupBy
157
+ # So we need to perform each aggregation separately and then combine results
158
+ if len(columns.aggregation_expressions(unalias=True)) > 1:
159
+ agg_expressions = columns.aggregation_expressions(unalias=True)
160
+ agg_metadata = columns.aggregation_columns
161
+ num_agg_functions = len(agg_expressions)
162
+
163
+ spark_names = []
164
+ pivot_results = []
165
+ for i, agg_expr in enumerate(agg_expressions):
166
+ pivot_result = (
167
+ input_df_actual.select(*used_columns)
168
+ .pivot(pivot_column[1].col, pivot_values if pivot_values else None)
169
+ .agg(agg_expr)
170
+ )
171
+ for col_name in pivot_result.columns:
172
+ spark_names.append(
173
+ f"{pivot_column_name(col_name)}_{agg_metadata[i].spark_name}"
174
+ )
175
+ pivot_results.append(pivot_result)
176
+
177
+ result = pivot_results[0]
178
+ for pivot_result in pivot_results[1:]:
179
+ result = result.cross_join(pivot_result)
180
+
181
+ pivot_columns_per_agg = len(pivot_results[0].columns)
182
+ reordered_spark_names = []
183
+ reordered_snowpark_names = []
184
+ reordered_types = []
185
+ column_selectors = []
186
+
187
+ for pivot_idx in range(pivot_columns_per_agg):
188
+ for agg_idx in range(num_agg_functions):
189
+ current_pos = agg_idx * pivot_columns_per_agg + pivot_idx
190
+ if current_pos < len(spark_names):
191
+ idx = current_pos + 1 # 1-based indexing for Snowpark
192
+ reordered_spark_names.append(spark_names[current_pos])
193
+ reordered_snowpark_names.append(f"${idx}")
194
+ reordered_types.append(
195
+ result.schema.fields[current_pos].datatype
196
+ )
197
+ column_selectors.append(snowpark_fn.col(f"${idx}"))
198
+
199
+ return DataFrameContainer.create_with_column_mapping(
200
+ dataframe=result.select(*column_selectors),
201
+ spark_column_names=reordered_spark_names,
202
+ snowpark_column_names=reordered_snowpark_names,
203
+ column_qualifiers=[[]] * len(reordered_spark_names),
204
+ parent_column_name_map=input_container.column_map,
205
+ snowpark_column_types=reordered_types,
206
+ )
207
+ else:
208
+ result = (
209
+ input_df_actual.select(*used_columns)
210
+ .pivot(pivot_column[1].col, pivot_values if pivot_values else None)
211
+ .agg(*columns.aggregation_expressions(unalias=True))
212
+ )
161
213
  else:
162
214
  result = (
163
215
  input_df_actual.group_by(*columns.grouping_expressions())
@@ -6,10 +6,12 @@ import ast
6
6
  import json
7
7
  import sys
8
8
  from collections import defaultdict
9
+ from copy import copy
9
10
 
10
11
  import pyspark.sql.connect.proto.expressions_pb2 as expressions_proto
11
12
  import pyspark.sql.connect.proto.relations_pb2 as relation_proto
12
13
  import pyspark.sql.connect.proto.types_pb2 as types_proto
14
+ from pyspark.errors import PySparkValueError
13
15
  from pyspark.errors.exceptions.base import AnalysisException
14
16
  from pyspark.serializers import CloudPickleSerializer
15
17
 
@@ -44,6 +46,7 @@ from snowflake.snowpark_connect.expression.typer import ExpressionTyper
44
46
  from snowflake.snowpark_connect.relation.map_relation import map_relation
45
47
  from snowflake.snowpark_connect.relation.utils import (
46
48
  TYPE_MAP_FOR_TO_SCHEMA,
49
+ can_sort_be_flattened,
47
50
  snowpark_functions_col,
48
51
  )
49
52
  from snowflake.snowpark_connect.type_mapping import (
@@ -266,6 +269,7 @@ def map_project(
266
269
 
267
270
  aliased_col = mapper.col.alias(snowpark_column)
268
271
  select_list.append(aliased_col)
272
+
269
273
  new_snowpark_columns.append(snowpark_column)
270
274
  new_spark_columns.append(spark_name)
271
275
  column_types.extend(mapper.types)
@@ -342,6 +346,12 @@ def map_sort(
342
346
 
343
347
  sort_order = sort.order
344
348
 
349
+ if not sort_order:
350
+ raise PySparkValueError(
351
+ error_class="CANNOT_BE_EMPTY",
352
+ message="At least one column must be specified.",
353
+ )
354
+
345
355
  if len(sort_order) == 1:
346
356
  parsed_col_name = split_fully_qualified_spark_name(
347
357
  sort_order[0].child.unresolved_attribute.unparsed_identifier
@@ -422,7 +432,30 @@ def map_sort(
422
432
  # TODO: sort.isglobal.
423
433
  if not order_specified:
424
434
  ascending = None
425
- result = input_df.sort(cols, ascending=ascending)
435
+
436
+ select_statement = getattr(input_df, "_select_statement", None)
437
+ sort_expressions = [c._expression for c in cols]
438
+ if (
439
+ can_sort_be_flattened(select_statement, *sort_expressions)
440
+ and input_df._ops_after_agg is None
441
+ ):
442
+ # "flattened" order by that will allow using dropped columns
443
+ new = copy(select_statement)
444
+ new.from_ = select_statement.from_.to_subqueryable()
445
+ new.pre_actions = new.from_.pre_actions
446
+ new.post_actions = new.from_.post_actions
447
+ new.order_by = sort_expressions + (select_statement.order_by or [])
448
+ new.column_states = select_statement.column_states
449
+ new._merge_projection_complexity_with_subquery = False
450
+ new.df_ast_ids = (
451
+ select_statement.df_ast_ids.copy()
452
+ if select_statement.df_ast_ids is not None
453
+ else None
454
+ )
455
+ new.attributes = select_statement.attributes
456
+ result = input_df._with_plan(new)
457
+ else:
458
+ result = input_df.sort(cols, ascending=ascending)
426
459
 
427
460
  return DataFrameContainer(
428
461
  result,
@@ -1075,14 +1108,12 @@ def map_group_map(
1075
1108
  snowpark_grouping_expressions: list[snowpark.Column] = []
1076
1109
  typer = ExpressionTyper(input_df)
1077
1110
  group_name_list: list[str] = []
1078
- qualifiers = []
1079
1111
  for exp in grouping_expressions:
1080
1112
  new_name, snowpark_column = map_single_column_expression(
1081
1113
  exp, input_container.column_map, typer
1082
1114
  )
1083
1115
  snowpark_grouping_expressions.append(snowpark_column.col)
1084
1116
  group_name_list.append(new_name)
1085
- qualifiers.append(snowpark_column.get_qualifiers())
1086
1117
  if rel.group_map.func.python_udf is None:
1087
1118
  raise ValueError("group_map relation without python udf is not supported")
1088
1119
 
@@ -1124,13 +1155,14 @@ def map_group_map(
1124
1155
  result = input_df.group_by(*snowpark_grouping_expressions).apply_in_pandas(
1125
1156
  callable_func, output_type
1126
1157
  )
1127
-
1128
- qualifiers.extend([[]] * (len(result.columns) - len(group_name_list)))
1158
+ # The UDTF `apply_in_pandas` generates a new table whose output schema
1159
+ # can be entirely different from that of the input Snowpark DataFrame.
1160
+ # As a result, the output DataFrame should not use qualifiers based on the input group by columns.
1129
1161
  return DataFrameContainer.create_with_column_mapping(
1130
1162
  dataframe=result,
1131
1163
  spark_column_names=[field.name for field in output_type],
1132
1164
  snowpark_column_names=result.columns,
1133
- column_qualifiers=qualifiers,
1165
+ column_qualifiers=None,
1134
1166
  parent_column_name_map=input_container.column_map,
1135
1167
  )
1136
1168
 
@@ -374,23 +374,31 @@ def map_aggregate(
374
374
  snowpark_columns: list[str] = []
375
375
  snowpark_column_types: list[snowpark_types.DataType] = []
376
376
 
377
- def _add_column(spark_name: str, snowpark_column: TypedColumn) -> snowpark.Column:
378
- alias = make_column_names_snowpark_compatible(
379
- [spark_name], plan_id, len(spark_columns)
380
- )[0]
377
+ # Use grouping columns directly without aliases
378
+ groupings = [col.col for _, col in raw_groupings]
379
+
380
+ # Create aliases only for aggregation columns
381
+ aggregations = []
382
+ for i, (spark_name, snowpark_column) in enumerate(raw_aggregations):
383
+ alias = make_column_names_snowpark_compatible([spark_name], plan_id, i)[0]
381
384
 
382
385
  spark_columns.append(spark_name)
383
386
  snowpark_columns.append(alias)
384
387
  snowpark_column_types.append(snowpark_column.typ)
385
388
 
386
- return snowpark_column.col.alias(alias)
387
-
388
- groupings = [_add_column(name, col) for name, col in raw_groupings]
389
- aggregations = [_add_column(name, col) for name, col in raw_aggregations]
389
+ aggregations.append(snowpark_column.col.alias(alias))
390
390
 
391
391
  match aggregate.group_type:
392
392
  case snowflake_proto.Aggregate.GROUP_TYPE_GROUPBY:
393
- result = input_df.group_by(groupings)
393
+ if groupings:
394
+ # Normal GROUP BY with explicit grouping columns
395
+ result = input_df.group_by(groupings)
396
+ else:
397
+ # No explicit GROUP BY - this is an aggregate over the entire table
398
+ # Use a dummy constant that will be excluded from the final result
399
+ result = input_df.with_column(
400
+ "__dummy_group__", snowpark_fn.lit(1)
401
+ ).group_by("__dummy_group__")
394
402
  case snowflake_proto.Aggregate.GROUP_TYPE_ROLLUP:
395
403
  result = input_df.rollup(groupings)
396
404
  case snowflake_proto.Aggregate.GROUP_TYPE_CUBE:
@@ -410,28 +418,54 @@ def map_aggregate(
410
418
  f"Unsupported GROUP BY type: {other}"
411
419
  )
412
420
 
413
- result = result.agg(*aggregations)
421
+ result = result.agg(*aggregations, exclude_grouping_columns=True)
422
+
423
+ # If we added a dummy grouping column, make sure it's excluded
424
+ if not groupings and "__dummy_group__" in result.columns:
425
+ result = result.drop("__dummy_group__")
426
+
427
+ # Apply HAVING condition if present
428
+ if aggregate.HasField("having_condition"):
429
+ from snowflake.snowpark_connect.expression.hybrid_column_map import (
430
+ create_hybrid_column_map_for_having,
431
+ )
432
+
433
+ # Create aggregated DataFrame column map
434
+ aggregated_column_map = DataFrameContainer.create_with_column_mapping(
435
+ dataframe=result,
436
+ spark_column_names=spark_columns,
437
+ snowpark_column_names=snowpark_columns,
438
+ snowpark_column_types=snowpark_column_types,
439
+ ).column_map
440
+
441
+ # Create hybrid column map that can resolve both input and aggregate contexts
442
+ hybrid_map = create_hybrid_column_map_for_having(
443
+ input_df=input_df,
444
+ input_column_map=input_container.column_map,
445
+ aggregated_df=result,
446
+ aggregated_column_map=aggregated_column_map,
447
+ aggregate_expressions=list(aggregate.aggregate_expressions),
448
+ grouping_expressions=list(aggregate.grouping_expressions),
449
+ spark_columns=spark_columns,
450
+ raw_aggregations=raw_aggregations,
451
+ )
452
+
453
+ # Map the HAVING condition using hybrid resolution
454
+ _, having_column = hybrid_map.resolve_expression(aggregate.having_condition)
455
+
456
+ # Apply the HAVING filter
457
+ result = result.filter(having_column.col)
414
458
 
415
459
  if aggregate.group_type == snowflake_proto.Aggregate.GROUP_TYPE_GROUPING_SETS:
416
460
  # Immediately drop extra columns. Unlike other GROUP BY operations,
417
461
  # grouping sets don't allow ORDER BY with columns that aren't in the aggregate list.
418
- result = result.select(result.columns[-len(spark_columns) :])
462
+ result = result.select(result.columns[-len(aggregations) :])
419
463
 
420
- # Build a parent column map that includes groupings.
421
- result_container = DataFrameContainer.create_with_column_mapping(
464
+ # Return only aggregation columns in the column map
465
+ return DataFrameContainer.create_with_column_mapping(
422
466
  dataframe=result,
423
467
  spark_column_names=spark_columns,
424
468
  snowpark_column_names=snowpark_columns,
425
469
  snowpark_column_types=snowpark_column_types,
426
- )
427
-
428
- # Drop the groupings.
429
- grouping_count = len(groupings)
430
-
431
- return DataFrameContainer.create_with_column_mapping(
432
- result.drop(snowpark_columns[:grouping_count]),
433
- spark_columns[grouping_count:],
434
- snowpark_columns[grouping_count:],
435
- snowpark_column_types[grouping_count:],
436
- parent_column_name_map=result_container.column_map,
470
+ parent_column_name_map=input_df._column_map,
437
471
  )
@@ -4,6 +4,7 @@
4
4
 
5
5
  import json
6
6
  import re
7
+ from json import JSONDecodeError
7
8
 
8
9
  import numpy as np
9
10
  import pyarrow as pa
@@ -19,6 +20,7 @@ from snowflake.snowpark_connect.column_name_handler import (
19
20
  )
20
21
  from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
21
22
  from snowflake.snowpark_connect.type_mapping import (
23
+ get_python_sql_utils_class,
22
24
  map_json_schema_to_snowpark,
23
25
  map_pyarrow_to_snowpark_types,
24
26
  map_simple_types,
@@ -34,7 +36,12 @@ def parse_local_relation_schema_string(rel: relation_proto.Relation):
34
36
  # schema_str can be a dict, or just a type string, e.g. INTEGER.
35
37
  schema_str = rel.local_relation.schema
36
38
  assert schema_str
37
- schema_dict = json.loads(schema_str)
39
+ try:
40
+ schema_dict = json.loads(schema_str)
41
+ except JSONDecodeError:
42
+ # Legacy scala clients sends unparsed struct type strings like "struct<id:bigint,a:int,b:double>"
43
+ spark_datatype = get_python_sql_utils_class().parseDataType(schema_str)
44
+ schema_dict = json.loads(spark_datatype.json())
38
45
 
39
46
  column_metadata = {}
40
47
  if isinstance(schema_dict, dict):
@@ -1,6 +1,7 @@
1
1
  #
2
2
  # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
3
3
  #
4
+ from copy import copy
4
5
 
5
6
  import pyspark.sql.connect.proto.expressions_pb2 as expressions_proto
6
7
  import pyspark.sql.connect.proto.relations_pb2 as relation_proto
@@ -8,6 +9,7 @@ from pyspark.errors.exceptions.base import AnalysisException, IllegalArgumentExc
8
9
 
9
10
  import snowflake.snowpark_connect.relation.utils as utils
10
11
  from snowflake import snowpark
12
+ from snowflake.snowpark._internal.analyzer.binary_expression import And
11
13
  from snowflake.snowpark.functions import col, expr as snowpark_expr
12
14
  from snowflake.snowpark.types import (
13
15
  BooleanType,
@@ -29,6 +31,7 @@ from snowflake.snowpark_connect.expression.map_expression import (
29
31
  )
30
32
  from snowflake.snowpark_connect.expression.typer import ExpressionTyper
31
33
  from snowflake.snowpark_connect.relation.map_relation import map_relation
34
+ from snowflake.snowpark_connect.relation.utils import can_filter_be_flattened
32
35
  from snowflake.snowpark_connect.utils.telemetry import (
33
36
  SnowparkConnectNotImplementedError,
34
37
  )
@@ -551,7 +554,33 @@ def map_filter(
551
554
  _, condition = map_single_column_expression(
552
555
  rel.filter.condition, input_container.column_map, typer
553
556
  )
554
- result = input_df.filter(condition.col)
557
+
558
+ select_statement = getattr(input_df, "_select_statement", None)
559
+ condition_exp = condition.col._expression
560
+ if (
561
+ can_filter_be_flattened(select_statement, condition_exp)
562
+ and input_df._ops_after_agg is None
563
+ ):
564
+ new = copy(select_statement)
565
+ new.from_ = select_statement.from_.to_subqueryable()
566
+ new.pre_actions = new.from_.pre_actions
567
+ new.post_actions = new.from_.post_actions
568
+ new.column_states = select_statement.column_states
569
+ new.where = (
570
+ And(select_statement.where, condition_exp)
571
+ if select_statement.where is not None
572
+ else condition_exp
573
+ )
574
+ new._merge_projection_complexity_with_subquery = False
575
+ new.df_ast_ids = (
576
+ select_statement.df_ast_ids.copy()
577
+ if select_statement.df_ast_ids is not None
578
+ else None
579
+ )
580
+ new.attributes = select_statement.attributes
581
+ result = input_df._with_plan(new)
582
+ else:
583
+ result = input_df.filter(condition.col)
555
584
 
556
585
  return DataFrameContainer(
557
586
  result,
@@ -77,6 +77,9 @@ from ..expression.map_sql_expression import (
77
77
  from ..utils.identifiers import spark_to_sf_single_id
78
78
 
79
79
  _ctes = ContextVar[dict[str, relation_proto.Relation]]("_ctes", default={})
80
+ _having_condition = ContextVar[expressions_proto.Expression | None](
81
+ "_having_condition", default=None
82
+ )
80
83
 
81
84
 
82
85
  def _is_sql_select_statement_helper(sql_string: str) -> bool:
@@ -1146,6 +1149,7 @@ def map_logical_plan_relation(
1146
1149
  grouping_expressions=grouping_expressions,
1147
1150
  aggregate_expressions=aggregate_expressions,
1148
1151
  grouping_sets=grouping_sets,
1152
+ having_condition=_having_condition.get(),
1149
1153
  )
1150
1154
  )
1151
1155
  )
@@ -1389,12 +1393,25 @@ def map_logical_plan_relation(
1389
1393
  )
1390
1394
  )
1391
1395
  case "UnresolvedHaving":
1392
- proto = relation_proto.Relation(
1393
- filter=relation_proto.Filter(
1394
- input=map_logical_plan_relation(rel.child()),
1395
- condition=map_logical_plan_expression(rel.havingCondition()),
1396
+ # Store the having condition in context and process the child aggregate
1397
+ child_relation = rel.child()
1398
+ if str(child_relation.getClass().getSimpleName()) != "Aggregate":
1399
+ raise SnowparkConnectNotImplementedError(
1400
+ "UnresolvedHaving can only be applied to Aggregate relations"
1396
1401
  )
1397
- )
1402
+
1403
+ # Store having condition in a context variable for the Aggregate case to pick up
1404
+ having_condition = map_logical_plan_expression(rel.havingCondition())
1405
+
1406
+ # Store in thread-local context (similar to how _ctes works)
1407
+ token = _having_condition.set(having_condition)
1408
+
1409
+ try:
1410
+ # Recursively call map_logical_plan_relation on the child Aggregate
1411
+ # The Aggregate case will pick up the having condition from context
1412
+ proto = map_logical_plan_relation(child_relation, plan_id)
1413
+ finally:
1414
+ _having_condition.reset(token)
1398
1415
  case "UnresolvedHint":
1399
1416
  proto = relation_proto.Relation(
1400
1417
  hint=relation_proto.Hint(
@@ -95,7 +95,8 @@ def map_read(
95
95
  if len(rel.read.data_source.paths) > 0:
96
96
  # Normalize paths to ensure consistent behavior
97
97
  clean_source_paths = [
98
- str(Path(path)) for path in rel.read.data_source.paths
98
+ path.rstrip("/") if is_cloud_path(path) else str(Path(path))
99
+ for path in rel.read.data_source.paths
99
100
  ]
100
101
 
101
102
  result = _read_file(
@@ -54,10 +54,17 @@ def map_read_parquet(
54
54
  if len(paths) == 1:
55
55
  df = _read_parquet_with_partitions(session, reader, paths[0])
56
56
  else:
57
+ is_merge_schema = options.config.get("mergeschema")
57
58
  df = _read_parquet_with_partitions(session, reader, paths[0])
59
+ schema_cols = df.columns
58
60
  for p in paths[1:]:
59
61
  reader._user_schema = None
60
- df = df.union_all(_read_parquet_with_partitions(session, reader, p))
62
+ df = df.union_all_by_name(
63
+ _read_parquet_with_partitions(session, reader, p),
64
+ allow_missing_columns=True,
65
+ )
66
+ if not is_merge_schema:
67
+ df = df.select(*schema_cols)
61
68
 
62
69
  renamed_df, snowpark_column_names = rename_columns_as_snowflake_standard(
63
70
  df, rel.common.plan_id
@@ -398,3 +398,12 @@ class ParquetReaderConfig(ReaderWriterConfig):
398
398
  ),
399
399
  options,
400
400
  )
401
+
402
+ def convert_to_snowpark_args(self) -> dict[str, Any]:
403
+ snowpark_args = super().convert_to_snowpark_args()
404
+
405
+ # Should be determined by spark.sql.parquet.binaryAsString, but currently Snowpark Connect only supports
406
+ # the default value (false). TODO: Add support for spark.sql.parquet.binaryAsString equal to "true".
407
+ snowpark_args["BINARY_AS_TEXT"] = False
408
+
409
+ return snowpark_args