dragon-ml-toolbox 19.13.0__py3-none-any.whl → 20.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/METADATA +29 -46
- dragon_ml_toolbox-20.0.0.dist-info/RECORD +178 -0
- ml_tools/{ETL_cleaning.py → ETL_cleaning/__init__.py} +13 -5
- ml_tools/ETL_cleaning/_basic_clean.py +351 -0
- ml_tools/ETL_cleaning/_clean_tools.py +128 -0
- ml_tools/ETL_cleaning/_dragon_cleaner.py +245 -0
- ml_tools/ETL_cleaning/_imprimir.py +13 -0
- ml_tools/{ETL_engineering.py → ETL_engineering/__init__.py} +8 -4
- ml_tools/ETL_engineering/_dragon_engineering.py +261 -0
- ml_tools/ETL_engineering/_imprimir.py +24 -0
- ml_tools/{_core/_ETL_engineering.py → ETL_engineering/_transforms.py} +14 -267
- ml_tools/{_core → GUI_tools}/_GUI_tools.py +37 -40
- ml_tools/{GUI_tools.py → GUI_tools/__init__.py} +7 -5
- ml_tools/GUI_tools/_imprimir.py +12 -0
- ml_tools/IO_tools/_IO_loggers.py +235 -0
- ml_tools/IO_tools/_IO_save_load.py +151 -0
- ml_tools/IO_tools/_IO_utils.py +140 -0
- ml_tools/{IO_tools.py → IO_tools/__init__.py} +13 -5
- ml_tools/IO_tools/_imprimir.py +14 -0
- ml_tools/MICE/_MICE_imputation.py +132 -0
- ml_tools/{MICE_imputation.py → MICE/__init__.py} +6 -7
- ml_tools/{_core/_MICE_imputation.py → MICE/_dragon_mice.py} +243 -322
- ml_tools/MICE/_imprimir.py +11 -0
- ml_tools/{ML_callbacks.py → ML_callbacks/__init__.py} +12 -4
- ml_tools/ML_callbacks/_base.py +101 -0
- ml_tools/ML_callbacks/_checkpoint.py +232 -0
- ml_tools/ML_callbacks/_early_stop.py +208 -0
- ml_tools/ML_callbacks/_imprimir.py +12 -0
- ml_tools/ML_callbacks/_scheduler.py +197 -0
- ml_tools/{ML_chaining_utilities.py → ML_chain/__init__.py} +8 -3
- ml_tools/{_core/_ML_chaining_utilities.py → ML_chain/_chaining_tools.py} +5 -129
- ml_tools/ML_chain/_dragon_chain.py +140 -0
- ml_tools/ML_chain/_imprimir.py +11 -0
- ml_tools/ML_configuration/__init__.py +90 -0
- ml_tools/ML_configuration/_base_model_config.py +69 -0
- ml_tools/ML_configuration/_finalize.py +366 -0
- ml_tools/ML_configuration/_imprimir.py +47 -0
- ml_tools/ML_configuration/_metrics.py +593 -0
- ml_tools/ML_configuration/_models.py +206 -0
- ml_tools/ML_configuration/_training.py +124 -0
- ml_tools/ML_datasetmaster/__init__.py +28 -0
- ml_tools/ML_datasetmaster/_base_datasetmaster.py +337 -0
- ml_tools/{_core/_ML_datasetmaster.py → ML_datasetmaster/_datasetmaster.py} +9 -329
- ml_tools/ML_datasetmaster/_imprimir.py +15 -0
- ml_tools/{_core/_ML_sequence_datasetmaster.py → ML_datasetmaster/_sequence_datasetmaster.py} +13 -15
- ml_tools/{_core/_ML_vision_datasetmaster.py → ML_datasetmaster/_vision_datasetmaster.py} +63 -65
- ml_tools/ML_evaluation/__init__.py +53 -0
- ml_tools/ML_evaluation/_classification.py +629 -0
- ml_tools/ML_evaluation/_feature_importance.py +409 -0
- ml_tools/ML_evaluation/_imprimir.py +25 -0
- ml_tools/ML_evaluation/_loss.py +92 -0
- ml_tools/ML_evaluation/_regression.py +273 -0
- ml_tools/{_core/_ML_sequence_evaluation.py → ML_evaluation/_sequence.py} +8 -11
- ml_tools/{_core/_ML_vision_evaluation.py → ML_evaluation/_vision.py} +12 -17
- ml_tools/{_core → ML_evaluation_captum}/_ML_evaluation_captum.py +11 -38
- ml_tools/{ML_evaluation_captum.py → ML_evaluation_captum/__init__.py} +6 -4
- ml_tools/ML_evaluation_captum/_imprimir.py +10 -0
- ml_tools/{_core → ML_finalize_handler}/_ML_finalize_handler.py +3 -7
- ml_tools/ML_finalize_handler/__init__.py +10 -0
- ml_tools/ML_finalize_handler/_imprimir.py +8 -0
- ml_tools/ML_inference/__init__.py +22 -0
- ml_tools/ML_inference/_base_inference.py +166 -0
- ml_tools/{_core/_ML_chaining_inference.py → ML_inference/_chain_inference.py} +14 -17
- ml_tools/ML_inference/_dragon_inference.py +332 -0
- ml_tools/ML_inference/_imprimir.py +11 -0
- ml_tools/ML_inference/_multi_inference.py +180 -0
- ml_tools/ML_inference_sequence/__init__.py +10 -0
- ml_tools/ML_inference_sequence/_imprimir.py +8 -0
- ml_tools/{_core/_ML_sequence_inference.py → ML_inference_sequence/_sequence_inference.py} +11 -15
- ml_tools/ML_inference_vision/__init__.py +10 -0
- ml_tools/ML_inference_vision/_imprimir.py +8 -0
- ml_tools/{_core/_ML_vision_inference.py → ML_inference_vision/_vision_inference.py} +15 -19
- ml_tools/ML_models/__init__.py +32 -0
- ml_tools/{_core/_ML_models_advanced.py → ML_models/_advanced_models.py} +22 -18
- ml_tools/ML_models/_base_mlp_attention.py +198 -0
- ml_tools/{_core/_models_advanced_base.py → ML_models/_base_save_load.py} +73 -49
- ml_tools/ML_models/_dragon_tabular.py +248 -0
- ml_tools/ML_models/_imprimir.py +18 -0
- ml_tools/ML_models/_mlp_attention.py +134 -0
- ml_tools/{_core → ML_models}/_models_advanced_helpers.py +13 -13
- ml_tools/ML_models_sequence/__init__.py +10 -0
- ml_tools/ML_models_sequence/_imprimir.py +8 -0
- ml_tools/{_core/_ML_sequence_models.py → ML_models_sequence/_sequence_models.py} +5 -8
- ml_tools/ML_models_vision/__init__.py +29 -0
- ml_tools/ML_models_vision/_base_wrapper.py +254 -0
- ml_tools/ML_models_vision/_image_classification.py +182 -0
- ml_tools/ML_models_vision/_image_segmentation.py +108 -0
- ml_tools/ML_models_vision/_imprimir.py +16 -0
- ml_tools/ML_models_vision/_object_detection.py +135 -0
- ml_tools/ML_optimization/__init__.py +21 -0
- ml_tools/ML_optimization/_imprimir.py +13 -0
- ml_tools/{_core/_ML_optimization_pareto.py → ML_optimization/_multi_dragon.py} +18 -24
- ml_tools/ML_optimization/_single_dragon.py +203 -0
- ml_tools/{_core/_ML_optimization.py → ML_optimization/_single_manual.py} +75 -213
- ml_tools/{_core → ML_scaler}/_ML_scaler.py +8 -11
- ml_tools/ML_scaler/__init__.py +10 -0
- ml_tools/ML_scaler/_imprimir.py +8 -0
- ml_tools/ML_trainer/__init__.py +20 -0
- ml_tools/ML_trainer/_base_trainer.py +297 -0
- ml_tools/ML_trainer/_dragon_detection_trainer.py +402 -0
- ml_tools/ML_trainer/_dragon_sequence_trainer.py +540 -0
- ml_tools/ML_trainer/_dragon_trainer.py +1160 -0
- ml_tools/ML_trainer/_imprimir.py +10 -0
- ml_tools/{ML_utilities.py → ML_utilities/__init__.py} +14 -6
- ml_tools/ML_utilities/_artifact_finder.py +382 -0
- ml_tools/ML_utilities/_imprimir.py +16 -0
- ml_tools/ML_utilities/_inspection.py +325 -0
- ml_tools/ML_utilities/_train_tools.py +205 -0
- ml_tools/{ML_vision_transformers.py → ML_vision_transformers/__init__.py} +9 -6
- ml_tools/{_core/_ML_vision_transformers.py → ML_vision_transformers/_core_transforms.py} +11 -155
- ml_tools/ML_vision_transformers/_imprimir.py +14 -0
- ml_tools/ML_vision_transformers/_offline_augmentation.py +159 -0
- ml_tools/{_core/_PSO_optimization.py → PSO_optimization/_PSO.py} +58 -15
- ml_tools/{PSO_optimization.py → PSO_optimization/__init__.py} +5 -3
- ml_tools/PSO_optimization/_imprimir.py +10 -0
- ml_tools/SQL/__init__.py +7 -0
- ml_tools/{_core/_SQL.py → SQL/_dragon_SQL.py} +7 -11
- ml_tools/SQL/_imprimir.py +8 -0
- ml_tools/{_core → VIF}/_VIF_factor.py +5 -8
- ml_tools/{VIF_factor.py → VIF/__init__.py} +4 -2
- ml_tools/VIF/_imprimir.py +10 -0
- ml_tools/_core/__init__.py +7 -1
- ml_tools/_core/_logger.py +8 -18
- ml_tools/_core/_schema_load_ops.py +43 -0
- ml_tools/_core/_script_info.py +2 -2
- ml_tools/{data_exploration.py → data_exploration/__init__.py} +32 -16
- ml_tools/data_exploration/_analysis.py +214 -0
- ml_tools/data_exploration/_cleaning.py +566 -0
- ml_tools/data_exploration/_features.py +583 -0
- ml_tools/data_exploration/_imprimir.py +32 -0
- ml_tools/data_exploration/_plotting.py +487 -0
- ml_tools/data_exploration/_schema_ops.py +176 -0
- ml_tools/{ensemble_evaluation.py → ensemble_evaluation/__init__.py} +6 -4
- ml_tools/{_core → ensemble_evaluation}/_ensemble_evaluation.py +3 -7
- ml_tools/ensemble_evaluation/_imprimir.py +14 -0
- ml_tools/{ensemble_inference.py → ensemble_inference/__init__.py} +5 -3
- ml_tools/{_core → ensemble_inference}/_ensemble_inference.py +15 -18
- ml_tools/ensemble_inference/_imprimir.py +9 -0
- ml_tools/{ensemble_learning.py → ensemble_learning/__init__.py} +4 -6
- ml_tools/{_core → ensemble_learning}/_ensemble_learning.py +7 -10
- ml_tools/ensemble_learning/_imprimir.py +10 -0
- ml_tools/{excel_handler.py → excel_handler/__init__.py} +5 -3
- ml_tools/{_core → excel_handler}/_excel_handler.py +6 -10
- ml_tools/excel_handler/_imprimir.py +13 -0
- ml_tools/{keys.py → keys/__init__.py} +4 -1
- ml_tools/keys/_imprimir.py +11 -0
- ml_tools/{_core → keys}/_keys.py +2 -0
- ml_tools/{math_utilities.py → math_utilities/__init__.py} +5 -2
- ml_tools/math_utilities/_imprimir.py +11 -0
- ml_tools/{_core → math_utilities}/_math_utilities.py +1 -5
- ml_tools/{optimization_tools.py → optimization_tools/__init__.py} +9 -4
- ml_tools/optimization_tools/_imprimir.py +13 -0
- ml_tools/optimization_tools/_optimization_bounds.py +236 -0
- ml_tools/optimization_tools/_optimization_plots.py +218 -0
- ml_tools/{path_manager.py → path_manager/__init__.py} +6 -3
- ml_tools/{_core/_path_manager.py → path_manager/_dragonmanager.py} +11 -347
- ml_tools/path_manager/_imprimir.py +15 -0
- ml_tools/path_manager/_path_tools.py +346 -0
- ml_tools/plot_fonts/__init__.py +8 -0
- ml_tools/plot_fonts/_imprimir.py +8 -0
- ml_tools/{_core → plot_fonts}/_plot_fonts.py +2 -5
- ml_tools/schema/__init__.py +15 -0
- ml_tools/schema/_feature_schema.py +223 -0
- ml_tools/schema/_gui_schema.py +191 -0
- ml_tools/schema/_imprimir.py +10 -0
- ml_tools/{serde.py → serde/__init__.py} +4 -2
- ml_tools/serde/_imprimir.py +10 -0
- ml_tools/{_core → serde}/_serde.py +3 -8
- ml_tools/{utilities.py → utilities/__init__.py} +11 -6
- ml_tools/utilities/_imprimir.py +18 -0
- ml_tools/{_core/_utilities.py → utilities/_utility_save_load.py} +13 -190
- ml_tools/utilities/_utility_tools.py +192 -0
- dragon_ml_toolbox-19.13.0.dist-info/RECORD +0 -111
- ml_tools/ML_chaining_inference.py +0 -8
- ml_tools/ML_configuration.py +0 -86
- ml_tools/ML_configuration_pytab.py +0 -14
- ml_tools/ML_datasetmaster.py +0 -10
- ml_tools/ML_evaluation.py +0 -16
- ml_tools/ML_evaluation_multi.py +0 -12
- ml_tools/ML_finalize_handler.py +0 -8
- ml_tools/ML_inference.py +0 -12
- ml_tools/ML_models.py +0 -14
- ml_tools/ML_models_advanced.py +0 -14
- ml_tools/ML_models_pytab.py +0 -14
- ml_tools/ML_optimization.py +0 -14
- ml_tools/ML_optimization_pareto.py +0 -8
- ml_tools/ML_scaler.py +0 -8
- ml_tools/ML_sequence_datasetmaster.py +0 -8
- ml_tools/ML_sequence_evaluation.py +0 -10
- ml_tools/ML_sequence_inference.py +0 -8
- ml_tools/ML_sequence_models.py +0 -8
- ml_tools/ML_trainer.py +0 -12
- ml_tools/ML_vision_datasetmaster.py +0 -12
- ml_tools/ML_vision_evaluation.py +0 -10
- ml_tools/ML_vision_inference.py +0 -8
- ml_tools/ML_vision_models.py +0 -18
- ml_tools/SQL.py +0 -8
- ml_tools/_core/_ETL_cleaning.py +0 -694
- ml_tools/_core/_IO_tools.py +0 -498
- ml_tools/_core/_ML_callbacks.py +0 -702
- ml_tools/_core/_ML_configuration.py +0 -1332
- ml_tools/_core/_ML_configuration_pytab.py +0 -102
- ml_tools/_core/_ML_evaluation.py +0 -867
- ml_tools/_core/_ML_evaluation_multi.py +0 -544
- ml_tools/_core/_ML_inference.py +0 -646
- ml_tools/_core/_ML_models.py +0 -668
- ml_tools/_core/_ML_models_pytab.py +0 -693
- ml_tools/_core/_ML_trainer.py +0 -2323
- ml_tools/_core/_ML_utilities.py +0 -886
- ml_tools/_core/_ML_vision_models.py +0 -644
- ml_tools/_core/_data_exploration.py +0 -1901
- ml_tools/_core/_optimization_tools.py +0 -493
- ml_tools/_core/_schema.py +0 -359
- ml_tools/plot_fonts.py +0 -8
- ml_tools/schema.py +0 -12
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/top_level.txt +0 -0
ml_tools/_core/_ML_callbacks.py
DELETED
|
@@ -1,702 +0,0 @@
|
|
|
1
|
-
import numpy as np
|
|
2
|
-
import torch
|
|
3
|
-
from collections import deque
|
|
4
|
-
from tqdm.auto import tqdm
|
|
5
|
-
from typing import Union, Literal, Optional
|
|
6
|
-
from pathlib import Path
|
|
7
|
-
|
|
8
|
-
from ._path_manager import make_fullpath
|
|
9
|
-
from ._keys import PyTorchLogKeys, PyTorchCheckpointKeys
|
|
10
|
-
from ._logger import get_logger
|
|
11
|
-
from ._script_info import _script_info
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
_LOGGER = get_logger("Callbacks")
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
__all__ = [
|
|
18
|
-
"History",
|
|
19
|
-
"TqdmProgressBar",
|
|
20
|
-
"DragonPatienceEarlyStopping",
|
|
21
|
-
"DragonPrecheltEarlyStopping",
|
|
22
|
-
"DragonModelCheckpoint",
|
|
23
|
-
"DragonScheduler",
|
|
24
|
-
"DragonReduceLROnPlateau"
|
|
25
|
-
]
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
class _Callback:
|
|
29
|
-
"""
|
|
30
|
-
Abstract base class used to build new callbacks.
|
|
31
|
-
|
|
32
|
-
The methods of this class are automatically called by the Trainer at different
|
|
33
|
-
points during training. Subclasses can override these methods to implement
|
|
34
|
-
custom logic.
|
|
35
|
-
"""
|
|
36
|
-
def __init__(self):
|
|
37
|
-
self.trainer = None
|
|
38
|
-
|
|
39
|
-
def set_trainer(self, trainer):
|
|
40
|
-
"""This is called by the Trainer to associate itself with the callback."""
|
|
41
|
-
self.trainer = trainer
|
|
42
|
-
|
|
43
|
-
def on_train_begin(self, logs=None):
|
|
44
|
-
"""Called at the beginning of training."""
|
|
45
|
-
pass
|
|
46
|
-
|
|
47
|
-
def on_train_end(self, logs=None):
|
|
48
|
-
"""Called at the end of training."""
|
|
49
|
-
pass
|
|
50
|
-
|
|
51
|
-
def on_epoch_begin(self, epoch, logs=None):
|
|
52
|
-
"""Called at the beginning of an epoch."""
|
|
53
|
-
pass
|
|
54
|
-
|
|
55
|
-
def on_epoch_end(self, epoch, logs=None):
|
|
56
|
-
"""Called at the end of an epoch."""
|
|
57
|
-
pass
|
|
58
|
-
|
|
59
|
-
def on_batch_begin(self, batch, logs=None):
|
|
60
|
-
"""Called at the beginning of a training batch."""
|
|
61
|
-
pass
|
|
62
|
-
|
|
63
|
-
def on_batch_end(self, batch, logs=None):
|
|
64
|
-
"""Called at the end of a training batch."""
|
|
65
|
-
pass
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
class History(_Callback):
|
|
69
|
-
"""
|
|
70
|
-
Callback that records events into a `history` dictionary.
|
|
71
|
-
|
|
72
|
-
This callback is automatically applied to every MyTrainer model.
|
|
73
|
-
The `history` attribute is a dictionary mapping metric names (e.g., 'val_loss')
|
|
74
|
-
to a list of metric values.
|
|
75
|
-
"""
|
|
76
|
-
def on_train_begin(self, logs=None):
|
|
77
|
-
# Clear history at the beginning of training
|
|
78
|
-
self.trainer.history = {} # type: ignore
|
|
79
|
-
|
|
80
|
-
def on_epoch_end(self, epoch, logs=None):
|
|
81
|
-
logs = logs or {}
|
|
82
|
-
for k, v in logs.items():
|
|
83
|
-
# Append new log values to the history dictionary
|
|
84
|
-
self.trainer.history.setdefault(k, []).append(v) # type: ignore
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
class TqdmProgressBar(_Callback):
|
|
88
|
-
"""Callback that provides a tqdm progress bar for training."""
|
|
89
|
-
def __init__(self):
|
|
90
|
-
self.epoch_bar = None
|
|
91
|
-
self.batch_bar = None
|
|
92
|
-
|
|
93
|
-
def on_train_begin(self, logs=None):
|
|
94
|
-
self.epochs = self.trainer.epochs # type: ignore
|
|
95
|
-
self.epoch_bar = tqdm(total=self.epochs, desc="Training Progress")
|
|
96
|
-
|
|
97
|
-
def on_epoch_begin(self, epoch, logs=None):
|
|
98
|
-
total_batches = len(self.trainer.train_loader) # type: ignore
|
|
99
|
-
self.batch_bar = tqdm(total=total_batches, desc=f"Epoch {epoch}/{self.epochs}", leave=False)
|
|
100
|
-
|
|
101
|
-
def on_batch_end(self, batch, logs=None):
|
|
102
|
-
self.batch_bar.update(1) # type: ignore
|
|
103
|
-
if logs:
|
|
104
|
-
self.batch_bar.set_postfix(loss=f"{logs.get(PyTorchLogKeys.BATCH_LOSS, 0):.4f}") # type: ignore
|
|
105
|
-
|
|
106
|
-
def on_epoch_end(self, epoch, logs=None):
|
|
107
|
-
self.batch_bar.close() # type: ignore
|
|
108
|
-
self.epoch_bar.update(1) # type: ignore
|
|
109
|
-
if logs:
|
|
110
|
-
train_loss_str = f"{logs.get(PyTorchLogKeys.TRAIN_LOSS, 0):.4f}"
|
|
111
|
-
val_loss_str = f"{logs.get(PyTorchLogKeys.VAL_LOSS, 0):.4f}"
|
|
112
|
-
self.epoch_bar.set_postfix_str(f"Train Loss: {train_loss_str}, Val Loss: {val_loss_str}") # type: ignore
|
|
113
|
-
|
|
114
|
-
def on_train_end(self, logs=None):
|
|
115
|
-
self.epoch_bar.close() # type: ignore
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
class _DragonEarlyStopping(_Callback):
|
|
119
|
-
"""
|
|
120
|
-
Base class for Early Stopping strategies.
|
|
121
|
-
Ensures type compatibility and shared logging logic.
|
|
122
|
-
"""
|
|
123
|
-
def __init__(self,
|
|
124
|
-
monitor: str,
|
|
125
|
-
verbose: int = 1):
|
|
126
|
-
super().__init__()
|
|
127
|
-
self.monitor = monitor
|
|
128
|
-
self.verbose = verbose
|
|
129
|
-
self.stopped_epoch = 0
|
|
130
|
-
|
|
131
|
-
def _stop_training(self, epoch: int, reason: str):
|
|
132
|
-
"""Helper to trigger the stop."""
|
|
133
|
-
self.stopped_epoch = epoch
|
|
134
|
-
self.trainer.stop_training = True # type: ignore
|
|
135
|
-
if self.verbose > 0:
|
|
136
|
-
_LOGGER.info(f"Epoch {epoch}: Early stopping triggered. Reason: {reason}")
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
class DragonPatienceEarlyStopping(_DragonEarlyStopping):
|
|
140
|
-
"""
|
|
141
|
-
Standard early stopping: Tracks minimum validation loss (or other metric) with a patience counter.
|
|
142
|
-
"""
|
|
143
|
-
def __init__(self,
|
|
144
|
-
monitor: Literal["Training Loss", "Validation Loss"] = "Validation Loss",
|
|
145
|
-
min_delta: float = 0.0,
|
|
146
|
-
patience: int = 10,
|
|
147
|
-
mode: Literal['min', 'max'] = 'min',
|
|
148
|
-
verbose: int = 1):
|
|
149
|
-
"""
|
|
150
|
-
Args:
|
|
151
|
-
monitor (str): Metric to monitor.
|
|
152
|
-
min_delta (float): Minimum change to qualify as an improvement.
|
|
153
|
-
patience (int): Number of epochs with no improvement after which training will be stopped.
|
|
154
|
-
mode (str): One of {'min', 'max'}. In 'min' mode, training will stop when the quantity monitored has stopped decreasing; in 'max' mode it will stop when the quantity monitored has stopped increasing.
|
|
155
|
-
verbose (int): Verbosity mode.
|
|
156
|
-
"""
|
|
157
|
-
# standardize monitor key
|
|
158
|
-
if monitor == "Training Loss":
|
|
159
|
-
std_monitor = PyTorchLogKeys.TRAIN_LOSS
|
|
160
|
-
elif monitor == "Validation Loss":
|
|
161
|
-
std_monitor = PyTorchLogKeys.VAL_LOSS
|
|
162
|
-
else:
|
|
163
|
-
_LOGGER.error(f"Unknown monitor key: {monitor}.")
|
|
164
|
-
raise ValueError()
|
|
165
|
-
|
|
166
|
-
super().__init__(std_monitor, verbose)
|
|
167
|
-
self.patience = patience
|
|
168
|
-
self.min_delta = min_delta
|
|
169
|
-
self.wait = 0
|
|
170
|
-
self.mode = mode
|
|
171
|
-
|
|
172
|
-
if mode not in ['min', 'max']:
|
|
173
|
-
_LOGGER.error(f"EarlyStopping mode {mode} is unknown, choose one of ('min', 'max')")
|
|
174
|
-
raise ValueError()
|
|
175
|
-
|
|
176
|
-
# Determine the comparison operator
|
|
177
|
-
if self.mode == 'min':
|
|
178
|
-
self.monitor_op = np.less
|
|
179
|
-
elif self.mode == 'max':
|
|
180
|
-
self.monitor_op = np.greater
|
|
181
|
-
else:
|
|
182
|
-
# raise error for unknown mode
|
|
183
|
-
_LOGGER.error(f"EarlyStopping mode {mode} is unknown, choose one of ('min', 'max')")
|
|
184
|
-
raise ValueError()
|
|
185
|
-
|
|
186
|
-
self.best = np.inf if self.monitor_op == np.less else -np.inf
|
|
187
|
-
|
|
188
|
-
def on_train_begin(self, logs=None):
|
|
189
|
-
self.wait = 0
|
|
190
|
-
self.best = np.inf if self.monitor_op == np.less else -np.inf
|
|
191
|
-
|
|
192
|
-
def on_epoch_end(self, epoch, logs=None):
|
|
193
|
-
current = logs.get(self.monitor) # type: ignore
|
|
194
|
-
if current is None:
|
|
195
|
-
return
|
|
196
|
-
|
|
197
|
-
# Check improvement
|
|
198
|
-
if self.monitor_op == np.less:
|
|
199
|
-
is_improvement = self.monitor_op(current, self.best - self.min_delta)
|
|
200
|
-
else:
|
|
201
|
-
is_improvement = self.monitor_op(current, self.best + self.min_delta)
|
|
202
|
-
|
|
203
|
-
if is_improvement:
|
|
204
|
-
if self.verbose > 1:
|
|
205
|
-
_LOGGER.info(f"EarlyStopping: {self.monitor} improved from {self.best:.4f} to {current:.4f}")
|
|
206
|
-
self.best = current
|
|
207
|
-
self.wait = 0
|
|
208
|
-
else:
|
|
209
|
-
self.wait += 1
|
|
210
|
-
if self.wait >= self.patience:
|
|
211
|
-
self._stop_training(epoch, f"No improvement in {self.monitor} for {self.wait} epochs.")
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
class DragonPrecheltEarlyStopping(_DragonEarlyStopping):
|
|
215
|
-
"""
|
|
216
|
-
Implements Prechelt's 'Progress-Modified GL' criterion.
|
|
217
|
-
Tracks the ratio between Generalization Loss (overfitting) and Training Progress.
|
|
218
|
-
|
|
219
|
-
References:
|
|
220
|
-
Prechelt, L. (1998). Early Stopping - But When?
|
|
221
|
-
"""
|
|
222
|
-
def __init__(self,
|
|
223
|
-
alpha: float = 0.75,
|
|
224
|
-
k: int = 5,
|
|
225
|
-
verbose: int = 1):
|
|
226
|
-
"""
|
|
227
|
-
This early stopping strategy monitors both validation loss and training loss to determine the optimal stopping point.
|
|
228
|
-
|
|
229
|
-
Args:
|
|
230
|
-
alpha (float): The threshold for the stopping criterion.
|
|
231
|
-
k (int): The window size for calculating training progress.
|
|
232
|
-
verbose (int): Verbosity mode.
|
|
233
|
-
|
|
234
|
-
NOTE:
|
|
235
|
-
|
|
236
|
-
- **The Strip Size (k)**:
|
|
237
|
-
- `5`: The empirical "gold standard." It is long enough to smooth out batch noise but short enough to react to convergence plateaus quickly.
|
|
238
|
-
- `10` to `20`: Use if the training curve is very jagged (e.g., noisy data, small batch sizes, high dropout, or Reinforcement Learning). A larger k value prevents premature stopping due to random volatility.
|
|
239
|
-
- **The threshold (alpha)**:
|
|
240
|
-
- `< 0.5`: Aggressive. Stops training very early.
|
|
241
|
-
- `0.75` to `0.80`: Prechelt found this range to be the most robust across different datasets. It typically yields the best trade-off between generalization and training cost.
|
|
242
|
-
- `1.0` to `1.2`: Useful for complex tasks (like Transformers) where training progress might dip temporarily before recovering. It risks slightly more overfitting but ensures potential is exhausted.
|
|
243
|
-
"""
|
|
244
|
-
super().__init__(PyTorchLogKeys.VAL_LOSS, verbose)
|
|
245
|
-
self.train_monitor = PyTorchLogKeys.TRAIN_LOSS
|
|
246
|
-
self.alpha = alpha
|
|
247
|
-
self.k = k
|
|
248
|
-
|
|
249
|
-
self.best_val_loss = np.inf
|
|
250
|
-
self.train_strip = deque(maxlen=k)
|
|
251
|
-
|
|
252
|
-
def on_train_begin(self, logs=None):
|
|
253
|
-
self.best_val_loss = np.inf
|
|
254
|
-
self.train_strip.clear()
|
|
255
|
-
|
|
256
|
-
def on_epoch_end(self, epoch, logs=None):
|
|
257
|
-
val_loss = logs.get(self.monitor) # type: ignore
|
|
258
|
-
train_loss = logs.get(self.train_monitor) # type: ignore
|
|
259
|
-
|
|
260
|
-
if val_loss is None or train_loss is None:
|
|
261
|
-
return
|
|
262
|
-
|
|
263
|
-
# 1. Update Best Validation Loss
|
|
264
|
-
if val_loss < self.best_val_loss:
|
|
265
|
-
self.best_val_loss = val_loss
|
|
266
|
-
|
|
267
|
-
# 2. Update Training Strip
|
|
268
|
-
self.train_strip.append(train_loss)
|
|
269
|
-
|
|
270
|
-
# 3. Calculate Generalization Loss (GL)
|
|
271
|
-
# GL(t) = 100 * (E_val / E_opt - 1)
|
|
272
|
-
# Low GL is good. High GL means we are drifting away from best val score (overfitting).
|
|
273
|
-
gl = 100 * ((val_loss / self.best_val_loss) - 1)
|
|
274
|
-
|
|
275
|
-
# 4. Calculate Progress (Pk)
|
|
276
|
-
# Pk(t) = 1000 * (Sum(strip) / (k * min(strip)) - 1)
|
|
277
|
-
# High Pk is good (training loss is still dropping fast). Low Pk means training has stalled.
|
|
278
|
-
if len(self.train_strip) < self.k:
|
|
279
|
-
# Not enough data for progress yet
|
|
280
|
-
return
|
|
281
|
-
|
|
282
|
-
strip_sum = sum(self.train_strip)
|
|
283
|
-
strip_min = min(self.train_strip)
|
|
284
|
-
|
|
285
|
-
# Avoid division by zero
|
|
286
|
-
if strip_min == 0:
|
|
287
|
-
pk = 0.1 # Arbitrary small number
|
|
288
|
-
else:
|
|
289
|
-
pk = 1000 * ((strip_sum / (self.k * strip_min)) - 1)
|
|
290
|
-
|
|
291
|
-
# 5. The Quotient Criterion
|
|
292
|
-
# Stop if GL / Pk > alpha
|
|
293
|
-
# Intuition: Stop if Overfitting is high AND Progress is low.
|
|
294
|
-
|
|
295
|
-
# Avoid division by zero
|
|
296
|
-
if pk == 0:
|
|
297
|
-
pk = 1e-6
|
|
298
|
-
|
|
299
|
-
quotient = gl / pk
|
|
300
|
-
|
|
301
|
-
if self.verbose > 1:
|
|
302
|
-
_LOGGER.info(f"Epoch {epoch}: GL={gl:.3f} | Pk={pk:.3f} | Quotient={quotient:.3f} (Threshold={self.alpha})")
|
|
303
|
-
|
|
304
|
-
if quotient > self.alpha:
|
|
305
|
-
self._stop_training(epoch, f"Prechelt Criterion triggered. Generalization/Progress quotient ({quotient:.3f}) > alpha ({self.alpha}).")
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
class DragonModelCheckpoint(_Callback):
|
|
309
|
-
"""
|
|
310
|
-
Saves the model weights, optimizer state, LR scheduler state (if any), and epoch number to a directory with automated filename generation and rotation.
|
|
311
|
-
"""
|
|
312
|
-
def __init__(self,
|
|
313
|
-
save_dir: Union[str, Path],
|
|
314
|
-
monitor: Literal["Training Loss", "Validation Loss", "both"] = "Validation Loss",
|
|
315
|
-
save_three_best: bool = True,
|
|
316
|
-
mode: Literal['min', 'max'] = 'min',
|
|
317
|
-
verbose: int = 0):
|
|
318
|
-
"""
|
|
319
|
-
Args:
|
|
320
|
-
save_dir (str): Directory where checkpoint files will be saved.
|
|
321
|
-
monitor (str): Metric to monitor. If "both", the sum of training loss and validation loss is used.
|
|
322
|
-
save_three_best (bool):
|
|
323
|
-
- If True, keeps the top 3 best checkpoints found during training (based on metric).
|
|
324
|
-
- If False, keeps the 3 most recent checkpoints (rolling window).
|
|
325
|
-
mode (str): One of {'min', 'max'}.
|
|
326
|
-
verbose (int): Verbosity mode.
|
|
327
|
-
"""
|
|
328
|
-
super().__init__()
|
|
329
|
-
self.save_dir = make_fullpath(save_dir, make=True, enforce="directory")
|
|
330
|
-
|
|
331
|
-
# Standardize monitor key
|
|
332
|
-
if monitor == "Training Loss":
|
|
333
|
-
std_monitor = PyTorchLogKeys.TRAIN_LOSS
|
|
334
|
-
elif monitor == "Validation Loss":
|
|
335
|
-
std_monitor = PyTorchLogKeys.VAL_LOSS
|
|
336
|
-
elif monitor == "both":
|
|
337
|
-
std_monitor = "both"
|
|
338
|
-
else:
|
|
339
|
-
_LOGGER.error(f"Unknown monitor key: {monitor}.")
|
|
340
|
-
raise ValueError()
|
|
341
|
-
|
|
342
|
-
self.monitor = std_monitor
|
|
343
|
-
self.save_three_best = save_three_best
|
|
344
|
-
self.verbose = verbose
|
|
345
|
-
self._latest_checkpoint_path = None
|
|
346
|
-
self._checkpoint_name = PyTorchCheckpointKeys.CHECKPOINT_NAME
|
|
347
|
-
|
|
348
|
-
# State variables
|
|
349
|
-
# stored as list of dicts: [{'path': Path, 'score': float, 'epoch': int}]
|
|
350
|
-
self.best_checkpoints = []
|
|
351
|
-
# For rolling check (save_three_best=False)
|
|
352
|
-
self.recent_checkpoints = []
|
|
353
|
-
|
|
354
|
-
if mode not in ['min', 'max']:
|
|
355
|
-
_LOGGER.error(f"ModelCheckpoint mode {mode} is unknown. Use 'min' or 'max'.")
|
|
356
|
-
raise ValueError()
|
|
357
|
-
self.mode = mode
|
|
358
|
-
|
|
359
|
-
# Determine comparison operator
|
|
360
|
-
if self.mode == 'min':
|
|
361
|
-
self.monitor_op = np.less
|
|
362
|
-
self.best = np.inf
|
|
363
|
-
else:
|
|
364
|
-
self.monitor_op = np.greater
|
|
365
|
-
self.best = -np.inf
|
|
366
|
-
|
|
367
|
-
def on_train_begin(self, logs=None):
|
|
368
|
-
"""Reset file tracking state when training starts.
|
|
369
|
-
NOTE: Do nOT reset self.best here if it differs from the default. This allows the Trainer to restore 'best' from a checkpoint before calling train()."""
|
|
370
|
-
self.best_checkpoints = []
|
|
371
|
-
self.recent_checkpoints = []
|
|
372
|
-
|
|
373
|
-
# Check if self.best is at default initialization value
|
|
374
|
-
is_default_min = (self.mode == 'min' and self.best == np.inf)
|
|
375
|
-
is_default_max = (self.mode == 'max' and self.best == -np.inf)
|
|
376
|
-
|
|
377
|
-
# If it is NOT default, it means it was restored.
|
|
378
|
-
if not (is_default_min or is_default_max):
|
|
379
|
-
_LOGGER.debug(f"Resuming with best score: {self.best:.4f}")
|
|
380
|
-
|
|
381
|
-
def _get_metric_value(self, logs):
|
|
382
|
-
"""Extracts or calculates the metric value based on configuration."""
|
|
383
|
-
if self.monitor == "both":
|
|
384
|
-
t_loss = logs.get(PyTorchLogKeys.TRAIN_LOSS)
|
|
385
|
-
v_loss = logs.get(PyTorchLogKeys.VAL_LOSS)
|
|
386
|
-
if t_loss is None or v_loss is None:
|
|
387
|
-
return None
|
|
388
|
-
return t_loss + v_loss
|
|
389
|
-
else:
|
|
390
|
-
return logs.get(self.monitor)
|
|
391
|
-
|
|
392
|
-
def on_epoch_end(self, epoch, logs=None):
|
|
393
|
-
logs = logs or {}
|
|
394
|
-
current_score = self._get_metric_value(logs)
|
|
395
|
-
|
|
396
|
-
if current_score is None:
|
|
397
|
-
if self.verbose > 0:
|
|
398
|
-
_LOGGER.warning(f"Epoch {epoch}: Metric '{self.monitor}' not found in logs. Skipping checkpoint.")
|
|
399
|
-
return
|
|
400
|
-
|
|
401
|
-
# 1. Update global best score (for logging/metadata)
|
|
402
|
-
if self.monitor_op(current_score, self.best):
|
|
403
|
-
if self.verbose > 0:
|
|
404
|
-
# Only log explicit "improvement" if we are beating the historical best
|
|
405
|
-
old_best_str = f"{self.best:.4f}" if not np.isinf(self.best) else "inf"
|
|
406
|
-
_LOGGER.info(f"Epoch {epoch}: {self.monitor} improved from {old_best_str} to {current_score:.4f}")
|
|
407
|
-
self.best = current_score
|
|
408
|
-
|
|
409
|
-
if self.save_three_best:
|
|
410
|
-
self._save_top_k_checkpoints(epoch, current_score)
|
|
411
|
-
else:
|
|
412
|
-
self._save_rolling_checkpoints(epoch, current_score)
|
|
413
|
-
|
|
414
|
-
def _save_checkpoint_file(self, epoch, current_score):
|
|
415
|
-
"""Helper to physically save the file."""
|
|
416
|
-
self.save_dir.mkdir(parents=True, exist_ok=True)
|
|
417
|
-
|
|
418
|
-
# Create filename
|
|
419
|
-
score_str = f"{current_score:.4f}".replace('.', '_')
|
|
420
|
-
filename = f"epoch{epoch}_{self._checkpoint_name}-{score_str}.pth"
|
|
421
|
-
filepath = self.save_dir / filename
|
|
422
|
-
|
|
423
|
-
# Create checkpoint dict
|
|
424
|
-
checkpoint_data = {
|
|
425
|
-
PyTorchCheckpointKeys.EPOCH: epoch,
|
|
426
|
-
PyTorchCheckpointKeys.MODEL_STATE: self.trainer.model.state_dict(), # type: ignore
|
|
427
|
-
PyTorchCheckpointKeys.OPTIMIZER_STATE: self.trainer.optimizer.state_dict(), # type: ignore
|
|
428
|
-
PyTorchCheckpointKeys.BEST_SCORE: current_score,
|
|
429
|
-
PyTorchCheckpointKeys.HISTORY: self.trainer.history, # type: ignore
|
|
430
|
-
}
|
|
431
|
-
|
|
432
|
-
if hasattr(self.trainer, 'scheduler') and self.trainer.scheduler is not None: # type: ignore
|
|
433
|
-
checkpoint_data[PyTorchCheckpointKeys.SCHEDULER_STATE] = self.trainer.scheduler.state_dict() # type: ignore
|
|
434
|
-
|
|
435
|
-
torch.save(checkpoint_data, filepath)
|
|
436
|
-
self._latest_checkpoint_path = filepath
|
|
437
|
-
|
|
438
|
-
return filepath
|
|
439
|
-
|
|
440
|
-
def _save_top_k_checkpoints(self, epoch, current_score):
|
|
441
|
-
"""Logic for maintaining the top 3 best checkpoints."""
|
|
442
|
-
|
|
443
|
-
def sort_key(item): return item['score']
|
|
444
|
-
|
|
445
|
-
# Determine sort direction so that Index 0 is BEST and Index -1 is WORST
|
|
446
|
-
# Min mode (lower is better): Ascending (reverse=False) -> [0.1, 0.5, 0.9] (0.1 is best)
|
|
447
|
-
# Max mode (higher is better): Descending (reverse=True) -> [0.9, 0.5, 0.1] (0.9 is best)
|
|
448
|
-
is_reverse = (self.mode == 'max')
|
|
449
|
-
|
|
450
|
-
should_save = False
|
|
451
|
-
|
|
452
|
-
if len(self.best_checkpoints) < 3:
|
|
453
|
-
should_save = True
|
|
454
|
-
else:
|
|
455
|
-
# Sort current list to identify the worst (last item)
|
|
456
|
-
self.best_checkpoints.sort(key=sort_key, reverse=is_reverse)
|
|
457
|
-
worst_entry = self.best_checkpoints[-1]
|
|
458
|
-
|
|
459
|
-
# Check if current is better than the worst in the list
|
|
460
|
-
# min mode: current < worst['score']
|
|
461
|
-
# max mode: current > worst['score']
|
|
462
|
-
if self.monitor_op(current_score, worst_entry['score']):
|
|
463
|
-
should_save = True
|
|
464
|
-
|
|
465
|
-
if should_save:
|
|
466
|
-
filepath = self._save_checkpoint_file(epoch, current_score)
|
|
467
|
-
|
|
468
|
-
if self.verbose > 0:
|
|
469
|
-
_LOGGER.info(f"Epoch {epoch}: {self.monitor} ({current_score:.4f}) is in top 3. Saving to {filepath.name}")
|
|
470
|
-
|
|
471
|
-
self.best_checkpoints.append({'path': filepath, 'score': current_score, 'epoch': epoch})
|
|
472
|
-
|
|
473
|
-
# Prune if > 3
|
|
474
|
-
if len(self.best_checkpoints) > 3:
|
|
475
|
-
# Re-sort to ensure worst is at the end
|
|
476
|
-
self.best_checkpoints.sort(key=sort_key, reverse=is_reverse)
|
|
477
|
-
|
|
478
|
-
# Evict the last one (Worst)
|
|
479
|
-
entry_to_delete = self.best_checkpoints.pop(-1)
|
|
480
|
-
|
|
481
|
-
if entry_to_delete['path'].exists():
|
|
482
|
-
if self.verbose > 0:
|
|
483
|
-
_LOGGER.info(f" -> Deleting checkpoint outside top 3: {entry_to_delete['path'].name}")
|
|
484
|
-
entry_to_delete['path'].unlink()
|
|
485
|
-
|
|
486
|
-
def _save_rolling_checkpoints(self, epoch, current_score):
|
|
487
|
-
"""Saves the latest model and keeps only the 3 most recent ones."""
|
|
488
|
-
filepath = self._save_checkpoint_file(epoch, current_score)
|
|
489
|
-
|
|
490
|
-
if self.verbose > 0:
|
|
491
|
-
_LOGGER.info(f'Epoch {epoch}: saving rolling model to {filepath.name}')
|
|
492
|
-
|
|
493
|
-
self.recent_checkpoints.append(filepath)
|
|
494
|
-
|
|
495
|
-
# If we have more than 3 checkpoints, remove the oldest one
|
|
496
|
-
if len(self.recent_checkpoints) > 3:
|
|
497
|
-
file_to_delete = self.recent_checkpoints.pop(0)
|
|
498
|
-
if file_to_delete.exists():
|
|
499
|
-
if self.verbose > 0:
|
|
500
|
-
_LOGGER.info(f" -> Deleting old rolling checkpoint: {file_to_delete.name}")
|
|
501
|
-
file_to_delete.unlink()
|
|
502
|
-
|
|
503
|
-
@property
|
|
504
|
-
def best_checkpoint_path(self):
|
|
505
|
-
# If tracking top 3, return the absolute best among them
|
|
506
|
-
if self.save_three_best and self.best_checkpoints:
|
|
507
|
-
def sort_key(item): return item['score']
|
|
508
|
-
is_reverse = (self.mode == 'max')
|
|
509
|
-
# Sort Best -> Worst
|
|
510
|
-
sorted_bests = sorted(self.best_checkpoints, key=sort_key, reverse=is_reverse)
|
|
511
|
-
# Index 0 is always the best based on the logic above
|
|
512
|
-
return sorted_bests[0]['path']
|
|
513
|
-
|
|
514
|
-
elif self._latest_checkpoint_path:
|
|
515
|
-
return self._latest_checkpoint_path
|
|
516
|
-
else:
|
|
517
|
-
_LOGGER.error("No checkpoint paths saved.")
|
|
518
|
-
raise ValueError()
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
class _DragonLRScheduler(_Callback):
|
|
522
|
-
"""
|
|
523
|
-
Base class for Dragon LR Schedulers.
|
|
524
|
-
Handles common logic like logging and attaching to the trainer.
|
|
525
|
-
"""
|
|
526
|
-
def __init__(self):
|
|
527
|
-
super().__init__()
|
|
528
|
-
self.scheduler = None
|
|
529
|
-
self.previous_lr = None
|
|
530
|
-
|
|
531
|
-
def set_trainer(self, trainer):
|
|
532
|
-
"""Associates the callback with the trainer."""
|
|
533
|
-
super().set_trainer(trainer)
|
|
534
|
-
# Note: Subclasses must ensure self.scheduler is set before or during this call
|
|
535
|
-
# if they want to register it immediately.
|
|
536
|
-
if self.scheduler:
|
|
537
|
-
self.trainer.scheduler = self.scheduler # type: ignore
|
|
538
|
-
|
|
539
|
-
def on_train_begin(self, logs=None):
|
|
540
|
-
"""Store the initial learning rate."""
|
|
541
|
-
if not self.trainer.optimizer: # type: ignore
|
|
542
|
-
_LOGGER.warning("No optimizer found in trainer. LRScheduler cannot track learning rate.")
|
|
543
|
-
return
|
|
544
|
-
self.previous_lr = self.trainer.optimizer.param_groups[0]['lr'] # type: ignore
|
|
545
|
-
|
|
546
|
-
def _check_and_log_lr(self, epoch, logs, verbose: bool):
|
|
547
|
-
"""Helper to log LR changes and update history."""
|
|
548
|
-
if not self.trainer.optimizer: # type: ignore
|
|
549
|
-
return
|
|
550
|
-
|
|
551
|
-
current_lr = self.trainer.optimizer.param_groups[0]['lr'] # type: ignore
|
|
552
|
-
|
|
553
|
-
# Log change
|
|
554
|
-
if self.previous_lr is not None and current_lr != self.previous_lr:
|
|
555
|
-
if verbose:
|
|
556
|
-
print(f" > Epoch {epoch}: Learning rate changed to {current_lr:.6f}")
|
|
557
|
-
self.previous_lr = current_lr
|
|
558
|
-
|
|
559
|
-
# Log to dictionary
|
|
560
|
-
logs[PyTorchLogKeys.LEARNING_RATE] = current_lr
|
|
561
|
-
|
|
562
|
-
# Log to history
|
|
563
|
-
if hasattr(self.trainer, 'history'):
|
|
564
|
-
self.trainer.history.setdefault(PyTorchLogKeys.LEARNING_RATE, []).append(current_lr) # type: ignore
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
class DragonScheduler(_DragonLRScheduler):
|
|
568
|
-
"""
|
|
569
|
-
Callback for standard PyTorch Learning Rate Schedulers.
|
|
570
|
-
|
|
571
|
-
Compatible with: StepLR, MultiStepLR, ExponentialLR, CosineAnnealingLR, etc.
|
|
572
|
-
|
|
573
|
-
NOT Compatible with: ReduceLROnPlateau (Use `DragonReduceLROnPlateau` instead).
|
|
574
|
-
"""
|
|
575
|
-
def __init__(self, scheduler, verbose: bool=True):
|
|
576
|
-
"""
|
|
577
|
-
Args:
|
|
578
|
-
scheduler: An initialized PyTorch learning rate scheduler instance.
|
|
579
|
-
verbose (bool): If True, logs learning rate changes to console.
|
|
580
|
-
"""
|
|
581
|
-
super().__init__()
|
|
582
|
-
if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
|
|
583
|
-
raise ValueError(
|
|
584
|
-
"DragonLRScheduler does not support 'ReduceLROnPlateau'. "
|
|
585
|
-
"Please use the `DragonReduceLROnPlateau` callback instead."
|
|
586
|
-
)
|
|
587
|
-
self.scheduler = scheduler
|
|
588
|
-
self.verbose = verbose
|
|
589
|
-
|
|
590
|
-
def set_trainer(self, trainer):
|
|
591
|
-
super().set_trainer(trainer)
|
|
592
|
-
# Explicitly register the scheduler again to be safe
|
|
593
|
-
self.trainer.scheduler = self.scheduler # type: ignore
|
|
594
|
-
if self.verbose:
|
|
595
|
-
_LOGGER.info(f"Registered LR Scheduler: {self.scheduler.__class__.__name__}")
|
|
596
|
-
|
|
597
|
-
def on_epoch_end(self, epoch, logs=None):
|
|
598
|
-
logs = logs or {}
|
|
599
|
-
|
|
600
|
-
# Standard step (no metrics needed)
|
|
601
|
-
self.scheduler.step()
|
|
602
|
-
|
|
603
|
-
self._check_and_log_lr(epoch, logs, self.verbose)
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
class DragonReduceLROnPlateau(_DragonLRScheduler):
|
|
607
|
-
"""
|
|
608
|
-
Specific callback for `torch.optim.lr_scheduler.ReduceLROnPlateau`. Reduces learning rate when a monitored metric has stopped improving.
|
|
609
|
-
|
|
610
|
-
This wrapper initializes the scheduler internally using the Trainer's optimizer, simplifying the setup process.
|
|
611
|
-
"""
|
|
612
|
-
def __init__(self,
|
|
613
|
-
monitor: Literal["Training Loss", "Validation Loss"] = "Validation Loss",
|
|
614
|
-
mode: Literal['min', 'max'] = 'min',
|
|
615
|
-
factor: float = 0.1,
|
|
616
|
-
patience: int = 5,
|
|
617
|
-
threshold: float = 1e-4,
|
|
618
|
-
threshold_mode: Literal['rel', 'abs'] = 'rel',
|
|
619
|
-
cooldown: int = 0,
|
|
620
|
-
min_lr: float = 0,
|
|
621
|
-
eps: float = 1e-8,
|
|
622
|
-
verbose: bool = True):
|
|
623
|
-
"""
|
|
624
|
-
Args:
|
|
625
|
-
monitor ("Training Loss", "Validation Loss"): Metric to monitor.
|
|
626
|
-
mode ('min', 'max'): One of 'min', 'max'.
|
|
627
|
-
factor (float): Factor by which the learning rate will be reduced. new_lr = lr * factor.
|
|
628
|
-
patience (int): Number of epochs with no improvement after which learning rate will be reduced.
|
|
629
|
-
threshold (float): Threshold for measuring the new optimum.
|
|
630
|
-
threshold_mode ('rel', 'abs'): One of 'rel', 'abs'.
|
|
631
|
-
cooldown (int): Number of epochs to wait before resuming normal operation after lr has been reduced.
|
|
632
|
-
min_lr (float or list): A scalar or a list of scalars.
|
|
633
|
-
eps (float): Minimal decay applied to lr.
|
|
634
|
-
verbose (bool): If True, logs learning rate changes to console.
|
|
635
|
-
"""
|
|
636
|
-
super().__init__()
|
|
637
|
-
|
|
638
|
-
# Standardize monitor key
|
|
639
|
-
if monitor == "Training Loss":
|
|
640
|
-
std_monitor = PyTorchLogKeys.TRAIN_LOSS
|
|
641
|
-
elif monitor == "Validation Loss":
|
|
642
|
-
std_monitor = PyTorchLogKeys.VAL_LOSS
|
|
643
|
-
else:
|
|
644
|
-
_LOGGER.error(f"Unknown monitor key: {monitor}.")
|
|
645
|
-
raise ValueError()
|
|
646
|
-
|
|
647
|
-
self.monitor = std_monitor
|
|
648
|
-
self.verbose = verbose
|
|
649
|
-
|
|
650
|
-
# Config storage for delayed initialization
|
|
651
|
-
self.config = {
|
|
652
|
-
'mode': mode,
|
|
653
|
-
'factor': factor,
|
|
654
|
-
'patience': patience,
|
|
655
|
-
'threshold': threshold,
|
|
656
|
-
'threshold_mode': threshold_mode,
|
|
657
|
-
'cooldown': cooldown,
|
|
658
|
-
'min_lr': min_lr,
|
|
659
|
-
'eps': eps,
|
|
660
|
-
}
|
|
661
|
-
|
|
662
|
-
def set_trainer(self, trainer):
|
|
663
|
-
"""
|
|
664
|
-
Initializes the ReduceLROnPlateau scheduler using the trainer's optimizer and registers it.
|
|
665
|
-
"""
|
|
666
|
-
super().set_trainer(trainer)
|
|
667
|
-
|
|
668
|
-
if not hasattr(self.trainer, 'optimizer'):
|
|
669
|
-
_LOGGER.error("Trainer has no optimizer. Cannot initialize ReduceLROnPlateau.")
|
|
670
|
-
raise ValueError()
|
|
671
|
-
|
|
672
|
-
# Initialize the actual scheduler with the optimizer
|
|
673
|
-
if self.verbose:
|
|
674
|
-
_LOGGER.info(f"Initializing ReduceLROnPlateau monitoring '{self.monitor}'")
|
|
675
|
-
|
|
676
|
-
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
|
677
|
-
optimizer=self.trainer.optimizer, # type: ignore
|
|
678
|
-
**self.config
|
|
679
|
-
)
|
|
680
|
-
|
|
681
|
-
# Register with trainer for checkpointing
|
|
682
|
-
self.trainer.scheduler = self.scheduler # type: ignore
|
|
683
|
-
|
|
684
|
-
def on_epoch_end(self, epoch, logs=None):
|
|
685
|
-
logs = logs or {}
|
|
686
|
-
|
|
687
|
-
metric_val = logs.get(self.monitor)
|
|
688
|
-
|
|
689
|
-
if metric_val is None:
|
|
690
|
-
_LOGGER.warning(f"DragonReduceLROnPlateau could not find metric '{self.monitor}' in logs. Scheduler step skipped.")
|
|
691
|
-
# Still log LR to keep history consistent
|
|
692
|
-
self._check_and_log_lr(epoch, logs, self.verbose)
|
|
693
|
-
return
|
|
694
|
-
|
|
695
|
-
# Step with metric
|
|
696
|
-
self.scheduler.step(metric_val)
|
|
697
|
-
|
|
698
|
-
self._check_and_log_lr(epoch, logs, self.verbose)
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
def info():
|
|
702
|
-
_script_info(__all__)
|