snowpark-checkpoints-validators 0.2.0rc1__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. snowflake/snowpark_checkpoints/__init__.py +44 -0
  2. snowflake/snowpark_checkpoints/__version__.py +16 -0
  3. snowflake/snowpark_checkpoints/checkpoint.py +580 -0
  4. snowflake/snowpark_checkpoints/errors.py +60 -0
  5. snowflake/snowpark_checkpoints/io_utils/__init__.py +26 -0
  6. snowflake/snowpark_checkpoints/io_utils/io_default_strategy.py +57 -0
  7. snowflake/snowpark_checkpoints/io_utils/io_env_strategy.py +133 -0
  8. snowflake/snowpark_checkpoints/io_utils/io_file_manager.py +76 -0
  9. snowflake/snowpark_checkpoints/job_context.py +128 -0
  10. snowflake/snowpark_checkpoints/singleton.py +23 -0
  11. snowflake/snowpark_checkpoints/snowpark_sampler.py +124 -0
  12. snowflake/snowpark_checkpoints/spark_migration.py +255 -0
  13. snowflake/snowpark_checkpoints/utils/__init__.py +14 -0
  14. snowflake/snowpark_checkpoints/utils/constants.py +134 -0
  15. snowflake/snowpark_checkpoints/utils/extra_config.py +132 -0
  16. snowflake/snowpark_checkpoints/utils/logging_utils.py +67 -0
  17. snowflake/snowpark_checkpoints/utils/pandera_check_manager.py +399 -0
  18. snowflake/snowpark_checkpoints/utils/supported_types.py +65 -0
  19. snowflake/snowpark_checkpoints/utils/telemetry.py +939 -0
  20. snowflake/snowpark_checkpoints/utils/utils_checks.py +398 -0
  21. snowflake/snowpark_checkpoints/validation_result_metadata.py +159 -0
  22. snowflake/snowpark_checkpoints/validation_results.py +49 -0
  23. snowpark_checkpoints_validators-0.3.0.dist-info/METADATA +325 -0
  24. snowpark_checkpoints_validators-0.3.0.dist-info/RECORD +26 -0
  25. snowpark_checkpoints_validators-0.2.0rc1.dist-info/METADATA +0 -514
  26. snowpark_checkpoints_validators-0.2.0rc1.dist-info/RECORD +0 -4
  27. {snowpark_checkpoints_validators-0.2.0rc1.dist-info → snowpark_checkpoints_validators-0.3.0.dist-info}/WHEEL +0 -0
  28. {snowpark_checkpoints_validators-0.2.0rc1.dist-info → snowpark_checkpoints_validators-0.3.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,44 @@
1
+ # Copyright 2025 Snowflake Inc.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import logging
17
+
18
+
19
+ # Add a NullHandler to prevent logging messages from being output to
20
+ # sys.stderr if no logging configuration is provided.
21
+ logging.getLogger(__name__).addHandler(logging.NullHandler())
22
+
23
+ # ruff: noqa: E402
24
+
25
+ from snowflake.snowpark_checkpoints.checkpoint import (
26
+ check_dataframe_schema,
27
+ check_input_schema,
28
+ check_output_schema,
29
+ validate_dataframe_checkpoint,
30
+ )
31
+ from snowflake.snowpark_checkpoints.job_context import SnowparkJobContext
32
+ from snowflake.snowpark_checkpoints.spark_migration import check_with_spark
33
+ from snowflake.snowpark_checkpoints.utils.constants import CheckpointMode
34
+
35
+
36
+ __all__ = [
37
+ "check_with_spark",
38
+ "SnowparkJobContext",
39
+ "check_dataframe_schema",
40
+ "check_output_schema",
41
+ "check_input_schema",
42
+ "validate_dataframe_checkpoint",
43
+ "CheckpointMode",
44
+ ]
@@ -0,0 +1,16 @@
1
+ # Copyright 2025 Snowflake Inc.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ __version__ = "0.3.0"
@@ -0,0 +1,580 @@
1
+ # Copyright 2025 Snowflake Inc.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # Wrapper around pandera which logs to snowflake
17
+
18
+ import logging
19
+
20
+ from typing import Any, Optional, Union, cast
21
+
22
+ from pandas import DataFrame as PandasDataFrame
23
+ from pandera import Check, DataFrameModel, DataFrameSchema
24
+ from pandera.errors import SchemaError, SchemaErrors
25
+
26
+ from snowflake.snowpark import DataFrame as SnowparkDataFrame
27
+ from snowflake.snowpark_checkpoints.errors import SchemaValidationError
28
+ from snowflake.snowpark_checkpoints.job_context import SnowparkJobContext
29
+ from snowflake.snowpark_checkpoints.snowpark_sampler import (
30
+ SamplingAdapter,
31
+ SamplingStrategy,
32
+ )
33
+ from snowflake.snowpark_checkpoints.utils.constants import (
34
+ FAIL_STATUS,
35
+ PASS_STATUS,
36
+ CheckpointMode,
37
+ )
38
+ from snowflake.snowpark_checkpoints.utils.extra_config import is_checkpoint_enabled
39
+ from snowflake.snowpark_checkpoints.utils.logging_utils import log
40
+ from snowflake.snowpark_checkpoints.utils.pandera_check_manager import (
41
+ PanderaCheckManager,
42
+ )
43
+ from snowflake.snowpark_checkpoints.utils.telemetry import STATUS_KEY, report_telemetry
44
+ from snowflake.snowpark_checkpoints.utils.utils_checks import (
45
+ _check_compare_data,
46
+ _generate_schema,
47
+ _process_sampling,
48
+ _replace_special_characters,
49
+ _update_validation_result,
50
+ )
51
+
52
+
53
+ LOGGER = logging.getLogger(__name__)
54
+
55
+
56
+ @log
57
+ def validate_dataframe_checkpoint(
58
+ df: SnowparkDataFrame,
59
+ checkpoint_name: str,
60
+ job_context: Optional[SnowparkJobContext] = None,
61
+ mode: Optional[CheckpointMode] = CheckpointMode.SCHEMA,
62
+ custom_checks: Optional[dict[Any, Any]] = None,
63
+ skip_checks: Optional[dict[Any, Any]] = None,
64
+ sample_frac: Optional[float] = 1.0,
65
+ sample_number: Optional[int] = None,
66
+ sampling_strategy: Optional[SamplingStrategy] = SamplingStrategy.RANDOM_SAMPLE,
67
+ output_path: Optional[str] = None,
68
+ ) -> Union[tuple[bool, PandasDataFrame], None]:
69
+ """Validate a Snowpark DataFrame against a specified checkpoint.
70
+
71
+ Args:
72
+ df (SnowparkDataFrame): The DataFrame to validate.
73
+ checkpoint_name (str): The name of the checkpoint to validate against.
74
+ job_context (SnowparkJobContext, optional): The job context for the validation. Required for PARQUET mode.
75
+ mode (CheckpointMode): The mode of validation (e.g., SCHEMA, PARQUET). Defaults to SCHEMA.
76
+ custom_checks (Optional[dict[Any, Any]], optional): Custom checks to apply during validation.
77
+ skip_checks (Optional[dict[Any, Any]], optional): Checks to skip during validation.
78
+ sample_frac (Optional[float], optional): Fraction of the DataFrame to sample for validation. Defaults to 0.1.
79
+ sample_number (Optional[int], optional): Number of rows to sample for validation.
80
+ sampling_strategy (Optional[SamplingStrategy], optional): Strategy to use for sampling.
81
+ Defaults to RANDOM_SAMPLE.
82
+ output_path (Optional[str], optional): The output path for the validation results.
83
+
84
+ Returns:
85
+ Union[tuple[bool, PandasDataFrame], None]: A tuple containing a boolean indicating success
86
+ and a Pandas DataFrame with validation results, or None if validation is not applicable.
87
+
88
+ Raises:
89
+ ValueError: If an invalid validation mode is provided or if job_context is None for PARQUET mode.
90
+
91
+ """
92
+ checkpoint_name = _replace_special_characters(checkpoint_name)
93
+
94
+ if not is_checkpoint_enabled(checkpoint_name):
95
+ LOGGER.warning(
96
+ "Checkpoint '%s' is disabled. Skipping DataFrame checkpoint validation.",
97
+ checkpoint_name,
98
+ )
99
+ return None
100
+
101
+ LOGGER.info(
102
+ "Starting DataFrame checkpoint validation for checkpoint '%s'", checkpoint_name
103
+ )
104
+
105
+ if mode == CheckpointMode.SCHEMA:
106
+ result = _check_dataframe_schema_file(
107
+ df,
108
+ checkpoint_name,
109
+ job_context,
110
+ custom_checks,
111
+ skip_checks,
112
+ sample_frac,
113
+ sample_number,
114
+ sampling_strategy,
115
+ output_path,
116
+ )
117
+ return result
118
+
119
+ if mode == CheckpointMode.DATAFRAME:
120
+ if job_context is None:
121
+ raise ValueError(
122
+ "No job context provided. Please provide one when using DataFrame mode validation."
123
+ )
124
+ _check_compare_data(df, job_context, checkpoint_name, output_path)
125
+ return None
126
+
127
+ raise ValueError(
128
+ (
129
+ "Invalid validation mode. "
130
+ "Please use 1 for schema validation or 2 for full data validation."
131
+ ),
132
+ )
133
+
134
+
135
+ def _check_dataframe_schema_file(
136
+ df: SnowparkDataFrame,
137
+ checkpoint_name: str,
138
+ job_context: Optional[SnowparkJobContext] = None,
139
+ custom_checks: Optional[dict[Any, Any]] = None,
140
+ skip_checks: Optional[dict[Any, Any]] = None,
141
+ sample_frac: Optional[float] = 1.0,
142
+ sample_number: Optional[int] = None,
143
+ sampling_strategy: Optional[SamplingStrategy] = SamplingStrategy.RANDOM_SAMPLE,
144
+ output_path: Optional[str] = None,
145
+ ) -> tuple[bool, PandasDataFrame]:
146
+ """Generate and checks the schema for a given DataFrame based on a checkpoint name.
147
+
148
+ Args:
149
+ df (SnowparkDataFrame): The DataFrame to be validated.
150
+ checkpoint_name (str): The name of the checkpoint to retrieve the schema.
151
+ job_context (SnowparkJobContext, optional): Context for job-related operations.
152
+ Defaults to None.
153
+ custom_checks (dict[Any, Any], optional): Custom checks to be added to the schema.
154
+ Defaults to None.
155
+ skip_checks (dict[Any, Any], optional): Checks to be skipped.
156
+ Defaults to None.
157
+ sample_frac (float, optional): Fraction of data to sample.
158
+ Defaults to 0.1.
159
+ sample_number (int, optional): Number of rows to sample.
160
+ Defaults to None.
161
+ sampling_strategy (SamplingStrategy, optional): Strategy for sampling data.
162
+ Defaults to SamplingStrategy.RANDOM_SAMPLE.
163
+ output_path (str, optional): The output path for the validation results.
164
+
165
+ Raises:
166
+ SchemaValidationError: If the DataFrame fails schema validation.
167
+
168
+ Returns:
169
+ tuple[bool, PanderaDataFrame]: A tuple containing the validity flag and the Pandera DataFrame.
170
+
171
+ """
172
+ if df is None:
173
+ raise ValueError("DataFrame is required")
174
+
175
+ if checkpoint_name is None:
176
+ raise ValueError("Checkpoint name is required")
177
+
178
+ schema = _generate_schema(checkpoint_name, output_path)
179
+
180
+ return _check_dataframe_schema(
181
+ df,
182
+ schema,
183
+ checkpoint_name,
184
+ job_context,
185
+ custom_checks,
186
+ skip_checks,
187
+ sample_frac,
188
+ sample_number,
189
+ sampling_strategy,
190
+ output_path,
191
+ )
192
+
193
+
194
+ @log
195
+ def check_dataframe_schema(
196
+ df: SnowparkDataFrame,
197
+ pandera_schema: DataFrameSchema,
198
+ checkpoint_name: str,
199
+ job_context: Optional[SnowparkJobContext] = None,
200
+ custom_checks: Optional[dict[str, list[Check]]] = None,
201
+ skip_checks: Optional[dict[Any, Any]] = None,
202
+ sample_frac: Optional[float] = 1.0,
203
+ sample_number: Optional[int] = None,
204
+ sampling_strategy: Optional[SamplingStrategy] = SamplingStrategy.RANDOM_SAMPLE,
205
+ output_path: Optional[str] = None,
206
+ ) -> Union[tuple[bool, PandasDataFrame], None]:
207
+ """Validate a DataFrame against a given Pandera schema using sampling techniques.
208
+
209
+ Args:
210
+ df (SnowparkDataFrame): The DataFrame to be validated.
211
+ pandera_schema (DataFrameSchema): The Pandera schema to validate against.
212
+ checkpoint_name (str, optional): The name of the checkpoint to retrieve the schema.
213
+ Defaults to None.
214
+ job_context (SnowparkJobContext, optional): Context for job-related operations.
215
+ Defaults to None.
216
+ custom_checks (dict[Any, Any], optional): Custom checks to be added to the schema.
217
+ Defaults to None.
218
+ skip_checks (dict[Any, Any], optional): Checks to be skipped.
219
+ Defaults to None.
220
+ sample_frac (float, optional): Fraction of data to sample.
221
+ Defaults to 0.1.
222
+ sample_number (int, optional): Number of rows to sample.
223
+ Defaults to None.
224
+ sampling_strategy (SamplingStrategy, optional): Strategy for sampling data.
225
+ Defaults to SamplingStrategy.RANDOM_SAMPLE.
226
+ output_path (str, optional): The output path for the validation results.
227
+
228
+ Raises:
229
+ SchemaValidationError: If the DataFrame fails schema validation.
230
+
231
+ Returns:
232
+ Union[tuple[bool, PandasDataFrame]|None]: A tuple containing the validity flag and the Pandas DataFrame.
233
+ If the validation for that checkpoint is disabled it returns None.
234
+
235
+ """
236
+ checkpoint_name = _replace_special_characters(checkpoint_name)
237
+ LOGGER.info(
238
+ "Starting DataFrame schema validation for checkpoint '%s'", checkpoint_name
239
+ )
240
+
241
+ if df is None:
242
+ raise ValueError("DataFrame is required")
243
+
244
+ if pandera_schema is None:
245
+ raise ValueError("Schema is required")
246
+
247
+ if not is_checkpoint_enabled(checkpoint_name):
248
+ LOGGER.warning(
249
+ "Checkpoint '%s' is disabled. Skipping DataFrame schema validation.",
250
+ checkpoint_name,
251
+ )
252
+ return None
253
+
254
+ return _check_dataframe_schema(
255
+ df,
256
+ pandera_schema,
257
+ checkpoint_name,
258
+ job_context,
259
+ custom_checks,
260
+ skip_checks,
261
+ sample_frac,
262
+ sample_number,
263
+ sampling_strategy,
264
+ output_path,
265
+ )
266
+
267
+
268
+ @report_telemetry(
269
+ params_list=["pandera_schema"],
270
+ return_indexes=[(STATUS_KEY, 0)],
271
+ multiple_return=True,
272
+ )
273
+ def _check_dataframe_schema(
274
+ df: SnowparkDataFrame,
275
+ pandera_schema: DataFrameSchema,
276
+ checkpoint_name: str,
277
+ job_context: SnowparkJobContext = None,
278
+ custom_checks: Optional[dict[str, list[Check]]] = None,
279
+ skip_checks: Optional[dict[Any, Any]] = None,
280
+ sample_frac: Optional[float] = 1.0,
281
+ sample_number: Optional[int] = None,
282
+ sampling_strategy: Optional[SamplingStrategy] = SamplingStrategy.RANDOM_SAMPLE,
283
+ output_path: Optional[str] = None,
284
+ ) -> tuple[bool, PandasDataFrame]:
285
+
286
+ pandera_check_manager = PanderaCheckManager(checkpoint_name, pandera_schema)
287
+ pandera_check_manager.skip_checks_on_schema(skip_checks)
288
+ pandera_check_manager.add_custom_checks(custom_checks)
289
+
290
+ pandera_schema_upper, sample_df = _process_sampling(
291
+ df, pandera_schema, job_context, sample_frac, sample_number, sampling_strategy
292
+ )
293
+ is_valid, validation_result = _validate(pandera_schema_upper, sample_df)
294
+ if is_valid:
295
+ LOGGER.info(
296
+ "DataFrame schema validation passed for checkpoint '%s'",
297
+ checkpoint_name,
298
+ )
299
+ if job_context is not None:
300
+ job_context._mark_pass(checkpoint_name)
301
+ else:
302
+ LOGGER.warning(
303
+ "No job context provided. Skipping result recording into Snowflake.",
304
+ )
305
+ _update_validation_result(checkpoint_name, PASS_STATUS, output_path)
306
+ else:
307
+ LOGGER.error(
308
+ "DataFrame schema validation failed for checkpoint '%s'",
309
+ checkpoint_name,
310
+ )
311
+ _update_validation_result(checkpoint_name, FAIL_STATUS, output_path)
312
+ raise SchemaValidationError(
313
+ "Snowpark DataFrame schema validation error",
314
+ job_context,
315
+ checkpoint_name,
316
+ validation_result,
317
+ )
318
+
319
+ return (is_valid, validation_result)
320
+
321
+
322
+ @report_telemetry(params_list=["pandera_schema"])
323
+ @log
324
+ def check_output_schema(
325
+ pandera_schema: DataFrameSchema,
326
+ checkpoint_name: str,
327
+ sample_frac: Optional[float] = 1.0,
328
+ sample_number: Optional[int] = None,
329
+ sampling_strategy: Optional[SamplingStrategy] = SamplingStrategy.RANDOM_SAMPLE,
330
+ job_context: Optional[SnowparkJobContext] = None,
331
+ output_path: Optional[str] = None,
332
+ ):
333
+ """Decorate to validate the schema of the output of a Snowpark function.
334
+
335
+ Args:
336
+ pandera_schema (DataFrameSchema): The Pandera schema to validate against.
337
+ checkpoint_name (Optional[str], optional): The name of the checkpoint to retrieve the schema.
338
+ sample_frac (Optional[float], optional): Fraction of data to sample.
339
+ Defaults to 0.1.
340
+ sample_number (Optional[int], optional): Number of rows to sample.
341
+ Defaults to None.
342
+ sampling_strategy (Optional[SamplingStrategy], optional): Strategy for sampling data.
343
+ Defaults to SamplingStrategy.RANDOM_SAMPLE.
344
+ job_context (SnowparkJobContext, optional): Context for job-related operations.
345
+ Defaults to None.
346
+ output_path (Optional[str], optional): The output path for the validation results.
347
+
348
+ """
349
+
350
+ def check_output_with_decorator(snowpark_fn):
351
+ """Decorate to validate the schema of the output of a Snowpark function.
352
+
353
+ Args:
354
+ snowpark_fn (function): The Snowpark function to validate.
355
+
356
+ Returns:
357
+ function: The decorated function.
358
+
359
+ """
360
+
361
+ @log(log_args=False)
362
+ def wrapper(*args, **kwargs):
363
+ """Wrapp a function to validate the schema of the output of a Snowpark function.
364
+
365
+ Args:
366
+ *args: The arguments to the Snowpark function.
367
+ **kwargs: The keyword arguments to the Snowpark function.
368
+
369
+ Returns:
370
+ Any: The result of the Snowpark function.
371
+
372
+ """
373
+ _checkpoint_name = checkpoint_name
374
+ if checkpoint_name is None:
375
+ LOGGER.warning(
376
+ (
377
+ "No checkpoint name provided for output schema validation. "
378
+ "Using '%s' as the checkpoint name.",
379
+ ),
380
+ snowpark_fn.__name__,
381
+ )
382
+ _checkpoint_name = snowpark_fn.__name__
383
+ _checkpoint_name = _replace_special_characters(_checkpoint_name)
384
+ LOGGER.info(
385
+ "Starting output schema validation for Snowpark function '%s' and checkpoint '%s'",
386
+ snowpark_fn.__name__,
387
+ _checkpoint_name,
388
+ )
389
+
390
+ # Run the sampled data in snowpark
391
+ LOGGER.info("Running the Snowpark function '%s'", snowpark_fn.__name__)
392
+ snowpark_results = snowpark_fn(*args, **kwargs)
393
+ sampler = SamplingAdapter(
394
+ job_context, sample_frac, sample_number, sampling_strategy
395
+ )
396
+ sampler.process_args([snowpark_results])
397
+ pandas_sample_args = sampler.get_sampled_pandas_args()
398
+
399
+ is_valid, validation_result = _validate(
400
+ pandera_schema, pandas_sample_args[0]
401
+ )
402
+ if is_valid:
403
+ LOGGER.info(
404
+ "Output schema validation passed for Snowpark function '%s' and checkpoint '%s'",
405
+ snowpark_fn.__name__,
406
+ _checkpoint_name,
407
+ )
408
+ if job_context is not None:
409
+ job_context._mark_pass(_checkpoint_name)
410
+ else:
411
+ LOGGER.warning(
412
+ "No job context provided. Skipping result recording into Snowflake.",
413
+ )
414
+ _update_validation_result(_checkpoint_name, PASS_STATUS, output_path)
415
+ else:
416
+ LOGGER.error(
417
+ "Output schema validation failed for Snowpark function '%s' and checkpoint '%s'",
418
+ snowpark_fn.__name__,
419
+ _checkpoint_name,
420
+ )
421
+ _update_validation_result(_checkpoint_name, FAIL_STATUS, output_path)
422
+ raise SchemaValidationError(
423
+ "Snowpark output schema validation error",
424
+ job_context,
425
+ _checkpoint_name,
426
+ validation_result,
427
+ )
428
+ return snowpark_results
429
+
430
+ return wrapper
431
+
432
+ return check_output_with_decorator
433
+
434
+
435
+ @report_telemetry(params_list=["pandera_schema"])
436
+ @log
437
+ def check_input_schema(
438
+ pandera_schema: DataFrameSchema,
439
+ checkpoint_name: str,
440
+ sample_frac: Optional[float] = 1.0,
441
+ sample_number: Optional[int] = None,
442
+ sampling_strategy: Optional[SamplingStrategy] = SamplingStrategy.RANDOM_SAMPLE,
443
+ job_context: Optional[SnowparkJobContext] = None,
444
+ output_path: Optional[str] = None,
445
+ ):
446
+ """Decorate factory for validating input DataFrame schemas before function execution.
447
+
448
+ Args:
449
+ pandera_schema (DataFrameSchema): The Pandera schema to validate against.
450
+ checkpoint_name (Optional[str], optional): The name of the checkpoint to retrieve the schema.
451
+ sample_frac (Optional[float], optional): Fraction of data to sample.
452
+ Defaults to 0.1.
453
+ sample_number (Optional[int], optional): Number of rows to sample.
454
+ Defaults to None.
455
+ sampling_strategy (Optional[SamplingStrategy], optional): Strategy for sampling data.
456
+ Defaults to SamplingStrategy.RANDOM_SAMPLE.
457
+ job_context (SnowparkJobContext, optional): Context for job-related operations.
458
+ Defaults to None.
459
+ output_path (Optional[str], optional): The output path for the validation results.
460
+
461
+
462
+ """
463
+
464
+ def check_input_with_decorator(snowpark_fn):
465
+ """Decorate that validates input schemas for the decorated function.
466
+
467
+ Args:
468
+ snowpark_fn (Callable): The function to be decorated with input schema validation.
469
+
470
+ Raises:
471
+ SchemaValidationError: If input data fails schema validation.
472
+
473
+ Returns:
474
+ Callable: A wrapper function that performs schema validation before executing the original function.
475
+
476
+ """
477
+
478
+ @log(log_args=False)
479
+ def wrapper(*args, **kwargs):
480
+ """Wrapp a function to validate the schema of the input of a Snowpark function.
481
+
482
+ Raises:
483
+ SchemaValidationError: If any input DataFrame fails schema validation.
484
+
485
+ Returns:
486
+ Any: The result of the original function after input validation.
487
+
488
+ """
489
+ _checkpoint_name = checkpoint_name
490
+ if checkpoint_name is None:
491
+ LOGGER.warning(
492
+ (
493
+ "No checkpoint name provided for input schema validation. "
494
+ "Using '%s' as the checkpoint name."
495
+ ),
496
+ snowpark_fn.__name__,
497
+ )
498
+ _checkpoint_name = snowpark_fn.__name__
499
+ _checkpoint_name = _replace_special_characters(_checkpoint_name)
500
+ LOGGER.info(
501
+ "Starting input schema validation for Snowpark function '%s' and checkpoint '%s'",
502
+ snowpark_fn.__name__,
503
+ _checkpoint_name,
504
+ )
505
+
506
+ # Run the sampled data in snowpark
507
+ sampler = SamplingAdapter(
508
+ job_context, sample_frac, sample_number, sampling_strategy
509
+ )
510
+ sampler.process_args(args)
511
+ pandas_sample_args = sampler.get_sampled_pandas_args()
512
+
513
+ LOGGER.info(
514
+ "Validating %s input argument(s) against a Pandera schema",
515
+ len(pandas_sample_args),
516
+ )
517
+ # Raises SchemaError on validation issues
518
+ for index, arg in enumerate(pandas_sample_args, start=1):
519
+ if not isinstance(arg, PandasDataFrame):
520
+ LOGGER.info(
521
+ "Arg %s: Skipping schema validation for non-DataFrame argument",
522
+ index,
523
+ )
524
+ continue
525
+
526
+ is_valid, validation_result = _validate(
527
+ pandera_schema,
528
+ arg,
529
+ )
530
+ if is_valid:
531
+ LOGGER.info(
532
+ "Arg %s: Input schema validation passed",
533
+ index,
534
+ )
535
+ if job_context is not None:
536
+ job_context._mark_pass(
537
+ _checkpoint_name,
538
+ )
539
+ _update_validation_result(
540
+ _checkpoint_name,
541
+ PASS_STATUS,
542
+ output_path,
543
+ )
544
+ else:
545
+ LOGGER.error(
546
+ "Arg %s: Input schema validation failed",
547
+ index,
548
+ )
549
+ _update_validation_result(
550
+ _checkpoint_name,
551
+ FAIL_STATUS,
552
+ output_path,
553
+ )
554
+ raise SchemaValidationError(
555
+ "Snowpark input schema validation error",
556
+ job_context,
557
+ _checkpoint_name,
558
+ validation_result,
559
+ )
560
+ return snowpark_fn(*args, **kwargs)
561
+
562
+ return wrapper
563
+
564
+ return check_input_with_decorator
565
+
566
+
567
+ def _validate(
568
+ schema: Union[type[DataFrameModel], DataFrameSchema],
569
+ df: PandasDataFrame,
570
+ lazy: bool = True,
571
+ ) -> tuple[bool, PandasDataFrame]:
572
+ if not isinstance(schema, DataFrameSchema):
573
+ schema = schema.to_schema()
574
+ is_valid = True
575
+ try:
576
+ df = schema.validate(df, lazy=lazy)
577
+ except (SchemaErrors, SchemaError) as schema_errors:
578
+ df = cast(PandasDataFrame, schema_errors.failure_cases)
579
+ is_valid = False
580
+ return is_valid, df
@@ -0,0 +1,60 @@
1
+ # Copyright 2025 Snowflake Inc.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Optional
17
+
18
+ from snowflake.snowpark_checkpoints.job_context import SnowparkJobContext
19
+
20
+
21
+ class SnowparkCheckpointError(Exception):
22
+ def __init__(
23
+ self,
24
+ message,
25
+ job_context: Optional[SnowparkJobContext],
26
+ checkpoint_name: str,
27
+ data=None,
28
+ ):
29
+ job_name = job_context.job_name if job_context else "Unknown Job"
30
+ super().__init__(
31
+ f"Job: {job_name} Checkpoint: {checkpoint_name}\n{message} \n {data}"
32
+ )
33
+ if job_context:
34
+ job_context._mark_fail(
35
+ message,
36
+ checkpoint_name,
37
+ data,
38
+ )
39
+
40
+
41
+ class SparkMigrationError(SnowparkCheckpointError):
42
+ def __init__(
43
+ self,
44
+ message,
45
+ job_context: Optional[SnowparkJobContext],
46
+ checkpoint_name=None,
47
+ data=None,
48
+ ):
49
+ super().__init__(message, job_context, checkpoint_name, data)
50
+
51
+
52
+ class SchemaValidationError(SnowparkCheckpointError):
53
+ def __init__(
54
+ self,
55
+ message,
56
+ job_context: Optional[SnowparkJobContext],
57
+ checkpoint_name=None,
58
+ data=None,
59
+ ):
60
+ super().__init__(message, job_context, checkpoint_name, data)