snowpark-checkpoints-validators 0.3.2__tar.gz → 0.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/PKG-INFO +1 -1
  2. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/pyproject.toml +1 -1
  3. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/src/snowflake/snowpark_checkpoints/__version__.py +1 -1
  4. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/src/snowflake/snowpark_checkpoints/checkpoint.py +24 -4
  5. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/src/snowflake/snowpark_checkpoints/snowpark_sampler.py +104 -3
  6. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/src/snowflake/snowpark_checkpoints/utils/constants.py +14 -0
  7. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/src/snowflake/snowpark_checkpoints/utils/utils_checks.py +87 -9
  8. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/integ/telemetry_expected/test_df_check_custom_check_telemetry.json +1 -1
  9. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/integ/telemetry_expected/test_df_check_from_file_telemetry.json +1 -1
  10. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/integ/telemetry_expected/test_df_check_skip_check_telemetry.json +1 -1
  11. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/integ/telemetry_expected/test_df_check_telemetry.json +1 -1
  12. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/integ/telemetry_expected/test_input_telemetry.json +1 -1
  13. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/integ/telemetry_expected/test_output_telemetry.json +1 -1
  14. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/integ/test_pandera.py +7 -7
  15. snowpark_checkpoints_validators-0.4.0/test/unit/test_pandera_validations.py +130 -0
  16. snowpark_checkpoints_validators-0.4.0/test/unit/test_snowpark_sampler.py +117 -0
  17. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/unit/test_utils_checks.py +18 -11
  18. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/unit/test_validation_result_metadata.py +3 -3
  19. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/.gitignore +0 -0
  20. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/CHANGELOG.md +0 -0
  21. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/LICENSE +0 -0
  22. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/README.md +0 -0
  23. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/src/snowflake/snowpark_checkpoints/__init__.py +0 -0
  24. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/src/snowflake/snowpark_checkpoints/errors.py +0 -0
  25. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/src/snowflake/snowpark_checkpoints/io_utils/__init__.py +0 -0
  26. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/src/snowflake/snowpark_checkpoints/io_utils/io_default_strategy.py +0 -0
  27. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/src/snowflake/snowpark_checkpoints/io_utils/io_env_strategy.py +0 -0
  28. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/src/snowflake/snowpark_checkpoints/io_utils/io_file_manager.py +0 -0
  29. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/src/snowflake/snowpark_checkpoints/job_context.py +0 -0
  30. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/src/snowflake/snowpark_checkpoints/singleton.py +0 -0
  31. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/src/snowflake/snowpark_checkpoints/spark_migration.py +0 -0
  32. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/src/snowflake/snowpark_checkpoints/utils/__init__.py +0 -0
  33. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/src/snowflake/snowpark_checkpoints/utils/extra_config.py +0 -0
  34. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/src/snowflake/snowpark_checkpoints/utils/logging_utils.py +0 -0
  35. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/src/snowflake/snowpark_checkpoints/utils/pandera_check_manager.py +0 -0
  36. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/src/snowflake/snowpark_checkpoints/utils/supported_types.py +0 -0
  37. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/src/snowflake/snowpark_checkpoints/utils/telemetry.py +0 -0
  38. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/src/snowflake/snowpark_checkpoints/validation_result_metadata.py +0 -0
  39. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/src/snowflake/snowpark_checkpoints/validation_results.py +0 -0
  40. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/.coveragerc +0 -0
  41. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/integ/e2eexample.py +0 -0
  42. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/integ/telemetry_compare_utils.py +0 -0
  43. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/integ/telemetry_expected/df_mode_dataframe_mismatch_telemetry.json +0 -0
  44. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/integ/telemetry_expected/df_mode_dataframe_telemetry.json +0 -0
  45. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/integ/telemetry_expected/spark_checkpoint_df_fail_telemetry.json +0 -0
  46. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/integ/telemetry_expected/spark_checkpoint_df_pass_telemetry.json +0 -0
  47. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/integ/telemetry_expected/spark_checkpoint_limit_sample_telemetry.json +0 -0
  48. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/integ/telemetry_expected/spark_checkpoint_random_sample_telemetry.json +0 -0
  49. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/integ/telemetry_expected/spark_checkpoint_scalar_fail_telemetry.json +0 -0
  50. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/integ/telemetry_expected/spark_checkpoint_scalar_passing_telemetry.json +0 -0
  51. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/integ/telemetry_expected/test_df_check_fail_telemetry.json +0 -0
  52. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/integ/telemetry_expected/test_input_fail_telemetry.json +0 -0
  53. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/integ/telemetry_expected/test_output_fail_telemetry.json +0 -0
  54. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/integ/test_parquet.py +0 -0
  55. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/integ/test_spark_checkpoint.py +0 -0
  56. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/unit/io_utils/test_default_strategy.py +0 -0
  57. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/unit/test_checkpoints.py +0 -0
  58. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/unit/test_extra_config.py +0 -0
  59. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/unit/test_job_context.py +0 -0
  60. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/unit/test_logger.py +0 -0
  61. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/unit/test_logging_utils.py +0 -0
  62. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/unit/test_pandera_check_manager.py +0 -0
  63. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/unit/test_spark_migration.py +0 -0
  64. {snowpark_checkpoints_validators-0.3.2 → snowpark_checkpoints_validators-0.4.0}/test/unit/test_telemetry.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: snowpark-checkpoints-validators
3
- Version: 0.3.2
3
+ Version: 0.4.0
4
4
  Summary: Migration tools for Snowpark
5
5
  Project-URL: Bug Tracker, https://github.com/snowflakedb/snowpark-checkpoints/issues
6
6
  Project-URL: Source code, https://github.com/snowflakedb/snowpark-checkpoints/
@@ -127,7 +127,7 @@ exclude_lines = [
127
127
 
128
128
  [tool.hatch.envs.linter.scripts]
129
129
  check = [
130
- 'ruff check --fix .',
130
+ "echo 'Running linting checks...' && ruff check --config=../ruff.toml --statistics --verbose . || (echo '❌ LINTING FAILED: Please fix the above linting issues before proceeding. Use \"ruff check --config=../ruff.toml --fix .\" to auto-fix some issues, or fix them manually.' && exit 1)",
131
131
  ]
132
132
 
133
133
  [tool.hatch.envs.test.scripts]
@@ -13,4 +13,4 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- __version__ = "0.3.2"
16
+ __version__ = "0.4.0"
@@ -332,7 +332,7 @@ def _check_dataframe_schema(
332
332
  pandera_schema_upper, sample_df = _process_sampling(
333
333
  df, pandera_schema, job_context, sample_frac, sample_number, sampling_strategy
334
334
  )
335
- is_valid, validation_result = _validate(pandera_schema_upper, sample_df)
335
+ is_valid, validation_result = validate(pandera_schema_upper, sample_df)
336
336
  if is_valid:
337
337
  LOGGER.info(
338
338
  "DataFrame schema validation passed for checkpoint '%s'",
@@ -438,7 +438,7 @@ def check_output_schema(
438
438
  sampler.process_args([snowpark_results])
439
439
  pandas_sample_args = sampler.get_sampled_pandas_args()
440
440
 
441
- is_valid, validation_result = _validate(
441
+ is_valid, validation_result = validate(
442
442
  pandera_schema, pandas_sample_args[0]
443
443
  )
444
444
  if is_valid:
@@ -565,7 +565,7 @@ def check_input_schema(
565
565
  )
566
566
  continue
567
567
 
568
- is_valid, validation_result = _validate(
568
+ is_valid, validation_result = validate(
569
569
  pandera_schema,
570
570
  arg,
571
571
  )
@@ -606,11 +606,31 @@ def check_input_schema(
606
606
  return check_input_with_decorator
607
607
 
608
608
 
609
- def _validate(
609
+ def validate(
610
610
  schema: Union[type[DataFrameModel], DataFrameSchema],
611
611
  df: PandasDataFrame,
612
612
  lazy: bool = True,
613
613
  ) -> tuple[bool, PandasDataFrame]:
614
+ """Validate a Pandas DataFrame against a given Pandera schema.
615
+
616
+ Args:
617
+ schema (Union[type[DataFrameModel], DataFrameSchema]):
618
+ The schema to validate against. Can be a Pandera `DataFrameSchema` or
619
+ a `DataFrameModel` class.
620
+ df (PandasDataFrame):
621
+ The Pandas DataFrame to be validated.
622
+ lazy (bool, optional):
623
+ If `True`, collect all validation errors before raising an exception.
624
+ If `False`, raise an exception as soon as the first error is encountered.
625
+ Defaults to `True`.
626
+
627
+ Returns:
628
+ tuple[bool, PandasDataFrame]:
629
+ A tuple containing:
630
+ - A boolean indicating whether validation passed.
631
+ - The validated DataFrame if successful, or the failure cases DataFrame if not.
632
+
633
+ """
614
634
  if not isinstance(schema, DataFrameSchema):
615
635
  schema = schema.to_schema()
616
636
  is_valid = True
@@ -20,7 +20,21 @@ from typing import Optional
20
20
  import pandas
21
21
 
22
22
  from snowflake.snowpark import DataFrame as SnowparkDataFrame
23
+ from snowflake.snowpark.types import (
24
+ BinaryType,
25
+ BooleanType,
26
+ DateType,
27
+ FloatType,
28
+ StringType,
29
+ TimestampType,
30
+ )
23
31
  from snowflake.snowpark_checkpoints.job_context import SnowparkJobContext
32
+ from snowflake.snowpark_checkpoints.utils.constants import (
33
+ INTEGER_TYPE_COLLECTION,
34
+ PANDAS_FLOAT_TYPE,
35
+ PANDAS_LONG_TYPE,
36
+ PANDAS_STRING_TYPE,
37
+ )
24
38
 
25
39
 
26
40
  LOGGER = logging.getLogger(__name__)
@@ -73,17 +87,17 @@ class SamplingAdapter:
73
87
  "Applying random sampling with fraction %s",
74
88
  self.sample_frac,
75
89
  )
76
- df_sample = arg.sample(frac=self.sample_frac).to_pandas()
90
+ df_sample = to_pandas(arg.sample(frac=self.sample_frac))
77
91
  else:
78
92
  LOGGER.info(
79
93
  "Applying random sampling with size %s", self.sample_number
80
94
  )
81
- df_sample = arg.sample(n=self.sample_number).to_pandas()
95
+ df_sample = to_pandas(arg.sample(n=self.sample_number))
82
96
  else:
83
97
  LOGGER.info(
84
98
  "Applying limit sampling with size %s", self.sample_number
85
99
  )
86
- df_sample = arg.limit(self.sample_number).to_pandas()
100
+ df_sample = to_pandas(arg.limit(self.sample_number))
87
101
 
88
102
  LOGGER.info(
89
103
  "Successfully sampled the DataFrame. Resulting DataFrame shape: %s",
@@ -122,3 +136,90 @@ class SamplingAdapter:
122
136
  else:
123
137
  pyspark_sample_args.append(arg)
124
138
  return pyspark_sample_args
139
+
140
+
141
+ def to_pandas(sampled_df: SnowparkDataFrame) -> pandas.DataFrame:
142
+ """Convert a Snowpark DataFrame to a Pandas DataFrame, handling missing values and type conversions."""
143
+ LOGGER.debug("Converting Snowpark DataFrame to Pandas DataFrame")
144
+ pandas_df = sampled_df.toPandas()
145
+ for field in sampled_df.schema.fields:
146
+ is_snowpark_integer = field.datatype.typeName() in INTEGER_TYPE_COLLECTION
147
+ is_snowpark_string = isinstance(field.datatype, StringType)
148
+ is_snowpark_binary = isinstance(field.datatype, BinaryType)
149
+ is_snowpark_timestamp = isinstance(field.datatype, TimestampType)
150
+ is_snowpark_float = isinstance(field.datatype, FloatType)
151
+ is_snowpark_boolean = isinstance(field.datatype, BooleanType)
152
+ is_snowpark_date = isinstance(field.datatype, DateType)
153
+ if is_snowpark_integer:
154
+ LOGGER.debug(
155
+ "Converting Spark integer column '%s' to Pandas nullable '%s' type",
156
+ field.name,
157
+ PANDAS_LONG_TYPE,
158
+ )
159
+ pandas_df[field.name] = (
160
+ pandas_df[field.name].astype(PANDAS_LONG_TYPE).fillna(0)
161
+ )
162
+ elif is_snowpark_string or is_snowpark_binary:
163
+ LOGGER.debug(
164
+ "Converting Spark string column '%s' to Pandas nullable '%s' type",
165
+ field.name,
166
+ PANDAS_STRING_TYPE,
167
+ )
168
+ pandas_df[field.name] = (
169
+ pandas_df[field.name].astype(PANDAS_STRING_TYPE).fillna("")
170
+ )
171
+ elif is_snowpark_timestamp:
172
+ LOGGER.debug(
173
+ "Converting Spark timestamp column '%s' to UTC naive Pandas datetime",
174
+ field.name,
175
+ )
176
+ pandas_df[field.name] = convert_all_to_utc_naive(
177
+ pandas_df[field.name]
178
+ ).fillna(pandas.NaT)
179
+ elif is_snowpark_float:
180
+ LOGGER.debug(
181
+ "Converting Spark float column '%s' to Pandas nullable float",
182
+ field.name,
183
+ )
184
+ pandas_df[field.name] = (
185
+ pandas_df[field.name].astype(PANDAS_FLOAT_TYPE).fillna(0.0)
186
+ )
187
+ elif is_snowpark_boolean:
188
+ LOGGER.debug(
189
+ "Converting Spark boolean column '%s' to Pandas nullable boolean",
190
+ field.name,
191
+ )
192
+ pandas_df[field.name] = (
193
+ pandas_df[field.name].astype("boolean").fillna(False)
194
+ )
195
+ elif is_snowpark_date:
196
+ LOGGER.debug(
197
+ "Converting Spark date column '%s' to Pandas nullable datetime",
198
+ field.name,
199
+ )
200
+ pandas_df[field.name] = pandas_df[field.name].fillna(pandas.NaT)
201
+
202
+ return pandas_df
203
+
204
+
205
+ def convert_all_to_utc_naive(series: pandas.Series) -> pandas.Series:
206
+ """Convert all timezone-aware or naive timestamps in a series to UTC naive.
207
+
208
+ Naive timestamps are assumed to be in UTC and localized accordingly.
209
+ Timezone-aware timestamps are converted to UTC and then made naive.
210
+
211
+ Args:
212
+ series (pandas.Series): A Pandas Series of `pd.Timestamp` objects,
213
+ either naive or timezone-aware.
214
+
215
+ Returns:
216
+ pandas.Series: A Series of UTC-normalized naive timestamps (`tzinfo=None`).
217
+
218
+ """
219
+
220
+ def convert(ts):
221
+ if ts.tz is None:
222
+ ts = ts.tz_localize("UTC")
223
+ return ts.tz_convert("UTC").tz_localize(None)
224
+
225
+ return series.apply(convert)
@@ -133,3 +133,17 @@ VALIDATION_RESULTS_JSON_FILE_NAME: Final[str] = "checkpoint_validation_results.j
133
133
  SNOWFLAKE_CHECKPOINT_CONTRACT_FILE_PATH_ENV_VAR: Final[
134
134
  str
135
135
  ] = "SNOWFLAKE_CHECKPOINT_CONTRACT_FILE_PATH"
136
+
137
+ BYTE_COLUMN_TYPE = "byte"
138
+ INTEGER_COLUMN_TYPE = "integer"
139
+ LONG_COLUMN_TYPE = "long"
140
+ SHORT_COLUMN_TYPE = "short"
141
+ PANDAS_FLOAT_TYPE = "float64"
142
+ PANDAS_LONG_TYPE = "Int64"
143
+ PANDAS_STRING_TYPE = "string"
144
+ INTEGER_TYPE_COLLECTION = [
145
+ BYTE_COLUMN_TYPE,
146
+ INTEGER_COLUMN_TYPE,
147
+ LONG_COLUMN_TYPE,
148
+ SHORT_COLUMN_TYPE,
149
+ ]
@@ -27,6 +27,9 @@ import numpy as np
27
27
  from pandera import DataFrameSchema
28
28
 
29
29
  from snowflake.snowpark import DataFrame as SnowparkDataFrame
30
+ from snowflake.snowpark import Session
31
+ from snowflake.snowpark.functions import col, expr
32
+ from snowflake.snowpark.types import TimestampType
30
33
  from snowflake.snowpark_checkpoints.errors import SchemaValidationError
31
34
  from snowflake.snowpark_checkpoints.io_utils.io_file_manager import get_io_file_manager
32
35
  from snowflake.snowpark_checkpoints.job_context import SnowparkJobContext
@@ -42,7 +45,6 @@ from snowflake.snowpark_checkpoints.utils.constants import (
42
45
  DATAFRAME_EXECUTION_MODE,
43
46
  DATAFRAME_PANDERA_SCHEMA_KEY,
44
47
  DEFAULT_KEY,
45
- EXCEPT_HASH_AGG_QUERY,
46
48
  FAIL_STATUS,
47
49
  PASS_STATUS,
48
50
  SNOWPARK_CHECKPOINTS_OUTPUT_DIRECTORY_NAME,
@@ -120,14 +122,14 @@ def _process_sampling(
120
122
  pandera_schema_upper = pandera_schema
121
123
  new_columns: dict[Any, Any] = {}
122
124
 
123
- for col in pandera_schema.columns:
124
- new_columns[col.upper()] = pandera_schema.columns[col]
125
+ for column in pandera_schema.columns:
126
+ new_columns[column.upper()] = pandera_schema.columns[column]
125
127
 
126
128
  pandera_schema_upper = pandera_schema_upper.remove_columns(pandera_schema.columns)
127
129
  pandera_schema_upper = pandera_schema_upper.add_columns(new_columns)
128
130
 
129
131
  sample_df = sampler.get_sampled_pandas_args()[0]
130
- sample_df.index = np.ones(sample_df.count().iloc[0])
132
+ sample_df.index = np.ones(sample_df.count().iloc[0], dtype=int)
131
133
 
132
134
  return pandera_schema_upper, sample_df
133
135
 
@@ -191,6 +193,7 @@ Please run the Snowpark checkpoint collector first."""
191
193
  schema_dict = checkpoint_schema_config.get(DATAFRAME_PANDERA_SCHEMA_KEY)
192
194
  schema_dict_str = json.dumps(schema_dict)
193
195
  schema = DataFrameSchema.from_json(schema_dict_str)
196
+ schema.coerce = False # Disable coercion to ensure strict validation
194
197
 
195
198
  if DATAFRAME_CUSTOM_DATA_KEY not in checkpoint_schema_config:
196
199
  LOGGER.info(
@@ -270,6 +273,7 @@ def _compare_data(
270
273
  SchemaValidationError: If there is a data mismatch between the DataFrame and the checkpoint table.
271
274
 
272
275
  """
276
+ df = convert_timestamps_to_utc_date(df)
273
277
  new_table_name = CHECKPOINT_TABLE_NAME_FORMAT.format(checkpoint_name)
274
278
  LOGGER.info(
275
279
  "Writing Snowpark DataFrame to table: '%s' for checkpoint: '%s'",
@@ -283,12 +287,12 @@ def _compare_data(
283
287
  new_table_name,
284
288
  checkpoint_name,
285
289
  )
286
- expect_df = job_context.snowpark_session.sql(
287
- EXCEPT_HASH_AGG_QUERY, [checkpoint_name, new_table_name]
288
- )
289
290
 
290
- if expect_df.count() != 0:
291
- error_message = f"Data mismatch for checkpoint {checkpoint_name}"
291
+ session = job_context.snowpark_session
292
+ result = get_comparison_differences(session, checkpoint_name, new_table_name)
293
+ has_failed = result.get("spark_only_rows") or result.get("snowpark_only_rows")
294
+ if has_failed or result.get("error"):
295
+ error_message = f"Data mismatch for checkpoint {checkpoint_name}: {result}"
292
296
  job_context._mark_fail(
293
297
  error_message,
294
298
  checkpoint_name,
@@ -312,6 +316,80 @@ def _compare_data(
312
316
  return True, None
313
317
 
314
318
 
319
+ def get_comparison_differences(
320
+ session: Session, spark_table: str, snowpark_table: str
321
+ ) -> dict:
322
+ """Compare two tables and return the differences."""
323
+ try:
324
+ spark_raw_schema = session.table(spark_table).schema.names
325
+ snowpark_raw_schema = session.table(snowpark_table).schema.names
326
+
327
+ spark_normalized = {
328
+ col_name.strip('"').upper(): col_name for col_name in spark_raw_schema
329
+ }
330
+ snowpark_normalized = {
331
+ col_name.strip('"').upper(): col_name for col_name in snowpark_raw_schema
332
+ }
333
+
334
+ common_cols = sorted(
335
+ list(
336
+ set(spark_normalized.keys()).intersection(
337
+ set(snowpark_normalized.keys())
338
+ )
339
+ )
340
+ )
341
+
342
+ if not common_cols:
343
+ return {
344
+ "error": f"No common columns found between {spark_table} and {snowpark_table}",
345
+ }
346
+
347
+ cols_for_spark_selection = [
348
+ spark_normalized[norm_col_name] for norm_col_name in common_cols
349
+ ]
350
+ cols_for_snowpark_selection = [
351
+ snowpark_normalized[norm_col_name] for norm_col_name in common_cols
352
+ ]
353
+
354
+ spark_ordered = session.table(spark_table).select(
355
+ *[col(c) for c in cols_for_spark_selection]
356
+ )
357
+ snowpark_ordered = session.table(snowpark_table).select(
358
+ *[col(c) for c in cols_for_snowpark_selection]
359
+ )
360
+
361
+ spark_leftovers = spark_ordered.except_(snowpark_ordered).collect()
362
+ snowpark_leftovers = snowpark_ordered.except_(spark_ordered).collect()
363
+
364
+ spark_only_rows = [row.asDict() for row in spark_leftovers]
365
+ snowpark_only_rows = [row.asDict() for row in snowpark_leftovers]
366
+
367
+ return {
368
+ "spark_only_rows": spark_only_rows,
369
+ "snowpark_only_rows": snowpark_only_rows,
370
+ }
371
+
372
+ except Exception as e:
373
+ return {"error": f"An error occurred: {str(e)}"}
374
+
375
+
376
+ def convert_timestamps_to_utc_date(df):
377
+ """Convert and normalize all Snowpark timestamp columns to UTC.
378
+
379
+ This function ensures timestamps are consistent across environments for reliable comparison.
380
+ """
381
+ new_cols = []
382
+ for field in df.schema.fields:
383
+ if isinstance(field.datatype, TimestampType):
384
+ utc_midnight_ts = expr(
385
+ f"convert_timezone('UTC', cast(to_date({field.name}) as timestamp_tz))"
386
+ ).alias(field.name)
387
+ new_cols.append(utc_midnight_ts)
388
+ else:
389
+ new_cols.append(col(field.name))
390
+ return df.select(new_cols)
391
+
392
+
315
393
  def _find_frame_in(stack: list[inspect.FrameInfo]) -> tuple:
316
394
  """Find a specific frame in the provided stack trace.
317
395
 
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "message": {
3
- "data": "{\"function\": \"_check_dataframe_schema\", \"mode\": 1, \"status\": true, \"schema_types\": [\"int8\", \"float64\"]}",
3
+ "data": "{\"function\": \"_check_dataframe_schema\", \"mode\": 1, \"status\": true, \"schema_types\": [\"int64\", \"float64\"]}",
4
4
  "driver_type": "PythonConnector",
5
5
  "driver_version": "3.12.4",
6
6
  "event_name": "DataFrame_Validator_Schema",
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "message": {
3
- "data": "{\"function\": \"_check_dataframe_schema\", \"mode\": 1, \"status\": true, \"schema_types\": [\"int8\", \"float64\"]}",
3
+ "data": "{\"function\": \"_check_dataframe_schema\", \"mode\": 1, \"status\": true, \"schema_types\": [\"int64\", \"float64\"]}",
4
4
  "driver_type": "PythonConnector",
5
5
  "driver_version": "3.12.4",
6
6
  "event_name": "DataFrame_Validator_Schema",
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "message": {
3
- "data": "{\"function\": \"_check_dataframe_schema\", \"mode\": 1, \"status\": true, \"schema_types\": [\"int8\", \"float64\"]}",
3
+ "data": "{\"function\": \"_check_dataframe_schema\", \"mode\": 1, \"status\": true, \"schema_types\": [\"int64\", \"float64\"]}",
4
4
  "driver_type": "PythonConnector",
5
5
  "driver_version": "3.12.4",
6
6
  "event_name": "DataFrame_Validator_Schema",
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "message": {
3
- "data": "{\"function\": \"_check_dataframe_schema\", \"mode\": 1, \"status\": true, \"schema_types\": [\"int8\", \"float64\"]}",
3
+ "data": "{\"function\": \"_check_dataframe_schema\", \"mode\": 1, \"status\": true, \"schema_types\": [\"int64\", \"float64\"]}",
4
4
  "driver_type": "PythonConnector",
5
5
  "driver_version": "3.12.4",
6
6
  "event_name": "DataFrame_Validator_Schema",
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "message": {
3
- "data": "{\"function\": \"check_input_schema\", \"schema_types\": [\"int8\", \"float64\"]}",
3
+ "data": "{\"function\": \"check_input_schema\", \"schema_types\": [\"int64\", \"float64\"]}",
4
4
  "driver_type": "PythonConnector",
5
5
  "driver_version": "3.12.4",
6
6
  "event_name": "DataFrame_Validator_Schema",
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "message": {
3
- "data": "{\"function\": \"check_output_schema\", \"schema_types\": [\"int8\", \"float64\", \"float64\"]}",
3
+ "data": "{\"function\": \"check_output_schema\", \"schema_types\": [\"int64\", \"float64\", \"float64\"]}",
4
4
  "driver_type": "PythonConnector",
5
5
  "driver_version": "3.12.4",
6
6
  "event_name": "DataFrame_Validator_Schema",
@@ -23,7 +23,7 @@ from unittest.mock import MagicMock, patch
23
23
 
24
24
  import pytest
25
25
 
26
- from numpy import int8
26
+ from numpy import int8, int64
27
27
  from pandas import DataFrame as PandasDataFrame
28
28
  from pandera import Check, Column, DataFrameSchema
29
29
  from pytest import raises
@@ -78,7 +78,7 @@ def test_input(telemetry_output_path):
78
78
 
79
79
  in_schema = DataFrameSchema(
80
80
  {
81
- "COLUMN1": Column(int8, Check(lambda x: 0 <= x <= 10, element_wise=True)),
81
+ "COLUMN1": Column(int64, Check(lambda x: 0 <= x <= 10, element_wise=True)),
82
82
  "COLUMN2": Column(float, Check(lambda x: x < -1.2, element_wise=True)),
83
83
  }
84
84
  )
@@ -161,7 +161,7 @@ def test_output(telemetry_output_path):
161
161
  out_schema = DataFrameSchema(
162
162
  {
163
163
  "COLUMN1": Column(
164
- int8, Check.between(0, 10, include_max=True, include_min=True)
164
+ int64, Check.between(0, 10, include_max=True, include_min=True)
165
165
  ),
166
166
  "COLUMN2": Column(float, Check.less_than_or_equal_to(-1.2)),
167
167
  "COLUMN3": Column(float, Check.less_than(10)),
@@ -244,7 +244,7 @@ def test_df_check(telemetry_output_path):
244
244
 
245
245
  schema = DataFrameSchema(
246
246
  {
247
- "COLUMN1": Column(int8, Check(lambda x: 0 <= x <= 10, element_wise=True)),
247
+ "COLUMN1": Column(int64, Check(lambda x: 0 <= x <= 10, element_wise=True)),
248
248
  "COLUMN2": Column(float, Check(lambda x: x < -1.2, element_wise=True)),
249
249
  }
250
250
  )
@@ -320,7 +320,7 @@ def test_df_check_from_file(telemetry_output_path):
320
320
 
321
321
  schema = DataFrameSchema(
322
322
  {
323
- "COLUMN1": Column(int8, Check.between(0, 10)),
323
+ "COLUMN1": Column(int64, Check.between(0, 10)),
324
324
  "COLUMN2": Column(float, Check.between(-20.5, -1.0)),
325
325
  }
326
326
  )
@@ -409,7 +409,7 @@ def test_df_check_custom_check(telemetry_output_path):
409
409
 
410
410
  schema = DataFrameSchema(
411
411
  {
412
- "COLUMN1": Column(int8, Check(lambda x: 0 <= x <= 10, element_wise=True)),
412
+ "COLUMN1": Column(int64, Check(lambda x: 0 <= x <= 10, element_wise=True)),
413
413
  "COLUMN2": Column(float, Check(lambda x: x < -1.2, element_wise=True)),
414
414
  }
415
415
  )
@@ -454,7 +454,7 @@ def test_df_check_skip_check(telemetry_output_path):
454
454
 
455
455
  schema = DataFrameSchema(
456
456
  {
457
- "COLUMN1": Column(int8, Check.between(0, 10, element_wise=True)),
457
+ "COLUMN1": Column(int64, Check.between(0, 10, element_wise=True)),
458
458
  "COLUMN2": Column(
459
459
  float,
460
460
  [
@@ -0,0 +1,130 @@
1
+ # Copyright 2025 Snowflake Inc.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from snowflake.snowpark_checkpoints.checkpoint import validate
17
+ from pandera import DataFrameSchema, Column, Check
18
+ import pandas as pd
19
+ import pytz
20
+
21
+
22
+ def test_pandera_validate_equivalent_dataframes():
23
+ schema = DataFrameSchema(
24
+ {
25
+ "a": Column(
26
+ int, checks=Check(lambda s: s > 0, element_wise=True), nullable=False
27
+ )
28
+ }
29
+ )
30
+ df = pd.DataFrame({"a": [1, 2, 3]})
31
+ result, validated_df = validate(schema, df)
32
+ assert result
33
+ pd.testing.assert_frame_equal(validated_df, df)
34
+
35
+
36
+ def test_pandera_validate_object_vs_string():
37
+ schema = DataFrameSchema({"a": Column(str, nullable=False)})
38
+
39
+ df_object = pd.DataFrame({"a": pd.Series(["x", "y", "z"], dtype="object")})
40
+ result, validated_df = validate(schema, df_object)
41
+ assert result
42
+
43
+ df_int_as_string = pd.DataFrame({"a": ["1", "2", "3"]})
44
+ result, validated_df = validate(schema, df_int_as_string)
45
+ assert result
46
+
47
+ df_mixed = pd.DataFrame({"a": ["x", 1, "z"]})
48
+ result, validated_df = validate(schema, df_mixed)
49
+ assert not result
50
+
51
+
52
+ def test_pandera_validate_int_vs_string():
53
+ schema = DataFrameSchema({"a": Column(int, nullable=False)})
54
+ df_valid_int = pd.DataFrame({"a": [1, 2, 3]})
55
+ result, _ = validate(schema, df_valid_int)
56
+ assert result
57
+
58
+ df_string_numbers = pd.DataFrame({"a": ["1", "2", "3"]})
59
+ result, failure_cases = validate(schema, df_string_numbers)
60
+ assert not result
61
+
62
+ df_mixed = pd.DataFrame({"a": [1, "2", 3]})
63
+ result, failure_cases = validate(schema, df_mixed)
64
+ assert not result
65
+
66
+
67
+ def test_timestamp_ntz():
68
+ schema = DataFrameSchema({"ts": Column(pd.Timestamp, nullable=False)})
69
+
70
+ df = pd.DataFrame(
71
+ {
72
+ "ts": pd.to_datetime(
73
+ ["2024-01-01 10:00", "2024-01-02 11:00", "2024-01-03 12:00"]
74
+ )
75
+ }
76
+ )
77
+ result, validated_df = validate(schema, df)
78
+ assert result
79
+ assert validated_df["ts"].dt.tz is None
80
+
81
+
82
+ def test_timestamp_utc_timezone():
83
+ schema = DataFrameSchema({"ts": Column(pd.Timestamp, nullable=False)})
84
+
85
+ df = pd.DataFrame(
86
+ {
87
+ "ts": pd.to_datetime(
88
+ [
89
+ "2024-01-01 10:00+00:00",
90
+ "2024-01-02 11:00+00:00",
91
+ "2024-01-03 12:00+00:00",
92
+ ]
93
+ )
94
+ }
95
+ )
96
+
97
+ df["ts"] = df["ts"].dt.tz_convert("UTC").dt.tz_localize(None)
98
+
99
+ result, validated_df = validate(schema, df)
100
+ assert result
101
+ assert validated_df["ts"].dt.tz is None
102
+
103
+
104
+ def convert_all_to_utc_naive(series: pd.Series) -> pd.Series:
105
+ def convert(ts):
106
+ if ts.tz is None:
107
+ ts = ts.tz_localize("UTC")
108
+ return ts.tz_convert("UTC").tz_localize(None)
109
+
110
+ return series.apply(convert)
111
+
112
+
113
+ def test_timestamp_mixed_timezones_fails():
114
+ schema = DataFrameSchema({"ts": Column(pd.Timestamp, nullable=False)})
115
+ eastern = pytz.timezone("US/Eastern")
116
+ df = pd.DataFrame(
117
+ {
118
+ "ts": [
119
+ pd.Timestamp("2024-01-01 10:00"),
120
+ eastern.localize(pd.Timestamp("2024-01-02 11:00")),
121
+ pd.Timestamp("2024-01-03 12:00+00:00"),
122
+ ]
123
+ }
124
+ )
125
+
126
+ df["ts"] = convert_all_to_utc_naive(df["ts"])
127
+ result, validated_df = validate(schema, df)
128
+
129
+ assert result
130
+ assert validated_df["ts"].dt.tz is None
@@ -0,0 +1,117 @@
1
+ # Copyright 2025 Snowflake Inc.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import pandas as pd
17
+ from snowflake.snowpark_checkpoints.snowpark_sampler import (
18
+ to_pandas,
19
+ convert_all_to_utc_naive,
20
+ )
21
+ from snowflake.snowpark.types import (
22
+ BinaryType,
23
+ FloatType,
24
+ StringType,
25
+ TimestampType,
26
+ )
27
+ from snowflake.snowpark_checkpoints.utils.constants import (
28
+ PANDAS_FLOAT_TYPE,
29
+ PANDAS_LONG_TYPE,
30
+ PANDAS_STRING_TYPE,
31
+ )
32
+
33
+
34
+ class DummyDataType:
35
+ def __init__(self, name):
36
+ self._name = name
37
+
38
+ def typeName(self):
39
+ return self._name
40
+
41
+
42
+ class DummyField:
43
+ def __init__(self, name, datatype):
44
+ self.name = name
45
+ self.datatype = datatype
46
+ self.typeName = datatype
47
+
48
+
49
+ class DummySchema:
50
+ def __init__(self, fields):
51
+ self.fields = fields
52
+
53
+
54
+ class DummySnowparkDF:
55
+ def __init__(self, pandas_df, fields):
56
+ self._pandas_df = pandas_df
57
+ self.schema = DummySchema(fields)
58
+
59
+ def toPandas(self):
60
+ return self._pandas_df
61
+
62
+
63
+ def test_to_pandas_integer_conversion():
64
+ df = pd.DataFrame({"int_col": [1, None]}, dtype="float")
65
+ fields = [DummyField("int_col", DummyDataType("integer"))]
66
+ sp_df = DummySnowparkDF(df, fields)
67
+
68
+ result = to_pandas(sp_df)
69
+ assert result["int_col"].dtype == PANDAS_LONG_TYPE
70
+ assert result["int_col"].iloc[1] == 0
71
+
72
+
73
+ def test_to_pandas_string_and_binary_conversion():
74
+ df = pd.DataFrame({"str_col": ["a", None], "bin_col": ["b", None]})
75
+ fields = [
76
+ DummyField("str_col", StringType()),
77
+ DummyField("bin_col", BinaryType()),
78
+ ]
79
+ sp_df = DummySnowparkDF(df, fields)
80
+
81
+ result = to_pandas(sp_df)
82
+ assert result["str_col"].dtype == PANDAS_STRING_TYPE
83
+ assert result["bin_col"].dtype == PANDAS_STRING_TYPE
84
+
85
+
86
+ def test_to_pandas_float_conversion():
87
+ df = pd.DataFrame({"float_col": [1.1, None]}, dtype="float")
88
+ fields = [DummyField("float_col", FloatType())]
89
+ sp_df = DummySnowparkDF(df, fields)
90
+
91
+ result = to_pandas(sp_df)
92
+ assert result["float_col"].dtype == PANDAS_FLOAT_TYPE
93
+
94
+
95
+ def test_to_pandas_timestamp_conversion():
96
+ utc_ts = pd.Timestamp("2023-01-01 12:00:00", tz="UTC")
97
+ naive_ts = pd.Timestamp("2023-01-02 12:00:00")
98
+ df = pd.DataFrame({"ts_col": [utc_ts, naive_ts]})
99
+ fields = [DummyField("ts_col", TimestampType())]
100
+ sp_df = DummySnowparkDF(df, fields)
101
+
102
+ result = to_pandas(sp_df)
103
+ assert pd.api.types.is_datetime64_any_dtype(result["ts_col"])
104
+ assert result["ts_col"].iloc[0].tzinfo is None
105
+ assert result["ts_col"].iloc[1].tzinfo is None
106
+
107
+
108
+ def test_convert_all_to_utc_naive_behavior():
109
+ utc_ts = pd.Timestamp("2024-01-01 10:00:00", tz="UTC")
110
+ naive_ts = pd.Timestamp("2024-01-01 12:00:00")
111
+ none_val = pd.NaT
112
+ series = pd.Series([utc_ts, naive_ts, none_val])
113
+
114
+ result = convert_all_to_utc_naive(series)
115
+ assert result[0].tzinfo is None
116
+ assert result[1].tzinfo is None
117
+ assert pd.isna(result[2])
@@ -282,6 +282,7 @@ def test_compare_data_match():
282
282
  job_context = MagicMock(spec=SnowparkJobContext)
283
283
  session = MagicMock()
284
284
  job_context.snowpark_session = session
285
+ job_context.job_name = checkpoint_name
285
286
 
286
287
  # Mock session.sql to return an empty DataFrame (indicating no mismatch)
287
288
  session.sql.return_value.count.return_value = 0
@@ -289,7 +290,6 @@ def test_compare_data_match():
289
290
  checkpoint_name = "test_checkpoint"
290
291
  validation_status = PASS_STATUS
291
292
  output_path = "test_output_path/utils/"
292
-
293
293
  with (
294
294
  patch("os.getcwd", return_value="/mocked/path"),
295
295
  patch("os.path.exists", return_value=False),
@@ -298,6 +298,14 @@ def test_compare_data_match():
298
298
  patch(
299
299
  "snowflake.snowpark_checkpoints.utils.utils_checks._update_validation_result"
300
300
  ) as mock_update_validation_result,
301
+ patch(
302
+ "snowflake.snowpark_checkpoints.utils.utils_checks.convert_timestamps_to_utc_date",
303
+ return_value=df,
304
+ ),
305
+ patch(
306
+ "snowflake.snowpark_checkpoints.utils.utils_checks.get_comparison_differences",
307
+ return_value={},
308
+ ) as mock_get_comparison_differences,
301
309
  ):
302
310
  # Call the function
303
311
  _check_compare_data(df, job_context, checkpoint_name, output_path)
@@ -309,11 +317,7 @@ def test_compare_data_match():
309
317
  df.write.save_as_table.assert_called_once_with(
310
318
  table_name=new_checkpoint_name, mode=OVERWRITE_MODE
311
319
  )
312
- calls = [
313
- call(EXCEPT_HASH_AGG_QUERY, [checkpoint_name, new_checkpoint_name]),
314
- call().count(),
315
- ]
316
- session.sql.assert_has_calls(calls)
320
+ mock_get_comparison_differences.assert_called_once()
317
321
  job_context._mark_pass.assert_called_once_with(
318
322
  checkpoint_name, DATAFRAME_EXECUTION_MODE
319
323
  )
@@ -344,6 +348,13 @@ def test_compare_data_mismatch():
344
348
  patch(
345
349
  "snowflake.snowpark_checkpoints.utils.utils_checks._update_validation_result"
346
350
  ) as mock_update_validation_result,
351
+ patch(
352
+ "snowflake.snowpark_checkpoints.utils.utils_checks.convert_timestamps_to_utc_date",
353
+ return_value=df,
354
+ ),
355
+ patch(
356
+ "snowflake.snowpark_checkpoints.utils.utils_checks.get_comparison_differences"
357
+ ) as mock_get_comparison_differences,
347
358
  ):
348
359
  # Call the function and expect a SchemaValidationError
349
360
  with raises(
@@ -359,11 +370,7 @@ def test_compare_data_mismatch():
359
370
  df.write.save_as_table.assert_called_once_with(
360
371
  table_name=new_checkpoint_name, mode=OVERWRITE_MODE
361
372
  )
362
- calls = [
363
- call(EXCEPT_HASH_AGG_QUERY, [checkpoint_name, new_checkpoint_name]),
364
- call().count(),
365
- ]
366
- session.sql.assert_has_calls(calls)
373
+ mock_get_comparison_differences.assert_called_once()
367
374
  job_context._mark_fail.assert_called()
368
375
  job_context._mark_pass.assert_not_called()
369
376
 
@@ -16,7 +16,7 @@ from snowflake.snowpark_checkpoints.validation_results import (
16
16
  )
17
17
  from pandas import DataFrame as PandasDataFrame, testing as PandasTesting
18
18
  from pandera import DataFrameSchema, Column, Check
19
- from snowflake.snowpark_checkpoints.checkpoint import _validate
19
+ from snowflake.snowpark_checkpoints.checkpoint import validate
20
20
 
21
21
 
22
22
  @fixture()
@@ -204,7 +204,7 @@ def test_clean_with_no_file():
204
204
 
205
205
  def test_validate_valid_schema(sample_data):
206
206
  df, valid_schema, _ = sample_data
207
- is_valid, result = _validate(valid_schema, df)
207
+ is_valid, result = validate(valid_schema, df)
208
208
  assert is_valid
209
209
  assert isinstance(result, PandasDataFrame)
210
210
  PandasTesting.assert_frame_equal(result, df)
@@ -212,7 +212,7 @@ def test_validate_valid_schema(sample_data):
212
212
 
213
213
  def test_validate_invalid_schema(sample_data):
214
214
  df, _, invalid_schema = sample_data
215
- is_valid, result = _validate(invalid_schema, df)
215
+ is_valid, result = validate(invalid_schema, df)
216
216
  assert not is_valid
217
217
  assert isinstance(result, PandasDataFrame)
218
218
  assert "failure_case" in result.columns