dragon-ml-toolbox 12.13.0__py3-none-any.whl → 14.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dragon-ml-toolbox might be problematic. Click here for more details.

Files changed (35) hide show
  1. {dragon_ml_toolbox-12.13.0.dist-info → dragon_ml_toolbox-14.3.0.dist-info}/METADATA +11 -2
  2. dragon_ml_toolbox-14.3.0.dist-info/RECORD +48 -0
  3. {dragon_ml_toolbox-12.13.0.dist-info → dragon_ml_toolbox-14.3.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +10 -0
  4. ml_tools/MICE_imputation.py +207 -5
  5. ml_tools/ML_callbacks.py +40 -8
  6. ml_tools/ML_datasetmaster.py +200 -261
  7. ml_tools/ML_evaluation.py +29 -17
  8. ml_tools/ML_evaluation_multi.py +13 -10
  9. ml_tools/ML_inference.py +14 -5
  10. ml_tools/ML_models.py +135 -55
  11. ml_tools/ML_models_advanced.py +323 -0
  12. ml_tools/ML_optimization.py +49 -36
  13. ml_tools/ML_trainer.py +560 -30
  14. ml_tools/ML_utilities.py +302 -4
  15. ml_tools/ML_vision_datasetmaster.py +1352 -0
  16. ml_tools/ML_vision_evaluation.py +260 -0
  17. ml_tools/ML_vision_inference.py +428 -0
  18. ml_tools/ML_vision_models.py +627 -0
  19. ml_tools/ML_vision_transformers.py +58 -0
  20. ml_tools/PSO_optimization.py +5 -1
  21. ml_tools/_ML_vision_recipe.py +88 -0
  22. ml_tools/__init__.py +1 -0
  23. ml_tools/_schema.py +96 -0
  24. ml_tools/custom_logger.py +37 -14
  25. ml_tools/data_exploration.py +576 -138
  26. ml_tools/keys.py +51 -1
  27. ml_tools/math_utilities.py +1 -1
  28. ml_tools/optimization_tools.py +65 -86
  29. ml_tools/serde.py +78 -17
  30. ml_tools/utilities.py +192 -3
  31. dragon_ml_toolbox-12.13.0.dist-info/RECORD +0 -41
  32. ml_tools/ML_simple_optimization.py +0 -413
  33. {dragon_ml_toolbox-12.13.0.dist-info → dragon_ml_toolbox-14.3.0.dist-info}/WHEEL +0 -0
  34. {dragon_ml_toolbox-12.13.0.dist-info → dragon_ml_toolbox-14.3.0.dist-info}/licenses/LICENSE +0 -0
  35. {dragon_ml_toolbox-12.13.0.dist-info → dragon_ml_toolbox-14.3.0.dist-info}/top_level.txt +0 -0
ml_tools/ML_trainer.py CHANGED
@@ -1,26 +1,29 @@
1
- from typing import List, Literal, Union, Optional
1
+ from typing import List, Literal, Union, Optional, Callable, Dict, Any, Tuple
2
2
  from pathlib import Path
3
3
  from torch.utils.data import DataLoader, Dataset
4
4
  import torch
5
5
  from torch import nn
6
6
  import numpy as np
7
7
 
8
- from .ML_callbacks import Callback, History, TqdmProgressBar
8
+ from .ML_callbacks import Callback, History, TqdmProgressBar, ModelCheckpoint
9
9
  from .ML_evaluation import classification_metrics, regression_metrics, plot_losses, shap_summary_plot, plot_attention_importance
10
10
  from .ML_evaluation_multi import multi_target_regression_metrics, multi_label_classification_metrics, multi_target_shap_summary_plot
11
11
  from ._script_info import _script_info
12
- from .keys import PyTorchLogKeys
12
+ from .keys import PyTorchLogKeys, PyTorchCheckpointKeys, DatasetKeys
13
13
  from ._logger import _LOGGER
14
+ from .path_manager import make_fullpath
15
+ from .ML_vision_evaluation import segmentation_metrics, object_detection_metrics
14
16
 
15
17
 
16
18
  __all__ = [
17
- "MLTrainer"
19
+ "MLTrainer",
20
+ "ObjectDetectionTrainer"
18
21
  ]
19
22
 
20
23
 
21
24
  class MLTrainer:
22
25
  def __init__(self, model: nn.Module, train_dataset: Dataset, test_dataset: Dataset,
23
- kind: Literal["regression", "classification", "multi_target_regression", "multi_label_classification"],
26
+ kind: Literal["regression", "classification", "multi_target_regression", "multi_label_classification", "segmentation"],
24
27
  criterion: nn.Module, optimizer: torch.optim.Optimizer,
25
28
  device: Union[Literal['cuda', 'mps', 'cpu'],str], dataloader_workers: int = 2, callbacks: Optional[List[Callback]] = None):
26
29
  """
@@ -32,7 +35,7 @@ class MLTrainer:
32
35
  model (nn.Module): The PyTorch model to train.
33
36
  train_dataset (Dataset): The training dataset.
34
37
  test_dataset (Dataset): The testing/validation dataset.
35
- kind (str): Can be 'regression', 'classification', 'multi_target_regression', or 'multi_label_classification'.
38
+ kind (str): Can be 'regression', 'classification', 'multi_target_regression', 'multi_label_classification', or 'segmentation'.
36
39
  criterion (nn.Module): The loss function.
37
40
  optimizer (torch.optim.Optimizer): The optimizer.
38
41
  device (str): The device to run training on ('cpu', 'cuda', 'mps').
@@ -45,8 +48,10 @@ class MLTrainer:
45
48
  - For **single-label, multi-class classification** tasks, `nn.CrossEntropyLoss` is the standard choice.
46
49
 
47
50
  - For **multi-label, binary classification** tasks (where each label is a 0 or 1), `nn.BCEWithLogitsLoss` is the correct choice as it treats each output as an independent binary problem.
51
+
52
+ - For **segmentation** tasks, `nn.CrossEntropyLoss` (for multi-class) or `nn.BCEWithLogitsLoss` (for binary) are common.
48
53
  """
49
- if kind not in ["regression", "classification", "multi_target_regression", "multi_label_classification"]:
54
+ if kind not in ["regression", "classification", "multi_target_regression", "multi_label_classification", "segmentation"]:
50
55
  raise ValueError(f"'{kind}' is not a valid task type.")
51
56
 
52
57
  self.model = model
@@ -55,6 +60,7 @@ class MLTrainer:
55
60
  self.kind = kind
56
61
  self.criterion = criterion
57
62
  self.optimizer = optimizer
63
+ self.scheduler = None
58
64
  self.device = self._validate_device(device)
59
65
  self.dataloader_workers = dataloader_workers
60
66
 
@@ -70,7 +76,9 @@ class MLTrainer:
70
76
  self.history = {}
71
77
  self.epoch = 0
72
78
  self.epochs = 0 # Total epochs for the fit run
79
+ self.start_epoch = 1
73
80
  self.stop_training = False
81
+ self._batch_size = 10
74
82
 
75
83
  def _validate_device(self, device: str) -> torch.device:
76
84
  """Validates the selected device and returns a torch.device object."""
@@ -109,8 +117,66 @@ class MLTrainer:
109
117
  num_workers=loader_workers,
110
118
  pin_memory=("cuda" in self.device.type)
111
119
  )
120
+
121
+ def _load_checkpoint(self, path: Union[str, Path]):
122
+ """Loads a training checkpoint to resume training."""
123
+ p = make_fullpath(path, enforce="file")
124
+ _LOGGER.info(f"Loading checkpoint from '{p.name}' to resume training...")
125
+
126
+ try:
127
+ checkpoint = torch.load(p, map_location=self.device)
128
+
129
+ if PyTorchCheckpointKeys.MODEL_STATE not in checkpoint or PyTorchCheckpointKeys.OPTIMIZER_STATE not in checkpoint:
130
+ _LOGGER.error(f"Checkpoint file '{p.name}' is invalid. Missing 'model_state_dict' or 'optimizer_state_dict'.")
131
+ raise KeyError()
112
132
 
113
- def fit(self, epochs: int = 10, batch_size: int = 10, shuffle: bool = True):
133
+ self.model.load_state_dict(checkpoint[PyTorchCheckpointKeys.MODEL_STATE])
134
+ self.optimizer.load_state_dict(checkpoint[PyTorchCheckpointKeys.OPTIMIZER_STATE])
135
+ self.start_epoch = checkpoint.get(PyTorchCheckpointKeys.EPOCH, 0) + 1 # Resume on the *next* epoch
136
+
137
+ # --- Scheduler State Loading Logic ---
138
+ scheduler_state_exists = PyTorchCheckpointKeys.SCHEDULER_STATE in checkpoint
139
+ scheduler_object_exists = self.scheduler is not None
140
+
141
+ if scheduler_object_exists and scheduler_state_exists:
142
+ # Case 1: Both exist. Attempt to load.
143
+ try:
144
+ self.scheduler.load_state_dict(checkpoint[PyTorchCheckpointKeys.SCHEDULER_STATE]) # type: ignore
145
+ scheduler_name = self.scheduler.__class__.__name__
146
+ _LOGGER.info(f"Restored LR scheduler state for: {scheduler_name}")
147
+ except Exception as e:
148
+ # Loading failed, likely a mismatch
149
+ scheduler_name = self.scheduler.__class__.__name__
150
+ _LOGGER.error(f"Failed to load scheduler state for '{scheduler_name}'. A different scheduler type might have been used.")
151
+ raise e
152
+
153
+ elif scheduler_object_exists and not scheduler_state_exists:
154
+ # Case 2: Scheduler provided, but no state in checkpoint.
155
+ scheduler_name = self.scheduler.__class__.__name__
156
+ _LOGGER.warning(f"'{scheduler_name}' was provided, but no scheduler state was found in the checkpoint. The scheduler will start from its initial state.")
157
+
158
+ elif not scheduler_object_exists and scheduler_state_exists:
159
+ # Case 3: State in checkpoint, but no scheduler provided.
160
+ _LOGGER.error("Checkpoint contains an LR scheduler state, but no LRScheduler callback was provided.")
161
+ raise ValueError()
162
+
163
+ # Restore callback states
164
+ for cb in self.callbacks:
165
+ if isinstance(cb, ModelCheckpoint) and PyTorchCheckpointKeys.BEST_SCORE in checkpoint:
166
+ cb.best = checkpoint[PyTorchCheckpointKeys.BEST_SCORE]
167
+ _LOGGER.info(f"Restored {cb.__class__.__name__} 'best' score to: {cb.best:.4f}")
168
+
169
+ _LOGGER.info(f"Checkpoint loaded. Resuming training from epoch {self.start_epoch}.")
170
+
171
+ except Exception as e:
172
+ _LOGGER.error(f"Failed to load checkpoint from '{p}': {e}")
173
+ raise
174
+
175
+ def fit(self,
176
+ epochs: int = 10,
177
+ batch_size: int = 10,
178
+ shuffle: bool = True,
179
+ resume_from_checkpoint: Optional[Union[str, Path]] = None):
114
180
  """
115
181
  Starts the training-validation process of the model.
116
182
 
@@ -120,6 +186,7 @@ class MLTrainer:
120
186
  epochs (int): The total number of epochs to train for.
121
187
  batch_size (int): The number of samples per batch.
122
188
  shuffle (bool): Whether to shuffle the training data at each epoch.
189
+ resume_from_checkpoint (str | Path | None): Optional path to a checkpoint to resume training.
123
190
 
124
191
  Note:
125
192
  For regression tasks using `nn.MSELoss` or `nn.L1Loss`, the trainer
@@ -129,18 +196,22 @@ class MLTrainer:
129
196
  shape of `[batch_size]`.
130
197
  """
131
198
  self.epochs = epochs
132
- self._create_dataloaders(batch_size, shuffle)
199
+ self._batch_size = batch_size
200
+ self._create_dataloaders(self._batch_size, shuffle)
133
201
  self.model.to(self.device)
134
202
 
203
+ if resume_from_checkpoint:
204
+ self._load_checkpoint(resume_from_checkpoint)
205
+
135
206
  # Reset stop_training flag on the trainer
136
207
  self.stop_training = False
137
208
 
138
- self.callbacks_hook('on_train_begin')
209
+ self._callbacks_hook('on_train_begin')
139
210
 
140
- for epoch in range(1, self.epochs + 1):
211
+ for epoch in range(self.start_epoch, self.epochs + 1):
141
212
  self.epoch = epoch
142
213
  epoch_logs = {}
143
- self.callbacks_hook('on_epoch_begin', epoch, logs=epoch_logs)
214
+ self._callbacks_hook('on_epoch_begin', epoch, logs=epoch_logs)
144
215
 
145
216
  train_logs = self._train_step()
146
217
  epoch_logs.update(train_logs)
@@ -148,13 +219,13 @@ class MLTrainer:
148
219
  val_logs = self._validation_step()
149
220
  epoch_logs.update(val_logs)
150
221
 
151
- self.callbacks_hook('on_epoch_end', epoch, logs=epoch_logs)
222
+ self._callbacks_hook('on_epoch_end', epoch, logs=epoch_logs)
152
223
 
153
224
  # Check the early stopping flag
154
225
  if self.stop_training:
155
226
  break
156
227
 
157
- self.callbacks_hook('on_train_end')
228
+ self._callbacks_hook('on_train_end')
158
229
  return self.history
159
230
 
160
231
  def _train_step(self):
@@ -166,7 +237,7 @@ class MLTrainer:
166
237
  PyTorchLogKeys.BATCH_INDEX: batch_idx,
167
238
  PyTorchLogKeys.BATCH_SIZE: features.size(0)
168
239
  }
169
- self.callbacks_hook('on_batch_begin', batch_idx, logs=batch_logs)
240
+ self._callbacks_hook('on_batch_begin', batch_idx, logs=batch_logs)
170
241
 
171
242
  features, target = features.to(self.device), target.to(self.device)
172
243
  self.optimizer.zero_grad()
@@ -188,7 +259,7 @@ class MLTrainer:
188
259
 
189
260
  # Add the batch loss to the logs and call the end-of-batch hook
190
261
  batch_logs[PyTorchLogKeys.BATCH_LOSS] = batch_loss
191
- self.callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
262
+ self._callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
192
263
 
193
264
  return {PyTorchLogKeys.TRAIN_LOSS: running_loss / len(self.train_loader.dataset)} # type: ignore
194
265
 
@@ -226,25 +297,40 @@ class MLTrainer:
226
297
  for features, target in dataloader:
227
298
  features = features.to(self.device)
228
299
  output = self.model(features).cpu()
229
- y_true_batch = target.numpy()
230
300
 
231
301
  y_pred_batch = None
232
302
  y_prob_batch = None
303
+ y_true_batch = None
233
304
 
234
305
  if self.kind in ["regression", "multi_target_regression"]:
235
306
  y_pred_batch = output.numpy()
307
+ y_true_batch = target.numpy()
236
308
 
237
309
  elif self.kind == "classification":
238
310
  probs = torch.softmax(output, dim=1)
239
311
  preds = torch.argmax(probs, dim=1)
240
312
  y_pred_batch = preds.numpy()
241
313
  y_prob_batch = probs.numpy()
314
+ y_true_batch = target.numpy()
242
315
 
243
316
  elif self.kind == "multi_label_classification":
244
317
  probs = torch.sigmoid(output)
245
318
  preds = (probs >= classification_threshold).int()
246
319
  y_pred_batch = preds.numpy()
247
320
  y_prob_batch = probs.numpy()
321
+ y_true_batch = target.numpy()
322
+
323
+ elif self.kind == "segmentation":
324
+ # output shape [N, C, H, W]
325
+ probs = torch.softmax(output, dim=1)
326
+ preds = torch.argmax(probs, dim=1) # shape [N, H, W]
327
+ y_pred_batch = preds.numpy()
328
+ y_prob_batch = probs.numpy() # Probs are [N, C, H, W]
329
+
330
+ # Handle target shape [N, 1, H, W] -> [N, H, W]
331
+ if target.ndim == 4 and target.shape[1] == 1:
332
+ target = target.squeeze(1)
333
+ y_true_batch = target.numpy()
248
334
 
249
335
  yield y_pred_batch, y_prob_batch, y_true_batch
250
336
 
@@ -268,7 +354,7 @@ class MLTrainer:
268
354
  elif isinstance(data, Dataset):
269
355
  # Create a new loader from the provided dataset
270
356
  eval_loader = DataLoader(data,
271
- batch_size=32,
357
+ batch_size=self._batch_size,
272
358
  shuffle=False,
273
359
  num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
274
360
  pin_memory=(self.device.type == "cuda"))
@@ -279,10 +365,11 @@ class MLTrainer:
279
365
  raise ValueError()
280
366
  # Create a fresh DataLoader from the test_dataset
281
367
  eval_loader = DataLoader(self.test_dataset,
282
- batch_size=32,
368
+ batch_size=self._batch_size,
283
369
  shuffle=False,
284
370
  num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
285
371
  pin_memory=(self.device.type == "cuda"))
372
+
286
373
  dataset_for_names = self.test_dataset
287
374
 
288
375
  if eval_loader is None:
@@ -333,7 +420,31 @@ class MLTrainer:
333
420
  _LOGGER.error("Evaluation for multi_label_classification requires probabilities (y_prob).")
334
421
  return
335
422
  multi_label_classification_metrics(y_true, y_prob, target_names, save_dir, classification_threshold)
423
+
424
+ elif self.kind == "segmentation":
425
+ class_names = None
426
+ try:
427
+ # Try to get 'classes' from VisionDatasetMaker
428
+ if hasattr(dataset_for_names, 'classes'):
429
+ class_names = dataset_for_names.classes # type: ignore
430
+ # Fallback for Subset
431
+ elif hasattr(dataset_for_names, 'dataset') and hasattr(dataset_for_names.dataset, 'classes'): # type: ignore
432
+ class_names = dataset_for_names.dataset.classes # type: ignore
433
+ except AttributeError:
434
+ pass # class_names is still None
336
435
 
436
+ if class_names is None:
437
+ try:
438
+ # Fallback to 'target_names'
439
+ class_names = dataset_for_names.target_names # type: ignore
440
+ except AttributeError:
441
+ # Fallback to inferring from labels
442
+ labels = np.unique(y_true)
443
+ class_names = [f"Class {i}" for i in labels]
444
+ _LOGGER.warning(f"Dataset has no 'classes' or 'target_names' attribute. Using generic names.")
445
+
446
+ segmentation_metrics(y_true, y_pred, save_dir, class_names=class_names)
447
+
337
448
  print("\n--- Training History ---")
338
449
  plot_losses(self.history, save_dir=save_dir)
339
450
 
@@ -343,7 +454,7 @@ class MLTrainer:
343
454
  n_samples: int = 300,
344
455
  feature_names: Optional[List[str]] = None,
345
456
  target_names: Optional[List[str]] = None,
346
- explainer_type: Literal['deep', 'kernel'] = 'deep'):
457
+ explainer_type: Literal['deep', 'kernel'] = 'kernel'):
347
458
  """
348
459
  Explains model predictions using SHAP and saves all artifacts.
349
460
 
@@ -357,11 +468,11 @@ class MLTrainer:
357
468
  explain_dataset (Dataset | None): A specific dataset to explain.
358
469
  If None, the trainer's test dataset is used.
359
470
  n_samples (int): The number of samples to use for both background and explanation.
360
- feature_names (list[str] | None): Feature names.
471
+ feature_names (list[str] | None): Feature names. If None, the names will be extracted from the Dataset and raise an error on failure.
361
472
  target_names (list[str] | None): Target names for multi-target tasks.
362
473
  save_dir (str | Path): Directory to save all SHAP artifacts.
363
474
  explainer_type (Literal['deep', 'kernel']): The explainer to use.
364
- - 'deep': (Default) Uses shap.DeepExplainer. Fast and efficient for PyTorch models.
475
+ - 'deep': Uses shap.DeepExplainer. Fast and efficient for PyTorch models.
365
476
  - 'kernel': Uses shap.KernelExplainer. Model-agnostic but EXTREMELY slow and memory-intensive. Use with a very low 'n_samples'< 100.
366
477
  """
367
478
  # Internal helper to create a dataloader and get a random sample
@@ -409,10 +520,10 @@ class MLTrainer:
409
520
  # attempt to get feature names
410
521
  if feature_names is None:
411
522
  # _LOGGER.info("`feature_names` not provided. Attempting to extract from dataset...")
412
- if hasattr(target_dataset, "feature_names"):
523
+ if hasattr(target_dataset, DatasetKeys.FEATURE_NAMES):
413
524
  feature_names = target_dataset.feature_names # type: ignore
414
525
  else:
415
- _LOGGER.error("Could not extract `feature_names` from the dataset. It must be provided if the dataset object does not have a `feature_names` attribute.")
526
+ _LOGGER.error(f"Could not extract `feature_names` from the dataset. It must be provided if the dataset object does not have a '{DatasetKeys.FEATURE_NAMES}' attribute.")
416
527
  raise ValueError()
417
528
 
418
529
  # move model to device
@@ -433,7 +544,7 @@ class MLTrainer:
433
544
  # try to get target names
434
545
  if target_names is None:
435
546
  target_names = []
436
- if hasattr(target_dataset, 'target_names'):
547
+ if hasattr(target_dataset, DatasetKeys.TARGET_NAMES):
437
548
  target_names = target_dataset.target_names # type: ignore
438
549
  else:
439
550
  # Infer number of targets from the model's output layer
@@ -484,7 +595,7 @@ class MLTrainer:
484
595
  yield attention_weights
485
596
 
486
597
  def explain_attention(self, save_dir: Union[str, Path],
487
- feature_names: Optional[List[str]],
598
+ feature_names: Optional[List[str]] = None,
488
599
  explain_dataset: Optional[Dataset] = None,
489
600
  plot_n_features: int = 10):
490
601
  """
@@ -494,7 +605,7 @@ class MLTrainer:
494
605
 
495
606
  Args:
496
607
  save_dir (str | Path): Directory to save the plot and summary data.
497
- feature_names (List[str] | None): Names for the features for plot labeling. If not given, generic names will be used.
608
+ feature_names (List[str] | None): Names for the features for plot labeling. If None, the names will be extracted from the Dataset and raise an error on failure.
498
609
  explain_dataset (Dataset, optional): A specific dataset to explain. If None, the trainer's test dataset is used.
499
610
  plot_n_features (int): Number of top features to plot.
500
611
  """
@@ -504,8 +615,7 @@ class MLTrainer:
504
615
  # --- Step 1: Check if the model supports this explanation ---
505
616
  if not getattr(self.model, 'has_interpretable_attention', False):
506
617
  _LOGGER.warning(
507
- "Model is not flagged for interpretable attention analysis. "
508
- "Skipping. This is the correct behavior for models like MultiHeadAttentionMLP."
618
+ "Model is not flagged for interpretable attention analysis. Skipping. This is the correct behavior for models like MultiHeadAttentionMLP."
509
619
  )
510
620
  return
511
621
 
@@ -515,6 +625,14 @@ class MLTrainer:
515
625
  _LOGGER.error("The explanation dataset is empty or invalid. Skipping attention analysis.")
516
626
  return
517
627
 
628
+ # Get feature names
629
+ if feature_names is None:
630
+ if hasattr(dataset_to_use, DatasetKeys.FEATURE_NAMES):
631
+ feature_names = dataset_to_use.feature_names # type: ignore
632
+ else:
633
+ _LOGGER.error(f"Could not extract `feature_names` from the dataset for attention plot. It must be provided if the dataset object does not have a '{DatasetKeys.FEATURE_NAMES}' attribute.")
634
+ raise ValueError()
635
+
518
636
  explain_loader = DataLoader(
519
637
  dataset=dataset_to_use, batch_size=32, shuffle=False,
520
638
  num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
@@ -538,11 +656,423 @@ class MLTrainer:
538
656
  else:
539
657
  _LOGGER.error("No attention weights were collected from the model.")
540
658
 
541
- def callbacks_hook(self, method_name: str, *args, **kwargs):
659
+ def _callbacks_hook(self, method_name: str, *args, **kwargs):
542
660
  """Calls the specified method on all callbacks."""
543
661
  for callback in self.callbacks:
544
662
  method = getattr(callback, method_name)
545
663
  method(*args, **kwargs)
664
+
665
+ def to_cpu(self):
666
+ """
667
+ Moves the model to the CPU and updates the trainer's device setting.
668
+
669
+ This is useful for running operations that require the CPU.
670
+ """
671
+ self.device = torch.device('cpu')
672
+ self.model.to(self.device)
673
+ _LOGGER.info("Trainer and model moved to CPU.")
674
+
675
+ def to_device(self, device: str):
676
+ """
677
+ Moves the model to the specified device and updates the trainer's device setting.
678
+
679
+ Args:
680
+ device (str): The target device (e.g., 'cuda', 'mps', 'cpu').
681
+ """
682
+ self.device = self._validate_device(device)
683
+ self.model.to(self.device)
684
+ _LOGGER.info(f"Trainer and model moved to {self.device}.")
685
+
686
+
687
+ # Object Detection Trainer
688
+ class ObjectDetectionTrainer:
689
+ def __init__(self, model: nn.Module, train_dataset: Dataset, test_dataset: Dataset,
690
+ collate_fn: Callable, optimizer: torch.optim.Optimizer,
691
+ device: Union[Literal['cuda', 'mps', 'cpu'],str], dataloader_workers: int = 2, callbacks: Optional[List[Callback]] = None):
692
+ """
693
+ Automates the training process of an Object Detection Model (e.g., DragonFastRCNN).
694
+
695
+ Built-in Callbacks: `History`, `TqdmProgressBar`
696
+
697
+ Args:
698
+ model (nn.Module): The PyTorch object detection model to train.
699
+ train_dataset (Dataset): The training dataset.
700
+ test_dataset (Dataset): The testing/validation dataset.
701
+ collate_fn (Callable): The collate function from `ObjectDetectionDatasetMaker.collate_fn`.
702
+ optimizer (torch.optim.Optimizer): The optimizer.
703
+ device (str): The device to run training on ('cpu', 'cuda', 'mps').
704
+ dataloader_workers (int): Subprocesses for data loading.
705
+ callbacks (List[Callback] | None): A list of callbacks to use during training.
706
+
707
+ ## Note:
708
+ This trainer is specialized. It does not take a `criterion` because object detection models like Faster R-CNN return a dictionary of losses directly from their forward pass during training.
709
+ """
710
+ self.model = model
711
+ self.train_dataset = train_dataset
712
+ self.test_dataset = test_dataset
713
+ self.kind = "object_detection"
714
+ self.collate_fn = collate_fn
715
+ self.criterion = None # Criterion is handled inside the model
716
+ self.optimizer = optimizer
717
+ self.scheduler = None
718
+ self.device = self._validate_device(device)
719
+ self.dataloader_workers = dataloader_workers
720
+
721
+ # Callback handler - History and TqdmProgressBar are added by default
722
+ default_callbacks = [History(), TqdmProgressBar()]
723
+ user_callbacks = callbacks if callbacks is not None else []
724
+ self.callbacks = default_callbacks + user_callbacks
725
+ self._set_trainer_on_callbacks()
726
+
727
+ # Internal state
728
+ self.train_loader = None
729
+ self.test_loader = None
730
+ self.history = {}
731
+ self.epoch = 0
732
+ self.epochs = 0 # Total epochs for the fit run
733
+ self.start_epoch = 1
734
+ self.stop_training = False
735
+ self._batch_size = 10
736
+
737
+ def _validate_device(self, device: str) -> torch.device:
738
+ """Validates the selected device and returns a torch.device object."""
739
+ device_lower = device.lower()
740
+ if "cuda" in device_lower and not torch.cuda.is_available():
741
+ _LOGGER.warning("CUDA not available, switching to CPU.")
742
+ device = "cpu"
743
+ elif device_lower == "mps" and not torch.backends.mps.is_available():
744
+ _LOGGER.warning("Apple Metal Performance Shaders (MPS) not available, switching to CPU.")
745
+ device = "cpu"
746
+ return torch.device(device)
747
+
748
+ def _set_trainer_on_callbacks(self):
749
+ """Gives each callback a reference to this trainer instance."""
750
+ for callback in self.callbacks:
751
+ callback.set_trainer(self)
752
+
753
+ def _create_dataloaders(self, batch_size: int, shuffle: bool):
754
+ """Initializes the DataLoaders with the object detection collate_fn."""
755
+ # Ensure stability on MPS devices by setting num_workers to 0
756
+ loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
757
+
758
+ self.train_loader = DataLoader(
759
+ dataset=self.train_dataset,
760
+ batch_size=batch_size,
761
+ shuffle=shuffle,
762
+ num_workers=loader_workers,
763
+ pin_memory=("cuda" in self.device.type),
764
+ collate_fn=self.collate_fn # Use the provided collate function
765
+ )
766
+
767
+ self.test_loader = DataLoader(
768
+ dataset=self.test_dataset,
769
+ batch_size=batch_size,
770
+ shuffle=False,
771
+ num_workers=loader_workers,
772
+ pin_memory=("cuda" in self.device.type),
773
+ collate_fn=self.collate_fn # Use the provided collate function
774
+ )
775
+
776
+ def _load_checkpoint(self, path: Union[str, Path]):
777
+ """Loads a training checkpoint to resume training."""
778
+ p = make_fullpath(path, enforce="file")
779
+ _LOGGER.info(f"Loading checkpoint from '{p.name}' to resume training...")
780
+
781
+ try:
782
+ checkpoint = torch.load(p, map_location=self.device)
783
+
784
+ if PyTorchCheckpointKeys.MODEL_STATE not in checkpoint or PyTorchCheckpointKeys.OPTIMIZER_STATE not in checkpoint:
785
+ _LOGGER.error(f"Checkpoint file '{p.name}' is invalid. Missing 'model_state_dict' or 'optimizer_state_dict'.")
786
+ raise KeyError()
787
+
788
+ self.model.load_state_dict(checkpoint[PyTorchCheckpointKeys.MODEL_STATE])
789
+ self.optimizer.load_state_dict(checkpoint[PyTorchCheckpointKeys.OPTIMIZER_STATE])
790
+ self.start_epoch = checkpoint.get(PyTorchCheckpointKeys.EPOCH, 0) + 1 # Resume on the *next* epoch
791
+
792
+ # --- Scheduler State Loading Logic ---
793
+ scheduler_state_exists = PyTorchCheckpointKeys.SCHEDULER_STATE in checkpoint
794
+ scheduler_object_exists = self.scheduler is not None
795
+
796
+ if scheduler_object_exists and scheduler_state_exists:
797
+ # Case 1: Both exist. Attempt to load.
798
+ try:
799
+ self.scheduler.load_state_dict(checkpoint[PyTorchCheckpointKeys.SCHEDULER_STATE]) # type: ignore
800
+ scheduler_name = self.scheduler.__class__.__name__
801
+ _LOGGER.info(f"Restored LR scheduler state for: {scheduler_name}")
802
+ except Exception as e:
803
+ # Loading failed, likely a mismatch
804
+ scheduler_name = self.scheduler.__class__.__name__
805
+ _LOGGER.error(f"Failed to load scheduler state for '{scheduler_name}'. A different scheduler type might have been used.")
806
+ raise e
807
+
808
+ elif scheduler_object_exists and not scheduler_state_exists:
809
+ # Case 2: Scheduler provided, but no state in checkpoint.
810
+ scheduler_name = self.scheduler.__class__.__name__
811
+ _LOGGER.warning(f"'{scheduler_name}' was provided, but no scheduler state was found in the checkpoint. The scheduler will start from its initial state.")
812
+
813
+ elif not scheduler_object_exists and scheduler_state_exists:
814
+ # Case 3: State in checkpoint, but no scheduler provided.
815
+ _LOGGER.error("Checkpoint contains an LR scheduler state, but no LRScheduler callback was provided.")
816
+ raise ValueError()
817
+
818
+ # Restore callback states
819
+ for cb in self.callbacks:
820
+ if isinstance(cb, ModelCheckpoint) and PyTorchCheckpointKeys.BEST_SCORE in checkpoint:
821
+ cb.best = checkpoint[PyTorchCheckpointKeys.BEST_SCORE]
822
+ _LOGGER.info(f"Restored {cb.__class__.__name__} 'best' score to: {cb.best:.4f}")
823
+
824
+ _LOGGER.info(f"Checkpoint loaded. Resuming training from epoch {self.start_epoch}.")
825
+
826
+ except Exception as e:
827
+ _LOGGER.error(f"Failed to load checkpoint from '{p}': {e}")
828
+ raise
829
+
830
+ def fit(self,
831
+ epochs: int = 10,
832
+ batch_size: int = 10,
833
+ shuffle: bool = True,
834
+ resume_from_checkpoint: Optional[Union[str, Path]] = None):
835
+ """
836
+ Starts the training-validation process of the model.
837
+
838
+ Returns the "History" callback dictionary.
839
+
840
+ Args:
841
+ epochs (int): The total number of epochs to train for.
842
+ batch_size (int): The number of samples per batch.
843
+ shuffle (bool): Whether to shuffle the training data at each epoch.
844
+ resume_from_checkpoint (str | Path | None): Optional path to a checkpoint to resume training.
845
+ """
846
+ self.epochs = epochs
847
+ self._batch_size = batch_size
848
+ self._create_dataloaders(self._batch_size, shuffle)
849
+ self.model.to(self.device)
850
+
851
+ if resume_from_checkpoint:
852
+ self._load_checkpoint(resume_from_checkpoint)
853
+
854
+ # Reset stop_training flag on the trainer
855
+ self.stop_training = False
856
+
857
+ self._callbacks_hook('on_train_begin')
858
+
859
+ for epoch in range(self.start_epoch, self.epochs + 1):
860
+ self.epoch = epoch
861
+ epoch_logs = {}
862
+ self._callbacks_hook('on_epoch_begin', epoch, logs=epoch_logs)
863
+
864
+ train_logs = self._train_step()
865
+ epoch_logs.update(train_logs)
866
+
867
+ val_logs = self._validation_step()
868
+ epoch_logs.update(val_logs)
869
+
870
+ self._callbacks_hook('on_epoch_end', epoch, logs=epoch_logs)
871
+
872
+ # Check the early stopping flag
873
+ if self.stop_training:
874
+ break
875
+
876
+ self._callbacks_hook('on_train_end')
877
+ return self.history
878
+
879
+ def _train_step(self):
880
+ self.model.train()
881
+ running_loss = 0.0
882
+ for batch_idx, (images, targets) in enumerate(self.train_loader): # type: ignore
883
+ # images is a tuple of tensors, targets is a tuple of dicts
884
+ batch_size = len(images)
885
+
886
+ # Create a log dictionary for the batch
887
+ batch_logs = {
888
+ PyTorchLogKeys.BATCH_INDEX: batch_idx,
889
+ PyTorchLogKeys.BATCH_SIZE: batch_size
890
+ }
891
+ self._callbacks_hook('on_batch_begin', batch_idx, logs=batch_logs)
892
+
893
+ # Move data to device
894
+ images = list(img.to(self.device) for img in images)
895
+ targets = [{k: v.to(self.device) for k, v in t.items()} for t in targets]
896
+
897
+ self.optimizer.zero_grad()
898
+
899
+ # Model returns a loss dict when in train() mode and targets are passed
900
+ loss_dict = self.model(images, targets)
901
+
902
+ if not loss_dict:
903
+ # No losses returned, skip batch
904
+ _LOGGER.warning(f"Model returned no losses for batch {batch_idx}. Skipping.")
905
+ batch_logs[PyTorchLogKeys.BATCH_LOSS] = 0
906
+ self._callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
907
+ continue
908
+
909
+ # Sum all losses
910
+ loss: torch.Tensor = sum(l for l in loss_dict.values()) # type: ignore
911
+
912
+ loss.backward()
913
+ self.optimizer.step()
914
+
915
+ # Calculate batch loss and update running loss for the epoch
916
+ batch_loss = loss.item()
917
+ running_loss += batch_loss * batch_size
918
+
919
+ # Add the batch loss to the logs and call the end-of-batch hook
920
+ batch_logs[PyTorchLogKeys.BATCH_LOSS] = batch_loss # type: ignore
921
+ self._callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
922
+
923
+ return {PyTorchLogKeys.TRAIN_LOSS: running_loss / len(self.train_loader.dataset)} # type: ignore
924
+
925
+ def _validation_step(self):
926
+ self.model.train() # Set to train mode even for validation loss calculation
927
+ # as model internals (e.g., proposals) might differ,
928
+ # but we still need loss_dict.
929
+ # We use torch.no_grad() to prevent gradient updates.
930
+ running_loss = 0.0
931
+ with torch.no_grad():
932
+ for images, targets in self.test_loader: # type: ignore
933
+ batch_size = len(images)
934
+
935
+ # Move data to device
936
+ images = list(img.to(self.device) for img in images)
937
+ targets = [{k: v.to(self.device) for k, v in t.items()} for t in targets]
938
+
939
+ # Get loss dict
940
+ loss_dict = self.model(images, targets)
941
+
942
+ if not loss_dict:
943
+ _LOGGER.warning("Model returned no losses during validation step. Skipping batch.")
944
+ continue # Skip if no losses
945
+
946
+ # Sum all losses
947
+ loss: torch.Tensor = sum(l for l in loss_dict.values()) # type: ignore
948
+
949
+ running_loss += loss.item() * batch_size
950
+
951
+ logs = {PyTorchLogKeys.VAL_LOSS: running_loss / len(self.test_loader.dataset)} # type: ignore
952
+ return logs
953
+
954
+ def evaluate(self, save_dir: Union[str, Path], data: Optional[Union[DataLoader, Dataset]] = None):
955
+ """
956
+ Evaluates the model using object detection mAP metrics.
957
+
958
+ Args:
959
+ save_dir (str | Path): Directory to save all reports and plots.
960
+ data (DataLoader | Dataset | None): The data to evaluate on. If None, defaults to the trainer's internal test_dataset.
961
+ """
962
+ dataset_for_names = None
963
+ eval_loader = None
964
+
965
+ if isinstance(data, DataLoader):
966
+ eval_loader = data
967
+ if hasattr(data, 'dataset'):
968
+ dataset_for_names = data.dataset
969
+ elif isinstance(data, Dataset):
970
+ # Create a new loader from the provided dataset
971
+ eval_loader = DataLoader(data,
972
+ batch_size=self._batch_size,
973
+ shuffle=False,
974
+ num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
975
+ pin_memory=(self.device.type == "cuda"),
976
+ collate_fn=self.collate_fn)
977
+ dataset_for_names = data
978
+ else: # data is None, use the trainer's default test dataset
979
+ if self.test_dataset is None:
980
+ _LOGGER.error("Cannot evaluate. No data provided and no test_dataset available in the trainer.")
981
+ raise ValueError()
982
+ # Create a fresh DataLoader from the test_dataset
983
+ eval_loader = DataLoader(
984
+ self.test_dataset,
985
+ batch_size=self._batch_size,
986
+ shuffle=False,
987
+ num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
988
+ pin_memory=(self.device.type == "cuda"),
989
+ collate_fn=self.collate_fn
990
+ )
991
+ dataset_for_names = self.test_dataset
992
+
993
+ if eval_loader is None:
994
+ _LOGGER.error("Cannot evaluate. No valid data was provided or found.")
995
+ raise ValueError()
996
+
997
+ print("\n--- Model Evaluation ---")
998
+
999
+ all_predictions = []
1000
+ all_targets = []
1001
+
1002
+ self.model.eval() # Set model to evaluation mode
1003
+ self.model.to(self.device)
1004
+
1005
+ with torch.no_grad():
1006
+ for images, targets in eval_loader:
1007
+ # Move images to device
1008
+ images = list(img.to(self.device) for img in images)
1009
+
1010
+ # Model returns predictions when in eval() mode
1011
+ predictions = self.model(images)
1012
+
1013
+ # Move predictions and targets to CPU for aggregation
1014
+ cpu_preds = [{k: v.to('cpu') for k, v in p.items()} for p in predictions]
1015
+ cpu_targets = [{k: v.to('cpu') for k, v in t.items()} for t in targets]
1016
+
1017
+ all_predictions.extend(cpu_preds)
1018
+ all_targets.extend(cpu_targets)
1019
+
1020
+ if not all_targets:
1021
+ _LOGGER.error("Evaluation failed: No data was processed.")
1022
+ return
1023
+
1024
+ # Get class names from the dataset for the report
1025
+ class_names = None
1026
+ try:
1027
+ # Try to get 'classes' from ObjectDetectionDatasetMaker
1028
+ if hasattr(dataset_for_names, 'classes'):
1029
+ class_names = dataset_for_names.classes # type: ignore
1030
+ # Fallback for Subset
1031
+ elif hasattr(dataset_for_names, 'dataset') and hasattr(dataset_for_names.dataset, 'classes'): # type: ignore
1032
+ class_names = dataset_for_names.dataset.classes # type: ignore
1033
+ except AttributeError:
1034
+ _LOGGER.warning("Could not find 'classes' attribute on dataset. Per-class metrics will not be named.")
1035
+ pass # class_names is still None
1036
+
1037
+ # --- Routing Logic ---
1038
+ object_detection_metrics(
1039
+ preds=all_predictions,
1040
+ targets=all_targets,
1041
+ save_dir=save_dir,
1042
+ class_names=class_names,
1043
+ print_output=False
1044
+ )
1045
+
1046
+ print("\n--- Training History ---")
1047
+ plot_losses(self.history, save_dir=save_dir)
1048
+
1049
+ def _callbacks_hook(self, method_name: str, *args, **kwargs):
1050
+ """Calls the specified method on all callbacks."""
1051
+ for callback in self.callbacks:
1052
+ method = getattr(callback, method_name)
1053
+ method(*args, **kwargs)
1054
+
1055
+ def to_cpu(self):
1056
+ """
1057
+ Moves the model to the CPU and updates the trainer's device setting.
1058
+
1059
+ This is useful for running operations that require the CPU.
1060
+ """
1061
+ self.device = torch.device('cpu')
1062
+ self.model.to(self.device)
1063
+ _LOGGER.info("Trainer and model moved to CPU.")
1064
+
1065
+ def to_device(self, device: str):
1066
+ """
1067
+ Moves the model to the specified device and updates the trainer's device setting.
1068
+
1069
+ Args:
1070
+ device (str): The target device (e.g., 'cuda', 'mps', 'cpu').
1071
+ """
1072
+ self.device = self._validate_device(device)
1073
+ self.model.to(self.device)
1074
+ _LOGGER.info(f"Trainer and model moved to {self.device}.")
1075
+
546
1076
 
547
1077
  def info():
548
1078
  _script_info(__all__)