dragon-ml-toolbox 14.7.0__py3-none-any.whl → 16.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.1.dist-info}/METADATA +9 -5
- dragon_ml_toolbox-16.2.1.dist-info/RECORD +51 -0
- ml_tools/ETL_cleaning.py +20 -20
- ml_tools/ETL_engineering.py +23 -25
- ml_tools/GUI_tools.py +20 -20
- ml_tools/MICE_imputation.py +3 -3
- ml_tools/ML_callbacks.py +43 -26
- ml_tools/ML_configuration.py +726 -32
- ml_tools/ML_datasetmaster.py +235 -280
- ml_tools/ML_evaluation.py +160 -42
- ml_tools/ML_evaluation_multi.py +103 -35
- ml_tools/ML_inference.py +290 -208
- ml_tools/ML_models.py +13 -102
- ml_tools/ML_models_advanced.py +1 -1
- ml_tools/ML_optimization.py +12 -12
- ml_tools/ML_scaler.py +11 -11
- ml_tools/ML_sequence_datasetmaster.py +341 -0
- ml_tools/ML_sequence_evaluation.py +219 -0
- ml_tools/ML_sequence_inference.py +391 -0
- ml_tools/ML_sequence_models.py +139 -0
- ml_tools/ML_trainer.py +1342 -386
- ml_tools/ML_utilities.py +1 -1
- ml_tools/ML_vision_datasetmaster.py +120 -72
- ml_tools/ML_vision_evaluation.py +30 -6
- ml_tools/ML_vision_inference.py +129 -152
- ml_tools/ML_vision_models.py +1 -1
- ml_tools/ML_vision_transformers.py +121 -40
- ml_tools/PSO_optimization.py +6 -6
- ml_tools/SQL.py +4 -4
- ml_tools/{keys.py → _keys.py} +45 -0
- ml_tools/_schema.py +1 -1
- ml_tools/ensemble_evaluation.py +1 -1
- ml_tools/ensemble_inference.py +7 -33
- ml_tools/ensemble_learning.py +1 -1
- ml_tools/optimization_tools.py +2 -2
- ml_tools/path_manager.py +5 -5
- ml_tools/utilities.py +1 -2
- dragon_ml_toolbox-14.7.0.dist-info/RECORD +0 -49
- ml_tools/RNN_forecast.py +0 -56
- ml_tools/_ML_vision_recipe.py +0 -88
- {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.1.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.1.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.1.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.1.dist-info}/top_level.txt +0 -0
ml_tools/ML_configuration.py
CHANGED
|
@@ -1,20 +1,45 @@
|
|
|
1
|
-
from typing import Optional
|
|
1
|
+
from typing import Union, Optional
|
|
2
|
+
import numpy as np
|
|
3
|
+
|
|
2
4
|
from ._script_info import _script_info
|
|
5
|
+
from ._logger import _LOGGER
|
|
6
|
+
from .path_manager import sanitize_filename
|
|
3
7
|
|
|
4
8
|
|
|
5
9
|
__all__ = [
|
|
6
|
-
"
|
|
7
|
-
"
|
|
10
|
+
"RegressionMetricsFormat",
|
|
11
|
+
"MultiTargetRegressionMetricsFormat",
|
|
12
|
+
"BinaryClassificationMetricsFormat",
|
|
13
|
+
"MultiClassClassificationMetricsFormat",
|
|
14
|
+
"BinaryImageClassificationMetricsFormat",
|
|
15
|
+
"MultiClassImageClassificationMetricsFormat",
|
|
16
|
+
"MultiLabelBinaryClassificationMetricsFormat",
|
|
17
|
+
"BinarySegmentationMetricsFormat",
|
|
18
|
+
"MultiClassSegmentationMetricsFormat",
|
|
19
|
+
"SequenceValueMetricsFormat",
|
|
20
|
+
"SequenceSequenceMetricsFormat",
|
|
21
|
+
|
|
22
|
+
"FinalizeBinaryClassification",
|
|
23
|
+
"FinalizeBinarySegmentation",
|
|
24
|
+
"FinalizeBinaryImageClassification",
|
|
25
|
+
"FinalizeMultiClassClassification",
|
|
26
|
+
"FinalizeMultiClassImageClassification",
|
|
27
|
+
"FinalizeMultiClassSegmentation",
|
|
28
|
+
"FinalizeMultiLabelBinaryClassification",
|
|
29
|
+
"FinalizeMultiTargetRegression",
|
|
30
|
+
"FinalizeRegression",
|
|
31
|
+
"FinalizeObjectDetection",
|
|
32
|
+
"FinalizeSequencePrediction"
|
|
8
33
|
]
|
|
9
34
|
|
|
35
|
+
# --- Private base classes ---
|
|
10
36
|
|
|
11
|
-
class
|
|
37
|
+
class _BaseClassificationFormat:
|
|
12
38
|
"""
|
|
13
|
-
|
|
39
|
+
[PRIVATE] Base configuration for single-label classification metrics.
|
|
14
40
|
"""
|
|
15
41
|
def __init__(self,
|
|
16
|
-
cmap: str="
|
|
17
|
-
class_map: Optional[dict[str,int]]=None,
|
|
42
|
+
cmap: str="BuGn",
|
|
18
43
|
ROC_PR_line: str='darkorange',
|
|
19
44
|
calibration_bins: int=15,
|
|
20
45
|
font_size: int=16) -> None:
|
|
@@ -27,11 +52,6 @@ class ClassificationMetricsFormat:
|
|
|
27
52
|
- Sequential options: 'Blues', 'Greens', 'Reds', 'Oranges', 'Purples'
|
|
28
53
|
- Diverging options: 'coolwarm', 'viridis', 'plasma', 'inferno'
|
|
29
54
|
|
|
30
|
-
class_map (dict[str,int] | None): A dictionary mapping
|
|
31
|
-
class string names to their integer indices (e.g., {'cat': 0, 'dog': 1}).
|
|
32
|
-
This is used to label the axes of the confusion matrix and classification
|
|
33
|
-
report correctly. Defaults to None.
|
|
34
|
-
|
|
35
55
|
ROC_PR_line (str): The color name or hex code for the line plotted
|
|
36
56
|
on the ROC and Precision-Recall curves. Defaults to 'darkorange'.
|
|
37
57
|
- Common color names: 'darkorange', 'cornflowerblue', 'crimson', 'forestgreen'
|
|
@@ -41,9 +61,16 @@ class ClassificationMetricsFormat:
|
|
|
41
61
|
creating the calibration (reliability) plot. Defaults to 15.
|
|
42
62
|
|
|
43
63
|
font_size (int): The base font size to apply to the plots. Defaults to 16.
|
|
64
|
+
|
|
65
|
+
<br>
|
|
66
|
+
|
|
67
|
+
### [Matplotlib Colormaps](https://matplotlib.org/stable/users/explain/colors/colormaps.html)
|
|
68
|
+
|
|
69
|
+
<br>
|
|
70
|
+
|
|
71
|
+
### [Matplotlib Colors](https://matplotlib.org/stable/gallery/color/named_colors.html)
|
|
44
72
|
"""
|
|
45
73
|
self.cmap = cmap
|
|
46
|
-
self.class_map = class_map
|
|
47
74
|
self.ROC_PR_line = ROC_PR_line
|
|
48
75
|
self.calibration_bins = calibration_bins
|
|
49
76
|
self.font_size = font_size
|
|
@@ -51,58 +78,725 @@ class ClassificationMetricsFormat:
|
|
|
51
78
|
def __repr__(self) -> str:
|
|
52
79
|
parts = [
|
|
53
80
|
f"cmap='{self.cmap}'",
|
|
54
|
-
f"class_map={self.class_map}",
|
|
55
81
|
f"ROC_PR_line='{self.ROC_PR_line}'",
|
|
56
82
|
f"calibration_bins={self.calibration_bins}",
|
|
57
83
|
f"font_size={self.font_size}"
|
|
58
84
|
]
|
|
59
|
-
return f"
|
|
85
|
+
return f"{self.__class__.__name__}({', '.join(parts)})"
|
|
60
86
|
|
|
61
87
|
|
|
62
|
-
class
|
|
88
|
+
class _BaseMultiLabelFormat:
|
|
63
89
|
"""
|
|
64
|
-
|
|
90
|
+
[PRIVATE] Base configuration for multi-label binary classification metrics.
|
|
65
91
|
"""
|
|
66
92
|
def __init__(self,
|
|
67
|
-
|
|
93
|
+
cmap: str = "BuGn",
|
|
68
94
|
ROC_PR_line: str='darkorange',
|
|
69
|
-
cmap: str = "Blues",
|
|
70
95
|
font_size: int = 16) -> None:
|
|
71
96
|
"""
|
|
72
97
|
Initializes the formatting configuration for multi-label classification metrics.
|
|
73
98
|
|
|
74
99
|
Args:
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
100
|
+
cmap (str): The matplotlib colormap name for the per-label
|
|
101
|
+
confusion matrices. Defaults to "Blues".
|
|
102
|
+
- Sequential options: 'Blues', 'Greens', 'Reds', 'Oranges', 'Purples'
|
|
103
|
+
- Diverging options: 'coolwarm', 'viridis', 'plasma', 'inferno'
|
|
104
|
+
|
|
79
105
|
ROC_PR_line (str): The color name or hex code for the line plotted
|
|
80
106
|
on the ROC and Precision-Recall curves (one for each label).
|
|
81
107
|
Defaults to 'darkorange'.
|
|
82
108
|
- Common color names: 'darkorange', 'cornflowerblue', 'crimson', 'forestgreen'
|
|
83
109
|
- Hex codes: '#FF6347', '#4682B4'
|
|
84
110
|
|
|
85
|
-
cmap (str): The matplotlib colormap name for the per-label
|
|
86
|
-
confusion matrices. Defaults to "Blues".
|
|
87
|
-
- Sequential options: 'Blues', 'Greens', 'Reds', 'Oranges', 'Purples'
|
|
88
|
-
- Diverging options: 'coolwarm', 'viridis', 'plasma', 'inferno'
|
|
89
|
-
|
|
90
111
|
font_size (int): The base font size to apply to the plots. Defaults to 16.
|
|
112
|
+
|
|
113
|
+
<br>
|
|
114
|
+
|
|
115
|
+
### [Matplotlib Colormaps](https://matplotlib.org/stable/users/explain/colors/colormaps.html)
|
|
116
|
+
|
|
117
|
+
<br>
|
|
118
|
+
|
|
119
|
+
### [Matplotlib Colors](https://matplotlib.org/stable/gallery/color/named_colors.html)
|
|
91
120
|
"""
|
|
92
|
-
self.threshold = threshold
|
|
93
121
|
self.cmap = cmap
|
|
94
122
|
self.ROC_PR_line = ROC_PR_line
|
|
95
123
|
self.font_size = font_size
|
|
96
124
|
|
|
97
125
|
def __repr__(self) -> str:
|
|
98
126
|
parts = [
|
|
99
|
-
f"threshold={self.threshold}",
|
|
100
|
-
f"ROC_PR_line='{self.ROC_PR_line}'",
|
|
101
127
|
f"cmap='{self.cmap}'",
|
|
128
|
+
f"ROC_PR_line='{self.ROC_PR_line}'",
|
|
129
|
+
f"font_size={self.font_size}"
|
|
130
|
+
]
|
|
131
|
+
return f"{self.__class__.__name__}({', '.join(parts)})"
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
class _BaseRegressionFormat:
|
|
135
|
+
"""
|
|
136
|
+
[PRIVATE] Base configuration for regression metrics.
|
|
137
|
+
"""
|
|
138
|
+
def __init__(self,
|
|
139
|
+
font_size: int=16,
|
|
140
|
+
scatter_color: str='tab:blue',
|
|
141
|
+
scatter_alpha: float=0.6,
|
|
142
|
+
ideal_line_color: str='k',
|
|
143
|
+
residual_line_color: str='red',
|
|
144
|
+
hist_bins: Union[int, str] = 'auto') -> None:
|
|
145
|
+
"""
|
|
146
|
+
Initializes the formatting configuration for regression metrics.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
font_size (int): The base font size to apply to the plots. Defaults to 16.
|
|
150
|
+
scatter_color (str): Matplotlib color for the scatter plot points. Defaults to 'tab:blue'.
|
|
151
|
+
- Common color names: 'tab:blue', 'crimson', 'forestgreen', '#4682B4'
|
|
152
|
+
scatter_alpha (float): Alpha transparency for scatter plot points. Defaults to 0.6.
|
|
153
|
+
ideal_line_color (str): Matplotlib color for the 'ideal' y=x line in the
|
|
154
|
+
True vs. Predicted plot. Defaults to 'k' (black).
|
|
155
|
+
- Common color names: 'k', 'red', 'darkgrey', '#FF6347'
|
|
156
|
+
residual_line_color (str): Matplotlib color for the y=0 line in the
|
|
157
|
+
Residual plot. Defaults to 'red'.
|
|
158
|
+
- Common color names: 'red', 'blue', 'k', '#4682B4'
|
|
159
|
+
hist_bins (int | str): The number of bins for the residuals histogram.
|
|
160
|
+
Defaults to 'auto' to use seaborn's automatic bin selection.
|
|
161
|
+
- Options: 'auto', 'sqrt', 10, 20
|
|
162
|
+
|
|
163
|
+
<br>
|
|
164
|
+
|
|
165
|
+
### [Matplotlib Colors](https://matplotlib.org/stable/gallery/color/named_colors.html)
|
|
166
|
+
"""
|
|
167
|
+
self.font_size = font_size
|
|
168
|
+
self.scatter_color = scatter_color
|
|
169
|
+
self.scatter_alpha = scatter_alpha
|
|
170
|
+
self.ideal_line_color = ideal_line_color
|
|
171
|
+
self.residual_line_color = residual_line_color
|
|
172
|
+
self.hist_bins = hist_bins
|
|
173
|
+
|
|
174
|
+
def __repr__(self) -> str:
|
|
175
|
+
parts = [
|
|
176
|
+
f"font_size={self.font_size}",
|
|
177
|
+
f"scatter_color='{self.scatter_color}'",
|
|
178
|
+
f"scatter_alpha={self.scatter_alpha}",
|
|
179
|
+
f"ideal_line_color='{self.ideal_line_color}'",
|
|
180
|
+
f"residual_line_color='{self.residual_line_color}'",
|
|
181
|
+
f"hist_bins='{self.hist_bins}'"
|
|
182
|
+
]
|
|
183
|
+
return f"{self.__class__.__name__}({', '.join(parts)})"
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
class _BaseSegmentationFormat:
|
|
187
|
+
"""
|
|
188
|
+
[PRIVATE] Base configuration for segmentation metrics.
|
|
189
|
+
"""
|
|
190
|
+
def __init__(self,
|
|
191
|
+
heatmap_cmap: str = "BuGn",
|
|
192
|
+
cm_cmap: str = "Purples",
|
|
193
|
+
font_size: int = 16) -> None:
|
|
194
|
+
"""
|
|
195
|
+
Initializes the formatting configuration for segmentation metrics.
|
|
196
|
+
|
|
197
|
+
Args:
|
|
198
|
+
heatmap_cmap (str): The matplotlib colormap name for the per-class
|
|
199
|
+
metrics heatmap. Defaults to "viridis".
|
|
200
|
+
- Sequential options: 'viridis', 'plasma', 'inferno', 'cividis'
|
|
201
|
+
- Diverging options: 'coolwarm', 'bwr', 'seismic'
|
|
202
|
+
cm_cmap (str): The matplotlib colormap name for the pixel-level
|
|
203
|
+
confusion matrix. Defaults to "Blues".
|
|
204
|
+
- Sequential options: 'Blues', 'Greens', 'Reds', 'Oranges'
|
|
205
|
+
font_size (int): The base font size to apply to the plots. Defaults to 16.
|
|
206
|
+
|
|
207
|
+
<br>
|
|
208
|
+
|
|
209
|
+
### [Matplotlib Colormaps](https://matplotlib.org/stable/users/explain/colors/colormaps.html)
|
|
210
|
+
"""
|
|
211
|
+
self.heatmap_cmap = heatmap_cmap
|
|
212
|
+
self.cm_cmap = cm_cmap
|
|
213
|
+
self.font_size = font_size
|
|
214
|
+
|
|
215
|
+
def __repr__(self) -> str:
|
|
216
|
+
parts = [
|
|
217
|
+
f"heatmap_cmap='{self.heatmap_cmap}'",
|
|
218
|
+
f"cm_cmap='{self.cm_cmap}'",
|
|
102
219
|
f"font_size={self.font_size}"
|
|
103
220
|
]
|
|
104
|
-
return f"
|
|
221
|
+
return f"{self.__class__.__name__}({', '.join(parts)})"
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
class _BaseSequenceValueFormat:
|
|
225
|
+
"""
|
|
226
|
+
[PRIVATE] Base configuration for sequence to value metrics.
|
|
227
|
+
"""
|
|
228
|
+
def __init__(self,
|
|
229
|
+
font_size: int=16,
|
|
230
|
+
scatter_color: str='tab:blue',
|
|
231
|
+
scatter_alpha: float=0.6,
|
|
232
|
+
ideal_line_color: str='k',
|
|
233
|
+
residual_line_color: str='red',
|
|
234
|
+
hist_bins: Union[int, str] = 'auto') -> None:
|
|
235
|
+
"""
|
|
236
|
+
Initializes the formatting configuration for sequence to value metrics.
|
|
237
|
+
|
|
238
|
+
Args:
|
|
239
|
+
font_size (int): The base font size to apply to the plots. Defaults to 16.
|
|
240
|
+
scatter_color (str): Matplotlib color for the scatter plot points. Defaults to 'tab:blue'.
|
|
241
|
+
- Common color names: 'tab:blue', 'crimson', 'forestgreen', '#4682B4'
|
|
242
|
+
scatter_alpha (float): Alpha transparency for scatter plot points. Defaults to 0.6.
|
|
243
|
+
ideal_line_color (str): Matplotlib color for the 'ideal' y=x line in the
|
|
244
|
+
True vs. Predicted plot. Defaults to 'k' (black).
|
|
245
|
+
- Common color names: 'k', 'red', 'darkgrey', '#FF6347'
|
|
246
|
+
residual_line_color (str): Matplotlib color for the y=0 line in the
|
|
247
|
+
Residual plot. Defaults to 'red'.
|
|
248
|
+
- Common color names: 'red', 'blue', 'k', '#4682B4'
|
|
249
|
+
hist_bins (int | str): The number of bins for the residuals histogram.
|
|
250
|
+
Defaults to 'auto' to use seaborn's automatic bin selection.
|
|
251
|
+
- Options: 'auto', 'sqrt', 10, 20
|
|
252
|
+
|
|
253
|
+
<br>
|
|
254
|
+
|
|
255
|
+
### [Matplotlib Colors](https://matplotlib.org/stable/gallery/color/named_colors.html)
|
|
256
|
+
"""
|
|
257
|
+
self.font_size = font_size
|
|
258
|
+
self.scatter_color = scatter_color
|
|
259
|
+
self.scatter_alpha = scatter_alpha
|
|
260
|
+
self.ideal_line_color = ideal_line_color
|
|
261
|
+
self.residual_line_color = residual_line_color
|
|
262
|
+
self.hist_bins = hist_bins
|
|
263
|
+
|
|
264
|
+
def __repr__(self) -> str:
|
|
265
|
+
parts = [
|
|
266
|
+
f"font_size={self.font_size}",
|
|
267
|
+
f"scatter_color='{self.scatter_color}'",
|
|
268
|
+
f"scatter_alpha={self.scatter_alpha}",
|
|
269
|
+
f"ideal_line_color='{self.ideal_line_color}'",
|
|
270
|
+
f"residual_line_color='{self.residual_line_color}'",
|
|
271
|
+
f"hist_bins='{self.hist_bins}'"
|
|
272
|
+
]
|
|
273
|
+
return f"{self.__class__.__name__}({', '.join(parts)})"
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
class _BaseSequenceSequenceFormat:
|
|
277
|
+
"""
|
|
278
|
+
[PRIVATE] Base configuration for sequence-to-sequence metrics.
|
|
279
|
+
"""
|
|
280
|
+
def __init__(self,
|
|
281
|
+
font_size: int = 16,
|
|
282
|
+
plot_figsize: tuple[int, int] = (10, 6),
|
|
283
|
+
grid_style: str = '--',
|
|
284
|
+
rmse_color: str = 'tab:blue',
|
|
285
|
+
rmse_marker: str = 'o-',
|
|
286
|
+
mae_color: str = 'tab:orange',
|
|
287
|
+
mae_marker: str = 's--'):
|
|
288
|
+
"""
|
|
289
|
+
Initializes the formatting configuration for seq-to-seq metrics.
|
|
290
|
+
|
|
291
|
+
Args:
|
|
292
|
+
font_size (int): The base font size to apply to the plots. Defaults to 16.
|
|
293
|
+
plot_figsize (Tuple[int, int]): Figure size for the plot. Defaults to (10, 6).
|
|
294
|
+
grid_style (str): Matplotlib linestyle for the plot grid. Defaults to '--'.
|
|
295
|
+
- Options: '--' (dashed), ':' (dotted), '-.' (dash-dot), '-' (solid)
|
|
296
|
+
rmse_color (str): Matplotlib color for the RMSE line. Defaults to 'tab:blue'.
|
|
297
|
+
- Common color names: 'tab:blue', 'crimson', 'forestgreen', '#4682B4'
|
|
298
|
+
rmse_marker (str): Matplotlib marker style for the RMSE line. Defaults to 'o-'.
|
|
299
|
+
- Options: 'o-' (circle), 's--' (square), '^:' (triangle), 'x' (x marker)
|
|
300
|
+
mae_color (str): Matplotlib color for the MAE line. Defaults to 'tab:orange'.
|
|
301
|
+
- Common color names: 'tab:orange', 'purple', 'black', '#FF6347'
|
|
302
|
+
mae_marker (str): Matplotlib marker style for the MAE line. Defaults to 's--'.
|
|
303
|
+
- Options: 's--', 'o-', 'v:', '+' (plus marker)
|
|
304
|
+
|
|
305
|
+
<br>
|
|
306
|
+
|
|
307
|
+
### [Matplotlib Colors](https://matplotlib.org/stable/gallery/color/named_colors.html)
|
|
308
|
+
|
|
309
|
+
<br>
|
|
310
|
+
|
|
311
|
+
### [Matplotlib Linestyles](https://matplotlib.org/stable/gallery/lines_bars_and_markers/linestyles.html)
|
|
312
|
+
|
|
313
|
+
<br>
|
|
314
|
+
|
|
315
|
+
### [Matplotlib Markers](https://matplotlib.org/stable/api/markers_api.html)
|
|
316
|
+
"""
|
|
317
|
+
self.font_size = font_size
|
|
318
|
+
self.plot_figsize = plot_figsize
|
|
319
|
+
self.grid_style = grid_style
|
|
320
|
+
self.rmse_color = rmse_color
|
|
321
|
+
self.rmse_marker = rmse_marker
|
|
322
|
+
self.mae_color = mae_color
|
|
323
|
+
self.mae_marker = mae_marker
|
|
324
|
+
|
|
325
|
+
def __repr__(self) -> str:
|
|
326
|
+
parts = [
|
|
327
|
+
f"font_size={self.font_size}",
|
|
328
|
+
f"plot_figsize={self.plot_figsize}",
|
|
329
|
+
f"grid_style='{self.grid_style}'",
|
|
330
|
+
f"rmse_color='{self.rmse_color}'",
|
|
331
|
+
f"mae_color='{self.mae_color}'"
|
|
332
|
+
]
|
|
333
|
+
return f"{self.__class__.__name__}({', '.join(parts)})"
|
|
334
|
+
|
|
335
|
+
# --- Public API classes ---
|
|
336
|
+
|
|
337
|
+
# Regression
|
|
338
|
+
class RegressionMetricsFormat(_BaseRegressionFormat):
|
|
339
|
+
"""
|
|
340
|
+
Configuration for single-target regression.
|
|
341
|
+
"""
|
|
342
|
+
def __init__(self,
|
|
343
|
+
font_size: int=16,
|
|
344
|
+
scatter_color: str='tab:blue',
|
|
345
|
+
scatter_alpha: float=0.6,
|
|
346
|
+
ideal_line_color: str='k',
|
|
347
|
+
residual_line_color: str='red',
|
|
348
|
+
hist_bins: Union[int, str] = 'auto') -> None:
|
|
349
|
+
super().__init__(font_size=font_size,
|
|
350
|
+
scatter_color=scatter_color,
|
|
351
|
+
scatter_alpha=scatter_alpha,
|
|
352
|
+
ideal_line_color=ideal_line_color,
|
|
353
|
+
residual_line_color=residual_line_color,
|
|
354
|
+
hist_bins=hist_bins)
|
|
355
|
+
|
|
356
|
+
|
|
357
|
+
# Multitarget regression
|
|
358
|
+
class MultiTargetRegressionMetricsFormat(_BaseRegressionFormat):
|
|
359
|
+
"""
|
|
360
|
+
Configuration for multi-target regression.
|
|
361
|
+
"""
|
|
362
|
+
def __init__(self,
|
|
363
|
+
font_size: int=16,
|
|
364
|
+
scatter_color: str='tab:blue',
|
|
365
|
+
scatter_alpha: float=0.6,
|
|
366
|
+
ideal_line_color: str='k',
|
|
367
|
+
residual_line_color: str='red',
|
|
368
|
+
hist_bins: Union[int, str] = 'auto') -> None:
|
|
369
|
+
super().__init__(font_size=font_size,
|
|
370
|
+
scatter_color=scatter_color,
|
|
371
|
+
scatter_alpha=scatter_alpha,
|
|
372
|
+
ideal_line_color=ideal_line_color,
|
|
373
|
+
residual_line_color=residual_line_color,
|
|
374
|
+
hist_bins=hist_bins)
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
# Classification
|
|
378
|
+
class BinaryClassificationMetricsFormat(_BaseClassificationFormat):
|
|
379
|
+
"""
|
|
380
|
+
Configuration for binary classification.
|
|
381
|
+
"""
|
|
382
|
+
def __init__(self,
|
|
383
|
+
cmap: str="BuGn",
|
|
384
|
+
ROC_PR_line: str='darkorange',
|
|
385
|
+
calibration_bins: int=15,
|
|
386
|
+
font_size: int=16) -> None:
|
|
387
|
+
super().__init__(cmap=cmap,
|
|
388
|
+
ROC_PR_line=ROC_PR_line,
|
|
389
|
+
calibration_bins=calibration_bins,
|
|
390
|
+
font_size=font_size)
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
class MultiClassClassificationMetricsFormat(_BaseClassificationFormat):
|
|
394
|
+
"""
|
|
395
|
+
Configuration for multi-class classification.
|
|
396
|
+
"""
|
|
397
|
+
def __init__(self,
|
|
398
|
+
cmap: str="BuGn",
|
|
399
|
+
ROC_PR_line: str='darkorange',
|
|
400
|
+
calibration_bins: int=15,
|
|
401
|
+
font_size: int=16) -> None:
|
|
402
|
+
super().__init__(cmap=cmap,
|
|
403
|
+
ROC_PR_line=ROC_PR_line,
|
|
404
|
+
calibration_bins=calibration_bins,
|
|
405
|
+
font_size=font_size)
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
class BinaryImageClassificationMetricsFormat(_BaseClassificationFormat):
|
|
409
|
+
"""
|
|
410
|
+
Configuration for binary image classification.
|
|
411
|
+
"""
|
|
412
|
+
def __init__(self,
|
|
413
|
+
cmap: str="BuGn",
|
|
414
|
+
ROC_PR_line: str='darkorange',
|
|
415
|
+
calibration_bins: int=15,
|
|
416
|
+
font_size: int=16) -> None:
|
|
417
|
+
super().__init__(cmap=cmap,
|
|
418
|
+
ROC_PR_line=ROC_PR_line,
|
|
419
|
+
calibration_bins=calibration_bins,
|
|
420
|
+
font_size=font_size)
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
class MultiClassImageClassificationMetricsFormat(_BaseClassificationFormat):
|
|
424
|
+
"""
|
|
425
|
+
Configuration for multi-class image classification.
|
|
426
|
+
"""
|
|
427
|
+
def __init__(self,
|
|
428
|
+
cmap: str="BuGn",
|
|
429
|
+
ROC_PR_line: str='darkorange',
|
|
430
|
+
calibration_bins: int=15,
|
|
431
|
+
font_size: int=16) -> None:
|
|
432
|
+
super().__init__(cmap=cmap,
|
|
433
|
+
ROC_PR_line=ROC_PR_line,
|
|
434
|
+
calibration_bins=calibration_bins,
|
|
435
|
+
font_size=font_size)
|
|
436
|
+
|
|
437
|
+
|
|
438
|
+
# Multi-Label classification
|
|
439
|
+
class MultiLabelBinaryClassificationMetricsFormat(_BaseMultiLabelFormat):
|
|
440
|
+
"""
|
|
441
|
+
Configuration for multi-label binary classification.
|
|
442
|
+
"""
|
|
443
|
+
def __init__(self,
|
|
444
|
+
cmap: str = "BuGn",
|
|
445
|
+
ROC_PR_line: str='darkorange',
|
|
446
|
+
font_size: int = 16) -> None:
|
|
447
|
+
super().__init__(cmap=cmap,
|
|
448
|
+
ROC_PR_line=ROC_PR_line,
|
|
449
|
+
font_size=font_size)
|
|
450
|
+
|
|
451
|
+
|
|
452
|
+
# Segmentation
|
|
453
|
+
class BinarySegmentationMetricsFormat(_BaseSegmentationFormat):
|
|
454
|
+
"""
|
|
455
|
+
Configuration for binary segmentation.
|
|
456
|
+
"""
|
|
457
|
+
def __init__(self,
|
|
458
|
+
heatmap_cmap: str = "BuGn",
|
|
459
|
+
cm_cmap: str = "Purples",
|
|
460
|
+
font_size: int = 16) -> None:
|
|
461
|
+
super().__init__(heatmap_cmap=heatmap_cmap,
|
|
462
|
+
cm_cmap=cm_cmap,
|
|
463
|
+
font_size=font_size)
|
|
464
|
+
|
|
465
|
+
|
|
466
|
+
class MultiClassSegmentationMetricsFormat(_BaseSegmentationFormat):
|
|
467
|
+
"""
|
|
468
|
+
Configuration for multi-class segmentation.
|
|
469
|
+
"""
|
|
470
|
+
def __init__(self,
|
|
471
|
+
heatmap_cmap: str = "BuGn",
|
|
472
|
+
cm_cmap: str = "Purples",
|
|
473
|
+
font_size: int = 16) -> None:
|
|
474
|
+
super().__init__(heatmap_cmap=heatmap_cmap,
|
|
475
|
+
cm_cmap=cm_cmap,
|
|
476
|
+
font_size=font_size)
|
|
477
|
+
|
|
478
|
+
|
|
479
|
+
# Sequence
|
|
480
|
+
class SequenceValueMetricsFormat(_BaseSequenceValueFormat):
|
|
481
|
+
"""
|
|
482
|
+
Configuration for sequence-to-value prediction.
|
|
483
|
+
"""
|
|
484
|
+
def __init__(self,
|
|
485
|
+
font_size: int=16,
|
|
486
|
+
scatter_color: str='tab:blue',
|
|
487
|
+
scatter_alpha: float=0.6,
|
|
488
|
+
ideal_line_color: str='k',
|
|
489
|
+
residual_line_color: str='red',
|
|
490
|
+
hist_bins: Union[int, str] = 'auto') -> None:
|
|
491
|
+
super().__init__(font_size=font_size,
|
|
492
|
+
scatter_color=scatter_color,
|
|
493
|
+
scatter_alpha=scatter_alpha,
|
|
494
|
+
ideal_line_color=ideal_line_color,
|
|
495
|
+
residual_line_color=residual_line_color,
|
|
496
|
+
hist_bins=hist_bins)
|
|
497
|
+
|
|
498
|
+
|
|
499
|
+
class SequenceSequenceMetricsFormat(_BaseSequenceSequenceFormat):
|
|
500
|
+
"""
|
|
501
|
+
Configuration for sequence-to-sequence prediction.
|
|
502
|
+
"""
|
|
503
|
+
def __init__(self,
|
|
504
|
+
font_size: int = 16,
|
|
505
|
+
plot_figsize: tuple[int, int] = (10, 6),
|
|
506
|
+
grid_style: str = '--',
|
|
507
|
+
rmse_color: str = 'tab:blue',
|
|
508
|
+
rmse_marker: str = 'o-',
|
|
509
|
+
mae_color: str = 'tab:orange',
|
|
510
|
+
mae_marker: str = 's--'):
|
|
511
|
+
super().__init__(font_size=font_size,
|
|
512
|
+
plot_figsize=plot_figsize,
|
|
513
|
+
grid_style=grid_style,
|
|
514
|
+
rmse_color=rmse_color,
|
|
515
|
+
rmse_marker=rmse_marker,
|
|
516
|
+
mae_color=mae_color,
|
|
517
|
+
mae_marker=mae_marker)
|
|
518
|
+
|
|
519
|
+
|
|
520
|
+
# -------- Finalize classes --------
|
|
521
|
+
class _FinalizeModelTraining:
|
|
522
|
+
"""
|
|
523
|
+
Base class for finalizing model training.
|
|
524
|
+
|
|
525
|
+
This class is not intended to be instantiated directly. Instead, use one of its specific subclasses.
|
|
526
|
+
"""
|
|
527
|
+
def __init__(self,
|
|
528
|
+
filename: str,
|
|
529
|
+
) -> None:
|
|
530
|
+
self.filename = _validate_string(string=filename, attribute_name="filename", extension=".pth")
|
|
531
|
+
self.target_name: Optional[str] = None
|
|
532
|
+
self.target_names: Optional[list[str]] = None
|
|
533
|
+
self.classification_threshold: Optional[float] = None
|
|
534
|
+
self.class_map: Optional[dict[str,int]] = None
|
|
535
|
+
self.initial_sequence: Optional[np.ndarray] = None
|
|
536
|
+
self.sequence_length: Optional[int] = None
|
|
537
|
+
|
|
538
|
+
|
|
539
|
+
class FinalizeRegression(_FinalizeModelTraining):
|
|
540
|
+
"""Parameters for finalizing a single-target regression model."""
|
|
541
|
+
def __init__(self,
|
|
542
|
+
filename: str,
|
|
543
|
+
target_name: str,
|
|
544
|
+
) -> None:
|
|
545
|
+
"""Initializes the finalization parameters.
|
|
546
|
+
|
|
547
|
+
Args:
|
|
548
|
+
filename (str): The name of the file to be saved.
|
|
549
|
+
target_name (str): The name of the target variable.
|
|
550
|
+
"""
|
|
551
|
+
super().__init__(filename=filename)
|
|
552
|
+
self.target_name = _validate_string(string=target_name, attribute_name="Target name")
|
|
553
|
+
|
|
554
|
+
|
|
555
|
+
class FinalizeMultiTargetRegression(_FinalizeModelTraining):
|
|
556
|
+
"""Parameters for finalizing a multi-target regression model."""
|
|
557
|
+
def __init__(self,
|
|
558
|
+
filename: str,
|
|
559
|
+
target_names: list[str],
|
|
560
|
+
) -> None:
|
|
561
|
+
"""Initializes the finalization parameters.
|
|
562
|
+
|
|
563
|
+
Args:
|
|
564
|
+
filename (str): The name of the file to be saved.
|
|
565
|
+
target_names (list[str]): A list of names for the target variables.
|
|
566
|
+
"""
|
|
567
|
+
super().__init__(filename=filename)
|
|
568
|
+
safe_names = [_validate_string(string=target_name, attribute_name="All target names") for target_name in target_names]
|
|
569
|
+
self.target_names = safe_names
|
|
570
|
+
|
|
571
|
+
|
|
572
|
+
class FinalizeBinaryClassification(_FinalizeModelTraining):
|
|
573
|
+
"""Parameters for finalizing a binary classification model."""
|
|
574
|
+
def __init__(self,
|
|
575
|
+
filename: str,
|
|
576
|
+
target_name: str,
|
|
577
|
+
classification_threshold: float,
|
|
578
|
+
class_map: dict[str,int]
|
|
579
|
+
) -> None:
|
|
580
|
+
"""Initializes the finalization parameters.
|
|
581
|
+
|
|
582
|
+
Args:
|
|
583
|
+
filename (str): The name of the file to be saved.
|
|
584
|
+
target_name (str): The name of the target variable.
|
|
585
|
+
classification_threshold (float): The cutoff threshold for classifying as the positive class.
|
|
586
|
+
class_map (dict[str,int]): A dictionary mapping class names (str)
|
|
587
|
+
to their integer representations (e.g., {'cat': 0, 'dog': 1}).
|
|
588
|
+
"""
|
|
589
|
+
super().__init__(filename=filename)
|
|
590
|
+
self.target_name = _validate_string(string=target_name, attribute_name="Target name")
|
|
591
|
+
self.classification_threshold = _validate_threshold(classification_threshold)
|
|
592
|
+
self.class_map = _validate_class_map(class_map)
|
|
593
|
+
|
|
594
|
+
|
|
595
|
+
class FinalizeMultiClassClassification(_FinalizeModelTraining):
|
|
596
|
+
"""Parameters for finalizing a multi-class classification model."""
|
|
597
|
+
def __init__(self,
|
|
598
|
+
filename: str,
|
|
599
|
+
target_name: str,
|
|
600
|
+
class_map: dict[str,int]
|
|
601
|
+
) -> None:
|
|
602
|
+
"""Initializes the finalization parameters.
|
|
603
|
+
|
|
604
|
+
Args:
|
|
605
|
+
filename (str): The name of the file to be saved.
|
|
606
|
+
target_name (str): The name of the target variable.
|
|
607
|
+
class_map (dict[str,int]): A dictionary mapping class names (str)
|
|
608
|
+
to their integer representations (e.g., {'cat': 0, 'dog': 1}).
|
|
609
|
+
"""
|
|
610
|
+
super().__init__(filename=filename)
|
|
611
|
+
self.target_name = _validate_string(string=target_name, attribute_name="Target name")
|
|
612
|
+
self.class_map = _validate_class_map(class_map)
|
|
613
|
+
|
|
614
|
+
|
|
615
|
+
class FinalizeBinaryImageClassification(_FinalizeModelTraining):
|
|
616
|
+
"""Parameters for finalizing a binary image classification model."""
|
|
617
|
+
def __init__(self,
|
|
618
|
+
filename: str,
|
|
619
|
+
classification_threshold: float,
|
|
620
|
+
class_map: dict[str,int]
|
|
621
|
+
) -> None:
|
|
622
|
+
"""Initializes the finalization parameters.
|
|
623
|
+
|
|
624
|
+
Args:
|
|
625
|
+
filename (str): The name of the file to be saved.
|
|
626
|
+
classification_threshold (float): The cutoff threshold for
|
|
627
|
+
classifying as the positive class.
|
|
628
|
+
class_map (dict[str,int]): A dictionary mapping class names (str)
|
|
629
|
+
to their integer representations (e.g., {'cat': 0, 'dog': 1}).
|
|
630
|
+
"""
|
|
631
|
+
super().__init__(filename=filename)
|
|
632
|
+
self.classification_threshold = _validate_threshold(classification_threshold)
|
|
633
|
+
self.class_map = _validate_class_map(class_map)
|
|
634
|
+
|
|
635
|
+
|
|
636
|
+
class FinalizeMultiClassImageClassification(_FinalizeModelTraining):
|
|
637
|
+
"""Parameters for finalizing a multi-class image classification model."""
|
|
638
|
+
def __init__(self,
|
|
639
|
+
filename: str,
|
|
640
|
+
class_map: dict[str,int]
|
|
641
|
+
) -> None:
|
|
642
|
+
"""Initializes the finalization parameters.
|
|
643
|
+
|
|
644
|
+
Args:
|
|
645
|
+
filename (str): The name of the file to be saved.
|
|
646
|
+
class_map (dict[str,int]): A dictionary mapping class names (str)
|
|
647
|
+
to their integer representations (e.g., {'cat': 0, 'dog': 1}).
|
|
648
|
+
"""
|
|
649
|
+
super().__init__(filename=filename)
|
|
650
|
+
self.class_map = _validate_class_map(class_map)
|
|
651
|
+
|
|
652
|
+
|
|
653
|
+
class FinalizeMultiLabelBinaryClassification(_FinalizeModelTraining):
|
|
654
|
+
"""Parameters for finalizing a multi-label binary classification model."""
|
|
655
|
+
def __init__(self,
|
|
656
|
+
filename: str,
|
|
657
|
+
target_names: list[str],
|
|
658
|
+
classification_threshold: float,
|
|
659
|
+
) -> None:
|
|
660
|
+
"""Initializes the finalization parameters.
|
|
661
|
+
|
|
662
|
+
Args:
|
|
663
|
+
filename (str): The name of the file to be saved.
|
|
664
|
+
target_names (list[str]): A list of names for the target variables.
|
|
665
|
+
classification_threshold (float): The cutoff threshold for classifying as the positive class.
|
|
666
|
+
"""
|
|
667
|
+
super().__init__(filename=filename)
|
|
668
|
+
safe_names = [_validate_string(string=target_name, attribute_name="All target names") for target_name in target_names]
|
|
669
|
+
self.target_names = safe_names
|
|
670
|
+
self.classification_threshold = _validate_threshold(classification_threshold)
|
|
671
|
+
|
|
672
|
+
|
|
673
|
+
class FinalizeBinarySegmentation(_FinalizeModelTraining):
|
|
674
|
+
"""Parameters for finalizing a binary segmentation model."""
|
|
675
|
+
def __init__(self,
|
|
676
|
+
filename: str,
|
|
677
|
+
class_map: dict[str,int],
|
|
678
|
+
classification_threshold: float,
|
|
679
|
+
) -> None:
|
|
680
|
+
"""Initializes the finalization parameters.
|
|
681
|
+
|
|
682
|
+
Args:
|
|
683
|
+
filename (str): The name of the file to be saved.
|
|
684
|
+
classification_threshold (float): The cutoff threshold for classifying as the positive class (mask).
|
|
685
|
+
"""
|
|
686
|
+
super().__init__(filename=filename)
|
|
687
|
+
self.classification_threshold = _validate_threshold(classification_threshold)
|
|
688
|
+
self.class_map = _validate_class_map(class_map)
|
|
689
|
+
|
|
690
|
+
|
|
691
|
+
class FinalizeMultiClassSegmentation(_FinalizeModelTraining):
|
|
692
|
+
"""Parameters for finalizing a multi-class segmentation model."""
|
|
693
|
+
def __init__(self,
|
|
694
|
+
filename: str,
|
|
695
|
+
class_map: dict[str,int]
|
|
696
|
+
) -> None:
|
|
697
|
+
"""Initializes the finalization parameters.
|
|
698
|
+
|
|
699
|
+
Args:
|
|
700
|
+
filename (str): The name of the file to be saved.
|
|
701
|
+
"""
|
|
702
|
+
super().__init__(filename=filename)
|
|
703
|
+
self.class_map = _validate_class_map(class_map)
|
|
704
|
+
|
|
705
|
+
|
|
706
|
+
class FinalizeObjectDetection(_FinalizeModelTraining):
|
|
707
|
+
"""Parameters for finalizing an object detection model."""
|
|
708
|
+
def __init__(self,
|
|
709
|
+
filename: str,
|
|
710
|
+
class_map: dict[str,int]
|
|
711
|
+
) -> None:
|
|
712
|
+
"""Initializes the finalization parameters.
|
|
713
|
+
|
|
714
|
+
Args:
|
|
715
|
+
filename (str): The name of the file to be saved.
|
|
716
|
+
"""
|
|
717
|
+
super().__init__(filename=filename)
|
|
718
|
+
self.class_map = _validate_class_map(class_map)
|
|
719
|
+
|
|
720
|
+
|
|
721
|
+
class FinalizeSequencePrediction(_FinalizeModelTraining):
|
|
722
|
+
"""Parameters for finalizing a sequence prediction model."""
|
|
723
|
+
def __init__(self,
|
|
724
|
+
filename: str,
|
|
725
|
+
last_training_sequence: np.ndarray,
|
|
726
|
+
) -> None:
|
|
727
|
+
"""Initializes the finalization parameters.
|
|
728
|
+
|
|
729
|
+
Args:
|
|
730
|
+
filename (str): The name of the file to be saved.
|
|
731
|
+
last_training_sequence (np.ndarray): The last sequence from the training data, needed to start predictions.
|
|
732
|
+
"""
|
|
733
|
+
super().__init__(filename=filename)
|
|
734
|
+
|
|
735
|
+
if not isinstance(last_training_sequence, np.ndarray):
|
|
736
|
+
_LOGGER.error(f"The last training sequence must be a 1D numpy array, got {type(last_training_sequence)}.")
|
|
737
|
+
raise TypeError()
|
|
738
|
+
|
|
739
|
+
if last_training_sequence.ndim == 1:
|
|
740
|
+
# It's already 1D, (N,). This is valid.
|
|
741
|
+
self.initial_sequence = last_training_sequence
|
|
742
|
+
elif last_training_sequence.ndim == 2:
|
|
743
|
+
# It's 2D, check for shape (1, N)
|
|
744
|
+
if last_training_sequence.shape[0] == 1:
|
|
745
|
+
# Shape is (1, N), flatten to (N,)
|
|
746
|
+
self.initial_sequence = last_training_sequence.flatten()
|
|
747
|
+
else:
|
|
748
|
+
# Shape is (N, 1) or (N, M), which is invalid
|
|
749
|
+
_LOGGER.error(f"The last training sequence must be a 1D numpy array, got shape {last_training_sequence.shape}.")
|
|
750
|
+
raise ValueError()
|
|
751
|
+
else:
|
|
752
|
+
# It's 3D or more, which is not supported
|
|
753
|
+
_LOGGER.error(f"The last training sequence must be a 1D numpy array, got shape {last_training_sequence.shape}.")
|
|
754
|
+
raise ValueError()
|
|
755
|
+
|
|
756
|
+
# Save the length of the validated 1D sequence
|
|
757
|
+
self.sequence_length = len(self.initial_sequence)
|
|
758
|
+
|
|
759
|
+
|
|
760
|
+
def _validate_string(string: str, attribute_name: str, extension: Optional[str]=None) -> str:
|
|
761
|
+
"""Helper for finalize classes"""
|
|
762
|
+
if not isinstance(string, str):
|
|
763
|
+
_LOGGER.error(f"{attribute_name} must be a string.")
|
|
764
|
+
raise TypeError()
|
|
765
|
+
|
|
766
|
+
if extension:
|
|
767
|
+
safe_name = sanitize_filename(string)
|
|
768
|
+
|
|
769
|
+
if not safe_name.endswith(extension):
|
|
770
|
+
safe_name += extension
|
|
771
|
+
else:
|
|
772
|
+
safe_name = string
|
|
773
|
+
|
|
774
|
+
return safe_name
|
|
775
|
+
|
|
776
|
+
def _validate_threshold(threshold: float):
|
|
777
|
+
"""Helper for finalize classes"""
|
|
778
|
+
if not isinstance(threshold, float):
|
|
779
|
+
_LOGGER.error(f"Classification threshold must be a float.")
|
|
780
|
+
raise TypeError()
|
|
781
|
+
elif threshold <= 0.0 or threshold >= 1.0:
|
|
782
|
+
_LOGGER.error(f"Classification threshold must be in the range [0.1, 0.9]")
|
|
783
|
+
raise ValueError()
|
|
784
|
+
|
|
785
|
+
return threshold
|
|
105
786
|
|
|
787
|
+
def _validate_class_map(map: dict[str,int]):
|
|
788
|
+
"""Helper for finalize classes"""
|
|
789
|
+
validated_map = None
|
|
790
|
+
if isinstance(map, dict):
|
|
791
|
+
if all( [isinstance(key, str) for key in map.keys()] ):
|
|
792
|
+
if all( [isinstance(val, str) for val in map.values()] ):
|
|
793
|
+
validated_map = map
|
|
794
|
+
|
|
795
|
+
if validated_map is None:
|
|
796
|
+
_LOGGER.error(f"Class map must be a dictionary of string keys and integer values.")
|
|
797
|
+
raise TypeError()
|
|
798
|
+
else:
|
|
799
|
+
return validated_map
|
|
106
800
|
|
|
107
801
|
def info():
|
|
108
802
|
_script_info(__all__)
|