dragon-ml-toolbox 14.7.0__py3-none-any.whl → 16.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/METADATA +9 -5
- dragon_ml_toolbox-16.2.0.dist-info/RECORD +51 -0
- ml_tools/ETL_cleaning.py +20 -20
- ml_tools/ETL_engineering.py +23 -25
- ml_tools/GUI_tools.py +20 -20
- ml_tools/MICE_imputation.py +3 -3
- ml_tools/ML_callbacks.py +43 -26
- ml_tools/ML_configuration.py +704 -24
- ml_tools/ML_datasetmaster.py +235 -280
- ml_tools/ML_evaluation.py +144 -39
- ml_tools/ML_evaluation_multi.py +103 -35
- ml_tools/ML_inference.py +290 -208
- ml_tools/ML_models.py +13 -102
- ml_tools/ML_models_advanced.py +1 -1
- ml_tools/ML_optimization.py +12 -12
- ml_tools/ML_scaler.py +11 -11
- ml_tools/ML_sequence_datasetmaster.py +341 -0
- ml_tools/ML_sequence_evaluation.py +219 -0
- ml_tools/ML_sequence_inference.py +391 -0
- ml_tools/ML_sequence_models.py +139 -0
- ml_tools/ML_trainer.py +1342 -386
- ml_tools/ML_utilities.py +1 -1
- ml_tools/ML_vision_datasetmaster.py +120 -72
- ml_tools/ML_vision_evaluation.py +30 -6
- ml_tools/ML_vision_inference.py +129 -152
- ml_tools/ML_vision_models.py +1 -1
- ml_tools/ML_vision_transformers.py +121 -40
- ml_tools/PSO_optimization.py +6 -6
- ml_tools/SQL.py +4 -4
- ml_tools/{keys.py → _keys.py} +45 -0
- ml_tools/_schema.py +1 -1
- ml_tools/ensemble_evaluation.py +1 -1
- ml_tools/ensemble_inference.py +7 -33
- ml_tools/ensemble_learning.py +1 -1
- ml_tools/optimization_tools.py +2 -2
- ml_tools/path_manager.py +5 -5
- ml_tools/utilities.py +1 -2
- dragon_ml_toolbox-14.7.0.dist-info/RECORD +0 -49
- ml_tools/RNN_forecast.py +0 -56
- ml_tools/_ML_vision_recipe.py +0 -88
- {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/top_level.txt +0 -0
ml_tools/ML_trainer.py
CHANGED
|
@@ -1,80 +1,96 @@
|
|
|
1
|
-
from typing import List, Literal, Union, Optional, Callable, Dict, Any
|
|
1
|
+
from typing import List, Literal, Union, Optional, Callable, Dict, Any
|
|
2
2
|
from pathlib import Path
|
|
3
3
|
from torch.utils.data import DataLoader, Dataset
|
|
4
4
|
import torch
|
|
5
5
|
from torch import nn
|
|
6
6
|
import numpy as np
|
|
7
|
+
from abc import ABC, abstractmethod
|
|
7
8
|
|
|
8
|
-
from .
|
|
9
|
+
from .path_manager import make_fullpath
|
|
10
|
+
from .ML_callbacks import _Callback, History, TqdmProgressBar, DragonModelCheckpoint, DragonEarlyStopping, DragonLRScheduler
|
|
9
11
|
from .ML_evaluation import classification_metrics, regression_metrics, plot_losses, shap_summary_plot, plot_attention_importance
|
|
10
12
|
from .ML_evaluation_multi import multi_target_regression_metrics, multi_label_classification_metrics, multi_target_shap_summary_plot
|
|
13
|
+
from .ML_vision_evaluation import segmentation_metrics, object_detection_metrics
|
|
14
|
+
from .ML_sequence_evaluation import sequence_to_sequence_metrics, sequence_to_value_metrics
|
|
15
|
+
from .ML_configuration import (RegressionMetricsFormat,
|
|
16
|
+
MultiTargetRegressionMetricsFormat,
|
|
17
|
+
BinaryClassificationMetricsFormat,
|
|
18
|
+
MultiClassClassificationMetricsFormat,
|
|
19
|
+
BinaryImageClassificationMetricsFormat,
|
|
20
|
+
MultiClassImageClassificationMetricsFormat,
|
|
21
|
+
MultiLabelBinaryClassificationMetricsFormat,
|
|
22
|
+
BinarySegmentationMetricsFormat,
|
|
23
|
+
MultiClassSegmentationMetricsFormat,
|
|
24
|
+
SequenceValueMetricsFormat,
|
|
25
|
+
SequenceSequenceMetricsFormat,
|
|
26
|
+
|
|
27
|
+
FinalizeBinaryClassification,
|
|
28
|
+
FinalizeBinarySegmentation,
|
|
29
|
+
FinalizeBinaryImageClassification,
|
|
30
|
+
FinalizeMultiClassClassification,
|
|
31
|
+
FinalizeMultiClassImageClassification,
|
|
32
|
+
FinalizeMultiClassSegmentation,
|
|
33
|
+
FinalizeMultiLabelBinaryClassification,
|
|
34
|
+
FinalizeMultiTargetRegression,
|
|
35
|
+
FinalizeRegression,
|
|
36
|
+
FinalizeObjectDetection,
|
|
37
|
+
FinalizeSequencePrediction)
|
|
38
|
+
|
|
11
39
|
from ._script_info import _script_info
|
|
12
|
-
from .
|
|
40
|
+
from ._keys import PyTorchLogKeys, PyTorchCheckpointKeys, DatasetKeys, MLTaskKeys, MagicWords, DragonTrainerKeys
|
|
13
41
|
from ._logger import _LOGGER
|
|
14
|
-
from .path_manager import make_fullpath
|
|
15
|
-
from .ML_vision_evaluation import segmentation_metrics, object_detection_metrics
|
|
16
|
-
from .ML_configuration import ClassificationMetricsFormat, MultiClassificationMetricsFormat
|
|
17
42
|
|
|
18
43
|
|
|
19
44
|
__all__ = [
|
|
20
|
-
"
|
|
21
|
-
"
|
|
45
|
+
"DragonTrainer",
|
|
46
|
+
"DragonDetectionTrainer",
|
|
47
|
+
"DragonSequenceTrainer"
|
|
22
48
|
]
|
|
23
49
|
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
kind: Literal["regression", "classification", "multi_target_regression", "multi_label_classification", "segmentation"],
|
|
28
|
-
criterion: nn.Module, optimizer: torch.optim.Optimizer,
|
|
29
|
-
device: Union[Literal['cuda', 'mps', 'cpu'],str], dataloader_workers: int = 2, callbacks: Optional[List[Callback]] = None):
|
|
30
|
-
"""
|
|
31
|
-
Automates the training process of a PyTorch Model.
|
|
32
|
-
|
|
33
|
-
Built-in Callbacks: `History`, `TqdmProgressBar`
|
|
34
|
-
|
|
35
|
-
Args:
|
|
36
|
-
model (nn.Module): The PyTorch model to train.
|
|
37
|
-
train_dataset (Dataset): The training dataset.
|
|
38
|
-
test_dataset (Dataset): The testing/validation dataset.
|
|
39
|
-
kind (str): Can be 'regression', 'classification', 'multi_target_regression', 'multi_label_classification', or 'segmentation'.
|
|
40
|
-
criterion (nn.Module): The loss function.
|
|
41
|
-
optimizer (torch.optim.Optimizer): The optimizer.
|
|
42
|
-
device (str): The device to run training on ('cpu', 'cuda', 'mps').
|
|
43
|
-
dataloader_workers (int): Subprocesses for data loading.
|
|
44
|
-
callbacks (List[Callback] | None): A list of callbacks to use during training.
|
|
45
|
-
|
|
46
|
-
Note:
|
|
47
|
-
- For **regression** and **multi_target_regression** tasks, suggested criterions include `nn.MSELoss` or `nn.L1Loss`.
|
|
48
|
-
|
|
49
|
-
- For **single-label, multi-class classification** tasks, `nn.CrossEntropyLoss` is the standard choice.
|
|
50
|
+
class _BaseDragonTrainer(ABC):
|
|
51
|
+
"""
|
|
52
|
+
Abstract base class for Dragon Trainers.
|
|
50
53
|
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
54
|
+
Handles the common training loop orchestration, checkpointing, callback
|
|
55
|
+
management, and device handling. Subclasses must implement the
|
|
56
|
+
task-specific logic (dataloaders, train/val steps, evaluation).
|
|
57
|
+
"""
|
|
58
|
+
def __init__(self,
|
|
59
|
+
model: nn.Module,
|
|
60
|
+
optimizer: torch.optim.Optimizer,
|
|
61
|
+
device: Union[Literal['cuda', 'mps', 'cpu'],str],
|
|
62
|
+
dataloader_workers: int = 2,
|
|
63
|
+
checkpoint_callback: Optional[DragonModelCheckpoint] = None,
|
|
64
|
+
early_stopping_callback: Optional[DragonEarlyStopping] = None,
|
|
65
|
+
lr_scheduler_callback: Optional[DragonLRScheduler] = None,
|
|
66
|
+
extra_callbacks: Optional[List[_Callback]] = None):
|
|
57
67
|
|
|
58
68
|
self.model = model
|
|
59
|
-
self.train_dataset = train_dataset
|
|
60
|
-
self.test_dataset = test_dataset
|
|
61
|
-
self.kind = kind
|
|
62
|
-
self.criterion = criterion
|
|
63
69
|
self.optimizer = optimizer
|
|
64
70
|
self.scheduler = None
|
|
65
71
|
self.device = self._validate_device(device)
|
|
66
72
|
self.dataloader_workers = dataloader_workers
|
|
67
73
|
|
|
68
|
-
# Callback handler
|
|
74
|
+
# Callback handler
|
|
69
75
|
default_callbacks = [History(), TqdmProgressBar()]
|
|
70
|
-
|
|
76
|
+
|
|
77
|
+
self._checkpoint_callback = None
|
|
78
|
+
if checkpoint_callback:
|
|
79
|
+
default_callbacks.append(checkpoint_callback)
|
|
80
|
+
self._checkpoint_callback = checkpoint_callback
|
|
81
|
+
if early_stopping_callback:
|
|
82
|
+
default_callbacks.append(early_stopping_callback)
|
|
83
|
+
if lr_scheduler_callback:
|
|
84
|
+
default_callbacks.append(lr_scheduler_callback)
|
|
85
|
+
|
|
86
|
+
user_callbacks = extra_callbacks if extra_callbacks is not None else []
|
|
71
87
|
self.callbacks = default_callbacks + user_callbacks
|
|
72
88
|
self._set_trainer_on_callbacks()
|
|
73
89
|
|
|
74
90
|
# Internal state
|
|
75
|
-
self.train_loader = None
|
|
76
|
-
self.
|
|
77
|
-
self.history = {}
|
|
91
|
+
self.train_loader: Optional[DataLoader] = None
|
|
92
|
+
self.validation_loader: Optional[DataLoader] = None
|
|
93
|
+
self.history: Dict[str, List[Any]] = {}
|
|
78
94
|
self.epoch = 0
|
|
79
95
|
self.epochs = 0 # Total epochs for the fit run
|
|
80
96
|
self.start_epoch = 1
|
|
@@ -97,32 +113,10 @@ class MLTrainer:
|
|
|
97
113
|
for callback in self.callbacks:
|
|
98
114
|
callback.set_trainer(self)
|
|
99
115
|
|
|
100
|
-
def _create_dataloaders(self, batch_size: int, shuffle: bool):
|
|
101
|
-
"""Initializes the DataLoaders."""
|
|
102
|
-
# Ensure stability on MPS devices by setting num_workers to 0
|
|
103
|
-
loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
|
|
104
|
-
|
|
105
|
-
self.train_loader = DataLoader(
|
|
106
|
-
dataset=self.train_dataset,
|
|
107
|
-
batch_size=batch_size,
|
|
108
|
-
shuffle=shuffle,
|
|
109
|
-
num_workers=loader_workers,
|
|
110
|
-
pin_memory=("cuda" in self.device.type),
|
|
111
|
-
drop_last=True # Drops the last batch if incomplete, selecting a good batch size is key.
|
|
112
|
-
)
|
|
113
|
-
|
|
114
|
-
self.test_loader = DataLoader(
|
|
115
|
-
dataset=self.test_dataset,
|
|
116
|
-
batch_size=batch_size,
|
|
117
|
-
shuffle=False,
|
|
118
|
-
num_workers=loader_workers,
|
|
119
|
-
pin_memory=("cuda" in self.device.type)
|
|
120
|
-
)
|
|
121
|
-
|
|
122
116
|
def _load_checkpoint(self, path: Union[str, Path]):
|
|
123
117
|
"""Loads a training checkpoint to resume training."""
|
|
124
118
|
p = make_fullpath(path, enforce="file")
|
|
125
|
-
_LOGGER.info(f"Loading checkpoint from '{p.name}'
|
|
119
|
+
_LOGGER.info(f"Loading checkpoint from '{p.name}'...")
|
|
126
120
|
|
|
127
121
|
try:
|
|
128
122
|
checkpoint = torch.load(p, map_location=self.device)
|
|
@@ -133,7 +127,16 @@ class MLTrainer:
|
|
|
133
127
|
|
|
134
128
|
self.model.load_state_dict(checkpoint[PyTorchCheckpointKeys.MODEL_STATE])
|
|
135
129
|
self.optimizer.load_state_dict(checkpoint[PyTorchCheckpointKeys.OPTIMIZER_STATE])
|
|
136
|
-
self.
|
|
130
|
+
self.epoch = checkpoint.get(PyTorchCheckpointKeys.EPOCH, 0)
|
|
131
|
+
self.start_epoch = self.epoch + 1 # Resume on the *next* epoch
|
|
132
|
+
|
|
133
|
+
# --- Load History ---
|
|
134
|
+
if PyTorchCheckpointKeys.HISTORY in checkpoint:
|
|
135
|
+
self.history = checkpoint[PyTorchCheckpointKeys.HISTORY]
|
|
136
|
+
_LOGGER.info(f"Restored training history up to epoch {self.epoch}.")
|
|
137
|
+
else:
|
|
138
|
+
_LOGGER.warning("No 'history' found in checkpoint. A new history will be started.")
|
|
139
|
+
self.history = {} # Ensure it's at least an empty dict
|
|
137
140
|
|
|
138
141
|
# --- Scheduler State Loading Logic ---
|
|
139
142
|
scheduler_state_exists = PyTorchCheckpointKeys.SCHEDULER_STATE in checkpoint
|
|
@@ -163,7 +166,7 @@ class MLTrainer:
|
|
|
163
166
|
|
|
164
167
|
# Restore callback states
|
|
165
168
|
for cb in self.callbacks:
|
|
166
|
-
if isinstance(cb,
|
|
169
|
+
if isinstance(cb, DragonModelCheckpoint) and PyTorchCheckpointKeys.BEST_SCORE in checkpoint:
|
|
167
170
|
cb.best = checkpoint[PyTorchCheckpointKeys.BEST_SCORE]
|
|
168
171
|
_LOGGER.info(f"Restored {cb.__class__.__name__} 'best' score to: {cb.best:.4f}")
|
|
169
172
|
|
|
@@ -174,7 +177,8 @@ class MLTrainer:
|
|
|
174
177
|
raise
|
|
175
178
|
|
|
176
179
|
def fit(self,
|
|
177
|
-
|
|
180
|
+
save_dir: Union[str,Path],
|
|
181
|
+
epochs: int = 100,
|
|
178
182
|
batch_size: int = 10,
|
|
179
183
|
shuffle: bool = True,
|
|
180
184
|
resume_from_checkpoint: Optional[Union[str, Path]] = None):
|
|
@@ -184,21 +188,15 @@ class MLTrainer:
|
|
|
184
188
|
Returns the "History" callback dictionary.
|
|
185
189
|
|
|
186
190
|
Args:
|
|
191
|
+
save_dir (str | Path): Directory to save the loss plot.
|
|
187
192
|
epochs (int): The total number of epochs to train for.
|
|
188
193
|
batch_size (int): The number of samples per batch.
|
|
189
194
|
shuffle (bool): Whether to shuffle the training data at each epoch.
|
|
190
195
|
resume_from_checkpoint (str | Path | None): Optional path to a checkpoint to resume training.
|
|
191
|
-
|
|
192
|
-
Note:
|
|
193
|
-
For regression tasks using `nn.MSELoss` or `nn.L1Loss`, the trainer
|
|
194
|
-
automatically aligns the model's output tensor with the target tensor's
|
|
195
|
-
shape using `output.view_as(target)`. This handles the common case
|
|
196
|
-
where a model outputs a shape of `[batch_size, 1]` and the target has a
|
|
197
|
-
shape of `[batch_size]`.
|
|
198
196
|
"""
|
|
199
197
|
self.epochs = epochs
|
|
200
198
|
self._batch_size = batch_size
|
|
201
|
-
self._create_dataloaders(self._batch_size, shuffle)
|
|
199
|
+
self._create_dataloaders(self._batch_size, shuffle) # type: ignore
|
|
202
200
|
self.model.to(self.device)
|
|
203
201
|
|
|
204
202
|
if resume_from_checkpoint:
|
|
@@ -209,11 +207,19 @@ class MLTrainer:
|
|
|
209
207
|
|
|
210
208
|
self._callbacks_hook('on_train_begin')
|
|
211
209
|
|
|
210
|
+
if not self.train_loader:
|
|
211
|
+
_LOGGER.error("Train loader is not initialized.")
|
|
212
|
+
raise ValueError()
|
|
213
|
+
|
|
214
|
+
if not self.validation_loader:
|
|
215
|
+
_LOGGER.error("Validation loader is not initialized.")
|
|
216
|
+
raise ValueError()
|
|
217
|
+
|
|
212
218
|
for epoch in range(self.start_epoch, self.epochs + 1):
|
|
213
219
|
self.epoch = epoch
|
|
214
|
-
epoch_logs = {}
|
|
220
|
+
epoch_logs: Dict[str, Any] = {}
|
|
215
221
|
self._callbacks_hook('on_epoch_begin', epoch, logs=epoch_logs)
|
|
216
|
-
|
|
222
|
+
|
|
217
223
|
train_logs = self._train_step()
|
|
218
224
|
epoch_logs.update(train_logs)
|
|
219
225
|
|
|
@@ -227,11 +233,204 @@ class MLTrainer:
|
|
|
227
233
|
break
|
|
228
234
|
|
|
229
235
|
self._callbacks_hook('on_train_end')
|
|
236
|
+
|
|
237
|
+
# Training History
|
|
238
|
+
plot_losses(self.history, save_dir=save_dir)
|
|
239
|
+
|
|
230
240
|
return self.history
|
|
241
|
+
|
|
242
|
+
def _callbacks_hook(self, method_name: str, *args, **kwargs):
|
|
243
|
+
"""Calls the specified method on all callbacks."""
|
|
244
|
+
for callback in self.callbacks:
|
|
245
|
+
method = getattr(callback, method_name)
|
|
246
|
+
method(*args, **kwargs)
|
|
247
|
+
|
|
248
|
+
def to_cpu(self):
|
|
249
|
+
"""
|
|
250
|
+
Moves the model to the CPU and updates the trainer's device setting.
|
|
251
|
+
|
|
252
|
+
This is useful for running operations that require the CPU.
|
|
253
|
+
"""
|
|
254
|
+
self.device = torch.device('cpu')
|
|
255
|
+
self.model.to(self.device)
|
|
256
|
+
_LOGGER.info("Trainer and model moved to CPU.")
|
|
257
|
+
|
|
258
|
+
def to_device(self, device: str):
|
|
259
|
+
"""
|
|
260
|
+
Moves the model to the specified device and updates the trainer's device setting.
|
|
261
|
+
|
|
262
|
+
Args:
|
|
263
|
+
device (str): The target device (e.g., 'cuda', 'mps', 'cpu').
|
|
264
|
+
"""
|
|
265
|
+
self.device = self._validate_device(device)
|
|
266
|
+
self.model.to(self.device)
|
|
267
|
+
_LOGGER.info(f"Trainer and model moved to {self.device}.")
|
|
268
|
+
|
|
269
|
+
def _load_model_state_for_finalizing(self, model_checkpoint: Union[Path, Literal['latest', 'current']]):
|
|
270
|
+
"""
|
|
271
|
+
Private helper to load the correct model state_dict based on user's choice.
|
|
272
|
+
This is called by finalize_model_training() in subclasses.
|
|
273
|
+
"""
|
|
274
|
+
if isinstance(model_checkpoint, Path):
|
|
275
|
+
self._load_checkpoint(path=model_checkpoint)
|
|
276
|
+
elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback:
|
|
277
|
+
path_to_latest = self._checkpoint_callback.best_checkpoint_path
|
|
278
|
+
self._load_checkpoint(path_to_latest)
|
|
279
|
+
elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback is None:
|
|
280
|
+
_LOGGER.error(f"'model_checkpoint' set to '{MagicWords.LATEST}' but no checkpoint callback was found.")
|
|
281
|
+
raise ValueError()
|
|
282
|
+
elif model_checkpoint == MagicWords.CURRENT:
|
|
283
|
+
pass
|
|
284
|
+
else:
|
|
285
|
+
_LOGGER.error(f"Unknown 'model_checkpoint' received '{model_checkpoint}'.")
|
|
286
|
+
raise ValueError()
|
|
287
|
+
|
|
288
|
+
# --- Abstract Methods ---
|
|
289
|
+
# These must be implemented by subclasses
|
|
290
|
+
|
|
291
|
+
@abstractmethod
|
|
292
|
+
def _create_dataloaders(self, batch_size: int, shuffle: bool):
|
|
293
|
+
"""Initializes the DataLoaders."""
|
|
294
|
+
raise NotImplementedError
|
|
295
|
+
|
|
296
|
+
@abstractmethod
|
|
297
|
+
def _train_step(self) -> Dict[str, float]:
|
|
298
|
+
"""Runs a single training epoch."""
|
|
299
|
+
raise NotImplementedError
|
|
300
|
+
|
|
301
|
+
@abstractmethod
|
|
302
|
+
def _validation_step(self) -> Dict[str, float]:
|
|
303
|
+
"""Runs a single validation epoch."""
|
|
304
|
+
raise NotImplementedError
|
|
305
|
+
|
|
306
|
+
@abstractmethod
|
|
307
|
+
def evaluate(self, *args, **kwargs):
|
|
308
|
+
"""Runs the full model evaluation."""
|
|
309
|
+
raise NotImplementedError
|
|
310
|
+
|
|
311
|
+
@abstractmethod
|
|
312
|
+
def _evaluate(self, *args, **kwargs):
|
|
313
|
+
"""Internal evaluation helper."""
|
|
314
|
+
raise NotImplementedError
|
|
315
|
+
|
|
316
|
+
@abstractmethod
|
|
317
|
+
def finalize_model_training(self, *args, **kwargs):
|
|
318
|
+
"""Saves the finalized model for inference."""
|
|
319
|
+
raise NotImplementedError
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
# --- DragonTrainer ----
|
|
323
|
+
class DragonTrainer(_BaseDragonTrainer):
|
|
324
|
+
def __init__(self,
|
|
325
|
+
model: nn.Module,
|
|
326
|
+
train_dataset: Dataset,
|
|
327
|
+
validation_dataset: Dataset,
|
|
328
|
+
kind: Literal["regression", "binary classification", "multiclass classification",
|
|
329
|
+
"multitarget regression", "multilabel binary classification",
|
|
330
|
+
"binary segmentation", "multiclass segmentation", "binary image classification", "multiclass image classification"],
|
|
331
|
+
optimizer: torch.optim.Optimizer,
|
|
332
|
+
device: Union[Literal['cuda', 'mps', 'cpu'],str],
|
|
333
|
+
checkpoint_callback: Optional[DragonModelCheckpoint],
|
|
334
|
+
early_stopping_callback: Optional[DragonEarlyStopping],
|
|
335
|
+
lr_scheduler_callback: Optional[DragonLRScheduler],
|
|
336
|
+
extra_callbacks: Optional[List[_Callback]] = None,
|
|
337
|
+
criterion: Union[nn.Module,Literal["auto"]] = "auto",
|
|
338
|
+
dataloader_workers: int = 2):
|
|
339
|
+
"""
|
|
340
|
+
Automates the training process of a PyTorch Model.
|
|
341
|
+
|
|
342
|
+
Built-in Callbacks: `History`, `TqdmProgressBar`
|
|
343
|
+
|
|
344
|
+
Args:
|
|
345
|
+
model (nn.Module): The PyTorch model to train.
|
|
346
|
+
train_dataset (Dataset): The training dataset.
|
|
347
|
+
validation_dataset (Dataset): The validation dataset.
|
|
348
|
+
kind (str): Used to redirect to the correct process.
|
|
349
|
+
criterion (nn.Module | "auto"): The loss function to use. If "auto", it will be inferred from the selected task
|
|
350
|
+
optimizer (torch.optim.Optimizer): The optimizer.
|
|
351
|
+
device (str): The device to run training on ('cpu', 'cuda', 'mps').
|
|
352
|
+
dataloader_workers (int): Subprocesses for data loading.
|
|
353
|
+
extra_callbacks (List[Callback] | None): A list of extra callbacks to use during training.
|
|
354
|
+
|
|
355
|
+
Note:
|
|
356
|
+
- For **regression** and **multi_target_regression** tasks, suggested criterions include `nn.MSELoss` or `nn.L1Loss`. The model should output as many logits as existing targets.
|
|
357
|
+
|
|
358
|
+
- For **single-label, binary classification**, `nn.BCEWithLogitsLoss` is the standard choice. The model should output a single logit.
|
|
231
359
|
|
|
360
|
+
- For **single-label, multi-class classification** tasks, `nn.CrossEntropyLoss` is the standard choice. The model should output as many logits as existing classes.
|
|
361
|
+
|
|
362
|
+
- For **multi-label, binary classification** tasks (where each label is a 0 or 1), `nn.BCEWithLogitsLoss` is the correct choice as it treats each output as an independent binary problem. The model should output 1 logit per binary target.
|
|
363
|
+
|
|
364
|
+
- For **binary segmentation** tasks, `nn.BCEWithLogitsLoss` is common. The model should output a single logit.
|
|
365
|
+
|
|
366
|
+
- for **multiclass segmentation** tasks, `nn.CrossEntropyLoss` is the standard. The model should output as many logits as existing classes.
|
|
367
|
+
"""
|
|
368
|
+
# Call the base class constructor with common parameters
|
|
369
|
+
super().__init__(
|
|
370
|
+
model=model,
|
|
371
|
+
optimizer=optimizer,
|
|
372
|
+
device=device,
|
|
373
|
+
dataloader_workers=dataloader_workers,
|
|
374
|
+
checkpoint_callback=checkpoint_callback,
|
|
375
|
+
early_stopping_callback=early_stopping_callback,
|
|
376
|
+
lr_scheduler_callback=lr_scheduler_callback,
|
|
377
|
+
extra_callbacks=extra_callbacks
|
|
378
|
+
)
|
|
379
|
+
|
|
380
|
+
if kind not in [MLTaskKeys.REGRESSION,
|
|
381
|
+
MLTaskKeys.BINARY_CLASSIFICATION,
|
|
382
|
+
MLTaskKeys.MULTICLASS_CLASSIFICATION,
|
|
383
|
+
MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION,
|
|
384
|
+
MLTaskKeys.MULTITARGET_REGRESSION,
|
|
385
|
+
MLTaskKeys.BINARY_SEGMENTATION,
|
|
386
|
+
MLTaskKeys.MULTICLASS_SEGMENTATION,
|
|
387
|
+
MLTaskKeys.BINARY_IMAGE_CLASSIFICATION,
|
|
388
|
+
MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION]:
|
|
389
|
+
raise ValueError(f"'{kind}' is not a valid task type.")
|
|
390
|
+
|
|
391
|
+
self.train_dataset = train_dataset
|
|
392
|
+
self.validation_dataset = validation_dataset
|
|
393
|
+
self.kind = kind
|
|
394
|
+
self._classification_threshold: float = 0.5
|
|
395
|
+
|
|
396
|
+
# loss function
|
|
397
|
+
if criterion == "auto":
|
|
398
|
+
if kind in [MLTaskKeys.REGRESSION, MLTaskKeys.MULTITARGET_REGRESSION]:
|
|
399
|
+
self.criterion = nn.MSELoss()
|
|
400
|
+
elif kind in [MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.BINARY_IMAGE_CLASSIFICATION, MLTaskKeys.BINARY_SEGMENTATION, MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION]:
|
|
401
|
+
self.criterion = nn.BCEWithLogitsLoss()
|
|
402
|
+
elif kind in [MLTaskKeys.MULTICLASS_CLASSIFICATION, MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION, MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION, MLTaskKeys.MULTICLASS_SEGMENTATION]:
|
|
403
|
+
self.criterion = nn.CrossEntropyLoss()
|
|
404
|
+
else:
|
|
405
|
+
self.criterion = criterion
|
|
406
|
+
|
|
407
|
+
def _create_dataloaders(self, batch_size: int, shuffle: bool):
|
|
408
|
+
"""Initializes the DataLoaders."""
|
|
409
|
+
# Ensure stability on MPS devices by setting num_workers to 0
|
|
410
|
+
loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
|
|
411
|
+
|
|
412
|
+
self.train_loader = DataLoader(
|
|
413
|
+
dataset=self.train_dataset,
|
|
414
|
+
batch_size=batch_size,
|
|
415
|
+
shuffle=shuffle,
|
|
416
|
+
num_workers=loader_workers,
|
|
417
|
+
pin_memory=("cuda" in self.device.type),
|
|
418
|
+
drop_last=True # Drops the last batch if incomplete, selecting a good batch size is key.
|
|
419
|
+
)
|
|
420
|
+
|
|
421
|
+
self.validation_loader = DataLoader(
|
|
422
|
+
dataset=self.validation_dataset,
|
|
423
|
+
batch_size=batch_size,
|
|
424
|
+
shuffle=False,
|
|
425
|
+
num_workers=loader_workers,
|
|
426
|
+
pin_memory=("cuda" in self.device.type)
|
|
427
|
+
)
|
|
428
|
+
|
|
232
429
|
def _train_step(self):
|
|
233
430
|
self.model.train()
|
|
234
431
|
running_loss = 0.0
|
|
432
|
+
total_samples = 0
|
|
433
|
+
|
|
235
434
|
for batch_idx, (features, target) in enumerate(self.train_loader): # type: ignore
|
|
236
435
|
# Create a log dictionary for the batch
|
|
237
436
|
batch_logs = {
|
|
@@ -245,9 +444,21 @@ class MLTrainer:
|
|
|
245
444
|
|
|
246
445
|
output = self.model(features)
|
|
247
446
|
|
|
248
|
-
#
|
|
249
|
-
|
|
250
|
-
|
|
447
|
+
# --- Label Type/Shape Correction ---
|
|
448
|
+
# Cast target to float for BCE-based losses
|
|
449
|
+
if self.kind in MLTaskKeys.ALL_BINARY_TASKS:
|
|
450
|
+
target = target.float()
|
|
451
|
+
|
|
452
|
+
# Reshape output to match target for single-logit tasks
|
|
453
|
+
if self.kind in [MLTaskKeys.REGRESSION, MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.BINARY_IMAGE_CLASSIFICATION]:
|
|
454
|
+
# If model outputs [N, 1] and target is [N], squeeze output
|
|
455
|
+
if output.ndim == 2 and output.shape[1] == 1 and target.ndim == 1:
|
|
456
|
+
output = output.squeeze(1)
|
|
457
|
+
|
|
458
|
+
if self.kind == MLTaskKeys.BINARY_SEGMENTATION:
|
|
459
|
+
# If model outputs [N, 1, H, W] and target is [N, H, W], squeeze output
|
|
460
|
+
if output.ndim == 4 and output.shape[1] == 1 and target.ndim == 3:
|
|
461
|
+
output = output.squeeze(1)
|
|
251
462
|
|
|
252
463
|
loss = self.criterion(output, target)
|
|
253
464
|
|
|
@@ -256,34 +467,58 @@ class MLTrainer:
|
|
|
256
467
|
|
|
257
468
|
# Calculate batch loss and update running loss for the epoch
|
|
258
469
|
batch_loss = loss.item()
|
|
259
|
-
|
|
470
|
+
batch_size = features.size(0)
|
|
471
|
+
running_loss += batch_loss * batch_size # Accumulate total loss
|
|
472
|
+
total_samples += batch_size # total samples
|
|
260
473
|
|
|
261
474
|
# Add the batch loss to the logs and call the end-of-batch hook
|
|
262
475
|
batch_logs[PyTorchLogKeys.BATCH_LOSS] = batch_loss
|
|
263
476
|
self._callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
|
|
477
|
+
|
|
478
|
+
if total_samples == 0:
|
|
479
|
+
_LOGGER.warning("No samples processed in a train_step. Returning 0 loss.")
|
|
480
|
+
return {PyTorchLogKeys.TRAIN_LOSS: 0.0}
|
|
264
481
|
|
|
265
|
-
return {PyTorchLogKeys.TRAIN_LOSS: running_loss /
|
|
482
|
+
return {PyTorchLogKeys.TRAIN_LOSS: running_loss / total_samples} # type: ignore
|
|
266
483
|
|
|
267
484
|
def _validation_step(self):
|
|
268
485
|
self.model.eval()
|
|
269
486
|
running_loss = 0.0
|
|
487
|
+
|
|
270
488
|
with torch.no_grad():
|
|
271
|
-
for features, target in self.
|
|
489
|
+
for features, target in self.validation_loader: # type: ignore
|
|
272
490
|
features, target = features.to(self.device), target.to(self.device)
|
|
273
491
|
|
|
274
492
|
output = self.model(features)
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
493
|
+
|
|
494
|
+
# --- Label Type/Shape Correction ---
|
|
495
|
+
# Cast target to float for BCE-based losses
|
|
496
|
+
if self.kind in MLTaskKeys.ALL_BINARY_TASKS:
|
|
497
|
+
target = target.float()
|
|
498
|
+
|
|
499
|
+
# Reshape output to match target for single-logit tasks
|
|
500
|
+
if self.kind in [MLTaskKeys.REGRESSION, MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.BINARY_IMAGE_CLASSIFICATION]:
|
|
501
|
+
# If model outputs [N, 1] and target is [N], squeeze output
|
|
502
|
+
if output.ndim == 2 and output.shape[1] == 1 and target.ndim == 1:
|
|
503
|
+
output = output.squeeze(1)
|
|
504
|
+
|
|
505
|
+
if self.kind == MLTaskKeys.BINARY_SEGMENTATION:
|
|
506
|
+
# If model outputs [N, 1, H, W] and target is [N, H, W], squeeze output
|
|
507
|
+
if output.ndim == 4 and output.shape[1] == 1 and target.ndim == 3:
|
|
508
|
+
output = output.squeeze(1)
|
|
278
509
|
|
|
279
510
|
loss = self.criterion(output, target)
|
|
280
511
|
|
|
281
512
|
running_loss += loss.item() * features.size(0)
|
|
513
|
+
|
|
514
|
+
if not self.validation_loader.dataset: # type: ignore
|
|
515
|
+
_LOGGER.warning("No samples processed in _validation_step. Returning 0 loss.")
|
|
516
|
+
return {PyTorchLogKeys.VAL_LOSS: 0.0}
|
|
282
517
|
|
|
283
|
-
logs = {PyTorchLogKeys.VAL_LOSS: running_loss / len(self.
|
|
518
|
+
logs = {PyTorchLogKeys.VAL_LOSS: running_loss / len(self.validation_loader.dataset)} # type: ignore
|
|
284
519
|
return logs
|
|
285
520
|
|
|
286
|
-
def _predict_for_eval(self, dataloader: DataLoader
|
|
521
|
+
def _predict_for_eval(self, dataloader: DataLoader):
|
|
287
522
|
"""
|
|
288
523
|
Private method to yield model predictions batch by batch for evaluation.
|
|
289
524
|
|
|
@@ -294,6 +529,7 @@ class MLTrainer:
|
|
|
294
529
|
"""
|
|
295
530
|
self.model.eval()
|
|
296
531
|
self.model.to(self.device)
|
|
532
|
+
|
|
297
533
|
with torch.no_grad():
|
|
298
534
|
for features, target in dataloader:
|
|
299
535
|
features = features.to(self.device)
|
|
@@ -303,25 +539,64 @@ class MLTrainer:
|
|
|
303
539
|
y_prob_batch = None
|
|
304
540
|
y_true_batch = None
|
|
305
541
|
|
|
306
|
-
if self.kind in [
|
|
542
|
+
if self.kind in [MLTaskKeys.REGRESSION, MLTaskKeys.MULTITARGET_REGRESSION]:
|
|
307
543
|
y_pred_batch = output.numpy()
|
|
308
544
|
y_true_batch = target.numpy()
|
|
545
|
+
|
|
546
|
+
elif self.kind in [MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.BINARY_IMAGE_CLASSIFICATION]:
|
|
547
|
+
# Assumes model output is [N, 1] (a single logit)
|
|
548
|
+
# Squeeze output from [N, 1] to [N] if necessary
|
|
549
|
+
if output.ndim == 2 and output.shape[1] == 1:
|
|
550
|
+
output = output.squeeze(1)
|
|
551
|
+
|
|
552
|
+
probs_pos = torch.sigmoid(output) # Probability of positive class
|
|
553
|
+
preds = (probs_pos >= self._classification_threshold).int()
|
|
554
|
+
y_pred_batch = preds.numpy()
|
|
555
|
+
# For metrics (like ROC AUC), we often need probs for *both* classes
|
|
556
|
+
# Create an [N, 2] array: [prob_class_0, prob_class_1]
|
|
557
|
+
probs_neg = 1.0 - probs_pos
|
|
558
|
+
y_prob_batch = torch.stack([probs_neg, probs_pos], dim=1).numpy()
|
|
559
|
+
y_true_batch = target.numpy()
|
|
309
560
|
|
|
310
|
-
elif self.kind
|
|
561
|
+
elif self.kind in [MLTaskKeys.MULTICLASS_CLASSIFICATION, MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION]:
|
|
562
|
+
num_classes = output.shape[1]
|
|
563
|
+
if num_classes < 3:
|
|
564
|
+
# Optional: warn the user they are using the wrong kind
|
|
565
|
+
wrong_class = MLTaskKeys.MULTICLASS_CLASSIFICATION if self.kind == MLTaskKeys.MULTICLASS_CLASSIFICATION else MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION
|
|
566
|
+
recommended_class = MLTaskKeys.BINARY_CLASSIFICATION if self.kind == MLTaskKeys.MULTICLASS_CLASSIFICATION else MLTaskKeys.BINARY_IMAGE_CLASSIFICATION
|
|
567
|
+
_LOGGER.warning(f"'{wrong_class}' kind used with {num_classes} classes. Consider using '{recommended_class}' instead.")
|
|
568
|
+
|
|
311
569
|
probs = torch.softmax(output, dim=1)
|
|
312
570
|
preds = torch.argmax(probs, dim=1)
|
|
313
571
|
y_pred_batch = preds.numpy()
|
|
314
572
|
y_prob_batch = probs.numpy()
|
|
315
573
|
y_true_batch = target.numpy()
|
|
316
574
|
|
|
317
|
-
elif self.kind ==
|
|
575
|
+
elif self.kind == MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION:
|
|
318
576
|
probs = torch.sigmoid(output)
|
|
319
|
-
preds = (probs >=
|
|
577
|
+
preds = (probs >= self._classification_threshold).int()
|
|
320
578
|
y_pred_batch = preds.numpy()
|
|
321
579
|
y_prob_batch = probs.numpy()
|
|
322
580
|
y_true_batch = target.numpy()
|
|
581
|
+
|
|
582
|
+
elif self.kind == MLTaskKeys.BINARY_SEGMENTATION:
|
|
583
|
+
# Assumes model output is [N, 1, H, W] (logits for positive class)
|
|
584
|
+
probs_pos = torch.sigmoid(output) # Shape [N, 1, H, W]
|
|
585
|
+
preds = (probs_pos >= self._classification_threshold).int() # Shape [N, 1, H, W]
|
|
586
|
+
|
|
587
|
+
# Squeeze preds to [N, H, W] (class indices 0 or 1)
|
|
588
|
+
y_pred_batch = preds.squeeze(1).numpy()
|
|
589
|
+
|
|
590
|
+
# Create [N, 2, H, W] probs for consistency
|
|
591
|
+
probs_neg = 1.0 - probs_pos
|
|
592
|
+
y_prob_batch = torch.cat([probs_neg, probs_pos], dim=1).numpy()
|
|
593
|
+
|
|
594
|
+
# Handle target shape [N, 1, H, W] -> [N, H, W]
|
|
595
|
+
if target.ndim == 4 and target.shape[1] == 1:
|
|
596
|
+
target = target.squeeze(1)
|
|
597
|
+
y_true_batch = target.numpy()
|
|
323
598
|
|
|
324
|
-
elif self.kind ==
|
|
599
|
+
elif self.kind == MLTaskKeys.MULTICLASS_SEGMENTATION:
|
|
325
600
|
# output shape [N, C, H, W]
|
|
326
601
|
probs = torch.softmax(output, dim=1)
|
|
327
602
|
preds = torch.argmax(probs, dim=1) # shape [N, H, W]
|
|
@@ -334,26 +609,192 @@ class MLTrainer:
|
|
|
334
609
|
y_true_batch = target.numpy()
|
|
335
610
|
|
|
336
611
|
yield y_pred_batch, y_prob_batch, y_true_batch
|
|
337
|
-
|
|
612
|
+
|
|
338
613
|
def evaluate(self,
|
|
339
614
|
save_dir: Union[str, Path],
|
|
340
|
-
|
|
341
|
-
|
|
615
|
+
model_checkpoint: Union[Path, Literal["latest", "current"]],
|
|
616
|
+
classification_threshold: Optional[float] = None,
|
|
617
|
+
test_data: Optional[Union[DataLoader, Dataset]] = None,
|
|
618
|
+
val_format_configuration: Optional[Union[
|
|
619
|
+
RegressionMetricsFormat,
|
|
620
|
+
MultiTargetRegressionMetricsFormat,
|
|
621
|
+
BinaryClassificationMetricsFormat,
|
|
622
|
+
MultiClassClassificationMetricsFormat,
|
|
623
|
+
BinaryImageClassificationMetricsFormat,
|
|
624
|
+
MultiClassImageClassificationMetricsFormat,
|
|
625
|
+
MultiLabelBinaryClassificationMetricsFormat,
|
|
626
|
+
BinarySegmentationMetricsFormat,
|
|
627
|
+
MultiClassSegmentationMetricsFormat
|
|
628
|
+
]]=None,
|
|
629
|
+
test_format_configuration: Optional[Union[
|
|
630
|
+
RegressionMetricsFormat,
|
|
631
|
+
MultiTargetRegressionMetricsFormat,
|
|
632
|
+
BinaryClassificationMetricsFormat,
|
|
633
|
+
MultiClassClassificationMetricsFormat,
|
|
634
|
+
BinaryImageClassificationMetricsFormat,
|
|
635
|
+
MultiClassImageClassificationMetricsFormat,
|
|
636
|
+
MultiLabelBinaryClassificationMetricsFormat,
|
|
637
|
+
BinarySegmentationMetricsFormat,
|
|
638
|
+
MultiClassSegmentationMetricsFormat,
|
|
639
|
+
]]=None):
|
|
342
640
|
"""
|
|
343
641
|
Evaluates the model, routing to the correct evaluation function based on task `kind`.
|
|
344
642
|
|
|
345
643
|
Args:
|
|
644
|
+
model_checkpoint ('auto' | Path | None):
|
|
645
|
+
- Path to a valid checkpoint for the model. The state of the trained model will be overwritten in place.
|
|
646
|
+
- If 'latest', the latest checkpoint will be loaded if a DragonModelCheckpoint was provided. The state of the trained model will be overwritten in place.
|
|
647
|
+
- If 'current', use the current state of the trained model up the latest trained epoch.
|
|
346
648
|
save_dir (str | Path): Directory to save all reports and plots.
|
|
347
|
-
|
|
649
|
+
classification_threshold (float | None): Used for tasks using a binary approach (binary classification, binary segmentation, multilabel binary classification)
|
|
650
|
+
test_data (DataLoader | Dataset | None): Optional Test data to evaluate the model performance. Validation and Test metrics will be saved to subdirectories.
|
|
651
|
+
val_format_configuration (object): Optional configuration for metric format output for the validation set.
|
|
652
|
+
test_format_configuration (object): Optional configuration for metric format output for the test set.
|
|
653
|
+
"""
|
|
654
|
+
# Validate model checkpoint
|
|
655
|
+
if isinstance(model_checkpoint, Path):
|
|
656
|
+
checkpoint_validated = make_fullpath(model_checkpoint, enforce="file")
|
|
657
|
+
elif model_checkpoint in [MagicWords.LATEST, MagicWords.CURRENT]:
|
|
658
|
+
checkpoint_validated = model_checkpoint
|
|
659
|
+
else:
|
|
660
|
+
_LOGGER.error(f"'model_checkpoint' must be a Path object, or the string '{MagicWords.LATEST}', or the string '{MagicWords.CURRENT}'.")
|
|
661
|
+
raise ValueError()
|
|
662
|
+
|
|
663
|
+
# Validate classification threshold
|
|
664
|
+
if self.kind not in MLTaskKeys.ALL_BINARY_TASKS:
|
|
665
|
+
# dummy value for tasks that do not need it
|
|
666
|
+
threshold_validated = 0.5
|
|
667
|
+
elif classification_threshold is None:
|
|
668
|
+
# it should have been provided for binary tasks
|
|
669
|
+
_LOGGER.error(f"The classification threshold must be provided for '{self.kind}'.")
|
|
670
|
+
raise ValueError()
|
|
671
|
+
elif classification_threshold <= 0.0 or classification_threshold >= 1.0:
|
|
672
|
+
# Invalid float
|
|
673
|
+
_LOGGER.error(f"A classification threshold of {classification_threshold} is invalid. Must be in the range (0.0 - 1.0).")
|
|
674
|
+
raise ValueError()
|
|
675
|
+
else:
|
|
676
|
+
threshold_validated = classification_threshold
|
|
677
|
+
|
|
678
|
+
# Validate val configuration
|
|
679
|
+
if val_format_configuration is not None:
|
|
680
|
+
if not isinstance(val_format_configuration, (RegressionMetricsFormat,
|
|
681
|
+
MultiTargetRegressionMetricsFormat,
|
|
682
|
+
BinaryClassificationMetricsFormat,
|
|
683
|
+
MultiClassClassificationMetricsFormat,
|
|
684
|
+
BinaryImageClassificationMetricsFormat,
|
|
685
|
+
MultiClassImageClassificationMetricsFormat,
|
|
686
|
+
MultiLabelBinaryClassificationMetricsFormat,
|
|
687
|
+
BinarySegmentationMetricsFormat,
|
|
688
|
+
MultiClassSegmentationMetricsFormat)):
|
|
689
|
+
_LOGGER.error(f"Invalid 'format_configuration': '{type(val_format_configuration)}'.")
|
|
690
|
+
raise ValueError()
|
|
691
|
+
else:
|
|
692
|
+
val_configuration_validated = val_format_configuration
|
|
693
|
+
else: # config is None
|
|
694
|
+
val_configuration_validated = None
|
|
695
|
+
|
|
696
|
+
# Validate directory
|
|
697
|
+
save_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
698
|
+
|
|
699
|
+
# Validate test data and dispatch
|
|
700
|
+
if test_data is not None:
|
|
701
|
+
if not isinstance(test_data, (DataLoader, Dataset)):
|
|
702
|
+
_LOGGER.error(f"Invalid type for 'test_data': '{type(test_data)}'.")
|
|
703
|
+
raise ValueError()
|
|
704
|
+
test_data_validated = test_data
|
|
705
|
+
|
|
706
|
+
validation_metrics_path = save_path / DragonTrainerKeys.VALIDATION_METRICS_DIR
|
|
707
|
+
test_metrics_path = save_path / DragonTrainerKeys.TEST_METRICS_DIR
|
|
708
|
+
|
|
709
|
+
# Dispatch validation set
|
|
710
|
+
_LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{DragonTrainerKeys.VALIDATION_METRICS_DIR}'")
|
|
711
|
+
self._evaluate(save_dir=validation_metrics_path,
|
|
712
|
+
model_checkpoint=checkpoint_validated,
|
|
713
|
+
classification_threshold=threshold_validated,
|
|
714
|
+
data=None,
|
|
715
|
+
format_configuration=val_configuration_validated)
|
|
716
|
+
|
|
717
|
+
# Validate test configuration
|
|
718
|
+
if test_format_configuration is not None:
|
|
719
|
+
if not isinstance(test_format_configuration, (RegressionMetricsFormat,
|
|
720
|
+
MultiTargetRegressionMetricsFormat,
|
|
721
|
+
BinaryClassificationMetricsFormat,
|
|
722
|
+
MultiClassClassificationMetricsFormat,
|
|
723
|
+
BinaryImageClassificationMetricsFormat,
|
|
724
|
+
MultiClassImageClassificationMetricsFormat,
|
|
725
|
+
MultiLabelBinaryClassificationMetricsFormat,
|
|
726
|
+
BinarySegmentationMetricsFormat,
|
|
727
|
+
MultiClassSegmentationMetricsFormat)):
|
|
728
|
+
warning_message_type = f"Invalid test_format_configuration': '{type(test_format_configuration)}'."
|
|
729
|
+
if val_configuration_validated is not None:
|
|
730
|
+
warning_message_type += " 'val_format_configuration' will be used for the test set metrics output."
|
|
731
|
+
test_configuration_validated = val_configuration_validated
|
|
732
|
+
else:
|
|
733
|
+
warning_message_type += " Using default format."
|
|
734
|
+
test_configuration_validated = None
|
|
735
|
+
_LOGGER.warning(warning_message_type)
|
|
736
|
+
else:
|
|
737
|
+
test_configuration_validated = test_format_configuration
|
|
738
|
+
else: #config is None
|
|
739
|
+
test_configuration_validated = None
|
|
740
|
+
|
|
741
|
+
# Dispatch test set
|
|
742
|
+
_LOGGER.info(f"Evaluating on test dataset. Metrics will be saved to '{DragonTrainerKeys.TEST_METRICS_DIR}'")
|
|
743
|
+
self._evaluate(save_dir=test_metrics_path,
|
|
744
|
+
model_checkpoint="current",
|
|
745
|
+
classification_threshold=threshold_validated,
|
|
746
|
+
data=test_data_validated,
|
|
747
|
+
format_configuration=test_configuration_validated)
|
|
748
|
+
else:
|
|
749
|
+
# Dispatch validation set
|
|
750
|
+
_LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{save_path.name}'")
|
|
751
|
+
self._evaluate(save_dir=save_path,
|
|
752
|
+
model_checkpoint=checkpoint_validated,
|
|
753
|
+
classification_threshold=threshold_validated,
|
|
754
|
+
data=None,
|
|
755
|
+
format_configuration=val_configuration_validated)
|
|
756
|
+
|
|
757
|
+
def _evaluate(self,
|
|
758
|
+
save_dir: Union[str, Path],
|
|
759
|
+
model_checkpoint: Union[Path, Literal["latest", "current"]],
|
|
760
|
+
classification_threshold: float,
|
|
761
|
+
data: Optional[Union[DataLoader, Dataset]],
|
|
762
|
+
format_configuration: Optional[Union[
|
|
763
|
+
RegressionMetricsFormat,
|
|
764
|
+
MultiTargetRegressionMetricsFormat,
|
|
765
|
+
BinaryClassificationMetricsFormat,
|
|
766
|
+
MultiClassClassificationMetricsFormat,
|
|
767
|
+
BinaryImageClassificationMetricsFormat,
|
|
768
|
+
MultiClassImageClassificationMetricsFormat,
|
|
769
|
+
MultiLabelBinaryClassificationMetricsFormat,
|
|
770
|
+
BinarySegmentationMetricsFormat,
|
|
771
|
+
MultiClassSegmentationMetricsFormat
|
|
772
|
+
]]=None):
|
|
773
|
+
"""
|
|
774
|
+
Changed to a private helper function.
|
|
348
775
|
"""
|
|
349
|
-
|
|
776
|
+
dataset_for_artifacts = None
|
|
350
777
|
eval_loader = None
|
|
351
|
-
|
|
778
|
+
|
|
779
|
+
# set threshold
|
|
780
|
+
self._classification_threshold = classification_threshold
|
|
781
|
+
|
|
782
|
+
# load model checkpoint
|
|
783
|
+
if isinstance(model_checkpoint, Path):
|
|
784
|
+
self._load_checkpoint(path=model_checkpoint)
|
|
785
|
+
elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback:
|
|
786
|
+
path_to_latest = self._checkpoint_callback.best_checkpoint_path
|
|
787
|
+
self._load_checkpoint(path_to_latest)
|
|
788
|
+
elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback is None:
|
|
789
|
+
_LOGGER.error(f"'model_checkpoint' set to '{MagicWords.LATEST}' but no checkpoint callback was found.")
|
|
790
|
+
raise ValueError()
|
|
791
|
+
|
|
792
|
+
# Dataloader
|
|
352
793
|
if isinstance(data, DataLoader):
|
|
353
794
|
eval_loader = data
|
|
354
795
|
# Try to get the dataset from the loader for fetching target names
|
|
355
796
|
if hasattr(data, 'dataset'):
|
|
356
|
-
|
|
797
|
+
dataset_for_artifacts = data.dataset # type: ignore
|
|
357
798
|
elif isinstance(data, Dataset):
|
|
358
799
|
# Create a new loader from the provided dataset
|
|
359
800
|
eval_loader = DataLoader(data,
|
|
@@ -361,19 +802,19 @@ class MLTrainer:
|
|
|
361
802
|
shuffle=False,
|
|
362
803
|
num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
|
|
363
804
|
pin_memory=(self.device.type == "cuda"))
|
|
364
|
-
|
|
805
|
+
dataset_for_artifacts = data
|
|
365
806
|
else: # data is None, use the trainer's default test dataset
|
|
366
|
-
if self.
|
|
367
|
-
_LOGGER.error("Cannot evaluate. No data provided and no
|
|
807
|
+
if self.validation_dataset is None:
|
|
808
|
+
_LOGGER.error("Cannot evaluate. No data provided and no validation dataset available in the trainer.")
|
|
368
809
|
raise ValueError()
|
|
369
810
|
# Create a fresh DataLoader from the test_dataset
|
|
370
|
-
eval_loader = DataLoader(self.
|
|
811
|
+
eval_loader = DataLoader(self.validation_dataset,
|
|
371
812
|
batch_size=self._batch_size,
|
|
372
813
|
shuffle=False,
|
|
373
814
|
num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
|
|
374
815
|
pin_memory=(self.device.type == "cuda"))
|
|
375
816
|
|
|
376
|
-
|
|
817
|
+
dataset_for_artifacts = self.validation_dataset
|
|
377
818
|
|
|
378
819
|
if eval_loader is None:
|
|
379
820
|
_LOGGER.error("Cannot evaluate. No valid data was provided or found.")
|
|
@@ -396,36 +837,83 @@ class MLTrainer:
|
|
|
396
837
|
y_prob = np.concatenate(all_probs) if all_probs else None
|
|
397
838
|
|
|
398
839
|
# --- Routing Logic ---
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
840
|
+
# Single-target regression
|
|
841
|
+
if self.kind == MLTaskKeys.REGRESSION:
|
|
842
|
+
# Check configuration
|
|
843
|
+
config = None
|
|
844
|
+
if format_configuration and isinstance(format_configuration, RegressionMetricsFormat):
|
|
845
|
+
config = format_configuration
|
|
846
|
+
elif format_configuration:
|
|
847
|
+
_LOGGER.warning(f"Wrong configuration type: Received '{type(format_configuration).__name__}'.")
|
|
848
|
+
|
|
849
|
+
regression_metrics(y_true=y_true.flatten(),
|
|
850
|
+
y_pred=y_pred.flatten(),
|
|
851
|
+
save_dir=save_dir,
|
|
852
|
+
config=config)
|
|
853
|
+
|
|
854
|
+
# single target classification
|
|
855
|
+
elif self.kind in [MLTaskKeys.BINARY_CLASSIFICATION,
|
|
856
|
+
MLTaskKeys.BINARY_IMAGE_CLASSIFICATION,
|
|
857
|
+
MLTaskKeys.MULTICLASS_CLASSIFICATION,
|
|
858
|
+
MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION]:
|
|
859
|
+
# get the class map if it exists
|
|
860
|
+
try:
|
|
861
|
+
class_map = dataset_for_artifacts.class_map # type: ignore
|
|
862
|
+
except AttributeError:
|
|
863
|
+
_LOGGER.warning(f"Dataset has no 'class_map' attribute. Using generics.")
|
|
864
|
+
class_map = None
|
|
414
865
|
else:
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
866
|
+
if not isinstance(class_map, dict):
|
|
867
|
+
_LOGGER.warning(f"Dataset has a 'class_map' attribute, but it is not a dictionary: '{type(class_map)}'.")
|
|
868
|
+
class_map = None
|
|
869
|
+
|
|
870
|
+
# Check configuration
|
|
871
|
+
config = None
|
|
872
|
+
if format_configuration:
|
|
873
|
+
if self.kind == MLTaskKeys.BINARY_CLASSIFICATION and isinstance(format_configuration, BinaryClassificationMetricsFormat):
|
|
874
|
+
config = format_configuration
|
|
875
|
+
elif self.kind == MLTaskKeys.BINARY_IMAGE_CLASSIFICATION and isinstance(format_configuration, BinaryImageClassificationMetricsFormat):
|
|
876
|
+
config = format_configuration
|
|
877
|
+
elif self.kind == MLTaskKeys.MULTICLASS_CLASSIFICATION and isinstance(format_configuration, MultiClassClassificationMetricsFormat):
|
|
878
|
+
config = format_configuration
|
|
879
|
+
elif self.kind == MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION and isinstance(format_configuration, MultiClassImageClassificationMetricsFormat):
|
|
880
|
+
config = format_configuration
|
|
881
|
+
else:
|
|
882
|
+
_LOGGER.warning(f"Wrong configuration type: Received '{type(format_configuration).__name__}'.")
|
|
883
|
+
|
|
884
|
+
classification_metrics(save_dir=save_dir,
|
|
885
|
+
y_true=y_true,
|
|
886
|
+
y_pred=y_pred,
|
|
887
|
+
y_prob=y_prob,
|
|
888
|
+
class_map=class_map,
|
|
889
|
+
config=config)
|
|
890
|
+
|
|
891
|
+
# multitarget regression
|
|
892
|
+
elif self.kind == MLTaskKeys.MULTITARGET_REGRESSION:
|
|
418
893
|
try:
|
|
419
|
-
target_names =
|
|
894
|
+
target_names = dataset_for_artifacts.target_names # type: ignore
|
|
420
895
|
except AttributeError:
|
|
421
896
|
num_targets = y_true.shape[1]
|
|
422
897
|
target_names = [f"target_{i}" for i in range(num_targets)]
|
|
423
898
|
_LOGGER.warning(f"Dataset has no 'target_names' attribute. Using generic names.")
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
899
|
+
|
|
900
|
+
# Check configuration
|
|
901
|
+
config = None
|
|
902
|
+
if format_configuration and isinstance(format_configuration, MultiTargetRegressionMetricsFormat):
|
|
903
|
+
config = format_configuration
|
|
904
|
+
elif format_configuration:
|
|
905
|
+
_LOGGER.warning(f"Wrong configuration type: Received '{type(format_configuration).__name__}'.")
|
|
906
|
+
|
|
907
|
+
multi_target_regression_metrics(y_true=y_true,
|
|
908
|
+
y_pred=y_pred,
|
|
909
|
+
target_names=target_names,
|
|
910
|
+
save_dir=save_dir,
|
|
911
|
+
config=config)
|
|
912
|
+
|
|
913
|
+
# multi-label binary classification
|
|
914
|
+
elif self.kind == MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION:
|
|
427
915
|
try:
|
|
428
|
-
target_names =
|
|
916
|
+
target_names = dataset_for_artifacts.target_names # type: ignore
|
|
429
917
|
except AttributeError:
|
|
430
918
|
num_targets = y_true.shape[1]
|
|
431
919
|
target_names = [f"label_{i}" for i in range(num_targets)]
|
|
@@ -435,44 +923,55 @@ class MLTrainer:
|
|
|
435
923
|
_LOGGER.error("Evaluation for multi_label_classification requires probabilities (y_prob).")
|
|
436
924
|
return
|
|
437
925
|
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
926
|
+
# Check configuration
|
|
927
|
+
config = None
|
|
928
|
+
if format_configuration and isinstance(format_configuration, MultiLabelBinaryClassificationMetricsFormat):
|
|
929
|
+
config = format_configuration
|
|
930
|
+
elif format_configuration:
|
|
931
|
+
_LOGGER.warning(f"Wrong configuration type: Received '{type(format_configuration).__name__}'.")
|
|
932
|
+
|
|
933
|
+
multi_label_classification_metrics(y_true=y_true,
|
|
934
|
+
y_pred=y_pred,
|
|
935
|
+
y_prob=y_prob,
|
|
936
|
+
target_names=target_names,
|
|
937
|
+
save_dir=save_dir,
|
|
938
|
+
config=config)
|
|
939
|
+
|
|
940
|
+
# Segmentation tasks
|
|
941
|
+
elif self.kind in [MLTaskKeys.BINARY_SEGMENTATION, MLTaskKeys.MULTICLASS_SEGMENTATION]:
|
|
451
942
|
class_names = None
|
|
452
943
|
try:
|
|
453
944
|
# Try to get 'classes' from VisionDatasetMaker
|
|
454
|
-
if hasattr(
|
|
455
|
-
class_names =
|
|
945
|
+
if hasattr(dataset_for_artifacts, 'classes'):
|
|
946
|
+
class_names = dataset_for_artifacts.classes # type: ignore
|
|
456
947
|
# Fallback for Subset
|
|
457
|
-
elif hasattr(
|
|
458
|
-
class_names =
|
|
948
|
+
elif hasattr(dataset_for_artifacts, 'dataset') and hasattr(dataset_for_artifacts.dataset, 'classes'): # type: ignore
|
|
949
|
+
class_names = dataset_for_artifacts.dataset.classes # type: ignore
|
|
459
950
|
except AttributeError:
|
|
460
951
|
pass # class_names is still None
|
|
461
952
|
|
|
462
953
|
if class_names is None:
|
|
463
954
|
try:
|
|
464
955
|
# Fallback to 'target_names'
|
|
465
|
-
class_names =
|
|
956
|
+
class_names = dataset_for_artifacts.target_names # type: ignore
|
|
466
957
|
except AttributeError:
|
|
467
958
|
# Fallback to inferring from labels
|
|
468
959
|
labels = np.unique(y_true)
|
|
469
960
|
class_names = [f"Class {i}" for i in labels]
|
|
470
961
|
_LOGGER.warning(f"Dataset has no 'classes' or 'target_names' attribute. Using generic names.")
|
|
471
962
|
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
963
|
+
# Check configuration
|
|
964
|
+
config = None
|
|
965
|
+
if format_configuration and isinstance(format_configuration, (BinarySegmentationMetricsFormat, MultiClassSegmentationMetricsFormat)):
|
|
966
|
+
config = format_configuration
|
|
967
|
+
elif format_configuration:
|
|
968
|
+
_LOGGER.warning(f"Wrong configuration type: Received '{type(format_configuration).__name__}'.")
|
|
969
|
+
|
|
970
|
+
segmentation_metrics(y_true=y_true,
|
|
971
|
+
y_pred=y_pred,
|
|
972
|
+
save_dir=save_dir,
|
|
973
|
+
class_names=class_names,
|
|
974
|
+
config=config)
|
|
476
975
|
|
|
477
976
|
def explain(self,
|
|
478
977
|
save_dir: Union[str,Path],
|
|
@@ -500,34 +999,52 @@ class MLTrainer:
|
|
|
500
999
|
explainer_type (Literal['deep', 'kernel']): The explainer to use.
|
|
501
1000
|
- 'deep': Uses shap.DeepExplainer. Fast and efficient for PyTorch models.
|
|
502
1001
|
- 'kernel': Uses shap.KernelExplainer. Model-agnostic but EXTREMELY slow and memory-intensive. Use with a very low 'n_samples'< 100.
|
|
503
|
-
"""
|
|
504
|
-
#
|
|
1002
|
+
"""
|
|
1003
|
+
# memory efficient helper
|
|
505
1004
|
def _get_random_sample(dataset: Dataset, num_samples: int):
|
|
1005
|
+
"""
|
|
1006
|
+
Memory-efficiently samples data from a dataset.
|
|
1007
|
+
"""
|
|
506
1008
|
if dataset is None:
|
|
507
1009
|
return None
|
|
508
1010
|
|
|
1011
|
+
dataset_len = len(dataset) # type: ignore
|
|
1012
|
+
if dataset_len == 0:
|
|
1013
|
+
return None
|
|
1014
|
+
|
|
509
1015
|
# For MPS devices, num_workers must be 0 to ensure stability
|
|
510
1016
|
loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
|
|
511
1017
|
|
|
1018
|
+
# Ensure batch_size is not larger than the dataset itself
|
|
1019
|
+
batch_size = min(num_samples, 64, dataset_len)
|
|
1020
|
+
|
|
512
1021
|
loader = DataLoader(
|
|
513
1022
|
dataset,
|
|
514
|
-
batch_size=
|
|
515
|
-
shuffle=
|
|
1023
|
+
batch_size=batch_size,
|
|
1024
|
+
shuffle=True, # Shuffle to get random samples
|
|
516
1025
|
num_workers=loader_workers
|
|
517
1026
|
)
|
|
518
1027
|
|
|
519
|
-
|
|
520
|
-
|
|
1028
|
+
collected_features = []
|
|
1029
|
+
num_collected = 0
|
|
1030
|
+
|
|
1031
|
+
for features, _ in loader:
|
|
1032
|
+
collected_features.append(features)
|
|
1033
|
+
num_collected += features.size(0)
|
|
1034
|
+
if num_collected >= num_samples:
|
|
1035
|
+
break # Stop once we have enough samples
|
|
1036
|
+
|
|
1037
|
+
if not collected_features:
|
|
521
1038
|
return None
|
|
522
1039
|
|
|
523
|
-
full_data = torch.cat(
|
|
1040
|
+
full_data = torch.cat(collected_features, dim=0)
|
|
524
1041
|
|
|
525
|
-
|
|
526
|
-
|
|
1042
|
+
# If we collected more than needed, trim it down
|
|
1043
|
+
if full_data.size(0) > num_samples:
|
|
1044
|
+
return full_data[:num_samples]
|
|
527
1045
|
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
1046
|
+
return full_data
|
|
1047
|
+
|
|
531
1048
|
# print(f"\n--- Preparing SHAP Data (sampling up to {n_samples} instances) ---")
|
|
532
1049
|
|
|
533
1050
|
# 1. Get background data from the trainer's train_dataset
|
|
@@ -537,7 +1054,7 @@ class MLTrainer:
|
|
|
537
1054
|
return
|
|
538
1055
|
|
|
539
1056
|
# 2. Determine target dataset and get explanation instances
|
|
540
|
-
target_dataset = explain_dataset if explain_dataset is not None else self.
|
|
1057
|
+
target_dataset = explain_dataset if explain_dataset is not None else self.validation_dataset
|
|
541
1058
|
instances_to_explain = _get_random_sample(target_dataset, n_samples)
|
|
542
1059
|
if instances_to_explain is None:
|
|
543
1060
|
_LOGGER.error("Explanation dataset is empty or invalid. Skipping SHAP analysis.")
|
|
@@ -556,7 +1073,7 @@ class MLTrainer:
|
|
|
556
1073
|
self.model.to(self.device)
|
|
557
1074
|
|
|
558
1075
|
# 3. Call the plotting function
|
|
559
|
-
if self.kind in [
|
|
1076
|
+
if self.kind in [MLTaskKeys.REGRESSION, MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.MULTICLASS_CLASSIFICATION]:
|
|
560
1077
|
shap_summary_plot(
|
|
561
1078
|
model=self.model,
|
|
562
1079
|
background_data=background_data,
|
|
@@ -566,7 +1083,7 @@ class MLTrainer:
|
|
|
566
1083
|
explainer_type=explainer_type,
|
|
567
1084
|
device=self.device
|
|
568
1085
|
)
|
|
569
|
-
elif self.kind in [
|
|
1086
|
+
elif self.kind in [MLTaskKeys.MULTITARGET_REGRESSION, MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION]:
|
|
570
1087
|
# try to get target names
|
|
571
1088
|
if target_names is None:
|
|
572
1089
|
target_names = []
|
|
@@ -640,13 +1157,11 @@ class MLTrainer:
|
|
|
640
1157
|
|
|
641
1158
|
# --- Step 1: Check if the model supports this explanation ---
|
|
642
1159
|
if not getattr(self.model, 'has_interpretable_attention', False):
|
|
643
|
-
_LOGGER.warning(
|
|
644
|
-
"Model is not flagged for interpretable attention analysis. Skipping. This is the correct behavior for models like MultiHeadAttentionMLP."
|
|
645
|
-
)
|
|
1160
|
+
_LOGGER.warning("Model is not compatible with interpretable attention analysis. Skipping.")
|
|
646
1161
|
return
|
|
647
1162
|
|
|
648
1163
|
# --- Step 2: Set up the dataloader ---
|
|
649
|
-
dataset_to_use = explain_dataset if explain_dataset is not None else self.
|
|
1164
|
+
dataset_to_use = explain_dataset if explain_dataset is not None else self.validation_dataset
|
|
650
1165
|
if not isinstance(dataset_to_use, Dataset):
|
|
651
1166
|
_LOGGER.error("The explanation dataset is empty or invalid. Skipping attention analysis.")
|
|
652
1167
|
return
|
|
@@ -681,40 +1196,101 @@ class MLTrainer:
|
|
|
681
1196
|
)
|
|
682
1197
|
else:
|
|
683
1198
|
_LOGGER.error("No attention weights were collected from the model.")
|
|
684
|
-
|
|
685
|
-
def _callbacks_hook(self, method_name: str, *args, **kwargs):
|
|
686
|
-
"""Calls the specified method on all callbacks."""
|
|
687
|
-
for callback in self.callbacks:
|
|
688
|
-
method = getattr(callback, method_name)
|
|
689
|
-
method(*args, **kwargs)
|
|
690
|
-
|
|
691
|
-
def to_cpu(self):
|
|
692
|
-
"""
|
|
693
|
-
Moves the model to the CPU and updates the trainer's device setting.
|
|
694
1199
|
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
1200
|
+
def finalize_model_training(self,
|
|
1201
|
+
model_checkpoint: Union[Path, Literal['latest', 'current']],
|
|
1202
|
+
save_dir: Union[str, Path],
|
|
1203
|
+
finalize_config: Union[FinalizeRegression,
|
|
1204
|
+
FinalizeMultiTargetRegression,
|
|
1205
|
+
FinalizeBinaryClassification,
|
|
1206
|
+
FinalizeBinaryImageClassification,
|
|
1207
|
+
FinalizeMultiClassClassification,
|
|
1208
|
+
FinalizeMultiClassImageClassification,
|
|
1209
|
+
FinalizeBinarySegmentation,
|
|
1210
|
+
FinalizeMultiClassSegmentation,
|
|
1211
|
+
FinalizeMultiLabelBinaryClassification]):
|
|
702
1212
|
"""
|
|
703
|
-
|
|
1213
|
+
Saves a finalized, "inference-ready" model state to a .pth file.
|
|
1214
|
+
|
|
1215
|
+
This method saves the model's `state_dict`, the final epoch number, and optional configuration for the task at hand.
|
|
704
1216
|
|
|
705
1217
|
Args:
|
|
706
|
-
|
|
1218
|
+
model_checkpoint (Path | "latest" | "current"):
|
|
1219
|
+
- Path: Loads the model state from a specific checkpoint file.
|
|
1220
|
+
- "latest": Loads the best model state saved by the `DragonModelCheckpoint` callback.
|
|
1221
|
+
- "current": Uses the model's state as it is at the end of the `fit()` call.
|
|
1222
|
+
save_dir (str | Path): The directory to save the finalized model.
|
|
1223
|
+
finalize_config (object): A data class instance specific to the ML task containing task-specific metadata required for inference.
|
|
707
1224
|
"""
|
|
708
|
-
self.
|
|
709
|
-
|
|
710
|
-
|
|
1225
|
+
if self.kind == MLTaskKeys.REGRESSION and not isinstance(finalize_config, FinalizeRegression):
|
|
1226
|
+
_LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeRegression', but got {type(finalize_config).__name__}.")
|
|
1227
|
+
raise TypeError()
|
|
1228
|
+
elif self.kind == MLTaskKeys.MULTITARGET_REGRESSION and not isinstance(finalize_config, FinalizeMultiTargetRegression):
|
|
1229
|
+
_LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeMultiTargetRegression', but got {type(finalize_config).__name__}.")
|
|
1230
|
+
raise TypeError()
|
|
1231
|
+
elif self.kind == MLTaskKeys.BINARY_CLASSIFICATION and not isinstance(finalize_config, FinalizeBinaryClassification):
|
|
1232
|
+
_LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeBinaryClassification', but got {type(finalize_config).__name__}.")
|
|
1233
|
+
raise TypeError()
|
|
1234
|
+
elif self.kind == MLTaskKeys.BINARY_IMAGE_CLASSIFICATION and not isinstance(finalize_config, FinalizeBinaryImageClassification):
|
|
1235
|
+
_LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeBinaryImageClassification', but got {type(finalize_config).__name__}.")
|
|
1236
|
+
raise TypeError()
|
|
1237
|
+
elif self.kind == MLTaskKeys.MULTICLASS_CLASSIFICATION and not isinstance(finalize_config, FinalizeMultiClassClassification):
|
|
1238
|
+
_LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeMultiClassClassification', but got {type(finalize_config).__name__}.")
|
|
1239
|
+
raise TypeError()
|
|
1240
|
+
elif self.kind == MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION and not isinstance(finalize_config, FinalizeMultiClassImageClassification):
|
|
1241
|
+
_LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeMultiClassImageClassification', but got {type(finalize_config).__name__}.")
|
|
1242
|
+
raise TypeError()
|
|
1243
|
+
elif self.kind == MLTaskKeys.BINARY_SEGMENTATION and not isinstance(finalize_config, FinalizeBinarySegmentation):
|
|
1244
|
+
_LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeBinarySegmentation', but got {type(finalize_config).__name__}.")
|
|
1245
|
+
raise TypeError()
|
|
1246
|
+
elif self.kind == MLTaskKeys.MULTICLASS_SEGMENTATION and not isinstance(finalize_config, FinalizeMultiClassSegmentation):
|
|
1247
|
+
_LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeMultiClassSegmentation', but got {type(finalize_config).__name__}.")
|
|
1248
|
+
raise TypeError()
|
|
1249
|
+
elif self.kind == MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION and not isinstance(finalize_config, FinalizeMultiLabelBinaryClassification):
|
|
1250
|
+
_LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeMultiLabelBinaryClassification', but got {type(finalize_config).__name__}.")
|
|
1251
|
+
raise TypeError()
|
|
1252
|
+
|
|
1253
|
+
# handle save path
|
|
1254
|
+
dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
1255
|
+
full_path = dir_path / finalize_config.filename
|
|
1256
|
+
|
|
1257
|
+
# handle checkpoint
|
|
1258
|
+
self._load_model_state_for_finalizing(model_checkpoint)
|
|
1259
|
+
|
|
1260
|
+
# Create finalized data
|
|
1261
|
+
finalized_data = {
|
|
1262
|
+
PyTorchCheckpointKeys.EPOCH: self.epoch,
|
|
1263
|
+
PyTorchCheckpointKeys.MODEL_STATE: self.model.state_dict(),
|
|
1264
|
+
}
|
|
1265
|
+
|
|
1266
|
+
# Parse config
|
|
1267
|
+
if finalize_config.target_name is not None:
|
|
1268
|
+
finalized_data[PyTorchCheckpointKeys.TARGET_NAME] = finalize_config.target_name
|
|
1269
|
+
if finalize_config.target_names is not None:
|
|
1270
|
+
finalized_data[PyTorchCheckpointKeys.TARGET_NAMES] = finalize_config.target_names
|
|
1271
|
+
if finalize_config.classification_threshold is not None:
|
|
1272
|
+
finalized_data[PyTorchCheckpointKeys.CLASSIFICATION_THRESHOLD] = finalize_config.classification_threshold
|
|
1273
|
+
if finalize_config.class_map is not None:
|
|
1274
|
+
finalized_data[PyTorchCheckpointKeys.CLASS_MAP] = finalize_config.class_map
|
|
1275
|
+
|
|
1276
|
+
# Save model file
|
|
1277
|
+
torch.save(finalized_data, full_path)
|
|
1278
|
+
|
|
1279
|
+
_LOGGER.info(f"Finalized model file saved to '{full_path}'")
|
|
711
1280
|
|
|
712
1281
|
|
|
713
1282
|
# Object Detection Trainer
|
|
714
|
-
class
|
|
715
|
-
def __init__(self, model: nn.Module,
|
|
1283
|
+
class DragonDetectionTrainer(_BaseDragonTrainer):
|
|
1284
|
+
def __init__(self, model: nn.Module,
|
|
1285
|
+
train_dataset: Dataset,
|
|
1286
|
+
validation_dataset: Dataset,
|
|
716
1287
|
collate_fn: Callable, optimizer: torch.optim.Optimizer,
|
|
717
|
-
device: Union[Literal['cuda', 'mps', 'cpu'],str],
|
|
1288
|
+
device: Union[Literal['cuda', 'mps', 'cpu'],str],
|
|
1289
|
+
checkpoint_callback: Optional[DragonModelCheckpoint],
|
|
1290
|
+
early_stopping_callback: Optional[DragonEarlyStopping],
|
|
1291
|
+
lr_scheduler_callback: Optional[DragonLRScheduler],
|
|
1292
|
+
extra_callbacks: Optional[List[_Callback]] = None,
|
|
1293
|
+
dataloader_workers: int = 2):
|
|
718
1294
|
"""
|
|
719
1295
|
Automates the training process of an Object Detection Model (e.g., DragonFastRCNN).
|
|
720
1296
|
|
|
@@ -723,58 +1299,36 @@ class ObjectDetectionTrainer:
|
|
|
723
1299
|
Args:
|
|
724
1300
|
model (nn.Module): The PyTorch object detection model to train.
|
|
725
1301
|
train_dataset (Dataset): The training dataset.
|
|
726
|
-
|
|
1302
|
+
validation_dataset (Dataset): The testing/validation dataset.
|
|
727
1303
|
collate_fn (Callable): The collate function from `ObjectDetectionDatasetMaker.collate_fn`.
|
|
728
1304
|
optimizer (torch.optim.Optimizer): The optimizer.
|
|
729
1305
|
device (str): The device to run training on ('cpu', 'cuda', 'mps').
|
|
730
1306
|
dataloader_workers (int): Subprocesses for data loading.
|
|
731
|
-
|
|
1307
|
+
checkpoint_callback (DragonModelCheckpoint | None): Callback to save the model.
|
|
1308
|
+
early_stopping_callback (DragonEarlyStopping | None): Callback to stop training early.
|
|
1309
|
+
lr_scheduler_callback (DragonLRScheduler | None): Callback to manage the LR scheduler.
|
|
1310
|
+
extra_callbacks (List[Callback] | None): A list of extra callbacks to use during training.
|
|
732
1311
|
|
|
733
1312
|
## Note:
|
|
734
1313
|
This trainer is specialized. It does not take a `criterion` because object detection models like Faster R-CNN return a dictionary of losses directly from their forward pass during training.
|
|
735
1314
|
"""
|
|
736
|
-
|
|
1315
|
+
# Call the base class constructor with common parameters
|
|
1316
|
+
super().__init__(
|
|
1317
|
+
model=model,
|
|
1318
|
+
optimizer=optimizer,
|
|
1319
|
+
device=device,
|
|
1320
|
+
dataloader_workers=dataloader_workers,
|
|
1321
|
+
checkpoint_callback=checkpoint_callback,
|
|
1322
|
+
early_stopping_callback=early_stopping_callback,
|
|
1323
|
+
lr_scheduler_callback=lr_scheduler_callback,
|
|
1324
|
+
extra_callbacks=extra_callbacks
|
|
1325
|
+
)
|
|
1326
|
+
|
|
737
1327
|
self.train_dataset = train_dataset
|
|
738
|
-
self.
|
|
739
|
-
self.kind =
|
|
1328
|
+
self.validation_dataset = validation_dataset # <-- Renamed
|
|
1329
|
+
self.kind = MLTaskKeys.OBJECT_DETECTION
|
|
740
1330
|
self.collate_fn = collate_fn
|
|
741
1331
|
self.criterion = None # Criterion is handled inside the model
|
|
742
|
-
self.optimizer = optimizer
|
|
743
|
-
self.scheduler = None
|
|
744
|
-
self.device = self._validate_device(device)
|
|
745
|
-
self.dataloader_workers = dataloader_workers
|
|
746
|
-
|
|
747
|
-
# Callback handler - History and TqdmProgressBar are added by default
|
|
748
|
-
default_callbacks = [History(), TqdmProgressBar()]
|
|
749
|
-
user_callbacks = callbacks if callbacks is not None else []
|
|
750
|
-
self.callbacks = default_callbacks + user_callbacks
|
|
751
|
-
self._set_trainer_on_callbacks()
|
|
752
|
-
|
|
753
|
-
# Internal state
|
|
754
|
-
self.train_loader = None
|
|
755
|
-
self.test_loader = None
|
|
756
|
-
self.history = {}
|
|
757
|
-
self.epoch = 0
|
|
758
|
-
self.epochs = 0 # Total epochs for the fit run
|
|
759
|
-
self.start_epoch = 1
|
|
760
|
-
self.stop_training = False
|
|
761
|
-
self._batch_size = 10
|
|
762
|
-
|
|
763
|
-
def _validate_device(self, device: str) -> torch.device:
|
|
764
|
-
"""Validates the selected device and returns a torch.device object."""
|
|
765
|
-
device_lower = device.lower()
|
|
766
|
-
if "cuda" in device_lower and not torch.cuda.is_available():
|
|
767
|
-
_LOGGER.warning("CUDA not available, switching to CPU.")
|
|
768
|
-
device = "cpu"
|
|
769
|
-
elif device_lower == "mps" and not torch.backends.mps.is_available():
|
|
770
|
-
_LOGGER.warning("Apple Metal Performance Shaders (MPS) not available, switching to CPU.")
|
|
771
|
-
device = "cpu"
|
|
772
|
-
return torch.device(device)
|
|
773
|
-
|
|
774
|
-
def _set_trainer_on_callbacks(self):
|
|
775
|
-
"""Gives each callback a reference to this trainer instance."""
|
|
776
|
-
for callback in self.callbacks:
|
|
777
|
-
callback.set_trainer(self)
|
|
778
1332
|
|
|
779
1333
|
def _create_dataloaders(self, batch_size: int, shuffle: bool):
|
|
780
1334
|
"""Initializes the DataLoaders with the object detection collate_fn."""
|
|
@@ -786,125 +1340,25 @@ class ObjectDetectionTrainer:
|
|
|
786
1340
|
batch_size=batch_size,
|
|
787
1341
|
shuffle=shuffle,
|
|
788
1342
|
num_workers=loader_workers,
|
|
789
|
-
pin_memory=("cuda" in self.device.type),
|
|
790
|
-
collate_fn=self.collate_fn # Use the provided collate function
|
|
1343
|
+
pin_memory=("cuda" in self.device.type),
|
|
1344
|
+
collate_fn=self.collate_fn, # Use the provided collate function
|
|
1345
|
+
drop_last=True
|
|
791
1346
|
)
|
|
792
1347
|
|
|
793
|
-
self.
|
|
794
|
-
dataset=self.
|
|
1348
|
+
self.validation_loader = DataLoader(
|
|
1349
|
+
dataset=self.validation_dataset,
|
|
795
1350
|
batch_size=batch_size,
|
|
796
1351
|
shuffle=False,
|
|
797
1352
|
num_workers=loader_workers,
|
|
798
1353
|
pin_memory=("cuda" in self.device.type),
|
|
799
1354
|
collate_fn=self.collate_fn # Use the provided collate function
|
|
800
1355
|
)
|
|
801
|
-
|
|
802
|
-
def _load_checkpoint(self, path: Union[str, Path]):
|
|
803
|
-
"""Loads a training checkpoint to resume training."""
|
|
804
|
-
p = make_fullpath(path, enforce="file")
|
|
805
|
-
_LOGGER.info(f"Loading checkpoint from '{p.name}' to resume training...")
|
|
806
|
-
|
|
807
|
-
try:
|
|
808
|
-
checkpoint = torch.load(p, map_location=self.device)
|
|
809
|
-
|
|
810
|
-
if PyTorchCheckpointKeys.MODEL_STATE not in checkpoint or PyTorchCheckpointKeys.OPTIMIZER_STATE not in checkpoint:
|
|
811
|
-
_LOGGER.error(f"Checkpoint file '{p.name}' is invalid. Missing 'model_state_dict' or 'optimizer_state_dict'.")
|
|
812
|
-
raise KeyError()
|
|
813
|
-
|
|
814
|
-
self.model.load_state_dict(checkpoint[PyTorchCheckpointKeys.MODEL_STATE])
|
|
815
|
-
self.optimizer.load_state_dict(checkpoint[PyTorchCheckpointKeys.OPTIMIZER_STATE])
|
|
816
|
-
self.start_epoch = checkpoint.get(PyTorchCheckpointKeys.EPOCH, 0) + 1 # Resume on the *next* epoch
|
|
817
|
-
|
|
818
|
-
# --- Scheduler State Loading Logic ---
|
|
819
|
-
scheduler_state_exists = PyTorchCheckpointKeys.SCHEDULER_STATE in checkpoint
|
|
820
|
-
scheduler_object_exists = self.scheduler is not None
|
|
821
|
-
|
|
822
|
-
if scheduler_object_exists and scheduler_state_exists:
|
|
823
|
-
# Case 1: Both exist. Attempt to load.
|
|
824
|
-
try:
|
|
825
|
-
self.scheduler.load_state_dict(checkpoint[PyTorchCheckpointKeys.SCHEDULER_STATE]) # type: ignore
|
|
826
|
-
scheduler_name = self.scheduler.__class__.__name__
|
|
827
|
-
_LOGGER.info(f"Restored LR scheduler state for: {scheduler_name}")
|
|
828
|
-
except Exception as e:
|
|
829
|
-
# Loading failed, likely a mismatch
|
|
830
|
-
scheduler_name = self.scheduler.__class__.__name__
|
|
831
|
-
_LOGGER.error(f"Failed to load scheduler state for '{scheduler_name}'. A different scheduler type might have been used.")
|
|
832
|
-
raise e
|
|
833
|
-
|
|
834
|
-
elif scheduler_object_exists and not scheduler_state_exists:
|
|
835
|
-
# Case 2: Scheduler provided, but no state in checkpoint.
|
|
836
|
-
scheduler_name = self.scheduler.__class__.__name__
|
|
837
|
-
_LOGGER.warning(f"'{scheduler_name}' was provided, but no scheduler state was found in the checkpoint. The scheduler will start from its initial state.")
|
|
838
|
-
|
|
839
|
-
elif not scheduler_object_exists and scheduler_state_exists:
|
|
840
|
-
# Case 3: State in checkpoint, but no scheduler provided.
|
|
841
|
-
_LOGGER.error("Checkpoint contains an LR scheduler state, but no LRScheduler callback was provided.")
|
|
842
|
-
raise ValueError()
|
|
843
|
-
|
|
844
|
-
# Restore callback states
|
|
845
|
-
for cb in self.callbacks:
|
|
846
|
-
if isinstance(cb, ModelCheckpoint) and PyTorchCheckpointKeys.BEST_SCORE in checkpoint:
|
|
847
|
-
cb.best = checkpoint[PyTorchCheckpointKeys.BEST_SCORE]
|
|
848
|
-
_LOGGER.info(f"Restored {cb.__class__.__name__} 'best' score to: {cb.best:.4f}")
|
|
849
|
-
|
|
850
|
-
_LOGGER.info(f"Checkpoint loaded. Resuming training from epoch {self.start_epoch}.")
|
|
851
|
-
|
|
852
|
-
except Exception as e:
|
|
853
|
-
_LOGGER.error(f"Failed to load checkpoint from '{p}': {e}")
|
|
854
|
-
raise
|
|
855
|
-
|
|
856
|
-
def fit(self,
|
|
857
|
-
epochs: int = 10,
|
|
858
|
-
batch_size: int = 10,
|
|
859
|
-
shuffle: bool = True,
|
|
860
|
-
resume_from_checkpoint: Optional[Union[str, Path]] = None):
|
|
861
|
-
"""
|
|
862
|
-
Starts the training-validation process of the model.
|
|
863
|
-
|
|
864
|
-
Returns the "History" callback dictionary.
|
|
865
|
-
|
|
866
|
-
Args:
|
|
867
|
-
epochs (int): The total number of epochs to train for.
|
|
868
|
-
batch_size (int): The number of samples per batch.
|
|
869
|
-
shuffle (bool): Whether to shuffle the training data at each epoch.
|
|
870
|
-
resume_from_checkpoint (str | Path | None): Optional path to a checkpoint to resume training.
|
|
871
|
-
"""
|
|
872
|
-
self.epochs = epochs
|
|
873
|
-
self._batch_size = batch_size
|
|
874
|
-
self._create_dataloaders(self._batch_size, shuffle)
|
|
875
|
-
self.model.to(self.device)
|
|
876
|
-
|
|
877
|
-
if resume_from_checkpoint:
|
|
878
|
-
self._load_checkpoint(resume_from_checkpoint)
|
|
879
|
-
|
|
880
|
-
# Reset stop_training flag on the trainer
|
|
881
|
-
self.stop_training = False
|
|
882
|
-
|
|
883
|
-
self._callbacks_hook('on_train_begin')
|
|
884
|
-
|
|
885
|
-
for epoch in range(self.start_epoch, self.epochs + 1):
|
|
886
|
-
self.epoch = epoch
|
|
887
|
-
epoch_logs = {}
|
|
888
|
-
self._callbacks_hook('on_epoch_begin', epoch, logs=epoch_logs)
|
|
889
|
-
|
|
890
|
-
train_logs = self._train_step()
|
|
891
|
-
epoch_logs.update(train_logs)
|
|
892
|
-
|
|
893
|
-
val_logs = self._validation_step()
|
|
894
|
-
epoch_logs.update(val_logs)
|
|
895
|
-
|
|
896
|
-
self._callbacks_hook('on_epoch_end', epoch, logs=epoch_logs)
|
|
897
|
-
|
|
898
|
-
# Check the early stopping flag
|
|
899
|
-
if self.stop_training:
|
|
900
|
-
break
|
|
901
1356
|
|
|
902
|
-
self._callbacks_hook('on_train_end')
|
|
903
|
-
return self.history
|
|
904
|
-
|
|
905
1357
|
def _train_step(self):
|
|
906
1358
|
self.model.train()
|
|
907
1359
|
running_loss = 0.0
|
|
1360
|
+
total_samples = 0
|
|
1361
|
+
|
|
908
1362
|
for batch_idx, (images, targets) in enumerate(self.train_loader): # type: ignore
|
|
909
1363
|
# images is a tuple of tensors, targets is a tuple of dicts
|
|
910
1364
|
batch_size = len(images)
|
|
@@ -941,21 +1395,28 @@ class ObjectDetectionTrainer:
|
|
|
941
1395
|
# Calculate batch loss and update running loss for the epoch
|
|
942
1396
|
batch_loss = loss.item()
|
|
943
1397
|
running_loss += batch_loss * batch_size
|
|
1398
|
+
total_samples += batch_size # <-- Accumulate total samples
|
|
944
1399
|
|
|
945
1400
|
# Add the batch loss to the logs and call the end-of-batch hook
|
|
946
1401
|
batch_logs[PyTorchLogKeys.BATCH_LOSS] = batch_loss # type: ignore
|
|
947
1402
|
self._callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
|
|
1403
|
+
|
|
1404
|
+
# Calculate loss using the correct denominator
|
|
1405
|
+
if total_samples == 0:
|
|
1406
|
+
_LOGGER.warning("No samples processed in _train_step. Returning 0 loss.")
|
|
1407
|
+
return {PyTorchLogKeys.TRAIN_LOSS: 0.0}
|
|
948
1408
|
|
|
949
|
-
return {PyTorchLogKeys.TRAIN_LOSS: running_loss /
|
|
1409
|
+
return {PyTorchLogKeys.TRAIN_LOSS: running_loss / total_samples}
|
|
950
1410
|
|
|
951
1411
|
def _validation_step(self):
|
|
952
1412
|
self.model.train() # Set to train mode even for validation loss calculation
|
|
953
|
-
# as model internals (e.g., proposals) might differ,
|
|
954
|
-
#
|
|
955
|
-
# We use torch.no_grad() to prevent gradient updates.
|
|
1413
|
+
# as model internals (e.g., proposals) might differ, but we still need loss_dict.
|
|
1414
|
+
# use torch.no_grad() to prevent gradient updates.
|
|
956
1415
|
running_loss = 0.0
|
|
1416
|
+
total_samples = 0
|
|
1417
|
+
|
|
957
1418
|
with torch.no_grad():
|
|
958
|
-
for images, targets in self.
|
|
1419
|
+
for images, targets in self.validation_loader: # type: ignore
|
|
959
1420
|
batch_size = len(images)
|
|
960
1421
|
|
|
961
1422
|
# Move data to device
|
|
@@ -973,25 +1434,105 @@ class ObjectDetectionTrainer:
|
|
|
973
1434
|
loss: torch.Tensor = sum(l for l in loss_dict.values()) # type: ignore
|
|
974
1435
|
|
|
975
1436
|
running_loss += loss.item() * batch_size
|
|
1437
|
+
total_samples += batch_size # <-- Accumulate total samples
|
|
976
1438
|
|
|
977
|
-
|
|
1439
|
+
# Calculate loss using the correct denominator
|
|
1440
|
+
if total_samples == 0:
|
|
1441
|
+
_LOGGER.warning("No samples processed in _validation_step. Returning 0 loss.")
|
|
1442
|
+
return {PyTorchLogKeys.VAL_LOSS: 0.0}
|
|
1443
|
+
|
|
1444
|
+
logs = {PyTorchLogKeys.VAL_LOSS: running_loss / total_samples}
|
|
978
1445
|
return logs
|
|
1446
|
+
|
|
1447
|
+
def evaluate(self,
|
|
1448
|
+
save_dir: Union[str, Path],
|
|
1449
|
+
model_checkpoint: Union[Path, Literal["latest", "current"]],
|
|
1450
|
+
test_data: Optional[Union[DataLoader, Dataset]] = None):
|
|
1451
|
+
"""
|
|
1452
|
+
Evaluates the model using object detection mAP metrics.
|
|
979
1453
|
|
|
980
|
-
|
|
1454
|
+
Args:
|
|
1455
|
+
save_dir (str | Path): Directory to save all reports and plots.
|
|
1456
|
+
model_checkpoint ('auto' | Path | None):
|
|
1457
|
+
- Path to a valid checkpoint for the model. The state of the trained model will be overwritten in place.
|
|
1458
|
+
- If 'latest', the latest checkpoint will be loaded if a DragonModelCheckpoint was provided. The state of the trained model will be overwritten in place.
|
|
1459
|
+
- If 'current', use the current state of the trained model up the latest trained epoch.
|
|
1460
|
+
test_data (DataLoader | Dataset | None): Optional Test data to evaluate the model performance. Validation and Test metrics will be saved to subdirectories.
|
|
1461
|
+
"""
|
|
1462
|
+
# Validate model checkpoint
|
|
1463
|
+
if isinstance(model_checkpoint, Path):
|
|
1464
|
+
checkpoint_validated = make_fullpath(model_checkpoint, enforce="file")
|
|
1465
|
+
elif model_checkpoint in [MagicWords.LATEST, MagicWords.CURRENT]:
|
|
1466
|
+
checkpoint_validated = model_checkpoint
|
|
1467
|
+
else:
|
|
1468
|
+
_LOGGER.error(f"'model_checkpoint' must be a Path object, or the string '{MagicWords.LATEST}', or the string '{MagicWords.CURRENT}'.")
|
|
1469
|
+
raise ValueError()
|
|
1470
|
+
|
|
1471
|
+
# Validate directory
|
|
1472
|
+
save_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
1473
|
+
|
|
1474
|
+
# Validate test data and dispatch
|
|
1475
|
+
if test_data is not None:
|
|
1476
|
+
if not isinstance(test_data, (DataLoader, Dataset)):
|
|
1477
|
+
_LOGGER.error(f"Invalid type for 'test_data': '{type(test_data)}'.")
|
|
1478
|
+
raise ValueError()
|
|
1479
|
+
test_data_validated = test_data
|
|
1480
|
+
|
|
1481
|
+
validation_metrics_path = save_path / DragonTrainerKeys.VALIDATION_METRICS_DIR
|
|
1482
|
+
test_metrics_path = save_path / DragonTrainerKeys.TEST_METRICS_DIR
|
|
1483
|
+
|
|
1484
|
+
# Dispatch validation set
|
|
1485
|
+
_LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{DragonTrainerKeys.VALIDATION_METRICS_DIR}'")
|
|
1486
|
+
self._evaluate(save_dir=validation_metrics_path,
|
|
1487
|
+
model_checkpoint=checkpoint_validated,
|
|
1488
|
+
data=None) # 'None' triggers use of self.test_dataset
|
|
1489
|
+
|
|
1490
|
+
# Dispatch test set
|
|
1491
|
+
_LOGGER.info(f"Evaluating on test dataset. Metrics will be saved to '{DragonTrainerKeys.TEST_METRICS_DIR}'")
|
|
1492
|
+
self._evaluate(save_dir=test_metrics_path,
|
|
1493
|
+
model_checkpoint="current", # Use 'current' state after loading checkpoint once
|
|
1494
|
+
data=test_data_validated)
|
|
1495
|
+
else:
|
|
1496
|
+
# Dispatch validation set
|
|
1497
|
+
_LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{save_path.name}'")
|
|
1498
|
+
self._evaluate(save_dir=save_path,
|
|
1499
|
+
model_checkpoint=checkpoint_validated,
|
|
1500
|
+
data=None) # 'None' triggers use of self.test_dataset
|
|
1501
|
+
|
|
1502
|
+
def _evaluate(self,
|
|
1503
|
+
save_dir: Union[str, Path],
|
|
1504
|
+
model_checkpoint: Union[Path, Literal["latest", "current"]],
|
|
1505
|
+
data: Optional[Union[DataLoader, Dataset]]):
|
|
981
1506
|
"""
|
|
1507
|
+
Changed to a private helper method
|
|
982
1508
|
Evaluates the model using object detection mAP metrics.
|
|
983
1509
|
|
|
984
1510
|
Args:
|
|
985
1511
|
save_dir (str | Path): Directory to save all reports and plots.
|
|
986
1512
|
data (DataLoader | Dataset | None): The data to evaluate on. If None, defaults to the trainer's internal test_dataset.
|
|
1513
|
+
model_checkpoint ('auto' | Path | None):
|
|
1514
|
+
- Path to a valid checkpoint for the model. The state of the trained model will be overwritten in place.
|
|
1515
|
+
- If 'latest', the latest checkpoint will be loaded if a DragonModelCheckpoint was provided. The state of the trained model will be overwritten in place.
|
|
1516
|
+
- If 'current', use the current state of the trained model up the latest trained epoch.
|
|
987
1517
|
"""
|
|
988
|
-
|
|
1518
|
+
dataset_for_artifacts = None
|
|
989
1519
|
eval_loader = None
|
|
1520
|
+
|
|
1521
|
+
# load model checkpoint
|
|
1522
|
+
if isinstance(model_checkpoint, Path):
|
|
1523
|
+
self._load_checkpoint(path=model_checkpoint)
|
|
1524
|
+
elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback:
|
|
1525
|
+
path_to_latest = self._checkpoint_callback.best_checkpoint_path
|
|
1526
|
+
self._load_checkpoint(path_to_latest)
|
|
1527
|
+
elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback is None:
|
|
1528
|
+
_LOGGER.error(f"'model_checkpoint' set to '{MagicWords.LATEST}' but no checkpoint callback was found.")
|
|
1529
|
+
raise ValueError()
|
|
990
1530
|
|
|
1531
|
+
# Dataloader
|
|
991
1532
|
if isinstance(data, DataLoader):
|
|
992
1533
|
eval_loader = data
|
|
993
1534
|
if hasattr(data, 'dataset'):
|
|
994
|
-
|
|
1535
|
+
dataset_for_artifacts = data.dataset # type: ignore
|
|
995
1536
|
elif isinstance(data, Dataset):
|
|
996
1537
|
# Create a new loader from the provided dataset
|
|
997
1538
|
eval_loader = DataLoader(data,
|
|
@@ -1000,21 +1541,21 @@ class ObjectDetectionTrainer:
|
|
|
1000
1541
|
num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
|
|
1001
1542
|
pin_memory=(self.device.type == "cuda"),
|
|
1002
1543
|
collate_fn=self.collate_fn)
|
|
1003
|
-
|
|
1544
|
+
dataset_for_artifacts = data
|
|
1004
1545
|
else: # data is None, use the trainer's default test dataset
|
|
1005
|
-
if self.
|
|
1546
|
+
if self.validation_dataset is None:
|
|
1006
1547
|
_LOGGER.error("Cannot evaluate. No data provided and no test_dataset available in the trainer.")
|
|
1007
1548
|
raise ValueError()
|
|
1008
1549
|
# Create a fresh DataLoader from the test_dataset
|
|
1009
1550
|
eval_loader = DataLoader(
|
|
1010
|
-
self.
|
|
1551
|
+
self.validation_dataset,
|
|
1011
1552
|
batch_size=self._batch_size,
|
|
1012
1553
|
shuffle=False,
|
|
1013
1554
|
num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
|
|
1014
1555
|
pin_memory=(self.device.type == "cuda"),
|
|
1015
1556
|
collate_fn=self.collate_fn
|
|
1016
1557
|
)
|
|
1017
|
-
|
|
1558
|
+
dataset_for_artifacts = self.validation_dataset
|
|
1018
1559
|
|
|
1019
1560
|
if eval_loader is None:
|
|
1020
1561
|
_LOGGER.error("Cannot evaluate. No valid data was provided or found.")
|
|
@@ -1051,11 +1592,11 @@ class ObjectDetectionTrainer:
|
|
|
1051
1592
|
class_names = None
|
|
1052
1593
|
try:
|
|
1053
1594
|
# Try to get 'classes' from ObjectDetectionDatasetMaker
|
|
1054
|
-
if hasattr(
|
|
1055
|
-
class_names =
|
|
1595
|
+
if hasattr(dataset_for_artifacts, 'classes'):
|
|
1596
|
+
class_names = dataset_for_artifacts.classes # type: ignore
|
|
1056
1597
|
# Fallback for Subset
|
|
1057
|
-
elif hasattr(
|
|
1058
|
-
class_names =
|
|
1598
|
+
elif hasattr(dataset_for_artifacts, 'dataset') and hasattr(dataset_for_artifacts.dataset, 'classes'): # type: ignore
|
|
1599
|
+
class_names = dataset_for_artifacts.dataset.classes # type: ignore
|
|
1059
1600
|
except AttributeError:
|
|
1060
1601
|
_LOGGER.warning("Could not find 'classes' attribute on dataset. Per-class metrics will not be named.")
|
|
1061
1602
|
pass # class_names is still None
|
|
@@ -1068,36 +1609,451 @@ class ObjectDetectionTrainer:
|
|
|
1068
1609
|
class_names=class_names,
|
|
1069
1610
|
print_output=False
|
|
1070
1611
|
)
|
|
1071
|
-
|
|
1072
|
-
# print("\n--- Training History ---")
|
|
1073
|
-
plot_losses(self.history, save_dir=save_dir)
|
|
1074
1612
|
|
|
1075
|
-
def
|
|
1076
|
-
|
|
1077
|
-
|
|
1078
|
-
|
|
1079
|
-
|
|
1613
|
+
def finalize_model_training(self,
|
|
1614
|
+
save_dir: Union[str, Path],
|
|
1615
|
+
model_checkpoint: Union[Path, Literal['latest', 'current']],
|
|
1616
|
+
finalize_config: FinalizeObjectDetection
|
|
1617
|
+
):
|
|
1618
|
+
"""
|
|
1619
|
+
Saves a finalized, "inference-ready" model state to a .pth file.
|
|
1620
|
+
|
|
1621
|
+
This method saves the model's `state_dict` and the final epoch number.
|
|
1622
|
+
|
|
1623
|
+
Args:
|
|
1624
|
+
save_dir (Union[str, Path]): The directory to save the finalized model.
|
|
1625
|
+
model_checkpoint (Union[Path, Literal["latest", "current"]]):
|
|
1626
|
+
- Path: Loads the model state from a specific checkpoint file.
|
|
1627
|
+
- "latest": Loads the best model state saved by the `DragonModelCheckpoint` callback.
|
|
1628
|
+
- "current": Uses the model's state as it is at the end of the `fit()` call.
|
|
1629
|
+
finalize_config (FinalizeObjectDetection): A data class instance specific to the ML task containing task-specific metadata required for inference.
|
|
1630
|
+
"""
|
|
1631
|
+
if not isinstance(finalize_config, FinalizeObjectDetection):
|
|
1632
|
+
_LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeObjectDetection', but got {type(finalize_config).__name__}.")
|
|
1633
|
+
raise TypeError()
|
|
1634
|
+
|
|
1635
|
+
# handle save path
|
|
1636
|
+
dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
1637
|
+
full_path = dir_path / finalize_config.filename
|
|
1638
|
+
|
|
1639
|
+
# handle checkpoint
|
|
1640
|
+
self._load_model_state_for_finalizing(model_checkpoint)
|
|
1641
|
+
|
|
1642
|
+
# Create finalized data
|
|
1643
|
+
finalized_data = {
|
|
1644
|
+
PyTorchCheckpointKeys.EPOCH: self.epoch,
|
|
1645
|
+
PyTorchCheckpointKeys.MODEL_STATE: self.model.state_dict(),
|
|
1646
|
+
}
|
|
1647
|
+
|
|
1648
|
+
if finalize_config.class_map is not None:
|
|
1649
|
+
finalized_data[PyTorchCheckpointKeys.CLASS_MAP] = finalize_config.class_map
|
|
1650
|
+
|
|
1651
|
+
torch.save(finalized_data, full_path)
|
|
1652
|
+
|
|
1653
|
+
_LOGGER.info(f"Finalized model file saved to '{full_path}'")
|
|
1654
|
+
|
|
1655
|
+
# --- DragonSequenceTrainer ----
|
|
1656
|
+
class DragonSequenceTrainer(_BaseDragonTrainer):
|
|
1657
|
+
def __init__(self,
|
|
1658
|
+
model: nn.Module,
|
|
1659
|
+
train_dataset: Dataset,
|
|
1660
|
+
validation_dataset: Dataset,
|
|
1661
|
+
kind: Literal["sequence-to-sequence", "sequence-to-value"],
|
|
1662
|
+
optimizer: torch.optim.Optimizer,
|
|
1663
|
+
device: Union[Literal['cuda', 'mps', 'cpu'],str],
|
|
1664
|
+
checkpoint_callback: Optional[DragonModelCheckpoint],
|
|
1665
|
+
early_stopping_callback: Optional[DragonEarlyStopping],
|
|
1666
|
+
lr_scheduler_callback: Optional[DragonLRScheduler],
|
|
1667
|
+
extra_callbacks: Optional[List[_Callback]] = None,
|
|
1668
|
+
criterion: Union[nn.Module,Literal["auto"]] = "auto",
|
|
1669
|
+
dataloader_workers: int = 2):
|
|
1670
|
+
"""
|
|
1671
|
+
Automates the training process of a PyTorch Sequence Model.
|
|
1672
|
+
|
|
1673
|
+
Built-in Callbacks: `History`, `TqdmProgressBar`
|
|
1674
|
+
|
|
1675
|
+
Args:
|
|
1676
|
+
model (nn.Module): The PyTorch model to train.
|
|
1677
|
+
train_dataset (Dataset): The training dataset.
|
|
1678
|
+
validation_dataset (Dataset): The validation dataset.
|
|
1679
|
+
kind (str): Used to redirect to the correct process ('sequence-to-sequence' or 'sequence-to-value').
|
|
1680
|
+
criterion (nn.Module | "auto"): The loss function to use. If "auto", it will be inferred from the selected task
|
|
1681
|
+
optimizer (torch.optim.Optimizer): The optimizer.
|
|
1682
|
+
device (str): The device to run training on ('cpu', 'cuda', 'mps').
|
|
1683
|
+
dataloader_workers (int): Subprocesses for data loading.
|
|
1684
|
+
extra_callbacks (List[Callback] | None): A list of extra callbacks to use during training.
|
|
1685
|
+
"""
|
|
1686
|
+
# Call the base class constructor with common parameters
|
|
1687
|
+
super().__init__(
|
|
1688
|
+
model=model,
|
|
1689
|
+
optimizer=optimizer,
|
|
1690
|
+
device=device,
|
|
1691
|
+
dataloader_workers=dataloader_workers,
|
|
1692
|
+
checkpoint_callback=checkpoint_callback,
|
|
1693
|
+
early_stopping_callback=early_stopping_callback,
|
|
1694
|
+
lr_scheduler_callback=lr_scheduler_callback,
|
|
1695
|
+
extra_callbacks=extra_callbacks
|
|
1696
|
+
)
|
|
1697
|
+
|
|
1698
|
+
if kind not in [MLTaskKeys.SEQUENCE_SEQUENCE, MLTaskKeys.SEQUENCE_VALUE]:
|
|
1699
|
+
raise ValueError(f"'{kind}' is not a valid task type for DragonSequenceTrainer.")
|
|
1700
|
+
|
|
1701
|
+
self.train_dataset = train_dataset
|
|
1702
|
+
self.validation_dataset = validation_dataset
|
|
1703
|
+
self.kind = kind
|
|
1704
|
+
|
|
1705
|
+
# try to validate against Dragon Sequence model
|
|
1706
|
+
if hasattr(self.model, "prediction_mode"):
|
|
1707
|
+
key_to_check: str = self.model.prediction_mode # type: ignore
|
|
1708
|
+
if not key_to_check == self.kind:
|
|
1709
|
+
_LOGGER.error(f"Trainer was set for '{self.kind}', but model architecture '{self.model}' is built for '{key_to_check}'.")
|
|
1710
|
+
raise RuntimeError()
|
|
1711
|
+
|
|
1712
|
+
# loss function
|
|
1713
|
+
if criterion == "auto":
|
|
1714
|
+
# Both sequence tasks are treated as regression problems
|
|
1715
|
+
self.criterion = nn.MSELoss()
|
|
1716
|
+
else:
|
|
1717
|
+
self.criterion = criterion
|
|
1718
|
+
|
|
1719
|
+
def _create_dataloaders(self, batch_size: int, shuffle: bool):
|
|
1720
|
+
"""Initializes the DataLoaders."""
|
|
1721
|
+
# Ensure stability on MPS devices by setting num_workers to 0
|
|
1722
|
+
loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
|
|
1723
|
+
|
|
1724
|
+
self.train_loader = DataLoader(
|
|
1725
|
+
dataset=self.train_dataset,
|
|
1726
|
+
batch_size=batch_size,
|
|
1727
|
+
shuffle=shuffle,
|
|
1728
|
+
num_workers=loader_workers,
|
|
1729
|
+
pin_memory=("cuda" in self.device.type),
|
|
1730
|
+
drop_last=True # Drops the last batch if incomplete, selecting a good batch size is key.
|
|
1731
|
+
)
|
|
1732
|
+
|
|
1733
|
+
self.validation_loader = DataLoader(
|
|
1734
|
+
dataset=self.validation_dataset,
|
|
1735
|
+
batch_size=batch_size,
|
|
1736
|
+
shuffle=False,
|
|
1737
|
+
num_workers=loader_workers,
|
|
1738
|
+
pin_memory=("cuda" in self.device.type)
|
|
1739
|
+
)
|
|
1740
|
+
|
|
1741
|
+
def _train_step(self):
|
|
1742
|
+
self.model.train()
|
|
1743
|
+
running_loss = 0.0
|
|
1744
|
+
total_samples = 0
|
|
1745
|
+
|
|
1746
|
+
for batch_idx, (features, target) in enumerate(self.train_loader): # type: ignore
|
|
1747
|
+
# Create a log dictionary for the batch
|
|
1748
|
+
batch_logs = {
|
|
1749
|
+
PyTorchLogKeys.BATCH_INDEX: batch_idx,
|
|
1750
|
+
PyTorchLogKeys.BATCH_SIZE: features.size(0)
|
|
1751
|
+
}
|
|
1752
|
+
self._callbacks_hook('on_batch_begin', batch_idx, logs=batch_logs)
|
|
1753
|
+
|
|
1754
|
+
features, target = features.to(self.device), target.to(self.device)
|
|
1755
|
+
self.optimizer.zero_grad()
|
|
1080
1756
|
|
|
1081
|
-
|
|
1757
|
+
output = self.model(features)
|
|
1758
|
+
|
|
1759
|
+
# --- Label Type/Shape Correction ---
|
|
1760
|
+
# Ensure target is float for MSELoss
|
|
1761
|
+
target = target.float()
|
|
1762
|
+
|
|
1763
|
+
# For seq-to-val, models might output [N, 1] but target is [N].
|
|
1764
|
+
if self.kind == MLTaskKeys.SEQUENCE_VALUE:
|
|
1765
|
+
if output.ndim == 2 and output.shape[1] == 1 and target.ndim == 1:
|
|
1766
|
+
output = output.squeeze(1)
|
|
1767
|
+
|
|
1768
|
+
# For seq-to-seq, models might output [N, Seq, 1] but target is [N, Seq].
|
|
1769
|
+
elif self.kind == MLTaskKeys.SEQUENCE_SEQUENCE:
|
|
1770
|
+
if output.ndim == 3 and output.shape[2] == 1 and target.ndim == 2:
|
|
1771
|
+
output = output.squeeze(-1)
|
|
1772
|
+
|
|
1773
|
+
loss = self.criterion(output, target)
|
|
1774
|
+
|
|
1775
|
+
loss.backward()
|
|
1776
|
+
self.optimizer.step()
|
|
1777
|
+
|
|
1778
|
+
# Calculate batch loss and update running loss for the epoch
|
|
1779
|
+
batch_loss = loss.item()
|
|
1780
|
+
batch_size = features.size(0)
|
|
1781
|
+
running_loss += batch_loss * batch_size # Accumulate total loss
|
|
1782
|
+
total_samples += batch_size # total samples
|
|
1783
|
+
|
|
1784
|
+
# Add the batch loss to the logs and call the end-of-batch hook
|
|
1785
|
+
batch_logs[PyTorchLogKeys.BATCH_LOSS] = batch_loss
|
|
1786
|
+
self._callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
|
|
1787
|
+
|
|
1788
|
+
if total_samples == 0:
|
|
1789
|
+
_LOGGER.warning("No samples processed in a train_step. Returning 0 loss.")
|
|
1790
|
+
return {PyTorchLogKeys.TRAIN_LOSS: 0.0}
|
|
1791
|
+
|
|
1792
|
+
return {PyTorchLogKeys.TRAIN_LOSS: running_loss / total_samples} # type: ignore
|
|
1793
|
+
|
|
1794
|
+
def _validation_step(self):
|
|
1795
|
+
self.model.eval()
|
|
1796
|
+
running_loss = 0.0
|
|
1797
|
+
|
|
1798
|
+
with torch.no_grad():
|
|
1799
|
+
for features, target in self.validation_loader: # type: ignore
|
|
1800
|
+
features, target = features.to(self.device), target.to(self.device)
|
|
1801
|
+
|
|
1802
|
+
output = self.model(features)
|
|
1803
|
+
|
|
1804
|
+
# --- Label Type/Shape Correction ---
|
|
1805
|
+
target = target.float()
|
|
1806
|
+
|
|
1807
|
+
# For seq-to-val, models might output [N, 1] but target is [N].
|
|
1808
|
+
if self.kind == MLTaskKeys.SEQUENCE_VALUE:
|
|
1809
|
+
if output.ndim == 2 and output.shape[1] == 1 and target.ndim == 1:
|
|
1810
|
+
output = output.squeeze(1)
|
|
1811
|
+
|
|
1812
|
+
# For seq-to-seq, models might output [N, Seq, 1] but target is [N, Seq].
|
|
1813
|
+
elif self.kind == MLTaskKeys.SEQUENCE_SEQUENCE:
|
|
1814
|
+
if output.ndim == 3 and output.shape[2] == 1 and target.ndim == 2:
|
|
1815
|
+
output = output.squeeze(-1)
|
|
1816
|
+
|
|
1817
|
+
loss = self.criterion(output, target)
|
|
1818
|
+
|
|
1819
|
+
running_loss += loss.item() * features.size(0)
|
|
1820
|
+
|
|
1821
|
+
if not self.validation_loader.dataset: # type: ignore
|
|
1822
|
+
_LOGGER.warning("No samples processed in _validation_step. Returning 0 loss.")
|
|
1823
|
+
return {PyTorchLogKeys.VAL_LOSS: 0.0}
|
|
1824
|
+
|
|
1825
|
+
logs = {PyTorchLogKeys.VAL_LOSS: running_loss / len(self.validation_loader.dataset)} # type: ignore
|
|
1826
|
+
return logs
|
|
1827
|
+
|
|
1828
|
+
def _predict_for_eval(self, dataloader: DataLoader):
|
|
1082
1829
|
"""
|
|
1083
|
-
|
|
1830
|
+
Private method to yield model predictions batch by batch for evaluation.
|
|
1084
1831
|
|
|
1085
|
-
|
|
1832
|
+
Yields:
|
|
1833
|
+
tuple: A tuple containing (y_pred_batch, y_prob_batch, y_true_batch).
|
|
1834
|
+
y_prob_batch is always None for sequence tasks.
|
|
1086
1835
|
"""
|
|
1087
|
-
self.
|
|
1836
|
+
self.model.eval()
|
|
1088
1837
|
self.model.to(self.device)
|
|
1089
|
-
|
|
1838
|
+
|
|
1839
|
+
with torch.no_grad():
|
|
1840
|
+
for features, target in dataloader:
|
|
1841
|
+
features = features.to(self.device)
|
|
1842
|
+
output = self.model(features).cpu()
|
|
1843
|
+
|
|
1844
|
+
y_pred_batch = output.numpy()
|
|
1845
|
+
y_prob_batch = None # Not applicable for sequence regression
|
|
1846
|
+
y_true_batch = target.numpy()
|
|
1847
|
+
|
|
1848
|
+
yield y_pred_batch, y_prob_batch, y_true_batch
|
|
1849
|
+
|
|
1850
|
+
def evaluate(self,
|
|
1851
|
+
save_dir: Union[str, Path],
|
|
1852
|
+
model_checkpoint: Union[Path, Literal["latest", "current"]],
|
|
1853
|
+
test_data: Optional[Union[DataLoader, Dataset]] = None,
|
|
1854
|
+
val_format_configuration: Optional[Union[SequenceValueMetricsFormat,
|
|
1855
|
+
SequenceSequenceMetricsFormat]]=None,
|
|
1856
|
+
test_format_configuration: Optional[Union[SequenceValueMetricsFormat,
|
|
1857
|
+
SequenceSequenceMetricsFormat]]=None):
|
|
1858
|
+
"""
|
|
1859
|
+
Evaluates the model, routing to the correct evaluation function.
|
|
1860
|
+
|
|
1861
|
+
Args:
|
|
1862
|
+
model_checkpoint ('auto' | Path | None):
|
|
1863
|
+
- Path to a valid checkpoint for the model.
|
|
1864
|
+
- If 'latest', the latest checkpoint will be loaded.
|
|
1865
|
+
- If 'current', use the current state of the trained model.
|
|
1866
|
+
save_dir (str | Path): Directory to save all reports and plots.
|
|
1867
|
+
test_data (DataLoader | Dataset | None): Optional Test data.
|
|
1868
|
+
val_format_configuration: Optional configuration for validation metrics.
|
|
1869
|
+
test_format_configuration: Optional configuration for test metrics.
|
|
1870
|
+
"""
|
|
1871
|
+
# Validate model checkpoint
|
|
1872
|
+
if isinstance(model_checkpoint, Path):
|
|
1873
|
+
checkpoint_validated = make_fullpath(model_checkpoint, enforce="file")
|
|
1874
|
+
elif model_checkpoint in [MagicWords.LATEST, MagicWords.CURRENT]:
|
|
1875
|
+
checkpoint_validated = model_checkpoint
|
|
1876
|
+
else:
|
|
1877
|
+
_LOGGER.error(f"'model_checkpoint' must be a Path object, or '{MagicWords.LATEST}', or '{MagicWords.CURRENT}'.")
|
|
1878
|
+
raise ValueError()
|
|
1879
|
+
|
|
1880
|
+
# Validate val configuration
|
|
1881
|
+
if val_format_configuration is not None:
|
|
1882
|
+
if not isinstance(val_format_configuration, (SequenceValueMetricsFormat, SequenceSequenceMetricsFormat)):
|
|
1883
|
+
_LOGGER.error(f"Invalid 'val_format_configuration': '{type(val_format_configuration)}'.")
|
|
1884
|
+
raise ValueError()
|
|
1885
|
+
|
|
1886
|
+
# Validate directory
|
|
1887
|
+
save_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
1888
|
+
|
|
1889
|
+
# Validate test data and dispatch
|
|
1890
|
+
if test_data is not None:
|
|
1891
|
+
if not isinstance(test_data, (DataLoader, Dataset)):
|
|
1892
|
+
_LOGGER.error(f"Invalid type for 'test_data': '{type(test_data)}'.")
|
|
1893
|
+
raise ValueError()
|
|
1894
|
+
test_data_validated = test_data
|
|
1090
1895
|
|
|
1091
|
-
|
|
1896
|
+
validation_metrics_path = save_path / DragonTrainerKeys.VALIDATION_METRICS_DIR
|
|
1897
|
+
test_metrics_path = save_path / DragonTrainerKeys.TEST_METRICS_DIR
|
|
1898
|
+
|
|
1899
|
+
# Dispatch validation set
|
|
1900
|
+
_LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{DragonTrainerKeys.VALIDATION_METRICS_DIR}'")
|
|
1901
|
+
self._evaluate(save_dir=validation_metrics_path,
|
|
1902
|
+
model_checkpoint=checkpoint_validated,
|
|
1903
|
+
data=None,
|
|
1904
|
+
format_configuration=val_format_configuration)
|
|
1905
|
+
|
|
1906
|
+
# Validate test configuration
|
|
1907
|
+
test_configuration_validated = None
|
|
1908
|
+
if test_format_configuration is not None:
|
|
1909
|
+
if not isinstance(test_format_configuration, (SequenceValueMetricsFormat, SequenceSequenceMetricsFormat)):
|
|
1910
|
+
warning_message_type = f"Invalid test_format_configuration': '{type(test_format_configuration)}'."
|
|
1911
|
+
if val_format_configuration is not None:
|
|
1912
|
+
warning_message_type += " 'val_format_configuration' will be used."
|
|
1913
|
+
test_configuration_validated = val_format_configuration
|
|
1914
|
+
else:
|
|
1915
|
+
warning_message_type += " Using default format."
|
|
1916
|
+
_LOGGER.warning(warning_message_type)
|
|
1917
|
+
else:
|
|
1918
|
+
test_configuration_validated = test_format_configuration
|
|
1919
|
+
|
|
1920
|
+
# Dispatch test set
|
|
1921
|
+
_LOGGER.info(f"Evaluating on test dataset. Metrics will be saved to '{DragonTrainerKeys.TEST_METRICS_DIR}'")
|
|
1922
|
+
self._evaluate(save_dir=test_metrics_path,
|
|
1923
|
+
model_checkpoint="current",
|
|
1924
|
+
data=test_data_validated,
|
|
1925
|
+
format_configuration=test_configuration_validated)
|
|
1926
|
+
else:
|
|
1927
|
+
# Dispatch validation set
|
|
1928
|
+
_LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{save_path.name}'")
|
|
1929
|
+
self._evaluate(save_dir=save_path,
|
|
1930
|
+
model_checkpoint=checkpoint_validated,
|
|
1931
|
+
data=None,
|
|
1932
|
+
format_configuration=val_format_configuration)
|
|
1933
|
+
|
|
1934
|
+
def _evaluate(self,
|
|
1935
|
+
save_dir: Union[str, Path],
|
|
1936
|
+
model_checkpoint: Union[Path, Literal["latest", "current"]],
|
|
1937
|
+
data: Optional[Union[DataLoader, Dataset]],
|
|
1938
|
+
format_configuration: object):
|
|
1092
1939
|
"""
|
|
1093
|
-
|
|
1940
|
+
Private evaluation helper.
|
|
1941
|
+
"""
|
|
1942
|
+
eval_loader = None
|
|
1943
|
+
|
|
1944
|
+
# load model checkpoint
|
|
1945
|
+
if isinstance(model_checkpoint, Path):
|
|
1946
|
+
self._load_checkpoint(path=model_checkpoint)
|
|
1947
|
+
elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback:
|
|
1948
|
+
path_to_latest = self._checkpoint_callback.best_checkpoint_path
|
|
1949
|
+
self._load_checkpoint(path_to_latest)
|
|
1950
|
+
elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback is None:
|
|
1951
|
+
_LOGGER.error(f"'model_checkpoint' set to '{MagicWords.LATEST}' but no checkpoint callback was found.")
|
|
1952
|
+
raise ValueError()
|
|
1953
|
+
|
|
1954
|
+
# Dataloader
|
|
1955
|
+
if isinstance(data, DataLoader):
|
|
1956
|
+
eval_loader = data
|
|
1957
|
+
elif isinstance(data, Dataset):
|
|
1958
|
+
# Create a new loader from the provided dataset
|
|
1959
|
+
eval_loader = DataLoader(data,
|
|
1960
|
+
batch_size=self._batch_size,
|
|
1961
|
+
shuffle=False,
|
|
1962
|
+
num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
|
|
1963
|
+
pin_memory=(self.device.type == "cuda"))
|
|
1964
|
+
else: # data is None, use the trainer's default validation dataset
|
|
1965
|
+
if self.validation_dataset is None:
|
|
1966
|
+
_LOGGER.error("Cannot evaluate. No data provided and no validation_dataset available in the trainer.")
|
|
1967
|
+
raise ValueError()
|
|
1968
|
+
eval_loader = DataLoader(self.validation_dataset,
|
|
1969
|
+
batch_size=self._batch_size,
|
|
1970
|
+
shuffle=False,
|
|
1971
|
+
num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
|
|
1972
|
+
pin_memory=(self.device.type == "cuda"))
|
|
1973
|
+
|
|
1974
|
+
if eval_loader is None:
|
|
1975
|
+
_LOGGER.error("Cannot evaluate. No valid data was provided or found.")
|
|
1976
|
+
raise ValueError()
|
|
1977
|
+
|
|
1978
|
+
all_preds, _, all_true = [], [], []
|
|
1979
|
+
for y_pred_b, y_prob_b, y_true_b in self._predict_for_eval(eval_loader):
|
|
1980
|
+
if y_pred_b is not None: all_preds.append(y_pred_b)
|
|
1981
|
+
if y_true_b is not None: all_true.append(y_true_b)
|
|
1982
|
+
|
|
1983
|
+
if not all_true:
|
|
1984
|
+
_LOGGER.error("Evaluation failed: No data was processed.")
|
|
1985
|
+
return
|
|
1986
|
+
|
|
1987
|
+
y_pred = np.concatenate(all_preds)
|
|
1988
|
+
y_true = np.concatenate(all_true)
|
|
1989
|
+
|
|
1990
|
+
# --- Routing Logic ---
|
|
1991
|
+
if self.kind == MLTaskKeys.SEQUENCE_VALUE:
|
|
1992
|
+
config = None
|
|
1993
|
+
if format_configuration and isinstance(format_configuration, SequenceValueMetricsFormat):
|
|
1994
|
+
config = format_configuration
|
|
1995
|
+
elif format_configuration:
|
|
1996
|
+
_LOGGER.warning(f"Wrong config type: Received {type(format_configuration).__name__}, expected SequenceValueMetricsFormat.")
|
|
1997
|
+
|
|
1998
|
+
sequence_to_value_metrics(y_true=y_true,
|
|
1999
|
+
y_pred=y_pred,
|
|
2000
|
+
save_dir=save_dir,
|
|
2001
|
+
config=config)
|
|
2002
|
+
|
|
2003
|
+
elif self.kind == MLTaskKeys.SEQUENCE_SEQUENCE:
|
|
2004
|
+
config = None
|
|
2005
|
+
if format_configuration and isinstance(format_configuration, SequenceSequenceMetricsFormat):
|
|
2006
|
+
config = format_configuration
|
|
2007
|
+
elif format_configuration:
|
|
2008
|
+
_LOGGER.warning(f"Wrong config type: Received {type(format_configuration).__name__}, expected SequenceSequenceMetricsFormat.")
|
|
2009
|
+
|
|
2010
|
+
sequence_to_sequence_metrics(y_true=y_true,
|
|
2011
|
+
y_pred=y_pred,
|
|
2012
|
+
save_dir=save_dir,
|
|
2013
|
+
config=config)
|
|
2014
|
+
|
|
2015
|
+
def finalize_model_training(self,
|
|
2016
|
+
save_dir: Union[str, Path],
|
|
2017
|
+
model_checkpoint: Union[Path, Literal['latest', 'current']],
|
|
2018
|
+
finalize_config: FinalizeSequencePrediction):
|
|
2019
|
+
"""
|
|
2020
|
+
Saves a finalized, "inference-ready" model state to a .pth file.
|
|
2021
|
+
|
|
2022
|
+
This method saves the model's `state_dict` and the final epoch number.
|
|
1094
2023
|
|
|
1095
2024
|
Args:
|
|
1096
|
-
|
|
2025
|
+
save_dir (Union[str, Path]): The directory to save the finalized model.
|
|
2026
|
+
model_checkpoint (Union[Path, Literal["latest", "current"]]):
|
|
2027
|
+
- Path: Loads the model state from a specific checkpoint file.
|
|
2028
|
+
- "latest": Loads the best model state saved by the `DragonModelCheckpoint` callback.
|
|
2029
|
+
- "current": Uses the model's state as it is at the end of the `fit()` call.
|
|
2030
|
+
finalize_config (FinalizeSequencePrediction): A data class instance specific to the ML task containing task-specific metadata required for inference.
|
|
1097
2031
|
"""
|
|
1098
|
-
|
|
1099
|
-
|
|
1100
|
-
|
|
2032
|
+
if not isinstance(finalize_config, FinalizeSequencePrediction):
|
|
2033
|
+
_LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeSequencePrediction', but got {type(finalize_config).__name__}.")
|
|
2034
|
+
raise TypeError()
|
|
2035
|
+
|
|
2036
|
+
# handle save path
|
|
2037
|
+
dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
2038
|
+
full_path = dir_path / finalize_config.filename
|
|
2039
|
+
|
|
2040
|
+
# handle checkpoint
|
|
2041
|
+
self._load_model_state_for_finalizing(model_checkpoint)
|
|
2042
|
+
|
|
2043
|
+
# Create finalized data
|
|
2044
|
+
finalized_data = {
|
|
2045
|
+
PyTorchCheckpointKeys.EPOCH: self.epoch,
|
|
2046
|
+
PyTorchCheckpointKeys.MODEL_STATE: self.model.state_dict(),
|
|
2047
|
+
}
|
|
2048
|
+
|
|
2049
|
+
if finalize_config.sequence_length is not None:
|
|
2050
|
+
finalized_data[PyTorchCheckpointKeys.SEQUENCE_LENGTH] = finalize_config.sequence_length
|
|
2051
|
+
if finalize_config.initial_sequence is not None:
|
|
2052
|
+
finalized_data[PyTorchCheckpointKeys.INITIAL_SEQUENCE] = finalize_config.initial_sequence
|
|
2053
|
+
|
|
2054
|
+
torch.save(finalized_data, full_path)
|
|
2055
|
+
|
|
2056
|
+
_LOGGER.info(f"Finalized model file saved to '{full_path}'")
|
|
1101
2057
|
|
|
1102
2058
|
|
|
1103
2059
|
def info():
|