dragon-ml-toolbox 6.4.1__py3-none-any.whl → 8.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dragon-ml-toolbox might be problematic. Click here for more details.
- {dragon_ml_toolbox-6.4.1.dist-info → dragon_ml_toolbox-8.0.0.dist-info}/METADATA +4 -1
- {dragon_ml_toolbox-6.4.1.dist-info → dragon_ml_toolbox-8.0.0.dist-info}/RECORD +14 -11
- ml_tools/ML_datasetmaster.py +285 -438
- ml_tools/ML_evaluation.py +119 -51
- ml_tools/ML_evaluation_multi.py +296 -0
- ml_tools/ML_inference.py +251 -31
- ml_tools/ML_models.py +468 -47
- ml_tools/ML_scaler.py +197 -0
- ml_tools/ML_trainer.py +246 -73
- ml_tools/_ML_optimization_multi.py +231 -0
- {dragon_ml_toolbox-6.4.1.dist-info → dragon_ml_toolbox-8.0.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-6.4.1.dist-info → dragon_ml_toolbox-8.0.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-6.4.1.dist-info → dragon_ml_toolbox-8.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-6.4.1.dist-info → dragon_ml_toolbox-8.0.0.dist-info}/top_level.txt +0 -0
ml_tools/ML_scaler.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch.utils.data import Dataset, DataLoader
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Union, List, Optional
|
|
5
|
+
from ._logger import _LOGGER
|
|
6
|
+
from ._script_info import _script_info
|
|
7
|
+
from .path_manager import make_fullpath
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"PytorchScaler"
|
|
11
|
+
]
|
|
12
|
+
|
|
13
|
+
class PytorchScaler:
|
|
14
|
+
"""
|
|
15
|
+
Standardizes continuous features in a PyTorch dataset by subtracting the
|
|
16
|
+
mean and dividing by the standard deviation.
|
|
17
|
+
|
|
18
|
+
The scaler is fitted on a training dataset and can then be saved and
|
|
19
|
+
loaded for consistent transformation during inference.
|
|
20
|
+
"""
|
|
21
|
+
def __init__(self,
|
|
22
|
+
mean: Optional[torch.Tensor] = None,
|
|
23
|
+
std: Optional[torch.Tensor] = None,
|
|
24
|
+
continuous_feature_indices: Optional[List[int]] = None):
|
|
25
|
+
"""
|
|
26
|
+
Initializes the scaler.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
mean (torch.Tensor, optional): The mean of the features to scale.
|
|
30
|
+
std (torch.Tensor, optional): The standard deviation of the features.
|
|
31
|
+
continuous_feature_indices (List[int], optional): The column indices of the features to standardize.
|
|
32
|
+
"""
|
|
33
|
+
self.mean_ = mean
|
|
34
|
+
self.std_ = std
|
|
35
|
+
self.continuous_feature_indices = continuous_feature_indices
|
|
36
|
+
|
|
37
|
+
@classmethod
|
|
38
|
+
def fit(cls, dataset: Dataset, continuous_feature_indices: List[int], batch_size: int = 64) -> 'PytorchScaler':
|
|
39
|
+
"""
|
|
40
|
+
Fits the scaler by computing the mean and std dev from a dataset using a
|
|
41
|
+
fast, single-pass, vectorized algorithm.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
dataset (Dataset): The PyTorch Dataset to fit on.
|
|
45
|
+
continuous_feature_indices (List[int]): The column indices of the
|
|
46
|
+
features to standardize.
|
|
47
|
+
batch_size (int): The batch size for iterating through the dataset.
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
PytorchScaler: A new, fitted instance of the scaler.
|
|
51
|
+
"""
|
|
52
|
+
if not continuous_feature_indices:
|
|
53
|
+
_LOGGER.warning("⚠️ No continuous feature indices provided. Scaler will not be fitted.")
|
|
54
|
+
return cls()
|
|
55
|
+
|
|
56
|
+
loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
57
|
+
|
|
58
|
+
running_sum, running_sum_sq = None, None
|
|
59
|
+
count = 0
|
|
60
|
+
num_continuous_features = len(continuous_feature_indices)
|
|
61
|
+
|
|
62
|
+
for features, _ in loader:
|
|
63
|
+
if running_sum is None:
|
|
64
|
+
device = features.device
|
|
65
|
+
running_sum = torch.zeros(num_continuous_features, device=device)
|
|
66
|
+
running_sum_sq = torch.zeros(num_continuous_features, device=device)
|
|
67
|
+
|
|
68
|
+
continuous_features = features[:, continuous_feature_indices].to(device)
|
|
69
|
+
|
|
70
|
+
running_sum += torch.sum(continuous_features, dim=0)
|
|
71
|
+
running_sum_sq += torch.sum(continuous_features**2, dim=0) # type: ignore
|
|
72
|
+
count += continuous_features.size(0)
|
|
73
|
+
|
|
74
|
+
if count == 0:
|
|
75
|
+
_LOGGER.warning("⚠️ Dataset is empty. Scaler cannot be fitted.")
|
|
76
|
+
return cls(continuous_feature_indices=continuous_feature_indices)
|
|
77
|
+
|
|
78
|
+
# Calculate mean
|
|
79
|
+
mean = running_sum / count
|
|
80
|
+
|
|
81
|
+
# Calculate standard deviation
|
|
82
|
+
if count < 2:
|
|
83
|
+
_LOGGER.warning(f"⚠️ Only one sample found. Standard deviation cannot be calculated and is set to 1.")
|
|
84
|
+
std = torch.ones_like(mean)
|
|
85
|
+
else:
|
|
86
|
+
# var = E[X^2] - (E[X])^2
|
|
87
|
+
var = (running_sum_sq / count) - mean**2
|
|
88
|
+
std = torch.sqrt(torch.clamp(var, min=1e-8)) # Clamp for numerical stability
|
|
89
|
+
|
|
90
|
+
_LOGGER.info(f"Scaler fitted on {count} samples for {num_continuous_features} continuous features.")
|
|
91
|
+
return cls(mean=mean, std=std, continuous_feature_indices=continuous_feature_indices)
|
|
92
|
+
|
|
93
|
+
def transform(self, data: torch.Tensor) -> torch.Tensor:
|
|
94
|
+
"""
|
|
95
|
+
Applies standardization to the specified continuous features.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
data (torch.Tensor): The input data tensor.
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
torch.Tensor: The transformed data tensor.
|
|
102
|
+
"""
|
|
103
|
+
if self.mean_ is None or self.std_ is None or self.continuous_feature_indices is None:
|
|
104
|
+
_LOGGER.warning("⚠️ Scaler has not been fitted. Returning original data.")
|
|
105
|
+
return data
|
|
106
|
+
|
|
107
|
+
data_clone = data.clone()
|
|
108
|
+
|
|
109
|
+
# Ensure mean and std are on the same device as the data
|
|
110
|
+
mean = self.mean_.to(data.device)
|
|
111
|
+
std = self.std_.to(data.device)
|
|
112
|
+
|
|
113
|
+
# Extract the columns to be scaled
|
|
114
|
+
features_to_scale = data_clone[:, self.continuous_feature_indices]
|
|
115
|
+
|
|
116
|
+
# Apply scaling, adding epsilon to std to prevent division by zero
|
|
117
|
+
scaled_features = (features_to_scale - mean) / (std + 1e-8)
|
|
118
|
+
|
|
119
|
+
# Place the scaled features back into the cloned tensor
|
|
120
|
+
data_clone[:, self.continuous_feature_indices] = scaled_features
|
|
121
|
+
|
|
122
|
+
return data_clone
|
|
123
|
+
|
|
124
|
+
def inverse_transform(self, data: torch.Tensor) -> torch.Tensor:
|
|
125
|
+
"""
|
|
126
|
+
Applies the inverse of the standardization transformation.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
data (torch.Tensor): The scaled data tensor.
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
torch.Tensor: The original-scale data tensor.
|
|
133
|
+
"""
|
|
134
|
+
if self.mean_ is None or self.std_ is None or self.continuous_feature_indices is None:
|
|
135
|
+
_LOGGER.warning("⚠️ Scaler has not been fitted. Returning original data.")
|
|
136
|
+
return data
|
|
137
|
+
|
|
138
|
+
data_clone = data.clone()
|
|
139
|
+
|
|
140
|
+
mean = self.mean_.to(data.device)
|
|
141
|
+
std = self.std_.to(data.device)
|
|
142
|
+
|
|
143
|
+
features_to_inverse = data_clone[:, self.continuous_feature_indices]
|
|
144
|
+
|
|
145
|
+
# Apply inverse scaling
|
|
146
|
+
original_scale_features = (features_to_inverse * (std + 1e-8)) + mean
|
|
147
|
+
|
|
148
|
+
data_clone[:, self.continuous_feature_indices] = original_scale_features
|
|
149
|
+
|
|
150
|
+
return data_clone
|
|
151
|
+
|
|
152
|
+
def save(self, filepath: Union[str, Path]):
|
|
153
|
+
"""
|
|
154
|
+
Saves the scaler's state (mean, std, indices) to a .pth file.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
filepath (str | Path): The path to save the file.
|
|
158
|
+
"""
|
|
159
|
+
path_obj = make_fullpath(filepath)
|
|
160
|
+
state = {
|
|
161
|
+
'mean': self.mean_,
|
|
162
|
+
'std': self.std_,
|
|
163
|
+
'continuous_feature_indices': self.continuous_feature_indices
|
|
164
|
+
}
|
|
165
|
+
torch.save(state, path_obj)
|
|
166
|
+
_LOGGER.info(f"✅ PytorchScaler state saved to '{path_obj.name}'.")
|
|
167
|
+
|
|
168
|
+
@staticmethod
|
|
169
|
+
def load(filepath: Union[str, Path]) -> 'PytorchScaler':
|
|
170
|
+
"""
|
|
171
|
+
Loads a scaler's state from a .pth file.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
filepath (str | Path): The path to the saved scaler file.
|
|
175
|
+
|
|
176
|
+
Returns:
|
|
177
|
+
PytorchScaler: An instance of the scaler with the loaded state.
|
|
178
|
+
"""
|
|
179
|
+
path_obj = make_fullpath(filepath, enforce="file")
|
|
180
|
+
state = torch.load(path_obj)
|
|
181
|
+
_LOGGER.info(f"✅ PytorchScaler state loaded from '{path_obj.name}'.")
|
|
182
|
+
return PytorchScaler(
|
|
183
|
+
mean=state['mean'],
|
|
184
|
+
std=state['std'],
|
|
185
|
+
continuous_feature_indices=state['continuous_feature_indices']
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
def __repr__(self) -> str:
|
|
189
|
+
"""Returns the developer-friendly string representation of the scaler."""
|
|
190
|
+
if self.continuous_feature_indices:
|
|
191
|
+
num_features = len(self.continuous_feature_indices)
|
|
192
|
+
return f"PytorchScaler(fitted for {num_features} features)"
|
|
193
|
+
return "PytorchScaler(not fitted)"
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def info():
|
|
197
|
+
_script_info(__all__)
|
ml_tools/ML_trainer.py
CHANGED
|
@@ -6,7 +6,8 @@ from torch import nn
|
|
|
6
6
|
import numpy as np
|
|
7
7
|
|
|
8
8
|
from .ML_callbacks import Callback, History, TqdmProgressBar
|
|
9
|
-
from .ML_evaluation import classification_metrics, regression_metrics, plot_losses, shap_summary_plot
|
|
9
|
+
from .ML_evaluation import classification_metrics, regression_metrics, plot_losses, shap_summary_plot, plot_attention_importance
|
|
10
|
+
from .ML_evaluation_multi import multi_target_regression_metrics, multi_label_classification_metrics, multi_target_shap_summary_plot
|
|
10
11
|
from ._script_info import _script_info
|
|
11
12
|
from .keys import PyTorchLogKeys
|
|
12
13
|
from ._logger import _LOGGER
|
|
@@ -19,7 +20,7 @@ __all__ = [
|
|
|
19
20
|
|
|
20
21
|
class MLTrainer:
|
|
21
22
|
def __init__(self, model: nn.Module, train_dataset: Dataset, test_dataset: Dataset,
|
|
22
|
-
kind: Literal["regression", "classification"],
|
|
23
|
+
kind: Literal["regression", "classification", "multi_target_regression", "multi_label_classification"],
|
|
23
24
|
criterion: nn.Module, optimizer: torch.optim.Optimizer,
|
|
24
25
|
device: Union[Literal['cuda', 'mps', 'cpu'],str], dataloader_workers: int = 2, callbacks: Optional[List[Callback]] = None):
|
|
25
26
|
"""
|
|
@@ -31,20 +32,22 @@ class MLTrainer:
|
|
|
31
32
|
model (nn.Module): The PyTorch model to train.
|
|
32
33
|
train_dataset (Dataset): The training dataset.
|
|
33
34
|
test_dataset (Dataset): The testing/validation dataset.
|
|
34
|
-
kind (str):
|
|
35
|
+
kind (str): Can be 'regression', 'classification', 'multi_target_regression', or 'multi_label_classification'.
|
|
35
36
|
criterion (nn.Module): The loss function.
|
|
36
37
|
optimizer (torch.optim.Optimizer): The optimizer.
|
|
37
38
|
device (str): The device to run training on ('cpu', 'cuda', 'mps').
|
|
38
|
-
dataloader_workers (int): Subprocesses for data loading.
|
|
39
|
+
dataloader_workers (int): Subprocesses for data loading.
|
|
39
40
|
callbacks (List[Callback] | None): A list of callbacks to use during training.
|
|
40
41
|
|
|
41
42
|
Note:
|
|
42
|
-
For **regression** tasks, suggested criterions include `nn.MSELoss` or `nn.L1Loss`.
|
|
43
|
-
|
|
44
|
-
For **classification** tasks, `nn.CrossEntropyLoss`
|
|
43
|
+
- For **regression** and **multi_target_regression** tasks, suggested criterions include `nn.MSELoss` or `nn.L1Loss`.
|
|
44
|
+
|
|
45
|
+
- For **single-label, multi-class classification** tasks, `nn.CrossEntropyLoss` is the standard choice.
|
|
46
|
+
|
|
47
|
+
- For **multi-label, binary classification** tasks (where each label is a 0 or 1), `nn.BCEWithLogitsLoss` is the correct choice as it treats each output as an independent binary problem.
|
|
45
48
|
"""
|
|
46
|
-
if kind not in ["regression", "classification"]:
|
|
47
|
-
raise
|
|
49
|
+
if kind not in ["regression", "classification", "multi_target_regression", "multi_label_classification"]:
|
|
50
|
+
raise ValueError(f"'{kind}' is not a valid task type.")
|
|
48
51
|
|
|
49
52
|
self.model = model
|
|
50
53
|
self.train_dataset = train_dataset
|
|
@@ -157,7 +160,6 @@ class MLTrainer:
|
|
|
157
160
|
def _train_step(self):
|
|
158
161
|
self.model.train()
|
|
159
162
|
running_loss = 0.0
|
|
160
|
-
# Enumerate to get batch index
|
|
161
163
|
for batch_idx, (features, target) in enumerate(self.train_loader): # type: ignore
|
|
162
164
|
# Create a log dictionary for the batch
|
|
163
165
|
batch_logs = {
|
|
@@ -168,22 +170,26 @@ class MLTrainer:
|
|
|
168
170
|
|
|
169
171
|
features, target = features.to(self.device), target.to(self.device)
|
|
170
172
|
self.optimizer.zero_grad()
|
|
173
|
+
|
|
171
174
|
output = self.model(features)
|
|
172
|
-
|
|
175
|
+
|
|
176
|
+
# Apply shape correction only for single-target regression
|
|
177
|
+
if self.kind == "regression":
|
|
173
178
|
output = output.view_as(target)
|
|
179
|
+
|
|
174
180
|
loss = self.criterion(output, target)
|
|
181
|
+
|
|
175
182
|
loss.backward()
|
|
176
183
|
self.optimizer.step()
|
|
177
184
|
|
|
178
185
|
# Calculate batch loss and update running loss for the epoch
|
|
179
186
|
batch_loss = loss.item()
|
|
180
187
|
running_loss += batch_loss * features.size(0)
|
|
181
|
-
|
|
188
|
+
|
|
182
189
|
# Add the batch loss to the logs and call the end-of-batch hook
|
|
183
190
|
batch_logs[PyTorchLogKeys.BATCH_LOSS] = batch_loss
|
|
184
191
|
self.callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
|
|
185
192
|
|
|
186
|
-
# Return the average loss for the entire epoch
|
|
187
193
|
return {PyTorchLogKeys.TRAIN_LOSS: running_loss / len(self.train_loader.dataset)} # type: ignore
|
|
188
194
|
|
|
189
195
|
def _validation_step(self):
|
|
@@ -192,25 +198,27 @@ class MLTrainer:
|
|
|
192
198
|
with torch.no_grad():
|
|
193
199
|
for features, target in self.test_loader: # type: ignore
|
|
194
200
|
features, target = features.to(self.device), target.to(self.device)
|
|
201
|
+
|
|
195
202
|
output = self.model(features)
|
|
196
|
-
|
|
203
|
+
# Apply shape correction only for single-target regression
|
|
204
|
+
if self.kind == "regression":
|
|
197
205
|
output = output.view_as(target)
|
|
206
|
+
|
|
198
207
|
loss = self.criterion(output, target)
|
|
208
|
+
|
|
199
209
|
running_loss += loss.item() * features.size(0)
|
|
210
|
+
|
|
200
211
|
logs = {PyTorchLogKeys.VAL_LOSS: running_loss / len(self.test_loader.dataset)} # type: ignore
|
|
201
212
|
return logs
|
|
202
213
|
|
|
203
|
-
def _predict_for_eval(self, dataloader: DataLoader):
|
|
214
|
+
def _predict_for_eval(self, dataloader: DataLoader, classification_threshold: float = 0.5):
|
|
204
215
|
"""
|
|
205
216
|
Private method to yield model predictions batch by batch for evaluation.
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
Args:
|
|
209
|
-
dataloader (DataLoader): The dataloader to predict on.
|
|
210
|
-
|
|
217
|
+
|
|
211
218
|
Yields:
|
|
212
219
|
tuple: A tuple containing (y_pred_batch, y_prob_batch, y_true_batch).
|
|
213
|
-
|
|
220
|
+
|
|
221
|
+
- y_prob_batch is None for regression tasks.
|
|
214
222
|
"""
|
|
215
223
|
self.model.eval()
|
|
216
224
|
self.model.to(self.device)
|
|
@@ -220,81 +228,135 @@ class MLTrainer:
|
|
|
220
228
|
output = self.model(features).cpu()
|
|
221
229
|
y_true_batch = target.numpy()
|
|
222
230
|
|
|
223
|
-
|
|
224
|
-
|
|
231
|
+
y_pred_batch = None
|
|
232
|
+
y_prob_batch = None
|
|
233
|
+
|
|
234
|
+
if self.kind in ["regression", "multi_target_regression"]:
|
|
235
|
+
y_pred_batch = output.numpy()
|
|
236
|
+
|
|
237
|
+
elif self.kind == "classification":
|
|
238
|
+
probs = torch.softmax(output, dim=1)
|
|
225
239
|
preds = torch.argmax(probs, dim=1)
|
|
226
240
|
y_pred_batch = preds.numpy()
|
|
227
241
|
y_prob_batch = probs.numpy()
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
242
|
+
|
|
243
|
+
elif self.kind == "multi_label_classification":
|
|
244
|
+
probs = torch.sigmoid(output)
|
|
245
|
+
preds = (probs >= classification_threshold).int()
|
|
246
|
+
y_pred_batch = preds.numpy()
|
|
247
|
+
y_prob_batch = probs.numpy()
|
|
248
|
+
|
|
233
249
|
yield y_pred_batch, y_prob_batch, y_true_batch
|
|
234
|
-
|
|
235
|
-
def evaluate(self, save_dir: Union[str,Path], data: Optional[Union[DataLoader, Dataset]] = None):
|
|
250
|
+
|
|
251
|
+
def evaluate(self, save_dir: Union[str, Path], data: Optional[Union[DataLoader, Dataset]] = None, classification_threshold: float = 0.5):
|
|
236
252
|
"""
|
|
237
|
-
Evaluates the model
|
|
253
|
+
Evaluates the model, routing to the correct evaluation function based on task `kind`.
|
|
238
254
|
|
|
239
255
|
Args:
|
|
240
|
-
data (DataLoader | Dataset | None ): The data to evaluate on.
|
|
241
|
-
Can be a DataLoader or a Dataset. If None, defaults to the trainer's internal test_dataset.
|
|
242
256
|
save_dir (str | Path): Directory to save all reports and plots.
|
|
257
|
+
data (DataLoader | Dataset | None): The data to evaluate on. If None, defaults to the trainer's internal test_dataset.
|
|
258
|
+
classification_threshold (float): Probability threshold for multi-label tasks.
|
|
243
259
|
"""
|
|
260
|
+
dataset_for_names = None
|
|
244
261
|
eval_loader = None
|
|
262
|
+
|
|
245
263
|
if isinstance(data, DataLoader):
|
|
246
264
|
eval_loader = data
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
265
|
+
# Try to get the dataset from the loader for fetching target names
|
|
266
|
+
if hasattr(data, 'dataset'):
|
|
267
|
+
dataset_for_names = data.dataset
|
|
268
|
+
elif isinstance(data, Dataset):
|
|
269
|
+
# Create a new loader from the provided dataset
|
|
270
|
+
eval_loader = DataLoader(data,
|
|
271
|
+
batch_size=32,
|
|
272
|
+
shuffle=False,
|
|
273
|
+
num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
|
|
274
|
+
pin_memory=(self.device.type == "cuda"))
|
|
275
|
+
dataset_for_names = data
|
|
276
|
+
else: # data is None, use the trainer's default test dataset
|
|
277
|
+
if self.test_dataset is None:
|
|
278
|
+
raise ValueError("Cannot evaluate. No data provided and no test_dataset available in the trainer.")
|
|
279
|
+
# Create a fresh DataLoader from the test_dataset
|
|
280
|
+
eval_loader = DataLoader(self.test_dataset,
|
|
281
|
+
batch_size=32,
|
|
282
|
+
shuffle=False,
|
|
283
|
+
num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
|
|
284
|
+
pin_memory=(self.device.type == "cuda"))
|
|
285
|
+
dataset_for_names = self.test_dataset
|
|
286
|
+
|
|
287
|
+
if eval_loader is None:
|
|
288
|
+
raise ValueError("Cannot evaluate. No valid data was provided or found.")
|
|
289
|
+
|
|
263
290
|
print("\n--- Model Evaluation ---")
|
|
264
291
|
|
|
265
|
-
# Collect results from the predict generator
|
|
266
292
|
all_preds, all_probs, all_true = [], [], []
|
|
267
|
-
for y_pred_b, y_prob_b, y_true_b in self._predict_for_eval(eval_loader):
|
|
268
|
-
all_preds.append(y_pred_b)
|
|
269
|
-
if y_prob_b is not None:
|
|
270
|
-
|
|
271
|
-
|
|
293
|
+
for y_pred_b, y_prob_b, y_true_b in self._predict_for_eval(eval_loader, classification_threshold):
|
|
294
|
+
if y_pred_b is not None: all_preds.append(y_pred_b)
|
|
295
|
+
if y_prob_b is not None: all_probs.append(y_prob_b)
|
|
296
|
+
if y_true_b is not None: all_true.append(y_true_b)
|
|
297
|
+
|
|
298
|
+
if not all_true:
|
|
299
|
+
_LOGGER.error("❌ Evaluation failed: No data was processed.")
|
|
300
|
+
return
|
|
272
301
|
|
|
273
302
|
y_pred = np.concatenate(all_preds)
|
|
274
303
|
y_true = np.concatenate(all_true)
|
|
275
|
-
y_prob = np.concatenate(all_probs) if
|
|
304
|
+
y_prob = np.concatenate(all_probs) if all_probs else None
|
|
276
305
|
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
else:
|
|
306
|
+
# --- Routing Logic ---
|
|
307
|
+
if self.kind == "regression":
|
|
280
308
|
regression_metrics(y_true.flatten(), y_pred.flatten(), save_dir)
|
|
281
309
|
|
|
310
|
+
elif self.kind == "classification":
|
|
311
|
+
classification_metrics(save_dir, y_true, y_pred, y_prob)
|
|
312
|
+
|
|
313
|
+
elif self.kind == "multi_target_regression":
|
|
314
|
+
try:
|
|
315
|
+
target_names = dataset_for_names.target_names # type: ignore
|
|
316
|
+
except AttributeError:
|
|
317
|
+
num_targets = y_true.shape[1]
|
|
318
|
+
target_names = [f"target_{i}" for i in range(num_targets)]
|
|
319
|
+
_LOGGER.warning(f"⚠️ Dataset has no 'target_names' attribute. Using generic names.")
|
|
320
|
+
multi_target_regression_metrics(y_true, y_pred, target_names, save_dir)
|
|
321
|
+
|
|
322
|
+
elif self.kind == "multi_label_classification":
|
|
323
|
+
try:
|
|
324
|
+
target_names = dataset_for_names.target_names # type: ignore
|
|
325
|
+
except AttributeError:
|
|
326
|
+
num_targets = y_true.shape[1]
|
|
327
|
+
target_names = [f"label_{i}" for i in range(num_targets)]
|
|
328
|
+
_LOGGER.warning(f"⚠️ Dataset has no 'target_names' attribute. Using generic names.")
|
|
329
|
+
|
|
330
|
+
if y_prob is None:
|
|
331
|
+
_LOGGER.error("❌ Evaluation for multi_label_classification requires probabilities (y_prob).")
|
|
332
|
+
return
|
|
333
|
+
multi_label_classification_metrics(y_true, y_prob, target_names, save_dir, classification_threshold)
|
|
334
|
+
|
|
282
335
|
print("\n--- Training History ---")
|
|
283
336
|
plot_losses(self.history, save_dir=save_dir)
|
|
284
337
|
|
|
285
|
-
def explain(self,
|
|
286
|
-
|
|
338
|
+
def explain(self,
|
|
339
|
+
save_dir: Union[str,Path],
|
|
340
|
+
explain_dataset: Optional[Dataset] = None,
|
|
341
|
+
n_samples: int = 1000,
|
|
342
|
+
feature_names: Optional[List[str]] = None,
|
|
343
|
+
target_names: Optional[List[str]] = None):
|
|
287
344
|
"""
|
|
288
345
|
Explains model predictions using SHAP and saves all artifacts.
|
|
289
346
|
|
|
290
347
|
The background data is automatically sampled from the trainer's training dataset.
|
|
348
|
+
|
|
349
|
+
This method automatically routes to the appropriate SHAP summary plot
|
|
350
|
+
function based on the task. If `feature_names` or `target_names` (multi-target) are not provided,
|
|
351
|
+
it will attempt to extract them from the dataset.
|
|
291
352
|
|
|
292
353
|
Args:
|
|
293
|
-
explain_dataset (Dataset
|
|
354
|
+
explain_dataset (Dataset | None): A specific dataset to explain.
|
|
294
355
|
If None, the trainer's test dataset is used.
|
|
295
356
|
n_samples (int): The number of samples to use for both background and explanation.
|
|
296
|
-
feature_names (
|
|
297
|
-
|
|
357
|
+
feature_names (list[str] | None): Feature names.
|
|
358
|
+
target_names (list[str] | None): Target names
|
|
359
|
+
save_dir (str | Path): Directory to save all SHAP artifacts.
|
|
298
360
|
"""
|
|
299
361
|
# Internal helper to create a dataloader and get a random sample
|
|
300
362
|
def _get_random_sample(dataset: Dataset, num_samples: int):
|
|
@@ -328,26 +390,137 @@ class MLTrainer:
|
|
|
328
390
|
# 1. Get background data from the trainer's train_dataset
|
|
329
391
|
background_data = _get_random_sample(self.train_dataset, n_samples)
|
|
330
392
|
if background_data is None:
|
|
331
|
-
|
|
393
|
+
_LOGGER.error("❌ Trainer's train_dataset is empty or invalid. Skipping SHAP analysis.")
|
|
332
394
|
return
|
|
333
395
|
|
|
334
396
|
# 2. Determine target dataset and get explanation instances
|
|
335
397
|
target_dataset = explain_dataset if explain_dataset is not None else self.test_dataset
|
|
336
398
|
instances_to_explain = _get_random_sample(target_dataset, n_samples)
|
|
337
399
|
if instances_to_explain is None:
|
|
338
|
-
|
|
400
|
+
_LOGGER.error("❌ Explanation dataset is empty or invalid. Skipping SHAP analysis.")
|
|
339
401
|
return
|
|
402
|
+
|
|
403
|
+
# attempt to get feature names
|
|
404
|
+
if feature_names is None:
|
|
405
|
+
# _LOGGER.info("`feature_names` not provided. Attempting to extract from dataset...")
|
|
406
|
+
if hasattr(target_dataset, "feature_names"):
|
|
407
|
+
feature_names = target_dataset.feature_names # type: ignore
|
|
408
|
+
else:
|
|
409
|
+
try:
|
|
410
|
+
# Handle PyTorch Subset
|
|
411
|
+
feature_names = target_dataset.dataset.feature_names # type: ignore
|
|
412
|
+
except AttributeError:
|
|
413
|
+
_LOGGER.error("❌ Could not extract `feature_names` from the dataset.")
|
|
414
|
+
raise ValueError("`feature_names` must be provided if the dataset object does not have a `feature_names` attribute.")
|
|
340
415
|
|
|
341
416
|
# 3. Call the plotting function
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
417
|
+
if self.kind in ["regression", "classification"]:
|
|
418
|
+
shap_summary_plot(
|
|
419
|
+
model=self.model,
|
|
420
|
+
background_data=background_data,
|
|
421
|
+
instances_to_explain=instances_to_explain,
|
|
422
|
+
feature_names=feature_names,
|
|
423
|
+
save_dir=save_dir
|
|
424
|
+
)
|
|
425
|
+
elif self.kind in ["multi_target_regression", "multi_label_classification"]:
|
|
426
|
+
# try to get target names
|
|
427
|
+
if target_names is None:
|
|
428
|
+
target_names = []
|
|
429
|
+
if hasattr(target_dataset, 'target_names'):
|
|
430
|
+
target_names = target_dataset.target_names # type: ignore
|
|
431
|
+
else:
|
|
432
|
+
# Infer number of targets from the model's output layer
|
|
433
|
+
try:
|
|
434
|
+
num_targets = self.model.output_layer.out_features # type: ignore
|
|
435
|
+
target_names = [f"target_{i}" for i in range(num_targets)] # type: ignore
|
|
436
|
+
_LOGGER.warning("Dataset has no 'target_names' attribute. Using generic names.")
|
|
437
|
+
except AttributeError:
|
|
438
|
+
_LOGGER.error("Cannot determine target names for multi-target SHAP plot. Skipping.")
|
|
439
|
+
return
|
|
440
|
+
|
|
441
|
+
multi_target_shap_summary_plot(
|
|
442
|
+
model=self.model,
|
|
443
|
+
background_data=background_data,
|
|
444
|
+
instances_to_explain=instances_to_explain,
|
|
445
|
+
feature_names=feature_names, # type: ignore
|
|
446
|
+
target_names=target_names, # type: ignore
|
|
447
|
+
save_dir=save_dir
|
|
448
|
+
)
|
|
449
|
+
|
|
450
|
+
def _attention_helper(self, dataloader: DataLoader):
|
|
451
|
+
"""
|
|
452
|
+
Private method to yield model attention weights batch by batch for evaluation.
|
|
453
|
+
|
|
454
|
+
Args:
|
|
455
|
+
dataloader (DataLoader): The dataloader to predict on.
|
|
456
|
+
|
|
457
|
+
Yields:
|
|
458
|
+
(torch.Tensor): Attention weights
|
|
459
|
+
"""
|
|
460
|
+
self.model.eval()
|
|
461
|
+
self.model.to(self.device)
|
|
462
|
+
|
|
463
|
+
with torch.no_grad():
|
|
464
|
+
for features, target in dataloader:
|
|
465
|
+
features = features.to(self.device)
|
|
466
|
+
attention_weights = None
|
|
467
|
+
|
|
468
|
+
# Get model output
|
|
469
|
+
# Unpack logits and weights from the special forward method
|
|
470
|
+
_output, attention_weights = self.model.forward_attention(features) # type: ignore
|
|
471
|
+
|
|
472
|
+
if attention_weights is not None:
|
|
473
|
+
attention_weights = attention_weights.cpu()
|
|
474
|
+
|
|
475
|
+
yield attention_weights
|
|
349
476
|
|
|
477
|
+
def explain_attention(self, save_dir: Union[str, Path], feature_names: Optional[List[str]], explain_dataset: Optional[Dataset] = None):
|
|
478
|
+
"""
|
|
479
|
+
Generates and saves a feature importance plot based on attention weights.
|
|
480
|
+
|
|
481
|
+
This method only works for models with a `forward_attention` method.
|
|
482
|
+
|
|
483
|
+
Args:
|
|
484
|
+
save_dir (str | Path): Directory to save the plot and summary data.
|
|
485
|
+
feature_names (List[str] | None): Names for the features for plot labeling.
|
|
486
|
+
explain_dataset (Dataset, optional): A specific dataset to explain. If None, the trainer's test dataset is used.
|
|
487
|
+
"""
|
|
488
|
+
|
|
489
|
+
print("\n--- Attention Analysis ---")
|
|
490
|
+
|
|
491
|
+
# --- Step 1: Check if the model supports this explanation ---
|
|
492
|
+
if not hasattr(self.model, 'forward_attention'):
|
|
493
|
+
_LOGGER.error("❌ Model does not have a `forward_attention` method. Skipping attention explanation.")
|
|
494
|
+
return
|
|
350
495
|
|
|
496
|
+
# --- Step 2: Set up the dataloader ---
|
|
497
|
+
dataset_to_use = explain_dataset if explain_dataset is not None else self.test_dataset
|
|
498
|
+
if not isinstance(dataset_to_use, Dataset):
|
|
499
|
+
_LOGGER.error("❌ The explanation dataset is empty or invalid. Skipping attention analysis.")
|
|
500
|
+
return
|
|
501
|
+
|
|
502
|
+
explain_loader = DataLoader(
|
|
503
|
+
dataset=dataset_to_use, batch_size=32, shuffle=False,
|
|
504
|
+
num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
|
|
505
|
+
pin_memory=("cuda" in self.device.type)
|
|
506
|
+
)
|
|
507
|
+
|
|
508
|
+
# --- Step 3: Collect weights ---
|
|
509
|
+
all_weights = []
|
|
510
|
+
for att_weights_b in self._attention_helper(explain_loader):
|
|
511
|
+
if att_weights_b is not None:
|
|
512
|
+
all_weights.append(att_weights_b)
|
|
513
|
+
|
|
514
|
+
# --- Step 4: Call the plotting function ---
|
|
515
|
+
if all_weights:
|
|
516
|
+
plot_attention_importance(
|
|
517
|
+
weights=all_weights,
|
|
518
|
+
feature_names=feature_names,
|
|
519
|
+
save_dir=save_dir
|
|
520
|
+
)
|
|
521
|
+
else:
|
|
522
|
+
_LOGGER.error("❌ No attention weights were collected from the model.")
|
|
523
|
+
|
|
351
524
|
def callbacks_hook(self, method_name: str, *args, **kwargs):
|
|
352
525
|
"""Calls the specified method on all callbacks."""
|
|
353
526
|
for callback in self.callbacks:
|