dragon-ml-toolbox 14.3.1__py3-none-any.whl → 16.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dragon-ml-toolbox might be problematic. Click here for more details.

Files changed (44) hide show
  1. {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/METADATA +10 -5
  2. dragon_ml_toolbox-16.0.0.dist-info/RECORD +51 -0
  3. ml_tools/ETL_cleaning.py +20 -20
  4. ml_tools/ETL_engineering.py +23 -25
  5. ml_tools/GUI_tools.py +20 -20
  6. ml_tools/MICE_imputation.py +3 -3
  7. ml_tools/ML_callbacks.py +43 -26
  8. ml_tools/ML_configuration.py +309 -0
  9. ml_tools/ML_datasetmaster.py +220 -260
  10. ml_tools/ML_evaluation.py +317 -81
  11. ml_tools/ML_evaluation_multi.py +127 -36
  12. ml_tools/ML_inference.py +249 -207
  13. ml_tools/ML_models.py +13 -102
  14. ml_tools/ML_models_advanced.py +1 -1
  15. ml_tools/ML_optimization.py +12 -12
  16. ml_tools/ML_scaler.py +11 -11
  17. ml_tools/ML_sequence_datasetmaster.py +341 -0
  18. ml_tools/ML_sequence_evaluation.py +215 -0
  19. ml_tools/ML_sequence_inference.py +391 -0
  20. ml_tools/ML_sequence_models.py +139 -0
  21. ml_tools/ML_trainer.py +1247 -338
  22. ml_tools/ML_utilities.py +51 -2
  23. ml_tools/ML_vision_datasetmaster.py +262 -118
  24. ml_tools/ML_vision_evaluation.py +26 -6
  25. ml_tools/ML_vision_inference.py +117 -140
  26. ml_tools/ML_vision_models.py +15 -1
  27. ml_tools/ML_vision_transformers.py +233 -7
  28. ml_tools/PSO_optimization.py +6 -6
  29. ml_tools/SQL.py +4 -4
  30. ml_tools/{keys.py → _keys.py} +45 -1
  31. ml_tools/_schema.py +1 -1
  32. ml_tools/ensemble_evaluation.py +54 -11
  33. ml_tools/ensemble_inference.py +7 -33
  34. ml_tools/ensemble_learning.py +1 -1
  35. ml_tools/optimization_tools.py +2 -2
  36. ml_tools/path_manager.py +5 -5
  37. ml_tools/utilities.py +1 -2
  38. dragon_ml_toolbox-14.3.1.dist-info/RECORD +0 -48
  39. ml_tools/RNN_forecast.py +0 -56
  40. ml_tools/_ML_vision_recipe.py +0 -88
  41. {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/WHEEL +0 -0
  42. {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/licenses/LICENSE +0 -0
  43. {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
  44. {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/top_level.txt +0 -0
ml_tools/ML_trainer.py CHANGED
@@ -1,79 +1,79 @@
1
- from typing import List, Literal, Union, Optional, Callable, Dict, Any, Tuple
1
+ from typing import List, Literal, Union, Optional, Callable, Dict, Any
2
2
  from pathlib import Path
3
3
  from torch.utils.data import DataLoader, Dataset
4
4
  import torch
5
5
  from torch import nn
6
6
  import numpy as np
7
+ from abc import ABC, abstractmethod
7
8
 
8
- from .ML_callbacks import Callback, History, TqdmProgressBar, ModelCheckpoint
9
+ from .path_manager import make_fullpath, sanitize_filename
10
+ from .ML_callbacks import _Callback, History, TqdmProgressBar, DragonModelCheckpoint, DragonEarlyStopping, DragonLRScheduler
9
11
  from .ML_evaluation import classification_metrics, regression_metrics, plot_losses, shap_summary_plot, plot_attention_importance
10
12
  from .ML_evaluation_multi import multi_target_regression_metrics, multi_label_classification_metrics, multi_target_shap_summary_plot
13
+ from .ML_vision_evaluation import segmentation_metrics, object_detection_metrics
14
+ from .ML_sequence_evaluation import sequence_to_sequence_metrics, sequence_to_value_metrics
15
+ from .ML_configuration import (ClassificationMetricsFormat,
16
+ MultiClassificationMetricsFormat,
17
+ RegressionMetricsFormat,
18
+ SegmentationMetricsFormat,
19
+ SequenceValueMetricsFormat,
20
+ SequenceSequenceMetricsFormat)
21
+
11
22
  from ._script_info import _script_info
12
- from .keys import PyTorchLogKeys, PyTorchCheckpointKeys, DatasetKeys
23
+ from ._keys import PyTorchLogKeys, PyTorchCheckpointKeys, DatasetKeys, MLTaskKeys, MagicWords, DragonTrainerKeys
13
24
  from ._logger import _LOGGER
14
- from .path_manager import make_fullpath
15
- from .ML_vision_evaluation import segmentation_metrics, object_detection_metrics
16
25
 
17
26
 
18
27
  __all__ = [
19
- "MLTrainer",
20
- "ObjectDetectionTrainer"
28
+ "DragonTrainer",
29
+ "DragonDetectionTrainer",
30
+ "DragonSequenceTrainer"
21
31
  ]
22
32
 
23
-
24
- class MLTrainer:
25
- def __init__(self, model: nn.Module, train_dataset: Dataset, test_dataset: Dataset,
26
- kind: Literal["regression", "classification", "multi_target_regression", "multi_label_classification", "segmentation"],
27
- criterion: nn.Module, optimizer: torch.optim.Optimizer,
28
- device: Union[Literal['cuda', 'mps', 'cpu'],str], dataloader_workers: int = 2, callbacks: Optional[List[Callback]] = None):
29
- """
30
- Automates the training process of a PyTorch Model.
31
-
32
- Built-in Callbacks: `History`, `TqdmProgressBar`
33
-
34
- Args:
35
- model (nn.Module): The PyTorch model to train.
36
- train_dataset (Dataset): The training dataset.
37
- test_dataset (Dataset): The testing/validation dataset.
38
- kind (str): Can be 'regression', 'classification', 'multi_target_regression', 'multi_label_classification', or 'segmentation'.
39
- criterion (nn.Module): The loss function.
40
- optimizer (torch.optim.Optimizer): The optimizer.
41
- device (str): The device to run training on ('cpu', 'cuda', 'mps').
42
- dataloader_workers (int): Subprocesses for data loading.
43
- callbacks (List[Callback] | None): A list of callbacks to use during training.
44
-
45
- Note:
46
- - For **regression** and **multi_target_regression** tasks, suggested criterions include `nn.MSELoss` or `nn.L1Loss`.
47
-
48
- - For **single-label, multi-class classification** tasks, `nn.CrossEntropyLoss` is the standard choice.
33
+ class _BaseDragonTrainer(ABC):
34
+ """
35
+ Abstract base class for Dragon Trainers.
49
36
 
50
- - For **multi-label, binary classification** tasks (where each label is a 0 or 1), `nn.BCEWithLogitsLoss` is the correct choice as it treats each output as an independent binary problem.
51
-
52
- - For **segmentation** tasks, `nn.CrossEntropyLoss` (for multi-class) or `nn.BCEWithLogitsLoss` (for binary) are common.
53
- """
54
- if kind not in ["regression", "classification", "multi_target_regression", "multi_label_classification", "segmentation"]:
55
- raise ValueError(f"'{kind}' is not a valid task type.")
37
+ Handles the common training loop orchestration, checkpointing, callback
38
+ management, and device handling. Subclasses must implement the
39
+ task-specific logic (dataloaders, train/val steps, evaluation).
40
+ """
41
+ def __init__(self,
42
+ model: nn.Module,
43
+ optimizer: torch.optim.Optimizer,
44
+ device: Union[Literal['cuda', 'mps', 'cpu'],str],
45
+ dataloader_workers: int = 2,
46
+ checkpoint_callback: Optional[DragonModelCheckpoint] = None,
47
+ early_stopping_callback: Optional[DragonEarlyStopping] = None,
48
+ lr_scheduler_callback: Optional[DragonLRScheduler] = None,
49
+ extra_callbacks: Optional[List[_Callback]] = None):
56
50
 
57
51
  self.model = model
58
- self.train_dataset = train_dataset
59
- self.test_dataset = test_dataset
60
- self.kind = kind
61
- self.criterion = criterion
62
52
  self.optimizer = optimizer
63
53
  self.scheduler = None
64
54
  self.device = self._validate_device(device)
65
55
  self.dataloader_workers = dataloader_workers
66
56
 
67
- # Callback handler - History and TqdmProgressBar are added by default
57
+ # Callback handler
68
58
  default_callbacks = [History(), TqdmProgressBar()]
69
- user_callbacks = callbacks if callbacks is not None else []
59
+
60
+ self._checkpoint_callback = None
61
+ if checkpoint_callback:
62
+ default_callbacks.append(checkpoint_callback)
63
+ self._checkpoint_callback = checkpoint_callback
64
+ if early_stopping_callback:
65
+ default_callbacks.append(early_stopping_callback)
66
+ if lr_scheduler_callback:
67
+ default_callbacks.append(lr_scheduler_callback)
68
+
69
+ user_callbacks = extra_callbacks if extra_callbacks is not None else []
70
70
  self.callbacks = default_callbacks + user_callbacks
71
71
  self._set_trainer_on_callbacks()
72
72
 
73
73
  # Internal state
74
- self.train_loader = None
75
- self.test_loader = None
76
- self.history = {}
74
+ self.train_loader: Optional[DataLoader] = None
75
+ self.validation_loader: Optional[DataLoader] = None
76
+ self.history: Dict[str, List[Any]] = {}
77
77
  self.epoch = 0
78
78
  self.epochs = 0 # Total epochs for the fit run
79
79
  self.start_epoch = 1
@@ -96,32 +96,10 @@ class MLTrainer:
96
96
  for callback in self.callbacks:
97
97
  callback.set_trainer(self)
98
98
 
99
- def _create_dataloaders(self, batch_size: int, shuffle: bool):
100
- """Initializes the DataLoaders."""
101
- # Ensure stability on MPS devices by setting num_workers to 0
102
- loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
103
-
104
- self.train_loader = DataLoader(
105
- dataset=self.train_dataset,
106
- batch_size=batch_size,
107
- shuffle=shuffle,
108
- num_workers=loader_workers,
109
- pin_memory=("cuda" in self.device.type),
110
- drop_last=True # Drops the last batch if incomplete, selecting a good batch size is key.
111
- )
112
-
113
- self.test_loader = DataLoader(
114
- dataset=self.test_dataset,
115
- batch_size=batch_size,
116
- shuffle=False,
117
- num_workers=loader_workers,
118
- pin_memory=("cuda" in self.device.type)
119
- )
120
-
121
99
  def _load_checkpoint(self, path: Union[str, Path]):
122
100
  """Loads a training checkpoint to resume training."""
123
101
  p = make_fullpath(path, enforce="file")
124
- _LOGGER.info(f"Loading checkpoint from '{p.name}' to resume training...")
102
+ _LOGGER.info(f"Loading checkpoint from '{p.name}'...")
125
103
 
126
104
  try:
127
105
  checkpoint = torch.load(p, map_location=self.device)
@@ -132,7 +110,16 @@ class MLTrainer:
132
110
 
133
111
  self.model.load_state_dict(checkpoint[PyTorchCheckpointKeys.MODEL_STATE])
134
112
  self.optimizer.load_state_dict(checkpoint[PyTorchCheckpointKeys.OPTIMIZER_STATE])
135
- self.start_epoch = checkpoint.get(PyTorchCheckpointKeys.EPOCH, 0) + 1 # Resume on the *next* epoch
113
+ self.epoch = checkpoint.get(PyTorchCheckpointKeys.EPOCH, 0)
114
+ self.start_epoch = self.epoch + 1 # Resume on the *next* epoch
115
+
116
+ # --- Load History ---
117
+ if PyTorchCheckpointKeys.HISTORY in checkpoint:
118
+ self.history = checkpoint[PyTorchCheckpointKeys.HISTORY]
119
+ _LOGGER.info(f"Restored training history up to epoch {self.epoch}.")
120
+ else:
121
+ _LOGGER.warning("No 'history' found in checkpoint. A new history will be started.")
122
+ self.history = {} # Ensure it's at least an empty dict
136
123
 
137
124
  # --- Scheduler State Loading Logic ---
138
125
  scheduler_state_exists = PyTorchCheckpointKeys.SCHEDULER_STATE in checkpoint
@@ -162,7 +149,7 @@ class MLTrainer:
162
149
 
163
150
  # Restore callback states
164
151
  for cb in self.callbacks:
165
- if isinstance(cb, ModelCheckpoint) and PyTorchCheckpointKeys.BEST_SCORE in checkpoint:
152
+ if isinstance(cb, DragonModelCheckpoint) and PyTorchCheckpointKeys.BEST_SCORE in checkpoint:
166
153
  cb.best = checkpoint[PyTorchCheckpointKeys.BEST_SCORE]
167
154
  _LOGGER.info(f"Restored {cb.__class__.__name__} 'best' score to: {cb.best:.4f}")
168
155
 
@@ -173,7 +160,8 @@ class MLTrainer:
173
160
  raise
174
161
 
175
162
  def fit(self,
176
- epochs: int = 10,
163
+ save_dir: Union[str,Path],
164
+ epochs: int = 100,
177
165
  batch_size: int = 10,
178
166
  shuffle: bool = True,
179
167
  resume_from_checkpoint: Optional[Union[str, Path]] = None):
@@ -183,21 +171,15 @@ class MLTrainer:
183
171
  Returns the "History" callback dictionary.
184
172
 
185
173
  Args:
174
+ save_dir (str | Path): Directory to save the loss plot.
186
175
  epochs (int): The total number of epochs to train for.
187
176
  batch_size (int): The number of samples per batch.
188
177
  shuffle (bool): Whether to shuffle the training data at each epoch.
189
178
  resume_from_checkpoint (str | Path | None): Optional path to a checkpoint to resume training.
190
-
191
- Note:
192
- For regression tasks using `nn.MSELoss` or `nn.L1Loss`, the trainer
193
- automatically aligns the model's output tensor with the target tensor's
194
- shape using `output.view_as(target)`. This handles the common case
195
- where a model outputs a shape of `[batch_size, 1]` and the target has a
196
- shape of `[batch_size]`.
197
179
  """
198
180
  self.epochs = epochs
199
181
  self._batch_size = batch_size
200
- self._create_dataloaders(self._batch_size, shuffle)
182
+ self._create_dataloaders(self._batch_size, shuffle) # type: ignore
201
183
  self.model.to(self.device)
202
184
 
203
185
  if resume_from_checkpoint:
@@ -208,11 +190,19 @@ class MLTrainer:
208
190
 
209
191
  self._callbacks_hook('on_train_begin')
210
192
 
193
+ if not self.train_loader:
194
+ _LOGGER.error("Train loader is not initialized.")
195
+ raise ValueError()
196
+
197
+ if not self.validation_loader:
198
+ _LOGGER.error("Validation loader is not initialized.")
199
+ raise ValueError()
200
+
211
201
  for epoch in range(self.start_epoch, self.epochs + 1):
212
202
  self.epoch = epoch
213
- epoch_logs = {}
203
+ epoch_logs: Dict[str, Any] = {}
214
204
  self._callbacks_hook('on_epoch_begin', epoch, logs=epoch_logs)
215
-
205
+
216
206
  train_logs = self._train_step()
217
207
  epoch_logs.update(train_logs)
218
208
 
@@ -226,11 +216,185 @@ class MLTrainer:
226
216
  break
227
217
 
228
218
  self._callbacks_hook('on_train_end')
219
+
220
+ # Training History
221
+ plot_losses(self.history, save_dir=save_dir)
222
+
229
223
  return self.history
224
+
225
+ def _callbacks_hook(self, method_name: str, *args, **kwargs):
226
+ """Calls the specified method on all callbacks."""
227
+ for callback in self.callbacks:
228
+ method = getattr(callback, method_name)
229
+ method(*args, **kwargs)
230
+
231
+ def to_cpu(self):
232
+ """
233
+ Moves the model to the CPU and updates the trainer's device setting.
234
+
235
+ This is useful for running operations that require the CPU.
236
+ """
237
+ self.device = torch.device('cpu')
238
+ self.model.to(self.device)
239
+ _LOGGER.info("Trainer and model moved to CPU.")
240
+
241
+ def to_device(self, device: str):
242
+ """
243
+ Moves the model to the specified device and updates the trainer's device setting.
244
+
245
+ Args:
246
+ device (str): The target device (e.g., 'cuda', 'mps', 'cpu').
247
+ """
248
+ self.device = self._validate_device(device)
249
+ self.model.to(self.device)
250
+ _LOGGER.info(f"Trainer and model moved to {self.device}.")
251
+
252
+ # --- Abstract Methods ---
253
+ # These must be implemented by subclasses
254
+
255
+ @abstractmethod
256
+ def _create_dataloaders(self, batch_size: int, shuffle: bool):
257
+ """Initializes the DataLoaders."""
258
+ raise NotImplementedError
259
+
260
+ @abstractmethod
261
+ def _train_step(self) -> Dict[str, float]:
262
+ """Runs a single training epoch."""
263
+ raise NotImplementedError
264
+
265
+ @abstractmethod
266
+ def _validation_step(self) -> Dict[str, float]:
267
+ """Runs a single validation epoch."""
268
+ raise NotImplementedError
269
+
270
+ @abstractmethod
271
+ def evaluate(self, *args, **kwargs):
272
+ """Runs the full model evaluation."""
273
+ raise NotImplementedError
274
+
275
+ @abstractmethod
276
+ def _evaluate(self, *args, **kwargs):
277
+ """Internal evaluation helper."""
278
+ raise NotImplementedError
279
+
280
+ @abstractmethod
281
+ def finalize_model_training(self, *args, **kwargs):
282
+ """Saves the finalized model for inference."""
283
+ raise NotImplementedError
284
+
285
+
286
+ # --- DragonTrainer ----
287
+ class DragonTrainer(_BaseDragonTrainer):
288
+ def __init__(self,
289
+ model: nn.Module,
290
+ train_dataset: Dataset,
291
+ validation_dataset: Dataset,
292
+ kind: Literal["regression", "binary classification", "multiclass classification",
293
+ "multitarget regression", "multilabel binary classification",
294
+ "binary segmentation", "multiclass segmentation", "binary image classification", "multiclass image classification"],
295
+ optimizer: torch.optim.Optimizer,
296
+ device: Union[Literal['cuda', 'mps', 'cpu'],str],
297
+ checkpoint_callback: Optional[DragonModelCheckpoint],
298
+ early_stopping_callback: Optional[DragonEarlyStopping],
299
+ lr_scheduler_callback: Optional[DragonLRScheduler],
300
+ extra_callbacks: Optional[List[_Callback]] = None,
301
+ criterion: Union[nn.Module,Literal["auto"]] = "auto",
302
+ dataloader_workers: int = 2):
303
+ """
304
+ Automates the training process of a PyTorch Model.
305
+
306
+ Built-in Callbacks: `History`, `TqdmProgressBar`
307
+
308
+ Args:
309
+ model (nn.Module): The PyTorch model to train.
310
+ train_dataset (Dataset): The training dataset.
311
+ validation_dataset (Dataset): The validation dataset.
312
+ kind (str): Used to redirect to the correct process.
313
+ criterion (nn.Module | "auto"): The loss function to use. If "auto", it will be inferred from the selected task
314
+ optimizer (torch.optim.Optimizer): The optimizer.
315
+ device (str): The device to run training on ('cpu', 'cuda', 'mps').
316
+ dataloader_workers (int): Subprocesses for data loading.
317
+ extra_callbacks (List[Callback] | None): A list of extra callbacks to use during training.
318
+
319
+ Note:
320
+ - For **regression** and **multi_target_regression** tasks, suggested criterions include `nn.MSELoss` or `nn.L1Loss`. The model should output as many logits as existing targets.
321
+
322
+ - For **single-label, binary classification**, `nn.BCEWithLogitsLoss` is the standard choice. The model should output a single logit.
230
323
 
324
+ - For **single-label, multi-class classification** tasks, `nn.CrossEntropyLoss` is the standard choice. The model should output as many logits as existing classes.
325
+
326
+ - For **multi-label, binary classification** tasks (where each label is a 0 or 1), `nn.BCEWithLogitsLoss` is the correct choice as it treats each output as an independent binary problem. The model should output 1 logit per binary target.
327
+
328
+ - For **binary segmentation** tasks, `nn.BCEWithLogitsLoss` is common. The model should output a single logit.
329
+
330
+ - for **multiclass segmentation** tasks, `nn.CrossEntropyLoss` is the standard. The model should output as many logits as existing classes.
331
+ """
332
+ # Call the base class constructor with common parameters
333
+ super().__init__(
334
+ model=model,
335
+ optimizer=optimizer,
336
+ device=device,
337
+ dataloader_workers=dataloader_workers,
338
+ checkpoint_callback=checkpoint_callback,
339
+ early_stopping_callback=early_stopping_callback,
340
+ lr_scheduler_callback=lr_scheduler_callback,
341
+ extra_callbacks=extra_callbacks
342
+ )
343
+
344
+ if kind not in [MLTaskKeys.REGRESSION,
345
+ MLTaskKeys.BINARY_CLASSIFICATION,
346
+ MLTaskKeys.MULTICLASS_CLASSIFICATION,
347
+ MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION,
348
+ MLTaskKeys.MULTITARGET_REGRESSION,
349
+ MLTaskKeys.BINARY_SEGMENTATION,
350
+ MLTaskKeys.MULTICLASS_SEGMENTATION,
351
+ MLTaskKeys.BINARY_IMAGE_CLASSIFICATION,
352
+ MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION]:
353
+ raise ValueError(f"'{kind}' is not a valid task type.")
354
+
355
+ self.train_dataset = train_dataset
356
+ self.validation_dataset = validation_dataset
357
+ self.kind = kind
358
+ self._classification_threshold: float = 0.5
359
+
360
+ # loss function
361
+ if criterion == "auto":
362
+ if kind in [MLTaskKeys.REGRESSION, MLTaskKeys.MULTITARGET_REGRESSION]:
363
+ self.criterion = nn.MSELoss()
364
+ elif kind in [MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.BINARY_IMAGE_CLASSIFICATION, MLTaskKeys.BINARY_SEGMENTATION, MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION]:
365
+ self.criterion = nn.BCEWithLogitsLoss()
366
+ elif kind in [MLTaskKeys.MULTICLASS_CLASSIFICATION, MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION, MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION, MLTaskKeys.MULTICLASS_SEGMENTATION]:
367
+ self.criterion = nn.CrossEntropyLoss()
368
+ else:
369
+ self.criterion = criterion
370
+
371
+ def _create_dataloaders(self, batch_size: int, shuffle: bool):
372
+ """Initializes the DataLoaders."""
373
+ # Ensure stability on MPS devices by setting num_workers to 0
374
+ loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
375
+
376
+ self.train_loader = DataLoader(
377
+ dataset=self.train_dataset,
378
+ batch_size=batch_size,
379
+ shuffle=shuffle,
380
+ num_workers=loader_workers,
381
+ pin_memory=("cuda" in self.device.type),
382
+ drop_last=True # Drops the last batch if incomplete, selecting a good batch size is key.
383
+ )
384
+
385
+ self.validation_loader = DataLoader(
386
+ dataset=self.validation_dataset,
387
+ batch_size=batch_size,
388
+ shuffle=False,
389
+ num_workers=loader_workers,
390
+ pin_memory=("cuda" in self.device.type)
391
+ )
392
+
231
393
  def _train_step(self):
232
394
  self.model.train()
233
395
  running_loss = 0.0
396
+ total_samples = 0
397
+
234
398
  for batch_idx, (features, target) in enumerate(self.train_loader): # type: ignore
235
399
  # Create a log dictionary for the batch
236
400
  batch_logs = {
@@ -244,9 +408,21 @@ class MLTrainer:
244
408
 
245
409
  output = self.model(features)
246
410
 
247
- # Apply shape correction only for single-target regression
248
- if self.kind == "regression":
249
- output = output.view_as(target)
411
+ # --- Label Type/Shape Correction ---
412
+ # Cast target to float for BCE-based losses
413
+ if self.kind in MLTaskKeys.ALL_BINARY_TASKS:
414
+ target = target.float()
415
+
416
+ # Reshape output to match target for single-logit tasks
417
+ if self.kind in [MLTaskKeys.REGRESSION, MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.BINARY_IMAGE_CLASSIFICATION]:
418
+ # If model outputs [N, 1] and target is [N], squeeze output
419
+ if output.ndim == 2 and output.shape[1] == 1 and target.ndim == 1:
420
+ output = output.squeeze(1)
421
+
422
+ if self.kind == MLTaskKeys.BINARY_SEGMENTATION:
423
+ # If model outputs [N, 1, H, W] and target is [N, H, W], squeeze output
424
+ if output.ndim == 4 and output.shape[1] == 1 and target.ndim == 3:
425
+ output = output.squeeze(1)
250
426
 
251
427
  loss = self.criterion(output, target)
252
428
 
@@ -255,34 +431,58 @@ class MLTrainer:
255
431
 
256
432
  # Calculate batch loss and update running loss for the epoch
257
433
  batch_loss = loss.item()
258
- running_loss += batch_loss * features.size(0)
434
+ batch_size = features.size(0)
435
+ running_loss += batch_loss * batch_size # Accumulate total loss
436
+ total_samples += batch_size # total samples
259
437
 
260
438
  # Add the batch loss to the logs and call the end-of-batch hook
261
439
  batch_logs[PyTorchLogKeys.BATCH_LOSS] = batch_loss
262
440
  self._callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
441
+
442
+ if total_samples == 0:
443
+ _LOGGER.warning("No samples processed in a train_step. Returning 0 loss.")
444
+ return {PyTorchLogKeys.TRAIN_LOSS: 0.0}
263
445
 
264
- return {PyTorchLogKeys.TRAIN_LOSS: running_loss / len(self.train_loader.dataset)} # type: ignore
446
+ return {PyTorchLogKeys.TRAIN_LOSS: running_loss / total_samples} # type: ignore
265
447
 
266
448
  def _validation_step(self):
267
449
  self.model.eval()
268
450
  running_loss = 0.0
451
+
269
452
  with torch.no_grad():
270
- for features, target in self.test_loader: # type: ignore
453
+ for features, target in self.validation_loader: # type: ignore
271
454
  features, target = features.to(self.device), target.to(self.device)
272
455
 
273
456
  output = self.model(features)
274
- # Apply shape correction only for single-target regression
275
- if self.kind == "regression":
276
- output = output.view_as(target)
457
+
458
+ # --- Label Type/Shape Correction ---
459
+ # Cast target to float for BCE-based losses
460
+ if self.kind in MLTaskKeys.ALL_BINARY_TASKS:
461
+ target = target.float()
462
+
463
+ # Reshape output to match target for single-logit tasks
464
+ if self.kind in [MLTaskKeys.REGRESSION, MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.BINARY_IMAGE_CLASSIFICATION]:
465
+ # If model outputs [N, 1] and target is [N], squeeze output
466
+ if output.ndim == 2 and output.shape[1] == 1 and target.ndim == 1:
467
+ output = output.squeeze(1)
468
+
469
+ if self.kind == MLTaskKeys.BINARY_SEGMENTATION:
470
+ # If model outputs [N, 1, H, W] and target is [N, H, W], squeeze output
471
+ if output.ndim == 4 and output.shape[1] == 1 and target.ndim == 3:
472
+ output = output.squeeze(1)
277
473
 
278
474
  loss = self.criterion(output, target)
279
475
 
280
476
  running_loss += loss.item() * features.size(0)
477
+
478
+ if not self.validation_loader.dataset: # type: ignore
479
+ _LOGGER.warning("No samples processed in _validation_step. Returning 0 loss.")
480
+ return {PyTorchLogKeys.VAL_LOSS: 0.0}
281
481
 
282
- logs = {PyTorchLogKeys.VAL_LOSS: running_loss / len(self.test_loader.dataset)} # type: ignore
482
+ logs = {PyTorchLogKeys.VAL_LOSS: running_loss / len(self.validation_loader.dataset)} # type: ignore
283
483
  return logs
284
484
 
285
- def _predict_for_eval(self, dataloader: DataLoader, classification_threshold: float = 0.5):
485
+ def _predict_for_eval(self, dataloader: DataLoader):
286
486
  """
287
487
  Private method to yield model predictions batch by batch for evaluation.
288
488
 
@@ -293,6 +493,7 @@ class MLTrainer:
293
493
  """
294
494
  self.model.eval()
295
495
  self.model.to(self.device)
496
+
296
497
  with torch.no_grad():
297
498
  for features, target in dataloader:
298
499
  features = features.to(self.device)
@@ -302,25 +503,64 @@ class MLTrainer:
302
503
  y_prob_batch = None
303
504
  y_true_batch = None
304
505
 
305
- if self.kind in ["regression", "multi_target_regression"]:
506
+ if self.kind in [MLTaskKeys.REGRESSION, MLTaskKeys.MULTITARGET_REGRESSION]:
306
507
  y_pred_batch = output.numpy()
307
508
  y_true_batch = target.numpy()
509
+
510
+ elif self.kind in [MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.BINARY_IMAGE_CLASSIFICATION]:
511
+ # Assumes model output is [N, 1] (a single logit)
512
+ # Squeeze output from [N, 1] to [N] if necessary
513
+ if output.ndim == 2 and output.shape[1] == 1:
514
+ output = output.squeeze(1)
515
+
516
+ probs_pos = torch.sigmoid(output) # Probability of positive class
517
+ preds = (probs_pos >= self._classification_threshold).int()
518
+ y_pred_batch = preds.numpy()
519
+ # For metrics (like ROC AUC), we often need probs for *both* classes
520
+ # Create an [N, 2] array: [prob_class_0, prob_class_1]
521
+ probs_neg = 1.0 - probs_pos
522
+ y_prob_batch = torch.stack([probs_neg, probs_pos], dim=1).numpy()
523
+ y_true_batch = target.numpy()
308
524
 
309
- elif self.kind == "classification":
525
+ elif self.kind in [MLTaskKeys.MULTICLASS_CLASSIFICATION, MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION]:
526
+ num_classes = output.shape[1]
527
+ if num_classes < 3:
528
+ # Optional: warn the user they are using the wrong kind
529
+ wrong_class = MLTaskKeys.MULTICLASS_CLASSIFICATION if self.kind == MLTaskKeys.MULTICLASS_CLASSIFICATION else MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION
530
+ recommended_class = MLTaskKeys.BINARY_CLASSIFICATION if self.kind == MLTaskKeys.MULTICLASS_CLASSIFICATION else MLTaskKeys.BINARY_IMAGE_CLASSIFICATION
531
+ _LOGGER.warning(f"'{wrong_class}' kind used with {num_classes} classes. Consider using '{recommended_class}' instead.")
532
+
310
533
  probs = torch.softmax(output, dim=1)
311
534
  preds = torch.argmax(probs, dim=1)
312
535
  y_pred_batch = preds.numpy()
313
536
  y_prob_batch = probs.numpy()
314
537
  y_true_batch = target.numpy()
315
538
 
316
- elif self.kind == "multi_label_classification":
539
+ elif self.kind == MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION:
317
540
  probs = torch.sigmoid(output)
318
- preds = (probs >= classification_threshold).int()
541
+ preds = (probs >= self._classification_threshold).int()
319
542
  y_pred_batch = preds.numpy()
320
543
  y_prob_batch = probs.numpy()
321
544
  y_true_batch = target.numpy()
545
+
546
+ elif self.kind == MLTaskKeys.BINARY_SEGMENTATION:
547
+ # Assumes model output is [N, 1, H, W] (logits for positive class)
548
+ probs_pos = torch.sigmoid(output) # Shape [N, 1, H, W]
549
+ preds = (probs_pos >= self._classification_threshold).int() # Shape [N, 1, H, W]
550
+
551
+ # Squeeze preds to [N, H, W] (class indices 0 or 1)
552
+ y_pred_batch = preds.squeeze(1).numpy()
553
+
554
+ # Create [N, 2, H, W] probs for consistency
555
+ probs_neg = 1.0 - probs_pos
556
+ y_prob_batch = torch.cat([probs_neg, probs_pos], dim=1).numpy()
557
+
558
+ # Handle target shape [N, 1, H, W] -> [N, H, W]
559
+ if target.ndim == 4 and target.shape[1] == 1:
560
+ target = target.squeeze(1)
561
+ y_true_batch = target.numpy()
322
562
 
323
- elif self.kind == "segmentation":
563
+ elif self.kind == MLTaskKeys.MULTICLASS_SEGMENTATION:
324
564
  # output shape [N, C, H, W]
325
565
  probs = torch.softmax(output, dim=1)
326
566
  preds = torch.argmax(probs, dim=1) # shape [N, H, W]
@@ -333,24 +573,161 @@ class MLTrainer:
333
573
  y_true_batch = target.numpy()
334
574
 
335
575
  yield y_pred_batch, y_prob_batch, y_true_batch
336
-
337
- def evaluate(self, save_dir: Union[str, Path], data: Optional[Union[DataLoader, Dataset]] = None, classification_threshold: float = 0.5):
576
+
577
+ def evaluate(self,
578
+ save_dir: Union[str, Path],
579
+ model_checkpoint: Union[Path, Literal["latest", "current"]],
580
+ classification_threshold: Optional[float] = None,
581
+ test_data: Optional[Union[DataLoader, Dataset]] = None,
582
+ val_format_configuration: Optional[Union[ClassificationMetricsFormat,
583
+ MultiClassificationMetricsFormat,
584
+ RegressionMetricsFormat,
585
+ SegmentationMetricsFormat]]=None,
586
+ test_format_configuration: Optional[Union[ClassificationMetricsFormat,
587
+ MultiClassificationMetricsFormat,
588
+ RegressionMetricsFormat,
589
+ SegmentationMetricsFormat]]=None):
338
590
  """
339
591
  Evaluates the model, routing to the correct evaluation function based on task `kind`.
340
592
 
341
593
  Args:
594
+ model_checkpoint ('auto' | Path | None):
595
+ - Path to a valid checkpoint for the model. The state of the trained model will be overwritten in place.
596
+ - If 'latest', the latest checkpoint will be loaded if a DragonModelCheckpoint was provided. The state of the trained model will be overwritten in place.
597
+ - If 'current', use the current state of the trained model up the latest trained epoch.
342
598
  save_dir (str | Path): Directory to save all reports and plots.
343
- data (DataLoader | Dataset | None): The data to evaluate on. If None, defaults to the trainer's internal test_dataset.
344
- classification_threshold (float): Probability threshold for multi-label tasks.
599
+ classification_threshold (float | None): Used for tasks using a binary approach (binary classification, binary segmentation, multilabel binary classification)
600
+ test_data (DataLoader | Dataset | None): Optional Test data to evaluate the model performance. Validation and Test metrics will be saved to subdirectories.
601
+ val_format_configuration: Optional configuration for metric format output for the validation set.
602
+ test_format_configuration: Optional configuration for metric format output for the test set.
603
+ """
604
+ # Validate model checkpoint
605
+ if isinstance(model_checkpoint, Path):
606
+ checkpoint_validated = make_fullpath(model_checkpoint, enforce="file")
607
+ elif model_checkpoint in [MagicWords.LATEST, MagicWords.CURRENT]:
608
+ checkpoint_validated = model_checkpoint
609
+ else:
610
+ _LOGGER.error(f"'model_checkpoint' must be a Path object, or the string '{MagicWords.LATEST}', or the string '{MagicWords.CURRENT}'.")
611
+ raise ValueError()
612
+
613
+ # Validate classification threshold
614
+ if self.kind not in MLTaskKeys.ALL_BINARY_TASKS:
615
+ # dummy value for tasks that do not need it
616
+ threshold_validated = 0.5
617
+ elif classification_threshold is None:
618
+ # it should have been provided for binary tasks
619
+ _LOGGER.error(f"The classification threshold must be provided for '{self.kind}'.")
620
+ raise ValueError()
621
+ elif classification_threshold <= 0.0 or classification_threshold >= 1.0:
622
+ # Invalid float
623
+ _LOGGER.error(f"A classification threshold of {classification_threshold} is invalid. Must be in the range (0.0 - 1.0).")
624
+ raise ValueError()
625
+ else:
626
+ threshold_validated = classification_threshold
627
+
628
+ # Validate val configuration
629
+ if val_format_configuration is not None:
630
+ if not isinstance(val_format_configuration, (ClassificationMetricsFormat,
631
+ MultiClassificationMetricsFormat,
632
+ RegressionMetricsFormat,
633
+ SegmentationMetricsFormat)):
634
+ _LOGGER.error(f"Invalid 'format_configuration': '{type(val_format_configuration)}'.")
635
+ raise ValueError()
636
+ else:
637
+ val_configuration_validated = val_format_configuration
638
+ else: # config is None
639
+ val_configuration_validated = None
640
+
641
+ # Validate directory
642
+ save_path = make_fullpath(save_dir, make=True, enforce="directory")
643
+
644
+ # Validate test data and dispatch
645
+ if test_data is not None:
646
+ if not isinstance(test_data, (DataLoader, Dataset)):
647
+ _LOGGER.error(f"Invalid type for 'test_data': '{type(test_data)}'.")
648
+ raise ValueError()
649
+ test_data_validated = test_data
650
+
651
+ validation_metrics_path = save_path / DragonTrainerKeys.VALIDATION_METRICS_DIR
652
+ test_metrics_path = save_path / DragonTrainerKeys.TEST_METRICS_DIR
653
+
654
+ # Dispatch validation set
655
+ _LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{DragonTrainerKeys.VALIDATION_METRICS_DIR}'")
656
+ self._evaluate(save_dir=validation_metrics_path,
657
+ model_checkpoint=checkpoint_validated,
658
+ classification_threshold=threshold_validated,
659
+ data=None,
660
+ format_configuration=val_configuration_validated)
661
+
662
+ # Validate test configuration
663
+ if test_format_configuration is not None:
664
+ if not isinstance(test_format_configuration, (ClassificationMetricsFormat,
665
+ MultiClassificationMetricsFormat,
666
+ RegressionMetricsFormat,
667
+ SegmentationMetricsFormat)):
668
+ warning_message_type = f"Invalid test_format_configuration': '{type(val_format_configuration)}'."
669
+ if val_configuration_validated is not None:
670
+ warning_message_type += " 'val_format_configuration' will be used for the test set metrics output."
671
+ test_configuration_validated = val_configuration_validated
672
+ else:
673
+ warning_message_type += " Using default format."
674
+ test_configuration_validated = None
675
+ _LOGGER.warning(warning_message_type)
676
+ else:
677
+ test_configuration_validated = test_format_configuration
678
+ else: #config is None
679
+ test_configuration_validated = None
680
+
681
+ # Dispatch test set
682
+ _LOGGER.info(f"Evaluating on test dataset. Metrics will be saved to '{DragonTrainerKeys.TEST_METRICS_DIR}'")
683
+ self._evaluate(save_dir=test_metrics_path,
684
+ model_checkpoint="current",
685
+ classification_threshold=threshold_validated,
686
+ data=test_data_validated,
687
+ format_configuration=test_configuration_validated)
688
+ else:
689
+ # Dispatch validation set
690
+ _LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{save_path.name}'")
691
+ self._evaluate(save_dir=save_path,
692
+ model_checkpoint=checkpoint_validated,
693
+ classification_threshold=threshold_validated,
694
+ data=None,
695
+ format_configuration=val_configuration_validated)
696
+
697
+ def _evaluate(self,
698
+ save_dir: Union[str, Path],
699
+ model_checkpoint: Union[Path, Literal["latest", "current"]],
700
+ classification_threshold: float,
701
+ data: Optional[Union[DataLoader, Dataset]],
702
+ format_configuration: Optional[Union[ClassificationMetricsFormat,
703
+ MultiClassificationMetricsFormat,
704
+ RegressionMetricsFormat,
705
+ SegmentationMetricsFormat]]):
706
+ """
707
+ Changed to a private helper function.
345
708
  """
346
709
  dataset_for_names = None
347
710
  eval_loader = None
348
-
711
+
712
+ # set threshold
713
+ self._classification_threshold = classification_threshold
714
+
715
+ # load model checkpoint
716
+ if isinstance(model_checkpoint, Path):
717
+ self._load_checkpoint(path=model_checkpoint)
718
+ elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback:
719
+ path_to_latest = self._checkpoint_callback.best_checkpoint_path
720
+ self._load_checkpoint(path_to_latest)
721
+ elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback is None:
722
+ _LOGGER.error(f"'model_checkpoint' set to '{MagicWords.LATEST}' but no checkpoint callback was found.")
723
+ raise ValueError()
724
+
725
+ # Dataloader
349
726
  if isinstance(data, DataLoader):
350
727
  eval_loader = data
351
728
  # Try to get the dataset from the loader for fetching target names
352
729
  if hasattr(data, 'dataset'):
353
- dataset_for_names = data.dataset
730
+ dataset_for_names = data.dataset # type: ignore
354
731
  elif isinstance(data, Dataset):
355
732
  # Create a new loader from the provided dataset
356
733
  eval_loader = DataLoader(data,
@@ -360,26 +737,26 @@ class MLTrainer:
360
737
  pin_memory=(self.device.type == "cuda"))
361
738
  dataset_for_names = data
362
739
  else: # data is None, use the trainer's default test dataset
363
- if self.test_dataset is None:
740
+ if self.validation_dataset is None:
364
741
  _LOGGER.error("Cannot evaluate. No data provided and no test_dataset available in the trainer.")
365
742
  raise ValueError()
366
743
  # Create a fresh DataLoader from the test_dataset
367
- eval_loader = DataLoader(self.test_dataset,
744
+ eval_loader = DataLoader(self.validation_dataset,
368
745
  batch_size=self._batch_size,
369
746
  shuffle=False,
370
747
  num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
371
748
  pin_memory=(self.device.type == "cuda"))
372
749
 
373
- dataset_for_names = self.test_dataset
750
+ dataset_for_names = self.validation_dataset
374
751
 
375
752
  if eval_loader is None:
376
753
  _LOGGER.error("Cannot evaluate. No valid data was provided or found.")
377
754
  raise ValueError()
378
755
 
379
- print("\n--- Model Evaluation ---")
756
+ # print("\n--- Model Evaluation ---")
380
757
 
381
758
  all_preds, all_probs, all_true = [], [], []
382
- for y_pred_b, y_prob_b, y_true_b in self._predict_for_eval(eval_loader, classification_threshold):
759
+ for y_pred_b, y_prob_b, y_true_b in self._predict_for_eval(eval_loader):
383
760
  if y_pred_b is not None: all_preds.append(y_pred_b)
384
761
  if y_prob_b is not None: all_probs.append(y_prob_b)
385
762
  if y_true_b is not None: all_true.append(y_true_b)
@@ -393,22 +770,55 @@ class MLTrainer:
393
770
  y_prob = np.concatenate(all_probs) if all_probs else None
394
771
 
395
772
  # --- Routing Logic ---
396
- if self.kind == "regression":
397
- regression_metrics(y_true.flatten(), y_pred.flatten(), save_dir)
398
-
399
- elif self.kind == "classification":
400
- classification_metrics(save_dir, y_true, y_pred, y_prob)
401
-
402
- elif self.kind == "multi_target_regression":
773
+ if self.kind == MLTaskKeys.REGRESSION:
774
+ # Check configuration
775
+ config = None
776
+ if format_configuration and isinstance(format_configuration, RegressionMetricsFormat):
777
+ config = format_configuration
778
+ elif format_configuration:
779
+ _LOGGER.warning(f"Wrong configuration type: Received {type(format_configuration).__name__}, expected RegressionMetricsFormat.")
780
+
781
+ regression_metrics(y_true=y_true.flatten(),
782
+ y_pred=y_pred.flatten(),
783
+ save_dir=save_dir,
784
+ config=config)
785
+
786
+ elif self.kind in [MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.BINARY_IMAGE_CLASSIFICATION, MLTaskKeys.MULTICLASS_CLASSIFICATION, MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION]:
787
+ # Check configuration
788
+ config = None
789
+ if format_configuration and isinstance(format_configuration, ClassificationMetricsFormat):
790
+ config = format_configuration
791
+ elif format_configuration:
792
+ _LOGGER.warning(f"Wrong configuration type: Received {type(format_configuration).__name__}, expected ClassificationMetricsFormat.")
793
+
794
+ classification_metrics(save_dir=save_dir,
795
+ y_true=y_true,
796
+ y_pred=y_pred,
797
+ y_prob=y_prob,
798
+ config=config)
799
+
800
+ elif self.kind == MLTaskKeys.MULTITARGET_REGRESSION:
403
801
  try:
404
802
  target_names = dataset_for_names.target_names # type: ignore
405
803
  except AttributeError:
406
804
  num_targets = y_true.shape[1]
407
805
  target_names = [f"target_{i}" for i in range(num_targets)]
408
806
  _LOGGER.warning(f"Dataset has no 'target_names' attribute. Using generic names.")
409
- multi_target_regression_metrics(y_true, y_pred, target_names, save_dir)
807
+
808
+ # Check configuration
809
+ config = None
810
+ if format_configuration and isinstance(format_configuration, RegressionMetricsFormat):
811
+ config = format_configuration
812
+ elif format_configuration:
813
+ _LOGGER.warning(f"Wrong configuration type: Received {type(format_configuration).__name__}, expected RegressionMetricsFormat.")
814
+
815
+ multi_target_regression_metrics(y_true=y_true,
816
+ y_pred=y_pred,
817
+ target_names=target_names,
818
+ save_dir=save_dir,
819
+ config=config)
410
820
 
411
- elif self.kind == "multi_label_classification":
821
+ elif self.kind == MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION:
412
822
  try:
413
823
  target_names = dataset_for_names.target_names # type: ignore
414
824
  except AttributeError:
@@ -419,9 +829,22 @@ class MLTrainer:
419
829
  if y_prob is None:
420
830
  _LOGGER.error("Evaluation for multi_label_classification requires probabilities (y_prob).")
421
831
  return
422
- multi_label_classification_metrics(y_true, y_prob, target_names, save_dir, classification_threshold)
423
832
 
424
- elif self.kind == "segmentation":
833
+ # Check configuration
834
+ config = None
835
+ if format_configuration and isinstance(format_configuration, MultiClassificationMetricsFormat):
836
+ config = format_configuration
837
+ elif format_configuration:
838
+ _LOGGER.warning(f"Wrong configuration type: Received {type(format_configuration).__name__}, expected MultiClassificationMetricsFormat.")
839
+
840
+ multi_label_classification_metrics(y_true=y_true,
841
+ y_pred=y_pred,
842
+ y_prob=y_prob,
843
+ target_names=target_names,
844
+ save_dir=save_dir,
845
+ config=config)
846
+
847
+ elif self.kind in [MLTaskKeys.BINARY_SEGMENTATION, MLTaskKeys.MULTICLASS_SEGMENTATION]:
425
848
  class_names = None
426
849
  try:
427
850
  # Try to get 'classes' from VisionDatasetMaker
@@ -443,10 +866,18 @@ class MLTrainer:
443
866
  class_names = [f"Class {i}" for i in labels]
444
867
  _LOGGER.warning(f"Dataset has no 'classes' or 'target_names' attribute. Using generic names.")
445
868
 
446
- segmentation_metrics(y_true, y_pred, save_dir, class_names=class_names)
447
-
448
- print("\n--- Training History ---")
449
- plot_losses(self.history, save_dir=save_dir)
869
+ # Check configuration
870
+ config = None
871
+ if format_configuration and isinstance(format_configuration, SegmentationMetricsFormat):
872
+ config = format_configuration
873
+ elif format_configuration:
874
+ _LOGGER.warning(f"Wrong configuration type: Received {type(format_configuration).__name__}, expected SegmentationMetricsFormat.")
875
+
876
+ segmentation_metrics(y_true=y_true,
877
+ y_pred=y_pred,
878
+ save_dir=save_dir,
879
+ class_names=class_names,
880
+ config=config)
450
881
 
451
882
  def explain(self,
452
883
  save_dir: Union[str,Path],
@@ -502,7 +933,7 @@ class MLTrainer:
502
933
  rand_indices = torch.randperm(full_data.size(0))[:num_samples]
503
934
  return full_data[rand_indices]
504
935
 
505
- print(f"\n--- Preparing SHAP Data (sampling up to {n_samples} instances) ---")
936
+ # print(f"\n--- Preparing SHAP Data (sampling up to {n_samples} instances) ---")
506
937
 
507
938
  # 1. Get background data from the trainer's train_dataset
508
939
  background_data = _get_random_sample(self.train_dataset, n_samples)
@@ -511,7 +942,7 @@ class MLTrainer:
511
942
  return
512
943
 
513
944
  # 2. Determine target dataset and get explanation instances
514
- target_dataset = explain_dataset if explain_dataset is not None else self.test_dataset
945
+ target_dataset = explain_dataset if explain_dataset is not None else self.validation_dataset
515
946
  instances_to_explain = _get_random_sample(target_dataset, n_samples)
516
947
  if instances_to_explain is None:
517
948
  _LOGGER.error("Explanation dataset is empty or invalid. Skipping SHAP analysis.")
@@ -530,7 +961,7 @@ class MLTrainer:
530
961
  self.model.to(self.device)
531
962
 
532
963
  # 3. Call the plotting function
533
- if self.kind in ["regression", "classification"]:
964
+ if self.kind in [MLTaskKeys.REGRESSION, MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.MULTICLASS_CLASSIFICATION]:
534
965
  shap_summary_plot(
535
966
  model=self.model,
536
967
  background_data=background_data,
@@ -540,7 +971,7 @@ class MLTrainer:
540
971
  explainer_type=explainer_type,
541
972
  device=self.device
542
973
  )
543
- elif self.kind in ["multi_target_regression", "multi_label_classification"]:
974
+ elif self.kind in [MLTaskKeys.MULTITARGET_REGRESSION, MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION]:
544
975
  # try to get target names
545
976
  if target_names is None:
546
977
  target_names = []
@@ -610,17 +1041,15 @@ class MLTrainer:
610
1041
  plot_n_features (int): Number of top features to plot.
611
1042
  """
612
1043
 
613
- print("\n--- Attention Analysis ---")
1044
+ # print("\n--- Attention Analysis ---")
614
1045
 
615
1046
  # --- Step 1: Check if the model supports this explanation ---
616
1047
  if not getattr(self.model, 'has_interpretable_attention', False):
617
- _LOGGER.warning(
618
- "Model is not flagged for interpretable attention analysis. Skipping. This is the correct behavior for models like MultiHeadAttentionMLP."
619
- )
1048
+ _LOGGER.warning("Model is not compatible with interpretable attention analysis. Skipping.")
620
1049
  return
621
1050
 
622
1051
  # --- Step 2: Set up the dataloader ---
623
- dataset_to_use = explain_dataset if explain_dataset is not None else self.test_dataset
1052
+ dataset_to_use = explain_dataset if explain_dataset is not None else self.validation_dataset
624
1053
  if not isinstance(dataset_to_use, Dataset):
625
1054
  _LOGGER.error("The explanation dataset is empty or invalid. Skipping attention analysis.")
626
1055
  return
@@ -655,40 +1084,111 @@ class MLTrainer:
655
1084
  )
656
1085
  else:
657
1086
  _LOGGER.error("No attention weights were collected from the model.")
658
-
659
- def _callbacks_hook(self, method_name: str, *args, **kwargs):
660
- """Calls the specified method on all callbacks."""
661
- for callback in self.callbacks:
662
- method = getattr(callback, method_name)
663
- method(*args, **kwargs)
664
-
665
- def to_cpu(self):
666
- """
667
- Moves the model to the CPU and updates the trainer's device setting.
668
1087
 
669
- This is useful for running operations that require the CPU.
670
- """
671
- self.device = torch.device('cpu')
672
- self.model.to(self.device)
673
- _LOGGER.info("Trainer and model moved to CPU.")
674
-
675
- def to_device(self, device: str):
1088
+ def finalize_model_training(self,
1089
+ save_dir: Union[str, Path],
1090
+ filename: str,
1091
+ model_checkpoint: Union[Path, Literal['latest', 'current']],
1092
+ classification_threshold: Optional[float]=None,
1093
+ class_map: Optional[Dict[str,int]]=None):
676
1094
  """
677
- Moves the model to the specified device and updates the trainer's device setting.
1095
+ Saves a finalized, "inference-ready" model state to a .pth file.
1096
+
1097
+ This method saves the model's `state_dict`, the final epoch number, and
1098
+ an optional classification threshold required for binary-based tasks (binary classification, binary segmentation,
1099
+ multilabel binary classification).
678
1100
 
679
1101
  Args:
680
- device (str): The target device (e.g., 'cuda', 'mps', 'cpu').
1102
+ save_dir (str | Path): The directory to save the finalized model.
1103
+ filename (str): The desired filename for the saved file.
1104
+ model_checkpoint (Path | "latest" | "current"):
1105
+ - Path: Loads the model state from a specific checkpoint file.
1106
+ - "latest": Loads the best model state saved by the `DragonModelCheckpoint` callback.
1107
+ - "current": Uses the model's state as it is at the end of the `fit()` call.
1108
+ classification_threshold (float, None):
1109
+ Required for `binary classification`, `binary segmentation`, and
1110
+ `multilabel binary classification`. This is the threshold (0.0-1.0)
1111
+ used to convert probabilities to class labels.
1112
+ class_map (Dict[str, int] | None): Sets the class name mapping to translate predicted integer labels back into string names. (For Classification and Segmentation Tasks)
681
1113
  """
682
- self.device = self._validate_device(device)
683
- self.model.to(self.device)
684
- _LOGGER.info(f"Trainer and model moved to {self.device}.")
1114
+ # handle save path
1115
+ sanitized_filename = sanitize_filename(filename)
1116
+ if not sanitized_filename.endswith(".pth"):
1117
+ sanitized_filename = sanitized_filename + ".pth"
1118
+
1119
+ dir_path = make_fullpath(save_dir, make=True, enforce="directory")
1120
+ full_path = dir_path / sanitized_filename
1121
+
1122
+ # threshold required for binary tasks
1123
+ if self.kind in MLTaskKeys.ALL_BINARY_TASKS:
1124
+ if classification_threshold is None:
1125
+ _LOGGER.error(f"A classification threshold is needed for binary-based classification tasks. If unknown, use '0.5' as a default.")
1126
+ raise ValueError()
1127
+ elif not isinstance(classification_threshold, float):
1128
+ _LOGGER.error(f"The classification threshold must be a float value.")
1129
+ raise TypeError()
1130
+ elif classification_threshold <= 0.0 or classification_threshold >= 1.0:
1131
+ _LOGGER.error(f"The classification threshold must be in the range (0.0 - 1.0).")
1132
+ else:
1133
+ classification_threshold = None
1134
+
1135
+ # handle checkpoint
1136
+ if isinstance(model_checkpoint, Path):
1137
+ self._load_checkpoint(path=model_checkpoint)
1138
+ elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback:
1139
+ path_to_latest = self._checkpoint_callback.best_checkpoint_path
1140
+ self._load_checkpoint(path_to_latest)
1141
+ elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback is None:
1142
+ _LOGGER.error(f"'model_checkpoint' set to '{MagicWords.LATEST}' but no checkpoint callback was found.")
1143
+ raise ValueError()
1144
+ elif model_checkpoint == MagicWords.CURRENT:
1145
+ pass
1146
+ else:
1147
+ _LOGGER.error(f"Unknown 'model_checkpoint' parameter received '{model_checkpoint}'.")
1148
+
1149
+ # Handle class map
1150
+ if self.kind in [MLTaskKeys.BINARY_CLASSIFICATION,
1151
+ MLTaskKeys.MULTICLASS_CLASSIFICATION,
1152
+ MLTaskKeys.BINARY_IMAGE_CLASSIFICATION,
1153
+ MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION,
1154
+ MLTaskKeys.BINARY_SEGMENTATION,
1155
+ MLTaskKeys.MULTICLASS_SEGMENTATION]:
1156
+ if class_map is None:
1157
+ _LOGGER.error(f"'class_map' is required for '{self.kind}'.")
1158
+ raise ValueError()
1159
+ else:
1160
+ class_map = None
1161
+
1162
+ # Create finalized data
1163
+ finalized_data = {
1164
+ PyTorchCheckpointKeys.EPOCH: self.epoch,
1165
+ PyTorchCheckpointKeys.MODEL_STATE: self.model.state_dict(),
1166
+ }
1167
+
1168
+ if classification_threshold is not None:
1169
+ self._classification_threshold = classification_threshold
1170
+ finalized_data[PyTorchCheckpointKeys.CLASSIFICATION_THRESHOLD] = classification_threshold
1171
+
1172
+ if class_map is not None:
1173
+ finalized_data[PyTorchCheckpointKeys.CLASS_MAP] = class_map
1174
+
1175
+ torch.save(finalized_data, full_path)
1176
+
1177
+ _LOGGER.info(f"Finalized model weights saved to {full_path}.")
685
1178
 
686
1179
 
687
1180
  # Object Detection Trainer
688
- class ObjectDetectionTrainer:
689
- def __init__(self, model: nn.Module, train_dataset: Dataset, test_dataset: Dataset,
1181
+ class DragonDetectionTrainer(_BaseDragonTrainer):
1182
+ def __init__(self, model: nn.Module,
1183
+ train_dataset: Dataset,
1184
+ validation_dataset: Dataset,
690
1185
  collate_fn: Callable, optimizer: torch.optim.Optimizer,
691
- device: Union[Literal['cuda', 'mps', 'cpu'],str], dataloader_workers: int = 2, callbacks: Optional[List[Callback]] = None):
1186
+ device: Union[Literal['cuda', 'mps', 'cpu'],str],
1187
+ checkpoint_callback: Optional[DragonModelCheckpoint],
1188
+ early_stopping_callback: Optional[DragonEarlyStopping],
1189
+ lr_scheduler_callback: Optional[DragonLRScheduler],
1190
+ extra_callbacks: Optional[List[_Callback]] = None,
1191
+ dataloader_workers: int = 2):
692
1192
  """
693
1193
  Automates the training process of an Object Detection Model (e.g., DragonFastRCNN).
694
1194
 
@@ -697,58 +1197,36 @@ class ObjectDetectionTrainer:
697
1197
  Args:
698
1198
  model (nn.Module): The PyTorch object detection model to train.
699
1199
  train_dataset (Dataset): The training dataset.
700
- test_dataset (Dataset): The testing/validation dataset.
1200
+ validation_dataset (Dataset): The testing/validation dataset.
701
1201
  collate_fn (Callable): The collate function from `ObjectDetectionDatasetMaker.collate_fn`.
702
1202
  optimizer (torch.optim.Optimizer): The optimizer.
703
1203
  device (str): The device to run training on ('cpu', 'cuda', 'mps').
704
1204
  dataloader_workers (int): Subprocesses for data loading.
705
- callbacks (List[Callback] | None): A list of callbacks to use during training.
1205
+ checkpoint_callback (DragonModelCheckpoint | None): Callback to save the model.
1206
+ early_stopping_callback (DragonEarlyStopping | None): Callback to stop training early.
1207
+ lr_scheduler_callback (DragonLRScheduler | None): Callback to manage the LR scheduler.
1208
+ extra_callbacks (List[Callback] | None): A list of extra callbacks to use during training.
706
1209
 
707
1210
  ## Note:
708
1211
  This trainer is specialized. It does not take a `criterion` because object detection models like Faster R-CNN return a dictionary of losses directly from their forward pass during training.
709
1212
  """
710
- self.model = model
1213
+ # Call the base class constructor with common parameters
1214
+ super().__init__(
1215
+ model=model,
1216
+ optimizer=optimizer,
1217
+ device=device,
1218
+ dataloader_workers=dataloader_workers,
1219
+ checkpoint_callback=checkpoint_callback,
1220
+ early_stopping_callback=early_stopping_callback,
1221
+ lr_scheduler_callback=lr_scheduler_callback,
1222
+ extra_callbacks=extra_callbacks
1223
+ )
1224
+
711
1225
  self.train_dataset = train_dataset
712
- self.test_dataset = test_dataset
1226
+ self.validation_dataset = validation_dataset # <-- Renamed
713
1227
  self.kind = "object_detection"
714
1228
  self.collate_fn = collate_fn
715
1229
  self.criterion = None # Criterion is handled inside the model
716
- self.optimizer = optimizer
717
- self.scheduler = None
718
- self.device = self._validate_device(device)
719
- self.dataloader_workers = dataloader_workers
720
-
721
- # Callback handler - History and TqdmProgressBar are added by default
722
- default_callbacks = [History(), TqdmProgressBar()]
723
- user_callbacks = callbacks if callbacks is not None else []
724
- self.callbacks = default_callbacks + user_callbacks
725
- self._set_trainer_on_callbacks()
726
-
727
- # Internal state
728
- self.train_loader = None
729
- self.test_loader = None
730
- self.history = {}
731
- self.epoch = 0
732
- self.epochs = 0 # Total epochs for the fit run
733
- self.start_epoch = 1
734
- self.stop_training = False
735
- self._batch_size = 10
736
-
737
- def _validate_device(self, device: str) -> torch.device:
738
- """Validates the selected device and returns a torch.device object."""
739
- device_lower = device.lower()
740
- if "cuda" in device_lower and not torch.cuda.is_available():
741
- _LOGGER.warning("CUDA not available, switching to CPU.")
742
- device = "cpu"
743
- elif device_lower == "mps" and not torch.backends.mps.is_available():
744
- _LOGGER.warning("Apple Metal Performance Shaders (MPS) not available, switching to CPU.")
745
- device = "cpu"
746
- return torch.device(device)
747
-
748
- def _set_trainer_on_callbacks(self):
749
- """Gives each callback a reference to this trainer instance."""
750
- for callback in self.callbacks:
751
- callback.set_trainer(self)
752
1230
 
753
1231
  def _create_dataloaders(self, batch_size: int, shuffle: bool):
754
1232
  """Initializes the DataLoaders with the object detection collate_fn."""
@@ -760,125 +1238,25 @@ class ObjectDetectionTrainer:
760
1238
  batch_size=batch_size,
761
1239
  shuffle=shuffle,
762
1240
  num_workers=loader_workers,
763
- pin_memory=("cuda" in self.device.type),
764
- collate_fn=self.collate_fn # Use the provided collate function
1241
+ pin_memory=("cuda" in self.device.type),
1242
+ collate_fn=self.collate_fn, # Use the provided collate function
1243
+ drop_last=True
765
1244
  )
766
1245
 
767
- self.test_loader = DataLoader(
768
- dataset=self.test_dataset,
1246
+ self.validation_loader = DataLoader(
1247
+ dataset=self.validation_dataset,
769
1248
  batch_size=batch_size,
770
1249
  shuffle=False,
771
1250
  num_workers=loader_workers,
772
1251
  pin_memory=("cuda" in self.device.type),
773
1252
  collate_fn=self.collate_fn # Use the provided collate function
774
1253
  )
775
-
776
- def _load_checkpoint(self, path: Union[str, Path]):
777
- """Loads a training checkpoint to resume training."""
778
- p = make_fullpath(path, enforce="file")
779
- _LOGGER.info(f"Loading checkpoint from '{p.name}' to resume training...")
780
-
781
- try:
782
- checkpoint = torch.load(p, map_location=self.device)
783
-
784
- if PyTorchCheckpointKeys.MODEL_STATE not in checkpoint or PyTorchCheckpointKeys.OPTIMIZER_STATE not in checkpoint:
785
- _LOGGER.error(f"Checkpoint file '{p.name}' is invalid. Missing 'model_state_dict' or 'optimizer_state_dict'.")
786
- raise KeyError()
787
-
788
- self.model.load_state_dict(checkpoint[PyTorchCheckpointKeys.MODEL_STATE])
789
- self.optimizer.load_state_dict(checkpoint[PyTorchCheckpointKeys.OPTIMIZER_STATE])
790
- self.start_epoch = checkpoint.get(PyTorchCheckpointKeys.EPOCH, 0) + 1 # Resume on the *next* epoch
791
-
792
- # --- Scheduler State Loading Logic ---
793
- scheduler_state_exists = PyTorchCheckpointKeys.SCHEDULER_STATE in checkpoint
794
- scheduler_object_exists = self.scheduler is not None
795
-
796
- if scheduler_object_exists and scheduler_state_exists:
797
- # Case 1: Both exist. Attempt to load.
798
- try:
799
- self.scheduler.load_state_dict(checkpoint[PyTorchCheckpointKeys.SCHEDULER_STATE]) # type: ignore
800
- scheduler_name = self.scheduler.__class__.__name__
801
- _LOGGER.info(f"Restored LR scheduler state for: {scheduler_name}")
802
- except Exception as e:
803
- # Loading failed, likely a mismatch
804
- scheduler_name = self.scheduler.__class__.__name__
805
- _LOGGER.error(f"Failed to load scheduler state for '{scheduler_name}'. A different scheduler type might have been used.")
806
- raise e
807
-
808
- elif scheduler_object_exists and not scheduler_state_exists:
809
- # Case 2: Scheduler provided, but no state in checkpoint.
810
- scheduler_name = self.scheduler.__class__.__name__
811
- _LOGGER.warning(f"'{scheduler_name}' was provided, but no scheduler state was found in the checkpoint. The scheduler will start from its initial state.")
812
-
813
- elif not scheduler_object_exists and scheduler_state_exists:
814
- # Case 3: State in checkpoint, but no scheduler provided.
815
- _LOGGER.error("Checkpoint contains an LR scheduler state, but no LRScheduler callback was provided.")
816
- raise ValueError()
817
-
818
- # Restore callback states
819
- for cb in self.callbacks:
820
- if isinstance(cb, ModelCheckpoint) and PyTorchCheckpointKeys.BEST_SCORE in checkpoint:
821
- cb.best = checkpoint[PyTorchCheckpointKeys.BEST_SCORE]
822
- _LOGGER.info(f"Restored {cb.__class__.__name__} 'best' score to: {cb.best:.4f}")
823
-
824
- _LOGGER.info(f"Checkpoint loaded. Resuming training from epoch {self.start_epoch}.")
825
-
826
- except Exception as e:
827
- _LOGGER.error(f"Failed to load checkpoint from '{p}': {e}")
828
- raise
829
-
830
- def fit(self,
831
- epochs: int = 10,
832
- batch_size: int = 10,
833
- shuffle: bool = True,
834
- resume_from_checkpoint: Optional[Union[str, Path]] = None):
835
- """
836
- Starts the training-validation process of the model.
837
-
838
- Returns the "History" callback dictionary.
839
1254
 
840
- Args:
841
- epochs (int): The total number of epochs to train for.
842
- batch_size (int): The number of samples per batch.
843
- shuffle (bool): Whether to shuffle the training data at each epoch.
844
- resume_from_checkpoint (str | Path | None): Optional path to a checkpoint to resume training.
845
- """
846
- self.epochs = epochs
847
- self._batch_size = batch_size
848
- self._create_dataloaders(self._batch_size, shuffle)
849
- self.model.to(self.device)
850
-
851
- if resume_from_checkpoint:
852
- self._load_checkpoint(resume_from_checkpoint)
853
-
854
- # Reset stop_training flag on the trainer
855
- self.stop_training = False
856
-
857
- self._callbacks_hook('on_train_begin')
858
-
859
- for epoch in range(self.start_epoch, self.epochs + 1):
860
- self.epoch = epoch
861
- epoch_logs = {}
862
- self._callbacks_hook('on_epoch_begin', epoch, logs=epoch_logs)
863
-
864
- train_logs = self._train_step()
865
- epoch_logs.update(train_logs)
866
-
867
- val_logs = self._validation_step()
868
- epoch_logs.update(val_logs)
869
-
870
- self._callbacks_hook('on_epoch_end', epoch, logs=epoch_logs)
871
-
872
- # Check the early stopping flag
873
- if self.stop_training:
874
- break
875
-
876
- self._callbacks_hook('on_train_end')
877
- return self.history
878
-
879
1255
  def _train_step(self):
880
1256
  self.model.train()
881
1257
  running_loss = 0.0
1258
+ total_samples = 0
1259
+
882
1260
  for batch_idx, (images, targets) in enumerate(self.train_loader): # type: ignore
883
1261
  # images is a tuple of tensors, targets is a tuple of dicts
884
1262
  batch_size = len(images)
@@ -915,21 +1293,28 @@ class ObjectDetectionTrainer:
915
1293
  # Calculate batch loss and update running loss for the epoch
916
1294
  batch_loss = loss.item()
917
1295
  running_loss += batch_loss * batch_size
1296
+ total_samples += batch_size # <-- Accumulate total samples
918
1297
 
919
1298
  # Add the batch loss to the logs and call the end-of-batch hook
920
1299
  batch_logs[PyTorchLogKeys.BATCH_LOSS] = batch_loss # type: ignore
921
1300
  self._callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
1301
+
1302
+ # Calculate loss using the correct denominator
1303
+ if total_samples == 0:
1304
+ _LOGGER.warning("No samples processed in _train_step. Returning 0 loss.")
1305
+ return {PyTorchLogKeys.TRAIN_LOSS: 0.0}
922
1306
 
923
- return {PyTorchLogKeys.TRAIN_LOSS: running_loss / len(self.train_loader.dataset)} # type: ignore
1307
+ return {PyTorchLogKeys.TRAIN_LOSS: running_loss / total_samples}
924
1308
 
925
1309
  def _validation_step(self):
926
1310
  self.model.train() # Set to train mode even for validation loss calculation
927
- # as model internals (e.g., proposals) might differ,
928
- # but we still need loss_dict.
929
- # We use torch.no_grad() to prevent gradient updates.
1311
+ # as model internals (e.g., proposals) might differ, but we still need loss_dict.
1312
+ # use torch.no_grad() to prevent gradient updates.
930
1313
  running_loss = 0.0
1314
+ total_samples = 0
1315
+
931
1316
  with torch.no_grad():
932
- for images, targets in self.test_loader: # type: ignore
1317
+ for images, targets in self.validation_loader: # type: ignore
933
1318
  batch_size = len(images)
934
1319
 
935
1320
  # Move data to device
@@ -947,25 +1332,105 @@ class ObjectDetectionTrainer:
947
1332
  loss: torch.Tensor = sum(l for l in loss_dict.values()) # type: ignore
948
1333
 
949
1334
  running_loss += loss.item() * batch_size
1335
+ total_samples += batch_size # <-- Accumulate total samples
950
1336
 
951
- logs = {PyTorchLogKeys.VAL_LOSS: running_loss / len(self.test_loader.dataset)} # type: ignore
1337
+ # Calculate loss using the correct denominator
1338
+ if total_samples == 0:
1339
+ _LOGGER.warning("No samples processed in _validation_step. Returning 0 loss.")
1340
+ return {PyTorchLogKeys.VAL_LOSS: 0.0}
1341
+
1342
+ logs = {PyTorchLogKeys.VAL_LOSS: running_loss / total_samples}
952
1343
  return logs
1344
+
1345
+ def evaluate(self,
1346
+ save_dir: Union[str, Path],
1347
+ model_checkpoint: Union[Path, Literal["latest", "current"]],
1348
+ test_data: Optional[Union[DataLoader, Dataset]] = None):
1349
+ """
1350
+ Evaluates the model using object detection mAP metrics.
953
1351
 
954
- def evaluate(self, save_dir: Union[str, Path], data: Optional[Union[DataLoader, Dataset]] = None):
1352
+ Args:
1353
+ save_dir (str | Path): Directory to save all reports and plots.
1354
+ model_checkpoint ('auto' | Path | None):
1355
+ - Path to a valid checkpoint for the model. The state of the trained model will be overwritten in place.
1356
+ - If 'latest', the latest checkpoint will be loaded if a DragonModelCheckpoint was provided. The state of the trained model will be overwritten in place.
1357
+ - If 'current', use the current state of the trained model up the latest trained epoch.
1358
+ test_data (DataLoader | Dataset | None): Optional Test data to evaluate the model performance. Validation and Test metrics will be saved to subdirectories.
955
1359
  """
1360
+ # Validate model checkpoint
1361
+ if isinstance(model_checkpoint, Path):
1362
+ checkpoint_validated = make_fullpath(model_checkpoint, enforce="file")
1363
+ elif model_checkpoint in [MagicWords.LATEST, MagicWords.CURRENT]:
1364
+ checkpoint_validated = model_checkpoint
1365
+ else:
1366
+ _LOGGER.error(f"'model_checkpoint' must be a Path object, or the string '{MagicWords.LATEST}', or the string '{MagicWords.CURRENT}'.")
1367
+ raise ValueError()
1368
+
1369
+ # Validate directory
1370
+ save_path = make_fullpath(save_dir, make=True, enforce="directory")
1371
+
1372
+ # Validate test data and dispatch
1373
+ if test_data is not None:
1374
+ if not isinstance(test_data, (DataLoader, Dataset)):
1375
+ _LOGGER.error(f"Invalid type for 'test_data': '{type(test_data)}'.")
1376
+ raise ValueError()
1377
+ test_data_validated = test_data
1378
+
1379
+ validation_metrics_path = save_path / DragonTrainerKeys.VALIDATION_METRICS_DIR
1380
+ test_metrics_path = save_path / DragonTrainerKeys.TEST_METRICS_DIR
1381
+
1382
+ # Dispatch validation set
1383
+ _LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{DragonTrainerKeys.VALIDATION_METRICS_DIR}'")
1384
+ self._evaluate(save_dir=validation_metrics_path,
1385
+ model_checkpoint=checkpoint_validated,
1386
+ data=None) # 'None' triggers use of self.test_dataset
1387
+
1388
+ # Dispatch test set
1389
+ _LOGGER.info(f"Evaluating on test dataset. Metrics will be saved to '{DragonTrainerKeys.TEST_METRICS_DIR}'")
1390
+ self._evaluate(save_dir=test_metrics_path,
1391
+ model_checkpoint="current", # Use 'current' state after loading checkpoint once
1392
+ data=test_data_validated)
1393
+ else:
1394
+ # Dispatch validation set
1395
+ _LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{save_path.name}'")
1396
+ self._evaluate(save_dir=save_path,
1397
+ model_checkpoint=checkpoint_validated,
1398
+ data=None) # 'None' triggers use of self.test_dataset
1399
+
1400
+ def _evaluate(self,
1401
+ save_dir: Union[str, Path],
1402
+ model_checkpoint: Union[Path, Literal["latest", "current"]],
1403
+ data: Optional[Union[DataLoader, Dataset]]):
1404
+ """
1405
+ Changed to a private helper method
956
1406
  Evaluates the model using object detection mAP metrics.
957
1407
 
958
1408
  Args:
959
1409
  save_dir (str | Path): Directory to save all reports and plots.
960
1410
  data (DataLoader | Dataset | None): The data to evaluate on. If None, defaults to the trainer's internal test_dataset.
1411
+ model_checkpoint ('auto' | Path | None):
1412
+ - Path to a valid checkpoint for the model. The state of the trained model will be overwritten in place.
1413
+ - If 'latest', the latest checkpoint will be loaded if a DragonModelCheckpoint was provided. The state of the trained model will be overwritten in place.
1414
+ - If 'current', use the current state of the trained model up the latest trained epoch.
961
1415
  """
962
1416
  dataset_for_names = None
963
1417
  eval_loader = None
1418
+
1419
+ # load model checkpoint
1420
+ if isinstance(model_checkpoint, Path):
1421
+ self._load_checkpoint(path=model_checkpoint)
1422
+ elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback:
1423
+ path_to_latest = self._checkpoint_callback.best_checkpoint_path
1424
+ self._load_checkpoint(path_to_latest)
1425
+ elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback is None:
1426
+ _LOGGER.error(f"'model_checkpoint' set to '{MagicWords.LATEST}' but no checkpoint callback was found.")
1427
+ raise ValueError()
964
1428
 
1429
+ # Dataloader
965
1430
  if isinstance(data, DataLoader):
966
1431
  eval_loader = data
967
1432
  if hasattr(data, 'dataset'):
968
- dataset_for_names = data.dataset
1433
+ dataset_for_names = data.dataset # type: ignore
969
1434
  elif isinstance(data, Dataset):
970
1435
  # Create a new loader from the provided dataset
971
1436
  eval_loader = DataLoader(data,
@@ -976,25 +1441,25 @@ class ObjectDetectionTrainer:
976
1441
  collate_fn=self.collate_fn)
977
1442
  dataset_for_names = data
978
1443
  else: # data is None, use the trainer's default test dataset
979
- if self.test_dataset is None:
1444
+ if self.validation_dataset is None:
980
1445
  _LOGGER.error("Cannot evaluate. No data provided and no test_dataset available in the trainer.")
981
1446
  raise ValueError()
982
1447
  # Create a fresh DataLoader from the test_dataset
983
1448
  eval_loader = DataLoader(
984
- self.test_dataset,
1449
+ self.validation_dataset,
985
1450
  batch_size=self._batch_size,
986
1451
  shuffle=False,
987
1452
  num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
988
1453
  pin_memory=(self.device.type == "cuda"),
989
1454
  collate_fn=self.collate_fn
990
1455
  )
991
- dataset_for_names = self.test_dataset
1456
+ dataset_for_names = self.validation_dataset
992
1457
 
993
1458
  if eval_loader is None:
994
1459
  _LOGGER.error("Cannot evaluate. No valid data was provided or found.")
995
1460
  raise ValueError()
996
1461
 
997
- print("\n--- Model Evaluation ---")
1462
+ # print("\n--- Model Evaluation ---")
998
1463
 
999
1464
  all_predictions = []
1000
1465
  all_targets = []
@@ -1042,36 +1507,480 @@ class ObjectDetectionTrainer:
1042
1507
  class_names=class_names,
1043
1508
  print_output=False
1044
1509
  )
1045
-
1046
- print("\n--- Training History ---")
1047
- plot_losses(self.history, save_dir=save_dir)
1048
1510
 
1049
- def _callbacks_hook(self, method_name: str, *args, **kwargs):
1050
- """Calls the specified method on all callbacks."""
1051
- for callback in self.callbacks:
1052
- method = getattr(callback, method_name)
1053
- method(*args, **kwargs)
1511
+ def finalize_model_training(self, save_dir: Union[str, Path], filename: str, model_checkpoint: Union[Path, Literal['latest', 'current']]):
1512
+ """
1513
+ Saves a finalized, "inference-ready" model state to a .pth file.
1514
+
1515
+ This method saves the model's `state_dict` and the final epoch number.
1516
+
1517
+ Args:
1518
+ save_dir (Union[str, Path]): The directory to save the finalized model.
1519
+ filename (str): The desired filename for the model (e.g., "final_model.pth").
1520
+ model_checkpoint (Union[Path, Literal["latest", "current"]]):
1521
+ - Path: Loads the model state from a specific checkpoint file.
1522
+ - "latest": Loads the best model state saved by the `DragonModelCheckpoint` callback.
1523
+ - "current": Uses the model's state as it is at the end of the `fit()` call.
1524
+ """
1525
+ # handle save path
1526
+ sanitized_filename = sanitize_filename(filename)
1527
+ if not sanitized_filename.endswith(".pth"):
1528
+ sanitized_filename = sanitized_filename + ".pth"
1054
1529
 
1055
- def to_cpu(self):
1530
+ dir_path = make_fullpath(save_dir, make=True, enforce="directory")
1531
+ full_path = dir_path / sanitized_filename
1532
+
1533
+ # handle checkpoint
1534
+ if isinstance(model_checkpoint, Path):
1535
+ self._load_checkpoint(path=model_checkpoint)
1536
+ elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback:
1537
+ path_to_latest = self._checkpoint_callback.best_checkpoint_path
1538
+ self._load_checkpoint(path_to_latest)
1539
+ elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback is None:
1540
+ _LOGGER.error(f"'model_checkpoint' set to '{MagicWords.LATEST}' but no checkpoint callback was found.")
1541
+ raise ValueError()
1542
+ elif model_checkpoint == MagicWords.CURRENT:
1543
+ pass
1544
+ else:
1545
+ _LOGGER.error(f"Unknown 'model_checkpoint' parameter received '{model_checkpoint}'.")
1546
+
1547
+ # Create finalized data
1548
+ finalized_data = {
1549
+ PyTorchCheckpointKeys.EPOCH: self.epoch,
1550
+ PyTorchCheckpointKeys.MODEL_STATE: self.model.state_dict(),
1551
+ }
1552
+
1553
+ torch.save(finalized_data, full_path)
1554
+
1555
+ _LOGGER.info(f"Finalized model weights saved to {full_path}.")
1556
+
1557
+ # --- DragonSequenceTrainer ----
1558
+ class DragonSequenceTrainer(_BaseDragonTrainer):
1559
+ def __init__(self,
1560
+ model: nn.Module,
1561
+ train_dataset: Dataset,
1562
+ validation_dataset: Dataset,
1563
+ kind: Literal["sequence-to-sequence", "sequence-to-value"],
1564
+ optimizer: torch.optim.Optimizer,
1565
+ device: Union[Literal['cuda', 'mps', 'cpu'],str],
1566
+ checkpoint_callback: Optional[DragonModelCheckpoint],
1567
+ early_stopping_callback: Optional[DragonEarlyStopping],
1568
+ lr_scheduler_callback: Optional[DragonLRScheduler],
1569
+ extra_callbacks: Optional[List[_Callback]] = None,
1570
+ criterion: Union[nn.Module,Literal["auto"]] = "auto",
1571
+ dataloader_workers: int = 2):
1056
1572
  """
1057
- Moves the model to the CPU and updates the trainer's device setting.
1573
+ Automates the training process of a PyTorch Sequence Model.
1058
1574
 
1059
- This is useful for running operations that require the CPU.
1575
+ Built-in Callbacks: `History`, `TqdmProgressBar`
1576
+
1577
+ Args:
1578
+ model (nn.Module): The PyTorch model to train.
1579
+ train_dataset (Dataset): The training dataset.
1580
+ validation_dataset (Dataset): The validation dataset.
1581
+ kind (str): Used to redirect to the correct process ('sequence-to-sequence' or 'sequence-to-value').
1582
+ criterion (nn.Module | "auto"): The loss function to use. If "auto", it will be inferred from the selected task
1583
+ optimizer (torch.optim.Optimizer): The optimizer.
1584
+ device (str): The device to run training on ('cpu', 'cuda', 'mps').
1585
+ dataloader_workers (int): Subprocesses for data loading.
1586
+ extra_callbacks (List[Callback] | None): A list of extra callbacks to use during training.
1060
1587
  """
1061
- self.device = torch.device('cpu')
1588
+ # Call the base class constructor with common parameters
1589
+ super().__init__(
1590
+ model=model,
1591
+ optimizer=optimizer,
1592
+ device=device,
1593
+ dataloader_workers=dataloader_workers,
1594
+ checkpoint_callback=checkpoint_callback,
1595
+ early_stopping_callback=early_stopping_callback,
1596
+ lr_scheduler_callback=lr_scheduler_callback,
1597
+ extra_callbacks=extra_callbacks
1598
+ )
1599
+
1600
+ if kind not in [MLTaskKeys.SEQUENCE_SEQUENCE, MLTaskKeys.SEQUENCE_VALUE]:
1601
+ raise ValueError(f"'{kind}' is not a valid task type for DragonSequenceTrainer.")
1602
+
1603
+ self.train_dataset = train_dataset
1604
+ self.validation_dataset = validation_dataset
1605
+ self.kind = kind
1606
+
1607
+ # try to validate against Dragon Sequence model
1608
+ if hasattr(self.model, "prediction_mode"):
1609
+ key_to_check: str = self.model.prediction_mode # type: ignore
1610
+ if not key_to_check == self.kind:
1611
+ _LOGGER.error(f"Trainer was set for '{self.kind}', but model architecture '{self.model}' is built for '{key_to_check}'.")
1612
+ raise RuntimeError()
1613
+
1614
+ # loss function
1615
+ if criterion == "auto":
1616
+ # Both sequence tasks are treated as regression problems
1617
+ self.criterion = nn.MSELoss()
1618
+ else:
1619
+ self.criterion = criterion
1620
+
1621
+ def _create_dataloaders(self, batch_size: int, shuffle: bool):
1622
+ """Initializes the DataLoaders."""
1623
+ # Ensure stability on MPS devices by setting num_workers to 0
1624
+ loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
1625
+
1626
+ self.train_loader = DataLoader(
1627
+ dataset=self.train_dataset,
1628
+ batch_size=batch_size,
1629
+ shuffle=shuffle,
1630
+ num_workers=loader_workers,
1631
+ pin_memory=("cuda" in self.device.type),
1632
+ drop_last=True # Drops the last batch if incomplete, selecting a good batch size is key.
1633
+ )
1634
+
1635
+ self.validation_loader = DataLoader(
1636
+ dataset=self.validation_dataset,
1637
+ batch_size=batch_size,
1638
+ shuffle=False,
1639
+ num_workers=loader_workers,
1640
+ pin_memory=("cuda" in self.device.type)
1641
+ )
1642
+
1643
+ def _train_step(self):
1644
+ self.model.train()
1645
+ running_loss = 0.0
1646
+ total_samples = 0
1647
+
1648
+ for batch_idx, (features, target) in enumerate(self.train_loader): # type: ignore
1649
+ # Create a log dictionary for the batch
1650
+ batch_logs = {
1651
+ PyTorchLogKeys.BATCH_INDEX: batch_idx,
1652
+ PyTorchLogKeys.BATCH_SIZE: features.size(0)
1653
+ }
1654
+ self._callbacks_hook('on_batch_begin', batch_idx, logs=batch_logs)
1655
+
1656
+ features, target = features.to(self.device), target.to(self.device)
1657
+ self.optimizer.zero_grad()
1658
+
1659
+ output = self.model(features)
1660
+
1661
+ # --- Label Type/Shape Correction ---
1662
+ # Ensure target is float for MSELoss
1663
+ target = target.float()
1664
+
1665
+ # For seq-to-val, models might output [N, 1] but target is [N].
1666
+ if self.kind == MLTaskKeys.SEQUENCE_VALUE:
1667
+ if output.ndim == 2 and output.shape[1] == 1 and target.ndim == 1:
1668
+ output = output.squeeze(1)
1669
+
1670
+ # For seq-to-seq, models might output [N, Seq, 1] but target is [N, Seq].
1671
+ elif self.kind == MLTaskKeys.SEQUENCE_SEQUENCE:
1672
+ if output.ndim == 3 and output.shape[2] == 1 and target.ndim == 2:
1673
+ output = output.squeeze(-1)
1674
+
1675
+ loss = self.criterion(output, target)
1676
+
1677
+ loss.backward()
1678
+ self.optimizer.step()
1679
+
1680
+ # Calculate batch loss and update running loss for the epoch
1681
+ batch_loss = loss.item()
1682
+ batch_size = features.size(0)
1683
+ running_loss += batch_loss * batch_size # Accumulate total loss
1684
+ total_samples += batch_size # total samples
1685
+
1686
+ # Add the batch loss to the logs and call the end-of-batch hook
1687
+ batch_logs[PyTorchLogKeys.BATCH_LOSS] = batch_loss
1688
+ self._callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
1689
+
1690
+ if total_samples == 0:
1691
+ _LOGGER.warning("No samples processed in a train_step. Returning 0 loss.")
1692
+ return {PyTorchLogKeys.TRAIN_LOSS: 0.0}
1693
+
1694
+ return {PyTorchLogKeys.TRAIN_LOSS: running_loss / total_samples} # type: ignore
1695
+
1696
+ def _validation_step(self):
1697
+ self.model.eval()
1698
+ running_loss = 0.0
1699
+
1700
+ with torch.no_grad():
1701
+ for features, target in self.validation_loader: # type: ignore
1702
+ features, target = features.to(self.device), target.to(self.device)
1703
+
1704
+ output = self.model(features)
1705
+
1706
+ # --- Label Type/Shape Correction ---
1707
+ target = target.float()
1708
+
1709
+ # For seq-to-val, models might output [N, 1] but target is [N].
1710
+ if self.kind == MLTaskKeys.SEQUENCE_VALUE:
1711
+ if output.ndim == 2 and output.shape[1] == 1 and target.ndim == 1:
1712
+ output = output.squeeze(1)
1713
+
1714
+ # For seq-to-seq, models might output [N, Seq, 1] but target is [N, Seq].
1715
+ elif self.kind == MLTaskKeys.SEQUENCE_SEQUENCE:
1716
+ if output.ndim == 3 and output.shape[2] == 1 and target.ndim == 2:
1717
+ output = output.squeeze(-1)
1718
+
1719
+ loss = self.criterion(output, target)
1720
+
1721
+ running_loss += loss.item() * features.size(0)
1722
+
1723
+ if not self.validation_loader.dataset: # type: ignore
1724
+ _LOGGER.warning("No samples processed in _validation_step. Returning 0 loss.")
1725
+ return {PyTorchLogKeys.VAL_LOSS: 0.0}
1726
+
1727
+ logs = {PyTorchLogKeys.VAL_LOSS: running_loss / len(self.validation_loader.dataset)} # type: ignore
1728
+ return logs
1729
+
1730
+ def _predict_for_eval(self, dataloader: DataLoader):
1731
+ """
1732
+ Private method to yield model predictions batch by batch for evaluation.
1733
+
1734
+ Yields:
1735
+ tuple: A tuple containing (y_pred_batch, y_prob_batch, y_true_batch).
1736
+ y_prob_batch is always None for sequence tasks.
1737
+ """
1738
+ self.model.eval()
1062
1739
  self.model.to(self.device)
1063
- _LOGGER.info("Trainer and model moved to CPU.")
1740
+
1741
+ with torch.no_grad():
1742
+ for features, target in dataloader:
1743
+ features = features.to(self.device)
1744
+ output = self.model(features).cpu()
1745
+
1746
+ y_pred_batch = output.numpy()
1747
+ y_prob_batch = None # Not applicable for sequence regression
1748
+ y_true_batch = target.numpy()
1749
+
1750
+ yield y_pred_batch, y_prob_batch, y_true_batch
1751
+
1752
+ def evaluate(self,
1753
+ save_dir: Union[str, Path],
1754
+ model_checkpoint: Union[Path, Literal["latest", "current"]],
1755
+ test_data: Optional[Union[DataLoader, Dataset]] = None,
1756
+ val_format_configuration: Optional[Union[SequenceValueMetricsFormat,
1757
+ SequenceSequenceMetricsFormat]]=None,
1758
+ test_format_configuration: Optional[Union[SequenceValueMetricsFormat,
1759
+ SequenceSequenceMetricsFormat]]=None):
1760
+ """
1761
+ Evaluates the model, routing to the correct evaluation function.
1762
+
1763
+ Args:
1764
+ model_checkpoint ('auto' | Path | None):
1765
+ - Path to a valid checkpoint for the model.
1766
+ - If 'latest', the latest checkpoint will be loaded.
1767
+ - If 'current', use the current state of the trained model.
1768
+ save_dir (str | Path): Directory to save all reports and plots.
1769
+ test_data (DataLoader | Dataset | None): Optional Test data.
1770
+ val_format_configuration: Optional configuration for validation metrics.
1771
+ test_format_configuration: Optional configuration for test metrics.
1772
+ """
1773
+ # Validate model checkpoint
1774
+ if isinstance(model_checkpoint, Path):
1775
+ checkpoint_validated = make_fullpath(model_checkpoint, enforce="file")
1776
+ elif model_checkpoint in [MagicWords.LATEST, MagicWords.CURRENT]:
1777
+ checkpoint_validated = model_checkpoint
1778
+ else:
1779
+ _LOGGER.error(f"'model_checkpoint' must be a Path object, or '{MagicWords.LATEST}', or '{MagicWords.CURRENT}'.")
1780
+ raise ValueError()
1781
+
1782
+ # Validate val configuration
1783
+ if val_format_configuration is not None:
1784
+ if not isinstance(val_format_configuration, (SequenceValueMetricsFormat, SequenceSequenceMetricsFormat)):
1785
+ _LOGGER.error(f"Invalid 'val_format_configuration': '{type(val_format_configuration)}'.")
1786
+ raise ValueError()
1787
+
1788
+ # Validate directory
1789
+ save_path = make_fullpath(save_dir, make=True, enforce="directory")
1790
+
1791
+ # Validate test data and dispatch
1792
+ if test_data is not None:
1793
+ if not isinstance(test_data, (DataLoader, Dataset)):
1794
+ _LOGGER.error(f"Invalid type for 'test_data': '{type(test_data)}'.")
1795
+ raise ValueError()
1796
+ test_data_validated = test_data
1064
1797
 
1065
- def to_device(self, device: str):
1798
+ validation_metrics_path = save_path / DragonTrainerKeys.VALIDATION_METRICS_DIR
1799
+ test_metrics_path = save_path / DragonTrainerKeys.TEST_METRICS_DIR
1800
+
1801
+ # Dispatch validation set
1802
+ _LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{DragonTrainerKeys.VALIDATION_METRICS_DIR}'")
1803
+ self._evaluate(save_dir=validation_metrics_path,
1804
+ model_checkpoint=checkpoint_validated,
1805
+ data=None,
1806
+ format_configuration=val_format_configuration)
1807
+
1808
+ # Validate test configuration
1809
+ test_configuration_validated = None
1810
+ if test_format_configuration is not None:
1811
+ if not isinstance(test_format_configuration, (SequenceValueMetricsFormat, SequenceSequenceMetricsFormat)):
1812
+ warning_message_type = f"Invalid test_format_configuration': '{type(test_format_configuration)}'."
1813
+ if val_format_configuration is not None:
1814
+ warning_message_type += " 'val_format_configuration' will be used."
1815
+ test_configuration_validated = val_format_configuration
1816
+ else:
1817
+ warning_message_type += " Using default format."
1818
+ _LOGGER.warning(warning_message_type)
1819
+ else:
1820
+ test_configuration_validated = test_format_configuration
1821
+
1822
+ # Dispatch test set
1823
+ _LOGGER.info(f"Evaluating on test dataset. Metrics will be saved to '{DragonTrainerKeys.TEST_METRICS_DIR}'")
1824
+ self._evaluate(save_dir=test_metrics_path,
1825
+ model_checkpoint="current",
1826
+ data=test_data_validated,
1827
+ format_configuration=test_configuration_validated)
1828
+ else:
1829
+ # Dispatch validation set
1830
+ _LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{save_path.name}'")
1831
+ self._evaluate(save_dir=save_path,
1832
+ model_checkpoint=checkpoint_validated,
1833
+ data=None,
1834
+ format_configuration=val_format_configuration)
1835
+
1836
+ def _evaluate(self,
1837
+ save_dir: Union[str, Path],
1838
+ model_checkpoint: Union[Path, Literal["latest", "current"]],
1839
+ data: Optional[Union[DataLoader, Dataset]],
1840
+ format_configuration: Optional[Union[SequenceValueMetricsFormat,
1841
+ SequenceSequenceMetricsFormat]]):
1066
1842
  """
1067
- Moves the model to the specified device and updates the trainer's device setting.
1843
+ Private evaluation helper.
1844
+ """
1845
+ eval_loader = None
1846
+
1847
+ # load model checkpoint
1848
+ if isinstance(model_checkpoint, Path):
1849
+ self._load_checkpoint(path=model_checkpoint)
1850
+ elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback:
1851
+ path_to_latest = self._checkpoint_callback.best_checkpoint_path
1852
+ self._load_checkpoint(path_to_latest)
1853
+ elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback is None:
1854
+ _LOGGER.error(f"'model_checkpoint' set to '{MagicWords.LATEST}' but no checkpoint callback was found.")
1855
+ raise ValueError()
1856
+
1857
+ # Dataloader
1858
+ if isinstance(data, DataLoader):
1859
+ eval_loader = data
1860
+ elif isinstance(data, Dataset):
1861
+ # Create a new loader from the provided dataset
1862
+ eval_loader = DataLoader(data,
1863
+ batch_size=self._batch_size,
1864
+ shuffle=False,
1865
+ num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
1866
+ pin_memory=(self.device.type == "cuda"))
1867
+ else: # data is None, use the trainer's default validation dataset
1868
+ if self.validation_dataset is None:
1869
+ _LOGGER.error("Cannot evaluate. No data provided and no validation_dataset available in the trainer.")
1870
+ raise ValueError()
1871
+ eval_loader = DataLoader(self.validation_dataset,
1872
+ batch_size=self._batch_size,
1873
+ shuffle=False,
1874
+ num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
1875
+ pin_memory=(self.device.type == "cuda"))
1876
+
1877
+ if eval_loader is None:
1878
+ _LOGGER.error("Cannot evaluate. No valid data was provided or found.")
1879
+ raise ValueError()
1880
+
1881
+ all_preds, _, all_true = [], [], []
1882
+ for y_pred_b, y_prob_b, y_true_b in self._predict_for_eval(eval_loader):
1883
+ if y_pred_b is not None: all_preds.append(y_pred_b)
1884
+ if y_true_b is not None: all_true.append(y_true_b)
1885
+
1886
+ if not all_true:
1887
+ _LOGGER.error("Evaluation failed: No data was processed.")
1888
+ return
1889
+
1890
+ y_pred = np.concatenate(all_preds)
1891
+ y_true = np.concatenate(all_true)
1892
+
1893
+ # --- Routing Logic ---
1894
+ if self.kind == MLTaskKeys.SEQUENCE_VALUE:
1895
+ config = None
1896
+ if format_configuration and isinstance(format_configuration, SequenceValueMetricsFormat):
1897
+ config = format_configuration
1898
+ elif format_configuration:
1899
+ _LOGGER.warning(f"Wrong config type: Received {type(format_configuration).__name__}, expected SequenceValueMetricsFormat.")
1900
+
1901
+ sequence_to_value_metrics(y_true=y_true,
1902
+ y_pred=y_pred,
1903
+ save_dir=save_dir,
1904
+ config=config)
1905
+
1906
+ elif self.kind == MLTaskKeys.SEQUENCE_SEQUENCE:
1907
+ config = None
1908
+ if format_configuration and isinstance(format_configuration, SequenceSequenceMetricsFormat):
1909
+ config = format_configuration
1910
+ elif format_configuration:
1911
+ _LOGGER.warning(f"Wrong config type: Received {type(format_configuration).__name__}, expected SequenceSequenceMetricsFormat.")
1912
+
1913
+ sequence_to_sequence_metrics(y_true=y_true,
1914
+ y_pred=y_pred,
1915
+ save_dir=save_dir,
1916
+ config=config)
1917
+
1918
+ def finalize_model_training(self,
1919
+ save_dir: Union[str, Path],
1920
+ filename: str,
1921
+ last_training_sequence: np.ndarray,
1922
+ model_checkpoint: Union[Path, Literal['latest', 'current']]):
1923
+ """
1924
+ Saves a finalized, "inference-ready" model state to a .pth file.
1925
+
1926
+ This method saves the model's `state_dict` and the final epoch number.
1068
1927
 
1069
1928
  Args:
1070
- device (str): The target device (e.g., 'cuda', 'mps', 'cpu').
1929
+ save_dir (Union[str, Path]): The directory to save the finalized model.
1930
+ filename (str): The desired filename for the model (e.g., "final_model.pth").
1931
+ last_training_sequence (np.ndarray): The last un-scaled sequence from the training data, used for forecasting.
1932
+ model_checkpoint (Union[Path, Literal["latest", "current"]]):
1933
+ - Path: Loads the model state from a specific checkpoint file.
1934
+ - "latest": Loads the best model state saved by the `DragonModelCheckpoint` callback.
1935
+ - "current": Uses the model's state as it is at the end of the `fit()` call.
1071
1936
  """
1072
- self.device = self._validate_device(device)
1073
- self.model.to(self.device)
1074
- _LOGGER.info(f"Trainer and model moved to {self.device}.")
1937
+ # handle save path
1938
+ sanitized_filename = sanitize_filename(filename)
1939
+ if not sanitized_filename.endswith(".pth"):
1940
+ sanitized_filename = sanitized_filename + ".pth"
1941
+
1942
+ dir_path = make_fullpath(save_dir, make=True, enforce="directory")
1943
+ full_path = dir_path / sanitized_filename
1944
+
1945
+ # handle checkpoint
1946
+ if isinstance(model_checkpoint, Path):
1947
+ self._load_checkpoint(path=model_checkpoint)
1948
+ elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback:
1949
+ path_to_latest = self._checkpoint_callback.best_checkpoint_path
1950
+ self._load_checkpoint(path_to_latest)
1951
+ elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback is None:
1952
+ _LOGGER.error(f"'model_checkpoint' set to '{MagicWords.LATEST}' but no checkpoint callback was found.")
1953
+ raise ValueError()
1954
+ elif model_checkpoint == MagicWords.CURRENT:
1955
+ pass
1956
+ else:
1957
+ _LOGGER.error(f"Unknown 'model_checkpoint' parameter received '{model_checkpoint}'.")
1958
+
1959
+ # --- 1. Validate the provided initial sequence ---
1960
+ if not isinstance(last_training_sequence, np.ndarray):
1961
+ _LOGGER.error(f"'last_training_sequence' must be a numpy array. Got {type(last_training_sequence)}")
1962
+ raise TypeError()
1963
+ if last_training_sequence.ndim != 1:
1964
+ _LOGGER.error(f"'last_training_sequence' must be a 1D array. Got {last_training_sequence.ndim} dimensions.")
1965
+ raise ValueError()
1966
+
1967
+ # --- 2. Derive sequence_length from the array ---
1968
+ sequence_length = len(last_training_sequence)
1969
+ if sequence_length <= 0:
1970
+ _LOGGER.error(f"Length of 'last_training_sequence' cannot be zero.")
1971
+ raise ValueError()
1972
+
1973
+ # Create finalized data
1974
+ finalized_data = {
1975
+ PyTorchCheckpointKeys.EPOCH: self.epoch,
1976
+ PyTorchCheckpointKeys.MODEL_STATE: self.model.state_dict(),
1977
+ PyTorchCheckpointKeys.SEQUENCE_LENGTH: sequence_length,
1978
+ PyTorchCheckpointKeys.INITIAL_SEQUENCE: last_training_sequence
1979
+ }
1980
+
1981
+ torch.save(finalized_data, full_path)
1982
+
1983
+ _LOGGER.info(f"Finalized model weights saved to {full_path}.")
1075
1984
 
1076
1985
 
1077
1986
  def info():