dragon-ml-toolbox 14.8.0__py3-none-any.whl → 16.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dragon-ml-toolbox might be problematic. Click here for more details.
- {dragon_ml_toolbox-14.8.0.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/METADATA +9 -5
- dragon_ml_toolbox-16.0.0.dist-info/RECORD +51 -0
- ml_tools/ETL_cleaning.py +20 -20
- ml_tools/ETL_engineering.py +23 -25
- ml_tools/GUI_tools.py +20 -20
- ml_tools/MICE_imputation.py +3 -3
- ml_tools/ML_callbacks.py +43 -26
- ml_tools/ML_configuration.py +204 -11
- ml_tools/ML_datasetmaster.py +198 -280
- ml_tools/ML_evaluation.py +132 -41
- ml_tools/ML_evaluation_multi.py +96 -35
- ml_tools/ML_inference.py +249 -207
- ml_tools/ML_models.py +13 -102
- ml_tools/ML_models_advanced.py +1 -1
- ml_tools/ML_optimization.py +12 -12
- ml_tools/ML_scaler.py +11 -11
- ml_tools/ML_sequence_datasetmaster.py +341 -0
- ml_tools/ML_sequence_evaluation.py +215 -0
- ml_tools/ML_sequence_inference.py +391 -0
- ml_tools/ML_sequence_models.py +139 -0
- ml_tools/ML_trainer.py +1237 -354
- ml_tools/ML_utilities.py +1 -1
- ml_tools/ML_vision_datasetmaster.py +73 -67
- ml_tools/ML_vision_evaluation.py +26 -6
- ml_tools/ML_vision_inference.py +117 -140
- ml_tools/ML_vision_models.py +1 -1
- ml_tools/ML_vision_transformers.py +121 -40
- ml_tools/PSO_optimization.py +6 -6
- ml_tools/SQL.py +4 -4
- ml_tools/{keys.py → _keys.py} +43 -0
- ml_tools/_schema.py +1 -1
- ml_tools/ensemble_evaluation.py +1 -1
- ml_tools/ensemble_inference.py +7 -33
- ml_tools/ensemble_learning.py +1 -1
- ml_tools/optimization_tools.py +2 -2
- ml_tools/path_manager.py +5 -5
- ml_tools/utilities.py +1 -2
- dragon_ml_toolbox-14.8.0.dist-info/RECORD +0 -49
- ml_tools/RNN_forecast.py +0 -56
- ml_tools/_ML_vision_recipe.py +0 -88
- {dragon_ml_toolbox-14.8.0.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-14.8.0.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-14.8.0.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-14.8.0.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/top_level.txt +0 -0
ml_tools/ML_trainer.py
CHANGED
|
@@ -1,80 +1,79 @@
|
|
|
1
|
-
from typing import List, Literal, Union, Optional, Callable, Dict, Any
|
|
1
|
+
from typing import List, Literal, Union, Optional, Callable, Dict, Any
|
|
2
2
|
from pathlib import Path
|
|
3
3
|
from torch.utils.data import DataLoader, Dataset
|
|
4
4
|
import torch
|
|
5
5
|
from torch import nn
|
|
6
6
|
import numpy as np
|
|
7
|
+
from abc import ABC, abstractmethod
|
|
7
8
|
|
|
8
|
-
from .
|
|
9
|
+
from .path_manager import make_fullpath, sanitize_filename
|
|
10
|
+
from .ML_callbacks import _Callback, History, TqdmProgressBar, DragonModelCheckpoint, DragonEarlyStopping, DragonLRScheduler
|
|
9
11
|
from .ML_evaluation import classification_metrics, regression_metrics, plot_losses, shap_summary_plot, plot_attention_importance
|
|
10
12
|
from .ML_evaluation_multi import multi_target_regression_metrics, multi_label_classification_metrics, multi_target_shap_summary_plot
|
|
13
|
+
from .ML_vision_evaluation import segmentation_metrics, object_detection_metrics
|
|
14
|
+
from .ML_sequence_evaluation import sequence_to_sequence_metrics, sequence_to_value_metrics
|
|
15
|
+
from .ML_configuration import (ClassificationMetricsFormat,
|
|
16
|
+
MultiClassificationMetricsFormat,
|
|
17
|
+
RegressionMetricsFormat,
|
|
18
|
+
SegmentationMetricsFormat,
|
|
19
|
+
SequenceValueMetricsFormat,
|
|
20
|
+
SequenceSequenceMetricsFormat)
|
|
21
|
+
|
|
11
22
|
from ._script_info import _script_info
|
|
12
|
-
from .
|
|
23
|
+
from ._keys import PyTorchLogKeys, PyTorchCheckpointKeys, DatasetKeys, MLTaskKeys, MagicWords, DragonTrainerKeys
|
|
13
24
|
from ._logger import _LOGGER
|
|
14
|
-
from .path_manager import make_fullpath
|
|
15
|
-
from .ML_vision_evaluation import segmentation_metrics, object_detection_metrics
|
|
16
|
-
from .ML_configuration import ClassificationMetricsFormat, MultiClassificationMetricsFormat
|
|
17
25
|
|
|
18
26
|
|
|
19
27
|
__all__ = [
|
|
20
|
-
"
|
|
21
|
-
"
|
|
28
|
+
"DragonTrainer",
|
|
29
|
+
"DragonDetectionTrainer",
|
|
30
|
+
"DragonSequenceTrainer"
|
|
22
31
|
]
|
|
23
32
|
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
kind: Literal["regression", "classification", "multi_target_regression", "multi_label_classification", "segmentation"],
|
|
28
|
-
criterion: nn.Module, optimizer: torch.optim.Optimizer,
|
|
29
|
-
device: Union[Literal['cuda', 'mps', 'cpu'],str], dataloader_workers: int = 2, callbacks: Optional[List[Callback]] = None):
|
|
30
|
-
"""
|
|
31
|
-
Automates the training process of a PyTorch Model.
|
|
32
|
-
|
|
33
|
-
Built-in Callbacks: `History`, `TqdmProgressBar`
|
|
34
|
-
|
|
35
|
-
Args:
|
|
36
|
-
model (nn.Module): The PyTorch model to train.
|
|
37
|
-
train_dataset (Dataset): The training dataset.
|
|
38
|
-
test_dataset (Dataset): The testing/validation dataset.
|
|
39
|
-
kind (str): Can be 'regression', 'classification', 'multi_target_regression', 'multi_label_classification', or 'segmentation'.
|
|
40
|
-
criterion (nn.Module): The loss function.
|
|
41
|
-
optimizer (torch.optim.Optimizer): The optimizer.
|
|
42
|
-
device (str): The device to run training on ('cpu', 'cuda', 'mps').
|
|
43
|
-
dataloader_workers (int): Subprocesses for data loading.
|
|
44
|
-
callbacks (List[Callback] | None): A list of callbacks to use during training.
|
|
45
|
-
|
|
46
|
-
Note:
|
|
47
|
-
- For **regression** and **multi_target_regression** tasks, suggested criterions include `nn.MSELoss` or `nn.L1Loss`.
|
|
48
|
-
|
|
49
|
-
- For **single-label, multi-class classification** tasks, `nn.CrossEntropyLoss` is the standard choice.
|
|
33
|
+
class _BaseDragonTrainer(ABC):
|
|
34
|
+
"""
|
|
35
|
+
Abstract base class for Dragon Trainers.
|
|
50
36
|
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
37
|
+
Handles the common training loop orchestration, checkpointing, callback
|
|
38
|
+
management, and device handling. Subclasses must implement the
|
|
39
|
+
task-specific logic (dataloaders, train/val steps, evaluation).
|
|
40
|
+
"""
|
|
41
|
+
def __init__(self,
|
|
42
|
+
model: nn.Module,
|
|
43
|
+
optimizer: torch.optim.Optimizer,
|
|
44
|
+
device: Union[Literal['cuda', 'mps', 'cpu'],str],
|
|
45
|
+
dataloader_workers: int = 2,
|
|
46
|
+
checkpoint_callback: Optional[DragonModelCheckpoint] = None,
|
|
47
|
+
early_stopping_callback: Optional[DragonEarlyStopping] = None,
|
|
48
|
+
lr_scheduler_callback: Optional[DragonLRScheduler] = None,
|
|
49
|
+
extra_callbacks: Optional[List[_Callback]] = None):
|
|
57
50
|
|
|
58
51
|
self.model = model
|
|
59
|
-
self.train_dataset = train_dataset
|
|
60
|
-
self.test_dataset = test_dataset
|
|
61
|
-
self.kind = kind
|
|
62
|
-
self.criterion = criterion
|
|
63
52
|
self.optimizer = optimizer
|
|
64
53
|
self.scheduler = None
|
|
65
54
|
self.device = self._validate_device(device)
|
|
66
55
|
self.dataloader_workers = dataloader_workers
|
|
67
56
|
|
|
68
|
-
# Callback handler
|
|
57
|
+
# Callback handler
|
|
69
58
|
default_callbacks = [History(), TqdmProgressBar()]
|
|
70
|
-
|
|
59
|
+
|
|
60
|
+
self._checkpoint_callback = None
|
|
61
|
+
if checkpoint_callback:
|
|
62
|
+
default_callbacks.append(checkpoint_callback)
|
|
63
|
+
self._checkpoint_callback = checkpoint_callback
|
|
64
|
+
if early_stopping_callback:
|
|
65
|
+
default_callbacks.append(early_stopping_callback)
|
|
66
|
+
if lr_scheduler_callback:
|
|
67
|
+
default_callbacks.append(lr_scheduler_callback)
|
|
68
|
+
|
|
69
|
+
user_callbacks = extra_callbacks if extra_callbacks is not None else []
|
|
71
70
|
self.callbacks = default_callbacks + user_callbacks
|
|
72
71
|
self._set_trainer_on_callbacks()
|
|
73
72
|
|
|
74
73
|
# Internal state
|
|
75
|
-
self.train_loader = None
|
|
76
|
-
self.
|
|
77
|
-
self.history = {}
|
|
74
|
+
self.train_loader: Optional[DataLoader] = None
|
|
75
|
+
self.validation_loader: Optional[DataLoader] = None
|
|
76
|
+
self.history: Dict[str, List[Any]] = {}
|
|
78
77
|
self.epoch = 0
|
|
79
78
|
self.epochs = 0 # Total epochs for the fit run
|
|
80
79
|
self.start_epoch = 1
|
|
@@ -97,32 +96,10 @@ class MLTrainer:
|
|
|
97
96
|
for callback in self.callbacks:
|
|
98
97
|
callback.set_trainer(self)
|
|
99
98
|
|
|
100
|
-
def _create_dataloaders(self, batch_size: int, shuffle: bool):
|
|
101
|
-
"""Initializes the DataLoaders."""
|
|
102
|
-
# Ensure stability on MPS devices by setting num_workers to 0
|
|
103
|
-
loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
|
|
104
|
-
|
|
105
|
-
self.train_loader = DataLoader(
|
|
106
|
-
dataset=self.train_dataset,
|
|
107
|
-
batch_size=batch_size,
|
|
108
|
-
shuffle=shuffle,
|
|
109
|
-
num_workers=loader_workers,
|
|
110
|
-
pin_memory=("cuda" in self.device.type),
|
|
111
|
-
drop_last=True # Drops the last batch if incomplete, selecting a good batch size is key.
|
|
112
|
-
)
|
|
113
|
-
|
|
114
|
-
self.test_loader = DataLoader(
|
|
115
|
-
dataset=self.test_dataset,
|
|
116
|
-
batch_size=batch_size,
|
|
117
|
-
shuffle=False,
|
|
118
|
-
num_workers=loader_workers,
|
|
119
|
-
pin_memory=("cuda" in self.device.type)
|
|
120
|
-
)
|
|
121
|
-
|
|
122
99
|
def _load_checkpoint(self, path: Union[str, Path]):
|
|
123
100
|
"""Loads a training checkpoint to resume training."""
|
|
124
101
|
p = make_fullpath(path, enforce="file")
|
|
125
|
-
_LOGGER.info(f"Loading checkpoint from '{p.name}'
|
|
102
|
+
_LOGGER.info(f"Loading checkpoint from '{p.name}'...")
|
|
126
103
|
|
|
127
104
|
try:
|
|
128
105
|
checkpoint = torch.load(p, map_location=self.device)
|
|
@@ -133,7 +110,16 @@ class MLTrainer:
|
|
|
133
110
|
|
|
134
111
|
self.model.load_state_dict(checkpoint[PyTorchCheckpointKeys.MODEL_STATE])
|
|
135
112
|
self.optimizer.load_state_dict(checkpoint[PyTorchCheckpointKeys.OPTIMIZER_STATE])
|
|
136
|
-
self.
|
|
113
|
+
self.epoch = checkpoint.get(PyTorchCheckpointKeys.EPOCH, 0)
|
|
114
|
+
self.start_epoch = self.epoch + 1 # Resume on the *next* epoch
|
|
115
|
+
|
|
116
|
+
# --- Load History ---
|
|
117
|
+
if PyTorchCheckpointKeys.HISTORY in checkpoint:
|
|
118
|
+
self.history = checkpoint[PyTorchCheckpointKeys.HISTORY]
|
|
119
|
+
_LOGGER.info(f"Restored training history up to epoch {self.epoch}.")
|
|
120
|
+
else:
|
|
121
|
+
_LOGGER.warning("No 'history' found in checkpoint. A new history will be started.")
|
|
122
|
+
self.history = {} # Ensure it's at least an empty dict
|
|
137
123
|
|
|
138
124
|
# --- Scheduler State Loading Logic ---
|
|
139
125
|
scheduler_state_exists = PyTorchCheckpointKeys.SCHEDULER_STATE in checkpoint
|
|
@@ -163,7 +149,7 @@ class MLTrainer:
|
|
|
163
149
|
|
|
164
150
|
# Restore callback states
|
|
165
151
|
for cb in self.callbacks:
|
|
166
|
-
if isinstance(cb,
|
|
152
|
+
if isinstance(cb, DragonModelCheckpoint) and PyTorchCheckpointKeys.BEST_SCORE in checkpoint:
|
|
167
153
|
cb.best = checkpoint[PyTorchCheckpointKeys.BEST_SCORE]
|
|
168
154
|
_LOGGER.info(f"Restored {cb.__class__.__name__} 'best' score to: {cb.best:.4f}")
|
|
169
155
|
|
|
@@ -174,7 +160,8 @@ class MLTrainer:
|
|
|
174
160
|
raise
|
|
175
161
|
|
|
176
162
|
def fit(self,
|
|
177
|
-
|
|
163
|
+
save_dir: Union[str,Path],
|
|
164
|
+
epochs: int = 100,
|
|
178
165
|
batch_size: int = 10,
|
|
179
166
|
shuffle: bool = True,
|
|
180
167
|
resume_from_checkpoint: Optional[Union[str, Path]] = None):
|
|
@@ -184,21 +171,15 @@ class MLTrainer:
|
|
|
184
171
|
Returns the "History" callback dictionary.
|
|
185
172
|
|
|
186
173
|
Args:
|
|
174
|
+
save_dir (str | Path): Directory to save the loss plot.
|
|
187
175
|
epochs (int): The total number of epochs to train for.
|
|
188
176
|
batch_size (int): The number of samples per batch.
|
|
189
177
|
shuffle (bool): Whether to shuffle the training data at each epoch.
|
|
190
178
|
resume_from_checkpoint (str | Path | None): Optional path to a checkpoint to resume training.
|
|
191
|
-
|
|
192
|
-
Note:
|
|
193
|
-
For regression tasks using `nn.MSELoss` or `nn.L1Loss`, the trainer
|
|
194
|
-
automatically aligns the model's output tensor with the target tensor's
|
|
195
|
-
shape using `output.view_as(target)`. This handles the common case
|
|
196
|
-
where a model outputs a shape of `[batch_size, 1]` and the target has a
|
|
197
|
-
shape of `[batch_size]`.
|
|
198
179
|
"""
|
|
199
180
|
self.epochs = epochs
|
|
200
181
|
self._batch_size = batch_size
|
|
201
|
-
self._create_dataloaders(self._batch_size, shuffle)
|
|
182
|
+
self._create_dataloaders(self._batch_size, shuffle) # type: ignore
|
|
202
183
|
self.model.to(self.device)
|
|
203
184
|
|
|
204
185
|
if resume_from_checkpoint:
|
|
@@ -209,11 +190,19 @@ class MLTrainer:
|
|
|
209
190
|
|
|
210
191
|
self._callbacks_hook('on_train_begin')
|
|
211
192
|
|
|
193
|
+
if not self.train_loader:
|
|
194
|
+
_LOGGER.error("Train loader is not initialized.")
|
|
195
|
+
raise ValueError()
|
|
196
|
+
|
|
197
|
+
if not self.validation_loader:
|
|
198
|
+
_LOGGER.error("Validation loader is not initialized.")
|
|
199
|
+
raise ValueError()
|
|
200
|
+
|
|
212
201
|
for epoch in range(self.start_epoch, self.epochs + 1):
|
|
213
202
|
self.epoch = epoch
|
|
214
|
-
epoch_logs = {}
|
|
203
|
+
epoch_logs: Dict[str, Any] = {}
|
|
215
204
|
self._callbacks_hook('on_epoch_begin', epoch, logs=epoch_logs)
|
|
216
|
-
|
|
205
|
+
|
|
217
206
|
train_logs = self._train_step()
|
|
218
207
|
epoch_logs.update(train_logs)
|
|
219
208
|
|
|
@@ -227,11 +216,185 @@ class MLTrainer:
|
|
|
227
216
|
break
|
|
228
217
|
|
|
229
218
|
self._callbacks_hook('on_train_end')
|
|
219
|
+
|
|
220
|
+
# Training History
|
|
221
|
+
plot_losses(self.history, save_dir=save_dir)
|
|
222
|
+
|
|
230
223
|
return self.history
|
|
224
|
+
|
|
225
|
+
def _callbacks_hook(self, method_name: str, *args, **kwargs):
|
|
226
|
+
"""Calls the specified method on all callbacks."""
|
|
227
|
+
for callback in self.callbacks:
|
|
228
|
+
method = getattr(callback, method_name)
|
|
229
|
+
method(*args, **kwargs)
|
|
230
|
+
|
|
231
|
+
def to_cpu(self):
|
|
232
|
+
"""
|
|
233
|
+
Moves the model to the CPU and updates the trainer's device setting.
|
|
234
|
+
|
|
235
|
+
This is useful for running operations that require the CPU.
|
|
236
|
+
"""
|
|
237
|
+
self.device = torch.device('cpu')
|
|
238
|
+
self.model.to(self.device)
|
|
239
|
+
_LOGGER.info("Trainer and model moved to CPU.")
|
|
240
|
+
|
|
241
|
+
def to_device(self, device: str):
|
|
242
|
+
"""
|
|
243
|
+
Moves the model to the specified device and updates the trainer's device setting.
|
|
244
|
+
|
|
245
|
+
Args:
|
|
246
|
+
device (str): The target device (e.g., 'cuda', 'mps', 'cpu').
|
|
247
|
+
"""
|
|
248
|
+
self.device = self._validate_device(device)
|
|
249
|
+
self.model.to(self.device)
|
|
250
|
+
_LOGGER.info(f"Trainer and model moved to {self.device}.")
|
|
251
|
+
|
|
252
|
+
# --- Abstract Methods ---
|
|
253
|
+
# These must be implemented by subclasses
|
|
254
|
+
|
|
255
|
+
@abstractmethod
|
|
256
|
+
def _create_dataloaders(self, batch_size: int, shuffle: bool):
|
|
257
|
+
"""Initializes the DataLoaders."""
|
|
258
|
+
raise NotImplementedError
|
|
259
|
+
|
|
260
|
+
@abstractmethod
|
|
261
|
+
def _train_step(self) -> Dict[str, float]:
|
|
262
|
+
"""Runs a single training epoch."""
|
|
263
|
+
raise NotImplementedError
|
|
264
|
+
|
|
265
|
+
@abstractmethod
|
|
266
|
+
def _validation_step(self) -> Dict[str, float]:
|
|
267
|
+
"""Runs a single validation epoch."""
|
|
268
|
+
raise NotImplementedError
|
|
269
|
+
|
|
270
|
+
@abstractmethod
|
|
271
|
+
def evaluate(self, *args, **kwargs):
|
|
272
|
+
"""Runs the full model evaluation."""
|
|
273
|
+
raise NotImplementedError
|
|
274
|
+
|
|
275
|
+
@abstractmethod
|
|
276
|
+
def _evaluate(self, *args, **kwargs):
|
|
277
|
+
"""Internal evaluation helper."""
|
|
278
|
+
raise NotImplementedError
|
|
279
|
+
|
|
280
|
+
@abstractmethod
|
|
281
|
+
def finalize_model_training(self, *args, **kwargs):
|
|
282
|
+
"""Saves the finalized model for inference."""
|
|
283
|
+
raise NotImplementedError
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
# --- DragonTrainer ----
|
|
287
|
+
class DragonTrainer(_BaseDragonTrainer):
|
|
288
|
+
def __init__(self,
|
|
289
|
+
model: nn.Module,
|
|
290
|
+
train_dataset: Dataset,
|
|
291
|
+
validation_dataset: Dataset,
|
|
292
|
+
kind: Literal["regression", "binary classification", "multiclass classification",
|
|
293
|
+
"multitarget regression", "multilabel binary classification",
|
|
294
|
+
"binary segmentation", "multiclass segmentation", "binary image classification", "multiclass image classification"],
|
|
295
|
+
optimizer: torch.optim.Optimizer,
|
|
296
|
+
device: Union[Literal['cuda', 'mps', 'cpu'],str],
|
|
297
|
+
checkpoint_callback: Optional[DragonModelCheckpoint],
|
|
298
|
+
early_stopping_callback: Optional[DragonEarlyStopping],
|
|
299
|
+
lr_scheduler_callback: Optional[DragonLRScheduler],
|
|
300
|
+
extra_callbacks: Optional[List[_Callback]] = None,
|
|
301
|
+
criterion: Union[nn.Module,Literal["auto"]] = "auto",
|
|
302
|
+
dataloader_workers: int = 2):
|
|
303
|
+
"""
|
|
304
|
+
Automates the training process of a PyTorch Model.
|
|
305
|
+
|
|
306
|
+
Built-in Callbacks: `History`, `TqdmProgressBar`
|
|
307
|
+
|
|
308
|
+
Args:
|
|
309
|
+
model (nn.Module): The PyTorch model to train.
|
|
310
|
+
train_dataset (Dataset): The training dataset.
|
|
311
|
+
validation_dataset (Dataset): The validation dataset.
|
|
312
|
+
kind (str): Used to redirect to the correct process.
|
|
313
|
+
criterion (nn.Module | "auto"): The loss function to use. If "auto", it will be inferred from the selected task
|
|
314
|
+
optimizer (torch.optim.Optimizer): The optimizer.
|
|
315
|
+
device (str): The device to run training on ('cpu', 'cuda', 'mps').
|
|
316
|
+
dataloader_workers (int): Subprocesses for data loading.
|
|
317
|
+
extra_callbacks (List[Callback] | None): A list of extra callbacks to use during training.
|
|
318
|
+
|
|
319
|
+
Note:
|
|
320
|
+
- For **regression** and **multi_target_regression** tasks, suggested criterions include `nn.MSELoss` or `nn.L1Loss`. The model should output as many logits as existing targets.
|
|
321
|
+
|
|
322
|
+
- For **single-label, binary classification**, `nn.BCEWithLogitsLoss` is the standard choice. The model should output a single logit.
|
|
231
323
|
|
|
324
|
+
- For **single-label, multi-class classification** tasks, `nn.CrossEntropyLoss` is the standard choice. The model should output as many logits as existing classes.
|
|
325
|
+
|
|
326
|
+
- For **multi-label, binary classification** tasks (where each label is a 0 or 1), `nn.BCEWithLogitsLoss` is the correct choice as it treats each output as an independent binary problem. The model should output 1 logit per binary target.
|
|
327
|
+
|
|
328
|
+
- For **binary segmentation** tasks, `nn.BCEWithLogitsLoss` is common. The model should output a single logit.
|
|
329
|
+
|
|
330
|
+
- for **multiclass segmentation** tasks, `nn.CrossEntropyLoss` is the standard. The model should output as many logits as existing classes.
|
|
331
|
+
"""
|
|
332
|
+
# Call the base class constructor with common parameters
|
|
333
|
+
super().__init__(
|
|
334
|
+
model=model,
|
|
335
|
+
optimizer=optimizer,
|
|
336
|
+
device=device,
|
|
337
|
+
dataloader_workers=dataloader_workers,
|
|
338
|
+
checkpoint_callback=checkpoint_callback,
|
|
339
|
+
early_stopping_callback=early_stopping_callback,
|
|
340
|
+
lr_scheduler_callback=lr_scheduler_callback,
|
|
341
|
+
extra_callbacks=extra_callbacks
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
if kind not in [MLTaskKeys.REGRESSION,
|
|
345
|
+
MLTaskKeys.BINARY_CLASSIFICATION,
|
|
346
|
+
MLTaskKeys.MULTICLASS_CLASSIFICATION,
|
|
347
|
+
MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION,
|
|
348
|
+
MLTaskKeys.MULTITARGET_REGRESSION,
|
|
349
|
+
MLTaskKeys.BINARY_SEGMENTATION,
|
|
350
|
+
MLTaskKeys.MULTICLASS_SEGMENTATION,
|
|
351
|
+
MLTaskKeys.BINARY_IMAGE_CLASSIFICATION,
|
|
352
|
+
MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION]:
|
|
353
|
+
raise ValueError(f"'{kind}' is not a valid task type.")
|
|
354
|
+
|
|
355
|
+
self.train_dataset = train_dataset
|
|
356
|
+
self.validation_dataset = validation_dataset
|
|
357
|
+
self.kind = kind
|
|
358
|
+
self._classification_threshold: float = 0.5
|
|
359
|
+
|
|
360
|
+
# loss function
|
|
361
|
+
if criterion == "auto":
|
|
362
|
+
if kind in [MLTaskKeys.REGRESSION, MLTaskKeys.MULTITARGET_REGRESSION]:
|
|
363
|
+
self.criterion = nn.MSELoss()
|
|
364
|
+
elif kind in [MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.BINARY_IMAGE_CLASSIFICATION, MLTaskKeys.BINARY_SEGMENTATION, MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION]:
|
|
365
|
+
self.criterion = nn.BCEWithLogitsLoss()
|
|
366
|
+
elif kind in [MLTaskKeys.MULTICLASS_CLASSIFICATION, MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION, MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION, MLTaskKeys.MULTICLASS_SEGMENTATION]:
|
|
367
|
+
self.criterion = nn.CrossEntropyLoss()
|
|
368
|
+
else:
|
|
369
|
+
self.criterion = criterion
|
|
370
|
+
|
|
371
|
+
def _create_dataloaders(self, batch_size: int, shuffle: bool):
|
|
372
|
+
"""Initializes the DataLoaders."""
|
|
373
|
+
# Ensure stability on MPS devices by setting num_workers to 0
|
|
374
|
+
loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
|
|
375
|
+
|
|
376
|
+
self.train_loader = DataLoader(
|
|
377
|
+
dataset=self.train_dataset,
|
|
378
|
+
batch_size=batch_size,
|
|
379
|
+
shuffle=shuffle,
|
|
380
|
+
num_workers=loader_workers,
|
|
381
|
+
pin_memory=("cuda" in self.device.type),
|
|
382
|
+
drop_last=True # Drops the last batch if incomplete, selecting a good batch size is key.
|
|
383
|
+
)
|
|
384
|
+
|
|
385
|
+
self.validation_loader = DataLoader(
|
|
386
|
+
dataset=self.validation_dataset,
|
|
387
|
+
batch_size=batch_size,
|
|
388
|
+
shuffle=False,
|
|
389
|
+
num_workers=loader_workers,
|
|
390
|
+
pin_memory=("cuda" in self.device.type)
|
|
391
|
+
)
|
|
392
|
+
|
|
232
393
|
def _train_step(self):
|
|
233
394
|
self.model.train()
|
|
234
395
|
running_loss = 0.0
|
|
396
|
+
total_samples = 0
|
|
397
|
+
|
|
235
398
|
for batch_idx, (features, target) in enumerate(self.train_loader): # type: ignore
|
|
236
399
|
# Create a log dictionary for the batch
|
|
237
400
|
batch_logs = {
|
|
@@ -245,9 +408,21 @@ class MLTrainer:
|
|
|
245
408
|
|
|
246
409
|
output = self.model(features)
|
|
247
410
|
|
|
248
|
-
#
|
|
249
|
-
|
|
250
|
-
|
|
411
|
+
# --- Label Type/Shape Correction ---
|
|
412
|
+
# Cast target to float for BCE-based losses
|
|
413
|
+
if self.kind in MLTaskKeys.ALL_BINARY_TASKS:
|
|
414
|
+
target = target.float()
|
|
415
|
+
|
|
416
|
+
# Reshape output to match target for single-logit tasks
|
|
417
|
+
if self.kind in [MLTaskKeys.REGRESSION, MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.BINARY_IMAGE_CLASSIFICATION]:
|
|
418
|
+
# If model outputs [N, 1] and target is [N], squeeze output
|
|
419
|
+
if output.ndim == 2 and output.shape[1] == 1 and target.ndim == 1:
|
|
420
|
+
output = output.squeeze(1)
|
|
421
|
+
|
|
422
|
+
if self.kind == MLTaskKeys.BINARY_SEGMENTATION:
|
|
423
|
+
# If model outputs [N, 1, H, W] and target is [N, H, W], squeeze output
|
|
424
|
+
if output.ndim == 4 and output.shape[1] == 1 and target.ndim == 3:
|
|
425
|
+
output = output.squeeze(1)
|
|
251
426
|
|
|
252
427
|
loss = self.criterion(output, target)
|
|
253
428
|
|
|
@@ -256,34 +431,58 @@ class MLTrainer:
|
|
|
256
431
|
|
|
257
432
|
# Calculate batch loss and update running loss for the epoch
|
|
258
433
|
batch_loss = loss.item()
|
|
259
|
-
|
|
434
|
+
batch_size = features.size(0)
|
|
435
|
+
running_loss += batch_loss * batch_size # Accumulate total loss
|
|
436
|
+
total_samples += batch_size # total samples
|
|
260
437
|
|
|
261
438
|
# Add the batch loss to the logs and call the end-of-batch hook
|
|
262
439
|
batch_logs[PyTorchLogKeys.BATCH_LOSS] = batch_loss
|
|
263
440
|
self._callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
|
|
441
|
+
|
|
442
|
+
if total_samples == 0:
|
|
443
|
+
_LOGGER.warning("No samples processed in a train_step. Returning 0 loss.")
|
|
444
|
+
return {PyTorchLogKeys.TRAIN_LOSS: 0.0}
|
|
264
445
|
|
|
265
|
-
return {PyTorchLogKeys.TRAIN_LOSS: running_loss /
|
|
446
|
+
return {PyTorchLogKeys.TRAIN_LOSS: running_loss / total_samples} # type: ignore
|
|
266
447
|
|
|
267
448
|
def _validation_step(self):
|
|
268
449
|
self.model.eval()
|
|
269
450
|
running_loss = 0.0
|
|
451
|
+
|
|
270
452
|
with torch.no_grad():
|
|
271
|
-
for features, target in self.
|
|
453
|
+
for features, target in self.validation_loader: # type: ignore
|
|
272
454
|
features, target = features.to(self.device), target.to(self.device)
|
|
273
455
|
|
|
274
456
|
output = self.model(features)
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
457
|
+
|
|
458
|
+
# --- Label Type/Shape Correction ---
|
|
459
|
+
# Cast target to float for BCE-based losses
|
|
460
|
+
if self.kind in MLTaskKeys.ALL_BINARY_TASKS:
|
|
461
|
+
target = target.float()
|
|
462
|
+
|
|
463
|
+
# Reshape output to match target for single-logit tasks
|
|
464
|
+
if self.kind in [MLTaskKeys.REGRESSION, MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.BINARY_IMAGE_CLASSIFICATION]:
|
|
465
|
+
# If model outputs [N, 1] and target is [N], squeeze output
|
|
466
|
+
if output.ndim == 2 and output.shape[1] == 1 and target.ndim == 1:
|
|
467
|
+
output = output.squeeze(1)
|
|
468
|
+
|
|
469
|
+
if self.kind == MLTaskKeys.BINARY_SEGMENTATION:
|
|
470
|
+
# If model outputs [N, 1, H, W] and target is [N, H, W], squeeze output
|
|
471
|
+
if output.ndim == 4 and output.shape[1] == 1 and target.ndim == 3:
|
|
472
|
+
output = output.squeeze(1)
|
|
278
473
|
|
|
279
474
|
loss = self.criterion(output, target)
|
|
280
475
|
|
|
281
476
|
running_loss += loss.item() * features.size(0)
|
|
477
|
+
|
|
478
|
+
if not self.validation_loader.dataset: # type: ignore
|
|
479
|
+
_LOGGER.warning("No samples processed in _validation_step. Returning 0 loss.")
|
|
480
|
+
return {PyTorchLogKeys.VAL_LOSS: 0.0}
|
|
282
481
|
|
|
283
|
-
logs = {PyTorchLogKeys.VAL_LOSS: running_loss / len(self.
|
|
482
|
+
logs = {PyTorchLogKeys.VAL_LOSS: running_loss / len(self.validation_loader.dataset)} # type: ignore
|
|
284
483
|
return logs
|
|
285
484
|
|
|
286
|
-
def _predict_for_eval(self, dataloader: DataLoader
|
|
485
|
+
def _predict_for_eval(self, dataloader: DataLoader):
|
|
287
486
|
"""
|
|
288
487
|
Private method to yield model predictions batch by batch for evaluation.
|
|
289
488
|
|
|
@@ -294,6 +493,7 @@ class MLTrainer:
|
|
|
294
493
|
"""
|
|
295
494
|
self.model.eval()
|
|
296
495
|
self.model.to(self.device)
|
|
496
|
+
|
|
297
497
|
with torch.no_grad():
|
|
298
498
|
for features, target in dataloader:
|
|
299
499
|
features = features.to(self.device)
|
|
@@ -303,25 +503,64 @@ class MLTrainer:
|
|
|
303
503
|
y_prob_batch = None
|
|
304
504
|
y_true_batch = None
|
|
305
505
|
|
|
306
|
-
if self.kind in [
|
|
506
|
+
if self.kind in [MLTaskKeys.REGRESSION, MLTaskKeys.MULTITARGET_REGRESSION]:
|
|
307
507
|
y_pred_batch = output.numpy()
|
|
308
508
|
y_true_batch = target.numpy()
|
|
509
|
+
|
|
510
|
+
elif self.kind in [MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.BINARY_IMAGE_CLASSIFICATION]:
|
|
511
|
+
# Assumes model output is [N, 1] (a single logit)
|
|
512
|
+
# Squeeze output from [N, 1] to [N] if necessary
|
|
513
|
+
if output.ndim == 2 and output.shape[1] == 1:
|
|
514
|
+
output = output.squeeze(1)
|
|
515
|
+
|
|
516
|
+
probs_pos = torch.sigmoid(output) # Probability of positive class
|
|
517
|
+
preds = (probs_pos >= self._classification_threshold).int()
|
|
518
|
+
y_pred_batch = preds.numpy()
|
|
519
|
+
# For metrics (like ROC AUC), we often need probs for *both* classes
|
|
520
|
+
# Create an [N, 2] array: [prob_class_0, prob_class_1]
|
|
521
|
+
probs_neg = 1.0 - probs_pos
|
|
522
|
+
y_prob_batch = torch.stack([probs_neg, probs_pos], dim=1).numpy()
|
|
523
|
+
y_true_batch = target.numpy()
|
|
309
524
|
|
|
310
|
-
elif self.kind
|
|
525
|
+
elif self.kind in [MLTaskKeys.MULTICLASS_CLASSIFICATION, MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION]:
|
|
526
|
+
num_classes = output.shape[1]
|
|
527
|
+
if num_classes < 3:
|
|
528
|
+
# Optional: warn the user they are using the wrong kind
|
|
529
|
+
wrong_class = MLTaskKeys.MULTICLASS_CLASSIFICATION if self.kind == MLTaskKeys.MULTICLASS_CLASSIFICATION else MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION
|
|
530
|
+
recommended_class = MLTaskKeys.BINARY_CLASSIFICATION if self.kind == MLTaskKeys.MULTICLASS_CLASSIFICATION else MLTaskKeys.BINARY_IMAGE_CLASSIFICATION
|
|
531
|
+
_LOGGER.warning(f"'{wrong_class}' kind used with {num_classes} classes. Consider using '{recommended_class}' instead.")
|
|
532
|
+
|
|
311
533
|
probs = torch.softmax(output, dim=1)
|
|
312
534
|
preds = torch.argmax(probs, dim=1)
|
|
313
535
|
y_pred_batch = preds.numpy()
|
|
314
536
|
y_prob_batch = probs.numpy()
|
|
315
537
|
y_true_batch = target.numpy()
|
|
316
538
|
|
|
317
|
-
elif self.kind ==
|
|
539
|
+
elif self.kind == MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION:
|
|
318
540
|
probs = torch.sigmoid(output)
|
|
319
|
-
preds = (probs >=
|
|
541
|
+
preds = (probs >= self._classification_threshold).int()
|
|
320
542
|
y_pred_batch = preds.numpy()
|
|
321
543
|
y_prob_batch = probs.numpy()
|
|
322
544
|
y_true_batch = target.numpy()
|
|
545
|
+
|
|
546
|
+
elif self.kind == MLTaskKeys.BINARY_SEGMENTATION:
|
|
547
|
+
# Assumes model output is [N, 1, H, W] (logits for positive class)
|
|
548
|
+
probs_pos = torch.sigmoid(output) # Shape [N, 1, H, W]
|
|
549
|
+
preds = (probs_pos >= self._classification_threshold).int() # Shape [N, 1, H, W]
|
|
550
|
+
|
|
551
|
+
# Squeeze preds to [N, H, W] (class indices 0 or 1)
|
|
552
|
+
y_pred_batch = preds.squeeze(1).numpy()
|
|
553
|
+
|
|
554
|
+
# Create [N, 2, H, W] probs for consistency
|
|
555
|
+
probs_neg = 1.0 - probs_pos
|
|
556
|
+
y_prob_batch = torch.cat([probs_neg, probs_pos], dim=1).numpy()
|
|
557
|
+
|
|
558
|
+
# Handle target shape [N, 1, H, W] -> [N, H, W]
|
|
559
|
+
if target.ndim == 4 and target.shape[1] == 1:
|
|
560
|
+
target = target.squeeze(1)
|
|
561
|
+
y_true_batch = target.numpy()
|
|
323
562
|
|
|
324
|
-
elif self.kind ==
|
|
563
|
+
elif self.kind == MLTaskKeys.MULTICLASS_SEGMENTATION:
|
|
325
564
|
# output shape [N, C, H, W]
|
|
326
565
|
probs = torch.softmax(output, dim=1)
|
|
327
566
|
preds = torch.argmax(probs, dim=1) # shape [N, H, W]
|
|
@@ -334,26 +573,161 @@ class MLTrainer:
|
|
|
334
573
|
y_true_batch = target.numpy()
|
|
335
574
|
|
|
336
575
|
yield y_pred_batch, y_prob_batch, y_true_batch
|
|
337
|
-
|
|
576
|
+
|
|
338
577
|
def evaluate(self,
|
|
339
578
|
save_dir: Union[str, Path],
|
|
340
|
-
|
|
341
|
-
|
|
579
|
+
model_checkpoint: Union[Path, Literal["latest", "current"]],
|
|
580
|
+
classification_threshold: Optional[float] = None,
|
|
581
|
+
test_data: Optional[Union[DataLoader, Dataset]] = None,
|
|
582
|
+
val_format_configuration: Optional[Union[ClassificationMetricsFormat,
|
|
583
|
+
MultiClassificationMetricsFormat,
|
|
584
|
+
RegressionMetricsFormat,
|
|
585
|
+
SegmentationMetricsFormat]]=None,
|
|
586
|
+
test_format_configuration: Optional[Union[ClassificationMetricsFormat,
|
|
587
|
+
MultiClassificationMetricsFormat,
|
|
588
|
+
RegressionMetricsFormat,
|
|
589
|
+
SegmentationMetricsFormat]]=None):
|
|
342
590
|
"""
|
|
343
591
|
Evaluates the model, routing to the correct evaluation function based on task `kind`.
|
|
344
592
|
|
|
345
593
|
Args:
|
|
594
|
+
model_checkpoint ('auto' | Path | None):
|
|
595
|
+
- Path to a valid checkpoint for the model. The state of the trained model will be overwritten in place.
|
|
596
|
+
- If 'latest', the latest checkpoint will be loaded if a DragonModelCheckpoint was provided. The state of the trained model will be overwritten in place.
|
|
597
|
+
- If 'current', use the current state of the trained model up the latest trained epoch.
|
|
346
598
|
save_dir (str | Path): Directory to save all reports and plots.
|
|
347
|
-
|
|
599
|
+
classification_threshold (float | None): Used for tasks using a binary approach (binary classification, binary segmentation, multilabel binary classification)
|
|
600
|
+
test_data (DataLoader | Dataset | None): Optional Test data to evaluate the model performance. Validation and Test metrics will be saved to subdirectories.
|
|
601
|
+
val_format_configuration: Optional configuration for metric format output for the validation set.
|
|
602
|
+
test_format_configuration: Optional configuration for metric format output for the test set.
|
|
603
|
+
"""
|
|
604
|
+
# Validate model checkpoint
|
|
605
|
+
if isinstance(model_checkpoint, Path):
|
|
606
|
+
checkpoint_validated = make_fullpath(model_checkpoint, enforce="file")
|
|
607
|
+
elif model_checkpoint in [MagicWords.LATEST, MagicWords.CURRENT]:
|
|
608
|
+
checkpoint_validated = model_checkpoint
|
|
609
|
+
else:
|
|
610
|
+
_LOGGER.error(f"'model_checkpoint' must be a Path object, or the string '{MagicWords.LATEST}', or the string '{MagicWords.CURRENT}'.")
|
|
611
|
+
raise ValueError()
|
|
612
|
+
|
|
613
|
+
# Validate classification threshold
|
|
614
|
+
if self.kind not in MLTaskKeys.ALL_BINARY_TASKS:
|
|
615
|
+
# dummy value for tasks that do not need it
|
|
616
|
+
threshold_validated = 0.5
|
|
617
|
+
elif classification_threshold is None:
|
|
618
|
+
# it should have been provided for binary tasks
|
|
619
|
+
_LOGGER.error(f"The classification threshold must be provided for '{self.kind}'.")
|
|
620
|
+
raise ValueError()
|
|
621
|
+
elif classification_threshold <= 0.0 or classification_threshold >= 1.0:
|
|
622
|
+
# Invalid float
|
|
623
|
+
_LOGGER.error(f"A classification threshold of {classification_threshold} is invalid. Must be in the range (0.0 - 1.0).")
|
|
624
|
+
raise ValueError()
|
|
625
|
+
else:
|
|
626
|
+
threshold_validated = classification_threshold
|
|
627
|
+
|
|
628
|
+
# Validate val configuration
|
|
629
|
+
if val_format_configuration is not None:
|
|
630
|
+
if not isinstance(val_format_configuration, (ClassificationMetricsFormat,
|
|
631
|
+
MultiClassificationMetricsFormat,
|
|
632
|
+
RegressionMetricsFormat,
|
|
633
|
+
SegmentationMetricsFormat)):
|
|
634
|
+
_LOGGER.error(f"Invalid 'format_configuration': '{type(val_format_configuration)}'.")
|
|
635
|
+
raise ValueError()
|
|
636
|
+
else:
|
|
637
|
+
val_configuration_validated = val_format_configuration
|
|
638
|
+
else: # config is None
|
|
639
|
+
val_configuration_validated = None
|
|
640
|
+
|
|
641
|
+
# Validate directory
|
|
642
|
+
save_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
643
|
+
|
|
644
|
+
# Validate test data and dispatch
|
|
645
|
+
if test_data is not None:
|
|
646
|
+
if not isinstance(test_data, (DataLoader, Dataset)):
|
|
647
|
+
_LOGGER.error(f"Invalid type for 'test_data': '{type(test_data)}'.")
|
|
648
|
+
raise ValueError()
|
|
649
|
+
test_data_validated = test_data
|
|
650
|
+
|
|
651
|
+
validation_metrics_path = save_path / DragonTrainerKeys.VALIDATION_METRICS_DIR
|
|
652
|
+
test_metrics_path = save_path / DragonTrainerKeys.TEST_METRICS_DIR
|
|
653
|
+
|
|
654
|
+
# Dispatch validation set
|
|
655
|
+
_LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{DragonTrainerKeys.VALIDATION_METRICS_DIR}'")
|
|
656
|
+
self._evaluate(save_dir=validation_metrics_path,
|
|
657
|
+
model_checkpoint=checkpoint_validated,
|
|
658
|
+
classification_threshold=threshold_validated,
|
|
659
|
+
data=None,
|
|
660
|
+
format_configuration=val_configuration_validated)
|
|
661
|
+
|
|
662
|
+
# Validate test configuration
|
|
663
|
+
if test_format_configuration is not None:
|
|
664
|
+
if not isinstance(test_format_configuration, (ClassificationMetricsFormat,
|
|
665
|
+
MultiClassificationMetricsFormat,
|
|
666
|
+
RegressionMetricsFormat,
|
|
667
|
+
SegmentationMetricsFormat)):
|
|
668
|
+
warning_message_type = f"Invalid test_format_configuration': '{type(val_format_configuration)}'."
|
|
669
|
+
if val_configuration_validated is not None:
|
|
670
|
+
warning_message_type += " 'val_format_configuration' will be used for the test set metrics output."
|
|
671
|
+
test_configuration_validated = val_configuration_validated
|
|
672
|
+
else:
|
|
673
|
+
warning_message_type += " Using default format."
|
|
674
|
+
test_configuration_validated = None
|
|
675
|
+
_LOGGER.warning(warning_message_type)
|
|
676
|
+
else:
|
|
677
|
+
test_configuration_validated = test_format_configuration
|
|
678
|
+
else: #config is None
|
|
679
|
+
test_configuration_validated = None
|
|
680
|
+
|
|
681
|
+
# Dispatch test set
|
|
682
|
+
_LOGGER.info(f"Evaluating on test dataset. Metrics will be saved to '{DragonTrainerKeys.TEST_METRICS_DIR}'")
|
|
683
|
+
self._evaluate(save_dir=test_metrics_path,
|
|
684
|
+
model_checkpoint="current",
|
|
685
|
+
classification_threshold=threshold_validated,
|
|
686
|
+
data=test_data_validated,
|
|
687
|
+
format_configuration=test_configuration_validated)
|
|
688
|
+
else:
|
|
689
|
+
# Dispatch validation set
|
|
690
|
+
_LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{save_path.name}'")
|
|
691
|
+
self._evaluate(save_dir=save_path,
|
|
692
|
+
model_checkpoint=checkpoint_validated,
|
|
693
|
+
classification_threshold=threshold_validated,
|
|
694
|
+
data=None,
|
|
695
|
+
format_configuration=val_configuration_validated)
|
|
696
|
+
|
|
697
|
+
def _evaluate(self,
|
|
698
|
+
save_dir: Union[str, Path],
|
|
699
|
+
model_checkpoint: Union[Path, Literal["latest", "current"]],
|
|
700
|
+
classification_threshold: float,
|
|
701
|
+
data: Optional[Union[DataLoader, Dataset]],
|
|
702
|
+
format_configuration: Optional[Union[ClassificationMetricsFormat,
|
|
703
|
+
MultiClassificationMetricsFormat,
|
|
704
|
+
RegressionMetricsFormat,
|
|
705
|
+
SegmentationMetricsFormat]]):
|
|
706
|
+
"""
|
|
707
|
+
Changed to a private helper function.
|
|
348
708
|
"""
|
|
349
709
|
dataset_for_names = None
|
|
350
710
|
eval_loader = None
|
|
351
|
-
|
|
711
|
+
|
|
712
|
+
# set threshold
|
|
713
|
+
self._classification_threshold = classification_threshold
|
|
714
|
+
|
|
715
|
+
# load model checkpoint
|
|
716
|
+
if isinstance(model_checkpoint, Path):
|
|
717
|
+
self._load_checkpoint(path=model_checkpoint)
|
|
718
|
+
elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback:
|
|
719
|
+
path_to_latest = self._checkpoint_callback.best_checkpoint_path
|
|
720
|
+
self._load_checkpoint(path_to_latest)
|
|
721
|
+
elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback is None:
|
|
722
|
+
_LOGGER.error(f"'model_checkpoint' set to '{MagicWords.LATEST}' but no checkpoint callback was found.")
|
|
723
|
+
raise ValueError()
|
|
724
|
+
|
|
725
|
+
# Dataloader
|
|
352
726
|
if isinstance(data, DataLoader):
|
|
353
727
|
eval_loader = data
|
|
354
728
|
# Try to get the dataset from the loader for fetching target names
|
|
355
729
|
if hasattr(data, 'dataset'):
|
|
356
|
-
dataset_for_names = data.dataset
|
|
730
|
+
dataset_for_names = data.dataset # type: ignore
|
|
357
731
|
elif isinstance(data, Dataset):
|
|
358
732
|
# Create a new loader from the provided dataset
|
|
359
733
|
eval_loader = DataLoader(data,
|
|
@@ -363,17 +737,17 @@ class MLTrainer:
|
|
|
363
737
|
pin_memory=(self.device.type == "cuda"))
|
|
364
738
|
dataset_for_names = data
|
|
365
739
|
else: # data is None, use the trainer's default test dataset
|
|
366
|
-
if self.
|
|
740
|
+
if self.validation_dataset is None:
|
|
367
741
|
_LOGGER.error("Cannot evaluate. No data provided and no test_dataset available in the trainer.")
|
|
368
742
|
raise ValueError()
|
|
369
743
|
# Create a fresh DataLoader from the test_dataset
|
|
370
|
-
eval_loader = DataLoader(self.
|
|
744
|
+
eval_loader = DataLoader(self.validation_dataset,
|
|
371
745
|
batch_size=self._batch_size,
|
|
372
746
|
shuffle=False,
|
|
373
747
|
num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
|
|
374
748
|
pin_memory=(self.device.type == "cuda"))
|
|
375
749
|
|
|
376
|
-
dataset_for_names = self.
|
|
750
|
+
dataset_for_names = self.validation_dataset
|
|
377
751
|
|
|
378
752
|
if eval_loader is None:
|
|
379
753
|
_LOGGER.error("Cannot evaluate. No valid data was provided or found.")
|
|
@@ -396,34 +770,55 @@ class MLTrainer:
|
|
|
396
770
|
y_prob = np.concatenate(all_probs) if all_probs else None
|
|
397
771
|
|
|
398
772
|
# --- Routing Logic ---
|
|
399
|
-
if self.kind ==
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
773
|
+
if self.kind == MLTaskKeys.REGRESSION:
|
|
774
|
+
# Check configuration
|
|
775
|
+
config = None
|
|
776
|
+
if format_configuration and isinstance(format_configuration, RegressionMetricsFormat):
|
|
777
|
+
config = format_configuration
|
|
778
|
+
elif format_configuration:
|
|
779
|
+
_LOGGER.warning(f"Wrong configuration type: Received {type(format_configuration).__name__}, expected RegressionMetricsFormat.")
|
|
780
|
+
|
|
781
|
+
regression_metrics(y_true=y_true.flatten(),
|
|
782
|
+
y_pred=y_pred.flatten(),
|
|
783
|
+
save_dir=save_dir,
|
|
784
|
+
config=config)
|
|
785
|
+
|
|
786
|
+
elif self.kind in [MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.BINARY_IMAGE_CLASSIFICATION, MLTaskKeys.MULTICLASS_CLASSIFICATION, MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION]:
|
|
787
|
+
# Check configuration
|
|
788
|
+
config = None
|
|
404
789
|
if format_configuration and isinstance(format_configuration, ClassificationMetricsFormat):
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
elif self.kind == "multi_target_regression":
|
|
790
|
+
config = format_configuration
|
|
791
|
+
elif format_configuration:
|
|
792
|
+
_LOGGER.warning(f"Wrong configuration type: Received {type(format_configuration).__name__}, expected ClassificationMetricsFormat.")
|
|
793
|
+
|
|
794
|
+
classification_metrics(save_dir=save_dir,
|
|
795
|
+
y_true=y_true,
|
|
796
|
+
y_pred=y_pred,
|
|
797
|
+
y_prob=y_prob,
|
|
798
|
+
config=config)
|
|
799
|
+
|
|
800
|
+
elif self.kind == MLTaskKeys.MULTITARGET_REGRESSION:
|
|
418
801
|
try:
|
|
419
802
|
target_names = dataset_for_names.target_names # type: ignore
|
|
420
803
|
except AttributeError:
|
|
421
804
|
num_targets = y_true.shape[1]
|
|
422
805
|
target_names = [f"target_{i}" for i in range(num_targets)]
|
|
423
806
|
_LOGGER.warning(f"Dataset has no 'target_names' attribute. Using generic names.")
|
|
424
|
-
|
|
807
|
+
|
|
808
|
+
# Check configuration
|
|
809
|
+
config = None
|
|
810
|
+
if format_configuration and isinstance(format_configuration, RegressionMetricsFormat):
|
|
811
|
+
config = format_configuration
|
|
812
|
+
elif format_configuration:
|
|
813
|
+
_LOGGER.warning(f"Wrong configuration type: Received {type(format_configuration).__name__}, expected RegressionMetricsFormat.")
|
|
814
|
+
|
|
815
|
+
multi_target_regression_metrics(y_true=y_true,
|
|
816
|
+
y_pred=y_pred,
|
|
817
|
+
target_names=target_names,
|
|
818
|
+
save_dir=save_dir,
|
|
819
|
+
config=config)
|
|
425
820
|
|
|
426
|
-
elif self.kind ==
|
|
821
|
+
elif self.kind == MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION:
|
|
427
822
|
try:
|
|
428
823
|
target_names = dataset_for_names.target_names # type: ignore
|
|
429
824
|
except AttributeError:
|
|
@@ -435,19 +830,21 @@ class MLTrainer:
|
|
|
435
830
|
_LOGGER.error("Evaluation for multi_label_classification requires probabilities (y_prob).")
|
|
436
831
|
return
|
|
437
832
|
|
|
833
|
+
# Check configuration
|
|
834
|
+
config = None
|
|
438
835
|
if format_configuration and isinstance(format_configuration, MultiClassificationMetricsFormat):
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
836
|
+
config = format_configuration
|
|
837
|
+
elif format_configuration:
|
|
838
|
+
_LOGGER.warning(f"Wrong configuration type: Received {type(format_configuration).__name__}, expected MultiClassificationMetricsFormat.")
|
|
839
|
+
|
|
840
|
+
multi_label_classification_metrics(y_true=y_true,
|
|
841
|
+
y_pred=y_pred,
|
|
842
|
+
y_prob=y_prob,
|
|
843
|
+
target_names=target_names,
|
|
844
|
+
save_dir=save_dir,
|
|
845
|
+
config=config)
|
|
449
846
|
|
|
450
|
-
elif self.kind
|
|
847
|
+
elif self.kind in [MLTaskKeys.BINARY_SEGMENTATION, MLTaskKeys.MULTICLASS_SEGMENTATION]:
|
|
451
848
|
class_names = None
|
|
452
849
|
try:
|
|
453
850
|
# Try to get 'classes' from VisionDatasetMaker
|
|
@@ -469,10 +866,18 @@ class MLTrainer:
|
|
|
469
866
|
class_names = [f"Class {i}" for i in labels]
|
|
470
867
|
_LOGGER.warning(f"Dataset has no 'classes' or 'target_names' attribute. Using generic names.")
|
|
471
868
|
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
869
|
+
# Check configuration
|
|
870
|
+
config = None
|
|
871
|
+
if format_configuration and isinstance(format_configuration, SegmentationMetricsFormat):
|
|
872
|
+
config = format_configuration
|
|
873
|
+
elif format_configuration:
|
|
874
|
+
_LOGGER.warning(f"Wrong configuration type: Received {type(format_configuration).__name__}, expected SegmentationMetricsFormat.")
|
|
875
|
+
|
|
876
|
+
segmentation_metrics(y_true=y_true,
|
|
877
|
+
y_pred=y_pred,
|
|
878
|
+
save_dir=save_dir,
|
|
879
|
+
class_names=class_names,
|
|
880
|
+
config=config)
|
|
476
881
|
|
|
477
882
|
def explain(self,
|
|
478
883
|
save_dir: Union[str,Path],
|
|
@@ -537,7 +942,7 @@ class MLTrainer:
|
|
|
537
942
|
return
|
|
538
943
|
|
|
539
944
|
# 2. Determine target dataset and get explanation instances
|
|
540
|
-
target_dataset = explain_dataset if explain_dataset is not None else self.
|
|
945
|
+
target_dataset = explain_dataset if explain_dataset is not None else self.validation_dataset
|
|
541
946
|
instances_to_explain = _get_random_sample(target_dataset, n_samples)
|
|
542
947
|
if instances_to_explain is None:
|
|
543
948
|
_LOGGER.error("Explanation dataset is empty or invalid. Skipping SHAP analysis.")
|
|
@@ -556,7 +961,7 @@ class MLTrainer:
|
|
|
556
961
|
self.model.to(self.device)
|
|
557
962
|
|
|
558
963
|
# 3. Call the plotting function
|
|
559
|
-
if self.kind in [
|
|
964
|
+
if self.kind in [MLTaskKeys.REGRESSION, MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.MULTICLASS_CLASSIFICATION]:
|
|
560
965
|
shap_summary_plot(
|
|
561
966
|
model=self.model,
|
|
562
967
|
background_data=background_data,
|
|
@@ -566,7 +971,7 @@ class MLTrainer:
|
|
|
566
971
|
explainer_type=explainer_type,
|
|
567
972
|
device=self.device
|
|
568
973
|
)
|
|
569
|
-
elif self.kind in [
|
|
974
|
+
elif self.kind in [MLTaskKeys.MULTITARGET_REGRESSION, MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION]:
|
|
570
975
|
# try to get target names
|
|
571
976
|
if target_names is None:
|
|
572
977
|
target_names = []
|
|
@@ -640,13 +1045,11 @@ class MLTrainer:
|
|
|
640
1045
|
|
|
641
1046
|
# --- Step 1: Check if the model supports this explanation ---
|
|
642
1047
|
if not getattr(self.model, 'has_interpretable_attention', False):
|
|
643
|
-
_LOGGER.warning(
|
|
644
|
-
"Model is not flagged for interpretable attention analysis. Skipping. This is the correct behavior for models like MultiHeadAttentionMLP."
|
|
645
|
-
)
|
|
1048
|
+
_LOGGER.warning("Model is not compatible with interpretable attention analysis. Skipping.")
|
|
646
1049
|
return
|
|
647
1050
|
|
|
648
1051
|
# --- Step 2: Set up the dataloader ---
|
|
649
|
-
dataset_to_use = explain_dataset if explain_dataset is not None else self.
|
|
1052
|
+
dataset_to_use = explain_dataset if explain_dataset is not None else self.validation_dataset
|
|
650
1053
|
if not isinstance(dataset_to_use, Dataset):
|
|
651
1054
|
_LOGGER.error("The explanation dataset is empty or invalid. Skipping attention analysis.")
|
|
652
1055
|
return
|
|
@@ -681,40 +1084,111 @@ class MLTrainer:
|
|
|
681
1084
|
)
|
|
682
1085
|
else:
|
|
683
1086
|
_LOGGER.error("No attention weights were collected from the model.")
|
|
684
|
-
|
|
685
|
-
def _callbacks_hook(self, method_name: str, *args, **kwargs):
|
|
686
|
-
"""Calls the specified method on all callbacks."""
|
|
687
|
-
for callback in self.callbacks:
|
|
688
|
-
method = getattr(callback, method_name)
|
|
689
|
-
method(*args, **kwargs)
|
|
690
|
-
|
|
691
|
-
def to_cpu(self):
|
|
692
|
-
"""
|
|
693
|
-
Moves the model to the CPU and updates the trainer's device setting.
|
|
694
1087
|
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
def to_device(self, device: str):
|
|
1088
|
+
def finalize_model_training(self,
|
|
1089
|
+
save_dir: Union[str, Path],
|
|
1090
|
+
filename: str,
|
|
1091
|
+
model_checkpoint: Union[Path, Literal['latest', 'current']],
|
|
1092
|
+
classification_threshold: Optional[float]=None,
|
|
1093
|
+
class_map: Optional[Dict[str,int]]=None):
|
|
702
1094
|
"""
|
|
703
|
-
|
|
1095
|
+
Saves a finalized, "inference-ready" model state to a .pth file.
|
|
1096
|
+
|
|
1097
|
+
This method saves the model's `state_dict`, the final epoch number, and
|
|
1098
|
+
an optional classification threshold required for binary-based tasks (binary classification, binary segmentation,
|
|
1099
|
+
multilabel binary classification).
|
|
704
1100
|
|
|
705
1101
|
Args:
|
|
706
|
-
|
|
1102
|
+
save_dir (str | Path): The directory to save the finalized model.
|
|
1103
|
+
filename (str): The desired filename for the saved file.
|
|
1104
|
+
model_checkpoint (Path | "latest" | "current"):
|
|
1105
|
+
- Path: Loads the model state from a specific checkpoint file.
|
|
1106
|
+
- "latest": Loads the best model state saved by the `DragonModelCheckpoint` callback.
|
|
1107
|
+
- "current": Uses the model's state as it is at the end of the `fit()` call.
|
|
1108
|
+
classification_threshold (float, None):
|
|
1109
|
+
Required for `binary classification`, `binary segmentation`, and
|
|
1110
|
+
`multilabel binary classification`. This is the threshold (0.0-1.0)
|
|
1111
|
+
used to convert probabilities to class labels.
|
|
1112
|
+
class_map (Dict[str, int] | None): Sets the class name mapping to translate predicted integer labels back into string names. (For Classification and Segmentation Tasks)
|
|
707
1113
|
"""
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
|
|
1114
|
+
# handle save path
|
|
1115
|
+
sanitized_filename = sanitize_filename(filename)
|
|
1116
|
+
if not sanitized_filename.endswith(".pth"):
|
|
1117
|
+
sanitized_filename = sanitized_filename + ".pth"
|
|
1118
|
+
|
|
1119
|
+
dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
1120
|
+
full_path = dir_path / sanitized_filename
|
|
1121
|
+
|
|
1122
|
+
# threshold required for binary tasks
|
|
1123
|
+
if self.kind in MLTaskKeys.ALL_BINARY_TASKS:
|
|
1124
|
+
if classification_threshold is None:
|
|
1125
|
+
_LOGGER.error(f"A classification threshold is needed for binary-based classification tasks. If unknown, use '0.5' as a default.")
|
|
1126
|
+
raise ValueError()
|
|
1127
|
+
elif not isinstance(classification_threshold, float):
|
|
1128
|
+
_LOGGER.error(f"The classification threshold must be a float value.")
|
|
1129
|
+
raise TypeError()
|
|
1130
|
+
elif classification_threshold <= 0.0 or classification_threshold >= 1.0:
|
|
1131
|
+
_LOGGER.error(f"The classification threshold must be in the range (0.0 - 1.0).")
|
|
1132
|
+
else:
|
|
1133
|
+
classification_threshold = None
|
|
1134
|
+
|
|
1135
|
+
# handle checkpoint
|
|
1136
|
+
if isinstance(model_checkpoint, Path):
|
|
1137
|
+
self._load_checkpoint(path=model_checkpoint)
|
|
1138
|
+
elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback:
|
|
1139
|
+
path_to_latest = self._checkpoint_callback.best_checkpoint_path
|
|
1140
|
+
self._load_checkpoint(path_to_latest)
|
|
1141
|
+
elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback is None:
|
|
1142
|
+
_LOGGER.error(f"'model_checkpoint' set to '{MagicWords.LATEST}' but no checkpoint callback was found.")
|
|
1143
|
+
raise ValueError()
|
|
1144
|
+
elif model_checkpoint == MagicWords.CURRENT:
|
|
1145
|
+
pass
|
|
1146
|
+
else:
|
|
1147
|
+
_LOGGER.error(f"Unknown 'model_checkpoint' parameter received '{model_checkpoint}'.")
|
|
1148
|
+
|
|
1149
|
+
# Handle class map
|
|
1150
|
+
if self.kind in [MLTaskKeys.BINARY_CLASSIFICATION,
|
|
1151
|
+
MLTaskKeys.MULTICLASS_CLASSIFICATION,
|
|
1152
|
+
MLTaskKeys.BINARY_IMAGE_CLASSIFICATION,
|
|
1153
|
+
MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION,
|
|
1154
|
+
MLTaskKeys.BINARY_SEGMENTATION,
|
|
1155
|
+
MLTaskKeys.MULTICLASS_SEGMENTATION]:
|
|
1156
|
+
if class_map is None:
|
|
1157
|
+
_LOGGER.error(f"'class_map' is required for '{self.kind}'.")
|
|
1158
|
+
raise ValueError()
|
|
1159
|
+
else:
|
|
1160
|
+
class_map = None
|
|
1161
|
+
|
|
1162
|
+
# Create finalized data
|
|
1163
|
+
finalized_data = {
|
|
1164
|
+
PyTorchCheckpointKeys.EPOCH: self.epoch,
|
|
1165
|
+
PyTorchCheckpointKeys.MODEL_STATE: self.model.state_dict(),
|
|
1166
|
+
}
|
|
1167
|
+
|
|
1168
|
+
if classification_threshold is not None:
|
|
1169
|
+
self._classification_threshold = classification_threshold
|
|
1170
|
+
finalized_data[PyTorchCheckpointKeys.CLASSIFICATION_THRESHOLD] = classification_threshold
|
|
1171
|
+
|
|
1172
|
+
if class_map is not None:
|
|
1173
|
+
finalized_data[PyTorchCheckpointKeys.CLASS_MAP] = class_map
|
|
1174
|
+
|
|
1175
|
+
torch.save(finalized_data, full_path)
|
|
1176
|
+
|
|
1177
|
+
_LOGGER.info(f"Finalized model weights saved to {full_path}.")
|
|
711
1178
|
|
|
712
1179
|
|
|
713
1180
|
# Object Detection Trainer
|
|
714
|
-
class
|
|
715
|
-
def __init__(self, model: nn.Module,
|
|
1181
|
+
class DragonDetectionTrainer(_BaseDragonTrainer):
|
|
1182
|
+
def __init__(self, model: nn.Module,
|
|
1183
|
+
train_dataset: Dataset,
|
|
1184
|
+
validation_dataset: Dataset,
|
|
716
1185
|
collate_fn: Callable, optimizer: torch.optim.Optimizer,
|
|
717
|
-
device: Union[Literal['cuda', 'mps', 'cpu'],str],
|
|
1186
|
+
device: Union[Literal['cuda', 'mps', 'cpu'],str],
|
|
1187
|
+
checkpoint_callback: Optional[DragonModelCheckpoint],
|
|
1188
|
+
early_stopping_callback: Optional[DragonEarlyStopping],
|
|
1189
|
+
lr_scheduler_callback: Optional[DragonLRScheduler],
|
|
1190
|
+
extra_callbacks: Optional[List[_Callback]] = None,
|
|
1191
|
+
dataloader_workers: int = 2):
|
|
718
1192
|
"""
|
|
719
1193
|
Automates the training process of an Object Detection Model (e.g., DragonFastRCNN).
|
|
720
1194
|
|
|
@@ -723,58 +1197,36 @@ class ObjectDetectionTrainer:
|
|
|
723
1197
|
Args:
|
|
724
1198
|
model (nn.Module): The PyTorch object detection model to train.
|
|
725
1199
|
train_dataset (Dataset): The training dataset.
|
|
726
|
-
|
|
1200
|
+
validation_dataset (Dataset): The testing/validation dataset.
|
|
727
1201
|
collate_fn (Callable): The collate function from `ObjectDetectionDatasetMaker.collate_fn`.
|
|
728
1202
|
optimizer (torch.optim.Optimizer): The optimizer.
|
|
729
1203
|
device (str): The device to run training on ('cpu', 'cuda', 'mps').
|
|
730
1204
|
dataloader_workers (int): Subprocesses for data loading.
|
|
731
|
-
|
|
1205
|
+
checkpoint_callback (DragonModelCheckpoint | None): Callback to save the model.
|
|
1206
|
+
early_stopping_callback (DragonEarlyStopping | None): Callback to stop training early.
|
|
1207
|
+
lr_scheduler_callback (DragonLRScheduler | None): Callback to manage the LR scheduler.
|
|
1208
|
+
extra_callbacks (List[Callback] | None): A list of extra callbacks to use during training.
|
|
732
1209
|
|
|
733
1210
|
## Note:
|
|
734
1211
|
This trainer is specialized. It does not take a `criterion` because object detection models like Faster R-CNN return a dictionary of losses directly from their forward pass during training.
|
|
735
1212
|
"""
|
|
736
|
-
|
|
1213
|
+
# Call the base class constructor with common parameters
|
|
1214
|
+
super().__init__(
|
|
1215
|
+
model=model,
|
|
1216
|
+
optimizer=optimizer,
|
|
1217
|
+
device=device,
|
|
1218
|
+
dataloader_workers=dataloader_workers,
|
|
1219
|
+
checkpoint_callback=checkpoint_callback,
|
|
1220
|
+
early_stopping_callback=early_stopping_callback,
|
|
1221
|
+
lr_scheduler_callback=lr_scheduler_callback,
|
|
1222
|
+
extra_callbacks=extra_callbacks
|
|
1223
|
+
)
|
|
1224
|
+
|
|
737
1225
|
self.train_dataset = train_dataset
|
|
738
|
-
self.
|
|
1226
|
+
self.validation_dataset = validation_dataset # <-- Renamed
|
|
739
1227
|
self.kind = "object_detection"
|
|
740
1228
|
self.collate_fn = collate_fn
|
|
741
1229
|
self.criterion = None # Criterion is handled inside the model
|
|
742
|
-
self.optimizer = optimizer
|
|
743
|
-
self.scheduler = None
|
|
744
|
-
self.device = self._validate_device(device)
|
|
745
|
-
self.dataloader_workers = dataloader_workers
|
|
746
|
-
|
|
747
|
-
# Callback handler - History and TqdmProgressBar are added by default
|
|
748
|
-
default_callbacks = [History(), TqdmProgressBar()]
|
|
749
|
-
user_callbacks = callbacks if callbacks is not None else []
|
|
750
|
-
self.callbacks = default_callbacks + user_callbacks
|
|
751
|
-
self._set_trainer_on_callbacks()
|
|
752
|
-
|
|
753
|
-
# Internal state
|
|
754
|
-
self.train_loader = None
|
|
755
|
-
self.test_loader = None
|
|
756
|
-
self.history = {}
|
|
757
|
-
self.epoch = 0
|
|
758
|
-
self.epochs = 0 # Total epochs for the fit run
|
|
759
|
-
self.start_epoch = 1
|
|
760
|
-
self.stop_training = False
|
|
761
|
-
self._batch_size = 10
|
|
762
|
-
|
|
763
|
-
def _validate_device(self, device: str) -> torch.device:
|
|
764
|
-
"""Validates the selected device and returns a torch.device object."""
|
|
765
|
-
device_lower = device.lower()
|
|
766
|
-
if "cuda" in device_lower and not torch.cuda.is_available():
|
|
767
|
-
_LOGGER.warning("CUDA not available, switching to CPU.")
|
|
768
|
-
device = "cpu"
|
|
769
|
-
elif device_lower == "mps" and not torch.backends.mps.is_available():
|
|
770
|
-
_LOGGER.warning("Apple Metal Performance Shaders (MPS) not available, switching to CPU.")
|
|
771
|
-
device = "cpu"
|
|
772
|
-
return torch.device(device)
|
|
773
|
-
|
|
774
|
-
def _set_trainer_on_callbacks(self):
|
|
775
|
-
"""Gives each callback a reference to this trainer instance."""
|
|
776
|
-
for callback in self.callbacks:
|
|
777
|
-
callback.set_trainer(self)
|
|
778
1230
|
|
|
779
1231
|
def _create_dataloaders(self, batch_size: int, shuffle: bool):
|
|
780
1232
|
"""Initializes the DataLoaders with the object detection collate_fn."""
|
|
@@ -786,125 +1238,25 @@ class ObjectDetectionTrainer:
|
|
|
786
1238
|
batch_size=batch_size,
|
|
787
1239
|
shuffle=shuffle,
|
|
788
1240
|
num_workers=loader_workers,
|
|
789
|
-
pin_memory=("cuda" in self.device.type),
|
|
790
|
-
collate_fn=self.collate_fn # Use the provided collate function
|
|
1241
|
+
pin_memory=("cuda" in self.device.type),
|
|
1242
|
+
collate_fn=self.collate_fn, # Use the provided collate function
|
|
1243
|
+
drop_last=True
|
|
791
1244
|
)
|
|
792
1245
|
|
|
793
|
-
self.
|
|
794
|
-
dataset=self.
|
|
1246
|
+
self.validation_loader = DataLoader(
|
|
1247
|
+
dataset=self.validation_dataset,
|
|
795
1248
|
batch_size=batch_size,
|
|
796
1249
|
shuffle=False,
|
|
797
1250
|
num_workers=loader_workers,
|
|
798
1251
|
pin_memory=("cuda" in self.device.type),
|
|
799
1252
|
collate_fn=self.collate_fn # Use the provided collate function
|
|
800
1253
|
)
|
|
801
|
-
|
|
802
|
-
def _load_checkpoint(self, path: Union[str, Path]):
|
|
803
|
-
"""Loads a training checkpoint to resume training."""
|
|
804
|
-
p = make_fullpath(path, enforce="file")
|
|
805
|
-
_LOGGER.info(f"Loading checkpoint from '{p.name}' to resume training...")
|
|
806
|
-
|
|
807
|
-
try:
|
|
808
|
-
checkpoint = torch.load(p, map_location=self.device)
|
|
809
|
-
|
|
810
|
-
if PyTorchCheckpointKeys.MODEL_STATE not in checkpoint or PyTorchCheckpointKeys.OPTIMIZER_STATE not in checkpoint:
|
|
811
|
-
_LOGGER.error(f"Checkpoint file '{p.name}' is invalid. Missing 'model_state_dict' or 'optimizer_state_dict'.")
|
|
812
|
-
raise KeyError()
|
|
813
|
-
|
|
814
|
-
self.model.load_state_dict(checkpoint[PyTorchCheckpointKeys.MODEL_STATE])
|
|
815
|
-
self.optimizer.load_state_dict(checkpoint[PyTorchCheckpointKeys.OPTIMIZER_STATE])
|
|
816
|
-
self.start_epoch = checkpoint.get(PyTorchCheckpointKeys.EPOCH, 0) + 1 # Resume on the *next* epoch
|
|
817
|
-
|
|
818
|
-
# --- Scheduler State Loading Logic ---
|
|
819
|
-
scheduler_state_exists = PyTorchCheckpointKeys.SCHEDULER_STATE in checkpoint
|
|
820
|
-
scheduler_object_exists = self.scheduler is not None
|
|
821
|
-
|
|
822
|
-
if scheduler_object_exists and scheduler_state_exists:
|
|
823
|
-
# Case 1: Both exist. Attempt to load.
|
|
824
|
-
try:
|
|
825
|
-
self.scheduler.load_state_dict(checkpoint[PyTorchCheckpointKeys.SCHEDULER_STATE]) # type: ignore
|
|
826
|
-
scheduler_name = self.scheduler.__class__.__name__
|
|
827
|
-
_LOGGER.info(f"Restored LR scheduler state for: {scheduler_name}")
|
|
828
|
-
except Exception as e:
|
|
829
|
-
# Loading failed, likely a mismatch
|
|
830
|
-
scheduler_name = self.scheduler.__class__.__name__
|
|
831
|
-
_LOGGER.error(f"Failed to load scheduler state for '{scheduler_name}'. A different scheduler type might have been used.")
|
|
832
|
-
raise e
|
|
833
1254
|
|
|
834
|
-
elif scheduler_object_exists and not scheduler_state_exists:
|
|
835
|
-
# Case 2: Scheduler provided, but no state in checkpoint.
|
|
836
|
-
scheduler_name = self.scheduler.__class__.__name__
|
|
837
|
-
_LOGGER.warning(f"'{scheduler_name}' was provided, but no scheduler state was found in the checkpoint. The scheduler will start from its initial state.")
|
|
838
|
-
|
|
839
|
-
elif not scheduler_object_exists and scheduler_state_exists:
|
|
840
|
-
# Case 3: State in checkpoint, but no scheduler provided.
|
|
841
|
-
_LOGGER.error("Checkpoint contains an LR scheduler state, but no LRScheduler callback was provided.")
|
|
842
|
-
raise ValueError()
|
|
843
|
-
|
|
844
|
-
# Restore callback states
|
|
845
|
-
for cb in self.callbacks:
|
|
846
|
-
if isinstance(cb, ModelCheckpoint) and PyTorchCheckpointKeys.BEST_SCORE in checkpoint:
|
|
847
|
-
cb.best = checkpoint[PyTorchCheckpointKeys.BEST_SCORE]
|
|
848
|
-
_LOGGER.info(f"Restored {cb.__class__.__name__} 'best' score to: {cb.best:.4f}")
|
|
849
|
-
|
|
850
|
-
_LOGGER.info(f"Checkpoint loaded. Resuming training from epoch {self.start_epoch}.")
|
|
851
|
-
|
|
852
|
-
except Exception as e:
|
|
853
|
-
_LOGGER.error(f"Failed to load checkpoint from '{p}': {e}")
|
|
854
|
-
raise
|
|
855
|
-
|
|
856
|
-
def fit(self,
|
|
857
|
-
epochs: int = 10,
|
|
858
|
-
batch_size: int = 10,
|
|
859
|
-
shuffle: bool = True,
|
|
860
|
-
resume_from_checkpoint: Optional[Union[str, Path]] = None):
|
|
861
|
-
"""
|
|
862
|
-
Starts the training-validation process of the model.
|
|
863
|
-
|
|
864
|
-
Returns the "History" callback dictionary.
|
|
865
|
-
|
|
866
|
-
Args:
|
|
867
|
-
epochs (int): The total number of epochs to train for.
|
|
868
|
-
batch_size (int): The number of samples per batch.
|
|
869
|
-
shuffle (bool): Whether to shuffle the training data at each epoch.
|
|
870
|
-
resume_from_checkpoint (str | Path | None): Optional path to a checkpoint to resume training.
|
|
871
|
-
"""
|
|
872
|
-
self.epochs = epochs
|
|
873
|
-
self._batch_size = batch_size
|
|
874
|
-
self._create_dataloaders(self._batch_size, shuffle)
|
|
875
|
-
self.model.to(self.device)
|
|
876
|
-
|
|
877
|
-
if resume_from_checkpoint:
|
|
878
|
-
self._load_checkpoint(resume_from_checkpoint)
|
|
879
|
-
|
|
880
|
-
# Reset stop_training flag on the trainer
|
|
881
|
-
self.stop_training = False
|
|
882
|
-
|
|
883
|
-
self._callbacks_hook('on_train_begin')
|
|
884
|
-
|
|
885
|
-
for epoch in range(self.start_epoch, self.epochs + 1):
|
|
886
|
-
self.epoch = epoch
|
|
887
|
-
epoch_logs = {}
|
|
888
|
-
self._callbacks_hook('on_epoch_begin', epoch, logs=epoch_logs)
|
|
889
|
-
|
|
890
|
-
train_logs = self._train_step()
|
|
891
|
-
epoch_logs.update(train_logs)
|
|
892
|
-
|
|
893
|
-
val_logs = self._validation_step()
|
|
894
|
-
epoch_logs.update(val_logs)
|
|
895
|
-
|
|
896
|
-
self._callbacks_hook('on_epoch_end', epoch, logs=epoch_logs)
|
|
897
|
-
|
|
898
|
-
# Check the early stopping flag
|
|
899
|
-
if self.stop_training:
|
|
900
|
-
break
|
|
901
|
-
|
|
902
|
-
self._callbacks_hook('on_train_end')
|
|
903
|
-
return self.history
|
|
904
|
-
|
|
905
1255
|
def _train_step(self):
|
|
906
1256
|
self.model.train()
|
|
907
1257
|
running_loss = 0.0
|
|
1258
|
+
total_samples = 0
|
|
1259
|
+
|
|
908
1260
|
for batch_idx, (images, targets) in enumerate(self.train_loader): # type: ignore
|
|
909
1261
|
# images is a tuple of tensors, targets is a tuple of dicts
|
|
910
1262
|
batch_size = len(images)
|
|
@@ -941,21 +1293,28 @@ class ObjectDetectionTrainer:
|
|
|
941
1293
|
# Calculate batch loss and update running loss for the epoch
|
|
942
1294
|
batch_loss = loss.item()
|
|
943
1295
|
running_loss += batch_loss * batch_size
|
|
1296
|
+
total_samples += batch_size # <-- Accumulate total samples
|
|
944
1297
|
|
|
945
1298
|
# Add the batch loss to the logs and call the end-of-batch hook
|
|
946
1299
|
batch_logs[PyTorchLogKeys.BATCH_LOSS] = batch_loss # type: ignore
|
|
947
1300
|
self._callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
|
|
1301
|
+
|
|
1302
|
+
# Calculate loss using the correct denominator
|
|
1303
|
+
if total_samples == 0:
|
|
1304
|
+
_LOGGER.warning("No samples processed in _train_step. Returning 0 loss.")
|
|
1305
|
+
return {PyTorchLogKeys.TRAIN_LOSS: 0.0}
|
|
948
1306
|
|
|
949
|
-
return {PyTorchLogKeys.TRAIN_LOSS: running_loss /
|
|
1307
|
+
return {PyTorchLogKeys.TRAIN_LOSS: running_loss / total_samples}
|
|
950
1308
|
|
|
951
1309
|
def _validation_step(self):
|
|
952
1310
|
self.model.train() # Set to train mode even for validation loss calculation
|
|
953
|
-
# as model internals (e.g., proposals) might differ,
|
|
954
|
-
#
|
|
955
|
-
# We use torch.no_grad() to prevent gradient updates.
|
|
1311
|
+
# as model internals (e.g., proposals) might differ, but we still need loss_dict.
|
|
1312
|
+
# use torch.no_grad() to prevent gradient updates.
|
|
956
1313
|
running_loss = 0.0
|
|
1314
|
+
total_samples = 0
|
|
1315
|
+
|
|
957
1316
|
with torch.no_grad():
|
|
958
|
-
for images, targets in self.
|
|
1317
|
+
for images, targets in self.validation_loader: # type: ignore
|
|
959
1318
|
batch_size = len(images)
|
|
960
1319
|
|
|
961
1320
|
# Move data to device
|
|
@@ -973,25 +1332,105 @@ class ObjectDetectionTrainer:
|
|
|
973
1332
|
loss: torch.Tensor = sum(l for l in loss_dict.values()) # type: ignore
|
|
974
1333
|
|
|
975
1334
|
running_loss += loss.item() * batch_size
|
|
1335
|
+
total_samples += batch_size # <-- Accumulate total samples
|
|
1336
|
+
|
|
1337
|
+
# Calculate loss using the correct denominator
|
|
1338
|
+
if total_samples == 0:
|
|
1339
|
+
_LOGGER.warning("No samples processed in _validation_step. Returning 0 loss.")
|
|
1340
|
+
return {PyTorchLogKeys.VAL_LOSS: 0.0}
|
|
976
1341
|
|
|
977
|
-
logs = {PyTorchLogKeys.VAL_LOSS: running_loss /
|
|
1342
|
+
logs = {PyTorchLogKeys.VAL_LOSS: running_loss / total_samples}
|
|
978
1343
|
return logs
|
|
1344
|
+
|
|
1345
|
+
def evaluate(self,
|
|
1346
|
+
save_dir: Union[str, Path],
|
|
1347
|
+
model_checkpoint: Union[Path, Literal["latest", "current"]],
|
|
1348
|
+
test_data: Optional[Union[DataLoader, Dataset]] = None):
|
|
1349
|
+
"""
|
|
1350
|
+
Evaluates the model using object detection mAP metrics.
|
|
979
1351
|
|
|
980
|
-
|
|
1352
|
+
Args:
|
|
1353
|
+
save_dir (str | Path): Directory to save all reports and plots.
|
|
1354
|
+
model_checkpoint ('auto' | Path | None):
|
|
1355
|
+
- Path to a valid checkpoint for the model. The state of the trained model will be overwritten in place.
|
|
1356
|
+
- If 'latest', the latest checkpoint will be loaded if a DragonModelCheckpoint was provided. The state of the trained model will be overwritten in place.
|
|
1357
|
+
- If 'current', use the current state of the trained model up the latest trained epoch.
|
|
1358
|
+
test_data (DataLoader | Dataset | None): Optional Test data to evaluate the model performance. Validation and Test metrics will be saved to subdirectories.
|
|
981
1359
|
"""
|
|
1360
|
+
# Validate model checkpoint
|
|
1361
|
+
if isinstance(model_checkpoint, Path):
|
|
1362
|
+
checkpoint_validated = make_fullpath(model_checkpoint, enforce="file")
|
|
1363
|
+
elif model_checkpoint in [MagicWords.LATEST, MagicWords.CURRENT]:
|
|
1364
|
+
checkpoint_validated = model_checkpoint
|
|
1365
|
+
else:
|
|
1366
|
+
_LOGGER.error(f"'model_checkpoint' must be a Path object, or the string '{MagicWords.LATEST}', or the string '{MagicWords.CURRENT}'.")
|
|
1367
|
+
raise ValueError()
|
|
1368
|
+
|
|
1369
|
+
# Validate directory
|
|
1370
|
+
save_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
1371
|
+
|
|
1372
|
+
# Validate test data and dispatch
|
|
1373
|
+
if test_data is not None:
|
|
1374
|
+
if not isinstance(test_data, (DataLoader, Dataset)):
|
|
1375
|
+
_LOGGER.error(f"Invalid type for 'test_data': '{type(test_data)}'.")
|
|
1376
|
+
raise ValueError()
|
|
1377
|
+
test_data_validated = test_data
|
|
1378
|
+
|
|
1379
|
+
validation_metrics_path = save_path / DragonTrainerKeys.VALIDATION_METRICS_DIR
|
|
1380
|
+
test_metrics_path = save_path / DragonTrainerKeys.TEST_METRICS_DIR
|
|
1381
|
+
|
|
1382
|
+
# Dispatch validation set
|
|
1383
|
+
_LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{DragonTrainerKeys.VALIDATION_METRICS_DIR}'")
|
|
1384
|
+
self._evaluate(save_dir=validation_metrics_path,
|
|
1385
|
+
model_checkpoint=checkpoint_validated,
|
|
1386
|
+
data=None) # 'None' triggers use of self.test_dataset
|
|
1387
|
+
|
|
1388
|
+
# Dispatch test set
|
|
1389
|
+
_LOGGER.info(f"Evaluating on test dataset. Metrics will be saved to '{DragonTrainerKeys.TEST_METRICS_DIR}'")
|
|
1390
|
+
self._evaluate(save_dir=test_metrics_path,
|
|
1391
|
+
model_checkpoint="current", # Use 'current' state after loading checkpoint once
|
|
1392
|
+
data=test_data_validated)
|
|
1393
|
+
else:
|
|
1394
|
+
# Dispatch validation set
|
|
1395
|
+
_LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{save_path.name}'")
|
|
1396
|
+
self._evaluate(save_dir=save_path,
|
|
1397
|
+
model_checkpoint=checkpoint_validated,
|
|
1398
|
+
data=None) # 'None' triggers use of self.test_dataset
|
|
1399
|
+
|
|
1400
|
+
def _evaluate(self,
|
|
1401
|
+
save_dir: Union[str, Path],
|
|
1402
|
+
model_checkpoint: Union[Path, Literal["latest", "current"]],
|
|
1403
|
+
data: Optional[Union[DataLoader, Dataset]]):
|
|
1404
|
+
"""
|
|
1405
|
+
Changed to a private helper method
|
|
982
1406
|
Evaluates the model using object detection mAP metrics.
|
|
983
1407
|
|
|
984
1408
|
Args:
|
|
985
1409
|
save_dir (str | Path): Directory to save all reports and plots.
|
|
986
1410
|
data (DataLoader | Dataset | None): The data to evaluate on. If None, defaults to the trainer's internal test_dataset.
|
|
1411
|
+
model_checkpoint ('auto' | Path | None):
|
|
1412
|
+
- Path to a valid checkpoint for the model. The state of the trained model will be overwritten in place.
|
|
1413
|
+
- If 'latest', the latest checkpoint will be loaded if a DragonModelCheckpoint was provided. The state of the trained model will be overwritten in place.
|
|
1414
|
+
- If 'current', use the current state of the trained model up the latest trained epoch.
|
|
987
1415
|
"""
|
|
988
1416
|
dataset_for_names = None
|
|
989
1417
|
eval_loader = None
|
|
1418
|
+
|
|
1419
|
+
# load model checkpoint
|
|
1420
|
+
if isinstance(model_checkpoint, Path):
|
|
1421
|
+
self._load_checkpoint(path=model_checkpoint)
|
|
1422
|
+
elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback:
|
|
1423
|
+
path_to_latest = self._checkpoint_callback.best_checkpoint_path
|
|
1424
|
+
self._load_checkpoint(path_to_latest)
|
|
1425
|
+
elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback is None:
|
|
1426
|
+
_LOGGER.error(f"'model_checkpoint' set to '{MagicWords.LATEST}' but no checkpoint callback was found.")
|
|
1427
|
+
raise ValueError()
|
|
990
1428
|
|
|
1429
|
+
# Dataloader
|
|
991
1430
|
if isinstance(data, DataLoader):
|
|
992
1431
|
eval_loader = data
|
|
993
1432
|
if hasattr(data, 'dataset'):
|
|
994
|
-
dataset_for_names = data.dataset
|
|
1433
|
+
dataset_for_names = data.dataset # type: ignore
|
|
995
1434
|
elif isinstance(data, Dataset):
|
|
996
1435
|
# Create a new loader from the provided dataset
|
|
997
1436
|
eval_loader = DataLoader(data,
|
|
@@ -1002,19 +1441,19 @@ class ObjectDetectionTrainer:
|
|
|
1002
1441
|
collate_fn=self.collate_fn)
|
|
1003
1442
|
dataset_for_names = data
|
|
1004
1443
|
else: # data is None, use the trainer's default test dataset
|
|
1005
|
-
if self.
|
|
1444
|
+
if self.validation_dataset is None:
|
|
1006
1445
|
_LOGGER.error("Cannot evaluate. No data provided and no test_dataset available in the trainer.")
|
|
1007
1446
|
raise ValueError()
|
|
1008
1447
|
# Create a fresh DataLoader from the test_dataset
|
|
1009
1448
|
eval_loader = DataLoader(
|
|
1010
|
-
self.
|
|
1449
|
+
self.validation_dataset,
|
|
1011
1450
|
batch_size=self._batch_size,
|
|
1012
1451
|
shuffle=False,
|
|
1013
1452
|
num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
|
|
1014
1453
|
pin_memory=(self.device.type == "cuda"),
|
|
1015
1454
|
collate_fn=self.collate_fn
|
|
1016
1455
|
)
|
|
1017
|
-
dataset_for_names = self.
|
|
1456
|
+
dataset_for_names = self.validation_dataset
|
|
1018
1457
|
|
|
1019
1458
|
if eval_loader is None:
|
|
1020
1459
|
_LOGGER.error("Cannot evaluate. No valid data was provided or found.")
|
|
@@ -1068,36 +1507,480 @@ class ObjectDetectionTrainer:
|
|
|
1068
1507
|
class_names=class_names,
|
|
1069
1508
|
print_output=False
|
|
1070
1509
|
)
|
|
1071
|
-
|
|
1072
|
-
# print("\n--- Training History ---")
|
|
1073
|
-
plot_losses(self.history, save_dir=save_dir)
|
|
1074
1510
|
|
|
1075
|
-
def
|
|
1076
|
-
"""
|
|
1077
|
-
|
|
1078
|
-
|
|
1079
|
-
|
|
1511
|
+
def finalize_model_training(self, save_dir: Union[str, Path], filename: str, model_checkpoint: Union[Path, Literal['latest', 'current']]):
|
|
1512
|
+
"""
|
|
1513
|
+
Saves a finalized, "inference-ready" model state to a .pth file.
|
|
1514
|
+
|
|
1515
|
+
This method saves the model's `state_dict` and the final epoch number.
|
|
1516
|
+
|
|
1517
|
+
Args:
|
|
1518
|
+
save_dir (Union[str, Path]): The directory to save the finalized model.
|
|
1519
|
+
filename (str): The desired filename for the model (e.g., "final_model.pth").
|
|
1520
|
+
model_checkpoint (Union[Path, Literal["latest", "current"]]):
|
|
1521
|
+
- Path: Loads the model state from a specific checkpoint file.
|
|
1522
|
+
- "latest": Loads the best model state saved by the `DragonModelCheckpoint` callback.
|
|
1523
|
+
- "current": Uses the model's state as it is at the end of the `fit()` call.
|
|
1524
|
+
"""
|
|
1525
|
+
# handle save path
|
|
1526
|
+
sanitized_filename = sanitize_filename(filename)
|
|
1527
|
+
if not sanitized_filename.endswith(".pth"):
|
|
1528
|
+
sanitized_filename = sanitized_filename + ".pth"
|
|
1080
1529
|
|
|
1081
|
-
|
|
1530
|
+
dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
1531
|
+
full_path = dir_path / sanitized_filename
|
|
1532
|
+
|
|
1533
|
+
# handle checkpoint
|
|
1534
|
+
if isinstance(model_checkpoint, Path):
|
|
1535
|
+
self._load_checkpoint(path=model_checkpoint)
|
|
1536
|
+
elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback:
|
|
1537
|
+
path_to_latest = self._checkpoint_callback.best_checkpoint_path
|
|
1538
|
+
self._load_checkpoint(path_to_latest)
|
|
1539
|
+
elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback is None:
|
|
1540
|
+
_LOGGER.error(f"'model_checkpoint' set to '{MagicWords.LATEST}' but no checkpoint callback was found.")
|
|
1541
|
+
raise ValueError()
|
|
1542
|
+
elif model_checkpoint == MagicWords.CURRENT:
|
|
1543
|
+
pass
|
|
1544
|
+
else:
|
|
1545
|
+
_LOGGER.error(f"Unknown 'model_checkpoint' parameter received '{model_checkpoint}'.")
|
|
1546
|
+
|
|
1547
|
+
# Create finalized data
|
|
1548
|
+
finalized_data = {
|
|
1549
|
+
PyTorchCheckpointKeys.EPOCH: self.epoch,
|
|
1550
|
+
PyTorchCheckpointKeys.MODEL_STATE: self.model.state_dict(),
|
|
1551
|
+
}
|
|
1552
|
+
|
|
1553
|
+
torch.save(finalized_data, full_path)
|
|
1554
|
+
|
|
1555
|
+
_LOGGER.info(f"Finalized model weights saved to {full_path}.")
|
|
1556
|
+
|
|
1557
|
+
# --- DragonSequenceTrainer ----
|
|
1558
|
+
class DragonSequenceTrainer(_BaseDragonTrainer):
|
|
1559
|
+
def __init__(self,
|
|
1560
|
+
model: nn.Module,
|
|
1561
|
+
train_dataset: Dataset,
|
|
1562
|
+
validation_dataset: Dataset,
|
|
1563
|
+
kind: Literal["sequence-to-sequence", "sequence-to-value"],
|
|
1564
|
+
optimizer: torch.optim.Optimizer,
|
|
1565
|
+
device: Union[Literal['cuda', 'mps', 'cpu'],str],
|
|
1566
|
+
checkpoint_callback: Optional[DragonModelCheckpoint],
|
|
1567
|
+
early_stopping_callback: Optional[DragonEarlyStopping],
|
|
1568
|
+
lr_scheduler_callback: Optional[DragonLRScheduler],
|
|
1569
|
+
extra_callbacks: Optional[List[_Callback]] = None,
|
|
1570
|
+
criterion: Union[nn.Module,Literal["auto"]] = "auto",
|
|
1571
|
+
dataloader_workers: int = 2):
|
|
1082
1572
|
"""
|
|
1083
|
-
|
|
1573
|
+
Automates the training process of a PyTorch Sequence Model.
|
|
1084
1574
|
|
|
1085
|
-
|
|
1575
|
+
Built-in Callbacks: `History`, `TqdmProgressBar`
|
|
1576
|
+
|
|
1577
|
+
Args:
|
|
1578
|
+
model (nn.Module): The PyTorch model to train.
|
|
1579
|
+
train_dataset (Dataset): The training dataset.
|
|
1580
|
+
validation_dataset (Dataset): The validation dataset.
|
|
1581
|
+
kind (str): Used to redirect to the correct process ('sequence-to-sequence' or 'sequence-to-value').
|
|
1582
|
+
criterion (nn.Module | "auto"): The loss function to use. If "auto", it will be inferred from the selected task
|
|
1583
|
+
optimizer (torch.optim.Optimizer): The optimizer.
|
|
1584
|
+
device (str): The device to run training on ('cpu', 'cuda', 'mps').
|
|
1585
|
+
dataloader_workers (int): Subprocesses for data loading.
|
|
1586
|
+
extra_callbacks (List[Callback] | None): A list of extra callbacks to use during training.
|
|
1086
1587
|
"""
|
|
1087
|
-
|
|
1588
|
+
# Call the base class constructor with common parameters
|
|
1589
|
+
super().__init__(
|
|
1590
|
+
model=model,
|
|
1591
|
+
optimizer=optimizer,
|
|
1592
|
+
device=device,
|
|
1593
|
+
dataloader_workers=dataloader_workers,
|
|
1594
|
+
checkpoint_callback=checkpoint_callback,
|
|
1595
|
+
early_stopping_callback=early_stopping_callback,
|
|
1596
|
+
lr_scheduler_callback=lr_scheduler_callback,
|
|
1597
|
+
extra_callbacks=extra_callbacks
|
|
1598
|
+
)
|
|
1599
|
+
|
|
1600
|
+
if kind not in [MLTaskKeys.SEQUENCE_SEQUENCE, MLTaskKeys.SEQUENCE_VALUE]:
|
|
1601
|
+
raise ValueError(f"'{kind}' is not a valid task type for DragonSequenceTrainer.")
|
|
1602
|
+
|
|
1603
|
+
self.train_dataset = train_dataset
|
|
1604
|
+
self.validation_dataset = validation_dataset
|
|
1605
|
+
self.kind = kind
|
|
1606
|
+
|
|
1607
|
+
# try to validate against Dragon Sequence model
|
|
1608
|
+
if hasattr(self.model, "prediction_mode"):
|
|
1609
|
+
key_to_check: str = self.model.prediction_mode # type: ignore
|
|
1610
|
+
if not key_to_check == self.kind:
|
|
1611
|
+
_LOGGER.error(f"Trainer was set for '{self.kind}', but model architecture '{self.model}' is built for '{key_to_check}'.")
|
|
1612
|
+
raise RuntimeError()
|
|
1613
|
+
|
|
1614
|
+
# loss function
|
|
1615
|
+
if criterion == "auto":
|
|
1616
|
+
# Both sequence tasks are treated as regression problems
|
|
1617
|
+
self.criterion = nn.MSELoss()
|
|
1618
|
+
else:
|
|
1619
|
+
self.criterion = criterion
|
|
1620
|
+
|
|
1621
|
+
def _create_dataloaders(self, batch_size: int, shuffle: bool):
|
|
1622
|
+
"""Initializes the DataLoaders."""
|
|
1623
|
+
# Ensure stability on MPS devices by setting num_workers to 0
|
|
1624
|
+
loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
|
|
1625
|
+
|
|
1626
|
+
self.train_loader = DataLoader(
|
|
1627
|
+
dataset=self.train_dataset,
|
|
1628
|
+
batch_size=batch_size,
|
|
1629
|
+
shuffle=shuffle,
|
|
1630
|
+
num_workers=loader_workers,
|
|
1631
|
+
pin_memory=("cuda" in self.device.type),
|
|
1632
|
+
drop_last=True # Drops the last batch if incomplete, selecting a good batch size is key.
|
|
1633
|
+
)
|
|
1634
|
+
|
|
1635
|
+
self.validation_loader = DataLoader(
|
|
1636
|
+
dataset=self.validation_dataset,
|
|
1637
|
+
batch_size=batch_size,
|
|
1638
|
+
shuffle=False,
|
|
1639
|
+
num_workers=loader_workers,
|
|
1640
|
+
pin_memory=("cuda" in self.device.type)
|
|
1641
|
+
)
|
|
1642
|
+
|
|
1643
|
+
def _train_step(self):
|
|
1644
|
+
self.model.train()
|
|
1645
|
+
running_loss = 0.0
|
|
1646
|
+
total_samples = 0
|
|
1647
|
+
|
|
1648
|
+
for batch_idx, (features, target) in enumerate(self.train_loader): # type: ignore
|
|
1649
|
+
# Create a log dictionary for the batch
|
|
1650
|
+
batch_logs = {
|
|
1651
|
+
PyTorchLogKeys.BATCH_INDEX: batch_idx,
|
|
1652
|
+
PyTorchLogKeys.BATCH_SIZE: features.size(0)
|
|
1653
|
+
}
|
|
1654
|
+
self._callbacks_hook('on_batch_begin', batch_idx, logs=batch_logs)
|
|
1655
|
+
|
|
1656
|
+
features, target = features.to(self.device), target.to(self.device)
|
|
1657
|
+
self.optimizer.zero_grad()
|
|
1658
|
+
|
|
1659
|
+
output = self.model(features)
|
|
1660
|
+
|
|
1661
|
+
# --- Label Type/Shape Correction ---
|
|
1662
|
+
# Ensure target is float for MSELoss
|
|
1663
|
+
target = target.float()
|
|
1664
|
+
|
|
1665
|
+
# For seq-to-val, models might output [N, 1] but target is [N].
|
|
1666
|
+
if self.kind == MLTaskKeys.SEQUENCE_VALUE:
|
|
1667
|
+
if output.ndim == 2 and output.shape[1] == 1 and target.ndim == 1:
|
|
1668
|
+
output = output.squeeze(1)
|
|
1669
|
+
|
|
1670
|
+
# For seq-to-seq, models might output [N, Seq, 1] but target is [N, Seq].
|
|
1671
|
+
elif self.kind == MLTaskKeys.SEQUENCE_SEQUENCE:
|
|
1672
|
+
if output.ndim == 3 and output.shape[2] == 1 and target.ndim == 2:
|
|
1673
|
+
output = output.squeeze(-1)
|
|
1674
|
+
|
|
1675
|
+
loss = self.criterion(output, target)
|
|
1676
|
+
|
|
1677
|
+
loss.backward()
|
|
1678
|
+
self.optimizer.step()
|
|
1679
|
+
|
|
1680
|
+
# Calculate batch loss and update running loss for the epoch
|
|
1681
|
+
batch_loss = loss.item()
|
|
1682
|
+
batch_size = features.size(0)
|
|
1683
|
+
running_loss += batch_loss * batch_size # Accumulate total loss
|
|
1684
|
+
total_samples += batch_size # total samples
|
|
1685
|
+
|
|
1686
|
+
# Add the batch loss to the logs and call the end-of-batch hook
|
|
1687
|
+
batch_logs[PyTorchLogKeys.BATCH_LOSS] = batch_loss
|
|
1688
|
+
self._callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
|
|
1689
|
+
|
|
1690
|
+
if total_samples == 0:
|
|
1691
|
+
_LOGGER.warning("No samples processed in a train_step. Returning 0 loss.")
|
|
1692
|
+
return {PyTorchLogKeys.TRAIN_LOSS: 0.0}
|
|
1693
|
+
|
|
1694
|
+
return {PyTorchLogKeys.TRAIN_LOSS: running_loss / total_samples} # type: ignore
|
|
1695
|
+
|
|
1696
|
+
def _validation_step(self):
|
|
1697
|
+
self.model.eval()
|
|
1698
|
+
running_loss = 0.0
|
|
1699
|
+
|
|
1700
|
+
with torch.no_grad():
|
|
1701
|
+
for features, target in self.validation_loader: # type: ignore
|
|
1702
|
+
features, target = features.to(self.device), target.to(self.device)
|
|
1703
|
+
|
|
1704
|
+
output = self.model(features)
|
|
1705
|
+
|
|
1706
|
+
# --- Label Type/Shape Correction ---
|
|
1707
|
+
target = target.float()
|
|
1708
|
+
|
|
1709
|
+
# For seq-to-val, models might output [N, 1] but target is [N].
|
|
1710
|
+
if self.kind == MLTaskKeys.SEQUENCE_VALUE:
|
|
1711
|
+
if output.ndim == 2 and output.shape[1] == 1 and target.ndim == 1:
|
|
1712
|
+
output = output.squeeze(1)
|
|
1713
|
+
|
|
1714
|
+
# For seq-to-seq, models might output [N, Seq, 1] but target is [N, Seq].
|
|
1715
|
+
elif self.kind == MLTaskKeys.SEQUENCE_SEQUENCE:
|
|
1716
|
+
if output.ndim == 3 and output.shape[2] == 1 and target.ndim == 2:
|
|
1717
|
+
output = output.squeeze(-1)
|
|
1718
|
+
|
|
1719
|
+
loss = self.criterion(output, target)
|
|
1720
|
+
|
|
1721
|
+
running_loss += loss.item() * features.size(0)
|
|
1722
|
+
|
|
1723
|
+
if not self.validation_loader.dataset: # type: ignore
|
|
1724
|
+
_LOGGER.warning("No samples processed in _validation_step. Returning 0 loss.")
|
|
1725
|
+
return {PyTorchLogKeys.VAL_LOSS: 0.0}
|
|
1726
|
+
|
|
1727
|
+
logs = {PyTorchLogKeys.VAL_LOSS: running_loss / len(self.validation_loader.dataset)} # type: ignore
|
|
1728
|
+
return logs
|
|
1729
|
+
|
|
1730
|
+
def _predict_for_eval(self, dataloader: DataLoader):
|
|
1731
|
+
"""
|
|
1732
|
+
Private method to yield model predictions batch by batch for evaluation.
|
|
1733
|
+
|
|
1734
|
+
Yields:
|
|
1735
|
+
tuple: A tuple containing (y_pred_batch, y_prob_batch, y_true_batch).
|
|
1736
|
+
y_prob_batch is always None for sequence tasks.
|
|
1737
|
+
"""
|
|
1738
|
+
self.model.eval()
|
|
1088
1739
|
self.model.to(self.device)
|
|
1089
|
-
|
|
1740
|
+
|
|
1741
|
+
with torch.no_grad():
|
|
1742
|
+
for features, target in dataloader:
|
|
1743
|
+
features = features.to(self.device)
|
|
1744
|
+
output = self.model(features).cpu()
|
|
1745
|
+
|
|
1746
|
+
y_pred_batch = output.numpy()
|
|
1747
|
+
y_prob_batch = None # Not applicable for sequence regression
|
|
1748
|
+
y_true_batch = target.numpy()
|
|
1749
|
+
|
|
1750
|
+
yield y_pred_batch, y_prob_batch, y_true_batch
|
|
1751
|
+
|
|
1752
|
+
def evaluate(self,
|
|
1753
|
+
save_dir: Union[str, Path],
|
|
1754
|
+
model_checkpoint: Union[Path, Literal["latest", "current"]],
|
|
1755
|
+
test_data: Optional[Union[DataLoader, Dataset]] = None,
|
|
1756
|
+
val_format_configuration: Optional[Union[SequenceValueMetricsFormat,
|
|
1757
|
+
SequenceSequenceMetricsFormat]]=None,
|
|
1758
|
+
test_format_configuration: Optional[Union[SequenceValueMetricsFormat,
|
|
1759
|
+
SequenceSequenceMetricsFormat]]=None):
|
|
1760
|
+
"""
|
|
1761
|
+
Evaluates the model, routing to the correct evaluation function.
|
|
1762
|
+
|
|
1763
|
+
Args:
|
|
1764
|
+
model_checkpoint ('auto' | Path | None):
|
|
1765
|
+
- Path to a valid checkpoint for the model.
|
|
1766
|
+
- If 'latest', the latest checkpoint will be loaded.
|
|
1767
|
+
- If 'current', use the current state of the trained model.
|
|
1768
|
+
save_dir (str | Path): Directory to save all reports and plots.
|
|
1769
|
+
test_data (DataLoader | Dataset | None): Optional Test data.
|
|
1770
|
+
val_format_configuration: Optional configuration for validation metrics.
|
|
1771
|
+
test_format_configuration: Optional configuration for test metrics.
|
|
1772
|
+
"""
|
|
1773
|
+
# Validate model checkpoint
|
|
1774
|
+
if isinstance(model_checkpoint, Path):
|
|
1775
|
+
checkpoint_validated = make_fullpath(model_checkpoint, enforce="file")
|
|
1776
|
+
elif model_checkpoint in [MagicWords.LATEST, MagicWords.CURRENT]:
|
|
1777
|
+
checkpoint_validated = model_checkpoint
|
|
1778
|
+
else:
|
|
1779
|
+
_LOGGER.error(f"'model_checkpoint' must be a Path object, or '{MagicWords.LATEST}', or '{MagicWords.CURRENT}'.")
|
|
1780
|
+
raise ValueError()
|
|
1781
|
+
|
|
1782
|
+
# Validate val configuration
|
|
1783
|
+
if val_format_configuration is not None:
|
|
1784
|
+
if not isinstance(val_format_configuration, (SequenceValueMetricsFormat, SequenceSequenceMetricsFormat)):
|
|
1785
|
+
_LOGGER.error(f"Invalid 'val_format_configuration': '{type(val_format_configuration)}'.")
|
|
1786
|
+
raise ValueError()
|
|
1787
|
+
|
|
1788
|
+
# Validate directory
|
|
1789
|
+
save_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
1790
|
+
|
|
1791
|
+
# Validate test data and dispatch
|
|
1792
|
+
if test_data is not None:
|
|
1793
|
+
if not isinstance(test_data, (DataLoader, Dataset)):
|
|
1794
|
+
_LOGGER.error(f"Invalid type for 'test_data': '{type(test_data)}'.")
|
|
1795
|
+
raise ValueError()
|
|
1796
|
+
test_data_validated = test_data
|
|
1090
1797
|
|
|
1091
|
-
|
|
1798
|
+
validation_metrics_path = save_path / DragonTrainerKeys.VALIDATION_METRICS_DIR
|
|
1799
|
+
test_metrics_path = save_path / DragonTrainerKeys.TEST_METRICS_DIR
|
|
1800
|
+
|
|
1801
|
+
# Dispatch validation set
|
|
1802
|
+
_LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{DragonTrainerKeys.VALIDATION_METRICS_DIR}'")
|
|
1803
|
+
self._evaluate(save_dir=validation_metrics_path,
|
|
1804
|
+
model_checkpoint=checkpoint_validated,
|
|
1805
|
+
data=None,
|
|
1806
|
+
format_configuration=val_format_configuration)
|
|
1807
|
+
|
|
1808
|
+
# Validate test configuration
|
|
1809
|
+
test_configuration_validated = None
|
|
1810
|
+
if test_format_configuration is not None:
|
|
1811
|
+
if not isinstance(test_format_configuration, (SequenceValueMetricsFormat, SequenceSequenceMetricsFormat)):
|
|
1812
|
+
warning_message_type = f"Invalid test_format_configuration': '{type(test_format_configuration)}'."
|
|
1813
|
+
if val_format_configuration is not None:
|
|
1814
|
+
warning_message_type += " 'val_format_configuration' will be used."
|
|
1815
|
+
test_configuration_validated = val_format_configuration
|
|
1816
|
+
else:
|
|
1817
|
+
warning_message_type += " Using default format."
|
|
1818
|
+
_LOGGER.warning(warning_message_type)
|
|
1819
|
+
else:
|
|
1820
|
+
test_configuration_validated = test_format_configuration
|
|
1821
|
+
|
|
1822
|
+
# Dispatch test set
|
|
1823
|
+
_LOGGER.info(f"Evaluating on test dataset. Metrics will be saved to '{DragonTrainerKeys.TEST_METRICS_DIR}'")
|
|
1824
|
+
self._evaluate(save_dir=test_metrics_path,
|
|
1825
|
+
model_checkpoint="current",
|
|
1826
|
+
data=test_data_validated,
|
|
1827
|
+
format_configuration=test_configuration_validated)
|
|
1828
|
+
else:
|
|
1829
|
+
# Dispatch validation set
|
|
1830
|
+
_LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{save_path.name}'")
|
|
1831
|
+
self._evaluate(save_dir=save_path,
|
|
1832
|
+
model_checkpoint=checkpoint_validated,
|
|
1833
|
+
data=None,
|
|
1834
|
+
format_configuration=val_format_configuration)
|
|
1835
|
+
|
|
1836
|
+
def _evaluate(self,
|
|
1837
|
+
save_dir: Union[str, Path],
|
|
1838
|
+
model_checkpoint: Union[Path, Literal["latest", "current"]],
|
|
1839
|
+
data: Optional[Union[DataLoader, Dataset]],
|
|
1840
|
+
format_configuration: Optional[Union[SequenceValueMetricsFormat,
|
|
1841
|
+
SequenceSequenceMetricsFormat]]):
|
|
1092
1842
|
"""
|
|
1093
|
-
|
|
1843
|
+
Private evaluation helper.
|
|
1844
|
+
"""
|
|
1845
|
+
eval_loader = None
|
|
1846
|
+
|
|
1847
|
+
# load model checkpoint
|
|
1848
|
+
if isinstance(model_checkpoint, Path):
|
|
1849
|
+
self._load_checkpoint(path=model_checkpoint)
|
|
1850
|
+
elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback:
|
|
1851
|
+
path_to_latest = self._checkpoint_callback.best_checkpoint_path
|
|
1852
|
+
self._load_checkpoint(path_to_latest)
|
|
1853
|
+
elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback is None:
|
|
1854
|
+
_LOGGER.error(f"'model_checkpoint' set to '{MagicWords.LATEST}' but no checkpoint callback was found.")
|
|
1855
|
+
raise ValueError()
|
|
1856
|
+
|
|
1857
|
+
# Dataloader
|
|
1858
|
+
if isinstance(data, DataLoader):
|
|
1859
|
+
eval_loader = data
|
|
1860
|
+
elif isinstance(data, Dataset):
|
|
1861
|
+
# Create a new loader from the provided dataset
|
|
1862
|
+
eval_loader = DataLoader(data,
|
|
1863
|
+
batch_size=self._batch_size,
|
|
1864
|
+
shuffle=False,
|
|
1865
|
+
num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
|
|
1866
|
+
pin_memory=(self.device.type == "cuda"))
|
|
1867
|
+
else: # data is None, use the trainer's default validation dataset
|
|
1868
|
+
if self.validation_dataset is None:
|
|
1869
|
+
_LOGGER.error("Cannot evaluate. No data provided and no validation_dataset available in the trainer.")
|
|
1870
|
+
raise ValueError()
|
|
1871
|
+
eval_loader = DataLoader(self.validation_dataset,
|
|
1872
|
+
batch_size=self._batch_size,
|
|
1873
|
+
shuffle=False,
|
|
1874
|
+
num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
|
|
1875
|
+
pin_memory=(self.device.type == "cuda"))
|
|
1876
|
+
|
|
1877
|
+
if eval_loader is None:
|
|
1878
|
+
_LOGGER.error("Cannot evaluate. No valid data was provided or found.")
|
|
1879
|
+
raise ValueError()
|
|
1880
|
+
|
|
1881
|
+
all_preds, _, all_true = [], [], []
|
|
1882
|
+
for y_pred_b, y_prob_b, y_true_b in self._predict_for_eval(eval_loader):
|
|
1883
|
+
if y_pred_b is not None: all_preds.append(y_pred_b)
|
|
1884
|
+
if y_true_b is not None: all_true.append(y_true_b)
|
|
1885
|
+
|
|
1886
|
+
if not all_true:
|
|
1887
|
+
_LOGGER.error("Evaluation failed: No data was processed.")
|
|
1888
|
+
return
|
|
1889
|
+
|
|
1890
|
+
y_pred = np.concatenate(all_preds)
|
|
1891
|
+
y_true = np.concatenate(all_true)
|
|
1892
|
+
|
|
1893
|
+
# --- Routing Logic ---
|
|
1894
|
+
if self.kind == MLTaskKeys.SEQUENCE_VALUE:
|
|
1895
|
+
config = None
|
|
1896
|
+
if format_configuration and isinstance(format_configuration, SequenceValueMetricsFormat):
|
|
1897
|
+
config = format_configuration
|
|
1898
|
+
elif format_configuration:
|
|
1899
|
+
_LOGGER.warning(f"Wrong config type: Received {type(format_configuration).__name__}, expected SequenceValueMetricsFormat.")
|
|
1900
|
+
|
|
1901
|
+
sequence_to_value_metrics(y_true=y_true,
|
|
1902
|
+
y_pred=y_pred,
|
|
1903
|
+
save_dir=save_dir,
|
|
1904
|
+
config=config)
|
|
1905
|
+
|
|
1906
|
+
elif self.kind == MLTaskKeys.SEQUENCE_SEQUENCE:
|
|
1907
|
+
config = None
|
|
1908
|
+
if format_configuration and isinstance(format_configuration, SequenceSequenceMetricsFormat):
|
|
1909
|
+
config = format_configuration
|
|
1910
|
+
elif format_configuration:
|
|
1911
|
+
_LOGGER.warning(f"Wrong config type: Received {type(format_configuration).__name__}, expected SequenceSequenceMetricsFormat.")
|
|
1912
|
+
|
|
1913
|
+
sequence_to_sequence_metrics(y_true=y_true,
|
|
1914
|
+
y_pred=y_pred,
|
|
1915
|
+
save_dir=save_dir,
|
|
1916
|
+
config=config)
|
|
1917
|
+
|
|
1918
|
+
def finalize_model_training(self,
|
|
1919
|
+
save_dir: Union[str, Path],
|
|
1920
|
+
filename: str,
|
|
1921
|
+
last_training_sequence: np.ndarray,
|
|
1922
|
+
model_checkpoint: Union[Path, Literal['latest', 'current']]):
|
|
1923
|
+
"""
|
|
1924
|
+
Saves a finalized, "inference-ready" model state to a .pth file.
|
|
1925
|
+
|
|
1926
|
+
This method saves the model's `state_dict` and the final epoch number.
|
|
1094
1927
|
|
|
1095
1928
|
Args:
|
|
1096
|
-
|
|
1929
|
+
save_dir (Union[str, Path]): The directory to save the finalized model.
|
|
1930
|
+
filename (str): The desired filename for the model (e.g., "final_model.pth").
|
|
1931
|
+
last_training_sequence (np.ndarray): The last un-scaled sequence from the training data, used for forecasting.
|
|
1932
|
+
model_checkpoint (Union[Path, Literal["latest", "current"]]):
|
|
1933
|
+
- Path: Loads the model state from a specific checkpoint file.
|
|
1934
|
+
- "latest": Loads the best model state saved by the `DragonModelCheckpoint` callback.
|
|
1935
|
+
- "current": Uses the model's state as it is at the end of the `fit()` call.
|
|
1097
1936
|
"""
|
|
1098
|
-
|
|
1099
|
-
|
|
1100
|
-
|
|
1937
|
+
# handle save path
|
|
1938
|
+
sanitized_filename = sanitize_filename(filename)
|
|
1939
|
+
if not sanitized_filename.endswith(".pth"):
|
|
1940
|
+
sanitized_filename = sanitized_filename + ".pth"
|
|
1941
|
+
|
|
1942
|
+
dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
1943
|
+
full_path = dir_path / sanitized_filename
|
|
1944
|
+
|
|
1945
|
+
# handle checkpoint
|
|
1946
|
+
if isinstance(model_checkpoint, Path):
|
|
1947
|
+
self._load_checkpoint(path=model_checkpoint)
|
|
1948
|
+
elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback:
|
|
1949
|
+
path_to_latest = self._checkpoint_callback.best_checkpoint_path
|
|
1950
|
+
self._load_checkpoint(path_to_latest)
|
|
1951
|
+
elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback is None:
|
|
1952
|
+
_LOGGER.error(f"'model_checkpoint' set to '{MagicWords.LATEST}' but no checkpoint callback was found.")
|
|
1953
|
+
raise ValueError()
|
|
1954
|
+
elif model_checkpoint == MagicWords.CURRENT:
|
|
1955
|
+
pass
|
|
1956
|
+
else:
|
|
1957
|
+
_LOGGER.error(f"Unknown 'model_checkpoint' parameter received '{model_checkpoint}'.")
|
|
1958
|
+
|
|
1959
|
+
# --- 1. Validate the provided initial sequence ---
|
|
1960
|
+
if not isinstance(last_training_sequence, np.ndarray):
|
|
1961
|
+
_LOGGER.error(f"'last_training_sequence' must be a numpy array. Got {type(last_training_sequence)}")
|
|
1962
|
+
raise TypeError()
|
|
1963
|
+
if last_training_sequence.ndim != 1:
|
|
1964
|
+
_LOGGER.error(f"'last_training_sequence' must be a 1D array. Got {last_training_sequence.ndim} dimensions.")
|
|
1965
|
+
raise ValueError()
|
|
1966
|
+
|
|
1967
|
+
# --- 2. Derive sequence_length from the array ---
|
|
1968
|
+
sequence_length = len(last_training_sequence)
|
|
1969
|
+
if sequence_length <= 0:
|
|
1970
|
+
_LOGGER.error(f"Length of 'last_training_sequence' cannot be zero.")
|
|
1971
|
+
raise ValueError()
|
|
1972
|
+
|
|
1973
|
+
# Create finalized data
|
|
1974
|
+
finalized_data = {
|
|
1975
|
+
PyTorchCheckpointKeys.EPOCH: self.epoch,
|
|
1976
|
+
PyTorchCheckpointKeys.MODEL_STATE: self.model.state_dict(),
|
|
1977
|
+
PyTorchCheckpointKeys.SEQUENCE_LENGTH: sequence_length,
|
|
1978
|
+
PyTorchCheckpointKeys.INITIAL_SEQUENCE: last_training_sequence
|
|
1979
|
+
}
|
|
1980
|
+
|
|
1981
|
+
torch.save(finalized_data, full_path)
|
|
1982
|
+
|
|
1983
|
+
_LOGGER.info(f"Finalized model weights saved to {full_path}.")
|
|
1101
1984
|
|
|
1102
1985
|
|
|
1103
1986
|
def info():
|