snowpark-checkpoints-collectors 0.1.0rc2__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/snowpark_checkpoints_collector/__init__.py +22 -0
- snowflake/snowpark_checkpoints_collector/collection_common.py +160 -0
- snowflake/snowpark_checkpoints_collector/collection_result/model/__init__.py +24 -0
- snowflake/snowpark_checkpoints_collector/collection_result/model/collection_point_result.py +91 -0
- snowflake/snowpark_checkpoints_collector/collection_result/model/collection_point_result_manager.py +69 -0
- snowflake/snowpark_checkpoints_collector/column_collection/__init__.py +22 -0
- snowflake/snowpark_checkpoints_collector/column_collection/column_collector_manager.py +253 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/__init__.py +75 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/array_column_collector.py +113 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/binary_column_collector.py +87 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/boolean_column_collector.py +71 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/column_collector_base.py +95 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/date_column_collector.py +74 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/day_time_interval_column_collector.py +67 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/decimal_column_collector.py +92 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/empty_column_collector.py +88 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/map_column_collector.py +120 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/null_column_collector.py +49 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/numeric_column_collector.py +108 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/string_column_collector.py +70 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/struct_column_collector.py +102 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/timestamp_column_collector.py +75 -0
- snowflake/snowpark_checkpoints_collector/column_collection/model/timestamp_ntz_column_collector.py +75 -0
- snowflake/snowpark_checkpoints_collector/column_pandera_checks/__init__.py +20 -0
- snowflake/snowpark_checkpoints_collector/column_pandera_checks/pandera_column_checks_manager.py +223 -0
- snowflake/snowpark_checkpoints_collector/singleton.py +23 -0
- snowflake/snowpark_checkpoints_collector/snow_connection_model/__init__.py +20 -0
- snowflake/snowpark_checkpoints_collector/snow_connection_model/snow_connection.py +172 -0
- snowflake/snowpark_checkpoints_collector/summary_stats_collector.py +366 -0
- snowflake/snowpark_checkpoints_collector/utils/checkpoint_name_utils.py +53 -0
- snowflake/snowpark_checkpoints_collector/utils/extra_config.py +112 -0
- snowflake/snowpark_checkpoints_collector/utils/file_utils.py +132 -0
- snowflake/snowpark_checkpoints_collector/utils/telemetry.py +889 -0
- snowpark_checkpoints_collectors-0.1.1.dist-info/METADATA +143 -0
- snowpark_checkpoints_collectors-0.1.1.dist-info/RECORD +37 -0
- {snowpark_checkpoints_collectors-0.1.0rc2.dist-info → snowpark_checkpoints_collectors-0.1.1.dist-info}/licenses/LICENSE +0 -25
- snowpark_checkpoints_collectors-0.1.0rc2.dist-info/METADATA +0 -347
- snowpark_checkpoints_collectors-0.1.0rc2.dist-info/RECORD +0 -4
- {snowpark_checkpoints_collectors-0.1.0rc2.dist-info → snowpark_checkpoints_collectors-0.1.1.dist-info}/WHEEL +0 -0
@@ -0,0 +1,22 @@
|
|
1
|
+
# Copyright 2025 Snowflake Inc.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
__all__ = ["collect_dataframe_checkpoint", "CheckpointMode"]
|
17
|
+
|
18
|
+
from snowflake.snowpark_checkpoints_collector.summary_stats_collector import (
|
19
|
+
collect_dataframe_checkpoint,
|
20
|
+
)
|
21
|
+
|
22
|
+
from snowflake.snowpark_checkpoints_collector.collection_common import CheckpointMode
|
@@ -0,0 +1,160 @@
|
|
1
|
+
# Copyright 2025 Snowflake Inc.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
import locale
|
17
|
+
|
18
|
+
from enum import IntEnum
|
19
|
+
|
20
|
+
|
21
|
+
class CheckpointMode(IntEnum):
|
22
|
+
|
23
|
+
"""Enum class representing the collection mode."""
|
24
|
+
|
25
|
+
SCHEMA = 1
|
26
|
+
"""Collect automatic schema inference"""
|
27
|
+
DATAFRAME = 2
|
28
|
+
"""Export DataFrame as Parquet file to Snowflake"""
|
29
|
+
|
30
|
+
|
31
|
+
# CONSTANTS
|
32
|
+
ARRAY_COLUMN_TYPE = "array"
|
33
|
+
BINARY_COLUMN_TYPE = "binary"
|
34
|
+
BOOLEAN_COLUMN_TYPE = "boolean"
|
35
|
+
BYTE_COLUMN_TYPE = "byte"
|
36
|
+
DATE_COLUMN_TYPE = "date"
|
37
|
+
DAYTIMEINTERVAL_COLUMN_TYPE = "daytimeinterval"
|
38
|
+
DECIMAL_COLUMN_TYPE = "decimal"
|
39
|
+
DOUBLE_COLUMN_TYPE = "double"
|
40
|
+
FLOAT_COLUMN_TYPE = "float"
|
41
|
+
INTEGER_COLUMN_TYPE = "integer"
|
42
|
+
LONG_COLUMN_TYPE = "long"
|
43
|
+
MAP_COLUMN_TYPE = "map"
|
44
|
+
NULL_COLUMN_TYPE = "void"
|
45
|
+
SHORT_COLUMN_TYPE = "short"
|
46
|
+
STRING_COLUMN_TYPE = "string"
|
47
|
+
STRUCT_COLUMN_TYPE = "struct"
|
48
|
+
TIMESTAMP_COLUMN_TYPE = "timestamp"
|
49
|
+
TIMESTAMP_NTZ_COLUMN_TYPE = "timestamp_ntz"
|
50
|
+
|
51
|
+
PANDAS_BOOLEAN_DTYPE = "bool"
|
52
|
+
PANDAS_DATETIME_DTYPE = "datetime64[ns]"
|
53
|
+
PANDAS_FLOAT_DTYPE = "float64"
|
54
|
+
PANDAS_INTEGER_DTYPE = "int64"
|
55
|
+
PANDAS_OBJECT_DTYPE = "object"
|
56
|
+
PANDAS_TIMEDELTA_DTYPE = "timedelta64[ns]"
|
57
|
+
|
58
|
+
NUMERIC_TYPE_COLLECTION = [
|
59
|
+
BYTE_COLUMN_TYPE,
|
60
|
+
DOUBLE_COLUMN_TYPE,
|
61
|
+
FLOAT_COLUMN_TYPE,
|
62
|
+
INTEGER_COLUMN_TYPE,
|
63
|
+
LONG_COLUMN_TYPE,
|
64
|
+
SHORT_COLUMN_TYPE,
|
65
|
+
]
|
66
|
+
|
67
|
+
INTEGER_TYPE_COLLECTION = [
|
68
|
+
BYTE_COLUMN_TYPE,
|
69
|
+
INTEGER_COLUMN_TYPE,
|
70
|
+
LONG_COLUMN_TYPE,
|
71
|
+
SHORT_COLUMN_TYPE,
|
72
|
+
]
|
73
|
+
|
74
|
+
PANDAS_OBJECT_TYPE_COLLECTION = [
|
75
|
+
STRING_COLUMN_TYPE,
|
76
|
+
ARRAY_COLUMN_TYPE,
|
77
|
+
MAP_COLUMN_TYPE,
|
78
|
+
NULL_COLUMN_TYPE,
|
79
|
+
STRUCT_COLUMN_TYPE,
|
80
|
+
]
|
81
|
+
|
82
|
+
BETWEEN_CHECK_ERROR_MESSAGE_FORMAT = "Value must be between {} and {}"
|
83
|
+
|
84
|
+
# SCHEMA CONTRACT KEYS CONSTANTS
|
85
|
+
COLUMN_ALLOW_NULL_KEY = "allow_null"
|
86
|
+
COLUMN_COUNT_KEY = "rows_count"
|
87
|
+
COLUMN_DECIMAL_PRECISION_KEY = "decimal_precision"
|
88
|
+
COLUMN_FALSE_COUNT_KEY = "false_count"
|
89
|
+
COLUMN_FORMAT_KEY = "format"
|
90
|
+
COLUMN_IS_NULLABLE_KEY = "nullable"
|
91
|
+
COLUMN_IS_UNIQUE_SIZE_KEY = "is_unique_size"
|
92
|
+
COLUMN_KEY_TYPE_KEY = "key_type"
|
93
|
+
COLUMN_MARGIN_ERROR_KEY = "margin_error"
|
94
|
+
COLUMN_MAX_KEY = "max"
|
95
|
+
COLUMN_MAX_LENGTH_KEY = "max_length"
|
96
|
+
COLUMN_MAX_SIZE_KEY = "max_size"
|
97
|
+
COLUMN_MEAN_KEY = "mean"
|
98
|
+
COLUMN_MEAN_SIZE_KEY = "mean_size"
|
99
|
+
COLUMN_METADATA_KEY = "metadata"
|
100
|
+
COLUMN_MIN_KEY = "min"
|
101
|
+
COLUMN_MIN_LENGTH_KEY = "min_length"
|
102
|
+
COLUMN_MIN_SIZE_KEY = "min_size"
|
103
|
+
COLUMN_NAME_KEY = "name"
|
104
|
+
COLUMN_NULL_COUNT_KEY = "null_count"
|
105
|
+
COLUMN_NULL_VALUE_PROPORTION_KEY = "null_value_proportion"
|
106
|
+
COLUMN_ROWS_NOT_NULL_COUNT_KEY = "rows_not_null_count"
|
107
|
+
COLUMN_ROWS_NULL_COUNT_KEY = "rows_null_count"
|
108
|
+
COLUMN_SIZE_KEY = "size"
|
109
|
+
COLUMN_TRUE_COUNT_KEY = "true_count"
|
110
|
+
COLUMN_TYPE_KEY = "type"
|
111
|
+
COLUMN_VALUE_KEY = "value"
|
112
|
+
COLUMN_VALUE_TYPE_KEY = "value_type"
|
113
|
+
COLUMNS_KEY = "columns"
|
114
|
+
|
115
|
+
DATAFRAME_CUSTOM_DATA_KEY = "custom_data"
|
116
|
+
DATAFRAME_PANDERA_SCHEMA_KEY = "pandera_schema"
|
117
|
+
|
118
|
+
PANDERA_COLUMN_TYPE_KEY = "dtype"
|
119
|
+
|
120
|
+
CONTAINS_NULL_KEY = "containsNull"
|
121
|
+
ELEMENT_TYPE_KEY = "elementType"
|
122
|
+
FIELD_METADATA_KEY = "metadata"
|
123
|
+
FIELDS_KEY = "fields"
|
124
|
+
KEY_TYPE_KEY = "keyType"
|
125
|
+
NAME_KEY = "name"
|
126
|
+
VALUE_CONTAINS_NULL_KEY = "valueContainsNull"
|
127
|
+
VALUE_TYPE_KEY = "valueType"
|
128
|
+
|
129
|
+
# DIRECTORY AND FILE NAME CONSTANTS
|
130
|
+
SNOWPARK_CHECKPOINTS_OUTPUT_DIRECTORY_NAME = "snowpark-checkpoints-output"
|
131
|
+
CHECKPOINT_JSON_OUTPUT_FILE_NAME_FORMAT = "{}.json"
|
132
|
+
CHECKPOINT_PARQUET_OUTPUT_FILE_NAME_FORMAT = "{}.parquet"
|
133
|
+
COLLECTION_RESULT_FILE_NAME = "checkpoint_collection_results.json"
|
134
|
+
|
135
|
+
# MISC KEYS
|
136
|
+
DECIMAL_TOKEN_KEY = "decimal_point"
|
137
|
+
DOT_PARQUET_EXTENSION = ".parquet"
|
138
|
+
DOT_IPYNB_EXTENSION = ".ipynb"
|
139
|
+
UNKNOWN_SOURCE_FILE = "unknown"
|
140
|
+
UNKNOWN_LINE_OF_CODE = -1
|
141
|
+
BACKSLASH_TOKEN = "\\"
|
142
|
+
SLASH_TOKEN = "/"
|
143
|
+
PYSPARK_NONE_SIZE_VALUE = -1
|
144
|
+
PANDAS_LONG_TYPE = "Int64"
|
145
|
+
|
146
|
+
# ENVIRONMENT VARIABLES
|
147
|
+
SNOWFLAKE_CHECKPOINT_CONTRACT_FILE_PATH_ENV_VAR = (
|
148
|
+
"SNOWFLAKE_CHECKPOINT_CONTRACT_FILE_PATH"
|
149
|
+
)
|
150
|
+
|
151
|
+
|
152
|
+
def get_decimal_token() -> str:
|
153
|
+
"""Return the decimal token based on the local environment.
|
154
|
+
|
155
|
+
Returns:
|
156
|
+
str: The decimal token.
|
157
|
+
|
158
|
+
"""
|
159
|
+
decimal_token = locale.localeconv()[DECIMAL_TOKEN_KEY]
|
160
|
+
return decimal_token
|
@@ -0,0 +1,24 @@
|
|
1
|
+
# Copyright 2025 Snowflake Inc.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
__all__ = ["CollectionPointResult", "CollectionResult", "CollectionPointResultManager"]
|
17
|
+
|
18
|
+
from snowflake.snowpark_checkpoints_collector.collection_result.model.collection_point_result import (
|
19
|
+
CollectionPointResult,
|
20
|
+
CollectionResult,
|
21
|
+
)
|
22
|
+
from snowflake.snowpark_checkpoints_collector.collection_result.model.collection_point_result_manager import (
|
23
|
+
CollectionPointResultManager,
|
24
|
+
)
|
@@ -0,0 +1,91 @@
|
|
1
|
+
# Copyright 2025 Snowflake Inc.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
from datetime import datetime
|
16
|
+
from enum import Enum
|
17
|
+
|
18
|
+
from snowflake.snowpark_checkpoints_collector.utils import file_utils
|
19
|
+
|
20
|
+
|
21
|
+
TIMESTAMP_FORMAT = "%Y-%m-%d %H:%M:%S"
|
22
|
+
|
23
|
+
TIMESTAMP_KEY = "timestamp"
|
24
|
+
FILE_KEY = "file"
|
25
|
+
RESULT_KEY = "result"
|
26
|
+
LINE_OF_CODE_KEY = "line_of_code"
|
27
|
+
CHECKPOINT_NAME_KEY = "checkpoint_name"
|
28
|
+
|
29
|
+
|
30
|
+
class CollectionResult(Enum):
|
31
|
+
FAIL = "FAIL"
|
32
|
+
PASS = "PASS"
|
33
|
+
|
34
|
+
|
35
|
+
class CollectionPointResult:
|
36
|
+
|
37
|
+
"""Class for checkpoint collection results.
|
38
|
+
|
39
|
+
Attributes:
|
40
|
+
_timestamp (timestamp): the timestamp when collection started.
|
41
|
+
_file_path (str): the full path where checkpoint is.
|
42
|
+
_line_of_code (int): the line of code where the checkpoint is.
|
43
|
+
_checkpoint_name (str): the checkpoint name.
|
44
|
+
_result (CollectionResult): the result status of the checkpoint collection point.
|
45
|
+
|
46
|
+
"""
|
47
|
+
|
48
|
+
def __init__(
|
49
|
+
self,
|
50
|
+
file_path: str,
|
51
|
+
line_of_code: int,
|
52
|
+
checkpoint_name: str,
|
53
|
+
) -> None:
|
54
|
+
"""Init CollectionPointResult.
|
55
|
+
|
56
|
+
Args:
|
57
|
+
file_path (str): the full path where checkpoint is.
|
58
|
+
line_of_code (int): the line of code where the checkpoint is.
|
59
|
+
checkpoint_name (str): the checkpoint name.
|
60
|
+
|
61
|
+
"""
|
62
|
+
self._timestamp = datetime.now()
|
63
|
+
self._file_path = file_path
|
64
|
+
self._line_of_code = line_of_code
|
65
|
+
self._checkpoint_name = checkpoint_name
|
66
|
+
self._result = None
|
67
|
+
|
68
|
+
@property
|
69
|
+
def result(self):
|
70
|
+
"""Get the result status of the checkpoint collection point."""
|
71
|
+
return self._result
|
72
|
+
|
73
|
+
@result.setter
|
74
|
+
def result(self, value):
|
75
|
+
"""Set the result status of the checkpoint collection point."""
|
76
|
+
self._result = value
|
77
|
+
|
78
|
+
def get_collection_result_data(self) -> dict[str, any]:
|
79
|
+
"""Get the results of the checkpoint collection point."""
|
80
|
+
timestamp_with_format = self._timestamp.strftime(TIMESTAMP_FORMAT)
|
81
|
+
relative_path = file_utils.get_relative_file_path(self._file_path)
|
82
|
+
|
83
|
+
collection_point_result = {
|
84
|
+
TIMESTAMP_KEY: timestamp_with_format,
|
85
|
+
FILE_KEY: relative_path,
|
86
|
+
LINE_OF_CODE_KEY: self._line_of_code,
|
87
|
+
CHECKPOINT_NAME_KEY: self._checkpoint_name,
|
88
|
+
RESULT_KEY: self.result.value,
|
89
|
+
}
|
90
|
+
|
91
|
+
return collection_point_result
|
snowflake/snowpark_checkpoints_collector/collection_result/model/collection_point_result_manager.py
ADDED
@@ -0,0 +1,69 @@
|
|
1
|
+
# Copyright 2025 Snowflake Inc.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
import json
|
16
|
+
|
17
|
+
from typing import Optional
|
18
|
+
|
19
|
+
from snowflake.snowpark_checkpoints_collector.collection_result.model import (
|
20
|
+
CollectionPointResult,
|
21
|
+
)
|
22
|
+
from snowflake.snowpark_checkpoints_collector.singleton import Singleton
|
23
|
+
from snowflake.snowpark_checkpoints_collector.utils import file_utils
|
24
|
+
|
25
|
+
|
26
|
+
RESULTS_KEY = "results"
|
27
|
+
|
28
|
+
|
29
|
+
class CollectionPointResultManager(metaclass=Singleton):
|
30
|
+
|
31
|
+
"""Class for manage the checkpoint collection results. It is a singleton.
|
32
|
+
|
33
|
+
Attributes:
|
34
|
+
result_collection (list[any]): the collection of the checkpoint results.
|
35
|
+
output_file_path (str): the full path of the output file.
|
36
|
+
|
37
|
+
"""
|
38
|
+
|
39
|
+
def __init__(self, output_path: Optional[str] = None) -> None:
|
40
|
+
"""Init CollectionPointResultManager."""
|
41
|
+
self.result_collection: list[any] = []
|
42
|
+
self.output_file_path = file_utils.get_output_file_path(output_path)
|
43
|
+
|
44
|
+
def add_result(self, result: CollectionPointResult) -> None:
|
45
|
+
"""Add the CollectionPointResult result to the collection.
|
46
|
+
|
47
|
+
Args:
|
48
|
+
result (CollectionPointResult): the CollectionPointResult to add.
|
49
|
+
|
50
|
+
"""
|
51
|
+
result_json = result.get_collection_result_data()
|
52
|
+
self.result_collection.append(result_json)
|
53
|
+
self._save_result()
|
54
|
+
|
55
|
+
def to_json(self) -> str:
|
56
|
+
"""Convert to json the checkpoint results collected.
|
57
|
+
|
58
|
+
Returns:
|
59
|
+
str: the results as json string.
|
60
|
+
|
61
|
+
"""
|
62
|
+
dict_object = {RESULTS_KEY: self.result_collection}
|
63
|
+
result_collection_json = json.dumps(dict_object)
|
64
|
+
return result_collection_json
|
65
|
+
|
66
|
+
def _save_result(self) -> None:
|
67
|
+
result_collection_json = self.to_json()
|
68
|
+
with open(self.output_file_path, "w") as f:
|
69
|
+
f.write(result_collection_json)
|
@@ -0,0 +1,22 @@
|
|
1
|
+
# Copyright 2025 Snowflake Inc.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
__all__ = [
|
17
|
+
"ColumnCollectorManager",
|
18
|
+
]
|
19
|
+
|
20
|
+
from snowflake.snowpark_checkpoints_collector.column_collection.column_collector_manager import (
|
21
|
+
ColumnCollectorManager,
|
22
|
+
)
|
@@ -0,0 +1,253 @@
|
|
1
|
+
# Copyright 2025 Snowflake Inc.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
from pyspark.sql import DataFrame as SparkDataFrame
|
16
|
+
from pyspark.sql.types import StructField
|
17
|
+
|
18
|
+
from snowflake.snowpark_checkpoints_collector.collection_common import (
|
19
|
+
ARRAY_COLUMN_TYPE,
|
20
|
+
BINARY_COLUMN_TYPE,
|
21
|
+
BOOLEAN_COLUMN_TYPE,
|
22
|
+
BYTE_COLUMN_TYPE,
|
23
|
+
DATE_COLUMN_TYPE,
|
24
|
+
DAYTIMEINTERVAL_COLUMN_TYPE,
|
25
|
+
DECIMAL_COLUMN_TYPE,
|
26
|
+
DOUBLE_COLUMN_TYPE,
|
27
|
+
FLOAT_COLUMN_TYPE,
|
28
|
+
INTEGER_COLUMN_TYPE,
|
29
|
+
LONG_COLUMN_TYPE,
|
30
|
+
MAP_COLUMN_TYPE,
|
31
|
+
NULL_COLUMN_TYPE,
|
32
|
+
SHORT_COLUMN_TYPE,
|
33
|
+
STRING_COLUMN_TYPE,
|
34
|
+
STRUCT_COLUMN_TYPE,
|
35
|
+
TIMESTAMP_COLUMN_TYPE,
|
36
|
+
TIMESTAMP_NTZ_COLUMN_TYPE,
|
37
|
+
)
|
38
|
+
from snowflake.snowpark_checkpoints_collector.column_collection.model import (
|
39
|
+
ArrayColumnCollector,
|
40
|
+
BinaryColumnCollector,
|
41
|
+
BooleanColumnCollector,
|
42
|
+
DateColumnCollector,
|
43
|
+
DayTimeIntervalColumnCollector,
|
44
|
+
DecimalColumnCollector,
|
45
|
+
EmptyColumnCollector,
|
46
|
+
MapColumnCollector,
|
47
|
+
NullColumnCollector,
|
48
|
+
NumericColumnCollector,
|
49
|
+
StringColumnCollector,
|
50
|
+
StructColumnCollector,
|
51
|
+
TimestampColumnCollector,
|
52
|
+
TimestampNTZColumnCollector,
|
53
|
+
)
|
54
|
+
|
55
|
+
|
56
|
+
def collector_register(cls):
|
57
|
+
"""Decorate a class with the collection type mechanism.
|
58
|
+
|
59
|
+
Args:
|
60
|
+
cls: The class to decorate.
|
61
|
+
|
62
|
+
Returns:
|
63
|
+
The class to decorate.
|
64
|
+
|
65
|
+
"""
|
66
|
+
cls._collectors = {}
|
67
|
+
for method_name in dir(cls):
|
68
|
+
method = getattr(cls, method_name)
|
69
|
+
if hasattr(method, "_column_type"):
|
70
|
+
col_type_collection = method._column_type
|
71
|
+
for col_type in col_type_collection:
|
72
|
+
cls._collectors[col_type] = method_name
|
73
|
+
return cls
|
74
|
+
|
75
|
+
|
76
|
+
def column_register(*args):
|
77
|
+
"""Decorate a method to register it in the collection mechanism based on column type.
|
78
|
+
|
79
|
+
Args:
|
80
|
+
args: the column type to register.
|
81
|
+
|
82
|
+
Returns:
|
83
|
+
The wrapper.
|
84
|
+
|
85
|
+
"""
|
86
|
+
|
87
|
+
def wrapper(func):
|
88
|
+
has_arguments = len(args) > 0
|
89
|
+
if has_arguments:
|
90
|
+
func._column_type = args
|
91
|
+
return func
|
92
|
+
|
93
|
+
return wrapper
|
94
|
+
|
95
|
+
|
96
|
+
@collector_register
|
97
|
+
class ColumnCollectorManager:
|
98
|
+
|
99
|
+
"""Manage class for column collector based on type."""
|
100
|
+
|
101
|
+
def collect_column(
|
102
|
+
self, clm_name: str, struct_field: StructField, values: SparkDataFrame
|
103
|
+
) -> dict[str, any]:
|
104
|
+
"""Collect the data of the column based on the column type.
|
105
|
+
|
106
|
+
Args:
|
107
|
+
clm_name (str): the name of the column.
|
108
|
+
struct_field (pyspark.sql.types.StructField): the struct field of the column type.
|
109
|
+
values (pyspark.sql.DataFrame): the column values as PySpark DataFrame.
|
110
|
+
|
111
|
+
Returns:
|
112
|
+
dict[str, any]: The data collected.
|
113
|
+
|
114
|
+
"""
|
115
|
+
clm_type = struct_field.dataType.typeName()
|
116
|
+
if clm_type not in self._collectors:
|
117
|
+
return {}
|
118
|
+
|
119
|
+
func_name = self._collectors[clm_type]
|
120
|
+
func = getattr(self, func_name)
|
121
|
+
data = func(clm_name, struct_field, values)
|
122
|
+
return data
|
123
|
+
|
124
|
+
@column_register(ARRAY_COLUMN_TYPE)
|
125
|
+
def _collect_array_type_custom_data(
|
126
|
+
self, clm_name: str, struct_field: StructField, values: SparkDataFrame
|
127
|
+
) -> dict[str, any]:
|
128
|
+
column_collector = ArrayColumnCollector(clm_name, struct_field, values)
|
129
|
+
collected_data = column_collector.get_data()
|
130
|
+
return collected_data
|
131
|
+
|
132
|
+
@column_register(BINARY_COLUMN_TYPE)
|
133
|
+
def _collect_binary_type_custom_data(
|
134
|
+
self, clm_name: str, struct_field: StructField, values: SparkDataFrame
|
135
|
+
) -> dict[str, any]:
|
136
|
+
column_collector = BinaryColumnCollector(clm_name, struct_field, values)
|
137
|
+
collected_data = column_collector.get_data()
|
138
|
+
return collected_data
|
139
|
+
|
140
|
+
@column_register(BOOLEAN_COLUMN_TYPE)
|
141
|
+
def _collect_boolean_type_custom_data(
|
142
|
+
self, clm_name: str, struct_field: StructField, values: SparkDataFrame
|
143
|
+
) -> dict[str, any]:
|
144
|
+
column_collector = BooleanColumnCollector(clm_name, struct_field, values)
|
145
|
+
collected_data = column_collector.get_data()
|
146
|
+
return collected_data
|
147
|
+
|
148
|
+
@column_register(DATE_COLUMN_TYPE)
|
149
|
+
def _collect_date_type_custom_data(
|
150
|
+
self, clm_name: str, struct_field: StructField, values: SparkDataFrame
|
151
|
+
) -> dict[str, any]:
|
152
|
+
column_collector = DateColumnCollector(clm_name, struct_field, values)
|
153
|
+
collected_data = column_collector.get_data()
|
154
|
+
return collected_data
|
155
|
+
|
156
|
+
@column_register(DAYTIMEINTERVAL_COLUMN_TYPE)
|
157
|
+
def _collect_day_time_interval_type_custom_data(
|
158
|
+
self, clm_name: str, struct_field: StructField, values: SparkDataFrame
|
159
|
+
) -> dict[str, any]:
|
160
|
+
column_collector = DayTimeIntervalColumnCollector(
|
161
|
+
clm_name, struct_field, values
|
162
|
+
)
|
163
|
+
collected_data = column_collector.get_data()
|
164
|
+
return collected_data
|
165
|
+
|
166
|
+
@column_register(DECIMAL_COLUMN_TYPE)
|
167
|
+
def _collect_decimal_type_custom_data(
|
168
|
+
self, clm_name: str, struct_field: StructField, values: SparkDataFrame
|
169
|
+
) -> dict[str, any]:
|
170
|
+
column_collector = DecimalColumnCollector(clm_name, struct_field, values)
|
171
|
+
collected_data = column_collector.get_data()
|
172
|
+
return collected_data
|
173
|
+
|
174
|
+
@column_register(MAP_COLUMN_TYPE)
|
175
|
+
def _collect_map_type_custom_data(
|
176
|
+
self, clm_name: str, struct_field: StructField, values: SparkDataFrame
|
177
|
+
) -> dict[str, any]:
|
178
|
+
column_collector = MapColumnCollector(clm_name, struct_field, values)
|
179
|
+
collected_data = column_collector.get_data()
|
180
|
+
return collected_data
|
181
|
+
|
182
|
+
@column_register(NULL_COLUMN_TYPE)
|
183
|
+
def _collect_null_type_custom_data(
|
184
|
+
self, clm_name: str, struct_field: StructField, values: SparkDataFrame
|
185
|
+
) -> dict[str, any]:
|
186
|
+
column_collector = NullColumnCollector(clm_name, struct_field, values)
|
187
|
+
collected_data = column_collector.get_data()
|
188
|
+
return collected_data
|
189
|
+
|
190
|
+
@column_register(
|
191
|
+
BYTE_COLUMN_TYPE,
|
192
|
+
SHORT_COLUMN_TYPE,
|
193
|
+
INTEGER_COLUMN_TYPE,
|
194
|
+
LONG_COLUMN_TYPE,
|
195
|
+
FLOAT_COLUMN_TYPE,
|
196
|
+
DOUBLE_COLUMN_TYPE,
|
197
|
+
)
|
198
|
+
def _collect_numeric_type_custom_data(
|
199
|
+
self, clm_name: str, struct_field: StructField, values: SparkDataFrame
|
200
|
+
) -> dict[str, any]:
|
201
|
+
column_collector = NumericColumnCollector(clm_name, struct_field, values)
|
202
|
+
collected_data = column_collector.get_data()
|
203
|
+
return collected_data
|
204
|
+
|
205
|
+
@column_register(STRING_COLUMN_TYPE)
|
206
|
+
def _collect_string_type_custom_data(
|
207
|
+
self, clm_name: str, struct_field: StructField, values: SparkDataFrame
|
208
|
+
) -> dict[str, any]:
|
209
|
+
column_collector = StringColumnCollector(clm_name, struct_field, values)
|
210
|
+
collected_data = column_collector.get_data()
|
211
|
+
return collected_data
|
212
|
+
|
213
|
+
@column_register(STRUCT_COLUMN_TYPE)
|
214
|
+
def _collect_struct_type_custom_data(
|
215
|
+
self, clm_name: str, struct_field: StructField, values: SparkDataFrame
|
216
|
+
) -> dict[str, any]:
|
217
|
+
column_collector = StructColumnCollector(clm_name, struct_field, values)
|
218
|
+
collected_data = column_collector.get_data()
|
219
|
+
return collected_data
|
220
|
+
|
221
|
+
@column_register(TIMESTAMP_COLUMN_TYPE)
|
222
|
+
def _collect_timestamp_type_custom_data(
|
223
|
+
self, clm_name: str, struct_field: StructField, values: SparkDataFrame
|
224
|
+
) -> dict[str, any]:
|
225
|
+
column_collector = TimestampColumnCollector(clm_name, struct_field, values)
|
226
|
+
collected_data = column_collector.get_data()
|
227
|
+
return collected_data
|
228
|
+
|
229
|
+
@column_register(TIMESTAMP_NTZ_COLUMN_TYPE)
|
230
|
+
def _collect_timestampntz_type_custom_data(
|
231
|
+
self, clm_name: str, struct_field: StructField, values: SparkDataFrame
|
232
|
+
) -> dict[str, any]:
|
233
|
+
column_collector = TimestampNTZColumnCollector(clm_name, struct_field, values)
|
234
|
+
collected_data = column_collector.get_data()
|
235
|
+
return collected_data
|
236
|
+
|
237
|
+
def collect_empty_custom_data(
|
238
|
+
self, clm_name: str, struct_field: StructField, values: SparkDataFrame
|
239
|
+
) -> dict[str, any]:
|
240
|
+
"""Collect the data of a empty column.
|
241
|
+
|
242
|
+
Args:
|
243
|
+
clm_name (str): the name of the column.
|
244
|
+
struct_field (pyspark.sql.types.StructField): the struct field of the column type.
|
245
|
+
values (pyspark.sql.DataFrame): the column values as PySpark DataFrame.
|
246
|
+
|
247
|
+
Returns:
|
248
|
+
dict[str, any]: The data collected.
|
249
|
+
|
250
|
+
"""
|
251
|
+
column_collector = EmptyColumnCollector(clm_name, struct_field, values)
|
252
|
+
collected_data = column_collector.get_data()
|
253
|
+
return collected_data
|