snowpark-checkpoints-validators 0.1.0rc3__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (23) hide show
  1. snowflake/snowpark_checkpoints/__init__.py +34 -0
  2. snowflake/snowpark_checkpoints/__version__.py +16 -0
  3. snowflake/snowpark_checkpoints/checkpoint.py +482 -0
  4. snowflake/snowpark_checkpoints/errors.py +60 -0
  5. snowflake/snowpark_checkpoints/job_context.py +85 -0
  6. snowflake/snowpark_checkpoints/singleton.py +23 -0
  7. snowflake/snowpark_checkpoints/snowpark_sampler.py +99 -0
  8. snowflake/snowpark_checkpoints/spark_migration.py +222 -0
  9. snowflake/snowpark_checkpoints/utils/__init__.py +14 -0
  10. snowflake/snowpark_checkpoints/utils/checkpoint_logger.py +52 -0
  11. snowflake/snowpark_checkpoints/utils/constants.py +134 -0
  12. snowflake/snowpark_checkpoints/utils/extra_config.py +84 -0
  13. snowflake/snowpark_checkpoints/utils/pandera_check_manager.py +358 -0
  14. snowflake/snowpark_checkpoints/utils/supported_types.py +65 -0
  15. snowflake/snowpark_checkpoints/utils/telemetry.py +900 -0
  16. snowflake/snowpark_checkpoints/utils/utils_checks.py +374 -0
  17. snowflake/snowpark_checkpoints/validation_result_metadata.py +125 -0
  18. snowflake/snowpark_checkpoints/validation_results.py +49 -0
  19. {snowpark_checkpoints_validators-0.1.0rc3.dist-info → snowpark_checkpoints_validators-0.1.2.dist-info}/METADATA +4 -7
  20. snowpark_checkpoints_validators-0.1.2.dist-info/RECORD +22 -0
  21. snowpark_checkpoints_validators-0.1.0rc3.dist-info/RECORD +0 -4
  22. {snowpark_checkpoints_validators-0.1.0rc3.dist-info → snowpark_checkpoints_validators-0.1.2.dist-info}/WHEEL +0 -0
  23. {snowpark_checkpoints_validators-0.1.0rc3.dist-info → snowpark_checkpoints_validators-0.1.2.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,34 @@
1
+ # Copyright 2025 Snowflake Inc.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from snowflake.snowpark_checkpoints.checkpoint import (
17
+ check_dataframe_schema,
18
+ check_output_schema,
19
+ check_input_schema,
20
+ validate_dataframe_checkpoint,
21
+ )
22
+ from snowflake.snowpark_checkpoints.job_context import SnowparkJobContext
23
+ from snowflake.snowpark_checkpoints.spark_migration import check_with_spark
24
+ from snowflake.snowpark_checkpoints.utils.constants import CheckpointMode
25
+
26
+ __all__ = [
27
+ "check_with_spark",
28
+ "SnowparkJobContext",
29
+ "check_dataframe_schema",
30
+ "check_output_schema",
31
+ "check_input_schema",
32
+ "validate_dataframe_checkpoint",
33
+ "CheckpointMode",
34
+ ]
@@ -0,0 +1,16 @@
1
+ # Copyright 2025 Snowflake Inc.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ __version__ = "0.1.2"
@@ -0,0 +1,482 @@
1
+ # Copyright 2025 Snowflake Inc.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # Wrapper around pandera which logs to snowflake
17
+ from typing import Any, Optional, Union
18
+
19
+ from pandas import DataFrame as PandasDataFrame
20
+ from pandera import Check, DataFrameSchema
21
+ from pandera_report import DataFrameValidator
22
+
23
+ from snowflake.snowpark import DataFrame as SnowparkDataFrame
24
+ from snowflake.snowpark_checkpoints.errors import SchemaValidationError
25
+ from snowflake.snowpark_checkpoints.job_context import SnowparkJobContext
26
+ from snowflake.snowpark_checkpoints.snowpark_sampler import (
27
+ SamplingAdapter,
28
+ SamplingStrategy,
29
+ )
30
+ from snowflake.snowpark_checkpoints.utils.checkpoint_logger import CheckpointLogger
31
+ from snowflake.snowpark_checkpoints.utils.constants import (
32
+ FAIL_STATUS,
33
+ PASS_STATUS,
34
+ CheckpointMode,
35
+ )
36
+ from snowflake.snowpark_checkpoints.utils.extra_config import is_checkpoint_enabled
37
+ from snowflake.snowpark_checkpoints.utils.pandera_check_manager import (
38
+ PanderaCheckManager,
39
+ )
40
+ from snowflake.snowpark_checkpoints.utils.telemetry import STATUS_KEY, report_telemetry
41
+ from snowflake.snowpark_checkpoints.utils.utils_checks import (
42
+ _check_compare_data,
43
+ _generate_schema,
44
+ _process_sampling,
45
+ _replace_special_characters,
46
+ _update_validation_result,
47
+ )
48
+
49
+
50
+ def validate_dataframe_checkpoint(
51
+ df: SnowparkDataFrame,
52
+ checkpoint_name: str,
53
+ job_context: Optional[SnowparkJobContext] = None,
54
+ mode: Optional[CheckpointMode] = CheckpointMode.SCHEMA,
55
+ custom_checks: Optional[dict[Any, Any]] = None,
56
+ skip_checks: Optional[dict[Any, Any]] = None,
57
+ sample_frac: Optional[float] = 1.0,
58
+ sample_number: Optional[int] = None,
59
+ sampling_strategy: Optional[SamplingStrategy] = SamplingStrategy.RANDOM_SAMPLE,
60
+ output_path: Optional[str] = None,
61
+ ) -> Union[tuple[bool, PandasDataFrame], None]:
62
+ """Validate a Snowpark DataFrame against a specified checkpoint.
63
+
64
+ Args:
65
+ df (SnowparkDataFrame): The DataFrame to validate.
66
+ checkpoint_name (str): The name of the checkpoint to validate against.
67
+ job_context (SnowparkJobContext, optional): The job context for the validation. Required for PARQUET mode.
68
+ mode (CheckpointMode): The mode of validation (e.g., SCHEMA, PARQUET). Defaults to SCHEMA.
69
+ custom_checks (Optional[dict[Any, Any]], optional): Custom checks to apply during validation.
70
+ skip_checks (Optional[dict[Any, Any]], optional): Checks to skip during validation.
71
+ sample_frac (Optional[float], optional): Fraction of the DataFrame to sample for validation. Defaults to 0.1.
72
+ sample_number (Optional[int], optional): Number of rows to sample for validation.
73
+ sampling_strategy (Optional[SamplingStrategy], optional): Strategy to use for sampling.
74
+ Defaults to RANDOM_SAMPLE.
75
+ output_path (Optional[str], optional): The output path for the validation results.
76
+
77
+ Returns:
78
+ Union[tuple[bool, PandasDataFrame], None]: A tuple containing a boolean indicating success
79
+ and a Pandas DataFrame with validation results, or None if validation is not applicable.
80
+
81
+ Raises:
82
+ ValueError: If an invalid validation mode is provided or if job_context is None for PARQUET mode.
83
+
84
+ """
85
+ checkpoint_name = _replace_special_characters(checkpoint_name)
86
+
87
+ if is_checkpoint_enabled(checkpoint_name):
88
+
89
+ if mode == CheckpointMode.SCHEMA:
90
+ return _check_dataframe_schema_file(
91
+ df,
92
+ checkpoint_name,
93
+ job_context,
94
+ custom_checks,
95
+ skip_checks,
96
+ sample_frac,
97
+ sample_number,
98
+ sampling_strategy,
99
+ output_path,
100
+ )
101
+ elif mode == CheckpointMode.DATAFRAME:
102
+ if job_context is None:
103
+ raise ValueError(
104
+ "Connectionless mode is not supported for Parquet validation"
105
+ )
106
+ _check_compare_data(df, job_context, checkpoint_name, output_path)
107
+ else:
108
+ raise ValueError(
109
+ """Invalid validation mode.
110
+ Please use for schema validation use a 1 or for a full data validation use a 2 for schema validation."""
111
+ )
112
+
113
+
114
+ def _check_dataframe_schema_file(
115
+ df: SnowparkDataFrame,
116
+ checkpoint_name: str,
117
+ job_context: Optional[SnowparkJobContext] = None,
118
+ custom_checks: Optional[dict[Any, Any]] = None,
119
+ skip_checks: Optional[dict[Any, Any]] = None,
120
+ sample_frac: Optional[float] = 1.0,
121
+ sample_number: Optional[int] = None,
122
+ sampling_strategy: Optional[SamplingStrategy] = SamplingStrategy.RANDOM_SAMPLE,
123
+ output_path: Optional[str] = None,
124
+ ) -> tuple[bool, PandasDataFrame]:
125
+ """Generate and checks the schema for a given DataFrame based on a checkpoint name.
126
+
127
+ Args:
128
+ df (SnowparkDataFrame): The DataFrame to be validated.
129
+ checkpoint_name (str): The name of the checkpoint to retrieve the schema.
130
+ job_context (SnowparkJobContext, optional): Context for job-related operations.
131
+ Defaults to None.
132
+ custom_checks (dict[Any, Any], optional): Custom checks to be added to the schema.
133
+ Defaults to None.
134
+ skip_checks (dict[Any, Any], optional): Checks to be skipped.
135
+ Defaults to None.
136
+ sample_frac (float, optional): Fraction of data to sample.
137
+ Defaults to 0.1.
138
+ sample_number (int, optional): Number of rows to sample.
139
+ Defaults to None.
140
+ sampling_strategy (SamplingStrategy, optional): Strategy for sampling data.
141
+ Defaults to SamplingStrategy.RANDOM_SAMPLE.
142
+ output_path (str, optional): The output path for the validation results.
143
+
144
+ Raises:
145
+ SchemaValidationError: If the DataFrame fails schema validation.
146
+
147
+ Returns:
148
+ tuple[bool, PanderaDataFrame]: A tuple containing the validity flag and the Pandera DataFrame.
149
+
150
+ """
151
+ if df is None:
152
+ raise ValueError("DataFrame is required")
153
+
154
+ if checkpoint_name is None:
155
+ raise ValueError("Checkpoint name is required")
156
+
157
+ schema = _generate_schema(checkpoint_name, output_path)
158
+
159
+ return check_dataframe_schema(
160
+ df,
161
+ schema,
162
+ checkpoint_name,
163
+ job_context,
164
+ custom_checks,
165
+ skip_checks,
166
+ sample_frac,
167
+ sample_number,
168
+ sampling_strategy,
169
+ output_path,
170
+ )
171
+
172
+
173
+ def check_dataframe_schema(
174
+ df: SnowparkDataFrame,
175
+ pandera_schema: DataFrameSchema,
176
+ checkpoint_name: str,
177
+ job_context: Optional[SnowparkJobContext] = None,
178
+ custom_checks: Optional[dict[str, list[Check]]] = None,
179
+ skip_checks: Optional[dict[Any, Any]] = None,
180
+ sample_frac: Optional[float] = 1.0,
181
+ sample_number: Optional[int] = None,
182
+ sampling_strategy: Optional[SamplingStrategy] = SamplingStrategy.RANDOM_SAMPLE,
183
+ output_path: Optional[str] = None,
184
+ ) -> Union[tuple[bool, PandasDataFrame], None]:
185
+ """Validate a DataFrame against a given Pandera schema using sampling techniques.
186
+
187
+ Args:
188
+ df (SnowparkDataFrame): The DataFrame to be validated.
189
+ pandera_schema (DataFrameSchema): The Pandera schema to validate against.
190
+ checkpoint_name (str, optional): The name of the checkpoint to retrieve the schema.
191
+ Defaults to None.
192
+ job_context (SnowparkJobContext, optional): Context for job-related operations.
193
+ Defaults to None.
194
+ custom_checks (dict[Any, Any], optional): Custom checks to be added to the schema.
195
+ Defaults to None.
196
+ skip_checks (dict[Any, Any], optional): Checks to be skipped.
197
+ Defaults to None.
198
+ sample_frac (float, optional): Fraction of data to sample.
199
+ Defaults to 0.1.
200
+ sample_number (int, optional): Number of rows to sample.
201
+ Defaults to None.
202
+ sampling_strategy (SamplingStrategy, optional): Strategy for sampling data.
203
+ Defaults to SamplingStrategy.RANDOM_SAMPLE.
204
+ output_path (str, optional): The output path for the validation results.
205
+
206
+ Raises:
207
+ SchemaValidationError: If the DataFrame fails schema validation.
208
+
209
+ Returns:
210
+ Union[tuple[bool, PandasDataFrame]|None]: A tuple containing the validity flag and the Pandas DataFrame.
211
+ If the validation for that checkpoint is disabled it returns None.
212
+
213
+ """
214
+ checkpoint_name = _replace_special_characters(checkpoint_name)
215
+
216
+ if df is None:
217
+ raise ValueError("DataFrame is required")
218
+
219
+ if pandera_schema is None:
220
+ raise ValueError("Schema is required")
221
+
222
+ if is_checkpoint_enabled(checkpoint_name):
223
+ return _check_dataframe_schema(
224
+ df,
225
+ pandera_schema,
226
+ checkpoint_name,
227
+ job_context,
228
+ custom_checks,
229
+ skip_checks,
230
+ sample_frac,
231
+ sample_number,
232
+ sampling_strategy,
233
+ output_path,
234
+ )
235
+
236
+
237
+ @report_telemetry(
238
+ params_list=["pandera_schema"],
239
+ return_indexes=[(STATUS_KEY, 0)],
240
+ multiple_return=True,
241
+ )
242
+ def _check_dataframe_schema(
243
+ df: SnowparkDataFrame,
244
+ pandera_schema: DataFrameSchema,
245
+ checkpoint_name: str,
246
+ job_context: SnowparkJobContext = None,
247
+ custom_checks: Optional[dict[str, list[Check]]] = None,
248
+ skip_checks: Optional[dict[Any, Any]] = None,
249
+ sample_frac: Optional[float] = 1.0,
250
+ sample_number: Optional[int] = None,
251
+ sampling_strategy: Optional[SamplingStrategy] = SamplingStrategy.RANDOM_SAMPLE,
252
+ output_path: Optional[str] = None,
253
+ ) -> tuple[bool, PandasDataFrame]:
254
+
255
+ pandera_check_manager = PanderaCheckManager(checkpoint_name, pandera_schema)
256
+ pandera_check_manager.skip_checks_on_schema(skip_checks)
257
+ pandera_check_manager.add_custom_checks(custom_checks)
258
+
259
+ pandera_schema_upper, sample_df = _process_sampling(
260
+ df, pandera_schema, job_context, sample_frac, sample_number, sampling_strategy
261
+ )
262
+
263
+ # Raises SchemaError on validation issues
264
+ validator = DataFrameValidator()
265
+ is_valid, validation_result = validator.validate(
266
+ pandera_schema_upper, sample_df, validity_flag=True
267
+ )
268
+ if is_valid:
269
+ if job_context is not None:
270
+ job_context._mark_pass(checkpoint_name)
271
+ _update_validation_result(checkpoint_name, PASS_STATUS, output_path)
272
+ else:
273
+ _update_validation_result(checkpoint_name, FAIL_STATUS, output_path)
274
+ raise SchemaValidationError(
275
+ "Snowpark DataFrame schema validation error",
276
+ job_context,
277
+ checkpoint_name,
278
+ validation_result,
279
+ )
280
+
281
+ return (is_valid, validation_result)
282
+
283
+
284
+ @report_telemetry(params_list=["pandera_schema"])
285
+ def check_output_schema(
286
+ pandera_schema: DataFrameSchema,
287
+ checkpoint_name: str,
288
+ sample_frac: Optional[float] = 1.0,
289
+ sample_number: Optional[int] = None,
290
+ sampling_strategy: Optional[SamplingStrategy] = SamplingStrategy.RANDOM_SAMPLE,
291
+ job_context: Optional[SnowparkJobContext] = None,
292
+ output_path: Optional[str] = None,
293
+ ):
294
+ """Decorate to validate the schema of the output of a Snowpark function.
295
+
296
+ Args:
297
+ pandera_schema (DataFrameSchema): The Pandera schema to validate against.
298
+ checkpoint_name (Optional[str], optional): The name of the checkpoint to retrieve the schema.
299
+ sample_frac (Optional[float], optional): Fraction of data to sample.
300
+ Defaults to 0.1.
301
+ sample_number (Optional[int], optional): Number of rows to sample.
302
+ Defaults to None.
303
+ sampling_strategy (Optional[SamplingStrategy], optional): Strategy for sampling data.
304
+ Defaults to SamplingStrategy.RANDOM_SAMPLE.
305
+ job_context (SnowparkJobContext, optional): Context for job-related operations.
306
+ Defaults to None.
307
+ output_path (Optional[str], optional): The output path for the validation results.
308
+
309
+ """
310
+
311
+ def check_output_with_decorator(snowpark_fn):
312
+ """Decorate to validate the schema of the output of a Snowpark function.
313
+
314
+ Args:
315
+ snowpark_fn (function): The Snowpark function to validate.
316
+
317
+ Returns:
318
+ function: The decorated function.
319
+
320
+ """
321
+ _checkpoint_name = checkpoint_name
322
+ if checkpoint_name is None:
323
+ _checkpoint_name = snowpark_fn.__name__
324
+ _checkpoint_name = _replace_special_characters(_checkpoint_name)
325
+
326
+ def wrapper(*args, **kwargs):
327
+ """Wrapp a function to validate the schema of the output of a Snowpark function.
328
+
329
+ Args:
330
+ *args: The arguments to the Snowpark function.
331
+ **kwargs: The keyword arguments to the Snowpark function.
332
+
333
+ Returns:
334
+ Any: The result of the Snowpark function.
335
+
336
+ """
337
+ # Run the sampled data in snowpark
338
+ snowpark_results = snowpark_fn(*args, **kwargs)
339
+ sampler = SamplingAdapter(
340
+ job_context, sample_frac, sample_number, sampling_strategy
341
+ )
342
+ sampler.process_args([snowpark_results])
343
+ pandas_sample_args = sampler.get_sampled_pandas_args()
344
+
345
+ # Raises SchemaError on validation issues
346
+ validator = DataFrameValidator()
347
+ is_valid, validation_result = validator.validate(
348
+ pandera_schema, pandas_sample_args[0], validity_flag=True
349
+ )
350
+ logger = CheckpointLogger().get_logger()
351
+ logger.info(
352
+ f"Checkpoint {_checkpoint_name} validation result:\n{validation_result}"
353
+ )
354
+
355
+ if is_valid:
356
+ if job_context is not None:
357
+ job_context._mark_pass(_checkpoint_name)
358
+
359
+ _update_validation_result(_checkpoint_name, PASS_STATUS, output_path)
360
+ else:
361
+ _update_validation_result(_checkpoint_name, FAIL_STATUS, output_path)
362
+ raise SchemaValidationError(
363
+ "Snowpark output schema validation error",
364
+ job_context,
365
+ _checkpoint_name,
366
+ validation_result,
367
+ )
368
+
369
+ return snowpark_results
370
+
371
+ return wrapper
372
+
373
+ return check_output_with_decorator
374
+
375
+
376
+ @report_telemetry(params_list=["pandera_schema"])
377
+ def check_input_schema(
378
+ pandera_schema: DataFrameSchema,
379
+ checkpoint_name: str,
380
+ sample_frac: Optional[float] = 1.0,
381
+ sample_number: Optional[int] = None,
382
+ sampling_strategy: Optional[SamplingStrategy] = SamplingStrategy.RANDOM_SAMPLE,
383
+ job_context: Optional[SnowparkJobContext] = None,
384
+ output_path: Optional[str] = None,
385
+ ):
386
+ """Decorate factory for validating input DataFrame schemas before function execution.
387
+
388
+ Args:
389
+ pandera_schema (DataFrameSchema): The Pandera schema to validate against.
390
+ checkpoint_name (Optional[str], optional): The name of the checkpoint to retrieve the schema.
391
+ sample_frac (Optional[float], optional): Fraction of data to sample.
392
+ Defaults to 0.1.
393
+ sample_number (Optional[int], optional): Number of rows to sample.
394
+ Defaults to None.
395
+ sampling_strategy (Optional[SamplingStrategy], optional): Strategy for sampling data.
396
+ Defaults to SamplingStrategy.RANDOM_SAMPLE.
397
+ job_context (SnowparkJobContext, optional): Context for job-related operations.
398
+ Defaults to None.
399
+ output_path (Optional[str], optional): The output path for the validation results.
400
+
401
+
402
+ """
403
+
404
+ def check_input_with_decorator(snowpark_fn):
405
+ """Decorate that validates input schemas for the decorated function.
406
+
407
+ Args:
408
+ snowpark_fn (Callable): The function to be decorated with input schema validation.
409
+
410
+ Raises:
411
+ SchemaValidationError: If input data fails schema validation.
412
+
413
+ Returns:
414
+ Callable: A wrapper function that performs schema validation before executing the original function.
415
+
416
+ """
417
+ _checkpoint_name = checkpoint_name
418
+ if checkpoint_name is None:
419
+ _checkpoint_name = snowpark_fn.__name__
420
+ _checkpoint_name = _replace_special_characters(_checkpoint_name)
421
+
422
+ def wrapper(*args, **kwargs):
423
+ """Wrapp a function to validate the schema of the input of a Snowpark function.
424
+
425
+ Raises:
426
+ SchemaValidationError: If any input DataFrame fails schema validation.
427
+
428
+ Returns:
429
+ Any: The result of the original function after input validation.
430
+
431
+ """
432
+ # Run the sampled data in snowpark
433
+ sampler = SamplingAdapter(
434
+ job_context, sample_frac, sample_number, sampling_strategy
435
+ )
436
+ sampler.process_args(args)
437
+ pandas_sample_args = sampler.get_sampled_pandas_args()
438
+
439
+ # Raises SchemaError on validation issues
440
+ for arg in pandas_sample_args:
441
+ if isinstance(arg, PandasDataFrame):
442
+
443
+ validator = DataFrameValidator()
444
+ is_valid, validation_result = validator.validate(
445
+ pandera_schema,
446
+ arg,
447
+ validity_flag=True,
448
+ )
449
+
450
+ logger = CheckpointLogger().get_logger()
451
+ logger.info(
452
+ f"Checkpoint {checkpoint_name} validation result:\n{validation_result}"
453
+ )
454
+
455
+ if is_valid:
456
+ if job_context is not None:
457
+ job_context._mark_pass(
458
+ _checkpoint_name,
459
+ )
460
+
461
+ _update_validation_result(
462
+ _checkpoint_name,
463
+ PASS_STATUS,
464
+ output_path,
465
+ )
466
+ else:
467
+ _update_validation_result(
468
+ _checkpoint_name,
469
+ FAIL_STATUS,
470
+ output_path,
471
+ )
472
+ raise SchemaValidationError(
473
+ "Snowpark input schema validation error",
474
+ job_context,
475
+ _checkpoint_name,
476
+ validation_result,
477
+ )
478
+ return snowpark_fn(*args, **kwargs)
479
+
480
+ return wrapper
481
+
482
+ return check_input_with_decorator
@@ -0,0 +1,60 @@
1
+ # Copyright 2025 Snowflake Inc.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Optional
17
+
18
+ from snowflake.snowpark_checkpoints.job_context import SnowparkJobContext
19
+
20
+
21
+ class SnowparkCheckpointError(Exception):
22
+ def __init__(
23
+ self,
24
+ message,
25
+ job_context: Optional[SnowparkJobContext],
26
+ checkpoint_name: str,
27
+ data=None,
28
+ ):
29
+ job_name = job_context.job_name if job_context else "Unknown Job"
30
+ super().__init__(
31
+ f"Job: {job_name} Checkpoint: {checkpoint_name}\n{message} \n {data}"
32
+ )
33
+ if job_context:
34
+ job_context._mark_fail(
35
+ message,
36
+ checkpoint_name,
37
+ data,
38
+ )
39
+
40
+
41
+ class SparkMigrationError(SnowparkCheckpointError):
42
+ def __init__(
43
+ self,
44
+ message,
45
+ job_context: Optional[SnowparkJobContext],
46
+ checkpoint_name=None,
47
+ data=None,
48
+ ):
49
+ super().__init__(message, job_context, checkpoint_name, data)
50
+
51
+
52
+ class SchemaValidationError(SnowparkCheckpointError):
53
+ def __init__(
54
+ self,
55
+ message,
56
+ job_context: Optional[SnowparkJobContext],
57
+ checkpoint_name=None,
58
+ data=None,
59
+ ):
60
+ super().__init__(message, job_context, checkpoint_name, data)
@@ -0,0 +1,85 @@
1
+ # Copyright 2025 Snowflake Inc.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from datetime import datetime
17
+ from typing import Optional
18
+
19
+ import pandas as pd
20
+
21
+ from pyspark.sql import SparkSession
22
+
23
+ from snowflake.snowpark import Session
24
+ from snowflake.snowpark_checkpoints.utils.constants import SCHEMA_EXECUTION_MODE
25
+
26
+
27
+ class SnowparkJobContext:
28
+
29
+ """Class used to record migration results in Snowflake.
30
+
31
+ Args:
32
+ snowpark_session: A Snowpark session instance.
33
+ spark_session: A Spark session instance.
34
+ job_name: The name of the job.
35
+ log_results: Whether to log the migration results in Snowflake.
36
+
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ snowpark_session: Session,
42
+ spark_session: SparkSession = None,
43
+ job_name: Optional[str] = None,
44
+ log_results: Optional[bool] = True,
45
+ ):
46
+ self.log_results = log_results
47
+ self.job_name = job_name
48
+ self.spark_session = spark_session or SparkSession.builder.getOrCreate()
49
+ self.snowpark_session = snowpark_session
50
+
51
+ def _mark_fail(
52
+ self, message, checkpoint_name, data, execution_mode=SCHEMA_EXECUTION_MODE
53
+ ):
54
+ if self.log_results:
55
+ session = self.snowpark_session
56
+ df = pd.DataFrame(
57
+ {
58
+ "DATE": [datetime.now()],
59
+ "JOB": [self.job_name],
60
+ "STATUS": ["fail"],
61
+ "CHECKPOINT": [checkpoint_name],
62
+ "MESSAGE": [message],
63
+ "DATA": [f"{data}"],
64
+ "EXECUTION_MODE": [execution_mode],
65
+ }
66
+ )
67
+ report_df = session.createDataFrame(df)
68
+ report_df.write.mode("append").save_as_table("SNOWPARK_CHECKPOINTS_REPORT")
69
+
70
+ def _mark_pass(self, checkpoint_name, execution_mode=SCHEMA_EXECUTION_MODE):
71
+ if self.log_results:
72
+ session = self.snowpark_session
73
+ df = pd.DataFrame(
74
+ {
75
+ "DATE": [datetime.now()],
76
+ "JOB": [self.job_name],
77
+ "STATUS": ["pass"],
78
+ "CHECKPOINT": [checkpoint_name],
79
+ "MESSAGE": [""],
80
+ "DATA": [""],
81
+ "EXECUTION_MODE": [execution_mode],
82
+ }
83
+ )
84
+ report_df = session.createDataFrame(df)
85
+ report_df.write.mode("append").save_as_table("SNOWPARK_CHECKPOINTS_REPORT")