dragon-ml-toolbox 5.3.0__py3-none-any.whl → 6.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dragon-ml-toolbox might be problematic. Click here for more details.
- {dragon_ml_toolbox-5.3.0.dist-info → dragon_ml_toolbox-6.0.0.dist-info}/METADATA +9 -6
- {dragon_ml_toolbox-5.3.0.dist-info → dragon_ml_toolbox-6.0.0.dist-info}/RECORD +15 -14
- ml_tools/ML_callbacks.py +7 -7
- ml_tools/ML_evaluation.py +196 -106
- ml_tools/ML_trainer.py +17 -15
- ml_tools/PSO_optimization.py +5 -5
- ml_tools/ensemble_evaluation.py +639 -0
- ml_tools/ensemble_inference.py +10 -10
- ml_tools/ensemble_learning.py +47 -413
- ml_tools/keys.py +2 -2
- ml_tools/utilities.py +27 -3
- {dragon_ml_toolbox-5.3.0.dist-info → dragon_ml_toolbox-6.0.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-5.3.0.dist-info → dragon_ml_toolbox-6.0.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-5.3.0.dist-info → dragon_ml_toolbox-6.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-5.3.0.dist-info → dragon_ml_toolbox-6.0.0.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: dragon-ml-toolbox
|
|
3
|
-
Version:
|
|
3
|
+
Version: 6.0.0
|
|
4
4
|
Summary: A collection of tools for data science and machine learning projects.
|
|
5
5
|
Author-email: Karl Loza <luigiloza@gmail.com>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -141,19 +141,22 @@ pip install "dragon-ml-toolbox[pytorch]"
|
|
|
141
141
|
```bash
|
|
142
142
|
custom_logger
|
|
143
143
|
data_exploration
|
|
144
|
-
|
|
144
|
+
ensemble_evaluation
|
|
145
145
|
ensemble_inference
|
|
146
|
+
ensemble_learning
|
|
146
147
|
ETL_engineering
|
|
147
|
-
ML_datasetmaster
|
|
148
|
-
ML_models
|
|
149
148
|
ML_callbacks
|
|
149
|
+
ML_datasetmaster
|
|
150
150
|
ML_evaluation
|
|
151
|
-
ML_trainer
|
|
152
151
|
ML_inference
|
|
152
|
+
ML_models
|
|
153
|
+
ML_optimization
|
|
154
|
+
ML_trainer
|
|
155
|
+
optimization_tools
|
|
153
156
|
path_manager
|
|
154
157
|
PSO_optimization
|
|
155
|
-
SQL
|
|
156
158
|
RNN_forecast
|
|
159
|
+
SQL
|
|
157
160
|
utilities
|
|
158
161
|
```
|
|
159
162
|
|
|
@@ -1,16 +1,16 @@
|
|
|
1
|
-
dragon_ml_toolbox-
|
|
2
|
-
dragon_ml_toolbox-
|
|
1
|
+
dragon_ml_toolbox-6.0.0.dist-info/licenses/LICENSE,sha256=2uUFNy7D0TLgHim1K5s3DIJ4q_KvxEXVilnU20cWliY,1066
|
|
2
|
+
dragon_ml_toolbox-6.0.0.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=lY4_rJPnLnMu7YBQaY-_iz1JRDcLdQzNCyeLAF1glJY,1837
|
|
3
3
|
ml_tools/ETL_engineering.py,sha256=4wwZXi9_U7xfCY70jGBaKniOeZ0m75ppxWpQBd_DmLc,39369
|
|
4
4
|
ml_tools/GUI_tools.py,sha256=n4ZZ5kEjwK5rkOCFJE41HeLFfjhpJVLUSzk9Kd9Kr_0,45410
|
|
5
5
|
ml_tools/MICE_imputation.py,sha256=oFHg-OytOzPYTzBR_wIRHhP71cMn3aupDeT59ABsXlQ,11576
|
|
6
|
-
ml_tools/ML_callbacks.py,sha256=
|
|
6
|
+
ml_tools/ML_callbacks.py,sha256=FEJ80TSEtY0-hdnOsAWeVApQt1mdzTdOntqtoWmMAzE,13310
|
|
7
7
|
ml_tools/ML_datasetmaster.py,sha256=bbKCNA_b_uDIfxP9YIYKZm-VSfUSD15LvegFxpE9DIQ,34315
|
|
8
|
-
ml_tools/ML_evaluation.py,sha256=
|
|
8
|
+
ml_tools/ML_evaluation.py,sha256=A7AlEjy4ZOcdQMh9M3TJIDvCOXqzAbhgLxyhli8S_WY,13593
|
|
9
9
|
ml_tools/ML_inference.py,sha256=Fh-X2UQn3AznWBjf-7iPSxwE-EzkGQm1VEIRUAkURmE,5336
|
|
10
10
|
ml_tools/ML_models.py,sha256=SJhKHGAN2VTBqzcHUOpFWuVZ2Y7U1M4P_axG_LNYWcI,6460
|
|
11
11
|
ml_tools/ML_optimization.py,sha256=zGKpWW4SL1-3iiHglDP-dkuADL73T0kxs3Dc-Lyishs,9671
|
|
12
|
-
ml_tools/ML_trainer.py,sha256=
|
|
13
|
-
ml_tools/PSO_optimization.py,sha256=
|
|
12
|
+
ml_tools/ML_trainer.py,sha256=1q_CDXuMfndRsPuNofUn2mg2TlhG6MYuGqjWxTDgN9c,15112
|
|
13
|
+
ml_tools/PSO_optimization.py,sha256=9Y074d-B5h4Wvp9YPiy6KAeXM-Yv6Il3gWalKvOLVgo,22705
|
|
14
14
|
ml_tools/RNN_forecast.py,sha256=2CyjBLSYYc3xLHxwLXUmP5Qv8AmV1OB_EndETNX1IBk,1956
|
|
15
15
|
ml_tools/SQL.py,sha256=9zzS6AFEJM9aj6nE31hDe8S9TqLonk-J1amwZoiHNbk,10468
|
|
16
16
|
ml_tools/VIF_factor.py,sha256=2nUMupfUoogf8o6ghoFZk_OwWhFXU0R3C9Gj0HOlI14,10415
|
|
@@ -19,14 +19,15 @@ ml_tools/_logger.py,sha256=TpgYguxO-CWYqqgLW0tqFjtwZ58PE_W2OCfWNGZr0n0,1175
|
|
|
19
19
|
ml_tools/_script_info.py,sha256=21r83LV3RubsNZ_RTEUON6RbDf7Mh4_udweNcvdF_Fk,212
|
|
20
20
|
ml_tools/custom_logger.py,sha256=njM_0XPbQ1S-x5LeSQAaTo2if-XVOR_pQSGg4EDeiTU,4603
|
|
21
21
|
ml_tools/data_exploration.py,sha256=P4f8OpRa7Q4i-11nkppxXw5Lx2lwlpn20GwWBbN_xbM,23901
|
|
22
|
-
ml_tools/
|
|
23
|
-
ml_tools/
|
|
22
|
+
ml_tools/ensemble_evaluation.py,sha256=ywpBCvmVImocZAcKv52mSdQKKHdLswozknoev39l4Yo,24682
|
|
23
|
+
ml_tools/ensemble_inference.py,sha256=rtU7eUaQne615n2g7IHZCJI-OvrBCcjxbTkEIvtCGFQ,9414
|
|
24
|
+
ml_tools/ensemble_learning.py,sha256=dAyFgSTyvxJWjc_enJ_8EUoWwiekBeoNyJNxVY-kcUU,21868
|
|
24
25
|
ml_tools/handle_excel.py,sha256=J9iwIqMZemoxK49J5osSwp9Ge0h9YTKyYGbOm53hcno,13007
|
|
25
|
-
ml_tools/keys.py,sha256=
|
|
26
|
+
ml_tools/keys.py,sha256=HtPG8-MWh89C32A7eIlfuuA-DLwkxGkoDfwR2TGN9CQ,1074
|
|
26
27
|
ml_tools/optimization_tools.py,sha256=MuT4OG7_r1QqLUti-yYix7QeCpglezD0oe9BDCq0QXk,5086
|
|
27
28
|
ml_tools/path_manager.py,sha256=Z8e7w3MPqQaN8xmTnKuXZS6CIW59BFwwqGhGc00sdp4,13692
|
|
28
|
-
ml_tools/utilities.py,sha256=
|
|
29
|
-
dragon_ml_toolbox-
|
|
30
|
-
dragon_ml_toolbox-
|
|
31
|
-
dragon_ml_toolbox-
|
|
32
|
-
dragon_ml_toolbox-
|
|
29
|
+
ml_tools/utilities.py,sha256=LqXXTovaHbA5AOKRk6Ru6DgAPAM0wPfYU70kUjYBryo,19231
|
|
30
|
+
dragon_ml_toolbox-6.0.0.dist-info/METADATA,sha256=v7JMG994i_tGqZJmN87pWxswxJEGQTsH2m2fQ_qz0C0,6698
|
|
31
|
+
dragon_ml_toolbox-6.0.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
32
|
+
dragon_ml_toolbox-6.0.0.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
|
|
33
|
+
dragon_ml_toolbox-6.0.0.dist-info/RECORD,,
|
ml_tools/ML_callbacks.py
CHANGED
|
@@ -2,7 +2,7 @@ import numpy as np
|
|
|
2
2
|
import torch
|
|
3
3
|
from tqdm.auto import tqdm
|
|
4
4
|
from .path_manager import make_fullpath
|
|
5
|
-
from .keys import
|
|
5
|
+
from .keys import PyTorchLogKeys
|
|
6
6
|
from ._logger import _LOGGER
|
|
7
7
|
from typing import Optional
|
|
8
8
|
from ._script_info import _script_info
|
|
@@ -96,14 +96,14 @@ class TqdmProgressBar(Callback):
|
|
|
96
96
|
def on_batch_end(self, batch, logs=None):
|
|
97
97
|
self.batch_bar.update(1) # type: ignore
|
|
98
98
|
if logs:
|
|
99
|
-
self.batch_bar.set_postfix(loss=f"{logs.get(
|
|
99
|
+
self.batch_bar.set_postfix(loss=f"{logs.get(PyTorchLogKeys.BATCH_LOSS, 0):.4f}") # type: ignore
|
|
100
100
|
|
|
101
101
|
def on_epoch_end(self, epoch, logs=None):
|
|
102
102
|
self.batch_bar.close() # type: ignore
|
|
103
103
|
self.epoch_bar.update(1) # type: ignore
|
|
104
104
|
if logs:
|
|
105
|
-
train_loss_str = f"{logs.get(
|
|
106
|
-
val_loss_str = f"{logs.get(
|
|
105
|
+
train_loss_str = f"{logs.get(PyTorchLogKeys.TRAIN_LOSS, 0):.4f}"
|
|
106
|
+
val_loss_str = f"{logs.get(PyTorchLogKeys.VAL_LOSS, 0):.4f}"
|
|
107
107
|
self.epoch_bar.set_postfix_str(f"Train Loss: {train_loss_str}, Val Loss: {val_loss_str}") # type: ignore
|
|
108
108
|
|
|
109
109
|
def on_train_end(self, logs=None):
|
|
@@ -124,7 +124,7 @@ class EarlyStopping(Callback):
|
|
|
124
124
|
inferred from the name of the monitored quantity.
|
|
125
125
|
verbose (int): Verbosity mode.
|
|
126
126
|
"""
|
|
127
|
-
def __init__(self, monitor: str=
|
|
127
|
+
def __init__(self, monitor: str=PyTorchLogKeys.VAL_LOSS, min_delta: float=0.0, patience: int=5, mode: Literal['auto', 'min', 'max']='auto', verbose: int=1):
|
|
128
128
|
super().__init__()
|
|
129
129
|
self.monitor = monitor
|
|
130
130
|
self.patience = patience
|
|
@@ -201,8 +201,8 @@ class ModelCheckpoint(Callback):
|
|
|
201
201
|
mode (str): One of {'auto', 'min', 'max'}.
|
|
202
202
|
verbose (int): Verbosity mode.
|
|
203
203
|
"""
|
|
204
|
-
def __init__(self, save_dir: Union[str,Path], monitor: str =
|
|
205
|
-
save_best_only: bool =
|
|
204
|
+
def __init__(self, save_dir: Union[str,Path], monitor: str = PyTorchLogKeys.VAL_LOSS,
|
|
205
|
+
save_best_only: bool = True, mode: Literal['auto', 'min', 'max']= 'auto', verbose: int = 0):
|
|
206
206
|
super().__init__()
|
|
207
207
|
self.save_dir = make_fullpath(save_dir, make=True, enforce="directory")
|
|
208
208
|
if not self.save_dir.is_dir():
|
ml_tools/ML_evaluation.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
import numpy as np
|
|
2
2
|
import pandas as pd
|
|
3
3
|
import matplotlib.pyplot as plt
|
|
4
|
+
import seaborn as sns
|
|
5
|
+
from sklearn.calibration import CalibrationDisplay
|
|
4
6
|
from sklearn.metrics import (
|
|
5
7
|
classification_report,
|
|
6
8
|
ConfusionMatrixDisplay,
|
|
@@ -9,7 +11,9 @@ from sklearn.metrics import (
|
|
|
9
11
|
mean_squared_error,
|
|
10
12
|
mean_absolute_error,
|
|
11
13
|
r2_score,
|
|
12
|
-
median_absolute_error
|
|
14
|
+
median_absolute_error,
|
|
15
|
+
precision_recall_curve,
|
|
16
|
+
average_precision_score
|
|
13
17
|
)
|
|
14
18
|
import torch
|
|
15
19
|
import shap
|
|
@@ -28,13 +32,13 @@ __all__ = [
|
|
|
28
32
|
]
|
|
29
33
|
|
|
30
34
|
|
|
31
|
-
def plot_losses(history: dict, save_dir:
|
|
35
|
+
def plot_losses(history: dict, save_dir: Union[str, Path]):
|
|
32
36
|
"""
|
|
33
37
|
Plots training & validation loss curves from a history object.
|
|
34
38
|
|
|
35
39
|
Args:
|
|
36
40
|
history (dict): A dictionary containing 'train_loss' and 'val_loss'.
|
|
37
|
-
save_dir (str | Path
|
|
41
|
+
save_dir (str | Path): Directory to save the plot image.
|
|
38
42
|
"""
|
|
39
43
|
train_loss = history.get('train_loss', [])
|
|
40
44
|
val_loss = history.get('val_loss', [])
|
|
@@ -62,86 +66,123 @@ def plot_losses(history: dict, save_dir: Optional[Union[str, Path]] = None):
|
|
|
62
66
|
ax.grid(True)
|
|
63
67
|
plt.tight_layout()
|
|
64
68
|
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
else:
|
|
71
|
-
plt.show()
|
|
69
|
+
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
70
|
+
save_path = save_dir_path / "loss_plot.svg"
|
|
71
|
+
plt.savefig(save_path)
|
|
72
|
+
_LOGGER.info(f"📉 Loss plot saved as '{save_path.name}'")
|
|
73
|
+
|
|
72
74
|
plt.close(fig)
|
|
73
75
|
|
|
74
76
|
|
|
75
|
-
def classification_metrics(y_true: np.ndarray, y_pred: np.ndarray, y_prob: Optional[np.ndarray] = None,
|
|
76
|
-
cmap: str = "Blues"
|
|
77
|
+
def classification_metrics(save_dir: Union[str, Path], y_true: np.ndarray, y_pred: np.ndarray, y_prob: Optional[np.ndarray] = None,
|
|
78
|
+
cmap: str = "Blues"):
|
|
77
79
|
"""
|
|
78
|
-
|
|
80
|
+
Saves classification metrics and plots.
|
|
79
81
|
|
|
80
82
|
Args:
|
|
81
83
|
y_true (np.ndarray): Ground truth labels.
|
|
82
84
|
y_pred (np.ndarray): Predicted labels.
|
|
83
85
|
y_prob (np.ndarray, optional): Predicted probabilities for ROC curve.
|
|
84
86
|
cmap (str): Colormap for the confusion matrix.
|
|
85
|
-
save_dir (str | Path
|
|
87
|
+
save_dir (str | Path): Directory to save plots.
|
|
86
88
|
"""
|
|
87
89
|
print("--- Classification Report ---")
|
|
88
|
-
|
|
89
|
-
|
|
90
|
+
# Generate report as both text and dictionary
|
|
91
|
+
report_text: str = classification_report(y_true, y_pred) # type: ignore
|
|
92
|
+
report_dict: dict = classification_report(y_true, y_pred, output_dict=True) # type: ignore
|
|
93
|
+
print(report_text)
|
|
90
94
|
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
95
|
+
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
96
|
+
# Save text report
|
|
97
|
+
report_path = save_dir_path / "classification_report.txt"
|
|
98
|
+
report_path.write_text(report_text, encoding="utf-8")
|
|
99
|
+
_LOGGER.info(f"📝 Classification report saved as '{report_path.name}'")
|
|
100
|
+
|
|
101
|
+
# --- Save Classification Report Heatmap ---
|
|
102
|
+
try:
|
|
103
|
+
plt.figure(figsize=(8, 6), dpi=100)
|
|
104
|
+
sns.heatmap(pd.DataFrame(report_dict).iloc[:-1, :].T, annot=True, cmap='viridis', fmt='.2f')
|
|
105
|
+
plt.title("Classification Report")
|
|
106
|
+
plt.tight_layout()
|
|
107
|
+
heatmap_path = save_dir_path / "classification_report_heatmap.svg"
|
|
108
|
+
plt.savefig(heatmap_path)
|
|
109
|
+
_LOGGER.info(f"📊 Report heatmap saved as '{heatmap_path.name}'")
|
|
110
|
+
plt.close()
|
|
111
|
+
except Exception as e:
|
|
112
|
+
_LOGGER.error(f"❌ Could not generate classification report heatmap: {e}")
|
|
97
113
|
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
114
|
+
# Save Confusion Matrix
|
|
115
|
+
fig_cm, ax_cm = plt.subplots(figsize=(6, 6), dpi=100)
|
|
116
|
+
ConfusionMatrixDisplay.from_predictions(y_true, y_pred, cmap=cmap, ax=ax_cm)
|
|
117
|
+
ax_cm.set_title("Confusion Matrix")
|
|
118
|
+
cm_path = save_dir_path / "confusion_matrix.svg"
|
|
119
|
+
plt.savefig(cm_path)
|
|
120
|
+
_LOGGER.info(f"❇️ Confusion matrix saved as '{cm_path.name}'")
|
|
121
|
+
plt.close(fig_cm)
|
|
106
122
|
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
plt.
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
123
|
+
# Plotting logic for ROC and PR Curves
|
|
124
|
+
if y_prob is not None and y_prob.ndim > 1 and y_prob.shape[1] >= 2:
|
|
125
|
+
# Use probabilities of the positive class
|
|
126
|
+
y_score = y_prob[:, 1]
|
|
127
|
+
|
|
128
|
+
# --- Save ROC Curve ---
|
|
129
|
+
fpr, tpr, _ = roc_curve(y_true, y_score)
|
|
130
|
+
auc = roc_auc_score(y_true, y_score)
|
|
131
|
+
fig_roc, ax_roc = plt.subplots(figsize=(6, 6), dpi=100)
|
|
132
|
+
ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}')
|
|
133
|
+
ax_roc.plot([0, 1], [0, 1], 'k--')
|
|
134
|
+
ax_roc.set_title('Receiver Operating Characteristic (ROC) Curve')
|
|
135
|
+
ax_roc.set_xlabel('False Positive Rate')
|
|
136
|
+
ax_roc.set_ylabel('True Positive Rate')
|
|
137
|
+
ax_roc.legend(loc='lower right')
|
|
138
|
+
ax_roc.grid(True)
|
|
139
|
+
roc_path = save_dir_path / "roc_curve.svg"
|
|
140
|
+
plt.savefig(roc_path)
|
|
141
|
+
_LOGGER.info(f"📈 ROC curve saved as '{roc_path.name}'")
|
|
142
|
+
plt.close(fig_roc)
|
|
143
|
+
|
|
144
|
+
# --- Save Precision-Recall Curve ---
|
|
145
|
+
precision, recall, _ = precision_recall_curve(y_true, y_score)
|
|
146
|
+
ap_score = average_precision_score(y_true, y_score)
|
|
147
|
+
fig_pr, ax_pr = plt.subplots(figsize=(6, 6), dpi=100)
|
|
148
|
+
ax_pr.plot(recall, precision, label=f'AP = {ap_score:.2f}')
|
|
149
|
+
ax_pr.set_title('Precision-Recall Curve')
|
|
150
|
+
ax_pr.set_xlabel('Recall')
|
|
151
|
+
ax_pr.set_ylabel('Precision')
|
|
152
|
+
ax_pr.legend(loc='lower left')
|
|
153
|
+
ax_pr.grid(True)
|
|
154
|
+
pr_path = save_dir_path / "pr_curve.svg"
|
|
155
|
+
plt.savefig(pr_path)
|
|
156
|
+
_LOGGER.info(f"📈 PR curve saved as '{pr_path.name}'")
|
|
157
|
+
plt.close(fig_pr)
|
|
158
|
+
|
|
159
|
+
# --- Save Calibration Plot ---
|
|
160
|
+
if y_prob.ndim > 1 and y_prob.shape[1] >= 2:
|
|
161
|
+
y_score = y_prob[:, 1] # Use probabilities of the positive class
|
|
162
|
+
|
|
163
|
+
fig_cal, ax_cal = plt.subplots(figsize=(8, 8), dpi=100)
|
|
164
|
+
CalibrationDisplay.from_predictions(y_true, y_score, n_bins=15, ax=ax_cal)
|
|
165
|
+
|
|
166
|
+
ax_cal.set_title('Calibration Plot (Reliability Curve)')
|
|
167
|
+
ax_cal.set_xlabel('Mean Predicted Probability')
|
|
168
|
+
ax_cal.set_ylabel('Fraction of Positives')
|
|
169
|
+
ax_cal.grid(True)
|
|
170
|
+
plt.tight_layout()
|
|
171
|
+
|
|
172
|
+
cal_path = save_dir_path / "calibration_plot.svg"
|
|
173
|
+
plt.savefig(cal_path)
|
|
174
|
+
_LOGGER.info(f"✅ Calibration plot saved as '{cal_path.name}'")
|
|
175
|
+
plt.close(fig_cal)
|
|
135
176
|
|
|
136
177
|
|
|
137
|
-
def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir:
|
|
178
|
+
def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[str, Path]):
|
|
138
179
|
"""
|
|
139
|
-
|
|
180
|
+
Saves regression metrics and plots.
|
|
140
181
|
|
|
141
182
|
Args:
|
|
142
183
|
y_true (np.ndarray): Ground truth values.
|
|
143
184
|
y_pred (np.ndarray): Predicted values.
|
|
144
|
-
save_dir (str |
|
|
185
|
+
save_dir (str | Path): Directory to save plots and report.
|
|
145
186
|
"""
|
|
146
187
|
rmse = np.sqrt(mean_squared_error(y_true, y_pred))
|
|
147
188
|
mae = mean_absolute_error(y_true, y_pred)
|
|
@@ -158,44 +199,56 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Optiona
|
|
|
158
199
|
report_string = "\n".join(report_lines)
|
|
159
200
|
print(report_string)
|
|
160
201
|
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
_LOGGER.info(f"📝 Regression report saved as '{report_path.name}'")
|
|
202
|
+
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
203
|
+
# Save text report
|
|
204
|
+
report_path = save_dir_path / "regression_report.txt"
|
|
205
|
+
report_path.write_text(report_string)
|
|
206
|
+
_LOGGER.info(f"📝 Regression report saved as '{report_path.name}'")
|
|
167
207
|
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
208
|
+
# Save residual plot
|
|
209
|
+
residuals = y_true - y_pred
|
|
210
|
+
fig_res, ax_res = plt.subplots(figsize=(8, 6), dpi=100)
|
|
211
|
+
ax_res.scatter(y_pred, residuals, alpha=0.6)
|
|
212
|
+
ax_res.axhline(0, color='red', linestyle='--')
|
|
213
|
+
ax_res.set_xlabel("Predicted Values")
|
|
214
|
+
ax_res.set_ylabel("Residuals")
|
|
215
|
+
ax_res.set_title("Residual Plot")
|
|
216
|
+
ax_res.grid(True)
|
|
217
|
+
plt.tight_layout()
|
|
218
|
+
res_path = save_dir_path / "residual_plot.svg"
|
|
219
|
+
plt.savefig(res_path)
|
|
220
|
+
_LOGGER.info(f"📈 Residual plot saved as '{res_path.name}'")
|
|
221
|
+
plt.close(fig_res)
|
|
182
222
|
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
223
|
+
# Save true vs predicted plot
|
|
224
|
+
fig_tvp, ax_tvp = plt.subplots(figsize=(8, 6), dpi=100)
|
|
225
|
+
ax_tvp.scatter(y_true, y_pred, alpha=0.6)
|
|
226
|
+
ax_tvp.plot([y_true.min(), y_true.max()], [y_true.min(), y_true.max()], 'k--', lw=2)
|
|
227
|
+
ax_tvp.set_xlabel('True Values')
|
|
228
|
+
ax_tvp.set_ylabel('Predictions')
|
|
229
|
+
ax_tvp.set_title('True vs. Predicted Values')
|
|
230
|
+
ax_tvp.grid(True)
|
|
231
|
+
plt.tight_layout()
|
|
232
|
+
tvp_path = save_dir_path / "true_vs_predicted_plot.svg"
|
|
233
|
+
plt.savefig(tvp_path)
|
|
234
|
+
_LOGGER.info(f"📉 True vs. Predicted plot saved as '{tvp_path.name}'")
|
|
235
|
+
plt.close(fig_tvp)
|
|
236
|
+
|
|
237
|
+
# Save Histogram of Residuals
|
|
238
|
+
fig_hist, ax_hist = plt.subplots(figsize=(8, 6), dpi=100)
|
|
239
|
+
sns.histplot(residuals, kde=True, ax=ax_hist)
|
|
240
|
+
ax_hist.set_xlabel("Residual Value")
|
|
241
|
+
ax_hist.set_ylabel("Frequency")
|
|
242
|
+
ax_hist.set_title("Distribution of Residuals")
|
|
243
|
+
ax_hist.grid(True)
|
|
244
|
+
plt.tight_layout()
|
|
245
|
+
hist_path = save_dir_path / "residuals_histogram.svg"
|
|
246
|
+
plt.savefig(hist_path)
|
|
247
|
+
_LOGGER.info(f"📊 Residuals histogram saved as '{hist_path.name}'")
|
|
248
|
+
plt.close(fig_hist)
|
|
196
249
|
|
|
197
250
|
|
|
198
|
-
def shap_summary_plot(model, background_data: torch.Tensor, instances_to_explain: torch.Tensor,
|
|
251
|
+
def shap_summary_plot(model, background_data: Union[torch.Tensor,np.ndarray], instances_to_explain: Union[torch.Tensor,np.ndarray],
|
|
199
252
|
feature_names: Optional[list[str]]=None, save_dir: Optional[Union[str, Path]] = None):
|
|
200
253
|
"""
|
|
201
254
|
Calculates SHAP values and saves summary plots and data.
|
|
@@ -207,24 +260,54 @@ def shap_summary_plot(model, background_data: torch.Tensor, instances_to_explain
|
|
|
207
260
|
feature_names (list of str | None): Names of the features for plot labeling.
|
|
208
261
|
save_dir (str | Path | None): Directory to save SHAP artifacts. If None, dot plot is shown.
|
|
209
262
|
"""
|
|
263
|
+
# everything to numpy
|
|
264
|
+
if isinstance(background_data, np.ndarray):
|
|
265
|
+
background_data_np = background_data
|
|
266
|
+
else:
|
|
267
|
+
background_data_np = background_data.numpy()
|
|
268
|
+
|
|
269
|
+
if isinstance(instances_to_explain, np.ndarray):
|
|
270
|
+
instances_to_explain_np = instances_to_explain
|
|
271
|
+
else:
|
|
272
|
+
instances_to_explain_np = instances_to_explain.numpy()
|
|
273
|
+
|
|
274
|
+
# --- Data Validation Step ---
|
|
275
|
+
if np.isnan(background_data_np).any() or np.isnan(instances_to_explain_np).any():
|
|
276
|
+
_LOGGER.error("❌ Input data for SHAP contains NaN values. Aborting explanation.")
|
|
277
|
+
return
|
|
278
|
+
|
|
210
279
|
print("\n--- SHAP Value Explanation ---")
|
|
211
|
-
print("Calculating SHAP values... ")
|
|
212
280
|
|
|
213
281
|
model.eval()
|
|
214
282
|
model.cpu()
|
|
215
283
|
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
284
|
+
# 1. Summarize the background data.
|
|
285
|
+
# Summarize the background data using k-means. 10-50 clusters is a good starting point.
|
|
286
|
+
background_summary = shap.kmeans(background_data_np, 30)
|
|
287
|
+
|
|
288
|
+
# 2. Define a prediction function wrapper that SHAP can use. It must take a numpy array and return a numpy array.
|
|
289
|
+
def prediction_wrapper(x_np: np.ndarray) -> np.ndarray:
|
|
290
|
+
# Convert numpy data to torch tensor
|
|
291
|
+
x_torch = torch.from_numpy(x_np).float()
|
|
292
|
+
with torch.no_grad():
|
|
293
|
+
# Get model output
|
|
294
|
+
output = model(x_torch)
|
|
295
|
+
# Return as numpy array
|
|
296
|
+
return output.cpu().numpy().flatten()
|
|
222
297
|
|
|
298
|
+
# 3. Create the KernelExplainer
|
|
299
|
+
explainer = shap.KernelExplainer(prediction_wrapper, background_summary)
|
|
300
|
+
|
|
301
|
+
print("Calculating SHAP values with KernelExplainer...")
|
|
302
|
+
shap_values = explainer.shap_values(instances_to_explain_np, l1_reg="aic")
|
|
303
|
+
|
|
223
304
|
if save_dir:
|
|
224
305
|
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
306
|
+
plt.ioff()
|
|
307
|
+
|
|
225
308
|
# Save Bar Plot
|
|
226
309
|
bar_path = save_dir_path / "shap_bar_plot.svg"
|
|
227
|
-
shap.summary_plot(
|
|
310
|
+
shap.summary_plot(shap_values, instances_to_explain_np, feature_names=feature_names, plot_type="bar", show=False)
|
|
228
311
|
plt.title("SHAP Feature Importance")
|
|
229
312
|
plt.tight_layout()
|
|
230
313
|
plt.savefig(bar_path)
|
|
@@ -233,7 +316,7 @@ def shap_summary_plot(model, background_data: torch.Tensor, instances_to_explain
|
|
|
233
316
|
|
|
234
317
|
# Save Dot Plot
|
|
235
318
|
dot_path = save_dir_path / "shap_dot_plot.svg"
|
|
236
|
-
shap.summary_plot(
|
|
319
|
+
shap.summary_plot(shap_values, instances_to_explain_np, feature_names=feature_names, plot_type="dot", show=False)
|
|
237
320
|
plt.title("SHAP Feature Importance")
|
|
238
321
|
plt.tight_layout()
|
|
239
322
|
plt.savefig(dot_path)
|
|
@@ -242,18 +325,25 @@ def shap_summary_plot(model, background_data: torch.Tensor, instances_to_explain
|
|
|
242
325
|
|
|
243
326
|
# Save Summary Data to CSV
|
|
244
327
|
summary_path = save_dir_path / "shap_summary.csv"
|
|
245
|
-
|
|
328
|
+
# Ensure the array is 1D before creating the DataFrame
|
|
329
|
+
mean_abs_shap = np.abs(shap_values).mean(axis=0).flatten()
|
|
330
|
+
|
|
246
331
|
if feature_names is None:
|
|
247
332
|
feature_names = [f'feature_{i}' for i in range(len(mean_abs_shap))]
|
|
333
|
+
|
|
248
334
|
summary_df = pd.DataFrame({
|
|
249
335
|
'feature': feature_names,
|
|
250
336
|
'mean_abs_shap_value': mean_abs_shap
|
|
251
337
|
}).sort_values('mean_abs_shap_value', ascending=False)
|
|
338
|
+
|
|
252
339
|
summary_df.to_csv(summary_path, index=False)
|
|
340
|
+
|
|
253
341
|
_LOGGER.info(f"📝 SHAP summary data saved as '{summary_path.name}'")
|
|
342
|
+
plt.ion()
|
|
343
|
+
|
|
254
344
|
else:
|
|
255
345
|
_LOGGER.info("No save directory provided. Displaying SHAP dot plot.")
|
|
256
|
-
shap.summary_plot(
|
|
346
|
+
shap.summary_plot(shap_values, instances_to_explain_np, feature_names=feature_names, plot_type="dot")
|
|
257
347
|
|
|
258
348
|
|
|
259
349
|
def info():
|
ml_tools/ML_trainer.py
CHANGED
|
@@ -8,16 +8,16 @@ import numpy as np
|
|
|
8
8
|
from .ML_callbacks import Callback, History, TqdmProgressBar
|
|
9
9
|
from .ML_evaluation import classification_metrics, regression_metrics, plot_losses, shap_summary_plot
|
|
10
10
|
from ._script_info import _script_info
|
|
11
|
-
from .keys import
|
|
11
|
+
from .keys import PyTorchLogKeys
|
|
12
12
|
from ._logger import _LOGGER
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
__all__ = [
|
|
16
|
-
"
|
|
16
|
+
"MLTrainer"
|
|
17
17
|
]
|
|
18
18
|
|
|
19
19
|
|
|
20
|
-
class
|
|
20
|
+
class MLTrainer:
|
|
21
21
|
def __init__(self, model: nn.Module, train_dataset: Dataset, test_dataset: Dataset,
|
|
22
22
|
kind: Literal["regression", "classification"],
|
|
23
23
|
criterion: nn.Module, optimizer: torch.optim.Optimizer,
|
|
@@ -95,14 +95,16 @@ class MyTrainer:
|
|
|
95
95
|
batch_size=batch_size,
|
|
96
96
|
shuffle=shuffle,
|
|
97
97
|
num_workers=loader_workers,
|
|
98
|
-
pin_memory=(self.device.type
|
|
98
|
+
pin_memory=("cuda" in self.device.type),
|
|
99
|
+
drop_last=True # Drops the last batch if incomplete, selecting a good batch size is key.
|
|
99
100
|
)
|
|
101
|
+
|
|
100
102
|
self.test_loader = DataLoader(
|
|
101
103
|
dataset=self.test_dataset,
|
|
102
104
|
batch_size=batch_size,
|
|
103
105
|
shuffle=False,
|
|
104
106
|
num_workers=loader_workers,
|
|
105
|
-
pin_memory=(self.device.type
|
|
107
|
+
pin_memory=("cuda" in self.device.type)
|
|
106
108
|
)
|
|
107
109
|
|
|
108
110
|
def fit(self, epochs: int = 10, batch_size: int = 10, shuffle: bool = True):
|
|
@@ -159,8 +161,8 @@ class MyTrainer:
|
|
|
159
161
|
for batch_idx, (features, target) in enumerate(self.train_loader): # type: ignore
|
|
160
162
|
# Create a log dictionary for the batch
|
|
161
163
|
batch_logs = {
|
|
162
|
-
|
|
163
|
-
|
|
164
|
+
PyTorchLogKeys.BATCH_INDEX: batch_idx,
|
|
165
|
+
PyTorchLogKeys.BATCH_SIZE: features.size(0)
|
|
164
166
|
}
|
|
165
167
|
self.callbacks_hook('on_batch_begin', batch_idx, logs=batch_logs)
|
|
166
168
|
|
|
@@ -178,11 +180,11 @@ class MyTrainer:
|
|
|
178
180
|
running_loss += batch_loss * features.size(0)
|
|
179
181
|
|
|
180
182
|
# Add the batch loss to the logs and call the end-of-batch hook
|
|
181
|
-
batch_logs[
|
|
183
|
+
batch_logs[PyTorchLogKeys.BATCH_LOSS] = batch_loss
|
|
182
184
|
self.callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
|
|
183
185
|
|
|
184
186
|
# Return the average loss for the entire epoch
|
|
185
|
-
return {
|
|
187
|
+
return {PyTorchLogKeys.TRAIN_LOSS: running_loss / len(self.train_loader.dataset)} # type: ignore
|
|
186
188
|
|
|
187
189
|
def _validation_step(self):
|
|
188
190
|
self.model.eval()
|
|
@@ -195,7 +197,7 @@ class MyTrainer:
|
|
|
195
197
|
output = output.view_as(target)
|
|
196
198
|
loss = self.criterion(output, target)
|
|
197
199
|
running_loss += loss.item() * features.size(0)
|
|
198
|
-
logs = {
|
|
200
|
+
logs = {PyTorchLogKeys.VAL_LOSS: running_loss / len(self.test_loader.dataset)} # type: ignore
|
|
199
201
|
return logs
|
|
200
202
|
|
|
201
203
|
def _predict_for_eval(self, dataloader: DataLoader):
|
|
@@ -230,14 +232,14 @@ class MyTrainer:
|
|
|
230
232
|
|
|
231
233
|
yield y_pred_batch, y_prob_batch, y_true_batch
|
|
232
234
|
|
|
233
|
-
def evaluate(self, save_dir:
|
|
235
|
+
def evaluate(self, save_dir: Union[str,Path], data: Optional[Union[DataLoader, Dataset]] = None):
|
|
234
236
|
"""
|
|
235
237
|
Evaluates the model on the given data.
|
|
236
238
|
|
|
237
239
|
Args:
|
|
238
240
|
data (DataLoader | Dataset | None ): The data to evaluate on.
|
|
239
241
|
Can be a DataLoader or a Dataset. If None, defaults to the trainer's internal test_dataset.
|
|
240
|
-
save_dir (str | Path
|
|
242
|
+
save_dir (str | Path): Directory to save all reports and plots.
|
|
241
243
|
"""
|
|
242
244
|
eval_loader = None
|
|
243
245
|
if isinstance(data, DataLoader):
|
|
@@ -273,14 +275,14 @@ class MyTrainer:
|
|
|
273
275
|
y_prob = np.concatenate(all_probs) if self.kind == "classification" else None
|
|
274
276
|
|
|
275
277
|
if self.kind == "classification":
|
|
276
|
-
classification_metrics(y_true, y_pred, y_prob
|
|
278
|
+
classification_metrics(save_dir, y_true, y_pred, y_prob)
|
|
277
279
|
else:
|
|
278
|
-
regression_metrics(y_true.flatten(), y_pred.flatten(), save_dir
|
|
280
|
+
regression_metrics(y_true.flatten(), y_pred.flatten(), save_dir)
|
|
279
281
|
|
|
280
282
|
print("\n--- Training History ---")
|
|
281
283
|
plot_losses(self.history, save_dir=save_dir)
|
|
282
284
|
|
|
283
|
-
def explain(self, explain_dataset: Optional[Dataset] = None, n_samples: int =
|
|
285
|
+
def explain(self, explain_dataset: Optional[Dataset] = None, n_samples: int = 1000,
|
|
284
286
|
feature_names: Optional[List[str]] = None, save_dir: Optional[Union[str,Path]] = None):
|
|
285
287
|
"""
|
|
286
288
|
Explains model predictions using SHAP and saves all artifacts.
|