snowpark-checkpoints-validators 0.2.0rc1__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/snowpark_checkpoints/__init__.py +44 -0
- snowflake/snowpark_checkpoints/__version__.py +16 -0
- snowflake/snowpark_checkpoints/checkpoint.py +580 -0
- snowflake/snowpark_checkpoints/errors.py +60 -0
- snowflake/snowpark_checkpoints/job_context.py +128 -0
- snowflake/snowpark_checkpoints/singleton.py +23 -0
- snowflake/snowpark_checkpoints/snowpark_sampler.py +124 -0
- snowflake/snowpark_checkpoints/spark_migration.py +255 -0
- snowflake/snowpark_checkpoints/utils/__init__.py +14 -0
- snowflake/snowpark_checkpoints/utils/constants.py +134 -0
- snowflake/snowpark_checkpoints/utils/extra_config.py +89 -0
- snowflake/snowpark_checkpoints/utils/logging_utils.py +67 -0
- snowflake/snowpark_checkpoints/utils/pandera_check_manager.py +399 -0
- snowflake/snowpark_checkpoints/utils/supported_types.py +65 -0
- snowflake/snowpark_checkpoints/utils/telemetry.py +900 -0
- snowflake/snowpark_checkpoints/utils/utils_checks.py +395 -0
- snowflake/snowpark_checkpoints/validation_result_metadata.py +155 -0
- snowflake/snowpark_checkpoints/validation_results.py +49 -0
- snowpark_checkpoints_validators-0.2.1.dist-info/METADATA +323 -0
- snowpark_checkpoints_validators-0.2.1.dist-info/RECORD +22 -0
- snowpark_checkpoints_validators-0.2.0rc1.dist-info/METADATA +0 -514
- snowpark_checkpoints_validators-0.2.0rc1.dist-info/RECORD +0 -4
- {snowpark_checkpoints_validators-0.2.0rc1.dist-info → snowpark_checkpoints_validators-0.2.1.dist-info}/WHEEL +0 -0
- {snowpark_checkpoints_validators-0.2.0rc1.dist-info → snowpark_checkpoints_validators-0.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,128 @@
|
|
1
|
+
# Copyright 2025 Snowflake Inc.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
import logging
|
17
|
+
|
18
|
+
from datetime import datetime
|
19
|
+
from typing import Optional
|
20
|
+
|
21
|
+
import pandas as pd
|
22
|
+
|
23
|
+
from pyspark.sql import SparkSession
|
24
|
+
|
25
|
+
from snowflake.snowpark import Session
|
26
|
+
from snowflake.snowpark_checkpoints.utils.constants import SCHEMA_EXECUTION_MODE
|
27
|
+
|
28
|
+
|
29
|
+
LOGGER = logging.getLogger(__name__)
|
30
|
+
RESULTS_TABLE = "SNOWPARK_CHECKPOINTS_REPORT"
|
31
|
+
|
32
|
+
|
33
|
+
class SnowparkJobContext:
|
34
|
+
|
35
|
+
"""Class used to record migration results in Snowflake.
|
36
|
+
|
37
|
+
Args:
|
38
|
+
snowpark_session: A Snowpark session instance.
|
39
|
+
spark_session: A Spark session instance.
|
40
|
+
job_name: The name of the job.
|
41
|
+
log_results: Whether to log the migration results in Snowflake.
|
42
|
+
|
43
|
+
"""
|
44
|
+
|
45
|
+
def __init__(
|
46
|
+
self,
|
47
|
+
snowpark_session: Session,
|
48
|
+
spark_session: SparkSession = None,
|
49
|
+
job_name: Optional[str] = None,
|
50
|
+
log_results: Optional[bool] = True,
|
51
|
+
):
|
52
|
+
self.log_results = log_results
|
53
|
+
self.job_name = job_name
|
54
|
+
self.spark_session = spark_session or self._create_pyspark_session()
|
55
|
+
self.snowpark_session = snowpark_session
|
56
|
+
|
57
|
+
def _mark_fail(
|
58
|
+
self, message, checkpoint_name, data, execution_mode=SCHEMA_EXECUTION_MODE
|
59
|
+
):
|
60
|
+
if not self.log_results:
|
61
|
+
LOGGER.warning(
|
62
|
+
(
|
63
|
+
"Recording of migration results into Snowflake is disabled. "
|
64
|
+
"Failure result for checkpoint '%s' will not be recorded."
|
65
|
+
),
|
66
|
+
checkpoint_name,
|
67
|
+
)
|
68
|
+
return
|
69
|
+
|
70
|
+
LOGGER.debug(
|
71
|
+
"Marking failure for checkpoint '%s' in '%s' mode with message '%s'",
|
72
|
+
checkpoint_name,
|
73
|
+
execution_mode,
|
74
|
+
message,
|
75
|
+
)
|
76
|
+
|
77
|
+
session = self.snowpark_session
|
78
|
+
df = pd.DataFrame(
|
79
|
+
{
|
80
|
+
"DATE": [datetime.now()],
|
81
|
+
"JOB": [self.job_name],
|
82
|
+
"STATUS": ["fail"],
|
83
|
+
"CHECKPOINT": [checkpoint_name],
|
84
|
+
"MESSAGE": [message],
|
85
|
+
"DATA": [f"{data}"],
|
86
|
+
"EXECUTION_MODE": [execution_mode],
|
87
|
+
}
|
88
|
+
)
|
89
|
+
report_df = session.createDataFrame(df)
|
90
|
+
LOGGER.info("Writing failure result to table: '%s'", RESULTS_TABLE)
|
91
|
+
report_df.write.mode("append").save_as_table(RESULTS_TABLE)
|
92
|
+
|
93
|
+
def _mark_pass(self, checkpoint_name, execution_mode=SCHEMA_EXECUTION_MODE):
|
94
|
+
if not self.log_results:
|
95
|
+
LOGGER.warning(
|
96
|
+
(
|
97
|
+
"Recording of migration results into Snowflake is disabled. "
|
98
|
+
"Pass result for checkpoint '%s' will not be recorded."
|
99
|
+
),
|
100
|
+
checkpoint_name,
|
101
|
+
)
|
102
|
+
return
|
103
|
+
|
104
|
+
LOGGER.debug(
|
105
|
+
"Marking pass for checkpoint '%s' in '%s' mode",
|
106
|
+
checkpoint_name,
|
107
|
+
execution_mode,
|
108
|
+
)
|
109
|
+
|
110
|
+
session = self.snowpark_session
|
111
|
+
df = pd.DataFrame(
|
112
|
+
{
|
113
|
+
"DATE": [datetime.now()],
|
114
|
+
"JOB": [self.job_name],
|
115
|
+
"STATUS": ["pass"],
|
116
|
+
"CHECKPOINT": [checkpoint_name],
|
117
|
+
"MESSAGE": [""],
|
118
|
+
"DATA": [""],
|
119
|
+
"EXECUTION_MODE": [execution_mode],
|
120
|
+
}
|
121
|
+
)
|
122
|
+
report_df = session.createDataFrame(df)
|
123
|
+
LOGGER.info("Writing pass result to table: '%s'", RESULTS_TABLE)
|
124
|
+
report_df.write.mode("append").save_as_table(RESULTS_TABLE)
|
125
|
+
|
126
|
+
def _create_pyspark_session(self) -> SparkSession:
|
127
|
+
LOGGER.info("Creating a PySpark session")
|
128
|
+
return SparkSession.builder.getOrCreate()
|
@@ -0,0 +1,23 @@
|
|
1
|
+
# Copyright 2025 Snowflake Inc.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
|
17
|
+
class Singleton(type):
|
18
|
+
_instances = {}
|
19
|
+
|
20
|
+
def __call__(cls, *args, **kwargs):
|
21
|
+
if cls not in cls._instances:
|
22
|
+
cls._instances[cls] = super().__call__(*args, **kwargs)
|
23
|
+
return cls._instances[cls]
|
@@ -0,0 +1,124 @@
|
|
1
|
+
# Copyright 2025 Snowflake Inc.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
import logging
|
17
|
+
|
18
|
+
from typing import Optional
|
19
|
+
|
20
|
+
import pandas
|
21
|
+
|
22
|
+
from snowflake.snowpark import DataFrame as SnowparkDataFrame
|
23
|
+
from snowflake.snowpark_checkpoints.job_context import SnowparkJobContext
|
24
|
+
|
25
|
+
|
26
|
+
LOGGER = logging.getLogger(__name__)
|
27
|
+
|
28
|
+
|
29
|
+
class SamplingStrategy:
|
30
|
+
RANDOM_SAMPLE = 1
|
31
|
+
LIMIT = 2
|
32
|
+
|
33
|
+
|
34
|
+
class SamplingError(Exception):
|
35
|
+
pass
|
36
|
+
|
37
|
+
|
38
|
+
class SamplingAdapter:
|
39
|
+
def __init__(
|
40
|
+
self,
|
41
|
+
job_context: Optional[SnowparkJobContext],
|
42
|
+
sample_frac: Optional[float] = None,
|
43
|
+
sample_number: Optional[int] = None,
|
44
|
+
sampling_strategy: SamplingStrategy = SamplingStrategy.RANDOM_SAMPLE,
|
45
|
+
):
|
46
|
+
self.pandas_sample_args = []
|
47
|
+
self.job_context = job_context
|
48
|
+
if sample_frac and not (0 <= sample_frac <= 1):
|
49
|
+
raise ValueError(
|
50
|
+
f"'sample_size' value {sample_frac} is out of range (0 <= sample_size <= 1)"
|
51
|
+
)
|
52
|
+
|
53
|
+
self.sample_frac = sample_frac
|
54
|
+
self.sample_number = sample_number
|
55
|
+
self.sampling_strategy = sampling_strategy
|
56
|
+
|
57
|
+
def process_args(self, input_args):
|
58
|
+
# create the intermediate pandas
|
59
|
+
# data frame for the test data
|
60
|
+
LOGGER.info("Processing %s input argument(s) for sampling", len(input_args))
|
61
|
+
for arg in input_args:
|
62
|
+
if isinstance(arg, SnowparkDataFrame):
|
63
|
+
df_count = arg.count()
|
64
|
+
if df_count == 0:
|
65
|
+
raise SamplingError(
|
66
|
+
"Input DataFrame is empty. Cannot sample from an empty DataFrame."
|
67
|
+
)
|
68
|
+
|
69
|
+
LOGGER.info("Sampling a Snowpark DataFrame with %s rows", df_count)
|
70
|
+
if self.sampling_strategy == SamplingStrategy.RANDOM_SAMPLE:
|
71
|
+
if self.sample_frac:
|
72
|
+
LOGGER.info(
|
73
|
+
"Applying random sampling with fraction %s",
|
74
|
+
self.sample_frac,
|
75
|
+
)
|
76
|
+
df_sample = arg.sample(frac=self.sample_frac).to_pandas()
|
77
|
+
else:
|
78
|
+
LOGGER.info(
|
79
|
+
"Applying random sampling with size %s", self.sample_number
|
80
|
+
)
|
81
|
+
df_sample = arg.sample(n=self.sample_number).to_pandas()
|
82
|
+
else:
|
83
|
+
LOGGER.info(
|
84
|
+
"Applying limit sampling with size %s", self.sample_number
|
85
|
+
)
|
86
|
+
df_sample = arg.limit(self.sample_number).to_pandas()
|
87
|
+
|
88
|
+
LOGGER.info(
|
89
|
+
"Successfully sampled the DataFrame. Resulting DataFrame shape: %s",
|
90
|
+
df_sample.shape,
|
91
|
+
)
|
92
|
+
self.pandas_sample_args.append(df_sample)
|
93
|
+
else:
|
94
|
+
LOGGER.debug(
|
95
|
+
"Argument is not a Snowpark DataFrame. No sampling is applied."
|
96
|
+
)
|
97
|
+
self.pandas_sample_args.append(arg)
|
98
|
+
|
99
|
+
def get_sampled_pandas_args(self):
|
100
|
+
return self.pandas_sample_args
|
101
|
+
|
102
|
+
def get_sampled_snowpark_args(self):
|
103
|
+
if self.job_context is None:
|
104
|
+
raise SamplingError("Need a job context to compare with Spark")
|
105
|
+
snowpark_sample_args = []
|
106
|
+
for arg in self.pandas_sample_args:
|
107
|
+
if isinstance(arg, pandas.DataFrame):
|
108
|
+
snowpark_df = self.job_context.snowpark_session.create_dataframe(arg)
|
109
|
+
snowpark_sample_args.append(snowpark_df)
|
110
|
+
else:
|
111
|
+
snowpark_sample_args.append(arg)
|
112
|
+
return snowpark_sample_args
|
113
|
+
|
114
|
+
def get_sampled_spark_args(self):
|
115
|
+
if self.job_context is None:
|
116
|
+
raise SamplingError("Need a job context to compare with Spark")
|
117
|
+
pyspark_sample_args = []
|
118
|
+
for arg in self.pandas_sample_args:
|
119
|
+
if isinstance(arg, pandas.DataFrame):
|
120
|
+
pyspark_df = self.job_context.spark_session.createDataFrame(arg)
|
121
|
+
pyspark_sample_args.append(pyspark_df)
|
122
|
+
else:
|
123
|
+
pyspark_sample_args.append(arg)
|
124
|
+
return pyspark_sample_args
|
@@ -0,0 +1,255 @@
|
|
1
|
+
# Copyright 2025 Snowflake Inc.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
import logging
|
17
|
+
|
18
|
+
from typing import Callable, Optional, TypeVar
|
19
|
+
|
20
|
+
import pandas as pd
|
21
|
+
|
22
|
+
from pyspark.sql import DataFrame as SparkDataFrame
|
23
|
+
|
24
|
+
from snowflake.snowpark import DataFrame as SnowparkDataFrame
|
25
|
+
from snowflake.snowpark.types import PandasDataFrame
|
26
|
+
from snowflake.snowpark_checkpoints.errors import SparkMigrationError
|
27
|
+
from snowflake.snowpark_checkpoints.job_context import SnowparkJobContext
|
28
|
+
from snowflake.snowpark_checkpoints.snowpark_sampler import (
|
29
|
+
SamplingAdapter,
|
30
|
+
SamplingStrategy,
|
31
|
+
)
|
32
|
+
from snowflake.snowpark_checkpoints.utils.constants import FAIL_STATUS, PASS_STATUS
|
33
|
+
from snowflake.snowpark_checkpoints.utils.logging_utils import log
|
34
|
+
from snowflake.snowpark_checkpoints.utils.telemetry import STATUS_KEY, report_telemetry
|
35
|
+
from snowflake.snowpark_checkpoints.utils.utils_checks import (
|
36
|
+
_replace_special_characters,
|
37
|
+
_update_validation_result,
|
38
|
+
)
|
39
|
+
|
40
|
+
|
41
|
+
fn = TypeVar("F", bound=Callable)
|
42
|
+
LOGGER = logging.getLogger(__name__)
|
43
|
+
|
44
|
+
|
45
|
+
@log
|
46
|
+
def check_with_spark(
|
47
|
+
job_context: Optional[SnowparkJobContext],
|
48
|
+
spark_function: fn,
|
49
|
+
checkpoint_name: str,
|
50
|
+
sample_number: Optional[int] = 100,
|
51
|
+
sampling_strategy: Optional[SamplingStrategy] = SamplingStrategy.RANDOM_SAMPLE,
|
52
|
+
output_path: Optional[str] = None,
|
53
|
+
) -> Callable[[fn], fn]:
|
54
|
+
"""Validate function output with Spark instance.
|
55
|
+
|
56
|
+
Will take the input snowpark dataframe of this function, sample data, convert
|
57
|
+
it to a Spark dataframe and then execute `spark_function`. Subsequently
|
58
|
+
the output of that function will be compared to the output of this function
|
59
|
+
for the same sample of data.
|
60
|
+
|
61
|
+
Args:
|
62
|
+
job_context (SnowparkJobContext): The job context containing configuration and details for the validation.
|
63
|
+
spark_function (fn): The equivalent PySpark function to compare against the Snowpark implementation.
|
64
|
+
checkpoint_name (str): A name for the checkpoint. Defaults to None.
|
65
|
+
sample_number (Optional[int], optional): The number of rows for validation. Defaults to 100.
|
66
|
+
sampling_strategy (Optional[SamplingStrategy], optional): The strategy used for sampling data.
|
67
|
+
Defaults to SamplingStrategy.RANDOM_SAMPLE.
|
68
|
+
output_path (Optional[str], optional): The path to store the validation results. Defaults to None.
|
69
|
+
|
70
|
+
Returns:
|
71
|
+
Callable[[fn], fn]: A decorator that wraps the original Snowpark function with validation logic.
|
72
|
+
|
73
|
+
"""
|
74
|
+
|
75
|
+
def check_with_spark_decorator(snowpark_fn):
|
76
|
+
@log(log_args=False)
|
77
|
+
def wrapper(*args, **kwargs):
|
78
|
+
LOGGER.info(
|
79
|
+
"Starting output validation between Snowpark function '%s' and Spark function '%s'",
|
80
|
+
snowpark_fn.__name__,
|
81
|
+
spark_function.__name__,
|
82
|
+
)
|
83
|
+
_checkpoint_name = checkpoint_name
|
84
|
+
if checkpoint_name is None:
|
85
|
+
LOGGER.warning(
|
86
|
+
"No checkpoint name provided. Using '%s' as the checkpoint name",
|
87
|
+
snowpark_fn.__name__,
|
88
|
+
)
|
89
|
+
_checkpoint_name = snowpark_fn.__name__
|
90
|
+
_checkpoint_name = _replace_special_characters(_checkpoint_name)
|
91
|
+
|
92
|
+
sampler = SamplingAdapter(
|
93
|
+
job_context,
|
94
|
+
sample_number=sample_number,
|
95
|
+
sampling_strategy=sampling_strategy,
|
96
|
+
)
|
97
|
+
sampler.process_args(args)
|
98
|
+
snowpark_sample_args = sampler.get_sampled_snowpark_args()
|
99
|
+
pyspark_sample_args = sampler.get_sampled_spark_args()
|
100
|
+
|
101
|
+
# Run the sampled data in snowpark
|
102
|
+
LOGGER.info("Running the Snowpark function with sampled args")
|
103
|
+
snowpark_test_results = snowpark_fn(*snowpark_sample_args, **kwargs)
|
104
|
+
LOGGER.info("Running the Spark function with sampled args")
|
105
|
+
spark_test_results = spark_function(*pyspark_sample_args, **kwargs)
|
106
|
+
|
107
|
+
LOGGER.info("Comparing the results of the Snowpark and Spark functions")
|
108
|
+
result, exception = _assert_return(
|
109
|
+
snowpark_test_results,
|
110
|
+
spark_test_results,
|
111
|
+
job_context,
|
112
|
+
_checkpoint_name,
|
113
|
+
output_path,
|
114
|
+
)
|
115
|
+
if not result:
|
116
|
+
LOGGER.error(
|
117
|
+
"Validation failed. The results of the Snowpark function '%s' and Spark function '%s' do not match",
|
118
|
+
snowpark_fn.__name__,
|
119
|
+
spark_function.__name__,
|
120
|
+
)
|
121
|
+
raise exception from None
|
122
|
+
LOGGER.info(
|
123
|
+
"Validation passed. The results of the Snowpark function '%s' and Spark function '%s' match",
|
124
|
+
snowpark_fn.__name__,
|
125
|
+
spark_function.__name__,
|
126
|
+
)
|
127
|
+
|
128
|
+
# Run the original function in snowpark
|
129
|
+
return snowpark_fn(*args, **kwargs)
|
130
|
+
|
131
|
+
return wrapper
|
132
|
+
|
133
|
+
return check_with_spark_decorator
|
134
|
+
|
135
|
+
|
136
|
+
@report_telemetry(
|
137
|
+
params_list=["snowpark_results", "spark_results"],
|
138
|
+
return_indexes=[(STATUS_KEY, 0)],
|
139
|
+
multiple_return=True,
|
140
|
+
)
|
141
|
+
def _assert_return(
|
142
|
+
snowpark_results, spark_results, job_context, checkpoint_name, output_path=None
|
143
|
+
) -> tuple[bool, Optional[Exception]]:
|
144
|
+
"""Assert and validate the results from Snowpark and Spark transformations.
|
145
|
+
|
146
|
+
Args:
|
147
|
+
snowpark_results (Any): Results from the Snowpark transformation.
|
148
|
+
spark_results (Any): Results from the Spark transformation to compare against.
|
149
|
+
job_context (Any): Additional context about the job. Defaults to None.
|
150
|
+
checkpoint_name (Any): Name of the checkpoint for logging. Defaults to None.
|
151
|
+
output_path (Optional[str], optional): The path to store the validation results. Defaults to None.
|
152
|
+
|
153
|
+
Raises:
|
154
|
+
AssertionError: If the Snowpark and Spark results do not match.
|
155
|
+
TypeError: If the results cannot be compared.
|
156
|
+
|
157
|
+
"""
|
158
|
+
if isinstance(snowpark_results, SnowparkDataFrame) and isinstance(
|
159
|
+
spark_results, SparkDataFrame
|
160
|
+
):
|
161
|
+
LOGGER.debug("Comparing two DataFrame results for equality")
|
162
|
+
cmp = compare_spark_snowpark_dfs(spark_results, snowpark_results)
|
163
|
+
|
164
|
+
if not cmp.empty:
|
165
|
+
exception_result = SparkMigrationError(
|
166
|
+
"DataFrame difference:\n", job_context, checkpoint_name, cmp
|
167
|
+
)
|
168
|
+
return False, exception_result
|
169
|
+
job_context._mark_pass(checkpoint_name)
|
170
|
+
_update_validation_result(checkpoint_name, PASS_STATUS, output_path)
|
171
|
+
return True, None
|
172
|
+
else:
|
173
|
+
LOGGER.debug("Comparing two scalar results for equality")
|
174
|
+
if snowpark_results != spark_results:
|
175
|
+
exception_result = SparkMigrationError(
|
176
|
+
"Return value difference:\n",
|
177
|
+
job_context,
|
178
|
+
checkpoint_name,
|
179
|
+
f"{snowpark_results} != {spark_results}",
|
180
|
+
)
|
181
|
+
_update_validation_result(checkpoint_name, FAIL_STATUS, output_path)
|
182
|
+
return False, exception_result
|
183
|
+
job_context._mark_pass(checkpoint_name)
|
184
|
+
_update_validation_result(checkpoint_name, PASS_STATUS, output_path)
|
185
|
+
return True, None
|
186
|
+
|
187
|
+
|
188
|
+
def compare_spark_snowpark_dfs(
|
189
|
+
spark_df: SparkDataFrame, snowpark_df: SnowparkDataFrame
|
190
|
+
) -> PandasDataFrame:
|
191
|
+
"""Compare two dataframes for equality.
|
192
|
+
|
193
|
+
Args:
|
194
|
+
spark_df (SparkDataFrame): The Spark dataframe to compare.
|
195
|
+
snowpark_df (SnowparkDataFrame): The Snowpark dataframe to compare.
|
196
|
+
|
197
|
+
Returns:Pandas DataFrame containing the differences between the two dataframes.
|
198
|
+
|
199
|
+
"""
|
200
|
+
snowpark_df = snowpark_df.to_pandas()
|
201
|
+
snowpark_df.columns = snowpark_df.columns.str.upper()
|
202
|
+
spark_df = spark_df.toPandas()
|
203
|
+
spark_df.columns = spark_df.columns.str.upper()
|
204
|
+
spark_cols = set(spark_df.columns)
|
205
|
+
snowpark_cols = set(snowpark_df.columns)
|
206
|
+
cmp = pd.DataFrame([])
|
207
|
+
left = spark_cols - snowpark_cols
|
208
|
+
right = snowpark_cols - spark_cols
|
209
|
+
if left != set():
|
210
|
+
cmp = _compare_dfs(spark_df, snowpark_df, "spark", "snowpark")
|
211
|
+
if right != set():
|
212
|
+
right_cmp = _compare_dfs(snowpark_df, spark_df, "snowpark", "spark")
|
213
|
+
cmp = right_cmp if cmp.empty else pd.concat([cmp, right_cmp], ignore_index=True)
|
214
|
+
if left == set() and right == set():
|
215
|
+
if spark_df.shape == snowpark_df.shape:
|
216
|
+
cmp = spark_df.compare(snowpark_df, result_names=("spark", "snowpark"))
|
217
|
+
else:
|
218
|
+
cmp = spark_df.merge(snowpark_df, indicator=True, how="outer").loc[
|
219
|
+
lambda x: x["_merge"] != "both"
|
220
|
+
]
|
221
|
+
cmp = cmp.replace(
|
222
|
+
{"left_only": "spark_only", "right_only": "snowpark_only"}
|
223
|
+
)
|
224
|
+
|
225
|
+
return cmp
|
226
|
+
|
227
|
+
|
228
|
+
def _compare_dfs(
|
229
|
+
df_a: pd.DataFrame, df_b: pd.DataFrame, left_label: str, right_label: str
|
230
|
+
) -> PandasDataFrame:
|
231
|
+
"""Compare two dataframes for equality.
|
232
|
+
|
233
|
+
Args:
|
234
|
+
df_a (PandasDataFrame): The first dataframe to compare.
|
235
|
+
df_b (PandasDataFrame): The second dataframe to compare.
|
236
|
+
left_label (str): The label for the first dataframe.
|
237
|
+
right_label (str): The label for the second dataframe.
|
238
|
+
|
239
|
+
:return: Pandas DataFrame containing the differences between the two dataframes.
|
240
|
+
|
241
|
+
"""
|
242
|
+
df_a["side"] = "a"
|
243
|
+
df_b["side"] = "b"
|
244
|
+
a_only = [col for col in df_a.columns if col not in df_b.columns] + ["side"]
|
245
|
+
b_only = [col for col in df_b.columns if col not in df_a.columns] + ["side"]
|
246
|
+
cmp = (
|
247
|
+
df_a[a_only]
|
248
|
+
.merge(df_b[b_only], indicator=True, how="left")
|
249
|
+
.loc[lambda x: x["_merge"] != "both"]
|
250
|
+
)
|
251
|
+
cmp = cmp.replace(
|
252
|
+
{"left_only": f"{left_label}_only", "right_only": f"{right_label}_only"}
|
253
|
+
)
|
254
|
+
cmp = cmp.drop(columns="side")
|
255
|
+
return cmp
|
@@ -0,0 +1,14 @@
|
|
1
|
+
# Copyright 2025 Snowflake Inc.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
@@ -0,0 +1,134 @@
|
|
1
|
+
# Copyright 2025 Snowflake Inc.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
# Skip type
|
17
|
+
from enum import IntEnum
|
18
|
+
from typing import Final
|
19
|
+
|
20
|
+
|
21
|
+
class CheckpointMode(IntEnum):
|
22
|
+
|
23
|
+
"""Enum class representing the validation mode."""
|
24
|
+
|
25
|
+
SCHEMA = 1
|
26
|
+
"""Validate against a schema file"""
|
27
|
+
DATAFRAME = 2
|
28
|
+
"""Validate against a dataframe"""
|
29
|
+
|
30
|
+
|
31
|
+
# Execution mode
|
32
|
+
SCHEMA_EXECUTION_MODE: Final[str] = "Schema"
|
33
|
+
DATAFRAME_EXECUTION_MODE: Final[str] = "Dataframe"
|
34
|
+
|
35
|
+
|
36
|
+
# File position on stack
|
37
|
+
STACK_POSITION_CHECKPOINT: Final[int] = 6
|
38
|
+
|
39
|
+
# Validation status
|
40
|
+
PASS_STATUS: Final[str] = "PASS"
|
41
|
+
FAIL_STATUS: Final[str] = "FAIL"
|
42
|
+
|
43
|
+
# Validation result keys
|
44
|
+
DEFAULT_KEY: Final[str] = "default"
|
45
|
+
|
46
|
+
|
47
|
+
# Skip type
|
48
|
+
SKIP_ALL: Final[str] = "skip_all"
|
49
|
+
|
50
|
+
# Supported types
|
51
|
+
BOOLEAN_TYPE: Final[str] = "boolean"
|
52
|
+
BINARY_TYPE: Final[str] = "binary"
|
53
|
+
BYTE_TYPE: Final[str] = "byte"
|
54
|
+
CHAR_TYPE: Final[str] = "char"
|
55
|
+
DATE_TYPE: Final[str] = "date"
|
56
|
+
DAYTIMEINTERVAL_TYPE: Final[str] = "daytimeinterval"
|
57
|
+
DECIMAL_TYPE: Final[str] = "decimal"
|
58
|
+
DOUBLE_TYPE: Final[str] = "double"
|
59
|
+
FLOAT_TYPE: Final[str] = "float"
|
60
|
+
INTEGER_TYPE: Final[str] = "integer"
|
61
|
+
LONG_TYPE: Final[str] = "long"
|
62
|
+
SHORT_TYPE: Final[str] = "short"
|
63
|
+
STRING_TYPE: Final[str] = "string"
|
64
|
+
TIMESTAMP_TYPE: Final[str] = "timestamp"
|
65
|
+
TIMESTAMP_NTZ_TYPE: Final[str] = "timestamp_ntz"
|
66
|
+
VARCHAR_TYPE: Final[str] = "varchar"
|
67
|
+
|
68
|
+
# Pandas data types
|
69
|
+
PANDAS_BOOLEAN_DTYPE: Final[str] = "bool"
|
70
|
+
PANDAS_DATETIME_DTYPE: Final[str] = "datetime64[ns]"
|
71
|
+
PANDAS_FLOAT_DTYPE: Final[str] = "float64"
|
72
|
+
PANDAS_INTEGER_DTYPE: Final[str] = "int64"
|
73
|
+
PANDAS_OBJECT_DTYPE: Final[str] = "object"
|
74
|
+
PANDAS_TIMEDELTA_DTYPE: Final[str] = "timedelta64[ns]"
|
75
|
+
|
76
|
+
# Schemas keys
|
77
|
+
COLUMNS_KEY: Final[str] = "columns"
|
78
|
+
COUNT_KEY: Final[str] = "rows_count"
|
79
|
+
DECIMAL_PRECISION_KEY: Final[str] = "decimal_precision"
|
80
|
+
FALSE_COUNT_KEY: Final[str] = "false_count"
|
81
|
+
FORMAT_KEY: Final[str] = "format"
|
82
|
+
NAME_KEY: Final[str] = "name"
|
83
|
+
MARGIN_ERROR_KEY: Final[str] = "margin_error"
|
84
|
+
MAX_KEY: Final[str] = "max"
|
85
|
+
MEAN_KEY: Final[str] = "mean"
|
86
|
+
MIN_KEY: Final[str] = "min"
|
87
|
+
NULL_COUNT_KEY: Final[str] = "rows_null_count"
|
88
|
+
NULLABLE_KEY: Final[str] = "nullable"
|
89
|
+
ROWS_NOT_NULL_COUNT_KEY: Final[str] = "rows_not_null_count"
|
90
|
+
TRUE_COUNT_KEY: Final[str] = "true_count"
|
91
|
+
TYPE_KEY: Final[str] = "type"
|
92
|
+
ROWS_COUNT_KEY: Final[str] = "rows_count"
|
93
|
+
FORMAT_KEY: Final[str] = "format"
|
94
|
+
|
95
|
+
DATAFRAME_CUSTOM_DATA_KEY: Final[str] = "custom_data"
|
96
|
+
DATAFRAME_PANDERA_SCHEMA_KEY: Final[str] = "pandera_schema"
|
97
|
+
|
98
|
+
# Default values
|
99
|
+
DEFAULT_DATE_FORMAT: Final[str] = "%Y-%m-%d"
|
100
|
+
|
101
|
+
# SQL Column names
|
102
|
+
TABLE_NAME_COL: Final[str] = "TABLE_NAME"
|
103
|
+
CREATED_COL: Final[str] = "CREATED"
|
104
|
+
|
105
|
+
# SQL Table names
|
106
|
+
INFORMATION_SCHEMA_TABLE_NAME: Final[str] = "INFORMATION_SCHEMA"
|
107
|
+
TABLES_TABLE_NAME: Final[str] = "TABLES"
|
108
|
+
|
109
|
+
# SQL Query
|
110
|
+
EXCEPT_HASH_AGG_QUERY: Final[
|
111
|
+
str
|
112
|
+
] = "SELECT HASH_AGG(*) FROM IDENTIFIER(:1) EXCEPT SELECT HASH_AGG(*) FROM IDENTIFIER(:2)"
|
113
|
+
|
114
|
+
# Table checkpoints name
|
115
|
+
CHECKPOINT_TABLE_NAME_FORMAT: Final[str] = "{}_snowpark"
|
116
|
+
|
117
|
+
# Write mode
|
118
|
+
OVERWRITE_MODE: Final[str] = "overwrite"
|
119
|
+
|
120
|
+
# Validation modes
|
121
|
+
VALIDATION_MODE_KEY: Final[str] = "validation_mode"
|
122
|
+
PIPELINES_KEY: Final[str] = "pipelines"
|
123
|
+
|
124
|
+
# File name
|
125
|
+
CHECKPOINT_JSON_OUTPUT_FILE_FORMAT_NAME: Final[str] = "{}.json"
|
126
|
+
CHECKPOINTS_JSON_FILE_NAME: Final[str] = "checkpoints.json"
|
127
|
+
SNOWPARK_CHECKPOINTS_OUTPUT_DIRECTORY_NAME: Final[str] = "snowpark-checkpoints-output"
|
128
|
+
CHECKPOINT_PARQUET_OUTPUT_FILE_FORMAT_NAME: Final[str] = "{}.parquet"
|
129
|
+
VALIDATION_RESULTS_JSON_FILE_NAME: Final[str] = "checkpoint_validation_results.json"
|
130
|
+
|
131
|
+
# Environment variables
|
132
|
+
SNOWFLAKE_CHECKPOINT_CONTRACT_FILE_PATH_ENV_VAR: Final[
|
133
|
+
str
|
134
|
+
] = "SNOWFLAKE_CHECKPOINT_CONTRACT_FILE_PATH"
|