snowpark-checkpoints-validators 0.1.0rc3__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/snowpark_checkpoints/__init__.py +34 -0
- snowflake/snowpark_checkpoints/__version__.py +16 -0
- snowflake/snowpark_checkpoints/checkpoint.py +482 -0
- snowflake/snowpark_checkpoints/errors.py +60 -0
- snowflake/snowpark_checkpoints/job_context.py +85 -0
- snowflake/snowpark_checkpoints/singleton.py +23 -0
- snowflake/snowpark_checkpoints/snowpark_sampler.py +99 -0
- snowflake/snowpark_checkpoints/spark_migration.py +222 -0
- snowflake/snowpark_checkpoints/utils/__init__.py +14 -0
- snowflake/snowpark_checkpoints/utils/checkpoint_logger.py +52 -0
- snowflake/snowpark_checkpoints/utils/constants.py +134 -0
- snowflake/snowpark_checkpoints/utils/extra_config.py +84 -0
- snowflake/snowpark_checkpoints/utils/pandera_check_manager.py +358 -0
- snowflake/snowpark_checkpoints/utils/supported_types.py +65 -0
- snowflake/snowpark_checkpoints/utils/telemetry.py +900 -0
- snowflake/snowpark_checkpoints/utils/utils_checks.py +374 -0
- snowflake/snowpark_checkpoints/validation_result_metadata.py +125 -0
- snowflake/snowpark_checkpoints/validation_results.py +49 -0
- {snowpark_checkpoints_validators-0.1.0rc3.dist-info → snowpark_checkpoints_validators-0.1.2.dist-info}/METADATA +4 -7
- snowpark_checkpoints_validators-0.1.2.dist-info/RECORD +22 -0
- snowpark_checkpoints_validators-0.1.0rc3.dist-info/RECORD +0 -4
- {snowpark_checkpoints_validators-0.1.0rc3.dist-info → snowpark_checkpoints_validators-0.1.2.dist-info}/WHEEL +0 -0
- {snowpark_checkpoints_validators-0.1.0rc3.dist-info → snowpark_checkpoints_validators-0.1.2.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,358 @@
|
|
1
|
+
from datetime import datetime
|
2
|
+
from typing import Optional
|
3
|
+
|
4
|
+
from pandera import Check, DataFrameSchema
|
5
|
+
|
6
|
+
from snowflake.snowpark_checkpoints.utils.checkpoint_logger import CheckpointLogger
|
7
|
+
from snowflake.snowpark_checkpoints.utils.constants import (
|
8
|
+
COLUMNS_KEY,
|
9
|
+
DECIMAL_PRECISION_KEY,
|
10
|
+
DEFAULT_DATE_FORMAT,
|
11
|
+
FALSE_COUNT_KEY,
|
12
|
+
FORMAT_KEY,
|
13
|
+
MARGIN_ERROR_KEY,
|
14
|
+
MAX_KEY,
|
15
|
+
MEAN_KEY,
|
16
|
+
MIN_KEY,
|
17
|
+
NAME_KEY,
|
18
|
+
NULL_COUNT_KEY,
|
19
|
+
NULLABLE_KEY,
|
20
|
+
ROWS_COUNT_KEY,
|
21
|
+
SKIP_ALL,
|
22
|
+
TRUE_COUNT_KEY,
|
23
|
+
TYPE_KEY,
|
24
|
+
)
|
25
|
+
from snowflake.snowpark_checkpoints.utils.supported_types import (
|
26
|
+
BooleanTypes,
|
27
|
+
NumericTypes,
|
28
|
+
)
|
29
|
+
|
30
|
+
|
31
|
+
class PanderaCheckManager:
|
32
|
+
def __init__(self, checkpoint_name: str, schema: DataFrameSchema):
|
33
|
+
self.checkpoint_name = checkpoint_name
|
34
|
+
self.schema = schema
|
35
|
+
|
36
|
+
def _add_numeric_checks(self, col: str, additional_check: dict[str, any]):
|
37
|
+
"""Add numeric checks to a specified column in the schema.
|
38
|
+
|
39
|
+
This method adds two types of checks to the specified column:
|
40
|
+
1. A mean check that ensures the mean of the column values is within the specified margin of error.
|
41
|
+
2. A decimal precision check that ensures the number of decimal places in the column values does not
|
42
|
+
exceed the specified precision.
|
43
|
+
|
44
|
+
Args:
|
45
|
+
col (str): The name of the column to which the checks will be added.
|
46
|
+
additional_check (dict[str, any]): A dictionary containing the following keys:
|
47
|
+
- MEAN_KEY: The expected mean value for the column.
|
48
|
+
- MARGIN_ERROR_KEY: The acceptable margin of error for the mean check.
|
49
|
+
- DECIMAL_PRECISION_KEY: The maximum number of decimal places allowed for the column values.
|
50
|
+
|
51
|
+
"""
|
52
|
+
mean = additional_check.get(MEAN_KEY, 0)
|
53
|
+
std = additional_check.get(MARGIN_ERROR_KEY, 0)
|
54
|
+
|
55
|
+
def check_mean(series):
|
56
|
+
series_mean = series.mean()
|
57
|
+
return mean - std <= series_mean <= mean + std
|
58
|
+
|
59
|
+
self.schema.columns[col].checks.append(
|
60
|
+
Check(check_mean, element_wise=False, name="mean")
|
61
|
+
)
|
62
|
+
|
63
|
+
if DECIMAL_PRECISION_KEY in additional_check:
|
64
|
+
self.schema.columns[col].checks.append(
|
65
|
+
Check(
|
66
|
+
lambda series: series.apply(
|
67
|
+
lambda x: len(str(x).split(".")[1]) if "." in str(x) else 0
|
68
|
+
)
|
69
|
+
<= additional_check[DECIMAL_PRECISION_KEY],
|
70
|
+
name="decimal_precision",
|
71
|
+
)
|
72
|
+
)
|
73
|
+
|
74
|
+
def _add_boolean_checks(self, col: str, additional_check: dict[str, any]):
|
75
|
+
"""Add boolean checks to the schema for a specified column.
|
76
|
+
|
77
|
+
This method extends the checks for a given column in the schema by adding
|
78
|
+
boolean checks based on the provided additional_check dictionary. It calculates
|
79
|
+
the percentage of True and False values in the column and ensures that these
|
80
|
+
percentages fall within a specified margin of error.
|
81
|
+
|
82
|
+
Args:
|
83
|
+
col (str): The name of the column to which the checks will be added.
|
84
|
+
additional_check (dict[str, any]): A dictionary containing the following keys:
|
85
|
+
- TRUE_COUNT_KEY: The count of True values in the column.
|
86
|
+
- FALSE_COUNT_KEY: The count of False values in the column.
|
87
|
+
- ROWS_COUNT_KEY: The total number of rows in the column.
|
88
|
+
- MARGIN_ERROR_KEY: The acceptable margin of error for the percentage checks.
|
89
|
+
|
90
|
+
Returns:
|
91
|
+
None
|
92
|
+
|
93
|
+
"""
|
94
|
+
count_of_true = additional_check.get(TRUE_COUNT_KEY, 0)
|
95
|
+
count_of_false = additional_check.get(FALSE_COUNT_KEY, 0)
|
96
|
+
rows_count = additional_check.get(ROWS_COUNT_KEY, 0)
|
97
|
+
std = additional_check.get(MARGIN_ERROR_KEY, 0)
|
98
|
+
percentage_true = count_of_true / rows_count
|
99
|
+
percentage_false = count_of_false / rows_count
|
100
|
+
|
101
|
+
self.schema.columns[col].checks.extend(
|
102
|
+
[
|
103
|
+
Check(
|
104
|
+
lambda series: (
|
105
|
+
percentage_true - std
|
106
|
+
<= series.value_counts().get(True, 0) / series.count()
|
107
|
+
if series.count() > 0
|
108
|
+
else 1 <= percentage_true + std
|
109
|
+
),
|
110
|
+
),
|
111
|
+
Check(
|
112
|
+
lambda series: (
|
113
|
+
percentage_false - std
|
114
|
+
<= series.value_counts().get(False, 0) / series.count()
|
115
|
+
if series.count() > 0
|
116
|
+
else 1 <= percentage_false + std
|
117
|
+
),
|
118
|
+
),
|
119
|
+
]
|
120
|
+
)
|
121
|
+
|
122
|
+
def _add_null_checks(self, col: str, additional_check: dict[str, any]):
|
123
|
+
"""Add null checks to the schema for a specified column.
|
124
|
+
|
125
|
+
This method calculates the percentage of null values in the column and
|
126
|
+
appends a check to the schema that ensures the percentage of null values
|
127
|
+
in the series is within an acceptable margin of error.
|
128
|
+
|
129
|
+
Args:
|
130
|
+
col (str): The name of the column to add null checks for.
|
131
|
+
additional_check (dict[str, any]): A dictionary containing additional
|
132
|
+
check parameters:
|
133
|
+
- NULL_COUNT_KEY (str): The key for the count of null values.
|
134
|
+
- ROWS_COUNT_KEY (str): The key for the total number of rows.
|
135
|
+
- MARGIN_ERROR_KEY (str): The key for the margin of error.
|
136
|
+
|
137
|
+
Raises:
|
138
|
+
KeyError: If any of the required keys are missing from additional_check.
|
139
|
+
|
140
|
+
"""
|
141
|
+
count_of_null = additional_check.get(NULL_COUNT_KEY, 0)
|
142
|
+
rows_count = additional_check.get(ROWS_COUNT_KEY, 0)
|
143
|
+
std = additional_check.get(MARGIN_ERROR_KEY, 0)
|
144
|
+
percentage_null = count_of_null / rows_count
|
145
|
+
|
146
|
+
self.schema.columns[col].checks.append(
|
147
|
+
Check(
|
148
|
+
lambda series: (
|
149
|
+
percentage_null - std <= series.isnull().sum() / series.count()
|
150
|
+
if series.count() > 0
|
151
|
+
else 1 <= percentage_null + std
|
152
|
+
),
|
153
|
+
),
|
154
|
+
)
|
155
|
+
|
156
|
+
def _add_date_time_checks(self, col: str, additional_check: dict[str, any]):
|
157
|
+
"""Add date and time checks to a specified column in the given DataFrameSchema.
|
158
|
+
|
159
|
+
Args:
|
160
|
+
schema (DataFrameSchema): The schema to which the checks will be added.
|
161
|
+
col (str): The name of the column to which the checks will be applied.
|
162
|
+
additional_check (dict[str, Any]): A dictionary containing additional check parameters.
|
163
|
+
- FORMAT_KEY (str): The key for the date format string in the dictionary.
|
164
|
+
- MIN_KEY (str): The key for the minimum date value in the dictionary.
|
165
|
+
- MAX_KEY (str): The key for the maximum date value in the dictionary.
|
166
|
+
|
167
|
+
The function will add the following checks based on the provided additional_check dictionary:
|
168
|
+
- If both min_date and max_date are provided, a between check is added.
|
169
|
+
- If only min_date is provided, a greater_than_or_equal_to check is added.
|
170
|
+
- If only max_date is provided, a less_than_or_equal_to check is added.
|
171
|
+
|
172
|
+
"""
|
173
|
+
format = additional_check.get(FORMAT_KEY, DEFAULT_DATE_FORMAT)
|
174
|
+
|
175
|
+
min = additional_check.get(MIN_KEY, None)
|
176
|
+
min_date = datetime.strptime(min, format) if min else None
|
177
|
+
|
178
|
+
max = additional_check.get(MAX_KEY, None)
|
179
|
+
max_date = datetime.strptime(max, format) if max else None
|
180
|
+
|
181
|
+
if min_date and max_date:
|
182
|
+
self.schema.columns[col].checks.append(
|
183
|
+
Check.between(
|
184
|
+
min_date,
|
185
|
+
max_date,
|
186
|
+
include_max=True,
|
187
|
+
include_min=True,
|
188
|
+
)
|
189
|
+
)
|
190
|
+
elif min_date:
|
191
|
+
self.schema.columns[col].checks.append(
|
192
|
+
Check.greater_than_or_equal_to(min_date)
|
193
|
+
)
|
194
|
+
elif max_date:
|
195
|
+
self.schema.columns[col].checks.append(
|
196
|
+
Check.less_than_or_equal_to(max_date)
|
197
|
+
)
|
198
|
+
|
199
|
+
def _add_date_checks(self, col: str, additional_check: dict[str, any]):
|
200
|
+
"""Add date and time checks to a specified column in the given DataFrameSchema.
|
201
|
+
|
202
|
+
Args:
|
203
|
+
schema (DataFrameSchema): The schema to which the checks will be added.
|
204
|
+
col (str): The name of the column to which the checks will be applied.
|
205
|
+
additional_check (dict[str, Any]): A dictionary containing additional check parameters.
|
206
|
+
- FORMAT_KEY (str): The key for the date format string in the dictionary.
|
207
|
+
- MIN_KEY (str): The key for the minimum date value in the dictionary.
|
208
|
+
- MAX_KEY (str): The key for the maximum date value in the dictionary.
|
209
|
+
|
210
|
+
The function will add the following checks based on the provided additional_check dictionary:
|
211
|
+
- If both min_date and max_date are provided, a between check is added.
|
212
|
+
- If only min_date is provided, a greater_than_or_equal_to check is added.
|
213
|
+
- If only max_date is provided, a less_than_or_equal_to check is added.
|
214
|
+
|
215
|
+
"""
|
216
|
+
format = additional_check.get(FORMAT_KEY, DEFAULT_DATE_FORMAT)
|
217
|
+
|
218
|
+
min = additional_check.get(MIN_KEY, None)
|
219
|
+
min_date = datetime.strptime(min, format).date() if min else None
|
220
|
+
|
221
|
+
max = additional_check.get(MAX_KEY, None)
|
222
|
+
max_date = datetime.strptime(max, format).date() if max else None
|
223
|
+
|
224
|
+
if min_date and max_date:
|
225
|
+
self.schema.columns[col].checks.append(
|
226
|
+
Check.between(
|
227
|
+
min_date,
|
228
|
+
max_date,
|
229
|
+
include_max=True,
|
230
|
+
include_min=True,
|
231
|
+
)
|
232
|
+
)
|
233
|
+
elif min_date:
|
234
|
+
self.schema.columns[col].checks.append(
|
235
|
+
Check.greater_than_or_equal_to(min_date)
|
236
|
+
)
|
237
|
+
elif max_date:
|
238
|
+
self.schema.columns[col].checks.append(
|
239
|
+
Check.less_than_or_equal_to(max_date)
|
240
|
+
)
|
241
|
+
|
242
|
+
def proccess_checks(self, custom_data: dict) -> DataFrameSchema:
|
243
|
+
"""Process the checks defined in the custom_data dictionary and applies them to the schema.
|
244
|
+
|
245
|
+
Args:
|
246
|
+
custom_data (dict): A dictionary containing the custom checks to be applied. The dictionary
|
247
|
+
should have a key corresponding to COLUMNS_KEY, which maps to a list of
|
248
|
+
column check definitions. Each column check definition should include
|
249
|
+
the following keys:
|
250
|
+
- TYPE_KEY: The type of the column (e.g., numeric, boolean, date, datetime).
|
251
|
+
- NAME_KEY: The name of the column.
|
252
|
+
- NULLABLE_KEY: A boolean indicating if the column is nullable.
|
253
|
+
|
254
|
+
Returns:
|
255
|
+
DataFrameSchema: The updated schema with the applied checks.
|
256
|
+
|
257
|
+
Raises:
|
258
|
+
ValueError: If the column name or type is not defined in the schema.
|
259
|
+
|
260
|
+
"""
|
261
|
+
logger = CheckpointLogger().get_logger()
|
262
|
+
for additional_check in custom_data.get(COLUMNS_KEY):
|
263
|
+
|
264
|
+
type = additional_check.get(TYPE_KEY, None)
|
265
|
+
name = additional_check.get(NAME_KEY, None)
|
266
|
+
is_nullable = additional_check.get(NULLABLE_KEY, False)
|
267
|
+
|
268
|
+
if name is None:
|
269
|
+
raise ValueError(
|
270
|
+
f"Column name not defined in the schema {self.checkpoint_name}"
|
271
|
+
)
|
272
|
+
|
273
|
+
if type is None:
|
274
|
+
raise ValueError(f"Type not defined for column {name}")
|
275
|
+
|
276
|
+
if self.schema.columns.get(name) is None:
|
277
|
+
logger.warning(f"Column {name} not found in schema")
|
278
|
+
continue
|
279
|
+
|
280
|
+
if type in NumericTypes:
|
281
|
+
self._add_numeric_checks(name, additional_check)
|
282
|
+
|
283
|
+
elif type in BooleanTypes:
|
284
|
+
self._add_boolean_checks(name, additional_check)
|
285
|
+
|
286
|
+
elif type == "date":
|
287
|
+
self._add_date_checks(name, additional_check)
|
288
|
+
|
289
|
+
elif type == "datetime":
|
290
|
+
self._add_date_time_checks(name, additional_check)
|
291
|
+
|
292
|
+
if is_nullable:
|
293
|
+
self._add_null_checks(name, additional_check)
|
294
|
+
|
295
|
+
return self.schema
|
296
|
+
|
297
|
+
def skip_checks_on_schema(
|
298
|
+
self,
|
299
|
+
skip_checks: Optional[dict[str, list[str]]] = None,
|
300
|
+
) -> DataFrameSchema:
|
301
|
+
"""Modify the schema by skipping specified checks on columns.
|
302
|
+
|
303
|
+
Args:
|
304
|
+
skip_checks : Optional[dict[str, list[str]]], optional
|
305
|
+
A dictionary where keys are column names and values are lists of check names to skip.
|
306
|
+
If the special key 'SKIP_ALL' is present in the list of checks for a column, all checks
|
307
|
+
for that column will be skipped. If None, no checks will be skipped.
|
308
|
+
|
309
|
+
Returns:
|
310
|
+
DataFrameSchema: The modified schema with specified checks skipped.
|
311
|
+
|
312
|
+
"""
|
313
|
+
if not skip_checks:
|
314
|
+
return self.schema
|
315
|
+
|
316
|
+
for col, checks_to_skip in skip_checks.items():
|
317
|
+
|
318
|
+
if col in self.schema.columns:
|
319
|
+
|
320
|
+
if SKIP_ALL in checks_to_skip:
|
321
|
+
self.schema.columns[col].checks = {}
|
322
|
+
else:
|
323
|
+
self.schema.columns[col].checks = [
|
324
|
+
check
|
325
|
+
for check in self.schema.columns[col].checks
|
326
|
+
if check.name not in checks_to_skip
|
327
|
+
]
|
328
|
+
|
329
|
+
return self.schema
|
330
|
+
|
331
|
+
def add_custom_checks(
|
332
|
+
self,
|
333
|
+
custom_checks: Optional[dict[str, list[Check]]] = None,
|
334
|
+
):
|
335
|
+
"""Add custom checks to a Pandera DataFrameSchema.
|
336
|
+
|
337
|
+
Args:
|
338
|
+
schema (DataFrameSchema): The Pandera DataFrameSchema object to modify.
|
339
|
+
custom_checks (Optional[dict[str, list[Check]]]): A dictionary where keys are column names
|
340
|
+
and values are lists of checks to add for
|
341
|
+
those columns.
|
342
|
+
|
343
|
+
Returns:
|
344
|
+
None
|
345
|
+
|
346
|
+
"""
|
347
|
+
if not custom_checks:
|
348
|
+
return self.schema
|
349
|
+
|
350
|
+
for col, checks in custom_checks.items():
|
351
|
+
|
352
|
+
if col in self.schema.columns:
|
353
|
+
col_schema = self.schema.columns[col]
|
354
|
+
col_schema.checks.extend(checks)
|
355
|
+
else:
|
356
|
+
raise ValueError(f"Column {col} not found in schema")
|
357
|
+
|
358
|
+
return self.schema
|
@@ -0,0 +1,65 @@
|
|
1
|
+
# Copyright 2025 Snowflake Inc.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
from snowflake.snowpark_checkpoints.utils.constants import (
|
17
|
+
BINARY_TYPE,
|
18
|
+
BOOLEAN_TYPE,
|
19
|
+
BYTE_TYPE,
|
20
|
+
DATE_TYPE,
|
21
|
+
DECIMAL_TYPE,
|
22
|
+
DOUBLE_TYPE,
|
23
|
+
FLOAT_TYPE,
|
24
|
+
INTEGER_TYPE,
|
25
|
+
LONG_TYPE,
|
26
|
+
SHORT_TYPE,
|
27
|
+
STRING_TYPE,
|
28
|
+
TIMESTAMP_NTZ_TYPE,
|
29
|
+
TIMESTAMP_TYPE,
|
30
|
+
)
|
31
|
+
|
32
|
+
|
33
|
+
NumericTypes = [
|
34
|
+
BYTE_TYPE,
|
35
|
+
SHORT_TYPE,
|
36
|
+
INTEGER_TYPE,
|
37
|
+
LONG_TYPE,
|
38
|
+
FLOAT_TYPE,
|
39
|
+
DOUBLE_TYPE,
|
40
|
+
DECIMAL_TYPE,
|
41
|
+
]
|
42
|
+
|
43
|
+
StringTypes = [STRING_TYPE]
|
44
|
+
|
45
|
+
BinaryTypes = [BINARY_TYPE]
|
46
|
+
|
47
|
+
BooleanTypes = [BOOLEAN_TYPE]
|
48
|
+
|
49
|
+
DateTypes = [DATE_TYPE, TIMESTAMP_TYPE, TIMESTAMP_NTZ_TYPE]
|
50
|
+
|
51
|
+
SupportedTypes = [
|
52
|
+
BYTE_TYPE,
|
53
|
+
SHORT_TYPE,
|
54
|
+
INTEGER_TYPE,
|
55
|
+
LONG_TYPE,
|
56
|
+
FLOAT_TYPE,
|
57
|
+
DOUBLE_TYPE,
|
58
|
+
DECIMAL_TYPE,
|
59
|
+
STRING_TYPE,
|
60
|
+
BINARY_TYPE,
|
61
|
+
BOOLEAN_TYPE,
|
62
|
+
DATE_TYPE,
|
63
|
+
TIMESTAMP_TYPE,
|
64
|
+
TIMESTAMP_NTZ_TYPE,
|
65
|
+
]
|