snowpark-checkpoints-validators 0.1.0rc3__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/snowpark_checkpoints/__init__.py +34 -0
- snowflake/snowpark_checkpoints/__version__.py +16 -0
- snowflake/snowpark_checkpoints/checkpoint.py +482 -0
- snowflake/snowpark_checkpoints/errors.py +60 -0
- snowflake/snowpark_checkpoints/job_context.py +85 -0
- snowflake/snowpark_checkpoints/singleton.py +23 -0
- snowflake/snowpark_checkpoints/snowpark_sampler.py +99 -0
- snowflake/snowpark_checkpoints/spark_migration.py +222 -0
- snowflake/snowpark_checkpoints/utils/__init__.py +14 -0
- snowflake/snowpark_checkpoints/utils/checkpoint_logger.py +52 -0
- snowflake/snowpark_checkpoints/utils/constants.py +134 -0
- snowflake/snowpark_checkpoints/utils/extra_config.py +84 -0
- snowflake/snowpark_checkpoints/utils/pandera_check_manager.py +358 -0
- snowflake/snowpark_checkpoints/utils/supported_types.py +65 -0
- snowflake/snowpark_checkpoints/utils/telemetry.py +900 -0
- snowflake/snowpark_checkpoints/utils/utils_checks.py +374 -0
- snowflake/snowpark_checkpoints/validation_result_metadata.py +125 -0
- snowflake/snowpark_checkpoints/validation_results.py +49 -0
- {snowpark_checkpoints_validators-0.1.0rc3.dist-info → snowpark_checkpoints_validators-0.1.2.dist-info}/METADATA +4 -7
- snowpark_checkpoints_validators-0.1.2.dist-info/RECORD +22 -0
- snowpark_checkpoints_validators-0.1.0rc3.dist-info/RECORD +0 -4
- {snowpark_checkpoints_validators-0.1.0rc3.dist-info → snowpark_checkpoints_validators-0.1.2.dist-info}/WHEEL +0 -0
- {snowpark_checkpoints_validators-0.1.0rc3.dist-info → snowpark_checkpoints_validators-0.1.2.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,34 @@
|
|
1
|
+
# Copyright 2025 Snowflake Inc.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
from snowflake.snowpark_checkpoints.checkpoint import (
|
17
|
+
check_dataframe_schema,
|
18
|
+
check_output_schema,
|
19
|
+
check_input_schema,
|
20
|
+
validate_dataframe_checkpoint,
|
21
|
+
)
|
22
|
+
from snowflake.snowpark_checkpoints.job_context import SnowparkJobContext
|
23
|
+
from snowflake.snowpark_checkpoints.spark_migration import check_with_spark
|
24
|
+
from snowflake.snowpark_checkpoints.utils.constants import CheckpointMode
|
25
|
+
|
26
|
+
__all__ = [
|
27
|
+
"check_with_spark",
|
28
|
+
"SnowparkJobContext",
|
29
|
+
"check_dataframe_schema",
|
30
|
+
"check_output_schema",
|
31
|
+
"check_input_schema",
|
32
|
+
"validate_dataframe_checkpoint",
|
33
|
+
"CheckpointMode",
|
34
|
+
]
|
@@ -0,0 +1,16 @@
|
|
1
|
+
# Copyright 2025 Snowflake Inc.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
__version__ = "0.1.2"
|
@@ -0,0 +1,482 @@
|
|
1
|
+
# Copyright 2025 Snowflake Inc.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
# Wrapper around pandera which logs to snowflake
|
17
|
+
from typing import Any, Optional, Union
|
18
|
+
|
19
|
+
from pandas import DataFrame as PandasDataFrame
|
20
|
+
from pandera import Check, DataFrameSchema
|
21
|
+
from pandera_report import DataFrameValidator
|
22
|
+
|
23
|
+
from snowflake.snowpark import DataFrame as SnowparkDataFrame
|
24
|
+
from snowflake.snowpark_checkpoints.errors import SchemaValidationError
|
25
|
+
from snowflake.snowpark_checkpoints.job_context import SnowparkJobContext
|
26
|
+
from snowflake.snowpark_checkpoints.snowpark_sampler import (
|
27
|
+
SamplingAdapter,
|
28
|
+
SamplingStrategy,
|
29
|
+
)
|
30
|
+
from snowflake.snowpark_checkpoints.utils.checkpoint_logger import CheckpointLogger
|
31
|
+
from snowflake.snowpark_checkpoints.utils.constants import (
|
32
|
+
FAIL_STATUS,
|
33
|
+
PASS_STATUS,
|
34
|
+
CheckpointMode,
|
35
|
+
)
|
36
|
+
from snowflake.snowpark_checkpoints.utils.extra_config import is_checkpoint_enabled
|
37
|
+
from snowflake.snowpark_checkpoints.utils.pandera_check_manager import (
|
38
|
+
PanderaCheckManager,
|
39
|
+
)
|
40
|
+
from snowflake.snowpark_checkpoints.utils.telemetry import STATUS_KEY, report_telemetry
|
41
|
+
from snowflake.snowpark_checkpoints.utils.utils_checks import (
|
42
|
+
_check_compare_data,
|
43
|
+
_generate_schema,
|
44
|
+
_process_sampling,
|
45
|
+
_replace_special_characters,
|
46
|
+
_update_validation_result,
|
47
|
+
)
|
48
|
+
|
49
|
+
|
50
|
+
def validate_dataframe_checkpoint(
|
51
|
+
df: SnowparkDataFrame,
|
52
|
+
checkpoint_name: str,
|
53
|
+
job_context: Optional[SnowparkJobContext] = None,
|
54
|
+
mode: Optional[CheckpointMode] = CheckpointMode.SCHEMA,
|
55
|
+
custom_checks: Optional[dict[Any, Any]] = None,
|
56
|
+
skip_checks: Optional[dict[Any, Any]] = None,
|
57
|
+
sample_frac: Optional[float] = 1.0,
|
58
|
+
sample_number: Optional[int] = None,
|
59
|
+
sampling_strategy: Optional[SamplingStrategy] = SamplingStrategy.RANDOM_SAMPLE,
|
60
|
+
output_path: Optional[str] = None,
|
61
|
+
) -> Union[tuple[bool, PandasDataFrame], None]:
|
62
|
+
"""Validate a Snowpark DataFrame against a specified checkpoint.
|
63
|
+
|
64
|
+
Args:
|
65
|
+
df (SnowparkDataFrame): The DataFrame to validate.
|
66
|
+
checkpoint_name (str): The name of the checkpoint to validate against.
|
67
|
+
job_context (SnowparkJobContext, optional): The job context for the validation. Required for PARQUET mode.
|
68
|
+
mode (CheckpointMode): The mode of validation (e.g., SCHEMA, PARQUET). Defaults to SCHEMA.
|
69
|
+
custom_checks (Optional[dict[Any, Any]], optional): Custom checks to apply during validation.
|
70
|
+
skip_checks (Optional[dict[Any, Any]], optional): Checks to skip during validation.
|
71
|
+
sample_frac (Optional[float], optional): Fraction of the DataFrame to sample for validation. Defaults to 0.1.
|
72
|
+
sample_number (Optional[int], optional): Number of rows to sample for validation.
|
73
|
+
sampling_strategy (Optional[SamplingStrategy], optional): Strategy to use for sampling.
|
74
|
+
Defaults to RANDOM_SAMPLE.
|
75
|
+
output_path (Optional[str], optional): The output path for the validation results.
|
76
|
+
|
77
|
+
Returns:
|
78
|
+
Union[tuple[bool, PandasDataFrame], None]: A tuple containing a boolean indicating success
|
79
|
+
and a Pandas DataFrame with validation results, or None if validation is not applicable.
|
80
|
+
|
81
|
+
Raises:
|
82
|
+
ValueError: If an invalid validation mode is provided or if job_context is None for PARQUET mode.
|
83
|
+
|
84
|
+
"""
|
85
|
+
checkpoint_name = _replace_special_characters(checkpoint_name)
|
86
|
+
|
87
|
+
if is_checkpoint_enabled(checkpoint_name):
|
88
|
+
|
89
|
+
if mode == CheckpointMode.SCHEMA:
|
90
|
+
return _check_dataframe_schema_file(
|
91
|
+
df,
|
92
|
+
checkpoint_name,
|
93
|
+
job_context,
|
94
|
+
custom_checks,
|
95
|
+
skip_checks,
|
96
|
+
sample_frac,
|
97
|
+
sample_number,
|
98
|
+
sampling_strategy,
|
99
|
+
output_path,
|
100
|
+
)
|
101
|
+
elif mode == CheckpointMode.DATAFRAME:
|
102
|
+
if job_context is None:
|
103
|
+
raise ValueError(
|
104
|
+
"Connectionless mode is not supported for Parquet validation"
|
105
|
+
)
|
106
|
+
_check_compare_data(df, job_context, checkpoint_name, output_path)
|
107
|
+
else:
|
108
|
+
raise ValueError(
|
109
|
+
"""Invalid validation mode.
|
110
|
+
Please use for schema validation use a 1 or for a full data validation use a 2 for schema validation."""
|
111
|
+
)
|
112
|
+
|
113
|
+
|
114
|
+
def _check_dataframe_schema_file(
|
115
|
+
df: SnowparkDataFrame,
|
116
|
+
checkpoint_name: str,
|
117
|
+
job_context: Optional[SnowparkJobContext] = None,
|
118
|
+
custom_checks: Optional[dict[Any, Any]] = None,
|
119
|
+
skip_checks: Optional[dict[Any, Any]] = None,
|
120
|
+
sample_frac: Optional[float] = 1.0,
|
121
|
+
sample_number: Optional[int] = None,
|
122
|
+
sampling_strategy: Optional[SamplingStrategy] = SamplingStrategy.RANDOM_SAMPLE,
|
123
|
+
output_path: Optional[str] = None,
|
124
|
+
) -> tuple[bool, PandasDataFrame]:
|
125
|
+
"""Generate and checks the schema for a given DataFrame based on a checkpoint name.
|
126
|
+
|
127
|
+
Args:
|
128
|
+
df (SnowparkDataFrame): The DataFrame to be validated.
|
129
|
+
checkpoint_name (str): The name of the checkpoint to retrieve the schema.
|
130
|
+
job_context (SnowparkJobContext, optional): Context for job-related operations.
|
131
|
+
Defaults to None.
|
132
|
+
custom_checks (dict[Any, Any], optional): Custom checks to be added to the schema.
|
133
|
+
Defaults to None.
|
134
|
+
skip_checks (dict[Any, Any], optional): Checks to be skipped.
|
135
|
+
Defaults to None.
|
136
|
+
sample_frac (float, optional): Fraction of data to sample.
|
137
|
+
Defaults to 0.1.
|
138
|
+
sample_number (int, optional): Number of rows to sample.
|
139
|
+
Defaults to None.
|
140
|
+
sampling_strategy (SamplingStrategy, optional): Strategy for sampling data.
|
141
|
+
Defaults to SamplingStrategy.RANDOM_SAMPLE.
|
142
|
+
output_path (str, optional): The output path for the validation results.
|
143
|
+
|
144
|
+
Raises:
|
145
|
+
SchemaValidationError: If the DataFrame fails schema validation.
|
146
|
+
|
147
|
+
Returns:
|
148
|
+
tuple[bool, PanderaDataFrame]: A tuple containing the validity flag and the Pandera DataFrame.
|
149
|
+
|
150
|
+
"""
|
151
|
+
if df is None:
|
152
|
+
raise ValueError("DataFrame is required")
|
153
|
+
|
154
|
+
if checkpoint_name is None:
|
155
|
+
raise ValueError("Checkpoint name is required")
|
156
|
+
|
157
|
+
schema = _generate_schema(checkpoint_name, output_path)
|
158
|
+
|
159
|
+
return check_dataframe_schema(
|
160
|
+
df,
|
161
|
+
schema,
|
162
|
+
checkpoint_name,
|
163
|
+
job_context,
|
164
|
+
custom_checks,
|
165
|
+
skip_checks,
|
166
|
+
sample_frac,
|
167
|
+
sample_number,
|
168
|
+
sampling_strategy,
|
169
|
+
output_path,
|
170
|
+
)
|
171
|
+
|
172
|
+
|
173
|
+
def check_dataframe_schema(
|
174
|
+
df: SnowparkDataFrame,
|
175
|
+
pandera_schema: DataFrameSchema,
|
176
|
+
checkpoint_name: str,
|
177
|
+
job_context: Optional[SnowparkJobContext] = None,
|
178
|
+
custom_checks: Optional[dict[str, list[Check]]] = None,
|
179
|
+
skip_checks: Optional[dict[Any, Any]] = None,
|
180
|
+
sample_frac: Optional[float] = 1.0,
|
181
|
+
sample_number: Optional[int] = None,
|
182
|
+
sampling_strategy: Optional[SamplingStrategy] = SamplingStrategy.RANDOM_SAMPLE,
|
183
|
+
output_path: Optional[str] = None,
|
184
|
+
) -> Union[tuple[bool, PandasDataFrame], None]:
|
185
|
+
"""Validate a DataFrame against a given Pandera schema using sampling techniques.
|
186
|
+
|
187
|
+
Args:
|
188
|
+
df (SnowparkDataFrame): The DataFrame to be validated.
|
189
|
+
pandera_schema (DataFrameSchema): The Pandera schema to validate against.
|
190
|
+
checkpoint_name (str, optional): The name of the checkpoint to retrieve the schema.
|
191
|
+
Defaults to None.
|
192
|
+
job_context (SnowparkJobContext, optional): Context for job-related operations.
|
193
|
+
Defaults to None.
|
194
|
+
custom_checks (dict[Any, Any], optional): Custom checks to be added to the schema.
|
195
|
+
Defaults to None.
|
196
|
+
skip_checks (dict[Any, Any], optional): Checks to be skipped.
|
197
|
+
Defaults to None.
|
198
|
+
sample_frac (float, optional): Fraction of data to sample.
|
199
|
+
Defaults to 0.1.
|
200
|
+
sample_number (int, optional): Number of rows to sample.
|
201
|
+
Defaults to None.
|
202
|
+
sampling_strategy (SamplingStrategy, optional): Strategy for sampling data.
|
203
|
+
Defaults to SamplingStrategy.RANDOM_SAMPLE.
|
204
|
+
output_path (str, optional): The output path for the validation results.
|
205
|
+
|
206
|
+
Raises:
|
207
|
+
SchemaValidationError: If the DataFrame fails schema validation.
|
208
|
+
|
209
|
+
Returns:
|
210
|
+
Union[tuple[bool, PandasDataFrame]|None]: A tuple containing the validity flag and the Pandas DataFrame.
|
211
|
+
If the validation for that checkpoint is disabled it returns None.
|
212
|
+
|
213
|
+
"""
|
214
|
+
checkpoint_name = _replace_special_characters(checkpoint_name)
|
215
|
+
|
216
|
+
if df is None:
|
217
|
+
raise ValueError("DataFrame is required")
|
218
|
+
|
219
|
+
if pandera_schema is None:
|
220
|
+
raise ValueError("Schema is required")
|
221
|
+
|
222
|
+
if is_checkpoint_enabled(checkpoint_name):
|
223
|
+
return _check_dataframe_schema(
|
224
|
+
df,
|
225
|
+
pandera_schema,
|
226
|
+
checkpoint_name,
|
227
|
+
job_context,
|
228
|
+
custom_checks,
|
229
|
+
skip_checks,
|
230
|
+
sample_frac,
|
231
|
+
sample_number,
|
232
|
+
sampling_strategy,
|
233
|
+
output_path,
|
234
|
+
)
|
235
|
+
|
236
|
+
|
237
|
+
@report_telemetry(
|
238
|
+
params_list=["pandera_schema"],
|
239
|
+
return_indexes=[(STATUS_KEY, 0)],
|
240
|
+
multiple_return=True,
|
241
|
+
)
|
242
|
+
def _check_dataframe_schema(
|
243
|
+
df: SnowparkDataFrame,
|
244
|
+
pandera_schema: DataFrameSchema,
|
245
|
+
checkpoint_name: str,
|
246
|
+
job_context: SnowparkJobContext = None,
|
247
|
+
custom_checks: Optional[dict[str, list[Check]]] = None,
|
248
|
+
skip_checks: Optional[dict[Any, Any]] = None,
|
249
|
+
sample_frac: Optional[float] = 1.0,
|
250
|
+
sample_number: Optional[int] = None,
|
251
|
+
sampling_strategy: Optional[SamplingStrategy] = SamplingStrategy.RANDOM_SAMPLE,
|
252
|
+
output_path: Optional[str] = None,
|
253
|
+
) -> tuple[bool, PandasDataFrame]:
|
254
|
+
|
255
|
+
pandera_check_manager = PanderaCheckManager(checkpoint_name, pandera_schema)
|
256
|
+
pandera_check_manager.skip_checks_on_schema(skip_checks)
|
257
|
+
pandera_check_manager.add_custom_checks(custom_checks)
|
258
|
+
|
259
|
+
pandera_schema_upper, sample_df = _process_sampling(
|
260
|
+
df, pandera_schema, job_context, sample_frac, sample_number, sampling_strategy
|
261
|
+
)
|
262
|
+
|
263
|
+
# Raises SchemaError on validation issues
|
264
|
+
validator = DataFrameValidator()
|
265
|
+
is_valid, validation_result = validator.validate(
|
266
|
+
pandera_schema_upper, sample_df, validity_flag=True
|
267
|
+
)
|
268
|
+
if is_valid:
|
269
|
+
if job_context is not None:
|
270
|
+
job_context._mark_pass(checkpoint_name)
|
271
|
+
_update_validation_result(checkpoint_name, PASS_STATUS, output_path)
|
272
|
+
else:
|
273
|
+
_update_validation_result(checkpoint_name, FAIL_STATUS, output_path)
|
274
|
+
raise SchemaValidationError(
|
275
|
+
"Snowpark DataFrame schema validation error",
|
276
|
+
job_context,
|
277
|
+
checkpoint_name,
|
278
|
+
validation_result,
|
279
|
+
)
|
280
|
+
|
281
|
+
return (is_valid, validation_result)
|
282
|
+
|
283
|
+
|
284
|
+
@report_telemetry(params_list=["pandera_schema"])
|
285
|
+
def check_output_schema(
|
286
|
+
pandera_schema: DataFrameSchema,
|
287
|
+
checkpoint_name: str,
|
288
|
+
sample_frac: Optional[float] = 1.0,
|
289
|
+
sample_number: Optional[int] = None,
|
290
|
+
sampling_strategy: Optional[SamplingStrategy] = SamplingStrategy.RANDOM_SAMPLE,
|
291
|
+
job_context: Optional[SnowparkJobContext] = None,
|
292
|
+
output_path: Optional[str] = None,
|
293
|
+
):
|
294
|
+
"""Decorate to validate the schema of the output of a Snowpark function.
|
295
|
+
|
296
|
+
Args:
|
297
|
+
pandera_schema (DataFrameSchema): The Pandera schema to validate against.
|
298
|
+
checkpoint_name (Optional[str], optional): The name of the checkpoint to retrieve the schema.
|
299
|
+
sample_frac (Optional[float], optional): Fraction of data to sample.
|
300
|
+
Defaults to 0.1.
|
301
|
+
sample_number (Optional[int], optional): Number of rows to sample.
|
302
|
+
Defaults to None.
|
303
|
+
sampling_strategy (Optional[SamplingStrategy], optional): Strategy for sampling data.
|
304
|
+
Defaults to SamplingStrategy.RANDOM_SAMPLE.
|
305
|
+
job_context (SnowparkJobContext, optional): Context for job-related operations.
|
306
|
+
Defaults to None.
|
307
|
+
output_path (Optional[str], optional): The output path for the validation results.
|
308
|
+
|
309
|
+
"""
|
310
|
+
|
311
|
+
def check_output_with_decorator(snowpark_fn):
|
312
|
+
"""Decorate to validate the schema of the output of a Snowpark function.
|
313
|
+
|
314
|
+
Args:
|
315
|
+
snowpark_fn (function): The Snowpark function to validate.
|
316
|
+
|
317
|
+
Returns:
|
318
|
+
function: The decorated function.
|
319
|
+
|
320
|
+
"""
|
321
|
+
_checkpoint_name = checkpoint_name
|
322
|
+
if checkpoint_name is None:
|
323
|
+
_checkpoint_name = snowpark_fn.__name__
|
324
|
+
_checkpoint_name = _replace_special_characters(_checkpoint_name)
|
325
|
+
|
326
|
+
def wrapper(*args, **kwargs):
|
327
|
+
"""Wrapp a function to validate the schema of the output of a Snowpark function.
|
328
|
+
|
329
|
+
Args:
|
330
|
+
*args: The arguments to the Snowpark function.
|
331
|
+
**kwargs: The keyword arguments to the Snowpark function.
|
332
|
+
|
333
|
+
Returns:
|
334
|
+
Any: The result of the Snowpark function.
|
335
|
+
|
336
|
+
"""
|
337
|
+
# Run the sampled data in snowpark
|
338
|
+
snowpark_results = snowpark_fn(*args, **kwargs)
|
339
|
+
sampler = SamplingAdapter(
|
340
|
+
job_context, sample_frac, sample_number, sampling_strategy
|
341
|
+
)
|
342
|
+
sampler.process_args([snowpark_results])
|
343
|
+
pandas_sample_args = sampler.get_sampled_pandas_args()
|
344
|
+
|
345
|
+
# Raises SchemaError on validation issues
|
346
|
+
validator = DataFrameValidator()
|
347
|
+
is_valid, validation_result = validator.validate(
|
348
|
+
pandera_schema, pandas_sample_args[0], validity_flag=True
|
349
|
+
)
|
350
|
+
logger = CheckpointLogger().get_logger()
|
351
|
+
logger.info(
|
352
|
+
f"Checkpoint {_checkpoint_name} validation result:\n{validation_result}"
|
353
|
+
)
|
354
|
+
|
355
|
+
if is_valid:
|
356
|
+
if job_context is not None:
|
357
|
+
job_context._mark_pass(_checkpoint_name)
|
358
|
+
|
359
|
+
_update_validation_result(_checkpoint_name, PASS_STATUS, output_path)
|
360
|
+
else:
|
361
|
+
_update_validation_result(_checkpoint_name, FAIL_STATUS, output_path)
|
362
|
+
raise SchemaValidationError(
|
363
|
+
"Snowpark output schema validation error",
|
364
|
+
job_context,
|
365
|
+
_checkpoint_name,
|
366
|
+
validation_result,
|
367
|
+
)
|
368
|
+
|
369
|
+
return snowpark_results
|
370
|
+
|
371
|
+
return wrapper
|
372
|
+
|
373
|
+
return check_output_with_decorator
|
374
|
+
|
375
|
+
|
376
|
+
@report_telemetry(params_list=["pandera_schema"])
|
377
|
+
def check_input_schema(
|
378
|
+
pandera_schema: DataFrameSchema,
|
379
|
+
checkpoint_name: str,
|
380
|
+
sample_frac: Optional[float] = 1.0,
|
381
|
+
sample_number: Optional[int] = None,
|
382
|
+
sampling_strategy: Optional[SamplingStrategy] = SamplingStrategy.RANDOM_SAMPLE,
|
383
|
+
job_context: Optional[SnowparkJobContext] = None,
|
384
|
+
output_path: Optional[str] = None,
|
385
|
+
):
|
386
|
+
"""Decorate factory for validating input DataFrame schemas before function execution.
|
387
|
+
|
388
|
+
Args:
|
389
|
+
pandera_schema (DataFrameSchema): The Pandera schema to validate against.
|
390
|
+
checkpoint_name (Optional[str], optional): The name of the checkpoint to retrieve the schema.
|
391
|
+
sample_frac (Optional[float], optional): Fraction of data to sample.
|
392
|
+
Defaults to 0.1.
|
393
|
+
sample_number (Optional[int], optional): Number of rows to sample.
|
394
|
+
Defaults to None.
|
395
|
+
sampling_strategy (Optional[SamplingStrategy], optional): Strategy for sampling data.
|
396
|
+
Defaults to SamplingStrategy.RANDOM_SAMPLE.
|
397
|
+
job_context (SnowparkJobContext, optional): Context for job-related operations.
|
398
|
+
Defaults to None.
|
399
|
+
output_path (Optional[str], optional): The output path for the validation results.
|
400
|
+
|
401
|
+
|
402
|
+
"""
|
403
|
+
|
404
|
+
def check_input_with_decorator(snowpark_fn):
|
405
|
+
"""Decorate that validates input schemas for the decorated function.
|
406
|
+
|
407
|
+
Args:
|
408
|
+
snowpark_fn (Callable): The function to be decorated with input schema validation.
|
409
|
+
|
410
|
+
Raises:
|
411
|
+
SchemaValidationError: If input data fails schema validation.
|
412
|
+
|
413
|
+
Returns:
|
414
|
+
Callable: A wrapper function that performs schema validation before executing the original function.
|
415
|
+
|
416
|
+
"""
|
417
|
+
_checkpoint_name = checkpoint_name
|
418
|
+
if checkpoint_name is None:
|
419
|
+
_checkpoint_name = snowpark_fn.__name__
|
420
|
+
_checkpoint_name = _replace_special_characters(_checkpoint_name)
|
421
|
+
|
422
|
+
def wrapper(*args, **kwargs):
|
423
|
+
"""Wrapp a function to validate the schema of the input of a Snowpark function.
|
424
|
+
|
425
|
+
Raises:
|
426
|
+
SchemaValidationError: If any input DataFrame fails schema validation.
|
427
|
+
|
428
|
+
Returns:
|
429
|
+
Any: The result of the original function after input validation.
|
430
|
+
|
431
|
+
"""
|
432
|
+
# Run the sampled data in snowpark
|
433
|
+
sampler = SamplingAdapter(
|
434
|
+
job_context, sample_frac, sample_number, sampling_strategy
|
435
|
+
)
|
436
|
+
sampler.process_args(args)
|
437
|
+
pandas_sample_args = sampler.get_sampled_pandas_args()
|
438
|
+
|
439
|
+
# Raises SchemaError on validation issues
|
440
|
+
for arg in pandas_sample_args:
|
441
|
+
if isinstance(arg, PandasDataFrame):
|
442
|
+
|
443
|
+
validator = DataFrameValidator()
|
444
|
+
is_valid, validation_result = validator.validate(
|
445
|
+
pandera_schema,
|
446
|
+
arg,
|
447
|
+
validity_flag=True,
|
448
|
+
)
|
449
|
+
|
450
|
+
logger = CheckpointLogger().get_logger()
|
451
|
+
logger.info(
|
452
|
+
f"Checkpoint {checkpoint_name} validation result:\n{validation_result}"
|
453
|
+
)
|
454
|
+
|
455
|
+
if is_valid:
|
456
|
+
if job_context is not None:
|
457
|
+
job_context._mark_pass(
|
458
|
+
_checkpoint_name,
|
459
|
+
)
|
460
|
+
|
461
|
+
_update_validation_result(
|
462
|
+
_checkpoint_name,
|
463
|
+
PASS_STATUS,
|
464
|
+
output_path,
|
465
|
+
)
|
466
|
+
else:
|
467
|
+
_update_validation_result(
|
468
|
+
_checkpoint_name,
|
469
|
+
FAIL_STATUS,
|
470
|
+
output_path,
|
471
|
+
)
|
472
|
+
raise SchemaValidationError(
|
473
|
+
"Snowpark input schema validation error",
|
474
|
+
job_context,
|
475
|
+
_checkpoint_name,
|
476
|
+
validation_result,
|
477
|
+
)
|
478
|
+
return snowpark_fn(*args, **kwargs)
|
479
|
+
|
480
|
+
return wrapper
|
481
|
+
|
482
|
+
return check_input_with_decorator
|
@@ -0,0 +1,60 @@
|
|
1
|
+
# Copyright 2025 Snowflake Inc.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
from typing import Optional
|
17
|
+
|
18
|
+
from snowflake.snowpark_checkpoints.job_context import SnowparkJobContext
|
19
|
+
|
20
|
+
|
21
|
+
class SnowparkCheckpointError(Exception):
|
22
|
+
def __init__(
|
23
|
+
self,
|
24
|
+
message,
|
25
|
+
job_context: Optional[SnowparkJobContext],
|
26
|
+
checkpoint_name: str,
|
27
|
+
data=None,
|
28
|
+
):
|
29
|
+
job_name = job_context.job_name if job_context else "Unknown Job"
|
30
|
+
super().__init__(
|
31
|
+
f"Job: {job_name} Checkpoint: {checkpoint_name}\n{message} \n {data}"
|
32
|
+
)
|
33
|
+
if job_context:
|
34
|
+
job_context._mark_fail(
|
35
|
+
message,
|
36
|
+
checkpoint_name,
|
37
|
+
data,
|
38
|
+
)
|
39
|
+
|
40
|
+
|
41
|
+
class SparkMigrationError(SnowparkCheckpointError):
|
42
|
+
def __init__(
|
43
|
+
self,
|
44
|
+
message,
|
45
|
+
job_context: Optional[SnowparkJobContext],
|
46
|
+
checkpoint_name=None,
|
47
|
+
data=None,
|
48
|
+
):
|
49
|
+
super().__init__(message, job_context, checkpoint_name, data)
|
50
|
+
|
51
|
+
|
52
|
+
class SchemaValidationError(SnowparkCheckpointError):
|
53
|
+
def __init__(
|
54
|
+
self,
|
55
|
+
message,
|
56
|
+
job_context: Optional[SnowparkJobContext],
|
57
|
+
checkpoint_name=None,
|
58
|
+
data=None,
|
59
|
+
):
|
60
|
+
super().__init__(message, job_context, checkpoint_name, data)
|
@@ -0,0 +1,85 @@
|
|
1
|
+
# Copyright 2025 Snowflake Inc.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
from datetime import datetime
|
17
|
+
from typing import Optional
|
18
|
+
|
19
|
+
import pandas as pd
|
20
|
+
|
21
|
+
from pyspark.sql import SparkSession
|
22
|
+
|
23
|
+
from snowflake.snowpark import Session
|
24
|
+
from snowflake.snowpark_checkpoints.utils.constants import SCHEMA_EXECUTION_MODE
|
25
|
+
|
26
|
+
|
27
|
+
class SnowparkJobContext:
|
28
|
+
|
29
|
+
"""Class used to record migration results in Snowflake.
|
30
|
+
|
31
|
+
Args:
|
32
|
+
snowpark_session: A Snowpark session instance.
|
33
|
+
spark_session: A Spark session instance.
|
34
|
+
job_name: The name of the job.
|
35
|
+
log_results: Whether to log the migration results in Snowflake.
|
36
|
+
|
37
|
+
"""
|
38
|
+
|
39
|
+
def __init__(
|
40
|
+
self,
|
41
|
+
snowpark_session: Session,
|
42
|
+
spark_session: SparkSession = None,
|
43
|
+
job_name: Optional[str] = None,
|
44
|
+
log_results: Optional[bool] = True,
|
45
|
+
):
|
46
|
+
self.log_results = log_results
|
47
|
+
self.job_name = job_name
|
48
|
+
self.spark_session = spark_session or SparkSession.builder.getOrCreate()
|
49
|
+
self.snowpark_session = snowpark_session
|
50
|
+
|
51
|
+
def _mark_fail(
|
52
|
+
self, message, checkpoint_name, data, execution_mode=SCHEMA_EXECUTION_MODE
|
53
|
+
):
|
54
|
+
if self.log_results:
|
55
|
+
session = self.snowpark_session
|
56
|
+
df = pd.DataFrame(
|
57
|
+
{
|
58
|
+
"DATE": [datetime.now()],
|
59
|
+
"JOB": [self.job_name],
|
60
|
+
"STATUS": ["fail"],
|
61
|
+
"CHECKPOINT": [checkpoint_name],
|
62
|
+
"MESSAGE": [message],
|
63
|
+
"DATA": [f"{data}"],
|
64
|
+
"EXECUTION_MODE": [execution_mode],
|
65
|
+
}
|
66
|
+
)
|
67
|
+
report_df = session.createDataFrame(df)
|
68
|
+
report_df.write.mode("append").save_as_table("SNOWPARK_CHECKPOINTS_REPORT")
|
69
|
+
|
70
|
+
def _mark_pass(self, checkpoint_name, execution_mode=SCHEMA_EXECUTION_MODE):
|
71
|
+
if self.log_results:
|
72
|
+
session = self.snowpark_session
|
73
|
+
df = pd.DataFrame(
|
74
|
+
{
|
75
|
+
"DATE": [datetime.now()],
|
76
|
+
"JOB": [self.job_name],
|
77
|
+
"STATUS": ["pass"],
|
78
|
+
"CHECKPOINT": [checkpoint_name],
|
79
|
+
"MESSAGE": [""],
|
80
|
+
"DATA": [""],
|
81
|
+
"EXECUTION_MODE": [execution_mode],
|
82
|
+
}
|
83
|
+
)
|
84
|
+
report_df = session.createDataFrame(df)
|
85
|
+
report_df.write.mode("append").save_as_table("SNOWPARK_CHECKPOINTS_REPORT")
|