snowpark-checkpoints-collectors 0.1.0rc3__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. snowflake/snowpark_checkpoints_collector/__init__.py +22 -0
  2. snowflake/snowpark_checkpoints_collector/__version__.py +16 -0
  3. snowflake/snowpark_checkpoints_collector/collection_common.py +160 -0
  4. snowflake/snowpark_checkpoints_collector/collection_result/model/__init__.py +24 -0
  5. snowflake/snowpark_checkpoints_collector/collection_result/model/collection_point_result.py +91 -0
  6. snowflake/snowpark_checkpoints_collector/collection_result/model/collection_point_result_manager.py +69 -0
  7. snowflake/snowpark_checkpoints_collector/column_collection/__init__.py +22 -0
  8. snowflake/snowpark_checkpoints_collector/column_collection/column_collector_manager.py +253 -0
  9. snowflake/snowpark_checkpoints_collector/column_collection/model/__init__.py +75 -0
  10. snowflake/snowpark_checkpoints_collector/column_collection/model/array_column_collector.py +113 -0
  11. snowflake/snowpark_checkpoints_collector/column_collection/model/binary_column_collector.py +87 -0
  12. snowflake/snowpark_checkpoints_collector/column_collection/model/boolean_column_collector.py +71 -0
  13. snowflake/snowpark_checkpoints_collector/column_collection/model/column_collector_base.py +95 -0
  14. snowflake/snowpark_checkpoints_collector/column_collection/model/date_column_collector.py +74 -0
  15. snowflake/snowpark_checkpoints_collector/column_collection/model/day_time_interval_column_collector.py +67 -0
  16. snowflake/snowpark_checkpoints_collector/column_collection/model/decimal_column_collector.py +92 -0
  17. snowflake/snowpark_checkpoints_collector/column_collection/model/empty_column_collector.py +88 -0
  18. snowflake/snowpark_checkpoints_collector/column_collection/model/map_column_collector.py +120 -0
  19. snowflake/snowpark_checkpoints_collector/column_collection/model/null_column_collector.py +49 -0
  20. snowflake/snowpark_checkpoints_collector/column_collection/model/numeric_column_collector.py +108 -0
  21. snowflake/snowpark_checkpoints_collector/column_collection/model/string_column_collector.py +70 -0
  22. snowflake/snowpark_checkpoints_collector/column_collection/model/struct_column_collector.py +102 -0
  23. snowflake/snowpark_checkpoints_collector/column_collection/model/timestamp_column_collector.py +75 -0
  24. snowflake/snowpark_checkpoints_collector/column_collection/model/timestamp_ntz_column_collector.py +75 -0
  25. snowflake/snowpark_checkpoints_collector/column_pandera_checks/__init__.py +20 -0
  26. snowflake/snowpark_checkpoints_collector/column_pandera_checks/pandera_column_checks_manager.py +223 -0
  27. snowflake/snowpark_checkpoints_collector/singleton.py +23 -0
  28. snowflake/snowpark_checkpoints_collector/snow_connection_model/__init__.py +20 -0
  29. snowflake/snowpark_checkpoints_collector/snow_connection_model/snow_connection.py +172 -0
  30. snowflake/snowpark_checkpoints_collector/summary_stats_collector.py +366 -0
  31. snowflake/snowpark_checkpoints_collector/utils/checkpoint_name_utils.py +53 -0
  32. snowflake/snowpark_checkpoints_collector/utils/extra_config.py +112 -0
  33. snowflake/snowpark_checkpoints_collector/utils/file_utils.py +132 -0
  34. snowflake/snowpark_checkpoints_collector/utils/telemetry.py +889 -0
  35. {snowpark_checkpoints_collectors-0.1.0rc3.dist-info → snowpark_checkpoints_collectors-0.1.2.dist-info}/METADATA +4 -6
  36. snowpark_checkpoints_collectors-0.1.2.dist-info/RECORD +38 -0
  37. {snowpark_checkpoints_collectors-0.1.0rc3.dist-info → snowpark_checkpoints_collectors-0.1.2.dist-info}/licenses/LICENSE +0 -25
  38. snowpark_checkpoints_collectors-0.1.0rc3.dist-info/RECORD +0 -4
  39. {snowpark_checkpoints_collectors-0.1.0rc3.dist-info → snowpark_checkpoints_collectors-0.1.2.dist-info}/WHEEL +0 -0
@@ -0,0 +1,366 @@
1
+ # Copyright 2025 Snowflake Inc.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import glob
16
+ import json
17
+ import os
18
+ import shutil
19
+
20
+ from typing import Optional
21
+
22
+ import pandas
23
+ import pandera as pa
24
+
25
+ from pyspark.sql import DataFrame as SparkDataFrame
26
+ from pyspark.sql.functions import col
27
+ from pyspark.sql.types import DoubleType as SparkDoubleType
28
+ from pyspark.sql.types import StringType as SparkStringType
29
+ from pyspark.sql.types import StructField
30
+
31
+ from snowflake.snowpark_checkpoints_collector.collection_common import (
32
+ CHECKPOINT_JSON_OUTPUT_FILE_NAME_FORMAT,
33
+ COLUMNS_KEY,
34
+ DATAFRAME_CUSTOM_DATA_KEY,
35
+ DATAFRAME_PANDERA_SCHEMA_KEY,
36
+ DECIMAL_COLUMN_TYPE,
37
+ DOT_PARQUET_EXTENSION,
38
+ INTEGER_TYPE_COLLECTION,
39
+ NULL_COLUMN_TYPE,
40
+ PANDAS_LONG_TYPE,
41
+ PANDAS_OBJECT_TYPE_COLLECTION,
42
+ CheckpointMode,
43
+ )
44
+ from snowflake.snowpark_checkpoints_collector.collection_result.model import (
45
+ CollectionPointResult,
46
+ CollectionPointResultManager,
47
+ CollectionResult,
48
+ )
49
+ from snowflake.snowpark_checkpoints_collector.column_collection import (
50
+ ColumnCollectorManager,
51
+ )
52
+ from snowflake.snowpark_checkpoints_collector.column_pandera_checks import (
53
+ PanderaColumnChecksManager,
54
+ )
55
+ from snowflake.snowpark_checkpoints_collector.snow_connection_model import (
56
+ SnowConnection,
57
+ )
58
+ from snowflake.snowpark_checkpoints_collector.utils import (
59
+ checkpoint_name_utils,
60
+ file_utils,
61
+ )
62
+ from snowflake.snowpark_checkpoints_collector.utils.extra_config import (
63
+ get_checkpoint_mode,
64
+ get_checkpoint_sample,
65
+ is_checkpoint_enabled,
66
+ )
67
+ from snowflake.snowpark_checkpoints_collector.utils.telemetry import report_telemetry
68
+
69
+
70
+ def collect_dataframe_checkpoint(
71
+ df: SparkDataFrame,
72
+ checkpoint_name: str,
73
+ sample: Optional[float] = None,
74
+ mode: Optional[CheckpointMode] = None,
75
+ output_path: Optional[str] = None,
76
+ ) -> None:
77
+ """Collect a DataFrame checkpoint.
78
+
79
+ Args:
80
+ df (SparkDataFrame): The input Spark DataFrame to collect.
81
+ checkpoint_name (str): The name of the checkpoint.
82
+ sample (float, optional): Fraction of DataFrame to sample for schema inference.
83
+ Defaults to 1.0.
84
+ mode (CheckpointMode): The mode to execution the collection.
85
+ Defaults to CheckpointMode.Schema
86
+ output_path (str, optional): The output path to save the checkpoint.
87
+ Defaults to Current working Directory.
88
+
89
+ Raises:
90
+ Exception: Invalid mode value.
91
+ Exception: Invalid checkpoint name. Checkpoint names must only contain alphanumeric characters and underscores.
92
+
93
+ """
94
+ try:
95
+ normalized_checkpoint_name = checkpoint_name_utils.normalize_checkpoint_name(
96
+ checkpoint_name
97
+ )
98
+ is_valid_checkpoint_name = checkpoint_name_utils.is_valid_checkpoint_name(
99
+ normalized_checkpoint_name
100
+ )
101
+ if not is_valid_checkpoint_name:
102
+ raise Exception(
103
+ f"Invalid checkpoint name: {checkpoint_name}. Checkpoint names must only contain alphanumeric "
104
+ f"characters and underscores."
105
+ )
106
+
107
+ if is_checkpoint_enabled(normalized_checkpoint_name):
108
+
109
+ collection_point_file_path = (
110
+ file_utils.get_collection_point_source_file_path()
111
+ )
112
+ collection_point_line_of_code = (
113
+ file_utils.get_collection_point_line_of_code()
114
+ )
115
+ collection_point_result = CollectionPointResult(
116
+ collection_point_file_path,
117
+ collection_point_line_of_code,
118
+ normalized_checkpoint_name,
119
+ )
120
+
121
+ try:
122
+
123
+ _sample = get_checkpoint_sample(normalized_checkpoint_name, sample)
124
+
125
+ if _is_empty_dataframe_without_schema(df):
126
+ raise Exception(
127
+ "It is not possible to collect an empty DataFrame without schema"
128
+ )
129
+
130
+ _mode = get_checkpoint_mode(normalized_checkpoint_name, mode)
131
+
132
+ if _mode == CheckpointMode.SCHEMA:
133
+ column_type_dict = _get_spark_column_types(df)
134
+ _collect_dataframe_checkpoint_mode_schema(
135
+ normalized_checkpoint_name,
136
+ df,
137
+ _sample,
138
+ column_type_dict,
139
+ output_path,
140
+ )
141
+
142
+ elif _mode == CheckpointMode.DATAFRAME:
143
+ snow_connection = SnowConnection()
144
+ _collect_dataframe_checkpoint_mode_dataframe(
145
+ normalized_checkpoint_name, df, snow_connection, output_path
146
+ )
147
+
148
+ else:
149
+ raise Exception("Invalid mode value.")
150
+
151
+ collection_point_result.result = CollectionResult.PASS
152
+
153
+ except Exception as err:
154
+ collection_point_result.result = CollectionResult.FAIL
155
+ error_message = str(err)
156
+ raise Exception(error_message) from err
157
+
158
+ finally:
159
+ collection_point_result_manager = CollectionPointResultManager(
160
+ output_path
161
+ )
162
+ collection_point_result_manager.add_result(collection_point_result)
163
+
164
+ except Exception as err:
165
+ error_message = str(err)
166
+ raise Exception(error_message) from err
167
+
168
+
169
+ @report_telemetry(params_list=["column_type_dict"])
170
+ def _collect_dataframe_checkpoint_mode_schema(
171
+ checkpoint_name: str,
172
+ df: SparkDataFrame,
173
+ sample: float,
174
+ column_type_dict: dict[str, any],
175
+ output_path: Optional[str] = None,
176
+ ) -> None:
177
+ sampled_df = df.sample(sample)
178
+ if sampled_df.isEmpty():
179
+ sampled_df = df
180
+ pandas_df = _to_pandas(sampled_df)
181
+ is_empty_df_with_object_column = _is_empty_dataframe_with_object_column(df)
182
+ pandera_infer_schema = (
183
+ pa.infer_schema(pandas_df) if not is_empty_df_with_object_column else {}
184
+ )
185
+
186
+ column_name_collection = df.schema.names
187
+ columns_to_remove_from_pandera_schema_collection = []
188
+ column_custom_data_collection = []
189
+ column_collector_manager = ColumnCollectorManager()
190
+ column_pandera_checks_manager = PanderaColumnChecksManager()
191
+
192
+ for column_name in column_name_collection:
193
+ struct_field_column = column_type_dict[column_name]
194
+ column_type = struct_field_column.dataType.typeName()
195
+ pyspark_column = df.select(col(column_name))
196
+
197
+ is_empty_column = (
198
+ pyspark_column.dropna().isEmpty() and column_type is not NULL_COLUMN_TYPE
199
+ )
200
+ is_column_to_remove_from_pandera_schema = (
201
+ _is_column_to_remove_from_pandera_schema(column_type)
202
+ )
203
+
204
+ if is_column_to_remove_from_pandera_schema:
205
+ columns_to_remove_from_pandera_schema_collection.append(column_name)
206
+
207
+ if is_empty_column:
208
+ custom_data = column_collector_manager.collect_empty_custom_data(
209
+ column_name, struct_field_column, pyspark_column
210
+ )
211
+ column_custom_data_collection.append(custom_data)
212
+ continue
213
+
214
+ pandera_column = pandera_infer_schema.columns[column_name]
215
+ pandera_column.checks = []
216
+ column_pandera_checks_manager.add_checks_column(
217
+ column_name, column_type, df, pandera_column
218
+ )
219
+
220
+ custom_data = column_collector_manager.collect_column(
221
+ column_name, struct_field_column, pyspark_column
222
+ )
223
+ column_custom_data_collection.append(custom_data)
224
+
225
+ pandera_infer_schema_dict = _get_pandera_infer_schema_as_dict(
226
+ pandera_infer_schema,
227
+ is_empty_df_with_object_column,
228
+ columns_to_remove_from_pandera_schema_collection,
229
+ )
230
+
231
+ dataframe_custom_column_data = {COLUMNS_KEY: column_custom_data_collection}
232
+ dataframe_schema_contract = {
233
+ DATAFRAME_PANDERA_SCHEMA_KEY: pandera_infer_schema_dict,
234
+ DATAFRAME_CUSTOM_DATA_KEY: dataframe_custom_column_data,
235
+ }
236
+
237
+ dataframe_schema_contract_json = json.dumps(dataframe_schema_contract)
238
+ _generate_json_checkpoint_file(
239
+ checkpoint_name, dataframe_schema_contract_json, output_path
240
+ )
241
+
242
+
243
+ def _get_spark_column_types(df: SparkDataFrame) -> dict[str, StructField]:
244
+ schema = df.schema
245
+ column_type_collection = {}
246
+ for field in schema.fields:
247
+ column_name = field.name
248
+ column_type_collection[column_name] = field
249
+ return column_type_collection
250
+
251
+
252
+ def _is_empty_dataframe_without_schema(df: SparkDataFrame) -> bool:
253
+ is_empty = df.isEmpty()
254
+ has_schema = len(df.schema.fields) > 0
255
+ return is_empty and not has_schema
256
+
257
+
258
+ def _is_empty_dataframe_with_object_column(df: SparkDataFrame):
259
+ is_empty = df.isEmpty()
260
+ if not is_empty:
261
+ return False
262
+
263
+ for field in df.schema.fields:
264
+ if field.dataType.typeName() in PANDAS_OBJECT_TYPE_COLLECTION:
265
+ return True
266
+
267
+ return False
268
+
269
+
270
+ def _is_column_to_remove_from_pandera_schema(column_type) -> bool:
271
+ is_decimal_type = column_type == DECIMAL_COLUMN_TYPE
272
+ return is_decimal_type
273
+
274
+
275
+ def _get_pandera_infer_schema_as_dict(
276
+ pandera_infer_schema, is_empty_df_with_string_column, columns_to_remove_collection
277
+ ) -> dict[str, any]:
278
+ if is_empty_df_with_string_column:
279
+ return {}
280
+
281
+ pandera_infer_schema_dict = json.loads(pandera_infer_schema.to_json())
282
+ for column in columns_to_remove_collection:
283
+ del pandera_infer_schema_dict[COLUMNS_KEY][column]
284
+
285
+ return pandera_infer_schema_dict
286
+
287
+
288
+ def _generate_json_checkpoint_file(
289
+ checkpoint_name, dataframe_schema_contract, output_path: Optional[str] = None
290
+ ) -> None:
291
+ checkpoint_file_name = CHECKPOINT_JSON_OUTPUT_FILE_NAME_FORMAT.format(
292
+ checkpoint_name
293
+ )
294
+ output_directory_path = file_utils.get_output_directory_path(output_path)
295
+ checkpoint_file_path = os.path.join(output_directory_path, checkpoint_file_name)
296
+ with open(checkpoint_file_path, "w") as f:
297
+ f.write(dataframe_schema_contract)
298
+
299
+
300
+ @report_telemetry(params_list=["df"])
301
+ def _collect_dataframe_checkpoint_mode_dataframe(
302
+ checkpoint_name: str,
303
+ df: SparkDataFrame,
304
+ snow_connection: SnowConnection,
305
+ output_path: Optional[str] = None,
306
+ ) -> None:
307
+ output_path = file_utils.get_output_directory_path(output_path)
308
+ parquet_directory = os.path.join(output_path, checkpoint_name)
309
+ generate_parquet_for_spark_df(df, parquet_directory)
310
+ _create_snowflake_table_from_parquet(
311
+ checkpoint_name, parquet_directory, snow_connection
312
+ )
313
+
314
+
315
+ def generate_parquet_for_spark_df(spark_df: SparkDataFrame, output_path: str) -> None:
316
+ """Generate a parquet file from a Spark DataFrame.
317
+
318
+ This function will convert Float to Double to avoid precision problems.
319
+ Spark parquet use IEEE 32-bit floating point values,
320
+ while Snowflake uses IEEE 64-bit floating point values.
321
+
322
+ Args:
323
+ spark_df: dataframe to be saved as parquet
324
+ output_path: path to save the parquet files.
325
+ returns: None
326
+
327
+ Raises:
328
+ Exception: No parquet files were generated.
329
+
330
+ """
331
+ new_cols = [
332
+ (
333
+ col(c).cast(SparkStringType()).cast(SparkDoubleType()).alias(c)
334
+ if t == "float"
335
+ else col(c)
336
+ )
337
+ for (c, t) in spark_df.dtypes
338
+ ]
339
+ converted_df = spark_df.select(new_cols)
340
+
341
+ if os.path.exists(output_path):
342
+ shutil.rmtree(output_path)
343
+
344
+ converted_df.write.parquet(output_path, mode="overwrite")
345
+
346
+ target_dir = os.path.join(output_path, "**", f"*{DOT_PARQUET_EXTENSION}")
347
+ files = glob.glob(target_dir, recursive=True)
348
+ if len(files) == 0:
349
+ raise Exception("No parquet files were generated.")
350
+
351
+
352
+ def _create_snowflake_table_from_parquet(
353
+ table_name: str, input_path: str, snow_connection: SnowConnection
354
+ ) -> None:
355
+ snow_connection.create_snowflake_table_from_local_parquet(table_name, input_path)
356
+
357
+
358
+ def _to_pandas(sampled_df: SparkDataFrame) -> pandas.DataFrame:
359
+ pandas_df = sampled_df.toPandas()
360
+ for field in sampled_df.schema.fields:
361
+ has_nan = pandas_df[field.name].isna().any()
362
+ is_integer = field.dataType.typeName() in INTEGER_TYPE_COLLECTION
363
+ if has_nan and is_integer:
364
+ pandas_df[field.name] = pandas_df[field.name].astype(PANDAS_LONG_TYPE)
365
+
366
+ return pandas_df
@@ -0,0 +1,53 @@
1
+ # Copyright 2025 Snowflake Inc.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import re as regx
17
+
18
+
19
+ CHECKPOINT_NAME_REGEX_PATTERN = r"[a-zA-Z_][a-zA-Z0-9_]+"
20
+ TRANSLATION_TABLE = str.maketrans({" ": "_", "-": "_"})
21
+
22
+
23
+ def normalize_checkpoint_name(checkpoint_name: str) -> str:
24
+ """Normalize the provided checkpoint name by replacing: the whitespace and hyphen tokens by underscore token.
25
+
26
+ Args:
27
+ checkpoint_name (str): The checkpoint name to normalize.
28
+
29
+ Returns:
30
+ str: the checkpoint name normalized.
31
+
32
+ """
33
+ normalized_checkpoint_name = checkpoint_name.translate(TRANSLATION_TABLE)
34
+ return normalized_checkpoint_name
35
+
36
+
37
+ def is_valid_checkpoint_name(checkpoint_name: str) -> bool:
38
+ """Check if the provided checkpoint name is valid.
39
+
40
+ A valid checkpoint name must:
41
+ - Start with a letter (a-z, A-Z) or an underscore (_)
42
+ - Be followed by any combination of letters, digits (0-9) and underscores (_).
43
+
44
+ Args:
45
+ checkpoint_name (str): The checkpoint name to validate.
46
+
47
+ Returns:
48
+ bool: True if the checkpoint name is valid; otherwise, False.
49
+
50
+ """
51
+ matched = regx.fullmatch(CHECKPOINT_NAME_REGEX_PATTERN, checkpoint_name)
52
+ is_valid = bool(matched)
53
+ return is_valid
@@ -0,0 +1,112 @@
1
+ # Copyright 2025 Snowflake Inc.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import os
16
+
17
+ from typing import Optional
18
+
19
+ from snowflake.snowpark_checkpoints_collector.collection_common import (
20
+ SNOWFLAKE_CHECKPOINT_CONTRACT_FILE_PATH_ENV_VAR,
21
+ CheckpointMode,
22
+ )
23
+
24
+
25
+ # noinspection DuplicatedCode
26
+ def _get_checkpoint_contract_file_path() -> str:
27
+ return os.environ.get(SNOWFLAKE_CHECKPOINT_CONTRACT_FILE_PATH_ENV_VAR, os.getcwd())
28
+
29
+
30
+ # noinspection DuplicatedCode
31
+ def _get_metadata():
32
+ try:
33
+ from snowflake.snowpark_checkpoints_configuration.checkpoint_metadata import (
34
+ CheckpointMetadata,
35
+ )
36
+
37
+ path = _get_checkpoint_contract_file_path()
38
+ metadata = CheckpointMetadata(path)
39
+ return True, metadata
40
+
41
+ except ImportError:
42
+ return False, None
43
+
44
+
45
+ def is_checkpoint_enabled(checkpoint_name: str) -> bool:
46
+ """Check if a checkpoint is enabled.
47
+
48
+ Args:
49
+ checkpoint_name (str): The name of the checkpoint.
50
+
51
+ Returns:
52
+ bool: True if the checkpoint is enabled, False otherwise.
53
+
54
+ """
55
+ enabled, metadata = _get_metadata()
56
+ if enabled:
57
+ config = metadata.get_checkpoint(checkpoint_name)
58
+ return config.enabled
59
+ else:
60
+ return True
61
+
62
+
63
+ def get_checkpoint_sample(
64
+ checkpoint_name: str, sample: Optional[float] = None
65
+ ) -> float:
66
+ """Get the checkpoint sample.
67
+
68
+ Following this order first, the sample passed as argument, second, the sample from the checkpoint configuration,
69
+ third, the default sample value 1.0.
70
+
71
+ Args:
72
+ checkpoint_name (str): The name of the checkpoint.
73
+ sample (float, optional): The value passed to the function.
74
+
75
+ Returns:
76
+ float: returns the sample for that specific checkpoint.
77
+
78
+ """
79
+ default_sample = 1.0
80
+
81
+ enabled, metadata = _get_metadata()
82
+ if enabled:
83
+ config = metadata.get_checkpoint(checkpoint_name)
84
+ default_sample = config.sample if config.sample is not None else default_sample
85
+
86
+ return sample if sample is not None else default_sample
87
+
88
+
89
+ def get_checkpoint_mode(
90
+ checkpoint_name: str, mode: Optional[CheckpointMode] = None
91
+ ) -> CheckpointMode:
92
+ """Get the checkpoint mode.
93
+
94
+ Following this order first, the mode passed as argument, second, the mode from the checkpoint configuration,
95
+ third, the default mode value 1.
96
+
97
+ Args:
98
+ checkpoint_name (str): The name of the checkpoint.
99
+ mode (int, optional): The value passed to the function.
100
+
101
+ Returns:
102
+ int: returns the mode for that specific checkpoint.
103
+
104
+ """
105
+ default_mode = CheckpointMode.SCHEMA
106
+
107
+ enabled, metadata = _get_metadata()
108
+ if enabled:
109
+ config = metadata.get_checkpoint(checkpoint_name)
110
+ default_mode = config.mode if config.mode is not None else default_mode
111
+
112
+ return mode if mode is not None else default_mode
@@ -0,0 +1,132 @@
1
+ # Copyright 2025 Snowflake Inc.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import inspect
16
+ import os
17
+ import tempfile
18
+
19
+ from typing import Optional
20
+
21
+ from snowflake.snowpark_checkpoints_collector.collection_common import (
22
+ COLLECTION_RESULT_FILE_NAME,
23
+ DOT_IPYNB_EXTENSION,
24
+ SNOWPARK_CHECKPOINTS_OUTPUT_DIRECTORY_NAME,
25
+ UNKNOWN_LINE_OF_CODE,
26
+ UNKNOWN_SOURCE_FILE,
27
+ )
28
+
29
+
30
+ def get_output_file_path(out_path: Optional[str] = None) -> str:
31
+ """Get the output file path.
32
+
33
+ Args:
34
+ out_path (Optional[str], optional): the output path. Defaults to None.
35
+
36
+ Returns:
37
+ str: returns the output file path.
38
+
39
+ """
40
+ output_directory_path = get_output_directory_path(out_path)
41
+ output_file_path = os.path.join(output_directory_path, COLLECTION_RESULT_FILE_NAME)
42
+ return output_file_path
43
+
44
+
45
+ def get_relative_file_path(path: str) -> str:
46
+ """Get the relative file path.
47
+
48
+ Args:
49
+ path (str): a file path.
50
+
51
+ Returns:
52
+ str: returns the relative file path of the given file.
53
+
54
+ """
55
+ relative_file_path = os.path.relpath(path)
56
+ return relative_file_path
57
+
58
+
59
+ def get_output_directory_path(output_path: Optional[str] = None) -> str:
60
+ """Get the output directory path.
61
+
62
+ Returns:
63
+ str: returns the output directory path.
64
+
65
+ """
66
+ current_working_directory_path = output_path if output_path else os.getcwd()
67
+ checkpoints_output_directory_path = os.path.join(
68
+ current_working_directory_path, SNOWPARK_CHECKPOINTS_OUTPUT_DIRECTORY_NAME
69
+ )
70
+ os.makedirs(checkpoints_output_directory_path, exist_ok=True)
71
+ return checkpoints_output_directory_path
72
+
73
+
74
+ def get_collection_point_source_file_path() -> str:
75
+ """Get the path of the source file where collection point it is.
76
+
77
+ Returns:
78
+ str: returns the path of source file where collection point it is.
79
+
80
+ """
81
+ try:
82
+ collection_point_file_path = inspect.stack()[2].filename
83
+ is_temporal_file_path = _is_temporal_path(collection_point_file_path)
84
+ if is_temporal_file_path:
85
+ ipynb_file_path_collection = _get_ipynb_file_path_collection()
86
+ if len(ipynb_file_path_collection) == 1:
87
+ collection_point_file_path = ipynb_file_path_collection[0]
88
+ else:
89
+ collection_point_file_path = UNKNOWN_SOURCE_FILE
90
+
91
+ return collection_point_file_path
92
+
93
+ except Exception:
94
+ return UNKNOWN_SOURCE_FILE
95
+
96
+
97
+ def get_collection_point_line_of_code() -> int:
98
+ """Find the line of code of the source file where collection point it is.
99
+
100
+ Returns:
101
+ int: returns the line of code of the source file where collection point it is.
102
+
103
+ """
104
+ try:
105
+ collection_point_file_path = inspect.stack()[2].filename
106
+ collection_point_line_of_code = inspect.stack()[2].lineno
107
+ is_temporal_file_path = _is_temporal_path(collection_point_file_path)
108
+ if is_temporal_file_path:
109
+ collection_point_line_of_code = UNKNOWN_LINE_OF_CODE
110
+ return collection_point_line_of_code
111
+
112
+ except Exception:
113
+ return UNKNOWN_LINE_OF_CODE
114
+
115
+
116
+ def _is_temporal_path(path: str) -> bool:
117
+ temporal_directory_path = tempfile.gettempdir()
118
+ is_temporal_path = path.startswith(temporal_directory_path)
119
+ return is_temporal_path
120
+
121
+
122
+ def _get_ipynb_file_path_collection() -> list[str]:
123
+ current_working_directory_path = os.getcwd()
124
+ cwd_file_name_collection = os.listdir(current_working_directory_path)
125
+ ipynb_file_path_collection = []
126
+ for file_name in cwd_file_name_collection:
127
+ is_ipynb_file = file_name.endswith(DOT_IPYNB_EXTENSION)
128
+ if is_ipynb_file:
129
+ file_path = os.path.join(current_working_directory_path, file_name)
130
+ ipynb_file_path_collection.append(file_path)
131
+
132
+ return ipynb_file_path_collection