nshtrainer 0.29.1__py3-none-any.whl → 0.30.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nshtrainer/__init__.py CHANGED
@@ -1,5 +1,6 @@
1
1
  from . import _experimental as _experimental
2
2
  from . import callbacks as callbacks
3
+ from . import config as config
3
4
  from . import data as data
4
5
  from . import lr_scheduler as lr_scheduler
5
6
  from . import metrics as metrics
@@ -0,0 +1,3 @@
1
+ from .duration import Duration as Duration
2
+ from .duration import Epochs as Epochs
3
+ from .duration import Steps as Steps
@@ -0,0 +1,31 @@
1
+ import math
2
+ from typing import Annotated, Literal
3
+
4
+ import nshconfig as C
5
+
6
+
7
+ class Steps(C.Config):
8
+ kind: Literal["steps"] = "steps"
9
+
10
+ value: Annotated[int, C.Field(ge=0)]
11
+ """Number of steps."""
12
+
13
+ def to_steps(self, steps_per_epoch: int):
14
+ return self
15
+
16
+
17
+ class Epochs(C.Config):
18
+ kind: Literal["epochs"] = "epochs"
19
+
20
+ value: Annotated[int | float, C.Field(ge=0)]
21
+ """Number of epochs."""
22
+
23
+ def to_steps(self, steps_per_epoch: int):
24
+ value = self.value * steps_per_epoch
25
+ if not isinstance(value, int):
26
+ value = int(math.ceil(value))
27
+
28
+ return Steps(value=value)
29
+
30
+
31
+ Duration = Annotated[Steps | Epochs, C.Field(discriminator="kind")]
@@ -2,11 +2,11 @@ import math
2
2
  import warnings
3
3
  from typing import Literal
4
4
 
5
- import nshconfig as C
6
5
  from torch.optim import Optimizer
7
6
  from torch.optim.lr_scheduler import LRScheduler
8
7
  from typing_extensions import override
9
8
 
9
+ from ..config import Duration
10
10
  from ._base import LRSchedulerConfigBase, LRSchedulerMetadata
11
11
 
12
12
 
@@ -91,13 +91,13 @@ class LinearWarmupCosineAnnealingLR(LRScheduler):
91
91
  class LinearWarmupCosineDecayLRSchedulerConfig(LRSchedulerConfigBase):
92
92
  name: Literal["linear_warmup_cosine_decay"] = "linear_warmup_cosine_decay"
93
93
 
94
- warmup_epochs: int = C.Field(ge=0)
95
- r"""The number of epochs for the linear warmup phase.
96
- The learning rate is linearly increased from `warmup_start_lr` to the initial learning rate over this number of epochs."""
94
+ warmup_duration: Duration
95
+ r"""The duration for the linear warmup phase.
96
+ The learning rate is linearly increased from `warmup_start_lr` to the initial learning rate over this duration."""
97
97
 
98
- max_epochs: int = C.Field(gt=0)
99
- r"""The total number of epochs.
100
- The learning rate is decayed to `min_lr` over this number of epochs."""
98
+ max_duration: Duration
99
+ r"""The total duration.
100
+ The learning rate is decayed to `min_lr` over this duration."""
101
101
 
102
102
  warmup_start_lr_factor: float = 0.0
103
103
  r"""The initial learning rate for the linear warmup phase, as a factor of the initial learning rate.
@@ -121,11 +121,15 @@ class LinearWarmupCosineDecayLRSchedulerConfig(LRSchedulerConfigBase):
121
121
  @override
122
122
  def create_scheduler_impl(self, optimizer, lightning_module, lr):
123
123
  num_steps_per_epoch = self.compute_num_steps_per_epoch(lightning_module)
124
- warmup_steps = self.warmup_epochs * num_steps_per_epoch
125
- max_steps = self.max_epochs * num_steps_per_epoch
124
+ warmup_steps = self.warmup_duration.to_steps(num_steps_per_epoch).value
125
+ max_steps = self.max_duration.to_steps(num_steps_per_epoch).value
126
126
  warmup_start_lr = self.warmup_start_lr_factor * lr
127
127
  min_lr = self.min_lr_factor * lr
128
128
 
129
+ # Warmup and max steps should be at least 1.
130
+ warmup_steps = max(warmup_steps, 1)
131
+ max_steps = max(max_steps, 1)
132
+
129
133
  # Create the scheduler
130
134
  scheduler = LinearWarmupCosineAnnealingLR(
131
135
  optimizer=optimizer,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.29.1
3
+ Version: 0.30.1
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,4 +1,4 @@
1
- nshtrainer/__init__.py,sha256=39loiLLXbaGiozEsAn8mPHopxaPsek8JsgR9DD2gxtY,583
1
+ nshtrainer/__init__.py,sha256=sUb2yNdkHHhrKWCeWA5QKIA1Xx3jkO1QGD5Pa-HvgbA,614
2
2
  nshtrainer/_callback.py,sha256=A1zLsTy4b_wOYnInLLXGSRdHzT2yNa6mPEql-ozm0u0,1013
3
3
  nshtrainer/_checkpoint/loader.py,sha256=5vjg-OFChXJjgiOVv8vnV8nwTscfdDtEdxQRz6uPfDE,14158
4
4
  nshtrainer/_checkpoint/metadata.py,sha256=5D4PgKodzhLsmQvuF3xxkH49epKaegxi4wh_ImDTtns,4737
@@ -25,6 +25,8 @@ nshtrainer/callbacks/print_table.py,sha256=_FdAHhqylWGk4Z0c2FrLFeiMA4jhfA_beZRK_
25
25
  nshtrainer/callbacks/throughput_monitor.py,sha256=H_ocXErZxUO3dxFk8Tx_VQdpI9E_Ztvqof5WtFevLyQ,1838
26
26
  nshtrainer/callbacks/timer.py,sha256=quS79oYClDUvQxJkNWmDMe0hwRUkkREgTgqzVrnom50,4607
27
27
  nshtrainer/callbacks/wandb_watch.py,sha256=Y6SEXfIx3kDDQbI5zpP53BVq0FBLJbLd3RJsiHZk1-Y,2921
28
+ nshtrainer/config/__init__.py,sha256=v9RtlM1Pqj_4fCDfskgxEtiGtbWH3Tj7lqNsKCDQ4gk,119
29
+ nshtrainer/config/duration.py,sha256=f_obz0eorkktI3HzAuIawABDkvuL4lDqCxcPb3UW7Q4,692
28
30
  nshtrainer/data/__init__.py,sha256=7mk1tr7SWUZ7ySbsf0y0ZPszk7u4QznPhQ-7wnpH9ec,149
29
31
  nshtrainer/data/balanced_batch_sampler.py,sha256=dGBTDDtlBU6c-ZlVQOCnTW7SjTB5hczWsOWEdUWjvkA,4385
30
32
  nshtrainer/data/transform.py,sha256=6SNs3_TpNpfhcwTwvPKyEJ3opM1OT7LmMEYQNHKgRl8,2227
@@ -52,7 +54,7 @@ nshtrainer/loggers/tensorboard.py,sha256=wL2amRSdP68zbslZvBeM0ZQBnjF3hIKsz-_lBbd
52
54
  nshtrainer/loggers/wandb.py,sha256=FPwbf618AYmuPzHdhd1ZFhJ8qDjwTUiSe7cm7g3KCyM,5112
53
55
  nshtrainer/lr_scheduler/__init__.py,sha256=uEvgaFAs-4s_bAEMaildy0GT6OvgpgOEKTuzqutESHE,736
54
56
  nshtrainer/lr_scheduler/_base.py,sha256=7xOIuxQ86YHbFWG5a3gX46emQj1WN_LaY4-i0Q1TDBg,3659
55
- nshtrainer/lr_scheduler/linear_warmup_cosine.py,sha256=mn6cyizyI_stkXtg6zxIEGF9btIxMRWigUHUTlUYCSw,5221
57
+ nshtrainer/lr_scheduler/linear_warmup_cosine.py,sha256=Fyontbfu4k2932xZenE63QL4CrVGWANXdTeq63dUko0,5347
56
58
  nshtrainer/lr_scheduler/reduce_lr_on_plateau.py,sha256=h76oTHYpMxauV_l6lviya5DW-WKArwxxf7ZQizhmbCw,2782
57
59
  nshtrainer/metrics/__init__.py,sha256=ObLIELGguIEcUpRsUkqh1ltrvZii6vglTpJGrPvoy00,50
58
60
  nshtrainer/metrics/_config.py,sha256=jgRBfDAQLFTW7AiUY7CRtdfts6CR6keeuqm0FFMWCzQ,1288
@@ -87,6 +89,6 @@ nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
87
89
  nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
88
90
  nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
89
91
  nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
90
- nshtrainer-0.29.1.dist-info/METADATA,sha256=Qck1QY1pNnjQH9zLMyAMKVVvYMovEeIyP5zV7VlZios,916
91
- nshtrainer-0.29.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
92
- nshtrainer-0.29.1.dist-info/RECORD,,
92
+ nshtrainer-0.30.1.dist-info/METADATA,sha256=LV0wQlmotpfC3qO76dFVCbS26bEl-9YMiTetEeqVQsU,916
93
+ nshtrainer-0.30.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
94
+ nshtrainer-0.30.1.dist-info/RECORD,,