snowpark-checkpoints-validators 0.1.3__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/snowpark_checkpoints/__init__.py +11 -1
- snowflake/snowpark_checkpoints/__version__.py +1 -1
- snowflake/snowpark_checkpoints/checkpoint.py +195 -97
- snowflake/snowpark_checkpoints/job_context.py +72 -29
- snowflake/snowpark_checkpoints/snowpark_sampler.py +26 -1
- snowflake/snowpark_checkpoints/spark_migration.py +39 -6
- snowflake/snowpark_checkpoints/utils/extra_config.py +10 -5
- snowflake/snowpark_checkpoints/utils/logging_utils.py +67 -0
- snowflake/snowpark_checkpoints/utils/pandera_check_manager.py +48 -7
- snowflake/snowpark_checkpoints/utils/utils_checks.py +23 -2
- snowflake/snowpark_checkpoints/validation_result_metadata.py +30 -0
- {snowpark_checkpoints_validators-0.1.3.dist-info → snowpark_checkpoints_validators-0.2.0.dist-info}/METADATA +16 -4
- snowpark_checkpoints_validators-0.2.0.dist-info/RECORD +22 -0
- snowflake/snowpark_checkpoints/utils/checkpoint_logger.py +0 -52
- snowpark_checkpoints_validators-0.1.3.dist-info/RECORD +0 -22
- {snowpark_checkpoints_validators-0.1.3.dist-info → snowpark_checkpoints_validators-0.2.0.dist-info}/WHEEL +0 -0
- {snowpark_checkpoints_validators-0.1.3.dist-info → snowpark_checkpoints_validators-0.2.0.dist-info}/licenses/LICENSE +0 -0
@@ -13,16 +13,26 @@
|
|
13
13
|
# See the License for the specific language governing permissions and
|
14
14
|
# limitations under the License.
|
15
15
|
|
16
|
+
import logging
|
17
|
+
|
18
|
+
|
19
|
+
# Add a NullHandler to prevent logging messages from being output to
|
20
|
+
# sys.stderr if no logging configuration is provided.
|
21
|
+
logging.getLogger(__name__).addHandler(logging.NullHandler())
|
22
|
+
|
23
|
+
# ruff: noqa: E402
|
24
|
+
|
16
25
|
from snowflake.snowpark_checkpoints.checkpoint import (
|
17
26
|
check_dataframe_schema,
|
18
|
-
check_output_schema,
|
19
27
|
check_input_schema,
|
28
|
+
check_output_schema,
|
20
29
|
validate_dataframe_checkpoint,
|
21
30
|
)
|
22
31
|
from snowflake.snowpark_checkpoints.job_context import SnowparkJobContext
|
23
32
|
from snowflake.snowpark_checkpoints.spark_migration import check_with_spark
|
24
33
|
from snowflake.snowpark_checkpoints.utils.constants import CheckpointMode
|
25
34
|
|
35
|
+
|
26
36
|
__all__ = [
|
27
37
|
"check_with_spark",
|
28
38
|
"SnowparkJobContext",
|
@@ -14,11 +14,14 @@
|
|
14
14
|
# limitations under the License.
|
15
15
|
|
16
16
|
# Wrapper around pandera which logs to snowflake
|
17
|
-
|
17
|
+
|
18
|
+
import logging
|
19
|
+
|
20
|
+
from typing import Any, Optional, Union, cast
|
18
21
|
|
19
22
|
from pandas import DataFrame as PandasDataFrame
|
20
|
-
from pandera import Check, DataFrameSchema
|
21
|
-
from
|
23
|
+
from pandera import Check, DataFrameModel, DataFrameSchema
|
24
|
+
from pandera.errors import SchemaError, SchemaErrors
|
22
25
|
|
23
26
|
from snowflake.snowpark import DataFrame as SnowparkDataFrame
|
24
27
|
from snowflake.snowpark_checkpoints.errors import SchemaValidationError
|
@@ -27,13 +30,13 @@ from snowflake.snowpark_checkpoints.snowpark_sampler import (
|
|
27
30
|
SamplingAdapter,
|
28
31
|
SamplingStrategy,
|
29
32
|
)
|
30
|
-
from snowflake.snowpark_checkpoints.utils.checkpoint_logger import CheckpointLogger
|
31
33
|
from snowflake.snowpark_checkpoints.utils.constants import (
|
32
34
|
FAIL_STATUS,
|
33
35
|
PASS_STATUS,
|
34
36
|
CheckpointMode,
|
35
37
|
)
|
36
38
|
from snowflake.snowpark_checkpoints.utils.extra_config import is_checkpoint_enabled
|
39
|
+
from snowflake.snowpark_checkpoints.utils.logging_utils import log
|
37
40
|
from snowflake.snowpark_checkpoints.utils.pandera_check_manager import (
|
38
41
|
PanderaCheckManager,
|
39
42
|
)
|
@@ -47,6 +50,10 @@ from snowflake.snowpark_checkpoints.utils.utils_checks import (
|
|
47
50
|
)
|
48
51
|
|
49
52
|
|
53
|
+
LOGGER = logging.getLogger(__name__)
|
54
|
+
|
55
|
+
|
56
|
+
@log
|
50
57
|
def validate_dataframe_checkpoint(
|
51
58
|
df: SnowparkDataFrame,
|
52
59
|
checkpoint_name: str,
|
@@ -84,31 +91,45 @@ def validate_dataframe_checkpoint(
|
|
84
91
|
"""
|
85
92
|
checkpoint_name = _replace_special_characters(checkpoint_name)
|
86
93
|
|
87
|
-
if is_checkpoint_enabled(checkpoint_name):
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
94
|
+
if not is_checkpoint_enabled(checkpoint_name):
|
95
|
+
LOGGER.warning(
|
96
|
+
"Checkpoint '%s' is disabled. Skipping DataFrame checkpoint validation.",
|
97
|
+
checkpoint_name,
|
98
|
+
)
|
99
|
+
return None
|
100
|
+
|
101
|
+
LOGGER.info(
|
102
|
+
"Starting DataFrame checkpoint validation for checkpoint '%s'", checkpoint_name
|
103
|
+
)
|
104
|
+
|
105
|
+
if mode == CheckpointMode.SCHEMA:
|
106
|
+
result = _check_dataframe_schema_file(
|
107
|
+
df,
|
108
|
+
checkpoint_name,
|
109
|
+
job_context,
|
110
|
+
custom_checks,
|
111
|
+
skip_checks,
|
112
|
+
sample_frac,
|
113
|
+
sample_number,
|
114
|
+
sampling_strategy,
|
115
|
+
output_path,
|
116
|
+
)
|
117
|
+
return result
|
118
|
+
|
119
|
+
if mode == CheckpointMode.DATAFRAME:
|
120
|
+
if job_context is None:
|
108
121
|
raise ValueError(
|
109
|
-
"
|
110
|
-
Please use for schema validation use a 1 or for a full data validation use a 2 for schema validation."""
|
122
|
+
"No job context provided. Please provide one when using DataFrame mode validation."
|
111
123
|
)
|
124
|
+
_check_compare_data(df, job_context, checkpoint_name, output_path)
|
125
|
+
return None
|
126
|
+
|
127
|
+
raise ValueError(
|
128
|
+
(
|
129
|
+
"Invalid validation mode. "
|
130
|
+
"Please use 1 for schema validation or 2 for full data validation."
|
131
|
+
),
|
132
|
+
)
|
112
133
|
|
113
134
|
|
114
135
|
def _check_dataframe_schema_file(
|
@@ -156,7 +177,7 @@ def _check_dataframe_schema_file(
|
|
156
177
|
|
157
178
|
schema = _generate_schema(checkpoint_name, output_path)
|
158
179
|
|
159
|
-
return
|
180
|
+
return _check_dataframe_schema(
|
160
181
|
df,
|
161
182
|
schema,
|
162
183
|
checkpoint_name,
|
@@ -170,6 +191,7 @@ def _check_dataframe_schema_file(
|
|
170
191
|
)
|
171
192
|
|
172
193
|
|
194
|
+
@log
|
173
195
|
def check_dataframe_schema(
|
174
196
|
df: SnowparkDataFrame,
|
175
197
|
pandera_schema: DataFrameSchema,
|
@@ -212,6 +234,9 @@ def check_dataframe_schema(
|
|
212
234
|
|
213
235
|
"""
|
214
236
|
checkpoint_name = _replace_special_characters(checkpoint_name)
|
237
|
+
LOGGER.info(
|
238
|
+
"Starting DataFrame schema validation for checkpoint '%s'", checkpoint_name
|
239
|
+
)
|
215
240
|
|
216
241
|
if df is None:
|
217
242
|
raise ValueError("DataFrame is required")
|
@@ -219,19 +244,25 @@ def check_dataframe_schema(
|
|
219
244
|
if pandera_schema is None:
|
220
245
|
raise ValueError("Schema is required")
|
221
246
|
|
222
|
-
if is_checkpoint_enabled(checkpoint_name):
|
223
|
-
|
224
|
-
|
225
|
-
pandera_schema,
|
247
|
+
if not is_checkpoint_enabled(checkpoint_name):
|
248
|
+
LOGGER.warning(
|
249
|
+
"Checkpoint '%s' is disabled. Skipping DataFrame schema validation.",
|
226
250
|
checkpoint_name,
|
227
|
-
job_context,
|
228
|
-
custom_checks,
|
229
|
-
skip_checks,
|
230
|
-
sample_frac,
|
231
|
-
sample_number,
|
232
|
-
sampling_strategy,
|
233
|
-
output_path,
|
234
251
|
)
|
252
|
+
return None
|
253
|
+
|
254
|
+
return _check_dataframe_schema(
|
255
|
+
df,
|
256
|
+
pandera_schema,
|
257
|
+
checkpoint_name,
|
258
|
+
job_context,
|
259
|
+
custom_checks,
|
260
|
+
skip_checks,
|
261
|
+
sample_frac,
|
262
|
+
sample_number,
|
263
|
+
sampling_strategy,
|
264
|
+
output_path,
|
265
|
+
)
|
235
266
|
|
236
267
|
|
237
268
|
@report_telemetry(
|
@@ -259,17 +290,24 @@ def _check_dataframe_schema(
|
|
259
290
|
pandera_schema_upper, sample_df = _process_sampling(
|
260
291
|
df, pandera_schema, job_context, sample_frac, sample_number, sampling_strategy
|
261
292
|
)
|
262
|
-
|
263
|
-
# Raises SchemaError on validation issues
|
264
|
-
validator = DataFrameValidator()
|
265
|
-
is_valid, validation_result = validator.validate(
|
266
|
-
pandera_schema_upper, sample_df, validity_flag=True
|
267
|
-
)
|
293
|
+
is_valid, validation_result = _validate(pandera_schema_upper, sample_df)
|
268
294
|
if is_valid:
|
295
|
+
LOGGER.info(
|
296
|
+
"DataFrame schema validation passed for checkpoint '%s'",
|
297
|
+
checkpoint_name,
|
298
|
+
)
|
269
299
|
if job_context is not None:
|
270
300
|
job_context._mark_pass(checkpoint_name)
|
301
|
+
else:
|
302
|
+
LOGGER.warning(
|
303
|
+
"No job context provided. Skipping result recording into Snowflake.",
|
304
|
+
)
|
271
305
|
_update_validation_result(checkpoint_name, PASS_STATUS, output_path)
|
272
306
|
else:
|
307
|
+
LOGGER.error(
|
308
|
+
"DataFrame schema validation failed for checkpoint '%s'",
|
309
|
+
checkpoint_name,
|
310
|
+
)
|
273
311
|
_update_validation_result(checkpoint_name, FAIL_STATUS, output_path)
|
274
312
|
raise SchemaValidationError(
|
275
313
|
"Snowpark DataFrame schema validation error",
|
@@ -282,6 +320,7 @@ def _check_dataframe_schema(
|
|
282
320
|
|
283
321
|
|
284
322
|
@report_telemetry(params_list=["pandera_schema"])
|
323
|
+
@log
|
285
324
|
def check_output_schema(
|
286
325
|
pandera_schema: DataFrameSchema,
|
287
326
|
checkpoint_name: str,
|
@@ -318,11 +357,8 @@ def check_output_schema(
|
|
318
357
|
function: The decorated function.
|
319
358
|
|
320
359
|
"""
|
321
|
-
_checkpoint_name = checkpoint_name
|
322
|
-
if checkpoint_name is None:
|
323
|
-
_checkpoint_name = snowpark_fn.__name__
|
324
|
-
_checkpoint_name = _replace_special_characters(_checkpoint_name)
|
325
360
|
|
361
|
+
@log(log_args=False)
|
326
362
|
def wrapper(*args, **kwargs):
|
327
363
|
"""Wrapp a function to validate the schema of the output of a Snowpark function.
|
328
364
|
|
@@ -334,7 +370,25 @@ def check_output_schema(
|
|
334
370
|
Any: The result of the Snowpark function.
|
335
371
|
|
336
372
|
"""
|
373
|
+
_checkpoint_name = checkpoint_name
|
374
|
+
if checkpoint_name is None:
|
375
|
+
LOGGER.warning(
|
376
|
+
(
|
377
|
+
"No checkpoint name provided for output schema validation. "
|
378
|
+
"Using '%s' as the checkpoint name.",
|
379
|
+
),
|
380
|
+
snowpark_fn.__name__,
|
381
|
+
)
|
382
|
+
_checkpoint_name = snowpark_fn.__name__
|
383
|
+
_checkpoint_name = _replace_special_characters(_checkpoint_name)
|
384
|
+
LOGGER.info(
|
385
|
+
"Starting output schema validation for Snowpark function '%s' and checkpoint '%s'",
|
386
|
+
snowpark_fn.__name__,
|
387
|
+
_checkpoint_name,
|
388
|
+
)
|
389
|
+
|
337
390
|
# Run the sampled data in snowpark
|
391
|
+
LOGGER.info("Running the Snowpark function '%s'", snowpark_fn.__name__)
|
338
392
|
snowpark_results = snowpark_fn(*args, **kwargs)
|
339
393
|
sampler = SamplingAdapter(
|
340
394
|
job_context, sample_frac, sample_number, sampling_strategy
|
@@ -342,22 +396,28 @@ def check_output_schema(
|
|
342
396
|
sampler.process_args([snowpark_results])
|
343
397
|
pandas_sample_args = sampler.get_sampled_pandas_args()
|
344
398
|
|
345
|
-
|
346
|
-
|
347
|
-
is_valid, validation_result = validator.validate(
|
348
|
-
pandera_schema, pandas_sample_args[0], validity_flag=True
|
399
|
+
is_valid, validation_result = _validate(
|
400
|
+
pandera_schema, pandas_sample_args[0]
|
349
401
|
)
|
350
|
-
logger = CheckpointLogger().get_logger()
|
351
|
-
logger.info(
|
352
|
-
f"Checkpoint {_checkpoint_name} validation result:\n{validation_result}"
|
353
|
-
)
|
354
|
-
|
355
402
|
if is_valid:
|
403
|
+
LOGGER.info(
|
404
|
+
"Output schema validation passed for Snowpark function '%s' and checkpoint '%s'",
|
405
|
+
snowpark_fn.__name__,
|
406
|
+
_checkpoint_name,
|
407
|
+
)
|
356
408
|
if job_context is not None:
|
357
409
|
job_context._mark_pass(_checkpoint_name)
|
358
|
-
|
410
|
+
else:
|
411
|
+
LOGGER.warning(
|
412
|
+
"No job context provided. Skipping result recording into Snowflake.",
|
413
|
+
)
|
359
414
|
_update_validation_result(_checkpoint_name, PASS_STATUS, output_path)
|
360
415
|
else:
|
416
|
+
LOGGER.error(
|
417
|
+
"Output schema validation failed for Snowpark function '%s' and checkpoint '%s'",
|
418
|
+
snowpark_fn.__name__,
|
419
|
+
_checkpoint_name,
|
420
|
+
)
|
361
421
|
_update_validation_result(_checkpoint_name, FAIL_STATUS, output_path)
|
362
422
|
raise SchemaValidationError(
|
363
423
|
"Snowpark output schema validation error",
|
@@ -365,7 +425,6 @@ def check_output_schema(
|
|
365
425
|
_checkpoint_name,
|
366
426
|
validation_result,
|
367
427
|
)
|
368
|
-
|
369
428
|
return snowpark_results
|
370
429
|
|
371
430
|
return wrapper
|
@@ -374,6 +433,7 @@ def check_output_schema(
|
|
374
433
|
|
375
434
|
|
376
435
|
@report_telemetry(params_list=["pandera_schema"])
|
436
|
+
@log
|
377
437
|
def check_input_schema(
|
378
438
|
pandera_schema: DataFrameSchema,
|
379
439
|
checkpoint_name: str,
|
@@ -414,11 +474,8 @@ def check_input_schema(
|
|
414
474
|
Callable: A wrapper function that performs schema validation before executing the original function.
|
415
475
|
|
416
476
|
"""
|
417
|
-
_checkpoint_name = checkpoint_name
|
418
|
-
if checkpoint_name is None:
|
419
|
-
_checkpoint_name = snowpark_fn.__name__
|
420
|
-
_checkpoint_name = _replace_special_characters(_checkpoint_name)
|
421
477
|
|
478
|
+
@log(log_args=False)
|
422
479
|
def wrapper(*args, **kwargs):
|
423
480
|
"""Wrapp a function to validate the schema of the input of a Snowpark function.
|
424
481
|
|
@@ -429,6 +486,23 @@ def check_input_schema(
|
|
429
486
|
Any: The result of the original function after input validation.
|
430
487
|
|
431
488
|
"""
|
489
|
+
_checkpoint_name = checkpoint_name
|
490
|
+
if checkpoint_name is None:
|
491
|
+
LOGGER.warning(
|
492
|
+
(
|
493
|
+
"No checkpoint name provided for input schema validation. "
|
494
|
+
"Using '%s' as the checkpoint name."
|
495
|
+
),
|
496
|
+
snowpark_fn.__name__,
|
497
|
+
)
|
498
|
+
_checkpoint_name = snowpark_fn.__name__
|
499
|
+
_checkpoint_name = _replace_special_characters(_checkpoint_name)
|
500
|
+
LOGGER.info(
|
501
|
+
"Starting input schema validation for Snowpark function '%s' and checkpoint '%s'",
|
502
|
+
snowpark_fn.__name__,
|
503
|
+
_checkpoint_name,
|
504
|
+
)
|
505
|
+
|
432
506
|
# Run the sampled data in snowpark
|
433
507
|
sampler = SamplingAdapter(
|
434
508
|
job_context, sample_frac, sample_number, sampling_strategy
|
@@ -436,47 +510,71 @@ def check_input_schema(
|
|
436
510
|
sampler.process_args(args)
|
437
511
|
pandas_sample_args = sampler.get_sampled_pandas_args()
|
438
512
|
|
513
|
+
LOGGER.info(
|
514
|
+
"Validating %s input argument(s) against a Pandera schema",
|
515
|
+
len(pandas_sample_args),
|
516
|
+
)
|
439
517
|
# Raises SchemaError on validation issues
|
440
|
-
for arg in pandas_sample_args:
|
441
|
-
if isinstance(arg, PandasDataFrame):
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
pandera_schema,
|
446
|
-
arg,
|
447
|
-
validity_flag=True,
|
518
|
+
for index, arg in enumerate(pandas_sample_args, start=1):
|
519
|
+
if not isinstance(arg, PandasDataFrame):
|
520
|
+
LOGGER.info(
|
521
|
+
"Arg %s: Skipping schema validation for non-DataFrame argument",
|
522
|
+
index,
|
448
523
|
)
|
524
|
+
continue
|
449
525
|
|
450
|
-
|
451
|
-
|
452
|
-
|
526
|
+
is_valid, validation_result = _validate(
|
527
|
+
pandera_schema,
|
528
|
+
arg,
|
529
|
+
)
|
530
|
+
if is_valid:
|
531
|
+
LOGGER.info(
|
532
|
+
"Arg %s: Input schema validation passed",
|
533
|
+
index,
|
453
534
|
)
|
454
|
-
|
455
|
-
|
456
|
-
if job_context is not None:
|
457
|
-
job_context._mark_pass(
|
458
|
-
_checkpoint_name,
|
459
|
-
)
|
460
|
-
|
461
|
-
_update_validation_result(
|
462
|
-
_checkpoint_name,
|
463
|
-
PASS_STATUS,
|
464
|
-
output_path,
|
465
|
-
)
|
466
|
-
else:
|
467
|
-
_update_validation_result(
|
468
|
-
_checkpoint_name,
|
469
|
-
FAIL_STATUS,
|
470
|
-
output_path,
|
471
|
-
)
|
472
|
-
raise SchemaValidationError(
|
473
|
-
"Snowpark input schema validation error",
|
474
|
-
job_context,
|
535
|
+
if job_context is not None:
|
536
|
+
job_context._mark_pass(
|
475
537
|
_checkpoint_name,
|
476
|
-
validation_result,
|
477
538
|
)
|
539
|
+
_update_validation_result(
|
540
|
+
_checkpoint_name,
|
541
|
+
PASS_STATUS,
|
542
|
+
output_path,
|
543
|
+
)
|
544
|
+
else:
|
545
|
+
LOGGER.error(
|
546
|
+
"Arg %s: Input schema validation failed",
|
547
|
+
index,
|
548
|
+
)
|
549
|
+
_update_validation_result(
|
550
|
+
_checkpoint_name,
|
551
|
+
FAIL_STATUS,
|
552
|
+
output_path,
|
553
|
+
)
|
554
|
+
raise SchemaValidationError(
|
555
|
+
"Snowpark input schema validation error",
|
556
|
+
job_context,
|
557
|
+
_checkpoint_name,
|
558
|
+
validation_result,
|
559
|
+
)
|
478
560
|
return snowpark_fn(*args, **kwargs)
|
479
561
|
|
480
562
|
return wrapper
|
481
563
|
|
482
564
|
return check_input_with_decorator
|
565
|
+
|
566
|
+
|
567
|
+
def _validate(
|
568
|
+
schema: Union[type[DataFrameModel], DataFrameSchema],
|
569
|
+
df: PandasDataFrame,
|
570
|
+
lazy: bool = True,
|
571
|
+
) -> tuple[bool, PandasDataFrame]:
|
572
|
+
if not isinstance(schema, DataFrameSchema):
|
573
|
+
schema = schema.to_schema()
|
574
|
+
is_valid = True
|
575
|
+
try:
|
576
|
+
df = schema.validate(df, lazy=lazy)
|
577
|
+
except (SchemaErrors, SchemaError) as schema_errors:
|
578
|
+
df = cast(PandasDataFrame, schema_errors.failure_cases)
|
579
|
+
is_valid = False
|
580
|
+
return is_valid, df
|
@@ -13,6 +13,8 @@
|
|
13
13
|
# See the License for the specific language governing permissions and
|
14
14
|
# limitations under the License.
|
15
15
|
|
16
|
+
import logging
|
17
|
+
|
16
18
|
from datetime import datetime
|
17
19
|
from typing import Optional
|
18
20
|
|
@@ -24,6 +26,10 @@ from snowflake.snowpark import Session
|
|
24
26
|
from snowflake.snowpark_checkpoints.utils.constants import SCHEMA_EXECUTION_MODE
|
25
27
|
|
26
28
|
|
29
|
+
LOGGER = logging.getLogger(__name__)
|
30
|
+
RESULTS_TABLE = "SNOWPARK_CHECKPOINTS_REPORT"
|
31
|
+
|
32
|
+
|
27
33
|
class SnowparkJobContext:
|
28
34
|
|
29
35
|
"""Class used to record migration results in Snowflake.
|
@@ -45,41 +51,78 @@ class SnowparkJobContext:
|
|
45
51
|
):
|
46
52
|
self.log_results = log_results
|
47
53
|
self.job_name = job_name
|
48
|
-
self.spark_session = spark_session or
|
54
|
+
self.spark_session = spark_session or self._create_pyspark_session()
|
49
55
|
self.snowpark_session = snowpark_session
|
50
56
|
|
51
57
|
def _mark_fail(
|
52
58
|
self, message, checkpoint_name, data, execution_mode=SCHEMA_EXECUTION_MODE
|
53
59
|
):
|
54
|
-
if self.log_results:
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
"
|
59
|
-
|
60
|
-
|
61
|
-
"CHECKPOINT": [checkpoint_name],
|
62
|
-
"MESSAGE": [message],
|
63
|
-
"DATA": [f"{data}"],
|
64
|
-
"EXECUTION_MODE": [execution_mode],
|
65
|
-
}
|
60
|
+
if not self.log_results:
|
61
|
+
LOGGER.warning(
|
62
|
+
(
|
63
|
+
"Recording of migration results into Snowflake is disabled. "
|
64
|
+
"Failure result for checkpoint '%s' will not be recorded."
|
65
|
+
),
|
66
|
+
checkpoint_name,
|
66
67
|
)
|
67
|
-
|
68
|
-
|
68
|
+
return
|
69
|
+
|
70
|
+
LOGGER.debug(
|
71
|
+
"Marking failure for checkpoint '%s' in '%s' mode with message '%s'",
|
72
|
+
checkpoint_name,
|
73
|
+
execution_mode,
|
74
|
+
message,
|
75
|
+
)
|
76
|
+
|
77
|
+
session = self.snowpark_session
|
78
|
+
df = pd.DataFrame(
|
79
|
+
{
|
80
|
+
"DATE": [datetime.now()],
|
81
|
+
"JOB": [self.job_name],
|
82
|
+
"STATUS": ["fail"],
|
83
|
+
"CHECKPOINT": [checkpoint_name],
|
84
|
+
"MESSAGE": [message],
|
85
|
+
"DATA": [f"{data}"],
|
86
|
+
"EXECUTION_MODE": [execution_mode],
|
87
|
+
}
|
88
|
+
)
|
89
|
+
report_df = session.createDataFrame(df)
|
90
|
+
LOGGER.info("Writing failure result to table: '%s'", RESULTS_TABLE)
|
91
|
+
report_df.write.mode("append").save_as_table(RESULTS_TABLE)
|
69
92
|
|
70
93
|
def _mark_pass(self, checkpoint_name, execution_mode=SCHEMA_EXECUTION_MODE):
|
71
|
-
if self.log_results:
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
"
|
76
|
-
|
77
|
-
|
78
|
-
"CHECKPOINT": [checkpoint_name],
|
79
|
-
"MESSAGE": [""],
|
80
|
-
"DATA": [""],
|
81
|
-
"EXECUTION_MODE": [execution_mode],
|
82
|
-
}
|
94
|
+
if not self.log_results:
|
95
|
+
LOGGER.warning(
|
96
|
+
(
|
97
|
+
"Recording of migration results into Snowflake is disabled. "
|
98
|
+
"Pass result for checkpoint '%s' will not be recorded."
|
99
|
+
),
|
100
|
+
checkpoint_name,
|
83
101
|
)
|
84
|
-
|
85
|
-
|
102
|
+
return
|
103
|
+
|
104
|
+
LOGGER.debug(
|
105
|
+
"Marking pass for checkpoint '%s' in '%s' mode",
|
106
|
+
checkpoint_name,
|
107
|
+
execution_mode,
|
108
|
+
)
|
109
|
+
|
110
|
+
session = self.snowpark_session
|
111
|
+
df = pd.DataFrame(
|
112
|
+
{
|
113
|
+
"DATE": [datetime.now()],
|
114
|
+
"JOB": [self.job_name],
|
115
|
+
"STATUS": ["pass"],
|
116
|
+
"CHECKPOINT": [checkpoint_name],
|
117
|
+
"MESSAGE": [""],
|
118
|
+
"DATA": [""],
|
119
|
+
"EXECUTION_MODE": [execution_mode],
|
120
|
+
}
|
121
|
+
)
|
122
|
+
report_df = session.createDataFrame(df)
|
123
|
+
LOGGER.info("Writing pass result to table: '%s'", RESULTS_TABLE)
|
124
|
+
report_df.write.mode("append").save_as_table(RESULTS_TABLE)
|
125
|
+
|
126
|
+
def _create_pyspark_session(self) -> SparkSession:
|
127
|
+
LOGGER.info("Creating a PySpark session")
|
128
|
+
return SparkSession.builder.getOrCreate()
|
@@ -13,6 +13,8 @@
|
|
13
13
|
# See the License for the specific language governing permissions and
|
14
14
|
# limitations under the License.
|
15
15
|
|
16
|
+
import logging
|
17
|
+
|
16
18
|
from typing import Optional
|
17
19
|
|
18
20
|
import pandas
|
@@ -21,6 +23,9 @@ from snowflake.snowpark import DataFrame as SnowparkDataFrame
|
|
21
23
|
from snowflake.snowpark_checkpoints.job_context import SnowparkJobContext
|
22
24
|
|
23
25
|
|
26
|
+
LOGGER = logging.getLogger(__name__)
|
27
|
+
|
28
|
+
|
24
29
|
class SamplingStrategy:
|
25
30
|
RANDOM_SAMPLE = 1
|
26
31
|
LIMIT = 2
|
@@ -52,23 +57,43 @@ class SamplingAdapter:
|
|
52
57
|
def process_args(self, input_args):
|
53
58
|
# create the intermediate pandas
|
54
59
|
# data frame for the test data
|
60
|
+
LOGGER.info("Processing %s input argument(s) for sampling", len(input_args))
|
55
61
|
for arg in input_args:
|
56
62
|
if isinstance(arg, SnowparkDataFrame):
|
57
|
-
|
63
|
+
df_count = arg.count()
|
64
|
+
if df_count == 0:
|
58
65
|
raise SamplingError(
|
59
66
|
"Input DataFrame is empty. Cannot sample from an empty DataFrame."
|
60
67
|
)
|
61
68
|
|
69
|
+
LOGGER.info("Sampling a Snowpark DataFrame with %s rows", df_count)
|
62
70
|
if self.sampling_strategy == SamplingStrategy.RANDOM_SAMPLE:
|
63
71
|
if self.sample_frac:
|
72
|
+
LOGGER.info(
|
73
|
+
"Applying random sampling with fraction %s",
|
74
|
+
self.sample_frac,
|
75
|
+
)
|
64
76
|
df_sample = arg.sample(frac=self.sample_frac).to_pandas()
|
65
77
|
else:
|
78
|
+
LOGGER.info(
|
79
|
+
"Applying random sampling with size %s", self.sample_number
|
80
|
+
)
|
66
81
|
df_sample = arg.sample(n=self.sample_number).to_pandas()
|
67
82
|
else:
|
83
|
+
LOGGER.info(
|
84
|
+
"Applying limit sampling with size %s", self.sample_number
|
85
|
+
)
|
68
86
|
df_sample = arg.limit(self.sample_number).to_pandas()
|
69
87
|
|
88
|
+
LOGGER.info(
|
89
|
+
"Successfully sampled the DataFrame. Resulting DataFrame shape: %s",
|
90
|
+
df_sample.shape,
|
91
|
+
)
|
70
92
|
self.pandas_sample_args.append(df_sample)
|
71
93
|
else:
|
94
|
+
LOGGER.debug(
|
95
|
+
"Argument is not a Snowpark DataFrame. No sampling is applied."
|
96
|
+
)
|
72
97
|
self.pandas_sample_args.append(arg)
|
73
98
|
|
74
99
|
def get_sampled_pandas_args(self):
|
@@ -12,6 +12,9 @@
|
|
12
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
13
|
# See the License for the specific language governing permissions and
|
14
14
|
# limitations under the License.
|
15
|
+
|
16
|
+
import logging
|
17
|
+
|
15
18
|
from typing import Callable, Optional, TypeVar
|
16
19
|
|
17
20
|
import pandas as pd
|
@@ -27,6 +30,7 @@ from snowflake.snowpark_checkpoints.snowpark_sampler import (
|
|
27
30
|
SamplingStrategy,
|
28
31
|
)
|
29
32
|
from snowflake.snowpark_checkpoints.utils.constants import FAIL_STATUS, PASS_STATUS
|
33
|
+
from snowflake.snowpark_checkpoints.utils.logging_utils import log
|
30
34
|
from snowflake.snowpark_checkpoints.utils.telemetry import STATUS_KEY, report_telemetry
|
31
35
|
from snowflake.snowpark_checkpoints.utils.utils_checks import (
|
32
36
|
_replace_special_characters,
|
@@ -35,8 +39,10 @@ from snowflake.snowpark_checkpoints.utils.utils_checks import (
|
|
35
39
|
|
36
40
|
|
37
41
|
fn = TypeVar("F", bound=Callable)
|
42
|
+
LOGGER = logging.getLogger(__name__)
|
38
43
|
|
39
44
|
|
45
|
+
@log
|
40
46
|
def check_with_spark(
|
41
47
|
job_context: Optional[SnowparkJobContext],
|
42
48
|
spark_function: fn,
|
@@ -67,12 +73,22 @@ def check_with_spark(
|
|
67
73
|
"""
|
68
74
|
|
69
75
|
def check_with_spark_decorator(snowpark_fn):
|
70
|
-
|
71
|
-
if checkpoint_name is None:
|
72
|
-
_checkpoint_name = snowpark_fn.__name__
|
73
|
-
_checkpoint_name = _replace_special_characters(_checkpoint_name)
|
74
|
-
|
76
|
+
@log(log_args=False)
|
75
77
|
def wrapper(*args, **kwargs):
|
78
|
+
LOGGER.info(
|
79
|
+
"Starting output validation between Snowpark function '%s' and Spark function '%s'",
|
80
|
+
snowpark_fn.__name__,
|
81
|
+
spark_function.__name__,
|
82
|
+
)
|
83
|
+
_checkpoint_name = checkpoint_name
|
84
|
+
if checkpoint_name is None:
|
85
|
+
LOGGER.warning(
|
86
|
+
"No checkpoint name provided. Using '%s' as the checkpoint name",
|
87
|
+
snowpark_fn.__name__,
|
88
|
+
)
|
89
|
+
_checkpoint_name = snowpark_fn.__name__
|
90
|
+
_checkpoint_name = _replace_special_characters(_checkpoint_name)
|
91
|
+
|
76
92
|
sampler = SamplingAdapter(
|
77
93
|
job_context,
|
78
94
|
sample_number=sample_number,
|
@@ -81,9 +97,14 @@ def check_with_spark(
|
|
81
97
|
sampler.process_args(args)
|
82
98
|
snowpark_sample_args = sampler.get_sampled_snowpark_args()
|
83
99
|
pyspark_sample_args = sampler.get_sampled_spark_args()
|
100
|
+
|
84
101
|
# Run the sampled data in snowpark
|
102
|
+
LOGGER.info("Running the Snowpark function with sampled args")
|
85
103
|
snowpark_test_results = snowpark_fn(*snowpark_sample_args, **kwargs)
|
104
|
+
LOGGER.info("Running the Spark function with sampled args")
|
86
105
|
spark_test_results = spark_function(*pyspark_sample_args, **kwargs)
|
106
|
+
|
107
|
+
LOGGER.info("Comparing the results of the Snowpark and Spark functions")
|
87
108
|
result, exception = _assert_return(
|
88
109
|
snowpark_test_results,
|
89
110
|
spark_test_results,
|
@@ -92,7 +113,18 @@ def check_with_spark(
|
|
92
113
|
output_path,
|
93
114
|
)
|
94
115
|
if not result:
|
116
|
+
LOGGER.error(
|
117
|
+
"Validation failed. The results of the Snowpark function '%s' and Spark function '%s' do not match",
|
118
|
+
snowpark_fn.__name__,
|
119
|
+
spark_function.__name__,
|
120
|
+
)
|
95
121
|
raise exception from None
|
122
|
+
LOGGER.info(
|
123
|
+
"Validation passed. The results of the Snowpark function '%s' and Spark function '%s' match",
|
124
|
+
snowpark_fn.__name__,
|
125
|
+
spark_function.__name__,
|
126
|
+
)
|
127
|
+
|
96
128
|
# Run the original function in snowpark
|
97
129
|
return snowpark_fn(*args, **kwargs)
|
98
130
|
|
@@ -126,6 +158,7 @@ def _assert_return(
|
|
126
158
|
if isinstance(snowpark_results, SnowparkDataFrame) and isinstance(
|
127
159
|
spark_results, SparkDataFrame
|
128
160
|
):
|
161
|
+
LOGGER.debug("Comparing two DataFrame results for equality")
|
129
162
|
cmp = compare_spark_snowpark_dfs(spark_results, snowpark_results)
|
130
163
|
|
131
164
|
if not cmp.empty:
|
@@ -137,7 +170,7 @@ def _assert_return(
|
|
137
170
|
_update_validation_result(checkpoint_name, PASS_STATUS, output_path)
|
138
171
|
return True, None
|
139
172
|
else:
|
140
|
-
|
173
|
+
LOGGER.debug("Comparing two scalar results for equality")
|
141
174
|
if snowpark_results != spark_results:
|
142
175
|
exception_result = SparkMigrationError(
|
143
176
|
"Return value difference:\n",
|
@@ -13,6 +13,7 @@
|
|
13
13
|
# See the License for the specific language governing permissions and
|
14
14
|
# limitations under the License.
|
15
15
|
|
16
|
+
import logging
|
16
17
|
import os
|
17
18
|
|
18
19
|
from typing import Optional
|
@@ -22,6 +23,9 @@ from snowflake.snowpark_checkpoints.utils.constants import (
|
|
22
23
|
)
|
23
24
|
|
24
25
|
|
26
|
+
LOGGER = logging.getLogger(__name__)
|
27
|
+
|
28
|
+
|
25
29
|
# noinspection DuplicatedCode
|
26
30
|
def _get_checkpoint_contract_file_path() -> str:
|
27
31
|
return os.environ.get(SNOWFLAKE_CHECKPOINT_CONTRACT_FILE_PATH_ENV_VAR, os.getcwd())
|
@@ -35,10 +39,14 @@ def _get_metadata():
|
|
35
39
|
)
|
36
40
|
|
37
41
|
path = _get_checkpoint_contract_file_path()
|
42
|
+
LOGGER.debug("Loading checkpoint metadata from '%s'", path)
|
38
43
|
metadata = CheckpointMetadata(path)
|
39
44
|
return True, metadata
|
40
45
|
|
41
46
|
except ImportError:
|
47
|
+
LOGGER.debug(
|
48
|
+
"snowpark-checkpoints-configuration is not installed. Cannot get a checkpoint metadata instance."
|
49
|
+
)
|
42
50
|
return False, None
|
43
51
|
|
44
52
|
|
@@ -56,8 +64,7 @@ def is_checkpoint_enabled(checkpoint_name: Optional[str] = None) -> bool:
|
|
56
64
|
if enabled and checkpoint_name is not None:
|
57
65
|
config = metadata.get_checkpoint(checkpoint_name)
|
58
66
|
return config.enabled
|
59
|
-
|
60
|
-
return True
|
67
|
+
return True
|
61
68
|
|
62
69
|
|
63
70
|
def get_checkpoint_file(checkpoint_name: str) -> Optional[str]:
|
@@ -78,7 +85,5 @@ def get_checkpoint_file(checkpoint_name: str) -> Optional[str]:
|
|
78
85
|
enabled, metadata = _get_metadata()
|
79
86
|
if enabled:
|
80
87
|
config = metadata.get_checkpoint(checkpoint_name)
|
81
|
-
|
82
88
|
return config.file
|
83
|
-
|
84
|
-
return None
|
89
|
+
return None
|
@@ -0,0 +1,67 @@
|
|
1
|
+
# Copyright 2025 Snowflake Inc.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
import logging
|
17
|
+
|
18
|
+
from functools import wraps
|
19
|
+
from typing import Callable, Optional, TypeVar
|
20
|
+
|
21
|
+
from typing_extensions import ParamSpec
|
22
|
+
|
23
|
+
|
24
|
+
P = ParamSpec("P")
|
25
|
+
R = TypeVar("R")
|
26
|
+
|
27
|
+
|
28
|
+
def log(
|
29
|
+
_func: Optional[Callable[P, R]] = None,
|
30
|
+
*,
|
31
|
+
logger: Optional[logging.Logger] = None,
|
32
|
+
log_args: bool = True,
|
33
|
+
) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
34
|
+
"""Log the function call and any exceptions that occur.
|
35
|
+
|
36
|
+
Args:
|
37
|
+
_func: The function to log.
|
38
|
+
logger: The logger to use for logging. If not provided, a logger will be created using the
|
39
|
+
function's module name.
|
40
|
+
log_args: Whether to log the arguments passed to the function.
|
41
|
+
|
42
|
+
Returns:
|
43
|
+
A decorator that logs the function call and any exceptions that occur.
|
44
|
+
|
45
|
+
"""
|
46
|
+
|
47
|
+
def decorator(func: Callable[P, R]) -> Callable[P, R]:
|
48
|
+
@wraps(func)
|
49
|
+
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
50
|
+
_logger = logging.getLogger(func.__module__) if logger is None else logger
|
51
|
+
if log_args:
|
52
|
+
args_repr = [repr(a) for a in args]
|
53
|
+
kwargs_repr = [f"{k}={v!r}" for k, v in kwargs.items()]
|
54
|
+
formatted_args = ", ".join([*args_repr, *kwargs_repr])
|
55
|
+
_logger.debug("%s called with args %s", func.__name__, formatted_args)
|
56
|
+
try:
|
57
|
+
return func(*args, **kwargs)
|
58
|
+
except Exception:
|
59
|
+
_logger.exception("An error occurred in %s", func.__name__)
|
60
|
+
raise
|
61
|
+
|
62
|
+
return wrapper
|
63
|
+
|
64
|
+
# Handle the case where the decorator is used without parentheses
|
65
|
+
if _func is None:
|
66
|
+
return decorator
|
67
|
+
return decorator(_func)
|
@@ -1,9 +1,25 @@
|
|
1
|
+
# Copyright 2025 Snowflake Inc.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
import logging
|
17
|
+
|
1
18
|
from datetime import datetime
|
2
19
|
from typing import Optional
|
3
20
|
|
4
21
|
from pandera import Check, DataFrameSchema
|
5
22
|
|
6
|
-
from snowflake.snowpark_checkpoints.utils.checkpoint_logger import CheckpointLogger
|
7
23
|
from snowflake.snowpark_checkpoints.utils.constants import (
|
8
24
|
COLUMNS_KEY,
|
9
25
|
DECIMAL_PRECISION_KEY,
|
@@ -28,6 +44,9 @@ from snowflake.snowpark_checkpoints.utils.supported_types import (
|
|
28
44
|
)
|
29
45
|
|
30
46
|
|
47
|
+
LOGGER = logging.getLogger(__name__)
|
48
|
+
|
49
|
+
|
31
50
|
class PanderaCheckManager:
|
32
51
|
def __init__(self, checkpoint_name: str, schema: DataFrameSchema):
|
33
52
|
self.checkpoint_name = checkpoint_name
|
@@ -258,25 +277,28 @@ class PanderaCheckManager:
|
|
258
277
|
ValueError: If the column name or type is not defined in the schema.
|
259
278
|
|
260
279
|
"""
|
261
|
-
|
262
|
-
for additional_check in custom_data.get(COLUMNS_KEY):
|
280
|
+
LOGGER.info("Adding checks for the checkpoint '%s'", self.checkpoint_name)
|
263
281
|
|
264
|
-
|
282
|
+
for additional_check in custom_data.get(COLUMNS_KEY):
|
265
283
|
name = additional_check.get(NAME_KEY, None)
|
266
|
-
is_nullable = additional_check.get(NULLABLE_KEY, False)
|
267
|
-
|
268
284
|
if name is None:
|
269
285
|
raise ValueError(
|
270
286
|
f"Column name not defined in the schema {self.checkpoint_name}"
|
271
287
|
)
|
272
288
|
|
289
|
+
type = additional_check.get(TYPE_KEY, None)
|
273
290
|
if type is None:
|
274
291
|
raise ValueError(f"Type not defined for column {name}")
|
275
292
|
|
276
293
|
if self.schema.columns.get(name) is None:
|
277
|
-
|
294
|
+
LOGGER.warning(
|
295
|
+
"Column '%s' was not found in the Pandera schema. Skipping checks for this column.",
|
296
|
+
name,
|
297
|
+
)
|
278
298
|
continue
|
279
299
|
|
300
|
+
LOGGER.debug("Adding checks for column '%s' of type '%s'", name, type)
|
301
|
+
|
280
302
|
if type in NumericTypes:
|
281
303
|
self._add_numeric_checks(name, additional_check)
|
282
304
|
|
@@ -289,7 +311,9 @@ class PanderaCheckManager:
|
|
289
311
|
elif type == "datetime":
|
290
312
|
self._add_date_time_checks(name, additional_check)
|
291
313
|
|
314
|
+
is_nullable = additional_check.get(NULLABLE_KEY, False)
|
292
315
|
if is_nullable:
|
316
|
+
LOGGER.debug("Column '%s' is nullable. Adding null checks.", name)
|
293
317
|
self._add_null_checks(name, additional_check)
|
294
318
|
|
295
319
|
return self.schema
|
@@ -318,8 +342,19 @@ class PanderaCheckManager:
|
|
318
342
|
if col in self.schema.columns:
|
319
343
|
|
320
344
|
if SKIP_ALL in checks_to_skip:
|
345
|
+
LOGGER.info(
|
346
|
+
"Skipping all checks for column '%s' in checkpoint '%s'",
|
347
|
+
col,
|
348
|
+
self.checkpoint_name,
|
349
|
+
)
|
321
350
|
self.schema.columns[col].checks = {}
|
322
351
|
else:
|
352
|
+
LOGGER.info(
|
353
|
+
"Skipping checks %s for column '%s' in checkpoint '%s'",
|
354
|
+
checks_to_skip,
|
355
|
+
col,
|
356
|
+
self.checkpoint_name,
|
357
|
+
)
|
323
358
|
self.schema.columns[col].checks = [
|
324
359
|
check
|
325
360
|
for check in self.schema.columns[col].checks
|
@@ -350,6 +385,12 @@ class PanderaCheckManager:
|
|
350
385
|
for col, checks in custom_checks.items():
|
351
386
|
|
352
387
|
if col in self.schema.columns:
|
388
|
+
LOGGER.info(
|
389
|
+
"Adding %s custom checks to column '%s' in checkpoint '%s'",
|
390
|
+
len(checks),
|
391
|
+
col,
|
392
|
+
self.checkpoint_name,
|
393
|
+
)
|
353
394
|
col_schema = self.schema.columns[col]
|
354
395
|
col_schema.checks.extend(checks)
|
355
396
|
else:
|
@@ -15,6 +15,7 @@
|
|
15
15
|
|
16
16
|
import inspect
|
17
17
|
import json
|
18
|
+
import logging
|
18
19
|
import os
|
19
20
|
import re
|
20
21
|
|
@@ -58,6 +59,9 @@ from snowflake.snowpark_checkpoints.validation_result_metadata import (
|
|
58
59
|
from snowflake.snowpark_checkpoints.validation_results import ValidationResult
|
59
60
|
|
60
61
|
|
62
|
+
LOGGER = logging.getLogger(__name__)
|
63
|
+
|
64
|
+
|
61
65
|
def _replace_special_characters(checkpoint_name: str) -> str:
|
62
66
|
"""Replace special characters in the checkpoint name with underscores.
|
63
67
|
|
@@ -147,6 +151,9 @@ def _generate_schema(
|
|
147
151
|
constraints of the DataFrame.
|
148
152
|
|
149
153
|
"""
|
154
|
+
LOGGER.info(
|
155
|
+
"Generating Pandera DataFrameSchema for checkpoint: '%s'", checkpoint_name
|
156
|
+
)
|
150
157
|
current_directory_path = output_path if output_path else os.getcwd()
|
151
158
|
|
152
159
|
output_directory_path = os.path.join(
|
@@ -169,6 +176,7 @@ Please run the Snowpark checkpoint collector first."""
|
|
169
176
|
f"Checkpoint {checkpoint_name} JSON file not found. Please run the Snowpark checkpoint collector first."
|
170
177
|
)
|
171
178
|
|
179
|
+
LOGGER.info("Reading schema from file: '%s'", checkpoint_schema_file_path)
|
172
180
|
with open(checkpoint_schema_file_path) as schema_file:
|
173
181
|
checkpoint_schema_config = json.load(schema_file)
|
174
182
|
|
@@ -182,6 +190,10 @@ Please run the Snowpark checkpoint collector first."""
|
|
182
190
|
schema = DataFrameSchema.from_json(schema_dict_str)
|
183
191
|
|
184
192
|
if DATAFRAME_CUSTOM_DATA_KEY not in checkpoint_schema_config:
|
193
|
+
LOGGER.info(
|
194
|
+
"No custom data found in the JSON file for checkpoint: '%s'",
|
195
|
+
checkpoint_name,
|
196
|
+
)
|
185
197
|
return schema
|
186
198
|
|
187
199
|
custom_data = checkpoint_schema_config.get(DATAFRAME_CUSTOM_DATA_KEY)
|
@@ -221,7 +233,7 @@ def _check_compare_data(
|
|
221
233
|
SchemaValidationError: If there is a data mismatch between the DataFrame and the checkpoint table.
|
222
234
|
|
223
235
|
"""
|
224
|
-
|
236
|
+
_, err = _compare_data(df, job_context, checkpoint_name, output_path)
|
225
237
|
if err is not None:
|
226
238
|
raise err
|
227
239
|
|
@@ -256,9 +268,18 @@ def _compare_data(
|
|
256
268
|
|
257
269
|
"""
|
258
270
|
new_table_name = CHECKPOINT_TABLE_NAME_FORMAT.format(checkpoint_name)
|
259
|
-
|
271
|
+
LOGGER.info(
|
272
|
+
"Writing Snowpark DataFrame to table: '%s' for checkpoint: '%s'",
|
273
|
+
new_table_name,
|
274
|
+
checkpoint_name,
|
275
|
+
)
|
260
276
|
df.write.save_as_table(table_name=new_table_name, mode="overwrite")
|
261
277
|
|
278
|
+
LOGGER.info(
|
279
|
+
"Comparing DataFrame to checkpoint table: '%s' for checkpoint: '%s'",
|
280
|
+
new_table_name,
|
281
|
+
checkpoint_name,
|
282
|
+
)
|
262
283
|
expect_df = job_context.snowpark_session.sql(
|
263
284
|
EXCEPT_HASH_AGG_QUERY, [checkpoint_name, new_table_name]
|
264
285
|
)
|
@@ -13,6 +13,7 @@
|
|
13
13
|
# See the License for the specific language governing permissions and
|
14
14
|
# limitations under the License.
|
15
15
|
|
16
|
+
import logging
|
16
17
|
import os
|
17
18
|
|
18
19
|
from typing import Optional
|
@@ -28,6 +29,9 @@ from snowflake.snowpark_checkpoints.validation_results import (
|
|
28
29
|
)
|
29
30
|
|
30
31
|
|
32
|
+
LOGGER = logging.getLogger(__name__)
|
33
|
+
|
34
|
+
|
31
35
|
class ValidationResultsMetadata(metaclass=Singleton):
|
32
36
|
|
33
37
|
"""ValidationResultsMetadata is a class that manages the loading, storing, and updating of validation results.
|
@@ -69,14 +73,26 @@ class ValidationResultsMetadata(metaclass=Singleton):
|
|
69
73
|
SNOWPARK_CHECKPOINTS_OUTPUT_DIRECTORY_NAME,
|
70
74
|
)
|
71
75
|
|
76
|
+
LOGGER.debug(
|
77
|
+
"Setting validation results directory to: '%s'",
|
78
|
+
self.validation_results_directory,
|
79
|
+
)
|
80
|
+
|
72
81
|
self.validation_results_file = os.path.join(
|
73
82
|
self.validation_results_directory,
|
74
83
|
VALIDATION_RESULTS_JSON_FILE_NAME,
|
75
84
|
)
|
76
85
|
|
86
|
+
LOGGER.debug(
|
87
|
+
"Setting validation results file to: '%s'", self.validation_results_file
|
88
|
+
)
|
89
|
+
|
77
90
|
self.validation_results = ValidationResults(results=[])
|
78
91
|
|
79
92
|
if os.path.exists(self.validation_results_file):
|
93
|
+
LOGGER.info(
|
94
|
+
"Loading validation results from: '%s'", self.validation_results_file
|
95
|
+
)
|
80
96
|
with open(self.validation_results_file) as file:
|
81
97
|
try:
|
82
98
|
validation_result_json = file.read()
|
@@ -87,6 +103,11 @@ class ValidationResultsMetadata(metaclass=Singleton):
|
|
87
103
|
raise Exception(
|
88
104
|
f"Error reading validation results file: {self.validation_results_file} \n {e}"
|
89
105
|
) from None
|
106
|
+
else:
|
107
|
+
LOGGER.info(
|
108
|
+
"Validation results file not found: '%s'",
|
109
|
+
self.validation_results_file,
|
110
|
+
)
|
90
111
|
|
91
112
|
def clean(self):
|
92
113
|
"""Clean the validation results list.
|
@@ -95,6 +116,7 @@ class ValidationResultsMetadata(metaclass=Singleton):
|
|
95
116
|
|
96
117
|
"""
|
97
118
|
if not os.path.exists(self.validation_results_file):
|
119
|
+
LOGGER.info("Cleaning validation results...")
|
98
120
|
self.validation_results.results = []
|
99
121
|
|
100
122
|
def add_validation_result(self, validation_result: ValidationResult):
|
@@ -119,7 +141,15 @@ class ValidationResultsMetadata(metaclass=Singleton):
|
|
119
141
|
|
120
142
|
"""
|
121
143
|
if not os.path.exists(self.validation_results_directory):
|
144
|
+
LOGGER.debug(
|
145
|
+
"Validation results directory '%s' does not exist. Creating it...",
|
146
|
+
self.validation_results_directory,
|
147
|
+
)
|
122
148
|
os.makedirs(self.validation_results_directory)
|
123
149
|
|
124
150
|
with open(self.validation_results_file, "w") as output_file:
|
125
151
|
output_file.write(self.validation_results.model_dump_json())
|
152
|
+
LOGGER.info(
|
153
|
+
"Validation results successfully saved to: '%s'",
|
154
|
+
self.validation_results_file,
|
155
|
+
)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: snowpark-checkpoints-validators
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.2.0
|
4
4
|
Summary: Migration tools for Snowpark
|
5
5
|
Project-URL: Bug Tracker, https://github.com/snowflakedb/snowpark-checkpoints/issues
|
6
6
|
Project-URL: Source code, https://github.com/snowflakedb/snowpark-checkpoints/
|
@@ -26,11 +26,9 @@ Classifier: Topic :: Software Development :: Libraries
|
|
26
26
|
Classifier: Topic :: Software Development :: Libraries :: Application Frameworks
|
27
27
|
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
28
28
|
Requires-Python: <3.12,>=3.9
|
29
|
-
Requires-Dist: pandera-report==0.1.2
|
30
29
|
Requires-Dist: pandera[io]==0.20.4
|
31
|
-
Requires-Dist: pyspark
|
32
30
|
Requires-Dist: snowflake-connector-python[pandas]
|
33
|
-
Requires-Dist: snowflake-snowpark-python
|
31
|
+
Requires-Dist: snowflake-snowpark-python>=1.23.0
|
34
32
|
Provides-Extra: development
|
35
33
|
Requires-Dist: coverage>=7.6.7; extra == 'development'
|
36
34
|
Requires-Dist: deepdiff==8.1.1; extra == 'development'
|
@@ -38,10 +36,13 @@ Requires-Dist: deepdiff>=8.0.0; extra == 'development'
|
|
38
36
|
Requires-Dist: hatchling==1.25.0; extra == 'development'
|
39
37
|
Requires-Dist: pre-commit>=4.0.1; extra == 'development'
|
40
38
|
Requires-Dist: pyarrow>=18.0.0; extra == 'development'
|
39
|
+
Requires-Dist: pyspark>=3.5.0; extra == 'development'
|
41
40
|
Requires-Dist: pytest-cov>=6.0.0; extra == 'development'
|
42
41
|
Requires-Dist: pytest>=8.3.3; extra == 'development'
|
43
42
|
Requires-Dist: setuptools>=70.0.0; extra == 'development'
|
44
43
|
Requires-Dist: twine==5.1.1; extra == 'development'
|
44
|
+
Provides-Extra: pyspark
|
45
|
+
Requires-Dist: pyspark>=3.5.0; extra == 'pyspark'
|
45
46
|
Description-Content-Type: text/markdown
|
46
47
|
|
47
48
|
# snowpark-checkpoints-validators
|
@@ -52,6 +53,17 @@ Description-Content-Type: text/markdown
|
|
52
53
|
|
53
54
|
**snowpark-checkpoints-validators** is a package designed to validate Snowpark DataFrames against predefined schemas and checkpoints. This package ensures data integrity and consistency by performing schema and data validation checks at various stages of a Snowpark pipeline.
|
54
55
|
|
56
|
+
---
|
57
|
+
## Install the library
|
58
|
+
```bash
|
59
|
+
pip install snowpark-checkpoints-validators
|
60
|
+
```
|
61
|
+
This package requires PySpark to be installed in the same environment. If you do not have it, you can install PySpark alongside Snowpark Checkpoints by running the following command:
|
62
|
+
```bash
|
63
|
+
pip install "snowpark-checkpoints-validators[pyspark]"
|
64
|
+
```
|
65
|
+
---
|
66
|
+
|
55
67
|
## Features
|
56
68
|
|
57
69
|
- Validate Snowpark DataFrames against predefined Pandera schemas.
|
@@ -0,0 +1,22 @@
|
|
1
|
+
snowflake/snowpark_checkpoints/__init__.py,sha256=p7fzH3f8foD5nhNJHZ00JT3ODTXJGGkWTd3xRKx-8aQ,1435
|
2
|
+
snowflake/snowpark_checkpoints/__version__.py,sha256=ajnGza8ucK69-PA8wEbHmWZxDwd3bsTm74yMKiIWNHY,632
|
3
|
+
snowflake/snowpark_checkpoints/checkpoint.py,sha256=i-iDRYbGvQHy9ipW7UxHVhJhQ9BXNSO-bsCcHyg3oLA,22056
|
4
|
+
snowflake/snowpark_checkpoints/errors.py,sha256=9KjzRf8bjDZTTNL4LeySJAwuucDOyz0Ka7EFBKWFpyg,1821
|
5
|
+
snowflake/snowpark_checkpoints/job_context.py,sha256=RMK0g0HrbDVrOAvai4PgsGvsAn_GIo9aFmh-tWlyieY,4183
|
6
|
+
snowflake/snowpark_checkpoints/singleton.py,sha256=7AgIHQBXVRvPBBCkmBplzkdrrm-xVWf_N8svzA2vF8E,836
|
7
|
+
snowflake/snowpark_checkpoints/snowpark_sampler.py,sha256=Qxv-8nRGuf-ab3GoSUt8_MNL0ppjoBIMOFIMkqmwN5I,4668
|
8
|
+
snowflake/snowpark_checkpoints/spark_migration.py,sha256=s2HqomYx76Hqn71g9TleBeHI3t1nirgfPvkggqQQdts,10253
|
9
|
+
snowflake/snowpark_checkpoints/validation_result_metadata.py,sha256=fm2lKxjYlzlL6qsiv2icR9k5o7YNd2OwvFhiqGYrTpo,5745
|
10
|
+
snowflake/snowpark_checkpoints/validation_results.py,sha256=J8OcpNty6hQD8RbAy8xmA0UMbPWfXSmQnHYspWWSisk,1502
|
11
|
+
snowflake/snowpark_checkpoints/utils/__init__.py,sha256=I4srmZ8G1q9DU6Suo1S91aVfNvETyisKH95uvLAvEJ0,609
|
12
|
+
snowflake/snowpark_checkpoints/utils/constants.py,sha256=pgFttLDQ6fTa6obSdvivWBYClS21ap41YVDNGAS4sxY,4146
|
13
|
+
snowflake/snowpark_checkpoints/utils/extra_config.py,sha256=LvOdIhvE450AV0wLVK5P_hANvcNzAv8pLNe7Ksr598U,2802
|
14
|
+
snowflake/snowpark_checkpoints/utils/logging_utils.py,sha256=yyi6X5DqKeTg0HRhvsH6ymYp2P0wbnyKIzI2RzrQS7k,2278
|
15
|
+
snowflake/snowpark_checkpoints/utils/pandera_check_manager.py,sha256=tQIozLO-2kM8WZ-gGKfRwmXBx1cDPaIZB0qIcArp8xA,16100
|
16
|
+
snowflake/snowpark_checkpoints/utils/supported_types.py,sha256=GrMX2tHdSFnK7LlPbZx20UufD6Br6TNVRkkBwIxdPy0,1433
|
17
|
+
snowflake/snowpark_checkpoints/utils/telemetry.py,sha256=_WOVo19BxcF6cpQDplID6BEOvgJfHTGK1JZI1-OI4uc,31370
|
18
|
+
snowflake/snowpark_checkpoints/utils/utils_checks.py,sha256=ythgWkLstEkCae_TqtdPXJ1Jjbx9iTN8sLOl1ewKxzI,14191
|
19
|
+
snowpark_checkpoints_validators-0.2.0.dist-info/METADATA,sha256=ixLNouygrcyBFCQK3D77nmAIKsWnPIV9gCYSP_rRi1I,11470
|
20
|
+
snowpark_checkpoints_validators-0.2.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
21
|
+
snowpark_checkpoints_validators-0.2.0.dist-info/licenses/LICENSE,sha256=pmjhbh6uVhV5MBXOlou_UZgFP7CYVQITkCCdvfcS5lY,11340
|
22
|
+
snowpark_checkpoints_validators-0.2.0.dist-info/RECORD,,
|
@@ -1,52 +0,0 @@
|
|
1
|
-
# Copyright 2025 Snowflake Inc.
|
2
|
-
# SPDX-License-Identifier: Apache-2.0
|
3
|
-
|
4
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
-
# you may not use this file except in compliance with the License.
|
6
|
-
# You may obtain a copy of the License at
|
7
|
-
|
8
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
-
|
10
|
-
# Unless required by applicable law or agreed to in writing, software
|
11
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
-
# See the License for the specific language governing permissions and
|
14
|
-
# limitations under the License.
|
15
|
-
|
16
|
-
import logging
|
17
|
-
import threading
|
18
|
-
|
19
|
-
|
20
|
-
class CheckpointLogger:
|
21
|
-
_instance = None
|
22
|
-
_lock = threading.Lock()
|
23
|
-
|
24
|
-
def __new__(cls, *args, **kwargs):
|
25
|
-
if not cls._instance:
|
26
|
-
with cls._lock:
|
27
|
-
if not cls._instance:
|
28
|
-
cls._instance = super().__new__(cls, *args, **kwargs)
|
29
|
-
cls._instance._initialize()
|
30
|
-
return cls._instance
|
31
|
-
|
32
|
-
def _initialize(self):
|
33
|
-
# Create formatter
|
34
|
-
formatter = logging.Formatter(
|
35
|
-
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
36
|
-
)
|
37
|
-
|
38
|
-
self.logger = logging.getLogger("CheckpointLogger")
|
39
|
-
self.logger.setLevel(logging.INFO)
|
40
|
-
|
41
|
-
# Create console handler and set level to debug
|
42
|
-
ch = logging.StreamHandler()
|
43
|
-
ch.setLevel(logging.DEBUG)
|
44
|
-
|
45
|
-
# Add formatter to ch
|
46
|
-
ch.setFormatter(formatter)
|
47
|
-
|
48
|
-
# Add ch to logger
|
49
|
-
self.logger.addHandler(ch)
|
50
|
-
|
51
|
-
def get_logger(self):
|
52
|
-
return self.logger
|
@@ -1,22 +0,0 @@
|
|
1
|
-
snowflake/snowpark_checkpoints/__init__.py,sha256=1_xzSopIHWpw1i3gQqWLN0wCfWWEefjr4cl1vl0xSdY,1211
|
2
|
-
snowflake/snowpark_checkpoints/__version__.py,sha256=OfdAqrd8gnFI-pK7o_olRVrRKIWfQhQOoo_wR3u1s5s,632
|
3
|
-
snowflake/snowpark_checkpoints/checkpoint.py,sha256=-y1iWdGxYGuTWdngOEXdA59MT33PCiM7cP1s3jJs9jE,18997
|
4
|
-
snowflake/snowpark_checkpoints/errors.py,sha256=9KjzRf8bjDZTTNL4LeySJAwuucDOyz0Ka7EFBKWFpyg,1821
|
5
|
-
snowflake/snowpark_checkpoints/job_context.py,sha256=7LdJ682lC8hCJOYUn-AVXq_Llv18R9oGdK2F-amYR_o,2990
|
6
|
-
snowflake/snowpark_checkpoints/singleton.py,sha256=7AgIHQBXVRvPBBCkmBplzkdrrm-xVWf_N8svzA2vF8E,836
|
7
|
-
snowflake/snowpark_checkpoints/snowpark_sampler.py,sha256=-t7cg-swMK0SaU7r8y90MLSDPXGlKprc6xdVxEs29sU,3632
|
8
|
-
snowflake/snowpark_checkpoints/spark_migration.py,sha256=DzzgUZ-XlzIqCz-aWpBICP8mgnjk8UNoL8JsomadF-U,8832
|
9
|
-
snowflake/snowpark_checkpoints/validation_result_metadata.py,sha256=mHCIq6-F37HK-jYBAPeVtax9eIwiCvQZxFPGWi4KvQc,4765
|
10
|
-
snowflake/snowpark_checkpoints/validation_results.py,sha256=J8OcpNty6hQD8RbAy8xmA0UMbPWfXSmQnHYspWWSisk,1502
|
11
|
-
snowflake/snowpark_checkpoints/utils/__init__.py,sha256=I4srmZ8G1q9DU6Suo1S91aVfNvETyisKH95uvLAvEJ0,609
|
12
|
-
snowflake/snowpark_checkpoints/utils/checkpoint_logger.py,sha256=meGl5T3Avp4Qn0GEwkJi5GSLS4MDb7zTGbTOI-8bf1E,1592
|
13
|
-
snowflake/snowpark_checkpoints/utils/constants.py,sha256=pgFttLDQ6fTa6obSdvivWBYClS21ap41YVDNGAS4sxY,4146
|
14
|
-
snowflake/snowpark_checkpoints/utils/extra_config.py,sha256=pmGLYT7cu9WMKzQwcEPkgk1DMnnT1fREm45p19e79hk,2567
|
15
|
-
snowflake/snowpark_checkpoints/utils/pandera_check_manager.py,sha256=ddTwXauuZdowIRwPMT61GWYCG4XGKOFkVyfZO49bc-8,14516
|
16
|
-
snowflake/snowpark_checkpoints/utils/supported_types.py,sha256=GrMX2tHdSFnK7LlPbZx20UufD6Br6TNVRkkBwIxdPy0,1433
|
17
|
-
snowflake/snowpark_checkpoints/utils/telemetry.py,sha256=_WOVo19BxcF6cpQDplID6BEOvgJfHTGK1JZI1-OI4uc,31370
|
18
|
-
snowflake/snowpark_checkpoints/utils/utils_checks.py,sha256=o9HOBrDuTxSIgzZQHfsa9pMzzXRUsRAISI7L6OURouo,13528
|
19
|
-
snowpark_checkpoints_validators-0.1.3.dist-info/METADATA,sha256=BSv42Vrlq07M2hIiKHSXJNxaH1O4mYQOJ1U4uApT9uA,11064
|
20
|
-
snowpark_checkpoints_validators-0.1.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
21
|
-
snowpark_checkpoints_validators-0.1.3.dist-info/licenses/LICENSE,sha256=pmjhbh6uVhV5MBXOlou_UZgFP7CYVQITkCCdvfcS5lY,11340
|
22
|
-
snowpark_checkpoints_validators-0.1.3.dist-info/RECORD,,
|
File without changes
|
File without changes
|