snowpark-checkpoints-collectors 0.1.3__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -13,10 +13,18 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import logging
17
+
18
+
19
+ # Add a NullHandler to prevent logging messages from being output to
20
+ # sys.stderr if no logging configuration is provided.
21
+ logging.getLogger(__name__).addHandler(logging.NullHandler())
22
+
23
+ # ruff: noqa: E402
24
+
16
25
  __all__ = ["collect_dataframe_checkpoint", "CheckpointMode"]
17
26
 
27
+ from snowflake.snowpark_checkpoints_collector.collection_common import CheckpointMode
18
28
  from snowflake.snowpark_checkpoints_collector.summary_stats_collector import (
19
29
  collect_dataframe_checkpoint,
20
30
  )
21
-
22
- from snowflake.snowpark_checkpoints_collector.collection_common import CheckpointMode
@@ -13,4 +13,4 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- __version__ = "0.1.3"
16
+ __version__ = "0.2.0"
@@ -12,7 +12,9 @@
12
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
+
15
16
  import json
17
+ import logging
16
18
 
17
19
  from typing import Optional
18
20
 
@@ -24,6 +26,7 @@ from snowflake.snowpark_checkpoints_collector.utils import file_utils
24
26
 
25
27
 
26
28
  RESULTS_KEY = "results"
29
+ LOGGER = logging.getLogger(__name__)
27
30
 
28
31
 
29
32
  class CollectionPointResultManager(metaclass=Singleton):
@@ -49,6 +52,7 @@ class CollectionPointResultManager(metaclass=Singleton):
49
52
 
50
53
  """
51
54
  result_json = result.get_collection_result_data()
55
+ LOGGER.debug("Adding a new collection result: %s", result_json)
52
56
  self.result_collection.append(result_json)
53
57
  self._save_result()
54
58
 
@@ -65,5 +69,6 @@ class CollectionPointResultManager(metaclass=Singleton):
65
69
 
66
70
  def _save_result(self) -> None:
67
71
  result_collection_json = self.to_json()
72
+ LOGGER.info("Saving collection results to '%s'", self.output_file_path)
68
73
  with open(self.output_file_path, "w") as f:
69
74
  f.write(result_collection_json)
@@ -12,6 +12,9 @@
12
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
+
16
+ import logging
17
+
15
18
  from pyspark.sql import DataFrame as SparkDataFrame
16
19
  from pyspark.sql.types import StructField
17
20
 
@@ -53,6 +56,9 @@ from snowflake.snowpark_checkpoints_collector.column_collection.model import (
53
56
  )
54
57
 
55
58
 
59
+ LOGGER = logging.getLogger(__name__)
60
+
61
+
56
62
  def collector_register(cls):
57
63
  """Decorate a class with the collection type mechanism.
58
64
 
@@ -63,6 +69,7 @@ def collector_register(cls):
63
69
  The class to decorate.
64
70
 
65
71
  """
72
+ LOGGER.debug("Starting to register collectors from class %s", cls.__name__)
66
73
  cls._collectors = {}
67
74
  for method_name in dir(cls):
68
75
  method = getattr(cls, method_name)
@@ -70,6 +77,11 @@ def collector_register(cls):
70
77
  col_type_collection = method._column_type
71
78
  for col_type in col_type_collection:
72
79
  cls._collectors[col_type] = method_name
80
+ LOGGER.debug(
81
+ "Registered collector '%s' for column type '%s'",
82
+ method_name,
83
+ col_type,
84
+ )
73
85
  return cls
74
86
 
75
87
 
@@ -114,10 +126,21 @@ class ColumnCollectorManager:
114
126
  """
115
127
  clm_type = struct_field.dataType.typeName()
116
128
  if clm_type not in self._collectors:
129
+ LOGGER.debug(
130
+ "No collectors found for column '%s' of type '%s'. Skipping collection for this column.",
131
+ clm_name,
132
+ clm_type,
133
+ )
117
134
  return {}
118
135
 
119
136
  func_name = self._collectors[clm_type]
120
137
  func = getattr(self, func_name)
138
+ LOGGER.debug(
139
+ "Collecting custom data for column '%s' of type '%s' using collector method '%s'",
140
+ clm_name,
141
+ clm_type,
142
+ func_name,
143
+ )
121
144
  data = func(clm_name, struct_field, values)
122
145
  return data
123
146
 
@@ -12,6 +12,9 @@
12
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
+
16
+ import logging
17
+
15
18
  import pandas as pd
16
19
 
17
20
  from pandera import Check, Column
@@ -39,6 +42,9 @@ from snowflake.snowpark_checkpoints_collector.collection_common import (
39
42
  )
40
43
 
41
44
 
45
+ LOGGER = logging.getLogger(__name__)
46
+
47
+
42
48
  def collector_register(cls):
43
49
  """Decorate a class with the checks mechanism.
44
50
 
@@ -49,6 +55,7 @@ def collector_register(cls):
49
55
  The class to decorate.
50
56
 
51
57
  """
58
+ LOGGER.debug("Starting to register checks from class %s", cls.__name__)
52
59
  cls._collectors = {}
53
60
  for method_name in dir(cls):
54
61
  method = getattr(cls, method_name)
@@ -56,6 +63,9 @@ def collector_register(cls):
56
63
  col_type_collection = method._column_type
57
64
  for col_type in col_type_collection:
58
65
  cls._collectors[col_type] = method_name
66
+ LOGGER.debug(
67
+ "Registered check '%s' for column type '%s'", method_name, col_type
68
+ )
59
69
  return cls
60
70
 
61
71
 
@@ -101,10 +111,18 @@ class PanderaColumnChecksManager:
101
111
 
102
112
  """
103
113
  if clm_type not in self._collectors:
114
+ LOGGER.debug(
115
+ "No Pandera checks found for column '%s' of type '%s'. Skipping checks for this column.",
116
+ clm_name,
117
+ clm_type,
118
+ )
104
119
  return
105
120
 
106
121
  func_name = self._collectors[clm_type]
107
122
  func = getattr(self, func_name)
123
+ LOGGER.debug(
124
+ "Adding Pandera checks to column '%s' of type '%s'", clm_name, clm_type
125
+ )
108
126
  func(clm_name, pyspark_df, pandera_column)
109
127
 
110
128
  @column_register(BOOLEAN_COLUMN_TYPE)
@@ -12,7 +12,9 @@
12
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
+
15
16
  import glob
17
+ import logging
16
18
  import os.path
17
19
  import time
18
20
 
@@ -30,6 +32,7 @@ CREATE_STAGE_STATEMENT_FORMAT = "CREATE TEMP STAGE IF NOT EXISTS {}"
30
32
  REMOVE_STAGE_FOLDER_STATEMENT_FORMAT = "REMOVE {}"
31
33
  STAGE_PATH_FORMAT = "'@{}/{}'"
32
34
  PUT_FILE_IN_STAGE_STATEMENT_FORMAT = "PUT '{}' {} AUTO_COMPRESS=FALSE"
35
+ LOGGER = logging.getLogger(__name__)
33
36
 
34
37
 
35
38
  class SnowConnection:
@@ -41,14 +44,16 @@ class SnowConnection:
41
44
 
42
45
  """
43
46
 
44
- def __init__(self, session: Session = None) -> None:
47
+ def __init__(self, session: Optional[Session] = None) -> None:
45
48
  """Init SnowConnection.
46
49
 
47
50
  Args:
48
51
  session (Snowpark.Session): the Snowpark session.
49
52
 
50
53
  """
51
- self.session = session if session is not None else Session.builder.getOrCreate()
54
+ self.session = (
55
+ session if session is not None else self._create_snowpark_session()
56
+ )
52
57
  self.stage_id = int(time.time())
53
58
 
54
59
  def create_snowflake_table_from_local_parquet(
@@ -84,8 +89,8 @@ class SnowConnection:
84
89
  stage_name, stage_path, input_path, is_parquet_file
85
90
  )
86
91
  self.create_table_from_parquet(table_name, stage_directory_path)
87
-
88
92
  finally:
93
+ LOGGER.info("Removing stage folder %s", stage_directory_path)
89
94
  self.session.sql(
90
95
  REMOVE_STAGE_FOLDER_STATEMENT_FORMAT.format(stage_directory_path)
91
96
  ).collect()
@@ -98,6 +103,7 @@ class SnowConnection:
98
103
 
99
104
  """
100
105
  create_stage_statement = CREATE_STAGE_STATEMENT_FORMAT.format(stage_name)
106
+ LOGGER.info("Creating temporal stage '%s'", stage_name)
101
107
  self.session.sql(create_stage_statement).collect()
102
108
 
103
109
  def load_files_to_stage(
@@ -105,7 +111,7 @@ class SnowConnection:
105
111
  stage_name: str,
106
112
  folder_name: str,
107
113
  input_path: str,
108
- filter_func: Callable = None,
114
+ filter_func: Optional[Callable] = None,
109
115
  ) -> None:
110
116
  """Load files to a stage in Snowflake.
111
117
 
@@ -116,6 +122,7 @@ class SnowConnection:
116
122
  filter_func (Callable): the filter function to apply to the files.
117
123
 
118
124
  """
125
+ LOGGER.info("Starting to load files to '%s'", stage_name)
119
126
  input_path = (
120
127
  os.path.abspath(input_path)
121
128
  if not os.path.isabs(input_path)
@@ -126,16 +133,20 @@ class SnowConnection:
126
133
  return os.path.isfile(name) and (filter_func(name) if filter_func else True)
127
134
 
128
135
  target_dir = os.path.join(input_path, "**", "*")
136
+ LOGGER.debug("Searching for files in '%s'", input_path)
129
137
  files_collection = glob.glob(target_dir, recursive=True)
130
138
 
131
139
  files = [file for file in files_collection if filter_files(file)]
140
+ files_count = len(files)
132
141
 
133
- if len(files) == 0:
142
+ if files_count == 0:
134
143
  raise Exception(f"No files were found in the input directory: {input_path}")
135
144
 
145
+ LOGGER.debug("Found %s files in '%s'", files_count, input_path)
146
+
136
147
  for file in files:
137
148
  # if file is relative path, convert to absolute path
138
- # if absolute path, then try to resolve as some Win32 paths are not in LPN.
149
+ # if absolute path, then try to resolve as some Win32 paths are not in LPN.
139
150
  file_full_path = (
140
151
  str(os.path.abspath(file))
141
152
  if not os.path.isabs(file)
@@ -150,6 +161,7 @@ class SnowConnection:
150
161
  put_statement = PUT_FILE_IN_STAGE_STATEMENT_FORMAT.format(
151
162
  normalize_file_path, stage_file_path
152
163
  )
164
+ LOGGER.info("Loading file '%s' to %s", file_full_path, stage_file_path)
153
165
  self.session.sql(put_statement).collect()
154
166
 
155
167
  def create_table_from_parquet(
@@ -165,8 +177,25 @@ class SnowConnection:
165
177
  Exception: No parquet files were found in the stage
166
178
 
167
179
  """
168
- files = self.session.sql(f"LIST {stage_directory_path}").collect()
169
- if len(files) == 0:
170
- raise Exception("No parquet files were found in the stage.")
180
+ LOGGER.info("Starting to create table '%s' from parquet files", table_name)
181
+ parquet_files = self.session.sql(
182
+ f"LIST {stage_directory_path} PATTERN='.*{DOT_PARQUET_EXTENSION}'"
183
+ ).collect()
184
+ parquet_files_count = len(parquet_files)
185
+ if parquet_files_count == 0:
186
+ raise Exception(
187
+ f"No parquet files were found in the stage: {stage_directory_path}"
188
+ )
189
+
190
+ LOGGER.info(
191
+ "Reading %s parquet files from %s",
192
+ parquet_files_count,
193
+ stage_directory_path,
194
+ )
171
195
  dataframe = self.session.read.parquet(path=stage_directory_path)
196
+ LOGGER.info("Creating table '%s' from parquet files", table_name)
172
197
  dataframe.write.save_as_table(table_name=table_name, mode="overwrite")
198
+
199
+ def _create_snowpark_session(self) -> Session:
200
+ LOGGER.info("Creating a Snowpark session using the default connection")
201
+ return Session.builder.getOrCreate()
@@ -12,8 +12,10 @@
12
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
+
15
16
  import glob
16
17
  import json
18
+ import logging
17
19
  import os
18
20
  import shutil
19
21
 
@@ -64,9 +66,14 @@ from snowflake.snowpark_checkpoints_collector.utils.extra_config import (
64
66
  get_checkpoint_sample,
65
67
  is_checkpoint_enabled,
66
68
  )
69
+ from snowflake.snowpark_checkpoints_collector.utils.logging_utils import log
67
70
  from snowflake.snowpark_checkpoints_collector.utils.telemetry import report_telemetry
68
71
 
69
72
 
73
+ LOGGER = logging.getLogger(__name__)
74
+
75
+
76
+ @log
70
77
  def collect_dataframe_checkpoint(
71
78
  df: SparkDataFrame,
72
79
  checkpoint_name: str,
@@ -91,80 +98,90 @@ def collect_dataframe_checkpoint(
91
98
  Exception: Invalid checkpoint name. Checkpoint names must only contain alphanumeric characters and underscores.
92
99
 
93
100
  """
94
- try:
95
- normalized_checkpoint_name = checkpoint_name_utils.normalize_checkpoint_name(
96
- checkpoint_name
101
+ normalized_checkpoint_name = checkpoint_name_utils.normalize_checkpoint_name(
102
+ checkpoint_name
103
+ )
104
+ if normalized_checkpoint_name != checkpoint_name:
105
+ LOGGER.info(
106
+ "Checkpoint name '%s' was normalized to '%s'",
107
+ checkpoint_name,
108
+ normalized_checkpoint_name,
109
+ )
110
+ is_valid_checkpoint_name = checkpoint_name_utils.is_valid_checkpoint_name(
111
+ normalized_checkpoint_name
112
+ )
113
+ if not is_valid_checkpoint_name:
114
+ raise Exception(
115
+ f"Invalid checkpoint name: {normalized_checkpoint_name}. "
116
+ f"Checkpoint names must only contain alphanumeric characters and underscores."
97
117
  )
98
- is_valid_checkpoint_name = checkpoint_name_utils.is_valid_checkpoint_name(
99
- normalized_checkpoint_name
118
+ if not is_checkpoint_enabled(normalized_checkpoint_name):
119
+ LOGGER.info(
120
+ "Checkpoint '%s' is disabled. Skipping collection.",
121
+ normalized_checkpoint_name,
100
122
  )
101
- if not is_valid_checkpoint_name:
123
+ return
124
+
125
+ LOGGER.info("Starting to collect checkpoint '%s'", normalized_checkpoint_name)
126
+ LOGGER.debug("DataFrame size: %s rows", df.count())
127
+ LOGGER.debug("DataFrame schema: %s", df.schema)
128
+
129
+ collection_point_file_path = file_utils.get_collection_point_source_file_path()
130
+ collection_point_line_of_code = file_utils.get_collection_point_line_of_code()
131
+ collection_point_result = CollectionPointResult(
132
+ collection_point_file_path,
133
+ collection_point_line_of_code,
134
+ normalized_checkpoint_name,
135
+ )
136
+
137
+ try:
138
+ if _is_empty_dataframe_without_schema(df):
102
139
  raise Exception(
103
- f"Invalid checkpoint name: {checkpoint_name}. Checkpoint names must only contain alphanumeric "
104
- f"characters and underscores."
140
+ "It is not possible to collect an empty DataFrame without schema"
105
141
  )
106
142
 
107
- if is_checkpoint_enabled(normalized_checkpoint_name):
143
+ _mode = get_checkpoint_mode(normalized_checkpoint_name, mode)
108
144
 
109
- collection_point_file_path = (
110
- file_utils.get_collection_point_source_file_path()
111
- )
112
- collection_point_line_of_code = (
113
- file_utils.get_collection_point_line_of_code()
145
+ if _mode == CheckpointMode.SCHEMA:
146
+ column_type_dict = _get_spark_column_types(df)
147
+ _sample = get_checkpoint_sample(normalized_checkpoint_name, sample)
148
+ LOGGER.info(
149
+ "Collecting checkpoint in %s mode using sample value %s",
150
+ CheckpointMode.SCHEMA.name,
151
+ _sample,
114
152
  )
115
- collection_point_result = CollectionPointResult(
116
- collection_point_file_path,
117
- collection_point_line_of_code,
153
+ _collect_dataframe_checkpoint_mode_schema(
118
154
  normalized_checkpoint_name,
155
+ df,
156
+ _sample,
157
+ column_type_dict,
158
+ output_path,
119
159
  )
160
+ elif _mode == CheckpointMode.DATAFRAME:
161
+ LOGGER.info(
162
+ "Collecting checkpoint in %s mode", CheckpointMode.DATAFRAME.name
163
+ )
164
+ snow_connection = SnowConnection()
165
+ _collect_dataframe_checkpoint_mode_dataframe(
166
+ normalized_checkpoint_name, df, snow_connection, output_path
167
+ )
168
+ else:
169
+ raise Exception(f"Invalid mode value: {_mode}")
120
170
 
121
- try:
122
-
123
- _sample = get_checkpoint_sample(normalized_checkpoint_name, sample)
124
-
125
- if _is_empty_dataframe_without_schema(df):
126
- raise Exception(
127
- "It is not possible to collect an empty DataFrame without schema"
128
- )
129
-
130
- _mode = get_checkpoint_mode(normalized_checkpoint_name, mode)
131
-
132
- if _mode == CheckpointMode.SCHEMA:
133
- column_type_dict = _get_spark_column_types(df)
134
- _collect_dataframe_checkpoint_mode_schema(
135
- normalized_checkpoint_name,
136
- df,
137
- _sample,
138
- column_type_dict,
139
- output_path,
140
- )
141
-
142
- elif _mode == CheckpointMode.DATAFRAME:
143
- snow_connection = SnowConnection()
144
- _collect_dataframe_checkpoint_mode_dataframe(
145
- normalized_checkpoint_name, df, snow_connection, output_path
146
- )
147
-
148
- else:
149
- raise Exception("Invalid mode value.")
150
-
151
- collection_point_result.result = CollectionResult.PASS
152
-
153
- except Exception as err:
154
- collection_point_result.result = CollectionResult.FAIL
155
- error_message = str(err)
156
- raise Exception(error_message) from err
157
-
158
- finally:
159
- collection_point_result_manager = CollectionPointResultManager(
160
- output_path
161
- )
162
- collection_point_result_manager.add_result(collection_point_result)
171
+ collection_point_result.result = CollectionResult.PASS
172
+ LOGGER.info(
173
+ "Checkpoint '%s' collected successfully", normalized_checkpoint_name
174
+ )
163
175
 
164
176
  except Exception as err:
177
+ collection_point_result.result = CollectionResult.FAIL
165
178
  error_message = str(err)
166
179
  raise Exception(error_message) from err
167
180
 
181
+ finally:
182
+ collection_point_result_manager = CollectionPointResultManager(output_path)
183
+ collection_point_result_manager.add_result(collection_point_result)
184
+
168
185
 
169
186
  @report_telemetry(params_list=["column_type_dict"])
170
187
  def _collect_dataframe_checkpoint_mode_schema(
@@ -176,12 +193,19 @@ def _collect_dataframe_checkpoint_mode_schema(
176
193
  ) -> None:
177
194
  sampled_df = df.sample(sample)
178
195
  if sampled_df.isEmpty():
196
+ LOGGER.warning("Sampled DataFrame is empty. Collecting full DataFrame.")
179
197
  sampled_df = df
198
+
180
199
  pandas_df = _to_pandas(sampled_df)
181
200
  is_empty_df_with_object_column = _is_empty_dataframe_with_object_column(df)
182
- pandera_infer_schema = (
183
- pa.infer_schema(pandas_df) if not is_empty_df_with_object_column else {}
184
- )
201
+ if is_empty_df_with_object_column:
202
+ LOGGER.debug(
203
+ "DataFrame is empty with object column. Skipping Pandera schema inference."
204
+ )
205
+ pandera_infer_schema = {}
206
+ else:
207
+ LOGGER.debug("Inferring Pandera schema from DataFrame")
208
+ pandera_infer_schema = pa.infer_schema(pandas_df)
185
209
 
186
210
  column_name_collection = df.schema.names
187
211
  columns_to_remove_from_pandera_schema_collection = []
@@ -192,19 +216,20 @@ def _collect_dataframe_checkpoint_mode_schema(
192
216
  for column_name in column_name_collection:
193
217
  struct_field_column = column_type_dict[column_name]
194
218
  column_type = struct_field_column.dataType.typeName()
219
+ LOGGER.info("Collecting column '%s' of type '%s'", column_name, column_type)
195
220
  pyspark_column = df.select(col(column_name))
196
221
 
197
- is_empty_column = (
198
- pyspark_column.dropna().isEmpty() and column_type is not NULL_COLUMN_TYPE
199
- )
200
222
  is_column_to_remove_from_pandera_schema = (
201
223
  _is_column_to_remove_from_pandera_schema(column_type)
202
224
  )
203
-
204
225
  if is_column_to_remove_from_pandera_schema:
205
226
  columns_to_remove_from_pandera_schema_collection.append(column_name)
206
227
 
228
+ is_empty_column = (
229
+ pyspark_column.dropna().isEmpty() and column_type is not NULL_COLUMN_TYPE
230
+ )
207
231
  if is_empty_column:
232
+ LOGGER.debug("Column '%s' is empty.", column_name)
208
233
  custom_data = column_collector_manager.collect_empty_custom_data(
209
234
  column_name, struct_field_column, pyspark_column
210
235
  )
@@ -280,6 +305,7 @@ def _get_pandera_infer_schema_as_dict(
280
305
 
281
306
  pandera_infer_schema_dict = json.loads(pandera_infer_schema.to_json())
282
307
  for column in columns_to_remove_collection:
308
+ LOGGER.debug("Removing column '%s' from Pandera schema", column)
283
309
  del pandera_infer_schema_dict[COLUMNS_KEY][column]
284
310
 
285
311
  return pandera_infer_schema_dict
@@ -293,6 +319,7 @@ def _generate_json_checkpoint_file(
293
319
  )
294
320
  output_directory_path = file_utils.get_output_directory_path(output_path)
295
321
  checkpoint_file_path = os.path.join(output_directory_path, checkpoint_file_name)
322
+ LOGGER.info("Writing DataFrame JSON schema file to '%s'", checkpoint_file_path)
296
323
  with open(checkpoint_file_path, "w") as f:
297
324
  f.write(dataframe_schema_contract)
298
325
 
@@ -339,14 +366,24 @@ def generate_parquet_for_spark_df(spark_df: SparkDataFrame, output_path: str) ->
339
366
  converted_df = spark_df.select(new_cols)
340
367
 
341
368
  if os.path.exists(output_path):
369
+ LOGGER.warning(
370
+ "Output directory '%s' already exists. Deleting it...", output_path
371
+ )
342
372
  shutil.rmtree(output_path)
343
373
 
374
+ LOGGER.info("Writing DataFrame to parquet files at '%s'", output_path)
344
375
  converted_df.write.parquet(output_path, mode="overwrite")
345
376
 
346
377
  target_dir = os.path.join(output_path, "**", f"*{DOT_PARQUET_EXTENSION}")
347
- files = glob.glob(target_dir, recursive=True)
348
- if len(files) == 0:
378
+ parquet_files = glob.glob(target_dir, recursive=True)
379
+ parquet_files_count = len(parquet_files)
380
+ if parquet_files_count == 0:
349
381
  raise Exception("No parquet files were generated.")
382
+ LOGGER.info(
383
+ "%s parquet files were written in '%s'",
384
+ parquet_files_count,
385
+ output_path,
386
+ )
350
387
 
351
388
 
352
389
  def _create_snowflake_table_from_parquet(
@@ -356,11 +393,17 @@ def _create_snowflake_table_from_parquet(
356
393
 
357
394
 
358
395
  def _to_pandas(sampled_df: SparkDataFrame) -> pandas.DataFrame:
396
+ LOGGER.debug("Converting Spark DataFrame to Pandas DataFrame")
359
397
  pandas_df = sampled_df.toPandas()
360
398
  for field in sampled_df.schema.fields:
361
399
  has_nan = pandas_df[field.name].isna().any()
362
400
  is_integer = field.dataType.typeName() in INTEGER_TYPE_COLLECTION
363
401
  if has_nan and is_integer:
402
+ LOGGER.debug(
403
+ "Converting column '%s' to '%s' type",
404
+ field.name,
405
+ PANDAS_LONG_TYPE,
406
+ )
364
407
  pandas_df[field.name] = pandas_df[field.name].astype(PANDAS_LONG_TYPE)
365
408
 
366
409
  return pandas_df
@@ -12,6 +12,8 @@
12
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
+
16
+ import logging
15
17
  import os
16
18
 
17
19
  from typing import Optional
@@ -22,6 +24,8 @@ from snowflake.snowpark_checkpoints_collector.collection_common import (
22
24
  )
23
25
 
24
26
 
27
+ LOGGER = logging.getLogger(__name__)
28
+
25
29
  # noinspection DuplicatedCode
26
30
  def _get_checkpoint_contract_file_path() -> str:
27
31
  return os.environ.get(SNOWFLAKE_CHECKPOINT_CONTRACT_FILE_PATH_ENV_VAR, os.getcwd())
@@ -35,10 +39,14 @@ def _get_metadata():
35
39
  )
36
40
 
37
41
  path = _get_checkpoint_contract_file_path()
42
+ LOGGER.debug("Loading checkpoint metadata from '%s'", path)
38
43
  metadata = CheckpointMetadata(path)
39
44
  return True, metadata
40
45
 
41
46
  except ImportError:
47
+ LOGGER.debug(
48
+ "snowpark-checkpoints-configuration is not installed. Cannot get a checkpoint metadata instance."
49
+ )
42
50
  return False, None
43
51
 
44
52
 
@@ -56,8 +64,7 @@ def is_checkpoint_enabled(checkpoint_name: str) -> bool:
56
64
  if enabled:
57
65
  config = metadata.get_checkpoint(checkpoint_name)
58
66
  return config.enabled
59
- else:
60
- return True
67
+ return True
61
68
 
62
69
 
63
70
  def get_checkpoint_sample(
@@ -0,0 +1,67 @@
1
+ # Copyright 2025 Snowflake Inc.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import logging
17
+
18
+ from functools import wraps
19
+ from typing import Callable, Optional, TypeVar
20
+
21
+ from typing_extensions import ParamSpec
22
+
23
+
24
+ P = ParamSpec("P")
25
+ R = TypeVar("R")
26
+
27
+
28
+ def log(
29
+ _func: Optional[Callable[P, R]] = None,
30
+ *,
31
+ logger: Optional[logging.Logger] = None,
32
+ log_args: bool = True,
33
+ ) -> Callable[[Callable[P, R]], Callable[P, R]]:
34
+ """Log the function call and any exceptions that occur.
35
+
36
+ Args:
37
+ _func: The function to log.
38
+ logger: The logger to use for logging. If not provided, a logger will be created using the
39
+ function's module name.
40
+ log_args: Whether to log the arguments passed to the function.
41
+
42
+ Returns:
43
+ A decorator that logs the function call and any exceptions that occur.
44
+
45
+ """
46
+
47
+ def decorator(func: Callable[P, R]) -> Callable[P, R]:
48
+ @wraps(func)
49
+ def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
50
+ _logger = logging.getLogger(func.__module__) if logger is None else logger
51
+ if log_args:
52
+ args_repr = [repr(a) for a in args]
53
+ kwargs_repr = [f"{k}={v!r}" for k, v in kwargs.items()]
54
+ formatted_args = ", ".join([*args_repr, *kwargs_repr])
55
+ _logger.debug("%s called with args %s", func.__name__, formatted_args)
56
+ try:
57
+ return func(*args, **kwargs)
58
+ except Exception:
59
+ _logger.exception("An error occurred in %s", func.__name__)
60
+ raise
61
+
62
+ return wrapper
63
+
64
+ # Handle the case where the decorator is used without parentheses
65
+ if _func is None:
66
+ return decorator
67
+ return decorator(_func)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: snowpark-checkpoints-collectors
3
- Version: 0.1.3
3
+ Version: 0.2.0
4
4
  Summary: Snowpark column and table statistics collection
5
5
  Project-URL: Bug Tracker, https://github.com/snowflakedb/snowpark-checkpoints/issues
6
6
  Project-URL: Source code, https://github.com/snowflakedb/snowpark-checkpoints/
@@ -27,19 +27,21 @@ Classifier: Topic :: Software Development :: Libraries :: Application Frameworks
27
27
  Classifier: Topic :: Software Development :: Libraries :: Python Modules
28
28
  Requires-Python: <3.12,>=3.9
29
29
  Requires-Dist: pandera[io]==0.20.4
30
- Requires-Dist: pyspark
31
30
  Requires-Dist: snowflake-connector-python
32
- Requires-Dist: snowflake-snowpark-python==1.26.0
31
+ Requires-Dist: snowflake-snowpark-python>=1.23.0
33
32
  Provides-Extra: development
34
33
  Requires-Dist: coverage>=7.6.7; extra == 'development'
35
34
  Requires-Dist: deepdiff>=8.0.0; extra == 'development'
36
35
  Requires-Dist: hatchling==1.25.0; extra == 'development'
37
36
  Requires-Dist: pre-commit>=4.0.1; extra == 'development'
38
37
  Requires-Dist: pyarrow>=18.0.0; extra == 'development'
38
+ Requires-Dist: pyspark>=3.5.0; extra == 'development'
39
39
  Requires-Dist: pytest-cov>=6.0.0; extra == 'development'
40
40
  Requires-Dist: pytest>=8.3.3; extra == 'development'
41
41
  Requires-Dist: setuptools>=70.0.0; extra == 'development'
42
42
  Requires-Dist: twine==5.1.1; extra == 'development'
43
+ Provides-Extra: pyspark
44
+ Requires-Dist: pyspark>=3.5.0; extra == 'pyspark'
43
45
  Description-Content-Type: text/markdown
44
46
 
45
47
  # snowpark-checkpoints-collectors
@@ -50,6 +52,18 @@ Description-Content-Type: text/markdown
50
52
  ---
51
53
 
52
54
  **snowpark-checkpoints-collector** package offers a function for extracting information from PySpark dataframes. We can then use that data to validate against the converted Snowpark dataframes to ensure that behavioral equivalence has been achieved.
55
+
56
+ ---
57
+ ## Install the library
58
+ ```bash
59
+ pip install snowpark-checkpoints-collectors
60
+ ```
61
+ This package requires PySpark to be installed in the same environment. If you do not have it, you can install PySpark alongside Snowpark Checkpoints by running the following command:
62
+ ```bash
63
+ pip install "snowpark-checkpoints-collectors[pyspark]"
64
+ ```
65
+ ---
66
+
53
67
  ## Features
54
68
 
55
69
  - Schema inference collected data mode (Schema): This is the default mode, which leverages Pandera schema inference to obtain the metadata and checks that will be evaluated for the specified dataframe. This mode also collects custom data from columns of the DataFrame based on the PySpark type.
@@ -1,13 +1,13 @@
1
- snowflake/snowpark_checkpoints_collector/__init__.py,sha256=yf_DmREHUwtC8y_boY8iaQC3qaKi1miEb5kytllrAaw,874
2
- snowflake/snowpark_checkpoints_collector/__version__.py,sha256=OfdAqrd8gnFI-pK7o_olRVrRKIWfQhQOoo_wR3u1s5s,632
1
+ snowflake/snowpark_checkpoints_collector/__init__.py,sha256=GIESlH2W6g_qdcnyRqw9yjsvEkt0aniFvGixKlF4K7A,1096
2
+ snowflake/snowpark_checkpoints_collector/__version__.py,sha256=ajnGza8ucK69-PA8wEbHmWZxDwd3bsTm74yMKiIWNHY,632
3
3
  snowflake/snowpark_checkpoints_collector/collection_common.py,sha256=ff5vYffrTRjoJXZQvVQBaOlegAUj_vXBbl1IZidz8Qo,4510
4
4
  snowflake/snowpark_checkpoints_collector/singleton.py,sha256=7AgIHQBXVRvPBBCkmBplzkdrrm-xVWf_N8svzA2vF8E,836
5
- snowflake/snowpark_checkpoints_collector/summary_stats_collector.py,sha256=cvG1C9rLyF4w3Fybr3o_cno6mEHbXsbU17D_y2RrNck,12823
5
+ snowflake/snowpark_checkpoints_collector/summary_stats_collector.py,sha256=_U-gfBjk2QU_dDyJPGKekfzuP1Stkx-FyTuZiecvt6M,14572
6
6
  snowflake/snowpark_checkpoints_collector/collection_result/model/__init__.py,sha256=jZzx29WzrjH7C_6ZsBGoe4PxbW_oM4uIjySS1axIM34,1000
7
7
  snowflake/snowpark_checkpoints_collector/collection_result/model/collection_point_result.py,sha256=8xD9zGnFJ7Rz9RUXIys7JnV3kQD4mk8QwNOTxAihSjQ,2908
8
- snowflake/snowpark_checkpoints_collector/collection_result/model/collection_point_result_manager.py,sha256=4rFBPUdjjf-SuqEaz0_lxBv8szEWI6N1x48P6zDbqVw,2360
8
+ snowflake/snowpark_checkpoints_collector/collection_result/model/collection_point_result_manager.py,sha256=6XbjHiehEm_RN_9y2MRlr0MaSgk3cWTczwZEYqUHCpM,2565
9
9
  snowflake/snowpark_checkpoints_collector/column_collection/__init__.py,sha256=hpTh1V7hqBSHxNUqISwfxdz-NLD-7oZEMLXDUuRsoOU,783
10
- snowflake/snowpark_checkpoints_collector/column_collection/column_collector_manager.py,sha256=_8CjfN0Q6g0g_hkvx6zBMat0RNAqQ89xfkid0MPLsRE,8961
10
+ snowflake/snowpark_checkpoints_collector/column_collection/column_collector_manager.py,sha256=Vav_vbiipHFIAdHxeQG4ZK1BAmWTi_18hBnVeIeXFRs,9670
11
11
  snowflake/snowpark_checkpoints_collector/column_collection/model/__init__.py,sha256=d0WNMeayDyUKYFLLaVAMIC5Qt-DoWoWgOjj2ygJaHWA,2919
12
12
  snowflake/snowpark_checkpoints_collector/column_collection/model/array_column_collector.py,sha256=10ITldLcri_3LoQaqrZJMUwvpcgs5gQy3-BFKQB77EA,4268
13
13
  snowflake/snowpark_checkpoints_collector/column_collection/model/binary_column_collector.py,sha256=TuvKnwCIyoc3B9DfSeckGk6-bLLrDVDZdW8NDFkitMI,3255
@@ -25,14 +25,15 @@ snowflake/snowpark_checkpoints_collector/column_collection/model/struct_column_c
25
25
  snowflake/snowpark_checkpoints_collector/column_collection/model/timestamp_column_collector.py,sha256=FAWxRUX25ep2XhhagsBVuNmB3QUMA1xMfNTVkaHilbY,2572
26
26
  snowflake/snowpark_checkpoints_collector/column_collection/model/timestamp_ntz_column_collector.py,sha256=glUUnCLgTbGiPLpF2pSZ11KCgKSpHDRt5uhi1ZT9bxA,2578
27
27
  snowflake/snowpark_checkpoints_collector/column_pandera_checks/__init__.py,sha256=JNZPOYx8rUTONGz_d7xyfAvEC2_umHmGkJLoNSATLs4,793
28
- snowflake/snowpark_checkpoints_collector/column_pandera_checks/pandera_column_checks_manager.py,sha256=uugv4Pyq0wpYvJRFyQmJR1SvnXjlqBNHTLIDiTLTLhA,7311
28
+ snowflake/snowpark_checkpoints_collector/column_pandera_checks/pandera_column_checks_manager.py,sha256=X1Mm37DKt-WZ5AegvoUA3itU1nBUxvhBxvjO85QqcGE,7893
29
29
  snowflake/snowpark_checkpoints_collector/snow_connection_model/__init__.py,sha256=kLjZId-aGCljK7lF6yeEw-syEqeTOJDxdXfpv9YxvZA,755
30
- snowflake/snowpark_checkpoints_collector/snow_connection_model/snow_connection.py,sha256=QH3kPQ5rHS9CV7f-djw0mhM7KT99cFNYXpjU6ADJHuo,6047
30
+ snowflake/snowpark_checkpoints_collector/snow_connection_model/snow_connection.py,sha256=odKGTzc0xov8WOgJSR6WmVs0IT-f6O4YoaLqH6CbbFo,7263
31
31
  snowflake/snowpark_checkpoints_collector/utils/checkpoint_name_utils.py,sha256=WExQaZ4oL4otDCtM8kyGbf0Gn_v1a-tzM5j1p0wVDVg,1767
32
- snowflake/snowpark_checkpoints_collector/utils/extra_config.py,sha256=xkXFH1PIS0Mtzpu-LrcOKBjzCbptp2zWqgGN9X1P_A0,3393
32
+ snowflake/snowpark_checkpoints_collector/utils/extra_config.py,sha256=t8WakSiHA3sgnXxz0WXE7q2MG7czWlnSYB5XR9swIhs,3643
33
33
  snowflake/snowpark_checkpoints_collector/utils/file_utils.py,sha256=deetkhQZOB0GUxQJvUHw4Ridp_rNYiCqmK9li3uwBL0,4324
34
+ snowflake/snowpark_checkpoints_collector/utils/logging_utils.py,sha256=yyi6X5DqKeTg0HRhvsH6ymYp2P0wbnyKIzI2RzrQS7k,2278
34
35
  snowflake/snowpark_checkpoints_collector/utils/telemetry.py,sha256=7S0yFE3Zq96SEGmVuVbpYc_wtXIQUpL--6KfGoxwJcA,30837
35
- snowpark_checkpoints_collectors-0.1.3.dist-info/METADATA,sha256=gfG0BmaLZS39w6mhL2nQ5qP9XrAxTU4hBgst0iZTaCk,5559
36
- snowpark_checkpoints_collectors-0.1.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
37
- snowpark_checkpoints_collectors-0.1.3.dist-info/licenses/LICENSE,sha256=DVQuDIgE45qn836wDaWnYhSdxoLXgpRRKH4RuTjpRZQ,10174
38
- snowpark_checkpoints_collectors-0.1.3.dist-info/RECORD,,
36
+ snowpark_checkpoints_collectors-0.2.0.dist-info/METADATA,sha256=LPo0O5OEDHGXHKla-KDJioKIX8bqwBPbgP6BS8ufnQA,6003
37
+ snowpark_checkpoints_collectors-0.2.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
38
+ snowpark_checkpoints_collectors-0.2.0.dist-info/licenses/LICENSE,sha256=DVQuDIgE45qn836wDaWnYhSdxoLXgpRRKH4RuTjpRZQ,10174
39
+ snowpark_checkpoints_collectors-0.2.0.dist-info/RECORD,,