dragon-ml-toolbox 19.14.0__py3-none-any.whl → 20.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/METADATA +29 -46
- dragon_ml_toolbox-20.0.0.dist-info/RECORD +178 -0
- ml_tools/{ETL_cleaning.py → ETL_cleaning/__init__.py} +13 -5
- ml_tools/ETL_cleaning/_basic_clean.py +351 -0
- ml_tools/ETL_cleaning/_clean_tools.py +128 -0
- ml_tools/ETL_cleaning/_dragon_cleaner.py +245 -0
- ml_tools/ETL_cleaning/_imprimir.py +13 -0
- ml_tools/{ETL_engineering.py → ETL_engineering/__init__.py} +8 -4
- ml_tools/ETL_engineering/_dragon_engineering.py +261 -0
- ml_tools/ETL_engineering/_imprimir.py +24 -0
- ml_tools/{_core/_ETL_engineering.py → ETL_engineering/_transforms.py} +14 -267
- ml_tools/{_core → GUI_tools}/_GUI_tools.py +37 -40
- ml_tools/{GUI_tools.py → GUI_tools/__init__.py} +7 -5
- ml_tools/GUI_tools/_imprimir.py +12 -0
- ml_tools/IO_tools/_IO_loggers.py +235 -0
- ml_tools/IO_tools/_IO_save_load.py +151 -0
- ml_tools/IO_tools/_IO_utils.py +140 -0
- ml_tools/{IO_tools.py → IO_tools/__init__.py} +13 -5
- ml_tools/IO_tools/_imprimir.py +14 -0
- ml_tools/MICE/_MICE_imputation.py +132 -0
- ml_tools/{MICE_imputation.py → MICE/__init__.py} +6 -7
- ml_tools/{_core/_MICE_imputation.py → MICE/_dragon_mice.py} +243 -322
- ml_tools/MICE/_imprimir.py +11 -0
- ml_tools/{ML_callbacks.py → ML_callbacks/__init__.py} +12 -4
- ml_tools/ML_callbacks/_base.py +101 -0
- ml_tools/ML_callbacks/_checkpoint.py +232 -0
- ml_tools/ML_callbacks/_early_stop.py +208 -0
- ml_tools/ML_callbacks/_imprimir.py +12 -0
- ml_tools/ML_callbacks/_scheduler.py +197 -0
- ml_tools/{ML_chaining_utilities.py → ML_chain/__init__.py} +8 -3
- ml_tools/{_core/_ML_chaining_utilities.py → ML_chain/_chaining_tools.py} +5 -129
- ml_tools/ML_chain/_dragon_chain.py +140 -0
- ml_tools/ML_chain/_imprimir.py +11 -0
- ml_tools/ML_configuration/__init__.py +90 -0
- ml_tools/ML_configuration/_base_model_config.py +69 -0
- ml_tools/ML_configuration/_finalize.py +366 -0
- ml_tools/ML_configuration/_imprimir.py +47 -0
- ml_tools/ML_configuration/_metrics.py +593 -0
- ml_tools/ML_configuration/_models.py +206 -0
- ml_tools/ML_configuration/_training.py +124 -0
- ml_tools/ML_datasetmaster/__init__.py +28 -0
- ml_tools/ML_datasetmaster/_base_datasetmaster.py +337 -0
- ml_tools/{_core/_ML_datasetmaster.py → ML_datasetmaster/_datasetmaster.py} +9 -329
- ml_tools/ML_datasetmaster/_imprimir.py +15 -0
- ml_tools/{_core/_ML_sequence_datasetmaster.py → ML_datasetmaster/_sequence_datasetmaster.py} +13 -15
- ml_tools/{_core/_ML_vision_datasetmaster.py → ML_datasetmaster/_vision_datasetmaster.py} +63 -65
- ml_tools/ML_evaluation/__init__.py +53 -0
- ml_tools/ML_evaluation/_classification.py +629 -0
- ml_tools/ML_evaluation/_feature_importance.py +409 -0
- ml_tools/ML_evaluation/_imprimir.py +25 -0
- ml_tools/ML_evaluation/_loss.py +92 -0
- ml_tools/ML_evaluation/_regression.py +273 -0
- ml_tools/{_core/_ML_sequence_evaluation.py → ML_evaluation/_sequence.py} +8 -11
- ml_tools/{_core/_ML_vision_evaluation.py → ML_evaluation/_vision.py} +12 -17
- ml_tools/{_core → ML_evaluation_captum}/_ML_evaluation_captum.py +11 -38
- ml_tools/{ML_evaluation_captum.py → ML_evaluation_captum/__init__.py} +6 -4
- ml_tools/ML_evaluation_captum/_imprimir.py +10 -0
- ml_tools/{_core → ML_finalize_handler}/_ML_finalize_handler.py +3 -7
- ml_tools/ML_finalize_handler/__init__.py +10 -0
- ml_tools/ML_finalize_handler/_imprimir.py +8 -0
- ml_tools/ML_inference/__init__.py +22 -0
- ml_tools/ML_inference/_base_inference.py +166 -0
- ml_tools/{_core/_ML_chaining_inference.py → ML_inference/_chain_inference.py} +14 -17
- ml_tools/ML_inference/_dragon_inference.py +332 -0
- ml_tools/ML_inference/_imprimir.py +11 -0
- ml_tools/ML_inference/_multi_inference.py +180 -0
- ml_tools/ML_inference_sequence/__init__.py +10 -0
- ml_tools/ML_inference_sequence/_imprimir.py +8 -0
- ml_tools/{_core/_ML_sequence_inference.py → ML_inference_sequence/_sequence_inference.py} +11 -15
- ml_tools/ML_inference_vision/__init__.py +10 -0
- ml_tools/ML_inference_vision/_imprimir.py +8 -0
- ml_tools/{_core/_ML_vision_inference.py → ML_inference_vision/_vision_inference.py} +15 -19
- ml_tools/ML_models/__init__.py +32 -0
- ml_tools/{_core/_ML_models_advanced.py → ML_models/_advanced_models.py} +22 -18
- ml_tools/ML_models/_base_mlp_attention.py +198 -0
- ml_tools/{_core/_models_advanced_base.py → ML_models/_base_save_load.py} +73 -49
- ml_tools/ML_models/_dragon_tabular.py +248 -0
- ml_tools/ML_models/_imprimir.py +18 -0
- ml_tools/ML_models/_mlp_attention.py +134 -0
- ml_tools/{_core → ML_models}/_models_advanced_helpers.py +13 -13
- ml_tools/ML_models_sequence/__init__.py +10 -0
- ml_tools/ML_models_sequence/_imprimir.py +8 -0
- ml_tools/{_core/_ML_sequence_models.py → ML_models_sequence/_sequence_models.py} +5 -8
- ml_tools/ML_models_vision/__init__.py +29 -0
- ml_tools/ML_models_vision/_base_wrapper.py +254 -0
- ml_tools/ML_models_vision/_image_classification.py +182 -0
- ml_tools/ML_models_vision/_image_segmentation.py +108 -0
- ml_tools/ML_models_vision/_imprimir.py +16 -0
- ml_tools/ML_models_vision/_object_detection.py +135 -0
- ml_tools/ML_optimization/__init__.py +21 -0
- ml_tools/ML_optimization/_imprimir.py +13 -0
- ml_tools/{_core/_ML_optimization_pareto.py → ML_optimization/_multi_dragon.py} +18 -24
- ml_tools/ML_optimization/_single_dragon.py +203 -0
- ml_tools/{_core/_ML_optimization.py → ML_optimization/_single_manual.py} +75 -213
- ml_tools/{_core → ML_scaler}/_ML_scaler.py +8 -11
- ml_tools/ML_scaler/__init__.py +10 -0
- ml_tools/ML_scaler/_imprimir.py +8 -0
- ml_tools/ML_trainer/__init__.py +20 -0
- ml_tools/ML_trainer/_base_trainer.py +297 -0
- ml_tools/ML_trainer/_dragon_detection_trainer.py +402 -0
- ml_tools/ML_trainer/_dragon_sequence_trainer.py +540 -0
- ml_tools/ML_trainer/_dragon_trainer.py +1160 -0
- ml_tools/ML_trainer/_imprimir.py +10 -0
- ml_tools/{ML_utilities.py → ML_utilities/__init__.py} +14 -6
- ml_tools/ML_utilities/_artifact_finder.py +382 -0
- ml_tools/ML_utilities/_imprimir.py +16 -0
- ml_tools/ML_utilities/_inspection.py +325 -0
- ml_tools/ML_utilities/_train_tools.py +205 -0
- ml_tools/{ML_vision_transformers.py → ML_vision_transformers/__init__.py} +9 -6
- ml_tools/{_core/_ML_vision_transformers.py → ML_vision_transformers/_core_transforms.py} +11 -155
- ml_tools/ML_vision_transformers/_imprimir.py +14 -0
- ml_tools/ML_vision_transformers/_offline_augmentation.py +159 -0
- ml_tools/{_core/_PSO_optimization.py → PSO_optimization/_PSO.py} +58 -15
- ml_tools/{PSO_optimization.py → PSO_optimization/__init__.py} +5 -3
- ml_tools/PSO_optimization/_imprimir.py +10 -0
- ml_tools/SQL/__init__.py +7 -0
- ml_tools/{_core/_SQL.py → SQL/_dragon_SQL.py} +7 -11
- ml_tools/SQL/_imprimir.py +8 -0
- ml_tools/{_core → VIF}/_VIF_factor.py +5 -8
- ml_tools/{VIF_factor.py → VIF/__init__.py} +4 -2
- ml_tools/VIF/_imprimir.py +10 -0
- ml_tools/_core/__init__.py +7 -1
- ml_tools/_core/_logger.py +8 -18
- ml_tools/_core/_schema_load_ops.py +43 -0
- ml_tools/_core/_script_info.py +2 -2
- ml_tools/{data_exploration.py → data_exploration/__init__.py} +32 -16
- ml_tools/data_exploration/_analysis.py +214 -0
- ml_tools/data_exploration/_cleaning.py +566 -0
- ml_tools/data_exploration/_features.py +583 -0
- ml_tools/data_exploration/_imprimir.py +32 -0
- ml_tools/data_exploration/_plotting.py +487 -0
- ml_tools/data_exploration/_schema_ops.py +176 -0
- ml_tools/{ensemble_evaluation.py → ensemble_evaluation/__init__.py} +6 -4
- ml_tools/{_core → ensemble_evaluation}/_ensemble_evaluation.py +3 -7
- ml_tools/ensemble_evaluation/_imprimir.py +14 -0
- ml_tools/{ensemble_inference.py → ensemble_inference/__init__.py} +5 -3
- ml_tools/{_core → ensemble_inference}/_ensemble_inference.py +15 -18
- ml_tools/ensemble_inference/_imprimir.py +9 -0
- ml_tools/{ensemble_learning.py → ensemble_learning/__init__.py} +4 -6
- ml_tools/{_core → ensemble_learning}/_ensemble_learning.py +7 -10
- ml_tools/ensemble_learning/_imprimir.py +10 -0
- ml_tools/{excel_handler.py → excel_handler/__init__.py} +5 -3
- ml_tools/{_core → excel_handler}/_excel_handler.py +6 -10
- ml_tools/excel_handler/_imprimir.py +13 -0
- ml_tools/{keys.py → keys/__init__.py} +4 -1
- ml_tools/keys/_imprimir.py +11 -0
- ml_tools/{_core → keys}/_keys.py +2 -0
- ml_tools/{math_utilities.py → math_utilities/__init__.py} +5 -2
- ml_tools/math_utilities/_imprimir.py +11 -0
- ml_tools/{_core → math_utilities}/_math_utilities.py +1 -5
- ml_tools/{optimization_tools.py → optimization_tools/__init__.py} +9 -4
- ml_tools/optimization_tools/_imprimir.py +13 -0
- ml_tools/optimization_tools/_optimization_bounds.py +236 -0
- ml_tools/optimization_tools/_optimization_plots.py +218 -0
- ml_tools/{path_manager.py → path_manager/__init__.py} +6 -3
- ml_tools/{_core/_path_manager.py → path_manager/_dragonmanager.py} +11 -347
- ml_tools/path_manager/_imprimir.py +15 -0
- ml_tools/path_manager/_path_tools.py +346 -0
- ml_tools/plot_fonts/__init__.py +8 -0
- ml_tools/plot_fonts/_imprimir.py +8 -0
- ml_tools/{_core → plot_fonts}/_plot_fonts.py +2 -5
- ml_tools/schema/__init__.py +15 -0
- ml_tools/schema/_feature_schema.py +223 -0
- ml_tools/schema/_gui_schema.py +191 -0
- ml_tools/schema/_imprimir.py +10 -0
- ml_tools/{serde.py → serde/__init__.py} +4 -2
- ml_tools/serde/_imprimir.py +10 -0
- ml_tools/{_core → serde}/_serde.py +3 -8
- ml_tools/{utilities.py → utilities/__init__.py} +11 -6
- ml_tools/utilities/_imprimir.py +18 -0
- ml_tools/{_core/_utilities.py → utilities/_utility_save_load.py} +13 -190
- ml_tools/utilities/_utility_tools.py +192 -0
- dragon_ml_toolbox-19.14.0.dist-info/RECORD +0 -111
- ml_tools/ML_chaining_inference.py +0 -8
- ml_tools/ML_configuration.py +0 -86
- ml_tools/ML_configuration_pytab.py +0 -14
- ml_tools/ML_datasetmaster.py +0 -10
- ml_tools/ML_evaluation.py +0 -16
- ml_tools/ML_evaluation_multi.py +0 -12
- ml_tools/ML_finalize_handler.py +0 -8
- ml_tools/ML_inference.py +0 -12
- ml_tools/ML_models.py +0 -14
- ml_tools/ML_models_advanced.py +0 -14
- ml_tools/ML_models_pytab.py +0 -14
- ml_tools/ML_optimization.py +0 -14
- ml_tools/ML_optimization_pareto.py +0 -8
- ml_tools/ML_scaler.py +0 -8
- ml_tools/ML_sequence_datasetmaster.py +0 -8
- ml_tools/ML_sequence_evaluation.py +0 -10
- ml_tools/ML_sequence_inference.py +0 -8
- ml_tools/ML_sequence_models.py +0 -8
- ml_tools/ML_trainer.py +0 -12
- ml_tools/ML_vision_datasetmaster.py +0 -12
- ml_tools/ML_vision_evaluation.py +0 -10
- ml_tools/ML_vision_inference.py +0 -8
- ml_tools/ML_vision_models.py +0 -18
- ml_tools/SQL.py +0 -8
- ml_tools/_core/_ETL_cleaning.py +0 -694
- ml_tools/_core/_IO_tools.py +0 -498
- ml_tools/_core/_ML_callbacks.py +0 -702
- ml_tools/_core/_ML_configuration.py +0 -1332
- ml_tools/_core/_ML_configuration_pytab.py +0 -102
- ml_tools/_core/_ML_evaluation.py +0 -867
- ml_tools/_core/_ML_evaluation_multi.py +0 -544
- ml_tools/_core/_ML_inference.py +0 -646
- ml_tools/_core/_ML_models.py +0 -668
- ml_tools/_core/_ML_models_pytab.py +0 -693
- ml_tools/_core/_ML_trainer.py +0 -2323
- ml_tools/_core/_ML_utilities.py +0 -886
- ml_tools/_core/_ML_vision_models.py +0 -644
- ml_tools/_core/_data_exploration.py +0 -1909
- ml_tools/_core/_optimization_tools.py +0 -493
- ml_tools/_core/_schema.py +0 -359
- ml_tools/plot_fonts.py +0 -8
- ml_tools/schema.py +0 -12
- {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/top_level.txt +0 -0
|
@@ -1,16 +1,24 @@
|
|
|
1
|
-
from .
|
|
1
|
+
from ._early_stop import (
|
|
2
2
|
DragonPatienceEarlyStopping,
|
|
3
3
|
DragonPrecheltEarlyStopping,
|
|
4
|
+
)
|
|
5
|
+
|
|
6
|
+
from ._checkpoint import (
|
|
4
7
|
DragonModelCheckpoint,
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
from ._scheduler import (
|
|
5
11
|
DragonScheduler,
|
|
6
|
-
|
|
7
|
-
info
|
|
12
|
+
DragonPlateauScheduler,
|
|
8
13
|
)
|
|
9
14
|
|
|
15
|
+
from ._imprimir import info
|
|
16
|
+
|
|
17
|
+
|
|
10
18
|
__all__ = [
|
|
11
19
|
"DragonPatienceEarlyStopping",
|
|
12
20
|
"DragonPrecheltEarlyStopping",
|
|
13
21
|
"DragonModelCheckpoint",
|
|
14
22
|
"DragonScheduler",
|
|
15
|
-
"
|
|
23
|
+
"DragonPlateauScheduler",
|
|
16
24
|
]
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
from tqdm.auto import tqdm
|
|
2
|
+
|
|
3
|
+
from ..keys._keys import PyTorchLogKeys
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
"_Callback",
|
|
8
|
+
"History",
|
|
9
|
+
"TqdmProgressBar",
|
|
10
|
+
]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class _Callback:
|
|
14
|
+
"""
|
|
15
|
+
Abstract base class used to build new callbacks.
|
|
16
|
+
|
|
17
|
+
The methods of this class are automatically called by the Trainer at different
|
|
18
|
+
points during training. Subclasses can override these methods to implement
|
|
19
|
+
custom logic.
|
|
20
|
+
"""
|
|
21
|
+
def __init__(self):
|
|
22
|
+
self.trainer = None
|
|
23
|
+
|
|
24
|
+
def set_trainer(self, trainer):
|
|
25
|
+
"""This is called by the Trainer to associate itself with the callback."""
|
|
26
|
+
self.trainer = trainer
|
|
27
|
+
|
|
28
|
+
def on_train_begin(self, logs=None):
|
|
29
|
+
"""Called at the beginning of training."""
|
|
30
|
+
pass
|
|
31
|
+
|
|
32
|
+
def on_train_end(self, logs=None):
|
|
33
|
+
"""Called at the end of training."""
|
|
34
|
+
pass
|
|
35
|
+
|
|
36
|
+
def on_epoch_begin(self, epoch, logs=None):
|
|
37
|
+
"""Called at the beginning of an epoch."""
|
|
38
|
+
pass
|
|
39
|
+
|
|
40
|
+
def on_epoch_end(self, epoch, logs=None):
|
|
41
|
+
"""Called at the end of an epoch."""
|
|
42
|
+
pass
|
|
43
|
+
|
|
44
|
+
def on_batch_begin(self, batch, logs=None):
|
|
45
|
+
"""Called at the beginning of a training batch."""
|
|
46
|
+
pass
|
|
47
|
+
|
|
48
|
+
def on_batch_end(self, batch, logs=None):
|
|
49
|
+
"""Called at the end of a training batch."""
|
|
50
|
+
pass
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class History(_Callback):
|
|
54
|
+
"""
|
|
55
|
+
Callback that records events into a `history` dictionary.
|
|
56
|
+
|
|
57
|
+
This callback is automatically applied to every MyTrainer model.
|
|
58
|
+
The `history` attribute is a dictionary mapping metric names (e.g., 'val_loss')
|
|
59
|
+
to a list of metric values.
|
|
60
|
+
"""
|
|
61
|
+
def on_train_begin(self, logs=None):
|
|
62
|
+
# Clear history at the beginning of training
|
|
63
|
+
self.trainer.history = {} # type: ignore
|
|
64
|
+
|
|
65
|
+
def on_epoch_end(self, epoch, logs=None):
|
|
66
|
+
logs = logs or {}
|
|
67
|
+
for k, v in logs.items():
|
|
68
|
+
# Append new log values to the history dictionary
|
|
69
|
+
self.trainer.history.setdefault(k, []).append(v) # type: ignore
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class TqdmProgressBar(_Callback):
|
|
73
|
+
"""Callback that provides a tqdm progress bar for training."""
|
|
74
|
+
def __init__(self):
|
|
75
|
+
self.epoch_bar = None
|
|
76
|
+
self.batch_bar = None
|
|
77
|
+
|
|
78
|
+
def on_train_begin(self, logs=None):
|
|
79
|
+
self.epochs = self.trainer.epochs # type: ignore
|
|
80
|
+
self.epoch_bar = tqdm(total=self.epochs, desc="Training Progress")
|
|
81
|
+
|
|
82
|
+
def on_epoch_begin(self, epoch, logs=None):
|
|
83
|
+
total_batches = len(self.trainer.train_loader) # type: ignore
|
|
84
|
+
self.batch_bar = tqdm(total=total_batches, desc=f"Epoch {epoch}/{self.epochs}", leave=False)
|
|
85
|
+
|
|
86
|
+
def on_batch_end(self, batch, logs=None):
|
|
87
|
+
self.batch_bar.update(1) # type: ignore
|
|
88
|
+
if logs:
|
|
89
|
+
self.batch_bar.set_postfix(loss=f"{logs.get(PyTorchLogKeys.BATCH_LOSS, 0):.4f}") # type: ignore
|
|
90
|
+
|
|
91
|
+
def on_epoch_end(self, epoch, logs=None):
|
|
92
|
+
self.batch_bar.close() # type: ignore
|
|
93
|
+
self.epoch_bar.update(1) # type: ignore
|
|
94
|
+
if logs:
|
|
95
|
+
train_loss_str = f"{logs.get(PyTorchLogKeys.TRAIN_LOSS, 0):.4f}"
|
|
96
|
+
val_loss_str = f"{logs.get(PyTorchLogKeys.VAL_LOSS, 0):.4f}"
|
|
97
|
+
self.epoch_bar.set_postfix_str(f"Train Loss: {train_loss_str}, Val Loss: {val_loss_str}") # type: ignore
|
|
98
|
+
|
|
99
|
+
def on_train_end(self, logs=None):
|
|
100
|
+
self.epoch_bar.close() # type: ignore
|
|
101
|
+
|
|
@@ -0,0 +1,232 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import torch
|
|
3
|
+
from typing import Union, Literal
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
from ..path_manager import make_fullpath
|
|
7
|
+
from ..keys._keys import PyTorchLogKeys, PyTorchCheckpointKeys
|
|
8
|
+
from .._core import get_logger
|
|
9
|
+
|
|
10
|
+
from ._base import _Callback
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
_LOGGER = get_logger("Checkpoint")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
__all__ = [
|
|
17
|
+
"DragonModelCheckpoint",
|
|
18
|
+
]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class DragonModelCheckpoint(_Callback):
|
|
22
|
+
"""
|
|
23
|
+
Saves the model weights, optimizer state, LR scheduler state (if any), and epoch number to a directory with automated filename generation and rotation.
|
|
24
|
+
"""
|
|
25
|
+
def __init__(self,
|
|
26
|
+
save_dir: Union[str, Path],
|
|
27
|
+
monitor: Literal["Training Loss", "Validation Loss", "both"] = "Validation Loss",
|
|
28
|
+
save_three_best: bool = True,
|
|
29
|
+
mode: Literal['min', 'max'] = 'min',
|
|
30
|
+
verbose: int = 0):
|
|
31
|
+
"""
|
|
32
|
+
Args:
|
|
33
|
+
save_dir (str): Directory where checkpoint files will be saved.
|
|
34
|
+
monitor (str): Metric to monitor. If "both", the sum of training loss and validation loss is used.
|
|
35
|
+
save_three_best (bool):
|
|
36
|
+
- If True, keeps the top 3 best checkpoints found during training (based on metric).
|
|
37
|
+
- If False, keeps the 3 most recent checkpoints (rolling window).
|
|
38
|
+
mode (str): One of {'min', 'max'}.
|
|
39
|
+
verbose (int): Verbosity mode.
|
|
40
|
+
"""
|
|
41
|
+
super().__init__()
|
|
42
|
+
self.save_dir = make_fullpath(save_dir, make=True, enforce="directory")
|
|
43
|
+
|
|
44
|
+
# Standardize monitor key
|
|
45
|
+
if monitor == "Training Loss":
|
|
46
|
+
std_monitor = PyTorchLogKeys.TRAIN_LOSS
|
|
47
|
+
elif monitor == "Validation Loss":
|
|
48
|
+
std_monitor = PyTorchLogKeys.VAL_LOSS
|
|
49
|
+
elif monitor == "both":
|
|
50
|
+
std_monitor = "both"
|
|
51
|
+
else:
|
|
52
|
+
_LOGGER.error(f"Unknown monitor key: {monitor}.")
|
|
53
|
+
raise ValueError()
|
|
54
|
+
|
|
55
|
+
self.monitor = std_monitor
|
|
56
|
+
self.save_three_best = save_three_best
|
|
57
|
+
self.verbose = verbose
|
|
58
|
+
self._latest_checkpoint_path = None
|
|
59
|
+
self._checkpoint_name = PyTorchCheckpointKeys.CHECKPOINT_NAME
|
|
60
|
+
|
|
61
|
+
# State variables
|
|
62
|
+
# stored as list of dicts: [{'path': Path, 'score': float, 'epoch': int}]
|
|
63
|
+
self.best_checkpoints = []
|
|
64
|
+
# For rolling check (save_three_best=False)
|
|
65
|
+
self.recent_checkpoints = []
|
|
66
|
+
|
|
67
|
+
if mode not in ['min', 'max']:
|
|
68
|
+
_LOGGER.error(f"ModelCheckpoint mode {mode} is unknown. Use 'min' or 'max'.")
|
|
69
|
+
raise ValueError()
|
|
70
|
+
self.mode = mode
|
|
71
|
+
|
|
72
|
+
# Determine comparison operator
|
|
73
|
+
if self.mode == 'min':
|
|
74
|
+
self.monitor_op = np.less
|
|
75
|
+
self.best = np.inf
|
|
76
|
+
else:
|
|
77
|
+
self.monitor_op = np.greater
|
|
78
|
+
self.best = -np.inf
|
|
79
|
+
|
|
80
|
+
def on_train_begin(self, logs=None):
|
|
81
|
+
"""Reset file tracking state when training starts.
|
|
82
|
+
NOTE: Do nOT reset self.best here if it differs from the default. This allows the Trainer to restore 'best' from a checkpoint before calling train()."""
|
|
83
|
+
self.best_checkpoints = []
|
|
84
|
+
self.recent_checkpoints = []
|
|
85
|
+
|
|
86
|
+
# Check if self.best is at default initialization value
|
|
87
|
+
is_default_min = (self.mode == 'min' and self.best == np.inf)
|
|
88
|
+
is_default_max = (self.mode == 'max' and self.best == -np.inf)
|
|
89
|
+
|
|
90
|
+
# If it is NOT default, it means it was restored.
|
|
91
|
+
if not (is_default_min or is_default_max):
|
|
92
|
+
_LOGGER.debug(f"Resuming with best score: {self.best:.4f}")
|
|
93
|
+
|
|
94
|
+
def _get_metric_value(self, logs):
|
|
95
|
+
"""Extracts or calculates the metric value based on configuration."""
|
|
96
|
+
if self.monitor == "both":
|
|
97
|
+
t_loss = logs.get(PyTorchLogKeys.TRAIN_LOSS)
|
|
98
|
+
v_loss = logs.get(PyTorchLogKeys.VAL_LOSS)
|
|
99
|
+
if t_loss is None or v_loss is None:
|
|
100
|
+
return None
|
|
101
|
+
return t_loss + v_loss
|
|
102
|
+
else:
|
|
103
|
+
return logs.get(self.monitor)
|
|
104
|
+
|
|
105
|
+
def on_epoch_end(self, epoch, logs=None):
|
|
106
|
+
logs = logs or {}
|
|
107
|
+
current_score = self._get_metric_value(logs)
|
|
108
|
+
|
|
109
|
+
if current_score is None:
|
|
110
|
+
if self.verbose > 0:
|
|
111
|
+
_LOGGER.warning(f"Epoch {epoch}: Metric '{self.monitor}' not found in logs. Skipping checkpoint.")
|
|
112
|
+
return
|
|
113
|
+
|
|
114
|
+
# 1. Update global best score (for logging/metadata)
|
|
115
|
+
if self.monitor_op(current_score, self.best):
|
|
116
|
+
if self.verbose > 0:
|
|
117
|
+
# Only log explicit "improvement" if we are beating the historical best
|
|
118
|
+
old_best_str = f"{self.best:.4f}" if not np.isinf(self.best) else "inf"
|
|
119
|
+
_LOGGER.info(f"Epoch {epoch}: {self.monitor} improved from {old_best_str} to {current_score:.4f}")
|
|
120
|
+
self.best = current_score
|
|
121
|
+
|
|
122
|
+
if self.save_three_best:
|
|
123
|
+
self._save_top_k_checkpoints(epoch, current_score)
|
|
124
|
+
else:
|
|
125
|
+
self._save_rolling_checkpoints(epoch, current_score)
|
|
126
|
+
|
|
127
|
+
def _save_checkpoint_file(self, epoch, current_score):
|
|
128
|
+
"""Helper to physically save the file."""
|
|
129
|
+
self.save_dir.mkdir(parents=True, exist_ok=True)
|
|
130
|
+
|
|
131
|
+
# Create filename
|
|
132
|
+
score_str = f"{current_score:.4f}".replace('.', '_')
|
|
133
|
+
filename = f"epoch{epoch}_{self._checkpoint_name}-{score_str}.pth"
|
|
134
|
+
filepath = self.save_dir / filename
|
|
135
|
+
|
|
136
|
+
# Create checkpoint dict
|
|
137
|
+
checkpoint_data = {
|
|
138
|
+
PyTorchCheckpointKeys.EPOCH: epoch,
|
|
139
|
+
PyTorchCheckpointKeys.MODEL_STATE: self.trainer.model.state_dict(), # type: ignore
|
|
140
|
+
PyTorchCheckpointKeys.OPTIMIZER_STATE: self.trainer.optimizer.state_dict(), # type: ignore
|
|
141
|
+
PyTorchCheckpointKeys.BEST_SCORE: current_score,
|
|
142
|
+
PyTorchCheckpointKeys.HISTORY: self.trainer.history, # type: ignore
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
if hasattr(self.trainer, 'scheduler') and self.trainer.scheduler is not None: # type: ignore
|
|
146
|
+
checkpoint_data[PyTorchCheckpointKeys.SCHEDULER_STATE] = self.trainer.scheduler.state_dict() # type: ignore
|
|
147
|
+
|
|
148
|
+
torch.save(checkpoint_data, filepath)
|
|
149
|
+
self._latest_checkpoint_path = filepath
|
|
150
|
+
|
|
151
|
+
return filepath
|
|
152
|
+
|
|
153
|
+
def _save_top_k_checkpoints(self, epoch, current_score):
|
|
154
|
+
"""Logic for maintaining the top 3 best checkpoints."""
|
|
155
|
+
|
|
156
|
+
def sort_key(item): return item['score']
|
|
157
|
+
|
|
158
|
+
# Determine sort direction so that Index 0 is BEST and Index -1 is WORST
|
|
159
|
+
# Min mode (lower is better): Ascending (reverse=False) -> [0.1, 0.5, 0.9] (0.1 is best)
|
|
160
|
+
# Max mode (higher is better): Descending (reverse=True) -> [0.9, 0.5, 0.1] (0.9 is best)
|
|
161
|
+
is_reverse = (self.mode == 'max')
|
|
162
|
+
|
|
163
|
+
should_save = False
|
|
164
|
+
|
|
165
|
+
if len(self.best_checkpoints) < 3:
|
|
166
|
+
should_save = True
|
|
167
|
+
else:
|
|
168
|
+
# Sort current list to identify the worst (last item)
|
|
169
|
+
self.best_checkpoints.sort(key=sort_key, reverse=is_reverse)
|
|
170
|
+
worst_entry = self.best_checkpoints[-1]
|
|
171
|
+
|
|
172
|
+
# Check if current is better than the worst in the list
|
|
173
|
+
# min mode: current < worst['score']
|
|
174
|
+
# max mode: current > worst['score']
|
|
175
|
+
if self.monitor_op(current_score, worst_entry['score']):
|
|
176
|
+
should_save = True
|
|
177
|
+
|
|
178
|
+
if should_save:
|
|
179
|
+
filepath = self._save_checkpoint_file(epoch, current_score)
|
|
180
|
+
|
|
181
|
+
if self.verbose > 0:
|
|
182
|
+
_LOGGER.info(f"Epoch {epoch}: {self.monitor} ({current_score:.4f}) is in top 3. Saving to {filepath.name}")
|
|
183
|
+
|
|
184
|
+
self.best_checkpoints.append({'path': filepath, 'score': current_score, 'epoch': epoch})
|
|
185
|
+
|
|
186
|
+
# Prune if > 3
|
|
187
|
+
if len(self.best_checkpoints) > 3:
|
|
188
|
+
# Re-sort to ensure worst is at the end
|
|
189
|
+
self.best_checkpoints.sort(key=sort_key, reverse=is_reverse)
|
|
190
|
+
|
|
191
|
+
# Evict the last one (Worst)
|
|
192
|
+
entry_to_delete = self.best_checkpoints.pop(-1)
|
|
193
|
+
|
|
194
|
+
if entry_to_delete['path'].exists():
|
|
195
|
+
if self.verbose > 0:
|
|
196
|
+
_LOGGER.info(f" -> Deleting checkpoint outside top 3: {entry_to_delete['path'].name}")
|
|
197
|
+
entry_to_delete['path'].unlink()
|
|
198
|
+
|
|
199
|
+
def _save_rolling_checkpoints(self, epoch, current_score):
|
|
200
|
+
"""Saves the latest model and keeps only the 3 most recent ones."""
|
|
201
|
+
filepath = self._save_checkpoint_file(epoch, current_score)
|
|
202
|
+
|
|
203
|
+
if self.verbose > 0:
|
|
204
|
+
_LOGGER.info(f'Epoch {epoch}: saving rolling model to {filepath.name}')
|
|
205
|
+
|
|
206
|
+
self.recent_checkpoints.append(filepath)
|
|
207
|
+
|
|
208
|
+
# If we have more than 3 checkpoints, remove the oldest one
|
|
209
|
+
if len(self.recent_checkpoints) > 3:
|
|
210
|
+
file_to_delete = self.recent_checkpoints.pop(0)
|
|
211
|
+
if file_to_delete.exists():
|
|
212
|
+
if self.verbose > 0:
|
|
213
|
+
_LOGGER.info(f" -> Deleting old rolling checkpoint: {file_to_delete.name}")
|
|
214
|
+
file_to_delete.unlink()
|
|
215
|
+
|
|
216
|
+
@property
|
|
217
|
+
def best_checkpoint_path(self):
|
|
218
|
+
# If tracking top 3, return the absolute best among them
|
|
219
|
+
if self.save_three_best and self.best_checkpoints:
|
|
220
|
+
def sort_key(item): return item['score']
|
|
221
|
+
is_reverse = (self.mode == 'max')
|
|
222
|
+
# Sort Best -> Worst
|
|
223
|
+
sorted_bests = sorted(self.best_checkpoints, key=sort_key, reverse=is_reverse)
|
|
224
|
+
# Index 0 is always the best based on the logic above
|
|
225
|
+
return sorted_bests[0]['path']
|
|
226
|
+
|
|
227
|
+
elif self._latest_checkpoint_path:
|
|
228
|
+
return self._latest_checkpoint_path
|
|
229
|
+
else:
|
|
230
|
+
_LOGGER.error("No checkpoint paths saved.")
|
|
231
|
+
raise ValueError()
|
|
232
|
+
|
|
@@ -0,0 +1,208 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from collections import deque
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
from ..keys._keys import PyTorchLogKeys
|
|
6
|
+
from .._core import get_logger
|
|
7
|
+
|
|
8
|
+
from ._base import _Callback
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
_LOGGER = get_logger("EarlyStopping")
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"DragonPatienceEarlyStopping",
|
|
16
|
+
"DragonPrecheltEarlyStopping",
|
|
17
|
+
]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class _DragonEarlyStopping(_Callback):
|
|
21
|
+
"""
|
|
22
|
+
Base class for Early Stopping strategies.
|
|
23
|
+
Ensures type compatibility and shared logging logic.
|
|
24
|
+
"""
|
|
25
|
+
def __init__(self,
|
|
26
|
+
monitor: str,
|
|
27
|
+
verbose: int = 1):
|
|
28
|
+
super().__init__()
|
|
29
|
+
self.monitor = monitor
|
|
30
|
+
self.verbose = verbose
|
|
31
|
+
self.stopped_epoch = 0
|
|
32
|
+
|
|
33
|
+
def _stop_training(self, epoch: int, reason: str):
|
|
34
|
+
"""Helper to trigger the stop."""
|
|
35
|
+
self.stopped_epoch = epoch
|
|
36
|
+
self.trainer.stop_training = True # type: ignore
|
|
37
|
+
if self.verbose > 0:
|
|
38
|
+
_LOGGER.info(f"Epoch {epoch}: Early stopping triggered. Reason: {reason}")
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class DragonPatienceEarlyStopping(_DragonEarlyStopping):
|
|
42
|
+
"""
|
|
43
|
+
Standard early stopping: Tracks minimum validation loss (or other metric) with a patience counter.
|
|
44
|
+
"""
|
|
45
|
+
def __init__(self,
|
|
46
|
+
monitor: Literal["Training Loss", "Validation Loss"] = "Validation Loss",
|
|
47
|
+
min_delta: float = 0.0,
|
|
48
|
+
patience: int = 10,
|
|
49
|
+
mode: Literal['min', 'max'] = 'min',
|
|
50
|
+
verbose: int = 1):
|
|
51
|
+
"""
|
|
52
|
+
Args:
|
|
53
|
+
monitor (str): Metric to monitor.
|
|
54
|
+
min_delta (float): Minimum change to qualify as an improvement.
|
|
55
|
+
patience (int): Number of epochs with no improvement after which training will be stopped.
|
|
56
|
+
mode (str): One of {'min', 'max'}. In 'min' mode, training will stop when the quantity monitored has stopped decreasing; in 'max' mode it will stop when the quantity monitored has stopped increasing.
|
|
57
|
+
verbose (int): Verbosity mode.
|
|
58
|
+
"""
|
|
59
|
+
# standardize monitor key
|
|
60
|
+
if monitor == "Training Loss":
|
|
61
|
+
std_monitor = PyTorchLogKeys.TRAIN_LOSS
|
|
62
|
+
elif monitor == "Validation Loss":
|
|
63
|
+
std_monitor = PyTorchLogKeys.VAL_LOSS
|
|
64
|
+
else:
|
|
65
|
+
_LOGGER.error(f"Unknown monitor key: {monitor}.")
|
|
66
|
+
raise ValueError()
|
|
67
|
+
|
|
68
|
+
super().__init__(std_monitor, verbose)
|
|
69
|
+
self.patience = patience
|
|
70
|
+
self.min_delta = min_delta
|
|
71
|
+
self.wait = 0
|
|
72
|
+
self.mode = mode
|
|
73
|
+
|
|
74
|
+
if mode not in ['min', 'max']:
|
|
75
|
+
_LOGGER.error(f"EarlyStopping mode {mode} is unknown, choose one of ('min', 'max')")
|
|
76
|
+
raise ValueError()
|
|
77
|
+
|
|
78
|
+
# Determine the comparison operator
|
|
79
|
+
if self.mode == 'min':
|
|
80
|
+
self.monitor_op = np.less
|
|
81
|
+
elif self.mode == 'max':
|
|
82
|
+
self.monitor_op = np.greater
|
|
83
|
+
else:
|
|
84
|
+
# raise error for unknown mode
|
|
85
|
+
_LOGGER.error(f"EarlyStopping mode {mode} is unknown, choose one of ('min', 'max')")
|
|
86
|
+
raise ValueError()
|
|
87
|
+
|
|
88
|
+
self.best = np.inf if self.monitor_op == np.less else -np.inf
|
|
89
|
+
|
|
90
|
+
def on_train_begin(self, logs=None):
|
|
91
|
+
self.wait = 0
|
|
92
|
+
self.best = np.inf if self.monitor_op == np.less else -np.inf
|
|
93
|
+
|
|
94
|
+
def on_epoch_end(self, epoch, logs=None):
|
|
95
|
+
current = logs.get(self.monitor) # type: ignore
|
|
96
|
+
if current is None:
|
|
97
|
+
return
|
|
98
|
+
|
|
99
|
+
# Check improvement
|
|
100
|
+
if self.monitor_op == np.less:
|
|
101
|
+
is_improvement = self.monitor_op(current, self.best - self.min_delta)
|
|
102
|
+
else:
|
|
103
|
+
is_improvement = self.monitor_op(current, self.best + self.min_delta)
|
|
104
|
+
|
|
105
|
+
if is_improvement:
|
|
106
|
+
if self.verbose > 1:
|
|
107
|
+
_LOGGER.info(f"EarlyStopping: {self.monitor} improved from {self.best:.4f} to {current:.4f}")
|
|
108
|
+
self.best = current
|
|
109
|
+
self.wait = 0
|
|
110
|
+
else:
|
|
111
|
+
self.wait += 1
|
|
112
|
+
if self.wait >= self.patience:
|
|
113
|
+
self._stop_training(epoch, f"No improvement in {self.monitor} for {self.wait} epochs.")
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class DragonPrecheltEarlyStopping(_DragonEarlyStopping):
|
|
117
|
+
"""
|
|
118
|
+
Implements Prechelt's 'Progress-Modified GL' criterion.
|
|
119
|
+
Tracks the ratio between Generalization Loss (overfitting) and Training Progress.
|
|
120
|
+
|
|
121
|
+
References:
|
|
122
|
+
Prechelt, L. (1998). Early Stopping - But When?
|
|
123
|
+
"""
|
|
124
|
+
def __init__(self,
|
|
125
|
+
alpha: float = 0.75,
|
|
126
|
+
window_size: int = 5,
|
|
127
|
+
verbose: int = 1):
|
|
128
|
+
"""
|
|
129
|
+
This early stopping strategy monitors both validation loss and training loss to determine the optimal stopping point.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
alpha (float): The threshold for the stopping criterion.
|
|
133
|
+
window_size (int): The window size for calculating training progress.
|
|
134
|
+
verbose (int): Verbosity mode.
|
|
135
|
+
|
|
136
|
+
NOTE:
|
|
137
|
+
|
|
138
|
+
- **The Window Size (k)**:
|
|
139
|
+
- `5`: The empirical "gold standard." It is long enough to smooth out batch noise but short enough to react to convergence plateaus quickly.
|
|
140
|
+
- `10` to `20`: Use if the training curve is very jagged (e.g., noisy data, small batch sizes, high dropout, or Reinforcement Learning). A larger k value prevents premature stopping due to random volatility.
|
|
141
|
+
- **The threshold (alpha)**:
|
|
142
|
+
- `< 0.5`: Aggressive. Stops training very early.
|
|
143
|
+
- `0.75` to `0.80`: Prechelt found this range to be the most robust across different datasets. It typically yields the best trade-off between generalization and training cost.
|
|
144
|
+
- `1.0` to `1.2`: Useful for complex tasks (like Transformers) where training progress might dip temporarily before recovering. It risks slightly more overfitting but ensures potential is exhausted.
|
|
145
|
+
"""
|
|
146
|
+
super().__init__(PyTorchLogKeys.VAL_LOSS, verbose)
|
|
147
|
+
self.train_monitor = PyTorchLogKeys.TRAIN_LOSS
|
|
148
|
+
self.alpha = alpha
|
|
149
|
+
self.k = window_size
|
|
150
|
+
|
|
151
|
+
self.best_val_loss = np.inf
|
|
152
|
+
self.train_strip = deque(maxlen=window_size)
|
|
153
|
+
|
|
154
|
+
def on_train_begin(self, logs=None):
|
|
155
|
+
self.best_val_loss = np.inf
|
|
156
|
+
self.train_strip.clear()
|
|
157
|
+
|
|
158
|
+
def on_epoch_end(self, epoch, logs=None):
|
|
159
|
+
val_loss = logs.get(self.monitor) # type: ignore
|
|
160
|
+
train_loss = logs.get(self.train_monitor) # type: ignore
|
|
161
|
+
|
|
162
|
+
if val_loss is None or train_loss is None:
|
|
163
|
+
return
|
|
164
|
+
|
|
165
|
+
# 1. Update Best Validation Loss
|
|
166
|
+
if val_loss < self.best_val_loss:
|
|
167
|
+
self.best_val_loss = val_loss
|
|
168
|
+
|
|
169
|
+
# 2. Update Training Strip
|
|
170
|
+
self.train_strip.append(train_loss)
|
|
171
|
+
|
|
172
|
+
# 3. Calculate Generalization Loss (GL)
|
|
173
|
+
# GL(t) = 100 * (E_val / E_opt - 1)
|
|
174
|
+
# Low GL is good. High GL means we are drifting away from best val score (overfitting).
|
|
175
|
+
gl = 100 * ((val_loss / self.best_val_loss) - 1)
|
|
176
|
+
|
|
177
|
+
# 4. Calculate Progress (Pk)
|
|
178
|
+
# Pk(t) = 1000 * (Sum(strip) / (k * min(strip)) - 1)
|
|
179
|
+
# High Pk is good (training loss is still dropping fast). Low Pk means training has stalled.
|
|
180
|
+
if len(self.train_strip) < self.k:
|
|
181
|
+
# Not enough data for progress yet
|
|
182
|
+
return
|
|
183
|
+
|
|
184
|
+
strip_sum = sum(self.train_strip)
|
|
185
|
+
strip_min = min(self.train_strip)
|
|
186
|
+
|
|
187
|
+
# Avoid division by zero
|
|
188
|
+
if strip_min == 0:
|
|
189
|
+
pk = 0.1 # Arbitrary small number
|
|
190
|
+
else:
|
|
191
|
+
pk = 1000 * ((strip_sum / (self.k * strip_min)) - 1)
|
|
192
|
+
|
|
193
|
+
# 5. The Quotient Criterion
|
|
194
|
+
# Stop if GL / Pk > alpha
|
|
195
|
+
# Intuition: Stop if Overfitting is high AND Progress is low.
|
|
196
|
+
|
|
197
|
+
# Avoid division by zero
|
|
198
|
+
if pk == 0:
|
|
199
|
+
pk = 1e-6
|
|
200
|
+
|
|
201
|
+
quotient = gl / pk
|
|
202
|
+
|
|
203
|
+
if self.verbose > 1:
|
|
204
|
+
_LOGGER.info(f"Epoch {epoch}: GL={gl:.3f} | Pk={pk:.3f} | Quotient={quotient:.3f} (Threshold={self.alpha})")
|
|
205
|
+
|
|
206
|
+
if quotient > self.alpha:
|
|
207
|
+
self._stop_training(epoch, f"Prechelt Criterion triggered. Generalization/Progress quotient ({quotient:.3f}) > alpha ({self.alpha}).")
|
|
208
|
+
|