snowpark-connect 0.20.2__py3-none-any.whl → 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of snowpark-connect might be problematic. Click here for more details.

Files changed (67) hide show
  1. snowflake/snowpark_connect/analyze_plan/map_tree_string.py +3 -2
  2. snowflake/snowpark_connect/column_name_handler.py +6 -65
  3. snowflake/snowpark_connect/config.py +28 -14
  4. snowflake/snowpark_connect/dataframe_container.py +242 -0
  5. snowflake/snowpark_connect/execute_plan/map_execution_command.py +13 -23
  6. snowflake/snowpark_connect/execute_plan/map_execution_root.py +9 -5
  7. snowflake/snowpark_connect/expression/map_extension.py +2 -1
  8. snowflake/snowpark_connect/expression/map_unresolved_attribute.py +8 -7
  9. snowflake/snowpark_connect/expression/map_unresolved_function.py +279 -43
  10. snowflake/snowpark_connect/expression/map_unresolved_star.py +8 -8
  11. snowflake/snowpark_connect/expression/map_update_fields.py +1 -1
  12. snowflake/snowpark_connect/expression/typer.py +6 -6
  13. snowflake/snowpark_connect/proto/control_pb2.py +17 -16
  14. snowflake/snowpark_connect/proto/control_pb2.pyi +17 -17
  15. snowflake/snowpark_connect/proto/control_pb2_grpc.py +12 -63
  16. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.py +15 -14
  17. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.pyi +19 -14
  18. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +27 -26
  19. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +74 -68
  20. snowflake/snowpark_connect/relation/catalogs/abstract_spark_catalog.py +5 -5
  21. snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +25 -17
  22. snowflake/snowpark_connect/relation/map_aggregate.py +72 -47
  23. snowflake/snowpark_connect/relation/map_catalog.py +2 -2
  24. snowflake/snowpark_connect/relation/map_column_ops.py +207 -144
  25. snowflake/snowpark_connect/relation/map_crosstab.py +25 -6
  26. snowflake/snowpark_connect/relation/map_extension.py +81 -56
  27. snowflake/snowpark_connect/relation/map_join.py +72 -63
  28. snowflake/snowpark_connect/relation/map_local_relation.py +35 -20
  29. snowflake/snowpark_connect/relation/map_map_partitions.py +21 -16
  30. snowflake/snowpark_connect/relation/map_relation.py +22 -16
  31. snowflake/snowpark_connect/relation/map_row_ops.py +232 -146
  32. snowflake/snowpark_connect/relation/map_sample_by.py +15 -8
  33. snowflake/snowpark_connect/relation/map_show_string.py +42 -5
  34. snowflake/snowpark_connect/relation/map_sql.py +155 -78
  35. snowflake/snowpark_connect/relation/map_stats.py +88 -39
  36. snowflake/snowpark_connect/relation/map_subquery_alias.py +13 -14
  37. snowflake/snowpark_connect/relation/map_udtf.py +6 -9
  38. snowflake/snowpark_connect/relation/read/map_read.py +8 -3
  39. snowflake/snowpark_connect/relation/read/map_read_csv.py +7 -7
  40. snowflake/snowpark_connect/relation/read/map_read_jdbc.py +7 -7
  41. snowflake/snowpark_connect/relation/read/map_read_json.py +7 -7
  42. snowflake/snowpark_connect/relation/read/map_read_parquet.py +7 -7
  43. snowflake/snowpark_connect/relation/read/map_read_socket.py +7 -3
  44. snowflake/snowpark_connect/relation/read/map_read_table.py +25 -16
  45. snowflake/snowpark_connect/relation/read/map_read_text.py +7 -7
  46. snowflake/snowpark_connect/relation/utils.py +11 -5
  47. snowflake/snowpark_connect/relation/write/jdbc_write_dbapi.py +15 -12
  48. snowflake/snowpark_connect/relation/write/map_write.py +199 -40
  49. snowflake/snowpark_connect/relation/write/map_write_jdbc.py +3 -2
  50. snowflake/snowpark_connect/server.py +34 -4
  51. snowflake/snowpark_connect/type_mapping.py +2 -23
  52. snowflake/snowpark_connect/utils/cache.py +27 -22
  53. snowflake/snowpark_connect/utils/context.py +33 -17
  54. snowflake/snowpark_connect/utils/{attribute_handling.py → identifiers.py} +47 -0
  55. snowflake/snowpark_connect/utils/session.py +41 -34
  56. snowflake/snowpark_connect/utils/telemetry.py +1 -2
  57. snowflake/snowpark_connect/version.py +1 -1
  58. {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.21.0.dist-info}/METADATA +5 -3
  59. {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.21.0.dist-info}/RECORD +67 -64
  60. snowpark_connect-0.21.0.dist-info/licenses/LICENSE-binary +568 -0
  61. snowpark_connect-0.21.0.dist-info/licenses/NOTICE-binary +1533 -0
  62. {snowpark_connect-0.20.2.data → snowpark_connect-0.21.0.data}/scripts/snowpark-connect +0 -0
  63. {snowpark_connect-0.20.2.data → snowpark_connect-0.21.0.data}/scripts/snowpark-session +0 -0
  64. {snowpark_connect-0.20.2.data → snowpark_connect-0.21.0.data}/scripts/snowpark-submit +0 -0
  65. {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.21.0.dist-info}/WHEEL +0 -0
  66. {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.21.0.dist-info}/licenses/LICENSE.txt +0 -0
  67. {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.21.0.dist-info}/top_level.txt +0 -0
@@ -30,10 +30,9 @@ from snowflake.snowpark.table_function import _ExplodeFunctionCall
30
30
  from snowflake.snowpark.types import DataType, StructField, StructType, _NumericType
31
31
  from snowflake.snowpark_connect.column_name_handler import (
32
32
  make_column_names_snowpark_compatible,
33
- set_schema_getter,
34
- with_column_map,
35
33
  )
36
34
  from snowflake.snowpark_connect.config import global_config
35
+ from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
37
36
  from snowflake.snowpark_connect.error.error_utils import SparkException
38
37
  from snowflake.snowpark_connect.expression.map_expression import (
39
38
  map_alias,
@@ -53,13 +52,13 @@ from snowflake.snowpark_connect.type_mapping import (
53
52
  )
54
53
  from snowflake.snowpark_connect.typed_column import TypedColumn
55
54
  from snowflake.snowpark_connect.utils import context
56
- from snowflake.snowpark_connect.utils.attribute_handling import (
57
- split_fully_qualified_spark_name,
58
- )
59
55
  from snowflake.snowpark_connect.utils.context import (
60
56
  clear_lca_alias_map,
61
57
  register_lca_alias,
62
58
  )
59
+ from snowflake.snowpark_connect.utils.identifiers import (
60
+ split_fully_qualified_spark_name,
61
+ )
63
62
  from snowflake.snowpark_connect.utils.udtf_helper import (
64
63
  TEST_FLAG_FORCE_CREATE_SPROC,
65
64
  create_apply_udtf_in_sproc,
@@ -68,20 +67,21 @@ from snowflake.snowpark_connect.utils.udtf_helper import (
68
67
 
69
68
  def map_drop(
70
69
  rel: relation_proto.Relation,
71
- ) -> snowpark.DataFrame:
70
+ ) -> DataFrameContainer:
72
71
  """
73
72
  Drop columns from a DataFrame.
74
73
 
75
74
  The drop is a list of expressions that is applied to the DataFrame.
76
75
  """
77
- input_df: snowpark.DataFrame = map_relation(rel.drop.input)
76
+ input_container = map_relation(rel.drop.input)
77
+ input_df = input_container.dataframe
78
78
  typer = ExpressionTyper(input_df)
79
79
  columns_to_drop_with_names = []
80
80
  for exp in rel.drop.columns:
81
81
  if exp.WhichOneof("expr_type") == "unresolved_attribute":
82
82
  try:
83
83
  columns_to_drop_with_names.append(
84
- map_single_column_expression(exp, input_df._column_map, typer)
84
+ map_single_column_expression(exp, input_container.column_map, typer)
85
85
  )
86
86
  except AnalysisException as e:
87
87
  if "[COLUMN_NOT_FOUND]" in e.message:
@@ -91,8 +91,8 @@ def map_drop(
91
91
  columns_to_drop: list[Column] = [
92
92
  col[1].col for col in columns_to_drop_with_names
93
93
  ] + [
94
- snowpark_functions_col(c, input_df._column_map)
95
- for c in input_df._column_map.get_snowpark_column_names_from_spark_column_names(
94
+ snowpark_functions_col(c, input_container.column_map)
95
+ for c in input_container.column_map.get_snowpark_column_names_from_spark_column_names(
96
96
  list(rel.drop.column_names)
97
97
  )
98
98
  if c is not None
@@ -100,7 +100,7 @@ def map_drop(
100
100
  # Sometimes we get a drop query with only invalid names. In this case, we return
101
101
  # the input DataFrame.
102
102
  if len(columns_to_drop) == 0:
103
- return input_df
103
+ return input_container
104
104
 
105
105
  def _get_column_names_to_drop() -> list[str]:
106
106
  # more or less copied from Snowpark's DataFrame::drop
@@ -128,47 +128,52 @@ def map_drop(
128
128
  # Snowpark doesn't allow dropping all columns, so we have an EmptyDataFrame
129
129
  # object to handle these cases.
130
130
  try:
131
- new_columns_names = input_df._column_map.get_snowpark_columns_after_drop(
131
+ column_map = input_container.column_map
132
+ new_columns_names = column_map.get_snowpark_columns_after_drop(
132
133
  _get_column_names_to_drop()
133
134
  )
134
135
  result: snowpark.DataFrame = input_df.drop(*columns_to_drop)
135
- return with_column_map(
136
- result,
137
- input_df._column_map.get_spark_column_names_from_snowpark_column_names(
136
+ return DataFrameContainer.create_with_column_mapping(
137
+ dataframe=result,
138
+ spark_column_names=column_map.get_spark_column_names_from_snowpark_column_names(
138
139
  new_columns_names
139
140
  ),
140
141
  snowpark_column_names=new_columns_names,
141
- column_qualifiers=input_df._column_map.get_qualifiers_for_columns_after_drop(
142
+ column_qualifiers=column_map.get_qualifiers_for_columns_after_drop(
142
143
  _get_column_names_to_drop()
143
144
  ),
144
- parent_column_name_map=input_df._column_map,
145
+ parent_column_name_map=column_map,
145
146
  )
146
147
  except snowpark.exceptions.SnowparkColumnException:
147
148
  from snowflake.snowpark_connect.empty_dataframe import EmptyDataFrame
148
149
 
149
- return EmptyDataFrame()
150
+ return DataFrameContainer(EmptyDataFrame())
150
151
 
151
152
 
152
- def map_project(rel: relation_proto.Relation) -> snowpark.DataFrame:
153
+ def map_project(
154
+ rel: relation_proto.Relation,
155
+ ) -> DataFrameContainer:
153
156
  """
154
- Project column(s).
157
+ Project column(s) and return a container.
155
158
 
156
- Projections come in as expressions, which are mapped to `snowpark.Column`
157
- objects.
159
+ Projections come in as expressions, which are mapped to `snowpark.Column` objects.
158
160
  """
159
161
  if rel.project.HasField("input"):
160
- input_df = map_relation(rel.project.input)
162
+ input_container = map_relation(rel.project.input)
163
+ input_df = input_container.dataframe
161
164
  else:
162
165
  # Create a dataframe to represent a OneRowRelation AST node.
163
166
  # XXX: Snowflake does not support 0-column tables, so create a dummy column;
164
167
  # its name does not seem to show up anywhere.
165
168
  session = snowpark.Session.get_active_session()
166
- input_df = with_column_map(
167
- session.create_dataframe([None], ["__DUMMY"]),
168
- ["__DUMMY"],
169
- ["__DUMMY"],
169
+ input_container = DataFrameContainer.create_with_column_mapping(
170
+ dataframe=session.create_dataframe([None], ["__DUMMY"]),
171
+ spark_column_names=["__DUMMY"],
172
+ snowpark_column_names=["__DUMMY"],
170
173
  )
171
- context.set_df_before_projection(input_df)
174
+
175
+ input_df = input_container.dataframe
176
+ context.set_df_before_projection(input_container)
172
177
  expressions: list[expressions_proto.Expression] = rel.project.expressions
173
178
  if not expressions:
174
179
  # XXX: Snowflake does not support 0-column tables, so create a dummy column;
@@ -217,7 +222,7 @@ def map_project(rel: relation_proto.Relation) -> snowpark.DataFrame:
217
222
  )
218
223
 
219
224
  for exp in expressions:
220
- new_spark_names, mapper = map_expression(exp, input_df._column_map, typer)
225
+ new_spark_names, mapper = map_expression(exp, input_container.column_map, typer)
221
226
  if len(new_spark_names) == 1 and not isinstance(
222
227
  mapper.col, _ExplodeFunctionCall
223
228
  ):
@@ -238,7 +243,7 @@ def map_project(rel: relation_proto.Relation) -> snowpark.DataFrame:
238
243
  and not has_unresolved_star
239
244
  ):
240
245
  # Try to get the existing Snowpark column name for this Spark column
241
- existing_snowpark_name = input_df._column_map.get_snowpark_column_name_from_spark_column_name(
246
+ existing_snowpark_name = input_container.column_map.get_snowpark_column_name_from_spark_column_name(
242
247
  spark_name, allow_non_exists=True
243
248
  )
244
249
 
@@ -308,22 +313,28 @@ def map_project(rel: relation_proto.Relation) -> snowpark.DataFrame:
308
313
  result = result.toDF(*final_snowpark_columns)
309
314
  new_snowpark_columns = final_snowpark_columns
310
315
 
311
- return with_column_map(
312
- result,
313
- new_spark_columns,
314
- new_snowpark_columns,
315
- column_types,
316
- column_metadata=input_df._column_map.column_metadata,
316
+ return DataFrameContainer.create_with_column_mapping(
317
+ dataframe=result,
318
+ spark_column_names=new_spark_columns,
319
+ snowpark_column_names=new_snowpark_columns,
320
+ snowpark_column_types=column_types,
321
+ column_metadata=input_container.column_map.column_metadata,
317
322
  column_qualifiers=qualifiers,
318
- parent_column_name_map=input_df._column_map,
323
+ parent_column_name_map=input_container.column_map,
324
+ table_name=input_container.table_name,
325
+ alias=input_container.alias,
319
326
  )
320
327
 
321
328
 
322
- def map_sort(sort: relation_proto.Sort) -> snowpark.DataFrame:
329
+ def map_sort(
330
+ sort: relation_proto.Sort,
331
+ ) -> DataFrameContainer:
323
332
  """
324
- Implements DataFrame.sort().
333
+ Implements DataFrame.sort() and return a container.
334
+
325
335
  """
326
- input_df = map_relation(sort.input)
336
+ input_container = map_relation(sort.input)
337
+ input_df = input_container.dataframe
327
338
  cols = []
328
339
  ascending = [] # Ignored if all order values are set to "unspecified".
329
340
  order_specified = False
@@ -338,7 +349,7 @@ def map_sort(sort: relation_proto.Sort) -> snowpark.DataFrame:
338
349
  if (
339
350
  len(parsed_col_name) == 1
340
351
  and parsed_col_name[0].lower() == "all"
341
- and input_df._column_map.get_snowpark_column_name_from_spark_column_name(
352
+ and input_container.column_map.get_snowpark_column_name_from_spark_column_name(
342
353
  parsed_col_name[0], allow_non_exists=True
343
354
  )
344
355
  is None
@@ -354,7 +365,7 @@ def map_sort(sort: relation_proto.Sort) -> snowpark.DataFrame:
354
365
  direction=sort_order[0].direction,
355
366
  null_ordering=sort_order[0].null_ordering,
356
367
  )
357
- for col in input_df._column_map.get_spark_columns()
368
+ for col in input_container.column_map.get_spark_columns()
358
369
  ]
359
370
 
360
371
  for so in sort_order:
@@ -370,7 +381,7 @@ def map_sort(sort: relation_proto.Sort) -> snowpark.DataFrame:
370
381
  )
371
382
  else:
372
383
  _, typed_column = map_single_column_expression(
373
- so.child, input_df._column_map, typer
384
+ so.child, input_container.column_map, typer
374
385
  )
375
386
  col = typed_column.col
376
387
 
@@ -412,29 +423,35 @@ def map_sort(sort: relation_proto.Sort) -> snowpark.DataFrame:
412
423
  if not order_specified:
413
424
  ascending = None
414
425
  result = input_df.sort(cols, ascending=ascending)
415
- result._column_map = input_df._column_map
416
- result._table_name = input_df._table_name
417
- set_schema_getter(result, lambda: input_df.schema)
418
- return result
426
+
427
+ return DataFrameContainer(
428
+ result,
429
+ input_container.column_map,
430
+ input_container.table_name,
431
+ cached_schema_getter=lambda: input_df.schema,
432
+ )
419
433
 
420
434
 
421
- def map_to_df(rel: relation_proto.Relation) -> snowpark.DataFrame:
435
+ def map_to_df(
436
+ rel: relation_proto.Relation,
437
+ ) -> DataFrameContainer:
422
438
  """
423
- Transform the column names of the input DataFrame.
439
+ Transform the column names of the input DataFrame and return a container.
424
440
  """
425
- input_df: snowpark.DataFrame = map_relation(rel.to_df.input)
441
+ input_container = map_relation(rel.to_df.input)
442
+ input_df = input_container.dataframe
443
+
426
444
  new_column_names = list(rel.to_df.column_names)
427
- if len(new_column_names) != len(input_df._column_map.columns):
445
+ if len(new_column_names) != len(input_container.column_map.columns):
428
446
  # TODO: Check error type here
429
447
  raise ValueError(
430
448
  "Number of column names must match number of columns in DataFrame"
431
449
  )
432
-
433
450
  snowpark_new_column_names = make_column_names_snowpark_compatible(
434
451
  new_column_names, rel.common.plan_id
435
452
  )
436
-
437
453
  result = input_df.toDF(*snowpark_new_column_names)
454
+
438
455
  if result._select_statement is not None:
439
456
  # do not allow snowpark to flatten the to_df result
440
457
  # TODO: remove after SNOW-2203706 is fixed
@@ -448,27 +465,33 @@ def map_to_df(rel: relation_proto.Relation) -> snowpark.DataFrame:
448
465
  ]
449
466
  )
450
467
 
451
- set_schema_getter(result, _get_schema)
452
- result_with_column_map = with_column_map(
453
- result,
454
- new_column_names,
468
+ result_container = DataFrameContainer.create_with_column_mapping(
469
+ dataframe=result,
470
+ spark_column_names=new_column_names,
455
471
  snowpark_column_names=snowpark_new_column_names,
472
+ parent_column_name_map=input_container.column_map,
473
+ table_name=input_container.table_name,
474
+ alias=input_container.alias,
475
+ cached_schema_getter=_get_schema,
456
476
  )
457
- context.set_df_before_projection(result_with_column_map)
458
- return result_with_column_map
477
+ context.set_df_before_projection(result_container)
478
+ return result_container
459
479
 
460
480
 
461
- def map_to_schema(rel: relation_proto.Relation) -> snowpark.DataFrame:
481
+ def map_to_schema(
482
+ rel: relation_proto.Relation,
483
+ ) -> DataFrameContainer:
462
484
  """
463
485
  Transform the column names of the input DataFrame.
464
486
  """
465
- input_df: snowpark.DataFrame = map_relation(rel.to_schema.input)
487
+ input_container = map_relation(rel.to_schema.input)
488
+ input_df = input_container.dataframe
466
489
  new_column_names = [field.name for field in rel.to_schema.schema.struct.fields]
467
490
  snowpark_new_column_names = make_column_names_snowpark_compatible(
468
491
  new_column_names, rel.common.plan_id
469
492
  )
470
493
  count_case_insensitive_column_names = defaultdict()
471
- for key, value in input_df._column_map.spark_to_col.items():
494
+ for key, value in input_container.column_map.spark_to_col.items():
472
495
  count_case_insensitive_column_names[
473
496
  key.lower()
474
497
  ] = count_case_insensitive_column_names.get(key.lower(), 0) + len(value)
@@ -483,12 +506,12 @@ def map_to_schema(rel: relation_proto.Relation) -> snowpark.DataFrame:
483
506
  if field.name in already_existing_columns:
484
507
  if count_case_insensitive_column_names[field.name.lower()] > 1:
485
508
  raise AnalysisException(
486
- f"[AMBIGUOUS_COLUMN_OR_FIELD] Column or field `{field.name}` is ambiguous and has {len(input_df._column_map.spark_to_col[field.name])} matches."
509
+ f"[AMBIGUOUS_COLUMN_OR_FIELD] Column or field `{field.name}` is ambiguous and has {len(input_container.column_map.spark_to_col[field.name])} matches."
487
510
  )
488
511
  snowpark_name = None
489
- for name in input_df._column_map.spark_to_col:
512
+ for name in input_container.column_map.spark_to_col:
490
513
  if name.lower() == field.name.lower():
491
- snowpark_name = input_df._column_map.spark_to_col[name][
514
+ snowpark_name = input_container.column_map.spark_to_col[name][
492
515
  0
493
516
  ].snowpark_name
494
517
  break
@@ -516,10 +539,10 @@ def map_to_schema(rel: relation_proto.Relation) -> snowpark.DataFrame:
516
539
  # All columns already exist, we're doing a simple update.
517
540
  snowpark_new_column_names = []
518
541
  for column in new_column_names:
519
- for name in input_df._column_map.spark_to_col:
542
+ for name in input_container.column_map.spark_to_col:
520
543
  if name.lower() == column.lower():
521
544
  snowpark_new_column_names.append(
522
- input_df._column_map.spark_to_col[name][0].snowpark_name
545
+ input_container.column_map.spark_to_col[name][0].snowpark_name
523
546
  )
524
547
  result = input_df
525
548
  elif len(already_existing_columns) == 0:
@@ -540,16 +563,18 @@ def map_to_schema(rel: relation_proto.Relation) -> snowpark.DataFrame:
540
563
  # If the column doesn't already exist, append the new Snowpark name to columns_to_add
541
564
  if all(
542
565
  spark_column.lower() != name.lower()
543
- for name in input_df._column_map.spark_to_col
566
+ for name in input_container.column_map.spark_to_col
544
567
  ):
545
568
  columns_to_add.append(snowpark_column)
546
569
  new_snowpark_new_column_names.append(snowpark_column)
547
570
  else:
548
- for name in input_df._column_map.spark_to_col:
571
+ for name in input_container.column_map.spark_to_col:
549
572
  # If the column does exist, append the original Snowpark name, We don't need to add this column.
550
573
  if name.lower() == spark_column.lower():
551
574
  new_snowpark_new_column_names.append(
552
- input_df._column_map.spark_to_col[name][0].snowpark_name
575
+ input_container.column_map.spark_to_col[name][
576
+ 0
577
+ ].snowpark_name
553
578
  )
554
579
  # Add all columns introduced by the new schema.
555
580
  new_columns = [
@@ -581,21 +606,24 @@ def map_to_schema(rel: relation_proto.Relation) -> snowpark.DataFrame:
581
606
  column_metadata[field.name] = None
582
607
  else:
583
608
  column_metadata[field.name] = None
584
- return with_column_map(
585
- result_with_casting,
586
- new_column_names,
609
+ return DataFrameContainer.create_with_column_mapping(
610
+ dataframe=result_with_casting,
611
+ spark_column_names=new_column_names,
587
612
  snowpark_column_names=snowpark_new_column_names,
588
613
  snowpark_column_types=[field.datatype for field in snowpark_schema.fields],
589
614
  column_metadata=column_metadata,
590
- parent_column_name_map=input_df._column_map,
615
+ parent_column_name_map=input_container.column_map,
591
616
  )
592
617
 
593
618
 
594
- def map_with_columns_renamed(rel: relation_proto.Relation) -> snowpark.DataFrame:
619
+ def map_with_columns_renamed(
620
+ rel: relation_proto.Relation,
621
+ ) -> DataFrameContainer:
595
622
  """
596
- Rename columns in a DataFrame.
623
+ Rename columns in a DataFrame and return a container.
597
624
  """
598
- input_df: snowpark.DataFrame = map_relation(rel.with_columns_renamed.input)
625
+ input_container = map_relation(rel.with_columns_renamed.input)
626
+ input_df = input_container.dataframe
599
627
  rename_columns_map = dict(rel.with_columns_renamed.rename_columns_map)
600
628
 
601
629
  if not global_config.spark_sql_caseSensitive:
@@ -608,9 +636,11 @@ def map_with_columns_renamed(rel: relation_proto.Relation) -> snowpark.DataFrame
608
636
  k.lower(): v.lower() for k, v in rename_columns_map.items()
609
637
  }
610
638
 
639
+ column_map = input_container.column_map
640
+
611
641
  # re-construct the rename chains based on the input dataframe.
612
- if input_df._column_map.rename_chains:
613
- for key, value in input_df._column_map.rename_chains.items():
642
+ if input_container.column_map.rename_chains:
643
+ for key, value in input_container.column_map.rename_chains.items():
614
644
  if key in rename_columns_map:
615
645
  # This is to handle the case where the same column is renamed multiple times.
616
646
  # df.withColumnRenamed("a", "b").withColumnRenamed("a", "c")
@@ -628,19 +658,25 @@ def map_with_columns_renamed(rel: relation_proto.Relation) -> snowpark.DataFrame
628
658
  # This just copies the renames from previous computed dataframe
629
659
  rename_columns_map[key] = value
630
660
 
631
- existing_columns = input_df._column_map.get_spark_columns()
661
+ existing_columns = input_container.column_map.get_spark_columns()
662
+
663
+ def _column_exists_error(name: str) -> AnalysisException:
664
+ return AnalysisException(
665
+ f"[COLUMN_ALREADY_EXISTS] The column `{name}` already exists. Consider to choose another name or rename the existing column."
666
+ )
632
667
 
633
668
  # Validate for naming conflicts
634
669
  new_names_list = list(dict(rel.with_columns_renamed.rename_columns_map).values())
635
670
  seen = set()
636
671
  for new_name in new_names_list:
672
+ if column_map.has_spark_column(new_name):
673
+ # Spark doesn't allow reusing existing names, even if the result df will not contain duplicate columns
674
+ raise _column_exists_error(new_name)
637
675
  if (global_config.spark_sql_caseSensitive and new_name in seen) or (
638
676
  not global_config.spark_sql_caseSensitive
639
677
  and new_name.lower() in [s.lower() for s in seen]
640
678
  ):
641
- raise AnalysisException(
642
- f"[COLUMN_ALREADY_EXISTS] The column `{new_name}` already exists. Consider to choose another name or rename the existing column."
643
- )
679
+ raise _column_exists_error(new_name)
644
680
  seen.add(new_name)
645
681
 
646
682
  new_columns = []
@@ -656,25 +692,30 @@ def map_with_columns_renamed(rel: relation_proto.Relation) -> snowpark.DataFrame
656
692
 
657
693
  # Creating a new df to avoid updating the state of cached dataframe.
658
694
  new_df = input_df.select("*")
659
- result_df = with_column_map(
660
- new_df,
661
- new_columns,
662
- input_df._column_map.get_snowpark_columns(),
663
- column_qualifiers=input_df._column_map.get_qualifiers(),
664
- parent_column_name_map=input_df._column_map.get_parent_column_name_map(),
695
+ result_container = DataFrameContainer.create_with_column_mapping(
696
+ dataframe=new_df,
697
+ spark_column_names=new_columns,
698
+ snowpark_column_names=input_container.column_map.get_snowpark_columns(),
699
+ column_qualifiers=input_container.column_map.get_qualifiers(),
700
+ parent_column_name_map=input_container.column_map.get_parent_column_name_map(),
701
+ table_name=input_container.table_name,
702
+ alias=input_container.alias,
665
703
  )
666
- result_df._column_map.rename_chains = rename_columns_map
704
+ result_container.column_map.rename_chains = rename_columns_map
667
705
 
668
- return result_df
706
+ return result_container
669
707
 
670
708
 
671
- def map_with_columns(rel: relation_proto.Relation) -> snowpark.DataFrame:
709
+ def map_with_columns(
710
+ rel: relation_proto.Relation,
711
+ ) -> DataFrameContainer:
672
712
  """
673
- Add columns to a DataFrame.
713
+ Add columns to a DataFrame and return a container.
674
714
  """
675
- input_df: snowpark.DataFrame = map_relation(rel.with_columns.input)
715
+ input_container = map_relation(rel.with_columns.input)
716
+ input_df = input_container.dataframe
676
717
  with_columns = [
677
- map_alias(alias, input_df._column_map, ExpressionTyper(input_df))
718
+ map_alias(alias, input_container.column_map, ExpressionTyper(input_df))
678
719
  for alias in rel.with_columns.aliases
679
720
  ]
680
721
  # TODO: This list needs to contain all unique column names, but the code below doesn't
@@ -682,7 +723,7 @@ def map_with_columns(rel: relation_proto.Relation) -> snowpark.DataFrame:
682
723
  with_columns_names = []
683
724
  with_columns_exprs = []
684
725
  with_columns_types = []
685
- with_column_offset = len(input_df._column_map.get_spark_columns())
726
+ with_column_offset = len(input_container.column_map.get_spark_columns())
686
727
  new_spark_names = []
687
728
  seen_columns = set()
688
729
  for names_list, expr in with_columns:
@@ -690,7 +731,7 @@ def map_with_columns(rel: relation_proto.Relation) -> snowpark.DataFrame:
690
731
  len(names_list) == 1
691
732
  ), f"Expected single column name, got {len(names_list)}: {names_list}"
692
733
  name = names_list[0]
693
- name_normalized = input_df._column_map._normalized_spark_name(name)
734
+ name_normalized = input_container.column_map._normalized_spark_name(name)
694
735
  if name_normalized in seen_columns:
695
736
  raise ValueError(
696
737
  f"[COLUMN_ALREADY_EXISTS] The column `{name}` already exists."
@@ -698,11 +739,9 @@ def map_with_columns(rel: relation_proto.Relation) -> snowpark.DataFrame:
698
739
  seen_columns.add(name_normalized)
699
740
  # If the column name is already in the DataFrame, we replace it, so we use the
700
741
  # mapping to get the correct column name.
701
- if input_df._column_map.has_spark_column(name):
702
- all_instances_of_spark_column_name = (
703
- input_df._column_map.get_snowpark_column_names_from_spark_column_names(
704
- [name]
705
- )
742
+ if input_container.column_map.has_spark_column(name):
743
+ all_instances_of_spark_column_name = input_container.column_map.get_snowpark_column_names_from_spark_column_names(
744
+ [name]
706
745
  )
707
746
  if len(all_instances_of_spark_column_name) == 0:
708
747
  raise KeyError(f"Spark column name {name} does not exist")
@@ -729,7 +768,7 @@ def map_with_columns(rel: relation_proto.Relation) -> snowpark.DataFrame:
729
768
  new_spark_columns,
730
769
  new_snowpark_columns,
731
770
  qualifiers,
732
- ) = input_df._column_map.with_columns(new_spark_names, with_columns_names)
771
+ ) = input_container.column_map.with_columns(new_spark_names, with_columns_names)
733
772
 
734
773
  # dedup the change in columns at snowpark name level, this is required by the with columns functions
735
774
  with_columns_names_deduped = []
@@ -751,32 +790,39 @@ def map_with_columns(rel: relation_proto.Relation) -> snowpark.DataFrame:
751
790
  + list(zip(with_columns_names, with_columns_types))
752
791
  )
753
792
 
754
- column_metadata = input_df._column_map.column_metadata or {}
793
+ column_metadata = input_container.column_map.column_metadata or {}
755
794
  for alias in rel.with_columns.aliases:
756
795
  # this logic is triggered for df.withMetadata function.
757
796
  if alias.HasField("metadata") and len(alias.metadata.strip()) > 0:
758
797
  # spark sends list of alias names with only one element in the list with alias name.
759
798
  column_metadata[alias.name[0]] = json.loads(alias.metadata)
760
799
 
761
- return with_column_map(
762
- result,
763
- new_spark_columns,
800
+ return DataFrameContainer.create_with_column_mapping(
801
+ dataframe=result,
802
+ spark_column_names=new_spark_columns,
764
803
  snowpark_column_names=new_snowpark_columns,
765
804
  snowpark_column_types=[
766
805
  snowpark_name_to_type.get(n) for n in new_snowpark_columns
767
806
  ],
768
807
  column_metadata=column_metadata,
769
808
  column_qualifiers=qualifiers,
770
- parent_column_name_map=input_df._column_map,
809
+ parent_column_name_map=input_container.column_map,
810
+ table_name=input_container.table_name,
811
+ alias=input_container.alias,
771
812
  )
772
813
 
773
814
 
774
- def map_unpivot(rel: relation_proto.Relation) -> snowpark.DataFrame:
815
+ def map_unpivot(
816
+ rel: relation_proto.Relation,
817
+ ) -> DataFrameContainer:
775
818
  # Spark API: df.unpivot([id_columns], [unpivot_columns], var_column, val_column)
776
819
  # Snowpark API: df.unpivot(val_column, var_column, [unpivot_columns])
777
820
  if rel.unpivot.HasField("values") and len(rel.unpivot.values.values) == 0:
778
821
  raise SparkException.unpivot_requires_value_columns()
779
822
 
823
+ input_container = map_relation(rel.unpivot.input)
824
+ input_df = input_container.dataframe
825
+
780
826
  def get_lease_common_ancestor_classes(types: list[snowpark.types.DataType]) -> set:
781
827
  mro_lists = [set(type.__class__.mro()) for type in types]
782
828
  common_ancestors = set.intersection(*mro_lists)
@@ -795,12 +841,15 @@ def map_unpivot(rel: relation_proto.Relation) -> snowpark.DataFrame:
795
841
  type_column_list = [
796
842
  (
797
843
  f.datatype,
798
- df._column_map.get_spark_column_name_from_snowpark_column_name(
799
- snowpark_functions_col(f.name, df._column_map).get_name()
844
+ input_container.column_map.get_spark_column_name_from_snowpark_column_name(
845
+ snowpark_functions_col(
846
+ f.name, input_container.column_map
847
+ ).get_name()
800
848
  ),
801
849
  )
802
850
  for f in df.schema.fields
803
- if snowpark_functions_col(f.name, df._column_map).get_name() in col_names
851
+ if snowpark_functions_col(f.name, input_container.column_map).get_name()
852
+ in col_names
804
853
  ]
805
854
  type_iter, _ = zip(*type_column_list)
806
855
  type_list = list(type_iter)
@@ -837,7 +886,7 @@ def map_unpivot(rel: relation_proto.Relation) -> snowpark.DataFrame:
837
886
  typer = ExpressionTyper(input_df)
838
887
  for id_col in relation.unpivot.ids:
839
888
  spark_name, typed_column = map_single_column_expression(
840
- id_col, df._column_map, typer
889
+ id_col, input_container.column_map, typer
841
890
  )
842
891
  id_col_names.append(typed_column.col.get_name())
843
892
  spark_columns.append(spark_name)
@@ -848,7 +897,7 @@ def map_unpivot(rel: relation_proto.Relation) -> snowpark.DataFrame:
848
897
  unpivot_spark_names = []
849
898
  for v in relation.unpivot.values.values:
850
899
  spark_name, typed_column = map_single_column_expression(
851
- v, df._column_map, typer
900
+ v, input_container.column_map, typer
852
901
  )
853
902
  unpivot_col_names.append(typed_column.col.get_name())
854
903
  unpivot_spark_names.append(spark_name)
@@ -856,15 +905,19 @@ def map_unpivot(rel: relation_proto.Relation) -> snowpark.DataFrame:
856
905
  if not rel.unpivot.HasField("values"):
857
906
  # When `values` is `None`, all non-id columns will be unpivoted.
858
907
  for snowpark_name, spark_name in zip(
859
- df._column_map.get_snowpark_columns(),
860
- df._column_map.get_spark_columns(),
908
+ input_container.column_map.get_snowpark_columns(),
909
+ input_container.column_map.get_spark_columns(),
861
910
  ):
862
911
  if (
863
- snowpark_functions_col(snowpark_name, df._column_map).get_name()
912
+ snowpark_functions_col(
913
+ snowpark_name, input_container.column_map
914
+ ).get_name()
864
915
  not in id_col_names
865
916
  ):
866
917
  unpivot_col_names.append(
867
- snowpark_functions_col(snowpark_name, df._column_map).get_name()
918
+ snowpark_functions_col(
919
+ snowpark_name, input_container.column_map
920
+ ).get_name()
868
921
  )
869
922
  unpivot_spark_names.append(spark_name)
870
923
 
@@ -872,7 +925,6 @@ def map_unpivot(rel: relation_proto.Relation) -> snowpark.DataFrame:
872
925
  spark_columns.append(relation.unpivot.value_column_name)
873
926
  return spark_columns, id_col_names, unpivot_col_names, unpivot_spark_names
874
927
 
875
- input_df: snowpark.DataFrame = map_relation(rel.unpivot.input)
876
928
  (
877
929
  spark_columns,
878
930
  id_col_names,
@@ -899,27 +951,35 @@ def map_unpivot(rel: relation_proto.Relation) -> snowpark.DataFrame:
899
951
  column_reverse_project = []
900
952
  snowpark_columns = []
901
953
  qualifiers = []
902
- for c in input_df._column_map.get_snowpark_columns():
903
- c_name = snowpark_functions_col(c, input_df._column_map).get_name()
954
+ for c in input_container.column_map.get_snowpark_columns():
955
+ c_name = snowpark_functions_col(c, input_container.column_map).get_name()
904
956
  if c_name in unpivot_col_names:
905
957
  if cast_type:
906
958
  column_project.append(
907
- snowpark_functions_col(c, input_df._column_map)
959
+ snowpark_functions_col(c, input_container.column_map)
908
960
  .cast("DOUBLE")
909
961
  .alias(c_name)
910
962
  )
911
963
  else:
912
- column_project.append(snowpark_functions_col(c, input_df._column_map))
964
+ column_project.append(
965
+ snowpark_functions_col(c, input_container.column_map)
966
+ )
913
967
  if c_name in id_col_names:
914
968
  id_col_alias = "SES" + generate_random_alphanumeric().upper()
915
969
  column_project.append(
916
- snowpark_functions_col(c, input_df._column_map).alias(id_col_alias)
970
+ snowpark_functions_col(c, input_container.column_map).alias(
971
+ id_col_alias
972
+ )
917
973
  )
918
974
  column_reverse_project.append(
919
- snowpark_functions_col(id_col_alias, input_df._column_map).alias(c)
975
+ snowpark_functions_col(id_col_alias, input_container.column_map).alias(
976
+ c
977
+ )
920
978
  )
921
979
  snowpark_columns.append(c)
922
- qualifiers.append(input_df._column_map.get_qualifier_for_spark_column(c))
980
+ qualifiers.append(
981
+ input_container.column_map.get_qualifier_for_spark_column(c)
982
+ )
923
983
 
924
984
  # Without the case when postprocessing, the result Spark dataframe is:
925
985
  # +---+------------+------+
@@ -941,7 +1001,7 @@ def map_unpivot(rel: relation_proto.Relation) -> snowpark.DataFrame:
941
1001
  if post_process_variable_column is None:
942
1002
  post_process_variable_column = snowpark_fn.when(
943
1003
  snowpark_functions_col(
944
- snowpark_variable_column_name, input_df._column_map
1004
+ snowpark_variable_column_name, input_container.column_map
945
1005
  )
946
1006
  == unquote_if_quoted(snowpark_name),
947
1007
  spark_name,
@@ -949,7 +1009,7 @@ def map_unpivot(rel: relation_proto.Relation) -> snowpark.DataFrame:
949
1009
  else:
950
1010
  post_process_variable_column = post_process_variable_column.when(
951
1011
  snowpark_functions_col(
952
- snowpark_variable_column_name, input_df._column_map
1012
+ snowpark_variable_column_name, input_container.column_map
953
1013
  )
954
1014
  == unquote_if_quoted(snowpark_name),
955
1015
  spark_name,
@@ -960,7 +1020,7 @@ def map_unpivot(rel: relation_proto.Relation) -> snowpark.DataFrame:
960
1020
  )
961
1021
  snowpark_columns.append(snowpark_variable_column_name)
962
1022
  column_reverse_project.append(
963
- snowpark_functions_col(snowpark_value_column_name, input_df._column_map)
1023
+ snowpark_functions_col(snowpark_value_column_name, input_container.column_map)
964
1024
  )
965
1025
  snowpark_columns.append(snowpark_value_column_name)
966
1026
  qualifiers.extend([[]] * 2)
@@ -975,20 +1035,23 @@ def map_unpivot(rel: relation_proto.Relation) -> snowpark.DataFrame:
975
1035
  )
976
1036
  .select(*column_reverse_project)
977
1037
  )
978
- return with_column_map(
979
- result,
980
- spark_columns,
981
- snowpark_columns,
1038
+ return DataFrameContainer.create_with_column_mapping(
1039
+ dataframe=result,
1040
+ spark_column_names=spark_columns,
1041
+ snowpark_column_names=snowpark_columns,
982
1042
  column_qualifiers=qualifiers,
983
- parent_column_name_map=input_df._column_map,
1043
+ parent_column_name_map=input_container.column_map,
984
1044
  )
985
1045
 
986
1046
 
987
- def map_group_map(rel: relation_proto.Relation) -> snowpark.DataFrame:
1047
+ def map_group_map(
1048
+ rel: relation_proto.Relation,
1049
+ ) -> DataFrameContainer:
988
1050
  """
989
1051
  Add columns to a DataFrame.
990
1052
  """
991
- input_df: snowpark.DataFrame = map_relation(rel.group_map.input)
1053
+ input_container = map_relation(rel.group_map.input)
1054
+ input_df = input_container.dataframe
992
1055
  grouping_expressions = rel.group_map.grouping_expressions
993
1056
  snowpark_grouping_expressions: list[snowpark.Column] = []
994
1057
  typer = ExpressionTyper(input_df)
@@ -996,7 +1059,7 @@ def map_group_map(rel: relation_proto.Relation) -> snowpark.DataFrame:
996
1059
  qualifiers = []
997
1060
  for exp in grouping_expressions:
998
1061
  new_name, snowpark_column = map_single_column_expression(
999
- exp, input_df._column_map, typer
1062
+ exp, input_container.column_map, typer
1000
1063
  )
1001
1064
  snowpark_grouping_expressions.append(snowpark_column.col)
1002
1065
  group_name_list.append(new_name)
@@ -1013,9 +1076,9 @@ def map_group_map(rel: relation_proto.Relation) -> snowpark.DataFrame:
1013
1076
 
1014
1077
  if not is_compatible_python or TEST_FLAG_FORCE_CREATE_SPROC:
1015
1078
  original_columns = None
1016
- if input_df._column_map is not None:
1079
+ if input_container.column_map is not None:
1017
1080
  original_columns = [
1018
- column.spark_name for column in input_df._column_map.columns
1081
+ column.spark_name for column in input_container.column_map.columns
1019
1082
  ]
1020
1083
 
1021
1084
  apply_udtf_temp_name = create_apply_udtf_in_sproc(
@@ -1044,12 +1107,12 @@ def map_group_map(rel: relation_proto.Relation) -> snowpark.DataFrame:
1044
1107
  )
1045
1108
 
1046
1109
  qualifiers.extend([[]] * (len(result.columns) - len(group_name_list)))
1047
- return with_column_map(
1048
- result,
1049
- [field.name for field in output_type],
1050
- result.columns,
1110
+ return DataFrameContainer.create_with_column_mapping(
1111
+ dataframe=result,
1112
+ spark_column_names=[field.name for field in output_type],
1113
+ snowpark_column_names=result.columns,
1051
1114
  column_qualifiers=qualifiers,
1052
- parent_column_name_map=input_df._column_map,
1115
+ parent_column_name_map=input_container.column_map,
1053
1116
  )
1054
1117
 
1055
1118