nshtrainer 1.0.0b32__py3-none-any.whl → 1.0.0b36__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. nshtrainer/__init__.py +1 -0
  2. nshtrainer/_hf_hub.py +8 -1
  3. nshtrainer/callbacks/__init__.py +10 -23
  4. nshtrainer/callbacks/actsave.py +6 -2
  5. nshtrainer/callbacks/base.py +3 -0
  6. nshtrainer/callbacks/checkpoint/__init__.py +0 -4
  7. nshtrainer/callbacks/checkpoint/best_checkpoint.py +2 -0
  8. nshtrainer/callbacks/checkpoint/last_checkpoint.py +72 -2
  9. nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +4 -2
  10. nshtrainer/callbacks/debug_flag.py +4 -2
  11. nshtrainer/callbacks/directory_setup.py +23 -21
  12. nshtrainer/callbacks/early_stopping.py +4 -2
  13. nshtrainer/callbacks/ema.py +29 -27
  14. nshtrainer/callbacks/finite_checks.py +21 -19
  15. nshtrainer/callbacks/gradient_skipping.py +29 -27
  16. nshtrainer/callbacks/log_epoch.py +4 -2
  17. nshtrainer/callbacks/lr_monitor.py +6 -1
  18. nshtrainer/callbacks/norm_logging.py +36 -34
  19. nshtrainer/callbacks/print_table.py +20 -18
  20. nshtrainer/callbacks/rlp_sanity_checks.py +4 -2
  21. nshtrainer/callbacks/shared_parameters.py +9 -7
  22. nshtrainer/callbacks/timer.py +12 -10
  23. nshtrainer/callbacks/wandb_upload_code.py +4 -2
  24. nshtrainer/callbacks/wandb_watch.py +4 -2
  25. nshtrainer/configs/__init__.py +4 -8
  26. nshtrainer/configs/_hf_hub/__init__.py +2 -0
  27. nshtrainer/configs/callbacks/__init__.py +4 -8
  28. nshtrainer/configs/callbacks/actsave/__init__.py +2 -0
  29. nshtrainer/configs/callbacks/base/__init__.py +2 -0
  30. nshtrainer/configs/callbacks/checkpoint/__init__.py +4 -6
  31. nshtrainer/configs/callbacks/checkpoint/best_checkpoint/__init__.py +4 -0
  32. nshtrainer/configs/callbacks/checkpoint/last_checkpoint/__init__.py +4 -0
  33. nshtrainer/configs/callbacks/checkpoint/on_exception_checkpoint/__init__.py +4 -0
  34. nshtrainer/configs/callbacks/debug_flag/__init__.py +2 -0
  35. nshtrainer/configs/callbacks/directory_setup/__init__.py +2 -0
  36. nshtrainer/configs/callbacks/early_stopping/__init__.py +2 -0
  37. nshtrainer/configs/callbacks/ema/__init__.py +2 -0
  38. nshtrainer/configs/callbacks/finite_checks/__init__.py +2 -0
  39. nshtrainer/configs/callbacks/gradient_skipping/__init__.py +4 -0
  40. nshtrainer/configs/callbacks/log_epoch/__init__.py +2 -0
  41. nshtrainer/configs/callbacks/lr_monitor/__init__.py +2 -0
  42. nshtrainer/configs/callbacks/norm_logging/__init__.py +2 -0
  43. nshtrainer/configs/callbacks/print_table/__init__.py +2 -0
  44. nshtrainer/configs/callbacks/rlp_sanity_checks/__init__.py +4 -0
  45. nshtrainer/configs/callbacks/shared_parameters/__init__.py +4 -0
  46. nshtrainer/configs/callbacks/timer/__init__.py +2 -0
  47. nshtrainer/configs/callbacks/wandb_upload_code/__init__.py +4 -0
  48. nshtrainer/configs/callbacks/wandb_watch/__init__.py +2 -0
  49. nshtrainer/configs/trainer/__init__.py +2 -4
  50. nshtrainer/configs/trainer/_config/__init__.py +0 -8
  51. nshtrainer/data/datamodule.py +0 -2
  52. nshtrainer/model/base.py +0 -2
  53. nshtrainer/trainer/__init__.py +3 -2
  54. nshtrainer/trainer/_config.py +4 -42
  55. {nshtrainer-1.0.0b32.dist-info → nshtrainer-1.0.0b36.dist-info}/METADATA +1 -1
  56. {nshtrainer-1.0.0b32.dist-info → nshtrainer-1.0.0b36.dist-info}/RECORD +57 -60
  57. nshtrainer/callbacks/checkpoint/time_checkpoint.py +0 -114
  58. nshtrainer/configs/callbacks/checkpoint/time_checkpoint/__init__.py +0 -19
  59. nshtrainer/util/hparams.py +0 -18
  60. {nshtrainer-1.0.0b32.dist-info → nshtrainer-1.0.0b36.dist-info}/WHEEL +0 -0
@@ -7,13 +7,47 @@ import torch
7
7
  import torch.nn as nn
8
8
  from lightning.pytorch import Callback, LightningModule, Trainer
9
9
  from torch.optim import Optimizer
10
- from typing_extensions import override
10
+ from typing_extensions import final, override
11
11
 
12
- from .base import CallbackConfigBase
12
+ from .base import CallbackConfigBase, callback_registry
13
13
 
14
14
  log = logging.getLogger(__name__)
15
15
 
16
16
 
17
+ @final
18
+ @callback_registry.register
19
+ class NormLoggingCallbackConfig(CallbackConfigBase):
20
+ name: Literal["norm_logging"] = "norm_logging"
21
+
22
+ log_grad_norm: bool | str | float = False
23
+ """If enabled, will log the gradient norm (averaged across all model parameters) to the logger."""
24
+ log_grad_norm_per_param: bool | str | float = False
25
+ """If enabled, will log the gradient norm for each model parameter to the logger."""
26
+
27
+ log_param_norm: bool | str | float = False
28
+ """If enabled, will log the parameter norm (averaged across all model parameters) to the logger."""
29
+ log_param_norm_per_param: bool | str | float = False
30
+ """If enabled, will log the parameter norm for each model parameter to the logger."""
31
+
32
+ def __bool__(self):
33
+ return any(
34
+ v
35
+ for v in (
36
+ self.log_grad_norm,
37
+ self.log_grad_norm_per_param,
38
+ self.log_param_norm,
39
+ self.log_param_norm_per_param,
40
+ )
41
+ )
42
+
43
+ @override
44
+ def create_callbacks(self, trainer_config):
45
+ if not self:
46
+ return
47
+
48
+ yield NormLoggingCallback(self)
49
+
50
+
17
51
  def grad_norm(
18
52
  module: nn.Module,
19
53
  norm_type: float | int | str,
@@ -155,35 +189,3 @@ class NormLoggingCallback(Callback):
155
189
  self._perform_norm_logging(
156
190
  pl_module, optimizer, prefix=f"train/optimizer_{i}/"
157
191
  )
158
-
159
-
160
- class NormLoggingCallbackConfig(CallbackConfigBase):
161
- name: Literal["norm_logging"] = "norm_logging"
162
-
163
- log_grad_norm: bool | str | float = False
164
- """If enabled, will log the gradient norm (averaged across all model parameters) to the logger."""
165
- log_grad_norm_per_param: bool | str | float = False
166
- """If enabled, will log the gradient norm for each model parameter to the logger."""
167
-
168
- log_param_norm: bool | str | float = False
169
- """If enabled, will log the parameter norm (averaged across all model parameters) to the logger."""
170
- log_param_norm_per_param: bool | str | float = False
171
- """If enabled, will log the parameter norm for each model parameter to the logger."""
172
-
173
- def __bool__(self):
174
- return any(
175
- v
176
- for v in (
177
- self.log_grad_norm,
178
- self.log_grad_norm_per_param,
179
- self.log_param_norm,
180
- self.log_param_norm_per_param,
181
- )
182
- )
183
-
184
- @override
185
- def create_callbacks(self, trainer_config):
186
- if not self:
187
- return
188
-
189
- yield NormLoggingCallback(self)
@@ -9,13 +9,31 @@ from typing import Literal
9
9
  import torch
10
10
  from lightning.pytorch import LightningModule, Trainer
11
11
  from lightning.pytorch.callbacks import Callback
12
- from typing_extensions import override
12
+ from typing_extensions import final, override
13
13
 
14
- from .base import CallbackConfigBase
14
+ from .base import CallbackConfigBase, callback_registry
15
15
 
16
16
  log = logging.getLogger(__name__)
17
17
 
18
18
 
19
+ @final
20
+ @callback_registry.register
21
+ class PrintTableMetricsCallbackConfig(CallbackConfigBase):
22
+ """Configuration class for PrintTableMetricsCallback."""
23
+
24
+ name: Literal["print_table_metrics"] = "print_table_metrics"
25
+
26
+ enabled: bool = True
27
+ """Whether to enable the callback or not."""
28
+
29
+ metric_patterns: list[str] | None = None
30
+ """List of patterns to filter the metrics to be displayed. If None, all metrics are displayed."""
31
+
32
+ @override
33
+ def create_callbacks(self, trainer_config):
34
+ yield PrintTableMetricsCallback(metric_patterns=self.metric_patterns)
35
+
36
+
19
37
  class PrintTableMetricsCallback(Callback):
20
38
  """Prints a table with the metrics in columns on every epoch end."""
21
39
 
@@ -74,19 +92,3 @@ class PrintTableMetricsCallback(Callback):
74
92
  table.add_row(*values)
75
93
 
76
94
  return table
77
-
78
-
79
- class PrintTableMetricsCallbackConfig(CallbackConfigBase):
80
- """Configuration class for PrintTableMetricsCallback."""
81
-
82
- name: Literal["print_table_metrics"] = "print_table_metrics"
83
-
84
- enabled: bool = True
85
- """Whether to enable the callback or not."""
86
-
87
- metric_patterns: list[str] | None = None
88
- """List of patterns to filter the metrics to be displayed. If None, all metrics are displayed."""
89
-
90
- @override
91
- def create_callbacks(self, trainer_config):
92
- yield PrintTableMetricsCallback(metric_patterns=self.metric_patterns)
@@ -11,13 +11,15 @@ from lightning.pytorch.utilities.types import (
11
11
  LRSchedulerConfigType,
12
12
  LRSchedulerTypeUnion,
13
13
  )
14
- from typing_extensions import Protocol, override, runtime_checkable
14
+ from typing_extensions import Protocol, final, override, runtime_checkable
15
15
 
16
- from .base import CallbackConfigBase
16
+ from .base import CallbackConfigBase, callback_registry
17
17
 
18
18
  log = logging.getLogger(__name__)
19
19
 
20
20
 
21
+ @final
22
+ @callback_registry.register
21
23
  class RLPSanityChecksCallbackConfig(CallbackConfigBase):
22
24
  """
23
25
  If enabled, will do some sanity checks if the `ReduceLROnPlateau` scheduler is used:
@@ -7,18 +7,15 @@ from typing import Literal, Protocol, runtime_checkable
7
7
  import torch.nn as nn
8
8
  from lightning.pytorch import LightningModule, Trainer
9
9
  from lightning.pytorch.callbacks import Callback
10
- from typing_extensions import TypeAliasType, override
10
+ from typing_extensions import TypeAliasType, final, override
11
11
 
12
- from .base import CallbackConfigBase
12
+ from .base import CallbackConfigBase, callback_registry
13
13
 
14
14
  log = logging.getLogger(__name__)
15
15
 
16
16
 
17
- def _parameters_to_names(parameters: Iterable[nn.Parameter], model: nn.Module):
18
- mapping = {id(p): n for n, p in model.named_parameters()}
19
- return [mapping[id(p)] for p in parameters]
20
-
21
-
17
+ @final
18
+ @callback_registry.register
22
19
  class SharedParametersCallbackConfig(CallbackConfigBase):
23
20
  """A callback that allows scaling the gradients of shared parameters that
24
21
  are registered in the ``self.shared_parameters`` list of the root module.
@@ -34,6 +31,11 @@ class SharedParametersCallbackConfig(CallbackConfigBase):
34
31
  yield SharedParametersCallback(self)
35
32
 
36
33
 
34
+ def _parameters_to_names(parameters: Iterable[nn.Parameter], model: nn.Module):
35
+ mapping = {id(p): n for n, p in model.named_parameters()}
36
+ return [mapping[id(p)] for p in parameters]
37
+
38
+
37
39
  SharedParametersList = TypeAliasType(
38
40
  "SharedParametersList", list[tuple[nn.Parameter, int | float]]
39
41
  )
@@ -7,13 +7,23 @@ from typing import Any, Literal
7
7
  from lightning.pytorch import LightningModule, Trainer
8
8
  from lightning.pytorch.callbacks import Callback
9
9
  from lightning.pytorch.utilities.types import STEP_OUTPUT
10
- from typing_extensions import override
10
+ from typing_extensions import final, override
11
11
 
12
- from .base import CallbackConfigBase
12
+ from .base import CallbackConfigBase, callback_registry
13
13
 
14
14
  log = logging.getLogger(__name__)
15
15
 
16
16
 
17
+ @final
18
+ @callback_registry.register
19
+ class EpochTimerCallbackConfig(CallbackConfigBase):
20
+ name: Literal["epoch_timer"] = "epoch_timer"
21
+
22
+ @override
23
+ def create_callbacks(self, trainer_config):
24
+ yield EpochTimerCallback()
25
+
26
+
17
27
  class EpochTimerCallback(Callback):
18
28
  def __init__(self):
19
29
  super().__init__()
@@ -149,11 +159,3 @@ class EpochTimerCallback(Callback):
149
159
  def load_state_dict(self, state_dict: dict[str, Any]) -> None:
150
160
  self._elapsed_time = state_dict["elapsed_time"]
151
161
  self._total_batches = state_dict["total_batches"]
152
-
153
-
154
- class EpochTimerCallbackConfig(CallbackConfigBase):
155
- name: Literal["epoch_timer"] = "epoch_timer"
156
-
157
- @override
158
- def create_callbacks(self, trainer_config):
159
- yield EpochTimerCallback()
@@ -9,13 +9,15 @@ from lightning.pytorch import LightningModule, Trainer
9
9
  from lightning.pytorch.callbacks.callback import Callback
10
10
  from lightning.pytorch.loggers import WandbLogger
11
11
  from nshrunner._env import SNAPSHOT_DIR
12
- from typing_extensions import override
12
+ from typing_extensions import final, override
13
13
 
14
- from .base import CallbackConfigBase
14
+ from .base import CallbackConfigBase, callback_registry
15
15
 
16
16
  log = logging.getLogger(__name__)
17
17
 
18
18
 
19
+ @final
20
+ @callback_registry.register
19
21
  class WandbUploadCodeCallbackConfig(CallbackConfigBase):
20
22
  name: Literal["wandb_upload_code"] = "wandb_upload_code"
21
23
 
@@ -7,13 +7,15 @@ import torch.nn as nn
7
7
  from lightning.pytorch import LightningModule, Trainer
8
8
  from lightning.pytorch.callbacks.callback import Callback
9
9
  from lightning.pytorch.loggers import WandbLogger
10
- from typing_extensions import override
10
+ from typing_extensions import final, override
11
11
 
12
- from .base import CallbackConfigBase
12
+ from .base import CallbackConfigBase, callback_registry
13
13
 
14
14
  log = logging.getLogger(__name__)
15
15
 
16
16
 
17
+ @final
18
+ @callback_registry.register
17
19
  class WandbWatchCallbackConfig(CallbackConfigBase):
18
20
  name: Literal["wandb_watch"] = "wandb_watch"
19
21
 
@@ -5,6 +5,7 @@ __codegen__ = True
5
5
  from nshtrainer import MetricConfig as MetricConfig
6
6
  from nshtrainer import TrainerConfig as TrainerConfig
7
7
  from nshtrainer import accelerator_registry as accelerator_registry
8
+ from nshtrainer import callback_registry as callback_registry
8
9
  from nshtrainer import plugin_registry as plugin_registry
9
10
  from nshtrainer._checkpoint.metadata import CheckpointMetadata as CheckpointMetadata
10
11
  from nshtrainer._directory import DirectoryConfig as DirectoryConfig
@@ -13,6 +14,7 @@ from nshtrainer._hf_hub import (
13
14
  HuggingFaceHubAutoCreateConfig as HuggingFaceHubAutoCreateConfig,
14
15
  )
15
16
  from nshtrainer._hf_hub import HuggingFaceHubConfig as HuggingFaceHubConfig
17
+ from nshtrainer.callbacks import ActSaveConfig as ActSaveConfig
16
18
  from nshtrainer.callbacks import (
17
19
  BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
18
20
  )
@@ -35,6 +37,7 @@ from nshtrainer.callbacks import (
35
37
  from nshtrainer.callbacks import (
36
38
  LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
37
39
  )
40
+ from nshtrainer.callbacks import LearningRateMonitorConfig as LearningRateMonitorConfig
38
41
  from nshtrainer.callbacks import LogEpochCallbackConfig as LogEpochCallbackConfig
39
42
  from nshtrainer.callbacks import NormLoggingCallbackConfig as NormLoggingCallbackConfig
40
43
  from nshtrainer.callbacks import (
@@ -49,14 +52,10 @@ from nshtrainer.callbacks import (
49
52
  from nshtrainer.callbacks import (
50
53
  SharedParametersCallbackConfig as SharedParametersCallbackConfig,
51
54
  )
52
- from nshtrainer.callbacks import (
53
- TimeCheckpointCallbackConfig as TimeCheckpointCallbackConfig,
54
- )
55
55
  from nshtrainer.callbacks import (
56
56
  WandbUploadCodeCallbackConfig as WandbUploadCodeCallbackConfig,
57
57
  )
58
58
  from nshtrainer.callbacks import WandbWatchCallbackConfig as WandbWatchCallbackConfig
59
- from nshtrainer.callbacks.actsave import ActSaveConfig as ActSaveConfig
60
59
  from nshtrainer.callbacks.checkpoint._base import (
61
60
  BaseCheckpointCallbackConfig as BaseCheckpointCallbackConfig,
62
61
  )
@@ -106,9 +105,6 @@ from nshtrainer.trainer._config import (
106
105
  from nshtrainer.trainer._config import CheckpointSavingConfig as CheckpointSavingConfig
107
106
  from nshtrainer.trainer._config import EnvironmentConfig as EnvironmentConfig
108
107
  from nshtrainer.trainer._config import GradientClippingConfig as GradientClippingConfig
109
- from nshtrainer.trainer._config import (
110
- LearningRateMonitorConfig as LearningRateMonitorConfig,
111
- )
112
108
  from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingConfig
113
109
  from nshtrainer.trainer._config import StrategyConfig as StrategyConfig
114
110
  from nshtrainer.trainer.accelerator import CPUAcceleratorConfig as CPUAcceleratorConfig
@@ -320,7 +316,6 @@ __all__ = [
320
316
  "SwishNonlinearityConfig",
321
317
  "TanhNonlinearityConfig",
322
318
  "TensorboardLoggerConfig",
323
- "TimeCheckpointCallbackConfig",
324
319
  "TorchCheckpointIOPlugin",
325
320
  "TorchElasticEnvironmentPlugin",
326
321
  "TorchSyncBatchNormPlugin",
@@ -337,6 +332,7 @@ __all__ = [
337
332
  "_directory",
338
333
  "_hf_hub",
339
334
  "accelerator_registry",
335
+ "callback_registry",
340
336
  "callbacks",
341
337
  "loggers",
342
338
  "lr_scheduler",
@@ -7,9 +7,11 @@ from nshtrainer._hf_hub import (
7
7
  HuggingFaceHubAutoCreateConfig as HuggingFaceHubAutoCreateConfig,
8
8
  )
9
9
  from nshtrainer._hf_hub import HuggingFaceHubConfig as HuggingFaceHubConfig
10
+ from nshtrainer._hf_hub import callback_registry as callback_registry
10
11
 
11
12
  __all__ = [
12
13
  "CallbackConfigBase",
13
14
  "HuggingFaceHubAutoCreateConfig",
14
15
  "HuggingFaceHubConfig",
16
+ "callback_registry",
15
17
  ]
@@ -2,6 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  __codegen__ = True
4
4
 
5
+ from nshtrainer.callbacks import ActSaveConfig as ActSaveConfig
5
6
  from nshtrainer.callbacks import (
6
7
  BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
7
8
  )
@@ -25,6 +26,7 @@ from nshtrainer.callbacks import (
25
26
  from nshtrainer.callbacks import (
26
27
  LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
27
28
  )
29
+ from nshtrainer.callbacks import LearningRateMonitorConfig as LearningRateMonitorConfig
28
30
  from nshtrainer.callbacks import LogEpochCallbackConfig as LogEpochCallbackConfig
29
31
  from nshtrainer.callbacks import NormLoggingCallbackConfig as NormLoggingCallbackConfig
30
32
  from nshtrainer.callbacks import (
@@ -39,14 +41,11 @@ from nshtrainer.callbacks import (
39
41
  from nshtrainer.callbacks import (
40
42
  SharedParametersCallbackConfig as SharedParametersCallbackConfig,
41
43
  )
42
- from nshtrainer.callbacks import (
43
- TimeCheckpointCallbackConfig as TimeCheckpointCallbackConfig,
44
- )
45
44
  from nshtrainer.callbacks import (
46
45
  WandbUploadCodeCallbackConfig as WandbUploadCodeCallbackConfig,
47
46
  )
48
47
  from nshtrainer.callbacks import WandbWatchCallbackConfig as WandbWatchCallbackConfig
49
- from nshtrainer.callbacks.actsave import ActSaveConfig as ActSaveConfig
48
+ from nshtrainer.callbacks import callback_registry as callback_registry
50
49
  from nshtrainer.callbacks.checkpoint._base import (
51
50
  BaseCheckpointCallbackConfig as BaseCheckpointCallbackConfig,
52
51
  )
@@ -54,9 +53,6 @@ from nshtrainer.callbacks.checkpoint._base import (
54
53
  CheckpointMetadata as CheckpointMetadata,
55
54
  )
56
55
  from nshtrainer.callbacks.early_stopping import MetricConfig as MetricConfig
57
- from nshtrainer.callbacks.lr_monitor import (
58
- LearningRateMonitorConfig as LearningRateMonitorConfig,
59
- )
60
56
 
61
57
  from . import actsave as actsave
62
58
  from . import base as base
@@ -100,11 +96,11 @@ __all__ = [
100
96
  "PrintTableMetricsCallbackConfig",
101
97
  "RLPSanityChecksCallbackConfig",
102
98
  "SharedParametersCallbackConfig",
103
- "TimeCheckpointCallbackConfig",
104
99
  "WandbUploadCodeCallbackConfig",
105
100
  "WandbWatchCallbackConfig",
106
101
  "actsave",
107
102
  "base",
103
+ "callback_registry",
108
104
  "checkpoint",
109
105
  "debug_flag",
110
106
  "directory_setup",
@@ -4,8 +4,10 @@ __codegen__ = True
4
4
 
5
5
  from nshtrainer.callbacks.actsave import ActSaveConfig as ActSaveConfig
6
6
  from nshtrainer.callbacks.actsave import CallbackConfigBase as CallbackConfigBase
7
+ from nshtrainer.callbacks.actsave import callback_registry as callback_registry
7
8
 
8
9
  __all__ = [
9
10
  "ActSaveConfig",
10
11
  "CallbackConfigBase",
12
+ "callback_registry",
11
13
  ]
@@ -3,7 +3,9 @@ from __future__ import annotations
3
3
  __codegen__ = True
4
4
 
5
5
  from nshtrainer.callbacks.base import CallbackConfigBase as CallbackConfigBase
6
+ from nshtrainer.callbacks.base import callback_registry as callback_registry
6
7
 
7
8
  __all__ = [
8
9
  "CallbackConfigBase",
10
+ "callback_registry",
9
11
  ]
@@ -11,9 +11,6 @@ from nshtrainer.callbacks.checkpoint import (
11
11
  from nshtrainer.callbacks.checkpoint import (
12
12
  OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
13
13
  )
14
- from nshtrainer.callbacks.checkpoint import (
15
- TimeCheckpointCallbackConfig as TimeCheckpointCallbackConfig,
16
- )
17
14
  from nshtrainer.callbacks.checkpoint._base import (
18
15
  BaseCheckpointCallbackConfig as BaseCheckpointCallbackConfig,
19
16
  )
@@ -24,12 +21,14 @@ from nshtrainer.callbacks.checkpoint._base import (
24
21
  CheckpointMetadata as CheckpointMetadata,
25
22
  )
26
23
  from nshtrainer.callbacks.checkpoint.best_checkpoint import MetricConfig as MetricConfig
24
+ from nshtrainer.callbacks.checkpoint.best_checkpoint import (
25
+ callback_registry as callback_registry,
26
+ )
27
27
 
28
28
  from . import _base as _base
29
29
  from . import best_checkpoint as best_checkpoint
30
30
  from . import last_checkpoint as last_checkpoint
31
31
  from . import on_exception_checkpoint as on_exception_checkpoint
32
- from . import time_checkpoint as time_checkpoint
33
32
 
34
33
  __all__ = [
35
34
  "BaseCheckpointCallbackConfig",
@@ -39,10 +38,9 @@ __all__ = [
39
38
  "LastCheckpointCallbackConfig",
40
39
  "MetricConfig",
41
40
  "OnExceptionCheckpointCallbackConfig",
42
- "TimeCheckpointCallbackConfig",
43
41
  "_base",
44
42
  "best_checkpoint",
43
+ "callback_registry",
45
44
  "last_checkpoint",
46
45
  "on_exception_checkpoint",
47
- "time_checkpoint",
48
46
  ]
@@ -12,10 +12,14 @@ from nshtrainer.callbacks.checkpoint.best_checkpoint import (
12
12
  CheckpointMetadata as CheckpointMetadata,
13
13
  )
14
14
  from nshtrainer.callbacks.checkpoint.best_checkpoint import MetricConfig as MetricConfig
15
+ from nshtrainer.callbacks.checkpoint.best_checkpoint import (
16
+ callback_registry as callback_registry,
17
+ )
15
18
 
16
19
  __all__ = [
17
20
  "BaseCheckpointCallbackConfig",
18
21
  "BestCheckpointCallbackConfig",
19
22
  "CheckpointMetadata",
20
23
  "MetricConfig",
24
+ "callback_registry",
21
25
  ]
@@ -11,9 +11,13 @@ from nshtrainer.callbacks.checkpoint.last_checkpoint import (
11
11
  from nshtrainer.callbacks.checkpoint.last_checkpoint import (
12
12
  LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
13
13
  )
14
+ from nshtrainer.callbacks.checkpoint.last_checkpoint import (
15
+ callback_registry as callback_registry,
16
+ )
14
17
 
15
18
  __all__ = [
16
19
  "BaseCheckpointCallbackConfig",
17
20
  "CheckpointMetadata",
18
21
  "LastCheckpointCallbackConfig",
22
+ "callback_registry",
19
23
  ]
@@ -8,8 +8,12 @@ from nshtrainer.callbacks.checkpoint.on_exception_checkpoint import (
8
8
  from nshtrainer.callbacks.checkpoint.on_exception_checkpoint import (
9
9
  OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
10
10
  )
11
+ from nshtrainer.callbacks.checkpoint.on_exception_checkpoint import (
12
+ callback_registry as callback_registry,
13
+ )
11
14
 
12
15
  __all__ = [
13
16
  "CallbackConfigBase",
14
17
  "OnExceptionCheckpointCallbackConfig",
18
+ "callback_registry",
15
19
  ]
@@ -6,8 +6,10 @@ from nshtrainer.callbacks.debug_flag import CallbackConfigBase as CallbackConfig
6
6
  from nshtrainer.callbacks.debug_flag import (
7
7
  DebugFlagCallbackConfig as DebugFlagCallbackConfig,
8
8
  )
9
+ from nshtrainer.callbacks.debug_flag import callback_registry as callback_registry
9
10
 
10
11
  __all__ = [
11
12
  "CallbackConfigBase",
12
13
  "DebugFlagCallbackConfig",
14
+ "callback_registry",
13
15
  ]
@@ -8,8 +8,10 @@ from nshtrainer.callbacks.directory_setup import (
8
8
  from nshtrainer.callbacks.directory_setup import (
9
9
  DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
10
10
  )
11
+ from nshtrainer.callbacks.directory_setup import callback_registry as callback_registry
11
12
 
12
13
  __all__ = [
13
14
  "CallbackConfigBase",
14
15
  "DirectorySetupCallbackConfig",
16
+ "callback_registry",
15
17
  ]
@@ -7,9 +7,11 @@ from nshtrainer.callbacks.early_stopping import (
7
7
  EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
8
8
  )
9
9
  from nshtrainer.callbacks.early_stopping import MetricConfig as MetricConfig
10
+ from nshtrainer.callbacks.early_stopping import callback_registry as callback_registry
10
11
 
11
12
  __all__ = [
12
13
  "CallbackConfigBase",
13
14
  "EarlyStoppingCallbackConfig",
14
15
  "MetricConfig",
16
+ "callback_registry",
15
17
  ]
@@ -4,8 +4,10 @@ __codegen__ = True
4
4
 
5
5
  from nshtrainer.callbacks.ema import CallbackConfigBase as CallbackConfigBase
6
6
  from nshtrainer.callbacks.ema import EMACallbackConfig as EMACallbackConfig
7
+ from nshtrainer.callbacks.ema import callback_registry as callback_registry
7
8
 
8
9
  __all__ = [
9
10
  "CallbackConfigBase",
10
11
  "EMACallbackConfig",
12
+ "callback_registry",
11
13
  ]
@@ -6,8 +6,10 @@ from nshtrainer.callbacks.finite_checks import CallbackConfigBase as CallbackCon
6
6
  from nshtrainer.callbacks.finite_checks import (
7
7
  FiniteChecksCallbackConfig as FiniteChecksCallbackConfig,
8
8
  )
9
+ from nshtrainer.callbacks.finite_checks import callback_registry as callback_registry
9
10
 
10
11
  __all__ = [
11
12
  "CallbackConfigBase",
12
13
  "FiniteChecksCallbackConfig",
14
+ "callback_registry",
13
15
  ]
@@ -8,8 +8,12 @@ from nshtrainer.callbacks.gradient_skipping import (
8
8
  from nshtrainer.callbacks.gradient_skipping import (
9
9
  GradientSkippingCallbackConfig as GradientSkippingCallbackConfig,
10
10
  )
11
+ from nshtrainer.callbacks.gradient_skipping import (
12
+ callback_registry as callback_registry,
13
+ )
11
14
 
12
15
  __all__ = [
13
16
  "CallbackConfigBase",
14
17
  "GradientSkippingCallbackConfig",
18
+ "callback_registry",
15
19
  ]
@@ -6,8 +6,10 @@ from nshtrainer.callbacks.log_epoch import CallbackConfigBase as CallbackConfigB
6
6
  from nshtrainer.callbacks.log_epoch import (
7
7
  LogEpochCallbackConfig as LogEpochCallbackConfig,
8
8
  )
9
+ from nshtrainer.callbacks.log_epoch import callback_registry as callback_registry
9
10
 
10
11
  __all__ = [
11
12
  "CallbackConfigBase",
12
13
  "LogEpochCallbackConfig",
14
+ "callback_registry",
13
15
  ]
@@ -6,8 +6,10 @@ from nshtrainer.callbacks.lr_monitor import CallbackConfigBase as CallbackConfig
6
6
  from nshtrainer.callbacks.lr_monitor import (
7
7
  LearningRateMonitorConfig as LearningRateMonitorConfig,
8
8
  )
9
+ from nshtrainer.callbacks.lr_monitor import callback_registry as callback_registry
9
10
 
10
11
  __all__ = [
11
12
  "CallbackConfigBase",
12
13
  "LearningRateMonitorConfig",
14
+ "callback_registry",
13
15
  ]
@@ -6,8 +6,10 @@ from nshtrainer.callbacks.norm_logging import CallbackConfigBase as CallbackConf
6
6
  from nshtrainer.callbacks.norm_logging import (
7
7
  NormLoggingCallbackConfig as NormLoggingCallbackConfig,
8
8
  )
9
+ from nshtrainer.callbacks.norm_logging import callback_registry as callback_registry
9
10
 
10
11
  __all__ = [
11
12
  "CallbackConfigBase",
12
13
  "NormLoggingCallbackConfig",
14
+ "callback_registry",
13
15
  ]
@@ -6,8 +6,10 @@ from nshtrainer.callbacks.print_table import CallbackConfigBase as CallbackConfi
6
6
  from nshtrainer.callbacks.print_table import (
7
7
  PrintTableMetricsCallbackConfig as PrintTableMetricsCallbackConfig,
8
8
  )
9
+ from nshtrainer.callbacks.print_table import callback_registry as callback_registry
9
10
 
10
11
  __all__ = [
11
12
  "CallbackConfigBase",
12
13
  "PrintTableMetricsCallbackConfig",
14
+ "callback_registry",
13
15
  ]
@@ -8,8 +8,12 @@ from nshtrainer.callbacks.rlp_sanity_checks import (
8
8
  from nshtrainer.callbacks.rlp_sanity_checks import (
9
9
  RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
10
10
  )
11
+ from nshtrainer.callbacks.rlp_sanity_checks import (
12
+ callback_registry as callback_registry,
13
+ )
11
14
 
12
15
  __all__ = [
13
16
  "CallbackConfigBase",
14
17
  "RLPSanityChecksCallbackConfig",
18
+ "callback_registry",
15
19
  ]
@@ -8,8 +8,12 @@ from nshtrainer.callbacks.shared_parameters import (
8
8
  from nshtrainer.callbacks.shared_parameters import (
9
9
  SharedParametersCallbackConfig as SharedParametersCallbackConfig,
10
10
  )
11
+ from nshtrainer.callbacks.shared_parameters import (
12
+ callback_registry as callback_registry,
13
+ )
11
14
 
12
15
  __all__ = [
13
16
  "CallbackConfigBase",
14
17
  "SharedParametersCallbackConfig",
18
+ "callback_registry",
15
19
  ]
@@ -6,8 +6,10 @@ from nshtrainer.callbacks.timer import CallbackConfigBase as CallbackConfigBase
6
6
  from nshtrainer.callbacks.timer import (
7
7
  EpochTimerCallbackConfig as EpochTimerCallbackConfig,
8
8
  )
9
+ from nshtrainer.callbacks.timer import callback_registry as callback_registry
9
10
 
10
11
  __all__ = [
11
12
  "CallbackConfigBase",
12
13
  "EpochTimerCallbackConfig",
14
+ "callback_registry",
13
15
  ]