nshtrainer 1.0.0b36__py3-none-any.whl → 1.0.0b37__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. nshtrainer/_directory.py +3 -1
  2. nshtrainer/configs/__init__.py +12 -4
  3. nshtrainer/configs/loggers/__init__.py +6 -4
  4. nshtrainer/configs/loggers/actsave/__init__.py +4 -2
  5. nshtrainer/configs/loggers/base/__init__.py +11 -0
  6. nshtrainer/configs/loggers/csv/__init__.py +4 -2
  7. nshtrainer/configs/loggers/tensorboard/__init__.py +4 -2
  8. nshtrainer/configs/loggers/wandb/__init__.py +4 -2
  9. nshtrainer/configs/lr_scheduler/__init__.py +4 -2
  10. nshtrainer/configs/lr_scheduler/base/__init__.py +11 -0
  11. nshtrainer/configs/lr_scheduler/linear_warmup_cosine/__init__.py +4 -0
  12. nshtrainer/configs/lr_scheduler/reduce_lr_on_plateau/__init__.py +4 -0
  13. nshtrainer/configs/nn/__init__.py +4 -2
  14. nshtrainer/configs/nn/mlp/__init__.py +2 -2
  15. nshtrainer/configs/nn/nonlinearity/__init__.py +4 -2
  16. nshtrainer/configs/optimizer/__init__.py +2 -0
  17. nshtrainer/configs/trainer/__init__.py +2 -2
  18. nshtrainer/configs/trainer/_config/__init__.py +2 -2
  19. nshtrainer/loggers/__init__.py +3 -8
  20. nshtrainer/loggers/actsave.py +5 -2
  21. nshtrainer/loggers/{_base.py → base.py} +4 -1
  22. nshtrainer/loggers/csv.py +5 -3
  23. nshtrainer/loggers/tensorboard.py +5 -3
  24. nshtrainer/loggers/wandb.py +5 -3
  25. nshtrainer/lr_scheduler/__init__.py +2 -2
  26. nshtrainer/lr_scheduler/{_base.py → base.py} +3 -0
  27. nshtrainer/lr_scheduler/linear_warmup_cosine.py +56 -54
  28. nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +4 -2
  29. nshtrainer/nn/__init__.py +1 -1
  30. nshtrainer/nn/mlp.py +4 -4
  31. nshtrainer/nn/nonlinearity.py +37 -33
  32. nshtrainer/optimizer.py +8 -2
  33. nshtrainer/trainer/_config.py +2 -2
  34. {nshtrainer-1.0.0b36.dist-info → nshtrainer-1.0.0b37.dist-info}/METADATA +1 -1
  35. {nshtrainer-1.0.0b36.dist-info → nshtrainer-1.0.0b37.dist-info}/RECORD +36 -36
  36. nshtrainer/configs/loggers/_base/__init__.py +0 -9
  37. nshtrainer/configs/lr_scheduler/_base/__init__.py +0 -9
  38. {nshtrainer-1.0.0b36.dist-info → nshtrainer-1.0.0b37.dist-info}/WHEEL +0 -0
nshtrainer/_directory.py CHANGED
@@ -81,7 +81,9 @@ class DirectoryConfig(C.Config):
81
81
 
82
82
  # Save to nshtrainer/{id}/log/{logger name}
83
83
  log_dir = self.resolve_subdirectory(run_id, "log")
84
- log_dir = log_dir / logger.name
84
+ log_dir = log_dir / getattr(logger, "name")
85
+ # ^ NOTE: Logger must have a `name` attribute, as this is
86
+ # the discriminator for the logger registry
85
87
  log_dir.mkdir(exist_ok=True)
86
88
 
87
89
  return log_dir
@@ -60,24 +60,26 @@ from nshtrainer.callbacks.checkpoint._base import (
60
60
  BaseCheckpointCallbackConfig as BaseCheckpointCallbackConfig,
61
61
  )
62
62
  from nshtrainer.loggers import ActSaveLoggerConfig as ActSaveLoggerConfig
63
- from nshtrainer.loggers import BaseLoggerConfig as BaseLoggerConfig
64
63
  from nshtrainer.loggers import CSVLoggerConfig as CSVLoggerConfig
65
64
  from nshtrainer.loggers import LoggerConfig as LoggerConfig
65
+ from nshtrainer.loggers import LoggerConfigBase as LoggerConfigBase
66
66
  from nshtrainer.loggers import TensorboardLoggerConfig as TensorboardLoggerConfig
67
67
  from nshtrainer.loggers import WandbLoggerConfig as WandbLoggerConfig
68
+ from nshtrainer.loggers import logger_registry as logger_registry
68
69
  from nshtrainer.lr_scheduler import (
69
70
  LinearWarmupCosineDecayLRSchedulerConfig as LinearWarmupCosineDecayLRSchedulerConfig,
70
71
  )
71
72
  from nshtrainer.lr_scheduler import LRSchedulerConfig as LRSchedulerConfig
72
73
  from nshtrainer.lr_scheduler import LRSchedulerConfigBase as LRSchedulerConfigBase
73
74
  from nshtrainer.lr_scheduler import ReduceLROnPlateauConfig as ReduceLROnPlateauConfig
74
- from nshtrainer.nn import BaseNonlinearityConfig as BaseNonlinearityConfig
75
+ from nshtrainer.lr_scheduler.base import lr_scheduler_registry as lr_scheduler_registry
75
76
  from nshtrainer.nn import ELUNonlinearityConfig as ELUNonlinearityConfig
76
77
  from nshtrainer.nn import GELUNonlinearityConfig as GELUNonlinearityConfig
77
78
  from nshtrainer.nn import LeakyReLUNonlinearityConfig as LeakyReLUNonlinearityConfig
78
79
  from nshtrainer.nn import MishNonlinearityConfig as MishNonlinearityConfig
79
80
  from nshtrainer.nn import MLPConfig as MLPConfig
80
81
  from nshtrainer.nn import NonlinearityConfig as NonlinearityConfig
82
+ from nshtrainer.nn import NonlinearityConfigBase as NonlinearityConfigBase
81
83
  from nshtrainer.nn import PReLUConfig as PReLUConfig
82
84
  from nshtrainer.nn import ReLUNonlinearityConfig as ReLUNonlinearityConfig
83
85
  from nshtrainer.nn import SigmoidNonlinearityConfig as SigmoidNonlinearityConfig
@@ -90,9 +92,11 @@ from nshtrainer.nn import TanhNonlinearityConfig as TanhNonlinearityConfig
90
92
  from nshtrainer.nn.nonlinearity import (
91
93
  SwiGLUNonlinearityConfig as SwiGLUNonlinearityConfig,
92
94
  )
95
+ from nshtrainer.nn.nonlinearity import nonlinearity_registry as nonlinearity_registry
93
96
  from nshtrainer.optimizer import AdamWConfig as AdamWConfig
94
97
  from nshtrainer.optimizer import OptimizerConfig as OptimizerConfig
95
98
  from nshtrainer.optimizer import OptimizerConfigBase as OptimizerConfigBase
99
+ from nshtrainer.optimizer import optimizer_registry as optimizer_registry
96
100
  from nshtrainer.profiler import AdvancedProfilerConfig as AdvancedProfilerConfig
97
101
  from nshtrainer.profiler import BaseProfilerConfig as BaseProfilerConfig
98
102
  from nshtrainer.profiler import ProfilerConfig as ProfilerConfig
@@ -225,8 +229,6 @@ __all__ = [
225
229
  "AdvancedProfilerConfig",
226
230
  "AsyncCheckpointIOPlugin",
227
231
  "BaseCheckpointCallbackConfig",
228
- "BaseLoggerConfig",
229
- "BaseNonlinearityConfig",
230
232
  "BaseProfilerConfig",
231
233
  "BestCheckpointCallbackConfig",
232
234
  "BitsandbytesPluginConfig",
@@ -280,6 +282,7 @@ __all__ = [
280
282
  "LinearWarmupCosineDecayLRSchedulerConfig",
281
283
  "LogEpochCallbackConfig",
282
284
  "LoggerConfig",
285
+ "LoggerConfigBase",
283
286
  "MLPConfig",
284
287
  "MPIEnvironmentPlugin",
285
288
  "MPSAcceleratorConfig",
@@ -287,6 +290,7 @@ __all__ = [
287
290
  "MishNonlinearityConfig",
288
291
  "MixedPrecisionPluginConfig",
289
292
  "NonlinearityConfig",
293
+ "NonlinearityConfigBase",
290
294
  "NormLoggingCallbackConfig",
291
295
  "OnExceptionCheckpointCallbackConfig",
292
296
  "OptimizerConfig",
@@ -334,11 +338,15 @@ __all__ = [
334
338
  "accelerator_registry",
335
339
  "callback_registry",
336
340
  "callbacks",
341
+ "logger_registry",
337
342
  "loggers",
338
343
  "lr_scheduler",
344
+ "lr_scheduler_registry",
339
345
  "metrics",
340
346
  "nn",
347
+ "nonlinearity_registry",
341
348
  "optimizer",
349
+ "optimizer_registry",
342
350
  "plugin_registry",
343
351
  "profiler",
344
352
  "trainer",
@@ -3,11 +3,12 @@ from __future__ import annotations
3
3
  __codegen__ = True
4
4
 
5
5
  from nshtrainer.loggers import ActSaveLoggerConfig as ActSaveLoggerConfig
6
- from nshtrainer.loggers import BaseLoggerConfig as BaseLoggerConfig
7
6
  from nshtrainer.loggers import CSVLoggerConfig as CSVLoggerConfig
8
7
  from nshtrainer.loggers import LoggerConfig as LoggerConfig
8
+ from nshtrainer.loggers import LoggerConfigBase as LoggerConfigBase
9
9
  from nshtrainer.loggers import TensorboardLoggerConfig as TensorboardLoggerConfig
10
10
  from nshtrainer.loggers import WandbLoggerConfig as WandbLoggerConfig
11
+ from nshtrainer.loggers import logger_registry as logger_registry
11
12
  from nshtrainer.loggers.wandb import CallbackConfigBase as CallbackConfigBase
12
13
  from nshtrainer.loggers.wandb import (
13
14
  WandbUploadCodeCallbackConfig as WandbUploadCodeCallbackConfig,
@@ -16,25 +17,26 @@ from nshtrainer.loggers.wandb import (
16
17
  WandbWatchCallbackConfig as WandbWatchCallbackConfig,
17
18
  )
18
19
 
19
- from . import _base as _base
20
20
  from . import actsave as actsave
21
+ from . import base as base
21
22
  from . import csv as csv
22
23
  from . import tensorboard as tensorboard
23
24
  from . import wandb as wandb
24
25
 
25
26
  __all__ = [
26
27
  "ActSaveLoggerConfig",
27
- "BaseLoggerConfig",
28
28
  "CSVLoggerConfig",
29
29
  "CallbackConfigBase",
30
30
  "LoggerConfig",
31
+ "LoggerConfigBase",
31
32
  "TensorboardLoggerConfig",
32
33
  "WandbLoggerConfig",
33
34
  "WandbUploadCodeCallbackConfig",
34
35
  "WandbWatchCallbackConfig",
35
- "_base",
36
36
  "actsave",
37
+ "base",
37
38
  "csv",
39
+ "logger_registry",
38
40
  "tensorboard",
39
41
  "wandb",
40
42
  ]
@@ -3,9 +3,11 @@ from __future__ import annotations
3
3
  __codegen__ = True
4
4
 
5
5
  from nshtrainer.loggers.actsave import ActSaveLoggerConfig as ActSaveLoggerConfig
6
- from nshtrainer.loggers.actsave import BaseLoggerConfig as BaseLoggerConfig
6
+ from nshtrainer.loggers.actsave import LoggerConfigBase as LoggerConfigBase
7
+ from nshtrainer.loggers.actsave import logger_registry as logger_registry
7
8
 
8
9
  __all__ = [
9
10
  "ActSaveLoggerConfig",
10
- "BaseLoggerConfig",
11
+ "LoggerConfigBase",
12
+ "logger_registry",
11
13
  ]
@@ -0,0 +1,11 @@
1
+ from __future__ import annotations
2
+
3
+ __codegen__ = True
4
+
5
+ from nshtrainer.loggers.base import LoggerConfigBase as LoggerConfigBase
6
+ from nshtrainer.loggers.base import logger_registry as logger_registry
7
+
8
+ __all__ = [
9
+ "LoggerConfigBase",
10
+ "logger_registry",
11
+ ]
@@ -2,10 +2,12 @@ from __future__ import annotations
2
2
 
3
3
  __codegen__ = True
4
4
 
5
- from nshtrainer.loggers.csv import BaseLoggerConfig as BaseLoggerConfig
6
5
  from nshtrainer.loggers.csv import CSVLoggerConfig as CSVLoggerConfig
6
+ from nshtrainer.loggers.csv import LoggerConfigBase as LoggerConfigBase
7
+ from nshtrainer.loggers.csv import logger_registry as logger_registry
7
8
 
8
9
  __all__ = [
9
- "BaseLoggerConfig",
10
10
  "CSVLoggerConfig",
11
+ "LoggerConfigBase",
12
+ "logger_registry",
11
13
  ]
@@ -2,12 +2,14 @@ from __future__ import annotations
2
2
 
3
3
  __codegen__ = True
4
4
 
5
- from nshtrainer.loggers.tensorboard import BaseLoggerConfig as BaseLoggerConfig
5
+ from nshtrainer.loggers.tensorboard import LoggerConfigBase as LoggerConfigBase
6
6
  from nshtrainer.loggers.tensorboard import (
7
7
  TensorboardLoggerConfig as TensorboardLoggerConfig,
8
8
  )
9
+ from nshtrainer.loggers.tensorboard import logger_registry as logger_registry
9
10
 
10
11
  __all__ = [
11
- "BaseLoggerConfig",
12
+ "LoggerConfigBase",
12
13
  "TensorboardLoggerConfig",
14
+ "logger_registry",
13
15
  ]
@@ -2,8 +2,8 @@ from __future__ import annotations
2
2
 
3
3
  __codegen__ = True
4
4
 
5
- from nshtrainer.loggers.wandb import BaseLoggerConfig as BaseLoggerConfig
6
5
  from nshtrainer.loggers.wandb import CallbackConfigBase as CallbackConfigBase
6
+ from nshtrainer.loggers.wandb import LoggerConfigBase as LoggerConfigBase
7
7
  from nshtrainer.loggers.wandb import WandbLoggerConfig as WandbLoggerConfig
8
8
  from nshtrainer.loggers.wandb import (
9
9
  WandbUploadCodeCallbackConfig as WandbUploadCodeCallbackConfig,
@@ -11,11 +11,13 @@ from nshtrainer.loggers.wandb import (
11
11
  from nshtrainer.loggers.wandb import (
12
12
  WandbWatchCallbackConfig as WandbWatchCallbackConfig,
13
13
  )
14
+ from nshtrainer.loggers.wandb import logger_registry as logger_registry
14
15
 
15
16
  __all__ = [
16
- "BaseLoggerConfig",
17
17
  "CallbackConfigBase",
18
+ "LoggerConfigBase",
18
19
  "WandbLoggerConfig",
19
20
  "WandbUploadCodeCallbackConfig",
20
21
  "WandbWatchCallbackConfig",
22
+ "logger_registry",
21
23
  ]
@@ -8,12 +8,13 @@ from nshtrainer.lr_scheduler import (
8
8
  from nshtrainer.lr_scheduler import LRSchedulerConfig as LRSchedulerConfig
9
9
  from nshtrainer.lr_scheduler import LRSchedulerConfigBase as LRSchedulerConfigBase
10
10
  from nshtrainer.lr_scheduler import ReduceLROnPlateauConfig as ReduceLROnPlateauConfig
11
+ from nshtrainer.lr_scheduler.base import lr_scheduler_registry as lr_scheduler_registry
11
12
  from nshtrainer.lr_scheduler.linear_warmup_cosine import (
12
13
  DurationConfig as DurationConfig,
13
14
  )
14
15
  from nshtrainer.lr_scheduler.reduce_lr_on_plateau import MetricConfig as MetricConfig
15
16
 
16
- from . import _base as _base
17
+ from . import base as base
17
18
  from . import linear_warmup_cosine as linear_warmup_cosine
18
19
  from . import reduce_lr_on_plateau as reduce_lr_on_plateau
19
20
 
@@ -24,7 +25,8 @@ __all__ = [
24
25
  "LinearWarmupCosineDecayLRSchedulerConfig",
25
26
  "MetricConfig",
26
27
  "ReduceLROnPlateauConfig",
27
- "_base",
28
+ "base",
28
29
  "linear_warmup_cosine",
30
+ "lr_scheduler_registry",
29
31
  "reduce_lr_on_plateau",
30
32
  ]
@@ -0,0 +1,11 @@
1
+ from __future__ import annotations
2
+
3
+ __codegen__ = True
4
+
5
+ from nshtrainer.lr_scheduler.base import LRSchedulerConfigBase as LRSchedulerConfigBase
6
+ from nshtrainer.lr_scheduler.base import lr_scheduler_registry as lr_scheduler_registry
7
+
8
+ __all__ = [
9
+ "LRSchedulerConfigBase",
10
+ "lr_scheduler_registry",
11
+ ]
@@ -11,9 +11,13 @@ from nshtrainer.lr_scheduler.linear_warmup_cosine import (
11
11
  from nshtrainer.lr_scheduler.linear_warmup_cosine import (
12
12
  LRSchedulerConfigBase as LRSchedulerConfigBase,
13
13
  )
14
+ from nshtrainer.lr_scheduler.linear_warmup_cosine import (
15
+ lr_scheduler_registry as lr_scheduler_registry,
16
+ )
14
17
 
15
18
  __all__ = [
16
19
  "DurationConfig",
17
20
  "LRSchedulerConfigBase",
18
21
  "LinearWarmupCosineDecayLRSchedulerConfig",
22
+ "lr_scheduler_registry",
19
23
  ]
@@ -9,9 +9,13 @@ from nshtrainer.lr_scheduler.reduce_lr_on_plateau import MetricConfig as MetricC
9
9
  from nshtrainer.lr_scheduler.reduce_lr_on_plateau import (
10
10
  ReduceLROnPlateauConfig as ReduceLROnPlateauConfig,
11
11
  )
12
+ from nshtrainer.lr_scheduler.reduce_lr_on_plateau import (
13
+ lr_scheduler_registry as lr_scheduler_registry,
14
+ )
12
15
 
13
16
  __all__ = [
14
17
  "LRSchedulerConfigBase",
15
18
  "MetricConfig",
16
19
  "ReduceLROnPlateauConfig",
20
+ "lr_scheduler_registry",
17
21
  ]
@@ -2,13 +2,13 @@ from __future__ import annotations
2
2
 
3
3
  __codegen__ = True
4
4
 
5
- from nshtrainer.nn import BaseNonlinearityConfig as BaseNonlinearityConfig
6
5
  from nshtrainer.nn import ELUNonlinearityConfig as ELUNonlinearityConfig
7
6
  from nshtrainer.nn import GELUNonlinearityConfig as GELUNonlinearityConfig
8
7
  from nshtrainer.nn import LeakyReLUNonlinearityConfig as LeakyReLUNonlinearityConfig
9
8
  from nshtrainer.nn import MishNonlinearityConfig as MishNonlinearityConfig
10
9
  from nshtrainer.nn import MLPConfig as MLPConfig
11
10
  from nshtrainer.nn import NonlinearityConfig as NonlinearityConfig
11
+ from nshtrainer.nn import NonlinearityConfigBase as NonlinearityConfigBase
12
12
  from nshtrainer.nn import PReLUConfig as PReLUConfig
13
13
  from nshtrainer.nn import ReLUNonlinearityConfig as ReLUNonlinearityConfig
14
14
  from nshtrainer.nn import SigmoidNonlinearityConfig as SigmoidNonlinearityConfig
@@ -21,18 +21,19 @@ from nshtrainer.nn import TanhNonlinearityConfig as TanhNonlinearityConfig
21
21
  from nshtrainer.nn.nonlinearity import (
22
22
  SwiGLUNonlinearityConfig as SwiGLUNonlinearityConfig,
23
23
  )
24
+ from nshtrainer.nn.nonlinearity import nonlinearity_registry as nonlinearity_registry
24
25
 
25
26
  from . import mlp as mlp
26
27
  from . import nonlinearity as nonlinearity
27
28
 
28
29
  __all__ = [
29
- "BaseNonlinearityConfig",
30
30
  "ELUNonlinearityConfig",
31
31
  "GELUNonlinearityConfig",
32
32
  "LeakyReLUNonlinearityConfig",
33
33
  "MLPConfig",
34
34
  "MishNonlinearityConfig",
35
35
  "NonlinearityConfig",
36
+ "NonlinearityConfigBase",
36
37
  "PReLUConfig",
37
38
  "ReLUNonlinearityConfig",
38
39
  "SiLUNonlinearityConfig",
@@ -45,4 +46,5 @@ __all__ = [
45
46
  "TanhNonlinearityConfig",
46
47
  "mlp",
47
48
  "nonlinearity",
49
+ "nonlinearity_registry",
48
50
  ]
@@ -2,12 +2,12 @@ from __future__ import annotations
2
2
 
3
3
  __codegen__ = True
4
4
 
5
- from nshtrainer.nn.mlp import BaseNonlinearityConfig as BaseNonlinearityConfig
6
5
  from nshtrainer.nn.mlp import MLPConfig as MLPConfig
7
6
  from nshtrainer.nn.mlp import NonlinearityConfig as NonlinearityConfig
7
+ from nshtrainer.nn.mlp import NonlinearityConfigBase as NonlinearityConfigBase
8
8
 
9
9
  __all__ = [
10
- "BaseNonlinearityConfig",
11
10
  "MLPConfig",
12
11
  "NonlinearityConfig",
12
+ "NonlinearityConfigBase",
13
13
  ]
@@ -2,7 +2,6 @@ from __future__ import annotations
2
2
 
3
3
  __codegen__ = True
4
4
 
5
- from nshtrainer.nn.nonlinearity import BaseNonlinearityConfig as BaseNonlinearityConfig
6
5
  from nshtrainer.nn.nonlinearity import ELUNonlinearityConfig as ELUNonlinearityConfig
7
6
  from nshtrainer.nn.nonlinearity import GELUNonlinearityConfig as GELUNonlinearityConfig
8
7
  from nshtrainer.nn.nonlinearity import (
@@ -10,6 +9,7 @@ from nshtrainer.nn.nonlinearity import (
10
9
  )
11
10
  from nshtrainer.nn.nonlinearity import MishNonlinearityConfig as MishNonlinearityConfig
12
11
  from nshtrainer.nn.nonlinearity import NonlinearityConfig as NonlinearityConfig
12
+ from nshtrainer.nn.nonlinearity import NonlinearityConfigBase as NonlinearityConfigBase
13
13
  from nshtrainer.nn.nonlinearity import PReLUConfig as PReLUConfig
14
14
  from nshtrainer.nn.nonlinearity import ReLUNonlinearityConfig as ReLUNonlinearityConfig
15
15
  from nshtrainer.nn.nonlinearity import (
@@ -32,14 +32,15 @@ from nshtrainer.nn.nonlinearity import (
32
32
  SwishNonlinearityConfig as SwishNonlinearityConfig,
33
33
  )
34
34
  from nshtrainer.nn.nonlinearity import TanhNonlinearityConfig as TanhNonlinearityConfig
35
+ from nshtrainer.nn.nonlinearity import nonlinearity_registry as nonlinearity_registry
35
36
 
36
37
  __all__ = [
37
- "BaseNonlinearityConfig",
38
38
  "ELUNonlinearityConfig",
39
39
  "GELUNonlinearityConfig",
40
40
  "LeakyReLUNonlinearityConfig",
41
41
  "MishNonlinearityConfig",
42
42
  "NonlinearityConfig",
43
+ "NonlinearityConfigBase",
43
44
  "PReLUConfig",
44
45
  "ReLUNonlinearityConfig",
45
46
  "SiLUNonlinearityConfig",
@@ -50,4 +51,5 @@ __all__ = [
50
51
  "SwiGLUNonlinearityConfig",
51
52
  "SwishNonlinearityConfig",
52
53
  "TanhNonlinearityConfig",
54
+ "nonlinearity_registry",
53
55
  ]
@@ -5,9 +5,11 @@ __codegen__ = True
5
5
  from nshtrainer.optimizer import AdamWConfig as AdamWConfig
6
6
  from nshtrainer.optimizer import OptimizerConfig as OptimizerConfig
7
7
  from nshtrainer.optimizer import OptimizerConfigBase as OptimizerConfigBase
8
+ from nshtrainer.optimizer import optimizer_registry as optimizer_registry
8
9
 
9
10
  __all__ = [
10
11
  "AdamWConfig",
11
12
  "OptimizerConfig",
12
13
  "OptimizerConfigBase",
14
+ "optimizer_registry",
13
15
  ]
@@ -8,7 +8,6 @@ from nshtrainer.trainer import callback_registry as callback_registry
8
8
  from nshtrainer.trainer import plugin_registry as plugin_registry
9
9
  from nshtrainer.trainer._config import AcceleratorConfig as AcceleratorConfig
10
10
  from nshtrainer.trainer._config import ActSaveLoggerConfig as ActSaveLoggerConfig
11
- from nshtrainer.trainer._config import BaseLoggerConfig as BaseLoggerConfig
12
11
  from nshtrainer.trainer._config import (
13
12
  BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
14
13
  )
@@ -37,6 +36,7 @@ from nshtrainer.trainer._config import (
37
36
  )
38
37
  from nshtrainer.trainer._config import LogEpochCallbackConfig as LogEpochCallbackConfig
39
38
  from nshtrainer.trainer._config import LoggerConfig as LoggerConfig
39
+ from nshtrainer.trainer._config import LoggerConfigBase as LoggerConfigBase
40
40
  from nshtrainer.trainer._config import MetricConfig as MetricConfig
41
41
  from nshtrainer.trainer._config import (
42
42
  NormLoggingCallbackConfig as NormLoggingCallbackConfig,
@@ -133,7 +133,6 @@ __all__ = [
133
133
  "AcceleratorConfigBase",
134
134
  "ActSaveLoggerConfig",
135
135
  "AsyncCheckpointIOPlugin",
136
- "BaseLoggerConfig",
137
136
  "BestCheckpointCallbackConfig",
138
137
  "BitsandbytesPluginConfig",
139
138
  "CPUAcceleratorConfig",
@@ -161,6 +160,7 @@ __all__ = [
161
160
  "LightningEnvironmentPlugin",
162
161
  "LogEpochCallbackConfig",
163
162
  "LoggerConfig",
163
+ "LoggerConfigBase",
164
164
  "MPIEnvironmentPlugin",
165
165
  "MPSAcceleratorConfig",
166
166
  "MetricConfig",
@@ -4,7 +4,6 @@ __codegen__ = True
4
4
 
5
5
  from nshtrainer.trainer._config import AcceleratorConfig as AcceleratorConfig
6
6
  from nshtrainer.trainer._config import ActSaveLoggerConfig as ActSaveLoggerConfig
7
- from nshtrainer.trainer._config import BaseLoggerConfig as BaseLoggerConfig
8
7
  from nshtrainer.trainer._config import (
9
8
  BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
10
9
  )
@@ -33,6 +32,7 @@ from nshtrainer.trainer._config import (
33
32
  )
34
33
  from nshtrainer.trainer._config import LogEpochCallbackConfig as LogEpochCallbackConfig
35
34
  from nshtrainer.trainer._config import LoggerConfig as LoggerConfig
35
+ from nshtrainer.trainer._config import LoggerConfigBase as LoggerConfigBase
36
36
  from nshtrainer.trainer._config import MetricConfig as MetricConfig
37
37
  from nshtrainer.trainer._config import (
38
38
  NormLoggingCallbackConfig as NormLoggingCallbackConfig,
@@ -59,7 +59,6 @@ from nshtrainer.trainer._config import WandbLoggerConfig as WandbLoggerConfig
59
59
  __all__ = [
60
60
  "AcceleratorConfig",
61
61
  "ActSaveLoggerConfig",
62
- "BaseLoggerConfig",
63
62
  "BestCheckpointCallbackConfig",
64
63
  "CSVLoggerConfig",
65
64
  "CallbackConfig",
@@ -76,6 +75,7 @@ __all__ = [
76
75
  "LearningRateMonitorConfig",
77
76
  "LogEpochCallbackConfig",
78
77
  "LoggerConfig",
78
+ "LoggerConfigBase",
79
79
  "MetricConfig",
80
80
  "NormLoggingCallbackConfig",
81
81
  "OnExceptionCheckpointCallbackConfig",
@@ -5,19 +5,14 @@ from typing import Annotated
5
5
  import nshconfig as C
6
6
  from typing_extensions import TypeAliasType
7
7
 
8
- from ._base import BaseLoggerConfig as BaseLoggerConfig
9
8
  from .actsave import ActSaveLoggerConfig as ActSaveLoggerConfig
9
+ from .base import LoggerConfigBase as LoggerConfigBase
10
+ from .base import logger_registry as logger_registry
10
11
  from .csv import CSVLoggerConfig as CSVLoggerConfig
11
12
  from .tensorboard import TensorboardLoggerConfig as TensorboardLoggerConfig
12
13
  from .wandb import WandbLoggerConfig as WandbLoggerConfig
13
14
 
14
15
  LoggerConfig = TypeAliasType(
15
16
  "LoggerConfig",
16
- Annotated[
17
- CSVLoggerConfig
18
- | TensorboardLoggerConfig
19
- | WandbLoggerConfig
20
- | ActSaveLoggerConfig,
21
- C.Field(discriminator="name"),
22
- ],
17
+ Annotated[LoggerConfigBase, logger_registry.DynamicResolution()],
23
18
  )
@@ -5,11 +5,14 @@ from typing import Any, Literal
5
5
 
6
6
  import numpy as np
7
7
  from lightning.pytorch.loggers import Logger
8
+ from typing_extensions import final
8
9
 
9
- from ._base import BaseLoggerConfig
10
+ from .base import LoggerConfigBase, logger_registry
10
11
 
11
12
 
12
- class ActSaveLoggerConfig(BaseLoggerConfig):
13
+ @final
14
+ @logger_registry.register
15
+ class ActSaveLoggerConfig(LoggerConfigBase):
13
16
  name: Literal["actsave"] = "actsave"
14
17
 
15
18
  def create_logger(self, trainer_config):
@@ -10,7 +10,7 @@ if TYPE_CHECKING:
10
10
  from ..trainer._config import TrainerConfig
11
11
 
12
12
 
13
- class BaseLoggerConfig(C.Config, ABC):
13
+ class LoggerConfigBase(C.Config, ABC):
14
14
  enabled: bool = True
15
15
  """Enable this logger."""
16
16
 
@@ -29,3 +29,6 @@ class BaseLoggerConfig(C.Config, ABC):
29
29
 
30
30
  def __bool__(self):
31
31
  return self.enabled
32
+
33
+
34
+ logger_registry = C.Registry(LoggerConfigBase, discriminator="name")
nshtrainer/loggers/csv.py CHANGED
@@ -2,12 +2,14 @@ from __future__ import annotations
2
2
 
3
3
  from typing import Literal
4
4
 
5
- from typing_extensions import override
5
+ from typing_extensions import final, override
6
6
 
7
- from ._base import BaseLoggerConfig
7
+ from .base import LoggerConfigBase, logger_registry
8
8
 
9
9
 
10
- class CSVLoggerConfig(BaseLoggerConfig):
10
+ @final
11
+ @logger_registry.register
12
+ class CSVLoggerConfig(LoggerConfigBase):
11
13
  name: Literal["csv"] = "csv"
12
14
 
13
15
  enabled: bool = True
@@ -4,9 +4,9 @@ import logging
4
4
  from typing import Literal
5
5
 
6
6
  import nshconfig as C
7
- from typing_extensions import override
7
+ from typing_extensions import final, override
8
8
 
9
- from ._base import BaseLoggerConfig
9
+ from .base import LoggerConfigBase, logger_registry
10
10
 
11
11
  log = logging.getLogger(__name__)
12
12
 
@@ -30,7 +30,9 @@ def _tensorboard_available():
30
30
  return False
31
31
 
32
32
 
33
- class TensorboardLoggerConfig(BaseLoggerConfig):
33
+ @final
34
+ @logger_registry.register
35
+ class TensorboardLoggerConfig(LoggerConfigBase):
34
36
  name: Literal["tensorboard"] = "tensorboard"
35
37
 
36
38
  enabled: bool = C.Field(default_factory=lambda: _tensorboard_available())
@@ -7,12 +7,12 @@ from typing import TYPE_CHECKING, Literal
7
7
  import nshconfig as C
8
8
  from lightning.pytorch import Callback, LightningModule, Trainer
9
9
  from packaging import version
10
- from typing_extensions import assert_never, override
10
+ from typing_extensions import assert_never, final, override
11
11
 
12
12
  from ..callbacks.base import CallbackConfigBase
13
13
  from ..callbacks.wandb_upload_code import WandbUploadCodeCallbackConfig
14
14
  from ..callbacks.wandb_watch import WandbWatchCallbackConfig
15
- from ._base import BaseLoggerConfig
15
+ from .base import LoggerConfigBase, logger_registry
16
16
 
17
17
  if TYPE_CHECKING:
18
18
  from ..trainer._config import TrainerConfig
@@ -73,7 +73,9 @@ class FinishWandbOnTeardownCallback(Callback):
73
73
  wandb.finish()
74
74
 
75
75
 
76
- class WandbLoggerConfig(CallbackConfigBase, BaseLoggerConfig):
76
+ @final
77
+ @logger_registry.register
78
+ class WandbLoggerConfig(CallbackConfigBase, LoggerConfigBase):
77
79
  name: Literal["wandb"] = "wandb"
78
80
 
79
81
  enabled: bool = C.Field(default_factory=lambda: _wandb_available())
@@ -5,8 +5,8 @@ from typing import Annotated
5
5
  import nshconfig as C
6
6
  from typing_extensions import TypeAliasType
7
7
 
8
- from ._base import LRSchedulerConfigBase as LRSchedulerConfigBase
9
- from ._base import LRSchedulerMetadata as LRSchedulerMetadata
8
+ from .base import LRSchedulerConfigBase as LRSchedulerConfigBase
9
+ from .base import LRSchedulerMetadata as LRSchedulerMetadata
10
10
  from .linear_warmup_cosine import (
11
11
  LinearWarmupCosineAnnealingLR as LinearWarmupCosineAnnealingLR,
12
12
  )
@@ -94,3 +94,6 @@ class LRSchedulerConfigBase(C.Config, ABC):
94
94
  # ^ This is a hack to trigger the computation of the estimated stepping batches
95
95
  # and make sure that the `trainer.num_training_batches` attribute is set.
96
96
  return math.ceil(trainer.num_training_batches / trainer.accumulate_grad_batches)
97
+
98
+
99
+ lr_scheduler_registry = C.Registry(LRSchedulerConfigBase, discriminator="name")
@@ -6,10 +6,64 @@ from typing import Literal
6
6
 
7
7
  from torch.optim import Optimizer
8
8
  from torch.optim.lr_scheduler import LRScheduler
9
- from typing_extensions import override
9
+ from typing_extensions import final, override
10
10
 
11
11
  from ..util.config import DurationConfig
12
- from ._base import LRSchedulerConfigBase, LRSchedulerMetadata
12
+ from .base import LRSchedulerConfigBase, LRSchedulerMetadata, lr_scheduler_registry
13
+
14
+
15
+ @final
16
+ @lr_scheduler_registry.register
17
+ class LinearWarmupCosineDecayLRSchedulerConfig(LRSchedulerConfigBase):
18
+ name: Literal["linear_warmup_cosine_decay"] = "linear_warmup_cosine_decay"
19
+
20
+ warmup_duration: DurationConfig
21
+ r"""The duration for the linear warmup phase.
22
+ The learning rate is linearly increased from `warmup_start_lr` to the initial learning rate over this duration."""
23
+
24
+ max_duration: DurationConfig
25
+ r"""The total duration.
26
+ The learning rate is decayed to `min_lr` over this duration."""
27
+
28
+ warmup_start_lr_factor: float = 0.0
29
+ r"""The initial learning rate for the linear warmup phase, as a factor of the initial learning rate.
30
+ The learning rate is linearly increased from this value to the initial learning rate over `warmup_epochs` epochs."""
31
+
32
+ min_lr_factor: float = 0.0
33
+ r"""The minimum learning rate, as a factor of the initial learning rate.
34
+ The learning rate is decayed to this value over `max_epochs` epochs."""
35
+
36
+ annealing: bool = False
37
+ r"""Whether to restart the learning rate schedule after `max_epochs` epochs.
38
+ If `False`, the learning rate will be decayed to `min_lr` over `max_epochs` epochs, and then the learning rate will be set to `min_lr` for all subsequent epochs.
39
+ If `True`, the learning rate will be decayed to `min_lr` over `max_epochs` epochs, and then the learning rate will be increased back to the initial learning rate over `max_epochs` epochs, and so on (this is called a cosine annealing schedule)."""
40
+
41
+ @override
42
+ def metadata(self) -> LRSchedulerMetadata:
43
+ return {
44
+ "interval": "step",
45
+ }
46
+
47
+ @override
48
+ def create_scheduler_impl(self, optimizer, lightning_module):
49
+ num_steps_per_epoch = self.compute_num_steps_per_epoch(lightning_module)
50
+ warmup_steps = self.warmup_duration.to_steps(num_steps_per_epoch).value
51
+ max_steps = self.max_duration.to_steps(num_steps_per_epoch).value
52
+
53
+ # Warmup and max steps should be at least 1.
54
+ warmup_steps = max(warmup_steps, 1)
55
+ max_steps = max(max_steps, 1)
56
+
57
+ # Create the scheduler
58
+ scheduler = LinearWarmupCosineAnnealingLR(
59
+ optimizer=optimizer,
60
+ warmup_epochs=warmup_steps,
61
+ max_epochs=max_steps,
62
+ warmup_start_lr_factor=self.warmup_start_lr_factor,
63
+ eta_min_factor=self.min_lr_factor,
64
+ should_restart=self.annealing,
65
+ )
66
+ return scheduler
13
67
 
14
68
 
15
69
  class LinearWarmupCosineAnnealingLR(LRScheduler):
@@ -89,55 +143,3 @@ class LinearWarmupCosineAnnealingLR(LRScheduler):
89
143
  + self.eta_min_factor * base_lr
90
144
  for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
91
145
  ]
92
-
93
-
94
- class LinearWarmupCosineDecayLRSchedulerConfig(LRSchedulerConfigBase):
95
- name: Literal["linear_warmup_cosine_decay"] = "linear_warmup_cosine_decay"
96
-
97
- warmup_duration: DurationConfig
98
- r"""The duration for the linear warmup phase.
99
- The learning rate is linearly increased from `warmup_start_lr` to the initial learning rate over this duration."""
100
-
101
- max_duration: DurationConfig
102
- r"""The total duration.
103
- The learning rate is decayed to `min_lr` over this duration."""
104
-
105
- warmup_start_lr_factor: float = 0.0
106
- r"""The initial learning rate for the linear warmup phase, as a factor of the initial learning rate.
107
- The learning rate is linearly increased from this value to the initial learning rate over `warmup_epochs` epochs."""
108
-
109
- min_lr_factor: float = 0.0
110
- r"""The minimum learning rate, as a factor of the initial learning rate.
111
- The learning rate is decayed to this value over `max_epochs` epochs."""
112
-
113
- annealing: bool = False
114
- r"""Whether to restart the learning rate schedule after `max_epochs` epochs.
115
- If `False`, the learning rate will be decayed to `min_lr` over `max_epochs` epochs, and then the learning rate will be set to `min_lr` for all subsequent epochs.
116
- If `True`, the learning rate will be decayed to `min_lr` over `max_epochs` epochs, and then the learning rate will be increased back to the initial learning rate over `max_epochs` epochs, and so on (this is called a cosine annealing schedule)."""
117
-
118
- @override
119
- def metadata(self) -> LRSchedulerMetadata:
120
- return {
121
- "interval": "step",
122
- }
123
-
124
- @override
125
- def create_scheduler_impl(self, optimizer, lightning_module):
126
- num_steps_per_epoch = self.compute_num_steps_per_epoch(lightning_module)
127
- warmup_steps = self.warmup_duration.to_steps(num_steps_per_epoch).value
128
- max_steps = self.max_duration.to_steps(num_steps_per_epoch).value
129
-
130
- # Warmup and max steps should be at least 1.
131
- warmup_steps = max(warmup_steps, 1)
132
- max_steps = max(max_steps, 1)
133
-
134
- # Create the scheduler
135
- scheduler = LinearWarmupCosineAnnealingLR(
136
- optimizer=optimizer,
137
- warmup_epochs=warmup_steps,
138
- max_epochs=max_steps,
139
- warmup_start_lr_factor=self.warmup_start_lr_factor,
140
- eta_min_factor=self.min_lr_factor,
141
- should_restart=self.annealing,
142
- )
143
- return scheduler
@@ -4,12 +4,14 @@ from typing import Literal
4
4
 
5
5
  from lightning.pytorch.utilities.types import LRSchedulerConfigType
6
6
  from torch.optim.lr_scheduler import ReduceLROnPlateau
7
- from typing_extensions import override
7
+ from typing_extensions import final, override
8
8
 
9
9
  from ..metrics._config import MetricConfig
10
- from ._base import LRSchedulerConfigBase, LRSchedulerMetadata
10
+ from .base import LRSchedulerConfigBase, LRSchedulerMetadata, lr_scheduler_registry
11
11
 
12
12
 
13
+ @final
14
+ @lr_scheduler_registry.register
13
15
  class ReduceLROnPlateauConfig(LRSchedulerConfigBase):
14
16
  """Reduce learning rate when a metric has stopped improving."""
15
17
 
nshtrainer/nn/__init__.py CHANGED
@@ -6,12 +6,12 @@ from .mlp import MLPConfigDict as MLPConfigDict
6
6
  from .mlp import ResidualSequential as ResidualSequential
7
7
  from .module_dict import TypedModuleDict as TypedModuleDict
8
8
  from .module_list import TypedModuleList as TypedModuleList
9
- from .nonlinearity import BaseNonlinearityConfig as BaseNonlinearityConfig
10
9
  from .nonlinearity import ELUNonlinearityConfig as ELUNonlinearityConfig
11
10
  from .nonlinearity import GELUNonlinearityConfig as GELUNonlinearityConfig
12
11
  from .nonlinearity import LeakyReLUNonlinearityConfig as LeakyReLUNonlinearityConfig
13
12
  from .nonlinearity import MishNonlinearityConfig as MishNonlinearityConfig
14
13
  from .nonlinearity import NonlinearityConfig as NonlinearityConfig
14
+ from .nonlinearity import NonlinearityConfigBase as NonlinearityConfigBase
15
15
  from .nonlinearity import PReLUConfig as PReLUConfig
16
16
  from .nonlinearity import ReLUNonlinearityConfig as ReLUNonlinearityConfig
17
17
  from .nonlinearity import SigmoidNonlinearityConfig as SigmoidNonlinearityConfig
nshtrainer/nn/mlp.py CHANGED
@@ -9,7 +9,7 @@ import torch
9
9
  import torch.nn as nn
10
10
  from typing_extensions import TypedDict, override
11
11
 
12
- from .nonlinearity import BaseNonlinearityConfig, NonlinearityConfig
12
+ from .nonlinearity import NonlinearityConfig, NonlinearityConfigBase
13
13
 
14
14
 
15
15
  @runtime_checkable
@@ -92,11 +92,11 @@ class MLPConfig(C.Config):
92
92
 
93
93
  def MLP(
94
94
  dims: Sequence[int],
95
- activation: BaseNonlinearityConfig
95
+ activation: NonlinearityConfigBase
96
96
  | nn.Module
97
97
  | Callable[[], nn.Module]
98
98
  | None = None,
99
- nonlinearity: BaseNonlinearityConfig
99
+ nonlinearity: NonlinearityConfigBase
100
100
  | nn.Module
101
101
  | Callable[[], nn.Module]
102
102
  | None = None,
@@ -153,7 +153,7 @@ def MLP(
153
153
  layers.append(nn.Dropout(dropout))
154
154
  if i < len(dims) - 2:
155
155
  match activation:
156
- case BaseNonlinearityConfig():
156
+ case NonlinearityConfigBase():
157
157
  layers.append(activation.create_module())
158
158
  case nn.Module():
159
159
  # In this case, we create a deep copy of the module to avoid sharing parameters (if any).
@@ -7,10 +7,10 @@ import nshconfig as C
7
7
  import torch
8
8
  import torch.nn as nn
9
9
  import torch.nn.functional as F
10
- from typing_extensions import final, override
10
+ from typing_extensions import TypeAliasType, final, override
11
11
 
12
12
 
13
- class BaseNonlinearityConfig(C.Config, ABC):
13
+ class NonlinearityConfigBase(C.Config, ABC):
14
14
  @abstractmethod
15
15
  def create_module(self) -> nn.Module: ...
16
16
 
@@ -18,8 +18,12 @@ class BaseNonlinearityConfig(C.Config, ABC):
18
18
  def __call__(self, x: torch.Tensor) -> torch.Tensor: ...
19
19
 
20
20
 
21
+ nonlinearity_registry = C.Registry(NonlinearityConfigBase, discriminator="name")
22
+
23
+
21
24
  @final
22
- class ReLUNonlinearityConfig(BaseNonlinearityConfig):
25
+ @nonlinearity_registry.register
26
+ class ReLUNonlinearityConfig(NonlinearityConfigBase):
23
27
  name: Literal["relu"] = "relu"
24
28
 
25
29
  @override
@@ -31,7 +35,8 @@ class ReLUNonlinearityConfig(BaseNonlinearityConfig):
31
35
 
32
36
 
33
37
  @final
34
- class SigmoidNonlinearityConfig(BaseNonlinearityConfig):
38
+ @nonlinearity_registry.register
39
+ class SigmoidNonlinearityConfig(NonlinearityConfigBase):
35
40
  name: Literal["sigmoid"] = "sigmoid"
36
41
 
37
42
  @override
@@ -43,7 +48,8 @@ class SigmoidNonlinearityConfig(BaseNonlinearityConfig):
43
48
 
44
49
 
45
50
  @final
46
- class TanhNonlinearityConfig(BaseNonlinearityConfig):
51
+ @nonlinearity_registry.register
52
+ class TanhNonlinearityConfig(NonlinearityConfigBase):
47
53
  name: Literal["tanh"] = "tanh"
48
54
 
49
55
  @override
@@ -55,7 +61,8 @@ class TanhNonlinearityConfig(BaseNonlinearityConfig):
55
61
 
56
62
 
57
63
  @final
58
- class SoftmaxNonlinearityConfig(BaseNonlinearityConfig):
64
+ @nonlinearity_registry.register
65
+ class SoftmaxNonlinearityConfig(NonlinearityConfigBase):
59
66
  name: Literal["softmax"] = "softmax"
60
67
 
61
68
  dim: int = -1
@@ -70,7 +77,8 @@ class SoftmaxNonlinearityConfig(BaseNonlinearityConfig):
70
77
 
71
78
 
72
79
  @final
73
- class SoftplusNonlinearityConfig(BaseNonlinearityConfig):
80
+ @nonlinearity_registry.register
81
+ class SoftplusNonlinearityConfig(NonlinearityConfigBase):
74
82
  name: Literal["softplus"] = "softplus"
75
83
 
76
84
  beta: float = 1.0
@@ -88,7 +96,8 @@ class SoftplusNonlinearityConfig(BaseNonlinearityConfig):
88
96
 
89
97
 
90
98
  @final
91
- class SoftsignNonlinearityConfig(BaseNonlinearityConfig):
99
+ @nonlinearity_registry.register
100
+ class SoftsignNonlinearityConfig(NonlinearityConfigBase):
92
101
  name: Literal["softsign"] = "softsign"
93
102
 
94
103
  @override
@@ -100,7 +109,8 @@ class SoftsignNonlinearityConfig(BaseNonlinearityConfig):
100
109
 
101
110
 
102
111
  @final
103
- class ELUNonlinearityConfig(BaseNonlinearityConfig):
112
+ @nonlinearity_registry.register
113
+ class ELUNonlinearityConfig(NonlinearityConfigBase):
104
114
  name: Literal["elu"] = "elu"
105
115
 
106
116
  alpha: float = 1.0
@@ -115,7 +125,8 @@ class ELUNonlinearityConfig(BaseNonlinearityConfig):
115
125
 
116
126
 
117
127
  @final
118
- class LeakyReLUNonlinearityConfig(BaseNonlinearityConfig):
128
+ @nonlinearity_registry.register
129
+ class LeakyReLUNonlinearityConfig(NonlinearityConfigBase):
119
130
  name: Literal["leaky_relu"] = "leaky_relu"
120
131
 
121
132
  negative_slope: float = 1.0e-2
@@ -130,7 +141,8 @@ class LeakyReLUNonlinearityConfig(BaseNonlinearityConfig):
130
141
 
131
142
 
132
143
  @final
133
- class PReLUConfig(BaseNonlinearityConfig):
144
+ @nonlinearity_registry.register
145
+ class PReLUConfig(NonlinearityConfigBase):
134
146
  name: Literal["prelu"] = "prelu"
135
147
 
136
148
  num_parameters: int = 1
@@ -152,7 +164,8 @@ class PReLUConfig(BaseNonlinearityConfig):
152
164
 
153
165
 
154
166
  @final
155
- class GELUNonlinearityConfig(BaseNonlinearityConfig):
167
+ @nonlinearity_registry.register
168
+ class GELUNonlinearityConfig(NonlinearityConfigBase):
156
169
  name: Literal["gelu"] = "gelu"
157
170
 
158
171
  approximate: Literal["tanh", "none"] = "none"
@@ -167,7 +180,8 @@ class GELUNonlinearityConfig(BaseNonlinearityConfig):
167
180
 
168
181
 
169
182
  @final
170
- class SwishNonlinearityConfig(BaseNonlinearityConfig):
183
+ @nonlinearity_registry.register
184
+ class SwishNonlinearityConfig(NonlinearityConfigBase):
171
185
  name: Literal["swish"] = "swish"
172
186
 
173
187
  @override
@@ -179,7 +193,8 @@ class SwishNonlinearityConfig(BaseNonlinearityConfig):
179
193
 
180
194
 
181
195
  @final
182
- class SiLUNonlinearityConfig(BaseNonlinearityConfig):
196
+ @nonlinearity_registry.register
197
+ class SiLUNonlinearityConfig(NonlinearityConfigBase):
183
198
  name: Literal["silu"] = "silu"
184
199
 
185
200
  @override
@@ -191,7 +206,8 @@ class SiLUNonlinearityConfig(BaseNonlinearityConfig):
191
206
 
192
207
 
193
208
  @final
194
- class MishNonlinearityConfig(BaseNonlinearityConfig):
209
+ @nonlinearity_registry.register
210
+ class MishNonlinearityConfig(NonlinearityConfigBase):
195
211
  name: Literal["mish"] = "mish"
196
212
 
197
213
  @override
@@ -210,7 +226,8 @@ class SwiGLU(nn.SiLU):
210
226
 
211
227
 
212
228
  @final
213
- class SwiGLUNonlinearityConfig(BaseNonlinearityConfig):
229
+ @nonlinearity_registry.register
230
+ class SwiGLUNonlinearityConfig(NonlinearityConfigBase):
214
231
  name: Literal["swiglu"] = "swiglu"
215
232
 
216
233
  @override
@@ -222,20 +239,7 @@ class SwiGLUNonlinearityConfig(BaseNonlinearityConfig):
222
239
  return input * F.silu(gate)
223
240
 
224
241
 
225
- NonlinearityConfig = Annotated[
226
- ReLUNonlinearityConfig
227
- | SigmoidNonlinearityConfig
228
- | TanhNonlinearityConfig
229
- | SoftmaxNonlinearityConfig
230
- | SoftplusNonlinearityConfig
231
- | SoftsignNonlinearityConfig
232
- | ELUNonlinearityConfig
233
- | LeakyReLUNonlinearityConfig
234
- | PReLUConfig
235
- | GELUNonlinearityConfig
236
- | SwishNonlinearityConfig
237
- | SiLUNonlinearityConfig
238
- | MishNonlinearityConfig
239
- | SwiGLUNonlinearityConfig,
240
- C.Field(discriminator="name"),
241
- ]
242
+ NonlinearityConfig = TypeAliasType(
243
+ "NonlinearityConfig",
244
+ Annotated[NonlinearityConfigBase, nonlinearity_registry.DynamicResolution()],
245
+ )
nshtrainer/optimizer.py CHANGED
@@ -7,7 +7,7 @@ from typing import Annotated, Any, Literal
7
7
  import nshconfig as C
8
8
  import torch.nn as nn
9
9
  from torch.optim import Optimizer
10
- from typing_extensions import TypeAliasType, override
10
+ from typing_extensions import TypeAliasType, final, override
11
11
 
12
12
 
13
13
  class OptimizerConfigBase(C.Config, ABC):
@@ -18,6 +18,11 @@ class OptimizerConfigBase(C.Config, ABC):
18
18
  ) -> Optimizer: ...
19
19
 
20
20
 
21
+ optimizer_registry = C.Registry(OptimizerConfigBase, discriminator="name")
22
+
23
+
24
+ @final
25
+ @optimizer_registry.register
21
26
  class AdamWConfig(OptimizerConfigBase):
22
27
  name: Literal["adamw"] = "adamw"
23
28
 
@@ -58,5 +63,6 @@ class AdamWConfig(OptimizerConfigBase):
58
63
 
59
64
 
60
65
  OptimizerConfig = TypeAliasType(
61
- "OptimizerConfig", Annotated[AdamWConfig, C.Field(discriminator="name")]
66
+ "OptimizerConfig",
67
+ Annotated[OptimizerConfigBase, optimizer_registry.DynamicResolution()],
62
68
  )
@@ -48,8 +48,8 @@ from ..loggers import (
48
48
  TensorboardLoggerConfig,
49
49
  WandbLoggerConfig,
50
50
  )
51
- from ..loggers._base import BaseLoggerConfig
52
51
  from ..loggers.actsave import ActSaveLoggerConfig
52
+ from ..loggers.base import LoggerConfigBase
53
53
  from ..metrics._config import MetricConfig
54
54
  from ..profiler import ProfilerConfig
55
55
  from ..util._environment_info import EnvironmentConfig
@@ -770,7 +770,7 @@ class TrainerConfig(C.Config):
770
770
  yield self.auto_set_debug_flag
771
771
  yield from self.callbacks
772
772
 
773
- def _nshtrainer_all_logger_configs(self) -> Iterable[BaseLoggerConfig | None]:
773
+ def _nshtrainer_all_logger_configs(self) -> Iterable[LoggerConfigBase | None]:
774
774
  # Disable all loggers if barebones mode is enabled
775
775
  if self.barebones:
776
776
  return
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 1.0.0b36
3
+ Version: 1.0.0b37
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -3,7 +3,7 @@ nshtrainer/__init__.py,sha256=g_moPnfQxSxFZX5NB9ILQQOJrt4RTRuiFt9N0STIpxM,874
3
3
  nshtrainer/_callback.py,sha256=tXQCDzS6CvMTuTY5lQSH5qZs1pXUi-gt9bQdpXMVdEs,12715
4
4
  nshtrainer/_checkpoint/metadata.py,sha256=PHy-54Cg-o3OtCffAqrVv6ZVMU7zhRo_-sZiSEEno1Y,5019
5
5
  nshtrainer/_checkpoint/saver.py,sha256=LOP8jjKF0Dw9x9H-BKrLMWlEp1XTan2DUK0zQUCWw5U,1360
6
- nshtrainer/_directory.py,sha256=p2uk1FnISFEpMqlDevKhoWhQsCEtvHUPg459K-86QA8,3053
6
+ nshtrainer/_directory.py,sha256=xY8Z9POZJw0Uh56yqffZbnNZvdA_tnWCucT31dhwFCM,3183
7
7
  nshtrainer/_experimental/__init__.py,sha256=U4S_2y3zgLZVfMenHRaJFBW8yqh2mUBuI291LGQVOJ8,35
8
8
  nshtrainer/_hf_hub.py,sha256=4OsCbIITnZk_YLyoMrVyZ0SIN04FBxlC0ig2Et8UAdo,14287
9
9
  nshtrainer/callbacks/__init__.py,sha256=4giOYT8A709UOLRtQEt16QbOAFUHCjJ_aLB7ITTwXJI,3577
@@ -30,7 +30,7 @@ nshtrainer/callbacks/shared_parameters.py,sha256=s94jJTAIbDGukYJu6l247QonVOCudGC
30
30
  nshtrainer/callbacks/timer.py,sha256=gDcw_K_ikf0bkVgxQ0cDhvvNvz6GLZVLcatuKfh0ORU,4731
31
31
  nshtrainer/callbacks/wandb_upload_code.py,sha256=shV7UtnXgY2bUlXdVrXiaDs0PNLlIt7TzNJkJPkzvzI,2414
32
32
  nshtrainer/callbacks/wandb_watch.py,sha256=VB14Dy5ZRXQ3di0fPv0K_DFJurLhroLPytnuwQBiJFg,3037
33
- nshtrainer/configs/__init__.py,sha256=OevZEZxb4H8imadSQXK9huqdYUF4SrJPfNU_2fpMBvI,14084
33
+ nshtrainer/configs/__init__.py,sha256=MZfcSKhnjtVObBvVv9lu8L2cFTLINP5zcTQvWnz8jdk,14505
34
34
  nshtrainer/configs/_checkpoint/__init__.py,sha256=6s7Y68StboqscY2G4P_QG443jz5aiym5SjOogIljWLg,342
35
35
  nshtrainer/configs/_checkpoint/metadata/__init__.py,sha256=oOPfYkXTjKgm6pluGsG6V1TPyCEGjsQpHVL-LffSUFQ,290
36
36
  nshtrainer/configs/_directory/__init__.py,sha256=_oO7vM9DhzHSxtZcv86sTi7hZIptnK1gr-AP9mqQ370,386
@@ -58,29 +58,29 @@ nshtrainer/configs/callbacks/shared_parameters/__init__.py,sha256=AU7_bSnSRSlj16
58
58
  nshtrainer/configs/callbacks/timer/__init__.py,sha256=cOUtbsl0_OhCO0fIcBfLuIF6FEGBHQu7AvQFzwVznWQ,413
59
59
  nshtrainer/configs/callbacks/wandb_upload_code/__init__.py,sha256=CJeCc9OCu5F39lWiY5aIc4WxQlgBvB-8cga6cQtw0GQ,482
60
60
  nshtrainer/configs/callbacks/wandb_watch/__init__.py,sha256=dzz1oavL1BwELE33xus45_avBEAZDeB6xtcb6CsOEos,431
61
- nshtrainer/configs/loggers/__init__.py,sha256=5wTekL79mQxit8f1K3AMllvb0mKertTzOKfC3gpE2Zk,1251
62
- nshtrainer/configs/loggers/_base/__init__.py,sha256=HxPPPePsEjlNuhnjsMgYIl0rwj_iqNKKOBTEk_zIOsM,169
63
- nshtrainer/configs/loggers/actsave/__init__.py,sha256=2lZQ4bpbjwd4MuUE_Z_PGbmQjjGtWCZUCtXqKO4dTSc,280
64
- nshtrainer/configs/loggers/csv/__init__.py,sha256=M3QGF5GKiRGENy3re6LJKpa4A4RThy1FlmaFuR4cPyo,260
65
- nshtrainer/configs/loggers/tensorboard/__init__.py,sha256=FbkYXnSohIX6JN5XyI-9y91IJv_T3VB3IwmpagXAnM4,309
66
- nshtrainer/configs/loggers/wandb/__init__.py,sha256=76qb0HhWojf0Ub1x9OkMjtzeXxE67KysBGa-MBbJyC4,651
67
- nshtrainer/configs/lr_scheduler/__init__.py,sha256=8ORO-QC12SjZ2F_reMoDgr8-O8nxZxX0IKU4fl-cC3A,1023
68
- nshtrainer/configs/lr_scheduler/_base/__init__.py,sha256=fvGjkUJ1K2RVXjXror22QOtEa-xWFJz2Cz3HrBC5XfA,189
69
- nshtrainer/configs/lr_scheduler/linear_warmup_cosine/__init__.py,sha256=i8LeZh0c4wqtZ1ehZb2LCq7kwOL0OyswMMOnwyI6R04,533
70
- nshtrainer/configs/lr_scheduler/reduce_lr_on_plateau/__init__.py,sha256=lpXEFZY4cM3znZqYG9IZ1xNNtzttt8VVspSuOz0fb-k,467
61
+ nshtrainer/configs/loggers/__init__.py,sha256=GT7PO7UM3Mo87N616mGucc2ZRyGP8nQWBd_VJ_8RGXo,1337
62
+ nshtrainer/configs/loggers/actsave/__init__.py,sha256=J7SnbD-zxUynWSskJezooFyBZdnhgTWyybRvwn9gzy4,377
63
+ nshtrainer/configs/loggers/base/__init__.py,sha256=HLUfEDbjaAXqzsFmQbjdciIWzR1st1gRLKTCFvUFEX0,262
64
+ nshtrainer/configs/loggers/csv/__init__.py,sha256=gawaDX92JObGSmBqYpfNHWMHBwVOofS694W-1Y2GWDU,353
65
+ nshtrainer/configs/loggers/tensorboard/__init__.py,sha256=phzm-TnBkdkibTgoOxIIcAliqL3zU8gSNK61Mwxs1CM,410
66
+ nshtrainer/configs/loggers/wandb/__init__.py,sha256=TDcD5WZSKenc2mgIXhwz2l96l8P_Ur3N5CzEol5AKGw,746
67
+ nshtrainer/configs/lr_scheduler/__init__.py,sha256=xtiUx0isxA82-uXMn4-KmPnDCfbUkpAnd2_pFupAAKQ,1137
68
+ nshtrainer/configs/lr_scheduler/base/__init__.py,sha256=6Cx8r4rdxeSYxc_z0o7drKCblGJU_zzqrOoYlWYR5qY,305
69
+ nshtrainer/configs/lr_scheduler/linear_warmup_cosine/__init__.py,sha256=5ZMLDO9VL6SNU6pF-62lDnpmqix3_Ol9DdEwiuOPYlA,675
70
+ nshtrainer/configs/lr_scheduler/reduce_lr_on_plateau/__init__.py,sha256=w-vq8UbRGPX8DZVWCMC5eIrbvVc_guxjj7Du9AaeKCw,609
71
71
  nshtrainer/configs/metrics/__init__.py,sha256=mK_xgXJDAyGY4K_x_Zo4aj36kjT45d850keuUe3U1rY,200
72
72
  nshtrainer/configs/metrics/_config/__init__.py,sha256=XDDvDPWULd_vd3lrgF2KGAVR2LVDuhdQvy-fF2ImarI,159
73
- nshtrainer/configs/nn/__init__.py,sha256=3hVc81Gs9AJYVkrwJkQ_ye7tLU2HOLdBj-mMkXx2c_I,1957
74
- nshtrainer/configs/nn/mlp/__init__.py,sha256=eMECrgz-My9mFS7lpWVI3dj1ApB-E7xwfmNc37hUsPI,347
75
- nshtrainer/configs/nn/nonlinearity/__init__.py,sha256=Gjr2HCx8jJTcfu7sLgn54o2ucGKaBea4encm4AWpKNY,2040
76
- nshtrainer/configs/optimizer/__init__.py,sha256=IMEsEbiVFXSkj6WmDjNjmKQuRspphs5xZnYZ2gYE39Y,344
73
+ nshtrainer/configs/nn/__init__.py,sha256=tkFG2Hb0oL_AmWP3_0WkDN2zI5PkVfrgwXhaAII7CZw,2072
74
+ nshtrainer/configs/nn/mlp/__init__.py,sha256=O6kQ6utZNJPG9Fax5pRdZcHa3J-XFKKdXcc_PQg0jk0,347
75
+ nshtrainer/configs/nn/nonlinearity/__init__.py,sha256=LCTbTyelCMABVw505CGQ4UpEGlAnIhflSLFqwAQXLQA,2155
76
+ nshtrainer/configs/optimizer/__init__.py,sha256=itIDIHQvGm50eZ7JLyNElahnNUMPJ__4PMmTjc0RQ6o,444
77
77
  nshtrainer/configs/profiler/__init__.py,sha256=2ssaIpfVnvcbfNvZ-JeKp1Cx4NO1LknkVqTm1hu7Lvw,768
78
78
  nshtrainer/configs/profiler/_base/__init__.py,sha256=ekYfPg-VDhCAFM5nJka2TxUYdRDm1CKqjwUOQNbQjD4,176
79
79
  nshtrainer/configs/profiler/advanced/__init__.py,sha256=-ThpUat16Ij_0avkMUVVA8wCWDG_q_tM7KQofnWQCtg,308
80
80
  nshtrainer/configs/profiler/pytorch/__init__.py,sha256=soAU1s2_Pa1na4gW8CK-iysJBO5M_7YeZC2_x40iEdg,294
81
81
  nshtrainer/configs/profiler/simple/__init__.py,sha256=3Wb11lPuFuyasq8xS1CZ4WLuBCLS_nVSQGVllvOOi0Y,289
82
- nshtrainer/configs/trainer/__init__.py,sha256=8Z4E1IeJHtDW8fpDxJkiC9CgDqKrTBIR5VMK1q4DYy4,7729
83
- nshtrainer/configs/trainer/_config/__init__.py,sha256=t72kmUn60UtjpD6H38XzKbEs50gU2dS1IH0u-RnHZ04,3666
82
+ nshtrainer/configs/trainer/__init__.py,sha256=jYCp4Q9uvutA6NYqfthbREMg09-obD3gHtzEI2Ta-hU,7729
83
+ nshtrainer/configs/trainer/_config/__init__.py,sha256=uof_oJfhwjB1pft7KsRdk_RvNj-tE8wcDBEM7X5qtNc,3666
84
84
  nshtrainer/configs/trainer/accelerator/__init__.py,sha256=3H6R3wlwbKL1TzDqGCChZk78-BcE2czLouo7Djiq3nA,898
85
85
  nshtrainer/configs/trainer/plugin/__init__.py,sha256=NkHQxMPkrtTtdIAO4dQUE9SWEcHRDB0yUXLkTjnl4dA,3332
86
86
  nshtrainer/configs/trainer/plugin/base/__init__.py,sha256=GuubcKrbXt4VjJRT8VpNUQqBtyuutre_CfkS0EWZ5_E,368
@@ -99,16 +99,16 @@ nshtrainer/data/__init__.py,sha256=K4i3Tw4g9EOK2zlMMbidi99y0SyI4w8P7_XUf1n42Ts,2
99
99
  nshtrainer/data/balanced_batch_sampler.py,sha256=r1cBKRXKHD8E1Ax6tj-FUbE-z1qpbO58mQ9VrK9uLnc,5481
100
100
  nshtrainer/data/datamodule.py,sha256=lSOgH32nysJWa6Y7ba1QyOdUV0DVVdO98qokP8wigjk,4138
101
101
  nshtrainer/data/transform.py,sha256=qd0lIocO59Fk_m90xyOHgFezbymd1mRwly8nbYIfHGc,2263
102
- nshtrainer/loggers/__init__.py,sha256=-y8B-9TF6vJdZUQewJNDcZ2aOv04FEUFtKwaiDobIO0,670
103
- nshtrainer/loggers/_base.py,sha256=nw4AZzJP3Z-fljgQlgq7FkuMkPmYKTsXj7OfJJSmtXI,811
104
- nshtrainer/loggers/actsave.py,sha256=23Kre-mq-Y9Iw1SRyGmHnHK1bc_0gTWFXJViv9bkVz0,1324
105
- nshtrainer/loggers/csv.py,sha256=Deh5gm3oROJbQzigV4SHni5JRwSrBdm-4YD3yrcGnHo,1104
106
- nshtrainer/loggers/tensorboard.py,sha256=jP9V4nlq_MXUaoD6xv1Cws2ioft83Lm8yUJhhGhuUrQ,2268
107
- nshtrainer/loggers/wandb.py,sha256=EjKQQznLSUSCWO7uIviz9g0dVW4ZLxb_8UVhY4vR7r0,6800
108
- nshtrainer/lr_scheduler/__init__.py,sha256=BGnO-okUTZOtF15-UmQ05U4oEatSF5VNs3YeidNEWn4,853
109
- nshtrainer/lr_scheduler/_base.py,sha256=EhA2f_WiZ79RcXL2nJbwCwNK620c8ugEVUmJ8CcVpsE,3723
110
- nshtrainer/lr_scheduler/linear_warmup_cosine.py,sha256=gvUuv031lvWdXboDeH7iAF3ZgNPQK40bQwfmqb11TNk,5492
111
- nshtrainer/lr_scheduler/reduce_lr_on_plateau.py,sha256=vXH5S26ESHO_LPPqW8aDC3S5NGoZYkXeFjAOgttaUX8,2870
102
+ nshtrainer/loggers/__init__.py,sha256=Ddd3JJXVzew_ZpwHA9kGnGmvq4OwhItwghDL5PzNhDc,614
103
+ nshtrainer/loggers/actsave.py,sha256=wgNrpBB6wQM7qff8iLDb_sQnbiAcYHRmH56pcEJPB3o,1409
104
+ nshtrainer/loggers/base.py,sha256=1-HoPmOiyXevQvMLXboiKe-4GOE1V5SvjURohOHakVc,882
105
+ nshtrainer/loggers/csv.py,sha256=xJ8mSmw4vJwinIfqhF6t2HWmh_1dXEYyLfGuXwL7WHo,1160
106
+ nshtrainer/loggers/tensorboard.py,sha256=E7iO_fDt9bfH02hBL430bXPLljOo5iGgq2QyPqmx2gQ,2324
107
+ nshtrainer/loggers/wandb.py,sha256=KZXAUWrrmdX_L8rqej77oUHaM0JxZRM8y9z6JP9PISw,6856
108
+ nshtrainer/lr_scheduler/__init__.py,sha256=daMMK3erUcNXGGd_nZB8AWu3ZTYqfS1RSWeK4FV2udw,851
109
+ nshtrainer/lr_scheduler/base.py,sha256=062fGcH5sYeEKwoY55RydCTvfPwTnyZHCi049a3nMbM,3805
110
+ nshtrainer/lr_scheduler/linear_warmup_cosine.py,sha256=MsoXgCcWTKsrkNZiGnKS6yC-slRuleuwFxeM_lmG_pQ,5560
111
+ nshtrainer/lr_scheduler/reduce_lr_on_plateau.py,sha256=v9T0GpvOoHV30atFB0MwExHgHcTpMCYxbMRoPjPBjt8,2938
112
112
  nshtrainer/metrics/__init__.py,sha256=Nqkn_jsDf3n5WtfMcnaaEftYjIIT2b-S7rmsB1MOMkU,86
113
113
  nshtrainer/metrics/_config.py,sha256=XIRokFM8PHrhBa3w2R6BM6a4es3ncsoBqE_LqXQFsFE,1223
114
114
  nshtrainer/model/__init__.py,sha256=3G-bwPPSRStWdsdwG9-rn0bXcRpEiP1BiQpF_qavtls,97
@@ -116,19 +116,19 @@ nshtrainer/model/base.py,sha256=JL3AmH17GQjQIoMrZl3O0vUI7dj5ZsO5iEJgoLPyzHw,1035
116
116
  nshtrainer/model/mixins/callback.py,sha256=0LPgve4VszHbLipid4mpI1qnnmdGS2spivs0dXLvqHw,3154
117
117
  nshtrainer/model/mixins/debug.py,sha256=1LX9KzeFX9JDPs_a6YCdYDZXLhEk_5rBO2aCqlfBy7w,2087
118
118
  nshtrainer/model/mixins/logger.py,sha256=27H99FuLaxc6_dDLG2pid4E_5E0-eLGnc2Ifpt0HYIM,6066
119
- nshtrainer/nn/__init__.py,sha256=sANhrZpeN5syLKOsmXMwhaFl2SBFPWcLaEe1EH22TWQ,1463
120
- nshtrainer/nn/mlp.py,sha256=2W8bzE96DzCMzGm6WPiPhNFQfhqaoG3GXPn_oKBnlUM,5988
119
+ nshtrainer/nn/__init__.py,sha256=7KCs-GDOynCXAIdwkgAQacc0p3FHLEION50UtrvgAOc,1463
120
+ nshtrainer/nn/mlp.py,sha256=ZbkLyOc08stgIugvu1G5_h66DYtxAFDnboikBaJvvZ8,5988
121
121
  nshtrainer/nn/module_dict.py,sha256=9plb8aQUx5TUEPhX5jI9u8LrpTeKe7jZAHi8iIqcN8w,2365
122
122
  nshtrainer/nn/module_list.py,sha256=UB43pcwD_3nUke_DyLQt-iXKhWdKM6Zjm84lRC1hPYA,1755
123
- nshtrainer/nn/nonlinearity.py,sha256=mp5XvXRHURB6jwuZ0YyTj5ZoHJYNJNgO2aLtUY1D-2Y,6114
124
- nshtrainer/optimizer.py,sha256=wmSRpSoU59rstj2RBoifQ15ZwRInYpm0tDBQZ1gqOfE,1596
123
+ nshtrainer/nn/nonlinearity.py,sha256=xmaL4QCRvCxqmaGIOwetJeKK-6IK4m2OV7D3SjxSwJQ,6322
124
+ nshtrainer/optimizer.py,sha256=u968GRNPUNn3f_9BEY2RBNuJq5O3wJWams3NG0dkrOA,1738
125
125
  nshtrainer/profiler/__init__.py,sha256=RjaNBoVcTFu8lF0dNlFp-2LaPYdonoIbDy2_KhgF0Ek,594
126
126
  nshtrainer/profiler/_base.py,sha256=kFcSVn9gJuMwgDxbfyHh46CmEAIPZjxw3yjPbKgzvwA,950
127
127
  nshtrainer/profiler/advanced.py,sha256=XrM3FX0ThCv5UwUrrH0l4Ow4LGAtpiBww2N8QAU5NOQ,1160
128
128
  nshtrainer/profiler/pytorch.py,sha256=8K37XvPnCApUpIK8tA2zNMFIaIiTLSoxKQoiyCPBm1Q,2757
129
129
  nshtrainer/profiler/simple.py,sha256=PimjqcU-JuS-8C0ZGHAdwCxgNLij4x0FH6WXsjBQzZs,1005
130
130
  nshtrainer/trainer/__init__.py,sha256=fQ7gQRlGWX-90TYT0rttkQyvXDCzo7DAvJgr-jX1zsY,316
131
- nshtrainer/trainer/_config.py,sha256=SPg3WXjF3ufhnr27sTHQLq23hdebnW6CTWa8AJkRG0A,32982
131
+ nshtrainer/trainer/_config.py,sha256=QDy6sINVDGEqfHfPTWXSN-06EoEuMSVscHn8fCRTvr0,32981
132
132
  nshtrainer/trainer/_runtime_callback.py,sha256=6F2Gq27Q8OFfN3RtdNC6QRA8ac0LC1hh4DUE3V5WgbI,4217
133
133
  nshtrainer/trainer/accelerator.py,sha256=Bqq-ry7DeCY4zw9_zBvTZiijpA-uUHrDjtbLV652m4M,2415
134
134
  nshtrainer/trainer/plugin/__init__.py,sha256=UM8f70Ml3RGpsXeQr1Yh1yBcxxjFNGFGgJbniCn_rws,366
@@ -151,6 +151,6 @@ nshtrainer/util/seed.py,sha256=diMV8iwBKN7Xxt5pELmui-gyqyT80_CZzomrWhNss0k,316
151
151
  nshtrainer/util/slurm.py,sha256=HflkP5iI_r4UHMyPjw9R4dD5AHsJUpcfJw5PLvGYBRM,1603
152
152
  nshtrainer/util/typed.py,sha256=Xt5fUU6zwLKSTLUdenovnKK0N8qUq89Kddz2_XeykVQ,164
153
153
  nshtrainer/util/typing_utils.py,sha256=MjY-CUX9R5Tzat-BlFnQjwl1PQ_W2yZQoXhkYHlJ_VA,442
154
- nshtrainer-1.0.0b36.dist-info/METADATA,sha256=R9O2SnflaNiDkxtoOPD_YFCXIgnEl8YjkhbEU5CbWHQ,988
155
- nshtrainer-1.0.0b36.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
156
- nshtrainer-1.0.0b36.dist-info/RECORD,,
154
+ nshtrainer-1.0.0b37.dist-info/METADATA,sha256=ObMgpZ_qJLmBAkeRDN7ufTuRSTltiB_LYPFTphNvWks,988
155
+ nshtrainer-1.0.0b37.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
156
+ nshtrainer-1.0.0b37.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- from __future__ import annotations
2
-
3
- __codegen__ = True
4
-
5
- from nshtrainer.loggers._base import BaseLoggerConfig as BaseLoggerConfig
6
-
7
- __all__ = [
8
- "BaseLoggerConfig",
9
- ]
@@ -1,9 +0,0 @@
1
- from __future__ import annotations
2
-
3
- __codegen__ = True
4
-
5
- from nshtrainer.lr_scheduler._base import LRSchedulerConfigBase as LRSchedulerConfigBase
6
-
7
- __all__ = [
8
- "LRSchedulerConfigBase",
9
- ]