dragon-ml-toolbox 2.4.0__py3-none-any.whl → 3.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dragon_ml_toolbox-2.4.0.dist-info → dragon_ml_toolbox-3.0.0.dist-info}/METADATA +7 -4
- dragon_ml_toolbox-3.0.0.dist-info/RECORD +25 -0
- ml_tools/ETL_engineering.py +8 -7
- ml_tools/GUI_tools.py +24 -25
- ml_tools/MICE_imputation.py +8 -4
- ml_tools/ML_callbacks.py +341 -0
- ml_tools/ML_evaluation.py +255 -0
- ml_tools/ML_trainer.py +344 -0
- ml_tools/ML_tutorial.py +300 -0
- ml_tools/PSO_optimization.py +27 -20
- ml_tools/RNN_forecast.py +49 -0
- ml_tools/VIF_factor.py +6 -5
- ml_tools/datasetmaster.py +601 -527
- ml_tools/ensemble_learning.py +12 -9
- ml_tools/handle_excel.py +9 -10
- ml_tools/logger.py +45 -8
- ml_tools/utilities.py +18 -1
- dragon_ml_toolbox-2.4.0.dist-info/RECORD +0 -22
- ml_tools/trainer.py +0 -346
- ml_tools/vision_helpers.py +0 -231
- {dragon_ml_toolbox-2.4.0.dist-info → dragon_ml_toolbox-3.0.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-2.4.0.dist-info → dragon_ml_toolbox-3.0.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-2.4.0.dist-info → dragon_ml_toolbox-3.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-2.4.0.dist-info → dragon_ml_toolbox-3.0.0.dist-info}/top_level.txt +0 -0
- /ml_tools/{pytorch_models.py → _pytorch_models.py} +0 -0
|
@@ -0,0 +1,255 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import matplotlib.pyplot as plt
|
|
4
|
+
from sklearn.metrics import (
|
|
5
|
+
classification_report,
|
|
6
|
+
ConfusionMatrixDisplay,
|
|
7
|
+
roc_curve,
|
|
8
|
+
roc_auc_score,
|
|
9
|
+
mean_squared_error,
|
|
10
|
+
mean_absolute_error,
|
|
11
|
+
r2_score,
|
|
12
|
+
median_absolute_error
|
|
13
|
+
)
|
|
14
|
+
import torch
|
|
15
|
+
import shap
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
from .utilities import make_fullpath
|
|
18
|
+
from .logger import _LOGGER
|
|
19
|
+
from typing import Union, Optional
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
__all__ = [
|
|
23
|
+
"plot_losses",
|
|
24
|
+
"classification_metrics",
|
|
25
|
+
"regression_metrics",
|
|
26
|
+
"shap_summary_plot"
|
|
27
|
+
]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def plot_losses(history: dict, save_dir: Optional[Union[str, Path]] = None):
|
|
31
|
+
"""
|
|
32
|
+
Plots training & validation loss curves from a history object.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
history (dict): A dictionary containing 'train_loss' and 'val_loss'.
|
|
36
|
+
save_dir (str | Path | None): Directory to save the plot image.
|
|
37
|
+
"""
|
|
38
|
+
train_loss = history.get('train_loss', [])
|
|
39
|
+
val_loss = history.get('val_loss', [])
|
|
40
|
+
|
|
41
|
+
if not train_loss and not val_loss:
|
|
42
|
+
print("Warning: Loss history is empty or incomplete. Cannot plot.")
|
|
43
|
+
return
|
|
44
|
+
|
|
45
|
+
fig, ax = plt.subplots(figsize=(10, 5), dpi=100)
|
|
46
|
+
|
|
47
|
+
# Plot training loss only if data for it exists
|
|
48
|
+
if train_loss:
|
|
49
|
+
epochs = range(1, len(train_loss) + 1)
|
|
50
|
+
ax.plot(epochs, train_loss, 'o-', label='Training Loss')
|
|
51
|
+
|
|
52
|
+
# Plot validation loss only if data for it exists
|
|
53
|
+
if val_loss:
|
|
54
|
+
epochs = range(1, len(val_loss) + 1)
|
|
55
|
+
ax.plot(epochs, val_loss, 'o-', label='Validation Loss')
|
|
56
|
+
|
|
57
|
+
ax.set_title('Training and Validation Loss')
|
|
58
|
+
ax.set_xlabel('Epochs')
|
|
59
|
+
ax.set_ylabel('Loss')
|
|
60
|
+
ax.legend()
|
|
61
|
+
ax.grid(True)
|
|
62
|
+
plt.tight_layout()
|
|
63
|
+
|
|
64
|
+
if save_dir:
|
|
65
|
+
save_dir_path = make_fullpath(save_dir, make=True)
|
|
66
|
+
save_path = save_dir_path / "loss_plot.svg"
|
|
67
|
+
plt.savefig(save_path)
|
|
68
|
+
_LOGGER.info(f"Loss plot saved as '{save_path.name}'")
|
|
69
|
+
else:
|
|
70
|
+
plt.show()
|
|
71
|
+
plt.close(fig)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def classification_metrics(y_true: np.ndarray, y_pred: np.ndarray, y_prob: Optional[np.ndarray] = None,
|
|
75
|
+
cmap: str = "Blues", save_dir: Optional[Union[str, Path]] = None):
|
|
76
|
+
"""
|
|
77
|
+
Displays and optionally saves classification metrics and plots.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
y_true (np.ndarray): Ground truth labels.
|
|
81
|
+
y_pred (np.ndarray): Predicted labels.
|
|
82
|
+
y_prob (np.ndarray, optional): Predicted probabilities for ROC curve.
|
|
83
|
+
cmap (str): Colormap for the confusion matrix.
|
|
84
|
+
save_dir (str | Path | None): Directory to save plots. If None, plots are shown not saved.
|
|
85
|
+
"""
|
|
86
|
+
print("--- Classification Report ---")
|
|
87
|
+
report: str = classification_report(y_true, y_pred) # type: ignore
|
|
88
|
+
print(report)
|
|
89
|
+
|
|
90
|
+
if save_dir:
|
|
91
|
+
save_dir_path = make_fullpath(save_dir, make=True)
|
|
92
|
+
# Save text report
|
|
93
|
+
report_path = save_dir_path / "classification_report.txt"
|
|
94
|
+
report_path.write_text(report, encoding="utf-8")
|
|
95
|
+
_LOGGER.info(f"Classification report saved as '{report_path.name}'")
|
|
96
|
+
|
|
97
|
+
# Save Confusion Matrix
|
|
98
|
+
fig_cm, ax_cm = plt.subplots(figsize=(6, 6), dpi=100)
|
|
99
|
+
ConfusionMatrixDisplay.from_predictions(y_true, y_pred, cmap=cmap, ax=ax_cm)
|
|
100
|
+
ax_cm.set_title("Confusion Matrix")
|
|
101
|
+
cm_path = save_dir_path / "confusion_matrix.svg"
|
|
102
|
+
plt.savefig(cm_path)
|
|
103
|
+
_LOGGER.info(f"Confusion matrix saved as '{cm_path.name}'")
|
|
104
|
+
plt.close(fig_cm)
|
|
105
|
+
|
|
106
|
+
# Save ROC Curve
|
|
107
|
+
if y_prob is not None and y_prob.ndim > 1 and y_prob.shape[1] >= 2:
|
|
108
|
+
fpr, tpr, _ = roc_curve(y_true, y_prob[:, 1])
|
|
109
|
+
auc = roc_auc_score(y_true, y_prob[:, 1])
|
|
110
|
+
fig_roc, ax_roc = plt.subplots(figsize=(6, 6), dpi=100)
|
|
111
|
+
ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}')
|
|
112
|
+
ax_roc.plot([0, 1], [0, 1], 'k--')
|
|
113
|
+
ax_roc.set_title('Receiver Operating Characteristic (ROC) Curve')
|
|
114
|
+
ax_roc.set_xlabel('False Positive Rate')
|
|
115
|
+
ax_roc.set_ylabel('True Positive Rate')
|
|
116
|
+
ax_roc.legend(loc='lower right')
|
|
117
|
+
ax_roc.grid(True)
|
|
118
|
+
roc_path = save_dir_path / "roc_curve.svg"
|
|
119
|
+
plt.savefig(roc_path)
|
|
120
|
+
_LOGGER.info(f"ROC curve saved as '{roc_path.name}'")
|
|
121
|
+
plt.close(fig_roc)
|
|
122
|
+
else:
|
|
123
|
+
# Show plots if not saving
|
|
124
|
+
ConfusionMatrixDisplay.from_predictions(y_true, y_pred, cmap=cmap)
|
|
125
|
+
plt.show()
|
|
126
|
+
if y_prob is not None and y_prob.ndim > 1 and y_prob.shape[1] >= 2:
|
|
127
|
+
fpr, tpr, _ = roc_curve(y_true, y_prob[:, 1])
|
|
128
|
+
auc = roc_auc_score(y_true, y_prob[:, 1])
|
|
129
|
+
plt.figure()
|
|
130
|
+
plt.plot(fpr, tpr, label=f'AUC = {auc:.2f}')
|
|
131
|
+
plt.plot([0, 1], [0, 1], 'k--')
|
|
132
|
+
plt.title('ROC Curve')
|
|
133
|
+
plt.show()
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Optional[Union[str, Path]] = None):
|
|
137
|
+
"""
|
|
138
|
+
Displays regression metrics and optionally saves plots and report.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
y_true (np.ndarray): Ground truth values.
|
|
142
|
+
y_pred (np.ndarray): Predicted values.
|
|
143
|
+
save_dir (str | None): Directory to save plots and report.
|
|
144
|
+
"""
|
|
145
|
+
rmse = np.sqrt(mean_squared_error(y_true, y_pred))
|
|
146
|
+
mae = mean_absolute_error(y_true, y_pred)
|
|
147
|
+
r2 = r2_score(y_true, y_pred)
|
|
148
|
+
medae = median_absolute_error(y_true, y_pred)
|
|
149
|
+
|
|
150
|
+
report_lines = [
|
|
151
|
+
"--- Regression Report ---",
|
|
152
|
+
f" Root Mean Squared Error (RMSE): {rmse:.4f}",
|
|
153
|
+
f" Mean Absolute Error (MAE): {mae:.4f}",
|
|
154
|
+
f" Median Absolute Error (MedAE): {medae:.4f}",
|
|
155
|
+
f" Coefficient of Determination (R²): {r2:.4f}"
|
|
156
|
+
]
|
|
157
|
+
report_string = "\n".join(report_lines)
|
|
158
|
+
print(report_string)
|
|
159
|
+
|
|
160
|
+
if save_dir:
|
|
161
|
+
save_dir_path = make_fullpath(save_dir, make=True)
|
|
162
|
+
# Save text report
|
|
163
|
+
report_path = save_dir_path / "regression_report.txt"
|
|
164
|
+
report_path.write_text(report_string)
|
|
165
|
+
_LOGGER.info(f"Regression report saved as '{report_path.name}'")
|
|
166
|
+
|
|
167
|
+
# Save residual plot
|
|
168
|
+
residuals = y_true - y_pred
|
|
169
|
+
fig_res, ax_res = plt.subplots(figsize=(8, 6), dpi=100)
|
|
170
|
+
ax_res.scatter(y_pred, residuals, alpha=0.6)
|
|
171
|
+
ax_res.axhline(0, color='red', linestyle='--')
|
|
172
|
+
ax_res.set_xlabel("Predicted Values")
|
|
173
|
+
ax_res.set_ylabel("Residuals")
|
|
174
|
+
ax_res.set_title("Residual Plot")
|
|
175
|
+
ax_res.grid(True)
|
|
176
|
+
plt.tight_layout()
|
|
177
|
+
res_path = save_dir_path / "residual_plot.svg"
|
|
178
|
+
plt.savefig(res_path)
|
|
179
|
+
_LOGGER.info(f"Residual plot saved as '{res_path.name}'")
|
|
180
|
+
plt.close(fig_res)
|
|
181
|
+
|
|
182
|
+
# Save true vs predicted plot
|
|
183
|
+
fig_tvp, ax_tvp = plt.subplots(figsize=(8, 6), dpi=100)
|
|
184
|
+
ax_tvp.scatter(y_true, y_pred, alpha=0.6)
|
|
185
|
+
ax_tvp.plot([y_true.min(), y_true.max()], [y_true.min(), y_true.max()], 'k--', lw=2)
|
|
186
|
+
ax_tvp.set_xlabel('True Values')
|
|
187
|
+
ax_tvp.set_ylabel('Predictions')
|
|
188
|
+
ax_tvp.set_title('True vs. Predicted Values')
|
|
189
|
+
ax_tvp.grid(True)
|
|
190
|
+
plt.tight_layout()
|
|
191
|
+
tvp_path = save_dir_path / "true_vs_predicted_plot.svg"
|
|
192
|
+
plt.savefig(tvp_path)
|
|
193
|
+
_LOGGER.info(f"True vs. Predicted plot saved as '{tvp_path.name}'")
|
|
194
|
+
plt.close(fig_tvp)
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def shap_summary_plot(model, background_data: torch.Tensor, instances_to_explain: torch.Tensor,
|
|
198
|
+
feature_names: Optional[list[str]]=None, save_dir: Optional[Union[str, Path]] = None):
|
|
199
|
+
"""
|
|
200
|
+
Calculates SHAP values and saves summary plots and data.
|
|
201
|
+
|
|
202
|
+
Args:
|
|
203
|
+
model (nn.Module): The trained PyTorch model.
|
|
204
|
+
background_data (torch.Tensor): A sample of data for the explainer background.
|
|
205
|
+
instances_to_explain (torch.Tensor): The specific data instances to explain.
|
|
206
|
+
feature_names (list of str | None): Names of the features for plot labeling.
|
|
207
|
+
save_dir (str | Path | None): Directory to save SHAP artifacts. If None, dot plot is shown.
|
|
208
|
+
"""
|
|
209
|
+
print("\n--- SHAP Value Explanation ---")
|
|
210
|
+
print("Calculating SHAP values... ")
|
|
211
|
+
|
|
212
|
+
model.eval()
|
|
213
|
+
model.cpu()
|
|
214
|
+
|
|
215
|
+
explainer = shap.DeepExplainer(model, background_data)
|
|
216
|
+
shap_values = explainer.shap_values(instances_to_explain)
|
|
217
|
+
|
|
218
|
+
shap_values_for_plot = shap_values[1] if isinstance(shap_values, list) else shap_values
|
|
219
|
+
if isinstance(shap_values, list):
|
|
220
|
+
_LOGGER.info("Using SHAP values for the positive class (class 1) for plots.")
|
|
221
|
+
|
|
222
|
+
if save_dir:
|
|
223
|
+
save_dir_path = make_fullpath(save_dir, make=True)
|
|
224
|
+
# Save Bar Plot
|
|
225
|
+
bar_path = save_dir_path / "shap_bar_plot.svg"
|
|
226
|
+
shap.summary_plot(shap_values_for_plot, instances_to_explain, feature_names=feature_names, plot_type="bar", show=False)
|
|
227
|
+
plt.title("SHAP Feature Importance")
|
|
228
|
+
plt.tight_layout()
|
|
229
|
+
plt.savefig(bar_path)
|
|
230
|
+
_LOGGER.info(f"SHAP bar plot saved as '{bar_path.name}'")
|
|
231
|
+
plt.close()
|
|
232
|
+
|
|
233
|
+
# Save Dot Plot
|
|
234
|
+
dot_path = save_dir_path / "shap_dot_plot.svg"
|
|
235
|
+
shap.summary_plot(shap_values_for_plot, instances_to_explain, feature_names=feature_names, plot_type="dot", show=False)
|
|
236
|
+
plt.title("SHAP Feature Importance")
|
|
237
|
+
plt.tight_layout()
|
|
238
|
+
plt.savefig(dot_path)
|
|
239
|
+
_LOGGER.info(f"SHAP dot plot saved as '{dot_path.name}'")
|
|
240
|
+
plt.close()
|
|
241
|
+
|
|
242
|
+
# Save Summary Data to CSV
|
|
243
|
+
summary_path = save_dir_path / "shap_summary.csv"
|
|
244
|
+
mean_abs_shap = np.abs(shap_values_for_plot).mean(axis=0)
|
|
245
|
+
if feature_names is None:
|
|
246
|
+
feature_names = [f'feature_{i}' for i in range(len(mean_abs_shap))]
|
|
247
|
+
summary_df = pd.DataFrame({
|
|
248
|
+
'feature': feature_names,
|
|
249
|
+
'mean_abs_shap_value': mean_abs_shap
|
|
250
|
+
}).sort_values('mean_abs_shap_value', ascending=False)
|
|
251
|
+
summary_df.to_csv(summary_path, index=False)
|
|
252
|
+
_LOGGER.info(f"SHAP summary data saved as '{summary_path.name}'")
|
|
253
|
+
else:
|
|
254
|
+
_LOGGER.info("No save directory provided. Displaying SHAP dot plot.")
|
|
255
|
+
shap.summary_plot(shap_values_for_plot, instances_to_explain, feature_names=feature_names, plot_type="dot")
|
ml_tools/ML_trainer.py
ADDED
|
@@ -0,0 +1,344 @@
|
|
|
1
|
+
from typing import List, Literal, Union, Optional
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from torch.utils.data import DataLoader, Dataset
|
|
4
|
+
import torch
|
|
5
|
+
from torch import nn
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
from .ML_callbacks import Callback, History, TqdmProgressBar
|
|
9
|
+
from .ML_evaluation import classification_metrics, regression_metrics, plot_losses, shap_summary_plot
|
|
10
|
+
from .utilities import _script_info, LogKeys
|
|
11
|
+
from .logger import _LOGGER
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"MyTrainer"
|
|
16
|
+
]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class MyTrainer:
|
|
20
|
+
def __init__(self, model: nn.Module, train_dataset: Dataset, test_dataset: Dataset,
|
|
21
|
+
kind: Literal["regression", "classification"],
|
|
22
|
+
criterion: nn.Module, optimizer: torch.optim.Optimizer,
|
|
23
|
+
device: Union[Literal['cuda', 'mps', 'cpu'],str], dataloader_workers: int = 2, callbacks: Optional[List[Callback]] = None):
|
|
24
|
+
"""
|
|
25
|
+
Automates the training process of a PyTorch Model.
|
|
26
|
+
|
|
27
|
+
Built-in Callbacks: `History`, `TqdmProgressBar`
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
model (nn.Module): The PyTorch model to train.
|
|
31
|
+
train_dataset (Dataset): The training dataset.
|
|
32
|
+
test_dataset (Dataset): The testing/validation dataset.
|
|
33
|
+
kind (str): The type of task, 'regression' or 'classification'.
|
|
34
|
+
criterion (nn.Module): The loss function.
|
|
35
|
+
optimizer (torch.optim.Optimizer): The optimizer.
|
|
36
|
+
device (str): The device to run training on ('cpu', 'cuda', 'mps').
|
|
37
|
+
dataloader_workers (int): Subprocesses for data loading. Defaults to 2.
|
|
38
|
+
callbacks (List[Callback] | None): A list of callbacks to use during training.
|
|
39
|
+
|
|
40
|
+
Note:
|
|
41
|
+
For **regression** tasks, suggested criterions include `nn.MSELoss` or `nn.L1Loss`.
|
|
42
|
+
|
|
43
|
+
For **classification** tasks, `nn.CrossEntropyLoss` (multi-class) or `nn.BCEWithLogitsLoss` (binary) are common choices.
|
|
44
|
+
"""
|
|
45
|
+
if kind not in ["regression", "classification"]:
|
|
46
|
+
raise TypeError("Kind must be 'regression' or 'classification'.")
|
|
47
|
+
|
|
48
|
+
self.model = model
|
|
49
|
+
self.train_dataset = train_dataset
|
|
50
|
+
self.test_dataset = test_dataset
|
|
51
|
+
self.kind = kind
|
|
52
|
+
self.criterion = criterion
|
|
53
|
+
self.optimizer = optimizer
|
|
54
|
+
self.device = self._validate_device(device)
|
|
55
|
+
self.dataloader_workers = dataloader_workers
|
|
56
|
+
|
|
57
|
+
# Callback handler - History and TqdmProgressBar are added by default
|
|
58
|
+
default_callbacks = [History(), TqdmProgressBar()]
|
|
59
|
+
user_callbacks = callbacks if callbacks is not None else []
|
|
60
|
+
self.callbacks = default_callbacks + user_callbacks
|
|
61
|
+
self._set_trainer_on_callbacks()
|
|
62
|
+
|
|
63
|
+
# Internal state
|
|
64
|
+
self.train_loader = None
|
|
65
|
+
self.test_loader = None
|
|
66
|
+
self.history = {}
|
|
67
|
+
self.epoch = 0
|
|
68
|
+
self.epochs = 0 # Total epochs for the fit run
|
|
69
|
+
self.stop_training = False
|
|
70
|
+
|
|
71
|
+
def _validate_device(self, device: str) -> torch.device:
|
|
72
|
+
"""Validates the selected device and returns a torch.device object."""
|
|
73
|
+
device_lower = device.lower()
|
|
74
|
+
if "cuda" in device_lower and not torch.cuda.is_available():
|
|
75
|
+
_LOGGER.warning("CUDA not available, switching to CPU.")
|
|
76
|
+
device = "cpu"
|
|
77
|
+
elif device_lower == "mps" and not torch.backends.mps.is_available():
|
|
78
|
+
_LOGGER.warning("Apple Metal Performance Shaders (MPS) not available, switching to CPU.")
|
|
79
|
+
device = "cpu"
|
|
80
|
+
return torch.device(device)
|
|
81
|
+
|
|
82
|
+
def _set_trainer_on_callbacks(self):
|
|
83
|
+
"""Gives each callback a reference to this trainer instance."""
|
|
84
|
+
for callback in self.callbacks:
|
|
85
|
+
callback.set_trainer(self)
|
|
86
|
+
|
|
87
|
+
def _create_dataloaders(self, batch_size: int, shuffle: bool):
|
|
88
|
+
"""Initializes the DataLoaders."""
|
|
89
|
+
# Ensure stability on MPS devices by setting num_workers to 0
|
|
90
|
+
loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
|
|
91
|
+
|
|
92
|
+
self.train_loader = DataLoader(
|
|
93
|
+
dataset=self.train_dataset,
|
|
94
|
+
batch_size=batch_size,
|
|
95
|
+
shuffle=shuffle,
|
|
96
|
+
num_workers=loader_workers,
|
|
97
|
+
pin_memory=(self.device.type == "cuda")
|
|
98
|
+
)
|
|
99
|
+
self.test_loader = DataLoader(
|
|
100
|
+
dataset=self.test_dataset,
|
|
101
|
+
batch_size=batch_size,
|
|
102
|
+
shuffle=False,
|
|
103
|
+
num_workers=loader_workers,
|
|
104
|
+
pin_memory=(self.device.type == "cuda")
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
def fit(self, epochs: int = 10, batch_size: int = 32, shuffle: bool = True):
|
|
108
|
+
"""
|
|
109
|
+
Starts the training-validation process of the model.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
epochs (int): The total number of epochs to train for.
|
|
113
|
+
batch_size (int): The number of samples per batch.
|
|
114
|
+
shuffle (bool): Whether to shuffle the training data at each epoch.
|
|
115
|
+
"""
|
|
116
|
+
self.epochs = epochs
|
|
117
|
+
self._create_dataloaders(batch_size, shuffle)
|
|
118
|
+
self.model.to(self.device)
|
|
119
|
+
|
|
120
|
+
# Reset stop_training flag on the trainer
|
|
121
|
+
self.stop_training = False
|
|
122
|
+
|
|
123
|
+
self.callbacks_hook('on_train_begin')
|
|
124
|
+
|
|
125
|
+
for epoch in range(1, self.epochs + 1):
|
|
126
|
+
self.epoch = epoch
|
|
127
|
+
epoch_logs = {}
|
|
128
|
+
self.callbacks_hook('on_epoch_begin', epoch, logs=epoch_logs)
|
|
129
|
+
|
|
130
|
+
train_logs = self._train_step()
|
|
131
|
+
epoch_logs.update(train_logs)
|
|
132
|
+
|
|
133
|
+
val_logs = self._validation_step()
|
|
134
|
+
epoch_logs.update(val_logs)
|
|
135
|
+
|
|
136
|
+
self.callbacks_hook('on_epoch_end', epoch, logs=epoch_logs)
|
|
137
|
+
|
|
138
|
+
# Check the early stopping flag
|
|
139
|
+
if self.stop_training:
|
|
140
|
+
break
|
|
141
|
+
|
|
142
|
+
self.callbacks_hook('on_train_end')
|
|
143
|
+
return self.history
|
|
144
|
+
|
|
145
|
+
def _train_step(self):
|
|
146
|
+
self.model.train()
|
|
147
|
+
running_loss = 0.0
|
|
148
|
+
# Enumerate to get batch index
|
|
149
|
+
for batch_idx, (features, target) in enumerate(self.train_loader): # type: ignore
|
|
150
|
+
# Create a log dictionary for the batch
|
|
151
|
+
batch_logs = {
|
|
152
|
+
LogKeys.BATCH_INDEX: batch_idx,
|
|
153
|
+
LogKeys.BATCH_SIZE: features.size(0)
|
|
154
|
+
}
|
|
155
|
+
self.callbacks_hook('on_batch_begin', batch_idx, logs=batch_logs)
|
|
156
|
+
|
|
157
|
+
features, target = features.to(self.device), target.to(self.device)
|
|
158
|
+
self.optimizer.zero_grad()
|
|
159
|
+
output = self.model(features)
|
|
160
|
+
if isinstance(self.criterion, (nn.MSELoss, nn.L1Loss)):
|
|
161
|
+
output = output.view_as(target)
|
|
162
|
+
loss = self.criterion(output, target)
|
|
163
|
+
loss.backward()
|
|
164
|
+
self.optimizer.step()
|
|
165
|
+
|
|
166
|
+
# Calculate batch loss and update running loss for the epoch
|
|
167
|
+
batch_loss = loss.item()
|
|
168
|
+
running_loss += batch_loss * features.size(0)
|
|
169
|
+
|
|
170
|
+
# Add the batch loss to the logs and call the end-of-batch hook
|
|
171
|
+
batch_logs[LogKeys.BATCH_LOSS] = batch_loss
|
|
172
|
+
self.callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
|
|
173
|
+
|
|
174
|
+
# Return the average loss for the entire epoch
|
|
175
|
+
return {LogKeys.TRAIN_LOSS: running_loss / len(self.train_loader.dataset)} # type: ignore
|
|
176
|
+
|
|
177
|
+
def _validation_step(self):
|
|
178
|
+
self.model.eval()
|
|
179
|
+
running_loss = 0.0
|
|
180
|
+
with torch.no_grad():
|
|
181
|
+
for features, target in self.test_loader: # type: ignore
|
|
182
|
+
features, target = features.to(self.device), target.to(self.device)
|
|
183
|
+
output = self.model(features)
|
|
184
|
+
if isinstance(self.criterion, (nn.MSELoss, nn.L1Loss)):
|
|
185
|
+
output = output.view_as(target)
|
|
186
|
+
loss = self.criterion(output, target)
|
|
187
|
+
running_loss += loss.item() * features.size(0)
|
|
188
|
+
logs = {LogKeys.VAL_LOSS: running_loss / len(self.test_loader.dataset)} # type: ignore
|
|
189
|
+
return logs
|
|
190
|
+
|
|
191
|
+
def predict(self, dataloader: DataLoader):
|
|
192
|
+
"""
|
|
193
|
+
Yields model predictions batch by batch, avoids loading all predictions into memory at once.
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
dataloader (DataLoader): The dataloader to predict on.
|
|
197
|
+
|
|
198
|
+
Yields:
|
|
199
|
+
tuple: A tuple containing (y_pred_batch, y_prob_batch, y_true_batch).
|
|
200
|
+
y_prob_batch is None for regression tasks.
|
|
201
|
+
"""
|
|
202
|
+
self.model.eval()
|
|
203
|
+
self.model.to(self.device)
|
|
204
|
+
with torch.no_grad():
|
|
205
|
+
for features, target in dataloader:
|
|
206
|
+
features = features.to(self.device)
|
|
207
|
+
output = self.model(features).cpu()
|
|
208
|
+
y_true_batch = target.numpy()
|
|
209
|
+
|
|
210
|
+
if self.kind == "classification":
|
|
211
|
+
probs = nn.functional.softmax(output, dim=1)
|
|
212
|
+
preds = torch.argmax(probs, dim=1)
|
|
213
|
+
y_pred_batch = preds.numpy()
|
|
214
|
+
y_prob_batch = probs.numpy()
|
|
215
|
+
else:
|
|
216
|
+
y_pred_batch = output.numpy()
|
|
217
|
+
y_prob_batch = None
|
|
218
|
+
|
|
219
|
+
yield y_pred_batch, y_prob_batch, y_true_batch
|
|
220
|
+
|
|
221
|
+
def evaluate(self, data: Optional[Union[DataLoader, Dataset]] = None, save_dir: Optional[Union[str,Path]] = None):
|
|
222
|
+
"""
|
|
223
|
+
Evaluates the model on the given data.
|
|
224
|
+
|
|
225
|
+
Args:
|
|
226
|
+
data (DataLoader | Dataset | None ): The data to evaluate on.
|
|
227
|
+
Can be a DataLoader or a Dataset. If None, defaults to the trainer's internal test_dataset.
|
|
228
|
+
save_dir (str | Path | None): Directory to save all reports and plots. If None, metrics are shown but not saved.
|
|
229
|
+
"""
|
|
230
|
+
eval_loader = None
|
|
231
|
+
if isinstance(data, DataLoader):
|
|
232
|
+
eval_loader = data
|
|
233
|
+
else:
|
|
234
|
+
# Determine which dataset to use (the one passed in, or the default test_dataset)
|
|
235
|
+
dataset_to_use = data if data is not None else self.test_dataset
|
|
236
|
+
if not isinstance(dataset_to_use, Dataset):
|
|
237
|
+
raise ValueError("Cannot evaluate. No valid DataLoader or Dataset was provided, "
|
|
238
|
+
"and no test_dataset is available in the trainer.")
|
|
239
|
+
|
|
240
|
+
# Create a new DataLoader from the dataset
|
|
241
|
+
eval_loader = DataLoader(
|
|
242
|
+
dataset=dataset_to_use,
|
|
243
|
+
batch_size=32, # A sensible default for evaluation
|
|
244
|
+
shuffle=False,
|
|
245
|
+
num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
|
|
246
|
+
pin_memory=(self.device.type == "cuda")
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
print("\n--- Model Evaluation ---")
|
|
250
|
+
|
|
251
|
+
# Collect results from the predict generator
|
|
252
|
+
all_preds, all_probs, all_true = [], [], []
|
|
253
|
+
for y_pred_b, y_prob_b, y_true_b in self.predict(eval_loader):
|
|
254
|
+
all_preds.append(y_pred_b)
|
|
255
|
+
if y_prob_b is not None:
|
|
256
|
+
all_probs.append(y_prob_b)
|
|
257
|
+
all_true.append(y_true_b)
|
|
258
|
+
|
|
259
|
+
y_pred = np.concatenate(all_preds)
|
|
260
|
+
y_true = np.concatenate(all_true)
|
|
261
|
+
y_prob = np.concatenate(all_probs) if self.kind == "classification" else None
|
|
262
|
+
|
|
263
|
+
if self.kind == "classification":
|
|
264
|
+
classification_metrics(y_true, y_pred, y_prob, save_dir=save_dir)
|
|
265
|
+
else:
|
|
266
|
+
regression_metrics(y_true.flatten(), y_pred.flatten(), save_dir=save_dir)
|
|
267
|
+
|
|
268
|
+
print("\n--- Training History ---")
|
|
269
|
+
plot_losses(self.history, save_dir=save_dir)
|
|
270
|
+
|
|
271
|
+
def explain(self, explain_dataset: Optional[Dataset] = None, n_samples: int = 100,
|
|
272
|
+
feature_names: Optional[List[str]] = None, save_dir: Optional[str] = None):
|
|
273
|
+
"""
|
|
274
|
+
Explains model predictions using SHAP and saves all artifacts.
|
|
275
|
+
|
|
276
|
+
The background data is automatically sampled from the trainer's training dataset.
|
|
277
|
+
|
|
278
|
+
Args:
|
|
279
|
+
explain_dataset (Dataset, optional): A specific dataset to explain.
|
|
280
|
+
If None, the trainer's test dataset is used.
|
|
281
|
+
n_samples (int): The number of samples to use for both background and explanation.
|
|
282
|
+
feature_names (List[str], optional): Names for the features.
|
|
283
|
+
save_dir (str, optional): Directory to save all SHAP artifacts.
|
|
284
|
+
"""
|
|
285
|
+
# Internal helper to create a dataloader and get a random sample
|
|
286
|
+
def _get_random_sample(dataset: Dataset, num_samples: int):
|
|
287
|
+
if dataset is None:
|
|
288
|
+
return None
|
|
289
|
+
|
|
290
|
+
# For MPS devices, num_workers must be 0 to ensure stability
|
|
291
|
+
loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
|
|
292
|
+
|
|
293
|
+
loader = DataLoader(
|
|
294
|
+
dataset,
|
|
295
|
+
batch_size=64,
|
|
296
|
+
shuffle=False,
|
|
297
|
+
num_workers=loader_workers
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
all_features = [features for features, _ in loader]
|
|
301
|
+
if not all_features:
|
|
302
|
+
return None
|
|
303
|
+
|
|
304
|
+
full_data = torch.cat(all_features, dim=0)
|
|
305
|
+
|
|
306
|
+
if num_samples >= full_data.size(0):
|
|
307
|
+
return full_data
|
|
308
|
+
|
|
309
|
+
rand_indices = torch.randperm(full_data.size(0))[:num_samples]
|
|
310
|
+
return full_data[rand_indices]
|
|
311
|
+
|
|
312
|
+
print(f"\n--- Preparing SHAP Data (sampling up to {n_samples} instances) ---")
|
|
313
|
+
|
|
314
|
+
# 1. Get background data from the trainer's train_dataset
|
|
315
|
+
background_data = _get_random_sample(self.train_dataset, n_samples)
|
|
316
|
+
if background_data is None:
|
|
317
|
+
print("Warning: Trainer's train_dataset is empty or invalid. Skipping SHAP analysis.")
|
|
318
|
+
return
|
|
319
|
+
|
|
320
|
+
# 2. Determine target dataset and get explanation instances
|
|
321
|
+
target_dataset = explain_dataset if explain_dataset is not None else self.test_dataset
|
|
322
|
+
instances_to_explain = _get_random_sample(target_dataset, n_samples)
|
|
323
|
+
if instances_to_explain is None:
|
|
324
|
+
print("Warning: Explanation dataset is empty or invalid. Skipping SHAP analysis.")
|
|
325
|
+
return
|
|
326
|
+
|
|
327
|
+
# 3. Call the plotting function
|
|
328
|
+
shap_summary_plot(
|
|
329
|
+
model=self.model,
|
|
330
|
+
background_data=background_data,
|
|
331
|
+
instances_to_explain=instances_to_explain,
|
|
332
|
+
feature_names=feature_names,
|
|
333
|
+
save_dir=save_dir
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
def callbacks_hook(self, method_name: str, *args, **kwargs):
|
|
338
|
+
"""Calls the specified method on all callbacks."""
|
|
339
|
+
for callback in self.callbacks:
|
|
340
|
+
method = getattr(callback, method_name)
|
|
341
|
+
method(*args, **kwargs)
|
|
342
|
+
|
|
343
|
+
def info():
|
|
344
|
+
_script_info(__all__)
|