snowpark-checkpoints-validators 0.3.2__tar.gz → 0.4.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/PKG-INFO +1 -1
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/pyproject.toml +1 -1
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/src/snowflake/snowpark_checkpoints/__version__.py +1 -1
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/src/snowflake/snowpark_checkpoints/checkpoint.py +24 -4
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/src/snowflake/snowpark_checkpoints/snowpark_sampler.py +104 -3
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/src/snowflake/snowpark_checkpoints/utils/constants.py +14 -0
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/src/snowflake/snowpark_checkpoints/utils/utils_checks.py +87 -9
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/integ/telemetry_expected/test_df_check_custom_check_telemetry.json +1 -1
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/integ/telemetry_expected/test_df_check_from_file_telemetry.json +1 -1
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/integ/telemetry_expected/test_df_check_skip_check_telemetry.json +1 -1
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/integ/telemetry_expected/test_df_check_telemetry.json +1 -1
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/integ/telemetry_expected/test_input_telemetry.json +1 -1
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/integ/telemetry_expected/test_output_telemetry.json +1 -1
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/integ/test_pandera.py +7 -7
- snowpark_checkpoints_validators-0.4.0/test/unit/test_pandera_validations.py +130 -0
- snowpark_checkpoints_validators-0.4.0/test/unit/test_snowpark_sampler.py +117 -0
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/unit/test_utils_checks.py +18 -11
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/unit/test_validation_result_metadata.py +3 -3
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/.gitignore +0 -0
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/CHANGELOG.md +0 -0
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/LICENSE +0 -0
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/README.md +0 -0
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/src/snowflake/snowpark_checkpoints/__init__.py +0 -0
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/src/snowflake/snowpark_checkpoints/errors.py +0 -0
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/src/snowflake/snowpark_checkpoints/io_utils/__init__.py +0 -0
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/src/snowflake/snowpark_checkpoints/io_utils/io_default_strategy.py +0 -0
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/src/snowflake/snowpark_checkpoints/io_utils/io_env_strategy.py +0 -0
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/src/snowflake/snowpark_checkpoints/io_utils/io_file_manager.py +0 -0
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/src/snowflake/snowpark_checkpoints/job_context.py +0 -0
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/src/snowflake/snowpark_checkpoints/singleton.py +0 -0
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/src/snowflake/snowpark_checkpoints/spark_migration.py +0 -0
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/src/snowflake/snowpark_checkpoints/utils/__init__.py +0 -0
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/src/snowflake/snowpark_checkpoints/utils/extra_config.py +0 -0
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/src/snowflake/snowpark_checkpoints/utils/logging_utils.py +0 -0
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/src/snowflake/snowpark_checkpoints/utils/pandera_check_manager.py +0 -0
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/src/snowflake/snowpark_checkpoints/utils/supported_types.py +0 -0
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/src/snowflake/snowpark_checkpoints/utils/telemetry.py +0 -0
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/src/snowflake/snowpark_checkpoints/validation_result_metadata.py +0 -0
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/src/snowflake/snowpark_checkpoints/validation_results.py +0 -0
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/.coveragerc +0 -0
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/integ/e2eexample.py +0 -0
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/integ/telemetry_compare_utils.py +0 -0
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/integ/telemetry_expected/df_mode_dataframe_mismatch_telemetry.json +0 -0
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/integ/telemetry_expected/df_mode_dataframe_telemetry.json +0 -0
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/integ/telemetry_expected/spark_checkpoint_df_fail_telemetry.json +0 -0
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/integ/telemetry_expected/spark_checkpoint_df_pass_telemetry.json +0 -0
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/integ/telemetry_expected/spark_checkpoint_limit_sample_telemetry.json +0 -0
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/integ/telemetry_expected/spark_checkpoint_random_sample_telemetry.json +0 -0
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/integ/telemetry_expected/spark_checkpoint_scalar_fail_telemetry.json +0 -0
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/integ/telemetry_expected/spark_checkpoint_scalar_passing_telemetry.json +0 -0
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/integ/telemetry_expected/test_df_check_fail_telemetry.json +0 -0
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/integ/telemetry_expected/test_input_fail_telemetry.json +0 -0
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/integ/telemetry_expected/test_output_fail_telemetry.json +0 -0
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/integ/test_parquet.py +0 -0
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/integ/test_spark_checkpoint.py +0 -0
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/unit/io_utils/test_default_strategy.py +0 -0
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/unit/test_checkpoints.py +0 -0
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/unit/test_extra_config.py +0 -0
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/unit/test_job_context.py +0 -0
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/unit/test_logger.py +0 -0
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/unit/test_logging_utils.py +0 -0
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/unit/test_pandera_check_manager.py +0 -0
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/unit/test_spark_migration.py +0 -0
- {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/unit/test_telemetry.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: snowpark-checkpoints-validators
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.4.0
|
4
4
|
Summary: Migration tools for Snowpark
|
5
5
|
Project-URL: Bug Tracker, https://github.com/snowflakedb/snowpark-checkpoints/issues
|
6
6
|
Project-URL: Source code, https://github.com/snowflakedb/snowpark-checkpoints/
|
{snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/pyproject.toml
RENAMED
@@ -127,7 +127,7 @@ exclude_lines = [
|
|
127
127
|
|
128
128
|
[tool.hatch.envs.linter.scripts]
|
129
129
|
check = [
|
130
|
-
'ruff check --fix .',
|
130
|
+
"echo 'Running linting checks...' && ruff check --config=../ruff.toml --statistics --verbose . || (echo '❌ LINTING FAILED: Please fix the above linting issues before proceeding. Use \"ruff check --config=../ruff.toml --fix .\" to auto-fix some issues, or fix them manually.' && exit 1)",
|
131
131
|
]
|
132
132
|
|
133
133
|
[tool.hatch.envs.test.scripts]
|
@@ -332,7 +332,7 @@ def _check_dataframe_schema(
|
|
332
332
|
pandera_schema_upper, sample_df = _process_sampling(
|
333
333
|
df, pandera_schema, job_context, sample_frac, sample_number, sampling_strategy
|
334
334
|
)
|
335
|
-
is_valid, validation_result =
|
335
|
+
is_valid, validation_result = validate(pandera_schema_upper, sample_df)
|
336
336
|
if is_valid:
|
337
337
|
LOGGER.info(
|
338
338
|
"DataFrame schema validation passed for checkpoint '%s'",
|
@@ -438,7 +438,7 @@ def check_output_schema(
|
|
438
438
|
sampler.process_args([snowpark_results])
|
439
439
|
pandas_sample_args = sampler.get_sampled_pandas_args()
|
440
440
|
|
441
|
-
is_valid, validation_result =
|
441
|
+
is_valid, validation_result = validate(
|
442
442
|
pandera_schema, pandas_sample_args[0]
|
443
443
|
)
|
444
444
|
if is_valid:
|
@@ -565,7 +565,7 @@ def check_input_schema(
|
|
565
565
|
)
|
566
566
|
continue
|
567
567
|
|
568
|
-
is_valid, validation_result =
|
568
|
+
is_valid, validation_result = validate(
|
569
569
|
pandera_schema,
|
570
570
|
arg,
|
571
571
|
)
|
@@ -606,11 +606,31 @@ def check_input_schema(
|
|
606
606
|
return check_input_with_decorator
|
607
607
|
|
608
608
|
|
609
|
-
def
|
609
|
+
def validate(
|
610
610
|
schema: Union[type[DataFrameModel], DataFrameSchema],
|
611
611
|
df: PandasDataFrame,
|
612
612
|
lazy: bool = True,
|
613
613
|
) -> tuple[bool, PandasDataFrame]:
|
614
|
+
"""Validate a Pandas DataFrame against a given Pandera schema.
|
615
|
+
|
616
|
+
Args:
|
617
|
+
schema (Union[type[DataFrameModel], DataFrameSchema]):
|
618
|
+
The schema to validate against. Can be a Pandera `DataFrameSchema` or
|
619
|
+
a `DataFrameModel` class.
|
620
|
+
df (PandasDataFrame):
|
621
|
+
The Pandas DataFrame to be validated.
|
622
|
+
lazy (bool, optional):
|
623
|
+
If `True`, collect all validation errors before raising an exception.
|
624
|
+
If `False`, raise an exception as soon as the first error is encountered.
|
625
|
+
Defaults to `True`.
|
626
|
+
|
627
|
+
Returns:
|
628
|
+
tuple[bool, PandasDataFrame]:
|
629
|
+
A tuple containing:
|
630
|
+
- A boolean indicating whether validation passed.
|
631
|
+
- The validated DataFrame if successful, or the failure cases DataFrame if not.
|
632
|
+
|
633
|
+
"""
|
614
634
|
if not isinstance(schema, DataFrameSchema):
|
615
635
|
schema = schema.to_schema()
|
616
636
|
is_valid = True
|
@@ -20,7 +20,21 @@ from typing import Optional
|
|
20
20
|
import pandas
|
21
21
|
|
22
22
|
from snowflake.snowpark import DataFrame as SnowparkDataFrame
|
23
|
+
from snowflake.snowpark.types import (
|
24
|
+
BinaryType,
|
25
|
+
BooleanType,
|
26
|
+
DateType,
|
27
|
+
FloatType,
|
28
|
+
StringType,
|
29
|
+
TimestampType,
|
30
|
+
)
|
23
31
|
from snowflake.snowpark_checkpoints.job_context import SnowparkJobContext
|
32
|
+
from snowflake.snowpark_checkpoints.utils.constants import (
|
33
|
+
INTEGER_TYPE_COLLECTION,
|
34
|
+
PANDAS_FLOAT_TYPE,
|
35
|
+
PANDAS_LONG_TYPE,
|
36
|
+
PANDAS_STRING_TYPE,
|
37
|
+
)
|
24
38
|
|
25
39
|
|
26
40
|
LOGGER = logging.getLogger(__name__)
|
@@ -73,17 +87,17 @@ class SamplingAdapter:
|
|
73
87
|
"Applying random sampling with fraction %s",
|
74
88
|
self.sample_frac,
|
75
89
|
)
|
76
|
-
df_sample = arg.sample(frac=self.sample_frac)
|
90
|
+
df_sample = to_pandas(arg.sample(frac=self.sample_frac))
|
77
91
|
else:
|
78
92
|
LOGGER.info(
|
79
93
|
"Applying random sampling with size %s", self.sample_number
|
80
94
|
)
|
81
|
-
df_sample = arg.sample(n=self.sample_number)
|
95
|
+
df_sample = to_pandas(arg.sample(n=self.sample_number))
|
82
96
|
else:
|
83
97
|
LOGGER.info(
|
84
98
|
"Applying limit sampling with size %s", self.sample_number
|
85
99
|
)
|
86
|
-
df_sample = arg.limit(self.sample_number)
|
100
|
+
df_sample = to_pandas(arg.limit(self.sample_number))
|
87
101
|
|
88
102
|
LOGGER.info(
|
89
103
|
"Successfully sampled the DataFrame. Resulting DataFrame shape: %s",
|
@@ -122,3 +136,90 @@ class SamplingAdapter:
|
|
122
136
|
else:
|
123
137
|
pyspark_sample_args.append(arg)
|
124
138
|
return pyspark_sample_args
|
139
|
+
|
140
|
+
|
141
|
+
def to_pandas(sampled_df: SnowparkDataFrame) -> pandas.DataFrame:
|
142
|
+
"""Convert a Snowpark DataFrame to a Pandas DataFrame, handling missing values and type conversions."""
|
143
|
+
LOGGER.debug("Converting Snowpark DataFrame to Pandas DataFrame")
|
144
|
+
pandas_df = sampled_df.toPandas()
|
145
|
+
for field in sampled_df.schema.fields:
|
146
|
+
is_snowpark_integer = field.datatype.typeName() in INTEGER_TYPE_COLLECTION
|
147
|
+
is_snowpark_string = isinstance(field.datatype, StringType)
|
148
|
+
is_snowpark_binary = isinstance(field.datatype, BinaryType)
|
149
|
+
is_snowpark_timestamp = isinstance(field.datatype, TimestampType)
|
150
|
+
is_snowpark_float = isinstance(field.datatype, FloatType)
|
151
|
+
is_snowpark_boolean = isinstance(field.datatype, BooleanType)
|
152
|
+
is_snowpark_date = isinstance(field.datatype, DateType)
|
153
|
+
if is_snowpark_integer:
|
154
|
+
LOGGER.debug(
|
155
|
+
"Converting Spark integer column '%s' to Pandas nullable '%s' type",
|
156
|
+
field.name,
|
157
|
+
PANDAS_LONG_TYPE,
|
158
|
+
)
|
159
|
+
pandas_df[field.name] = (
|
160
|
+
pandas_df[field.name].astype(PANDAS_LONG_TYPE).fillna(0)
|
161
|
+
)
|
162
|
+
elif is_snowpark_string or is_snowpark_binary:
|
163
|
+
LOGGER.debug(
|
164
|
+
"Converting Spark string column '%s' to Pandas nullable '%s' type",
|
165
|
+
field.name,
|
166
|
+
PANDAS_STRING_TYPE,
|
167
|
+
)
|
168
|
+
pandas_df[field.name] = (
|
169
|
+
pandas_df[field.name].astype(PANDAS_STRING_TYPE).fillna("")
|
170
|
+
)
|
171
|
+
elif is_snowpark_timestamp:
|
172
|
+
LOGGER.debug(
|
173
|
+
"Converting Spark timestamp column '%s' to UTC naive Pandas datetime",
|
174
|
+
field.name,
|
175
|
+
)
|
176
|
+
pandas_df[field.name] = convert_all_to_utc_naive(
|
177
|
+
pandas_df[field.name]
|
178
|
+
).fillna(pandas.NaT)
|
179
|
+
elif is_snowpark_float:
|
180
|
+
LOGGER.debug(
|
181
|
+
"Converting Spark float column '%s' to Pandas nullable float",
|
182
|
+
field.name,
|
183
|
+
)
|
184
|
+
pandas_df[field.name] = (
|
185
|
+
pandas_df[field.name].astype(PANDAS_FLOAT_TYPE).fillna(0.0)
|
186
|
+
)
|
187
|
+
elif is_snowpark_boolean:
|
188
|
+
LOGGER.debug(
|
189
|
+
"Converting Spark boolean column '%s' to Pandas nullable boolean",
|
190
|
+
field.name,
|
191
|
+
)
|
192
|
+
pandas_df[field.name] = (
|
193
|
+
pandas_df[field.name].astype("boolean").fillna(False)
|
194
|
+
)
|
195
|
+
elif is_snowpark_date:
|
196
|
+
LOGGER.debug(
|
197
|
+
"Converting Spark date column '%s' to Pandas nullable datetime",
|
198
|
+
field.name,
|
199
|
+
)
|
200
|
+
pandas_df[field.name] = pandas_df[field.name].fillna(pandas.NaT)
|
201
|
+
|
202
|
+
return pandas_df
|
203
|
+
|
204
|
+
|
205
|
+
def convert_all_to_utc_naive(series: pandas.Series) -> pandas.Series:
|
206
|
+
"""Convert all timezone-aware or naive timestamps in a series to UTC naive.
|
207
|
+
|
208
|
+
Naive timestamps are assumed to be in UTC and localized accordingly.
|
209
|
+
Timezone-aware timestamps are converted to UTC and then made naive.
|
210
|
+
|
211
|
+
Args:
|
212
|
+
series (pandas.Series): A Pandas Series of `pd.Timestamp` objects,
|
213
|
+
either naive or timezone-aware.
|
214
|
+
|
215
|
+
Returns:
|
216
|
+
pandas.Series: A Series of UTC-normalized naive timestamps (`tzinfo=None`).
|
217
|
+
|
218
|
+
"""
|
219
|
+
|
220
|
+
def convert(ts):
|
221
|
+
if ts.tz is None:
|
222
|
+
ts = ts.tz_localize("UTC")
|
223
|
+
return ts.tz_convert("UTC").tz_localize(None)
|
224
|
+
|
225
|
+
return series.apply(convert)
|
@@ -133,3 +133,17 @@ VALIDATION_RESULTS_JSON_FILE_NAME: Final[str] = "checkpoint_validation_results.j
|
|
133
133
|
SNOWFLAKE_CHECKPOINT_CONTRACT_FILE_PATH_ENV_VAR: Final[
|
134
134
|
str
|
135
135
|
] = "SNOWFLAKE_CHECKPOINT_CONTRACT_FILE_PATH"
|
136
|
+
|
137
|
+
BYTE_COLUMN_TYPE = "byte"
|
138
|
+
INTEGER_COLUMN_TYPE = "integer"
|
139
|
+
LONG_COLUMN_TYPE = "long"
|
140
|
+
SHORT_COLUMN_TYPE = "short"
|
141
|
+
PANDAS_FLOAT_TYPE = "float64"
|
142
|
+
PANDAS_LONG_TYPE = "Int64"
|
143
|
+
PANDAS_STRING_TYPE = "string"
|
144
|
+
INTEGER_TYPE_COLLECTION = [
|
145
|
+
BYTE_COLUMN_TYPE,
|
146
|
+
INTEGER_COLUMN_TYPE,
|
147
|
+
LONG_COLUMN_TYPE,
|
148
|
+
SHORT_COLUMN_TYPE,
|
149
|
+
]
|
@@ -27,6 +27,9 @@ import numpy as np
|
|
27
27
|
from pandera import DataFrameSchema
|
28
28
|
|
29
29
|
from snowflake.snowpark import DataFrame as SnowparkDataFrame
|
30
|
+
from snowflake.snowpark import Session
|
31
|
+
from snowflake.snowpark.functions import col, expr
|
32
|
+
from snowflake.snowpark.types import TimestampType
|
30
33
|
from snowflake.snowpark_checkpoints.errors import SchemaValidationError
|
31
34
|
from snowflake.snowpark_checkpoints.io_utils.io_file_manager import get_io_file_manager
|
32
35
|
from snowflake.snowpark_checkpoints.job_context import SnowparkJobContext
|
@@ -42,7 +45,6 @@ from snowflake.snowpark_checkpoints.utils.constants import (
|
|
42
45
|
DATAFRAME_EXECUTION_MODE,
|
43
46
|
DATAFRAME_PANDERA_SCHEMA_KEY,
|
44
47
|
DEFAULT_KEY,
|
45
|
-
EXCEPT_HASH_AGG_QUERY,
|
46
48
|
FAIL_STATUS,
|
47
49
|
PASS_STATUS,
|
48
50
|
SNOWPARK_CHECKPOINTS_OUTPUT_DIRECTORY_NAME,
|
@@ -120,14 +122,14 @@ def _process_sampling(
|
|
120
122
|
pandera_schema_upper = pandera_schema
|
121
123
|
new_columns: dict[Any, Any] = {}
|
122
124
|
|
123
|
-
for
|
124
|
-
new_columns[
|
125
|
+
for column in pandera_schema.columns:
|
126
|
+
new_columns[column.upper()] = pandera_schema.columns[column]
|
125
127
|
|
126
128
|
pandera_schema_upper = pandera_schema_upper.remove_columns(pandera_schema.columns)
|
127
129
|
pandera_schema_upper = pandera_schema_upper.add_columns(new_columns)
|
128
130
|
|
129
131
|
sample_df = sampler.get_sampled_pandas_args()[0]
|
130
|
-
sample_df.index = np.ones(sample_df.count().iloc[0])
|
132
|
+
sample_df.index = np.ones(sample_df.count().iloc[0], dtype=int)
|
131
133
|
|
132
134
|
return pandera_schema_upper, sample_df
|
133
135
|
|
@@ -191,6 +193,7 @@ Please run the Snowpark checkpoint collector first."""
|
|
191
193
|
schema_dict = checkpoint_schema_config.get(DATAFRAME_PANDERA_SCHEMA_KEY)
|
192
194
|
schema_dict_str = json.dumps(schema_dict)
|
193
195
|
schema = DataFrameSchema.from_json(schema_dict_str)
|
196
|
+
schema.coerce = False # Disable coercion to ensure strict validation
|
194
197
|
|
195
198
|
if DATAFRAME_CUSTOM_DATA_KEY not in checkpoint_schema_config:
|
196
199
|
LOGGER.info(
|
@@ -270,6 +273,7 @@ def _compare_data(
|
|
270
273
|
SchemaValidationError: If there is a data mismatch between the DataFrame and the checkpoint table.
|
271
274
|
|
272
275
|
"""
|
276
|
+
df = convert_timestamps_to_utc_date(df)
|
273
277
|
new_table_name = CHECKPOINT_TABLE_NAME_FORMAT.format(checkpoint_name)
|
274
278
|
LOGGER.info(
|
275
279
|
"Writing Snowpark DataFrame to table: '%s' for checkpoint: '%s'",
|
@@ -283,12 +287,12 @@ def _compare_data(
|
|
283
287
|
new_table_name,
|
284
288
|
checkpoint_name,
|
285
289
|
)
|
286
|
-
expect_df = job_context.snowpark_session.sql(
|
287
|
-
EXCEPT_HASH_AGG_QUERY, [checkpoint_name, new_table_name]
|
288
|
-
)
|
289
290
|
|
290
|
-
|
291
|
-
|
291
|
+
session = job_context.snowpark_session
|
292
|
+
result = get_comparison_differences(session, checkpoint_name, new_table_name)
|
293
|
+
has_failed = result.get("spark_only_rows") or result.get("snowpark_only_rows")
|
294
|
+
if has_failed or result.get("error"):
|
295
|
+
error_message = f"Data mismatch for checkpoint {checkpoint_name}: {result}"
|
292
296
|
job_context._mark_fail(
|
293
297
|
error_message,
|
294
298
|
checkpoint_name,
|
@@ -312,6 +316,80 @@ def _compare_data(
|
|
312
316
|
return True, None
|
313
317
|
|
314
318
|
|
319
|
+
def get_comparison_differences(
|
320
|
+
session: Session, spark_table: str, snowpark_table: str
|
321
|
+
) -> dict:
|
322
|
+
"""Compare two tables and return the differences."""
|
323
|
+
try:
|
324
|
+
spark_raw_schema = session.table(spark_table).schema.names
|
325
|
+
snowpark_raw_schema = session.table(snowpark_table).schema.names
|
326
|
+
|
327
|
+
spark_normalized = {
|
328
|
+
col_name.strip('"').upper(): col_name for col_name in spark_raw_schema
|
329
|
+
}
|
330
|
+
snowpark_normalized = {
|
331
|
+
col_name.strip('"').upper(): col_name for col_name in snowpark_raw_schema
|
332
|
+
}
|
333
|
+
|
334
|
+
common_cols = sorted(
|
335
|
+
list(
|
336
|
+
set(spark_normalized.keys()).intersection(
|
337
|
+
set(snowpark_normalized.keys())
|
338
|
+
)
|
339
|
+
)
|
340
|
+
)
|
341
|
+
|
342
|
+
if not common_cols:
|
343
|
+
return {
|
344
|
+
"error": f"No common columns found between {spark_table} and {snowpark_table}",
|
345
|
+
}
|
346
|
+
|
347
|
+
cols_for_spark_selection = [
|
348
|
+
spark_normalized[norm_col_name] for norm_col_name in common_cols
|
349
|
+
]
|
350
|
+
cols_for_snowpark_selection = [
|
351
|
+
snowpark_normalized[norm_col_name] for norm_col_name in common_cols
|
352
|
+
]
|
353
|
+
|
354
|
+
spark_ordered = session.table(spark_table).select(
|
355
|
+
*[col(c) for c in cols_for_spark_selection]
|
356
|
+
)
|
357
|
+
snowpark_ordered = session.table(snowpark_table).select(
|
358
|
+
*[col(c) for c in cols_for_snowpark_selection]
|
359
|
+
)
|
360
|
+
|
361
|
+
spark_leftovers = spark_ordered.except_(snowpark_ordered).collect()
|
362
|
+
snowpark_leftovers = snowpark_ordered.except_(spark_ordered).collect()
|
363
|
+
|
364
|
+
spark_only_rows = [row.asDict() for row in spark_leftovers]
|
365
|
+
snowpark_only_rows = [row.asDict() for row in snowpark_leftovers]
|
366
|
+
|
367
|
+
return {
|
368
|
+
"spark_only_rows": spark_only_rows,
|
369
|
+
"snowpark_only_rows": snowpark_only_rows,
|
370
|
+
}
|
371
|
+
|
372
|
+
except Exception as e:
|
373
|
+
return {"error": f"An error occurred: {str(e)}"}
|
374
|
+
|
375
|
+
|
376
|
+
def convert_timestamps_to_utc_date(df):
|
377
|
+
"""Convert and normalize all Snowpark timestamp columns to UTC.
|
378
|
+
|
379
|
+
This function ensures timestamps are consistent across environments for reliable comparison.
|
380
|
+
"""
|
381
|
+
new_cols = []
|
382
|
+
for field in df.schema.fields:
|
383
|
+
if isinstance(field.datatype, TimestampType):
|
384
|
+
utc_midnight_ts = expr(
|
385
|
+
f"convert_timezone('UTC', cast(to_date({field.name}) as timestamp_tz))"
|
386
|
+
).alias(field.name)
|
387
|
+
new_cols.append(utc_midnight_ts)
|
388
|
+
else:
|
389
|
+
new_cols.append(col(field.name))
|
390
|
+
return df.select(new_cols)
|
391
|
+
|
392
|
+
|
315
393
|
def _find_frame_in(stack: list[inspect.FrameInfo]) -> tuple:
|
316
394
|
"""Find a specific frame in the provided stack trace.
|
317
395
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
{
|
2
2
|
"message": {
|
3
|
-
"data": "{\"function\": \"_check_dataframe_schema\", \"mode\": 1, \"status\": true, \"schema_types\": [\"
|
3
|
+
"data": "{\"function\": \"_check_dataframe_schema\", \"mode\": 1, \"status\": true, \"schema_types\": [\"int64\", \"float64\"]}",
|
4
4
|
"driver_type": "PythonConnector",
|
5
5
|
"driver_version": "3.12.4",
|
6
6
|
"event_name": "DataFrame_Validator_Schema",
|
@@ -1,6 +1,6 @@
|
|
1
1
|
{
|
2
2
|
"message": {
|
3
|
-
"data": "{\"function\": \"_check_dataframe_schema\", \"mode\": 1, \"status\": true, \"schema_types\": [\"
|
3
|
+
"data": "{\"function\": \"_check_dataframe_schema\", \"mode\": 1, \"status\": true, \"schema_types\": [\"int64\", \"float64\"]}",
|
4
4
|
"driver_type": "PythonConnector",
|
5
5
|
"driver_version": "3.12.4",
|
6
6
|
"event_name": "DataFrame_Validator_Schema",
|
@@ -1,6 +1,6 @@
|
|
1
1
|
{
|
2
2
|
"message": {
|
3
|
-
"data": "{\"function\": \"_check_dataframe_schema\", \"mode\": 1, \"status\": true, \"schema_types\": [\"
|
3
|
+
"data": "{\"function\": \"_check_dataframe_schema\", \"mode\": 1, \"status\": true, \"schema_types\": [\"int64\", \"float64\"]}",
|
4
4
|
"driver_type": "PythonConnector",
|
5
5
|
"driver_version": "3.12.4",
|
6
6
|
"event_name": "DataFrame_Validator_Schema",
|
@@ -1,6 +1,6 @@
|
|
1
1
|
{
|
2
2
|
"message": {
|
3
|
-
"data": "{\"function\": \"_check_dataframe_schema\", \"mode\": 1, \"status\": true, \"schema_types\": [\"
|
3
|
+
"data": "{\"function\": \"_check_dataframe_schema\", \"mode\": 1, \"status\": true, \"schema_types\": [\"int64\", \"float64\"]}",
|
4
4
|
"driver_type": "PythonConnector",
|
5
5
|
"driver_version": "3.12.4",
|
6
6
|
"event_name": "DataFrame_Validator_Schema",
|
@@ -1,6 +1,6 @@
|
|
1
1
|
{
|
2
2
|
"message": {
|
3
|
-
"data": "{\"function\": \"check_input_schema\", \"schema_types\": [\"
|
3
|
+
"data": "{\"function\": \"check_input_schema\", \"schema_types\": [\"int64\", \"float64\"]}",
|
4
4
|
"driver_type": "PythonConnector",
|
5
5
|
"driver_version": "3.12.4",
|
6
6
|
"event_name": "DataFrame_Validator_Schema",
|
@@ -1,6 +1,6 @@
|
|
1
1
|
{
|
2
2
|
"message": {
|
3
|
-
"data": "{\"function\": \"check_output_schema\", \"schema_types\": [\"
|
3
|
+
"data": "{\"function\": \"check_output_schema\", \"schema_types\": [\"int64\", \"float64\", \"float64\"]}",
|
4
4
|
"driver_type": "PythonConnector",
|
5
5
|
"driver_version": "3.12.4",
|
6
6
|
"event_name": "DataFrame_Validator_Schema",
|
@@ -23,7 +23,7 @@ from unittest.mock import MagicMock, patch
|
|
23
23
|
|
24
24
|
import pytest
|
25
25
|
|
26
|
-
from numpy import int8
|
26
|
+
from numpy import int8, int64
|
27
27
|
from pandas import DataFrame as PandasDataFrame
|
28
28
|
from pandera import Check, Column, DataFrameSchema
|
29
29
|
from pytest import raises
|
@@ -78,7 +78,7 @@ def test_input(telemetry_output_path):
|
|
78
78
|
|
79
79
|
in_schema = DataFrameSchema(
|
80
80
|
{
|
81
|
-
"COLUMN1": Column(
|
81
|
+
"COLUMN1": Column(int64, Check(lambda x: 0 <= x <= 10, element_wise=True)),
|
82
82
|
"COLUMN2": Column(float, Check(lambda x: x < -1.2, element_wise=True)),
|
83
83
|
}
|
84
84
|
)
|
@@ -161,7 +161,7 @@ def test_output(telemetry_output_path):
|
|
161
161
|
out_schema = DataFrameSchema(
|
162
162
|
{
|
163
163
|
"COLUMN1": Column(
|
164
|
-
|
164
|
+
int64, Check.between(0, 10, include_max=True, include_min=True)
|
165
165
|
),
|
166
166
|
"COLUMN2": Column(float, Check.less_than_or_equal_to(-1.2)),
|
167
167
|
"COLUMN3": Column(float, Check.less_than(10)),
|
@@ -244,7 +244,7 @@ def test_df_check(telemetry_output_path):
|
|
244
244
|
|
245
245
|
schema = DataFrameSchema(
|
246
246
|
{
|
247
|
-
"COLUMN1": Column(
|
247
|
+
"COLUMN1": Column(int64, Check(lambda x: 0 <= x <= 10, element_wise=True)),
|
248
248
|
"COLUMN2": Column(float, Check(lambda x: x < -1.2, element_wise=True)),
|
249
249
|
}
|
250
250
|
)
|
@@ -320,7 +320,7 @@ def test_df_check_from_file(telemetry_output_path):
|
|
320
320
|
|
321
321
|
schema = DataFrameSchema(
|
322
322
|
{
|
323
|
-
"COLUMN1": Column(
|
323
|
+
"COLUMN1": Column(int64, Check.between(0, 10)),
|
324
324
|
"COLUMN2": Column(float, Check.between(-20.5, -1.0)),
|
325
325
|
}
|
326
326
|
)
|
@@ -409,7 +409,7 @@ def test_df_check_custom_check(telemetry_output_path):
|
|
409
409
|
|
410
410
|
schema = DataFrameSchema(
|
411
411
|
{
|
412
|
-
"COLUMN1": Column(
|
412
|
+
"COLUMN1": Column(int64, Check(lambda x: 0 <= x <= 10, element_wise=True)),
|
413
413
|
"COLUMN2": Column(float, Check(lambda x: x < -1.2, element_wise=True)),
|
414
414
|
}
|
415
415
|
)
|
@@ -454,7 +454,7 @@ def test_df_check_skip_check(telemetry_output_path):
|
|
454
454
|
|
455
455
|
schema = DataFrameSchema(
|
456
456
|
{
|
457
|
-
"COLUMN1": Column(
|
457
|
+
"COLUMN1": Column(int64, Check.between(0, 10, element_wise=True)),
|
458
458
|
"COLUMN2": Column(
|
459
459
|
float,
|
460
460
|
[
|
@@ -0,0 +1,130 @@
|
|
1
|
+
# Copyright 2025 Snowflake Inc.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
from snowflake.snowpark_checkpoints.checkpoint import validate
|
17
|
+
from pandera import DataFrameSchema, Column, Check
|
18
|
+
import pandas as pd
|
19
|
+
import pytz
|
20
|
+
|
21
|
+
|
22
|
+
def test_pandera_validate_equivalent_dataframes():
|
23
|
+
schema = DataFrameSchema(
|
24
|
+
{
|
25
|
+
"a": Column(
|
26
|
+
int, checks=Check(lambda s: s > 0, element_wise=True), nullable=False
|
27
|
+
)
|
28
|
+
}
|
29
|
+
)
|
30
|
+
df = pd.DataFrame({"a": [1, 2, 3]})
|
31
|
+
result, validated_df = validate(schema, df)
|
32
|
+
assert result
|
33
|
+
pd.testing.assert_frame_equal(validated_df, df)
|
34
|
+
|
35
|
+
|
36
|
+
def test_pandera_validate_object_vs_string():
|
37
|
+
schema = DataFrameSchema({"a": Column(str, nullable=False)})
|
38
|
+
|
39
|
+
df_object = pd.DataFrame({"a": pd.Series(["x", "y", "z"], dtype="object")})
|
40
|
+
result, validated_df = validate(schema, df_object)
|
41
|
+
assert result
|
42
|
+
|
43
|
+
df_int_as_string = pd.DataFrame({"a": ["1", "2", "3"]})
|
44
|
+
result, validated_df = validate(schema, df_int_as_string)
|
45
|
+
assert result
|
46
|
+
|
47
|
+
df_mixed = pd.DataFrame({"a": ["x", 1, "z"]})
|
48
|
+
result, validated_df = validate(schema, df_mixed)
|
49
|
+
assert not result
|
50
|
+
|
51
|
+
|
52
|
+
def test_pandera_validate_int_vs_string():
|
53
|
+
schema = DataFrameSchema({"a": Column(int, nullable=False)})
|
54
|
+
df_valid_int = pd.DataFrame({"a": [1, 2, 3]})
|
55
|
+
result, _ = validate(schema, df_valid_int)
|
56
|
+
assert result
|
57
|
+
|
58
|
+
df_string_numbers = pd.DataFrame({"a": ["1", "2", "3"]})
|
59
|
+
result, failure_cases = validate(schema, df_string_numbers)
|
60
|
+
assert not result
|
61
|
+
|
62
|
+
df_mixed = pd.DataFrame({"a": [1, "2", 3]})
|
63
|
+
result, failure_cases = validate(schema, df_mixed)
|
64
|
+
assert not result
|
65
|
+
|
66
|
+
|
67
|
+
def test_timestamp_ntz():
|
68
|
+
schema = DataFrameSchema({"ts": Column(pd.Timestamp, nullable=False)})
|
69
|
+
|
70
|
+
df = pd.DataFrame(
|
71
|
+
{
|
72
|
+
"ts": pd.to_datetime(
|
73
|
+
["2024-01-01 10:00", "2024-01-02 11:00", "2024-01-03 12:00"]
|
74
|
+
)
|
75
|
+
}
|
76
|
+
)
|
77
|
+
result, validated_df = validate(schema, df)
|
78
|
+
assert result
|
79
|
+
assert validated_df["ts"].dt.tz is None
|
80
|
+
|
81
|
+
|
82
|
+
def test_timestamp_utc_timezone():
|
83
|
+
schema = DataFrameSchema({"ts": Column(pd.Timestamp, nullable=False)})
|
84
|
+
|
85
|
+
df = pd.DataFrame(
|
86
|
+
{
|
87
|
+
"ts": pd.to_datetime(
|
88
|
+
[
|
89
|
+
"2024-01-01 10:00+00:00",
|
90
|
+
"2024-01-02 11:00+00:00",
|
91
|
+
"2024-01-03 12:00+00:00",
|
92
|
+
]
|
93
|
+
)
|
94
|
+
}
|
95
|
+
)
|
96
|
+
|
97
|
+
df["ts"] = df["ts"].dt.tz_convert("UTC").dt.tz_localize(None)
|
98
|
+
|
99
|
+
result, validated_df = validate(schema, df)
|
100
|
+
assert result
|
101
|
+
assert validated_df["ts"].dt.tz is None
|
102
|
+
|
103
|
+
|
104
|
+
def convert_all_to_utc_naive(series: pd.Series) -> pd.Series:
|
105
|
+
def convert(ts):
|
106
|
+
if ts.tz is None:
|
107
|
+
ts = ts.tz_localize("UTC")
|
108
|
+
return ts.tz_convert("UTC").tz_localize(None)
|
109
|
+
|
110
|
+
return series.apply(convert)
|
111
|
+
|
112
|
+
|
113
|
+
def test_timestamp_mixed_timezones_fails():
|
114
|
+
schema = DataFrameSchema({"ts": Column(pd.Timestamp, nullable=False)})
|
115
|
+
eastern = pytz.timezone("US/Eastern")
|
116
|
+
df = pd.DataFrame(
|
117
|
+
{
|
118
|
+
"ts": [
|
119
|
+
pd.Timestamp("2024-01-01 10:00"),
|
120
|
+
eastern.localize(pd.Timestamp("2024-01-02 11:00")),
|
121
|
+
pd.Timestamp("2024-01-03 12:00+00:00"),
|
122
|
+
]
|
123
|
+
}
|
124
|
+
)
|
125
|
+
|
126
|
+
df["ts"] = convert_all_to_utc_naive(df["ts"])
|
127
|
+
result, validated_df = validate(schema, df)
|
128
|
+
|
129
|
+
assert result
|
130
|
+
assert validated_df["ts"].dt.tz is None
|
@@ -0,0 +1,117 @@
|
|
1
|
+
# Copyright 2025 Snowflake Inc.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
import pandas as pd
|
17
|
+
from snowflake.snowpark_checkpoints.snowpark_sampler import (
|
18
|
+
to_pandas,
|
19
|
+
convert_all_to_utc_naive,
|
20
|
+
)
|
21
|
+
from snowflake.snowpark.types import (
|
22
|
+
BinaryType,
|
23
|
+
FloatType,
|
24
|
+
StringType,
|
25
|
+
TimestampType,
|
26
|
+
)
|
27
|
+
from snowflake.snowpark_checkpoints.utils.constants import (
|
28
|
+
PANDAS_FLOAT_TYPE,
|
29
|
+
PANDAS_LONG_TYPE,
|
30
|
+
PANDAS_STRING_TYPE,
|
31
|
+
)
|
32
|
+
|
33
|
+
|
34
|
+
class DummyDataType:
|
35
|
+
def __init__(self, name):
|
36
|
+
self._name = name
|
37
|
+
|
38
|
+
def typeName(self):
|
39
|
+
return self._name
|
40
|
+
|
41
|
+
|
42
|
+
class DummyField:
|
43
|
+
def __init__(self, name, datatype):
|
44
|
+
self.name = name
|
45
|
+
self.datatype = datatype
|
46
|
+
self.typeName = datatype
|
47
|
+
|
48
|
+
|
49
|
+
class DummySchema:
|
50
|
+
def __init__(self, fields):
|
51
|
+
self.fields = fields
|
52
|
+
|
53
|
+
|
54
|
+
class DummySnowparkDF:
|
55
|
+
def __init__(self, pandas_df, fields):
|
56
|
+
self._pandas_df = pandas_df
|
57
|
+
self.schema = DummySchema(fields)
|
58
|
+
|
59
|
+
def toPandas(self):
|
60
|
+
return self._pandas_df
|
61
|
+
|
62
|
+
|
63
|
+
def test_to_pandas_integer_conversion():
|
64
|
+
df = pd.DataFrame({"int_col": [1, None]}, dtype="float")
|
65
|
+
fields = [DummyField("int_col", DummyDataType("integer"))]
|
66
|
+
sp_df = DummySnowparkDF(df, fields)
|
67
|
+
|
68
|
+
result = to_pandas(sp_df)
|
69
|
+
assert result["int_col"].dtype == PANDAS_LONG_TYPE
|
70
|
+
assert result["int_col"].iloc[1] == 0
|
71
|
+
|
72
|
+
|
73
|
+
def test_to_pandas_string_and_binary_conversion():
|
74
|
+
df = pd.DataFrame({"str_col": ["a", None], "bin_col": ["b", None]})
|
75
|
+
fields = [
|
76
|
+
DummyField("str_col", StringType()),
|
77
|
+
DummyField("bin_col", BinaryType()),
|
78
|
+
]
|
79
|
+
sp_df = DummySnowparkDF(df, fields)
|
80
|
+
|
81
|
+
result = to_pandas(sp_df)
|
82
|
+
assert result["str_col"].dtype == PANDAS_STRING_TYPE
|
83
|
+
assert result["bin_col"].dtype == PANDAS_STRING_TYPE
|
84
|
+
|
85
|
+
|
86
|
+
def test_to_pandas_float_conversion():
|
87
|
+
df = pd.DataFrame({"float_col": [1.1, None]}, dtype="float")
|
88
|
+
fields = [DummyField("float_col", FloatType())]
|
89
|
+
sp_df = DummySnowparkDF(df, fields)
|
90
|
+
|
91
|
+
result = to_pandas(sp_df)
|
92
|
+
assert result["float_col"].dtype == PANDAS_FLOAT_TYPE
|
93
|
+
|
94
|
+
|
95
|
+
def test_to_pandas_timestamp_conversion():
|
96
|
+
utc_ts = pd.Timestamp("2023-01-01 12:00:00", tz="UTC")
|
97
|
+
naive_ts = pd.Timestamp("2023-01-02 12:00:00")
|
98
|
+
df = pd.DataFrame({"ts_col": [utc_ts, naive_ts]})
|
99
|
+
fields = [DummyField("ts_col", TimestampType())]
|
100
|
+
sp_df = DummySnowparkDF(df, fields)
|
101
|
+
|
102
|
+
result = to_pandas(sp_df)
|
103
|
+
assert pd.api.types.is_datetime64_any_dtype(result["ts_col"])
|
104
|
+
assert result["ts_col"].iloc[0].tzinfo is None
|
105
|
+
assert result["ts_col"].iloc[1].tzinfo is None
|
106
|
+
|
107
|
+
|
108
|
+
def test_convert_all_to_utc_naive_behavior():
|
109
|
+
utc_ts = pd.Timestamp("2024-01-01 10:00:00", tz="UTC")
|
110
|
+
naive_ts = pd.Timestamp("2024-01-01 12:00:00")
|
111
|
+
none_val = pd.NaT
|
112
|
+
series = pd.Series([utc_ts, naive_ts, none_val])
|
113
|
+
|
114
|
+
result = convert_all_to_utc_naive(series)
|
115
|
+
assert result[0].tzinfo is None
|
116
|
+
assert result[1].tzinfo is None
|
117
|
+
assert pd.isna(result[2])
|
@@ -282,6 +282,7 @@ def test_compare_data_match():
|
|
282
282
|
job_context = MagicMock(spec=SnowparkJobContext)
|
283
283
|
session = MagicMock()
|
284
284
|
job_context.snowpark_session = session
|
285
|
+
job_context.job_name = checkpoint_name
|
285
286
|
|
286
287
|
# Mock session.sql to return an empty DataFrame (indicating no mismatch)
|
287
288
|
session.sql.return_value.count.return_value = 0
|
@@ -289,7 +290,6 @@ def test_compare_data_match():
|
|
289
290
|
checkpoint_name = "test_checkpoint"
|
290
291
|
validation_status = PASS_STATUS
|
291
292
|
output_path = "test_output_path/utils/"
|
292
|
-
|
293
293
|
with (
|
294
294
|
patch("os.getcwd", return_value="/mocked/path"),
|
295
295
|
patch("os.path.exists", return_value=False),
|
@@ -298,6 +298,14 @@ def test_compare_data_match():
|
|
298
298
|
patch(
|
299
299
|
"snowflake.snowpark_checkpoints.utils.utils_checks._update_validation_result"
|
300
300
|
) as mock_update_validation_result,
|
301
|
+
patch(
|
302
|
+
"snowflake.snowpark_checkpoints.utils.utils_checks.convert_timestamps_to_utc_date",
|
303
|
+
return_value=df,
|
304
|
+
),
|
305
|
+
patch(
|
306
|
+
"snowflake.snowpark_checkpoints.utils.utils_checks.get_comparison_differences",
|
307
|
+
return_value={},
|
308
|
+
) as mock_get_comparison_differences,
|
301
309
|
):
|
302
310
|
# Call the function
|
303
311
|
_check_compare_data(df, job_context, checkpoint_name, output_path)
|
@@ -309,11 +317,7 @@ def test_compare_data_match():
|
|
309
317
|
df.write.save_as_table.assert_called_once_with(
|
310
318
|
table_name=new_checkpoint_name, mode=OVERWRITE_MODE
|
311
319
|
)
|
312
|
-
|
313
|
-
call(EXCEPT_HASH_AGG_QUERY, [checkpoint_name, new_checkpoint_name]),
|
314
|
-
call().count(),
|
315
|
-
]
|
316
|
-
session.sql.assert_has_calls(calls)
|
320
|
+
mock_get_comparison_differences.assert_called_once()
|
317
321
|
job_context._mark_pass.assert_called_once_with(
|
318
322
|
checkpoint_name, DATAFRAME_EXECUTION_MODE
|
319
323
|
)
|
@@ -344,6 +348,13 @@ def test_compare_data_mismatch():
|
|
344
348
|
patch(
|
345
349
|
"snowflake.snowpark_checkpoints.utils.utils_checks._update_validation_result"
|
346
350
|
) as mock_update_validation_result,
|
351
|
+
patch(
|
352
|
+
"snowflake.snowpark_checkpoints.utils.utils_checks.convert_timestamps_to_utc_date",
|
353
|
+
return_value=df,
|
354
|
+
),
|
355
|
+
patch(
|
356
|
+
"snowflake.snowpark_checkpoints.utils.utils_checks.get_comparison_differences"
|
357
|
+
) as mock_get_comparison_differences,
|
347
358
|
):
|
348
359
|
# Call the function and expect a SchemaValidationError
|
349
360
|
with raises(
|
@@ -359,11 +370,7 @@ def test_compare_data_mismatch():
|
|
359
370
|
df.write.save_as_table.assert_called_once_with(
|
360
371
|
table_name=new_checkpoint_name, mode=OVERWRITE_MODE
|
361
372
|
)
|
362
|
-
|
363
|
-
call(EXCEPT_HASH_AGG_QUERY, [checkpoint_name, new_checkpoint_name]),
|
364
|
-
call().count(),
|
365
|
-
]
|
366
|
-
session.sql.assert_has_calls(calls)
|
373
|
+
mock_get_comparison_differences.assert_called_once()
|
367
374
|
job_context._mark_fail.assert_called()
|
368
375
|
job_context._mark_pass.assert_not_called()
|
369
376
|
|
@@ -16,7 +16,7 @@ from snowflake.snowpark_checkpoints.validation_results import (
|
|
16
16
|
)
|
17
17
|
from pandas import DataFrame as PandasDataFrame, testing as PandasTesting
|
18
18
|
from pandera import DataFrameSchema, Column, Check
|
19
|
-
from snowflake.snowpark_checkpoints.checkpoint import
|
19
|
+
from snowflake.snowpark_checkpoints.checkpoint import validate
|
20
20
|
|
21
21
|
|
22
22
|
@fixture()
|
@@ -204,7 +204,7 @@ def test_clean_with_no_file():
|
|
204
204
|
|
205
205
|
def test_validate_valid_schema(sample_data):
|
206
206
|
df, valid_schema, _ = sample_data
|
207
|
-
is_valid, result =
|
207
|
+
is_valid, result = validate(valid_schema, df)
|
208
208
|
assert is_valid
|
209
209
|
assert isinstance(result, PandasDataFrame)
|
210
210
|
PandasTesting.assert_frame_equal(result, df)
|
@@ -212,7 +212,7 @@ def test_validate_valid_schema(sample_data):
|
|
212
212
|
|
213
213
|
def test_validate_invalid_schema(sample_data):
|
214
214
|
df, _, invalid_schema = sample_data
|
215
|
-
is_valid, result =
|
215
|
+
is_valid, result = validate(invalid_schema, df)
|
216
216
|
assert not is_valid
|
217
217
|
assert isinstance(result, PandasDataFrame)
|
218
218
|
assert "failure_case" in result.columns
|
File without changes
|
{snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/CHANGELOG.md
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/.coveragerc
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|