nshtrainer 1.0.0b32__py3-none-any.whl → 1.0.0b36__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/__init__.py +1 -0
- nshtrainer/_hf_hub.py +8 -1
- nshtrainer/callbacks/__init__.py +10 -23
- nshtrainer/callbacks/actsave.py +6 -2
- nshtrainer/callbacks/base.py +3 -0
- nshtrainer/callbacks/checkpoint/__init__.py +0 -4
- nshtrainer/callbacks/checkpoint/best_checkpoint.py +2 -0
- nshtrainer/callbacks/checkpoint/last_checkpoint.py +72 -2
- nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +4 -2
- nshtrainer/callbacks/debug_flag.py +4 -2
- nshtrainer/callbacks/directory_setup.py +23 -21
- nshtrainer/callbacks/early_stopping.py +4 -2
- nshtrainer/callbacks/ema.py +29 -27
- nshtrainer/callbacks/finite_checks.py +21 -19
- nshtrainer/callbacks/gradient_skipping.py +29 -27
- nshtrainer/callbacks/log_epoch.py +4 -2
- nshtrainer/callbacks/lr_monitor.py +6 -1
- nshtrainer/callbacks/norm_logging.py +36 -34
- nshtrainer/callbacks/print_table.py +20 -18
- nshtrainer/callbacks/rlp_sanity_checks.py +4 -2
- nshtrainer/callbacks/shared_parameters.py +9 -7
- nshtrainer/callbacks/timer.py +12 -10
- nshtrainer/callbacks/wandb_upload_code.py +4 -2
- nshtrainer/callbacks/wandb_watch.py +4 -2
- nshtrainer/configs/__init__.py +4 -8
- nshtrainer/configs/_hf_hub/__init__.py +2 -0
- nshtrainer/configs/callbacks/__init__.py +4 -8
- nshtrainer/configs/callbacks/actsave/__init__.py +2 -0
- nshtrainer/configs/callbacks/base/__init__.py +2 -0
- nshtrainer/configs/callbacks/checkpoint/__init__.py +4 -6
- nshtrainer/configs/callbacks/checkpoint/best_checkpoint/__init__.py +4 -0
- nshtrainer/configs/callbacks/checkpoint/last_checkpoint/__init__.py +4 -0
- nshtrainer/configs/callbacks/checkpoint/on_exception_checkpoint/__init__.py +4 -0
- nshtrainer/configs/callbacks/debug_flag/__init__.py +2 -0
- nshtrainer/configs/callbacks/directory_setup/__init__.py +2 -0
- nshtrainer/configs/callbacks/early_stopping/__init__.py +2 -0
- nshtrainer/configs/callbacks/ema/__init__.py +2 -0
- nshtrainer/configs/callbacks/finite_checks/__init__.py +2 -0
- nshtrainer/configs/callbacks/gradient_skipping/__init__.py +4 -0
- nshtrainer/configs/callbacks/log_epoch/__init__.py +2 -0
- nshtrainer/configs/callbacks/lr_monitor/__init__.py +2 -0
- nshtrainer/configs/callbacks/norm_logging/__init__.py +2 -0
- nshtrainer/configs/callbacks/print_table/__init__.py +2 -0
- nshtrainer/configs/callbacks/rlp_sanity_checks/__init__.py +4 -0
- nshtrainer/configs/callbacks/shared_parameters/__init__.py +4 -0
- nshtrainer/configs/callbacks/timer/__init__.py +2 -0
- nshtrainer/configs/callbacks/wandb_upload_code/__init__.py +4 -0
- nshtrainer/configs/callbacks/wandb_watch/__init__.py +2 -0
- nshtrainer/configs/trainer/__init__.py +2 -4
- nshtrainer/configs/trainer/_config/__init__.py +0 -8
- nshtrainer/data/datamodule.py +0 -2
- nshtrainer/model/base.py +0 -2
- nshtrainer/trainer/__init__.py +3 -2
- nshtrainer/trainer/_config.py +4 -42
- {nshtrainer-1.0.0b32.dist-info → nshtrainer-1.0.0b36.dist-info}/METADATA +1 -1
- {nshtrainer-1.0.0b32.dist-info → nshtrainer-1.0.0b36.dist-info}/RECORD +57 -60
- nshtrainer/callbacks/checkpoint/time_checkpoint.py +0 -114
- nshtrainer/configs/callbacks/checkpoint/time_checkpoint/__init__.py +0 -19
- nshtrainer/util/hparams.py +0 -18
- {nshtrainer-1.0.0b32.dist-info → nshtrainer-1.0.0b36.dist-info}/WHEEL +0 -0
@@ -7,13 +7,47 @@ import torch
|
|
7
7
|
import torch.nn as nn
|
8
8
|
from lightning.pytorch import Callback, LightningModule, Trainer
|
9
9
|
from torch.optim import Optimizer
|
10
|
-
from typing_extensions import override
|
10
|
+
from typing_extensions import final, override
|
11
11
|
|
12
|
-
from .base import CallbackConfigBase
|
12
|
+
from .base import CallbackConfigBase, callback_registry
|
13
13
|
|
14
14
|
log = logging.getLogger(__name__)
|
15
15
|
|
16
16
|
|
17
|
+
@final
|
18
|
+
@callback_registry.register
|
19
|
+
class NormLoggingCallbackConfig(CallbackConfigBase):
|
20
|
+
name: Literal["norm_logging"] = "norm_logging"
|
21
|
+
|
22
|
+
log_grad_norm: bool | str | float = False
|
23
|
+
"""If enabled, will log the gradient norm (averaged across all model parameters) to the logger."""
|
24
|
+
log_grad_norm_per_param: bool | str | float = False
|
25
|
+
"""If enabled, will log the gradient norm for each model parameter to the logger."""
|
26
|
+
|
27
|
+
log_param_norm: bool | str | float = False
|
28
|
+
"""If enabled, will log the parameter norm (averaged across all model parameters) to the logger."""
|
29
|
+
log_param_norm_per_param: bool | str | float = False
|
30
|
+
"""If enabled, will log the parameter norm for each model parameter to the logger."""
|
31
|
+
|
32
|
+
def __bool__(self):
|
33
|
+
return any(
|
34
|
+
v
|
35
|
+
for v in (
|
36
|
+
self.log_grad_norm,
|
37
|
+
self.log_grad_norm_per_param,
|
38
|
+
self.log_param_norm,
|
39
|
+
self.log_param_norm_per_param,
|
40
|
+
)
|
41
|
+
)
|
42
|
+
|
43
|
+
@override
|
44
|
+
def create_callbacks(self, trainer_config):
|
45
|
+
if not self:
|
46
|
+
return
|
47
|
+
|
48
|
+
yield NormLoggingCallback(self)
|
49
|
+
|
50
|
+
|
17
51
|
def grad_norm(
|
18
52
|
module: nn.Module,
|
19
53
|
norm_type: float | int | str,
|
@@ -155,35 +189,3 @@ class NormLoggingCallback(Callback):
|
|
155
189
|
self._perform_norm_logging(
|
156
190
|
pl_module, optimizer, prefix=f"train/optimizer_{i}/"
|
157
191
|
)
|
158
|
-
|
159
|
-
|
160
|
-
class NormLoggingCallbackConfig(CallbackConfigBase):
|
161
|
-
name: Literal["norm_logging"] = "norm_logging"
|
162
|
-
|
163
|
-
log_grad_norm: bool | str | float = False
|
164
|
-
"""If enabled, will log the gradient norm (averaged across all model parameters) to the logger."""
|
165
|
-
log_grad_norm_per_param: bool | str | float = False
|
166
|
-
"""If enabled, will log the gradient norm for each model parameter to the logger."""
|
167
|
-
|
168
|
-
log_param_norm: bool | str | float = False
|
169
|
-
"""If enabled, will log the parameter norm (averaged across all model parameters) to the logger."""
|
170
|
-
log_param_norm_per_param: bool | str | float = False
|
171
|
-
"""If enabled, will log the parameter norm for each model parameter to the logger."""
|
172
|
-
|
173
|
-
def __bool__(self):
|
174
|
-
return any(
|
175
|
-
v
|
176
|
-
for v in (
|
177
|
-
self.log_grad_norm,
|
178
|
-
self.log_grad_norm_per_param,
|
179
|
-
self.log_param_norm,
|
180
|
-
self.log_param_norm_per_param,
|
181
|
-
)
|
182
|
-
)
|
183
|
-
|
184
|
-
@override
|
185
|
-
def create_callbacks(self, trainer_config):
|
186
|
-
if not self:
|
187
|
-
return
|
188
|
-
|
189
|
-
yield NormLoggingCallback(self)
|
@@ -9,13 +9,31 @@ from typing import Literal
|
|
9
9
|
import torch
|
10
10
|
from lightning.pytorch import LightningModule, Trainer
|
11
11
|
from lightning.pytorch.callbacks import Callback
|
12
|
-
from typing_extensions import override
|
12
|
+
from typing_extensions import final, override
|
13
13
|
|
14
|
-
from .base import CallbackConfigBase
|
14
|
+
from .base import CallbackConfigBase, callback_registry
|
15
15
|
|
16
16
|
log = logging.getLogger(__name__)
|
17
17
|
|
18
18
|
|
19
|
+
@final
|
20
|
+
@callback_registry.register
|
21
|
+
class PrintTableMetricsCallbackConfig(CallbackConfigBase):
|
22
|
+
"""Configuration class for PrintTableMetricsCallback."""
|
23
|
+
|
24
|
+
name: Literal["print_table_metrics"] = "print_table_metrics"
|
25
|
+
|
26
|
+
enabled: bool = True
|
27
|
+
"""Whether to enable the callback or not."""
|
28
|
+
|
29
|
+
metric_patterns: list[str] | None = None
|
30
|
+
"""List of patterns to filter the metrics to be displayed. If None, all metrics are displayed."""
|
31
|
+
|
32
|
+
@override
|
33
|
+
def create_callbacks(self, trainer_config):
|
34
|
+
yield PrintTableMetricsCallback(metric_patterns=self.metric_patterns)
|
35
|
+
|
36
|
+
|
19
37
|
class PrintTableMetricsCallback(Callback):
|
20
38
|
"""Prints a table with the metrics in columns on every epoch end."""
|
21
39
|
|
@@ -74,19 +92,3 @@ class PrintTableMetricsCallback(Callback):
|
|
74
92
|
table.add_row(*values)
|
75
93
|
|
76
94
|
return table
|
77
|
-
|
78
|
-
|
79
|
-
class PrintTableMetricsCallbackConfig(CallbackConfigBase):
|
80
|
-
"""Configuration class for PrintTableMetricsCallback."""
|
81
|
-
|
82
|
-
name: Literal["print_table_metrics"] = "print_table_metrics"
|
83
|
-
|
84
|
-
enabled: bool = True
|
85
|
-
"""Whether to enable the callback or not."""
|
86
|
-
|
87
|
-
metric_patterns: list[str] | None = None
|
88
|
-
"""List of patterns to filter the metrics to be displayed. If None, all metrics are displayed."""
|
89
|
-
|
90
|
-
@override
|
91
|
-
def create_callbacks(self, trainer_config):
|
92
|
-
yield PrintTableMetricsCallback(metric_patterns=self.metric_patterns)
|
@@ -11,13 +11,15 @@ from lightning.pytorch.utilities.types import (
|
|
11
11
|
LRSchedulerConfigType,
|
12
12
|
LRSchedulerTypeUnion,
|
13
13
|
)
|
14
|
-
from typing_extensions import Protocol, override, runtime_checkable
|
14
|
+
from typing_extensions import Protocol, final, override, runtime_checkable
|
15
15
|
|
16
|
-
from .base import CallbackConfigBase
|
16
|
+
from .base import CallbackConfigBase, callback_registry
|
17
17
|
|
18
18
|
log = logging.getLogger(__name__)
|
19
19
|
|
20
20
|
|
21
|
+
@final
|
22
|
+
@callback_registry.register
|
21
23
|
class RLPSanityChecksCallbackConfig(CallbackConfigBase):
|
22
24
|
"""
|
23
25
|
If enabled, will do some sanity checks if the `ReduceLROnPlateau` scheduler is used:
|
@@ -7,18 +7,15 @@ from typing import Literal, Protocol, runtime_checkable
|
|
7
7
|
import torch.nn as nn
|
8
8
|
from lightning.pytorch import LightningModule, Trainer
|
9
9
|
from lightning.pytorch.callbacks import Callback
|
10
|
-
from typing_extensions import TypeAliasType, override
|
10
|
+
from typing_extensions import TypeAliasType, final, override
|
11
11
|
|
12
|
-
from .base import CallbackConfigBase
|
12
|
+
from .base import CallbackConfigBase, callback_registry
|
13
13
|
|
14
14
|
log = logging.getLogger(__name__)
|
15
15
|
|
16
16
|
|
17
|
-
|
18
|
-
|
19
|
-
return [mapping[id(p)] for p in parameters]
|
20
|
-
|
21
|
-
|
17
|
+
@final
|
18
|
+
@callback_registry.register
|
22
19
|
class SharedParametersCallbackConfig(CallbackConfigBase):
|
23
20
|
"""A callback that allows scaling the gradients of shared parameters that
|
24
21
|
are registered in the ``self.shared_parameters`` list of the root module.
|
@@ -34,6 +31,11 @@ class SharedParametersCallbackConfig(CallbackConfigBase):
|
|
34
31
|
yield SharedParametersCallback(self)
|
35
32
|
|
36
33
|
|
34
|
+
def _parameters_to_names(parameters: Iterable[nn.Parameter], model: nn.Module):
|
35
|
+
mapping = {id(p): n for n, p in model.named_parameters()}
|
36
|
+
return [mapping[id(p)] for p in parameters]
|
37
|
+
|
38
|
+
|
37
39
|
SharedParametersList = TypeAliasType(
|
38
40
|
"SharedParametersList", list[tuple[nn.Parameter, int | float]]
|
39
41
|
)
|
nshtrainer/callbacks/timer.py
CHANGED
@@ -7,13 +7,23 @@ from typing import Any, Literal
|
|
7
7
|
from lightning.pytorch import LightningModule, Trainer
|
8
8
|
from lightning.pytorch.callbacks import Callback
|
9
9
|
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
10
|
-
from typing_extensions import override
|
10
|
+
from typing_extensions import final, override
|
11
11
|
|
12
|
-
from .base import CallbackConfigBase
|
12
|
+
from .base import CallbackConfigBase, callback_registry
|
13
13
|
|
14
14
|
log = logging.getLogger(__name__)
|
15
15
|
|
16
16
|
|
17
|
+
@final
|
18
|
+
@callback_registry.register
|
19
|
+
class EpochTimerCallbackConfig(CallbackConfigBase):
|
20
|
+
name: Literal["epoch_timer"] = "epoch_timer"
|
21
|
+
|
22
|
+
@override
|
23
|
+
def create_callbacks(self, trainer_config):
|
24
|
+
yield EpochTimerCallback()
|
25
|
+
|
26
|
+
|
17
27
|
class EpochTimerCallback(Callback):
|
18
28
|
def __init__(self):
|
19
29
|
super().__init__()
|
@@ -149,11 +159,3 @@ class EpochTimerCallback(Callback):
|
|
149
159
|
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
150
160
|
self._elapsed_time = state_dict["elapsed_time"]
|
151
161
|
self._total_batches = state_dict["total_batches"]
|
152
|
-
|
153
|
-
|
154
|
-
class EpochTimerCallbackConfig(CallbackConfigBase):
|
155
|
-
name: Literal["epoch_timer"] = "epoch_timer"
|
156
|
-
|
157
|
-
@override
|
158
|
-
def create_callbacks(self, trainer_config):
|
159
|
-
yield EpochTimerCallback()
|
@@ -9,13 +9,15 @@ from lightning.pytorch import LightningModule, Trainer
|
|
9
9
|
from lightning.pytorch.callbacks.callback import Callback
|
10
10
|
from lightning.pytorch.loggers import WandbLogger
|
11
11
|
from nshrunner._env import SNAPSHOT_DIR
|
12
|
-
from typing_extensions import override
|
12
|
+
from typing_extensions import final, override
|
13
13
|
|
14
|
-
from .base import CallbackConfigBase
|
14
|
+
from .base import CallbackConfigBase, callback_registry
|
15
15
|
|
16
16
|
log = logging.getLogger(__name__)
|
17
17
|
|
18
18
|
|
19
|
+
@final
|
20
|
+
@callback_registry.register
|
19
21
|
class WandbUploadCodeCallbackConfig(CallbackConfigBase):
|
20
22
|
name: Literal["wandb_upload_code"] = "wandb_upload_code"
|
21
23
|
|
@@ -7,13 +7,15 @@ import torch.nn as nn
|
|
7
7
|
from lightning.pytorch import LightningModule, Trainer
|
8
8
|
from lightning.pytorch.callbacks.callback import Callback
|
9
9
|
from lightning.pytorch.loggers import WandbLogger
|
10
|
-
from typing_extensions import override
|
10
|
+
from typing_extensions import final, override
|
11
11
|
|
12
|
-
from .base import CallbackConfigBase
|
12
|
+
from .base import CallbackConfigBase, callback_registry
|
13
13
|
|
14
14
|
log = logging.getLogger(__name__)
|
15
15
|
|
16
16
|
|
17
|
+
@final
|
18
|
+
@callback_registry.register
|
17
19
|
class WandbWatchCallbackConfig(CallbackConfigBase):
|
18
20
|
name: Literal["wandb_watch"] = "wandb_watch"
|
19
21
|
|
nshtrainer/configs/__init__.py
CHANGED
@@ -5,6 +5,7 @@ __codegen__ = True
|
|
5
5
|
from nshtrainer import MetricConfig as MetricConfig
|
6
6
|
from nshtrainer import TrainerConfig as TrainerConfig
|
7
7
|
from nshtrainer import accelerator_registry as accelerator_registry
|
8
|
+
from nshtrainer import callback_registry as callback_registry
|
8
9
|
from nshtrainer import plugin_registry as plugin_registry
|
9
10
|
from nshtrainer._checkpoint.metadata import CheckpointMetadata as CheckpointMetadata
|
10
11
|
from nshtrainer._directory import DirectoryConfig as DirectoryConfig
|
@@ -13,6 +14,7 @@ from nshtrainer._hf_hub import (
|
|
13
14
|
HuggingFaceHubAutoCreateConfig as HuggingFaceHubAutoCreateConfig,
|
14
15
|
)
|
15
16
|
from nshtrainer._hf_hub import HuggingFaceHubConfig as HuggingFaceHubConfig
|
17
|
+
from nshtrainer.callbacks import ActSaveConfig as ActSaveConfig
|
16
18
|
from nshtrainer.callbacks import (
|
17
19
|
BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
|
18
20
|
)
|
@@ -35,6 +37,7 @@ from nshtrainer.callbacks import (
|
|
35
37
|
from nshtrainer.callbacks import (
|
36
38
|
LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
|
37
39
|
)
|
40
|
+
from nshtrainer.callbacks import LearningRateMonitorConfig as LearningRateMonitorConfig
|
38
41
|
from nshtrainer.callbacks import LogEpochCallbackConfig as LogEpochCallbackConfig
|
39
42
|
from nshtrainer.callbacks import NormLoggingCallbackConfig as NormLoggingCallbackConfig
|
40
43
|
from nshtrainer.callbacks import (
|
@@ -49,14 +52,10 @@ from nshtrainer.callbacks import (
|
|
49
52
|
from nshtrainer.callbacks import (
|
50
53
|
SharedParametersCallbackConfig as SharedParametersCallbackConfig,
|
51
54
|
)
|
52
|
-
from nshtrainer.callbacks import (
|
53
|
-
TimeCheckpointCallbackConfig as TimeCheckpointCallbackConfig,
|
54
|
-
)
|
55
55
|
from nshtrainer.callbacks import (
|
56
56
|
WandbUploadCodeCallbackConfig as WandbUploadCodeCallbackConfig,
|
57
57
|
)
|
58
58
|
from nshtrainer.callbacks import WandbWatchCallbackConfig as WandbWatchCallbackConfig
|
59
|
-
from nshtrainer.callbacks.actsave import ActSaveConfig as ActSaveConfig
|
60
59
|
from nshtrainer.callbacks.checkpoint._base import (
|
61
60
|
BaseCheckpointCallbackConfig as BaseCheckpointCallbackConfig,
|
62
61
|
)
|
@@ -106,9 +105,6 @@ from nshtrainer.trainer._config import (
|
|
106
105
|
from nshtrainer.trainer._config import CheckpointSavingConfig as CheckpointSavingConfig
|
107
106
|
from nshtrainer.trainer._config import EnvironmentConfig as EnvironmentConfig
|
108
107
|
from nshtrainer.trainer._config import GradientClippingConfig as GradientClippingConfig
|
109
|
-
from nshtrainer.trainer._config import (
|
110
|
-
LearningRateMonitorConfig as LearningRateMonitorConfig,
|
111
|
-
)
|
112
108
|
from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingConfig
|
113
109
|
from nshtrainer.trainer._config import StrategyConfig as StrategyConfig
|
114
110
|
from nshtrainer.trainer.accelerator import CPUAcceleratorConfig as CPUAcceleratorConfig
|
@@ -320,7 +316,6 @@ __all__ = [
|
|
320
316
|
"SwishNonlinearityConfig",
|
321
317
|
"TanhNonlinearityConfig",
|
322
318
|
"TensorboardLoggerConfig",
|
323
|
-
"TimeCheckpointCallbackConfig",
|
324
319
|
"TorchCheckpointIOPlugin",
|
325
320
|
"TorchElasticEnvironmentPlugin",
|
326
321
|
"TorchSyncBatchNormPlugin",
|
@@ -337,6 +332,7 @@ __all__ = [
|
|
337
332
|
"_directory",
|
338
333
|
"_hf_hub",
|
339
334
|
"accelerator_registry",
|
335
|
+
"callback_registry",
|
340
336
|
"callbacks",
|
341
337
|
"loggers",
|
342
338
|
"lr_scheduler",
|
@@ -7,9 +7,11 @@ from nshtrainer._hf_hub import (
|
|
7
7
|
HuggingFaceHubAutoCreateConfig as HuggingFaceHubAutoCreateConfig,
|
8
8
|
)
|
9
9
|
from nshtrainer._hf_hub import HuggingFaceHubConfig as HuggingFaceHubConfig
|
10
|
+
from nshtrainer._hf_hub import callback_registry as callback_registry
|
10
11
|
|
11
12
|
__all__ = [
|
12
13
|
"CallbackConfigBase",
|
13
14
|
"HuggingFaceHubAutoCreateConfig",
|
14
15
|
"HuggingFaceHubConfig",
|
16
|
+
"callback_registry",
|
15
17
|
]
|
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__codegen__ = True
|
4
4
|
|
5
|
+
from nshtrainer.callbacks import ActSaveConfig as ActSaveConfig
|
5
6
|
from nshtrainer.callbacks import (
|
6
7
|
BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
|
7
8
|
)
|
@@ -25,6 +26,7 @@ from nshtrainer.callbacks import (
|
|
25
26
|
from nshtrainer.callbacks import (
|
26
27
|
LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
|
27
28
|
)
|
29
|
+
from nshtrainer.callbacks import LearningRateMonitorConfig as LearningRateMonitorConfig
|
28
30
|
from nshtrainer.callbacks import LogEpochCallbackConfig as LogEpochCallbackConfig
|
29
31
|
from nshtrainer.callbacks import NormLoggingCallbackConfig as NormLoggingCallbackConfig
|
30
32
|
from nshtrainer.callbacks import (
|
@@ -39,14 +41,11 @@ from nshtrainer.callbacks import (
|
|
39
41
|
from nshtrainer.callbacks import (
|
40
42
|
SharedParametersCallbackConfig as SharedParametersCallbackConfig,
|
41
43
|
)
|
42
|
-
from nshtrainer.callbacks import (
|
43
|
-
TimeCheckpointCallbackConfig as TimeCheckpointCallbackConfig,
|
44
|
-
)
|
45
44
|
from nshtrainer.callbacks import (
|
46
45
|
WandbUploadCodeCallbackConfig as WandbUploadCodeCallbackConfig,
|
47
46
|
)
|
48
47
|
from nshtrainer.callbacks import WandbWatchCallbackConfig as WandbWatchCallbackConfig
|
49
|
-
from nshtrainer.callbacks
|
48
|
+
from nshtrainer.callbacks import callback_registry as callback_registry
|
50
49
|
from nshtrainer.callbacks.checkpoint._base import (
|
51
50
|
BaseCheckpointCallbackConfig as BaseCheckpointCallbackConfig,
|
52
51
|
)
|
@@ -54,9 +53,6 @@ from nshtrainer.callbacks.checkpoint._base import (
|
|
54
53
|
CheckpointMetadata as CheckpointMetadata,
|
55
54
|
)
|
56
55
|
from nshtrainer.callbacks.early_stopping import MetricConfig as MetricConfig
|
57
|
-
from nshtrainer.callbacks.lr_monitor import (
|
58
|
-
LearningRateMonitorConfig as LearningRateMonitorConfig,
|
59
|
-
)
|
60
56
|
|
61
57
|
from . import actsave as actsave
|
62
58
|
from . import base as base
|
@@ -100,11 +96,11 @@ __all__ = [
|
|
100
96
|
"PrintTableMetricsCallbackConfig",
|
101
97
|
"RLPSanityChecksCallbackConfig",
|
102
98
|
"SharedParametersCallbackConfig",
|
103
|
-
"TimeCheckpointCallbackConfig",
|
104
99
|
"WandbUploadCodeCallbackConfig",
|
105
100
|
"WandbWatchCallbackConfig",
|
106
101
|
"actsave",
|
107
102
|
"base",
|
103
|
+
"callback_registry",
|
108
104
|
"checkpoint",
|
109
105
|
"debug_flag",
|
110
106
|
"directory_setup",
|
@@ -4,8 +4,10 @@ __codegen__ = True
|
|
4
4
|
|
5
5
|
from nshtrainer.callbacks.actsave import ActSaveConfig as ActSaveConfig
|
6
6
|
from nshtrainer.callbacks.actsave import CallbackConfigBase as CallbackConfigBase
|
7
|
+
from nshtrainer.callbacks.actsave import callback_registry as callback_registry
|
7
8
|
|
8
9
|
__all__ = [
|
9
10
|
"ActSaveConfig",
|
10
11
|
"CallbackConfigBase",
|
12
|
+
"callback_registry",
|
11
13
|
]
|
@@ -3,7 +3,9 @@ from __future__ import annotations
|
|
3
3
|
__codegen__ = True
|
4
4
|
|
5
5
|
from nshtrainer.callbacks.base import CallbackConfigBase as CallbackConfigBase
|
6
|
+
from nshtrainer.callbacks.base import callback_registry as callback_registry
|
6
7
|
|
7
8
|
__all__ = [
|
8
9
|
"CallbackConfigBase",
|
10
|
+
"callback_registry",
|
9
11
|
]
|
@@ -11,9 +11,6 @@ from nshtrainer.callbacks.checkpoint import (
|
|
11
11
|
from nshtrainer.callbacks.checkpoint import (
|
12
12
|
OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
|
13
13
|
)
|
14
|
-
from nshtrainer.callbacks.checkpoint import (
|
15
|
-
TimeCheckpointCallbackConfig as TimeCheckpointCallbackConfig,
|
16
|
-
)
|
17
14
|
from nshtrainer.callbacks.checkpoint._base import (
|
18
15
|
BaseCheckpointCallbackConfig as BaseCheckpointCallbackConfig,
|
19
16
|
)
|
@@ -24,12 +21,14 @@ from nshtrainer.callbacks.checkpoint._base import (
|
|
24
21
|
CheckpointMetadata as CheckpointMetadata,
|
25
22
|
)
|
26
23
|
from nshtrainer.callbacks.checkpoint.best_checkpoint import MetricConfig as MetricConfig
|
24
|
+
from nshtrainer.callbacks.checkpoint.best_checkpoint import (
|
25
|
+
callback_registry as callback_registry,
|
26
|
+
)
|
27
27
|
|
28
28
|
from . import _base as _base
|
29
29
|
from . import best_checkpoint as best_checkpoint
|
30
30
|
from . import last_checkpoint as last_checkpoint
|
31
31
|
from . import on_exception_checkpoint as on_exception_checkpoint
|
32
|
-
from . import time_checkpoint as time_checkpoint
|
33
32
|
|
34
33
|
__all__ = [
|
35
34
|
"BaseCheckpointCallbackConfig",
|
@@ -39,10 +38,9 @@ __all__ = [
|
|
39
38
|
"LastCheckpointCallbackConfig",
|
40
39
|
"MetricConfig",
|
41
40
|
"OnExceptionCheckpointCallbackConfig",
|
42
|
-
"TimeCheckpointCallbackConfig",
|
43
41
|
"_base",
|
44
42
|
"best_checkpoint",
|
43
|
+
"callback_registry",
|
45
44
|
"last_checkpoint",
|
46
45
|
"on_exception_checkpoint",
|
47
|
-
"time_checkpoint",
|
48
46
|
]
|
@@ -12,10 +12,14 @@ from nshtrainer.callbacks.checkpoint.best_checkpoint import (
|
|
12
12
|
CheckpointMetadata as CheckpointMetadata,
|
13
13
|
)
|
14
14
|
from nshtrainer.callbacks.checkpoint.best_checkpoint import MetricConfig as MetricConfig
|
15
|
+
from nshtrainer.callbacks.checkpoint.best_checkpoint import (
|
16
|
+
callback_registry as callback_registry,
|
17
|
+
)
|
15
18
|
|
16
19
|
__all__ = [
|
17
20
|
"BaseCheckpointCallbackConfig",
|
18
21
|
"BestCheckpointCallbackConfig",
|
19
22
|
"CheckpointMetadata",
|
20
23
|
"MetricConfig",
|
24
|
+
"callback_registry",
|
21
25
|
]
|
@@ -11,9 +11,13 @@ from nshtrainer.callbacks.checkpoint.last_checkpoint import (
|
|
11
11
|
from nshtrainer.callbacks.checkpoint.last_checkpoint import (
|
12
12
|
LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
|
13
13
|
)
|
14
|
+
from nshtrainer.callbacks.checkpoint.last_checkpoint import (
|
15
|
+
callback_registry as callback_registry,
|
16
|
+
)
|
14
17
|
|
15
18
|
__all__ = [
|
16
19
|
"BaseCheckpointCallbackConfig",
|
17
20
|
"CheckpointMetadata",
|
18
21
|
"LastCheckpointCallbackConfig",
|
22
|
+
"callback_registry",
|
19
23
|
]
|
@@ -8,8 +8,12 @@ from nshtrainer.callbacks.checkpoint.on_exception_checkpoint import (
|
|
8
8
|
from nshtrainer.callbacks.checkpoint.on_exception_checkpoint import (
|
9
9
|
OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
|
10
10
|
)
|
11
|
+
from nshtrainer.callbacks.checkpoint.on_exception_checkpoint import (
|
12
|
+
callback_registry as callback_registry,
|
13
|
+
)
|
11
14
|
|
12
15
|
__all__ = [
|
13
16
|
"CallbackConfigBase",
|
14
17
|
"OnExceptionCheckpointCallbackConfig",
|
18
|
+
"callback_registry",
|
15
19
|
]
|
@@ -6,8 +6,10 @@ from nshtrainer.callbacks.debug_flag import CallbackConfigBase as CallbackConfig
|
|
6
6
|
from nshtrainer.callbacks.debug_flag import (
|
7
7
|
DebugFlagCallbackConfig as DebugFlagCallbackConfig,
|
8
8
|
)
|
9
|
+
from nshtrainer.callbacks.debug_flag import callback_registry as callback_registry
|
9
10
|
|
10
11
|
__all__ = [
|
11
12
|
"CallbackConfigBase",
|
12
13
|
"DebugFlagCallbackConfig",
|
14
|
+
"callback_registry",
|
13
15
|
]
|
@@ -8,8 +8,10 @@ from nshtrainer.callbacks.directory_setup import (
|
|
8
8
|
from nshtrainer.callbacks.directory_setup import (
|
9
9
|
DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
|
10
10
|
)
|
11
|
+
from nshtrainer.callbacks.directory_setup import callback_registry as callback_registry
|
11
12
|
|
12
13
|
__all__ = [
|
13
14
|
"CallbackConfigBase",
|
14
15
|
"DirectorySetupCallbackConfig",
|
16
|
+
"callback_registry",
|
15
17
|
]
|
@@ -7,9 +7,11 @@ from nshtrainer.callbacks.early_stopping import (
|
|
7
7
|
EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
|
8
8
|
)
|
9
9
|
from nshtrainer.callbacks.early_stopping import MetricConfig as MetricConfig
|
10
|
+
from nshtrainer.callbacks.early_stopping import callback_registry as callback_registry
|
10
11
|
|
11
12
|
__all__ = [
|
12
13
|
"CallbackConfigBase",
|
13
14
|
"EarlyStoppingCallbackConfig",
|
14
15
|
"MetricConfig",
|
16
|
+
"callback_registry",
|
15
17
|
]
|
@@ -4,8 +4,10 @@ __codegen__ = True
|
|
4
4
|
|
5
5
|
from nshtrainer.callbacks.ema import CallbackConfigBase as CallbackConfigBase
|
6
6
|
from nshtrainer.callbacks.ema import EMACallbackConfig as EMACallbackConfig
|
7
|
+
from nshtrainer.callbacks.ema import callback_registry as callback_registry
|
7
8
|
|
8
9
|
__all__ = [
|
9
10
|
"CallbackConfigBase",
|
10
11
|
"EMACallbackConfig",
|
12
|
+
"callback_registry",
|
11
13
|
]
|
@@ -6,8 +6,10 @@ from nshtrainer.callbacks.finite_checks import CallbackConfigBase as CallbackCon
|
|
6
6
|
from nshtrainer.callbacks.finite_checks import (
|
7
7
|
FiniteChecksCallbackConfig as FiniteChecksCallbackConfig,
|
8
8
|
)
|
9
|
+
from nshtrainer.callbacks.finite_checks import callback_registry as callback_registry
|
9
10
|
|
10
11
|
__all__ = [
|
11
12
|
"CallbackConfigBase",
|
12
13
|
"FiniteChecksCallbackConfig",
|
14
|
+
"callback_registry",
|
13
15
|
]
|
@@ -8,8 +8,12 @@ from nshtrainer.callbacks.gradient_skipping import (
|
|
8
8
|
from nshtrainer.callbacks.gradient_skipping import (
|
9
9
|
GradientSkippingCallbackConfig as GradientSkippingCallbackConfig,
|
10
10
|
)
|
11
|
+
from nshtrainer.callbacks.gradient_skipping import (
|
12
|
+
callback_registry as callback_registry,
|
13
|
+
)
|
11
14
|
|
12
15
|
__all__ = [
|
13
16
|
"CallbackConfigBase",
|
14
17
|
"GradientSkippingCallbackConfig",
|
18
|
+
"callback_registry",
|
15
19
|
]
|
@@ -6,8 +6,10 @@ from nshtrainer.callbacks.log_epoch import CallbackConfigBase as CallbackConfigB
|
|
6
6
|
from nshtrainer.callbacks.log_epoch import (
|
7
7
|
LogEpochCallbackConfig as LogEpochCallbackConfig,
|
8
8
|
)
|
9
|
+
from nshtrainer.callbacks.log_epoch import callback_registry as callback_registry
|
9
10
|
|
10
11
|
__all__ = [
|
11
12
|
"CallbackConfigBase",
|
12
13
|
"LogEpochCallbackConfig",
|
14
|
+
"callback_registry",
|
13
15
|
]
|
@@ -6,8 +6,10 @@ from nshtrainer.callbacks.lr_monitor import CallbackConfigBase as CallbackConfig
|
|
6
6
|
from nshtrainer.callbacks.lr_monitor import (
|
7
7
|
LearningRateMonitorConfig as LearningRateMonitorConfig,
|
8
8
|
)
|
9
|
+
from nshtrainer.callbacks.lr_monitor import callback_registry as callback_registry
|
9
10
|
|
10
11
|
__all__ = [
|
11
12
|
"CallbackConfigBase",
|
12
13
|
"LearningRateMonitorConfig",
|
14
|
+
"callback_registry",
|
13
15
|
]
|
@@ -6,8 +6,10 @@ from nshtrainer.callbacks.norm_logging import CallbackConfigBase as CallbackConf
|
|
6
6
|
from nshtrainer.callbacks.norm_logging import (
|
7
7
|
NormLoggingCallbackConfig as NormLoggingCallbackConfig,
|
8
8
|
)
|
9
|
+
from nshtrainer.callbacks.norm_logging import callback_registry as callback_registry
|
9
10
|
|
10
11
|
__all__ = [
|
11
12
|
"CallbackConfigBase",
|
12
13
|
"NormLoggingCallbackConfig",
|
14
|
+
"callback_registry",
|
13
15
|
]
|
@@ -6,8 +6,10 @@ from nshtrainer.callbacks.print_table import CallbackConfigBase as CallbackConfi
|
|
6
6
|
from nshtrainer.callbacks.print_table import (
|
7
7
|
PrintTableMetricsCallbackConfig as PrintTableMetricsCallbackConfig,
|
8
8
|
)
|
9
|
+
from nshtrainer.callbacks.print_table import callback_registry as callback_registry
|
9
10
|
|
10
11
|
__all__ = [
|
11
12
|
"CallbackConfigBase",
|
12
13
|
"PrintTableMetricsCallbackConfig",
|
14
|
+
"callback_registry",
|
13
15
|
]
|
@@ -8,8 +8,12 @@ from nshtrainer.callbacks.rlp_sanity_checks import (
|
|
8
8
|
from nshtrainer.callbacks.rlp_sanity_checks import (
|
9
9
|
RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
|
10
10
|
)
|
11
|
+
from nshtrainer.callbacks.rlp_sanity_checks import (
|
12
|
+
callback_registry as callback_registry,
|
13
|
+
)
|
11
14
|
|
12
15
|
__all__ = [
|
13
16
|
"CallbackConfigBase",
|
14
17
|
"RLPSanityChecksCallbackConfig",
|
18
|
+
"callback_registry",
|
15
19
|
]
|
@@ -8,8 +8,12 @@ from nshtrainer.callbacks.shared_parameters import (
|
|
8
8
|
from nshtrainer.callbacks.shared_parameters import (
|
9
9
|
SharedParametersCallbackConfig as SharedParametersCallbackConfig,
|
10
10
|
)
|
11
|
+
from nshtrainer.callbacks.shared_parameters import (
|
12
|
+
callback_registry as callback_registry,
|
13
|
+
)
|
11
14
|
|
12
15
|
__all__ = [
|
13
16
|
"CallbackConfigBase",
|
14
17
|
"SharedParametersCallbackConfig",
|
18
|
+
"callback_registry",
|
15
19
|
]
|
@@ -6,8 +6,10 @@ from nshtrainer.callbacks.timer import CallbackConfigBase as CallbackConfigBase
|
|
6
6
|
from nshtrainer.callbacks.timer import (
|
7
7
|
EpochTimerCallbackConfig as EpochTimerCallbackConfig,
|
8
8
|
)
|
9
|
+
from nshtrainer.callbacks.timer import callback_registry as callback_registry
|
9
10
|
|
10
11
|
__all__ = [
|
11
12
|
"CallbackConfigBase",
|
12
13
|
"EpochTimerCallbackConfig",
|
14
|
+
"callback_registry",
|
13
15
|
]
|