dragon-ml-toolbox 19.14.0__py3-none-any.whl → 20.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/METADATA +29 -46
  2. dragon_ml_toolbox-20.0.0.dist-info/RECORD +178 -0
  3. ml_tools/{ETL_cleaning.py → ETL_cleaning/__init__.py} +13 -5
  4. ml_tools/ETL_cleaning/_basic_clean.py +351 -0
  5. ml_tools/ETL_cleaning/_clean_tools.py +128 -0
  6. ml_tools/ETL_cleaning/_dragon_cleaner.py +245 -0
  7. ml_tools/ETL_cleaning/_imprimir.py +13 -0
  8. ml_tools/{ETL_engineering.py → ETL_engineering/__init__.py} +8 -4
  9. ml_tools/ETL_engineering/_dragon_engineering.py +261 -0
  10. ml_tools/ETL_engineering/_imprimir.py +24 -0
  11. ml_tools/{_core/_ETL_engineering.py → ETL_engineering/_transforms.py} +14 -267
  12. ml_tools/{_core → GUI_tools}/_GUI_tools.py +37 -40
  13. ml_tools/{GUI_tools.py → GUI_tools/__init__.py} +7 -5
  14. ml_tools/GUI_tools/_imprimir.py +12 -0
  15. ml_tools/IO_tools/_IO_loggers.py +235 -0
  16. ml_tools/IO_tools/_IO_save_load.py +151 -0
  17. ml_tools/IO_tools/_IO_utils.py +140 -0
  18. ml_tools/{IO_tools.py → IO_tools/__init__.py} +13 -5
  19. ml_tools/IO_tools/_imprimir.py +14 -0
  20. ml_tools/MICE/_MICE_imputation.py +132 -0
  21. ml_tools/{MICE_imputation.py → MICE/__init__.py} +6 -7
  22. ml_tools/{_core/_MICE_imputation.py → MICE/_dragon_mice.py} +243 -322
  23. ml_tools/MICE/_imprimir.py +11 -0
  24. ml_tools/{ML_callbacks.py → ML_callbacks/__init__.py} +12 -4
  25. ml_tools/ML_callbacks/_base.py +101 -0
  26. ml_tools/ML_callbacks/_checkpoint.py +232 -0
  27. ml_tools/ML_callbacks/_early_stop.py +208 -0
  28. ml_tools/ML_callbacks/_imprimir.py +12 -0
  29. ml_tools/ML_callbacks/_scheduler.py +197 -0
  30. ml_tools/{ML_chaining_utilities.py → ML_chain/__init__.py} +8 -3
  31. ml_tools/{_core/_ML_chaining_utilities.py → ML_chain/_chaining_tools.py} +5 -129
  32. ml_tools/ML_chain/_dragon_chain.py +140 -0
  33. ml_tools/ML_chain/_imprimir.py +11 -0
  34. ml_tools/ML_configuration/__init__.py +90 -0
  35. ml_tools/ML_configuration/_base_model_config.py +69 -0
  36. ml_tools/ML_configuration/_finalize.py +366 -0
  37. ml_tools/ML_configuration/_imprimir.py +47 -0
  38. ml_tools/ML_configuration/_metrics.py +593 -0
  39. ml_tools/ML_configuration/_models.py +206 -0
  40. ml_tools/ML_configuration/_training.py +124 -0
  41. ml_tools/ML_datasetmaster/__init__.py +28 -0
  42. ml_tools/ML_datasetmaster/_base_datasetmaster.py +337 -0
  43. ml_tools/{_core/_ML_datasetmaster.py → ML_datasetmaster/_datasetmaster.py} +9 -329
  44. ml_tools/ML_datasetmaster/_imprimir.py +15 -0
  45. ml_tools/{_core/_ML_sequence_datasetmaster.py → ML_datasetmaster/_sequence_datasetmaster.py} +13 -15
  46. ml_tools/{_core/_ML_vision_datasetmaster.py → ML_datasetmaster/_vision_datasetmaster.py} +63 -65
  47. ml_tools/ML_evaluation/__init__.py +53 -0
  48. ml_tools/ML_evaluation/_classification.py +629 -0
  49. ml_tools/ML_evaluation/_feature_importance.py +409 -0
  50. ml_tools/ML_evaluation/_imprimir.py +25 -0
  51. ml_tools/ML_evaluation/_loss.py +92 -0
  52. ml_tools/ML_evaluation/_regression.py +273 -0
  53. ml_tools/{_core/_ML_sequence_evaluation.py → ML_evaluation/_sequence.py} +8 -11
  54. ml_tools/{_core/_ML_vision_evaluation.py → ML_evaluation/_vision.py} +12 -17
  55. ml_tools/{_core → ML_evaluation_captum}/_ML_evaluation_captum.py +11 -38
  56. ml_tools/{ML_evaluation_captum.py → ML_evaluation_captum/__init__.py} +6 -4
  57. ml_tools/ML_evaluation_captum/_imprimir.py +10 -0
  58. ml_tools/{_core → ML_finalize_handler}/_ML_finalize_handler.py +3 -7
  59. ml_tools/ML_finalize_handler/__init__.py +10 -0
  60. ml_tools/ML_finalize_handler/_imprimir.py +8 -0
  61. ml_tools/ML_inference/__init__.py +22 -0
  62. ml_tools/ML_inference/_base_inference.py +166 -0
  63. ml_tools/{_core/_ML_chaining_inference.py → ML_inference/_chain_inference.py} +14 -17
  64. ml_tools/ML_inference/_dragon_inference.py +332 -0
  65. ml_tools/ML_inference/_imprimir.py +11 -0
  66. ml_tools/ML_inference/_multi_inference.py +180 -0
  67. ml_tools/ML_inference_sequence/__init__.py +10 -0
  68. ml_tools/ML_inference_sequence/_imprimir.py +8 -0
  69. ml_tools/{_core/_ML_sequence_inference.py → ML_inference_sequence/_sequence_inference.py} +11 -15
  70. ml_tools/ML_inference_vision/__init__.py +10 -0
  71. ml_tools/ML_inference_vision/_imprimir.py +8 -0
  72. ml_tools/{_core/_ML_vision_inference.py → ML_inference_vision/_vision_inference.py} +15 -19
  73. ml_tools/ML_models/__init__.py +32 -0
  74. ml_tools/{_core/_ML_models_advanced.py → ML_models/_advanced_models.py} +22 -18
  75. ml_tools/ML_models/_base_mlp_attention.py +198 -0
  76. ml_tools/{_core/_models_advanced_base.py → ML_models/_base_save_load.py} +73 -49
  77. ml_tools/ML_models/_dragon_tabular.py +248 -0
  78. ml_tools/ML_models/_imprimir.py +18 -0
  79. ml_tools/ML_models/_mlp_attention.py +134 -0
  80. ml_tools/{_core → ML_models}/_models_advanced_helpers.py +13 -13
  81. ml_tools/ML_models_sequence/__init__.py +10 -0
  82. ml_tools/ML_models_sequence/_imprimir.py +8 -0
  83. ml_tools/{_core/_ML_sequence_models.py → ML_models_sequence/_sequence_models.py} +5 -8
  84. ml_tools/ML_models_vision/__init__.py +29 -0
  85. ml_tools/ML_models_vision/_base_wrapper.py +254 -0
  86. ml_tools/ML_models_vision/_image_classification.py +182 -0
  87. ml_tools/ML_models_vision/_image_segmentation.py +108 -0
  88. ml_tools/ML_models_vision/_imprimir.py +16 -0
  89. ml_tools/ML_models_vision/_object_detection.py +135 -0
  90. ml_tools/ML_optimization/__init__.py +21 -0
  91. ml_tools/ML_optimization/_imprimir.py +13 -0
  92. ml_tools/{_core/_ML_optimization_pareto.py → ML_optimization/_multi_dragon.py} +18 -24
  93. ml_tools/ML_optimization/_single_dragon.py +203 -0
  94. ml_tools/{_core/_ML_optimization.py → ML_optimization/_single_manual.py} +75 -213
  95. ml_tools/{_core → ML_scaler}/_ML_scaler.py +8 -11
  96. ml_tools/ML_scaler/__init__.py +10 -0
  97. ml_tools/ML_scaler/_imprimir.py +8 -0
  98. ml_tools/ML_trainer/__init__.py +20 -0
  99. ml_tools/ML_trainer/_base_trainer.py +297 -0
  100. ml_tools/ML_trainer/_dragon_detection_trainer.py +402 -0
  101. ml_tools/ML_trainer/_dragon_sequence_trainer.py +540 -0
  102. ml_tools/ML_trainer/_dragon_trainer.py +1160 -0
  103. ml_tools/ML_trainer/_imprimir.py +10 -0
  104. ml_tools/{ML_utilities.py → ML_utilities/__init__.py} +14 -6
  105. ml_tools/ML_utilities/_artifact_finder.py +382 -0
  106. ml_tools/ML_utilities/_imprimir.py +16 -0
  107. ml_tools/ML_utilities/_inspection.py +325 -0
  108. ml_tools/ML_utilities/_train_tools.py +205 -0
  109. ml_tools/{ML_vision_transformers.py → ML_vision_transformers/__init__.py} +9 -6
  110. ml_tools/{_core/_ML_vision_transformers.py → ML_vision_transformers/_core_transforms.py} +11 -155
  111. ml_tools/ML_vision_transformers/_imprimir.py +14 -0
  112. ml_tools/ML_vision_transformers/_offline_augmentation.py +159 -0
  113. ml_tools/{_core/_PSO_optimization.py → PSO_optimization/_PSO.py} +58 -15
  114. ml_tools/{PSO_optimization.py → PSO_optimization/__init__.py} +5 -3
  115. ml_tools/PSO_optimization/_imprimir.py +10 -0
  116. ml_tools/SQL/__init__.py +7 -0
  117. ml_tools/{_core/_SQL.py → SQL/_dragon_SQL.py} +7 -11
  118. ml_tools/SQL/_imprimir.py +8 -0
  119. ml_tools/{_core → VIF}/_VIF_factor.py +5 -8
  120. ml_tools/{VIF_factor.py → VIF/__init__.py} +4 -2
  121. ml_tools/VIF/_imprimir.py +10 -0
  122. ml_tools/_core/__init__.py +7 -1
  123. ml_tools/_core/_logger.py +8 -18
  124. ml_tools/_core/_schema_load_ops.py +43 -0
  125. ml_tools/_core/_script_info.py +2 -2
  126. ml_tools/{data_exploration.py → data_exploration/__init__.py} +32 -16
  127. ml_tools/data_exploration/_analysis.py +214 -0
  128. ml_tools/data_exploration/_cleaning.py +566 -0
  129. ml_tools/data_exploration/_features.py +583 -0
  130. ml_tools/data_exploration/_imprimir.py +32 -0
  131. ml_tools/data_exploration/_plotting.py +487 -0
  132. ml_tools/data_exploration/_schema_ops.py +176 -0
  133. ml_tools/{ensemble_evaluation.py → ensemble_evaluation/__init__.py} +6 -4
  134. ml_tools/{_core → ensemble_evaluation}/_ensemble_evaluation.py +3 -7
  135. ml_tools/ensemble_evaluation/_imprimir.py +14 -0
  136. ml_tools/{ensemble_inference.py → ensemble_inference/__init__.py} +5 -3
  137. ml_tools/{_core → ensemble_inference}/_ensemble_inference.py +15 -18
  138. ml_tools/ensemble_inference/_imprimir.py +9 -0
  139. ml_tools/{ensemble_learning.py → ensemble_learning/__init__.py} +4 -6
  140. ml_tools/{_core → ensemble_learning}/_ensemble_learning.py +7 -10
  141. ml_tools/ensemble_learning/_imprimir.py +10 -0
  142. ml_tools/{excel_handler.py → excel_handler/__init__.py} +5 -3
  143. ml_tools/{_core → excel_handler}/_excel_handler.py +6 -10
  144. ml_tools/excel_handler/_imprimir.py +13 -0
  145. ml_tools/{keys.py → keys/__init__.py} +4 -1
  146. ml_tools/keys/_imprimir.py +11 -0
  147. ml_tools/{_core → keys}/_keys.py +2 -0
  148. ml_tools/{math_utilities.py → math_utilities/__init__.py} +5 -2
  149. ml_tools/math_utilities/_imprimir.py +11 -0
  150. ml_tools/{_core → math_utilities}/_math_utilities.py +1 -5
  151. ml_tools/{optimization_tools.py → optimization_tools/__init__.py} +9 -4
  152. ml_tools/optimization_tools/_imprimir.py +13 -0
  153. ml_tools/optimization_tools/_optimization_bounds.py +236 -0
  154. ml_tools/optimization_tools/_optimization_plots.py +218 -0
  155. ml_tools/{path_manager.py → path_manager/__init__.py} +6 -3
  156. ml_tools/{_core/_path_manager.py → path_manager/_dragonmanager.py} +11 -347
  157. ml_tools/path_manager/_imprimir.py +15 -0
  158. ml_tools/path_manager/_path_tools.py +346 -0
  159. ml_tools/plot_fonts/__init__.py +8 -0
  160. ml_tools/plot_fonts/_imprimir.py +8 -0
  161. ml_tools/{_core → plot_fonts}/_plot_fonts.py +2 -5
  162. ml_tools/schema/__init__.py +15 -0
  163. ml_tools/schema/_feature_schema.py +223 -0
  164. ml_tools/schema/_gui_schema.py +191 -0
  165. ml_tools/schema/_imprimir.py +10 -0
  166. ml_tools/{serde.py → serde/__init__.py} +4 -2
  167. ml_tools/serde/_imprimir.py +10 -0
  168. ml_tools/{_core → serde}/_serde.py +3 -8
  169. ml_tools/{utilities.py → utilities/__init__.py} +11 -6
  170. ml_tools/utilities/_imprimir.py +18 -0
  171. ml_tools/{_core/_utilities.py → utilities/_utility_save_load.py} +13 -190
  172. ml_tools/utilities/_utility_tools.py +192 -0
  173. dragon_ml_toolbox-19.14.0.dist-info/RECORD +0 -111
  174. ml_tools/ML_chaining_inference.py +0 -8
  175. ml_tools/ML_configuration.py +0 -86
  176. ml_tools/ML_configuration_pytab.py +0 -14
  177. ml_tools/ML_datasetmaster.py +0 -10
  178. ml_tools/ML_evaluation.py +0 -16
  179. ml_tools/ML_evaluation_multi.py +0 -12
  180. ml_tools/ML_finalize_handler.py +0 -8
  181. ml_tools/ML_inference.py +0 -12
  182. ml_tools/ML_models.py +0 -14
  183. ml_tools/ML_models_advanced.py +0 -14
  184. ml_tools/ML_models_pytab.py +0 -14
  185. ml_tools/ML_optimization.py +0 -14
  186. ml_tools/ML_optimization_pareto.py +0 -8
  187. ml_tools/ML_scaler.py +0 -8
  188. ml_tools/ML_sequence_datasetmaster.py +0 -8
  189. ml_tools/ML_sequence_evaluation.py +0 -10
  190. ml_tools/ML_sequence_inference.py +0 -8
  191. ml_tools/ML_sequence_models.py +0 -8
  192. ml_tools/ML_trainer.py +0 -12
  193. ml_tools/ML_vision_datasetmaster.py +0 -12
  194. ml_tools/ML_vision_evaluation.py +0 -10
  195. ml_tools/ML_vision_inference.py +0 -8
  196. ml_tools/ML_vision_models.py +0 -18
  197. ml_tools/SQL.py +0 -8
  198. ml_tools/_core/_ETL_cleaning.py +0 -694
  199. ml_tools/_core/_IO_tools.py +0 -498
  200. ml_tools/_core/_ML_callbacks.py +0 -702
  201. ml_tools/_core/_ML_configuration.py +0 -1332
  202. ml_tools/_core/_ML_configuration_pytab.py +0 -102
  203. ml_tools/_core/_ML_evaluation.py +0 -867
  204. ml_tools/_core/_ML_evaluation_multi.py +0 -544
  205. ml_tools/_core/_ML_inference.py +0 -646
  206. ml_tools/_core/_ML_models.py +0 -668
  207. ml_tools/_core/_ML_models_pytab.py +0 -693
  208. ml_tools/_core/_ML_trainer.py +0 -2323
  209. ml_tools/_core/_ML_utilities.py +0 -886
  210. ml_tools/_core/_ML_vision_models.py +0 -644
  211. ml_tools/_core/_data_exploration.py +0 -1909
  212. ml_tools/_core/_optimization_tools.py +0 -493
  213. ml_tools/_core/_schema.py +0 -359
  214. ml_tools/plot_fonts.py +0 -8
  215. ml_tools/schema.py +0 -12
  216. {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/WHEEL +0 -0
  217. {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE +0 -0
  218. {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
  219. {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/top_level.txt +0 -0
@@ -1,702 +0,0 @@
1
- import numpy as np
2
- import torch
3
- from collections import deque
4
- from tqdm.auto import tqdm
5
- from typing import Union, Literal, Optional
6
- from pathlib import Path
7
-
8
- from ._path_manager import make_fullpath
9
- from ._keys import PyTorchLogKeys, PyTorchCheckpointKeys
10
- from ._logger import get_logger
11
- from ._script_info import _script_info
12
-
13
-
14
- _LOGGER = get_logger("Callbacks")
15
-
16
-
17
- __all__ = [
18
- "History",
19
- "TqdmProgressBar",
20
- "DragonPatienceEarlyStopping",
21
- "DragonPrecheltEarlyStopping",
22
- "DragonModelCheckpoint",
23
- "DragonScheduler",
24
- "DragonReduceLROnPlateau"
25
- ]
26
-
27
-
28
- class _Callback:
29
- """
30
- Abstract base class used to build new callbacks.
31
-
32
- The methods of this class are automatically called by the Trainer at different
33
- points during training. Subclasses can override these methods to implement
34
- custom logic.
35
- """
36
- def __init__(self):
37
- self.trainer = None
38
-
39
- def set_trainer(self, trainer):
40
- """This is called by the Trainer to associate itself with the callback."""
41
- self.trainer = trainer
42
-
43
- def on_train_begin(self, logs=None):
44
- """Called at the beginning of training."""
45
- pass
46
-
47
- def on_train_end(self, logs=None):
48
- """Called at the end of training."""
49
- pass
50
-
51
- def on_epoch_begin(self, epoch, logs=None):
52
- """Called at the beginning of an epoch."""
53
- pass
54
-
55
- def on_epoch_end(self, epoch, logs=None):
56
- """Called at the end of an epoch."""
57
- pass
58
-
59
- def on_batch_begin(self, batch, logs=None):
60
- """Called at the beginning of a training batch."""
61
- pass
62
-
63
- def on_batch_end(self, batch, logs=None):
64
- """Called at the end of a training batch."""
65
- pass
66
-
67
-
68
- class History(_Callback):
69
- """
70
- Callback that records events into a `history` dictionary.
71
-
72
- This callback is automatically applied to every MyTrainer model.
73
- The `history` attribute is a dictionary mapping metric names (e.g., 'val_loss')
74
- to a list of metric values.
75
- """
76
- def on_train_begin(self, logs=None):
77
- # Clear history at the beginning of training
78
- self.trainer.history = {} # type: ignore
79
-
80
- def on_epoch_end(self, epoch, logs=None):
81
- logs = logs or {}
82
- for k, v in logs.items():
83
- # Append new log values to the history dictionary
84
- self.trainer.history.setdefault(k, []).append(v) # type: ignore
85
-
86
-
87
- class TqdmProgressBar(_Callback):
88
- """Callback that provides a tqdm progress bar for training."""
89
- def __init__(self):
90
- self.epoch_bar = None
91
- self.batch_bar = None
92
-
93
- def on_train_begin(self, logs=None):
94
- self.epochs = self.trainer.epochs # type: ignore
95
- self.epoch_bar = tqdm(total=self.epochs, desc="Training Progress")
96
-
97
- def on_epoch_begin(self, epoch, logs=None):
98
- total_batches = len(self.trainer.train_loader) # type: ignore
99
- self.batch_bar = tqdm(total=total_batches, desc=f"Epoch {epoch}/{self.epochs}", leave=False)
100
-
101
- def on_batch_end(self, batch, logs=None):
102
- self.batch_bar.update(1) # type: ignore
103
- if logs:
104
- self.batch_bar.set_postfix(loss=f"{logs.get(PyTorchLogKeys.BATCH_LOSS, 0):.4f}") # type: ignore
105
-
106
- def on_epoch_end(self, epoch, logs=None):
107
- self.batch_bar.close() # type: ignore
108
- self.epoch_bar.update(1) # type: ignore
109
- if logs:
110
- train_loss_str = f"{logs.get(PyTorchLogKeys.TRAIN_LOSS, 0):.4f}"
111
- val_loss_str = f"{logs.get(PyTorchLogKeys.VAL_LOSS, 0):.4f}"
112
- self.epoch_bar.set_postfix_str(f"Train Loss: {train_loss_str}, Val Loss: {val_loss_str}") # type: ignore
113
-
114
- def on_train_end(self, logs=None):
115
- self.epoch_bar.close() # type: ignore
116
-
117
-
118
- class _DragonEarlyStopping(_Callback):
119
- """
120
- Base class for Early Stopping strategies.
121
- Ensures type compatibility and shared logging logic.
122
- """
123
- def __init__(self,
124
- monitor: str,
125
- verbose: int = 1):
126
- super().__init__()
127
- self.monitor = monitor
128
- self.verbose = verbose
129
- self.stopped_epoch = 0
130
-
131
- def _stop_training(self, epoch: int, reason: str):
132
- """Helper to trigger the stop."""
133
- self.stopped_epoch = epoch
134
- self.trainer.stop_training = True # type: ignore
135
- if self.verbose > 0:
136
- _LOGGER.info(f"Epoch {epoch}: Early stopping triggered. Reason: {reason}")
137
-
138
-
139
- class DragonPatienceEarlyStopping(_DragonEarlyStopping):
140
- """
141
- Standard early stopping: Tracks minimum validation loss (or other metric) with a patience counter.
142
- """
143
- def __init__(self,
144
- monitor: Literal["Training Loss", "Validation Loss"] = "Validation Loss",
145
- min_delta: float = 0.0,
146
- patience: int = 10,
147
- mode: Literal['min', 'max'] = 'min',
148
- verbose: int = 1):
149
- """
150
- Args:
151
- monitor (str): Metric to monitor.
152
- min_delta (float): Minimum change to qualify as an improvement.
153
- patience (int): Number of epochs with no improvement after which training will be stopped.
154
- mode (str): One of {'min', 'max'}. In 'min' mode, training will stop when the quantity monitored has stopped decreasing; in 'max' mode it will stop when the quantity monitored has stopped increasing.
155
- verbose (int): Verbosity mode.
156
- """
157
- # standardize monitor key
158
- if monitor == "Training Loss":
159
- std_monitor = PyTorchLogKeys.TRAIN_LOSS
160
- elif monitor == "Validation Loss":
161
- std_monitor = PyTorchLogKeys.VAL_LOSS
162
- else:
163
- _LOGGER.error(f"Unknown monitor key: {monitor}.")
164
- raise ValueError()
165
-
166
- super().__init__(std_monitor, verbose)
167
- self.patience = patience
168
- self.min_delta = min_delta
169
- self.wait = 0
170
- self.mode = mode
171
-
172
- if mode not in ['min', 'max']:
173
- _LOGGER.error(f"EarlyStopping mode {mode} is unknown, choose one of ('min', 'max')")
174
- raise ValueError()
175
-
176
- # Determine the comparison operator
177
- if self.mode == 'min':
178
- self.monitor_op = np.less
179
- elif self.mode == 'max':
180
- self.monitor_op = np.greater
181
- else:
182
- # raise error for unknown mode
183
- _LOGGER.error(f"EarlyStopping mode {mode} is unknown, choose one of ('min', 'max')")
184
- raise ValueError()
185
-
186
- self.best = np.inf if self.monitor_op == np.less else -np.inf
187
-
188
- def on_train_begin(self, logs=None):
189
- self.wait = 0
190
- self.best = np.inf if self.monitor_op == np.less else -np.inf
191
-
192
- def on_epoch_end(self, epoch, logs=None):
193
- current = logs.get(self.monitor) # type: ignore
194
- if current is None:
195
- return
196
-
197
- # Check improvement
198
- if self.monitor_op == np.less:
199
- is_improvement = self.monitor_op(current, self.best - self.min_delta)
200
- else:
201
- is_improvement = self.monitor_op(current, self.best + self.min_delta)
202
-
203
- if is_improvement:
204
- if self.verbose > 1:
205
- _LOGGER.info(f"EarlyStopping: {self.monitor} improved from {self.best:.4f} to {current:.4f}")
206
- self.best = current
207
- self.wait = 0
208
- else:
209
- self.wait += 1
210
- if self.wait >= self.patience:
211
- self._stop_training(epoch, f"No improvement in {self.monitor} for {self.wait} epochs.")
212
-
213
-
214
- class DragonPrecheltEarlyStopping(_DragonEarlyStopping):
215
- """
216
- Implements Prechelt's 'Progress-Modified GL' criterion.
217
- Tracks the ratio between Generalization Loss (overfitting) and Training Progress.
218
-
219
- References:
220
- Prechelt, L. (1998). Early Stopping - But When?
221
- """
222
- def __init__(self,
223
- alpha: float = 0.75,
224
- k: int = 5,
225
- verbose: int = 1):
226
- """
227
- This early stopping strategy monitors both validation loss and training loss to determine the optimal stopping point.
228
-
229
- Args:
230
- alpha (float): The threshold for the stopping criterion.
231
- k (int): The window size for calculating training progress.
232
- verbose (int): Verbosity mode.
233
-
234
- NOTE:
235
-
236
- - **The Strip Size (k)**:
237
- - `5`: The empirical "gold standard." It is long enough to smooth out batch noise but short enough to react to convergence plateaus quickly.
238
- - `10` to `20`: Use if the training curve is very jagged (e.g., noisy data, small batch sizes, high dropout, or Reinforcement Learning). A larger k value prevents premature stopping due to random volatility.
239
- - **The threshold (alpha)**:
240
- - `< 0.5`: Aggressive. Stops training very early.
241
- - `0.75` to `0.80`: Prechelt found this range to be the most robust across different datasets. It typically yields the best trade-off between generalization and training cost.
242
- - `1.0` to `1.2`: Useful for complex tasks (like Transformers) where training progress might dip temporarily before recovering. It risks slightly more overfitting but ensures potential is exhausted.
243
- """
244
- super().__init__(PyTorchLogKeys.VAL_LOSS, verbose)
245
- self.train_monitor = PyTorchLogKeys.TRAIN_LOSS
246
- self.alpha = alpha
247
- self.k = k
248
-
249
- self.best_val_loss = np.inf
250
- self.train_strip = deque(maxlen=k)
251
-
252
- def on_train_begin(self, logs=None):
253
- self.best_val_loss = np.inf
254
- self.train_strip.clear()
255
-
256
- def on_epoch_end(self, epoch, logs=None):
257
- val_loss = logs.get(self.monitor) # type: ignore
258
- train_loss = logs.get(self.train_monitor) # type: ignore
259
-
260
- if val_loss is None or train_loss is None:
261
- return
262
-
263
- # 1. Update Best Validation Loss
264
- if val_loss < self.best_val_loss:
265
- self.best_val_loss = val_loss
266
-
267
- # 2. Update Training Strip
268
- self.train_strip.append(train_loss)
269
-
270
- # 3. Calculate Generalization Loss (GL)
271
- # GL(t) = 100 * (E_val / E_opt - 1)
272
- # Low GL is good. High GL means we are drifting away from best val score (overfitting).
273
- gl = 100 * ((val_loss / self.best_val_loss) - 1)
274
-
275
- # 4. Calculate Progress (Pk)
276
- # Pk(t) = 1000 * (Sum(strip) / (k * min(strip)) - 1)
277
- # High Pk is good (training loss is still dropping fast). Low Pk means training has stalled.
278
- if len(self.train_strip) < self.k:
279
- # Not enough data for progress yet
280
- return
281
-
282
- strip_sum = sum(self.train_strip)
283
- strip_min = min(self.train_strip)
284
-
285
- # Avoid division by zero
286
- if strip_min == 0:
287
- pk = 0.1 # Arbitrary small number
288
- else:
289
- pk = 1000 * ((strip_sum / (self.k * strip_min)) - 1)
290
-
291
- # 5. The Quotient Criterion
292
- # Stop if GL / Pk > alpha
293
- # Intuition: Stop if Overfitting is high AND Progress is low.
294
-
295
- # Avoid division by zero
296
- if pk == 0:
297
- pk = 1e-6
298
-
299
- quotient = gl / pk
300
-
301
- if self.verbose > 1:
302
- _LOGGER.info(f"Epoch {epoch}: GL={gl:.3f} | Pk={pk:.3f} | Quotient={quotient:.3f} (Threshold={self.alpha})")
303
-
304
- if quotient > self.alpha:
305
- self._stop_training(epoch, f"Prechelt Criterion triggered. Generalization/Progress quotient ({quotient:.3f}) > alpha ({self.alpha}).")
306
-
307
-
308
- class DragonModelCheckpoint(_Callback):
309
- """
310
- Saves the model weights, optimizer state, LR scheduler state (if any), and epoch number to a directory with automated filename generation and rotation.
311
- """
312
- def __init__(self,
313
- save_dir: Union[str, Path],
314
- monitor: Literal["Training Loss", "Validation Loss", "both"] = "Validation Loss",
315
- save_three_best: bool = True,
316
- mode: Literal['min', 'max'] = 'min',
317
- verbose: int = 0):
318
- """
319
- Args:
320
- save_dir (str): Directory where checkpoint files will be saved.
321
- monitor (str): Metric to monitor. If "both", the sum of training loss and validation loss is used.
322
- save_three_best (bool):
323
- - If True, keeps the top 3 best checkpoints found during training (based on metric).
324
- - If False, keeps the 3 most recent checkpoints (rolling window).
325
- mode (str): One of {'min', 'max'}.
326
- verbose (int): Verbosity mode.
327
- """
328
- super().__init__()
329
- self.save_dir = make_fullpath(save_dir, make=True, enforce="directory")
330
-
331
- # Standardize monitor key
332
- if monitor == "Training Loss":
333
- std_monitor = PyTorchLogKeys.TRAIN_LOSS
334
- elif monitor == "Validation Loss":
335
- std_monitor = PyTorchLogKeys.VAL_LOSS
336
- elif monitor == "both":
337
- std_monitor = "both"
338
- else:
339
- _LOGGER.error(f"Unknown monitor key: {monitor}.")
340
- raise ValueError()
341
-
342
- self.monitor = std_monitor
343
- self.save_three_best = save_three_best
344
- self.verbose = verbose
345
- self._latest_checkpoint_path = None
346
- self._checkpoint_name = PyTorchCheckpointKeys.CHECKPOINT_NAME
347
-
348
- # State variables
349
- # stored as list of dicts: [{'path': Path, 'score': float, 'epoch': int}]
350
- self.best_checkpoints = []
351
- # For rolling check (save_three_best=False)
352
- self.recent_checkpoints = []
353
-
354
- if mode not in ['min', 'max']:
355
- _LOGGER.error(f"ModelCheckpoint mode {mode} is unknown. Use 'min' or 'max'.")
356
- raise ValueError()
357
- self.mode = mode
358
-
359
- # Determine comparison operator
360
- if self.mode == 'min':
361
- self.monitor_op = np.less
362
- self.best = np.inf
363
- else:
364
- self.monitor_op = np.greater
365
- self.best = -np.inf
366
-
367
- def on_train_begin(self, logs=None):
368
- """Reset file tracking state when training starts.
369
- NOTE: Do nOT reset self.best here if it differs from the default. This allows the Trainer to restore 'best' from a checkpoint before calling train()."""
370
- self.best_checkpoints = []
371
- self.recent_checkpoints = []
372
-
373
- # Check if self.best is at default initialization value
374
- is_default_min = (self.mode == 'min' and self.best == np.inf)
375
- is_default_max = (self.mode == 'max' and self.best == -np.inf)
376
-
377
- # If it is NOT default, it means it was restored.
378
- if not (is_default_min or is_default_max):
379
- _LOGGER.debug(f"Resuming with best score: {self.best:.4f}")
380
-
381
- def _get_metric_value(self, logs):
382
- """Extracts or calculates the metric value based on configuration."""
383
- if self.monitor == "both":
384
- t_loss = logs.get(PyTorchLogKeys.TRAIN_LOSS)
385
- v_loss = logs.get(PyTorchLogKeys.VAL_LOSS)
386
- if t_loss is None or v_loss is None:
387
- return None
388
- return t_loss + v_loss
389
- else:
390
- return logs.get(self.monitor)
391
-
392
- def on_epoch_end(self, epoch, logs=None):
393
- logs = logs or {}
394
- current_score = self._get_metric_value(logs)
395
-
396
- if current_score is None:
397
- if self.verbose > 0:
398
- _LOGGER.warning(f"Epoch {epoch}: Metric '{self.monitor}' not found in logs. Skipping checkpoint.")
399
- return
400
-
401
- # 1. Update global best score (for logging/metadata)
402
- if self.monitor_op(current_score, self.best):
403
- if self.verbose > 0:
404
- # Only log explicit "improvement" if we are beating the historical best
405
- old_best_str = f"{self.best:.4f}" if not np.isinf(self.best) else "inf"
406
- _LOGGER.info(f"Epoch {epoch}: {self.monitor} improved from {old_best_str} to {current_score:.4f}")
407
- self.best = current_score
408
-
409
- if self.save_three_best:
410
- self._save_top_k_checkpoints(epoch, current_score)
411
- else:
412
- self._save_rolling_checkpoints(epoch, current_score)
413
-
414
- def _save_checkpoint_file(self, epoch, current_score):
415
- """Helper to physically save the file."""
416
- self.save_dir.mkdir(parents=True, exist_ok=True)
417
-
418
- # Create filename
419
- score_str = f"{current_score:.4f}".replace('.', '_')
420
- filename = f"epoch{epoch}_{self._checkpoint_name}-{score_str}.pth"
421
- filepath = self.save_dir / filename
422
-
423
- # Create checkpoint dict
424
- checkpoint_data = {
425
- PyTorchCheckpointKeys.EPOCH: epoch,
426
- PyTorchCheckpointKeys.MODEL_STATE: self.trainer.model.state_dict(), # type: ignore
427
- PyTorchCheckpointKeys.OPTIMIZER_STATE: self.trainer.optimizer.state_dict(), # type: ignore
428
- PyTorchCheckpointKeys.BEST_SCORE: current_score,
429
- PyTorchCheckpointKeys.HISTORY: self.trainer.history, # type: ignore
430
- }
431
-
432
- if hasattr(self.trainer, 'scheduler') and self.trainer.scheduler is not None: # type: ignore
433
- checkpoint_data[PyTorchCheckpointKeys.SCHEDULER_STATE] = self.trainer.scheduler.state_dict() # type: ignore
434
-
435
- torch.save(checkpoint_data, filepath)
436
- self._latest_checkpoint_path = filepath
437
-
438
- return filepath
439
-
440
- def _save_top_k_checkpoints(self, epoch, current_score):
441
- """Logic for maintaining the top 3 best checkpoints."""
442
-
443
- def sort_key(item): return item['score']
444
-
445
- # Determine sort direction so that Index 0 is BEST and Index -1 is WORST
446
- # Min mode (lower is better): Ascending (reverse=False) -> [0.1, 0.5, 0.9] (0.1 is best)
447
- # Max mode (higher is better): Descending (reverse=True) -> [0.9, 0.5, 0.1] (0.9 is best)
448
- is_reverse = (self.mode == 'max')
449
-
450
- should_save = False
451
-
452
- if len(self.best_checkpoints) < 3:
453
- should_save = True
454
- else:
455
- # Sort current list to identify the worst (last item)
456
- self.best_checkpoints.sort(key=sort_key, reverse=is_reverse)
457
- worst_entry = self.best_checkpoints[-1]
458
-
459
- # Check if current is better than the worst in the list
460
- # min mode: current < worst['score']
461
- # max mode: current > worst['score']
462
- if self.monitor_op(current_score, worst_entry['score']):
463
- should_save = True
464
-
465
- if should_save:
466
- filepath = self._save_checkpoint_file(epoch, current_score)
467
-
468
- if self.verbose > 0:
469
- _LOGGER.info(f"Epoch {epoch}: {self.monitor} ({current_score:.4f}) is in top 3. Saving to {filepath.name}")
470
-
471
- self.best_checkpoints.append({'path': filepath, 'score': current_score, 'epoch': epoch})
472
-
473
- # Prune if > 3
474
- if len(self.best_checkpoints) > 3:
475
- # Re-sort to ensure worst is at the end
476
- self.best_checkpoints.sort(key=sort_key, reverse=is_reverse)
477
-
478
- # Evict the last one (Worst)
479
- entry_to_delete = self.best_checkpoints.pop(-1)
480
-
481
- if entry_to_delete['path'].exists():
482
- if self.verbose > 0:
483
- _LOGGER.info(f" -> Deleting checkpoint outside top 3: {entry_to_delete['path'].name}")
484
- entry_to_delete['path'].unlink()
485
-
486
- def _save_rolling_checkpoints(self, epoch, current_score):
487
- """Saves the latest model and keeps only the 3 most recent ones."""
488
- filepath = self._save_checkpoint_file(epoch, current_score)
489
-
490
- if self.verbose > 0:
491
- _LOGGER.info(f'Epoch {epoch}: saving rolling model to {filepath.name}')
492
-
493
- self.recent_checkpoints.append(filepath)
494
-
495
- # If we have more than 3 checkpoints, remove the oldest one
496
- if len(self.recent_checkpoints) > 3:
497
- file_to_delete = self.recent_checkpoints.pop(0)
498
- if file_to_delete.exists():
499
- if self.verbose > 0:
500
- _LOGGER.info(f" -> Deleting old rolling checkpoint: {file_to_delete.name}")
501
- file_to_delete.unlink()
502
-
503
- @property
504
- def best_checkpoint_path(self):
505
- # If tracking top 3, return the absolute best among them
506
- if self.save_three_best and self.best_checkpoints:
507
- def sort_key(item): return item['score']
508
- is_reverse = (self.mode == 'max')
509
- # Sort Best -> Worst
510
- sorted_bests = sorted(self.best_checkpoints, key=sort_key, reverse=is_reverse)
511
- # Index 0 is always the best based on the logic above
512
- return sorted_bests[0]['path']
513
-
514
- elif self._latest_checkpoint_path:
515
- return self._latest_checkpoint_path
516
- else:
517
- _LOGGER.error("No checkpoint paths saved.")
518
- raise ValueError()
519
-
520
-
521
- class _DragonLRScheduler(_Callback):
522
- """
523
- Base class for Dragon LR Schedulers.
524
- Handles common logic like logging and attaching to the trainer.
525
- """
526
- def __init__(self):
527
- super().__init__()
528
- self.scheduler = None
529
- self.previous_lr = None
530
-
531
- def set_trainer(self, trainer):
532
- """Associates the callback with the trainer."""
533
- super().set_trainer(trainer)
534
- # Note: Subclasses must ensure self.scheduler is set before or during this call
535
- # if they want to register it immediately.
536
- if self.scheduler:
537
- self.trainer.scheduler = self.scheduler # type: ignore
538
-
539
- def on_train_begin(self, logs=None):
540
- """Store the initial learning rate."""
541
- if not self.trainer.optimizer: # type: ignore
542
- _LOGGER.warning("No optimizer found in trainer. LRScheduler cannot track learning rate.")
543
- return
544
- self.previous_lr = self.trainer.optimizer.param_groups[0]['lr'] # type: ignore
545
-
546
- def _check_and_log_lr(self, epoch, logs, verbose: bool):
547
- """Helper to log LR changes and update history."""
548
- if not self.trainer.optimizer: # type: ignore
549
- return
550
-
551
- current_lr = self.trainer.optimizer.param_groups[0]['lr'] # type: ignore
552
-
553
- # Log change
554
- if self.previous_lr is not None and current_lr != self.previous_lr:
555
- if verbose:
556
- print(f" > Epoch {epoch}: Learning rate changed to {current_lr:.6f}")
557
- self.previous_lr = current_lr
558
-
559
- # Log to dictionary
560
- logs[PyTorchLogKeys.LEARNING_RATE] = current_lr
561
-
562
- # Log to history
563
- if hasattr(self.trainer, 'history'):
564
- self.trainer.history.setdefault(PyTorchLogKeys.LEARNING_RATE, []).append(current_lr) # type: ignore
565
-
566
-
567
- class DragonScheduler(_DragonLRScheduler):
568
- """
569
- Callback for standard PyTorch Learning Rate Schedulers.
570
-
571
- Compatible with: StepLR, MultiStepLR, ExponentialLR, CosineAnnealingLR, etc.
572
-
573
- NOT Compatible with: ReduceLROnPlateau (Use `DragonReduceLROnPlateau` instead).
574
- """
575
- def __init__(self, scheduler, verbose: bool=True):
576
- """
577
- Args:
578
- scheduler: An initialized PyTorch learning rate scheduler instance.
579
- verbose (bool): If True, logs learning rate changes to console.
580
- """
581
- super().__init__()
582
- if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
583
- raise ValueError(
584
- "DragonLRScheduler does not support 'ReduceLROnPlateau'. "
585
- "Please use the `DragonReduceLROnPlateau` callback instead."
586
- )
587
- self.scheduler = scheduler
588
- self.verbose = verbose
589
-
590
- def set_trainer(self, trainer):
591
- super().set_trainer(trainer)
592
- # Explicitly register the scheduler again to be safe
593
- self.trainer.scheduler = self.scheduler # type: ignore
594
- if self.verbose:
595
- _LOGGER.info(f"Registered LR Scheduler: {self.scheduler.__class__.__name__}")
596
-
597
- def on_epoch_end(self, epoch, logs=None):
598
- logs = logs or {}
599
-
600
- # Standard step (no metrics needed)
601
- self.scheduler.step()
602
-
603
- self._check_and_log_lr(epoch, logs, self.verbose)
604
-
605
-
606
- class DragonReduceLROnPlateau(_DragonLRScheduler):
607
- """
608
- Specific callback for `torch.optim.lr_scheduler.ReduceLROnPlateau`. Reduces learning rate when a monitored metric has stopped improving.
609
-
610
- This wrapper initializes the scheduler internally using the Trainer's optimizer, simplifying the setup process.
611
- """
612
- def __init__(self,
613
- monitor: Literal["Training Loss", "Validation Loss"] = "Validation Loss",
614
- mode: Literal['min', 'max'] = 'min',
615
- factor: float = 0.1,
616
- patience: int = 5,
617
- threshold: float = 1e-4,
618
- threshold_mode: Literal['rel', 'abs'] = 'rel',
619
- cooldown: int = 0,
620
- min_lr: float = 0,
621
- eps: float = 1e-8,
622
- verbose: bool = True):
623
- """
624
- Args:
625
- monitor ("Training Loss", "Validation Loss"): Metric to monitor.
626
- mode ('min', 'max'): One of 'min', 'max'.
627
- factor (float): Factor by which the learning rate will be reduced. new_lr = lr * factor.
628
- patience (int): Number of epochs with no improvement after which learning rate will be reduced.
629
- threshold (float): Threshold for measuring the new optimum.
630
- threshold_mode ('rel', 'abs'): One of 'rel', 'abs'.
631
- cooldown (int): Number of epochs to wait before resuming normal operation after lr has been reduced.
632
- min_lr (float or list): A scalar or a list of scalars.
633
- eps (float): Minimal decay applied to lr.
634
- verbose (bool): If True, logs learning rate changes to console.
635
- """
636
- super().__init__()
637
-
638
- # Standardize monitor key
639
- if monitor == "Training Loss":
640
- std_monitor = PyTorchLogKeys.TRAIN_LOSS
641
- elif monitor == "Validation Loss":
642
- std_monitor = PyTorchLogKeys.VAL_LOSS
643
- else:
644
- _LOGGER.error(f"Unknown monitor key: {monitor}.")
645
- raise ValueError()
646
-
647
- self.monitor = std_monitor
648
- self.verbose = verbose
649
-
650
- # Config storage for delayed initialization
651
- self.config = {
652
- 'mode': mode,
653
- 'factor': factor,
654
- 'patience': patience,
655
- 'threshold': threshold,
656
- 'threshold_mode': threshold_mode,
657
- 'cooldown': cooldown,
658
- 'min_lr': min_lr,
659
- 'eps': eps,
660
- }
661
-
662
- def set_trainer(self, trainer):
663
- """
664
- Initializes the ReduceLROnPlateau scheduler using the trainer's optimizer and registers it.
665
- """
666
- super().set_trainer(trainer)
667
-
668
- if not hasattr(self.trainer, 'optimizer'):
669
- _LOGGER.error("Trainer has no optimizer. Cannot initialize ReduceLROnPlateau.")
670
- raise ValueError()
671
-
672
- # Initialize the actual scheduler with the optimizer
673
- if self.verbose:
674
- _LOGGER.info(f"Initializing ReduceLROnPlateau monitoring '{self.monitor}'")
675
-
676
- self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
677
- optimizer=self.trainer.optimizer, # type: ignore
678
- **self.config
679
- )
680
-
681
- # Register with trainer for checkpointing
682
- self.trainer.scheduler = self.scheduler # type: ignore
683
-
684
- def on_epoch_end(self, epoch, logs=None):
685
- logs = logs or {}
686
-
687
- metric_val = logs.get(self.monitor)
688
-
689
- if metric_val is None:
690
- _LOGGER.warning(f"DragonReduceLROnPlateau could not find metric '{self.monitor}' in logs. Scheduler step skipped.")
691
- # Still log LR to keep history consistent
692
- self._check_and_log_lr(epoch, logs, self.verbose)
693
- return
694
-
695
- # Step with metric
696
- self.scheduler.step(metric_val)
697
-
698
- self._check_and_log_lr(epoch, logs, self.verbose)
699
-
700
-
701
- def info():
702
- _script_info(__all__)