snowpark-checkpoints-collectors 0.2.0rc1__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. snowflake/snowpark_checkpoints_collector/__init__.py +30 -0
  2. snowflake/snowpark_checkpoints_collector/__version__.py +16 -0
  3. snowflake/snowpark_checkpoints_collector/collection_common.py +160 -0
  4. snowflake/snowpark_checkpoints_collector/collection_result/model/__init__.py +24 -0
  5. snowflake/snowpark_checkpoints_collector/collection_result/model/collection_point_result.py +91 -0
  6. snowflake/snowpark_checkpoints_collector/collection_result/model/collection_point_result_manager.py +76 -0
  7. snowflake/snowpark_checkpoints_collector/column_collection/__init__.py +22 -0
  8. snowflake/snowpark_checkpoints_collector/column_collection/column_collector_manager.py +276 -0
  9. snowflake/snowpark_checkpoints_collector/column_collection/model/__init__.py +75 -0
  10. snowflake/snowpark_checkpoints_collector/column_collection/model/array_column_collector.py +113 -0
  11. snowflake/snowpark_checkpoints_collector/column_collection/model/binary_column_collector.py +87 -0
  12. snowflake/snowpark_checkpoints_collector/column_collection/model/boolean_column_collector.py +71 -0
  13. snowflake/snowpark_checkpoints_collector/column_collection/model/column_collector_base.py +95 -0
  14. snowflake/snowpark_checkpoints_collector/column_collection/model/date_column_collector.py +74 -0
  15. snowflake/snowpark_checkpoints_collector/column_collection/model/day_time_interval_column_collector.py +67 -0
  16. snowflake/snowpark_checkpoints_collector/column_collection/model/decimal_column_collector.py +92 -0
  17. snowflake/snowpark_checkpoints_collector/column_collection/model/empty_column_collector.py +88 -0
  18. snowflake/snowpark_checkpoints_collector/column_collection/model/map_column_collector.py +120 -0
  19. snowflake/snowpark_checkpoints_collector/column_collection/model/null_column_collector.py +49 -0
  20. snowflake/snowpark_checkpoints_collector/column_collection/model/numeric_column_collector.py +108 -0
  21. snowflake/snowpark_checkpoints_collector/column_collection/model/string_column_collector.py +70 -0
  22. snowflake/snowpark_checkpoints_collector/column_collection/model/struct_column_collector.py +102 -0
  23. snowflake/snowpark_checkpoints_collector/column_collection/model/timestamp_column_collector.py +75 -0
  24. snowflake/snowpark_checkpoints_collector/column_collection/model/timestamp_ntz_column_collector.py +75 -0
  25. snowflake/snowpark_checkpoints_collector/column_pandera_checks/__init__.py +20 -0
  26. snowflake/snowpark_checkpoints_collector/column_pandera_checks/pandera_column_checks_manager.py +241 -0
  27. snowflake/snowpark_checkpoints_collector/io_utils/__init__.py +26 -0
  28. snowflake/snowpark_checkpoints_collector/io_utils/io_default_strategy.py +61 -0
  29. snowflake/snowpark_checkpoints_collector/io_utils/io_env_strategy.py +142 -0
  30. snowflake/snowpark_checkpoints_collector/io_utils/io_file_manager.py +79 -0
  31. snowflake/snowpark_checkpoints_collector/singleton.py +23 -0
  32. snowflake/snowpark_checkpoints_collector/snow_connection_model/__init__.py +20 -0
  33. snowflake/snowpark_checkpoints_collector/snow_connection_model/snow_connection.py +203 -0
  34. snowflake/snowpark_checkpoints_collector/summary_stats_collector.py +409 -0
  35. snowflake/snowpark_checkpoints_collector/utils/checkpoint_name_utils.py +53 -0
  36. snowflake/snowpark_checkpoints_collector/utils/extra_config.py +164 -0
  37. snowflake/snowpark_checkpoints_collector/utils/file_utils.py +137 -0
  38. snowflake/snowpark_checkpoints_collector/utils/logging_utils.py +67 -0
  39. snowflake/snowpark_checkpoints_collector/utils/telemetry.py +928 -0
  40. snowpark_checkpoints_collectors-0.3.0.dist-info/METADATA +159 -0
  41. snowpark_checkpoints_collectors-0.3.0.dist-info/RECORD +43 -0
  42. {snowpark_checkpoints_collectors-0.2.0rc1.dist-info → snowpark_checkpoints_collectors-0.3.0.dist-info}/licenses/LICENSE +0 -25
  43. snowpark_checkpoints_collectors-0.2.0rc1.dist-info/METADATA +0 -347
  44. snowpark_checkpoints_collectors-0.2.0rc1.dist-info/RECORD +0 -4
  45. {snowpark_checkpoints_collectors-0.2.0rc1.dist-info → snowpark_checkpoints_collectors-0.3.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,203 @@
1
+ # Copyright 2025 Snowflake Inc.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import io
17
+ import logging
18
+ import os.path
19
+ import time
20
+
21
+ from pathlib import Path
22
+ from typing import Callable, Optional
23
+
24
+ from snowflake.snowpark import Session
25
+ from snowflake.snowpark_checkpoints_collector.collection_common import (
26
+ DOT_PARQUET_EXTENSION,
27
+ )
28
+ from snowflake.snowpark_checkpoints_collector.io_utils.io_file_manager import (
29
+ get_io_file_manager,
30
+ )
31
+
32
+
33
+ STAGE_NAME = "CHECKPOINT_STAGE"
34
+ CREATE_STAGE_STATEMENT_FORMAT = "CREATE TEMP STAGE IF NOT EXISTS {}"
35
+ REMOVE_STAGE_FOLDER_STATEMENT_FORMAT = "REMOVE {}"
36
+ STAGE_PATH_FORMAT = "'@{}/{}'"
37
+ PUT_FILE_IN_STAGE_STATEMENT_FORMAT = "PUT '{}' {} AUTO_COMPRESS=FALSE"
38
+ LOGGER = logging.getLogger(__name__)
39
+
40
+
41
+ class SnowConnection:
42
+
43
+ """Class for manage the Snowpark Connection.
44
+
45
+ Attributes:
46
+ session (Snowpark.Session): the Snowpark session.
47
+
48
+ """
49
+
50
+ def __init__(self, session: Optional[Session] = None) -> None:
51
+ """Init SnowConnection.
52
+
53
+ Args:
54
+ session (Snowpark.Session): the Snowpark session.
55
+
56
+ """
57
+ self.session = (
58
+ session if session is not None else self._create_snowpark_session()
59
+ )
60
+ self.stage_id = int(time.time())
61
+
62
+ def create_snowflake_table_from_local_parquet(
63
+ self,
64
+ table_name: str,
65
+ input_path: str,
66
+ stage_path: Optional[str] = None,
67
+ ) -> None:
68
+ """Upload to parquet files from the input path and create a table.
69
+
70
+ Args:
71
+ table_name (str): the name of the table to be created.
72
+ input_path (str): the input directory path.
73
+ stage_path: (str, optional): the stage path.
74
+
75
+ """
76
+ input_path = (
77
+ os.path.abspath(input_path)
78
+ if not os.path.isabs(input_path)
79
+ else str(Path(input_path).resolve())
80
+ )
81
+ folder = f"table_files_{int(time.time())}"
82
+ stage_path = stage_path if stage_path else folder
83
+ stage_name = f"{STAGE_NAME}_{self.stage_id}"
84
+ stage_directory_path = STAGE_PATH_FORMAT.format(stage_name, stage_path)
85
+
86
+ def is_parquet_file(file: str):
87
+ return file.endswith(DOT_PARQUET_EXTENSION)
88
+
89
+ try:
90
+ self.create_tmp_stage(stage_name)
91
+ self.load_files_to_stage(
92
+ stage_name, stage_path, input_path, is_parquet_file
93
+ )
94
+ self.create_table_from_parquet(table_name, stage_directory_path)
95
+ finally:
96
+ LOGGER.info("Removing stage folder %s", stage_directory_path)
97
+ self.session.sql(
98
+ REMOVE_STAGE_FOLDER_STATEMENT_FORMAT.format(stage_directory_path)
99
+ ).collect()
100
+
101
+ def create_tmp_stage(self, stage_name: str) -> None:
102
+ """Create a temp stage in Snowflake.
103
+
104
+ Args:
105
+ stage_name (str): the name of the stage.
106
+
107
+ """
108
+ create_stage_statement = CREATE_STAGE_STATEMENT_FORMAT.format(stage_name)
109
+ LOGGER.info("Creating temporal stage '%s'", stage_name)
110
+ self.session.sql(create_stage_statement).collect()
111
+
112
+ def load_files_to_stage(
113
+ self,
114
+ stage_name: str,
115
+ folder_name: str,
116
+ input_path: str,
117
+ filter_func: Optional[Callable] = None,
118
+ ) -> None:
119
+ """Load files to a stage in Snowflake.
120
+
121
+ Args:
122
+ stage_name (str): the name of the stage.
123
+ folder_name (str): the folder name.
124
+ input_path (str): the input directory path.
125
+ filter_func (Callable): the filter function to apply to the files.
126
+
127
+ """
128
+ LOGGER.info("Starting to load files to '%s'", stage_name)
129
+ input_path = (
130
+ os.path.abspath(input_path)
131
+ if not os.path.isabs(input_path)
132
+ else str(Path(input_path).resolve())
133
+ )
134
+
135
+ def filter_files(name: str):
136
+ return get_io_file_manager().file_exists(name) and (
137
+ filter_func(name) if filter_func else True
138
+ )
139
+
140
+ target_dir = os.path.join(input_path, "**", "*")
141
+ LOGGER.debug("Searching for files in '%s'", input_path)
142
+ files_collection = get_io_file_manager().ls(target_dir, recursive=True)
143
+
144
+ files = [file for file in files_collection if filter_files(file)]
145
+ files_count = len(files)
146
+
147
+ if files_count == 0:
148
+ raise Exception(f"No files were found in the input directory: {input_path}")
149
+
150
+ LOGGER.debug("Found %s files in '%s'", files_count, input_path)
151
+
152
+ for file in files:
153
+ # if file is relative path, convert to absolute path
154
+ # if absolute path, then try to resolve as some Win32 paths are not in LPN.
155
+ file_full_path = (
156
+ str(os.path.abspath(file))
157
+ if not os.path.isabs(file)
158
+ else str(Path(file).resolve())
159
+ )
160
+ new_file_path = file_full_path.replace(input_path, folder_name)
161
+ # as Posix to convert Windows dir to posix
162
+ new_file_path = Path(new_file_path).as_posix()
163
+ stage_file_path = STAGE_PATH_FORMAT.format(stage_name, new_file_path)
164
+ parquet_file = get_io_file_manager().read_bytes(file_full_path)
165
+ binary_parquet = io.BytesIO(parquet_file)
166
+ LOGGER.info("Loading file '%s' to %s", file_full_path, stage_file_path)
167
+ self.session.file.put_stream(binary_parquet, stage_file_path)
168
+
169
+ def create_table_from_parquet(
170
+ self, table_name: str, stage_directory_path: str
171
+ ) -> None:
172
+ """Create a table from a parquet file in Snowflake.
173
+
174
+ Args:
175
+ table_name (str): the name of the table.
176
+ stage_directory_path (str): the stage directory path.
177
+
178
+ Raise:
179
+ Exception: No parquet files were found in the stage
180
+
181
+ """
182
+ LOGGER.info("Starting to create table '%s' from parquet files", table_name)
183
+ parquet_files = self.session.sql(
184
+ f"LIST {stage_directory_path} PATTERN='.*{DOT_PARQUET_EXTENSION}'"
185
+ ).collect()
186
+ parquet_files_count = len(parquet_files)
187
+ if parquet_files_count == 0:
188
+ raise Exception(
189
+ f"No parquet files were found in the stage: {stage_directory_path}"
190
+ )
191
+
192
+ LOGGER.info(
193
+ "Reading %s parquet files from %s",
194
+ parquet_files_count,
195
+ stage_directory_path,
196
+ )
197
+ dataframe = self.session.read.parquet(path=stage_directory_path)
198
+ LOGGER.info("Creating table '%s' from parquet files", table_name)
199
+ dataframe.write.save_as_table(table_name=table_name, mode="overwrite")
200
+
201
+ def _create_snowpark_session(self) -> Session:
202
+ LOGGER.info("Creating a Snowpark session using the default connection")
203
+ return Session.builder.getOrCreate()
@@ -0,0 +1,409 @@
1
+ # Copyright 2025 Snowflake Inc.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import json
16
+ import logging
17
+ import os
18
+
19
+ from typing import Optional
20
+
21
+ import pandas
22
+ import pandera as pa
23
+
24
+ from pyspark.sql import DataFrame as SparkDataFrame
25
+ from pyspark.sql.functions import col
26
+ from pyspark.sql.types import DoubleType as SparkDoubleType
27
+ from pyspark.sql.types import StringType as SparkStringType
28
+ from pyspark.sql.types import StructField
29
+
30
+ from snowflake.snowpark_checkpoints_collector.collection_common import (
31
+ CHECKPOINT_JSON_OUTPUT_FILE_NAME_FORMAT,
32
+ COLUMNS_KEY,
33
+ DATAFRAME_CUSTOM_DATA_KEY,
34
+ DATAFRAME_PANDERA_SCHEMA_KEY,
35
+ DECIMAL_COLUMN_TYPE,
36
+ DOT_PARQUET_EXTENSION,
37
+ INTEGER_TYPE_COLLECTION,
38
+ NULL_COLUMN_TYPE,
39
+ PANDAS_LONG_TYPE,
40
+ PANDAS_OBJECT_TYPE_COLLECTION,
41
+ CheckpointMode,
42
+ )
43
+ from snowflake.snowpark_checkpoints_collector.collection_result.model import (
44
+ CollectionPointResult,
45
+ CollectionPointResultManager,
46
+ CollectionResult,
47
+ )
48
+ from snowflake.snowpark_checkpoints_collector.column_collection import (
49
+ ColumnCollectorManager,
50
+ )
51
+ from snowflake.snowpark_checkpoints_collector.column_pandera_checks import (
52
+ PanderaColumnChecksManager,
53
+ )
54
+ from snowflake.snowpark_checkpoints_collector.io_utils.io_file_manager import (
55
+ get_io_file_manager,
56
+ )
57
+ from snowflake.snowpark_checkpoints_collector.snow_connection_model import (
58
+ SnowConnection,
59
+ )
60
+ from snowflake.snowpark_checkpoints_collector.utils import (
61
+ checkpoint_name_utils,
62
+ file_utils,
63
+ )
64
+ from snowflake.snowpark_checkpoints_collector.utils.extra_config import (
65
+ get_checkpoint_mode,
66
+ get_checkpoint_sample,
67
+ is_checkpoint_enabled,
68
+ )
69
+ from snowflake.snowpark_checkpoints_collector.utils.logging_utils import log
70
+ from snowflake.snowpark_checkpoints_collector.utils.telemetry import report_telemetry
71
+
72
+
73
+ LOGGER = logging.getLogger(__name__)
74
+
75
+
76
+ @log
77
+ def collect_dataframe_checkpoint(
78
+ df: SparkDataFrame,
79
+ checkpoint_name: str,
80
+ sample: Optional[float] = None,
81
+ mode: Optional[CheckpointMode] = None,
82
+ output_path: Optional[str] = None,
83
+ ) -> None:
84
+ """Collect a DataFrame checkpoint.
85
+
86
+ Args:
87
+ df (SparkDataFrame): The input Spark DataFrame to collect.
88
+ checkpoint_name (str): The name of the checkpoint.
89
+ sample (float, optional): Fraction of DataFrame to sample for schema inference.
90
+ Defaults to 1.0.
91
+ mode (CheckpointMode): The mode to execution the collection.
92
+ Defaults to CheckpointMode.Schema
93
+ output_path (str, optional): The output path to save the checkpoint.
94
+ Defaults to Current working Directory.
95
+
96
+ Raises:
97
+ Exception: Invalid mode value.
98
+ Exception: Invalid checkpoint name. Checkpoint names must only contain alphanumeric characters
99
+ , underscores and dollar signs.
100
+
101
+ """
102
+ normalized_checkpoint_name = checkpoint_name_utils.normalize_checkpoint_name(
103
+ checkpoint_name
104
+ )
105
+ if normalized_checkpoint_name != checkpoint_name:
106
+ LOGGER.info(
107
+ "Checkpoint name '%s' was normalized to '%s'",
108
+ checkpoint_name,
109
+ normalized_checkpoint_name,
110
+ )
111
+ is_valid_checkpoint_name = checkpoint_name_utils.is_valid_checkpoint_name(
112
+ normalized_checkpoint_name
113
+ )
114
+ if not is_valid_checkpoint_name:
115
+ raise Exception(
116
+ f"Invalid checkpoint name: {normalized_checkpoint_name}. "
117
+ "Checkpoint names must only contain alphanumeric characters, underscores and dollar signs."
118
+ )
119
+ if not is_checkpoint_enabled(normalized_checkpoint_name):
120
+ LOGGER.info(
121
+ "Checkpoint '%s' is disabled. Skipping collection.",
122
+ normalized_checkpoint_name,
123
+ )
124
+ return
125
+
126
+ LOGGER.info("Starting to collect checkpoint '%s'", normalized_checkpoint_name)
127
+ LOGGER.debug("DataFrame size: %s rows", df.count())
128
+ LOGGER.debug("DataFrame schema: %s", df.schema)
129
+
130
+ collection_point_file_path = file_utils.get_collection_point_source_file_path()
131
+ collection_point_line_of_code = file_utils.get_collection_point_line_of_code()
132
+ collection_point_result = CollectionPointResult(
133
+ collection_point_file_path,
134
+ collection_point_line_of_code,
135
+ normalized_checkpoint_name,
136
+ )
137
+
138
+ try:
139
+ if _is_empty_dataframe_without_schema(df):
140
+ raise Exception(
141
+ "It is not possible to collect an empty DataFrame without schema"
142
+ )
143
+
144
+ _mode = get_checkpoint_mode(normalized_checkpoint_name, mode)
145
+
146
+ if _mode == CheckpointMode.SCHEMA:
147
+ column_type_dict = _get_spark_column_types(df)
148
+ _sample = get_checkpoint_sample(normalized_checkpoint_name, sample)
149
+ LOGGER.info(
150
+ "Collecting checkpoint in %s mode using sample value %s",
151
+ CheckpointMode.SCHEMA.name,
152
+ _sample,
153
+ )
154
+ _collect_dataframe_checkpoint_mode_schema(
155
+ normalized_checkpoint_name,
156
+ df,
157
+ _sample,
158
+ column_type_dict,
159
+ output_path,
160
+ )
161
+ elif _mode == CheckpointMode.DATAFRAME:
162
+ LOGGER.info(
163
+ "Collecting checkpoint in %s mode", CheckpointMode.DATAFRAME.name
164
+ )
165
+ snow_connection = SnowConnection()
166
+ _collect_dataframe_checkpoint_mode_dataframe(
167
+ normalized_checkpoint_name, df, snow_connection, output_path
168
+ )
169
+ else:
170
+ raise Exception(f"Invalid mode value: {_mode}")
171
+
172
+ collection_point_result.result = CollectionResult.PASS
173
+ LOGGER.info(
174
+ "Checkpoint '%s' collected successfully", normalized_checkpoint_name
175
+ )
176
+
177
+ except Exception as err:
178
+ collection_point_result.result = CollectionResult.FAIL
179
+ error_message = str(err)
180
+ raise Exception(error_message) from err
181
+
182
+ finally:
183
+ collection_point_result_manager = CollectionPointResultManager(output_path)
184
+ collection_point_result_manager.add_result(collection_point_result)
185
+
186
+
187
+ @report_telemetry(params_list=["column_type_dict"])
188
+ def _collect_dataframe_checkpoint_mode_schema(
189
+ checkpoint_name: str,
190
+ df: SparkDataFrame,
191
+ sample: float,
192
+ column_type_dict: dict[str, any],
193
+ output_path: Optional[str] = None,
194
+ ) -> None:
195
+ sampled_df = df.sample(sample)
196
+ if sampled_df.isEmpty():
197
+ LOGGER.warning("Sampled DataFrame is empty. Collecting full DataFrame.")
198
+ sampled_df = df
199
+
200
+ pandas_df = _to_pandas(sampled_df)
201
+ is_empty_df_with_object_column = _is_empty_dataframe_with_object_column(df)
202
+ if is_empty_df_with_object_column:
203
+ LOGGER.debug(
204
+ "DataFrame is empty with object column. Skipping Pandera schema inference."
205
+ )
206
+ pandera_infer_schema = {}
207
+ else:
208
+ LOGGER.debug("Inferring Pandera schema from DataFrame")
209
+ pandera_infer_schema = pa.infer_schema(pandas_df)
210
+
211
+ column_name_collection = df.schema.names
212
+ columns_to_remove_from_pandera_schema_collection = []
213
+ column_custom_data_collection = []
214
+ column_collector_manager = ColumnCollectorManager()
215
+ column_pandera_checks_manager = PanderaColumnChecksManager()
216
+
217
+ for column_name in column_name_collection:
218
+ struct_field_column = column_type_dict[column_name]
219
+ column_type = struct_field_column.dataType.typeName()
220
+ LOGGER.info("Collecting column '%s' of type '%s'", column_name, column_type)
221
+ pyspark_column = df.select(col(column_name))
222
+
223
+ is_column_to_remove_from_pandera_schema = (
224
+ _is_column_to_remove_from_pandera_schema(column_type)
225
+ )
226
+ if is_column_to_remove_from_pandera_schema:
227
+ columns_to_remove_from_pandera_schema_collection.append(column_name)
228
+
229
+ is_empty_column = (
230
+ pyspark_column.dropna().isEmpty() and column_type is not NULL_COLUMN_TYPE
231
+ )
232
+ if is_empty_column:
233
+ LOGGER.debug("Column '%s' is empty.", column_name)
234
+ custom_data = column_collector_manager.collect_empty_custom_data(
235
+ column_name, struct_field_column, pyspark_column
236
+ )
237
+ column_custom_data_collection.append(custom_data)
238
+ continue
239
+
240
+ pandera_column = pandera_infer_schema.columns[column_name]
241
+ pandera_column.checks = []
242
+ column_pandera_checks_manager.add_checks_column(
243
+ column_name, column_type, df, pandera_column
244
+ )
245
+
246
+ custom_data = column_collector_manager.collect_column(
247
+ column_name, struct_field_column, pyspark_column
248
+ )
249
+ column_custom_data_collection.append(custom_data)
250
+
251
+ pandera_infer_schema_dict = _get_pandera_infer_schema_as_dict(
252
+ pandera_infer_schema,
253
+ is_empty_df_with_object_column,
254
+ columns_to_remove_from_pandera_schema_collection,
255
+ )
256
+
257
+ dataframe_custom_column_data = {COLUMNS_KEY: column_custom_data_collection}
258
+ dataframe_schema_contract = {
259
+ DATAFRAME_PANDERA_SCHEMA_KEY: pandera_infer_schema_dict,
260
+ DATAFRAME_CUSTOM_DATA_KEY: dataframe_custom_column_data,
261
+ }
262
+
263
+ dataframe_schema_contract_json = json.dumps(dataframe_schema_contract)
264
+ _generate_json_checkpoint_file(
265
+ checkpoint_name, dataframe_schema_contract_json, output_path
266
+ )
267
+
268
+
269
+ def _get_spark_column_types(df: SparkDataFrame) -> dict[str, StructField]:
270
+ schema = df.schema
271
+ column_type_collection = {}
272
+ for field in schema.fields:
273
+ column_name = field.name
274
+ column_type_collection[column_name] = field
275
+ return column_type_collection
276
+
277
+
278
+ def _is_empty_dataframe_without_schema(df: SparkDataFrame) -> bool:
279
+ is_empty = df.isEmpty()
280
+ has_schema = len(df.schema.fields) > 0
281
+ return is_empty and not has_schema
282
+
283
+
284
+ def _is_empty_dataframe_with_object_column(df: SparkDataFrame):
285
+ is_empty = df.isEmpty()
286
+ if not is_empty:
287
+ return False
288
+
289
+ for field in df.schema.fields:
290
+ if field.dataType.typeName() in PANDAS_OBJECT_TYPE_COLLECTION:
291
+ return True
292
+
293
+ return False
294
+
295
+
296
+ def _is_column_to_remove_from_pandera_schema(column_type) -> bool:
297
+ is_decimal_type = column_type == DECIMAL_COLUMN_TYPE
298
+ return is_decimal_type
299
+
300
+
301
+ def _get_pandera_infer_schema_as_dict(
302
+ pandera_infer_schema, is_empty_df_with_string_column, columns_to_remove_collection
303
+ ) -> dict[str, any]:
304
+ if is_empty_df_with_string_column:
305
+ return {}
306
+
307
+ pandera_infer_schema_dict = json.loads(pandera_infer_schema.to_json())
308
+ for column in columns_to_remove_collection:
309
+ LOGGER.debug("Removing column '%s' from Pandera schema", column)
310
+ del pandera_infer_schema_dict[COLUMNS_KEY][column]
311
+
312
+ return pandera_infer_schema_dict
313
+
314
+
315
+ def _generate_json_checkpoint_file(
316
+ checkpoint_name, dataframe_schema_contract, output_path: Optional[str] = None
317
+ ) -> None:
318
+ checkpoint_file_name = CHECKPOINT_JSON_OUTPUT_FILE_NAME_FORMAT.format(
319
+ checkpoint_name
320
+ )
321
+ output_directory_path = file_utils.get_output_directory_path(output_path)
322
+ checkpoint_file_path = os.path.join(output_directory_path, checkpoint_file_name)
323
+ LOGGER.info("Writing DataFrame JSON schema file to '%s'", checkpoint_file_path)
324
+ get_io_file_manager().write(checkpoint_file_path, dataframe_schema_contract)
325
+
326
+
327
+ @report_telemetry(params_list=["df"])
328
+ def _collect_dataframe_checkpoint_mode_dataframe(
329
+ checkpoint_name: str,
330
+ df: SparkDataFrame,
331
+ snow_connection: SnowConnection,
332
+ output_path: Optional[str] = None,
333
+ ) -> None:
334
+ output_path = file_utils.get_output_directory_path(output_path)
335
+ parquet_directory = os.path.join(output_path, checkpoint_name)
336
+ generate_parquet_for_spark_df(df, parquet_directory)
337
+ _create_snowflake_table_from_parquet(
338
+ checkpoint_name, parquet_directory, snow_connection
339
+ )
340
+
341
+
342
+ def generate_parquet_for_spark_df(spark_df: SparkDataFrame, output_path: str) -> None:
343
+ """Generate a parquet file from a Spark DataFrame.
344
+
345
+ This function will convert Float to Double to avoid precision problems.
346
+ Spark parquet use IEEE 32-bit floating point values,
347
+ while Snowflake uses IEEE 64-bit floating point values.
348
+
349
+ Args:
350
+ spark_df: dataframe to be saved as parquet
351
+ output_path: path to save the parquet files.
352
+ returns: None
353
+
354
+ Raises:
355
+ Exception: No parquet files were generated.
356
+
357
+ """
358
+ new_cols = [
359
+ (
360
+ col(c).cast(SparkStringType()).cast(SparkDoubleType()).alias(c)
361
+ if t == "float"
362
+ else col(c)
363
+ )
364
+ for (c, t) in spark_df.dtypes
365
+ ]
366
+ converted_df = spark_df.select(new_cols)
367
+
368
+ if get_io_file_manager().folder_exists(output_path):
369
+ LOGGER.warning(
370
+ "Output directory '%s' already exists. Deleting it...", output_path
371
+ )
372
+ get_io_file_manager().remove_dir(output_path)
373
+
374
+ LOGGER.info("Writing DataFrame to parquet files at '%s'", output_path)
375
+ converted_df.write.parquet(output_path, mode="overwrite")
376
+
377
+ target_dir = os.path.join(output_path, "**", f"*{DOT_PARQUET_EXTENSION}")
378
+ parquet_files = get_io_file_manager().ls(target_dir, recursive=True)
379
+ parquet_files_count = len(parquet_files)
380
+ if parquet_files_count == 0:
381
+ raise Exception("No parquet files were generated.")
382
+ LOGGER.info(
383
+ "%s parquet files were written in '%s'",
384
+ parquet_files_count,
385
+ output_path,
386
+ )
387
+
388
+
389
+ def _create_snowflake_table_from_parquet(
390
+ table_name: str, input_path: str, snow_connection: SnowConnection
391
+ ) -> None:
392
+ snow_connection.create_snowflake_table_from_local_parquet(table_name, input_path)
393
+
394
+
395
+ def _to_pandas(sampled_df: SparkDataFrame) -> pandas.DataFrame:
396
+ LOGGER.debug("Converting Spark DataFrame to Pandas DataFrame")
397
+ pandas_df = sampled_df.toPandas()
398
+ for field in sampled_df.schema.fields:
399
+ has_nan = pandas_df[field.name].isna().any()
400
+ is_integer = field.dataType.typeName() in INTEGER_TYPE_COLLECTION
401
+ if has_nan and is_integer:
402
+ LOGGER.debug(
403
+ "Converting column '%s' to '%s' type",
404
+ field.name,
405
+ PANDAS_LONG_TYPE,
406
+ )
407
+ pandas_df[field.name] = pandas_df[field.name].astype(PANDAS_LONG_TYPE)
408
+
409
+ return pandas_df
@@ -0,0 +1,53 @@
1
+ # Copyright 2025 Snowflake Inc.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import re as regx
17
+
18
+
19
+ CHECKPOINT_NAME_REGEX_PATTERN = r"[a-zA-Z_][a-zA-Z0-9_$]+"
20
+ TRANSLATION_TABLE = str.maketrans({" ": "_", "-": "_"})
21
+
22
+
23
+ def normalize_checkpoint_name(checkpoint_name: str) -> str:
24
+ """Normalize the provided checkpoint name by replacing: the whitespace and hyphen tokens by underscore token.
25
+
26
+ Args:
27
+ checkpoint_name (str): The checkpoint name to normalize.
28
+
29
+ Returns:
30
+ str: the checkpoint name normalized.
31
+
32
+ """
33
+ normalized_checkpoint_name = checkpoint_name.translate(TRANSLATION_TABLE)
34
+ return normalized_checkpoint_name
35
+
36
+
37
+ def is_valid_checkpoint_name(checkpoint_name: str) -> bool:
38
+ """Check if the provided checkpoint name is valid.
39
+
40
+ A valid checkpoint name must:
41
+ - Start with a letter (a-z, A-Z) or an underscore (_)
42
+ - Be followed by any combination of letters, digits (0-9), underscores (_), and dollar signs ($).
43
+
44
+ Args:
45
+ checkpoint_name (str): The checkpoint name to validate.
46
+
47
+ Returns:
48
+ bool: True if the checkpoint name is valid; otherwise, False.
49
+
50
+ """
51
+ matched = regx.fullmatch(CHECKPOINT_NAME_REGEX_PATTERN, checkpoint_name)
52
+ is_valid = bool(matched)
53
+ return is_valid