snowpark-checkpoints-collectors 0.1.0rc2__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/snowpark_checkpoints_collector/__init__.py +22 -0
- snowflake/snowpark_checkpoints_collector/collection_common.py +160 -0
- snowflake/snowpark_checkpoints_collector/collection_result/model/__init__.py +24 -0
- snowflake/snowpark_checkpoints_collector/collection_result/model/collection_point_result.py +91 -0
- snowflake/snowpark_checkpoints_collector/collection_result/model/collection_point_result_manager.py +69 -0
- snowflake/snowpark_checkpoints_collector/column_collection/__init__.py +22 -0
- snowflake/snowpark_checkpoints_collector/column_collection/column_collector_manager.py +253 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/__init__.py +75 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/array_column_collector.py +113 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/binary_column_collector.py +87 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/boolean_column_collector.py +71 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/column_collector_base.py +95 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/date_column_collector.py +74 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/day_time_interval_column_collector.py +67 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/decimal_column_collector.py +92 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/empty_column_collector.py +88 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/map_column_collector.py +120 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/null_column_collector.py +49 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/numeric_column_collector.py +108 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/string_column_collector.py +70 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/struct_column_collector.py +102 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/timestamp_column_collector.py +75 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/timestamp_ntz_column_collector.py +75 -0
- snowflake/snowpark_checkpoints_collector/column_pandera_checks/__init__.py +20 -0
- snowflake/snowpark_checkpoints_collector/column_pandera_checks/pandera_column_checks_manager.py +223 -0
- snowflake/snowpark_checkpoints_collector/singleton.py +23 -0
- snowflake/snowpark_checkpoints_collector/snow_connection_model/__init__.py +20 -0
- snowflake/snowpark_checkpoints_collector/snow_connection_model/snow_connection.py +172 -0
- snowflake/snowpark_checkpoints_collector/summary_stats_collector.py +366 -0
- snowflake/snowpark_checkpoints_collector/utils/checkpoint_name_utils.py +53 -0
- snowflake/snowpark_checkpoints_collector/utils/extra_config.py +112 -0
- snowflake/snowpark_checkpoints_collector/utils/file_utils.py +132 -0
- snowflake/snowpark_checkpoints_collector/utils/telemetry.py +889 -0
- snowpark_checkpoints_collectors-0.1.1.dist-info/METADATA +143 -0
- snowpark_checkpoints_collectors-0.1.1.dist-info/RECORD +37 -0
- {snowpark_checkpoints_collectors-0.1.0rc2.dist-info → snowpark_checkpoints_collectors-0.1.1.dist-info}/licenses/LICENSE +0 -25
- snowpark_checkpoints_collectors-0.1.0rc2.dist-info/METADATA +0 -347
- snowpark_checkpoints_collectors-0.1.0rc2.dist-info/RECORD +0 -4
- {snowpark_checkpoints_collectors-0.1.0rc2.dist-info → snowpark_checkpoints_collectors-0.1.1.dist-info}/WHEEL +0 -0
@@ -0,0 +1,20 @@
|
|
1
|
+
# Copyright 2025 Snowflake Inc.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
__all__ = ["PanderaColumnChecksManager"]
|
17
|
+
|
18
|
+
from snowflake.snowpark_checkpoints_collector.column_pandera_checks.pandera_column_checks_manager import (
|
19
|
+
PanderaColumnChecksManager,
|
20
|
+
)
|
snowflake/snowpark_checkpoints_collector/column_pandera_checks/pandera_column_checks_manager.py
ADDED
@@ -0,0 +1,223 @@
|
|
1
|
+
# Copyright 2025 Snowflake Inc.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
import pandas as pd
|
16
|
+
|
17
|
+
from pandera import Check, Column
|
18
|
+
from pyspark.sql import DataFrame as SparkDataFrame
|
19
|
+
from pyspark.sql.functions import col as spark_col
|
20
|
+
from pyspark.sql.functions import length as spark_length
|
21
|
+
from pyspark.sql.functions import max as spark_max
|
22
|
+
from pyspark.sql.functions import min as spark_min
|
23
|
+
|
24
|
+
from snowflake.snowpark_checkpoints_collector.collection_common import (
|
25
|
+
BETWEEN_CHECK_ERROR_MESSAGE_FORMAT,
|
26
|
+
BOOLEAN_COLUMN_TYPE,
|
27
|
+
BYTE_COLUMN_TYPE,
|
28
|
+
COLUMN_MAX_KEY,
|
29
|
+
COLUMN_MIN_KEY,
|
30
|
+
DAYTIMEINTERVAL_COLUMN_TYPE,
|
31
|
+
DOUBLE_COLUMN_TYPE,
|
32
|
+
FLOAT_COLUMN_TYPE,
|
33
|
+
INTEGER_COLUMN_TYPE,
|
34
|
+
LONG_COLUMN_TYPE,
|
35
|
+
SHORT_COLUMN_TYPE,
|
36
|
+
STRING_COLUMN_TYPE,
|
37
|
+
TIMESTAMP_COLUMN_TYPE,
|
38
|
+
TIMESTAMP_NTZ_COLUMN_TYPE,
|
39
|
+
)
|
40
|
+
|
41
|
+
|
42
|
+
def collector_register(cls):
|
43
|
+
"""Decorate a class with the checks mechanism.
|
44
|
+
|
45
|
+
Args:
|
46
|
+
cls: The class to decorate.
|
47
|
+
|
48
|
+
Returns:
|
49
|
+
The class to decorate.
|
50
|
+
|
51
|
+
"""
|
52
|
+
cls._collectors = {}
|
53
|
+
for method_name in dir(cls):
|
54
|
+
method = getattr(cls, method_name)
|
55
|
+
if hasattr(method, "_column_type"):
|
56
|
+
col_type_collection = method._column_type
|
57
|
+
for col_type in col_type_collection:
|
58
|
+
cls._collectors[col_type] = method_name
|
59
|
+
return cls
|
60
|
+
|
61
|
+
|
62
|
+
def column_register(*args):
|
63
|
+
"""Decorate a method to register it in the checks mechanism based on column type.
|
64
|
+
|
65
|
+
Args:
|
66
|
+
args: the column type to register.
|
67
|
+
|
68
|
+
Returns:
|
69
|
+
The wrapper.
|
70
|
+
|
71
|
+
"""
|
72
|
+
|
73
|
+
def wrapper(func):
|
74
|
+
has_arguments = len(args) > 0
|
75
|
+
if has_arguments:
|
76
|
+
func._column_type = args
|
77
|
+
return func
|
78
|
+
|
79
|
+
return wrapper
|
80
|
+
|
81
|
+
|
82
|
+
@collector_register
|
83
|
+
class PanderaColumnChecksManager:
|
84
|
+
|
85
|
+
"""Manage class for Pandera column checks based on type."""
|
86
|
+
|
87
|
+
def add_checks_column(
|
88
|
+
self,
|
89
|
+
clm_name: str,
|
90
|
+
clm_type: str,
|
91
|
+
pyspark_df: SparkDataFrame,
|
92
|
+
pandera_column: Column,
|
93
|
+
) -> None:
|
94
|
+
"""Add checks to Pandera column based on the column type.
|
95
|
+
|
96
|
+
Args:
|
97
|
+
clm_name (str): the name of the column.
|
98
|
+
clm_type (str): the type of the column.
|
99
|
+
pyspark_df (pyspark.sql.DataFrame): the DataFrame.
|
100
|
+
pandera_column (pandera.Column): the Pandera column.
|
101
|
+
|
102
|
+
"""
|
103
|
+
if clm_type not in self._collectors:
|
104
|
+
return
|
105
|
+
|
106
|
+
func_name = self._collectors[clm_type]
|
107
|
+
func = getattr(self, func_name)
|
108
|
+
func(clm_name, pyspark_df, pandera_column)
|
109
|
+
|
110
|
+
@column_register(BOOLEAN_COLUMN_TYPE)
|
111
|
+
def _add_boolean_type_checks(
|
112
|
+
self, clm_name: str, pyspark_df: SparkDataFrame, pandera_column: Column
|
113
|
+
) -> None:
|
114
|
+
pandera_column.checks.extend([Check.isin([True, False])])
|
115
|
+
|
116
|
+
@column_register(DAYTIMEINTERVAL_COLUMN_TYPE)
|
117
|
+
def _add_daytimeinterval_type_checks(
|
118
|
+
self, clm_name: str, pyspark_df: SparkDataFrame, pandera_column: Column
|
119
|
+
) -> None:
|
120
|
+
select_result = pyspark_df.select(
|
121
|
+
spark_min(spark_col(clm_name)).alias(COLUMN_MIN_KEY),
|
122
|
+
spark_max(spark_col(clm_name)).alias(COLUMN_MAX_KEY),
|
123
|
+
).collect()[0]
|
124
|
+
|
125
|
+
min_value = pd.to_timedelta(select_result[COLUMN_MIN_KEY])
|
126
|
+
max_value = pd.to_timedelta(select_result[COLUMN_MAX_KEY])
|
127
|
+
|
128
|
+
pandera_column.checks.append(
|
129
|
+
Check.between(
|
130
|
+
min_value=min_value,
|
131
|
+
max_value=max_value,
|
132
|
+
include_max=True,
|
133
|
+
include_min=True,
|
134
|
+
title=BETWEEN_CHECK_ERROR_MESSAGE_FORMAT.format(min_value, max_value),
|
135
|
+
)
|
136
|
+
)
|
137
|
+
|
138
|
+
@column_register(
|
139
|
+
BYTE_COLUMN_TYPE,
|
140
|
+
SHORT_COLUMN_TYPE,
|
141
|
+
INTEGER_COLUMN_TYPE,
|
142
|
+
LONG_COLUMN_TYPE,
|
143
|
+
FLOAT_COLUMN_TYPE,
|
144
|
+
DOUBLE_COLUMN_TYPE,
|
145
|
+
)
|
146
|
+
def _add_numeric_type_checks(
|
147
|
+
self, clm_name: str, pyspark_df: SparkDataFrame, pandera_column: Column
|
148
|
+
) -> None:
|
149
|
+
select_result = pyspark_df.select(
|
150
|
+
spark_min(spark_col(clm_name)).alias(COLUMN_MIN_KEY),
|
151
|
+
spark_max(spark_col(clm_name)).alias(COLUMN_MAX_KEY),
|
152
|
+
).collect()[0]
|
153
|
+
|
154
|
+
min_value = select_result[COLUMN_MIN_KEY]
|
155
|
+
max_value = select_result[COLUMN_MAX_KEY]
|
156
|
+
|
157
|
+
pandera_column.checks.append(
|
158
|
+
Check.between(
|
159
|
+
min_value=min_value,
|
160
|
+
max_value=max_value,
|
161
|
+
include_max=True,
|
162
|
+
include_min=True,
|
163
|
+
title=BETWEEN_CHECK_ERROR_MESSAGE_FORMAT.format(min_value, max_value),
|
164
|
+
)
|
165
|
+
)
|
166
|
+
|
167
|
+
@column_register(STRING_COLUMN_TYPE)
|
168
|
+
def _add_string_type_checks(
|
169
|
+
self, clm_name: str, pyspark_df: SparkDataFrame, pandera_column: Column
|
170
|
+
) -> None:
|
171
|
+
select_result = pyspark_df.select(
|
172
|
+
spark_min(spark_length(spark_col(clm_name))).alias(COLUMN_MIN_KEY),
|
173
|
+
spark_max(spark_length(spark_col(clm_name))).alias(COLUMN_MAX_KEY),
|
174
|
+
).collect()[0]
|
175
|
+
|
176
|
+
min_length = select_result[COLUMN_MIN_KEY]
|
177
|
+
max_length = select_result[COLUMN_MAX_KEY]
|
178
|
+
|
179
|
+
pandera_column.checks.append(Check.str_length(min_length, max_length))
|
180
|
+
|
181
|
+
@column_register(TIMESTAMP_COLUMN_TYPE)
|
182
|
+
def _add_timestamp_type_checks(
|
183
|
+
self, clm_name: str, pyspark_df: SparkDataFrame, pandera_column: Column
|
184
|
+
) -> None:
|
185
|
+
select_result = pyspark_df.select(
|
186
|
+
spark_min(spark_col(clm_name)).alias(COLUMN_MIN_KEY),
|
187
|
+
spark_max(spark_col(clm_name)).alias(COLUMN_MAX_KEY),
|
188
|
+
).collect()[0]
|
189
|
+
|
190
|
+
min_value = pd.Timestamp(select_result[COLUMN_MIN_KEY])
|
191
|
+
max_value = pd.Timestamp(select_result[COLUMN_MAX_KEY])
|
192
|
+
|
193
|
+
pandera_column.checks.append(
|
194
|
+
Check.between(
|
195
|
+
min_value=min_value,
|
196
|
+
max_value=max_value,
|
197
|
+
include_max=True,
|
198
|
+
include_min=True,
|
199
|
+
title=BETWEEN_CHECK_ERROR_MESSAGE_FORMAT.format(min_value, max_value),
|
200
|
+
)
|
201
|
+
)
|
202
|
+
|
203
|
+
@column_register(TIMESTAMP_NTZ_COLUMN_TYPE)
|
204
|
+
def _add_timestamp_ntz_type_checks(
|
205
|
+
self, clm_name: str, pyspark_df: SparkDataFrame, pandera_column: Column
|
206
|
+
) -> None:
|
207
|
+
select_result = pyspark_df.select(
|
208
|
+
spark_min(spark_col(clm_name)).alias(COLUMN_MIN_KEY),
|
209
|
+
spark_max(spark_col(clm_name)).alias(COLUMN_MAX_KEY),
|
210
|
+
).collect()[0]
|
211
|
+
|
212
|
+
min_value = pd.Timestamp(select_result[COLUMN_MIN_KEY])
|
213
|
+
max_value = pd.Timestamp(select_result[COLUMN_MAX_KEY])
|
214
|
+
|
215
|
+
pandera_column.checks.append(
|
216
|
+
Check.between(
|
217
|
+
min_value=min_value,
|
218
|
+
max_value=max_value,
|
219
|
+
include_max=True,
|
220
|
+
include_min=True,
|
221
|
+
title=BETWEEN_CHECK_ERROR_MESSAGE_FORMAT.format(min_value, max_value),
|
222
|
+
)
|
223
|
+
)
|
@@ -0,0 +1,23 @@
|
|
1
|
+
# Copyright 2025 Snowflake Inc.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
|
17
|
+
class Singleton(type):
|
18
|
+
_instances = {}
|
19
|
+
|
20
|
+
def __call__(cls, *args, **kwargs):
|
21
|
+
if cls not in cls._instances:
|
22
|
+
cls._instances[cls] = super().__call__(*args, **kwargs)
|
23
|
+
return cls._instances[cls]
|
@@ -0,0 +1,20 @@
|
|
1
|
+
# Copyright 2025 Snowflake Inc.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
__all__ = ["SnowConnection"]
|
17
|
+
|
18
|
+
from snowflake.snowpark_checkpoints_collector.snow_connection_model.snow_connection import (
|
19
|
+
SnowConnection,
|
20
|
+
)
|
@@ -0,0 +1,172 @@
|
|
1
|
+
# Copyright 2025 Snowflake Inc.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
import glob
|
16
|
+
import os.path
|
17
|
+
import time
|
18
|
+
|
19
|
+
from pathlib import Path
|
20
|
+
from typing import Callable, Optional
|
21
|
+
|
22
|
+
from snowflake.snowpark import Session
|
23
|
+
from snowflake.snowpark_checkpoints_collector.collection_common import (
|
24
|
+
DOT_PARQUET_EXTENSION,
|
25
|
+
)
|
26
|
+
|
27
|
+
|
28
|
+
STAGE_NAME = "CHECKPOINT_STAGE"
|
29
|
+
CREATE_STAGE_STATEMENT_FORMAT = "CREATE TEMP STAGE IF NOT EXISTS {}"
|
30
|
+
REMOVE_STAGE_FOLDER_STATEMENT_FORMAT = "REMOVE {}"
|
31
|
+
STAGE_PATH_FORMAT = "'@{}/{}'"
|
32
|
+
PUT_FILE_IN_STAGE_STATEMENT_FORMAT = "PUT '{}' {} AUTO_COMPRESS=FALSE"
|
33
|
+
|
34
|
+
|
35
|
+
class SnowConnection:
|
36
|
+
|
37
|
+
"""Class for manage the Snowpark Connection.
|
38
|
+
|
39
|
+
Attributes:
|
40
|
+
session (Snowpark.Session): the Snowpark session.
|
41
|
+
|
42
|
+
"""
|
43
|
+
|
44
|
+
def __init__(self, session: Session = None) -> None:
|
45
|
+
"""Init SnowConnection.
|
46
|
+
|
47
|
+
Args:
|
48
|
+
session (Snowpark.Session): the Snowpark session.
|
49
|
+
|
50
|
+
"""
|
51
|
+
self.session = session if session is not None else Session.builder.getOrCreate()
|
52
|
+
self.stage_id = int(time.time())
|
53
|
+
|
54
|
+
def create_snowflake_table_from_local_parquet(
|
55
|
+
self,
|
56
|
+
table_name: str,
|
57
|
+
input_path: str,
|
58
|
+
stage_path: Optional[str] = None,
|
59
|
+
) -> None:
|
60
|
+
"""Upload to parquet files from the input path and create a table.
|
61
|
+
|
62
|
+
Args:
|
63
|
+
table_name (str): the name of the table to be created.
|
64
|
+
input_path (str): the input directory path.
|
65
|
+
stage_path: (str, optional): the stage path.
|
66
|
+
|
67
|
+
"""
|
68
|
+
input_path = (
|
69
|
+
os.path.abspath(input_path)
|
70
|
+
if not os.path.isabs(input_path)
|
71
|
+
else str(Path(input_path).resolve())
|
72
|
+
)
|
73
|
+
folder = f"table_files_{int(time.time())}"
|
74
|
+
stage_path = stage_path if stage_path else folder
|
75
|
+
stage_name = f"{STAGE_NAME}_{self.stage_id}"
|
76
|
+
stage_directory_path = STAGE_PATH_FORMAT.format(stage_name, stage_path)
|
77
|
+
|
78
|
+
def is_parquet_file(file: str):
|
79
|
+
return file.endswith(DOT_PARQUET_EXTENSION)
|
80
|
+
|
81
|
+
try:
|
82
|
+
self.create_tmp_stage(stage_name)
|
83
|
+
self.load_files_to_stage(
|
84
|
+
stage_name, stage_path, input_path, is_parquet_file
|
85
|
+
)
|
86
|
+
self.create_table_from_parquet(table_name, stage_directory_path)
|
87
|
+
|
88
|
+
finally:
|
89
|
+
self.session.sql(
|
90
|
+
REMOVE_STAGE_FOLDER_STATEMENT_FORMAT.format(stage_directory_path)
|
91
|
+
).collect()
|
92
|
+
|
93
|
+
def create_tmp_stage(self, stage_name: str) -> None:
|
94
|
+
"""Create a temp stage in Snowflake.
|
95
|
+
|
96
|
+
Args:
|
97
|
+
stage_name (str): the name of the stage.
|
98
|
+
|
99
|
+
"""
|
100
|
+
create_stage_statement = CREATE_STAGE_STATEMENT_FORMAT.format(stage_name)
|
101
|
+
self.session.sql(create_stage_statement).collect()
|
102
|
+
|
103
|
+
def load_files_to_stage(
|
104
|
+
self,
|
105
|
+
stage_name: str,
|
106
|
+
folder_name: str,
|
107
|
+
input_path: str,
|
108
|
+
filter_func: Callable = None,
|
109
|
+
) -> None:
|
110
|
+
"""Load files to a stage in Snowflake.
|
111
|
+
|
112
|
+
Args:
|
113
|
+
stage_name (str): the name of the stage.
|
114
|
+
folder_name (str): the folder name.
|
115
|
+
input_path (str): the input directory path.
|
116
|
+
filter_func (Callable): the filter function to apply to the files.
|
117
|
+
|
118
|
+
"""
|
119
|
+
input_path = (
|
120
|
+
os.path.abspath(input_path)
|
121
|
+
if not os.path.isabs(input_path)
|
122
|
+
else str(Path(input_path).resolve())
|
123
|
+
)
|
124
|
+
|
125
|
+
def filter_files(name: str):
|
126
|
+
return os.path.isfile(name) and (filter_func(name) if filter_func else True)
|
127
|
+
|
128
|
+
target_dir = os.path.join(input_path, "**", "*")
|
129
|
+
files_collection = glob.glob(target_dir, recursive=True)
|
130
|
+
|
131
|
+
files = [file for file in files_collection if filter_files(file)]
|
132
|
+
|
133
|
+
if len(files) == 0:
|
134
|
+
raise Exception(f"No files were found in the input directory: {input_path}")
|
135
|
+
|
136
|
+
for file in files:
|
137
|
+
# if file is relative path, convert to absolute path
|
138
|
+
# if absolute path, then try to resolve as some Win32 paths are not in LPN.
|
139
|
+
file_full_path = (
|
140
|
+
str(os.path.abspath(file))
|
141
|
+
if not os.path.isabs(file)
|
142
|
+
else str(Path(file).resolve())
|
143
|
+
)
|
144
|
+
# Snowflake required URI format for input in the put.
|
145
|
+
normalize_file_path = Path(file_full_path).as_uri()
|
146
|
+
new_file_path = file_full_path.replace(input_path, folder_name)
|
147
|
+
# as Posix to convert Windows dir to posix
|
148
|
+
new_file_path = Path(new_file_path).as_posix()
|
149
|
+
stage_file_path = STAGE_PATH_FORMAT.format(stage_name, new_file_path)
|
150
|
+
put_statement = PUT_FILE_IN_STAGE_STATEMENT_FORMAT.format(
|
151
|
+
normalize_file_path, stage_file_path
|
152
|
+
)
|
153
|
+
self.session.sql(put_statement).collect()
|
154
|
+
|
155
|
+
def create_table_from_parquet(
|
156
|
+
self, table_name: str, stage_directory_path: str
|
157
|
+
) -> None:
|
158
|
+
"""Create a table from a parquet file in Snowflake.
|
159
|
+
|
160
|
+
Args:
|
161
|
+
table_name (str): the name of the table.
|
162
|
+
stage_directory_path (str): the stage directory path.
|
163
|
+
|
164
|
+
Raise:
|
165
|
+
Exception: No parquet files were found in the stage
|
166
|
+
|
167
|
+
"""
|
168
|
+
files = self.session.sql(f"LIST {stage_directory_path}").collect()
|
169
|
+
if len(files) == 0:
|
170
|
+
raise Exception("No parquet files were found in the stage.")
|
171
|
+
dataframe = self.session.read.parquet(path=stage_directory_path)
|
172
|
+
dataframe.write.save_as_table(table_name=table_name, mode="overwrite")
|