dragon-ml-toolbox 19.14.0__py3-none-any.whl → 20.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/METADATA +29 -46
  2. dragon_ml_toolbox-20.0.0.dist-info/RECORD +178 -0
  3. ml_tools/{ETL_cleaning.py → ETL_cleaning/__init__.py} +13 -5
  4. ml_tools/ETL_cleaning/_basic_clean.py +351 -0
  5. ml_tools/ETL_cleaning/_clean_tools.py +128 -0
  6. ml_tools/ETL_cleaning/_dragon_cleaner.py +245 -0
  7. ml_tools/ETL_cleaning/_imprimir.py +13 -0
  8. ml_tools/{ETL_engineering.py → ETL_engineering/__init__.py} +8 -4
  9. ml_tools/ETL_engineering/_dragon_engineering.py +261 -0
  10. ml_tools/ETL_engineering/_imprimir.py +24 -0
  11. ml_tools/{_core/_ETL_engineering.py → ETL_engineering/_transforms.py} +14 -267
  12. ml_tools/{_core → GUI_tools}/_GUI_tools.py +37 -40
  13. ml_tools/{GUI_tools.py → GUI_tools/__init__.py} +7 -5
  14. ml_tools/GUI_tools/_imprimir.py +12 -0
  15. ml_tools/IO_tools/_IO_loggers.py +235 -0
  16. ml_tools/IO_tools/_IO_save_load.py +151 -0
  17. ml_tools/IO_tools/_IO_utils.py +140 -0
  18. ml_tools/{IO_tools.py → IO_tools/__init__.py} +13 -5
  19. ml_tools/IO_tools/_imprimir.py +14 -0
  20. ml_tools/MICE/_MICE_imputation.py +132 -0
  21. ml_tools/{MICE_imputation.py → MICE/__init__.py} +6 -7
  22. ml_tools/{_core/_MICE_imputation.py → MICE/_dragon_mice.py} +243 -322
  23. ml_tools/MICE/_imprimir.py +11 -0
  24. ml_tools/{ML_callbacks.py → ML_callbacks/__init__.py} +12 -4
  25. ml_tools/ML_callbacks/_base.py +101 -0
  26. ml_tools/ML_callbacks/_checkpoint.py +232 -0
  27. ml_tools/ML_callbacks/_early_stop.py +208 -0
  28. ml_tools/ML_callbacks/_imprimir.py +12 -0
  29. ml_tools/ML_callbacks/_scheduler.py +197 -0
  30. ml_tools/{ML_chaining_utilities.py → ML_chain/__init__.py} +8 -3
  31. ml_tools/{_core/_ML_chaining_utilities.py → ML_chain/_chaining_tools.py} +5 -129
  32. ml_tools/ML_chain/_dragon_chain.py +140 -0
  33. ml_tools/ML_chain/_imprimir.py +11 -0
  34. ml_tools/ML_configuration/__init__.py +90 -0
  35. ml_tools/ML_configuration/_base_model_config.py +69 -0
  36. ml_tools/ML_configuration/_finalize.py +366 -0
  37. ml_tools/ML_configuration/_imprimir.py +47 -0
  38. ml_tools/ML_configuration/_metrics.py +593 -0
  39. ml_tools/ML_configuration/_models.py +206 -0
  40. ml_tools/ML_configuration/_training.py +124 -0
  41. ml_tools/ML_datasetmaster/__init__.py +28 -0
  42. ml_tools/ML_datasetmaster/_base_datasetmaster.py +337 -0
  43. ml_tools/{_core/_ML_datasetmaster.py → ML_datasetmaster/_datasetmaster.py} +9 -329
  44. ml_tools/ML_datasetmaster/_imprimir.py +15 -0
  45. ml_tools/{_core/_ML_sequence_datasetmaster.py → ML_datasetmaster/_sequence_datasetmaster.py} +13 -15
  46. ml_tools/{_core/_ML_vision_datasetmaster.py → ML_datasetmaster/_vision_datasetmaster.py} +63 -65
  47. ml_tools/ML_evaluation/__init__.py +53 -0
  48. ml_tools/ML_evaluation/_classification.py +629 -0
  49. ml_tools/ML_evaluation/_feature_importance.py +409 -0
  50. ml_tools/ML_evaluation/_imprimir.py +25 -0
  51. ml_tools/ML_evaluation/_loss.py +92 -0
  52. ml_tools/ML_evaluation/_regression.py +273 -0
  53. ml_tools/{_core/_ML_sequence_evaluation.py → ML_evaluation/_sequence.py} +8 -11
  54. ml_tools/{_core/_ML_vision_evaluation.py → ML_evaluation/_vision.py} +12 -17
  55. ml_tools/{_core → ML_evaluation_captum}/_ML_evaluation_captum.py +11 -38
  56. ml_tools/{ML_evaluation_captum.py → ML_evaluation_captum/__init__.py} +6 -4
  57. ml_tools/ML_evaluation_captum/_imprimir.py +10 -0
  58. ml_tools/{_core → ML_finalize_handler}/_ML_finalize_handler.py +3 -7
  59. ml_tools/ML_finalize_handler/__init__.py +10 -0
  60. ml_tools/ML_finalize_handler/_imprimir.py +8 -0
  61. ml_tools/ML_inference/__init__.py +22 -0
  62. ml_tools/ML_inference/_base_inference.py +166 -0
  63. ml_tools/{_core/_ML_chaining_inference.py → ML_inference/_chain_inference.py} +14 -17
  64. ml_tools/ML_inference/_dragon_inference.py +332 -0
  65. ml_tools/ML_inference/_imprimir.py +11 -0
  66. ml_tools/ML_inference/_multi_inference.py +180 -0
  67. ml_tools/ML_inference_sequence/__init__.py +10 -0
  68. ml_tools/ML_inference_sequence/_imprimir.py +8 -0
  69. ml_tools/{_core/_ML_sequence_inference.py → ML_inference_sequence/_sequence_inference.py} +11 -15
  70. ml_tools/ML_inference_vision/__init__.py +10 -0
  71. ml_tools/ML_inference_vision/_imprimir.py +8 -0
  72. ml_tools/{_core/_ML_vision_inference.py → ML_inference_vision/_vision_inference.py} +15 -19
  73. ml_tools/ML_models/__init__.py +32 -0
  74. ml_tools/{_core/_ML_models_advanced.py → ML_models/_advanced_models.py} +22 -18
  75. ml_tools/ML_models/_base_mlp_attention.py +198 -0
  76. ml_tools/{_core/_models_advanced_base.py → ML_models/_base_save_load.py} +73 -49
  77. ml_tools/ML_models/_dragon_tabular.py +248 -0
  78. ml_tools/ML_models/_imprimir.py +18 -0
  79. ml_tools/ML_models/_mlp_attention.py +134 -0
  80. ml_tools/{_core → ML_models}/_models_advanced_helpers.py +13 -13
  81. ml_tools/ML_models_sequence/__init__.py +10 -0
  82. ml_tools/ML_models_sequence/_imprimir.py +8 -0
  83. ml_tools/{_core/_ML_sequence_models.py → ML_models_sequence/_sequence_models.py} +5 -8
  84. ml_tools/ML_models_vision/__init__.py +29 -0
  85. ml_tools/ML_models_vision/_base_wrapper.py +254 -0
  86. ml_tools/ML_models_vision/_image_classification.py +182 -0
  87. ml_tools/ML_models_vision/_image_segmentation.py +108 -0
  88. ml_tools/ML_models_vision/_imprimir.py +16 -0
  89. ml_tools/ML_models_vision/_object_detection.py +135 -0
  90. ml_tools/ML_optimization/__init__.py +21 -0
  91. ml_tools/ML_optimization/_imprimir.py +13 -0
  92. ml_tools/{_core/_ML_optimization_pareto.py → ML_optimization/_multi_dragon.py} +18 -24
  93. ml_tools/ML_optimization/_single_dragon.py +203 -0
  94. ml_tools/{_core/_ML_optimization.py → ML_optimization/_single_manual.py} +75 -213
  95. ml_tools/{_core → ML_scaler}/_ML_scaler.py +8 -11
  96. ml_tools/ML_scaler/__init__.py +10 -0
  97. ml_tools/ML_scaler/_imprimir.py +8 -0
  98. ml_tools/ML_trainer/__init__.py +20 -0
  99. ml_tools/ML_trainer/_base_trainer.py +297 -0
  100. ml_tools/ML_trainer/_dragon_detection_trainer.py +402 -0
  101. ml_tools/ML_trainer/_dragon_sequence_trainer.py +540 -0
  102. ml_tools/ML_trainer/_dragon_trainer.py +1160 -0
  103. ml_tools/ML_trainer/_imprimir.py +10 -0
  104. ml_tools/{ML_utilities.py → ML_utilities/__init__.py} +14 -6
  105. ml_tools/ML_utilities/_artifact_finder.py +382 -0
  106. ml_tools/ML_utilities/_imprimir.py +16 -0
  107. ml_tools/ML_utilities/_inspection.py +325 -0
  108. ml_tools/ML_utilities/_train_tools.py +205 -0
  109. ml_tools/{ML_vision_transformers.py → ML_vision_transformers/__init__.py} +9 -6
  110. ml_tools/{_core/_ML_vision_transformers.py → ML_vision_transformers/_core_transforms.py} +11 -155
  111. ml_tools/ML_vision_transformers/_imprimir.py +14 -0
  112. ml_tools/ML_vision_transformers/_offline_augmentation.py +159 -0
  113. ml_tools/{_core/_PSO_optimization.py → PSO_optimization/_PSO.py} +58 -15
  114. ml_tools/{PSO_optimization.py → PSO_optimization/__init__.py} +5 -3
  115. ml_tools/PSO_optimization/_imprimir.py +10 -0
  116. ml_tools/SQL/__init__.py +7 -0
  117. ml_tools/{_core/_SQL.py → SQL/_dragon_SQL.py} +7 -11
  118. ml_tools/SQL/_imprimir.py +8 -0
  119. ml_tools/{_core → VIF}/_VIF_factor.py +5 -8
  120. ml_tools/{VIF_factor.py → VIF/__init__.py} +4 -2
  121. ml_tools/VIF/_imprimir.py +10 -0
  122. ml_tools/_core/__init__.py +7 -1
  123. ml_tools/_core/_logger.py +8 -18
  124. ml_tools/_core/_schema_load_ops.py +43 -0
  125. ml_tools/_core/_script_info.py +2 -2
  126. ml_tools/{data_exploration.py → data_exploration/__init__.py} +32 -16
  127. ml_tools/data_exploration/_analysis.py +214 -0
  128. ml_tools/data_exploration/_cleaning.py +566 -0
  129. ml_tools/data_exploration/_features.py +583 -0
  130. ml_tools/data_exploration/_imprimir.py +32 -0
  131. ml_tools/data_exploration/_plotting.py +487 -0
  132. ml_tools/data_exploration/_schema_ops.py +176 -0
  133. ml_tools/{ensemble_evaluation.py → ensemble_evaluation/__init__.py} +6 -4
  134. ml_tools/{_core → ensemble_evaluation}/_ensemble_evaluation.py +3 -7
  135. ml_tools/ensemble_evaluation/_imprimir.py +14 -0
  136. ml_tools/{ensemble_inference.py → ensemble_inference/__init__.py} +5 -3
  137. ml_tools/{_core → ensemble_inference}/_ensemble_inference.py +15 -18
  138. ml_tools/ensemble_inference/_imprimir.py +9 -0
  139. ml_tools/{ensemble_learning.py → ensemble_learning/__init__.py} +4 -6
  140. ml_tools/{_core → ensemble_learning}/_ensemble_learning.py +7 -10
  141. ml_tools/ensemble_learning/_imprimir.py +10 -0
  142. ml_tools/{excel_handler.py → excel_handler/__init__.py} +5 -3
  143. ml_tools/{_core → excel_handler}/_excel_handler.py +6 -10
  144. ml_tools/excel_handler/_imprimir.py +13 -0
  145. ml_tools/{keys.py → keys/__init__.py} +4 -1
  146. ml_tools/keys/_imprimir.py +11 -0
  147. ml_tools/{_core → keys}/_keys.py +2 -0
  148. ml_tools/{math_utilities.py → math_utilities/__init__.py} +5 -2
  149. ml_tools/math_utilities/_imprimir.py +11 -0
  150. ml_tools/{_core → math_utilities}/_math_utilities.py +1 -5
  151. ml_tools/{optimization_tools.py → optimization_tools/__init__.py} +9 -4
  152. ml_tools/optimization_tools/_imprimir.py +13 -0
  153. ml_tools/optimization_tools/_optimization_bounds.py +236 -0
  154. ml_tools/optimization_tools/_optimization_plots.py +218 -0
  155. ml_tools/{path_manager.py → path_manager/__init__.py} +6 -3
  156. ml_tools/{_core/_path_manager.py → path_manager/_dragonmanager.py} +11 -347
  157. ml_tools/path_manager/_imprimir.py +15 -0
  158. ml_tools/path_manager/_path_tools.py +346 -0
  159. ml_tools/plot_fonts/__init__.py +8 -0
  160. ml_tools/plot_fonts/_imprimir.py +8 -0
  161. ml_tools/{_core → plot_fonts}/_plot_fonts.py +2 -5
  162. ml_tools/schema/__init__.py +15 -0
  163. ml_tools/schema/_feature_schema.py +223 -0
  164. ml_tools/schema/_gui_schema.py +191 -0
  165. ml_tools/schema/_imprimir.py +10 -0
  166. ml_tools/{serde.py → serde/__init__.py} +4 -2
  167. ml_tools/serde/_imprimir.py +10 -0
  168. ml_tools/{_core → serde}/_serde.py +3 -8
  169. ml_tools/{utilities.py → utilities/__init__.py} +11 -6
  170. ml_tools/utilities/_imprimir.py +18 -0
  171. ml_tools/{_core/_utilities.py → utilities/_utility_save_load.py} +13 -190
  172. ml_tools/utilities/_utility_tools.py +192 -0
  173. dragon_ml_toolbox-19.14.0.dist-info/RECORD +0 -111
  174. ml_tools/ML_chaining_inference.py +0 -8
  175. ml_tools/ML_configuration.py +0 -86
  176. ml_tools/ML_configuration_pytab.py +0 -14
  177. ml_tools/ML_datasetmaster.py +0 -10
  178. ml_tools/ML_evaluation.py +0 -16
  179. ml_tools/ML_evaluation_multi.py +0 -12
  180. ml_tools/ML_finalize_handler.py +0 -8
  181. ml_tools/ML_inference.py +0 -12
  182. ml_tools/ML_models.py +0 -14
  183. ml_tools/ML_models_advanced.py +0 -14
  184. ml_tools/ML_models_pytab.py +0 -14
  185. ml_tools/ML_optimization.py +0 -14
  186. ml_tools/ML_optimization_pareto.py +0 -8
  187. ml_tools/ML_scaler.py +0 -8
  188. ml_tools/ML_sequence_datasetmaster.py +0 -8
  189. ml_tools/ML_sequence_evaluation.py +0 -10
  190. ml_tools/ML_sequence_inference.py +0 -8
  191. ml_tools/ML_sequence_models.py +0 -8
  192. ml_tools/ML_trainer.py +0 -12
  193. ml_tools/ML_vision_datasetmaster.py +0 -12
  194. ml_tools/ML_vision_evaluation.py +0 -10
  195. ml_tools/ML_vision_inference.py +0 -8
  196. ml_tools/ML_vision_models.py +0 -18
  197. ml_tools/SQL.py +0 -8
  198. ml_tools/_core/_ETL_cleaning.py +0 -694
  199. ml_tools/_core/_IO_tools.py +0 -498
  200. ml_tools/_core/_ML_callbacks.py +0 -702
  201. ml_tools/_core/_ML_configuration.py +0 -1332
  202. ml_tools/_core/_ML_configuration_pytab.py +0 -102
  203. ml_tools/_core/_ML_evaluation.py +0 -867
  204. ml_tools/_core/_ML_evaluation_multi.py +0 -544
  205. ml_tools/_core/_ML_inference.py +0 -646
  206. ml_tools/_core/_ML_models.py +0 -668
  207. ml_tools/_core/_ML_models_pytab.py +0 -693
  208. ml_tools/_core/_ML_trainer.py +0 -2323
  209. ml_tools/_core/_ML_utilities.py +0 -886
  210. ml_tools/_core/_ML_vision_models.py +0 -644
  211. ml_tools/_core/_data_exploration.py +0 -1909
  212. ml_tools/_core/_optimization_tools.py +0 -493
  213. ml_tools/_core/_schema.py +0 -359
  214. ml_tools/plot_fonts.py +0 -8
  215. ml_tools/schema.py +0 -12
  216. {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/WHEEL +0 -0
  217. {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE +0 -0
  218. {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
  219. {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,11 @@
1
+ from .._core import _imprimir_disponibles
2
+
3
+ _GRUPOS = [
4
+ "DragonMICE",
5
+ "get_convergence_diagnostic",
6
+ "get_imputed_distributions",
7
+ "run_mice_pipeline",
8
+ ]
9
+
10
+ def info():
11
+ _imprimir_disponibles(_GRUPOS)
@@ -1,16 +1,24 @@
1
- from ._core._ML_callbacks import (
1
+ from ._early_stop import (
2
2
  DragonPatienceEarlyStopping,
3
3
  DragonPrecheltEarlyStopping,
4
+ )
5
+
6
+ from ._checkpoint import (
4
7
  DragonModelCheckpoint,
8
+ )
9
+
10
+ from ._scheduler import (
5
11
  DragonScheduler,
6
- DragonReduceLROnPlateau,
7
- info
12
+ DragonPlateauScheduler,
8
13
  )
9
14
 
15
+ from ._imprimir import info
16
+
17
+
10
18
  __all__ = [
11
19
  "DragonPatienceEarlyStopping",
12
20
  "DragonPrecheltEarlyStopping",
13
21
  "DragonModelCheckpoint",
14
22
  "DragonScheduler",
15
- "DragonReduceLROnPlateau",
23
+ "DragonPlateauScheduler",
16
24
  ]
@@ -0,0 +1,101 @@
1
+ from tqdm.auto import tqdm
2
+
3
+ from ..keys._keys import PyTorchLogKeys
4
+
5
+
6
+ __all__ = [
7
+ "_Callback",
8
+ "History",
9
+ "TqdmProgressBar",
10
+ ]
11
+
12
+
13
+ class _Callback:
14
+ """
15
+ Abstract base class used to build new callbacks.
16
+
17
+ The methods of this class are automatically called by the Trainer at different
18
+ points during training. Subclasses can override these methods to implement
19
+ custom logic.
20
+ """
21
+ def __init__(self):
22
+ self.trainer = None
23
+
24
+ def set_trainer(self, trainer):
25
+ """This is called by the Trainer to associate itself with the callback."""
26
+ self.trainer = trainer
27
+
28
+ def on_train_begin(self, logs=None):
29
+ """Called at the beginning of training."""
30
+ pass
31
+
32
+ def on_train_end(self, logs=None):
33
+ """Called at the end of training."""
34
+ pass
35
+
36
+ def on_epoch_begin(self, epoch, logs=None):
37
+ """Called at the beginning of an epoch."""
38
+ pass
39
+
40
+ def on_epoch_end(self, epoch, logs=None):
41
+ """Called at the end of an epoch."""
42
+ pass
43
+
44
+ def on_batch_begin(self, batch, logs=None):
45
+ """Called at the beginning of a training batch."""
46
+ pass
47
+
48
+ def on_batch_end(self, batch, logs=None):
49
+ """Called at the end of a training batch."""
50
+ pass
51
+
52
+
53
+ class History(_Callback):
54
+ """
55
+ Callback that records events into a `history` dictionary.
56
+
57
+ This callback is automatically applied to every MyTrainer model.
58
+ The `history` attribute is a dictionary mapping metric names (e.g., 'val_loss')
59
+ to a list of metric values.
60
+ """
61
+ def on_train_begin(self, logs=None):
62
+ # Clear history at the beginning of training
63
+ self.trainer.history = {} # type: ignore
64
+
65
+ def on_epoch_end(self, epoch, logs=None):
66
+ logs = logs or {}
67
+ for k, v in logs.items():
68
+ # Append new log values to the history dictionary
69
+ self.trainer.history.setdefault(k, []).append(v) # type: ignore
70
+
71
+
72
+ class TqdmProgressBar(_Callback):
73
+ """Callback that provides a tqdm progress bar for training."""
74
+ def __init__(self):
75
+ self.epoch_bar = None
76
+ self.batch_bar = None
77
+
78
+ def on_train_begin(self, logs=None):
79
+ self.epochs = self.trainer.epochs # type: ignore
80
+ self.epoch_bar = tqdm(total=self.epochs, desc="Training Progress")
81
+
82
+ def on_epoch_begin(self, epoch, logs=None):
83
+ total_batches = len(self.trainer.train_loader) # type: ignore
84
+ self.batch_bar = tqdm(total=total_batches, desc=f"Epoch {epoch}/{self.epochs}", leave=False)
85
+
86
+ def on_batch_end(self, batch, logs=None):
87
+ self.batch_bar.update(1) # type: ignore
88
+ if logs:
89
+ self.batch_bar.set_postfix(loss=f"{logs.get(PyTorchLogKeys.BATCH_LOSS, 0):.4f}") # type: ignore
90
+
91
+ def on_epoch_end(self, epoch, logs=None):
92
+ self.batch_bar.close() # type: ignore
93
+ self.epoch_bar.update(1) # type: ignore
94
+ if logs:
95
+ train_loss_str = f"{logs.get(PyTorchLogKeys.TRAIN_LOSS, 0):.4f}"
96
+ val_loss_str = f"{logs.get(PyTorchLogKeys.VAL_LOSS, 0):.4f}"
97
+ self.epoch_bar.set_postfix_str(f"Train Loss: {train_loss_str}, Val Loss: {val_loss_str}") # type: ignore
98
+
99
+ def on_train_end(self, logs=None):
100
+ self.epoch_bar.close() # type: ignore
101
+
@@ -0,0 +1,232 @@
1
+ import numpy as np
2
+ import torch
3
+ from typing import Union, Literal
4
+ from pathlib import Path
5
+
6
+ from ..path_manager import make_fullpath
7
+ from ..keys._keys import PyTorchLogKeys, PyTorchCheckpointKeys
8
+ from .._core import get_logger
9
+
10
+ from ._base import _Callback
11
+
12
+
13
+ _LOGGER = get_logger("Checkpoint")
14
+
15
+
16
+ __all__ = [
17
+ "DragonModelCheckpoint",
18
+ ]
19
+
20
+
21
+ class DragonModelCheckpoint(_Callback):
22
+ """
23
+ Saves the model weights, optimizer state, LR scheduler state (if any), and epoch number to a directory with automated filename generation and rotation.
24
+ """
25
+ def __init__(self,
26
+ save_dir: Union[str, Path],
27
+ monitor: Literal["Training Loss", "Validation Loss", "both"] = "Validation Loss",
28
+ save_three_best: bool = True,
29
+ mode: Literal['min', 'max'] = 'min',
30
+ verbose: int = 0):
31
+ """
32
+ Args:
33
+ save_dir (str): Directory where checkpoint files will be saved.
34
+ monitor (str): Metric to monitor. If "both", the sum of training loss and validation loss is used.
35
+ save_three_best (bool):
36
+ - If True, keeps the top 3 best checkpoints found during training (based on metric).
37
+ - If False, keeps the 3 most recent checkpoints (rolling window).
38
+ mode (str): One of {'min', 'max'}.
39
+ verbose (int): Verbosity mode.
40
+ """
41
+ super().__init__()
42
+ self.save_dir = make_fullpath(save_dir, make=True, enforce="directory")
43
+
44
+ # Standardize monitor key
45
+ if monitor == "Training Loss":
46
+ std_monitor = PyTorchLogKeys.TRAIN_LOSS
47
+ elif monitor == "Validation Loss":
48
+ std_monitor = PyTorchLogKeys.VAL_LOSS
49
+ elif monitor == "both":
50
+ std_monitor = "both"
51
+ else:
52
+ _LOGGER.error(f"Unknown monitor key: {monitor}.")
53
+ raise ValueError()
54
+
55
+ self.monitor = std_monitor
56
+ self.save_three_best = save_three_best
57
+ self.verbose = verbose
58
+ self._latest_checkpoint_path = None
59
+ self._checkpoint_name = PyTorchCheckpointKeys.CHECKPOINT_NAME
60
+
61
+ # State variables
62
+ # stored as list of dicts: [{'path': Path, 'score': float, 'epoch': int}]
63
+ self.best_checkpoints = []
64
+ # For rolling check (save_three_best=False)
65
+ self.recent_checkpoints = []
66
+
67
+ if mode not in ['min', 'max']:
68
+ _LOGGER.error(f"ModelCheckpoint mode {mode} is unknown. Use 'min' or 'max'.")
69
+ raise ValueError()
70
+ self.mode = mode
71
+
72
+ # Determine comparison operator
73
+ if self.mode == 'min':
74
+ self.monitor_op = np.less
75
+ self.best = np.inf
76
+ else:
77
+ self.monitor_op = np.greater
78
+ self.best = -np.inf
79
+
80
+ def on_train_begin(self, logs=None):
81
+ """Reset file tracking state when training starts.
82
+ NOTE: Do nOT reset self.best here if it differs from the default. This allows the Trainer to restore 'best' from a checkpoint before calling train()."""
83
+ self.best_checkpoints = []
84
+ self.recent_checkpoints = []
85
+
86
+ # Check if self.best is at default initialization value
87
+ is_default_min = (self.mode == 'min' and self.best == np.inf)
88
+ is_default_max = (self.mode == 'max' and self.best == -np.inf)
89
+
90
+ # If it is NOT default, it means it was restored.
91
+ if not (is_default_min or is_default_max):
92
+ _LOGGER.debug(f"Resuming with best score: {self.best:.4f}")
93
+
94
+ def _get_metric_value(self, logs):
95
+ """Extracts or calculates the metric value based on configuration."""
96
+ if self.monitor == "both":
97
+ t_loss = logs.get(PyTorchLogKeys.TRAIN_LOSS)
98
+ v_loss = logs.get(PyTorchLogKeys.VAL_LOSS)
99
+ if t_loss is None or v_loss is None:
100
+ return None
101
+ return t_loss + v_loss
102
+ else:
103
+ return logs.get(self.monitor)
104
+
105
+ def on_epoch_end(self, epoch, logs=None):
106
+ logs = logs or {}
107
+ current_score = self._get_metric_value(logs)
108
+
109
+ if current_score is None:
110
+ if self.verbose > 0:
111
+ _LOGGER.warning(f"Epoch {epoch}: Metric '{self.monitor}' not found in logs. Skipping checkpoint.")
112
+ return
113
+
114
+ # 1. Update global best score (for logging/metadata)
115
+ if self.monitor_op(current_score, self.best):
116
+ if self.verbose > 0:
117
+ # Only log explicit "improvement" if we are beating the historical best
118
+ old_best_str = f"{self.best:.4f}" if not np.isinf(self.best) else "inf"
119
+ _LOGGER.info(f"Epoch {epoch}: {self.monitor} improved from {old_best_str} to {current_score:.4f}")
120
+ self.best = current_score
121
+
122
+ if self.save_three_best:
123
+ self._save_top_k_checkpoints(epoch, current_score)
124
+ else:
125
+ self._save_rolling_checkpoints(epoch, current_score)
126
+
127
+ def _save_checkpoint_file(self, epoch, current_score):
128
+ """Helper to physically save the file."""
129
+ self.save_dir.mkdir(parents=True, exist_ok=True)
130
+
131
+ # Create filename
132
+ score_str = f"{current_score:.4f}".replace('.', '_')
133
+ filename = f"epoch{epoch}_{self._checkpoint_name}-{score_str}.pth"
134
+ filepath = self.save_dir / filename
135
+
136
+ # Create checkpoint dict
137
+ checkpoint_data = {
138
+ PyTorchCheckpointKeys.EPOCH: epoch,
139
+ PyTorchCheckpointKeys.MODEL_STATE: self.trainer.model.state_dict(), # type: ignore
140
+ PyTorchCheckpointKeys.OPTIMIZER_STATE: self.trainer.optimizer.state_dict(), # type: ignore
141
+ PyTorchCheckpointKeys.BEST_SCORE: current_score,
142
+ PyTorchCheckpointKeys.HISTORY: self.trainer.history, # type: ignore
143
+ }
144
+
145
+ if hasattr(self.trainer, 'scheduler') and self.trainer.scheduler is not None: # type: ignore
146
+ checkpoint_data[PyTorchCheckpointKeys.SCHEDULER_STATE] = self.trainer.scheduler.state_dict() # type: ignore
147
+
148
+ torch.save(checkpoint_data, filepath)
149
+ self._latest_checkpoint_path = filepath
150
+
151
+ return filepath
152
+
153
+ def _save_top_k_checkpoints(self, epoch, current_score):
154
+ """Logic for maintaining the top 3 best checkpoints."""
155
+
156
+ def sort_key(item): return item['score']
157
+
158
+ # Determine sort direction so that Index 0 is BEST and Index -1 is WORST
159
+ # Min mode (lower is better): Ascending (reverse=False) -> [0.1, 0.5, 0.9] (0.1 is best)
160
+ # Max mode (higher is better): Descending (reverse=True) -> [0.9, 0.5, 0.1] (0.9 is best)
161
+ is_reverse = (self.mode == 'max')
162
+
163
+ should_save = False
164
+
165
+ if len(self.best_checkpoints) < 3:
166
+ should_save = True
167
+ else:
168
+ # Sort current list to identify the worst (last item)
169
+ self.best_checkpoints.sort(key=sort_key, reverse=is_reverse)
170
+ worst_entry = self.best_checkpoints[-1]
171
+
172
+ # Check if current is better than the worst in the list
173
+ # min mode: current < worst['score']
174
+ # max mode: current > worst['score']
175
+ if self.monitor_op(current_score, worst_entry['score']):
176
+ should_save = True
177
+
178
+ if should_save:
179
+ filepath = self._save_checkpoint_file(epoch, current_score)
180
+
181
+ if self.verbose > 0:
182
+ _LOGGER.info(f"Epoch {epoch}: {self.monitor} ({current_score:.4f}) is in top 3. Saving to {filepath.name}")
183
+
184
+ self.best_checkpoints.append({'path': filepath, 'score': current_score, 'epoch': epoch})
185
+
186
+ # Prune if > 3
187
+ if len(self.best_checkpoints) > 3:
188
+ # Re-sort to ensure worst is at the end
189
+ self.best_checkpoints.sort(key=sort_key, reverse=is_reverse)
190
+
191
+ # Evict the last one (Worst)
192
+ entry_to_delete = self.best_checkpoints.pop(-1)
193
+
194
+ if entry_to_delete['path'].exists():
195
+ if self.verbose > 0:
196
+ _LOGGER.info(f" -> Deleting checkpoint outside top 3: {entry_to_delete['path'].name}")
197
+ entry_to_delete['path'].unlink()
198
+
199
+ def _save_rolling_checkpoints(self, epoch, current_score):
200
+ """Saves the latest model and keeps only the 3 most recent ones."""
201
+ filepath = self._save_checkpoint_file(epoch, current_score)
202
+
203
+ if self.verbose > 0:
204
+ _LOGGER.info(f'Epoch {epoch}: saving rolling model to {filepath.name}')
205
+
206
+ self.recent_checkpoints.append(filepath)
207
+
208
+ # If we have more than 3 checkpoints, remove the oldest one
209
+ if len(self.recent_checkpoints) > 3:
210
+ file_to_delete = self.recent_checkpoints.pop(0)
211
+ if file_to_delete.exists():
212
+ if self.verbose > 0:
213
+ _LOGGER.info(f" -> Deleting old rolling checkpoint: {file_to_delete.name}")
214
+ file_to_delete.unlink()
215
+
216
+ @property
217
+ def best_checkpoint_path(self):
218
+ # If tracking top 3, return the absolute best among them
219
+ if self.save_three_best and self.best_checkpoints:
220
+ def sort_key(item): return item['score']
221
+ is_reverse = (self.mode == 'max')
222
+ # Sort Best -> Worst
223
+ sorted_bests = sorted(self.best_checkpoints, key=sort_key, reverse=is_reverse)
224
+ # Index 0 is always the best based on the logic above
225
+ return sorted_bests[0]['path']
226
+
227
+ elif self._latest_checkpoint_path:
228
+ return self._latest_checkpoint_path
229
+ else:
230
+ _LOGGER.error("No checkpoint paths saved.")
231
+ raise ValueError()
232
+
@@ -0,0 +1,208 @@
1
+ import numpy as np
2
+ from collections import deque
3
+ from typing import Literal
4
+
5
+ from ..keys._keys import PyTorchLogKeys
6
+ from .._core import get_logger
7
+
8
+ from ._base import _Callback
9
+
10
+
11
+ _LOGGER = get_logger("EarlyStopping")
12
+
13
+
14
+ __all__ = [
15
+ "DragonPatienceEarlyStopping",
16
+ "DragonPrecheltEarlyStopping",
17
+ ]
18
+
19
+
20
+ class _DragonEarlyStopping(_Callback):
21
+ """
22
+ Base class for Early Stopping strategies.
23
+ Ensures type compatibility and shared logging logic.
24
+ """
25
+ def __init__(self,
26
+ monitor: str,
27
+ verbose: int = 1):
28
+ super().__init__()
29
+ self.monitor = monitor
30
+ self.verbose = verbose
31
+ self.stopped_epoch = 0
32
+
33
+ def _stop_training(self, epoch: int, reason: str):
34
+ """Helper to trigger the stop."""
35
+ self.stopped_epoch = epoch
36
+ self.trainer.stop_training = True # type: ignore
37
+ if self.verbose > 0:
38
+ _LOGGER.info(f"Epoch {epoch}: Early stopping triggered. Reason: {reason}")
39
+
40
+
41
+ class DragonPatienceEarlyStopping(_DragonEarlyStopping):
42
+ """
43
+ Standard early stopping: Tracks minimum validation loss (or other metric) with a patience counter.
44
+ """
45
+ def __init__(self,
46
+ monitor: Literal["Training Loss", "Validation Loss"] = "Validation Loss",
47
+ min_delta: float = 0.0,
48
+ patience: int = 10,
49
+ mode: Literal['min', 'max'] = 'min',
50
+ verbose: int = 1):
51
+ """
52
+ Args:
53
+ monitor (str): Metric to monitor.
54
+ min_delta (float): Minimum change to qualify as an improvement.
55
+ patience (int): Number of epochs with no improvement after which training will be stopped.
56
+ mode (str): One of {'min', 'max'}. In 'min' mode, training will stop when the quantity monitored has stopped decreasing; in 'max' mode it will stop when the quantity monitored has stopped increasing.
57
+ verbose (int): Verbosity mode.
58
+ """
59
+ # standardize monitor key
60
+ if monitor == "Training Loss":
61
+ std_monitor = PyTorchLogKeys.TRAIN_LOSS
62
+ elif monitor == "Validation Loss":
63
+ std_monitor = PyTorchLogKeys.VAL_LOSS
64
+ else:
65
+ _LOGGER.error(f"Unknown monitor key: {monitor}.")
66
+ raise ValueError()
67
+
68
+ super().__init__(std_monitor, verbose)
69
+ self.patience = patience
70
+ self.min_delta = min_delta
71
+ self.wait = 0
72
+ self.mode = mode
73
+
74
+ if mode not in ['min', 'max']:
75
+ _LOGGER.error(f"EarlyStopping mode {mode} is unknown, choose one of ('min', 'max')")
76
+ raise ValueError()
77
+
78
+ # Determine the comparison operator
79
+ if self.mode == 'min':
80
+ self.monitor_op = np.less
81
+ elif self.mode == 'max':
82
+ self.monitor_op = np.greater
83
+ else:
84
+ # raise error for unknown mode
85
+ _LOGGER.error(f"EarlyStopping mode {mode} is unknown, choose one of ('min', 'max')")
86
+ raise ValueError()
87
+
88
+ self.best = np.inf if self.monitor_op == np.less else -np.inf
89
+
90
+ def on_train_begin(self, logs=None):
91
+ self.wait = 0
92
+ self.best = np.inf if self.monitor_op == np.less else -np.inf
93
+
94
+ def on_epoch_end(self, epoch, logs=None):
95
+ current = logs.get(self.monitor) # type: ignore
96
+ if current is None:
97
+ return
98
+
99
+ # Check improvement
100
+ if self.monitor_op == np.less:
101
+ is_improvement = self.monitor_op(current, self.best - self.min_delta)
102
+ else:
103
+ is_improvement = self.monitor_op(current, self.best + self.min_delta)
104
+
105
+ if is_improvement:
106
+ if self.verbose > 1:
107
+ _LOGGER.info(f"EarlyStopping: {self.monitor} improved from {self.best:.4f} to {current:.4f}")
108
+ self.best = current
109
+ self.wait = 0
110
+ else:
111
+ self.wait += 1
112
+ if self.wait >= self.patience:
113
+ self._stop_training(epoch, f"No improvement in {self.monitor} for {self.wait} epochs.")
114
+
115
+
116
+ class DragonPrecheltEarlyStopping(_DragonEarlyStopping):
117
+ """
118
+ Implements Prechelt's 'Progress-Modified GL' criterion.
119
+ Tracks the ratio between Generalization Loss (overfitting) and Training Progress.
120
+
121
+ References:
122
+ Prechelt, L. (1998). Early Stopping - But When?
123
+ """
124
+ def __init__(self,
125
+ alpha: float = 0.75,
126
+ window_size: int = 5,
127
+ verbose: int = 1):
128
+ """
129
+ This early stopping strategy monitors both validation loss and training loss to determine the optimal stopping point.
130
+
131
+ Args:
132
+ alpha (float): The threshold for the stopping criterion.
133
+ window_size (int): The window size for calculating training progress.
134
+ verbose (int): Verbosity mode.
135
+
136
+ NOTE:
137
+
138
+ - **The Window Size (k)**:
139
+ - `5`: The empirical "gold standard." It is long enough to smooth out batch noise but short enough to react to convergence plateaus quickly.
140
+ - `10` to `20`: Use if the training curve is very jagged (e.g., noisy data, small batch sizes, high dropout, or Reinforcement Learning). A larger k value prevents premature stopping due to random volatility.
141
+ - **The threshold (alpha)**:
142
+ - `< 0.5`: Aggressive. Stops training very early.
143
+ - `0.75` to `0.80`: Prechelt found this range to be the most robust across different datasets. It typically yields the best trade-off between generalization and training cost.
144
+ - `1.0` to `1.2`: Useful for complex tasks (like Transformers) where training progress might dip temporarily before recovering. It risks slightly more overfitting but ensures potential is exhausted.
145
+ """
146
+ super().__init__(PyTorchLogKeys.VAL_LOSS, verbose)
147
+ self.train_monitor = PyTorchLogKeys.TRAIN_LOSS
148
+ self.alpha = alpha
149
+ self.k = window_size
150
+
151
+ self.best_val_loss = np.inf
152
+ self.train_strip = deque(maxlen=window_size)
153
+
154
+ def on_train_begin(self, logs=None):
155
+ self.best_val_loss = np.inf
156
+ self.train_strip.clear()
157
+
158
+ def on_epoch_end(self, epoch, logs=None):
159
+ val_loss = logs.get(self.monitor) # type: ignore
160
+ train_loss = logs.get(self.train_monitor) # type: ignore
161
+
162
+ if val_loss is None or train_loss is None:
163
+ return
164
+
165
+ # 1. Update Best Validation Loss
166
+ if val_loss < self.best_val_loss:
167
+ self.best_val_loss = val_loss
168
+
169
+ # 2. Update Training Strip
170
+ self.train_strip.append(train_loss)
171
+
172
+ # 3. Calculate Generalization Loss (GL)
173
+ # GL(t) = 100 * (E_val / E_opt - 1)
174
+ # Low GL is good. High GL means we are drifting away from best val score (overfitting).
175
+ gl = 100 * ((val_loss / self.best_val_loss) - 1)
176
+
177
+ # 4. Calculate Progress (Pk)
178
+ # Pk(t) = 1000 * (Sum(strip) / (k * min(strip)) - 1)
179
+ # High Pk is good (training loss is still dropping fast). Low Pk means training has stalled.
180
+ if len(self.train_strip) < self.k:
181
+ # Not enough data for progress yet
182
+ return
183
+
184
+ strip_sum = sum(self.train_strip)
185
+ strip_min = min(self.train_strip)
186
+
187
+ # Avoid division by zero
188
+ if strip_min == 0:
189
+ pk = 0.1 # Arbitrary small number
190
+ else:
191
+ pk = 1000 * ((strip_sum / (self.k * strip_min)) - 1)
192
+
193
+ # 5. The Quotient Criterion
194
+ # Stop if GL / Pk > alpha
195
+ # Intuition: Stop if Overfitting is high AND Progress is low.
196
+
197
+ # Avoid division by zero
198
+ if pk == 0:
199
+ pk = 1e-6
200
+
201
+ quotient = gl / pk
202
+
203
+ if self.verbose > 1:
204
+ _LOGGER.info(f"Epoch {epoch}: GL={gl:.3f} | Pk={pk:.3f} | Quotient={quotient:.3f} (Threshold={self.alpha})")
205
+
206
+ if quotient > self.alpha:
207
+ self._stop_training(epoch, f"Prechelt Criterion triggered. Generalization/Progress quotient ({quotient:.3f}) > alpha ({self.alpha}).")
208
+
@@ -0,0 +1,12 @@
1
+ from .._core import _imprimir_disponibles
2
+
3
+ _GRUPOS = [
4
+ "DragonPatienceEarlyStopping",
5
+ "DragonPrecheltEarlyStopping",
6
+ "DragonModelCheckpoint",
7
+ "DragonScheduler",
8
+ "DragonPlateauScheduler",
9
+ ]
10
+
11
+ def info():
12
+ _imprimir_disponibles(_GRUPOS)