dragon-ml-toolbox 13.3.0__py3-none-any.whl → 16.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/METADATA +20 -6
  2. dragon_ml_toolbox-16.2.0.dist-info/RECORD +51 -0
  3. {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +10 -0
  4. ml_tools/ETL_cleaning.py +20 -20
  5. ml_tools/ETL_engineering.py +23 -25
  6. ml_tools/GUI_tools.py +20 -20
  7. ml_tools/MICE_imputation.py +207 -5
  8. ml_tools/ML_callbacks.py +43 -26
  9. ml_tools/ML_configuration.py +788 -0
  10. ml_tools/ML_datasetmaster.py +303 -448
  11. ml_tools/ML_evaluation.py +351 -93
  12. ml_tools/ML_evaluation_multi.py +139 -42
  13. ml_tools/ML_inference.py +290 -209
  14. ml_tools/ML_models.py +33 -106
  15. ml_tools/ML_models_advanced.py +323 -0
  16. ml_tools/ML_optimization.py +12 -12
  17. ml_tools/ML_scaler.py +11 -11
  18. ml_tools/ML_sequence_datasetmaster.py +341 -0
  19. ml_tools/ML_sequence_evaluation.py +219 -0
  20. ml_tools/ML_sequence_inference.py +391 -0
  21. ml_tools/ML_sequence_models.py +139 -0
  22. ml_tools/ML_trainer.py +1604 -179
  23. ml_tools/ML_utilities.py +351 -4
  24. ml_tools/ML_vision_datasetmaster.py +1540 -0
  25. ml_tools/ML_vision_evaluation.py +284 -0
  26. ml_tools/ML_vision_inference.py +405 -0
  27. ml_tools/ML_vision_models.py +641 -0
  28. ml_tools/ML_vision_transformers.py +284 -0
  29. ml_tools/PSO_optimization.py +6 -6
  30. ml_tools/SQL.py +4 -4
  31. ml_tools/_keys.py +171 -0
  32. ml_tools/_schema.py +1 -1
  33. ml_tools/custom_logger.py +37 -14
  34. ml_tools/data_exploration.py +502 -93
  35. ml_tools/ensemble_evaluation.py +54 -11
  36. ml_tools/ensemble_inference.py +7 -33
  37. ml_tools/ensemble_learning.py +1 -1
  38. ml_tools/math_utilities.py +1 -1
  39. ml_tools/optimization_tools.py +2 -2
  40. ml_tools/path_manager.py +5 -5
  41. ml_tools/serde.py +2 -2
  42. ml_tools/utilities.py +192 -4
  43. dragon_ml_toolbox-13.3.0.dist-info/RECORD +0 -41
  44. ml_tools/RNN_forecast.py +0 -56
  45. ml_tools/keys.py +0 -87
  46. {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/WHEEL +0 -0
  47. {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/licenses/LICENSE +0 -0
  48. {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,788 @@
1
+ from typing import Union, Optional
2
+ import numpy as np
3
+
4
+ from ._script_info import _script_info
5
+ from ._logger import _LOGGER
6
+ from .path_manager import sanitize_filename
7
+
8
+
9
+ __all__ = [
10
+ "RegressionMetricsFormat",
11
+ "MultiTargetRegressionMetricsFormat",
12
+ "BinaryClassificationMetricsFormat",
13
+ "MultiClassClassificationMetricsFormat",
14
+ "BinaryImageClassificationMetricsFormat",
15
+ "MultiClassImageClassificationMetricsFormat",
16
+ "MultiLabelBinaryClassificationMetricsFormat",
17
+ "BinarySegmentationMetricsFormat",
18
+ "MultiClassSegmentationMetricsFormat",
19
+ "SequenceValueMetricsFormat",
20
+ "SequenceSequenceMetricsFormat",
21
+
22
+ "FinalizeBinaryClassification",
23
+ "FinalizeBinarySegmentation",
24
+ "FinalizeBinaryImageClassification",
25
+ "FinalizeMultiClassClassification",
26
+ "FinalizeMultiClassImageClassification",
27
+ "FinalizeMultiClassSegmentation",
28
+ "FinalizeMultiLabelBinaryClassification",
29
+ "FinalizeMultiTargetRegression",
30
+ "FinalizeRegression",
31
+ "FinalizeObjectDetection",
32
+ "FinalizeSequencePrediction"
33
+ ]
34
+
35
+ # --- Private base classes ---
36
+
37
+ class _BaseClassificationFormat:
38
+ """
39
+ [PRIVATE] Base configuration for single-label classification metrics.
40
+ """
41
+ def __init__(self,
42
+ cmap: str="Blues",
43
+ ROC_PR_line: str='darkorange',
44
+ calibration_bins: int=15,
45
+ font_size: int=16) -> None:
46
+ """
47
+ Initializes the formatting configuration for single-label classification metrics.
48
+
49
+ Args:
50
+ cmap (str): The matplotlib colormap name for the confusion matrix
51
+ and report heatmap. Defaults to "Blues".
52
+ - Sequential options: 'Blues', 'Greens', 'Reds', 'Oranges', 'Purples'
53
+ - Diverging options: 'coolwarm', 'viridis', 'plasma', 'inferno'
54
+
55
+ ROC_PR_line (str): The color name or hex code for the line plotted
56
+ on the ROC and Precision-Recall curves. Defaults to 'darkorange'.
57
+ - Common color names: 'darkorange', 'cornflowerblue', 'crimson', 'forestgreen'
58
+ - Hex codes: '#FF6347', '#4682B4'
59
+
60
+ calibration_bins (int): The number of bins to use when
61
+ creating the calibration (reliability) plot. Defaults to 15.
62
+
63
+ font_size (int): The base font size to apply to the plots. Defaults to 16.
64
+
65
+ <br>
66
+
67
+ ## [Matplotlib Colormaps](https://matplotlib.org/stable/users/explain/colors/colormaps.html)
68
+ """
69
+ self.cmap = cmap
70
+ self.ROC_PR_line = ROC_PR_line
71
+ self.calibration_bins = calibration_bins
72
+ self.font_size = font_size
73
+
74
+ def __repr__(self) -> str:
75
+ parts = [
76
+ f"cmap='{self.cmap}'",
77
+ f"ROC_PR_line='{self.ROC_PR_line}'",
78
+ f"calibration_bins={self.calibration_bins}",
79
+ f"font_size={self.font_size}"
80
+ ]
81
+ return f"{self.__class__.__name__}({', '.join(parts)})"
82
+
83
+
84
+ class _BaseMultiLabelFormat:
85
+ """
86
+ [PRIVATE] Base configuration for multi-label binary classification metrics.
87
+ """
88
+ def __init__(self,
89
+ ROC_PR_line: str='darkorange',
90
+ cmap: str = "Blues",
91
+ font_size: int = 16) -> None:
92
+ """
93
+ Initializes the formatting configuration for multi-label classification metrics.
94
+
95
+ Args:
96
+ ROC_PR_line (str): The color name or hex code for the line plotted
97
+ on the ROC and Precision-Recall curves (one for each label).
98
+ Defaults to 'darkorange'.
99
+ - Common color names: 'darkorange', 'cornflowerblue', 'crimson', 'forestgreen'
100
+ - Hex codes: '#FF6347', '#4682B4'
101
+
102
+ cmap (str): The matplotlib colormap name for the per-label
103
+ confusion matrices. Defaults to "Blues".
104
+ - Sequential options: 'Blues', 'Greens', 'Reds', 'Oranges', 'Purples'
105
+ - Diverging options: 'coolwarm', 'viridis', 'plasma', 'inferno'
106
+
107
+ font_size (int): The base font size to apply to the plots. Defaults to 16.
108
+
109
+ <br>
110
+
111
+ ## [Matplotlib Colormaps](https://matplotlib.org/stable/users/explain/colors/colormaps.html)
112
+ """
113
+ self.cmap = cmap
114
+ self.ROC_PR_line = ROC_PR_line
115
+ self.font_size = font_size
116
+
117
+ def __repr__(self) -> str:
118
+ parts = [
119
+ f"ROC_PR_line='{self.ROC_PR_line}'",
120
+ f"cmap='{self.cmap}'",
121
+ f"font_size={self.font_size}"
122
+ ]
123
+ return f"{self.__class__.__name__}({', '.join(parts)})"
124
+
125
+
126
+ class _BaseRegressionFormat:
127
+ """
128
+ [PRIVATE] Base configuration for regression metrics.
129
+ """
130
+ def __init__(self,
131
+ font_size: int=16,
132
+ scatter_color: str='tab:blue',
133
+ scatter_alpha: float=0.6,
134
+ ideal_line_color: str='k',
135
+ residual_line_color: str='red',
136
+ hist_bins: Union[int, str] = 'auto') -> None:
137
+ """
138
+ Initializes the formatting configuration for regression metrics.
139
+
140
+ Args:
141
+ font_size (int): The base font size to apply to the plots. Defaults to 16.
142
+ scatter_color (str): Matplotlib color for the scatter plot points. Defaults to 'tab:blue'.
143
+ - Common color names: 'tab:blue', 'crimson', 'forestgreen', '#4682B4'
144
+ scatter_alpha (float): Alpha transparency for scatter plot points. Defaults to 0.6.
145
+ ideal_line_color (str): Matplotlib color for the 'ideal' y=x line in the
146
+ True vs. Predicted plot. Defaults to 'k' (black).
147
+ - Common color names: 'k', 'red', 'darkgrey', '#FF6347'
148
+ residual_line_color (str): Matplotlib color for the y=0 line in the
149
+ Residual plot. Defaults to 'red'.
150
+ - Common color names: 'red', 'blue', 'k', '#4682B4'
151
+ hist_bins (int | str): The number of bins for the residuals histogram.
152
+ Defaults to 'auto' to use seaborn's automatic bin selection.
153
+ - Options: 'auto', 'sqrt', 10, 20
154
+
155
+ <br>
156
+
157
+ ## [Matplotlib Colors](https://matplotlib.org/stable/users/explain/colors/colors.html)
158
+ """
159
+ self.font_size = font_size
160
+ self.scatter_color = scatter_color
161
+ self.scatter_alpha = scatter_alpha
162
+ self.ideal_line_color = ideal_line_color
163
+ self.residual_line_color = residual_line_color
164
+ self.hist_bins = hist_bins
165
+
166
+ def __repr__(self) -> str:
167
+ parts = [
168
+ f"font_size={self.font_size}",
169
+ f"scatter_color='{self.scatter_color}'",
170
+ f"scatter_alpha={self.scatter_alpha}",
171
+ f"ideal_line_color='{self.ideal_line_color}'",
172
+ f"residual_line_color='{self.residual_line_color}'",
173
+ f"hist_bins='{self.hist_bins}'"
174
+ ]
175
+ return f"{self.__class__.__name__}({', '.join(parts)})"
176
+
177
+
178
+ class _BaseSegmentationFormat:
179
+ """
180
+ [PRIVATE] Base configuration for segmentation metrics.
181
+ """
182
+ def __init__(self,
183
+ heatmap_cmap: str = 'viridis',
184
+ cm_cmap: str = "Blues",
185
+ font_size: int = 16) -> None:
186
+ """
187
+ Initializes the formatting configuration for segmentation metrics.
188
+
189
+ Args:
190
+ heatmap_cmap (str): The matplotlib colormap name for the per-class
191
+ metrics heatmap. Defaults to "viridis".
192
+ - Sequential options: 'viridis', 'plasma', 'inferno', 'cividis'
193
+ - Diverging options: 'coolwarm', 'bwr', 'seismic'
194
+ cm_cmap (str): The matplotlib colormap name for the pixel-level
195
+ confusion matrix. Defaults to "Blues".
196
+ - Sequential options: 'Blues', 'Greens', 'Reds', 'Oranges'
197
+ font_size (int): The base font size to apply to the plots. Defaults to 16.
198
+
199
+ <br>
200
+
201
+ ## [Matplotlib Colormaps](https://matplotlib.org/stable/users/explain/colors/colormaps.html)
202
+ """
203
+ self.heatmap_cmap = heatmap_cmap
204
+ self.cm_cmap = cm_cmap
205
+ self.font_size = font_size
206
+
207
+ def __repr__(self) -> str:
208
+ parts = [
209
+ f"heatmap_cmap='{self.heatmap_cmap}'",
210
+ f"cm_cmap='{self.cm_cmap}'",
211
+ f"font_size={self.font_size}"
212
+ ]
213
+ return f"{self.__class__.__name__}({', '.join(parts)})"
214
+
215
+
216
+ class _BaseSequenceValueFormat:
217
+ """
218
+ [PRIVATE] Base configuration for sequence to value metrics.
219
+ """
220
+ def __init__(self,
221
+ font_size: int=16,
222
+ scatter_color: str='tab:blue',
223
+ scatter_alpha: float=0.6,
224
+ ideal_line_color: str='k',
225
+ residual_line_color: str='red',
226
+ hist_bins: Union[int, str] = 'auto') -> None:
227
+ """
228
+ Initializes the formatting configuration for sequence to value metrics.
229
+
230
+ Args:
231
+ font_size (int): The base font size to apply to the plots. Defaults to 16.
232
+ scatter_color (str): Matplotlib color for the scatter plot points. Defaults to 'tab:blue'.
233
+ - Common color names: 'tab:blue', 'crimson', 'forestgreen', '#4682B4'
234
+ scatter_alpha (float): Alpha transparency for scatter plot points. Defaults to 0.6.
235
+ ideal_line_color (str): Matplotlib color for the 'ideal' y=x line in the
236
+ True vs. Predicted plot. Defaults to 'k' (black).
237
+ - Common color names: 'k', 'red', 'darkgrey', '#FF6347'
238
+ residual_line_color (str): Matplotlib color for the y=0 line in the
239
+ Residual plot. Defaults to 'red'.
240
+ - Common color names: 'red', 'blue', 'k', '#4682B4'
241
+ hist_bins (int | str): The number of bins for the residuals histogram.
242
+ Defaults to 'auto' to use seaborn's automatic bin selection.
243
+ - Options: 'auto', 'sqrt', 10, 20
244
+
245
+ <br>
246
+
247
+ ## [Matplotlib Colors](https://matplotlib.org/stable/users/explain/colors/colors.html)
248
+ """
249
+ self.font_size = font_size
250
+ self.scatter_color = scatter_color
251
+ self.scatter_alpha = scatter_alpha
252
+ self.ideal_line_color = ideal_line_color
253
+ self.residual_line_color = residual_line_color
254
+ self.hist_bins = hist_bins
255
+
256
+ def __repr__(self) -> str:
257
+ parts = [
258
+ f"font_size={self.font_size}",
259
+ f"scatter_color='{self.scatter_color}'",
260
+ f"scatter_alpha={self.scatter_alpha}",
261
+ f"ideal_line_color='{self.ideal_line_color}'",
262
+ f"residual_line_color='{self.residual_line_color}'",
263
+ f"hist_bins='{self.hist_bins}'"
264
+ ]
265
+ return f"{self.__class__.__name__}({', '.join(parts)})"
266
+
267
+
268
+ class _BaseSequenceSequenceFormat:
269
+ """
270
+ [PRIVATE] Base configuration for sequence-to-sequence metrics.
271
+ """
272
+ def __init__(self,
273
+ font_size: int = 16,
274
+ plot_figsize: tuple[int, int] = (10, 6),
275
+ grid_style: str = '--',
276
+ rmse_color: str = 'tab:blue',
277
+ rmse_marker: str = 'o-',
278
+ mae_color: str = 'tab:orange',
279
+ mae_marker: str = 's--'):
280
+ """
281
+ Initializes the formatting configuration for seq-to-seq metrics.
282
+
283
+ Args:
284
+ font_size (int): The base font size to apply to the plots. Defaults to 16.
285
+ plot_figsize (Tuple[int, int]): Figure size for the plot. Defaults to (10, 6).
286
+ grid_style (str): Matplotlib linestyle for the plot grid. Defaults to '--'.
287
+ - Options: '--' (dashed), ':' (dotted), '-.' (dash-dot), '-' (solid)
288
+ rmse_color (str): Matplotlib color for the RMSE line. Defaults to 'tab:blue'.
289
+ - Common color names: 'tab:blue', 'crimson', 'forestgreen', '#4682B4'
290
+ rmse_marker (str): Matplotlib marker style for the RMSE line. Defaults to 'o-'.
291
+ - Options: 'o-' (circle), 's--' (square), '^:' (triangle), 'x' (x marker)
292
+ mae_color (str): Matplotlib color for the MAE line. Defaults to 'tab:orange'.
293
+ - Common color names: 'tab:orange', 'purple', 'black', '#FF6347'
294
+ mae_marker (str): Matplotlib marker style for the MAE line. Defaults to 's--'.
295
+ - Options: 's--', 'o-', 'v:', '+' (plus marker)
296
+
297
+ <br>
298
+
299
+ ## [Matplotlib Colors](https://matplotlib.org/stable/users/explain/colors/colors.html)
300
+ ## [Matplotlib Linestyles](https://matplotlib.org/stable/gallery/lines_bars_and_markers/linestyles.html)
301
+ ## [Matplotlib Markers](https://matplotlib.org/stable/api/markers_api.html)
302
+ """
303
+ self.font_size = font_size
304
+ self.plot_figsize = plot_figsize
305
+ self.grid_style = grid_style
306
+ self.rmse_color = rmse_color
307
+ self.rmse_marker = rmse_marker
308
+ self.mae_color = mae_color
309
+ self.mae_marker = mae_marker
310
+
311
+ def __repr__(self) -> str:
312
+ parts = [
313
+ f"font_size={self.font_size}",
314
+ f"plot_figsize={self.plot_figsize}",
315
+ f"grid_style='{self.grid_style}'",
316
+ f"rmse_color='{self.rmse_color}'",
317
+ f"mae_color='{self.mae_color}'"
318
+ ]
319
+ return f"{self.__class__.__name__}({', '.join(parts)})"
320
+
321
+ # --- Public API classes ---
322
+
323
+ # Regression
324
+ class RegressionMetricsFormat(_BaseRegressionFormat):
325
+ """
326
+ Configuration for single-target regression.
327
+ """
328
+ def __init__(self,
329
+ font_size: int=16,
330
+ scatter_color: str='tab:blue',
331
+ scatter_alpha: float=0.6,
332
+ ideal_line_color: str='k',
333
+ residual_line_color: str='red',
334
+ hist_bins: Union[int, str] = 'auto') -> None:
335
+ super().__init__(font_size=font_size,
336
+ scatter_color=scatter_color,
337
+ scatter_alpha=scatter_alpha,
338
+ ideal_line_color=ideal_line_color,
339
+ residual_line_color=residual_line_color,
340
+ hist_bins=hist_bins)
341
+
342
+
343
+ # Multitarget regression
344
+ class MultiTargetRegressionMetricsFormat(_BaseRegressionFormat):
345
+ """
346
+ Configuration for multi-target regression.
347
+ """
348
+ def __init__(self,
349
+ font_size: int=16,
350
+ scatter_color: str='tab:blue',
351
+ scatter_alpha: float=0.6,
352
+ ideal_line_color: str='k',
353
+ residual_line_color: str='red',
354
+ hist_bins: Union[int, str] = 'auto') -> None:
355
+ super().__init__(font_size=font_size,
356
+ scatter_color=scatter_color,
357
+ scatter_alpha=scatter_alpha,
358
+ ideal_line_color=ideal_line_color,
359
+ residual_line_color=residual_line_color,
360
+ hist_bins=hist_bins)
361
+
362
+
363
+ # Classification
364
+ class BinaryClassificationMetricsFormat(_BaseClassificationFormat):
365
+ """
366
+ Configuration for binary classification.
367
+ """
368
+ def __init__(self,
369
+ cmap: str="Blues",
370
+ ROC_PR_line: str='darkorange',
371
+ calibration_bins: int=15,
372
+ font_size: int=16) -> None:
373
+ super().__init__(cmap=cmap,
374
+ ROC_PR_line=ROC_PR_line,
375
+ calibration_bins=calibration_bins,
376
+ font_size=font_size)
377
+
378
+
379
+ class MultiClassClassificationMetricsFormat(_BaseClassificationFormat):
380
+ """
381
+ Configuration for multi-class classification.
382
+ """
383
+ def __init__(self,
384
+ cmap: str="Blues",
385
+ ROC_PR_line: str='darkorange',
386
+ calibration_bins: int=15,
387
+ font_size: int=16) -> None:
388
+ super().__init__(cmap=cmap,
389
+ ROC_PR_line=ROC_PR_line,
390
+ calibration_bins=calibration_bins,
391
+ font_size=font_size)
392
+
393
+
394
+ class BinaryImageClassificationMetricsFormat(_BaseClassificationFormat):
395
+ """
396
+ Configuration for binary image classification.
397
+ """
398
+ def __init__(self,
399
+ cmap: str="Blues",
400
+ ROC_PR_line: str='darkorange',
401
+ calibration_bins: int=15,
402
+ font_size: int=16) -> None:
403
+ super().__init__(cmap=cmap,
404
+ ROC_PR_line=ROC_PR_line,
405
+ calibration_bins=calibration_bins,
406
+ font_size=font_size)
407
+
408
+
409
+ class MultiClassImageClassificationMetricsFormat(_BaseClassificationFormat):
410
+ """
411
+ Configuration for multi-class image classification.
412
+ """
413
+ def __init__(self,
414
+ cmap: str="Blues",
415
+ ROC_PR_line: str='darkorange',
416
+ calibration_bins: int=15,
417
+ font_size: int=16) -> None:
418
+ super().__init__(cmap=cmap,
419
+ ROC_PR_line=ROC_PR_line,
420
+ calibration_bins=calibration_bins,
421
+ font_size=font_size)
422
+
423
+
424
+ # Multi-Label classification
425
+ class MultiLabelBinaryClassificationMetricsFormat(_BaseMultiLabelFormat):
426
+ """
427
+ Configuration for multi-label binary classification.
428
+ """
429
+ def __init__(self,
430
+ ROC_PR_line: str='darkorange',
431
+ cmap: str = "Blues",
432
+ font_size: int = 16) -> None:
433
+ super().__init__(ROC_PR_line=ROC_PR_line,
434
+ cmap=cmap,
435
+ font_size=font_size)
436
+
437
+
438
+ # Segmentation
439
+ class BinarySegmentationMetricsFormat(_BaseSegmentationFormat):
440
+ """
441
+ Configuration for binary segmentation.
442
+ """
443
+ def __init__(self,
444
+ heatmap_cmap: str = 'viridis',
445
+ cm_cmap: str = "Blues",
446
+ font_size: int = 16) -> None:
447
+ super().__init__(heatmap_cmap=heatmap_cmap,
448
+ cm_cmap=cm_cmap,
449
+ font_size=font_size)
450
+
451
+
452
+ class MultiClassSegmentationMetricsFormat(_BaseSegmentationFormat):
453
+ """
454
+ Configuration for multi-class segmentation.
455
+ """
456
+ def __init__(self,
457
+ heatmap_cmap: str = 'viridis',
458
+ cm_cmap: str = "Blues",
459
+ font_size: int = 16) -> None:
460
+ super().__init__(heatmap_cmap=heatmap_cmap,
461
+ cm_cmap=cm_cmap,
462
+ font_size=font_size)
463
+
464
+
465
+ # Sequence
466
+ class SequenceValueMetricsFormat(_BaseSequenceValueFormat):
467
+ """
468
+ Configuration for sequence-to-value prediction.
469
+ """
470
+ def __init__(self,
471
+ font_size: int=16,
472
+ scatter_color: str='tab:blue',
473
+ scatter_alpha: float=0.6,
474
+ ideal_line_color: str='k',
475
+ residual_line_color: str='red',
476
+ hist_bins: Union[int, str] = 'auto') -> None:
477
+ super().__init__(font_size=font_size,
478
+ scatter_color=scatter_color,
479
+ scatter_alpha=scatter_alpha,
480
+ ideal_line_color=ideal_line_color,
481
+ residual_line_color=residual_line_color,
482
+ hist_bins=hist_bins)
483
+
484
+
485
+ class SequenceSequenceMetricsFormat(_BaseSequenceSequenceFormat):
486
+ """
487
+ Configuration for sequence-to-sequence prediction.
488
+ """
489
+ def __init__(self,
490
+ font_size: int = 16,
491
+ plot_figsize: tuple[int, int] = (10, 6),
492
+ grid_style: str = '--',
493
+ rmse_color: str = 'tab:blue',
494
+ rmse_marker: str = 'o-',
495
+ mae_color: str = 'tab:orange',
496
+ mae_marker: str = 's--'):
497
+ super().__init__(font_size=font_size,
498
+ plot_figsize=plot_figsize,
499
+ grid_style=grid_style,
500
+ rmse_color=rmse_color,
501
+ rmse_marker=rmse_marker,
502
+ mae_color=mae_color,
503
+ mae_marker=mae_marker)
504
+
505
+
506
+ # -------- Finalize classes --------
507
+ class _FinalizeModelTraining:
508
+ """
509
+ Base class for finalizing model training.
510
+
511
+ This class is not intended to be instantiated directly. Instead, use one of its specific subclasses.
512
+ """
513
+ def __init__(self,
514
+ filename: str,
515
+ ) -> None:
516
+ self.filename = _validate_string(string=filename, attribute_name="filename", extension=".pth")
517
+ self.target_name: Optional[str] = None
518
+ self.target_names: Optional[list[str]] = None
519
+ self.classification_threshold: Optional[float] = None
520
+ self.class_map: Optional[dict[str,int]] = None
521
+ self.initial_sequence: Optional[np.ndarray] = None
522
+ self.sequence_length: Optional[int] = None
523
+
524
+
525
+ class FinalizeRegression(_FinalizeModelTraining):
526
+ """Parameters for finalizing a single-target regression model."""
527
+ def __init__(self,
528
+ filename: str,
529
+ target_name: str,
530
+ ) -> None:
531
+ """Initializes the finalization parameters.
532
+
533
+ Args:
534
+ filename (str): The name of the file to be saved.
535
+ target_name (str): The name of the target variable.
536
+ """
537
+ super().__init__(filename=filename)
538
+ self.target_name = _validate_string(string=target_name, attribute_name="Target name")
539
+
540
+
541
+ class FinalizeMultiTargetRegression(_FinalizeModelTraining):
542
+ """Parameters for finalizing a multi-target regression model."""
543
+ def __init__(self,
544
+ filename: str,
545
+ target_names: list[str],
546
+ ) -> None:
547
+ """Initializes the finalization parameters.
548
+
549
+ Args:
550
+ filename (str): The name of the file to be saved.
551
+ target_names (list[str]): A list of names for the target variables.
552
+ """
553
+ super().__init__(filename=filename)
554
+ safe_names = [_validate_string(string=target_name, attribute_name="All target names") for target_name in target_names]
555
+ self.target_names = safe_names
556
+
557
+
558
+ class FinalizeBinaryClassification(_FinalizeModelTraining):
559
+ """Parameters for finalizing a binary classification model."""
560
+ def __init__(self,
561
+ filename: str,
562
+ target_name: str,
563
+ classification_threshold: float,
564
+ class_map: dict[str,int]
565
+ ) -> None:
566
+ """Initializes the finalization parameters.
567
+
568
+ Args:
569
+ filename (str): The name of the file to be saved.
570
+ target_name (str): The name of the target variable.
571
+ classification_threshold (float): The cutoff threshold for classifying as the positive class.
572
+ class_map (dict[str,int]): A dictionary mapping class names (str)
573
+ to their integer representations (e.g., {'cat': 0, 'dog': 1}).
574
+ """
575
+ super().__init__(filename=filename)
576
+ self.target_name = _validate_string(string=target_name, attribute_name="Target name")
577
+ self.classification_threshold = _validate_threshold(classification_threshold)
578
+ self.class_map = _validate_class_map(class_map)
579
+
580
+
581
+ class FinalizeMultiClassClassification(_FinalizeModelTraining):
582
+ """Parameters for finalizing a multi-class classification model."""
583
+ def __init__(self,
584
+ filename: str,
585
+ target_name: str,
586
+ class_map: dict[str,int]
587
+ ) -> None:
588
+ """Initializes the finalization parameters.
589
+
590
+ Args:
591
+ filename (str): The name of the file to be saved.
592
+ target_name (str): The name of the target variable.
593
+ class_map (dict[str,int]): A dictionary mapping class names (str)
594
+ to their integer representations (e.g., {'cat': 0, 'dog': 1}).
595
+ """
596
+ super().__init__(filename=filename)
597
+ self.target_name = _validate_string(string=target_name, attribute_name="Target name")
598
+ self.class_map = _validate_class_map(class_map)
599
+
600
+
601
+ class FinalizeBinaryImageClassification(_FinalizeModelTraining):
602
+ """Parameters for finalizing a binary image classification model."""
603
+ def __init__(self,
604
+ filename: str,
605
+ classification_threshold: float,
606
+ class_map: dict[str,int]
607
+ ) -> None:
608
+ """Initializes the finalization parameters.
609
+
610
+ Args:
611
+ filename (str): The name of the file to be saved.
612
+ classification_threshold (float): The cutoff threshold for
613
+ classifying as the positive class.
614
+ class_map (dict[str,int]): A dictionary mapping class names (str)
615
+ to their integer representations (e.g., {'cat': 0, 'dog': 1}).
616
+ """
617
+ super().__init__(filename=filename)
618
+ self.classification_threshold = _validate_threshold(classification_threshold)
619
+ self.class_map = _validate_class_map(class_map)
620
+
621
+
622
+ class FinalizeMultiClassImageClassification(_FinalizeModelTraining):
623
+ """Parameters for finalizing a multi-class image classification model."""
624
+ def __init__(self,
625
+ filename: str,
626
+ class_map: dict[str,int]
627
+ ) -> None:
628
+ """Initializes the finalization parameters.
629
+
630
+ Args:
631
+ filename (str): The name of the file to be saved.
632
+ class_map (dict[str,int]): A dictionary mapping class names (str)
633
+ to their integer representations (e.g., {'cat': 0, 'dog': 1}).
634
+ """
635
+ super().__init__(filename=filename)
636
+ self.class_map = _validate_class_map(class_map)
637
+
638
+
639
+ class FinalizeMultiLabelBinaryClassification(_FinalizeModelTraining):
640
+ """Parameters for finalizing a multi-label binary classification model."""
641
+ def __init__(self,
642
+ filename: str,
643
+ target_names: list[str],
644
+ classification_threshold: float,
645
+ ) -> None:
646
+ """Initializes the finalization parameters.
647
+
648
+ Args:
649
+ filename (str): The name of the file to be saved.
650
+ target_names (list[str]): A list of names for the target variables.
651
+ classification_threshold (float): The cutoff threshold for classifying as the positive class.
652
+ """
653
+ super().__init__(filename=filename)
654
+ safe_names = [_validate_string(string=target_name, attribute_name="All target names") for target_name in target_names]
655
+ self.target_names = safe_names
656
+ self.classification_threshold = _validate_threshold(classification_threshold)
657
+
658
+
659
+ class FinalizeBinarySegmentation(_FinalizeModelTraining):
660
+ """Parameters for finalizing a binary segmentation model."""
661
+ def __init__(self,
662
+ filename: str,
663
+ class_map: dict[str,int],
664
+ classification_threshold: float,
665
+ ) -> None:
666
+ """Initializes the finalization parameters.
667
+
668
+ Args:
669
+ filename (str): The name of the file to be saved.
670
+ classification_threshold (float): The cutoff threshold for classifying as the positive class (mask).
671
+ """
672
+ super().__init__(filename=filename)
673
+ self.classification_threshold = _validate_threshold(classification_threshold)
674
+ self.class_map = _validate_class_map(class_map)
675
+
676
+
677
+ class FinalizeMultiClassSegmentation(_FinalizeModelTraining):
678
+ """Parameters for finalizing a multi-class segmentation model."""
679
+ def __init__(self,
680
+ filename: str,
681
+ class_map: dict[str,int]
682
+ ) -> None:
683
+ """Initializes the finalization parameters.
684
+
685
+ Args:
686
+ filename (str): The name of the file to be saved.
687
+ """
688
+ super().__init__(filename=filename)
689
+ self.class_map = _validate_class_map(class_map)
690
+
691
+
692
+ class FinalizeObjectDetection(_FinalizeModelTraining):
693
+ """Parameters for finalizing an object detection model."""
694
+ def __init__(self,
695
+ filename: str,
696
+ class_map: dict[str,int]
697
+ ) -> None:
698
+ """Initializes the finalization parameters.
699
+
700
+ Args:
701
+ filename (str): The name of the file to be saved.
702
+ """
703
+ super().__init__(filename=filename)
704
+ self.class_map = _validate_class_map(class_map)
705
+
706
+
707
+ class FinalizeSequencePrediction(_FinalizeModelTraining):
708
+ """Parameters for finalizing a sequence prediction model."""
709
+ def __init__(self,
710
+ filename: str,
711
+ last_training_sequence: np.ndarray,
712
+ ) -> None:
713
+ """Initializes the finalization parameters.
714
+
715
+ Args:
716
+ filename (str): The name of the file to be saved.
717
+ last_training_sequence (np.ndarray): The last sequence from the training data, needed to start predictions.
718
+ """
719
+ super().__init__(filename=filename)
720
+
721
+ if not isinstance(last_training_sequence, np.ndarray):
722
+ _LOGGER.error(f"The last training sequence must be a 1D numpy array, got {type(last_training_sequence)}.")
723
+ raise TypeError()
724
+
725
+ if last_training_sequence.ndim == 1:
726
+ # It's already 1D, (N,). This is valid.
727
+ self.initial_sequence = last_training_sequence
728
+ elif last_training_sequence.ndim == 2:
729
+ # It's 2D, check for shape (1, N)
730
+ if last_training_sequence.shape[0] == 1:
731
+ # Shape is (1, N), flatten to (N,)
732
+ self.initial_sequence = last_training_sequence.flatten()
733
+ else:
734
+ # Shape is (N, 1) or (N, M), which is invalid
735
+ _LOGGER.error(f"The last training sequence must be a 1D numpy array, got shape {last_training_sequence.shape}.")
736
+ raise ValueError()
737
+ else:
738
+ # It's 3D or more, which is not supported
739
+ _LOGGER.error(f"The last training sequence must be a 1D numpy array, got shape {last_training_sequence.shape}.")
740
+ raise ValueError()
741
+
742
+ # Save the length of the validated 1D sequence
743
+ self.sequence_length = len(self.initial_sequence)
744
+
745
+
746
+ def _validate_string(string: str, attribute_name: str, extension: Optional[str]=None) -> str:
747
+ """Helper for finalize classes"""
748
+ if not isinstance(string, str):
749
+ _LOGGER.error(f"{attribute_name} must be a string.")
750
+ raise TypeError()
751
+
752
+ if extension:
753
+ safe_name = sanitize_filename(string)
754
+
755
+ if not safe_name.endswith(extension):
756
+ safe_name += extension
757
+ else:
758
+ safe_name = string
759
+
760
+ return safe_name
761
+
762
+ def _validate_threshold(threshold: float):
763
+ """Helper for finalize classes"""
764
+ if not isinstance(threshold, float):
765
+ _LOGGER.error(f"Classification threshold must be a float.")
766
+ raise TypeError()
767
+ elif threshold <= 0.0 or threshold >= 1.0:
768
+ _LOGGER.error(f"Classification threshold must be in the range [0.1, 0.9]")
769
+ raise ValueError()
770
+
771
+ return threshold
772
+
773
+ def _validate_class_map(map: dict[str,int]):
774
+ """Helper for finalize classes"""
775
+ validated_map = None
776
+ if isinstance(map, dict):
777
+ if all( [isinstance(key, str) for key in map.keys()] ):
778
+ if all( [isinstance(val, str) for val in map.values()] ):
779
+ validated_map = map
780
+
781
+ if validated_map is None:
782
+ _LOGGER.error(f"Class map must be a dictionary of string keys and integer values.")
783
+ raise TypeError()
784
+ else:
785
+ return validated_map
786
+
787
+ def info():
788
+ _script_info(__all__)