dragon-ml-toolbox 13.0.0__py3-none-any.whl → 14.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {dragon_ml_toolbox-13.0.0.dist-info → dragon_ml_toolbox-14.7.0.dist-info}/METADATA +12 -2
  2. dragon_ml_toolbox-14.7.0.dist-info/RECORD +49 -0
  3. {dragon_ml_toolbox-13.0.0.dist-info → dragon_ml_toolbox-14.7.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +10 -0
  4. ml_tools/MICE_imputation.py +207 -5
  5. ml_tools/ML_configuration.py +108 -0
  6. ml_tools/ML_datasetmaster.py +241 -260
  7. ml_tools/ML_evaluation.py +229 -76
  8. ml_tools/ML_evaluation_multi.py +45 -16
  9. ml_tools/ML_inference.py +0 -1
  10. ml_tools/ML_models.py +135 -55
  11. ml_tools/ML_models_advanced.py +323 -0
  12. ml_tools/ML_optimization.py +49 -36
  13. ml_tools/ML_trainer.py +498 -29
  14. ml_tools/ML_utilities.py +351 -4
  15. ml_tools/ML_vision_datasetmaster.py +1492 -0
  16. ml_tools/ML_vision_evaluation.py +260 -0
  17. ml_tools/ML_vision_inference.py +428 -0
  18. ml_tools/ML_vision_models.py +641 -0
  19. ml_tools/ML_vision_transformers.py +203 -0
  20. ml_tools/PSO_optimization.py +5 -1
  21. ml_tools/_ML_vision_recipe.py +88 -0
  22. ml_tools/__init__.py +1 -0
  23. ml_tools/_schema.py +96 -0
  24. ml_tools/custom_logger.py +37 -14
  25. ml_tools/data_exploration.py +576 -138
  26. ml_tools/ensemble_evaluation.py +53 -10
  27. ml_tools/keys.py +43 -1
  28. ml_tools/math_utilities.py +1 -1
  29. ml_tools/optimization_tools.py +65 -86
  30. ml_tools/serde.py +78 -17
  31. ml_tools/utilities.py +192 -3
  32. dragon_ml_toolbox-13.0.0.dist-info/RECORD +0 -41
  33. ml_tools/ML_simple_optimization.py +0 -413
  34. {dragon_ml_toolbox-13.0.0.dist-info → dragon_ml_toolbox-14.7.0.dist-info}/WHEEL +0 -0
  35. {dragon_ml_toolbox-13.0.0.dist-info → dragon_ml_toolbox-14.7.0.dist-info}/licenses/LICENSE +0 -0
  36. {dragon_ml_toolbox-13.0.0.dist-info → dragon_ml_toolbox-14.7.0.dist-info}/top_level.txt +0 -0
ml_tools/ML_trainer.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import List, Literal, Union, Optional
1
+ from typing import List, Literal, Union, Optional, Callable, Dict, Any, Tuple
2
2
  from pathlib import Path
3
3
  from torch.utils.data import DataLoader, Dataset
4
4
  import torch
@@ -9,19 +9,22 @@ from .ML_callbacks import Callback, History, TqdmProgressBar, ModelCheckpoint
9
9
  from .ML_evaluation import classification_metrics, regression_metrics, plot_losses, shap_summary_plot, plot_attention_importance
10
10
  from .ML_evaluation_multi import multi_target_regression_metrics, multi_label_classification_metrics, multi_target_shap_summary_plot
11
11
  from ._script_info import _script_info
12
- from .keys import PyTorchLogKeys, PyTorchCheckpointKeys
12
+ from .keys import PyTorchLogKeys, PyTorchCheckpointKeys, DatasetKeys
13
13
  from ._logger import _LOGGER
14
14
  from .path_manager import make_fullpath
15
+ from .ML_vision_evaluation import segmentation_metrics, object_detection_metrics
16
+ from .ML_configuration import ClassificationMetricsFormat, MultiClassificationMetricsFormat
15
17
 
16
18
 
17
19
  __all__ = [
18
- "MLTrainer"
20
+ "MLTrainer",
21
+ "ObjectDetectionTrainer",
19
22
  ]
20
23
 
21
24
 
22
25
  class MLTrainer:
23
26
  def __init__(self, model: nn.Module, train_dataset: Dataset, test_dataset: Dataset,
24
- kind: Literal["regression", "classification", "multi_target_regression", "multi_label_classification"],
27
+ kind: Literal["regression", "classification", "multi_target_regression", "multi_label_classification", "segmentation"],
25
28
  criterion: nn.Module, optimizer: torch.optim.Optimizer,
26
29
  device: Union[Literal['cuda', 'mps', 'cpu'],str], dataloader_workers: int = 2, callbacks: Optional[List[Callback]] = None):
27
30
  """
@@ -33,7 +36,7 @@ class MLTrainer:
33
36
  model (nn.Module): The PyTorch model to train.
34
37
  train_dataset (Dataset): The training dataset.
35
38
  test_dataset (Dataset): The testing/validation dataset.
36
- kind (str): Can be 'regression', 'classification', 'multi_target_regression', or 'multi_label_classification'.
39
+ kind (str): Can be 'regression', 'classification', 'multi_target_regression', 'multi_label_classification', or 'segmentation'.
37
40
  criterion (nn.Module): The loss function.
38
41
  optimizer (torch.optim.Optimizer): The optimizer.
39
42
  device (str): The device to run training on ('cpu', 'cuda', 'mps').
@@ -46,8 +49,10 @@ class MLTrainer:
46
49
  - For **single-label, multi-class classification** tasks, `nn.CrossEntropyLoss` is the standard choice.
47
50
 
48
51
  - For **multi-label, binary classification** tasks (where each label is a 0 or 1), `nn.BCEWithLogitsLoss` is the correct choice as it treats each output as an independent binary problem.
52
+
53
+ - For **segmentation** tasks, `nn.CrossEntropyLoss` (for multi-class) or `nn.BCEWithLogitsLoss` (for binary) are common.
49
54
  """
50
- if kind not in ["regression", "classification", "multi_target_regression", "multi_label_classification"]:
55
+ if kind not in ["regression", "classification", "multi_target_regression", "multi_label_classification", "segmentation"]:
51
56
  raise ValueError(f"'{kind}' is not a valid task type.")
52
57
 
53
58
  self.model = model
@@ -74,6 +79,7 @@ class MLTrainer:
74
79
  self.epochs = 0 # Total epochs for the fit run
75
80
  self.start_epoch = 1
76
81
  self.stop_training = False
82
+ self._batch_size = 10
77
83
 
78
84
  def _validate_device(self, device: str) -> torch.device:
79
85
  """Validates the selected device and returns a torch.device object."""
@@ -191,7 +197,8 @@ class MLTrainer:
191
197
  shape of `[batch_size]`.
192
198
  """
193
199
  self.epochs = epochs
194
- self._create_dataloaders(batch_size, shuffle)
200
+ self._batch_size = batch_size
201
+ self._create_dataloaders(self._batch_size, shuffle)
195
202
  self.model.to(self.device)
196
203
 
197
204
  if resume_from_checkpoint:
@@ -291,36 +298,53 @@ class MLTrainer:
291
298
  for features, target in dataloader:
292
299
  features = features.to(self.device)
293
300
  output = self.model(features).cpu()
294
- y_true_batch = target.numpy()
295
301
 
296
302
  y_pred_batch = None
297
303
  y_prob_batch = None
304
+ y_true_batch = None
298
305
 
299
306
  if self.kind in ["regression", "multi_target_regression"]:
300
307
  y_pred_batch = output.numpy()
308
+ y_true_batch = target.numpy()
301
309
 
302
310
  elif self.kind == "classification":
303
311
  probs = torch.softmax(output, dim=1)
304
312
  preds = torch.argmax(probs, dim=1)
305
313
  y_pred_batch = preds.numpy()
306
314
  y_prob_batch = probs.numpy()
315
+ y_true_batch = target.numpy()
307
316
 
308
317
  elif self.kind == "multi_label_classification":
309
318
  probs = torch.sigmoid(output)
310
319
  preds = (probs >= classification_threshold).int()
311
320
  y_pred_batch = preds.numpy()
312
321
  y_prob_batch = probs.numpy()
322
+ y_true_batch = target.numpy()
323
+
324
+ elif self.kind == "segmentation":
325
+ # output shape [N, C, H, W]
326
+ probs = torch.softmax(output, dim=1)
327
+ preds = torch.argmax(probs, dim=1) # shape [N, H, W]
328
+ y_pred_batch = preds.numpy()
329
+ y_prob_batch = probs.numpy() # Probs are [N, C, H, W]
330
+
331
+ # Handle target shape [N, 1, H, W] -> [N, H, W]
332
+ if target.ndim == 4 and target.shape[1] == 1:
333
+ target = target.squeeze(1)
334
+ y_true_batch = target.numpy()
313
335
 
314
336
  yield y_pred_batch, y_prob_batch, y_true_batch
315
337
 
316
- def evaluate(self, save_dir: Union[str, Path], data: Optional[Union[DataLoader, Dataset]] = None, classification_threshold: float = 0.5):
338
+ def evaluate(self,
339
+ save_dir: Union[str, Path],
340
+ data: Optional[Union[DataLoader, Dataset]] = None,
341
+ format_configuration: Optional[Union[ClassificationMetricsFormat, MultiClassificationMetricsFormat]]=None):
317
342
  """
318
343
  Evaluates the model, routing to the correct evaluation function based on task `kind`.
319
344
 
320
345
  Args:
321
346
  save_dir (str | Path): Directory to save all reports and plots.
322
347
  data (DataLoader | Dataset | None): The data to evaluate on. If None, defaults to the trainer's internal test_dataset.
323
- classification_threshold (float): Probability threshold for multi-label tasks.
324
348
  """
325
349
  dataset_for_names = None
326
350
  eval_loader = None
@@ -333,7 +357,7 @@ class MLTrainer:
333
357
  elif isinstance(data, Dataset):
334
358
  # Create a new loader from the provided dataset
335
359
  eval_loader = DataLoader(data,
336
- batch_size=32,
360
+ batch_size=self._batch_size,
337
361
  shuffle=False,
338
362
  num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
339
363
  pin_memory=(self.device.type == "cuda"))
@@ -344,20 +368,21 @@ class MLTrainer:
344
368
  raise ValueError()
345
369
  # Create a fresh DataLoader from the test_dataset
346
370
  eval_loader = DataLoader(self.test_dataset,
347
- batch_size=32,
371
+ batch_size=self._batch_size,
348
372
  shuffle=False,
349
373
  num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
350
374
  pin_memory=(self.device.type == "cuda"))
375
+
351
376
  dataset_for_names = self.test_dataset
352
377
 
353
378
  if eval_loader is None:
354
379
  _LOGGER.error("Cannot evaluate. No valid data was provided or found.")
355
380
  raise ValueError()
356
381
 
357
- print("\n--- Model Evaluation ---")
382
+ # print("\n--- Model Evaluation ---")
358
383
 
359
384
  all_preds, all_probs, all_true = [], [], []
360
- for y_pred_b, y_prob_b, y_true_b in self._predict_for_eval(eval_loader, classification_threshold):
385
+ for y_pred_b, y_prob_b, y_true_b in self._predict_for_eval(eval_loader):
361
386
  if y_pred_b is not None: all_preds.append(y_pred_b)
362
387
  if y_prob_b is not None: all_probs.append(y_prob_b)
363
388
  if y_true_b is not None: all_true.append(y_true_b)
@@ -375,7 +400,19 @@ class MLTrainer:
375
400
  regression_metrics(y_true.flatten(), y_pred.flatten(), save_dir)
376
401
 
377
402
  elif self.kind == "classification":
378
- classification_metrics(save_dir, y_true, y_pred, y_prob)
403
+ # Parse configuration
404
+ if format_configuration and isinstance(format_configuration, ClassificationMetricsFormat):
405
+ classification_metrics(save_dir=save_dir,
406
+ y_true=y_true,
407
+ y_pred=y_pred,
408
+ y_prob=y_prob,
409
+ cmap=format_configuration.cmap,
410
+ class_map=format_configuration.class_map,
411
+ ROC_PR_line=format_configuration.ROC_PR_line,
412
+ calibration_bins=format_configuration.calibration_bins,
413
+ font_size=format_configuration.font_size)
414
+ else:
415
+ classification_metrics(save_dir, y_true, y_pred, y_prob)
379
416
 
380
417
  elif self.kind == "multi_target_regression":
381
418
  try:
@@ -397,9 +434,44 @@ class MLTrainer:
397
434
  if y_prob is None:
398
435
  _LOGGER.error("Evaluation for multi_label_classification requires probabilities (y_prob).")
399
436
  return
400
- multi_label_classification_metrics(y_true, y_prob, target_names, save_dir, classification_threshold)
437
+
438
+ if format_configuration and isinstance(format_configuration, MultiClassificationMetricsFormat):
439
+ multi_label_classification_metrics(y_true=y_true,
440
+ y_prob=y_prob,
441
+ target_names=target_names,
442
+ save_dir=save_dir,
443
+ threshold=format_configuration.threshold,
444
+ ROC_PR_line=format_configuration.ROC_PR_line,
445
+ cmap=format_configuration.cmap,
446
+ font_size=format_configuration.font_size)
447
+ else:
448
+ multi_label_classification_metrics(y_true, y_prob, target_names, save_dir)
449
+
450
+ elif self.kind == "segmentation":
451
+ class_names = None
452
+ try:
453
+ # Try to get 'classes' from VisionDatasetMaker
454
+ if hasattr(dataset_for_names, 'classes'):
455
+ class_names = dataset_for_names.classes # type: ignore
456
+ # Fallback for Subset
457
+ elif hasattr(dataset_for_names, 'dataset') and hasattr(dataset_for_names.dataset, 'classes'): # type: ignore
458
+ class_names = dataset_for_names.dataset.classes # type: ignore
459
+ except AttributeError:
460
+ pass # class_names is still None
401
461
 
402
- print("\n--- Training History ---")
462
+ if class_names is None:
463
+ try:
464
+ # Fallback to 'target_names'
465
+ class_names = dataset_for_names.target_names # type: ignore
466
+ except AttributeError:
467
+ # Fallback to inferring from labels
468
+ labels = np.unique(y_true)
469
+ class_names = [f"Class {i}" for i in labels]
470
+ _LOGGER.warning(f"Dataset has no 'classes' or 'target_names' attribute. Using generic names.")
471
+
472
+ segmentation_metrics(y_true, y_pred, save_dir, class_names=class_names)
473
+
474
+ # print("\n--- Training History ---")
403
475
  plot_losses(self.history, save_dir=save_dir)
404
476
 
405
477
  def explain(self,
@@ -408,7 +480,7 @@ class MLTrainer:
408
480
  n_samples: int = 300,
409
481
  feature_names: Optional[List[str]] = None,
410
482
  target_names: Optional[List[str]] = None,
411
- explainer_type: Literal['deep', 'kernel'] = 'deep'):
483
+ explainer_type: Literal['deep', 'kernel'] = 'kernel'):
412
484
  """
413
485
  Explains model predictions using SHAP and saves all artifacts.
414
486
 
@@ -422,11 +494,11 @@ class MLTrainer:
422
494
  explain_dataset (Dataset | None): A specific dataset to explain.
423
495
  If None, the trainer's test dataset is used.
424
496
  n_samples (int): The number of samples to use for both background and explanation.
425
- feature_names (list[str] | None): Feature names.
497
+ feature_names (list[str] | None): Feature names. If None, the names will be extracted from the Dataset and raise an error on failure.
426
498
  target_names (list[str] | None): Target names for multi-target tasks.
427
499
  save_dir (str | Path): Directory to save all SHAP artifacts.
428
500
  explainer_type (Literal['deep', 'kernel']): The explainer to use.
429
- - 'deep': (Default) Uses shap.DeepExplainer. Fast and efficient for PyTorch models.
501
+ - 'deep': Uses shap.DeepExplainer. Fast and efficient for PyTorch models.
430
502
  - 'kernel': Uses shap.KernelExplainer. Model-agnostic but EXTREMELY slow and memory-intensive. Use with a very low 'n_samples'< 100.
431
503
  """
432
504
  # Internal helper to create a dataloader and get a random sample
@@ -456,7 +528,7 @@ class MLTrainer:
456
528
  rand_indices = torch.randperm(full_data.size(0))[:num_samples]
457
529
  return full_data[rand_indices]
458
530
 
459
- print(f"\n--- Preparing SHAP Data (sampling up to {n_samples} instances) ---")
531
+ # print(f"\n--- Preparing SHAP Data (sampling up to {n_samples} instances) ---")
460
532
 
461
533
  # 1. Get background data from the trainer's train_dataset
462
534
  background_data = _get_random_sample(self.train_dataset, n_samples)
@@ -474,10 +546,10 @@ class MLTrainer:
474
546
  # attempt to get feature names
475
547
  if feature_names is None:
476
548
  # _LOGGER.info("`feature_names` not provided. Attempting to extract from dataset...")
477
- if hasattr(target_dataset, "feature_names"):
549
+ if hasattr(target_dataset, DatasetKeys.FEATURE_NAMES):
478
550
  feature_names = target_dataset.feature_names # type: ignore
479
551
  else:
480
- _LOGGER.error("Could not extract `feature_names` from the dataset. It must be provided if the dataset object does not have a `feature_names` attribute.")
552
+ _LOGGER.error(f"Could not extract `feature_names` from the dataset. It must be provided if the dataset object does not have a '{DatasetKeys.FEATURE_NAMES}' attribute.")
481
553
  raise ValueError()
482
554
 
483
555
  # move model to device
@@ -498,7 +570,7 @@ class MLTrainer:
498
570
  # try to get target names
499
571
  if target_names is None:
500
572
  target_names = []
501
- if hasattr(target_dataset, 'target_names'):
573
+ if hasattr(target_dataset, DatasetKeys.TARGET_NAMES):
502
574
  target_names = target_dataset.target_names # type: ignore
503
575
  else:
504
576
  # Infer number of targets from the model's output layer
@@ -549,7 +621,7 @@ class MLTrainer:
549
621
  yield attention_weights
550
622
 
551
623
  def explain_attention(self, save_dir: Union[str, Path],
552
- feature_names: Optional[List[str]],
624
+ feature_names: Optional[List[str]] = None,
553
625
  explain_dataset: Optional[Dataset] = None,
554
626
  plot_n_features: int = 10):
555
627
  """
@@ -559,18 +631,17 @@ class MLTrainer:
559
631
 
560
632
  Args:
561
633
  save_dir (str | Path): Directory to save the plot and summary data.
562
- feature_names (List[str] | None): Names for the features for plot labeling. If not given, generic names will be used.
634
+ feature_names (List[str] | None): Names for the features for plot labeling. If None, the names will be extracted from the Dataset and raise an error on failure.
563
635
  explain_dataset (Dataset, optional): A specific dataset to explain. If None, the trainer's test dataset is used.
564
636
  plot_n_features (int): Number of top features to plot.
565
637
  """
566
638
 
567
- print("\n--- Attention Analysis ---")
639
+ # print("\n--- Attention Analysis ---")
568
640
 
569
641
  # --- Step 1: Check if the model supports this explanation ---
570
642
  if not getattr(self.model, 'has_interpretable_attention', False):
571
643
  _LOGGER.warning(
572
- "Model is not flagged for interpretable attention analysis. "
573
- "Skipping. This is the correct behavior for models like MultiHeadAttentionMLP."
644
+ "Model is not flagged for interpretable attention analysis. Skipping. This is the correct behavior for models like MultiHeadAttentionMLP."
574
645
  )
575
646
  return
576
647
 
@@ -580,6 +651,14 @@ class MLTrainer:
580
651
  _LOGGER.error("The explanation dataset is empty or invalid. Skipping attention analysis.")
581
652
  return
582
653
 
654
+ # Get feature names
655
+ if feature_names is None:
656
+ if hasattr(dataset_to_use, DatasetKeys.FEATURE_NAMES):
657
+ feature_names = dataset_to_use.feature_names # type: ignore
658
+ else:
659
+ _LOGGER.error(f"Could not extract `feature_names` from the dataset for attention plot. It must be provided if the dataset object does not have a '{DatasetKeys.FEATURE_NAMES}' attribute.")
660
+ raise ValueError()
661
+
583
662
  explain_loader = DataLoader(
584
663
  dataset=dataset_to_use, batch_size=32, shuffle=False,
585
664
  num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
@@ -629,7 +708,397 @@ class MLTrainer:
629
708
  self.device = self._validate_device(device)
630
709
  self.model.to(self.device)
631
710
  _LOGGER.info(f"Trainer and model moved to {self.device}.")
711
+
712
+
713
+ # Object Detection Trainer
714
+ class ObjectDetectionTrainer:
715
+ def __init__(self, model: nn.Module, train_dataset: Dataset, test_dataset: Dataset,
716
+ collate_fn: Callable, optimizer: torch.optim.Optimizer,
717
+ device: Union[Literal['cuda', 'mps', 'cpu'],str], dataloader_workers: int = 2, callbacks: Optional[List[Callback]] = None):
718
+ """
719
+ Automates the training process of an Object Detection Model (e.g., DragonFastRCNN).
720
+
721
+ Built-in Callbacks: `History`, `TqdmProgressBar`
722
+
723
+ Args:
724
+ model (nn.Module): The PyTorch object detection model to train.
725
+ train_dataset (Dataset): The training dataset.
726
+ test_dataset (Dataset): The testing/validation dataset.
727
+ collate_fn (Callable): The collate function from `ObjectDetectionDatasetMaker.collate_fn`.
728
+ optimizer (torch.optim.Optimizer): The optimizer.
729
+ device (str): The device to run training on ('cpu', 'cuda', 'mps').
730
+ dataloader_workers (int): Subprocesses for data loading.
731
+ callbacks (List[Callback] | None): A list of callbacks to use during training.
732
+
733
+ ## Note:
734
+ This trainer is specialized. It does not take a `criterion` because object detection models like Faster R-CNN return a dictionary of losses directly from their forward pass during training.
735
+ """
736
+ self.model = model
737
+ self.train_dataset = train_dataset
738
+ self.test_dataset = test_dataset
739
+ self.kind = "object_detection"
740
+ self.collate_fn = collate_fn
741
+ self.criterion = None # Criterion is handled inside the model
742
+ self.optimizer = optimizer
743
+ self.scheduler = None
744
+ self.device = self._validate_device(device)
745
+ self.dataloader_workers = dataloader_workers
746
+
747
+ # Callback handler - History and TqdmProgressBar are added by default
748
+ default_callbacks = [History(), TqdmProgressBar()]
749
+ user_callbacks = callbacks if callbacks is not None else []
750
+ self.callbacks = default_callbacks + user_callbacks
751
+ self._set_trainer_on_callbacks()
752
+
753
+ # Internal state
754
+ self.train_loader = None
755
+ self.test_loader = None
756
+ self.history = {}
757
+ self.epoch = 0
758
+ self.epochs = 0 # Total epochs for the fit run
759
+ self.start_epoch = 1
760
+ self.stop_training = False
761
+ self._batch_size = 10
762
+
763
+ def _validate_device(self, device: str) -> torch.device:
764
+ """Validates the selected device and returns a torch.device object."""
765
+ device_lower = device.lower()
766
+ if "cuda" in device_lower and not torch.cuda.is_available():
767
+ _LOGGER.warning("CUDA not available, switching to CPU.")
768
+ device = "cpu"
769
+ elif device_lower == "mps" and not torch.backends.mps.is_available():
770
+ _LOGGER.warning("Apple Metal Performance Shaders (MPS) not available, switching to CPU.")
771
+ device = "cpu"
772
+ return torch.device(device)
773
+
774
+ def _set_trainer_on_callbacks(self):
775
+ """Gives each callback a reference to this trainer instance."""
776
+ for callback in self.callbacks:
777
+ callback.set_trainer(self)
778
+
779
+ def _create_dataloaders(self, batch_size: int, shuffle: bool):
780
+ """Initializes the DataLoaders with the object detection collate_fn."""
781
+ # Ensure stability on MPS devices by setting num_workers to 0
782
+ loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
783
+
784
+ self.train_loader = DataLoader(
785
+ dataset=self.train_dataset,
786
+ batch_size=batch_size,
787
+ shuffle=shuffle,
788
+ num_workers=loader_workers,
789
+ pin_memory=("cuda" in self.device.type),
790
+ collate_fn=self.collate_fn # Use the provided collate function
791
+ )
792
+
793
+ self.test_loader = DataLoader(
794
+ dataset=self.test_dataset,
795
+ batch_size=batch_size,
796
+ shuffle=False,
797
+ num_workers=loader_workers,
798
+ pin_memory=("cuda" in self.device.type),
799
+ collate_fn=self.collate_fn # Use the provided collate function
800
+ )
801
+
802
+ def _load_checkpoint(self, path: Union[str, Path]):
803
+ """Loads a training checkpoint to resume training."""
804
+ p = make_fullpath(path, enforce="file")
805
+ _LOGGER.info(f"Loading checkpoint from '{p.name}' to resume training...")
806
+
807
+ try:
808
+ checkpoint = torch.load(p, map_location=self.device)
809
+
810
+ if PyTorchCheckpointKeys.MODEL_STATE not in checkpoint or PyTorchCheckpointKeys.OPTIMIZER_STATE not in checkpoint:
811
+ _LOGGER.error(f"Checkpoint file '{p.name}' is invalid. Missing 'model_state_dict' or 'optimizer_state_dict'.")
812
+ raise KeyError()
813
+
814
+ self.model.load_state_dict(checkpoint[PyTorchCheckpointKeys.MODEL_STATE])
815
+ self.optimizer.load_state_dict(checkpoint[PyTorchCheckpointKeys.OPTIMIZER_STATE])
816
+ self.start_epoch = checkpoint.get(PyTorchCheckpointKeys.EPOCH, 0) + 1 # Resume on the *next* epoch
817
+
818
+ # --- Scheduler State Loading Logic ---
819
+ scheduler_state_exists = PyTorchCheckpointKeys.SCHEDULER_STATE in checkpoint
820
+ scheduler_object_exists = self.scheduler is not None
821
+
822
+ if scheduler_object_exists and scheduler_state_exists:
823
+ # Case 1: Both exist. Attempt to load.
824
+ try:
825
+ self.scheduler.load_state_dict(checkpoint[PyTorchCheckpointKeys.SCHEDULER_STATE]) # type: ignore
826
+ scheduler_name = self.scheduler.__class__.__name__
827
+ _LOGGER.info(f"Restored LR scheduler state for: {scheduler_name}")
828
+ except Exception as e:
829
+ # Loading failed, likely a mismatch
830
+ scheduler_name = self.scheduler.__class__.__name__
831
+ _LOGGER.error(f"Failed to load scheduler state for '{scheduler_name}'. A different scheduler type might have been used.")
832
+ raise e
833
+
834
+ elif scheduler_object_exists and not scheduler_state_exists:
835
+ # Case 2: Scheduler provided, but no state in checkpoint.
836
+ scheduler_name = self.scheduler.__class__.__name__
837
+ _LOGGER.warning(f"'{scheduler_name}' was provided, but no scheduler state was found in the checkpoint. The scheduler will start from its initial state.")
838
+
839
+ elif not scheduler_object_exists and scheduler_state_exists:
840
+ # Case 3: State in checkpoint, but no scheduler provided.
841
+ _LOGGER.error("Checkpoint contains an LR scheduler state, but no LRScheduler callback was provided.")
842
+ raise ValueError()
843
+
844
+ # Restore callback states
845
+ for cb in self.callbacks:
846
+ if isinstance(cb, ModelCheckpoint) and PyTorchCheckpointKeys.BEST_SCORE in checkpoint:
847
+ cb.best = checkpoint[PyTorchCheckpointKeys.BEST_SCORE]
848
+ _LOGGER.info(f"Restored {cb.__class__.__name__} 'best' score to: {cb.best:.4f}")
849
+
850
+ _LOGGER.info(f"Checkpoint loaded. Resuming training from epoch {self.start_epoch}.")
851
+
852
+ except Exception as e:
853
+ _LOGGER.error(f"Failed to load checkpoint from '{p}': {e}")
854
+ raise
855
+
856
+ def fit(self,
857
+ epochs: int = 10,
858
+ batch_size: int = 10,
859
+ shuffle: bool = True,
860
+ resume_from_checkpoint: Optional[Union[str, Path]] = None):
861
+ """
862
+ Starts the training-validation process of the model.
863
+
864
+ Returns the "History" callback dictionary.
865
+
866
+ Args:
867
+ epochs (int): The total number of epochs to train for.
868
+ batch_size (int): The number of samples per batch.
869
+ shuffle (bool): Whether to shuffle the training data at each epoch.
870
+ resume_from_checkpoint (str | Path | None): Optional path to a checkpoint to resume training.
871
+ """
872
+ self.epochs = epochs
873
+ self._batch_size = batch_size
874
+ self._create_dataloaders(self._batch_size, shuffle)
875
+ self.model.to(self.device)
876
+
877
+ if resume_from_checkpoint:
878
+ self._load_checkpoint(resume_from_checkpoint)
879
+
880
+ # Reset stop_training flag on the trainer
881
+ self.stop_training = False
882
+
883
+ self._callbacks_hook('on_train_begin')
884
+
885
+ for epoch in range(self.start_epoch, self.epochs + 1):
886
+ self.epoch = epoch
887
+ epoch_logs = {}
888
+ self._callbacks_hook('on_epoch_begin', epoch, logs=epoch_logs)
889
+
890
+ train_logs = self._train_step()
891
+ epoch_logs.update(train_logs)
892
+
893
+ val_logs = self._validation_step()
894
+ epoch_logs.update(val_logs)
895
+
896
+ self._callbacks_hook('on_epoch_end', epoch, logs=epoch_logs)
897
+
898
+ # Check the early stopping flag
899
+ if self.stop_training:
900
+ break
901
+
902
+ self._callbacks_hook('on_train_end')
903
+ return self.history
904
+
905
+ def _train_step(self):
906
+ self.model.train()
907
+ running_loss = 0.0
908
+ for batch_idx, (images, targets) in enumerate(self.train_loader): # type: ignore
909
+ # images is a tuple of tensors, targets is a tuple of dicts
910
+ batch_size = len(images)
911
+
912
+ # Create a log dictionary for the batch
913
+ batch_logs = {
914
+ PyTorchLogKeys.BATCH_INDEX: batch_idx,
915
+ PyTorchLogKeys.BATCH_SIZE: batch_size
916
+ }
917
+ self._callbacks_hook('on_batch_begin', batch_idx, logs=batch_logs)
918
+
919
+ # Move data to device
920
+ images = list(img.to(self.device) for img in images)
921
+ targets = [{k: v.to(self.device) for k, v in t.items()} for t in targets]
922
+
923
+ self.optimizer.zero_grad()
924
+
925
+ # Model returns a loss dict when in train() mode and targets are passed
926
+ loss_dict = self.model(images, targets)
927
+
928
+ if not loss_dict:
929
+ # No losses returned, skip batch
930
+ _LOGGER.warning(f"Model returned no losses for batch {batch_idx}. Skipping.")
931
+ batch_logs[PyTorchLogKeys.BATCH_LOSS] = 0
932
+ self._callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
933
+ continue
934
+
935
+ # Sum all losses
936
+ loss: torch.Tensor = sum(l for l in loss_dict.values()) # type: ignore
937
+
938
+ loss.backward()
939
+ self.optimizer.step()
940
+
941
+ # Calculate batch loss and update running loss for the epoch
942
+ batch_loss = loss.item()
943
+ running_loss += batch_loss * batch_size
944
+
945
+ # Add the batch loss to the logs and call the end-of-batch hook
946
+ batch_logs[PyTorchLogKeys.BATCH_LOSS] = batch_loss # type: ignore
947
+ self._callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
948
+
949
+ return {PyTorchLogKeys.TRAIN_LOSS: running_loss / len(self.train_loader.dataset)} # type: ignore
950
+
951
+ def _validation_step(self):
952
+ self.model.train() # Set to train mode even for validation loss calculation
953
+ # as model internals (e.g., proposals) might differ,
954
+ # but we still need loss_dict.
955
+ # We use torch.no_grad() to prevent gradient updates.
956
+ running_loss = 0.0
957
+ with torch.no_grad():
958
+ for images, targets in self.test_loader: # type: ignore
959
+ batch_size = len(images)
960
+
961
+ # Move data to device
962
+ images = list(img.to(self.device) for img in images)
963
+ targets = [{k: v.to(self.device) for k, v in t.items()} for t in targets]
964
+
965
+ # Get loss dict
966
+ loss_dict = self.model(images, targets)
967
+
968
+ if not loss_dict:
969
+ _LOGGER.warning("Model returned no losses during validation step. Skipping batch.")
970
+ continue # Skip if no losses
971
+
972
+ # Sum all losses
973
+ loss: torch.Tensor = sum(l for l in loss_dict.values()) # type: ignore
974
+
975
+ running_loss += loss.item() * batch_size
976
+
977
+ logs = {PyTorchLogKeys.VAL_LOSS: running_loss / len(self.test_loader.dataset)} # type: ignore
978
+ return logs
979
+
980
+ def evaluate(self, save_dir: Union[str, Path], data: Optional[Union[DataLoader, Dataset]] = None):
981
+ """
982
+ Evaluates the model using object detection mAP metrics.
983
+
984
+ Args:
985
+ save_dir (str | Path): Directory to save all reports and plots.
986
+ data (DataLoader | Dataset | None): The data to evaluate on. If None, defaults to the trainer's internal test_dataset.
987
+ """
988
+ dataset_for_names = None
989
+ eval_loader = None
990
+
991
+ if isinstance(data, DataLoader):
992
+ eval_loader = data
993
+ if hasattr(data, 'dataset'):
994
+ dataset_for_names = data.dataset
995
+ elif isinstance(data, Dataset):
996
+ # Create a new loader from the provided dataset
997
+ eval_loader = DataLoader(data,
998
+ batch_size=self._batch_size,
999
+ shuffle=False,
1000
+ num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
1001
+ pin_memory=(self.device.type == "cuda"),
1002
+ collate_fn=self.collate_fn)
1003
+ dataset_for_names = data
1004
+ else: # data is None, use the trainer's default test dataset
1005
+ if self.test_dataset is None:
1006
+ _LOGGER.error("Cannot evaluate. No data provided and no test_dataset available in the trainer.")
1007
+ raise ValueError()
1008
+ # Create a fresh DataLoader from the test_dataset
1009
+ eval_loader = DataLoader(
1010
+ self.test_dataset,
1011
+ batch_size=self._batch_size,
1012
+ shuffle=False,
1013
+ num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
1014
+ pin_memory=(self.device.type == "cuda"),
1015
+ collate_fn=self.collate_fn
1016
+ )
1017
+ dataset_for_names = self.test_dataset
1018
+
1019
+ if eval_loader is None:
1020
+ _LOGGER.error("Cannot evaluate. No valid data was provided or found.")
1021
+ raise ValueError()
1022
+
1023
+ # print("\n--- Model Evaluation ---")
1024
+
1025
+ all_predictions = []
1026
+ all_targets = []
1027
+
1028
+ self.model.eval() # Set model to evaluation mode
1029
+ self.model.to(self.device)
1030
+
1031
+ with torch.no_grad():
1032
+ for images, targets in eval_loader:
1033
+ # Move images to device
1034
+ images = list(img.to(self.device) for img in images)
1035
+
1036
+ # Model returns predictions when in eval() mode
1037
+ predictions = self.model(images)
1038
+
1039
+ # Move predictions and targets to CPU for aggregation
1040
+ cpu_preds = [{k: v.to('cpu') for k, v in p.items()} for p in predictions]
1041
+ cpu_targets = [{k: v.to('cpu') for k, v in t.items()} for t in targets]
1042
+
1043
+ all_predictions.extend(cpu_preds)
1044
+ all_targets.extend(cpu_targets)
1045
+
1046
+ if not all_targets:
1047
+ _LOGGER.error("Evaluation failed: No data was processed.")
1048
+ return
1049
+
1050
+ # Get class names from the dataset for the report
1051
+ class_names = None
1052
+ try:
1053
+ # Try to get 'classes' from ObjectDetectionDatasetMaker
1054
+ if hasattr(dataset_for_names, 'classes'):
1055
+ class_names = dataset_for_names.classes # type: ignore
1056
+ # Fallback for Subset
1057
+ elif hasattr(dataset_for_names, 'dataset') and hasattr(dataset_for_names.dataset, 'classes'): # type: ignore
1058
+ class_names = dataset_for_names.dataset.classes # type: ignore
1059
+ except AttributeError:
1060
+ _LOGGER.warning("Could not find 'classes' attribute on dataset. Per-class metrics will not be named.")
1061
+ pass # class_names is still None
1062
+
1063
+ # --- Routing Logic ---
1064
+ object_detection_metrics(
1065
+ preds=all_predictions,
1066
+ targets=all_targets,
1067
+ save_dir=save_dir,
1068
+ class_names=class_names,
1069
+ print_output=False
1070
+ )
1071
+
1072
+ # print("\n--- Training History ---")
1073
+ plot_losses(self.history, save_dir=save_dir)
1074
+
1075
+ def _callbacks_hook(self, method_name: str, *args, **kwargs):
1076
+ """Calls the specified method on all callbacks."""
1077
+ for callback in self.callbacks:
1078
+ method = getattr(callback, method_name)
1079
+ method(*args, **kwargs)
1080
+
1081
+ def to_cpu(self):
1082
+ """
1083
+ Moves the model to the CPU and updates the trainer's device setting.
1084
+
1085
+ This is useful for running operations that require the CPU.
1086
+ """
1087
+ self.device = torch.device('cpu')
1088
+ self.model.to(self.device)
1089
+ _LOGGER.info("Trainer and model moved to CPU.")
632
1090
 
1091
+ def to_device(self, device: str):
1092
+ """
1093
+ Moves the model to the specified device and updates the trainer's device setting.
1094
+
1095
+ Args:
1096
+ device (str): The target device (e.g., 'cuda', 'mps', 'cpu').
1097
+ """
1098
+ self.device = self._validate_device(device)
1099
+ self.model.to(self.device)
1100
+ _LOGGER.info(f"Trainer and model moved to {self.device}.")
1101
+
633
1102
 
634
1103
  def info():
635
1104
  _script_info(__all__)