snowpark-checkpoints-collectors 0.1.0rc2__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/snowpark_checkpoints_collector/__init__.py +22 -0
- snowflake/snowpark_checkpoints_collector/collection_common.py +160 -0
- snowflake/snowpark_checkpoints_collector/collection_result/model/__init__.py +24 -0
- snowflake/snowpark_checkpoints_collector/collection_result/model/collection_point_result.py +91 -0
- snowflake/snowpark_checkpoints_collector/collection_result/model/collection_point_result_manager.py +69 -0
- snowflake/snowpark_checkpoints_collector/column_collection/__init__.py +22 -0
- snowflake/snowpark_checkpoints_collector/column_collection/column_collector_manager.py +253 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/__init__.py +75 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/array_column_collector.py +113 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/binary_column_collector.py +87 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/boolean_column_collector.py +71 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/column_collector_base.py +95 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/date_column_collector.py +74 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/day_time_interval_column_collector.py +67 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/decimal_column_collector.py +92 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/empty_column_collector.py +88 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/map_column_collector.py +120 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/null_column_collector.py +49 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/numeric_column_collector.py +108 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/string_column_collector.py +70 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/struct_column_collector.py +102 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/timestamp_column_collector.py +75 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/timestamp_ntz_column_collector.py +75 -0
- snowflake/snowpark_checkpoints_collector/column_pandera_checks/__init__.py +20 -0
- snowflake/snowpark_checkpoints_collector/column_pandera_checks/pandera_column_checks_manager.py +223 -0
- snowflake/snowpark_checkpoints_collector/singleton.py +23 -0
- snowflake/snowpark_checkpoints_collector/snow_connection_model/__init__.py +20 -0
- snowflake/snowpark_checkpoints_collector/snow_connection_model/snow_connection.py +172 -0
- snowflake/snowpark_checkpoints_collector/summary_stats_collector.py +366 -0
- snowflake/snowpark_checkpoints_collector/utils/checkpoint_name_utils.py +53 -0
- snowflake/snowpark_checkpoints_collector/utils/extra_config.py +112 -0
- snowflake/snowpark_checkpoints_collector/utils/file_utils.py +132 -0
- snowflake/snowpark_checkpoints_collector/utils/telemetry.py +889 -0
- snowpark_checkpoints_collectors-0.1.1.dist-info/METADATA +143 -0
- snowpark_checkpoints_collectors-0.1.1.dist-info/RECORD +37 -0
- {snowpark_checkpoints_collectors-0.1.0rc2.dist-info → snowpark_checkpoints_collectors-0.1.1.dist-info}/licenses/LICENSE +0 -25
- snowpark_checkpoints_collectors-0.1.0rc2.dist-info/METADATA +0 -347
- snowpark_checkpoints_collectors-0.1.0rc2.dist-info/RECORD +0 -4
- {snowpark_checkpoints_collectors-0.1.0rc2.dist-info → snowpark_checkpoints_collectors-0.1.1.dist-info}/WHEEL +0 -0
@@ -0,0 +1,366 @@
|
|
1
|
+
# Copyright 2025 Snowflake Inc.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
import glob
|
16
|
+
import json
|
17
|
+
import os
|
18
|
+
import shutil
|
19
|
+
|
20
|
+
from typing import Optional
|
21
|
+
|
22
|
+
import pandas
|
23
|
+
import pandera as pa
|
24
|
+
|
25
|
+
from pyspark.sql import DataFrame as SparkDataFrame
|
26
|
+
from pyspark.sql.functions import col
|
27
|
+
from pyspark.sql.types import DoubleType as SparkDoubleType
|
28
|
+
from pyspark.sql.types import StringType as SparkStringType
|
29
|
+
from pyspark.sql.types import StructField
|
30
|
+
|
31
|
+
from snowflake.snowpark_checkpoints_collector.collection_common import (
|
32
|
+
CHECKPOINT_JSON_OUTPUT_FILE_NAME_FORMAT,
|
33
|
+
COLUMNS_KEY,
|
34
|
+
DATAFRAME_CUSTOM_DATA_KEY,
|
35
|
+
DATAFRAME_PANDERA_SCHEMA_KEY,
|
36
|
+
DECIMAL_COLUMN_TYPE,
|
37
|
+
DOT_PARQUET_EXTENSION,
|
38
|
+
INTEGER_TYPE_COLLECTION,
|
39
|
+
NULL_COLUMN_TYPE,
|
40
|
+
PANDAS_LONG_TYPE,
|
41
|
+
PANDAS_OBJECT_TYPE_COLLECTION,
|
42
|
+
CheckpointMode,
|
43
|
+
)
|
44
|
+
from snowflake.snowpark_checkpoints_collector.collection_result.model import (
|
45
|
+
CollectionPointResult,
|
46
|
+
CollectionPointResultManager,
|
47
|
+
CollectionResult,
|
48
|
+
)
|
49
|
+
from snowflake.snowpark_checkpoints_collector.column_collection import (
|
50
|
+
ColumnCollectorManager,
|
51
|
+
)
|
52
|
+
from snowflake.snowpark_checkpoints_collector.column_pandera_checks import (
|
53
|
+
PanderaColumnChecksManager,
|
54
|
+
)
|
55
|
+
from snowflake.snowpark_checkpoints_collector.snow_connection_model import (
|
56
|
+
SnowConnection,
|
57
|
+
)
|
58
|
+
from snowflake.snowpark_checkpoints_collector.utils import (
|
59
|
+
checkpoint_name_utils,
|
60
|
+
file_utils,
|
61
|
+
)
|
62
|
+
from snowflake.snowpark_checkpoints_collector.utils.extra_config import (
|
63
|
+
get_checkpoint_mode,
|
64
|
+
get_checkpoint_sample,
|
65
|
+
is_checkpoint_enabled,
|
66
|
+
)
|
67
|
+
from snowflake.snowpark_checkpoints_collector.utils.telemetry import report_telemetry
|
68
|
+
|
69
|
+
|
70
|
+
def collect_dataframe_checkpoint(
|
71
|
+
df: SparkDataFrame,
|
72
|
+
checkpoint_name: str,
|
73
|
+
sample: Optional[float] = None,
|
74
|
+
mode: Optional[CheckpointMode] = None,
|
75
|
+
output_path: Optional[str] = None,
|
76
|
+
) -> None:
|
77
|
+
"""Collect a DataFrame checkpoint.
|
78
|
+
|
79
|
+
Args:
|
80
|
+
df (SparkDataFrame): The input Spark DataFrame to collect.
|
81
|
+
checkpoint_name (str): The name of the checkpoint.
|
82
|
+
sample (float, optional): Fraction of DataFrame to sample for schema inference.
|
83
|
+
Defaults to 1.0.
|
84
|
+
mode (CheckpointMode): The mode to execution the collection.
|
85
|
+
Defaults to CheckpointMode.Schema
|
86
|
+
output_path (str, optional): The output path to save the checkpoint.
|
87
|
+
Defaults to Current working Directory.
|
88
|
+
|
89
|
+
Raises:
|
90
|
+
Exception: Invalid mode value.
|
91
|
+
Exception: Invalid checkpoint name. Checkpoint names must only contain alphanumeric characters and underscores.
|
92
|
+
|
93
|
+
"""
|
94
|
+
try:
|
95
|
+
normalized_checkpoint_name = checkpoint_name_utils.normalize_checkpoint_name(
|
96
|
+
checkpoint_name
|
97
|
+
)
|
98
|
+
is_valid_checkpoint_name = checkpoint_name_utils.is_valid_checkpoint_name(
|
99
|
+
normalized_checkpoint_name
|
100
|
+
)
|
101
|
+
if not is_valid_checkpoint_name:
|
102
|
+
raise Exception(
|
103
|
+
f"Invalid checkpoint name: {checkpoint_name}. Checkpoint names must only contain alphanumeric "
|
104
|
+
f"characters and underscores."
|
105
|
+
)
|
106
|
+
|
107
|
+
if is_checkpoint_enabled(normalized_checkpoint_name):
|
108
|
+
|
109
|
+
collection_point_file_path = (
|
110
|
+
file_utils.get_collection_point_source_file_path()
|
111
|
+
)
|
112
|
+
collection_point_line_of_code = (
|
113
|
+
file_utils.get_collection_point_line_of_code()
|
114
|
+
)
|
115
|
+
collection_point_result = CollectionPointResult(
|
116
|
+
collection_point_file_path,
|
117
|
+
collection_point_line_of_code,
|
118
|
+
normalized_checkpoint_name,
|
119
|
+
)
|
120
|
+
|
121
|
+
try:
|
122
|
+
|
123
|
+
_sample = get_checkpoint_sample(normalized_checkpoint_name, sample)
|
124
|
+
|
125
|
+
if _is_empty_dataframe_without_schema(df):
|
126
|
+
raise Exception(
|
127
|
+
"It is not possible to collect an empty DataFrame without schema"
|
128
|
+
)
|
129
|
+
|
130
|
+
_mode = get_checkpoint_mode(normalized_checkpoint_name, mode)
|
131
|
+
|
132
|
+
if _mode == CheckpointMode.SCHEMA:
|
133
|
+
column_type_dict = _get_spark_column_types(df)
|
134
|
+
_collect_dataframe_checkpoint_mode_schema(
|
135
|
+
normalized_checkpoint_name,
|
136
|
+
df,
|
137
|
+
_sample,
|
138
|
+
column_type_dict,
|
139
|
+
output_path,
|
140
|
+
)
|
141
|
+
|
142
|
+
elif _mode == CheckpointMode.DATAFRAME:
|
143
|
+
snow_connection = SnowConnection()
|
144
|
+
_collect_dataframe_checkpoint_mode_dataframe(
|
145
|
+
normalized_checkpoint_name, df, snow_connection, output_path
|
146
|
+
)
|
147
|
+
|
148
|
+
else:
|
149
|
+
raise Exception("Invalid mode value.")
|
150
|
+
|
151
|
+
collection_point_result.result = CollectionResult.PASS
|
152
|
+
|
153
|
+
except Exception as err:
|
154
|
+
collection_point_result.result = CollectionResult.FAIL
|
155
|
+
error_message = str(err)
|
156
|
+
raise Exception(error_message) from err
|
157
|
+
|
158
|
+
finally:
|
159
|
+
collection_point_result_manager = CollectionPointResultManager(
|
160
|
+
output_path
|
161
|
+
)
|
162
|
+
collection_point_result_manager.add_result(collection_point_result)
|
163
|
+
|
164
|
+
except Exception as err:
|
165
|
+
error_message = str(err)
|
166
|
+
raise Exception(error_message) from err
|
167
|
+
|
168
|
+
|
169
|
+
@report_telemetry(params_list=["column_type_dict"])
|
170
|
+
def _collect_dataframe_checkpoint_mode_schema(
|
171
|
+
checkpoint_name: str,
|
172
|
+
df: SparkDataFrame,
|
173
|
+
sample: float,
|
174
|
+
column_type_dict: dict[str, any],
|
175
|
+
output_path: Optional[str] = None,
|
176
|
+
) -> None:
|
177
|
+
sampled_df = df.sample(sample)
|
178
|
+
if sampled_df.isEmpty():
|
179
|
+
sampled_df = df
|
180
|
+
pandas_df = _to_pandas(sampled_df)
|
181
|
+
is_empty_df_with_object_column = _is_empty_dataframe_with_object_column(df)
|
182
|
+
pandera_infer_schema = (
|
183
|
+
pa.infer_schema(pandas_df) if not is_empty_df_with_object_column else {}
|
184
|
+
)
|
185
|
+
|
186
|
+
column_name_collection = df.schema.names
|
187
|
+
columns_to_remove_from_pandera_schema_collection = []
|
188
|
+
column_custom_data_collection = []
|
189
|
+
column_collector_manager = ColumnCollectorManager()
|
190
|
+
column_pandera_checks_manager = PanderaColumnChecksManager()
|
191
|
+
|
192
|
+
for column_name in column_name_collection:
|
193
|
+
struct_field_column = column_type_dict[column_name]
|
194
|
+
column_type = struct_field_column.dataType.typeName()
|
195
|
+
pyspark_column = df.select(col(column_name))
|
196
|
+
|
197
|
+
is_empty_column = (
|
198
|
+
pyspark_column.dropna().isEmpty() and column_type is not NULL_COLUMN_TYPE
|
199
|
+
)
|
200
|
+
is_column_to_remove_from_pandera_schema = (
|
201
|
+
_is_column_to_remove_from_pandera_schema(column_type)
|
202
|
+
)
|
203
|
+
|
204
|
+
if is_column_to_remove_from_pandera_schema:
|
205
|
+
columns_to_remove_from_pandera_schema_collection.append(column_name)
|
206
|
+
|
207
|
+
if is_empty_column:
|
208
|
+
custom_data = column_collector_manager.collect_empty_custom_data(
|
209
|
+
column_name, struct_field_column, pyspark_column
|
210
|
+
)
|
211
|
+
column_custom_data_collection.append(custom_data)
|
212
|
+
continue
|
213
|
+
|
214
|
+
pandera_column = pandera_infer_schema.columns[column_name]
|
215
|
+
pandera_column.checks = []
|
216
|
+
column_pandera_checks_manager.add_checks_column(
|
217
|
+
column_name, column_type, df, pandera_column
|
218
|
+
)
|
219
|
+
|
220
|
+
custom_data = column_collector_manager.collect_column(
|
221
|
+
column_name, struct_field_column, pyspark_column
|
222
|
+
)
|
223
|
+
column_custom_data_collection.append(custom_data)
|
224
|
+
|
225
|
+
pandera_infer_schema_dict = _get_pandera_infer_schema_as_dict(
|
226
|
+
pandera_infer_schema,
|
227
|
+
is_empty_df_with_object_column,
|
228
|
+
columns_to_remove_from_pandera_schema_collection,
|
229
|
+
)
|
230
|
+
|
231
|
+
dataframe_custom_column_data = {COLUMNS_KEY: column_custom_data_collection}
|
232
|
+
dataframe_schema_contract = {
|
233
|
+
DATAFRAME_PANDERA_SCHEMA_KEY: pandera_infer_schema_dict,
|
234
|
+
DATAFRAME_CUSTOM_DATA_KEY: dataframe_custom_column_data,
|
235
|
+
}
|
236
|
+
|
237
|
+
dataframe_schema_contract_json = json.dumps(dataframe_schema_contract)
|
238
|
+
_generate_json_checkpoint_file(
|
239
|
+
checkpoint_name, dataframe_schema_contract_json, output_path
|
240
|
+
)
|
241
|
+
|
242
|
+
|
243
|
+
def _get_spark_column_types(df: SparkDataFrame) -> dict[str, StructField]:
|
244
|
+
schema = df.schema
|
245
|
+
column_type_collection = {}
|
246
|
+
for field in schema.fields:
|
247
|
+
column_name = field.name
|
248
|
+
column_type_collection[column_name] = field
|
249
|
+
return column_type_collection
|
250
|
+
|
251
|
+
|
252
|
+
def _is_empty_dataframe_without_schema(df: SparkDataFrame) -> bool:
|
253
|
+
is_empty = df.isEmpty()
|
254
|
+
has_schema = len(df.schema.fields) > 0
|
255
|
+
return is_empty and not has_schema
|
256
|
+
|
257
|
+
|
258
|
+
def _is_empty_dataframe_with_object_column(df: SparkDataFrame):
|
259
|
+
is_empty = df.isEmpty()
|
260
|
+
if not is_empty:
|
261
|
+
return False
|
262
|
+
|
263
|
+
for field in df.schema.fields:
|
264
|
+
if field.dataType.typeName() in PANDAS_OBJECT_TYPE_COLLECTION:
|
265
|
+
return True
|
266
|
+
|
267
|
+
return False
|
268
|
+
|
269
|
+
|
270
|
+
def _is_column_to_remove_from_pandera_schema(column_type) -> bool:
|
271
|
+
is_decimal_type = column_type == DECIMAL_COLUMN_TYPE
|
272
|
+
return is_decimal_type
|
273
|
+
|
274
|
+
|
275
|
+
def _get_pandera_infer_schema_as_dict(
|
276
|
+
pandera_infer_schema, is_empty_df_with_string_column, columns_to_remove_collection
|
277
|
+
) -> dict[str, any]:
|
278
|
+
if is_empty_df_with_string_column:
|
279
|
+
return {}
|
280
|
+
|
281
|
+
pandera_infer_schema_dict = json.loads(pandera_infer_schema.to_json())
|
282
|
+
for column in columns_to_remove_collection:
|
283
|
+
del pandera_infer_schema_dict[COLUMNS_KEY][column]
|
284
|
+
|
285
|
+
return pandera_infer_schema_dict
|
286
|
+
|
287
|
+
|
288
|
+
def _generate_json_checkpoint_file(
|
289
|
+
checkpoint_name, dataframe_schema_contract, output_path: Optional[str] = None
|
290
|
+
) -> None:
|
291
|
+
checkpoint_file_name = CHECKPOINT_JSON_OUTPUT_FILE_NAME_FORMAT.format(
|
292
|
+
checkpoint_name
|
293
|
+
)
|
294
|
+
output_directory_path = file_utils.get_output_directory_path(output_path)
|
295
|
+
checkpoint_file_path = os.path.join(output_directory_path, checkpoint_file_name)
|
296
|
+
with open(checkpoint_file_path, "w") as f:
|
297
|
+
f.write(dataframe_schema_contract)
|
298
|
+
|
299
|
+
|
300
|
+
@report_telemetry(params_list=["df"])
|
301
|
+
def _collect_dataframe_checkpoint_mode_dataframe(
|
302
|
+
checkpoint_name: str,
|
303
|
+
df: SparkDataFrame,
|
304
|
+
snow_connection: SnowConnection,
|
305
|
+
output_path: Optional[str] = None,
|
306
|
+
) -> None:
|
307
|
+
output_path = file_utils.get_output_directory_path(output_path)
|
308
|
+
parquet_directory = os.path.join(output_path, checkpoint_name)
|
309
|
+
generate_parquet_for_spark_df(df, parquet_directory)
|
310
|
+
_create_snowflake_table_from_parquet(
|
311
|
+
checkpoint_name, parquet_directory, snow_connection
|
312
|
+
)
|
313
|
+
|
314
|
+
|
315
|
+
def generate_parquet_for_spark_df(spark_df: SparkDataFrame, output_path: str) -> None:
|
316
|
+
"""Generate a parquet file from a Spark DataFrame.
|
317
|
+
|
318
|
+
This function will convert Float to Double to avoid precision problems.
|
319
|
+
Spark parquet use IEEE 32-bit floating point values,
|
320
|
+
while Snowflake uses IEEE 64-bit floating point values.
|
321
|
+
|
322
|
+
Args:
|
323
|
+
spark_df: dataframe to be saved as parquet
|
324
|
+
output_path: path to save the parquet files.
|
325
|
+
returns: None
|
326
|
+
|
327
|
+
Raises:
|
328
|
+
Exception: No parquet files were generated.
|
329
|
+
|
330
|
+
"""
|
331
|
+
new_cols = [
|
332
|
+
(
|
333
|
+
col(c).cast(SparkStringType()).cast(SparkDoubleType()).alias(c)
|
334
|
+
if t == "float"
|
335
|
+
else col(c)
|
336
|
+
)
|
337
|
+
for (c, t) in spark_df.dtypes
|
338
|
+
]
|
339
|
+
converted_df = spark_df.select(new_cols)
|
340
|
+
|
341
|
+
if os.path.exists(output_path):
|
342
|
+
shutil.rmtree(output_path)
|
343
|
+
|
344
|
+
converted_df.write.parquet(output_path, mode="overwrite")
|
345
|
+
|
346
|
+
target_dir = os.path.join(output_path, "**", f"*{DOT_PARQUET_EXTENSION}")
|
347
|
+
files = glob.glob(target_dir, recursive=True)
|
348
|
+
if len(files) == 0:
|
349
|
+
raise Exception("No parquet files were generated.")
|
350
|
+
|
351
|
+
|
352
|
+
def _create_snowflake_table_from_parquet(
|
353
|
+
table_name: str, input_path: str, snow_connection: SnowConnection
|
354
|
+
) -> None:
|
355
|
+
snow_connection.create_snowflake_table_from_local_parquet(table_name, input_path)
|
356
|
+
|
357
|
+
|
358
|
+
def _to_pandas(sampled_df: SparkDataFrame) -> pandas.DataFrame:
|
359
|
+
pandas_df = sampled_df.toPandas()
|
360
|
+
for field in sampled_df.schema.fields:
|
361
|
+
has_nan = pandas_df[field.name].isna().any()
|
362
|
+
is_integer = field.dataType.typeName() in INTEGER_TYPE_COLLECTION
|
363
|
+
if has_nan and is_integer:
|
364
|
+
pandas_df[field.name] = pandas_df[field.name].astype(PANDAS_LONG_TYPE)
|
365
|
+
|
366
|
+
return pandas_df
|
@@ -0,0 +1,53 @@
|
|
1
|
+
# Copyright 2025 Snowflake Inc.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
import re as regx
|
17
|
+
|
18
|
+
|
19
|
+
CHECKPOINT_NAME_REGEX_PATTERN = r"[a-zA-Z_][a-zA-Z0-9_]+"
|
20
|
+
TRANSLATION_TABLE = str.maketrans({" ": "_", "-": "_"})
|
21
|
+
|
22
|
+
|
23
|
+
def normalize_checkpoint_name(checkpoint_name: str) -> str:
|
24
|
+
"""Normalize the provided checkpoint name by replacing: the whitespace and hyphen tokens by underscore token.
|
25
|
+
|
26
|
+
Args:
|
27
|
+
checkpoint_name (str): The checkpoint name to normalize.
|
28
|
+
|
29
|
+
Returns:
|
30
|
+
str: the checkpoint name normalized.
|
31
|
+
|
32
|
+
"""
|
33
|
+
normalized_checkpoint_name = checkpoint_name.translate(TRANSLATION_TABLE)
|
34
|
+
return normalized_checkpoint_name
|
35
|
+
|
36
|
+
|
37
|
+
def is_valid_checkpoint_name(checkpoint_name: str) -> bool:
|
38
|
+
"""Check if the provided checkpoint name is valid.
|
39
|
+
|
40
|
+
A valid checkpoint name must:
|
41
|
+
- Start with a letter (a-z, A-Z) or an underscore (_)
|
42
|
+
- Be followed by any combination of letters, digits (0-9) and underscores (_).
|
43
|
+
|
44
|
+
Args:
|
45
|
+
checkpoint_name (str): The checkpoint name to validate.
|
46
|
+
|
47
|
+
Returns:
|
48
|
+
bool: True if the checkpoint name is valid; otherwise, False.
|
49
|
+
|
50
|
+
"""
|
51
|
+
matched = regx.fullmatch(CHECKPOINT_NAME_REGEX_PATTERN, checkpoint_name)
|
52
|
+
is_valid = bool(matched)
|
53
|
+
return is_valid
|
@@ -0,0 +1,112 @@
|
|
1
|
+
# Copyright 2025 Snowflake Inc.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
import os
|
16
|
+
|
17
|
+
from typing import Optional
|
18
|
+
|
19
|
+
from snowflake.snowpark_checkpoints_collector.collection_common import (
|
20
|
+
SNOWFLAKE_CHECKPOINT_CONTRACT_FILE_PATH_ENV_VAR,
|
21
|
+
CheckpointMode,
|
22
|
+
)
|
23
|
+
|
24
|
+
|
25
|
+
# noinspection DuplicatedCode
|
26
|
+
def _get_checkpoint_contract_file_path() -> str:
|
27
|
+
return os.environ.get(SNOWFLAKE_CHECKPOINT_CONTRACT_FILE_PATH_ENV_VAR, os.getcwd())
|
28
|
+
|
29
|
+
|
30
|
+
# noinspection DuplicatedCode
|
31
|
+
def _get_metadata():
|
32
|
+
try:
|
33
|
+
from snowflake.snowpark_checkpoints_configuration.checkpoint_metadata import (
|
34
|
+
CheckpointMetadata,
|
35
|
+
)
|
36
|
+
|
37
|
+
path = _get_checkpoint_contract_file_path()
|
38
|
+
metadata = CheckpointMetadata(path)
|
39
|
+
return True, metadata
|
40
|
+
|
41
|
+
except ImportError:
|
42
|
+
return False, None
|
43
|
+
|
44
|
+
|
45
|
+
def is_checkpoint_enabled(checkpoint_name: str) -> bool:
|
46
|
+
"""Check if a checkpoint is enabled.
|
47
|
+
|
48
|
+
Args:
|
49
|
+
checkpoint_name (str): The name of the checkpoint.
|
50
|
+
|
51
|
+
Returns:
|
52
|
+
bool: True if the checkpoint is enabled, False otherwise.
|
53
|
+
|
54
|
+
"""
|
55
|
+
enabled, metadata = _get_metadata()
|
56
|
+
if enabled:
|
57
|
+
config = metadata.get_checkpoint(checkpoint_name)
|
58
|
+
return config.enabled
|
59
|
+
else:
|
60
|
+
return True
|
61
|
+
|
62
|
+
|
63
|
+
def get_checkpoint_sample(
|
64
|
+
checkpoint_name: str, sample: Optional[float] = None
|
65
|
+
) -> float:
|
66
|
+
"""Get the checkpoint sample.
|
67
|
+
|
68
|
+
Following this order first, the sample passed as argument, second, the sample from the checkpoint configuration,
|
69
|
+
third, the default sample value 1.0.
|
70
|
+
|
71
|
+
Args:
|
72
|
+
checkpoint_name (str): The name of the checkpoint.
|
73
|
+
sample (float, optional): The value passed to the function.
|
74
|
+
|
75
|
+
Returns:
|
76
|
+
float: returns the sample for that specific checkpoint.
|
77
|
+
|
78
|
+
"""
|
79
|
+
default_sample = 1.0
|
80
|
+
|
81
|
+
enabled, metadata = _get_metadata()
|
82
|
+
if enabled:
|
83
|
+
config = metadata.get_checkpoint(checkpoint_name)
|
84
|
+
default_sample = config.sample if config.sample is not None else default_sample
|
85
|
+
|
86
|
+
return sample if sample is not None else default_sample
|
87
|
+
|
88
|
+
|
89
|
+
def get_checkpoint_mode(
|
90
|
+
checkpoint_name: str, mode: Optional[CheckpointMode] = None
|
91
|
+
) -> CheckpointMode:
|
92
|
+
"""Get the checkpoint mode.
|
93
|
+
|
94
|
+
Following this order first, the mode passed as argument, second, the mode from the checkpoint configuration,
|
95
|
+
third, the default mode value 1.
|
96
|
+
|
97
|
+
Args:
|
98
|
+
checkpoint_name (str): The name of the checkpoint.
|
99
|
+
mode (int, optional): The value passed to the function.
|
100
|
+
|
101
|
+
Returns:
|
102
|
+
int: returns the mode for that specific checkpoint.
|
103
|
+
|
104
|
+
"""
|
105
|
+
default_mode = CheckpointMode.SCHEMA
|
106
|
+
|
107
|
+
enabled, metadata = _get_metadata()
|
108
|
+
if enabled:
|
109
|
+
config = metadata.get_checkpoint(checkpoint_name)
|
110
|
+
default_mode = config.mode if config.mode is not None else default_mode
|
111
|
+
|
112
|
+
return mode if mode is not None else default_mode
|
@@ -0,0 +1,132 @@
|
|
1
|
+
# Copyright 2025 Snowflake Inc.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
import inspect
|
16
|
+
import os
|
17
|
+
import tempfile
|
18
|
+
|
19
|
+
from typing import Optional
|
20
|
+
|
21
|
+
from snowflake.snowpark_checkpoints_collector.collection_common import (
|
22
|
+
COLLECTION_RESULT_FILE_NAME,
|
23
|
+
DOT_IPYNB_EXTENSION,
|
24
|
+
SNOWPARK_CHECKPOINTS_OUTPUT_DIRECTORY_NAME,
|
25
|
+
UNKNOWN_LINE_OF_CODE,
|
26
|
+
UNKNOWN_SOURCE_FILE,
|
27
|
+
)
|
28
|
+
|
29
|
+
|
30
|
+
def get_output_file_path(out_path: Optional[str] = None) -> str:
|
31
|
+
"""Get the output file path.
|
32
|
+
|
33
|
+
Args:
|
34
|
+
out_path (Optional[str], optional): the output path. Defaults to None.
|
35
|
+
|
36
|
+
Returns:
|
37
|
+
str: returns the output file path.
|
38
|
+
|
39
|
+
"""
|
40
|
+
output_directory_path = get_output_directory_path(out_path)
|
41
|
+
output_file_path = os.path.join(output_directory_path, COLLECTION_RESULT_FILE_NAME)
|
42
|
+
return output_file_path
|
43
|
+
|
44
|
+
|
45
|
+
def get_relative_file_path(path: str) -> str:
|
46
|
+
"""Get the relative file path.
|
47
|
+
|
48
|
+
Args:
|
49
|
+
path (str): a file path.
|
50
|
+
|
51
|
+
Returns:
|
52
|
+
str: returns the relative file path of the given file.
|
53
|
+
|
54
|
+
"""
|
55
|
+
relative_file_path = os.path.relpath(path)
|
56
|
+
return relative_file_path
|
57
|
+
|
58
|
+
|
59
|
+
def get_output_directory_path(output_path: Optional[str] = None) -> str:
|
60
|
+
"""Get the output directory path.
|
61
|
+
|
62
|
+
Returns:
|
63
|
+
str: returns the output directory path.
|
64
|
+
|
65
|
+
"""
|
66
|
+
current_working_directory_path = output_path if output_path else os.getcwd()
|
67
|
+
checkpoints_output_directory_path = os.path.join(
|
68
|
+
current_working_directory_path, SNOWPARK_CHECKPOINTS_OUTPUT_DIRECTORY_NAME
|
69
|
+
)
|
70
|
+
os.makedirs(checkpoints_output_directory_path, exist_ok=True)
|
71
|
+
return checkpoints_output_directory_path
|
72
|
+
|
73
|
+
|
74
|
+
def get_collection_point_source_file_path() -> str:
|
75
|
+
"""Get the path of the source file where collection point it is.
|
76
|
+
|
77
|
+
Returns:
|
78
|
+
str: returns the path of source file where collection point it is.
|
79
|
+
|
80
|
+
"""
|
81
|
+
try:
|
82
|
+
collection_point_file_path = inspect.stack()[2].filename
|
83
|
+
is_temporal_file_path = _is_temporal_path(collection_point_file_path)
|
84
|
+
if is_temporal_file_path:
|
85
|
+
ipynb_file_path_collection = _get_ipynb_file_path_collection()
|
86
|
+
if len(ipynb_file_path_collection) == 1:
|
87
|
+
collection_point_file_path = ipynb_file_path_collection[0]
|
88
|
+
else:
|
89
|
+
collection_point_file_path = UNKNOWN_SOURCE_FILE
|
90
|
+
|
91
|
+
return collection_point_file_path
|
92
|
+
|
93
|
+
except Exception:
|
94
|
+
return UNKNOWN_SOURCE_FILE
|
95
|
+
|
96
|
+
|
97
|
+
def get_collection_point_line_of_code() -> int:
|
98
|
+
"""Find the line of code of the source file where collection point it is.
|
99
|
+
|
100
|
+
Returns:
|
101
|
+
int: returns the line of code of the source file where collection point it is.
|
102
|
+
|
103
|
+
"""
|
104
|
+
try:
|
105
|
+
collection_point_file_path = inspect.stack()[2].filename
|
106
|
+
collection_point_line_of_code = inspect.stack()[2].lineno
|
107
|
+
is_temporal_file_path = _is_temporal_path(collection_point_file_path)
|
108
|
+
if is_temporal_file_path:
|
109
|
+
collection_point_line_of_code = UNKNOWN_LINE_OF_CODE
|
110
|
+
return collection_point_line_of_code
|
111
|
+
|
112
|
+
except Exception:
|
113
|
+
return UNKNOWN_LINE_OF_CODE
|
114
|
+
|
115
|
+
|
116
|
+
def _is_temporal_path(path: str) -> bool:
|
117
|
+
temporal_directory_path = tempfile.gettempdir()
|
118
|
+
is_temporal_path = path.startswith(temporal_directory_path)
|
119
|
+
return is_temporal_path
|
120
|
+
|
121
|
+
|
122
|
+
def _get_ipynb_file_path_collection() -> list[str]:
|
123
|
+
current_working_directory_path = os.getcwd()
|
124
|
+
cwd_file_name_collection = os.listdir(current_working_directory_path)
|
125
|
+
ipynb_file_path_collection = []
|
126
|
+
for file_name in cwd_file_name_collection:
|
127
|
+
is_ipynb_file = file_name.endswith(DOT_IPYNB_EXTENSION)
|
128
|
+
if is_ipynb_file:
|
129
|
+
file_path = os.path.join(current_working_directory_path, file_name)
|
130
|
+
ipynb_file_path_collection.append(file_path)
|
131
|
+
|
132
|
+
return ipynb_file_path_collection
|