snowpark-checkpoints-validators 0.2.0rc1__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. snowflake/snowpark_checkpoints/__init__.py +44 -0
  2. snowflake/snowpark_checkpoints/__version__.py +16 -0
  3. snowflake/snowpark_checkpoints/checkpoint.py +580 -0
  4. snowflake/snowpark_checkpoints/errors.py +60 -0
  5. snowflake/snowpark_checkpoints/io_utils/__init__.py +26 -0
  6. snowflake/snowpark_checkpoints/io_utils/io_default_strategy.py +57 -0
  7. snowflake/snowpark_checkpoints/io_utils/io_env_strategy.py +133 -0
  8. snowflake/snowpark_checkpoints/io_utils/io_file_manager.py +76 -0
  9. snowflake/snowpark_checkpoints/job_context.py +128 -0
  10. snowflake/snowpark_checkpoints/singleton.py +23 -0
  11. snowflake/snowpark_checkpoints/snowpark_sampler.py +124 -0
  12. snowflake/snowpark_checkpoints/spark_migration.py +255 -0
  13. snowflake/snowpark_checkpoints/utils/__init__.py +14 -0
  14. snowflake/snowpark_checkpoints/utils/constants.py +134 -0
  15. snowflake/snowpark_checkpoints/utils/extra_config.py +132 -0
  16. snowflake/snowpark_checkpoints/utils/logging_utils.py +67 -0
  17. snowflake/snowpark_checkpoints/utils/pandera_check_manager.py +399 -0
  18. snowflake/snowpark_checkpoints/utils/supported_types.py +65 -0
  19. snowflake/snowpark_checkpoints/utils/telemetry.py +939 -0
  20. snowflake/snowpark_checkpoints/utils/utils_checks.py +398 -0
  21. snowflake/snowpark_checkpoints/validation_result_metadata.py +159 -0
  22. snowflake/snowpark_checkpoints/validation_results.py +49 -0
  23. snowpark_checkpoints_validators-0.3.0.dist-info/METADATA +325 -0
  24. snowpark_checkpoints_validators-0.3.0.dist-info/RECORD +26 -0
  25. snowpark_checkpoints_validators-0.2.0rc1.dist-info/METADATA +0 -514
  26. snowpark_checkpoints_validators-0.2.0rc1.dist-info/RECORD +0 -4
  27. {snowpark_checkpoints_validators-0.2.0rc1.dist-info → snowpark_checkpoints_validators-0.3.0.dist-info}/WHEEL +0 -0
  28. {snowpark_checkpoints_validators-0.2.0rc1.dist-info → snowpark_checkpoints_validators-0.3.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,255 @@
1
+ # Copyright 2025 Snowflake Inc.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import logging
17
+
18
+ from typing import Callable, Optional, TypeVar
19
+
20
+ import pandas as pd
21
+
22
+ from pyspark.sql import DataFrame as SparkDataFrame
23
+
24
+ from snowflake.snowpark import DataFrame as SnowparkDataFrame
25
+ from snowflake.snowpark.types import PandasDataFrame
26
+ from snowflake.snowpark_checkpoints.errors import SparkMigrationError
27
+ from snowflake.snowpark_checkpoints.job_context import SnowparkJobContext
28
+ from snowflake.snowpark_checkpoints.snowpark_sampler import (
29
+ SamplingAdapter,
30
+ SamplingStrategy,
31
+ )
32
+ from snowflake.snowpark_checkpoints.utils.constants import FAIL_STATUS, PASS_STATUS
33
+ from snowflake.snowpark_checkpoints.utils.logging_utils import log
34
+ from snowflake.snowpark_checkpoints.utils.telemetry import STATUS_KEY, report_telemetry
35
+ from snowflake.snowpark_checkpoints.utils.utils_checks import (
36
+ _replace_special_characters,
37
+ _update_validation_result,
38
+ )
39
+
40
+
41
+ fn = TypeVar("F", bound=Callable)
42
+ LOGGER = logging.getLogger(__name__)
43
+
44
+
45
+ @log
46
+ def check_with_spark(
47
+ job_context: Optional[SnowparkJobContext],
48
+ spark_function: fn,
49
+ checkpoint_name: str,
50
+ sample_number: Optional[int] = 100,
51
+ sampling_strategy: Optional[SamplingStrategy] = SamplingStrategy.RANDOM_SAMPLE,
52
+ output_path: Optional[str] = None,
53
+ ) -> Callable[[fn], fn]:
54
+ """Validate function output with Spark instance.
55
+
56
+ Will take the input snowpark dataframe of this function, sample data, convert
57
+ it to a Spark dataframe and then execute `spark_function`. Subsequently
58
+ the output of that function will be compared to the output of this function
59
+ for the same sample of data.
60
+
61
+ Args:
62
+ job_context (SnowparkJobContext): The job context containing configuration and details for the validation.
63
+ spark_function (fn): The equivalent PySpark function to compare against the Snowpark implementation.
64
+ checkpoint_name (str): A name for the checkpoint. Defaults to None.
65
+ sample_number (Optional[int], optional): The number of rows for validation. Defaults to 100.
66
+ sampling_strategy (Optional[SamplingStrategy], optional): The strategy used for sampling data.
67
+ Defaults to SamplingStrategy.RANDOM_SAMPLE.
68
+ output_path (Optional[str], optional): The path to store the validation results. Defaults to None.
69
+
70
+ Returns:
71
+ Callable[[fn], fn]: A decorator that wraps the original Snowpark function with validation logic.
72
+
73
+ """
74
+
75
+ def check_with_spark_decorator(snowpark_fn):
76
+ @log(log_args=False)
77
+ def wrapper(*args, **kwargs):
78
+ LOGGER.info(
79
+ "Starting output validation between Snowpark function '%s' and Spark function '%s'",
80
+ snowpark_fn.__name__,
81
+ spark_function.__name__,
82
+ )
83
+ _checkpoint_name = checkpoint_name
84
+ if checkpoint_name is None:
85
+ LOGGER.warning(
86
+ "No checkpoint name provided. Using '%s' as the checkpoint name",
87
+ snowpark_fn.__name__,
88
+ )
89
+ _checkpoint_name = snowpark_fn.__name__
90
+ _checkpoint_name = _replace_special_characters(_checkpoint_name)
91
+
92
+ sampler = SamplingAdapter(
93
+ job_context,
94
+ sample_number=sample_number,
95
+ sampling_strategy=sampling_strategy,
96
+ )
97
+ sampler.process_args(args)
98
+ snowpark_sample_args = sampler.get_sampled_snowpark_args()
99
+ pyspark_sample_args = sampler.get_sampled_spark_args()
100
+
101
+ # Run the sampled data in snowpark
102
+ LOGGER.info("Running the Snowpark function with sampled args")
103
+ snowpark_test_results = snowpark_fn(*snowpark_sample_args, **kwargs)
104
+ LOGGER.info("Running the Spark function with sampled args")
105
+ spark_test_results = spark_function(*pyspark_sample_args, **kwargs)
106
+
107
+ LOGGER.info("Comparing the results of the Snowpark and Spark functions")
108
+ result, exception = _assert_return(
109
+ snowpark_test_results,
110
+ spark_test_results,
111
+ job_context,
112
+ _checkpoint_name,
113
+ output_path,
114
+ )
115
+ if not result:
116
+ LOGGER.error(
117
+ "Validation failed. The results of the Snowpark function '%s' and Spark function '%s' do not match",
118
+ snowpark_fn.__name__,
119
+ spark_function.__name__,
120
+ )
121
+ raise exception from None
122
+ LOGGER.info(
123
+ "Validation passed. The results of the Snowpark function '%s' and Spark function '%s' match",
124
+ snowpark_fn.__name__,
125
+ spark_function.__name__,
126
+ )
127
+
128
+ # Run the original function in snowpark
129
+ return snowpark_fn(*args, **kwargs)
130
+
131
+ return wrapper
132
+
133
+ return check_with_spark_decorator
134
+
135
+
136
+ @report_telemetry(
137
+ params_list=["snowpark_results", "spark_results"],
138
+ return_indexes=[(STATUS_KEY, 0)],
139
+ multiple_return=True,
140
+ )
141
+ def _assert_return(
142
+ snowpark_results, spark_results, job_context, checkpoint_name, output_path=None
143
+ ) -> tuple[bool, Optional[Exception]]:
144
+ """Assert and validate the results from Snowpark and Spark transformations.
145
+
146
+ Args:
147
+ snowpark_results (Any): Results from the Snowpark transformation.
148
+ spark_results (Any): Results from the Spark transformation to compare against.
149
+ job_context (Any): Additional context about the job. Defaults to None.
150
+ checkpoint_name (Any): Name of the checkpoint for logging. Defaults to None.
151
+ output_path (Optional[str], optional): The path to store the validation results. Defaults to None.
152
+
153
+ Raises:
154
+ AssertionError: If the Snowpark and Spark results do not match.
155
+ TypeError: If the results cannot be compared.
156
+
157
+ """
158
+ if isinstance(snowpark_results, SnowparkDataFrame) and isinstance(
159
+ spark_results, SparkDataFrame
160
+ ):
161
+ LOGGER.debug("Comparing two DataFrame results for equality")
162
+ cmp = compare_spark_snowpark_dfs(spark_results, snowpark_results)
163
+
164
+ if not cmp.empty:
165
+ exception_result = SparkMigrationError(
166
+ "DataFrame difference:\n", job_context, checkpoint_name, cmp
167
+ )
168
+ return False, exception_result
169
+ job_context._mark_pass(checkpoint_name)
170
+ _update_validation_result(checkpoint_name, PASS_STATUS, output_path)
171
+ return True, None
172
+ else:
173
+ LOGGER.debug("Comparing two scalar results for equality")
174
+ if snowpark_results != spark_results:
175
+ exception_result = SparkMigrationError(
176
+ "Return value difference:\n",
177
+ job_context,
178
+ checkpoint_name,
179
+ f"{snowpark_results} != {spark_results}",
180
+ )
181
+ _update_validation_result(checkpoint_name, FAIL_STATUS, output_path)
182
+ return False, exception_result
183
+ job_context._mark_pass(checkpoint_name)
184
+ _update_validation_result(checkpoint_name, PASS_STATUS, output_path)
185
+ return True, None
186
+
187
+
188
+ def compare_spark_snowpark_dfs(
189
+ spark_df: SparkDataFrame, snowpark_df: SnowparkDataFrame
190
+ ) -> PandasDataFrame:
191
+ """Compare two dataframes for equality.
192
+
193
+ Args:
194
+ spark_df (SparkDataFrame): The Spark dataframe to compare.
195
+ snowpark_df (SnowparkDataFrame): The Snowpark dataframe to compare.
196
+
197
+ Returns:Pandas DataFrame containing the differences between the two dataframes.
198
+
199
+ """
200
+ snowpark_df = snowpark_df.to_pandas()
201
+ snowpark_df.columns = snowpark_df.columns.str.upper()
202
+ spark_df = spark_df.toPandas()
203
+ spark_df.columns = spark_df.columns.str.upper()
204
+ spark_cols = set(spark_df.columns)
205
+ snowpark_cols = set(snowpark_df.columns)
206
+ cmp = pd.DataFrame([])
207
+ left = spark_cols - snowpark_cols
208
+ right = snowpark_cols - spark_cols
209
+ if left != set():
210
+ cmp = _compare_dfs(spark_df, snowpark_df, "spark", "snowpark")
211
+ if right != set():
212
+ right_cmp = _compare_dfs(snowpark_df, spark_df, "snowpark", "spark")
213
+ cmp = right_cmp if cmp.empty else pd.concat([cmp, right_cmp], ignore_index=True)
214
+ if left == set() and right == set():
215
+ if spark_df.shape == snowpark_df.shape:
216
+ cmp = spark_df.compare(snowpark_df, result_names=("spark", "snowpark"))
217
+ else:
218
+ cmp = spark_df.merge(snowpark_df, indicator=True, how="outer").loc[
219
+ lambda x: x["_merge"] != "both"
220
+ ]
221
+ cmp = cmp.replace(
222
+ {"left_only": "spark_only", "right_only": "snowpark_only"}
223
+ )
224
+
225
+ return cmp
226
+
227
+
228
+ def _compare_dfs(
229
+ df_a: pd.DataFrame, df_b: pd.DataFrame, left_label: str, right_label: str
230
+ ) -> PandasDataFrame:
231
+ """Compare two dataframes for equality.
232
+
233
+ Args:
234
+ df_a (PandasDataFrame): The first dataframe to compare.
235
+ df_b (PandasDataFrame): The second dataframe to compare.
236
+ left_label (str): The label for the first dataframe.
237
+ right_label (str): The label for the second dataframe.
238
+
239
+ :return: Pandas DataFrame containing the differences between the two dataframes.
240
+
241
+ """
242
+ df_a["side"] = "a"
243
+ df_b["side"] = "b"
244
+ a_only = [col for col in df_a.columns if col not in df_b.columns] + ["side"]
245
+ b_only = [col for col in df_b.columns if col not in df_a.columns] + ["side"]
246
+ cmp = (
247
+ df_a[a_only]
248
+ .merge(df_b[b_only], indicator=True, how="left")
249
+ .loc[lambda x: x["_merge"] != "both"]
250
+ )
251
+ cmp = cmp.replace(
252
+ {"left_only": f"{left_label}_only", "right_only": f"{right_label}_only"}
253
+ )
254
+ cmp = cmp.drop(columns="side")
255
+ return cmp
@@ -0,0 +1,14 @@
1
+ # Copyright 2025 Snowflake Inc.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
@@ -0,0 +1,134 @@
1
+ # Copyright 2025 Snowflake Inc.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # Skip type
17
+ from enum import IntEnum
18
+ from typing import Final
19
+
20
+
21
+ class CheckpointMode(IntEnum):
22
+
23
+ """Enum class representing the validation mode."""
24
+
25
+ SCHEMA = 1
26
+ """Validate against a schema file"""
27
+ DATAFRAME = 2
28
+ """Validate against a dataframe"""
29
+
30
+
31
+ # Execution mode
32
+ SCHEMA_EXECUTION_MODE: Final[str] = "Schema"
33
+ DATAFRAME_EXECUTION_MODE: Final[str] = "Dataframe"
34
+
35
+
36
+ # File position on stack
37
+ STACK_POSITION_CHECKPOINT: Final[int] = 6
38
+
39
+ # Validation status
40
+ PASS_STATUS: Final[str] = "PASS"
41
+ FAIL_STATUS: Final[str] = "FAIL"
42
+
43
+ # Validation result keys
44
+ DEFAULT_KEY: Final[str] = "default"
45
+
46
+
47
+ # Skip type
48
+ SKIP_ALL: Final[str] = "skip_all"
49
+
50
+ # Supported types
51
+ BOOLEAN_TYPE: Final[str] = "boolean"
52
+ BINARY_TYPE: Final[str] = "binary"
53
+ BYTE_TYPE: Final[str] = "byte"
54
+ CHAR_TYPE: Final[str] = "char"
55
+ DATE_TYPE: Final[str] = "date"
56
+ DAYTIMEINTERVAL_TYPE: Final[str] = "daytimeinterval"
57
+ DECIMAL_TYPE: Final[str] = "decimal"
58
+ DOUBLE_TYPE: Final[str] = "double"
59
+ FLOAT_TYPE: Final[str] = "float"
60
+ INTEGER_TYPE: Final[str] = "integer"
61
+ LONG_TYPE: Final[str] = "long"
62
+ SHORT_TYPE: Final[str] = "short"
63
+ STRING_TYPE: Final[str] = "string"
64
+ TIMESTAMP_TYPE: Final[str] = "timestamp"
65
+ TIMESTAMP_NTZ_TYPE: Final[str] = "timestamp_ntz"
66
+ VARCHAR_TYPE: Final[str] = "varchar"
67
+
68
+ # Pandas data types
69
+ PANDAS_BOOLEAN_DTYPE: Final[str] = "bool"
70
+ PANDAS_DATETIME_DTYPE: Final[str] = "datetime64[ns]"
71
+ PANDAS_FLOAT_DTYPE: Final[str] = "float64"
72
+ PANDAS_INTEGER_DTYPE: Final[str] = "int64"
73
+ PANDAS_OBJECT_DTYPE: Final[str] = "object"
74
+ PANDAS_TIMEDELTA_DTYPE: Final[str] = "timedelta64[ns]"
75
+
76
+ # Schemas keys
77
+ COLUMNS_KEY: Final[str] = "columns"
78
+ COUNT_KEY: Final[str] = "rows_count"
79
+ DECIMAL_PRECISION_KEY: Final[str] = "decimal_precision"
80
+ FALSE_COUNT_KEY: Final[str] = "false_count"
81
+ FORMAT_KEY: Final[str] = "format"
82
+ NAME_KEY: Final[str] = "name"
83
+ MARGIN_ERROR_KEY: Final[str] = "margin_error"
84
+ MAX_KEY: Final[str] = "max"
85
+ MEAN_KEY: Final[str] = "mean"
86
+ MIN_KEY: Final[str] = "min"
87
+ NULL_COUNT_KEY: Final[str] = "rows_null_count"
88
+ NULLABLE_KEY: Final[str] = "nullable"
89
+ ROWS_NOT_NULL_COUNT_KEY: Final[str] = "rows_not_null_count"
90
+ TRUE_COUNT_KEY: Final[str] = "true_count"
91
+ TYPE_KEY: Final[str] = "type"
92
+ ROWS_COUNT_KEY: Final[str] = "rows_count"
93
+ FORMAT_KEY: Final[str] = "format"
94
+
95
+ DATAFRAME_CUSTOM_DATA_KEY: Final[str] = "custom_data"
96
+ DATAFRAME_PANDERA_SCHEMA_KEY: Final[str] = "pandera_schema"
97
+
98
+ # Default values
99
+ DEFAULT_DATE_FORMAT: Final[str] = "%Y-%m-%d"
100
+
101
+ # SQL Column names
102
+ TABLE_NAME_COL: Final[str] = "TABLE_NAME"
103
+ CREATED_COL: Final[str] = "CREATED"
104
+
105
+ # SQL Table names
106
+ INFORMATION_SCHEMA_TABLE_NAME: Final[str] = "INFORMATION_SCHEMA"
107
+ TABLES_TABLE_NAME: Final[str] = "TABLES"
108
+
109
+ # SQL Query
110
+ EXCEPT_HASH_AGG_QUERY: Final[
111
+ str
112
+ ] = "SELECT HASH_AGG(*) FROM IDENTIFIER(:1) EXCEPT SELECT HASH_AGG(*) FROM IDENTIFIER(:2)"
113
+
114
+ # Table checkpoints name
115
+ CHECKPOINT_TABLE_NAME_FORMAT: Final[str] = "{}_snowpark"
116
+
117
+ # Write mode
118
+ OVERWRITE_MODE: Final[str] = "overwrite"
119
+
120
+ # Validation modes
121
+ VALIDATION_MODE_KEY: Final[str] = "validation_mode"
122
+ PIPELINES_KEY: Final[str] = "pipelines"
123
+
124
+ # File name
125
+ CHECKPOINT_JSON_OUTPUT_FILE_FORMAT_NAME: Final[str] = "{}.json"
126
+ CHECKPOINTS_JSON_FILE_NAME: Final[str] = "checkpoints.json"
127
+ SNOWPARK_CHECKPOINTS_OUTPUT_DIRECTORY_NAME: Final[str] = "snowpark-checkpoints-output"
128
+ CHECKPOINT_PARQUET_OUTPUT_FILE_FORMAT_NAME: Final[str] = "{}.parquet"
129
+ VALIDATION_RESULTS_JSON_FILE_NAME: Final[str] = "checkpoint_validation_results.json"
130
+
131
+ # Environment variables
132
+ SNOWFLAKE_CHECKPOINT_CONTRACT_FILE_PATH_ENV_VAR: Final[
133
+ str
134
+ ] = "SNOWFLAKE_CHECKPOINT_CONTRACT_FILE_PATH"
@@ -0,0 +1,132 @@
1
+ # Copyright 2025 Snowflake Inc.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import logging
17
+ import os
18
+
19
+ from typing import Optional
20
+
21
+ from snowflake.snowpark_checkpoints.io_utils.io_file_manager import get_io_file_manager
22
+ from snowflake.snowpark_checkpoints.utils.constants import (
23
+ SNOWFLAKE_CHECKPOINT_CONTRACT_FILE_PATH_ENV_VAR,
24
+ )
25
+
26
+
27
+ LOGGER = logging.getLogger(__name__)
28
+
29
+
30
+ # noinspection DuplicatedCode
31
+ def _get_checkpoint_contract_file_path() -> str:
32
+ return os.environ.get(
33
+ SNOWFLAKE_CHECKPOINT_CONTRACT_FILE_PATH_ENV_VAR, get_io_file_manager().getcwd()
34
+ )
35
+
36
+
37
+ def _set_conf_io_strategy() -> None:
38
+ try:
39
+ from snowflake.snowpark_checkpoints.io_utils.io_default_strategy import (
40
+ IODefaultStrategy,
41
+ )
42
+ from snowflake.snowpark_checkpoints_configuration.io_utils.io_file_manager import (
43
+ EnvStrategy as ConfEnvStrategy,
44
+ )
45
+ from snowflake.snowpark_checkpoints_configuration.io_utils.io_file_manager import (
46
+ get_io_file_manager as get_conf_io_file_manager,
47
+ )
48
+
49
+ is_default_strategy = isinstance(
50
+ get_io_file_manager().strategy, IODefaultStrategy
51
+ )
52
+
53
+ if is_default_strategy:
54
+ return
55
+
56
+ class CustomConfEnvStrategy(ConfEnvStrategy):
57
+ def file_exists(self, path: str) -> bool:
58
+ return get_io_file_manager().file_exists(path)
59
+
60
+ def read(
61
+ self, file_path: str, mode: str = "r", encoding: Optional[str] = None
62
+ ) -> Optional[str]:
63
+ return get_io_file_manager().read(file_path, mode, encoding)
64
+
65
+ def getcwd(self) -> str:
66
+ return get_io_file_manager().getcwd()
67
+
68
+ get_conf_io_file_manager().set_strategy(CustomConfEnvStrategy())
69
+
70
+ except ImportError:
71
+ LOGGER.debug(
72
+ "snowpark-checkpoints-configuration is not installed. Cannot get a checkpoint metadata instance."
73
+ )
74
+
75
+
76
+ # noinspection DuplicatedCode
77
+ def _get_metadata():
78
+ try:
79
+ from snowflake.snowpark_checkpoints_configuration.checkpoint_metadata import (
80
+ CheckpointMetadata,
81
+ )
82
+
83
+ path = _get_checkpoint_contract_file_path()
84
+ _set_conf_io_strategy()
85
+ LOGGER.debug("Loading checkpoint metadata from '%s'", path)
86
+ metadata = CheckpointMetadata(path)
87
+ return True, metadata
88
+
89
+ except ImportError:
90
+ LOGGER.debug(
91
+ "snowpark-checkpoints-configuration is not installed. Cannot get a checkpoint metadata instance."
92
+ )
93
+ return False, None
94
+
95
+
96
+ def is_checkpoint_enabled(checkpoint_name: Optional[str] = None) -> bool:
97
+ """Check if a checkpoint is enabled.
98
+
99
+ Args:
100
+ checkpoint_name (Optional[str], optional): The name of the checkpoint.
101
+
102
+ Returns:
103
+ bool: True if the checkpoint is enabled, False otherwise.
104
+
105
+ """
106
+ enabled, metadata = _get_metadata()
107
+ if enabled and checkpoint_name is not None:
108
+ config = metadata.get_checkpoint(checkpoint_name)
109
+ return config.enabled
110
+ return True
111
+
112
+
113
+ def get_checkpoint_file(checkpoint_name: str) -> Optional[str]:
114
+ """Retrieve the configuration for a specified checkpoint.
115
+
116
+ This function fetches the checkpoint configuration if metadata is enabled.
117
+ It extracts the file name from the checkpoint metadata or
118
+ from the call stack if not explicitly provided in the metadata.
119
+
120
+ Args:
121
+ checkpoint_name (str): The name of the checkpoint to retrieve the configuration for.
122
+
123
+ Returns:
124
+ Optional[dict]: A dictionary containing the file name,
125
+ or None if metadata is not enabled.
126
+
127
+ """
128
+ enabled, metadata = _get_metadata()
129
+ if enabled:
130
+ config = metadata.get_checkpoint(checkpoint_name)
131
+ return config.file
132
+ return None
@@ -0,0 +1,67 @@
1
+ # Copyright 2025 Snowflake Inc.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import logging
17
+
18
+ from functools import wraps
19
+ from typing import Callable, Optional, TypeVar
20
+
21
+ from typing_extensions import ParamSpec
22
+
23
+
24
+ P = ParamSpec("P")
25
+ R = TypeVar("R")
26
+
27
+
28
+ def log(
29
+ _func: Optional[Callable[P, R]] = None,
30
+ *,
31
+ logger: Optional[logging.Logger] = None,
32
+ log_args: bool = True,
33
+ ) -> Callable[[Callable[P, R]], Callable[P, R]]:
34
+ """Log the function call and any exceptions that occur.
35
+
36
+ Args:
37
+ _func: The function to log.
38
+ logger: The logger to use for logging. If not provided, a logger will be created using the
39
+ function's module name.
40
+ log_args: Whether to log the arguments passed to the function.
41
+
42
+ Returns:
43
+ A decorator that logs the function call and any exceptions that occur.
44
+
45
+ """
46
+
47
+ def decorator(func: Callable[P, R]) -> Callable[P, R]:
48
+ @wraps(func)
49
+ def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
50
+ _logger = logging.getLogger(func.__module__) if logger is None else logger
51
+ if log_args:
52
+ args_repr = [repr(a) for a in args]
53
+ kwargs_repr = [f"{k}={v!r}" for k, v in kwargs.items()]
54
+ formatted_args = ", ".join([*args_repr, *kwargs_repr])
55
+ _logger.debug("%s called with args %s", func.__name__, formatted_args)
56
+ try:
57
+ return func(*args, **kwargs)
58
+ except Exception:
59
+ _logger.exception("An error occurred in %s", func.__name__)
60
+ raise
61
+
62
+ return wrapper
63
+
64
+ # Handle the case where the decorator is used without parentheses
65
+ if _func is None:
66
+ return decorator
67
+ return decorator(_func)