snowpark-checkpoints-collectors 0.2.0rc1__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. snowflake/snowpark_checkpoints_collector/__init__.py +30 -0
  2. snowflake/snowpark_checkpoints_collector/__version__.py +16 -0
  3. snowflake/snowpark_checkpoints_collector/collection_common.py +160 -0
  4. snowflake/snowpark_checkpoints_collector/collection_result/model/__init__.py +24 -0
  5. snowflake/snowpark_checkpoints_collector/collection_result/model/collection_point_result.py +91 -0
  6. snowflake/snowpark_checkpoints_collector/collection_result/model/collection_point_result_manager.py +76 -0
  7. snowflake/snowpark_checkpoints_collector/column_collection/__init__.py +22 -0
  8. snowflake/snowpark_checkpoints_collector/column_collection/column_collector_manager.py +276 -0
  9. snowflake/snowpark_checkpoints_collector/column_collection/model/__init__.py +75 -0
  10. snowflake/snowpark_checkpoints_collector/column_collection/model/array_column_collector.py +113 -0
  11. snowflake/snowpark_checkpoints_collector/column_collection/model/binary_column_collector.py +87 -0
  12. snowflake/snowpark_checkpoints_collector/column_collection/model/boolean_column_collector.py +71 -0
  13. snowflake/snowpark_checkpoints_collector/column_collection/model/column_collector_base.py +95 -0
  14. snowflake/snowpark_checkpoints_collector/column_collection/model/date_column_collector.py +74 -0
  15. snowflake/snowpark_checkpoints_collector/column_collection/model/day_time_interval_column_collector.py +67 -0
  16. snowflake/snowpark_checkpoints_collector/column_collection/model/decimal_column_collector.py +92 -0
  17. snowflake/snowpark_checkpoints_collector/column_collection/model/empty_column_collector.py +88 -0
  18. snowflake/snowpark_checkpoints_collector/column_collection/model/map_column_collector.py +120 -0
  19. snowflake/snowpark_checkpoints_collector/column_collection/model/null_column_collector.py +49 -0
  20. snowflake/snowpark_checkpoints_collector/column_collection/model/numeric_column_collector.py +108 -0
  21. snowflake/snowpark_checkpoints_collector/column_collection/model/string_column_collector.py +70 -0
  22. snowflake/snowpark_checkpoints_collector/column_collection/model/struct_column_collector.py +102 -0
  23. snowflake/snowpark_checkpoints_collector/column_collection/model/timestamp_column_collector.py +75 -0
  24. snowflake/snowpark_checkpoints_collector/column_collection/model/timestamp_ntz_column_collector.py +75 -0
  25. snowflake/snowpark_checkpoints_collector/column_pandera_checks/__init__.py +20 -0
  26. snowflake/snowpark_checkpoints_collector/column_pandera_checks/pandera_column_checks_manager.py +241 -0
  27. snowflake/snowpark_checkpoints_collector/io_utils/__init__.py +26 -0
  28. snowflake/snowpark_checkpoints_collector/io_utils/io_default_strategy.py +61 -0
  29. snowflake/snowpark_checkpoints_collector/io_utils/io_env_strategy.py +142 -0
  30. snowflake/snowpark_checkpoints_collector/io_utils/io_file_manager.py +79 -0
  31. snowflake/snowpark_checkpoints_collector/singleton.py +23 -0
  32. snowflake/snowpark_checkpoints_collector/snow_connection_model/__init__.py +20 -0
  33. snowflake/snowpark_checkpoints_collector/snow_connection_model/snow_connection.py +203 -0
  34. snowflake/snowpark_checkpoints_collector/summary_stats_collector.py +409 -0
  35. snowflake/snowpark_checkpoints_collector/utils/checkpoint_name_utils.py +53 -0
  36. snowflake/snowpark_checkpoints_collector/utils/extra_config.py +164 -0
  37. snowflake/snowpark_checkpoints_collector/utils/file_utils.py +137 -0
  38. snowflake/snowpark_checkpoints_collector/utils/logging_utils.py +67 -0
  39. snowflake/snowpark_checkpoints_collector/utils/telemetry.py +928 -0
  40. snowpark_checkpoints_collectors-0.3.0.dist-info/METADATA +159 -0
  41. snowpark_checkpoints_collectors-0.3.0.dist-info/RECORD +43 -0
  42. {snowpark_checkpoints_collectors-0.2.0rc1.dist-info → snowpark_checkpoints_collectors-0.3.0.dist-info}/licenses/LICENSE +0 -25
  43. snowpark_checkpoints_collectors-0.2.0rc1.dist-info/METADATA +0 -347
  44. snowpark_checkpoints_collectors-0.2.0rc1.dist-info/RECORD +0 -4
  45. {snowpark_checkpoints_collectors-0.2.0rc1.dist-info → snowpark_checkpoints_collectors-0.3.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,30 @@
1
+ # Copyright 2025 Snowflake Inc.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import logging
17
+
18
+
19
+ # Add a NullHandler to prevent logging messages from being output to
20
+ # sys.stderr if no logging configuration is provided.
21
+ logging.getLogger(__name__).addHandler(logging.NullHandler())
22
+
23
+ # ruff: noqa: E402
24
+
25
+ __all__ = ["collect_dataframe_checkpoint", "CheckpointMode"]
26
+
27
+ from snowflake.snowpark_checkpoints_collector.collection_common import CheckpointMode
28
+ from snowflake.snowpark_checkpoints_collector.summary_stats_collector import (
29
+ collect_dataframe_checkpoint,
30
+ )
@@ -0,0 +1,16 @@
1
+ # Copyright 2025 Snowflake Inc.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ __version__ = "0.3.0"
@@ -0,0 +1,160 @@
1
+ # Copyright 2025 Snowflake Inc.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import locale
17
+
18
+ from enum import IntEnum
19
+
20
+
21
+ class CheckpointMode(IntEnum):
22
+
23
+ """Enum class representing the collection mode."""
24
+
25
+ SCHEMA = 1
26
+ """Collect automatic schema inference"""
27
+ DATAFRAME = 2
28
+ """Export DataFrame as Parquet file to Snowflake"""
29
+
30
+
31
+ # CONSTANTS
32
+ ARRAY_COLUMN_TYPE = "array"
33
+ BINARY_COLUMN_TYPE = "binary"
34
+ BOOLEAN_COLUMN_TYPE = "boolean"
35
+ BYTE_COLUMN_TYPE = "byte"
36
+ DATE_COLUMN_TYPE = "date"
37
+ DAYTIMEINTERVAL_COLUMN_TYPE = "daytimeinterval"
38
+ DECIMAL_COLUMN_TYPE = "decimal"
39
+ DOUBLE_COLUMN_TYPE = "double"
40
+ FLOAT_COLUMN_TYPE = "float"
41
+ INTEGER_COLUMN_TYPE = "integer"
42
+ LONG_COLUMN_TYPE = "long"
43
+ MAP_COLUMN_TYPE = "map"
44
+ NULL_COLUMN_TYPE = "void"
45
+ SHORT_COLUMN_TYPE = "short"
46
+ STRING_COLUMN_TYPE = "string"
47
+ STRUCT_COLUMN_TYPE = "struct"
48
+ TIMESTAMP_COLUMN_TYPE = "timestamp"
49
+ TIMESTAMP_NTZ_COLUMN_TYPE = "timestamp_ntz"
50
+
51
+ PANDAS_BOOLEAN_DTYPE = "bool"
52
+ PANDAS_DATETIME_DTYPE = "datetime64[ns]"
53
+ PANDAS_FLOAT_DTYPE = "float64"
54
+ PANDAS_INTEGER_DTYPE = "int64"
55
+ PANDAS_OBJECT_DTYPE = "object"
56
+ PANDAS_TIMEDELTA_DTYPE = "timedelta64[ns]"
57
+
58
+ NUMERIC_TYPE_COLLECTION = [
59
+ BYTE_COLUMN_TYPE,
60
+ DOUBLE_COLUMN_TYPE,
61
+ FLOAT_COLUMN_TYPE,
62
+ INTEGER_COLUMN_TYPE,
63
+ LONG_COLUMN_TYPE,
64
+ SHORT_COLUMN_TYPE,
65
+ ]
66
+
67
+ INTEGER_TYPE_COLLECTION = [
68
+ BYTE_COLUMN_TYPE,
69
+ INTEGER_COLUMN_TYPE,
70
+ LONG_COLUMN_TYPE,
71
+ SHORT_COLUMN_TYPE,
72
+ ]
73
+
74
+ PANDAS_OBJECT_TYPE_COLLECTION = [
75
+ STRING_COLUMN_TYPE,
76
+ ARRAY_COLUMN_TYPE,
77
+ MAP_COLUMN_TYPE,
78
+ NULL_COLUMN_TYPE,
79
+ STRUCT_COLUMN_TYPE,
80
+ ]
81
+
82
+ BETWEEN_CHECK_ERROR_MESSAGE_FORMAT = "Value must be between {} and {}"
83
+
84
+ # SCHEMA CONTRACT KEYS CONSTANTS
85
+ COLUMN_ALLOW_NULL_KEY = "allow_null"
86
+ COLUMN_COUNT_KEY = "rows_count"
87
+ COLUMN_DECIMAL_PRECISION_KEY = "decimal_precision"
88
+ COLUMN_FALSE_COUNT_KEY = "false_count"
89
+ COLUMN_FORMAT_KEY = "format"
90
+ COLUMN_IS_NULLABLE_KEY = "nullable"
91
+ COLUMN_IS_UNIQUE_SIZE_KEY = "is_unique_size"
92
+ COLUMN_KEY_TYPE_KEY = "key_type"
93
+ COLUMN_MARGIN_ERROR_KEY = "margin_error"
94
+ COLUMN_MAX_KEY = "max"
95
+ COLUMN_MAX_LENGTH_KEY = "max_length"
96
+ COLUMN_MAX_SIZE_KEY = "max_size"
97
+ COLUMN_MEAN_KEY = "mean"
98
+ COLUMN_MEAN_SIZE_KEY = "mean_size"
99
+ COLUMN_METADATA_KEY = "metadata"
100
+ COLUMN_MIN_KEY = "min"
101
+ COLUMN_MIN_LENGTH_KEY = "min_length"
102
+ COLUMN_MIN_SIZE_KEY = "min_size"
103
+ COLUMN_NAME_KEY = "name"
104
+ COLUMN_NULL_COUNT_KEY = "null_count"
105
+ COLUMN_NULL_VALUE_PROPORTION_KEY = "null_value_proportion"
106
+ COLUMN_ROWS_NOT_NULL_COUNT_KEY = "rows_not_null_count"
107
+ COLUMN_ROWS_NULL_COUNT_KEY = "rows_null_count"
108
+ COLUMN_SIZE_KEY = "size"
109
+ COLUMN_TRUE_COUNT_KEY = "true_count"
110
+ COLUMN_TYPE_KEY = "type"
111
+ COLUMN_VALUE_KEY = "value"
112
+ COLUMN_VALUE_TYPE_KEY = "value_type"
113
+ COLUMNS_KEY = "columns"
114
+
115
+ DATAFRAME_CUSTOM_DATA_KEY = "custom_data"
116
+ DATAFRAME_PANDERA_SCHEMA_KEY = "pandera_schema"
117
+
118
+ PANDERA_COLUMN_TYPE_KEY = "dtype"
119
+
120
+ CONTAINS_NULL_KEY = "containsNull"
121
+ ELEMENT_TYPE_KEY = "elementType"
122
+ FIELD_METADATA_KEY = "metadata"
123
+ FIELDS_KEY = "fields"
124
+ KEY_TYPE_KEY = "keyType"
125
+ NAME_KEY = "name"
126
+ VALUE_CONTAINS_NULL_KEY = "valueContainsNull"
127
+ VALUE_TYPE_KEY = "valueType"
128
+
129
+ # DIRECTORY AND FILE NAME CONSTANTS
130
+ SNOWPARK_CHECKPOINTS_OUTPUT_DIRECTORY_NAME = "snowpark-checkpoints-output"
131
+ CHECKPOINT_JSON_OUTPUT_FILE_NAME_FORMAT = "{}.json"
132
+ CHECKPOINT_PARQUET_OUTPUT_FILE_NAME_FORMAT = "{}.parquet"
133
+ COLLECTION_RESULT_FILE_NAME = "checkpoint_collection_results.json"
134
+
135
+ # MISC KEYS
136
+ DECIMAL_TOKEN_KEY = "decimal_point"
137
+ DOT_PARQUET_EXTENSION = ".parquet"
138
+ DOT_IPYNB_EXTENSION = ".ipynb"
139
+ UNKNOWN_SOURCE_FILE = "unknown"
140
+ UNKNOWN_LINE_OF_CODE = -1
141
+ BACKSLASH_TOKEN = "\\"
142
+ SLASH_TOKEN = "/"
143
+ PYSPARK_NONE_SIZE_VALUE = -1
144
+ PANDAS_LONG_TYPE = "Int64"
145
+
146
+ # ENVIRONMENT VARIABLES
147
+ SNOWFLAKE_CHECKPOINT_CONTRACT_FILE_PATH_ENV_VAR = (
148
+ "SNOWFLAKE_CHECKPOINT_CONTRACT_FILE_PATH"
149
+ )
150
+
151
+
152
+ def get_decimal_token() -> str:
153
+ """Return the decimal token based on the local environment.
154
+
155
+ Returns:
156
+ str: The decimal token.
157
+
158
+ """
159
+ decimal_token = locale.localeconv()[DECIMAL_TOKEN_KEY]
160
+ return decimal_token
@@ -0,0 +1,24 @@
1
+ # Copyright 2025 Snowflake Inc.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ __all__ = ["CollectionPointResult", "CollectionResult", "CollectionPointResultManager"]
17
+
18
+ from snowflake.snowpark_checkpoints_collector.collection_result.model.collection_point_result import (
19
+ CollectionPointResult,
20
+ CollectionResult,
21
+ )
22
+ from snowflake.snowpark_checkpoints_collector.collection_result.model.collection_point_result_manager import (
23
+ CollectionPointResultManager,
24
+ )
@@ -0,0 +1,91 @@
1
+ # Copyright 2025 Snowflake Inc.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from datetime import datetime
16
+ from enum import Enum
17
+
18
+ from snowflake.snowpark_checkpoints_collector.utils import file_utils
19
+
20
+
21
+ TIMESTAMP_FORMAT = "%Y-%m-%d %H:%M:%S"
22
+
23
+ TIMESTAMP_KEY = "timestamp"
24
+ FILE_KEY = "file"
25
+ RESULT_KEY = "result"
26
+ LINE_OF_CODE_KEY = "line_of_code"
27
+ CHECKPOINT_NAME_KEY = "checkpoint_name"
28
+
29
+
30
+ class CollectionResult(Enum):
31
+ FAIL = "FAIL"
32
+ PASS = "PASS"
33
+
34
+
35
+ class CollectionPointResult:
36
+
37
+ """Class for checkpoint collection results.
38
+
39
+ Attributes:
40
+ _timestamp (timestamp): the timestamp when collection started.
41
+ _file_path (str): the full path where checkpoint is.
42
+ _line_of_code (int): the line of code where the checkpoint is.
43
+ _checkpoint_name (str): the checkpoint name.
44
+ _result (CollectionResult): the result status of the checkpoint collection point.
45
+
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ file_path: str,
51
+ line_of_code: int,
52
+ checkpoint_name: str,
53
+ ) -> None:
54
+ """Init CollectionPointResult.
55
+
56
+ Args:
57
+ file_path (str): the full path where checkpoint is.
58
+ line_of_code (int): the line of code where the checkpoint is.
59
+ checkpoint_name (str): the checkpoint name.
60
+
61
+ """
62
+ self._timestamp = datetime.now()
63
+ self._file_path = file_path
64
+ self._line_of_code = line_of_code
65
+ self._checkpoint_name = checkpoint_name
66
+ self._result = None
67
+
68
+ @property
69
+ def result(self):
70
+ """Get the result status of the checkpoint collection point."""
71
+ return self._result
72
+
73
+ @result.setter
74
+ def result(self, value):
75
+ """Set the result status of the checkpoint collection point."""
76
+ self._result = value
77
+
78
+ def get_collection_result_data(self) -> dict[str, any]:
79
+ """Get the results of the checkpoint collection point."""
80
+ timestamp_with_format = self._timestamp.strftime(TIMESTAMP_FORMAT)
81
+ relative_path = file_utils.get_relative_file_path(self._file_path)
82
+
83
+ collection_point_result = {
84
+ TIMESTAMP_KEY: timestamp_with_format,
85
+ FILE_KEY: relative_path,
86
+ LINE_OF_CODE_KEY: self._line_of_code,
87
+ CHECKPOINT_NAME_KEY: self._checkpoint_name,
88
+ RESULT_KEY: self.result.value,
89
+ }
90
+
91
+ return collection_point_result
@@ -0,0 +1,76 @@
1
+ # Copyright 2025 Snowflake Inc.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import json
17
+ import logging
18
+
19
+ from typing import Optional
20
+
21
+ from snowflake.snowpark_checkpoints_collector.collection_result.model import (
22
+ CollectionPointResult,
23
+ )
24
+ from snowflake.snowpark_checkpoints_collector.io_utils.io_file_manager import (
25
+ get_io_file_manager,
26
+ )
27
+ from snowflake.snowpark_checkpoints_collector.singleton import Singleton
28
+ from snowflake.snowpark_checkpoints_collector.utils import file_utils
29
+
30
+
31
+ RESULTS_KEY = "results"
32
+ LOGGER = logging.getLogger(__name__)
33
+
34
+
35
+ class CollectionPointResultManager(metaclass=Singleton):
36
+
37
+ """Class for manage the checkpoint collection results. It is a singleton.
38
+
39
+ Attributes:
40
+ result_collection (list[any]): the collection of the checkpoint results.
41
+ output_file_path (str): the full path of the output file.
42
+
43
+ """
44
+
45
+ def __init__(self, output_path: Optional[str] = None) -> None:
46
+ """Init CollectionPointResultManager."""
47
+ self.result_collection: list[any] = []
48
+ self.output_file_path = file_utils.get_output_file_path(output_path)
49
+
50
+ def add_result(self, result: CollectionPointResult) -> None:
51
+ """Add the CollectionPointResult result to the collection.
52
+
53
+ Args:
54
+ result (CollectionPointResult): the CollectionPointResult to add.
55
+
56
+ """
57
+ result_json = result.get_collection_result_data()
58
+ LOGGER.debug("Adding a new collection result: %s", result_json)
59
+ self.result_collection.append(result_json)
60
+ self._save_result()
61
+
62
+ def to_json(self) -> str:
63
+ """Convert to json the checkpoint results collected.
64
+
65
+ Returns:
66
+ str: the results as json string.
67
+
68
+ """
69
+ dict_object = {RESULTS_KEY: self.result_collection}
70
+ result_collection_json = json.dumps(dict_object)
71
+ return result_collection_json
72
+
73
+ def _save_result(self) -> None:
74
+ result_collection_json = self.to_json()
75
+ LOGGER.info("Saving collection results to '%s'", self.output_file_path)
76
+ get_io_file_manager().write(self.output_file_path, result_collection_json)
@@ -0,0 +1,22 @@
1
+ # Copyright 2025 Snowflake Inc.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ __all__ = [
17
+ "ColumnCollectorManager",
18
+ ]
19
+
20
+ from snowflake.snowpark_checkpoints_collector.column_collection.column_collector_manager import (
21
+ ColumnCollectorManager,
22
+ )
@@ -0,0 +1,276 @@
1
+ # Copyright 2025 Snowflake Inc.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import logging
17
+
18
+ from pyspark.sql import DataFrame as SparkDataFrame
19
+ from pyspark.sql.types import StructField
20
+
21
+ from snowflake.snowpark_checkpoints_collector.collection_common import (
22
+ ARRAY_COLUMN_TYPE,
23
+ BINARY_COLUMN_TYPE,
24
+ BOOLEAN_COLUMN_TYPE,
25
+ BYTE_COLUMN_TYPE,
26
+ DATE_COLUMN_TYPE,
27
+ DAYTIMEINTERVAL_COLUMN_TYPE,
28
+ DECIMAL_COLUMN_TYPE,
29
+ DOUBLE_COLUMN_TYPE,
30
+ FLOAT_COLUMN_TYPE,
31
+ INTEGER_COLUMN_TYPE,
32
+ LONG_COLUMN_TYPE,
33
+ MAP_COLUMN_TYPE,
34
+ NULL_COLUMN_TYPE,
35
+ SHORT_COLUMN_TYPE,
36
+ STRING_COLUMN_TYPE,
37
+ STRUCT_COLUMN_TYPE,
38
+ TIMESTAMP_COLUMN_TYPE,
39
+ TIMESTAMP_NTZ_COLUMN_TYPE,
40
+ )
41
+ from snowflake.snowpark_checkpoints_collector.column_collection.model import (
42
+ ArrayColumnCollector,
43
+ BinaryColumnCollector,
44
+ BooleanColumnCollector,
45
+ DateColumnCollector,
46
+ DayTimeIntervalColumnCollector,
47
+ DecimalColumnCollector,
48
+ EmptyColumnCollector,
49
+ MapColumnCollector,
50
+ NullColumnCollector,
51
+ NumericColumnCollector,
52
+ StringColumnCollector,
53
+ StructColumnCollector,
54
+ TimestampColumnCollector,
55
+ TimestampNTZColumnCollector,
56
+ )
57
+
58
+
59
+ LOGGER = logging.getLogger(__name__)
60
+
61
+
62
+ def collector_register(cls):
63
+ """Decorate a class with the collection type mechanism.
64
+
65
+ Args:
66
+ cls: The class to decorate.
67
+
68
+ Returns:
69
+ The class to decorate.
70
+
71
+ """
72
+ LOGGER.debug("Starting to register collectors from class %s", cls.__name__)
73
+ cls._collectors = {}
74
+ for method_name in dir(cls):
75
+ method = getattr(cls, method_name)
76
+ if hasattr(method, "_column_type"):
77
+ col_type_collection = method._column_type
78
+ for col_type in col_type_collection:
79
+ cls._collectors[col_type] = method_name
80
+ LOGGER.debug(
81
+ "Registered collector '%s' for column type '%s'",
82
+ method_name,
83
+ col_type,
84
+ )
85
+ return cls
86
+
87
+
88
+ def column_register(*args):
89
+ """Decorate a method to register it in the collection mechanism based on column type.
90
+
91
+ Args:
92
+ args: the column type to register.
93
+
94
+ Returns:
95
+ The wrapper.
96
+
97
+ """
98
+
99
+ def wrapper(func):
100
+ has_arguments = len(args) > 0
101
+ if has_arguments:
102
+ func._column_type = args
103
+ return func
104
+
105
+ return wrapper
106
+
107
+
108
+ @collector_register
109
+ class ColumnCollectorManager:
110
+
111
+ """Manage class for column collector based on type."""
112
+
113
+ def collect_column(
114
+ self, clm_name: str, struct_field: StructField, values: SparkDataFrame
115
+ ) -> dict[str, any]:
116
+ """Collect the data of the column based on the column type.
117
+
118
+ Args:
119
+ clm_name (str): the name of the column.
120
+ struct_field (pyspark.sql.types.StructField): the struct field of the column type.
121
+ values (pyspark.sql.DataFrame): the column values as PySpark DataFrame.
122
+
123
+ Returns:
124
+ dict[str, any]: The data collected.
125
+
126
+ """
127
+ clm_type = struct_field.dataType.typeName()
128
+ if clm_type not in self._collectors:
129
+ LOGGER.debug(
130
+ "No collectors found for column '%s' of type '%s'. Skipping collection for this column.",
131
+ clm_name,
132
+ clm_type,
133
+ )
134
+ return {}
135
+
136
+ func_name = self._collectors[clm_type]
137
+ func = getattr(self, func_name)
138
+ LOGGER.debug(
139
+ "Collecting custom data for column '%s' of type '%s' using collector method '%s'",
140
+ clm_name,
141
+ clm_type,
142
+ func_name,
143
+ )
144
+ data = func(clm_name, struct_field, values)
145
+ return data
146
+
147
+ @column_register(ARRAY_COLUMN_TYPE)
148
+ def _collect_array_type_custom_data(
149
+ self, clm_name: str, struct_field: StructField, values: SparkDataFrame
150
+ ) -> dict[str, any]:
151
+ column_collector = ArrayColumnCollector(clm_name, struct_field, values)
152
+ collected_data = column_collector.get_data()
153
+ return collected_data
154
+
155
+ @column_register(BINARY_COLUMN_TYPE)
156
+ def _collect_binary_type_custom_data(
157
+ self, clm_name: str, struct_field: StructField, values: SparkDataFrame
158
+ ) -> dict[str, any]:
159
+ column_collector = BinaryColumnCollector(clm_name, struct_field, values)
160
+ collected_data = column_collector.get_data()
161
+ return collected_data
162
+
163
+ @column_register(BOOLEAN_COLUMN_TYPE)
164
+ def _collect_boolean_type_custom_data(
165
+ self, clm_name: str, struct_field: StructField, values: SparkDataFrame
166
+ ) -> dict[str, any]:
167
+ column_collector = BooleanColumnCollector(clm_name, struct_field, values)
168
+ collected_data = column_collector.get_data()
169
+ return collected_data
170
+
171
+ @column_register(DATE_COLUMN_TYPE)
172
+ def _collect_date_type_custom_data(
173
+ self, clm_name: str, struct_field: StructField, values: SparkDataFrame
174
+ ) -> dict[str, any]:
175
+ column_collector = DateColumnCollector(clm_name, struct_field, values)
176
+ collected_data = column_collector.get_data()
177
+ return collected_data
178
+
179
+ @column_register(DAYTIMEINTERVAL_COLUMN_TYPE)
180
+ def _collect_day_time_interval_type_custom_data(
181
+ self, clm_name: str, struct_field: StructField, values: SparkDataFrame
182
+ ) -> dict[str, any]:
183
+ column_collector = DayTimeIntervalColumnCollector(
184
+ clm_name, struct_field, values
185
+ )
186
+ collected_data = column_collector.get_data()
187
+ return collected_data
188
+
189
+ @column_register(DECIMAL_COLUMN_TYPE)
190
+ def _collect_decimal_type_custom_data(
191
+ self, clm_name: str, struct_field: StructField, values: SparkDataFrame
192
+ ) -> dict[str, any]:
193
+ column_collector = DecimalColumnCollector(clm_name, struct_field, values)
194
+ collected_data = column_collector.get_data()
195
+ return collected_data
196
+
197
+ @column_register(MAP_COLUMN_TYPE)
198
+ def _collect_map_type_custom_data(
199
+ self, clm_name: str, struct_field: StructField, values: SparkDataFrame
200
+ ) -> dict[str, any]:
201
+ column_collector = MapColumnCollector(clm_name, struct_field, values)
202
+ collected_data = column_collector.get_data()
203
+ return collected_data
204
+
205
+ @column_register(NULL_COLUMN_TYPE)
206
+ def _collect_null_type_custom_data(
207
+ self, clm_name: str, struct_field: StructField, values: SparkDataFrame
208
+ ) -> dict[str, any]:
209
+ column_collector = NullColumnCollector(clm_name, struct_field, values)
210
+ collected_data = column_collector.get_data()
211
+ return collected_data
212
+
213
+ @column_register(
214
+ BYTE_COLUMN_TYPE,
215
+ SHORT_COLUMN_TYPE,
216
+ INTEGER_COLUMN_TYPE,
217
+ LONG_COLUMN_TYPE,
218
+ FLOAT_COLUMN_TYPE,
219
+ DOUBLE_COLUMN_TYPE,
220
+ )
221
+ def _collect_numeric_type_custom_data(
222
+ self, clm_name: str, struct_field: StructField, values: SparkDataFrame
223
+ ) -> dict[str, any]:
224
+ column_collector = NumericColumnCollector(clm_name, struct_field, values)
225
+ collected_data = column_collector.get_data()
226
+ return collected_data
227
+
228
+ @column_register(STRING_COLUMN_TYPE)
229
+ def _collect_string_type_custom_data(
230
+ self, clm_name: str, struct_field: StructField, values: SparkDataFrame
231
+ ) -> dict[str, any]:
232
+ column_collector = StringColumnCollector(clm_name, struct_field, values)
233
+ collected_data = column_collector.get_data()
234
+ return collected_data
235
+
236
+ @column_register(STRUCT_COLUMN_TYPE)
237
+ def _collect_struct_type_custom_data(
238
+ self, clm_name: str, struct_field: StructField, values: SparkDataFrame
239
+ ) -> dict[str, any]:
240
+ column_collector = StructColumnCollector(clm_name, struct_field, values)
241
+ collected_data = column_collector.get_data()
242
+ return collected_data
243
+
244
+ @column_register(TIMESTAMP_COLUMN_TYPE)
245
+ def _collect_timestamp_type_custom_data(
246
+ self, clm_name: str, struct_field: StructField, values: SparkDataFrame
247
+ ) -> dict[str, any]:
248
+ column_collector = TimestampColumnCollector(clm_name, struct_field, values)
249
+ collected_data = column_collector.get_data()
250
+ return collected_data
251
+
252
+ @column_register(TIMESTAMP_NTZ_COLUMN_TYPE)
253
+ def _collect_timestampntz_type_custom_data(
254
+ self, clm_name: str, struct_field: StructField, values: SparkDataFrame
255
+ ) -> dict[str, any]:
256
+ column_collector = TimestampNTZColumnCollector(clm_name, struct_field, values)
257
+ collected_data = column_collector.get_data()
258
+ return collected_data
259
+
260
+ def collect_empty_custom_data(
261
+ self, clm_name: str, struct_field: StructField, values: SparkDataFrame
262
+ ) -> dict[str, any]:
263
+ """Collect the data of a empty column.
264
+
265
+ Args:
266
+ clm_name (str): the name of the column.
267
+ struct_field (pyspark.sql.types.StructField): the struct field of the column type.
268
+ values (pyspark.sql.DataFrame): the column values as PySpark DataFrame.
269
+
270
+ Returns:
271
+ dict[str, any]: The data collected.
272
+
273
+ """
274
+ column_collector = EmptyColumnCollector(clm_name, struct_field, values)
275
+ collected_data = column_collector.get_data()
276
+ return collected_data