snowpark-connect 0.20.2__py3-none-any.whl → 0.22.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of snowpark-connect might be problematic. Click here for more details.

Files changed (84) hide show
  1. snowflake/snowpark_connect/analyze_plan/map_tree_string.py +3 -2
  2. snowflake/snowpark_connect/column_name_handler.py +6 -65
  3. snowflake/snowpark_connect/config.py +47 -17
  4. snowflake/snowpark_connect/dataframe_container.py +242 -0
  5. snowflake/snowpark_connect/error/error_utils.py +25 -0
  6. snowflake/snowpark_connect/execute_plan/map_execution_command.py +13 -23
  7. snowflake/snowpark_connect/execute_plan/map_execution_root.py +9 -5
  8. snowflake/snowpark_connect/expression/map_extension.py +2 -1
  9. snowflake/snowpark_connect/expression/map_udf.py +4 -4
  10. snowflake/snowpark_connect/expression/map_unresolved_attribute.py +8 -7
  11. snowflake/snowpark_connect/expression/map_unresolved_function.py +481 -170
  12. snowflake/snowpark_connect/expression/map_unresolved_star.py +8 -8
  13. snowflake/snowpark_connect/expression/map_update_fields.py +1 -1
  14. snowflake/snowpark_connect/expression/typer.py +6 -6
  15. snowflake/snowpark_connect/proto/control_pb2.py +17 -16
  16. snowflake/snowpark_connect/proto/control_pb2.pyi +17 -17
  17. snowflake/snowpark_connect/proto/control_pb2_grpc.py +12 -63
  18. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.py +15 -14
  19. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.pyi +19 -14
  20. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2_grpc.py +4 -0
  21. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +27 -26
  22. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +74 -68
  23. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2_grpc.py +4 -0
  24. snowflake/snowpark_connect/relation/catalogs/abstract_spark_catalog.py +5 -5
  25. snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +25 -17
  26. snowflake/snowpark_connect/relation/map_aggregate.py +170 -61
  27. snowflake/snowpark_connect/relation/map_catalog.py +2 -2
  28. snowflake/snowpark_connect/relation/map_column_ops.py +227 -145
  29. snowflake/snowpark_connect/relation/map_crosstab.py +25 -6
  30. snowflake/snowpark_connect/relation/map_extension.py +81 -56
  31. snowflake/snowpark_connect/relation/map_join.py +72 -63
  32. snowflake/snowpark_connect/relation/map_local_relation.py +35 -20
  33. snowflake/snowpark_connect/relation/map_map_partitions.py +24 -17
  34. snowflake/snowpark_connect/relation/map_relation.py +22 -16
  35. snowflake/snowpark_connect/relation/map_row_ops.py +232 -146
  36. snowflake/snowpark_connect/relation/map_sample_by.py +15 -8
  37. snowflake/snowpark_connect/relation/map_show_string.py +42 -5
  38. snowflake/snowpark_connect/relation/map_sql.py +141 -237
  39. snowflake/snowpark_connect/relation/map_stats.py +88 -39
  40. snowflake/snowpark_connect/relation/map_subquery_alias.py +13 -14
  41. snowflake/snowpark_connect/relation/map_udtf.py +10 -13
  42. snowflake/snowpark_connect/relation/read/map_read.py +8 -3
  43. snowflake/snowpark_connect/relation/read/map_read_csv.py +7 -7
  44. snowflake/snowpark_connect/relation/read/map_read_jdbc.py +7 -7
  45. snowflake/snowpark_connect/relation/read/map_read_json.py +19 -8
  46. snowflake/snowpark_connect/relation/read/map_read_parquet.py +7 -7
  47. snowflake/snowpark_connect/relation/read/map_read_socket.py +7 -3
  48. snowflake/snowpark_connect/relation/read/map_read_table.py +25 -16
  49. snowflake/snowpark_connect/relation/read/map_read_text.py +7 -7
  50. snowflake/snowpark_connect/relation/read/reader_config.py +1 -0
  51. snowflake/snowpark_connect/relation/utils.py +11 -5
  52. snowflake/snowpark_connect/relation/write/jdbc_write_dbapi.py +15 -12
  53. snowflake/snowpark_connect/relation/write/map_write.py +259 -56
  54. snowflake/snowpark_connect/relation/write/map_write_jdbc.py +3 -2
  55. snowflake/snowpark_connect/server.py +43 -4
  56. snowflake/snowpark_connect/type_mapping.py +6 -23
  57. snowflake/snowpark_connect/utils/cache.py +27 -22
  58. snowflake/snowpark_connect/utils/context.py +33 -17
  59. snowflake/snowpark_connect/utils/describe_query_cache.py +2 -9
  60. snowflake/snowpark_connect/utils/{attribute_handling.py → identifiers.py} +47 -0
  61. snowflake/snowpark_connect/utils/session.py +41 -38
  62. snowflake/snowpark_connect/utils/telemetry.py +214 -63
  63. snowflake/snowpark_connect/utils/udxf_import_utils.py +14 -0
  64. snowflake/snowpark_connect/version.py +1 -1
  65. snowflake/snowpark_decoder/__init__.py +0 -0
  66. snowflake/snowpark_decoder/_internal/proto/generated/DataframeProcessorMsg_pb2.py +36 -0
  67. snowflake/snowpark_decoder/_internal/proto/generated/DataframeProcessorMsg_pb2.pyi +156 -0
  68. snowflake/snowpark_decoder/dp_session.py +111 -0
  69. snowflake/snowpark_decoder/spark_decoder.py +76 -0
  70. {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.22.1.dist-info}/METADATA +6 -4
  71. {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.22.1.dist-info}/RECORD +83 -69
  72. snowpark_connect-0.22.1.dist-info/licenses/LICENSE-binary +568 -0
  73. snowpark_connect-0.22.1.dist-info/licenses/NOTICE-binary +1533 -0
  74. {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.22.1.dist-info}/top_level.txt +1 -0
  75. spark/__init__.py +0 -0
  76. spark/connect/__init__.py +0 -0
  77. spark/connect/envelope_pb2.py +31 -0
  78. spark/connect/envelope_pb2.pyi +46 -0
  79. snowflake/snowpark_connect/includes/jars/jackson-mapper-asl-1.9.13.jar +0 -0
  80. {snowpark_connect-0.20.2.data → snowpark_connect-0.22.1.data}/scripts/snowpark-connect +0 -0
  81. {snowpark_connect-0.20.2.data → snowpark_connect-0.22.1.data}/scripts/snowpark-session +0 -0
  82. {snowpark_connect-0.20.2.data → snowpark_connect-0.22.1.data}/scripts/snowpark-submit +0 -0
  83. {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.22.1.dist-info}/WHEEL +0 -0
  84. {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.22.1.dist-info}/licenses/LICENSE.txt +0 -0
@@ -30,10 +30,9 @@ from snowflake.snowpark.table_function import _ExplodeFunctionCall
30
30
  from snowflake.snowpark.types import DataType, StructField, StructType, _NumericType
31
31
  from snowflake.snowpark_connect.column_name_handler import (
32
32
  make_column_names_snowpark_compatible,
33
- set_schema_getter,
34
- with_column_map,
35
33
  )
36
34
  from snowflake.snowpark_connect.config import global_config
35
+ from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
37
36
  from snowflake.snowpark_connect.error.error_utils import SparkException
38
37
  from snowflake.snowpark_connect.expression.map_expression import (
39
38
  map_alias,
@@ -53,13 +52,13 @@ from snowflake.snowpark_connect.type_mapping import (
53
52
  )
54
53
  from snowflake.snowpark_connect.typed_column import TypedColumn
55
54
  from snowflake.snowpark_connect.utils import context
56
- from snowflake.snowpark_connect.utils.attribute_handling import (
57
- split_fully_qualified_spark_name,
58
- )
59
55
  from snowflake.snowpark_connect.utils.context import (
60
56
  clear_lca_alias_map,
61
57
  register_lca_alias,
62
58
  )
59
+ from snowflake.snowpark_connect.utils.identifiers import (
60
+ split_fully_qualified_spark_name,
61
+ )
63
62
  from snowflake.snowpark_connect.utils.udtf_helper import (
64
63
  TEST_FLAG_FORCE_CREATE_SPROC,
65
64
  create_apply_udtf_in_sproc,
@@ -68,20 +67,21 @@ from snowflake.snowpark_connect.utils.udtf_helper import (
68
67
 
69
68
  def map_drop(
70
69
  rel: relation_proto.Relation,
71
- ) -> snowpark.DataFrame:
70
+ ) -> DataFrameContainer:
72
71
  """
73
72
  Drop columns from a DataFrame.
74
73
 
75
74
  The drop is a list of expressions that is applied to the DataFrame.
76
75
  """
77
- input_df: snowpark.DataFrame = map_relation(rel.drop.input)
76
+ input_container = map_relation(rel.drop.input)
77
+ input_df = input_container.dataframe
78
78
  typer = ExpressionTyper(input_df)
79
79
  columns_to_drop_with_names = []
80
80
  for exp in rel.drop.columns:
81
81
  if exp.WhichOneof("expr_type") == "unresolved_attribute":
82
82
  try:
83
83
  columns_to_drop_with_names.append(
84
- map_single_column_expression(exp, input_df._column_map, typer)
84
+ map_single_column_expression(exp, input_container.column_map, typer)
85
85
  )
86
86
  except AnalysisException as e:
87
87
  if "[COLUMN_NOT_FOUND]" in e.message:
@@ -91,8 +91,8 @@ def map_drop(
91
91
  columns_to_drop: list[Column] = [
92
92
  col[1].col for col in columns_to_drop_with_names
93
93
  ] + [
94
- snowpark_functions_col(c, input_df._column_map)
95
- for c in input_df._column_map.get_snowpark_column_names_from_spark_column_names(
94
+ snowpark_functions_col(c, input_container.column_map)
95
+ for c in input_container.column_map.get_snowpark_column_names_from_spark_column_names(
96
96
  list(rel.drop.column_names)
97
97
  )
98
98
  if c is not None
@@ -100,7 +100,7 @@ def map_drop(
100
100
  # Sometimes we get a drop query with only invalid names. In this case, we return
101
101
  # the input DataFrame.
102
102
  if len(columns_to_drop) == 0:
103
- return input_df
103
+ return input_container
104
104
 
105
105
  def _get_column_names_to_drop() -> list[str]:
106
106
  # more or less copied from Snowpark's DataFrame::drop
@@ -128,47 +128,52 @@ def map_drop(
128
128
  # Snowpark doesn't allow dropping all columns, so we have an EmptyDataFrame
129
129
  # object to handle these cases.
130
130
  try:
131
- new_columns_names = input_df._column_map.get_snowpark_columns_after_drop(
131
+ column_map = input_container.column_map
132
+ new_columns_names = column_map.get_snowpark_columns_after_drop(
132
133
  _get_column_names_to_drop()
133
134
  )
134
135
  result: snowpark.DataFrame = input_df.drop(*columns_to_drop)
135
- return with_column_map(
136
- result,
137
- input_df._column_map.get_spark_column_names_from_snowpark_column_names(
136
+ return DataFrameContainer.create_with_column_mapping(
137
+ dataframe=result,
138
+ spark_column_names=column_map.get_spark_column_names_from_snowpark_column_names(
138
139
  new_columns_names
139
140
  ),
140
141
  snowpark_column_names=new_columns_names,
141
- column_qualifiers=input_df._column_map.get_qualifiers_for_columns_after_drop(
142
+ column_qualifiers=column_map.get_qualifiers_for_columns_after_drop(
142
143
  _get_column_names_to_drop()
143
144
  ),
144
- parent_column_name_map=input_df._column_map,
145
+ parent_column_name_map=column_map,
145
146
  )
146
147
  except snowpark.exceptions.SnowparkColumnException:
147
148
  from snowflake.snowpark_connect.empty_dataframe import EmptyDataFrame
148
149
 
149
- return EmptyDataFrame()
150
+ return DataFrameContainer(EmptyDataFrame())
150
151
 
151
152
 
152
- def map_project(rel: relation_proto.Relation) -> snowpark.DataFrame:
153
+ def map_project(
154
+ rel: relation_proto.Relation,
155
+ ) -> DataFrameContainer:
153
156
  """
154
- Project column(s).
157
+ Project column(s) and return a container.
155
158
 
156
- Projections come in as expressions, which are mapped to `snowpark.Column`
157
- objects.
159
+ Projections come in as expressions, which are mapped to `snowpark.Column` objects.
158
160
  """
159
161
  if rel.project.HasField("input"):
160
- input_df = map_relation(rel.project.input)
162
+ input_container = map_relation(rel.project.input)
163
+ input_df = input_container.dataframe
161
164
  else:
162
165
  # Create a dataframe to represent a OneRowRelation AST node.
163
166
  # XXX: Snowflake does not support 0-column tables, so create a dummy column;
164
167
  # its name does not seem to show up anywhere.
165
168
  session = snowpark.Session.get_active_session()
166
- input_df = with_column_map(
167
- session.create_dataframe([None], ["__DUMMY"]),
168
- ["__DUMMY"],
169
- ["__DUMMY"],
169
+ input_container = DataFrameContainer.create_with_column_mapping(
170
+ dataframe=session.create_dataframe([None], ["__DUMMY"]),
171
+ spark_column_names=["__DUMMY"],
172
+ snowpark_column_names=["__DUMMY"],
170
173
  )
171
- context.set_df_before_projection(input_df)
174
+
175
+ input_df = input_container.dataframe
176
+ context.set_df_before_projection(input_container)
172
177
  expressions: list[expressions_proto.Expression] = rel.project.expressions
173
178
  if not expressions:
174
179
  # XXX: Snowflake does not support 0-column tables, so create a dummy column;
@@ -217,7 +222,7 @@ def map_project(rel: relation_proto.Relation) -> snowpark.DataFrame:
217
222
  )
218
223
 
219
224
  for exp in expressions:
220
- new_spark_names, mapper = map_expression(exp, input_df._column_map, typer)
225
+ new_spark_names, mapper = map_expression(exp, input_container.column_map, typer)
221
226
  if len(new_spark_names) == 1 and not isinstance(
222
227
  mapper.col, _ExplodeFunctionCall
223
228
  ):
@@ -238,7 +243,7 @@ def map_project(rel: relation_proto.Relation) -> snowpark.DataFrame:
238
243
  and not has_unresolved_star
239
244
  ):
240
245
  # Try to get the existing Snowpark column name for this Spark column
241
- existing_snowpark_name = input_df._column_map.get_snowpark_column_name_from_spark_column_name(
246
+ existing_snowpark_name = input_container.column_map.get_snowpark_column_name_from_spark_column_name(
242
247
  spark_name, allow_non_exists=True
243
248
  )
244
249
 
@@ -308,22 +313,28 @@ def map_project(rel: relation_proto.Relation) -> snowpark.DataFrame:
308
313
  result = result.toDF(*final_snowpark_columns)
309
314
  new_snowpark_columns = final_snowpark_columns
310
315
 
311
- return with_column_map(
312
- result,
313
- new_spark_columns,
314
- new_snowpark_columns,
315
- column_types,
316
- column_metadata=input_df._column_map.column_metadata,
316
+ return DataFrameContainer.create_with_column_mapping(
317
+ dataframe=result,
318
+ spark_column_names=new_spark_columns,
319
+ snowpark_column_names=new_snowpark_columns,
320
+ snowpark_column_types=column_types,
321
+ column_metadata=input_container.column_map.column_metadata,
317
322
  column_qualifiers=qualifiers,
318
- parent_column_name_map=input_df._column_map,
323
+ parent_column_name_map=input_container.column_map,
324
+ table_name=input_container.table_name,
325
+ alias=input_container.alias,
319
326
  )
320
327
 
321
328
 
322
- def map_sort(sort: relation_proto.Sort) -> snowpark.DataFrame:
329
+ def map_sort(
330
+ sort: relation_proto.Sort,
331
+ ) -> DataFrameContainer:
323
332
  """
324
- Implements DataFrame.sort().
333
+ Implements DataFrame.sort() and return a container.
334
+
325
335
  """
326
- input_df = map_relation(sort.input)
336
+ input_container = map_relation(sort.input)
337
+ input_df = input_container.dataframe
327
338
  cols = []
328
339
  ascending = [] # Ignored if all order values are set to "unspecified".
329
340
  order_specified = False
@@ -338,7 +349,7 @@ def map_sort(sort: relation_proto.Sort) -> snowpark.DataFrame:
338
349
  if (
339
350
  len(parsed_col_name) == 1
340
351
  and parsed_col_name[0].lower() == "all"
341
- and input_df._column_map.get_snowpark_column_name_from_spark_column_name(
352
+ and input_container.column_map.get_snowpark_column_name_from_spark_column_name(
342
353
  parsed_col_name[0], allow_non_exists=True
343
354
  )
344
355
  is None
@@ -354,7 +365,7 @@ def map_sort(sort: relation_proto.Sort) -> snowpark.DataFrame:
354
365
  direction=sort_order[0].direction,
355
366
  null_ordering=sort_order[0].null_ordering,
356
367
  )
357
- for col in input_df._column_map.get_spark_columns()
368
+ for col in input_container.column_map.get_spark_columns()
358
369
  ]
359
370
 
360
371
  for so in sort_order:
@@ -370,7 +381,7 @@ def map_sort(sort: relation_proto.Sort) -> snowpark.DataFrame:
370
381
  )
371
382
  else:
372
383
  _, typed_column = map_single_column_expression(
373
- so.child, input_df._column_map, typer
384
+ so.child, input_container.column_map, typer
374
385
  )
375
386
  col = typed_column.col
376
387
 
@@ -412,29 +423,35 @@ def map_sort(sort: relation_proto.Sort) -> snowpark.DataFrame:
412
423
  if not order_specified:
413
424
  ascending = None
414
425
  result = input_df.sort(cols, ascending=ascending)
415
- result._column_map = input_df._column_map
416
- result._table_name = input_df._table_name
417
- set_schema_getter(result, lambda: input_df.schema)
418
- return result
426
+
427
+ return DataFrameContainer(
428
+ result,
429
+ input_container.column_map,
430
+ input_container.table_name,
431
+ cached_schema_getter=lambda: input_df.schema,
432
+ )
419
433
 
420
434
 
421
- def map_to_df(rel: relation_proto.Relation) -> snowpark.DataFrame:
435
+ def map_to_df(
436
+ rel: relation_proto.Relation,
437
+ ) -> DataFrameContainer:
422
438
  """
423
- Transform the column names of the input DataFrame.
439
+ Transform the column names of the input DataFrame and return a container.
424
440
  """
425
- input_df: snowpark.DataFrame = map_relation(rel.to_df.input)
441
+ input_container = map_relation(rel.to_df.input)
442
+ input_df = input_container.dataframe
443
+
426
444
  new_column_names = list(rel.to_df.column_names)
427
- if len(new_column_names) != len(input_df._column_map.columns):
445
+ if len(new_column_names) != len(input_container.column_map.columns):
428
446
  # TODO: Check error type here
429
447
  raise ValueError(
430
448
  "Number of column names must match number of columns in DataFrame"
431
449
  )
432
-
433
450
  snowpark_new_column_names = make_column_names_snowpark_compatible(
434
451
  new_column_names, rel.common.plan_id
435
452
  )
436
-
437
453
  result = input_df.toDF(*snowpark_new_column_names)
454
+
438
455
  if result._select_statement is not None:
439
456
  # do not allow snowpark to flatten the to_df result
440
457
  # TODO: remove after SNOW-2203706 is fixed
@@ -448,27 +465,33 @@ def map_to_df(rel: relation_proto.Relation) -> snowpark.DataFrame:
448
465
  ]
449
466
  )
450
467
 
451
- set_schema_getter(result, _get_schema)
452
- result_with_column_map = with_column_map(
453
- result,
454
- new_column_names,
468
+ result_container = DataFrameContainer.create_with_column_mapping(
469
+ dataframe=result,
470
+ spark_column_names=new_column_names,
455
471
  snowpark_column_names=snowpark_new_column_names,
472
+ parent_column_name_map=input_container.column_map,
473
+ table_name=input_container.table_name,
474
+ alias=input_container.alias,
475
+ cached_schema_getter=_get_schema,
456
476
  )
457
- context.set_df_before_projection(result_with_column_map)
458
- return result_with_column_map
477
+ context.set_df_before_projection(result_container)
478
+ return result_container
459
479
 
460
480
 
461
- def map_to_schema(rel: relation_proto.Relation) -> snowpark.DataFrame:
481
+ def map_to_schema(
482
+ rel: relation_proto.Relation,
483
+ ) -> DataFrameContainer:
462
484
  """
463
485
  Transform the column names of the input DataFrame.
464
486
  """
465
- input_df: snowpark.DataFrame = map_relation(rel.to_schema.input)
487
+ input_container = map_relation(rel.to_schema.input)
488
+ input_df = input_container.dataframe
466
489
  new_column_names = [field.name for field in rel.to_schema.schema.struct.fields]
467
490
  snowpark_new_column_names = make_column_names_snowpark_compatible(
468
491
  new_column_names, rel.common.plan_id
469
492
  )
470
493
  count_case_insensitive_column_names = defaultdict()
471
- for key, value in input_df._column_map.spark_to_col.items():
494
+ for key, value in input_container.column_map.spark_to_col.items():
472
495
  count_case_insensitive_column_names[
473
496
  key.lower()
474
497
  ] = count_case_insensitive_column_names.get(key.lower(), 0) + len(value)
@@ -483,12 +506,12 @@ def map_to_schema(rel: relation_proto.Relation) -> snowpark.DataFrame:
483
506
  if field.name in already_existing_columns:
484
507
  if count_case_insensitive_column_names[field.name.lower()] > 1:
485
508
  raise AnalysisException(
486
- f"[AMBIGUOUS_COLUMN_OR_FIELD] Column or field `{field.name}` is ambiguous and has {len(input_df._column_map.spark_to_col[field.name])} matches."
509
+ f"[AMBIGUOUS_COLUMN_OR_FIELD] Column or field `{field.name}` is ambiguous and has {len(input_container.column_map.spark_to_col[field.name])} matches."
487
510
  )
488
511
  snowpark_name = None
489
- for name in input_df._column_map.spark_to_col:
512
+ for name in input_container.column_map.spark_to_col:
490
513
  if name.lower() == field.name.lower():
491
- snowpark_name = input_df._column_map.spark_to_col[name][
514
+ snowpark_name = input_container.column_map.spark_to_col[name][
492
515
  0
493
516
  ].snowpark_name
494
517
  break
@@ -516,10 +539,10 @@ def map_to_schema(rel: relation_proto.Relation) -> snowpark.DataFrame:
516
539
  # All columns already exist, we're doing a simple update.
517
540
  snowpark_new_column_names = []
518
541
  for column in new_column_names:
519
- for name in input_df._column_map.spark_to_col:
542
+ for name in input_container.column_map.spark_to_col:
520
543
  if name.lower() == column.lower():
521
544
  snowpark_new_column_names.append(
522
- input_df._column_map.spark_to_col[name][0].snowpark_name
545
+ input_container.column_map.spark_to_col[name][0].snowpark_name
523
546
  )
524
547
  result = input_df
525
548
  elif len(already_existing_columns) == 0:
@@ -540,16 +563,18 @@ def map_to_schema(rel: relation_proto.Relation) -> snowpark.DataFrame:
540
563
  # If the column doesn't already exist, append the new Snowpark name to columns_to_add
541
564
  if all(
542
565
  spark_column.lower() != name.lower()
543
- for name in input_df._column_map.spark_to_col
566
+ for name in input_container.column_map.spark_to_col
544
567
  ):
545
568
  columns_to_add.append(snowpark_column)
546
569
  new_snowpark_new_column_names.append(snowpark_column)
547
570
  else:
548
- for name in input_df._column_map.spark_to_col:
571
+ for name in input_container.column_map.spark_to_col:
549
572
  # If the column does exist, append the original Snowpark name, We don't need to add this column.
550
573
  if name.lower() == spark_column.lower():
551
574
  new_snowpark_new_column_names.append(
552
- input_df._column_map.spark_to_col[name][0].snowpark_name
575
+ input_container.column_map.spark_to_col[name][
576
+ 0
577
+ ].snowpark_name
553
578
  )
554
579
  # Add all columns introduced by the new schema.
555
580
  new_columns = [
@@ -581,21 +606,24 @@ def map_to_schema(rel: relation_proto.Relation) -> snowpark.DataFrame:
581
606
  column_metadata[field.name] = None
582
607
  else:
583
608
  column_metadata[field.name] = None
584
- return with_column_map(
585
- result_with_casting,
586
- new_column_names,
609
+ return DataFrameContainer.create_with_column_mapping(
610
+ dataframe=result_with_casting,
611
+ spark_column_names=new_column_names,
587
612
  snowpark_column_names=snowpark_new_column_names,
588
613
  snowpark_column_types=[field.datatype for field in snowpark_schema.fields],
589
614
  column_metadata=column_metadata,
590
- parent_column_name_map=input_df._column_map,
615
+ parent_column_name_map=input_container.column_map,
591
616
  )
592
617
 
593
618
 
594
- def map_with_columns_renamed(rel: relation_proto.Relation) -> snowpark.DataFrame:
619
+ def map_with_columns_renamed(
620
+ rel: relation_proto.Relation,
621
+ ) -> DataFrameContainer:
595
622
  """
596
- Rename columns in a DataFrame.
623
+ Rename columns in a DataFrame and return a container.
597
624
  """
598
- input_df: snowpark.DataFrame = map_relation(rel.with_columns_renamed.input)
625
+ input_container = map_relation(rel.with_columns_renamed.input)
626
+ input_df = input_container.dataframe
599
627
  rename_columns_map = dict(rel.with_columns_renamed.rename_columns_map)
600
628
 
601
629
  if not global_config.spark_sql_caseSensitive:
@@ -608,9 +636,11 @@ def map_with_columns_renamed(rel: relation_proto.Relation) -> snowpark.DataFrame
608
636
  k.lower(): v.lower() for k, v in rename_columns_map.items()
609
637
  }
610
638
 
639
+ column_map = input_container.column_map
640
+
611
641
  # re-construct the rename chains based on the input dataframe.
612
- if input_df._column_map.rename_chains:
613
- for key, value in input_df._column_map.rename_chains.items():
642
+ if input_container.column_map.rename_chains:
643
+ for key, value in input_container.column_map.rename_chains.items():
614
644
  if key in rename_columns_map:
615
645
  # This is to handle the case where the same column is renamed multiple times.
616
646
  # df.withColumnRenamed("a", "b").withColumnRenamed("a", "c")
@@ -628,19 +658,44 @@ def map_with_columns_renamed(rel: relation_proto.Relation) -> snowpark.DataFrame
628
658
  # This just copies the renames from previous computed dataframe
629
659
  rename_columns_map[key] = value
630
660
 
631
- existing_columns = input_df._column_map.get_spark_columns()
661
+ existing_columns = input_container.column_map.get_spark_columns()
662
+
663
+ def _column_exists_error(name: str) -> AnalysisException:
664
+ return AnalysisException(
665
+ f"[COLUMN_ALREADY_EXISTS] The column `{name}` already exists. Consider to choose another name or rename the existing column."
666
+ )
632
667
 
633
668
  # Validate for naming conflicts
634
- new_names_list = list(dict(rel.with_columns_renamed.rename_columns_map).values())
669
+ rename_map = dict(rel.with_columns_renamed.rename_columns_map)
670
+ new_names_list = list(rename_map.values())
635
671
  seen = set()
636
672
  for new_name in new_names_list:
673
+ # Check if this new name conflicts with existing columns
674
+ # But allow renaming a column to a different case version of itself
675
+ is_case_insensitive_self_rename = False
676
+ if not global_config.spark_sql_caseSensitive:
677
+ # Find the source column(s) that map to this new name
678
+ source_columns = [
679
+ old_name
680
+ for old_name, new_name_candidate in rename_map.items()
681
+ if new_name_candidate == new_name
682
+ ]
683
+ # Check if any source column is the same as new name when case-insensitive
684
+ is_case_insensitive_self_rename = any(
685
+ source_col.lower() == new_name.lower() for source_col in source_columns
686
+ )
687
+
688
+ if (
689
+ column_map.has_spark_column(new_name)
690
+ and not is_case_insensitive_self_rename
691
+ ):
692
+ # Spark doesn't allow reusing existing names, even if the result df will not contain duplicate columns
693
+ raise _column_exists_error(new_name)
637
694
  if (global_config.spark_sql_caseSensitive and new_name in seen) or (
638
695
  not global_config.spark_sql_caseSensitive
639
696
  and new_name.lower() in [s.lower() for s in seen]
640
697
  ):
641
- raise AnalysisException(
642
- f"[COLUMN_ALREADY_EXISTS] The column `{new_name}` already exists. Consider to choose another name or rename the existing column."
643
- )
698
+ raise _column_exists_error(new_name)
644
699
  seen.add(new_name)
645
700
 
646
701
  new_columns = []
@@ -656,25 +711,30 @@ def map_with_columns_renamed(rel: relation_proto.Relation) -> snowpark.DataFrame
656
711
 
657
712
  # Creating a new df to avoid updating the state of cached dataframe.
658
713
  new_df = input_df.select("*")
659
- result_df = with_column_map(
660
- new_df,
661
- new_columns,
662
- input_df._column_map.get_snowpark_columns(),
663
- column_qualifiers=input_df._column_map.get_qualifiers(),
664
- parent_column_name_map=input_df._column_map.get_parent_column_name_map(),
714
+ result_container = DataFrameContainer.create_with_column_mapping(
715
+ dataframe=new_df,
716
+ spark_column_names=new_columns,
717
+ snowpark_column_names=input_container.column_map.get_snowpark_columns(),
718
+ column_qualifiers=input_container.column_map.get_qualifiers(),
719
+ parent_column_name_map=input_container.column_map.get_parent_column_name_map(),
720
+ table_name=input_container.table_name,
721
+ alias=input_container.alias,
665
722
  )
666
- result_df._column_map.rename_chains = rename_columns_map
723
+ result_container.column_map.rename_chains = rename_columns_map
667
724
 
668
- return result_df
725
+ return result_container
669
726
 
670
727
 
671
- def map_with_columns(rel: relation_proto.Relation) -> snowpark.DataFrame:
728
+ def map_with_columns(
729
+ rel: relation_proto.Relation,
730
+ ) -> DataFrameContainer:
672
731
  """
673
- Add columns to a DataFrame.
732
+ Add columns to a DataFrame and return a container.
674
733
  """
675
- input_df: snowpark.DataFrame = map_relation(rel.with_columns.input)
734
+ input_container = map_relation(rel.with_columns.input)
735
+ input_df = input_container.dataframe
676
736
  with_columns = [
677
- map_alias(alias, input_df._column_map, ExpressionTyper(input_df))
737
+ map_alias(alias, input_container.column_map, ExpressionTyper(input_df))
678
738
  for alias in rel.with_columns.aliases
679
739
  ]
680
740
  # TODO: This list needs to contain all unique column names, but the code below doesn't
@@ -682,7 +742,7 @@ def map_with_columns(rel: relation_proto.Relation) -> snowpark.DataFrame:
682
742
  with_columns_names = []
683
743
  with_columns_exprs = []
684
744
  with_columns_types = []
685
- with_column_offset = len(input_df._column_map.get_spark_columns())
745
+ with_column_offset = len(input_container.column_map.get_spark_columns())
686
746
  new_spark_names = []
687
747
  seen_columns = set()
688
748
  for names_list, expr in with_columns:
@@ -690,7 +750,7 @@ def map_with_columns(rel: relation_proto.Relation) -> snowpark.DataFrame:
690
750
  len(names_list) == 1
691
751
  ), f"Expected single column name, got {len(names_list)}: {names_list}"
692
752
  name = names_list[0]
693
- name_normalized = input_df._column_map._normalized_spark_name(name)
753
+ name_normalized = input_container.column_map._normalized_spark_name(name)
694
754
  if name_normalized in seen_columns:
695
755
  raise ValueError(
696
756
  f"[COLUMN_ALREADY_EXISTS] The column `{name}` already exists."
@@ -698,11 +758,9 @@ def map_with_columns(rel: relation_proto.Relation) -> snowpark.DataFrame:
698
758
  seen_columns.add(name_normalized)
699
759
  # If the column name is already in the DataFrame, we replace it, so we use the
700
760
  # mapping to get the correct column name.
701
- if input_df._column_map.has_spark_column(name):
702
- all_instances_of_spark_column_name = (
703
- input_df._column_map.get_snowpark_column_names_from_spark_column_names(
704
- [name]
705
- )
761
+ if input_container.column_map.has_spark_column(name):
762
+ all_instances_of_spark_column_name = input_container.column_map.get_snowpark_column_names_from_spark_column_names(
763
+ [name]
706
764
  )
707
765
  if len(all_instances_of_spark_column_name) == 0:
708
766
  raise KeyError(f"Spark column name {name} does not exist")
@@ -729,7 +787,7 @@ def map_with_columns(rel: relation_proto.Relation) -> snowpark.DataFrame:
729
787
  new_spark_columns,
730
788
  new_snowpark_columns,
731
789
  qualifiers,
732
- ) = input_df._column_map.with_columns(new_spark_names, with_columns_names)
790
+ ) = input_container.column_map.with_columns(new_spark_names, with_columns_names)
733
791
 
734
792
  # dedup the change in columns at snowpark name level, this is required by the with columns functions
735
793
  with_columns_names_deduped = []
@@ -751,32 +809,39 @@ def map_with_columns(rel: relation_proto.Relation) -> snowpark.DataFrame:
751
809
  + list(zip(with_columns_names, with_columns_types))
752
810
  )
753
811
 
754
- column_metadata = input_df._column_map.column_metadata or {}
812
+ column_metadata = input_container.column_map.column_metadata or {}
755
813
  for alias in rel.with_columns.aliases:
756
814
  # this logic is triggered for df.withMetadata function.
757
815
  if alias.HasField("metadata") and len(alias.metadata.strip()) > 0:
758
816
  # spark sends list of alias names with only one element in the list with alias name.
759
817
  column_metadata[alias.name[0]] = json.loads(alias.metadata)
760
818
 
761
- return with_column_map(
762
- result,
763
- new_spark_columns,
819
+ return DataFrameContainer.create_with_column_mapping(
820
+ dataframe=result,
821
+ spark_column_names=new_spark_columns,
764
822
  snowpark_column_names=new_snowpark_columns,
765
823
  snowpark_column_types=[
766
824
  snowpark_name_to_type.get(n) for n in new_snowpark_columns
767
825
  ],
768
826
  column_metadata=column_metadata,
769
827
  column_qualifiers=qualifiers,
770
- parent_column_name_map=input_df._column_map,
828
+ parent_column_name_map=input_container.column_map,
829
+ table_name=input_container.table_name,
830
+ alias=input_container.alias,
771
831
  )
772
832
 
773
833
 
774
- def map_unpivot(rel: relation_proto.Relation) -> snowpark.DataFrame:
834
+ def map_unpivot(
835
+ rel: relation_proto.Relation,
836
+ ) -> DataFrameContainer:
775
837
  # Spark API: df.unpivot([id_columns], [unpivot_columns], var_column, val_column)
776
838
  # Snowpark API: df.unpivot(val_column, var_column, [unpivot_columns])
777
839
  if rel.unpivot.HasField("values") and len(rel.unpivot.values.values) == 0:
778
840
  raise SparkException.unpivot_requires_value_columns()
779
841
 
842
+ input_container = map_relation(rel.unpivot.input)
843
+ input_df = input_container.dataframe
844
+
780
845
  def get_lease_common_ancestor_classes(types: list[snowpark.types.DataType]) -> set:
781
846
  mro_lists = [set(type.__class__.mro()) for type in types]
782
847
  common_ancestors = set.intersection(*mro_lists)
@@ -795,12 +860,15 @@ def map_unpivot(rel: relation_proto.Relation) -> snowpark.DataFrame:
795
860
  type_column_list = [
796
861
  (
797
862
  f.datatype,
798
- df._column_map.get_spark_column_name_from_snowpark_column_name(
799
- snowpark_functions_col(f.name, df._column_map).get_name()
863
+ input_container.column_map.get_spark_column_name_from_snowpark_column_name(
864
+ snowpark_functions_col(
865
+ f.name, input_container.column_map
866
+ ).get_name()
800
867
  ),
801
868
  )
802
869
  for f in df.schema.fields
803
- if snowpark_functions_col(f.name, df._column_map).get_name() in col_names
870
+ if snowpark_functions_col(f.name, input_container.column_map).get_name()
871
+ in col_names
804
872
  ]
805
873
  type_iter, _ = zip(*type_column_list)
806
874
  type_list = list(type_iter)
@@ -837,7 +905,7 @@ def map_unpivot(rel: relation_proto.Relation) -> snowpark.DataFrame:
837
905
  typer = ExpressionTyper(input_df)
838
906
  for id_col in relation.unpivot.ids:
839
907
  spark_name, typed_column = map_single_column_expression(
840
- id_col, df._column_map, typer
908
+ id_col, input_container.column_map, typer
841
909
  )
842
910
  id_col_names.append(typed_column.col.get_name())
843
911
  spark_columns.append(spark_name)
@@ -848,7 +916,7 @@ def map_unpivot(rel: relation_proto.Relation) -> snowpark.DataFrame:
848
916
  unpivot_spark_names = []
849
917
  for v in relation.unpivot.values.values:
850
918
  spark_name, typed_column = map_single_column_expression(
851
- v, df._column_map, typer
919
+ v, input_container.column_map, typer
852
920
  )
853
921
  unpivot_col_names.append(typed_column.col.get_name())
854
922
  unpivot_spark_names.append(spark_name)
@@ -856,15 +924,19 @@ def map_unpivot(rel: relation_proto.Relation) -> snowpark.DataFrame:
856
924
  if not rel.unpivot.HasField("values"):
857
925
  # When `values` is `None`, all non-id columns will be unpivoted.
858
926
  for snowpark_name, spark_name in zip(
859
- df._column_map.get_snowpark_columns(),
860
- df._column_map.get_spark_columns(),
927
+ input_container.column_map.get_snowpark_columns(),
928
+ input_container.column_map.get_spark_columns(),
861
929
  ):
862
930
  if (
863
- snowpark_functions_col(snowpark_name, df._column_map).get_name()
931
+ snowpark_functions_col(
932
+ snowpark_name, input_container.column_map
933
+ ).get_name()
864
934
  not in id_col_names
865
935
  ):
866
936
  unpivot_col_names.append(
867
- snowpark_functions_col(snowpark_name, df._column_map).get_name()
937
+ snowpark_functions_col(
938
+ snowpark_name, input_container.column_map
939
+ ).get_name()
868
940
  )
869
941
  unpivot_spark_names.append(spark_name)
870
942
 
@@ -872,7 +944,6 @@ def map_unpivot(rel: relation_proto.Relation) -> snowpark.DataFrame:
872
944
  spark_columns.append(relation.unpivot.value_column_name)
873
945
  return spark_columns, id_col_names, unpivot_col_names, unpivot_spark_names
874
946
 
875
- input_df: snowpark.DataFrame = map_relation(rel.unpivot.input)
876
947
  (
877
948
  spark_columns,
878
949
  id_col_names,
@@ -899,27 +970,35 @@ def map_unpivot(rel: relation_proto.Relation) -> snowpark.DataFrame:
899
970
  column_reverse_project = []
900
971
  snowpark_columns = []
901
972
  qualifiers = []
902
- for c in input_df._column_map.get_snowpark_columns():
903
- c_name = snowpark_functions_col(c, input_df._column_map).get_name()
973
+ for c in input_container.column_map.get_snowpark_columns():
974
+ c_name = snowpark_functions_col(c, input_container.column_map).get_name()
904
975
  if c_name in unpivot_col_names:
905
976
  if cast_type:
906
977
  column_project.append(
907
- snowpark_functions_col(c, input_df._column_map)
978
+ snowpark_functions_col(c, input_container.column_map)
908
979
  .cast("DOUBLE")
909
980
  .alias(c_name)
910
981
  )
911
982
  else:
912
- column_project.append(snowpark_functions_col(c, input_df._column_map))
983
+ column_project.append(
984
+ snowpark_functions_col(c, input_container.column_map)
985
+ )
913
986
  if c_name in id_col_names:
914
987
  id_col_alias = "SES" + generate_random_alphanumeric().upper()
915
988
  column_project.append(
916
- snowpark_functions_col(c, input_df._column_map).alias(id_col_alias)
989
+ snowpark_functions_col(c, input_container.column_map).alias(
990
+ id_col_alias
991
+ )
917
992
  )
918
993
  column_reverse_project.append(
919
- snowpark_functions_col(id_col_alias, input_df._column_map).alias(c)
994
+ snowpark_functions_col(id_col_alias, input_container.column_map).alias(
995
+ c
996
+ )
920
997
  )
921
998
  snowpark_columns.append(c)
922
- qualifiers.append(input_df._column_map.get_qualifier_for_spark_column(c))
999
+ qualifiers.append(
1000
+ input_container.column_map.get_qualifier_for_spark_column(c)
1001
+ )
923
1002
 
924
1003
  # Without the case when postprocessing, the result Spark dataframe is:
925
1004
  # +---+------------+------+
@@ -941,7 +1020,7 @@ def map_unpivot(rel: relation_proto.Relation) -> snowpark.DataFrame:
941
1020
  if post_process_variable_column is None:
942
1021
  post_process_variable_column = snowpark_fn.when(
943
1022
  snowpark_functions_col(
944
- snowpark_variable_column_name, input_df._column_map
1023
+ snowpark_variable_column_name, input_container.column_map
945
1024
  )
946
1025
  == unquote_if_quoted(snowpark_name),
947
1026
  spark_name,
@@ -949,7 +1028,7 @@ def map_unpivot(rel: relation_proto.Relation) -> snowpark.DataFrame:
949
1028
  else:
950
1029
  post_process_variable_column = post_process_variable_column.when(
951
1030
  snowpark_functions_col(
952
- snowpark_variable_column_name, input_df._column_map
1031
+ snowpark_variable_column_name, input_container.column_map
953
1032
  )
954
1033
  == unquote_if_quoted(snowpark_name),
955
1034
  spark_name,
@@ -960,7 +1039,7 @@ def map_unpivot(rel: relation_proto.Relation) -> snowpark.DataFrame:
960
1039
  )
961
1040
  snowpark_columns.append(snowpark_variable_column_name)
962
1041
  column_reverse_project.append(
963
- snowpark_functions_col(snowpark_value_column_name, input_df._column_map)
1042
+ snowpark_functions_col(snowpark_value_column_name, input_container.column_map)
964
1043
  )
965
1044
  snowpark_columns.append(snowpark_value_column_name)
966
1045
  qualifiers.extend([[]] * 2)
@@ -975,20 +1054,23 @@ def map_unpivot(rel: relation_proto.Relation) -> snowpark.DataFrame:
975
1054
  )
976
1055
  .select(*column_reverse_project)
977
1056
  )
978
- return with_column_map(
979
- result,
980
- spark_columns,
981
- snowpark_columns,
1057
+ return DataFrameContainer.create_with_column_mapping(
1058
+ dataframe=result,
1059
+ spark_column_names=spark_columns,
1060
+ snowpark_column_names=snowpark_columns,
982
1061
  column_qualifiers=qualifiers,
983
- parent_column_name_map=input_df._column_map,
1062
+ parent_column_name_map=input_container.column_map,
984
1063
  )
985
1064
 
986
1065
 
987
- def map_group_map(rel: relation_proto.Relation) -> snowpark.DataFrame:
1066
+ def map_group_map(
1067
+ rel: relation_proto.Relation,
1068
+ ) -> DataFrameContainer:
988
1069
  """
989
1070
  Add columns to a DataFrame.
990
1071
  """
991
- input_df: snowpark.DataFrame = map_relation(rel.group_map.input)
1072
+ input_container = map_relation(rel.group_map.input)
1073
+ input_df = input_container.dataframe
992
1074
  grouping_expressions = rel.group_map.grouping_expressions
993
1075
  snowpark_grouping_expressions: list[snowpark.Column] = []
994
1076
  typer = ExpressionTyper(input_df)
@@ -996,7 +1078,7 @@ def map_group_map(rel: relation_proto.Relation) -> snowpark.DataFrame:
996
1078
  qualifiers = []
997
1079
  for exp in grouping_expressions:
998
1080
  new_name, snowpark_column = map_single_column_expression(
999
- exp, input_df._column_map, typer
1081
+ exp, input_container.column_map, typer
1000
1082
  )
1001
1083
  snowpark_grouping_expressions.append(snowpark_column.col)
1002
1084
  group_name_list.append(new_name)
@@ -1013,9 +1095,9 @@ def map_group_map(rel: relation_proto.Relation) -> snowpark.DataFrame:
1013
1095
 
1014
1096
  if not is_compatible_python or TEST_FLAG_FORCE_CREATE_SPROC:
1015
1097
  original_columns = None
1016
- if input_df._column_map is not None:
1098
+ if input_container.column_map is not None:
1017
1099
  original_columns = [
1018
- column.spark_name for column in input_df._column_map.columns
1100
+ column.spark_name for column in input_container.column_map.columns
1019
1101
  ]
1020
1102
 
1021
1103
  apply_udtf_temp_name = create_apply_udtf_in_sproc(
@@ -1044,12 +1126,12 @@ def map_group_map(rel: relation_proto.Relation) -> snowpark.DataFrame:
1044
1126
  )
1045
1127
 
1046
1128
  qualifiers.extend([[]] * (len(result.columns) - len(group_name_list)))
1047
- return with_column_map(
1048
- result,
1049
- [field.name for field in output_type],
1050
- result.columns,
1129
+ return DataFrameContainer.create_with_column_mapping(
1130
+ dataframe=result,
1131
+ spark_column_names=[field.name for field in output_type],
1132
+ snowpark_column_names=result.columns,
1051
1133
  column_qualifiers=qualifiers,
1052
- parent_column_name_map=input_df._column_map,
1134
+ parent_column_name_map=input_container.column_map,
1053
1135
  )
1054
1136
 
1055
1137