nshtrainer 0.12.0__py3-none-any.whl → 0.12.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/callbacks/__init__.py +3 -1
- nshtrainer/callbacks/early_stopping.py +68 -0
- nshtrainer/model/config.py +1 -67
- {nshtrainer-0.12.0.dist-info → nshtrainer-0.12.1.dist-info}/METADATA +1 -1
- {nshtrainer-0.12.0.dist-info → nshtrainer-0.12.1.dist-info}/RECORD +6 -6
- {nshtrainer-0.12.0.dist-info → nshtrainer-0.12.1.dist-info}/WHEEL +0 -0
nshtrainer/callbacks/__init__.py
CHANGED
|
@@ -13,6 +13,7 @@ from .checkpoint import (
|
|
|
13
13
|
OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
|
|
14
14
|
)
|
|
15
15
|
from .early_stopping import EarlyStopping as EarlyStopping
|
|
16
|
+
from .early_stopping import EarlyStoppingConfig as EarlyStoppingConfig
|
|
16
17
|
from .ema import EMA as EMA
|
|
17
18
|
from .ema import EMAConfig as EMAConfig
|
|
18
19
|
from .finite_checks import FiniteChecksCallback as FiniteChecksCallback
|
|
@@ -34,7 +35,8 @@ from .wandb_watch import WandbWatchCallback as WandbWatchCallback
|
|
|
34
35
|
from .wandb_watch import WandbWatchConfig as WandbWatchConfig
|
|
35
36
|
|
|
36
37
|
CallbackConfig = Annotated[
|
|
37
|
-
|
|
38
|
+
EarlyStoppingConfig
|
|
39
|
+
| ThroughputMonitorConfig
|
|
38
40
|
| EpochTimerConfig
|
|
39
41
|
| PrintTableMetricsConfig
|
|
40
42
|
| FiniteChecksConfig
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import math
|
|
3
|
+
from typing import Literal
|
|
3
4
|
|
|
4
5
|
from lightning.fabric.utilities.rank_zero import _get_rank
|
|
5
6
|
from lightning.pytorch import Trainer
|
|
@@ -7,9 +8,76 @@ from lightning.pytorch.callbacks import EarlyStopping as _EarlyStopping
|
|
|
7
8
|
from lightning.pytorch.utilities.rank_zero import rank_prefixed_message
|
|
8
9
|
from typing_extensions import override
|
|
9
10
|
|
|
11
|
+
from .base import CallbackConfigBase
|
|
12
|
+
|
|
10
13
|
log = logging.getLogger(__name__)
|
|
11
14
|
|
|
12
15
|
|
|
16
|
+
class EarlyStoppingConfig(CallbackConfigBase):
|
|
17
|
+
name: Literal["early_stopping"] = "early_stopping"
|
|
18
|
+
|
|
19
|
+
monitor: str | None = None
|
|
20
|
+
"""
|
|
21
|
+
The metric to monitor for early stopping.
|
|
22
|
+
If None, the primary metric will be used.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
mode: Literal["min", "max"] | None = None
|
|
26
|
+
"""
|
|
27
|
+
The mode for the metric to monitor for early stopping.
|
|
28
|
+
If None, the primary metric mode will be used.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
patience: int
|
|
32
|
+
"""
|
|
33
|
+
Number of epochs with no improvement after which training will be stopped.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
min_delta: float = 1.0e-8
|
|
37
|
+
"""
|
|
38
|
+
Minimum change in the monitored quantity to qualify as an improvement.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
min_lr: float | None = None
|
|
42
|
+
"""
|
|
43
|
+
Minimum learning rate. If the learning rate of the model is less than this value,
|
|
44
|
+
the training will be stopped.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
strict: bool = True
|
|
48
|
+
"""
|
|
49
|
+
Whether to enforce that the monitored quantity must improve by at least `min_delta`
|
|
50
|
+
to qualify as an improvement.
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
@override
|
|
54
|
+
def create_callbacks(self, root_config):
|
|
55
|
+
monitor = self.monitor
|
|
56
|
+
mode = self.mode
|
|
57
|
+
if monitor is None:
|
|
58
|
+
assert mode is None, "If `monitor` is not provided, `mode` must be None."
|
|
59
|
+
|
|
60
|
+
primary_metric = root_config.primary_metric
|
|
61
|
+
if primary_metric is None:
|
|
62
|
+
raise ValueError(
|
|
63
|
+
"No primary metric is set, so `monitor` must be provided in `early_stopping`."
|
|
64
|
+
)
|
|
65
|
+
monitor = primary_metric.validation_monitor
|
|
66
|
+
mode = primary_metric.mode
|
|
67
|
+
|
|
68
|
+
if mode is None:
|
|
69
|
+
mode = "min"
|
|
70
|
+
|
|
71
|
+
yield EarlyStopping(
|
|
72
|
+
monitor=monitor,
|
|
73
|
+
mode=mode,
|
|
74
|
+
patience=self.patience,
|
|
75
|
+
min_delta=self.min_delta,
|
|
76
|
+
min_lr=self.min_lr,
|
|
77
|
+
strict=self.strict,
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
|
|
13
81
|
class EarlyStopping(_EarlyStopping):
|
|
14
82
|
def __init__(
|
|
15
83
|
self,
|
nshtrainer/model/config.py
CHANGED
|
@@ -38,6 +38,7 @@ from .._checkpoint.loader import CheckpointLoadingConfig
|
|
|
38
38
|
from ..callbacks import (
|
|
39
39
|
BestCheckpointCallbackConfig,
|
|
40
40
|
CallbackConfig,
|
|
41
|
+
EarlyStoppingConfig,
|
|
41
42
|
LastCheckpointCallbackConfig,
|
|
42
43
|
OnExceptionCheckpointCallbackConfig,
|
|
43
44
|
WandbWatchConfig,
|
|
@@ -1043,73 +1044,6 @@ class LightningTrainerKwargs(TypedDict, total=False):
|
|
|
1043
1044
|
"""
|
|
1044
1045
|
|
|
1045
1046
|
|
|
1046
|
-
class EarlyStoppingConfig(CallbackConfigBase):
|
|
1047
|
-
monitor: str | None = None
|
|
1048
|
-
"""
|
|
1049
|
-
The metric to monitor for early stopping.
|
|
1050
|
-
If None, the primary metric will be used.
|
|
1051
|
-
"""
|
|
1052
|
-
|
|
1053
|
-
mode: Literal["min", "max"] | None = None
|
|
1054
|
-
"""
|
|
1055
|
-
The mode for the metric to monitor for early stopping.
|
|
1056
|
-
If None, the primary metric mode will be used.
|
|
1057
|
-
"""
|
|
1058
|
-
|
|
1059
|
-
patience: int
|
|
1060
|
-
"""
|
|
1061
|
-
Number of epochs with no improvement after which training will be stopped.
|
|
1062
|
-
"""
|
|
1063
|
-
|
|
1064
|
-
min_delta: float = 1.0e-8
|
|
1065
|
-
"""
|
|
1066
|
-
Minimum change in the monitored quantity to qualify as an improvement.
|
|
1067
|
-
"""
|
|
1068
|
-
|
|
1069
|
-
min_lr: float | None = None
|
|
1070
|
-
"""
|
|
1071
|
-
Minimum learning rate. If the learning rate of the model is less than this value,
|
|
1072
|
-
the training will be stopped.
|
|
1073
|
-
"""
|
|
1074
|
-
|
|
1075
|
-
strict: bool = True
|
|
1076
|
-
"""
|
|
1077
|
-
Whether to enforce that the monitored quantity must improve by at least `min_delta`
|
|
1078
|
-
to qualify as an improvement.
|
|
1079
|
-
"""
|
|
1080
|
-
|
|
1081
|
-
@override
|
|
1082
|
-
def create_callbacks(self, root_config: "BaseConfig"):
|
|
1083
|
-
from ..callbacks.early_stopping import EarlyStopping
|
|
1084
|
-
|
|
1085
|
-
monitor = self.monitor
|
|
1086
|
-
mode = self.mode
|
|
1087
|
-
if monitor is None:
|
|
1088
|
-
assert mode is None, "If `monitor` is not provided, `mode` must be None."
|
|
1089
|
-
|
|
1090
|
-
primary_metric = root_config.primary_metric
|
|
1091
|
-
if primary_metric is None:
|
|
1092
|
-
raise ValueError(
|
|
1093
|
-
"No primary metric is set, so `monitor` must be provided in `early_stopping`."
|
|
1094
|
-
)
|
|
1095
|
-
monitor = primary_metric.validation_monitor
|
|
1096
|
-
mode = primary_metric.mode
|
|
1097
|
-
|
|
1098
|
-
if mode is None:
|
|
1099
|
-
mode = "min"
|
|
1100
|
-
|
|
1101
|
-
return [
|
|
1102
|
-
EarlyStopping(
|
|
1103
|
-
monitor=monitor,
|
|
1104
|
-
mode=mode,
|
|
1105
|
-
patience=self.patience,
|
|
1106
|
-
min_delta=self.min_delta,
|
|
1107
|
-
min_lr=self.min_lr,
|
|
1108
|
-
strict=self.strict,
|
|
1109
|
-
)
|
|
1110
|
-
]
|
|
1111
|
-
|
|
1112
|
-
|
|
1113
1047
|
class SanityCheckingConfig(C.Config):
|
|
1114
1048
|
reduce_lr_on_plateau: Literal["disable", "error", "warn"] = "error"
|
|
1115
1049
|
"""
|
|
@@ -6,7 +6,7 @@ nshtrainer/_experimental/__init__.py,sha256=2tQIcrWT8U8no_AeBTYnozaTmxN40kuAJdGQ
|
|
|
6
6
|
nshtrainer/_experimental/flops/__init__.py,sha256=edo9Ez3LlrnxkNRX9W6YBhPkRPKYGLpkpnl5gx7sEX8,1550
|
|
7
7
|
nshtrainer/_experimental/flops/flop_counter.py,sha256=-sL0Fy6poXa__hyzUMdZScjPULp4coQELQpPU6p6dXU,25736
|
|
8
8
|
nshtrainer/_experimental/flops/module_tracker.py,sha256=bUL-IRTd0aF_DwmXkZjHZAA31p4ZEhyqhc26XWKQUUY,4922
|
|
9
|
-
nshtrainer/callbacks/__init__.py,sha256=
|
|
9
|
+
nshtrainer/callbacks/__init__.py,sha256=4qocBDzQbLLhhbIEfvbA3SQB_Dy9ZJH7keMwPay-ZS8,2359
|
|
10
10
|
nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHmPDtzdC4ctgXyRudkRJqH4m4,23184
|
|
11
11
|
nshtrainer/callbacks/actsave.py,sha256=qbnaKts4_dvjPeAaPtv7Ds12_vEWzaHUfg_--49NB9I,4041
|
|
12
12
|
nshtrainer/callbacks/base.py,sha256=UnlYZAqSb8UwBJR-N5-XunxFx2yZjZ4lyGqUfhbCRlI,3555
|
|
@@ -15,7 +15,7 @@ nshtrainer/callbacks/checkpoint/_base.py,sha256=yBUZ63P4a3m5jviIUGxBd2WXq_TigpKt
|
|
|
15
15
|
nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=DJiLo7NDzd-lp-O3v7Cv8WejyjXPV_6_RmfltKO9fvE,2165
|
|
16
16
|
nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=CqB_8Xv32rtpLCaEEPi6DbRZm4ph5TWS-LfqIHXUIUA,1097
|
|
17
17
|
nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=ctT88EGT22_t_6tr5r7Sfo43cuve6XeroBnBYRMPOus,3372
|
|
18
|
-
nshtrainer/callbacks/early_stopping.py,sha256=
|
|
18
|
+
nshtrainer/callbacks/early_stopping.py,sha256=m-YIWwQAmp1KUNYwJEiuQyNRyX-0pmzBbAQrkH1BLYI,5664
|
|
19
19
|
nshtrainer/callbacks/ema.py,sha256=8-WHmKFP3VfnzMviJaIFmVD9xHPqIPmq9NRF5xdu3c8,12131
|
|
20
20
|
nshtrainer/callbacks/finite_checks.py,sha256=gJC_RUr3ais3FJI0uB6wUZnDdE3WRwCix3ppA3PwQXA,2077
|
|
21
21
|
nshtrainer/callbacks/gradient_skipping.py,sha256=pqu5AELx4ctJxR2Y7YSSiGd5oGauVCTZFCEIIS6s88w,3665
|
|
@@ -54,7 +54,7 @@ nshtrainer/metrics/__init__.py,sha256=ObLIELGguIEcUpRsUkqh1ltrvZii6vglTpJGrPvoy0
|
|
|
54
54
|
nshtrainer/metrics/_config.py,sha256=jgRBfDAQLFTW7AiUY7CRtdfts6CR6keeuqm0FFMWCzQ,1288
|
|
55
55
|
nshtrainer/model/__init__.py,sha256=RlGW5a46DZcqK6cYICYxDaKpZIEj-8zLxoMrl432tno,1429
|
|
56
56
|
nshtrainer/model/base.py,sha256=AXRfEsFAT0Ln7zjYVPU5NgtHS_c8FZM-M4pyLamO7OA,17516
|
|
57
|
-
nshtrainer/model/config.py,sha256=
|
|
57
|
+
nshtrainer/model/config.py,sha256=gvZRv8Gow6ViF0t8OIoR_IzxLn8GzVmbpvoVzpdykFc,50871
|
|
58
58
|
nshtrainer/model/modules/callback.py,sha256=K0-cyEtBcQhI7Q2e-AGTE8T-GghUPY9DYmneU6ULV6g,6401
|
|
59
59
|
nshtrainer/model/modules/debug.py,sha256=Yy7XEdPou9BkCsD5hJchwJGmCVGrfUru5g9VjPM4uAw,1120
|
|
60
60
|
nshtrainer/model/modules/distributed.py,sha256=ABpR9d-3uBS_fivfy_WYW-dExW6vp5BPaoPQnOudHng,1725
|
|
@@ -82,6 +82,6 @@ nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
|
|
|
82
82
|
nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
|
|
83
83
|
nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
|
|
84
84
|
nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
|
|
85
|
-
nshtrainer-0.12.
|
|
86
|
-
nshtrainer-0.12.
|
|
87
|
-
nshtrainer-0.12.
|
|
85
|
+
nshtrainer-0.12.1.dist-info/METADATA,sha256=5pwyF0q1y7DxG3HNgRPJ3vyVG3XuctB9KyR2XlrW9m0,860
|
|
86
|
+
nshtrainer-0.12.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
87
|
+
nshtrainer-0.12.1.dist-info/RECORD,,
|
|
File without changes
|