snowpark-connect 0.20.2__py3-none-any.whl → 0.22.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of snowpark-connect might be problematic. Click here for more details.

Files changed (84) hide show
  1. snowflake/snowpark_connect/analyze_plan/map_tree_string.py +3 -2
  2. snowflake/snowpark_connect/column_name_handler.py +6 -65
  3. snowflake/snowpark_connect/config.py +47 -17
  4. snowflake/snowpark_connect/dataframe_container.py +242 -0
  5. snowflake/snowpark_connect/error/error_utils.py +25 -0
  6. snowflake/snowpark_connect/execute_plan/map_execution_command.py +13 -23
  7. snowflake/snowpark_connect/execute_plan/map_execution_root.py +9 -5
  8. snowflake/snowpark_connect/expression/map_extension.py +2 -1
  9. snowflake/snowpark_connect/expression/map_udf.py +4 -4
  10. snowflake/snowpark_connect/expression/map_unresolved_attribute.py +8 -7
  11. snowflake/snowpark_connect/expression/map_unresolved_function.py +481 -170
  12. snowflake/snowpark_connect/expression/map_unresolved_star.py +8 -8
  13. snowflake/snowpark_connect/expression/map_update_fields.py +1 -1
  14. snowflake/snowpark_connect/expression/typer.py +6 -6
  15. snowflake/snowpark_connect/proto/control_pb2.py +17 -16
  16. snowflake/snowpark_connect/proto/control_pb2.pyi +17 -17
  17. snowflake/snowpark_connect/proto/control_pb2_grpc.py +12 -63
  18. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.py +15 -14
  19. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.pyi +19 -14
  20. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2_grpc.py +4 -0
  21. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +27 -26
  22. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +74 -68
  23. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2_grpc.py +4 -0
  24. snowflake/snowpark_connect/relation/catalogs/abstract_spark_catalog.py +5 -5
  25. snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +25 -17
  26. snowflake/snowpark_connect/relation/map_aggregate.py +170 -61
  27. snowflake/snowpark_connect/relation/map_catalog.py +2 -2
  28. snowflake/snowpark_connect/relation/map_column_ops.py +227 -145
  29. snowflake/snowpark_connect/relation/map_crosstab.py +25 -6
  30. snowflake/snowpark_connect/relation/map_extension.py +81 -56
  31. snowflake/snowpark_connect/relation/map_join.py +72 -63
  32. snowflake/snowpark_connect/relation/map_local_relation.py +35 -20
  33. snowflake/snowpark_connect/relation/map_map_partitions.py +24 -17
  34. snowflake/snowpark_connect/relation/map_relation.py +22 -16
  35. snowflake/snowpark_connect/relation/map_row_ops.py +232 -146
  36. snowflake/snowpark_connect/relation/map_sample_by.py +15 -8
  37. snowflake/snowpark_connect/relation/map_show_string.py +42 -5
  38. snowflake/snowpark_connect/relation/map_sql.py +141 -237
  39. snowflake/snowpark_connect/relation/map_stats.py +88 -39
  40. snowflake/snowpark_connect/relation/map_subquery_alias.py +13 -14
  41. snowflake/snowpark_connect/relation/map_udtf.py +10 -13
  42. snowflake/snowpark_connect/relation/read/map_read.py +8 -3
  43. snowflake/snowpark_connect/relation/read/map_read_csv.py +7 -7
  44. snowflake/snowpark_connect/relation/read/map_read_jdbc.py +7 -7
  45. snowflake/snowpark_connect/relation/read/map_read_json.py +19 -8
  46. snowflake/snowpark_connect/relation/read/map_read_parquet.py +7 -7
  47. snowflake/snowpark_connect/relation/read/map_read_socket.py +7 -3
  48. snowflake/snowpark_connect/relation/read/map_read_table.py +25 -16
  49. snowflake/snowpark_connect/relation/read/map_read_text.py +7 -7
  50. snowflake/snowpark_connect/relation/read/reader_config.py +1 -0
  51. snowflake/snowpark_connect/relation/utils.py +11 -5
  52. snowflake/snowpark_connect/relation/write/jdbc_write_dbapi.py +15 -12
  53. snowflake/snowpark_connect/relation/write/map_write.py +259 -56
  54. snowflake/snowpark_connect/relation/write/map_write_jdbc.py +3 -2
  55. snowflake/snowpark_connect/server.py +43 -4
  56. snowflake/snowpark_connect/type_mapping.py +6 -23
  57. snowflake/snowpark_connect/utils/cache.py +27 -22
  58. snowflake/snowpark_connect/utils/context.py +33 -17
  59. snowflake/snowpark_connect/utils/describe_query_cache.py +2 -9
  60. snowflake/snowpark_connect/utils/{attribute_handling.py → identifiers.py} +47 -0
  61. snowflake/snowpark_connect/utils/session.py +41 -38
  62. snowflake/snowpark_connect/utils/telemetry.py +214 -63
  63. snowflake/snowpark_connect/utils/udxf_import_utils.py +14 -0
  64. snowflake/snowpark_connect/version.py +1 -1
  65. snowflake/snowpark_decoder/__init__.py +0 -0
  66. snowflake/snowpark_decoder/_internal/proto/generated/DataframeProcessorMsg_pb2.py +36 -0
  67. snowflake/snowpark_decoder/_internal/proto/generated/DataframeProcessorMsg_pb2.pyi +156 -0
  68. snowflake/snowpark_decoder/dp_session.py +111 -0
  69. snowflake/snowpark_decoder/spark_decoder.py +76 -0
  70. {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.22.1.dist-info}/METADATA +6 -4
  71. {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.22.1.dist-info}/RECORD +83 -69
  72. snowpark_connect-0.22.1.dist-info/licenses/LICENSE-binary +568 -0
  73. snowpark_connect-0.22.1.dist-info/licenses/NOTICE-binary +1533 -0
  74. {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.22.1.dist-info}/top_level.txt +1 -0
  75. spark/__init__.py +0 -0
  76. spark/connect/__init__.py +0 -0
  77. spark/connect/envelope_pb2.py +31 -0
  78. spark/connect/envelope_pb2.pyi +46 -0
  79. snowflake/snowpark_connect/includes/jars/jackson-mapper-asl-1.9.13.jar +0 -0
  80. {snowpark_connect-0.20.2.data → snowpark_connect-0.22.1.data}/scripts/snowpark-connect +0 -0
  81. {snowpark_connect-0.20.2.data → snowpark_connect-0.22.1.data}/scripts/snowpark-session +0 -0
  82. {snowpark_connect-0.20.2.data → snowpark_connect-0.22.1.data}/scripts/snowpark-submit +0 -0
  83. {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.22.1.dist-info}/WHEEL +0 -0
  84. {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.22.1.dist-info}/licenses/LICENSE.txt +0 -0
@@ -4,15 +4,19 @@
4
4
 
5
5
  import re
6
6
  from dataclasses import dataclass
7
+ from typing import Optional
7
8
 
8
9
  import pyspark.sql.connect.proto.relations_pb2 as relation_proto
9
10
 
11
+ import snowflake.snowpark.functions as snowpark_fn
10
12
  from snowflake import snowpark
13
+ from snowflake.snowpark import Column
14
+ from snowflake.snowpark._internal.analyzer.unary_expression import Alias
11
15
  from snowflake.snowpark.types import DataType
12
16
  from snowflake.snowpark_connect.column_name_handler import (
13
17
  make_column_names_snowpark_compatible,
14
- with_column_map,
15
18
  )
19
+ from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
16
20
  from snowflake.snowpark_connect.expression.literal import get_literal_field_and_name
17
21
  from snowflake.snowpark_connect.expression.map_expression import (
18
22
  map_single_column_expression,
@@ -20,111 +24,171 @@ from snowflake.snowpark_connect.expression.map_expression import (
20
24
  from snowflake.snowpark_connect.expression.typer import ExpressionTyper
21
25
  from snowflake.snowpark_connect.relation.map_relation import map_relation
22
26
  from snowflake.snowpark_connect.typed_column import TypedColumn
23
- from snowflake.snowpark_connect.utils.context import temporary_pivot_expression
27
+ from snowflake.snowpark_connect.utils.context import (
28
+ get_is_evaluating_sql,
29
+ set_current_grouping_columns,
30
+ temporary_pivot_expression,
31
+ )
24
32
 
25
33
 
26
- def map_group_by_aggregate(rel: relation_proto.Relation) -> snowpark.DataFrame:
34
+ def map_group_by_aggregate(
35
+ rel: relation_proto.Relation,
36
+ ) -> DataFrameContainer:
27
37
  """
28
38
  Groups the DataFrame using the specified columns.
29
39
 
30
40
  Aggregations come in as expressions, which are mapped to `snowpark.Column`
31
41
  objects.
32
42
  """
33
- input_df, columns = map_aggregate_helper(rel)
43
+ input_df_container, columns = map_aggregate_helper(rel)
44
+ input_df_actual = input_df_container.dataframe
45
+
34
46
  if len(columns.grouping_expressions()) == 0:
35
- result = input_df.agg(*columns.aggregation_expressions())
47
+ result = input_df_actual.agg(*columns.aggregation_expressions())
36
48
  else:
37
- result = input_df.group_by(*columns.grouping_expressions()).agg(
49
+ result = input_df_actual.group_by(*columns.grouping_expressions()).agg(
38
50
  *columns.aggregation_expressions()
39
51
  )
40
- return with_column_map(
41
- result,
42
- columns.spark_names(),
43
- columns.snowpark_names(),
44
- columns.data_types(),
52
+ return DataFrameContainer.create_with_column_mapping(
53
+ dataframe=result,
54
+ spark_column_names=columns.spark_names(),
55
+ snowpark_column_names=columns.snowpark_names(),
56
+ snowpark_column_types=columns.data_types(),
45
57
  column_qualifiers=columns.get_qualifiers(),
46
- parent_column_name_map=input_df._column_map,
58
+ parent_column_name_map=input_df_container.column_map,
47
59
  )
48
60
 
49
61
 
50
- def map_rollup_aggregate(rel: relation_proto.Relation) -> snowpark.DataFrame:
62
+ def map_rollup_aggregate(
63
+ rel: relation_proto.Relation,
64
+ ) -> DataFrameContainer:
51
65
  """
52
66
  Create a multidimensional rollup for the current DataFrame using the specified columns.
53
67
 
54
68
  Aggregations come in as expressions, which are mapped to `snowpark.Column`
55
69
  objects.
56
70
  """
57
- input_df, columns = map_aggregate_helper(rel)
71
+ input_container, columns = map_aggregate_helper(rel)
72
+ input_df_actual = input_container.dataframe
73
+
58
74
  if len(columns.grouping_expressions()) == 0:
59
- result = input_df.agg(*columns.aggregation_expressions())
75
+ result = input_df_actual.agg(*columns.aggregation_expressions())
60
76
  else:
61
- result = input_df.rollup(*columns.grouping_expressions()).agg(
77
+ result = input_df_actual.rollup(*columns.grouping_expressions()).agg(
62
78
  *columns.aggregation_expressions()
63
79
  )
64
- return with_column_map(
65
- result,
66
- columns.spark_names(),
67
- columns.snowpark_names(),
68
- columns.data_types(),
80
+ return DataFrameContainer.create_with_column_mapping(
81
+ dataframe=result,
82
+ spark_column_names=columns.spark_names(),
83
+ snowpark_column_names=columns.snowpark_names(),
84
+ snowpark_column_types=columns.data_types(),
69
85
  column_qualifiers=columns.get_qualifiers(),
70
- parent_column_name_map=input_df._column_map,
86
+ parent_column_name_map=input_container.column_map,
71
87
  )
72
88
 
73
89
 
74
- def map_cube_aggregate(rel: relation_proto.Relation) -> snowpark.DataFrame:
90
+ def map_cube_aggregate(
91
+ rel: relation_proto.Relation,
92
+ ) -> DataFrameContainer:
75
93
  """
76
94
  Create a multidimensional cube for the current DataFrame using the specified columns.
77
95
 
78
96
  Aggregations come in as expressions, which are mapped to `snowpark.Column`
79
97
  objects.
80
98
  """
81
- input_df, columns = map_aggregate_helper(rel)
99
+ input_container, columns = map_aggregate_helper(rel)
100
+ input_df_actual = input_container.dataframe
101
+
82
102
  if len(columns.grouping_expressions()) == 0:
83
- result = input_df.agg(*columns.aggregation_expressions())
103
+ result = input_df_actual.agg(*columns.aggregation_expressions())
84
104
  else:
85
- result = input_df.cube(*columns.grouping_expressions()).agg(
105
+ result = input_df_actual.cube(*columns.grouping_expressions()).agg(
86
106
  *columns.aggregation_expressions()
87
107
  )
88
- return with_column_map(
89
- result,
90
- columns.spark_names(),
91
- columns.snowpark_names(),
92
- columns.data_types(),
108
+ return DataFrameContainer.create_with_column_mapping(
109
+ dataframe=result,
110
+ spark_column_names=columns.spark_names(),
111
+ snowpark_column_names=columns.snowpark_names(),
112
+ snowpark_column_types=columns.data_types(),
93
113
  column_qualifiers=columns.get_qualifiers(),
94
- parent_column_name_map=input_df._column_map,
114
+ parent_column_name_map=input_container.column_map,
95
115
  )
96
116
 
97
117
 
98
- def map_pivot_aggregate(rel: relation_proto.Relation) -> snowpark.DataFrame:
118
+ def map_pivot_aggregate(
119
+ rel: relation_proto.Relation,
120
+ ) -> DataFrameContainer:
99
121
  """
100
122
  Pivots a column of the current DataFrame and performs the specified aggregation.
101
123
 
102
124
  There are 2 versions of the pivot function: one that requires the caller to specify the list of the distinct values
103
125
  to pivot on and one that does not.
104
126
  """
105
- input_df, columns = map_aggregate_helper(rel, pivot=True, skip_alias=True)
127
+ input_container, columns = map_aggregate_helper(rel, pivot=True, skip_alias=True)
128
+ input_df_actual = input_container.dataframe
129
+
106
130
  pivot_column = map_single_column_expression(
107
- rel.aggregate.pivot.col, input_df._column_map, ExpressionTyper(input_df)
131
+ rel.aggregate.pivot.col,
132
+ input_container.column_map,
133
+ ExpressionTyper(input_df_actual),
108
134
  )
109
135
  pivot_values = [
110
136
  get_literal_field_and_name(lit)[0] for lit in rel.aggregate.pivot.values
111
137
  ]
112
138
 
139
+ used_columns = {pivot_column[1].col._expression.name}
140
+ if get_is_evaluating_sql():
141
+ # When evaluating SQL spark doesn't trim columns from the result
142
+ used_columns = {"*"}
143
+ else:
144
+ for expression in rel.aggregate.aggregate_expressions:
145
+ matched_identifiers = re.findall(
146
+ r'unparsed_identifier: "(.*)"', expression.__str__()
147
+ )
148
+ for identifier in matched_identifiers:
149
+ mapped_col = input_container.column_map.spark_to_col.get(
150
+ identifier, None
151
+ )
152
+ if mapped_col:
153
+ used_columns.add(mapped_col[0].snowpark_name)
154
+
113
155
  if len(columns.grouping_expressions()) == 0:
114
- result = input_df.pivot(
115
- pivot_column[1].col, pivot_values if pivot_values else None
116
- ).agg(*columns.aggregation_expressions())
156
+ result = (
157
+ input_df_actual.select(*used_columns)
158
+ .pivot(pivot_column[1].col, pivot_values if pivot_values else None)
159
+ .agg(*columns.aggregation_expressions(unalias=True))
160
+ )
117
161
  else:
118
162
  result = (
119
- input_df.group_by(*columns.grouping_expressions())
163
+ input_df_actual.group_by(*columns.grouping_expressions())
120
164
  .pivot(pivot_column[1].col, pivot_values if pivot_values else None)
121
- .agg(*columns.aggregation_expressions())
165
+ .agg(*columns.aggregation_expressions(unalias=True))
122
166
  )
123
167
 
168
+ agg_name_list = [c.spark_name for c in columns.grouping_columns]
169
+
170
+ # Calculate number of pivot values for proper Spark-compatible indexing
171
+ total_pivot_columns = len(result.columns) - len(agg_name_list)
172
+ num_pivot_values = (
173
+ total_pivot_columns // len(columns.aggregation_columns)
174
+ if len(columns.aggregation_columns) > 0
175
+ else 1
176
+ )
177
+
178
+ def _get_agg_exp_alias_for_col(col_index: int) -> Optional[str]:
179
+ if col_index < len(agg_name_list) or len(columns.aggregation_columns) <= 1:
180
+ return None
181
+ else:
182
+ index = (col_index - len(agg_name_list)) // num_pivot_values
183
+ return columns.aggregation_columns[index].spark_name
184
+
124
185
  spark_columns = []
125
- for col in [string_parser(s) for s in result.columns]:
186
+ for col in [
187
+ pivot_column_name(c, _get_agg_exp_alias_for_col(i))
188
+ for i, c in enumerate(result.columns)
189
+ ]:
126
190
  spark_col = (
127
- input_df._column_map.get_spark_column_name_from_snowpark_column_name(
191
+ input_container.column_map.get_spark_column_name_from_snowpark_column_name(
128
192
  col, allow_non_exists=True
129
193
  )
130
194
  )
@@ -132,22 +196,57 @@ def map_pivot_aggregate(rel: relation_proto.Relation) -> snowpark.DataFrame:
132
196
  if spark_col is not None:
133
197
  spark_columns.append(spark_col)
134
198
  else:
135
- spark_columns.append(col)
199
+ # Handle NULL column names to match Spark behavior (lowercase 'null')
200
+ if col == "NULL":
201
+ spark_columns.append(col.lower())
202
+ else:
203
+ spark_columns.append(col)
204
+
205
+ grouping_cols_count = len(agg_name_list)
206
+ pivot_cols = result.columns[grouping_cols_count:]
207
+ spark_pivot_cols = spark_columns[grouping_cols_count:]
208
+
209
+ num_agg_functions = len(columns.aggregation_columns)
210
+ num_pivot_values = len(pivot_cols) // num_agg_functions
211
+
212
+ reordered_snowpark_cols = []
213
+ reordered_spark_cols = []
214
+ column_indices = [] # 1-based indexing
215
+
216
+ for i in range(grouping_cols_count):
217
+ reordered_snowpark_cols.append(result.columns[i])
218
+ reordered_spark_cols.append(spark_columns[i])
219
+ column_indices.append(i + 1)
220
+
221
+ for pivot_idx in range(num_pivot_values):
222
+ for agg_idx in range(num_agg_functions):
223
+ current_pos = agg_idx * num_pivot_values + pivot_idx
224
+ if current_pos < len(pivot_cols):
225
+ reordered_snowpark_cols.append(pivot_cols[current_pos])
226
+ reordered_spark_cols.append(spark_pivot_cols[current_pos])
227
+ original_index = grouping_cols_count + current_pos
228
+ column_indices.append(original_index + 1)
229
+
230
+ reordered_result = result.select(
231
+ *[snowpark_fn.col(f"${idx}") for idx in column_indices]
232
+ )
136
233
 
137
- agg_name_list = [c.spark_name for c in columns.grouping_columns]
138
- return with_column_map(
139
- result,
140
- agg_name_list + spark_columns[len(agg_name_list) :],
141
- result.columns,
234
+ return DataFrameContainer.create_with_column_mapping(
235
+ dataframe=reordered_result,
236
+ spark_column_names=reordered_spark_cols,
237
+ snowpark_column_names=[f"${idx}" for idx in column_indices],
142
238
  column_qualifiers=(
143
239
  columns.get_qualifiers()[: len(agg_name_list)]
144
- + [[]] * (len(spark_columns) - len(agg_name_list))
240
+ + [[]] * (len(reordered_spark_cols) - len(agg_name_list))
145
241
  ),
146
- parent_column_name_map=input_df._column_map,
242
+ parent_column_name_map=input_container.column_map,
243
+ snowpark_column_types=[
244
+ result.schema.fields[idx - 1].datatype for idx in column_indices
245
+ ],
147
246
  )
148
247
 
149
248
 
150
- def string_parser(s):
249
+ def pivot_column_name(snowpark_cname, opt_alias: Optional[str] = None) -> Optional[str]:
151
250
  # For values that are used as pivoted columns, the input and output are in the following format (outermost double quotes are part of the input):
152
251
 
153
252
  # 1. "'Java'" -> Java
@@ -162,7 +261,7 @@ def string_parser(s):
162
261
 
163
262
  try:
164
263
  # handling values that are used as pivoted columns
165
- match = re.match(r'^"\'(.*)\'"$', s)
264
+ match = re.match(r'^"\'(.*)\'"$', snowpark_cname)
166
265
  # extract the content between the outermost double quote followed by a single quote "'
167
266
  content = match.group(1)
168
267
  # convert the escaped double quote to the actual double quote
@@ -174,10 +273,10 @@ def string_parser(s):
174
273
  content = re.sub(r"'", "", content)
175
274
  # replace the placeholder with the single quote which we want to preserve
176
275
  result = content.replace(escape_single_quote_placeholder, "'")
177
- return result
276
+ return f"{result}_{opt_alias}" if opt_alias else result
178
277
  except Exception:
179
278
  # fallback to the original logic, handling aliased column names
180
- double_quote_list = re.findall(r'"(.*?)"', s)
279
+ double_quote_list = re.findall(r'"(.*?)"', snowpark_cname)
181
280
  spark_string = ""
182
281
  for entry in list(filter(None, double_quote_list)):
183
282
  if "'" in entry:
@@ -189,7 +288,7 @@ def string_parser(s):
189
288
  spark_string += entry
190
289
  else:
191
290
  spark_string += '"' + entry + '"'
192
- return s if spark_string == "" else spark_string
291
+ return snowpark_cname if spark_string == "" else spark_string
193
292
 
194
293
 
195
294
  @dataclass(frozen=True)
@@ -210,8 +309,14 @@ class _Columns:
210
309
  def grouping_expressions(self) -> list[snowpark.Column]:
211
310
  return [col.expression for col in self.grouping_columns]
212
311
 
213
- def aggregation_expressions(self) -> list[snowpark.Column]:
214
- return [col.expression for col in self.aggregation_columns]
312
+ def aggregation_expressions(self, unalias: bool = False) -> list[snowpark.Column]:
313
+ def _unalias(col: snowpark.Column) -> snowpark.Column:
314
+ if unalias and hasattr(col, "_expr1") and isinstance(col._expr1, Alias):
315
+ return _unalias(Column(col._expr1.child))
316
+ else:
317
+ return col
318
+
319
+ return [_unalias(col.expression) for col in self.aggregation_columns]
215
320
 
216
321
  def expressions(self) -> list[snowpark.Column]:
217
322
  return self.grouping_expressions() + self.aggregation_expressions()
@@ -246,7 +351,8 @@ class _Columns:
246
351
  def map_aggregate_helper(
247
352
  rel: relation_proto.Relation, pivot: bool = False, skip_alias: bool = False
248
353
  ):
249
- input_df = map_relation(rel.aggregate.input)
354
+ input_container = map_relation(rel.aggregate.input)
355
+ input_df = input_container.dataframe
250
356
  grouping_expressions = rel.aggregate.grouping_expressions
251
357
  expressions = rel.aggregate.aggregate_expressions
252
358
  groupings: list[_ColumnMetadata] = []
@@ -258,7 +364,7 @@ def map_aggregate_helper(
258
364
  with temporary_pivot_expression(pivot):
259
365
  for exp in grouping_expressions:
260
366
  new_name, snowpark_column = map_single_column_expression(
261
- exp, input_df._column_map, typer
367
+ exp, input_container.column_map, typer
262
368
  )
263
369
  alias = make_column_names_snowpark_compatible(
264
370
  [new_name], rel.common.plan_id, len(groupings)
@@ -275,9 +381,12 @@ def map_aggregate_helper(
275
381
  )
276
382
  )
277
383
 
384
+ grouping_cols = [g.spark_name for g in groupings]
385
+ set_current_grouping_columns(grouping_cols)
386
+
278
387
  for exp in expressions:
279
388
  new_name, snowpark_column = map_single_column_expression(
280
- exp, input_df._column_map, typer
389
+ exp, input_container.column_map, typer
281
390
  )
282
391
  alias = make_column_names_snowpark_compatible(
283
392
  [new_name], rel.common.plan_id, len(groupings) + len(aggregations)
@@ -313,7 +422,7 @@ def map_aggregate_helper(
313
422
  )
314
423
 
315
424
  return (
316
- input_df,
425
+ input_container,
317
426
  _Columns(
318
427
  grouping_columns=groupings,
319
428
  aggregation_columns=aggregations,
@@ -7,7 +7,7 @@ import re
7
7
  import pandas
8
8
  import pyspark.sql.connect.proto.catalog_pb2 as catalog_proto
9
9
 
10
- from snowflake import snowpark
10
+ from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
11
11
  from snowflake.snowpark_connect.relation.catalogs import CATALOGS
12
12
  from snowflake.snowpark_connect.relation.catalogs.utils import (
13
13
  CURRENT_CATALOG_NAME,
@@ -22,7 +22,7 @@ from snowflake.snowpark_connect.utils.telemetry import (
22
22
 
23
23
  def map_catalog(
24
24
  rel: catalog_proto.Catalog,
25
- ) -> pandas.DataFrame | snowpark.DataFrame:
25
+ ) -> DataFrameContainer | pandas.DataFrame:
26
26
  match rel.WhichOneof("cat_type"):
27
27
  # Database related APIs
28
28
  case "current_database":