snowpark-checkpoints-collectors 0.2.0rc1__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/snowpark_checkpoints_collector/__init__.py +30 -0
- snowflake/snowpark_checkpoints_collector/__version__.py +16 -0
- snowflake/snowpark_checkpoints_collector/collection_common.py +160 -0
- snowflake/snowpark_checkpoints_collector/collection_result/model/__init__.py +24 -0
- snowflake/snowpark_checkpoints_collector/collection_result/model/collection_point_result.py +91 -0
- snowflake/snowpark_checkpoints_collector/collection_result/model/collection_point_result_manager.py +76 -0
- snowflake/snowpark_checkpoints_collector/column_collection/__init__.py +22 -0
- snowflake/snowpark_checkpoints_collector/column_collection/column_collector_manager.py +276 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/__init__.py +75 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/array_column_collector.py +113 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/binary_column_collector.py +87 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/boolean_column_collector.py +71 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/column_collector_base.py +95 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/date_column_collector.py +74 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/day_time_interval_column_collector.py +67 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/decimal_column_collector.py +92 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/empty_column_collector.py +88 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/map_column_collector.py +120 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/null_column_collector.py +49 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/numeric_column_collector.py +108 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/string_column_collector.py +70 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/struct_column_collector.py +102 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/timestamp_column_collector.py +75 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/timestamp_ntz_column_collector.py +75 -0
- snowflake/snowpark_checkpoints_collector/column_pandera_checks/__init__.py +20 -0
- snowflake/snowpark_checkpoints_collector/column_pandera_checks/pandera_column_checks_manager.py +241 -0
- snowflake/snowpark_checkpoints_collector/io_utils/__init__.py +26 -0
- snowflake/snowpark_checkpoints_collector/io_utils/io_default_strategy.py +61 -0
- snowflake/snowpark_checkpoints_collector/io_utils/io_env_strategy.py +142 -0
- snowflake/snowpark_checkpoints_collector/io_utils/io_file_manager.py +79 -0
- snowflake/snowpark_checkpoints_collector/singleton.py +23 -0
- snowflake/snowpark_checkpoints_collector/snow_connection_model/__init__.py +20 -0
- snowflake/snowpark_checkpoints_collector/snow_connection_model/snow_connection.py +203 -0
- snowflake/snowpark_checkpoints_collector/summary_stats_collector.py +409 -0
- snowflake/snowpark_checkpoints_collector/utils/checkpoint_name_utils.py +53 -0
- snowflake/snowpark_checkpoints_collector/utils/extra_config.py +164 -0
- snowflake/snowpark_checkpoints_collector/utils/file_utils.py +137 -0
- snowflake/snowpark_checkpoints_collector/utils/logging_utils.py +67 -0
- snowflake/snowpark_checkpoints_collector/utils/telemetry.py +928 -0
- snowpark_checkpoints_collectors-0.3.0.dist-info/METADATA +159 -0
- snowpark_checkpoints_collectors-0.3.0.dist-info/RECORD +43 -0
- {snowpark_checkpoints_collectors-0.2.0rc1.dist-info → snowpark_checkpoints_collectors-0.3.0.dist-info}/licenses/LICENSE +0 -25
- snowpark_checkpoints_collectors-0.2.0rc1.dist-info/METADATA +0 -347
- snowpark_checkpoints_collectors-0.2.0rc1.dist-info/RECORD +0 -4
- {snowpark_checkpoints_collectors-0.2.0rc1.dist-info → snowpark_checkpoints_collectors-0.3.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,20 @@
|
|
1
|
+
# Copyright 2025 Snowflake Inc.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
__all__ = ["PanderaColumnChecksManager"]
|
17
|
+
|
18
|
+
from snowflake.snowpark_checkpoints_collector.column_pandera_checks.pandera_column_checks_manager import (
|
19
|
+
PanderaColumnChecksManager,
|
20
|
+
)
|
snowflake/snowpark_checkpoints_collector/column_pandera_checks/pandera_column_checks_manager.py
ADDED
@@ -0,0 +1,241 @@
|
|
1
|
+
# Copyright 2025 Snowflake Inc.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
import logging
|
17
|
+
|
18
|
+
import pandas as pd
|
19
|
+
|
20
|
+
from pandera import Check, Column
|
21
|
+
from pyspark.sql import DataFrame as SparkDataFrame
|
22
|
+
from pyspark.sql.functions import col as spark_col
|
23
|
+
from pyspark.sql.functions import length as spark_length
|
24
|
+
from pyspark.sql.functions import max as spark_max
|
25
|
+
from pyspark.sql.functions import min as spark_min
|
26
|
+
|
27
|
+
from snowflake.snowpark_checkpoints_collector.collection_common import (
|
28
|
+
BETWEEN_CHECK_ERROR_MESSAGE_FORMAT,
|
29
|
+
BOOLEAN_COLUMN_TYPE,
|
30
|
+
BYTE_COLUMN_TYPE,
|
31
|
+
COLUMN_MAX_KEY,
|
32
|
+
COLUMN_MIN_KEY,
|
33
|
+
DAYTIMEINTERVAL_COLUMN_TYPE,
|
34
|
+
DOUBLE_COLUMN_TYPE,
|
35
|
+
FLOAT_COLUMN_TYPE,
|
36
|
+
INTEGER_COLUMN_TYPE,
|
37
|
+
LONG_COLUMN_TYPE,
|
38
|
+
SHORT_COLUMN_TYPE,
|
39
|
+
STRING_COLUMN_TYPE,
|
40
|
+
TIMESTAMP_COLUMN_TYPE,
|
41
|
+
TIMESTAMP_NTZ_COLUMN_TYPE,
|
42
|
+
)
|
43
|
+
|
44
|
+
|
45
|
+
LOGGER = logging.getLogger(__name__)
|
46
|
+
|
47
|
+
|
48
|
+
def collector_register(cls):
|
49
|
+
"""Decorate a class with the checks mechanism.
|
50
|
+
|
51
|
+
Args:
|
52
|
+
cls: The class to decorate.
|
53
|
+
|
54
|
+
Returns:
|
55
|
+
The class to decorate.
|
56
|
+
|
57
|
+
"""
|
58
|
+
LOGGER.debug("Starting to register checks from class %s", cls.__name__)
|
59
|
+
cls._collectors = {}
|
60
|
+
for method_name in dir(cls):
|
61
|
+
method = getattr(cls, method_name)
|
62
|
+
if hasattr(method, "_column_type"):
|
63
|
+
col_type_collection = method._column_type
|
64
|
+
for col_type in col_type_collection:
|
65
|
+
cls._collectors[col_type] = method_name
|
66
|
+
LOGGER.debug(
|
67
|
+
"Registered check '%s' for column type '%s'", method_name, col_type
|
68
|
+
)
|
69
|
+
return cls
|
70
|
+
|
71
|
+
|
72
|
+
def column_register(*args):
|
73
|
+
"""Decorate a method to register it in the checks mechanism based on column type.
|
74
|
+
|
75
|
+
Args:
|
76
|
+
args: the column type to register.
|
77
|
+
|
78
|
+
Returns:
|
79
|
+
The wrapper.
|
80
|
+
|
81
|
+
"""
|
82
|
+
|
83
|
+
def wrapper(func):
|
84
|
+
has_arguments = len(args) > 0
|
85
|
+
if has_arguments:
|
86
|
+
func._column_type = args
|
87
|
+
return func
|
88
|
+
|
89
|
+
return wrapper
|
90
|
+
|
91
|
+
|
92
|
+
@collector_register
|
93
|
+
class PanderaColumnChecksManager:
|
94
|
+
|
95
|
+
"""Manage class for Pandera column checks based on type."""
|
96
|
+
|
97
|
+
def add_checks_column(
|
98
|
+
self,
|
99
|
+
clm_name: str,
|
100
|
+
clm_type: str,
|
101
|
+
pyspark_df: SparkDataFrame,
|
102
|
+
pandera_column: Column,
|
103
|
+
) -> None:
|
104
|
+
"""Add checks to Pandera column based on the column type.
|
105
|
+
|
106
|
+
Args:
|
107
|
+
clm_name (str): the name of the column.
|
108
|
+
clm_type (str): the type of the column.
|
109
|
+
pyspark_df (pyspark.sql.DataFrame): the DataFrame.
|
110
|
+
pandera_column (pandera.Column): the Pandera column.
|
111
|
+
|
112
|
+
"""
|
113
|
+
if clm_type not in self._collectors:
|
114
|
+
LOGGER.debug(
|
115
|
+
"No Pandera checks found for column '%s' of type '%s'. Skipping checks for this column.",
|
116
|
+
clm_name,
|
117
|
+
clm_type,
|
118
|
+
)
|
119
|
+
return
|
120
|
+
|
121
|
+
func_name = self._collectors[clm_type]
|
122
|
+
func = getattr(self, func_name)
|
123
|
+
LOGGER.debug(
|
124
|
+
"Adding Pandera checks to column '%s' of type '%s'", clm_name, clm_type
|
125
|
+
)
|
126
|
+
func(clm_name, pyspark_df, pandera_column)
|
127
|
+
|
128
|
+
@column_register(BOOLEAN_COLUMN_TYPE)
|
129
|
+
def _add_boolean_type_checks(
|
130
|
+
self, clm_name: str, pyspark_df: SparkDataFrame, pandera_column: Column
|
131
|
+
) -> None:
|
132
|
+
pandera_column.checks.extend([Check.isin([True, False])])
|
133
|
+
|
134
|
+
@column_register(DAYTIMEINTERVAL_COLUMN_TYPE)
|
135
|
+
def _add_daytimeinterval_type_checks(
|
136
|
+
self, clm_name: str, pyspark_df: SparkDataFrame, pandera_column: Column
|
137
|
+
) -> None:
|
138
|
+
select_result = pyspark_df.select(
|
139
|
+
spark_min(spark_col(clm_name)).alias(COLUMN_MIN_KEY),
|
140
|
+
spark_max(spark_col(clm_name)).alias(COLUMN_MAX_KEY),
|
141
|
+
).collect()[0]
|
142
|
+
|
143
|
+
min_value = pd.to_timedelta(select_result[COLUMN_MIN_KEY])
|
144
|
+
max_value = pd.to_timedelta(select_result[COLUMN_MAX_KEY])
|
145
|
+
|
146
|
+
pandera_column.checks.append(
|
147
|
+
Check.between(
|
148
|
+
min_value=min_value,
|
149
|
+
max_value=max_value,
|
150
|
+
include_max=True,
|
151
|
+
include_min=True,
|
152
|
+
title=BETWEEN_CHECK_ERROR_MESSAGE_FORMAT.format(min_value, max_value),
|
153
|
+
)
|
154
|
+
)
|
155
|
+
|
156
|
+
@column_register(
|
157
|
+
BYTE_COLUMN_TYPE,
|
158
|
+
SHORT_COLUMN_TYPE,
|
159
|
+
INTEGER_COLUMN_TYPE,
|
160
|
+
LONG_COLUMN_TYPE,
|
161
|
+
FLOAT_COLUMN_TYPE,
|
162
|
+
DOUBLE_COLUMN_TYPE,
|
163
|
+
)
|
164
|
+
def _add_numeric_type_checks(
|
165
|
+
self, clm_name: str, pyspark_df: SparkDataFrame, pandera_column: Column
|
166
|
+
) -> None:
|
167
|
+
select_result = pyspark_df.select(
|
168
|
+
spark_min(spark_col(clm_name)).alias(COLUMN_MIN_KEY),
|
169
|
+
spark_max(spark_col(clm_name)).alias(COLUMN_MAX_KEY),
|
170
|
+
).collect()[0]
|
171
|
+
|
172
|
+
min_value = select_result[COLUMN_MIN_KEY]
|
173
|
+
max_value = select_result[COLUMN_MAX_KEY]
|
174
|
+
|
175
|
+
pandera_column.checks.append(
|
176
|
+
Check.between(
|
177
|
+
min_value=min_value,
|
178
|
+
max_value=max_value,
|
179
|
+
include_max=True,
|
180
|
+
include_min=True,
|
181
|
+
title=BETWEEN_CHECK_ERROR_MESSAGE_FORMAT.format(min_value, max_value),
|
182
|
+
)
|
183
|
+
)
|
184
|
+
|
185
|
+
@column_register(STRING_COLUMN_TYPE)
|
186
|
+
def _add_string_type_checks(
|
187
|
+
self, clm_name: str, pyspark_df: SparkDataFrame, pandera_column: Column
|
188
|
+
) -> None:
|
189
|
+
select_result = pyspark_df.select(
|
190
|
+
spark_min(spark_length(spark_col(clm_name))).alias(COLUMN_MIN_KEY),
|
191
|
+
spark_max(spark_length(spark_col(clm_name))).alias(COLUMN_MAX_KEY),
|
192
|
+
).collect()[0]
|
193
|
+
|
194
|
+
min_length = select_result[COLUMN_MIN_KEY]
|
195
|
+
max_length = select_result[COLUMN_MAX_KEY]
|
196
|
+
|
197
|
+
pandera_column.checks.append(Check.str_length(min_length, max_length))
|
198
|
+
|
199
|
+
@column_register(TIMESTAMP_COLUMN_TYPE)
|
200
|
+
def _add_timestamp_type_checks(
|
201
|
+
self, clm_name: str, pyspark_df: SparkDataFrame, pandera_column: Column
|
202
|
+
) -> None:
|
203
|
+
select_result = pyspark_df.select(
|
204
|
+
spark_min(spark_col(clm_name)).alias(COLUMN_MIN_KEY),
|
205
|
+
spark_max(spark_col(clm_name)).alias(COLUMN_MAX_KEY),
|
206
|
+
).collect()[0]
|
207
|
+
|
208
|
+
min_value = pd.Timestamp(select_result[COLUMN_MIN_KEY])
|
209
|
+
max_value = pd.Timestamp(select_result[COLUMN_MAX_KEY])
|
210
|
+
|
211
|
+
pandera_column.checks.append(
|
212
|
+
Check.between(
|
213
|
+
min_value=min_value,
|
214
|
+
max_value=max_value,
|
215
|
+
include_max=True,
|
216
|
+
include_min=True,
|
217
|
+
title=BETWEEN_CHECK_ERROR_MESSAGE_FORMAT.format(min_value, max_value),
|
218
|
+
)
|
219
|
+
)
|
220
|
+
|
221
|
+
@column_register(TIMESTAMP_NTZ_COLUMN_TYPE)
|
222
|
+
def _add_timestamp_ntz_type_checks(
|
223
|
+
self, clm_name: str, pyspark_df: SparkDataFrame, pandera_column: Column
|
224
|
+
) -> None:
|
225
|
+
select_result = pyspark_df.select(
|
226
|
+
spark_min(spark_col(clm_name)).alias(COLUMN_MIN_KEY),
|
227
|
+
spark_max(spark_col(clm_name)).alias(COLUMN_MAX_KEY),
|
228
|
+
).collect()[0]
|
229
|
+
|
230
|
+
min_value = pd.Timestamp(select_result[COLUMN_MIN_KEY])
|
231
|
+
max_value = pd.Timestamp(select_result[COLUMN_MAX_KEY])
|
232
|
+
|
233
|
+
pandera_column.checks.append(
|
234
|
+
Check.between(
|
235
|
+
min_value=min_value,
|
236
|
+
max_value=max_value,
|
237
|
+
include_max=True,
|
238
|
+
include_min=True,
|
239
|
+
title=BETWEEN_CHECK_ERROR_MESSAGE_FORMAT.format(min_value, max_value),
|
240
|
+
)
|
241
|
+
)
|
@@ -0,0 +1,26 @@
|
|
1
|
+
# Copyright 2025 Snowflake Inc.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
__all__ = ["EnvStrategy", "IOFileManager", "IODefaultStrategy"]
|
17
|
+
|
18
|
+
from snowflake.snowpark_checkpoints_collector.io_utils.io_env_strategy import (
|
19
|
+
EnvStrategy,
|
20
|
+
)
|
21
|
+
from snowflake.snowpark_checkpoints_collector.io_utils.io_default_strategy import (
|
22
|
+
IODefaultStrategy,
|
23
|
+
)
|
24
|
+
from snowflake.snowpark_checkpoints_collector.io_utils.io_file_manager import (
|
25
|
+
IOFileManager,
|
26
|
+
)
|
@@ -0,0 +1,61 @@
|
|
1
|
+
# Copyright 2025 Snowflake Inc.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
import glob
|
17
|
+
import os
|
18
|
+
import shutil
|
19
|
+
|
20
|
+
from pathlib import Path
|
21
|
+
from typing import Optional
|
22
|
+
|
23
|
+
from snowflake.snowpark_checkpoints_collector.io_utils import EnvStrategy
|
24
|
+
|
25
|
+
|
26
|
+
class IODefaultStrategy(EnvStrategy):
|
27
|
+
def mkdir(self, path: str, exist_ok: bool = False) -> None:
|
28
|
+
os.makedirs(path, exist_ok=exist_ok)
|
29
|
+
|
30
|
+
def folder_exists(self, path: str) -> bool:
|
31
|
+
return os.path.isdir(path)
|
32
|
+
|
33
|
+
def file_exists(self, path: str) -> bool:
|
34
|
+
return os.path.isfile(path)
|
35
|
+
|
36
|
+
def write(self, file_path: str, file_content: str, overwrite: bool = True) -> None:
|
37
|
+
mode = "w" if overwrite else "x"
|
38
|
+
with open(file_path, mode) as file:
|
39
|
+
file.write(file_content)
|
40
|
+
|
41
|
+
def read(
|
42
|
+
self, file_path: str, mode: str = "r", encoding: Optional[str] = None
|
43
|
+
) -> str:
|
44
|
+
with open(file_path, mode=mode, encoding=encoding) as file:
|
45
|
+
return file.read()
|
46
|
+
|
47
|
+
def read_bytes(self, file_path: str) -> bytes:
|
48
|
+
with open(file_path, mode="rb") as f:
|
49
|
+
return f.read()
|
50
|
+
|
51
|
+
def ls(self, path: str, recursive: bool = False) -> list[str]:
|
52
|
+
return glob.glob(path, recursive=recursive)
|
53
|
+
|
54
|
+
def getcwd(self) -> str:
|
55
|
+
return os.getcwd()
|
56
|
+
|
57
|
+
def remove_dir(self, path: str) -> None:
|
58
|
+
shutil.rmtree(path)
|
59
|
+
|
60
|
+
def telemetry_path_files(self, path: str) -> Path:
|
61
|
+
return Path(path)
|
@@ -0,0 +1,142 @@
|
|
1
|
+
# Copyright 2025 Snowflake Inc.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
from abc import ABC, abstractmethod
|
17
|
+
from pathlib import Path
|
18
|
+
from typing import Optional
|
19
|
+
|
20
|
+
|
21
|
+
class EnvStrategy(ABC):
|
22
|
+
|
23
|
+
"""An abstract base class that defines methods for file and directory operations.
|
24
|
+
|
25
|
+
Subclasses should implement these methods to provide environment-specific behavior.
|
26
|
+
"""
|
27
|
+
|
28
|
+
@abstractmethod
|
29
|
+
def mkdir(self, path: str, exist_ok: bool = False) -> None:
|
30
|
+
"""Create a directory.
|
31
|
+
|
32
|
+
Args:
|
33
|
+
path: The name of the directory to create.
|
34
|
+
exist_ok: If False, an error is raised if the directory already exists.
|
35
|
+
|
36
|
+
"""
|
37
|
+
|
38
|
+
@abstractmethod
|
39
|
+
def folder_exists(self, path: str) -> bool:
|
40
|
+
"""Check if a folder exists.
|
41
|
+
|
42
|
+
Args:
|
43
|
+
path: The path to the folder.
|
44
|
+
|
45
|
+
Returns:
|
46
|
+
bool: True if the folder exists, False otherwise.
|
47
|
+
|
48
|
+
"""
|
49
|
+
|
50
|
+
@abstractmethod
|
51
|
+
def file_exists(self, path: str) -> bool:
|
52
|
+
"""Check if a file exists.
|
53
|
+
|
54
|
+
Args:
|
55
|
+
path: The path to the file.
|
56
|
+
|
57
|
+
Returns:
|
58
|
+
bool: True if the file exists, False otherwise.
|
59
|
+
|
60
|
+
"""
|
61
|
+
|
62
|
+
@abstractmethod
|
63
|
+
def write(self, file_path: str, file_content: str, overwrite: bool = True) -> None:
|
64
|
+
"""Write content to a file.
|
65
|
+
|
66
|
+
Args:
|
67
|
+
file_path: The name of the file to write to.
|
68
|
+
file_content: The content to write to the file.
|
69
|
+
overwrite: If True, overwrite the file if it exists.
|
70
|
+
|
71
|
+
"""
|
72
|
+
|
73
|
+
@abstractmethod
|
74
|
+
def read(
|
75
|
+
self, file_path: str, mode: str = "r", encoding: Optional[str] = None
|
76
|
+
) -> str:
|
77
|
+
"""Read content from a file.
|
78
|
+
|
79
|
+
Args:
|
80
|
+
file_path: The path to the file to read from.
|
81
|
+
mode: The mode in which to open the file.
|
82
|
+
encoding: The encoding to use for reading the file.
|
83
|
+
|
84
|
+
Returns:
|
85
|
+
str: The content of the file.
|
86
|
+
|
87
|
+
"""
|
88
|
+
|
89
|
+
@abstractmethod
|
90
|
+
def read_bytes(self, file_path: str) -> bytes:
|
91
|
+
"""Read binary content from a file.
|
92
|
+
|
93
|
+
Args:
|
94
|
+
file_path: The path to the file to read from.
|
95
|
+
|
96
|
+
Returns:
|
97
|
+
bytes: The binary content of the file.
|
98
|
+
|
99
|
+
"""
|
100
|
+
|
101
|
+
@abstractmethod
|
102
|
+
def ls(self, path: str, recursive: bool = False) -> list[str]:
|
103
|
+
"""List the contents of a directory.
|
104
|
+
|
105
|
+
Args:
|
106
|
+
path: The path to the directory.
|
107
|
+
recursive: If True, list the contents recursively.
|
108
|
+
|
109
|
+
Returns:
|
110
|
+
list[str]: A list of the contents of the directory.
|
111
|
+
|
112
|
+
"""
|
113
|
+
|
114
|
+
@abstractmethod
|
115
|
+
def getcwd(self) -> str:
|
116
|
+
"""Get the current working directory.
|
117
|
+
|
118
|
+
Returns:
|
119
|
+
str: The current working directory.
|
120
|
+
|
121
|
+
"""
|
122
|
+
|
123
|
+
@abstractmethod
|
124
|
+
def remove_dir(self, path: str) -> None:
|
125
|
+
"""Remove a directory and all its contents.
|
126
|
+
|
127
|
+
Args:
|
128
|
+
path: The path to the directory to remove.
|
129
|
+
|
130
|
+
"""
|
131
|
+
|
132
|
+
@abstractmethod
|
133
|
+
def telemetry_path_files(self, path: str) -> Path:
|
134
|
+
"""Get the path to the telemetry files.
|
135
|
+
|
136
|
+
Args:
|
137
|
+
path: The path to the telemetry directory.
|
138
|
+
|
139
|
+
Returns:
|
140
|
+
Path: The path object representing the telemetry files.
|
141
|
+
|
142
|
+
"""
|
@@ -0,0 +1,79 @@
|
|
1
|
+
# Copyright 2025 Snowflake Inc.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
from pathlib import Path
|
17
|
+
from typing import Optional
|
18
|
+
|
19
|
+
from snowflake.snowpark_checkpoints_collector.io_utils import (
|
20
|
+
EnvStrategy,
|
21
|
+
IODefaultStrategy,
|
22
|
+
)
|
23
|
+
from snowflake.snowpark_checkpoints_collector.singleton import Singleton
|
24
|
+
|
25
|
+
|
26
|
+
class IOFileManager(metaclass=Singleton):
|
27
|
+
def __init__(self, strategy: Optional[EnvStrategy] = None):
|
28
|
+
self.strategy = strategy or IODefaultStrategy()
|
29
|
+
|
30
|
+
def mkdir(self, path: str, exist_ok: bool = False) -> None:
|
31
|
+
return self.strategy.mkdir(path, exist_ok)
|
32
|
+
|
33
|
+
def folder_exists(self, path: str) -> bool:
|
34
|
+
return self.strategy.folder_exists(path)
|
35
|
+
|
36
|
+
def file_exists(self, path: str) -> bool:
|
37
|
+
return self.strategy.file_exists(path)
|
38
|
+
|
39
|
+
def write(self, file_path: str, file_content: str, overwrite: bool = True) -> None:
|
40
|
+
return self.strategy.write(file_path, file_content, overwrite)
|
41
|
+
|
42
|
+
def read(
|
43
|
+
self, file_path: str, mode: str = "r", encoding: Optional[str] = None
|
44
|
+
) -> str:
|
45
|
+
return self.strategy.read(file_path, mode, encoding)
|
46
|
+
|
47
|
+
def read_bytes(self, file_path: str) -> bytes:
|
48
|
+
return self.strategy.read_bytes(file_path)
|
49
|
+
|
50
|
+
def ls(self, path: str, recursive: bool = False) -> list[str]:
|
51
|
+
return self.strategy.ls(path, recursive)
|
52
|
+
|
53
|
+
def getcwd(self) -> str:
|
54
|
+
return self.strategy.getcwd()
|
55
|
+
|
56
|
+
def remove_dir(self, path: str) -> None:
|
57
|
+
return self.strategy.remove_dir(path)
|
58
|
+
|
59
|
+
def telemetry_path_files(self, path: str) -> Path:
|
60
|
+
return self.strategy.telemetry_path_files(path)
|
61
|
+
|
62
|
+
def set_strategy(self, strategy: EnvStrategy):
|
63
|
+
"""Set the strategy for file and directory operations.
|
64
|
+
|
65
|
+
Args:
|
66
|
+
strategy (EnvStrategy): The strategy to use for file and directory operations.
|
67
|
+
|
68
|
+
"""
|
69
|
+
self.strategy = strategy
|
70
|
+
|
71
|
+
|
72
|
+
def get_io_file_manager():
|
73
|
+
"""Get the singleton instance of IOFileManager.
|
74
|
+
|
75
|
+
Returns:
|
76
|
+
IOFileManager: The singleton instance of IOFileManager.
|
77
|
+
|
78
|
+
"""
|
79
|
+
return IOFileManager()
|
@@ -0,0 +1,23 @@
|
|
1
|
+
# Copyright 2025 Snowflake Inc.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
|
17
|
+
class Singleton(type):
|
18
|
+
_instances = {}
|
19
|
+
|
20
|
+
def __call__(cls, *args, **kwargs):
|
21
|
+
if cls not in cls._instances:
|
22
|
+
cls._instances[cls] = super().__call__(*args, **kwargs)
|
23
|
+
return cls._instances[cls]
|
@@ -0,0 +1,20 @@
|
|
1
|
+
# Copyright 2025 Snowflake Inc.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
__all__ = ["SnowConnection"]
|
17
|
+
|
18
|
+
from snowflake.snowpark_checkpoints_collector.snow_connection_model.snow_connection import (
|
19
|
+
SnowConnection,
|
20
|
+
)
|