dragon-ml-toolbox 13.3.0__py3-none-any.whl → 16.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/METADATA +20 -6
- dragon_ml_toolbox-16.2.0.dist-info/RECORD +51 -0
- {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +10 -0
- ml_tools/ETL_cleaning.py +20 -20
- ml_tools/ETL_engineering.py +23 -25
- ml_tools/GUI_tools.py +20 -20
- ml_tools/MICE_imputation.py +207 -5
- ml_tools/ML_callbacks.py +43 -26
- ml_tools/ML_configuration.py +788 -0
- ml_tools/ML_datasetmaster.py +303 -448
- ml_tools/ML_evaluation.py +351 -93
- ml_tools/ML_evaluation_multi.py +139 -42
- ml_tools/ML_inference.py +290 -209
- ml_tools/ML_models.py +33 -106
- ml_tools/ML_models_advanced.py +323 -0
- ml_tools/ML_optimization.py +12 -12
- ml_tools/ML_scaler.py +11 -11
- ml_tools/ML_sequence_datasetmaster.py +341 -0
- ml_tools/ML_sequence_evaluation.py +219 -0
- ml_tools/ML_sequence_inference.py +391 -0
- ml_tools/ML_sequence_models.py +139 -0
- ml_tools/ML_trainer.py +1604 -179
- ml_tools/ML_utilities.py +351 -4
- ml_tools/ML_vision_datasetmaster.py +1540 -0
- ml_tools/ML_vision_evaluation.py +284 -0
- ml_tools/ML_vision_inference.py +405 -0
- ml_tools/ML_vision_models.py +641 -0
- ml_tools/ML_vision_transformers.py +284 -0
- ml_tools/PSO_optimization.py +6 -6
- ml_tools/SQL.py +4 -4
- ml_tools/_keys.py +171 -0
- ml_tools/_schema.py +1 -1
- ml_tools/custom_logger.py +37 -14
- ml_tools/data_exploration.py +502 -93
- ml_tools/ensemble_evaluation.py +54 -11
- ml_tools/ensemble_inference.py +7 -33
- ml_tools/ensemble_learning.py +1 -1
- ml_tools/math_utilities.py +1 -1
- ml_tools/optimization_tools.py +2 -2
- ml_tools/path_manager.py +5 -5
- ml_tools/serde.py +2 -2
- ml_tools/utilities.py +192 -4
- dragon_ml_toolbox-13.3.0.dist-info/RECORD +0 -41
- ml_tools/RNN_forecast.py +0 -56
- ml_tools/keys.py +0 -87
- {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/top_level.txt +0 -0
ml_tools/ML_trainer.py
CHANGED
|
@@ -1,79 +1,101 @@
|
|
|
1
|
-
from typing import List, Literal, Union, Optional
|
|
1
|
+
from typing import List, Literal, Union, Optional, Callable, Dict, Any
|
|
2
2
|
from pathlib import Path
|
|
3
3
|
from torch.utils.data import DataLoader, Dataset
|
|
4
4
|
import torch
|
|
5
5
|
from torch import nn
|
|
6
6
|
import numpy as np
|
|
7
|
+
from abc import ABC, abstractmethod
|
|
7
8
|
|
|
8
|
-
from .
|
|
9
|
+
from .path_manager import make_fullpath
|
|
10
|
+
from .ML_callbacks import _Callback, History, TqdmProgressBar, DragonModelCheckpoint, DragonEarlyStopping, DragonLRScheduler
|
|
9
11
|
from .ML_evaluation import classification_metrics, regression_metrics, plot_losses, shap_summary_plot, plot_attention_importance
|
|
10
12
|
from .ML_evaluation_multi import multi_target_regression_metrics, multi_label_classification_metrics, multi_target_shap_summary_plot
|
|
13
|
+
from .ML_vision_evaluation import segmentation_metrics, object_detection_metrics
|
|
14
|
+
from .ML_sequence_evaluation import sequence_to_sequence_metrics, sequence_to_value_metrics
|
|
15
|
+
from .ML_configuration import (RegressionMetricsFormat,
|
|
16
|
+
MultiTargetRegressionMetricsFormat,
|
|
17
|
+
BinaryClassificationMetricsFormat,
|
|
18
|
+
MultiClassClassificationMetricsFormat,
|
|
19
|
+
BinaryImageClassificationMetricsFormat,
|
|
20
|
+
MultiClassImageClassificationMetricsFormat,
|
|
21
|
+
MultiLabelBinaryClassificationMetricsFormat,
|
|
22
|
+
BinarySegmentationMetricsFormat,
|
|
23
|
+
MultiClassSegmentationMetricsFormat,
|
|
24
|
+
SequenceValueMetricsFormat,
|
|
25
|
+
SequenceSequenceMetricsFormat,
|
|
26
|
+
|
|
27
|
+
FinalizeBinaryClassification,
|
|
28
|
+
FinalizeBinarySegmentation,
|
|
29
|
+
FinalizeBinaryImageClassification,
|
|
30
|
+
FinalizeMultiClassClassification,
|
|
31
|
+
FinalizeMultiClassImageClassification,
|
|
32
|
+
FinalizeMultiClassSegmentation,
|
|
33
|
+
FinalizeMultiLabelBinaryClassification,
|
|
34
|
+
FinalizeMultiTargetRegression,
|
|
35
|
+
FinalizeRegression,
|
|
36
|
+
FinalizeObjectDetection,
|
|
37
|
+
FinalizeSequencePrediction)
|
|
38
|
+
|
|
11
39
|
from ._script_info import _script_info
|
|
12
|
-
from .
|
|
40
|
+
from ._keys import PyTorchLogKeys, PyTorchCheckpointKeys, DatasetKeys, MLTaskKeys, MagicWords, DragonTrainerKeys
|
|
13
41
|
from ._logger import _LOGGER
|
|
14
|
-
from .path_manager import make_fullpath
|
|
15
42
|
|
|
16
43
|
|
|
17
44
|
__all__ = [
|
|
18
|
-
"
|
|
45
|
+
"DragonTrainer",
|
|
46
|
+
"DragonDetectionTrainer",
|
|
47
|
+
"DragonSequenceTrainer"
|
|
19
48
|
]
|
|
20
49
|
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
kind: Literal["regression", "classification", "multi_target_regression", "multi_label_classification"],
|
|
25
|
-
criterion: nn.Module, optimizer: torch.optim.Optimizer,
|
|
26
|
-
device: Union[Literal['cuda', 'mps', 'cpu'],str], dataloader_workers: int = 2, callbacks: Optional[List[Callback]] = None):
|
|
27
|
-
"""
|
|
28
|
-
Automates the training process of a PyTorch Model.
|
|
29
|
-
|
|
30
|
-
Built-in Callbacks: `History`, `TqdmProgressBar`
|
|
31
|
-
|
|
32
|
-
Args:
|
|
33
|
-
model (nn.Module): The PyTorch model to train.
|
|
34
|
-
train_dataset (Dataset): The training dataset.
|
|
35
|
-
test_dataset (Dataset): The testing/validation dataset.
|
|
36
|
-
kind (str): Can be 'regression', 'classification', 'multi_target_regression', or 'multi_label_classification'.
|
|
37
|
-
criterion (nn.Module): The loss function.
|
|
38
|
-
optimizer (torch.optim.Optimizer): The optimizer.
|
|
39
|
-
device (str): The device to run training on ('cpu', 'cuda', 'mps').
|
|
40
|
-
dataloader_workers (int): Subprocesses for data loading.
|
|
41
|
-
callbacks (List[Callback] | None): A list of callbacks to use during training.
|
|
42
|
-
|
|
43
|
-
Note:
|
|
44
|
-
- For **regression** and **multi_target_regression** tasks, suggested criterions include `nn.MSELoss` or `nn.L1Loss`.
|
|
50
|
+
class _BaseDragonTrainer(ABC):
|
|
51
|
+
"""
|
|
52
|
+
Abstract base class for Dragon Trainers.
|
|
45
53
|
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
54
|
+
Handles the common training loop orchestration, checkpointing, callback
|
|
55
|
+
management, and device handling. Subclasses must implement the
|
|
56
|
+
task-specific logic (dataloaders, train/val steps, evaluation).
|
|
57
|
+
"""
|
|
58
|
+
def __init__(self,
|
|
59
|
+
model: nn.Module,
|
|
60
|
+
optimizer: torch.optim.Optimizer,
|
|
61
|
+
device: Union[Literal['cuda', 'mps', 'cpu'],str],
|
|
62
|
+
dataloader_workers: int = 2,
|
|
63
|
+
checkpoint_callback: Optional[DragonModelCheckpoint] = None,
|
|
64
|
+
early_stopping_callback: Optional[DragonEarlyStopping] = None,
|
|
65
|
+
lr_scheduler_callback: Optional[DragonLRScheduler] = None,
|
|
66
|
+
extra_callbacks: Optional[List[_Callback]] = None):
|
|
52
67
|
|
|
53
68
|
self.model = model
|
|
54
|
-
self.train_dataset = train_dataset
|
|
55
|
-
self.test_dataset = test_dataset
|
|
56
|
-
self.kind = kind
|
|
57
|
-
self.criterion = criterion
|
|
58
69
|
self.optimizer = optimizer
|
|
59
70
|
self.scheduler = None
|
|
60
71
|
self.device = self._validate_device(device)
|
|
61
72
|
self.dataloader_workers = dataloader_workers
|
|
62
73
|
|
|
63
|
-
# Callback handler
|
|
74
|
+
# Callback handler
|
|
64
75
|
default_callbacks = [History(), TqdmProgressBar()]
|
|
65
|
-
|
|
76
|
+
|
|
77
|
+
self._checkpoint_callback = None
|
|
78
|
+
if checkpoint_callback:
|
|
79
|
+
default_callbacks.append(checkpoint_callback)
|
|
80
|
+
self._checkpoint_callback = checkpoint_callback
|
|
81
|
+
if early_stopping_callback:
|
|
82
|
+
default_callbacks.append(early_stopping_callback)
|
|
83
|
+
if lr_scheduler_callback:
|
|
84
|
+
default_callbacks.append(lr_scheduler_callback)
|
|
85
|
+
|
|
86
|
+
user_callbacks = extra_callbacks if extra_callbacks is not None else []
|
|
66
87
|
self.callbacks = default_callbacks + user_callbacks
|
|
67
88
|
self._set_trainer_on_callbacks()
|
|
68
89
|
|
|
69
90
|
# Internal state
|
|
70
|
-
self.train_loader = None
|
|
71
|
-
self.
|
|
72
|
-
self.history = {}
|
|
91
|
+
self.train_loader: Optional[DataLoader] = None
|
|
92
|
+
self.validation_loader: Optional[DataLoader] = None
|
|
93
|
+
self.history: Dict[str, List[Any]] = {}
|
|
73
94
|
self.epoch = 0
|
|
74
95
|
self.epochs = 0 # Total epochs for the fit run
|
|
75
96
|
self.start_epoch = 1
|
|
76
97
|
self.stop_training = False
|
|
98
|
+
self._batch_size = 10
|
|
77
99
|
|
|
78
100
|
def _validate_device(self, device: str) -> torch.device:
|
|
79
101
|
"""Validates the selected device and returns a torch.device object."""
|
|
@@ -91,32 +113,10 @@ class MLTrainer:
|
|
|
91
113
|
for callback in self.callbacks:
|
|
92
114
|
callback.set_trainer(self)
|
|
93
115
|
|
|
94
|
-
def _create_dataloaders(self, batch_size: int, shuffle: bool):
|
|
95
|
-
"""Initializes the DataLoaders."""
|
|
96
|
-
# Ensure stability on MPS devices by setting num_workers to 0
|
|
97
|
-
loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
|
|
98
|
-
|
|
99
|
-
self.train_loader = DataLoader(
|
|
100
|
-
dataset=self.train_dataset,
|
|
101
|
-
batch_size=batch_size,
|
|
102
|
-
shuffle=shuffle,
|
|
103
|
-
num_workers=loader_workers,
|
|
104
|
-
pin_memory=("cuda" in self.device.type),
|
|
105
|
-
drop_last=True # Drops the last batch if incomplete, selecting a good batch size is key.
|
|
106
|
-
)
|
|
107
|
-
|
|
108
|
-
self.test_loader = DataLoader(
|
|
109
|
-
dataset=self.test_dataset,
|
|
110
|
-
batch_size=batch_size,
|
|
111
|
-
shuffle=False,
|
|
112
|
-
num_workers=loader_workers,
|
|
113
|
-
pin_memory=("cuda" in self.device.type)
|
|
114
|
-
)
|
|
115
|
-
|
|
116
116
|
def _load_checkpoint(self, path: Union[str, Path]):
|
|
117
117
|
"""Loads a training checkpoint to resume training."""
|
|
118
118
|
p = make_fullpath(path, enforce="file")
|
|
119
|
-
_LOGGER.info(f"Loading checkpoint from '{p.name}'
|
|
119
|
+
_LOGGER.info(f"Loading checkpoint from '{p.name}'...")
|
|
120
120
|
|
|
121
121
|
try:
|
|
122
122
|
checkpoint = torch.load(p, map_location=self.device)
|
|
@@ -127,7 +127,16 @@ class MLTrainer:
|
|
|
127
127
|
|
|
128
128
|
self.model.load_state_dict(checkpoint[PyTorchCheckpointKeys.MODEL_STATE])
|
|
129
129
|
self.optimizer.load_state_dict(checkpoint[PyTorchCheckpointKeys.OPTIMIZER_STATE])
|
|
130
|
-
self.
|
|
130
|
+
self.epoch = checkpoint.get(PyTorchCheckpointKeys.EPOCH, 0)
|
|
131
|
+
self.start_epoch = self.epoch + 1 # Resume on the *next* epoch
|
|
132
|
+
|
|
133
|
+
# --- Load History ---
|
|
134
|
+
if PyTorchCheckpointKeys.HISTORY in checkpoint:
|
|
135
|
+
self.history = checkpoint[PyTorchCheckpointKeys.HISTORY]
|
|
136
|
+
_LOGGER.info(f"Restored training history up to epoch {self.epoch}.")
|
|
137
|
+
else:
|
|
138
|
+
_LOGGER.warning("No 'history' found in checkpoint. A new history will be started.")
|
|
139
|
+
self.history = {} # Ensure it's at least an empty dict
|
|
131
140
|
|
|
132
141
|
# --- Scheduler State Loading Logic ---
|
|
133
142
|
scheduler_state_exists = PyTorchCheckpointKeys.SCHEDULER_STATE in checkpoint
|
|
@@ -157,7 +166,7 @@ class MLTrainer:
|
|
|
157
166
|
|
|
158
167
|
# Restore callback states
|
|
159
168
|
for cb in self.callbacks:
|
|
160
|
-
if isinstance(cb,
|
|
169
|
+
if isinstance(cb, DragonModelCheckpoint) and PyTorchCheckpointKeys.BEST_SCORE in checkpoint:
|
|
161
170
|
cb.best = checkpoint[PyTorchCheckpointKeys.BEST_SCORE]
|
|
162
171
|
_LOGGER.info(f"Restored {cb.__class__.__name__} 'best' score to: {cb.best:.4f}")
|
|
163
172
|
|
|
@@ -168,7 +177,8 @@ class MLTrainer:
|
|
|
168
177
|
raise
|
|
169
178
|
|
|
170
179
|
def fit(self,
|
|
171
|
-
|
|
180
|
+
save_dir: Union[str,Path],
|
|
181
|
+
epochs: int = 100,
|
|
172
182
|
batch_size: int = 10,
|
|
173
183
|
shuffle: bool = True,
|
|
174
184
|
resume_from_checkpoint: Optional[Union[str, Path]] = None):
|
|
@@ -178,20 +188,15 @@ class MLTrainer:
|
|
|
178
188
|
Returns the "History" callback dictionary.
|
|
179
189
|
|
|
180
190
|
Args:
|
|
191
|
+
save_dir (str | Path): Directory to save the loss plot.
|
|
181
192
|
epochs (int): The total number of epochs to train for.
|
|
182
193
|
batch_size (int): The number of samples per batch.
|
|
183
194
|
shuffle (bool): Whether to shuffle the training data at each epoch.
|
|
184
195
|
resume_from_checkpoint (str | Path | None): Optional path to a checkpoint to resume training.
|
|
185
|
-
|
|
186
|
-
Note:
|
|
187
|
-
For regression tasks using `nn.MSELoss` or `nn.L1Loss`, the trainer
|
|
188
|
-
automatically aligns the model's output tensor with the target tensor's
|
|
189
|
-
shape using `output.view_as(target)`. This handles the common case
|
|
190
|
-
where a model outputs a shape of `[batch_size, 1]` and the target has a
|
|
191
|
-
shape of `[batch_size]`.
|
|
192
196
|
"""
|
|
193
197
|
self.epochs = epochs
|
|
194
|
-
self.
|
|
198
|
+
self._batch_size = batch_size
|
|
199
|
+
self._create_dataloaders(self._batch_size, shuffle) # type: ignore
|
|
195
200
|
self.model.to(self.device)
|
|
196
201
|
|
|
197
202
|
if resume_from_checkpoint:
|
|
@@ -202,11 +207,19 @@ class MLTrainer:
|
|
|
202
207
|
|
|
203
208
|
self._callbacks_hook('on_train_begin')
|
|
204
209
|
|
|
210
|
+
if not self.train_loader:
|
|
211
|
+
_LOGGER.error("Train loader is not initialized.")
|
|
212
|
+
raise ValueError()
|
|
213
|
+
|
|
214
|
+
if not self.validation_loader:
|
|
215
|
+
_LOGGER.error("Validation loader is not initialized.")
|
|
216
|
+
raise ValueError()
|
|
217
|
+
|
|
205
218
|
for epoch in range(self.start_epoch, self.epochs + 1):
|
|
206
219
|
self.epoch = epoch
|
|
207
|
-
epoch_logs = {}
|
|
220
|
+
epoch_logs: Dict[str, Any] = {}
|
|
208
221
|
self._callbacks_hook('on_epoch_begin', epoch, logs=epoch_logs)
|
|
209
|
-
|
|
222
|
+
|
|
210
223
|
train_logs = self._train_step()
|
|
211
224
|
epoch_logs.update(train_logs)
|
|
212
225
|
|
|
@@ -220,11 +233,204 @@ class MLTrainer:
|
|
|
220
233
|
break
|
|
221
234
|
|
|
222
235
|
self._callbacks_hook('on_train_end')
|
|
236
|
+
|
|
237
|
+
# Training History
|
|
238
|
+
plot_losses(self.history, save_dir=save_dir)
|
|
239
|
+
|
|
223
240
|
return self.history
|
|
241
|
+
|
|
242
|
+
def _callbacks_hook(self, method_name: str, *args, **kwargs):
|
|
243
|
+
"""Calls the specified method on all callbacks."""
|
|
244
|
+
for callback in self.callbacks:
|
|
245
|
+
method = getattr(callback, method_name)
|
|
246
|
+
method(*args, **kwargs)
|
|
247
|
+
|
|
248
|
+
def to_cpu(self):
|
|
249
|
+
"""
|
|
250
|
+
Moves the model to the CPU and updates the trainer's device setting.
|
|
251
|
+
|
|
252
|
+
This is useful for running operations that require the CPU.
|
|
253
|
+
"""
|
|
254
|
+
self.device = torch.device('cpu')
|
|
255
|
+
self.model.to(self.device)
|
|
256
|
+
_LOGGER.info("Trainer and model moved to CPU.")
|
|
257
|
+
|
|
258
|
+
def to_device(self, device: str):
|
|
259
|
+
"""
|
|
260
|
+
Moves the model to the specified device and updates the trainer's device setting.
|
|
261
|
+
|
|
262
|
+
Args:
|
|
263
|
+
device (str): The target device (e.g., 'cuda', 'mps', 'cpu').
|
|
264
|
+
"""
|
|
265
|
+
self.device = self._validate_device(device)
|
|
266
|
+
self.model.to(self.device)
|
|
267
|
+
_LOGGER.info(f"Trainer and model moved to {self.device}.")
|
|
268
|
+
|
|
269
|
+
def _load_model_state_for_finalizing(self, model_checkpoint: Union[Path, Literal['latest', 'current']]):
|
|
270
|
+
"""
|
|
271
|
+
Private helper to load the correct model state_dict based on user's choice.
|
|
272
|
+
This is called by finalize_model_training() in subclasses.
|
|
273
|
+
"""
|
|
274
|
+
if isinstance(model_checkpoint, Path):
|
|
275
|
+
self._load_checkpoint(path=model_checkpoint)
|
|
276
|
+
elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback:
|
|
277
|
+
path_to_latest = self._checkpoint_callback.best_checkpoint_path
|
|
278
|
+
self._load_checkpoint(path_to_latest)
|
|
279
|
+
elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback is None:
|
|
280
|
+
_LOGGER.error(f"'model_checkpoint' set to '{MagicWords.LATEST}' but no checkpoint callback was found.")
|
|
281
|
+
raise ValueError()
|
|
282
|
+
elif model_checkpoint == MagicWords.CURRENT:
|
|
283
|
+
pass
|
|
284
|
+
else:
|
|
285
|
+
_LOGGER.error(f"Unknown 'model_checkpoint' received '{model_checkpoint}'.")
|
|
286
|
+
raise ValueError()
|
|
287
|
+
|
|
288
|
+
# --- Abstract Methods ---
|
|
289
|
+
# These must be implemented by subclasses
|
|
290
|
+
|
|
291
|
+
@abstractmethod
|
|
292
|
+
def _create_dataloaders(self, batch_size: int, shuffle: bool):
|
|
293
|
+
"""Initializes the DataLoaders."""
|
|
294
|
+
raise NotImplementedError
|
|
295
|
+
|
|
296
|
+
@abstractmethod
|
|
297
|
+
def _train_step(self) -> Dict[str, float]:
|
|
298
|
+
"""Runs a single training epoch."""
|
|
299
|
+
raise NotImplementedError
|
|
300
|
+
|
|
301
|
+
@abstractmethod
|
|
302
|
+
def _validation_step(self) -> Dict[str, float]:
|
|
303
|
+
"""Runs a single validation epoch."""
|
|
304
|
+
raise NotImplementedError
|
|
305
|
+
|
|
306
|
+
@abstractmethod
|
|
307
|
+
def evaluate(self, *args, **kwargs):
|
|
308
|
+
"""Runs the full model evaluation."""
|
|
309
|
+
raise NotImplementedError
|
|
310
|
+
|
|
311
|
+
@abstractmethod
|
|
312
|
+
def _evaluate(self, *args, **kwargs):
|
|
313
|
+
"""Internal evaluation helper."""
|
|
314
|
+
raise NotImplementedError
|
|
315
|
+
|
|
316
|
+
@abstractmethod
|
|
317
|
+
def finalize_model_training(self, *args, **kwargs):
|
|
318
|
+
"""Saves the finalized model for inference."""
|
|
319
|
+
raise NotImplementedError
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
# --- DragonTrainer ----
|
|
323
|
+
class DragonTrainer(_BaseDragonTrainer):
|
|
324
|
+
def __init__(self,
|
|
325
|
+
model: nn.Module,
|
|
326
|
+
train_dataset: Dataset,
|
|
327
|
+
validation_dataset: Dataset,
|
|
328
|
+
kind: Literal["regression", "binary classification", "multiclass classification",
|
|
329
|
+
"multitarget regression", "multilabel binary classification",
|
|
330
|
+
"binary segmentation", "multiclass segmentation", "binary image classification", "multiclass image classification"],
|
|
331
|
+
optimizer: torch.optim.Optimizer,
|
|
332
|
+
device: Union[Literal['cuda', 'mps', 'cpu'],str],
|
|
333
|
+
checkpoint_callback: Optional[DragonModelCheckpoint],
|
|
334
|
+
early_stopping_callback: Optional[DragonEarlyStopping],
|
|
335
|
+
lr_scheduler_callback: Optional[DragonLRScheduler],
|
|
336
|
+
extra_callbacks: Optional[List[_Callback]] = None,
|
|
337
|
+
criterion: Union[nn.Module,Literal["auto"]] = "auto",
|
|
338
|
+
dataloader_workers: int = 2):
|
|
339
|
+
"""
|
|
340
|
+
Automates the training process of a PyTorch Model.
|
|
341
|
+
|
|
342
|
+
Built-in Callbacks: `History`, `TqdmProgressBar`
|
|
343
|
+
|
|
344
|
+
Args:
|
|
345
|
+
model (nn.Module): The PyTorch model to train.
|
|
346
|
+
train_dataset (Dataset): The training dataset.
|
|
347
|
+
validation_dataset (Dataset): The validation dataset.
|
|
348
|
+
kind (str): Used to redirect to the correct process.
|
|
349
|
+
criterion (nn.Module | "auto"): The loss function to use. If "auto", it will be inferred from the selected task
|
|
350
|
+
optimizer (torch.optim.Optimizer): The optimizer.
|
|
351
|
+
device (str): The device to run training on ('cpu', 'cuda', 'mps').
|
|
352
|
+
dataloader_workers (int): Subprocesses for data loading.
|
|
353
|
+
extra_callbacks (List[Callback] | None): A list of extra callbacks to use during training.
|
|
354
|
+
|
|
355
|
+
Note:
|
|
356
|
+
- For **regression** and **multi_target_regression** tasks, suggested criterions include `nn.MSELoss` or `nn.L1Loss`. The model should output as many logits as existing targets.
|
|
357
|
+
|
|
358
|
+
- For **single-label, binary classification**, `nn.BCEWithLogitsLoss` is the standard choice. The model should output a single logit.
|
|
224
359
|
|
|
360
|
+
- For **single-label, multi-class classification** tasks, `nn.CrossEntropyLoss` is the standard choice. The model should output as many logits as existing classes.
|
|
361
|
+
|
|
362
|
+
- For **multi-label, binary classification** tasks (where each label is a 0 or 1), `nn.BCEWithLogitsLoss` is the correct choice as it treats each output as an independent binary problem. The model should output 1 logit per binary target.
|
|
363
|
+
|
|
364
|
+
- For **binary segmentation** tasks, `nn.BCEWithLogitsLoss` is common. The model should output a single logit.
|
|
365
|
+
|
|
366
|
+
- for **multiclass segmentation** tasks, `nn.CrossEntropyLoss` is the standard. The model should output as many logits as existing classes.
|
|
367
|
+
"""
|
|
368
|
+
# Call the base class constructor with common parameters
|
|
369
|
+
super().__init__(
|
|
370
|
+
model=model,
|
|
371
|
+
optimizer=optimizer,
|
|
372
|
+
device=device,
|
|
373
|
+
dataloader_workers=dataloader_workers,
|
|
374
|
+
checkpoint_callback=checkpoint_callback,
|
|
375
|
+
early_stopping_callback=early_stopping_callback,
|
|
376
|
+
lr_scheduler_callback=lr_scheduler_callback,
|
|
377
|
+
extra_callbacks=extra_callbacks
|
|
378
|
+
)
|
|
379
|
+
|
|
380
|
+
if kind not in [MLTaskKeys.REGRESSION,
|
|
381
|
+
MLTaskKeys.BINARY_CLASSIFICATION,
|
|
382
|
+
MLTaskKeys.MULTICLASS_CLASSIFICATION,
|
|
383
|
+
MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION,
|
|
384
|
+
MLTaskKeys.MULTITARGET_REGRESSION,
|
|
385
|
+
MLTaskKeys.BINARY_SEGMENTATION,
|
|
386
|
+
MLTaskKeys.MULTICLASS_SEGMENTATION,
|
|
387
|
+
MLTaskKeys.BINARY_IMAGE_CLASSIFICATION,
|
|
388
|
+
MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION]:
|
|
389
|
+
raise ValueError(f"'{kind}' is not a valid task type.")
|
|
390
|
+
|
|
391
|
+
self.train_dataset = train_dataset
|
|
392
|
+
self.validation_dataset = validation_dataset
|
|
393
|
+
self.kind = kind
|
|
394
|
+
self._classification_threshold: float = 0.5
|
|
395
|
+
|
|
396
|
+
# loss function
|
|
397
|
+
if criterion == "auto":
|
|
398
|
+
if kind in [MLTaskKeys.REGRESSION, MLTaskKeys.MULTITARGET_REGRESSION]:
|
|
399
|
+
self.criterion = nn.MSELoss()
|
|
400
|
+
elif kind in [MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.BINARY_IMAGE_CLASSIFICATION, MLTaskKeys.BINARY_SEGMENTATION, MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION]:
|
|
401
|
+
self.criterion = nn.BCEWithLogitsLoss()
|
|
402
|
+
elif kind in [MLTaskKeys.MULTICLASS_CLASSIFICATION, MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION, MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION, MLTaskKeys.MULTICLASS_SEGMENTATION]:
|
|
403
|
+
self.criterion = nn.CrossEntropyLoss()
|
|
404
|
+
else:
|
|
405
|
+
self.criterion = criterion
|
|
406
|
+
|
|
407
|
+
def _create_dataloaders(self, batch_size: int, shuffle: bool):
|
|
408
|
+
"""Initializes the DataLoaders."""
|
|
409
|
+
# Ensure stability on MPS devices by setting num_workers to 0
|
|
410
|
+
loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
|
|
411
|
+
|
|
412
|
+
self.train_loader = DataLoader(
|
|
413
|
+
dataset=self.train_dataset,
|
|
414
|
+
batch_size=batch_size,
|
|
415
|
+
shuffle=shuffle,
|
|
416
|
+
num_workers=loader_workers,
|
|
417
|
+
pin_memory=("cuda" in self.device.type),
|
|
418
|
+
drop_last=True # Drops the last batch if incomplete, selecting a good batch size is key.
|
|
419
|
+
)
|
|
420
|
+
|
|
421
|
+
self.validation_loader = DataLoader(
|
|
422
|
+
dataset=self.validation_dataset,
|
|
423
|
+
batch_size=batch_size,
|
|
424
|
+
shuffle=False,
|
|
425
|
+
num_workers=loader_workers,
|
|
426
|
+
pin_memory=("cuda" in self.device.type)
|
|
427
|
+
)
|
|
428
|
+
|
|
225
429
|
def _train_step(self):
|
|
226
430
|
self.model.train()
|
|
227
431
|
running_loss = 0.0
|
|
432
|
+
total_samples = 0
|
|
433
|
+
|
|
228
434
|
for batch_idx, (features, target) in enumerate(self.train_loader): # type: ignore
|
|
229
435
|
# Create a log dictionary for the batch
|
|
230
436
|
batch_logs = {
|
|
@@ -238,9 +444,21 @@ class MLTrainer:
|
|
|
238
444
|
|
|
239
445
|
output = self.model(features)
|
|
240
446
|
|
|
241
|
-
#
|
|
242
|
-
|
|
243
|
-
|
|
447
|
+
# --- Label Type/Shape Correction ---
|
|
448
|
+
# Cast target to float for BCE-based losses
|
|
449
|
+
if self.kind in MLTaskKeys.ALL_BINARY_TASKS:
|
|
450
|
+
target = target.float()
|
|
451
|
+
|
|
452
|
+
# Reshape output to match target for single-logit tasks
|
|
453
|
+
if self.kind in [MLTaskKeys.REGRESSION, MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.BINARY_IMAGE_CLASSIFICATION]:
|
|
454
|
+
# If model outputs [N, 1] and target is [N], squeeze output
|
|
455
|
+
if output.ndim == 2 and output.shape[1] == 1 and target.ndim == 1:
|
|
456
|
+
output = output.squeeze(1)
|
|
457
|
+
|
|
458
|
+
if self.kind == MLTaskKeys.BINARY_SEGMENTATION:
|
|
459
|
+
# If model outputs [N, 1, H, W] and target is [N, H, W], squeeze output
|
|
460
|
+
if output.ndim == 4 and output.shape[1] == 1 and target.ndim == 3:
|
|
461
|
+
output = output.squeeze(1)
|
|
244
462
|
|
|
245
463
|
loss = self.criterion(output, target)
|
|
246
464
|
|
|
@@ -249,34 +467,58 @@ class MLTrainer:
|
|
|
249
467
|
|
|
250
468
|
# Calculate batch loss and update running loss for the epoch
|
|
251
469
|
batch_loss = loss.item()
|
|
252
|
-
|
|
470
|
+
batch_size = features.size(0)
|
|
471
|
+
running_loss += batch_loss * batch_size # Accumulate total loss
|
|
472
|
+
total_samples += batch_size # total samples
|
|
253
473
|
|
|
254
474
|
# Add the batch loss to the logs and call the end-of-batch hook
|
|
255
475
|
batch_logs[PyTorchLogKeys.BATCH_LOSS] = batch_loss
|
|
256
476
|
self._callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
|
|
477
|
+
|
|
478
|
+
if total_samples == 0:
|
|
479
|
+
_LOGGER.warning("No samples processed in a train_step. Returning 0 loss.")
|
|
480
|
+
return {PyTorchLogKeys.TRAIN_LOSS: 0.0}
|
|
257
481
|
|
|
258
|
-
return {PyTorchLogKeys.TRAIN_LOSS: running_loss /
|
|
482
|
+
return {PyTorchLogKeys.TRAIN_LOSS: running_loss / total_samples} # type: ignore
|
|
259
483
|
|
|
260
484
|
def _validation_step(self):
|
|
261
485
|
self.model.eval()
|
|
262
486
|
running_loss = 0.0
|
|
487
|
+
|
|
263
488
|
with torch.no_grad():
|
|
264
|
-
for features, target in self.
|
|
489
|
+
for features, target in self.validation_loader: # type: ignore
|
|
265
490
|
features, target = features.to(self.device), target.to(self.device)
|
|
266
491
|
|
|
267
492
|
output = self.model(features)
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
493
|
+
|
|
494
|
+
# --- Label Type/Shape Correction ---
|
|
495
|
+
# Cast target to float for BCE-based losses
|
|
496
|
+
if self.kind in MLTaskKeys.ALL_BINARY_TASKS:
|
|
497
|
+
target = target.float()
|
|
498
|
+
|
|
499
|
+
# Reshape output to match target for single-logit tasks
|
|
500
|
+
if self.kind in [MLTaskKeys.REGRESSION, MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.BINARY_IMAGE_CLASSIFICATION]:
|
|
501
|
+
# If model outputs [N, 1] and target is [N], squeeze output
|
|
502
|
+
if output.ndim == 2 and output.shape[1] == 1 and target.ndim == 1:
|
|
503
|
+
output = output.squeeze(1)
|
|
504
|
+
|
|
505
|
+
if self.kind == MLTaskKeys.BINARY_SEGMENTATION:
|
|
506
|
+
# If model outputs [N, 1, H, W] and target is [N, H, W], squeeze output
|
|
507
|
+
if output.ndim == 4 and output.shape[1] == 1 and target.ndim == 3:
|
|
508
|
+
output = output.squeeze(1)
|
|
271
509
|
|
|
272
510
|
loss = self.criterion(output, target)
|
|
273
511
|
|
|
274
512
|
running_loss += loss.item() * features.size(0)
|
|
513
|
+
|
|
514
|
+
if not self.validation_loader.dataset: # type: ignore
|
|
515
|
+
_LOGGER.warning("No samples processed in _validation_step. Returning 0 loss.")
|
|
516
|
+
return {PyTorchLogKeys.VAL_LOSS: 0.0}
|
|
275
517
|
|
|
276
|
-
logs = {PyTorchLogKeys.VAL_LOSS: running_loss / len(self.
|
|
518
|
+
logs = {PyTorchLogKeys.VAL_LOSS: running_loss / len(self.validation_loader.dataset)} # type: ignore
|
|
277
519
|
return logs
|
|
278
520
|
|
|
279
|
-
def _predict_for_eval(self, dataloader: DataLoader
|
|
521
|
+
def _predict_for_eval(self, dataloader: DataLoader):
|
|
280
522
|
"""
|
|
281
523
|
Private method to yield model predictions batch by batch for evaluation.
|
|
282
524
|
|
|
@@ -287,77 +529,301 @@ class MLTrainer:
|
|
|
287
529
|
"""
|
|
288
530
|
self.model.eval()
|
|
289
531
|
self.model.to(self.device)
|
|
532
|
+
|
|
290
533
|
with torch.no_grad():
|
|
291
534
|
for features, target in dataloader:
|
|
292
535
|
features = features.to(self.device)
|
|
293
536
|
output = self.model(features).cpu()
|
|
294
|
-
y_true_batch = target.numpy()
|
|
295
537
|
|
|
296
538
|
y_pred_batch = None
|
|
297
539
|
y_prob_batch = None
|
|
540
|
+
y_true_batch = None
|
|
298
541
|
|
|
299
|
-
if self.kind in [
|
|
542
|
+
if self.kind in [MLTaskKeys.REGRESSION, MLTaskKeys.MULTITARGET_REGRESSION]:
|
|
300
543
|
y_pred_batch = output.numpy()
|
|
544
|
+
y_true_batch = target.numpy()
|
|
545
|
+
|
|
546
|
+
elif self.kind in [MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.BINARY_IMAGE_CLASSIFICATION]:
|
|
547
|
+
# Assumes model output is [N, 1] (a single logit)
|
|
548
|
+
# Squeeze output from [N, 1] to [N] if necessary
|
|
549
|
+
if output.ndim == 2 and output.shape[1] == 1:
|
|
550
|
+
output = output.squeeze(1)
|
|
551
|
+
|
|
552
|
+
probs_pos = torch.sigmoid(output) # Probability of positive class
|
|
553
|
+
preds = (probs_pos >= self._classification_threshold).int()
|
|
554
|
+
y_pred_batch = preds.numpy()
|
|
555
|
+
# For metrics (like ROC AUC), we often need probs for *both* classes
|
|
556
|
+
# Create an [N, 2] array: [prob_class_0, prob_class_1]
|
|
557
|
+
probs_neg = 1.0 - probs_pos
|
|
558
|
+
y_prob_batch = torch.stack([probs_neg, probs_pos], dim=1).numpy()
|
|
559
|
+
y_true_batch = target.numpy()
|
|
301
560
|
|
|
302
|
-
elif self.kind
|
|
561
|
+
elif self.kind in [MLTaskKeys.MULTICLASS_CLASSIFICATION, MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION]:
|
|
562
|
+
num_classes = output.shape[1]
|
|
563
|
+
if num_classes < 3:
|
|
564
|
+
# Optional: warn the user they are using the wrong kind
|
|
565
|
+
wrong_class = MLTaskKeys.MULTICLASS_CLASSIFICATION if self.kind == MLTaskKeys.MULTICLASS_CLASSIFICATION else MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION
|
|
566
|
+
recommended_class = MLTaskKeys.BINARY_CLASSIFICATION if self.kind == MLTaskKeys.MULTICLASS_CLASSIFICATION else MLTaskKeys.BINARY_IMAGE_CLASSIFICATION
|
|
567
|
+
_LOGGER.warning(f"'{wrong_class}' kind used with {num_classes} classes. Consider using '{recommended_class}' instead.")
|
|
568
|
+
|
|
303
569
|
probs = torch.softmax(output, dim=1)
|
|
304
570
|
preds = torch.argmax(probs, dim=1)
|
|
305
571
|
y_pred_batch = preds.numpy()
|
|
306
572
|
y_prob_batch = probs.numpy()
|
|
573
|
+
y_true_batch = target.numpy()
|
|
307
574
|
|
|
308
|
-
elif self.kind ==
|
|
575
|
+
elif self.kind == MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION:
|
|
309
576
|
probs = torch.sigmoid(output)
|
|
310
|
-
preds = (probs >=
|
|
577
|
+
preds = (probs >= self._classification_threshold).int()
|
|
311
578
|
y_pred_batch = preds.numpy()
|
|
312
579
|
y_prob_batch = probs.numpy()
|
|
580
|
+
y_true_batch = target.numpy()
|
|
581
|
+
|
|
582
|
+
elif self.kind == MLTaskKeys.BINARY_SEGMENTATION:
|
|
583
|
+
# Assumes model output is [N, 1, H, W] (logits for positive class)
|
|
584
|
+
probs_pos = torch.sigmoid(output) # Shape [N, 1, H, W]
|
|
585
|
+
preds = (probs_pos >= self._classification_threshold).int() # Shape [N, 1, H, W]
|
|
586
|
+
|
|
587
|
+
# Squeeze preds to [N, H, W] (class indices 0 or 1)
|
|
588
|
+
y_pred_batch = preds.squeeze(1).numpy()
|
|
313
589
|
|
|
314
|
-
|
|
590
|
+
# Create [N, 2, H, W] probs for consistency
|
|
591
|
+
probs_neg = 1.0 - probs_pos
|
|
592
|
+
y_prob_batch = torch.cat([probs_neg, probs_pos], dim=1).numpy()
|
|
593
|
+
|
|
594
|
+
# Handle target shape [N, 1, H, W] -> [N, H, W]
|
|
595
|
+
if target.ndim == 4 and target.shape[1] == 1:
|
|
596
|
+
target = target.squeeze(1)
|
|
597
|
+
y_true_batch = target.numpy()
|
|
598
|
+
|
|
599
|
+
elif self.kind == MLTaskKeys.MULTICLASS_SEGMENTATION:
|
|
600
|
+
# output shape [N, C, H, W]
|
|
601
|
+
probs = torch.softmax(output, dim=1)
|
|
602
|
+
preds = torch.argmax(probs, dim=1) # shape [N, H, W]
|
|
603
|
+
y_pred_batch = preds.numpy()
|
|
604
|
+
y_prob_batch = probs.numpy() # Probs are [N, C, H, W]
|
|
605
|
+
|
|
606
|
+
# Handle target shape [N, 1, H, W] -> [N, H, W]
|
|
607
|
+
if target.ndim == 4 and target.shape[1] == 1:
|
|
608
|
+
target = target.squeeze(1)
|
|
609
|
+
y_true_batch = target.numpy()
|
|
315
610
|
|
|
316
|
-
|
|
611
|
+
yield y_pred_batch, y_prob_batch, y_true_batch
|
|
612
|
+
|
|
613
|
+
def evaluate(self,
|
|
614
|
+
save_dir: Union[str, Path],
|
|
615
|
+
model_checkpoint: Union[Path, Literal["latest", "current"]],
|
|
616
|
+
classification_threshold: Optional[float] = None,
|
|
617
|
+
test_data: Optional[Union[DataLoader, Dataset]] = None,
|
|
618
|
+
val_format_configuration: Optional[Union[
|
|
619
|
+
RegressionMetricsFormat,
|
|
620
|
+
MultiTargetRegressionMetricsFormat,
|
|
621
|
+
BinaryClassificationMetricsFormat,
|
|
622
|
+
MultiClassClassificationMetricsFormat,
|
|
623
|
+
BinaryImageClassificationMetricsFormat,
|
|
624
|
+
MultiClassImageClassificationMetricsFormat,
|
|
625
|
+
MultiLabelBinaryClassificationMetricsFormat,
|
|
626
|
+
BinarySegmentationMetricsFormat,
|
|
627
|
+
MultiClassSegmentationMetricsFormat
|
|
628
|
+
]]=None,
|
|
629
|
+
test_format_configuration: Optional[Union[
|
|
630
|
+
RegressionMetricsFormat,
|
|
631
|
+
MultiTargetRegressionMetricsFormat,
|
|
632
|
+
BinaryClassificationMetricsFormat,
|
|
633
|
+
MultiClassClassificationMetricsFormat,
|
|
634
|
+
BinaryImageClassificationMetricsFormat,
|
|
635
|
+
MultiClassImageClassificationMetricsFormat,
|
|
636
|
+
MultiLabelBinaryClassificationMetricsFormat,
|
|
637
|
+
BinarySegmentationMetricsFormat,
|
|
638
|
+
MultiClassSegmentationMetricsFormat,
|
|
639
|
+
]]=None):
|
|
317
640
|
"""
|
|
318
641
|
Evaluates the model, routing to the correct evaluation function based on task `kind`.
|
|
319
642
|
|
|
320
643
|
Args:
|
|
644
|
+
model_checkpoint ('auto' | Path | None):
|
|
645
|
+
- Path to a valid checkpoint for the model. The state of the trained model will be overwritten in place.
|
|
646
|
+
- If 'latest', the latest checkpoint will be loaded if a DragonModelCheckpoint was provided. The state of the trained model will be overwritten in place.
|
|
647
|
+
- If 'current', use the current state of the trained model up the latest trained epoch.
|
|
321
648
|
save_dir (str | Path): Directory to save all reports and plots.
|
|
322
|
-
|
|
323
|
-
|
|
649
|
+
classification_threshold (float | None): Used for tasks using a binary approach (binary classification, binary segmentation, multilabel binary classification)
|
|
650
|
+
test_data (DataLoader | Dataset | None): Optional Test data to evaluate the model performance. Validation and Test metrics will be saved to subdirectories.
|
|
651
|
+
val_format_configuration (object): Optional configuration for metric format output for the validation set.
|
|
652
|
+
test_format_configuration (object): Optional configuration for metric format output for the test set.
|
|
653
|
+
"""
|
|
654
|
+
# Validate model checkpoint
|
|
655
|
+
if isinstance(model_checkpoint, Path):
|
|
656
|
+
checkpoint_validated = make_fullpath(model_checkpoint, enforce="file")
|
|
657
|
+
elif model_checkpoint in [MagicWords.LATEST, MagicWords.CURRENT]:
|
|
658
|
+
checkpoint_validated = model_checkpoint
|
|
659
|
+
else:
|
|
660
|
+
_LOGGER.error(f"'model_checkpoint' must be a Path object, or the string '{MagicWords.LATEST}', or the string '{MagicWords.CURRENT}'.")
|
|
661
|
+
raise ValueError()
|
|
662
|
+
|
|
663
|
+
# Validate classification threshold
|
|
664
|
+
if self.kind not in MLTaskKeys.ALL_BINARY_TASKS:
|
|
665
|
+
# dummy value for tasks that do not need it
|
|
666
|
+
threshold_validated = 0.5
|
|
667
|
+
elif classification_threshold is None:
|
|
668
|
+
# it should have been provided for binary tasks
|
|
669
|
+
_LOGGER.error(f"The classification threshold must be provided for '{self.kind}'.")
|
|
670
|
+
raise ValueError()
|
|
671
|
+
elif classification_threshold <= 0.0 or classification_threshold >= 1.0:
|
|
672
|
+
# Invalid float
|
|
673
|
+
_LOGGER.error(f"A classification threshold of {classification_threshold} is invalid. Must be in the range (0.0 - 1.0).")
|
|
674
|
+
raise ValueError()
|
|
675
|
+
else:
|
|
676
|
+
threshold_validated = classification_threshold
|
|
677
|
+
|
|
678
|
+
# Validate val configuration
|
|
679
|
+
if val_format_configuration is not None:
|
|
680
|
+
if not isinstance(val_format_configuration, (RegressionMetricsFormat,
|
|
681
|
+
MultiTargetRegressionMetricsFormat,
|
|
682
|
+
BinaryClassificationMetricsFormat,
|
|
683
|
+
MultiClassClassificationMetricsFormat,
|
|
684
|
+
BinaryImageClassificationMetricsFormat,
|
|
685
|
+
MultiClassImageClassificationMetricsFormat,
|
|
686
|
+
MultiLabelBinaryClassificationMetricsFormat,
|
|
687
|
+
BinarySegmentationMetricsFormat,
|
|
688
|
+
MultiClassSegmentationMetricsFormat)):
|
|
689
|
+
_LOGGER.error(f"Invalid 'format_configuration': '{type(val_format_configuration)}'.")
|
|
690
|
+
raise ValueError()
|
|
691
|
+
else:
|
|
692
|
+
val_configuration_validated = val_format_configuration
|
|
693
|
+
else: # config is None
|
|
694
|
+
val_configuration_validated = None
|
|
695
|
+
|
|
696
|
+
# Validate directory
|
|
697
|
+
save_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
698
|
+
|
|
699
|
+
# Validate test data and dispatch
|
|
700
|
+
if test_data is not None:
|
|
701
|
+
if not isinstance(test_data, (DataLoader, Dataset)):
|
|
702
|
+
_LOGGER.error(f"Invalid type for 'test_data': '{type(test_data)}'.")
|
|
703
|
+
raise ValueError()
|
|
704
|
+
test_data_validated = test_data
|
|
705
|
+
|
|
706
|
+
validation_metrics_path = save_path / DragonTrainerKeys.VALIDATION_METRICS_DIR
|
|
707
|
+
test_metrics_path = save_path / DragonTrainerKeys.TEST_METRICS_DIR
|
|
708
|
+
|
|
709
|
+
# Dispatch validation set
|
|
710
|
+
_LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{DragonTrainerKeys.VALIDATION_METRICS_DIR}'")
|
|
711
|
+
self._evaluate(save_dir=validation_metrics_path,
|
|
712
|
+
model_checkpoint=checkpoint_validated,
|
|
713
|
+
classification_threshold=threshold_validated,
|
|
714
|
+
data=None,
|
|
715
|
+
format_configuration=val_configuration_validated)
|
|
716
|
+
|
|
717
|
+
# Validate test configuration
|
|
718
|
+
if test_format_configuration is not None:
|
|
719
|
+
if not isinstance(test_format_configuration, (RegressionMetricsFormat,
|
|
720
|
+
MultiTargetRegressionMetricsFormat,
|
|
721
|
+
BinaryClassificationMetricsFormat,
|
|
722
|
+
MultiClassClassificationMetricsFormat,
|
|
723
|
+
BinaryImageClassificationMetricsFormat,
|
|
724
|
+
MultiClassImageClassificationMetricsFormat,
|
|
725
|
+
MultiLabelBinaryClassificationMetricsFormat,
|
|
726
|
+
BinarySegmentationMetricsFormat,
|
|
727
|
+
MultiClassSegmentationMetricsFormat)):
|
|
728
|
+
warning_message_type = f"Invalid test_format_configuration': '{type(test_format_configuration)}'."
|
|
729
|
+
if val_configuration_validated is not None:
|
|
730
|
+
warning_message_type += " 'val_format_configuration' will be used for the test set metrics output."
|
|
731
|
+
test_configuration_validated = val_configuration_validated
|
|
732
|
+
else:
|
|
733
|
+
warning_message_type += " Using default format."
|
|
734
|
+
test_configuration_validated = None
|
|
735
|
+
_LOGGER.warning(warning_message_type)
|
|
736
|
+
else:
|
|
737
|
+
test_configuration_validated = test_format_configuration
|
|
738
|
+
else: #config is None
|
|
739
|
+
test_configuration_validated = None
|
|
740
|
+
|
|
741
|
+
# Dispatch test set
|
|
742
|
+
_LOGGER.info(f"Evaluating on test dataset. Metrics will be saved to '{DragonTrainerKeys.TEST_METRICS_DIR}'")
|
|
743
|
+
self._evaluate(save_dir=test_metrics_path,
|
|
744
|
+
model_checkpoint="current",
|
|
745
|
+
classification_threshold=threshold_validated,
|
|
746
|
+
data=test_data_validated,
|
|
747
|
+
format_configuration=test_configuration_validated)
|
|
748
|
+
else:
|
|
749
|
+
# Dispatch validation set
|
|
750
|
+
_LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{save_path.name}'")
|
|
751
|
+
self._evaluate(save_dir=save_path,
|
|
752
|
+
model_checkpoint=checkpoint_validated,
|
|
753
|
+
classification_threshold=threshold_validated,
|
|
754
|
+
data=None,
|
|
755
|
+
format_configuration=val_configuration_validated)
|
|
756
|
+
|
|
757
|
+
def _evaluate(self,
|
|
758
|
+
save_dir: Union[str, Path],
|
|
759
|
+
model_checkpoint: Union[Path, Literal["latest", "current"]],
|
|
760
|
+
classification_threshold: float,
|
|
761
|
+
data: Optional[Union[DataLoader, Dataset]],
|
|
762
|
+
format_configuration: Optional[Union[
|
|
763
|
+
RegressionMetricsFormat,
|
|
764
|
+
MultiTargetRegressionMetricsFormat,
|
|
765
|
+
BinaryClassificationMetricsFormat,
|
|
766
|
+
MultiClassClassificationMetricsFormat,
|
|
767
|
+
BinaryImageClassificationMetricsFormat,
|
|
768
|
+
MultiClassImageClassificationMetricsFormat,
|
|
769
|
+
MultiLabelBinaryClassificationMetricsFormat,
|
|
770
|
+
BinarySegmentationMetricsFormat,
|
|
771
|
+
MultiClassSegmentationMetricsFormat
|
|
772
|
+
]]=None):
|
|
773
|
+
"""
|
|
774
|
+
Changed to a private helper function.
|
|
324
775
|
"""
|
|
325
|
-
|
|
776
|
+
dataset_for_artifacts = None
|
|
326
777
|
eval_loader = None
|
|
327
|
-
|
|
778
|
+
|
|
779
|
+
# set threshold
|
|
780
|
+
self._classification_threshold = classification_threshold
|
|
781
|
+
|
|
782
|
+
# load model checkpoint
|
|
783
|
+
if isinstance(model_checkpoint, Path):
|
|
784
|
+
self._load_checkpoint(path=model_checkpoint)
|
|
785
|
+
elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback:
|
|
786
|
+
path_to_latest = self._checkpoint_callback.best_checkpoint_path
|
|
787
|
+
self._load_checkpoint(path_to_latest)
|
|
788
|
+
elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback is None:
|
|
789
|
+
_LOGGER.error(f"'model_checkpoint' set to '{MagicWords.LATEST}' but no checkpoint callback was found.")
|
|
790
|
+
raise ValueError()
|
|
791
|
+
|
|
792
|
+
# Dataloader
|
|
328
793
|
if isinstance(data, DataLoader):
|
|
329
794
|
eval_loader = data
|
|
330
795
|
# Try to get the dataset from the loader for fetching target names
|
|
331
796
|
if hasattr(data, 'dataset'):
|
|
332
|
-
|
|
797
|
+
dataset_for_artifacts = data.dataset # type: ignore
|
|
333
798
|
elif isinstance(data, Dataset):
|
|
334
799
|
# Create a new loader from the provided dataset
|
|
335
800
|
eval_loader = DataLoader(data,
|
|
336
|
-
batch_size=
|
|
801
|
+
batch_size=self._batch_size,
|
|
337
802
|
shuffle=False,
|
|
338
803
|
num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
|
|
339
804
|
pin_memory=(self.device.type == "cuda"))
|
|
340
|
-
|
|
805
|
+
dataset_for_artifacts = data
|
|
341
806
|
else: # data is None, use the trainer's default test dataset
|
|
342
|
-
if self.
|
|
343
|
-
_LOGGER.error("Cannot evaluate. No data provided and no
|
|
807
|
+
if self.validation_dataset is None:
|
|
808
|
+
_LOGGER.error("Cannot evaluate. No data provided and no validation dataset available in the trainer.")
|
|
344
809
|
raise ValueError()
|
|
345
810
|
# Create a fresh DataLoader from the test_dataset
|
|
346
|
-
eval_loader = DataLoader(self.
|
|
347
|
-
batch_size=
|
|
811
|
+
eval_loader = DataLoader(self.validation_dataset,
|
|
812
|
+
batch_size=self._batch_size,
|
|
348
813
|
shuffle=False,
|
|
349
814
|
num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
|
|
350
815
|
pin_memory=(self.device.type == "cuda"))
|
|
351
|
-
|
|
816
|
+
|
|
817
|
+
dataset_for_artifacts = self.validation_dataset
|
|
352
818
|
|
|
353
819
|
if eval_loader is None:
|
|
354
820
|
_LOGGER.error("Cannot evaluate. No valid data was provided or found.")
|
|
355
821
|
raise ValueError()
|
|
356
822
|
|
|
357
|
-
print("\n--- Model Evaluation ---")
|
|
823
|
+
# print("\n--- Model Evaluation ---")
|
|
358
824
|
|
|
359
825
|
all_preds, all_probs, all_true = [], [], []
|
|
360
|
-
for y_pred_b, y_prob_b, y_true_b in self._predict_for_eval(eval_loader
|
|
826
|
+
for y_pred_b, y_prob_b, y_true_b in self._predict_for_eval(eval_loader):
|
|
361
827
|
if y_pred_b is not None: all_preds.append(y_pred_b)
|
|
362
828
|
if y_prob_b is not None: all_probs.append(y_prob_b)
|
|
363
829
|
if y_true_b is not None: all_true.append(y_true_b)
|
|
@@ -371,24 +837,83 @@ class MLTrainer:
|
|
|
371
837
|
y_prob = np.concatenate(all_probs) if all_probs else None
|
|
372
838
|
|
|
373
839
|
# --- Routing Logic ---
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
840
|
+
# Single-target regression
|
|
841
|
+
if self.kind == MLTaskKeys.REGRESSION:
|
|
842
|
+
# Check configuration
|
|
843
|
+
config = None
|
|
844
|
+
if format_configuration and isinstance(format_configuration, RegressionMetricsFormat):
|
|
845
|
+
config = format_configuration
|
|
846
|
+
elif format_configuration:
|
|
847
|
+
_LOGGER.warning(f"Wrong configuration type: Received '{type(format_configuration).__name__}'.")
|
|
848
|
+
|
|
849
|
+
regression_metrics(y_true=y_true.flatten(),
|
|
850
|
+
y_pred=y_pred.flatten(),
|
|
851
|
+
save_dir=save_dir,
|
|
852
|
+
config=config)
|
|
853
|
+
|
|
854
|
+
# single target classification
|
|
855
|
+
elif self.kind in [MLTaskKeys.BINARY_CLASSIFICATION,
|
|
856
|
+
MLTaskKeys.BINARY_IMAGE_CLASSIFICATION,
|
|
857
|
+
MLTaskKeys.MULTICLASS_CLASSIFICATION,
|
|
858
|
+
MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION]:
|
|
859
|
+
# get the class map if it exists
|
|
860
|
+
try:
|
|
861
|
+
class_map = dataset_for_artifacts.class_map # type: ignore
|
|
862
|
+
except AttributeError:
|
|
863
|
+
_LOGGER.warning(f"Dataset has no 'class_map' attribute. Using generics.")
|
|
864
|
+
class_map = None
|
|
865
|
+
else:
|
|
866
|
+
if not isinstance(class_map, dict):
|
|
867
|
+
_LOGGER.warning(f"Dataset has a 'class_map' attribute, but it is not a dictionary: '{type(class_map)}'.")
|
|
868
|
+
class_map = None
|
|
869
|
+
|
|
870
|
+
# Check configuration
|
|
871
|
+
config = None
|
|
872
|
+
if format_configuration:
|
|
873
|
+
if self.kind == MLTaskKeys.BINARY_CLASSIFICATION and isinstance(format_configuration, BinaryClassificationMetricsFormat):
|
|
874
|
+
config = format_configuration
|
|
875
|
+
elif self.kind == MLTaskKeys.BINARY_IMAGE_CLASSIFICATION and isinstance(format_configuration, BinaryImageClassificationMetricsFormat):
|
|
876
|
+
config = format_configuration
|
|
877
|
+
elif self.kind == MLTaskKeys.MULTICLASS_CLASSIFICATION and isinstance(format_configuration, MultiClassClassificationMetricsFormat):
|
|
878
|
+
config = format_configuration
|
|
879
|
+
elif self.kind == MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION and isinstance(format_configuration, MultiClassImageClassificationMetricsFormat):
|
|
880
|
+
config = format_configuration
|
|
881
|
+
else:
|
|
882
|
+
_LOGGER.warning(f"Wrong configuration type: Received '{type(format_configuration).__name__}'.")
|
|
883
|
+
|
|
884
|
+
classification_metrics(save_dir=save_dir,
|
|
885
|
+
y_true=y_true,
|
|
886
|
+
y_pred=y_pred,
|
|
887
|
+
y_prob=y_prob,
|
|
888
|
+
class_map=class_map,
|
|
889
|
+
config=config)
|
|
890
|
+
|
|
891
|
+
# multitarget regression
|
|
892
|
+
elif self.kind == MLTaskKeys.MULTITARGET_REGRESSION:
|
|
381
893
|
try:
|
|
382
|
-
target_names =
|
|
894
|
+
target_names = dataset_for_artifacts.target_names # type: ignore
|
|
383
895
|
except AttributeError:
|
|
384
896
|
num_targets = y_true.shape[1]
|
|
385
897
|
target_names = [f"target_{i}" for i in range(num_targets)]
|
|
386
898
|
_LOGGER.warning(f"Dataset has no 'target_names' attribute. Using generic names.")
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
899
|
+
|
|
900
|
+
# Check configuration
|
|
901
|
+
config = None
|
|
902
|
+
if format_configuration and isinstance(format_configuration, MultiTargetRegressionMetricsFormat):
|
|
903
|
+
config = format_configuration
|
|
904
|
+
elif format_configuration:
|
|
905
|
+
_LOGGER.warning(f"Wrong configuration type: Received '{type(format_configuration).__name__}'.")
|
|
906
|
+
|
|
907
|
+
multi_target_regression_metrics(y_true=y_true,
|
|
908
|
+
y_pred=y_pred,
|
|
909
|
+
target_names=target_names,
|
|
910
|
+
save_dir=save_dir,
|
|
911
|
+
config=config)
|
|
912
|
+
|
|
913
|
+
# multi-label binary classification
|
|
914
|
+
elif self.kind == MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION:
|
|
390
915
|
try:
|
|
391
|
-
target_names =
|
|
916
|
+
target_names = dataset_for_artifacts.target_names # type: ignore
|
|
392
917
|
except AttributeError:
|
|
393
918
|
num_targets = y_true.shape[1]
|
|
394
919
|
target_names = [f"label_{i}" for i in range(num_targets)]
|
|
@@ -397,10 +922,56 @@ class MLTrainer:
|
|
|
397
922
|
if y_prob is None:
|
|
398
923
|
_LOGGER.error("Evaluation for multi_label_classification requires probabilities (y_prob).")
|
|
399
924
|
return
|
|
400
|
-
|
|
925
|
+
|
|
926
|
+
# Check configuration
|
|
927
|
+
config = None
|
|
928
|
+
if format_configuration and isinstance(format_configuration, MultiLabelBinaryClassificationMetricsFormat):
|
|
929
|
+
config = format_configuration
|
|
930
|
+
elif format_configuration:
|
|
931
|
+
_LOGGER.warning(f"Wrong configuration type: Received '{type(format_configuration).__name__}'.")
|
|
401
932
|
|
|
402
|
-
|
|
403
|
-
|
|
933
|
+
multi_label_classification_metrics(y_true=y_true,
|
|
934
|
+
y_pred=y_pred,
|
|
935
|
+
y_prob=y_prob,
|
|
936
|
+
target_names=target_names,
|
|
937
|
+
save_dir=save_dir,
|
|
938
|
+
config=config)
|
|
939
|
+
|
|
940
|
+
# Segmentation tasks
|
|
941
|
+
elif self.kind in [MLTaskKeys.BINARY_SEGMENTATION, MLTaskKeys.MULTICLASS_SEGMENTATION]:
|
|
942
|
+
class_names = None
|
|
943
|
+
try:
|
|
944
|
+
# Try to get 'classes' from VisionDatasetMaker
|
|
945
|
+
if hasattr(dataset_for_artifacts, 'classes'):
|
|
946
|
+
class_names = dataset_for_artifacts.classes # type: ignore
|
|
947
|
+
# Fallback for Subset
|
|
948
|
+
elif hasattr(dataset_for_artifacts, 'dataset') and hasattr(dataset_for_artifacts.dataset, 'classes'): # type: ignore
|
|
949
|
+
class_names = dataset_for_artifacts.dataset.classes # type: ignore
|
|
950
|
+
except AttributeError:
|
|
951
|
+
pass # class_names is still None
|
|
952
|
+
|
|
953
|
+
if class_names is None:
|
|
954
|
+
try:
|
|
955
|
+
# Fallback to 'target_names'
|
|
956
|
+
class_names = dataset_for_artifacts.target_names # type: ignore
|
|
957
|
+
except AttributeError:
|
|
958
|
+
# Fallback to inferring from labels
|
|
959
|
+
labels = np.unique(y_true)
|
|
960
|
+
class_names = [f"Class {i}" for i in labels]
|
|
961
|
+
_LOGGER.warning(f"Dataset has no 'classes' or 'target_names' attribute. Using generic names.")
|
|
962
|
+
|
|
963
|
+
# Check configuration
|
|
964
|
+
config = None
|
|
965
|
+
if format_configuration and isinstance(format_configuration, (BinarySegmentationMetricsFormat, MultiClassSegmentationMetricsFormat)):
|
|
966
|
+
config = format_configuration
|
|
967
|
+
elif format_configuration:
|
|
968
|
+
_LOGGER.warning(f"Wrong configuration type: Received '{type(format_configuration).__name__}'.")
|
|
969
|
+
|
|
970
|
+
segmentation_metrics(y_true=y_true,
|
|
971
|
+
y_pred=y_pred,
|
|
972
|
+
save_dir=save_dir,
|
|
973
|
+
class_names=class_names,
|
|
974
|
+
config=config)
|
|
404
975
|
|
|
405
976
|
def explain(self,
|
|
406
977
|
save_dir: Union[str,Path],
|
|
@@ -408,7 +979,7 @@ class MLTrainer:
|
|
|
408
979
|
n_samples: int = 300,
|
|
409
980
|
feature_names: Optional[List[str]] = None,
|
|
410
981
|
target_names: Optional[List[str]] = None,
|
|
411
|
-
explainer_type: Literal['deep', 'kernel'] = '
|
|
982
|
+
explainer_type: Literal['deep', 'kernel'] = 'kernel'):
|
|
412
983
|
"""
|
|
413
984
|
Explains model predictions using SHAP and saves all artifacts.
|
|
414
985
|
|
|
@@ -422,41 +993,59 @@ class MLTrainer:
|
|
|
422
993
|
explain_dataset (Dataset | None): A specific dataset to explain.
|
|
423
994
|
If None, the trainer's test dataset is used.
|
|
424
995
|
n_samples (int): The number of samples to use for both background and explanation.
|
|
425
|
-
feature_names (list[str] | None): Feature names.
|
|
996
|
+
feature_names (list[str] | None): Feature names. If None, the names will be extracted from the Dataset and raise an error on failure.
|
|
426
997
|
target_names (list[str] | None): Target names for multi-target tasks.
|
|
427
998
|
save_dir (str | Path): Directory to save all SHAP artifacts.
|
|
428
999
|
explainer_type (Literal['deep', 'kernel']): The explainer to use.
|
|
429
|
-
- 'deep':
|
|
1000
|
+
- 'deep': Uses shap.DeepExplainer. Fast and efficient for PyTorch models.
|
|
430
1001
|
- 'kernel': Uses shap.KernelExplainer. Model-agnostic but EXTREMELY slow and memory-intensive. Use with a very low 'n_samples'< 100.
|
|
431
|
-
"""
|
|
432
|
-
#
|
|
1002
|
+
"""
|
|
1003
|
+
# memory efficient helper
|
|
433
1004
|
def _get_random_sample(dataset: Dataset, num_samples: int):
|
|
1005
|
+
"""
|
|
1006
|
+
Memory-efficiently samples data from a dataset.
|
|
1007
|
+
"""
|
|
434
1008
|
if dataset is None:
|
|
435
1009
|
return None
|
|
436
1010
|
|
|
1011
|
+
dataset_len = len(dataset) # type: ignore
|
|
1012
|
+
if dataset_len == 0:
|
|
1013
|
+
return None
|
|
1014
|
+
|
|
437
1015
|
# For MPS devices, num_workers must be 0 to ensure stability
|
|
438
1016
|
loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
|
|
439
1017
|
|
|
1018
|
+
# Ensure batch_size is not larger than the dataset itself
|
|
1019
|
+
batch_size = min(num_samples, 64, dataset_len)
|
|
1020
|
+
|
|
440
1021
|
loader = DataLoader(
|
|
441
1022
|
dataset,
|
|
442
|
-
batch_size=
|
|
443
|
-
shuffle=
|
|
1023
|
+
batch_size=batch_size,
|
|
1024
|
+
shuffle=True, # Shuffle to get random samples
|
|
444
1025
|
num_workers=loader_workers
|
|
445
1026
|
)
|
|
446
1027
|
|
|
447
|
-
|
|
448
|
-
|
|
1028
|
+
collected_features = []
|
|
1029
|
+
num_collected = 0
|
|
1030
|
+
|
|
1031
|
+
for features, _ in loader:
|
|
1032
|
+
collected_features.append(features)
|
|
1033
|
+
num_collected += features.size(0)
|
|
1034
|
+
if num_collected >= num_samples:
|
|
1035
|
+
break # Stop once we have enough samples
|
|
1036
|
+
|
|
1037
|
+
if not collected_features:
|
|
449
1038
|
return None
|
|
450
1039
|
|
|
451
|
-
full_data = torch.cat(
|
|
1040
|
+
full_data = torch.cat(collected_features, dim=0)
|
|
452
1041
|
|
|
453
|
-
|
|
454
|
-
|
|
1042
|
+
# If we collected more than needed, trim it down
|
|
1043
|
+
if full_data.size(0) > num_samples:
|
|
1044
|
+
return full_data[:num_samples]
|
|
455
1045
|
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
print(f"\n--- Preparing SHAP Data (sampling up to {n_samples} instances) ---")
|
|
1046
|
+
return full_data
|
|
1047
|
+
|
|
1048
|
+
# print(f"\n--- Preparing SHAP Data (sampling up to {n_samples} instances) ---")
|
|
460
1049
|
|
|
461
1050
|
# 1. Get background data from the trainer's train_dataset
|
|
462
1051
|
background_data = _get_random_sample(self.train_dataset, n_samples)
|
|
@@ -465,7 +1054,7 @@ class MLTrainer:
|
|
|
465
1054
|
return
|
|
466
1055
|
|
|
467
1056
|
# 2. Determine target dataset and get explanation instances
|
|
468
|
-
target_dataset = explain_dataset if explain_dataset is not None else self.
|
|
1057
|
+
target_dataset = explain_dataset if explain_dataset is not None else self.validation_dataset
|
|
469
1058
|
instances_to_explain = _get_random_sample(target_dataset, n_samples)
|
|
470
1059
|
if instances_to_explain is None:
|
|
471
1060
|
_LOGGER.error("Explanation dataset is empty or invalid. Skipping SHAP analysis.")
|
|
@@ -474,17 +1063,17 @@ class MLTrainer:
|
|
|
474
1063
|
# attempt to get feature names
|
|
475
1064
|
if feature_names is None:
|
|
476
1065
|
# _LOGGER.info("`feature_names` not provided. Attempting to extract from dataset...")
|
|
477
|
-
if hasattr(target_dataset,
|
|
1066
|
+
if hasattr(target_dataset, DatasetKeys.FEATURE_NAMES):
|
|
478
1067
|
feature_names = target_dataset.feature_names # type: ignore
|
|
479
1068
|
else:
|
|
480
|
-
_LOGGER.error("Could not extract `feature_names` from the dataset. It must be provided if the dataset object does not have a
|
|
1069
|
+
_LOGGER.error(f"Could not extract `feature_names` from the dataset. It must be provided if the dataset object does not have a '{DatasetKeys.FEATURE_NAMES}' attribute.")
|
|
481
1070
|
raise ValueError()
|
|
482
1071
|
|
|
483
1072
|
# move model to device
|
|
484
1073
|
self.model.to(self.device)
|
|
485
1074
|
|
|
486
1075
|
# 3. Call the plotting function
|
|
487
|
-
if self.kind in [
|
|
1076
|
+
if self.kind in [MLTaskKeys.REGRESSION, MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.MULTICLASS_CLASSIFICATION]:
|
|
488
1077
|
shap_summary_plot(
|
|
489
1078
|
model=self.model,
|
|
490
1079
|
background_data=background_data,
|
|
@@ -494,11 +1083,11 @@ class MLTrainer:
|
|
|
494
1083
|
explainer_type=explainer_type,
|
|
495
1084
|
device=self.device
|
|
496
1085
|
)
|
|
497
|
-
elif self.kind in [
|
|
1086
|
+
elif self.kind in [MLTaskKeys.MULTITARGET_REGRESSION, MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION]:
|
|
498
1087
|
# try to get target names
|
|
499
1088
|
if target_names is None:
|
|
500
1089
|
target_names = []
|
|
501
|
-
if hasattr(target_dataset,
|
|
1090
|
+
if hasattr(target_dataset, DatasetKeys.TARGET_NAMES):
|
|
502
1091
|
target_names = target_dataset.target_names # type: ignore
|
|
503
1092
|
else:
|
|
504
1093
|
# Infer number of targets from the model's output layer
|
|
@@ -549,7 +1138,7 @@ class MLTrainer:
|
|
|
549
1138
|
yield attention_weights
|
|
550
1139
|
|
|
551
1140
|
def explain_attention(self, save_dir: Union[str, Path],
|
|
552
|
-
feature_names: Optional[List[str]],
|
|
1141
|
+
feature_names: Optional[List[str]] = None,
|
|
553
1142
|
explain_dataset: Optional[Dataset] = None,
|
|
554
1143
|
plot_n_features: int = 10):
|
|
555
1144
|
"""
|
|
@@ -559,27 +1148,32 @@ class MLTrainer:
|
|
|
559
1148
|
|
|
560
1149
|
Args:
|
|
561
1150
|
save_dir (str | Path): Directory to save the plot and summary data.
|
|
562
|
-
feature_names (List[str] | None): Names for the features for plot labeling. If
|
|
1151
|
+
feature_names (List[str] | None): Names for the features for plot labeling. If None, the names will be extracted from the Dataset and raise an error on failure.
|
|
563
1152
|
explain_dataset (Dataset, optional): A specific dataset to explain. If None, the trainer's test dataset is used.
|
|
564
1153
|
plot_n_features (int): Number of top features to plot.
|
|
565
1154
|
"""
|
|
566
1155
|
|
|
567
|
-
print("\n--- Attention Analysis ---")
|
|
1156
|
+
# print("\n--- Attention Analysis ---")
|
|
568
1157
|
|
|
569
1158
|
# --- Step 1: Check if the model supports this explanation ---
|
|
570
1159
|
if not getattr(self.model, 'has_interpretable_attention', False):
|
|
571
|
-
_LOGGER.warning(
|
|
572
|
-
"Model is not flagged for interpretable attention analysis. "
|
|
573
|
-
"Skipping. This is the correct behavior for models like MultiHeadAttentionMLP."
|
|
574
|
-
)
|
|
1160
|
+
_LOGGER.warning("Model is not compatible with interpretable attention analysis. Skipping.")
|
|
575
1161
|
return
|
|
576
1162
|
|
|
577
1163
|
# --- Step 2: Set up the dataloader ---
|
|
578
|
-
dataset_to_use = explain_dataset if explain_dataset is not None else self.
|
|
1164
|
+
dataset_to_use = explain_dataset if explain_dataset is not None else self.validation_dataset
|
|
579
1165
|
if not isinstance(dataset_to_use, Dataset):
|
|
580
1166
|
_LOGGER.error("The explanation dataset is empty or invalid. Skipping attention analysis.")
|
|
581
1167
|
return
|
|
582
1168
|
|
|
1169
|
+
# Get feature names
|
|
1170
|
+
if feature_names is None:
|
|
1171
|
+
if hasattr(dataset_to_use, DatasetKeys.FEATURE_NAMES):
|
|
1172
|
+
feature_names = dataset_to_use.feature_names # type: ignore
|
|
1173
|
+
else:
|
|
1174
|
+
_LOGGER.error(f"Could not extract `feature_names` from the dataset for attention plot. It must be provided if the dataset object does not have a '{DatasetKeys.FEATURE_NAMES}' attribute.")
|
|
1175
|
+
raise ValueError()
|
|
1176
|
+
|
|
583
1177
|
explain_loader = DataLoader(
|
|
584
1178
|
dataset=dataset_to_use, batch_size=32, shuffle=False,
|
|
585
1179
|
num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
|
|
@@ -602,34 +1196,865 @@ class MLTrainer:
|
|
|
602
1196
|
)
|
|
603
1197
|
else:
|
|
604
1198
|
_LOGGER.error("No attention weights were collected from the model.")
|
|
605
|
-
|
|
606
|
-
def _callbacks_hook(self, method_name: str, *args, **kwargs):
|
|
607
|
-
"""Calls the specified method on all callbacks."""
|
|
608
|
-
for callback in self.callbacks:
|
|
609
|
-
method = getattr(callback, method_name)
|
|
610
|
-
method(*args, **kwargs)
|
|
611
|
-
|
|
612
|
-
def to_cpu(self):
|
|
613
|
-
"""
|
|
614
|
-
Moves the model to the CPU and updates the trainer's device setting.
|
|
615
1199
|
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
1200
|
+
def finalize_model_training(self,
|
|
1201
|
+
model_checkpoint: Union[Path, Literal['latest', 'current']],
|
|
1202
|
+
save_dir: Union[str, Path],
|
|
1203
|
+
finalize_config: Union[FinalizeRegression,
|
|
1204
|
+
FinalizeMultiTargetRegression,
|
|
1205
|
+
FinalizeBinaryClassification,
|
|
1206
|
+
FinalizeBinaryImageClassification,
|
|
1207
|
+
FinalizeMultiClassClassification,
|
|
1208
|
+
FinalizeMultiClassImageClassification,
|
|
1209
|
+
FinalizeBinarySegmentation,
|
|
1210
|
+
FinalizeMultiClassSegmentation,
|
|
1211
|
+
FinalizeMultiLabelBinaryClassification]):
|
|
623
1212
|
"""
|
|
624
|
-
|
|
1213
|
+
Saves a finalized, "inference-ready" model state to a .pth file.
|
|
1214
|
+
|
|
1215
|
+
This method saves the model's `state_dict`, the final epoch number, and optional configuration for the task at hand.
|
|
625
1216
|
|
|
626
1217
|
Args:
|
|
627
|
-
|
|
1218
|
+
model_checkpoint (Path | "latest" | "current"):
|
|
1219
|
+
- Path: Loads the model state from a specific checkpoint file.
|
|
1220
|
+
- "latest": Loads the best model state saved by the `DragonModelCheckpoint` callback.
|
|
1221
|
+
- "current": Uses the model's state as it is at the end of the `fit()` call.
|
|
1222
|
+
save_dir (str | Path): The directory to save the finalized model.
|
|
1223
|
+
finalize_config (object): A data class instance specific to the ML task containing task-specific metadata required for inference.
|
|
628
1224
|
"""
|
|
629
|
-
self.
|
|
630
|
-
|
|
631
|
-
|
|
1225
|
+
if self.kind == MLTaskKeys.REGRESSION and not isinstance(finalize_config, FinalizeRegression):
|
|
1226
|
+
_LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeRegression', but got {type(finalize_config).__name__}.")
|
|
1227
|
+
raise TypeError()
|
|
1228
|
+
elif self.kind == MLTaskKeys.MULTITARGET_REGRESSION and not isinstance(finalize_config, FinalizeMultiTargetRegression):
|
|
1229
|
+
_LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeMultiTargetRegression', but got {type(finalize_config).__name__}.")
|
|
1230
|
+
raise TypeError()
|
|
1231
|
+
elif self.kind == MLTaskKeys.BINARY_CLASSIFICATION and not isinstance(finalize_config, FinalizeBinaryClassification):
|
|
1232
|
+
_LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeBinaryClassification', but got {type(finalize_config).__name__}.")
|
|
1233
|
+
raise TypeError()
|
|
1234
|
+
elif self.kind == MLTaskKeys.BINARY_IMAGE_CLASSIFICATION and not isinstance(finalize_config, FinalizeBinaryImageClassification):
|
|
1235
|
+
_LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeBinaryImageClassification', but got {type(finalize_config).__name__}.")
|
|
1236
|
+
raise TypeError()
|
|
1237
|
+
elif self.kind == MLTaskKeys.MULTICLASS_CLASSIFICATION and not isinstance(finalize_config, FinalizeMultiClassClassification):
|
|
1238
|
+
_LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeMultiClassClassification', but got {type(finalize_config).__name__}.")
|
|
1239
|
+
raise TypeError()
|
|
1240
|
+
elif self.kind == MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION and not isinstance(finalize_config, FinalizeMultiClassImageClassification):
|
|
1241
|
+
_LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeMultiClassImageClassification', but got {type(finalize_config).__name__}.")
|
|
1242
|
+
raise TypeError()
|
|
1243
|
+
elif self.kind == MLTaskKeys.BINARY_SEGMENTATION and not isinstance(finalize_config, FinalizeBinarySegmentation):
|
|
1244
|
+
_LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeBinarySegmentation', but got {type(finalize_config).__name__}.")
|
|
1245
|
+
raise TypeError()
|
|
1246
|
+
elif self.kind == MLTaskKeys.MULTICLASS_SEGMENTATION and not isinstance(finalize_config, FinalizeMultiClassSegmentation):
|
|
1247
|
+
_LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeMultiClassSegmentation', but got {type(finalize_config).__name__}.")
|
|
1248
|
+
raise TypeError()
|
|
1249
|
+
elif self.kind == MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION and not isinstance(finalize_config, FinalizeMultiLabelBinaryClassification):
|
|
1250
|
+
_LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeMultiLabelBinaryClassification', but got {type(finalize_config).__name__}.")
|
|
1251
|
+
raise TypeError()
|
|
1252
|
+
|
|
1253
|
+
# handle save path
|
|
1254
|
+
dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
1255
|
+
full_path = dir_path / finalize_config.filename
|
|
1256
|
+
|
|
1257
|
+
# handle checkpoint
|
|
1258
|
+
self._load_model_state_for_finalizing(model_checkpoint)
|
|
1259
|
+
|
|
1260
|
+
# Create finalized data
|
|
1261
|
+
finalized_data = {
|
|
1262
|
+
PyTorchCheckpointKeys.EPOCH: self.epoch,
|
|
1263
|
+
PyTorchCheckpointKeys.MODEL_STATE: self.model.state_dict(),
|
|
1264
|
+
}
|
|
1265
|
+
|
|
1266
|
+
# Parse config
|
|
1267
|
+
if finalize_config.target_name is not None:
|
|
1268
|
+
finalized_data[PyTorchCheckpointKeys.TARGET_NAME] = finalize_config.target_name
|
|
1269
|
+
if finalize_config.target_names is not None:
|
|
1270
|
+
finalized_data[PyTorchCheckpointKeys.TARGET_NAMES] = finalize_config.target_names
|
|
1271
|
+
if finalize_config.classification_threshold is not None:
|
|
1272
|
+
finalized_data[PyTorchCheckpointKeys.CLASSIFICATION_THRESHOLD] = finalize_config.classification_threshold
|
|
1273
|
+
if finalize_config.class_map is not None:
|
|
1274
|
+
finalized_data[PyTorchCheckpointKeys.CLASS_MAP] = finalize_config.class_map
|
|
1275
|
+
|
|
1276
|
+
# Save model file
|
|
1277
|
+
torch.save(finalized_data, full_path)
|
|
1278
|
+
|
|
1279
|
+
_LOGGER.info(f"Finalized model file saved to '{full_path}'")
|
|
1280
|
+
|
|
1281
|
+
|
|
1282
|
+
# Object Detection Trainer
|
|
1283
|
+
class DragonDetectionTrainer(_BaseDragonTrainer):
|
|
1284
|
+
def __init__(self, model: nn.Module,
|
|
1285
|
+
train_dataset: Dataset,
|
|
1286
|
+
validation_dataset: Dataset,
|
|
1287
|
+
collate_fn: Callable, optimizer: torch.optim.Optimizer,
|
|
1288
|
+
device: Union[Literal['cuda', 'mps', 'cpu'],str],
|
|
1289
|
+
checkpoint_callback: Optional[DragonModelCheckpoint],
|
|
1290
|
+
early_stopping_callback: Optional[DragonEarlyStopping],
|
|
1291
|
+
lr_scheduler_callback: Optional[DragonLRScheduler],
|
|
1292
|
+
extra_callbacks: Optional[List[_Callback]] = None,
|
|
1293
|
+
dataloader_workers: int = 2):
|
|
1294
|
+
"""
|
|
1295
|
+
Automates the training process of an Object Detection Model (e.g., DragonFastRCNN).
|
|
1296
|
+
|
|
1297
|
+
Built-in Callbacks: `History`, `TqdmProgressBar`
|
|
1298
|
+
|
|
1299
|
+
Args:
|
|
1300
|
+
model (nn.Module): The PyTorch object detection model to train.
|
|
1301
|
+
train_dataset (Dataset): The training dataset.
|
|
1302
|
+
validation_dataset (Dataset): The testing/validation dataset.
|
|
1303
|
+
collate_fn (Callable): The collate function from `ObjectDetectionDatasetMaker.collate_fn`.
|
|
1304
|
+
optimizer (torch.optim.Optimizer): The optimizer.
|
|
1305
|
+
device (str): The device to run training on ('cpu', 'cuda', 'mps').
|
|
1306
|
+
dataloader_workers (int): Subprocesses for data loading.
|
|
1307
|
+
checkpoint_callback (DragonModelCheckpoint | None): Callback to save the model.
|
|
1308
|
+
early_stopping_callback (DragonEarlyStopping | None): Callback to stop training early.
|
|
1309
|
+
lr_scheduler_callback (DragonLRScheduler | None): Callback to manage the LR scheduler.
|
|
1310
|
+
extra_callbacks (List[Callback] | None): A list of extra callbacks to use during training.
|
|
1311
|
+
|
|
1312
|
+
## Note:
|
|
1313
|
+
This trainer is specialized. It does not take a `criterion` because object detection models like Faster R-CNN return a dictionary of losses directly from their forward pass during training.
|
|
1314
|
+
"""
|
|
1315
|
+
# Call the base class constructor with common parameters
|
|
1316
|
+
super().__init__(
|
|
1317
|
+
model=model,
|
|
1318
|
+
optimizer=optimizer,
|
|
1319
|
+
device=device,
|
|
1320
|
+
dataloader_workers=dataloader_workers,
|
|
1321
|
+
checkpoint_callback=checkpoint_callback,
|
|
1322
|
+
early_stopping_callback=early_stopping_callback,
|
|
1323
|
+
lr_scheduler_callback=lr_scheduler_callback,
|
|
1324
|
+
extra_callbacks=extra_callbacks
|
|
1325
|
+
)
|
|
1326
|
+
|
|
1327
|
+
self.train_dataset = train_dataset
|
|
1328
|
+
self.validation_dataset = validation_dataset # <-- Renamed
|
|
1329
|
+
self.kind = MLTaskKeys.OBJECT_DETECTION
|
|
1330
|
+
self.collate_fn = collate_fn
|
|
1331
|
+
self.criterion = None # Criterion is handled inside the model
|
|
1332
|
+
|
|
1333
|
+
def _create_dataloaders(self, batch_size: int, shuffle: bool):
|
|
1334
|
+
"""Initializes the DataLoaders with the object detection collate_fn."""
|
|
1335
|
+
# Ensure stability on MPS devices by setting num_workers to 0
|
|
1336
|
+
loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
|
|
1337
|
+
|
|
1338
|
+
self.train_loader = DataLoader(
|
|
1339
|
+
dataset=self.train_dataset,
|
|
1340
|
+
batch_size=batch_size,
|
|
1341
|
+
shuffle=shuffle,
|
|
1342
|
+
num_workers=loader_workers,
|
|
1343
|
+
pin_memory=("cuda" in self.device.type),
|
|
1344
|
+
collate_fn=self.collate_fn, # Use the provided collate function
|
|
1345
|
+
drop_last=True
|
|
1346
|
+
)
|
|
1347
|
+
|
|
1348
|
+
self.validation_loader = DataLoader(
|
|
1349
|
+
dataset=self.validation_dataset,
|
|
1350
|
+
batch_size=batch_size,
|
|
1351
|
+
shuffle=False,
|
|
1352
|
+
num_workers=loader_workers,
|
|
1353
|
+
pin_memory=("cuda" in self.device.type),
|
|
1354
|
+
collate_fn=self.collate_fn # Use the provided collate function
|
|
1355
|
+
)
|
|
1356
|
+
|
|
1357
|
+
def _train_step(self):
|
|
1358
|
+
self.model.train()
|
|
1359
|
+
running_loss = 0.0
|
|
1360
|
+
total_samples = 0
|
|
1361
|
+
|
|
1362
|
+
for batch_idx, (images, targets) in enumerate(self.train_loader): # type: ignore
|
|
1363
|
+
# images is a tuple of tensors, targets is a tuple of dicts
|
|
1364
|
+
batch_size = len(images)
|
|
1365
|
+
|
|
1366
|
+
# Create a log dictionary for the batch
|
|
1367
|
+
batch_logs = {
|
|
1368
|
+
PyTorchLogKeys.BATCH_INDEX: batch_idx,
|
|
1369
|
+
PyTorchLogKeys.BATCH_SIZE: batch_size
|
|
1370
|
+
}
|
|
1371
|
+
self._callbacks_hook('on_batch_begin', batch_idx, logs=batch_logs)
|
|
1372
|
+
|
|
1373
|
+
# Move data to device
|
|
1374
|
+
images = list(img.to(self.device) for img in images)
|
|
1375
|
+
targets = [{k: v.to(self.device) for k, v in t.items()} for t in targets]
|
|
1376
|
+
|
|
1377
|
+
self.optimizer.zero_grad()
|
|
1378
|
+
|
|
1379
|
+
# Model returns a loss dict when in train() mode and targets are passed
|
|
1380
|
+
loss_dict = self.model(images, targets)
|
|
1381
|
+
|
|
1382
|
+
if not loss_dict:
|
|
1383
|
+
# No losses returned, skip batch
|
|
1384
|
+
_LOGGER.warning(f"Model returned no losses for batch {batch_idx}. Skipping.")
|
|
1385
|
+
batch_logs[PyTorchLogKeys.BATCH_LOSS] = 0
|
|
1386
|
+
self._callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
|
|
1387
|
+
continue
|
|
1388
|
+
|
|
1389
|
+
# Sum all losses
|
|
1390
|
+
loss: torch.Tensor = sum(l for l in loss_dict.values()) # type: ignore
|
|
1391
|
+
|
|
1392
|
+
loss.backward()
|
|
1393
|
+
self.optimizer.step()
|
|
1394
|
+
|
|
1395
|
+
# Calculate batch loss and update running loss for the epoch
|
|
1396
|
+
batch_loss = loss.item()
|
|
1397
|
+
running_loss += batch_loss * batch_size
|
|
1398
|
+
total_samples += batch_size # <-- Accumulate total samples
|
|
1399
|
+
|
|
1400
|
+
# Add the batch loss to the logs and call the end-of-batch hook
|
|
1401
|
+
batch_logs[PyTorchLogKeys.BATCH_LOSS] = batch_loss # type: ignore
|
|
1402
|
+
self._callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
|
|
1403
|
+
|
|
1404
|
+
# Calculate loss using the correct denominator
|
|
1405
|
+
if total_samples == 0:
|
|
1406
|
+
_LOGGER.warning("No samples processed in _train_step. Returning 0 loss.")
|
|
1407
|
+
return {PyTorchLogKeys.TRAIN_LOSS: 0.0}
|
|
1408
|
+
|
|
1409
|
+
return {PyTorchLogKeys.TRAIN_LOSS: running_loss / total_samples}
|
|
1410
|
+
|
|
1411
|
+
def _validation_step(self):
|
|
1412
|
+
self.model.train() # Set to train mode even for validation loss calculation
|
|
1413
|
+
# as model internals (e.g., proposals) might differ, but we still need loss_dict.
|
|
1414
|
+
# use torch.no_grad() to prevent gradient updates.
|
|
1415
|
+
running_loss = 0.0
|
|
1416
|
+
total_samples = 0
|
|
1417
|
+
|
|
1418
|
+
with torch.no_grad():
|
|
1419
|
+
for images, targets in self.validation_loader: # type: ignore
|
|
1420
|
+
batch_size = len(images)
|
|
1421
|
+
|
|
1422
|
+
# Move data to device
|
|
1423
|
+
images = list(img.to(self.device) for img in images)
|
|
1424
|
+
targets = [{k: v.to(self.device) for k, v in t.items()} for t in targets]
|
|
1425
|
+
|
|
1426
|
+
# Get loss dict
|
|
1427
|
+
loss_dict = self.model(images, targets)
|
|
1428
|
+
|
|
1429
|
+
if not loss_dict:
|
|
1430
|
+
_LOGGER.warning("Model returned no losses during validation step. Skipping batch.")
|
|
1431
|
+
continue # Skip if no losses
|
|
1432
|
+
|
|
1433
|
+
# Sum all losses
|
|
1434
|
+
loss: torch.Tensor = sum(l for l in loss_dict.values()) # type: ignore
|
|
1435
|
+
|
|
1436
|
+
running_loss += loss.item() * batch_size
|
|
1437
|
+
total_samples += batch_size # <-- Accumulate total samples
|
|
1438
|
+
|
|
1439
|
+
# Calculate loss using the correct denominator
|
|
1440
|
+
if total_samples == 0:
|
|
1441
|
+
_LOGGER.warning("No samples processed in _validation_step. Returning 0 loss.")
|
|
1442
|
+
return {PyTorchLogKeys.VAL_LOSS: 0.0}
|
|
1443
|
+
|
|
1444
|
+
logs = {PyTorchLogKeys.VAL_LOSS: running_loss / total_samples}
|
|
1445
|
+
return logs
|
|
1446
|
+
|
|
1447
|
+
def evaluate(self,
|
|
1448
|
+
save_dir: Union[str, Path],
|
|
1449
|
+
model_checkpoint: Union[Path, Literal["latest", "current"]],
|
|
1450
|
+
test_data: Optional[Union[DataLoader, Dataset]] = None):
|
|
1451
|
+
"""
|
|
1452
|
+
Evaluates the model using object detection mAP metrics.
|
|
1453
|
+
|
|
1454
|
+
Args:
|
|
1455
|
+
save_dir (str | Path): Directory to save all reports and plots.
|
|
1456
|
+
model_checkpoint ('auto' | Path | None):
|
|
1457
|
+
- Path to a valid checkpoint for the model. The state of the trained model will be overwritten in place.
|
|
1458
|
+
- If 'latest', the latest checkpoint will be loaded if a DragonModelCheckpoint was provided. The state of the trained model will be overwritten in place.
|
|
1459
|
+
- If 'current', use the current state of the trained model up the latest trained epoch.
|
|
1460
|
+
test_data (DataLoader | Dataset | None): Optional Test data to evaluate the model performance. Validation and Test metrics will be saved to subdirectories.
|
|
1461
|
+
"""
|
|
1462
|
+
# Validate model checkpoint
|
|
1463
|
+
if isinstance(model_checkpoint, Path):
|
|
1464
|
+
checkpoint_validated = make_fullpath(model_checkpoint, enforce="file")
|
|
1465
|
+
elif model_checkpoint in [MagicWords.LATEST, MagicWords.CURRENT]:
|
|
1466
|
+
checkpoint_validated = model_checkpoint
|
|
1467
|
+
else:
|
|
1468
|
+
_LOGGER.error(f"'model_checkpoint' must be a Path object, or the string '{MagicWords.LATEST}', or the string '{MagicWords.CURRENT}'.")
|
|
1469
|
+
raise ValueError()
|
|
1470
|
+
|
|
1471
|
+
# Validate directory
|
|
1472
|
+
save_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
1473
|
+
|
|
1474
|
+
# Validate test data and dispatch
|
|
1475
|
+
if test_data is not None:
|
|
1476
|
+
if not isinstance(test_data, (DataLoader, Dataset)):
|
|
1477
|
+
_LOGGER.error(f"Invalid type for 'test_data': '{type(test_data)}'.")
|
|
1478
|
+
raise ValueError()
|
|
1479
|
+
test_data_validated = test_data
|
|
1480
|
+
|
|
1481
|
+
validation_metrics_path = save_path / DragonTrainerKeys.VALIDATION_METRICS_DIR
|
|
1482
|
+
test_metrics_path = save_path / DragonTrainerKeys.TEST_METRICS_DIR
|
|
1483
|
+
|
|
1484
|
+
# Dispatch validation set
|
|
1485
|
+
_LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{DragonTrainerKeys.VALIDATION_METRICS_DIR}'")
|
|
1486
|
+
self._evaluate(save_dir=validation_metrics_path,
|
|
1487
|
+
model_checkpoint=checkpoint_validated,
|
|
1488
|
+
data=None) # 'None' triggers use of self.test_dataset
|
|
1489
|
+
|
|
1490
|
+
# Dispatch test set
|
|
1491
|
+
_LOGGER.info(f"Evaluating on test dataset. Metrics will be saved to '{DragonTrainerKeys.TEST_METRICS_DIR}'")
|
|
1492
|
+
self._evaluate(save_dir=test_metrics_path,
|
|
1493
|
+
model_checkpoint="current", # Use 'current' state after loading checkpoint once
|
|
1494
|
+
data=test_data_validated)
|
|
1495
|
+
else:
|
|
1496
|
+
# Dispatch validation set
|
|
1497
|
+
_LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{save_path.name}'")
|
|
1498
|
+
self._evaluate(save_dir=save_path,
|
|
1499
|
+
model_checkpoint=checkpoint_validated,
|
|
1500
|
+
data=None) # 'None' triggers use of self.test_dataset
|
|
632
1501
|
|
|
1502
|
+
def _evaluate(self,
|
|
1503
|
+
save_dir: Union[str, Path],
|
|
1504
|
+
model_checkpoint: Union[Path, Literal["latest", "current"]],
|
|
1505
|
+
data: Optional[Union[DataLoader, Dataset]]):
|
|
1506
|
+
"""
|
|
1507
|
+
Changed to a private helper method
|
|
1508
|
+
Evaluates the model using object detection mAP metrics.
|
|
1509
|
+
|
|
1510
|
+
Args:
|
|
1511
|
+
save_dir (str | Path): Directory to save all reports and plots.
|
|
1512
|
+
data (DataLoader | Dataset | None): The data to evaluate on. If None, defaults to the trainer's internal test_dataset.
|
|
1513
|
+
model_checkpoint ('auto' | Path | None):
|
|
1514
|
+
- Path to a valid checkpoint for the model. The state of the trained model will be overwritten in place.
|
|
1515
|
+
- If 'latest', the latest checkpoint will be loaded if a DragonModelCheckpoint was provided. The state of the trained model will be overwritten in place.
|
|
1516
|
+
- If 'current', use the current state of the trained model up the latest trained epoch.
|
|
1517
|
+
"""
|
|
1518
|
+
dataset_for_artifacts = None
|
|
1519
|
+
eval_loader = None
|
|
1520
|
+
|
|
1521
|
+
# load model checkpoint
|
|
1522
|
+
if isinstance(model_checkpoint, Path):
|
|
1523
|
+
self._load_checkpoint(path=model_checkpoint)
|
|
1524
|
+
elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback:
|
|
1525
|
+
path_to_latest = self._checkpoint_callback.best_checkpoint_path
|
|
1526
|
+
self._load_checkpoint(path_to_latest)
|
|
1527
|
+
elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback is None:
|
|
1528
|
+
_LOGGER.error(f"'model_checkpoint' set to '{MagicWords.LATEST}' but no checkpoint callback was found.")
|
|
1529
|
+
raise ValueError()
|
|
1530
|
+
|
|
1531
|
+
# Dataloader
|
|
1532
|
+
if isinstance(data, DataLoader):
|
|
1533
|
+
eval_loader = data
|
|
1534
|
+
if hasattr(data, 'dataset'):
|
|
1535
|
+
dataset_for_artifacts = data.dataset # type: ignore
|
|
1536
|
+
elif isinstance(data, Dataset):
|
|
1537
|
+
# Create a new loader from the provided dataset
|
|
1538
|
+
eval_loader = DataLoader(data,
|
|
1539
|
+
batch_size=self._batch_size,
|
|
1540
|
+
shuffle=False,
|
|
1541
|
+
num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
|
|
1542
|
+
pin_memory=(self.device.type == "cuda"),
|
|
1543
|
+
collate_fn=self.collate_fn)
|
|
1544
|
+
dataset_for_artifacts = data
|
|
1545
|
+
else: # data is None, use the trainer's default test dataset
|
|
1546
|
+
if self.validation_dataset is None:
|
|
1547
|
+
_LOGGER.error("Cannot evaluate. No data provided and no test_dataset available in the trainer.")
|
|
1548
|
+
raise ValueError()
|
|
1549
|
+
# Create a fresh DataLoader from the test_dataset
|
|
1550
|
+
eval_loader = DataLoader(
|
|
1551
|
+
self.validation_dataset,
|
|
1552
|
+
batch_size=self._batch_size,
|
|
1553
|
+
shuffle=False,
|
|
1554
|
+
num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
|
|
1555
|
+
pin_memory=(self.device.type == "cuda"),
|
|
1556
|
+
collate_fn=self.collate_fn
|
|
1557
|
+
)
|
|
1558
|
+
dataset_for_artifacts = self.validation_dataset
|
|
1559
|
+
|
|
1560
|
+
if eval_loader is None:
|
|
1561
|
+
_LOGGER.error("Cannot evaluate. No valid data was provided or found.")
|
|
1562
|
+
raise ValueError()
|
|
1563
|
+
|
|
1564
|
+
# print("\n--- Model Evaluation ---")
|
|
1565
|
+
|
|
1566
|
+
all_predictions = []
|
|
1567
|
+
all_targets = []
|
|
1568
|
+
|
|
1569
|
+
self.model.eval() # Set model to evaluation mode
|
|
1570
|
+
self.model.to(self.device)
|
|
1571
|
+
|
|
1572
|
+
with torch.no_grad():
|
|
1573
|
+
for images, targets in eval_loader:
|
|
1574
|
+
# Move images to device
|
|
1575
|
+
images = list(img.to(self.device) for img in images)
|
|
1576
|
+
|
|
1577
|
+
# Model returns predictions when in eval() mode
|
|
1578
|
+
predictions = self.model(images)
|
|
1579
|
+
|
|
1580
|
+
# Move predictions and targets to CPU for aggregation
|
|
1581
|
+
cpu_preds = [{k: v.to('cpu') for k, v in p.items()} for p in predictions]
|
|
1582
|
+
cpu_targets = [{k: v.to('cpu') for k, v in t.items()} for t in targets]
|
|
1583
|
+
|
|
1584
|
+
all_predictions.extend(cpu_preds)
|
|
1585
|
+
all_targets.extend(cpu_targets)
|
|
1586
|
+
|
|
1587
|
+
if not all_targets:
|
|
1588
|
+
_LOGGER.error("Evaluation failed: No data was processed.")
|
|
1589
|
+
return
|
|
1590
|
+
|
|
1591
|
+
# Get class names from the dataset for the report
|
|
1592
|
+
class_names = None
|
|
1593
|
+
try:
|
|
1594
|
+
# Try to get 'classes' from ObjectDetectionDatasetMaker
|
|
1595
|
+
if hasattr(dataset_for_artifacts, 'classes'):
|
|
1596
|
+
class_names = dataset_for_artifacts.classes # type: ignore
|
|
1597
|
+
# Fallback for Subset
|
|
1598
|
+
elif hasattr(dataset_for_artifacts, 'dataset') and hasattr(dataset_for_artifacts.dataset, 'classes'): # type: ignore
|
|
1599
|
+
class_names = dataset_for_artifacts.dataset.classes # type: ignore
|
|
1600
|
+
except AttributeError:
|
|
1601
|
+
_LOGGER.warning("Could not find 'classes' attribute on dataset. Per-class metrics will not be named.")
|
|
1602
|
+
pass # class_names is still None
|
|
1603
|
+
|
|
1604
|
+
# --- Routing Logic ---
|
|
1605
|
+
object_detection_metrics(
|
|
1606
|
+
preds=all_predictions,
|
|
1607
|
+
targets=all_targets,
|
|
1608
|
+
save_dir=save_dir,
|
|
1609
|
+
class_names=class_names,
|
|
1610
|
+
print_output=False
|
|
1611
|
+
)
|
|
1612
|
+
|
|
1613
|
+
def finalize_model_training(self,
|
|
1614
|
+
save_dir: Union[str, Path],
|
|
1615
|
+
model_checkpoint: Union[Path, Literal['latest', 'current']],
|
|
1616
|
+
finalize_config: FinalizeObjectDetection
|
|
1617
|
+
):
|
|
1618
|
+
"""
|
|
1619
|
+
Saves a finalized, "inference-ready" model state to a .pth file.
|
|
1620
|
+
|
|
1621
|
+
This method saves the model's `state_dict` and the final epoch number.
|
|
1622
|
+
|
|
1623
|
+
Args:
|
|
1624
|
+
save_dir (Union[str, Path]): The directory to save the finalized model.
|
|
1625
|
+
model_checkpoint (Union[Path, Literal["latest", "current"]]):
|
|
1626
|
+
- Path: Loads the model state from a specific checkpoint file.
|
|
1627
|
+
- "latest": Loads the best model state saved by the `DragonModelCheckpoint` callback.
|
|
1628
|
+
- "current": Uses the model's state as it is at the end of the `fit()` call.
|
|
1629
|
+
finalize_config (FinalizeObjectDetection): A data class instance specific to the ML task containing task-specific metadata required for inference.
|
|
1630
|
+
"""
|
|
1631
|
+
if not isinstance(finalize_config, FinalizeObjectDetection):
|
|
1632
|
+
_LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeObjectDetection', but got {type(finalize_config).__name__}.")
|
|
1633
|
+
raise TypeError()
|
|
1634
|
+
|
|
1635
|
+
# handle save path
|
|
1636
|
+
dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
1637
|
+
full_path = dir_path / finalize_config.filename
|
|
1638
|
+
|
|
1639
|
+
# handle checkpoint
|
|
1640
|
+
self._load_model_state_for_finalizing(model_checkpoint)
|
|
1641
|
+
|
|
1642
|
+
# Create finalized data
|
|
1643
|
+
finalized_data = {
|
|
1644
|
+
PyTorchCheckpointKeys.EPOCH: self.epoch,
|
|
1645
|
+
PyTorchCheckpointKeys.MODEL_STATE: self.model.state_dict(),
|
|
1646
|
+
}
|
|
1647
|
+
|
|
1648
|
+
if finalize_config.class_map is not None:
|
|
1649
|
+
finalized_data[PyTorchCheckpointKeys.CLASS_MAP] = finalize_config.class_map
|
|
1650
|
+
|
|
1651
|
+
torch.save(finalized_data, full_path)
|
|
1652
|
+
|
|
1653
|
+
_LOGGER.info(f"Finalized model file saved to '{full_path}'")
|
|
1654
|
+
|
|
1655
|
+
# --- DragonSequenceTrainer ----
|
|
1656
|
+
class DragonSequenceTrainer(_BaseDragonTrainer):
|
|
1657
|
+
def __init__(self,
|
|
1658
|
+
model: nn.Module,
|
|
1659
|
+
train_dataset: Dataset,
|
|
1660
|
+
validation_dataset: Dataset,
|
|
1661
|
+
kind: Literal["sequence-to-sequence", "sequence-to-value"],
|
|
1662
|
+
optimizer: torch.optim.Optimizer,
|
|
1663
|
+
device: Union[Literal['cuda', 'mps', 'cpu'],str],
|
|
1664
|
+
checkpoint_callback: Optional[DragonModelCheckpoint],
|
|
1665
|
+
early_stopping_callback: Optional[DragonEarlyStopping],
|
|
1666
|
+
lr_scheduler_callback: Optional[DragonLRScheduler],
|
|
1667
|
+
extra_callbacks: Optional[List[_Callback]] = None,
|
|
1668
|
+
criterion: Union[nn.Module,Literal["auto"]] = "auto",
|
|
1669
|
+
dataloader_workers: int = 2):
|
|
1670
|
+
"""
|
|
1671
|
+
Automates the training process of a PyTorch Sequence Model.
|
|
1672
|
+
|
|
1673
|
+
Built-in Callbacks: `History`, `TqdmProgressBar`
|
|
1674
|
+
|
|
1675
|
+
Args:
|
|
1676
|
+
model (nn.Module): The PyTorch model to train.
|
|
1677
|
+
train_dataset (Dataset): The training dataset.
|
|
1678
|
+
validation_dataset (Dataset): The validation dataset.
|
|
1679
|
+
kind (str): Used to redirect to the correct process ('sequence-to-sequence' or 'sequence-to-value').
|
|
1680
|
+
criterion (nn.Module | "auto"): The loss function to use. If "auto", it will be inferred from the selected task
|
|
1681
|
+
optimizer (torch.optim.Optimizer): The optimizer.
|
|
1682
|
+
device (str): The device to run training on ('cpu', 'cuda', 'mps').
|
|
1683
|
+
dataloader_workers (int): Subprocesses for data loading.
|
|
1684
|
+
extra_callbacks (List[Callback] | None): A list of extra callbacks to use during training.
|
|
1685
|
+
"""
|
|
1686
|
+
# Call the base class constructor with common parameters
|
|
1687
|
+
super().__init__(
|
|
1688
|
+
model=model,
|
|
1689
|
+
optimizer=optimizer,
|
|
1690
|
+
device=device,
|
|
1691
|
+
dataloader_workers=dataloader_workers,
|
|
1692
|
+
checkpoint_callback=checkpoint_callback,
|
|
1693
|
+
early_stopping_callback=early_stopping_callback,
|
|
1694
|
+
lr_scheduler_callback=lr_scheduler_callback,
|
|
1695
|
+
extra_callbacks=extra_callbacks
|
|
1696
|
+
)
|
|
1697
|
+
|
|
1698
|
+
if kind not in [MLTaskKeys.SEQUENCE_SEQUENCE, MLTaskKeys.SEQUENCE_VALUE]:
|
|
1699
|
+
raise ValueError(f"'{kind}' is not a valid task type for DragonSequenceTrainer.")
|
|
1700
|
+
|
|
1701
|
+
self.train_dataset = train_dataset
|
|
1702
|
+
self.validation_dataset = validation_dataset
|
|
1703
|
+
self.kind = kind
|
|
1704
|
+
|
|
1705
|
+
# try to validate against Dragon Sequence model
|
|
1706
|
+
if hasattr(self.model, "prediction_mode"):
|
|
1707
|
+
key_to_check: str = self.model.prediction_mode # type: ignore
|
|
1708
|
+
if not key_to_check == self.kind:
|
|
1709
|
+
_LOGGER.error(f"Trainer was set for '{self.kind}', but model architecture '{self.model}' is built for '{key_to_check}'.")
|
|
1710
|
+
raise RuntimeError()
|
|
1711
|
+
|
|
1712
|
+
# loss function
|
|
1713
|
+
if criterion == "auto":
|
|
1714
|
+
# Both sequence tasks are treated as regression problems
|
|
1715
|
+
self.criterion = nn.MSELoss()
|
|
1716
|
+
else:
|
|
1717
|
+
self.criterion = criterion
|
|
1718
|
+
|
|
1719
|
+
def _create_dataloaders(self, batch_size: int, shuffle: bool):
|
|
1720
|
+
"""Initializes the DataLoaders."""
|
|
1721
|
+
# Ensure stability on MPS devices by setting num_workers to 0
|
|
1722
|
+
loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
|
|
1723
|
+
|
|
1724
|
+
self.train_loader = DataLoader(
|
|
1725
|
+
dataset=self.train_dataset,
|
|
1726
|
+
batch_size=batch_size,
|
|
1727
|
+
shuffle=shuffle,
|
|
1728
|
+
num_workers=loader_workers,
|
|
1729
|
+
pin_memory=("cuda" in self.device.type),
|
|
1730
|
+
drop_last=True # Drops the last batch if incomplete, selecting a good batch size is key.
|
|
1731
|
+
)
|
|
1732
|
+
|
|
1733
|
+
self.validation_loader = DataLoader(
|
|
1734
|
+
dataset=self.validation_dataset,
|
|
1735
|
+
batch_size=batch_size,
|
|
1736
|
+
shuffle=False,
|
|
1737
|
+
num_workers=loader_workers,
|
|
1738
|
+
pin_memory=("cuda" in self.device.type)
|
|
1739
|
+
)
|
|
1740
|
+
|
|
1741
|
+
def _train_step(self):
|
|
1742
|
+
self.model.train()
|
|
1743
|
+
running_loss = 0.0
|
|
1744
|
+
total_samples = 0
|
|
1745
|
+
|
|
1746
|
+
for batch_idx, (features, target) in enumerate(self.train_loader): # type: ignore
|
|
1747
|
+
# Create a log dictionary for the batch
|
|
1748
|
+
batch_logs = {
|
|
1749
|
+
PyTorchLogKeys.BATCH_INDEX: batch_idx,
|
|
1750
|
+
PyTorchLogKeys.BATCH_SIZE: features.size(0)
|
|
1751
|
+
}
|
|
1752
|
+
self._callbacks_hook('on_batch_begin', batch_idx, logs=batch_logs)
|
|
1753
|
+
|
|
1754
|
+
features, target = features.to(self.device), target.to(self.device)
|
|
1755
|
+
self.optimizer.zero_grad()
|
|
1756
|
+
|
|
1757
|
+
output = self.model(features)
|
|
1758
|
+
|
|
1759
|
+
# --- Label Type/Shape Correction ---
|
|
1760
|
+
# Ensure target is float for MSELoss
|
|
1761
|
+
target = target.float()
|
|
1762
|
+
|
|
1763
|
+
# For seq-to-val, models might output [N, 1] but target is [N].
|
|
1764
|
+
if self.kind == MLTaskKeys.SEQUENCE_VALUE:
|
|
1765
|
+
if output.ndim == 2 and output.shape[1] == 1 and target.ndim == 1:
|
|
1766
|
+
output = output.squeeze(1)
|
|
1767
|
+
|
|
1768
|
+
# For seq-to-seq, models might output [N, Seq, 1] but target is [N, Seq].
|
|
1769
|
+
elif self.kind == MLTaskKeys.SEQUENCE_SEQUENCE:
|
|
1770
|
+
if output.ndim == 3 and output.shape[2] == 1 and target.ndim == 2:
|
|
1771
|
+
output = output.squeeze(-1)
|
|
1772
|
+
|
|
1773
|
+
loss = self.criterion(output, target)
|
|
1774
|
+
|
|
1775
|
+
loss.backward()
|
|
1776
|
+
self.optimizer.step()
|
|
1777
|
+
|
|
1778
|
+
# Calculate batch loss and update running loss for the epoch
|
|
1779
|
+
batch_loss = loss.item()
|
|
1780
|
+
batch_size = features.size(0)
|
|
1781
|
+
running_loss += batch_loss * batch_size # Accumulate total loss
|
|
1782
|
+
total_samples += batch_size # total samples
|
|
1783
|
+
|
|
1784
|
+
# Add the batch loss to the logs and call the end-of-batch hook
|
|
1785
|
+
batch_logs[PyTorchLogKeys.BATCH_LOSS] = batch_loss
|
|
1786
|
+
self._callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
|
|
1787
|
+
|
|
1788
|
+
if total_samples == 0:
|
|
1789
|
+
_LOGGER.warning("No samples processed in a train_step. Returning 0 loss.")
|
|
1790
|
+
return {PyTorchLogKeys.TRAIN_LOSS: 0.0}
|
|
1791
|
+
|
|
1792
|
+
return {PyTorchLogKeys.TRAIN_LOSS: running_loss / total_samples} # type: ignore
|
|
1793
|
+
|
|
1794
|
+
def _validation_step(self):
|
|
1795
|
+
self.model.eval()
|
|
1796
|
+
running_loss = 0.0
|
|
1797
|
+
|
|
1798
|
+
with torch.no_grad():
|
|
1799
|
+
for features, target in self.validation_loader: # type: ignore
|
|
1800
|
+
features, target = features.to(self.device), target.to(self.device)
|
|
1801
|
+
|
|
1802
|
+
output = self.model(features)
|
|
1803
|
+
|
|
1804
|
+
# --- Label Type/Shape Correction ---
|
|
1805
|
+
target = target.float()
|
|
1806
|
+
|
|
1807
|
+
# For seq-to-val, models might output [N, 1] but target is [N].
|
|
1808
|
+
if self.kind == MLTaskKeys.SEQUENCE_VALUE:
|
|
1809
|
+
if output.ndim == 2 and output.shape[1] == 1 and target.ndim == 1:
|
|
1810
|
+
output = output.squeeze(1)
|
|
1811
|
+
|
|
1812
|
+
# For seq-to-seq, models might output [N, Seq, 1] but target is [N, Seq].
|
|
1813
|
+
elif self.kind == MLTaskKeys.SEQUENCE_SEQUENCE:
|
|
1814
|
+
if output.ndim == 3 and output.shape[2] == 1 and target.ndim == 2:
|
|
1815
|
+
output = output.squeeze(-1)
|
|
1816
|
+
|
|
1817
|
+
loss = self.criterion(output, target)
|
|
1818
|
+
|
|
1819
|
+
running_loss += loss.item() * features.size(0)
|
|
1820
|
+
|
|
1821
|
+
if not self.validation_loader.dataset: # type: ignore
|
|
1822
|
+
_LOGGER.warning("No samples processed in _validation_step. Returning 0 loss.")
|
|
1823
|
+
return {PyTorchLogKeys.VAL_LOSS: 0.0}
|
|
1824
|
+
|
|
1825
|
+
logs = {PyTorchLogKeys.VAL_LOSS: running_loss / len(self.validation_loader.dataset)} # type: ignore
|
|
1826
|
+
return logs
|
|
1827
|
+
|
|
1828
|
+
def _predict_for_eval(self, dataloader: DataLoader):
|
|
1829
|
+
"""
|
|
1830
|
+
Private method to yield model predictions batch by batch for evaluation.
|
|
1831
|
+
|
|
1832
|
+
Yields:
|
|
1833
|
+
tuple: A tuple containing (y_pred_batch, y_prob_batch, y_true_batch).
|
|
1834
|
+
y_prob_batch is always None for sequence tasks.
|
|
1835
|
+
"""
|
|
1836
|
+
self.model.eval()
|
|
1837
|
+
self.model.to(self.device)
|
|
1838
|
+
|
|
1839
|
+
with torch.no_grad():
|
|
1840
|
+
for features, target in dataloader:
|
|
1841
|
+
features = features.to(self.device)
|
|
1842
|
+
output = self.model(features).cpu()
|
|
1843
|
+
|
|
1844
|
+
y_pred_batch = output.numpy()
|
|
1845
|
+
y_prob_batch = None # Not applicable for sequence regression
|
|
1846
|
+
y_true_batch = target.numpy()
|
|
1847
|
+
|
|
1848
|
+
yield y_pred_batch, y_prob_batch, y_true_batch
|
|
1849
|
+
|
|
1850
|
+
def evaluate(self,
|
|
1851
|
+
save_dir: Union[str, Path],
|
|
1852
|
+
model_checkpoint: Union[Path, Literal["latest", "current"]],
|
|
1853
|
+
test_data: Optional[Union[DataLoader, Dataset]] = None,
|
|
1854
|
+
val_format_configuration: Optional[Union[SequenceValueMetricsFormat,
|
|
1855
|
+
SequenceSequenceMetricsFormat]]=None,
|
|
1856
|
+
test_format_configuration: Optional[Union[SequenceValueMetricsFormat,
|
|
1857
|
+
SequenceSequenceMetricsFormat]]=None):
|
|
1858
|
+
"""
|
|
1859
|
+
Evaluates the model, routing to the correct evaluation function.
|
|
1860
|
+
|
|
1861
|
+
Args:
|
|
1862
|
+
model_checkpoint ('auto' | Path | None):
|
|
1863
|
+
- Path to a valid checkpoint for the model.
|
|
1864
|
+
- If 'latest', the latest checkpoint will be loaded.
|
|
1865
|
+
- If 'current', use the current state of the trained model.
|
|
1866
|
+
save_dir (str | Path): Directory to save all reports and plots.
|
|
1867
|
+
test_data (DataLoader | Dataset | None): Optional Test data.
|
|
1868
|
+
val_format_configuration: Optional configuration for validation metrics.
|
|
1869
|
+
test_format_configuration: Optional configuration for test metrics.
|
|
1870
|
+
"""
|
|
1871
|
+
# Validate model checkpoint
|
|
1872
|
+
if isinstance(model_checkpoint, Path):
|
|
1873
|
+
checkpoint_validated = make_fullpath(model_checkpoint, enforce="file")
|
|
1874
|
+
elif model_checkpoint in [MagicWords.LATEST, MagicWords.CURRENT]:
|
|
1875
|
+
checkpoint_validated = model_checkpoint
|
|
1876
|
+
else:
|
|
1877
|
+
_LOGGER.error(f"'model_checkpoint' must be a Path object, or '{MagicWords.LATEST}', or '{MagicWords.CURRENT}'.")
|
|
1878
|
+
raise ValueError()
|
|
1879
|
+
|
|
1880
|
+
# Validate val configuration
|
|
1881
|
+
if val_format_configuration is not None:
|
|
1882
|
+
if not isinstance(val_format_configuration, (SequenceValueMetricsFormat, SequenceSequenceMetricsFormat)):
|
|
1883
|
+
_LOGGER.error(f"Invalid 'val_format_configuration': '{type(val_format_configuration)}'.")
|
|
1884
|
+
raise ValueError()
|
|
1885
|
+
|
|
1886
|
+
# Validate directory
|
|
1887
|
+
save_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
1888
|
+
|
|
1889
|
+
# Validate test data and dispatch
|
|
1890
|
+
if test_data is not None:
|
|
1891
|
+
if not isinstance(test_data, (DataLoader, Dataset)):
|
|
1892
|
+
_LOGGER.error(f"Invalid type for 'test_data': '{type(test_data)}'.")
|
|
1893
|
+
raise ValueError()
|
|
1894
|
+
test_data_validated = test_data
|
|
1895
|
+
|
|
1896
|
+
validation_metrics_path = save_path / DragonTrainerKeys.VALIDATION_METRICS_DIR
|
|
1897
|
+
test_metrics_path = save_path / DragonTrainerKeys.TEST_METRICS_DIR
|
|
1898
|
+
|
|
1899
|
+
# Dispatch validation set
|
|
1900
|
+
_LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{DragonTrainerKeys.VALIDATION_METRICS_DIR}'")
|
|
1901
|
+
self._evaluate(save_dir=validation_metrics_path,
|
|
1902
|
+
model_checkpoint=checkpoint_validated,
|
|
1903
|
+
data=None,
|
|
1904
|
+
format_configuration=val_format_configuration)
|
|
1905
|
+
|
|
1906
|
+
# Validate test configuration
|
|
1907
|
+
test_configuration_validated = None
|
|
1908
|
+
if test_format_configuration is not None:
|
|
1909
|
+
if not isinstance(test_format_configuration, (SequenceValueMetricsFormat, SequenceSequenceMetricsFormat)):
|
|
1910
|
+
warning_message_type = f"Invalid test_format_configuration': '{type(test_format_configuration)}'."
|
|
1911
|
+
if val_format_configuration is not None:
|
|
1912
|
+
warning_message_type += " 'val_format_configuration' will be used."
|
|
1913
|
+
test_configuration_validated = val_format_configuration
|
|
1914
|
+
else:
|
|
1915
|
+
warning_message_type += " Using default format."
|
|
1916
|
+
_LOGGER.warning(warning_message_type)
|
|
1917
|
+
else:
|
|
1918
|
+
test_configuration_validated = test_format_configuration
|
|
1919
|
+
|
|
1920
|
+
# Dispatch test set
|
|
1921
|
+
_LOGGER.info(f"Evaluating on test dataset. Metrics will be saved to '{DragonTrainerKeys.TEST_METRICS_DIR}'")
|
|
1922
|
+
self._evaluate(save_dir=test_metrics_path,
|
|
1923
|
+
model_checkpoint="current",
|
|
1924
|
+
data=test_data_validated,
|
|
1925
|
+
format_configuration=test_configuration_validated)
|
|
1926
|
+
else:
|
|
1927
|
+
# Dispatch validation set
|
|
1928
|
+
_LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{save_path.name}'")
|
|
1929
|
+
self._evaluate(save_dir=save_path,
|
|
1930
|
+
model_checkpoint=checkpoint_validated,
|
|
1931
|
+
data=None,
|
|
1932
|
+
format_configuration=val_format_configuration)
|
|
1933
|
+
|
|
1934
|
+
def _evaluate(self,
|
|
1935
|
+
save_dir: Union[str, Path],
|
|
1936
|
+
model_checkpoint: Union[Path, Literal["latest", "current"]],
|
|
1937
|
+
data: Optional[Union[DataLoader, Dataset]],
|
|
1938
|
+
format_configuration: object):
|
|
1939
|
+
"""
|
|
1940
|
+
Private evaluation helper.
|
|
1941
|
+
"""
|
|
1942
|
+
eval_loader = None
|
|
1943
|
+
|
|
1944
|
+
# load model checkpoint
|
|
1945
|
+
if isinstance(model_checkpoint, Path):
|
|
1946
|
+
self._load_checkpoint(path=model_checkpoint)
|
|
1947
|
+
elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback:
|
|
1948
|
+
path_to_latest = self._checkpoint_callback.best_checkpoint_path
|
|
1949
|
+
self._load_checkpoint(path_to_latest)
|
|
1950
|
+
elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback is None:
|
|
1951
|
+
_LOGGER.error(f"'model_checkpoint' set to '{MagicWords.LATEST}' but no checkpoint callback was found.")
|
|
1952
|
+
raise ValueError()
|
|
1953
|
+
|
|
1954
|
+
# Dataloader
|
|
1955
|
+
if isinstance(data, DataLoader):
|
|
1956
|
+
eval_loader = data
|
|
1957
|
+
elif isinstance(data, Dataset):
|
|
1958
|
+
# Create a new loader from the provided dataset
|
|
1959
|
+
eval_loader = DataLoader(data,
|
|
1960
|
+
batch_size=self._batch_size,
|
|
1961
|
+
shuffle=False,
|
|
1962
|
+
num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
|
|
1963
|
+
pin_memory=(self.device.type == "cuda"))
|
|
1964
|
+
else: # data is None, use the trainer's default validation dataset
|
|
1965
|
+
if self.validation_dataset is None:
|
|
1966
|
+
_LOGGER.error("Cannot evaluate. No data provided and no validation_dataset available in the trainer.")
|
|
1967
|
+
raise ValueError()
|
|
1968
|
+
eval_loader = DataLoader(self.validation_dataset,
|
|
1969
|
+
batch_size=self._batch_size,
|
|
1970
|
+
shuffle=False,
|
|
1971
|
+
num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
|
|
1972
|
+
pin_memory=(self.device.type == "cuda"))
|
|
1973
|
+
|
|
1974
|
+
if eval_loader is None:
|
|
1975
|
+
_LOGGER.error("Cannot evaluate. No valid data was provided or found.")
|
|
1976
|
+
raise ValueError()
|
|
1977
|
+
|
|
1978
|
+
all_preds, _, all_true = [], [], []
|
|
1979
|
+
for y_pred_b, y_prob_b, y_true_b in self._predict_for_eval(eval_loader):
|
|
1980
|
+
if y_pred_b is not None: all_preds.append(y_pred_b)
|
|
1981
|
+
if y_true_b is not None: all_true.append(y_true_b)
|
|
1982
|
+
|
|
1983
|
+
if not all_true:
|
|
1984
|
+
_LOGGER.error("Evaluation failed: No data was processed.")
|
|
1985
|
+
return
|
|
1986
|
+
|
|
1987
|
+
y_pred = np.concatenate(all_preds)
|
|
1988
|
+
y_true = np.concatenate(all_true)
|
|
1989
|
+
|
|
1990
|
+
# --- Routing Logic ---
|
|
1991
|
+
if self.kind == MLTaskKeys.SEQUENCE_VALUE:
|
|
1992
|
+
config = None
|
|
1993
|
+
if format_configuration and isinstance(format_configuration, SequenceValueMetricsFormat):
|
|
1994
|
+
config = format_configuration
|
|
1995
|
+
elif format_configuration:
|
|
1996
|
+
_LOGGER.warning(f"Wrong config type: Received {type(format_configuration).__name__}, expected SequenceValueMetricsFormat.")
|
|
1997
|
+
|
|
1998
|
+
sequence_to_value_metrics(y_true=y_true,
|
|
1999
|
+
y_pred=y_pred,
|
|
2000
|
+
save_dir=save_dir,
|
|
2001
|
+
config=config)
|
|
2002
|
+
|
|
2003
|
+
elif self.kind == MLTaskKeys.SEQUENCE_SEQUENCE:
|
|
2004
|
+
config = None
|
|
2005
|
+
if format_configuration and isinstance(format_configuration, SequenceSequenceMetricsFormat):
|
|
2006
|
+
config = format_configuration
|
|
2007
|
+
elif format_configuration:
|
|
2008
|
+
_LOGGER.warning(f"Wrong config type: Received {type(format_configuration).__name__}, expected SequenceSequenceMetricsFormat.")
|
|
2009
|
+
|
|
2010
|
+
sequence_to_sequence_metrics(y_true=y_true,
|
|
2011
|
+
y_pred=y_pred,
|
|
2012
|
+
save_dir=save_dir,
|
|
2013
|
+
config=config)
|
|
2014
|
+
|
|
2015
|
+
def finalize_model_training(self,
|
|
2016
|
+
save_dir: Union[str, Path],
|
|
2017
|
+
model_checkpoint: Union[Path, Literal['latest', 'current']],
|
|
2018
|
+
finalize_config: FinalizeSequencePrediction):
|
|
2019
|
+
"""
|
|
2020
|
+
Saves a finalized, "inference-ready" model state to a .pth file.
|
|
2021
|
+
|
|
2022
|
+
This method saves the model's `state_dict` and the final epoch number.
|
|
2023
|
+
|
|
2024
|
+
Args:
|
|
2025
|
+
save_dir (Union[str, Path]): The directory to save the finalized model.
|
|
2026
|
+
model_checkpoint (Union[Path, Literal["latest", "current"]]):
|
|
2027
|
+
- Path: Loads the model state from a specific checkpoint file.
|
|
2028
|
+
- "latest": Loads the best model state saved by the `DragonModelCheckpoint` callback.
|
|
2029
|
+
- "current": Uses the model's state as it is at the end of the `fit()` call.
|
|
2030
|
+
finalize_config (FinalizeSequencePrediction): A data class instance specific to the ML task containing task-specific metadata required for inference.
|
|
2031
|
+
"""
|
|
2032
|
+
if not isinstance(finalize_config, FinalizeSequencePrediction):
|
|
2033
|
+
_LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeSequencePrediction', but got {type(finalize_config).__name__}.")
|
|
2034
|
+
raise TypeError()
|
|
2035
|
+
|
|
2036
|
+
# handle save path
|
|
2037
|
+
dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
2038
|
+
full_path = dir_path / finalize_config.filename
|
|
2039
|
+
|
|
2040
|
+
# handle checkpoint
|
|
2041
|
+
self._load_model_state_for_finalizing(model_checkpoint)
|
|
2042
|
+
|
|
2043
|
+
# Create finalized data
|
|
2044
|
+
finalized_data = {
|
|
2045
|
+
PyTorchCheckpointKeys.EPOCH: self.epoch,
|
|
2046
|
+
PyTorchCheckpointKeys.MODEL_STATE: self.model.state_dict(),
|
|
2047
|
+
}
|
|
2048
|
+
|
|
2049
|
+
if finalize_config.sequence_length is not None:
|
|
2050
|
+
finalized_data[PyTorchCheckpointKeys.SEQUENCE_LENGTH] = finalize_config.sequence_length
|
|
2051
|
+
if finalize_config.initial_sequence is not None:
|
|
2052
|
+
finalized_data[PyTorchCheckpointKeys.INITIAL_SEQUENCE] = finalize_config.initial_sequence
|
|
2053
|
+
|
|
2054
|
+
torch.save(finalized_data, full_path)
|
|
2055
|
+
|
|
2056
|
+
_LOGGER.info(f"Finalized model file saved to '{full_path}'")
|
|
2057
|
+
|
|
633
2058
|
|
|
634
2059
|
def info():
|
|
635
2060
|
_script_info(__all__)
|