snowpark-connect 0.28.1__py3-none-any.whl → 0.30.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of snowpark-connect might be problematic. Click here for more details.

Files changed (47) hide show
  1. snowflake/snowpark_connect/analyze_plan/map_tree_string.py +8 -4
  2. snowflake/snowpark_connect/client.py +65 -0
  3. snowflake/snowpark_connect/column_name_handler.py +6 -0
  4. snowflake/snowpark_connect/config.py +33 -5
  5. snowflake/snowpark_connect/execute_plan/map_execution_root.py +21 -19
  6. snowflake/snowpark_connect/expression/map_extension.py +277 -1
  7. snowflake/snowpark_connect/expression/map_sql_expression.py +107 -2
  8. snowflake/snowpark_connect/expression/map_unresolved_function.py +425 -269
  9. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.py +12 -10
  10. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.pyi +14 -2
  11. snowflake/snowpark_connect/relation/io_utils.py +21 -1
  12. snowflake/snowpark_connect/relation/map_column_ops.py +9 -4
  13. snowflake/snowpark_connect/relation/map_extension.py +21 -4
  14. snowflake/snowpark_connect/relation/map_join.py +8 -0
  15. snowflake/snowpark_connect/relation/map_map_partitions.py +7 -8
  16. snowflake/snowpark_connect/relation/map_relation.py +1 -3
  17. snowflake/snowpark_connect/relation/map_row_ops.py +116 -15
  18. snowflake/snowpark_connect/relation/map_show_string.py +14 -6
  19. snowflake/snowpark_connect/relation/map_sql.py +39 -5
  20. snowflake/snowpark_connect/relation/map_stats.py +1 -1
  21. snowflake/snowpark_connect/relation/read/map_read.py +22 -3
  22. snowflake/snowpark_connect/relation/read/map_read_csv.py +119 -29
  23. snowflake/snowpark_connect/relation/read/map_read_json.py +57 -36
  24. snowflake/snowpark_connect/relation/read/map_read_parquet.py +7 -1
  25. snowflake/snowpark_connect/relation/read/map_read_text.py +6 -1
  26. snowflake/snowpark_connect/relation/read/metadata_utils.py +159 -0
  27. snowflake/snowpark_connect/relation/stage_locator.py +85 -53
  28. snowflake/snowpark_connect/relation/write/map_write.py +67 -4
  29. snowflake/snowpark_connect/server.py +29 -16
  30. snowflake/snowpark_connect/type_mapping.py +75 -3
  31. snowflake/snowpark_connect/utils/context.py +0 -14
  32. snowflake/snowpark_connect/utils/describe_query_cache.py +6 -3
  33. snowflake/snowpark_connect/utils/io_utils.py +36 -0
  34. snowflake/snowpark_connect/utils/session.py +4 -0
  35. snowflake/snowpark_connect/utils/telemetry.py +30 -5
  36. snowflake/snowpark_connect/utils/udf_cache.py +37 -7
  37. snowflake/snowpark_connect/version.py +1 -1
  38. {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.30.0.dist-info}/METADATA +3 -2
  39. {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.30.0.dist-info}/RECORD +47 -45
  40. {snowpark_connect-0.28.1.data → snowpark_connect-0.30.0.data}/scripts/snowpark-connect +0 -0
  41. {snowpark_connect-0.28.1.data → snowpark_connect-0.30.0.data}/scripts/snowpark-session +0 -0
  42. {snowpark_connect-0.28.1.data → snowpark_connect-0.30.0.data}/scripts/snowpark-submit +0 -0
  43. {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.30.0.dist-info}/WHEEL +0 -0
  44. {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.30.0.dist-info}/licenses/LICENSE-binary +0 -0
  45. {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.30.0.dist-info}/licenses/LICENSE.txt +0 -0
  46. {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.30.0.dist-info}/licenses/NOTICE-binary +0 -0
  47. {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.30.0.dist-info}/top_level.txt +0 -0
@@ -46,6 +46,9 @@ def map_read(
46
46
 
47
47
  Currently, the supported read formats are `csv`, `json` and `parquet`.
48
48
  """
49
+
50
+ materialize_df = True
51
+
49
52
  match rel.read.WhichOneof("read_type"):
50
53
  case "named_table":
51
54
  return map_read_table_or_file(rel)
@@ -99,6 +102,10 @@ def map_read(
99
102
  for path in rel.read.data_source.paths
100
103
  ]
101
104
 
105
+ # JSON already materializes the table internally
106
+ if read_format == "json":
107
+ materialize_df = False
108
+
102
109
  result = _read_file(
103
110
  clean_source_paths, options, read_format, rel, schema, session
104
111
  )
@@ -159,7 +166,9 @@ def map_read(
159
166
  raise SnowparkConnectNotImplementedError(f"Unsupported read type: {other}")
160
167
 
161
168
  return df_cache_map_put_if_absent(
162
- (get_session_id(), rel.common.plan_id), lambda: result, materialize=True
169
+ (get_session_id(), rel.common.plan_id),
170
+ lambda: result,
171
+ materialize=materialize_df,
163
172
  )
164
173
 
165
174
 
@@ -205,6 +214,15 @@ def _get_supported_read_file_format(unparsed_identifier: str) -> str | None:
205
214
  return None
206
215
 
207
216
 
217
+ def _quote_stage_path(stage_path: str) -> str:
218
+ """
219
+ Quote stage paths to escape any special characters.
220
+ """
221
+ if stage_path.startswith("@"):
222
+ return f"'{stage_path}'"
223
+ return stage_path
224
+
225
+
208
226
  def _read_file(
209
227
  clean_source_paths: list[str],
210
228
  options: dict,
@@ -218,6 +236,7 @@ def _read_file(
218
236
  session,
219
237
  )
220
238
  upload_files_if_needed(paths, clean_source_paths, session, read_format)
239
+ paths = [_quote_stage_path(path) for path in paths]
221
240
  match read_format:
222
241
  case "csv":
223
242
  from snowflake.snowpark_connect.relation.read.map_read_csv import (
@@ -285,8 +304,8 @@ def upload_files_if_needed(
285
304
 
286
305
  def _upload_dir(target: str, source: str) -> None:
287
306
  # overwrite=True will not remove all stale files in the target prefix
288
-
289
- remove_command = f"REMOVE {target}/"
307
+ # Quote the target path to allow special characters.
308
+ remove_command = f"REMOVE '{target}/'"
290
309
  assert (
291
310
  "//" not in remove_command
292
311
  ), f"Remove command {remove_command} contains double slash"
@@ -3,6 +3,7 @@
3
3
  #
4
4
 
5
5
  import copy
6
+ from typing import Any
6
7
 
7
8
  import pyspark.sql.connect.proto.relations_pb2 as relation_proto
8
9
 
@@ -12,10 +13,15 @@ from snowflake.snowpark.dataframe_reader import DataFrameReader
12
13
  from snowflake.snowpark.types import StringType, StructField, StructType
13
14
  from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
14
15
  from snowflake.snowpark_connect.relation.read.map_read import CsvReaderConfig
16
+ from snowflake.snowpark_connect.relation.read.metadata_utils import (
17
+ add_filename_metadata_to_reader,
18
+ get_non_metadata_fields,
19
+ )
15
20
  from snowflake.snowpark_connect.relation.read.utils import (
16
21
  get_spark_column_names_from_snowpark_columns,
17
22
  rename_columns_as_snowflake_standard,
18
23
  )
24
+ from snowflake.snowpark_connect.utils.io_utils import cached_file_format
19
25
  from snowflake.snowpark_connect.utils.telemetry import (
20
26
  SnowparkConnectNotImplementedError,
21
27
  )
@@ -42,21 +48,39 @@ def map_read_csv(
42
48
  )
43
49
  else:
44
50
  snowpark_options = options.convert_to_snowpark_args()
51
+ parse_header = snowpark_options.get("PARSE_HEADER", False)
52
+ file_format_options = _parse_csv_snowpark_options(snowpark_options)
53
+ file_format = cached_file_format(session, "csv", file_format_options)
54
+
55
+ snowpark_read_options = dict()
56
+ snowpark_read_options["FORMAT_NAME"] = file_format
57
+ snowpark_read_options["ENFORCE_EXISTING_FILE_FORMAT"] = True
58
+ snowpark_read_options["INFER_SCHEMA"] = snowpark_options.get(
59
+ "INFER_SCHEMA", False
60
+ )
61
+ snowpark_read_options["PATTERN"] = snowpark_options.get("PATTERN", None)
62
+
45
63
  raw_options = rel.read.data_source.options
64
+
46
65
  if schema is None or (
47
- snowpark_options.get("PARSE_HEADER", False)
48
- and raw_options.get("enforceSchema", "True").lower() == "false"
66
+ parse_header and raw_options.get("enforceSchema", "True").lower() == "false"
49
67
  ): # Schema has to equals to header's format
50
- reader = session.read.options(snowpark_options)
68
+ reader = add_filename_metadata_to_reader(
69
+ session.read.options(snowpark_options), raw_options
70
+ )
51
71
  else:
52
- reader = session.read.options(snowpark_options).schema(schema)
72
+ reader = add_filename_metadata_to_reader(
73
+ session.read.options(snowpark_options).schema(schema), raw_options
74
+ )
53
75
  df = read_data(
54
76
  reader,
55
77
  schema,
56
78
  session,
57
79
  paths[0],
58
- snowpark_options,
80
+ file_format_options,
81
+ snowpark_read_options,
59
82
  raw_options,
83
+ parse_header,
60
84
  )
61
85
  if len(paths) > 1:
62
86
  # TODO: figure out if this is what Spark does.
@@ -81,15 +105,65 @@ def map_read_csv(
81
105
  )
82
106
 
83
107
 
108
+ _csv_file_format_allowed_options = {
109
+ "COMPRESSION",
110
+ "RECORD_DELIMITER",
111
+ "FIELD_DELIMITER",
112
+ "MULTI_LINE",
113
+ "FILE_EXTENSION",
114
+ "PARSE_HEADER",
115
+ "SKIP_HEADER",
116
+ "SKIP_BLANK_LINES",
117
+ "DATE_FORMAT",
118
+ "TIME_FORMAT",
119
+ "TIMESTAMP_FORMAT",
120
+ "BINARY_FORMAT",
121
+ "ESCAPE",
122
+ "ESCAPE_UNENCLOSED_FIELD",
123
+ "TRIM_SPACE",
124
+ "FIELD_OPTIONALLY_ENCLOSED_BY",
125
+ "NULL_IF",
126
+ "ERROR_ON_COLUMN_COUNT_MISMATCH",
127
+ "REPLACE_INVALID_CHARACTERS",
128
+ "EMPTY_FIELD_AS_NULL",
129
+ "SKIP_BYTE_ORDER_MARK",
130
+ "ENCODING",
131
+ }
132
+
133
+
134
+ def _parse_csv_snowpark_options(snowpark_options: dict[str, Any]) -> dict[str, Any]:
135
+ file_format_options = dict()
136
+ for key, value in snowpark_options.items():
137
+ upper_key = key.upper()
138
+ if upper_key in _csv_file_format_allowed_options:
139
+ file_format_options[upper_key] = value
140
+
141
+ # This option has to be removed, because we cannot use at the same time predefined file format and parse_header option
142
+ # Such combination causes snowpark to raise SQL compilation error: Invalid file format "PARSE_HEADER" is only allowed for CSV INFER_SCHEMA and MATCH_BY_COLUMN_NAME
143
+ parse_header = file_format_options.get("PARSE_HEADER", False)
144
+ if parse_header:
145
+ file_format_options["SKIP_HEADER"] = 1
146
+ del file_format_options["PARSE_HEADER"]
147
+
148
+ return file_format_options
149
+
150
+
84
151
  def get_header_names(
85
152
  session: snowpark.Session,
86
153
  path: list[str],
87
- snowpark_options: dict,
154
+ file_format_options: dict,
155
+ snowpark_read_options: dict,
88
156
  ) -> list[str]:
89
- snowpark_options_no_header = copy.copy(snowpark_options)
90
- snowpark_options_no_header["PARSE_HEADER"] = False
157
+ no_header_file_format_options = copy.copy(file_format_options)
158
+ no_header_file_format_options["PARSE_HEADER"] = False
159
+ no_header_file_format_options.pop("SKIP_HEADER", None)
160
+
161
+ file_format = cached_file_format(session, "csv", no_header_file_format_options)
162
+ no_header_snowpark_read_options = copy.copy(snowpark_read_options)
163
+ no_header_snowpark_read_options["FORMAT_NAME"] = file_format
164
+ no_header_snowpark_read_options.pop("INFER_SCHEMA", None)
91
165
 
92
- header_df = session.read.options(snowpark_options_no_header).csv(path).limit(1)
166
+ header_df = session.read.options(no_header_snowpark_read_options).csv(path).limit(1)
93
167
  header_data = header_df.collect()[0]
94
168
  return [
95
169
  f'"{header_data[i]}"'
@@ -103,40 +177,56 @@ def read_data(
103
177
  schema: snowpark.types.StructType | None,
104
178
  session: snowpark.Session,
105
179
  path: list[str],
106
- snowpark_options: dict,
180
+ file_format_options: dict,
181
+ snowpark_read_options: dict,
107
182
  raw_options: dict,
183
+ parse_header: bool,
108
184
  ) -> snowpark.DataFrame:
109
185
  df = reader.csv(path)
110
186
  filename = path.strip("/").split("/")[-1]
187
+ non_metadata_fields = get_non_metadata_fields(df.schema.fields)
188
+
111
189
  if schema is not None:
112
- if len(schema.fields) != len(df.schema.fields):
190
+ if len(schema.fields) != len(non_metadata_fields):
113
191
  raise Exception(f"csv load from {filename} failed.")
114
192
  if raw_options.get("enforceSchema", "True").lower() == "false":
115
193
  for i in range(len(schema.fields)):
116
194
  if (
117
- schema.fields[i].name != df.schema.fields[i].name
118
- and f'"{schema.fields[i].name}"' != df.schema.fields[i].name
195
+ schema.fields[i].name != non_metadata_fields[i].name
196
+ and f'"{schema.fields[i].name}"' != non_metadata_fields[i].name
119
197
  ):
120
198
  raise Exception("CSV header does not conform to the schema")
121
199
  return df
122
200
 
123
- headers = get_header_names(session, path, snowpark_options)
124
-
201
+ headers = get_header_names(
202
+ session, path, file_format_options, snowpark_read_options
203
+ )
204
+
205
+ df_schema_fields = non_metadata_fields
206
+ if len(headers) == len(df_schema_fields) and parse_header:
207
+ return df.select(
208
+ [
209
+ snowpark_fn.col(df_schema_fields[i].name).alias(headers[i])
210
+ for i in range(len(headers))
211
+ ]
212
+ )
125
213
  # Handle mismatch in column count between header and data
126
- if (
127
- len(df.schema.fields) == 1
128
- and df.schema.fields[0].name.upper() == "C1"
129
- and snowpark_options.get("PARSE_HEADER") is True
130
- and len(headers) != len(df.schema.fields)
214
+ elif (
215
+ len(df_schema_fields) == 1
216
+ and df_schema_fields[0].name.upper() == "C1"
217
+ and parse_header
218
+ and len(headers) != len(df_schema_fields)
131
219
  ):
132
- df = (
133
- session.read.options(snowpark_options)
134
- .schema(StructType([StructField(h, StringType(), True) for h in headers]))
135
- .csv(path)
220
+ df = reader.schema(
221
+ StructType([StructField(h, StringType(), True) for h in headers])
222
+ ).csv(path)
223
+ elif not parse_header and len(headers) != len(df_schema_fields):
224
+ return df.select([df_schema_fields[i].name for i in range(len(headers))])
225
+ elif parse_header and len(headers) != len(df_schema_fields):
226
+ return df.select(
227
+ [
228
+ snowpark_fn.col(df_schema_fields[i].name).alias(headers[i])
229
+ for i in range(len(headers))
230
+ ]
136
231
  )
137
- elif snowpark_options.get("PARSE_HEADER") is False and len(headers) != len(
138
- df.schema.fields
139
- ):
140
- return df.select([df.schema.fields[i].name for i in range(len(headers))])
141
-
142
232
  return df
@@ -2,9 +2,12 @@
2
2
  # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
3
3
  #
4
4
 
5
+ import concurrent.futures
5
6
  import copy
6
7
  import json
8
+ import os
7
9
  import typing
10
+ import uuid
8
11
  from contextlib import suppress
9
12
  from datetime import datetime
10
13
 
@@ -26,6 +29,9 @@ from snowflake.snowpark.types import (
26
29
  )
27
30
  from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
28
31
  from snowflake.snowpark_connect.relation.read.map_read import JsonReaderConfig
32
+ from snowflake.snowpark_connect.relation.read.metadata_utils import (
33
+ add_filename_metadata_to_reader,
34
+ )
29
35
  from snowflake.snowpark_connect.relation.read.utils import (
30
36
  get_spark_column_names_from_snowpark_columns,
31
37
  rename_columns_as_snowflake_standard,
@@ -63,19 +69,26 @@ def map_read_json(
63
69
  )
64
70
  else:
65
71
  snowpark_options = options.convert_to_snowpark_args()
72
+ raw_options = rel.read.data_source.options
66
73
  snowpark_options["infer_schema"] = True
67
74
 
68
75
  rows_to_infer_schema = snowpark_options.pop("rowstoinferschema", 1000)
69
76
  dropFieldIfAllNull = snowpark_options.pop("dropfieldifallnull", False)
70
77
  batch_size = snowpark_options.pop("batchsize", 1000)
71
78
 
72
- reader = session.read.options(snowpark_options)
79
+ reader = add_filename_metadata_to_reader(
80
+ session.read.options(snowpark_options), raw_options
81
+ )
73
82
 
74
83
  df = reader.json(paths[0])
75
84
  if len(paths) > 1:
76
85
  # TODO: figure out if this is what Spark does.
77
86
  for p in paths[1:]:
78
- df = df.union_all(session.read.options(snowpark_options).json(p))
87
+ df = df.union_all(
88
+ add_filename_metadata_to_reader(
89
+ session.read.options(snowpark_options), raw_options
90
+ ).json(p)
91
+ )
79
92
 
80
93
  if schema is None:
81
94
  schema = copy.deepcopy(df.schema)
@@ -253,20 +266,20 @@ def merge_row_schema(
253
266
  return schema
254
267
 
255
268
 
256
- def union_data_into_df(
257
- result_df: snowpark.DataFrame,
258
- data: typing.List[Row],
259
- schema: StructType,
269
+ def insert_data_chunk(
260
270
  session: snowpark.Session,
261
- ) -> snowpark.DataFrame:
262
- current_df = session.create_dataframe(
271
+ data: list[Row],
272
+ schema: StructType,
273
+ table_name: str,
274
+ ) -> None:
275
+ df = session.create_dataframe(
263
276
  data=data,
264
277
  schema=schema,
265
278
  )
266
- if result_df is None:
267
- return current_df
268
279
 
269
- return result_df.union(current_df)
280
+ df.write.mode("append").save_as_table(
281
+ table_name, table_type="temp", table_exists=True
282
+ )
270
283
 
271
284
 
272
285
  def construct_dataframe_by_schema(
@@ -276,39 +289,47 @@ def construct_dataframe_by_schema(
276
289
  snowpark_options: dict,
277
290
  batch_size: int = 1000,
278
291
  ) -> snowpark.DataFrame:
279
- result = None
292
+ table_name = "__sas_json_read_temp_" + uuid.uuid4().hex
293
+
294
+ # We can have more workers than CPU count, this is an IO-intensive task
295
+ max_workers = min(16, os.cpu_count() * 2)
280
296
 
281
297
  current_data = []
282
298
  progress = 0
283
- for row in rows:
284
- current_data.append(construct_row_by_schema(row, schema, snowpark_options))
285
- if len(current_data) >= batch_size:
299
+
300
+ # Initialize the temp table
301
+ session.create_dataframe([], schema=schema).write.mode("append").save_as_table(
302
+ table_name, table_type="temp", table_exists=False
303
+ )
304
+
305
+ with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as exc:
306
+ for row in rows:
307
+ current_data.append(construct_row_by_schema(row, schema, snowpark_options))
308
+ if len(current_data) >= batch_size:
309
+ progress += len(current_data)
310
+ exc.submit(
311
+ insert_data_chunk,
312
+ session,
313
+ copy.deepcopy(current_data),
314
+ schema,
315
+ table_name,
316
+ )
317
+
318
+ logger.info(f"JSON reader: finished processing {progress} rows")
319
+ current_data.clear()
320
+
321
+ if len(current_data) > 0:
286
322
  progress += len(current_data)
287
- result = union_data_into_df(
288
- result,
289
- current_data,
290
- schema,
323
+ exc.submit(
324
+ insert_data_chunk,
291
325
  session,
326
+ copy.deepcopy(current_data),
327
+ schema,
328
+ table_name,
292
329
  )
293
-
294
330
  logger.info(f"JSON reader: finished processing {progress} rows")
295
- current_data = []
296
-
297
- if len(current_data) > 0:
298
- progress += len(current_data)
299
- result = union_data_into_df(
300
- result,
301
- current_data,
302
- schema,
303
- session,
304
- )
305
-
306
- logger.info(f"JSON reader: finished processing {progress} rows")
307
- current_data = []
308
331
 
309
- if result is None:
310
- raise ValueError("Dataframe cannot be empty")
311
- return result
332
+ return session.table(table_name)
312
333
 
313
334
 
314
335
  def construct_row_by_schema(
@@ -22,6 +22,9 @@ from snowflake.snowpark._internal.analyzer.analyzer_utils import (
22
22
  from snowflake.snowpark.column import METADATA_FILENAME
23
23
  from snowflake.snowpark.types import DataType, DoubleType, IntegerType, StringType
24
24
  from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
25
+ from snowflake.snowpark_connect.relation.read.metadata_utils import (
26
+ add_filename_metadata_to_reader,
27
+ )
25
28
  from snowflake.snowpark_connect.relation.read.reader_config import ReaderWriterConfig
26
29
  from snowflake.snowpark_connect.relation.read.utils import (
27
30
  rename_columns_as_snowflake_standard,
@@ -46,10 +49,13 @@ def map_read_parquet(
46
49
  )
47
50
 
48
51
  snowpark_options = options.convert_to_snowpark_args()
52
+ raw_options = rel.read.data_source.options
49
53
  assert schema is None, "Read PARQUET does not support user schema"
50
54
  assert len(paths) > 0, "Read PARQUET expects at least one path"
51
55
 
52
- reader = session.read.options(snowpark_options)
56
+ reader = add_filename_metadata_to_reader(
57
+ session.read.options(snowpark_options), raw_options
58
+ )
53
59
 
54
60
  if len(paths) == 1:
55
61
  df = _read_parquet_with_partitions(session, reader, paths[0])
@@ -43,7 +43,12 @@ def read_text(
43
43
  ) -> snowpark.DataFrame:
44
44
  # TODO: handle stage name with double quotes
45
45
  files_paths = get_file_paths_from_stage(path, session)
46
- stage_name = path.split("/")[0]
46
+ # Remove matching quotes from both ends of the path to get the stage name, if present.
47
+ if path and len(path) > 1 and path[0] == path[-1] and path[0] in ('"', "'"):
48
+ unquoted_path = path[1:-1]
49
+ else:
50
+ unquoted_path = path
51
+ stage_name = unquoted_path.split("/")[0]
47
52
  line_sep = options.get("lineSep") or "\n"
48
53
  column_name = (
49
54
  schema[0].name if schema is not None and len(schema.fields) > 0 else '"value"'
@@ -0,0 +1,159 @@
1
+ #
2
+ # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
3
+ #
4
+
5
+ """
6
+ Utilities for handling internal metadata columns in file-based DataFrames.
7
+ """
8
+
9
+ import os
10
+
11
+ import pandas
12
+ from pyspark.errors.exceptions.base import AnalysisException
13
+
14
+ from snowflake import snowpark
15
+ from snowflake.snowpark.column import METADATA_FILENAME
16
+ from snowflake.snowpark.functions import col
17
+ from snowflake.snowpark.types import StructField
18
+ from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
19
+
20
+ # Constant for the metadata filename column name
21
+ METADATA_FILENAME_COLUMN = "METADATA$FILENAME"
22
+
23
+
24
+ def add_filename_metadata_to_reader(
25
+ reader: snowpark.DataFrameReader,
26
+ options: dict | None = None,
27
+ ) -> snowpark.DataFrameReader:
28
+ """
29
+ Add filename metadata to a DataFrameReader based on configuration.
30
+
31
+ Args:
32
+ reader: Snowpark DataFrameReader instance
33
+ options: Dictionary of options to check for metadata configuration
34
+
35
+ Returns:
36
+ DataFrameReader with filename metadata enabled if configured, otherwise unchanged
37
+ """
38
+ # NOTE: SNOWPARK_POPULATE_FILE_METADATA_DEFAULT is an internal environment variable
39
+ # used only for CI testing to verify no metadata columns leak in regular file operations.
40
+ # This environment variable should NOT be exposed to end users. Users should only use snowpark.populateFileMetadata
41
+ # to enable metadata population.
42
+ metadata_default = os.environ.get(
43
+ "SNOWPARK_POPULATE_FILE_METADATA_DEFAULT", "false"
44
+ )
45
+
46
+ populate_metadata = (
47
+ options.get("snowpark.populateFileMetadata", metadata_default)
48
+ if options
49
+ else metadata_default
50
+ ).lower() == "true"
51
+
52
+ if populate_metadata:
53
+ return reader.with_metadata(METADATA_FILENAME)
54
+ else:
55
+ return reader
56
+
57
+
58
+ def get_non_metadata_fields(schema_fields: list[StructField]) -> list[StructField]:
59
+ """
60
+ Filter out METADATA$FILENAME fields from a list of schema fields.
61
+
62
+ Args:
63
+ schema_fields: List of StructField objects from a DataFrame schema
64
+
65
+ Returns:
66
+ List of StructField objects excluding METADATA$FILENAME
67
+ """
68
+ return [field for field in schema_fields if field.name != METADATA_FILENAME_COLUMN]
69
+
70
+
71
+ def get_non_metadata_column_names(schema_fields: list[StructField]) -> list[str]:
72
+ """
73
+ Get column names from schema fields, excluding METADATA$FILENAME.
74
+
75
+ Args:
76
+ schema_fields: List of StructField objects from a DataFrame schema
77
+
78
+ Returns:
79
+ List of column names (strings) excluding METADATA$FILENAME
80
+ """
81
+ return [
82
+ field.name for field in schema_fields if field.name != METADATA_FILENAME_COLUMN
83
+ ]
84
+
85
+
86
+ def filter_metadata_column_name(column_names: list[str]) -> list[str]:
87
+ """
88
+ Get column names from column_names, excluding METADATA$FILENAME.
89
+
90
+ Returns:
91
+ List of column names (strings) excluding METADATA$FILENAME
92
+ """
93
+ return [
94
+ col_name for col_name in column_names if col_name != METADATA_FILENAME_COLUMN
95
+ ]
96
+
97
+
98
+ def filter_metadata_columns(
99
+ result_container: DataFrameContainer | pandas.DataFrame | None,
100
+ ) -> DataFrameContainer | pandas.DataFrame | None:
101
+ """
102
+ Filter METADATA$FILENAME from DataFrame container for execution and write operations.
103
+
104
+ Args:
105
+ result_container: DataFrameContainer or pandas DataFrame to filter
106
+
107
+ Returns:
108
+ Filtered container (callers can access dataframe via container.dataframe)
109
+ """
110
+ # Handle pandas DataFrame case - return as-is
111
+ if isinstance(result_container, pandas.DataFrame):
112
+ return result_container
113
+
114
+ if result_container is None:
115
+ return None
116
+
117
+ result_df = result_container.dataframe
118
+ if not isinstance(result_df, snowpark.DataFrame):
119
+ return result_container
120
+
121
+ df_columns = result_container.column_map.get_snowpark_columns()
122
+ has_metadata_filename = any(name == METADATA_FILENAME_COLUMN for name in df_columns)
123
+
124
+ if not has_metadata_filename:
125
+ return result_container
126
+
127
+ non_metadata_columns = filter_metadata_column_name(df_columns)
128
+
129
+ if len(non_metadata_columns) == 0:
130
+ # DataFrame contains only metadata columns (METADATA$FILENAME), no actual data columns remaining.
131
+ # We don't have a way to return an empty dataframe.
132
+ raise AnalysisException(
133
+ "[DATAFRAME_MISSING_DATA_COLUMNS] Cannot perform operation on DataFrame that contains no data columns."
134
+ )
135
+
136
+ filtered_df = result_df.select([col(name) for name in non_metadata_columns])
137
+
138
+ original_spark_columns = result_container.column_map.get_spark_columns()
139
+ original_snowpark_columns = result_container.column_map.get_snowpark_columns()
140
+
141
+ filtered_spark_columns = []
142
+ filtered_snowpark_columns = []
143
+
144
+ for i, colname in enumerate(df_columns):
145
+ if colname != METADATA_FILENAME_COLUMN:
146
+ filtered_spark_columns.append(original_spark_columns[i])
147
+ filtered_snowpark_columns.append(original_snowpark_columns[i])
148
+
149
+ new_container = DataFrameContainer.create_with_column_mapping(
150
+ dataframe=filtered_df,
151
+ spark_column_names=filtered_spark_columns,
152
+ snowpark_column_names=filtered_snowpark_columns,
153
+ column_metadata=result_container.column_map.column_metadata,
154
+ table_name=result_container.table_name,
155
+ alias=result_container.alias,
156
+ partition_hint=result_container.partition_hint,
157
+ )
158
+
159
+ return new_container