dragon-ml-toolbox 19.13.0__py3-none-any.whl → 20.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/METADATA +29 -46
- dragon_ml_toolbox-20.0.0.dist-info/RECORD +178 -0
- ml_tools/{ETL_cleaning.py → ETL_cleaning/__init__.py} +13 -5
- ml_tools/ETL_cleaning/_basic_clean.py +351 -0
- ml_tools/ETL_cleaning/_clean_tools.py +128 -0
- ml_tools/ETL_cleaning/_dragon_cleaner.py +245 -0
- ml_tools/ETL_cleaning/_imprimir.py +13 -0
- ml_tools/{ETL_engineering.py → ETL_engineering/__init__.py} +8 -4
- ml_tools/ETL_engineering/_dragon_engineering.py +261 -0
- ml_tools/ETL_engineering/_imprimir.py +24 -0
- ml_tools/{_core/_ETL_engineering.py → ETL_engineering/_transforms.py} +14 -267
- ml_tools/{_core → GUI_tools}/_GUI_tools.py +37 -40
- ml_tools/{GUI_tools.py → GUI_tools/__init__.py} +7 -5
- ml_tools/GUI_tools/_imprimir.py +12 -0
- ml_tools/IO_tools/_IO_loggers.py +235 -0
- ml_tools/IO_tools/_IO_save_load.py +151 -0
- ml_tools/IO_tools/_IO_utils.py +140 -0
- ml_tools/{IO_tools.py → IO_tools/__init__.py} +13 -5
- ml_tools/IO_tools/_imprimir.py +14 -0
- ml_tools/MICE/_MICE_imputation.py +132 -0
- ml_tools/{MICE_imputation.py → MICE/__init__.py} +6 -7
- ml_tools/{_core/_MICE_imputation.py → MICE/_dragon_mice.py} +243 -322
- ml_tools/MICE/_imprimir.py +11 -0
- ml_tools/{ML_callbacks.py → ML_callbacks/__init__.py} +12 -4
- ml_tools/ML_callbacks/_base.py +101 -0
- ml_tools/ML_callbacks/_checkpoint.py +232 -0
- ml_tools/ML_callbacks/_early_stop.py +208 -0
- ml_tools/ML_callbacks/_imprimir.py +12 -0
- ml_tools/ML_callbacks/_scheduler.py +197 -0
- ml_tools/{ML_chaining_utilities.py → ML_chain/__init__.py} +8 -3
- ml_tools/{_core/_ML_chaining_utilities.py → ML_chain/_chaining_tools.py} +5 -129
- ml_tools/ML_chain/_dragon_chain.py +140 -0
- ml_tools/ML_chain/_imprimir.py +11 -0
- ml_tools/ML_configuration/__init__.py +90 -0
- ml_tools/ML_configuration/_base_model_config.py +69 -0
- ml_tools/ML_configuration/_finalize.py +366 -0
- ml_tools/ML_configuration/_imprimir.py +47 -0
- ml_tools/ML_configuration/_metrics.py +593 -0
- ml_tools/ML_configuration/_models.py +206 -0
- ml_tools/ML_configuration/_training.py +124 -0
- ml_tools/ML_datasetmaster/__init__.py +28 -0
- ml_tools/ML_datasetmaster/_base_datasetmaster.py +337 -0
- ml_tools/{_core/_ML_datasetmaster.py → ML_datasetmaster/_datasetmaster.py} +9 -329
- ml_tools/ML_datasetmaster/_imprimir.py +15 -0
- ml_tools/{_core/_ML_sequence_datasetmaster.py → ML_datasetmaster/_sequence_datasetmaster.py} +13 -15
- ml_tools/{_core/_ML_vision_datasetmaster.py → ML_datasetmaster/_vision_datasetmaster.py} +63 -65
- ml_tools/ML_evaluation/__init__.py +53 -0
- ml_tools/ML_evaluation/_classification.py +629 -0
- ml_tools/ML_evaluation/_feature_importance.py +409 -0
- ml_tools/ML_evaluation/_imprimir.py +25 -0
- ml_tools/ML_evaluation/_loss.py +92 -0
- ml_tools/ML_evaluation/_regression.py +273 -0
- ml_tools/{_core/_ML_sequence_evaluation.py → ML_evaluation/_sequence.py} +8 -11
- ml_tools/{_core/_ML_vision_evaluation.py → ML_evaluation/_vision.py} +12 -17
- ml_tools/{_core → ML_evaluation_captum}/_ML_evaluation_captum.py +11 -38
- ml_tools/{ML_evaluation_captum.py → ML_evaluation_captum/__init__.py} +6 -4
- ml_tools/ML_evaluation_captum/_imprimir.py +10 -0
- ml_tools/{_core → ML_finalize_handler}/_ML_finalize_handler.py +3 -7
- ml_tools/ML_finalize_handler/__init__.py +10 -0
- ml_tools/ML_finalize_handler/_imprimir.py +8 -0
- ml_tools/ML_inference/__init__.py +22 -0
- ml_tools/ML_inference/_base_inference.py +166 -0
- ml_tools/{_core/_ML_chaining_inference.py → ML_inference/_chain_inference.py} +14 -17
- ml_tools/ML_inference/_dragon_inference.py +332 -0
- ml_tools/ML_inference/_imprimir.py +11 -0
- ml_tools/ML_inference/_multi_inference.py +180 -0
- ml_tools/ML_inference_sequence/__init__.py +10 -0
- ml_tools/ML_inference_sequence/_imprimir.py +8 -0
- ml_tools/{_core/_ML_sequence_inference.py → ML_inference_sequence/_sequence_inference.py} +11 -15
- ml_tools/ML_inference_vision/__init__.py +10 -0
- ml_tools/ML_inference_vision/_imprimir.py +8 -0
- ml_tools/{_core/_ML_vision_inference.py → ML_inference_vision/_vision_inference.py} +15 -19
- ml_tools/ML_models/__init__.py +32 -0
- ml_tools/{_core/_ML_models_advanced.py → ML_models/_advanced_models.py} +22 -18
- ml_tools/ML_models/_base_mlp_attention.py +198 -0
- ml_tools/{_core/_models_advanced_base.py → ML_models/_base_save_load.py} +73 -49
- ml_tools/ML_models/_dragon_tabular.py +248 -0
- ml_tools/ML_models/_imprimir.py +18 -0
- ml_tools/ML_models/_mlp_attention.py +134 -0
- ml_tools/{_core → ML_models}/_models_advanced_helpers.py +13 -13
- ml_tools/ML_models_sequence/__init__.py +10 -0
- ml_tools/ML_models_sequence/_imprimir.py +8 -0
- ml_tools/{_core/_ML_sequence_models.py → ML_models_sequence/_sequence_models.py} +5 -8
- ml_tools/ML_models_vision/__init__.py +29 -0
- ml_tools/ML_models_vision/_base_wrapper.py +254 -0
- ml_tools/ML_models_vision/_image_classification.py +182 -0
- ml_tools/ML_models_vision/_image_segmentation.py +108 -0
- ml_tools/ML_models_vision/_imprimir.py +16 -0
- ml_tools/ML_models_vision/_object_detection.py +135 -0
- ml_tools/ML_optimization/__init__.py +21 -0
- ml_tools/ML_optimization/_imprimir.py +13 -0
- ml_tools/{_core/_ML_optimization_pareto.py → ML_optimization/_multi_dragon.py} +18 -24
- ml_tools/ML_optimization/_single_dragon.py +203 -0
- ml_tools/{_core/_ML_optimization.py → ML_optimization/_single_manual.py} +75 -213
- ml_tools/{_core → ML_scaler}/_ML_scaler.py +8 -11
- ml_tools/ML_scaler/__init__.py +10 -0
- ml_tools/ML_scaler/_imprimir.py +8 -0
- ml_tools/ML_trainer/__init__.py +20 -0
- ml_tools/ML_trainer/_base_trainer.py +297 -0
- ml_tools/ML_trainer/_dragon_detection_trainer.py +402 -0
- ml_tools/ML_trainer/_dragon_sequence_trainer.py +540 -0
- ml_tools/ML_trainer/_dragon_trainer.py +1160 -0
- ml_tools/ML_trainer/_imprimir.py +10 -0
- ml_tools/{ML_utilities.py → ML_utilities/__init__.py} +14 -6
- ml_tools/ML_utilities/_artifact_finder.py +382 -0
- ml_tools/ML_utilities/_imprimir.py +16 -0
- ml_tools/ML_utilities/_inspection.py +325 -0
- ml_tools/ML_utilities/_train_tools.py +205 -0
- ml_tools/{ML_vision_transformers.py → ML_vision_transformers/__init__.py} +9 -6
- ml_tools/{_core/_ML_vision_transformers.py → ML_vision_transformers/_core_transforms.py} +11 -155
- ml_tools/ML_vision_transformers/_imprimir.py +14 -0
- ml_tools/ML_vision_transformers/_offline_augmentation.py +159 -0
- ml_tools/{_core/_PSO_optimization.py → PSO_optimization/_PSO.py} +58 -15
- ml_tools/{PSO_optimization.py → PSO_optimization/__init__.py} +5 -3
- ml_tools/PSO_optimization/_imprimir.py +10 -0
- ml_tools/SQL/__init__.py +7 -0
- ml_tools/{_core/_SQL.py → SQL/_dragon_SQL.py} +7 -11
- ml_tools/SQL/_imprimir.py +8 -0
- ml_tools/{_core → VIF}/_VIF_factor.py +5 -8
- ml_tools/{VIF_factor.py → VIF/__init__.py} +4 -2
- ml_tools/VIF/_imprimir.py +10 -0
- ml_tools/_core/__init__.py +7 -1
- ml_tools/_core/_logger.py +8 -18
- ml_tools/_core/_schema_load_ops.py +43 -0
- ml_tools/_core/_script_info.py +2 -2
- ml_tools/{data_exploration.py → data_exploration/__init__.py} +32 -16
- ml_tools/data_exploration/_analysis.py +214 -0
- ml_tools/data_exploration/_cleaning.py +566 -0
- ml_tools/data_exploration/_features.py +583 -0
- ml_tools/data_exploration/_imprimir.py +32 -0
- ml_tools/data_exploration/_plotting.py +487 -0
- ml_tools/data_exploration/_schema_ops.py +176 -0
- ml_tools/{ensemble_evaluation.py → ensemble_evaluation/__init__.py} +6 -4
- ml_tools/{_core → ensemble_evaluation}/_ensemble_evaluation.py +3 -7
- ml_tools/ensemble_evaluation/_imprimir.py +14 -0
- ml_tools/{ensemble_inference.py → ensemble_inference/__init__.py} +5 -3
- ml_tools/{_core → ensemble_inference}/_ensemble_inference.py +15 -18
- ml_tools/ensemble_inference/_imprimir.py +9 -0
- ml_tools/{ensemble_learning.py → ensemble_learning/__init__.py} +4 -6
- ml_tools/{_core → ensemble_learning}/_ensemble_learning.py +7 -10
- ml_tools/ensemble_learning/_imprimir.py +10 -0
- ml_tools/{excel_handler.py → excel_handler/__init__.py} +5 -3
- ml_tools/{_core → excel_handler}/_excel_handler.py +6 -10
- ml_tools/excel_handler/_imprimir.py +13 -0
- ml_tools/{keys.py → keys/__init__.py} +4 -1
- ml_tools/keys/_imprimir.py +11 -0
- ml_tools/{_core → keys}/_keys.py +2 -0
- ml_tools/{math_utilities.py → math_utilities/__init__.py} +5 -2
- ml_tools/math_utilities/_imprimir.py +11 -0
- ml_tools/{_core → math_utilities}/_math_utilities.py +1 -5
- ml_tools/{optimization_tools.py → optimization_tools/__init__.py} +9 -4
- ml_tools/optimization_tools/_imprimir.py +13 -0
- ml_tools/optimization_tools/_optimization_bounds.py +236 -0
- ml_tools/optimization_tools/_optimization_plots.py +218 -0
- ml_tools/{path_manager.py → path_manager/__init__.py} +6 -3
- ml_tools/{_core/_path_manager.py → path_manager/_dragonmanager.py} +11 -347
- ml_tools/path_manager/_imprimir.py +15 -0
- ml_tools/path_manager/_path_tools.py +346 -0
- ml_tools/plot_fonts/__init__.py +8 -0
- ml_tools/plot_fonts/_imprimir.py +8 -0
- ml_tools/{_core → plot_fonts}/_plot_fonts.py +2 -5
- ml_tools/schema/__init__.py +15 -0
- ml_tools/schema/_feature_schema.py +223 -0
- ml_tools/schema/_gui_schema.py +191 -0
- ml_tools/schema/_imprimir.py +10 -0
- ml_tools/{serde.py → serde/__init__.py} +4 -2
- ml_tools/serde/_imprimir.py +10 -0
- ml_tools/{_core → serde}/_serde.py +3 -8
- ml_tools/{utilities.py → utilities/__init__.py} +11 -6
- ml_tools/utilities/_imprimir.py +18 -0
- ml_tools/{_core/_utilities.py → utilities/_utility_save_load.py} +13 -190
- ml_tools/utilities/_utility_tools.py +192 -0
- dragon_ml_toolbox-19.13.0.dist-info/RECORD +0 -111
- ml_tools/ML_chaining_inference.py +0 -8
- ml_tools/ML_configuration.py +0 -86
- ml_tools/ML_configuration_pytab.py +0 -14
- ml_tools/ML_datasetmaster.py +0 -10
- ml_tools/ML_evaluation.py +0 -16
- ml_tools/ML_evaluation_multi.py +0 -12
- ml_tools/ML_finalize_handler.py +0 -8
- ml_tools/ML_inference.py +0 -12
- ml_tools/ML_models.py +0 -14
- ml_tools/ML_models_advanced.py +0 -14
- ml_tools/ML_models_pytab.py +0 -14
- ml_tools/ML_optimization.py +0 -14
- ml_tools/ML_optimization_pareto.py +0 -8
- ml_tools/ML_scaler.py +0 -8
- ml_tools/ML_sequence_datasetmaster.py +0 -8
- ml_tools/ML_sequence_evaluation.py +0 -10
- ml_tools/ML_sequence_inference.py +0 -8
- ml_tools/ML_sequence_models.py +0 -8
- ml_tools/ML_trainer.py +0 -12
- ml_tools/ML_vision_datasetmaster.py +0 -12
- ml_tools/ML_vision_evaluation.py +0 -10
- ml_tools/ML_vision_inference.py +0 -8
- ml_tools/ML_vision_models.py +0 -18
- ml_tools/SQL.py +0 -8
- ml_tools/_core/_ETL_cleaning.py +0 -694
- ml_tools/_core/_IO_tools.py +0 -498
- ml_tools/_core/_ML_callbacks.py +0 -702
- ml_tools/_core/_ML_configuration.py +0 -1332
- ml_tools/_core/_ML_configuration_pytab.py +0 -102
- ml_tools/_core/_ML_evaluation.py +0 -867
- ml_tools/_core/_ML_evaluation_multi.py +0 -544
- ml_tools/_core/_ML_inference.py +0 -646
- ml_tools/_core/_ML_models.py +0 -668
- ml_tools/_core/_ML_models_pytab.py +0 -693
- ml_tools/_core/_ML_trainer.py +0 -2323
- ml_tools/_core/_ML_utilities.py +0 -886
- ml_tools/_core/_ML_vision_models.py +0 -644
- ml_tools/_core/_data_exploration.py +0 -1901
- ml_tools/_core/_optimization_tools.py +0 -493
- ml_tools/_core/_schema.py +0 -359
- ml_tools/plot_fonts.py +0 -8
- ml_tools/schema.py +0 -12
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,297 @@
|
|
|
1
|
+
from typing import Literal, Union, Optional, Any
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from torch.utils.data import DataLoader
|
|
4
|
+
import torch
|
|
5
|
+
from torch import nn
|
|
6
|
+
from abc import ABC, abstractmethod
|
|
7
|
+
|
|
8
|
+
from ..ML_callbacks._base import _Callback, History, TqdmProgressBar
|
|
9
|
+
from ..ML_callbacks._checkpoint import DragonModelCheckpoint
|
|
10
|
+
from ..ML_callbacks._early_stop import _DragonEarlyStopping
|
|
11
|
+
from ..ML_callbacks._scheduler import _DragonLRScheduler
|
|
12
|
+
from ..ML_evaluation import plot_losses
|
|
13
|
+
|
|
14
|
+
from ..path_manager import make_fullpath
|
|
15
|
+
from ..keys._keys import PyTorchCheckpointKeys, MagicWords
|
|
16
|
+
from .._core import get_logger
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
_LOGGER = get_logger("DragonTrainer")
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
__all__ = [
|
|
23
|
+
"_BaseDragonTrainer",
|
|
24
|
+
]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class _BaseDragonTrainer(ABC):
|
|
28
|
+
"""
|
|
29
|
+
Abstract base class for Dragon Trainers.
|
|
30
|
+
|
|
31
|
+
Handles the common training loop orchestration, checkpointing, callback
|
|
32
|
+
management, and device handling. Subclasses must implement the
|
|
33
|
+
task-specific logic (dataloaders, train/val steps, evaluation).
|
|
34
|
+
"""
|
|
35
|
+
def __init__(self,
|
|
36
|
+
model: nn.Module,
|
|
37
|
+
optimizer: torch.optim.Optimizer,
|
|
38
|
+
device: Union[Literal['cuda', 'mps', 'cpu'],str],
|
|
39
|
+
dataloader_workers: int = 2,
|
|
40
|
+
checkpoint_callback: Optional[DragonModelCheckpoint] = None,
|
|
41
|
+
early_stopping_callback: Optional[_DragonEarlyStopping] = None,
|
|
42
|
+
lr_scheduler_callback: Optional[_DragonLRScheduler] = None,
|
|
43
|
+
extra_callbacks: Optional[list[_Callback]] = None):
|
|
44
|
+
|
|
45
|
+
self.model = model
|
|
46
|
+
self.optimizer = optimizer
|
|
47
|
+
self.scheduler = None
|
|
48
|
+
self.device = self._validate_device(device)
|
|
49
|
+
self.dataloader_workers = dataloader_workers
|
|
50
|
+
|
|
51
|
+
# Callback handler
|
|
52
|
+
default_callbacks = [History(), TqdmProgressBar()]
|
|
53
|
+
|
|
54
|
+
self._checkpoint_callback = None
|
|
55
|
+
if checkpoint_callback:
|
|
56
|
+
default_callbacks.append(checkpoint_callback)
|
|
57
|
+
self._checkpoint_callback = checkpoint_callback
|
|
58
|
+
if early_stopping_callback:
|
|
59
|
+
default_callbacks.append(early_stopping_callback)
|
|
60
|
+
if lr_scheduler_callback:
|
|
61
|
+
default_callbacks.append(lr_scheduler_callback)
|
|
62
|
+
|
|
63
|
+
user_callbacks = extra_callbacks if extra_callbacks is not None else []
|
|
64
|
+
self.callbacks = default_callbacks + user_callbacks
|
|
65
|
+
self._set_trainer_on_callbacks()
|
|
66
|
+
|
|
67
|
+
# Internal state
|
|
68
|
+
self.train_loader: Optional[DataLoader] = None
|
|
69
|
+
self.validation_loader: Optional[DataLoader] = None
|
|
70
|
+
self.history: dict[str, list[Any]] = {}
|
|
71
|
+
self.epoch = 0
|
|
72
|
+
self.epochs = 0 # Total epochs for the fit run
|
|
73
|
+
self.start_epoch = 1
|
|
74
|
+
self.stop_training = False
|
|
75
|
+
self._batch_size = 10
|
|
76
|
+
|
|
77
|
+
def _validate_device(self, device: str) -> torch.device:
|
|
78
|
+
"""Validates the selected device and returns a torch.device object."""
|
|
79
|
+
device_lower = device.lower()
|
|
80
|
+
if "cuda" in device_lower and not torch.cuda.is_available():
|
|
81
|
+
_LOGGER.warning("CUDA not available, switching to CPU.")
|
|
82
|
+
device = "cpu"
|
|
83
|
+
elif device_lower == "mps" and not torch.backends.mps.is_available():
|
|
84
|
+
_LOGGER.warning("Apple Metal Performance Shaders (MPS) not available, switching to CPU.")
|
|
85
|
+
device = "cpu"
|
|
86
|
+
return torch.device(device)
|
|
87
|
+
|
|
88
|
+
def _set_trainer_on_callbacks(self):
|
|
89
|
+
"""Gives each callback a reference to this trainer instance."""
|
|
90
|
+
for callback in self.callbacks:
|
|
91
|
+
callback.set_trainer(self)
|
|
92
|
+
|
|
93
|
+
def _load_checkpoint(self, path: Union[str, Path]):
|
|
94
|
+
"""Loads a training checkpoint to resume training."""
|
|
95
|
+
p = make_fullpath(path, enforce="file")
|
|
96
|
+
_LOGGER.info(f"Loading checkpoint from '{p.name}'...")
|
|
97
|
+
|
|
98
|
+
try:
|
|
99
|
+
checkpoint = torch.load(p, map_location=self.device)
|
|
100
|
+
|
|
101
|
+
if PyTorchCheckpointKeys.MODEL_STATE not in checkpoint or PyTorchCheckpointKeys.OPTIMIZER_STATE not in checkpoint:
|
|
102
|
+
_LOGGER.error(f"Checkpoint file '{p.name}' is invalid. Missing 'model_state_dict' or 'optimizer_state_dict'.")
|
|
103
|
+
raise KeyError()
|
|
104
|
+
|
|
105
|
+
self.model.load_state_dict(checkpoint[PyTorchCheckpointKeys.MODEL_STATE])
|
|
106
|
+
self.optimizer.load_state_dict(checkpoint[PyTorchCheckpointKeys.OPTIMIZER_STATE])
|
|
107
|
+
self.epoch = checkpoint.get(PyTorchCheckpointKeys.EPOCH, 0)
|
|
108
|
+
self.start_epoch = self.epoch + 1 # Resume on the *next* epoch
|
|
109
|
+
|
|
110
|
+
# --- Load History ---
|
|
111
|
+
if PyTorchCheckpointKeys.HISTORY in checkpoint:
|
|
112
|
+
self.history = checkpoint[PyTorchCheckpointKeys.HISTORY]
|
|
113
|
+
_LOGGER.info(f"Restored training history up to epoch {self.epoch}.")
|
|
114
|
+
else:
|
|
115
|
+
_LOGGER.warning("No 'history' found in checkpoint. A new history will be started.")
|
|
116
|
+
self.history = {} # Ensure it's at least an empty dict
|
|
117
|
+
|
|
118
|
+
# --- Scheduler State Loading Logic ---
|
|
119
|
+
scheduler_state_exists = PyTorchCheckpointKeys.SCHEDULER_STATE in checkpoint
|
|
120
|
+
scheduler_object_exists = self.scheduler is not None
|
|
121
|
+
|
|
122
|
+
if scheduler_object_exists and scheduler_state_exists:
|
|
123
|
+
# Case 1: Both exist. Attempt to load.
|
|
124
|
+
try:
|
|
125
|
+
self.scheduler.load_state_dict(checkpoint[PyTorchCheckpointKeys.SCHEDULER_STATE]) # type: ignore
|
|
126
|
+
scheduler_name = self.scheduler.__class__.__name__
|
|
127
|
+
_LOGGER.info(f"Restored LR scheduler state for: {scheduler_name}")
|
|
128
|
+
except Exception as e:
|
|
129
|
+
# Loading failed, likely a mismatch
|
|
130
|
+
scheduler_name = self.scheduler.__class__.__name__
|
|
131
|
+
_LOGGER.error(f"Failed to load scheduler state for '{scheduler_name}'. A different scheduler type might have been used.")
|
|
132
|
+
raise e
|
|
133
|
+
|
|
134
|
+
elif scheduler_object_exists and not scheduler_state_exists:
|
|
135
|
+
# Case 2: Scheduler provided, but no state in checkpoint.
|
|
136
|
+
scheduler_name = self.scheduler.__class__.__name__
|
|
137
|
+
_LOGGER.warning(f"'{scheduler_name}' was provided, but no scheduler state was found in the checkpoint. The scheduler will start from its initial state.")
|
|
138
|
+
|
|
139
|
+
elif not scheduler_object_exists and scheduler_state_exists:
|
|
140
|
+
# Case 3: State in checkpoint, but no scheduler provided.
|
|
141
|
+
_LOGGER.error("Checkpoint contains an LR scheduler state, but no LRScheduler callback was provided.")
|
|
142
|
+
raise ValueError()
|
|
143
|
+
|
|
144
|
+
# Restore callback states
|
|
145
|
+
for cb in self.callbacks:
|
|
146
|
+
if isinstance(cb, DragonModelCheckpoint) and PyTorchCheckpointKeys.BEST_SCORE in checkpoint:
|
|
147
|
+
cb.best = checkpoint[PyTorchCheckpointKeys.BEST_SCORE]
|
|
148
|
+
_LOGGER.info(f"Restored {cb.__class__.__name__} 'best' score to: {cb.best:.4f}")
|
|
149
|
+
|
|
150
|
+
_LOGGER.info(f"Checkpoint loaded. Resuming training from epoch {self.start_epoch}.")
|
|
151
|
+
|
|
152
|
+
except Exception as e:
|
|
153
|
+
_LOGGER.error(f"Failed to load checkpoint from '{p}': {e}")
|
|
154
|
+
raise
|
|
155
|
+
|
|
156
|
+
def fit(self,
|
|
157
|
+
save_dir: Union[str,Path],
|
|
158
|
+
epochs: int = 100,
|
|
159
|
+
batch_size: int = 10,
|
|
160
|
+
shuffle: bool = True,
|
|
161
|
+
resume_from_checkpoint: Optional[Union[str, Path]] = None):
|
|
162
|
+
"""
|
|
163
|
+
Starts the training-validation process of the model.
|
|
164
|
+
|
|
165
|
+
Returns the "History" callback dictionary.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
save_dir (str | Path): Directory to save the loss plot.
|
|
169
|
+
epochs (int): The total number of epochs to train for.
|
|
170
|
+
batch_size (int): The number of samples per batch.
|
|
171
|
+
shuffle (bool): Whether to shuffle the training data at each epoch.
|
|
172
|
+
resume_from_checkpoint (str | Path | None): Optional path to a checkpoint to resume training.
|
|
173
|
+
"""
|
|
174
|
+
self.epochs = epochs
|
|
175
|
+
self._batch_size = batch_size
|
|
176
|
+
self._create_dataloaders(self._batch_size, shuffle) # type: ignore
|
|
177
|
+
self.model.to(self.device)
|
|
178
|
+
|
|
179
|
+
if resume_from_checkpoint:
|
|
180
|
+
self._load_checkpoint(resume_from_checkpoint)
|
|
181
|
+
|
|
182
|
+
# Reset stop_training flag on the trainer
|
|
183
|
+
self.stop_training = False
|
|
184
|
+
|
|
185
|
+
self._callbacks_hook('on_train_begin')
|
|
186
|
+
|
|
187
|
+
if not self.train_loader:
|
|
188
|
+
_LOGGER.error("Train loader is not initialized.")
|
|
189
|
+
raise ValueError()
|
|
190
|
+
|
|
191
|
+
if not self.validation_loader:
|
|
192
|
+
_LOGGER.error("Validation loader is not initialized.")
|
|
193
|
+
raise ValueError()
|
|
194
|
+
|
|
195
|
+
for epoch in range(self.start_epoch, self.epochs + 1):
|
|
196
|
+
self.epoch = epoch
|
|
197
|
+
epoch_logs: dict[str, Any] = {}
|
|
198
|
+
self._callbacks_hook('on_epoch_begin', epoch, logs=epoch_logs)
|
|
199
|
+
|
|
200
|
+
train_logs = self._train_step()
|
|
201
|
+
epoch_logs.update(train_logs)
|
|
202
|
+
|
|
203
|
+
val_logs = self._validation_step()
|
|
204
|
+
epoch_logs.update(val_logs)
|
|
205
|
+
|
|
206
|
+
self._callbacks_hook('on_epoch_end', epoch, logs=epoch_logs)
|
|
207
|
+
|
|
208
|
+
# Check the early stopping flag
|
|
209
|
+
if self.stop_training:
|
|
210
|
+
break
|
|
211
|
+
|
|
212
|
+
self._callbacks_hook('on_train_end')
|
|
213
|
+
|
|
214
|
+
# Training History
|
|
215
|
+
plot_losses(self.history, save_dir=save_dir)
|
|
216
|
+
|
|
217
|
+
return self.history
|
|
218
|
+
|
|
219
|
+
def _callbacks_hook(self, method_name: str, *args, **kwargs):
|
|
220
|
+
"""Calls the specified method on all callbacks."""
|
|
221
|
+
for callback in self.callbacks:
|
|
222
|
+
method = getattr(callback, method_name)
|
|
223
|
+
method(*args, **kwargs)
|
|
224
|
+
|
|
225
|
+
def to_cpu(self):
|
|
226
|
+
"""
|
|
227
|
+
Moves the model to the CPU and updates the trainer's device setting.
|
|
228
|
+
|
|
229
|
+
This is useful for running operations that require the CPU.
|
|
230
|
+
"""
|
|
231
|
+
self.device = torch.device('cpu')
|
|
232
|
+
self.model.to(self.device)
|
|
233
|
+
_LOGGER.info("Trainer and model moved to CPU.")
|
|
234
|
+
|
|
235
|
+
def to_device(self, device: str):
|
|
236
|
+
"""
|
|
237
|
+
Moves the model to the specified device and updates the trainer's device setting.
|
|
238
|
+
|
|
239
|
+
Args:
|
|
240
|
+
device (str): The target device (e.g., 'cuda', 'mps', 'cpu').
|
|
241
|
+
"""
|
|
242
|
+
self.device = self._validate_device(device)
|
|
243
|
+
self.model.to(self.device)
|
|
244
|
+
_LOGGER.info(f"Trainer and model moved to {self.device}.")
|
|
245
|
+
|
|
246
|
+
def _load_model_state_for_finalizing(self, model_checkpoint: Union[Path, Literal['best', 'current']]):
|
|
247
|
+
"""
|
|
248
|
+
Private helper to load the correct model state_dict based on user's choice.
|
|
249
|
+
This is called by finalize_model_training() in subclasses.
|
|
250
|
+
"""
|
|
251
|
+
if isinstance(model_checkpoint, Path):
|
|
252
|
+
self._load_checkpoint(path=model_checkpoint)
|
|
253
|
+
elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback:
|
|
254
|
+
path_to_latest = self._checkpoint_callback.best_checkpoint_path
|
|
255
|
+
self._load_checkpoint(path_to_latest)
|
|
256
|
+
elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback is None:
|
|
257
|
+
_LOGGER.error(f"'model_checkpoint' set to '{MagicWords.BEST}' but no checkpoint callback was found.")
|
|
258
|
+
raise ValueError()
|
|
259
|
+
elif model_checkpoint == MagicWords.CURRENT:
|
|
260
|
+
pass
|
|
261
|
+
else:
|
|
262
|
+
_LOGGER.error(f"Unknown 'model_checkpoint' received '{model_checkpoint}'.")
|
|
263
|
+
raise ValueError()
|
|
264
|
+
|
|
265
|
+
# --- Abstract Methods ---
|
|
266
|
+
# These must be implemented by subclasses
|
|
267
|
+
|
|
268
|
+
@abstractmethod
|
|
269
|
+
def _create_dataloaders(self, batch_size: int, shuffle: bool):
|
|
270
|
+
"""Initializes the DataLoaders."""
|
|
271
|
+
raise NotImplementedError
|
|
272
|
+
|
|
273
|
+
@abstractmethod
|
|
274
|
+
def _train_step(self) -> dict[str, float]:
|
|
275
|
+
"""Runs a single training epoch."""
|
|
276
|
+
raise NotImplementedError
|
|
277
|
+
|
|
278
|
+
@abstractmethod
|
|
279
|
+
def _validation_step(self) -> dict[str, float]:
|
|
280
|
+
"""Runs a single validation epoch."""
|
|
281
|
+
raise NotImplementedError
|
|
282
|
+
|
|
283
|
+
@abstractmethod
|
|
284
|
+
def evaluate(self, *args, **kwargs):
|
|
285
|
+
"""Runs the full model evaluation."""
|
|
286
|
+
raise NotImplementedError
|
|
287
|
+
|
|
288
|
+
@abstractmethod
|
|
289
|
+
def _evaluate(self, *args, **kwargs):
|
|
290
|
+
"""Internal evaluation helper."""
|
|
291
|
+
raise NotImplementedError
|
|
292
|
+
|
|
293
|
+
@abstractmethod
|
|
294
|
+
def finalize_model_training(self, *args, **kwargs):
|
|
295
|
+
"""Saves the finalized model for inference."""
|
|
296
|
+
raise NotImplementedError
|
|
297
|
+
|