nshtrainer 1.0.0b33__py3-none-any.whl → 1.0.0b37__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. nshtrainer/__init__.py +1 -0
  2. nshtrainer/_directory.py +3 -1
  3. nshtrainer/_hf_hub.py +8 -1
  4. nshtrainer/callbacks/__init__.py +10 -23
  5. nshtrainer/callbacks/actsave.py +6 -2
  6. nshtrainer/callbacks/base.py +3 -0
  7. nshtrainer/callbacks/checkpoint/__init__.py +0 -4
  8. nshtrainer/callbacks/checkpoint/best_checkpoint.py +2 -0
  9. nshtrainer/callbacks/checkpoint/last_checkpoint.py +72 -2
  10. nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +4 -2
  11. nshtrainer/callbacks/debug_flag.py +4 -2
  12. nshtrainer/callbacks/directory_setup.py +23 -21
  13. nshtrainer/callbacks/early_stopping.py +4 -2
  14. nshtrainer/callbacks/ema.py +29 -27
  15. nshtrainer/callbacks/finite_checks.py +21 -19
  16. nshtrainer/callbacks/gradient_skipping.py +29 -27
  17. nshtrainer/callbacks/log_epoch.py +4 -2
  18. nshtrainer/callbacks/lr_monitor.py +6 -1
  19. nshtrainer/callbacks/norm_logging.py +36 -34
  20. nshtrainer/callbacks/print_table.py +20 -18
  21. nshtrainer/callbacks/rlp_sanity_checks.py +4 -2
  22. nshtrainer/callbacks/shared_parameters.py +9 -7
  23. nshtrainer/callbacks/timer.py +12 -10
  24. nshtrainer/callbacks/wandb_upload_code.py +4 -2
  25. nshtrainer/callbacks/wandb_watch.py +4 -2
  26. nshtrainer/configs/__init__.py +16 -12
  27. nshtrainer/configs/_hf_hub/__init__.py +2 -0
  28. nshtrainer/configs/callbacks/__init__.py +4 -8
  29. nshtrainer/configs/callbacks/actsave/__init__.py +2 -0
  30. nshtrainer/configs/callbacks/base/__init__.py +2 -0
  31. nshtrainer/configs/callbacks/checkpoint/__init__.py +4 -6
  32. nshtrainer/configs/callbacks/checkpoint/best_checkpoint/__init__.py +4 -0
  33. nshtrainer/configs/callbacks/checkpoint/last_checkpoint/__init__.py +4 -0
  34. nshtrainer/configs/callbacks/checkpoint/on_exception_checkpoint/__init__.py +4 -0
  35. nshtrainer/configs/callbacks/debug_flag/__init__.py +2 -0
  36. nshtrainer/configs/callbacks/directory_setup/__init__.py +2 -0
  37. nshtrainer/configs/callbacks/early_stopping/__init__.py +2 -0
  38. nshtrainer/configs/callbacks/ema/__init__.py +2 -0
  39. nshtrainer/configs/callbacks/finite_checks/__init__.py +2 -0
  40. nshtrainer/configs/callbacks/gradient_skipping/__init__.py +4 -0
  41. nshtrainer/configs/callbacks/log_epoch/__init__.py +2 -0
  42. nshtrainer/configs/callbacks/lr_monitor/__init__.py +2 -0
  43. nshtrainer/configs/callbacks/norm_logging/__init__.py +2 -0
  44. nshtrainer/configs/callbacks/print_table/__init__.py +2 -0
  45. nshtrainer/configs/callbacks/rlp_sanity_checks/__init__.py +4 -0
  46. nshtrainer/configs/callbacks/shared_parameters/__init__.py +4 -0
  47. nshtrainer/configs/callbacks/timer/__init__.py +2 -0
  48. nshtrainer/configs/callbacks/wandb_upload_code/__init__.py +4 -0
  49. nshtrainer/configs/callbacks/wandb_watch/__init__.py +2 -0
  50. nshtrainer/configs/loggers/__init__.py +6 -4
  51. nshtrainer/configs/loggers/actsave/__init__.py +4 -2
  52. nshtrainer/configs/loggers/base/__init__.py +11 -0
  53. nshtrainer/configs/loggers/csv/__init__.py +4 -2
  54. nshtrainer/configs/loggers/tensorboard/__init__.py +4 -2
  55. nshtrainer/configs/loggers/wandb/__init__.py +4 -2
  56. nshtrainer/configs/lr_scheduler/__init__.py +4 -2
  57. nshtrainer/configs/lr_scheduler/base/__init__.py +11 -0
  58. nshtrainer/configs/lr_scheduler/linear_warmup_cosine/__init__.py +4 -0
  59. nshtrainer/configs/lr_scheduler/reduce_lr_on_plateau/__init__.py +4 -0
  60. nshtrainer/configs/nn/__init__.py +4 -2
  61. nshtrainer/configs/nn/mlp/__init__.py +2 -2
  62. nshtrainer/configs/nn/nonlinearity/__init__.py +4 -2
  63. nshtrainer/configs/optimizer/__init__.py +2 -0
  64. nshtrainer/configs/trainer/__init__.py +4 -6
  65. nshtrainer/configs/trainer/_config/__init__.py +2 -10
  66. nshtrainer/loggers/__init__.py +3 -8
  67. nshtrainer/loggers/actsave.py +5 -2
  68. nshtrainer/loggers/{_base.py → base.py} +4 -1
  69. nshtrainer/loggers/csv.py +5 -3
  70. nshtrainer/loggers/tensorboard.py +5 -3
  71. nshtrainer/loggers/wandb.py +5 -3
  72. nshtrainer/lr_scheduler/__init__.py +2 -2
  73. nshtrainer/lr_scheduler/{_base.py → base.py} +3 -0
  74. nshtrainer/lr_scheduler/linear_warmup_cosine.py +56 -54
  75. nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +4 -2
  76. nshtrainer/nn/__init__.py +1 -1
  77. nshtrainer/nn/mlp.py +4 -4
  78. nshtrainer/nn/nonlinearity.py +37 -33
  79. nshtrainer/optimizer.py +8 -2
  80. nshtrainer/trainer/__init__.py +3 -2
  81. nshtrainer/trainer/_config.py +6 -44
  82. {nshtrainer-1.0.0b33.dist-info → nshtrainer-1.0.0b37.dist-info}/METADATA +1 -1
  83. nshtrainer-1.0.0b37.dist-info/RECORD +156 -0
  84. nshtrainer/callbacks/checkpoint/time_checkpoint.py +0 -114
  85. nshtrainer/configs/callbacks/checkpoint/time_checkpoint/__init__.py +0 -19
  86. nshtrainer/configs/loggers/_base/__init__.py +0 -9
  87. nshtrainer/configs/lr_scheduler/_base/__init__.py +0 -9
  88. nshtrainer-1.0.0b33.dist-info/RECORD +0 -158
  89. {nshtrainer-1.0.0b33.dist-info → nshtrainer-1.0.0b37.dist-info}/WHEEL +0 -0
@@ -6,8 +6,10 @@ from nshtrainer.callbacks.lr_monitor import CallbackConfigBase as CallbackConfig
6
6
  from nshtrainer.callbacks.lr_monitor import (
7
7
  LearningRateMonitorConfig as LearningRateMonitorConfig,
8
8
  )
9
+ from nshtrainer.callbacks.lr_monitor import callback_registry as callback_registry
9
10
 
10
11
  __all__ = [
11
12
  "CallbackConfigBase",
12
13
  "LearningRateMonitorConfig",
14
+ "callback_registry",
13
15
  ]
@@ -6,8 +6,10 @@ from nshtrainer.callbacks.norm_logging import CallbackConfigBase as CallbackConf
6
6
  from nshtrainer.callbacks.norm_logging import (
7
7
  NormLoggingCallbackConfig as NormLoggingCallbackConfig,
8
8
  )
9
+ from nshtrainer.callbacks.norm_logging import callback_registry as callback_registry
9
10
 
10
11
  __all__ = [
11
12
  "CallbackConfigBase",
12
13
  "NormLoggingCallbackConfig",
14
+ "callback_registry",
13
15
  ]
@@ -6,8 +6,10 @@ from nshtrainer.callbacks.print_table import CallbackConfigBase as CallbackConfi
6
6
  from nshtrainer.callbacks.print_table import (
7
7
  PrintTableMetricsCallbackConfig as PrintTableMetricsCallbackConfig,
8
8
  )
9
+ from nshtrainer.callbacks.print_table import callback_registry as callback_registry
9
10
 
10
11
  __all__ = [
11
12
  "CallbackConfigBase",
12
13
  "PrintTableMetricsCallbackConfig",
14
+ "callback_registry",
13
15
  ]
@@ -8,8 +8,12 @@ from nshtrainer.callbacks.rlp_sanity_checks import (
8
8
  from nshtrainer.callbacks.rlp_sanity_checks import (
9
9
  RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
10
10
  )
11
+ from nshtrainer.callbacks.rlp_sanity_checks import (
12
+ callback_registry as callback_registry,
13
+ )
11
14
 
12
15
  __all__ = [
13
16
  "CallbackConfigBase",
14
17
  "RLPSanityChecksCallbackConfig",
18
+ "callback_registry",
15
19
  ]
@@ -8,8 +8,12 @@ from nshtrainer.callbacks.shared_parameters import (
8
8
  from nshtrainer.callbacks.shared_parameters import (
9
9
  SharedParametersCallbackConfig as SharedParametersCallbackConfig,
10
10
  )
11
+ from nshtrainer.callbacks.shared_parameters import (
12
+ callback_registry as callback_registry,
13
+ )
11
14
 
12
15
  __all__ = [
13
16
  "CallbackConfigBase",
14
17
  "SharedParametersCallbackConfig",
18
+ "callback_registry",
15
19
  ]
@@ -6,8 +6,10 @@ from nshtrainer.callbacks.timer import CallbackConfigBase as CallbackConfigBase
6
6
  from nshtrainer.callbacks.timer import (
7
7
  EpochTimerCallbackConfig as EpochTimerCallbackConfig,
8
8
  )
9
+ from nshtrainer.callbacks.timer import callback_registry as callback_registry
9
10
 
10
11
  __all__ = [
11
12
  "CallbackConfigBase",
12
13
  "EpochTimerCallbackConfig",
14
+ "callback_registry",
13
15
  ]
@@ -8,8 +8,12 @@ from nshtrainer.callbacks.wandb_upload_code import (
8
8
  from nshtrainer.callbacks.wandb_upload_code import (
9
9
  WandbUploadCodeCallbackConfig as WandbUploadCodeCallbackConfig,
10
10
  )
11
+ from nshtrainer.callbacks.wandb_upload_code import (
12
+ callback_registry as callback_registry,
13
+ )
11
14
 
12
15
  __all__ = [
13
16
  "CallbackConfigBase",
14
17
  "WandbUploadCodeCallbackConfig",
18
+ "callback_registry",
15
19
  ]
@@ -6,8 +6,10 @@ from nshtrainer.callbacks.wandb_watch import CallbackConfigBase as CallbackConfi
6
6
  from nshtrainer.callbacks.wandb_watch import (
7
7
  WandbWatchCallbackConfig as WandbWatchCallbackConfig,
8
8
  )
9
+ from nshtrainer.callbacks.wandb_watch import callback_registry as callback_registry
9
10
 
10
11
  __all__ = [
11
12
  "CallbackConfigBase",
12
13
  "WandbWatchCallbackConfig",
14
+ "callback_registry",
13
15
  ]
@@ -3,11 +3,12 @@ from __future__ import annotations
3
3
  __codegen__ = True
4
4
 
5
5
  from nshtrainer.loggers import ActSaveLoggerConfig as ActSaveLoggerConfig
6
- from nshtrainer.loggers import BaseLoggerConfig as BaseLoggerConfig
7
6
  from nshtrainer.loggers import CSVLoggerConfig as CSVLoggerConfig
8
7
  from nshtrainer.loggers import LoggerConfig as LoggerConfig
8
+ from nshtrainer.loggers import LoggerConfigBase as LoggerConfigBase
9
9
  from nshtrainer.loggers import TensorboardLoggerConfig as TensorboardLoggerConfig
10
10
  from nshtrainer.loggers import WandbLoggerConfig as WandbLoggerConfig
11
+ from nshtrainer.loggers import logger_registry as logger_registry
11
12
  from nshtrainer.loggers.wandb import CallbackConfigBase as CallbackConfigBase
12
13
  from nshtrainer.loggers.wandb import (
13
14
  WandbUploadCodeCallbackConfig as WandbUploadCodeCallbackConfig,
@@ -16,25 +17,26 @@ from nshtrainer.loggers.wandb import (
16
17
  WandbWatchCallbackConfig as WandbWatchCallbackConfig,
17
18
  )
18
19
 
19
- from . import _base as _base
20
20
  from . import actsave as actsave
21
+ from . import base as base
21
22
  from . import csv as csv
22
23
  from . import tensorboard as tensorboard
23
24
  from . import wandb as wandb
24
25
 
25
26
  __all__ = [
26
27
  "ActSaveLoggerConfig",
27
- "BaseLoggerConfig",
28
28
  "CSVLoggerConfig",
29
29
  "CallbackConfigBase",
30
30
  "LoggerConfig",
31
+ "LoggerConfigBase",
31
32
  "TensorboardLoggerConfig",
32
33
  "WandbLoggerConfig",
33
34
  "WandbUploadCodeCallbackConfig",
34
35
  "WandbWatchCallbackConfig",
35
- "_base",
36
36
  "actsave",
37
+ "base",
37
38
  "csv",
39
+ "logger_registry",
38
40
  "tensorboard",
39
41
  "wandb",
40
42
  ]
@@ -3,9 +3,11 @@ from __future__ import annotations
3
3
  __codegen__ = True
4
4
 
5
5
  from nshtrainer.loggers.actsave import ActSaveLoggerConfig as ActSaveLoggerConfig
6
- from nshtrainer.loggers.actsave import BaseLoggerConfig as BaseLoggerConfig
6
+ from nshtrainer.loggers.actsave import LoggerConfigBase as LoggerConfigBase
7
+ from nshtrainer.loggers.actsave import logger_registry as logger_registry
7
8
 
8
9
  __all__ = [
9
10
  "ActSaveLoggerConfig",
10
- "BaseLoggerConfig",
11
+ "LoggerConfigBase",
12
+ "logger_registry",
11
13
  ]
@@ -0,0 +1,11 @@
1
+ from __future__ import annotations
2
+
3
+ __codegen__ = True
4
+
5
+ from nshtrainer.loggers.base import LoggerConfigBase as LoggerConfigBase
6
+ from nshtrainer.loggers.base import logger_registry as logger_registry
7
+
8
+ __all__ = [
9
+ "LoggerConfigBase",
10
+ "logger_registry",
11
+ ]
@@ -2,10 +2,12 @@ from __future__ import annotations
2
2
 
3
3
  __codegen__ = True
4
4
 
5
- from nshtrainer.loggers.csv import BaseLoggerConfig as BaseLoggerConfig
6
5
  from nshtrainer.loggers.csv import CSVLoggerConfig as CSVLoggerConfig
6
+ from nshtrainer.loggers.csv import LoggerConfigBase as LoggerConfigBase
7
+ from nshtrainer.loggers.csv import logger_registry as logger_registry
7
8
 
8
9
  __all__ = [
9
- "BaseLoggerConfig",
10
10
  "CSVLoggerConfig",
11
+ "LoggerConfigBase",
12
+ "logger_registry",
11
13
  ]
@@ -2,12 +2,14 @@ from __future__ import annotations
2
2
 
3
3
  __codegen__ = True
4
4
 
5
- from nshtrainer.loggers.tensorboard import BaseLoggerConfig as BaseLoggerConfig
5
+ from nshtrainer.loggers.tensorboard import LoggerConfigBase as LoggerConfigBase
6
6
  from nshtrainer.loggers.tensorboard import (
7
7
  TensorboardLoggerConfig as TensorboardLoggerConfig,
8
8
  )
9
+ from nshtrainer.loggers.tensorboard import logger_registry as logger_registry
9
10
 
10
11
  __all__ = [
11
- "BaseLoggerConfig",
12
+ "LoggerConfigBase",
12
13
  "TensorboardLoggerConfig",
14
+ "logger_registry",
13
15
  ]
@@ -2,8 +2,8 @@ from __future__ import annotations
2
2
 
3
3
  __codegen__ = True
4
4
 
5
- from nshtrainer.loggers.wandb import BaseLoggerConfig as BaseLoggerConfig
6
5
  from nshtrainer.loggers.wandb import CallbackConfigBase as CallbackConfigBase
6
+ from nshtrainer.loggers.wandb import LoggerConfigBase as LoggerConfigBase
7
7
  from nshtrainer.loggers.wandb import WandbLoggerConfig as WandbLoggerConfig
8
8
  from nshtrainer.loggers.wandb import (
9
9
  WandbUploadCodeCallbackConfig as WandbUploadCodeCallbackConfig,
@@ -11,11 +11,13 @@ from nshtrainer.loggers.wandb import (
11
11
  from nshtrainer.loggers.wandb import (
12
12
  WandbWatchCallbackConfig as WandbWatchCallbackConfig,
13
13
  )
14
+ from nshtrainer.loggers.wandb import logger_registry as logger_registry
14
15
 
15
16
  __all__ = [
16
- "BaseLoggerConfig",
17
17
  "CallbackConfigBase",
18
+ "LoggerConfigBase",
18
19
  "WandbLoggerConfig",
19
20
  "WandbUploadCodeCallbackConfig",
20
21
  "WandbWatchCallbackConfig",
22
+ "logger_registry",
21
23
  ]
@@ -8,12 +8,13 @@ from nshtrainer.lr_scheduler import (
8
8
  from nshtrainer.lr_scheduler import LRSchedulerConfig as LRSchedulerConfig
9
9
  from nshtrainer.lr_scheduler import LRSchedulerConfigBase as LRSchedulerConfigBase
10
10
  from nshtrainer.lr_scheduler import ReduceLROnPlateauConfig as ReduceLROnPlateauConfig
11
+ from nshtrainer.lr_scheduler.base import lr_scheduler_registry as lr_scheduler_registry
11
12
  from nshtrainer.lr_scheduler.linear_warmup_cosine import (
12
13
  DurationConfig as DurationConfig,
13
14
  )
14
15
  from nshtrainer.lr_scheduler.reduce_lr_on_plateau import MetricConfig as MetricConfig
15
16
 
16
- from . import _base as _base
17
+ from . import base as base
17
18
  from . import linear_warmup_cosine as linear_warmup_cosine
18
19
  from . import reduce_lr_on_plateau as reduce_lr_on_plateau
19
20
 
@@ -24,7 +25,8 @@ __all__ = [
24
25
  "LinearWarmupCosineDecayLRSchedulerConfig",
25
26
  "MetricConfig",
26
27
  "ReduceLROnPlateauConfig",
27
- "_base",
28
+ "base",
28
29
  "linear_warmup_cosine",
30
+ "lr_scheduler_registry",
29
31
  "reduce_lr_on_plateau",
30
32
  ]
@@ -0,0 +1,11 @@
1
+ from __future__ import annotations
2
+
3
+ __codegen__ = True
4
+
5
+ from nshtrainer.lr_scheduler.base import LRSchedulerConfigBase as LRSchedulerConfigBase
6
+ from nshtrainer.lr_scheduler.base import lr_scheduler_registry as lr_scheduler_registry
7
+
8
+ __all__ = [
9
+ "LRSchedulerConfigBase",
10
+ "lr_scheduler_registry",
11
+ ]
@@ -11,9 +11,13 @@ from nshtrainer.lr_scheduler.linear_warmup_cosine import (
11
11
  from nshtrainer.lr_scheduler.linear_warmup_cosine import (
12
12
  LRSchedulerConfigBase as LRSchedulerConfigBase,
13
13
  )
14
+ from nshtrainer.lr_scheduler.linear_warmup_cosine import (
15
+ lr_scheduler_registry as lr_scheduler_registry,
16
+ )
14
17
 
15
18
  __all__ = [
16
19
  "DurationConfig",
17
20
  "LRSchedulerConfigBase",
18
21
  "LinearWarmupCosineDecayLRSchedulerConfig",
22
+ "lr_scheduler_registry",
19
23
  ]
@@ -9,9 +9,13 @@ from nshtrainer.lr_scheduler.reduce_lr_on_plateau import MetricConfig as MetricC
9
9
  from nshtrainer.lr_scheduler.reduce_lr_on_plateau import (
10
10
  ReduceLROnPlateauConfig as ReduceLROnPlateauConfig,
11
11
  )
12
+ from nshtrainer.lr_scheduler.reduce_lr_on_plateau import (
13
+ lr_scheduler_registry as lr_scheduler_registry,
14
+ )
12
15
 
13
16
  __all__ = [
14
17
  "LRSchedulerConfigBase",
15
18
  "MetricConfig",
16
19
  "ReduceLROnPlateauConfig",
20
+ "lr_scheduler_registry",
17
21
  ]
@@ -2,13 +2,13 @@ from __future__ import annotations
2
2
 
3
3
  __codegen__ = True
4
4
 
5
- from nshtrainer.nn import BaseNonlinearityConfig as BaseNonlinearityConfig
6
5
  from nshtrainer.nn import ELUNonlinearityConfig as ELUNonlinearityConfig
7
6
  from nshtrainer.nn import GELUNonlinearityConfig as GELUNonlinearityConfig
8
7
  from nshtrainer.nn import LeakyReLUNonlinearityConfig as LeakyReLUNonlinearityConfig
9
8
  from nshtrainer.nn import MishNonlinearityConfig as MishNonlinearityConfig
10
9
  from nshtrainer.nn import MLPConfig as MLPConfig
11
10
  from nshtrainer.nn import NonlinearityConfig as NonlinearityConfig
11
+ from nshtrainer.nn import NonlinearityConfigBase as NonlinearityConfigBase
12
12
  from nshtrainer.nn import PReLUConfig as PReLUConfig
13
13
  from nshtrainer.nn import ReLUNonlinearityConfig as ReLUNonlinearityConfig
14
14
  from nshtrainer.nn import SigmoidNonlinearityConfig as SigmoidNonlinearityConfig
@@ -21,18 +21,19 @@ from nshtrainer.nn import TanhNonlinearityConfig as TanhNonlinearityConfig
21
21
  from nshtrainer.nn.nonlinearity import (
22
22
  SwiGLUNonlinearityConfig as SwiGLUNonlinearityConfig,
23
23
  )
24
+ from nshtrainer.nn.nonlinearity import nonlinearity_registry as nonlinearity_registry
24
25
 
25
26
  from . import mlp as mlp
26
27
  from . import nonlinearity as nonlinearity
27
28
 
28
29
  __all__ = [
29
- "BaseNonlinearityConfig",
30
30
  "ELUNonlinearityConfig",
31
31
  "GELUNonlinearityConfig",
32
32
  "LeakyReLUNonlinearityConfig",
33
33
  "MLPConfig",
34
34
  "MishNonlinearityConfig",
35
35
  "NonlinearityConfig",
36
+ "NonlinearityConfigBase",
36
37
  "PReLUConfig",
37
38
  "ReLUNonlinearityConfig",
38
39
  "SiLUNonlinearityConfig",
@@ -45,4 +46,5 @@ __all__ = [
45
46
  "TanhNonlinearityConfig",
46
47
  "mlp",
47
48
  "nonlinearity",
49
+ "nonlinearity_registry",
48
50
  ]
@@ -2,12 +2,12 @@ from __future__ import annotations
2
2
 
3
3
  __codegen__ = True
4
4
 
5
- from nshtrainer.nn.mlp import BaseNonlinearityConfig as BaseNonlinearityConfig
6
5
  from nshtrainer.nn.mlp import MLPConfig as MLPConfig
7
6
  from nshtrainer.nn.mlp import NonlinearityConfig as NonlinearityConfig
7
+ from nshtrainer.nn.mlp import NonlinearityConfigBase as NonlinearityConfigBase
8
8
 
9
9
  __all__ = [
10
- "BaseNonlinearityConfig",
11
10
  "MLPConfig",
12
11
  "NonlinearityConfig",
12
+ "NonlinearityConfigBase",
13
13
  ]
@@ -2,7 +2,6 @@ from __future__ import annotations
2
2
 
3
3
  __codegen__ = True
4
4
 
5
- from nshtrainer.nn.nonlinearity import BaseNonlinearityConfig as BaseNonlinearityConfig
6
5
  from nshtrainer.nn.nonlinearity import ELUNonlinearityConfig as ELUNonlinearityConfig
7
6
  from nshtrainer.nn.nonlinearity import GELUNonlinearityConfig as GELUNonlinearityConfig
8
7
  from nshtrainer.nn.nonlinearity import (
@@ -10,6 +9,7 @@ from nshtrainer.nn.nonlinearity import (
10
9
  )
11
10
  from nshtrainer.nn.nonlinearity import MishNonlinearityConfig as MishNonlinearityConfig
12
11
  from nshtrainer.nn.nonlinearity import NonlinearityConfig as NonlinearityConfig
12
+ from nshtrainer.nn.nonlinearity import NonlinearityConfigBase as NonlinearityConfigBase
13
13
  from nshtrainer.nn.nonlinearity import PReLUConfig as PReLUConfig
14
14
  from nshtrainer.nn.nonlinearity import ReLUNonlinearityConfig as ReLUNonlinearityConfig
15
15
  from nshtrainer.nn.nonlinearity import (
@@ -32,14 +32,15 @@ from nshtrainer.nn.nonlinearity import (
32
32
  SwishNonlinearityConfig as SwishNonlinearityConfig,
33
33
  )
34
34
  from nshtrainer.nn.nonlinearity import TanhNonlinearityConfig as TanhNonlinearityConfig
35
+ from nshtrainer.nn.nonlinearity import nonlinearity_registry as nonlinearity_registry
35
36
 
36
37
  __all__ = [
37
- "BaseNonlinearityConfig",
38
38
  "ELUNonlinearityConfig",
39
39
  "GELUNonlinearityConfig",
40
40
  "LeakyReLUNonlinearityConfig",
41
41
  "MishNonlinearityConfig",
42
42
  "NonlinearityConfig",
43
+ "NonlinearityConfigBase",
43
44
  "PReLUConfig",
44
45
  "ReLUNonlinearityConfig",
45
46
  "SiLUNonlinearityConfig",
@@ -50,4 +51,5 @@ __all__ = [
50
51
  "SwiGLUNonlinearityConfig",
51
52
  "SwishNonlinearityConfig",
52
53
  "TanhNonlinearityConfig",
54
+ "nonlinearity_registry",
53
55
  ]
@@ -5,9 +5,11 @@ __codegen__ = True
5
5
  from nshtrainer.optimizer import AdamWConfig as AdamWConfig
6
6
  from nshtrainer.optimizer import OptimizerConfig as OptimizerConfig
7
7
  from nshtrainer.optimizer import OptimizerConfigBase as OptimizerConfigBase
8
+ from nshtrainer.optimizer import optimizer_registry as optimizer_registry
8
9
 
9
10
  __all__ = [
10
11
  "AdamWConfig",
11
12
  "OptimizerConfig",
12
13
  "OptimizerConfigBase",
14
+ "optimizer_registry",
13
15
  ]
@@ -4,10 +4,10 @@ __codegen__ = True
4
4
 
5
5
  from nshtrainer.trainer import TrainerConfig as TrainerConfig
6
6
  from nshtrainer.trainer import accelerator_registry as accelerator_registry
7
+ from nshtrainer.trainer import callback_registry as callback_registry
7
8
  from nshtrainer.trainer import plugin_registry as plugin_registry
8
9
  from nshtrainer.trainer._config import AcceleratorConfig as AcceleratorConfig
9
10
  from nshtrainer.trainer._config import ActSaveLoggerConfig as ActSaveLoggerConfig
10
- from nshtrainer.trainer._config import BaseLoggerConfig as BaseLoggerConfig
11
11
  from nshtrainer.trainer._config import (
12
12
  BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
13
13
  )
@@ -36,6 +36,7 @@ from nshtrainer.trainer._config import (
36
36
  )
37
37
  from nshtrainer.trainer._config import LogEpochCallbackConfig as LogEpochCallbackConfig
38
38
  from nshtrainer.trainer._config import LoggerConfig as LoggerConfig
39
+ from nshtrainer.trainer._config import LoggerConfigBase as LoggerConfigBase
39
40
  from nshtrainer.trainer._config import MetricConfig as MetricConfig
40
41
  from nshtrainer.trainer._config import (
41
42
  NormLoggingCallbackConfig as NormLoggingCallbackConfig,
@@ -55,9 +56,6 @@ from nshtrainer.trainer._config import StrategyConfig as StrategyConfig
55
56
  from nshtrainer.trainer._config import (
56
57
  TensorboardLoggerConfig as TensorboardLoggerConfig,
57
58
  )
58
- from nshtrainer.trainer._config import (
59
- TimeCheckpointCallbackConfig as TimeCheckpointCallbackConfig,
60
- )
61
59
  from nshtrainer.trainer._config import WandbLoggerConfig as WandbLoggerConfig
62
60
  from nshtrainer.trainer.accelerator import CPUAcceleratorConfig as CPUAcceleratorConfig
63
61
  from nshtrainer.trainer.accelerator import (
@@ -135,7 +133,6 @@ __all__ = [
135
133
  "AcceleratorConfigBase",
136
134
  "ActSaveLoggerConfig",
137
135
  "AsyncCheckpointIOPlugin",
138
- "BaseLoggerConfig",
139
136
  "BestCheckpointCallbackConfig",
140
137
  "BitsandbytesPluginConfig",
141
138
  "CPUAcceleratorConfig",
@@ -163,6 +160,7 @@ __all__ = [
163
160
  "LightningEnvironmentPlugin",
164
161
  "LogEpochCallbackConfig",
165
162
  "LoggerConfig",
163
+ "LoggerConfigBase",
166
164
  "MPIEnvironmentPlugin",
167
165
  "MPSAcceleratorConfig",
168
166
  "MetricConfig",
@@ -179,7 +177,6 @@ __all__ = [
179
177
  "StrategyConfig",
180
178
  "StrategyConfigBase",
181
179
  "TensorboardLoggerConfig",
182
- "TimeCheckpointCallbackConfig",
183
180
  "TorchCheckpointIOPlugin",
184
181
  "TorchElasticEnvironmentPlugin",
185
182
  "TorchSyncBatchNormPlugin",
@@ -193,6 +190,7 @@ __all__ = [
193
190
  "_config",
194
191
  "accelerator",
195
192
  "accelerator_registry",
193
+ "callback_registry",
196
194
  "plugin",
197
195
  "plugin_registry",
198
196
  "strategy",
@@ -4,7 +4,6 @@ __codegen__ = True
4
4
 
5
5
  from nshtrainer.trainer._config import AcceleratorConfig as AcceleratorConfig
6
6
  from nshtrainer.trainer._config import ActSaveLoggerConfig as ActSaveLoggerConfig
7
- from nshtrainer.trainer._config import BaseLoggerConfig as BaseLoggerConfig
8
7
  from nshtrainer.trainer._config import (
9
8
  BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
10
9
  )
@@ -33,6 +32,7 @@ from nshtrainer.trainer._config import (
33
32
  )
34
33
  from nshtrainer.trainer._config import LogEpochCallbackConfig as LogEpochCallbackConfig
35
34
  from nshtrainer.trainer._config import LoggerConfig as LoggerConfig
35
+ from nshtrainer.trainer._config import LoggerConfigBase as LoggerConfigBase
36
36
  from nshtrainer.trainer._config import MetricConfig as MetricConfig
37
37
  from nshtrainer.trainer._config import (
38
38
  NormLoggingCallbackConfig as NormLoggingCallbackConfig,
@@ -53,18 +53,12 @@ from nshtrainer.trainer._config import StrategyConfig as StrategyConfig
53
53
  from nshtrainer.trainer._config import (
54
54
  TensorboardLoggerConfig as TensorboardLoggerConfig,
55
55
  )
56
- from nshtrainer.trainer._config import (
57
- TimeCheckpointCallbackConfig as TimeCheckpointCallbackConfig,
58
- )
59
56
  from nshtrainer.trainer._config import TrainerConfig as TrainerConfig
60
57
  from nshtrainer.trainer._config import WandbLoggerConfig as WandbLoggerConfig
61
- from nshtrainer.trainer._config import accelerator_registry as accelerator_registry
62
- from nshtrainer.trainer._config import plugin_registry as plugin_registry
63
58
 
64
59
  __all__ = [
65
60
  "AcceleratorConfig",
66
61
  "ActSaveLoggerConfig",
67
- "BaseLoggerConfig",
68
62
  "BestCheckpointCallbackConfig",
69
63
  "CSVLoggerConfig",
70
64
  "CallbackConfig",
@@ -81,6 +75,7 @@ __all__ = [
81
75
  "LearningRateMonitorConfig",
82
76
  "LogEpochCallbackConfig",
83
77
  "LoggerConfig",
78
+ "LoggerConfigBase",
84
79
  "MetricConfig",
85
80
  "NormLoggingCallbackConfig",
86
81
  "OnExceptionCheckpointCallbackConfig",
@@ -91,9 +86,6 @@ __all__ = [
91
86
  "SharedParametersCallbackConfig",
92
87
  "StrategyConfig",
93
88
  "TensorboardLoggerConfig",
94
- "TimeCheckpointCallbackConfig",
95
89
  "TrainerConfig",
96
90
  "WandbLoggerConfig",
97
- "accelerator_registry",
98
- "plugin_registry",
99
91
  ]
@@ -5,19 +5,14 @@ from typing import Annotated
5
5
  import nshconfig as C
6
6
  from typing_extensions import TypeAliasType
7
7
 
8
- from ._base import BaseLoggerConfig as BaseLoggerConfig
9
8
  from .actsave import ActSaveLoggerConfig as ActSaveLoggerConfig
9
+ from .base import LoggerConfigBase as LoggerConfigBase
10
+ from .base import logger_registry as logger_registry
10
11
  from .csv import CSVLoggerConfig as CSVLoggerConfig
11
12
  from .tensorboard import TensorboardLoggerConfig as TensorboardLoggerConfig
12
13
  from .wandb import WandbLoggerConfig as WandbLoggerConfig
13
14
 
14
15
  LoggerConfig = TypeAliasType(
15
16
  "LoggerConfig",
16
- Annotated[
17
- CSVLoggerConfig
18
- | TensorboardLoggerConfig
19
- | WandbLoggerConfig
20
- | ActSaveLoggerConfig,
21
- C.Field(discriminator="name"),
22
- ],
17
+ Annotated[LoggerConfigBase, logger_registry.DynamicResolution()],
23
18
  )
@@ -5,11 +5,14 @@ from typing import Any, Literal
5
5
 
6
6
  import numpy as np
7
7
  from lightning.pytorch.loggers import Logger
8
+ from typing_extensions import final
8
9
 
9
- from ._base import BaseLoggerConfig
10
+ from .base import LoggerConfigBase, logger_registry
10
11
 
11
12
 
12
- class ActSaveLoggerConfig(BaseLoggerConfig):
13
+ @final
14
+ @logger_registry.register
15
+ class ActSaveLoggerConfig(LoggerConfigBase):
13
16
  name: Literal["actsave"] = "actsave"
14
17
 
15
18
  def create_logger(self, trainer_config):
@@ -10,7 +10,7 @@ if TYPE_CHECKING:
10
10
  from ..trainer._config import TrainerConfig
11
11
 
12
12
 
13
- class BaseLoggerConfig(C.Config, ABC):
13
+ class LoggerConfigBase(C.Config, ABC):
14
14
  enabled: bool = True
15
15
  """Enable this logger."""
16
16
 
@@ -29,3 +29,6 @@ class BaseLoggerConfig(C.Config, ABC):
29
29
 
30
30
  def __bool__(self):
31
31
  return self.enabled
32
+
33
+
34
+ logger_registry = C.Registry(LoggerConfigBase, discriminator="name")
nshtrainer/loggers/csv.py CHANGED
@@ -2,12 +2,14 @@ from __future__ import annotations
2
2
 
3
3
  from typing import Literal
4
4
 
5
- from typing_extensions import override
5
+ from typing_extensions import final, override
6
6
 
7
- from ._base import BaseLoggerConfig
7
+ from .base import LoggerConfigBase, logger_registry
8
8
 
9
9
 
10
- class CSVLoggerConfig(BaseLoggerConfig):
10
+ @final
11
+ @logger_registry.register
12
+ class CSVLoggerConfig(LoggerConfigBase):
11
13
  name: Literal["csv"] = "csv"
12
14
 
13
15
  enabled: bool = True
@@ -4,9 +4,9 @@ import logging
4
4
  from typing import Literal
5
5
 
6
6
  import nshconfig as C
7
- from typing_extensions import override
7
+ from typing_extensions import final, override
8
8
 
9
- from ._base import BaseLoggerConfig
9
+ from .base import LoggerConfigBase, logger_registry
10
10
 
11
11
  log = logging.getLogger(__name__)
12
12
 
@@ -30,7 +30,9 @@ def _tensorboard_available():
30
30
  return False
31
31
 
32
32
 
33
- class TensorboardLoggerConfig(BaseLoggerConfig):
33
+ @final
34
+ @logger_registry.register
35
+ class TensorboardLoggerConfig(LoggerConfigBase):
34
36
  name: Literal["tensorboard"] = "tensorboard"
35
37
 
36
38
  enabled: bool = C.Field(default_factory=lambda: _tensorboard_available())
@@ -7,12 +7,12 @@ from typing import TYPE_CHECKING, Literal
7
7
  import nshconfig as C
8
8
  from lightning.pytorch import Callback, LightningModule, Trainer
9
9
  from packaging import version
10
- from typing_extensions import assert_never, override
10
+ from typing_extensions import assert_never, final, override
11
11
 
12
12
  from ..callbacks.base import CallbackConfigBase
13
13
  from ..callbacks.wandb_upload_code import WandbUploadCodeCallbackConfig
14
14
  from ..callbacks.wandb_watch import WandbWatchCallbackConfig
15
- from ._base import BaseLoggerConfig
15
+ from .base import LoggerConfigBase, logger_registry
16
16
 
17
17
  if TYPE_CHECKING:
18
18
  from ..trainer._config import TrainerConfig
@@ -73,7 +73,9 @@ class FinishWandbOnTeardownCallback(Callback):
73
73
  wandb.finish()
74
74
 
75
75
 
76
- class WandbLoggerConfig(CallbackConfigBase, BaseLoggerConfig):
76
+ @final
77
+ @logger_registry.register
78
+ class WandbLoggerConfig(CallbackConfigBase, LoggerConfigBase):
77
79
  name: Literal["wandb"] = "wandb"
78
80
 
79
81
  enabled: bool = C.Field(default_factory=lambda: _wandb_available())
@@ -5,8 +5,8 @@ from typing import Annotated
5
5
  import nshconfig as C
6
6
  from typing_extensions import TypeAliasType
7
7
 
8
- from ._base import LRSchedulerConfigBase as LRSchedulerConfigBase
9
- from ._base import LRSchedulerMetadata as LRSchedulerMetadata
8
+ from .base import LRSchedulerConfigBase as LRSchedulerConfigBase
9
+ from .base import LRSchedulerMetadata as LRSchedulerMetadata
10
10
  from .linear_warmup_cosine import (
11
11
  LinearWarmupCosineAnnealingLR as LinearWarmupCosineAnnealingLR,
12
12
  )