dragon-ml-toolbox 19.14.0__py3-none-any.whl → 20.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/METADATA +29 -46
  2. dragon_ml_toolbox-20.0.0.dist-info/RECORD +178 -0
  3. ml_tools/{ETL_cleaning.py → ETL_cleaning/__init__.py} +13 -5
  4. ml_tools/ETL_cleaning/_basic_clean.py +351 -0
  5. ml_tools/ETL_cleaning/_clean_tools.py +128 -0
  6. ml_tools/ETL_cleaning/_dragon_cleaner.py +245 -0
  7. ml_tools/ETL_cleaning/_imprimir.py +13 -0
  8. ml_tools/{ETL_engineering.py → ETL_engineering/__init__.py} +8 -4
  9. ml_tools/ETL_engineering/_dragon_engineering.py +261 -0
  10. ml_tools/ETL_engineering/_imprimir.py +24 -0
  11. ml_tools/{_core/_ETL_engineering.py → ETL_engineering/_transforms.py} +14 -267
  12. ml_tools/{_core → GUI_tools}/_GUI_tools.py +37 -40
  13. ml_tools/{GUI_tools.py → GUI_tools/__init__.py} +7 -5
  14. ml_tools/GUI_tools/_imprimir.py +12 -0
  15. ml_tools/IO_tools/_IO_loggers.py +235 -0
  16. ml_tools/IO_tools/_IO_save_load.py +151 -0
  17. ml_tools/IO_tools/_IO_utils.py +140 -0
  18. ml_tools/{IO_tools.py → IO_tools/__init__.py} +13 -5
  19. ml_tools/IO_tools/_imprimir.py +14 -0
  20. ml_tools/MICE/_MICE_imputation.py +132 -0
  21. ml_tools/{MICE_imputation.py → MICE/__init__.py} +6 -7
  22. ml_tools/{_core/_MICE_imputation.py → MICE/_dragon_mice.py} +243 -322
  23. ml_tools/MICE/_imprimir.py +11 -0
  24. ml_tools/{ML_callbacks.py → ML_callbacks/__init__.py} +12 -4
  25. ml_tools/ML_callbacks/_base.py +101 -0
  26. ml_tools/ML_callbacks/_checkpoint.py +232 -0
  27. ml_tools/ML_callbacks/_early_stop.py +208 -0
  28. ml_tools/ML_callbacks/_imprimir.py +12 -0
  29. ml_tools/ML_callbacks/_scheduler.py +197 -0
  30. ml_tools/{ML_chaining_utilities.py → ML_chain/__init__.py} +8 -3
  31. ml_tools/{_core/_ML_chaining_utilities.py → ML_chain/_chaining_tools.py} +5 -129
  32. ml_tools/ML_chain/_dragon_chain.py +140 -0
  33. ml_tools/ML_chain/_imprimir.py +11 -0
  34. ml_tools/ML_configuration/__init__.py +90 -0
  35. ml_tools/ML_configuration/_base_model_config.py +69 -0
  36. ml_tools/ML_configuration/_finalize.py +366 -0
  37. ml_tools/ML_configuration/_imprimir.py +47 -0
  38. ml_tools/ML_configuration/_metrics.py +593 -0
  39. ml_tools/ML_configuration/_models.py +206 -0
  40. ml_tools/ML_configuration/_training.py +124 -0
  41. ml_tools/ML_datasetmaster/__init__.py +28 -0
  42. ml_tools/ML_datasetmaster/_base_datasetmaster.py +337 -0
  43. ml_tools/{_core/_ML_datasetmaster.py → ML_datasetmaster/_datasetmaster.py} +9 -329
  44. ml_tools/ML_datasetmaster/_imprimir.py +15 -0
  45. ml_tools/{_core/_ML_sequence_datasetmaster.py → ML_datasetmaster/_sequence_datasetmaster.py} +13 -15
  46. ml_tools/{_core/_ML_vision_datasetmaster.py → ML_datasetmaster/_vision_datasetmaster.py} +63 -65
  47. ml_tools/ML_evaluation/__init__.py +53 -0
  48. ml_tools/ML_evaluation/_classification.py +629 -0
  49. ml_tools/ML_evaluation/_feature_importance.py +409 -0
  50. ml_tools/ML_evaluation/_imprimir.py +25 -0
  51. ml_tools/ML_evaluation/_loss.py +92 -0
  52. ml_tools/ML_evaluation/_regression.py +273 -0
  53. ml_tools/{_core/_ML_sequence_evaluation.py → ML_evaluation/_sequence.py} +8 -11
  54. ml_tools/{_core/_ML_vision_evaluation.py → ML_evaluation/_vision.py} +12 -17
  55. ml_tools/{_core → ML_evaluation_captum}/_ML_evaluation_captum.py +11 -38
  56. ml_tools/{ML_evaluation_captum.py → ML_evaluation_captum/__init__.py} +6 -4
  57. ml_tools/ML_evaluation_captum/_imprimir.py +10 -0
  58. ml_tools/{_core → ML_finalize_handler}/_ML_finalize_handler.py +3 -7
  59. ml_tools/ML_finalize_handler/__init__.py +10 -0
  60. ml_tools/ML_finalize_handler/_imprimir.py +8 -0
  61. ml_tools/ML_inference/__init__.py +22 -0
  62. ml_tools/ML_inference/_base_inference.py +166 -0
  63. ml_tools/{_core/_ML_chaining_inference.py → ML_inference/_chain_inference.py} +14 -17
  64. ml_tools/ML_inference/_dragon_inference.py +332 -0
  65. ml_tools/ML_inference/_imprimir.py +11 -0
  66. ml_tools/ML_inference/_multi_inference.py +180 -0
  67. ml_tools/ML_inference_sequence/__init__.py +10 -0
  68. ml_tools/ML_inference_sequence/_imprimir.py +8 -0
  69. ml_tools/{_core/_ML_sequence_inference.py → ML_inference_sequence/_sequence_inference.py} +11 -15
  70. ml_tools/ML_inference_vision/__init__.py +10 -0
  71. ml_tools/ML_inference_vision/_imprimir.py +8 -0
  72. ml_tools/{_core/_ML_vision_inference.py → ML_inference_vision/_vision_inference.py} +15 -19
  73. ml_tools/ML_models/__init__.py +32 -0
  74. ml_tools/{_core/_ML_models_advanced.py → ML_models/_advanced_models.py} +22 -18
  75. ml_tools/ML_models/_base_mlp_attention.py +198 -0
  76. ml_tools/{_core/_models_advanced_base.py → ML_models/_base_save_load.py} +73 -49
  77. ml_tools/ML_models/_dragon_tabular.py +248 -0
  78. ml_tools/ML_models/_imprimir.py +18 -0
  79. ml_tools/ML_models/_mlp_attention.py +134 -0
  80. ml_tools/{_core → ML_models}/_models_advanced_helpers.py +13 -13
  81. ml_tools/ML_models_sequence/__init__.py +10 -0
  82. ml_tools/ML_models_sequence/_imprimir.py +8 -0
  83. ml_tools/{_core/_ML_sequence_models.py → ML_models_sequence/_sequence_models.py} +5 -8
  84. ml_tools/ML_models_vision/__init__.py +29 -0
  85. ml_tools/ML_models_vision/_base_wrapper.py +254 -0
  86. ml_tools/ML_models_vision/_image_classification.py +182 -0
  87. ml_tools/ML_models_vision/_image_segmentation.py +108 -0
  88. ml_tools/ML_models_vision/_imprimir.py +16 -0
  89. ml_tools/ML_models_vision/_object_detection.py +135 -0
  90. ml_tools/ML_optimization/__init__.py +21 -0
  91. ml_tools/ML_optimization/_imprimir.py +13 -0
  92. ml_tools/{_core/_ML_optimization_pareto.py → ML_optimization/_multi_dragon.py} +18 -24
  93. ml_tools/ML_optimization/_single_dragon.py +203 -0
  94. ml_tools/{_core/_ML_optimization.py → ML_optimization/_single_manual.py} +75 -213
  95. ml_tools/{_core → ML_scaler}/_ML_scaler.py +8 -11
  96. ml_tools/ML_scaler/__init__.py +10 -0
  97. ml_tools/ML_scaler/_imprimir.py +8 -0
  98. ml_tools/ML_trainer/__init__.py +20 -0
  99. ml_tools/ML_trainer/_base_trainer.py +297 -0
  100. ml_tools/ML_trainer/_dragon_detection_trainer.py +402 -0
  101. ml_tools/ML_trainer/_dragon_sequence_trainer.py +540 -0
  102. ml_tools/ML_trainer/_dragon_trainer.py +1160 -0
  103. ml_tools/ML_trainer/_imprimir.py +10 -0
  104. ml_tools/{ML_utilities.py → ML_utilities/__init__.py} +14 -6
  105. ml_tools/ML_utilities/_artifact_finder.py +382 -0
  106. ml_tools/ML_utilities/_imprimir.py +16 -0
  107. ml_tools/ML_utilities/_inspection.py +325 -0
  108. ml_tools/ML_utilities/_train_tools.py +205 -0
  109. ml_tools/{ML_vision_transformers.py → ML_vision_transformers/__init__.py} +9 -6
  110. ml_tools/{_core/_ML_vision_transformers.py → ML_vision_transformers/_core_transforms.py} +11 -155
  111. ml_tools/ML_vision_transformers/_imprimir.py +14 -0
  112. ml_tools/ML_vision_transformers/_offline_augmentation.py +159 -0
  113. ml_tools/{_core/_PSO_optimization.py → PSO_optimization/_PSO.py} +58 -15
  114. ml_tools/{PSO_optimization.py → PSO_optimization/__init__.py} +5 -3
  115. ml_tools/PSO_optimization/_imprimir.py +10 -0
  116. ml_tools/SQL/__init__.py +7 -0
  117. ml_tools/{_core/_SQL.py → SQL/_dragon_SQL.py} +7 -11
  118. ml_tools/SQL/_imprimir.py +8 -0
  119. ml_tools/{_core → VIF}/_VIF_factor.py +5 -8
  120. ml_tools/{VIF_factor.py → VIF/__init__.py} +4 -2
  121. ml_tools/VIF/_imprimir.py +10 -0
  122. ml_tools/_core/__init__.py +7 -1
  123. ml_tools/_core/_logger.py +8 -18
  124. ml_tools/_core/_schema_load_ops.py +43 -0
  125. ml_tools/_core/_script_info.py +2 -2
  126. ml_tools/{data_exploration.py → data_exploration/__init__.py} +32 -16
  127. ml_tools/data_exploration/_analysis.py +214 -0
  128. ml_tools/data_exploration/_cleaning.py +566 -0
  129. ml_tools/data_exploration/_features.py +583 -0
  130. ml_tools/data_exploration/_imprimir.py +32 -0
  131. ml_tools/data_exploration/_plotting.py +487 -0
  132. ml_tools/data_exploration/_schema_ops.py +176 -0
  133. ml_tools/{ensemble_evaluation.py → ensemble_evaluation/__init__.py} +6 -4
  134. ml_tools/{_core → ensemble_evaluation}/_ensemble_evaluation.py +3 -7
  135. ml_tools/ensemble_evaluation/_imprimir.py +14 -0
  136. ml_tools/{ensemble_inference.py → ensemble_inference/__init__.py} +5 -3
  137. ml_tools/{_core → ensemble_inference}/_ensemble_inference.py +15 -18
  138. ml_tools/ensemble_inference/_imprimir.py +9 -0
  139. ml_tools/{ensemble_learning.py → ensemble_learning/__init__.py} +4 -6
  140. ml_tools/{_core → ensemble_learning}/_ensemble_learning.py +7 -10
  141. ml_tools/ensemble_learning/_imprimir.py +10 -0
  142. ml_tools/{excel_handler.py → excel_handler/__init__.py} +5 -3
  143. ml_tools/{_core → excel_handler}/_excel_handler.py +6 -10
  144. ml_tools/excel_handler/_imprimir.py +13 -0
  145. ml_tools/{keys.py → keys/__init__.py} +4 -1
  146. ml_tools/keys/_imprimir.py +11 -0
  147. ml_tools/{_core → keys}/_keys.py +2 -0
  148. ml_tools/{math_utilities.py → math_utilities/__init__.py} +5 -2
  149. ml_tools/math_utilities/_imprimir.py +11 -0
  150. ml_tools/{_core → math_utilities}/_math_utilities.py +1 -5
  151. ml_tools/{optimization_tools.py → optimization_tools/__init__.py} +9 -4
  152. ml_tools/optimization_tools/_imprimir.py +13 -0
  153. ml_tools/optimization_tools/_optimization_bounds.py +236 -0
  154. ml_tools/optimization_tools/_optimization_plots.py +218 -0
  155. ml_tools/{path_manager.py → path_manager/__init__.py} +6 -3
  156. ml_tools/{_core/_path_manager.py → path_manager/_dragonmanager.py} +11 -347
  157. ml_tools/path_manager/_imprimir.py +15 -0
  158. ml_tools/path_manager/_path_tools.py +346 -0
  159. ml_tools/plot_fonts/__init__.py +8 -0
  160. ml_tools/plot_fonts/_imprimir.py +8 -0
  161. ml_tools/{_core → plot_fonts}/_plot_fonts.py +2 -5
  162. ml_tools/schema/__init__.py +15 -0
  163. ml_tools/schema/_feature_schema.py +223 -0
  164. ml_tools/schema/_gui_schema.py +191 -0
  165. ml_tools/schema/_imprimir.py +10 -0
  166. ml_tools/{serde.py → serde/__init__.py} +4 -2
  167. ml_tools/serde/_imprimir.py +10 -0
  168. ml_tools/{_core → serde}/_serde.py +3 -8
  169. ml_tools/{utilities.py → utilities/__init__.py} +11 -6
  170. ml_tools/utilities/_imprimir.py +18 -0
  171. ml_tools/{_core/_utilities.py → utilities/_utility_save_load.py} +13 -190
  172. ml_tools/utilities/_utility_tools.py +192 -0
  173. dragon_ml_toolbox-19.14.0.dist-info/RECORD +0 -111
  174. ml_tools/ML_chaining_inference.py +0 -8
  175. ml_tools/ML_configuration.py +0 -86
  176. ml_tools/ML_configuration_pytab.py +0 -14
  177. ml_tools/ML_datasetmaster.py +0 -10
  178. ml_tools/ML_evaluation.py +0 -16
  179. ml_tools/ML_evaluation_multi.py +0 -12
  180. ml_tools/ML_finalize_handler.py +0 -8
  181. ml_tools/ML_inference.py +0 -12
  182. ml_tools/ML_models.py +0 -14
  183. ml_tools/ML_models_advanced.py +0 -14
  184. ml_tools/ML_models_pytab.py +0 -14
  185. ml_tools/ML_optimization.py +0 -14
  186. ml_tools/ML_optimization_pareto.py +0 -8
  187. ml_tools/ML_scaler.py +0 -8
  188. ml_tools/ML_sequence_datasetmaster.py +0 -8
  189. ml_tools/ML_sequence_evaluation.py +0 -10
  190. ml_tools/ML_sequence_inference.py +0 -8
  191. ml_tools/ML_sequence_models.py +0 -8
  192. ml_tools/ML_trainer.py +0 -12
  193. ml_tools/ML_vision_datasetmaster.py +0 -12
  194. ml_tools/ML_vision_evaluation.py +0 -10
  195. ml_tools/ML_vision_inference.py +0 -8
  196. ml_tools/ML_vision_models.py +0 -18
  197. ml_tools/SQL.py +0 -8
  198. ml_tools/_core/_ETL_cleaning.py +0 -694
  199. ml_tools/_core/_IO_tools.py +0 -498
  200. ml_tools/_core/_ML_callbacks.py +0 -702
  201. ml_tools/_core/_ML_configuration.py +0 -1332
  202. ml_tools/_core/_ML_configuration_pytab.py +0 -102
  203. ml_tools/_core/_ML_evaluation.py +0 -867
  204. ml_tools/_core/_ML_evaluation_multi.py +0 -544
  205. ml_tools/_core/_ML_inference.py +0 -646
  206. ml_tools/_core/_ML_models.py +0 -668
  207. ml_tools/_core/_ML_models_pytab.py +0 -693
  208. ml_tools/_core/_ML_trainer.py +0 -2323
  209. ml_tools/_core/_ML_utilities.py +0 -886
  210. ml_tools/_core/_ML_vision_models.py +0 -644
  211. ml_tools/_core/_data_exploration.py +0 -1909
  212. ml_tools/_core/_optimization_tools.py +0 -493
  213. ml_tools/_core/_schema.py +0 -359
  214. ml_tools/plot_fonts.py +0 -8
  215. ml_tools/schema.py +0 -12
  216. {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/WHEEL +0 -0
  217. {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE +0 -0
  218. {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
  219. {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,297 @@
1
+ from typing import Literal, Union, Optional, Any
2
+ from pathlib import Path
3
+ from torch.utils.data import DataLoader
4
+ import torch
5
+ from torch import nn
6
+ from abc import ABC, abstractmethod
7
+
8
+ from ..ML_callbacks._base import _Callback, History, TqdmProgressBar
9
+ from ..ML_callbacks._checkpoint import DragonModelCheckpoint
10
+ from ..ML_callbacks._early_stop import _DragonEarlyStopping
11
+ from ..ML_callbacks._scheduler import _DragonLRScheduler
12
+ from ..ML_evaluation import plot_losses
13
+
14
+ from ..path_manager import make_fullpath
15
+ from ..keys._keys import PyTorchCheckpointKeys, MagicWords
16
+ from .._core import get_logger
17
+
18
+
19
+ _LOGGER = get_logger("DragonTrainer")
20
+
21
+
22
+ __all__ = [
23
+ "_BaseDragonTrainer",
24
+ ]
25
+
26
+
27
+ class _BaseDragonTrainer(ABC):
28
+ """
29
+ Abstract base class for Dragon Trainers.
30
+
31
+ Handles the common training loop orchestration, checkpointing, callback
32
+ management, and device handling. Subclasses must implement the
33
+ task-specific logic (dataloaders, train/val steps, evaluation).
34
+ """
35
+ def __init__(self,
36
+ model: nn.Module,
37
+ optimizer: torch.optim.Optimizer,
38
+ device: Union[Literal['cuda', 'mps', 'cpu'],str],
39
+ dataloader_workers: int = 2,
40
+ checkpoint_callback: Optional[DragonModelCheckpoint] = None,
41
+ early_stopping_callback: Optional[_DragonEarlyStopping] = None,
42
+ lr_scheduler_callback: Optional[_DragonLRScheduler] = None,
43
+ extra_callbacks: Optional[list[_Callback]] = None):
44
+
45
+ self.model = model
46
+ self.optimizer = optimizer
47
+ self.scheduler = None
48
+ self.device = self._validate_device(device)
49
+ self.dataloader_workers = dataloader_workers
50
+
51
+ # Callback handler
52
+ default_callbacks = [History(), TqdmProgressBar()]
53
+
54
+ self._checkpoint_callback = None
55
+ if checkpoint_callback:
56
+ default_callbacks.append(checkpoint_callback)
57
+ self._checkpoint_callback = checkpoint_callback
58
+ if early_stopping_callback:
59
+ default_callbacks.append(early_stopping_callback)
60
+ if lr_scheduler_callback:
61
+ default_callbacks.append(lr_scheduler_callback)
62
+
63
+ user_callbacks = extra_callbacks if extra_callbacks is not None else []
64
+ self.callbacks = default_callbacks + user_callbacks
65
+ self._set_trainer_on_callbacks()
66
+
67
+ # Internal state
68
+ self.train_loader: Optional[DataLoader] = None
69
+ self.validation_loader: Optional[DataLoader] = None
70
+ self.history: dict[str, list[Any]] = {}
71
+ self.epoch = 0
72
+ self.epochs = 0 # Total epochs for the fit run
73
+ self.start_epoch = 1
74
+ self.stop_training = False
75
+ self._batch_size = 10
76
+
77
+ def _validate_device(self, device: str) -> torch.device:
78
+ """Validates the selected device and returns a torch.device object."""
79
+ device_lower = device.lower()
80
+ if "cuda" in device_lower and not torch.cuda.is_available():
81
+ _LOGGER.warning("CUDA not available, switching to CPU.")
82
+ device = "cpu"
83
+ elif device_lower == "mps" and not torch.backends.mps.is_available():
84
+ _LOGGER.warning("Apple Metal Performance Shaders (MPS) not available, switching to CPU.")
85
+ device = "cpu"
86
+ return torch.device(device)
87
+
88
+ def _set_trainer_on_callbacks(self):
89
+ """Gives each callback a reference to this trainer instance."""
90
+ for callback in self.callbacks:
91
+ callback.set_trainer(self)
92
+
93
+ def _load_checkpoint(self, path: Union[str, Path]):
94
+ """Loads a training checkpoint to resume training."""
95
+ p = make_fullpath(path, enforce="file")
96
+ _LOGGER.info(f"Loading checkpoint from '{p.name}'...")
97
+
98
+ try:
99
+ checkpoint = torch.load(p, map_location=self.device)
100
+
101
+ if PyTorchCheckpointKeys.MODEL_STATE not in checkpoint or PyTorchCheckpointKeys.OPTIMIZER_STATE not in checkpoint:
102
+ _LOGGER.error(f"Checkpoint file '{p.name}' is invalid. Missing 'model_state_dict' or 'optimizer_state_dict'.")
103
+ raise KeyError()
104
+
105
+ self.model.load_state_dict(checkpoint[PyTorchCheckpointKeys.MODEL_STATE])
106
+ self.optimizer.load_state_dict(checkpoint[PyTorchCheckpointKeys.OPTIMIZER_STATE])
107
+ self.epoch = checkpoint.get(PyTorchCheckpointKeys.EPOCH, 0)
108
+ self.start_epoch = self.epoch + 1 # Resume on the *next* epoch
109
+
110
+ # --- Load History ---
111
+ if PyTorchCheckpointKeys.HISTORY in checkpoint:
112
+ self.history = checkpoint[PyTorchCheckpointKeys.HISTORY]
113
+ _LOGGER.info(f"Restored training history up to epoch {self.epoch}.")
114
+ else:
115
+ _LOGGER.warning("No 'history' found in checkpoint. A new history will be started.")
116
+ self.history = {} # Ensure it's at least an empty dict
117
+
118
+ # --- Scheduler State Loading Logic ---
119
+ scheduler_state_exists = PyTorchCheckpointKeys.SCHEDULER_STATE in checkpoint
120
+ scheduler_object_exists = self.scheduler is not None
121
+
122
+ if scheduler_object_exists and scheduler_state_exists:
123
+ # Case 1: Both exist. Attempt to load.
124
+ try:
125
+ self.scheduler.load_state_dict(checkpoint[PyTorchCheckpointKeys.SCHEDULER_STATE]) # type: ignore
126
+ scheduler_name = self.scheduler.__class__.__name__
127
+ _LOGGER.info(f"Restored LR scheduler state for: {scheduler_name}")
128
+ except Exception as e:
129
+ # Loading failed, likely a mismatch
130
+ scheduler_name = self.scheduler.__class__.__name__
131
+ _LOGGER.error(f"Failed to load scheduler state for '{scheduler_name}'. A different scheduler type might have been used.")
132
+ raise e
133
+
134
+ elif scheduler_object_exists and not scheduler_state_exists:
135
+ # Case 2: Scheduler provided, but no state in checkpoint.
136
+ scheduler_name = self.scheduler.__class__.__name__
137
+ _LOGGER.warning(f"'{scheduler_name}' was provided, but no scheduler state was found in the checkpoint. The scheduler will start from its initial state.")
138
+
139
+ elif not scheduler_object_exists and scheduler_state_exists:
140
+ # Case 3: State in checkpoint, but no scheduler provided.
141
+ _LOGGER.error("Checkpoint contains an LR scheduler state, but no LRScheduler callback was provided.")
142
+ raise ValueError()
143
+
144
+ # Restore callback states
145
+ for cb in self.callbacks:
146
+ if isinstance(cb, DragonModelCheckpoint) and PyTorchCheckpointKeys.BEST_SCORE in checkpoint:
147
+ cb.best = checkpoint[PyTorchCheckpointKeys.BEST_SCORE]
148
+ _LOGGER.info(f"Restored {cb.__class__.__name__} 'best' score to: {cb.best:.4f}")
149
+
150
+ _LOGGER.info(f"Checkpoint loaded. Resuming training from epoch {self.start_epoch}.")
151
+
152
+ except Exception as e:
153
+ _LOGGER.error(f"Failed to load checkpoint from '{p}': {e}")
154
+ raise
155
+
156
+ def fit(self,
157
+ save_dir: Union[str,Path],
158
+ epochs: int = 100,
159
+ batch_size: int = 10,
160
+ shuffle: bool = True,
161
+ resume_from_checkpoint: Optional[Union[str, Path]] = None):
162
+ """
163
+ Starts the training-validation process of the model.
164
+
165
+ Returns the "History" callback dictionary.
166
+
167
+ Args:
168
+ save_dir (str | Path): Directory to save the loss plot.
169
+ epochs (int): The total number of epochs to train for.
170
+ batch_size (int): The number of samples per batch.
171
+ shuffle (bool): Whether to shuffle the training data at each epoch.
172
+ resume_from_checkpoint (str | Path | None): Optional path to a checkpoint to resume training.
173
+ """
174
+ self.epochs = epochs
175
+ self._batch_size = batch_size
176
+ self._create_dataloaders(self._batch_size, shuffle) # type: ignore
177
+ self.model.to(self.device)
178
+
179
+ if resume_from_checkpoint:
180
+ self._load_checkpoint(resume_from_checkpoint)
181
+
182
+ # Reset stop_training flag on the trainer
183
+ self.stop_training = False
184
+
185
+ self._callbacks_hook('on_train_begin')
186
+
187
+ if not self.train_loader:
188
+ _LOGGER.error("Train loader is not initialized.")
189
+ raise ValueError()
190
+
191
+ if not self.validation_loader:
192
+ _LOGGER.error("Validation loader is not initialized.")
193
+ raise ValueError()
194
+
195
+ for epoch in range(self.start_epoch, self.epochs + 1):
196
+ self.epoch = epoch
197
+ epoch_logs: dict[str, Any] = {}
198
+ self._callbacks_hook('on_epoch_begin', epoch, logs=epoch_logs)
199
+
200
+ train_logs = self._train_step()
201
+ epoch_logs.update(train_logs)
202
+
203
+ val_logs = self._validation_step()
204
+ epoch_logs.update(val_logs)
205
+
206
+ self._callbacks_hook('on_epoch_end', epoch, logs=epoch_logs)
207
+
208
+ # Check the early stopping flag
209
+ if self.stop_training:
210
+ break
211
+
212
+ self._callbacks_hook('on_train_end')
213
+
214
+ # Training History
215
+ plot_losses(self.history, save_dir=save_dir)
216
+
217
+ return self.history
218
+
219
+ def _callbacks_hook(self, method_name: str, *args, **kwargs):
220
+ """Calls the specified method on all callbacks."""
221
+ for callback in self.callbacks:
222
+ method = getattr(callback, method_name)
223
+ method(*args, **kwargs)
224
+
225
+ def to_cpu(self):
226
+ """
227
+ Moves the model to the CPU and updates the trainer's device setting.
228
+
229
+ This is useful for running operations that require the CPU.
230
+ """
231
+ self.device = torch.device('cpu')
232
+ self.model.to(self.device)
233
+ _LOGGER.info("Trainer and model moved to CPU.")
234
+
235
+ def to_device(self, device: str):
236
+ """
237
+ Moves the model to the specified device and updates the trainer's device setting.
238
+
239
+ Args:
240
+ device (str): The target device (e.g., 'cuda', 'mps', 'cpu').
241
+ """
242
+ self.device = self._validate_device(device)
243
+ self.model.to(self.device)
244
+ _LOGGER.info(f"Trainer and model moved to {self.device}.")
245
+
246
+ def _load_model_state_for_finalizing(self, model_checkpoint: Union[Path, Literal['best', 'current']]):
247
+ """
248
+ Private helper to load the correct model state_dict based on user's choice.
249
+ This is called by finalize_model_training() in subclasses.
250
+ """
251
+ if isinstance(model_checkpoint, Path):
252
+ self._load_checkpoint(path=model_checkpoint)
253
+ elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback:
254
+ path_to_latest = self._checkpoint_callback.best_checkpoint_path
255
+ self._load_checkpoint(path_to_latest)
256
+ elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback is None:
257
+ _LOGGER.error(f"'model_checkpoint' set to '{MagicWords.BEST}' but no checkpoint callback was found.")
258
+ raise ValueError()
259
+ elif model_checkpoint == MagicWords.CURRENT:
260
+ pass
261
+ else:
262
+ _LOGGER.error(f"Unknown 'model_checkpoint' received '{model_checkpoint}'.")
263
+ raise ValueError()
264
+
265
+ # --- Abstract Methods ---
266
+ # These must be implemented by subclasses
267
+
268
+ @abstractmethod
269
+ def _create_dataloaders(self, batch_size: int, shuffle: bool):
270
+ """Initializes the DataLoaders."""
271
+ raise NotImplementedError
272
+
273
+ @abstractmethod
274
+ def _train_step(self) -> dict[str, float]:
275
+ """Runs a single training epoch."""
276
+ raise NotImplementedError
277
+
278
+ @abstractmethod
279
+ def _validation_step(self) -> dict[str, float]:
280
+ """Runs a single validation epoch."""
281
+ raise NotImplementedError
282
+
283
+ @abstractmethod
284
+ def evaluate(self, *args, **kwargs):
285
+ """Runs the full model evaluation."""
286
+ raise NotImplementedError
287
+
288
+ @abstractmethod
289
+ def _evaluate(self, *args, **kwargs):
290
+ """Internal evaluation helper."""
291
+ raise NotImplementedError
292
+
293
+ @abstractmethod
294
+ def finalize_model_training(self, *args, **kwargs):
295
+ """Saves the finalized model for inference."""
296
+ raise NotImplementedError
297
+