snowpark-checkpoints-validators 0.1.0rc3__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (23) hide show
  1. snowflake/snowpark_checkpoints/__init__.py +34 -0
  2. snowflake/snowpark_checkpoints/__version__.py +16 -0
  3. snowflake/snowpark_checkpoints/checkpoint.py +482 -0
  4. snowflake/snowpark_checkpoints/errors.py +60 -0
  5. snowflake/snowpark_checkpoints/job_context.py +85 -0
  6. snowflake/snowpark_checkpoints/singleton.py +23 -0
  7. snowflake/snowpark_checkpoints/snowpark_sampler.py +99 -0
  8. snowflake/snowpark_checkpoints/spark_migration.py +222 -0
  9. snowflake/snowpark_checkpoints/utils/__init__.py +14 -0
  10. snowflake/snowpark_checkpoints/utils/checkpoint_logger.py +52 -0
  11. snowflake/snowpark_checkpoints/utils/constants.py +134 -0
  12. snowflake/snowpark_checkpoints/utils/extra_config.py +84 -0
  13. snowflake/snowpark_checkpoints/utils/pandera_check_manager.py +358 -0
  14. snowflake/snowpark_checkpoints/utils/supported_types.py +65 -0
  15. snowflake/snowpark_checkpoints/utils/telemetry.py +900 -0
  16. snowflake/snowpark_checkpoints/utils/utils_checks.py +374 -0
  17. snowflake/snowpark_checkpoints/validation_result_metadata.py +125 -0
  18. snowflake/snowpark_checkpoints/validation_results.py +49 -0
  19. {snowpark_checkpoints_validators-0.1.0rc3.dist-info → snowpark_checkpoints_validators-0.1.2.dist-info}/METADATA +4 -7
  20. snowpark_checkpoints_validators-0.1.2.dist-info/RECORD +22 -0
  21. snowpark_checkpoints_validators-0.1.0rc3.dist-info/RECORD +0 -4
  22. {snowpark_checkpoints_validators-0.1.0rc3.dist-info → snowpark_checkpoints_validators-0.1.2.dist-info}/WHEEL +0 -0
  23. {snowpark_checkpoints_validators-0.1.0rc3.dist-info → snowpark_checkpoints_validators-0.1.2.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,358 @@
1
+ from datetime import datetime
2
+ from typing import Optional
3
+
4
+ from pandera import Check, DataFrameSchema
5
+
6
+ from snowflake.snowpark_checkpoints.utils.checkpoint_logger import CheckpointLogger
7
+ from snowflake.snowpark_checkpoints.utils.constants import (
8
+ COLUMNS_KEY,
9
+ DECIMAL_PRECISION_KEY,
10
+ DEFAULT_DATE_FORMAT,
11
+ FALSE_COUNT_KEY,
12
+ FORMAT_KEY,
13
+ MARGIN_ERROR_KEY,
14
+ MAX_KEY,
15
+ MEAN_KEY,
16
+ MIN_KEY,
17
+ NAME_KEY,
18
+ NULL_COUNT_KEY,
19
+ NULLABLE_KEY,
20
+ ROWS_COUNT_KEY,
21
+ SKIP_ALL,
22
+ TRUE_COUNT_KEY,
23
+ TYPE_KEY,
24
+ )
25
+ from snowflake.snowpark_checkpoints.utils.supported_types import (
26
+ BooleanTypes,
27
+ NumericTypes,
28
+ )
29
+
30
+
31
+ class PanderaCheckManager:
32
+ def __init__(self, checkpoint_name: str, schema: DataFrameSchema):
33
+ self.checkpoint_name = checkpoint_name
34
+ self.schema = schema
35
+
36
+ def _add_numeric_checks(self, col: str, additional_check: dict[str, any]):
37
+ """Add numeric checks to a specified column in the schema.
38
+
39
+ This method adds two types of checks to the specified column:
40
+ 1. A mean check that ensures the mean of the column values is within the specified margin of error.
41
+ 2. A decimal precision check that ensures the number of decimal places in the column values does not
42
+ exceed the specified precision.
43
+
44
+ Args:
45
+ col (str): The name of the column to which the checks will be added.
46
+ additional_check (dict[str, any]): A dictionary containing the following keys:
47
+ - MEAN_KEY: The expected mean value for the column.
48
+ - MARGIN_ERROR_KEY: The acceptable margin of error for the mean check.
49
+ - DECIMAL_PRECISION_KEY: The maximum number of decimal places allowed for the column values.
50
+
51
+ """
52
+ mean = additional_check.get(MEAN_KEY, 0)
53
+ std = additional_check.get(MARGIN_ERROR_KEY, 0)
54
+
55
+ def check_mean(series):
56
+ series_mean = series.mean()
57
+ return mean - std <= series_mean <= mean + std
58
+
59
+ self.schema.columns[col].checks.append(
60
+ Check(check_mean, element_wise=False, name="mean")
61
+ )
62
+
63
+ if DECIMAL_PRECISION_KEY in additional_check:
64
+ self.schema.columns[col].checks.append(
65
+ Check(
66
+ lambda series: series.apply(
67
+ lambda x: len(str(x).split(".")[1]) if "." in str(x) else 0
68
+ )
69
+ <= additional_check[DECIMAL_PRECISION_KEY],
70
+ name="decimal_precision",
71
+ )
72
+ )
73
+
74
+ def _add_boolean_checks(self, col: str, additional_check: dict[str, any]):
75
+ """Add boolean checks to the schema for a specified column.
76
+
77
+ This method extends the checks for a given column in the schema by adding
78
+ boolean checks based on the provided additional_check dictionary. It calculates
79
+ the percentage of True and False values in the column and ensures that these
80
+ percentages fall within a specified margin of error.
81
+
82
+ Args:
83
+ col (str): The name of the column to which the checks will be added.
84
+ additional_check (dict[str, any]): A dictionary containing the following keys:
85
+ - TRUE_COUNT_KEY: The count of True values in the column.
86
+ - FALSE_COUNT_KEY: The count of False values in the column.
87
+ - ROWS_COUNT_KEY: The total number of rows in the column.
88
+ - MARGIN_ERROR_KEY: The acceptable margin of error for the percentage checks.
89
+
90
+ Returns:
91
+ None
92
+
93
+ """
94
+ count_of_true = additional_check.get(TRUE_COUNT_KEY, 0)
95
+ count_of_false = additional_check.get(FALSE_COUNT_KEY, 0)
96
+ rows_count = additional_check.get(ROWS_COUNT_KEY, 0)
97
+ std = additional_check.get(MARGIN_ERROR_KEY, 0)
98
+ percentage_true = count_of_true / rows_count
99
+ percentage_false = count_of_false / rows_count
100
+
101
+ self.schema.columns[col].checks.extend(
102
+ [
103
+ Check(
104
+ lambda series: (
105
+ percentage_true - std
106
+ <= series.value_counts().get(True, 0) / series.count()
107
+ if series.count() > 0
108
+ else 1 <= percentage_true + std
109
+ ),
110
+ ),
111
+ Check(
112
+ lambda series: (
113
+ percentage_false - std
114
+ <= series.value_counts().get(False, 0) / series.count()
115
+ if series.count() > 0
116
+ else 1 <= percentage_false + std
117
+ ),
118
+ ),
119
+ ]
120
+ )
121
+
122
+ def _add_null_checks(self, col: str, additional_check: dict[str, any]):
123
+ """Add null checks to the schema for a specified column.
124
+
125
+ This method calculates the percentage of null values in the column and
126
+ appends a check to the schema that ensures the percentage of null values
127
+ in the series is within an acceptable margin of error.
128
+
129
+ Args:
130
+ col (str): The name of the column to add null checks for.
131
+ additional_check (dict[str, any]): A dictionary containing additional
132
+ check parameters:
133
+ - NULL_COUNT_KEY (str): The key for the count of null values.
134
+ - ROWS_COUNT_KEY (str): The key for the total number of rows.
135
+ - MARGIN_ERROR_KEY (str): The key for the margin of error.
136
+
137
+ Raises:
138
+ KeyError: If any of the required keys are missing from additional_check.
139
+
140
+ """
141
+ count_of_null = additional_check.get(NULL_COUNT_KEY, 0)
142
+ rows_count = additional_check.get(ROWS_COUNT_KEY, 0)
143
+ std = additional_check.get(MARGIN_ERROR_KEY, 0)
144
+ percentage_null = count_of_null / rows_count
145
+
146
+ self.schema.columns[col].checks.append(
147
+ Check(
148
+ lambda series: (
149
+ percentage_null - std <= series.isnull().sum() / series.count()
150
+ if series.count() > 0
151
+ else 1 <= percentage_null + std
152
+ ),
153
+ ),
154
+ )
155
+
156
+ def _add_date_time_checks(self, col: str, additional_check: dict[str, any]):
157
+ """Add date and time checks to a specified column in the given DataFrameSchema.
158
+
159
+ Args:
160
+ schema (DataFrameSchema): The schema to which the checks will be added.
161
+ col (str): The name of the column to which the checks will be applied.
162
+ additional_check (dict[str, Any]): A dictionary containing additional check parameters.
163
+ - FORMAT_KEY (str): The key for the date format string in the dictionary.
164
+ - MIN_KEY (str): The key for the minimum date value in the dictionary.
165
+ - MAX_KEY (str): The key for the maximum date value in the dictionary.
166
+
167
+ The function will add the following checks based on the provided additional_check dictionary:
168
+ - If both min_date and max_date are provided, a between check is added.
169
+ - If only min_date is provided, a greater_than_or_equal_to check is added.
170
+ - If only max_date is provided, a less_than_or_equal_to check is added.
171
+
172
+ """
173
+ format = additional_check.get(FORMAT_KEY, DEFAULT_DATE_FORMAT)
174
+
175
+ min = additional_check.get(MIN_KEY, None)
176
+ min_date = datetime.strptime(min, format) if min else None
177
+
178
+ max = additional_check.get(MAX_KEY, None)
179
+ max_date = datetime.strptime(max, format) if max else None
180
+
181
+ if min_date and max_date:
182
+ self.schema.columns[col].checks.append(
183
+ Check.between(
184
+ min_date,
185
+ max_date,
186
+ include_max=True,
187
+ include_min=True,
188
+ )
189
+ )
190
+ elif min_date:
191
+ self.schema.columns[col].checks.append(
192
+ Check.greater_than_or_equal_to(min_date)
193
+ )
194
+ elif max_date:
195
+ self.schema.columns[col].checks.append(
196
+ Check.less_than_or_equal_to(max_date)
197
+ )
198
+
199
+ def _add_date_checks(self, col: str, additional_check: dict[str, any]):
200
+ """Add date and time checks to a specified column in the given DataFrameSchema.
201
+
202
+ Args:
203
+ schema (DataFrameSchema): The schema to which the checks will be added.
204
+ col (str): The name of the column to which the checks will be applied.
205
+ additional_check (dict[str, Any]): A dictionary containing additional check parameters.
206
+ - FORMAT_KEY (str): The key for the date format string in the dictionary.
207
+ - MIN_KEY (str): The key for the minimum date value in the dictionary.
208
+ - MAX_KEY (str): The key for the maximum date value in the dictionary.
209
+
210
+ The function will add the following checks based on the provided additional_check dictionary:
211
+ - If both min_date and max_date are provided, a between check is added.
212
+ - If only min_date is provided, a greater_than_or_equal_to check is added.
213
+ - If only max_date is provided, a less_than_or_equal_to check is added.
214
+
215
+ """
216
+ format = additional_check.get(FORMAT_KEY, DEFAULT_DATE_FORMAT)
217
+
218
+ min = additional_check.get(MIN_KEY, None)
219
+ min_date = datetime.strptime(min, format).date() if min else None
220
+
221
+ max = additional_check.get(MAX_KEY, None)
222
+ max_date = datetime.strptime(max, format).date() if max else None
223
+
224
+ if min_date and max_date:
225
+ self.schema.columns[col].checks.append(
226
+ Check.between(
227
+ min_date,
228
+ max_date,
229
+ include_max=True,
230
+ include_min=True,
231
+ )
232
+ )
233
+ elif min_date:
234
+ self.schema.columns[col].checks.append(
235
+ Check.greater_than_or_equal_to(min_date)
236
+ )
237
+ elif max_date:
238
+ self.schema.columns[col].checks.append(
239
+ Check.less_than_or_equal_to(max_date)
240
+ )
241
+
242
+ def proccess_checks(self, custom_data: dict) -> DataFrameSchema:
243
+ """Process the checks defined in the custom_data dictionary and applies them to the schema.
244
+
245
+ Args:
246
+ custom_data (dict): A dictionary containing the custom checks to be applied. The dictionary
247
+ should have a key corresponding to COLUMNS_KEY, which maps to a list of
248
+ column check definitions. Each column check definition should include
249
+ the following keys:
250
+ - TYPE_KEY: The type of the column (e.g., numeric, boolean, date, datetime).
251
+ - NAME_KEY: The name of the column.
252
+ - NULLABLE_KEY: A boolean indicating if the column is nullable.
253
+
254
+ Returns:
255
+ DataFrameSchema: The updated schema with the applied checks.
256
+
257
+ Raises:
258
+ ValueError: If the column name or type is not defined in the schema.
259
+
260
+ """
261
+ logger = CheckpointLogger().get_logger()
262
+ for additional_check in custom_data.get(COLUMNS_KEY):
263
+
264
+ type = additional_check.get(TYPE_KEY, None)
265
+ name = additional_check.get(NAME_KEY, None)
266
+ is_nullable = additional_check.get(NULLABLE_KEY, False)
267
+
268
+ if name is None:
269
+ raise ValueError(
270
+ f"Column name not defined in the schema {self.checkpoint_name}"
271
+ )
272
+
273
+ if type is None:
274
+ raise ValueError(f"Type not defined for column {name}")
275
+
276
+ if self.schema.columns.get(name) is None:
277
+ logger.warning(f"Column {name} not found in schema")
278
+ continue
279
+
280
+ if type in NumericTypes:
281
+ self._add_numeric_checks(name, additional_check)
282
+
283
+ elif type in BooleanTypes:
284
+ self._add_boolean_checks(name, additional_check)
285
+
286
+ elif type == "date":
287
+ self._add_date_checks(name, additional_check)
288
+
289
+ elif type == "datetime":
290
+ self._add_date_time_checks(name, additional_check)
291
+
292
+ if is_nullable:
293
+ self._add_null_checks(name, additional_check)
294
+
295
+ return self.schema
296
+
297
+ def skip_checks_on_schema(
298
+ self,
299
+ skip_checks: Optional[dict[str, list[str]]] = None,
300
+ ) -> DataFrameSchema:
301
+ """Modify the schema by skipping specified checks on columns.
302
+
303
+ Args:
304
+ skip_checks : Optional[dict[str, list[str]]], optional
305
+ A dictionary where keys are column names and values are lists of check names to skip.
306
+ If the special key 'SKIP_ALL' is present in the list of checks for a column, all checks
307
+ for that column will be skipped. If None, no checks will be skipped.
308
+
309
+ Returns:
310
+ DataFrameSchema: The modified schema with specified checks skipped.
311
+
312
+ """
313
+ if not skip_checks:
314
+ return self.schema
315
+
316
+ for col, checks_to_skip in skip_checks.items():
317
+
318
+ if col in self.schema.columns:
319
+
320
+ if SKIP_ALL in checks_to_skip:
321
+ self.schema.columns[col].checks = {}
322
+ else:
323
+ self.schema.columns[col].checks = [
324
+ check
325
+ for check in self.schema.columns[col].checks
326
+ if check.name not in checks_to_skip
327
+ ]
328
+
329
+ return self.schema
330
+
331
+ def add_custom_checks(
332
+ self,
333
+ custom_checks: Optional[dict[str, list[Check]]] = None,
334
+ ):
335
+ """Add custom checks to a Pandera DataFrameSchema.
336
+
337
+ Args:
338
+ schema (DataFrameSchema): The Pandera DataFrameSchema object to modify.
339
+ custom_checks (Optional[dict[str, list[Check]]]): A dictionary where keys are column names
340
+ and values are lists of checks to add for
341
+ those columns.
342
+
343
+ Returns:
344
+ None
345
+
346
+ """
347
+ if not custom_checks:
348
+ return self.schema
349
+
350
+ for col, checks in custom_checks.items():
351
+
352
+ if col in self.schema.columns:
353
+ col_schema = self.schema.columns[col]
354
+ col_schema.checks.extend(checks)
355
+ else:
356
+ raise ValueError(f"Column {col} not found in schema")
357
+
358
+ return self.schema
@@ -0,0 +1,65 @@
1
+ # Copyright 2025 Snowflake Inc.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from snowflake.snowpark_checkpoints.utils.constants import (
17
+ BINARY_TYPE,
18
+ BOOLEAN_TYPE,
19
+ BYTE_TYPE,
20
+ DATE_TYPE,
21
+ DECIMAL_TYPE,
22
+ DOUBLE_TYPE,
23
+ FLOAT_TYPE,
24
+ INTEGER_TYPE,
25
+ LONG_TYPE,
26
+ SHORT_TYPE,
27
+ STRING_TYPE,
28
+ TIMESTAMP_NTZ_TYPE,
29
+ TIMESTAMP_TYPE,
30
+ )
31
+
32
+
33
+ NumericTypes = [
34
+ BYTE_TYPE,
35
+ SHORT_TYPE,
36
+ INTEGER_TYPE,
37
+ LONG_TYPE,
38
+ FLOAT_TYPE,
39
+ DOUBLE_TYPE,
40
+ DECIMAL_TYPE,
41
+ ]
42
+
43
+ StringTypes = [STRING_TYPE]
44
+
45
+ BinaryTypes = [BINARY_TYPE]
46
+
47
+ BooleanTypes = [BOOLEAN_TYPE]
48
+
49
+ DateTypes = [DATE_TYPE, TIMESTAMP_TYPE, TIMESTAMP_NTZ_TYPE]
50
+
51
+ SupportedTypes = [
52
+ BYTE_TYPE,
53
+ SHORT_TYPE,
54
+ INTEGER_TYPE,
55
+ LONG_TYPE,
56
+ FLOAT_TYPE,
57
+ DOUBLE_TYPE,
58
+ DECIMAL_TYPE,
59
+ STRING_TYPE,
60
+ BINARY_TYPE,
61
+ BOOLEAN_TYPE,
62
+ DATE_TYPE,
63
+ TIMESTAMP_TYPE,
64
+ TIMESTAMP_NTZ_TYPE,
65
+ ]