nshtrainer 0.12.1__py3-none-any.whl → 0.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
|
@@ -8,6 +8,7 @@ from lightning.pytorch.callbacks import EarlyStopping as _EarlyStopping
|
|
|
8
8
|
from lightning.pytorch.utilities.rank_zero import rank_prefixed_message
|
|
9
9
|
from typing_extensions import override
|
|
10
10
|
|
|
11
|
+
from ..metrics._config import MetricConfig
|
|
11
12
|
from .base import CallbackConfigBase
|
|
12
13
|
|
|
13
14
|
log = logging.getLogger(__name__)
|
|
@@ -16,18 +17,12 @@ log = logging.getLogger(__name__)
|
|
|
16
17
|
class EarlyStoppingConfig(CallbackConfigBase):
|
|
17
18
|
name: Literal["early_stopping"] = "early_stopping"
|
|
18
19
|
|
|
19
|
-
|
|
20
|
+
metric: MetricConfig | None = None
|
|
20
21
|
"""
|
|
21
22
|
The metric to monitor for early stopping.
|
|
22
23
|
If None, the primary metric will be used.
|
|
23
24
|
"""
|
|
24
25
|
|
|
25
|
-
mode: Literal["min", "max"] | None = None
|
|
26
|
-
"""
|
|
27
|
-
The mode for the metric to monitor for early stopping.
|
|
28
|
-
If None, the primary metric mode will be used.
|
|
29
|
-
"""
|
|
30
|
-
|
|
31
26
|
patience: int
|
|
32
27
|
"""
|
|
33
28
|
Number of epochs with no improvement after which training will be stopped.
|
|
@@ -52,64 +47,30 @@ class EarlyStoppingConfig(CallbackConfigBase):
|
|
|
52
47
|
|
|
53
48
|
@override
|
|
54
49
|
def create_callbacks(self, root_config):
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
"No primary metric is set, so `monitor` must be provided in `early_stopping`."
|
|
64
|
-
)
|
|
65
|
-
monitor = primary_metric.validation_monitor
|
|
66
|
-
mode = primary_metric.mode
|
|
67
|
-
|
|
68
|
-
if mode is None:
|
|
69
|
-
mode = "min"
|
|
70
|
-
|
|
71
|
-
yield EarlyStopping(
|
|
72
|
-
monitor=monitor,
|
|
73
|
-
mode=mode,
|
|
74
|
-
patience=self.patience,
|
|
75
|
-
min_delta=self.min_delta,
|
|
76
|
-
min_lr=self.min_lr,
|
|
77
|
-
strict=self.strict,
|
|
78
|
-
)
|
|
50
|
+
if (metric := self.metric) is None and (
|
|
51
|
+
metric := root_config.primary_metric
|
|
52
|
+
) is None:
|
|
53
|
+
raise ValueError(
|
|
54
|
+
"Either `metric` or `root_config.primary_metric` must be set."
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
yield EarlyStopping(self, metric)
|
|
79
58
|
|
|
80
59
|
|
|
81
60
|
class EarlyStopping(_EarlyStopping):
|
|
82
|
-
def __init__(
|
|
83
|
-
self
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
patience: int = 3,
|
|
88
|
-
verbose: bool = True,
|
|
89
|
-
mode: str = "min",
|
|
90
|
-
strict: bool = True,
|
|
91
|
-
check_finite: bool = True,
|
|
92
|
-
stopping_threshold: float | None = None,
|
|
93
|
-
divergence_threshold: float | None = None,
|
|
94
|
-
check_on_train_epoch_end: bool | None = None,
|
|
95
|
-
log_rank_zero_only: bool = False,
|
|
96
|
-
):
|
|
61
|
+
def __init__(self, config: EarlyStoppingConfig, metric: MetricConfig):
|
|
62
|
+
self.config = config
|
|
63
|
+
self.metric = metric
|
|
64
|
+
del config, metric
|
|
65
|
+
|
|
97
66
|
super().__init__(
|
|
98
|
-
monitor,
|
|
99
|
-
|
|
100
|
-
patience,
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
strict,
|
|
104
|
-
check_finite,
|
|
105
|
-
stopping_threshold,
|
|
106
|
-
divergence_threshold,
|
|
107
|
-
check_on_train_epoch_end,
|
|
108
|
-
log_rank_zero_only,
|
|
67
|
+
monitor=self.metric.validation_monitor,
|
|
68
|
+
mode=self.metric.mode,
|
|
69
|
+
patience=self.patience,
|
|
70
|
+
min_delta=self.min_delta,
|
|
71
|
+
strict=self.strict,
|
|
109
72
|
)
|
|
110
73
|
|
|
111
|
-
self.min_lr = min_lr
|
|
112
|
-
|
|
113
74
|
@override
|
|
114
75
|
@staticmethod
|
|
115
76
|
def _log_info(
|
|
@@ -152,7 +113,7 @@ class EarlyStopping(_EarlyStopping):
|
|
|
152
113
|
def _evaluate_stopping_criteria_min_lr(
|
|
153
114
|
self, trainer: Trainer
|
|
154
115
|
) -> tuple[bool, str | None]:
|
|
155
|
-
if self.min_lr is None:
|
|
116
|
+
if self.config.min_lr is None:
|
|
156
117
|
return False, None
|
|
157
118
|
|
|
158
119
|
# Get the maximum LR across all param groups in all optimizers
|
|
@@ -167,13 +128,13 @@ class EarlyStopping(_EarlyStopping):
|
|
|
167
128
|
return False, None
|
|
168
129
|
|
|
169
130
|
# If the maximum LR is less than the minimum LR, stop training
|
|
170
|
-
if model_max_lr >= self.min_lr:
|
|
131
|
+
if model_max_lr >= self.config.min_lr:
|
|
171
132
|
return False, None
|
|
172
133
|
|
|
173
134
|
return True, (
|
|
174
135
|
"Stopping threshold reached: "
|
|
175
136
|
f"The maximum LR of the model across all param groups is {model_max_lr:.2e} "
|
|
176
|
-
f"which is less than the minimum LR {self.min_lr:.2e}"
|
|
137
|
+
f"which is less than the minimum LR {self.config.min_lr:.2e}"
|
|
177
138
|
)
|
|
178
139
|
|
|
179
140
|
def on_early_stopping(self, trainer: Trainer):
|
|
@@ -15,7 +15,7 @@ nshtrainer/callbacks/checkpoint/_base.py,sha256=yBUZ63P4a3m5jviIUGxBd2WXq_TigpKt
|
|
|
15
15
|
nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=DJiLo7NDzd-lp-O3v7Cv8WejyjXPV_6_RmfltKO9fvE,2165
|
|
16
16
|
nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=CqB_8Xv32rtpLCaEEPi6DbRZm4ph5TWS-LfqIHXUIUA,1097
|
|
17
17
|
nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=ctT88EGT22_t_6tr5r7Sfo43cuve6XeroBnBYRMPOus,3372
|
|
18
|
-
nshtrainer/callbacks/early_stopping.py,sha256=
|
|
18
|
+
nshtrainer/callbacks/early_stopping.py,sha256=2vhlXvp6gIFaJXjmPFVK-nbuBUva3200rve5s0krs2c,4618
|
|
19
19
|
nshtrainer/callbacks/ema.py,sha256=8-WHmKFP3VfnzMviJaIFmVD9xHPqIPmq9NRF5xdu3c8,12131
|
|
20
20
|
nshtrainer/callbacks/finite_checks.py,sha256=gJC_RUr3ais3FJI0uB6wUZnDdE3WRwCix3ppA3PwQXA,2077
|
|
21
21
|
nshtrainer/callbacks/gradient_skipping.py,sha256=pqu5AELx4ctJxR2Y7YSSiGd5oGauVCTZFCEIIS6s88w,3665
|
|
@@ -82,6 +82,6 @@ nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
|
|
|
82
82
|
nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
|
|
83
83
|
nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
|
|
84
84
|
nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
|
|
85
|
-
nshtrainer-0.
|
|
86
|
-
nshtrainer-0.
|
|
87
|
-
nshtrainer-0.
|
|
85
|
+
nshtrainer-0.13.0.dist-info/METADATA,sha256=1OsZvE4GYPu85jiPrBJffrqGvsmdUuQU96TGIYnlIoo,860
|
|
86
|
+
nshtrainer-0.13.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
87
|
+
nshtrainer-0.13.0.dist-info/RECORD,,
|
|
File without changes
|