dragon-ml-toolbox 14.7.0__py3-none-any.whl → 16.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.1.dist-info}/METADATA +9 -5
  2. dragon_ml_toolbox-16.2.1.dist-info/RECORD +51 -0
  3. ml_tools/ETL_cleaning.py +20 -20
  4. ml_tools/ETL_engineering.py +23 -25
  5. ml_tools/GUI_tools.py +20 -20
  6. ml_tools/MICE_imputation.py +3 -3
  7. ml_tools/ML_callbacks.py +43 -26
  8. ml_tools/ML_configuration.py +726 -32
  9. ml_tools/ML_datasetmaster.py +235 -280
  10. ml_tools/ML_evaluation.py +160 -42
  11. ml_tools/ML_evaluation_multi.py +103 -35
  12. ml_tools/ML_inference.py +290 -208
  13. ml_tools/ML_models.py +13 -102
  14. ml_tools/ML_models_advanced.py +1 -1
  15. ml_tools/ML_optimization.py +12 -12
  16. ml_tools/ML_scaler.py +11 -11
  17. ml_tools/ML_sequence_datasetmaster.py +341 -0
  18. ml_tools/ML_sequence_evaluation.py +219 -0
  19. ml_tools/ML_sequence_inference.py +391 -0
  20. ml_tools/ML_sequence_models.py +139 -0
  21. ml_tools/ML_trainer.py +1342 -386
  22. ml_tools/ML_utilities.py +1 -1
  23. ml_tools/ML_vision_datasetmaster.py +120 -72
  24. ml_tools/ML_vision_evaluation.py +30 -6
  25. ml_tools/ML_vision_inference.py +129 -152
  26. ml_tools/ML_vision_models.py +1 -1
  27. ml_tools/ML_vision_transformers.py +121 -40
  28. ml_tools/PSO_optimization.py +6 -6
  29. ml_tools/SQL.py +4 -4
  30. ml_tools/{keys.py → _keys.py} +45 -0
  31. ml_tools/_schema.py +1 -1
  32. ml_tools/ensemble_evaluation.py +1 -1
  33. ml_tools/ensemble_inference.py +7 -33
  34. ml_tools/ensemble_learning.py +1 -1
  35. ml_tools/optimization_tools.py +2 -2
  36. ml_tools/path_manager.py +5 -5
  37. ml_tools/utilities.py +1 -2
  38. dragon_ml_toolbox-14.7.0.dist-info/RECORD +0 -49
  39. ml_tools/RNN_forecast.py +0 -56
  40. ml_tools/_ML_vision_recipe.py +0 -88
  41. {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.1.dist-info}/WHEEL +0 -0
  42. {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.1.dist-info}/licenses/LICENSE +0 -0
  43. {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.1.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
  44. {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.1.dist-info}/top_level.txt +0 -0
ml_tools/ML_trainer.py CHANGED
@@ -1,80 +1,96 @@
1
- from typing import List, Literal, Union, Optional, Callable, Dict, Any, Tuple
1
+ from typing import List, Literal, Union, Optional, Callable, Dict, Any
2
2
  from pathlib import Path
3
3
  from torch.utils.data import DataLoader, Dataset
4
4
  import torch
5
5
  from torch import nn
6
6
  import numpy as np
7
+ from abc import ABC, abstractmethod
7
8
 
8
- from .ML_callbacks import Callback, History, TqdmProgressBar, ModelCheckpoint
9
+ from .path_manager import make_fullpath
10
+ from .ML_callbacks import _Callback, History, TqdmProgressBar, DragonModelCheckpoint, DragonEarlyStopping, DragonLRScheduler
9
11
  from .ML_evaluation import classification_metrics, regression_metrics, plot_losses, shap_summary_plot, plot_attention_importance
10
12
  from .ML_evaluation_multi import multi_target_regression_metrics, multi_label_classification_metrics, multi_target_shap_summary_plot
13
+ from .ML_vision_evaluation import segmentation_metrics, object_detection_metrics
14
+ from .ML_sequence_evaluation import sequence_to_sequence_metrics, sequence_to_value_metrics
15
+ from .ML_configuration import (RegressionMetricsFormat,
16
+ MultiTargetRegressionMetricsFormat,
17
+ BinaryClassificationMetricsFormat,
18
+ MultiClassClassificationMetricsFormat,
19
+ BinaryImageClassificationMetricsFormat,
20
+ MultiClassImageClassificationMetricsFormat,
21
+ MultiLabelBinaryClassificationMetricsFormat,
22
+ BinarySegmentationMetricsFormat,
23
+ MultiClassSegmentationMetricsFormat,
24
+ SequenceValueMetricsFormat,
25
+ SequenceSequenceMetricsFormat,
26
+
27
+ FinalizeBinaryClassification,
28
+ FinalizeBinarySegmentation,
29
+ FinalizeBinaryImageClassification,
30
+ FinalizeMultiClassClassification,
31
+ FinalizeMultiClassImageClassification,
32
+ FinalizeMultiClassSegmentation,
33
+ FinalizeMultiLabelBinaryClassification,
34
+ FinalizeMultiTargetRegression,
35
+ FinalizeRegression,
36
+ FinalizeObjectDetection,
37
+ FinalizeSequencePrediction)
38
+
11
39
  from ._script_info import _script_info
12
- from .keys import PyTorchLogKeys, PyTorchCheckpointKeys, DatasetKeys
40
+ from ._keys import PyTorchLogKeys, PyTorchCheckpointKeys, DatasetKeys, MLTaskKeys, MagicWords, DragonTrainerKeys
13
41
  from ._logger import _LOGGER
14
- from .path_manager import make_fullpath
15
- from .ML_vision_evaluation import segmentation_metrics, object_detection_metrics
16
- from .ML_configuration import ClassificationMetricsFormat, MultiClassificationMetricsFormat
17
42
 
18
43
 
19
44
  __all__ = [
20
- "MLTrainer",
21
- "ObjectDetectionTrainer",
45
+ "DragonTrainer",
46
+ "DragonDetectionTrainer",
47
+ "DragonSequenceTrainer"
22
48
  ]
23
49
 
24
-
25
- class MLTrainer:
26
- def __init__(self, model: nn.Module, train_dataset: Dataset, test_dataset: Dataset,
27
- kind: Literal["regression", "classification", "multi_target_regression", "multi_label_classification", "segmentation"],
28
- criterion: nn.Module, optimizer: torch.optim.Optimizer,
29
- device: Union[Literal['cuda', 'mps', 'cpu'],str], dataloader_workers: int = 2, callbacks: Optional[List[Callback]] = None):
30
- """
31
- Automates the training process of a PyTorch Model.
32
-
33
- Built-in Callbacks: `History`, `TqdmProgressBar`
34
-
35
- Args:
36
- model (nn.Module): The PyTorch model to train.
37
- train_dataset (Dataset): The training dataset.
38
- test_dataset (Dataset): The testing/validation dataset.
39
- kind (str): Can be 'regression', 'classification', 'multi_target_regression', 'multi_label_classification', or 'segmentation'.
40
- criterion (nn.Module): The loss function.
41
- optimizer (torch.optim.Optimizer): The optimizer.
42
- device (str): The device to run training on ('cpu', 'cuda', 'mps').
43
- dataloader_workers (int): Subprocesses for data loading.
44
- callbacks (List[Callback] | None): A list of callbacks to use during training.
45
-
46
- Note:
47
- - For **regression** and **multi_target_regression** tasks, suggested criterions include `nn.MSELoss` or `nn.L1Loss`.
48
-
49
- - For **single-label, multi-class classification** tasks, `nn.CrossEntropyLoss` is the standard choice.
50
+ class _BaseDragonTrainer(ABC):
51
+ """
52
+ Abstract base class for Dragon Trainers.
50
53
 
51
- - For **multi-label, binary classification** tasks (where each label is a 0 or 1), `nn.BCEWithLogitsLoss` is the correct choice as it treats each output as an independent binary problem.
52
-
53
- - For **segmentation** tasks, `nn.CrossEntropyLoss` (for multi-class) or `nn.BCEWithLogitsLoss` (for binary) are common.
54
- """
55
- if kind not in ["regression", "classification", "multi_target_regression", "multi_label_classification", "segmentation"]:
56
- raise ValueError(f"'{kind}' is not a valid task type.")
54
+ Handles the common training loop orchestration, checkpointing, callback
55
+ management, and device handling. Subclasses must implement the
56
+ task-specific logic (dataloaders, train/val steps, evaluation).
57
+ """
58
+ def __init__(self,
59
+ model: nn.Module,
60
+ optimizer: torch.optim.Optimizer,
61
+ device: Union[Literal['cuda', 'mps', 'cpu'],str],
62
+ dataloader_workers: int = 2,
63
+ checkpoint_callback: Optional[DragonModelCheckpoint] = None,
64
+ early_stopping_callback: Optional[DragonEarlyStopping] = None,
65
+ lr_scheduler_callback: Optional[DragonLRScheduler] = None,
66
+ extra_callbacks: Optional[List[_Callback]] = None):
57
67
 
58
68
  self.model = model
59
- self.train_dataset = train_dataset
60
- self.test_dataset = test_dataset
61
- self.kind = kind
62
- self.criterion = criterion
63
69
  self.optimizer = optimizer
64
70
  self.scheduler = None
65
71
  self.device = self._validate_device(device)
66
72
  self.dataloader_workers = dataloader_workers
67
73
 
68
- # Callback handler - History and TqdmProgressBar are added by default
74
+ # Callback handler
69
75
  default_callbacks = [History(), TqdmProgressBar()]
70
- user_callbacks = callbacks if callbacks is not None else []
76
+
77
+ self._checkpoint_callback = None
78
+ if checkpoint_callback:
79
+ default_callbacks.append(checkpoint_callback)
80
+ self._checkpoint_callback = checkpoint_callback
81
+ if early_stopping_callback:
82
+ default_callbacks.append(early_stopping_callback)
83
+ if lr_scheduler_callback:
84
+ default_callbacks.append(lr_scheduler_callback)
85
+
86
+ user_callbacks = extra_callbacks if extra_callbacks is not None else []
71
87
  self.callbacks = default_callbacks + user_callbacks
72
88
  self._set_trainer_on_callbacks()
73
89
 
74
90
  # Internal state
75
- self.train_loader = None
76
- self.test_loader = None
77
- self.history = {}
91
+ self.train_loader: Optional[DataLoader] = None
92
+ self.validation_loader: Optional[DataLoader] = None
93
+ self.history: Dict[str, List[Any]] = {}
78
94
  self.epoch = 0
79
95
  self.epochs = 0 # Total epochs for the fit run
80
96
  self.start_epoch = 1
@@ -97,32 +113,10 @@ class MLTrainer:
97
113
  for callback in self.callbacks:
98
114
  callback.set_trainer(self)
99
115
 
100
- def _create_dataloaders(self, batch_size: int, shuffle: bool):
101
- """Initializes the DataLoaders."""
102
- # Ensure stability on MPS devices by setting num_workers to 0
103
- loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
104
-
105
- self.train_loader = DataLoader(
106
- dataset=self.train_dataset,
107
- batch_size=batch_size,
108
- shuffle=shuffle,
109
- num_workers=loader_workers,
110
- pin_memory=("cuda" in self.device.type),
111
- drop_last=True # Drops the last batch if incomplete, selecting a good batch size is key.
112
- )
113
-
114
- self.test_loader = DataLoader(
115
- dataset=self.test_dataset,
116
- batch_size=batch_size,
117
- shuffle=False,
118
- num_workers=loader_workers,
119
- pin_memory=("cuda" in self.device.type)
120
- )
121
-
122
116
  def _load_checkpoint(self, path: Union[str, Path]):
123
117
  """Loads a training checkpoint to resume training."""
124
118
  p = make_fullpath(path, enforce="file")
125
- _LOGGER.info(f"Loading checkpoint from '{p.name}' to resume training...")
119
+ _LOGGER.info(f"Loading checkpoint from '{p.name}'...")
126
120
 
127
121
  try:
128
122
  checkpoint = torch.load(p, map_location=self.device)
@@ -133,7 +127,16 @@ class MLTrainer:
133
127
 
134
128
  self.model.load_state_dict(checkpoint[PyTorchCheckpointKeys.MODEL_STATE])
135
129
  self.optimizer.load_state_dict(checkpoint[PyTorchCheckpointKeys.OPTIMIZER_STATE])
136
- self.start_epoch = checkpoint.get(PyTorchCheckpointKeys.EPOCH, 0) + 1 # Resume on the *next* epoch
130
+ self.epoch = checkpoint.get(PyTorchCheckpointKeys.EPOCH, 0)
131
+ self.start_epoch = self.epoch + 1 # Resume on the *next* epoch
132
+
133
+ # --- Load History ---
134
+ if PyTorchCheckpointKeys.HISTORY in checkpoint:
135
+ self.history = checkpoint[PyTorchCheckpointKeys.HISTORY]
136
+ _LOGGER.info(f"Restored training history up to epoch {self.epoch}.")
137
+ else:
138
+ _LOGGER.warning("No 'history' found in checkpoint. A new history will be started.")
139
+ self.history = {} # Ensure it's at least an empty dict
137
140
 
138
141
  # --- Scheduler State Loading Logic ---
139
142
  scheduler_state_exists = PyTorchCheckpointKeys.SCHEDULER_STATE in checkpoint
@@ -163,7 +166,7 @@ class MLTrainer:
163
166
 
164
167
  # Restore callback states
165
168
  for cb in self.callbacks:
166
- if isinstance(cb, ModelCheckpoint) and PyTorchCheckpointKeys.BEST_SCORE in checkpoint:
169
+ if isinstance(cb, DragonModelCheckpoint) and PyTorchCheckpointKeys.BEST_SCORE in checkpoint:
167
170
  cb.best = checkpoint[PyTorchCheckpointKeys.BEST_SCORE]
168
171
  _LOGGER.info(f"Restored {cb.__class__.__name__} 'best' score to: {cb.best:.4f}")
169
172
 
@@ -174,7 +177,8 @@ class MLTrainer:
174
177
  raise
175
178
 
176
179
  def fit(self,
177
- epochs: int = 10,
180
+ save_dir: Union[str,Path],
181
+ epochs: int = 100,
178
182
  batch_size: int = 10,
179
183
  shuffle: bool = True,
180
184
  resume_from_checkpoint: Optional[Union[str, Path]] = None):
@@ -184,21 +188,15 @@ class MLTrainer:
184
188
  Returns the "History" callback dictionary.
185
189
 
186
190
  Args:
191
+ save_dir (str | Path): Directory to save the loss plot.
187
192
  epochs (int): The total number of epochs to train for.
188
193
  batch_size (int): The number of samples per batch.
189
194
  shuffle (bool): Whether to shuffle the training data at each epoch.
190
195
  resume_from_checkpoint (str | Path | None): Optional path to a checkpoint to resume training.
191
-
192
- Note:
193
- For regression tasks using `nn.MSELoss` or `nn.L1Loss`, the trainer
194
- automatically aligns the model's output tensor with the target tensor's
195
- shape using `output.view_as(target)`. This handles the common case
196
- where a model outputs a shape of `[batch_size, 1]` and the target has a
197
- shape of `[batch_size]`.
198
196
  """
199
197
  self.epochs = epochs
200
198
  self._batch_size = batch_size
201
- self._create_dataloaders(self._batch_size, shuffle)
199
+ self._create_dataloaders(self._batch_size, shuffle) # type: ignore
202
200
  self.model.to(self.device)
203
201
 
204
202
  if resume_from_checkpoint:
@@ -209,11 +207,19 @@ class MLTrainer:
209
207
 
210
208
  self._callbacks_hook('on_train_begin')
211
209
 
210
+ if not self.train_loader:
211
+ _LOGGER.error("Train loader is not initialized.")
212
+ raise ValueError()
213
+
214
+ if not self.validation_loader:
215
+ _LOGGER.error("Validation loader is not initialized.")
216
+ raise ValueError()
217
+
212
218
  for epoch in range(self.start_epoch, self.epochs + 1):
213
219
  self.epoch = epoch
214
- epoch_logs = {}
220
+ epoch_logs: Dict[str, Any] = {}
215
221
  self._callbacks_hook('on_epoch_begin', epoch, logs=epoch_logs)
216
-
222
+
217
223
  train_logs = self._train_step()
218
224
  epoch_logs.update(train_logs)
219
225
 
@@ -227,11 +233,204 @@ class MLTrainer:
227
233
  break
228
234
 
229
235
  self._callbacks_hook('on_train_end')
236
+
237
+ # Training History
238
+ plot_losses(self.history, save_dir=save_dir)
239
+
230
240
  return self.history
241
+
242
+ def _callbacks_hook(self, method_name: str, *args, **kwargs):
243
+ """Calls the specified method on all callbacks."""
244
+ for callback in self.callbacks:
245
+ method = getattr(callback, method_name)
246
+ method(*args, **kwargs)
247
+
248
+ def to_cpu(self):
249
+ """
250
+ Moves the model to the CPU and updates the trainer's device setting.
251
+
252
+ This is useful for running operations that require the CPU.
253
+ """
254
+ self.device = torch.device('cpu')
255
+ self.model.to(self.device)
256
+ _LOGGER.info("Trainer and model moved to CPU.")
257
+
258
+ def to_device(self, device: str):
259
+ """
260
+ Moves the model to the specified device and updates the trainer's device setting.
261
+
262
+ Args:
263
+ device (str): The target device (e.g., 'cuda', 'mps', 'cpu').
264
+ """
265
+ self.device = self._validate_device(device)
266
+ self.model.to(self.device)
267
+ _LOGGER.info(f"Trainer and model moved to {self.device}.")
268
+
269
+ def _load_model_state_for_finalizing(self, model_checkpoint: Union[Path, Literal['latest', 'current']]):
270
+ """
271
+ Private helper to load the correct model state_dict based on user's choice.
272
+ This is called by finalize_model_training() in subclasses.
273
+ """
274
+ if isinstance(model_checkpoint, Path):
275
+ self._load_checkpoint(path=model_checkpoint)
276
+ elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback:
277
+ path_to_latest = self._checkpoint_callback.best_checkpoint_path
278
+ self._load_checkpoint(path_to_latest)
279
+ elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback is None:
280
+ _LOGGER.error(f"'model_checkpoint' set to '{MagicWords.LATEST}' but no checkpoint callback was found.")
281
+ raise ValueError()
282
+ elif model_checkpoint == MagicWords.CURRENT:
283
+ pass
284
+ else:
285
+ _LOGGER.error(f"Unknown 'model_checkpoint' received '{model_checkpoint}'.")
286
+ raise ValueError()
287
+
288
+ # --- Abstract Methods ---
289
+ # These must be implemented by subclasses
290
+
291
+ @abstractmethod
292
+ def _create_dataloaders(self, batch_size: int, shuffle: bool):
293
+ """Initializes the DataLoaders."""
294
+ raise NotImplementedError
295
+
296
+ @abstractmethod
297
+ def _train_step(self) -> Dict[str, float]:
298
+ """Runs a single training epoch."""
299
+ raise NotImplementedError
300
+
301
+ @abstractmethod
302
+ def _validation_step(self) -> Dict[str, float]:
303
+ """Runs a single validation epoch."""
304
+ raise NotImplementedError
305
+
306
+ @abstractmethod
307
+ def evaluate(self, *args, **kwargs):
308
+ """Runs the full model evaluation."""
309
+ raise NotImplementedError
310
+
311
+ @abstractmethod
312
+ def _evaluate(self, *args, **kwargs):
313
+ """Internal evaluation helper."""
314
+ raise NotImplementedError
315
+
316
+ @abstractmethod
317
+ def finalize_model_training(self, *args, **kwargs):
318
+ """Saves the finalized model for inference."""
319
+ raise NotImplementedError
320
+
321
+
322
+ # --- DragonTrainer ----
323
+ class DragonTrainer(_BaseDragonTrainer):
324
+ def __init__(self,
325
+ model: nn.Module,
326
+ train_dataset: Dataset,
327
+ validation_dataset: Dataset,
328
+ kind: Literal["regression", "binary classification", "multiclass classification",
329
+ "multitarget regression", "multilabel binary classification",
330
+ "binary segmentation", "multiclass segmentation", "binary image classification", "multiclass image classification"],
331
+ optimizer: torch.optim.Optimizer,
332
+ device: Union[Literal['cuda', 'mps', 'cpu'],str],
333
+ checkpoint_callback: Optional[DragonModelCheckpoint],
334
+ early_stopping_callback: Optional[DragonEarlyStopping],
335
+ lr_scheduler_callback: Optional[DragonLRScheduler],
336
+ extra_callbacks: Optional[List[_Callback]] = None,
337
+ criterion: Union[nn.Module,Literal["auto"]] = "auto",
338
+ dataloader_workers: int = 2):
339
+ """
340
+ Automates the training process of a PyTorch Model.
341
+
342
+ Built-in Callbacks: `History`, `TqdmProgressBar`
343
+
344
+ Args:
345
+ model (nn.Module): The PyTorch model to train.
346
+ train_dataset (Dataset): The training dataset.
347
+ validation_dataset (Dataset): The validation dataset.
348
+ kind (str): Used to redirect to the correct process.
349
+ criterion (nn.Module | "auto"): The loss function to use. If "auto", it will be inferred from the selected task
350
+ optimizer (torch.optim.Optimizer): The optimizer.
351
+ device (str): The device to run training on ('cpu', 'cuda', 'mps').
352
+ dataloader_workers (int): Subprocesses for data loading.
353
+ extra_callbacks (List[Callback] | None): A list of extra callbacks to use during training.
354
+
355
+ Note:
356
+ - For **regression** and **multi_target_regression** tasks, suggested criterions include `nn.MSELoss` or `nn.L1Loss`. The model should output as many logits as existing targets.
357
+
358
+ - For **single-label, binary classification**, `nn.BCEWithLogitsLoss` is the standard choice. The model should output a single logit.
231
359
 
360
+ - For **single-label, multi-class classification** tasks, `nn.CrossEntropyLoss` is the standard choice. The model should output as many logits as existing classes.
361
+
362
+ - For **multi-label, binary classification** tasks (where each label is a 0 or 1), `nn.BCEWithLogitsLoss` is the correct choice as it treats each output as an independent binary problem. The model should output 1 logit per binary target.
363
+
364
+ - For **binary segmentation** tasks, `nn.BCEWithLogitsLoss` is common. The model should output a single logit.
365
+
366
+ - for **multiclass segmentation** tasks, `nn.CrossEntropyLoss` is the standard. The model should output as many logits as existing classes.
367
+ """
368
+ # Call the base class constructor with common parameters
369
+ super().__init__(
370
+ model=model,
371
+ optimizer=optimizer,
372
+ device=device,
373
+ dataloader_workers=dataloader_workers,
374
+ checkpoint_callback=checkpoint_callback,
375
+ early_stopping_callback=early_stopping_callback,
376
+ lr_scheduler_callback=lr_scheduler_callback,
377
+ extra_callbacks=extra_callbacks
378
+ )
379
+
380
+ if kind not in [MLTaskKeys.REGRESSION,
381
+ MLTaskKeys.BINARY_CLASSIFICATION,
382
+ MLTaskKeys.MULTICLASS_CLASSIFICATION,
383
+ MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION,
384
+ MLTaskKeys.MULTITARGET_REGRESSION,
385
+ MLTaskKeys.BINARY_SEGMENTATION,
386
+ MLTaskKeys.MULTICLASS_SEGMENTATION,
387
+ MLTaskKeys.BINARY_IMAGE_CLASSIFICATION,
388
+ MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION]:
389
+ raise ValueError(f"'{kind}' is not a valid task type.")
390
+
391
+ self.train_dataset = train_dataset
392
+ self.validation_dataset = validation_dataset
393
+ self.kind = kind
394
+ self._classification_threshold: float = 0.5
395
+
396
+ # loss function
397
+ if criterion == "auto":
398
+ if kind in [MLTaskKeys.REGRESSION, MLTaskKeys.MULTITARGET_REGRESSION]:
399
+ self.criterion = nn.MSELoss()
400
+ elif kind in [MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.BINARY_IMAGE_CLASSIFICATION, MLTaskKeys.BINARY_SEGMENTATION, MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION]:
401
+ self.criterion = nn.BCEWithLogitsLoss()
402
+ elif kind in [MLTaskKeys.MULTICLASS_CLASSIFICATION, MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION, MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION, MLTaskKeys.MULTICLASS_SEGMENTATION]:
403
+ self.criterion = nn.CrossEntropyLoss()
404
+ else:
405
+ self.criterion = criterion
406
+
407
+ def _create_dataloaders(self, batch_size: int, shuffle: bool):
408
+ """Initializes the DataLoaders."""
409
+ # Ensure stability on MPS devices by setting num_workers to 0
410
+ loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
411
+
412
+ self.train_loader = DataLoader(
413
+ dataset=self.train_dataset,
414
+ batch_size=batch_size,
415
+ shuffle=shuffle,
416
+ num_workers=loader_workers,
417
+ pin_memory=("cuda" in self.device.type),
418
+ drop_last=True # Drops the last batch if incomplete, selecting a good batch size is key.
419
+ )
420
+
421
+ self.validation_loader = DataLoader(
422
+ dataset=self.validation_dataset,
423
+ batch_size=batch_size,
424
+ shuffle=False,
425
+ num_workers=loader_workers,
426
+ pin_memory=("cuda" in self.device.type)
427
+ )
428
+
232
429
  def _train_step(self):
233
430
  self.model.train()
234
431
  running_loss = 0.0
432
+ total_samples = 0
433
+
235
434
  for batch_idx, (features, target) in enumerate(self.train_loader): # type: ignore
236
435
  # Create a log dictionary for the batch
237
436
  batch_logs = {
@@ -245,9 +444,21 @@ class MLTrainer:
245
444
 
246
445
  output = self.model(features)
247
446
 
248
- # Apply shape correction only for single-target regression
249
- if self.kind == "regression":
250
- output = output.view_as(target)
447
+ # --- Label Type/Shape Correction ---
448
+ # Cast target to float for BCE-based losses
449
+ if self.kind in MLTaskKeys.ALL_BINARY_TASKS:
450
+ target = target.float()
451
+
452
+ # Reshape output to match target for single-logit tasks
453
+ if self.kind in [MLTaskKeys.REGRESSION, MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.BINARY_IMAGE_CLASSIFICATION]:
454
+ # If model outputs [N, 1] and target is [N], squeeze output
455
+ if output.ndim == 2 and output.shape[1] == 1 and target.ndim == 1:
456
+ output = output.squeeze(1)
457
+
458
+ if self.kind == MLTaskKeys.BINARY_SEGMENTATION:
459
+ # If model outputs [N, 1, H, W] and target is [N, H, W], squeeze output
460
+ if output.ndim == 4 and output.shape[1] == 1 and target.ndim == 3:
461
+ output = output.squeeze(1)
251
462
 
252
463
  loss = self.criterion(output, target)
253
464
 
@@ -256,34 +467,58 @@ class MLTrainer:
256
467
 
257
468
  # Calculate batch loss and update running loss for the epoch
258
469
  batch_loss = loss.item()
259
- running_loss += batch_loss * features.size(0)
470
+ batch_size = features.size(0)
471
+ running_loss += batch_loss * batch_size # Accumulate total loss
472
+ total_samples += batch_size # total samples
260
473
 
261
474
  # Add the batch loss to the logs and call the end-of-batch hook
262
475
  batch_logs[PyTorchLogKeys.BATCH_LOSS] = batch_loss
263
476
  self._callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
477
+
478
+ if total_samples == 0:
479
+ _LOGGER.warning("No samples processed in a train_step. Returning 0 loss.")
480
+ return {PyTorchLogKeys.TRAIN_LOSS: 0.0}
264
481
 
265
- return {PyTorchLogKeys.TRAIN_LOSS: running_loss / len(self.train_loader.dataset)} # type: ignore
482
+ return {PyTorchLogKeys.TRAIN_LOSS: running_loss / total_samples} # type: ignore
266
483
 
267
484
  def _validation_step(self):
268
485
  self.model.eval()
269
486
  running_loss = 0.0
487
+
270
488
  with torch.no_grad():
271
- for features, target in self.test_loader: # type: ignore
489
+ for features, target in self.validation_loader: # type: ignore
272
490
  features, target = features.to(self.device), target.to(self.device)
273
491
 
274
492
  output = self.model(features)
275
- # Apply shape correction only for single-target regression
276
- if self.kind == "regression":
277
- output = output.view_as(target)
493
+
494
+ # --- Label Type/Shape Correction ---
495
+ # Cast target to float for BCE-based losses
496
+ if self.kind in MLTaskKeys.ALL_BINARY_TASKS:
497
+ target = target.float()
498
+
499
+ # Reshape output to match target for single-logit tasks
500
+ if self.kind in [MLTaskKeys.REGRESSION, MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.BINARY_IMAGE_CLASSIFICATION]:
501
+ # If model outputs [N, 1] and target is [N], squeeze output
502
+ if output.ndim == 2 and output.shape[1] == 1 and target.ndim == 1:
503
+ output = output.squeeze(1)
504
+
505
+ if self.kind == MLTaskKeys.BINARY_SEGMENTATION:
506
+ # If model outputs [N, 1, H, W] and target is [N, H, W], squeeze output
507
+ if output.ndim == 4 and output.shape[1] == 1 and target.ndim == 3:
508
+ output = output.squeeze(1)
278
509
 
279
510
  loss = self.criterion(output, target)
280
511
 
281
512
  running_loss += loss.item() * features.size(0)
513
+
514
+ if not self.validation_loader.dataset: # type: ignore
515
+ _LOGGER.warning("No samples processed in _validation_step. Returning 0 loss.")
516
+ return {PyTorchLogKeys.VAL_LOSS: 0.0}
282
517
 
283
- logs = {PyTorchLogKeys.VAL_LOSS: running_loss / len(self.test_loader.dataset)} # type: ignore
518
+ logs = {PyTorchLogKeys.VAL_LOSS: running_loss / len(self.validation_loader.dataset)} # type: ignore
284
519
  return logs
285
520
 
286
- def _predict_for_eval(self, dataloader: DataLoader, classification_threshold: float = 0.5):
521
+ def _predict_for_eval(self, dataloader: DataLoader):
287
522
  """
288
523
  Private method to yield model predictions batch by batch for evaluation.
289
524
 
@@ -294,6 +529,7 @@ class MLTrainer:
294
529
  """
295
530
  self.model.eval()
296
531
  self.model.to(self.device)
532
+
297
533
  with torch.no_grad():
298
534
  for features, target in dataloader:
299
535
  features = features.to(self.device)
@@ -303,25 +539,64 @@ class MLTrainer:
303
539
  y_prob_batch = None
304
540
  y_true_batch = None
305
541
 
306
- if self.kind in ["regression", "multi_target_regression"]:
542
+ if self.kind in [MLTaskKeys.REGRESSION, MLTaskKeys.MULTITARGET_REGRESSION]:
307
543
  y_pred_batch = output.numpy()
308
544
  y_true_batch = target.numpy()
545
+
546
+ elif self.kind in [MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.BINARY_IMAGE_CLASSIFICATION]:
547
+ # Assumes model output is [N, 1] (a single logit)
548
+ # Squeeze output from [N, 1] to [N] if necessary
549
+ if output.ndim == 2 and output.shape[1] == 1:
550
+ output = output.squeeze(1)
551
+
552
+ probs_pos = torch.sigmoid(output) # Probability of positive class
553
+ preds = (probs_pos >= self._classification_threshold).int()
554
+ y_pred_batch = preds.numpy()
555
+ # For metrics (like ROC AUC), we often need probs for *both* classes
556
+ # Create an [N, 2] array: [prob_class_0, prob_class_1]
557
+ probs_neg = 1.0 - probs_pos
558
+ y_prob_batch = torch.stack([probs_neg, probs_pos], dim=1).numpy()
559
+ y_true_batch = target.numpy()
309
560
 
310
- elif self.kind == "classification":
561
+ elif self.kind in [MLTaskKeys.MULTICLASS_CLASSIFICATION, MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION]:
562
+ num_classes = output.shape[1]
563
+ if num_classes < 3:
564
+ # Optional: warn the user they are using the wrong kind
565
+ wrong_class = MLTaskKeys.MULTICLASS_CLASSIFICATION if self.kind == MLTaskKeys.MULTICLASS_CLASSIFICATION else MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION
566
+ recommended_class = MLTaskKeys.BINARY_CLASSIFICATION if self.kind == MLTaskKeys.MULTICLASS_CLASSIFICATION else MLTaskKeys.BINARY_IMAGE_CLASSIFICATION
567
+ _LOGGER.warning(f"'{wrong_class}' kind used with {num_classes} classes. Consider using '{recommended_class}' instead.")
568
+
311
569
  probs = torch.softmax(output, dim=1)
312
570
  preds = torch.argmax(probs, dim=1)
313
571
  y_pred_batch = preds.numpy()
314
572
  y_prob_batch = probs.numpy()
315
573
  y_true_batch = target.numpy()
316
574
 
317
- elif self.kind == "multi_label_classification":
575
+ elif self.kind == MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION:
318
576
  probs = torch.sigmoid(output)
319
- preds = (probs >= classification_threshold).int()
577
+ preds = (probs >= self._classification_threshold).int()
320
578
  y_pred_batch = preds.numpy()
321
579
  y_prob_batch = probs.numpy()
322
580
  y_true_batch = target.numpy()
581
+
582
+ elif self.kind == MLTaskKeys.BINARY_SEGMENTATION:
583
+ # Assumes model output is [N, 1, H, W] (logits for positive class)
584
+ probs_pos = torch.sigmoid(output) # Shape [N, 1, H, W]
585
+ preds = (probs_pos >= self._classification_threshold).int() # Shape [N, 1, H, W]
586
+
587
+ # Squeeze preds to [N, H, W] (class indices 0 or 1)
588
+ y_pred_batch = preds.squeeze(1).numpy()
589
+
590
+ # Create [N, 2, H, W] probs for consistency
591
+ probs_neg = 1.0 - probs_pos
592
+ y_prob_batch = torch.cat([probs_neg, probs_pos], dim=1).numpy()
593
+
594
+ # Handle target shape [N, 1, H, W] -> [N, H, W]
595
+ if target.ndim == 4 and target.shape[1] == 1:
596
+ target = target.squeeze(1)
597
+ y_true_batch = target.numpy()
323
598
 
324
- elif self.kind == "segmentation":
599
+ elif self.kind == MLTaskKeys.MULTICLASS_SEGMENTATION:
325
600
  # output shape [N, C, H, W]
326
601
  probs = torch.softmax(output, dim=1)
327
602
  preds = torch.argmax(probs, dim=1) # shape [N, H, W]
@@ -334,26 +609,192 @@ class MLTrainer:
334
609
  y_true_batch = target.numpy()
335
610
 
336
611
  yield y_pred_batch, y_prob_batch, y_true_batch
337
-
612
+
338
613
  def evaluate(self,
339
614
  save_dir: Union[str, Path],
340
- data: Optional[Union[DataLoader, Dataset]] = None,
341
- format_configuration: Optional[Union[ClassificationMetricsFormat, MultiClassificationMetricsFormat]]=None):
615
+ model_checkpoint: Union[Path, Literal["latest", "current"]],
616
+ classification_threshold: Optional[float] = None,
617
+ test_data: Optional[Union[DataLoader, Dataset]] = None,
618
+ val_format_configuration: Optional[Union[
619
+ RegressionMetricsFormat,
620
+ MultiTargetRegressionMetricsFormat,
621
+ BinaryClassificationMetricsFormat,
622
+ MultiClassClassificationMetricsFormat,
623
+ BinaryImageClassificationMetricsFormat,
624
+ MultiClassImageClassificationMetricsFormat,
625
+ MultiLabelBinaryClassificationMetricsFormat,
626
+ BinarySegmentationMetricsFormat,
627
+ MultiClassSegmentationMetricsFormat
628
+ ]]=None,
629
+ test_format_configuration: Optional[Union[
630
+ RegressionMetricsFormat,
631
+ MultiTargetRegressionMetricsFormat,
632
+ BinaryClassificationMetricsFormat,
633
+ MultiClassClassificationMetricsFormat,
634
+ BinaryImageClassificationMetricsFormat,
635
+ MultiClassImageClassificationMetricsFormat,
636
+ MultiLabelBinaryClassificationMetricsFormat,
637
+ BinarySegmentationMetricsFormat,
638
+ MultiClassSegmentationMetricsFormat,
639
+ ]]=None):
342
640
  """
343
641
  Evaluates the model, routing to the correct evaluation function based on task `kind`.
344
642
 
345
643
  Args:
644
+ model_checkpoint ('auto' | Path | None):
645
+ - Path to a valid checkpoint for the model. The state of the trained model will be overwritten in place.
646
+ - If 'latest', the latest checkpoint will be loaded if a DragonModelCheckpoint was provided. The state of the trained model will be overwritten in place.
647
+ - If 'current', use the current state of the trained model up the latest trained epoch.
346
648
  save_dir (str | Path): Directory to save all reports and plots.
347
- data (DataLoader | Dataset | None): The data to evaluate on. If None, defaults to the trainer's internal test_dataset.
649
+ classification_threshold (float | None): Used for tasks using a binary approach (binary classification, binary segmentation, multilabel binary classification)
650
+ test_data (DataLoader | Dataset | None): Optional Test data to evaluate the model performance. Validation and Test metrics will be saved to subdirectories.
651
+ val_format_configuration (object): Optional configuration for metric format output for the validation set.
652
+ test_format_configuration (object): Optional configuration for metric format output for the test set.
653
+ """
654
+ # Validate model checkpoint
655
+ if isinstance(model_checkpoint, Path):
656
+ checkpoint_validated = make_fullpath(model_checkpoint, enforce="file")
657
+ elif model_checkpoint in [MagicWords.LATEST, MagicWords.CURRENT]:
658
+ checkpoint_validated = model_checkpoint
659
+ else:
660
+ _LOGGER.error(f"'model_checkpoint' must be a Path object, or the string '{MagicWords.LATEST}', or the string '{MagicWords.CURRENT}'.")
661
+ raise ValueError()
662
+
663
+ # Validate classification threshold
664
+ if self.kind not in MLTaskKeys.ALL_BINARY_TASKS:
665
+ # dummy value for tasks that do not need it
666
+ threshold_validated = 0.5
667
+ elif classification_threshold is None:
668
+ # it should have been provided for binary tasks
669
+ _LOGGER.error(f"The classification threshold must be provided for '{self.kind}'.")
670
+ raise ValueError()
671
+ elif classification_threshold <= 0.0 or classification_threshold >= 1.0:
672
+ # Invalid float
673
+ _LOGGER.error(f"A classification threshold of {classification_threshold} is invalid. Must be in the range (0.0 - 1.0).")
674
+ raise ValueError()
675
+ else:
676
+ threshold_validated = classification_threshold
677
+
678
+ # Validate val configuration
679
+ if val_format_configuration is not None:
680
+ if not isinstance(val_format_configuration, (RegressionMetricsFormat,
681
+ MultiTargetRegressionMetricsFormat,
682
+ BinaryClassificationMetricsFormat,
683
+ MultiClassClassificationMetricsFormat,
684
+ BinaryImageClassificationMetricsFormat,
685
+ MultiClassImageClassificationMetricsFormat,
686
+ MultiLabelBinaryClassificationMetricsFormat,
687
+ BinarySegmentationMetricsFormat,
688
+ MultiClassSegmentationMetricsFormat)):
689
+ _LOGGER.error(f"Invalid 'format_configuration': '{type(val_format_configuration)}'.")
690
+ raise ValueError()
691
+ else:
692
+ val_configuration_validated = val_format_configuration
693
+ else: # config is None
694
+ val_configuration_validated = None
695
+
696
+ # Validate directory
697
+ save_path = make_fullpath(save_dir, make=True, enforce="directory")
698
+
699
+ # Validate test data and dispatch
700
+ if test_data is not None:
701
+ if not isinstance(test_data, (DataLoader, Dataset)):
702
+ _LOGGER.error(f"Invalid type for 'test_data': '{type(test_data)}'.")
703
+ raise ValueError()
704
+ test_data_validated = test_data
705
+
706
+ validation_metrics_path = save_path / DragonTrainerKeys.VALIDATION_METRICS_DIR
707
+ test_metrics_path = save_path / DragonTrainerKeys.TEST_METRICS_DIR
708
+
709
+ # Dispatch validation set
710
+ _LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{DragonTrainerKeys.VALIDATION_METRICS_DIR}'")
711
+ self._evaluate(save_dir=validation_metrics_path,
712
+ model_checkpoint=checkpoint_validated,
713
+ classification_threshold=threshold_validated,
714
+ data=None,
715
+ format_configuration=val_configuration_validated)
716
+
717
+ # Validate test configuration
718
+ if test_format_configuration is not None:
719
+ if not isinstance(test_format_configuration, (RegressionMetricsFormat,
720
+ MultiTargetRegressionMetricsFormat,
721
+ BinaryClassificationMetricsFormat,
722
+ MultiClassClassificationMetricsFormat,
723
+ BinaryImageClassificationMetricsFormat,
724
+ MultiClassImageClassificationMetricsFormat,
725
+ MultiLabelBinaryClassificationMetricsFormat,
726
+ BinarySegmentationMetricsFormat,
727
+ MultiClassSegmentationMetricsFormat)):
728
+ warning_message_type = f"Invalid test_format_configuration': '{type(test_format_configuration)}'."
729
+ if val_configuration_validated is not None:
730
+ warning_message_type += " 'val_format_configuration' will be used for the test set metrics output."
731
+ test_configuration_validated = val_configuration_validated
732
+ else:
733
+ warning_message_type += " Using default format."
734
+ test_configuration_validated = None
735
+ _LOGGER.warning(warning_message_type)
736
+ else:
737
+ test_configuration_validated = test_format_configuration
738
+ else: #config is None
739
+ test_configuration_validated = None
740
+
741
+ # Dispatch test set
742
+ _LOGGER.info(f"Evaluating on test dataset. Metrics will be saved to '{DragonTrainerKeys.TEST_METRICS_DIR}'")
743
+ self._evaluate(save_dir=test_metrics_path,
744
+ model_checkpoint="current",
745
+ classification_threshold=threshold_validated,
746
+ data=test_data_validated,
747
+ format_configuration=test_configuration_validated)
748
+ else:
749
+ # Dispatch validation set
750
+ _LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{save_path.name}'")
751
+ self._evaluate(save_dir=save_path,
752
+ model_checkpoint=checkpoint_validated,
753
+ classification_threshold=threshold_validated,
754
+ data=None,
755
+ format_configuration=val_configuration_validated)
756
+
757
+ def _evaluate(self,
758
+ save_dir: Union[str, Path],
759
+ model_checkpoint: Union[Path, Literal["latest", "current"]],
760
+ classification_threshold: float,
761
+ data: Optional[Union[DataLoader, Dataset]],
762
+ format_configuration: Optional[Union[
763
+ RegressionMetricsFormat,
764
+ MultiTargetRegressionMetricsFormat,
765
+ BinaryClassificationMetricsFormat,
766
+ MultiClassClassificationMetricsFormat,
767
+ BinaryImageClassificationMetricsFormat,
768
+ MultiClassImageClassificationMetricsFormat,
769
+ MultiLabelBinaryClassificationMetricsFormat,
770
+ BinarySegmentationMetricsFormat,
771
+ MultiClassSegmentationMetricsFormat
772
+ ]]=None):
773
+ """
774
+ Changed to a private helper function.
348
775
  """
349
- dataset_for_names = None
776
+ dataset_for_artifacts = None
350
777
  eval_loader = None
351
-
778
+
779
+ # set threshold
780
+ self._classification_threshold = classification_threshold
781
+
782
+ # load model checkpoint
783
+ if isinstance(model_checkpoint, Path):
784
+ self._load_checkpoint(path=model_checkpoint)
785
+ elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback:
786
+ path_to_latest = self._checkpoint_callback.best_checkpoint_path
787
+ self._load_checkpoint(path_to_latest)
788
+ elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback is None:
789
+ _LOGGER.error(f"'model_checkpoint' set to '{MagicWords.LATEST}' but no checkpoint callback was found.")
790
+ raise ValueError()
791
+
792
+ # Dataloader
352
793
  if isinstance(data, DataLoader):
353
794
  eval_loader = data
354
795
  # Try to get the dataset from the loader for fetching target names
355
796
  if hasattr(data, 'dataset'):
356
- dataset_for_names = data.dataset
797
+ dataset_for_artifacts = data.dataset # type: ignore
357
798
  elif isinstance(data, Dataset):
358
799
  # Create a new loader from the provided dataset
359
800
  eval_loader = DataLoader(data,
@@ -361,19 +802,19 @@ class MLTrainer:
361
802
  shuffle=False,
362
803
  num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
363
804
  pin_memory=(self.device.type == "cuda"))
364
- dataset_for_names = data
805
+ dataset_for_artifacts = data
365
806
  else: # data is None, use the trainer's default test dataset
366
- if self.test_dataset is None:
367
- _LOGGER.error("Cannot evaluate. No data provided and no test_dataset available in the trainer.")
807
+ if self.validation_dataset is None:
808
+ _LOGGER.error("Cannot evaluate. No data provided and no validation dataset available in the trainer.")
368
809
  raise ValueError()
369
810
  # Create a fresh DataLoader from the test_dataset
370
- eval_loader = DataLoader(self.test_dataset,
811
+ eval_loader = DataLoader(self.validation_dataset,
371
812
  batch_size=self._batch_size,
372
813
  shuffle=False,
373
814
  num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
374
815
  pin_memory=(self.device.type == "cuda"))
375
816
 
376
- dataset_for_names = self.test_dataset
817
+ dataset_for_artifacts = self.validation_dataset
377
818
 
378
819
  if eval_loader is None:
379
820
  _LOGGER.error("Cannot evaluate. No valid data was provided or found.")
@@ -396,36 +837,83 @@ class MLTrainer:
396
837
  y_prob = np.concatenate(all_probs) if all_probs else None
397
838
 
398
839
  # --- Routing Logic ---
399
- if self.kind == "regression":
400
- regression_metrics(y_true.flatten(), y_pred.flatten(), save_dir)
401
-
402
- elif self.kind == "classification":
403
- # Parse configuration
404
- if format_configuration and isinstance(format_configuration, ClassificationMetricsFormat):
405
- classification_metrics(save_dir=save_dir,
406
- y_true=y_true,
407
- y_pred=y_pred,
408
- y_prob=y_prob,
409
- cmap=format_configuration.cmap,
410
- class_map=format_configuration.class_map,
411
- ROC_PR_line=format_configuration.ROC_PR_line,
412
- calibration_bins=format_configuration.calibration_bins,
413
- font_size=format_configuration.font_size)
840
+ # Single-target regression
841
+ if self.kind == MLTaskKeys.REGRESSION:
842
+ # Check configuration
843
+ config = None
844
+ if format_configuration and isinstance(format_configuration, RegressionMetricsFormat):
845
+ config = format_configuration
846
+ elif format_configuration:
847
+ _LOGGER.warning(f"Wrong configuration type: Received '{type(format_configuration).__name__}'.")
848
+
849
+ regression_metrics(y_true=y_true.flatten(),
850
+ y_pred=y_pred.flatten(),
851
+ save_dir=save_dir,
852
+ config=config)
853
+
854
+ # single target classification
855
+ elif self.kind in [MLTaskKeys.BINARY_CLASSIFICATION,
856
+ MLTaskKeys.BINARY_IMAGE_CLASSIFICATION,
857
+ MLTaskKeys.MULTICLASS_CLASSIFICATION,
858
+ MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION]:
859
+ # get the class map if it exists
860
+ try:
861
+ class_map = dataset_for_artifacts.class_map # type: ignore
862
+ except AttributeError:
863
+ _LOGGER.warning(f"Dataset has no 'class_map' attribute. Using generics.")
864
+ class_map = None
414
865
  else:
415
- classification_metrics(save_dir, y_true, y_pred, y_prob)
416
-
417
- elif self.kind == "multi_target_regression":
866
+ if not isinstance(class_map, dict):
867
+ _LOGGER.warning(f"Dataset has a 'class_map' attribute, but it is not a dictionary: '{type(class_map)}'.")
868
+ class_map = None
869
+
870
+ # Check configuration
871
+ config = None
872
+ if format_configuration:
873
+ if self.kind == MLTaskKeys.BINARY_CLASSIFICATION and isinstance(format_configuration, BinaryClassificationMetricsFormat):
874
+ config = format_configuration
875
+ elif self.kind == MLTaskKeys.BINARY_IMAGE_CLASSIFICATION and isinstance(format_configuration, BinaryImageClassificationMetricsFormat):
876
+ config = format_configuration
877
+ elif self.kind == MLTaskKeys.MULTICLASS_CLASSIFICATION and isinstance(format_configuration, MultiClassClassificationMetricsFormat):
878
+ config = format_configuration
879
+ elif self.kind == MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION and isinstance(format_configuration, MultiClassImageClassificationMetricsFormat):
880
+ config = format_configuration
881
+ else:
882
+ _LOGGER.warning(f"Wrong configuration type: Received '{type(format_configuration).__name__}'.")
883
+
884
+ classification_metrics(save_dir=save_dir,
885
+ y_true=y_true,
886
+ y_pred=y_pred,
887
+ y_prob=y_prob,
888
+ class_map=class_map,
889
+ config=config)
890
+
891
+ # multitarget regression
892
+ elif self.kind == MLTaskKeys.MULTITARGET_REGRESSION:
418
893
  try:
419
- target_names = dataset_for_names.target_names # type: ignore
894
+ target_names = dataset_for_artifacts.target_names # type: ignore
420
895
  except AttributeError:
421
896
  num_targets = y_true.shape[1]
422
897
  target_names = [f"target_{i}" for i in range(num_targets)]
423
898
  _LOGGER.warning(f"Dataset has no 'target_names' attribute. Using generic names.")
424
- multi_target_regression_metrics(y_true, y_pred, target_names, save_dir)
425
-
426
- elif self.kind == "multi_label_classification":
899
+
900
+ # Check configuration
901
+ config = None
902
+ if format_configuration and isinstance(format_configuration, MultiTargetRegressionMetricsFormat):
903
+ config = format_configuration
904
+ elif format_configuration:
905
+ _LOGGER.warning(f"Wrong configuration type: Received '{type(format_configuration).__name__}'.")
906
+
907
+ multi_target_regression_metrics(y_true=y_true,
908
+ y_pred=y_pred,
909
+ target_names=target_names,
910
+ save_dir=save_dir,
911
+ config=config)
912
+
913
+ # multi-label binary classification
914
+ elif self.kind == MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION:
427
915
  try:
428
- target_names = dataset_for_names.target_names # type: ignore
916
+ target_names = dataset_for_artifacts.target_names # type: ignore
429
917
  except AttributeError:
430
918
  num_targets = y_true.shape[1]
431
919
  target_names = [f"label_{i}" for i in range(num_targets)]
@@ -435,44 +923,55 @@ class MLTrainer:
435
923
  _LOGGER.error("Evaluation for multi_label_classification requires probabilities (y_prob).")
436
924
  return
437
925
 
438
- if format_configuration and isinstance(format_configuration, MultiClassificationMetricsFormat):
439
- multi_label_classification_metrics(y_true=y_true,
440
- y_prob=y_prob,
441
- target_names=target_names,
442
- save_dir=save_dir,
443
- threshold=format_configuration.threshold,
444
- ROC_PR_line=format_configuration.ROC_PR_line,
445
- cmap=format_configuration.cmap,
446
- font_size=format_configuration.font_size)
447
- else:
448
- multi_label_classification_metrics(y_true, y_prob, target_names, save_dir)
449
-
450
- elif self.kind == "segmentation":
926
+ # Check configuration
927
+ config = None
928
+ if format_configuration and isinstance(format_configuration, MultiLabelBinaryClassificationMetricsFormat):
929
+ config = format_configuration
930
+ elif format_configuration:
931
+ _LOGGER.warning(f"Wrong configuration type: Received '{type(format_configuration).__name__}'.")
932
+
933
+ multi_label_classification_metrics(y_true=y_true,
934
+ y_pred=y_pred,
935
+ y_prob=y_prob,
936
+ target_names=target_names,
937
+ save_dir=save_dir,
938
+ config=config)
939
+
940
+ # Segmentation tasks
941
+ elif self.kind in [MLTaskKeys.BINARY_SEGMENTATION, MLTaskKeys.MULTICLASS_SEGMENTATION]:
451
942
  class_names = None
452
943
  try:
453
944
  # Try to get 'classes' from VisionDatasetMaker
454
- if hasattr(dataset_for_names, 'classes'):
455
- class_names = dataset_for_names.classes # type: ignore
945
+ if hasattr(dataset_for_artifacts, 'classes'):
946
+ class_names = dataset_for_artifacts.classes # type: ignore
456
947
  # Fallback for Subset
457
- elif hasattr(dataset_for_names, 'dataset') and hasattr(dataset_for_names.dataset, 'classes'): # type: ignore
458
- class_names = dataset_for_names.dataset.classes # type: ignore
948
+ elif hasattr(dataset_for_artifacts, 'dataset') and hasattr(dataset_for_artifacts.dataset, 'classes'): # type: ignore
949
+ class_names = dataset_for_artifacts.dataset.classes # type: ignore
459
950
  except AttributeError:
460
951
  pass # class_names is still None
461
952
 
462
953
  if class_names is None:
463
954
  try:
464
955
  # Fallback to 'target_names'
465
- class_names = dataset_for_names.target_names # type: ignore
956
+ class_names = dataset_for_artifacts.target_names # type: ignore
466
957
  except AttributeError:
467
958
  # Fallback to inferring from labels
468
959
  labels = np.unique(y_true)
469
960
  class_names = [f"Class {i}" for i in labels]
470
961
  _LOGGER.warning(f"Dataset has no 'classes' or 'target_names' attribute. Using generic names.")
471
962
 
472
- segmentation_metrics(y_true, y_pred, save_dir, class_names=class_names)
473
-
474
- # print("\n--- Training History ---")
475
- plot_losses(self.history, save_dir=save_dir)
963
+ # Check configuration
964
+ config = None
965
+ if format_configuration and isinstance(format_configuration, (BinarySegmentationMetricsFormat, MultiClassSegmentationMetricsFormat)):
966
+ config = format_configuration
967
+ elif format_configuration:
968
+ _LOGGER.warning(f"Wrong configuration type: Received '{type(format_configuration).__name__}'.")
969
+
970
+ segmentation_metrics(y_true=y_true,
971
+ y_pred=y_pred,
972
+ save_dir=save_dir,
973
+ class_names=class_names,
974
+ config=config)
476
975
 
477
976
  def explain(self,
478
977
  save_dir: Union[str,Path],
@@ -500,34 +999,52 @@ class MLTrainer:
500
999
  explainer_type (Literal['deep', 'kernel']): The explainer to use.
501
1000
  - 'deep': Uses shap.DeepExplainer. Fast and efficient for PyTorch models.
502
1001
  - 'kernel': Uses shap.KernelExplainer. Model-agnostic but EXTREMELY slow and memory-intensive. Use with a very low 'n_samples'< 100.
503
- """
504
- # Internal helper to create a dataloader and get a random sample
1002
+ """
1003
+ # memory efficient helper
505
1004
  def _get_random_sample(dataset: Dataset, num_samples: int):
1005
+ """
1006
+ Memory-efficiently samples data from a dataset.
1007
+ """
506
1008
  if dataset is None:
507
1009
  return None
508
1010
 
1011
+ dataset_len = len(dataset) # type: ignore
1012
+ if dataset_len == 0:
1013
+ return None
1014
+
509
1015
  # For MPS devices, num_workers must be 0 to ensure stability
510
1016
  loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
511
1017
 
1018
+ # Ensure batch_size is not larger than the dataset itself
1019
+ batch_size = min(num_samples, 64, dataset_len)
1020
+
512
1021
  loader = DataLoader(
513
1022
  dataset,
514
- batch_size=64,
515
- shuffle=False,
1023
+ batch_size=batch_size,
1024
+ shuffle=True, # Shuffle to get random samples
516
1025
  num_workers=loader_workers
517
1026
  )
518
1027
 
519
- all_features = [features for features, _ in loader]
520
- if not all_features:
1028
+ collected_features = []
1029
+ num_collected = 0
1030
+
1031
+ for features, _ in loader:
1032
+ collected_features.append(features)
1033
+ num_collected += features.size(0)
1034
+ if num_collected >= num_samples:
1035
+ break # Stop once we have enough samples
1036
+
1037
+ if not collected_features:
521
1038
  return None
522
1039
 
523
- full_data = torch.cat(all_features, dim=0)
1040
+ full_data = torch.cat(collected_features, dim=0)
524
1041
 
525
- if num_samples >= full_data.size(0):
526
- return full_data
1042
+ # If we collected more than needed, trim it down
1043
+ if full_data.size(0) > num_samples:
1044
+ return full_data[:num_samples]
527
1045
 
528
- rand_indices = torch.randperm(full_data.size(0))[:num_samples]
529
- return full_data[rand_indices]
530
-
1046
+ return full_data
1047
+
531
1048
  # print(f"\n--- Preparing SHAP Data (sampling up to {n_samples} instances) ---")
532
1049
 
533
1050
  # 1. Get background data from the trainer's train_dataset
@@ -537,7 +1054,7 @@ class MLTrainer:
537
1054
  return
538
1055
 
539
1056
  # 2. Determine target dataset and get explanation instances
540
- target_dataset = explain_dataset if explain_dataset is not None else self.test_dataset
1057
+ target_dataset = explain_dataset if explain_dataset is not None else self.validation_dataset
541
1058
  instances_to_explain = _get_random_sample(target_dataset, n_samples)
542
1059
  if instances_to_explain is None:
543
1060
  _LOGGER.error("Explanation dataset is empty or invalid. Skipping SHAP analysis.")
@@ -556,7 +1073,7 @@ class MLTrainer:
556
1073
  self.model.to(self.device)
557
1074
 
558
1075
  # 3. Call the plotting function
559
- if self.kind in ["regression", "classification"]:
1076
+ if self.kind in [MLTaskKeys.REGRESSION, MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.MULTICLASS_CLASSIFICATION]:
560
1077
  shap_summary_plot(
561
1078
  model=self.model,
562
1079
  background_data=background_data,
@@ -566,7 +1083,7 @@ class MLTrainer:
566
1083
  explainer_type=explainer_type,
567
1084
  device=self.device
568
1085
  )
569
- elif self.kind in ["multi_target_regression", "multi_label_classification"]:
1086
+ elif self.kind in [MLTaskKeys.MULTITARGET_REGRESSION, MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION]:
570
1087
  # try to get target names
571
1088
  if target_names is None:
572
1089
  target_names = []
@@ -640,13 +1157,11 @@ class MLTrainer:
640
1157
 
641
1158
  # --- Step 1: Check if the model supports this explanation ---
642
1159
  if not getattr(self.model, 'has_interpretable_attention', False):
643
- _LOGGER.warning(
644
- "Model is not flagged for interpretable attention analysis. Skipping. This is the correct behavior for models like MultiHeadAttentionMLP."
645
- )
1160
+ _LOGGER.warning("Model is not compatible with interpretable attention analysis. Skipping.")
646
1161
  return
647
1162
 
648
1163
  # --- Step 2: Set up the dataloader ---
649
- dataset_to_use = explain_dataset if explain_dataset is not None else self.test_dataset
1164
+ dataset_to_use = explain_dataset if explain_dataset is not None else self.validation_dataset
650
1165
  if not isinstance(dataset_to_use, Dataset):
651
1166
  _LOGGER.error("The explanation dataset is empty or invalid. Skipping attention analysis.")
652
1167
  return
@@ -681,40 +1196,101 @@ class MLTrainer:
681
1196
  )
682
1197
  else:
683
1198
  _LOGGER.error("No attention weights were collected from the model.")
684
-
685
- def _callbacks_hook(self, method_name: str, *args, **kwargs):
686
- """Calls the specified method on all callbacks."""
687
- for callback in self.callbacks:
688
- method = getattr(callback, method_name)
689
- method(*args, **kwargs)
690
-
691
- def to_cpu(self):
692
- """
693
- Moves the model to the CPU and updates the trainer's device setting.
694
1199
 
695
- This is useful for running operations that require the CPU.
696
- """
697
- self.device = torch.device('cpu')
698
- self.model.to(self.device)
699
- _LOGGER.info("Trainer and model moved to CPU.")
700
-
701
- def to_device(self, device: str):
1200
+ def finalize_model_training(self,
1201
+ model_checkpoint: Union[Path, Literal['latest', 'current']],
1202
+ save_dir: Union[str, Path],
1203
+ finalize_config: Union[FinalizeRegression,
1204
+ FinalizeMultiTargetRegression,
1205
+ FinalizeBinaryClassification,
1206
+ FinalizeBinaryImageClassification,
1207
+ FinalizeMultiClassClassification,
1208
+ FinalizeMultiClassImageClassification,
1209
+ FinalizeBinarySegmentation,
1210
+ FinalizeMultiClassSegmentation,
1211
+ FinalizeMultiLabelBinaryClassification]):
702
1212
  """
703
- Moves the model to the specified device and updates the trainer's device setting.
1213
+ Saves a finalized, "inference-ready" model state to a .pth file.
1214
+
1215
+ This method saves the model's `state_dict`, the final epoch number, and optional configuration for the task at hand.
704
1216
 
705
1217
  Args:
706
- device (str): The target device (e.g., 'cuda', 'mps', 'cpu').
1218
+ model_checkpoint (Path | "latest" | "current"):
1219
+ - Path: Loads the model state from a specific checkpoint file.
1220
+ - "latest": Loads the best model state saved by the `DragonModelCheckpoint` callback.
1221
+ - "current": Uses the model's state as it is at the end of the `fit()` call.
1222
+ save_dir (str | Path): The directory to save the finalized model.
1223
+ finalize_config (object): A data class instance specific to the ML task containing task-specific metadata required for inference.
707
1224
  """
708
- self.device = self._validate_device(device)
709
- self.model.to(self.device)
710
- _LOGGER.info(f"Trainer and model moved to {self.device}.")
1225
+ if self.kind == MLTaskKeys.REGRESSION and not isinstance(finalize_config, FinalizeRegression):
1226
+ _LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeRegression', but got {type(finalize_config).__name__}.")
1227
+ raise TypeError()
1228
+ elif self.kind == MLTaskKeys.MULTITARGET_REGRESSION and not isinstance(finalize_config, FinalizeMultiTargetRegression):
1229
+ _LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeMultiTargetRegression', but got {type(finalize_config).__name__}.")
1230
+ raise TypeError()
1231
+ elif self.kind == MLTaskKeys.BINARY_CLASSIFICATION and not isinstance(finalize_config, FinalizeBinaryClassification):
1232
+ _LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeBinaryClassification', but got {type(finalize_config).__name__}.")
1233
+ raise TypeError()
1234
+ elif self.kind == MLTaskKeys.BINARY_IMAGE_CLASSIFICATION and not isinstance(finalize_config, FinalizeBinaryImageClassification):
1235
+ _LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeBinaryImageClassification', but got {type(finalize_config).__name__}.")
1236
+ raise TypeError()
1237
+ elif self.kind == MLTaskKeys.MULTICLASS_CLASSIFICATION and not isinstance(finalize_config, FinalizeMultiClassClassification):
1238
+ _LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeMultiClassClassification', but got {type(finalize_config).__name__}.")
1239
+ raise TypeError()
1240
+ elif self.kind == MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION and not isinstance(finalize_config, FinalizeMultiClassImageClassification):
1241
+ _LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeMultiClassImageClassification', but got {type(finalize_config).__name__}.")
1242
+ raise TypeError()
1243
+ elif self.kind == MLTaskKeys.BINARY_SEGMENTATION and not isinstance(finalize_config, FinalizeBinarySegmentation):
1244
+ _LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeBinarySegmentation', but got {type(finalize_config).__name__}.")
1245
+ raise TypeError()
1246
+ elif self.kind == MLTaskKeys.MULTICLASS_SEGMENTATION and not isinstance(finalize_config, FinalizeMultiClassSegmentation):
1247
+ _LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeMultiClassSegmentation', but got {type(finalize_config).__name__}.")
1248
+ raise TypeError()
1249
+ elif self.kind == MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION and not isinstance(finalize_config, FinalizeMultiLabelBinaryClassification):
1250
+ _LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeMultiLabelBinaryClassification', but got {type(finalize_config).__name__}.")
1251
+ raise TypeError()
1252
+
1253
+ # handle save path
1254
+ dir_path = make_fullpath(save_dir, make=True, enforce="directory")
1255
+ full_path = dir_path / finalize_config.filename
1256
+
1257
+ # handle checkpoint
1258
+ self._load_model_state_for_finalizing(model_checkpoint)
1259
+
1260
+ # Create finalized data
1261
+ finalized_data = {
1262
+ PyTorchCheckpointKeys.EPOCH: self.epoch,
1263
+ PyTorchCheckpointKeys.MODEL_STATE: self.model.state_dict(),
1264
+ }
1265
+
1266
+ # Parse config
1267
+ if finalize_config.target_name is not None:
1268
+ finalized_data[PyTorchCheckpointKeys.TARGET_NAME] = finalize_config.target_name
1269
+ if finalize_config.target_names is not None:
1270
+ finalized_data[PyTorchCheckpointKeys.TARGET_NAMES] = finalize_config.target_names
1271
+ if finalize_config.classification_threshold is not None:
1272
+ finalized_data[PyTorchCheckpointKeys.CLASSIFICATION_THRESHOLD] = finalize_config.classification_threshold
1273
+ if finalize_config.class_map is not None:
1274
+ finalized_data[PyTorchCheckpointKeys.CLASS_MAP] = finalize_config.class_map
1275
+
1276
+ # Save model file
1277
+ torch.save(finalized_data, full_path)
1278
+
1279
+ _LOGGER.info(f"Finalized model file saved to '{full_path}'")
711
1280
 
712
1281
 
713
1282
  # Object Detection Trainer
714
- class ObjectDetectionTrainer:
715
- def __init__(self, model: nn.Module, train_dataset: Dataset, test_dataset: Dataset,
1283
+ class DragonDetectionTrainer(_BaseDragonTrainer):
1284
+ def __init__(self, model: nn.Module,
1285
+ train_dataset: Dataset,
1286
+ validation_dataset: Dataset,
716
1287
  collate_fn: Callable, optimizer: torch.optim.Optimizer,
717
- device: Union[Literal['cuda', 'mps', 'cpu'],str], dataloader_workers: int = 2, callbacks: Optional[List[Callback]] = None):
1288
+ device: Union[Literal['cuda', 'mps', 'cpu'],str],
1289
+ checkpoint_callback: Optional[DragonModelCheckpoint],
1290
+ early_stopping_callback: Optional[DragonEarlyStopping],
1291
+ lr_scheduler_callback: Optional[DragonLRScheduler],
1292
+ extra_callbacks: Optional[List[_Callback]] = None,
1293
+ dataloader_workers: int = 2):
718
1294
  """
719
1295
  Automates the training process of an Object Detection Model (e.g., DragonFastRCNN).
720
1296
 
@@ -723,58 +1299,36 @@ class ObjectDetectionTrainer:
723
1299
  Args:
724
1300
  model (nn.Module): The PyTorch object detection model to train.
725
1301
  train_dataset (Dataset): The training dataset.
726
- test_dataset (Dataset): The testing/validation dataset.
1302
+ validation_dataset (Dataset): The testing/validation dataset.
727
1303
  collate_fn (Callable): The collate function from `ObjectDetectionDatasetMaker.collate_fn`.
728
1304
  optimizer (torch.optim.Optimizer): The optimizer.
729
1305
  device (str): The device to run training on ('cpu', 'cuda', 'mps').
730
1306
  dataloader_workers (int): Subprocesses for data loading.
731
- callbacks (List[Callback] | None): A list of callbacks to use during training.
1307
+ checkpoint_callback (DragonModelCheckpoint | None): Callback to save the model.
1308
+ early_stopping_callback (DragonEarlyStopping | None): Callback to stop training early.
1309
+ lr_scheduler_callback (DragonLRScheduler | None): Callback to manage the LR scheduler.
1310
+ extra_callbacks (List[Callback] | None): A list of extra callbacks to use during training.
732
1311
 
733
1312
  ## Note:
734
1313
  This trainer is specialized. It does not take a `criterion` because object detection models like Faster R-CNN return a dictionary of losses directly from their forward pass during training.
735
1314
  """
736
- self.model = model
1315
+ # Call the base class constructor with common parameters
1316
+ super().__init__(
1317
+ model=model,
1318
+ optimizer=optimizer,
1319
+ device=device,
1320
+ dataloader_workers=dataloader_workers,
1321
+ checkpoint_callback=checkpoint_callback,
1322
+ early_stopping_callback=early_stopping_callback,
1323
+ lr_scheduler_callback=lr_scheduler_callback,
1324
+ extra_callbacks=extra_callbacks
1325
+ )
1326
+
737
1327
  self.train_dataset = train_dataset
738
- self.test_dataset = test_dataset
739
- self.kind = "object_detection"
1328
+ self.validation_dataset = validation_dataset # <-- Renamed
1329
+ self.kind = MLTaskKeys.OBJECT_DETECTION
740
1330
  self.collate_fn = collate_fn
741
1331
  self.criterion = None # Criterion is handled inside the model
742
- self.optimizer = optimizer
743
- self.scheduler = None
744
- self.device = self._validate_device(device)
745
- self.dataloader_workers = dataloader_workers
746
-
747
- # Callback handler - History and TqdmProgressBar are added by default
748
- default_callbacks = [History(), TqdmProgressBar()]
749
- user_callbacks = callbacks if callbacks is not None else []
750
- self.callbacks = default_callbacks + user_callbacks
751
- self._set_trainer_on_callbacks()
752
-
753
- # Internal state
754
- self.train_loader = None
755
- self.test_loader = None
756
- self.history = {}
757
- self.epoch = 0
758
- self.epochs = 0 # Total epochs for the fit run
759
- self.start_epoch = 1
760
- self.stop_training = False
761
- self._batch_size = 10
762
-
763
- def _validate_device(self, device: str) -> torch.device:
764
- """Validates the selected device and returns a torch.device object."""
765
- device_lower = device.lower()
766
- if "cuda" in device_lower and not torch.cuda.is_available():
767
- _LOGGER.warning("CUDA not available, switching to CPU.")
768
- device = "cpu"
769
- elif device_lower == "mps" and not torch.backends.mps.is_available():
770
- _LOGGER.warning("Apple Metal Performance Shaders (MPS) not available, switching to CPU.")
771
- device = "cpu"
772
- return torch.device(device)
773
-
774
- def _set_trainer_on_callbacks(self):
775
- """Gives each callback a reference to this trainer instance."""
776
- for callback in self.callbacks:
777
- callback.set_trainer(self)
778
1332
 
779
1333
  def _create_dataloaders(self, batch_size: int, shuffle: bool):
780
1334
  """Initializes the DataLoaders with the object detection collate_fn."""
@@ -786,125 +1340,25 @@ class ObjectDetectionTrainer:
786
1340
  batch_size=batch_size,
787
1341
  shuffle=shuffle,
788
1342
  num_workers=loader_workers,
789
- pin_memory=("cuda" in self.device.type),
790
- collate_fn=self.collate_fn # Use the provided collate function
1343
+ pin_memory=("cuda" in self.device.type),
1344
+ collate_fn=self.collate_fn, # Use the provided collate function
1345
+ drop_last=True
791
1346
  )
792
1347
 
793
- self.test_loader = DataLoader(
794
- dataset=self.test_dataset,
1348
+ self.validation_loader = DataLoader(
1349
+ dataset=self.validation_dataset,
795
1350
  batch_size=batch_size,
796
1351
  shuffle=False,
797
1352
  num_workers=loader_workers,
798
1353
  pin_memory=("cuda" in self.device.type),
799
1354
  collate_fn=self.collate_fn # Use the provided collate function
800
1355
  )
801
-
802
- def _load_checkpoint(self, path: Union[str, Path]):
803
- """Loads a training checkpoint to resume training."""
804
- p = make_fullpath(path, enforce="file")
805
- _LOGGER.info(f"Loading checkpoint from '{p.name}' to resume training...")
806
-
807
- try:
808
- checkpoint = torch.load(p, map_location=self.device)
809
-
810
- if PyTorchCheckpointKeys.MODEL_STATE not in checkpoint or PyTorchCheckpointKeys.OPTIMIZER_STATE not in checkpoint:
811
- _LOGGER.error(f"Checkpoint file '{p.name}' is invalid. Missing 'model_state_dict' or 'optimizer_state_dict'.")
812
- raise KeyError()
813
-
814
- self.model.load_state_dict(checkpoint[PyTorchCheckpointKeys.MODEL_STATE])
815
- self.optimizer.load_state_dict(checkpoint[PyTorchCheckpointKeys.OPTIMIZER_STATE])
816
- self.start_epoch = checkpoint.get(PyTorchCheckpointKeys.EPOCH, 0) + 1 # Resume on the *next* epoch
817
-
818
- # --- Scheduler State Loading Logic ---
819
- scheduler_state_exists = PyTorchCheckpointKeys.SCHEDULER_STATE in checkpoint
820
- scheduler_object_exists = self.scheduler is not None
821
-
822
- if scheduler_object_exists and scheduler_state_exists:
823
- # Case 1: Both exist. Attempt to load.
824
- try:
825
- self.scheduler.load_state_dict(checkpoint[PyTorchCheckpointKeys.SCHEDULER_STATE]) # type: ignore
826
- scheduler_name = self.scheduler.__class__.__name__
827
- _LOGGER.info(f"Restored LR scheduler state for: {scheduler_name}")
828
- except Exception as e:
829
- # Loading failed, likely a mismatch
830
- scheduler_name = self.scheduler.__class__.__name__
831
- _LOGGER.error(f"Failed to load scheduler state for '{scheduler_name}'. A different scheduler type might have been used.")
832
- raise e
833
-
834
- elif scheduler_object_exists and not scheduler_state_exists:
835
- # Case 2: Scheduler provided, but no state in checkpoint.
836
- scheduler_name = self.scheduler.__class__.__name__
837
- _LOGGER.warning(f"'{scheduler_name}' was provided, but no scheduler state was found in the checkpoint. The scheduler will start from its initial state.")
838
-
839
- elif not scheduler_object_exists and scheduler_state_exists:
840
- # Case 3: State in checkpoint, but no scheduler provided.
841
- _LOGGER.error("Checkpoint contains an LR scheduler state, but no LRScheduler callback was provided.")
842
- raise ValueError()
843
-
844
- # Restore callback states
845
- for cb in self.callbacks:
846
- if isinstance(cb, ModelCheckpoint) and PyTorchCheckpointKeys.BEST_SCORE in checkpoint:
847
- cb.best = checkpoint[PyTorchCheckpointKeys.BEST_SCORE]
848
- _LOGGER.info(f"Restored {cb.__class__.__name__} 'best' score to: {cb.best:.4f}")
849
-
850
- _LOGGER.info(f"Checkpoint loaded. Resuming training from epoch {self.start_epoch}.")
851
-
852
- except Exception as e:
853
- _LOGGER.error(f"Failed to load checkpoint from '{p}': {e}")
854
- raise
855
-
856
- def fit(self,
857
- epochs: int = 10,
858
- batch_size: int = 10,
859
- shuffle: bool = True,
860
- resume_from_checkpoint: Optional[Union[str, Path]] = None):
861
- """
862
- Starts the training-validation process of the model.
863
-
864
- Returns the "History" callback dictionary.
865
-
866
- Args:
867
- epochs (int): The total number of epochs to train for.
868
- batch_size (int): The number of samples per batch.
869
- shuffle (bool): Whether to shuffle the training data at each epoch.
870
- resume_from_checkpoint (str | Path | None): Optional path to a checkpoint to resume training.
871
- """
872
- self.epochs = epochs
873
- self._batch_size = batch_size
874
- self._create_dataloaders(self._batch_size, shuffle)
875
- self.model.to(self.device)
876
-
877
- if resume_from_checkpoint:
878
- self._load_checkpoint(resume_from_checkpoint)
879
-
880
- # Reset stop_training flag on the trainer
881
- self.stop_training = False
882
-
883
- self._callbacks_hook('on_train_begin')
884
-
885
- for epoch in range(self.start_epoch, self.epochs + 1):
886
- self.epoch = epoch
887
- epoch_logs = {}
888
- self._callbacks_hook('on_epoch_begin', epoch, logs=epoch_logs)
889
-
890
- train_logs = self._train_step()
891
- epoch_logs.update(train_logs)
892
-
893
- val_logs = self._validation_step()
894
- epoch_logs.update(val_logs)
895
-
896
- self._callbacks_hook('on_epoch_end', epoch, logs=epoch_logs)
897
-
898
- # Check the early stopping flag
899
- if self.stop_training:
900
- break
901
1356
 
902
- self._callbacks_hook('on_train_end')
903
- return self.history
904
-
905
1357
  def _train_step(self):
906
1358
  self.model.train()
907
1359
  running_loss = 0.0
1360
+ total_samples = 0
1361
+
908
1362
  for batch_idx, (images, targets) in enumerate(self.train_loader): # type: ignore
909
1363
  # images is a tuple of tensors, targets is a tuple of dicts
910
1364
  batch_size = len(images)
@@ -941,21 +1395,28 @@ class ObjectDetectionTrainer:
941
1395
  # Calculate batch loss and update running loss for the epoch
942
1396
  batch_loss = loss.item()
943
1397
  running_loss += batch_loss * batch_size
1398
+ total_samples += batch_size # <-- Accumulate total samples
944
1399
 
945
1400
  # Add the batch loss to the logs and call the end-of-batch hook
946
1401
  batch_logs[PyTorchLogKeys.BATCH_LOSS] = batch_loss # type: ignore
947
1402
  self._callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
1403
+
1404
+ # Calculate loss using the correct denominator
1405
+ if total_samples == 0:
1406
+ _LOGGER.warning("No samples processed in _train_step. Returning 0 loss.")
1407
+ return {PyTorchLogKeys.TRAIN_LOSS: 0.0}
948
1408
 
949
- return {PyTorchLogKeys.TRAIN_LOSS: running_loss / len(self.train_loader.dataset)} # type: ignore
1409
+ return {PyTorchLogKeys.TRAIN_LOSS: running_loss / total_samples}
950
1410
 
951
1411
  def _validation_step(self):
952
1412
  self.model.train() # Set to train mode even for validation loss calculation
953
- # as model internals (e.g., proposals) might differ,
954
- # but we still need loss_dict.
955
- # We use torch.no_grad() to prevent gradient updates.
1413
+ # as model internals (e.g., proposals) might differ, but we still need loss_dict.
1414
+ # use torch.no_grad() to prevent gradient updates.
956
1415
  running_loss = 0.0
1416
+ total_samples = 0
1417
+
957
1418
  with torch.no_grad():
958
- for images, targets in self.test_loader: # type: ignore
1419
+ for images, targets in self.validation_loader: # type: ignore
959
1420
  batch_size = len(images)
960
1421
 
961
1422
  # Move data to device
@@ -973,25 +1434,105 @@ class ObjectDetectionTrainer:
973
1434
  loss: torch.Tensor = sum(l for l in loss_dict.values()) # type: ignore
974
1435
 
975
1436
  running_loss += loss.item() * batch_size
1437
+ total_samples += batch_size # <-- Accumulate total samples
976
1438
 
977
- logs = {PyTorchLogKeys.VAL_LOSS: running_loss / len(self.test_loader.dataset)} # type: ignore
1439
+ # Calculate loss using the correct denominator
1440
+ if total_samples == 0:
1441
+ _LOGGER.warning("No samples processed in _validation_step. Returning 0 loss.")
1442
+ return {PyTorchLogKeys.VAL_LOSS: 0.0}
1443
+
1444
+ logs = {PyTorchLogKeys.VAL_LOSS: running_loss / total_samples}
978
1445
  return logs
1446
+
1447
+ def evaluate(self,
1448
+ save_dir: Union[str, Path],
1449
+ model_checkpoint: Union[Path, Literal["latest", "current"]],
1450
+ test_data: Optional[Union[DataLoader, Dataset]] = None):
1451
+ """
1452
+ Evaluates the model using object detection mAP metrics.
979
1453
 
980
- def evaluate(self, save_dir: Union[str, Path], data: Optional[Union[DataLoader, Dataset]] = None):
1454
+ Args:
1455
+ save_dir (str | Path): Directory to save all reports and plots.
1456
+ model_checkpoint ('auto' | Path | None):
1457
+ - Path to a valid checkpoint for the model. The state of the trained model will be overwritten in place.
1458
+ - If 'latest', the latest checkpoint will be loaded if a DragonModelCheckpoint was provided. The state of the trained model will be overwritten in place.
1459
+ - If 'current', use the current state of the trained model up the latest trained epoch.
1460
+ test_data (DataLoader | Dataset | None): Optional Test data to evaluate the model performance. Validation and Test metrics will be saved to subdirectories.
1461
+ """
1462
+ # Validate model checkpoint
1463
+ if isinstance(model_checkpoint, Path):
1464
+ checkpoint_validated = make_fullpath(model_checkpoint, enforce="file")
1465
+ elif model_checkpoint in [MagicWords.LATEST, MagicWords.CURRENT]:
1466
+ checkpoint_validated = model_checkpoint
1467
+ else:
1468
+ _LOGGER.error(f"'model_checkpoint' must be a Path object, or the string '{MagicWords.LATEST}', or the string '{MagicWords.CURRENT}'.")
1469
+ raise ValueError()
1470
+
1471
+ # Validate directory
1472
+ save_path = make_fullpath(save_dir, make=True, enforce="directory")
1473
+
1474
+ # Validate test data and dispatch
1475
+ if test_data is not None:
1476
+ if not isinstance(test_data, (DataLoader, Dataset)):
1477
+ _LOGGER.error(f"Invalid type for 'test_data': '{type(test_data)}'.")
1478
+ raise ValueError()
1479
+ test_data_validated = test_data
1480
+
1481
+ validation_metrics_path = save_path / DragonTrainerKeys.VALIDATION_METRICS_DIR
1482
+ test_metrics_path = save_path / DragonTrainerKeys.TEST_METRICS_DIR
1483
+
1484
+ # Dispatch validation set
1485
+ _LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{DragonTrainerKeys.VALIDATION_METRICS_DIR}'")
1486
+ self._evaluate(save_dir=validation_metrics_path,
1487
+ model_checkpoint=checkpoint_validated,
1488
+ data=None) # 'None' triggers use of self.test_dataset
1489
+
1490
+ # Dispatch test set
1491
+ _LOGGER.info(f"Evaluating on test dataset. Metrics will be saved to '{DragonTrainerKeys.TEST_METRICS_DIR}'")
1492
+ self._evaluate(save_dir=test_metrics_path,
1493
+ model_checkpoint="current", # Use 'current' state after loading checkpoint once
1494
+ data=test_data_validated)
1495
+ else:
1496
+ # Dispatch validation set
1497
+ _LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{save_path.name}'")
1498
+ self._evaluate(save_dir=save_path,
1499
+ model_checkpoint=checkpoint_validated,
1500
+ data=None) # 'None' triggers use of self.test_dataset
1501
+
1502
+ def _evaluate(self,
1503
+ save_dir: Union[str, Path],
1504
+ model_checkpoint: Union[Path, Literal["latest", "current"]],
1505
+ data: Optional[Union[DataLoader, Dataset]]):
981
1506
  """
1507
+ Changed to a private helper method
982
1508
  Evaluates the model using object detection mAP metrics.
983
1509
 
984
1510
  Args:
985
1511
  save_dir (str | Path): Directory to save all reports and plots.
986
1512
  data (DataLoader | Dataset | None): The data to evaluate on. If None, defaults to the trainer's internal test_dataset.
1513
+ model_checkpoint ('auto' | Path | None):
1514
+ - Path to a valid checkpoint for the model. The state of the trained model will be overwritten in place.
1515
+ - If 'latest', the latest checkpoint will be loaded if a DragonModelCheckpoint was provided. The state of the trained model will be overwritten in place.
1516
+ - If 'current', use the current state of the trained model up the latest trained epoch.
987
1517
  """
988
- dataset_for_names = None
1518
+ dataset_for_artifacts = None
989
1519
  eval_loader = None
1520
+
1521
+ # load model checkpoint
1522
+ if isinstance(model_checkpoint, Path):
1523
+ self._load_checkpoint(path=model_checkpoint)
1524
+ elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback:
1525
+ path_to_latest = self._checkpoint_callback.best_checkpoint_path
1526
+ self._load_checkpoint(path_to_latest)
1527
+ elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback is None:
1528
+ _LOGGER.error(f"'model_checkpoint' set to '{MagicWords.LATEST}' but no checkpoint callback was found.")
1529
+ raise ValueError()
990
1530
 
1531
+ # Dataloader
991
1532
  if isinstance(data, DataLoader):
992
1533
  eval_loader = data
993
1534
  if hasattr(data, 'dataset'):
994
- dataset_for_names = data.dataset
1535
+ dataset_for_artifacts = data.dataset # type: ignore
995
1536
  elif isinstance(data, Dataset):
996
1537
  # Create a new loader from the provided dataset
997
1538
  eval_loader = DataLoader(data,
@@ -1000,21 +1541,21 @@ class ObjectDetectionTrainer:
1000
1541
  num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
1001
1542
  pin_memory=(self.device.type == "cuda"),
1002
1543
  collate_fn=self.collate_fn)
1003
- dataset_for_names = data
1544
+ dataset_for_artifacts = data
1004
1545
  else: # data is None, use the trainer's default test dataset
1005
- if self.test_dataset is None:
1546
+ if self.validation_dataset is None:
1006
1547
  _LOGGER.error("Cannot evaluate. No data provided and no test_dataset available in the trainer.")
1007
1548
  raise ValueError()
1008
1549
  # Create a fresh DataLoader from the test_dataset
1009
1550
  eval_loader = DataLoader(
1010
- self.test_dataset,
1551
+ self.validation_dataset,
1011
1552
  batch_size=self._batch_size,
1012
1553
  shuffle=False,
1013
1554
  num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
1014
1555
  pin_memory=(self.device.type == "cuda"),
1015
1556
  collate_fn=self.collate_fn
1016
1557
  )
1017
- dataset_for_names = self.test_dataset
1558
+ dataset_for_artifacts = self.validation_dataset
1018
1559
 
1019
1560
  if eval_loader is None:
1020
1561
  _LOGGER.error("Cannot evaluate. No valid data was provided or found.")
@@ -1051,11 +1592,11 @@ class ObjectDetectionTrainer:
1051
1592
  class_names = None
1052
1593
  try:
1053
1594
  # Try to get 'classes' from ObjectDetectionDatasetMaker
1054
- if hasattr(dataset_for_names, 'classes'):
1055
- class_names = dataset_for_names.classes # type: ignore
1595
+ if hasattr(dataset_for_artifacts, 'classes'):
1596
+ class_names = dataset_for_artifacts.classes # type: ignore
1056
1597
  # Fallback for Subset
1057
- elif hasattr(dataset_for_names, 'dataset') and hasattr(dataset_for_names.dataset, 'classes'): # type: ignore
1058
- class_names = dataset_for_names.dataset.classes # type: ignore
1598
+ elif hasattr(dataset_for_artifacts, 'dataset') and hasattr(dataset_for_artifacts.dataset, 'classes'): # type: ignore
1599
+ class_names = dataset_for_artifacts.dataset.classes # type: ignore
1059
1600
  except AttributeError:
1060
1601
  _LOGGER.warning("Could not find 'classes' attribute on dataset. Per-class metrics will not be named.")
1061
1602
  pass # class_names is still None
@@ -1068,36 +1609,451 @@ class ObjectDetectionTrainer:
1068
1609
  class_names=class_names,
1069
1610
  print_output=False
1070
1611
  )
1071
-
1072
- # print("\n--- Training History ---")
1073
- plot_losses(self.history, save_dir=save_dir)
1074
1612
 
1075
- def _callbacks_hook(self, method_name: str, *args, **kwargs):
1076
- """Calls the specified method on all callbacks."""
1077
- for callback in self.callbacks:
1078
- method = getattr(callback, method_name)
1079
- method(*args, **kwargs)
1613
+ def finalize_model_training(self,
1614
+ save_dir: Union[str, Path],
1615
+ model_checkpoint: Union[Path, Literal['latest', 'current']],
1616
+ finalize_config: FinalizeObjectDetection
1617
+ ):
1618
+ """
1619
+ Saves a finalized, "inference-ready" model state to a .pth file.
1620
+
1621
+ This method saves the model's `state_dict` and the final epoch number.
1622
+
1623
+ Args:
1624
+ save_dir (Union[str, Path]): The directory to save the finalized model.
1625
+ model_checkpoint (Union[Path, Literal["latest", "current"]]):
1626
+ - Path: Loads the model state from a specific checkpoint file.
1627
+ - "latest": Loads the best model state saved by the `DragonModelCheckpoint` callback.
1628
+ - "current": Uses the model's state as it is at the end of the `fit()` call.
1629
+ finalize_config (FinalizeObjectDetection): A data class instance specific to the ML task containing task-specific metadata required for inference.
1630
+ """
1631
+ if not isinstance(finalize_config, FinalizeObjectDetection):
1632
+ _LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeObjectDetection', but got {type(finalize_config).__name__}.")
1633
+ raise TypeError()
1634
+
1635
+ # handle save path
1636
+ dir_path = make_fullpath(save_dir, make=True, enforce="directory")
1637
+ full_path = dir_path / finalize_config.filename
1638
+
1639
+ # handle checkpoint
1640
+ self._load_model_state_for_finalizing(model_checkpoint)
1641
+
1642
+ # Create finalized data
1643
+ finalized_data = {
1644
+ PyTorchCheckpointKeys.EPOCH: self.epoch,
1645
+ PyTorchCheckpointKeys.MODEL_STATE: self.model.state_dict(),
1646
+ }
1647
+
1648
+ if finalize_config.class_map is not None:
1649
+ finalized_data[PyTorchCheckpointKeys.CLASS_MAP] = finalize_config.class_map
1650
+
1651
+ torch.save(finalized_data, full_path)
1652
+
1653
+ _LOGGER.info(f"Finalized model file saved to '{full_path}'")
1654
+
1655
+ # --- DragonSequenceTrainer ----
1656
+ class DragonSequenceTrainer(_BaseDragonTrainer):
1657
+ def __init__(self,
1658
+ model: nn.Module,
1659
+ train_dataset: Dataset,
1660
+ validation_dataset: Dataset,
1661
+ kind: Literal["sequence-to-sequence", "sequence-to-value"],
1662
+ optimizer: torch.optim.Optimizer,
1663
+ device: Union[Literal['cuda', 'mps', 'cpu'],str],
1664
+ checkpoint_callback: Optional[DragonModelCheckpoint],
1665
+ early_stopping_callback: Optional[DragonEarlyStopping],
1666
+ lr_scheduler_callback: Optional[DragonLRScheduler],
1667
+ extra_callbacks: Optional[List[_Callback]] = None,
1668
+ criterion: Union[nn.Module,Literal["auto"]] = "auto",
1669
+ dataloader_workers: int = 2):
1670
+ """
1671
+ Automates the training process of a PyTorch Sequence Model.
1672
+
1673
+ Built-in Callbacks: `History`, `TqdmProgressBar`
1674
+
1675
+ Args:
1676
+ model (nn.Module): The PyTorch model to train.
1677
+ train_dataset (Dataset): The training dataset.
1678
+ validation_dataset (Dataset): The validation dataset.
1679
+ kind (str): Used to redirect to the correct process ('sequence-to-sequence' or 'sequence-to-value').
1680
+ criterion (nn.Module | "auto"): The loss function to use. If "auto", it will be inferred from the selected task
1681
+ optimizer (torch.optim.Optimizer): The optimizer.
1682
+ device (str): The device to run training on ('cpu', 'cuda', 'mps').
1683
+ dataloader_workers (int): Subprocesses for data loading.
1684
+ extra_callbacks (List[Callback] | None): A list of extra callbacks to use during training.
1685
+ """
1686
+ # Call the base class constructor with common parameters
1687
+ super().__init__(
1688
+ model=model,
1689
+ optimizer=optimizer,
1690
+ device=device,
1691
+ dataloader_workers=dataloader_workers,
1692
+ checkpoint_callback=checkpoint_callback,
1693
+ early_stopping_callback=early_stopping_callback,
1694
+ lr_scheduler_callback=lr_scheduler_callback,
1695
+ extra_callbacks=extra_callbacks
1696
+ )
1697
+
1698
+ if kind not in [MLTaskKeys.SEQUENCE_SEQUENCE, MLTaskKeys.SEQUENCE_VALUE]:
1699
+ raise ValueError(f"'{kind}' is not a valid task type for DragonSequenceTrainer.")
1700
+
1701
+ self.train_dataset = train_dataset
1702
+ self.validation_dataset = validation_dataset
1703
+ self.kind = kind
1704
+
1705
+ # try to validate against Dragon Sequence model
1706
+ if hasattr(self.model, "prediction_mode"):
1707
+ key_to_check: str = self.model.prediction_mode # type: ignore
1708
+ if not key_to_check == self.kind:
1709
+ _LOGGER.error(f"Trainer was set for '{self.kind}', but model architecture '{self.model}' is built for '{key_to_check}'.")
1710
+ raise RuntimeError()
1711
+
1712
+ # loss function
1713
+ if criterion == "auto":
1714
+ # Both sequence tasks are treated as regression problems
1715
+ self.criterion = nn.MSELoss()
1716
+ else:
1717
+ self.criterion = criterion
1718
+
1719
+ def _create_dataloaders(self, batch_size: int, shuffle: bool):
1720
+ """Initializes the DataLoaders."""
1721
+ # Ensure stability on MPS devices by setting num_workers to 0
1722
+ loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
1723
+
1724
+ self.train_loader = DataLoader(
1725
+ dataset=self.train_dataset,
1726
+ batch_size=batch_size,
1727
+ shuffle=shuffle,
1728
+ num_workers=loader_workers,
1729
+ pin_memory=("cuda" in self.device.type),
1730
+ drop_last=True # Drops the last batch if incomplete, selecting a good batch size is key.
1731
+ )
1732
+
1733
+ self.validation_loader = DataLoader(
1734
+ dataset=self.validation_dataset,
1735
+ batch_size=batch_size,
1736
+ shuffle=False,
1737
+ num_workers=loader_workers,
1738
+ pin_memory=("cuda" in self.device.type)
1739
+ )
1740
+
1741
+ def _train_step(self):
1742
+ self.model.train()
1743
+ running_loss = 0.0
1744
+ total_samples = 0
1745
+
1746
+ for batch_idx, (features, target) in enumerate(self.train_loader): # type: ignore
1747
+ # Create a log dictionary for the batch
1748
+ batch_logs = {
1749
+ PyTorchLogKeys.BATCH_INDEX: batch_idx,
1750
+ PyTorchLogKeys.BATCH_SIZE: features.size(0)
1751
+ }
1752
+ self._callbacks_hook('on_batch_begin', batch_idx, logs=batch_logs)
1753
+
1754
+ features, target = features.to(self.device), target.to(self.device)
1755
+ self.optimizer.zero_grad()
1080
1756
 
1081
- def to_cpu(self):
1757
+ output = self.model(features)
1758
+
1759
+ # --- Label Type/Shape Correction ---
1760
+ # Ensure target is float for MSELoss
1761
+ target = target.float()
1762
+
1763
+ # For seq-to-val, models might output [N, 1] but target is [N].
1764
+ if self.kind == MLTaskKeys.SEQUENCE_VALUE:
1765
+ if output.ndim == 2 and output.shape[1] == 1 and target.ndim == 1:
1766
+ output = output.squeeze(1)
1767
+
1768
+ # For seq-to-seq, models might output [N, Seq, 1] but target is [N, Seq].
1769
+ elif self.kind == MLTaskKeys.SEQUENCE_SEQUENCE:
1770
+ if output.ndim == 3 and output.shape[2] == 1 and target.ndim == 2:
1771
+ output = output.squeeze(-1)
1772
+
1773
+ loss = self.criterion(output, target)
1774
+
1775
+ loss.backward()
1776
+ self.optimizer.step()
1777
+
1778
+ # Calculate batch loss and update running loss for the epoch
1779
+ batch_loss = loss.item()
1780
+ batch_size = features.size(0)
1781
+ running_loss += batch_loss * batch_size # Accumulate total loss
1782
+ total_samples += batch_size # total samples
1783
+
1784
+ # Add the batch loss to the logs and call the end-of-batch hook
1785
+ batch_logs[PyTorchLogKeys.BATCH_LOSS] = batch_loss
1786
+ self._callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
1787
+
1788
+ if total_samples == 0:
1789
+ _LOGGER.warning("No samples processed in a train_step. Returning 0 loss.")
1790
+ return {PyTorchLogKeys.TRAIN_LOSS: 0.0}
1791
+
1792
+ return {PyTorchLogKeys.TRAIN_LOSS: running_loss / total_samples} # type: ignore
1793
+
1794
+ def _validation_step(self):
1795
+ self.model.eval()
1796
+ running_loss = 0.0
1797
+
1798
+ with torch.no_grad():
1799
+ for features, target in self.validation_loader: # type: ignore
1800
+ features, target = features.to(self.device), target.to(self.device)
1801
+
1802
+ output = self.model(features)
1803
+
1804
+ # --- Label Type/Shape Correction ---
1805
+ target = target.float()
1806
+
1807
+ # For seq-to-val, models might output [N, 1] but target is [N].
1808
+ if self.kind == MLTaskKeys.SEQUENCE_VALUE:
1809
+ if output.ndim == 2 and output.shape[1] == 1 and target.ndim == 1:
1810
+ output = output.squeeze(1)
1811
+
1812
+ # For seq-to-seq, models might output [N, Seq, 1] but target is [N, Seq].
1813
+ elif self.kind == MLTaskKeys.SEQUENCE_SEQUENCE:
1814
+ if output.ndim == 3 and output.shape[2] == 1 and target.ndim == 2:
1815
+ output = output.squeeze(-1)
1816
+
1817
+ loss = self.criterion(output, target)
1818
+
1819
+ running_loss += loss.item() * features.size(0)
1820
+
1821
+ if not self.validation_loader.dataset: # type: ignore
1822
+ _LOGGER.warning("No samples processed in _validation_step. Returning 0 loss.")
1823
+ return {PyTorchLogKeys.VAL_LOSS: 0.0}
1824
+
1825
+ logs = {PyTorchLogKeys.VAL_LOSS: running_loss / len(self.validation_loader.dataset)} # type: ignore
1826
+ return logs
1827
+
1828
+ def _predict_for_eval(self, dataloader: DataLoader):
1082
1829
  """
1083
- Moves the model to the CPU and updates the trainer's device setting.
1830
+ Private method to yield model predictions batch by batch for evaluation.
1084
1831
 
1085
- This is useful for running operations that require the CPU.
1832
+ Yields:
1833
+ tuple: A tuple containing (y_pred_batch, y_prob_batch, y_true_batch).
1834
+ y_prob_batch is always None for sequence tasks.
1086
1835
  """
1087
- self.device = torch.device('cpu')
1836
+ self.model.eval()
1088
1837
  self.model.to(self.device)
1089
- _LOGGER.info("Trainer and model moved to CPU.")
1838
+
1839
+ with torch.no_grad():
1840
+ for features, target in dataloader:
1841
+ features = features.to(self.device)
1842
+ output = self.model(features).cpu()
1843
+
1844
+ y_pred_batch = output.numpy()
1845
+ y_prob_batch = None # Not applicable for sequence regression
1846
+ y_true_batch = target.numpy()
1847
+
1848
+ yield y_pred_batch, y_prob_batch, y_true_batch
1849
+
1850
+ def evaluate(self,
1851
+ save_dir: Union[str, Path],
1852
+ model_checkpoint: Union[Path, Literal["latest", "current"]],
1853
+ test_data: Optional[Union[DataLoader, Dataset]] = None,
1854
+ val_format_configuration: Optional[Union[SequenceValueMetricsFormat,
1855
+ SequenceSequenceMetricsFormat]]=None,
1856
+ test_format_configuration: Optional[Union[SequenceValueMetricsFormat,
1857
+ SequenceSequenceMetricsFormat]]=None):
1858
+ """
1859
+ Evaluates the model, routing to the correct evaluation function.
1860
+
1861
+ Args:
1862
+ model_checkpoint ('auto' | Path | None):
1863
+ - Path to a valid checkpoint for the model.
1864
+ - If 'latest', the latest checkpoint will be loaded.
1865
+ - If 'current', use the current state of the trained model.
1866
+ save_dir (str | Path): Directory to save all reports and plots.
1867
+ test_data (DataLoader | Dataset | None): Optional Test data.
1868
+ val_format_configuration: Optional configuration for validation metrics.
1869
+ test_format_configuration: Optional configuration for test metrics.
1870
+ """
1871
+ # Validate model checkpoint
1872
+ if isinstance(model_checkpoint, Path):
1873
+ checkpoint_validated = make_fullpath(model_checkpoint, enforce="file")
1874
+ elif model_checkpoint in [MagicWords.LATEST, MagicWords.CURRENT]:
1875
+ checkpoint_validated = model_checkpoint
1876
+ else:
1877
+ _LOGGER.error(f"'model_checkpoint' must be a Path object, or '{MagicWords.LATEST}', or '{MagicWords.CURRENT}'.")
1878
+ raise ValueError()
1879
+
1880
+ # Validate val configuration
1881
+ if val_format_configuration is not None:
1882
+ if not isinstance(val_format_configuration, (SequenceValueMetricsFormat, SequenceSequenceMetricsFormat)):
1883
+ _LOGGER.error(f"Invalid 'val_format_configuration': '{type(val_format_configuration)}'.")
1884
+ raise ValueError()
1885
+
1886
+ # Validate directory
1887
+ save_path = make_fullpath(save_dir, make=True, enforce="directory")
1888
+
1889
+ # Validate test data and dispatch
1890
+ if test_data is not None:
1891
+ if not isinstance(test_data, (DataLoader, Dataset)):
1892
+ _LOGGER.error(f"Invalid type for 'test_data': '{type(test_data)}'.")
1893
+ raise ValueError()
1894
+ test_data_validated = test_data
1090
1895
 
1091
- def to_device(self, device: str):
1896
+ validation_metrics_path = save_path / DragonTrainerKeys.VALIDATION_METRICS_DIR
1897
+ test_metrics_path = save_path / DragonTrainerKeys.TEST_METRICS_DIR
1898
+
1899
+ # Dispatch validation set
1900
+ _LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{DragonTrainerKeys.VALIDATION_METRICS_DIR}'")
1901
+ self._evaluate(save_dir=validation_metrics_path,
1902
+ model_checkpoint=checkpoint_validated,
1903
+ data=None,
1904
+ format_configuration=val_format_configuration)
1905
+
1906
+ # Validate test configuration
1907
+ test_configuration_validated = None
1908
+ if test_format_configuration is not None:
1909
+ if not isinstance(test_format_configuration, (SequenceValueMetricsFormat, SequenceSequenceMetricsFormat)):
1910
+ warning_message_type = f"Invalid test_format_configuration': '{type(test_format_configuration)}'."
1911
+ if val_format_configuration is not None:
1912
+ warning_message_type += " 'val_format_configuration' will be used."
1913
+ test_configuration_validated = val_format_configuration
1914
+ else:
1915
+ warning_message_type += " Using default format."
1916
+ _LOGGER.warning(warning_message_type)
1917
+ else:
1918
+ test_configuration_validated = test_format_configuration
1919
+
1920
+ # Dispatch test set
1921
+ _LOGGER.info(f"Evaluating on test dataset. Metrics will be saved to '{DragonTrainerKeys.TEST_METRICS_DIR}'")
1922
+ self._evaluate(save_dir=test_metrics_path,
1923
+ model_checkpoint="current",
1924
+ data=test_data_validated,
1925
+ format_configuration=test_configuration_validated)
1926
+ else:
1927
+ # Dispatch validation set
1928
+ _LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{save_path.name}'")
1929
+ self._evaluate(save_dir=save_path,
1930
+ model_checkpoint=checkpoint_validated,
1931
+ data=None,
1932
+ format_configuration=val_format_configuration)
1933
+
1934
+ def _evaluate(self,
1935
+ save_dir: Union[str, Path],
1936
+ model_checkpoint: Union[Path, Literal["latest", "current"]],
1937
+ data: Optional[Union[DataLoader, Dataset]],
1938
+ format_configuration: object):
1092
1939
  """
1093
- Moves the model to the specified device and updates the trainer's device setting.
1940
+ Private evaluation helper.
1941
+ """
1942
+ eval_loader = None
1943
+
1944
+ # load model checkpoint
1945
+ if isinstance(model_checkpoint, Path):
1946
+ self._load_checkpoint(path=model_checkpoint)
1947
+ elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback:
1948
+ path_to_latest = self._checkpoint_callback.best_checkpoint_path
1949
+ self._load_checkpoint(path_to_latest)
1950
+ elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback is None:
1951
+ _LOGGER.error(f"'model_checkpoint' set to '{MagicWords.LATEST}' but no checkpoint callback was found.")
1952
+ raise ValueError()
1953
+
1954
+ # Dataloader
1955
+ if isinstance(data, DataLoader):
1956
+ eval_loader = data
1957
+ elif isinstance(data, Dataset):
1958
+ # Create a new loader from the provided dataset
1959
+ eval_loader = DataLoader(data,
1960
+ batch_size=self._batch_size,
1961
+ shuffle=False,
1962
+ num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
1963
+ pin_memory=(self.device.type == "cuda"))
1964
+ else: # data is None, use the trainer's default validation dataset
1965
+ if self.validation_dataset is None:
1966
+ _LOGGER.error("Cannot evaluate. No data provided and no validation_dataset available in the trainer.")
1967
+ raise ValueError()
1968
+ eval_loader = DataLoader(self.validation_dataset,
1969
+ batch_size=self._batch_size,
1970
+ shuffle=False,
1971
+ num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
1972
+ pin_memory=(self.device.type == "cuda"))
1973
+
1974
+ if eval_loader is None:
1975
+ _LOGGER.error("Cannot evaluate. No valid data was provided or found.")
1976
+ raise ValueError()
1977
+
1978
+ all_preds, _, all_true = [], [], []
1979
+ for y_pred_b, y_prob_b, y_true_b in self._predict_for_eval(eval_loader):
1980
+ if y_pred_b is not None: all_preds.append(y_pred_b)
1981
+ if y_true_b is not None: all_true.append(y_true_b)
1982
+
1983
+ if not all_true:
1984
+ _LOGGER.error("Evaluation failed: No data was processed.")
1985
+ return
1986
+
1987
+ y_pred = np.concatenate(all_preds)
1988
+ y_true = np.concatenate(all_true)
1989
+
1990
+ # --- Routing Logic ---
1991
+ if self.kind == MLTaskKeys.SEQUENCE_VALUE:
1992
+ config = None
1993
+ if format_configuration and isinstance(format_configuration, SequenceValueMetricsFormat):
1994
+ config = format_configuration
1995
+ elif format_configuration:
1996
+ _LOGGER.warning(f"Wrong config type: Received {type(format_configuration).__name__}, expected SequenceValueMetricsFormat.")
1997
+
1998
+ sequence_to_value_metrics(y_true=y_true,
1999
+ y_pred=y_pred,
2000
+ save_dir=save_dir,
2001
+ config=config)
2002
+
2003
+ elif self.kind == MLTaskKeys.SEQUENCE_SEQUENCE:
2004
+ config = None
2005
+ if format_configuration and isinstance(format_configuration, SequenceSequenceMetricsFormat):
2006
+ config = format_configuration
2007
+ elif format_configuration:
2008
+ _LOGGER.warning(f"Wrong config type: Received {type(format_configuration).__name__}, expected SequenceSequenceMetricsFormat.")
2009
+
2010
+ sequence_to_sequence_metrics(y_true=y_true,
2011
+ y_pred=y_pred,
2012
+ save_dir=save_dir,
2013
+ config=config)
2014
+
2015
+ def finalize_model_training(self,
2016
+ save_dir: Union[str, Path],
2017
+ model_checkpoint: Union[Path, Literal['latest', 'current']],
2018
+ finalize_config: FinalizeSequencePrediction):
2019
+ """
2020
+ Saves a finalized, "inference-ready" model state to a .pth file.
2021
+
2022
+ This method saves the model's `state_dict` and the final epoch number.
1094
2023
 
1095
2024
  Args:
1096
- device (str): The target device (e.g., 'cuda', 'mps', 'cpu').
2025
+ save_dir (Union[str, Path]): The directory to save the finalized model.
2026
+ model_checkpoint (Union[Path, Literal["latest", "current"]]):
2027
+ - Path: Loads the model state from a specific checkpoint file.
2028
+ - "latest": Loads the best model state saved by the `DragonModelCheckpoint` callback.
2029
+ - "current": Uses the model's state as it is at the end of the `fit()` call.
2030
+ finalize_config (FinalizeSequencePrediction): A data class instance specific to the ML task containing task-specific metadata required for inference.
1097
2031
  """
1098
- self.device = self._validate_device(device)
1099
- self.model.to(self.device)
1100
- _LOGGER.info(f"Trainer and model moved to {self.device}.")
2032
+ if not isinstance(finalize_config, FinalizeSequencePrediction):
2033
+ _LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeSequencePrediction', but got {type(finalize_config).__name__}.")
2034
+ raise TypeError()
2035
+
2036
+ # handle save path
2037
+ dir_path = make_fullpath(save_dir, make=True, enforce="directory")
2038
+ full_path = dir_path / finalize_config.filename
2039
+
2040
+ # handle checkpoint
2041
+ self._load_model_state_for_finalizing(model_checkpoint)
2042
+
2043
+ # Create finalized data
2044
+ finalized_data = {
2045
+ PyTorchCheckpointKeys.EPOCH: self.epoch,
2046
+ PyTorchCheckpointKeys.MODEL_STATE: self.model.state_dict(),
2047
+ }
2048
+
2049
+ if finalize_config.sequence_length is not None:
2050
+ finalized_data[PyTorchCheckpointKeys.SEQUENCE_LENGTH] = finalize_config.sequence_length
2051
+ if finalize_config.initial_sequence is not None:
2052
+ finalized_data[PyTorchCheckpointKeys.INITIAL_SEQUENCE] = finalize_config.initial_sequence
2053
+
2054
+ torch.save(finalized_data, full_path)
2055
+
2056
+ _LOGGER.info(f"Finalized model file saved to '{full_path}'")
1101
2057
 
1102
2058
 
1103
2059
  def info():