snowpark-connect 0.20.2__py3-none-any.whl → 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of snowpark-connect might be problematic. Click here for more details.

Files changed (67) hide show
  1. snowflake/snowpark_connect/analyze_plan/map_tree_string.py +3 -2
  2. snowflake/snowpark_connect/column_name_handler.py +6 -65
  3. snowflake/snowpark_connect/config.py +28 -14
  4. snowflake/snowpark_connect/dataframe_container.py +242 -0
  5. snowflake/snowpark_connect/execute_plan/map_execution_command.py +13 -23
  6. snowflake/snowpark_connect/execute_plan/map_execution_root.py +9 -5
  7. snowflake/snowpark_connect/expression/map_extension.py +2 -1
  8. snowflake/snowpark_connect/expression/map_unresolved_attribute.py +8 -7
  9. snowflake/snowpark_connect/expression/map_unresolved_function.py +279 -43
  10. snowflake/snowpark_connect/expression/map_unresolved_star.py +8 -8
  11. snowflake/snowpark_connect/expression/map_update_fields.py +1 -1
  12. snowflake/snowpark_connect/expression/typer.py +6 -6
  13. snowflake/snowpark_connect/proto/control_pb2.py +17 -16
  14. snowflake/snowpark_connect/proto/control_pb2.pyi +17 -17
  15. snowflake/snowpark_connect/proto/control_pb2_grpc.py +12 -63
  16. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.py +15 -14
  17. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.pyi +19 -14
  18. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +27 -26
  19. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +74 -68
  20. snowflake/snowpark_connect/relation/catalogs/abstract_spark_catalog.py +5 -5
  21. snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +25 -17
  22. snowflake/snowpark_connect/relation/map_aggregate.py +72 -47
  23. snowflake/snowpark_connect/relation/map_catalog.py +2 -2
  24. snowflake/snowpark_connect/relation/map_column_ops.py +207 -144
  25. snowflake/snowpark_connect/relation/map_crosstab.py +25 -6
  26. snowflake/snowpark_connect/relation/map_extension.py +81 -56
  27. snowflake/snowpark_connect/relation/map_join.py +72 -63
  28. snowflake/snowpark_connect/relation/map_local_relation.py +35 -20
  29. snowflake/snowpark_connect/relation/map_map_partitions.py +21 -16
  30. snowflake/snowpark_connect/relation/map_relation.py +22 -16
  31. snowflake/snowpark_connect/relation/map_row_ops.py +232 -146
  32. snowflake/snowpark_connect/relation/map_sample_by.py +15 -8
  33. snowflake/snowpark_connect/relation/map_show_string.py +42 -5
  34. snowflake/snowpark_connect/relation/map_sql.py +155 -78
  35. snowflake/snowpark_connect/relation/map_stats.py +88 -39
  36. snowflake/snowpark_connect/relation/map_subquery_alias.py +13 -14
  37. snowflake/snowpark_connect/relation/map_udtf.py +6 -9
  38. snowflake/snowpark_connect/relation/read/map_read.py +8 -3
  39. snowflake/snowpark_connect/relation/read/map_read_csv.py +7 -7
  40. snowflake/snowpark_connect/relation/read/map_read_jdbc.py +7 -7
  41. snowflake/snowpark_connect/relation/read/map_read_json.py +7 -7
  42. snowflake/snowpark_connect/relation/read/map_read_parquet.py +7 -7
  43. snowflake/snowpark_connect/relation/read/map_read_socket.py +7 -3
  44. snowflake/snowpark_connect/relation/read/map_read_table.py +25 -16
  45. snowflake/snowpark_connect/relation/read/map_read_text.py +7 -7
  46. snowflake/snowpark_connect/relation/utils.py +11 -5
  47. snowflake/snowpark_connect/relation/write/jdbc_write_dbapi.py +15 -12
  48. snowflake/snowpark_connect/relation/write/map_write.py +199 -40
  49. snowflake/snowpark_connect/relation/write/map_write_jdbc.py +3 -2
  50. snowflake/snowpark_connect/server.py +34 -4
  51. snowflake/snowpark_connect/type_mapping.py +2 -23
  52. snowflake/snowpark_connect/utils/cache.py +27 -22
  53. snowflake/snowpark_connect/utils/context.py +33 -17
  54. snowflake/snowpark_connect/utils/{attribute_handling.py → identifiers.py} +47 -0
  55. snowflake/snowpark_connect/utils/session.py +41 -34
  56. snowflake/snowpark_connect/utils/telemetry.py +1 -2
  57. snowflake/snowpark_connect/version.py +1 -1
  58. {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.21.0.dist-info}/METADATA +5 -3
  59. {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.21.0.dist-info}/RECORD +67 -64
  60. snowpark_connect-0.21.0.dist-info/licenses/LICENSE-binary +568 -0
  61. snowpark_connect-0.21.0.dist-info/licenses/NOTICE-binary +1533 -0
  62. {snowpark_connect-0.20.2.data → snowpark_connect-0.21.0.data}/scripts/snowpark-connect +0 -0
  63. {snowpark_connect-0.20.2.data → snowpark_connect-0.21.0.data}/scripts/snowpark-session +0 -0
  64. {snowpark_connect-0.20.2.data → snowpark_connect-0.21.0.data}/scripts/snowpark-submit +0 -0
  65. {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.21.0.dist-info}/WHEEL +0 -0
  66. {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.21.0.dist-info}/licenses/LICENSE.txt +0 -0
  67. {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.21.0.dist-info}/top_level.txt +0 -0
@@ -7,30 +7,34 @@ import ast
7
7
  import numpy as np
8
8
  import pandas
9
9
  import pyspark.sql.connect.proto.relations_pb2 as relation_proto
10
+ from pyspark.errors.exceptions.base import AnalysisException
10
11
 
11
12
  import snowflake.snowpark.functions as fn
12
13
  import snowflake.snowpark.types as snowpark_types
13
14
  from snowflake import snowpark
14
15
  from snowflake.snowpark.exceptions import SnowparkSQLException
15
- from snowflake.snowpark_connect.column_name_handler import with_column_map
16
+ from snowflake.snowpark_connect.config import get_boolean_session_config_param
17
+ from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
16
18
  from snowflake.snowpark_connect.relation.map_relation import map_relation
17
19
  from snowflake.snowpark_connect.utils.session import get_or_create_snowpark_session
18
20
 
19
21
 
20
22
  def map_corr(
21
23
  rel: relation_proto.Relation,
22
- ) -> snowpark.DataFrame:
24
+ ) -> DataFrameContainer:
23
25
  """
24
26
  Find the correlation of two columns in the input DataFrame.
25
27
 
26
28
  Returns a pandas DataFrame because the correlation of two columns produces
27
29
  a scalar value.
28
30
  """
29
- input_df: snowpark.DataFrame = map_relation(rel.corr.input)
30
- col1 = input_df._column_map.get_snowpark_column_name_from_spark_column_name(
31
+ input_container = map_relation(rel.corr.input)
32
+ input_df = input_container.dataframe
33
+
34
+ col1 = input_container.column_map.get_snowpark_column_name_from_spark_column_name(
31
35
  rel.corr.col1
32
36
  )
33
- col2 = input_df._column_map.get_snowpark_column_name_from_spark_column_name(
37
+ col2 = input_container.column_map.get_snowpark_column_name_from_spark_column_name(
34
38
  rel.corr.col2
35
39
  )
36
40
  # TODO: Handle method, Snowpark does not support this yet.
@@ -44,18 +48,20 @@ def map_corr(
44
48
 
45
49
  def map_cov(
46
50
  rel: relation_proto.Relation,
47
- ) -> snowpark.DataFrame:
51
+ ) -> DataFrameContainer:
48
52
  """
49
53
  Find the covariance of two columns in the input DataFrame.
50
54
 
51
55
  Returns a pandas DataFrame because the corvariance of two columns produces
52
56
  a scalar value.
53
57
  """
54
- input_df: snowpark.DataFrame = map_relation(rel.cov.input)
55
- col1 = input_df._column_map.get_snowpark_column_name_from_spark_column_name(
58
+ input_container = map_relation(rel.cov.input)
59
+ input_df = input_container.dataframe
60
+
61
+ col1 = input_container.column_map.get_snowpark_column_name_from_spark_column_name(
56
62
  rel.cov.col1
57
63
  )
58
- col2 = input_df._column_map.get_snowpark_column_name_from_spark_column_name(
64
+ col2 = input_container.column_map.get_snowpark_column_name_from_spark_column_name(
59
65
  rel.cov.col2
60
66
  )
61
67
  result: float = input_df.cov(col1, col2)
@@ -64,15 +70,40 @@ def map_cov(
64
70
 
65
71
  def map_approx_quantile(
66
72
  rel: relation_proto.Relation,
67
- ) -> snowpark.DataFrame:
73
+ ) -> DataFrameContainer:
68
74
  """
69
75
  Find one or more approximate quantiles in the input DataFrame.
70
76
 
71
77
  Returns a pandas DataFrame because the approximate quantile produces a
72
78
  list of scalar values.
73
79
  """
74
- input_df: snowpark.DataFrame = map_relation(rel.approx_quantile.input)
75
- cols = input_df._column_map.get_snowpark_column_names_from_spark_column_names(
80
+ input_container = map_relation(rel.approx_quantile.input)
81
+ input_df = input_container.dataframe
82
+
83
+ snowflake_compatible = get_boolean_session_config_param(
84
+ "enable_snowflake_extension_behavior"
85
+ )
86
+
87
+ if not snowflake_compatible:
88
+ # When Snowflake extension behavior is disabled, validate that all requested columns exist
89
+ requested_spark_cols = list(rel.approx_quantile.cols)
90
+ available_spark_cols = input_container.column_map.get_spark_columns()
91
+
92
+ for col_name in requested_spark_cols:
93
+ if col_name not in available_spark_cols:
94
+ # Find suggestions for the unresolved column
95
+ suggestions = [c for c in available_spark_cols if c != col_name]
96
+ suggestion_text = (
97
+ f" Did you mean one of the following? [`{'`, `'.join(suggestions)}`]."
98
+ if suggestions
99
+ else ""
100
+ )
101
+
102
+ raise AnalysisException(
103
+ f"[UNRESOLVED_COLUMN.WITH_SUGGESTION] A column or function parameter with name `{col_name}` cannot be resolved.{suggestion_text}"
104
+ )
105
+
106
+ cols = input_container.column_map.get_snowpark_column_names_from_spark_column_names(
76
107
  list(rel.approx_quantile.cols)
77
108
  )
78
109
  quantile = list(rel.approx_quantile.probabilities)
@@ -84,7 +115,7 @@ def map_approx_quantile(
84
115
 
85
116
  def map_describe(
86
117
  rel: relation_proto.Relation,
87
- ) -> snowpark.DataFrame:
118
+ ) -> DataFrameContainer:
88
119
  """
89
120
  Computes basic statistics for numeric columns, which includes count, mean, stddev, min, and max.
90
121
  If no columns are provided, this function computes statistics for all numerical or string columns.
@@ -92,15 +123,19 @@ def map_describe(
92
123
 
93
124
  Returns a new DataFrame that provides basic statistics for the given DataFrame
94
125
  """
126
+ input_container = map_relation(rel.describe.input)
127
+ input_df = input_container.dataframe
128
+
95
129
  session = get_or_create_snowpark_session()
96
- input_df: snowpark.DataFrame = map_relation(rel.describe.input)
97
130
  spark_cols = (
98
131
  list(rel.describe.cols)
99
132
  if rel.describe.cols
100
- else input_df._column_map.get_spark_columns()
133
+ else input_container.column_map.get_spark_columns()
101
134
  )
102
135
  cols = [
103
- input_df._column_map.get_snowpark_column_name_from_spark_column_name(column)
136
+ input_container.column_map.get_snowpark_column_name_from_spark_column_name(
137
+ column
138
+ )
104
139
  for column in spark_cols
105
140
  ]
106
141
 
@@ -121,13 +156,13 @@ def map_describe(
121
156
  if stat == row.SUMMARY:
122
157
  ordered_statistics.append(row)
123
158
  ordered_desc_df = session.create_dataframe(ordered_statistics)
124
- return _build_column_map_helper(ordered_desc_df, input_df)
159
+ return _build_column_map_helper_container(ordered_desc_df, input_container)
125
160
 
126
161
 
127
162
  # TODO: track missing Snowpark feature
128
163
  def map_summary(
129
164
  rel: relation_proto.Relation,
130
- ) -> snowpark.DataFrame:
165
+ ) -> DataFrameContainer:
131
166
  """
132
167
  Computes specified statistics for numeric or string columns. Available statistics are: count, mean, stddev, min,
133
168
  max, arbitrary approximate percentiles specified as a percentage (e.g., 75%), count_distinct, and
@@ -137,12 +172,14 @@ def map_summary(
137
172
  Returns a new DataFrame that provides specified statistics for the given DataFrame.
138
173
  """
139
174
  session = get_or_create_snowpark_session()
140
- input_df: snowpark.DataFrame = map_relation(rel.summary.input)
175
+ result = map_relation(rel.summary.input)
176
+ input_container: DataFrameContainer = result
177
+ input_df = input_container.dataframe
141
178
 
142
179
  numeric_and_string_spark_cols = [
143
180
  column
144
181
  for field, column in zip(
145
- input_df.schema.fields, input_df._column_map.get_spark_columns()
182
+ input_df.schema.fields, input_container.column_map.get_spark_columns()
146
183
  )
147
184
  if isinstance(
148
185
  field.datatype, (snowpark_types._NumericType, snowpark_types.StringType)
@@ -151,7 +188,9 @@ def map_summary(
151
188
 
152
189
  # this is intentional to trigger ambigous column name is two columns of same name are provided
153
190
  numeric_and_string_snowpark_cols = [
154
- input_df._column_map.get_snowpark_column_name_from_spark_column_name(column)
191
+ input_container.column_map.get_snowpark_column_name_from_spark_column_name(
192
+ column
193
+ )
155
194
  for column in numeric_and_string_spark_cols
156
195
  ]
157
196
 
@@ -221,9 +260,11 @@ def map_summary(
221
260
  # Modified quantile results, inserting [None, None, None] for string columns
222
261
  numeric_index = iter(approx_quantile_values)
223
262
  approx_quantile_values_including_string_columns = [
224
- [str(value) for value in next(numeric_index)]
225
- if col in eligible_columns
226
- else [None] * len(quantiles)
263
+ (
264
+ [str(value) for value in next(numeric_index)]
265
+ if col in eligible_columns
266
+ else [None] * len(quantiles)
267
+ )
227
268
  for col in input_df.columns
228
269
  ]
229
270
 
@@ -248,22 +289,24 @@ def map_summary(
248
289
  spark_col_names = ["summary"]
249
290
  spark_col_names.extend(numeric_and_string_spark_cols)
250
291
 
251
- return with_column_map(
252
- ordered_summary_df,
253
- spark_col_names,
292
+ return DataFrameContainer.create_with_column_mapping(
293
+ dataframe=ordered_summary_df,
294
+ spark_column_names=spark_col_names,
254
295
  snowpark_column_names=ordered_summary_df.columns,
255
296
  )
256
297
 
257
298
 
258
- def map_freq_items(rel: relation_proto.Relation) -> snowpark.DataFrame:
299
+ def map_freq_items(rel: relation_proto.Relation) -> DataFrameContainer:
259
300
  """
260
301
  Returns an approximation of the most frequent values in the input, along with their approximate frequencies.
261
302
  """
303
+ input_container = map_relation(rel.freq_items.input)
304
+ input_df = input_container.dataframe
305
+
262
306
  session = get_or_create_snowpark_session()
263
- input_df: snowpark.DataFrame = map_relation(rel.freq_items.input)
264
307
  support = rel.freq_items.support
265
308
  spark_col_names = []
266
- cols = input_df._column_map.get_snowpark_column_names_from_spark_column_names(
309
+ cols = input_container.column_map.get_snowpark_column_names_from_spark_column_names(
267
310
  list(rel.freq_items.cols)
268
311
  )
269
312
  approx_top_k_df = input_df.select(
@@ -289,11 +332,14 @@ def map_freq_items(rel: relation_proto.Relation) -> snowpark.DataFrame:
289
332
 
290
333
  for sp_col_name in cols:
291
334
  spark_col_names.append(
292
- f"{input_df._column_map.get_spark_column_name_from_snowpark_column_name(sp_col_name)}_freqItems"
335
+ f"{input_container.column_map.get_spark_column_name_from_snowpark_column_name(sp_col_name)}_freqItems"
293
336
  )
294
337
  approx_top_k_df = session.createDataFrame([filtered_values], spark_col_names)
295
- return with_column_map(
296
- approx_top_k_df, spark_col_names, snowpark_column_names=spark_col_names
338
+
339
+ return DataFrameContainer.create_with_column_mapping(
340
+ dataframe=approx_top_k_df,
341
+ spark_column_names=spark_col_names,
342
+ snowpark_column_names=spark_col_names,
297
343
  )
298
344
 
299
345
 
@@ -306,19 +352,22 @@ def add_stat_to_df(
306
352
  return summary_df.union(session.createDataFrame(df_data, summary_df.schema))
307
353
 
308
354
 
309
- def _build_column_map_helper(
355
+ def _build_column_map_helper_container(
310
356
  desc_df: snowpark.DataFrame,
311
- input_df: snowpark.DataFrame,
312
- ) -> snowpark.DataFrame:
357
+ input_container: DataFrameContainer,
358
+ ) -> DataFrameContainer:
359
+ """Container version of _build_column_map_helper."""
313
360
  spark_col_names = ["summary"]
314
361
  for i, sp_col_name in enumerate(desc_df.columns):
315
362
  if i != 0:
316
363
  spark_col_names.append(
317
- input_df._column_map.get_spark_column_name_from_snowpark_column_name(
364
+ input_container.column_map.get_spark_column_name_from_snowpark_column_name(
318
365
  sp_col_name
319
366
  )
320
367
  )
321
368
 
322
- return with_column_map(
323
- desc_df, spark_col_names, snowpark_column_names=desc_df.columns
369
+ return DataFrameContainer.create_with_column_mapping(
370
+ dataframe=desc_df,
371
+ spark_column_names=spark_col_names,
372
+ snowpark_column_names=desc_df.columns,
324
373
  )
@@ -4,29 +4,28 @@
4
4
 
5
5
  import pyspark.sql.connect.proto.relations_pb2 as relation_proto
6
6
 
7
- from snowflake import snowpark
8
- from snowflake.snowpark_connect.column_name_handler import with_column_map
7
+ from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
9
8
  from snowflake.snowpark_connect.relation.map_relation import map_relation
10
9
 
11
10
 
12
- def map_alias(rel: relation_proto.Relation) -> snowpark.DataFrame:
11
+ def map_alias(
12
+ rel: relation_proto.Relation,
13
+ ) -> DataFrameContainer:
13
14
  """
14
15
  Returns an aliased dataframe in which the columns can now be referenced to using col(<df alias>, <column name>).
15
16
  """
16
17
  alias: str = rel.subquery_alias.alias
17
18
  # we set reuse_parsed_plan=False because we need new expr_id for the attributes (output columns) in aliased snowpark dataframe
18
19
  # reuse_parsed_plan will lead to ambiguous column name for operations like joining two dataframes that are aliased from the same dataframe
19
- input_df: snowpark.DataFrame = map_relation(
20
- rel.subquery_alias.input, reuse_parsed_plan=False
21
- )
22
- input_df._alias = alias
23
- qualifiers = [[alias]] * len(input_df._column_map.columns)
20
+ input_container = map_relation(rel.subquery_alias.input, reuse_parsed_plan=False)
21
+ qualifiers = [[alias]] * len(input_container.column_map.columns)
24
22
 
25
- return with_column_map(
26
- input_df,
27
- input_df._column_map.get_spark_columns(),
28
- input_df._column_map.get_snowpark_columns(),
29
- column_metadata=input_df._column_map.column_metadata,
23
+ return DataFrameContainer.create_with_column_mapping(
24
+ dataframe=input_container.dataframe,
25
+ spark_column_names=input_container.column_map.get_spark_columns(),
26
+ snowpark_column_names=input_container.column_map.get_snowpark_columns(),
27
+ column_metadata=input_container.column_map.column_metadata,
30
28
  column_qualifiers=qualifiers,
31
- parent_column_name_map=input_df._column_map.get_parent_column_name_map(),
29
+ parent_column_name_map=input_container.column_map.get_parent_column_name_map(),
30
+ alias=alias,
32
31
  )
@@ -8,7 +8,6 @@ import pyspark.sql.connect.proto.relations_pb2 as relation_proto
8
8
  import pyspark.sql.connect.proto.types_pb2 as types_proto
9
9
  from pyspark.errors.exceptions.base import PySparkTypeError, PythonException
10
10
 
11
- from snowflake import snowpark
12
11
  from snowflake.snowpark.functions import col, parse_json
13
12
  from snowflake.snowpark.types import (
14
13
  ArrayType,
@@ -17,14 +16,12 @@ from snowflake.snowpark.types import (
17
16
  StructType,
18
17
  VariantType,
19
18
  )
20
- from snowflake.snowpark_connect.column_name_handler import (
21
- ColumnNameMap,
22
- with_column_map,
23
- )
19
+ from snowflake.snowpark_connect.column_name_handler import ColumnNameMap
24
20
  from snowflake.snowpark_connect.config import (
25
21
  get_boolean_session_config_param,
26
22
  global_config,
27
23
  )
24
+ from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
28
25
  from snowflake.snowpark_connect.expression.map_expression import (
29
26
  map_single_column_expression,
30
27
  )
@@ -203,7 +200,7 @@ def is_arrow_enabled_in_udtf() -> bool: # REINSTATED
203
200
 
204
201
  def map_common_inline_user_defined_table_function(
205
202
  rel: relation_proto.CommonInlineUserDefinedTableFunction,
206
- ) -> snowpark.DataFrame:
203
+ ) -> DataFrameContainer:
207
204
  udtf_check(rel)
208
205
  session = get_or_create_snowpark_session()
209
206
  python_udft = rel.python_udtf
@@ -280,9 +277,9 @@ def map_common_inline_user_defined_table_function(
280
277
 
281
278
  snowpark_columns = [f.name for f in output_schema.fields]
282
279
 
283
- return with_column_map(
284
- df,
285
- spark_column_names,
280
+ return DataFrameContainer.create_with_column_mapping(
281
+ dataframe=df,
282
+ spark_column_names=spark_column_names,
286
283
  snowpark_column_names=snowpark_columns,
287
284
  snowpark_column_types=snowpark_column_types,
288
285
  )
@@ -7,12 +7,14 @@ import json
7
7
  import logging
8
8
  import os
9
9
  import re
10
+ from pathlib import Path
10
11
 
11
12
  import pyspark.sql.connect.proto.relations_pb2 as relation_proto
12
13
 
13
14
  from snowflake import snowpark
14
15
  from snowflake.snowpark.types import StructType
15
16
  from snowflake.snowpark_connect.config import global_config
17
+ from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
16
18
  from snowflake.snowpark_connect.relation.io_utils import (
17
19
  convert_file_prefix_path,
18
20
  is_cloud_path,
@@ -38,7 +40,7 @@ logger = logging.getLogger("snowflake_connect_server")
38
40
 
39
41
  def map_read(
40
42
  rel: relation_proto.Relation,
41
- ) -> snowpark.DataFrame:
43
+ ) -> DataFrameContainer:
42
44
  """
43
45
  Read a file into a Snowpark DataFrame.
44
46
 
@@ -91,9 +93,9 @@ def map_read(
91
93
  telemetry.report_io_read(read_format, options)
92
94
  session: snowpark.Session = get_or_create_snowpark_session()
93
95
  if len(rel.read.data_source.paths) > 0:
94
- # Clean up trailing slashes from source paths to ensure consistent behavior
96
+ # Normalize paths to ensure consistent behavior
95
97
  clean_source_paths = [
96
- path.rstrip("/") for path in rel.read.data_source.paths
98
+ str(Path(path)) for path in rel.read.data_source.paths
97
99
  ]
98
100
 
99
101
  result = _read_file(
@@ -284,6 +286,9 @@ def upload_files_if_needed(
284
286
  # overwrite=True will not remove all stale files in the target prefix
285
287
 
286
288
  remove_command = f"REMOVE {target}/"
289
+ assert (
290
+ "//" not in remove_command
291
+ ), f"Remove command {remove_command} contains double slash"
287
292
  session.sql(remove_command).collect()
288
293
 
289
294
  try:
@@ -10,7 +10,7 @@ import snowflake.snowpark.functions as snowpark_fn
10
10
  from snowflake import snowpark
11
11
  from snowflake.snowpark.dataframe_reader import DataFrameReader
12
12
  from snowflake.snowpark.types import StringType, StructField, StructType
13
- from snowflake.snowpark_connect.column_name_handler import with_column_map
13
+ from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
14
14
  from snowflake.snowpark_connect.relation.read.map_read import CsvReaderConfig
15
15
  from snowflake.snowpark_connect.relation.read.utils import (
16
16
  get_spark_column_names_from_snowpark_columns,
@@ -27,7 +27,7 @@ def map_read_csv(
27
27
  session: snowpark.Session,
28
28
  paths: list[str],
29
29
  options: CsvReaderConfig,
30
- ) -> snowpark.DataFrame:
30
+ ) -> DataFrameContainer:
31
31
  """
32
32
  Read a CSV file into a Snowpark DataFrame.
33
33
 
@@ -73,11 +73,11 @@ def map_read_csv(
73
73
  renamed_df, snowpark_column_names = rename_columns_as_snowflake_standard(
74
74
  df, rel.common.plan_id
75
75
  )
76
- return with_column_map(
77
- renamed_df,
78
- spark_column_names,
79
- snowpark_column_names,
80
- [f.datatype for f in df.schema.fields],
76
+ return DataFrameContainer.create_with_column_mapping(
77
+ dataframe=renamed_df,
78
+ spark_column_names=spark_column_names,
79
+ snowpark_column_names=snowpark_column_names,
80
+ snowpark_column_types=[f.datatype for f in df.schema.fields],
81
81
  )
82
82
 
83
83
 
@@ -8,7 +8,7 @@ import pyspark.sql.connect.proto.relations_pb2 as relation_proto
8
8
 
9
9
  from snowflake import snowpark
10
10
  from snowflake.snowpark._internal.analyzer.analyzer_utils import unquote_if_quoted
11
- from snowflake.snowpark_connect.column_name_handler import with_column_map
11
+ from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
12
12
  from snowflake.snowpark_connect.relation.read.jdbc_read_dbapi import JdbcDataFrameReader
13
13
  from snowflake.snowpark_connect.relation.read.utils import (
14
14
  Connection,
@@ -46,7 +46,7 @@ def map_read_jdbc(
46
46
  rel: relation_proto.Relation,
47
47
  session: snowpark.Session,
48
48
  options: dict[str, str],
49
- ) -> snowpark.DataFrame:
49
+ ) -> DataFrameContainer:
50
50
  """
51
51
  Read a table data or query data from a JDBC external datasource into a Snowpark DataFrame.
52
52
  """
@@ -98,11 +98,11 @@ def map_read_jdbc(
98
98
  renamed_df, snowpark_cols = rename_columns_as_snowflake_standard(
99
99
  df, rel.common.plan_id
100
100
  )
101
- return with_column_map(
102
- renamed_df,
103
- true_names,
104
- snowpark_cols,
105
- [f.datatype for f in df.schema.fields],
101
+ return DataFrameContainer.create_with_column_mapping(
102
+ dataframe=renamed_df,
103
+ spark_column_names=true_names,
104
+ snowpark_column_names=snowpark_cols,
105
+ snowpark_column_types=[f.datatype for f in df.schema.fields],
106
106
  )
107
107
  except Exception as e:
108
108
  raise Exception(f"Error accessing JDBC datasource for read: {e}")
@@ -22,7 +22,7 @@ from snowflake.snowpark.types import (
22
22
  StructField,
23
23
  StructType,
24
24
  )
25
- from snowflake.snowpark_connect.column_name_handler import with_column_map
25
+ from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
26
26
  from snowflake.snowpark_connect.relation.read.map_read import JsonReaderConfig
27
27
  from snowflake.snowpark_connect.relation.read.utils import (
28
28
  get_spark_column_names_from_snowpark_columns,
@@ -43,7 +43,7 @@ def map_read_json(
43
43
  session: snowpark.Session,
44
44
  paths: list[str],
45
45
  options: JsonReaderConfig,
46
- ) -> snowpark.DataFrame:
46
+ ) -> DataFrameContainer:
47
47
  """
48
48
  Read a JSON file into a Snowpark DataFrame.
49
49
 
@@ -105,11 +105,11 @@ def map_read_json(
105
105
  renamed_df, snowpark_column_names = rename_columns_as_snowflake_standard(
106
106
  df, rel.common.plan_id
107
107
  )
108
- return with_column_map(
109
- renamed_df,
110
- spark_column_names,
111
- snowpark_column_names,
112
- [f.datatype for f in df.schema.fields],
108
+ return DataFrameContainer.create_with_column_mapping(
109
+ dataframe=renamed_df,
110
+ spark_column_names=spark_column_names,
111
+ snowpark_column_names=snowpark_column_names,
112
+ snowpark_column_types=[f.datatype for f in df.schema.fields],
113
113
  )
114
114
 
115
115
 
@@ -21,7 +21,7 @@ from snowflake.snowpark._internal.analyzer.analyzer_utils import (
21
21
  )
22
22
  from snowflake.snowpark.column import METADATA_FILENAME
23
23
  from snowflake.snowpark.types import DataType, DoubleType, IntegerType, StringType
24
- from snowflake.snowpark_connect.column_name_handler import with_column_map
24
+ from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
25
25
  from snowflake.snowpark_connect.relation.read.reader_config import ReaderWriterConfig
26
26
  from snowflake.snowpark_connect.relation.read.utils import (
27
27
  rename_columns_as_snowflake_standard,
@@ -37,7 +37,7 @@ def map_read_parquet(
37
37
  session: snowpark.Session,
38
38
  paths: list[str],
39
39
  options: ReaderWriterConfig,
40
- ) -> snowpark.DataFrame:
40
+ ) -> DataFrameContainer:
41
41
  """Read a Parquet file into a Snowpark DataFrame."""
42
42
 
43
43
  if rel.read.is_streaming is True:
@@ -62,11 +62,11 @@ def map_read_parquet(
62
62
  renamed_df, snowpark_column_names = rename_columns_as_snowflake_standard(
63
63
  df, rel.common.plan_id
64
64
  )
65
- return with_column_map(
66
- renamed_df,
67
- [analyzer_utils.unquote_if_quoted(c) for c in df.columns],
68
- snowpark_column_names,
69
- [f.datatype for f in df.schema.fields],
65
+ return DataFrameContainer.create_with_column_mapping(
66
+ dataframe=renamed_df,
67
+ spark_column_names=[analyzer_utils.unquote_if_quoted(c) for c in df.columns],
68
+ snowpark_column_names=snowpark_column_names,
69
+ snowpark_column_types=[f.datatype for f in df.schema.fields],
70
70
  )
71
71
 
72
72
 
@@ -8,7 +8,7 @@ import pandas
8
8
  import pyspark.sql.connect.proto.relations_pb2 as relation_proto
9
9
 
10
10
  from snowflake import snowpark
11
- from snowflake.snowpark_connect.column_name_handler import with_column_map
11
+ from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
12
12
  from snowflake.snowpark_connect.utils.telemetry import (
13
13
  SnowparkConnectNotImplementedError,
14
14
  )
@@ -24,7 +24,7 @@ def map_read_socket(
24
24
  rel: relation_proto.Relation,
25
25
  session: snowpark.Session,
26
26
  options: dict[str, str],
27
- ) -> snowpark.DataFrame:
27
+ ) -> DataFrameContainer:
28
28
  if rel.read.is_streaming is True:
29
29
  global full_data
30
30
  host = options.get("host", None)
@@ -50,7 +50,11 @@ def map_read_socket(
50
50
  pandas.DataFrame({snowpark_cname: dataframe_data.split("\n")})
51
51
  )
52
52
  spark_cname = "value"
53
- return with_column_map(df, [spark_cname], [snowpark_cname])
53
+ return DataFrameContainer.create_with_column_mapping(
54
+ dataframe=df,
55
+ spark_column_names=[spark_cname],
56
+ snowpark_column_names=[snowpark_cname],
57
+ )
54
58
  except OSError as e:
55
59
  raise Exception(f"Error connecting to {host}:{port} - {e}")
56
60
  else:
@@ -11,12 +11,13 @@ from snowflake.snowpark._internal.analyzer.analyzer_utils import (
11
11
  unquote_if_quoted,
12
12
  )
13
13
  from snowflake.snowpark.exceptions import SnowparkSQLException
14
- from snowflake.snowpark_connect.column_name_handler import with_column_map
15
- from snowflake.snowpark_connect.config import auto_uppercase_dml
14
+ from snowflake.snowpark_connect.column_name_handler import ALREADY_QUOTED
15
+ from snowflake.snowpark_connect.config import auto_uppercase_non_column_identifiers
16
+ from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
16
17
  from snowflake.snowpark_connect.relation.read.utils import (
17
18
  rename_columns_as_snowflake_standard,
18
19
  )
19
- from snowflake.snowpark_connect.utils.attribute_handling import (
20
+ from snowflake.snowpark_connect.utils.identifiers import (
20
21
  split_fully_qualified_spark_name,
21
22
  )
22
23
  from snowflake.snowpark_connect.utils.session import _get_current_snowpark_session
@@ -27,9 +28,16 @@ from snowflake.snowpark_connect.utils.telemetry import (
27
28
 
28
29
  def post_process_df(
29
30
  df: snowpark.DataFrame, plan_id: int, source_table_name: str = None
30
- ) -> snowpark.DataFrame:
31
+ ) -> DataFrameContainer:
32
+ def _lower_or_unquote(string):
33
+ return (
34
+ string[1:-1].replace('""', '"')
35
+ if ALREADY_QUOTED.match(string)
36
+ else string.lower()
37
+ )
38
+
31
39
  try:
32
- true_names = list(map(lambda x: unquote_if_quoted(x).lower(), df.columns))
40
+ true_names = list(map(lambda x: _lower_or_unquote(x), df.columns))
33
41
  renamed_df, snowpark_column_names = rename_columns_as_snowflake_standard(
34
42
  df, plan_id
35
43
  )
@@ -44,11 +52,11 @@ def post_process_df(
44
52
  if current_schema:
45
53
  name_parts = [unquote_if_quoted(current_schema)] + name_parts
46
54
 
47
- return with_column_map(
48
- renamed_df,
49
- true_names,
50
- snowpark_column_names,
51
- [f.datatype for f in df.schema.fields],
55
+ return DataFrameContainer.create_with_column_mapping(
56
+ dataframe=renamed_df,
57
+ spark_column_names=true_names,
58
+ snowpark_column_names=snowpark_column_names,
59
+ snowpark_column_types=[f.datatype for f in df.schema.fields],
52
60
  column_qualifiers=[name_parts] * len(true_names)
53
61
  if source_table_name
54
62
  else None,
@@ -66,19 +74,18 @@ def post_process_df(
66
74
 
67
75
  def get_table_from_name(
68
76
  table_name: str, session: snowpark.Session, plan_id: int
69
- ) -> snowpark.DataFrame:
77
+ ) -> DataFrameContainer:
78
+ """Get table from name returning a container."""
70
79
  snowpark_name = ".".join(
71
80
  quote_name_without_upper_casing(part)
72
81
  for part in split_fully_qualified_spark_name(table_name)
73
82
  )
74
83
 
75
- if auto_uppercase_dml():
84
+ if auto_uppercase_non_column_identifiers():
76
85
  snowpark_name = snowpark_name.upper()
77
86
 
78
87
  df = session.read.table(snowpark_name)
79
- post_processed_df = post_process_df(df, plan_id, table_name)
80
- post_processed_df._table_name = table_name
81
- return post_processed_df
88
+ return post_process_df(df, plan_id, table_name)
82
89
 
83
90
 
84
91
  def get_table_from_query(
@@ -88,7 +95,9 @@ def get_table_from_query(
88
95
  return post_process_df(df, plan_id)
89
96
 
90
97
 
91
- def map_read_table(rel: relation_proto.Relation) -> snowpark.DataFrame:
98
+ def map_read_table(
99
+ rel: relation_proto.Relation,
100
+ ) -> DataFrameContainer:
92
101
  """
93
102
  Read a table into a Snowpark DataFrame.
94
103
  """