dragon-ml-toolbox 14.3.1__py3-none-any.whl → 14.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dragon-ml-toolbox might be problematic. Click here for more details.
- {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-14.8.0.dist-info}/METADATA +2 -1
- {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-14.8.0.dist-info}/RECORD +17 -16
- ml_tools/ML_configuration.py +116 -0
- ml_tools/ML_datasetmaster.py +42 -0
- ml_tools/ML_evaluation.py +208 -63
- ml_tools/ML_evaluation_multi.py +40 -10
- ml_tools/ML_trainer.py +38 -12
- ml_tools/ML_utilities.py +50 -1
- ml_tools/ML_vision_datasetmaster.py +198 -60
- ml_tools/ML_vision_models.py +15 -1
- ml_tools/ML_vision_transformers.py +151 -6
- ml_tools/ensemble_evaluation.py +53 -10
- ml_tools/keys.py +2 -1
- {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-14.8.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-14.8.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-14.8.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-14.8.0.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: dragon-ml-toolbox
|
|
3
|
-
Version: 14.
|
|
3
|
+
Version: 14.8.0
|
|
4
4
|
Summary: A collection of tools for data science and machine learning projects.
|
|
5
5
|
Author-email: "Karl L. Loza Vidaurre" <luigiloza@gmail.com>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -141,6 +141,7 @@ ETL_cleaning
|
|
|
141
141
|
ETL_engineering
|
|
142
142
|
math_utilities
|
|
143
143
|
ML_callbacks
|
|
144
|
+
ML_configuration
|
|
144
145
|
ML_datasetmaster
|
|
145
146
|
ML_evaluation_multi
|
|
146
147
|
ML_evaluation
|
|
@@ -1,25 +1,26 @@
|
|
|
1
|
-
dragon_ml_toolbox-14.
|
|
2
|
-
dragon_ml_toolbox-14.
|
|
1
|
+
dragon_ml_toolbox-14.8.0.dist-info/licenses/LICENSE,sha256=L35WDmmLZNTlJvxF6Vy7Uy4SYNi6rCfWUqlTHpoRMoU,1081
|
|
2
|
+
dragon_ml_toolbox-14.8.0.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=gkOdNDbKYpIJezwSo2CEnISkLeYfYHv9t8b5K2-P69A,2687
|
|
3
3
|
ml_tools/ETL_cleaning.py,sha256=2VBRllV8F-ZiPylPp8Az2gwn5ztgazN0BH5OKnRUhV0,20402
|
|
4
4
|
ml_tools/ETL_engineering.py,sha256=KfYqgsxupAx6e_TxwO1LZXeu5mFkIhVXJrNjP3CzIZc,54927
|
|
5
5
|
ml_tools/GUI_tools.py,sha256=Va6ig-dHULPVRwQYYtH3fvY5XPIoqRcJpRW8oXC55Hw,45413
|
|
6
6
|
ml_tools/MICE_imputation.py,sha256=KLJXGQLKJ6AuWWttAG-LCCaxpS-ygM4dXPiguHDaL6Y,20815
|
|
7
7
|
ml_tools/ML_callbacks.py,sha256=elD2Yr030sv_6gX_m9GVd6HTyrbmt34nFS8lrgS4HtM,15808
|
|
8
|
-
ml_tools/
|
|
9
|
-
ml_tools/
|
|
10
|
-
ml_tools/
|
|
8
|
+
ml_tools/ML_configuration.py,sha256=tXkm2q57bl2kK0Iqpx1G7s1pEURBL_UMmqD8mlsGPs4,4689
|
|
9
|
+
ml_tools/ML_datasetmaster.py,sha256=Zi5jBnBI_U6tD8mpCVL5bQcsqsGEMAzMsCVI_wFD2QU,30175
|
|
10
|
+
ml_tools/ML_evaluation.py,sha256=EvlgFeMQeZ1RSEMtNd-nv7W0d0SVcR4n6cwW5UG16DU,25358
|
|
11
|
+
ml_tools/ML_evaluation_multi.py,sha256=bQZ2gJY-dBzKQxvtd-B6wVaGBdFpQGVBr7tQZFokp5E,17166
|
|
11
12
|
ml_tools/ML_inference.py,sha256=YJ953bhNWsdlPRtJQh3h2ACfMIgp8dQ9KtL9Azar-5s,23489
|
|
12
13
|
ml_tools/ML_models.py,sha256=PqOcNlws7vCJMbiVCKqlPuktxvskZVUHG3VfU-Yshf8,31415
|
|
13
14
|
ml_tools/ML_models_advanced.py,sha256=vk3PZBSu3DVso2S1rKTxxdS43XG8Q5FnasIL3-rMajc,12410
|
|
14
15
|
ml_tools/ML_optimization.py,sha256=P0zkhKAwTpkorIBtR0AOIDcyexo5ngmvFUzo3DfNO-E,22692
|
|
15
16
|
ml_tools/ML_scaler.py,sha256=tw6onj9o8_kk3FQYb930HUzvv1zsFZe2YZJdF3LtHkU,7538
|
|
16
|
-
ml_tools/ML_trainer.py,sha256=
|
|
17
|
-
ml_tools/ML_utilities.py,sha256=
|
|
18
|
-
ml_tools/ML_vision_datasetmaster.py,sha256=
|
|
17
|
+
ml_tools/ML_trainer.py,sha256=salZxfv3RWRCiinp5S9xeUsHysMbMQ52EecR8GyEbaM,51461
|
|
18
|
+
ml_tools/ML_utilities.py,sha256=eYe2N-65FTzaOHF5gmiJl-HmicyzhqcdvlDiIivr5_g,22993
|
|
19
|
+
ml_tools/ML_vision_datasetmaster.py,sha256=VHZo0gzgrXrfGcHA34WKD3gGfhlxMrOXbNdYhXb6p6M,64462
|
|
19
20
|
ml_tools/ML_vision_evaluation.py,sha256=t12R7i1RkOCt9zu1_lxSBr8OH6A6Get0k8ftDLctn6I,10486
|
|
20
21
|
ml_tools/ML_vision_inference.py,sha256=He3KV3VJAm8PwO-fOq4b9VO8UXFr-GmpuCnoHXf4VZI,20588
|
|
21
|
-
ml_tools/ML_vision_models.py,sha256=
|
|
22
|
-
ml_tools/ML_vision_transformers.py,sha256=
|
|
22
|
+
ml_tools/ML_vision_models.py,sha256=WqiRN9JAjv--BcwkDrooXAs4Qo26JHPCHh3JSPm4kMI,26226
|
|
23
|
+
ml_tools/ML_vision_transformers.py,sha256=h332O9BjDMgxrBc0I-bJwJODWlcp7nJHbX1QS2etwBk,7738
|
|
23
24
|
ml_tools/PSO_optimization.py,sha256=T-HWHMRJUnPvPwixdU5jif3_rnnI36TzcL8u3oSCwuA,22960
|
|
24
25
|
ml_tools/RNN_forecast.py,sha256=Qa2KoZfdAvSjZ4yE78N4BFXtr3tTr0Gx7tQJZPotsh0,1967
|
|
25
26
|
ml_tools/SQL.py,sha256=vXLPGfVVg8bfkbBE3HVfyEclVbdJy0TBhuQONtMwSCQ,11234
|
|
@@ -32,17 +33,17 @@ ml_tools/_script_info.py,sha256=21r83LV3RubsNZ_RTEUON6RbDf7Mh4_udweNcvdF_Fk,212
|
|
|
32
33
|
ml_tools/constants.py,sha256=3br5Rk9cL2IUo638eJuMOGdbGQaWssaUecYEvSeRBLM,3322
|
|
33
34
|
ml_tools/custom_logger.py,sha256=TGc0Ww2Xlqj2XE3q4bP43hV7T3qnb5ci9f0pYHXF5TY,11226
|
|
34
35
|
ml_tools/data_exploration.py,sha256=bwHzFJ-IAo5GN3T53F-1J_pXUg8VHS91sG_90utAsfg,69911
|
|
35
|
-
ml_tools/ensemble_evaluation.py,sha256=
|
|
36
|
+
ml_tools/ensemble_evaluation.py,sha256=2sJ3jD6yBNPRNwSokyaLKqKHi0QhF13ChoFe5yd4zwg,28368
|
|
36
37
|
ml_tools/ensemble_inference.py,sha256=0yLmLNj45RVVoSCLH1ZYJG9IoAhTkWUqEZmLOQTFGTY,9348
|
|
37
38
|
ml_tools/ensemble_learning.py,sha256=vsIED7nlheYI4w2SBzP6SC1AnNeMfn-2A1Gqw5EfxsM,21964
|
|
38
39
|
ml_tools/handle_excel.py,sha256=pfdAPb9ywegFkM9T54bRssDOsX-K7rSeV0RaMz7lEAo,14006
|
|
39
|
-
ml_tools/keys.py,sha256
|
|
40
|
+
ml_tools/keys.py,sha256=-OiL9G0RIOKQk6BwETKIP3LWz2s5-x6lZW2YitJa4mY,3330
|
|
40
41
|
ml_tools/math_utilities.py,sha256=xeKq1quR_3DYLgowcp4Uam_4s3JltUyOnqMOGuAiYWU,8802
|
|
41
42
|
ml_tools/optimization_tools.py,sha256=TYFQ2nSnp7xxs-VyoZISWgnGJghFbsWasHjruegyJRs,12763
|
|
42
43
|
ml_tools/path_manager.py,sha256=CyDU16pOKmC82jPubqJPT6EBt-u-3rGVbxyPIZCvDDY,18432
|
|
43
44
|
ml_tools/serde.py,sha256=c8uDYjYry_VrLvoG4ixqDj5pij88lVn6Tu4NHcPkwDU,6943
|
|
44
45
|
ml_tools/utilities.py,sha256=aWqvYzmxlD74PD5Yqu1VuTekDJeYLQrmPIU_VeVyRp0,22526
|
|
45
|
-
dragon_ml_toolbox-14.
|
|
46
|
-
dragon_ml_toolbox-14.
|
|
47
|
-
dragon_ml_toolbox-14.
|
|
48
|
-
dragon_ml_toolbox-14.
|
|
46
|
+
dragon_ml_toolbox-14.8.0.dist-info/METADATA,sha256=9OndkhzBGS_XzlCPuHH88wIgndT2jhWN4fydXTGJg-8,6492
|
|
47
|
+
dragon_ml_toolbox-14.8.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
48
|
+
dragon_ml_toolbox-14.8.0.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
|
|
49
|
+
dragon_ml_toolbox-14.8.0.dist-info/RECORD,,
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
from ._script_info import _script_info
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"ClassificationMetricsFormat",
|
|
7
|
+
"MultiClassificationMetricsFormat"
|
|
8
|
+
]
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ClassificationMetricsFormat:
|
|
12
|
+
"""
|
|
13
|
+
Optional configuration for classification tasks, use in the '.evaluate()' method of the MLTrainer.
|
|
14
|
+
"""
|
|
15
|
+
def __init__(self,
|
|
16
|
+
cmap: str="Blues",
|
|
17
|
+
class_map: Optional[dict[str,int]]=None,
|
|
18
|
+
ROC_PR_line: str='darkorange',
|
|
19
|
+
calibration_bins: int=15,
|
|
20
|
+
font_size: int=16) -> None:
|
|
21
|
+
"""
|
|
22
|
+
Initializes the formatting configuration for single-label classification metrics.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
cmap (str): The matplotlib colormap name for the confusion matrix
|
|
26
|
+
and report heatmap. Defaults to "Blues".
|
|
27
|
+
- Sequential options: 'Blues', 'Greens', 'Reds', 'Oranges', 'Purples'
|
|
28
|
+
- Diverging options: 'coolwarm', 'viridis', 'plasma', 'inferno'
|
|
29
|
+
|
|
30
|
+
class_map (dict[str,int] | None): A dictionary mapping
|
|
31
|
+
class string names to their integer indices (e.g., {'cat': 0, 'dog': 1}).
|
|
32
|
+
This is used to label the axes of the confusion matrix and classification
|
|
33
|
+
report correctly. Defaults to None.
|
|
34
|
+
|
|
35
|
+
ROC_PR_line (str): The color name or hex code for the line plotted
|
|
36
|
+
on the ROC and Precision-Recall curves. Defaults to 'darkorange'.
|
|
37
|
+
- Common color names: 'darkorange', 'cornflowerblue', 'crimson', 'forestgreen'
|
|
38
|
+
- Hex codes: '#FF6347', '#4682B4'
|
|
39
|
+
|
|
40
|
+
calibration_bins (int): The number of bins to use when
|
|
41
|
+
creating the calibration (reliability) plot. Defaults to 15.
|
|
42
|
+
|
|
43
|
+
font_size (int): The base font size to apply to the plots. Defaults to 16.
|
|
44
|
+
|
|
45
|
+
<br>
|
|
46
|
+
|
|
47
|
+
## [Matplotlib Colormaps](https://matplotlib.org/stable/users/explain/colors/colormaps.html)
|
|
48
|
+
"""
|
|
49
|
+
self.cmap = cmap
|
|
50
|
+
self.class_map = class_map
|
|
51
|
+
self.ROC_PR_line = ROC_PR_line
|
|
52
|
+
self.calibration_bins = calibration_bins
|
|
53
|
+
self.font_size = font_size
|
|
54
|
+
|
|
55
|
+
def __repr__(self) -> str:
|
|
56
|
+
parts = [
|
|
57
|
+
f"cmap='{self.cmap}'",
|
|
58
|
+
f"class_map={self.class_map}",
|
|
59
|
+
f"ROC_PR_line='{self.ROC_PR_line}'",
|
|
60
|
+
f"calibration_bins={self.calibration_bins}",
|
|
61
|
+
f"font_size={self.font_size}"
|
|
62
|
+
]
|
|
63
|
+
return f"ClassificationMetricsFormat({', '.join(parts)})"
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class MultiClassificationMetricsFormat:
|
|
67
|
+
"""
|
|
68
|
+
Optional configuration for multi-label classification tasks, use in the '.evaluate()' method of the MLTrainer.
|
|
69
|
+
"""
|
|
70
|
+
def __init__(self,
|
|
71
|
+
threshold: float=0.5,
|
|
72
|
+
ROC_PR_line: str='darkorange',
|
|
73
|
+
cmap: str = "Blues",
|
|
74
|
+
font_size: int = 16) -> None:
|
|
75
|
+
"""
|
|
76
|
+
Initializes the formatting configuration for multi-label classification metrics.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
threshold (float): The probability threshold (0.0 to 1.0) used
|
|
80
|
+
to convert sigmoid outputs into binary (0 or 1) predictions for
|
|
81
|
+
calculating the confusion matrix and overall metrics. Defaults to 0.5.
|
|
82
|
+
|
|
83
|
+
ROC_PR_line (str): The color name or hex code for the line plotted
|
|
84
|
+
on the ROC and Precision-Recall curves (one for each label).
|
|
85
|
+
Defaults to 'darkorange'.
|
|
86
|
+
- Common color names: 'darkorange', 'cornflowerblue', 'crimson', 'forestgreen'
|
|
87
|
+
- Hex codes: '#FF6347', '#4682B4'
|
|
88
|
+
|
|
89
|
+
cmap (str): The matplotlib colormap name for the per-label
|
|
90
|
+
confusion matrices. Defaults to "Blues".
|
|
91
|
+
- Sequential options: 'Blues', 'Greens', 'Reds', 'Oranges', 'Purples'
|
|
92
|
+
- Diverging options: 'coolwarm', 'viridis', 'plasma', 'inferno'
|
|
93
|
+
|
|
94
|
+
font_size (int): The base font size to apply to the plots. Defaults to 16.
|
|
95
|
+
|
|
96
|
+
<br>
|
|
97
|
+
|
|
98
|
+
## [Matplotlib Colormaps](https://matplotlib.org/stable/users/explain/colors/colormaps.html)
|
|
99
|
+
"""
|
|
100
|
+
self.threshold = threshold
|
|
101
|
+
self.cmap = cmap
|
|
102
|
+
self.ROC_PR_line = ROC_PR_line
|
|
103
|
+
self.font_size = font_size
|
|
104
|
+
|
|
105
|
+
def __repr__(self) -> str:
|
|
106
|
+
parts = [
|
|
107
|
+
f"threshold={self.threshold}",
|
|
108
|
+
f"ROC_PR_line='{self.ROC_PR_line}'",
|
|
109
|
+
f"cmap='{self.cmap}'",
|
|
110
|
+
f"font_size={self.font_size}"
|
|
111
|
+
]
|
|
112
|
+
return f"MultiClassificationMetricsFormat({', '.join(parts)})"
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def info():
|
|
116
|
+
_script_info(__all__)
|
ml_tools/ML_datasetmaster.py
CHANGED
|
@@ -333,7 +333,20 @@ class DatasetMaker(_BaseDatasetMaker):
|
|
|
333
333
|
# --- 5. Create Datasets ---
|
|
334
334
|
self._train_ds = _PytorchDataset(X_train_final, y_train, labels_dtype=label_dtype, feature_names=self._feature_names, target_names=self._target_names)
|
|
335
335
|
self._test_ds = _PytorchDataset(X_test_final, y_test, labels_dtype=label_dtype, feature_names=self._feature_names, target_names=self._target_names)
|
|
336
|
+
|
|
337
|
+
def __repr__(self) -> str:
|
|
338
|
+
s = f"<{self.__class__.__name__} (ID: '{self.id}')>\n"
|
|
339
|
+
s += f" Target: {self.target_names[0]}\n"
|
|
340
|
+
s += f" Features: {self.number_of_features}\n"
|
|
341
|
+
s += f" Scaler: {'Fitted' if self.scaler else 'None'}\n"
|
|
336
342
|
|
|
343
|
+
if self._train_ds:
|
|
344
|
+
s += f" Train Samples: {len(self._train_ds)}\n" # type: ignore
|
|
345
|
+
if self._test_ds:
|
|
346
|
+
s += f" Test Samples: {len(self._test_ds)}\n" # type: ignore
|
|
347
|
+
|
|
348
|
+
return s
|
|
349
|
+
|
|
337
350
|
|
|
338
351
|
# --- Multi-Target Class ---
|
|
339
352
|
class DatasetMakerMulti(_BaseDatasetMaker):
|
|
@@ -448,6 +461,19 @@ class DatasetMakerMulti(_BaseDatasetMaker):
|
|
|
448
461
|
self._train_ds = _PytorchDataset(X_train_final, y_train, labels_dtype=label_dtype, feature_names=self._feature_names, target_names=self._target_names)
|
|
449
462
|
self._test_ds = _PytorchDataset(X_test_final, y_test, labels_dtype=label_dtype, feature_names=self._feature_names, target_names=self._target_names)
|
|
450
463
|
|
|
464
|
+
def __repr__(self) -> str:
|
|
465
|
+
s = f"<{self.__class__.__name__} (ID: '{self.id}')>\n"
|
|
466
|
+
s += f" Targets: {self.number_of_targets}\n"
|
|
467
|
+
s += f" Features: {self.number_of_features}\n"
|
|
468
|
+
s += f" Scaler: {'Fitted' if self.scaler else 'None'}\n"
|
|
469
|
+
|
|
470
|
+
if self._train_ds:
|
|
471
|
+
s += f" Train Samples: {len(self._train_ds)}\n" # type: ignore
|
|
472
|
+
if self._test_ds:
|
|
473
|
+
s += f" Test Samples: {len(self._test_ds)}\n" # type: ignore
|
|
474
|
+
|
|
475
|
+
return s
|
|
476
|
+
|
|
451
477
|
|
|
452
478
|
# --- Private Base Class ---
|
|
453
479
|
class _BaseMaker(ABC):
|
|
@@ -654,6 +680,22 @@ class SequenceMaker(_BaseMaker):
|
|
|
654
680
|
_LOGGER.error("Windows have not been generated. Call .generate_windows() first.")
|
|
655
681
|
raise RuntimeError()
|
|
656
682
|
return self._train_dataset, self._test_dataset
|
|
683
|
+
|
|
684
|
+
def __repr__(self) -> str:
|
|
685
|
+
s = f"<{self.__class__.__name__}>:\n"
|
|
686
|
+
s += f" Sequence Length (Window): {self.sequence_length}\n"
|
|
687
|
+
s += f" Total Data Points: {len(self.sequence)}\n"
|
|
688
|
+
s += " --- Status ---\n"
|
|
689
|
+
s += f" Split: {self._is_split}\n"
|
|
690
|
+
s += f" Normalized: {self._is_normalized}\n"
|
|
691
|
+
s += f" Windows Generated: {self._are_windows_generated}\n"
|
|
692
|
+
|
|
693
|
+
if self._are_windows_generated:
|
|
694
|
+
train_len = len(self._train_dataset) if self._train_dataset else 0 # type: ignore
|
|
695
|
+
test_len = len(self._test_dataset) if self._test_dataset else 0 # type: ignore
|
|
696
|
+
s += f" Datasets (Train/Test): {train_len} / {test_len} windows\n"
|
|
697
|
+
|
|
698
|
+
return s
|
|
657
699
|
|
|
658
700
|
|
|
659
701
|
def info():
|
ml_tools/ML_evaluation.py
CHANGED
|
@@ -21,7 +21,7 @@ from pathlib import Path
|
|
|
21
21
|
from typing import Union, Optional, List, Literal
|
|
22
22
|
import warnings
|
|
23
23
|
|
|
24
|
-
from .path_manager import make_fullpath
|
|
24
|
+
from .path_manager import make_fullpath, sanitize_filename
|
|
25
25
|
from ._logger import _LOGGER
|
|
26
26
|
from ._script_info import _script_info
|
|
27
27
|
from .keys import SHAPKeys, PyTorchLogKeys
|
|
@@ -35,6 +35,8 @@ __all__ = [
|
|
|
35
35
|
"plot_attention_importance"
|
|
36
36
|
]
|
|
37
37
|
|
|
38
|
+
DPI_value = 250
|
|
39
|
+
|
|
38
40
|
|
|
39
41
|
def plot_losses(history: dict, save_dir: Union[str, Path]):
|
|
40
42
|
"""
|
|
@@ -48,10 +50,10 @@ def plot_losses(history: dict, save_dir: Union[str, Path]):
|
|
|
48
50
|
val_loss = history.get(PyTorchLogKeys.VAL_LOSS, [])
|
|
49
51
|
|
|
50
52
|
if not train_loss and not val_loss:
|
|
51
|
-
|
|
53
|
+
_LOGGER.warning("Loss history is empty or incomplete. Cannot plot.")
|
|
52
54
|
return
|
|
53
55
|
|
|
54
|
-
fig, ax = plt.subplots(figsize=(10, 5), dpi=
|
|
56
|
+
fig, ax = plt.subplots(figsize=(10, 5), dpi=DPI_value)
|
|
55
57
|
|
|
56
58
|
# Plot training loss only if data for it exists
|
|
57
59
|
if train_loss:
|
|
@@ -78,8 +80,15 @@ def plot_losses(history: dict, save_dir: Union[str, Path]):
|
|
|
78
80
|
plt.close(fig)
|
|
79
81
|
|
|
80
82
|
|
|
81
|
-
def classification_metrics(save_dir: Union[str, Path],
|
|
82
|
-
|
|
83
|
+
def classification_metrics(save_dir: Union[str, Path],
|
|
84
|
+
y_true: np.ndarray,
|
|
85
|
+
y_pred: np.ndarray,
|
|
86
|
+
y_prob: Optional[np.ndarray] = None,
|
|
87
|
+
cmap: str = "Blues",
|
|
88
|
+
class_map: Optional[dict[str,int]]=None,
|
|
89
|
+
ROC_PR_line: str='darkorange',
|
|
90
|
+
calibration_bins: int=15,
|
|
91
|
+
font_size: int=16):
|
|
83
92
|
"""
|
|
84
93
|
Saves classification metrics and plots.
|
|
85
94
|
|
|
@@ -89,12 +98,31 @@ def classification_metrics(save_dir: Union[str, Path], y_true: np.ndarray, y_pre
|
|
|
89
98
|
y_prob (np.ndarray, optional): Predicted probabilities for ROC curve.
|
|
90
99
|
cmap (str): Colormap for the confusion matrix.
|
|
91
100
|
save_dir (str | Path): Directory to save plots.
|
|
101
|
+
class_map (dict[str, int], None): A map of {class_name: index} used to order and label the confusion matrix.
|
|
92
102
|
"""
|
|
93
|
-
|
|
103
|
+
original_rc_params = plt.rcParams.copy()
|
|
104
|
+
plt.rcParams.update({'font.size': font_size})
|
|
105
|
+
|
|
106
|
+
# print("--- Classification Report ---")
|
|
107
|
+
|
|
108
|
+
# --- Parse class_map ---
|
|
109
|
+
map_labels = None
|
|
110
|
+
map_display_labels = None
|
|
111
|
+
if class_map:
|
|
112
|
+
# Sort the map by its values (the indices) to ensure correct order
|
|
113
|
+
try:
|
|
114
|
+
sorted_items = sorted(class_map.items(), key=lambda item: item[1])
|
|
115
|
+
map_labels = [item[1] for item in sorted_items]
|
|
116
|
+
map_display_labels = [item[0] for item in sorted_items]
|
|
117
|
+
except Exception as e:
|
|
118
|
+
_LOGGER.warning(f"Could not parse 'class_map': {e}")
|
|
119
|
+
map_labels = None
|
|
120
|
+
map_display_labels = None
|
|
121
|
+
|
|
94
122
|
# Generate report as both text and dictionary
|
|
95
|
-
report_text: str = classification_report(y_true, y_pred) # type: ignore
|
|
96
|
-
report_dict: dict = classification_report(y_true, y_pred, output_dict=True) # type: ignore
|
|
97
|
-
print(report_text)
|
|
123
|
+
report_text: str = classification_report(y_true, y_pred, labels=map_labels, target_names=map_display_labels) # type: ignore
|
|
124
|
+
report_dict: dict = classification_report(y_true, y_pred, output_dict=True, labels=map_labels, target_names=map_display_labels) # type: ignore
|
|
125
|
+
# print(report_text)
|
|
98
126
|
|
|
99
127
|
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
100
128
|
# Save text report
|
|
@@ -104,8 +132,15 @@ def classification_metrics(save_dir: Union[str, Path], y_true: np.ndarray, y_pre
|
|
|
104
132
|
|
|
105
133
|
# --- Save Classification Report Heatmap ---
|
|
106
134
|
try:
|
|
107
|
-
plt.figure(figsize=(8, 6), dpi=
|
|
108
|
-
sns.
|
|
135
|
+
plt.figure(figsize=(8, 6), dpi=DPI_value)
|
|
136
|
+
sns.set_theme(font_scale=1.2) # Scale seaborn font
|
|
137
|
+
sns.heatmap(pd.DataFrame(report_dict).iloc[:-1, :].T,
|
|
138
|
+
annot=True,
|
|
139
|
+
cmap=cmap,
|
|
140
|
+
fmt='.2f',
|
|
141
|
+
vmin=0.0,
|
|
142
|
+
vmax=1.0)
|
|
143
|
+
sns.set_theme(font_scale=1.0) # Reset seaborn scale
|
|
109
144
|
plt.title("Classification Report")
|
|
110
145
|
plt.tight_layout()
|
|
111
146
|
heatmap_path = save_dir_path / "classification_report_heatmap.svg"
|
|
@@ -114,69 +149,179 @@ def classification_metrics(save_dir: Union[str, Path], y_true: np.ndarray, y_pre
|
|
|
114
149
|
plt.close()
|
|
115
150
|
except Exception as e:
|
|
116
151
|
_LOGGER.error(f"Could not generate classification report heatmap: {e}")
|
|
117
|
-
|
|
152
|
+
|
|
153
|
+
# --- labels for Confusion Matrix ---
|
|
154
|
+
plot_labels = map_labels
|
|
155
|
+
plot_display_labels = map_display_labels
|
|
156
|
+
|
|
118
157
|
# Save Confusion Matrix
|
|
119
|
-
fig_cm, ax_cm = plt.subplots(figsize=(6, 6), dpi=
|
|
120
|
-
ConfusionMatrixDisplay.from_predictions(y_true,
|
|
158
|
+
fig_cm, ax_cm = plt.subplots(figsize=(6, 6), dpi=DPI_value)
|
|
159
|
+
disp_ = ConfusionMatrixDisplay.from_predictions(y_true,
|
|
160
|
+
y_pred,
|
|
161
|
+
cmap=cmap,
|
|
162
|
+
ax=ax_cm,
|
|
163
|
+
normalize='true',
|
|
164
|
+
labels=plot_labels,
|
|
165
|
+
display_labels=plot_display_labels)
|
|
166
|
+
|
|
167
|
+
disp_.im_.set_clim(vmin=0.0, vmax=1.0)
|
|
168
|
+
|
|
169
|
+
# Turn off gridlines
|
|
170
|
+
ax_cm.grid(False)
|
|
171
|
+
|
|
172
|
+
# Manually update font size of cell texts
|
|
173
|
+
for text in ax_cm.texts:
|
|
174
|
+
text.set_fontsize(font_size)
|
|
175
|
+
|
|
176
|
+
fig_cm.tight_layout()
|
|
177
|
+
|
|
121
178
|
ax_cm.set_title("Confusion Matrix")
|
|
122
179
|
cm_path = save_dir_path / "confusion_matrix.svg"
|
|
123
180
|
plt.savefig(cm_path)
|
|
124
181
|
_LOGGER.info(f"❇️ Confusion matrix saved as '{cm_path.name}'")
|
|
125
182
|
plt.close(fig_cm)
|
|
126
183
|
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
184
|
+
|
|
185
|
+
# Plotting logic for ROC, PR, and Calibration Curves
|
|
186
|
+
if y_prob is not None and y_prob.ndim == 2:
|
|
187
|
+
num_classes = y_prob.shape[1]
|
|
188
|
+
|
|
189
|
+
# --- Determine which classes to loop over ---
|
|
190
|
+
class_indices_to_plot = []
|
|
191
|
+
plot_titles = []
|
|
192
|
+
save_suffixes = []
|
|
193
|
+
|
|
194
|
+
if num_classes == 2:
|
|
195
|
+
# Binary case: Only plot for the positive class (index 1)
|
|
196
|
+
class_indices_to_plot = [1]
|
|
197
|
+
plot_titles = [""] # No extra title
|
|
198
|
+
save_suffixes = [""] # No extra suffix
|
|
199
|
+
_LOGGER.info("Generating binary classification plots (ROC, PR, Calibration).")
|
|
131
200
|
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
ax_pr.set_title('Precision-Recall Curve')
|
|
154
|
-
ax_pr.set_xlabel('Recall')
|
|
155
|
-
ax_pr.set_ylabel('Precision')
|
|
156
|
-
ax_pr.legend(loc='lower left')
|
|
157
|
-
ax_pr.grid(True)
|
|
158
|
-
pr_path = save_dir_path / "pr_curve.svg"
|
|
159
|
-
plt.savefig(pr_path)
|
|
160
|
-
_LOGGER.info(f"📈 PR curve saved as '{pr_path.name}'")
|
|
161
|
-
plt.close(fig_pr)
|
|
201
|
+
elif num_classes > 2:
|
|
202
|
+
_LOGGER.info(f"Generating One-vs-Rest plots for {num_classes} classes.")
|
|
203
|
+
# Multiclass case: Plot for every class (One-vs-Rest)
|
|
204
|
+
class_indices_to_plot = list(range(num_classes))
|
|
205
|
+
|
|
206
|
+
# --- Use class_map names if available ---
|
|
207
|
+
use_generic_names = True
|
|
208
|
+
if map_display_labels and len(map_display_labels) == num_classes:
|
|
209
|
+
try:
|
|
210
|
+
# Ensure labels are safe for filenames
|
|
211
|
+
safe_names = [sanitize_filename(name) for name in map_display_labels]
|
|
212
|
+
plot_titles = [f" ({name} vs. Rest)" for name in map_display_labels]
|
|
213
|
+
save_suffixes = [f"_{safe_names[i]}" for i in class_indices_to_plot]
|
|
214
|
+
use_generic_names = False
|
|
215
|
+
except Exception as e:
|
|
216
|
+
_LOGGER.warning(f"Failed to use 'class_map' for plot titles: {e}. Reverting to generic names.")
|
|
217
|
+
use_generic_names = True
|
|
218
|
+
|
|
219
|
+
if use_generic_names:
|
|
220
|
+
plot_titles = [f" (Class {i} vs. Rest)" for i in class_indices_to_plot]
|
|
221
|
+
save_suffixes = [f"_class_{i}" for i in class_indices_to_plot]
|
|
162
222
|
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
223
|
+
else:
|
|
224
|
+
# Should not happen, but good to check
|
|
225
|
+
_LOGGER.warning(f"Probability array has invalid shape {y_prob.shape}. Skipping ROC/PR/Calibration plots.")
|
|
226
|
+
|
|
227
|
+
# --- Loop and generate plots ---
|
|
228
|
+
for i, class_index in enumerate(class_indices_to_plot):
|
|
229
|
+
plot_title = plot_titles[i]
|
|
230
|
+
save_suffix = save_suffixes[i]
|
|
231
|
+
|
|
232
|
+
# Get scores for the current class
|
|
233
|
+
y_score = y_prob[:, class_index]
|
|
234
|
+
|
|
235
|
+
# Binarize y_true for the current class
|
|
236
|
+
y_true_binary = (y_true == class_index).astype(int)
|
|
237
|
+
|
|
238
|
+
# --- Save ROC Curve ---
|
|
239
|
+
fpr, tpr, _ = roc_curve(y_true_binary, y_score)
|
|
166
240
|
|
|
167
|
-
|
|
168
|
-
|
|
241
|
+
# Calculate AUC.
|
|
242
|
+
# Note: For multiclass, roc_auc_score(y_true, y_prob, multi_class='ovr') could average, but plotting individual curves is more informative.
|
|
243
|
+
# Here we calculate the specific AUC for the binarized problem.
|
|
244
|
+
auc = roc_auc_score(y_true_binary, y_score)
|
|
169
245
|
|
|
170
|
-
|
|
246
|
+
fig_roc, ax_roc = plt.subplots(figsize=(6, 6), dpi=DPI_value)
|
|
247
|
+
ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}', color=ROC_PR_line)
|
|
248
|
+
ax_roc.plot([0, 1], [0, 1], 'k--')
|
|
249
|
+
ax_roc.set_title(f'Receiver Operating Characteristic{plot_title}')
|
|
250
|
+
ax_roc.set_xlabel('False Positive Rate')
|
|
251
|
+
ax_roc.set_ylabel('True Positive Rate')
|
|
252
|
+
ax_roc.legend(loc='lower right')
|
|
253
|
+
ax_roc.grid(True)
|
|
254
|
+
roc_path = save_dir_path / f"roc_curve{save_suffix}.svg"
|
|
255
|
+
plt.savefig(roc_path)
|
|
256
|
+
plt.close(fig_roc)
|
|
257
|
+
|
|
258
|
+
# --- Save Precision-Recall Curve ---
|
|
259
|
+
precision, recall, _ = precision_recall_curve(y_true_binary, y_score)
|
|
260
|
+
ap_score = average_precision_score(y_true_binary, y_score)
|
|
261
|
+
fig_pr, ax_pr = plt.subplots(figsize=(6, 6), dpi=DPI_value)
|
|
262
|
+
ax_pr.plot(recall, precision, label=f'Avg Precision = {ap_score:.2f}', color=ROC_PR_line)
|
|
263
|
+
ax_pr.set_title(f'Precision-Recall Curve{plot_title}')
|
|
264
|
+
ax_pr.set_xlabel('Recall')
|
|
265
|
+
ax_pr.set_ylabel('Precision')
|
|
266
|
+
ax_pr.legend(loc='lower left')
|
|
267
|
+
ax_pr.grid(True)
|
|
268
|
+
pr_path = save_dir_path / f"pr_curve{save_suffix}.svg"
|
|
269
|
+
plt.savefig(pr_path)
|
|
270
|
+
plt.close(fig_pr)
|
|
271
|
+
|
|
272
|
+
# --- Save Calibration Plot ---
|
|
273
|
+
fig_cal, ax_cal = plt.subplots(figsize=(8, 8), dpi=DPI_value)
|
|
274
|
+
|
|
275
|
+
# --- Step 1: Get binned data *without* plotting ---
|
|
276
|
+
with plt.ioff(): # Suppress showing the temporary plot
|
|
277
|
+
fig_temp, ax_temp = plt.subplots()
|
|
278
|
+
cal_display_temp = CalibrationDisplay.from_predictions(
|
|
279
|
+
y_true_binary, # Use binarized labels
|
|
280
|
+
y_score,
|
|
281
|
+
n_bins=calibration_bins,
|
|
282
|
+
ax=ax_temp,
|
|
283
|
+
name="temp" # Add a name to suppress potential warnings
|
|
284
|
+
)
|
|
285
|
+
# Get the x, y coordinates of the binned data
|
|
286
|
+
line_x, line_y = cal_display_temp.line_.get_data() # type: ignore
|
|
287
|
+
plt.close(fig_temp) # Close the temporary plot
|
|
288
|
+
|
|
289
|
+
# --- Step 2: Build the plot from scratch ---
|
|
290
|
+
ax_cal.plot([0, 1], [0, 1], 'k--', label='Perfectly calibrated')
|
|
291
|
+
|
|
292
|
+
sns.regplot(
|
|
293
|
+
x=line_x,
|
|
294
|
+
y=line_y,
|
|
295
|
+
ax=ax_cal,
|
|
296
|
+
scatter=False,
|
|
297
|
+
label=f"Calibration Curve ({calibration_bins} bins)",
|
|
298
|
+
line_kws={
|
|
299
|
+
'color': ROC_PR_line,
|
|
300
|
+
'linestyle': '--',
|
|
301
|
+
'linewidth': 2,
|
|
302
|
+
}
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
ax_cal.set_title(f'Reliability Curve{plot_title}')
|
|
171
306
|
ax_cal.set_xlabel('Mean Predicted Probability')
|
|
172
307
|
ax_cal.set_ylabel('Fraction of Positives')
|
|
308
|
+
|
|
309
|
+
# --- Step 3: Set final limits *after* plotting ---
|
|
310
|
+
ax_cal.set_ylim(0.0, 1.0)
|
|
311
|
+
ax_cal.set_xlim(0.0, 1.0)
|
|
312
|
+
|
|
313
|
+
ax_cal.legend(loc='lower right')
|
|
173
314
|
ax_cal.grid(True)
|
|
174
315
|
plt.tight_layout()
|
|
175
316
|
|
|
176
|
-
cal_path = save_dir_path / "calibration_plot.svg"
|
|
317
|
+
cal_path = save_dir_path / f"calibration_plot{save_suffix}.svg"
|
|
177
318
|
plt.savefig(cal_path)
|
|
178
|
-
_LOGGER.info(f"📈 Calibration plot saved as '{cal_path.name}'")
|
|
179
319
|
plt.close(fig_cal)
|
|
320
|
+
|
|
321
|
+
_LOGGER.info(f"📈 Saved {len(class_indices_to_plot)} sets of ROC, Precision-Recall, and Calibration plots.")
|
|
322
|
+
|
|
323
|
+
# restore RC params
|
|
324
|
+
plt.rcParams.update(original_rc_params)
|
|
180
325
|
|
|
181
326
|
|
|
182
327
|
def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[str, Path]):
|
|
@@ -211,7 +356,7 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
|
|
|
211
356
|
|
|
212
357
|
# Save residual plot
|
|
213
358
|
residuals = y_true - y_pred
|
|
214
|
-
fig_res, ax_res = plt.subplots(figsize=(8, 6), dpi=
|
|
359
|
+
fig_res, ax_res = plt.subplots(figsize=(8, 6), dpi=DPI_value)
|
|
215
360
|
ax_res.scatter(y_pred, residuals, alpha=0.6)
|
|
216
361
|
ax_res.axhline(0, color='red', linestyle='--')
|
|
217
362
|
ax_res.set_xlabel("Predicted Values")
|
|
@@ -225,7 +370,7 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
|
|
|
225
370
|
plt.close(fig_res)
|
|
226
371
|
|
|
227
372
|
# Save true vs predicted plot
|
|
228
|
-
fig_tvp, ax_tvp = plt.subplots(figsize=(8, 6), dpi=
|
|
373
|
+
fig_tvp, ax_tvp = plt.subplots(figsize=(8, 6), dpi=DPI_value)
|
|
229
374
|
ax_tvp.scatter(y_true, y_pred, alpha=0.6)
|
|
230
375
|
ax_tvp.plot([y_true.min(), y_true.max()], [y_true.min(), y_true.max()], 'k--', lw=2)
|
|
231
376
|
ax_tvp.set_xlabel('True Values')
|
|
@@ -239,7 +384,7 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
|
|
|
239
384
|
plt.close(fig_tvp)
|
|
240
385
|
|
|
241
386
|
# Save Histogram of Residuals
|
|
242
|
-
fig_hist, ax_hist = plt.subplots(figsize=(8, 6), dpi=
|
|
387
|
+
fig_hist, ax_hist = plt.subplots(figsize=(8, 6), dpi=DPI_value)
|
|
243
388
|
sns.histplot(residuals, kde=True, ax=ax_hist)
|
|
244
389
|
ax_hist.set_xlabel("Residual Value")
|
|
245
390
|
ax_hist.set_ylabel("Frequency")
|
|
@@ -276,7 +421,7 @@ def shap_summary_plot(model,
|
|
|
276
421
|
slow and memory-intensive.
|
|
277
422
|
"""
|
|
278
423
|
|
|
279
|
-
|
|
424
|
+
_LOGGER.info(f"📊 Running SHAP Value Explanation Using {explainer_type.upper()} Explainer")
|
|
280
425
|
|
|
281
426
|
model.eval()
|
|
282
427
|
# model.cpu() # Run explanations on CPU
|
|
@@ -348,9 +493,9 @@ def shap_summary_plot(model,
|
|
|
348
493
|
_LOGGER.error(f"Invalid explainer_type: '{explainer_type}'. Must be 'deep' or 'kernel'.")
|
|
349
494
|
raise ValueError()
|
|
350
495
|
|
|
351
|
-
if not isinstance(shap_values, list) and shap_values.ndim == 3 and shap_values.shape[2] == 1:
|
|
496
|
+
if not isinstance(shap_values, list) and shap_values.ndim == 3 and shap_values.shape[2] == 1: # type: ignore
|
|
352
497
|
# _LOGGER.info("Squeezing SHAP values from (N, F, 1) to (N, F) for regression plot.")
|
|
353
|
-
shap_values = shap_values.squeeze(-1)
|
|
498
|
+
shap_values = shap_values.squeeze(-1) # type: ignore
|
|
354
499
|
|
|
355
500
|
# --- 3. Plotting and Saving ---
|
|
356
501
|
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
@@ -455,7 +600,7 @@ def plot_attention_importance(weights: List[torch.Tensor], feature_names: Option
|
|
|
455
600
|
# --- Step 3: Create and save the plot for top N features ---
|
|
456
601
|
plot_df = summary_df.head(top_n).sort_values('mean_attention', ascending=True)
|
|
457
602
|
|
|
458
|
-
plt.figure(figsize=(10, 8), dpi=
|
|
603
|
+
plt.figure(figsize=(10, 8), dpi=DPI_value)
|
|
459
604
|
|
|
460
605
|
# Create horizontal bar plot with error bars
|
|
461
606
|
plt.barh(
|