dragon-ml-toolbox 13.3.0__py3-none-any.whl → 16.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/METADATA +20 -6
  2. dragon_ml_toolbox-16.2.0.dist-info/RECORD +51 -0
  3. {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +10 -0
  4. ml_tools/ETL_cleaning.py +20 -20
  5. ml_tools/ETL_engineering.py +23 -25
  6. ml_tools/GUI_tools.py +20 -20
  7. ml_tools/MICE_imputation.py +207 -5
  8. ml_tools/ML_callbacks.py +43 -26
  9. ml_tools/ML_configuration.py +788 -0
  10. ml_tools/ML_datasetmaster.py +303 -448
  11. ml_tools/ML_evaluation.py +351 -93
  12. ml_tools/ML_evaluation_multi.py +139 -42
  13. ml_tools/ML_inference.py +290 -209
  14. ml_tools/ML_models.py +33 -106
  15. ml_tools/ML_models_advanced.py +323 -0
  16. ml_tools/ML_optimization.py +12 -12
  17. ml_tools/ML_scaler.py +11 -11
  18. ml_tools/ML_sequence_datasetmaster.py +341 -0
  19. ml_tools/ML_sequence_evaluation.py +219 -0
  20. ml_tools/ML_sequence_inference.py +391 -0
  21. ml_tools/ML_sequence_models.py +139 -0
  22. ml_tools/ML_trainer.py +1604 -179
  23. ml_tools/ML_utilities.py +351 -4
  24. ml_tools/ML_vision_datasetmaster.py +1540 -0
  25. ml_tools/ML_vision_evaluation.py +284 -0
  26. ml_tools/ML_vision_inference.py +405 -0
  27. ml_tools/ML_vision_models.py +641 -0
  28. ml_tools/ML_vision_transformers.py +284 -0
  29. ml_tools/PSO_optimization.py +6 -6
  30. ml_tools/SQL.py +4 -4
  31. ml_tools/_keys.py +171 -0
  32. ml_tools/_schema.py +1 -1
  33. ml_tools/custom_logger.py +37 -14
  34. ml_tools/data_exploration.py +502 -93
  35. ml_tools/ensemble_evaluation.py +54 -11
  36. ml_tools/ensemble_inference.py +7 -33
  37. ml_tools/ensemble_learning.py +1 -1
  38. ml_tools/math_utilities.py +1 -1
  39. ml_tools/optimization_tools.py +2 -2
  40. ml_tools/path_manager.py +5 -5
  41. ml_tools/serde.py +2 -2
  42. ml_tools/utilities.py +192 -4
  43. dragon_ml_toolbox-13.3.0.dist-info/RECORD +0 -41
  44. ml_tools/RNN_forecast.py +0 -56
  45. ml_tools/keys.py +0 -87
  46. {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/WHEEL +0 -0
  47. {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/licenses/LICENSE +0 -0
  48. {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/top_level.txt +0 -0
ml_tools/ML_trainer.py CHANGED
@@ -1,79 +1,101 @@
1
- from typing import List, Literal, Union, Optional
1
+ from typing import List, Literal, Union, Optional, Callable, Dict, Any
2
2
  from pathlib import Path
3
3
  from torch.utils.data import DataLoader, Dataset
4
4
  import torch
5
5
  from torch import nn
6
6
  import numpy as np
7
+ from abc import ABC, abstractmethod
7
8
 
8
- from .ML_callbacks import Callback, History, TqdmProgressBar, ModelCheckpoint
9
+ from .path_manager import make_fullpath
10
+ from .ML_callbacks import _Callback, History, TqdmProgressBar, DragonModelCheckpoint, DragonEarlyStopping, DragonLRScheduler
9
11
  from .ML_evaluation import classification_metrics, regression_metrics, plot_losses, shap_summary_plot, plot_attention_importance
10
12
  from .ML_evaluation_multi import multi_target_regression_metrics, multi_label_classification_metrics, multi_target_shap_summary_plot
13
+ from .ML_vision_evaluation import segmentation_metrics, object_detection_metrics
14
+ from .ML_sequence_evaluation import sequence_to_sequence_metrics, sequence_to_value_metrics
15
+ from .ML_configuration import (RegressionMetricsFormat,
16
+ MultiTargetRegressionMetricsFormat,
17
+ BinaryClassificationMetricsFormat,
18
+ MultiClassClassificationMetricsFormat,
19
+ BinaryImageClassificationMetricsFormat,
20
+ MultiClassImageClassificationMetricsFormat,
21
+ MultiLabelBinaryClassificationMetricsFormat,
22
+ BinarySegmentationMetricsFormat,
23
+ MultiClassSegmentationMetricsFormat,
24
+ SequenceValueMetricsFormat,
25
+ SequenceSequenceMetricsFormat,
26
+
27
+ FinalizeBinaryClassification,
28
+ FinalizeBinarySegmentation,
29
+ FinalizeBinaryImageClassification,
30
+ FinalizeMultiClassClassification,
31
+ FinalizeMultiClassImageClassification,
32
+ FinalizeMultiClassSegmentation,
33
+ FinalizeMultiLabelBinaryClassification,
34
+ FinalizeMultiTargetRegression,
35
+ FinalizeRegression,
36
+ FinalizeObjectDetection,
37
+ FinalizeSequencePrediction)
38
+
11
39
  from ._script_info import _script_info
12
- from .keys import PyTorchLogKeys, PyTorchCheckpointKeys
40
+ from ._keys import PyTorchLogKeys, PyTorchCheckpointKeys, DatasetKeys, MLTaskKeys, MagicWords, DragonTrainerKeys
13
41
  from ._logger import _LOGGER
14
- from .path_manager import make_fullpath
15
42
 
16
43
 
17
44
  __all__ = [
18
- "MLTrainer"
45
+ "DragonTrainer",
46
+ "DragonDetectionTrainer",
47
+ "DragonSequenceTrainer"
19
48
  ]
20
49
 
21
-
22
- class MLTrainer:
23
- def __init__(self, model: nn.Module, train_dataset: Dataset, test_dataset: Dataset,
24
- kind: Literal["regression", "classification", "multi_target_regression", "multi_label_classification"],
25
- criterion: nn.Module, optimizer: torch.optim.Optimizer,
26
- device: Union[Literal['cuda', 'mps', 'cpu'],str], dataloader_workers: int = 2, callbacks: Optional[List[Callback]] = None):
27
- """
28
- Automates the training process of a PyTorch Model.
29
-
30
- Built-in Callbacks: `History`, `TqdmProgressBar`
31
-
32
- Args:
33
- model (nn.Module): The PyTorch model to train.
34
- train_dataset (Dataset): The training dataset.
35
- test_dataset (Dataset): The testing/validation dataset.
36
- kind (str): Can be 'regression', 'classification', 'multi_target_regression', or 'multi_label_classification'.
37
- criterion (nn.Module): The loss function.
38
- optimizer (torch.optim.Optimizer): The optimizer.
39
- device (str): The device to run training on ('cpu', 'cuda', 'mps').
40
- dataloader_workers (int): Subprocesses for data loading.
41
- callbacks (List[Callback] | None): A list of callbacks to use during training.
42
-
43
- Note:
44
- - For **regression** and **multi_target_regression** tasks, suggested criterions include `nn.MSELoss` or `nn.L1Loss`.
50
+ class _BaseDragonTrainer(ABC):
51
+ """
52
+ Abstract base class for Dragon Trainers.
45
53
 
46
- - For **single-label, multi-class classification** tasks, `nn.CrossEntropyLoss` is the standard choice.
47
-
48
- - For **multi-label, binary classification** tasks (where each label is a 0 or 1), `nn.BCEWithLogitsLoss` is the correct choice as it treats each output as an independent binary problem.
49
- """
50
- if kind not in ["regression", "classification", "multi_target_regression", "multi_label_classification"]:
51
- raise ValueError(f"'{kind}' is not a valid task type.")
54
+ Handles the common training loop orchestration, checkpointing, callback
55
+ management, and device handling. Subclasses must implement the
56
+ task-specific logic (dataloaders, train/val steps, evaluation).
57
+ """
58
+ def __init__(self,
59
+ model: nn.Module,
60
+ optimizer: torch.optim.Optimizer,
61
+ device: Union[Literal['cuda', 'mps', 'cpu'],str],
62
+ dataloader_workers: int = 2,
63
+ checkpoint_callback: Optional[DragonModelCheckpoint] = None,
64
+ early_stopping_callback: Optional[DragonEarlyStopping] = None,
65
+ lr_scheduler_callback: Optional[DragonLRScheduler] = None,
66
+ extra_callbacks: Optional[List[_Callback]] = None):
52
67
 
53
68
  self.model = model
54
- self.train_dataset = train_dataset
55
- self.test_dataset = test_dataset
56
- self.kind = kind
57
- self.criterion = criterion
58
69
  self.optimizer = optimizer
59
70
  self.scheduler = None
60
71
  self.device = self._validate_device(device)
61
72
  self.dataloader_workers = dataloader_workers
62
73
 
63
- # Callback handler - History and TqdmProgressBar are added by default
74
+ # Callback handler
64
75
  default_callbacks = [History(), TqdmProgressBar()]
65
- user_callbacks = callbacks if callbacks is not None else []
76
+
77
+ self._checkpoint_callback = None
78
+ if checkpoint_callback:
79
+ default_callbacks.append(checkpoint_callback)
80
+ self._checkpoint_callback = checkpoint_callback
81
+ if early_stopping_callback:
82
+ default_callbacks.append(early_stopping_callback)
83
+ if lr_scheduler_callback:
84
+ default_callbacks.append(lr_scheduler_callback)
85
+
86
+ user_callbacks = extra_callbacks if extra_callbacks is not None else []
66
87
  self.callbacks = default_callbacks + user_callbacks
67
88
  self._set_trainer_on_callbacks()
68
89
 
69
90
  # Internal state
70
- self.train_loader = None
71
- self.test_loader = None
72
- self.history = {}
91
+ self.train_loader: Optional[DataLoader] = None
92
+ self.validation_loader: Optional[DataLoader] = None
93
+ self.history: Dict[str, List[Any]] = {}
73
94
  self.epoch = 0
74
95
  self.epochs = 0 # Total epochs for the fit run
75
96
  self.start_epoch = 1
76
97
  self.stop_training = False
98
+ self._batch_size = 10
77
99
 
78
100
  def _validate_device(self, device: str) -> torch.device:
79
101
  """Validates the selected device and returns a torch.device object."""
@@ -91,32 +113,10 @@ class MLTrainer:
91
113
  for callback in self.callbacks:
92
114
  callback.set_trainer(self)
93
115
 
94
- def _create_dataloaders(self, batch_size: int, shuffle: bool):
95
- """Initializes the DataLoaders."""
96
- # Ensure stability on MPS devices by setting num_workers to 0
97
- loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
98
-
99
- self.train_loader = DataLoader(
100
- dataset=self.train_dataset,
101
- batch_size=batch_size,
102
- shuffle=shuffle,
103
- num_workers=loader_workers,
104
- pin_memory=("cuda" in self.device.type),
105
- drop_last=True # Drops the last batch if incomplete, selecting a good batch size is key.
106
- )
107
-
108
- self.test_loader = DataLoader(
109
- dataset=self.test_dataset,
110
- batch_size=batch_size,
111
- shuffle=False,
112
- num_workers=loader_workers,
113
- pin_memory=("cuda" in self.device.type)
114
- )
115
-
116
116
  def _load_checkpoint(self, path: Union[str, Path]):
117
117
  """Loads a training checkpoint to resume training."""
118
118
  p = make_fullpath(path, enforce="file")
119
- _LOGGER.info(f"Loading checkpoint from '{p.name}' to resume training...")
119
+ _LOGGER.info(f"Loading checkpoint from '{p.name}'...")
120
120
 
121
121
  try:
122
122
  checkpoint = torch.load(p, map_location=self.device)
@@ -127,7 +127,16 @@ class MLTrainer:
127
127
 
128
128
  self.model.load_state_dict(checkpoint[PyTorchCheckpointKeys.MODEL_STATE])
129
129
  self.optimizer.load_state_dict(checkpoint[PyTorchCheckpointKeys.OPTIMIZER_STATE])
130
- self.start_epoch = checkpoint.get(PyTorchCheckpointKeys.EPOCH, 0) + 1 # Resume on the *next* epoch
130
+ self.epoch = checkpoint.get(PyTorchCheckpointKeys.EPOCH, 0)
131
+ self.start_epoch = self.epoch + 1 # Resume on the *next* epoch
132
+
133
+ # --- Load History ---
134
+ if PyTorchCheckpointKeys.HISTORY in checkpoint:
135
+ self.history = checkpoint[PyTorchCheckpointKeys.HISTORY]
136
+ _LOGGER.info(f"Restored training history up to epoch {self.epoch}.")
137
+ else:
138
+ _LOGGER.warning("No 'history' found in checkpoint. A new history will be started.")
139
+ self.history = {} # Ensure it's at least an empty dict
131
140
 
132
141
  # --- Scheduler State Loading Logic ---
133
142
  scheduler_state_exists = PyTorchCheckpointKeys.SCHEDULER_STATE in checkpoint
@@ -157,7 +166,7 @@ class MLTrainer:
157
166
 
158
167
  # Restore callback states
159
168
  for cb in self.callbacks:
160
- if isinstance(cb, ModelCheckpoint) and PyTorchCheckpointKeys.BEST_SCORE in checkpoint:
169
+ if isinstance(cb, DragonModelCheckpoint) and PyTorchCheckpointKeys.BEST_SCORE in checkpoint:
161
170
  cb.best = checkpoint[PyTorchCheckpointKeys.BEST_SCORE]
162
171
  _LOGGER.info(f"Restored {cb.__class__.__name__} 'best' score to: {cb.best:.4f}")
163
172
 
@@ -168,7 +177,8 @@ class MLTrainer:
168
177
  raise
169
178
 
170
179
  def fit(self,
171
- epochs: int = 10,
180
+ save_dir: Union[str,Path],
181
+ epochs: int = 100,
172
182
  batch_size: int = 10,
173
183
  shuffle: bool = True,
174
184
  resume_from_checkpoint: Optional[Union[str, Path]] = None):
@@ -178,20 +188,15 @@ class MLTrainer:
178
188
  Returns the "History" callback dictionary.
179
189
 
180
190
  Args:
191
+ save_dir (str | Path): Directory to save the loss plot.
181
192
  epochs (int): The total number of epochs to train for.
182
193
  batch_size (int): The number of samples per batch.
183
194
  shuffle (bool): Whether to shuffle the training data at each epoch.
184
195
  resume_from_checkpoint (str | Path | None): Optional path to a checkpoint to resume training.
185
-
186
- Note:
187
- For regression tasks using `nn.MSELoss` or `nn.L1Loss`, the trainer
188
- automatically aligns the model's output tensor with the target tensor's
189
- shape using `output.view_as(target)`. This handles the common case
190
- where a model outputs a shape of `[batch_size, 1]` and the target has a
191
- shape of `[batch_size]`.
192
196
  """
193
197
  self.epochs = epochs
194
- self._create_dataloaders(batch_size, shuffle)
198
+ self._batch_size = batch_size
199
+ self._create_dataloaders(self._batch_size, shuffle) # type: ignore
195
200
  self.model.to(self.device)
196
201
 
197
202
  if resume_from_checkpoint:
@@ -202,11 +207,19 @@ class MLTrainer:
202
207
 
203
208
  self._callbacks_hook('on_train_begin')
204
209
 
210
+ if not self.train_loader:
211
+ _LOGGER.error("Train loader is not initialized.")
212
+ raise ValueError()
213
+
214
+ if not self.validation_loader:
215
+ _LOGGER.error("Validation loader is not initialized.")
216
+ raise ValueError()
217
+
205
218
  for epoch in range(self.start_epoch, self.epochs + 1):
206
219
  self.epoch = epoch
207
- epoch_logs = {}
220
+ epoch_logs: Dict[str, Any] = {}
208
221
  self._callbacks_hook('on_epoch_begin', epoch, logs=epoch_logs)
209
-
222
+
210
223
  train_logs = self._train_step()
211
224
  epoch_logs.update(train_logs)
212
225
 
@@ -220,11 +233,204 @@ class MLTrainer:
220
233
  break
221
234
 
222
235
  self._callbacks_hook('on_train_end')
236
+
237
+ # Training History
238
+ plot_losses(self.history, save_dir=save_dir)
239
+
223
240
  return self.history
241
+
242
+ def _callbacks_hook(self, method_name: str, *args, **kwargs):
243
+ """Calls the specified method on all callbacks."""
244
+ for callback in self.callbacks:
245
+ method = getattr(callback, method_name)
246
+ method(*args, **kwargs)
247
+
248
+ def to_cpu(self):
249
+ """
250
+ Moves the model to the CPU and updates the trainer's device setting.
251
+
252
+ This is useful for running operations that require the CPU.
253
+ """
254
+ self.device = torch.device('cpu')
255
+ self.model.to(self.device)
256
+ _LOGGER.info("Trainer and model moved to CPU.")
257
+
258
+ def to_device(self, device: str):
259
+ """
260
+ Moves the model to the specified device and updates the trainer's device setting.
261
+
262
+ Args:
263
+ device (str): The target device (e.g., 'cuda', 'mps', 'cpu').
264
+ """
265
+ self.device = self._validate_device(device)
266
+ self.model.to(self.device)
267
+ _LOGGER.info(f"Trainer and model moved to {self.device}.")
268
+
269
+ def _load_model_state_for_finalizing(self, model_checkpoint: Union[Path, Literal['latest', 'current']]):
270
+ """
271
+ Private helper to load the correct model state_dict based on user's choice.
272
+ This is called by finalize_model_training() in subclasses.
273
+ """
274
+ if isinstance(model_checkpoint, Path):
275
+ self._load_checkpoint(path=model_checkpoint)
276
+ elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback:
277
+ path_to_latest = self._checkpoint_callback.best_checkpoint_path
278
+ self._load_checkpoint(path_to_latest)
279
+ elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback is None:
280
+ _LOGGER.error(f"'model_checkpoint' set to '{MagicWords.LATEST}' but no checkpoint callback was found.")
281
+ raise ValueError()
282
+ elif model_checkpoint == MagicWords.CURRENT:
283
+ pass
284
+ else:
285
+ _LOGGER.error(f"Unknown 'model_checkpoint' received '{model_checkpoint}'.")
286
+ raise ValueError()
287
+
288
+ # --- Abstract Methods ---
289
+ # These must be implemented by subclasses
290
+
291
+ @abstractmethod
292
+ def _create_dataloaders(self, batch_size: int, shuffle: bool):
293
+ """Initializes the DataLoaders."""
294
+ raise NotImplementedError
295
+
296
+ @abstractmethod
297
+ def _train_step(self) -> Dict[str, float]:
298
+ """Runs a single training epoch."""
299
+ raise NotImplementedError
300
+
301
+ @abstractmethod
302
+ def _validation_step(self) -> Dict[str, float]:
303
+ """Runs a single validation epoch."""
304
+ raise NotImplementedError
305
+
306
+ @abstractmethod
307
+ def evaluate(self, *args, **kwargs):
308
+ """Runs the full model evaluation."""
309
+ raise NotImplementedError
310
+
311
+ @abstractmethod
312
+ def _evaluate(self, *args, **kwargs):
313
+ """Internal evaluation helper."""
314
+ raise NotImplementedError
315
+
316
+ @abstractmethod
317
+ def finalize_model_training(self, *args, **kwargs):
318
+ """Saves the finalized model for inference."""
319
+ raise NotImplementedError
320
+
321
+
322
+ # --- DragonTrainer ----
323
+ class DragonTrainer(_BaseDragonTrainer):
324
+ def __init__(self,
325
+ model: nn.Module,
326
+ train_dataset: Dataset,
327
+ validation_dataset: Dataset,
328
+ kind: Literal["regression", "binary classification", "multiclass classification",
329
+ "multitarget regression", "multilabel binary classification",
330
+ "binary segmentation", "multiclass segmentation", "binary image classification", "multiclass image classification"],
331
+ optimizer: torch.optim.Optimizer,
332
+ device: Union[Literal['cuda', 'mps', 'cpu'],str],
333
+ checkpoint_callback: Optional[DragonModelCheckpoint],
334
+ early_stopping_callback: Optional[DragonEarlyStopping],
335
+ lr_scheduler_callback: Optional[DragonLRScheduler],
336
+ extra_callbacks: Optional[List[_Callback]] = None,
337
+ criterion: Union[nn.Module,Literal["auto"]] = "auto",
338
+ dataloader_workers: int = 2):
339
+ """
340
+ Automates the training process of a PyTorch Model.
341
+
342
+ Built-in Callbacks: `History`, `TqdmProgressBar`
343
+
344
+ Args:
345
+ model (nn.Module): The PyTorch model to train.
346
+ train_dataset (Dataset): The training dataset.
347
+ validation_dataset (Dataset): The validation dataset.
348
+ kind (str): Used to redirect to the correct process.
349
+ criterion (nn.Module | "auto"): The loss function to use. If "auto", it will be inferred from the selected task
350
+ optimizer (torch.optim.Optimizer): The optimizer.
351
+ device (str): The device to run training on ('cpu', 'cuda', 'mps').
352
+ dataloader_workers (int): Subprocesses for data loading.
353
+ extra_callbacks (List[Callback] | None): A list of extra callbacks to use during training.
354
+
355
+ Note:
356
+ - For **regression** and **multi_target_regression** tasks, suggested criterions include `nn.MSELoss` or `nn.L1Loss`. The model should output as many logits as existing targets.
357
+
358
+ - For **single-label, binary classification**, `nn.BCEWithLogitsLoss` is the standard choice. The model should output a single logit.
224
359
 
360
+ - For **single-label, multi-class classification** tasks, `nn.CrossEntropyLoss` is the standard choice. The model should output as many logits as existing classes.
361
+
362
+ - For **multi-label, binary classification** tasks (where each label is a 0 or 1), `nn.BCEWithLogitsLoss` is the correct choice as it treats each output as an independent binary problem. The model should output 1 logit per binary target.
363
+
364
+ - For **binary segmentation** tasks, `nn.BCEWithLogitsLoss` is common. The model should output a single logit.
365
+
366
+ - for **multiclass segmentation** tasks, `nn.CrossEntropyLoss` is the standard. The model should output as many logits as existing classes.
367
+ """
368
+ # Call the base class constructor with common parameters
369
+ super().__init__(
370
+ model=model,
371
+ optimizer=optimizer,
372
+ device=device,
373
+ dataloader_workers=dataloader_workers,
374
+ checkpoint_callback=checkpoint_callback,
375
+ early_stopping_callback=early_stopping_callback,
376
+ lr_scheduler_callback=lr_scheduler_callback,
377
+ extra_callbacks=extra_callbacks
378
+ )
379
+
380
+ if kind not in [MLTaskKeys.REGRESSION,
381
+ MLTaskKeys.BINARY_CLASSIFICATION,
382
+ MLTaskKeys.MULTICLASS_CLASSIFICATION,
383
+ MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION,
384
+ MLTaskKeys.MULTITARGET_REGRESSION,
385
+ MLTaskKeys.BINARY_SEGMENTATION,
386
+ MLTaskKeys.MULTICLASS_SEGMENTATION,
387
+ MLTaskKeys.BINARY_IMAGE_CLASSIFICATION,
388
+ MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION]:
389
+ raise ValueError(f"'{kind}' is not a valid task type.")
390
+
391
+ self.train_dataset = train_dataset
392
+ self.validation_dataset = validation_dataset
393
+ self.kind = kind
394
+ self._classification_threshold: float = 0.5
395
+
396
+ # loss function
397
+ if criterion == "auto":
398
+ if kind in [MLTaskKeys.REGRESSION, MLTaskKeys.MULTITARGET_REGRESSION]:
399
+ self.criterion = nn.MSELoss()
400
+ elif kind in [MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.BINARY_IMAGE_CLASSIFICATION, MLTaskKeys.BINARY_SEGMENTATION, MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION]:
401
+ self.criterion = nn.BCEWithLogitsLoss()
402
+ elif kind in [MLTaskKeys.MULTICLASS_CLASSIFICATION, MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION, MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION, MLTaskKeys.MULTICLASS_SEGMENTATION]:
403
+ self.criterion = nn.CrossEntropyLoss()
404
+ else:
405
+ self.criterion = criterion
406
+
407
+ def _create_dataloaders(self, batch_size: int, shuffle: bool):
408
+ """Initializes the DataLoaders."""
409
+ # Ensure stability on MPS devices by setting num_workers to 0
410
+ loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
411
+
412
+ self.train_loader = DataLoader(
413
+ dataset=self.train_dataset,
414
+ batch_size=batch_size,
415
+ shuffle=shuffle,
416
+ num_workers=loader_workers,
417
+ pin_memory=("cuda" in self.device.type),
418
+ drop_last=True # Drops the last batch if incomplete, selecting a good batch size is key.
419
+ )
420
+
421
+ self.validation_loader = DataLoader(
422
+ dataset=self.validation_dataset,
423
+ batch_size=batch_size,
424
+ shuffle=False,
425
+ num_workers=loader_workers,
426
+ pin_memory=("cuda" in self.device.type)
427
+ )
428
+
225
429
  def _train_step(self):
226
430
  self.model.train()
227
431
  running_loss = 0.0
432
+ total_samples = 0
433
+
228
434
  for batch_idx, (features, target) in enumerate(self.train_loader): # type: ignore
229
435
  # Create a log dictionary for the batch
230
436
  batch_logs = {
@@ -238,9 +444,21 @@ class MLTrainer:
238
444
 
239
445
  output = self.model(features)
240
446
 
241
- # Apply shape correction only for single-target regression
242
- if self.kind == "regression":
243
- output = output.view_as(target)
447
+ # --- Label Type/Shape Correction ---
448
+ # Cast target to float for BCE-based losses
449
+ if self.kind in MLTaskKeys.ALL_BINARY_TASKS:
450
+ target = target.float()
451
+
452
+ # Reshape output to match target for single-logit tasks
453
+ if self.kind in [MLTaskKeys.REGRESSION, MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.BINARY_IMAGE_CLASSIFICATION]:
454
+ # If model outputs [N, 1] and target is [N], squeeze output
455
+ if output.ndim == 2 and output.shape[1] == 1 and target.ndim == 1:
456
+ output = output.squeeze(1)
457
+
458
+ if self.kind == MLTaskKeys.BINARY_SEGMENTATION:
459
+ # If model outputs [N, 1, H, W] and target is [N, H, W], squeeze output
460
+ if output.ndim == 4 and output.shape[1] == 1 and target.ndim == 3:
461
+ output = output.squeeze(1)
244
462
 
245
463
  loss = self.criterion(output, target)
246
464
 
@@ -249,34 +467,58 @@ class MLTrainer:
249
467
 
250
468
  # Calculate batch loss and update running loss for the epoch
251
469
  batch_loss = loss.item()
252
- running_loss += batch_loss * features.size(0)
470
+ batch_size = features.size(0)
471
+ running_loss += batch_loss * batch_size # Accumulate total loss
472
+ total_samples += batch_size # total samples
253
473
 
254
474
  # Add the batch loss to the logs and call the end-of-batch hook
255
475
  batch_logs[PyTorchLogKeys.BATCH_LOSS] = batch_loss
256
476
  self._callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
477
+
478
+ if total_samples == 0:
479
+ _LOGGER.warning("No samples processed in a train_step. Returning 0 loss.")
480
+ return {PyTorchLogKeys.TRAIN_LOSS: 0.0}
257
481
 
258
- return {PyTorchLogKeys.TRAIN_LOSS: running_loss / len(self.train_loader.dataset)} # type: ignore
482
+ return {PyTorchLogKeys.TRAIN_LOSS: running_loss / total_samples} # type: ignore
259
483
 
260
484
  def _validation_step(self):
261
485
  self.model.eval()
262
486
  running_loss = 0.0
487
+
263
488
  with torch.no_grad():
264
- for features, target in self.test_loader: # type: ignore
489
+ for features, target in self.validation_loader: # type: ignore
265
490
  features, target = features.to(self.device), target.to(self.device)
266
491
 
267
492
  output = self.model(features)
268
- # Apply shape correction only for single-target regression
269
- if self.kind == "regression":
270
- output = output.view_as(target)
493
+
494
+ # --- Label Type/Shape Correction ---
495
+ # Cast target to float for BCE-based losses
496
+ if self.kind in MLTaskKeys.ALL_BINARY_TASKS:
497
+ target = target.float()
498
+
499
+ # Reshape output to match target for single-logit tasks
500
+ if self.kind in [MLTaskKeys.REGRESSION, MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.BINARY_IMAGE_CLASSIFICATION]:
501
+ # If model outputs [N, 1] and target is [N], squeeze output
502
+ if output.ndim == 2 and output.shape[1] == 1 and target.ndim == 1:
503
+ output = output.squeeze(1)
504
+
505
+ if self.kind == MLTaskKeys.BINARY_SEGMENTATION:
506
+ # If model outputs [N, 1, H, W] and target is [N, H, W], squeeze output
507
+ if output.ndim == 4 and output.shape[1] == 1 and target.ndim == 3:
508
+ output = output.squeeze(1)
271
509
 
272
510
  loss = self.criterion(output, target)
273
511
 
274
512
  running_loss += loss.item() * features.size(0)
513
+
514
+ if not self.validation_loader.dataset: # type: ignore
515
+ _LOGGER.warning("No samples processed in _validation_step. Returning 0 loss.")
516
+ return {PyTorchLogKeys.VAL_LOSS: 0.0}
275
517
 
276
- logs = {PyTorchLogKeys.VAL_LOSS: running_loss / len(self.test_loader.dataset)} # type: ignore
518
+ logs = {PyTorchLogKeys.VAL_LOSS: running_loss / len(self.validation_loader.dataset)} # type: ignore
277
519
  return logs
278
520
 
279
- def _predict_for_eval(self, dataloader: DataLoader, classification_threshold: float = 0.5):
521
+ def _predict_for_eval(self, dataloader: DataLoader):
280
522
  """
281
523
  Private method to yield model predictions batch by batch for evaluation.
282
524
 
@@ -287,77 +529,301 @@ class MLTrainer:
287
529
  """
288
530
  self.model.eval()
289
531
  self.model.to(self.device)
532
+
290
533
  with torch.no_grad():
291
534
  for features, target in dataloader:
292
535
  features = features.to(self.device)
293
536
  output = self.model(features).cpu()
294
- y_true_batch = target.numpy()
295
537
 
296
538
  y_pred_batch = None
297
539
  y_prob_batch = None
540
+ y_true_batch = None
298
541
 
299
- if self.kind in ["regression", "multi_target_regression"]:
542
+ if self.kind in [MLTaskKeys.REGRESSION, MLTaskKeys.MULTITARGET_REGRESSION]:
300
543
  y_pred_batch = output.numpy()
544
+ y_true_batch = target.numpy()
545
+
546
+ elif self.kind in [MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.BINARY_IMAGE_CLASSIFICATION]:
547
+ # Assumes model output is [N, 1] (a single logit)
548
+ # Squeeze output from [N, 1] to [N] if necessary
549
+ if output.ndim == 2 and output.shape[1] == 1:
550
+ output = output.squeeze(1)
551
+
552
+ probs_pos = torch.sigmoid(output) # Probability of positive class
553
+ preds = (probs_pos >= self._classification_threshold).int()
554
+ y_pred_batch = preds.numpy()
555
+ # For metrics (like ROC AUC), we often need probs for *both* classes
556
+ # Create an [N, 2] array: [prob_class_0, prob_class_1]
557
+ probs_neg = 1.0 - probs_pos
558
+ y_prob_batch = torch.stack([probs_neg, probs_pos], dim=1).numpy()
559
+ y_true_batch = target.numpy()
301
560
 
302
- elif self.kind == "classification":
561
+ elif self.kind in [MLTaskKeys.MULTICLASS_CLASSIFICATION, MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION]:
562
+ num_classes = output.shape[1]
563
+ if num_classes < 3:
564
+ # Optional: warn the user they are using the wrong kind
565
+ wrong_class = MLTaskKeys.MULTICLASS_CLASSIFICATION if self.kind == MLTaskKeys.MULTICLASS_CLASSIFICATION else MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION
566
+ recommended_class = MLTaskKeys.BINARY_CLASSIFICATION if self.kind == MLTaskKeys.MULTICLASS_CLASSIFICATION else MLTaskKeys.BINARY_IMAGE_CLASSIFICATION
567
+ _LOGGER.warning(f"'{wrong_class}' kind used with {num_classes} classes. Consider using '{recommended_class}' instead.")
568
+
303
569
  probs = torch.softmax(output, dim=1)
304
570
  preds = torch.argmax(probs, dim=1)
305
571
  y_pred_batch = preds.numpy()
306
572
  y_prob_batch = probs.numpy()
573
+ y_true_batch = target.numpy()
307
574
 
308
- elif self.kind == "multi_label_classification":
575
+ elif self.kind == MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION:
309
576
  probs = torch.sigmoid(output)
310
- preds = (probs >= classification_threshold).int()
577
+ preds = (probs >= self._classification_threshold).int()
311
578
  y_pred_batch = preds.numpy()
312
579
  y_prob_batch = probs.numpy()
580
+ y_true_batch = target.numpy()
581
+
582
+ elif self.kind == MLTaskKeys.BINARY_SEGMENTATION:
583
+ # Assumes model output is [N, 1, H, W] (logits for positive class)
584
+ probs_pos = torch.sigmoid(output) # Shape [N, 1, H, W]
585
+ preds = (probs_pos >= self._classification_threshold).int() # Shape [N, 1, H, W]
586
+
587
+ # Squeeze preds to [N, H, W] (class indices 0 or 1)
588
+ y_pred_batch = preds.squeeze(1).numpy()
313
589
 
314
- yield y_pred_batch, y_prob_batch, y_true_batch
590
+ # Create [N, 2, H, W] probs for consistency
591
+ probs_neg = 1.0 - probs_pos
592
+ y_prob_batch = torch.cat([probs_neg, probs_pos], dim=1).numpy()
593
+
594
+ # Handle target shape [N, 1, H, W] -> [N, H, W]
595
+ if target.ndim == 4 and target.shape[1] == 1:
596
+ target = target.squeeze(1)
597
+ y_true_batch = target.numpy()
598
+
599
+ elif self.kind == MLTaskKeys.MULTICLASS_SEGMENTATION:
600
+ # output shape [N, C, H, W]
601
+ probs = torch.softmax(output, dim=1)
602
+ preds = torch.argmax(probs, dim=1) # shape [N, H, W]
603
+ y_pred_batch = preds.numpy()
604
+ y_prob_batch = probs.numpy() # Probs are [N, C, H, W]
605
+
606
+ # Handle target shape [N, 1, H, W] -> [N, H, W]
607
+ if target.ndim == 4 and target.shape[1] == 1:
608
+ target = target.squeeze(1)
609
+ y_true_batch = target.numpy()
315
610
 
316
- def evaluate(self, save_dir: Union[str, Path], data: Optional[Union[DataLoader, Dataset]] = None, classification_threshold: float = 0.5):
611
+ yield y_pred_batch, y_prob_batch, y_true_batch
612
+
613
+ def evaluate(self,
614
+ save_dir: Union[str, Path],
615
+ model_checkpoint: Union[Path, Literal["latest", "current"]],
616
+ classification_threshold: Optional[float] = None,
617
+ test_data: Optional[Union[DataLoader, Dataset]] = None,
618
+ val_format_configuration: Optional[Union[
619
+ RegressionMetricsFormat,
620
+ MultiTargetRegressionMetricsFormat,
621
+ BinaryClassificationMetricsFormat,
622
+ MultiClassClassificationMetricsFormat,
623
+ BinaryImageClassificationMetricsFormat,
624
+ MultiClassImageClassificationMetricsFormat,
625
+ MultiLabelBinaryClassificationMetricsFormat,
626
+ BinarySegmentationMetricsFormat,
627
+ MultiClassSegmentationMetricsFormat
628
+ ]]=None,
629
+ test_format_configuration: Optional[Union[
630
+ RegressionMetricsFormat,
631
+ MultiTargetRegressionMetricsFormat,
632
+ BinaryClassificationMetricsFormat,
633
+ MultiClassClassificationMetricsFormat,
634
+ BinaryImageClassificationMetricsFormat,
635
+ MultiClassImageClassificationMetricsFormat,
636
+ MultiLabelBinaryClassificationMetricsFormat,
637
+ BinarySegmentationMetricsFormat,
638
+ MultiClassSegmentationMetricsFormat,
639
+ ]]=None):
317
640
  """
318
641
  Evaluates the model, routing to the correct evaluation function based on task `kind`.
319
642
 
320
643
  Args:
644
+ model_checkpoint ('auto' | Path | None):
645
+ - Path to a valid checkpoint for the model. The state of the trained model will be overwritten in place.
646
+ - If 'latest', the latest checkpoint will be loaded if a DragonModelCheckpoint was provided. The state of the trained model will be overwritten in place.
647
+ - If 'current', use the current state of the trained model up the latest trained epoch.
321
648
  save_dir (str | Path): Directory to save all reports and plots.
322
- data (DataLoader | Dataset | None): The data to evaluate on. If None, defaults to the trainer's internal test_dataset.
323
- classification_threshold (float): Probability threshold for multi-label tasks.
649
+ classification_threshold (float | None): Used for tasks using a binary approach (binary classification, binary segmentation, multilabel binary classification)
650
+ test_data (DataLoader | Dataset | None): Optional Test data to evaluate the model performance. Validation and Test metrics will be saved to subdirectories.
651
+ val_format_configuration (object): Optional configuration for metric format output for the validation set.
652
+ test_format_configuration (object): Optional configuration for metric format output for the test set.
653
+ """
654
+ # Validate model checkpoint
655
+ if isinstance(model_checkpoint, Path):
656
+ checkpoint_validated = make_fullpath(model_checkpoint, enforce="file")
657
+ elif model_checkpoint in [MagicWords.LATEST, MagicWords.CURRENT]:
658
+ checkpoint_validated = model_checkpoint
659
+ else:
660
+ _LOGGER.error(f"'model_checkpoint' must be a Path object, or the string '{MagicWords.LATEST}', or the string '{MagicWords.CURRENT}'.")
661
+ raise ValueError()
662
+
663
+ # Validate classification threshold
664
+ if self.kind not in MLTaskKeys.ALL_BINARY_TASKS:
665
+ # dummy value for tasks that do not need it
666
+ threshold_validated = 0.5
667
+ elif classification_threshold is None:
668
+ # it should have been provided for binary tasks
669
+ _LOGGER.error(f"The classification threshold must be provided for '{self.kind}'.")
670
+ raise ValueError()
671
+ elif classification_threshold <= 0.0 or classification_threshold >= 1.0:
672
+ # Invalid float
673
+ _LOGGER.error(f"A classification threshold of {classification_threshold} is invalid. Must be in the range (0.0 - 1.0).")
674
+ raise ValueError()
675
+ else:
676
+ threshold_validated = classification_threshold
677
+
678
+ # Validate val configuration
679
+ if val_format_configuration is not None:
680
+ if not isinstance(val_format_configuration, (RegressionMetricsFormat,
681
+ MultiTargetRegressionMetricsFormat,
682
+ BinaryClassificationMetricsFormat,
683
+ MultiClassClassificationMetricsFormat,
684
+ BinaryImageClassificationMetricsFormat,
685
+ MultiClassImageClassificationMetricsFormat,
686
+ MultiLabelBinaryClassificationMetricsFormat,
687
+ BinarySegmentationMetricsFormat,
688
+ MultiClassSegmentationMetricsFormat)):
689
+ _LOGGER.error(f"Invalid 'format_configuration': '{type(val_format_configuration)}'.")
690
+ raise ValueError()
691
+ else:
692
+ val_configuration_validated = val_format_configuration
693
+ else: # config is None
694
+ val_configuration_validated = None
695
+
696
+ # Validate directory
697
+ save_path = make_fullpath(save_dir, make=True, enforce="directory")
698
+
699
+ # Validate test data and dispatch
700
+ if test_data is not None:
701
+ if not isinstance(test_data, (DataLoader, Dataset)):
702
+ _LOGGER.error(f"Invalid type for 'test_data': '{type(test_data)}'.")
703
+ raise ValueError()
704
+ test_data_validated = test_data
705
+
706
+ validation_metrics_path = save_path / DragonTrainerKeys.VALIDATION_METRICS_DIR
707
+ test_metrics_path = save_path / DragonTrainerKeys.TEST_METRICS_DIR
708
+
709
+ # Dispatch validation set
710
+ _LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{DragonTrainerKeys.VALIDATION_METRICS_DIR}'")
711
+ self._evaluate(save_dir=validation_metrics_path,
712
+ model_checkpoint=checkpoint_validated,
713
+ classification_threshold=threshold_validated,
714
+ data=None,
715
+ format_configuration=val_configuration_validated)
716
+
717
+ # Validate test configuration
718
+ if test_format_configuration is not None:
719
+ if not isinstance(test_format_configuration, (RegressionMetricsFormat,
720
+ MultiTargetRegressionMetricsFormat,
721
+ BinaryClassificationMetricsFormat,
722
+ MultiClassClassificationMetricsFormat,
723
+ BinaryImageClassificationMetricsFormat,
724
+ MultiClassImageClassificationMetricsFormat,
725
+ MultiLabelBinaryClassificationMetricsFormat,
726
+ BinarySegmentationMetricsFormat,
727
+ MultiClassSegmentationMetricsFormat)):
728
+ warning_message_type = f"Invalid test_format_configuration': '{type(test_format_configuration)}'."
729
+ if val_configuration_validated is not None:
730
+ warning_message_type += " 'val_format_configuration' will be used for the test set metrics output."
731
+ test_configuration_validated = val_configuration_validated
732
+ else:
733
+ warning_message_type += " Using default format."
734
+ test_configuration_validated = None
735
+ _LOGGER.warning(warning_message_type)
736
+ else:
737
+ test_configuration_validated = test_format_configuration
738
+ else: #config is None
739
+ test_configuration_validated = None
740
+
741
+ # Dispatch test set
742
+ _LOGGER.info(f"Evaluating on test dataset. Metrics will be saved to '{DragonTrainerKeys.TEST_METRICS_DIR}'")
743
+ self._evaluate(save_dir=test_metrics_path,
744
+ model_checkpoint="current",
745
+ classification_threshold=threshold_validated,
746
+ data=test_data_validated,
747
+ format_configuration=test_configuration_validated)
748
+ else:
749
+ # Dispatch validation set
750
+ _LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{save_path.name}'")
751
+ self._evaluate(save_dir=save_path,
752
+ model_checkpoint=checkpoint_validated,
753
+ classification_threshold=threshold_validated,
754
+ data=None,
755
+ format_configuration=val_configuration_validated)
756
+
757
+ def _evaluate(self,
758
+ save_dir: Union[str, Path],
759
+ model_checkpoint: Union[Path, Literal["latest", "current"]],
760
+ classification_threshold: float,
761
+ data: Optional[Union[DataLoader, Dataset]],
762
+ format_configuration: Optional[Union[
763
+ RegressionMetricsFormat,
764
+ MultiTargetRegressionMetricsFormat,
765
+ BinaryClassificationMetricsFormat,
766
+ MultiClassClassificationMetricsFormat,
767
+ BinaryImageClassificationMetricsFormat,
768
+ MultiClassImageClassificationMetricsFormat,
769
+ MultiLabelBinaryClassificationMetricsFormat,
770
+ BinarySegmentationMetricsFormat,
771
+ MultiClassSegmentationMetricsFormat
772
+ ]]=None):
773
+ """
774
+ Changed to a private helper function.
324
775
  """
325
- dataset_for_names = None
776
+ dataset_for_artifacts = None
326
777
  eval_loader = None
327
-
778
+
779
+ # set threshold
780
+ self._classification_threshold = classification_threshold
781
+
782
+ # load model checkpoint
783
+ if isinstance(model_checkpoint, Path):
784
+ self._load_checkpoint(path=model_checkpoint)
785
+ elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback:
786
+ path_to_latest = self._checkpoint_callback.best_checkpoint_path
787
+ self._load_checkpoint(path_to_latest)
788
+ elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback is None:
789
+ _LOGGER.error(f"'model_checkpoint' set to '{MagicWords.LATEST}' but no checkpoint callback was found.")
790
+ raise ValueError()
791
+
792
+ # Dataloader
328
793
  if isinstance(data, DataLoader):
329
794
  eval_loader = data
330
795
  # Try to get the dataset from the loader for fetching target names
331
796
  if hasattr(data, 'dataset'):
332
- dataset_for_names = data.dataset
797
+ dataset_for_artifacts = data.dataset # type: ignore
333
798
  elif isinstance(data, Dataset):
334
799
  # Create a new loader from the provided dataset
335
800
  eval_loader = DataLoader(data,
336
- batch_size=32,
801
+ batch_size=self._batch_size,
337
802
  shuffle=False,
338
803
  num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
339
804
  pin_memory=(self.device.type == "cuda"))
340
- dataset_for_names = data
805
+ dataset_for_artifacts = data
341
806
  else: # data is None, use the trainer's default test dataset
342
- if self.test_dataset is None:
343
- _LOGGER.error("Cannot evaluate. No data provided and no test_dataset available in the trainer.")
807
+ if self.validation_dataset is None:
808
+ _LOGGER.error("Cannot evaluate. No data provided and no validation dataset available in the trainer.")
344
809
  raise ValueError()
345
810
  # Create a fresh DataLoader from the test_dataset
346
- eval_loader = DataLoader(self.test_dataset,
347
- batch_size=32,
811
+ eval_loader = DataLoader(self.validation_dataset,
812
+ batch_size=self._batch_size,
348
813
  shuffle=False,
349
814
  num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
350
815
  pin_memory=(self.device.type == "cuda"))
351
- dataset_for_names = self.test_dataset
816
+
817
+ dataset_for_artifacts = self.validation_dataset
352
818
 
353
819
  if eval_loader is None:
354
820
  _LOGGER.error("Cannot evaluate. No valid data was provided or found.")
355
821
  raise ValueError()
356
822
 
357
- print("\n--- Model Evaluation ---")
823
+ # print("\n--- Model Evaluation ---")
358
824
 
359
825
  all_preds, all_probs, all_true = [], [], []
360
- for y_pred_b, y_prob_b, y_true_b in self._predict_for_eval(eval_loader, classification_threshold):
826
+ for y_pred_b, y_prob_b, y_true_b in self._predict_for_eval(eval_loader):
361
827
  if y_pred_b is not None: all_preds.append(y_pred_b)
362
828
  if y_prob_b is not None: all_probs.append(y_prob_b)
363
829
  if y_true_b is not None: all_true.append(y_true_b)
@@ -371,24 +837,83 @@ class MLTrainer:
371
837
  y_prob = np.concatenate(all_probs) if all_probs else None
372
838
 
373
839
  # --- Routing Logic ---
374
- if self.kind == "regression":
375
- regression_metrics(y_true.flatten(), y_pred.flatten(), save_dir)
376
-
377
- elif self.kind == "classification":
378
- classification_metrics(save_dir, y_true, y_pred, y_prob)
379
-
380
- elif self.kind == "multi_target_regression":
840
+ # Single-target regression
841
+ if self.kind == MLTaskKeys.REGRESSION:
842
+ # Check configuration
843
+ config = None
844
+ if format_configuration and isinstance(format_configuration, RegressionMetricsFormat):
845
+ config = format_configuration
846
+ elif format_configuration:
847
+ _LOGGER.warning(f"Wrong configuration type: Received '{type(format_configuration).__name__}'.")
848
+
849
+ regression_metrics(y_true=y_true.flatten(),
850
+ y_pred=y_pred.flatten(),
851
+ save_dir=save_dir,
852
+ config=config)
853
+
854
+ # single target classification
855
+ elif self.kind in [MLTaskKeys.BINARY_CLASSIFICATION,
856
+ MLTaskKeys.BINARY_IMAGE_CLASSIFICATION,
857
+ MLTaskKeys.MULTICLASS_CLASSIFICATION,
858
+ MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION]:
859
+ # get the class map if it exists
860
+ try:
861
+ class_map = dataset_for_artifacts.class_map # type: ignore
862
+ except AttributeError:
863
+ _LOGGER.warning(f"Dataset has no 'class_map' attribute. Using generics.")
864
+ class_map = None
865
+ else:
866
+ if not isinstance(class_map, dict):
867
+ _LOGGER.warning(f"Dataset has a 'class_map' attribute, but it is not a dictionary: '{type(class_map)}'.")
868
+ class_map = None
869
+
870
+ # Check configuration
871
+ config = None
872
+ if format_configuration:
873
+ if self.kind == MLTaskKeys.BINARY_CLASSIFICATION and isinstance(format_configuration, BinaryClassificationMetricsFormat):
874
+ config = format_configuration
875
+ elif self.kind == MLTaskKeys.BINARY_IMAGE_CLASSIFICATION and isinstance(format_configuration, BinaryImageClassificationMetricsFormat):
876
+ config = format_configuration
877
+ elif self.kind == MLTaskKeys.MULTICLASS_CLASSIFICATION and isinstance(format_configuration, MultiClassClassificationMetricsFormat):
878
+ config = format_configuration
879
+ elif self.kind == MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION and isinstance(format_configuration, MultiClassImageClassificationMetricsFormat):
880
+ config = format_configuration
881
+ else:
882
+ _LOGGER.warning(f"Wrong configuration type: Received '{type(format_configuration).__name__}'.")
883
+
884
+ classification_metrics(save_dir=save_dir,
885
+ y_true=y_true,
886
+ y_pred=y_pred,
887
+ y_prob=y_prob,
888
+ class_map=class_map,
889
+ config=config)
890
+
891
+ # multitarget regression
892
+ elif self.kind == MLTaskKeys.MULTITARGET_REGRESSION:
381
893
  try:
382
- target_names = dataset_for_names.target_names # type: ignore
894
+ target_names = dataset_for_artifacts.target_names # type: ignore
383
895
  except AttributeError:
384
896
  num_targets = y_true.shape[1]
385
897
  target_names = [f"target_{i}" for i in range(num_targets)]
386
898
  _LOGGER.warning(f"Dataset has no 'target_names' attribute. Using generic names.")
387
- multi_target_regression_metrics(y_true, y_pred, target_names, save_dir)
388
-
389
- elif self.kind == "multi_label_classification":
899
+
900
+ # Check configuration
901
+ config = None
902
+ if format_configuration and isinstance(format_configuration, MultiTargetRegressionMetricsFormat):
903
+ config = format_configuration
904
+ elif format_configuration:
905
+ _LOGGER.warning(f"Wrong configuration type: Received '{type(format_configuration).__name__}'.")
906
+
907
+ multi_target_regression_metrics(y_true=y_true,
908
+ y_pred=y_pred,
909
+ target_names=target_names,
910
+ save_dir=save_dir,
911
+ config=config)
912
+
913
+ # multi-label binary classification
914
+ elif self.kind == MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION:
390
915
  try:
391
- target_names = dataset_for_names.target_names # type: ignore
916
+ target_names = dataset_for_artifacts.target_names # type: ignore
392
917
  except AttributeError:
393
918
  num_targets = y_true.shape[1]
394
919
  target_names = [f"label_{i}" for i in range(num_targets)]
@@ -397,10 +922,56 @@ class MLTrainer:
397
922
  if y_prob is None:
398
923
  _LOGGER.error("Evaluation for multi_label_classification requires probabilities (y_prob).")
399
924
  return
400
- multi_label_classification_metrics(y_true, y_prob, target_names, save_dir, classification_threshold)
925
+
926
+ # Check configuration
927
+ config = None
928
+ if format_configuration and isinstance(format_configuration, MultiLabelBinaryClassificationMetricsFormat):
929
+ config = format_configuration
930
+ elif format_configuration:
931
+ _LOGGER.warning(f"Wrong configuration type: Received '{type(format_configuration).__name__}'.")
401
932
 
402
- print("\n--- Training History ---")
403
- plot_losses(self.history, save_dir=save_dir)
933
+ multi_label_classification_metrics(y_true=y_true,
934
+ y_pred=y_pred,
935
+ y_prob=y_prob,
936
+ target_names=target_names,
937
+ save_dir=save_dir,
938
+ config=config)
939
+
940
+ # Segmentation tasks
941
+ elif self.kind in [MLTaskKeys.BINARY_SEGMENTATION, MLTaskKeys.MULTICLASS_SEGMENTATION]:
942
+ class_names = None
943
+ try:
944
+ # Try to get 'classes' from VisionDatasetMaker
945
+ if hasattr(dataset_for_artifacts, 'classes'):
946
+ class_names = dataset_for_artifacts.classes # type: ignore
947
+ # Fallback for Subset
948
+ elif hasattr(dataset_for_artifacts, 'dataset') and hasattr(dataset_for_artifacts.dataset, 'classes'): # type: ignore
949
+ class_names = dataset_for_artifacts.dataset.classes # type: ignore
950
+ except AttributeError:
951
+ pass # class_names is still None
952
+
953
+ if class_names is None:
954
+ try:
955
+ # Fallback to 'target_names'
956
+ class_names = dataset_for_artifacts.target_names # type: ignore
957
+ except AttributeError:
958
+ # Fallback to inferring from labels
959
+ labels = np.unique(y_true)
960
+ class_names = [f"Class {i}" for i in labels]
961
+ _LOGGER.warning(f"Dataset has no 'classes' or 'target_names' attribute. Using generic names.")
962
+
963
+ # Check configuration
964
+ config = None
965
+ if format_configuration and isinstance(format_configuration, (BinarySegmentationMetricsFormat, MultiClassSegmentationMetricsFormat)):
966
+ config = format_configuration
967
+ elif format_configuration:
968
+ _LOGGER.warning(f"Wrong configuration type: Received '{type(format_configuration).__name__}'.")
969
+
970
+ segmentation_metrics(y_true=y_true,
971
+ y_pred=y_pred,
972
+ save_dir=save_dir,
973
+ class_names=class_names,
974
+ config=config)
404
975
 
405
976
  def explain(self,
406
977
  save_dir: Union[str,Path],
@@ -408,7 +979,7 @@ class MLTrainer:
408
979
  n_samples: int = 300,
409
980
  feature_names: Optional[List[str]] = None,
410
981
  target_names: Optional[List[str]] = None,
411
- explainer_type: Literal['deep', 'kernel'] = 'deep'):
982
+ explainer_type: Literal['deep', 'kernel'] = 'kernel'):
412
983
  """
413
984
  Explains model predictions using SHAP and saves all artifacts.
414
985
 
@@ -422,41 +993,59 @@ class MLTrainer:
422
993
  explain_dataset (Dataset | None): A specific dataset to explain.
423
994
  If None, the trainer's test dataset is used.
424
995
  n_samples (int): The number of samples to use for both background and explanation.
425
- feature_names (list[str] | None): Feature names.
996
+ feature_names (list[str] | None): Feature names. If None, the names will be extracted from the Dataset and raise an error on failure.
426
997
  target_names (list[str] | None): Target names for multi-target tasks.
427
998
  save_dir (str | Path): Directory to save all SHAP artifacts.
428
999
  explainer_type (Literal['deep', 'kernel']): The explainer to use.
429
- - 'deep': (Default) Uses shap.DeepExplainer. Fast and efficient for PyTorch models.
1000
+ - 'deep': Uses shap.DeepExplainer. Fast and efficient for PyTorch models.
430
1001
  - 'kernel': Uses shap.KernelExplainer. Model-agnostic but EXTREMELY slow and memory-intensive. Use with a very low 'n_samples'< 100.
431
- """
432
- # Internal helper to create a dataloader and get a random sample
1002
+ """
1003
+ # memory efficient helper
433
1004
  def _get_random_sample(dataset: Dataset, num_samples: int):
1005
+ """
1006
+ Memory-efficiently samples data from a dataset.
1007
+ """
434
1008
  if dataset is None:
435
1009
  return None
436
1010
 
1011
+ dataset_len = len(dataset) # type: ignore
1012
+ if dataset_len == 0:
1013
+ return None
1014
+
437
1015
  # For MPS devices, num_workers must be 0 to ensure stability
438
1016
  loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
439
1017
 
1018
+ # Ensure batch_size is not larger than the dataset itself
1019
+ batch_size = min(num_samples, 64, dataset_len)
1020
+
440
1021
  loader = DataLoader(
441
1022
  dataset,
442
- batch_size=64,
443
- shuffle=False,
1023
+ batch_size=batch_size,
1024
+ shuffle=True, # Shuffle to get random samples
444
1025
  num_workers=loader_workers
445
1026
  )
446
1027
 
447
- all_features = [features for features, _ in loader]
448
- if not all_features:
1028
+ collected_features = []
1029
+ num_collected = 0
1030
+
1031
+ for features, _ in loader:
1032
+ collected_features.append(features)
1033
+ num_collected += features.size(0)
1034
+ if num_collected >= num_samples:
1035
+ break # Stop once we have enough samples
1036
+
1037
+ if not collected_features:
449
1038
  return None
450
1039
 
451
- full_data = torch.cat(all_features, dim=0)
1040
+ full_data = torch.cat(collected_features, dim=0)
452
1041
 
453
- if num_samples >= full_data.size(0):
454
- return full_data
1042
+ # If we collected more than needed, trim it down
1043
+ if full_data.size(0) > num_samples:
1044
+ return full_data[:num_samples]
455
1045
 
456
- rand_indices = torch.randperm(full_data.size(0))[:num_samples]
457
- return full_data[rand_indices]
458
-
459
- print(f"\n--- Preparing SHAP Data (sampling up to {n_samples} instances) ---")
1046
+ return full_data
1047
+
1048
+ # print(f"\n--- Preparing SHAP Data (sampling up to {n_samples} instances) ---")
460
1049
 
461
1050
  # 1. Get background data from the trainer's train_dataset
462
1051
  background_data = _get_random_sample(self.train_dataset, n_samples)
@@ -465,7 +1054,7 @@ class MLTrainer:
465
1054
  return
466
1055
 
467
1056
  # 2. Determine target dataset and get explanation instances
468
- target_dataset = explain_dataset if explain_dataset is not None else self.test_dataset
1057
+ target_dataset = explain_dataset if explain_dataset is not None else self.validation_dataset
469
1058
  instances_to_explain = _get_random_sample(target_dataset, n_samples)
470
1059
  if instances_to_explain is None:
471
1060
  _LOGGER.error("Explanation dataset is empty or invalid. Skipping SHAP analysis.")
@@ -474,17 +1063,17 @@ class MLTrainer:
474
1063
  # attempt to get feature names
475
1064
  if feature_names is None:
476
1065
  # _LOGGER.info("`feature_names` not provided. Attempting to extract from dataset...")
477
- if hasattr(target_dataset, "feature_names"):
1066
+ if hasattr(target_dataset, DatasetKeys.FEATURE_NAMES):
478
1067
  feature_names = target_dataset.feature_names # type: ignore
479
1068
  else:
480
- _LOGGER.error("Could not extract `feature_names` from the dataset. It must be provided if the dataset object does not have a `feature_names` attribute.")
1069
+ _LOGGER.error(f"Could not extract `feature_names` from the dataset. It must be provided if the dataset object does not have a '{DatasetKeys.FEATURE_NAMES}' attribute.")
481
1070
  raise ValueError()
482
1071
 
483
1072
  # move model to device
484
1073
  self.model.to(self.device)
485
1074
 
486
1075
  # 3. Call the plotting function
487
- if self.kind in ["regression", "classification"]:
1076
+ if self.kind in [MLTaskKeys.REGRESSION, MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.MULTICLASS_CLASSIFICATION]:
488
1077
  shap_summary_plot(
489
1078
  model=self.model,
490
1079
  background_data=background_data,
@@ -494,11 +1083,11 @@ class MLTrainer:
494
1083
  explainer_type=explainer_type,
495
1084
  device=self.device
496
1085
  )
497
- elif self.kind in ["multi_target_regression", "multi_label_classification"]:
1086
+ elif self.kind in [MLTaskKeys.MULTITARGET_REGRESSION, MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION]:
498
1087
  # try to get target names
499
1088
  if target_names is None:
500
1089
  target_names = []
501
- if hasattr(target_dataset, 'target_names'):
1090
+ if hasattr(target_dataset, DatasetKeys.TARGET_NAMES):
502
1091
  target_names = target_dataset.target_names # type: ignore
503
1092
  else:
504
1093
  # Infer number of targets from the model's output layer
@@ -549,7 +1138,7 @@ class MLTrainer:
549
1138
  yield attention_weights
550
1139
 
551
1140
  def explain_attention(self, save_dir: Union[str, Path],
552
- feature_names: Optional[List[str]],
1141
+ feature_names: Optional[List[str]] = None,
553
1142
  explain_dataset: Optional[Dataset] = None,
554
1143
  plot_n_features: int = 10):
555
1144
  """
@@ -559,27 +1148,32 @@ class MLTrainer:
559
1148
 
560
1149
  Args:
561
1150
  save_dir (str | Path): Directory to save the plot and summary data.
562
- feature_names (List[str] | None): Names for the features for plot labeling. If not given, generic names will be used.
1151
+ feature_names (List[str] | None): Names for the features for plot labeling. If None, the names will be extracted from the Dataset and raise an error on failure.
563
1152
  explain_dataset (Dataset, optional): A specific dataset to explain. If None, the trainer's test dataset is used.
564
1153
  plot_n_features (int): Number of top features to plot.
565
1154
  """
566
1155
 
567
- print("\n--- Attention Analysis ---")
1156
+ # print("\n--- Attention Analysis ---")
568
1157
 
569
1158
  # --- Step 1: Check if the model supports this explanation ---
570
1159
  if not getattr(self.model, 'has_interpretable_attention', False):
571
- _LOGGER.warning(
572
- "Model is not flagged for interpretable attention analysis. "
573
- "Skipping. This is the correct behavior for models like MultiHeadAttentionMLP."
574
- )
1160
+ _LOGGER.warning("Model is not compatible with interpretable attention analysis. Skipping.")
575
1161
  return
576
1162
 
577
1163
  # --- Step 2: Set up the dataloader ---
578
- dataset_to_use = explain_dataset if explain_dataset is not None else self.test_dataset
1164
+ dataset_to_use = explain_dataset if explain_dataset is not None else self.validation_dataset
579
1165
  if not isinstance(dataset_to_use, Dataset):
580
1166
  _LOGGER.error("The explanation dataset is empty or invalid. Skipping attention analysis.")
581
1167
  return
582
1168
 
1169
+ # Get feature names
1170
+ if feature_names is None:
1171
+ if hasattr(dataset_to_use, DatasetKeys.FEATURE_NAMES):
1172
+ feature_names = dataset_to_use.feature_names # type: ignore
1173
+ else:
1174
+ _LOGGER.error(f"Could not extract `feature_names` from the dataset for attention plot. It must be provided if the dataset object does not have a '{DatasetKeys.FEATURE_NAMES}' attribute.")
1175
+ raise ValueError()
1176
+
583
1177
  explain_loader = DataLoader(
584
1178
  dataset=dataset_to_use, batch_size=32, shuffle=False,
585
1179
  num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
@@ -602,34 +1196,865 @@ class MLTrainer:
602
1196
  )
603
1197
  else:
604
1198
  _LOGGER.error("No attention weights were collected from the model.")
605
-
606
- def _callbacks_hook(self, method_name: str, *args, **kwargs):
607
- """Calls the specified method on all callbacks."""
608
- for callback in self.callbacks:
609
- method = getattr(callback, method_name)
610
- method(*args, **kwargs)
611
-
612
- def to_cpu(self):
613
- """
614
- Moves the model to the CPU and updates the trainer's device setting.
615
1199
 
616
- This is useful for running operations that require the CPU.
617
- """
618
- self.device = torch.device('cpu')
619
- self.model.to(self.device)
620
- _LOGGER.info("Trainer and model moved to CPU.")
621
-
622
- def to_device(self, device: str):
1200
+ def finalize_model_training(self,
1201
+ model_checkpoint: Union[Path, Literal['latest', 'current']],
1202
+ save_dir: Union[str, Path],
1203
+ finalize_config: Union[FinalizeRegression,
1204
+ FinalizeMultiTargetRegression,
1205
+ FinalizeBinaryClassification,
1206
+ FinalizeBinaryImageClassification,
1207
+ FinalizeMultiClassClassification,
1208
+ FinalizeMultiClassImageClassification,
1209
+ FinalizeBinarySegmentation,
1210
+ FinalizeMultiClassSegmentation,
1211
+ FinalizeMultiLabelBinaryClassification]):
623
1212
  """
624
- Moves the model to the specified device and updates the trainer's device setting.
1213
+ Saves a finalized, "inference-ready" model state to a .pth file.
1214
+
1215
+ This method saves the model's `state_dict`, the final epoch number, and optional configuration for the task at hand.
625
1216
 
626
1217
  Args:
627
- device (str): The target device (e.g., 'cuda', 'mps', 'cpu').
1218
+ model_checkpoint (Path | "latest" | "current"):
1219
+ - Path: Loads the model state from a specific checkpoint file.
1220
+ - "latest": Loads the best model state saved by the `DragonModelCheckpoint` callback.
1221
+ - "current": Uses the model's state as it is at the end of the `fit()` call.
1222
+ save_dir (str | Path): The directory to save the finalized model.
1223
+ finalize_config (object): A data class instance specific to the ML task containing task-specific metadata required for inference.
628
1224
  """
629
- self.device = self._validate_device(device)
630
- self.model.to(self.device)
631
- _LOGGER.info(f"Trainer and model moved to {self.device}.")
1225
+ if self.kind == MLTaskKeys.REGRESSION and not isinstance(finalize_config, FinalizeRegression):
1226
+ _LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeRegression', but got {type(finalize_config).__name__}.")
1227
+ raise TypeError()
1228
+ elif self.kind == MLTaskKeys.MULTITARGET_REGRESSION and not isinstance(finalize_config, FinalizeMultiTargetRegression):
1229
+ _LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeMultiTargetRegression', but got {type(finalize_config).__name__}.")
1230
+ raise TypeError()
1231
+ elif self.kind == MLTaskKeys.BINARY_CLASSIFICATION and not isinstance(finalize_config, FinalizeBinaryClassification):
1232
+ _LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeBinaryClassification', but got {type(finalize_config).__name__}.")
1233
+ raise TypeError()
1234
+ elif self.kind == MLTaskKeys.BINARY_IMAGE_CLASSIFICATION and not isinstance(finalize_config, FinalizeBinaryImageClassification):
1235
+ _LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeBinaryImageClassification', but got {type(finalize_config).__name__}.")
1236
+ raise TypeError()
1237
+ elif self.kind == MLTaskKeys.MULTICLASS_CLASSIFICATION and not isinstance(finalize_config, FinalizeMultiClassClassification):
1238
+ _LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeMultiClassClassification', but got {type(finalize_config).__name__}.")
1239
+ raise TypeError()
1240
+ elif self.kind == MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION and not isinstance(finalize_config, FinalizeMultiClassImageClassification):
1241
+ _LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeMultiClassImageClassification', but got {type(finalize_config).__name__}.")
1242
+ raise TypeError()
1243
+ elif self.kind == MLTaskKeys.BINARY_SEGMENTATION and not isinstance(finalize_config, FinalizeBinarySegmentation):
1244
+ _LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeBinarySegmentation', but got {type(finalize_config).__name__}.")
1245
+ raise TypeError()
1246
+ elif self.kind == MLTaskKeys.MULTICLASS_SEGMENTATION and not isinstance(finalize_config, FinalizeMultiClassSegmentation):
1247
+ _LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeMultiClassSegmentation', but got {type(finalize_config).__name__}.")
1248
+ raise TypeError()
1249
+ elif self.kind == MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION and not isinstance(finalize_config, FinalizeMultiLabelBinaryClassification):
1250
+ _LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeMultiLabelBinaryClassification', but got {type(finalize_config).__name__}.")
1251
+ raise TypeError()
1252
+
1253
+ # handle save path
1254
+ dir_path = make_fullpath(save_dir, make=True, enforce="directory")
1255
+ full_path = dir_path / finalize_config.filename
1256
+
1257
+ # handle checkpoint
1258
+ self._load_model_state_for_finalizing(model_checkpoint)
1259
+
1260
+ # Create finalized data
1261
+ finalized_data = {
1262
+ PyTorchCheckpointKeys.EPOCH: self.epoch,
1263
+ PyTorchCheckpointKeys.MODEL_STATE: self.model.state_dict(),
1264
+ }
1265
+
1266
+ # Parse config
1267
+ if finalize_config.target_name is not None:
1268
+ finalized_data[PyTorchCheckpointKeys.TARGET_NAME] = finalize_config.target_name
1269
+ if finalize_config.target_names is not None:
1270
+ finalized_data[PyTorchCheckpointKeys.TARGET_NAMES] = finalize_config.target_names
1271
+ if finalize_config.classification_threshold is not None:
1272
+ finalized_data[PyTorchCheckpointKeys.CLASSIFICATION_THRESHOLD] = finalize_config.classification_threshold
1273
+ if finalize_config.class_map is not None:
1274
+ finalized_data[PyTorchCheckpointKeys.CLASS_MAP] = finalize_config.class_map
1275
+
1276
+ # Save model file
1277
+ torch.save(finalized_data, full_path)
1278
+
1279
+ _LOGGER.info(f"Finalized model file saved to '{full_path}'")
1280
+
1281
+
1282
+ # Object Detection Trainer
1283
+ class DragonDetectionTrainer(_BaseDragonTrainer):
1284
+ def __init__(self, model: nn.Module,
1285
+ train_dataset: Dataset,
1286
+ validation_dataset: Dataset,
1287
+ collate_fn: Callable, optimizer: torch.optim.Optimizer,
1288
+ device: Union[Literal['cuda', 'mps', 'cpu'],str],
1289
+ checkpoint_callback: Optional[DragonModelCheckpoint],
1290
+ early_stopping_callback: Optional[DragonEarlyStopping],
1291
+ lr_scheduler_callback: Optional[DragonLRScheduler],
1292
+ extra_callbacks: Optional[List[_Callback]] = None,
1293
+ dataloader_workers: int = 2):
1294
+ """
1295
+ Automates the training process of an Object Detection Model (e.g., DragonFastRCNN).
1296
+
1297
+ Built-in Callbacks: `History`, `TqdmProgressBar`
1298
+
1299
+ Args:
1300
+ model (nn.Module): The PyTorch object detection model to train.
1301
+ train_dataset (Dataset): The training dataset.
1302
+ validation_dataset (Dataset): The testing/validation dataset.
1303
+ collate_fn (Callable): The collate function from `ObjectDetectionDatasetMaker.collate_fn`.
1304
+ optimizer (torch.optim.Optimizer): The optimizer.
1305
+ device (str): The device to run training on ('cpu', 'cuda', 'mps').
1306
+ dataloader_workers (int): Subprocesses for data loading.
1307
+ checkpoint_callback (DragonModelCheckpoint | None): Callback to save the model.
1308
+ early_stopping_callback (DragonEarlyStopping | None): Callback to stop training early.
1309
+ lr_scheduler_callback (DragonLRScheduler | None): Callback to manage the LR scheduler.
1310
+ extra_callbacks (List[Callback] | None): A list of extra callbacks to use during training.
1311
+
1312
+ ## Note:
1313
+ This trainer is specialized. It does not take a `criterion` because object detection models like Faster R-CNN return a dictionary of losses directly from their forward pass during training.
1314
+ """
1315
+ # Call the base class constructor with common parameters
1316
+ super().__init__(
1317
+ model=model,
1318
+ optimizer=optimizer,
1319
+ device=device,
1320
+ dataloader_workers=dataloader_workers,
1321
+ checkpoint_callback=checkpoint_callback,
1322
+ early_stopping_callback=early_stopping_callback,
1323
+ lr_scheduler_callback=lr_scheduler_callback,
1324
+ extra_callbacks=extra_callbacks
1325
+ )
1326
+
1327
+ self.train_dataset = train_dataset
1328
+ self.validation_dataset = validation_dataset # <-- Renamed
1329
+ self.kind = MLTaskKeys.OBJECT_DETECTION
1330
+ self.collate_fn = collate_fn
1331
+ self.criterion = None # Criterion is handled inside the model
1332
+
1333
+ def _create_dataloaders(self, batch_size: int, shuffle: bool):
1334
+ """Initializes the DataLoaders with the object detection collate_fn."""
1335
+ # Ensure stability on MPS devices by setting num_workers to 0
1336
+ loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
1337
+
1338
+ self.train_loader = DataLoader(
1339
+ dataset=self.train_dataset,
1340
+ batch_size=batch_size,
1341
+ shuffle=shuffle,
1342
+ num_workers=loader_workers,
1343
+ pin_memory=("cuda" in self.device.type),
1344
+ collate_fn=self.collate_fn, # Use the provided collate function
1345
+ drop_last=True
1346
+ )
1347
+
1348
+ self.validation_loader = DataLoader(
1349
+ dataset=self.validation_dataset,
1350
+ batch_size=batch_size,
1351
+ shuffle=False,
1352
+ num_workers=loader_workers,
1353
+ pin_memory=("cuda" in self.device.type),
1354
+ collate_fn=self.collate_fn # Use the provided collate function
1355
+ )
1356
+
1357
+ def _train_step(self):
1358
+ self.model.train()
1359
+ running_loss = 0.0
1360
+ total_samples = 0
1361
+
1362
+ for batch_idx, (images, targets) in enumerate(self.train_loader): # type: ignore
1363
+ # images is a tuple of tensors, targets is a tuple of dicts
1364
+ batch_size = len(images)
1365
+
1366
+ # Create a log dictionary for the batch
1367
+ batch_logs = {
1368
+ PyTorchLogKeys.BATCH_INDEX: batch_idx,
1369
+ PyTorchLogKeys.BATCH_SIZE: batch_size
1370
+ }
1371
+ self._callbacks_hook('on_batch_begin', batch_idx, logs=batch_logs)
1372
+
1373
+ # Move data to device
1374
+ images = list(img.to(self.device) for img in images)
1375
+ targets = [{k: v.to(self.device) for k, v in t.items()} for t in targets]
1376
+
1377
+ self.optimizer.zero_grad()
1378
+
1379
+ # Model returns a loss dict when in train() mode and targets are passed
1380
+ loss_dict = self.model(images, targets)
1381
+
1382
+ if not loss_dict:
1383
+ # No losses returned, skip batch
1384
+ _LOGGER.warning(f"Model returned no losses for batch {batch_idx}. Skipping.")
1385
+ batch_logs[PyTorchLogKeys.BATCH_LOSS] = 0
1386
+ self._callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
1387
+ continue
1388
+
1389
+ # Sum all losses
1390
+ loss: torch.Tensor = sum(l for l in loss_dict.values()) # type: ignore
1391
+
1392
+ loss.backward()
1393
+ self.optimizer.step()
1394
+
1395
+ # Calculate batch loss and update running loss for the epoch
1396
+ batch_loss = loss.item()
1397
+ running_loss += batch_loss * batch_size
1398
+ total_samples += batch_size # <-- Accumulate total samples
1399
+
1400
+ # Add the batch loss to the logs and call the end-of-batch hook
1401
+ batch_logs[PyTorchLogKeys.BATCH_LOSS] = batch_loss # type: ignore
1402
+ self._callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
1403
+
1404
+ # Calculate loss using the correct denominator
1405
+ if total_samples == 0:
1406
+ _LOGGER.warning("No samples processed in _train_step. Returning 0 loss.")
1407
+ return {PyTorchLogKeys.TRAIN_LOSS: 0.0}
1408
+
1409
+ return {PyTorchLogKeys.TRAIN_LOSS: running_loss / total_samples}
1410
+
1411
+ def _validation_step(self):
1412
+ self.model.train() # Set to train mode even for validation loss calculation
1413
+ # as model internals (e.g., proposals) might differ, but we still need loss_dict.
1414
+ # use torch.no_grad() to prevent gradient updates.
1415
+ running_loss = 0.0
1416
+ total_samples = 0
1417
+
1418
+ with torch.no_grad():
1419
+ for images, targets in self.validation_loader: # type: ignore
1420
+ batch_size = len(images)
1421
+
1422
+ # Move data to device
1423
+ images = list(img.to(self.device) for img in images)
1424
+ targets = [{k: v.to(self.device) for k, v in t.items()} for t in targets]
1425
+
1426
+ # Get loss dict
1427
+ loss_dict = self.model(images, targets)
1428
+
1429
+ if not loss_dict:
1430
+ _LOGGER.warning("Model returned no losses during validation step. Skipping batch.")
1431
+ continue # Skip if no losses
1432
+
1433
+ # Sum all losses
1434
+ loss: torch.Tensor = sum(l for l in loss_dict.values()) # type: ignore
1435
+
1436
+ running_loss += loss.item() * batch_size
1437
+ total_samples += batch_size # <-- Accumulate total samples
1438
+
1439
+ # Calculate loss using the correct denominator
1440
+ if total_samples == 0:
1441
+ _LOGGER.warning("No samples processed in _validation_step. Returning 0 loss.")
1442
+ return {PyTorchLogKeys.VAL_LOSS: 0.0}
1443
+
1444
+ logs = {PyTorchLogKeys.VAL_LOSS: running_loss / total_samples}
1445
+ return logs
1446
+
1447
+ def evaluate(self,
1448
+ save_dir: Union[str, Path],
1449
+ model_checkpoint: Union[Path, Literal["latest", "current"]],
1450
+ test_data: Optional[Union[DataLoader, Dataset]] = None):
1451
+ """
1452
+ Evaluates the model using object detection mAP metrics.
1453
+
1454
+ Args:
1455
+ save_dir (str | Path): Directory to save all reports and plots.
1456
+ model_checkpoint ('auto' | Path | None):
1457
+ - Path to a valid checkpoint for the model. The state of the trained model will be overwritten in place.
1458
+ - If 'latest', the latest checkpoint will be loaded if a DragonModelCheckpoint was provided. The state of the trained model will be overwritten in place.
1459
+ - If 'current', use the current state of the trained model up the latest trained epoch.
1460
+ test_data (DataLoader | Dataset | None): Optional Test data to evaluate the model performance. Validation and Test metrics will be saved to subdirectories.
1461
+ """
1462
+ # Validate model checkpoint
1463
+ if isinstance(model_checkpoint, Path):
1464
+ checkpoint_validated = make_fullpath(model_checkpoint, enforce="file")
1465
+ elif model_checkpoint in [MagicWords.LATEST, MagicWords.CURRENT]:
1466
+ checkpoint_validated = model_checkpoint
1467
+ else:
1468
+ _LOGGER.error(f"'model_checkpoint' must be a Path object, or the string '{MagicWords.LATEST}', or the string '{MagicWords.CURRENT}'.")
1469
+ raise ValueError()
1470
+
1471
+ # Validate directory
1472
+ save_path = make_fullpath(save_dir, make=True, enforce="directory")
1473
+
1474
+ # Validate test data and dispatch
1475
+ if test_data is not None:
1476
+ if not isinstance(test_data, (DataLoader, Dataset)):
1477
+ _LOGGER.error(f"Invalid type for 'test_data': '{type(test_data)}'.")
1478
+ raise ValueError()
1479
+ test_data_validated = test_data
1480
+
1481
+ validation_metrics_path = save_path / DragonTrainerKeys.VALIDATION_METRICS_DIR
1482
+ test_metrics_path = save_path / DragonTrainerKeys.TEST_METRICS_DIR
1483
+
1484
+ # Dispatch validation set
1485
+ _LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{DragonTrainerKeys.VALIDATION_METRICS_DIR}'")
1486
+ self._evaluate(save_dir=validation_metrics_path,
1487
+ model_checkpoint=checkpoint_validated,
1488
+ data=None) # 'None' triggers use of self.test_dataset
1489
+
1490
+ # Dispatch test set
1491
+ _LOGGER.info(f"Evaluating on test dataset. Metrics will be saved to '{DragonTrainerKeys.TEST_METRICS_DIR}'")
1492
+ self._evaluate(save_dir=test_metrics_path,
1493
+ model_checkpoint="current", # Use 'current' state after loading checkpoint once
1494
+ data=test_data_validated)
1495
+ else:
1496
+ # Dispatch validation set
1497
+ _LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{save_path.name}'")
1498
+ self._evaluate(save_dir=save_path,
1499
+ model_checkpoint=checkpoint_validated,
1500
+ data=None) # 'None' triggers use of self.test_dataset
632
1501
 
1502
+ def _evaluate(self,
1503
+ save_dir: Union[str, Path],
1504
+ model_checkpoint: Union[Path, Literal["latest", "current"]],
1505
+ data: Optional[Union[DataLoader, Dataset]]):
1506
+ """
1507
+ Changed to a private helper method
1508
+ Evaluates the model using object detection mAP metrics.
1509
+
1510
+ Args:
1511
+ save_dir (str | Path): Directory to save all reports and plots.
1512
+ data (DataLoader | Dataset | None): The data to evaluate on. If None, defaults to the trainer's internal test_dataset.
1513
+ model_checkpoint ('auto' | Path | None):
1514
+ - Path to a valid checkpoint for the model. The state of the trained model will be overwritten in place.
1515
+ - If 'latest', the latest checkpoint will be loaded if a DragonModelCheckpoint was provided. The state of the trained model will be overwritten in place.
1516
+ - If 'current', use the current state of the trained model up the latest trained epoch.
1517
+ """
1518
+ dataset_for_artifacts = None
1519
+ eval_loader = None
1520
+
1521
+ # load model checkpoint
1522
+ if isinstance(model_checkpoint, Path):
1523
+ self._load_checkpoint(path=model_checkpoint)
1524
+ elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback:
1525
+ path_to_latest = self._checkpoint_callback.best_checkpoint_path
1526
+ self._load_checkpoint(path_to_latest)
1527
+ elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback is None:
1528
+ _LOGGER.error(f"'model_checkpoint' set to '{MagicWords.LATEST}' but no checkpoint callback was found.")
1529
+ raise ValueError()
1530
+
1531
+ # Dataloader
1532
+ if isinstance(data, DataLoader):
1533
+ eval_loader = data
1534
+ if hasattr(data, 'dataset'):
1535
+ dataset_for_artifacts = data.dataset # type: ignore
1536
+ elif isinstance(data, Dataset):
1537
+ # Create a new loader from the provided dataset
1538
+ eval_loader = DataLoader(data,
1539
+ batch_size=self._batch_size,
1540
+ shuffle=False,
1541
+ num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
1542
+ pin_memory=(self.device.type == "cuda"),
1543
+ collate_fn=self.collate_fn)
1544
+ dataset_for_artifacts = data
1545
+ else: # data is None, use the trainer's default test dataset
1546
+ if self.validation_dataset is None:
1547
+ _LOGGER.error("Cannot evaluate. No data provided and no test_dataset available in the trainer.")
1548
+ raise ValueError()
1549
+ # Create a fresh DataLoader from the test_dataset
1550
+ eval_loader = DataLoader(
1551
+ self.validation_dataset,
1552
+ batch_size=self._batch_size,
1553
+ shuffle=False,
1554
+ num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
1555
+ pin_memory=(self.device.type == "cuda"),
1556
+ collate_fn=self.collate_fn
1557
+ )
1558
+ dataset_for_artifacts = self.validation_dataset
1559
+
1560
+ if eval_loader is None:
1561
+ _LOGGER.error("Cannot evaluate. No valid data was provided or found.")
1562
+ raise ValueError()
1563
+
1564
+ # print("\n--- Model Evaluation ---")
1565
+
1566
+ all_predictions = []
1567
+ all_targets = []
1568
+
1569
+ self.model.eval() # Set model to evaluation mode
1570
+ self.model.to(self.device)
1571
+
1572
+ with torch.no_grad():
1573
+ for images, targets in eval_loader:
1574
+ # Move images to device
1575
+ images = list(img.to(self.device) for img in images)
1576
+
1577
+ # Model returns predictions when in eval() mode
1578
+ predictions = self.model(images)
1579
+
1580
+ # Move predictions and targets to CPU for aggregation
1581
+ cpu_preds = [{k: v.to('cpu') for k, v in p.items()} for p in predictions]
1582
+ cpu_targets = [{k: v.to('cpu') for k, v in t.items()} for t in targets]
1583
+
1584
+ all_predictions.extend(cpu_preds)
1585
+ all_targets.extend(cpu_targets)
1586
+
1587
+ if not all_targets:
1588
+ _LOGGER.error("Evaluation failed: No data was processed.")
1589
+ return
1590
+
1591
+ # Get class names from the dataset for the report
1592
+ class_names = None
1593
+ try:
1594
+ # Try to get 'classes' from ObjectDetectionDatasetMaker
1595
+ if hasattr(dataset_for_artifacts, 'classes'):
1596
+ class_names = dataset_for_artifacts.classes # type: ignore
1597
+ # Fallback for Subset
1598
+ elif hasattr(dataset_for_artifacts, 'dataset') and hasattr(dataset_for_artifacts.dataset, 'classes'): # type: ignore
1599
+ class_names = dataset_for_artifacts.dataset.classes # type: ignore
1600
+ except AttributeError:
1601
+ _LOGGER.warning("Could not find 'classes' attribute on dataset. Per-class metrics will not be named.")
1602
+ pass # class_names is still None
1603
+
1604
+ # --- Routing Logic ---
1605
+ object_detection_metrics(
1606
+ preds=all_predictions,
1607
+ targets=all_targets,
1608
+ save_dir=save_dir,
1609
+ class_names=class_names,
1610
+ print_output=False
1611
+ )
1612
+
1613
+ def finalize_model_training(self,
1614
+ save_dir: Union[str, Path],
1615
+ model_checkpoint: Union[Path, Literal['latest', 'current']],
1616
+ finalize_config: FinalizeObjectDetection
1617
+ ):
1618
+ """
1619
+ Saves a finalized, "inference-ready" model state to a .pth file.
1620
+
1621
+ This method saves the model's `state_dict` and the final epoch number.
1622
+
1623
+ Args:
1624
+ save_dir (Union[str, Path]): The directory to save the finalized model.
1625
+ model_checkpoint (Union[Path, Literal["latest", "current"]]):
1626
+ - Path: Loads the model state from a specific checkpoint file.
1627
+ - "latest": Loads the best model state saved by the `DragonModelCheckpoint` callback.
1628
+ - "current": Uses the model's state as it is at the end of the `fit()` call.
1629
+ finalize_config (FinalizeObjectDetection): A data class instance specific to the ML task containing task-specific metadata required for inference.
1630
+ """
1631
+ if not isinstance(finalize_config, FinalizeObjectDetection):
1632
+ _LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeObjectDetection', but got {type(finalize_config).__name__}.")
1633
+ raise TypeError()
1634
+
1635
+ # handle save path
1636
+ dir_path = make_fullpath(save_dir, make=True, enforce="directory")
1637
+ full_path = dir_path / finalize_config.filename
1638
+
1639
+ # handle checkpoint
1640
+ self._load_model_state_for_finalizing(model_checkpoint)
1641
+
1642
+ # Create finalized data
1643
+ finalized_data = {
1644
+ PyTorchCheckpointKeys.EPOCH: self.epoch,
1645
+ PyTorchCheckpointKeys.MODEL_STATE: self.model.state_dict(),
1646
+ }
1647
+
1648
+ if finalize_config.class_map is not None:
1649
+ finalized_data[PyTorchCheckpointKeys.CLASS_MAP] = finalize_config.class_map
1650
+
1651
+ torch.save(finalized_data, full_path)
1652
+
1653
+ _LOGGER.info(f"Finalized model file saved to '{full_path}'")
1654
+
1655
+ # --- DragonSequenceTrainer ----
1656
+ class DragonSequenceTrainer(_BaseDragonTrainer):
1657
+ def __init__(self,
1658
+ model: nn.Module,
1659
+ train_dataset: Dataset,
1660
+ validation_dataset: Dataset,
1661
+ kind: Literal["sequence-to-sequence", "sequence-to-value"],
1662
+ optimizer: torch.optim.Optimizer,
1663
+ device: Union[Literal['cuda', 'mps', 'cpu'],str],
1664
+ checkpoint_callback: Optional[DragonModelCheckpoint],
1665
+ early_stopping_callback: Optional[DragonEarlyStopping],
1666
+ lr_scheduler_callback: Optional[DragonLRScheduler],
1667
+ extra_callbacks: Optional[List[_Callback]] = None,
1668
+ criterion: Union[nn.Module,Literal["auto"]] = "auto",
1669
+ dataloader_workers: int = 2):
1670
+ """
1671
+ Automates the training process of a PyTorch Sequence Model.
1672
+
1673
+ Built-in Callbacks: `History`, `TqdmProgressBar`
1674
+
1675
+ Args:
1676
+ model (nn.Module): The PyTorch model to train.
1677
+ train_dataset (Dataset): The training dataset.
1678
+ validation_dataset (Dataset): The validation dataset.
1679
+ kind (str): Used to redirect to the correct process ('sequence-to-sequence' or 'sequence-to-value').
1680
+ criterion (nn.Module | "auto"): The loss function to use. If "auto", it will be inferred from the selected task
1681
+ optimizer (torch.optim.Optimizer): The optimizer.
1682
+ device (str): The device to run training on ('cpu', 'cuda', 'mps').
1683
+ dataloader_workers (int): Subprocesses for data loading.
1684
+ extra_callbacks (List[Callback] | None): A list of extra callbacks to use during training.
1685
+ """
1686
+ # Call the base class constructor with common parameters
1687
+ super().__init__(
1688
+ model=model,
1689
+ optimizer=optimizer,
1690
+ device=device,
1691
+ dataloader_workers=dataloader_workers,
1692
+ checkpoint_callback=checkpoint_callback,
1693
+ early_stopping_callback=early_stopping_callback,
1694
+ lr_scheduler_callback=lr_scheduler_callback,
1695
+ extra_callbacks=extra_callbacks
1696
+ )
1697
+
1698
+ if kind not in [MLTaskKeys.SEQUENCE_SEQUENCE, MLTaskKeys.SEQUENCE_VALUE]:
1699
+ raise ValueError(f"'{kind}' is not a valid task type for DragonSequenceTrainer.")
1700
+
1701
+ self.train_dataset = train_dataset
1702
+ self.validation_dataset = validation_dataset
1703
+ self.kind = kind
1704
+
1705
+ # try to validate against Dragon Sequence model
1706
+ if hasattr(self.model, "prediction_mode"):
1707
+ key_to_check: str = self.model.prediction_mode # type: ignore
1708
+ if not key_to_check == self.kind:
1709
+ _LOGGER.error(f"Trainer was set for '{self.kind}', but model architecture '{self.model}' is built for '{key_to_check}'.")
1710
+ raise RuntimeError()
1711
+
1712
+ # loss function
1713
+ if criterion == "auto":
1714
+ # Both sequence tasks are treated as regression problems
1715
+ self.criterion = nn.MSELoss()
1716
+ else:
1717
+ self.criterion = criterion
1718
+
1719
+ def _create_dataloaders(self, batch_size: int, shuffle: bool):
1720
+ """Initializes the DataLoaders."""
1721
+ # Ensure stability on MPS devices by setting num_workers to 0
1722
+ loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
1723
+
1724
+ self.train_loader = DataLoader(
1725
+ dataset=self.train_dataset,
1726
+ batch_size=batch_size,
1727
+ shuffle=shuffle,
1728
+ num_workers=loader_workers,
1729
+ pin_memory=("cuda" in self.device.type),
1730
+ drop_last=True # Drops the last batch if incomplete, selecting a good batch size is key.
1731
+ )
1732
+
1733
+ self.validation_loader = DataLoader(
1734
+ dataset=self.validation_dataset,
1735
+ batch_size=batch_size,
1736
+ shuffle=False,
1737
+ num_workers=loader_workers,
1738
+ pin_memory=("cuda" in self.device.type)
1739
+ )
1740
+
1741
+ def _train_step(self):
1742
+ self.model.train()
1743
+ running_loss = 0.0
1744
+ total_samples = 0
1745
+
1746
+ for batch_idx, (features, target) in enumerate(self.train_loader): # type: ignore
1747
+ # Create a log dictionary for the batch
1748
+ batch_logs = {
1749
+ PyTorchLogKeys.BATCH_INDEX: batch_idx,
1750
+ PyTorchLogKeys.BATCH_SIZE: features.size(0)
1751
+ }
1752
+ self._callbacks_hook('on_batch_begin', batch_idx, logs=batch_logs)
1753
+
1754
+ features, target = features.to(self.device), target.to(self.device)
1755
+ self.optimizer.zero_grad()
1756
+
1757
+ output = self.model(features)
1758
+
1759
+ # --- Label Type/Shape Correction ---
1760
+ # Ensure target is float for MSELoss
1761
+ target = target.float()
1762
+
1763
+ # For seq-to-val, models might output [N, 1] but target is [N].
1764
+ if self.kind == MLTaskKeys.SEQUENCE_VALUE:
1765
+ if output.ndim == 2 and output.shape[1] == 1 and target.ndim == 1:
1766
+ output = output.squeeze(1)
1767
+
1768
+ # For seq-to-seq, models might output [N, Seq, 1] but target is [N, Seq].
1769
+ elif self.kind == MLTaskKeys.SEQUENCE_SEQUENCE:
1770
+ if output.ndim == 3 and output.shape[2] == 1 and target.ndim == 2:
1771
+ output = output.squeeze(-1)
1772
+
1773
+ loss = self.criterion(output, target)
1774
+
1775
+ loss.backward()
1776
+ self.optimizer.step()
1777
+
1778
+ # Calculate batch loss and update running loss for the epoch
1779
+ batch_loss = loss.item()
1780
+ batch_size = features.size(0)
1781
+ running_loss += batch_loss * batch_size # Accumulate total loss
1782
+ total_samples += batch_size # total samples
1783
+
1784
+ # Add the batch loss to the logs and call the end-of-batch hook
1785
+ batch_logs[PyTorchLogKeys.BATCH_LOSS] = batch_loss
1786
+ self._callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
1787
+
1788
+ if total_samples == 0:
1789
+ _LOGGER.warning("No samples processed in a train_step. Returning 0 loss.")
1790
+ return {PyTorchLogKeys.TRAIN_LOSS: 0.0}
1791
+
1792
+ return {PyTorchLogKeys.TRAIN_LOSS: running_loss / total_samples} # type: ignore
1793
+
1794
+ def _validation_step(self):
1795
+ self.model.eval()
1796
+ running_loss = 0.0
1797
+
1798
+ with torch.no_grad():
1799
+ for features, target in self.validation_loader: # type: ignore
1800
+ features, target = features.to(self.device), target.to(self.device)
1801
+
1802
+ output = self.model(features)
1803
+
1804
+ # --- Label Type/Shape Correction ---
1805
+ target = target.float()
1806
+
1807
+ # For seq-to-val, models might output [N, 1] but target is [N].
1808
+ if self.kind == MLTaskKeys.SEQUENCE_VALUE:
1809
+ if output.ndim == 2 and output.shape[1] == 1 and target.ndim == 1:
1810
+ output = output.squeeze(1)
1811
+
1812
+ # For seq-to-seq, models might output [N, Seq, 1] but target is [N, Seq].
1813
+ elif self.kind == MLTaskKeys.SEQUENCE_SEQUENCE:
1814
+ if output.ndim == 3 and output.shape[2] == 1 and target.ndim == 2:
1815
+ output = output.squeeze(-1)
1816
+
1817
+ loss = self.criterion(output, target)
1818
+
1819
+ running_loss += loss.item() * features.size(0)
1820
+
1821
+ if not self.validation_loader.dataset: # type: ignore
1822
+ _LOGGER.warning("No samples processed in _validation_step. Returning 0 loss.")
1823
+ return {PyTorchLogKeys.VAL_LOSS: 0.0}
1824
+
1825
+ logs = {PyTorchLogKeys.VAL_LOSS: running_loss / len(self.validation_loader.dataset)} # type: ignore
1826
+ return logs
1827
+
1828
+ def _predict_for_eval(self, dataloader: DataLoader):
1829
+ """
1830
+ Private method to yield model predictions batch by batch for evaluation.
1831
+
1832
+ Yields:
1833
+ tuple: A tuple containing (y_pred_batch, y_prob_batch, y_true_batch).
1834
+ y_prob_batch is always None for sequence tasks.
1835
+ """
1836
+ self.model.eval()
1837
+ self.model.to(self.device)
1838
+
1839
+ with torch.no_grad():
1840
+ for features, target in dataloader:
1841
+ features = features.to(self.device)
1842
+ output = self.model(features).cpu()
1843
+
1844
+ y_pred_batch = output.numpy()
1845
+ y_prob_batch = None # Not applicable for sequence regression
1846
+ y_true_batch = target.numpy()
1847
+
1848
+ yield y_pred_batch, y_prob_batch, y_true_batch
1849
+
1850
+ def evaluate(self,
1851
+ save_dir: Union[str, Path],
1852
+ model_checkpoint: Union[Path, Literal["latest", "current"]],
1853
+ test_data: Optional[Union[DataLoader, Dataset]] = None,
1854
+ val_format_configuration: Optional[Union[SequenceValueMetricsFormat,
1855
+ SequenceSequenceMetricsFormat]]=None,
1856
+ test_format_configuration: Optional[Union[SequenceValueMetricsFormat,
1857
+ SequenceSequenceMetricsFormat]]=None):
1858
+ """
1859
+ Evaluates the model, routing to the correct evaluation function.
1860
+
1861
+ Args:
1862
+ model_checkpoint ('auto' | Path | None):
1863
+ - Path to a valid checkpoint for the model.
1864
+ - If 'latest', the latest checkpoint will be loaded.
1865
+ - If 'current', use the current state of the trained model.
1866
+ save_dir (str | Path): Directory to save all reports and plots.
1867
+ test_data (DataLoader | Dataset | None): Optional Test data.
1868
+ val_format_configuration: Optional configuration for validation metrics.
1869
+ test_format_configuration: Optional configuration for test metrics.
1870
+ """
1871
+ # Validate model checkpoint
1872
+ if isinstance(model_checkpoint, Path):
1873
+ checkpoint_validated = make_fullpath(model_checkpoint, enforce="file")
1874
+ elif model_checkpoint in [MagicWords.LATEST, MagicWords.CURRENT]:
1875
+ checkpoint_validated = model_checkpoint
1876
+ else:
1877
+ _LOGGER.error(f"'model_checkpoint' must be a Path object, or '{MagicWords.LATEST}', or '{MagicWords.CURRENT}'.")
1878
+ raise ValueError()
1879
+
1880
+ # Validate val configuration
1881
+ if val_format_configuration is not None:
1882
+ if not isinstance(val_format_configuration, (SequenceValueMetricsFormat, SequenceSequenceMetricsFormat)):
1883
+ _LOGGER.error(f"Invalid 'val_format_configuration': '{type(val_format_configuration)}'.")
1884
+ raise ValueError()
1885
+
1886
+ # Validate directory
1887
+ save_path = make_fullpath(save_dir, make=True, enforce="directory")
1888
+
1889
+ # Validate test data and dispatch
1890
+ if test_data is not None:
1891
+ if not isinstance(test_data, (DataLoader, Dataset)):
1892
+ _LOGGER.error(f"Invalid type for 'test_data': '{type(test_data)}'.")
1893
+ raise ValueError()
1894
+ test_data_validated = test_data
1895
+
1896
+ validation_metrics_path = save_path / DragonTrainerKeys.VALIDATION_METRICS_DIR
1897
+ test_metrics_path = save_path / DragonTrainerKeys.TEST_METRICS_DIR
1898
+
1899
+ # Dispatch validation set
1900
+ _LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{DragonTrainerKeys.VALIDATION_METRICS_DIR}'")
1901
+ self._evaluate(save_dir=validation_metrics_path,
1902
+ model_checkpoint=checkpoint_validated,
1903
+ data=None,
1904
+ format_configuration=val_format_configuration)
1905
+
1906
+ # Validate test configuration
1907
+ test_configuration_validated = None
1908
+ if test_format_configuration is not None:
1909
+ if not isinstance(test_format_configuration, (SequenceValueMetricsFormat, SequenceSequenceMetricsFormat)):
1910
+ warning_message_type = f"Invalid test_format_configuration': '{type(test_format_configuration)}'."
1911
+ if val_format_configuration is not None:
1912
+ warning_message_type += " 'val_format_configuration' will be used."
1913
+ test_configuration_validated = val_format_configuration
1914
+ else:
1915
+ warning_message_type += " Using default format."
1916
+ _LOGGER.warning(warning_message_type)
1917
+ else:
1918
+ test_configuration_validated = test_format_configuration
1919
+
1920
+ # Dispatch test set
1921
+ _LOGGER.info(f"Evaluating on test dataset. Metrics will be saved to '{DragonTrainerKeys.TEST_METRICS_DIR}'")
1922
+ self._evaluate(save_dir=test_metrics_path,
1923
+ model_checkpoint="current",
1924
+ data=test_data_validated,
1925
+ format_configuration=test_configuration_validated)
1926
+ else:
1927
+ # Dispatch validation set
1928
+ _LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{save_path.name}'")
1929
+ self._evaluate(save_dir=save_path,
1930
+ model_checkpoint=checkpoint_validated,
1931
+ data=None,
1932
+ format_configuration=val_format_configuration)
1933
+
1934
+ def _evaluate(self,
1935
+ save_dir: Union[str, Path],
1936
+ model_checkpoint: Union[Path, Literal["latest", "current"]],
1937
+ data: Optional[Union[DataLoader, Dataset]],
1938
+ format_configuration: object):
1939
+ """
1940
+ Private evaluation helper.
1941
+ """
1942
+ eval_loader = None
1943
+
1944
+ # load model checkpoint
1945
+ if isinstance(model_checkpoint, Path):
1946
+ self._load_checkpoint(path=model_checkpoint)
1947
+ elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback:
1948
+ path_to_latest = self._checkpoint_callback.best_checkpoint_path
1949
+ self._load_checkpoint(path_to_latest)
1950
+ elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback is None:
1951
+ _LOGGER.error(f"'model_checkpoint' set to '{MagicWords.LATEST}' but no checkpoint callback was found.")
1952
+ raise ValueError()
1953
+
1954
+ # Dataloader
1955
+ if isinstance(data, DataLoader):
1956
+ eval_loader = data
1957
+ elif isinstance(data, Dataset):
1958
+ # Create a new loader from the provided dataset
1959
+ eval_loader = DataLoader(data,
1960
+ batch_size=self._batch_size,
1961
+ shuffle=False,
1962
+ num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
1963
+ pin_memory=(self.device.type == "cuda"))
1964
+ else: # data is None, use the trainer's default validation dataset
1965
+ if self.validation_dataset is None:
1966
+ _LOGGER.error("Cannot evaluate. No data provided and no validation_dataset available in the trainer.")
1967
+ raise ValueError()
1968
+ eval_loader = DataLoader(self.validation_dataset,
1969
+ batch_size=self._batch_size,
1970
+ shuffle=False,
1971
+ num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
1972
+ pin_memory=(self.device.type == "cuda"))
1973
+
1974
+ if eval_loader is None:
1975
+ _LOGGER.error("Cannot evaluate. No valid data was provided or found.")
1976
+ raise ValueError()
1977
+
1978
+ all_preds, _, all_true = [], [], []
1979
+ for y_pred_b, y_prob_b, y_true_b in self._predict_for_eval(eval_loader):
1980
+ if y_pred_b is not None: all_preds.append(y_pred_b)
1981
+ if y_true_b is not None: all_true.append(y_true_b)
1982
+
1983
+ if not all_true:
1984
+ _LOGGER.error("Evaluation failed: No data was processed.")
1985
+ return
1986
+
1987
+ y_pred = np.concatenate(all_preds)
1988
+ y_true = np.concatenate(all_true)
1989
+
1990
+ # --- Routing Logic ---
1991
+ if self.kind == MLTaskKeys.SEQUENCE_VALUE:
1992
+ config = None
1993
+ if format_configuration and isinstance(format_configuration, SequenceValueMetricsFormat):
1994
+ config = format_configuration
1995
+ elif format_configuration:
1996
+ _LOGGER.warning(f"Wrong config type: Received {type(format_configuration).__name__}, expected SequenceValueMetricsFormat.")
1997
+
1998
+ sequence_to_value_metrics(y_true=y_true,
1999
+ y_pred=y_pred,
2000
+ save_dir=save_dir,
2001
+ config=config)
2002
+
2003
+ elif self.kind == MLTaskKeys.SEQUENCE_SEQUENCE:
2004
+ config = None
2005
+ if format_configuration and isinstance(format_configuration, SequenceSequenceMetricsFormat):
2006
+ config = format_configuration
2007
+ elif format_configuration:
2008
+ _LOGGER.warning(f"Wrong config type: Received {type(format_configuration).__name__}, expected SequenceSequenceMetricsFormat.")
2009
+
2010
+ sequence_to_sequence_metrics(y_true=y_true,
2011
+ y_pred=y_pred,
2012
+ save_dir=save_dir,
2013
+ config=config)
2014
+
2015
+ def finalize_model_training(self,
2016
+ save_dir: Union[str, Path],
2017
+ model_checkpoint: Union[Path, Literal['latest', 'current']],
2018
+ finalize_config: FinalizeSequencePrediction):
2019
+ """
2020
+ Saves a finalized, "inference-ready" model state to a .pth file.
2021
+
2022
+ This method saves the model's `state_dict` and the final epoch number.
2023
+
2024
+ Args:
2025
+ save_dir (Union[str, Path]): The directory to save the finalized model.
2026
+ model_checkpoint (Union[Path, Literal["latest", "current"]]):
2027
+ - Path: Loads the model state from a specific checkpoint file.
2028
+ - "latest": Loads the best model state saved by the `DragonModelCheckpoint` callback.
2029
+ - "current": Uses the model's state as it is at the end of the `fit()` call.
2030
+ finalize_config (FinalizeSequencePrediction): A data class instance specific to the ML task containing task-specific metadata required for inference.
2031
+ """
2032
+ if not isinstance(finalize_config, FinalizeSequencePrediction):
2033
+ _LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeSequencePrediction', but got {type(finalize_config).__name__}.")
2034
+ raise TypeError()
2035
+
2036
+ # handle save path
2037
+ dir_path = make_fullpath(save_dir, make=True, enforce="directory")
2038
+ full_path = dir_path / finalize_config.filename
2039
+
2040
+ # handle checkpoint
2041
+ self._load_model_state_for_finalizing(model_checkpoint)
2042
+
2043
+ # Create finalized data
2044
+ finalized_data = {
2045
+ PyTorchCheckpointKeys.EPOCH: self.epoch,
2046
+ PyTorchCheckpointKeys.MODEL_STATE: self.model.state_dict(),
2047
+ }
2048
+
2049
+ if finalize_config.sequence_length is not None:
2050
+ finalized_data[PyTorchCheckpointKeys.SEQUENCE_LENGTH] = finalize_config.sequence_length
2051
+ if finalize_config.initial_sequence is not None:
2052
+ finalized_data[PyTorchCheckpointKeys.INITIAL_SEQUENCE] = finalize_config.initial_sequence
2053
+
2054
+ torch.save(finalized_data, full_path)
2055
+
2056
+ _LOGGER.info(f"Finalized model file saved to '{full_path}'")
2057
+
633
2058
 
634
2059
  def info():
635
2060
  _script_info(__all__)