snowpark-checkpoints-validators 0.3.2__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -13,4 +13,4 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- __version__ = "0.3.2"
16
+ __version__ = "0.4.0"
@@ -332,7 +332,7 @@ def _check_dataframe_schema(
332
332
  pandera_schema_upper, sample_df = _process_sampling(
333
333
  df, pandera_schema, job_context, sample_frac, sample_number, sampling_strategy
334
334
  )
335
- is_valid, validation_result = _validate(pandera_schema_upper, sample_df)
335
+ is_valid, validation_result = validate(pandera_schema_upper, sample_df)
336
336
  if is_valid:
337
337
  LOGGER.info(
338
338
  "DataFrame schema validation passed for checkpoint '%s'",
@@ -438,7 +438,7 @@ def check_output_schema(
438
438
  sampler.process_args([snowpark_results])
439
439
  pandas_sample_args = sampler.get_sampled_pandas_args()
440
440
 
441
- is_valid, validation_result = _validate(
441
+ is_valid, validation_result = validate(
442
442
  pandera_schema, pandas_sample_args[0]
443
443
  )
444
444
  if is_valid:
@@ -565,7 +565,7 @@ def check_input_schema(
565
565
  )
566
566
  continue
567
567
 
568
- is_valid, validation_result = _validate(
568
+ is_valid, validation_result = validate(
569
569
  pandera_schema,
570
570
  arg,
571
571
  )
@@ -606,11 +606,31 @@ def check_input_schema(
606
606
  return check_input_with_decorator
607
607
 
608
608
 
609
- def _validate(
609
+ def validate(
610
610
  schema: Union[type[DataFrameModel], DataFrameSchema],
611
611
  df: PandasDataFrame,
612
612
  lazy: bool = True,
613
613
  ) -> tuple[bool, PandasDataFrame]:
614
+ """Validate a Pandas DataFrame against a given Pandera schema.
615
+
616
+ Args:
617
+ schema (Union[type[DataFrameModel], DataFrameSchema]):
618
+ The schema to validate against. Can be a Pandera `DataFrameSchema` or
619
+ a `DataFrameModel` class.
620
+ df (PandasDataFrame):
621
+ The Pandas DataFrame to be validated.
622
+ lazy (bool, optional):
623
+ If `True`, collect all validation errors before raising an exception.
624
+ If `False`, raise an exception as soon as the first error is encountered.
625
+ Defaults to `True`.
626
+
627
+ Returns:
628
+ tuple[bool, PandasDataFrame]:
629
+ A tuple containing:
630
+ - A boolean indicating whether validation passed.
631
+ - The validated DataFrame if successful, or the failure cases DataFrame if not.
632
+
633
+ """
614
634
  if not isinstance(schema, DataFrameSchema):
615
635
  schema = schema.to_schema()
616
636
  is_valid = True
@@ -20,7 +20,21 @@ from typing import Optional
20
20
  import pandas
21
21
 
22
22
  from snowflake.snowpark import DataFrame as SnowparkDataFrame
23
+ from snowflake.snowpark.types import (
24
+ BinaryType,
25
+ BooleanType,
26
+ DateType,
27
+ FloatType,
28
+ StringType,
29
+ TimestampType,
30
+ )
23
31
  from snowflake.snowpark_checkpoints.job_context import SnowparkJobContext
32
+ from snowflake.snowpark_checkpoints.utils.constants import (
33
+ INTEGER_TYPE_COLLECTION,
34
+ PANDAS_FLOAT_TYPE,
35
+ PANDAS_LONG_TYPE,
36
+ PANDAS_STRING_TYPE,
37
+ )
24
38
 
25
39
 
26
40
  LOGGER = logging.getLogger(__name__)
@@ -73,17 +87,17 @@ class SamplingAdapter:
73
87
  "Applying random sampling with fraction %s",
74
88
  self.sample_frac,
75
89
  )
76
- df_sample = arg.sample(frac=self.sample_frac).to_pandas()
90
+ df_sample = to_pandas(arg.sample(frac=self.sample_frac))
77
91
  else:
78
92
  LOGGER.info(
79
93
  "Applying random sampling with size %s", self.sample_number
80
94
  )
81
- df_sample = arg.sample(n=self.sample_number).to_pandas()
95
+ df_sample = to_pandas(arg.sample(n=self.sample_number))
82
96
  else:
83
97
  LOGGER.info(
84
98
  "Applying limit sampling with size %s", self.sample_number
85
99
  )
86
- df_sample = arg.limit(self.sample_number).to_pandas()
100
+ df_sample = to_pandas(arg.limit(self.sample_number))
87
101
 
88
102
  LOGGER.info(
89
103
  "Successfully sampled the DataFrame. Resulting DataFrame shape: %s",
@@ -122,3 +136,90 @@ class SamplingAdapter:
122
136
  else:
123
137
  pyspark_sample_args.append(arg)
124
138
  return pyspark_sample_args
139
+
140
+
141
+ def to_pandas(sampled_df: SnowparkDataFrame) -> pandas.DataFrame:
142
+ """Convert a Snowpark DataFrame to a Pandas DataFrame, handling missing values and type conversions."""
143
+ LOGGER.debug("Converting Snowpark DataFrame to Pandas DataFrame")
144
+ pandas_df = sampled_df.toPandas()
145
+ for field in sampled_df.schema.fields:
146
+ is_snowpark_integer = field.datatype.typeName() in INTEGER_TYPE_COLLECTION
147
+ is_snowpark_string = isinstance(field.datatype, StringType)
148
+ is_snowpark_binary = isinstance(field.datatype, BinaryType)
149
+ is_snowpark_timestamp = isinstance(field.datatype, TimestampType)
150
+ is_snowpark_float = isinstance(field.datatype, FloatType)
151
+ is_snowpark_boolean = isinstance(field.datatype, BooleanType)
152
+ is_snowpark_date = isinstance(field.datatype, DateType)
153
+ if is_snowpark_integer:
154
+ LOGGER.debug(
155
+ "Converting Spark integer column '%s' to Pandas nullable '%s' type",
156
+ field.name,
157
+ PANDAS_LONG_TYPE,
158
+ )
159
+ pandas_df[field.name] = (
160
+ pandas_df[field.name].astype(PANDAS_LONG_TYPE).fillna(0)
161
+ )
162
+ elif is_snowpark_string or is_snowpark_binary:
163
+ LOGGER.debug(
164
+ "Converting Spark string column '%s' to Pandas nullable '%s' type",
165
+ field.name,
166
+ PANDAS_STRING_TYPE,
167
+ )
168
+ pandas_df[field.name] = (
169
+ pandas_df[field.name].astype(PANDAS_STRING_TYPE).fillna("")
170
+ )
171
+ elif is_snowpark_timestamp:
172
+ LOGGER.debug(
173
+ "Converting Spark timestamp column '%s' to UTC naive Pandas datetime",
174
+ field.name,
175
+ )
176
+ pandas_df[field.name] = convert_all_to_utc_naive(
177
+ pandas_df[field.name]
178
+ ).fillna(pandas.NaT)
179
+ elif is_snowpark_float:
180
+ LOGGER.debug(
181
+ "Converting Spark float column '%s' to Pandas nullable float",
182
+ field.name,
183
+ )
184
+ pandas_df[field.name] = (
185
+ pandas_df[field.name].astype(PANDAS_FLOAT_TYPE).fillna(0.0)
186
+ )
187
+ elif is_snowpark_boolean:
188
+ LOGGER.debug(
189
+ "Converting Spark boolean column '%s' to Pandas nullable boolean",
190
+ field.name,
191
+ )
192
+ pandas_df[field.name] = (
193
+ pandas_df[field.name].astype("boolean").fillna(False)
194
+ )
195
+ elif is_snowpark_date:
196
+ LOGGER.debug(
197
+ "Converting Spark date column '%s' to Pandas nullable datetime",
198
+ field.name,
199
+ )
200
+ pandas_df[field.name] = pandas_df[field.name].fillna(pandas.NaT)
201
+
202
+ return pandas_df
203
+
204
+
205
+ def convert_all_to_utc_naive(series: pandas.Series) -> pandas.Series:
206
+ """Convert all timezone-aware or naive timestamps in a series to UTC naive.
207
+
208
+ Naive timestamps are assumed to be in UTC and localized accordingly.
209
+ Timezone-aware timestamps are converted to UTC and then made naive.
210
+
211
+ Args:
212
+ series (pandas.Series): A Pandas Series of `pd.Timestamp` objects,
213
+ either naive or timezone-aware.
214
+
215
+ Returns:
216
+ pandas.Series: A Series of UTC-normalized naive timestamps (`tzinfo=None`).
217
+
218
+ """
219
+
220
+ def convert(ts):
221
+ if ts.tz is None:
222
+ ts = ts.tz_localize("UTC")
223
+ return ts.tz_convert("UTC").tz_localize(None)
224
+
225
+ return series.apply(convert)
@@ -133,3 +133,17 @@ VALIDATION_RESULTS_JSON_FILE_NAME: Final[str] = "checkpoint_validation_results.j
133
133
  SNOWFLAKE_CHECKPOINT_CONTRACT_FILE_PATH_ENV_VAR: Final[
134
134
  str
135
135
  ] = "SNOWFLAKE_CHECKPOINT_CONTRACT_FILE_PATH"
136
+
137
+ BYTE_COLUMN_TYPE = "byte"
138
+ INTEGER_COLUMN_TYPE = "integer"
139
+ LONG_COLUMN_TYPE = "long"
140
+ SHORT_COLUMN_TYPE = "short"
141
+ PANDAS_FLOAT_TYPE = "float64"
142
+ PANDAS_LONG_TYPE = "Int64"
143
+ PANDAS_STRING_TYPE = "string"
144
+ INTEGER_TYPE_COLLECTION = [
145
+ BYTE_COLUMN_TYPE,
146
+ INTEGER_COLUMN_TYPE,
147
+ LONG_COLUMN_TYPE,
148
+ SHORT_COLUMN_TYPE,
149
+ ]
@@ -27,6 +27,9 @@ import numpy as np
27
27
  from pandera import DataFrameSchema
28
28
 
29
29
  from snowflake.snowpark import DataFrame as SnowparkDataFrame
30
+ from snowflake.snowpark import Session
31
+ from snowflake.snowpark.functions import col, expr
32
+ from snowflake.snowpark.types import TimestampType
30
33
  from snowflake.snowpark_checkpoints.errors import SchemaValidationError
31
34
  from snowflake.snowpark_checkpoints.io_utils.io_file_manager import get_io_file_manager
32
35
  from snowflake.snowpark_checkpoints.job_context import SnowparkJobContext
@@ -42,7 +45,6 @@ from snowflake.snowpark_checkpoints.utils.constants import (
42
45
  DATAFRAME_EXECUTION_MODE,
43
46
  DATAFRAME_PANDERA_SCHEMA_KEY,
44
47
  DEFAULT_KEY,
45
- EXCEPT_HASH_AGG_QUERY,
46
48
  FAIL_STATUS,
47
49
  PASS_STATUS,
48
50
  SNOWPARK_CHECKPOINTS_OUTPUT_DIRECTORY_NAME,
@@ -120,14 +122,14 @@ def _process_sampling(
120
122
  pandera_schema_upper = pandera_schema
121
123
  new_columns: dict[Any, Any] = {}
122
124
 
123
- for col in pandera_schema.columns:
124
- new_columns[col.upper()] = pandera_schema.columns[col]
125
+ for column in pandera_schema.columns:
126
+ new_columns[column.upper()] = pandera_schema.columns[column]
125
127
 
126
128
  pandera_schema_upper = pandera_schema_upper.remove_columns(pandera_schema.columns)
127
129
  pandera_schema_upper = pandera_schema_upper.add_columns(new_columns)
128
130
 
129
131
  sample_df = sampler.get_sampled_pandas_args()[0]
130
- sample_df.index = np.ones(sample_df.count().iloc[0])
132
+ sample_df.index = np.ones(sample_df.count().iloc[0], dtype=int)
131
133
 
132
134
  return pandera_schema_upper, sample_df
133
135
 
@@ -191,6 +193,7 @@ Please run the Snowpark checkpoint collector first."""
191
193
  schema_dict = checkpoint_schema_config.get(DATAFRAME_PANDERA_SCHEMA_KEY)
192
194
  schema_dict_str = json.dumps(schema_dict)
193
195
  schema = DataFrameSchema.from_json(schema_dict_str)
196
+ schema.coerce = False # Disable coercion to ensure strict validation
194
197
 
195
198
  if DATAFRAME_CUSTOM_DATA_KEY not in checkpoint_schema_config:
196
199
  LOGGER.info(
@@ -270,6 +273,7 @@ def _compare_data(
270
273
  SchemaValidationError: If there is a data mismatch between the DataFrame and the checkpoint table.
271
274
 
272
275
  """
276
+ df = convert_timestamps_to_utc_date(df)
273
277
  new_table_name = CHECKPOINT_TABLE_NAME_FORMAT.format(checkpoint_name)
274
278
  LOGGER.info(
275
279
  "Writing Snowpark DataFrame to table: '%s' for checkpoint: '%s'",
@@ -283,12 +287,12 @@ def _compare_data(
283
287
  new_table_name,
284
288
  checkpoint_name,
285
289
  )
286
- expect_df = job_context.snowpark_session.sql(
287
- EXCEPT_HASH_AGG_QUERY, [checkpoint_name, new_table_name]
288
- )
289
290
 
290
- if expect_df.count() != 0:
291
- error_message = f"Data mismatch for checkpoint {checkpoint_name}"
291
+ session = job_context.snowpark_session
292
+ result = get_comparison_differences(session, checkpoint_name, new_table_name)
293
+ has_failed = result.get("spark_only_rows") or result.get("snowpark_only_rows")
294
+ if has_failed or result.get("error"):
295
+ error_message = f"Data mismatch for checkpoint {checkpoint_name}: {result}"
292
296
  job_context._mark_fail(
293
297
  error_message,
294
298
  checkpoint_name,
@@ -312,6 +316,80 @@ def _compare_data(
312
316
  return True, None
313
317
 
314
318
 
319
+ def get_comparison_differences(
320
+ session: Session, spark_table: str, snowpark_table: str
321
+ ) -> dict:
322
+ """Compare two tables and return the differences."""
323
+ try:
324
+ spark_raw_schema = session.table(spark_table).schema.names
325
+ snowpark_raw_schema = session.table(snowpark_table).schema.names
326
+
327
+ spark_normalized = {
328
+ col_name.strip('"').upper(): col_name for col_name in spark_raw_schema
329
+ }
330
+ snowpark_normalized = {
331
+ col_name.strip('"').upper(): col_name for col_name in snowpark_raw_schema
332
+ }
333
+
334
+ common_cols = sorted(
335
+ list(
336
+ set(spark_normalized.keys()).intersection(
337
+ set(snowpark_normalized.keys())
338
+ )
339
+ )
340
+ )
341
+
342
+ if not common_cols:
343
+ return {
344
+ "error": f"No common columns found between {spark_table} and {snowpark_table}",
345
+ }
346
+
347
+ cols_for_spark_selection = [
348
+ spark_normalized[norm_col_name] for norm_col_name in common_cols
349
+ ]
350
+ cols_for_snowpark_selection = [
351
+ snowpark_normalized[norm_col_name] for norm_col_name in common_cols
352
+ ]
353
+
354
+ spark_ordered = session.table(spark_table).select(
355
+ *[col(c) for c in cols_for_spark_selection]
356
+ )
357
+ snowpark_ordered = session.table(snowpark_table).select(
358
+ *[col(c) for c in cols_for_snowpark_selection]
359
+ )
360
+
361
+ spark_leftovers = spark_ordered.except_(snowpark_ordered).collect()
362
+ snowpark_leftovers = snowpark_ordered.except_(spark_ordered).collect()
363
+
364
+ spark_only_rows = [row.asDict() for row in spark_leftovers]
365
+ snowpark_only_rows = [row.asDict() for row in snowpark_leftovers]
366
+
367
+ return {
368
+ "spark_only_rows": spark_only_rows,
369
+ "snowpark_only_rows": snowpark_only_rows,
370
+ }
371
+
372
+ except Exception as e:
373
+ return {"error": f"An error occurred: {str(e)}"}
374
+
375
+
376
+ def convert_timestamps_to_utc_date(df):
377
+ """Convert and normalize all Snowpark timestamp columns to UTC.
378
+
379
+ This function ensures timestamps are consistent across environments for reliable comparison.
380
+ """
381
+ new_cols = []
382
+ for field in df.schema.fields:
383
+ if isinstance(field.datatype, TimestampType):
384
+ utc_midnight_ts = expr(
385
+ f"convert_timezone('UTC', cast(to_date({field.name}) as timestamp_tz))"
386
+ ).alias(field.name)
387
+ new_cols.append(utc_midnight_ts)
388
+ else:
389
+ new_cols.append(col(field.name))
390
+ return df.select(new_cols)
391
+
392
+
315
393
  def _find_frame_in(stack: list[inspect.FrameInfo]) -> tuple:
316
394
  """Find a specific frame in the provided stack trace.
317
395
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: snowpark-checkpoints-validators
3
- Version: 0.3.2
3
+ Version: 0.4.0
4
4
  Summary: Migration tools for Snowpark
5
5
  Project-URL: Bug Tracker, https://github.com/snowflakedb/snowpark-checkpoints/issues
6
6
  Project-URL: Source code, https://github.com/snowflakedb/snowpark-checkpoints/
@@ -1,10 +1,10 @@
1
1
  snowflake/snowpark_checkpoints/__init__.py,sha256=CfKakKzrSymSDP9zGSE2iK4RAHcHZSfL-zEG_8GnHnc,1509
2
- snowflake/snowpark_checkpoints/__version__.py,sha256=1W0aBeLTL5Svy-qrNkZc6gAKtQLDbncpMyN2SlnJhoU,632
3
- snowflake/snowpark_checkpoints/checkpoint.py,sha256=pU-HdpoS4SYzJU0qEaFzS5QBUE8K55Sn8K27zJe9_xM,24187
2
+ snowflake/snowpark_checkpoints/__version__.py,sha256=mZG_4eaVJdzo54iJo1tR3khnIA6lKjmN2lUgMoangNY,632
3
+ snowflake/snowpark_checkpoints/checkpoint.py,sha256=4IzS_wuONVQdxUnCixymb1HJr3eeiiGEzXfvi74I1Qc,25002
4
4
  snowflake/snowpark_checkpoints/errors.py,sha256=9KjzRf8bjDZTTNL4LeySJAwuucDOyz0Ka7EFBKWFpyg,1821
5
5
  snowflake/snowpark_checkpoints/job_context.py,sha256=RMK0g0HrbDVrOAvai4PgsGvsAn_GIo9aFmh-tWlyieY,4183
6
6
  snowflake/snowpark_checkpoints/singleton.py,sha256=7AgIHQBXVRvPBBCkmBplzkdrrm-xVWf_N8svzA2vF8E,836
7
- snowflake/snowpark_checkpoints/snowpark_sampler.py,sha256=Qxv-8nRGuf-ab3GoSUt8_MNL0ppjoBIMOFIMkqmwN5I,4668
7
+ snowflake/snowpark_checkpoints/snowpark_sampler.py,sha256=soew7FBnWqGp6VeBEFDakNbyjJD1imVJepGf6UmbFew,8426
8
8
  snowflake/snowpark_checkpoints/spark_migration.py,sha256=s2HqomYx76Hqn71g9TleBeHI3t1nirgfPvkggqQQdts,10253
9
9
  snowflake/snowpark_checkpoints/validation_result_metadata.py,sha256=5C8f1g-Grs2ydpXiZBLGt5n9cvEHBaw2-CDeb2vnhpg,5847
10
10
  snowflake/snowpark_checkpoints/validation_results.py,sha256=J8OcpNty6hQD8RbAy8xmA0UMbPWfXSmQnHYspWWSisk,1502
@@ -13,14 +13,14 @@ snowflake/snowpark_checkpoints/io_utils/io_default_strategy.py,sha256=VMfdqj4uDg
13
13
  snowflake/snowpark_checkpoints/io_utils/io_env_strategy.py,sha256=ltG_rxm0CkJFXpskOf__ByZw-C6B9LtycqlyB9EmaJI,3569
14
14
  snowflake/snowpark_checkpoints/io_utils/io_file_manager.py,sha256=YHrxRBzTlhIUrSFrsoWkRY_Qa-TXgDWglr00T98Tc5g,2485
15
15
  snowflake/snowpark_checkpoints/utils/__init__.py,sha256=I4srmZ8G1q9DU6Suo1S91aVfNvETyisKH95uvLAvEJ0,609
16
- snowflake/snowpark_checkpoints/utils/constants.py,sha256=M3vLdvKiVOhHMo0oPu4P42Wn_v6UDqmK6wHOGuoG6sY,4179
16
+ snowflake/snowpark_checkpoints/utils/constants.py,sha256=SscPXRhTKfT2moChXheMDJBs1A8YWKvjNuQkwV8FT38,4501
17
17
  snowflake/snowpark_checkpoints/utils/extra_config.py,sha256=xOYaG6MfsUCAHI0C_7qWF_m96xcLIZWwrgxY4UlpaZI,4325
18
18
  snowflake/snowpark_checkpoints/utils/logging_utils.py,sha256=yyi6X5DqKeTg0HRhvsH6ymYp2P0wbnyKIzI2RzrQS7k,2278
19
19
  snowflake/snowpark_checkpoints/utils/pandera_check_manager.py,sha256=tQIozLO-2kM8WZ-gGKfRwmXBx1cDPaIZB0qIcArp8xA,16100
20
20
  snowflake/snowpark_checkpoints/utils/supported_types.py,sha256=GrMX2tHdSFnK7LlPbZx20UufD6Br6TNVRkkBwIxdPy0,1433
21
21
  snowflake/snowpark_checkpoints/utils/telemetry.py,sha256=GfuyIaI3QG4a4_qWwyJHvWRM0GENunNexuEJ6IgscF4,32684
22
- snowflake/snowpark_checkpoints/utils/utils_checks.py,sha256=oQ1c4n-uAA2kFIpWIRPWhbCW8e-wwOIL8qDqLvr5Fok,14398
23
- snowpark_checkpoints_validators-0.3.2.dist-info/METADATA,sha256=COJncHytOF0_orQJPFUkPgcNKMaQWk5l-TYVb2nQBMg,12676
24
- snowpark_checkpoints_validators-0.3.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
25
- snowpark_checkpoints_validators-0.3.2.dist-info/licenses/LICENSE,sha256=pmjhbh6uVhV5MBXOlou_UZgFP7CYVQITkCCdvfcS5lY,11340
26
- snowpark_checkpoints_validators-0.3.2.dist-info/RECORD,,
22
+ snowflake/snowpark_checkpoints/utils/utils_checks.py,sha256=5-EdkNnCjCYfwzdDLgVg0GykbsueXaGYhh5pOE1j0Z8,17325
23
+ snowpark_checkpoints_validators-0.4.0.dist-info/METADATA,sha256=LtNR7bV-MskVmJ-4CzqEWcFia5_wNT8cJV8JEbeHy5s,12676
24
+ snowpark_checkpoints_validators-0.4.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
25
+ snowpark_checkpoints_validators-0.4.0.dist-info/licenses/LICENSE,sha256=pmjhbh6uVhV5MBXOlou_UZgFP7CYVQITkCCdvfcS5lY,11340
26
+ snowpark_checkpoints_validators-0.4.0.dist-info/RECORD,,