nshtrainer 0.29.0__py3-none-any.whl → 0.30.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/__init__.py +1 -0
- nshtrainer/callbacks/checkpoint/_base.py +3 -3
- nshtrainer/config/__init__.py +3 -0
- nshtrainer/config/duration.py +31 -0
- nshtrainer/lr_scheduler/linear_warmup_cosine.py +18 -9
- {nshtrainer-0.29.0.dist-info → nshtrainer-0.30.0.dist-info}/METADATA +1 -1
- {nshtrainer-0.29.0.dist-info → nshtrainer-0.30.0.dist-info}/RECORD +8 -6
- {nshtrainer-0.29.0.dist-info → nshtrainer-0.30.0.dist-info}/WHEEL +0 -0
nshtrainer/__init__.py
CHANGED
|
@@ -155,15 +155,15 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
|
|
|
155
155
|
trainer.save_checkpoint(filepath, self.config.save_weights_only)
|
|
156
156
|
|
|
157
157
|
if trainer.is_global_zero:
|
|
158
|
+
# Remove old checkpoints
|
|
159
|
+
self.remove_old_checkpoints(trainer)
|
|
160
|
+
|
|
158
161
|
# Create the latest symlink
|
|
159
162
|
if (symlink_filename := self.symlink_path()) is not None:
|
|
160
163
|
symlink_path = self.dirpath / symlink_filename
|
|
161
164
|
_link_checkpoint(filepath, symlink_path, metadata=True)
|
|
162
165
|
log.debug(f"Created latest symlink: {symlink_path}")
|
|
163
166
|
|
|
164
|
-
# Remove old checkpoints
|
|
165
|
-
self.remove_old_checkpoints(trainer)
|
|
166
|
-
|
|
167
167
|
# Barrier to ensure all processes have saved the checkpoint,
|
|
168
168
|
# deleted the old checkpoints, and created the symlink before continuing
|
|
169
169
|
trainer.strategy.barrier()
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from typing import Annotated, Literal
|
|
3
|
+
|
|
4
|
+
import nshconfig as C
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class Steps(C.Config):
|
|
8
|
+
kind: Literal["steps"] = "steps"
|
|
9
|
+
|
|
10
|
+
value: Annotated[int, C.Field(ge=0)]
|
|
11
|
+
"""Number of steps."""
|
|
12
|
+
|
|
13
|
+
def to_steps(self, steps_per_epoch: int):
|
|
14
|
+
return self
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Epochs(C.Config):
|
|
18
|
+
kind: Literal["epochs"] = "epochs"
|
|
19
|
+
|
|
20
|
+
value: Annotated[int | float, C.Field(ge=0)]
|
|
21
|
+
"""Number of epochs."""
|
|
22
|
+
|
|
23
|
+
def to_steps(self, steps_per_epoch: int):
|
|
24
|
+
value = self.value * steps_per_epoch
|
|
25
|
+
if not isinstance(value, int):
|
|
26
|
+
value = int(math.ceil(value))
|
|
27
|
+
|
|
28
|
+
return Steps(value=value)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
Duration = Annotated[Steps | Epochs, C.Field(discriminator="kind")]
|
|
@@ -2,11 +2,11 @@ import math
|
|
|
2
2
|
import warnings
|
|
3
3
|
from typing import Literal
|
|
4
4
|
|
|
5
|
-
import nshconfig as C
|
|
6
5
|
from torch.optim import Optimizer
|
|
7
6
|
from torch.optim.lr_scheduler import LRScheduler
|
|
8
7
|
from typing_extensions import override
|
|
9
8
|
|
|
9
|
+
from ..config import Duration
|
|
10
10
|
from ._base import LRSchedulerConfigBase, LRSchedulerMetadata
|
|
11
11
|
|
|
12
12
|
|
|
@@ -91,13 +91,13 @@ class LinearWarmupCosineAnnealingLR(LRScheduler):
|
|
|
91
91
|
class LinearWarmupCosineDecayLRSchedulerConfig(LRSchedulerConfigBase):
|
|
92
92
|
name: Literal["linear_warmup_cosine_decay"] = "linear_warmup_cosine_decay"
|
|
93
93
|
|
|
94
|
-
|
|
95
|
-
r"""The
|
|
96
|
-
The learning rate is linearly increased from `warmup_start_lr` to the initial learning rate over this
|
|
94
|
+
warmup_duration: Duration
|
|
95
|
+
r"""The duration for the linear warmup phase.
|
|
96
|
+
The learning rate is linearly increased from `warmup_start_lr` to the initial learning rate over this duration."""
|
|
97
97
|
|
|
98
|
-
|
|
99
|
-
r"""The total
|
|
100
|
-
The learning rate is decayed to `min_lr` over this
|
|
98
|
+
max_duration: Duration
|
|
99
|
+
r"""The total duration.
|
|
100
|
+
The learning rate is decayed to `min_lr` over this duration."""
|
|
101
101
|
|
|
102
102
|
warmup_start_lr_factor: float = 0.0
|
|
103
103
|
r"""The initial learning rate for the linear warmup phase, as a factor of the initial learning rate.
|
|
@@ -121,11 +121,20 @@ class LinearWarmupCosineDecayLRSchedulerConfig(LRSchedulerConfigBase):
|
|
|
121
121
|
@override
|
|
122
122
|
def create_scheduler_impl(self, optimizer, lightning_module, lr):
|
|
123
123
|
num_steps_per_epoch = self.compute_num_steps_per_epoch(lightning_module)
|
|
124
|
-
warmup_steps =
|
|
125
|
-
|
|
124
|
+
warmup_steps = (
|
|
125
|
+
self.warmup_duration.to_steps(num_steps_per_epoch).value
|
|
126
|
+
* num_steps_per_epoch
|
|
127
|
+
)
|
|
128
|
+
max_steps = (
|
|
129
|
+
self.max_duration.to_steps(num_steps_per_epoch).value * num_steps_per_epoch
|
|
130
|
+
)
|
|
126
131
|
warmup_start_lr = self.warmup_start_lr_factor * lr
|
|
127
132
|
min_lr = self.min_lr_factor * lr
|
|
128
133
|
|
|
134
|
+
# Warmup and max steps should be at least 1.
|
|
135
|
+
warmup_steps = max(warmup_steps, 1)
|
|
136
|
+
max_steps = max(max_steps, 1)
|
|
137
|
+
|
|
129
138
|
# Create the scheduler
|
|
130
139
|
scheduler = LinearWarmupCosineAnnealingLR(
|
|
131
140
|
optimizer=optimizer,
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
nshtrainer/__init__.py,sha256=
|
|
1
|
+
nshtrainer/__init__.py,sha256=sUb2yNdkHHhrKWCeWA5QKIA1Xx3jkO1QGD5Pa-HvgbA,614
|
|
2
2
|
nshtrainer/_callback.py,sha256=A1zLsTy4b_wOYnInLLXGSRdHzT2yNa6mPEql-ozm0u0,1013
|
|
3
3
|
nshtrainer/_checkpoint/loader.py,sha256=5vjg-OFChXJjgiOVv8vnV8nwTscfdDtEdxQRz6uPfDE,14158
|
|
4
4
|
nshtrainer/_checkpoint/metadata.py,sha256=5D4PgKodzhLsmQvuF3xxkH49epKaegxi4wh_ImDTtns,4737
|
|
@@ -10,7 +10,7 @@ nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHm
|
|
|
10
10
|
nshtrainer/callbacks/actsave.py,sha256=qbnaKts4_dvjPeAaPtv7Ds12_vEWzaHUfg_--49NB9I,4041
|
|
11
11
|
nshtrainer/callbacks/base.py,sha256=NpjeKmonJ1Kaz5_39XSn3LlDwvbGjk6WV8BpHSNCvI4,3508
|
|
12
12
|
nshtrainer/callbacks/checkpoint/__init__.py,sha256=g-3zIthupERKqWZQw-A_busQPaPRkto6iHBV-M7nK1Y,527
|
|
13
|
-
nshtrainer/callbacks/checkpoint/_base.py,sha256=
|
|
13
|
+
nshtrainer/callbacks/checkpoint/_base.py,sha256=vvlwuD-20NozYVIolGGShmUdkkNYeuwN6xCoFnK4GiU,6157
|
|
14
14
|
nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=8BHgLAd3Tuzf5sup0guEAKF1jJiAwYsjdKBFYZw98ac,2171
|
|
15
15
|
nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=CWWv0cSwQ1VAX26N7hAyMxbNCk26Keh39oQguBEK5To,1102
|
|
16
16
|
nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=ctT88EGT22_t_6tr5r7Sfo43cuve6XeroBnBYRMPOus,3372
|
|
@@ -25,6 +25,8 @@ nshtrainer/callbacks/print_table.py,sha256=_FdAHhqylWGk4Z0c2FrLFeiMA4jhfA_beZRK_
|
|
|
25
25
|
nshtrainer/callbacks/throughput_monitor.py,sha256=H_ocXErZxUO3dxFk8Tx_VQdpI9E_Ztvqof5WtFevLyQ,1838
|
|
26
26
|
nshtrainer/callbacks/timer.py,sha256=quS79oYClDUvQxJkNWmDMe0hwRUkkREgTgqzVrnom50,4607
|
|
27
27
|
nshtrainer/callbacks/wandb_watch.py,sha256=Y6SEXfIx3kDDQbI5zpP53BVq0FBLJbLd3RJsiHZk1-Y,2921
|
|
28
|
+
nshtrainer/config/__init__.py,sha256=v9RtlM1Pqj_4fCDfskgxEtiGtbWH3Tj7lqNsKCDQ4gk,119
|
|
29
|
+
nshtrainer/config/duration.py,sha256=f_obz0eorkktI3HzAuIawABDkvuL4lDqCxcPb3UW7Q4,692
|
|
28
30
|
nshtrainer/data/__init__.py,sha256=7mk1tr7SWUZ7ySbsf0y0ZPszk7u4QznPhQ-7wnpH9ec,149
|
|
29
31
|
nshtrainer/data/balanced_batch_sampler.py,sha256=dGBTDDtlBU6c-ZlVQOCnTW7SjTB5hczWsOWEdUWjvkA,4385
|
|
30
32
|
nshtrainer/data/transform.py,sha256=6SNs3_TpNpfhcwTwvPKyEJ3opM1OT7LmMEYQNHKgRl8,2227
|
|
@@ -52,7 +54,7 @@ nshtrainer/loggers/tensorboard.py,sha256=wL2amRSdP68zbslZvBeM0ZQBnjF3hIKsz-_lBbd
|
|
|
52
54
|
nshtrainer/loggers/wandb.py,sha256=FPwbf618AYmuPzHdhd1ZFhJ8qDjwTUiSe7cm7g3KCyM,5112
|
|
53
55
|
nshtrainer/lr_scheduler/__init__.py,sha256=uEvgaFAs-4s_bAEMaildy0GT6OvgpgOEKTuzqutESHE,736
|
|
54
56
|
nshtrainer/lr_scheduler/_base.py,sha256=7xOIuxQ86YHbFWG5a3gX46emQj1WN_LaY4-i0Q1TDBg,3659
|
|
55
|
-
nshtrainer/lr_scheduler/linear_warmup_cosine.py,sha256=
|
|
57
|
+
nshtrainer/lr_scheduler/linear_warmup_cosine.py,sha256=pmX5n7mmhSqPTz4Nu9g_JTsE9gzCkuU4V3GuAHUsDoA,5451
|
|
56
58
|
nshtrainer/lr_scheduler/reduce_lr_on_plateau.py,sha256=h76oTHYpMxauV_l6lviya5DW-WKArwxxf7ZQizhmbCw,2782
|
|
57
59
|
nshtrainer/metrics/__init__.py,sha256=ObLIELGguIEcUpRsUkqh1ltrvZii6vglTpJGrPvoy00,50
|
|
58
60
|
nshtrainer/metrics/_config.py,sha256=jgRBfDAQLFTW7AiUY7CRtdfts6CR6keeuqm0FFMWCzQ,1288
|
|
@@ -87,6 +89,6 @@ nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
|
|
|
87
89
|
nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
|
|
88
90
|
nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
|
|
89
91
|
nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
|
|
90
|
-
nshtrainer-0.
|
|
91
|
-
nshtrainer-0.
|
|
92
|
-
nshtrainer-0.
|
|
92
|
+
nshtrainer-0.30.0.dist-info/METADATA,sha256=lDudS-lD7exw8lNe_3vT13ysnk491QCkObXGLQtjhMk,916
|
|
93
|
+
nshtrainer-0.30.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
94
|
+
nshtrainer-0.30.0.dist-info/RECORD,,
|
|
File without changes
|