dragon-ml-toolbox 2.4.0__py3-none-any.whl → 3.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dragon-ml-toolbox might be problematic. Click here for more details.
- {dragon_ml_toolbox-2.4.0.dist-info → dragon_ml_toolbox-3.1.0.dist-info}/METADATA +7 -4
- dragon_ml_toolbox-3.1.0.dist-info/RECORD +25 -0
- ml_tools/ETL_engineering.py +49 -19
- ml_tools/GUI_tools.py +24 -25
- ml_tools/MICE_imputation.py +8 -4
- ml_tools/ML_callbacks.py +341 -0
- ml_tools/ML_evaluation.py +255 -0
- ml_tools/ML_trainer.py +344 -0
- ml_tools/ML_tutorial.py +300 -0
- ml_tools/PSO_optimization.py +27 -20
- ml_tools/RNN_forecast.py +49 -0
- ml_tools/VIF_factor.py +6 -5
- ml_tools/data_exploration.py +2 -2
- ml_tools/datasetmaster.py +601 -527
- ml_tools/ensemble_learning.py +12 -9
- ml_tools/handle_excel.py +9 -10
- ml_tools/logger.py +45 -8
- ml_tools/utilities.py +18 -1
- dragon_ml_toolbox-2.4.0.dist-info/RECORD +0 -22
- ml_tools/trainer.py +0 -346
- ml_tools/vision_helpers.py +0 -231
- {dragon_ml_toolbox-2.4.0.dist-info → dragon_ml_toolbox-3.1.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-2.4.0.dist-info → dragon_ml_toolbox-3.1.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-2.4.0.dist-info → dragon_ml_toolbox-3.1.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-2.4.0.dist-info → dragon_ml_toolbox-3.1.0.dist-info}/top_level.txt +0 -0
- /ml_tools/{pytorch_models.py → _pytorch_models.py} +0 -0
ml_tools/ML_callbacks.py
ADDED
|
@@ -0,0 +1,341 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import torch
|
|
3
|
+
from tqdm.auto import tqdm
|
|
4
|
+
from .utilities import make_fullpath, LogKeys
|
|
5
|
+
from .logger import _LOGGER
|
|
6
|
+
from typing import Optional
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"Callback",
|
|
11
|
+
"History",
|
|
12
|
+
"TqdmProgressBar",
|
|
13
|
+
"EarlyStopping",
|
|
14
|
+
"ModelCheckpoint",
|
|
15
|
+
"LRScheduler"
|
|
16
|
+
]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class Callback:
|
|
20
|
+
"""
|
|
21
|
+
Abstract base class used to build new callbacks.
|
|
22
|
+
|
|
23
|
+
The methods of this class are automatically called by the Trainer at different
|
|
24
|
+
points during training. Subclasses can override these methods to implement
|
|
25
|
+
custom logic.
|
|
26
|
+
"""
|
|
27
|
+
def __init__(self):
|
|
28
|
+
self.trainer = None
|
|
29
|
+
|
|
30
|
+
def set_trainer(self, trainer):
|
|
31
|
+
"""This is called by the Trainer to associate itself with the callback."""
|
|
32
|
+
self.trainer = trainer
|
|
33
|
+
|
|
34
|
+
def on_train_begin(self, logs=None):
|
|
35
|
+
"""Called at the beginning of training."""
|
|
36
|
+
pass
|
|
37
|
+
|
|
38
|
+
def on_train_end(self, logs=None):
|
|
39
|
+
"""Called at the end of training."""
|
|
40
|
+
pass
|
|
41
|
+
|
|
42
|
+
def on_epoch_begin(self, epoch, logs=None):
|
|
43
|
+
"""Called at the beginning of an epoch."""
|
|
44
|
+
pass
|
|
45
|
+
|
|
46
|
+
def on_epoch_end(self, epoch, logs=None):
|
|
47
|
+
"""Called at the end of an epoch."""
|
|
48
|
+
pass
|
|
49
|
+
|
|
50
|
+
def on_batch_begin(self, batch, logs=None):
|
|
51
|
+
"""Called at the beginning of a training batch."""
|
|
52
|
+
pass
|
|
53
|
+
|
|
54
|
+
def on_batch_end(self, batch, logs=None):
|
|
55
|
+
"""Called at the end of a training batch."""
|
|
56
|
+
pass
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class History(Callback):
|
|
60
|
+
"""
|
|
61
|
+
Callback that records events into a `history` dictionary.
|
|
62
|
+
|
|
63
|
+
This callback is automatically applied to every MyTrainer model.
|
|
64
|
+
The `history` attribute is a dictionary mapping metric names (e.g., 'val_loss')
|
|
65
|
+
to a list of metric values.
|
|
66
|
+
"""
|
|
67
|
+
def on_train_begin(self, logs=None):
|
|
68
|
+
# Clear history at the beginning of training
|
|
69
|
+
self.trainer.history = {} # type: ignore
|
|
70
|
+
|
|
71
|
+
def on_epoch_end(self, epoch, logs=None):
|
|
72
|
+
logs = logs or {}
|
|
73
|
+
for k, v in logs.items():
|
|
74
|
+
# Append new log values to the history dictionary
|
|
75
|
+
self.trainer.history.setdefault(k, []).append(v) # type: ignore
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class TqdmProgressBar(Callback):
|
|
79
|
+
"""Callback that provides a tqdm progress bar for training."""
|
|
80
|
+
def __init__(self):
|
|
81
|
+
self.epoch_bar = None
|
|
82
|
+
self.batch_bar = None
|
|
83
|
+
|
|
84
|
+
def on_train_begin(self, logs=None):
|
|
85
|
+
self.epochs = self.trainer.epochs # type: ignore
|
|
86
|
+
self.epoch_bar = tqdm(total=self.epochs, desc="Training Progress")
|
|
87
|
+
|
|
88
|
+
def on_epoch_begin(self, epoch, logs=None):
|
|
89
|
+
total_batches = len(self.trainer.train_loader) # type: ignore
|
|
90
|
+
self.batch_bar = tqdm(total=total_batches, desc=f"Epoch {epoch}/{self.epochs}", leave=False)
|
|
91
|
+
|
|
92
|
+
def on_batch_end(self, batch, logs=None):
|
|
93
|
+
self.batch_bar.update(1) # type: ignore
|
|
94
|
+
if logs:
|
|
95
|
+
self.batch_bar.set_postfix(loss=f"{logs.get(LogKeys.BATCH_LOSS, 0):.4f}") # type: ignore
|
|
96
|
+
|
|
97
|
+
def on_epoch_end(self, epoch, logs=None):
|
|
98
|
+
self.batch_bar.close() # type: ignore
|
|
99
|
+
self.epoch_bar.update(1) # type: ignore
|
|
100
|
+
if logs:
|
|
101
|
+
train_loss_str = f"{logs.get(LogKeys.TRAIN_LOSS, 0):.4f}"
|
|
102
|
+
val_loss_str = f"{logs.get(LogKeys.VAL_LOSS, 0):.4f}"
|
|
103
|
+
self.epoch_bar.set_postfix_str(f"Train Loss: {train_loss_str}, Val Loss: {val_loss_str}") # type: ignore
|
|
104
|
+
|
|
105
|
+
def on_train_end(self, logs=None):
|
|
106
|
+
self.epoch_bar.close() # type: ignore
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class EarlyStopping(Callback):
|
|
110
|
+
"""
|
|
111
|
+
Stop training when a monitored metric has stopped improving.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
monitor (str): Quantity to be monitored. Defaults to 'val_loss'.
|
|
115
|
+
min_delta (float): Minimum change in the monitored quantity to qualify as an improvement.
|
|
116
|
+
patience (int): Number of epochs with no improvement after which training will be stopped.
|
|
117
|
+
mode (str): One of {'auto', 'min', 'max'}. In 'min' mode, training will stop when the quantity
|
|
118
|
+
monitored has stopped decreasing; in 'max' mode it will stop when the quantity
|
|
119
|
+
monitored has stopped increasing; in 'auto' mode, the direction is automatically
|
|
120
|
+
inferred from the name of the monitored quantity.
|
|
121
|
+
verbose (int): Verbosity mode.
|
|
122
|
+
"""
|
|
123
|
+
def __init__(self, monitor: str=LogKeys.VAL_LOSS, min_delta=0.0, patience=3, mode='auto', verbose=1):
|
|
124
|
+
super().__init__()
|
|
125
|
+
self.monitor = monitor
|
|
126
|
+
self.patience = patience
|
|
127
|
+
self.min_delta = min_delta
|
|
128
|
+
self.wait = 0
|
|
129
|
+
self.stopped_epoch = 0
|
|
130
|
+
self.verbose = verbose
|
|
131
|
+
|
|
132
|
+
if mode not in ['auto', 'min', 'max']:
|
|
133
|
+
raise ValueError(f"EarlyStopping mode {mode} is unknown, choose one of ('auto', 'min', 'max')")
|
|
134
|
+
self.mode = mode
|
|
135
|
+
|
|
136
|
+
# Determine the comparison operator based on the mode
|
|
137
|
+
if self.mode == 'min':
|
|
138
|
+
self.monitor_op = np.less
|
|
139
|
+
elif self.mode == 'max':
|
|
140
|
+
self.monitor_op = np.greater
|
|
141
|
+
else: # auto mode
|
|
142
|
+
if 'acc' in self.monitor.lower():
|
|
143
|
+
self.monitor_op = np.greater
|
|
144
|
+
else: # Default to min mode for loss or other metrics
|
|
145
|
+
self.monitor_op = np.less
|
|
146
|
+
|
|
147
|
+
self.best = np.Inf if self.monitor_op == np.less else -np.Inf
|
|
148
|
+
|
|
149
|
+
def on_train_begin(self, logs=None):
|
|
150
|
+
# Reset state at the beginning of training
|
|
151
|
+
self.wait = 0
|
|
152
|
+
self.stopped_epoch = 0
|
|
153
|
+
self.best = np.Inf if self.monitor_op == np.less else -np.Inf
|
|
154
|
+
|
|
155
|
+
def on_epoch_end(self, epoch, logs=None):
|
|
156
|
+
current = logs.get(self.monitor) # type: ignore
|
|
157
|
+
if current is None:
|
|
158
|
+
return
|
|
159
|
+
|
|
160
|
+
# Determine the comparison threshold based on the mode
|
|
161
|
+
if self.monitor_op == np.less:
|
|
162
|
+
# For 'min' mode, we need to be smaller than 'best' by at least 'min_delta'
|
|
163
|
+
# Correct check: current < self.best - self.min_delta
|
|
164
|
+
is_improvement = self.monitor_op(current, self.best - self.min_delta)
|
|
165
|
+
else:
|
|
166
|
+
# For 'max' mode, we need to be greater than 'best' by at least 'min_delta'
|
|
167
|
+
# Correct check: current > self.best + self.min_delta
|
|
168
|
+
is_improvement = self.monitor_op(current, self.best + self.min_delta)
|
|
169
|
+
|
|
170
|
+
if is_improvement:
|
|
171
|
+
if self.verbose > 1:
|
|
172
|
+
_LOGGER.info(f"EarlyStopping: {self.monitor} improved from {self.best:.4f} to {current:.4f}")
|
|
173
|
+
self.best = current
|
|
174
|
+
self.wait = 0
|
|
175
|
+
else:
|
|
176
|
+
self.wait += 1
|
|
177
|
+
if self.wait >= self.patience:
|
|
178
|
+
self.stopped_epoch = epoch
|
|
179
|
+
self.trainer.stop_training = True # type: ignore
|
|
180
|
+
if self.verbose > 0:
|
|
181
|
+
print("")
|
|
182
|
+
_LOGGER.info(f"Epoch {epoch+1}: early stopping after {self.wait} epochs with no improvement.")
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
class ModelCheckpoint(Callback):
|
|
186
|
+
"""
|
|
187
|
+
Saves the model to a directory with automated filename generation and rotation. The filename includes the epoch and score.
|
|
188
|
+
|
|
189
|
+
- If `save_best_only` is True, it saves the single best model, deleting the
|
|
190
|
+
previous best.
|
|
191
|
+
- If `save_best_only` is False, it keeps the 3 most recent checkpoints,
|
|
192
|
+
deleting the oldest ones automatically.
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
save_dir (str): Directory where checkpoint files will be saved.
|
|
196
|
+
monitor (str): Metric to monitor for `save_best_only=True`.
|
|
197
|
+
save_best_only (bool): If true, save only the best model.
|
|
198
|
+
mode (str): One of {'auto', 'min', 'max'}.
|
|
199
|
+
verbose (int): Verbosity mode.
|
|
200
|
+
"""
|
|
201
|
+
def __init__(self, save_dir: str, monitor: str = LogKeys.VAL_LOSS,
|
|
202
|
+
save_best_only: bool = False, mode: str = 'auto', verbose: int = 1):
|
|
203
|
+
super().__init__()
|
|
204
|
+
self.save_dir = make_fullpath(save_dir, make=True)
|
|
205
|
+
if not self.save_dir.is_dir():
|
|
206
|
+
_LOGGER.error(f"{save_dir} is not a valid directory.")
|
|
207
|
+
raise IOError()
|
|
208
|
+
|
|
209
|
+
self.monitor = monitor
|
|
210
|
+
self.save_best_only = save_best_only
|
|
211
|
+
self.verbose = verbose
|
|
212
|
+
|
|
213
|
+
# State variables to be managed during training
|
|
214
|
+
self.saved_checkpoints = []
|
|
215
|
+
self.last_best_filepath = None
|
|
216
|
+
|
|
217
|
+
if mode not in ['auto', 'min', 'max']:
|
|
218
|
+
raise ValueError(f"ModelCheckpoint mode {mode} is unknown.")
|
|
219
|
+
self.mode = mode
|
|
220
|
+
|
|
221
|
+
if self.mode == 'min':
|
|
222
|
+
self.monitor_op = np.less
|
|
223
|
+
elif self.mode == 'max':
|
|
224
|
+
self.monitor_op = np.greater
|
|
225
|
+
else:
|
|
226
|
+
self.monitor_op = np.less if 'loss' in self.monitor else np.greater
|
|
227
|
+
|
|
228
|
+
self.best = np.Inf if self.monitor_op == np.less else -np.Inf
|
|
229
|
+
|
|
230
|
+
def on_train_begin(self, logs=None):
|
|
231
|
+
"""Reset state when training starts."""
|
|
232
|
+
self.best = np.Inf if self.monitor_op == np.less else -np.Inf
|
|
233
|
+
self.saved_checkpoints = []
|
|
234
|
+
self.last_best_filepath = None
|
|
235
|
+
|
|
236
|
+
def on_epoch_end(self, epoch, logs=None):
|
|
237
|
+
logs = logs or {}
|
|
238
|
+
self.save_dir.mkdir(parents=True, exist_ok=True)
|
|
239
|
+
|
|
240
|
+
if self.save_best_only:
|
|
241
|
+
self._save_best_model(epoch, logs)
|
|
242
|
+
else:
|
|
243
|
+
self._save_rolling_checkpoints(epoch, logs)
|
|
244
|
+
|
|
245
|
+
def _save_best_model(self, epoch, logs):
|
|
246
|
+
"""Saves a single best model and deletes the previous one."""
|
|
247
|
+
current = logs.get(self.monitor)
|
|
248
|
+
if current is None:
|
|
249
|
+
return
|
|
250
|
+
|
|
251
|
+
if self.monitor_op(current, self.best):
|
|
252
|
+
old_best_str = f"{self.best:.4f}" if self.best not in [np.Inf, -np.Inf] else "inf"
|
|
253
|
+
|
|
254
|
+
# Create a descriptive filename
|
|
255
|
+
filename = f"epoch_{epoch}-{self.monitor}_{current:.4f}.pth"
|
|
256
|
+
new_filepath = self.save_dir / filename
|
|
257
|
+
|
|
258
|
+
if self.verbose > 0:
|
|
259
|
+
print("")
|
|
260
|
+
_LOGGER.info(f"Epoch {epoch}: {self.monitor} improved from {old_best_str} to {current:.4f}, saving model to {new_filepath}")
|
|
261
|
+
|
|
262
|
+
# Save the new best model
|
|
263
|
+
torch.save(self.trainer.model.state_dict(), new_filepath) # type: ignore
|
|
264
|
+
|
|
265
|
+
# Delete the old best model file
|
|
266
|
+
if self.last_best_filepath and self.last_best_filepath.exists():
|
|
267
|
+
self.last_best_filepath.unlink()
|
|
268
|
+
|
|
269
|
+
# Update state
|
|
270
|
+
self.best = current
|
|
271
|
+
self.last_best_filepath = new_filepath
|
|
272
|
+
|
|
273
|
+
def _save_rolling_checkpoints(self, epoch, logs):
|
|
274
|
+
"""Saves the latest model and keeps only the last 5."""
|
|
275
|
+
filename = f"epoch_{epoch}.pth"
|
|
276
|
+
filepath = self.save_dir / filename
|
|
277
|
+
|
|
278
|
+
if self.verbose > 0:
|
|
279
|
+
print("")
|
|
280
|
+
_LOGGER.info(f'Epoch {epoch}: saving model to {filepath}')
|
|
281
|
+
torch.save(self.trainer.model.state_dict(), filepath) # type: ignore
|
|
282
|
+
|
|
283
|
+
self.saved_checkpoints.append(filepath)
|
|
284
|
+
|
|
285
|
+
# If we have more than n checkpoints, remove the oldest one
|
|
286
|
+
if len(self.saved_checkpoints) > 3:
|
|
287
|
+
file_to_delete = self.saved_checkpoints.pop(0)
|
|
288
|
+
if file_to_delete.exists():
|
|
289
|
+
if self.verbose > 0:
|
|
290
|
+
_LOGGER.info(f" -> Deleting old checkpoint: {file_to_delete.name}")
|
|
291
|
+
file_to_delete.unlink()
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
class LRScheduler(Callback):
|
|
295
|
+
"""
|
|
296
|
+
Callback to manage a PyTorch learning rate scheduler.
|
|
297
|
+
|
|
298
|
+
This callback automatically calls the scheduler's `step()` method at the
|
|
299
|
+
end of each epoch. It also logs a message when the learning rate changes.
|
|
300
|
+
|
|
301
|
+
Args:
|
|
302
|
+
scheduler: An initialized PyTorch learning rate scheduler.
|
|
303
|
+
monitor (str, optional): The metric to monitor for schedulers that
|
|
304
|
+
require it, like `ReduceLROnPlateau`.
|
|
305
|
+
Should match a key in the logs (e.g., 'val_loss').
|
|
306
|
+
"""
|
|
307
|
+
def __init__(self, scheduler, monitor: Optional[str] = None):
|
|
308
|
+
super().__init__()
|
|
309
|
+
self.scheduler = scheduler
|
|
310
|
+
self.monitor = monitor
|
|
311
|
+
self.previous_lr = None
|
|
312
|
+
|
|
313
|
+
def on_train_begin(self, logs=None):
|
|
314
|
+
"""Store the initial learning rate."""
|
|
315
|
+
self.previous_lr = self.trainer.optimizer.param_groups[0]['lr'] # type: ignore
|
|
316
|
+
|
|
317
|
+
def on_epoch_end(self, epoch, logs=None):
|
|
318
|
+
"""Step the scheduler and log any change in learning rate."""
|
|
319
|
+
# For schedulers that need a metric (e.g., val_loss)
|
|
320
|
+
if isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
|
|
321
|
+
if self.monitor is None:
|
|
322
|
+
raise ValueError("LRScheduler needs a `monitor` metric for ReduceLROnPlateau.")
|
|
323
|
+
|
|
324
|
+
metric_val = logs.get(self.monitor) # type: ignore
|
|
325
|
+
if metric_val is not None:
|
|
326
|
+
self.scheduler.step(metric_val)
|
|
327
|
+
else:
|
|
328
|
+
print("")
|
|
329
|
+
_LOGGER.warning(f"LRScheduler could not find metric '{self.monitor}' in logs.")
|
|
330
|
+
|
|
331
|
+
# For all other schedulers
|
|
332
|
+
else:
|
|
333
|
+
self.scheduler.step()
|
|
334
|
+
|
|
335
|
+
# Log the change if the LR was updated
|
|
336
|
+
current_lr = self.trainer.optimizer.param_groups[0]['lr'] # type: ignore
|
|
337
|
+
if current_lr != self.previous_lr:
|
|
338
|
+
print("")
|
|
339
|
+
_LOGGER.info(f"Epoch {epoch}: Learning rate changed to {current_lr:.6f}")
|
|
340
|
+
self.previous_lr = current_lr
|
|
341
|
+
|
|
@@ -0,0 +1,255 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import matplotlib.pyplot as plt
|
|
4
|
+
from sklearn.metrics import (
|
|
5
|
+
classification_report,
|
|
6
|
+
ConfusionMatrixDisplay,
|
|
7
|
+
roc_curve,
|
|
8
|
+
roc_auc_score,
|
|
9
|
+
mean_squared_error,
|
|
10
|
+
mean_absolute_error,
|
|
11
|
+
r2_score,
|
|
12
|
+
median_absolute_error
|
|
13
|
+
)
|
|
14
|
+
import torch
|
|
15
|
+
import shap
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
from .utilities import make_fullpath
|
|
18
|
+
from .logger import _LOGGER
|
|
19
|
+
from typing import Union, Optional
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
__all__ = [
|
|
23
|
+
"plot_losses",
|
|
24
|
+
"classification_metrics",
|
|
25
|
+
"regression_metrics",
|
|
26
|
+
"shap_summary_plot"
|
|
27
|
+
]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def plot_losses(history: dict, save_dir: Optional[Union[str, Path]] = None):
|
|
31
|
+
"""
|
|
32
|
+
Plots training & validation loss curves from a history object.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
history (dict): A dictionary containing 'train_loss' and 'val_loss'.
|
|
36
|
+
save_dir (str | Path | None): Directory to save the plot image.
|
|
37
|
+
"""
|
|
38
|
+
train_loss = history.get('train_loss', [])
|
|
39
|
+
val_loss = history.get('val_loss', [])
|
|
40
|
+
|
|
41
|
+
if not train_loss and not val_loss:
|
|
42
|
+
print("Warning: Loss history is empty or incomplete. Cannot plot.")
|
|
43
|
+
return
|
|
44
|
+
|
|
45
|
+
fig, ax = plt.subplots(figsize=(10, 5), dpi=100)
|
|
46
|
+
|
|
47
|
+
# Plot training loss only if data for it exists
|
|
48
|
+
if train_loss:
|
|
49
|
+
epochs = range(1, len(train_loss) + 1)
|
|
50
|
+
ax.plot(epochs, train_loss, 'o-', label='Training Loss')
|
|
51
|
+
|
|
52
|
+
# Plot validation loss only if data for it exists
|
|
53
|
+
if val_loss:
|
|
54
|
+
epochs = range(1, len(val_loss) + 1)
|
|
55
|
+
ax.plot(epochs, val_loss, 'o-', label='Validation Loss')
|
|
56
|
+
|
|
57
|
+
ax.set_title('Training and Validation Loss')
|
|
58
|
+
ax.set_xlabel('Epochs')
|
|
59
|
+
ax.set_ylabel('Loss')
|
|
60
|
+
ax.legend()
|
|
61
|
+
ax.grid(True)
|
|
62
|
+
plt.tight_layout()
|
|
63
|
+
|
|
64
|
+
if save_dir:
|
|
65
|
+
save_dir_path = make_fullpath(save_dir, make=True)
|
|
66
|
+
save_path = save_dir_path / "loss_plot.svg"
|
|
67
|
+
plt.savefig(save_path)
|
|
68
|
+
_LOGGER.info(f"Loss plot saved as '{save_path.name}'")
|
|
69
|
+
else:
|
|
70
|
+
plt.show()
|
|
71
|
+
plt.close(fig)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def classification_metrics(y_true: np.ndarray, y_pred: np.ndarray, y_prob: Optional[np.ndarray] = None,
|
|
75
|
+
cmap: str = "Blues", save_dir: Optional[Union[str, Path]] = None):
|
|
76
|
+
"""
|
|
77
|
+
Displays and optionally saves classification metrics and plots.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
y_true (np.ndarray): Ground truth labels.
|
|
81
|
+
y_pred (np.ndarray): Predicted labels.
|
|
82
|
+
y_prob (np.ndarray, optional): Predicted probabilities for ROC curve.
|
|
83
|
+
cmap (str): Colormap for the confusion matrix.
|
|
84
|
+
save_dir (str | Path | None): Directory to save plots. If None, plots are shown not saved.
|
|
85
|
+
"""
|
|
86
|
+
print("--- Classification Report ---")
|
|
87
|
+
report: str = classification_report(y_true, y_pred) # type: ignore
|
|
88
|
+
print(report)
|
|
89
|
+
|
|
90
|
+
if save_dir:
|
|
91
|
+
save_dir_path = make_fullpath(save_dir, make=True)
|
|
92
|
+
# Save text report
|
|
93
|
+
report_path = save_dir_path / "classification_report.txt"
|
|
94
|
+
report_path.write_text(report, encoding="utf-8")
|
|
95
|
+
_LOGGER.info(f"Classification report saved as '{report_path.name}'")
|
|
96
|
+
|
|
97
|
+
# Save Confusion Matrix
|
|
98
|
+
fig_cm, ax_cm = plt.subplots(figsize=(6, 6), dpi=100)
|
|
99
|
+
ConfusionMatrixDisplay.from_predictions(y_true, y_pred, cmap=cmap, ax=ax_cm)
|
|
100
|
+
ax_cm.set_title("Confusion Matrix")
|
|
101
|
+
cm_path = save_dir_path / "confusion_matrix.svg"
|
|
102
|
+
plt.savefig(cm_path)
|
|
103
|
+
_LOGGER.info(f"Confusion matrix saved as '{cm_path.name}'")
|
|
104
|
+
plt.close(fig_cm)
|
|
105
|
+
|
|
106
|
+
# Save ROC Curve
|
|
107
|
+
if y_prob is not None and y_prob.ndim > 1 and y_prob.shape[1] >= 2:
|
|
108
|
+
fpr, tpr, _ = roc_curve(y_true, y_prob[:, 1])
|
|
109
|
+
auc = roc_auc_score(y_true, y_prob[:, 1])
|
|
110
|
+
fig_roc, ax_roc = plt.subplots(figsize=(6, 6), dpi=100)
|
|
111
|
+
ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}')
|
|
112
|
+
ax_roc.plot([0, 1], [0, 1], 'k--')
|
|
113
|
+
ax_roc.set_title('Receiver Operating Characteristic (ROC) Curve')
|
|
114
|
+
ax_roc.set_xlabel('False Positive Rate')
|
|
115
|
+
ax_roc.set_ylabel('True Positive Rate')
|
|
116
|
+
ax_roc.legend(loc='lower right')
|
|
117
|
+
ax_roc.grid(True)
|
|
118
|
+
roc_path = save_dir_path / "roc_curve.svg"
|
|
119
|
+
plt.savefig(roc_path)
|
|
120
|
+
_LOGGER.info(f"ROC curve saved as '{roc_path.name}'")
|
|
121
|
+
plt.close(fig_roc)
|
|
122
|
+
else:
|
|
123
|
+
# Show plots if not saving
|
|
124
|
+
ConfusionMatrixDisplay.from_predictions(y_true, y_pred, cmap=cmap)
|
|
125
|
+
plt.show()
|
|
126
|
+
if y_prob is not None and y_prob.ndim > 1 and y_prob.shape[1] >= 2:
|
|
127
|
+
fpr, tpr, _ = roc_curve(y_true, y_prob[:, 1])
|
|
128
|
+
auc = roc_auc_score(y_true, y_prob[:, 1])
|
|
129
|
+
plt.figure()
|
|
130
|
+
plt.plot(fpr, tpr, label=f'AUC = {auc:.2f}')
|
|
131
|
+
plt.plot([0, 1], [0, 1], 'k--')
|
|
132
|
+
plt.title('ROC Curve')
|
|
133
|
+
plt.show()
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Optional[Union[str, Path]] = None):
|
|
137
|
+
"""
|
|
138
|
+
Displays regression metrics and optionally saves plots and report.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
y_true (np.ndarray): Ground truth values.
|
|
142
|
+
y_pred (np.ndarray): Predicted values.
|
|
143
|
+
save_dir (str | None): Directory to save plots and report.
|
|
144
|
+
"""
|
|
145
|
+
rmse = np.sqrt(mean_squared_error(y_true, y_pred))
|
|
146
|
+
mae = mean_absolute_error(y_true, y_pred)
|
|
147
|
+
r2 = r2_score(y_true, y_pred)
|
|
148
|
+
medae = median_absolute_error(y_true, y_pred)
|
|
149
|
+
|
|
150
|
+
report_lines = [
|
|
151
|
+
"--- Regression Report ---",
|
|
152
|
+
f" Root Mean Squared Error (RMSE): {rmse:.4f}",
|
|
153
|
+
f" Mean Absolute Error (MAE): {mae:.4f}",
|
|
154
|
+
f" Median Absolute Error (MedAE): {medae:.4f}",
|
|
155
|
+
f" Coefficient of Determination (R²): {r2:.4f}"
|
|
156
|
+
]
|
|
157
|
+
report_string = "\n".join(report_lines)
|
|
158
|
+
print(report_string)
|
|
159
|
+
|
|
160
|
+
if save_dir:
|
|
161
|
+
save_dir_path = make_fullpath(save_dir, make=True)
|
|
162
|
+
# Save text report
|
|
163
|
+
report_path = save_dir_path / "regression_report.txt"
|
|
164
|
+
report_path.write_text(report_string)
|
|
165
|
+
_LOGGER.info(f"Regression report saved as '{report_path.name}'")
|
|
166
|
+
|
|
167
|
+
# Save residual plot
|
|
168
|
+
residuals = y_true - y_pred
|
|
169
|
+
fig_res, ax_res = plt.subplots(figsize=(8, 6), dpi=100)
|
|
170
|
+
ax_res.scatter(y_pred, residuals, alpha=0.6)
|
|
171
|
+
ax_res.axhline(0, color='red', linestyle='--')
|
|
172
|
+
ax_res.set_xlabel("Predicted Values")
|
|
173
|
+
ax_res.set_ylabel("Residuals")
|
|
174
|
+
ax_res.set_title("Residual Plot")
|
|
175
|
+
ax_res.grid(True)
|
|
176
|
+
plt.tight_layout()
|
|
177
|
+
res_path = save_dir_path / "residual_plot.svg"
|
|
178
|
+
plt.savefig(res_path)
|
|
179
|
+
_LOGGER.info(f"Residual plot saved as '{res_path.name}'")
|
|
180
|
+
plt.close(fig_res)
|
|
181
|
+
|
|
182
|
+
# Save true vs predicted plot
|
|
183
|
+
fig_tvp, ax_tvp = plt.subplots(figsize=(8, 6), dpi=100)
|
|
184
|
+
ax_tvp.scatter(y_true, y_pred, alpha=0.6)
|
|
185
|
+
ax_tvp.plot([y_true.min(), y_true.max()], [y_true.min(), y_true.max()], 'k--', lw=2)
|
|
186
|
+
ax_tvp.set_xlabel('True Values')
|
|
187
|
+
ax_tvp.set_ylabel('Predictions')
|
|
188
|
+
ax_tvp.set_title('True vs. Predicted Values')
|
|
189
|
+
ax_tvp.grid(True)
|
|
190
|
+
plt.tight_layout()
|
|
191
|
+
tvp_path = save_dir_path / "true_vs_predicted_plot.svg"
|
|
192
|
+
plt.savefig(tvp_path)
|
|
193
|
+
_LOGGER.info(f"True vs. Predicted plot saved as '{tvp_path.name}'")
|
|
194
|
+
plt.close(fig_tvp)
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def shap_summary_plot(model, background_data: torch.Tensor, instances_to_explain: torch.Tensor,
|
|
198
|
+
feature_names: Optional[list[str]]=None, save_dir: Optional[Union[str, Path]] = None):
|
|
199
|
+
"""
|
|
200
|
+
Calculates SHAP values and saves summary plots and data.
|
|
201
|
+
|
|
202
|
+
Args:
|
|
203
|
+
model (nn.Module): The trained PyTorch model.
|
|
204
|
+
background_data (torch.Tensor): A sample of data for the explainer background.
|
|
205
|
+
instances_to_explain (torch.Tensor): The specific data instances to explain.
|
|
206
|
+
feature_names (list of str | None): Names of the features for plot labeling.
|
|
207
|
+
save_dir (str | Path | None): Directory to save SHAP artifacts. If None, dot plot is shown.
|
|
208
|
+
"""
|
|
209
|
+
print("\n--- SHAP Value Explanation ---")
|
|
210
|
+
print("Calculating SHAP values... ")
|
|
211
|
+
|
|
212
|
+
model.eval()
|
|
213
|
+
model.cpu()
|
|
214
|
+
|
|
215
|
+
explainer = shap.DeepExplainer(model, background_data)
|
|
216
|
+
shap_values = explainer.shap_values(instances_to_explain)
|
|
217
|
+
|
|
218
|
+
shap_values_for_plot = shap_values[1] if isinstance(shap_values, list) else shap_values
|
|
219
|
+
if isinstance(shap_values, list):
|
|
220
|
+
_LOGGER.info("Using SHAP values for the positive class (class 1) for plots.")
|
|
221
|
+
|
|
222
|
+
if save_dir:
|
|
223
|
+
save_dir_path = make_fullpath(save_dir, make=True)
|
|
224
|
+
# Save Bar Plot
|
|
225
|
+
bar_path = save_dir_path / "shap_bar_plot.svg"
|
|
226
|
+
shap.summary_plot(shap_values_for_plot, instances_to_explain, feature_names=feature_names, plot_type="bar", show=False)
|
|
227
|
+
plt.title("SHAP Feature Importance")
|
|
228
|
+
plt.tight_layout()
|
|
229
|
+
plt.savefig(bar_path)
|
|
230
|
+
_LOGGER.info(f"SHAP bar plot saved as '{bar_path.name}'")
|
|
231
|
+
plt.close()
|
|
232
|
+
|
|
233
|
+
# Save Dot Plot
|
|
234
|
+
dot_path = save_dir_path / "shap_dot_plot.svg"
|
|
235
|
+
shap.summary_plot(shap_values_for_plot, instances_to_explain, feature_names=feature_names, plot_type="dot", show=False)
|
|
236
|
+
plt.title("SHAP Feature Importance")
|
|
237
|
+
plt.tight_layout()
|
|
238
|
+
plt.savefig(dot_path)
|
|
239
|
+
_LOGGER.info(f"SHAP dot plot saved as '{dot_path.name}'")
|
|
240
|
+
plt.close()
|
|
241
|
+
|
|
242
|
+
# Save Summary Data to CSV
|
|
243
|
+
summary_path = save_dir_path / "shap_summary.csv"
|
|
244
|
+
mean_abs_shap = np.abs(shap_values_for_plot).mean(axis=0)
|
|
245
|
+
if feature_names is None:
|
|
246
|
+
feature_names = [f'feature_{i}' for i in range(len(mean_abs_shap))]
|
|
247
|
+
summary_df = pd.DataFrame({
|
|
248
|
+
'feature': feature_names,
|
|
249
|
+
'mean_abs_shap_value': mean_abs_shap
|
|
250
|
+
}).sort_values('mean_abs_shap_value', ascending=False)
|
|
251
|
+
summary_df.to_csv(summary_path, index=False)
|
|
252
|
+
_LOGGER.info(f"SHAP summary data saved as '{summary_path.name}'")
|
|
253
|
+
else:
|
|
254
|
+
_LOGGER.info("No save directory provided. Displaying SHAP dot plot.")
|
|
255
|
+
shap.summary_plot(shap_values_for_plot, instances_to_explain, feature_names=feature_names, plot_type="dot")
|