snowpark-checkpoints-validators 0.2.0rc1__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/snowpark_checkpoints/__init__.py +44 -0
- snowflake/snowpark_checkpoints/__version__.py +16 -0
- snowflake/snowpark_checkpoints/checkpoint.py +580 -0
- snowflake/snowpark_checkpoints/errors.py +60 -0
- snowflake/snowpark_checkpoints/job_context.py +128 -0
- snowflake/snowpark_checkpoints/singleton.py +23 -0
- snowflake/snowpark_checkpoints/snowpark_sampler.py +124 -0
- snowflake/snowpark_checkpoints/spark_migration.py +255 -0
- snowflake/snowpark_checkpoints/utils/__init__.py +14 -0
- snowflake/snowpark_checkpoints/utils/constants.py +134 -0
- snowflake/snowpark_checkpoints/utils/extra_config.py +89 -0
- snowflake/snowpark_checkpoints/utils/logging_utils.py +67 -0
- snowflake/snowpark_checkpoints/utils/pandera_check_manager.py +399 -0
- snowflake/snowpark_checkpoints/utils/supported_types.py +65 -0
- snowflake/snowpark_checkpoints/utils/telemetry.py +900 -0
- snowflake/snowpark_checkpoints/utils/utils_checks.py +395 -0
- snowflake/snowpark_checkpoints/validation_result_metadata.py +155 -0
- snowflake/snowpark_checkpoints/validation_results.py +49 -0
- snowpark_checkpoints_validators-0.2.1.dist-info/METADATA +323 -0
- snowpark_checkpoints_validators-0.2.1.dist-info/RECORD +22 -0
- snowpark_checkpoints_validators-0.2.0rc1.dist-info/METADATA +0 -514
- snowpark_checkpoints_validators-0.2.0rc1.dist-info/RECORD +0 -4
- {snowpark_checkpoints_validators-0.2.0rc1.dist-info → snowpark_checkpoints_validators-0.2.1.dist-info}/WHEEL +0 -0
- {snowpark_checkpoints_validators-0.2.0rc1.dist-info → snowpark_checkpoints_validators-0.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,44 @@
|
|
1
|
+
# Copyright 2025 Snowflake Inc.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
import logging
|
17
|
+
|
18
|
+
|
19
|
+
# Add a NullHandler to prevent logging messages from being output to
|
20
|
+
# sys.stderr if no logging configuration is provided.
|
21
|
+
logging.getLogger(__name__).addHandler(logging.NullHandler())
|
22
|
+
|
23
|
+
# ruff: noqa: E402
|
24
|
+
|
25
|
+
from snowflake.snowpark_checkpoints.checkpoint import (
|
26
|
+
check_dataframe_schema,
|
27
|
+
check_input_schema,
|
28
|
+
check_output_schema,
|
29
|
+
validate_dataframe_checkpoint,
|
30
|
+
)
|
31
|
+
from snowflake.snowpark_checkpoints.job_context import SnowparkJobContext
|
32
|
+
from snowflake.snowpark_checkpoints.spark_migration import check_with_spark
|
33
|
+
from snowflake.snowpark_checkpoints.utils.constants import CheckpointMode
|
34
|
+
|
35
|
+
|
36
|
+
__all__ = [
|
37
|
+
"check_with_spark",
|
38
|
+
"SnowparkJobContext",
|
39
|
+
"check_dataframe_schema",
|
40
|
+
"check_output_schema",
|
41
|
+
"check_input_schema",
|
42
|
+
"validate_dataframe_checkpoint",
|
43
|
+
"CheckpointMode",
|
44
|
+
]
|
@@ -0,0 +1,16 @@
|
|
1
|
+
# Copyright 2025 Snowflake Inc.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
__version__ = "0.2.1"
|
@@ -0,0 +1,580 @@
|
|
1
|
+
# Copyright 2025 Snowflake Inc.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
# Wrapper around pandera which logs to snowflake
|
17
|
+
|
18
|
+
import logging
|
19
|
+
|
20
|
+
from typing import Any, Optional, Union, cast
|
21
|
+
|
22
|
+
from pandas import DataFrame as PandasDataFrame
|
23
|
+
from pandera import Check, DataFrameModel, DataFrameSchema
|
24
|
+
from pandera.errors import SchemaError, SchemaErrors
|
25
|
+
|
26
|
+
from snowflake.snowpark import DataFrame as SnowparkDataFrame
|
27
|
+
from snowflake.snowpark_checkpoints.errors import SchemaValidationError
|
28
|
+
from snowflake.snowpark_checkpoints.job_context import SnowparkJobContext
|
29
|
+
from snowflake.snowpark_checkpoints.snowpark_sampler import (
|
30
|
+
SamplingAdapter,
|
31
|
+
SamplingStrategy,
|
32
|
+
)
|
33
|
+
from snowflake.snowpark_checkpoints.utils.constants import (
|
34
|
+
FAIL_STATUS,
|
35
|
+
PASS_STATUS,
|
36
|
+
CheckpointMode,
|
37
|
+
)
|
38
|
+
from snowflake.snowpark_checkpoints.utils.extra_config import is_checkpoint_enabled
|
39
|
+
from snowflake.snowpark_checkpoints.utils.logging_utils import log
|
40
|
+
from snowflake.snowpark_checkpoints.utils.pandera_check_manager import (
|
41
|
+
PanderaCheckManager,
|
42
|
+
)
|
43
|
+
from snowflake.snowpark_checkpoints.utils.telemetry import STATUS_KEY, report_telemetry
|
44
|
+
from snowflake.snowpark_checkpoints.utils.utils_checks import (
|
45
|
+
_check_compare_data,
|
46
|
+
_generate_schema,
|
47
|
+
_process_sampling,
|
48
|
+
_replace_special_characters,
|
49
|
+
_update_validation_result,
|
50
|
+
)
|
51
|
+
|
52
|
+
|
53
|
+
LOGGER = logging.getLogger(__name__)
|
54
|
+
|
55
|
+
|
56
|
+
@log
|
57
|
+
def validate_dataframe_checkpoint(
|
58
|
+
df: SnowparkDataFrame,
|
59
|
+
checkpoint_name: str,
|
60
|
+
job_context: Optional[SnowparkJobContext] = None,
|
61
|
+
mode: Optional[CheckpointMode] = CheckpointMode.SCHEMA,
|
62
|
+
custom_checks: Optional[dict[Any, Any]] = None,
|
63
|
+
skip_checks: Optional[dict[Any, Any]] = None,
|
64
|
+
sample_frac: Optional[float] = 1.0,
|
65
|
+
sample_number: Optional[int] = None,
|
66
|
+
sampling_strategy: Optional[SamplingStrategy] = SamplingStrategy.RANDOM_SAMPLE,
|
67
|
+
output_path: Optional[str] = None,
|
68
|
+
) -> Union[tuple[bool, PandasDataFrame], None]:
|
69
|
+
"""Validate a Snowpark DataFrame against a specified checkpoint.
|
70
|
+
|
71
|
+
Args:
|
72
|
+
df (SnowparkDataFrame): The DataFrame to validate.
|
73
|
+
checkpoint_name (str): The name of the checkpoint to validate against.
|
74
|
+
job_context (SnowparkJobContext, optional): The job context for the validation. Required for PARQUET mode.
|
75
|
+
mode (CheckpointMode): The mode of validation (e.g., SCHEMA, PARQUET). Defaults to SCHEMA.
|
76
|
+
custom_checks (Optional[dict[Any, Any]], optional): Custom checks to apply during validation.
|
77
|
+
skip_checks (Optional[dict[Any, Any]], optional): Checks to skip during validation.
|
78
|
+
sample_frac (Optional[float], optional): Fraction of the DataFrame to sample for validation. Defaults to 0.1.
|
79
|
+
sample_number (Optional[int], optional): Number of rows to sample for validation.
|
80
|
+
sampling_strategy (Optional[SamplingStrategy], optional): Strategy to use for sampling.
|
81
|
+
Defaults to RANDOM_SAMPLE.
|
82
|
+
output_path (Optional[str], optional): The output path for the validation results.
|
83
|
+
|
84
|
+
Returns:
|
85
|
+
Union[tuple[bool, PandasDataFrame], None]: A tuple containing a boolean indicating success
|
86
|
+
and a Pandas DataFrame with validation results, or None if validation is not applicable.
|
87
|
+
|
88
|
+
Raises:
|
89
|
+
ValueError: If an invalid validation mode is provided or if job_context is None for PARQUET mode.
|
90
|
+
|
91
|
+
"""
|
92
|
+
checkpoint_name = _replace_special_characters(checkpoint_name)
|
93
|
+
|
94
|
+
if not is_checkpoint_enabled(checkpoint_name):
|
95
|
+
LOGGER.warning(
|
96
|
+
"Checkpoint '%s' is disabled. Skipping DataFrame checkpoint validation.",
|
97
|
+
checkpoint_name,
|
98
|
+
)
|
99
|
+
return None
|
100
|
+
|
101
|
+
LOGGER.info(
|
102
|
+
"Starting DataFrame checkpoint validation for checkpoint '%s'", checkpoint_name
|
103
|
+
)
|
104
|
+
|
105
|
+
if mode == CheckpointMode.SCHEMA:
|
106
|
+
result = _check_dataframe_schema_file(
|
107
|
+
df,
|
108
|
+
checkpoint_name,
|
109
|
+
job_context,
|
110
|
+
custom_checks,
|
111
|
+
skip_checks,
|
112
|
+
sample_frac,
|
113
|
+
sample_number,
|
114
|
+
sampling_strategy,
|
115
|
+
output_path,
|
116
|
+
)
|
117
|
+
return result
|
118
|
+
|
119
|
+
if mode == CheckpointMode.DATAFRAME:
|
120
|
+
if job_context is None:
|
121
|
+
raise ValueError(
|
122
|
+
"No job context provided. Please provide one when using DataFrame mode validation."
|
123
|
+
)
|
124
|
+
_check_compare_data(df, job_context, checkpoint_name, output_path)
|
125
|
+
return None
|
126
|
+
|
127
|
+
raise ValueError(
|
128
|
+
(
|
129
|
+
"Invalid validation mode. "
|
130
|
+
"Please use 1 for schema validation or 2 for full data validation."
|
131
|
+
),
|
132
|
+
)
|
133
|
+
|
134
|
+
|
135
|
+
def _check_dataframe_schema_file(
|
136
|
+
df: SnowparkDataFrame,
|
137
|
+
checkpoint_name: str,
|
138
|
+
job_context: Optional[SnowparkJobContext] = None,
|
139
|
+
custom_checks: Optional[dict[Any, Any]] = None,
|
140
|
+
skip_checks: Optional[dict[Any, Any]] = None,
|
141
|
+
sample_frac: Optional[float] = 1.0,
|
142
|
+
sample_number: Optional[int] = None,
|
143
|
+
sampling_strategy: Optional[SamplingStrategy] = SamplingStrategy.RANDOM_SAMPLE,
|
144
|
+
output_path: Optional[str] = None,
|
145
|
+
) -> tuple[bool, PandasDataFrame]:
|
146
|
+
"""Generate and checks the schema for a given DataFrame based on a checkpoint name.
|
147
|
+
|
148
|
+
Args:
|
149
|
+
df (SnowparkDataFrame): The DataFrame to be validated.
|
150
|
+
checkpoint_name (str): The name of the checkpoint to retrieve the schema.
|
151
|
+
job_context (SnowparkJobContext, optional): Context for job-related operations.
|
152
|
+
Defaults to None.
|
153
|
+
custom_checks (dict[Any, Any], optional): Custom checks to be added to the schema.
|
154
|
+
Defaults to None.
|
155
|
+
skip_checks (dict[Any, Any], optional): Checks to be skipped.
|
156
|
+
Defaults to None.
|
157
|
+
sample_frac (float, optional): Fraction of data to sample.
|
158
|
+
Defaults to 0.1.
|
159
|
+
sample_number (int, optional): Number of rows to sample.
|
160
|
+
Defaults to None.
|
161
|
+
sampling_strategy (SamplingStrategy, optional): Strategy for sampling data.
|
162
|
+
Defaults to SamplingStrategy.RANDOM_SAMPLE.
|
163
|
+
output_path (str, optional): The output path for the validation results.
|
164
|
+
|
165
|
+
Raises:
|
166
|
+
SchemaValidationError: If the DataFrame fails schema validation.
|
167
|
+
|
168
|
+
Returns:
|
169
|
+
tuple[bool, PanderaDataFrame]: A tuple containing the validity flag and the Pandera DataFrame.
|
170
|
+
|
171
|
+
"""
|
172
|
+
if df is None:
|
173
|
+
raise ValueError("DataFrame is required")
|
174
|
+
|
175
|
+
if checkpoint_name is None:
|
176
|
+
raise ValueError("Checkpoint name is required")
|
177
|
+
|
178
|
+
schema = _generate_schema(checkpoint_name, output_path)
|
179
|
+
|
180
|
+
return _check_dataframe_schema(
|
181
|
+
df,
|
182
|
+
schema,
|
183
|
+
checkpoint_name,
|
184
|
+
job_context,
|
185
|
+
custom_checks,
|
186
|
+
skip_checks,
|
187
|
+
sample_frac,
|
188
|
+
sample_number,
|
189
|
+
sampling_strategy,
|
190
|
+
output_path,
|
191
|
+
)
|
192
|
+
|
193
|
+
|
194
|
+
@log
|
195
|
+
def check_dataframe_schema(
|
196
|
+
df: SnowparkDataFrame,
|
197
|
+
pandera_schema: DataFrameSchema,
|
198
|
+
checkpoint_name: str,
|
199
|
+
job_context: Optional[SnowparkJobContext] = None,
|
200
|
+
custom_checks: Optional[dict[str, list[Check]]] = None,
|
201
|
+
skip_checks: Optional[dict[Any, Any]] = None,
|
202
|
+
sample_frac: Optional[float] = 1.0,
|
203
|
+
sample_number: Optional[int] = None,
|
204
|
+
sampling_strategy: Optional[SamplingStrategy] = SamplingStrategy.RANDOM_SAMPLE,
|
205
|
+
output_path: Optional[str] = None,
|
206
|
+
) -> Union[tuple[bool, PandasDataFrame], None]:
|
207
|
+
"""Validate a DataFrame against a given Pandera schema using sampling techniques.
|
208
|
+
|
209
|
+
Args:
|
210
|
+
df (SnowparkDataFrame): The DataFrame to be validated.
|
211
|
+
pandera_schema (DataFrameSchema): The Pandera schema to validate against.
|
212
|
+
checkpoint_name (str, optional): The name of the checkpoint to retrieve the schema.
|
213
|
+
Defaults to None.
|
214
|
+
job_context (SnowparkJobContext, optional): Context for job-related operations.
|
215
|
+
Defaults to None.
|
216
|
+
custom_checks (dict[Any, Any], optional): Custom checks to be added to the schema.
|
217
|
+
Defaults to None.
|
218
|
+
skip_checks (dict[Any, Any], optional): Checks to be skipped.
|
219
|
+
Defaults to None.
|
220
|
+
sample_frac (float, optional): Fraction of data to sample.
|
221
|
+
Defaults to 0.1.
|
222
|
+
sample_number (int, optional): Number of rows to sample.
|
223
|
+
Defaults to None.
|
224
|
+
sampling_strategy (SamplingStrategy, optional): Strategy for sampling data.
|
225
|
+
Defaults to SamplingStrategy.RANDOM_SAMPLE.
|
226
|
+
output_path (str, optional): The output path for the validation results.
|
227
|
+
|
228
|
+
Raises:
|
229
|
+
SchemaValidationError: If the DataFrame fails schema validation.
|
230
|
+
|
231
|
+
Returns:
|
232
|
+
Union[tuple[bool, PandasDataFrame]|None]: A tuple containing the validity flag and the Pandas DataFrame.
|
233
|
+
If the validation for that checkpoint is disabled it returns None.
|
234
|
+
|
235
|
+
"""
|
236
|
+
checkpoint_name = _replace_special_characters(checkpoint_name)
|
237
|
+
LOGGER.info(
|
238
|
+
"Starting DataFrame schema validation for checkpoint '%s'", checkpoint_name
|
239
|
+
)
|
240
|
+
|
241
|
+
if df is None:
|
242
|
+
raise ValueError("DataFrame is required")
|
243
|
+
|
244
|
+
if pandera_schema is None:
|
245
|
+
raise ValueError("Schema is required")
|
246
|
+
|
247
|
+
if not is_checkpoint_enabled(checkpoint_name):
|
248
|
+
LOGGER.warning(
|
249
|
+
"Checkpoint '%s' is disabled. Skipping DataFrame schema validation.",
|
250
|
+
checkpoint_name,
|
251
|
+
)
|
252
|
+
return None
|
253
|
+
|
254
|
+
return _check_dataframe_schema(
|
255
|
+
df,
|
256
|
+
pandera_schema,
|
257
|
+
checkpoint_name,
|
258
|
+
job_context,
|
259
|
+
custom_checks,
|
260
|
+
skip_checks,
|
261
|
+
sample_frac,
|
262
|
+
sample_number,
|
263
|
+
sampling_strategy,
|
264
|
+
output_path,
|
265
|
+
)
|
266
|
+
|
267
|
+
|
268
|
+
@report_telemetry(
|
269
|
+
params_list=["pandera_schema"],
|
270
|
+
return_indexes=[(STATUS_KEY, 0)],
|
271
|
+
multiple_return=True,
|
272
|
+
)
|
273
|
+
def _check_dataframe_schema(
|
274
|
+
df: SnowparkDataFrame,
|
275
|
+
pandera_schema: DataFrameSchema,
|
276
|
+
checkpoint_name: str,
|
277
|
+
job_context: SnowparkJobContext = None,
|
278
|
+
custom_checks: Optional[dict[str, list[Check]]] = None,
|
279
|
+
skip_checks: Optional[dict[Any, Any]] = None,
|
280
|
+
sample_frac: Optional[float] = 1.0,
|
281
|
+
sample_number: Optional[int] = None,
|
282
|
+
sampling_strategy: Optional[SamplingStrategy] = SamplingStrategy.RANDOM_SAMPLE,
|
283
|
+
output_path: Optional[str] = None,
|
284
|
+
) -> tuple[bool, PandasDataFrame]:
|
285
|
+
|
286
|
+
pandera_check_manager = PanderaCheckManager(checkpoint_name, pandera_schema)
|
287
|
+
pandera_check_manager.skip_checks_on_schema(skip_checks)
|
288
|
+
pandera_check_manager.add_custom_checks(custom_checks)
|
289
|
+
|
290
|
+
pandera_schema_upper, sample_df = _process_sampling(
|
291
|
+
df, pandera_schema, job_context, sample_frac, sample_number, sampling_strategy
|
292
|
+
)
|
293
|
+
is_valid, validation_result = _validate(pandera_schema_upper, sample_df)
|
294
|
+
if is_valid:
|
295
|
+
LOGGER.info(
|
296
|
+
"DataFrame schema validation passed for checkpoint '%s'",
|
297
|
+
checkpoint_name,
|
298
|
+
)
|
299
|
+
if job_context is not None:
|
300
|
+
job_context._mark_pass(checkpoint_name)
|
301
|
+
else:
|
302
|
+
LOGGER.warning(
|
303
|
+
"No job context provided. Skipping result recording into Snowflake.",
|
304
|
+
)
|
305
|
+
_update_validation_result(checkpoint_name, PASS_STATUS, output_path)
|
306
|
+
else:
|
307
|
+
LOGGER.error(
|
308
|
+
"DataFrame schema validation failed for checkpoint '%s'",
|
309
|
+
checkpoint_name,
|
310
|
+
)
|
311
|
+
_update_validation_result(checkpoint_name, FAIL_STATUS, output_path)
|
312
|
+
raise SchemaValidationError(
|
313
|
+
"Snowpark DataFrame schema validation error",
|
314
|
+
job_context,
|
315
|
+
checkpoint_name,
|
316
|
+
validation_result,
|
317
|
+
)
|
318
|
+
|
319
|
+
return (is_valid, validation_result)
|
320
|
+
|
321
|
+
|
322
|
+
@report_telemetry(params_list=["pandera_schema"])
|
323
|
+
@log
|
324
|
+
def check_output_schema(
|
325
|
+
pandera_schema: DataFrameSchema,
|
326
|
+
checkpoint_name: str,
|
327
|
+
sample_frac: Optional[float] = 1.0,
|
328
|
+
sample_number: Optional[int] = None,
|
329
|
+
sampling_strategy: Optional[SamplingStrategy] = SamplingStrategy.RANDOM_SAMPLE,
|
330
|
+
job_context: Optional[SnowparkJobContext] = None,
|
331
|
+
output_path: Optional[str] = None,
|
332
|
+
):
|
333
|
+
"""Decorate to validate the schema of the output of a Snowpark function.
|
334
|
+
|
335
|
+
Args:
|
336
|
+
pandera_schema (DataFrameSchema): The Pandera schema to validate against.
|
337
|
+
checkpoint_name (Optional[str], optional): The name of the checkpoint to retrieve the schema.
|
338
|
+
sample_frac (Optional[float], optional): Fraction of data to sample.
|
339
|
+
Defaults to 0.1.
|
340
|
+
sample_number (Optional[int], optional): Number of rows to sample.
|
341
|
+
Defaults to None.
|
342
|
+
sampling_strategy (Optional[SamplingStrategy], optional): Strategy for sampling data.
|
343
|
+
Defaults to SamplingStrategy.RANDOM_SAMPLE.
|
344
|
+
job_context (SnowparkJobContext, optional): Context for job-related operations.
|
345
|
+
Defaults to None.
|
346
|
+
output_path (Optional[str], optional): The output path for the validation results.
|
347
|
+
|
348
|
+
"""
|
349
|
+
|
350
|
+
def check_output_with_decorator(snowpark_fn):
|
351
|
+
"""Decorate to validate the schema of the output of a Snowpark function.
|
352
|
+
|
353
|
+
Args:
|
354
|
+
snowpark_fn (function): The Snowpark function to validate.
|
355
|
+
|
356
|
+
Returns:
|
357
|
+
function: The decorated function.
|
358
|
+
|
359
|
+
"""
|
360
|
+
|
361
|
+
@log(log_args=False)
|
362
|
+
def wrapper(*args, **kwargs):
|
363
|
+
"""Wrapp a function to validate the schema of the output of a Snowpark function.
|
364
|
+
|
365
|
+
Args:
|
366
|
+
*args: The arguments to the Snowpark function.
|
367
|
+
**kwargs: The keyword arguments to the Snowpark function.
|
368
|
+
|
369
|
+
Returns:
|
370
|
+
Any: The result of the Snowpark function.
|
371
|
+
|
372
|
+
"""
|
373
|
+
_checkpoint_name = checkpoint_name
|
374
|
+
if checkpoint_name is None:
|
375
|
+
LOGGER.warning(
|
376
|
+
(
|
377
|
+
"No checkpoint name provided for output schema validation. "
|
378
|
+
"Using '%s' as the checkpoint name.",
|
379
|
+
),
|
380
|
+
snowpark_fn.__name__,
|
381
|
+
)
|
382
|
+
_checkpoint_name = snowpark_fn.__name__
|
383
|
+
_checkpoint_name = _replace_special_characters(_checkpoint_name)
|
384
|
+
LOGGER.info(
|
385
|
+
"Starting output schema validation for Snowpark function '%s' and checkpoint '%s'",
|
386
|
+
snowpark_fn.__name__,
|
387
|
+
_checkpoint_name,
|
388
|
+
)
|
389
|
+
|
390
|
+
# Run the sampled data in snowpark
|
391
|
+
LOGGER.info("Running the Snowpark function '%s'", snowpark_fn.__name__)
|
392
|
+
snowpark_results = snowpark_fn(*args, **kwargs)
|
393
|
+
sampler = SamplingAdapter(
|
394
|
+
job_context, sample_frac, sample_number, sampling_strategy
|
395
|
+
)
|
396
|
+
sampler.process_args([snowpark_results])
|
397
|
+
pandas_sample_args = sampler.get_sampled_pandas_args()
|
398
|
+
|
399
|
+
is_valid, validation_result = _validate(
|
400
|
+
pandera_schema, pandas_sample_args[0]
|
401
|
+
)
|
402
|
+
if is_valid:
|
403
|
+
LOGGER.info(
|
404
|
+
"Output schema validation passed for Snowpark function '%s' and checkpoint '%s'",
|
405
|
+
snowpark_fn.__name__,
|
406
|
+
_checkpoint_name,
|
407
|
+
)
|
408
|
+
if job_context is not None:
|
409
|
+
job_context._mark_pass(_checkpoint_name)
|
410
|
+
else:
|
411
|
+
LOGGER.warning(
|
412
|
+
"No job context provided. Skipping result recording into Snowflake.",
|
413
|
+
)
|
414
|
+
_update_validation_result(_checkpoint_name, PASS_STATUS, output_path)
|
415
|
+
else:
|
416
|
+
LOGGER.error(
|
417
|
+
"Output schema validation failed for Snowpark function '%s' and checkpoint '%s'",
|
418
|
+
snowpark_fn.__name__,
|
419
|
+
_checkpoint_name,
|
420
|
+
)
|
421
|
+
_update_validation_result(_checkpoint_name, FAIL_STATUS, output_path)
|
422
|
+
raise SchemaValidationError(
|
423
|
+
"Snowpark output schema validation error",
|
424
|
+
job_context,
|
425
|
+
_checkpoint_name,
|
426
|
+
validation_result,
|
427
|
+
)
|
428
|
+
return snowpark_results
|
429
|
+
|
430
|
+
return wrapper
|
431
|
+
|
432
|
+
return check_output_with_decorator
|
433
|
+
|
434
|
+
|
435
|
+
@report_telemetry(params_list=["pandera_schema"])
|
436
|
+
@log
|
437
|
+
def check_input_schema(
|
438
|
+
pandera_schema: DataFrameSchema,
|
439
|
+
checkpoint_name: str,
|
440
|
+
sample_frac: Optional[float] = 1.0,
|
441
|
+
sample_number: Optional[int] = None,
|
442
|
+
sampling_strategy: Optional[SamplingStrategy] = SamplingStrategy.RANDOM_SAMPLE,
|
443
|
+
job_context: Optional[SnowparkJobContext] = None,
|
444
|
+
output_path: Optional[str] = None,
|
445
|
+
):
|
446
|
+
"""Decorate factory for validating input DataFrame schemas before function execution.
|
447
|
+
|
448
|
+
Args:
|
449
|
+
pandera_schema (DataFrameSchema): The Pandera schema to validate against.
|
450
|
+
checkpoint_name (Optional[str], optional): The name of the checkpoint to retrieve the schema.
|
451
|
+
sample_frac (Optional[float], optional): Fraction of data to sample.
|
452
|
+
Defaults to 0.1.
|
453
|
+
sample_number (Optional[int], optional): Number of rows to sample.
|
454
|
+
Defaults to None.
|
455
|
+
sampling_strategy (Optional[SamplingStrategy], optional): Strategy for sampling data.
|
456
|
+
Defaults to SamplingStrategy.RANDOM_SAMPLE.
|
457
|
+
job_context (SnowparkJobContext, optional): Context for job-related operations.
|
458
|
+
Defaults to None.
|
459
|
+
output_path (Optional[str], optional): The output path for the validation results.
|
460
|
+
|
461
|
+
|
462
|
+
"""
|
463
|
+
|
464
|
+
def check_input_with_decorator(snowpark_fn):
|
465
|
+
"""Decorate that validates input schemas for the decorated function.
|
466
|
+
|
467
|
+
Args:
|
468
|
+
snowpark_fn (Callable): The function to be decorated with input schema validation.
|
469
|
+
|
470
|
+
Raises:
|
471
|
+
SchemaValidationError: If input data fails schema validation.
|
472
|
+
|
473
|
+
Returns:
|
474
|
+
Callable: A wrapper function that performs schema validation before executing the original function.
|
475
|
+
|
476
|
+
"""
|
477
|
+
|
478
|
+
@log(log_args=False)
|
479
|
+
def wrapper(*args, **kwargs):
|
480
|
+
"""Wrapp a function to validate the schema of the input of a Snowpark function.
|
481
|
+
|
482
|
+
Raises:
|
483
|
+
SchemaValidationError: If any input DataFrame fails schema validation.
|
484
|
+
|
485
|
+
Returns:
|
486
|
+
Any: The result of the original function after input validation.
|
487
|
+
|
488
|
+
"""
|
489
|
+
_checkpoint_name = checkpoint_name
|
490
|
+
if checkpoint_name is None:
|
491
|
+
LOGGER.warning(
|
492
|
+
(
|
493
|
+
"No checkpoint name provided for input schema validation. "
|
494
|
+
"Using '%s' as the checkpoint name."
|
495
|
+
),
|
496
|
+
snowpark_fn.__name__,
|
497
|
+
)
|
498
|
+
_checkpoint_name = snowpark_fn.__name__
|
499
|
+
_checkpoint_name = _replace_special_characters(_checkpoint_name)
|
500
|
+
LOGGER.info(
|
501
|
+
"Starting input schema validation for Snowpark function '%s' and checkpoint '%s'",
|
502
|
+
snowpark_fn.__name__,
|
503
|
+
_checkpoint_name,
|
504
|
+
)
|
505
|
+
|
506
|
+
# Run the sampled data in snowpark
|
507
|
+
sampler = SamplingAdapter(
|
508
|
+
job_context, sample_frac, sample_number, sampling_strategy
|
509
|
+
)
|
510
|
+
sampler.process_args(args)
|
511
|
+
pandas_sample_args = sampler.get_sampled_pandas_args()
|
512
|
+
|
513
|
+
LOGGER.info(
|
514
|
+
"Validating %s input argument(s) against a Pandera schema",
|
515
|
+
len(pandas_sample_args),
|
516
|
+
)
|
517
|
+
# Raises SchemaError on validation issues
|
518
|
+
for index, arg in enumerate(pandas_sample_args, start=1):
|
519
|
+
if not isinstance(arg, PandasDataFrame):
|
520
|
+
LOGGER.info(
|
521
|
+
"Arg %s: Skipping schema validation for non-DataFrame argument",
|
522
|
+
index,
|
523
|
+
)
|
524
|
+
continue
|
525
|
+
|
526
|
+
is_valid, validation_result = _validate(
|
527
|
+
pandera_schema,
|
528
|
+
arg,
|
529
|
+
)
|
530
|
+
if is_valid:
|
531
|
+
LOGGER.info(
|
532
|
+
"Arg %s: Input schema validation passed",
|
533
|
+
index,
|
534
|
+
)
|
535
|
+
if job_context is not None:
|
536
|
+
job_context._mark_pass(
|
537
|
+
_checkpoint_name,
|
538
|
+
)
|
539
|
+
_update_validation_result(
|
540
|
+
_checkpoint_name,
|
541
|
+
PASS_STATUS,
|
542
|
+
output_path,
|
543
|
+
)
|
544
|
+
else:
|
545
|
+
LOGGER.error(
|
546
|
+
"Arg %s: Input schema validation failed",
|
547
|
+
index,
|
548
|
+
)
|
549
|
+
_update_validation_result(
|
550
|
+
_checkpoint_name,
|
551
|
+
FAIL_STATUS,
|
552
|
+
output_path,
|
553
|
+
)
|
554
|
+
raise SchemaValidationError(
|
555
|
+
"Snowpark input schema validation error",
|
556
|
+
job_context,
|
557
|
+
_checkpoint_name,
|
558
|
+
validation_result,
|
559
|
+
)
|
560
|
+
return snowpark_fn(*args, **kwargs)
|
561
|
+
|
562
|
+
return wrapper
|
563
|
+
|
564
|
+
return check_input_with_decorator
|
565
|
+
|
566
|
+
|
567
|
+
def _validate(
|
568
|
+
schema: Union[type[DataFrameModel], DataFrameSchema],
|
569
|
+
df: PandasDataFrame,
|
570
|
+
lazy: bool = True,
|
571
|
+
) -> tuple[bool, PandasDataFrame]:
|
572
|
+
if not isinstance(schema, DataFrameSchema):
|
573
|
+
schema = schema.to_schema()
|
574
|
+
is_valid = True
|
575
|
+
try:
|
576
|
+
df = schema.validate(df, lazy=lazy)
|
577
|
+
except (SchemaErrors, SchemaError) as schema_errors:
|
578
|
+
df = cast(PandasDataFrame, schema_errors.failure_cases)
|
579
|
+
is_valid = False
|
580
|
+
return is_valid, df
|
@@ -0,0 +1,60 @@
|
|
1
|
+
# Copyright 2025 Snowflake Inc.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
from typing import Optional
|
17
|
+
|
18
|
+
from snowflake.snowpark_checkpoints.job_context import SnowparkJobContext
|
19
|
+
|
20
|
+
|
21
|
+
class SnowparkCheckpointError(Exception):
|
22
|
+
def __init__(
|
23
|
+
self,
|
24
|
+
message,
|
25
|
+
job_context: Optional[SnowparkJobContext],
|
26
|
+
checkpoint_name: str,
|
27
|
+
data=None,
|
28
|
+
):
|
29
|
+
job_name = job_context.job_name if job_context else "Unknown Job"
|
30
|
+
super().__init__(
|
31
|
+
f"Job: {job_name} Checkpoint: {checkpoint_name}\n{message} \n {data}"
|
32
|
+
)
|
33
|
+
if job_context:
|
34
|
+
job_context._mark_fail(
|
35
|
+
message,
|
36
|
+
checkpoint_name,
|
37
|
+
data,
|
38
|
+
)
|
39
|
+
|
40
|
+
|
41
|
+
class SparkMigrationError(SnowparkCheckpointError):
|
42
|
+
def __init__(
|
43
|
+
self,
|
44
|
+
message,
|
45
|
+
job_context: Optional[SnowparkJobContext],
|
46
|
+
checkpoint_name=None,
|
47
|
+
data=None,
|
48
|
+
):
|
49
|
+
super().__init__(message, job_context, checkpoint_name, data)
|
50
|
+
|
51
|
+
|
52
|
+
class SchemaValidationError(SnowparkCheckpointError):
|
53
|
+
def __init__(
|
54
|
+
self,
|
55
|
+
message,
|
56
|
+
job_context: Optional[SnowparkJobContext],
|
57
|
+
checkpoint_name=None,
|
58
|
+
data=None,
|
59
|
+
):
|
60
|
+
super().__init__(message, job_context, checkpoint_name, data)
|