snowpark-connect 0.20.2__py3-none-any.whl → 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of snowpark-connect might be problematic. Click here for more details.

Files changed (67) hide show
  1. snowflake/snowpark_connect/analyze_plan/map_tree_string.py +3 -2
  2. snowflake/snowpark_connect/column_name_handler.py +6 -65
  3. snowflake/snowpark_connect/config.py +28 -14
  4. snowflake/snowpark_connect/dataframe_container.py +242 -0
  5. snowflake/snowpark_connect/execute_plan/map_execution_command.py +13 -23
  6. snowflake/snowpark_connect/execute_plan/map_execution_root.py +9 -5
  7. snowflake/snowpark_connect/expression/map_extension.py +2 -1
  8. snowflake/snowpark_connect/expression/map_unresolved_attribute.py +8 -7
  9. snowflake/snowpark_connect/expression/map_unresolved_function.py +279 -43
  10. snowflake/snowpark_connect/expression/map_unresolved_star.py +8 -8
  11. snowflake/snowpark_connect/expression/map_update_fields.py +1 -1
  12. snowflake/snowpark_connect/expression/typer.py +6 -6
  13. snowflake/snowpark_connect/proto/control_pb2.py +17 -16
  14. snowflake/snowpark_connect/proto/control_pb2.pyi +17 -17
  15. snowflake/snowpark_connect/proto/control_pb2_grpc.py +12 -63
  16. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.py +15 -14
  17. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.pyi +19 -14
  18. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +27 -26
  19. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +74 -68
  20. snowflake/snowpark_connect/relation/catalogs/abstract_spark_catalog.py +5 -5
  21. snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +25 -17
  22. snowflake/snowpark_connect/relation/map_aggregate.py +72 -47
  23. snowflake/snowpark_connect/relation/map_catalog.py +2 -2
  24. snowflake/snowpark_connect/relation/map_column_ops.py +207 -144
  25. snowflake/snowpark_connect/relation/map_crosstab.py +25 -6
  26. snowflake/snowpark_connect/relation/map_extension.py +81 -56
  27. snowflake/snowpark_connect/relation/map_join.py +72 -63
  28. snowflake/snowpark_connect/relation/map_local_relation.py +35 -20
  29. snowflake/snowpark_connect/relation/map_map_partitions.py +21 -16
  30. snowflake/snowpark_connect/relation/map_relation.py +22 -16
  31. snowflake/snowpark_connect/relation/map_row_ops.py +232 -146
  32. snowflake/snowpark_connect/relation/map_sample_by.py +15 -8
  33. snowflake/snowpark_connect/relation/map_show_string.py +42 -5
  34. snowflake/snowpark_connect/relation/map_sql.py +155 -78
  35. snowflake/snowpark_connect/relation/map_stats.py +88 -39
  36. snowflake/snowpark_connect/relation/map_subquery_alias.py +13 -14
  37. snowflake/snowpark_connect/relation/map_udtf.py +6 -9
  38. snowflake/snowpark_connect/relation/read/map_read.py +8 -3
  39. snowflake/snowpark_connect/relation/read/map_read_csv.py +7 -7
  40. snowflake/snowpark_connect/relation/read/map_read_jdbc.py +7 -7
  41. snowflake/snowpark_connect/relation/read/map_read_json.py +7 -7
  42. snowflake/snowpark_connect/relation/read/map_read_parquet.py +7 -7
  43. snowflake/snowpark_connect/relation/read/map_read_socket.py +7 -3
  44. snowflake/snowpark_connect/relation/read/map_read_table.py +25 -16
  45. snowflake/snowpark_connect/relation/read/map_read_text.py +7 -7
  46. snowflake/snowpark_connect/relation/utils.py +11 -5
  47. snowflake/snowpark_connect/relation/write/jdbc_write_dbapi.py +15 -12
  48. snowflake/snowpark_connect/relation/write/map_write.py +199 -40
  49. snowflake/snowpark_connect/relation/write/map_write_jdbc.py +3 -2
  50. snowflake/snowpark_connect/server.py +34 -4
  51. snowflake/snowpark_connect/type_mapping.py +2 -23
  52. snowflake/snowpark_connect/utils/cache.py +27 -22
  53. snowflake/snowpark_connect/utils/context.py +33 -17
  54. snowflake/snowpark_connect/utils/{attribute_handling.py → identifiers.py} +47 -0
  55. snowflake/snowpark_connect/utils/session.py +41 -34
  56. snowflake/snowpark_connect/utils/telemetry.py +1 -2
  57. snowflake/snowpark_connect/version.py +1 -1
  58. {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.21.0.dist-info}/METADATA +5 -3
  59. {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.21.0.dist-info}/RECORD +67 -64
  60. snowpark_connect-0.21.0.dist-info/licenses/LICENSE-binary +568 -0
  61. snowpark_connect-0.21.0.dist-info/licenses/NOTICE-binary +1533 -0
  62. {snowpark_connect-0.20.2.data → snowpark_connect-0.21.0.data}/scripts/snowpark-connect +0 -0
  63. {snowpark_connect-0.20.2.data → snowpark_connect-0.21.0.data}/scripts/snowpark-session +0 -0
  64. {snowpark_connect-0.20.2.data → snowpark_connect-0.21.0.data}/scripts/snowpark-submit +0 -0
  65. {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.21.0.dist-info}/WHEEL +0 -0
  66. {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.21.0.dist-info}/licenses/LICENSE.txt +0 -0
  67. {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.21.0.dist-info}/top_level.txt +0 -0
@@ -20,12 +20,9 @@ from snowflake.snowpark.types import (
20
20
  NullType,
21
21
  ShortType,
22
22
  )
23
- from snowflake.snowpark_connect.column_name_handler import (
24
- schema_getter,
25
- set_schema_getter,
26
- with_column_map,
27
- )
23
+ from snowflake.snowpark_connect.column_name_handler import ColumnNameMap, schema_getter
28
24
  from snowflake.snowpark_connect.config import global_config
25
+ from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
29
26
  from snowflake.snowpark_connect.expression.literal import get_literal_field_and_name
30
27
  from snowflake.snowpark_connect.expression.map_expression import (
31
28
  map_single_column_expression,
@@ -39,13 +36,14 @@ from snowflake.snowpark_connect.utils.telemetry import (
39
36
 
40
37
  def map_deduplicate(
41
38
  rel: relation_proto.Relation,
42
- ) -> snowpark.DataFrame:
39
+ ) -> DataFrameContainer:
43
40
  """
44
41
  Deduplicate a DataFrame based on a Relation's deduplicate.
45
42
 
46
43
  The deduplicate is a list of columns that is applied to the DataFrame.
47
44
  """
48
- input_df: snowpark.DataFrame = map_relation(rel.deduplicate.input)
45
+ input_container = map_relation(rel.deduplicate.input)
46
+ input_df = input_container.dataframe
49
47
 
50
48
  if (
51
49
  rel.deduplicate.HasField("within_watermark")
@@ -62,23 +60,29 @@ def map_deduplicate(
62
60
  result: snowpark.DataFrame = input_df.drop_duplicates()
63
61
  else:
64
62
  result: snowpark.DataFrame = input_df.drop_duplicates(
65
- *input_df._column_map.get_snowpark_column_names_from_spark_column_names(
63
+ *input_container.column_map.get_snowpark_column_names_from_spark_column_names(
66
64
  list(rel.deduplicate.column_names)
67
65
  )
68
66
  )
69
- result._column_map = input_df._column_map
70
- result._table_name = input_df._table_name
71
- set_schema_getter(result, lambda: input_df.schema)
72
- return result
73
67
 
68
+ return DataFrameContainer(
69
+ result,
70
+ input_container.column_map,
71
+ input_container.table_name,
72
+ input_container.alias,
73
+ cached_schema_getter=lambda: input_df.schema,
74
+ )
74
75
 
75
- def map_dropna(rel: relation_proto.Relation) -> snowpark.DataFrame:
76
+
77
+ def map_dropna(
78
+ rel: relation_proto.Relation,
79
+ ) -> DataFrameContainer:
76
80
  """
77
81
  Drop NA values from the input DataFrame.
78
-
79
-
80
82
  """
81
- input_df: snowpark.DataFrame = map_relation(rel.drop_na.input)
83
+ input_container = map_relation(rel.drop_na.input)
84
+ input_df = input_container.dataframe
85
+
82
86
  if rel.drop_na.HasField("min_non_nulls"):
83
87
  thresh = rel.drop_na.min_non_nulls
84
88
  how = "all"
@@ -89,7 +93,9 @@ def map_dropna(rel: relation_proto.Relation) -> snowpark.DataFrame:
89
93
  columns: list[str] = [
90
94
  # Use the mapping to get the Snowpark internal column name
91
95
  # TODO: Verify the behavior of duplicate column names with dropna
92
- input_df._column_map.get_snowpark_column_name_from_spark_column_name(c)
96
+ input_container.column_map.get_snowpark_column_name_from_spark_column_name(
97
+ c
98
+ )
93
99
  for c in rel.drop_na.cols
94
100
  ]
95
101
  result: snowpark.DataFrame = input_df.dropna(
@@ -97,22 +103,32 @@ def map_dropna(rel: relation_proto.Relation) -> snowpark.DataFrame:
97
103
  )
98
104
  else:
99
105
  result: snowpark.DataFrame = input_df.dropna(how=how, thresh=thresh)
100
- result._column_map = input_df._column_map
101
- result._table_name = input_df._table_name
102
- set_schema_getter(result, lambda: input_df.schema)
103
- return result
104
106
 
107
+ return DataFrameContainer(
108
+ result,
109
+ input_container.column_map,
110
+ input_container.table_name,
111
+ input_container.alias,
112
+ cached_schema_getter=lambda: input_df.schema,
113
+ )
105
114
 
106
- def map_fillna(rel: relation_proto.Relation) -> snowpark.DataFrame:
115
+
116
+ def map_fillna(
117
+ rel: relation_proto.Relation,
118
+ ) -> DataFrameContainer:
107
119
  """
108
120
  Fill NA values in the DataFrame.
109
121
 
110
122
  The `fill_value` is a scalar value that will be used to replace NaN values.
111
123
  """
112
- input_df: snowpark.DataFrame = map_relation(rel.fill_na.input)
124
+ input_container = map_relation(rel.fill_na.input)
125
+ input_df = input_container.dataframe
126
+
113
127
  if len(rel.fill_na.cols) > 0:
114
128
  columns: list[str] = [
115
- input_df._column_map.get_snowpark_column_name_from_spark_column_name(c)
129
+ input_container.column_map.get_snowpark_column_name_from_spark_column_name(
130
+ c
131
+ )
116
132
  for c in rel.fill_na.cols
117
133
  ]
118
134
  values = [get_literal_field_and_name(v)[0] for v in rel.fill_na.values]
@@ -142,26 +158,50 @@ def map_fillna(rel: relation_proto.Relation) -> snowpark.DataFrame:
142
158
  for field in input_df.schema.fields
143
159
  }
144
160
  result = input_df.fillna(fill_value, include_decimal=True)
145
- result._column_map = input_df._column_map
146
- result._table_name = input_df._table_name
147
- set_schema_getter(result, lambda: input_df.schema)
148
- return result
149
161
 
162
+ return DataFrameContainer(
163
+ result,
164
+ input_container.column_map,
165
+ input_container.table_name,
166
+ input_container.alias,
167
+ cached_schema_getter=lambda: input_df.schema,
168
+ )
150
169
 
151
- def map_union(rel: relation_proto.Relation) -> snowpark.DataFrame:
170
+
171
+ def map_union(
172
+ rel: relation_proto.Relation,
173
+ ) -> DataFrameContainer:
152
174
  """
153
175
  Union two DataFrames together.
154
176
 
155
177
  The two DataFrames must have the same schema.
156
178
  """
157
- left_df: snowpark.DataFrame = map_relation(rel.set_op.left_input)
158
- right_df: snowpark.DataFrame = map_relation(rel.set_op.right_input)
179
+ left_result = map_relation(rel.set_op.left_input)
180
+ right_result = map_relation(rel.set_op.right_input)
181
+ left_df = left_result.dataframe
182
+ right_df = right_result.dataframe
183
+ allow_missing_columns = bool(rel.set_op.allow_missing_columns)
159
184
 
160
185
  # workaround for unstructured type vs structured type
161
- left_dtypes = [field.datatype for field in left_df.schema.fields]
162
- right_dtypes = [field.datatype for field in right_df.schema.fields]
186
+ # Use cached schema if available to avoid triggering extra queries
187
+ if (
188
+ hasattr(left_result, "cached_schema_getter")
189
+ and left_result.cached_schema_getter is not None
190
+ ):
191
+ left_schema = left_result.cached_schema_getter()
192
+ else:
193
+ left_schema = left_df.schema
163
194
 
164
- allow_missing_columns = bool(rel.set_op.allow_missing_columns)
195
+ if (
196
+ hasattr(right_result, "cached_schema_getter")
197
+ and right_result.cached_schema_getter is not None
198
+ ):
199
+ right_schema = right_result.cached_schema_getter()
200
+ else:
201
+ right_schema = right_df.schema
202
+
203
+ left_dtypes = [field.datatype for field in left_schema.fields]
204
+ right_dtypes = [field.datatype for field in right_schema.fields]
165
205
 
166
206
  spark_sql_ansi_enabled = global_config.spark_sql_ansi_enabled
167
207
  if left_dtypes != right_dtypes and not rel.set_op.by_name:
@@ -212,13 +252,22 @@ def map_union(rel: relation_proto.Relation) -> snowpark.DataFrame:
212
252
  target_right_dtypes.append(right_type)
213
253
 
214
254
  def cast_columns(
215
- df: snowpark.DataFrame,
255
+ df_container: DataFrameContainer,
216
256
  df_dtypes: list[snowpark.types.DataType],
217
257
  target_dtypes: list[snowpark.types.DataType],
258
+ column_map: ColumnNameMap,
218
259
  ):
260
+ df: snowpark.DataFrame = df_container.dataframe
219
261
  if df_dtypes == target_dtypes:
220
- return df
221
- df_schema = df.schema # Get current schema
262
+ return df_container
263
+ # Use cached schema if available to avoid triggering extra queries
264
+ if (
265
+ hasattr(df_container, "cached_schema_getter")
266
+ and df_container.cached_schema_getter is not None
267
+ ):
268
+ df_schema = df_container.cached_schema_getter()
269
+ else:
270
+ df_schema = df.schema # Get current schema
222
271
  new_columns = []
223
272
 
224
273
  for i, field in enumerate(df_schema.fields):
@@ -232,38 +281,46 @@ def map_union(rel: relation_proto.Relation) -> snowpark.DataFrame:
232
281
  new_columns.append(df[col_name])
233
282
 
234
283
  new_df = df.select(new_columns)
235
- return with_column_map(
236
- new_df,
237
- df._column_map.get_spark_columns(),
238
- df._column_map.get_snowpark_columns(),
239
- target_dtypes,
240
- df._column_map.column_metadata,
241
- parent_column_name_map=df._column_map,
284
+ return DataFrameContainer.create_with_column_mapping(
285
+ dataframe=new_df,
286
+ spark_column_names=column_map.get_spark_columns(),
287
+ snowpark_column_names=column_map.get_snowpark_columns(),
288
+ snowpark_column_types=target_dtypes,
289
+ column_metadata=column_map.column_metadata,
290
+ parent_column_name_map=column_map,
242
291
  )
243
292
 
244
- left_df = cast_columns(left_df, left_dtypes, target_left_dtypes)
245
- right_df = cast_columns(right_df, right_dtypes, target_right_dtypes)
293
+ left_result = cast_columns(
294
+ left_result,
295
+ left_dtypes,
296
+ target_left_dtypes,
297
+ left_result.column_map,
298
+ )
299
+ right_result = cast_columns(
300
+ right_result,
301
+ right_dtypes,
302
+ target_right_dtypes,
303
+ right_result.column_map,
304
+ )
305
+ left_df = left_result.dataframe
306
+ right_df = right_result.dataframe
246
307
 
247
308
  # Save the column names so that we can restore them after the union.
248
- left_df_columns = left_df.columns
249
-
250
- result: snowpark.DataFrame = None
309
+ left_df_columns = left_result.dataframe.columns
251
310
 
252
311
  if rel.set_op.by_name:
253
312
  # To use unionByName, we need to have the same column names.
254
313
  # We rename the columns back to their originals using the map
255
- left_column_map = left_df._column_map
256
- left_table_name = left_df._table_name
314
+ left_column_map = left_result.column_map
315
+ left_table_name = left_result.table_name
257
316
  left_schema_getter = schema_getter(left_df)
258
- right_column_map = right_df._column_map
259
-
317
+ right_column_map = right_result.column_map
260
318
  columns_to_restore: dict[str, tuple[str, str]] = {}
261
319
 
262
320
  for column in right_df.columns:
263
321
  spark_name = (
264
322
  right_column_map.get_spark_column_name_from_snowpark_column_name(column)
265
323
  )
266
-
267
324
  right_df = right_df.withColumnRenamed(column, spark_name)
268
325
  columns_to_restore[spark_name.upper()] = (spark_name, column)
269
326
 
@@ -271,11 +328,10 @@ def map_union(rel: relation_proto.Relation) -> snowpark.DataFrame:
271
328
  spark_name = (
272
329
  left_column_map.get_spark_column_name_from_snowpark_column_name(column)
273
330
  )
274
-
275
331
  left_df = left_df.withColumnRenamed(column, spark_name)
276
332
  columns_to_restore[spark_name.upper()] = (spark_name, column)
277
333
 
278
- result = left_df.union_all_by_name(
334
+ result = left_df.unionAllByName(
279
335
  right_df, allow_missing_columns=allow_missing_columns
280
336
  )
281
337
 
@@ -296,41 +352,42 @@ def map_union(rel: relation_proto.Relation) -> snowpark.DataFrame:
296
352
  right_df_col_metadata = right_column_map.column_metadata or {}
297
353
  merged_column_metadata = left_df_col_metadata | right_df_col_metadata
298
354
 
299
- return with_column_map(
355
+ return DataFrameContainer.create_with_column_mapping(
300
356
  result,
301
- spark_columns,
302
- snowpark_columns,
357
+ spark_column_names=spark_columns,
358
+ snowpark_column_names=snowpark_columns,
303
359
  column_metadata=merged_column_metadata,
304
360
  )
305
361
 
306
362
  for i in range(len(left_df_columns)):
307
363
  result = result.withColumnRenamed(result.columns[i], left_df_columns[i])
308
364
 
309
- result._column_map = left_column_map
310
- result._table_name = left_table_name
311
- set_schema_getter(result, left_schema_getter)
365
+ return DataFrameContainer(
366
+ result,
367
+ column_map=left_column_map,
368
+ table_name=left_table_name,
369
+ cached_schema_getter=left_schema_getter,
370
+ )
312
371
  elif rel.set_op.is_all:
313
372
  result = left_df.unionAll(right_df)
314
- result._column_map = left_df._column_map
315
- result._table_name = left_df._table_name
316
- set_schema_getter(result, lambda: left_df.schema)
373
+ return DataFrameContainer(
374
+ result,
375
+ column_map=left_result.column_map,
376
+ cached_schema_getter=lambda: left_df.schema,
377
+ )
317
378
  else:
318
379
  result = left_df.union(right_df)
319
- result._column_map = left_df._column_map
320
- result._table_name = left_df._table_name
321
- set_schema_getter(result, lambda: left_df.schema)
322
-
323
- # union operation does not preserve column qualifiers
324
- return with_column_map(
325
- result,
326
- result._column_map.get_spark_columns(),
327
- result._column_map.get_snowpark_columns(),
328
- column_metadata=result._column_map.column_metadata,
329
- parent_column_name_map=result._column_map,
330
- )
380
+ # union operation does not preserve column qualifiers
381
+ return DataFrameContainer(
382
+ result,
383
+ column_map=left_result.column_map,
384
+ cached_schema_getter=lambda: left_df.schema,
385
+ )
331
386
 
332
387
 
333
- def map_intersect(rel: relation_proto.Relation) -> snowpark.DataFrame:
388
+ def map_intersect(
389
+ rel: relation_proto.Relation,
390
+ ) -> DataFrameContainer:
334
391
  """
335
392
  Return a new DataFrame containing rows in both DataFrames:
336
393
 
@@ -363,36 +420,36 @@ def map_intersect(rel: relation_proto.Relation) -> snowpark.DataFrame:
363
420
  | b| 3|
364
421
  +---+---+
365
422
  """
366
- left_df: snowpark.DataFrame = map_relation(rel.set_op.left_input)
367
- right_df: snowpark.DataFrame = map_relation(rel.set_op.right_input)
423
+ left_result = map_relation(rel.set_op.left_input)
424
+ right_result = map_relation(rel.set_op.right_input)
425
+ left_df = left_result.dataframe
426
+ right_df = right_result.dataframe
368
427
 
369
428
  if rel.set_op.is_all:
370
429
  left_df_with_row_number = utils.get_df_with_partition_row_number(
371
- left_df, rel.set_op.left_input.common.plan_id, "left_row_number"
430
+ left_result, rel.set_op.left_input.common.plan_id, "left_row_number"
372
431
  )
373
432
  right_df_with_row_number = utils.get_df_with_partition_row_number(
374
- right_df, rel.set_op.right_input.common.plan_id, "right_row_number"
433
+ right_result, rel.set_op.right_input.common.plan_id, "right_row_number"
375
434
  )
376
435
 
377
436
  result: snowpark.DataFrame = left_df_with_row_number.intersect(
378
437
  right_df_with_row_number
379
- ).select(*left_df._column_map.get_snowpark_columns())
438
+ ).select(*left_result.column_map.get_snowpark_columns())
380
439
  else:
381
440
  result: snowpark.DataFrame = left_df.intersect(right_df)
382
441
 
383
- # the result df keeps the column map of the original left_df
384
- result = with_column_map(
385
- result,
386
- left_df._column_map.get_spark_columns(),
387
- left_df._column_map.get_snowpark_columns(),
388
- column_metadata=left_df._column_map.column_metadata,
442
+ return DataFrameContainer(
443
+ dataframe=result,
444
+ column_map=left_result.column_map,
445
+ table_name=left_result.table_name,
446
+ cached_schema_getter=lambda: left_df.schema,
389
447
  )
390
- result._table_name = left_df._table_name
391
- set_schema_getter(result, lambda: left_df.schema)
392
- return result
393
448
 
394
449
 
395
- def map_except(rel: relation_proto.Relation) -> snowpark.DataFrame:
450
+ def map_except(
451
+ rel: relation_proto.Relation,
452
+ ) -> DataFrameContainer:
396
453
  """
397
454
  Return a new DataFrame containing rows in the left DataFrame but not in the right DataFrame.
398
455
 
@@ -426,8 +483,10 @@ def map_except(rel: relation_proto.Relation) -> snowpark.DataFrame:
426
483
  | c| 4|
427
484
  +---+---+
428
485
  """
429
- left_df: snowpark.DataFrame = map_relation(rel.set_op.left_input)
430
- right_df: snowpark.DataFrame = map_relation(rel.set_op.right_input)
486
+ left_result = map_relation(rel.set_op.left_input)
487
+ right_result = map_relation(rel.set_op.right_input)
488
+ left_df = left_result.dataframe
489
+ right_df = right_result.dataframe
431
490
 
432
491
  if rel.set_op.is_all:
433
492
  # Snowflake except removes all duplicated rows. In order to handle the case,
@@ -453,91 +512,107 @@ def map_except(rel: relation_proto.Relation) -> snowpark.DataFrame:
453
512
  # +---+---+------------+
454
513
  # at the end we will do a select to exclude the row number column
455
514
  left_df_with_row_number = utils.get_df_with_partition_row_number(
456
- left_df, rel.set_op.left_input.common.plan_id, "left_row_number"
515
+ left_result, rel.set_op.left_input.common.plan_id, "left_row_number"
457
516
  )
458
517
  right_df_with_row_number = utils.get_df_with_partition_row_number(
459
- right_df, rel.set_op.right_input.common.plan_id, "right_row_number"
518
+ right_result, rel.set_op.right_input.common.plan_id, "right_row_number"
460
519
  )
461
520
 
462
521
  # Perform except use left_df_with_row_number and right_df_with_row_number,
463
522
  # and drop the row number column after except.
464
523
  result_df = left_df_with_row_number.except_(right_df_with_row_number).select(
465
- *left_df._column_map.get_snowpark_columns()
524
+ *left_result.column_map.get_snowpark_columns()
466
525
  )
467
526
  else:
468
527
  result_df = left_df.except_(right_df)
469
528
 
470
529
  # the result df keeps the column map of the original left_df
471
530
  # union operation does not preserve column qualifiers
472
- result_df = with_column_map(
473
- result_df,
474
- left_df._column_map.get_spark_columns(),
475
- left_df._column_map.get_snowpark_columns(),
476
- column_metadata=left_df._column_map.column_metadata,
531
+ return DataFrameContainer(
532
+ dataframe=result_df,
533
+ column_map=left_result.column_map,
534
+ table_name=left_result.table_name,
535
+ cached_schema_getter=lambda: left_df.schema,
477
536
  )
478
- result_df._table_name = left_df._table_name
479
- set_schema_getter(result_df, lambda: left_df.schema)
480
- return result_df
481
537
 
482
538
 
483
539
  def map_filter(
484
540
  rel: relation_proto.Relation,
485
- ) -> snowpark.DataFrame:
541
+ ) -> DataFrameContainer:
486
542
  """
487
543
  Filter a DataFrame based on a Relation's filter.
488
544
 
489
545
  The filter is a SQL expression that is applied to the DataFrame.
490
546
  """
491
- input_df = map_relation(rel.filter.input)
547
+ input_container = map_relation(rel.filter.input)
548
+ input_df = input_container.dataframe
549
+
492
550
  typer = ExpressionTyper(input_df)
493
551
  _, condition = map_single_column_expression(
494
- rel.filter.condition, input_df._column_map, typer
552
+ rel.filter.condition, input_container.column_map, typer
495
553
  )
496
554
  result = input_df.filter(condition.col)
497
- result._column_map = input_df._column_map
498
- result._alias = input_df._alias
499
- result._table_name = input_df._table_name
500
- set_schema_getter(result, lambda: input_df.schema)
501
- return result
555
+
556
+ return DataFrameContainer(
557
+ result,
558
+ input_container.column_map,
559
+ input_container.table_name,
560
+ input_container.alias,
561
+ cached_schema_getter=lambda: input_df.schema,
562
+ )
502
563
 
503
564
 
504
565
  def map_limit(
505
566
  rel: relation_proto.Relation,
506
- ) -> snowpark.DataFrame:
567
+ ) -> DataFrameContainer:
507
568
  """
508
569
  Limit a DataFrame based on a Relation's limit.
509
570
 
510
571
  The limit is an integer that is applied to the DataFrame.
511
572
  """
512
- input_df: snowpark.DataFrame = map_relation(rel.limit.input)
573
+ input_container = map_relation(rel.limit.input)
574
+ input_df = input_container.dataframe
575
+
513
576
  result: snowpark.DataFrame = input_df.limit(rel.limit.limit)
514
- result._column_map = input_df._column_map
515
- result._table_name = input_df._table_name
516
- set_schema_getter(result, lambda: input_df.schema)
517
- return result
577
+
578
+ return DataFrameContainer(
579
+ result,
580
+ column_map=input_container.column_map,
581
+ table_name=input_container.table_name,
582
+ alias=input_container.alias,
583
+ cached_schema_getter=lambda: input_df.schema,
584
+ )
518
585
 
519
586
 
520
587
  def map_offset(
521
588
  rel: relation_proto.Relation,
522
- ) -> snowpark.DataFrame:
589
+ ) -> DataFrameContainer:
523
590
  """
524
591
  Offset a DataFrame based on a Relation's offset.
525
592
 
526
593
  The offset is an integer that is applied to the DataFrame.
527
594
  """
528
- input_df: snowpark.DataFrame = map_relation(rel.offset.input)
595
+ input_container = map_relation(rel.offset.input)
596
+ input_df = input_container.dataframe
597
+
529
598
  # TODO: This is a terrible way to have to do this, but Snowpark does not
530
599
  # support offset without limit.
531
600
  result: snowpark.DataFrame = input_df.limit(
532
601
  input_df.count(), offset=rel.offset.offset
533
602
  )
534
- result._column_map = input_df._column_map
535
- result._table_name = input_df._table_name
536
- set_schema_getter(result, lambda: input_df.schema)
537
- return result
603
+
604
+ return DataFrameContainer(
605
+ result,
606
+ column_map=input_container.column_map,
607
+ table_name=input_container.table_name,
608
+ alias=input_container.alias,
609
+ cached_schema_getter=lambda: input_df.schema,
610
+ )
538
611
 
539
612
 
540
- def map_replace(rel: relation_proto.Relation) -> snowpark.DataFrame:
613
+ def map_replace(
614
+ rel: relation_proto.Relation,
615
+ ) -> DataFrameContainer:
541
616
  """
542
617
  Replace values in the DataFrame.
543
618
 
@@ -545,10 +620,11 @@ def map_replace(rel: relation_proto.Relation) -> snowpark.DataFrame:
545
620
  values to replace. The values in the dictionary are the values to replace
546
621
  and the keys are the values to replace them with.
547
622
  """
548
- input_df: snowpark.DataFrame = map_relation(rel.replace.input)
623
+ result = map_relation(rel.replace.input)
624
+ input_df = result.dataframe
549
625
  ordered_columns = input_df.columns
550
- column_map = input_df._column_map
551
- table_name = input_df._table_name
626
+ column_map = result.column_map
627
+ table_name = result.table_name
552
628
  # note that seems like spark connect always send number values as double in rel.replace.replacements.
553
629
  to_replace = [
554
630
  get_literal_field_and_name(i.old_value)[0] for i in rel.replace.replacements
@@ -647,7 +723,7 @@ def map_replace(rel: relation_proto.Relation) -> snowpark.DataFrame:
647
723
 
648
724
  if len(rel.replace.cols) > 0:
649
725
  columns: list[str] = [
650
- input_df._column_map.get_snowpark_column_name_from_spark_column_name(c)
726
+ column_map.get_snowpark_column_name_from_spark_column_name(c)
651
727
  for c in rel.replace.cols
652
728
  ]
653
729
  for c in columns:
@@ -657,18 +733,19 @@ def map_replace(rel: relation_proto.Relation) -> snowpark.DataFrame:
657
733
  input_df = input_df.with_column(c, replace_case_expr(c, to_replace, values))
658
734
 
659
735
  result = input_df.select(*[col(c) for c in ordered_columns])
660
- result._column_map = column_map
661
- result._table_name = table_name
662
- return result
736
+
737
+ return DataFrameContainer(result, column_map=column_map, table_name=table_name)
663
738
 
664
739
 
665
740
  def map_sample(
666
741
  rel: relation_proto.Relation,
667
- ) -> snowpark.DataFrame:
742
+ ) -> DataFrameContainer:
668
743
  """
669
744
  Sample a DataFrame based on a Relation's sample.
670
745
  """
671
- input_df: snowpark.DataFrame = map_relation(rel.sample.input)
746
+ input_container = map_relation(rel.sample.input)
747
+ input_df = input_container.dataframe
748
+
672
749
  frac = rel.sample.upper_bound - rel.sample.lower_bound
673
750
  if frac < 0 or frac > 1:
674
751
  raise IllegalArgumentException("Sample fraction must be between 0 and 1")
@@ -691,26 +768,35 @@ def map_sample(
691
768
  )
692
769
  else:
693
770
  result: snowpark.DataFrame = input_df.sample(frac=frac)
694
- result._column_map = input_df._column_map
695
- result._table_name = input_df._table_name
696
- set_schema_getter(result, lambda: input_df.schema)
697
- return result
771
+ return DataFrameContainer(
772
+ result,
773
+ column_map=input_container.column_map,
774
+ table_name=input_container.table_name,
775
+ alias=input_container.alias,
776
+ cached_schema_getter=lambda: input_df.schema,
777
+ )
698
778
 
699
779
 
700
780
  def map_tail(
701
781
  rel: relation_proto.Relation,
702
- ) -> snowpark.DataFrame:
782
+ ) -> DataFrameContainer:
703
783
  """
704
784
  Tail a DataFrame based on a Relation's tail.
705
785
 
706
786
  The tail is an integer that is applied to the DataFrame.
707
787
  """
708
- input_df: snowpark.DataFrame = map_relation(rel.tail.input)
788
+ input_container = map_relation(rel.tail.input)
789
+ input_df = input_container.dataframe
790
+
709
791
  num_rows = input_df.count()
710
792
  result: snowpark.DataFrame = input_df.limit(
711
793
  num_rows, offset=max(0, num_rows - rel.tail.limit)
712
794
  )
713
- result._column_map = input_df._column_map
714
- result._table_name = input_df._table_name
715
- set_schema_getter(result, lambda: input_df.schema)
716
- return result
795
+
796
+ return DataFrameContainer(
797
+ result,
798
+ column_map=input_container.column_map,
799
+ table_name=input_container.table_name,
800
+ alias=input_container.alias,
801
+ cached_schema_getter=lambda: input_df.schema,
802
+ )