dragon-ml-toolbox 19.10.0__py3-none-any.whl → 19.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -7,7 +7,7 @@ import numpy as np
7
7
  from abc import ABC, abstractmethod
8
8
 
9
9
  from ._path_manager import make_fullpath
10
- from ._ML_callbacks import _Callback, History, TqdmProgressBar, DragonModelCheckpoint, DragonEarlyStopping, DragonLRScheduler
10
+ from ._ML_callbacks import _Callback, History, TqdmProgressBar, DragonModelCheckpoint, _DragonEarlyStopping, _DragonLRScheduler
11
11
  from ._ML_evaluation import classification_metrics, regression_metrics, plot_losses, shap_summary_plot, plot_attention_importance
12
12
  from ._ML_evaluation_multi import multi_target_regression_metrics, multi_label_classification_metrics, multi_target_shap_summary_plot
13
13
  from ._ML_vision_evaluation import segmentation_metrics, object_detection_metrics
@@ -66,8 +66,8 @@ class _BaseDragonTrainer(ABC):
66
66
  device: Union[Literal['cuda', 'mps', 'cpu'],str],
67
67
  dataloader_workers: int = 2,
68
68
  checkpoint_callback: Optional[DragonModelCheckpoint] = None,
69
- early_stopping_callback: Optional[DragonEarlyStopping] = None,
70
- lr_scheduler_callback: Optional[DragonLRScheduler] = None,
69
+ early_stopping_callback: Optional[_DragonEarlyStopping] = None,
70
+ lr_scheduler_callback: Optional[_DragonLRScheduler] = None,
71
71
  extra_callbacks: Optional[List[_Callback]] = None):
72
72
 
73
73
  self.model = model
@@ -271,18 +271,18 @@ class _BaseDragonTrainer(ABC):
271
271
  self.model.to(self.device)
272
272
  _LOGGER.info(f"Trainer and model moved to {self.device}.")
273
273
 
274
- def _load_model_state_for_finalizing(self, model_checkpoint: Union[Path, Literal['latest', 'current']]):
274
+ def _load_model_state_for_finalizing(self, model_checkpoint: Union[Path, Literal['best', 'current']]):
275
275
  """
276
276
  Private helper to load the correct model state_dict based on user's choice.
277
277
  This is called by finalize_model_training() in subclasses.
278
278
  """
279
279
  if isinstance(model_checkpoint, Path):
280
280
  self._load_checkpoint(path=model_checkpoint)
281
- elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback:
281
+ elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback:
282
282
  path_to_latest = self._checkpoint_callback.best_checkpoint_path
283
283
  self._load_checkpoint(path_to_latest)
284
- elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback is None:
285
- _LOGGER.error(f"'model_checkpoint' set to '{MagicWords.LATEST}' but no checkpoint callback was found.")
284
+ elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback is None:
285
+ _LOGGER.error(f"'model_checkpoint' set to '{MagicWords.BEST}' but no checkpoint callback was found.")
286
286
  raise ValueError()
287
287
  elif model_checkpoint == MagicWords.CURRENT:
288
288
  pass
@@ -336,8 +336,8 @@ class DragonTrainer(_BaseDragonTrainer):
336
336
  optimizer: torch.optim.Optimizer,
337
337
  device: Union[Literal['cuda', 'mps', 'cpu'],str],
338
338
  checkpoint_callback: Optional[DragonModelCheckpoint],
339
- early_stopping_callback: Optional[DragonEarlyStopping],
340
- lr_scheduler_callback: Optional[DragonLRScheduler],
339
+ early_stopping_callback: Optional[_DragonEarlyStopping],
340
+ lr_scheduler_callback: Optional[_DragonLRScheduler],
341
341
  extra_callbacks: Optional[List[_Callback]] = None,
342
342
  criterion: Union[nn.Module,Literal["auto"]] = "auto",
343
343
  dataloader_workers: int = 2):
@@ -634,7 +634,7 @@ class DragonTrainer(_BaseDragonTrainer):
634
634
 
635
635
  def evaluate(self,
636
636
  save_dir: Union[str, Path],
637
- model_checkpoint: Union[Path, Literal["latest", "current"]],
637
+ model_checkpoint: Union[Path, Literal["best", "current"]],
638
638
  classification_threshold: Optional[float] = None,
639
639
  test_data: Optional[Union[DataLoader, Dataset]] = None,
640
640
  val_format_configuration: Optional[Union[
@@ -665,7 +665,7 @@ class DragonTrainer(_BaseDragonTrainer):
665
665
  Args:
666
666
  model_checkpoint ('auto' | Path | None):
667
667
  - Path to a valid checkpoint for the model. The state of the trained model will be overwritten in place.
668
- - If 'latest', the latest checkpoint will be loaded if a DragonModelCheckpoint was provided. The state of the trained model will be overwritten in place.
668
+ - If 'best', the best checkpoint will be loaded if a DragonModelCheckpoint was provided. The state of the trained model will be overwritten in place.
669
669
  - If 'current', use the current state of the trained model up the latest trained epoch.
670
670
  save_dir (str | Path): Directory to save all reports and plots.
671
671
  classification_threshold (float | None): Used for tasks using a binary approach (binary classification, binary segmentation, multilabel binary classification)
@@ -676,10 +676,10 @@ class DragonTrainer(_BaseDragonTrainer):
676
676
  # Validate model checkpoint
677
677
  if isinstance(model_checkpoint, Path):
678
678
  checkpoint_validated = make_fullpath(model_checkpoint, enforce="file")
679
- elif model_checkpoint in [MagicWords.LATEST, MagicWords.CURRENT]:
679
+ elif model_checkpoint in [MagicWords.BEST, MagicWords.CURRENT]:
680
680
  checkpoint_validated = model_checkpoint
681
681
  else:
682
- _LOGGER.error(f"'model_checkpoint' must be a Path object, or the string '{MagicWords.LATEST}', or the string '{MagicWords.CURRENT}'.")
682
+ _LOGGER.error(f"'model_checkpoint' must be a Path object, or the string '{MagicWords.BEST}', or the string '{MagicWords.CURRENT}'.")
683
683
  raise ValueError()
684
684
 
685
685
  # Validate classification threshold
@@ -778,7 +778,7 @@ class DragonTrainer(_BaseDragonTrainer):
778
778
 
779
779
  def _evaluate(self,
780
780
  save_dir: Union[str, Path],
781
- model_checkpoint: Union[Path, Literal["latest", "current"]],
781
+ model_checkpoint: Union[Path, Literal["best", "current"]],
782
782
  classification_threshold: float,
783
783
  data: Optional[Union[DataLoader, Dataset]],
784
784
  format_configuration: Optional[Union[
@@ -804,11 +804,11 @@ class DragonTrainer(_BaseDragonTrainer):
804
804
  # load model checkpoint
805
805
  if isinstance(model_checkpoint, Path):
806
806
  self._load_checkpoint(path=model_checkpoint)
807
- elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback:
807
+ elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback:
808
808
  path_to_latest = self._checkpoint_callback.best_checkpoint_path
809
809
  self._load_checkpoint(path_to_latest)
810
- elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback is None:
811
- _LOGGER.error(f"'model_checkpoint' set to '{MagicWords.LATEST}' but no checkpoint callback was found.")
810
+ elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback is None:
811
+ _LOGGER.error(f"'model_checkpoint' set to '{MagicWords.BEST}' but no checkpoint callback was found.")
812
812
  raise ValueError()
813
813
 
814
814
  # Dataloader
@@ -1352,7 +1352,7 @@ class DragonTrainer(_BaseDragonTrainer):
1352
1352
  _LOGGER.error("No attention weights were collected from the model.")
1353
1353
 
1354
1354
  def finalize_model_training(self,
1355
- model_checkpoint: Union[Path, Literal['latest', 'current']],
1355
+ model_checkpoint: Union[Path, Literal['best', 'current']],
1356
1356
  save_dir: Union[str, Path],
1357
1357
  finalize_config: Union[FinalizeRegression,
1358
1358
  FinalizeMultiTargetRegression,
@@ -1369,10 +1369,10 @@ class DragonTrainer(_BaseDragonTrainer):
1369
1369
  This method saves the model's `state_dict`, the final epoch number, and optional configuration for the task at hand.
1370
1370
 
1371
1371
  Args:
1372
- model_checkpoint (Path | "latest" | "current"):
1372
+ model_checkpoint (Path | "best" | "current"):
1373
1373
  - Path: Loads the model state from a specific checkpoint file.
1374
- - "latest": Loads the best model state saved by the `DragonModelCheckpoint` callback.
1375
- - "current": Uses the model's state as it is at the end of the `fit()` call.
1374
+ - "best": Loads the best model state saved by the `DragonModelCheckpoint` callback.
1375
+ - "current": Uses the model's state as it is.
1376
1376
  save_dir (str | Path): The directory to save the finalized model.
1377
1377
  finalize_config (object): A data class instance specific to the ML task containing task-specific metadata required for inference.
1378
1378
  """
@@ -1442,8 +1442,8 @@ class DragonDetectionTrainer(_BaseDragonTrainer):
1442
1442
  collate_fn: Callable, optimizer: torch.optim.Optimizer,
1443
1443
  device: Union[Literal['cuda', 'mps', 'cpu'],str],
1444
1444
  checkpoint_callback: Optional[DragonModelCheckpoint],
1445
- early_stopping_callback: Optional[DragonEarlyStopping],
1446
- lr_scheduler_callback: Optional[DragonLRScheduler],
1445
+ early_stopping_callback: Optional[_DragonEarlyStopping],
1446
+ lr_scheduler_callback: Optional[_DragonLRScheduler],
1447
1447
  extra_callbacks: Optional[List[_Callback]] = None,
1448
1448
  dataloader_workers: int = 2):
1449
1449
  """
@@ -1601,7 +1601,7 @@ class DragonDetectionTrainer(_BaseDragonTrainer):
1601
1601
 
1602
1602
  def evaluate(self,
1603
1603
  save_dir: Union[str, Path],
1604
- model_checkpoint: Union[Path, Literal["latest", "current"]],
1604
+ model_checkpoint: Union[Path, Literal["best", "current"]],
1605
1605
  test_data: Optional[Union[DataLoader, Dataset]] = None):
1606
1606
  """
1607
1607
  Evaluates the model using object detection mAP metrics.
@@ -1610,17 +1610,17 @@ class DragonDetectionTrainer(_BaseDragonTrainer):
1610
1610
  save_dir (str | Path): Directory to save all reports and plots.
1611
1611
  model_checkpoint ('auto' | Path | None):
1612
1612
  - Path to a valid checkpoint for the model. The state of the trained model will be overwritten in place.
1613
- - If 'latest', the latest checkpoint will be loaded if a DragonModelCheckpoint was provided. The state of the trained model will be overwritten in place.
1613
+ - If 'best', the best checkpoint will be loaded if a DragonModelCheckpoint was provided. The state of the trained model will be overwritten in place.
1614
1614
  - If 'current', use the current state of the trained model up the latest trained epoch.
1615
1615
  test_data (DataLoader | Dataset | None): Optional Test data to evaluate the model performance. Validation and Test metrics will be saved to subdirectories.
1616
1616
  """
1617
1617
  # Validate model checkpoint
1618
1618
  if isinstance(model_checkpoint, Path):
1619
1619
  checkpoint_validated = make_fullpath(model_checkpoint, enforce="file")
1620
- elif model_checkpoint in [MagicWords.LATEST, MagicWords.CURRENT]:
1620
+ elif model_checkpoint in [MagicWords.BEST, MagicWords.CURRENT]:
1621
1621
  checkpoint_validated = model_checkpoint
1622
1622
  else:
1623
- _LOGGER.error(f"'model_checkpoint' must be a Path object, or the string '{MagicWords.LATEST}', or the string '{MagicWords.CURRENT}'.")
1623
+ _LOGGER.error(f"'model_checkpoint' must be a Path object, or the string '{MagicWords.BEST}', or the string '{MagicWords.CURRENT}'.")
1624
1624
  raise ValueError()
1625
1625
 
1626
1626
  # Validate directory
@@ -1656,7 +1656,7 @@ class DragonDetectionTrainer(_BaseDragonTrainer):
1656
1656
 
1657
1657
  def _evaluate(self,
1658
1658
  save_dir: Union[str, Path],
1659
- model_checkpoint: Union[Path, Literal["latest", "current"]],
1659
+ model_checkpoint: Union[Path, Literal["best", "current"]],
1660
1660
  data: Optional[Union[DataLoader, Dataset]]):
1661
1661
  """
1662
1662
  Changed to a private helper method
@@ -1667,7 +1667,7 @@ class DragonDetectionTrainer(_BaseDragonTrainer):
1667
1667
  data (DataLoader | Dataset | None): The data to evaluate on. If None, defaults to the trainer's internal test_dataset.
1668
1668
  model_checkpoint ('auto' | Path | None):
1669
1669
  - Path to a valid checkpoint for the model. The state of the trained model will be overwritten in place.
1670
- - If 'latest', the latest checkpoint will be loaded if a DragonModelCheckpoint was provided. The state of the trained model will be overwritten in place.
1670
+ - If 'best', the best checkpoint will be loaded if a DragonModelCheckpoint was provided. The state of the trained model will be overwritten in place.
1671
1671
  - If 'current', use the current state of the trained model up the latest trained epoch.
1672
1672
  """
1673
1673
  dataset_for_artifacts = None
@@ -1676,11 +1676,11 @@ class DragonDetectionTrainer(_BaseDragonTrainer):
1676
1676
  # load model checkpoint
1677
1677
  if isinstance(model_checkpoint, Path):
1678
1678
  self._load_checkpoint(path=model_checkpoint)
1679
- elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback:
1679
+ elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback:
1680
1680
  path_to_latest = self._checkpoint_callback.best_checkpoint_path
1681
1681
  self._load_checkpoint(path_to_latest)
1682
- elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback is None:
1683
- _LOGGER.error(f"'model_checkpoint' set to '{MagicWords.LATEST}' but no checkpoint callback was found.")
1682
+ elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback is None:
1683
+ _LOGGER.error(f"'model_checkpoint' set to '{MagicWords.BEST}' but no checkpoint callback was found.")
1684
1684
  raise ValueError()
1685
1685
 
1686
1686
  # Dataloader
@@ -1767,7 +1767,7 @@ class DragonDetectionTrainer(_BaseDragonTrainer):
1767
1767
 
1768
1768
  def finalize_model_training(self,
1769
1769
  save_dir: Union[str, Path],
1770
- model_checkpoint: Union[Path, Literal['latest', 'current']],
1770
+ model_checkpoint: Union[Path, Literal['best', 'current']],
1771
1771
  finalize_config: FinalizeObjectDetection
1772
1772
  ):
1773
1773
  """
@@ -1777,10 +1777,10 @@ class DragonDetectionTrainer(_BaseDragonTrainer):
1777
1777
 
1778
1778
  Args:
1779
1779
  save_dir (Union[str, Path]): The directory to save the finalized model.
1780
- model_checkpoint (Union[Path, Literal["latest", "current"]]):
1780
+ model_checkpoint (Union[Path, Literal["best", "current"]]):
1781
1781
  - Path: Loads the model state from a specific checkpoint file.
1782
- - "latest": Loads the best model state saved by the `DragonModelCheckpoint` callback.
1783
- - "current": Uses the model's state as it is at the end of the `fit()` call.
1782
+ - "best": Loads the best model state saved by the `DragonModelCheckpoint` callback.
1783
+ - "current": Uses the model's state as it is.
1784
1784
  finalize_config (FinalizeObjectDetection): A data class instance specific to the ML task containing task-specific metadata required for inference.
1785
1785
  """
1786
1786
  if not isinstance(finalize_config, FinalizeObjectDetection):
@@ -1818,8 +1818,8 @@ class DragonSequenceTrainer(_BaseDragonTrainer):
1818
1818
  optimizer: torch.optim.Optimizer,
1819
1819
  device: Union[Literal['cuda', 'mps', 'cpu'],str],
1820
1820
  checkpoint_callback: Optional[DragonModelCheckpoint],
1821
- early_stopping_callback: Optional[DragonEarlyStopping],
1822
- lr_scheduler_callback: Optional[DragonLRScheduler],
1821
+ early_stopping_callback: Optional[_DragonEarlyStopping],
1822
+ lr_scheduler_callback: Optional[_DragonLRScheduler],
1823
1823
  extra_callbacks: Optional[List[_Callback]] = None,
1824
1824
  criterion: Union[nn.Module,Literal["auto"]] = "auto",
1825
1825
  dataloader_workers: int = 2):
@@ -2036,7 +2036,7 @@ class DragonSequenceTrainer(_BaseDragonTrainer):
2036
2036
 
2037
2037
  def evaluate(self,
2038
2038
  save_dir: Union[str, Path],
2039
- model_checkpoint: Union[Path, Literal["latest", "current"]],
2039
+ model_checkpoint: Union[Path, Literal["best", "current"]],
2040
2040
  test_data: Optional[Union[DataLoader, Dataset]] = None,
2041
2041
  val_format_configuration: Optional[Union[SequenceValueMetricsFormat,
2042
2042
  SequenceSequenceMetricsFormat]]=None,
@@ -2048,7 +2048,7 @@ class DragonSequenceTrainer(_BaseDragonTrainer):
2048
2048
  Args:
2049
2049
  model_checkpoint ('auto' | Path | None):
2050
2050
  - Path to a valid checkpoint for the model.
2051
- - If 'latest', the latest checkpoint will be loaded.
2051
+ - If 'best', the best checkpoint will be loaded.
2052
2052
  - If 'current', use the current state of the trained model.
2053
2053
  save_dir (str | Path): Directory to save all reports and plots.
2054
2054
  test_data (DataLoader | Dataset | None): Optional Test data.
@@ -2058,10 +2058,10 @@ class DragonSequenceTrainer(_BaseDragonTrainer):
2058
2058
  # Validate model checkpoint
2059
2059
  if isinstance(model_checkpoint, Path):
2060
2060
  checkpoint_validated = make_fullpath(model_checkpoint, enforce="file")
2061
- elif model_checkpoint in [MagicWords.LATEST, MagicWords.CURRENT]:
2061
+ elif model_checkpoint in [MagicWords.BEST, MagicWords.CURRENT]:
2062
2062
  checkpoint_validated = model_checkpoint
2063
2063
  else:
2064
- _LOGGER.error(f"'model_checkpoint' must be a Path object, or '{MagicWords.LATEST}', or '{MagicWords.CURRENT}'.")
2064
+ _LOGGER.error(f"'model_checkpoint' must be a Path object, or '{MagicWords.BEST}', or '{MagicWords.CURRENT}'.")
2065
2065
  raise ValueError()
2066
2066
 
2067
2067
  # Validate val configuration
@@ -2120,7 +2120,7 @@ class DragonSequenceTrainer(_BaseDragonTrainer):
2120
2120
 
2121
2121
  def _evaluate(self,
2122
2122
  save_dir: Union[str, Path],
2123
- model_checkpoint: Union[Path, Literal["latest", "current"]],
2123
+ model_checkpoint: Union[Path, Literal["best", "current"]],
2124
2124
  data: Optional[Union[DataLoader, Dataset]],
2125
2125
  format_configuration: object):
2126
2126
  """
@@ -2131,11 +2131,11 @@ class DragonSequenceTrainer(_BaseDragonTrainer):
2131
2131
  # load model checkpoint
2132
2132
  if isinstance(model_checkpoint, Path):
2133
2133
  self._load_checkpoint(path=model_checkpoint)
2134
- elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback:
2134
+ elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback:
2135
2135
  path_to_latest = self._checkpoint_callback.best_checkpoint_path
2136
2136
  self._load_checkpoint(path_to_latest)
2137
- elif model_checkpoint == MagicWords.LATEST and self._checkpoint_callback is None:
2138
- _LOGGER.error(f"'model_checkpoint' set to '{MagicWords.LATEST}' but no checkpoint callback was found.")
2137
+ elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback is None:
2138
+ _LOGGER.error(f"'model_checkpoint' set to '{MagicWords.BEST}' but no checkpoint callback was found.")
2139
2139
  raise ValueError()
2140
2140
 
2141
2141
  # Dataloader
@@ -2273,7 +2273,7 @@ class DragonSequenceTrainer(_BaseDragonTrainer):
2273
2273
 
2274
2274
  def finalize_model_training(self,
2275
2275
  save_dir: Union[str, Path],
2276
- model_checkpoint: Union[Path, Literal['latest', 'current']],
2276
+ model_checkpoint: Union[Path, Literal['best', 'current']],
2277
2277
  finalize_config: Union[FinalizeSequenceSequencePrediction, FinalizeSequenceValuePrediction]):
2278
2278
  """
2279
2279
  Saves a finalized, "inference-ready" model state to a .pth file.
@@ -2282,10 +2282,10 @@ class DragonSequenceTrainer(_BaseDragonTrainer):
2282
2282
 
2283
2283
  Args:
2284
2284
  save_dir (Union[str, Path]): The directory to save the finalized model.
2285
- model_checkpoint (Union[Path, Literal["latest", "current"]]):
2285
+ model_checkpoint (Union[Path, Literal["best", "current"]]):
2286
2286
  - Path: Loads the model state from a specific checkpoint file.
2287
- - "latest": Loads the best model state saved by the `DragonModelCheckpoint` callback.
2288
- - "current": Uses the model's state as it is at the end of the `fit()` call.
2287
+ - "best": Loads the best model state saved by the `DragonModelCheckpoint` callback.
2288
+ - "current": Uses the model's state as it is.
2289
2289
  finalize_config (FinalizeSequencePrediction): A data class instance specific to the ML task containing task-specific metadata required for inference.
2290
2290
  """
2291
2291
  if self.kind == MLTaskKeys.SEQUENCE_SEQUENCE and not isinstance(finalize_config, FinalizeSequenceSequencePrediction):