nshtrainer 0.11.13__py3-none-any.whl → 0.12.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -13,6 +13,7 @@ from .checkpoint import (
13
13
  OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
14
14
  )
15
15
  from .early_stopping import EarlyStopping as EarlyStopping
16
+ from .early_stopping import EarlyStoppingConfig as EarlyStoppingConfig
16
17
  from .ema import EMA as EMA
17
18
  from .ema import EMAConfig as EMAConfig
18
19
  from .finite_checks import FiniteChecksCallback as FiniteChecksCallback
@@ -34,7 +35,8 @@ from .wandb_watch import WandbWatchCallback as WandbWatchCallback
34
35
  from .wandb_watch import WandbWatchConfig as WandbWatchConfig
35
36
 
36
37
  CallbackConfig = Annotated[
37
- ThroughputMonitorConfig
38
+ EarlyStoppingConfig
39
+ | ThroughputMonitorConfig
38
40
  | EpochTimerConfig
39
41
  | PrintTableMetricsConfig
40
42
  | FiniteChecksConfig
@@ -55,7 +55,7 @@ class BestCheckpoint(CheckpointBase[BestCheckpointCallbackConfig]):
55
55
 
56
56
  @override
57
57
  def default_filename(self):
58
- return f"epoch{{epoch:03d}}-{self._metric_name_normalized}{{{self.metric.validation_monitor}}}"
58
+ return f"epoch{{epoch}}-step{{step}}-{self._metric_name_normalized}{{{self.metric.validation_monitor}}}"
59
59
 
60
60
  @override
61
61
  def topk_sort_key(self, metadata: CheckpointMetadata):
@@ -28,7 +28,7 @@ class LastCheckpoint(CheckpointBase[LastCheckpointCallbackConfig]):
28
28
 
29
29
  @override
30
30
  def default_filename(self):
31
- return "epoch{epoch:03d}-step{step:07d}"
31
+ return "epoch{epoch}-step{step}"
32
32
 
33
33
  @override
34
34
  def topk_sort_key(self, metadata: CheckpointMetadata):
@@ -1,5 +1,6 @@
1
1
  import logging
2
2
  import math
3
+ from typing import Literal
3
4
 
4
5
  from lightning.fabric.utilities.rank_zero import _get_rank
5
6
  from lightning.pytorch import Trainer
@@ -7,9 +8,76 @@ from lightning.pytorch.callbacks import EarlyStopping as _EarlyStopping
7
8
  from lightning.pytorch.utilities.rank_zero import rank_prefixed_message
8
9
  from typing_extensions import override
9
10
 
11
+ from .base import CallbackConfigBase
12
+
10
13
  log = logging.getLogger(__name__)
11
14
 
12
15
 
16
+ class EarlyStoppingConfig(CallbackConfigBase):
17
+ name: Literal["early_stopping"] = "early_stopping"
18
+
19
+ monitor: str | None = None
20
+ """
21
+ The metric to monitor for early stopping.
22
+ If None, the primary metric will be used.
23
+ """
24
+
25
+ mode: Literal["min", "max"] | None = None
26
+ """
27
+ The mode for the metric to monitor for early stopping.
28
+ If None, the primary metric mode will be used.
29
+ """
30
+
31
+ patience: int
32
+ """
33
+ Number of epochs with no improvement after which training will be stopped.
34
+ """
35
+
36
+ min_delta: float = 1.0e-8
37
+ """
38
+ Minimum change in the monitored quantity to qualify as an improvement.
39
+ """
40
+
41
+ min_lr: float | None = None
42
+ """
43
+ Minimum learning rate. If the learning rate of the model is less than this value,
44
+ the training will be stopped.
45
+ """
46
+
47
+ strict: bool = True
48
+ """
49
+ Whether to enforce that the monitored quantity must improve by at least `min_delta`
50
+ to qualify as an improvement.
51
+ """
52
+
53
+ @override
54
+ def create_callbacks(self, root_config):
55
+ monitor = self.monitor
56
+ mode = self.mode
57
+ if monitor is None:
58
+ assert mode is None, "If `monitor` is not provided, `mode` must be None."
59
+
60
+ primary_metric = root_config.primary_metric
61
+ if primary_metric is None:
62
+ raise ValueError(
63
+ "No primary metric is set, so `monitor` must be provided in `early_stopping`."
64
+ )
65
+ monitor = primary_metric.validation_monitor
66
+ mode = primary_metric.mode
67
+
68
+ if mode is None:
69
+ mode = "min"
70
+
71
+ yield EarlyStopping(
72
+ monitor=monitor,
73
+ mode=mode,
74
+ patience=self.patience,
75
+ min_delta=self.min_delta,
76
+ min_lr=self.min_lr,
77
+ strict=self.strict,
78
+ )
79
+
80
+
13
81
  class EarlyStopping(_EarlyStopping):
14
82
  def __init__(
15
83
  self,
@@ -38,6 +38,7 @@ from .._checkpoint.loader import CheckpointLoadingConfig
38
38
  from ..callbacks import (
39
39
  BestCheckpointCallbackConfig,
40
40
  CallbackConfig,
41
+ EarlyStoppingConfig,
41
42
  LastCheckpointCallbackConfig,
42
43
  OnExceptionCheckpointCallbackConfig,
43
44
  WandbWatchConfig,
@@ -1043,73 +1044,6 @@ class LightningTrainerKwargs(TypedDict, total=False):
1043
1044
  """
1044
1045
 
1045
1046
 
1046
- class EarlyStoppingConfig(CallbackConfigBase):
1047
- monitor: str | None = None
1048
- """
1049
- The metric to monitor for early stopping.
1050
- If None, the primary metric will be used.
1051
- """
1052
-
1053
- mode: Literal["min", "max"] | None = None
1054
- """
1055
- The mode for the metric to monitor for early stopping.
1056
- If None, the primary metric mode will be used.
1057
- """
1058
-
1059
- patience: int
1060
- """
1061
- Number of epochs with no improvement after which training will be stopped.
1062
- """
1063
-
1064
- min_delta: float = 1.0e-8
1065
- """
1066
- Minimum change in the monitored quantity to qualify as an improvement.
1067
- """
1068
-
1069
- min_lr: float | None = None
1070
- """
1071
- Minimum learning rate. If the learning rate of the model is less than this value,
1072
- the training will be stopped.
1073
- """
1074
-
1075
- strict: bool = True
1076
- """
1077
- Whether to enforce that the monitored quantity must improve by at least `min_delta`
1078
- to qualify as an improvement.
1079
- """
1080
-
1081
- @override
1082
- def create_callbacks(self, root_config: "BaseConfig"):
1083
- from ..callbacks.early_stopping import EarlyStopping
1084
-
1085
- monitor = self.monitor
1086
- mode = self.mode
1087
- if monitor is None:
1088
- assert mode is None, "If `monitor` is not provided, `mode` must be None."
1089
-
1090
- primary_metric = root_config.primary_metric
1091
- if primary_metric is None:
1092
- raise ValueError(
1093
- "No primary metric is set, so `monitor` must be provided in `early_stopping`."
1094
- )
1095
- monitor = primary_metric.validation_monitor
1096
- mode = primary_metric.mode
1097
-
1098
- if mode is None:
1099
- mode = "min"
1100
-
1101
- return [
1102
- EarlyStopping(
1103
- monitor=monitor,
1104
- mode=mode,
1105
- patience=self.patience,
1106
- min_delta=self.min_delta,
1107
- min_lr=self.min_lr,
1108
- strict=self.strict,
1109
- )
1110
- ]
1111
-
1112
-
1113
1047
  class SanityCheckingConfig(C.Config):
1114
1048
  reduce_lr_on_plateau: Literal["disable", "error", "warn"] = "error"
1115
1049
  """
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.11.13
3
+ Version: 0.12.1
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -6,16 +6,16 @@ nshtrainer/_experimental/__init__.py,sha256=2tQIcrWT8U8no_AeBTYnozaTmxN40kuAJdGQ
6
6
  nshtrainer/_experimental/flops/__init__.py,sha256=edo9Ez3LlrnxkNRX9W6YBhPkRPKYGLpkpnl5gx7sEX8,1550
7
7
  nshtrainer/_experimental/flops/flop_counter.py,sha256=-sL0Fy6poXa__hyzUMdZScjPULp4coQELQpPU6p6dXU,25736
8
8
  nshtrainer/_experimental/flops/module_tracker.py,sha256=bUL-IRTd0aF_DwmXkZjHZAA31p4ZEhyqhc26XWKQUUY,4922
9
- nshtrainer/callbacks/__init__.py,sha256=k-DbpIlH2t5-oR3gHGHr8KiyCd_Twers4PcIUM1noqQ,2262
9
+ nshtrainer/callbacks/__init__.py,sha256=4qocBDzQbLLhhbIEfvbA3SQB_Dy9ZJH7keMwPay-ZS8,2359
10
10
  nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHmPDtzdC4ctgXyRudkRJqH4m4,23184
11
11
  nshtrainer/callbacks/actsave.py,sha256=qbnaKts4_dvjPeAaPtv7Ds12_vEWzaHUfg_--49NB9I,4041
12
12
  nshtrainer/callbacks/base.py,sha256=UnlYZAqSb8UwBJR-N5-XunxFx2yZjZ4lyGqUfhbCRlI,3555
13
13
  nshtrainer/callbacks/checkpoint/__init__.py,sha256=g-3zIthupERKqWZQw-A_busQPaPRkto6iHBV-M7nK1Y,527
14
14
  nshtrainer/callbacks/checkpoint/_base.py,sha256=yBUZ63P4a3m5jviIUGxBd2WXq_TigpKt3w4KdzqrzLs,6094
15
- nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=qajV_GxeUg0GXeOtiimmPabMJnkNu_I1prZb2ksPOG8,2156
16
- nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=ctwl2bmHC79enpg9wi-iHWQYIkP-iQIeyEvJUUJ5AW8,1105
15
+ nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=DJiLo7NDzd-lp-O3v7Cv8WejyjXPV_6_RmfltKO9fvE,2165
16
+ nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=CqB_8Xv32rtpLCaEEPi6DbRZm4ph5TWS-LfqIHXUIUA,1097
17
17
  nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=ctT88EGT22_t_6tr5r7Sfo43cuve6XeroBnBYRMPOus,3372
18
- nshtrainer/callbacks/early_stopping.py,sha256=LGn3rdbvkFfUo9kwMzK4eMGlPAqD9uFdowDx6VdfozQ,3761
18
+ nshtrainer/callbacks/early_stopping.py,sha256=m-YIWwQAmp1KUNYwJEiuQyNRyX-0pmzBbAQrkH1BLYI,5664
19
19
  nshtrainer/callbacks/ema.py,sha256=8-WHmKFP3VfnzMviJaIFmVD9xHPqIPmq9NRF5xdu3c8,12131
20
20
  nshtrainer/callbacks/finite_checks.py,sha256=gJC_RUr3ais3FJI0uB6wUZnDdE3WRwCix3ppA3PwQXA,2077
21
21
  nshtrainer/callbacks/gradient_skipping.py,sha256=pqu5AELx4ctJxR2Y7YSSiGd5oGauVCTZFCEIIS6s88w,3665
@@ -54,7 +54,7 @@ nshtrainer/metrics/__init__.py,sha256=ObLIELGguIEcUpRsUkqh1ltrvZii6vglTpJGrPvoy0
54
54
  nshtrainer/metrics/_config.py,sha256=jgRBfDAQLFTW7AiUY7CRtdfts6CR6keeuqm0FFMWCzQ,1288
55
55
  nshtrainer/model/__init__.py,sha256=RlGW5a46DZcqK6cYICYxDaKpZIEj-8zLxoMrl432tno,1429
56
56
  nshtrainer/model/base.py,sha256=AXRfEsFAT0Ln7zjYVPU5NgtHS_c8FZM-M4pyLamO7OA,17516
57
- nshtrainer/model/config.py,sha256=3qSaXNjpKnxF60LbfpONsVhjtrteLEdHkjb5KAhAnIk,52757
57
+ nshtrainer/model/config.py,sha256=gvZRv8Gow6ViF0t8OIoR_IzxLn8GzVmbpvoVzpdykFc,50871
58
58
  nshtrainer/model/modules/callback.py,sha256=K0-cyEtBcQhI7Q2e-AGTE8T-GghUPY9DYmneU6ULV6g,6401
59
59
  nshtrainer/model/modules/debug.py,sha256=Yy7XEdPou9BkCsD5hJchwJGmCVGrfUru5g9VjPM4uAw,1120
60
60
  nshtrainer/model/modules/distributed.py,sha256=ABpR9d-3uBS_fivfy_WYW-dExW6vp5BPaoPQnOudHng,1725
@@ -82,6 +82,6 @@ nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
82
82
  nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
83
83
  nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
84
84
  nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
85
- nshtrainer-0.11.13.dist-info/METADATA,sha256=wKFqCeZ6hxeHznkFksP3-kqF6vhG7ErudiM-auKKEJE,861
86
- nshtrainer-0.11.13.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
87
- nshtrainer-0.11.13.dist-info/RECORD,,
85
+ nshtrainer-0.12.1.dist-info/METADATA,sha256=5pwyF0q1y7DxG3HNgRPJ3vyVG3XuctB9KyR2XlrW9m0,860
86
+ nshtrainer-0.12.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
87
+ nshtrainer-0.12.1.dist-info/RECORD,,