snowpark-checkpoints-collectors 0.1.0rc2__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. snowflake/snowpark_checkpoints_collector/__init__.py +22 -0
  2. snowflake/snowpark_checkpoints_collector/collection_common.py +160 -0
  3. snowflake/snowpark_checkpoints_collector/collection_result/model/__init__.py +24 -0
  4. snowflake/snowpark_checkpoints_collector/collection_result/model/collection_point_result.py +91 -0
  5. snowflake/snowpark_checkpoints_collector/collection_result/model/collection_point_result_manager.py +69 -0
  6. snowflake/snowpark_checkpoints_collector/column_collection/__init__.py +22 -0
  7. snowflake/snowpark_checkpoints_collector/column_collection/column_collector_manager.py +253 -0
  8. snowflake/snowpark_checkpoints_collector/column_collection/model/__init__.py +75 -0
  9. snowflake/snowpark_checkpoints_collector/column_collection/model/array_column_collector.py +113 -0
  10. snowflake/snowpark_checkpoints_collector/column_collection/model/binary_column_collector.py +87 -0
  11. snowflake/snowpark_checkpoints_collector/column_collection/model/boolean_column_collector.py +71 -0
  12. snowflake/snowpark_checkpoints_collector/column_collection/model/column_collector_base.py +95 -0
  13. snowflake/snowpark_checkpoints_collector/column_collection/model/date_column_collector.py +74 -0
  14. snowflake/snowpark_checkpoints_collector/column_collection/model/day_time_interval_column_collector.py +67 -0
  15. snowflake/snowpark_checkpoints_collector/column_collection/model/decimal_column_collector.py +92 -0
  16. snowflake/snowpark_checkpoints_collector/column_collection/model/empty_column_collector.py +88 -0
  17. snowflake/snowpark_checkpoints_collector/column_collection/model/map_column_collector.py +120 -0
  18. snowflake/snowpark_checkpoints_collector/column_collection/model/null_column_collector.py +49 -0
  19. snowflake/snowpark_checkpoints_collector/column_collection/model/numeric_column_collector.py +108 -0
  20. snowflake/snowpark_checkpoints_collector/column_collection/model/string_column_collector.py +70 -0
  21. snowflake/snowpark_checkpoints_collector/column_collection/model/struct_column_collector.py +102 -0
  22. snowflake/snowpark_checkpoints_collector/column_collection/model/timestamp_column_collector.py +75 -0
  23. snowflake/snowpark_checkpoints_collector/column_collection/model/timestamp_ntz_column_collector.py +75 -0
  24. snowflake/snowpark_checkpoints_collector/column_pandera_checks/__init__.py +20 -0
  25. snowflake/snowpark_checkpoints_collector/column_pandera_checks/pandera_column_checks_manager.py +223 -0
  26. snowflake/snowpark_checkpoints_collector/singleton.py +23 -0
  27. snowflake/snowpark_checkpoints_collector/snow_connection_model/__init__.py +20 -0
  28. snowflake/snowpark_checkpoints_collector/snow_connection_model/snow_connection.py +172 -0
  29. snowflake/snowpark_checkpoints_collector/summary_stats_collector.py +366 -0
  30. snowflake/snowpark_checkpoints_collector/utils/checkpoint_name_utils.py +53 -0
  31. snowflake/snowpark_checkpoints_collector/utils/extra_config.py +112 -0
  32. snowflake/snowpark_checkpoints_collector/utils/file_utils.py +132 -0
  33. snowflake/snowpark_checkpoints_collector/utils/telemetry.py +889 -0
  34. snowpark_checkpoints_collectors-0.1.1.dist-info/METADATA +143 -0
  35. snowpark_checkpoints_collectors-0.1.1.dist-info/RECORD +37 -0
  36. {snowpark_checkpoints_collectors-0.1.0rc2.dist-info → snowpark_checkpoints_collectors-0.1.1.dist-info}/licenses/LICENSE +0 -25
  37. snowpark_checkpoints_collectors-0.1.0rc2.dist-info/METADATA +0 -347
  38. snowpark_checkpoints_collectors-0.1.0rc2.dist-info/RECORD +0 -4
  39. {snowpark_checkpoints_collectors-0.1.0rc2.dist-info → snowpark_checkpoints_collectors-0.1.1.dist-info}/WHEEL +0 -0
@@ -0,0 +1,22 @@
1
+ # Copyright 2025 Snowflake Inc.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ __all__ = ["collect_dataframe_checkpoint", "CheckpointMode"]
17
+
18
+ from snowflake.snowpark_checkpoints_collector.summary_stats_collector import (
19
+ collect_dataframe_checkpoint,
20
+ )
21
+
22
+ from snowflake.snowpark_checkpoints_collector.collection_common import CheckpointMode
@@ -0,0 +1,160 @@
1
+ # Copyright 2025 Snowflake Inc.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import locale
17
+
18
+ from enum import IntEnum
19
+
20
+
21
+ class CheckpointMode(IntEnum):
22
+
23
+ """Enum class representing the collection mode."""
24
+
25
+ SCHEMA = 1
26
+ """Collect automatic schema inference"""
27
+ DATAFRAME = 2
28
+ """Export DataFrame as Parquet file to Snowflake"""
29
+
30
+
31
+ # CONSTANTS
32
+ ARRAY_COLUMN_TYPE = "array"
33
+ BINARY_COLUMN_TYPE = "binary"
34
+ BOOLEAN_COLUMN_TYPE = "boolean"
35
+ BYTE_COLUMN_TYPE = "byte"
36
+ DATE_COLUMN_TYPE = "date"
37
+ DAYTIMEINTERVAL_COLUMN_TYPE = "daytimeinterval"
38
+ DECIMAL_COLUMN_TYPE = "decimal"
39
+ DOUBLE_COLUMN_TYPE = "double"
40
+ FLOAT_COLUMN_TYPE = "float"
41
+ INTEGER_COLUMN_TYPE = "integer"
42
+ LONG_COLUMN_TYPE = "long"
43
+ MAP_COLUMN_TYPE = "map"
44
+ NULL_COLUMN_TYPE = "void"
45
+ SHORT_COLUMN_TYPE = "short"
46
+ STRING_COLUMN_TYPE = "string"
47
+ STRUCT_COLUMN_TYPE = "struct"
48
+ TIMESTAMP_COLUMN_TYPE = "timestamp"
49
+ TIMESTAMP_NTZ_COLUMN_TYPE = "timestamp_ntz"
50
+
51
+ PANDAS_BOOLEAN_DTYPE = "bool"
52
+ PANDAS_DATETIME_DTYPE = "datetime64[ns]"
53
+ PANDAS_FLOAT_DTYPE = "float64"
54
+ PANDAS_INTEGER_DTYPE = "int64"
55
+ PANDAS_OBJECT_DTYPE = "object"
56
+ PANDAS_TIMEDELTA_DTYPE = "timedelta64[ns]"
57
+
58
+ NUMERIC_TYPE_COLLECTION = [
59
+ BYTE_COLUMN_TYPE,
60
+ DOUBLE_COLUMN_TYPE,
61
+ FLOAT_COLUMN_TYPE,
62
+ INTEGER_COLUMN_TYPE,
63
+ LONG_COLUMN_TYPE,
64
+ SHORT_COLUMN_TYPE,
65
+ ]
66
+
67
+ INTEGER_TYPE_COLLECTION = [
68
+ BYTE_COLUMN_TYPE,
69
+ INTEGER_COLUMN_TYPE,
70
+ LONG_COLUMN_TYPE,
71
+ SHORT_COLUMN_TYPE,
72
+ ]
73
+
74
+ PANDAS_OBJECT_TYPE_COLLECTION = [
75
+ STRING_COLUMN_TYPE,
76
+ ARRAY_COLUMN_TYPE,
77
+ MAP_COLUMN_TYPE,
78
+ NULL_COLUMN_TYPE,
79
+ STRUCT_COLUMN_TYPE,
80
+ ]
81
+
82
+ BETWEEN_CHECK_ERROR_MESSAGE_FORMAT = "Value must be between {} and {}"
83
+
84
+ # SCHEMA CONTRACT KEYS CONSTANTS
85
+ COLUMN_ALLOW_NULL_KEY = "allow_null"
86
+ COLUMN_COUNT_KEY = "rows_count"
87
+ COLUMN_DECIMAL_PRECISION_KEY = "decimal_precision"
88
+ COLUMN_FALSE_COUNT_KEY = "false_count"
89
+ COLUMN_FORMAT_KEY = "format"
90
+ COLUMN_IS_NULLABLE_KEY = "nullable"
91
+ COLUMN_IS_UNIQUE_SIZE_KEY = "is_unique_size"
92
+ COLUMN_KEY_TYPE_KEY = "key_type"
93
+ COLUMN_MARGIN_ERROR_KEY = "margin_error"
94
+ COLUMN_MAX_KEY = "max"
95
+ COLUMN_MAX_LENGTH_KEY = "max_length"
96
+ COLUMN_MAX_SIZE_KEY = "max_size"
97
+ COLUMN_MEAN_KEY = "mean"
98
+ COLUMN_MEAN_SIZE_KEY = "mean_size"
99
+ COLUMN_METADATA_KEY = "metadata"
100
+ COLUMN_MIN_KEY = "min"
101
+ COLUMN_MIN_LENGTH_KEY = "min_length"
102
+ COLUMN_MIN_SIZE_KEY = "min_size"
103
+ COLUMN_NAME_KEY = "name"
104
+ COLUMN_NULL_COUNT_KEY = "null_count"
105
+ COLUMN_NULL_VALUE_PROPORTION_KEY = "null_value_proportion"
106
+ COLUMN_ROWS_NOT_NULL_COUNT_KEY = "rows_not_null_count"
107
+ COLUMN_ROWS_NULL_COUNT_KEY = "rows_null_count"
108
+ COLUMN_SIZE_KEY = "size"
109
+ COLUMN_TRUE_COUNT_KEY = "true_count"
110
+ COLUMN_TYPE_KEY = "type"
111
+ COLUMN_VALUE_KEY = "value"
112
+ COLUMN_VALUE_TYPE_KEY = "value_type"
113
+ COLUMNS_KEY = "columns"
114
+
115
+ DATAFRAME_CUSTOM_DATA_KEY = "custom_data"
116
+ DATAFRAME_PANDERA_SCHEMA_KEY = "pandera_schema"
117
+
118
+ PANDERA_COLUMN_TYPE_KEY = "dtype"
119
+
120
+ CONTAINS_NULL_KEY = "containsNull"
121
+ ELEMENT_TYPE_KEY = "elementType"
122
+ FIELD_METADATA_KEY = "metadata"
123
+ FIELDS_KEY = "fields"
124
+ KEY_TYPE_KEY = "keyType"
125
+ NAME_KEY = "name"
126
+ VALUE_CONTAINS_NULL_KEY = "valueContainsNull"
127
+ VALUE_TYPE_KEY = "valueType"
128
+
129
+ # DIRECTORY AND FILE NAME CONSTANTS
130
+ SNOWPARK_CHECKPOINTS_OUTPUT_DIRECTORY_NAME = "snowpark-checkpoints-output"
131
+ CHECKPOINT_JSON_OUTPUT_FILE_NAME_FORMAT = "{}.json"
132
+ CHECKPOINT_PARQUET_OUTPUT_FILE_NAME_FORMAT = "{}.parquet"
133
+ COLLECTION_RESULT_FILE_NAME = "checkpoint_collection_results.json"
134
+
135
+ # MISC KEYS
136
+ DECIMAL_TOKEN_KEY = "decimal_point"
137
+ DOT_PARQUET_EXTENSION = ".parquet"
138
+ DOT_IPYNB_EXTENSION = ".ipynb"
139
+ UNKNOWN_SOURCE_FILE = "unknown"
140
+ UNKNOWN_LINE_OF_CODE = -1
141
+ BACKSLASH_TOKEN = "\\"
142
+ SLASH_TOKEN = "/"
143
+ PYSPARK_NONE_SIZE_VALUE = -1
144
+ PANDAS_LONG_TYPE = "Int64"
145
+
146
+ # ENVIRONMENT VARIABLES
147
+ SNOWFLAKE_CHECKPOINT_CONTRACT_FILE_PATH_ENV_VAR = (
148
+ "SNOWFLAKE_CHECKPOINT_CONTRACT_FILE_PATH"
149
+ )
150
+
151
+
152
+ def get_decimal_token() -> str:
153
+ """Return the decimal token based on the local environment.
154
+
155
+ Returns:
156
+ str: The decimal token.
157
+
158
+ """
159
+ decimal_token = locale.localeconv()[DECIMAL_TOKEN_KEY]
160
+ return decimal_token
@@ -0,0 +1,24 @@
1
+ # Copyright 2025 Snowflake Inc.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ __all__ = ["CollectionPointResult", "CollectionResult", "CollectionPointResultManager"]
17
+
18
+ from snowflake.snowpark_checkpoints_collector.collection_result.model.collection_point_result import (
19
+ CollectionPointResult,
20
+ CollectionResult,
21
+ )
22
+ from snowflake.snowpark_checkpoints_collector.collection_result.model.collection_point_result_manager import (
23
+ CollectionPointResultManager,
24
+ )
@@ -0,0 +1,91 @@
1
+ # Copyright 2025 Snowflake Inc.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from datetime import datetime
16
+ from enum import Enum
17
+
18
+ from snowflake.snowpark_checkpoints_collector.utils import file_utils
19
+
20
+
21
+ TIMESTAMP_FORMAT = "%Y-%m-%d %H:%M:%S"
22
+
23
+ TIMESTAMP_KEY = "timestamp"
24
+ FILE_KEY = "file"
25
+ RESULT_KEY = "result"
26
+ LINE_OF_CODE_KEY = "line_of_code"
27
+ CHECKPOINT_NAME_KEY = "checkpoint_name"
28
+
29
+
30
+ class CollectionResult(Enum):
31
+ FAIL = "FAIL"
32
+ PASS = "PASS"
33
+
34
+
35
+ class CollectionPointResult:
36
+
37
+ """Class for checkpoint collection results.
38
+
39
+ Attributes:
40
+ _timestamp (timestamp): the timestamp when collection started.
41
+ _file_path (str): the full path where checkpoint is.
42
+ _line_of_code (int): the line of code where the checkpoint is.
43
+ _checkpoint_name (str): the checkpoint name.
44
+ _result (CollectionResult): the result status of the checkpoint collection point.
45
+
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ file_path: str,
51
+ line_of_code: int,
52
+ checkpoint_name: str,
53
+ ) -> None:
54
+ """Init CollectionPointResult.
55
+
56
+ Args:
57
+ file_path (str): the full path where checkpoint is.
58
+ line_of_code (int): the line of code where the checkpoint is.
59
+ checkpoint_name (str): the checkpoint name.
60
+
61
+ """
62
+ self._timestamp = datetime.now()
63
+ self._file_path = file_path
64
+ self._line_of_code = line_of_code
65
+ self._checkpoint_name = checkpoint_name
66
+ self._result = None
67
+
68
+ @property
69
+ def result(self):
70
+ """Get the result status of the checkpoint collection point."""
71
+ return self._result
72
+
73
+ @result.setter
74
+ def result(self, value):
75
+ """Set the result status of the checkpoint collection point."""
76
+ self._result = value
77
+
78
+ def get_collection_result_data(self) -> dict[str, any]:
79
+ """Get the results of the checkpoint collection point."""
80
+ timestamp_with_format = self._timestamp.strftime(TIMESTAMP_FORMAT)
81
+ relative_path = file_utils.get_relative_file_path(self._file_path)
82
+
83
+ collection_point_result = {
84
+ TIMESTAMP_KEY: timestamp_with_format,
85
+ FILE_KEY: relative_path,
86
+ LINE_OF_CODE_KEY: self._line_of_code,
87
+ CHECKPOINT_NAME_KEY: self._checkpoint_name,
88
+ RESULT_KEY: self.result.value,
89
+ }
90
+
91
+ return collection_point_result
@@ -0,0 +1,69 @@
1
+ # Copyright 2025 Snowflake Inc.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import json
16
+
17
+ from typing import Optional
18
+
19
+ from snowflake.snowpark_checkpoints_collector.collection_result.model import (
20
+ CollectionPointResult,
21
+ )
22
+ from snowflake.snowpark_checkpoints_collector.singleton import Singleton
23
+ from snowflake.snowpark_checkpoints_collector.utils import file_utils
24
+
25
+
26
+ RESULTS_KEY = "results"
27
+
28
+
29
+ class CollectionPointResultManager(metaclass=Singleton):
30
+
31
+ """Class for manage the checkpoint collection results. It is a singleton.
32
+
33
+ Attributes:
34
+ result_collection (list[any]): the collection of the checkpoint results.
35
+ output_file_path (str): the full path of the output file.
36
+
37
+ """
38
+
39
+ def __init__(self, output_path: Optional[str] = None) -> None:
40
+ """Init CollectionPointResultManager."""
41
+ self.result_collection: list[any] = []
42
+ self.output_file_path = file_utils.get_output_file_path(output_path)
43
+
44
+ def add_result(self, result: CollectionPointResult) -> None:
45
+ """Add the CollectionPointResult result to the collection.
46
+
47
+ Args:
48
+ result (CollectionPointResult): the CollectionPointResult to add.
49
+
50
+ """
51
+ result_json = result.get_collection_result_data()
52
+ self.result_collection.append(result_json)
53
+ self._save_result()
54
+
55
+ def to_json(self) -> str:
56
+ """Convert to json the checkpoint results collected.
57
+
58
+ Returns:
59
+ str: the results as json string.
60
+
61
+ """
62
+ dict_object = {RESULTS_KEY: self.result_collection}
63
+ result_collection_json = json.dumps(dict_object)
64
+ return result_collection_json
65
+
66
+ def _save_result(self) -> None:
67
+ result_collection_json = self.to_json()
68
+ with open(self.output_file_path, "w") as f:
69
+ f.write(result_collection_json)
@@ -0,0 +1,22 @@
1
+ # Copyright 2025 Snowflake Inc.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ __all__ = [
17
+ "ColumnCollectorManager",
18
+ ]
19
+
20
+ from snowflake.snowpark_checkpoints_collector.column_collection.column_collector_manager import (
21
+ ColumnCollectorManager,
22
+ )
@@ -0,0 +1,253 @@
1
+ # Copyright 2025 Snowflake Inc.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from pyspark.sql import DataFrame as SparkDataFrame
16
+ from pyspark.sql.types import StructField
17
+
18
+ from snowflake.snowpark_checkpoints_collector.collection_common import (
19
+ ARRAY_COLUMN_TYPE,
20
+ BINARY_COLUMN_TYPE,
21
+ BOOLEAN_COLUMN_TYPE,
22
+ BYTE_COLUMN_TYPE,
23
+ DATE_COLUMN_TYPE,
24
+ DAYTIMEINTERVAL_COLUMN_TYPE,
25
+ DECIMAL_COLUMN_TYPE,
26
+ DOUBLE_COLUMN_TYPE,
27
+ FLOAT_COLUMN_TYPE,
28
+ INTEGER_COLUMN_TYPE,
29
+ LONG_COLUMN_TYPE,
30
+ MAP_COLUMN_TYPE,
31
+ NULL_COLUMN_TYPE,
32
+ SHORT_COLUMN_TYPE,
33
+ STRING_COLUMN_TYPE,
34
+ STRUCT_COLUMN_TYPE,
35
+ TIMESTAMP_COLUMN_TYPE,
36
+ TIMESTAMP_NTZ_COLUMN_TYPE,
37
+ )
38
+ from snowflake.snowpark_checkpoints_collector.column_collection.model import (
39
+ ArrayColumnCollector,
40
+ BinaryColumnCollector,
41
+ BooleanColumnCollector,
42
+ DateColumnCollector,
43
+ DayTimeIntervalColumnCollector,
44
+ DecimalColumnCollector,
45
+ EmptyColumnCollector,
46
+ MapColumnCollector,
47
+ NullColumnCollector,
48
+ NumericColumnCollector,
49
+ StringColumnCollector,
50
+ StructColumnCollector,
51
+ TimestampColumnCollector,
52
+ TimestampNTZColumnCollector,
53
+ )
54
+
55
+
56
+ def collector_register(cls):
57
+ """Decorate a class with the collection type mechanism.
58
+
59
+ Args:
60
+ cls: The class to decorate.
61
+
62
+ Returns:
63
+ The class to decorate.
64
+
65
+ """
66
+ cls._collectors = {}
67
+ for method_name in dir(cls):
68
+ method = getattr(cls, method_name)
69
+ if hasattr(method, "_column_type"):
70
+ col_type_collection = method._column_type
71
+ for col_type in col_type_collection:
72
+ cls._collectors[col_type] = method_name
73
+ return cls
74
+
75
+
76
+ def column_register(*args):
77
+ """Decorate a method to register it in the collection mechanism based on column type.
78
+
79
+ Args:
80
+ args: the column type to register.
81
+
82
+ Returns:
83
+ The wrapper.
84
+
85
+ """
86
+
87
+ def wrapper(func):
88
+ has_arguments = len(args) > 0
89
+ if has_arguments:
90
+ func._column_type = args
91
+ return func
92
+
93
+ return wrapper
94
+
95
+
96
+ @collector_register
97
+ class ColumnCollectorManager:
98
+
99
+ """Manage class for column collector based on type."""
100
+
101
+ def collect_column(
102
+ self, clm_name: str, struct_field: StructField, values: SparkDataFrame
103
+ ) -> dict[str, any]:
104
+ """Collect the data of the column based on the column type.
105
+
106
+ Args:
107
+ clm_name (str): the name of the column.
108
+ struct_field (pyspark.sql.types.StructField): the struct field of the column type.
109
+ values (pyspark.sql.DataFrame): the column values as PySpark DataFrame.
110
+
111
+ Returns:
112
+ dict[str, any]: The data collected.
113
+
114
+ """
115
+ clm_type = struct_field.dataType.typeName()
116
+ if clm_type not in self._collectors:
117
+ return {}
118
+
119
+ func_name = self._collectors[clm_type]
120
+ func = getattr(self, func_name)
121
+ data = func(clm_name, struct_field, values)
122
+ return data
123
+
124
+ @column_register(ARRAY_COLUMN_TYPE)
125
+ def _collect_array_type_custom_data(
126
+ self, clm_name: str, struct_field: StructField, values: SparkDataFrame
127
+ ) -> dict[str, any]:
128
+ column_collector = ArrayColumnCollector(clm_name, struct_field, values)
129
+ collected_data = column_collector.get_data()
130
+ return collected_data
131
+
132
+ @column_register(BINARY_COLUMN_TYPE)
133
+ def _collect_binary_type_custom_data(
134
+ self, clm_name: str, struct_field: StructField, values: SparkDataFrame
135
+ ) -> dict[str, any]:
136
+ column_collector = BinaryColumnCollector(clm_name, struct_field, values)
137
+ collected_data = column_collector.get_data()
138
+ return collected_data
139
+
140
+ @column_register(BOOLEAN_COLUMN_TYPE)
141
+ def _collect_boolean_type_custom_data(
142
+ self, clm_name: str, struct_field: StructField, values: SparkDataFrame
143
+ ) -> dict[str, any]:
144
+ column_collector = BooleanColumnCollector(clm_name, struct_field, values)
145
+ collected_data = column_collector.get_data()
146
+ return collected_data
147
+
148
+ @column_register(DATE_COLUMN_TYPE)
149
+ def _collect_date_type_custom_data(
150
+ self, clm_name: str, struct_field: StructField, values: SparkDataFrame
151
+ ) -> dict[str, any]:
152
+ column_collector = DateColumnCollector(clm_name, struct_field, values)
153
+ collected_data = column_collector.get_data()
154
+ return collected_data
155
+
156
+ @column_register(DAYTIMEINTERVAL_COLUMN_TYPE)
157
+ def _collect_day_time_interval_type_custom_data(
158
+ self, clm_name: str, struct_field: StructField, values: SparkDataFrame
159
+ ) -> dict[str, any]:
160
+ column_collector = DayTimeIntervalColumnCollector(
161
+ clm_name, struct_field, values
162
+ )
163
+ collected_data = column_collector.get_data()
164
+ return collected_data
165
+
166
+ @column_register(DECIMAL_COLUMN_TYPE)
167
+ def _collect_decimal_type_custom_data(
168
+ self, clm_name: str, struct_field: StructField, values: SparkDataFrame
169
+ ) -> dict[str, any]:
170
+ column_collector = DecimalColumnCollector(clm_name, struct_field, values)
171
+ collected_data = column_collector.get_data()
172
+ return collected_data
173
+
174
+ @column_register(MAP_COLUMN_TYPE)
175
+ def _collect_map_type_custom_data(
176
+ self, clm_name: str, struct_field: StructField, values: SparkDataFrame
177
+ ) -> dict[str, any]:
178
+ column_collector = MapColumnCollector(clm_name, struct_field, values)
179
+ collected_data = column_collector.get_data()
180
+ return collected_data
181
+
182
+ @column_register(NULL_COLUMN_TYPE)
183
+ def _collect_null_type_custom_data(
184
+ self, clm_name: str, struct_field: StructField, values: SparkDataFrame
185
+ ) -> dict[str, any]:
186
+ column_collector = NullColumnCollector(clm_name, struct_field, values)
187
+ collected_data = column_collector.get_data()
188
+ return collected_data
189
+
190
+ @column_register(
191
+ BYTE_COLUMN_TYPE,
192
+ SHORT_COLUMN_TYPE,
193
+ INTEGER_COLUMN_TYPE,
194
+ LONG_COLUMN_TYPE,
195
+ FLOAT_COLUMN_TYPE,
196
+ DOUBLE_COLUMN_TYPE,
197
+ )
198
+ def _collect_numeric_type_custom_data(
199
+ self, clm_name: str, struct_field: StructField, values: SparkDataFrame
200
+ ) -> dict[str, any]:
201
+ column_collector = NumericColumnCollector(clm_name, struct_field, values)
202
+ collected_data = column_collector.get_data()
203
+ return collected_data
204
+
205
+ @column_register(STRING_COLUMN_TYPE)
206
+ def _collect_string_type_custom_data(
207
+ self, clm_name: str, struct_field: StructField, values: SparkDataFrame
208
+ ) -> dict[str, any]:
209
+ column_collector = StringColumnCollector(clm_name, struct_field, values)
210
+ collected_data = column_collector.get_data()
211
+ return collected_data
212
+
213
+ @column_register(STRUCT_COLUMN_TYPE)
214
+ def _collect_struct_type_custom_data(
215
+ self, clm_name: str, struct_field: StructField, values: SparkDataFrame
216
+ ) -> dict[str, any]:
217
+ column_collector = StructColumnCollector(clm_name, struct_field, values)
218
+ collected_data = column_collector.get_data()
219
+ return collected_data
220
+
221
+ @column_register(TIMESTAMP_COLUMN_TYPE)
222
+ def _collect_timestamp_type_custom_data(
223
+ self, clm_name: str, struct_field: StructField, values: SparkDataFrame
224
+ ) -> dict[str, any]:
225
+ column_collector = TimestampColumnCollector(clm_name, struct_field, values)
226
+ collected_data = column_collector.get_data()
227
+ return collected_data
228
+
229
+ @column_register(TIMESTAMP_NTZ_COLUMN_TYPE)
230
+ def _collect_timestampntz_type_custom_data(
231
+ self, clm_name: str, struct_field: StructField, values: SparkDataFrame
232
+ ) -> dict[str, any]:
233
+ column_collector = TimestampNTZColumnCollector(clm_name, struct_field, values)
234
+ collected_data = column_collector.get_data()
235
+ return collected_data
236
+
237
+ def collect_empty_custom_data(
238
+ self, clm_name: str, struct_field: StructField, values: SparkDataFrame
239
+ ) -> dict[str, any]:
240
+ """Collect the data of a empty column.
241
+
242
+ Args:
243
+ clm_name (str): the name of the column.
244
+ struct_field (pyspark.sql.types.StructField): the struct field of the column type.
245
+ values (pyspark.sql.DataFrame): the column values as PySpark DataFrame.
246
+
247
+ Returns:
248
+ dict[str, any]: The data collected.
249
+
250
+ """
251
+ column_collector = EmptyColumnCollector(clm_name, struct_field, values)
252
+ collected_data = column_collector.get_data()
253
+ return collected_data