dragon-ml-toolbox 14.8.0__py3-none-any.whl → 16.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dragon-ml-toolbox might be problematic. Click here for more details.

Files changed (44) hide show
  1. {dragon_ml_toolbox-14.8.0.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/METADATA +9 -5
  2. dragon_ml_toolbox-16.0.0.dist-info/RECORD +51 -0
  3. ml_tools/ETL_cleaning.py +20 -20
  4. ml_tools/ETL_engineering.py +23 -25
  5. ml_tools/GUI_tools.py +20 -20
  6. ml_tools/MICE_imputation.py +3 -3
  7. ml_tools/ML_callbacks.py +43 -26
  8. ml_tools/ML_configuration.py +204 -11
  9. ml_tools/ML_datasetmaster.py +198 -280
  10. ml_tools/ML_evaluation.py +132 -41
  11. ml_tools/ML_evaluation_multi.py +96 -35
  12. ml_tools/ML_inference.py +249 -207
  13. ml_tools/ML_models.py +13 -102
  14. ml_tools/ML_models_advanced.py +1 -1
  15. ml_tools/ML_optimization.py +12 -12
  16. ml_tools/ML_scaler.py +11 -11
  17. ml_tools/ML_sequence_datasetmaster.py +341 -0
  18. ml_tools/ML_sequence_evaluation.py +215 -0
  19. ml_tools/ML_sequence_inference.py +391 -0
  20. ml_tools/ML_sequence_models.py +139 -0
  21. ml_tools/ML_trainer.py +1237 -354
  22. ml_tools/ML_utilities.py +1 -1
  23. ml_tools/ML_vision_datasetmaster.py +73 -67
  24. ml_tools/ML_vision_evaluation.py +26 -6
  25. ml_tools/ML_vision_inference.py +117 -140
  26. ml_tools/ML_vision_models.py +1 -1
  27. ml_tools/ML_vision_transformers.py +121 -40
  28. ml_tools/PSO_optimization.py +6 -6
  29. ml_tools/SQL.py +4 -4
  30. ml_tools/{keys.py → _keys.py} +43 -0
  31. ml_tools/_schema.py +1 -1
  32. ml_tools/ensemble_evaluation.py +1 -1
  33. ml_tools/ensemble_inference.py +7 -33
  34. ml_tools/ensemble_learning.py +1 -1
  35. ml_tools/optimization_tools.py +2 -2
  36. ml_tools/path_manager.py +5 -5
  37. ml_tools/utilities.py +1 -2
  38. dragon_ml_toolbox-14.8.0.dist-info/RECORD +0 -49
  39. ml_tools/RNN_forecast.py +0 -56
  40. ml_tools/_ML_vision_recipe.py +0 -88
  41. {dragon_ml_toolbox-14.8.0.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/WHEEL +0 -0
  42. {dragon_ml_toolbox-14.8.0.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/licenses/LICENSE +0 -0
  43. {dragon_ml_toolbox-14.8.0.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
  44. {dragon_ml_toolbox-14.8.0.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/top_level.txt +0 -0
ml_tools/ML_trainer.py CHANGED
@@ -1,80 +1,79 @@
1
- from typing import List, Literal, Union, Optional, Callable, Dict, Any, Tuple
1
+ from typing import List, Literal, Union, Optional, Callable, Dict, Any
2
2
  from pathlib import Path
3
3
  from torch.utils.data import DataLoader, Dataset
4
4
  import torch
5
5
  from torch import nn
6
6
  import numpy as np
7
+ from abc import ABC, abstractmethod
7
8
 
8
- from .ML_callbacks import Callback, History, TqdmProgressBar, ModelCheckpoint
9
+ from .path_manager import make_fullpath, sanitize_filename
10
+ from .ML_callbacks import _Callback, History, TqdmProgressBar, DragonModelCheckpoint, DragonEarlyStopping, DragonLRScheduler
9
11
  from .ML_evaluation import classification_metrics, regression_metrics, plot_losses, shap_summary_plot, plot_attention_importance
10
12
  from .ML_evaluation_multi import multi_target_regression_metrics, multi_label_classification_metrics, multi_target_shap_summary_plot
13
+ from .ML_vision_evaluation import segmentation_metrics, object_detection_metrics
14
+ from .ML_sequence_evaluation import sequence_to_sequence_metrics, sequence_to_value_metrics
15
+ from .ML_configuration import (ClassificationMetricsFormat,
16
+ MultiClassificationMetricsFormat,
17
+ RegressionMetricsFormat,
18
+ SegmentationMetricsFormat,
19
+ SequenceValueMetricsFormat,
20
+ SequenceSequenceMetricsFormat)
21
+
11
22
  from ._script_info import _script_info
12
- from .keys import PyTorchLogKeys, PyTorchCheckpointKeys, DatasetKeys
23
+ from ._keys import PyTorchLogKeys, PyTorchCheckpointKeys, DatasetKeys, MLTaskKeys, MagicWords, DragonTrainerKeys
13
24
  from ._logger import _LOGGER
14
- from .path_manager import make_fullpath
15
- from .ML_vision_evaluation import segmentation_metrics, object_detection_metrics
16
- from .ML_configuration import ClassificationMetricsFormat, MultiClassificationMetricsFormat
17
25
 
18
26
 
19
27
  __all__ = [
20
- "MLTrainer",
21
- "ObjectDetectionTrainer",
28
+ "DragonTrainer",
29
+ "DragonDetectionTrainer",
30
+ "DragonSequenceTrainer"
22
31
  ]
23
32
 
24
-
25
- class MLTrainer:
26
- def __init__(self, model: nn.Module, train_dataset: Dataset, test_dataset: Dataset,
27
- kind: Literal["regression", "classification", "multi_target_regression", "multi_label_classification", "segmentation"],
28
- criterion: nn.Module, optimizer: torch.optim.Optimizer,
29
- device: Union[Literal['cuda', 'mps', 'cpu'],str], dataloader_workers: int = 2, callbacks: Optional[List[Callback]] = None):
30
- """
31
- Automates the training process of a PyTorch Model.
32
-
33
- Built-in Callbacks: `History`, `TqdmProgressBar`
34
-
35
- Args:
36
- model (nn.Module): The PyTorch model to train.
37
- train_dataset (Dataset): The training dataset.
38
- test_dataset (Dataset): The testing/validation dataset.
39
- kind (str): Can be 'regression', 'classification', 'multi_target_regression', 'multi_label_classification', or 'segmentation'.
40
- criterion (nn.Module): The loss function.
41
- optimizer (torch.optim.Optimizer): The optimizer.
42
- device (str): The device to run training on ('cpu', 'cuda', 'mps').
43
- dataloader_workers (int): Subprocesses for data loading.
44
- callbacks (List[Callback] | None): A list of callbacks to use during training.
45
-
46
- Note:
47
- - For **regression** and **multi_target_regression** tasks, suggested criterions include `nn.MSELoss` or `nn.L1Loss`.
48
-
49
- - For **single-label, multi-class classification** tasks, `nn.CrossEntropyLoss` is the standard choice.
33
+ class _BaseDragonTrainer(ABC):
34
+ """
35
+ Abstract base class for Dragon Trainers.
50
36
 
51
- - For **multi-label, binary classification** tasks (where each label is a 0 or 1), `nn.BCEWithLogitsLoss` is the correct choice as it treats each output as an independent binary problem.
52
-
53
- - For **segmentation** tasks, `nn.CrossEntropyLoss` (for multi-class) or `nn.BCEWithLogitsLoss` (for binary) are common.
54
- """
55
- if kind not in ["regression", "classification", "multi_target_regression", "multi_label_classification", "segmentation"]:
56
- raise ValueError(f"'{kind}' is not a valid task type.")
37
+ Handles the common training loop orchestration, checkpointing, callback
38
+ management, and device handling. Subclasses must implement the
39
+ task-specific logic (dataloaders, train/val steps, evaluation).
40
+ """
41
+ def __init__(self,
42
+ model: nn.Module,
43
+ optimizer: torch.optim.Optimizer,
44
+ device: Union[Literal['cuda', 'mps', 'cpu'],str],
45
+ dataloader_workers: int = 2,
46
+ checkpoint_callback: Optional[DragonModelCheckpoint] = None,
47
+ early_stopping_callback: Optional[DragonEarlyStopping] = None,
48
+ lr_scheduler_callback: Optional[DragonLRScheduler] = None,
49
+ extra_callbacks: Optional[List[_Callback]] = None):
57
50
 
58
51
  self.model = model
59
- self.train_dataset = train_dataset
60
- self.test_dataset = test_dataset
61
- self.kind = kind
62
- self.criterion = criterion
63
52
  self.optimizer = optimizer
64
53
  self.scheduler = None
65
54
  self.device = self._validate_device(device)
66
55
  self.dataloader_workers = dataloader_workers
67
56
 
68
- # Callback handler - History and TqdmProgressBar are added by default
57
+ # Callback handler
69
58
  default_callbacks = [History(), TqdmProgressBar()]
70
- user_callbacks = callbacks if callbacks is not None else []
59
+
60
+ self._checkpoint_callback = None
61
+ if checkpoint_callback:
62
+ default_callbacks.append(checkpoint_callback)
63
+ self._checkpoint_callback = checkpoint_callback
64
+ if early_stopping_callback:
65
+ default_callbacks.append(early_stopping_callback)
66
+ if lr_scheduler_callback:
67
+ default_callbacks.append(lr_scheduler_callback)
68
+
69
+ user_callbacks = extra_callbacks if extra_callbacks is not None else []
71
70
  self.callbacks = default_callbacks + user_callbacks
72
71
  self._set_trainer_on_callbacks()
73
72
 
74
73
  # Internal state
75
- self.train_loader = None
76
- self.test_loader = None
77
- self.history = {}
74
+ self.train_loader: Optional[DataLoader] = None
75
+ self.validation_loader: Optional[DataLoader] = None
76
+ self.history: Dict[str, List[Any]] = {}
78
77
  self.epoch = 0
79
78
  self.epochs = 0 # Total epochs for the fit run
80
79
  self.start_epoch = 1
@@ -97,32 +96,10 @@ class MLTrainer:
97
96
  for callback in self.callbacks:
98
97
  callback.set_trainer(self)
99
98
 
100
- def _create_dataloaders(self, batch_size: int, shuffle: bool):
101
- """Initializes the DataLoaders."""
102
- # Ensure stability on MPS devices by setting num_workers to 0
103
- loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
104
-
105
- self.train_loader = DataLoader(
106
- dataset=self.train_dataset,
107
- batch_size=batch_size,
108
- shuffle=shuffle,
109
- num_workers=loader_workers,
110
- pin_memory=("cuda" in self.device.type),
111
- drop_last=True # Drops the last batch if incomplete, selecting a good batch size is key.
112
- )
113
-
114
- self.test_loader = DataLoader(
115
- dataset=self.test_dataset,
116
- batch_size=batch_size,
117
- shuffle=False,
118
- num_workers=loader_workers,
119
- pin_memory=("cuda" in self.device.type)
120
- )
121
-
122
99
  def _load_checkpoint(self, path: Union[str, Path]):
123
100
  """Loads a training checkpoint to resume training."""
124
101
  p = make_fullpath(path, enforce="file")
125
- _LOGGER.info(f"Loading checkpoint from '{p.name}' to resume training...")
102
+ _LOGGER.info(f"Loading checkpoint from '{p.name}'...")
126
103
 
127
104
  try:
128
105
  checkpoint = torch.load(p, map_location=self.device)
@@ -133,7 +110,16 @@ class MLTrainer:
133
110
 
134
111
  self.model.load_state_dict(checkpoint[PyTorchCheckpointKeys.MODEL_STATE])
135
112
  self.optimizer.load_state_dict(checkpoint[PyTorchCheckpointKeys.OPTIMIZER_STATE])
136
- self.start_epoch = checkpoint.get(PyTorchCheckpointKeys.EPOCH, 0) + 1 # Resume on the *next* epoch
113
+ self.epoch = checkpoint.get(PyTorchCheckpointKeys.EPOCH, 0)
114
+ self.start_epoch = self.epoch + 1 # Resume on the *next* epoch
115
+
116
+ # --- Load History ---
117
+ if PyTorchCheckpointKeys.HISTORY in checkpoint:
118
+ self.history = checkpoint[PyTorchCheckpointKeys.HISTORY]
119
+ _LOGGER.info(f"Restored training history up to epoch {self.epoch}.")
120
+ else:
121
+ _LOGGER.warning("No 'history' found in checkpoint. A new history will be started.")
122
+ self.history = {} # Ensure it's at least an empty dict
137
123
 
138
124
  # --- Scheduler State Loading Logic ---
139
125
  scheduler_state_exists = PyTorchCheckpointKeys.SCHEDULER_STATE in checkpoint
@@ -163,7 +149,7 @@ class MLTrainer:
163
149
 
164
150
  # Restore callback states
165
151
  for cb in self.callbacks:
166
- if isinstance(cb, ModelCheckpoint) and PyTorchCheckpointKeys.BEST_SCORE in checkpoint:
152
+ if isinstance(cb, DragonModelCheckpoint) and PyTorchCheckpointKeys.BEST_SCORE in checkpoint:
167
153
  cb.best = checkpoint[PyTorchCheckpointKeys.BEST_SCORE]
168
154
  _LOGGER.info(f"Restored {cb.__class__.__name__} 'best' score to: {cb.best:.4f}")
169
155
 
@@ -174,7 +160,8 @@ class MLTrainer:
174
160
  raise
175
161
 
176
162
  def fit(self,
177
- epochs: int = 10,
163
+ save_dir: Union[str,Path],
164
+ epochs: int = 100,
178
165
  batch_size: int = 10,
179
166
  shuffle: bool = True,
180
167
  resume_from_checkpoint: Optional[Union[str, Path]] = None):
@@ -184,21 +171,15 @@ class MLTrainer:
184
171
  Returns the "History" callback dictionary.
185
172
 
186
173
  Args:
174
+ save_dir (str | Path): Directory to save the loss plot.
187
175
  epochs (int): The total number of epochs to train for.
188
176
  batch_size (int): The number of samples per batch.
189
177
  shuffle (bool): Whether to shuffle the training data at each epoch.
190
178
  resume_from_checkpoint (str | Path | None): Optional path to a checkpoint to resume training.
191
-
192
- Note:
193
- For regression tasks using `nn.MSELoss` or `nn.L1Loss`, the trainer
194
- automatically aligns the model's output tensor with the target tensor's
195
- shape using `output.view_as(target)`. This handles the common case
196
- where a model outputs a shape of `[batch_size, 1]` and the target has a
197
- shape of `[batch_size]`.
198
179
  """
199
180
  self.epochs = epochs
200
181
  self._batch_size = batch_size
201
- self._create_dataloaders(self._batch_size, shuffle)
182
+ self._create_dataloaders(self._batch_size, shuffle) # type: ignore
202
183
  self.model.to(self.device)
203
184
 
204
185
  if resume_from_checkpoint:
@@ -209,11 +190,19 @@ class MLTrainer:
209
190
 
210
191
  self._callbacks_hook('on_train_begin')
211
192
 
193
+ if not self.train_loader:
194
+ _LOGGER.error("Train loader is not initialized.")
195
+ raise ValueError()
196
+
197
+ if not self.validation_loader:
198
+ _LOGGER.error("Validation loader is not initialized.")
199
+ raise ValueError()
200
+
212
201
  for epoch in range(self.start_epoch, self.epochs + 1):
213
202
  self.epoch = epoch
214
- epoch_logs = {}
203
+ epoch_logs: Dict[str, Any] = {}
215
204
  self._callbacks_hook('on_epoch_begin', epoch, logs=epoch_logs)
216
-
205
+
217
206
  train_logs = self._train_step()
218
207
  epoch_logs.update(train_logs)
219
208
 
@@ -227,11 +216,185 @@ class MLTrainer:
227
216
  break
228
217
 
229
218
  self._callbacks_hook('on_train_end')
219
+
220
+ # Training History
221
+ plot_losses(self.history, save_dir=save_dir)
222
+
230
223
  return self.history
224
+
225
+ def _callbacks_hook(self, method_name: str, *args, **kwargs):
226
+ """Calls the specified method on all callbacks."""
227
+ for callback in self.callbacks:
228
+ method = getattr(callback, method_name)
229
+ method(*args, **kwargs)
230
+
231
+ def to_cpu(self):
232
+ """
233
+ Moves the model to the CPU and updates the trainer's device setting.
234
+
235
+ This is useful for running operations that require the CPU.
236
+ """
237
+ self.device = torch.device('cpu')
238
+ self.model.to(self.device)
239
+ _LOGGER.info("Trainer and model moved to CPU.")
240
+
241
+ def to_device(self, device: str):
242
+ """
243
+ Moves the model to the specified device and updates the trainer's device setting.
244
+
245
+ Args:
246
+ device (str): The target device (e.g., 'cuda', 'mps', 'cpu').
247
+ """
248
+ self.device = self._validate_device(device)
249
+ self.model.to(self.device)
250
+ _LOGGER.info(f"Trainer and model moved to {self.device}.")
251
+
252
+ # --- Abstract Methods ---
253
+ # These must be implemented by subclasses
254
+
255
+ @abstractmethod
256
+ def _create_dataloaders(self, batch_size: int, shuffle: bool):
257
+ """Initializes the DataLoaders."""
258
+ raise NotImplementedError
259
+
260
+ @abstractmethod
261
+ def _train_step(self) -> Dict[str, float]:
262
+ """Runs a single training epoch."""
263
+ raise NotImplementedError
264
+
265
+ @abstractmethod
266
+ def _validation_step(self) -> Dict[str, float]:
267
+ """Runs a single validation epoch."""
268
+ raise NotImplementedError
269
+
270
+ @abstractmethod
271
+ def evaluate(self, *args, **kwargs):
272
+ """Runs the full model evaluation."""
273
+ raise NotImplementedError
274
+
275
+ @abstractmethod
276
+ def _evaluate(self, *args, **kwargs):
277
+ """Internal evaluation helper."""
278
+ raise NotImplementedError
279
+
280
+ @abstractmethod
281
+ def finalize_model_training(self, *args, **kwargs):
282
+ """Saves the finalized model for inference."""
283
+ raise NotImplementedError
284
+
285
+
286
+ # --- DragonTrainer ----
287
+ class DragonTrainer(_BaseDragonTrainer):
288
+ def __init__(self,
289
+ model: nn.Module,
290
+ train_dataset: Dataset,
291
+ validation_dataset: Dataset,
292
+ kind: Literal["regression", "binary classification", "multiclass classification",
293
+ "multitarget regression", "multilabel binary classification",
294
+ "binary segmentation", "multiclass segmentation", "binary image classification", "multiclass image classification"],
295
+ optimizer: torch.optim.Optimizer,
296
+ device: Union[Literal['cuda', 'mps', 'cpu'],str],
297
+ checkpoint_callback: Optional[DragonModelCheckpoint],
298
+ early_stopping_callback: Optional[DragonEarlyStopping],
299
+ lr_scheduler_callback: Optional[DragonLRScheduler],
300
+ extra_callbacks: Optional[List[_Callback]] = None,
301
+ criterion: Union[nn.Module,Literal["auto"]] = "auto",
302
+ dataloader_workers: int = 2):
303
+ """
304
+ Automates the training process of a PyTorch Model.
305
+
306
+ Built-in Callbacks: `History`, `TqdmProgressBar`
307
+
308
+ Args:
309
+ model (nn.Module): The PyTorch model to train.
310
+ train_dataset (Dataset): The training dataset.
311
+ validation_dataset (Dataset): The validation dataset.
312
+ kind (str): Used to redirect to the correct process.
313
+ criterion (nn.Module | "auto"): The loss function to use. If "auto", it will be inferred from the selected task
314
+ optimizer (torch.optim.Optimizer): The optimizer.
315
+ device (str): The device to run training on ('cpu', 'cuda', 'mps').
316
+ dataloader_workers (int): Subprocesses for data loading.
317
+ extra_callbacks (List[Callback] | None): A list of extra callbacks to use during training.
318
+
319
+ Note:
320
+ - For **regression** and **multi_target_regression** tasks, suggested criterions include `nn.MSELoss` or `nn.L1Loss`. The model should output as many logits as existing targets.
321
+
322
+ - For **single-label, binary classification**, `nn.BCEWithLogitsLoss` is the standard choice. The model should output a single logit.
231
323
 
324
+ - For **single-label, multi-class classification** tasks, `nn.CrossEntropyLoss` is the standard choice. The model should output as many logits as existing classes.
325
+
326
+ - For **multi-label, binary classification** tasks (where each label is a 0 or 1), `nn.BCEWithLogitsLoss` is the correct choice as it treats each output as an independent binary problem. The model should output 1 logit per binary target.
327
+
328
+ - For **binary segmentation** tasks, `nn.BCEWithLogitsLoss` is common. The model should output a single logit.
329
+
330
+ - for **multiclass segmentation** tasks, `nn.CrossEntropyLoss` is the standard. The model should output as many logits as existing classes.
331
+ """
332
+ # Call the base class constructor with common parameters
333
+ super().__init__(
334
+ model=model,
335
+ optimizer=optimizer,
336
+ device=device,
337
+ dataloader_workers=dataloader_workers,
338
+ checkpoint_callback=checkpoint_callback,
339
+ early_stopping_callback=early_stopping_callback,
340
+ lr_scheduler_callback=lr_scheduler_callback,
341
+ extra_callbacks=extra_callbacks
342
+ )
343
+
344
+ if kind not in [MLTaskKeys.REGRESSION,
345
+ MLTaskKeys.BINARY_CLASSIFICATION,
346
+ MLTaskKeys.MULTICLASS_CLASSIFICATION,
347
+ MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION,
348
+ MLTaskKeys.MULTITARGET_REGRESSION,
349
+ MLTaskKeys.BINARY_SEGMENTATION,
350
+ MLTaskKeys.MULTICLASS_SEGMENTATION,
351
+ MLTaskKeys.BINARY_IMAGE_CLASSIFICATION,
352
+ MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION]:
353
+ raise ValueError(f"'{kind}' is not a valid task type.")
354
+
355
+ self.train_dataset = train_dataset
356
+ self.validation_dataset = validation_dataset
357
+ self.kind = kind
358
+ self._classification_threshold: float = 0.5
359
+
360
+ # loss function
361
+ if criterion == "auto":
362
+ if kind in [MLTaskKeys.REGRESSION, MLTaskKeys.MULTITARGET_REGRESSION]:
363
+ self.criterion = nn.MSELoss()
364
+ elif kind in [MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.BINARY_IMAGE_CLASSIFICATION, MLTaskKeys.BINARY_SEGMENTATION, MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION]:
365
+ self.criterion = nn.BCEWithLogitsLoss()
366
+ elif kind in [MLTaskKeys.MULTICLASS_CLASSIFICATION, MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION, MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION, MLTaskKeys.MULTICLASS_SEGMENTATION]:
367
+ self.criterion = nn.CrossEntropyLoss()
368
+ else:
369
+ self.criterion = criterion
370
+
371
+ def _create_dataloaders(self, batch_size: int, shuffle: bool):
372
+ """Initializes the DataLoaders."""
373
+ # Ensure stability on MPS devices by setting num_workers to 0
374
+ loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
375
+
376
+ self.train_loader = DataLoader(
377
+ dataset=self.train_dataset,
378
+ batch_size=batch_size,
379
+ shuffle=shuffle,
380
+ num_workers=loader_workers,
381
+ pin_memory=("cuda" in self.device.type),
382
+ drop_last=True # Drops the last batch if incomplete, selecting a good batch size is key.
383
+ )
384
+
385
+ self.validation_loader = DataLoader(
386
+ dataset=self.validation_dataset,
387
+ batch_size=batch_size,
388
+ shuffle=False,
389
+ num_workers=loader_workers,
390
+ pin_memory=("cuda" in self.device.type)
391
+ )
392
+
232
393
  def _train_step(self):
233
394
  self.model.train()
234
395
  running_loss = 0.0
396
+ total_samples = 0
397
+
235
398
  for batch_idx, (features, target) in enumerate(self.train_loader): # type: ignore
236
399
  # Create a log dictionary for the batch
237
400
  batch_logs = {
@@ -245,9 +408,21 @@ class MLTrainer:
245
408
 
246
409
  output = self.model(features)
247
410
 
248
- # Apply shape correction only for single-target regression
249
- if self.kind == "regression":
250
- output = output.view_as(target)
411
+ # --- Label Type/Shape Correction ---
412
+ # Cast target to float for BCE-based losses
413
+ if self.kind in MLTaskKeys.ALL_BINARY_TASKS:
414
+ target = target.float()
415
+
416
+ # Reshape output to match target for single-logit tasks
417
+ if self.kind in [MLTaskKeys.REGRESSION, MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.BINARY_IMAGE_CLASSIFICATION]:
418
+ # If model outputs [N, 1] and target is [N], squeeze output
419
+ if output.ndim == 2 and output.shape[1] == 1 and target.ndim == 1:
420
+ output = output.squeeze(1)
421
+
422
+ if self.kind == MLTaskKeys.BINARY_SEGMENTATION:
423
+ # If model outputs [N, 1, H, W] and target is [N, H, W], squeeze output
424
+ if output.ndim == 4 and output.shape[1] == 1 and target.ndim == 3:
425
+ output = output.squeeze(1)
251
426
 
252
427
  loss = self.criterion(output, target)
253
428
 
@@ -256,34 +431,58 @@ class MLTrainer:
256
431
 
257
432
  # Calculate batch loss and update running loss for the epoch
258
433
  batch_loss = loss.item()
259
- running_loss += batch_loss * features.size(0)
434
+ batch_size = features.size(0)
435
+ running_loss += batch_loss * batch_size # Accumulate total loss
436
+ total_samples += batch_size # total samples
260
437
 
261
438
  # Add the batch loss to the logs and call the end-of-batch hook
262
439
  batch_logs[PyTorchLogKeys.BATCH_LOSS] = batch_loss
263
440
  self._callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
441
+
442
+ if total_samples == 0:
443
+ _LOGGER.warning("No samples processed in a train_step. Returning 0 loss.")
444
+ return {PyTorchLogKeys.TRAIN_LOSS: 0.0}
264
445
 
265
- return {PyTorchLogKeys.TRAIN_LOSS: running_loss / len(self.train_loader.dataset)} # type: ignore
446
+ return {PyTorchLogKeys.TRAIN_LOSS: running_loss / total_samples} # type: ignore
266
447
 
267
448
  def _validation_step(self):
268
449
  self.model.eval()
269
450
  running_loss = 0.0
451
+
270
452
  with torch.no_grad():
271
- for features, target in self.test_loader: # type: ignore
453
+ for features, target in self.validation_loader: # type: ignore
272
454
  features, target = features.to(self.device), target.to(self.device)
273
455
 
274
456
  output = self.model(features)
275
- # Apply shape correction only for single-target regression
276
- if self.kind == "regression":
277
- output = output.view_as(target)
457
+
458
+ # --- Label Type/Shape Correction ---
459
+ # Cast target to float for BCE-based losses
460
+ if self.kind in MLTaskKeys.ALL_BINARY_TASKS:
461
+ target = target.float()
462
+
463
+ # Reshape output to match target for single-logit tasks
464
+ if self.kind in [MLTaskKeys.REGRESSION, MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.BINARY_IMAGE_CLASSIFICATION]:
465
+ # If model outputs [N, 1] and target is [N], squeeze output
466
+ if output.ndim == 2 and output.shape[1] == 1 and target.ndim == 1:
467
+ output = output.squeeze(1)
468
+
469
+ if self.kind == MLTaskKeys.BINARY_SEGMENTATION:
470
+ # If model outputs [N, 1, H, W] and target is [N, H, W], squeeze output
471
+ if output.ndim == 4 and output.shape[1] == 1 and target.ndim == 3:
472
+ output = output.squeeze(1)
278
473
 
279
474
  loss = self.criterion(output, target)
280
475
 
281
476
  running_loss += loss.item() * features.size(0)
477
+
478
+ if not self.validation_loader.dataset: # type: ignore
479
+ _LOGGER.warning("No samples processed in _validation_step. Returning 0 loss.")
480
+ return {PyTorchLogKeys.VAL_LOSS: 0.0}
282
481
 
283
- logs = {PyTorchLogKeys.VAL_LOSS: running_loss / len(self.test_loader.dataset)} # type: ignore
482
+ logs = {PyTorchLogKeys.VAL_LOSS: running_loss / len(self.validation_loader.dataset)} # type: ignore
284
483
  return logs
285
484
 
286
- def _predict_for_eval(self, dataloader: DataLoader, classification_threshold: float = 0.5):
485
+ def _predict_for_eval(self, dataloader: DataLoader):
287
486
  """
288
487
  Private method to yield model predictions batch by batch for evaluation.
289
488
 
@@ -294,6 +493,7 @@ class MLTrainer:
294
493
  """
295
494
  self.model.eval()
296
495
  self.model.to(self.device)
496
+
297
497
  with torch.no_grad():
298
498
  for features, target in dataloader:
299
499
  features = features.to(self.device)
@@ -303,25 +503,64 @@ class MLTrainer:
303
503
  y_prob_batch = None
304
504
  y_true_batch = None
305
505
 
306
- if self.kind in ["regression", "multi_target_regression"]:
506
+ if self.kind in [MLTaskKeys.REGRESSION, MLTaskKeys.MULTITARGET_REGRESSION]:
307
507
  y_pred_batch = output.numpy()
308
508
  y_true_batch = target.numpy()
509
+
510
+ elif self.kind in [MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.BINARY_IMAGE_CLASSIFICATION]:
511
+ # Assumes model output is [N, 1] (a single logit)
512
+ # Squeeze output from [N, 1] to [N] if necessary
513
+ if output.ndim == 2 and output.shape[1] == 1:
514
+ output = output.squeeze(1)
515
+
516
+ probs_pos = torch.sigmoid(output) # Probability of positive class
517
+ preds = (probs_pos >= self._classification_threshold).int()
518
+ y_pred_batch = preds.numpy()
519
+ # For metrics (like ROC AUC), we often need probs for *both* classes
520
+ # Create an [N, 2] array: [prob_class_0, prob_class_1]
521
+ probs_neg = 1.0 - probs_pos
522
+ y_prob_batch = torch.stack([probs_neg, probs_pos], dim=1).numpy()
523
+ y_true_batch = target.numpy()
309
524
 
310
- elif self.kind == "classification":
525
+ elif self.kind in [MLTaskKeys.MULTICLASS_CLASSIFICATION, MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION]:
526
+ num_classes = output.shape[1]
527
+ if num_classes < 3:
528
+ # Optional: warn the user they are using the wrong kind
529
+ wrong_class = MLTaskKeys.MULTICLASS_CLASSIFICATION if self.kind == MLTaskKeys.MULTICLASS_CLASSIFICATION else MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION
530
+ recommended_class = MLTaskKeys.BINARY_CLASSIFICATION if self.kind == MLTaskKeys.MULTICLASS_CLASSIFICATION else MLTaskKeys.BINARY_IMAGE_CLASSIFICATION
531
+ _LOGGER.warning(f"'{wrong_class}' kind used with {num_classes} classes. Consider using '{recommended_class}' instead.")
532
+
311
533
  probs = torch.softmax(output, dim=1)
312
534
  preds = torch.argmax(probs, dim=1)
313
535
  y_pred_batch = preds.numpy()
314
536
  y_prob_batch = probs.numpy()
315
537
  y_true_batch = target.numpy()
316
538
 
317
- elif self.kind == "multi_label_classification":
539
+ elif self.kind == MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION:
318
540
  probs = torch.sigmoid(output)
319
- preds = (probs >= classification_threshold).int()
541
+ preds = (probs >= self._classification_threshold).int()
320
542
  y_pred_batch = preds.numpy()
321
543
  y_prob_batch = probs.numpy()
322
544
  y_true_batch = target.numpy()
545
+
546
+ elif self.kind == MLTaskKeys.BINARY_SEGMENTATION:
547
+ # Assumes model output is [N, 1, H, W] (logits for positive class)
548
+ probs_pos = torch.sigmoid(output) # Shape [N, 1, H, W]
549
+ preds = (probs_pos >= self._classification_threshold).int() # Shape [N, 1, H, W]
550
+
551
+ # Squeeze preds to [N, H, W] (class indices 0 or 1)
552
+ y_pred_batch = preds.squeeze(1).numpy()
553
+
554
+ # Create [N, 2, H, W] probs for consistency
555
+ probs_neg = 1.0 - probs_pos
556
+ y_prob_batch = torch.cat([probs_neg, probs_pos], dim=1).numpy()
557
+
558
+ # Handle target shape [N, 1, H, W] -> [N, H, W]
559
+ if target.ndim == 4 and target.shape[1] == 1:
560
+ target = target.squeeze(1)
561
+ y_true_batch = target.numpy()
323
562
 
324
- elif self.kind == "segmentation":
563
+ elif self.kind == MLTaskKeys.MULTICLASS_SEGMENTATION:
325
564
  # output shape [N, C, H, W]
326
565
  probs = torch.softmax(output, dim=1)
327
566
  preds = torch.argmax(probs, dim=1) # shape [N, H, W]
@@ -334,26 +573,161 @@ class MLTrainer:
334
573
  y_true_batch = target.numpy()
335
574
 
336
575
  yield y_pred_batch, y_prob_batch, y_true_batch
337
-
576
+
338
577
  def evaluate(self,
339
578
  save_dir: Union[str, Path],
340
- data: Optional[Union[DataLoader, Dataset]] = None,
341
- format_configuration: Optional[Union[ClassificationMetricsFormat, MultiClassificationMetricsFormat]]=None):
579
+ model_checkpoint: Union[Path, Literal["latest", "current"]],
580
+ classification_threshold: Optional[float] = None,
581
+ test_data: Optional[Union[DataLoader, Dataset]] = None,
582
+ val_format_configuration: Optional[Union[ClassificationMetricsFormat,
583
+ MultiClassificationMetricsFormat,
584
+ RegressionMetricsFormat,
585
+ SegmentationMetricsFormat]]=None,
586
+ test_format_configuration: Optional[Union[ClassificationMetricsFormat,
587
+ MultiClassificationMetricsFormat,
588
+ RegressionMetricsFormat,
589
+ SegmentationMetricsFormat]]=None):
342
590
  """
343
591
  Evaluates the model, routing to the correct evaluation function based on task `kind`.
344
592
 
345
593
  Args:
594
+ model_checkpoint ('auto' | Path | None):
595
+ - Path to a valid checkpoint for the model. The state of the trained model will be overwritten in place.
596
+ - If 'latest', the latest checkpoint will be loaded if a DragonModelCheckpoint was provided. The state of the trained model will be overwritten in place.
597
+ - If 'current', use the current state of the trained model up the latest trained epoch.
346
598
  save_dir (str | Path): Directory to save all reports and plots.
347
- data (DataLoader | Dataset | None): The data to evaluate on. If None, defaults to the trainer's internal test_dataset.
599
+ classification_threshold (float | None): Used for tasks using a binary approach (binary classification, binary segmentation, multilabel binary classification)
600
+ test_data (DataLoader | Dataset | None): Optional Test data to evaluate the model performance. Validation and Test metrics will be saved to subdirectories.
601
+ val_format_configuration: Optional configuration for metric format output for the validation set.
602
+ test_format_configuration: Optional configuration for metric format output for the test set.
603
+ """
604
+ # Validate model checkpoint
605
+ if isinstance(model_checkpoint, Path):
606
+ checkpoint_validated = make_fullpath(model_checkpoint, enforce="file")
607
+ elif model_checkpoint in [MagicWords.LATEST, MagicWords.CURRENT]:
608
+ checkpoint_validated = model_checkpoint
609
+ else:
610
+ _LOGGER.error(f"'model_checkpoint' must be a Path object, or the string '{MagicWords.LATEST}', or the string '{MagicWords.CURRENT}'.")
611
+ raise ValueError()
612
+
613
+ # Validate classification threshold
614
+ if self.kind not in MLTaskKeys.ALL_BINARY_TASKS:
615
+ # dummy value for tasks that do not need it
616
+ threshold_validated = 0.5
617
+ elif classification_threshold is None:
618
+ # it should have been provided for binary tasks
619
+ _LOGGER.error(f"The classification threshold must be provided for '{self.kind}'.")
620
+ raise ValueError()
621
+ elif classification_threshold <= 0.0 or classification_threshold >= 1.0:
622
+ # Invalid float
623
+ _LOGGER.error(f"A classification threshold of {classification_threshold} is invalid. Must be in the range (0.0 - 1.0).")
624
+ raise ValueError()
625
+ else:
626
+ threshold_validated = classification_threshold
627
+
628
+ # Validate val configuration
629
+ if val_format_configuration is not None:
630
+ if not isinstance(val_format_configuration, (ClassificationMetricsFormat,
631
+ MultiClassificationMetricsFormat,
632
+ RegressionMetricsFormat,
633
+ SegmentationMetricsFormat)):
634
+ _LOGGER.error(f"Invalid 'format_configuration': '{type(val_format_configuration)}'.")
635
+ raise ValueError()
636
+ else:
637
+ val_configuration_validated = val_format_configuration
638
+ else: # config is None
639
+ val_configuration_validated = None
640
+
641
+ # Validate directory
642
+ save_path = make_fullpath(save_dir, make=True, enforce="directory")
643
+
644
+ # Validate test data and dispatch
645
+ if test_data is not None:
646
+ if not isinstance(test_data, (DataLoader, Dataset)):
647
+ _LOGGER.error(f"Invalid type for 'test_data': '{type(test_data)}'.")
648
+ raise ValueError()
649
+ test_data_validated = test_data
650
+
651
+ validation_metrics_path = save_path / DragonTrainerKeys.VALIDATION_METRICS_DIR
652
+ test_metrics_path = save_path / DragonTrainerKeys.TEST_METRICS_DIR
653
+
654
+ # Dispatch validation set
655
+ _LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{DragonTrainerKeys.VALIDATION_METRICS_DIR}'")
656
+ self._evaluate(save_dir=validation_metrics_path,
657
+ model_checkpoint=checkpoint_validated,
658
+ classification_threshold=threshold_validated,
659
+ data=None,
660
+ format_configuration=val_configuration_validated)
661
+
662
+ # Validate test configuration
663
+ if test_format_configuration is not None:
664
+ if not isinstance(test_format_configuration, (ClassificationMetricsFormat,
665
+ MultiClassificationMetricsFormat,
666
+ RegressionMetricsFormat,
667
+ SegmentationMetricsFormat)):
668
+ warning_message_type = f"Invalid test_format_configuration': '{type(val_format_configuration)}'."
669
+ if val_configuration_validated is not None:
670
+ warning_message_type += " 'val_format_configuration' will be used for the test set metrics output."
671
+ test_configuration_validated = val_configuration_validated
672
+ else:
673
+ warning_message_type += " Using default format."
674
+ test_configuration_validated = None
675
+ _LOGGER.warning(warning_message_type)
676
+ else:
677
+ test_configuration_validated = test_format_configuration
678
+ else: #config is None
679
+ test_configuration_validated = None
680
+
681
+ # Dispatch test set
682
+ _LOGGER.info(f"Evaluating on test dataset. Metrics will be saved to '{DragonTrainerKeys.TEST_METRICS_DIR}'")
683
+ self._evaluate(save_dir=test_metrics_path,
684
+ model_checkpoint="current",
685
+ classification_threshold=threshold_validated,
686
+ data=test_data_validated,
687
+ format_configuration=test_configuration_validated)
688
+ else:
689
+ # Dispatch validation set
690
+ _LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{save_path.name}'")
691
+ self._evaluate(save_dir=save_path,
692
+ model_checkpoint=checkpoint_validated,
693
+ classification_threshold=threshold_validated,
694
+ data=None,
695
+ format_configuration=val_configuration_validated)
696
+
697
+ def _evaluate(self,
698
+ save_dir: Union[str, Path],
699
+ model_checkpoint: Union[Path, Literal["latest", "current"]],
700
+ classification_threshold: float,
701
+ data: Optional[Union[DataLoader, Dataset]],
702
+ format_configuration: Optional[Union[ClassificationMetricsFormat,
703
+ MultiClassificationMetricsFormat,
704
+ RegressionMetricsFormat,
705
+ SegmentationMetricsFormat]]):
706
+ """
707
+ Changed to a private helper function.
348
708
  """
349
709
  dataset_for_names = None
350
710
  eval_loader = None
351
-
711
+
712
+ # set threshold
713
+ self._classification_threshold = classification_threshold
714
+
715
+ # load model checkpoint
716
+ if isinstance(model_checkpoint, Path):
717
+ self._load_checkpoint(path=model_checkpoint)
718
+ elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback:
719
+ path_to_latest = self._checkpoint_callback.best_checkpoint_path
720
+ self._load_checkpoint(path_to_latest)
721
+ elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback is None:
722
+ _LOGGER.error(f"'model_checkpoint' set to '{MagicWords.LATEST}' but no checkpoint callback was found.")
723
+ raise ValueError()
724
+
725
+ # Dataloader
352
726
  if isinstance(data, DataLoader):
353
727
  eval_loader = data
354
728
  # Try to get the dataset from the loader for fetching target names
355
729
  if hasattr(data, 'dataset'):
356
- dataset_for_names = data.dataset
730
+ dataset_for_names = data.dataset # type: ignore
357
731
  elif isinstance(data, Dataset):
358
732
  # Create a new loader from the provided dataset
359
733
  eval_loader = DataLoader(data,
@@ -363,17 +737,17 @@ class MLTrainer:
363
737
  pin_memory=(self.device.type == "cuda"))
364
738
  dataset_for_names = data
365
739
  else: # data is None, use the trainer's default test dataset
366
- if self.test_dataset is None:
740
+ if self.validation_dataset is None:
367
741
  _LOGGER.error("Cannot evaluate. No data provided and no test_dataset available in the trainer.")
368
742
  raise ValueError()
369
743
  # Create a fresh DataLoader from the test_dataset
370
- eval_loader = DataLoader(self.test_dataset,
744
+ eval_loader = DataLoader(self.validation_dataset,
371
745
  batch_size=self._batch_size,
372
746
  shuffle=False,
373
747
  num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
374
748
  pin_memory=(self.device.type == "cuda"))
375
749
 
376
- dataset_for_names = self.test_dataset
750
+ dataset_for_names = self.validation_dataset
377
751
 
378
752
  if eval_loader is None:
379
753
  _LOGGER.error("Cannot evaluate. No valid data was provided or found.")
@@ -396,34 +770,55 @@ class MLTrainer:
396
770
  y_prob = np.concatenate(all_probs) if all_probs else None
397
771
 
398
772
  # --- Routing Logic ---
399
- if self.kind == "regression":
400
- regression_metrics(y_true.flatten(), y_pred.flatten(), save_dir)
401
-
402
- elif self.kind == "classification":
403
- # Parse configuration
773
+ if self.kind == MLTaskKeys.REGRESSION:
774
+ # Check configuration
775
+ config = None
776
+ if format_configuration and isinstance(format_configuration, RegressionMetricsFormat):
777
+ config = format_configuration
778
+ elif format_configuration:
779
+ _LOGGER.warning(f"Wrong configuration type: Received {type(format_configuration).__name__}, expected RegressionMetricsFormat.")
780
+
781
+ regression_metrics(y_true=y_true.flatten(),
782
+ y_pred=y_pred.flatten(),
783
+ save_dir=save_dir,
784
+ config=config)
785
+
786
+ elif self.kind in [MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.BINARY_IMAGE_CLASSIFICATION, MLTaskKeys.MULTICLASS_CLASSIFICATION, MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION]:
787
+ # Check configuration
788
+ config = None
404
789
  if format_configuration and isinstance(format_configuration, ClassificationMetricsFormat):
405
- classification_metrics(save_dir=save_dir,
406
- y_true=y_true,
407
- y_pred=y_pred,
408
- y_prob=y_prob,
409
- cmap=format_configuration.cmap,
410
- class_map=format_configuration.class_map,
411
- ROC_PR_line=format_configuration.ROC_PR_line,
412
- calibration_bins=format_configuration.calibration_bins,
413
- font_size=format_configuration.font_size)
414
- else:
415
- classification_metrics(save_dir, y_true, y_pred, y_prob)
416
-
417
- elif self.kind == "multi_target_regression":
790
+ config = format_configuration
791
+ elif format_configuration:
792
+ _LOGGER.warning(f"Wrong configuration type: Received {type(format_configuration).__name__}, expected ClassificationMetricsFormat.")
793
+
794
+ classification_metrics(save_dir=save_dir,
795
+ y_true=y_true,
796
+ y_pred=y_pred,
797
+ y_prob=y_prob,
798
+ config=config)
799
+
800
+ elif self.kind == MLTaskKeys.MULTITARGET_REGRESSION:
418
801
  try:
419
802
  target_names = dataset_for_names.target_names # type: ignore
420
803
  except AttributeError:
421
804
  num_targets = y_true.shape[1]
422
805
  target_names = [f"target_{i}" for i in range(num_targets)]
423
806
  _LOGGER.warning(f"Dataset has no 'target_names' attribute. Using generic names.")
424
- multi_target_regression_metrics(y_true, y_pred, target_names, save_dir)
807
+
808
+ # Check configuration
809
+ config = None
810
+ if format_configuration and isinstance(format_configuration, RegressionMetricsFormat):
811
+ config = format_configuration
812
+ elif format_configuration:
813
+ _LOGGER.warning(f"Wrong configuration type: Received {type(format_configuration).__name__}, expected RegressionMetricsFormat.")
814
+
815
+ multi_target_regression_metrics(y_true=y_true,
816
+ y_pred=y_pred,
817
+ target_names=target_names,
818
+ save_dir=save_dir,
819
+ config=config)
425
820
 
426
- elif self.kind == "multi_label_classification":
821
+ elif self.kind == MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION:
427
822
  try:
428
823
  target_names = dataset_for_names.target_names # type: ignore
429
824
  except AttributeError:
@@ -435,19 +830,21 @@ class MLTrainer:
435
830
  _LOGGER.error("Evaluation for multi_label_classification requires probabilities (y_prob).")
436
831
  return
437
832
 
833
+ # Check configuration
834
+ config = None
438
835
  if format_configuration and isinstance(format_configuration, MultiClassificationMetricsFormat):
439
- multi_label_classification_metrics(y_true=y_true,
440
- y_prob=y_prob,
441
- target_names=target_names,
442
- save_dir=save_dir,
443
- threshold=format_configuration.threshold,
444
- ROC_PR_line=format_configuration.ROC_PR_line,
445
- cmap=format_configuration.cmap,
446
- font_size=format_configuration.font_size)
447
- else:
448
- multi_label_classification_metrics(y_true, y_prob, target_names, save_dir)
836
+ config = format_configuration
837
+ elif format_configuration:
838
+ _LOGGER.warning(f"Wrong configuration type: Received {type(format_configuration).__name__}, expected MultiClassificationMetricsFormat.")
839
+
840
+ multi_label_classification_metrics(y_true=y_true,
841
+ y_pred=y_pred,
842
+ y_prob=y_prob,
843
+ target_names=target_names,
844
+ save_dir=save_dir,
845
+ config=config)
449
846
 
450
- elif self.kind == "segmentation":
847
+ elif self.kind in [MLTaskKeys.BINARY_SEGMENTATION, MLTaskKeys.MULTICLASS_SEGMENTATION]:
451
848
  class_names = None
452
849
  try:
453
850
  # Try to get 'classes' from VisionDatasetMaker
@@ -469,10 +866,18 @@ class MLTrainer:
469
866
  class_names = [f"Class {i}" for i in labels]
470
867
  _LOGGER.warning(f"Dataset has no 'classes' or 'target_names' attribute. Using generic names.")
471
868
 
472
- segmentation_metrics(y_true, y_pred, save_dir, class_names=class_names)
473
-
474
- # print("\n--- Training History ---")
475
- plot_losses(self.history, save_dir=save_dir)
869
+ # Check configuration
870
+ config = None
871
+ if format_configuration and isinstance(format_configuration, SegmentationMetricsFormat):
872
+ config = format_configuration
873
+ elif format_configuration:
874
+ _LOGGER.warning(f"Wrong configuration type: Received {type(format_configuration).__name__}, expected SegmentationMetricsFormat.")
875
+
876
+ segmentation_metrics(y_true=y_true,
877
+ y_pred=y_pred,
878
+ save_dir=save_dir,
879
+ class_names=class_names,
880
+ config=config)
476
881
 
477
882
  def explain(self,
478
883
  save_dir: Union[str,Path],
@@ -537,7 +942,7 @@ class MLTrainer:
537
942
  return
538
943
 
539
944
  # 2. Determine target dataset and get explanation instances
540
- target_dataset = explain_dataset if explain_dataset is not None else self.test_dataset
945
+ target_dataset = explain_dataset if explain_dataset is not None else self.validation_dataset
541
946
  instances_to_explain = _get_random_sample(target_dataset, n_samples)
542
947
  if instances_to_explain is None:
543
948
  _LOGGER.error("Explanation dataset is empty or invalid. Skipping SHAP analysis.")
@@ -556,7 +961,7 @@ class MLTrainer:
556
961
  self.model.to(self.device)
557
962
 
558
963
  # 3. Call the plotting function
559
- if self.kind in ["regression", "classification"]:
964
+ if self.kind in [MLTaskKeys.REGRESSION, MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.MULTICLASS_CLASSIFICATION]:
560
965
  shap_summary_plot(
561
966
  model=self.model,
562
967
  background_data=background_data,
@@ -566,7 +971,7 @@ class MLTrainer:
566
971
  explainer_type=explainer_type,
567
972
  device=self.device
568
973
  )
569
- elif self.kind in ["multi_target_regression", "multi_label_classification"]:
974
+ elif self.kind in [MLTaskKeys.MULTITARGET_REGRESSION, MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION]:
570
975
  # try to get target names
571
976
  if target_names is None:
572
977
  target_names = []
@@ -640,13 +1045,11 @@ class MLTrainer:
640
1045
 
641
1046
  # --- Step 1: Check if the model supports this explanation ---
642
1047
  if not getattr(self.model, 'has_interpretable_attention', False):
643
- _LOGGER.warning(
644
- "Model is not flagged for interpretable attention analysis. Skipping. This is the correct behavior for models like MultiHeadAttentionMLP."
645
- )
1048
+ _LOGGER.warning("Model is not compatible with interpretable attention analysis. Skipping.")
646
1049
  return
647
1050
 
648
1051
  # --- Step 2: Set up the dataloader ---
649
- dataset_to_use = explain_dataset if explain_dataset is not None else self.test_dataset
1052
+ dataset_to_use = explain_dataset if explain_dataset is not None else self.validation_dataset
650
1053
  if not isinstance(dataset_to_use, Dataset):
651
1054
  _LOGGER.error("The explanation dataset is empty or invalid. Skipping attention analysis.")
652
1055
  return
@@ -681,40 +1084,111 @@ class MLTrainer:
681
1084
  )
682
1085
  else:
683
1086
  _LOGGER.error("No attention weights were collected from the model.")
684
-
685
- def _callbacks_hook(self, method_name: str, *args, **kwargs):
686
- """Calls the specified method on all callbacks."""
687
- for callback in self.callbacks:
688
- method = getattr(callback, method_name)
689
- method(*args, **kwargs)
690
-
691
- def to_cpu(self):
692
- """
693
- Moves the model to the CPU and updates the trainer's device setting.
694
1087
 
695
- This is useful for running operations that require the CPU.
696
- """
697
- self.device = torch.device('cpu')
698
- self.model.to(self.device)
699
- _LOGGER.info("Trainer and model moved to CPU.")
700
-
701
- def to_device(self, device: str):
1088
+ def finalize_model_training(self,
1089
+ save_dir: Union[str, Path],
1090
+ filename: str,
1091
+ model_checkpoint: Union[Path, Literal['latest', 'current']],
1092
+ classification_threshold: Optional[float]=None,
1093
+ class_map: Optional[Dict[str,int]]=None):
702
1094
  """
703
- Moves the model to the specified device and updates the trainer's device setting.
1095
+ Saves a finalized, "inference-ready" model state to a .pth file.
1096
+
1097
+ This method saves the model's `state_dict`, the final epoch number, and
1098
+ an optional classification threshold required for binary-based tasks (binary classification, binary segmentation,
1099
+ multilabel binary classification).
704
1100
 
705
1101
  Args:
706
- device (str): The target device (e.g., 'cuda', 'mps', 'cpu').
1102
+ save_dir (str | Path): The directory to save the finalized model.
1103
+ filename (str): The desired filename for the saved file.
1104
+ model_checkpoint (Path | "latest" | "current"):
1105
+ - Path: Loads the model state from a specific checkpoint file.
1106
+ - "latest": Loads the best model state saved by the `DragonModelCheckpoint` callback.
1107
+ - "current": Uses the model's state as it is at the end of the `fit()` call.
1108
+ classification_threshold (float, None):
1109
+ Required for `binary classification`, `binary segmentation`, and
1110
+ `multilabel binary classification`. This is the threshold (0.0-1.0)
1111
+ used to convert probabilities to class labels.
1112
+ class_map (Dict[str, int] | None): Sets the class name mapping to translate predicted integer labels back into string names. (For Classification and Segmentation Tasks)
707
1113
  """
708
- self.device = self._validate_device(device)
709
- self.model.to(self.device)
710
- _LOGGER.info(f"Trainer and model moved to {self.device}.")
1114
+ # handle save path
1115
+ sanitized_filename = sanitize_filename(filename)
1116
+ if not sanitized_filename.endswith(".pth"):
1117
+ sanitized_filename = sanitized_filename + ".pth"
1118
+
1119
+ dir_path = make_fullpath(save_dir, make=True, enforce="directory")
1120
+ full_path = dir_path / sanitized_filename
1121
+
1122
+ # threshold required for binary tasks
1123
+ if self.kind in MLTaskKeys.ALL_BINARY_TASKS:
1124
+ if classification_threshold is None:
1125
+ _LOGGER.error(f"A classification threshold is needed for binary-based classification tasks. If unknown, use '0.5' as a default.")
1126
+ raise ValueError()
1127
+ elif not isinstance(classification_threshold, float):
1128
+ _LOGGER.error(f"The classification threshold must be a float value.")
1129
+ raise TypeError()
1130
+ elif classification_threshold <= 0.0 or classification_threshold >= 1.0:
1131
+ _LOGGER.error(f"The classification threshold must be in the range (0.0 - 1.0).")
1132
+ else:
1133
+ classification_threshold = None
1134
+
1135
+ # handle checkpoint
1136
+ if isinstance(model_checkpoint, Path):
1137
+ self._load_checkpoint(path=model_checkpoint)
1138
+ elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback:
1139
+ path_to_latest = self._checkpoint_callback.best_checkpoint_path
1140
+ self._load_checkpoint(path_to_latest)
1141
+ elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback is None:
1142
+ _LOGGER.error(f"'model_checkpoint' set to '{MagicWords.LATEST}' but no checkpoint callback was found.")
1143
+ raise ValueError()
1144
+ elif model_checkpoint == MagicWords.CURRENT:
1145
+ pass
1146
+ else:
1147
+ _LOGGER.error(f"Unknown 'model_checkpoint' parameter received '{model_checkpoint}'.")
1148
+
1149
+ # Handle class map
1150
+ if self.kind in [MLTaskKeys.BINARY_CLASSIFICATION,
1151
+ MLTaskKeys.MULTICLASS_CLASSIFICATION,
1152
+ MLTaskKeys.BINARY_IMAGE_CLASSIFICATION,
1153
+ MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION,
1154
+ MLTaskKeys.BINARY_SEGMENTATION,
1155
+ MLTaskKeys.MULTICLASS_SEGMENTATION]:
1156
+ if class_map is None:
1157
+ _LOGGER.error(f"'class_map' is required for '{self.kind}'.")
1158
+ raise ValueError()
1159
+ else:
1160
+ class_map = None
1161
+
1162
+ # Create finalized data
1163
+ finalized_data = {
1164
+ PyTorchCheckpointKeys.EPOCH: self.epoch,
1165
+ PyTorchCheckpointKeys.MODEL_STATE: self.model.state_dict(),
1166
+ }
1167
+
1168
+ if classification_threshold is not None:
1169
+ self._classification_threshold = classification_threshold
1170
+ finalized_data[PyTorchCheckpointKeys.CLASSIFICATION_THRESHOLD] = classification_threshold
1171
+
1172
+ if class_map is not None:
1173
+ finalized_data[PyTorchCheckpointKeys.CLASS_MAP] = class_map
1174
+
1175
+ torch.save(finalized_data, full_path)
1176
+
1177
+ _LOGGER.info(f"Finalized model weights saved to {full_path}.")
711
1178
 
712
1179
 
713
1180
  # Object Detection Trainer
714
- class ObjectDetectionTrainer:
715
- def __init__(self, model: nn.Module, train_dataset: Dataset, test_dataset: Dataset,
1181
+ class DragonDetectionTrainer(_BaseDragonTrainer):
1182
+ def __init__(self, model: nn.Module,
1183
+ train_dataset: Dataset,
1184
+ validation_dataset: Dataset,
716
1185
  collate_fn: Callable, optimizer: torch.optim.Optimizer,
717
- device: Union[Literal['cuda', 'mps', 'cpu'],str], dataloader_workers: int = 2, callbacks: Optional[List[Callback]] = None):
1186
+ device: Union[Literal['cuda', 'mps', 'cpu'],str],
1187
+ checkpoint_callback: Optional[DragonModelCheckpoint],
1188
+ early_stopping_callback: Optional[DragonEarlyStopping],
1189
+ lr_scheduler_callback: Optional[DragonLRScheduler],
1190
+ extra_callbacks: Optional[List[_Callback]] = None,
1191
+ dataloader_workers: int = 2):
718
1192
  """
719
1193
  Automates the training process of an Object Detection Model (e.g., DragonFastRCNN).
720
1194
 
@@ -723,58 +1197,36 @@ class ObjectDetectionTrainer:
723
1197
  Args:
724
1198
  model (nn.Module): The PyTorch object detection model to train.
725
1199
  train_dataset (Dataset): The training dataset.
726
- test_dataset (Dataset): The testing/validation dataset.
1200
+ validation_dataset (Dataset): The testing/validation dataset.
727
1201
  collate_fn (Callable): The collate function from `ObjectDetectionDatasetMaker.collate_fn`.
728
1202
  optimizer (torch.optim.Optimizer): The optimizer.
729
1203
  device (str): The device to run training on ('cpu', 'cuda', 'mps').
730
1204
  dataloader_workers (int): Subprocesses for data loading.
731
- callbacks (List[Callback] | None): A list of callbacks to use during training.
1205
+ checkpoint_callback (DragonModelCheckpoint | None): Callback to save the model.
1206
+ early_stopping_callback (DragonEarlyStopping | None): Callback to stop training early.
1207
+ lr_scheduler_callback (DragonLRScheduler | None): Callback to manage the LR scheduler.
1208
+ extra_callbacks (List[Callback] | None): A list of extra callbacks to use during training.
732
1209
 
733
1210
  ## Note:
734
1211
  This trainer is specialized. It does not take a `criterion` because object detection models like Faster R-CNN return a dictionary of losses directly from their forward pass during training.
735
1212
  """
736
- self.model = model
1213
+ # Call the base class constructor with common parameters
1214
+ super().__init__(
1215
+ model=model,
1216
+ optimizer=optimizer,
1217
+ device=device,
1218
+ dataloader_workers=dataloader_workers,
1219
+ checkpoint_callback=checkpoint_callback,
1220
+ early_stopping_callback=early_stopping_callback,
1221
+ lr_scheduler_callback=lr_scheduler_callback,
1222
+ extra_callbacks=extra_callbacks
1223
+ )
1224
+
737
1225
  self.train_dataset = train_dataset
738
- self.test_dataset = test_dataset
1226
+ self.validation_dataset = validation_dataset # <-- Renamed
739
1227
  self.kind = "object_detection"
740
1228
  self.collate_fn = collate_fn
741
1229
  self.criterion = None # Criterion is handled inside the model
742
- self.optimizer = optimizer
743
- self.scheduler = None
744
- self.device = self._validate_device(device)
745
- self.dataloader_workers = dataloader_workers
746
-
747
- # Callback handler - History and TqdmProgressBar are added by default
748
- default_callbacks = [History(), TqdmProgressBar()]
749
- user_callbacks = callbacks if callbacks is not None else []
750
- self.callbacks = default_callbacks + user_callbacks
751
- self._set_trainer_on_callbacks()
752
-
753
- # Internal state
754
- self.train_loader = None
755
- self.test_loader = None
756
- self.history = {}
757
- self.epoch = 0
758
- self.epochs = 0 # Total epochs for the fit run
759
- self.start_epoch = 1
760
- self.stop_training = False
761
- self._batch_size = 10
762
-
763
- def _validate_device(self, device: str) -> torch.device:
764
- """Validates the selected device and returns a torch.device object."""
765
- device_lower = device.lower()
766
- if "cuda" in device_lower and not torch.cuda.is_available():
767
- _LOGGER.warning("CUDA not available, switching to CPU.")
768
- device = "cpu"
769
- elif device_lower == "mps" and not torch.backends.mps.is_available():
770
- _LOGGER.warning("Apple Metal Performance Shaders (MPS) not available, switching to CPU.")
771
- device = "cpu"
772
- return torch.device(device)
773
-
774
- def _set_trainer_on_callbacks(self):
775
- """Gives each callback a reference to this trainer instance."""
776
- for callback in self.callbacks:
777
- callback.set_trainer(self)
778
1230
 
779
1231
  def _create_dataloaders(self, batch_size: int, shuffle: bool):
780
1232
  """Initializes the DataLoaders with the object detection collate_fn."""
@@ -786,125 +1238,25 @@ class ObjectDetectionTrainer:
786
1238
  batch_size=batch_size,
787
1239
  shuffle=shuffle,
788
1240
  num_workers=loader_workers,
789
- pin_memory=("cuda" in self.device.type),
790
- collate_fn=self.collate_fn # Use the provided collate function
1241
+ pin_memory=("cuda" in self.device.type),
1242
+ collate_fn=self.collate_fn, # Use the provided collate function
1243
+ drop_last=True
791
1244
  )
792
1245
 
793
- self.test_loader = DataLoader(
794
- dataset=self.test_dataset,
1246
+ self.validation_loader = DataLoader(
1247
+ dataset=self.validation_dataset,
795
1248
  batch_size=batch_size,
796
1249
  shuffle=False,
797
1250
  num_workers=loader_workers,
798
1251
  pin_memory=("cuda" in self.device.type),
799
1252
  collate_fn=self.collate_fn # Use the provided collate function
800
1253
  )
801
-
802
- def _load_checkpoint(self, path: Union[str, Path]):
803
- """Loads a training checkpoint to resume training."""
804
- p = make_fullpath(path, enforce="file")
805
- _LOGGER.info(f"Loading checkpoint from '{p.name}' to resume training...")
806
-
807
- try:
808
- checkpoint = torch.load(p, map_location=self.device)
809
-
810
- if PyTorchCheckpointKeys.MODEL_STATE not in checkpoint or PyTorchCheckpointKeys.OPTIMIZER_STATE not in checkpoint:
811
- _LOGGER.error(f"Checkpoint file '{p.name}' is invalid. Missing 'model_state_dict' or 'optimizer_state_dict'.")
812
- raise KeyError()
813
-
814
- self.model.load_state_dict(checkpoint[PyTorchCheckpointKeys.MODEL_STATE])
815
- self.optimizer.load_state_dict(checkpoint[PyTorchCheckpointKeys.OPTIMIZER_STATE])
816
- self.start_epoch = checkpoint.get(PyTorchCheckpointKeys.EPOCH, 0) + 1 # Resume on the *next* epoch
817
-
818
- # --- Scheduler State Loading Logic ---
819
- scheduler_state_exists = PyTorchCheckpointKeys.SCHEDULER_STATE in checkpoint
820
- scheduler_object_exists = self.scheduler is not None
821
-
822
- if scheduler_object_exists and scheduler_state_exists:
823
- # Case 1: Both exist. Attempt to load.
824
- try:
825
- self.scheduler.load_state_dict(checkpoint[PyTorchCheckpointKeys.SCHEDULER_STATE]) # type: ignore
826
- scheduler_name = self.scheduler.__class__.__name__
827
- _LOGGER.info(f"Restored LR scheduler state for: {scheduler_name}")
828
- except Exception as e:
829
- # Loading failed, likely a mismatch
830
- scheduler_name = self.scheduler.__class__.__name__
831
- _LOGGER.error(f"Failed to load scheduler state for '{scheduler_name}'. A different scheduler type might have been used.")
832
- raise e
833
1254
 
834
- elif scheduler_object_exists and not scheduler_state_exists:
835
- # Case 2: Scheduler provided, but no state in checkpoint.
836
- scheduler_name = self.scheduler.__class__.__name__
837
- _LOGGER.warning(f"'{scheduler_name}' was provided, but no scheduler state was found in the checkpoint. The scheduler will start from its initial state.")
838
-
839
- elif not scheduler_object_exists and scheduler_state_exists:
840
- # Case 3: State in checkpoint, but no scheduler provided.
841
- _LOGGER.error("Checkpoint contains an LR scheduler state, but no LRScheduler callback was provided.")
842
- raise ValueError()
843
-
844
- # Restore callback states
845
- for cb in self.callbacks:
846
- if isinstance(cb, ModelCheckpoint) and PyTorchCheckpointKeys.BEST_SCORE in checkpoint:
847
- cb.best = checkpoint[PyTorchCheckpointKeys.BEST_SCORE]
848
- _LOGGER.info(f"Restored {cb.__class__.__name__} 'best' score to: {cb.best:.4f}")
849
-
850
- _LOGGER.info(f"Checkpoint loaded. Resuming training from epoch {self.start_epoch}.")
851
-
852
- except Exception as e:
853
- _LOGGER.error(f"Failed to load checkpoint from '{p}': {e}")
854
- raise
855
-
856
- def fit(self,
857
- epochs: int = 10,
858
- batch_size: int = 10,
859
- shuffle: bool = True,
860
- resume_from_checkpoint: Optional[Union[str, Path]] = None):
861
- """
862
- Starts the training-validation process of the model.
863
-
864
- Returns the "History" callback dictionary.
865
-
866
- Args:
867
- epochs (int): The total number of epochs to train for.
868
- batch_size (int): The number of samples per batch.
869
- shuffle (bool): Whether to shuffle the training data at each epoch.
870
- resume_from_checkpoint (str | Path | None): Optional path to a checkpoint to resume training.
871
- """
872
- self.epochs = epochs
873
- self._batch_size = batch_size
874
- self._create_dataloaders(self._batch_size, shuffle)
875
- self.model.to(self.device)
876
-
877
- if resume_from_checkpoint:
878
- self._load_checkpoint(resume_from_checkpoint)
879
-
880
- # Reset stop_training flag on the trainer
881
- self.stop_training = False
882
-
883
- self._callbacks_hook('on_train_begin')
884
-
885
- for epoch in range(self.start_epoch, self.epochs + 1):
886
- self.epoch = epoch
887
- epoch_logs = {}
888
- self._callbacks_hook('on_epoch_begin', epoch, logs=epoch_logs)
889
-
890
- train_logs = self._train_step()
891
- epoch_logs.update(train_logs)
892
-
893
- val_logs = self._validation_step()
894
- epoch_logs.update(val_logs)
895
-
896
- self._callbacks_hook('on_epoch_end', epoch, logs=epoch_logs)
897
-
898
- # Check the early stopping flag
899
- if self.stop_training:
900
- break
901
-
902
- self._callbacks_hook('on_train_end')
903
- return self.history
904
-
905
1255
  def _train_step(self):
906
1256
  self.model.train()
907
1257
  running_loss = 0.0
1258
+ total_samples = 0
1259
+
908
1260
  for batch_idx, (images, targets) in enumerate(self.train_loader): # type: ignore
909
1261
  # images is a tuple of tensors, targets is a tuple of dicts
910
1262
  batch_size = len(images)
@@ -941,21 +1293,28 @@ class ObjectDetectionTrainer:
941
1293
  # Calculate batch loss and update running loss for the epoch
942
1294
  batch_loss = loss.item()
943
1295
  running_loss += batch_loss * batch_size
1296
+ total_samples += batch_size # <-- Accumulate total samples
944
1297
 
945
1298
  # Add the batch loss to the logs and call the end-of-batch hook
946
1299
  batch_logs[PyTorchLogKeys.BATCH_LOSS] = batch_loss # type: ignore
947
1300
  self._callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
1301
+
1302
+ # Calculate loss using the correct denominator
1303
+ if total_samples == 0:
1304
+ _LOGGER.warning("No samples processed in _train_step. Returning 0 loss.")
1305
+ return {PyTorchLogKeys.TRAIN_LOSS: 0.0}
948
1306
 
949
- return {PyTorchLogKeys.TRAIN_LOSS: running_loss / len(self.train_loader.dataset)} # type: ignore
1307
+ return {PyTorchLogKeys.TRAIN_LOSS: running_loss / total_samples}
950
1308
 
951
1309
  def _validation_step(self):
952
1310
  self.model.train() # Set to train mode even for validation loss calculation
953
- # as model internals (e.g., proposals) might differ,
954
- # but we still need loss_dict.
955
- # We use torch.no_grad() to prevent gradient updates.
1311
+ # as model internals (e.g., proposals) might differ, but we still need loss_dict.
1312
+ # use torch.no_grad() to prevent gradient updates.
956
1313
  running_loss = 0.0
1314
+ total_samples = 0
1315
+
957
1316
  with torch.no_grad():
958
- for images, targets in self.test_loader: # type: ignore
1317
+ for images, targets in self.validation_loader: # type: ignore
959
1318
  batch_size = len(images)
960
1319
 
961
1320
  # Move data to device
@@ -973,25 +1332,105 @@ class ObjectDetectionTrainer:
973
1332
  loss: torch.Tensor = sum(l for l in loss_dict.values()) # type: ignore
974
1333
 
975
1334
  running_loss += loss.item() * batch_size
1335
+ total_samples += batch_size # <-- Accumulate total samples
1336
+
1337
+ # Calculate loss using the correct denominator
1338
+ if total_samples == 0:
1339
+ _LOGGER.warning("No samples processed in _validation_step. Returning 0 loss.")
1340
+ return {PyTorchLogKeys.VAL_LOSS: 0.0}
976
1341
 
977
- logs = {PyTorchLogKeys.VAL_LOSS: running_loss / len(self.test_loader.dataset)} # type: ignore
1342
+ logs = {PyTorchLogKeys.VAL_LOSS: running_loss / total_samples}
978
1343
  return logs
1344
+
1345
+ def evaluate(self,
1346
+ save_dir: Union[str, Path],
1347
+ model_checkpoint: Union[Path, Literal["latest", "current"]],
1348
+ test_data: Optional[Union[DataLoader, Dataset]] = None):
1349
+ """
1350
+ Evaluates the model using object detection mAP metrics.
979
1351
 
980
- def evaluate(self, save_dir: Union[str, Path], data: Optional[Union[DataLoader, Dataset]] = None):
1352
+ Args:
1353
+ save_dir (str | Path): Directory to save all reports and plots.
1354
+ model_checkpoint ('auto' | Path | None):
1355
+ - Path to a valid checkpoint for the model. The state of the trained model will be overwritten in place.
1356
+ - If 'latest', the latest checkpoint will be loaded if a DragonModelCheckpoint was provided. The state of the trained model will be overwritten in place.
1357
+ - If 'current', use the current state of the trained model up the latest trained epoch.
1358
+ test_data (DataLoader | Dataset | None): Optional Test data to evaluate the model performance. Validation and Test metrics will be saved to subdirectories.
981
1359
  """
1360
+ # Validate model checkpoint
1361
+ if isinstance(model_checkpoint, Path):
1362
+ checkpoint_validated = make_fullpath(model_checkpoint, enforce="file")
1363
+ elif model_checkpoint in [MagicWords.LATEST, MagicWords.CURRENT]:
1364
+ checkpoint_validated = model_checkpoint
1365
+ else:
1366
+ _LOGGER.error(f"'model_checkpoint' must be a Path object, or the string '{MagicWords.LATEST}', or the string '{MagicWords.CURRENT}'.")
1367
+ raise ValueError()
1368
+
1369
+ # Validate directory
1370
+ save_path = make_fullpath(save_dir, make=True, enforce="directory")
1371
+
1372
+ # Validate test data and dispatch
1373
+ if test_data is not None:
1374
+ if not isinstance(test_data, (DataLoader, Dataset)):
1375
+ _LOGGER.error(f"Invalid type for 'test_data': '{type(test_data)}'.")
1376
+ raise ValueError()
1377
+ test_data_validated = test_data
1378
+
1379
+ validation_metrics_path = save_path / DragonTrainerKeys.VALIDATION_METRICS_DIR
1380
+ test_metrics_path = save_path / DragonTrainerKeys.TEST_METRICS_DIR
1381
+
1382
+ # Dispatch validation set
1383
+ _LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{DragonTrainerKeys.VALIDATION_METRICS_DIR}'")
1384
+ self._evaluate(save_dir=validation_metrics_path,
1385
+ model_checkpoint=checkpoint_validated,
1386
+ data=None) # 'None' triggers use of self.test_dataset
1387
+
1388
+ # Dispatch test set
1389
+ _LOGGER.info(f"Evaluating on test dataset. Metrics will be saved to '{DragonTrainerKeys.TEST_METRICS_DIR}'")
1390
+ self._evaluate(save_dir=test_metrics_path,
1391
+ model_checkpoint="current", # Use 'current' state after loading checkpoint once
1392
+ data=test_data_validated)
1393
+ else:
1394
+ # Dispatch validation set
1395
+ _LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{save_path.name}'")
1396
+ self._evaluate(save_dir=save_path,
1397
+ model_checkpoint=checkpoint_validated,
1398
+ data=None) # 'None' triggers use of self.test_dataset
1399
+
1400
+ def _evaluate(self,
1401
+ save_dir: Union[str, Path],
1402
+ model_checkpoint: Union[Path, Literal["latest", "current"]],
1403
+ data: Optional[Union[DataLoader, Dataset]]):
1404
+ """
1405
+ Changed to a private helper method
982
1406
  Evaluates the model using object detection mAP metrics.
983
1407
 
984
1408
  Args:
985
1409
  save_dir (str | Path): Directory to save all reports and plots.
986
1410
  data (DataLoader | Dataset | None): The data to evaluate on. If None, defaults to the trainer's internal test_dataset.
1411
+ model_checkpoint ('auto' | Path | None):
1412
+ - Path to a valid checkpoint for the model. The state of the trained model will be overwritten in place.
1413
+ - If 'latest', the latest checkpoint will be loaded if a DragonModelCheckpoint was provided. The state of the trained model will be overwritten in place.
1414
+ - If 'current', use the current state of the trained model up the latest trained epoch.
987
1415
  """
988
1416
  dataset_for_names = None
989
1417
  eval_loader = None
1418
+
1419
+ # load model checkpoint
1420
+ if isinstance(model_checkpoint, Path):
1421
+ self._load_checkpoint(path=model_checkpoint)
1422
+ elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback:
1423
+ path_to_latest = self._checkpoint_callback.best_checkpoint_path
1424
+ self._load_checkpoint(path_to_latest)
1425
+ elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback is None:
1426
+ _LOGGER.error(f"'model_checkpoint' set to '{MagicWords.LATEST}' but no checkpoint callback was found.")
1427
+ raise ValueError()
990
1428
 
1429
+ # Dataloader
991
1430
  if isinstance(data, DataLoader):
992
1431
  eval_loader = data
993
1432
  if hasattr(data, 'dataset'):
994
- dataset_for_names = data.dataset
1433
+ dataset_for_names = data.dataset # type: ignore
995
1434
  elif isinstance(data, Dataset):
996
1435
  # Create a new loader from the provided dataset
997
1436
  eval_loader = DataLoader(data,
@@ -1002,19 +1441,19 @@ class ObjectDetectionTrainer:
1002
1441
  collate_fn=self.collate_fn)
1003
1442
  dataset_for_names = data
1004
1443
  else: # data is None, use the trainer's default test dataset
1005
- if self.test_dataset is None:
1444
+ if self.validation_dataset is None:
1006
1445
  _LOGGER.error("Cannot evaluate. No data provided and no test_dataset available in the trainer.")
1007
1446
  raise ValueError()
1008
1447
  # Create a fresh DataLoader from the test_dataset
1009
1448
  eval_loader = DataLoader(
1010
- self.test_dataset,
1449
+ self.validation_dataset,
1011
1450
  batch_size=self._batch_size,
1012
1451
  shuffle=False,
1013
1452
  num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
1014
1453
  pin_memory=(self.device.type == "cuda"),
1015
1454
  collate_fn=self.collate_fn
1016
1455
  )
1017
- dataset_for_names = self.test_dataset
1456
+ dataset_for_names = self.validation_dataset
1018
1457
 
1019
1458
  if eval_loader is None:
1020
1459
  _LOGGER.error("Cannot evaluate. No valid data was provided or found.")
@@ -1068,36 +1507,480 @@ class ObjectDetectionTrainer:
1068
1507
  class_names=class_names,
1069
1508
  print_output=False
1070
1509
  )
1071
-
1072
- # print("\n--- Training History ---")
1073
- plot_losses(self.history, save_dir=save_dir)
1074
1510
 
1075
- def _callbacks_hook(self, method_name: str, *args, **kwargs):
1076
- """Calls the specified method on all callbacks."""
1077
- for callback in self.callbacks:
1078
- method = getattr(callback, method_name)
1079
- method(*args, **kwargs)
1511
+ def finalize_model_training(self, save_dir: Union[str, Path], filename: str, model_checkpoint: Union[Path, Literal['latest', 'current']]):
1512
+ """
1513
+ Saves a finalized, "inference-ready" model state to a .pth file.
1514
+
1515
+ This method saves the model's `state_dict` and the final epoch number.
1516
+
1517
+ Args:
1518
+ save_dir (Union[str, Path]): The directory to save the finalized model.
1519
+ filename (str): The desired filename for the model (e.g., "final_model.pth").
1520
+ model_checkpoint (Union[Path, Literal["latest", "current"]]):
1521
+ - Path: Loads the model state from a specific checkpoint file.
1522
+ - "latest": Loads the best model state saved by the `DragonModelCheckpoint` callback.
1523
+ - "current": Uses the model's state as it is at the end of the `fit()` call.
1524
+ """
1525
+ # handle save path
1526
+ sanitized_filename = sanitize_filename(filename)
1527
+ if not sanitized_filename.endswith(".pth"):
1528
+ sanitized_filename = sanitized_filename + ".pth"
1080
1529
 
1081
- def to_cpu(self):
1530
+ dir_path = make_fullpath(save_dir, make=True, enforce="directory")
1531
+ full_path = dir_path / sanitized_filename
1532
+
1533
+ # handle checkpoint
1534
+ if isinstance(model_checkpoint, Path):
1535
+ self._load_checkpoint(path=model_checkpoint)
1536
+ elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback:
1537
+ path_to_latest = self._checkpoint_callback.best_checkpoint_path
1538
+ self._load_checkpoint(path_to_latest)
1539
+ elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback is None:
1540
+ _LOGGER.error(f"'model_checkpoint' set to '{MagicWords.LATEST}' but no checkpoint callback was found.")
1541
+ raise ValueError()
1542
+ elif model_checkpoint == MagicWords.CURRENT:
1543
+ pass
1544
+ else:
1545
+ _LOGGER.error(f"Unknown 'model_checkpoint' parameter received '{model_checkpoint}'.")
1546
+
1547
+ # Create finalized data
1548
+ finalized_data = {
1549
+ PyTorchCheckpointKeys.EPOCH: self.epoch,
1550
+ PyTorchCheckpointKeys.MODEL_STATE: self.model.state_dict(),
1551
+ }
1552
+
1553
+ torch.save(finalized_data, full_path)
1554
+
1555
+ _LOGGER.info(f"Finalized model weights saved to {full_path}.")
1556
+
1557
+ # --- DragonSequenceTrainer ----
1558
+ class DragonSequenceTrainer(_BaseDragonTrainer):
1559
+ def __init__(self,
1560
+ model: nn.Module,
1561
+ train_dataset: Dataset,
1562
+ validation_dataset: Dataset,
1563
+ kind: Literal["sequence-to-sequence", "sequence-to-value"],
1564
+ optimizer: torch.optim.Optimizer,
1565
+ device: Union[Literal['cuda', 'mps', 'cpu'],str],
1566
+ checkpoint_callback: Optional[DragonModelCheckpoint],
1567
+ early_stopping_callback: Optional[DragonEarlyStopping],
1568
+ lr_scheduler_callback: Optional[DragonLRScheduler],
1569
+ extra_callbacks: Optional[List[_Callback]] = None,
1570
+ criterion: Union[nn.Module,Literal["auto"]] = "auto",
1571
+ dataloader_workers: int = 2):
1082
1572
  """
1083
- Moves the model to the CPU and updates the trainer's device setting.
1573
+ Automates the training process of a PyTorch Sequence Model.
1084
1574
 
1085
- This is useful for running operations that require the CPU.
1575
+ Built-in Callbacks: `History`, `TqdmProgressBar`
1576
+
1577
+ Args:
1578
+ model (nn.Module): The PyTorch model to train.
1579
+ train_dataset (Dataset): The training dataset.
1580
+ validation_dataset (Dataset): The validation dataset.
1581
+ kind (str): Used to redirect to the correct process ('sequence-to-sequence' or 'sequence-to-value').
1582
+ criterion (nn.Module | "auto"): The loss function to use. If "auto", it will be inferred from the selected task
1583
+ optimizer (torch.optim.Optimizer): The optimizer.
1584
+ device (str): The device to run training on ('cpu', 'cuda', 'mps').
1585
+ dataloader_workers (int): Subprocesses for data loading.
1586
+ extra_callbacks (List[Callback] | None): A list of extra callbacks to use during training.
1086
1587
  """
1087
- self.device = torch.device('cpu')
1588
+ # Call the base class constructor with common parameters
1589
+ super().__init__(
1590
+ model=model,
1591
+ optimizer=optimizer,
1592
+ device=device,
1593
+ dataloader_workers=dataloader_workers,
1594
+ checkpoint_callback=checkpoint_callback,
1595
+ early_stopping_callback=early_stopping_callback,
1596
+ lr_scheduler_callback=lr_scheduler_callback,
1597
+ extra_callbacks=extra_callbacks
1598
+ )
1599
+
1600
+ if kind not in [MLTaskKeys.SEQUENCE_SEQUENCE, MLTaskKeys.SEQUENCE_VALUE]:
1601
+ raise ValueError(f"'{kind}' is not a valid task type for DragonSequenceTrainer.")
1602
+
1603
+ self.train_dataset = train_dataset
1604
+ self.validation_dataset = validation_dataset
1605
+ self.kind = kind
1606
+
1607
+ # try to validate against Dragon Sequence model
1608
+ if hasattr(self.model, "prediction_mode"):
1609
+ key_to_check: str = self.model.prediction_mode # type: ignore
1610
+ if not key_to_check == self.kind:
1611
+ _LOGGER.error(f"Trainer was set for '{self.kind}', but model architecture '{self.model}' is built for '{key_to_check}'.")
1612
+ raise RuntimeError()
1613
+
1614
+ # loss function
1615
+ if criterion == "auto":
1616
+ # Both sequence tasks are treated as regression problems
1617
+ self.criterion = nn.MSELoss()
1618
+ else:
1619
+ self.criterion = criterion
1620
+
1621
+ def _create_dataloaders(self, batch_size: int, shuffle: bool):
1622
+ """Initializes the DataLoaders."""
1623
+ # Ensure stability on MPS devices by setting num_workers to 0
1624
+ loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
1625
+
1626
+ self.train_loader = DataLoader(
1627
+ dataset=self.train_dataset,
1628
+ batch_size=batch_size,
1629
+ shuffle=shuffle,
1630
+ num_workers=loader_workers,
1631
+ pin_memory=("cuda" in self.device.type),
1632
+ drop_last=True # Drops the last batch if incomplete, selecting a good batch size is key.
1633
+ )
1634
+
1635
+ self.validation_loader = DataLoader(
1636
+ dataset=self.validation_dataset,
1637
+ batch_size=batch_size,
1638
+ shuffle=False,
1639
+ num_workers=loader_workers,
1640
+ pin_memory=("cuda" in self.device.type)
1641
+ )
1642
+
1643
+ def _train_step(self):
1644
+ self.model.train()
1645
+ running_loss = 0.0
1646
+ total_samples = 0
1647
+
1648
+ for batch_idx, (features, target) in enumerate(self.train_loader): # type: ignore
1649
+ # Create a log dictionary for the batch
1650
+ batch_logs = {
1651
+ PyTorchLogKeys.BATCH_INDEX: batch_idx,
1652
+ PyTorchLogKeys.BATCH_SIZE: features.size(0)
1653
+ }
1654
+ self._callbacks_hook('on_batch_begin', batch_idx, logs=batch_logs)
1655
+
1656
+ features, target = features.to(self.device), target.to(self.device)
1657
+ self.optimizer.zero_grad()
1658
+
1659
+ output = self.model(features)
1660
+
1661
+ # --- Label Type/Shape Correction ---
1662
+ # Ensure target is float for MSELoss
1663
+ target = target.float()
1664
+
1665
+ # For seq-to-val, models might output [N, 1] but target is [N].
1666
+ if self.kind == MLTaskKeys.SEQUENCE_VALUE:
1667
+ if output.ndim == 2 and output.shape[1] == 1 and target.ndim == 1:
1668
+ output = output.squeeze(1)
1669
+
1670
+ # For seq-to-seq, models might output [N, Seq, 1] but target is [N, Seq].
1671
+ elif self.kind == MLTaskKeys.SEQUENCE_SEQUENCE:
1672
+ if output.ndim == 3 and output.shape[2] == 1 and target.ndim == 2:
1673
+ output = output.squeeze(-1)
1674
+
1675
+ loss = self.criterion(output, target)
1676
+
1677
+ loss.backward()
1678
+ self.optimizer.step()
1679
+
1680
+ # Calculate batch loss and update running loss for the epoch
1681
+ batch_loss = loss.item()
1682
+ batch_size = features.size(0)
1683
+ running_loss += batch_loss * batch_size # Accumulate total loss
1684
+ total_samples += batch_size # total samples
1685
+
1686
+ # Add the batch loss to the logs and call the end-of-batch hook
1687
+ batch_logs[PyTorchLogKeys.BATCH_LOSS] = batch_loss
1688
+ self._callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
1689
+
1690
+ if total_samples == 0:
1691
+ _LOGGER.warning("No samples processed in a train_step. Returning 0 loss.")
1692
+ return {PyTorchLogKeys.TRAIN_LOSS: 0.0}
1693
+
1694
+ return {PyTorchLogKeys.TRAIN_LOSS: running_loss / total_samples} # type: ignore
1695
+
1696
+ def _validation_step(self):
1697
+ self.model.eval()
1698
+ running_loss = 0.0
1699
+
1700
+ with torch.no_grad():
1701
+ for features, target in self.validation_loader: # type: ignore
1702
+ features, target = features.to(self.device), target.to(self.device)
1703
+
1704
+ output = self.model(features)
1705
+
1706
+ # --- Label Type/Shape Correction ---
1707
+ target = target.float()
1708
+
1709
+ # For seq-to-val, models might output [N, 1] but target is [N].
1710
+ if self.kind == MLTaskKeys.SEQUENCE_VALUE:
1711
+ if output.ndim == 2 and output.shape[1] == 1 and target.ndim == 1:
1712
+ output = output.squeeze(1)
1713
+
1714
+ # For seq-to-seq, models might output [N, Seq, 1] but target is [N, Seq].
1715
+ elif self.kind == MLTaskKeys.SEQUENCE_SEQUENCE:
1716
+ if output.ndim == 3 and output.shape[2] == 1 and target.ndim == 2:
1717
+ output = output.squeeze(-1)
1718
+
1719
+ loss = self.criterion(output, target)
1720
+
1721
+ running_loss += loss.item() * features.size(0)
1722
+
1723
+ if not self.validation_loader.dataset: # type: ignore
1724
+ _LOGGER.warning("No samples processed in _validation_step. Returning 0 loss.")
1725
+ return {PyTorchLogKeys.VAL_LOSS: 0.0}
1726
+
1727
+ logs = {PyTorchLogKeys.VAL_LOSS: running_loss / len(self.validation_loader.dataset)} # type: ignore
1728
+ return logs
1729
+
1730
+ def _predict_for_eval(self, dataloader: DataLoader):
1731
+ """
1732
+ Private method to yield model predictions batch by batch for evaluation.
1733
+
1734
+ Yields:
1735
+ tuple: A tuple containing (y_pred_batch, y_prob_batch, y_true_batch).
1736
+ y_prob_batch is always None for sequence tasks.
1737
+ """
1738
+ self.model.eval()
1088
1739
  self.model.to(self.device)
1089
- _LOGGER.info("Trainer and model moved to CPU.")
1740
+
1741
+ with torch.no_grad():
1742
+ for features, target in dataloader:
1743
+ features = features.to(self.device)
1744
+ output = self.model(features).cpu()
1745
+
1746
+ y_pred_batch = output.numpy()
1747
+ y_prob_batch = None # Not applicable for sequence regression
1748
+ y_true_batch = target.numpy()
1749
+
1750
+ yield y_pred_batch, y_prob_batch, y_true_batch
1751
+
1752
+ def evaluate(self,
1753
+ save_dir: Union[str, Path],
1754
+ model_checkpoint: Union[Path, Literal["latest", "current"]],
1755
+ test_data: Optional[Union[DataLoader, Dataset]] = None,
1756
+ val_format_configuration: Optional[Union[SequenceValueMetricsFormat,
1757
+ SequenceSequenceMetricsFormat]]=None,
1758
+ test_format_configuration: Optional[Union[SequenceValueMetricsFormat,
1759
+ SequenceSequenceMetricsFormat]]=None):
1760
+ """
1761
+ Evaluates the model, routing to the correct evaluation function.
1762
+
1763
+ Args:
1764
+ model_checkpoint ('auto' | Path | None):
1765
+ - Path to a valid checkpoint for the model.
1766
+ - If 'latest', the latest checkpoint will be loaded.
1767
+ - If 'current', use the current state of the trained model.
1768
+ save_dir (str | Path): Directory to save all reports and plots.
1769
+ test_data (DataLoader | Dataset | None): Optional Test data.
1770
+ val_format_configuration: Optional configuration for validation metrics.
1771
+ test_format_configuration: Optional configuration for test metrics.
1772
+ """
1773
+ # Validate model checkpoint
1774
+ if isinstance(model_checkpoint, Path):
1775
+ checkpoint_validated = make_fullpath(model_checkpoint, enforce="file")
1776
+ elif model_checkpoint in [MagicWords.LATEST, MagicWords.CURRENT]:
1777
+ checkpoint_validated = model_checkpoint
1778
+ else:
1779
+ _LOGGER.error(f"'model_checkpoint' must be a Path object, or '{MagicWords.LATEST}', or '{MagicWords.CURRENT}'.")
1780
+ raise ValueError()
1781
+
1782
+ # Validate val configuration
1783
+ if val_format_configuration is not None:
1784
+ if not isinstance(val_format_configuration, (SequenceValueMetricsFormat, SequenceSequenceMetricsFormat)):
1785
+ _LOGGER.error(f"Invalid 'val_format_configuration': '{type(val_format_configuration)}'.")
1786
+ raise ValueError()
1787
+
1788
+ # Validate directory
1789
+ save_path = make_fullpath(save_dir, make=True, enforce="directory")
1790
+
1791
+ # Validate test data and dispatch
1792
+ if test_data is not None:
1793
+ if not isinstance(test_data, (DataLoader, Dataset)):
1794
+ _LOGGER.error(f"Invalid type for 'test_data': '{type(test_data)}'.")
1795
+ raise ValueError()
1796
+ test_data_validated = test_data
1090
1797
 
1091
- def to_device(self, device: str):
1798
+ validation_metrics_path = save_path / DragonTrainerKeys.VALIDATION_METRICS_DIR
1799
+ test_metrics_path = save_path / DragonTrainerKeys.TEST_METRICS_DIR
1800
+
1801
+ # Dispatch validation set
1802
+ _LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{DragonTrainerKeys.VALIDATION_METRICS_DIR}'")
1803
+ self._evaluate(save_dir=validation_metrics_path,
1804
+ model_checkpoint=checkpoint_validated,
1805
+ data=None,
1806
+ format_configuration=val_format_configuration)
1807
+
1808
+ # Validate test configuration
1809
+ test_configuration_validated = None
1810
+ if test_format_configuration is not None:
1811
+ if not isinstance(test_format_configuration, (SequenceValueMetricsFormat, SequenceSequenceMetricsFormat)):
1812
+ warning_message_type = f"Invalid test_format_configuration': '{type(test_format_configuration)}'."
1813
+ if val_format_configuration is not None:
1814
+ warning_message_type += " 'val_format_configuration' will be used."
1815
+ test_configuration_validated = val_format_configuration
1816
+ else:
1817
+ warning_message_type += " Using default format."
1818
+ _LOGGER.warning(warning_message_type)
1819
+ else:
1820
+ test_configuration_validated = test_format_configuration
1821
+
1822
+ # Dispatch test set
1823
+ _LOGGER.info(f"Evaluating on test dataset. Metrics will be saved to '{DragonTrainerKeys.TEST_METRICS_DIR}'")
1824
+ self._evaluate(save_dir=test_metrics_path,
1825
+ model_checkpoint="current",
1826
+ data=test_data_validated,
1827
+ format_configuration=test_configuration_validated)
1828
+ else:
1829
+ # Dispatch validation set
1830
+ _LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{save_path.name}'")
1831
+ self._evaluate(save_dir=save_path,
1832
+ model_checkpoint=checkpoint_validated,
1833
+ data=None,
1834
+ format_configuration=val_format_configuration)
1835
+
1836
+ def _evaluate(self,
1837
+ save_dir: Union[str, Path],
1838
+ model_checkpoint: Union[Path, Literal["latest", "current"]],
1839
+ data: Optional[Union[DataLoader, Dataset]],
1840
+ format_configuration: Optional[Union[SequenceValueMetricsFormat,
1841
+ SequenceSequenceMetricsFormat]]):
1092
1842
  """
1093
- Moves the model to the specified device and updates the trainer's device setting.
1843
+ Private evaluation helper.
1844
+ """
1845
+ eval_loader = None
1846
+
1847
+ # load model checkpoint
1848
+ if isinstance(model_checkpoint, Path):
1849
+ self._load_checkpoint(path=model_checkpoint)
1850
+ elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback:
1851
+ path_to_latest = self._checkpoint_callback.best_checkpoint_path
1852
+ self._load_checkpoint(path_to_latest)
1853
+ elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback is None:
1854
+ _LOGGER.error(f"'model_checkpoint' set to '{MagicWords.LATEST}' but no checkpoint callback was found.")
1855
+ raise ValueError()
1856
+
1857
+ # Dataloader
1858
+ if isinstance(data, DataLoader):
1859
+ eval_loader = data
1860
+ elif isinstance(data, Dataset):
1861
+ # Create a new loader from the provided dataset
1862
+ eval_loader = DataLoader(data,
1863
+ batch_size=self._batch_size,
1864
+ shuffle=False,
1865
+ num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
1866
+ pin_memory=(self.device.type == "cuda"))
1867
+ else: # data is None, use the trainer's default validation dataset
1868
+ if self.validation_dataset is None:
1869
+ _LOGGER.error("Cannot evaluate. No data provided and no validation_dataset available in the trainer.")
1870
+ raise ValueError()
1871
+ eval_loader = DataLoader(self.validation_dataset,
1872
+ batch_size=self._batch_size,
1873
+ shuffle=False,
1874
+ num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
1875
+ pin_memory=(self.device.type == "cuda"))
1876
+
1877
+ if eval_loader is None:
1878
+ _LOGGER.error("Cannot evaluate. No valid data was provided or found.")
1879
+ raise ValueError()
1880
+
1881
+ all_preds, _, all_true = [], [], []
1882
+ for y_pred_b, y_prob_b, y_true_b in self._predict_for_eval(eval_loader):
1883
+ if y_pred_b is not None: all_preds.append(y_pred_b)
1884
+ if y_true_b is not None: all_true.append(y_true_b)
1885
+
1886
+ if not all_true:
1887
+ _LOGGER.error("Evaluation failed: No data was processed.")
1888
+ return
1889
+
1890
+ y_pred = np.concatenate(all_preds)
1891
+ y_true = np.concatenate(all_true)
1892
+
1893
+ # --- Routing Logic ---
1894
+ if self.kind == MLTaskKeys.SEQUENCE_VALUE:
1895
+ config = None
1896
+ if format_configuration and isinstance(format_configuration, SequenceValueMetricsFormat):
1897
+ config = format_configuration
1898
+ elif format_configuration:
1899
+ _LOGGER.warning(f"Wrong config type: Received {type(format_configuration).__name__}, expected SequenceValueMetricsFormat.")
1900
+
1901
+ sequence_to_value_metrics(y_true=y_true,
1902
+ y_pred=y_pred,
1903
+ save_dir=save_dir,
1904
+ config=config)
1905
+
1906
+ elif self.kind == MLTaskKeys.SEQUENCE_SEQUENCE:
1907
+ config = None
1908
+ if format_configuration and isinstance(format_configuration, SequenceSequenceMetricsFormat):
1909
+ config = format_configuration
1910
+ elif format_configuration:
1911
+ _LOGGER.warning(f"Wrong config type: Received {type(format_configuration).__name__}, expected SequenceSequenceMetricsFormat.")
1912
+
1913
+ sequence_to_sequence_metrics(y_true=y_true,
1914
+ y_pred=y_pred,
1915
+ save_dir=save_dir,
1916
+ config=config)
1917
+
1918
+ def finalize_model_training(self,
1919
+ save_dir: Union[str, Path],
1920
+ filename: str,
1921
+ last_training_sequence: np.ndarray,
1922
+ model_checkpoint: Union[Path, Literal['latest', 'current']]):
1923
+ """
1924
+ Saves a finalized, "inference-ready" model state to a .pth file.
1925
+
1926
+ This method saves the model's `state_dict` and the final epoch number.
1094
1927
 
1095
1928
  Args:
1096
- device (str): The target device (e.g., 'cuda', 'mps', 'cpu').
1929
+ save_dir (Union[str, Path]): The directory to save the finalized model.
1930
+ filename (str): The desired filename for the model (e.g., "final_model.pth").
1931
+ last_training_sequence (np.ndarray): The last un-scaled sequence from the training data, used for forecasting.
1932
+ model_checkpoint (Union[Path, Literal["latest", "current"]]):
1933
+ - Path: Loads the model state from a specific checkpoint file.
1934
+ - "latest": Loads the best model state saved by the `DragonModelCheckpoint` callback.
1935
+ - "current": Uses the model's state as it is at the end of the `fit()` call.
1097
1936
  """
1098
- self.device = self._validate_device(device)
1099
- self.model.to(self.device)
1100
- _LOGGER.info(f"Trainer and model moved to {self.device}.")
1937
+ # handle save path
1938
+ sanitized_filename = sanitize_filename(filename)
1939
+ if not sanitized_filename.endswith(".pth"):
1940
+ sanitized_filename = sanitized_filename + ".pth"
1941
+
1942
+ dir_path = make_fullpath(save_dir, make=True, enforce="directory")
1943
+ full_path = dir_path / sanitized_filename
1944
+
1945
+ # handle checkpoint
1946
+ if isinstance(model_checkpoint, Path):
1947
+ self._load_checkpoint(path=model_checkpoint)
1948
+ elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback:
1949
+ path_to_latest = self._checkpoint_callback.best_checkpoint_path
1950
+ self._load_checkpoint(path_to_latest)
1951
+ elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback is None:
1952
+ _LOGGER.error(f"'model_checkpoint' set to '{MagicWords.LATEST}' but no checkpoint callback was found.")
1953
+ raise ValueError()
1954
+ elif model_checkpoint == MagicWords.CURRENT:
1955
+ pass
1956
+ else:
1957
+ _LOGGER.error(f"Unknown 'model_checkpoint' parameter received '{model_checkpoint}'.")
1958
+
1959
+ # --- 1. Validate the provided initial sequence ---
1960
+ if not isinstance(last_training_sequence, np.ndarray):
1961
+ _LOGGER.error(f"'last_training_sequence' must be a numpy array. Got {type(last_training_sequence)}")
1962
+ raise TypeError()
1963
+ if last_training_sequence.ndim != 1:
1964
+ _LOGGER.error(f"'last_training_sequence' must be a 1D array. Got {last_training_sequence.ndim} dimensions.")
1965
+ raise ValueError()
1966
+
1967
+ # --- 2. Derive sequence_length from the array ---
1968
+ sequence_length = len(last_training_sequence)
1969
+ if sequence_length <= 0:
1970
+ _LOGGER.error(f"Length of 'last_training_sequence' cannot be zero.")
1971
+ raise ValueError()
1972
+
1973
+ # Create finalized data
1974
+ finalized_data = {
1975
+ PyTorchCheckpointKeys.EPOCH: self.epoch,
1976
+ PyTorchCheckpointKeys.MODEL_STATE: self.model.state_dict(),
1977
+ PyTorchCheckpointKeys.SEQUENCE_LENGTH: sequence_length,
1978
+ PyTorchCheckpointKeys.INITIAL_SEQUENCE: last_training_sequence
1979
+ }
1980
+
1981
+ torch.save(finalized_data, full_path)
1982
+
1983
+ _LOGGER.info(f"Finalized model weights saved to {full_path}.")
1101
1984
 
1102
1985
 
1103
1986
  def info():