nshtrainer 0.29.0__py3-none-any.whl → 0.30.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nshtrainer/__init__.py CHANGED
@@ -1,5 +1,6 @@
1
1
  from . import _experimental as _experimental
2
2
  from . import callbacks as callbacks
3
+ from . import config as config
3
4
  from . import data as data
4
5
  from . import lr_scheduler as lr_scheduler
5
6
  from . import metrics as metrics
@@ -155,15 +155,15 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
155
155
  trainer.save_checkpoint(filepath, self.config.save_weights_only)
156
156
 
157
157
  if trainer.is_global_zero:
158
+ # Remove old checkpoints
159
+ self.remove_old_checkpoints(trainer)
160
+
158
161
  # Create the latest symlink
159
162
  if (symlink_filename := self.symlink_path()) is not None:
160
163
  symlink_path = self.dirpath / symlink_filename
161
164
  _link_checkpoint(filepath, symlink_path, metadata=True)
162
165
  log.debug(f"Created latest symlink: {symlink_path}")
163
166
 
164
- # Remove old checkpoints
165
- self.remove_old_checkpoints(trainer)
166
-
167
167
  # Barrier to ensure all processes have saved the checkpoint,
168
168
  # deleted the old checkpoints, and created the symlink before continuing
169
169
  trainer.strategy.barrier()
@@ -0,0 +1,3 @@
1
+ from .duration import Duration as Duration
2
+ from .duration import Epochs as Epochs
3
+ from .duration import Steps as Steps
@@ -0,0 +1,31 @@
1
+ import math
2
+ from typing import Annotated, Literal
3
+
4
+ import nshconfig as C
5
+
6
+
7
+ class Steps(C.Config):
8
+ kind: Literal["steps"] = "steps"
9
+
10
+ value: Annotated[int, C.Field(ge=0)]
11
+ """Number of steps."""
12
+
13
+ def to_steps(self, steps_per_epoch: int):
14
+ return self
15
+
16
+
17
+ class Epochs(C.Config):
18
+ kind: Literal["epochs"] = "epochs"
19
+
20
+ value: Annotated[int | float, C.Field(ge=0)]
21
+ """Number of epochs."""
22
+
23
+ def to_steps(self, steps_per_epoch: int):
24
+ value = self.value * steps_per_epoch
25
+ if not isinstance(value, int):
26
+ value = int(math.ceil(value))
27
+
28
+ return Steps(value=value)
29
+
30
+
31
+ Duration = Annotated[Steps | Epochs, C.Field(discriminator="kind")]
@@ -2,11 +2,11 @@ import math
2
2
  import warnings
3
3
  from typing import Literal
4
4
 
5
- import nshconfig as C
6
5
  from torch.optim import Optimizer
7
6
  from torch.optim.lr_scheduler import LRScheduler
8
7
  from typing_extensions import override
9
8
 
9
+ from ..config import Duration
10
10
  from ._base import LRSchedulerConfigBase, LRSchedulerMetadata
11
11
 
12
12
 
@@ -91,13 +91,13 @@ class LinearWarmupCosineAnnealingLR(LRScheduler):
91
91
  class LinearWarmupCosineDecayLRSchedulerConfig(LRSchedulerConfigBase):
92
92
  name: Literal["linear_warmup_cosine_decay"] = "linear_warmup_cosine_decay"
93
93
 
94
- warmup_epochs: int = C.Field(ge=0)
95
- r"""The number of epochs for the linear warmup phase.
96
- The learning rate is linearly increased from `warmup_start_lr` to the initial learning rate over this number of epochs."""
94
+ warmup_duration: Duration
95
+ r"""The duration for the linear warmup phase.
96
+ The learning rate is linearly increased from `warmup_start_lr` to the initial learning rate over this duration."""
97
97
 
98
- max_epochs: int = C.Field(gt=0)
99
- r"""The total number of epochs.
100
- The learning rate is decayed to `min_lr` over this number of epochs."""
98
+ max_duration: Duration
99
+ r"""The total duration.
100
+ The learning rate is decayed to `min_lr` over this duration."""
101
101
 
102
102
  warmup_start_lr_factor: float = 0.0
103
103
  r"""The initial learning rate for the linear warmup phase, as a factor of the initial learning rate.
@@ -121,11 +121,20 @@ class LinearWarmupCosineDecayLRSchedulerConfig(LRSchedulerConfigBase):
121
121
  @override
122
122
  def create_scheduler_impl(self, optimizer, lightning_module, lr):
123
123
  num_steps_per_epoch = self.compute_num_steps_per_epoch(lightning_module)
124
- warmup_steps = self.warmup_epochs * num_steps_per_epoch
125
- max_steps = self.max_epochs * num_steps_per_epoch
124
+ warmup_steps = (
125
+ self.warmup_duration.to_steps(num_steps_per_epoch).value
126
+ * num_steps_per_epoch
127
+ )
128
+ max_steps = (
129
+ self.max_duration.to_steps(num_steps_per_epoch).value * num_steps_per_epoch
130
+ )
126
131
  warmup_start_lr = self.warmup_start_lr_factor * lr
127
132
  min_lr = self.min_lr_factor * lr
128
133
 
134
+ # Warmup and max steps should be at least 1.
135
+ warmup_steps = max(warmup_steps, 1)
136
+ max_steps = max(max_steps, 1)
137
+
129
138
  # Create the scheduler
130
139
  scheduler = LinearWarmupCosineAnnealingLR(
131
140
  optimizer=optimizer,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.29.0
3
+ Version: 0.30.0
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,4 +1,4 @@
1
- nshtrainer/__init__.py,sha256=39loiLLXbaGiozEsAn8mPHopxaPsek8JsgR9DD2gxtY,583
1
+ nshtrainer/__init__.py,sha256=sUb2yNdkHHhrKWCeWA5QKIA1Xx3jkO1QGD5Pa-HvgbA,614
2
2
  nshtrainer/_callback.py,sha256=A1zLsTy4b_wOYnInLLXGSRdHzT2yNa6mPEql-ozm0u0,1013
3
3
  nshtrainer/_checkpoint/loader.py,sha256=5vjg-OFChXJjgiOVv8vnV8nwTscfdDtEdxQRz6uPfDE,14158
4
4
  nshtrainer/_checkpoint/metadata.py,sha256=5D4PgKodzhLsmQvuF3xxkH49epKaegxi4wh_ImDTtns,4737
@@ -10,7 +10,7 @@ nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHm
10
10
  nshtrainer/callbacks/actsave.py,sha256=qbnaKts4_dvjPeAaPtv7Ds12_vEWzaHUfg_--49NB9I,4041
11
11
  nshtrainer/callbacks/base.py,sha256=NpjeKmonJ1Kaz5_39XSn3LlDwvbGjk6WV8BpHSNCvI4,3508
12
12
  nshtrainer/callbacks/checkpoint/__init__.py,sha256=g-3zIthupERKqWZQw-A_busQPaPRkto6iHBV-M7nK1Y,527
13
- nshtrainer/callbacks/checkpoint/_base.py,sha256=MzMF7JtvR3A_7DAM2r4NGQSBDisA7krv6WlVk5rKABQ,6157
13
+ nshtrainer/callbacks/checkpoint/_base.py,sha256=vvlwuD-20NozYVIolGGShmUdkkNYeuwN6xCoFnK4GiU,6157
14
14
  nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=8BHgLAd3Tuzf5sup0guEAKF1jJiAwYsjdKBFYZw98ac,2171
15
15
  nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=CWWv0cSwQ1VAX26N7hAyMxbNCk26Keh39oQguBEK5To,1102
16
16
  nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=ctT88EGT22_t_6tr5r7Sfo43cuve6XeroBnBYRMPOus,3372
@@ -25,6 +25,8 @@ nshtrainer/callbacks/print_table.py,sha256=_FdAHhqylWGk4Z0c2FrLFeiMA4jhfA_beZRK_
25
25
  nshtrainer/callbacks/throughput_monitor.py,sha256=H_ocXErZxUO3dxFk8Tx_VQdpI9E_Ztvqof5WtFevLyQ,1838
26
26
  nshtrainer/callbacks/timer.py,sha256=quS79oYClDUvQxJkNWmDMe0hwRUkkREgTgqzVrnom50,4607
27
27
  nshtrainer/callbacks/wandb_watch.py,sha256=Y6SEXfIx3kDDQbI5zpP53BVq0FBLJbLd3RJsiHZk1-Y,2921
28
+ nshtrainer/config/__init__.py,sha256=v9RtlM1Pqj_4fCDfskgxEtiGtbWH3Tj7lqNsKCDQ4gk,119
29
+ nshtrainer/config/duration.py,sha256=f_obz0eorkktI3HzAuIawABDkvuL4lDqCxcPb3UW7Q4,692
28
30
  nshtrainer/data/__init__.py,sha256=7mk1tr7SWUZ7ySbsf0y0ZPszk7u4QznPhQ-7wnpH9ec,149
29
31
  nshtrainer/data/balanced_batch_sampler.py,sha256=dGBTDDtlBU6c-ZlVQOCnTW7SjTB5hczWsOWEdUWjvkA,4385
30
32
  nshtrainer/data/transform.py,sha256=6SNs3_TpNpfhcwTwvPKyEJ3opM1OT7LmMEYQNHKgRl8,2227
@@ -52,7 +54,7 @@ nshtrainer/loggers/tensorboard.py,sha256=wL2amRSdP68zbslZvBeM0ZQBnjF3hIKsz-_lBbd
52
54
  nshtrainer/loggers/wandb.py,sha256=FPwbf618AYmuPzHdhd1ZFhJ8qDjwTUiSe7cm7g3KCyM,5112
53
55
  nshtrainer/lr_scheduler/__init__.py,sha256=uEvgaFAs-4s_bAEMaildy0GT6OvgpgOEKTuzqutESHE,736
54
56
  nshtrainer/lr_scheduler/_base.py,sha256=7xOIuxQ86YHbFWG5a3gX46emQj1WN_LaY4-i0Q1TDBg,3659
55
- nshtrainer/lr_scheduler/linear_warmup_cosine.py,sha256=mn6cyizyI_stkXtg6zxIEGF9btIxMRWigUHUTlUYCSw,5221
57
+ nshtrainer/lr_scheduler/linear_warmup_cosine.py,sha256=pmX5n7mmhSqPTz4Nu9g_JTsE9gzCkuU4V3GuAHUsDoA,5451
56
58
  nshtrainer/lr_scheduler/reduce_lr_on_plateau.py,sha256=h76oTHYpMxauV_l6lviya5DW-WKArwxxf7ZQizhmbCw,2782
57
59
  nshtrainer/metrics/__init__.py,sha256=ObLIELGguIEcUpRsUkqh1ltrvZii6vglTpJGrPvoy00,50
58
60
  nshtrainer/metrics/_config.py,sha256=jgRBfDAQLFTW7AiUY7CRtdfts6CR6keeuqm0FFMWCzQ,1288
@@ -87,6 +89,6 @@ nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
87
89
  nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
88
90
  nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
89
91
  nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
90
- nshtrainer-0.29.0.dist-info/METADATA,sha256=EP3cdORGt4w_H0pX-whQJ5ULsO5HQXo3VlHp5bkfqfk,916
91
- nshtrainer-0.29.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
92
- nshtrainer-0.29.0.dist-info/RECORD,,
92
+ nshtrainer-0.30.0.dist-info/METADATA,sha256=lDudS-lD7exw8lNe_3vT13ysnk491QCkObXGLQtjhMk,916
93
+ nshtrainer-0.30.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
94
+ nshtrainer-0.30.0.dist-info/RECORD,,