nshtrainer 1.0.0b33__py3-none-any.whl → 1.0.0b37__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. nshtrainer/__init__.py +1 -0
  2. nshtrainer/_directory.py +3 -1
  3. nshtrainer/_hf_hub.py +8 -1
  4. nshtrainer/callbacks/__init__.py +10 -23
  5. nshtrainer/callbacks/actsave.py +6 -2
  6. nshtrainer/callbacks/base.py +3 -0
  7. nshtrainer/callbacks/checkpoint/__init__.py +0 -4
  8. nshtrainer/callbacks/checkpoint/best_checkpoint.py +2 -0
  9. nshtrainer/callbacks/checkpoint/last_checkpoint.py +72 -2
  10. nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +4 -2
  11. nshtrainer/callbacks/debug_flag.py +4 -2
  12. nshtrainer/callbacks/directory_setup.py +23 -21
  13. nshtrainer/callbacks/early_stopping.py +4 -2
  14. nshtrainer/callbacks/ema.py +29 -27
  15. nshtrainer/callbacks/finite_checks.py +21 -19
  16. nshtrainer/callbacks/gradient_skipping.py +29 -27
  17. nshtrainer/callbacks/log_epoch.py +4 -2
  18. nshtrainer/callbacks/lr_monitor.py +6 -1
  19. nshtrainer/callbacks/norm_logging.py +36 -34
  20. nshtrainer/callbacks/print_table.py +20 -18
  21. nshtrainer/callbacks/rlp_sanity_checks.py +4 -2
  22. nshtrainer/callbacks/shared_parameters.py +9 -7
  23. nshtrainer/callbacks/timer.py +12 -10
  24. nshtrainer/callbacks/wandb_upload_code.py +4 -2
  25. nshtrainer/callbacks/wandb_watch.py +4 -2
  26. nshtrainer/configs/__init__.py +16 -12
  27. nshtrainer/configs/_hf_hub/__init__.py +2 -0
  28. nshtrainer/configs/callbacks/__init__.py +4 -8
  29. nshtrainer/configs/callbacks/actsave/__init__.py +2 -0
  30. nshtrainer/configs/callbacks/base/__init__.py +2 -0
  31. nshtrainer/configs/callbacks/checkpoint/__init__.py +4 -6
  32. nshtrainer/configs/callbacks/checkpoint/best_checkpoint/__init__.py +4 -0
  33. nshtrainer/configs/callbacks/checkpoint/last_checkpoint/__init__.py +4 -0
  34. nshtrainer/configs/callbacks/checkpoint/on_exception_checkpoint/__init__.py +4 -0
  35. nshtrainer/configs/callbacks/debug_flag/__init__.py +2 -0
  36. nshtrainer/configs/callbacks/directory_setup/__init__.py +2 -0
  37. nshtrainer/configs/callbacks/early_stopping/__init__.py +2 -0
  38. nshtrainer/configs/callbacks/ema/__init__.py +2 -0
  39. nshtrainer/configs/callbacks/finite_checks/__init__.py +2 -0
  40. nshtrainer/configs/callbacks/gradient_skipping/__init__.py +4 -0
  41. nshtrainer/configs/callbacks/log_epoch/__init__.py +2 -0
  42. nshtrainer/configs/callbacks/lr_monitor/__init__.py +2 -0
  43. nshtrainer/configs/callbacks/norm_logging/__init__.py +2 -0
  44. nshtrainer/configs/callbacks/print_table/__init__.py +2 -0
  45. nshtrainer/configs/callbacks/rlp_sanity_checks/__init__.py +4 -0
  46. nshtrainer/configs/callbacks/shared_parameters/__init__.py +4 -0
  47. nshtrainer/configs/callbacks/timer/__init__.py +2 -0
  48. nshtrainer/configs/callbacks/wandb_upload_code/__init__.py +4 -0
  49. nshtrainer/configs/callbacks/wandb_watch/__init__.py +2 -0
  50. nshtrainer/configs/loggers/__init__.py +6 -4
  51. nshtrainer/configs/loggers/actsave/__init__.py +4 -2
  52. nshtrainer/configs/loggers/base/__init__.py +11 -0
  53. nshtrainer/configs/loggers/csv/__init__.py +4 -2
  54. nshtrainer/configs/loggers/tensorboard/__init__.py +4 -2
  55. nshtrainer/configs/loggers/wandb/__init__.py +4 -2
  56. nshtrainer/configs/lr_scheduler/__init__.py +4 -2
  57. nshtrainer/configs/lr_scheduler/base/__init__.py +11 -0
  58. nshtrainer/configs/lr_scheduler/linear_warmup_cosine/__init__.py +4 -0
  59. nshtrainer/configs/lr_scheduler/reduce_lr_on_plateau/__init__.py +4 -0
  60. nshtrainer/configs/nn/__init__.py +4 -2
  61. nshtrainer/configs/nn/mlp/__init__.py +2 -2
  62. nshtrainer/configs/nn/nonlinearity/__init__.py +4 -2
  63. nshtrainer/configs/optimizer/__init__.py +2 -0
  64. nshtrainer/configs/trainer/__init__.py +4 -6
  65. nshtrainer/configs/trainer/_config/__init__.py +2 -10
  66. nshtrainer/loggers/__init__.py +3 -8
  67. nshtrainer/loggers/actsave.py +5 -2
  68. nshtrainer/loggers/{_base.py → base.py} +4 -1
  69. nshtrainer/loggers/csv.py +5 -3
  70. nshtrainer/loggers/tensorboard.py +5 -3
  71. nshtrainer/loggers/wandb.py +5 -3
  72. nshtrainer/lr_scheduler/__init__.py +2 -2
  73. nshtrainer/lr_scheduler/{_base.py → base.py} +3 -0
  74. nshtrainer/lr_scheduler/linear_warmup_cosine.py +56 -54
  75. nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +4 -2
  76. nshtrainer/nn/__init__.py +1 -1
  77. nshtrainer/nn/mlp.py +4 -4
  78. nshtrainer/nn/nonlinearity.py +37 -33
  79. nshtrainer/optimizer.py +8 -2
  80. nshtrainer/trainer/__init__.py +3 -2
  81. nshtrainer/trainer/_config.py +6 -44
  82. {nshtrainer-1.0.0b33.dist-info → nshtrainer-1.0.0b37.dist-info}/METADATA +1 -1
  83. nshtrainer-1.0.0b37.dist-info/RECORD +156 -0
  84. nshtrainer/callbacks/checkpoint/time_checkpoint.py +0 -114
  85. nshtrainer/configs/callbacks/checkpoint/time_checkpoint/__init__.py +0 -19
  86. nshtrainer/configs/loggers/_base/__init__.py +0 -9
  87. nshtrainer/configs/lr_scheduler/_base/__init__.py +0 -9
  88. nshtrainer-1.0.0b33.dist-info/RECORD +0 -158
  89. {nshtrainer-1.0.0b33.dist-info → nshtrainer-1.0.0b37.dist-info}/WHEEL +0 -0
@@ -94,3 +94,6 @@ class LRSchedulerConfigBase(C.Config, ABC):
94
94
  # ^ This is a hack to trigger the computation of the estimated stepping batches
95
95
  # and make sure that the `trainer.num_training_batches` attribute is set.
96
96
  return math.ceil(trainer.num_training_batches / trainer.accumulate_grad_batches)
97
+
98
+
99
+ lr_scheduler_registry = C.Registry(LRSchedulerConfigBase, discriminator="name")
@@ -6,10 +6,64 @@ from typing import Literal
6
6
 
7
7
  from torch.optim import Optimizer
8
8
  from torch.optim.lr_scheduler import LRScheduler
9
- from typing_extensions import override
9
+ from typing_extensions import final, override
10
10
 
11
11
  from ..util.config import DurationConfig
12
- from ._base import LRSchedulerConfigBase, LRSchedulerMetadata
12
+ from .base import LRSchedulerConfigBase, LRSchedulerMetadata, lr_scheduler_registry
13
+
14
+
15
+ @final
16
+ @lr_scheduler_registry.register
17
+ class LinearWarmupCosineDecayLRSchedulerConfig(LRSchedulerConfigBase):
18
+ name: Literal["linear_warmup_cosine_decay"] = "linear_warmup_cosine_decay"
19
+
20
+ warmup_duration: DurationConfig
21
+ r"""The duration for the linear warmup phase.
22
+ The learning rate is linearly increased from `warmup_start_lr` to the initial learning rate over this duration."""
23
+
24
+ max_duration: DurationConfig
25
+ r"""The total duration.
26
+ The learning rate is decayed to `min_lr` over this duration."""
27
+
28
+ warmup_start_lr_factor: float = 0.0
29
+ r"""The initial learning rate for the linear warmup phase, as a factor of the initial learning rate.
30
+ The learning rate is linearly increased from this value to the initial learning rate over `warmup_epochs` epochs."""
31
+
32
+ min_lr_factor: float = 0.0
33
+ r"""The minimum learning rate, as a factor of the initial learning rate.
34
+ The learning rate is decayed to this value over `max_epochs` epochs."""
35
+
36
+ annealing: bool = False
37
+ r"""Whether to restart the learning rate schedule after `max_epochs` epochs.
38
+ If `False`, the learning rate will be decayed to `min_lr` over `max_epochs` epochs, and then the learning rate will be set to `min_lr` for all subsequent epochs.
39
+ If `True`, the learning rate will be decayed to `min_lr` over `max_epochs` epochs, and then the learning rate will be increased back to the initial learning rate over `max_epochs` epochs, and so on (this is called a cosine annealing schedule)."""
40
+
41
+ @override
42
+ def metadata(self) -> LRSchedulerMetadata:
43
+ return {
44
+ "interval": "step",
45
+ }
46
+
47
+ @override
48
+ def create_scheduler_impl(self, optimizer, lightning_module):
49
+ num_steps_per_epoch = self.compute_num_steps_per_epoch(lightning_module)
50
+ warmup_steps = self.warmup_duration.to_steps(num_steps_per_epoch).value
51
+ max_steps = self.max_duration.to_steps(num_steps_per_epoch).value
52
+
53
+ # Warmup and max steps should be at least 1.
54
+ warmup_steps = max(warmup_steps, 1)
55
+ max_steps = max(max_steps, 1)
56
+
57
+ # Create the scheduler
58
+ scheduler = LinearWarmupCosineAnnealingLR(
59
+ optimizer=optimizer,
60
+ warmup_epochs=warmup_steps,
61
+ max_epochs=max_steps,
62
+ warmup_start_lr_factor=self.warmup_start_lr_factor,
63
+ eta_min_factor=self.min_lr_factor,
64
+ should_restart=self.annealing,
65
+ )
66
+ return scheduler
13
67
 
14
68
 
15
69
  class LinearWarmupCosineAnnealingLR(LRScheduler):
@@ -89,55 +143,3 @@ class LinearWarmupCosineAnnealingLR(LRScheduler):
89
143
  + self.eta_min_factor * base_lr
90
144
  for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
91
145
  ]
92
-
93
-
94
- class LinearWarmupCosineDecayLRSchedulerConfig(LRSchedulerConfigBase):
95
- name: Literal["linear_warmup_cosine_decay"] = "linear_warmup_cosine_decay"
96
-
97
- warmup_duration: DurationConfig
98
- r"""The duration for the linear warmup phase.
99
- The learning rate is linearly increased from `warmup_start_lr` to the initial learning rate over this duration."""
100
-
101
- max_duration: DurationConfig
102
- r"""The total duration.
103
- The learning rate is decayed to `min_lr` over this duration."""
104
-
105
- warmup_start_lr_factor: float = 0.0
106
- r"""The initial learning rate for the linear warmup phase, as a factor of the initial learning rate.
107
- The learning rate is linearly increased from this value to the initial learning rate over `warmup_epochs` epochs."""
108
-
109
- min_lr_factor: float = 0.0
110
- r"""The minimum learning rate, as a factor of the initial learning rate.
111
- The learning rate is decayed to this value over `max_epochs` epochs."""
112
-
113
- annealing: bool = False
114
- r"""Whether to restart the learning rate schedule after `max_epochs` epochs.
115
- If `False`, the learning rate will be decayed to `min_lr` over `max_epochs` epochs, and then the learning rate will be set to `min_lr` for all subsequent epochs.
116
- If `True`, the learning rate will be decayed to `min_lr` over `max_epochs` epochs, and then the learning rate will be increased back to the initial learning rate over `max_epochs` epochs, and so on (this is called a cosine annealing schedule)."""
117
-
118
- @override
119
- def metadata(self) -> LRSchedulerMetadata:
120
- return {
121
- "interval": "step",
122
- }
123
-
124
- @override
125
- def create_scheduler_impl(self, optimizer, lightning_module):
126
- num_steps_per_epoch = self.compute_num_steps_per_epoch(lightning_module)
127
- warmup_steps = self.warmup_duration.to_steps(num_steps_per_epoch).value
128
- max_steps = self.max_duration.to_steps(num_steps_per_epoch).value
129
-
130
- # Warmup and max steps should be at least 1.
131
- warmup_steps = max(warmup_steps, 1)
132
- max_steps = max(max_steps, 1)
133
-
134
- # Create the scheduler
135
- scheduler = LinearWarmupCosineAnnealingLR(
136
- optimizer=optimizer,
137
- warmup_epochs=warmup_steps,
138
- max_epochs=max_steps,
139
- warmup_start_lr_factor=self.warmup_start_lr_factor,
140
- eta_min_factor=self.min_lr_factor,
141
- should_restart=self.annealing,
142
- )
143
- return scheduler
@@ -4,12 +4,14 @@ from typing import Literal
4
4
 
5
5
  from lightning.pytorch.utilities.types import LRSchedulerConfigType
6
6
  from torch.optim.lr_scheduler import ReduceLROnPlateau
7
- from typing_extensions import override
7
+ from typing_extensions import final, override
8
8
 
9
9
  from ..metrics._config import MetricConfig
10
- from ._base import LRSchedulerConfigBase, LRSchedulerMetadata
10
+ from .base import LRSchedulerConfigBase, LRSchedulerMetadata, lr_scheduler_registry
11
11
 
12
12
 
13
+ @final
14
+ @lr_scheduler_registry.register
13
15
  class ReduceLROnPlateauConfig(LRSchedulerConfigBase):
14
16
  """Reduce learning rate when a metric has stopped improving."""
15
17
 
nshtrainer/nn/__init__.py CHANGED
@@ -6,12 +6,12 @@ from .mlp import MLPConfigDict as MLPConfigDict
6
6
  from .mlp import ResidualSequential as ResidualSequential
7
7
  from .module_dict import TypedModuleDict as TypedModuleDict
8
8
  from .module_list import TypedModuleList as TypedModuleList
9
- from .nonlinearity import BaseNonlinearityConfig as BaseNonlinearityConfig
10
9
  from .nonlinearity import ELUNonlinearityConfig as ELUNonlinearityConfig
11
10
  from .nonlinearity import GELUNonlinearityConfig as GELUNonlinearityConfig
12
11
  from .nonlinearity import LeakyReLUNonlinearityConfig as LeakyReLUNonlinearityConfig
13
12
  from .nonlinearity import MishNonlinearityConfig as MishNonlinearityConfig
14
13
  from .nonlinearity import NonlinearityConfig as NonlinearityConfig
14
+ from .nonlinearity import NonlinearityConfigBase as NonlinearityConfigBase
15
15
  from .nonlinearity import PReLUConfig as PReLUConfig
16
16
  from .nonlinearity import ReLUNonlinearityConfig as ReLUNonlinearityConfig
17
17
  from .nonlinearity import SigmoidNonlinearityConfig as SigmoidNonlinearityConfig
nshtrainer/nn/mlp.py CHANGED
@@ -9,7 +9,7 @@ import torch
9
9
  import torch.nn as nn
10
10
  from typing_extensions import TypedDict, override
11
11
 
12
- from .nonlinearity import BaseNonlinearityConfig, NonlinearityConfig
12
+ from .nonlinearity import NonlinearityConfig, NonlinearityConfigBase
13
13
 
14
14
 
15
15
  @runtime_checkable
@@ -92,11 +92,11 @@ class MLPConfig(C.Config):
92
92
 
93
93
  def MLP(
94
94
  dims: Sequence[int],
95
- activation: BaseNonlinearityConfig
95
+ activation: NonlinearityConfigBase
96
96
  | nn.Module
97
97
  | Callable[[], nn.Module]
98
98
  | None = None,
99
- nonlinearity: BaseNonlinearityConfig
99
+ nonlinearity: NonlinearityConfigBase
100
100
  | nn.Module
101
101
  | Callable[[], nn.Module]
102
102
  | None = None,
@@ -153,7 +153,7 @@ def MLP(
153
153
  layers.append(nn.Dropout(dropout))
154
154
  if i < len(dims) - 2:
155
155
  match activation:
156
- case BaseNonlinearityConfig():
156
+ case NonlinearityConfigBase():
157
157
  layers.append(activation.create_module())
158
158
  case nn.Module():
159
159
  # In this case, we create a deep copy of the module to avoid sharing parameters (if any).
@@ -7,10 +7,10 @@ import nshconfig as C
7
7
  import torch
8
8
  import torch.nn as nn
9
9
  import torch.nn.functional as F
10
- from typing_extensions import final, override
10
+ from typing_extensions import TypeAliasType, final, override
11
11
 
12
12
 
13
- class BaseNonlinearityConfig(C.Config, ABC):
13
+ class NonlinearityConfigBase(C.Config, ABC):
14
14
  @abstractmethod
15
15
  def create_module(self) -> nn.Module: ...
16
16
 
@@ -18,8 +18,12 @@ class BaseNonlinearityConfig(C.Config, ABC):
18
18
  def __call__(self, x: torch.Tensor) -> torch.Tensor: ...
19
19
 
20
20
 
21
+ nonlinearity_registry = C.Registry(NonlinearityConfigBase, discriminator="name")
22
+
23
+
21
24
  @final
22
- class ReLUNonlinearityConfig(BaseNonlinearityConfig):
25
+ @nonlinearity_registry.register
26
+ class ReLUNonlinearityConfig(NonlinearityConfigBase):
23
27
  name: Literal["relu"] = "relu"
24
28
 
25
29
  @override
@@ -31,7 +35,8 @@ class ReLUNonlinearityConfig(BaseNonlinearityConfig):
31
35
 
32
36
 
33
37
  @final
34
- class SigmoidNonlinearityConfig(BaseNonlinearityConfig):
38
+ @nonlinearity_registry.register
39
+ class SigmoidNonlinearityConfig(NonlinearityConfigBase):
35
40
  name: Literal["sigmoid"] = "sigmoid"
36
41
 
37
42
  @override
@@ -43,7 +48,8 @@ class SigmoidNonlinearityConfig(BaseNonlinearityConfig):
43
48
 
44
49
 
45
50
  @final
46
- class TanhNonlinearityConfig(BaseNonlinearityConfig):
51
+ @nonlinearity_registry.register
52
+ class TanhNonlinearityConfig(NonlinearityConfigBase):
47
53
  name: Literal["tanh"] = "tanh"
48
54
 
49
55
  @override
@@ -55,7 +61,8 @@ class TanhNonlinearityConfig(BaseNonlinearityConfig):
55
61
 
56
62
 
57
63
  @final
58
- class SoftmaxNonlinearityConfig(BaseNonlinearityConfig):
64
+ @nonlinearity_registry.register
65
+ class SoftmaxNonlinearityConfig(NonlinearityConfigBase):
59
66
  name: Literal["softmax"] = "softmax"
60
67
 
61
68
  dim: int = -1
@@ -70,7 +77,8 @@ class SoftmaxNonlinearityConfig(BaseNonlinearityConfig):
70
77
 
71
78
 
72
79
  @final
73
- class SoftplusNonlinearityConfig(BaseNonlinearityConfig):
80
+ @nonlinearity_registry.register
81
+ class SoftplusNonlinearityConfig(NonlinearityConfigBase):
74
82
  name: Literal["softplus"] = "softplus"
75
83
 
76
84
  beta: float = 1.0
@@ -88,7 +96,8 @@ class SoftplusNonlinearityConfig(BaseNonlinearityConfig):
88
96
 
89
97
 
90
98
  @final
91
- class SoftsignNonlinearityConfig(BaseNonlinearityConfig):
99
+ @nonlinearity_registry.register
100
+ class SoftsignNonlinearityConfig(NonlinearityConfigBase):
92
101
  name: Literal["softsign"] = "softsign"
93
102
 
94
103
  @override
@@ -100,7 +109,8 @@ class SoftsignNonlinearityConfig(BaseNonlinearityConfig):
100
109
 
101
110
 
102
111
  @final
103
- class ELUNonlinearityConfig(BaseNonlinearityConfig):
112
+ @nonlinearity_registry.register
113
+ class ELUNonlinearityConfig(NonlinearityConfigBase):
104
114
  name: Literal["elu"] = "elu"
105
115
 
106
116
  alpha: float = 1.0
@@ -115,7 +125,8 @@ class ELUNonlinearityConfig(BaseNonlinearityConfig):
115
125
 
116
126
 
117
127
  @final
118
- class LeakyReLUNonlinearityConfig(BaseNonlinearityConfig):
128
+ @nonlinearity_registry.register
129
+ class LeakyReLUNonlinearityConfig(NonlinearityConfigBase):
119
130
  name: Literal["leaky_relu"] = "leaky_relu"
120
131
 
121
132
  negative_slope: float = 1.0e-2
@@ -130,7 +141,8 @@ class LeakyReLUNonlinearityConfig(BaseNonlinearityConfig):
130
141
 
131
142
 
132
143
  @final
133
- class PReLUConfig(BaseNonlinearityConfig):
144
+ @nonlinearity_registry.register
145
+ class PReLUConfig(NonlinearityConfigBase):
134
146
  name: Literal["prelu"] = "prelu"
135
147
 
136
148
  num_parameters: int = 1
@@ -152,7 +164,8 @@ class PReLUConfig(BaseNonlinearityConfig):
152
164
 
153
165
 
154
166
  @final
155
- class GELUNonlinearityConfig(BaseNonlinearityConfig):
167
+ @nonlinearity_registry.register
168
+ class GELUNonlinearityConfig(NonlinearityConfigBase):
156
169
  name: Literal["gelu"] = "gelu"
157
170
 
158
171
  approximate: Literal["tanh", "none"] = "none"
@@ -167,7 +180,8 @@ class GELUNonlinearityConfig(BaseNonlinearityConfig):
167
180
 
168
181
 
169
182
  @final
170
- class SwishNonlinearityConfig(BaseNonlinearityConfig):
183
+ @nonlinearity_registry.register
184
+ class SwishNonlinearityConfig(NonlinearityConfigBase):
171
185
  name: Literal["swish"] = "swish"
172
186
 
173
187
  @override
@@ -179,7 +193,8 @@ class SwishNonlinearityConfig(BaseNonlinearityConfig):
179
193
 
180
194
 
181
195
  @final
182
- class SiLUNonlinearityConfig(BaseNonlinearityConfig):
196
+ @nonlinearity_registry.register
197
+ class SiLUNonlinearityConfig(NonlinearityConfigBase):
183
198
  name: Literal["silu"] = "silu"
184
199
 
185
200
  @override
@@ -191,7 +206,8 @@ class SiLUNonlinearityConfig(BaseNonlinearityConfig):
191
206
 
192
207
 
193
208
  @final
194
- class MishNonlinearityConfig(BaseNonlinearityConfig):
209
+ @nonlinearity_registry.register
210
+ class MishNonlinearityConfig(NonlinearityConfigBase):
195
211
  name: Literal["mish"] = "mish"
196
212
 
197
213
  @override
@@ -210,7 +226,8 @@ class SwiGLU(nn.SiLU):
210
226
 
211
227
 
212
228
  @final
213
- class SwiGLUNonlinearityConfig(BaseNonlinearityConfig):
229
+ @nonlinearity_registry.register
230
+ class SwiGLUNonlinearityConfig(NonlinearityConfigBase):
214
231
  name: Literal["swiglu"] = "swiglu"
215
232
 
216
233
  @override
@@ -222,20 +239,7 @@ class SwiGLUNonlinearityConfig(BaseNonlinearityConfig):
222
239
  return input * F.silu(gate)
223
240
 
224
241
 
225
- NonlinearityConfig = Annotated[
226
- ReLUNonlinearityConfig
227
- | SigmoidNonlinearityConfig
228
- | TanhNonlinearityConfig
229
- | SoftmaxNonlinearityConfig
230
- | SoftplusNonlinearityConfig
231
- | SoftsignNonlinearityConfig
232
- | ELUNonlinearityConfig
233
- | LeakyReLUNonlinearityConfig
234
- | PReLUConfig
235
- | GELUNonlinearityConfig
236
- | SwishNonlinearityConfig
237
- | SiLUNonlinearityConfig
238
- | MishNonlinearityConfig
239
- | SwiGLUNonlinearityConfig,
240
- C.Field(discriminator="name"),
241
- ]
242
+ NonlinearityConfig = TypeAliasType(
243
+ "NonlinearityConfig",
244
+ Annotated[NonlinearityConfigBase, nonlinearity_registry.DynamicResolution()],
245
+ )
nshtrainer/optimizer.py CHANGED
@@ -7,7 +7,7 @@ from typing import Annotated, Any, Literal
7
7
  import nshconfig as C
8
8
  import torch.nn as nn
9
9
  from torch.optim import Optimizer
10
- from typing_extensions import TypeAliasType, override
10
+ from typing_extensions import TypeAliasType, final, override
11
11
 
12
12
 
13
13
  class OptimizerConfigBase(C.Config, ABC):
@@ -18,6 +18,11 @@ class OptimizerConfigBase(C.Config, ABC):
18
18
  ) -> Optimizer: ...
19
19
 
20
20
 
21
+ optimizer_registry = C.Registry(OptimizerConfigBase, discriminator="name")
22
+
23
+
24
+ @final
25
+ @optimizer_registry.register
21
26
  class AdamWConfig(OptimizerConfigBase):
22
27
  name: Literal["adamw"] = "adamw"
23
28
 
@@ -58,5 +63,6 @@ class AdamWConfig(OptimizerConfigBase):
58
63
 
59
64
 
60
65
  OptimizerConfig = TypeAliasType(
61
- "OptimizerConfig", Annotated[AdamWConfig, C.Field(discriminator="name")]
66
+ "OptimizerConfig",
67
+ Annotated[OptimizerConfigBase, optimizer_registry.DynamicResolution()],
62
68
  )
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from ..callbacks import callback_registry as callback_registry
3
4
  from ._config import TrainerConfig as TrainerConfig
4
- from ._config import accelerator_registry as accelerator_registry
5
- from ._config import plugin_registry as plugin_registry
5
+ from .accelerator import accelerator_registry as accelerator_registry
6
+ from .plugin import plugin_registry as plugin_registry
6
7
  from .trainer import Trainer as Trainer
@@ -37,7 +37,6 @@ from ..callbacks import (
37
37
  OnExceptionCheckpointCallbackConfig,
38
38
  )
39
39
  from ..callbacks.base import CallbackConfigBase
40
- from ..callbacks.checkpoint.time_checkpoint import TimeCheckpointCallbackConfig
41
40
  from ..callbacks.debug_flag import DebugFlagCallbackConfig
42
41
  from ..callbacks.log_epoch import LogEpochCallbackConfig
43
42
  from ..callbacks.lr_monitor import LearningRateMonitorConfig
@@ -49,14 +48,14 @@ from ..loggers import (
49
48
  TensorboardLoggerConfig,
50
49
  WandbLoggerConfig,
51
50
  )
52
- from ..loggers._base import BaseLoggerConfig
53
51
  from ..loggers.actsave import ActSaveLoggerConfig
52
+ from ..loggers.base import LoggerConfigBase
54
53
  from ..metrics._config import MetricConfig
55
54
  from ..profiler import ProfilerConfig
56
55
  from ..util._environment_info import EnvironmentConfig
57
- from .accelerator import AcceleratorConfig, AcceleratorLiteral, accelerator_registry
58
- from .plugin import PluginConfig, plugin_registry
59
- from .strategy import StrategyConfig
56
+ from .accelerator import AcceleratorConfig, AcceleratorLiteral
57
+ from .plugin import PluginConfig
58
+ from .strategy import StrategyConfig, StrategyLiteral
60
59
 
61
60
  log = logging.getLogger(__name__)
62
61
 
@@ -70,46 +69,12 @@ class GradientClippingConfig(C.Config):
70
69
  """Norm type to use for gradient clipping."""
71
70
 
72
71
 
73
- StrategyLiteral = TypeAliasType(
74
- "StrategyLiteral",
75
- Literal[
76
- "auto",
77
- "ddp",
78
- "ddp_find_unused_parameters_false",
79
- "ddp_find_unused_parameters_true",
80
- "ddp_spawn",
81
- "ddp_spawn_find_unused_parameters_false",
82
- "ddp_spawn_find_unused_parameters_true",
83
- "ddp_fork",
84
- "ddp_fork_find_unused_parameters_false",
85
- "ddp_fork_find_unused_parameters_true",
86
- "ddp_notebook",
87
- "dp",
88
- "deepspeed",
89
- "deepspeed_stage_1",
90
- "deepspeed_stage_1_offload",
91
- "deepspeed_stage_2",
92
- "deepspeed_stage_2_offload",
93
- "deepspeed_stage_3",
94
- "deepspeed_stage_3_offload",
95
- "deepspeed_stage_3_offload_nvme",
96
- "fsdp",
97
- "fsdp_cpu_offload",
98
- "single_xla",
99
- "xla_fsdp",
100
- "xla",
101
- "single_tpu",
102
- ],
103
- )
104
-
105
-
106
72
  CheckpointCallbackConfig = TypeAliasType(
107
73
  "CheckpointCallbackConfig",
108
74
  Annotated[
109
75
  BestCheckpointCallbackConfig
110
76
  | LastCheckpointCallbackConfig
111
- | OnExceptionCheckpointCallbackConfig
112
- | TimeCheckpointCallbackConfig,
77
+ | OnExceptionCheckpointCallbackConfig,
113
78
  C.Field(discriminator="name"),
114
79
  ],
115
80
  )
@@ -123,7 +88,6 @@ class CheckpointSavingConfig(CallbackConfigBase):
123
88
  BestCheckpointCallbackConfig(throw_on_no_metric=False),
124
89
  LastCheckpointCallbackConfig(),
125
90
  OnExceptionCheckpointCallbackConfig(),
126
- TimeCheckpointCallbackConfig(interval=timedelta(hours=12)),
127
91
  ]
128
92
  """Checkpoint callback configurations."""
129
93
 
@@ -397,8 +361,6 @@ class SanityCheckingConfig(C.Config):
397
361
  """
398
362
 
399
363
 
400
- @plugin_registry.rebuild_on_registers
401
- @accelerator_registry.rebuild_on_registers
402
364
  class TrainerConfig(C.Config):
403
365
  # region Active Run Configuration
404
366
  id: str = C.Field(default_factory=lambda: TrainerConfig.generate_id())
@@ -808,7 +770,7 @@ class TrainerConfig(C.Config):
808
770
  yield self.auto_set_debug_flag
809
771
  yield from self.callbacks
810
772
 
811
- def _nshtrainer_all_logger_configs(self) -> Iterable[BaseLoggerConfig | None]:
773
+ def _nshtrainer_all_logger_configs(self) -> Iterable[LoggerConfigBase | None]:
812
774
  # Disable all loggers if barebones mode is enabled
813
775
  if self.barebones:
814
776
  return
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 1.0.0b33
3
+ Version: 1.0.0b37
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com