snowpark-checkpoints-validators 0.2.0rc1__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. snowflake/snowpark_checkpoints/__init__.py +44 -0
  2. snowflake/snowpark_checkpoints/__version__.py +16 -0
  3. snowflake/snowpark_checkpoints/checkpoint.py +580 -0
  4. snowflake/snowpark_checkpoints/errors.py +60 -0
  5. snowflake/snowpark_checkpoints/io_utils/__init__.py +26 -0
  6. snowflake/snowpark_checkpoints/io_utils/io_default_strategy.py +57 -0
  7. snowflake/snowpark_checkpoints/io_utils/io_env_strategy.py +133 -0
  8. snowflake/snowpark_checkpoints/io_utils/io_file_manager.py +76 -0
  9. snowflake/snowpark_checkpoints/job_context.py +128 -0
  10. snowflake/snowpark_checkpoints/singleton.py +23 -0
  11. snowflake/snowpark_checkpoints/snowpark_sampler.py +124 -0
  12. snowflake/snowpark_checkpoints/spark_migration.py +255 -0
  13. snowflake/snowpark_checkpoints/utils/__init__.py +14 -0
  14. snowflake/snowpark_checkpoints/utils/constants.py +134 -0
  15. snowflake/snowpark_checkpoints/utils/extra_config.py +132 -0
  16. snowflake/snowpark_checkpoints/utils/logging_utils.py +67 -0
  17. snowflake/snowpark_checkpoints/utils/pandera_check_manager.py +399 -0
  18. snowflake/snowpark_checkpoints/utils/supported_types.py +65 -0
  19. snowflake/snowpark_checkpoints/utils/telemetry.py +939 -0
  20. snowflake/snowpark_checkpoints/utils/utils_checks.py +398 -0
  21. snowflake/snowpark_checkpoints/validation_result_metadata.py +159 -0
  22. snowflake/snowpark_checkpoints/validation_results.py +49 -0
  23. snowpark_checkpoints_validators-0.3.0.dist-info/METADATA +325 -0
  24. snowpark_checkpoints_validators-0.3.0.dist-info/RECORD +26 -0
  25. snowpark_checkpoints_validators-0.2.0rc1.dist-info/METADATA +0 -514
  26. snowpark_checkpoints_validators-0.2.0rc1.dist-info/RECORD +0 -4
  27. {snowpark_checkpoints_validators-0.2.0rc1.dist-info → snowpark_checkpoints_validators-0.3.0.dist-info}/WHEEL +0 -0
  28. {snowpark_checkpoints_validators-0.2.0rc1.dist-info → snowpark_checkpoints_validators-0.3.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,26 @@
1
+ # Copyright 2025 Snowflake Inc.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ __all__ = ["EnvStrategy", "IOFileManager", "IODefaultStrategy"]
17
+
18
+ from snowflake.snowpark_checkpoints.io_utils.io_env_strategy import (
19
+ EnvStrategy,
20
+ )
21
+ from snowflake.snowpark_checkpoints.io_utils.io_default_strategy import (
22
+ IODefaultStrategy,
23
+ )
24
+ from snowflake.snowpark_checkpoints.io_utils.io_file_manager import (
25
+ IOFileManager,
26
+ )
@@ -0,0 +1,57 @@
1
+ # Copyright 2025 Snowflake Inc.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import glob
17
+ import os
18
+
19
+ from pathlib import Path
20
+ from typing import Optional
21
+
22
+ from snowflake.snowpark_checkpoints.io_utils import EnvStrategy
23
+
24
+
25
+ class IODefaultStrategy(EnvStrategy):
26
+ def mkdir(self, path: str, exist_ok: bool = False) -> None:
27
+ os.makedirs(path, exist_ok=exist_ok)
28
+
29
+ def folder_exists(self, path: str) -> bool:
30
+ return os.path.isdir(path)
31
+
32
+ def file_exists(self, path: str) -> bool:
33
+ return os.path.isfile(path)
34
+
35
+ def write(self, file_path: str, file_content: str, overwrite: bool = True) -> None:
36
+ mode = "w" if overwrite else "x"
37
+ with open(file_path, mode) as file:
38
+ file.write(file_content)
39
+
40
+ def read(
41
+ self, file_path: str, mode: str = "r", encoding: Optional[str] = None
42
+ ) -> str:
43
+ with open(file_path, mode=mode, encoding=encoding) as file:
44
+ return file.read()
45
+
46
+ def read_bytes(self, file_path: str) -> bytes:
47
+ with open(file_path, mode="rb") as f:
48
+ return f.read()
49
+
50
+ def ls(self, path: str, recursive: bool = False) -> list[str]:
51
+ return glob.glob(path, recursive=recursive)
52
+
53
+ def getcwd(self) -> str:
54
+ return os.getcwd()
55
+
56
+ def telemetry_path_files(self, path: str) -> Path:
57
+ return Path(path)
@@ -0,0 +1,133 @@
1
+ # Copyright 2025 Snowflake Inc.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from abc import ABC, abstractmethod
17
+ from pathlib import Path
18
+ from typing import Optional
19
+
20
+
21
+ class EnvStrategy(ABC):
22
+
23
+ """An abstract base class that defines methods for file and directory operations.
24
+
25
+ Subclasses should implement these methods to provide environment-specific behavior.
26
+ """
27
+
28
+ @abstractmethod
29
+ def mkdir(self, path: str, exist_ok: bool = False) -> None:
30
+ """Create a directory.
31
+
32
+ Args:
33
+ path: The name of the directory to create.
34
+ exist_ok: If False, an error is raised if the directory already exists.
35
+
36
+ """
37
+
38
+ @abstractmethod
39
+ def folder_exists(self, path: str) -> bool:
40
+ """Check if a folder exists.
41
+
42
+ Args:
43
+ path: The path to the folder.
44
+
45
+ Returns:
46
+ bool: True if the folder exists, False otherwise.
47
+
48
+ """
49
+
50
+ @abstractmethod
51
+ def file_exists(self, path: str) -> bool:
52
+ """Check if a file exists.
53
+
54
+ Args:
55
+ path: The path to the file.
56
+
57
+ Returns:
58
+ bool: True if the file exists, False otherwise.
59
+
60
+ """
61
+
62
+ @abstractmethod
63
+ def write(self, file_path: str, file_content: str, overwrite: bool = True) -> None:
64
+ """Write content to a file.
65
+
66
+ Args:
67
+ file_path: The name of the file to write to.
68
+ file_content: The content to write to the file.
69
+ overwrite: If True, overwrite the file if it exists.
70
+
71
+ """
72
+
73
+ @abstractmethod
74
+ def read(
75
+ self, file_path: str, mode: str = "r", encoding: Optional[str] = None
76
+ ) -> str:
77
+ """Read content from a file.
78
+
79
+ Args:
80
+ file_path: The path to the file to read from.
81
+ mode: The mode in which to open the file.
82
+ encoding: The encoding to use for reading the file.
83
+
84
+ Returns:
85
+ str: The content of the file.
86
+
87
+ """
88
+
89
+ @abstractmethod
90
+ def read_bytes(self, file_path: str) -> bytes:
91
+ """Read binary content from a file.
92
+
93
+ Args:
94
+ file_path: The path to the file to read from.
95
+
96
+ Returns:
97
+ bytes: The binary content of the file.
98
+
99
+ """
100
+
101
+ @abstractmethod
102
+ def ls(self, path: str, recursive: bool = False) -> list[str]:
103
+ """List the contents of a directory.
104
+
105
+ Args:
106
+ path: The path to the directory.
107
+ recursive: If True, list the contents recursively.
108
+
109
+ Returns:
110
+ list[str]: A list of the contents of the directory.
111
+
112
+ """
113
+
114
+ @abstractmethod
115
+ def getcwd(self) -> str:
116
+ """Get the current working directory.
117
+
118
+ Returns:
119
+ str: The current working directory.
120
+
121
+ """
122
+
123
+ @abstractmethod
124
+ def telemetry_path_files(self, path: str) -> Path:
125
+ """Get the path to the telemetry files.
126
+
127
+ Args:
128
+ path: The path to the telemetry directory.
129
+
130
+ Returns:
131
+ Path: The path object representing the telemetry files.
132
+
133
+ """
@@ -0,0 +1,76 @@
1
+ # Copyright 2025 Snowflake Inc.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from pathlib import Path
17
+ from typing import Optional
18
+
19
+ from snowflake.snowpark_checkpoints.io_utils import (
20
+ EnvStrategy,
21
+ IODefaultStrategy,
22
+ )
23
+ from snowflake.snowpark_checkpoints.singleton import Singleton
24
+
25
+
26
+ class IOFileManager(metaclass=Singleton):
27
+ def __init__(self, strategy: Optional[EnvStrategy] = None):
28
+ self.strategy = strategy or IODefaultStrategy()
29
+
30
+ def mkdir(self, path: str, exist_ok: bool = False) -> None:
31
+ return self.strategy.mkdir(path, exist_ok)
32
+
33
+ def folder_exists(self, path: str) -> bool:
34
+ return self.strategy.folder_exists(path)
35
+
36
+ def file_exists(self, path: str) -> bool:
37
+ return self.strategy.file_exists(path)
38
+
39
+ def write(self, file_path: str, file_content: str, overwrite: bool = True) -> None:
40
+ return self.strategy.write(file_path, file_content, overwrite)
41
+
42
+ def read(
43
+ self, file_path: str, mode: str = "r", encoding: Optional[str] = None
44
+ ) -> str:
45
+ return self.strategy.read(file_path, mode, encoding)
46
+
47
+ def read_bytes(self, file_path: str) -> bytes:
48
+ return self.strategy.read_bytes(file_path)
49
+
50
+ def ls(self, path: str, recursive: bool = False) -> list[str]:
51
+ return self.strategy.ls(path, recursive)
52
+
53
+ def getcwd(self) -> str:
54
+ return self.strategy.getcwd()
55
+
56
+ def telemetry_path_files(self, path: str) -> Path:
57
+ return self.strategy.telemetry_path_files(path)
58
+
59
+ def set_strategy(self, strategy: EnvStrategy):
60
+ """Set the strategy for file and directory operations.
61
+
62
+ Args:
63
+ strategy (EnvStrategy): The strategy to use for file and directory operations.
64
+
65
+ """
66
+ self.strategy = strategy
67
+
68
+
69
+ def get_io_file_manager():
70
+ """Get the singleton instance of IOFileManager.
71
+
72
+ Returns:
73
+ IOFileManager: The singleton instance of IOFileManager.
74
+
75
+ """
76
+ return IOFileManager()
@@ -0,0 +1,128 @@
1
+ # Copyright 2025 Snowflake Inc.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import logging
17
+
18
+ from datetime import datetime
19
+ from typing import Optional
20
+
21
+ import pandas as pd
22
+
23
+ from pyspark.sql import SparkSession
24
+
25
+ from snowflake.snowpark import Session
26
+ from snowflake.snowpark_checkpoints.utils.constants import SCHEMA_EXECUTION_MODE
27
+
28
+
29
+ LOGGER = logging.getLogger(__name__)
30
+ RESULTS_TABLE = "SNOWPARK_CHECKPOINTS_REPORT"
31
+
32
+
33
+ class SnowparkJobContext:
34
+
35
+ """Class used to record migration results in Snowflake.
36
+
37
+ Args:
38
+ snowpark_session: A Snowpark session instance.
39
+ spark_session: A Spark session instance.
40
+ job_name: The name of the job.
41
+ log_results: Whether to log the migration results in Snowflake.
42
+
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ snowpark_session: Session,
48
+ spark_session: SparkSession = None,
49
+ job_name: Optional[str] = None,
50
+ log_results: Optional[bool] = True,
51
+ ):
52
+ self.log_results = log_results
53
+ self.job_name = job_name
54
+ self.spark_session = spark_session or self._create_pyspark_session()
55
+ self.snowpark_session = snowpark_session
56
+
57
+ def _mark_fail(
58
+ self, message, checkpoint_name, data, execution_mode=SCHEMA_EXECUTION_MODE
59
+ ):
60
+ if not self.log_results:
61
+ LOGGER.warning(
62
+ (
63
+ "Recording of migration results into Snowflake is disabled. "
64
+ "Failure result for checkpoint '%s' will not be recorded."
65
+ ),
66
+ checkpoint_name,
67
+ )
68
+ return
69
+
70
+ LOGGER.debug(
71
+ "Marking failure for checkpoint '%s' in '%s' mode with message '%s'",
72
+ checkpoint_name,
73
+ execution_mode,
74
+ message,
75
+ )
76
+
77
+ session = self.snowpark_session
78
+ df = pd.DataFrame(
79
+ {
80
+ "DATE": [datetime.now()],
81
+ "JOB": [self.job_name],
82
+ "STATUS": ["fail"],
83
+ "CHECKPOINT": [checkpoint_name],
84
+ "MESSAGE": [message],
85
+ "DATA": [f"{data}"],
86
+ "EXECUTION_MODE": [execution_mode],
87
+ }
88
+ )
89
+ report_df = session.createDataFrame(df)
90
+ LOGGER.info("Writing failure result to table: '%s'", RESULTS_TABLE)
91
+ report_df.write.mode("append").save_as_table(RESULTS_TABLE)
92
+
93
+ def _mark_pass(self, checkpoint_name, execution_mode=SCHEMA_EXECUTION_MODE):
94
+ if not self.log_results:
95
+ LOGGER.warning(
96
+ (
97
+ "Recording of migration results into Snowflake is disabled. "
98
+ "Pass result for checkpoint '%s' will not be recorded."
99
+ ),
100
+ checkpoint_name,
101
+ )
102
+ return
103
+
104
+ LOGGER.debug(
105
+ "Marking pass for checkpoint '%s' in '%s' mode",
106
+ checkpoint_name,
107
+ execution_mode,
108
+ )
109
+
110
+ session = self.snowpark_session
111
+ df = pd.DataFrame(
112
+ {
113
+ "DATE": [datetime.now()],
114
+ "JOB": [self.job_name],
115
+ "STATUS": ["pass"],
116
+ "CHECKPOINT": [checkpoint_name],
117
+ "MESSAGE": [""],
118
+ "DATA": [""],
119
+ "EXECUTION_MODE": [execution_mode],
120
+ }
121
+ )
122
+ report_df = session.createDataFrame(df)
123
+ LOGGER.info("Writing pass result to table: '%s'", RESULTS_TABLE)
124
+ report_df.write.mode("append").save_as_table(RESULTS_TABLE)
125
+
126
+ def _create_pyspark_session(self) -> SparkSession:
127
+ LOGGER.info("Creating a PySpark session")
128
+ return SparkSession.builder.getOrCreate()
@@ -0,0 +1,23 @@
1
+ # Copyright 2025 Snowflake Inc.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ class Singleton(type):
18
+ _instances = {}
19
+
20
+ def __call__(cls, *args, **kwargs):
21
+ if cls not in cls._instances:
22
+ cls._instances[cls] = super().__call__(*args, **kwargs)
23
+ return cls._instances[cls]
@@ -0,0 +1,124 @@
1
+ # Copyright 2025 Snowflake Inc.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import logging
17
+
18
+ from typing import Optional
19
+
20
+ import pandas
21
+
22
+ from snowflake.snowpark import DataFrame as SnowparkDataFrame
23
+ from snowflake.snowpark_checkpoints.job_context import SnowparkJobContext
24
+
25
+
26
+ LOGGER = logging.getLogger(__name__)
27
+
28
+
29
+ class SamplingStrategy:
30
+ RANDOM_SAMPLE = 1
31
+ LIMIT = 2
32
+
33
+
34
+ class SamplingError(Exception):
35
+ pass
36
+
37
+
38
+ class SamplingAdapter:
39
+ def __init__(
40
+ self,
41
+ job_context: Optional[SnowparkJobContext],
42
+ sample_frac: Optional[float] = None,
43
+ sample_number: Optional[int] = None,
44
+ sampling_strategy: SamplingStrategy = SamplingStrategy.RANDOM_SAMPLE,
45
+ ):
46
+ self.pandas_sample_args = []
47
+ self.job_context = job_context
48
+ if sample_frac and not (0 <= sample_frac <= 1):
49
+ raise ValueError(
50
+ f"'sample_size' value {sample_frac} is out of range (0 <= sample_size <= 1)"
51
+ )
52
+
53
+ self.sample_frac = sample_frac
54
+ self.sample_number = sample_number
55
+ self.sampling_strategy = sampling_strategy
56
+
57
+ def process_args(self, input_args):
58
+ # create the intermediate pandas
59
+ # data frame for the test data
60
+ LOGGER.info("Processing %s input argument(s) for sampling", len(input_args))
61
+ for arg in input_args:
62
+ if isinstance(arg, SnowparkDataFrame):
63
+ df_count = arg.count()
64
+ if df_count == 0:
65
+ raise SamplingError(
66
+ "Input DataFrame is empty. Cannot sample from an empty DataFrame."
67
+ )
68
+
69
+ LOGGER.info("Sampling a Snowpark DataFrame with %s rows", df_count)
70
+ if self.sampling_strategy == SamplingStrategy.RANDOM_SAMPLE:
71
+ if self.sample_frac:
72
+ LOGGER.info(
73
+ "Applying random sampling with fraction %s",
74
+ self.sample_frac,
75
+ )
76
+ df_sample = arg.sample(frac=self.sample_frac).to_pandas()
77
+ else:
78
+ LOGGER.info(
79
+ "Applying random sampling with size %s", self.sample_number
80
+ )
81
+ df_sample = arg.sample(n=self.sample_number).to_pandas()
82
+ else:
83
+ LOGGER.info(
84
+ "Applying limit sampling with size %s", self.sample_number
85
+ )
86
+ df_sample = arg.limit(self.sample_number).to_pandas()
87
+
88
+ LOGGER.info(
89
+ "Successfully sampled the DataFrame. Resulting DataFrame shape: %s",
90
+ df_sample.shape,
91
+ )
92
+ self.pandas_sample_args.append(df_sample)
93
+ else:
94
+ LOGGER.debug(
95
+ "Argument is not a Snowpark DataFrame. No sampling is applied."
96
+ )
97
+ self.pandas_sample_args.append(arg)
98
+
99
+ def get_sampled_pandas_args(self):
100
+ return self.pandas_sample_args
101
+
102
+ def get_sampled_snowpark_args(self):
103
+ if self.job_context is None:
104
+ raise SamplingError("Need a job context to compare with Spark")
105
+ snowpark_sample_args = []
106
+ for arg in self.pandas_sample_args:
107
+ if isinstance(arg, pandas.DataFrame):
108
+ snowpark_df = self.job_context.snowpark_session.create_dataframe(arg)
109
+ snowpark_sample_args.append(snowpark_df)
110
+ else:
111
+ snowpark_sample_args.append(arg)
112
+ return snowpark_sample_args
113
+
114
+ def get_sampled_spark_args(self):
115
+ if self.job_context is None:
116
+ raise SamplingError("Need a job context to compare with Spark")
117
+ pyspark_sample_args = []
118
+ for arg in self.pandas_sample_args:
119
+ if isinstance(arg, pandas.DataFrame):
120
+ pyspark_df = self.job_context.spark_session.createDataFrame(arg)
121
+ pyspark_sample_args.append(pyspark_df)
122
+ else:
123
+ pyspark_sample_args.append(arg)
124
+ return pyspark_sample_args