snowpark-checkpoints-validators 0.1.0rc1__tar.gz → 0.1.0rc2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. {snowpark_checkpoints_validators-0.1.0rc1 → snowpark_checkpoints_validators-0.1.0rc2}/.gitignore +3 -0
  2. {snowpark_checkpoints_validators-0.1.0rc1 → snowpark_checkpoints_validators-0.1.0rc2}/PKG-INFO +120 -52
  3. {snowpark_checkpoints_validators-0.1.0rc1 → snowpark_checkpoints_validators-0.1.0rc2}/README.md +119 -51
  4. {snowpark_checkpoints_validators-0.1.0rc1 → snowpark_checkpoints_validators-0.1.0rc2}/src/snowflake/snowpark_checkpoints/__init__.py +2 -0
  5. {snowpark_checkpoints_validators-0.1.0rc1 → snowpark_checkpoints_validators-0.1.0rc2}/src/snowflake/snowpark_checkpoints/checkpoint.py +90 -89
  6. {snowpark_checkpoints_validators-0.1.0rc1 → snowpark_checkpoints_validators-0.1.0rc2}/src/snowflake/snowpark_checkpoints/errors.py +1 -1
  7. {snowpark_checkpoints_validators-0.1.0rc1 → snowpark_checkpoints_validators-0.1.0rc2}/src/snowflake/snowpark_checkpoints/job_context.py +14 -3
  8. snowpark_checkpoints_validators-0.1.0rc2/src/snowflake/snowpark_checkpoints/singleton.py +12 -0
  9. {snowpark_checkpoints_validators-0.1.0rc1 → snowpark_checkpoints_validators-0.1.0rc2}/src/snowflake/snowpark_checkpoints/spark_migration.py +5 -11
  10. snowpark_checkpoints_validators-0.1.0rc1/src/snowflake/snowpark_checkpoints/utils/constant.py → snowpark_checkpoints_validators-0.1.0rc2/src/snowflake/snowpark_checkpoints/utils/constants.py +9 -0
  11. {snowpark_checkpoints_validators-0.1.0rc1 → snowpark_checkpoints_validators-0.1.0rc2}/src/snowflake/snowpark_checkpoints/utils/extra_config.py +1 -1
  12. snowpark_checkpoints_validators-0.1.0rc2/src/snowflake/snowpark_checkpoints/utils/pandera_check_manager.py +358 -0
  13. {snowpark_checkpoints_validators-0.1.0rc1 → snowpark_checkpoints_validators-0.1.0rc2}/src/snowflake/snowpark_checkpoints/utils/supported_types.py +1 -1
  14. {snowpark_checkpoints_validators-0.1.0rc1 → snowpark_checkpoints_validators-0.1.0rc2}/src/snowflake/snowpark_checkpoints/utils/telemetry.py +290 -103
  15. snowpark_checkpoints_validators-0.1.0rc2/src/snowflake/snowpark_checkpoints/utils/utils_checks.py +361 -0
  16. {snowpark_checkpoints_validators-0.1.0rc1 → snowpark_checkpoints_validators-0.1.0rc2}/src/snowflake/snowpark_checkpoints/validation_result_metadata.py +16 -12
  17. snowpark_checkpoints_validators-0.1.0rc2/test/integ/telemetry_compare_utils.py +54 -0
  18. snowpark_checkpoints_validators-0.1.0rc2/test/integ/telemetry_expected/df_mode_dataframe_mismatch_telemetry.json +17 -0
  19. snowpark_checkpoints_validators-0.1.0rc2/test/integ/telemetry_expected/df_mode_dataframe_telemetry.json +17 -0
  20. snowpark_checkpoints_validators-0.1.0rc2/test/integ/telemetry_expected/spark_checkpoint_df_fail_telemetry.json +17 -0
  21. snowpark_checkpoints_validators-0.1.0rc2/test/integ/telemetry_expected/spark_checkpoint_df_pass_telemetry.json +17 -0
  22. snowpark_checkpoints_validators-0.1.0rc2/test/integ/telemetry_expected/spark_checkpoint_limit_sample_telemetry.json +17 -0
  23. snowpark_checkpoints_validators-0.1.0rc2/test/integ/telemetry_expected/spark_checkpoint_random_sample_telemetry.json +17 -0
  24. snowpark_checkpoints_validators-0.1.0rc2/test/integ/telemetry_expected/spark_checkpoint_scalar_fail_telemetry.json +17 -0
  25. snowpark_checkpoints_validators-0.1.0rc2/test/integ/telemetry_expected/spark_checkpoint_scalar_passing_telemetry.json +17 -0
  26. snowpark_checkpoints_validators-0.1.0rc2/test/integ/telemetry_expected/test_df_check_custom_check_telemetry.json +17 -0
  27. snowpark_checkpoints_validators-0.1.0rc2/test/integ/telemetry_expected/test_df_check_fail_telemetry.json +17 -0
  28. snowpark_checkpoints_validators-0.1.0rc2/test/integ/telemetry_expected/test_df_check_from_file_telemetry.json +17 -0
  29. snowpark_checkpoints_validators-0.1.0rc2/test/integ/telemetry_expected/test_df_check_skip_check_telemetry.json +17 -0
  30. snowpark_checkpoints_validators-0.1.0rc2/test/integ/telemetry_expected/test_df_check_telemetry.json +17 -0
  31. snowpark_checkpoints_validators-0.1.0rc2/test/integ/telemetry_expected/test_input_fail_telemetry.json +17 -0
  32. snowpark_checkpoints_validators-0.1.0rc2/test/integ/telemetry_expected/test_input_telemetry.json +17 -0
  33. snowpark_checkpoints_validators-0.1.0rc2/test/integ/telemetry_expected/test_output_fail_telemetry.json +17 -0
  34. snowpark_checkpoints_validators-0.1.0rc2/test/integ/telemetry_expected/test_output_telemetry.json +17 -0
  35. {snowpark_checkpoints_validators-0.1.0rc1 → snowpark_checkpoints_validators-0.1.0rc2}/test/integ/test_pandera.py +185 -22
  36. {snowpark_checkpoints_validators-0.1.0rc1 → snowpark_checkpoints_validators-0.1.0rc2}/test/integ/test_parquet.py +34 -5
  37. {snowpark_checkpoints_validators-0.1.0rc1 → snowpark_checkpoints_validators-0.1.0rc2}/test/integ/test_spark_checkpoint.py +45 -6
  38. {snowpark_checkpoints_validators-0.1.0rc1 → snowpark_checkpoints_validators-0.1.0rc2}/test/unit/test_extra_config.py +1 -1
  39. snowpark_checkpoints_validators-0.1.0rc2/test/unit/test_pandera_check_manager.py +785 -0
  40. {snowpark_checkpoints_validators-0.1.0rc1 → snowpark_checkpoints_validators-0.1.0rc2}/test/unit/test_utils_checks.py +33 -376
  41. {snowpark_checkpoints_validators-0.1.0rc1 → snowpark_checkpoints_validators-0.1.0rc2}/test/unit/test_validation_result_metadata.py +44 -1
  42. snowpark_checkpoints_validators-0.1.0rc1/src/snowflake/snowpark_checkpoints/utils/utils_checks.py +0 -560
  43. {snowpark_checkpoints_validators-0.1.0rc1 → snowpark_checkpoints_validators-0.1.0rc2}/CHANGELOG.md +0 -0
  44. {snowpark_checkpoints_validators-0.1.0rc1 → snowpark_checkpoints_validators-0.1.0rc2}/LICENSE +0 -0
  45. {snowpark_checkpoints_validators-0.1.0rc1 → snowpark_checkpoints_validators-0.1.0rc2}/pyproject.toml +0 -0
  46. {snowpark_checkpoints_validators-0.1.0rc1 → snowpark_checkpoints_validators-0.1.0rc2}/src/snowflake/snowpark_checkpoints/snowpark_sampler.py +0 -0
  47. {snowpark_checkpoints_validators-0.1.0rc1 → snowpark_checkpoints_validators-0.1.0rc2}/src/snowflake/snowpark_checkpoints/utils/__init__.py +0 -0
  48. {snowpark_checkpoints_validators-0.1.0rc1 → snowpark_checkpoints_validators-0.1.0rc2}/src/snowflake/snowpark_checkpoints/utils/checkpoint_logger.py +0 -0
  49. {snowpark_checkpoints_validators-0.1.0rc1 → snowpark_checkpoints_validators-0.1.0rc2}/src/snowflake/snowpark_checkpoints/validation_results.py +0 -0
  50. {snowpark_checkpoints_validators-0.1.0rc1 → snowpark_checkpoints_validators-0.1.0rc2}/test/.coveragerc +0 -0
  51. {snowpark_checkpoints_validators-0.1.0rc1 → snowpark_checkpoints_validators-0.1.0rc2}/test/integ/e2eexample.py +0 -0
  52. {snowpark_checkpoints_validators-0.1.0rc1 → snowpark_checkpoints_validators-0.1.0rc2}/test/unit/test_spark_migration.py +0 -0
  53. {snowpark_checkpoints_validators-0.1.0rc1 → snowpark_checkpoints_validators-0.1.0rc2}/test/unit/test_telemetry.py +0 -0
@@ -4,10 +4,13 @@
4
4
 
5
5
  # demos
6
6
  snowpark-checkpoints-output/
7
+ Demos/Demos/
8
+ Demos/snowpark-checkpoints-output/
7
9
 
8
10
  # env
9
11
  wheelvenv/
10
12
 
13
+
11
14
  # version
12
15
  !__version__.py
13
16
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: snowpark-checkpoints-validators
3
- Version: 0.1.0rc1
3
+ Version: 0.1.0rc2
4
4
  Summary: Migration tools for Snowpark
5
5
  Project-URL: Bug Tracker, https://github.com/snowflakedb/snowpark-checkpoints/issues
6
6
  Project-URL: Source code, https://github.com/snowflakedb/snowpark-checkpoints/
@@ -244,7 +244,7 @@ Requires-Dist: setuptools>=70.0.0; extra == 'development'
244
244
  Requires-Dist: twine==5.1.1; extra == 'development'
245
245
  Description-Content-Type: text/markdown
246
246
 
247
- # Snowpark Checkpoints Validators
247
+ # snowpark-checkpoints-validators
248
248
 
249
249
  ---
250
250
  **NOTE**
@@ -270,9 +270,16 @@ This package is on Private Preview.
270
270
  The `validate_dataframe_checkpoint` function validates a Snowpark DataFrame against a checkpoint schema file or dataframe.
271
271
 
272
272
  ```python
273
- from snowflake.snowpark_checkpoints.checkpoint import validate_dataframe_checkpoint
273
+ from snowflake.snowpark import DataFrame as SnowparkDataFrame
274
+ from snowflake.snowpark_checkpoints.job_context import SnowparkJobContext
275
+ from snowflake.snowpark_checkpoints.utils.constant import (
276
+ CheckpointMode,
277
+ )
278
+ from snowflake.snowpark_checkpoints.spark_migration import SamplingStrategy
279
+ from typing import Any, Optional
274
280
 
275
- validate_dataframe_checkpoint(
281
+ # Signature of the function
282
+ def validate_dataframe_checkpoint(
276
283
  df: SnowparkDataFrame,
277
284
  checkpoint_name: str,
278
285
  job_context: Optional[SnowparkJobContext] = None,
@@ -283,16 +290,17 @@ validate_dataframe_checkpoint(
283
290
  sample_number: Optional[int] = None,
284
291
  sampling_strategy: Optional[SamplingStrategy] = SamplingStrategy.RANDOM_SAMPLE,
285
292
  output_path: Optional[str] = None,
286
- )
293
+ ):
294
+ ...
287
295
  ```
288
296
 
289
- - `df`: Snowpark DataFrame to validate.
290
- - `checkpoint_name`: Name of the checkpoint schema file or DataFrame.
297
+ - `df`: Snowpark dataframe to validate.
298
+ - `checkpoint_name`: Name of the checkpoint schema file or dataframe.
291
299
  - `job_context`: Snowpark job context.
292
300
  - `mode`: Checkpoint mode (schema or data).
293
301
  - `custom_checks`: Custom checks to perform.
294
302
  - `skip_checks`: Checks to skip.
295
- - `sample_frac`: Fraction of the DataFrame to sample.
303
+ - `sample_frac`: Fraction of the dataframe to sample.
296
304
  - `sample_number`: Number of rows to sample.
297
305
  - `sampling_strategy`: Sampling strategy to use.
298
306
  - `output_path`: Output path for the checkpoint report.
@@ -301,16 +309,24 @@ validate_dataframe_checkpoint(
301
309
 
302
310
  ```python
303
311
  from snowflake.snowpark import Session
304
- from snowflake.snowpark import DataFrame as SnowparkDataFrame
312
+ from snowflake.snowpark_checkpoints.utils.constant import (
313
+ CheckpointMode,
314
+ )
305
315
  from snowflake.snowpark_checkpoints.checkpoint import validate_dataframe_checkpoint
316
+ from snowflake.snowpark_checkpoints.spark_migration import SamplingStrategy
317
+ from snowflake.snowpark_checkpoints.job_context import SnowparkJobContext
318
+ from pyspark.sql import SparkSession
306
319
 
307
320
  session = Session.builder.getOrCreate()
321
+ job_context = SnowparkJobContext(
322
+ session, SparkSession.builder.getOrCreate(), "job_context", True
323
+ )
308
324
  df = session.read.format("csv").load("data.csv")
309
325
 
310
326
  validate_dataframe_checkpoint(
311
327
  df,
312
328
  "schema_checkpoint",
313
- job_context=session,
329
+ job_context=job_context,
314
330
  mode=CheckpointMode.SCHEMA,
315
331
  sample_frac=0.1,
316
332
  sampling_strategy=SamplingStrategy.RANDOM_SAMPLE
@@ -319,22 +335,24 @@ validate_dataframe_checkpoint(
319
335
 
320
336
  ### Check with Spark Decorator
321
337
 
322
- The `check_with_spark` decorator converts any Snowpark DataFrame arguments to a function, samples them, and converts them to PySpark DataFrames. It then executes a provided Spark function and compares the outputs between the two implementations.
338
+ The `check_with_spark` decorator converts any Snowpark dataframe arguments to a function, samples them, and converts them to PySpark dataframe. It then executes a provided Spark function and compares the outputs between the two implementations.
323
339
 
324
340
  ```python
325
- from snowflake.snowpark_checkpoints.spark_migration import check_with_spark
341
+ from snowflake.snowpark_checkpoints.job_context import SnowparkJobContext
342
+ from snowflake.snowpark_checkpoints.spark_migration import SamplingStrategy
343
+ from typing import Callable, Optional, TypeVar
326
344
 
327
- @check_with_spark(
345
+ fn = TypeVar("F", bound=Callable)
346
+
347
+ # Signature of the decorator
348
+ def check_with_spark(
328
349
  job_context: Optional[SnowparkJobContext],
329
- spark_function: Callable,
350
+ spark_function: fn,
330
351
  checkpoint_name: str,
331
352
  sample_number: Optional[int] = 100,
332
353
  sampling_strategy: Optional[SamplingStrategy] = SamplingStrategy.RANDOM_SAMPLE,
333
- check_dtypes: Optional[bool] = False,
334
- check_with_precision: Optional[bool] = False,
335
354
  output_path: Optional[str] = None,
336
- )
337
- def snowpark_fn(df: SnowparkDataFrame):
355
+ ) -> Callable[[fn], fn]:
338
356
  ...
339
357
  ```
340
358
 
@@ -343,8 +361,6 @@ def snowpark_fn(df: SnowparkDataFrame):
343
361
  - `checkpoint_name`: Name of the check.
344
362
  - `sample_number`: Number of rows to sample.
345
363
  - `sampling_strategy`: Sampling strategy to use.
346
- - `check_dtypes`: Check data types.
347
- - `check_with_precision`: Check with precision.
348
364
  - `output_path`: Output path for the checkpoint report.
349
365
 
350
366
  ### Usage Example
@@ -353,52 +369,63 @@ def snowpark_fn(df: SnowparkDataFrame):
353
369
  from snowflake.snowpark import Session
354
370
  from snowflake.snowpark import DataFrame as SnowparkDataFrame
355
371
  from snowflake.snowpark_checkpoints.spark_migration import check_with_spark
372
+ from snowflake.snowpark_checkpoints.job_context import SnowparkJobContext
373
+ from pyspark.sql import DataFrame as SparkDataFrame, SparkSession
356
374
 
357
375
  session = Session.builder.getOrCreate()
358
- df = session.read.format("csv").load("data.csv")
376
+ job_context = SnowparkJobContext(
377
+ session, SparkSession.builder.getOrCreate(), "job_context", True
378
+ )
379
+
380
+ def my_spark_scalar_fn(df: SparkDataFrame):
381
+ return df.count()
359
382
 
360
383
  @check_with_spark(
361
- job_context=session,
362
- spark_function=lambda df: df.withColumn("COLUMN1", df["COLUMN1"] + 1),
363
- checkpoint_name="Check_Column1_Increment",
364
- sample_number=100,
365
- sampling_strategy=SamplingStrategy.RANDOM_SAMPLE,
384
+ job_context=job_context,
385
+ spark_function=my_spark_scalar_fn,
386
+ checkpoint_name="count_checkpoint",
366
387
  )
367
- def increment_column1(df: SnowparkDataFrame):
368
- return df.with_column("COLUMN1", df["COLUMN1"] + 1)
388
+ def my_snowpark_scalar_fn(df: SnowparkDataFrame):
389
+ return df.count()
369
390
 
370
- increment_column1(df)
391
+ df = job_context.snowpark_session.create_dataframe(
392
+ [[1, 2], [3, 4]], schema=["a", "b"]
393
+ )
394
+ count = my_snowpark_scalar_fn(df)
371
395
  ```
372
396
 
373
397
  ### Pandera Snowpark Decorators
374
398
 
375
- The decorators `@check_input_schema` and `@check_output_schema` allow for sampled schema validation of Snowpark DataFrames in the input arguments or in the return value.
399
+ The decorators `@check_input_schema` and `@check_output_schema` allow for sampled schema validation of Snowpark dataframes in the input arguments or in the return value.
376
400
 
377
401
  ```python
378
- from snowflake.snowpark_checkpoints.checkpoint import check_input_schema, check_output_schema
402
+ from snowflake.snowpark_checkpoints.spark_migration import SamplingStrategy
403
+ from snowflake.snowpark_checkpoints.job_context import SnowparkJobContext
404
+ from pandera import DataFrameSchema
405
+ from typing import Optional
379
406
 
380
- @check_input_schema(
407
+ # Signature of the decorator
408
+ def check_input_schema(
381
409
  pandera_schema: DataFrameSchema,
382
410
  checkpoint_name: str,
383
411
  sample_frac: Optional[float] = 1.0,
384
412
  sample_number: Optional[int] = None,
385
413
  sampling_strategy: Optional[SamplingStrategy] = SamplingStrategy.RANDOM_SAMPLE,
386
- job_context: Optional[SnowparkJobContext],
414
+ job_context: Optional[SnowparkJobContext] = None,
387
415
  output_path: Optional[str] = None,
388
- )
389
- def snowpark_fn(df: SnowparkDataFrame):
416
+ ):
390
417
  ...
391
418
 
392
- @check_output_schema(
419
+ # Signature of the decorator
420
+ def check_output_schema(
393
421
  pandera_schema: DataFrameSchema,
394
422
  checkpoint_name: str,
395
423
  sample_frac: Optional[float] = 1.0,
396
424
  sample_number: Optional[int] = None,
397
425
  sampling_strategy: Optional[SamplingStrategy] = SamplingStrategy.RANDOM_SAMPLE,
398
- job_context: Optional[SnowparkJobContext],
426
+ job_context: Optional[SnowparkJobContext] = None,
399
427
  output_path: Optional[str] = None,
400
- )
401
- def snowpark_fn(df: SnowparkDataFrame):
428
+ ):
402
429
  ...
403
430
  ```
404
431
 
@@ -412,28 +439,71 @@ def snowpark_fn(df: SnowparkDataFrame):
412
439
 
413
440
  ### Usage Example
414
441
 
415
- The following will result in a Pandera `SchemaError`:
442
+ #### Check Input Schema Example
443
+ ```python
444
+ from pandas import DataFrame as PandasDataFrame
445
+ from pandera import DataFrameSchema, Column, Check
446
+ from snowflake.snowpark import Session
447
+ from snowflake.snowpark import DataFrame as SnowparkDataFrame
448
+ from snowflake.snowpark_checkpoints.checkpoint import check_input_schema
449
+ from numpy import int8
450
+
451
+ df = PandasDataFrame(
452
+ {
453
+ "COLUMN1": [1, 4, 0, 10, 9],
454
+ "COLUMN2": [-1.3, -1.4, -2.9, -10.1, -20.4],
455
+ }
456
+ )
416
457
 
458
+ in_schema = DataFrameSchema(
459
+ {
460
+ "COLUMN1": Column(int8, Check(lambda x: 0 <= x <= 10, element_wise=True)),
461
+ "COLUMN2": Column(float, Check(lambda x: x < -1.2, element_wise=True)),
462
+ }
463
+ )
464
+
465
+ @check_input_schema(in_schema, "input_schema_checkpoint")
466
+ def preprocessor(dataframe: SnowparkDataFrame):
467
+ dataframe = dataframe.withColumn(
468
+ "COLUMN3", dataframe["COLUMN1"] + dataframe["COLUMN2"]
469
+ )
470
+ return dataframe
471
+
472
+ session = Session.builder.getOrCreate()
473
+ sp_dataframe = session.create_dataframe(df)
474
+
475
+ preprocessed_dataframe = preprocessor(sp_dataframe)
476
+ ```
477
+
478
+ #### Check Input Schema Example
417
479
  ```python
418
480
  from pandas import DataFrame as PandasDataFrame
419
481
  from pandera import DataFrameSchema, Column, Check
420
482
  from snowflake.snowpark import Session
421
483
  from snowflake.snowpark import DataFrame as SnowparkDataFrame
422
484
  from snowflake.snowpark_checkpoints.checkpoint import check_output_schema
485
+ from numpy import int8
423
486
 
424
- df = PandasDataFrame({
425
- "COLUMN1": [1, 4, 0, 10, 9],
426
- "COLUMN2": [-1.3, -1.4, -2.9, -10.1, -20.4],
427
- })
487
+ df = PandasDataFrame(
488
+ {
489
+ "COLUMN1": [1, 4, 0, 10, 9],
490
+ "COLUMN2": [-1.3, -1.4, -2.9, -10.1, -20.4],
491
+ }
492
+ )
428
493
 
429
- out_schema = DataFrameSchema({
430
- "COLUMN1": Column(int8, Check(lambda x: 0 <= x <= 10, element_wise=True)),
431
- "COLUMN2": Column(float, Check(lambda x: x < -1.2)),
432
- })
494
+ out_schema = DataFrameSchema(
495
+ {
496
+ "COLUMN1": Column(int8, Check.between(0, 10, include_max=True, include_min=True)),
497
+ "COLUMN2": Column(float, Check.less_than_or_equal_to(-1.2)),
498
+ "COLUMN3": Column(float, Check.less_than(10)),
499
+ }
500
+ )
433
501
 
434
502
  @check_output_schema(out_schema, "output_schema_checkpoint")
435
503
  def preprocessor(dataframe: SnowparkDataFrame):
436
- return dataframe.with_column("COLUMN1", lit('Some bad data yo'))
504
+ return dataframe.with_column(
505
+ "COLUMN3", dataframe["COLUMN1"] + dataframe["COLUMN2"]
506
+ )
437
507
 
438
508
  session = Session.builder.getOrCreate()
439
509
  sp_dataframe = session.create_dataframe(df)
@@ -441,6 +511,4 @@ sp_dataframe = session.create_dataframe(df)
441
511
  preprocessed_dataframe = preprocessor(sp_dataframe)
442
512
  ```
443
513
 
444
- ## License
445
-
446
- This project is licensed under the MIT License. See the [LICENSE](LICENSE) file for more details.
514
+ ------
@@ -1,4 +1,4 @@
1
- # Snowpark Checkpoints Validators
1
+ # snowpark-checkpoints-validators
2
2
 
3
3
  ---
4
4
  **NOTE**
@@ -24,9 +24,16 @@ This package is on Private Preview.
24
24
  The `validate_dataframe_checkpoint` function validates a Snowpark DataFrame against a checkpoint schema file or dataframe.
25
25
 
26
26
  ```python
27
- from snowflake.snowpark_checkpoints.checkpoint import validate_dataframe_checkpoint
27
+ from snowflake.snowpark import DataFrame as SnowparkDataFrame
28
+ from snowflake.snowpark_checkpoints.job_context import SnowparkJobContext
29
+ from snowflake.snowpark_checkpoints.utils.constant import (
30
+ CheckpointMode,
31
+ )
32
+ from snowflake.snowpark_checkpoints.spark_migration import SamplingStrategy
33
+ from typing import Any, Optional
28
34
 
29
- validate_dataframe_checkpoint(
35
+ # Signature of the function
36
+ def validate_dataframe_checkpoint(
30
37
  df: SnowparkDataFrame,
31
38
  checkpoint_name: str,
32
39
  job_context: Optional[SnowparkJobContext] = None,
@@ -37,16 +44,17 @@ validate_dataframe_checkpoint(
37
44
  sample_number: Optional[int] = None,
38
45
  sampling_strategy: Optional[SamplingStrategy] = SamplingStrategy.RANDOM_SAMPLE,
39
46
  output_path: Optional[str] = None,
40
- )
47
+ ):
48
+ ...
41
49
  ```
42
50
 
43
- - `df`: Snowpark DataFrame to validate.
44
- - `checkpoint_name`: Name of the checkpoint schema file or DataFrame.
51
+ - `df`: Snowpark dataframe to validate.
52
+ - `checkpoint_name`: Name of the checkpoint schema file or dataframe.
45
53
  - `job_context`: Snowpark job context.
46
54
  - `mode`: Checkpoint mode (schema or data).
47
55
  - `custom_checks`: Custom checks to perform.
48
56
  - `skip_checks`: Checks to skip.
49
- - `sample_frac`: Fraction of the DataFrame to sample.
57
+ - `sample_frac`: Fraction of the dataframe to sample.
50
58
  - `sample_number`: Number of rows to sample.
51
59
  - `sampling_strategy`: Sampling strategy to use.
52
60
  - `output_path`: Output path for the checkpoint report.
@@ -55,16 +63,24 @@ validate_dataframe_checkpoint(
55
63
 
56
64
  ```python
57
65
  from snowflake.snowpark import Session
58
- from snowflake.snowpark import DataFrame as SnowparkDataFrame
66
+ from snowflake.snowpark_checkpoints.utils.constant import (
67
+ CheckpointMode,
68
+ )
59
69
  from snowflake.snowpark_checkpoints.checkpoint import validate_dataframe_checkpoint
70
+ from snowflake.snowpark_checkpoints.spark_migration import SamplingStrategy
71
+ from snowflake.snowpark_checkpoints.job_context import SnowparkJobContext
72
+ from pyspark.sql import SparkSession
60
73
 
61
74
  session = Session.builder.getOrCreate()
75
+ job_context = SnowparkJobContext(
76
+ session, SparkSession.builder.getOrCreate(), "job_context", True
77
+ )
62
78
  df = session.read.format("csv").load("data.csv")
63
79
 
64
80
  validate_dataframe_checkpoint(
65
81
  df,
66
82
  "schema_checkpoint",
67
- job_context=session,
83
+ job_context=job_context,
68
84
  mode=CheckpointMode.SCHEMA,
69
85
  sample_frac=0.1,
70
86
  sampling_strategy=SamplingStrategy.RANDOM_SAMPLE
@@ -73,22 +89,24 @@ validate_dataframe_checkpoint(
73
89
 
74
90
  ### Check with Spark Decorator
75
91
 
76
- The `check_with_spark` decorator converts any Snowpark DataFrame arguments to a function, samples them, and converts them to PySpark DataFrames. It then executes a provided Spark function and compares the outputs between the two implementations.
92
+ The `check_with_spark` decorator converts any Snowpark dataframe arguments to a function, samples them, and converts them to PySpark dataframe. It then executes a provided Spark function and compares the outputs between the two implementations.
77
93
 
78
94
  ```python
79
- from snowflake.snowpark_checkpoints.spark_migration import check_with_spark
95
+ from snowflake.snowpark_checkpoints.job_context import SnowparkJobContext
96
+ from snowflake.snowpark_checkpoints.spark_migration import SamplingStrategy
97
+ from typing import Callable, Optional, TypeVar
80
98
 
81
- @check_with_spark(
99
+ fn = TypeVar("F", bound=Callable)
100
+
101
+ # Signature of the decorator
102
+ def check_with_spark(
82
103
  job_context: Optional[SnowparkJobContext],
83
- spark_function: Callable,
104
+ spark_function: fn,
84
105
  checkpoint_name: str,
85
106
  sample_number: Optional[int] = 100,
86
107
  sampling_strategy: Optional[SamplingStrategy] = SamplingStrategy.RANDOM_SAMPLE,
87
- check_dtypes: Optional[bool] = False,
88
- check_with_precision: Optional[bool] = False,
89
108
  output_path: Optional[str] = None,
90
- )
91
- def snowpark_fn(df: SnowparkDataFrame):
109
+ ) -> Callable[[fn], fn]:
92
110
  ...
93
111
  ```
94
112
 
@@ -97,8 +115,6 @@ def snowpark_fn(df: SnowparkDataFrame):
97
115
  - `checkpoint_name`: Name of the check.
98
116
  - `sample_number`: Number of rows to sample.
99
117
  - `sampling_strategy`: Sampling strategy to use.
100
- - `check_dtypes`: Check data types.
101
- - `check_with_precision`: Check with precision.
102
118
  - `output_path`: Output path for the checkpoint report.
103
119
 
104
120
  ### Usage Example
@@ -107,52 +123,63 @@ def snowpark_fn(df: SnowparkDataFrame):
107
123
  from snowflake.snowpark import Session
108
124
  from snowflake.snowpark import DataFrame as SnowparkDataFrame
109
125
  from snowflake.snowpark_checkpoints.spark_migration import check_with_spark
126
+ from snowflake.snowpark_checkpoints.job_context import SnowparkJobContext
127
+ from pyspark.sql import DataFrame as SparkDataFrame, SparkSession
110
128
 
111
129
  session = Session.builder.getOrCreate()
112
- df = session.read.format("csv").load("data.csv")
130
+ job_context = SnowparkJobContext(
131
+ session, SparkSession.builder.getOrCreate(), "job_context", True
132
+ )
133
+
134
+ def my_spark_scalar_fn(df: SparkDataFrame):
135
+ return df.count()
113
136
 
114
137
  @check_with_spark(
115
- job_context=session,
116
- spark_function=lambda df: df.withColumn("COLUMN1", df["COLUMN1"] + 1),
117
- checkpoint_name="Check_Column1_Increment",
118
- sample_number=100,
119
- sampling_strategy=SamplingStrategy.RANDOM_SAMPLE,
138
+ job_context=job_context,
139
+ spark_function=my_spark_scalar_fn,
140
+ checkpoint_name="count_checkpoint",
120
141
  )
121
- def increment_column1(df: SnowparkDataFrame):
122
- return df.with_column("COLUMN1", df["COLUMN1"] + 1)
142
+ def my_snowpark_scalar_fn(df: SnowparkDataFrame):
143
+ return df.count()
123
144
 
124
- increment_column1(df)
145
+ df = job_context.snowpark_session.create_dataframe(
146
+ [[1, 2], [3, 4]], schema=["a", "b"]
147
+ )
148
+ count = my_snowpark_scalar_fn(df)
125
149
  ```
126
150
 
127
151
  ### Pandera Snowpark Decorators
128
152
 
129
- The decorators `@check_input_schema` and `@check_output_schema` allow for sampled schema validation of Snowpark DataFrames in the input arguments or in the return value.
153
+ The decorators `@check_input_schema` and `@check_output_schema` allow for sampled schema validation of Snowpark dataframes in the input arguments or in the return value.
130
154
 
131
155
  ```python
132
- from snowflake.snowpark_checkpoints.checkpoint import check_input_schema, check_output_schema
156
+ from snowflake.snowpark_checkpoints.spark_migration import SamplingStrategy
157
+ from snowflake.snowpark_checkpoints.job_context import SnowparkJobContext
158
+ from pandera import DataFrameSchema
159
+ from typing import Optional
133
160
 
134
- @check_input_schema(
161
+ # Signature of the decorator
162
+ def check_input_schema(
135
163
  pandera_schema: DataFrameSchema,
136
164
  checkpoint_name: str,
137
165
  sample_frac: Optional[float] = 1.0,
138
166
  sample_number: Optional[int] = None,
139
167
  sampling_strategy: Optional[SamplingStrategy] = SamplingStrategy.RANDOM_SAMPLE,
140
- job_context: Optional[SnowparkJobContext],
168
+ job_context: Optional[SnowparkJobContext] = None,
141
169
  output_path: Optional[str] = None,
142
- )
143
- def snowpark_fn(df: SnowparkDataFrame):
170
+ ):
144
171
  ...
145
172
 
146
- @check_output_schema(
173
+ # Signature of the decorator
174
+ def check_output_schema(
147
175
  pandera_schema: DataFrameSchema,
148
176
  checkpoint_name: str,
149
177
  sample_frac: Optional[float] = 1.0,
150
178
  sample_number: Optional[int] = None,
151
179
  sampling_strategy: Optional[SamplingStrategy] = SamplingStrategy.RANDOM_SAMPLE,
152
- job_context: Optional[SnowparkJobContext],
180
+ job_context: Optional[SnowparkJobContext] = None,
153
181
  output_path: Optional[str] = None,
154
- )
155
- def snowpark_fn(df: SnowparkDataFrame):
182
+ ):
156
183
  ...
157
184
  ```
158
185
 
@@ -166,28 +193,71 @@ def snowpark_fn(df: SnowparkDataFrame):
166
193
 
167
194
  ### Usage Example
168
195
 
169
- The following will result in a Pandera `SchemaError`:
196
+ #### Check Input Schema Example
197
+ ```python
198
+ from pandas import DataFrame as PandasDataFrame
199
+ from pandera import DataFrameSchema, Column, Check
200
+ from snowflake.snowpark import Session
201
+ from snowflake.snowpark import DataFrame as SnowparkDataFrame
202
+ from snowflake.snowpark_checkpoints.checkpoint import check_input_schema
203
+ from numpy import int8
204
+
205
+ df = PandasDataFrame(
206
+ {
207
+ "COLUMN1": [1, 4, 0, 10, 9],
208
+ "COLUMN2": [-1.3, -1.4, -2.9, -10.1, -20.4],
209
+ }
210
+ )
170
211
 
212
+ in_schema = DataFrameSchema(
213
+ {
214
+ "COLUMN1": Column(int8, Check(lambda x: 0 <= x <= 10, element_wise=True)),
215
+ "COLUMN2": Column(float, Check(lambda x: x < -1.2, element_wise=True)),
216
+ }
217
+ )
218
+
219
+ @check_input_schema(in_schema, "input_schema_checkpoint")
220
+ def preprocessor(dataframe: SnowparkDataFrame):
221
+ dataframe = dataframe.withColumn(
222
+ "COLUMN3", dataframe["COLUMN1"] + dataframe["COLUMN2"]
223
+ )
224
+ return dataframe
225
+
226
+ session = Session.builder.getOrCreate()
227
+ sp_dataframe = session.create_dataframe(df)
228
+
229
+ preprocessed_dataframe = preprocessor(sp_dataframe)
230
+ ```
231
+
232
+ #### Check Input Schema Example
171
233
  ```python
172
234
  from pandas import DataFrame as PandasDataFrame
173
235
  from pandera import DataFrameSchema, Column, Check
174
236
  from snowflake.snowpark import Session
175
237
  from snowflake.snowpark import DataFrame as SnowparkDataFrame
176
238
  from snowflake.snowpark_checkpoints.checkpoint import check_output_schema
239
+ from numpy import int8
177
240
 
178
- df = PandasDataFrame({
179
- "COLUMN1": [1, 4, 0, 10, 9],
180
- "COLUMN2": [-1.3, -1.4, -2.9, -10.1, -20.4],
181
- })
241
+ df = PandasDataFrame(
242
+ {
243
+ "COLUMN1": [1, 4, 0, 10, 9],
244
+ "COLUMN2": [-1.3, -1.4, -2.9, -10.1, -20.4],
245
+ }
246
+ )
182
247
 
183
- out_schema = DataFrameSchema({
184
- "COLUMN1": Column(int8, Check(lambda x: 0 <= x <= 10, element_wise=True)),
185
- "COLUMN2": Column(float, Check(lambda x: x < -1.2)),
186
- })
248
+ out_schema = DataFrameSchema(
249
+ {
250
+ "COLUMN1": Column(int8, Check.between(0, 10, include_max=True, include_min=True)),
251
+ "COLUMN2": Column(float, Check.less_than_or_equal_to(-1.2)),
252
+ "COLUMN3": Column(float, Check.less_than(10)),
253
+ }
254
+ )
187
255
 
188
256
  @check_output_schema(out_schema, "output_schema_checkpoint")
189
257
  def preprocessor(dataframe: SnowparkDataFrame):
190
- return dataframe.with_column("COLUMN1", lit('Some bad data yo'))
258
+ return dataframe.with_column(
259
+ "COLUMN3", dataframe["COLUMN1"] + dataframe["COLUMN2"]
260
+ )
191
261
 
192
262
  session = Session.builder.getOrCreate()
193
263
  sp_dataframe = session.create_dataframe(df)
@@ -195,6 +265,4 @@ sp_dataframe = session.create_dataframe(df)
195
265
  preprocessed_dataframe = preprocessor(sp_dataframe)
196
266
  ```
197
267
 
198
- ## License
199
-
200
- This project is licensed under the MIT License. See the [LICENSE](LICENSE) file for more details.
268
+ ------
@@ -10,6 +10,7 @@ from snowflake.snowpark_checkpoints.checkpoint import (
10
10
  )
11
11
  from snowflake.snowpark_checkpoints.job_context import SnowparkJobContext
12
12
  from snowflake.snowpark_checkpoints.spark_migration import check_with_spark
13
+ from snowflake.snowpark_checkpoints.utils.constants import CheckpointMode
13
14
 
14
15
  __all__ = [
15
16
  "check_with_spark",
@@ -18,4 +19,5 @@ __all__ = [
18
19
  "check_output_schema",
19
20
  "check_input_schema",
20
21
  "validate_dataframe_checkpoint",
22
+ "CheckpointMode",
21
23
  ]