snowpark-connect 0.20.2__py3-none-any.whl → 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of snowpark-connect might be problematic. Click here for more details.

Files changed (67) hide show
  1. snowflake/snowpark_connect/analyze_plan/map_tree_string.py +3 -2
  2. snowflake/snowpark_connect/column_name_handler.py +6 -65
  3. snowflake/snowpark_connect/config.py +28 -14
  4. snowflake/snowpark_connect/dataframe_container.py +242 -0
  5. snowflake/snowpark_connect/execute_plan/map_execution_command.py +13 -23
  6. snowflake/snowpark_connect/execute_plan/map_execution_root.py +9 -5
  7. snowflake/snowpark_connect/expression/map_extension.py +2 -1
  8. snowflake/snowpark_connect/expression/map_unresolved_attribute.py +8 -7
  9. snowflake/snowpark_connect/expression/map_unresolved_function.py +279 -43
  10. snowflake/snowpark_connect/expression/map_unresolved_star.py +8 -8
  11. snowflake/snowpark_connect/expression/map_update_fields.py +1 -1
  12. snowflake/snowpark_connect/expression/typer.py +6 -6
  13. snowflake/snowpark_connect/proto/control_pb2.py +17 -16
  14. snowflake/snowpark_connect/proto/control_pb2.pyi +17 -17
  15. snowflake/snowpark_connect/proto/control_pb2_grpc.py +12 -63
  16. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.py +15 -14
  17. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.pyi +19 -14
  18. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +27 -26
  19. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +74 -68
  20. snowflake/snowpark_connect/relation/catalogs/abstract_spark_catalog.py +5 -5
  21. snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +25 -17
  22. snowflake/snowpark_connect/relation/map_aggregate.py +72 -47
  23. snowflake/snowpark_connect/relation/map_catalog.py +2 -2
  24. snowflake/snowpark_connect/relation/map_column_ops.py +207 -144
  25. snowflake/snowpark_connect/relation/map_crosstab.py +25 -6
  26. snowflake/snowpark_connect/relation/map_extension.py +81 -56
  27. snowflake/snowpark_connect/relation/map_join.py +72 -63
  28. snowflake/snowpark_connect/relation/map_local_relation.py +35 -20
  29. snowflake/snowpark_connect/relation/map_map_partitions.py +21 -16
  30. snowflake/snowpark_connect/relation/map_relation.py +22 -16
  31. snowflake/snowpark_connect/relation/map_row_ops.py +232 -146
  32. snowflake/snowpark_connect/relation/map_sample_by.py +15 -8
  33. snowflake/snowpark_connect/relation/map_show_string.py +42 -5
  34. snowflake/snowpark_connect/relation/map_sql.py +155 -78
  35. snowflake/snowpark_connect/relation/map_stats.py +88 -39
  36. snowflake/snowpark_connect/relation/map_subquery_alias.py +13 -14
  37. snowflake/snowpark_connect/relation/map_udtf.py +6 -9
  38. snowflake/snowpark_connect/relation/read/map_read.py +8 -3
  39. snowflake/snowpark_connect/relation/read/map_read_csv.py +7 -7
  40. snowflake/snowpark_connect/relation/read/map_read_jdbc.py +7 -7
  41. snowflake/snowpark_connect/relation/read/map_read_json.py +7 -7
  42. snowflake/snowpark_connect/relation/read/map_read_parquet.py +7 -7
  43. snowflake/snowpark_connect/relation/read/map_read_socket.py +7 -3
  44. snowflake/snowpark_connect/relation/read/map_read_table.py +25 -16
  45. snowflake/snowpark_connect/relation/read/map_read_text.py +7 -7
  46. snowflake/snowpark_connect/relation/utils.py +11 -5
  47. snowflake/snowpark_connect/relation/write/jdbc_write_dbapi.py +15 -12
  48. snowflake/snowpark_connect/relation/write/map_write.py +199 -40
  49. snowflake/snowpark_connect/relation/write/map_write_jdbc.py +3 -2
  50. snowflake/snowpark_connect/server.py +34 -4
  51. snowflake/snowpark_connect/type_mapping.py +2 -23
  52. snowflake/snowpark_connect/utils/cache.py +27 -22
  53. snowflake/snowpark_connect/utils/context.py +33 -17
  54. snowflake/snowpark_connect/utils/{attribute_handling.py → identifiers.py} +47 -0
  55. snowflake/snowpark_connect/utils/session.py +41 -34
  56. snowflake/snowpark_connect/utils/telemetry.py +1 -2
  57. snowflake/snowpark_connect/version.py +1 -1
  58. {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.21.0.dist-info}/METADATA +5 -3
  59. {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.21.0.dist-info}/RECORD +67 -64
  60. snowpark_connect-0.21.0.dist-info/licenses/LICENSE-binary +568 -0
  61. snowpark_connect-0.21.0.dist-info/licenses/NOTICE-binary +1533 -0
  62. {snowpark_connect-0.20.2.data → snowpark_connect-0.21.0.data}/scripts/snowpark-connect +0 -0
  63. {snowpark_connect-0.20.2.data → snowpark_connect-0.21.0.data}/scripts/snowpark-session +0 -0
  64. {snowpark_connect-0.20.2.data → snowpark_connect-0.21.0.data}/scripts/snowpark-submit +0 -0
  65. {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.21.0.dist-info}/WHEEL +0 -0
  66. {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.21.0.dist-info}/licenses/LICENSE.txt +0 -0
  67. {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.21.0.dist-info}/top_level.txt +0 -0
@@ -6,26 +6,41 @@ import pyspark.sql.connect.proto.relations_pb2 as relation_proto
6
6
 
7
7
  import snowflake.snowpark.functions as fn
8
8
  from snowflake import snowpark
9
- from snowflake.snowpark_connect.column_name_handler import with_column_map
9
+ from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
10
10
  from snowflake.snowpark_connect.relation.map_relation import map_relation
11
11
 
12
12
 
13
13
  def map_crosstab(
14
14
  rel: relation_proto.Relation,
15
- ) -> snowpark.DataFrame:
15
+ ) -> DataFrameContainer:
16
16
  """
17
17
  Perform a crosstab on the input DataFrame.
18
18
  """
19
- input_df: snowpark.DataFrame = map_relation(rel.crosstab.input)
20
- col1 = input_df._column_map.get_snowpark_column_name_from_spark_column_name(
19
+ input_container = map_relation(rel.crosstab.input)
20
+ input_df = input_container.dataframe
21
+
22
+ col1 = input_container.column_map.get_snowpark_column_name_from_spark_column_name(
21
23
  rel.crosstab.col1
22
24
  )
23
- col2 = input_df._column_map.get_snowpark_column_name_from_spark_column_name(
25
+ col2 = input_container.column_map.get_snowpark_column_name_from_spark_column_name(
24
26
  rel.crosstab.col2
25
27
  )
26
28
  input_df = input_df.select(
27
29
  fn.col(col1).cast("string").alias(col1), fn.col(col2).cast("string").alias(col2)
28
30
  )
31
+
32
+ # Handle empty DataFrame case
33
+ if input_df.count() == 0:
34
+ # For empty DataFrame, return a DataFrame with just the first column name
35
+ result = input_df.select(
36
+ fn.lit(f"{rel.crosstab.col1}_{rel.crosstab.col2}").alias("c0")
37
+ )
38
+ return DataFrameContainer.create_with_column_mapping(
39
+ dataframe=result,
40
+ spark_column_names=[f"{rel.crosstab.col1}_{rel.crosstab.col2}"],
41
+ snowpark_column_names=["c0"],
42
+ )
43
+
29
44
  result: snowpark.DataFrame = input_df.crosstab(col1, col2)
30
45
  new_columns = [f"{rel.crosstab.col1}_{rel.crosstab.col2}"] + [
31
46
  (
@@ -45,4 +60,8 @@ def map_crosstab(
45
60
  result = result.rename(
46
61
  dict(zip(result.columns, [f"c{i}" for i in range(len(result.columns))]))
47
62
  )
48
- return with_column_map(result, new_columns, result.columns)
63
+ return DataFrameContainer.create_with_column_mapping(
64
+ dataframe=result,
65
+ spark_column_names=new_columns,
66
+ snowpark_column_names=result.columns,
67
+ )
@@ -14,26 +14,29 @@ from snowflake import snowpark
14
14
  from snowflake.snowpark_connect.column_name_handler import (
15
15
  ColumnNameMap,
16
16
  make_column_names_snowpark_compatible,
17
- with_column_map,
18
17
  )
19
18
  from snowflake.snowpark_connect.config import get_boolean_session_config_param
19
+ from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
20
20
  from snowflake.snowpark_connect.expression.map_expression import map_expression
21
21
  from snowflake.snowpark_connect.expression.typer import ExpressionTyper
22
22
  from snowflake.snowpark_connect.relation.map_relation import map_relation
23
23
  from snowflake.snowpark_connect.typed_column import TypedColumn
24
- from snowflake.snowpark_connect.utils.attribute_handling import (
25
- split_fully_qualified_spark_name,
26
- )
27
24
  from snowflake.snowpark_connect.utils.context import (
28
25
  get_sql_aggregate_function_count,
29
26
  push_outer_dataframe,
27
+ set_current_grouping_columns,
28
+ )
29
+ from snowflake.snowpark_connect.utils.identifiers import (
30
+ split_fully_qualified_spark_name,
30
31
  )
31
32
  from snowflake.snowpark_connect.utils.telemetry import (
32
33
  SnowparkConnectNotImplementedError,
33
34
  )
34
35
 
35
36
 
36
- def map_extension(rel: relation_proto.Relation) -> snowpark.DataFrame:
37
+ def map_extension(
38
+ rel: relation_proto.Relation,
39
+ ) -> DataFrameContainer:
37
40
  """
38
41
  The Extension relation type contains any extensions we use for adding new
39
42
  functionality to Spark Connect.
@@ -46,7 +49,8 @@ def map_extension(rel: relation_proto.Relation) -> snowpark.DataFrame:
46
49
  match extension.WhichOneof("op"):
47
50
  case "rdd_map":
48
51
  rdd_map = extension.rdd_map
49
- input_df: snowpark.DataFrame = map_relation(rdd_map.input)
52
+ result = map_relation(rdd_map.input)
53
+ input_df = result.dataframe
50
54
 
51
55
  column_name = "_RDD_"
52
56
  if len(input_df.columns) > 1:
@@ -67,32 +71,39 @@ def map_extension(rel: relation_proto.Relation) -> snowpark.DataFrame:
67
71
  replace=True,
68
72
  )
69
73
  result = input_df.select(func(column_name).as_(column_name))
70
- return with_column_map(result, [column_name], [column_name], [return_type])
74
+ return DataFrameContainer.create_with_column_mapping(
75
+ dataframe=result,
76
+ spark_column_names=[column_name],
77
+ snowpark_column_names=[column_name],
78
+ snowpark_column_types=[return_type],
79
+ )
71
80
  case "subquery_column_aliases":
72
81
  subquery_aliases = extension.subquery_column_aliases
73
82
  rel.extension.Unpack(subquery_aliases)
74
- input_df: snowpark.DataFrame = map_relation(subquery_aliases.input)
75
- snowpark_col_names = input_df._column_map.get_snowpark_columns()
83
+ result = map_relation(subquery_aliases.input)
84
+ input_df = result.dataframe
85
+ snowpark_col_names = result.column_map.get_snowpark_columns()
76
86
  if len(subquery_aliases.aliases) != len(snowpark_col_names):
77
87
  raise AnalysisException(
78
88
  "Number of column aliases does not match number of columns. "
79
89
  f"Number of column aliases: {len(subquery_aliases.aliases)}; "
80
90
  f"number of columns: {len(snowpark_col_names)}."
81
91
  )
82
- return with_column_map(
83
- input_df,
84
- subquery_aliases.aliases,
85
- snowpark_col_names,
86
- column_qualifiers=input_df._column_map.get_qualifiers(),
92
+ return DataFrameContainer.create_with_column_mapping(
93
+ dataframe=input_df,
94
+ spark_column_names=subquery_aliases.aliases,
95
+ snowpark_column_names=snowpark_col_names,
96
+ column_qualifiers=result.column_map.get_qualifiers(),
87
97
  )
88
98
  case "lateral_join":
89
99
  lateral_join = extension.lateral_join
90
- left_df: snowpark.DataFrame = map_relation(lateral_join.left)
100
+ left_result = map_relation(lateral_join.left)
101
+ left_df = left_result.dataframe
91
102
 
92
103
  udtf_info = get_udtf_project(lateral_join.right)
93
104
  if udtf_info:
94
105
  return handle_lateral_join_with_udtf(
95
- left_df, lateral_join.right, udtf_info
106
+ left_result, lateral_join.right, udtf_info
96
107
  )
97
108
 
98
109
  left_queries = left_df.queries["queries"]
@@ -101,8 +112,9 @@ def map_extension(rel: relation_proto.Relation) -> snowpark.DataFrame:
101
112
  f"Unexpected number of queries: {len(left_queries)}"
102
113
  )
103
114
  left_query = left_queries[0]
104
- with push_outer_dataframe(left_df):
105
- right_df: snowpark.DataFrame = map_relation(lateral_join.right)
115
+ with push_outer_dataframe(left_result):
116
+ right_result = map_relation(lateral_join.right)
117
+ right_df = right_result.dataframe
106
118
  right_queries = right_df.queries["queries"]
107
119
  if len(right_queries) != 1:
108
120
  raise SnowparkConnectNotImplementedError(
@@ -112,14 +124,14 @@ def map_extension(rel: relation_proto.Relation) -> snowpark.DataFrame:
112
124
  input_df_sql = f"WITH __left AS ({left_query}) SELECT * FROM __left INNER JOIN LATERAL ({right_query})"
113
125
  session = snowpark.Session.get_active_session()
114
126
  input_df = session.sql(input_df_sql)
115
- return with_column_map(
116
- input_df,
117
- left_df._column_map.get_spark_columns()
118
- + right_df._column_map.get_spark_columns(),
119
- left_df._column_map.get_snowpark_columns()
120
- + right_df._column_map.get_snowpark_columns(),
121
- column_qualifiers=left_df._column_map.get_qualifiers()
122
- + right_df._column_map.get_qualifiers(),
127
+ return DataFrameContainer.create_with_column_mapping(
128
+ dataframe=input_df,
129
+ spark_column_names=left_result.column_map.get_spark_columns()
130
+ + right_result.column_map.get_spark_columns(),
131
+ snowpark_column_names=left_result.column_map.get_snowpark_columns()
132
+ + right_result.column_map.get_snowpark_columns(),
133
+ column_qualifiers=left_result.column_map.get_qualifiers()
134
+ + right_result.column_map.get_qualifiers(),
123
135
  )
124
136
 
125
137
  case "udtf_with_table_arguments":
@@ -165,13 +177,13 @@ def handle_udtf_with_table_arguments(
165
177
  raise ValueError(f"UDTF '{udtf_info.function_name}' not found.")
166
178
  _udtf_obj, udtf_spark_output_names = session._udtfs[udtf_name_lower]
167
179
 
168
- table_dfs = []
180
+ table_containers = []
169
181
  for table_arg_info in udtf_info.table_arguments:
170
- table_df = map_relation(table_arg_info.table_argument)
171
- table_dfs.append((table_df, table_arg_info.table_argument_idx))
182
+ result = map_relation(table_arg_info.table_argument)
183
+ table_containers.append((result, table_arg_info.table_argument_idx))
172
184
 
173
- if len(table_dfs) == 1:
174
- base_df = table_dfs[0][0]
185
+ if len(table_containers) == 1:
186
+ base_df = table_containers[0][0].dataframe
175
187
  else:
176
188
  if not get_boolean_session_config_param(
177
189
  "spark.sql.tvf.allowMultipleTableArguments.enabled"
@@ -181,11 +193,11 @@ def handle_udtf_with_table_arguments(
181
193
  "Please set `spark.sql.tvf.allowMultipleTableArguments.enabled` to `true`"
182
194
  )
183
195
 
184
- base_df = table_dfs[0][0]
196
+ base_df = table_containers[0][0].dataframe
185
197
  first_table_col_count = len(base_df.columns)
186
198
 
187
- for table_df, _ in table_dfs[1:]:
188
- base_df = base_df.cross_join(table_df)
199
+ for table_container, _ in table_containers[1:]:
200
+ base_df = base_df.cross_join(table_container.dataframe)
189
201
 
190
202
  # Ensure deterministic ordering to match Spark's Cartesian product behavior
191
203
  # For two tables A and B, Spark produces: for each B row, iterate through A rows
@@ -206,9 +218,9 @@ def handle_udtf_with_table_arguments(
206
218
  scalar_args.append(typed_column.col)
207
219
 
208
220
  table_arg_variants = []
209
- for table_df, table_arg_idx in table_dfs:
210
- table_columns = table_df._column_map.get_snowpark_columns()
211
- spark_columns = table_df._column_map.get_spark_columns()
221
+ for table_container, table_arg_idx in table_containers:
222
+ table_columns = table_container.column_map.get_snowpark_columns()
223
+ spark_columns = table_container.column_map.get_spark_columns()
212
224
 
213
225
  # Create a structure that supports both positional and named access
214
226
  # Format: {"__fields__": ["col1", "col2"], "__values__": [val1, val2]}
@@ -247,15 +259,15 @@ def handle_udtf_with_table_arguments(
247
259
 
248
260
  final_df = result_df.select(*udtf_output_columns)
249
261
 
250
- return with_column_map(
251
- final_df,
252
- udtf_spark_output_names,
253
- udtf_output_columns,
262
+ return DataFrameContainer.create_with_column_mapping(
263
+ dataframe=final_df,
264
+ spark_column_names=udtf_spark_output_names,
265
+ snowpark_column_names=udtf_output_columns,
254
266
  )
255
267
 
256
268
 
257
269
  def handle_lateral_join_with_udtf(
258
- left_df: snowpark.DataFrame,
270
+ left_result: DataFrameContainer,
259
271
  udtf_relation: relation_proto.Relation,
260
272
  udtf_info: tuple[snowpark.udtf.UserDefinedTableFunction, list],
261
273
  ) -> snowpark.DataFrame:
@@ -269,7 +281,8 @@ def handle_lateral_join_with_udtf(
269
281
  _udtf_obj, udtf_spark_output_names = udtf_info
270
282
 
271
283
  typer = ExpressionTyper.dummy_typer(session)
272
- left_column_map = left_df._column_map
284
+ left_column_map = left_result.column_map
285
+ left_df = left_result.dataframe
273
286
  table_func = snowpark_fn.table_function(_udtf_obj.name)
274
287
  udtf_args = [
275
288
  map_expression(arg_proto, left_column_map, typer)[1].col
@@ -278,11 +291,12 @@ def handle_lateral_join_with_udtf(
278
291
  udtf_args_variant = [snowpark_fn.to_variant(arg) for arg in udtf_args]
279
292
  result_df = left_df.join_table_function(table_func(*udtf_args_variant))
280
293
 
281
- return with_column_map(
282
- result_df,
283
- left_df._column_map.get_spark_columns() + udtf_spark_output_names,
284
- result_df.columns,
285
- column_qualifiers=left_df._column_map.get_qualifiers()
294
+ return DataFrameContainer.create_with_column_mapping(
295
+ dataframe=result_df,
296
+ spark_column_names=left_result.column_map.get_spark_columns()
297
+ + udtf_spark_output_names,
298
+ snowpark_column_names=result_df.columns,
299
+ column_qualifiers=left_result.column_map.get_qualifiers()
286
300
  + [[]] * len(udtf_spark_output_names),
287
301
  )
288
302
 
@@ -290,7 +304,8 @@ def handle_lateral_join_with_udtf(
290
304
  def map_aggregate(
291
305
  aggregate: snowflake_proto.Aggregate, plan_id: int
292
306
  ) -> snowpark.DataFrame:
293
- input_df: snowpark.DataFrame = map_relation(aggregate.input)
307
+ input_container = map_relation(aggregate.input)
308
+ input_df: snowpark.DataFrame = input_container.dataframe
294
309
 
295
310
  # Detect the "GROUP BY ALL" case:
296
311
  # - it's a plain GROUP BY (not ROLLUP, CUBE, etc.)
@@ -307,7 +322,7 @@ def map_aggregate(
307
322
  if (
308
323
  len(parsed_col_name) == 1
309
324
  and parsed_col_name[0].lower() == "all"
310
- and input_df._column_map.get_snowpark_column_name_from_spark_column_name(
325
+ and input_container.column_map.get_snowpark_column_name_from_spark_column_name(
311
326
  parsed_col_name[0], allow_non_exists=True
312
327
  )
313
328
  is None
@@ -320,7 +335,9 @@ def map_aggregate(
320
335
  typer = ExpressionTyper(input_df)
321
336
 
322
337
  def _map_column(exp: expression_proto.Expression) -> tuple[str, TypedColumn]:
323
- new_names, snowpark_column = map_expression(exp, input_df._column_map, typer)
338
+ new_names, snowpark_column = map_expression(
339
+ exp, input_container.column_map, typer
340
+ )
324
341
  if len(new_names) != 1:
325
342
  raise SnowparkConnectNotImplementedError(
326
343
  "Multi-column aggregate expressions are not supported"
@@ -345,6 +362,10 @@ def map_aggregate(
345
362
  if not is_group_by_all:
346
363
  raw_groupings = [_map_column(exp) for exp in aggregate.grouping_expressions]
347
364
 
365
+ # Set the current grouping columns in context for grouping_id() function
366
+ grouping_spark_columns = [spark_name for spark_name, _ in raw_groupings]
367
+ set_current_grouping_columns(grouping_spark_columns)
368
+
348
369
  # Now create column name lists and assign aliases.
349
370
  # In case of GROUP BY ALL, even though groupings are a subset of aggregations,
350
371
  # they will have their own aliases so we can drop them later.
@@ -378,7 +399,7 @@ def map_aggregate(
378
399
  # TODO: What do we do about groupings?
379
400
  sets = (
380
401
  [
381
- map_expression(exp, input_df._column_map, typer)[1].col
402
+ map_expression(exp, input_container.column_map, typer)[1].col
382
403
  for exp in grouping_sets.grouping_set
383
404
  ]
384
405
  for grouping_sets in aggregate.grouping_sets
@@ -397,16 +418,20 @@ def map_aggregate(
397
418
  result = result.select(result.columns[-len(spark_columns) :])
398
419
 
399
420
  # Build a parent column map that includes groupings.
400
- result = with_column_map(
401
- result, spark_columns, snowpark_columns, snowpark_column_types
421
+ result_container = DataFrameContainer.create_with_column_mapping(
422
+ dataframe=result,
423
+ spark_column_names=spark_columns,
424
+ snowpark_column_names=snowpark_columns,
425
+ snowpark_column_types=snowpark_column_types,
402
426
  )
403
427
 
404
428
  # Drop the groupings.
405
429
  grouping_count = len(groupings)
406
- return with_column_map(
430
+
431
+ return DataFrameContainer.create_with_column_mapping(
407
432
  result.drop(snowpark_columns[:grouping_count]),
408
433
  spark_columns[grouping_count:],
409
434
  snowpark_columns[grouping_count:],
410
435
  snowpark_column_types[grouping_count:],
411
- parent_column_name_map=result._column_map,
436
+ parent_column_name_map=result_container.column_map,
412
437
  )
@@ -8,14 +8,10 @@ import pyspark.sql.connect.proto.relations_pb2 as relation_proto
8
8
 
9
9
  import snowflake.snowpark.functions as snowpark_fn
10
10
  from snowflake import snowpark
11
- from snowflake.snowpark_connect.column_name_handler import (
12
- ColumnNameMap,
13
- JoinColumnNameMap,
14
- set_schema_getter,
15
- with_column_map,
16
- )
11
+ from snowflake.snowpark_connect.column_name_handler import JoinColumnNameMap
17
12
  from snowflake.snowpark_connect.config import global_config
18
13
  from snowflake.snowpark_connect.constants import COLUMN_METADATA_COLLISION_KEY
14
+ from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
19
15
  from snowflake.snowpark_connect.error.error_utils import SparkException
20
16
  from snowflake.snowpark_connect.expression.map_expression import (
21
17
  map_single_column_expression,
@@ -38,15 +34,18 @@ from snowflake.snowpark_connect.utils.telemetry import (
38
34
  USING_COLUMN_NOT_FOUND_ERROR = "[UNRESOLVED_USING_COLUMN_FOR_JOIN] USING column `{0}` not found on the {1} side of the join. The {1}-side columns: {2}"
39
35
 
40
36
 
41
- def map_join(rel: relation_proto.Relation) -> snowpark.DataFrame:
42
- left_input: snowpark.DataFrame = map_relation(rel.join.left)
43
- right_input: snowpark.DataFrame = map_relation(rel.join.right)
37
+ def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
38
+ left_container: DataFrameContainer = map_relation(rel.join.left)
39
+ right_container: DataFrameContainer = map_relation(rel.join.right)
40
+
41
+ left_input: snowpark.DataFrame = left_container.dataframe
42
+ right_input: snowpark.DataFrame = right_container.dataframe
44
43
  is_natural_join = rel.join.join_type >= NATURAL_JOIN_TYPE_BASE
45
44
  using_columns = rel.join.using_columns
46
45
  if is_natural_join:
47
46
  rel.join.join_type -= NATURAL_JOIN_TYPE_BASE
48
- left_spark_columns = left_input._column_map.get_spark_columns()
49
- right_spark_columns = right_input._column_map.get_spark_columns()
47
+ left_spark_columns = left_container.column_map.get_spark_columns()
48
+ right_spark_columns = right_container.column_map.get_spark_columns()
50
49
  common_spark_columns = [
51
50
  x for x in left_spark_columns if x in right_spark_columns
52
51
  ]
@@ -79,8 +78,8 @@ def map_join(rel: relation_proto.Relation) -> snowpark.DataFrame:
79
78
  if rel.join.HasField("join_condition"):
80
79
  assert not using_columns
81
80
 
82
- left_columns = list(left_input._column_map.spark_to_col.keys())
83
- right_columns = list(right_input._column_map.spark_to_col.keys())
81
+ left_columns = list(left_container.column_map.spark_to_col.keys())
82
+ right_columns = list(right_container.column_map.spark_to_col.keys())
84
83
 
85
84
  # All PySpark join types are in the format of JOIN_TYPE_XXX.
86
85
  # We remove the first 10 characters (JOIN_TYPE_) and replace all underscores with spaces to match the exception.
@@ -90,15 +89,15 @@ def map_join(rel: relation_proto.Relation) -> snowpark.DataFrame:
90
89
  with push_sql_scope(), push_evaluating_join_condition(
91
90
  pyspark_join_type, left_columns, right_columns
92
91
  ):
93
- if left_input._alias is not None:
94
- set_sql_plan_name(left_input._alias, rel.join.left.common.plan_id)
95
- if right_input._alias is not None:
96
- set_sql_plan_name(right_input._alias, rel.join.right.common.plan_id)
92
+ if left_container.alias is not None:
93
+ set_sql_plan_name(left_container.alias, rel.join.left.common.plan_id)
94
+ if right_container.alias is not None:
95
+ set_sql_plan_name(right_container.alias, rel.join.right.common.plan_id)
97
96
  _, join_expression = map_single_column_expression(
98
97
  rel.join.join_condition,
99
98
  column_mapping=JoinColumnNameMap(
100
- left_input,
101
- right_input,
99
+ left_container.column_map,
100
+ right_container.column_map,
102
101
  ),
103
102
  typer=JoinExpressionTyper(left_input, right_input),
104
103
  )
@@ -111,7 +110,7 @@ def map_join(rel: relation_proto.Relation) -> snowpark.DataFrame:
111
110
  )
112
111
  elif using_columns:
113
112
  if any(
114
- left_input._column_map.get_snowpark_column_name_from_spark_column_name(
113
+ left_container.column_map.get_snowpark_column_name_from_spark_column_name(
115
114
  c, allow_non_exists=True, return_first=True
116
115
  )
117
116
  is None
@@ -124,17 +123,17 @@ def map_join(rel: relation_proto.Relation) -> snowpark.DataFrame:
124
123
  next(
125
124
  c
126
125
  for c in using_columns
127
- if left_input._column_map.get_snowpark_column_name_from_spark_column_name(
126
+ if left_container.column_map.get_snowpark_column_name_from_spark_column_name(
128
127
  c, allow_non_exists=True, return_first=True
129
128
  )
130
129
  is None
131
130
  ),
132
131
  "left",
133
- left_input._column_map.get_spark_columns(),
132
+ left_container.column_map.get_spark_columns(),
134
133
  )
135
134
  )
136
135
  if any(
137
- right_input._column_map.get_snowpark_column_name_from_spark_column_name(
136
+ right_container.column_map.get_snowpark_column_name_from_spark_column_name(
138
137
  c, allow_non_exists=True, return_first=True
139
138
  )
140
139
  is None
@@ -147,26 +146,26 @@ def map_join(rel: relation_proto.Relation) -> snowpark.DataFrame:
147
146
  next(
148
147
  c
149
148
  for c in using_columns
150
- if right_input._column_map.get_snowpark_column_name_from_spark_column_name(
149
+ if right_container.column_map.get_snowpark_column_name_from_spark_column_name(
151
150
  c, allow_non_exists=True, return_first=True
152
151
  )
153
152
  is None
154
153
  ),
155
154
  "right",
156
- right_input._column_map.get_spark_columns(),
155
+ right_container.column_map.get_spark_columns(),
157
156
  )
158
157
  )
159
158
 
160
159
  # Round trip the using columns through the column map to get the correct names
161
160
  # in order to support case sensitivity.
162
161
  # TODO: case_corrected_left_columns / case_corrected_right_columns may no longer be required as Snowpark dataframe preserves the column casing now.
163
- case_corrected_left_columns = left_input._column_map.get_spark_column_names_from_snowpark_column_names(
164
- left_input._column_map.get_snowpark_column_names_from_spark_column_names(
162
+ case_corrected_left_columns = left_container.column_map.get_spark_column_names_from_snowpark_column_names(
163
+ left_container.column_map.get_snowpark_column_names_from_spark_column_names(
165
164
  list(using_columns), return_first=True
166
165
  )
167
166
  )
168
- case_corrected_right_columns = right_input._column_map.get_spark_column_names_from_snowpark_column_names(
169
- right_input._column_map.get_snowpark_column_names_from_spark_column_names(
167
+ case_corrected_right_columns = right_container.column_map.get_spark_column_names_from_snowpark_column_names(
168
+ right_container.column_map.get_snowpark_column_names_from_spark_column_names(
170
169
  list(using_columns), return_first=True
171
170
  )
172
171
  )
@@ -177,12 +176,12 @@ def map_join(rel: relation_proto.Relation) -> snowpark.DataFrame:
177
176
  snowpark_using_columns = [
178
177
  (
179
178
  left_input[
180
- left_input._column_map.get_snowpark_column_name_from_spark_column_name(
179
+ left_container.column_map.get_snowpark_column_name_from_spark_column_name(
181
180
  lft, return_first=True
182
181
  )
183
182
  ],
184
183
  right_input[
185
- right_input._column_map.get_snowpark_column_name_from_spark_column_name(
184
+ right_container.column_map.get_snowpark_column_name_from_spark_column_name(
186
185
  r, return_first=True
187
186
  )
188
187
  ],
@@ -231,45 +230,49 @@ def map_join(rel: relation_proto.Relation) -> snowpark.DataFrame:
231
230
  # - LEFT SEMI JOIN: Returns left rows that have matches in right table (no right columns)
232
231
  # - LEFT ANTI JOIN: Returns left rows that have NO matches in right table (no right columns)
233
232
  # Both preserve only the columns from the left DataFrame without adding any columns from the right.
234
- spark_cols_after_join: list[str] = left_input._column_map.get_spark_columns()
235
- qualifiers = left_input._column_map.get_qualifiers()
233
+ spark_cols_after_join: list[str] = left_container.column_map.get_spark_columns()
234
+ qualifiers = left_container.column_map.get_qualifiers()
236
235
  else:
237
236
  # Add Spark columns and plan_ids from left DF
238
237
  spark_cols_after_join: list[str] = list(
239
- left_input._column_map.get_spark_columns()
238
+ left_container.column_map.get_spark_columns()
240
239
  ) + [
241
240
  spark_col
242
- for i, spark_col in enumerate(right_input._column_map.get_spark_columns())
241
+ for i, spark_col in enumerate(
242
+ right_container.column_map.get_spark_columns()
243
+ )
243
244
  if spark_col not in case_corrected_right_columns
244
245
  or spark_col
245
- in right_input._column_map.get_spark_columns()[
246
+ in right_container.column_map.get_spark_columns()[
246
247
  :i
247
248
  ] # this is to make sure we only remove the column once
248
249
  ]
249
250
 
250
- qualifiers = list(left_input._column_map.get_qualifiers()) + [
251
- right_input._column_map.get_qualifier_for_spark_column(spark_col)
252
- for i, spark_col in enumerate(right_input._column_map.get_spark_columns())
251
+ qualifiers = list(left_container.column_map.get_qualifiers()) + [
252
+ right_container.column_map.get_qualifier_for_spark_column(spark_col)
253
+ for i, spark_col in enumerate(
254
+ right_container.column_map.get_spark_columns()
255
+ )
253
256
  if spark_col not in case_corrected_right_columns
254
257
  or spark_col
255
- in right_input._column_map.get_spark_columns()[
258
+ in right_container.column_map.get_spark_columns()[
256
259
  :i
257
260
  ] # this is to make sure we only remove the column once]
258
261
  ]
259
262
 
260
263
  column_metadata = {}
261
- if left_input._column_map.column_metadata:
262
- column_metadata.update(left_input._column_map.column_metadata)
264
+ if left_container.column_map.column_metadata:
265
+ column_metadata.update(left_container.column_map.column_metadata)
263
266
 
264
- if right_input._column_map.column_metadata:
265
- for key, value in right_input._column_map.column_metadata.items():
267
+ if right_container.column_map.column_metadata:
268
+ for key, value in right_container.column_map.column_metadata.items():
266
269
  if key not in column_metadata:
267
270
  column_metadata[key] = value
268
271
  else:
269
272
  # In case of collision, use snowpark's column's expr_id as prefix.
270
273
  # this is a temporary solution until SNOW-1926440 is resolved.
271
274
  try:
272
- snowpark_name = right_input._column_map.get_snowpark_column_name_from_spark_column_name(
275
+ snowpark_name = right_container.column_map.get_snowpark_column_name_from_spark_column_name(
273
276
  key
274
277
  )
275
278
  expr_id = right_input[snowpark_name]._expression.expr_id
@@ -281,10 +284,10 @@ def map_join(rel: relation_proto.Relation) -> snowpark.DataFrame:
281
284
  # ignore any errors that happens while fetching the metadata
282
285
  pass
283
286
 
284
- result_df = with_column_map(
285
- result,
286
- spark_cols_after_join,
287
- result.columns,
287
+ result_container = DataFrameContainer.create_with_column_mapping(
288
+ dataframe=result,
289
+ spark_column_names=spark_cols_after_join,
290
+ snowpark_column_names=result.columns,
288
291
  column_metadata=column_metadata,
289
292
  column_qualifiers=qualifiers,
290
293
  )
@@ -298,7 +301,7 @@ def map_join(rel: relation_proto.Relation) -> snowpark.DataFrame:
298
301
  and rel.join.right.common.HasField("plan_id")
299
302
  ):
300
303
  right_plan_id = rel.join.right.common.plan_id
301
- set_plan_id_map(right_plan_id, result_df)
304
+ set_plan_id_map(right_plan_id, result_container)
302
305
 
303
306
  # For FULL OUTER joins, we also need to map the left dataframe's plan_id
304
307
  # since both columns are replaced with a coalesced column
@@ -309,7 +312,7 @@ def map_join(rel: relation_proto.Relation) -> snowpark.DataFrame:
309
312
  and rel.join.left.common.HasField("plan_id")
310
313
  ):
311
314
  left_plan_id = rel.join.left.common.plan_id
312
- set_plan_id_map(left_plan_id, result_df)
315
+ set_plan_id_map(left_plan_id, result_container)
313
316
 
314
317
  if rel.join.using_columns:
315
318
  # When join 'using_columns', the 'join columns' should go first in result DF.
@@ -323,19 +326,25 @@ def map_join(rel: relation_proto.Relation) -> snowpark.DataFrame:
323
326
  remaining = [el for i, el in enumerate(lst) if i not in idxs_to_shift]
324
327
  return to_move + remaining
325
328
 
326
- reordered_df = result_df.select(
327
- [snowpark_fn.col(c) for c in reorder(result_df.columns)]
329
+ # Create reordered DataFrame
330
+ reordered_df = result_container.dataframe.select(
331
+ [snowpark_fn.col(c) for c in reorder(result_container.dataframe.columns)]
328
332
  )
329
- reordered_df._column_map = ColumnNameMap(
330
- spark_column_names=reorder(result_df._column_map.get_spark_columns()),
331
- snowpark_column_names=reorder(result_df._column_map.get_snowpark_columns()),
333
+
334
+ # Create new container with reordered metadata
335
+ original_df = result_container.dataframe
336
+ return DataFrameContainer.create_with_column_mapping(
337
+ dataframe=reordered_df,
338
+ spark_column_names=reorder(result_container.column_map.get_spark_columns()),
339
+ snowpark_column_names=reorder(
340
+ result_container.column_map.get_snowpark_columns()
341
+ ),
332
342
  column_metadata=column_metadata,
333
343
  column_qualifiers=reorder(qualifiers),
344
+ table_name=result_container.table_name,
345
+ cached_schema_getter=lambda: snowpark.types.StructType(
346
+ reorder(original_df.schema.fields)
347
+ ),
334
348
  )
335
- reordered_df._table_name = result_df._table_name
336
- set_schema_getter(
337
- reordered_df,
338
- lambda: snowpark.types.StructType(reorder(result_df.schema.fields)),
339
- )
340
- return reordered_df
341
- return result_df
349
+
350
+ return result_container