dragon-ml-toolbox 14.7.0__py3-none-any.whl → 16.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/METADATA +9 -5
  2. dragon_ml_toolbox-16.2.0.dist-info/RECORD +51 -0
  3. ml_tools/ETL_cleaning.py +20 -20
  4. ml_tools/ETL_engineering.py +23 -25
  5. ml_tools/GUI_tools.py +20 -20
  6. ml_tools/MICE_imputation.py +3 -3
  7. ml_tools/ML_callbacks.py +43 -26
  8. ml_tools/ML_configuration.py +704 -24
  9. ml_tools/ML_datasetmaster.py +235 -280
  10. ml_tools/ML_evaluation.py +144 -39
  11. ml_tools/ML_evaluation_multi.py +103 -35
  12. ml_tools/ML_inference.py +290 -208
  13. ml_tools/ML_models.py +13 -102
  14. ml_tools/ML_models_advanced.py +1 -1
  15. ml_tools/ML_optimization.py +12 -12
  16. ml_tools/ML_scaler.py +11 -11
  17. ml_tools/ML_sequence_datasetmaster.py +341 -0
  18. ml_tools/ML_sequence_evaluation.py +219 -0
  19. ml_tools/ML_sequence_inference.py +391 -0
  20. ml_tools/ML_sequence_models.py +139 -0
  21. ml_tools/ML_trainer.py +1342 -386
  22. ml_tools/ML_utilities.py +1 -1
  23. ml_tools/ML_vision_datasetmaster.py +120 -72
  24. ml_tools/ML_vision_evaluation.py +30 -6
  25. ml_tools/ML_vision_inference.py +129 -152
  26. ml_tools/ML_vision_models.py +1 -1
  27. ml_tools/ML_vision_transformers.py +121 -40
  28. ml_tools/PSO_optimization.py +6 -6
  29. ml_tools/SQL.py +4 -4
  30. ml_tools/{keys.py → _keys.py} +45 -0
  31. ml_tools/_schema.py +1 -1
  32. ml_tools/ensemble_evaluation.py +1 -1
  33. ml_tools/ensemble_inference.py +7 -33
  34. ml_tools/ensemble_learning.py +1 -1
  35. ml_tools/optimization_tools.py +2 -2
  36. ml_tools/path_manager.py +5 -5
  37. ml_tools/utilities.py +1 -2
  38. dragon_ml_toolbox-14.7.0.dist-info/RECORD +0 -49
  39. ml_tools/RNN_forecast.py +0 -56
  40. ml_tools/_ML_vision_recipe.py +0 -88
  41. {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/WHEEL +0 -0
  42. {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/licenses/LICENSE +0 -0
  43. {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
  44. {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/top_level.txt +0 -0
@@ -1,20 +1,45 @@
1
- from typing import Optional
1
+ from typing import Union, Optional
2
+ import numpy as np
3
+
2
4
  from ._script_info import _script_info
5
+ from ._logger import _LOGGER
6
+ from .path_manager import sanitize_filename
3
7
 
4
8
 
5
9
  __all__ = [
6
- "ClassificationMetricsFormat",
7
- "MultiClassificationMetricsFormat"
10
+ "RegressionMetricsFormat",
11
+ "MultiTargetRegressionMetricsFormat",
12
+ "BinaryClassificationMetricsFormat",
13
+ "MultiClassClassificationMetricsFormat",
14
+ "BinaryImageClassificationMetricsFormat",
15
+ "MultiClassImageClassificationMetricsFormat",
16
+ "MultiLabelBinaryClassificationMetricsFormat",
17
+ "BinarySegmentationMetricsFormat",
18
+ "MultiClassSegmentationMetricsFormat",
19
+ "SequenceValueMetricsFormat",
20
+ "SequenceSequenceMetricsFormat",
21
+
22
+ "FinalizeBinaryClassification",
23
+ "FinalizeBinarySegmentation",
24
+ "FinalizeBinaryImageClassification",
25
+ "FinalizeMultiClassClassification",
26
+ "FinalizeMultiClassImageClassification",
27
+ "FinalizeMultiClassSegmentation",
28
+ "FinalizeMultiLabelBinaryClassification",
29
+ "FinalizeMultiTargetRegression",
30
+ "FinalizeRegression",
31
+ "FinalizeObjectDetection",
32
+ "FinalizeSequencePrediction"
8
33
  ]
9
34
 
35
+ # --- Private base classes ---
10
36
 
11
- class ClassificationMetricsFormat:
37
+ class _BaseClassificationFormat:
12
38
  """
13
- Optional configuration for classification tasks, use in the '.evaluate()' method of the MLTrainer.
39
+ [PRIVATE] Base configuration for single-label classification metrics.
14
40
  """
15
41
  def __init__(self,
16
42
  cmap: str="Blues",
17
- class_map: Optional[dict[str,int]]=None,
18
43
  ROC_PR_line: str='darkorange',
19
44
  calibration_bins: int=15,
20
45
  font_size: int=16) -> None:
@@ -27,11 +52,6 @@ class ClassificationMetricsFormat:
27
52
  - Sequential options: 'Blues', 'Greens', 'Reds', 'Oranges', 'Purples'
28
53
  - Diverging options: 'coolwarm', 'viridis', 'plasma', 'inferno'
29
54
 
30
- class_map (dict[str,int] | None): A dictionary mapping
31
- class string names to their integer indices (e.g., {'cat': 0, 'dog': 1}).
32
- This is used to label the axes of the confusion matrix and classification
33
- report correctly. Defaults to None.
34
-
35
55
  ROC_PR_line (str): The color name or hex code for the line plotted
36
56
  on the ROC and Precision-Recall curves. Defaults to 'darkorange'.
37
57
  - Common color names: 'darkorange', 'cornflowerblue', 'crimson', 'forestgreen'
@@ -41,9 +61,12 @@ class ClassificationMetricsFormat:
41
61
  creating the calibration (reliability) plot. Defaults to 15.
42
62
 
43
63
  font_size (int): The base font size to apply to the plots. Defaults to 16.
64
+
65
+ <br>
66
+
67
+ ## [Matplotlib Colormaps](https://matplotlib.org/stable/users/explain/colors/colormaps.html)
44
68
  """
45
69
  self.cmap = cmap
46
- self.class_map = class_map
47
70
  self.ROC_PR_line = ROC_PR_line
48
71
  self.calibration_bins = calibration_bins
49
72
  self.font_size = font_size
@@ -51,20 +74,18 @@ class ClassificationMetricsFormat:
51
74
  def __repr__(self) -> str:
52
75
  parts = [
53
76
  f"cmap='{self.cmap}'",
54
- f"class_map={self.class_map}",
55
77
  f"ROC_PR_line='{self.ROC_PR_line}'",
56
78
  f"calibration_bins={self.calibration_bins}",
57
79
  f"font_size={self.font_size}"
58
80
  ]
59
- return f"ClassificationMetricsFormat({', '.join(parts)})"
81
+ return f"{self.__class__.__name__}({', '.join(parts)})"
60
82
 
61
83
 
62
- class MultiClassificationMetricsFormat:
84
+ class _BaseMultiLabelFormat:
63
85
  """
64
- Optional configuration for multi-label classification tasks, use in the '.evaluate()' method of the MLTrainer.
86
+ [PRIVATE] Base configuration for multi-label binary classification metrics.
65
87
  """
66
88
  def __init__(self,
67
- threshold: float=0.5,
68
89
  ROC_PR_line: str='darkorange',
69
90
  cmap: str = "Blues",
70
91
  font_size: int = 16) -> None:
@@ -72,10 +93,6 @@ class MultiClassificationMetricsFormat:
72
93
  Initializes the formatting configuration for multi-label classification metrics.
73
94
 
74
95
  Args:
75
- threshold (float): The probability threshold (0.0 to 1.0) used
76
- to convert sigmoid outputs into binary (0 or 1) predictions for
77
- calculating the confusion matrix and overall metrics. Defaults to 0.5.
78
-
79
96
  ROC_PR_line (str): The color name or hex code for the line plotted
80
97
  on the ROC and Precision-Recall curves (one for each label).
81
98
  Defaults to 'darkorange'.
@@ -88,21 +105,684 @@ class MultiClassificationMetricsFormat:
88
105
  - Diverging options: 'coolwarm', 'viridis', 'plasma', 'inferno'
89
106
 
90
107
  font_size (int): The base font size to apply to the plots. Defaults to 16.
108
+
109
+ <br>
110
+
111
+ ## [Matplotlib Colormaps](https://matplotlib.org/stable/users/explain/colors/colormaps.html)
91
112
  """
92
- self.threshold = threshold
93
113
  self.cmap = cmap
94
114
  self.ROC_PR_line = ROC_PR_line
95
115
  self.font_size = font_size
96
116
 
97
117
  def __repr__(self) -> str:
98
118
  parts = [
99
- f"threshold={self.threshold}",
100
119
  f"ROC_PR_line='{self.ROC_PR_line}'",
101
120
  f"cmap='{self.cmap}'",
102
121
  f"font_size={self.font_size}"
103
122
  ]
104
- return f"MultiClassificationMetricsFormat({', '.join(parts)})"
123
+ return f"{self.__class__.__name__}({', '.join(parts)})"
124
+
125
+
126
+ class _BaseRegressionFormat:
127
+ """
128
+ [PRIVATE] Base configuration for regression metrics.
129
+ """
130
+ def __init__(self,
131
+ font_size: int=16,
132
+ scatter_color: str='tab:blue',
133
+ scatter_alpha: float=0.6,
134
+ ideal_line_color: str='k',
135
+ residual_line_color: str='red',
136
+ hist_bins: Union[int, str] = 'auto') -> None:
137
+ """
138
+ Initializes the formatting configuration for regression metrics.
139
+
140
+ Args:
141
+ font_size (int): The base font size to apply to the plots. Defaults to 16.
142
+ scatter_color (str): Matplotlib color for the scatter plot points. Defaults to 'tab:blue'.
143
+ - Common color names: 'tab:blue', 'crimson', 'forestgreen', '#4682B4'
144
+ scatter_alpha (float): Alpha transparency for scatter plot points. Defaults to 0.6.
145
+ ideal_line_color (str): Matplotlib color for the 'ideal' y=x line in the
146
+ True vs. Predicted plot. Defaults to 'k' (black).
147
+ - Common color names: 'k', 'red', 'darkgrey', '#FF6347'
148
+ residual_line_color (str): Matplotlib color for the y=0 line in the
149
+ Residual plot. Defaults to 'red'.
150
+ - Common color names: 'red', 'blue', 'k', '#4682B4'
151
+ hist_bins (int | str): The number of bins for the residuals histogram.
152
+ Defaults to 'auto' to use seaborn's automatic bin selection.
153
+ - Options: 'auto', 'sqrt', 10, 20
154
+
155
+ <br>
156
+
157
+ ## [Matplotlib Colors](https://matplotlib.org/stable/users/explain/colors/colors.html)
158
+ """
159
+ self.font_size = font_size
160
+ self.scatter_color = scatter_color
161
+ self.scatter_alpha = scatter_alpha
162
+ self.ideal_line_color = ideal_line_color
163
+ self.residual_line_color = residual_line_color
164
+ self.hist_bins = hist_bins
165
+
166
+ def __repr__(self) -> str:
167
+ parts = [
168
+ f"font_size={self.font_size}",
169
+ f"scatter_color='{self.scatter_color}'",
170
+ f"scatter_alpha={self.scatter_alpha}",
171
+ f"ideal_line_color='{self.ideal_line_color}'",
172
+ f"residual_line_color='{self.residual_line_color}'",
173
+ f"hist_bins='{self.hist_bins}'"
174
+ ]
175
+ return f"{self.__class__.__name__}({', '.join(parts)})"
176
+
177
+
178
+ class _BaseSegmentationFormat:
179
+ """
180
+ [PRIVATE] Base configuration for segmentation metrics.
181
+ """
182
+ def __init__(self,
183
+ heatmap_cmap: str = 'viridis',
184
+ cm_cmap: str = "Blues",
185
+ font_size: int = 16) -> None:
186
+ """
187
+ Initializes the formatting configuration for segmentation metrics.
188
+
189
+ Args:
190
+ heatmap_cmap (str): The matplotlib colormap name for the per-class
191
+ metrics heatmap. Defaults to "viridis".
192
+ - Sequential options: 'viridis', 'plasma', 'inferno', 'cividis'
193
+ - Diverging options: 'coolwarm', 'bwr', 'seismic'
194
+ cm_cmap (str): The matplotlib colormap name for the pixel-level
195
+ confusion matrix. Defaults to "Blues".
196
+ - Sequential options: 'Blues', 'Greens', 'Reds', 'Oranges'
197
+ font_size (int): The base font size to apply to the plots. Defaults to 16.
198
+
199
+ <br>
200
+
201
+ ## [Matplotlib Colormaps](https://matplotlib.org/stable/users/explain/colors/colormaps.html)
202
+ """
203
+ self.heatmap_cmap = heatmap_cmap
204
+ self.cm_cmap = cm_cmap
205
+ self.font_size = font_size
206
+
207
+ def __repr__(self) -> str:
208
+ parts = [
209
+ f"heatmap_cmap='{self.heatmap_cmap}'",
210
+ f"cm_cmap='{self.cm_cmap}'",
211
+ f"font_size={self.font_size}"
212
+ ]
213
+ return f"{self.__class__.__name__}({', '.join(parts)})"
214
+
215
+
216
+ class _BaseSequenceValueFormat:
217
+ """
218
+ [PRIVATE] Base configuration for sequence to value metrics.
219
+ """
220
+ def __init__(self,
221
+ font_size: int=16,
222
+ scatter_color: str='tab:blue',
223
+ scatter_alpha: float=0.6,
224
+ ideal_line_color: str='k',
225
+ residual_line_color: str='red',
226
+ hist_bins: Union[int, str] = 'auto') -> None:
227
+ """
228
+ Initializes the formatting configuration for sequence to value metrics.
229
+
230
+ Args:
231
+ font_size (int): The base font size to apply to the plots. Defaults to 16.
232
+ scatter_color (str): Matplotlib color for the scatter plot points. Defaults to 'tab:blue'.
233
+ - Common color names: 'tab:blue', 'crimson', 'forestgreen', '#4682B4'
234
+ scatter_alpha (float): Alpha transparency for scatter plot points. Defaults to 0.6.
235
+ ideal_line_color (str): Matplotlib color for the 'ideal' y=x line in the
236
+ True vs. Predicted plot. Defaults to 'k' (black).
237
+ - Common color names: 'k', 'red', 'darkgrey', '#FF6347'
238
+ residual_line_color (str): Matplotlib color for the y=0 line in the
239
+ Residual plot. Defaults to 'red'.
240
+ - Common color names: 'red', 'blue', 'k', '#4682B4'
241
+ hist_bins (int | str): The number of bins for the residuals histogram.
242
+ Defaults to 'auto' to use seaborn's automatic bin selection.
243
+ - Options: 'auto', 'sqrt', 10, 20
244
+
245
+ <br>
246
+
247
+ ## [Matplotlib Colors](https://matplotlib.org/stable/users/explain/colors/colors.html)
248
+ """
249
+ self.font_size = font_size
250
+ self.scatter_color = scatter_color
251
+ self.scatter_alpha = scatter_alpha
252
+ self.ideal_line_color = ideal_line_color
253
+ self.residual_line_color = residual_line_color
254
+ self.hist_bins = hist_bins
255
+
256
+ def __repr__(self) -> str:
257
+ parts = [
258
+ f"font_size={self.font_size}",
259
+ f"scatter_color='{self.scatter_color}'",
260
+ f"scatter_alpha={self.scatter_alpha}",
261
+ f"ideal_line_color='{self.ideal_line_color}'",
262
+ f"residual_line_color='{self.residual_line_color}'",
263
+ f"hist_bins='{self.hist_bins}'"
264
+ ]
265
+ return f"{self.__class__.__name__}({', '.join(parts)})"
266
+
267
+
268
+ class _BaseSequenceSequenceFormat:
269
+ """
270
+ [PRIVATE] Base configuration for sequence-to-sequence metrics.
271
+ """
272
+ def __init__(self,
273
+ font_size: int = 16,
274
+ plot_figsize: tuple[int, int] = (10, 6),
275
+ grid_style: str = '--',
276
+ rmse_color: str = 'tab:blue',
277
+ rmse_marker: str = 'o-',
278
+ mae_color: str = 'tab:orange',
279
+ mae_marker: str = 's--'):
280
+ """
281
+ Initializes the formatting configuration for seq-to-seq metrics.
282
+
283
+ Args:
284
+ font_size (int): The base font size to apply to the plots. Defaults to 16.
285
+ plot_figsize (Tuple[int, int]): Figure size for the plot. Defaults to (10, 6).
286
+ grid_style (str): Matplotlib linestyle for the plot grid. Defaults to '--'.
287
+ - Options: '--' (dashed), ':' (dotted), '-.' (dash-dot), '-' (solid)
288
+ rmse_color (str): Matplotlib color for the RMSE line. Defaults to 'tab:blue'.
289
+ - Common color names: 'tab:blue', 'crimson', 'forestgreen', '#4682B4'
290
+ rmse_marker (str): Matplotlib marker style for the RMSE line. Defaults to 'o-'.
291
+ - Options: 'o-' (circle), 's--' (square), '^:' (triangle), 'x' (x marker)
292
+ mae_color (str): Matplotlib color for the MAE line. Defaults to 'tab:orange'.
293
+ - Common color names: 'tab:orange', 'purple', 'black', '#FF6347'
294
+ mae_marker (str): Matplotlib marker style for the MAE line. Defaults to 's--'.
295
+ - Options: 's--', 'o-', 'v:', '+' (plus marker)
296
+
297
+ <br>
298
+
299
+ ## [Matplotlib Colors](https://matplotlib.org/stable/users/explain/colors/colors.html)
300
+ ## [Matplotlib Linestyles](https://matplotlib.org/stable/gallery/lines_bars_and_markers/linestyles.html)
301
+ ## [Matplotlib Markers](https://matplotlib.org/stable/api/markers_api.html)
302
+ """
303
+ self.font_size = font_size
304
+ self.plot_figsize = plot_figsize
305
+ self.grid_style = grid_style
306
+ self.rmse_color = rmse_color
307
+ self.rmse_marker = rmse_marker
308
+ self.mae_color = mae_color
309
+ self.mae_marker = mae_marker
310
+
311
+ def __repr__(self) -> str:
312
+ parts = [
313
+ f"font_size={self.font_size}",
314
+ f"plot_figsize={self.plot_figsize}",
315
+ f"grid_style='{self.grid_style}'",
316
+ f"rmse_color='{self.rmse_color}'",
317
+ f"mae_color='{self.mae_color}'"
318
+ ]
319
+ return f"{self.__class__.__name__}({', '.join(parts)})"
320
+
321
+ # --- Public API classes ---
322
+
323
+ # Regression
324
+ class RegressionMetricsFormat(_BaseRegressionFormat):
325
+ """
326
+ Configuration for single-target regression.
327
+ """
328
+ def __init__(self,
329
+ font_size: int=16,
330
+ scatter_color: str='tab:blue',
331
+ scatter_alpha: float=0.6,
332
+ ideal_line_color: str='k',
333
+ residual_line_color: str='red',
334
+ hist_bins: Union[int, str] = 'auto') -> None:
335
+ super().__init__(font_size=font_size,
336
+ scatter_color=scatter_color,
337
+ scatter_alpha=scatter_alpha,
338
+ ideal_line_color=ideal_line_color,
339
+ residual_line_color=residual_line_color,
340
+ hist_bins=hist_bins)
341
+
342
+
343
+ # Multitarget regression
344
+ class MultiTargetRegressionMetricsFormat(_BaseRegressionFormat):
345
+ """
346
+ Configuration for multi-target regression.
347
+ """
348
+ def __init__(self,
349
+ font_size: int=16,
350
+ scatter_color: str='tab:blue',
351
+ scatter_alpha: float=0.6,
352
+ ideal_line_color: str='k',
353
+ residual_line_color: str='red',
354
+ hist_bins: Union[int, str] = 'auto') -> None:
355
+ super().__init__(font_size=font_size,
356
+ scatter_color=scatter_color,
357
+ scatter_alpha=scatter_alpha,
358
+ ideal_line_color=ideal_line_color,
359
+ residual_line_color=residual_line_color,
360
+ hist_bins=hist_bins)
361
+
362
+
363
+ # Classification
364
+ class BinaryClassificationMetricsFormat(_BaseClassificationFormat):
365
+ """
366
+ Configuration for binary classification.
367
+ """
368
+ def __init__(self,
369
+ cmap: str="Blues",
370
+ ROC_PR_line: str='darkorange',
371
+ calibration_bins: int=15,
372
+ font_size: int=16) -> None:
373
+ super().__init__(cmap=cmap,
374
+ ROC_PR_line=ROC_PR_line,
375
+ calibration_bins=calibration_bins,
376
+ font_size=font_size)
377
+
378
+
379
+ class MultiClassClassificationMetricsFormat(_BaseClassificationFormat):
380
+ """
381
+ Configuration for multi-class classification.
382
+ """
383
+ def __init__(self,
384
+ cmap: str="Blues",
385
+ ROC_PR_line: str='darkorange',
386
+ calibration_bins: int=15,
387
+ font_size: int=16) -> None:
388
+ super().__init__(cmap=cmap,
389
+ ROC_PR_line=ROC_PR_line,
390
+ calibration_bins=calibration_bins,
391
+ font_size=font_size)
392
+
393
+
394
+ class BinaryImageClassificationMetricsFormat(_BaseClassificationFormat):
395
+ """
396
+ Configuration for binary image classification.
397
+ """
398
+ def __init__(self,
399
+ cmap: str="Blues",
400
+ ROC_PR_line: str='darkorange',
401
+ calibration_bins: int=15,
402
+ font_size: int=16) -> None:
403
+ super().__init__(cmap=cmap,
404
+ ROC_PR_line=ROC_PR_line,
405
+ calibration_bins=calibration_bins,
406
+ font_size=font_size)
407
+
408
+
409
+ class MultiClassImageClassificationMetricsFormat(_BaseClassificationFormat):
410
+ """
411
+ Configuration for multi-class image classification.
412
+ """
413
+ def __init__(self,
414
+ cmap: str="Blues",
415
+ ROC_PR_line: str='darkorange',
416
+ calibration_bins: int=15,
417
+ font_size: int=16) -> None:
418
+ super().__init__(cmap=cmap,
419
+ ROC_PR_line=ROC_PR_line,
420
+ calibration_bins=calibration_bins,
421
+ font_size=font_size)
422
+
423
+
424
+ # Multi-Label classification
425
+ class MultiLabelBinaryClassificationMetricsFormat(_BaseMultiLabelFormat):
426
+ """
427
+ Configuration for multi-label binary classification.
428
+ """
429
+ def __init__(self,
430
+ ROC_PR_line: str='darkorange',
431
+ cmap: str = "Blues",
432
+ font_size: int = 16) -> None:
433
+ super().__init__(ROC_PR_line=ROC_PR_line,
434
+ cmap=cmap,
435
+ font_size=font_size)
436
+
437
+
438
+ # Segmentation
439
+ class BinarySegmentationMetricsFormat(_BaseSegmentationFormat):
440
+ """
441
+ Configuration for binary segmentation.
442
+ """
443
+ def __init__(self,
444
+ heatmap_cmap: str = 'viridis',
445
+ cm_cmap: str = "Blues",
446
+ font_size: int = 16) -> None:
447
+ super().__init__(heatmap_cmap=heatmap_cmap,
448
+ cm_cmap=cm_cmap,
449
+ font_size=font_size)
450
+
451
+
452
+ class MultiClassSegmentationMetricsFormat(_BaseSegmentationFormat):
453
+ """
454
+ Configuration for multi-class segmentation.
455
+ """
456
+ def __init__(self,
457
+ heatmap_cmap: str = 'viridis',
458
+ cm_cmap: str = "Blues",
459
+ font_size: int = 16) -> None:
460
+ super().__init__(heatmap_cmap=heatmap_cmap,
461
+ cm_cmap=cm_cmap,
462
+ font_size=font_size)
463
+
464
+
465
+ # Sequence
466
+ class SequenceValueMetricsFormat(_BaseSequenceValueFormat):
467
+ """
468
+ Configuration for sequence-to-value prediction.
469
+ """
470
+ def __init__(self,
471
+ font_size: int=16,
472
+ scatter_color: str='tab:blue',
473
+ scatter_alpha: float=0.6,
474
+ ideal_line_color: str='k',
475
+ residual_line_color: str='red',
476
+ hist_bins: Union[int, str] = 'auto') -> None:
477
+ super().__init__(font_size=font_size,
478
+ scatter_color=scatter_color,
479
+ scatter_alpha=scatter_alpha,
480
+ ideal_line_color=ideal_line_color,
481
+ residual_line_color=residual_line_color,
482
+ hist_bins=hist_bins)
483
+
484
+
485
+ class SequenceSequenceMetricsFormat(_BaseSequenceSequenceFormat):
486
+ """
487
+ Configuration for sequence-to-sequence prediction.
488
+ """
489
+ def __init__(self,
490
+ font_size: int = 16,
491
+ plot_figsize: tuple[int, int] = (10, 6),
492
+ grid_style: str = '--',
493
+ rmse_color: str = 'tab:blue',
494
+ rmse_marker: str = 'o-',
495
+ mae_color: str = 'tab:orange',
496
+ mae_marker: str = 's--'):
497
+ super().__init__(font_size=font_size,
498
+ plot_figsize=plot_figsize,
499
+ grid_style=grid_style,
500
+ rmse_color=rmse_color,
501
+ rmse_marker=rmse_marker,
502
+ mae_color=mae_color,
503
+ mae_marker=mae_marker)
504
+
505
+
506
+ # -------- Finalize classes --------
507
+ class _FinalizeModelTraining:
508
+ """
509
+ Base class for finalizing model training.
510
+
511
+ This class is not intended to be instantiated directly. Instead, use one of its specific subclasses.
512
+ """
513
+ def __init__(self,
514
+ filename: str,
515
+ ) -> None:
516
+ self.filename = _validate_string(string=filename, attribute_name="filename", extension=".pth")
517
+ self.target_name: Optional[str] = None
518
+ self.target_names: Optional[list[str]] = None
519
+ self.classification_threshold: Optional[float] = None
520
+ self.class_map: Optional[dict[str,int]] = None
521
+ self.initial_sequence: Optional[np.ndarray] = None
522
+ self.sequence_length: Optional[int] = None
523
+
524
+
525
+ class FinalizeRegression(_FinalizeModelTraining):
526
+ """Parameters for finalizing a single-target regression model."""
527
+ def __init__(self,
528
+ filename: str,
529
+ target_name: str,
530
+ ) -> None:
531
+ """Initializes the finalization parameters.
532
+
533
+ Args:
534
+ filename (str): The name of the file to be saved.
535
+ target_name (str): The name of the target variable.
536
+ """
537
+ super().__init__(filename=filename)
538
+ self.target_name = _validate_string(string=target_name, attribute_name="Target name")
539
+
540
+
541
+ class FinalizeMultiTargetRegression(_FinalizeModelTraining):
542
+ """Parameters for finalizing a multi-target regression model."""
543
+ def __init__(self,
544
+ filename: str,
545
+ target_names: list[str],
546
+ ) -> None:
547
+ """Initializes the finalization parameters.
548
+
549
+ Args:
550
+ filename (str): The name of the file to be saved.
551
+ target_names (list[str]): A list of names for the target variables.
552
+ """
553
+ super().__init__(filename=filename)
554
+ safe_names = [_validate_string(string=target_name, attribute_name="All target names") for target_name in target_names]
555
+ self.target_names = safe_names
556
+
557
+
558
+ class FinalizeBinaryClassification(_FinalizeModelTraining):
559
+ """Parameters for finalizing a binary classification model."""
560
+ def __init__(self,
561
+ filename: str,
562
+ target_name: str,
563
+ classification_threshold: float,
564
+ class_map: dict[str,int]
565
+ ) -> None:
566
+ """Initializes the finalization parameters.
567
+
568
+ Args:
569
+ filename (str): The name of the file to be saved.
570
+ target_name (str): The name of the target variable.
571
+ classification_threshold (float): The cutoff threshold for classifying as the positive class.
572
+ class_map (dict[str,int]): A dictionary mapping class names (str)
573
+ to their integer representations (e.g., {'cat': 0, 'dog': 1}).
574
+ """
575
+ super().__init__(filename=filename)
576
+ self.target_name = _validate_string(string=target_name, attribute_name="Target name")
577
+ self.classification_threshold = _validate_threshold(classification_threshold)
578
+ self.class_map = _validate_class_map(class_map)
579
+
580
+
581
+ class FinalizeMultiClassClassification(_FinalizeModelTraining):
582
+ """Parameters for finalizing a multi-class classification model."""
583
+ def __init__(self,
584
+ filename: str,
585
+ target_name: str,
586
+ class_map: dict[str,int]
587
+ ) -> None:
588
+ """Initializes the finalization parameters.
589
+
590
+ Args:
591
+ filename (str): The name of the file to be saved.
592
+ target_name (str): The name of the target variable.
593
+ class_map (dict[str,int]): A dictionary mapping class names (str)
594
+ to their integer representations (e.g., {'cat': 0, 'dog': 1}).
595
+ """
596
+ super().__init__(filename=filename)
597
+ self.target_name = _validate_string(string=target_name, attribute_name="Target name")
598
+ self.class_map = _validate_class_map(class_map)
599
+
600
+
601
+ class FinalizeBinaryImageClassification(_FinalizeModelTraining):
602
+ """Parameters for finalizing a binary image classification model."""
603
+ def __init__(self,
604
+ filename: str,
605
+ classification_threshold: float,
606
+ class_map: dict[str,int]
607
+ ) -> None:
608
+ """Initializes the finalization parameters.
609
+
610
+ Args:
611
+ filename (str): The name of the file to be saved.
612
+ classification_threshold (float): The cutoff threshold for
613
+ classifying as the positive class.
614
+ class_map (dict[str,int]): A dictionary mapping class names (str)
615
+ to their integer representations (e.g., {'cat': 0, 'dog': 1}).
616
+ """
617
+ super().__init__(filename=filename)
618
+ self.classification_threshold = _validate_threshold(classification_threshold)
619
+ self.class_map = _validate_class_map(class_map)
620
+
621
+
622
+ class FinalizeMultiClassImageClassification(_FinalizeModelTraining):
623
+ """Parameters for finalizing a multi-class image classification model."""
624
+ def __init__(self,
625
+ filename: str,
626
+ class_map: dict[str,int]
627
+ ) -> None:
628
+ """Initializes the finalization parameters.
629
+
630
+ Args:
631
+ filename (str): The name of the file to be saved.
632
+ class_map (dict[str,int]): A dictionary mapping class names (str)
633
+ to their integer representations (e.g., {'cat': 0, 'dog': 1}).
634
+ """
635
+ super().__init__(filename=filename)
636
+ self.class_map = _validate_class_map(class_map)
637
+
638
+
639
+ class FinalizeMultiLabelBinaryClassification(_FinalizeModelTraining):
640
+ """Parameters for finalizing a multi-label binary classification model."""
641
+ def __init__(self,
642
+ filename: str,
643
+ target_names: list[str],
644
+ classification_threshold: float,
645
+ ) -> None:
646
+ """Initializes the finalization parameters.
647
+
648
+ Args:
649
+ filename (str): The name of the file to be saved.
650
+ target_names (list[str]): A list of names for the target variables.
651
+ classification_threshold (float): The cutoff threshold for classifying as the positive class.
652
+ """
653
+ super().__init__(filename=filename)
654
+ safe_names = [_validate_string(string=target_name, attribute_name="All target names") for target_name in target_names]
655
+ self.target_names = safe_names
656
+ self.classification_threshold = _validate_threshold(classification_threshold)
657
+
658
+
659
+ class FinalizeBinarySegmentation(_FinalizeModelTraining):
660
+ """Parameters for finalizing a binary segmentation model."""
661
+ def __init__(self,
662
+ filename: str,
663
+ class_map: dict[str,int],
664
+ classification_threshold: float,
665
+ ) -> None:
666
+ """Initializes the finalization parameters.
667
+
668
+ Args:
669
+ filename (str): The name of the file to be saved.
670
+ classification_threshold (float): The cutoff threshold for classifying as the positive class (mask).
671
+ """
672
+ super().__init__(filename=filename)
673
+ self.classification_threshold = _validate_threshold(classification_threshold)
674
+ self.class_map = _validate_class_map(class_map)
675
+
676
+
677
+ class FinalizeMultiClassSegmentation(_FinalizeModelTraining):
678
+ """Parameters for finalizing a multi-class segmentation model."""
679
+ def __init__(self,
680
+ filename: str,
681
+ class_map: dict[str,int]
682
+ ) -> None:
683
+ """Initializes the finalization parameters.
684
+
685
+ Args:
686
+ filename (str): The name of the file to be saved.
687
+ """
688
+ super().__init__(filename=filename)
689
+ self.class_map = _validate_class_map(class_map)
690
+
691
+
692
+ class FinalizeObjectDetection(_FinalizeModelTraining):
693
+ """Parameters for finalizing an object detection model."""
694
+ def __init__(self,
695
+ filename: str,
696
+ class_map: dict[str,int]
697
+ ) -> None:
698
+ """Initializes the finalization parameters.
699
+
700
+ Args:
701
+ filename (str): The name of the file to be saved.
702
+ """
703
+ super().__init__(filename=filename)
704
+ self.class_map = _validate_class_map(class_map)
705
+
706
+
707
+ class FinalizeSequencePrediction(_FinalizeModelTraining):
708
+ """Parameters for finalizing a sequence prediction model."""
709
+ def __init__(self,
710
+ filename: str,
711
+ last_training_sequence: np.ndarray,
712
+ ) -> None:
713
+ """Initializes the finalization parameters.
714
+
715
+ Args:
716
+ filename (str): The name of the file to be saved.
717
+ last_training_sequence (np.ndarray): The last sequence from the training data, needed to start predictions.
718
+ """
719
+ super().__init__(filename=filename)
720
+
721
+ if not isinstance(last_training_sequence, np.ndarray):
722
+ _LOGGER.error(f"The last training sequence must be a 1D numpy array, got {type(last_training_sequence)}.")
723
+ raise TypeError()
724
+
725
+ if last_training_sequence.ndim == 1:
726
+ # It's already 1D, (N,). This is valid.
727
+ self.initial_sequence = last_training_sequence
728
+ elif last_training_sequence.ndim == 2:
729
+ # It's 2D, check for shape (1, N)
730
+ if last_training_sequence.shape[0] == 1:
731
+ # Shape is (1, N), flatten to (N,)
732
+ self.initial_sequence = last_training_sequence.flatten()
733
+ else:
734
+ # Shape is (N, 1) or (N, M), which is invalid
735
+ _LOGGER.error(f"The last training sequence must be a 1D numpy array, got shape {last_training_sequence.shape}.")
736
+ raise ValueError()
737
+ else:
738
+ # It's 3D or more, which is not supported
739
+ _LOGGER.error(f"The last training sequence must be a 1D numpy array, got shape {last_training_sequence.shape}.")
740
+ raise ValueError()
741
+
742
+ # Save the length of the validated 1D sequence
743
+ self.sequence_length = len(self.initial_sequence)
744
+
745
+
746
+ def _validate_string(string: str, attribute_name: str, extension: Optional[str]=None) -> str:
747
+ """Helper for finalize classes"""
748
+ if not isinstance(string, str):
749
+ _LOGGER.error(f"{attribute_name} must be a string.")
750
+ raise TypeError()
751
+
752
+ if extension:
753
+ safe_name = sanitize_filename(string)
754
+
755
+ if not safe_name.endswith(extension):
756
+ safe_name += extension
757
+ else:
758
+ safe_name = string
759
+
760
+ return safe_name
761
+
762
+ def _validate_threshold(threshold: float):
763
+ """Helper for finalize classes"""
764
+ if not isinstance(threshold, float):
765
+ _LOGGER.error(f"Classification threshold must be a float.")
766
+ raise TypeError()
767
+ elif threshold <= 0.0 or threshold >= 1.0:
768
+ _LOGGER.error(f"Classification threshold must be in the range [0.1, 0.9]")
769
+ raise ValueError()
770
+
771
+ return threshold
105
772
 
773
+ def _validate_class_map(map: dict[str,int]):
774
+ """Helper for finalize classes"""
775
+ validated_map = None
776
+ if isinstance(map, dict):
777
+ if all( [isinstance(key, str) for key in map.keys()] ):
778
+ if all( [isinstance(val, str) for val in map.values()] ):
779
+ validated_map = map
780
+
781
+ if validated_map is None:
782
+ _LOGGER.error(f"Class map must be a dictionary of string keys and integer values.")
783
+ raise TypeError()
784
+ else:
785
+ return validated_map
106
786
 
107
787
  def info():
108
788
  _script_info(__all__)