snowpark-checkpoints-validators 0.1.4__tar.gz → 0.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {snowpark_checkpoints_validators-0.1.4 → snowpark_checkpoints_validators-0.2.0}/PKG-INFO +1 -1
- {snowpark_checkpoints_validators-0.1.4 → snowpark_checkpoints_validators-0.2.0}/src/snowflake/snowpark_checkpoints/__init__.py +11 -1
- {snowpark_checkpoints_validators-0.1.4 → snowpark_checkpoints_validators-0.2.0}/src/snowflake/snowpark_checkpoints/__version__.py +1 -1
- {snowpark_checkpoints_validators-0.1.4 → snowpark_checkpoints_validators-0.2.0}/src/snowflake/snowpark_checkpoints/checkpoint.py +173 -82
- snowpark_checkpoints_validators-0.2.0/src/snowflake/snowpark_checkpoints/job_context.py +128 -0
- {snowpark_checkpoints_validators-0.1.4 → snowpark_checkpoints_validators-0.2.0}/src/snowflake/snowpark_checkpoints/snowpark_sampler.py +26 -1
- {snowpark_checkpoints_validators-0.1.4 → snowpark_checkpoints_validators-0.2.0}/src/snowflake/snowpark_checkpoints/spark_migration.py +39 -6
- {snowpark_checkpoints_validators-0.1.4 → snowpark_checkpoints_validators-0.2.0}/src/snowflake/snowpark_checkpoints/utils/extra_config.py +10 -5
- snowpark_checkpoints_validators-0.2.0/src/snowflake/snowpark_checkpoints/utils/logging_utils.py +67 -0
- {snowpark_checkpoints_validators-0.1.4 → snowpark_checkpoints_validators-0.2.0}/src/snowflake/snowpark_checkpoints/utils/pandera_check_manager.py +48 -7
- {snowpark_checkpoints_validators-0.1.4 → snowpark_checkpoints_validators-0.2.0}/src/snowflake/snowpark_checkpoints/utils/utils_checks.py +23 -2
- {snowpark_checkpoints_validators-0.1.4 → snowpark_checkpoints_validators-0.2.0}/src/snowflake/snowpark_checkpoints/validation_result_metadata.py +30 -0
- {snowpark_checkpoints_validators-0.1.4 → snowpark_checkpoints_validators-0.2.0}/test/integ/test_pandera.py +47 -18
- {snowpark_checkpoints_validators-0.1.4 → snowpark_checkpoints_validators-0.2.0}/test/integ/test_parquet.py +84 -25
- {snowpark_checkpoints_validators-0.1.4 → snowpark_checkpoints_validators-0.2.0}/test/integ/test_spark_checkpoint.py +40 -21
- {snowpark_checkpoints_validators-0.1.4 → snowpark_checkpoints_validators-0.2.0}/test/unit/test_extra_config.py +7 -1
- snowpark_checkpoints_validators-0.2.0/test/unit/test_job_context.py +49 -0
- snowpark_checkpoints_validators-0.2.0/test/unit/test_logger.py +134 -0
- snowpark_checkpoints_validators-0.2.0/test/unit/test_logging_utils.py +132 -0
- {snowpark_checkpoints_validators-0.1.4 → snowpark_checkpoints_validators-0.2.0}/test/unit/test_validation_result_metadata.py +40 -0
- snowpark_checkpoints_validators-0.1.4/src/snowflake/snowpark_checkpoints/job_context.py +0 -85
- snowpark_checkpoints_validators-0.1.4/src/snowflake/snowpark_checkpoints/utils/checkpoint_logger.py +0 -52
- {snowpark_checkpoints_validators-0.1.4 → snowpark_checkpoints_validators-0.2.0}/.gitignore +0 -0
- {snowpark_checkpoints_validators-0.1.4 → snowpark_checkpoints_validators-0.2.0}/CHANGELOG.md +0 -0
- {snowpark_checkpoints_validators-0.1.4 → snowpark_checkpoints_validators-0.2.0}/LICENSE +0 -0
- {snowpark_checkpoints_validators-0.1.4 → snowpark_checkpoints_validators-0.2.0}/README.md +0 -0
- {snowpark_checkpoints_validators-0.1.4 → snowpark_checkpoints_validators-0.2.0}/pyproject.toml +0 -0
- {snowpark_checkpoints_validators-0.1.4 → snowpark_checkpoints_validators-0.2.0}/src/snowflake/snowpark_checkpoints/errors.py +0 -0
- {snowpark_checkpoints_validators-0.1.4 → snowpark_checkpoints_validators-0.2.0}/src/snowflake/snowpark_checkpoints/singleton.py +0 -0
- {snowpark_checkpoints_validators-0.1.4 → snowpark_checkpoints_validators-0.2.0}/src/snowflake/snowpark_checkpoints/utils/__init__.py +0 -0
- {snowpark_checkpoints_validators-0.1.4 → snowpark_checkpoints_validators-0.2.0}/src/snowflake/snowpark_checkpoints/utils/constants.py +0 -0
- {snowpark_checkpoints_validators-0.1.4 → snowpark_checkpoints_validators-0.2.0}/src/snowflake/snowpark_checkpoints/utils/supported_types.py +0 -0
- {snowpark_checkpoints_validators-0.1.4 → snowpark_checkpoints_validators-0.2.0}/src/snowflake/snowpark_checkpoints/utils/telemetry.py +0 -0
- {snowpark_checkpoints_validators-0.1.4 → snowpark_checkpoints_validators-0.2.0}/src/snowflake/snowpark_checkpoints/validation_results.py +0 -0
- {snowpark_checkpoints_validators-0.1.4 → snowpark_checkpoints_validators-0.2.0}/test/.coveragerc +0 -0
- {snowpark_checkpoints_validators-0.1.4 → snowpark_checkpoints_validators-0.2.0}/test/integ/e2eexample.py +0 -0
- {snowpark_checkpoints_validators-0.1.4 → snowpark_checkpoints_validators-0.2.0}/test/integ/telemetry_compare_utils.py +0 -0
- {snowpark_checkpoints_validators-0.1.4 → snowpark_checkpoints_validators-0.2.0}/test/integ/telemetry_expected/df_mode_dataframe_mismatch_telemetry.json +0 -0
- {snowpark_checkpoints_validators-0.1.4 → snowpark_checkpoints_validators-0.2.0}/test/integ/telemetry_expected/df_mode_dataframe_telemetry.json +0 -0
- {snowpark_checkpoints_validators-0.1.4 → snowpark_checkpoints_validators-0.2.0}/test/integ/telemetry_expected/spark_checkpoint_df_fail_telemetry.json +0 -0
- {snowpark_checkpoints_validators-0.1.4 → snowpark_checkpoints_validators-0.2.0}/test/integ/telemetry_expected/spark_checkpoint_df_pass_telemetry.json +0 -0
- {snowpark_checkpoints_validators-0.1.4 → snowpark_checkpoints_validators-0.2.0}/test/integ/telemetry_expected/spark_checkpoint_limit_sample_telemetry.json +0 -0
- {snowpark_checkpoints_validators-0.1.4 → snowpark_checkpoints_validators-0.2.0}/test/integ/telemetry_expected/spark_checkpoint_random_sample_telemetry.json +0 -0
- {snowpark_checkpoints_validators-0.1.4 → snowpark_checkpoints_validators-0.2.0}/test/integ/telemetry_expected/spark_checkpoint_scalar_fail_telemetry.json +0 -0
- {snowpark_checkpoints_validators-0.1.4 → snowpark_checkpoints_validators-0.2.0}/test/integ/telemetry_expected/spark_checkpoint_scalar_passing_telemetry.json +0 -0
- {snowpark_checkpoints_validators-0.1.4 → snowpark_checkpoints_validators-0.2.0}/test/integ/telemetry_expected/test_df_check_custom_check_telemetry.json +0 -0
- {snowpark_checkpoints_validators-0.1.4 → snowpark_checkpoints_validators-0.2.0}/test/integ/telemetry_expected/test_df_check_fail_telemetry.json +0 -0
- {snowpark_checkpoints_validators-0.1.4 → snowpark_checkpoints_validators-0.2.0}/test/integ/telemetry_expected/test_df_check_from_file_telemetry.json +0 -0
- {snowpark_checkpoints_validators-0.1.4 → snowpark_checkpoints_validators-0.2.0}/test/integ/telemetry_expected/test_df_check_skip_check_telemetry.json +0 -0
- {snowpark_checkpoints_validators-0.1.4 → snowpark_checkpoints_validators-0.2.0}/test/integ/telemetry_expected/test_df_check_telemetry.json +0 -0
- {snowpark_checkpoints_validators-0.1.4 → snowpark_checkpoints_validators-0.2.0}/test/integ/telemetry_expected/test_input_fail_telemetry.json +0 -0
- {snowpark_checkpoints_validators-0.1.4 → snowpark_checkpoints_validators-0.2.0}/test/integ/telemetry_expected/test_input_telemetry.json +0 -0
- {snowpark_checkpoints_validators-0.1.4 → snowpark_checkpoints_validators-0.2.0}/test/integ/telemetry_expected/test_output_fail_telemetry.json +0 -0
- {snowpark_checkpoints_validators-0.1.4 → snowpark_checkpoints_validators-0.2.0}/test/integ/telemetry_expected/test_output_telemetry.json +0 -0
- {snowpark_checkpoints_validators-0.1.4 → snowpark_checkpoints_validators-0.2.0}/test/unit/test_pandera_check_manager.py +0 -0
- {snowpark_checkpoints_validators-0.1.4 → snowpark_checkpoints_validators-0.2.0}/test/unit/test_spark_migration.py +0 -0
- {snowpark_checkpoints_validators-0.1.4 → snowpark_checkpoints_validators-0.2.0}/test/unit/test_telemetry.py +0 -0
- {snowpark_checkpoints_validators-0.1.4 → snowpark_checkpoints_validators-0.2.0}/test/unit/test_utils_checks.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: snowpark-checkpoints-validators
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.2.0
|
4
4
|
Summary: Migration tools for Snowpark
|
5
5
|
Project-URL: Bug Tracker, https://github.com/snowflakedb/snowpark-checkpoints/issues
|
6
6
|
Project-URL: Source code, https://github.com/snowflakedb/snowpark-checkpoints/
|
@@ -13,16 +13,26 @@
|
|
13
13
|
# See the License for the specific language governing permissions and
|
14
14
|
# limitations under the License.
|
15
15
|
|
16
|
+
import logging
|
17
|
+
|
18
|
+
|
19
|
+
# Add a NullHandler to prevent logging messages from being output to
|
20
|
+
# sys.stderr if no logging configuration is provided.
|
21
|
+
logging.getLogger(__name__).addHandler(logging.NullHandler())
|
22
|
+
|
23
|
+
# ruff: noqa: E402
|
24
|
+
|
16
25
|
from snowflake.snowpark_checkpoints.checkpoint import (
|
17
26
|
check_dataframe_schema,
|
18
|
-
check_output_schema,
|
19
27
|
check_input_schema,
|
28
|
+
check_output_schema,
|
20
29
|
validate_dataframe_checkpoint,
|
21
30
|
)
|
22
31
|
from snowflake.snowpark_checkpoints.job_context import SnowparkJobContext
|
23
32
|
from snowflake.snowpark_checkpoints.spark_migration import check_with_spark
|
24
33
|
from snowflake.snowpark_checkpoints.utils.constants import CheckpointMode
|
25
34
|
|
35
|
+
|
26
36
|
__all__ = [
|
27
37
|
"check_with_spark",
|
28
38
|
"SnowparkJobContext",
|
@@ -14,6 +14,9 @@
|
|
14
14
|
# limitations under the License.
|
15
15
|
|
16
16
|
# Wrapper around pandera which logs to snowflake
|
17
|
+
|
18
|
+
import logging
|
19
|
+
|
17
20
|
from typing import Any, Optional, Union, cast
|
18
21
|
|
19
22
|
from pandas import DataFrame as PandasDataFrame
|
@@ -27,13 +30,13 @@ from snowflake.snowpark_checkpoints.snowpark_sampler import (
|
|
27
30
|
SamplingAdapter,
|
28
31
|
SamplingStrategy,
|
29
32
|
)
|
30
|
-
from snowflake.snowpark_checkpoints.utils.checkpoint_logger import CheckpointLogger
|
31
33
|
from snowflake.snowpark_checkpoints.utils.constants import (
|
32
34
|
FAIL_STATUS,
|
33
35
|
PASS_STATUS,
|
34
36
|
CheckpointMode,
|
35
37
|
)
|
36
38
|
from snowflake.snowpark_checkpoints.utils.extra_config import is_checkpoint_enabled
|
39
|
+
from snowflake.snowpark_checkpoints.utils.logging_utils import log
|
37
40
|
from snowflake.snowpark_checkpoints.utils.pandera_check_manager import (
|
38
41
|
PanderaCheckManager,
|
39
42
|
)
|
@@ -47,6 +50,10 @@ from snowflake.snowpark_checkpoints.utils.utils_checks import (
|
|
47
50
|
)
|
48
51
|
|
49
52
|
|
53
|
+
LOGGER = logging.getLogger(__name__)
|
54
|
+
|
55
|
+
|
56
|
+
@log
|
50
57
|
def validate_dataframe_checkpoint(
|
51
58
|
df: SnowparkDataFrame,
|
52
59
|
checkpoint_name: str,
|
@@ -84,31 +91,45 @@ def validate_dataframe_checkpoint(
|
|
84
91
|
"""
|
85
92
|
checkpoint_name = _replace_special_characters(checkpoint_name)
|
86
93
|
|
87
|
-
if is_checkpoint_enabled(checkpoint_name):
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
94
|
+
if not is_checkpoint_enabled(checkpoint_name):
|
95
|
+
LOGGER.warning(
|
96
|
+
"Checkpoint '%s' is disabled. Skipping DataFrame checkpoint validation.",
|
97
|
+
checkpoint_name,
|
98
|
+
)
|
99
|
+
return None
|
100
|
+
|
101
|
+
LOGGER.info(
|
102
|
+
"Starting DataFrame checkpoint validation for checkpoint '%s'", checkpoint_name
|
103
|
+
)
|
104
|
+
|
105
|
+
if mode == CheckpointMode.SCHEMA:
|
106
|
+
result = _check_dataframe_schema_file(
|
107
|
+
df,
|
108
|
+
checkpoint_name,
|
109
|
+
job_context,
|
110
|
+
custom_checks,
|
111
|
+
skip_checks,
|
112
|
+
sample_frac,
|
113
|
+
sample_number,
|
114
|
+
sampling_strategy,
|
115
|
+
output_path,
|
116
|
+
)
|
117
|
+
return result
|
118
|
+
|
119
|
+
if mode == CheckpointMode.DATAFRAME:
|
120
|
+
if job_context is None:
|
108
121
|
raise ValueError(
|
109
|
-
"
|
110
|
-
Please use for schema validation use a 1 or for a full data validation use a 2 for schema validation."""
|
122
|
+
"No job context provided. Please provide one when using DataFrame mode validation."
|
111
123
|
)
|
124
|
+
_check_compare_data(df, job_context, checkpoint_name, output_path)
|
125
|
+
return None
|
126
|
+
|
127
|
+
raise ValueError(
|
128
|
+
(
|
129
|
+
"Invalid validation mode. "
|
130
|
+
"Please use 1 for schema validation or 2 for full data validation."
|
131
|
+
),
|
132
|
+
)
|
112
133
|
|
113
134
|
|
114
135
|
def _check_dataframe_schema_file(
|
@@ -156,7 +177,7 @@ def _check_dataframe_schema_file(
|
|
156
177
|
|
157
178
|
schema = _generate_schema(checkpoint_name, output_path)
|
158
179
|
|
159
|
-
return
|
180
|
+
return _check_dataframe_schema(
|
160
181
|
df,
|
161
182
|
schema,
|
162
183
|
checkpoint_name,
|
@@ -170,6 +191,7 @@ def _check_dataframe_schema_file(
|
|
170
191
|
)
|
171
192
|
|
172
193
|
|
194
|
+
@log
|
173
195
|
def check_dataframe_schema(
|
174
196
|
df: SnowparkDataFrame,
|
175
197
|
pandera_schema: DataFrameSchema,
|
@@ -212,6 +234,9 @@ def check_dataframe_schema(
|
|
212
234
|
|
213
235
|
"""
|
214
236
|
checkpoint_name = _replace_special_characters(checkpoint_name)
|
237
|
+
LOGGER.info(
|
238
|
+
"Starting DataFrame schema validation for checkpoint '%s'", checkpoint_name
|
239
|
+
)
|
215
240
|
|
216
241
|
if df is None:
|
217
242
|
raise ValueError("DataFrame is required")
|
@@ -219,19 +244,25 @@ def check_dataframe_schema(
|
|
219
244
|
if pandera_schema is None:
|
220
245
|
raise ValueError("Schema is required")
|
221
246
|
|
222
|
-
if is_checkpoint_enabled(checkpoint_name):
|
223
|
-
|
224
|
-
|
225
|
-
pandera_schema,
|
247
|
+
if not is_checkpoint_enabled(checkpoint_name):
|
248
|
+
LOGGER.warning(
|
249
|
+
"Checkpoint '%s' is disabled. Skipping DataFrame schema validation.",
|
226
250
|
checkpoint_name,
|
227
|
-
job_context,
|
228
|
-
custom_checks,
|
229
|
-
skip_checks,
|
230
|
-
sample_frac,
|
231
|
-
sample_number,
|
232
|
-
sampling_strategy,
|
233
|
-
output_path,
|
234
251
|
)
|
252
|
+
return None
|
253
|
+
|
254
|
+
return _check_dataframe_schema(
|
255
|
+
df,
|
256
|
+
pandera_schema,
|
257
|
+
checkpoint_name,
|
258
|
+
job_context,
|
259
|
+
custom_checks,
|
260
|
+
skip_checks,
|
261
|
+
sample_frac,
|
262
|
+
sample_number,
|
263
|
+
sampling_strategy,
|
264
|
+
output_path,
|
265
|
+
)
|
235
266
|
|
236
267
|
|
237
268
|
@report_telemetry(
|
@@ -261,10 +292,22 @@ def _check_dataframe_schema(
|
|
261
292
|
)
|
262
293
|
is_valid, validation_result = _validate(pandera_schema_upper, sample_df)
|
263
294
|
if is_valid:
|
295
|
+
LOGGER.info(
|
296
|
+
"DataFrame schema validation passed for checkpoint '%s'",
|
297
|
+
checkpoint_name,
|
298
|
+
)
|
264
299
|
if job_context is not None:
|
265
300
|
job_context._mark_pass(checkpoint_name)
|
301
|
+
else:
|
302
|
+
LOGGER.warning(
|
303
|
+
"No job context provided. Skipping result recording into Snowflake.",
|
304
|
+
)
|
266
305
|
_update_validation_result(checkpoint_name, PASS_STATUS, output_path)
|
267
306
|
else:
|
307
|
+
LOGGER.error(
|
308
|
+
"DataFrame schema validation failed for checkpoint '%s'",
|
309
|
+
checkpoint_name,
|
310
|
+
)
|
268
311
|
_update_validation_result(checkpoint_name, FAIL_STATUS, output_path)
|
269
312
|
raise SchemaValidationError(
|
270
313
|
"Snowpark DataFrame schema validation error",
|
@@ -277,6 +320,7 @@ def _check_dataframe_schema(
|
|
277
320
|
|
278
321
|
|
279
322
|
@report_telemetry(params_list=["pandera_schema"])
|
323
|
+
@log
|
280
324
|
def check_output_schema(
|
281
325
|
pandera_schema: DataFrameSchema,
|
282
326
|
checkpoint_name: str,
|
@@ -313,11 +357,8 @@ def check_output_schema(
|
|
313
357
|
function: The decorated function.
|
314
358
|
|
315
359
|
"""
|
316
|
-
_checkpoint_name = checkpoint_name
|
317
|
-
if checkpoint_name is None:
|
318
|
-
_checkpoint_name = snowpark_fn.__name__
|
319
|
-
_checkpoint_name = _replace_special_characters(_checkpoint_name)
|
320
360
|
|
361
|
+
@log(log_args=False)
|
321
362
|
def wrapper(*args, **kwargs):
|
322
363
|
"""Wrapp a function to validate the schema of the output of a Snowpark function.
|
323
364
|
|
@@ -329,7 +370,25 @@ def check_output_schema(
|
|
329
370
|
Any: The result of the Snowpark function.
|
330
371
|
|
331
372
|
"""
|
373
|
+
_checkpoint_name = checkpoint_name
|
374
|
+
if checkpoint_name is None:
|
375
|
+
LOGGER.warning(
|
376
|
+
(
|
377
|
+
"No checkpoint name provided for output schema validation. "
|
378
|
+
"Using '%s' as the checkpoint name.",
|
379
|
+
),
|
380
|
+
snowpark_fn.__name__,
|
381
|
+
)
|
382
|
+
_checkpoint_name = snowpark_fn.__name__
|
383
|
+
_checkpoint_name = _replace_special_characters(_checkpoint_name)
|
384
|
+
LOGGER.info(
|
385
|
+
"Starting output schema validation for Snowpark function '%s' and checkpoint '%s'",
|
386
|
+
snowpark_fn.__name__,
|
387
|
+
_checkpoint_name,
|
388
|
+
)
|
389
|
+
|
332
390
|
# Run the sampled data in snowpark
|
391
|
+
LOGGER.info("Running the Snowpark function '%s'", snowpark_fn.__name__)
|
333
392
|
snowpark_results = snowpark_fn(*args, **kwargs)
|
334
393
|
sampler = SamplingAdapter(
|
335
394
|
job_context, sample_frac, sample_number, sampling_strategy
|
@@ -340,17 +399,25 @@ def check_output_schema(
|
|
340
399
|
is_valid, validation_result = _validate(
|
341
400
|
pandera_schema, pandas_sample_args[0]
|
342
401
|
)
|
343
|
-
logger = CheckpointLogger().get_logger()
|
344
|
-
logger.info(
|
345
|
-
f"Checkpoint {_checkpoint_name} validation result:\n{validation_result}"
|
346
|
-
)
|
347
|
-
|
348
402
|
if is_valid:
|
403
|
+
LOGGER.info(
|
404
|
+
"Output schema validation passed for Snowpark function '%s' and checkpoint '%s'",
|
405
|
+
snowpark_fn.__name__,
|
406
|
+
_checkpoint_name,
|
407
|
+
)
|
349
408
|
if job_context is not None:
|
350
409
|
job_context._mark_pass(_checkpoint_name)
|
351
|
-
|
410
|
+
else:
|
411
|
+
LOGGER.warning(
|
412
|
+
"No job context provided. Skipping result recording into Snowflake.",
|
413
|
+
)
|
352
414
|
_update_validation_result(_checkpoint_name, PASS_STATUS, output_path)
|
353
415
|
else:
|
416
|
+
LOGGER.error(
|
417
|
+
"Output schema validation failed for Snowpark function '%s' and checkpoint '%s'",
|
418
|
+
snowpark_fn.__name__,
|
419
|
+
_checkpoint_name,
|
420
|
+
)
|
354
421
|
_update_validation_result(_checkpoint_name, FAIL_STATUS, output_path)
|
355
422
|
raise SchemaValidationError(
|
356
423
|
"Snowpark output schema validation error",
|
@@ -358,7 +425,6 @@ def check_output_schema(
|
|
358
425
|
_checkpoint_name,
|
359
426
|
validation_result,
|
360
427
|
)
|
361
|
-
|
362
428
|
return snowpark_results
|
363
429
|
|
364
430
|
return wrapper
|
@@ -367,6 +433,7 @@ def check_output_schema(
|
|
367
433
|
|
368
434
|
|
369
435
|
@report_telemetry(params_list=["pandera_schema"])
|
436
|
+
@log
|
370
437
|
def check_input_schema(
|
371
438
|
pandera_schema: DataFrameSchema,
|
372
439
|
checkpoint_name: str,
|
@@ -407,11 +474,8 @@ def check_input_schema(
|
|
407
474
|
Callable: A wrapper function that performs schema validation before executing the original function.
|
408
475
|
|
409
476
|
"""
|
410
|
-
_checkpoint_name = checkpoint_name
|
411
|
-
if checkpoint_name is None:
|
412
|
-
_checkpoint_name = snowpark_fn.__name__
|
413
|
-
_checkpoint_name = _replace_special_characters(_checkpoint_name)
|
414
477
|
|
478
|
+
@log(log_args=False)
|
415
479
|
def wrapper(*args, **kwargs):
|
416
480
|
"""Wrapp a function to validate the schema of the input of a Snowpark function.
|
417
481
|
|
@@ -422,6 +486,23 @@ def check_input_schema(
|
|
422
486
|
Any: The result of the original function after input validation.
|
423
487
|
|
424
488
|
"""
|
489
|
+
_checkpoint_name = checkpoint_name
|
490
|
+
if checkpoint_name is None:
|
491
|
+
LOGGER.warning(
|
492
|
+
(
|
493
|
+
"No checkpoint name provided for input schema validation. "
|
494
|
+
"Using '%s' as the checkpoint name."
|
495
|
+
),
|
496
|
+
snowpark_fn.__name__,
|
497
|
+
)
|
498
|
+
_checkpoint_name = snowpark_fn.__name__
|
499
|
+
_checkpoint_name = _replace_special_characters(_checkpoint_name)
|
500
|
+
LOGGER.info(
|
501
|
+
"Starting input schema validation for Snowpark function '%s' and checkpoint '%s'",
|
502
|
+
snowpark_fn.__name__,
|
503
|
+
_checkpoint_name,
|
504
|
+
)
|
505
|
+
|
425
506
|
# Run the sampled data in snowpark
|
426
507
|
sampler = SamplingAdapter(
|
427
508
|
job_context, sample_frac, sample_number, sampling_strategy
|
@@ -429,43 +510,53 @@ def check_input_schema(
|
|
429
510
|
sampler.process_args(args)
|
430
511
|
pandas_sample_args = sampler.get_sampled_pandas_args()
|
431
512
|
|
513
|
+
LOGGER.info(
|
514
|
+
"Validating %s input argument(s) against a Pandera schema",
|
515
|
+
len(pandas_sample_args),
|
516
|
+
)
|
432
517
|
# Raises SchemaError on validation issues
|
433
|
-
for arg in pandas_sample_args:
|
434
|
-
if isinstance(arg, PandasDataFrame):
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
arg,
|
518
|
+
for index, arg in enumerate(pandas_sample_args, start=1):
|
519
|
+
if not isinstance(arg, PandasDataFrame):
|
520
|
+
LOGGER.info(
|
521
|
+
"Arg %s: Skipping schema validation for non-DataFrame argument",
|
522
|
+
index,
|
439
523
|
)
|
524
|
+
continue
|
440
525
|
|
441
|
-
|
442
|
-
|
443
|
-
|
526
|
+
is_valid, validation_result = _validate(
|
527
|
+
pandera_schema,
|
528
|
+
arg,
|
529
|
+
)
|
530
|
+
if is_valid:
|
531
|
+
LOGGER.info(
|
532
|
+
"Arg %s: Input schema validation passed",
|
533
|
+
index,
|
444
534
|
)
|
445
|
-
|
446
|
-
|
447
|
-
if job_context is not None:
|
448
|
-
job_context._mark_pass(
|
449
|
-
_checkpoint_name,
|
450
|
-
)
|
451
|
-
|
452
|
-
_update_validation_result(
|
453
|
-
_checkpoint_name,
|
454
|
-
PASS_STATUS,
|
455
|
-
output_path,
|
456
|
-
)
|
457
|
-
else:
|
458
|
-
_update_validation_result(
|
459
|
-
_checkpoint_name,
|
460
|
-
FAIL_STATUS,
|
461
|
-
output_path,
|
462
|
-
)
|
463
|
-
raise SchemaValidationError(
|
464
|
-
"Snowpark input schema validation error",
|
465
|
-
job_context,
|
535
|
+
if job_context is not None:
|
536
|
+
job_context._mark_pass(
|
466
537
|
_checkpoint_name,
|
467
|
-
validation_result,
|
468
538
|
)
|
539
|
+
_update_validation_result(
|
540
|
+
_checkpoint_name,
|
541
|
+
PASS_STATUS,
|
542
|
+
output_path,
|
543
|
+
)
|
544
|
+
else:
|
545
|
+
LOGGER.error(
|
546
|
+
"Arg %s: Input schema validation failed",
|
547
|
+
index,
|
548
|
+
)
|
549
|
+
_update_validation_result(
|
550
|
+
_checkpoint_name,
|
551
|
+
FAIL_STATUS,
|
552
|
+
output_path,
|
553
|
+
)
|
554
|
+
raise SchemaValidationError(
|
555
|
+
"Snowpark input schema validation error",
|
556
|
+
job_context,
|
557
|
+
_checkpoint_name,
|
558
|
+
validation_result,
|
559
|
+
)
|
469
560
|
return snowpark_fn(*args, **kwargs)
|
470
561
|
|
471
562
|
return wrapper
|
@@ -0,0 +1,128 @@
|
|
1
|
+
# Copyright 2025 Snowflake Inc.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
import logging
|
17
|
+
|
18
|
+
from datetime import datetime
|
19
|
+
from typing import Optional
|
20
|
+
|
21
|
+
import pandas as pd
|
22
|
+
|
23
|
+
from pyspark.sql import SparkSession
|
24
|
+
|
25
|
+
from snowflake.snowpark import Session
|
26
|
+
from snowflake.snowpark_checkpoints.utils.constants import SCHEMA_EXECUTION_MODE
|
27
|
+
|
28
|
+
|
29
|
+
LOGGER = logging.getLogger(__name__)
|
30
|
+
RESULTS_TABLE = "SNOWPARK_CHECKPOINTS_REPORT"
|
31
|
+
|
32
|
+
|
33
|
+
class SnowparkJobContext:
|
34
|
+
|
35
|
+
"""Class used to record migration results in Snowflake.
|
36
|
+
|
37
|
+
Args:
|
38
|
+
snowpark_session: A Snowpark session instance.
|
39
|
+
spark_session: A Spark session instance.
|
40
|
+
job_name: The name of the job.
|
41
|
+
log_results: Whether to log the migration results in Snowflake.
|
42
|
+
|
43
|
+
"""
|
44
|
+
|
45
|
+
def __init__(
|
46
|
+
self,
|
47
|
+
snowpark_session: Session,
|
48
|
+
spark_session: SparkSession = None,
|
49
|
+
job_name: Optional[str] = None,
|
50
|
+
log_results: Optional[bool] = True,
|
51
|
+
):
|
52
|
+
self.log_results = log_results
|
53
|
+
self.job_name = job_name
|
54
|
+
self.spark_session = spark_session or self._create_pyspark_session()
|
55
|
+
self.snowpark_session = snowpark_session
|
56
|
+
|
57
|
+
def _mark_fail(
|
58
|
+
self, message, checkpoint_name, data, execution_mode=SCHEMA_EXECUTION_MODE
|
59
|
+
):
|
60
|
+
if not self.log_results:
|
61
|
+
LOGGER.warning(
|
62
|
+
(
|
63
|
+
"Recording of migration results into Snowflake is disabled. "
|
64
|
+
"Failure result for checkpoint '%s' will not be recorded."
|
65
|
+
),
|
66
|
+
checkpoint_name,
|
67
|
+
)
|
68
|
+
return
|
69
|
+
|
70
|
+
LOGGER.debug(
|
71
|
+
"Marking failure for checkpoint '%s' in '%s' mode with message '%s'",
|
72
|
+
checkpoint_name,
|
73
|
+
execution_mode,
|
74
|
+
message,
|
75
|
+
)
|
76
|
+
|
77
|
+
session = self.snowpark_session
|
78
|
+
df = pd.DataFrame(
|
79
|
+
{
|
80
|
+
"DATE": [datetime.now()],
|
81
|
+
"JOB": [self.job_name],
|
82
|
+
"STATUS": ["fail"],
|
83
|
+
"CHECKPOINT": [checkpoint_name],
|
84
|
+
"MESSAGE": [message],
|
85
|
+
"DATA": [f"{data}"],
|
86
|
+
"EXECUTION_MODE": [execution_mode],
|
87
|
+
}
|
88
|
+
)
|
89
|
+
report_df = session.createDataFrame(df)
|
90
|
+
LOGGER.info("Writing failure result to table: '%s'", RESULTS_TABLE)
|
91
|
+
report_df.write.mode("append").save_as_table(RESULTS_TABLE)
|
92
|
+
|
93
|
+
def _mark_pass(self, checkpoint_name, execution_mode=SCHEMA_EXECUTION_MODE):
|
94
|
+
if not self.log_results:
|
95
|
+
LOGGER.warning(
|
96
|
+
(
|
97
|
+
"Recording of migration results into Snowflake is disabled. "
|
98
|
+
"Pass result for checkpoint '%s' will not be recorded."
|
99
|
+
),
|
100
|
+
checkpoint_name,
|
101
|
+
)
|
102
|
+
return
|
103
|
+
|
104
|
+
LOGGER.debug(
|
105
|
+
"Marking pass for checkpoint '%s' in '%s' mode",
|
106
|
+
checkpoint_name,
|
107
|
+
execution_mode,
|
108
|
+
)
|
109
|
+
|
110
|
+
session = self.snowpark_session
|
111
|
+
df = pd.DataFrame(
|
112
|
+
{
|
113
|
+
"DATE": [datetime.now()],
|
114
|
+
"JOB": [self.job_name],
|
115
|
+
"STATUS": ["pass"],
|
116
|
+
"CHECKPOINT": [checkpoint_name],
|
117
|
+
"MESSAGE": [""],
|
118
|
+
"DATA": [""],
|
119
|
+
"EXECUTION_MODE": [execution_mode],
|
120
|
+
}
|
121
|
+
)
|
122
|
+
report_df = session.createDataFrame(df)
|
123
|
+
LOGGER.info("Writing pass result to table: '%s'", RESULTS_TABLE)
|
124
|
+
report_df.write.mode("append").save_as_table(RESULTS_TABLE)
|
125
|
+
|
126
|
+
def _create_pyspark_session(self) -> SparkSession:
|
127
|
+
LOGGER.info("Creating a PySpark session")
|
128
|
+
return SparkSession.builder.getOrCreate()
|
@@ -13,6 +13,8 @@
|
|
13
13
|
# See the License for the specific language governing permissions and
|
14
14
|
# limitations under the License.
|
15
15
|
|
16
|
+
import logging
|
17
|
+
|
16
18
|
from typing import Optional
|
17
19
|
|
18
20
|
import pandas
|
@@ -21,6 +23,9 @@ from snowflake.snowpark import DataFrame as SnowparkDataFrame
|
|
21
23
|
from snowflake.snowpark_checkpoints.job_context import SnowparkJobContext
|
22
24
|
|
23
25
|
|
26
|
+
LOGGER = logging.getLogger(__name__)
|
27
|
+
|
28
|
+
|
24
29
|
class SamplingStrategy:
|
25
30
|
RANDOM_SAMPLE = 1
|
26
31
|
LIMIT = 2
|
@@ -52,23 +57,43 @@ class SamplingAdapter:
|
|
52
57
|
def process_args(self, input_args):
|
53
58
|
# create the intermediate pandas
|
54
59
|
# data frame for the test data
|
60
|
+
LOGGER.info("Processing %s input argument(s) for sampling", len(input_args))
|
55
61
|
for arg in input_args:
|
56
62
|
if isinstance(arg, SnowparkDataFrame):
|
57
|
-
|
63
|
+
df_count = arg.count()
|
64
|
+
if df_count == 0:
|
58
65
|
raise SamplingError(
|
59
66
|
"Input DataFrame is empty. Cannot sample from an empty DataFrame."
|
60
67
|
)
|
61
68
|
|
69
|
+
LOGGER.info("Sampling a Snowpark DataFrame with %s rows", df_count)
|
62
70
|
if self.sampling_strategy == SamplingStrategy.RANDOM_SAMPLE:
|
63
71
|
if self.sample_frac:
|
72
|
+
LOGGER.info(
|
73
|
+
"Applying random sampling with fraction %s",
|
74
|
+
self.sample_frac,
|
75
|
+
)
|
64
76
|
df_sample = arg.sample(frac=self.sample_frac).to_pandas()
|
65
77
|
else:
|
78
|
+
LOGGER.info(
|
79
|
+
"Applying random sampling with size %s", self.sample_number
|
80
|
+
)
|
66
81
|
df_sample = arg.sample(n=self.sample_number).to_pandas()
|
67
82
|
else:
|
83
|
+
LOGGER.info(
|
84
|
+
"Applying limit sampling with size %s", self.sample_number
|
85
|
+
)
|
68
86
|
df_sample = arg.limit(self.sample_number).to_pandas()
|
69
87
|
|
88
|
+
LOGGER.info(
|
89
|
+
"Successfully sampled the DataFrame. Resulting DataFrame shape: %s",
|
90
|
+
df_sample.shape,
|
91
|
+
)
|
70
92
|
self.pandas_sample_args.append(df_sample)
|
71
93
|
else:
|
94
|
+
LOGGER.debug(
|
95
|
+
"Argument is not a Snowpark DataFrame. No sampling is applied."
|
96
|
+
)
|
72
97
|
self.pandas_sample_args.append(arg)
|
73
98
|
|
74
99
|
def get_sampled_pandas_args(self):
|