dragon-ml-toolbox 14.3.1__py3-none-any.whl → 16.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dragon-ml-toolbox might be problematic. Click here for more details.
- {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/METADATA +10 -5
- dragon_ml_toolbox-16.0.0.dist-info/RECORD +51 -0
- ml_tools/ETL_cleaning.py +20 -20
- ml_tools/ETL_engineering.py +23 -25
- ml_tools/GUI_tools.py +20 -20
- ml_tools/MICE_imputation.py +3 -3
- ml_tools/ML_callbacks.py +43 -26
- ml_tools/ML_configuration.py +309 -0
- ml_tools/ML_datasetmaster.py +220 -260
- ml_tools/ML_evaluation.py +317 -81
- ml_tools/ML_evaluation_multi.py +127 -36
- ml_tools/ML_inference.py +249 -207
- ml_tools/ML_models.py +13 -102
- ml_tools/ML_models_advanced.py +1 -1
- ml_tools/ML_optimization.py +12 -12
- ml_tools/ML_scaler.py +11 -11
- ml_tools/ML_sequence_datasetmaster.py +341 -0
- ml_tools/ML_sequence_evaluation.py +215 -0
- ml_tools/ML_sequence_inference.py +391 -0
- ml_tools/ML_sequence_models.py +139 -0
- ml_tools/ML_trainer.py +1247 -338
- ml_tools/ML_utilities.py +51 -2
- ml_tools/ML_vision_datasetmaster.py +262 -118
- ml_tools/ML_vision_evaluation.py +26 -6
- ml_tools/ML_vision_inference.py +117 -140
- ml_tools/ML_vision_models.py +15 -1
- ml_tools/ML_vision_transformers.py +233 -7
- ml_tools/PSO_optimization.py +6 -6
- ml_tools/SQL.py +4 -4
- ml_tools/{keys.py → _keys.py} +45 -1
- ml_tools/_schema.py +1 -1
- ml_tools/ensemble_evaluation.py +54 -11
- ml_tools/ensemble_inference.py +7 -33
- ml_tools/ensemble_learning.py +1 -1
- ml_tools/optimization_tools.py +2 -2
- ml_tools/path_manager.py +5 -5
- ml_tools/utilities.py +1 -2
- dragon_ml_toolbox-14.3.1.dist-info/RECORD +0 -48
- ml_tools/RNN_forecast.py +0 -56
- ml_tools/_ML_vision_recipe.py +0 -88
- {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/top_level.txt +0 -0
ml_tools/ML_trainer.py
CHANGED
|
@@ -1,79 +1,79 @@
|
|
|
1
|
-
from typing import List, Literal, Union, Optional, Callable, Dict, Any
|
|
1
|
+
from typing import List, Literal, Union, Optional, Callable, Dict, Any
|
|
2
2
|
from pathlib import Path
|
|
3
3
|
from torch.utils.data import DataLoader, Dataset
|
|
4
4
|
import torch
|
|
5
5
|
from torch import nn
|
|
6
6
|
import numpy as np
|
|
7
|
+
from abc import ABC, abstractmethod
|
|
7
8
|
|
|
8
|
-
from .
|
|
9
|
+
from .path_manager import make_fullpath, sanitize_filename
|
|
10
|
+
from .ML_callbacks import _Callback, History, TqdmProgressBar, DragonModelCheckpoint, DragonEarlyStopping, DragonLRScheduler
|
|
9
11
|
from .ML_evaluation import classification_metrics, regression_metrics, plot_losses, shap_summary_plot, plot_attention_importance
|
|
10
12
|
from .ML_evaluation_multi import multi_target_regression_metrics, multi_label_classification_metrics, multi_target_shap_summary_plot
|
|
13
|
+
from .ML_vision_evaluation import segmentation_metrics, object_detection_metrics
|
|
14
|
+
from .ML_sequence_evaluation import sequence_to_sequence_metrics, sequence_to_value_metrics
|
|
15
|
+
from .ML_configuration import (ClassificationMetricsFormat,
|
|
16
|
+
MultiClassificationMetricsFormat,
|
|
17
|
+
RegressionMetricsFormat,
|
|
18
|
+
SegmentationMetricsFormat,
|
|
19
|
+
SequenceValueMetricsFormat,
|
|
20
|
+
SequenceSequenceMetricsFormat)
|
|
21
|
+
|
|
11
22
|
from ._script_info import _script_info
|
|
12
|
-
from .
|
|
23
|
+
from ._keys import PyTorchLogKeys, PyTorchCheckpointKeys, DatasetKeys, MLTaskKeys, MagicWords, DragonTrainerKeys
|
|
13
24
|
from ._logger import _LOGGER
|
|
14
|
-
from .path_manager import make_fullpath
|
|
15
|
-
from .ML_vision_evaluation import segmentation_metrics, object_detection_metrics
|
|
16
25
|
|
|
17
26
|
|
|
18
27
|
__all__ = [
|
|
19
|
-
"
|
|
20
|
-
"
|
|
28
|
+
"DragonTrainer",
|
|
29
|
+
"DragonDetectionTrainer",
|
|
30
|
+
"DragonSequenceTrainer"
|
|
21
31
|
]
|
|
22
32
|
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
kind: Literal["regression", "classification", "multi_target_regression", "multi_label_classification", "segmentation"],
|
|
27
|
-
criterion: nn.Module, optimizer: torch.optim.Optimizer,
|
|
28
|
-
device: Union[Literal['cuda', 'mps', 'cpu'],str], dataloader_workers: int = 2, callbacks: Optional[List[Callback]] = None):
|
|
29
|
-
"""
|
|
30
|
-
Automates the training process of a PyTorch Model.
|
|
31
|
-
|
|
32
|
-
Built-in Callbacks: `History`, `TqdmProgressBar`
|
|
33
|
-
|
|
34
|
-
Args:
|
|
35
|
-
model (nn.Module): The PyTorch model to train.
|
|
36
|
-
train_dataset (Dataset): The training dataset.
|
|
37
|
-
test_dataset (Dataset): The testing/validation dataset.
|
|
38
|
-
kind (str): Can be 'regression', 'classification', 'multi_target_regression', 'multi_label_classification', or 'segmentation'.
|
|
39
|
-
criterion (nn.Module): The loss function.
|
|
40
|
-
optimizer (torch.optim.Optimizer): The optimizer.
|
|
41
|
-
device (str): The device to run training on ('cpu', 'cuda', 'mps').
|
|
42
|
-
dataloader_workers (int): Subprocesses for data loading.
|
|
43
|
-
callbacks (List[Callback] | None): A list of callbacks to use during training.
|
|
44
|
-
|
|
45
|
-
Note:
|
|
46
|
-
- For **regression** and **multi_target_regression** tasks, suggested criterions include `nn.MSELoss` or `nn.L1Loss`.
|
|
47
|
-
|
|
48
|
-
- For **single-label, multi-class classification** tasks, `nn.CrossEntropyLoss` is the standard choice.
|
|
33
|
+
class _BaseDragonTrainer(ABC):
|
|
34
|
+
"""
|
|
35
|
+
Abstract base class for Dragon Trainers.
|
|
49
36
|
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
37
|
+
Handles the common training loop orchestration, checkpointing, callback
|
|
38
|
+
management, and device handling. Subclasses must implement the
|
|
39
|
+
task-specific logic (dataloaders, train/val steps, evaluation).
|
|
40
|
+
"""
|
|
41
|
+
def __init__(self,
|
|
42
|
+
model: nn.Module,
|
|
43
|
+
optimizer: torch.optim.Optimizer,
|
|
44
|
+
device: Union[Literal['cuda', 'mps', 'cpu'],str],
|
|
45
|
+
dataloader_workers: int = 2,
|
|
46
|
+
checkpoint_callback: Optional[DragonModelCheckpoint] = None,
|
|
47
|
+
early_stopping_callback: Optional[DragonEarlyStopping] = None,
|
|
48
|
+
lr_scheduler_callback: Optional[DragonLRScheduler] = None,
|
|
49
|
+
extra_callbacks: Optional[List[_Callback]] = None):
|
|
56
50
|
|
|
57
51
|
self.model = model
|
|
58
|
-
self.train_dataset = train_dataset
|
|
59
|
-
self.test_dataset = test_dataset
|
|
60
|
-
self.kind = kind
|
|
61
|
-
self.criterion = criterion
|
|
62
52
|
self.optimizer = optimizer
|
|
63
53
|
self.scheduler = None
|
|
64
54
|
self.device = self._validate_device(device)
|
|
65
55
|
self.dataloader_workers = dataloader_workers
|
|
66
56
|
|
|
67
|
-
# Callback handler
|
|
57
|
+
# Callback handler
|
|
68
58
|
default_callbacks = [History(), TqdmProgressBar()]
|
|
69
|
-
|
|
59
|
+
|
|
60
|
+
self._checkpoint_callback = None
|
|
61
|
+
if checkpoint_callback:
|
|
62
|
+
default_callbacks.append(checkpoint_callback)
|
|
63
|
+
self._checkpoint_callback = checkpoint_callback
|
|
64
|
+
if early_stopping_callback:
|
|
65
|
+
default_callbacks.append(early_stopping_callback)
|
|
66
|
+
if lr_scheduler_callback:
|
|
67
|
+
default_callbacks.append(lr_scheduler_callback)
|
|
68
|
+
|
|
69
|
+
user_callbacks = extra_callbacks if extra_callbacks is not None else []
|
|
70
70
|
self.callbacks = default_callbacks + user_callbacks
|
|
71
71
|
self._set_trainer_on_callbacks()
|
|
72
72
|
|
|
73
73
|
# Internal state
|
|
74
|
-
self.train_loader = None
|
|
75
|
-
self.
|
|
76
|
-
self.history = {}
|
|
74
|
+
self.train_loader: Optional[DataLoader] = None
|
|
75
|
+
self.validation_loader: Optional[DataLoader] = None
|
|
76
|
+
self.history: Dict[str, List[Any]] = {}
|
|
77
77
|
self.epoch = 0
|
|
78
78
|
self.epochs = 0 # Total epochs for the fit run
|
|
79
79
|
self.start_epoch = 1
|
|
@@ -96,32 +96,10 @@ class MLTrainer:
|
|
|
96
96
|
for callback in self.callbacks:
|
|
97
97
|
callback.set_trainer(self)
|
|
98
98
|
|
|
99
|
-
def _create_dataloaders(self, batch_size: int, shuffle: bool):
|
|
100
|
-
"""Initializes the DataLoaders."""
|
|
101
|
-
# Ensure stability on MPS devices by setting num_workers to 0
|
|
102
|
-
loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
|
|
103
|
-
|
|
104
|
-
self.train_loader = DataLoader(
|
|
105
|
-
dataset=self.train_dataset,
|
|
106
|
-
batch_size=batch_size,
|
|
107
|
-
shuffle=shuffle,
|
|
108
|
-
num_workers=loader_workers,
|
|
109
|
-
pin_memory=("cuda" in self.device.type),
|
|
110
|
-
drop_last=True # Drops the last batch if incomplete, selecting a good batch size is key.
|
|
111
|
-
)
|
|
112
|
-
|
|
113
|
-
self.test_loader = DataLoader(
|
|
114
|
-
dataset=self.test_dataset,
|
|
115
|
-
batch_size=batch_size,
|
|
116
|
-
shuffle=False,
|
|
117
|
-
num_workers=loader_workers,
|
|
118
|
-
pin_memory=("cuda" in self.device.type)
|
|
119
|
-
)
|
|
120
|
-
|
|
121
99
|
def _load_checkpoint(self, path: Union[str, Path]):
|
|
122
100
|
"""Loads a training checkpoint to resume training."""
|
|
123
101
|
p = make_fullpath(path, enforce="file")
|
|
124
|
-
_LOGGER.info(f"Loading checkpoint from '{p.name}'
|
|
102
|
+
_LOGGER.info(f"Loading checkpoint from '{p.name}'...")
|
|
125
103
|
|
|
126
104
|
try:
|
|
127
105
|
checkpoint = torch.load(p, map_location=self.device)
|
|
@@ -132,7 +110,16 @@ class MLTrainer:
|
|
|
132
110
|
|
|
133
111
|
self.model.load_state_dict(checkpoint[PyTorchCheckpointKeys.MODEL_STATE])
|
|
134
112
|
self.optimizer.load_state_dict(checkpoint[PyTorchCheckpointKeys.OPTIMIZER_STATE])
|
|
135
|
-
self.
|
|
113
|
+
self.epoch = checkpoint.get(PyTorchCheckpointKeys.EPOCH, 0)
|
|
114
|
+
self.start_epoch = self.epoch + 1 # Resume on the *next* epoch
|
|
115
|
+
|
|
116
|
+
# --- Load History ---
|
|
117
|
+
if PyTorchCheckpointKeys.HISTORY in checkpoint:
|
|
118
|
+
self.history = checkpoint[PyTorchCheckpointKeys.HISTORY]
|
|
119
|
+
_LOGGER.info(f"Restored training history up to epoch {self.epoch}.")
|
|
120
|
+
else:
|
|
121
|
+
_LOGGER.warning("No 'history' found in checkpoint. A new history will be started.")
|
|
122
|
+
self.history = {} # Ensure it's at least an empty dict
|
|
136
123
|
|
|
137
124
|
# --- Scheduler State Loading Logic ---
|
|
138
125
|
scheduler_state_exists = PyTorchCheckpointKeys.SCHEDULER_STATE in checkpoint
|
|
@@ -162,7 +149,7 @@ class MLTrainer:
|
|
|
162
149
|
|
|
163
150
|
# Restore callback states
|
|
164
151
|
for cb in self.callbacks:
|
|
165
|
-
if isinstance(cb,
|
|
152
|
+
if isinstance(cb, DragonModelCheckpoint) and PyTorchCheckpointKeys.BEST_SCORE in checkpoint:
|
|
166
153
|
cb.best = checkpoint[PyTorchCheckpointKeys.BEST_SCORE]
|
|
167
154
|
_LOGGER.info(f"Restored {cb.__class__.__name__} 'best' score to: {cb.best:.4f}")
|
|
168
155
|
|
|
@@ -173,7 +160,8 @@ class MLTrainer:
|
|
|
173
160
|
raise
|
|
174
161
|
|
|
175
162
|
def fit(self,
|
|
176
|
-
|
|
163
|
+
save_dir: Union[str,Path],
|
|
164
|
+
epochs: int = 100,
|
|
177
165
|
batch_size: int = 10,
|
|
178
166
|
shuffle: bool = True,
|
|
179
167
|
resume_from_checkpoint: Optional[Union[str, Path]] = None):
|
|
@@ -183,21 +171,15 @@ class MLTrainer:
|
|
|
183
171
|
Returns the "History" callback dictionary.
|
|
184
172
|
|
|
185
173
|
Args:
|
|
174
|
+
save_dir (str | Path): Directory to save the loss plot.
|
|
186
175
|
epochs (int): The total number of epochs to train for.
|
|
187
176
|
batch_size (int): The number of samples per batch.
|
|
188
177
|
shuffle (bool): Whether to shuffle the training data at each epoch.
|
|
189
178
|
resume_from_checkpoint (str | Path | None): Optional path to a checkpoint to resume training.
|
|
190
|
-
|
|
191
|
-
Note:
|
|
192
|
-
For regression tasks using `nn.MSELoss` or `nn.L1Loss`, the trainer
|
|
193
|
-
automatically aligns the model's output tensor with the target tensor's
|
|
194
|
-
shape using `output.view_as(target)`. This handles the common case
|
|
195
|
-
where a model outputs a shape of `[batch_size, 1]` and the target has a
|
|
196
|
-
shape of `[batch_size]`.
|
|
197
179
|
"""
|
|
198
180
|
self.epochs = epochs
|
|
199
181
|
self._batch_size = batch_size
|
|
200
|
-
self._create_dataloaders(self._batch_size, shuffle)
|
|
182
|
+
self._create_dataloaders(self._batch_size, shuffle) # type: ignore
|
|
201
183
|
self.model.to(self.device)
|
|
202
184
|
|
|
203
185
|
if resume_from_checkpoint:
|
|
@@ -208,11 +190,19 @@ class MLTrainer:
|
|
|
208
190
|
|
|
209
191
|
self._callbacks_hook('on_train_begin')
|
|
210
192
|
|
|
193
|
+
if not self.train_loader:
|
|
194
|
+
_LOGGER.error("Train loader is not initialized.")
|
|
195
|
+
raise ValueError()
|
|
196
|
+
|
|
197
|
+
if not self.validation_loader:
|
|
198
|
+
_LOGGER.error("Validation loader is not initialized.")
|
|
199
|
+
raise ValueError()
|
|
200
|
+
|
|
211
201
|
for epoch in range(self.start_epoch, self.epochs + 1):
|
|
212
202
|
self.epoch = epoch
|
|
213
|
-
epoch_logs = {}
|
|
203
|
+
epoch_logs: Dict[str, Any] = {}
|
|
214
204
|
self._callbacks_hook('on_epoch_begin', epoch, logs=epoch_logs)
|
|
215
|
-
|
|
205
|
+
|
|
216
206
|
train_logs = self._train_step()
|
|
217
207
|
epoch_logs.update(train_logs)
|
|
218
208
|
|
|
@@ -226,11 +216,185 @@ class MLTrainer:
|
|
|
226
216
|
break
|
|
227
217
|
|
|
228
218
|
self._callbacks_hook('on_train_end')
|
|
219
|
+
|
|
220
|
+
# Training History
|
|
221
|
+
plot_losses(self.history, save_dir=save_dir)
|
|
222
|
+
|
|
229
223
|
return self.history
|
|
224
|
+
|
|
225
|
+
def _callbacks_hook(self, method_name: str, *args, **kwargs):
|
|
226
|
+
"""Calls the specified method on all callbacks."""
|
|
227
|
+
for callback in self.callbacks:
|
|
228
|
+
method = getattr(callback, method_name)
|
|
229
|
+
method(*args, **kwargs)
|
|
230
|
+
|
|
231
|
+
def to_cpu(self):
|
|
232
|
+
"""
|
|
233
|
+
Moves the model to the CPU and updates the trainer's device setting.
|
|
234
|
+
|
|
235
|
+
This is useful for running operations that require the CPU.
|
|
236
|
+
"""
|
|
237
|
+
self.device = torch.device('cpu')
|
|
238
|
+
self.model.to(self.device)
|
|
239
|
+
_LOGGER.info("Trainer and model moved to CPU.")
|
|
240
|
+
|
|
241
|
+
def to_device(self, device: str):
|
|
242
|
+
"""
|
|
243
|
+
Moves the model to the specified device and updates the trainer's device setting.
|
|
244
|
+
|
|
245
|
+
Args:
|
|
246
|
+
device (str): The target device (e.g., 'cuda', 'mps', 'cpu').
|
|
247
|
+
"""
|
|
248
|
+
self.device = self._validate_device(device)
|
|
249
|
+
self.model.to(self.device)
|
|
250
|
+
_LOGGER.info(f"Trainer and model moved to {self.device}.")
|
|
251
|
+
|
|
252
|
+
# --- Abstract Methods ---
|
|
253
|
+
# These must be implemented by subclasses
|
|
254
|
+
|
|
255
|
+
@abstractmethod
|
|
256
|
+
def _create_dataloaders(self, batch_size: int, shuffle: bool):
|
|
257
|
+
"""Initializes the DataLoaders."""
|
|
258
|
+
raise NotImplementedError
|
|
259
|
+
|
|
260
|
+
@abstractmethod
|
|
261
|
+
def _train_step(self) -> Dict[str, float]:
|
|
262
|
+
"""Runs a single training epoch."""
|
|
263
|
+
raise NotImplementedError
|
|
264
|
+
|
|
265
|
+
@abstractmethod
|
|
266
|
+
def _validation_step(self) -> Dict[str, float]:
|
|
267
|
+
"""Runs a single validation epoch."""
|
|
268
|
+
raise NotImplementedError
|
|
269
|
+
|
|
270
|
+
@abstractmethod
|
|
271
|
+
def evaluate(self, *args, **kwargs):
|
|
272
|
+
"""Runs the full model evaluation."""
|
|
273
|
+
raise NotImplementedError
|
|
274
|
+
|
|
275
|
+
@abstractmethod
|
|
276
|
+
def _evaluate(self, *args, **kwargs):
|
|
277
|
+
"""Internal evaluation helper."""
|
|
278
|
+
raise NotImplementedError
|
|
279
|
+
|
|
280
|
+
@abstractmethod
|
|
281
|
+
def finalize_model_training(self, *args, **kwargs):
|
|
282
|
+
"""Saves the finalized model for inference."""
|
|
283
|
+
raise NotImplementedError
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
# --- DragonTrainer ----
|
|
287
|
+
class DragonTrainer(_BaseDragonTrainer):
|
|
288
|
+
def __init__(self,
|
|
289
|
+
model: nn.Module,
|
|
290
|
+
train_dataset: Dataset,
|
|
291
|
+
validation_dataset: Dataset,
|
|
292
|
+
kind: Literal["regression", "binary classification", "multiclass classification",
|
|
293
|
+
"multitarget regression", "multilabel binary classification",
|
|
294
|
+
"binary segmentation", "multiclass segmentation", "binary image classification", "multiclass image classification"],
|
|
295
|
+
optimizer: torch.optim.Optimizer,
|
|
296
|
+
device: Union[Literal['cuda', 'mps', 'cpu'],str],
|
|
297
|
+
checkpoint_callback: Optional[DragonModelCheckpoint],
|
|
298
|
+
early_stopping_callback: Optional[DragonEarlyStopping],
|
|
299
|
+
lr_scheduler_callback: Optional[DragonLRScheduler],
|
|
300
|
+
extra_callbacks: Optional[List[_Callback]] = None,
|
|
301
|
+
criterion: Union[nn.Module,Literal["auto"]] = "auto",
|
|
302
|
+
dataloader_workers: int = 2):
|
|
303
|
+
"""
|
|
304
|
+
Automates the training process of a PyTorch Model.
|
|
305
|
+
|
|
306
|
+
Built-in Callbacks: `History`, `TqdmProgressBar`
|
|
307
|
+
|
|
308
|
+
Args:
|
|
309
|
+
model (nn.Module): The PyTorch model to train.
|
|
310
|
+
train_dataset (Dataset): The training dataset.
|
|
311
|
+
validation_dataset (Dataset): The validation dataset.
|
|
312
|
+
kind (str): Used to redirect to the correct process.
|
|
313
|
+
criterion (nn.Module | "auto"): The loss function to use. If "auto", it will be inferred from the selected task
|
|
314
|
+
optimizer (torch.optim.Optimizer): The optimizer.
|
|
315
|
+
device (str): The device to run training on ('cpu', 'cuda', 'mps').
|
|
316
|
+
dataloader_workers (int): Subprocesses for data loading.
|
|
317
|
+
extra_callbacks (List[Callback] | None): A list of extra callbacks to use during training.
|
|
318
|
+
|
|
319
|
+
Note:
|
|
320
|
+
- For **regression** and **multi_target_regression** tasks, suggested criterions include `nn.MSELoss` or `nn.L1Loss`. The model should output as many logits as existing targets.
|
|
321
|
+
|
|
322
|
+
- For **single-label, binary classification**, `nn.BCEWithLogitsLoss` is the standard choice. The model should output a single logit.
|
|
230
323
|
|
|
324
|
+
- For **single-label, multi-class classification** tasks, `nn.CrossEntropyLoss` is the standard choice. The model should output as many logits as existing classes.
|
|
325
|
+
|
|
326
|
+
- For **multi-label, binary classification** tasks (where each label is a 0 or 1), `nn.BCEWithLogitsLoss` is the correct choice as it treats each output as an independent binary problem. The model should output 1 logit per binary target.
|
|
327
|
+
|
|
328
|
+
- For **binary segmentation** tasks, `nn.BCEWithLogitsLoss` is common. The model should output a single logit.
|
|
329
|
+
|
|
330
|
+
- for **multiclass segmentation** tasks, `nn.CrossEntropyLoss` is the standard. The model should output as many logits as existing classes.
|
|
331
|
+
"""
|
|
332
|
+
# Call the base class constructor with common parameters
|
|
333
|
+
super().__init__(
|
|
334
|
+
model=model,
|
|
335
|
+
optimizer=optimizer,
|
|
336
|
+
device=device,
|
|
337
|
+
dataloader_workers=dataloader_workers,
|
|
338
|
+
checkpoint_callback=checkpoint_callback,
|
|
339
|
+
early_stopping_callback=early_stopping_callback,
|
|
340
|
+
lr_scheduler_callback=lr_scheduler_callback,
|
|
341
|
+
extra_callbacks=extra_callbacks
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
if kind not in [MLTaskKeys.REGRESSION,
|
|
345
|
+
MLTaskKeys.BINARY_CLASSIFICATION,
|
|
346
|
+
MLTaskKeys.MULTICLASS_CLASSIFICATION,
|
|
347
|
+
MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION,
|
|
348
|
+
MLTaskKeys.MULTITARGET_REGRESSION,
|
|
349
|
+
MLTaskKeys.BINARY_SEGMENTATION,
|
|
350
|
+
MLTaskKeys.MULTICLASS_SEGMENTATION,
|
|
351
|
+
MLTaskKeys.BINARY_IMAGE_CLASSIFICATION,
|
|
352
|
+
MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION]:
|
|
353
|
+
raise ValueError(f"'{kind}' is not a valid task type.")
|
|
354
|
+
|
|
355
|
+
self.train_dataset = train_dataset
|
|
356
|
+
self.validation_dataset = validation_dataset
|
|
357
|
+
self.kind = kind
|
|
358
|
+
self._classification_threshold: float = 0.5
|
|
359
|
+
|
|
360
|
+
# loss function
|
|
361
|
+
if criterion == "auto":
|
|
362
|
+
if kind in [MLTaskKeys.REGRESSION, MLTaskKeys.MULTITARGET_REGRESSION]:
|
|
363
|
+
self.criterion = nn.MSELoss()
|
|
364
|
+
elif kind in [MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.BINARY_IMAGE_CLASSIFICATION, MLTaskKeys.BINARY_SEGMENTATION, MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION]:
|
|
365
|
+
self.criterion = nn.BCEWithLogitsLoss()
|
|
366
|
+
elif kind in [MLTaskKeys.MULTICLASS_CLASSIFICATION, MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION, MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION, MLTaskKeys.MULTICLASS_SEGMENTATION]:
|
|
367
|
+
self.criterion = nn.CrossEntropyLoss()
|
|
368
|
+
else:
|
|
369
|
+
self.criterion = criterion
|
|
370
|
+
|
|
371
|
+
def _create_dataloaders(self, batch_size: int, shuffle: bool):
|
|
372
|
+
"""Initializes the DataLoaders."""
|
|
373
|
+
# Ensure stability on MPS devices by setting num_workers to 0
|
|
374
|
+
loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
|
|
375
|
+
|
|
376
|
+
self.train_loader = DataLoader(
|
|
377
|
+
dataset=self.train_dataset,
|
|
378
|
+
batch_size=batch_size,
|
|
379
|
+
shuffle=shuffle,
|
|
380
|
+
num_workers=loader_workers,
|
|
381
|
+
pin_memory=("cuda" in self.device.type),
|
|
382
|
+
drop_last=True # Drops the last batch if incomplete, selecting a good batch size is key.
|
|
383
|
+
)
|
|
384
|
+
|
|
385
|
+
self.validation_loader = DataLoader(
|
|
386
|
+
dataset=self.validation_dataset,
|
|
387
|
+
batch_size=batch_size,
|
|
388
|
+
shuffle=False,
|
|
389
|
+
num_workers=loader_workers,
|
|
390
|
+
pin_memory=("cuda" in self.device.type)
|
|
391
|
+
)
|
|
392
|
+
|
|
231
393
|
def _train_step(self):
|
|
232
394
|
self.model.train()
|
|
233
395
|
running_loss = 0.0
|
|
396
|
+
total_samples = 0
|
|
397
|
+
|
|
234
398
|
for batch_idx, (features, target) in enumerate(self.train_loader): # type: ignore
|
|
235
399
|
# Create a log dictionary for the batch
|
|
236
400
|
batch_logs = {
|
|
@@ -244,9 +408,21 @@ class MLTrainer:
|
|
|
244
408
|
|
|
245
409
|
output = self.model(features)
|
|
246
410
|
|
|
247
|
-
#
|
|
248
|
-
|
|
249
|
-
|
|
411
|
+
# --- Label Type/Shape Correction ---
|
|
412
|
+
# Cast target to float for BCE-based losses
|
|
413
|
+
if self.kind in MLTaskKeys.ALL_BINARY_TASKS:
|
|
414
|
+
target = target.float()
|
|
415
|
+
|
|
416
|
+
# Reshape output to match target for single-logit tasks
|
|
417
|
+
if self.kind in [MLTaskKeys.REGRESSION, MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.BINARY_IMAGE_CLASSIFICATION]:
|
|
418
|
+
# If model outputs [N, 1] and target is [N], squeeze output
|
|
419
|
+
if output.ndim == 2 and output.shape[1] == 1 and target.ndim == 1:
|
|
420
|
+
output = output.squeeze(1)
|
|
421
|
+
|
|
422
|
+
if self.kind == MLTaskKeys.BINARY_SEGMENTATION:
|
|
423
|
+
# If model outputs [N, 1, H, W] and target is [N, H, W], squeeze output
|
|
424
|
+
if output.ndim == 4 and output.shape[1] == 1 and target.ndim == 3:
|
|
425
|
+
output = output.squeeze(1)
|
|
250
426
|
|
|
251
427
|
loss = self.criterion(output, target)
|
|
252
428
|
|
|
@@ -255,34 +431,58 @@ class MLTrainer:
|
|
|
255
431
|
|
|
256
432
|
# Calculate batch loss and update running loss for the epoch
|
|
257
433
|
batch_loss = loss.item()
|
|
258
|
-
|
|
434
|
+
batch_size = features.size(0)
|
|
435
|
+
running_loss += batch_loss * batch_size # Accumulate total loss
|
|
436
|
+
total_samples += batch_size # total samples
|
|
259
437
|
|
|
260
438
|
# Add the batch loss to the logs and call the end-of-batch hook
|
|
261
439
|
batch_logs[PyTorchLogKeys.BATCH_LOSS] = batch_loss
|
|
262
440
|
self._callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
|
|
441
|
+
|
|
442
|
+
if total_samples == 0:
|
|
443
|
+
_LOGGER.warning("No samples processed in a train_step. Returning 0 loss.")
|
|
444
|
+
return {PyTorchLogKeys.TRAIN_LOSS: 0.0}
|
|
263
445
|
|
|
264
|
-
return {PyTorchLogKeys.TRAIN_LOSS: running_loss /
|
|
446
|
+
return {PyTorchLogKeys.TRAIN_LOSS: running_loss / total_samples} # type: ignore
|
|
265
447
|
|
|
266
448
|
def _validation_step(self):
|
|
267
449
|
self.model.eval()
|
|
268
450
|
running_loss = 0.0
|
|
451
|
+
|
|
269
452
|
with torch.no_grad():
|
|
270
|
-
for features, target in self.
|
|
453
|
+
for features, target in self.validation_loader: # type: ignore
|
|
271
454
|
features, target = features.to(self.device), target.to(self.device)
|
|
272
455
|
|
|
273
456
|
output = self.model(features)
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
457
|
+
|
|
458
|
+
# --- Label Type/Shape Correction ---
|
|
459
|
+
# Cast target to float for BCE-based losses
|
|
460
|
+
if self.kind in MLTaskKeys.ALL_BINARY_TASKS:
|
|
461
|
+
target = target.float()
|
|
462
|
+
|
|
463
|
+
# Reshape output to match target for single-logit tasks
|
|
464
|
+
if self.kind in [MLTaskKeys.REGRESSION, MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.BINARY_IMAGE_CLASSIFICATION]:
|
|
465
|
+
# If model outputs [N, 1] and target is [N], squeeze output
|
|
466
|
+
if output.ndim == 2 and output.shape[1] == 1 and target.ndim == 1:
|
|
467
|
+
output = output.squeeze(1)
|
|
468
|
+
|
|
469
|
+
if self.kind == MLTaskKeys.BINARY_SEGMENTATION:
|
|
470
|
+
# If model outputs [N, 1, H, W] and target is [N, H, W], squeeze output
|
|
471
|
+
if output.ndim == 4 and output.shape[1] == 1 and target.ndim == 3:
|
|
472
|
+
output = output.squeeze(1)
|
|
277
473
|
|
|
278
474
|
loss = self.criterion(output, target)
|
|
279
475
|
|
|
280
476
|
running_loss += loss.item() * features.size(0)
|
|
477
|
+
|
|
478
|
+
if not self.validation_loader.dataset: # type: ignore
|
|
479
|
+
_LOGGER.warning("No samples processed in _validation_step. Returning 0 loss.")
|
|
480
|
+
return {PyTorchLogKeys.VAL_LOSS: 0.0}
|
|
281
481
|
|
|
282
|
-
logs = {PyTorchLogKeys.VAL_LOSS: running_loss / len(self.
|
|
482
|
+
logs = {PyTorchLogKeys.VAL_LOSS: running_loss / len(self.validation_loader.dataset)} # type: ignore
|
|
283
483
|
return logs
|
|
284
484
|
|
|
285
|
-
def _predict_for_eval(self, dataloader: DataLoader
|
|
485
|
+
def _predict_for_eval(self, dataloader: DataLoader):
|
|
286
486
|
"""
|
|
287
487
|
Private method to yield model predictions batch by batch for evaluation.
|
|
288
488
|
|
|
@@ -293,6 +493,7 @@ class MLTrainer:
|
|
|
293
493
|
"""
|
|
294
494
|
self.model.eval()
|
|
295
495
|
self.model.to(self.device)
|
|
496
|
+
|
|
296
497
|
with torch.no_grad():
|
|
297
498
|
for features, target in dataloader:
|
|
298
499
|
features = features.to(self.device)
|
|
@@ -302,25 +503,64 @@ class MLTrainer:
|
|
|
302
503
|
y_prob_batch = None
|
|
303
504
|
y_true_batch = None
|
|
304
505
|
|
|
305
|
-
if self.kind in [
|
|
506
|
+
if self.kind in [MLTaskKeys.REGRESSION, MLTaskKeys.MULTITARGET_REGRESSION]:
|
|
306
507
|
y_pred_batch = output.numpy()
|
|
307
508
|
y_true_batch = target.numpy()
|
|
509
|
+
|
|
510
|
+
elif self.kind in [MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.BINARY_IMAGE_CLASSIFICATION]:
|
|
511
|
+
# Assumes model output is [N, 1] (a single logit)
|
|
512
|
+
# Squeeze output from [N, 1] to [N] if necessary
|
|
513
|
+
if output.ndim == 2 and output.shape[1] == 1:
|
|
514
|
+
output = output.squeeze(1)
|
|
515
|
+
|
|
516
|
+
probs_pos = torch.sigmoid(output) # Probability of positive class
|
|
517
|
+
preds = (probs_pos >= self._classification_threshold).int()
|
|
518
|
+
y_pred_batch = preds.numpy()
|
|
519
|
+
# For metrics (like ROC AUC), we often need probs for *both* classes
|
|
520
|
+
# Create an [N, 2] array: [prob_class_0, prob_class_1]
|
|
521
|
+
probs_neg = 1.0 - probs_pos
|
|
522
|
+
y_prob_batch = torch.stack([probs_neg, probs_pos], dim=1).numpy()
|
|
523
|
+
y_true_batch = target.numpy()
|
|
308
524
|
|
|
309
|
-
elif self.kind
|
|
525
|
+
elif self.kind in [MLTaskKeys.MULTICLASS_CLASSIFICATION, MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION]:
|
|
526
|
+
num_classes = output.shape[1]
|
|
527
|
+
if num_classes < 3:
|
|
528
|
+
# Optional: warn the user they are using the wrong kind
|
|
529
|
+
wrong_class = MLTaskKeys.MULTICLASS_CLASSIFICATION if self.kind == MLTaskKeys.MULTICLASS_CLASSIFICATION else MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION
|
|
530
|
+
recommended_class = MLTaskKeys.BINARY_CLASSIFICATION if self.kind == MLTaskKeys.MULTICLASS_CLASSIFICATION else MLTaskKeys.BINARY_IMAGE_CLASSIFICATION
|
|
531
|
+
_LOGGER.warning(f"'{wrong_class}' kind used with {num_classes} classes. Consider using '{recommended_class}' instead.")
|
|
532
|
+
|
|
310
533
|
probs = torch.softmax(output, dim=1)
|
|
311
534
|
preds = torch.argmax(probs, dim=1)
|
|
312
535
|
y_pred_batch = preds.numpy()
|
|
313
536
|
y_prob_batch = probs.numpy()
|
|
314
537
|
y_true_batch = target.numpy()
|
|
315
538
|
|
|
316
|
-
elif self.kind ==
|
|
539
|
+
elif self.kind == MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION:
|
|
317
540
|
probs = torch.sigmoid(output)
|
|
318
|
-
preds = (probs >=
|
|
541
|
+
preds = (probs >= self._classification_threshold).int()
|
|
319
542
|
y_pred_batch = preds.numpy()
|
|
320
543
|
y_prob_batch = probs.numpy()
|
|
321
544
|
y_true_batch = target.numpy()
|
|
545
|
+
|
|
546
|
+
elif self.kind == MLTaskKeys.BINARY_SEGMENTATION:
|
|
547
|
+
# Assumes model output is [N, 1, H, W] (logits for positive class)
|
|
548
|
+
probs_pos = torch.sigmoid(output) # Shape [N, 1, H, W]
|
|
549
|
+
preds = (probs_pos >= self._classification_threshold).int() # Shape [N, 1, H, W]
|
|
550
|
+
|
|
551
|
+
# Squeeze preds to [N, H, W] (class indices 0 or 1)
|
|
552
|
+
y_pred_batch = preds.squeeze(1).numpy()
|
|
553
|
+
|
|
554
|
+
# Create [N, 2, H, W] probs for consistency
|
|
555
|
+
probs_neg = 1.0 - probs_pos
|
|
556
|
+
y_prob_batch = torch.cat([probs_neg, probs_pos], dim=1).numpy()
|
|
557
|
+
|
|
558
|
+
# Handle target shape [N, 1, H, W] -> [N, H, W]
|
|
559
|
+
if target.ndim == 4 and target.shape[1] == 1:
|
|
560
|
+
target = target.squeeze(1)
|
|
561
|
+
y_true_batch = target.numpy()
|
|
322
562
|
|
|
323
|
-
elif self.kind ==
|
|
563
|
+
elif self.kind == MLTaskKeys.MULTICLASS_SEGMENTATION:
|
|
324
564
|
# output shape [N, C, H, W]
|
|
325
565
|
probs = torch.softmax(output, dim=1)
|
|
326
566
|
preds = torch.argmax(probs, dim=1) # shape [N, H, W]
|
|
@@ -333,24 +573,161 @@ class MLTrainer:
|
|
|
333
573
|
y_true_batch = target.numpy()
|
|
334
574
|
|
|
335
575
|
yield y_pred_batch, y_prob_batch, y_true_batch
|
|
336
|
-
|
|
337
|
-
def evaluate(self,
|
|
576
|
+
|
|
577
|
+
def evaluate(self,
|
|
578
|
+
save_dir: Union[str, Path],
|
|
579
|
+
model_checkpoint: Union[Path, Literal["latest", "current"]],
|
|
580
|
+
classification_threshold: Optional[float] = None,
|
|
581
|
+
test_data: Optional[Union[DataLoader, Dataset]] = None,
|
|
582
|
+
val_format_configuration: Optional[Union[ClassificationMetricsFormat,
|
|
583
|
+
MultiClassificationMetricsFormat,
|
|
584
|
+
RegressionMetricsFormat,
|
|
585
|
+
SegmentationMetricsFormat]]=None,
|
|
586
|
+
test_format_configuration: Optional[Union[ClassificationMetricsFormat,
|
|
587
|
+
MultiClassificationMetricsFormat,
|
|
588
|
+
RegressionMetricsFormat,
|
|
589
|
+
SegmentationMetricsFormat]]=None):
|
|
338
590
|
"""
|
|
339
591
|
Evaluates the model, routing to the correct evaluation function based on task `kind`.
|
|
340
592
|
|
|
341
593
|
Args:
|
|
594
|
+
model_checkpoint ('auto' | Path | None):
|
|
595
|
+
- Path to a valid checkpoint for the model. The state of the trained model will be overwritten in place.
|
|
596
|
+
- If 'latest', the latest checkpoint will be loaded if a DragonModelCheckpoint was provided. The state of the trained model will be overwritten in place.
|
|
597
|
+
- If 'current', use the current state of the trained model up the latest trained epoch.
|
|
342
598
|
save_dir (str | Path): Directory to save all reports and plots.
|
|
343
|
-
|
|
344
|
-
|
|
599
|
+
classification_threshold (float | None): Used for tasks using a binary approach (binary classification, binary segmentation, multilabel binary classification)
|
|
600
|
+
test_data (DataLoader | Dataset | None): Optional Test data to evaluate the model performance. Validation and Test metrics will be saved to subdirectories.
|
|
601
|
+
val_format_configuration: Optional configuration for metric format output for the validation set.
|
|
602
|
+
test_format_configuration: Optional configuration for metric format output for the test set.
|
|
603
|
+
"""
|
|
604
|
+
# Validate model checkpoint
|
|
605
|
+
if isinstance(model_checkpoint, Path):
|
|
606
|
+
checkpoint_validated = make_fullpath(model_checkpoint, enforce="file")
|
|
607
|
+
elif model_checkpoint in [MagicWords.LATEST, MagicWords.CURRENT]:
|
|
608
|
+
checkpoint_validated = model_checkpoint
|
|
609
|
+
else:
|
|
610
|
+
_LOGGER.error(f"'model_checkpoint' must be a Path object, or the string '{MagicWords.LATEST}', or the string '{MagicWords.CURRENT}'.")
|
|
611
|
+
raise ValueError()
|
|
612
|
+
|
|
613
|
+
# Validate classification threshold
|
|
614
|
+
if self.kind not in MLTaskKeys.ALL_BINARY_TASKS:
|
|
615
|
+
# dummy value for tasks that do not need it
|
|
616
|
+
threshold_validated = 0.5
|
|
617
|
+
elif classification_threshold is None:
|
|
618
|
+
# it should have been provided for binary tasks
|
|
619
|
+
_LOGGER.error(f"The classification threshold must be provided for '{self.kind}'.")
|
|
620
|
+
raise ValueError()
|
|
621
|
+
elif classification_threshold <= 0.0 or classification_threshold >= 1.0:
|
|
622
|
+
# Invalid float
|
|
623
|
+
_LOGGER.error(f"A classification threshold of {classification_threshold} is invalid. Must be in the range (0.0 - 1.0).")
|
|
624
|
+
raise ValueError()
|
|
625
|
+
else:
|
|
626
|
+
threshold_validated = classification_threshold
|
|
627
|
+
|
|
628
|
+
# Validate val configuration
|
|
629
|
+
if val_format_configuration is not None:
|
|
630
|
+
if not isinstance(val_format_configuration, (ClassificationMetricsFormat,
|
|
631
|
+
MultiClassificationMetricsFormat,
|
|
632
|
+
RegressionMetricsFormat,
|
|
633
|
+
SegmentationMetricsFormat)):
|
|
634
|
+
_LOGGER.error(f"Invalid 'format_configuration': '{type(val_format_configuration)}'.")
|
|
635
|
+
raise ValueError()
|
|
636
|
+
else:
|
|
637
|
+
val_configuration_validated = val_format_configuration
|
|
638
|
+
else: # config is None
|
|
639
|
+
val_configuration_validated = None
|
|
640
|
+
|
|
641
|
+
# Validate directory
|
|
642
|
+
save_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
643
|
+
|
|
644
|
+
# Validate test data and dispatch
|
|
645
|
+
if test_data is not None:
|
|
646
|
+
if not isinstance(test_data, (DataLoader, Dataset)):
|
|
647
|
+
_LOGGER.error(f"Invalid type for 'test_data': '{type(test_data)}'.")
|
|
648
|
+
raise ValueError()
|
|
649
|
+
test_data_validated = test_data
|
|
650
|
+
|
|
651
|
+
validation_metrics_path = save_path / DragonTrainerKeys.VALIDATION_METRICS_DIR
|
|
652
|
+
test_metrics_path = save_path / DragonTrainerKeys.TEST_METRICS_DIR
|
|
653
|
+
|
|
654
|
+
# Dispatch validation set
|
|
655
|
+
_LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{DragonTrainerKeys.VALIDATION_METRICS_DIR}'")
|
|
656
|
+
self._evaluate(save_dir=validation_metrics_path,
|
|
657
|
+
model_checkpoint=checkpoint_validated,
|
|
658
|
+
classification_threshold=threshold_validated,
|
|
659
|
+
data=None,
|
|
660
|
+
format_configuration=val_configuration_validated)
|
|
661
|
+
|
|
662
|
+
# Validate test configuration
|
|
663
|
+
if test_format_configuration is not None:
|
|
664
|
+
if not isinstance(test_format_configuration, (ClassificationMetricsFormat,
|
|
665
|
+
MultiClassificationMetricsFormat,
|
|
666
|
+
RegressionMetricsFormat,
|
|
667
|
+
SegmentationMetricsFormat)):
|
|
668
|
+
warning_message_type = f"Invalid test_format_configuration': '{type(val_format_configuration)}'."
|
|
669
|
+
if val_configuration_validated is not None:
|
|
670
|
+
warning_message_type += " 'val_format_configuration' will be used for the test set metrics output."
|
|
671
|
+
test_configuration_validated = val_configuration_validated
|
|
672
|
+
else:
|
|
673
|
+
warning_message_type += " Using default format."
|
|
674
|
+
test_configuration_validated = None
|
|
675
|
+
_LOGGER.warning(warning_message_type)
|
|
676
|
+
else:
|
|
677
|
+
test_configuration_validated = test_format_configuration
|
|
678
|
+
else: #config is None
|
|
679
|
+
test_configuration_validated = None
|
|
680
|
+
|
|
681
|
+
# Dispatch test set
|
|
682
|
+
_LOGGER.info(f"Evaluating on test dataset. Metrics will be saved to '{DragonTrainerKeys.TEST_METRICS_DIR}'")
|
|
683
|
+
self._evaluate(save_dir=test_metrics_path,
|
|
684
|
+
model_checkpoint="current",
|
|
685
|
+
classification_threshold=threshold_validated,
|
|
686
|
+
data=test_data_validated,
|
|
687
|
+
format_configuration=test_configuration_validated)
|
|
688
|
+
else:
|
|
689
|
+
# Dispatch validation set
|
|
690
|
+
_LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{save_path.name}'")
|
|
691
|
+
self._evaluate(save_dir=save_path,
|
|
692
|
+
model_checkpoint=checkpoint_validated,
|
|
693
|
+
classification_threshold=threshold_validated,
|
|
694
|
+
data=None,
|
|
695
|
+
format_configuration=val_configuration_validated)
|
|
696
|
+
|
|
697
|
+
def _evaluate(self,
|
|
698
|
+
save_dir: Union[str, Path],
|
|
699
|
+
model_checkpoint: Union[Path, Literal["latest", "current"]],
|
|
700
|
+
classification_threshold: float,
|
|
701
|
+
data: Optional[Union[DataLoader, Dataset]],
|
|
702
|
+
format_configuration: Optional[Union[ClassificationMetricsFormat,
|
|
703
|
+
MultiClassificationMetricsFormat,
|
|
704
|
+
RegressionMetricsFormat,
|
|
705
|
+
SegmentationMetricsFormat]]):
|
|
706
|
+
"""
|
|
707
|
+
Changed to a private helper function.
|
|
345
708
|
"""
|
|
346
709
|
dataset_for_names = None
|
|
347
710
|
eval_loader = None
|
|
348
|
-
|
|
711
|
+
|
|
712
|
+
# set threshold
|
|
713
|
+
self._classification_threshold = classification_threshold
|
|
714
|
+
|
|
715
|
+
# load model checkpoint
|
|
716
|
+
if isinstance(model_checkpoint, Path):
|
|
717
|
+
self._load_checkpoint(path=model_checkpoint)
|
|
718
|
+
elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback:
|
|
719
|
+
path_to_latest = self._checkpoint_callback.best_checkpoint_path
|
|
720
|
+
self._load_checkpoint(path_to_latest)
|
|
721
|
+
elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback is None:
|
|
722
|
+
_LOGGER.error(f"'model_checkpoint' set to '{MagicWords.LATEST}' but no checkpoint callback was found.")
|
|
723
|
+
raise ValueError()
|
|
724
|
+
|
|
725
|
+
# Dataloader
|
|
349
726
|
if isinstance(data, DataLoader):
|
|
350
727
|
eval_loader = data
|
|
351
728
|
# Try to get the dataset from the loader for fetching target names
|
|
352
729
|
if hasattr(data, 'dataset'):
|
|
353
|
-
dataset_for_names = data.dataset
|
|
730
|
+
dataset_for_names = data.dataset # type: ignore
|
|
354
731
|
elif isinstance(data, Dataset):
|
|
355
732
|
# Create a new loader from the provided dataset
|
|
356
733
|
eval_loader = DataLoader(data,
|
|
@@ -360,26 +737,26 @@ class MLTrainer:
|
|
|
360
737
|
pin_memory=(self.device.type == "cuda"))
|
|
361
738
|
dataset_for_names = data
|
|
362
739
|
else: # data is None, use the trainer's default test dataset
|
|
363
|
-
if self.
|
|
740
|
+
if self.validation_dataset is None:
|
|
364
741
|
_LOGGER.error("Cannot evaluate. No data provided and no test_dataset available in the trainer.")
|
|
365
742
|
raise ValueError()
|
|
366
743
|
# Create a fresh DataLoader from the test_dataset
|
|
367
|
-
eval_loader = DataLoader(self.
|
|
744
|
+
eval_loader = DataLoader(self.validation_dataset,
|
|
368
745
|
batch_size=self._batch_size,
|
|
369
746
|
shuffle=False,
|
|
370
747
|
num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
|
|
371
748
|
pin_memory=(self.device.type == "cuda"))
|
|
372
749
|
|
|
373
|
-
dataset_for_names = self.
|
|
750
|
+
dataset_for_names = self.validation_dataset
|
|
374
751
|
|
|
375
752
|
if eval_loader is None:
|
|
376
753
|
_LOGGER.error("Cannot evaluate. No valid data was provided or found.")
|
|
377
754
|
raise ValueError()
|
|
378
755
|
|
|
379
|
-
print("\n--- Model Evaluation ---")
|
|
756
|
+
# print("\n--- Model Evaluation ---")
|
|
380
757
|
|
|
381
758
|
all_preds, all_probs, all_true = [], [], []
|
|
382
|
-
for y_pred_b, y_prob_b, y_true_b in self._predict_for_eval(eval_loader
|
|
759
|
+
for y_pred_b, y_prob_b, y_true_b in self._predict_for_eval(eval_loader):
|
|
383
760
|
if y_pred_b is not None: all_preds.append(y_pred_b)
|
|
384
761
|
if y_prob_b is not None: all_probs.append(y_prob_b)
|
|
385
762
|
if y_true_b is not None: all_true.append(y_true_b)
|
|
@@ -393,22 +770,55 @@ class MLTrainer:
|
|
|
393
770
|
y_prob = np.concatenate(all_probs) if all_probs else None
|
|
394
771
|
|
|
395
772
|
# --- Routing Logic ---
|
|
396
|
-
if self.kind ==
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
773
|
+
if self.kind == MLTaskKeys.REGRESSION:
|
|
774
|
+
# Check configuration
|
|
775
|
+
config = None
|
|
776
|
+
if format_configuration and isinstance(format_configuration, RegressionMetricsFormat):
|
|
777
|
+
config = format_configuration
|
|
778
|
+
elif format_configuration:
|
|
779
|
+
_LOGGER.warning(f"Wrong configuration type: Received {type(format_configuration).__name__}, expected RegressionMetricsFormat.")
|
|
780
|
+
|
|
781
|
+
regression_metrics(y_true=y_true.flatten(),
|
|
782
|
+
y_pred=y_pred.flatten(),
|
|
783
|
+
save_dir=save_dir,
|
|
784
|
+
config=config)
|
|
785
|
+
|
|
786
|
+
elif self.kind in [MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.BINARY_IMAGE_CLASSIFICATION, MLTaskKeys.MULTICLASS_CLASSIFICATION, MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION]:
|
|
787
|
+
# Check configuration
|
|
788
|
+
config = None
|
|
789
|
+
if format_configuration and isinstance(format_configuration, ClassificationMetricsFormat):
|
|
790
|
+
config = format_configuration
|
|
791
|
+
elif format_configuration:
|
|
792
|
+
_LOGGER.warning(f"Wrong configuration type: Received {type(format_configuration).__name__}, expected ClassificationMetricsFormat.")
|
|
793
|
+
|
|
794
|
+
classification_metrics(save_dir=save_dir,
|
|
795
|
+
y_true=y_true,
|
|
796
|
+
y_pred=y_pred,
|
|
797
|
+
y_prob=y_prob,
|
|
798
|
+
config=config)
|
|
799
|
+
|
|
800
|
+
elif self.kind == MLTaskKeys.MULTITARGET_REGRESSION:
|
|
403
801
|
try:
|
|
404
802
|
target_names = dataset_for_names.target_names # type: ignore
|
|
405
803
|
except AttributeError:
|
|
406
804
|
num_targets = y_true.shape[1]
|
|
407
805
|
target_names = [f"target_{i}" for i in range(num_targets)]
|
|
408
806
|
_LOGGER.warning(f"Dataset has no 'target_names' attribute. Using generic names.")
|
|
409
|
-
|
|
807
|
+
|
|
808
|
+
# Check configuration
|
|
809
|
+
config = None
|
|
810
|
+
if format_configuration and isinstance(format_configuration, RegressionMetricsFormat):
|
|
811
|
+
config = format_configuration
|
|
812
|
+
elif format_configuration:
|
|
813
|
+
_LOGGER.warning(f"Wrong configuration type: Received {type(format_configuration).__name__}, expected RegressionMetricsFormat.")
|
|
814
|
+
|
|
815
|
+
multi_target_regression_metrics(y_true=y_true,
|
|
816
|
+
y_pred=y_pred,
|
|
817
|
+
target_names=target_names,
|
|
818
|
+
save_dir=save_dir,
|
|
819
|
+
config=config)
|
|
410
820
|
|
|
411
|
-
elif self.kind ==
|
|
821
|
+
elif self.kind == MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION:
|
|
412
822
|
try:
|
|
413
823
|
target_names = dataset_for_names.target_names # type: ignore
|
|
414
824
|
except AttributeError:
|
|
@@ -419,9 +829,22 @@ class MLTrainer:
|
|
|
419
829
|
if y_prob is None:
|
|
420
830
|
_LOGGER.error("Evaluation for multi_label_classification requires probabilities (y_prob).")
|
|
421
831
|
return
|
|
422
|
-
multi_label_classification_metrics(y_true, y_prob, target_names, save_dir, classification_threshold)
|
|
423
832
|
|
|
424
|
-
|
|
833
|
+
# Check configuration
|
|
834
|
+
config = None
|
|
835
|
+
if format_configuration and isinstance(format_configuration, MultiClassificationMetricsFormat):
|
|
836
|
+
config = format_configuration
|
|
837
|
+
elif format_configuration:
|
|
838
|
+
_LOGGER.warning(f"Wrong configuration type: Received {type(format_configuration).__name__}, expected MultiClassificationMetricsFormat.")
|
|
839
|
+
|
|
840
|
+
multi_label_classification_metrics(y_true=y_true,
|
|
841
|
+
y_pred=y_pred,
|
|
842
|
+
y_prob=y_prob,
|
|
843
|
+
target_names=target_names,
|
|
844
|
+
save_dir=save_dir,
|
|
845
|
+
config=config)
|
|
846
|
+
|
|
847
|
+
elif self.kind in [MLTaskKeys.BINARY_SEGMENTATION, MLTaskKeys.MULTICLASS_SEGMENTATION]:
|
|
425
848
|
class_names = None
|
|
426
849
|
try:
|
|
427
850
|
# Try to get 'classes' from VisionDatasetMaker
|
|
@@ -443,10 +866,18 @@ class MLTrainer:
|
|
|
443
866
|
class_names = [f"Class {i}" for i in labels]
|
|
444
867
|
_LOGGER.warning(f"Dataset has no 'classes' or 'target_names' attribute. Using generic names.")
|
|
445
868
|
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
869
|
+
# Check configuration
|
|
870
|
+
config = None
|
|
871
|
+
if format_configuration and isinstance(format_configuration, SegmentationMetricsFormat):
|
|
872
|
+
config = format_configuration
|
|
873
|
+
elif format_configuration:
|
|
874
|
+
_LOGGER.warning(f"Wrong configuration type: Received {type(format_configuration).__name__}, expected SegmentationMetricsFormat.")
|
|
875
|
+
|
|
876
|
+
segmentation_metrics(y_true=y_true,
|
|
877
|
+
y_pred=y_pred,
|
|
878
|
+
save_dir=save_dir,
|
|
879
|
+
class_names=class_names,
|
|
880
|
+
config=config)
|
|
450
881
|
|
|
451
882
|
def explain(self,
|
|
452
883
|
save_dir: Union[str,Path],
|
|
@@ -502,7 +933,7 @@ class MLTrainer:
|
|
|
502
933
|
rand_indices = torch.randperm(full_data.size(0))[:num_samples]
|
|
503
934
|
return full_data[rand_indices]
|
|
504
935
|
|
|
505
|
-
print(f"\n--- Preparing SHAP Data (sampling up to {n_samples} instances) ---")
|
|
936
|
+
# print(f"\n--- Preparing SHAP Data (sampling up to {n_samples} instances) ---")
|
|
506
937
|
|
|
507
938
|
# 1. Get background data from the trainer's train_dataset
|
|
508
939
|
background_data = _get_random_sample(self.train_dataset, n_samples)
|
|
@@ -511,7 +942,7 @@ class MLTrainer:
|
|
|
511
942
|
return
|
|
512
943
|
|
|
513
944
|
# 2. Determine target dataset and get explanation instances
|
|
514
|
-
target_dataset = explain_dataset if explain_dataset is not None else self.
|
|
945
|
+
target_dataset = explain_dataset if explain_dataset is not None else self.validation_dataset
|
|
515
946
|
instances_to_explain = _get_random_sample(target_dataset, n_samples)
|
|
516
947
|
if instances_to_explain is None:
|
|
517
948
|
_LOGGER.error("Explanation dataset is empty or invalid. Skipping SHAP analysis.")
|
|
@@ -530,7 +961,7 @@ class MLTrainer:
|
|
|
530
961
|
self.model.to(self.device)
|
|
531
962
|
|
|
532
963
|
# 3. Call the plotting function
|
|
533
|
-
if self.kind in [
|
|
964
|
+
if self.kind in [MLTaskKeys.REGRESSION, MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.MULTICLASS_CLASSIFICATION]:
|
|
534
965
|
shap_summary_plot(
|
|
535
966
|
model=self.model,
|
|
536
967
|
background_data=background_data,
|
|
@@ -540,7 +971,7 @@ class MLTrainer:
|
|
|
540
971
|
explainer_type=explainer_type,
|
|
541
972
|
device=self.device
|
|
542
973
|
)
|
|
543
|
-
elif self.kind in [
|
|
974
|
+
elif self.kind in [MLTaskKeys.MULTITARGET_REGRESSION, MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION]:
|
|
544
975
|
# try to get target names
|
|
545
976
|
if target_names is None:
|
|
546
977
|
target_names = []
|
|
@@ -610,17 +1041,15 @@ class MLTrainer:
|
|
|
610
1041
|
plot_n_features (int): Number of top features to plot.
|
|
611
1042
|
"""
|
|
612
1043
|
|
|
613
|
-
print("\n--- Attention Analysis ---")
|
|
1044
|
+
# print("\n--- Attention Analysis ---")
|
|
614
1045
|
|
|
615
1046
|
# --- Step 1: Check if the model supports this explanation ---
|
|
616
1047
|
if not getattr(self.model, 'has_interpretable_attention', False):
|
|
617
|
-
_LOGGER.warning(
|
|
618
|
-
"Model is not flagged for interpretable attention analysis. Skipping. This is the correct behavior for models like MultiHeadAttentionMLP."
|
|
619
|
-
)
|
|
1048
|
+
_LOGGER.warning("Model is not compatible with interpretable attention analysis. Skipping.")
|
|
620
1049
|
return
|
|
621
1050
|
|
|
622
1051
|
# --- Step 2: Set up the dataloader ---
|
|
623
|
-
dataset_to_use = explain_dataset if explain_dataset is not None else self.
|
|
1052
|
+
dataset_to_use = explain_dataset if explain_dataset is not None else self.validation_dataset
|
|
624
1053
|
if not isinstance(dataset_to_use, Dataset):
|
|
625
1054
|
_LOGGER.error("The explanation dataset is empty or invalid. Skipping attention analysis.")
|
|
626
1055
|
return
|
|
@@ -655,40 +1084,111 @@ class MLTrainer:
|
|
|
655
1084
|
)
|
|
656
1085
|
else:
|
|
657
1086
|
_LOGGER.error("No attention weights were collected from the model.")
|
|
658
|
-
|
|
659
|
-
def _callbacks_hook(self, method_name: str, *args, **kwargs):
|
|
660
|
-
"""Calls the specified method on all callbacks."""
|
|
661
|
-
for callback in self.callbacks:
|
|
662
|
-
method = getattr(callback, method_name)
|
|
663
|
-
method(*args, **kwargs)
|
|
664
|
-
|
|
665
|
-
def to_cpu(self):
|
|
666
|
-
"""
|
|
667
|
-
Moves the model to the CPU and updates the trainer's device setting.
|
|
668
1087
|
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
def to_device(self, device: str):
|
|
1088
|
+
def finalize_model_training(self,
|
|
1089
|
+
save_dir: Union[str, Path],
|
|
1090
|
+
filename: str,
|
|
1091
|
+
model_checkpoint: Union[Path, Literal['latest', 'current']],
|
|
1092
|
+
classification_threshold: Optional[float]=None,
|
|
1093
|
+
class_map: Optional[Dict[str,int]]=None):
|
|
676
1094
|
"""
|
|
677
|
-
|
|
1095
|
+
Saves a finalized, "inference-ready" model state to a .pth file.
|
|
1096
|
+
|
|
1097
|
+
This method saves the model's `state_dict`, the final epoch number, and
|
|
1098
|
+
an optional classification threshold required for binary-based tasks (binary classification, binary segmentation,
|
|
1099
|
+
multilabel binary classification).
|
|
678
1100
|
|
|
679
1101
|
Args:
|
|
680
|
-
|
|
1102
|
+
save_dir (str | Path): The directory to save the finalized model.
|
|
1103
|
+
filename (str): The desired filename for the saved file.
|
|
1104
|
+
model_checkpoint (Path | "latest" | "current"):
|
|
1105
|
+
- Path: Loads the model state from a specific checkpoint file.
|
|
1106
|
+
- "latest": Loads the best model state saved by the `DragonModelCheckpoint` callback.
|
|
1107
|
+
- "current": Uses the model's state as it is at the end of the `fit()` call.
|
|
1108
|
+
classification_threshold (float, None):
|
|
1109
|
+
Required for `binary classification`, `binary segmentation`, and
|
|
1110
|
+
`multilabel binary classification`. This is the threshold (0.0-1.0)
|
|
1111
|
+
used to convert probabilities to class labels.
|
|
1112
|
+
class_map (Dict[str, int] | None): Sets the class name mapping to translate predicted integer labels back into string names. (For Classification and Segmentation Tasks)
|
|
681
1113
|
"""
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
|
|
1114
|
+
# handle save path
|
|
1115
|
+
sanitized_filename = sanitize_filename(filename)
|
|
1116
|
+
if not sanitized_filename.endswith(".pth"):
|
|
1117
|
+
sanitized_filename = sanitized_filename + ".pth"
|
|
1118
|
+
|
|
1119
|
+
dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
1120
|
+
full_path = dir_path / sanitized_filename
|
|
1121
|
+
|
|
1122
|
+
# threshold required for binary tasks
|
|
1123
|
+
if self.kind in MLTaskKeys.ALL_BINARY_TASKS:
|
|
1124
|
+
if classification_threshold is None:
|
|
1125
|
+
_LOGGER.error(f"A classification threshold is needed for binary-based classification tasks. If unknown, use '0.5' as a default.")
|
|
1126
|
+
raise ValueError()
|
|
1127
|
+
elif not isinstance(classification_threshold, float):
|
|
1128
|
+
_LOGGER.error(f"The classification threshold must be a float value.")
|
|
1129
|
+
raise TypeError()
|
|
1130
|
+
elif classification_threshold <= 0.0 or classification_threshold >= 1.0:
|
|
1131
|
+
_LOGGER.error(f"The classification threshold must be in the range (0.0 - 1.0).")
|
|
1132
|
+
else:
|
|
1133
|
+
classification_threshold = None
|
|
1134
|
+
|
|
1135
|
+
# handle checkpoint
|
|
1136
|
+
if isinstance(model_checkpoint, Path):
|
|
1137
|
+
self._load_checkpoint(path=model_checkpoint)
|
|
1138
|
+
elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback:
|
|
1139
|
+
path_to_latest = self._checkpoint_callback.best_checkpoint_path
|
|
1140
|
+
self._load_checkpoint(path_to_latest)
|
|
1141
|
+
elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback is None:
|
|
1142
|
+
_LOGGER.error(f"'model_checkpoint' set to '{MagicWords.LATEST}' but no checkpoint callback was found.")
|
|
1143
|
+
raise ValueError()
|
|
1144
|
+
elif model_checkpoint == MagicWords.CURRENT:
|
|
1145
|
+
pass
|
|
1146
|
+
else:
|
|
1147
|
+
_LOGGER.error(f"Unknown 'model_checkpoint' parameter received '{model_checkpoint}'.")
|
|
1148
|
+
|
|
1149
|
+
# Handle class map
|
|
1150
|
+
if self.kind in [MLTaskKeys.BINARY_CLASSIFICATION,
|
|
1151
|
+
MLTaskKeys.MULTICLASS_CLASSIFICATION,
|
|
1152
|
+
MLTaskKeys.BINARY_IMAGE_CLASSIFICATION,
|
|
1153
|
+
MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION,
|
|
1154
|
+
MLTaskKeys.BINARY_SEGMENTATION,
|
|
1155
|
+
MLTaskKeys.MULTICLASS_SEGMENTATION]:
|
|
1156
|
+
if class_map is None:
|
|
1157
|
+
_LOGGER.error(f"'class_map' is required for '{self.kind}'.")
|
|
1158
|
+
raise ValueError()
|
|
1159
|
+
else:
|
|
1160
|
+
class_map = None
|
|
1161
|
+
|
|
1162
|
+
# Create finalized data
|
|
1163
|
+
finalized_data = {
|
|
1164
|
+
PyTorchCheckpointKeys.EPOCH: self.epoch,
|
|
1165
|
+
PyTorchCheckpointKeys.MODEL_STATE: self.model.state_dict(),
|
|
1166
|
+
}
|
|
1167
|
+
|
|
1168
|
+
if classification_threshold is not None:
|
|
1169
|
+
self._classification_threshold = classification_threshold
|
|
1170
|
+
finalized_data[PyTorchCheckpointKeys.CLASSIFICATION_THRESHOLD] = classification_threshold
|
|
1171
|
+
|
|
1172
|
+
if class_map is not None:
|
|
1173
|
+
finalized_data[PyTorchCheckpointKeys.CLASS_MAP] = class_map
|
|
1174
|
+
|
|
1175
|
+
torch.save(finalized_data, full_path)
|
|
1176
|
+
|
|
1177
|
+
_LOGGER.info(f"Finalized model weights saved to {full_path}.")
|
|
685
1178
|
|
|
686
1179
|
|
|
687
1180
|
# Object Detection Trainer
|
|
688
|
-
class
|
|
689
|
-
def __init__(self, model: nn.Module,
|
|
1181
|
+
class DragonDetectionTrainer(_BaseDragonTrainer):
|
|
1182
|
+
def __init__(self, model: nn.Module,
|
|
1183
|
+
train_dataset: Dataset,
|
|
1184
|
+
validation_dataset: Dataset,
|
|
690
1185
|
collate_fn: Callable, optimizer: torch.optim.Optimizer,
|
|
691
|
-
device: Union[Literal['cuda', 'mps', 'cpu'],str],
|
|
1186
|
+
device: Union[Literal['cuda', 'mps', 'cpu'],str],
|
|
1187
|
+
checkpoint_callback: Optional[DragonModelCheckpoint],
|
|
1188
|
+
early_stopping_callback: Optional[DragonEarlyStopping],
|
|
1189
|
+
lr_scheduler_callback: Optional[DragonLRScheduler],
|
|
1190
|
+
extra_callbacks: Optional[List[_Callback]] = None,
|
|
1191
|
+
dataloader_workers: int = 2):
|
|
692
1192
|
"""
|
|
693
1193
|
Automates the training process of an Object Detection Model (e.g., DragonFastRCNN).
|
|
694
1194
|
|
|
@@ -697,58 +1197,36 @@ class ObjectDetectionTrainer:
|
|
|
697
1197
|
Args:
|
|
698
1198
|
model (nn.Module): The PyTorch object detection model to train.
|
|
699
1199
|
train_dataset (Dataset): The training dataset.
|
|
700
|
-
|
|
1200
|
+
validation_dataset (Dataset): The testing/validation dataset.
|
|
701
1201
|
collate_fn (Callable): The collate function from `ObjectDetectionDatasetMaker.collate_fn`.
|
|
702
1202
|
optimizer (torch.optim.Optimizer): The optimizer.
|
|
703
1203
|
device (str): The device to run training on ('cpu', 'cuda', 'mps').
|
|
704
1204
|
dataloader_workers (int): Subprocesses for data loading.
|
|
705
|
-
|
|
1205
|
+
checkpoint_callback (DragonModelCheckpoint | None): Callback to save the model.
|
|
1206
|
+
early_stopping_callback (DragonEarlyStopping | None): Callback to stop training early.
|
|
1207
|
+
lr_scheduler_callback (DragonLRScheduler | None): Callback to manage the LR scheduler.
|
|
1208
|
+
extra_callbacks (List[Callback] | None): A list of extra callbacks to use during training.
|
|
706
1209
|
|
|
707
1210
|
## Note:
|
|
708
1211
|
This trainer is specialized. It does not take a `criterion` because object detection models like Faster R-CNN return a dictionary of losses directly from their forward pass during training.
|
|
709
1212
|
"""
|
|
710
|
-
|
|
1213
|
+
# Call the base class constructor with common parameters
|
|
1214
|
+
super().__init__(
|
|
1215
|
+
model=model,
|
|
1216
|
+
optimizer=optimizer,
|
|
1217
|
+
device=device,
|
|
1218
|
+
dataloader_workers=dataloader_workers,
|
|
1219
|
+
checkpoint_callback=checkpoint_callback,
|
|
1220
|
+
early_stopping_callback=early_stopping_callback,
|
|
1221
|
+
lr_scheduler_callback=lr_scheduler_callback,
|
|
1222
|
+
extra_callbacks=extra_callbacks
|
|
1223
|
+
)
|
|
1224
|
+
|
|
711
1225
|
self.train_dataset = train_dataset
|
|
712
|
-
self.
|
|
1226
|
+
self.validation_dataset = validation_dataset # <-- Renamed
|
|
713
1227
|
self.kind = "object_detection"
|
|
714
1228
|
self.collate_fn = collate_fn
|
|
715
1229
|
self.criterion = None # Criterion is handled inside the model
|
|
716
|
-
self.optimizer = optimizer
|
|
717
|
-
self.scheduler = None
|
|
718
|
-
self.device = self._validate_device(device)
|
|
719
|
-
self.dataloader_workers = dataloader_workers
|
|
720
|
-
|
|
721
|
-
# Callback handler - History and TqdmProgressBar are added by default
|
|
722
|
-
default_callbacks = [History(), TqdmProgressBar()]
|
|
723
|
-
user_callbacks = callbacks if callbacks is not None else []
|
|
724
|
-
self.callbacks = default_callbacks + user_callbacks
|
|
725
|
-
self._set_trainer_on_callbacks()
|
|
726
|
-
|
|
727
|
-
# Internal state
|
|
728
|
-
self.train_loader = None
|
|
729
|
-
self.test_loader = None
|
|
730
|
-
self.history = {}
|
|
731
|
-
self.epoch = 0
|
|
732
|
-
self.epochs = 0 # Total epochs for the fit run
|
|
733
|
-
self.start_epoch = 1
|
|
734
|
-
self.stop_training = False
|
|
735
|
-
self._batch_size = 10
|
|
736
|
-
|
|
737
|
-
def _validate_device(self, device: str) -> torch.device:
|
|
738
|
-
"""Validates the selected device and returns a torch.device object."""
|
|
739
|
-
device_lower = device.lower()
|
|
740
|
-
if "cuda" in device_lower and not torch.cuda.is_available():
|
|
741
|
-
_LOGGER.warning("CUDA not available, switching to CPU.")
|
|
742
|
-
device = "cpu"
|
|
743
|
-
elif device_lower == "mps" and not torch.backends.mps.is_available():
|
|
744
|
-
_LOGGER.warning("Apple Metal Performance Shaders (MPS) not available, switching to CPU.")
|
|
745
|
-
device = "cpu"
|
|
746
|
-
return torch.device(device)
|
|
747
|
-
|
|
748
|
-
def _set_trainer_on_callbacks(self):
|
|
749
|
-
"""Gives each callback a reference to this trainer instance."""
|
|
750
|
-
for callback in self.callbacks:
|
|
751
|
-
callback.set_trainer(self)
|
|
752
1230
|
|
|
753
1231
|
def _create_dataloaders(self, batch_size: int, shuffle: bool):
|
|
754
1232
|
"""Initializes the DataLoaders with the object detection collate_fn."""
|
|
@@ -760,125 +1238,25 @@ class ObjectDetectionTrainer:
|
|
|
760
1238
|
batch_size=batch_size,
|
|
761
1239
|
shuffle=shuffle,
|
|
762
1240
|
num_workers=loader_workers,
|
|
763
|
-
pin_memory=("cuda" in self.device.type),
|
|
764
|
-
collate_fn=self.collate_fn # Use the provided collate function
|
|
1241
|
+
pin_memory=("cuda" in self.device.type),
|
|
1242
|
+
collate_fn=self.collate_fn, # Use the provided collate function
|
|
1243
|
+
drop_last=True
|
|
765
1244
|
)
|
|
766
1245
|
|
|
767
|
-
self.
|
|
768
|
-
dataset=self.
|
|
1246
|
+
self.validation_loader = DataLoader(
|
|
1247
|
+
dataset=self.validation_dataset,
|
|
769
1248
|
batch_size=batch_size,
|
|
770
1249
|
shuffle=False,
|
|
771
1250
|
num_workers=loader_workers,
|
|
772
1251
|
pin_memory=("cuda" in self.device.type),
|
|
773
1252
|
collate_fn=self.collate_fn # Use the provided collate function
|
|
774
1253
|
)
|
|
775
|
-
|
|
776
|
-
def _load_checkpoint(self, path: Union[str, Path]):
|
|
777
|
-
"""Loads a training checkpoint to resume training."""
|
|
778
|
-
p = make_fullpath(path, enforce="file")
|
|
779
|
-
_LOGGER.info(f"Loading checkpoint from '{p.name}' to resume training...")
|
|
780
|
-
|
|
781
|
-
try:
|
|
782
|
-
checkpoint = torch.load(p, map_location=self.device)
|
|
783
|
-
|
|
784
|
-
if PyTorchCheckpointKeys.MODEL_STATE not in checkpoint or PyTorchCheckpointKeys.OPTIMIZER_STATE not in checkpoint:
|
|
785
|
-
_LOGGER.error(f"Checkpoint file '{p.name}' is invalid. Missing 'model_state_dict' or 'optimizer_state_dict'.")
|
|
786
|
-
raise KeyError()
|
|
787
|
-
|
|
788
|
-
self.model.load_state_dict(checkpoint[PyTorchCheckpointKeys.MODEL_STATE])
|
|
789
|
-
self.optimizer.load_state_dict(checkpoint[PyTorchCheckpointKeys.OPTIMIZER_STATE])
|
|
790
|
-
self.start_epoch = checkpoint.get(PyTorchCheckpointKeys.EPOCH, 0) + 1 # Resume on the *next* epoch
|
|
791
|
-
|
|
792
|
-
# --- Scheduler State Loading Logic ---
|
|
793
|
-
scheduler_state_exists = PyTorchCheckpointKeys.SCHEDULER_STATE in checkpoint
|
|
794
|
-
scheduler_object_exists = self.scheduler is not None
|
|
795
|
-
|
|
796
|
-
if scheduler_object_exists and scheduler_state_exists:
|
|
797
|
-
# Case 1: Both exist. Attempt to load.
|
|
798
|
-
try:
|
|
799
|
-
self.scheduler.load_state_dict(checkpoint[PyTorchCheckpointKeys.SCHEDULER_STATE]) # type: ignore
|
|
800
|
-
scheduler_name = self.scheduler.__class__.__name__
|
|
801
|
-
_LOGGER.info(f"Restored LR scheduler state for: {scheduler_name}")
|
|
802
|
-
except Exception as e:
|
|
803
|
-
# Loading failed, likely a mismatch
|
|
804
|
-
scheduler_name = self.scheduler.__class__.__name__
|
|
805
|
-
_LOGGER.error(f"Failed to load scheduler state for '{scheduler_name}'. A different scheduler type might have been used.")
|
|
806
|
-
raise e
|
|
807
|
-
|
|
808
|
-
elif scheduler_object_exists and not scheduler_state_exists:
|
|
809
|
-
# Case 2: Scheduler provided, but no state in checkpoint.
|
|
810
|
-
scheduler_name = self.scheduler.__class__.__name__
|
|
811
|
-
_LOGGER.warning(f"'{scheduler_name}' was provided, but no scheduler state was found in the checkpoint. The scheduler will start from its initial state.")
|
|
812
|
-
|
|
813
|
-
elif not scheduler_object_exists and scheduler_state_exists:
|
|
814
|
-
# Case 3: State in checkpoint, but no scheduler provided.
|
|
815
|
-
_LOGGER.error("Checkpoint contains an LR scheduler state, but no LRScheduler callback was provided.")
|
|
816
|
-
raise ValueError()
|
|
817
|
-
|
|
818
|
-
# Restore callback states
|
|
819
|
-
for cb in self.callbacks:
|
|
820
|
-
if isinstance(cb, ModelCheckpoint) and PyTorchCheckpointKeys.BEST_SCORE in checkpoint:
|
|
821
|
-
cb.best = checkpoint[PyTorchCheckpointKeys.BEST_SCORE]
|
|
822
|
-
_LOGGER.info(f"Restored {cb.__class__.__name__} 'best' score to: {cb.best:.4f}")
|
|
823
|
-
|
|
824
|
-
_LOGGER.info(f"Checkpoint loaded. Resuming training from epoch {self.start_epoch}.")
|
|
825
|
-
|
|
826
|
-
except Exception as e:
|
|
827
|
-
_LOGGER.error(f"Failed to load checkpoint from '{p}': {e}")
|
|
828
|
-
raise
|
|
829
|
-
|
|
830
|
-
def fit(self,
|
|
831
|
-
epochs: int = 10,
|
|
832
|
-
batch_size: int = 10,
|
|
833
|
-
shuffle: bool = True,
|
|
834
|
-
resume_from_checkpoint: Optional[Union[str, Path]] = None):
|
|
835
|
-
"""
|
|
836
|
-
Starts the training-validation process of the model.
|
|
837
|
-
|
|
838
|
-
Returns the "History" callback dictionary.
|
|
839
1254
|
|
|
840
|
-
Args:
|
|
841
|
-
epochs (int): The total number of epochs to train for.
|
|
842
|
-
batch_size (int): The number of samples per batch.
|
|
843
|
-
shuffle (bool): Whether to shuffle the training data at each epoch.
|
|
844
|
-
resume_from_checkpoint (str | Path | None): Optional path to a checkpoint to resume training.
|
|
845
|
-
"""
|
|
846
|
-
self.epochs = epochs
|
|
847
|
-
self._batch_size = batch_size
|
|
848
|
-
self._create_dataloaders(self._batch_size, shuffle)
|
|
849
|
-
self.model.to(self.device)
|
|
850
|
-
|
|
851
|
-
if resume_from_checkpoint:
|
|
852
|
-
self._load_checkpoint(resume_from_checkpoint)
|
|
853
|
-
|
|
854
|
-
# Reset stop_training flag on the trainer
|
|
855
|
-
self.stop_training = False
|
|
856
|
-
|
|
857
|
-
self._callbacks_hook('on_train_begin')
|
|
858
|
-
|
|
859
|
-
for epoch in range(self.start_epoch, self.epochs + 1):
|
|
860
|
-
self.epoch = epoch
|
|
861
|
-
epoch_logs = {}
|
|
862
|
-
self._callbacks_hook('on_epoch_begin', epoch, logs=epoch_logs)
|
|
863
|
-
|
|
864
|
-
train_logs = self._train_step()
|
|
865
|
-
epoch_logs.update(train_logs)
|
|
866
|
-
|
|
867
|
-
val_logs = self._validation_step()
|
|
868
|
-
epoch_logs.update(val_logs)
|
|
869
|
-
|
|
870
|
-
self._callbacks_hook('on_epoch_end', epoch, logs=epoch_logs)
|
|
871
|
-
|
|
872
|
-
# Check the early stopping flag
|
|
873
|
-
if self.stop_training:
|
|
874
|
-
break
|
|
875
|
-
|
|
876
|
-
self._callbacks_hook('on_train_end')
|
|
877
|
-
return self.history
|
|
878
|
-
|
|
879
1255
|
def _train_step(self):
|
|
880
1256
|
self.model.train()
|
|
881
1257
|
running_loss = 0.0
|
|
1258
|
+
total_samples = 0
|
|
1259
|
+
|
|
882
1260
|
for batch_idx, (images, targets) in enumerate(self.train_loader): # type: ignore
|
|
883
1261
|
# images is a tuple of tensors, targets is a tuple of dicts
|
|
884
1262
|
batch_size = len(images)
|
|
@@ -915,21 +1293,28 @@ class ObjectDetectionTrainer:
|
|
|
915
1293
|
# Calculate batch loss and update running loss for the epoch
|
|
916
1294
|
batch_loss = loss.item()
|
|
917
1295
|
running_loss += batch_loss * batch_size
|
|
1296
|
+
total_samples += batch_size # <-- Accumulate total samples
|
|
918
1297
|
|
|
919
1298
|
# Add the batch loss to the logs and call the end-of-batch hook
|
|
920
1299
|
batch_logs[PyTorchLogKeys.BATCH_LOSS] = batch_loss # type: ignore
|
|
921
1300
|
self._callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
|
|
1301
|
+
|
|
1302
|
+
# Calculate loss using the correct denominator
|
|
1303
|
+
if total_samples == 0:
|
|
1304
|
+
_LOGGER.warning("No samples processed in _train_step. Returning 0 loss.")
|
|
1305
|
+
return {PyTorchLogKeys.TRAIN_LOSS: 0.0}
|
|
922
1306
|
|
|
923
|
-
return {PyTorchLogKeys.TRAIN_LOSS: running_loss /
|
|
1307
|
+
return {PyTorchLogKeys.TRAIN_LOSS: running_loss / total_samples}
|
|
924
1308
|
|
|
925
1309
|
def _validation_step(self):
|
|
926
1310
|
self.model.train() # Set to train mode even for validation loss calculation
|
|
927
|
-
# as model internals (e.g., proposals) might differ,
|
|
928
|
-
#
|
|
929
|
-
# We use torch.no_grad() to prevent gradient updates.
|
|
1311
|
+
# as model internals (e.g., proposals) might differ, but we still need loss_dict.
|
|
1312
|
+
# use torch.no_grad() to prevent gradient updates.
|
|
930
1313
|
running_loss = 0.0
|
|
1314
|
+
total_samples = 0
|
|
1315
|
+
|
|
931
1316
|
with torch.no_grad():
|
|
932
|
-
for images, targets in self.
|
|
1317
|
+
for images, targets in self.validation_loader: # type: ignore
|
|
933
1318
|
batch_size = len(images)
|
|
934
1319
|
|
|
935
1320
|
# Move data to device
|
|
@@ -947,25 +1332,105 @@ class ObjectDetectionTrainer:
|
|
|
947
1332
|
loss: torch.Tensor = sum(l for l in loss_dict.values()) # type: ignore
|
|
948
1333
|
|
|
949
1334
|
running_loss += loss.item() * batch_size
|
|
1335
|
+
total_samples += batch_size # <-- Accumulate total samples
|
|
950
1336
|
|
|
951
|
-
|
|
1337
|
+
# Calculate loss using the correct denominator
|
|
1338
|
+
if total_samples == 0:
|
|
1339
|
+
_LOGGER.warning("No samples processed in _validation_step. Returning 0 loss.")
|
|
1340
|
+
return {PyTorchLogKeys.VAL_LOSS: 0.0}
|
|
1341
|
+
|
|
1342
|
+
logs = {PyTorchLogKeys.VAL_LOSS: running_loss / total_samples}
|
|
952
1343
|
return logs
|
|
1344
|
+
|
|
1345
|
+
def evaluate(self,
|
|
1346
|
+
save_dir: Union[str, Path],
|
|
1347
|
+
model_checkpoint: Union[Path, Literal["latest", "current"]],
|
|
1348
|
+
test_data: Optional[Union[DataLoader, Dataset]] = None):
|
|
1349
|
+
"""
|
|
1350
|
+
Evaluates the model using object detection mAP metrics.
|
|
953
1351
|
|
|
954
|
-
|
|
1352
|
+
Args:
|
|
1353
|
+
save_dir (str | Path): Directory to save all reports and plots.
|
|
1354
|
+
model_checkpoint ('auto' | Path | None):
|
|
1355
|
+
- Path to a valid checkpoint for the model. The state of the trained model will be overwritten in place.
|
|
1356
|
+
- If 'latest', the latest checkpoint will be loaded if a DragonModelCheckpoint was provided. The state of the trained model will be overwritten in place.
|
|
1357
|
+
- If 'current', use the current state of the trained model up the latest trained epoch.
|
|
1358
|
+
test_data (DataLoader | Dataset | None): Optional Test data to evaluate the model performance. Validation and Test metrics will be saved to subdirectories.
|
|
955
1359
|
"""
|
|
1360
|
+
# Validate model checkpoint
|
|
1361
|
+
if isinstance(model_checkpoint, Path):
|
|
1362
|
+
checkpoint_validated = make_fullpath(model_checkpoint, enforce="file")
|
|
1363
|
+
elif model_checkpoint in [MagicWords.LATEST, MagicWords.CURRENT]:
|
|
1364
|
+
checkpoint_validated = model_checkpoint
|
|
1365
|
+
else:
|
|
1366
|
+
_LOGGER.error(f"'model_checkpoint' must be a Path object, or the string '{MagicWords.LATEST}', or the string '{MagicWords.CURRENT}'.")
|
|
1367
|
+
raise ValueError()
|
|
1368
|
+
|
|
1369
|
+
# Validate directory
|
|
1370
|
+
save_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
1371
|
+
|
|
1372
|
+
# Validate test data and dispatch
|
|
1373
|
+
if test_data is not None:
|
|
1374
|
+
if not isinstance(test_data, (DataLoader, Dataset)):
|
|
1375
|
+
_LOGGER.error(f"Invalid type for 'test_data': '{type(test_data)}'.")
|
|
1376
|
+
raise ValueError()
|
|
1377
|
+
test_data_validated = test_data
|
|
1378
|
+
|
|
1379
|
+
validation_metrics_path = save_path / DragonTrainerKeys.VALIDATION_METRICS_DIR
|
|
1380
|
+
test_metrics_path = save_path / DragonTrainerKeys.TEST_METRICS_DIR
|
|
1381
|
+
|
|
1382
|
+
# Dispatch validation set
|
|
1383
|
+
_LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{DragonTrainerKeys.VALIDATION_METRICS_DIR}'")
|
|
1384
|
+
self._evaluate(save_dir=validation_metrics_path,
|
|
1385
|
+
model_checkpoint=checkpoint_validated,
|
|
1386
|
+
data=None) # 'None' triggers use of self.test_dataset
|
|
1387
|
+
|
|
1388
|
+
# Dispatch test set
|
|
1389
|
+
_LOGGER.info(f"Evaluating on test dataset. Metrics will be saved to '{DragonTrainerKeys.TEST_METRICS_DIR}'")
|
|
1390
|
+
self._evaluate(save_dir=test_metrics_path,
|
|
1391
|
+
model_checkpoint="current", # Use 'current' state after loading checkpoint once
|
|
1392
|
+
data=test_data_validated)
|
|
1393
|
+
else:
|
|
1394
|
+
# Dispatch validation set
|
|
1395
|
+
_LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{save_path.name}'")
|
|
1396
|
+
self._evaluate(save_dir=save_path,
|
|
1397
|
+
model_checkpoint=checkpoint_validated,
|
|
1398
|
+
data=None) # 'None' triggers use of self.test_dataset
|
|
1399
|
+
|
|
1400
|
+
def _evaluate(self,
|
|
1401
|
+
save_dir: Union[str, Path],
|
|
1402
|
+
model_checkpoint: Union[Path, Literal["latest", "current"]],
|
|
1403
|
+
data: Optional[Union[DataLoader, Dataset]]):
|
|
1404
|
+
"""
|
|
1405
|
+
Changed to a private helper method
|
|
956
1406
|
Evaluates the model using object detection mAP metrics.
|
|
957
1407
|
|
|
958
1408
|
Args:
|
|
959
1409
|
save_dir (str | Path): Directory to save all reports and plots.
|
|
960
1410
|
data (DataLoader | Dataset | None): The data to evaluate on. If None, defaults to the trainer's internal test_dataset.
|
|
1411
|
+
model_checkpoint ('auto' | Path | None):
|
|
1412
|
+
- Path to a valid checkpoint for the model. The state of the trained model will be overwritten in place.
|
|
1413
|
+
- If 'latest', the latest checkpoint will be loaded if a DragonModelCheckpoint was provided. The state of the trained model will be overwritten in place.
|
|
1414
|
+
- If 'current', use the current state of the trained model up the latest trained epoch.
|
|
961
1415
|
"""
|
|
962
1416
|
dataset_for_names = None
|
|
963
1417
|
eval_loader = None
|
|
1418
|
+
|
|
1419
|
+
# load model checkpoint
|
|
1420
|
+
if isinstance(model_checkpoint, Path):
|
|
1421
|
+
self._load_checkpoint(path=model_checkpoint)
|
|
1422
|
+
elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback:
|
|
1423
|
+
path_to_latest = self._checkpoint_callback.best_checkpoint_path
|
|
1424
|
+
self._load_checkpoint(path_to_latest)
|
|
1425
|
+
elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback is None:
|
|
1426
|
+
_LOGGER.error(f"'model_checkpoint' set to '{MagicWords.LATEST}' but no checkpoint callback was found.")
|
|
1427
|
+
raise ValueError()
|
|
964
1428
|
|
|
1429
|
+
# Dataloader
|
|
965
1430
|
if isinstance(data, DataLoader):
|
|
966
1431
|
eval_loader = data
|
|
967
1432
|
if hasattr(data, 'dataset'):
|
|
968
|
-
dataset_for_names = data.dataset
|
|
1433
|
+
dataset_for_names = data.dataset # type: ignore
|
|
969
1434
|
elif isinstance(data, Dataset):
|
|
970
1435
|
# Create a new loader from the provided dataset
|
|
971
1436
|
eval_loader = DataLoader(data,
|
|
@@ -976,25 +1441,25 @@ class ObjectDetectionTrainer:
|
|
|
976
1441
|
collate_fn=self.collate_fn)
|
|
977
1442
|
dataset_for_names = data
|
|
978
1443
|
else: # data is None, use the trainer's default test dataset
|
|
979
|
-
if self.
|
|
1444
|
+
if self.validation_dataset is None:
|
|
980
1445
|
_LOGGER.error("Cannot evaluate. No data provided and no test_dataset available in the trainer.")
|
|
981
1446
|
raise ValueError()
|
|
982
1447
|
# Create a fresh DataLoader from the test_dataset
|
|
983
1448
|
eval_loader = DataLoader(
|
|
984
|
-
self.
|
|
1449
|
+
self.validation_dataset,
|
|
985
1450
|
batch_size=self._batch_size,
|
|
986
1451
|
shuffle=False,
|
|
987
1452
|
num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
|
|
988
1453
|
pin_memory=(self.device.type == "cuda"),
|
|
989
1454
|
collate_fn=self.collate_fn
|
|
990
1455
|
)
|
|
991
|
-
dataset_for_names = self.
|
|
1456
|
+
dataset_for_names = self.validation_dataset
|
|
992
1457
|
|
|
993
1458
|
if eval_loader is None:
|
|
994
1459
|
_LOGGER.error("Cannot evaluate. No valid data was provided or found.")
|
|
995
1460
|
raise ValueError()
|
|
996
1461
|
|
|
997
|
-
print("\n--- Model Evaluation ---")
|
|
1462
|
+
# print("\n--- Model Evaluation ---")
|
|
998
1463
|
|
|
999
1464
|
all_predictions = []
|
|
1000
1465
|
all_targets = []
|
|
@@ -1042,36 +1507,480 @@ class ObjectDetectionTrainer:
|
|
|
1042
1507
|
class_names=class_names,
|
|
1043
1508
|
print_output=False
|
|
1044
1509
|
)
|
|
1045
|
-
|
|
1046
|
-
print("\n--- Training History ---")
|
|
1047
|
-
plot_losses(self.history, save_dir=save_dir)
|
|
1048
1510
|
|
|
1049
|
-
def
|
|
1050
|
-
"""
|
|
1051
|
-
|
|
1052
|
-
|
|
1053
|
-
|
|
1511
|
+
def finalize_model_training(self, save_dir: Union[str, Path], filename: str, model_checkpoint: Union[Path, Literal['latest', 'current']]):
|
|
1512
|
+
"""
|
|
1513
|
+
Saves a finalized, "inference-ready" model state to a .pth file.
|
|
1514
|
+
|
|
1515
|
+
This method saves the model's `state_dict` and the final epoch number.
|
|
1516
|
+
|
|
1517
|
+
Args:
|
|
1518
|
+
save_dir (Union[str, Path]): The directory to save the finalized model.
|
|
1519
|
+
filename (str): The desired filename for the model (e.g., "final_model.pth").
|
|
1520
|
+
model_checkpoint (Union[Path, Literal["latest", "current"]]):
|
|
1521
|
+
- Path: Loads the model state from a specific checkpoint file.
|
|
1522
|
+
- "latest": Loads the best model state saved by the `DragonModelCheckpoint` callback.
|
|
1523
|
+
- "current": Uses the model's state as it is at the end of the `fit()` call.
|
|
1524
|
+
"""
|
|
1525
|
+
# handle save path
|
|
1526
|
+
sanitized_filename = sanitize_filename(filename)
|
|
1527
|
+
if not sanitized_filename.endswith(".pth"):
|
|
1528
|
+
sanitized_filename = sanitized_filename + ".pth"
|
|
1054
1529
|
|
|
1055
|
-
|
|
1530
|
+
dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
1531
|
+
full_path = dir_path / sanitized_filename
|
|
1532
|
+
|
|
1533
|
+
# handle checkpoint
|
|
1534
|
+
if isinstance(model_checkpoint, Path):
|
|
1535
|
+
self._load_checkpoint(path=model_checkpoint)
|
|
1536
|
+
elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback:
|
|
1537
|
+
path_to_latest = self._checkpoint_callback.best_checkpoint_path
|
|
1538
|
+
self._load_checkpoint(path_to_latest)
|
|
1539
|
+
elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback is None:
|
|
1540
|
+
_LOGGER.error(f"'model_checkpoint' set to '{MagicWords.LATEST}' but no checkpoint callback was found.")
|
|
1541
|
+
raise ValueError()
|
|
1542
|
+
elif model_checkpoint == MagicWords.CURRENT:
|
|
1543
|
+
pass
|
|
1544
|
+
else:
|
|
1545
|
+
_LOGGER.error(f"Unknown 'model_checkpoint' parameter received '{model_checkpoint}'.")
|
|
1546
|
+
|
|
1547
|
+
# Create finalized data
|
|
1548
|
+
finalized_data = {
|
|
1549
|
+
PyTorchCheckpointKeys.EPOCH: self.epoch,
|
|
1550
|
+
PyTorchCheckpointKeys.MODEL_STATE: self.model.state_dict(),
|
|
1551
|
+
}
|
|
1552
|
+
|
|
1553
|
+
torch.save(finalized_data, full_path)
|
|
1554
|
+
|
|
1555
|
+
_LOGGER.info(f"Finalized model weights saved to {full_path}.")
|
|
1556
|
+
|
|
1557
|
+
# --- DragonSequenceTrainer ----
|
|
1558
|
+
class DragonSequenceTrainer(_BaseDragonTrainer):
|
|
1559
|
+
def __init__(self,
|
|
1560
|
+
model: nn.Module,
|
|
1561
|
+
train_dataset: Dataset,
|
|
1562
|
+
validation_dataset: Dataset,
|
|
1563
|
+
kind: Literal["sequence-to-sequence", "sequence-to-value"],
|
|
1564
|
+
optimizer: torch.optim.Optimizer,
|
|
1565
|
+
device: Union[Literal['cuda', 'mps', 'cpu'],str],
|
|
1566
|
+
checkpoint_callback: Optional[DragonModelCheckpoint],
|
|
1567
|
+
early_stopping_callback: Optional[DragonEarlyStopping],
|
|
1568
|
+
lr_scheduler_callback: Optional[DragonLRScheduler],
|
|
1569
|
+
extra_callbacks: Optional[List[_Callback]] = None,
|
|
1570
|
+
criterion: Union[nn.Module,Literal["auto"]] = "auto",
|
|
1571
|
+
dataloader_workers: int = 2):
|
|
1056
1572
|
"""
|
|
1057
|
-
|
|
1573
|
+
Automates the training process of a PyTorch Sequence Model.
|
|
1058
1574
|
|
|
1059
|
-
|
|
1575
|
+
Built-in Callbacks: `History`, `TqdmProgressBar`
|
|
1576
|
+
|
|
1577
|
+
Args:
|
|
1578
|
+
model (nn.Module): The PyTorch model to train.
|
|
1579
|
+
train_dataset (Dataset): The training dataset.
|
|
1580
|
+
validation_dataset (Dataset): The validation dataset.
|
|
1581
|
+
kind (str): Used to redirect to the correct process ('sequence-to-sequence' or 'sequence-to-value').
|
|
1582
|
+
criterion (nn.Module | "auto"): The loss function to use. If "auto", it will be inferred from the selected task
|
|
1583
|
+
optimizer (torch.optim.Optimizer): The optimizer.
|
|
1584
|
+
device (str): The device to run training on ('cpu', 'cuda', 'mps').
|
|
1585
|
+
dataloader_workers (int): Subprocesses for data loading.
|
|
1586
|
+
extra_callbacks (List[Callback] | None): A list of extra callbacks to use during training.
|
|
1060
1587
|
"""
|
|
1061
|
-
|
|
1588
|
+
# Call the base class constructor with common parameters
|
|
1589
|
+
super().__init__(
|
|
1590
|
+
model=model,
|
|
1591
|
+
optimizer=optimizer,
|
|
1592
|
+
device=device,
|
|
1593
|
+
dataloader_workers=dataloader_workers,
|
|
1594
|
+
checkpoint_callback=checkpoint_callback,
|
|
1595
|
+
early_stopping_callback=early_stopping_callback,
|
|
1596
|
+
lr_scheduler_callback=lr_scheduler_callback,
|
|
1597
|
+
extra_callbacks=extra_callbacks
|
|
1598
|
+
)
|
|
1599
|
+
|
|
1600
|
+
if kind not in [MLTaskKeys.SEQUENCE_SEQUENCE, MLTaskKeys.SEQUENCE_VALUE]:
|
|
1601
|
+
raise ValueError(f"'{kind}' is not a valid task type for DragonSequenceTrainer.")
|
|
1602
|
+
|
|
1603
|
+
self.train_dataset = train_dataset
|
|
1604
|
+
self.validation_dataset = validation_dataset
|
|
1605
|
+
self.kind = kind
|
|
1606
|
+
|
|
1607
|
+
# try to validate against Dragon Sequence model
|
|
1608
|
+
if hasattr(self.model, "prediction_mode"):
|
|
1609
|
+
key_to_check: str = self.model.prediction_mode # type: ignore
|
|
1610
|
+
if not key_to_check == self.kind:
|
|
1611
|
+
_LOGGER.error(f"Trainer was set for '{self.kind}', but model architecture '{self.model}' is built for '{key_to_check}'.")
|
|
1612
|
+
raise RuntimeError()
|
|
1613
|
+
|
|
1614
|
+
# loss function
|
|
1615
|
+
if criterion == "auto":
|
|
1616
|
+
# Both sequence tasks are treated as regression problems
|
|
1617
|
+
self.criterion = nn.MSELoss()
|
|
1618
|
+
else:
|
|
1619
|
+
self.criterion = criterion
|
|
1620
|
+
|
|
1621
|
+
def _create_dataloaders(self, batch_size: int, shuffle: bool):
|
|
1622
|
+
"""Initializes the DataLoaders."""
|
|
1623
|
+
# Ensure stability on MPS devices by setting num_workers to 0
|
|
1624
|
+
loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
|
|
1625
|
+
|
|
1626
|
+
self.train_loader = DataLoader(
|
|
1627
|
+
dataset=self.train_dataset,
|
|
1628
|
+
batch_size=batch_size,
|
|
1629
|
+
shuffle=shuffle,
|
|
1630
|
+
num_workers=loader_workers,
|
|
1631
|
+
pin_memory=("cuda" in self.device.type),
|
|
1632
|
+
drop_last=True # Drops the last batch if incomplete, selecting a good batch size is key.
|
|
1633
|
+
)
|
|
1634
|
+
|
|
1635
|
+
self.validation_loader = DataLoader(
|
|
1636
|
+
dataset=self.validation_dataset,
|
|
1637
|
+
batch_size=batch_size,
|
|
1638
|
+
shuffle=False,
|
|
1639
|
+
num_workers=loader_workers,
|
|
1640
|
+
pin_memory=("cuda" in self.device.type)
|
|
1641
|
+
)
|
|
1642
|
+
|
|
1643
|
+
def _train_step(self):
|
|
1644
|
+
self.model.train()
|
|
1645
|
+
running_loss = 0.0
|
|
1646
|
+
total_samples = 0
|
|
1647
|
+
|
|
1648
|
+
for batch_idx, (features, target) in enumerate(self.train_loader): # type: ignore
|
|
1649
|
+
# Create a log dictionary for the batch
|
|
1650
|
+
batch_logs = {
|
|
1651
|
+
PyTorchLogKeys.BATCH_INDEX: batch_idx,
|
|
1652
|
+
PyTorchLogKeys.BATCH_SIZE: features.size(0)
|
|
1653
|
+
}
|
|
1654
|
+
self._callbacks_hook('on_batch_begin', batch_idx, logs=batch_logs)
|
|
1655
|
+
|
|
1656
|
+
features, target = features.to(self.device), target.to(self.device)
|
|
1657
|
+
self.optimizer.zero_grad()
|
|
1658
|
+
|
|
1659
|
+
output = self.model(features)
|
|
1660
|
+
|
|
1661
|
+
# --- Label Type/Shape Correction ---
|
|
1662
|
+
# Ensure target is float for MSELoss
|
|
1663
|
+
target = target.float()
|
|
1664
|
+
|
|
1665
|
+
# For seq-to-val, models might output [N, 1] but target is [N].
|
|
1666
|
+
if self.kind == MLTaskKeys.SEQUENCE_VALUE:
|
|
1667
|
+
if output.ndim == 2 and output.shape[1] == 1 and target.ndim == 1:
|
|
1668
|
+
output = output.squeeze(1)
|
|
1669
|
+
|
|
1670
|
+
# For seq-to-seq, models might output [N, Seq, 1] but target is [N, Seq].
|
|
1671
|
+
elif self.kind == MLTaskKeys.SEQUENCE_SEQUENCE:
|
|
1672
|
+
if output.ndim == 3 and output.shape[2] == 1 and target.ndim == 2:
|
|
1673
|
+
output = output.squeeze(-1)
|
|
1674
|
+
|
|
1675
|
+
loss = self.criterion(output, target)
|
|
1676
|
+
|
|
1677
|
+
loss.backward()
|
|
1678
|
+
self.optimizer.step()
|
|
1679
|
+
|
|
1680
|
+
# Calculate batch loss and update running loss for the epoch
|
|
1681
|
+
batch_loss = loss.item()
|
|
1682
|
+
batch_size = features.size(0)
|
|
1683
|
+
running_loss += batch_loss * batch_size # Accumulate total loss
|
|
1684
|
+
total_samples += batch_size # total samples
|
|
1685
|
+
|
|
1686
|
+
# Add the batch loss to the logs and call the end-of-batch hook
|
|
1687
|
+
batch_logs[PyTorchLogKeys.BATCH_LOSS] = batch_loss
|
|
1688
|
+
self._callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
|
|
1689
|
+
|
|
1690
|
+
if total_samples == 0:
|
|
1691
|
+
_LOGGER.warning("No samples processed in a train_step. Returning 0 loss.")
|
|
1692
|
+
return {PyTorchLogKeys.TRAIN_LOSS: 0.0}
|
|
1693
|
+
|
|
1694
|
+
return {PyTorchLogKeys.TRAIN_LOSS: running_loss / total_samples} # type: ignore
|
|
1695
|
+
|
|
1696
|
+
def _validation_step(self):
|
|
1697
|
+
self.model.eval()
|
|
1698
|
+
running_loss = 0.0
|
|
1699
|
+
|
|
1700
|
+
with torch.no_grad():
|
|
1701
|
+
for features, target in self.validation_loader: # type: ignore
|
|
1702
|
+
features, target = features.to(self.device), target.to(self.device)
|
|
1703
|
+
|
|
1704
|
+
output = self.model(features)
|
|
1705
|
+
|
|
1706
|
+
# --- Label Type/Shape Correction ---
|
|
1707
|
+
target = target.float()
|
|
1708
|
+
|
|
1709
|
+
# For seq-to-val, models might output [N, 1] but target is [N].
|
|
1710
|
+
if self.kind == MLTaskKeys.SEQUENCE_VALUE:
|
|
1711
|
+
if output.ndim == 2 and output.shape[1] == 1 and target.ndim == 1:
|
|
1712
|
+
output = output.squeeze(1)
|
|
1713
|
+
|
|
1714
|
+
# For seq-to-seq, models might output [N, Seq, 1] but target is [N, Seq].
|
|
1715
|
+
elif self.kind == MLTaskKeys.SEQUENCE_SEQUENCE:
|
|
1716
|
+
if output.ndim == 3 and output.shape[2] == 1 and target.ndim == 2:
|
|
1717
|
+
output = output.squeeze(-1)
|
|
1718
|
+
|
|
1719
|
+
loss = self.criterion(output, target)
|
|
1720
|
+
|
|
1721
|
+
running_loss += loss.item() * features.size(0)
|
|
1722
|
+
|
|
1723
|
+
if not self.validation_loader.dataset: # type: ignore
|
|
1724
|
+
_LOGGER.warning("No samples processed in _validation_step. Returning 0 loss.")
|
|
1725
|
+
return {PyTorchLogKeys.VAL_LOSS: 0.0}
|
|
1726
|
+
|
|
1727
|
+
logs = {PyTorchLogKeys.VAL_LOSS: running_loss / len(self.validation_loader.dataset)} # type: ignore
|
|
1728
|
+
return logs
|
|
1729
|
+
|
|
1730
|
+
def _predict_for_eval(self, dataloader: DataLoader):
|
|
1731
|
+
"""
|
|
1732
|
+
Private method to yield model predictions batch by batch for evaluation.
|
|
1733
|
+
|
|
1734
|
+
Yields:
|
|
1735
|
+
tuple: A tuple containing (y_pred_batch, y_prob_batch, y_true_batch).
|
|
1736
|
+
y_prob_batch is always None for sequence tasks.
|
|
1737
|
+
"""
|
|
1738
|
+
self.model.eval()
|
|
1062
1739
|
self.model.to(self.device)
|
|
1063
|
-
|
|
1740
|
+
|
|
1741
|
+
with torch.no_grad():
|
|
1742
|
+
for features, target in dataloader:
|
|
1743
|
+
features = features.to(self.device)
|
|
1744
|
+
output = self.model(features).cpu()
|
|
1745
|
+
|
|
1746
|
+
y_pred_batch = output.numpy()
|
|
1747
|
+
y_prob_batch = None # Not applicable for sequence regression
|
|
1748
|
+
y_true_batch = target.numpy()
|
|
1749
|
+
|
|
1750
|
+
yield y_pred_batch, y_prob_batch, y_true_batch
|
|
1751
|
+
|
|
1752
|
+
def evaluate(self,
|
|
1753
|
+
save_dir: Union[str, Path],
|
|
1754
|
+
model_checkpoint: Union[Path, Literal["latest", "current"]],
|
|
1755
|
+
test_data: Optional[Union[DataLoader, Dataset]] = None,
|
|
1756
|
+
val_format_configuration: Optional[Union[SequenceValueMetricsFormat,
|
|
1757
|
+
SequenceSequenceMetricsFormat]]=None,
|
|
1758
|
+
test_format_configuration: Optional[Union[SequenceValueMetricsFormat,
|
|
1759
|
+
SequenceSequenceMetricsFormat]]=None):
|
|
1760
|
+
"""
|
|
1761
|
+
Evaluates the model, routing to the correct evaluation function.
|
|
1762
|
+
|
|
1763
|
+
Args:
|
|
1764
|
+
model_checkpoint ('auto' | Path | None):
|
|
1765
|
+
- Path to a valid checkpoint for the model.
|
|
1766
|
+
- If 'latest', the latest checkpoint will be loaded.
|
|
1767
|
+
- If 'current', use the current state of the trained model.
|
|
1768
|
+
save_dir (str | Path): Directory to save all reports and plots.
|
|
1769
|
+
test_data (DataLoader | Dataset | None): Optional Test data.
|
|
1770
|
+
val_format_configuration: Optional configuration for validation metrics.
|
|
1771
|
+
test_format_configuration: Optional configuration for test metrics.
|
|
1772
|
+
"""
|
|
1773
|
+
# Validate model checkpoint
|
|
1774
|
+
if isinstance(model_checkpoint, Path):
|
|
1775
|
+
checkpoint_validated = make_fullpath(model_checkpoint, enforce="file")
|
|
1776
|
+
elif model_checkpoint in [MagicWords.LATEST, MagicWords.CURRENT]:
|
|
1777
|
+
checkpoint_validated = model_checkpoint
|
|
1778
|
+
else:
|
|
1779
|
+
_LOGGER.error(f"'model_checkpoint' must be a Path object, or '{MagicWords.LATEST}', or '{MagicWords.CURRENT}'.")
|
|
1780
|
+
raise ValueError()
|
|
1781
|
+
|
|
1782
|
+
# Validate val configuration
|
|
1783
|
+
if val_format_configuration is not None:
|
|
1784
|
+
if not isinstance(val_format_configuration, (SequenceValueMetricsFormat, SequenceSequenceMetricsFormat)):
|
|
1785
|
+
_LOGGER.error(f"Invalid 'val_format_configuration': '{type(val_format_configuration)}'.")
|
|
1786
|
+
raise ValueError()
|
|
1787
|
+
|
|
1788
|
+
# Validate directory
|
|
1789
|
+
save_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
1790
|
+
|
|
1791
|
+
# Validate test data and dispatch
|
|
1792
|
+
if test_data is not None:
|
|
1793
|
+
if not isinstance(test_data, (DataLoader, Dataset)):
|
|
1794
|
+
_LOGGER.error(f"Invalid type for 'test_data': '{type(test_data)}'.")
|
|
1795
|
+
raise ValueError()
|
|
1796
|
+
test_data_validated = test_data
|
|
1064
1797
|
|
|
1065
|
-
|
|
1798
|
+
validation_metrics_path = save_path / DragonTrainerKeys.VALIDATION_METRICS_DIR
|
|
1799
|
+
test_metrics_path = save_path / DragonTrainerKeys.TEST_METRICS_DIR
|
|
1800
|
+
|
|
1801
|
+
# Dispatch validation set
|
|
1802
|
+
_LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{DragonTrainerKeys.VALIDATION_METRICS_DIR}'")
|
|
1803
|
+
self._evaluate(save_dir=validation_metrics_path,
|
|
1804
|
+
model_checkpoint=checkpoint_validated,
|
|
1805
|
+
data=None,
|
|
1806
|
+
format_configuration=val_format_configuration)
|
|
1807
|
+
|
|
1808
|
+
# Validate test configuration
|
|
1809
|
+
test_configuration_validated = None
|
|
1810
|
+
if test_format_configuration is not None:
|
|
1811
|
+
if not isinstance(test_format_configuration, (SequenceValueMetricsFormat, SequenceSequenceMetricsFormat)):
|
|
1812
|
+
warning_message_type = f"Invalid test_format_configuration': '{type(test_format_configuration)}'."
|
|
1813
|
+
if val_format_configuration is not None:
|
|
1814
|
+
warning_message_type += " 'val_format_configuration' will be used."
|
|
1815
|
+
test_configuration_validated = val_format_configuration
|
|
1816
|
+
else:
|
|
1817
|
+
warning_message_type += " Using default format."
|
|
1818
|
+
_LOGGER.warning(warning_message_type)
|
|
1819
|
+
else:
|
|
1820
|
+
test_configuration_validated = test_format_configuration
|
|
1821
|
+
|
|
1822
|
+
# Dispatch test set
|
|
1823
|
+
_LOGGER.info(f"Evaluating on test dataset. Metrics will be saved to '{DragonTrainerKeys.TEST_METRICS_DIR}'")
|
|
1824
|
+
self._evaluate(save_dir=test_metrics_path,
|
|
1825
|
+
model_checkpoint="current",
|
|
1826
|
+
data=test_data_validated,
|
|
1827
|
+
format_configuration=test_configuration_validated)
|
|
1828
|
+
else:
|
|
1829
|
+
# Dispatch validation set
|
|
1830
|
+
_LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{save_path.name}'")
|
|
1831
|
+
self._evaluate(save_dir=save_path,
|
|
1832
|
+
model_checkpoint=checkpoint_validated,
|
|
1833
|
+
data=None,
|
|
1834
|
+
format_configuration=val_format_configuration)
|
|
1835
|
+
|
|
1836
|
+
def _evaluate(self,
|
|
1837
|
+
save_dir: Union[str, Path],
|
|
1838
|
+
model_checkpoint: Union[Path, Literal["latest", "current"]],
|
|
1839
|
+
data: Optional[Union[DataLoader, Dataset]],
|
|
1840
|
+
format_configuration: Optional[Union[SequenceValueMetricsFormat,
|
|
1841
|
+
SequenceSequenceMetricsFormat]]):
|
|
1066
1842
|
"""
|
|
1067
|
-
|
|
1843
|
+
Private evaluation helper.
|
|
1844
|
+
"""
|
|
1845
|
+
eval_loader = None
|
|
1846
|
+
|
|
1847
|
+
# load model checkpoint
|
|
1848
|
+
if isinstance(model_checkpoint, Path):
|
|
1849
|
+
self._load_checkpoint(path=model_checkpoint)
|
|
1850
|
+
elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback:
|
|
1851
|
+
path_to_latest = self._checkpoint_callback.best_checkpoint_path
|
|
1852
|
+
self._load_checkpoint(path_to_latest)
|
|
1853
|
+
elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback is None:
|
|
1854
|
+
_LOGGER.error(f"'model_checkpoint' set to '{MagicWords.LATEST}' but no checkpoint callback was found.")
|
|
1855
|
+
raise ValueError()
|
|
1856
|
+
|
|
1857
|
+
# Dataloader
|
|
1858
|
+
if isinstance(data, DataLoader):
|
|
1859
|
+
eval_loader = data
|
|
1860
|
+
elif isinstance(data, Dataset):
|
|
1861
|
+
# Create a new loader from the provided dataset
|
|
1862
|
+
eval_loader = DataLoader(data,
|
|
1863
|
+
batch_size=self._batch_size,
|
|
1864
|
+
shuffle=False,
|
|
1865
|
+
num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
|
|
1866
|
+
pin_memory=(self.device.type == "cuda"))
|
|
1867
|
+
else: # data is None, use the trainer's default validation dataset
|
|
1868
|
+
if self.validation_dataset is None:
|
|
1869
|
+
_LOGGER.error("Cannot evaluate. No data provided and no validation_dataset available in the trainer.")
|
|
1870
|
+
raise ValueError()
|
|
1871
|
+
eval_loader = DataLoader(self.validation_dataset,
|
|
1872
|
+
batch_size=self._batch_size,
|
|
1873
|
+
shuffle=False,
|
|
1874
|
+
num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
|
|
1875
|
+
pin_memory=(self.device.type == "cuda"))
|
|
1876
|
+
|
|
1877
|
+
if eval_loader is None:
|
|
1878
|
+
_LOGGER.error("Cannot evaluate. No valid data was provided or found.")
|
|
1879
|
+
raise ValueError()
|
|
1880
|
+
|
|
1881
|
+
all_preds, _, all_true = [], [], []
|
|
1882
|
+
for y_pred_b, y_prob_b, y_true_b in self._predict_for_eval(eval_loader):
|
|
1883
|
+
if y_pred_b is not None: all_preds.append(y_pred_b)
|
|
1884
|
+
if y_true_b is not None: all_true.append(y_true_b)
|
|
1885
|
+
|
|
1886
|
+
if not all_true:
|
|
1887
|
+
_LOGGER.error("Evaluation failed: No data was processed.")
|
|
1888
|
+
return
|
|
1889
|
+
|
|
1890
|
+
y_pred = np.concatenate(all_preds)
|
|
1891
|
+
y_true = np.concatenate(all_true)
|
|
1892
|
+
|
|
1893
|
+
# --- Routing Logic ---
|
|
1894
|
+
if self.kind == MLTaskKeys.SEQUENCE_VALUE:
|
|
1895
|
+
config = None
|
|
1896
|
+
if format_configuration and isinstance(format_configuration, SequenceValueMetricsFormat):
|
|
1897
|
+
config = format_configuration
|
|
1898
|
+
elif format_configuration:
|
|
1899
|
+
_LOGGER.warning(f"Wrong config type: Received {type(format_configuration).__name__}, expected SequenceValueMetricsFormat.")
|
|
1900
|
+
|
|
1901
|
+
sequence_to_value_metrics(y_true=y_true,
|
|
1902
|
+
y_pred=y_pred,
|
|
1903
|
+
save_dir=save_dir,
|
|
1904
|
+
config=config)
|
|
1905
|
+
|
|
1906
|
+
elif self.kind == MLTaskKeys.SEQUENCE_SEQUENCE:
|
|
1907
|
+
config = None
|
|
1908
|
+
if format_configuration and isinstance(format_configuration, SequenceSequenceMetricsFormat):
|
|
1909
|
+
config = format_configuration
|
|
1910
|
+
elif format_configuration:
|
|
1911
|
+
_LOGGER.warning(f"Wrong config type: Received {type(format_configuration).__name__}, expected SequenceSequenceMetricsFormat.")
|
|
1912
|
+
|
|
1913
|
+
sequence_to_sequence_metrics(y_true=y_true,
|
|
1914
|
+
y_pred=y_pred,
|
|
1915
|
+
save_dir=save_dir,
|
|
1916
|
+
config=config)
|
|
1917
|
+
|
|
1918
|
+
def finalize_model_training(self,
|
|
1919
|
+
save_dir: Union[str, Path],
|
|
1920
|
+
filename: str,
|
|
1921
|
+
last_training_sequence: np.ndarray,
|
|
1922
|
+
model_checkpoint: Union[Path, Literal['latest', 'current']]):
|
|
1923
|
+
"""
|
|
1924
|
+
Saves a finalized, "inference-ready" model state to a .pth file.
|
|
1925
|
+
|
|
1926
|
+
This method saves the model's `state_dict` and the final epoch number.
|
|
1068
1927
|
|
|
1069
1928
|
Args:
|
|
1070
|
-
|
|
1929
|
+
save_dir (Union[str, Path]): The directory to save the finalized model.
|
|
1930
|
+
filename (str): The desired filename for the model (e.g., "final_model.pth").
|
|
1931
|
+
last_training_sequence (np.ndarray): The last un-scaled sequence from the training data, used for forecasting.
|
|
1932
|
+
model_checkpoint (Union[Path, Literal["latest", "current"]]):
|
|
1933
|
+
- Path: Loads the model state from a specific checkpoint file.
|
|
1934
|
+
- "latest": Loads the best model state saved by the `DragonModelCheckpoint` callback.
|
|
1935
|
+
- "current": Uses the model's state as it is at the end of the `fit()` call.
|
|
1071
1936
|
"""
|
|
1072
|
-
|
|
1073
|
-
|
|
1074
|
-
|
|
1937
|
+
# handle save path
|
|
1938
|
+
sanitized_filename = sanitize_filename(filename)
|
|
1939
|
+
if not sanitized_filename.endswith(".pth"):
|
|
1940
|
+
sanitized_filename = sanitized_filename + ".pth"
|
|
1941
|
+
|
|
1942
|
+
dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
1943
|
+
full_path = dir_path / sanitized_filename
|
|
1944
|
+
|
|
1945
|
+
# handle checkpoint
|
|
1946
|
+
if isinstance(model_checkpoint, Path):
|
|
1947
|
+
self._load_checkpoint(path=model_checkpoint)
|
|
1948
|
+
elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback:
|
|
1949
|
+
path_to_latest = self._checkpoint_callback.best_checkpoint_path
|
|
1950
|
+
self._load_checkpoint(path_to_latest)
|
|
1951
|
+
elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback is None:
|
|
1952
|
+
_LOGGER.error(f"'model_checkpoint' set to '{MagicWords.LATEST}' but no checkpoint callback was found.")
|
|
1953
|
+
raise ValueError()
|
|
1954
|
+
elif model_checkpoint == MagicWords.CURRENT:
|
|
1955
|
+
pass
|
|
1956
|
+
else:
|
|
1957
|
+
_LOGGER.error(f"Unknown 'model_checkpoint' parameter received '{model_checkpoint}'.")
|
|
1958
|
+
|
|
1959
|
+
# --- 1. Validate the provided initial sequence ---
|
|
1960
|
+
if not isinstance(last_training_sequence, np.ndarray):
|
|
1961
|
+
_LOGGER.error(f"'last_training_sequence' must be a numpy array. Got {type(last_training_sequence)}")
|
|
1962
|
+
raise TypeError()
|
|
1963
|
+
if last_training_sequence.ndim != 1:
|
|
1964
|
+
_LOGGER.error(f"'last_training_sequence' must be a 1D array. Got {last_training_sequence.ndim} dimensions.")
|
|
1965
|
+
raise ValueError()
|
|
1966
|
+
|
|
1967
|
+
# --- 2. Derive sequence_length from the array ---
|
|
1968
|
+
sequence_length = len(last_training_sequence)
|
|
1969
|
+
if sequence_length <= 0:
|
|
1970
|
+
_LOGGER.error(f"Length of 'last_training_sequence' cannot be zero.")
|
|
1971
|
+
raise ValueError()
|
|
1972
|
+
|
|
1973
|
+
# Create finalized data
|
|
1974
|
+
finalized_data = {
|
|
1975
|
+
PyTorchCheckpointKeys.EPOCH: self.epoch,
|
|
1976
|
+
PyTorchCheckpointKeys.MODEL_STATE: self.model.state_dict(),
|
|
1977
|
+
PyTorchCheckpointKeys.SEQUENCE_LENGTH: sequence_length,
|
|
1978
|
+
PyTorchCheckpointKeys.INITIAL_SEQUENCE: last_training_sequence
|
|
1979
|
+
}
|
|
1980
|
+
|
|
1981
|
+
torch.save(finalized_data, full_path)
|
|
1982
|
+
|
|
1983
|
+
_LOGGER.info(f"Finalized model weights saved to {full_path}.")
|
|
1075
1984
|
|
|
1076
1985
|
|
|
1077
1986
|
def info():
|