dragon-ml-toolbox 19.11.0__py3-none-any.whl → 19.12.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dragon_ml_toolbox-19.11.0.dist-info → dragon_ml_toolbox-19.12.1.dist-info}/METADATA +1 -1
- {dragon_ml_toolbox-19.11.0.dist-info → dragon_ml_toolbox-19.12.1.dist-info}/RECORD +16 -16
- ml_tools/ML_callbacks.py +8 -4
- ml_tools/_core/_IO_tools.py +8 -2
- ml_tools/_core/_ML_callbacks.py +461 -171
- ml_tools/_core/_ML_configuration.py +15 -6
- ml_tools/_core/_ML_finalize_handler.py +5 -4
- ml_tools/_core/_ML_trainer.py +50 -50
- ml_tools/_core/_keys.py +32 -1
- ml_tools/_core/_path_manager.py +111 -2
- ml_tools/keys.py +2 -0
- ml_tools/path_manager.py +5 -1
- {dragon_ml_toolbox-19.11.0.dist-info → dragon_ml_toolbox-19.12.1.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-19.11.0.dist-info → dragon_ml_toolbox-19.12.1.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-19.11.0.dist-info → dragon_ml_toolbox-19.12.1.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-19.11.0.dist-info → dragon_ml_toolbox-19.12.1.dist-info}/top_level.txt +0 -0
|
@@ -660,18 +660,27 @@ class DragonTrainingConfig(_BaseModelParams):
|
|
|
660
660
|
initial_learning_rate: float,
|
|
661
661
|
batch_size: int,
|
|
662
662
|
random_state: int = 101,
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
663
|
+
# early_stop_patience: Optional[int] = None,
|
|
664
|
+
# scheduler_patience: Optional[int] = None,
|
|
665
|
+
# scheduler_lr_factor: Optional[float] = None,
|
|
666
666
|
**kwargs: Any) -> None:
|
|
667
|
+
"""
|
|
668
|
+
Args:
|
|
669
|
+
validation_size (float): Proportion of data for validation set.
|
|
670
|
+
test_size (float): Proportion of data for test set.
|
|
671
|
+
initial_learning_rate (float): Starting learning rate.
|
|
672
|
+
batch_size (int): Number of samples per training batch.
|
|
673
|
+
random_state (int): Seed for reproducibility.
|
|
674
|
+
**kwargs: Additional training parameters as key-value pairs.
|
|
675
|
+
"""
|
|
667
676
|
self.validation_size = validation_size
|
|
668
677
|
self.test_size = test_size
|
|
669
678
|
self.initial_learning_rate = initial_learning_rate
|
|
670
679
|
self.batch_size = batch_size
|
|
671
680
|
self.random_state = random_state
|
|
672
|
-
self.early_stop_patience = early_stop_patience
|
|
673
|
-
self.scheduler_patience = scheduler_patience
|
|
674
|
-
self.scheduler_lr_factor = scheduler_lr_factor
|
|
681
|
+
# self.early_stop_patience = early_stop_patience
|
|
682
|
+
# self.scheduler_patience = scheduler_patience
|
|
683
|
+
# self.scheduler_lr_factor = scheduler_lr_factor
|
|
675
684
|
|
|
676
685
|
# Process kwargs with validation
|
|
677
686
|
for key, value in kwargs.items():
|
|
@@ -51,7 +51,7 @@ class FinalizedFileHandler:
|
|
|
51
51
|
self._initial_sequence: Optional[np.ndarray] = None
|
|
52
52
|
self._target_name: Optional[str] = None
|
|
53
53
|
self._target_names: Optional[list[str]] = None
|
|
54
|
-
self._model_state_dict: Optional[Any] = None
|
|
54
|
+
self._model_state_dict: Optional[dict[str, Any]] = None
|
|
55
55
|
|
|
56
56
|
# Set warning outputs
|
|
57
57
|
self._verbose: bool=True
|
|
@@ -90,7 +90,7 @@ class FinalizedFileHandler:
|
|
|
90
90
|
|
|
91
91
|
else:
|
|
92
92
|
# It is a dict, but missing the keys, assume it is the raw state dict
|
|
93
|
-
_LOGGER.
|
|
93
|
+
_LOGGER.warning(f"File '{pth_path.name}' does not have the required keys for a Dragon-ML finalized-file. Keys found:\n {list(pth_file_content.keys())}")
|
|
94
94
|
self._model_state_dict = pth_file_content
|
|
95
95
|
|
|
96
96
|
|
|
@@ -113,9 +113,10 @@ class FinalizedFileHandler:
|
|
|
113
113
|
return self._task
|
|
114
114
|
|
|
115
115
|
@property
|
|
116
|
-
def model_state_dict(self):
|
|
116
|
+
def model_state_dict(self) -> dict[str, Any]:
|
|
117
117
|
"""Returns the model state dictionary."""
|
|
118
|
-
|
|
118
|
+
# No need to check for None, as it is guaranteed to be set in __init__
|
|
119
|
+
return self._model_state_dict # type: ignore
|
|
119
120
|
|
|
120
121
|
@property
|
|
121
122
|
def epoch(self) -> Optional[int]:
|
ml_tools/_core/_ML_trainer.py
CHANGED
|
@@ -7,7 +7,7 @@ import numpy as np
|
|
|
7
7
|
from abc import ABC, abstractmethod
|
|
8
8
|
|
|
9
9
|
from ._path_manager import make_fullpath
|
|
10
|
-
from ._ML_callbacks import _Callback, History, TqdmProgressBar, DragonModelCheckpoint,
|
|
10
|
+
from ._ML_callbacks import _Callback, History, TqdmProgressBar, DragonModelCheckpoint, _DragonEarlyStopping, _DragonLRScheduler
|
|
11
11
|
from ._ML_evaluation import classification_metrics, regression_metrics, plot_losses, shap_summary_plot, plot_attention_importance
|
|
12
12
|
from ._ML_evaluation_multi import multi_target_regression_metrics, multi_label_classification_metrics, multi_target_shap_summary_plot
|
|
13
13
|
from ._ML_vision_evaluation import segmentation_metrics, object_detection_metrics
|
|
@@ -66,8 +66,8 @@ class _BaseDragonTrainer(ABC):
|
|
|
66
66
|
device: Union[Literal['cuda', 'mps', 'cpu'],str],
|
|
67
67
|
dataloader_workers: int = 2,
|
|
68
68
|
checkpoint_callback: Optional[DragonModelCheckpoint] = None,
|
|
69
|
-
early_stopping_callback: Optional[
|
|
70
|
-
lr_scheduler_callback: Optional[
|
|
69
|
+
early_stopping_callback: Optional[_DragonEarlyStopping] = None,
|
|
70
|
+
lr_scheduler_callback: Optional[_DragonLRScheduler] = None,
|
|
71
71
|
extra_callbacks: Optional[List[_Callback]] = None):
|
|
72
72
|
|
|
73
73
|
self.model = model
|
|
@@ -271,18 +271,18 @@ class _BaseDragonTrainer(ABC):
|
|
|
271
271
|
self.model.to(self.device)
|
|
272
272
|
_LOGGER.info(f"Trainer and model moved to {self.device}.")
|
|
273
273
|
|
|
274
|
-
def _load_model_state_for_finalizing(self, model_checkpoint: Union[Path, Literal['
|
|
274
|
+
def _load_model_state_for_finalizing(self, model_checkpoint: Union[Path, Literal['best', 'current']]):
|
|
275
275
|
"""
|
|
276
276
|
Private helper to load the correct model state_dict based on user's choice.
|
|
277
277
|
This is called by finalize_model_training() in subclasses.
|
|
278
278
|
"""
|
|
279
279
|
if isinstance(model_checkpoint, Path):
|
|
280
280
|
self._load_checkpoint(path=model_checkpoint)
|
|
281
|
-
elif model_checkpoint == MagicWords.
|
|
281
|
+
elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback:
|
|
282
282
|
path_to_latest = self._checkpoint_callback.best_checkpoint_path
|
|
283
283
|
self._load_checkpoint(path_to_latest)
|
|
284
|
-
elif model_checkpoint == MagicWords.
|
|
285
|
-
_LOGGER.error(f"'model_checkpoint' set to '{MagicWords.
|
|
284
|
+
elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback is None:
|
|
285
|
+
_LOGGER.error(f"'model_checkpoint' set to '{MagicWords.BEST}' but no checkpoint callback was found.")
|
|
286
286
|
raise ValueError()
|
|
287
287
|
elif model_checkpoint == MagicWords.CURRENT:
|
|
288
288
|
pass
|
|
@@ -336,8 +336,8 @@ class DragonTrainer(_BaseDragonTrainer):
|
|
|
336
336
|
optimizer: torch.optim.Optimizer,
|
|
337
337
|
device: Union[Literal['cuda', 'mps', 'cpu'],str],
|
|
338
338
|
checkpoint_callback: Optional[DragonModelCheckpoint],
|
|
339
|
-
early_stopping_callback: Optional[
|
|
340
|
-
lr_scheduler_callback: Optional[
|
|
339
|
+
early_stopping_callback: Optional[_DragonEarlyStopping],
|
|
340
|
+
lr_scheduler_callback: Optional[_DragonLRScheduler],
|
|
341
341
|
extra_callbacks: Optional[List[_Callback]] = None,
|
|
342
342
|
criterion: Union[nn.Module,Literal["auto"]] = "auto",
|
|
343
343
|
dataloader_workers: int = 2):
|
|
@@ -634,7 +634,7 @@ class DragonTrainer(_BaseDragonTrainer):
|
|
|
634
634
|
|
|
635
635
|
def evaluate(self,
|
|
636
636
|
save_dir: Union[str, Path],
|
|
637
|
-
model_checkpoint: Union[Path, Literal["
|
|
637
|
+
model_checkpoint: Union[Path, Literal["best", "current"]],
|
|
638
638
|
classification_threshold: Optional[float] = None,
|
|
639
639
|
test_data: Optional[Union[DataLoader, Dataset]] = None,
|
|
640
640
|
val_format_configuration: Optional[Union[
|
|
@@ -665,7 +665,7 @@ class DragonTrainer(_BaseDragonTrainer):
|
|
|
665
665
|
Args:
|
|
666
666
|
model_checkpoint ('auto' | Path | None):
|
|
667
667
|
- Path to a valid checkpoint for the model. The state of the trained model will be overwritten in place.
|
|
668
|
-
- If '
|
|
668
|
+
- If 'best', the best checkpoint will be loaded if a DragonModelCheckpoint was provided. The state of the trained model will be overwritten in place.
|
|
669
669
|
- If 'current', use the current state of the trained model up the latest trained epoch.
|
|
670
670
|
save_dir (str | Path): Directory to save all reports and plots.
|
|
671
671
|
classification_threshold (float | None): Used for tasks using a binary approach (binary classification, binary segmentation, multilabel binary classification)
|
|
@@ -676,10 +676,10 @@ class DragonTrainer(_BaseDragonTrainer):
|
|
|
676
676
|
# Validate model checkpoint
|
|
677
677
|
if isinstance(model_checkpoint, Path):
|
|
678
678
|
checkpoint_validated = make_fullpath(model_checkpoint, enforce="file")
|
|
679
|
-
elif model_checkpoint in [MagicWords.
|
|
679
|
+
elif model_checkpoint in [MagicWords.BEST, MagicWords.CURRENT]:
|
|
680
680
|
checkpoint_validated = model_checkpoint
|
|
681
681
|
else:
|
|
682
|
-
_LOGGER.error(f"'model_checkpoint' must be a Path object, or the string '{MagicWords.
|
|
682
|
+
_LOGGER.error(f"'model_checkpoint' must be a Path object, or the string '{MagicWords.BEST}', or the string '{MagicWords.CURRENT}'.")
|
|
683
683
|
raise ValueError()
|
|
684
684
|
|
|
685
685
|
# Validate classification threshold
|
|
@@ -778,7 +778,7 @@ class DragonTrainer(_BaseDragonTrainer):
|
|
|
778
778
|
|
|
779
779
|
def _evaluate(self,
|
|
780
780
|
save_dir: Union[str, Path],
|
|
781
|
-
model_checkpoint: Union[Path, Literal["
|
|
781
|
+
model_checkpoint: Union[Path, Literal["best", "current"]],
|
|
782
782
|
classification_threshold: float,
|
|
783
783
|
data: Optional[Union[DataLoader, Dataset]],
|
|
784
784
|
format_configuration: Optional[Union[
|
|
@@ -804,11 +804,11 @@ class DragonTrainer(_BaseDragonTrainer):
|
|
|
804
804
|
# load model checkpoint
|
|
805
805
|
if isinstance(model_checkpoint, Path):
|
|
806
806
|
self._load_checkpoint(path=model_checkpoint)
|
|
807
|
-
elif model_checkpoint == MagicWords.
|
|
807
|
+
elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback:
|
|
808
808
|
path_to_latest = self._checkpoint_callback.best_checkpoint_path
|
|
809
809
|
self._load_checkpoint(path_to_latest)
|
|
810
|
-
elif model_checkpoint == MagicWords.
|
|
811
|
-
_LOGGER.error(f"'model_checkpoint' set to '{MagicWords.
|
|
810
|
+
elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback is None:
|
|
811
|
+
_LOGGER.error(f"'model_checkpoint' set to '{MagicWords.BEST}' but no checkpoint callback was found.")
|
|
812
812
|
raise ValueError()
|
|
813
813
|
|
|
814
814
|
# Dataloader
|
|
@@ -1352,7 +1352,7 @@ class DragonTrainer(_BaseDragonTrainer):
|
|
|
1352
1352
|
_LOGGER.error("No attention weights were collected from the model.")
|
|
1353
1353
|
|
|
1354
1354
|
def finalize_model_training(self,
|
|
1355
|
-
model_checkpoint: Union[Path, Literal['
|
|
1355
|
+
model_checkpoint: Union[Path, Literal['best', 'current']],
|
|
1356
1356
|
save_dir: Union[str, Path],
|
|
1357
1357
|
finalize_config: Union[FinalizeRegression,
|
|
1358
1358
|
FinalizeMultiTargetRegression,
|
|
@@ -1369,10 +1369,10 @@ class DragonTrainer(_BaseDragonTrainer):
|
|
|
1369
1369
|
This method saves the model's `state_dict`, the final epoch number, and optional configuration for the task at hand.
|
|
1370
1370
|
|
|
1371
1371
|
Args:
|
|
1372
|
-
model_checkpoint (Path | "
|
|
1372
|
+
model_checkpoint (Path | "best" | "current"):
|
|
1373
1373
|
- Path: Loads the model state from a specific checkpoint file.
|
|
1374
|
-
- "
|
|
1375
|
-
- "current": Uses the model's state as it is
|
|
1374
|
+
- "best": Loads the best model state saved by the `DragonModelCheckpoint` callback.
|
|
1375
|
+
- "current": Uses the model's state as it is.
|
|
1376
1376
|
save_dir (str | Path): The directory to save the finalized model.
|
|
1377
1377
|
finalize_config (object): A data class instance specific to the ML task containing task-specific metadata required for inference.
|
|
1378
1378
|
"""
|
|
@@ -1442,8 +1442,8 @@ class DragonDetectionTrainer(_BaseDragonTrainer):
|
|
|
1442
1442
|
collate_fn: Callable, optimizer: torch.optim.Optimizer,
|
|
1443
1443
|
device: Union[Literal['cuda', 'mps', 'cpu'],str],
|
|
1444
1444
|
checkpoint_callback: Optional[DragonModelCheckpoint],
|
|
1445
|
-
early_stopping_callback: Optional[
|
|
1446
|
-
lr_scheduler_callback: Optional[
|
|
1445
|
+
early_stopping_callback: Optional[_DragonEarlyStopping],
|
|
1446
|
+
lr_scheduler_callback: Optional[_DragonLRScheduler],
|
|
1447
1447
|
extra_callbacks: Optional[List[_Callback]] = None,
|
|
1448
1448
|
dataloader_workers: int = 2):
|
|
1449
1449
|
"""
|
|
@@ -1601,7 +1601,7 @@ class DragonDetectionTrainer(_BaseDragonTrainer):
|
|
|
1601
1601
|
|
|
1602
1602
|
def evaluate(self,
|
|
1603
1603
|
save_dir: Union[str, Path],
|
|
1604
|
-
model_checkpoint: Union[Path, Literal["
|
|
1604
|
+
model_checkpoint: Union[Path, Literal["best", "current"]],
|
|
1605
1605
|
test_data: Optional[Union[DataLoader, Dataset]] = None):
|
|
1606
1606
|
"""
|
|
1607
1607
|
Evaluates the model using object detection mAP metrics.
|
|
@@ -1610,17 +1610,17 @@ class DragonDetectionTrainer(_BaseDragonTrainer):
|
|
|
1610
1610
|
save_dir (str | Path): Directory to save all reports and plots.
|
|
1611
1611
|
model_checkpoint ('auto' | Path | None):
|
|
1612
1612
|
- Path to a valid checkpoint for the model. The state of the trained model will be overwritten in place.
|
|
1613
|
-
- If '
|
|
1613
|
+
- If 'best', the best checkpoint will be loaded if a DragonModelCheckpoint was provided. The state of the trained model will be overwritten in place.
|
|
1614
1614
|
- If 'current', use the current state of the trained model up the latest trained epoch.
|
|
1615
1615
|
test_data (DataLoader | Dataset | None): Optional Test data to evaluate the model performance. Validation and Test metrics will be saved to subdirectories.
|
|
1616
1616
|
"""
|
|
1617
1617
|
# Validate model checkpoint
|
|
1618
1618
|
if isinstance(model_checkpoint, Path):
|
|
1619
1619
|
checkpoint_validated = make_fullpath(model_checkpoint, enforce="file")
|
|
1620
|
-
elif model_checkpoint in [MagicWords.
|
|
1620
|
+
elif model_checkpoint in [MagicWords.BEST, MagicWords.CURRENT]:
|
|
1621
1621
|
checkpoint_validated = model_checkpoint
|
|
1622
1622
|
else:
|
|
1623
|
-
_LOGGER.error(f"'model_checkpoint' must be a Path object, or the string '{MagicWords.
|
|
1623
|
+
_LOGGER.error(f"'model_checkpoint' must be a Path object, or the string '{MagicWords.BEST}', or the string '{MagicWords.CURRENT}'.")
|
|
1624
1624
|
raise ValueError()
|
|
1625
1625
|
|
|
1626
1626
|
# Validate directory
|
|
@@ -1656,7 +1656,7 @@ class DragonDetectionTrainer(_BaseDragonTrainer):
|
|
|
1656
1656
|
|
|
1657
1657
|
def _evaluate(self,
|
|
1658
1658
|
save_dir: Union[str, Path],
|
|
1659
|
-
model_checkpoint: Union[Path, Literal["
|
|
1659
|
+
model_checkpoint: Union[Path, Literal["best", "current"]],
|
|
1660
1660
|
data: Optional[Union[DataLoader, Dataset]]):
|
|
1661
1661
|
"""
|
|
1662
1662
|
Changed to a private helper method
|
|
@@ -1667,7 +1667,7 @@ class DragonDetectionTrainer(_BaseDragonTrainer):
|
|
|
1667
1667
|
data (DataLoader | Dataset | None): The data to evaluate on. If None, defaults to the trainer's internal test_dataset.
|
|
1668
1668
|
model_checkpoint ('auto' | Path | None):
|
|
1669
1669
|
- Path to a valid checkpoint for the model. The state of the trained model will be overwritten in place.
|
|
1670
|
-
- If '
|
|
1670
|
+
- If 'best', the best checkpoint will be loaded if a DragonModelCheckpoint was provided. The state of the trained model will be overwritten in place.
|
|
1671
1671
|
- If 'current', use the current state of the trained model up the latest trained epoch.
|
|
1672
1672
|
"""
|
|
1673
1673
|
dataset_for_artifacts = None
|
|
@@ -1676,11 +1676,11 @@ class DragonDetectionTrainer(_BaseDragonTrainer):
|
|
|
1676
1676
|
# load model checkpoint
|
|
1677
1677
|
if isinstance(model_checkpoint, Path):
|
|
1678
1678
|
self._load_checkpoint(path=model_checkpoint)
|
|
1679
|
-
elif model_checkpoint == MagicWords.
|
|
1679
|
+
elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback:
|
|
1680
1680
|
path_to_latest = self._checkpoint_callback.best_checkpoint_path
|
|
1681
1681
|
self._load_checkpoint(path_to_latest)
|
|
1682
|
-
elif model_checkpoint == MagicWords.
|
|
1683
|
-
_LOGGER.error(f"'model_checkpoint' set to '{MagicWords.
|
|
1682
|
+
elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback is None:
|
|
1683
|
+
_LOGGER.error(f"'model_checkpoint' set to '{MagicWords.BEST}' but no checkpoint callback was found.")
|
|
1684
1684
|
raise ValueError()
|
|
1685
1685
|
|
|
1686
1686
|
# Dataloader
|
|
@@ -1767,7 +1767,7 @@ class DragonDetectionTrainer(_BaseDragonTrainer):
|
|
|
1767
1767
|
|
|
1768
1768
|
def finalize_model_training(self,
|
|
1769
1769
|
save_dir: Union[str, Path],
|
|
1770
|
-
model_checkpoint: Union[Path, Literal['
|
|
1770
|
+
model_checkpoint: Union[Path, Literal['best', 'current']],
|
|
1771
1771
|
finalize_config: FinalizeObjectDetection
|
|
1772
1772
|
):
|
|
1773
1773
|
"""
|
|
@@ -1777,10 +1777,10 @@ class DragonDetectionTrainer(_BaseDragonTrainer):
|
|
|
1777
1777
|
|
|
1778
1778
|
Args:
|
|
1779
1779
|
save_dir (Union[str, Path]): The directory to save the finalized model.
|
|
1780
|
-
model_checkpoint (Union[Path, Literal["
|
|
1780
|
+
model_checkpoint (Union[Path, Literal["best", "current"]]):
|
|
1781
1781
|
- Path: Loads the model state from a specific checkpoint file.
|
|
1782
|
-
- "
|
|
1783
|
-
- "current": Uses the model's state as it is
|
|
1782
|
+
- "best": Loads the best model state saved by the `DragonModelCheckpoint` callback.
|
|
1783
|
+
- "current": Uses the model's state as it is.
|
|
1784
1784
|
finalize_config (FinalizeObjectDetection): A data class instance specific to the ML task containing task-specific metadata required for inference.
|
|
1785
1785
|
"""
|
|
1786
1786
|
if not isinstance(finalize_config, FinalizeObjectDetection):
|
|
@@ -1818,8 +1818,8 @@ class DragonSequenceTrainer(_BaseDragonTrainer):
|
|
|
1818
1818
|
optimizer: torch.optim.Optimizer,
|
|
1819
1819
|
device: Union[Literal['cuda', 'mps', 'cpu'],str],
|
|
1820
1820
|
checkpoint_callback: Optional[DragonModelCheckpoint],
|
|
1821
|
-
early_stopping_callback: Optional[
|
|
1822
|
-
lr_scheduler_callback: Optional[
|
|
1821
|
+
early_stopping_callback: Optional[_DragonEarlyStopping],
|
|
1822
|
+
lr_scheduler_callback: Optional[_DragonLRScheduler],
|
|
1823
1823
|
extra_callbacks: Optional[List[_Callback]] = None,
|
|
1824
1824
|
criterion: Union[nn.Module,Literal["auto"]] = "auto",
|
|
1825
1825
|
dataloader_workers: int = 2):
|
|
@@ -2036,7 +2036,7 @@ class DragonSequenceTrainer(_BaseDragonTrainer):
|
|
|
2036
2036
|
|
|
2037
2037
|
def evaluate(self,
|
|
2038
2038
|
save_dir: Union[str, Path],
|
|
2039
|
-
model_checkpoint: Union[Path, Literal["
|
|
2039
|
+
model_checkpoint: Union[Path, Literal["best", "current"]],
|
|
2040
2040
|
test_data: Optional[Union[DataLoader, Dataset]] = None,
|
|
2041
2041
|
val_format_configuration: Optional[Union[SequenceValueMetricsFormat,
|
|
2042
2042
|
SequenceSequenceMetricsFormat]]=None,
|
|
@@ -2048,7 +2048,7 @@ class DragonSequenceTrainer(_BaseDragonTrainer):
|
|
|
2048
2048
|
Args:
|
|
2049
2049
|
model_checkpoint ('auto' | Path | None):
|
|
2050
2050
|
- Path to a valid checkpoint for the model.
|
|
2051
|
-
- If '
|
|
2051
|
+
- If 'best', the best checkpoint will be loaded.
|
|
2052
2052
|
- If 'current', use the current state of the trained model.
|
|
2053
2053
|
save_dir (str | Path): Directory to save all reports and plots.
|
|
2054
2054
|
test_data (DataLoader | Dataset | None): Optional Test data.
|
|
@@ -2058,10 +2058,10 @@ class DragonSequenceTrainer(_BaseDragonTrainer):
|
|
|
2058
2058
|
# Validate model checkpoint
|
|
2059
2059
|
if isinstance(model_checkpoint, Path):
|
|
2060
2060
|
checkpoint_validated = make_fullpath(model_checkpoint, enforce="file")
|
|
2061
|
-
elif model_checkpoint in [MagicWords.
|
|
2061
|
+
elif model_checkpoint in [MagicWords.BEST, MagicWords.CURRENT]:
|
|
2062
2062
|
checkpoint_validated = model_checkpoint
|
|
2063
2063
|
else:
|
|
2064
|
-
_LOGGER.error(f"'model_checkpoint' must be a Path object, or '{MagicWords.
|
|
2064
|
+
_LOGGER.error(f"'model_checkpoint' must be a Path object, or '{MagicWords.BEST}', or '{MagicWords.CURRENT}'.")
|
|
2065
2065
|
raise ValueError()
|
|
2066
2066
|
|
|
2067
2067
|
# Validate val configuration
|
|
@@ -2120,7 +2120,7 @@ class DragonSequenceTrainer(_BaseDragonTrainer):
|
|
|
2120
2120
|
|
|
2121
2121
|
def _evaluate(self,
|
|
2122
2122
|
save_dir: Union[str, Path],
|
|
2123
|
-
model_checkpoint: Union[Path, Literal["
|
|
2123
|
+
model_checkpoint: Union[Path, Literal["best", "current"]],
|
|
2124
2124
|
data: Optional[Union[DataLoader, Dataset]],
|
|
2125
2125
|
format_configuration: object):
|
|
2126
2126
|
"""
|
|
@@ -2131,11 +2131,11 @@ class DragonSequenceTrainer(_BaseDragonTrainer):
|
|
|
2131
2131
|
# load model checkpoint
|
|
2132
2132
|
if isinstance(model_checkpoint, Path):
|
|
2133
2133
|
self._load_checkpoint(path=model_checkpoint)
|
|
2134
|
-
elif model_checkpoint == MagicWords.
|
|
2134
|
+
elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback:
|
|
2135
2135
|
path_to_latest = self._checkpoint_callback.best_checkpoint_path
|
|
2136
2136
|
self._load_checkpoint(path_to_latest)
|
|
2137
|
-
elif model_checkpoint == MagicWords.
|
|
2138
|
-
_LOGGER.error(f"'model_checkpoint' set to '{MagicWords.
|
|
2137
|
+
elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback is None:
|
|
2138
|
+
_LOGGER.error(f"'model_checkpoint' set to '{MagicWords.BEST}' but no checkpoint callback was found.")
|
|
2139
2139
|
raise ValueError()
|
|
2140
2140
|
|
|
2141
2141
|
# Dataloader
|
|
@@ -2273,7 +2273,7 @@ class DragonSequenceTrainer(_BaseDragonTrainer):
|
|
|
2273
2273
|
|
|
2274
2274
|
def finalize_model_training(self,
|
|
2275
2275
|
save_dir: Union[str, Path],
|
|
2276
|
-
model_checkpoint: Union[Path, Literal['
|
|
2276
|
+
model_checkpoint: Union[Path, Literal['best', 'current']],
|
|
2277
2277
|
finalize_config: Union[FinalizeSequenceSequencePrediction, FinalizeSequenceValuePrediction]):
|
|
2278
2278
|
"""
|
|
2279
2279
|
Saves a finalized, "inference-ready" model state to a .pth file.
|
|
@@ -2282,10 +2282,10 @@ class DragonSequenceTrainer(_BaseDragonTrainer):
|
|
|
2282
2282
|
|
|
2283
2283
|
Args:
|
|
2284
2284
|
save_dir (Union[str, Path]): The directory to save the finalized model.
|
|
2285
|
-
model_checkpoint (Union[Path, Literal["
|
|
2285
|
+
model_checkpoint (Union[Path, Literal["best", "current"]]):
|
|
2286
2286
|
- Path: Loads the model state from a specific checkpoint file.
|
|
2287
|
-
- "
|
|
2288
|
-
- "current": Uses the model's state as it is
|
|
2287
|
+
- "best": Loads the best model state saved by the `DragonModelCheckpoint` callback.
|
|
2288
|
+
- "current": Uses the model's state as it is.
|
|
2289
2289
|
finalize_config (FinalizeSequencePrediction): A data class instance specific to the ML task containing task-specific metadata required for inference.
|
|
2290
2290
|
"""
|
|
2291
2291
|
if self.kind == MLTaskKeys.SEQUENCE_SEQUENCE and not isinstance(finalize_config, FinalizeSequenceSequencePrediction):
|
ml_tools/_core/_keys.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
class MagicWords:
|
|
2
2
|
"""General purpose keys"""
|
|
3
|
-
|
|
3
|
+
BEST = "best"
|
|
4
4
|
CURRENT = "current"
|
|
5
5
|
RENAME = "rename"
|
|
6
6
|
UNKNOWN = "unknown"
|
|
@@ -200,6 +200,37 @@ class MLTaskKeys:
|
|
|
200
200
|
ALL_BINARY_TASKS = [BINARY_CLASSIFICATION, MULTILABEL_BINARY_CLASSIFICATION, BINARY_IMAGE_CLASSIFICATION, BINARY_SEGMENTATION]
|
|
201
201
|
|
|
202
202
|
|
|
203
|
+
class _PublicTaskKeys:
|
|
204
|
+
"""
|
|
205
|
+
Task keys used in the Dragon ML pipeline:
|
|
206
|
+
|
|
207
|
+
1. REGRESSION
|
|
208
|
+
2. MULTITARGET_REGRESSION
|
|
209
|
+
3. BINARY_CLASSIFICATION
|
|
210
|
+
4. MULTICLASS_CLASSIFICATION
|
|
211
|
+
5. MULTILABEL_BINARY_CLASSIFICATION
|
|
212
|
+
6. BINARY_IMAGE_CLASSIFICATION
|
|
213
|
+
7. MULTICLASS_IMAGE_CLASSIFICATION
|
|
214
|
+
8. BINARY_SEGMENTATION
|
|
215
|
+
9. MULTICLASS_SEGMENTATION
|
|
216
|
+
10. OBJECT_DETECTION
|
|
217
|
+
11. SEQUENCE_SEQUENCE
|
|
218
|
+
12. SEQUENCE_VALUE
|
|
219
|
+
"""
|
|
220
|
+
REGRESSION = MLTaskKeys.REGRESSION
|
|
221
|
+
MULTITARGET_REGRESSION = MLTaskKeys.MULTITARGET_REGRESSION
|
|
222
|
+
BINARY_CLASSIFICATION = MLTaskKeys.BINARY_CLASSIFICATION
|
|
223
|
+
MULTICLASS_CLASSIFICATION = MLTaskKeys.MULTICLASS_CLASSIFICATION
|
|
224
|
+
MULTILABEL_BINARY_CLASSIFICATION = MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION
|
|
225
|
+
BINARY_IMAGE_CLASSIFICATION = MLTaskKeys.BINARY_IMAGE_CLASSIFICATION
|
|
226
|
+
MULTICLASS_IMAGE_CLASSIFICATION = MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION
|
|
227
|
+
BINARY_SEGMENTATION = MLTaskKeys.BINARY_SEGMENTATION
|
|
228
|
+
MULTICLASS_SEGMENTATION = MLTaskKeys.MULTICLASS_SEGMENTATION
|
|
229
|
+
OBJECT_DETECTION = MLTaskKeys.OBJECT_DETECTION
|
|
230
|
+
SEQUENCE_SEQUENCE = MLTaskKeys.SEQUENCE_SEQUENCE
|
|
231
|
+
SEQUENCE_VALUE = MLTaskKeys.SEQUENCE_VALUE
|
|
232
|
+
|
|
233
|
+
|
|
203
234
|
class DragonTrainerKeys:
|
|
204
235
|
VALIDATION_METRICS_DIR = "Validation_Metrics"
|
|
205
236
|
TEST_METRICS_DIR = "Test_Metrics"
|
ml_tools/_core/_path_manager.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
|
1
|
-
from pprint import pprint
|
|
2
1
|
from typing import Optional, List, Dict, Union, Literal
|
|
3
2
|
from pathlib import Path
|
|
4
3
|
import re
|
|
5
4
|
import sys
|
|
5
|
+
import shutil
|
|
6
6
|
|
|
7
7
|
from ._script_info import _script_info
|
|
8
8
|
from ._logger import get_logger
|
|
@@ -17,7 +17,9 @@ __all__ = [
|
|
|
17
17
|
"sanitize_filename",
|
|
18
18
|
"list_csv_paths",
|
|
19
19
|
"list_files_by_extension",
|
|
20
|
-
"list_subdirectories"
|
|
20
|
+
"list_subdirectories",
|
|
21
|
+
"clean_directory",
|
|
22
|
+
"safe_move",
|
|
21
23
|
]
|
|
22
24
|
|
|
23
25
|
|
|
@@ -542,5 +544,112 @@ def list_subdirectories(
|
|
|
542
544
|
return dir_map
|
|
543
545
|
|
|
544
546
|
|
|
547
|
+
def clean_directory(directory: Union[str, Path], verbose: bool = False) -> None:
|
|
548
|
+
"""
|
|
549
|
+
⚠️ DANGER: DESTRUCTIVE OPERATION ⚠️
|
|
550
|
+
|
|
551
|
+
Deletes all files and subdirectories inside the specified directory. It is designed to empty a folder, not delete the folder itself.
|
|
552
|
+
|
|
553
|
+
Safety: It skips hidden files and directories (those starting with a period '.'). This works for macOS/Linux hidden files and dot-config folders on Windows.
|
|
554
|
+
|
|
555
|
+
Args:
|
|
556
|
+
directory (str | Path): The directory path to clean.
|
|
557
|
+
verbose (bool): If True, prints the name of each top-level item deleted.
|
|
558
|
+
"""
|
|
559
|
+
target_dir = make_fullpath(directory, enforce="directory")
|
|
560
|
+
|
|
561
|
+
if verbose:
|
|
562
|
+
_LOGGER.warning(f"Starting cleanup of directory: {target_dir}")
|
|
563
|
+
|
|
564
|
+
for item in target_dir.iterdir():
|
|
565
|
+
# Safety Check: Skip hidden files/dirs
|
|
566
|
+
if item.name.startswith("."):
|
|
567
|
+
continue
|
|
568
|
+
|
|
569
|
+
try:
|
|
570
|
+
if item.is_file() or item.is_symlink():
|
|
571
|
+
item.unlink()
|
|
572
|
+
if verbose:
|
|
573
|
+
print(f" 🗑️ Deleted file: {item.name}")
|
|
574
|
+
elif item.is_dir():
|
|
575
|
+
shutil.rmtree(item)
|
|
576
|
+
if verbose:
|
|
577
|
+
print(f" 🗑️ Deleted directory: {item.name}")
|
|
578
|
+
except Exception as e:
|
|
579
|
+
_LOGGER.warning(f"Failed to delete item '{item.name}': {e}")
|
|
580
|
+
continue
|
|
581
|
+
|
|
582
|
+
|
|
583
|
+
def safe_move(
|
|
584
|
+
source: Union[str, Path],
|
|
585
|
+
final_destination: Union[str, Path],
|
|
586
|
+
rename: Optional[str] = None,
|
|
587
|
+
overwrite: bool = False
|
|
588
|
+
) -> Path:
|
|
589
|
+
"""
|
|
590
|
+
Moves a file or directory to a destination directory with safety checks.
|
|
591
|
+
|
|
592
|
+
Features:
|
|
593
|
+
- Supports optional renaming (sanitized automatically).
|
|
594
|
+
- PRESERVES file extensions during renaming (cannot be modified).
|
|
595
|
+
- Prevents accidental overwrites unless explicit.
|
|
596
|
+
|
|
597
|
+
Args:
|
|
598
|
+
source (str | Path): The file or directory to move.
|
|
599
|
+
final_destination (str | Path): The destination DIRECTORY where the item will be moved. It will be created if it does not exist.
|
|
600
|
+
rename (Optional[str]): If provided, the moved item will be renamed to this. Note: For files, the extension is strictly preserved.
|
|
601
|
+
overwrite (bool): If True, overwrites the destination path if it exists.
|
|
602
|
+
|
|
603
|
+
Returns:
|
|
604
|
+
Path: The new absolute path of the moved item.
|
|
605
|
+
"""
|
|
606
|
+
# 1. Validation and Setup
|
|
607
|
+
src_path = make_fullpath(source, make=False)
|
|
608
|
+
|
|
609
|
+
# Ensure destination directory exists
|
|
610
|
+
dest_dir_path = make_fullpath(final_destination, make=True, enforce="directory")
|
|
611
|
+
|
|
612
|
+
# 2. Determine Target Name
|
|
613
|
+
if rename:
|
|
614
|
+
sanitized_name = sanitize_filename(rename)
|
|
615
|
+
if src_path.is_file():
|
|
616
|
+
# Strict Extension Preservation
|
|
617
|
+
final_name = f"{sanitized_name}{src_path.suffix}"
|
|
618
|
+
else:
|
|
619
|
+
final_name = sanitized_name
|
|
620
|
+
else:
|
|
621
|
+
final_name = src_path.name
|
|
622
|
+
|
|
623
|
+
final_path = dest_dir_path / final_name
|
|
624
|
+
|
|
625
|
+
# 3. Safety Checks (Collision Detection)
|
|
626
|
+
if final_path.exists():
|
|
627
|
+
if not overwrite:
|
|
628
|
+
_LOGGER.error(f"Destination already exists: '{final_path}'. Use overwrite=True to force.")
|
|
629
|
+
raise FileExistsError()
|
|
630
|
+
|
|
631
|
+
# Smart Overwrite Handling
|
|
632
|
+
if final_path.is_dir():
|
|
633
|
+
if src_path.is_file():
|
|
634
|
+
_LOGGER.error(f"Cannot overwrite directory '{final_path}' with file '{src_path}'")
|
|
635
|
+
raise IsADirectoryError()
|
|
636
|
+
# If overwriting a directory, we must remove the old one first to avoid nesting/errors
|
|
637
|
+
shutil.rmtree(final_path)
|
|
638
|
+
else:
|
|
639
|
+
# Destination is a file
|
|
640
|
+
if src_path.is_dir():
|
|
641
|
+
_LOGGER.error(f"Cannot overwrite file '{final_path}' with directory '{src_path}'")
|
|
642
|
+
raise FileExistsError()
|
|
643
|
+
final_path.unlink()
|
|
644
|
+
|
|
645
|
+
# 4. Perform Move
|
|
646
|
+
try:
|
|
647
|
+
shutil.move(str(src_path), str(final_path))
|
|
648
|
+
return final_path
|
|
649
|
+
except Exception as e:
|
|
650
|
+
_LOGGER.exception(f"Failed to move '{src_path}' to '{final_path}'")
|
|
651
|
+
raise e
|
|
652
|
+
|
|
653
|
+
|
|
545
654
|
def info():
|
|
546
655
|
_script_info(__all__)
|
ml_tools/keys.py
CHANGED
|
@@ -2,10 +2,12 @@ from ._core._keys import (
|
|
|
2
2
|
PyTorchInferenceKeys as InferenceKeys,
|
|
3
3
|
_CheckpointCallbackKeys as CheckpointCallbackKeys,
|
|
4
4
|
_FinalizedFileKeys as FinalizedFileKeys,
|
|
5
|
+
_PublicTaskKeys as TaskKeys,
|
|
5
6
|
)
|
|
6
7
|
|
|
7
8
|
__all__ = [
|
|
8
9
|
"InferenceKeys",
|
|
9
10
|
"CheckpointCallbackKeys",
|
|
10
11
|
"FinalizedFileKeys",
|
|
12
|
+
"TaskKeys",
|
|
11
13
|
]
|
ml_tools/path_manager.py
CHANGED
|
@@ -5,6 +5,8 @@ from ._core._path_manager import (
|
|
|
5
5
|
list_csv_paths,
|
|
6
6
|
list_files_by_extension,
|
|
7
7
|
list_subdirectories,
|
|
8
|
+
clean_directory,
|
|
9
|
+
safe_move,
|
|
8
10
|
info
|
|
9
11
|
)
|
|
10
12
|
|
|
@@ -14,5 +16,7 @@ __all__ = [
|
|
|
14
16
|
"sanitize_filename",
|
|
15
17
|
"list_csv_paths",
|
|
16
18
|
"list_files_by_extension",
|
|
17
|
-
"list_subdirectories"
|
|
19
|
+
"list_subdirectories",
|
|
20
|
+
"clean_directory",
|
|
21
|
+
"safe_move",
|
|
18
22
|
]
|
|
File without changes
|
{dragon_ml_toolbox-19.11.0.dist-info → dragon_ml_toolbox-19.12.1.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|