nshtrainer 1.0.0b33__py3-none-any.whl → 1.0.0b37__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. nshtrainer/__init__.py +1 -0
  2. nshtrainer/_directory.py +3 -1
  3. nshtrainer/_hf_hub.py +8 -1
  4. nshtrainer/callbacks/__init__.py +10 -23
  5. nshtrainer/callbacks/actsave.py +6 -2
  6. nshtrainer/callbacks/base.py +3 -0
  7. nshtrainer/callbacks/checkpoint/__init__.py +0 -4
  8. nshtrainer/callbacks/checkpoint/best_checkpoint.py +2 -0
  9. nshtrainer/callbacks/checkpoint/last_checkpoint.py +72 -2
  10. nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +4 -2
  11. nshtrainer/callbacks/debug_flag.py +4 -2
  12. nshtrainer/callbacks/directory_setup.py +23 -21
  13. nshtrainer/callbacks/early_stopping.py +4 -2
  14. nshtrainer/callbacks/ema.py +29 -27
  15. nshtrainer/callbacks/finite_checks.py +21 -19
  16. nshtrainer/callbacks/gradient_skipping.py +29 -27
  17. nshtrainer/callbacks/log_epoch.py +4 -2
  18. nshtrainer/callbacks/lr_monitor.py +6 -1
  19. nshtrainer/callbacks/norm_logging.py +36 -34
  20. nshtrainer/callbacks/print_table.py +20 -18
  21. nshtrainer/callbacks/rlp_sanity_checks.py +4 -2
  22. nshtrainer/callbacks/shared_parameters.py +9 -7
  23. nshtrainer/callbacks/timer.py +12 -10
  24. nshtrainer/callbacks/wandb_upload_code.py +4 -2
  25. nshtrainer/callbacks/wandb_watch.py +4 -2
  26. nshtrainer/configs/__init__.py +16 -12
  27. nshtrainer/configs/_hf_hub/__init__.py +2 -0
  28. nshtrainer/configs/callbacks/__init__.py +4 -8
  29. nshtrainer/configs/callbacks/actsave/__init__.py +2 -0
  30. nshtrainer/configs/callbacks/base/__init__.py +2 -0
  31. nshtrainer/configs/callbacks/checkpoint/__init__.py +4 -6
  32. nshtrainer/configs/callbacks/checkpoint/best_checkpoint/__init__.py +4 -0
  33. nshtrainer/configs/callbacks/checkpoint/last_checkpoint/__init__.py +4 -0
  34. nshtrainer/configs/callbacks/checkpoint/on_exception_checkpoint/__init__.py +4 -0
  35. nshtrainer/configs/callbacks/debug_flag/__init__.py +2 -0
  36. nshtrainer/configs/callbacks/directory_setup/__init__.py +2 -0
  37. nshtrainer/configs/callbacks/early_stopping/__init__.py +2 -0
  38. nshtrainer/configs/callbacks/ema/__init__.py +2 -0
  39. nshtrainer/configs/callbacks/finite_checks/__init__.py +2 -0
  40. nshtrainer/configs/callbacks/gradient_skipping/__init__.py +4 -0
  41. nshtrainer/configs/callbacks/log_epoch/__init__.py +2 -0
  42. nshtrainer/configs/callbacks/lr_monitor/__init__.py +2 -0
  43. nshtrainer/configs/callbacks/norm_logging/__init__.py +2 -0
  44. nshtrainer/configs/callbacks/print_table/__init__.py +2 -0
  45. nshtrainer/configs/callbacks/rlp_sanity_checks/__init__.py +4 -0
  46. nshtrainer/configs/callbacks/shared_parameters/__init__.py +4 -0
  47. nshtrainer/configs/callbacks/timer/__init__.py +2 -0
  48. nshtrainer/configs/callbacks/wandb_upload_code/__init__.py +4 -0
  49. nshtrainer/configs/callbacks/wandb_watch/__init__.py +2 -0
  50. nshtrainer/configs/loggers/__init__.py +6 -4
  51. nshtrainer/configs/loggers/actsave/__init__.py +4 -2
  52. nshtrainer/configs/loggers/base/__init__.py +11 -0
  53. nshtrainer/configs/loggers/csv/__init__.py +4 -2
  54. nshtrainer/configs/loggers/tensorboard/__init__.py +4 -2
  55. nshtrainer/configs/loggers/wandb/__init__.py +4 -2
  56. nshtrainer/configs/lr_scheduler/__init__.py +4 -2
  57. nshtrainer/configs/lr_scheduler/base/__init__.py +11 -0
  58. nshtrainer/configs/lr_scheduler/linear_warmup_cosine/__init__.py +4 -0
  59. nshtrainer/configs/lr_scheduler/reduce_lr_on_plateau/__init__.py +4 -0
  60. nshtrainer/configs/nn/__init__.py +4 -2
  61. nshtrainer/configs/nn/mlp/__init__.py +2 -2
  62. nshtrainer/configs/nn/nonlinearity/__init__.py +4 -2
  63. nshtrainer/configs/optimizer/__init__.py +2 -0
  64. nshtrainer/configs/trainer/__init__.py +4 -6
  65. nshtrainer/configs/trainer/_config/__init__.py +2 -10
  66. nshtrainer/loggers/__init__.py +3 -8
  67. nshtrainer/loggers/actsave.py +5 -2
  68. nshtrainer/loggers/{_base.py → base.py} +4 -1
  69. nshtrainer/loggers/csv.py +5 -3
  70. nshtrainer/loggers/tensorboard.py +5 -3
  71. nshtrainer/loggers/wandb.py +5 -3
  72. nshtrainer/lr_scheduler/__init__.py +2 -2
  73. nshtrainer/lr_scheduler/{_base.py → base.py} +3 -0
  74. nshtrainer/lr_scheduler/linear_warmup_cosine.py +56 -54
  75. nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +4 -2
  76. nshtrainer/nn/__init__.py +1 -1
  77. nshtrainer/nn/mlp.py +4 -4
  78. nshtrainer/nn/nonlinearity.py +37 -33
  79. nshtrainer/optimizer.py +8 -2
  80. nshtrainer/trainer/__init__.py +3 -2
  81. nshtrainer/trainer/_config.py +6 -44
  82. {nshtrainer-1.0.0b33.dist-info → nshtrainer-1.0.0b37.dist-info}/METADATA +1 -1
  83. nshtrainer-1.0.0b37.dist-info/RECORD +156 -0
  84. nshtrainer/callbacks/checkpoint/time_checkpoint.py +0 -114
  85. nshtrainer/configs/callbacks/checkpoint/time_checkpoint/__init__.py +0 -19
  86. nshtrainer/configs/loggers/_base/__init__.py +0 -9
  87. nshtrainer/configs/lr_scheduler/_base/__init__.py +0 -9
  88. nshtrainer-1.0.0b33.dist-info/RECORD +0 -158
  89. {nshtrainer-1.0.0b33.dist-info → nshtrainer-1.0.0b37.dist-info}/WHEEL +0 -0
@@ -7,13 +7,47 @@ import torch
7
7
  import torch.nn as nn
8
8
  from lightning.pytorch import Callback, LightningModule, Trainer
9
9
  from torch.optim import Optimizer
10
- from typing_extensions import override
10
+ from typing_extensions import final, override
11
11
 
12
- from .base import CallbackConfigBase
12
+ from .base import CallbackConfigBase, callback_registry
13
13
 
14
14
  log = logging.getLogger(__name__)
15
15
 
16
16
 
17
+ @final
18
+ @callback_registry.register
19
+ class NormLoggingCallbackConfig(CallbackConfigBase):
20
+ name: Literal["norm_logging"] = "norm_logging"
21
+
22
+ log_grad_norm: bool | str | float = False
23
+ """If enabled, will log the gradient norm (averaged across all model parameters) to the logger."""
24
+ log_grad_norm_per_param: bool | str | float = False
25
+ """If enabled, will log the gradient norm for each model parameter to the logger."""
26
+
27
+ log_param_norm: bool | str | float = False
28
+ """If enabled, will log the parameter norm (averaged across all model parameters) to the logger."""
29
+ log_param_norm_per_param: bool | str | float = False
30
+ """If enabled, will log the parameter norm for each model parameter to the logger."""
31
+
32
+ def __bool__(self):
33
+ return any(
34
+ v
35
+ for v in (
36
+ self.log_grad_norm,
37
+ self.log_grad_norm_per_param,
38
+ self.log_param_norm,
39
+ self.log_param_norm_per_param,
40
+ )
41
+ )
42
+
43
+ @override
44
+ def create_callbacks(self, trainer_config):
45
+ if not self:
46
+ return
47
+
48
+ yield NormLoggingCallback(self)
49
+
50
+
17
51
  def grad_norm(
18
52
  module: nn.Module,
19
53
  norm_type: float | int | str,
@@ -155,35 +189,3 @@ class NormLoggingCallback(Callback):
155
189
  self._perform_norm_logging(
156
190
  pl_module, optimizer, prefix=f"train/optimizer_{i}/"
157
191
  )
158
-
159
-
160
- class NormLoggingCallbackConfig(CallbackConfigBase):
161
- name: Literal["norm_logging"] = "norm_logging"
162
-
163
- log_grad_norm: bool | str | float = False
164
- """If enabled, will log the gradient norm (averaged across all model parameters) to the logger."""
165
- log_grad_norm_per_param: bool | str | float = False
166
- """If enabled, will log the gradient norm for each model parameter to the logger."""
167
-
168
- log_param_norm: bool | str | float = False
169
- """If enabled, will log the parameter norm (averaged across all model parameters) to the logger."""
170
- log_param_norm_per_param: bool | str | float = False
171
- """If enabled, will log the parameter norm for each model parameter to the logger."""
172
-
173
- def __bool__(self):
174
- return any(
175
- v
176
- for v in (
177
- self.log_grad_norm,
178
- self.log_grad_norm_per_param,
179
- self.log_param_norm,
180
- self.log_param_norm_per_param,
181
- )
182
- )
183
-
184
- @override
185
- def create_callbacks(self, trainer_config):
186
- if not self:
187
- return
188
-
189
- yield NormLoggingCallback(self)
@@ -9,13 +9,31 @@ from typing import Literal
9
9
  import torch
10
10
  from lightning.pytorch import LightningModule, Trainer
11
11
  from lightning.pytorch.callbacks import Callback
12
- from typing_extensions import override
12
+ from typing_extensions import final, override
13
13
 
14
- from .base import CallbackConfigBase
14
+ from .base import CallbackConfigBase, callback_registry
15
15
 
16
16
  log = logging.getLogger(__name__)
17
17
 
18
18
 
19
+ @final
20
+ @callback_registry.register
21
+ class PrintTableMetricsCallbackConfig(CallbackConfigBase):
22
+ """Configuration class for PrintTableMetricsCallback."""
23
+
24
+ name: Literal["print_table_metrics"] = "print_table_metrics"
25
+
26
+ enabled: bool = True
27
+ """Whether to enable the callback or not."""
28
+
29
+ metric_patterns: list[str] | None = None
30
+ """List of patterns to filter the metrics to be displayed. If None, all metrics are displayed."""
31
+
32
+ @override
33
+ def create_callbacks(self, trainer_config):
34
+ yield PrintTableMetricsCallback(metric_patterns=self.metric_patterns)
35
+
36
+
19
37
  class PrintTableMetricsCallback(Callback):
20
38
  """Prints a table with the metrics in columns on every epoch end."""
21
39
 
@@ -74,19 +92,3 @@ class PrintTableMetricsCallback(Callback):
74
92
  table.add_row(*values)
75
93
 
76
94
  return table
77
-
78
-
79
- class PrintTableMetricsCallbackConfig(CallbackConfigBase):
80
- """Configuration class for PrintTableMetricsCallback."""
81
-
82
- name: Literal["print_table_metrics"] = "print_table_metrics"
83
-
84
- enabled: bool = True
85
- """Whether to enable the callback or not."""
86
-
87
- metric_patterns: list[str] | None = None
88
- """List of patterns to filter the metrics to be displayed. If None, all metrics are displayed."""
89
-
90
- @override
91
- def create_callbacks(self, trainer_config):
92
- yield PrintTableMetricsCallback(metric_patterns=self.metric_patterns)
@@ -11,13 +11,15 @@ from lightning.pytorch.utilities.types import (
11
11
  LRSchedulerConfigType,
12
12
  LRSchedulerTypeUnion,
13
13
  )
14
- from typing_extensions import Protocol, override, runtime_checkable
14
+ from typing_extensions import Protocol, final, override, runtime_checkable
15
15
 
16
- from .base import CallbackConfigBase
16
+ from .base import CallbackConfigBase, callback_registry
17
17
 
18
18
  log = logging.getLogger(__name__)
19
19
 
20
20
 
21
+ @final
22
+ @callback_registry.register
21
23
  class RLPSanityChecksCallbackConfig(CallbackConfigBase):
22
24
  """
23
25
  If enabled, will do some sanity checks if the `ReduceLROnPlateau` scheduler is used:
@@ -7,18 +7,15 @@ from typing import Literal, Protocol, runtime_checkable
7
7
  import torch.nn as nn
8
8
  from lightning.pytorch import LightningModule, Trainer
9
9
  from lightning.pytorch.callbacks import Callback
10
- from typing_extensions import TypeAliasType, override
10
+ from typing_extensions import TypeAliasType, final, override
11
11
 
12
- from .base import CallbackConfigBase
12
+ from .base import CallbackConfigBase, callback_registry
13
13
 
14
14
  log = logging.getLogger(__name__)
15
15
 
16
16
 
17
- def _parameters_to_names(parameters: Iterable[nn.Parameter], model: nn.Module):
18
- mapping = {id(p): n for n, p in model.named_parameters()}
19
- return [mapping[id(p)] for p in parameters]
20
-
21
-
17
+ @final
18
+ @callback_registry.register
22
19
  class SharedParametersCallbackConfig(CallbackConfigBase):
23
20
  """A callback that allows scaling the gradients of shared parameters that
24
21
  are registered in the ``self.shared_parameters`` list of the root module.
@@ -34,6 +31,11 @@ class SharedParametersCallbackConfig(CallbackConfigBase):
34
31
  yield SharedParametersCallback(self)
35
32
 
36
33
 
34
+ def _parameters_to_names(parameters: Iterable[nn.Parameter], model: nn.Module):
35
+ mapping = {id(p): n for n, p in model.named_parameters()}
36
+ return [mapping[id(p)] for p in parameters]
37
+
38
+
37
39
  SharedParametersList = TypeAliasType(
38
40
  "SharedParametersList", list[tuple[nn.Parameter, int | float]]
39
41
  )
@@ -7,13 +7,23 @@ from typing import Any, Literal
7
7
  from lightning.pytorch import LightningModule, Trainer
8
8
  from lightning.pytorch.callbacks import Callback
9
9
  from lightning.pytorch.utilities.types import STEP_OUTPUT
10
- from typing_extensions import override
10
+ from typing_extensions import final, override
11
11
 
12
- from .base import CallbackConfigBase
12
+ from .base import CallbackConfigBase, callback_registry
13
13
 
14
14
  log = logging.getLogger(__name__)
15
15
 
16
16
 
17
+ @final
18
+ @callback_registry.register
19
+ class EpochTimerCallbackConfig(CallbackConfigBase):
20
+ name: Literal["epoch_timer"] = "epoch_timer"
21
+
22
+ @override
23
+ def create_callbacks(self, trainer_config):
24
+ yield EpochTimerCallback()
25
+
26
+
17
27
  class EpochTimerCallback(Callback):
18
28
  def __init__(self):
19
29
  super().__init__()
@@ -149,11 +159,3 @@ class EpochTimerCallback(Callback):
149
159
  def load_state_dict(self, state_dict: dict[str, Any]) -> None:
150
160
  self._elapsed_time = state_dict["elapsed_time"]
151
161
  self._total_batches = state_dict["total_batches"]
152
-
153
-
154
- class EpochTimerCallbackConfig(CallbackConfigBase):
155
- name: Literal["epoch_timer"] = "epoch_timer"
156
-
157
- @override
158
- def create_callbacks(self, trainer_config):
159
- yield EpochTimerCallback()
@@ -9,13 +9,15 @@ from lightning.pytorch import LightningModule, Trainer
9
9
  from lightning.pytorch.callbacks.callback import Callback
10
10
  from lightning.pytorch.loggers import WandbLogger
11
11
  from nshrunner._env import SNAPSHOT_DIR
12
- from typing_extensions import override
12
+ from typing_extensions import final, override
13
13
 
14
- from .base import CallbackConfigBase
14
+ from .base import CallbackConfigBase, callback_registry
15
15
 
16
16
  log = logging.getLogger(__name__)
17
17
 
18
18
 
19
+ @final
20
+ @callback_registry.register
19
21
  class WandbUploadCodeCallbackConfig(CallbackConfigBase):
20
22
  name: Literal["wandb_upload_code"] = "wandb_upload_code"
21
23
 
@@ -7,13 +7,15 @@ import torch.nn as nn
7
7
  from lightning.pytorch import LightningModule, Trainer
8
8
  from lightning.pytorch.callbacks.callback import Callback
9
9
  from lightning.pytorch.loggers import WandbLogger
10
- from typing_extensions import override
10
+ from typing_extensions import final, override
11
11
 
12
- from .base import CallbackConfigBase
12
+ from .base import CallbackConfigBase, callback_registry
13
13
 
14
14
  log = logging.getLogger(__name__)
15
15
 
16
16
 
17
+ @final
18
+ @callback_registry.register
17
19
  class WandbWatchCallbackConfig(CallbackConfigBase):
18
20
  name: Literal["wandb_watch"] = "wandb_watch"
19
21
 
@@ -5,6 +5,7 @@ __codegen__ = True
5
5
  from nshtrainer import MetricConfig as MetricConfig
6
6
  from nshtrainer import TrainerConfig as TrainerConfig
7
7
  from nshtrainer import accelerator_registry as accelerator_registry
8
+ from nshtrainer import callback_registry as callback_registry
8
9
  from nshtrainer import plugin_registry as plugin_registry
9
10
  from nshtrainer._checkpoint.metadata import CheckpointMetadata as CheckpointMetadata
10
11
  from nshtrainer._directory import DirectoryConfig as DirectoryConfig
@@ -13,6 +14,7 @@ from nshtrainer._hf_hub import (
13
14
  HuggingFaceHubAutoCreateConfig as HuggingFaceHubAutoCreateConfig,
14
15
  )
15
16
  from nshtrainer._hf_hub import HuggingFaceHubConfig as HuggingFaceHubConfig
17
+ from nshtrainer.callbacks import ActSaveConfig as ActSaveConfig
16
18
  from nshtrainer.callbacks import (
17
19
  BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
18
20
  )
@@ -35,6 +37,7 @@ from nshtrainer.callbacks import (
35
37
  from nshtrainer.callbacks import (
36
38
  LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
37
39
  )
40
+ from nshtrainer.callbacks import LearningRateMonitorConfig as LearningRateMonitorConfig
38
41
  from nshtrainer.callbacks import LogEpochCallbackConfig as LogEpochCallbackConfig
39
42
  from nshtrainer.callbacks import NormLoggingCallbackConfig as NormLoggingCallbackConfig
40
43
  from nshtrainer.callbacks import (
@@ -49,36 +52,34 @@ from nshtrainer.callbacks import (
49
52
  from nshtrainer.callbacks import (
50
53
  SharedParametersCallbackConfig as SharedParametersCallbackConfig,
51
54
  )
52
- from nshtrainer.callbacks import (
53
- TimeCheckpointCallbackConfig as TimeCheckpointCallbackConfig,
54
- )
55
55
  from nshtrainer.callbacks import (
56
56
  WandbUploadCodeCallbackConfig as WandbUploadCodeCallbackConfig,
57
57
  )
58
58
  from nshtrainer.callbacks import WandbWatchCallbackConfig as WandbWatchCallbackConfig
59
- from nshtrainer.callbacks.actsave import ActSaveConfig as ActSaveConfig
60
59
  from nshtrainer.callbacks.checkpoint._base import (
61
60
  BaseCheckpointCallbackConfig as BaseCheckpointCallbackConfig,
62
61
  )
63
62
  from nshtrainer.loggers import ActSaveLoggerConfig as ActSaveLoggerConfig
64
- from nshtrainer.loggers import BaseLoggerConfig as BaseLoggerConfig
65
63
  from nshtrainer.loggers import CSVLoggerConfig as CSVLoggerConfig
66
64
  from nshtrainer.loggers import LoggerConfig as LoggerConfig
65
+ from nshtrainer.loggers import LoggerConfigBase as LoggerConfigBase
67
66
  from nshtrainer.loggers import TensorboardLoggerConfig as TensorboardLoggerConfig
68
67
  from nshtrainer.loggers import WandbLoggerConfig as WandbLoggerConfig
68
+ from nshtrainer.loggers import logger_registry as logger_registry
69
69
  from nshtrainer.lr_scheduler import (
70
70
  LinearWarmupCosineDecayLRSchedulerConfig as LinearWarmupCosineDecayLRSchedulerConfig,
71
71
  )
72
72
  from nshtrainer.lr_scheduler import LRSchedulerConfig as LRSchedulerConfig
73
73
  from nshtrainer.lr_scheduler import LRSchedulerConfigBase as LRSchedulerConfigBase
74
74
  from nshtrainer.lr_scheduler import ReduceLROnPlateauConfig as ReduceLROnPlateauConfig
75
- from nshtrainer.nn import BaseNonlinearityConfig as BaseNonlinearityConfig
75
+ from nshtrainer.lr_scheduler.base import lr_scheduler_registry as lr_scheduler_registry
76
76
  from nshtrainer.nn import ELUNonlinearityConfig as ELUNonlinearityConfig
77
77
  from nshtrainer.nn import GELUNonlinearityConfig as GELUNonlinearityConfig
78
78
  from nshtrainer.nn import LeakyReLUNonlinearityConfig as LeakyReLUNonlinearityConfig
79
79
  from nshtrainer.nn import MishNonlinearityConfig as MishNonlinearityConfig
80
80
  from nshtrainer.nn import MLPConfig as MLPConfig
81
81
  from nshtrainer.nn import NonlinearityConfig as NonlinearityConfig
82
+ from nshtrainer.nn import NonlinearityConfigBase as NonlinearityConfigBase
82
83
  from nshtrainer.nn import PReLUConfig as PReLUConfig
83
84
  from nshtrainer.nn import ReLUNonlinearityConfig as ReLUNonlinearityConfig
84
85
  from nshtrainer.nn import SigmoidNonlinearityConfig as SigmoidNonlinearityConfig
@@ -91,9 +92,11 @@ from nshtrainer.nn import TanhNonlinearityConfig as TanhNonlinearityConfig
91
92
  from nshtrainer.nn.nonlinearity import (
92
93
  SwiGLUNonlinearityConfig as SwiGLUNonlinearityConfig,
93
94
  )
95
+ from nshtrainer.nn.nonlinearity import nonlinearity_registry as nonlinearity_registry
94
96
  from nshtrainer.optimizer import AdamWConfig as AdamWConfig
95
97
  from nshtrainer.optimizer import OptimizerConfig as OptimizerConfig
96
98
  from nshtrainer.optimizer import OptimizerConfigBase as OptimizerConfigBase
99
+ from nshtrainer.optimizer import optimizer_registry as optimizer_registry
97
100
  from nshtrainer.profiler import AdvancedProfilerConfig as AdvancedProfilerConfig
98
101
  from nshtrainer.profiler import BaseProfilerConfig as BaseProfilerConfig
99
102
  from nshtrainer.profiler import ProfilerConfig as ProfilerConfig
@@ -106,9 +109,6 @@ from nshtrainer.trainer._config import (
106
109
  from nshtrainer.trainer._config import CheckpointSavingConfig as CheckpointSavingConfig
107
110
  from nshtrainer.trainer._config import EnvironmentConfig as EnvironmentConfig
108
111
  from nshtrainer.trainer._config import GradientClippingConfig as GradientClippingConfig
109
- from nshtrainer.trainer._config import (
110
- LearningRateMonitorConfig as LearningRateMonitorConfig,
111
- )
112
112
  from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingConfig
113
113
  from nshtrainer.trainer._config import StrategyConfig as StrategyConfig
114
114
  from nshtrainer.trainer.accelerator import CPUAcceleratorConfig as CPUAcceleratorConfig
@@ -229,8 +229,6 @@ __all__ = [
229
229
  "AdvancedProfilerConfig",
230
230
  "AsyncCheckpointIOPlugin",
231
231
  "BaseCheckpointCallbackConfig",
232
- "BaseLoggerConfig",
233
- "BaseNonlinearityConfig",
234
232
  "BaseProfilerConfig",
235
233
  "BestCheckpointCallbackConfig",
236
234
  "BitsandbytesPluginConfig",
@@ -284,6 +282,7 @@ __all__ = [
284
282
  "LinearWarmupCosineDecayLRSchedulerConfig",
285
283
  "LogEpochCallbackConfig",
286
284
  "LoggerConfig",
285
+ "LoggerConfigBase",
287
286
  "MLPConfig",
288
287
  "MPIEnvironmentPlugin",
289
288
  "MPSAcceleratorConfig",
@@ -291,6 +290,7 @@ __all__ = [
291
290
  "MishNonlinearityConfig",
292
291
  "MixedPrecisionPluginConfig",
293
292
  "NonlinearityConfig",
293
+ "NonlinearityConfigBase",
294
294
  "NormLoggingCallbackConfig",
295
295
  "OnExceptionCheckpointCallbackConfig",
296
296
  "OptimizerConfig",
@@ -320,7 +320,6 @@ __all__ = [
320
320
  "SwishNonlinearityConfig",
321
321
  "TanhNonlinearityConfig",
322
322
  "TensorboardLoggerConfig",
323
- "TimeCheckpointCallbackConfig",
324
323
  "TorchCheckpointIOPlugin",
325
324
  "TorchElasticEnvironmentPlugin",
326
325
  "TorchSyncBatchNormPlugin",
@@ -337,12 +336,17 @@ __all__ = [
337
336
  "_directory",
338
337
  "_hf_hub",
339
338
  "accelerator_registry",
339
+ "callback_registry",
340
340
  "callbacks",
341
+ "logger_registry",
341
342
  "loggers",
342
343
  "lr_scheduler",
344
+ "lr_scheduler_registry",
343
345
  "metrics",
344
346
  "nn",
347
+ "nonlinearity_registry",
345
348
  "optimizer",
349
+ "optimizer_registry",
346
350
  "plugin_registry",
347
351
  "profiler",
348
352
  "trainer",
@@ -7,9 +7,11 @@ from nshtrainer._hf_hub import (
7
7
  HuggingFaceHubAutoCreateConfig as HuggingFaceHubAutoCreateConfig,
8
8
  )
9
9
  from nshtrainer._hf_hub import HuggingFaceHubConfig as HuggingFaceHubConfig
10
+ from nshtrainer._hf_hub import callback_registry as callback_registry
10
11
 
11
12
  __all__ = [
12
13
  "CallbackConfigBase",
13
14
  "HuggingFaceHubAutoCreateConfig",
14
15
  "HuggingFaceHubConfig",
16
+ "callback_registry",
15
17
  ]
@@ -2,6 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  __codegen__ = True
4
4
 
5
+ from nshtrainer.callbacks import ActSaveConfig as ActSaveConfig
5
6
  from nshtrainer.callbacks import (
6
7
  BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
7
8
  )
@@ -25,6 +26,7 @@ from nshtrainer.callbacks import (
25
26
  from nshtrainer.callbacks import (
26
27
  LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
27
28
  )
29
+ from nshtrainer.callbacks import LearningRateMonitorConfig as LearningRateMonitorConfig
28
30
  from nshtrainer.callbacks import LogEpochCallbackConfig as LogEpochCallbackConfig
29
31
  from nshtrainer.callbacks import NormLoggingCallbackConfig as NormLoggingCallbackConfig
30
32
  from nshtrainer.callbacks import (
@@ -39,14 +41,11 @@ from nshtrainer.callbacks import (
39
41
  from nshtrainer.callbacks import (
40
42
  SharedParametersCallbackConfig as SharedParametersCallbackConfig,
41
43
  )
42
- from nshtrainer.callbacks import (
43
- TimeCheckpointCallbackConfig as TimeCheckpointCallbackConfig,
44
- )
45
44
  from nshtrainer.callbacks import (
46
45
  WandbUploadCodeCallbackConfig as WandbUploadCodeCallbackConfig,
47
46
  )
48
47
  from nshtrainer.callbacks import WandbWatchCallbackConfig as WandbWatchCallbackConfig
49
- from nshtrainer.callbacks.actsave import ActSaveConfig as ActSaveConfig
48
+ from nshtrainer.callbacks import callback_registry as callback_registry
50
49
  from nshtrainer.callbacks.checkpoint._base import (
51
50
  BaseCheckpointCallbackConfig as BaseCheckpointCallbackConfig,
52
51
  )
@@ -54,9 +53,6 @@ from nshtrainer.callbacks.checkpoint._base import (
54
53
  CheckpointMetadata as CheckpointMetadata,
55
54
  )
56
55
  from nshtrainer.callbacks.early_stopping import MetricConfig as MetricConfig
57
- from nshtrainer.callbacks.lr_monitor import (
58
- LearningRateMonitorConfig as LearningRateMonitorConfig,
59
- )
60
56
 
61
57
  from . import actsave as actsave
62
58
  from . import base as base
@@ -100,11 +96,11 @@ __all__ = [
100
96
  "PrintTableMetricsCallbackConfig",
101
97
  "RLPSanityChecksCallbackConfig",
102
98
  "SharedParametersCallbackConfig",
103
- "TimeCheckpointCallbackConfig",
104
99
  "WandbUploadCodeCallbackConfig",
105
100
  "WandbWatchCallbackConfig",
106
101
  "actsave",
107
102
  "base",
103
+ "callback_registry",
108
104
  "checkpoint",
109
105
  "debug_flag",
110
106
  "directory_setup",
@@ -4,8 +4,10 @@ __codegen__ = True
4
4
 
5
5
  from nshtrainer.callbacks.actsave import ActSaveConfig as ActSaveConfig
6
6
  from nshtrainer.callbacks.actsave import CallbackConfigBase as CallbackConfigBase
7
+ from nshtrainer.callbacks.actsave import callback_registry as callback_registry
7
8
 
8
9
  __all__ = [
9
10
  "ActSaveConfig",
10
11
  "CallbackConfigBase",
12
+ "callback_registry",
11
13
  ]
@@ -3,7 +3,9 @@ from __future__ import annotations
3
3
  __codegen__ = True
4
4
 
5
5
  from nshtrainer.callbacks.base import CallbackConfigBase as CallbackConfigBase
6
+ from nshtrainer.callbacks.base import callback_registry as callback_registry
6
7
 
7
8
  __all__ = [
8
9
  "CallbackConfigBase",
10
+ "callback_registry",
9
11
  ]
@@ -11,9 +11,6 @@ from nshtrainer.callbacks.checkpoint import (
11
11
  from nshtrainer.callbacks.checkpoint import (
12
12
  OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
13
13
  )
14
- from nshtrainer.callbacks.checkpoint import (
15
- TimeCheckpointCallbackConfig as TimeCheckpointCallbackConfig,
16
- )
17
14
  from nshtrainer.callbacks.checkpoint._base import (
18
15
  BaseCheckpointCallbackConfig as BaseCheckpointCallbackConfig,
19
16
  )
@@ -24,12 +21,14 @@ from nshtrainer.callbacks.checkpoint._base import (
24
21
  CheckpointMetadata as CheckpointMetadata,
25
22
  )
26
23
  from nshtrainer.callbacks.checkpoint.best_checkpoint import MetricConfig as MetricConfig
24
+ from nshtrainer.callbacks.checkpoint.best_checkpoint import (
25
+ callback_registry as callback_registry,
26
+ )
27
27
 
28
28
  from . import _base as _base
29
29
  from . import best_checkpoint as best_checkpoint
30
30
  from . import last_checkpoint as last_checkpoint
31
31
  from . import on_exception_checkpoint as on_exception_checkpoint
32
- from . import time_checkpoint as time_checkpoint
33
32
 
34
33
  __all__ = [
35
34
  "BaseCheckpointCallbackConfig",
@@ -39,10 +38,9 @@ __all__ = [
39
38
  "LastCheckpointCallbackConfig",
40
39
  "MetricConfig",
41
40
  "OnExceptionCheckpointCallbackConfig",
42
- "TimeCheckpointCallbackConfig",
43
41
  "_base",
44
42
  "best_checkpoint",
43
+ "callback_registry",
45
44
  "last_checkpoint",
46
45
  "on_exception_checkpoint",
47
- "time_checkpoint",
48
46
  ]
@@ -12,10 +12,14 @@ from nshtrainer.callbacks.checkpoint.best_checkpoint import (
12
12
  CheckpointMetadata as CheckpointMetadata,
13
13
  )
14
14
  from nshtrainer.callbacks.checkpoint.best_checkpoint import MetricConfig as MetricConfig
15
+ from nshtrainer.callbacks.checkpoint.best_checkpoint import (
16
+ callback_registry as callback_registry,
17
+ )
15
18
 
16
19
  __all__ = [
17
20
  "BaseCheckpointCallbackConfig",
18
21
  "BestCheckpointCallbackConfig",
19
22
  "CheckpointMetadata",
20
23
  "MetricConfig",
24
+ "callback_registry",
21
25
  ]
@@ -11,9 +11,13 @@ from nshtrainer.callbacks.checkpoint.last_checkpoint import (
11
11
  from nshtrainer.callbacks.checkpoint.last_checkpoint import (
12
12
  LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
13
13
  )
14
+ from nshtrainer.callbacks.checkpoint.last_checkpoint import (
15
+ callback_registry as callback_registry,
16
+ )
14
17
 
15
18
  __all__ = [
16
19
  "BaseCheckpointCallbackConfig",
17
20
  "CheckpointMetadata",
18
21
  "LastCheckpointCallbackConfig",
22
+ "callback_registry",
19
23
  ]
@@ -8,8 +8,12 @@ from nshtrainer.callbacks.checkpoint.on_exception_checkpoint import (
8
8
  from nshtrainer.callbacks.checkpoint.on_exception_checkpoint import (
9
9
  OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
10
10
  )
11
+ from nshtrainer.callbacks.checkpoint.on_exception_checkpoint import (
12
+ callback_registry as callback_registry,
13
+ )
11
14
 
12
15
  __all__ = [
13
16
  "CallbackConfigBase",
14
17
  "OnExceptionCheckpointCallbackConfig",
18
+ "callback_registry",
15
19
  ]
@@ -6,8 +6,10 @@ from nshtrainer.callbacks.debug_flag import CallbackConfigBase as CallbackConfig
6
6
  from nshtrainer.callbacks.debug_flag import (
7
7
  DebugFlagCallbackConfig as DebugFlagCallbackConfig,
8
8
  )
9
+ from nshtrainer.callbacks.debug_flag import callback_registry as callback_registry
9
10
 
10
11
  __all__ = [
11
12
  "CallbackConfigBase",
12
13
  "DebugFlagCallbackConfig",
14
+ "callback_registry",
13
15
  ]
@@ -8,8 +8,10 @@ from nshtrainer.callbacks.directory_setup import (
8
8
  from nshtrainer.callbacks.directory_setup import (
9
9
  DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
10
10
  )
11
+ from nshtrainer.callbacks.directory_setup import callback_registry as callback_registry
11
12
 
12
13
  __all__ = [
13
14
  "CallbackConfigBase",
14
15
  "DirectorySetupCallbackConfig",
16
+ "callback_registry",
15
17
  ]
@@ -7,9 +7,11 @@ from nshtrainer.callbacks.early_stopping import (
7
7
  EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
8
8
  )
9
9
  from nshtrainer.callbacks.early_stopping import MetricConfig as MetricConfig
10
+ from nshtrainer.callbacks.early_stopping import callback_registry as callback_registry
10
11
 
11
12
  __all__ = [
12
13
  "CallbackConfigBase",
13
14
  "EarlyStoppingCallbackConfig",
14
15
  "MetricConfig",
16
+ "callback_registry",
15
17
  ]
@@ -4,8 +4,10 @@ __codegen__ = True
4
4
 
5
5
  from nshtrainer.callbacks.ema import CallbackConfigBase as CallbackConfigBase
6
6
  from nshtrainer.callbacks.ema import EMACallbackConfig as EMACallbackConfig
7
+ from nshtrainer.callbacks.ema import callback_registry as callback_registry
7
8
 
8
9
  __all__ = [
9
10
  "CallbackConfigBase",
10
11
  "EMACallbackConfig",
12
+ "callback_registry",
11
13
  ]
@@ -6,8 +6,10 @@ from nshtrainer.callbacks.finite_checks import CallbackConfigBase as CallbackCon
6
6
  from nshtrainer.callbacks.finite_checks import (
7
7
  FiniteChecksCallbackConfig as FiniteChecksCallbackConfig,
8
8
  )
9
+ from nshtrainer.callbacks.finite_checks import callback_registry as callback_registry
9
10
 
10
11
  __all__ = [
11
12
  "CallbackConfigBase",
12
13
  "FiniteChecksCallbackConfig",
14
+ "callback_registry",
13
15
  ]
@@ -8,8 +8,12 @@ from nshtrainer.callbacks.gradient_skipping import (
8
8
  from nshtrainer.callbacks.gradient_skipping import (
9
9
  GradientSkippingCallbackConfig as GradientSkippingCallbackConfig,
10
10
  )
11
+ from nshtrainer.callbacks.gradient_skipping import (
12
+ callback_registry as callback_registry,
13
+ )
11
14
 
12
15
  __all__ = [
13
16
  "CallbackConfigBase",
14
17
  "GradientSkippingCallbackConfig",
18
+ "callback_registry",
15
19
  ]
@@ -6,8 +6,10 @@ from nshtrainer.callbacks.log_epoch import CallbackConfigBase as CallbackConfigB
6
6
  from nshtrainer.callbacks.log_epoch import (
7
7
  LogEpochCallbackConfig as LogEpochCallbackConfig,
8
8
  )
9
+ from nshtrainer.callbacks.log_epoch import callback_registry as callback_registry
9
10
 
10
11
  __all__ = [
11
12
  "CallbackConfigBase",
12
13
  "LogEpochCallbackConfig",
14
+ "callback_registry",
13
15
  ]