nshtrainer 1.0.0b36__py3-none-any.whl → 1.0.0b39__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/_directory.py +3 -1
- nshtrainer/callbacks/checkpoint/last_checkpoint.py +13 -12
- nshtrainer/configs/__init__.py +12 -4
- nshtrainer/configs/loggers/__init__.py +6 -4
- nshtrainer/configs/loggers/actsave/__init__.py +4 -2
- nshtrainer/configs/loggers/base/__init__.py +11 -0
- nshtrainer/configs/loggers/csv/__init__.py +4 -2
- nshtrainer/configs/loggers/tensorboard/__init__.py +4 -2
- nshtrainer/configs/loggers/wandb/__init__.py +4 -2
- nshtrainer/configs/lr_scheduler/__init__.py +4 -2
- nshtrainer/configs/lr_scheduler/base/__init__.py +11 -0
- nshtrainer/configs/lr_scheduler/linear_warmup_cosine/__init__.py +4 -0
- nshtrainer/configs/lr_scheduler/reduce_lr_on_plateau/__init__.py +4 -0
- nshtrainer/configs/nn/__init__.py +4 -2
- nshtrainer/configs/nn/mlp/__init__.py +2 -2
- nshtrainer/configs/nn/nonlinearity/__init__.py +4 -2
- nshtrainer/configs/optimizer/__init__.py +2 -0
- nshtrainer/configs/trainer/__init__.py +2 -2
- nshtrainer/configs/trainer/_config/__init__.py +2 -2
- nshtrainer/loggers/__init__.py +3 -8
- nshtrainer/loggers/actsave.py +5 -2
- nshtrainer/loggers/{_base.py → base.py} +13 -1
- nshtrainer/loggers/csv.py +5 -3
- nshtrainer/loggers/tensorboard.py +5 -3
- nshtrainer/loggers/wandb.py +5 -3
- nshtrainer/lr_scheduler/__init__.py +2 -2
- nshtrainer/lr_scheduler/{_base.py → base.py} +3 -0
- nshtrainer/lr_scheduler/linear_warmup_cosine.py +56 -54
- nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +4 -2
- nshtrainer/nn/__init__.py +1 -1
- nshtrainer/nn/mlp.py +4 -4
- nshtrainer/nn/nonlinearity.py +37 -33
- nshtrainer/optimizer.py +8 -2
- nshtrainer/trainer/_config.py +2 -2
- {nshtrainer-1.0.0b36.dist-info → nshtrainer-1.0.0b39.dist-info}/METADATA +1 -1
- {nshtrainer-1.0.0b36.dist-info → nshtrainer-1.0.0b39.dist-info}/RECORD +37 -37
- nshtrainer/configs/loggers/_base/__init__.py +0 -9
- nshtrainer/configs/lr_scheduler/_base/__init__.py +0 -9
- {nshtrainer-1.0.0b36.dist-info → nshtrainer-1.0.0b39.dist-info}/WHEEL +0 -0
@@ -5,8 +5,8 @@ from typing import Annotated
|
|
5
5
|
import nshconfig as C
|
6
6
|
from typing_extensions import TypeAliasType
|
7
7
|
|
8
|
-
from .
|
9
|
-
from .
|
8
|
+
from .base import LRSchedulerConfigBase as LRSchedulerConfigBase
|
9
|
+
from .base import LRSchedulerMetadata as LRSchedulerMetadata
|
10
10
|
from .linear_warmup_cosine import (
|
11
11
|
LinearWarmupCosineAnnealingLR as LinearWarmupCosineAnnealingLR,
|
12
12
|
)
|
@@ -94,3 +94,6 @@ class LRSchedulerConfigBase(C.Config, ABC):
|
|
94
94
|
# ^ This is a hack to trigger the computation of the estimated stepping batches
|
95
95
|
# and make sure that the `trainer.num_training_batches` attribute is set.
|
96
96
|
return math.ceil(trainer.num_training_batches / trainer.accumulate_grad_batches)
|
97
|
+
|
98
|
+
|
99
|
+
lr_scheduler_registry = C.Registry(LRSchedulerConfigBase, discriminator="name")
|
@@ -6,10 +6,64 @@ from typing import Literal
|
|
6
6
|
|
7
7
|
from torch.optim import Optimizer
|
8
8
|
from torch.optim.lr_scheduler import LRScheduler
|
9
|
-
from typing_extensions import override
|
9
|
+
from typing_extensions import final, override
|
10
10
|
|
11
11
|
from ..util.config import DurationConfig
|
12
|
-
from .
|
12
|
+
from .base import LRSchedulerConfigBase, LRSchedulerMetadata, lr_scheduler_registry
|
13
|
+
|
14
|
+
|
15
|
+
@final
|
16
|
+
@lr_scheduler_registry.register
|
17
|
+
class LinearWarmupCosineDecayLRSchedulerConfig(LRSchedulerConfigBase):
|
18
|
+
name: Literal["linear_warmup_cosine_decay"] = "linear_warmup_cosine_decay"
|
19
|
+
|
20
|
+
warmup_duration: DurationConfig
|
21
|
+
r"""The duration for the linear warmup phase.
|
22
|
+
The learning rate is linearly increased from `warmup_start_lr` to the initial learning rate over this duration."""
|
23
|
+
|
24
|
+
max_duration: DurationConfig
|
25
|
+
r"""The total duration.
|
26
|
+
The learning rate is decayed to `min_lr` over this duration."""
|
27
|
+
|
28
|
+
warmup_start_lr_factor: float = 0.0
|
29
|
+
r"""The initial learning rate for the linear warmup phase, as a factor of the initial learning rate.
|
30
|
+
The learning rate is linearly increased from this value to the initial learning rate over `warmup_epochs` epochs."""
|
31
|
+
|
32
|
+
min_lr_factor: float = 0.0
|
33
|
+
r"""The minimum learning rate, as a factor of the initial learning rate.
|
34
|
+
The learning rate is decayed to this value over `max_epochs` epochs."""
|
35
|
+
|
36
|
+
annealing: bool = False
|
37
|
+
r"""Whether to restart the learning rate schedule after `max_epochs` epochs.
|
38
|
+
If `False`, the learning rate will be decayed to `min_lr` over `max_epochs` epochs, and then the learning rate will be set to `min_lr` for all subsequent epochs.
|
39
|
+
If `True`, the learning rate will be decayed to `min_lr` over `max_epochs` epochs, and then the learning rate will be increased back to the initial learning rate over `max_epochs` epochs, and so on (this is called a cosine annealing schedule)."""
|
40
|
+
|
41
|
+
@override
|
42
|
+
def metadata(self) -> LRSchedulerMetadata:
|
43
|
+
return {
|
44
|
+
"interval": "step",
|
45
|
+
}
|
46
|
+
|
47
|
+
@override
|
48
|
+
def create_scheduler_impl(self, optimizer, lightning_module):
|
49
|
+
num_steps_per_epoch = self.compute_num_steps_per_epoch(lightning_module)
|
50
|
+
warmup_steps = self.warmup_duration.to_steps(num_steps_per_epoch).value
|
51
|
+
max_steps = self.max_duration.to_steps(num_steps_per_epoch).value
|
52
|
+
|
53
|
+
# Warmup and max steps should be at least 1.
|
54
|
+
warmup_steps = max(warmup_steps, 1)
|
55
|
+
max_steps = max(max_steps, 1)
|
56
|
+
|
57
|
+
# Create the scheduler
|
58
|
+
scheduler = LinearWarmupCosineAnnealingLR(
|
59
|
+
optimizer=optimizer,
|
60
|
+
warmup_epochs=warmup_steps,
|
61
|
+
max_epochs=max_steps,
|
62
|
+
warmup_start_lr_factor=self.warmup_start_lr_factor,
|
63
|
+
eta_min_factor=self.min_lr_factor,
|
64
|
+
should_restart=self.annealing,
|
65
|
+
)
|
66
|
+
return scheduler
|
13
67
|
|
14
68
|
|
15
69
|
class LinearWarmupCosineAnnealingLR(LRScheduler):
|
@@ -89,55 +143,3 @@ class LinearWarmupCosineAnnealingLR(LRScheduler):
|
|
89
143
|
+ self.eta_min_factor * base_lr
|
90
144
|
for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
|
91
145
|
]
|
92
|
-
|
93
|
-
|
94
|
-
class LinearWarmupCosineDecayLRSchedulerConfig(LRSchedulerConfigBase):
|
95
|
-
name: Literal["linear_warmup_cosine_decay"] = "linear_warmup_cosine_decay"
|
96
|
-
|
97
|
-
warmup_duration: DurationConfig
|
98
|
-
r"""The duration for the linear warmup phase.
|
99
|
-
The learning rate is linearly increased from `warmup_start_lr` to the initial learning rate over this duration."""
|
100
|
-
|
101
|
-
max_duration: DurationConfig
|
102
|
-
r"""The total duration.
|
103
|
-
The learning rate is decayed to `min_lr` over this duration."""
|
104
|
-
|
105
|
-
warmup_start_lr_factor: float = 0.0
|
106
|
-
r"""The initial learning rate for the linear warmup phase, as a factor of the initial learning rate.
|
107
|
-
The learning rate is linearly increased from this value to the initial learning rate over `warmup_epochs` epochs."""
|
108
|
-
|
109
|
-
min_lr_factor: float = 0.0
|
110
|
-
r"""The minimum learning rate, as a factor of the initial learning rate.
|
111
|
-
The learning rate is decayed to this value over `max_epochs` epochs."""
|
112
|
-
|
113
|
-
annealing: bool = False
|
114
|
-
r"""Whether to restart the learning rate schedule after `max_epochs` epochs.
|
115
|
-
If `False`, the learning rate will be decayed to `min_lr` over `max_epochs` epochs, and then the learning rate will be set to `min_lr` for all subsequent epochs.
|
116
|
-
If `True`, the learning rate will be decayed to `min_lr` over `max_epochs` epochs, and then the learning rate will be increased back to the initial learning rate over `max_epochs` epochs, and so on (this is called a cosine annealing schedule)."""
|
117
|
-
|
118
|
-
@override
|
119
|
-
def metadata(self) -> LRSchedulerMetadata:
|
120
|
-
return {
|
121
|
-
"interval": "step",
|
122
|
-
}
|
123
|
-
|
124
|
-
@override
|
125
|
-
def create_scheduler_impl(self, optimizer, lightning_module):
|
126
|
-
num_steps_per_epoch = self.compute_num_steps_per_epoch(lightning_module)
|
127
|
-
warmup_steps = self.warmup_duration.to_steps(num_steps_per_epoch).value
|
128
|
-
max_steps = self.max_duration.to_steps(num_steps_per_epoch).value
|
129
|
-
|
130
|
-
# Warmup and max steps should be at least 1.
|
131
|
-
warmup_steps = max(warmup_steps, 1)
|
132
|
-
max_steps = max(max_steps, 1)
|
133
|
-
|
134
|
-
# Create the scheduler
|
135
|
-
scheduler = LinearWarmupCosineAnnealingLR(
|
136
|
-
optimizer=optimizer,
|
137
|
-
warmup_epochs=warmup_steps,
|
138
|
-
max_epochs=max_steps,
|
139
|
-
warmup_start_lr_factor=self.warmup_start_lr_factor,
|
140
|
-
eta_min_factor=self.min_lr_factor,
|
141
|
-
should_restart=self.annealing,
|
142
|
-
)
|
143
|
-
return scheduler
|
@@ -4,12 +4,14 @@ from typing import Literal
|
|
4
4
|
|
5
5
|
from lightning.pytorch.utilities.types import LRSchedulerConfigType
|
6
6
|
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
7
|
-
from typing_extensions import override
|
7
|
+
from typing_extensions import final, override
|
8
8
|
|
9
9
|
from ..metrics._config import MetricConfig
|
10
|
-
from .
|
10
|
+
from .base import LRSchedulerConfigBase, LRSchedulerMetadata, lr_scheduler_registry
|
11
11
|
|
12
12
|
|
13
|
+
@final
|
14
|
+
@lr_scheduler_registry.register
|
13
15
|
class ReduceLROnPlateauConfig(LRSchedulerConfigBase):
|
14
16
|
"""Reduce learning rate when a metric has stopped improving."""
|
15
17
|
|
nshtrainer/nn/__init__.py
CHANGED
@@ -6,12 +6,12 @@ from .mlp import MLPConfigDict as MLPConfigDict
|
|
6
6
|
from .mlp import ResidualSequential as ResidualSequential
|
7
7
|
from .module_dict import TypedModuleDict as TypedModuleDict
|
8
8
|
from .module_list import TypedModuleList as TypedModuleList
|
9
|
-
from .nonlinearity import BaseNonlinearityConfig as BaseNonlinearityConfig
|
10
9
|
from .nonlinearity import ELUNonlinearityConfig as ELUNonlinearityConfig
|
11
10
|
from .nonlinearity import GELUNonlinearityConfig as GELUNonlinearityConfig
|
12
11
|
from .nonlinearity import LeakyReLUNonlinearityConfig as LeakyReLUNonlinearityConfig
|
13
12
|
from .nonlinearity import MishNonlinearityConfig as MishNonlinearityConfig
|
14
13
|
from .nonlinearity import NonlinearityConfig as NonlinearityConfig
|
14
|
+
from .nonlinearity import NonlinearityConfigBase as NonlinearityConfigBase
|
15
15
|
from .nonlinearity import PReLUConfig as PReLUConfig
|
16
16
|
from .nonlinearity import ReLUNonlinearityConfig as ReLUNonlinearityConfig
|
17
17
|
from .nonlinearity import SigmoidNonlinearityConfig as SigmoidNonlinearityConfig
|
nshtrainer/nn/mlp.py
CHANGED
@@ -9,7 +9,7 @@ import torch
|
|
9
9
|
import torch.nn as nn
|
10
10
|
from typing_extensions import TypedDict, override
|
11
11
|
|
12
|
-
from .nonlinearity import
|
12
|
+
from .nonlinearity import NonlinearityConfig, NonlinearityConfigBase
|
13
13
|
|
14
14
|
|
15
15
|
@runtime_checkable
|
@@ -92,11 +92,11 @@ class MLPConfig(C.Config):
|
|
92
92
|
|
93
93
|
def MLP(
|
94
94
|
dims: Sequence[int],
|
95
|
-
activation:
|
95
|
+
activation: NonlinearityConfigBase
|
96
96
|
| nn.Module
|
97
97
|
| Callable[[], nn.Module]
|
98
98
|
| None = None,
|
99
|
-
nonlinearity:
|
99
|
+
nonlinearity: NonlinearityConfigBase
|
100
100
|
| nn.Module
|
101
101
|
| Callable[[], nn.Module]
|
102
102
|
| None = None,
|
@@ -153,7 +153,7 @@ def MLP(
|
|
153
153
|
layers.append(nn.Dropout(dropout))
|
154
154
|
if i < len(dims) - 2:
|
155
155
|
match activation:
|
156
|
-
case
|
156
|
+
case NonlinearityConfigBase():
|
157
157
|
layers.append(activation.create_module())
|
158
158
|
case nn.Module():
|
159
159
|
# In this case, we create a deep copy of the module to avoid sharing parameters (if any).
|
nshtrainer/nn/nonlinearity.py
CHANGED
@@ -7,10 +7,10 @@ import nshconfig as C
|
|
7
7
|
import torch
|
8
8
|
import torch.nn as nn
|
9
9
|
import torch.nn.functional as F
|
10
|
-
from typing_extensions import final, override
|
10
|
+
from typing_extensions import TypeAliasType, final, override
|
11
11
|
|
12
12
|
|
13
|
-
class
|
13
|
+
class NonlinearityConfigBase(C.Config, ABC):
|
14
14
|
@abstractmethod
|
15
15
|
def create_module(self) -> nn.Module: ...
|
16
16
|
|
@@ -18,8 +18,12 @@ class BaseNonlinearityConfig(C.Config, ABC):
|
|
18
18
|
def __call__(self, x: torch.Tensor) -> torch.Tensor: ...
|
19
19
|
|
20
20
|
|
21
|
+
nonlinearity_registry = C.Registry(NonlinearityConfigBase, discriminator="name")
|
22
|
+
|
23
|
+
|
21
24
|
@final
|
22
|
-
|
25
|
+
@nonlinearity_registry.register
|
26
|
+
class ReLUNonlinearityConfig(NonlinearityConfigBase):
|
23
27
|
name: Literal["relu"] = "relu"
|
24
28
|
|
25
29
|
@override
|
@@ -31,7 +35,8 @@ class ReLUNonlinearityConfig(BaseNonlinearityConfig):
|
|
31
35
|
|
32
36
|
|
33
37
|
@final
|
34
|
-
|
38
|
+
@nonlinearity_registry.register
|
39
|
+
class SigmoidNonlinearityConfig(NonlinearityConfigBase):
|
35
40
|
name: Literal["sigmoid"] = "sigmoid"
|
36
41
|
|
37
42
|
@override
|
@@ -43,7 +48,8 @@ class SigmoidNonlinearityConfig(BaseNonlinearityConfig):
|
|
43
48
|
|
44
49
|
|
45
50
|
@final
|
46
|
-
|
51
|
+
@nonlinearity_registry.register
|
52
|
+
class TanhNonlinearityConfig(NonlinearityConfigBase):
|
47
53
|
name: Literal["tanh"] = "tanh"
|
48
54
|
|
49
55
|
@override
|
@@ -55,7 +61,8 @@ class TanhNonlinearityConfig(BaseNonlinearityConfig):
|
|
55
61
|
|
56
62
|
|
57
63
|
@final
|
58
|
-
|
64
|
+
@nonlinearity_registry.register
|
65
|
+
class SoftmaxNonlinearityConfig(NonlinearityConfigBase):
|
59
66
|
name: Literal["softmax"] = "softmax"
|
60
67
|
|
61
68
|
dim: int = -1
|
@@ -70,7 +77,8 @@ class SoftmaxNonlinearityConfig(BaseNonlinearityConfig):
|
|
70
77
|
|
71
78
|
|
72
79
|
@final
|
73
|
-
|
80
|
+
@nonlinearity_registry.register
|
81
|
+
class SoftplusNonlinearityConfig(NonlinearityConfigBase):
|
74
82
|
name: Literal["softplus"] = "softplus"
|
75
83
|
|
76
84
|
beta: float = 1.0
|
@@ -88,7 +96,8 @@ class SoftplusNonlinearityConfig(BaseNonlinearityConfig):
|
|
88
96
|
|
89
97
|
|
90
98
|
@final
|
91
|
-
|
99
|
+
@nonlinearity_registry.register
|
100
|
+
class SoftsignNonlinearityConfig(NonlinearityConfigBase):
|
92
101
|
name: Literal["softsign"] = "softsign"
|
93
102
|
|
94
103
|
@override
|
@@ -100,7 +109,8 @@ class SoftsignNonlinearityConfig(BaseNonlinearityConfig):
|
|
100
109
|
|
101
110
|
|
102
111
|
@final
|
103
|
-
|
112
|
+
@nonlinearity_registry.register
|
113
|
+
class ELUNonlinearityConfig(NonlinearityConfigBase):
|
104
114
|
name: Literal["elu"] = "elu"
|
105
115
|
|
106
116
|
alpha: float = 1.0
|
@@ -115,7 +125,8 @@ class ELUNonlinearityConfig(BaseNonlinearityConfig):
|
|
115
125
|
|
116
126
|
|
117
127
|
@final
|
118
|
-
|
128
|
+
@nonlinearity_registry.register
|
129
|
+
class LeakyReLUNonlinearityConfig(NonlinearityConfigBase):
|
119
130
|
name: Literal["leaky_relu"] = "leaky_relu"
|
120
131
|
|
121
132
|
negative_slope: float = 1.0e-2
|
@@ -130,7 +141,8 @@ class LeakyReLUNonlinearityConfig(BaseNonlinearityConfig):
|
|
130
141
|
|
131
142
|
|
132
143
|
@final
|
133
|
-
|
144
|
+
@nonlinearity_registry.register
|
145
|
+
class PReLUConfig(NonlinearityConfigBase):
|
134
146
|
name: Literal["prelu"] = "prelu"
|
135
147
|
|
136
148
|
num_parameters: int = 1
|
@@ -152,7 +164,8 @@ class PReLUConfig(BaseNonlinearityConfig):
|
|
152
164
|
|
153
165
|
|
154
166
|
@final
|
155
|
-
|
167
|
+
@nonlinearity_registry.register
|
168
|
+
class GELUNonlinearityConfig(NonlinearityConfigBase):
|
156
169
|
name: Literal["gelu"] = "gelu"
|
157
170
|
|
158
171
|
approximate: Literal["tanh", "none"] = "none"
|
@@ -167,7 +180,8 @@ class GELUNonlinearityConfig(BaseNonlinearityConfig):
|
|
167
180
|
|
168
181
|
|
169
182
|
@final
|
170
|
-
|
183
|
+
@nonlinearity_registry.register
|
184
|
+
class SwishNonlinearityConfig(NonlinearityConfigBase):
|
171
185
|
name: Literal["swish"] = "swish"
|
172
186
|
|
173
187
|
@override
|
@@ -179,7 +193,8 @@ class SwishNonlinearityConfig(BaseNonlinearityConfig):
|
|
179
193
|
|
180
194
|
|
181
195
|
@final
|
182
|
-
|
196
|
+
@nonlinearity_registry.register
|
197
|
+
class SiLUNonlinearityConfig(NonlinearityConfigBase):
|
183
198
|
name: Literal["silu"] = "silu"
|
184
199
|
|
185
200
|
@override
|
@@ -191,7 +206,8 @@ class SiLUNonlinearityConfig(BaseNonlinearityConfig):
|
|
191
206
|
|
192
207
|
|
193
208
|
@final
|
194
|
-
|
209
|
+
@nonlinearity_registry.register
|
210
|
+
class MishNonlinearityConfig(NonlinearityConfigBase):
|
195
211
|
name: Literal["mish"] = "mish"
|
196
212
|
|
197
213
|
@override
|
@@ -210,7 +226,8 @@ class SwiGLU(nn.SiLU):
|
|
210
226
|
|
211
227
|
|
212
228
|
@final
|
213
|
-
|
229
|
+
@nonlinearity_registry.register
|
230
|
+
class SwiGLUNonlinearityConfig(NonlinearityConfigBase):
|
214
231
|
name: Literal["swiglu"] = "swiglu"
|
215
232
|
|
216
233
|
@override
|
@@ -222,20 +239,7 @@ class SwiGLUNonlinearityConfig(BaseNonlinearityConfig):
|
|
222
239
|
return input * F.silu(gate)
|
223
240
|
|
224
241
|
|
225
|
-
NonlinearityConfig =
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
| SoftmaxNonlinearityConfig
|
230
|
-
| SoftplusNonlinearityConfig
|
231
|
-
| SoftsignNonlinearityConfig
|
232
|
-
| ELUNonlinearityConfig
|
233
|
-
| LeakyReLUNonlinearityConfig
|
234
|
-
| PReLUConfig
|
235
|
-
| GELUNonlinearityConfig
|
236
|
-
| SwishNonlinearityConfig
|
237
|
-
| SiLUNonlinearityConfig
|
238
|
-
| MishNonlinearityConfig
|
239
|
-
| SwiGLUNonlinearityConfig,
|
240
|
-
C.Field(discriminator="name"),
|
241
|
-
]
|
242
|
+
NonlinearityConfig = TypeAliasType(
|
243
|
+
"NonlinearityConfig",
|
244
|
+
Annotated[NonlinearityConfigBase, nonlinearity_registry.DynamicResolution()],
|
245
|
+
)
|
nshtrainer/optimizer.py
CHANGED
@@ -7,7 +7,7 @@ from typing import Annotated, Any, Literal
|
|
7
7
|
import nshconfig as C
|
8
8
|
import torch.nn as nn
|
9
9
|
from torch.optim import Optimizer
|
10
|
-
from typing_extensions import TypeAliasType, override
|
10
|
+
from typing_extensions import TypeAliasType, final, override
|
11
11
|
|
12
12
|
|
13
13
|
class OptimizerConfigBase(C.Config, ABC):
|
@@ -18,6 +18,11 @@ class OptimizerConfigBase(C.Config, ABC):
|
|
18
18
|
) -> Optimizer: ...
|
19
19
|
|
20
20
|
|
21
|
+
optimizer_registry = C.Registry(OptimizerConfigBase, discriminator="name")
|
22
|
+
|
23
|
+
|
24
|
+
@final
|
25
|
+
@optimizer_registry.register
|
21
26
|
class AdamWConfig(OptimizerConfigBase):
|
22
27
|
name: Literal["adamw"] = "adamw"
|
23
28
|
|
@@ -58,5 +63,6 @@ class AdamWConfig(OptimizerConfigBase):
|
|
58
63
|
|
59
64
|
|
60
65
|
OptimizerConfig = TypeAliasType(
|
61
|
-
"OptimizerConfig",
|
66
|
+
"OptimizerConfig",
|
67
|
+
Annotated[OptimizerConfigBase, optimizer_registry.DynamicResolution()],
|
62
68
|
)
|
nshtrainer/trainer/_config.py
CHANGED
@@ -48,8 +48,8 @@ from ..loggers import (
|
|
48
48
|
TensorboardLoggerConfig,
|
49
49
|
WandbLoggerConfig,
|
50
50
|
)
|
51
|
-
from ..loggers._base import BaseLoggerConfig
|
52
51
|
from ..loggers.actsave import ActSaveLoggerConfig
|
52
|
+
from ..loggers.base import LoggerConfigBase
|
53
53
|
from ..metrics._config import MetricConfig
|
54
54
|
from ..profiler import ProfilerConfig
|
55
55
|
from ..util._environment_info import EnvironmentConfig
|
@@ -770,7 +770,7 @@ class TrainerConfig(C.Config):
|
|
770
770
|
yield self.auto_set_debug_flag
|
771
771
|
yield from self.callbacks
|
772
772
|
|
773
|
-
def _nshtrainer_all_logger_configs(self) -> Iterable[
|
773
|
+
def _nshtrainer_all_logger_configs(self) -> Iterable[LoggerConfigBase | None]:
|
774
774
|
# Disable all loggers if barebones mode is enabled
|
775
775
|
if self.barebones:
|
776
776
|
return
|
@@ -3,7 +3,7 @@ nshtrainer/__init__.py,sha256=g_moPnfQxSxFZX5NB9ILQQOJrt4RTRuiFt9N0STIpxM,874
|
|
3
3
|
nshtrainer/_callback.py,sha256=tXQCDzS6CvMTuTY5lQSH5qZs1pXUi-gt9bQdpXMVdEs,12715
|
4
4
|
nshtrainer/_checkpoint/metadata.py,sha256=PHy-54Cg-o3OtCffAqrVv6ZVMU7zhRo_-sZiSEEno1Y,5019
|
5
5
|
nshtrainer/_checkpoint/saver.py,sha256=LOP8jjKF0Dw9x9H-BKrLMWlEp1XTan2DUK0zQUCWw5U,1360
|
6
|
-
nshtrainer/_directory.py,sha256=
|
6
|
+
nshtrainer/_directory.py,sha256=TJR9ccyuzRlAVfVjGyeQ3E2AFAcz-XbBCxWfiXo2SlY,3191
|
7
7
|
nshtrainer/_experimental/__init__.py,sha256=U4S_2y3zgLZVfMenHRaJFBW8yqh2mUBuI291LGQVOJ8,35
|
8
8
|
nshtrainer/_hf_hub.py,sha256=4OsCbIITnZk_YLyoMrVyZ0SIN04FBxlC0ig2Et8UAdo,14287
|
9
9
|
nshtrainer/callbacks/__init__.py,sha256=4giOYT8A709UOLRtQEt16QbOAFUHCjJ_aLB7ITTwXJI,3577
|
@@ -12,7 +12,7 @@ nshtrainer/callbacks/base.py,sha256=Alaou1IHAIlMEM7g58d_02ozY2xWlshBN7fsw5Ee21s,
|
|
12
12
|
nshtrainer/callbacks/checkpoint/__init__.py,sha256=l8tkHc83_mLiU0-wT09SWdRzwpm2ulbkLzcuCmuTwzE,620
|
13
13
|
nshtrainer/callbacks/checkpoint/_base.py,sha256=ZVEUVl5kjCSSe69Q0rMUbKBNNUog0pxBwWkeyuxG2w0,6304
|
14
14
|
nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=2CQuhPJ3Fi7lDw7z-J8kXXXuDU8-4HcU48oZxR49apk,2667
|
15
|
-
nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=
|
15
|
+
nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=vn-as3ex7kaTRcKsIurVtM6kUSHYNwHJeYG82j2dMcc,3554
|
16
16
|
nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=nljzETqkHwA-4g8mxaeFK5HxA8My0dlIPzIUscSMWyk,3525
|
17
17
|
nshtrainer/callbacks/debug_flag.py,sha256=96fuP0C7C6dSs1GiMeUYzzs0X3Q4Pjt9JVWg3b75fU4,1748
|
18
18
|
nshtrainer/callbacks/directory_setup.py,sha256=wPas_Ren8ANejogmIdKhqqgj4ulxz9AS_8xVIAfRXa0,2565
|
@@ -30,7 +30,7 @@ nshtrainer/callbacks/shared_parameters.py,sha256=s94jJTAIbDGukYJu6l247QonVOCudGC
|
|
30
30
|
nshtrainer/callbacks/timer.py,sha256=gDcw_K_ikf0bkVgxQ0cDhvvNvz6GLZVLcatuKfh0ORU,4731
|
31
31
|
nshtrainer/callbacks/wandb_upload_code.py,sha256=shV7UtnXgY2bUlXdVrXiaDs0PNLlIt7TzNJkJPkzvzI,2414
|
32
32
|
nshtrainer/callbacks/wandb_watch.py,sha256=VB14Dy5ZRXQ3di0fPv0K_DFJurLhroLPytnuwQBiJFg,3037
|
33
|
-
nshtrainer/configs/__init__.py,sha256=
|
33
|
+
nshtrainer/configs/__init__.py,sha256=MZfcSKhnjtVObBvVv9lu8L2cFTLINP5zcTQvWnz8jdk,14505
|
34
34
|
nshtrainer/configs/_checkpoint/__init__.py,sha256=6s7Y68StboqscY2G4P_QG443jz5aiym5SjOogIljWLg,342
|
35
35
|
nshtrainer/configs/_checkpoint/metadata/__init__.py,sha256=oOPfYkXTjKgm6pluGsG6V1TPyCEGjsQpHVL-LffSUFQ,290
|
36
36
|
nshtrainer/configs/_directory/__init__.py,sha256=_oO7vM9DhzHSxtZcv86sTi7hZIptnK1gr-AP9mqQ370,386
|
@@ -58,29 +58,29 @@ nshtrainer/configs/callbacks/shared_parameters/__init__.py,sha256=AU7_bSnSRSlj16
|
|
58
58
|
nshtrainer/configs/callbacks/timer/__init__.py,sha256=cOUtbsl0_OhCO0fIcBfLuIF6FEGBHQu7AvQFzwVznWQ,413
|
59
59
|
nshtrainer/configs/callbacks/wandb_upload_code/__init__.py,sha256=CJeCc9OCu5F39lWiY5aIc4WxQlgBvB-8cga6cQtw0GQ,482
|
60
60
|
nshtrainer/configs/callbacks/wandb_watch/__init__.py,sha256=dzz1oavL1BwELE33xus45_avBEAZDeB6xtcb6CsOEos,431
|
61
|
-
nshtrainer/configs/loggers/__init__.py,sha256=
|
62
|
-
nshtrainer/configs/loggers/
|
63
|
-
nshtrainer/configs/loggers/
|
64
|
-
nshtrainer/configs/loggers/csv/__init__.py,sha256=
|
65
|
-
nshtrainer/configs/loggers/tensorboard/__init__.py,sha256=
|
66
|
-
nshtrainer/configs/loggers/wandb/__init__.py,sha256=
|
67
|
-
nshtrainer/configs/lr_scheduler/__init__.py,sha256=
|
68
|
-
nshtrainer/configs/lr_scheduler/
|
69
|
-
nshtrainer/configs/lr_scheduler/linear_warmup_cosine/__init__.py,sha256=
|
70
|
-
nshtrainer/configs/lr_scheduler/reduce_lr_on_plateau/__init__.py,sha256=
|
61
|
+
nshtrainer/configs/loggers/__init__.py,sha256=GT7PO7UM3Mo87N616mGucc2ZRyGP8nQWBd_VJ_8RGXo,1337
|
62
|
+
nshtrainer/configs/loggers/actsave/__init__.py,sha256=J7SnbD-zxUynWSskJezooFyBZdnhgTWyybRvwn9gzy4,377
|
63
|
+
nshtrainer/configs/loggers/base/__init__.py,sha256=HLUfEDbjaAXqzsFmQbjdciIWzR1st1gRLKTCFvUFEX0,262
|
64
|
+
nshtrainer/configs/loggers/csv/__init__.py,sha256=gawaDX92JObGSmBqYpfNHWMHBwVOofS694W-1Y2GWDU,353
|
65
|
+
nshtrainer/configs/loggers/tensorboard/__init__.py,sha256=phzm-TnBkdkibTgoOxIIcAliqL3zU8gSNK61Mwxs1CM,410
|
66
|
+
nshtrainer/configs/loggers/wandb/__init__.py,sha256=TDcD5WZSKenc2mgIXhwz2l96l8P_Ur3N5CzEol5AKGw,746
|
67
|
+
nshtrainer/configs/lr_scheduler/__init__.py,sha256=xtiUx0isxA82-uXMn4-KmPnDCfbUkpAnd2_pFupAAKQ,1137
|
68
|
+
nshtrainer/configs/lr_scheduler/base/__init__.py,sha256=6Cx8r4rdxeSYxc_z0o7drKCblGJU_zzqrOoYlWYR5qY,305
|
69
|
+
nshtrainer/configs/lr_scheduler/linear_warmup_cosine/__init__.py,sha256=5ZMLDO9VL6SNU6pF-62lDnpmqix3_Ol9DdEwiuOPYlA,675
|
70
|
+
nshtrainer/configs/lr_scheduler/reduce_lr_on_plateau/__init__.py,sha256=w-vq8UbRGPX8DZVWCMC5eIrbvVc_guxjj7Du9AaeKCw,609
|
71
71
|
nshtrainer/configs/metrics/__init__.py,sha256=mK_xgXJDAyGY4K_x_Zo4aj36kjT45d850keuUe3U1rY,200
|
72
72
|
nshtrainer/configs/metrics/_config/__init__.py,sha256=XDDvDPWULd_vd3lrgF2KGAVR2LVDuhdQvy-fF2ImarI,159
|
73
|
-
nshtrainer/configs/nn/__init__.py,sha256=
|
74
|
-
nshtrainer/configs/nn/mlp/__init__.py,sha256=
|
75
|
-
nshtrainer/configs/nn/nonlinearity/__init__.py,sha256=
|
76
|
-
nshtrainer/configs/optimizer/__init__.py,sha256=
|
73
|
+
nshtrainer/configs/nn/__init__.py,sha256=tkFG2Hb0oL_AmWP3_0WkDN2zI5PkVfrgwXhaAII7CZw,2072
|
74
|
+
nshtrainer/configs/nn/mlp/__init__.py,sha256=O6kQ6utZNJPG9Fax5pRdZcHa3J-XFKKdXcc_PQg0jk0,347
|
75
|
+
nshtrainer/configs/nn/nonlinearity/__init__.py,sha256=LCTbTyelCMABVw505CGQ4UpEGlAnIhflSLFqwAQXLQA,2155
|
76
|
+
nshtrainer/configs/optimizer/__init__.py,sha256=itIDIHQvGm50eZ7JLyNElahnNUMPJ__4PMmTjc0RQ6o,444
|
77
77
|
nshtrainer/configs/profiler/__init__.py,sha256=2ssaIpfVnvcbfNvZ-JeKp1Cx4NO1LknkVqTm1hu7Lvw,768
|
78
78
|
nshtrainer/configs/profiler/_base/__init__.py,sha256=ekYfPg-VDhCAFM5nJka2TxUYdRDm1CKqjwUOQNbQjD4,176
|
79
79
|
nshtrainer/configs/profiler/advanced/__init__.py,sha256=-ThpUat16Ij_0avkMUVVA8wCWDG_q_tM7KQofnWQCtg,308
|
80
80
|
nshtrainer/configs/profiler/pytorch/__init__.py,sha256=soAU1s2_Pa1na4gW8CK-iysJBO5M_7YeZC2_x40iEdg,294
|
81
81
|
nshtrainer/configs/profiler/simple/__init__.py,sha256=3Wb11lPuFuyasq8xS1CZ4WLuBCLS_nVSQGVllvOOi0Y,289
|
82
|
-
nshtrainer/configs/trainer/__init__.py,sha256=
|
83
|
-
nshtrainer/configs/trainer/_config/__init__.py,sha256=
|
82
|
+
nshtrainer/configs/trainer/__init__.py,sha256=jYCp4Q9uvutA6NYqfthbREMg09-obD3gHtzEI2Ta-hU,7729
|
83
|
+
nshtrainer/configs/trainer/_config/__init__.py,sha256=uof_oJfhwjB1pft7KsRdk_RvNj-tE8wcDBEM7X5qtNc,3666
|
84
84
|
nshtrainer/configs/trainer/accelerator/__init__.py,sha256=3H6R3wlwbKL1TzDqGCChZk78-BcE2czLouo7Djiq3nA,898
|
85
85
|
nshtrainer/configs/trainer/plugin/__init__.py,sha256=NkHQxMPkrtTtdIAO4dQUE9SWEcHRDB0yUXLkTjnl4dA,3332
|
86
86
|
nshtrainer/configs/trainer/plugin/base/__init__.py,sha256=GuubcKrbXt4VjJRT8VpNUQqBtyuutre_CfkS0EWZ5_E,368
|
@@ -99,16 +99,16 @@ nshtrainer/data/__init__.py,sha256=K4i3Tw4g9EOK2zlMMbidi99y0SyI4w8P7_XUf1n42Ts,2
|
|
99
99
|
nshtrainer/data/balanced_batch_sampler.py,sha256=r1cBKRXKHD8E1Ax6tj-FUbE-z1qpbO58mQ9VrK9uLnc,5481
|
100
100
|
nshtrainer/data/datamodule.py,sha256=lSOgH32nysJWa6Y7ba1QyOdUV0DVVdO98qokP8wigjk,4138
|
101
101
|
nshtrainer/data/transform.py,sha256=qd0lIocO59Fk_m90xyOHgFezbymd1mRwly8nbYIfHGc,2263
|
102
|
-
nshtrainer/loggers/__init__.py,sha256
|
103
|
-
nshtrainer/loggers/
|
104
|
-
nshtrainer/loggers/
|
105
|
-
nshtrainer/loggers/csv.py,sha256=
|
106
|
-
nshtrainer/loggers/tensorboard.py,sha256=
|
107
|
-
nshtrainer/loggers/wandb.py,sha256=
|
108
|
-
nshtrainer/lr_scheduler/__init__.py,sha256=
|
109
|
-
nshtrainer/lr_scheduler/
|
110
|
-
nshtrainer/lr_scheduler/linear_warmup_cosine.py,sha256=
|
111
|
-
nshtrainer/lr_scheduler/reduce_lr_on_plateau.py,sha256=
|
102
|
+
nshtrainer/loggers/__init__.py,sha256=Ddd3JJXVzew_ZpwHA9kGnGmvq4OwhItwghDL5PzNhDc,614
|
103
|
+
nshtrainer/loggers/actsave.py,sha256=wgNrpBB6wQM7qff8iLDb_sQnbiAcYHRmH56pcEJPB3o,1409
|
104
|
+
nshtrainer/loggers/base.py,sha256=ON92XbwTSgadQOSyw5PiRRFzyH6uJ-xLtE0nB3cbgPc,1205
|
105
|
+
nshtrainer/loggers/csv.py,sha256=xJ8mSmw4vJwinIfqhF6t2HWmh_1dXEYyLfGuXwL7WHo,1160
|
106
|
+
nshtrainer/loggers/tensorboard.py,sha256=E7iO_fDt9bfH02hBL430bXPLljOo5iGgq2QyPqmx2gQ,2324
|
107
|
+
nshtrainer/loggers/wandb.py,sha256=KZXAUWrrmdX_L8rqej77oUHaM0JxZRM8y9z6JP9PISw,6856
|
108
|
+
nshtrainer/lr_scheduler/__init__.py,sha256=daMMK3erUcNXGGd_nZB8AWu3ZTYqfS1RSWeK4FV2udw,851
|
109
|
+
nshtrainer/lr_scheduler/base.py,sha256=062fGcH5sYeEKwoY55RydCTvfPwTnyZHCi049a3nMbM,3805
|
110
|
+
nshtrainer/lr_scheduler/linear_warmup_cosine.py,sha256=MsoXgCcWTKsrkNZiGnKS6yC-slRuleuwFxeM_lmG_pQ,5560
|
111
|
+
nshtrainer/lr_scheduler/reduce_lr_on_plateau.py,sha256=v9T0GpvOoHV30atFB0MwExHgHcTpMCYxbMRoPjPBjt8,2938
|
112
112
|
nshtrainer/metrics/__init__.py,sha256=Nqkn_jsDf3n5WtfMcnaaEftYjIIT2b-S7rmsB1MOMkU,86
|
113
113
|
nshtrainer/metrics/_config.py,sha256=XIRokFM8PHrhBa3w2R6BM6a4es3ncsoBqE_LqXQFsFE,1223
|
114
114
|
nshtrainer/model/__init__.py,sha256=3G-bwPPSRStWdsdwG9-rn0bXcRpEiP1BiQpF_qavtls,97
|
@@ -116,19 +116,19 @@ nshtrainer/model/base.py,sha256=JL3AmH17GQjQIoMrZl3O0vUI7dj5ZsO5iEJgoLPyzHw,1035
|
|
116
116
|
nshtrainer/model/mixins/callback.py,sha256=0LPgve4VszHbLipid4mpI1qnnmdGS2spivs0dXLvqHw,3154
|
117
117
|
nshtrainer/model/mixins/debug.py,sha256=1LX9KzeFX9JDPs_a6YCdYDZXLhEk_5rBO2aCqlfBy7w,2087
|
118
118
|
nshtrainer/model/mixins/logger.py,sha256=27H99FuLaxc6_dDLG2pid4E_5E0-eLGnc2Ifpt0HYIM,6066
|
119
|
-
nshtrainer/nn/__init__.py,sha256=
|
120
|
-
nshtrainer/nn/mlp.py,sha256=
|
119
|
+
nshtrainer/nn/__init__.py,sha256=7KCs-GDOynCXAIdwkgAQacc0p3FHLEION50UtrvgAOc,1463
|
120
|
+
nshtrainer/nn/mlp.py,sha256=ZbkLyOc08stgIugvu1G5_h66DYtxAFDnboikBaJvvZ8,5988
|
121
121
|
nshtrainer/nn/module_dict.py,sha256=9plb8aQUx5TUEPhX5jI9u8LrpTeKe7jZAHi8iIqcN8w,2365
|
122
122
|
nshtrainer/nn/module_list.py,sha256=UB43pcwD_3nUke_DyLQt-iXKhWdKM6Zjm84lRC1hPYA,1755
|
123
|
-
nshtrainer/nn/nonlinearity.py,sha256=
|
124
|
-
nshtrainer/optimizer.py,sha256=
|
123
|
+
nshtrainer/nn/nonlinearity.py,sha256=xmaL4QCRvCxqmaGIOwetJeKK-6IK4m2OV7D3SjxSwJQ,6322
|
124
|
+
nshtrainer/optimizer.py,sha256=u968GRNPUNn3f_9BEY2RBNuJq5O3wJWams3NG0dkrOA,1738
|
125
125
|
nshtrainer/profiler/__init__.py,sha256=RjaNBoVcTFu8lF0dNlFp-2LaPYdonoIbDy2_KhgF0Ek,594
|
126
126
|
nshtrainer/profiler/_base.py,sha256=kFcSVn9gJuMwgDxbfyHh46CmEAIPZjxw3yjPbKgzvwA,950
|
127
127
|
nshtrainer/profiler/advanced.py,sha256=XrM3FX0ThCv5UwUrrH0l4Ow4LGAtpiBww2N8QAU5NOQ,1160
|
128
128
|
nshtrainer/profiler/pytorch.py,sha256=8K37XvPnCApUpIK8tA2zNMFIaIiTLSoxKQoiyCPBm1Q,2757
|
129
129
|
nshtrainer/profiler/simple.py,sha256=PimjqcU-JuS-8C0ZGHAdwCxgNLij4x0FH6WXsjBQzZs,1005
|
130
130
|
nshtrainer/trainer/__init__.py,sha256=fQ7gQRlGWX-90TYT0rttkQyvXDCzo7DAvJgr-jX1zsY,316
|
131
|
-
nshtrainer/trainer/_config.py,sha256=
|
131
|
+
nshtrainer/trainer/_config.py,sha256=QDy6sINVDGEqfHfPTWXSN-06EoEuMSVscHn8fCRTvr0,32981
|
132
132
|
nshtrainer/trainer/_runtime_callback.py,sha256=6F2Gq27Q8OFfN3RtdNC6QRA8ac0LC1hh4DUE3V5WgbI,4217
|
133
133
|
nshtrainer/trainer/accelerator.py,sha256=Bqq-ry7DeCY4zw9_zBvTZiijpA-uUHrDjtbLV652m4M,2415
|
134
134
|
nshtrainer/trainer/plugin/__init__.py,sha256=UM8f70Ml3RGpsXeQr1Yh1yBcxxjFNGFGgJbniCn_rws,366
|
@@ -151,6 +151,6 @@ nshtrainer/util/seed.py,sha256=diMV8iwBKN7Xxt5pELmui-gyqyT80_CZzomrWhNss0k,316
|
|
151
151
|
nshtrainer/util/slurm.py,sha256=HflkP5iI_r4UHMyPjw9R4dD5AHsJUpcfJw5PLvGYBRM,1603
|
152
152
|
nshtrainer/util/typed.py,sha256=Xt5fUU6zwLKSTLUdenovnKK0N8qUq89Kddz2_XeykVQ,164
|
153
153
|
nshtrainer/util/typing_utils.py,sha256=MjY-CUX9R5Tzat-BlFnQjwl1PQ_W2yZQoXhkYHlJ_VA,442
|
154
|
-
nshtrainer-1.0.
|
155
|
-
nshtrainer-1.0.
|
156
|
-
nshtrainer-1.0.
|
154
|
+
nshtrainer-1.0.0b39.dist-info/METADATA,sha256=zzE6nHlj-clB3HJs5_-bBePCHSOrtTkZTi9z_NrSeRY,988
|
155
|
+
nshtrainer-1.0.0b39.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
|
156
|
+
nshtrainer-1.0.0b39.dist-info/RECORD,,
|
File without changes
|