dragon-ml-toolbox 12.13.0__py3-none-any.whl → 14.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dragon-ml-toolbox might be problematic. Click here for more details.
- {dragon_ml_toolbox-12.13.0.dist-info → dragon_ml_toolbox-14.3.0.dist-info}/METADATA +11 -2
- dragon_ml_toolbox-14.3.0.dist-info/RECORD +48 -0
- {dragon_ml_toolbox-12.13.0.dist-info → dragon_ml_toolbox-14.3.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +10 -0
- ml_tools/MICE_imputation.py +207 -5
- ml_tools/ML_callbacks.py +40 -8
- ml_tools/ML_datasetmaster.py +200 -261
- ml_tools/ML_evaluation.py +29 -17
- ml_tools/ML_evaluation_multi.py +13 -10
- ml_tools/ML_inference.py +14 -5
- ml_tools/ML_models.py +135 -55
- ml_tools/ML_models_advanced.py +323 -0
- ml_tools/ML_optimization.py +49 -36
- ml_tools/ML_trainer.py +560 -30
- ml_tools/ML_utilities.py +302 -4
- ml_tools/ML_vision_datasetmaster.py +1352 -0
- ml_tools/ML_vision_evaluation.py +260 -0
- ml_tools/ML_vision_inference.py +428 -0
- ml_tools/ML_vision_models.py +627 -0
- ml_tools/ML_vision_transformers.py +58 -0
- ml_tools/PSO_optimization.py +5 -1
- ml_tools/_ML_vision_recipe.py +88 -0
- ml_tools/__init__.py +1 -0
- ml_tools/_schema.py +96 -0
- ml_tools/custom_logger.py +37 -14
- ml_tools/data_exploration.py +576 -138
- ml_tools/keys.py +51 -1
- ml_tools/math_utilities.py +1 -1
- ml_tools/optimization_tools.py +65 -86
- ml_tools/serde.py +78 -17
- ml_tools/utilities.py +192 -3
- dragon_ml_toolbox-12.13.0.dist-info/RECORD +0 -41
- ml_tools/ML_simple_optimization.py +0 -413
- {dragon_ml_toolbox-12.13.0.dist-info → dragon_ml_toolbox-14.3.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-12.13.0.dist-info → dragon_ml_toolbox-14.3.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-12.13.0.dist-info → dragon_ml_toolbox-14.3.0.dist-info}/top_level.txt +0 -0
ml_tools/ML_trainer.py
CHANGED
|
@@ -1,26 +1,29 @@
|
|
|
1
|
-
from typing import List, Literal, Union, Optional
|
|
1
|
+
from typing import List, Literal, Union, Optional, Callable, Dict, Any, Tuple
|
|
2
2
|
from pathlib import Path
|
|
3
3
|
from torch.utils.data import DataLoader, Dataset
|
|
4
4
|
import torch
|
|
5
5
|
from torch import nn
|
|
6
6
|
import numpy as np
|
|
7
7
|
|
|
8
|
-
from .ML_callbacks import Callback, History, TqdmProgressBar
|
|
8
|
+
from .ML_callbacks import Callback, History, TqdmProgressBar, ModelCheckpoint
|
|
9
9
|
from .ML_evaluation import classification_metrics, regression_metrics, plot_losses, shap_summary_plot, plot_attention_importance
|
|
10
10
|
from .ML_evaluation_multi import multi_target_regression_metrics, multi_label_classification_metrics, multi_target_shap_summary_plot
|
|
11
11
|
from ._script_info import _script_info
|
|
12
|
-
from .keys import PyTorchLogKeys
|
|
12
|
+
from .keys import PyTorchLogKeys, PyTorchCheckpointKeys, DatasetKeys
|
|
13
13
|
from ._logger import _LOGGER
|
|
14
|
+
from .path_manager import make_fullpath
|
|
15
|
+
from .ML_vision_evaluation import segmentation_metrics, object_detection_metrics
|
|
14
16
|
|
|
15
17
|
|
|
16
18
|
__all__ = [
|
|
17
|
-
"MLTrainer"
|
|
19
|
+
"MLTrainer",
|
|
20
|
+
"ObjectDetectionTrainer"
|
|
18
21
|
]
|
|
19
22
|
|
|
20
23
|
|
|
21
24
|
class MLTrainer:
|
|
22
25
|
def __init__(self, model: nn.Module, train_dataset: Dataset, test_dataset: Dataset,
|
|
23
|
-
kind: Literal["regression", "classification", "multi_target_regression", "multi_label_classification"],
|
|
26
|
+
kind: Literal["regression", "classification", "multi_target_regression", "multi_label_classification", "segmentation"],
|
|
24
27
|
criterion: nn.Module, optimizer: torch.optim.Optimizer,
|
|
25
28
|
device: Union[Literal['cuda', 'mps', 'cpu'],str], dataloader_workers: int = 2, callbacks: Optional[List[Callback]] = None):
|
|
26
29
|
"""
|
|
@@ -32,7 +35,7 @@ class MLTrainer:
|
|
|
32
35
|
model (nn.Module): The PyTorch model to train.
|
|
33
36
|
train_dataset (Dataset): The training dataset.
|
|
34
37
|
test_dataset (Dataset): The testing/validation dataset.
|
|
35
|
-
kind (str): Can be 'regression', 'classification', 'multi_target_regression', or '
|
|
38
|
+
kind (str): Can be 'regression', 'classification', 'multi_target_regression', 'multi_label_classification', or 'segmentation'.
|
|
36
39
|
criterion (nn.Module): The loss function.
|
|
37
40
|
optimizer (torch.optim.Optimizer): The optimizer.
|
|
38
41
|
device (str): The device to run training on ('cpu', 'cuda', 'mps').
|
|
@@ -45,8 +48,10 @@ class MLTrainer:
|
|
|
45
48
|
- For **single-label, multi-class classification** tasks, `nn.CrossEntropyLoss` is the standard choice.
|
|
46
49
|
|
|
47
50
|
- For **multi-label, binary classification** tasks (where each label is a 0 or 1), `nn.BCEWithLogitsLoss` is the correct choice as it treats each output as an independent binary problem.
|
|
51
|
+
|
|
52
|
+
- For **segmentation** tasks, `nn.CrossEntropyLoss` (for multi-class) or `nn.BCEWithLogitsLoss` (for binary) are common.
|
|
48
53
|
"""
|
|
49
|
-
if kind not in ["regression", "classification", "multi_target_regression", "multi_label_classification"]:
|
|
54
|
+
if kind not in ["regression", "classification", "multi_target_regression", "multi_label_classification", "segmentation"]:
|
|
50
55
|
raise ValueError(f"'{kind}' is not a valid task type.")
|
|
51
56
|
|
|
52
57
|
self.model = model
|
|
@@ -55,6 +60,7 @@ class MLTrainer:
|
|
|
55
60
|
self.kind = kind
|
|
56
61
|
self.criterion = criterion
|
|
57
62
|
self.optimizer = optimizer
|
|
63
|
+
self.scheduler = None
|
|
58
64
|
self.device = self._validate_device(device)
|
|
59
65
|
self.dataloader_workers = dataloader_workers
|
|
60
66
|
|
|
@@ -70,7 +76,9 @@ class MLTrainer:
|
|
|
70
76
|
self.history = {}
|
|
71
77
|
self.epoch = 0
|
|
72
78
|
self.epochs = 0 # Total epochs for the fit run
|
|
79
|
+
self.start_epoch = 1
|
|
73
80
|
self.stop_training = False
|
|
81
|
+
self._batch_size = 10
|
|
74
82
|
|
|
75
83
|
def _validate_device(self, device: str) -> torch.device:
|
|
76
84
|
"""Validates the selected device and returns a torch.device object."""
|
|
@@ -109,8 +117,66 @@ class MLTrainer:
|
|
|
109
117
|
num_workers=loader_workers,
|
|
110
118
|
pin_memory=("cuda" in self.device.type)
|
|
111
119
|
)
|
|
120
|
+
|
|
121
|
+
def _load_checkpoint(self, path: Union[str, Path]):
|
|
122
|
+
"""Loads a training checkpoint to resume training."""
|
|
123
|
+
p = make_fullpath(path, enforce="file")
|
|
124
|
+
_LOGGER.info(f"Loading checkpoint from '{p.name}' to resume training...")
|
|
125
|
+
|
|
126
|
+
try:
|
|
127
|
+
checkpoint = torch.load(p, map_location=self.device)
|
|
128
|
+
|
|
129
|
+
if PyTorchCheckpointKeys.MODEL_STATE not in checkpoint or PyTorchCheckpointKeys.OPTIMIZER_STATE not in checkpoint:
|
|
130
|
+
_LOGGER.error(f"Checkpoint file '{p.name}' is invalid. Missing 'model_state_dict' or 'optimizer_state_dict'.")
|
|
131
|
+
raise KeyError()
|
|
112
132
|
|
|
113
|
-
|
|
133
|
+
self.model.load_state_dict(checkpoint[PyTorchCheckpointKeys.MODEL_STATE])
|
|
134
|
+
self.optimizer.load_state_dict(checkpoint[PyTorchCheckpointKeys.OPTIMIZER_STATE])
|
|
135
|
+
self.start_epoch = checkpoint.get(PyTorchCheckpointKeys.EPOCH, 0) + 1 # Resume on the *next* epoch
|
|
136
|
+
|
|
137
|
+
# --- Scheduler State Loading Logic ---
|
|
138
|
+
scheduler_state_exists = PyTorchCheckpointKeys.SCHEDULER_STATE in checkpoint
|
|
139
|
+
scheduler_object_exists = self.scheduler is not None
|
|
140
|
+
|
|
141
|
+
if scheduler_object_exists and scheduler_state_exists:
|
|
142
|
+
# Case 1: Both exist. Attempt to load.
|
|
143
|
+
try:
|
|
144
|
+
self.scheduler.load_state_dict(checkpoint[PyTorchCheckpointKeys.SCHEDULER_STATE]) # type: ignore
|
|
145
|
+
scheduler_name = self.scheduler.__class__.__name__
|
|
146
|
+
_LOGGER.info(f"Restored LR scheduler state for: {scheduler_name}")
|
|
147
|
+
except Exception as e:
|
|
148
|
+
# Loading failed, likely a mismatch
|
|
149
|
+
scheduler_name = self.scheduler.__class__.__name__
|
|
150
|
+
_LOGGER.error(f"Failed to load scheduler state for '{scheduler_name}'. A different scheduler type might have been used.")
|
|
151
|
+
raise e
|
|
152
|
+
|
|
153
|
+
elif scheduler_object_exists and not scheduler_state_exists:
|
|
154
|
+
# Case 2: Scheduler provided, but no state in checkpoint.
|
|
155
|
+
scheduler_name = self.scheduler.__class__.__name__
|
|
156
|
+
_LOGGER.warning(f"'{scheduler_name}' was provided, but no scheduler state was found in the checkpoint. The scheduler will start from its initial state.")
|
|
157
|
+
|
|
158
|
+
elif not scheduler_object_exists and scheduler_state_exists:
|
|
159
|
+
# Case 3: State in checkpoint, but no scheduler provided.
|
|
160
|
+
_LOGGER.error("Checkpoint contains an LR scheduler state, but no LRScheduler callback was provided.")
|
|
161
|
+
raise ValueError()
|
|
162
|
+
|
|
163
|
+
# Restore callback states
|
|
164
|
+
for cb in self.callbacks:
|
|
165
|
+
if isinstance(cb, ModelCheckpoint) and PyTorchCheckpointKeys.BEST_SCORE in checkpoint:
|
|
166
|
+
cb.best = checkpoint[PyTorchCheckpointKeys.BEST_SCORE]
|
|
167
|
+
_LOGGER.info(f"Restored {cb.__class__.__name__} 'best' score to: {cb.best:.4f}")
|
|
168
|
+
|
|
169
|
+
_LOGGER.info(f"Checkpoint loaded. Resuming training from epoch {self.start_epoch}.")
|
|
170
|
+
|
|
171
|
+
except Exception as e:
|
|
172
|
+
_LOGGER.error(f"Failed to load checkpoint from '{p}': {e}")
|
|
173
|
+
raise
|
|
174
|
+
|
|
175
|
+
def fit(self,
|
|
176
|
+
epochs: int = 10,
|
|
177
|
+
batch_size: int = 10,
|
|
178
|
+
shuffle: bool = True,
|
|
179
|
+
resume_from_checkpoint: Optional[Union[str, Path]] = None):
|
|
114
180
|
"""
|
|
115
181
|
Starts the training-validation process of the model.
|
|
116
182
|
|
|
@@ -120,6 +186,7 @@ class MLTrainer:
|
|
|
120
186
|
epochs (int): The total number of epochs to train for.
|
|
121
187
|
batch_size (int): The number of samples per batch.
|
|
122
188
|
shuffle (bool): Whether to shuffle the training data at each epoch.
|
|
189
|
+
resume_from_checkpoint (str | Path | None): Optional path to a checkpoint to resume training.
|
|
123
190
|
|
|
124
191
|
Note:
|
|
125
192
|
For regression tasks using `nn.MSELoss` or `nn.L1Loss`, the trainer
|
|
@@ -129,18 +196,22 @@ class MLTrainer:
|
|
|
129
196
|
shape of `[batch_size]`.
|
|
130
197
|
"""
|
|
131
198
|
self.epochs = epochs
|
|
132
|
-
self.
|
|
199
|
+
self._batch_size = batch_size
|
|
200
|
+
self._create_dataloaders(self._batch_size, shuffle)
|
|
133
201
|
self.model.to(self.device)
|
|
134
202
|
|
|
203
|
+
if resume_from_checkpoint:
|
|
204
|
+
self._load_checkpoint(resume_from_checkpoint)
|
|
205
|
+
|
|
135
206
|
# Reset stop_training flag on the trainer
|
|
136
207
|
self.stop_training = False
|
|
137
208
|
|
|
138
|
-
self.
|
|
209
|
+
self._callbacks_hook('on_train_begin')
|
|
139
210
|
|
|
140
|
-
for epoch in range(
|
|
211
|
+
for epoch in range(self.start_epoch, self.epochs + 1):
|
|
141
212
|
self.epoch = epoch
|
|
142
213
|
epoch_logs = {}
|
|
143
|
-
self.
|
|
214
|
+
self._callbacks_hook('on_epoch_begin', epoch, logs=epoch_logs)
|
|
144
215
|
|
|
145
216
|
train_logs = self._train_step()
|
|
146
217
|
epoch_logs.update(train_logs)
|
|
@@ -148,13 +219,13 @@ class MLTrainer:
|
|
|
148
219
|
val_logs = self._validation_step()
|
|
149
220
|
epoch_logs.update(val_logs)
|
|
150
221
|
|
|
151
|
-
self.
|
|
222
|
+
self._callbacks_hook('on_epoch_end', epoch, logs=epoch_logs)
|
|
152
223
|
|
|
153
224
|
# Check the early stopping flag
|
|
154
225
|
if self.stop_training:
|
|
155
226
|
break
|
|
156
227
|
|
|
157
|
-
self.
|
|
228
|
+
self._callbacks_hook('on_train_end')
|
|
158
229
|
return self.history
|
|
159
230
|
|
|
160
231
|
def _train_step(self):
|
|
@@ -166,7 +237,7 @@ class MLTrainer:
|
|
|
166
237
|
PyTorchLogKeys.BATCH_INDEX: batch_idx,
|
|
167
238
|
PyTorchLogKeys.BATCH_SIZE: features.size(0)
|
|
168
239
|
}
|
|
169
|
-
self.
|
|
240
|
+
self._callbacks_hook('on_batch_begin', batch_idx, logs=batch_logs)
|
|
170
241
|
|
|
171
242
|
features, target = features.to(self.device), target.to(self.device)
|
|
172
243
|
self.optimizer.zero_grad()
|
|
@@ -188,7 +259,7 @@ class MLTrainer:
|
|
|
188
259
|
|
|
189
260
|
# Add the batch loss to the logs and call the end-of-batch hook
|
|
190
261
|
batch_logs[PyTorchLogKeys.BATCH_LOSS] = batch_loss
|
|
191
|
-
self.
|
|
262
|
+
self._callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
|
|
192
263
|
|
|
193
264
|
return {PyTorchLogKeys.TRAIN_LOSS: running_loss / len(self.train_loader.dataset)} # type: ignore
|
|
194
265
|
|
|
@@ -226,25 +297,40 @@ class MLTrainer:
|
|
|
226
297
|
for features, target in dataloader:
|
|
227
298
|
features = features.to(self.device)
|
|
228
299
|
output = self.model(features).cpu()
|
|
229
|
-
y_true_batch = target.numpy()
|
|
230
300
|
|
|
231
301
|
y_pred_batch = None
|
|
232
302
|
y_prob_batch = None
|
|
303
|
+
y_true_batch = None
|
|
233
304
|
|
|
234
305
|
if self.kind in ["regression", "multi_target_regression"]:
|
|
235
306
|
y_pred_batch = output.numpy()
|
|
307
|
+
y_true_batch = target.numpy()
|
|
236
308
|
|
|
237
309
|
elif self.kind == "classification":
|
|
238
310
|
probs = torch.softmax(output, dim=1)
|
|
239
311
|
preds = torch.argmax(probs, dim=1)
|
|
240
312
|
y_pred_batch = preds.numpy()
|
|
241
313
|
y_prob_batch = probs.numpy()
|
|
314
|
+
y_true_batch = target.numpy()
|
|
242
315
|
|
|
243
316
|
elif self.kind == "multi_label_classification":
|
|
244
317
|
probs = torch.sigmoid(output)
|
|
245
318
|
preds = (probs >= classification_threshold).int()
|
|
246
319
|
y_pred_batch = preds.numpy()
|
|
247
320
|
y_prob_batch = probs.numpy()
|
|
321
|
+
y_true_batch = target.numpy()
|
|
322
|
+
|
|
323
|
+
elif self.kind == "segmentation":
|
|
324
|
+
# output shape [N, C, H, W]
|
|
325
|
+
probs = torch.softmax(output, dim=1)
|
|
326
|
+
preds = torch.argmax(probs, dim=1) # shape [N, H, W]
|
|
327
|
+
y_pred_batch = preds.numpy()
|
|
328
|
+
y_prob_batch = probs.numpy() # Probs are [N, C, H, W]
|
|
329
|
+
|
|
330
|
+
# Handle target shape [N, 1, H, W] -> [N, H, W]
|
|
331
|
+
if target.ndim == 4 and target.shape[1] == 1:
|
|
332
|
+
target = target.squeeze(1)
|
|
333
|
+
y_true_batch = target.numpy()
|
|
248
334
|
|
|
249
335
|
yield y_pred_batch, y_prob_batch, y_true_batch
|
|
250
336
|
|
|
@@ -268,7 +354,7 @@ class MLTrainer:
|
|
|
268
354
|
elif isinstance(data, Dataset):
|
|
269
355
|
# Create a new loader from the provided dataset
|
|
270
356
|
eval_loader = DataLoader(data,
|
|
271
|
-
batch_size=
|
|
357
|
+
batch_size=self._batch_size,
|
|
272
358
|
shuffle=False,
|
|
273
359
|
num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
|
|
274
360
|
pin_memory=(self.device.type == "cuda"))
|
|
@@ -279,10 +365,11 @@ class MLTrainer:
|
|
|
279
365
|
raise ValueError()
|
|
280
366
|
# Create a fresh DataLoader from the test_dataset
|
|
281
367
|
eval_loader = DataLoader(self.test_dataset,
|
|
282
|
-
batch_size=
|
|
368
|
+
batch_size=self._batch_size,
|
|
283
369
|
shuffle=False,
|
|
284
370
|
num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
|
|
285
371
|
pin_memory=(self.device.type == "cuda"))
|
|
372
|
+
|
|
286
373
|
dataset_for_names = self.test_dataset
|
|
287
374
|
|
|
288
375
|
if eval_loader is None:
|
|
@@ -333,7 +420,31 @@ class MLTrainer:
|
|
|
333
420
|
_LOGGER.error("Evaluation for multi_label_classification requires probabilities (y_prob).")
|
|
334
421
|
return
|
|
335
422
|
multi_label_classification_metrics(y_true, y_prob, target_names, save_dir, classification_threshold)
|
|
423
|
+
|
|
424
|
+
elif self.kind == "segmentation":
|
|
425
|
+
class_names = None
|
|
426
|
+
try:
|
|
427
|
+
# Try to get 'classes' from VisionDatasetMaker
|
|
428
|
+
if hasattr(dataset_for_names, 'classes'):
|
|
429
|
+
class_names = dataset_for_names.classes # type: ignore
|
|
430
|
+
# Fallback for Subset
|
|
431
|
+
elif hasattr(dataset_for_names, 'dataset') and hasattr(dataset_for_names.dataset, 'classes'): # type: ignore
|
|
432
|
+
class_names = dataset_for_names.dataset.classes # type: ignore
|
|
433
|
+
except AttributeError:
|
|
434
|
+
pass # class_names is still None
|
|
336
435
|
|
|
436
|
+
if class_names is None:
|
|
437
|
+
try:
|
|
438
|
+
# Fallback to 'target_names'
|
|
439
|
+
class_names = dataset_for_names.target_names # type: ignore
|
|
440
|
+
except AttributeError:
|
|
441
|
+
# Fallback to inferring from labels
|
|
442
|
+
labels = np.unique(y_true)
|
|
443
|
+
class_names = [f"Class {i}" for i in labels]
|
|
444
|
+
_LOGGER.warning(f"Dataset has no 'classes' or 'target_names' attribute. Using generic names.")
|
|
445
|
+
|
|
446
|
+
segmentation_metrics(y_true, y_pred, save_dir, class_names=class_names)
|
|
447
|
+
|
|
337
448
|
print("\n--- Training History ---")
|
|
338
449
|
plot_losses(self.history, save_dir=save_dir)
|
|
339
450
|
|
|
@@ -343,7 +454,7 @@ class MLTrainer:
|
|
|
343
454
|
n_samples: int = 300,
|
|
344
455
|
feature_names: Optional[List[str]] = None,
|
|
345
456
|
target_names: Optional[List[str]] = None,
|
|
346
|
-
explainer_type: Literal['deep', 'kernel'] = '
|
|
457
|
+
explainer_type: Literal['deep', 'kernel'] = 'kernel'):
|
|
347
458
|
"""
|
|
348
459
|
Explains model predictions using SHAP and saves all artifacts.
|
|
349
460
|
|
|
@@ -357,11 +468,11 @@ class MLTrainer:
|
|
|
357
468
|
explain_dataset (Dataset | None): A specific dataset to explain.
|
|
358
469
|
If None, the trainer's test dataset is used.
|
|
359
470
|
n_samples (int): The number of samples to use for both background and explanation.
|
|
360
|
-
feature_names (list[str] | None): Feature names.
|
|
471
|
+
feature_names (list[str] | None): Feature names. If None, the names will be extracted from the Dataset and raise an error on failure.
|
|
361
472
|
target_names (list[str] | None): Target names for multi-target tasks.
|
|
362
473
|
save_dir (str | Path): Directory to save all SHAP artifacts.
|
|
363
474
|
explainer_type (Literal['deep', 'kernel']): The explainer to use.
|
|
364
|
-
- 'deep':
|
|
475
|
+
- 'deep': Uses shap.DeepExplainer. Fast and efficient for PyTorch models.
|
|
365
476
|
- 'kernel': Uses shap.KernelExplainer. Model-agnostic but EXTREMELY slow and memory-intensive. Use with a very low 'n_samples'< 100.
|
|
366
477
|
"""
|
|
367
478
|
# Internal helper to create a dataloader and get a random sample
|
|
@@ -409,10 +520,10 @@ class MLTrainer:
|
|
|
409
520
|
# attempt to get feature names
|
|
410
521
|
if feature_names is None:
|
|
411
522
|
# _LOGGER.info("`feature_names` not provided. Attempting to extract from dataset...")
|
|
412
|
-
if hasattr(target_dataset,
|
|
523
|
+
if hasattr(target_dataset, DatasetKeys.FEATURE_NAMES):
|
|
413
524
|
feature_names = target_dataset.feature_names # type: ignore
|
|
414
525
|
else:
|
|
415
|
-
_LOGGER.error("Could not extract `feature_names` from the dataset. It must be provided if the dataset object does not have a
|
|
526
|
+
_LOGGER.error(f"Could not extract `feature_names` from the dataset. It must be provided if the dataset object does not have a '{DatasetKeys.FEATURE_NAMES}' attribute.")
|
|
416
527
|
raise ValueError()
|
|
417
528
|
|
|
418
529
|
# move model to device
|
|
@@ -433,7 +544,7 @@ class MLTrainer:
|
|
|
433
544
|
# try to get target names
|
|
434
545
|
if target_names is None:
|
|
435
546
|
target_names = []
|
|
436
|
-
if hasattr(target_dataset,
|
|
547
|
+
if hasattr(target_dataset, DatasetKeys.TARGET_NAMES):
|
|
437
548
|
target_names = target_dataset.target_names # type: ignore
|
|
438
549
|
else:
|
|
439
550
|
# Infer number of targets from the model's output layer
|
|
@@ -484,7 +595,7 @@ class MLTrainer:
|
|
|
484
595
|
yield attention_weights
|
|
485
596
|
|
|
486
597
|
def explain_attention(self, save_dir: Union[str, Path],
|
|
487
|
-
feature_names: Optional[List[str]],
|
|
598
|
+
feature_names: Optional[List[str]] = None,
|
|
488
599
|
explain_dataset: Optional[Dataset] = None,
|
|
489
600
|
plot_n_features: int = 10):
|
|
490
601
|
"""
|
|
@@ -494,7 +605,7 @@ class MLTrainer:
|
|
|
494
605
|
|
|
495
606
|
Args:
|
|
496
607
|
save_dir (str | Path): Directory to save the plot and summary data.
|
|
497
|
-
feature_names (List[str] | None): Names for the features for plot labeling. If
|
|
608
|
+
feature_names (List[str] | None): Names for the features for plot labeling. If None, the names will be extracted from the Dataset and raise an error on failure.
|
|
498
609
|
explain_dataset (Dataset, optional): A specific dataset to explain. If None, the trainer's test dataset is used.
|
|
499
610
|
plot_n_features (int): Number of top features to plot.
|
|
500
611
|
"""
|
|
@@ -504,8 +615,7 @@ class MLTrainer:
|
|
|
504
615
|
# --- Step 1: Check if the model supports this explanation ---
|
|
505
616
|
if not getattr(self.model, 'has_interpretable_attention', False):
|
|
506
617
|
_LOGGER.warning(
|
|
507
|
-
"Model is not flagged for interpretable attention analysis. "
|
|
508
|
-
"Skipping. This is the correct behavior for models like MultiHeadAttentionMLP."
|
|
618
|
+
"Model is not flagged for interpretable attention analysis. Skipping. This is the correct behavior for models like MultiHeadAttentionMLP."
|
|
509
619
|
)
|
|
510
620
|
return
|
|
511
621
|
|
|
@@ -515,6 +625,14 @@ class MLTrainer:
|
|
|
515
625
|
_LOGGER.error("The explanation dataset is empty or invalid. Skipping attention analysis.")
|
|
516
626
|
return
|
|
517
627
|
|
|
628
|
+
# Get feature names
|
|
629
|
+
if feature_names is None:
|
|
630
|
+
if hasattr(dataset_to_use, DatasetKeys.FEATURE_NAMES):
|
|
631
|
+
feature_names = dataset_to_use.feature_names # type: ignore
|
|
632
|
+
else:
|
|
633
|
+
_LOGGER.error(f"Could not extract `feature_names` from the dataset for attention plot. It must be provided if the dataset object does not have a '{DatasetKeys.FEATURE_NAMES}' attribute.")
|
|
634
|
+
raise ValueError()
|
|
635
|
+
|
|
518
636
|
explain_loader = DataLoader(
|
|
519
637
|
dataset=dataset_to_use, batch_size=32, shuffle=False,
|
|
520
638
|
num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
|
|
@@ -538,11 +656,423 @@ class MLTrainer:
|
|
|
538
656
|
else:
|
|
539
657
|
_LOGGER.error("No attention weights were collected from the model.")
|
|
540
658
|
|
|
541
|
-
def
|
|
659
|
+
def _callbacks_hook(self, method_name: str, *args, **kwargs):
|
|
542
660
|
"""Calls the specified method on all callbacks."""
|
|
543
661
|
for callback in self.callbacks:
|
|
544
662
|
method = getattr(callback, method_name)
|
|
545
663
|
method(*args, **kwargs)
|
|
664
|
+
|
|
665
|
+
def to_cpu(self):
|
|
666
|
+
"""
|
|
667
|
+
Moves the model to the CPU and updates the trainer's device setting.
|
|
668
|
+
|
|
669
|
+
This is useful for running operations that require the CPU.
|
|
670
|
+
"""
|
|
671
|
+
self.device = torch.device('cpu')
|
|
672
|
+
self.model.to(self.device)
|
|
673
|
+
_LOGGER.info("Trainer and model moved to CPU.")
|
|
674
|
+
|
|
675
|
+
def to_device(self, device: str):
|
|
676
|
+
"""
|
|
677
|
+
Moves the model to the specified device and updates the trainer's device setting.
|
|
678
|
+
|
|
679
|
+
Args:
|
|
680
|
+
device (str): The target device (e.g., 'cuda', 'mps', 'cpu').
|
|
681
|
+
"""
|
|
682
|
+
self.device = self._validate_device(device)
|
|
683
|
+
self.model.to(self.device)
|
|
684
|
+
_LOGGER.info(f"Trainer and model moved to {self.device}.")
|
|
685
|
+
|
|
686
|
+
|
|
687
|
+
# Object Detection Trainer
|
|
688
|
+
class ObjectDetectionTrainer:
|
|
689
|
+
def __init__(self, model: nn.Module, train_dataset: Dataset, test_dataset: Dataset,
|
|
690
|
+
collate_fn: Callable, optimizer: torch.optim.Optimizer,
|
|
691
|
+
device: Union[Literal['cuda', 'mps', 'cpu'],str], dataloader_workers: int = 2, callbacks: Optional[List[Callback]] = None):
|
|
692
|
+
"""
|
|
693
|
+
Automates the training process of an Object Detection Model (e.g., DragonFastRCNN).
|
|
694
|
+
|
|
695
|
+
Built-in Callbacks: `History`, `TqdmProgressBar`
|
|
696
|
+
|
|
697
|
+
Args:
|
|
698
|
+
model (nn.Module): The PyTorch object detection model to train.
|
|
699
|
+
train_dataset (Dataset): The training dataset.
|
|
700
|
+
test_dataset (Dataset): The testing/validation dataset.
|
|
701
|
+
collate_fn (Callable): The collate function from `ObjectDetectionDatasetMaker.collate_fn`.
|
|
702
|
+
optimizer (torch.optim.Optimizer): The optimizer.
|
|
703
|
+
device (str): The device to run training on ('cpu', 'cuda', 'mps').
|
|
704
|
+
dataloader_workers (int): Subprocesses for data loading.
|
|
705
|
+
callbacks (List[Callback] | None): A list of callbacks to use during training.
|
|
706
|
+
|
|
707
|
+
## Note:
|
|
708
|
+
This trainer is specialized. It does not take a `criterion` because object detection models like Faster R-CNN return a dictionary of losses directly from their forward pass during training.
|
|
709
|
+
"""
|
|
710
|
+
self.model = model
|
|
711
|
+
self.train_dataset = train_dataset
|
|
712
|
+
self.test_dataset = test_dataset
|
|
713
|
+
self.kind = "object_detection"
|
|
714
|
+
self.collate_fn = collate_fn
|
|
715
|
+
self.criterion = None # Criterion is handled inside the model
|
|
716
|
+
self.optimizer = optimizer
|
|
717
|
+
self.scheduler = None
|
|
718
|
+
self.device = self._validate_device(device)
|
|
719
|
+
self.dataloader_workers = dataloader_workers
|
|
720
|
+
|
|
721
|
+
# Callback handler - History and TqdmProgressBar are added by default
|
|
722
|
+
default_callbacks = [History(), TqdmProgressBar()]
|
|
723
|
+
user_callbacks = callbacks if callbacks is not None else []
|
|
724
|
+
self.callbacks = default_callbacks + user_callbacks
|
|
725
|
+
self._set_trainer_on_callbacks()
|
|
726
|
+
|
|
727
|
+
# Internal state
|
|
728
|
+
self.train_loader = None
|
|
729
|
+
self.test_loader = None
|
|
730
|
+
self.history = {}
|
|
731
|
+
self.epoch = 0
|
|
732
|
+
self.epochs = 0 # Total epochs for the fit run
|
|
733
|
+
self.start_epoch = 1
|
|
734
|
+
self.stop_training = False
|
|
735
|
+
self._batch_size = 10
|
|
736
|
+
|
|
737
|
+
def _validate_device(self, device: str) -> torch.device:
|
|
738
|
+
"""Validates the selected device and returns a torch.device object."""
|
|
739
|
+
device_lower = device.lower()
|
|
740
|
+
if "cuda" in device_lower and not torch.cuda.is_available():
|
|
741
|
+
_LOGGER.warning("CUDA not available, switching to CPU.")
|
|
742
|
+
device = "cpu"
|
|
743
|
+
elif device_lower == "mps" and not torch.backends.mps.is_available():
|
|
744
|
+
_LOGGER.warning("Apple Metal Performance Shaders (MPS) not available, switching to CPU.")
|
|
745
|
+
device = "cpu"
|
|
746
|
+
return torch.device(device)
|
|
747
|
+
|
|
748
|
+
def _set_trainer_on_callbacks(self):
|
|
749
|
+
"""Gives each callback a reference to this trainer instance."""
|
|
750
|
+
for callback in self.callbacks:
|
|
751
|
+
callback.set_trainer(self)
|
|
752
|
+
|
|
753
|
+
def _create_dataloaders(self, batch_size: int, shuffle: bool):
|
|
754
|
+
"""Initializes the DataLoaders with the object detection collate_fn."""
|
|
755
|
+
# Ensure stability on MPS devices by setting num_workers to 0
|
|
756
|
+
loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
|
|
757
|
+
|
|
758
|
+
self.train_loader = DataLoader(
|
|
759
|
+
dataset=self.train_dataset,
|
|
760
|
+
batch_size=batch_size,
|
|
761
|
+
shuffle=shuffle,
|
|
762
|
+
num_workers=loader_workers,
|
|
763
|
+
pin_memory=("cuda" in self.device.type),
|
|
764
|
+
collate_fn=self.collate_fn # Use the provided collate function
|
|
765
|
+
)
|
|
766
|
+
|
|
767
|
+
self.test_loader = DataLoader(
|
|
768
|
+
dataset=self.test_dataset,
|
|
769
|
+
batch_size=batch_size,
|
|
770
|
+
shuffle=False,
|
|
771
|
+
num_workers=loader_workers,
|
|
772
|
+
pin_memory=("cuda" in self.device.type),
|
|
773
|
+
collate_fn=self.collate_fn # Use the provided collate function
|
|
774
|
+
)
|
|
775
|
+
|
|
776
|
+
def _load_checkpoint(self, path: Union[str, Path]):
|
|
777
|
+
"""Loads a training checkpoint to resume training."""
|
|
778
|
+
p = make_fullpath(path, enforce="file")
|
|
779
|
+
_LOGGER.info(f"Loading checkpoint from '{p.name}' to resume training...")
|
|
780
|
+
|
|
781
|
+
try:
|
|
782
|
+
checkpoint = torch.load(p, map_location=self.device)
|
|
783
|
+
|
|
784
|
+
if PyTorchCheckpointKeys.MODEL_STATE not in checkpoint or PyTorchCheckpointKeys.OPTIMIZER_STATE not in checkpoint:
|
|
785
|
+
_LOGGER.error(f"Checkpoint file '{p.name}' is invalid. Missing 'model_state_dict' or 'optimizer_state_dict'.")
|
|
786
|
+
raise KeyError()
|
|
787
|
+
|
|
788
|
+
self.model.load_state_dict(checkpoint[PyTorchCheckpointKeys.MODEL_STATE])
|
|
789
|
+
self.optimizer.load_state_dict(checkpoint[PyTorchCheckpointKeys.OPTIMIZER_STATE])
|
|
790
|
+
self.start_epoch = checkpoint.get(PyTorchCheckpointKeys.EPOCH, 0) + 1 # Resume on the *next* epoch
|
|
791
|
+
|
|
792
|
+
# --- Scheduler State Loading Logic ---
|
|
793
|
+
scheduler_state_exists = PyTorchCheckpointKeys.SCHEDULER_STATE in checkpoint
|
|
794
|
+
scheduler_object_exists = self.scheduler is not None
|
|
795
|
+
|
|
796
|
+
if scheduler_object_exists and scheduler_state_exists:
|
|
797
|
+
# Case 1: Both exist. Attempt to load.
|
|
798
|
+
try:
|
|
799
|
+
self.scheduler.load_state_dict(checkpoint[PyTorchCheckpointKeys.SCHEDULER_STATE]) # type: ignore
|
|
800
|
+
scheduler_name = self.scheduler.__class__.__name__
|
|
801
|
+
_LOGGER.info(f"Restored LR scheduler state for: {scheduler_name}")
|
|
802
|
+
except Exception as e:
|
|
803
|
+
# Loading failed, likely a mismatch
|
|
804
|
+
scheduler_name = self.scheduler.__class__.__name__
|
|
805
|
+
_LOGGER.error(f"Failed to load scheduler state for '{scheduler_name}'. A different scheduler type might have been used.")
|
|
806
|
+
raise e
|
|
807
|
+
|
|
808
|
+
elif scheduler_object_exists and not scheduler_state_exists:
|
|
809
|
+
# Case 2: Scheduler provided, but no state in checkpoint.
|
|
810
|
+
scheduler_name = self.scheduler.__class__.__name__
|
|
811
|
+
_LOGGER.warning(f"'{scheduler_name}' was provided, but no scheduler state was found in the checkpoint. The scheduler will start from its initial state.")
|
|
812
|
+
|
|
813
|
+
elif not scheduler_object_exists and scheduler_state_exists:
|
|
814
|
+
# Case 3: State in checkpoint, but no scheduler provided.
|
|
815
|
+
_LOGGER.error("Checkpoint contains an LR scheduler state, but no LRScheduler callback was provided.")
|
|
816
|
+
raise ValueError()
|
|
817
|
+
|
|
818
|
+
# Restore callback states
|
|
819
|
+
for cb in self.callbacks:
|
|
820
|
+
if isinstance(cb, ModelCheckpoint) and PyTorchCheckpointKeys.BEST_SCORE in checkpoint:
|
|
821
|
+
cb.best = checkpoint[PyTorchCheckpointKeys.BEST_SCORE]
|
|
822
|
+
_LOGGER.info(f"Restored {cb.__class__.__name__} 'best' score to: {cb.best:.4f}")
|
|
823
|
+
|
|
824
|
+
_LOGGER.info(f"Checkpoint loaded. Resuming training from epoch {self.start_epoch}.")
|
|
825
|
+
|
|
826
|
+
except Exception as e:
|
|
827
|
+
_LOGGER.error(f"Failed to load checkpoint from '{p}': {e}")
|
|
828
|
+
raise
|
|
829
|
+
|
|
830
|
+
def fit(self,
|
|
831
|
+
epochs: int = 10,
|
|
832
|
+
batch_size: int = 10,
|
|
833
|
+
shuffle: bool = True,
|
|
834
|
+
resume_from_checkpoint: Optional[Union[str, Path]] = None):
|
|
835
|
+
"""
|
|
836
|
+
Starts the training-validation process of the model.
|
|
837
|
+
|
|
838
|
+
Returns the "History" callback dictionary.
|
|
839
|
+
|
|
840
|
+
Args:
|
|
841
|
+
epochs (int): The total number of epochs to train for.
|
|
842
|
+
batch_size (int): The number of samples per batch.
|
|
843
|
+
shuffle (bool): Whether to shuffle the training data at each epoch.
|
|
844
|
+
resume_from_checkpoint (str | Path | None): Optional path to a checkpoint to resume training.
|
|
845
|
+
"""
|
|
846
|
+
self.epochs = epochs
|
|
847
|
+
self._batch_size = batch_size
|
|
848
|
+
self._create_dataloaders(self._batch_size, shuffle)
|
|
849
|
+
self.model.to(self.device)
|
|
850
|
+
|
|
851
|
+
if resume_from_checkpoint:
|
|
852
|
+
self._load_checkpoint(resume_from_checkpoint)
|
|
853
|
+
|
|
854
|
+
# Reset stop_training flag on the trainer
|
|
855
|
+
self.stop_training = False
|
|
856
|
+
|
|
857
|
+
self._callbacks_hook('on_train_begin')
|
|
858
|
+
|
|
859
|
+
for epoch in range(self.start_epoch, self.epochs + 1):
|
|
860
|
+
self.epoch = epoch
|
|
861
|
+
epoch_logs = {}
|
|
862
|
+
self._callbacks_hook('on_epoch_begin', epoch, logs=epoch_logs)
|
|
863
|
+
|
|
864
|
+
train_logs = self._train_step()
|
|
865
|
+
epoch_logs.update(train_logs)
|
|
866
|
+
|
|
867
|
+
val_logs = self._validation_step()
|
|
868
|
+
epoch_logs.update(val_logs)
|
|
869
|
+
|
|
870
|
+
self._callbacks_hook('on_epoch_end', epoch, logs=epoch_logs)
|
|
871
|
+
|
|
872
|
+
# Check the early stopping flag
|
|
873
|
+
if self.stop_training:
|
|
874
|
+
break
|
|
875
|
+
|
|
876
|
+
self._callbacks_hook('on_train_end')
|
|
877
|
+
return self.history
|
|
878
|
+
|
|
879
|
+
def _train_step(self):
|
|
880
|
+
self.model.train()
|
|
881
|
+
running_loss = 0.0
|
|
882
|
+
for batch_idx, (images, targets) in enumerate(self.train_loader): # type: ignore
|
|
883
|
+
# images is a tuple of tensors, targets is a tuple of dicts
|
|
884
|
+
batch_size = len(images)
|
|
885
|
+
|
|
886
|
+
# Create a log dictionary for the batch
|
|
887
|
+
batch_logs = {
|
|
888
|
+
PyTorchLogKeys.BATCH_INDEX: batch_idx,
|
|
889
|
+
PyTorchLogKeys.BATCH_SIZE: batch_size
|
|
890
|
+
}
|
|
891
|
+
self._callbacks_hook('on_batch_begin', batch_idx, logs=batch_logs)
|
|
892
|
+
|
|
893
|
+
# Move data to device
|
|
894
|
+
images = list(img.to(self.device) for img in images)
|
|
895
|
+
targets = [{k: v.to(self.device) for k, v in t.items()} for t in targets]
|
|
896
|
+
|
|
897
|
+
self.optimizer.zero_grad()
|
|
898
|
+
|
|
899
|
+
# Model returns a loss dict when in train() mode and targets are passed
|
|
900
|
+
loss_dict = self.model(images, targets)
|
|
901
|
+
|
|
902
|
+
if not loss_dict:
|
|
903
|
+
# No losses returned, skip batch
|
|
904
|
+
_LOGGER.warning(f"Model returned no losses for batch {batch_idx}. Skipping.")
|
|
905
|
+
batch_logs[PyTorchLogKeys.BATCH_LOSS] = 0
|
|
906
|
+
self._callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
|
|
907
|
+
continue
|
|
908
|
+
|
|
909
|
+
# Sum all losses
|
|
910
|
+
loss: torch.Tensor = sum(l for l in loss_dict.values()) # type: ignore
|
|
911
|
+
|
|
912
|
+
loss.backward()
|
|
913
|
+
self.optimizer.step()
|
|
914
|
+
|
|
915
|
+
# Calculate batch loss and update running loss for the epoch
|
|
916
|
+
batch_loss = loss.item()
|
|
917
|
+
running_loss += batch_loss * batch_size
|
|
918
|
+
|
|
919
|
+
# Add the batch loss to the logs and call the end-of-batch hook
|
|
920
|
+
batch_logs[PyTorchLogKeys.BATCH_LOSS] = batch_loss # type: ignore
|
|
921
|
+
self._callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
|
|
922
|
+
|
|
923
|
+
return {PyTorchLogKeys.TRAIN_LOSS: running_loss / len(self.train_loader.dataset)} # type: ignore
|
|
924
|
+
|
|
925
|
+
def _validation_step(self):
|
|
926
|
+
self.model.train() # Set to train mode even for validation loss calculation
|
|
927
|
+
# as model internals (e.g., proposals) might differ,
|
|
928
|
+
# but we still need loss_dict.
|
|
929
|
+
# We use torch.no_grad() to prevent gradient updates.
|
|
930
|
+
running_loss = 0.0
|
|
931
|
+
with torch.no_grad():
|
|
932
|
+
for images, targets in self.test_loader: # type: ignore
|
|
933
|
+
batch_size = len(images)
|
|
934
|
+
|
|
935
|
+
# Move data to device
|
|
936
|
+
images = list(img.to(self.device) for img in images)
|
|
937
|
+
targets = [{k: v.to(self.device) for k, v in t.items()} for t in targets]
|
|
938
|
+
|
|
939
|
+
# Get loss dict
|
|
940
|
+
loss_dict = self.model(images, targets)
|
|
941
|
+
|
|
942
|
+
if not loss_dict:
|
|
943
|
+
_LOGGER.warning("Model returned no losses during validation step. Skipping batch.")
|
|
944
|
+
continue # Skip if no losses
|
|
945
|
+
|
|
946
|
+
# Sum all losses
|
|
947
|
+
loss: torch.Tensor = sum(l for l in loss_dict.values()) # type: ignore
|
|
948
|
+
|
|
949
|
+
running_loss += loss.item() * batch_size
|
|
950
|
+
|
|
951
|
+
logs = {PyTorchLogKeys.VAL_LOSS: running_loss / len(self.test_loader.dataset)} # type: ignore
|
|
952
|
+
return logs
|
|
953
|
+
|
|
954
|
+
def evaluate(self, save_dir: Union[str, Path], data: Optional[Union[DataLoader, Dataset]] = None):
|
|
955
|
+
"""
|
|
956
|
+
Evaluates the model using object detection mAP metrics.
|
|
957
|
+
|
|
958
|
+
Args:
|
|
959
|
+
save_dir (str | Path): Directory to save all reports and plots.
|
|
960
|
+
data (DataLoader | Dataset | None): The data to evaluate on. If None, defaults to the trainer's internal test_dataset.
|
|
961
|
+
"""
|
|
962
|
+
dataset_for_names = None
|
|
963
|
+
eval_loader = None
|
|
964
|
+
|
|
965
|
+
if isinstance(data, DataLoader):
|
|
966
|
+
eval_loader = data
|
|
967
|
+
if hasattr(data, 'dataset'):
|
|
968
|
+
dataset_for_names = data.dataset
|
|
969
|
+
elif isinstance(data, Dataset):
|
|
970
|
+
# Create a new loader from the provided dataset
|
|
971
|
+
eval_loader = DataLoader(data,
|
|
972
|
+
batch_size=self._batch_size,
|
|
973
|
+
shuffle=False,
|
|
974
|
+
num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
|
|
975
|
+
pin_memory=(self.device.type == "cuda"),
|
|
976
|
+
collate_fn=self.collate_fn)
|
|
977
|
+
dataset_for_names = data
|
|
978
|
+
else: # data is None, use the trainer's default test dataset
|
|
979
|
+
if self.test_dataset is None:
|
|
980
|
+
_LOGGER.error("Cannot evaluate. No data provided and no test_dataset available in the trainer.")
|
|
981
|
+
raise ValueError()
|
|
982
|
+
# Create a fresh DataLoader from the test_dataset
|
|
983
|
+
eval_loader = DataLoader(
|
|
984
|
+
self.test_dataset,
|
|
985
|
+
batch_size=self._batch_size,
|
|
986
|
+
shuffle=False,
|
|
987
|
+
num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
|
|
988
|
+
pin_memory=(self.device.type == "cuda"),
|
|
989
|
+
collate_fn=self.collate_fn
|
|
990
|
+
)
|
|
991
|
+
dataset_for_names = self.test_dataset
|
|
992
|
+
|
|
993
|
+
if eval_loader is None:
|
|
994
|
+
_LOGGER.error("Cannot evaluate. No valid data was provided or found.")
|
|
995
|
+
raise ValueError()
|
|
996
|
+
|
|
997
|
+
print("\n--- Model Evaluation ---")
|
|
998
|
+
|
|
999
|
+
all_predictions = []
|
|
1000
|
+
all_targets = []
|
|
1001
|
+
|
|
1002
|
+
self.model.eval() # Set model to evaluation mode
|
|
1003
|
+
self.model.to(self.device)
|
|
1004
|
+
|
|
1005
|
+
with torch.no_grad():
|
|
1006
|
+
for images, targets in eval_loader:
|
|
1007
|
+
# Move images to device
|
|
1008
|
+
images = list(img.to(self.device) for img in images)
|
|
1009
|
+
|
|
1010
|
+
# Model returns predictions when in eval() mode
|
|
1011
|
+
predictions = self.model(images)
|
|
1012
|
+
|
|
1013
|
+
# Move predictions and targets to CPU for aggregation
|
|
1014
|
+
cpu_preds = [{k: v.to('cpu') for k, v in p.items()} for p in predictions]
|
|
1015
|
+
cpu_targets = [{k: v.to('cpu') for k, v in t.items()} for t in targets]
|
|
1016
|
+
|
|
1017
|
+
all_predictions.extend(cpu_preds)
|
|
1018
|
+
all_targets.extend(cpu_targets)
|
|
1019
|
+
|
|
1020
|
+
if not all_targets:
|
|
1021
|
+
_LOGGER.error("Evaluation failed: No data was processed.")
|
|
1022
|
+
return
|
|
1023
|
+
|
|
1024
|
+
# Get class names from the dataset for the report
|
|
1025
|
+
class_names = None
|
|
1026
|
+
try:
|
|
1027
|
+
# Try to get 'classes' from ObjectDetectionDatasetMaker
|
|
1028
|
+
if hasattr(dataset_for_names, 'classes'):
|
|
1029
|
+
class_names = dataset_for_names.classes # type: ignore
|
|
1030
|
+
# Fallback for Subset
|
|
1031
|
+
elif hasattr(dataset_for_names, 'dataset') and hasattr(dataset_for_names.dataset, 'classes'): # type: ignore
|
|
1032
|
+
class_names = dataset_for_names.dataset.classes # type: ignore
|
|
1033
|
+
except AttributeError:
|
|
1034
|
+
_LOGGER.warning("Could not find 'classes' attribute on dataset. Per-class metrics will not be named.")
|
|
1035
|
+
pass # class_names is still None
|
|
1036
|
+
|
|
1037
|
+
# --- Routing Logic ---
|
|
1038
|
+
object_detection_metrics(
|
|
1039
|
+
preds=all_predictions,
|
|
1040
|
+
targets=all_targets,
|
|
1041
|
+
save_dir=save_dir,
|
|
1042
|
+
class_names=class_names,
|
|
1043
|
+
print_output=False
|
|
1044
|
+
)
|
|
1045
|
+
|
|
1046
|
+
print("\n--- Training History ---")
|
|
1047
|
+
plot_losses(self.history, save_dir=save_dir)
|
|
1048
|
+
|
|
1049
|
+
def _callbacks_hook(self, method_name: str, *args, **kwargs):
|
|
1050
|
+
"""Calls the specified method on all callbacks."""
|
|
1051
|
+
for callback in self.callbacks:
|
|
1052
|
+
method = getattr(callback, method_name)
|
|
1053
|
+
method(*args, **kwargs)
|
|
1054
|
+
|
|
1055
|
+
def to_cpu(self):
|
|
1056
|
+
"""
|
|
1057
|
+
Moves the model to the CPU and updates the trainer's device setting.
|
|
1058
|
+
|
|
1059
|
+
This is useful for running operations that require the CPU.
|
|
1060
|
+
"""
|
|
1061
|
+
self.device = torch.device('cpu')
|
|
1062
|
+
self.model.to(self.device)
|
|
1063
|
+
_LOGGER.info("Trainer and model moved to CPU.")
|
|
1064
|
+
|
|
1065
|
+
def to_device(self, device: str):
|
|
1066
|
+
"""
|
|
1067
|
+
Moves the model to the specified device and updates the trainer's device setting.
|
|
1068
|
+
|
|
1069
|
+
Args:
|
|
1070
|
+
device (str): The target device (e.g., 'cuda', 'mps', 'cpu').
|
|
1071
|
+
"""
|
|
1072
|
+
self.device = self._validate_device(device)
|
|
1073
|
+
self.model.to(self.device)
|
|
1074
|
+
_LOGGER.info(f"Trainer and model moved to {self.device}.")
|
|
1075
|
+
|
|
546
1076
|
|
|
547
1077
|
def info():
|
|
548
1078
|
_script_info(__all__)
|