nshtrainer 1.0.0b33__py3-none-any.whl → 1.0.0b37__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/__init__.py +1 -0
- nshtrainer/_directory.py +3 -1
- nshtrainer/_hf_hub.py +8 -1
- nshtrainer/callbacks/__init__.py +10 -23
- nshtrainer/callbacks/actsave.py +6 -2
- nshtrainer/callbacks/base.py +3 -0
- nshtrainer/callbacks/checkpoint/__init__.py +0 -4
- nshtrainer/callbacks/checkpoint/best_checkpoint.py +2 -0
- nshtrainer/callbacks/checkpoint/last_checkpoint.py +72 -2
- nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +4 -2
- nshtrainer/callbacks/debug_flag.py +4 -2
- nshtrainer/callbacks/directory_setup.py +23 -21
- nshtrainer/callbacks/early_stopping.py +4 -2
- nshtrainer/callbacks/ema.py +29 -27
- nshtrainer/callbacks/finite_checks.py +21 -19
- nshtrainer/callbacks/gradient_skipping.py +29 -27
- nshtrainer/callbacks/log_epoch.py +4 -2
- nshtrainer/callbacks/lr_monitor.py +6 -1
- nshtrainer/callbacks/norm_logging.py +36 -34
- nshtrainer/callbacks/print_table.py +20 -18
- nshtrainer/callbacks/rlp_sanity_checks.py +4 -2
- nshtrainer/callbacks/shared_parameters.py +9 -7
- nshtrainer/callbacks/timer.py +12 -10
- nshtrainer/callbacks/wandb_upload_code.py +4 -2
- nshtrainer/callbacks/wandb_watch.py +4 -2
- nshtrainer/configs/__init__.py +16 -12
- nshtrainer/configs/_hf_hub/__init__.py +2 -0
- nshtrainer/configs/callbacks/__init__.py +4 -8
- nshtrainer/configs/callbacks/actsave/__init__.py +2 -0
- nshtrainer/configs/callbacks/base/__init__.py +2 -0
- nshtrainer/configs/callbacks/checkpoint/__init__.py +4 -6
- nshtrainer/configs/callbacks/checkpoint/best_checkpoint/__init__.py +4 -0
- nshtrainer/configs/callbacks/checkpoint/last_checkpoint/__init__.py +4 -0
- nshtrainer/configs/callbacks/checkpoint/on_exception_checkpoint/__init__.py +4 -0
- nshtrainer/configs/callbacks/debug_flag/__init__.py +2 -0
- nshtrainer/configs/callbacks/directory_setup/__init__.py +2 -0
- nshtrainer/configs/callbacks/early_stopping/__init__.py +2 -0
- nshtrainer/configs/callbacks/ema/__init__.py +2 -0
- nshtrainer/configs/callbacks/finite_checks/__init__.py +2 -0
- nshtrainer/configs/callbacks/gradient_skipping/__init__.py +4 -0
- nshtrainer/configs/callbacks/log_epoch/__init__.py +2 -0
- nshtrainer/configs/callbacks/lr_monitor/__init__.py +2 -0
- nshtrainer/configs/callbacks/norm_logging/__init__.py +2 -0
- nshtrainer/configs/callbacks/print_table/__init__.py +2 -0
- nshtrainer/configs/callbacks/rlp_sanity_checks/__init__.py +4 -0
- nshtrainer/configs/callbacks/shared_parameters/__init__.py +4 -0
- nshtrainer/configs/callbacks/timer/__init__.py +2 -0
- nshtrainer/configs/callbacks/wandb_upload_code/__init__.py +4 -0
- nshtrainer/configs/callbacks/wandb_watch/__init__.py +2 -0
- nshtrainer/configs/loggers/__init__.py +6 -4
- nshtrainer/configs/loggers/actsave/__init__.py +4 -2
- nshtrainer/configs/loggers/base/__init__.py +11 -0
- nshtrainer/configs/loggers/csv/__init__.py +4 -2
- nshtrainer/configs/loggers/tensorboard/__init__.py +4 -2
- nshtrainer/configs/loggers/wandb/__init__.py +4 -2
- nshtrainer/configs/lr_scheduler/__init__.py +4 -2
- nshtrainer/configs/lr_scheduler/base/__init__.py +11 -0
- nshtrainer/configs/lr_scheduler/linear_warmup_cosine/__init__.py +4 -0
- nshtrainer/configs/lr_scheduler/reduce_lr_on_plateau/__init__.py +4 -0
- nshtrainer/configs/nn/__init__.py +4 -2
- nshtrainer/configs/nn/mlp/__init__.py +2 -2
- nshtrainer/configs/nn/nonlinearity/__init__.py +4 -2
- nshtrainer/configs/optimizer/__init__.py +2 -0
- nshtrainer/configs/trainer/__init__.py +4 -6
- nshtrainer/configs/trainer/_config/__init__.py +2 -10
- nshtrainer/loggers/__init__.py +3 -8
- nshtrainer/loggers/actsave.py +5 -2
- nshtrainer/loggers/{_base.py → base.py} +4 -1
- nshtrainer/loggers/csv.py +5 -3
- nshtrainer/loggers/tensorboard.py +5 -3
- nshtrainer/loggers/wandb.py +5 -3
- nshtrainer/lr_scheduler/__init__.py +2 -2
- nshtrainer/lr_scheduler/{_base.py → base.py} +3 -0
- nshtrainer/lr_scheduler/linear_warmup_cosine.py +56 -54
- nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +4 -2
- nshtrainer/nn/__init__.py +1 -1
- nshtrainer/nn/mlp.py +4 -4
- nshtrainer/nn/nonlinearity.py +37 -33
- nshtrainer/optimizer.py +8 -2
- nshtrainer/trainer/__init__.py +3 -2
- nshtrainer/trainer/_config.py +6 -44
- {nshtrainer-1.0.0b33.dist-info → nshtrainer-1.0.0b37.dist-info}/METADATA +1 -1
- nshtrainer-1.0.0b37.dist-info/RECORD +156 -0
- nshtrainer/callbacks/checkpoint/time_checkpoint.py +0 -114
- nshtrainer/configs/callbacks/checkpoint/time_checkpoint/__init__.py +0 -19
- nshtrainer/configs/loggers/_base/__init__.py +0 -9
- nshtrainer/configs/lr_scheduler/_base/__init__.py +0 -9
- nshtrainer-1.0.0b33.dist-info/RECORD +0 -158
- {nshtrainer-1.0.0b33.dist-info → nshtrainer-1.0.0b37.dist-info}/WHEEL +0 -0
@@ -7,13 +7,47 @@ import torch
|
|
7
7
|
import torch.nn as nn
|
8
8
|
from lightning.pytorch import Callback, LightningModule, Trainer
|
9
9
|
from torch.optim import Optimizer
|
10
|
-
from typing_extensions import override
|
10
|
+
from typing_extensions import final, override
|
11
11
|
|
12
|
-
from .base import CallbackConfigBase
|
12
|
+
from .base import CallbackConfigBase, callback_registry
|
13
13
|
|
14
14
|
log = logging.getLogger(__name__)
|
15
15
|
|
16
16
|
|
17
|
+
@final
|
18
|
+
@callback_registry.register
|
19
|
+
class NormLoggingCallbackConfig(CallbackConfigBase):
|
20
|
+
name: Literal["norm_logging"] = "norm_logging"
|
21
|
+
|
22
|
+
log_grad_norm: bool | str | float = False
|
23
|
+
"""If enabled, will log the gradient norm (averaged across all model parameters) to the logger."""
|
24
|
+
log_grad_norm_per_param: bool | str | float = False
|
25
|
+
"""If enabled, will log the gradient norm for each model parameter to the logger."""
|
26
|
+
|
27
|
+
log_param_norm: bool | str | float = False
|
28
|
+
"""If enabled, will log the parameter norm (averaged across all model parameters) to the logger."""
|
29
|
+
log_param_norm_per_param: bool | str | float = False
|
30
|
+
"""If enabled, will log the parameter norm for each model parameter to the logger."""
|
31
|
+
|
32
|
+
def __bool__(self):
|
33
|
+
return any(
|
34
|
+
v
|
35
|
+
for v in (
|
36
|
+
self.log_grad_norm,
|
37
|
+
self.log_grad_norm_per_param,
|
38
|
+
self.log_param_norm,
|
39
|
+
self.log_param_norm_per_param,
|
40
|
+
)
|
41
|
+
)
|
42
|
+
|
43
|
+
@override
|
44
|
+
def create_callbacks(self, trainer_config):
|
45
|
+
if not self:
|
46
|
+
return
|
47
|
+
|
48
|
+
yield NormLoggingCallback(self)
|
49
|
+
|
50
|
+
|
17
51
|
def grad_norm(
|
18
52
|
module: nn.Module,
|
19
53
|
norm_type: float | int | str,
|
@@ -155,35 +189,3 @@ class NormLoggingCallback(Callback):
|
|
155
189
|
self._perform_norm_logging(
|
156
190
|
pl_module, optimizer, prefix=f"train/optimizer_{i}/"
|
157
191
|
)
|
158
|
-
|
159
|
-
|
160
|
-
class NormLoggingCallbackConfig(CallbackConfigBase):
|
161
|
-
name: Literal["norm_logging"] = "norm_logging"
|
162
|
-
|
163
|
-
log_grad_norm: bool | str | float = False
|
164
|
-
"""If enabled, will log the gradient norm (averaged across all model parameters) to the logger."""
|
165
|
-
log_grad_norm_per_param: bool | str | float = False
|
166
|
-
"""If enabled, will log the gradient norm for each model parameter to the logger."""
|
167
|
-
|
168
|
-
log_param_norm: bool | str | float = False
|
169
|
-
"""If enabled, will log the parameter norm (averaged across all model parameters) to the logger."""
|
170
|
-
log_param_norm_per_param: bool | str | float = False
|
171
|
-
"""If enabled, will log the parameter norm for each model parameter to the logger."""
|
172
|
-
|
173
|
-
def __bool__(self):
|
174
|
-
return any(
|
175
|
-
v
|
176
|
-
for v in (
|
177
|
-
self.log_grad_norm,
|
178
|
-
self.log_grad_norm_per_param,
|
179
|
-
self.log_param_norm,
|
180
|
-
self.log_param_norm_per_param,
|
181
|
-
)
|
182
|
-
)
|
183
|
-
|
184
|
-
@override
|
185
|
-
def create_callbacks(self, trainer_config):
|
186
|
-
if not self:
|
187
|
-
return
|
188
|
-
|
189
|
-
yield NormLoggingCallback(self)
|
@@ -9,13 +9,31 @@ from typing import Literal
|
|
9
9
|
import torch
|
10
10
|
from lightning.pytorch import LightningModule, Trainer
|
11
11
|
from lightning.pytorch.callbacks import Callback
|
12
|
-
from typing_extensions import override
|
12
|
+
from typing_extensions import final, override
|
13
13
|
|
14
|
-
from .base import CallbackConfigBase
|
14
|
+
from .base import CallbackConfigBase, callback_registry
|
15
15
|
|
16
16
|
log = logging.getLogger(__name__)
|
17
17
|
|
18
18
|
|
19
|
+
@final
|
20
|
+
@callback_registry.register
|
21
|
+
class PrintTableMetricsCallbackConfig(CallbackConfigBase):
|
22
|
+
"""Configuration class for PrintTableMetricsCallback."""
|
23
|
+
|
24
|
+
name: Literal["print_table_metrics"] = "print_table_metrics"
|
25
|
+
|
26
|
+
enabled: bool = True
|
27
|
+
"""Whether to enable the callback or not."""
|
28
|
+
|
29
|
+
metric_patterns: list[str] | None = None
|
30
|
+
"""List of patterns to filter the metrics to be displayed. If None, all metrics are displayed."""
|
31
|
+
|
32
|
+
@override
|
33
|
+
def create_callbacks(self, trainer_config):
|
34
|
+
yield PrintTableMetricsCallback(metric_patterns=self.metric_patterns)
|
35
|
+
|
36
|
+
|
19
37
|
class PrintTableMetricsCallback(Callback):
|
20
38
|
"""Prints a table with the metrics in columns on every epoch end."""
|
21
39
|
|
@@ -74,19 +92,3 @@ class PrintTableMetricsCallback(Callback):
|
|
74
92
|
table.add_row(*values)
|
75
93
|
|
76
94
|
return table
|
77
|
-
|
78
|
-
|
79
|
-
class PrintTableMetricsCallbackConfig(CallbackConfigBase):
|
80
|
-
"""Configuration class for PrintTableMetricsCallback."""
|
81
|
-
|
82
|
-
name: Literal["print_table_metrics"] = "print_table_metrics"
|
83
|
-
|
84
|
-
enabled: bool = True
|
85
|
-
"""Whether to enable the callback or not."""
|
86
|
-
|
87
|
-
metric_patterns: list[str] | None = None
|
88
|
-
"""List of patterns to filter the metrics to be displayed. If None, all metrics are displayed."""
|
89
|
-
|
90
|
-
@override
|
91
|
-
def create_callbacks(self, trainer_config):
|
92
|
-
yield PrintTableMetricsCallback(metric_patterns=self.metric_patterns)
|
@@ -11,13 +11,15 @@ from lightning.pytorch.utilities.types import (
|
|
11
11
|
LRSchedulerConfigType,
|
12
12
|
LRSchedulerTypeUnion,
|
13
13
|
)
|
14
|
-
from typing_extensions import Protocol, override, runtime_checkable
|
14
|
+
from typing_extensions import Protocol, final, override, runtime_checkable
|
15
15
|
|
16
|
-
from .base import CallbackConfigBase
|
16
|
+
from .base import CallbackConfigBase, callback_registry
|
17
17
|
|
18
18
|
log = logging.getLogger(__name__)
|
19
19
|
|
20
20
|
|
21
|
+
@final
|
22
|
+
@callback_registry.register
|
21
23
|
class RLPSanityChecksCallbackConfig(CallbackConfigBase):
|
22
24
|
"""
|
23
25
|
If enabled, will do some sanity checks if the `ReduceLROnPlateau` scheduler is used:
|
@@ -7,18 +7,15 @@ from typing import Literal, Protocol, runtime_checkable
|
|
7
7
|
import torch.nn as nn
|
8
8
|
from lightning.pytorch import LightningModule, Trainer
|
9
9
|
from lightning.pytorch.callbacks import Callback
|
10
|
-
from typing_extensions import TypeAliasType, override
|
10
|
+
from typing_extensions import TypeAliasType, final, override
|
11
11
|
|
12
|
-
from .base import CallbackConfigBase
|
12
|
+
from .base import CallbackConfigBase, callback_registry
|
13
13
|
|
14
14
|
log = logging.getLogger(__name__)
|
15
15
|
|
16
16
|
|
17
|
-
|
18
|
-
|
19
|
-
return [mapping[id(p)] for p in parameters]
|
20
|
-
|
21
|
-
|
17
|
+
@final
|
18
|
+
@callback_registry.register
|
22
19
|
class SharedParametersCallbackConfig(CallbackConfigBase):
|
23
20
|
"""A callback that allows scaling the gradients of shared parameters that
|
24
21
|
are registered in the ``self.shared_parameters`` list of the root module.
|
@@ -34,6 +31,11 @@ class SharedParametersCallbackConfig(CallbackConfigBase):
|
|
34
31
|
yield SharedParametersCallback(self)
|
35
32
|
|
36
33
|
|
34
|
+
def _parameters_to_names(parameters: Iterable[nn.Parameter], model: nn.Module):
|
35
|
+
mapping = {id(p): n for n, p in model.named_parameters()}
|
36
|
+
return [mapping[id(p)] for p in parameters]
|
37
|
+
|
38
|
+
|
37
39
|
SharedParametersList = TypeAliasType(
|
38
40
|
"SharedParametersList", list[tuple[nn.Parameter, int | float]]
|
39
41
|
)
|
nshtrainer/callbacks/timer.py
CHANGED
@@ -7,13 +7,23 @@ from typing import Any, Literal
|
|
7
7
|
from lightning.pytorch import LightningModule, Trainer
|
8
8
|
from lightning.pytorch.callbacks import Callback
|
9
9
|
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
10
|
-
from typing_extensions import override
|
10
|
+
from typing_extensions import final, override
|
11
11
|
|
12
|
-
from .base import CallbackConfigBase
|
12
|
+
from .base import CallbackConfigBase, callback_registry
|
13
13
|
|
14
14
|
log = logging.getLogger(__name__)
|
15
15
|
|
16
16
|
|
17
|
+
@final
|
18
|
+
@callback_registry.register
|
19
|
+
class EpochTimerCallbackConfig(CallbackConfigBase):
|
20
|
+
name: Literal["epoch_timer"] = "epoch_timer"
|
21
|
+
|
22
|
+
@override
|
23
|
+
def create_callbacks(self, trainer_config):
|
24
|
+
yield EpochTimerCallback()
|
25
|
+
|
26
|
+
|
17
27
|
class EpochTimerCallback(Callback):
|
18
28
|
def __init__(self):
|
19
29
|
super().__init__()
|
@@ -149,11 +159,3 @@ class EpochTimerCallback(Callback):
|
|
149
159
|
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
150
160
|
self._elapsed_time = state_dict["elapsed_time"]
|
151
161
|
self._total_batches = state_dict["total_batches"]
|
152
|
-
|
153
|
-
|
154
|
-
class EpochTimerCallbackConfig(CallbackConfigBase):
|
155
|
-
name: Literal["epoch_timer"] = "epoch_timer"
|
156
|
-
|
157
|
-
@override
|
158
|
-
def create_callbacks(self, trainer_config):
|
159
|
-
yield EpochTimerCallback()
|
@@ -9,13 +9,15 @@ from lightning.pytorch import LightningModule, Trainer
|
|
9
9
|
from lightning.pytorch.callbacks.callback import Callback
|
10
10
|
from lightning.pytorch.loggers import WandbLogger
|
11
11
|
from nshrunner._env import SNAPSHOT_DIR
|
12
|
-
from typing_extensions import override
|
12
|
+
from typing_extensions import final, override
|
13
13
|
|
14
|
-
from .base import CallbackConfigBase
|
14
|
+
from .base import CallbackConfigBase, callback_registry
|
15
15
|
|
16
16
|
log = logging.getLogger(__name__)
|
17
17
|
|
18
18
|
|
19
|
+
@final
|
20
|
+
@callback_registry.register
|
19
21
|
class WandbUploadCodeCallbackConfig(CallbackConfigBase):
|
20
22
|
name: Literal["wandb_upload_code"] = "wandb_upload_code"
|
21
23
|
|
@@ -7,13 +7,15 @@ import torch.nn as nn
|
|
7
7
|
from lightning.pytorch import LightningModule, Trainer
|
8
8
|
from lightning.pytorch.callbacks.callback import Callback
|
9
9
|
from lightning.pytorch.loggers import WandbLogger
|
10
|
-
from typing_extensions import override
|
10
|
+
from typing_extensions import final, override
|
11
11
|
|
12
|
-
from .base import CallbackConfigBase
|
12
|
+
from .base import CallbackConfigBase, callback_registry
|
13
13
|
|
14
14
|
log = logging.getLogger(__name__)
|
15
15
|
|
16
16
|
|
17
|
+
@final
|
18
|
+
@callback_registry.register
|
17
19
|
class WandbWatchCallbackConfig(CallbackConfigBase):
|
18
20
|
name: Literal["wandb_watch"] = "wandb_watch"
|
19
21
|
|
nshtrainer/configs/__init__.py
CHANGED
@@ -5,6 +5,7 @@ __codegen__ = True
|
|
5
5
|
from nshtrainer import MetricConfig as MetricConfig
|
6
6
|
from nshtrainer import TrainerConfig as TrainerConfig
|
7
7
|
from nshtrainer import accelerator_registry as accelerator_registry
|
8
|
+
from nshtrainer import callback_registry as callback_registry
|
8
9
|
from nshtrainer import plugin_registry as plugin_registry
|
9
10
|
from nshtrainer._checkpoint.metadata import CheckpointMetadata as CheckpointMetadata
|
10
11
|
from nshtrainer._directory import DirectoryConfig as DirectoryConfig
|
@@ -13,6 +14,7 @@ from nshtrainer._hf_hub import (
|
|
13
14
|
HuggingFaceHubAutoCreateConfig as HuggingFaceHubAutoCreateConfig,
|
14
15
|
)
|
15
16
|
from nshtrainer._hf_hub import HuggingFaceHubConfig as HuggingFaceHubConfig
|
17
|
+
from nshtrainer.callbacks import ActSaveConfig as ActSaveConfig
|
16
18
|
from nshtrainer.callbacks import (
|
17
19
|
BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
|
18
20
|
)
|
@@ -35,6 +37,7 @@ from nshtrainer.callbacks import (
|
|
35
37
|
from nshtrainer.callbacks import (
|
36
38
|
LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
|
37
39
|
)
|
40
|
+
from nshtrainer.callbacks import LearningRateMonitorConfig as LearningRateMonitorConfig
|
38
41
|
from nshtrainer.callbacks import LogEpochCallbackConfig as LogEpochCallbackConfig
|
39
42
|
from nshtrainer.callbacks import NormLoggingCallbackConfig as NormLoggingCallbackConfig
|
40
43
|
from nshtrainer.callbacks import (
|
@@ -49,36 +52,34 @@ from nshtrainer.callbacks import (
|
|
49
52
|
from nshtrainer.callbacks import (
|
50
53
|
SharedParametersCallbackConfig as SharedParametersCallbackConfig,
|
51
54
|
)
|
52
|
-
from nshtrainer.callbacks import (
|
53
|
-
TimeCheckpointCallbackConfig as TimeCheckpointCallbackConfig,
|
54
|
-
)
|
55
55
|
from nshtrainer.callbacks import (
|
56
56
|
WandbUploadCodeCallbackConfig as WandbUploadCodeCallbackConfig,
|
57
57
|
)
|
58
58
|
from nshtrainer.callbacks import WandbWatchCallbackConfig as WandbWatchCallbackConfig
|
59
|
-
from nshtrainer.callbacks.actsave import ActSaveConfig as ActSaveConfig
|
60
59
|
from nshtrainer.callbacks.checkpoint._base import (
|
61
60
|
BaseCheckpointCallbackConfig as BaseCheckpointCallbackConfig,
|
62
61
|
)
|
63
62
|
from nshtrainer.loggers import ActSaveLoggerConfig as ActSaveLoggerConfig
|
64
|
-
from nshtrainer.loggers import BaseLoggerConfig as BaseLoggerConfig
|
65
63
|
from nshtrainer.loggers import CSVLoggerConfig as CSVLoggerConfig
|
66
64
|
from nshtrainer.loggers import LoggerConfig as LoggerConfig
|
65
|
+
from nshtrainer.loggers import LoggerConfigBase as LoggerConfigBase
|
67
66
|
from nshtrainer.loggers import TensorboardLoggerConfig as TensorboardLoggerConfig
|
68
67
|
from nshtrainer.loggers import WandbLoggerConfig as WandbLoggerConfig
|
68
|
+
from nshtrainer.loggers import logger_registry as logger_registry
|
69
69
|
from nshtrainer.lr_scheduler import (
|
70
70
|
LinearWarmupCosineDecayLRSchedulerConfig as LinearWarmupCosineDecayLRSchedulerConfig,
|
71
71
|
)
|
72
72
|
from nshtrainer.lr_scheduler import LRSchedulerConfig as LRSchedulerConfig
|
73
73
|
from nshtrainer.lr_scheduler import LRSchedulerConfigBase as LRSchedulerConfigBase
|
74
74
|
from nshtrainer.lr_scheduler import ReduceLROnPlateauConfig as ReduceLROnPlateauConfig
|
75
|
-
from nshtrainer.
|
75
|
+
from nshtrainer.lr_scheduler.base import lr_scheduler_registry as lr_scheduler_registry
|
76
76
|
from nshtrainer.nn import ELUNonlinearityConfig as ELUNonlinearityConfig
|
77
77
|
from nshtrainer.nn import GELUNonlinearityConfig as GELUNonlinearityConfig
|
78
78
|
from nshtrainer.nn import LeakyReLUNonlinearityConfig as LeakyReLUNonlinearityConfig
|
79
79
|
from nshtrainer.nn import MishNonlinearityConfig as MishNonlinearityConfig
|
80
80
|
from nshtrainer.nn import MLPConfig as MLPConfig
|
81
81
|
from nshtrainer.nn import NonlinearityConfig as NonlinearityConfig
|
82
|
+
from nshtrainer.nn import NonlinearityConfigBase as NonlinearityConfigBase
|
82
83
|
from nshtrainer.nn import PReLUConfig as PReLUConfig
|
83
84
|
from nshtrainer.nn import ReLUNonlinearityConfig as ReLUNonlinearityConfig
|
84
85
|
from nshtrainer.nn import SigmoidNonlinearityConfig as SigmoidNonlinearityConfig
|
@@ -91,9 +92,11 @@ from nshtrainer.nn import TanhNonlinearityConfig as TanhNonlinearityConfig
|
|
91
92
|
from nshtrainer.nn.nonlinearity import (
|
92
93
|
SwiGLUNonlinearityConfig as SwiGLUNonlinearityConfig,
|
93
94
|
)
|
95
|
+
from nshtrainer.nn.nonlinearity import nonlinearity_registry as nonlinearity_registry
|
94
96
|
from nshtrainer.optimizer import AdamWConfig as AdamWConfig
|
95
97
|
from nshtrainer.optimizer import OptimizerConfig as OptimizerConfig
|
96
98
|
from nshtrainer.optimizer import OptimizerConfigBase as OptimizerConfigBase
|
99
|
+
from nshtrainer.optimizer import optimizer_registry as optimizer_registry
|
97
100
|
from nshtrainer.profiler import AdvancedProfilerConfig as AdvancedProfilerConfig
|
98
101
|
from nshtrainer.profiler import BaseProfilerConfig as BaseProfilerConfig
|
99
102
|
from nshtrainer.profiler import ProfilerConfig as ProfilerConfig
|
@@ -106,9 +109,6 @@ from nshtrainer.trainer._config import (
|
|
106
109
|
from nshtrainer.trainer._config import CheckpointSavingConfig as CheckpointSavingConfig
|
107
110
|
from nshtrainer.trainer._config import EnvironmentConfig as EnvironmentConfig
|
108
111
|
from nshtrainer.trainer._config import GradientClippingConfig as GradientClippingConfig
|
109
|
-
from nshtrainer.trainer._config import (
|
110
|
-
LearningRateMonitorConfig as LearningRateMonitorConfig,
|
111
|
-
)
|
112
112
|
from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingConfig
|
113
113
|
from nshtrainer.trainer._config import StrategyConfig as StrategyConfig
|
114
114
|
from nshtrainer.trainer.accelerator import CPUAcceleratorConfig as CPUAcceleratorConfig
|
@@ -229,8 +229,6 @@ __all__ = [
|
|
229
229
|
"AdvancedProfilerConfig",
|
230
230
|
"AsyncCheckpointIOPlugin",
|
231
231
|
"BaseCheckpointCallbackConfig",
|
232
|
-
"BaseLoggerConfig",
|
233
|
-
"BaseNonlinearityConfig",
|
234
232
|
"BaseProfilerConfig",
|
235
233
|
"BestCheckpointCallbackConfig",
|
236
234
|
"BitsandbytesPluginConfig",
|
@@ -284,6 +282,7 @@ __all__ = [
|
|
284
282
|
"LinearWarmupCosineDecayLRSchedulerConfig",
|
285
283
|
"LogEpochCallbackConfig",
|
286
284
|
"LoggerConfig",
|
285
|
+
"LoggerConfigBase",
|
287
286
|
"MLPConfig",
|
288
287
|
"MPIEnvironmentPlugin",
|
289
288
|
"MPSAcceleratorConfig",
|
@@ -291,6 +290,7 @@ __all__ = [
|
|
291
290
|
"MishNonlinearityConfig",
|
292
291
|
"MixedPrecisionPluginConfig",
|
293
292
|
"NonlinearityConfig",
|
293
|
+
"NonlinearityConfigBase",
|
294
294
|
"NormLoggingCallbackConfig",
|
295
295
|
"OnExceptionCheckpointCallbackConfig",
|
296
296
|
"OptimizerConfig",
|
@@ -320,7 +320,6 @@ __all__ = [
|
|
320
320
|
"SwishNonlinearityConfig",
|
321
321
|
"TanhNonlinearityConfig",
|
322
322
|
"TensorboardLoggerConfig",
|
323
|
-
"TimeCheckpointCallbackConfig",
|
324
323
|
"TorchCheckpointIOPlugin",
|
325
324
|
"TorchElasticEnvironmentPlugin",
|
326
325
|
"TorchSyncBatchNormPlugin",
|
@@ -337,12 +336,17 @@ __all__ = [
|
|
337
336
|
"_directory",
|
338
337
|
"_hf_hub",
|
339
338
|
"accelerator_registry",
|
339
|
+
"callback_registry",
|
340
340
|
"callbacks",
|
341
|
+
"logger_registry",
|
341
342
|
"loggers",
|
342
343
|
"lr_scheduler",
|
344
|
+
"lr_scheduler_registry",
|
343
345
|
"metrics",
|
344
346
|
"nn",
|
347
|
+
"nonlinearity_registry",
|
345
348
|
"optimizer",
|
349
|
+
"optimizer_registry",
|
346
350
|
"plugin_registry",
|
347
351
|
"profiler",
|
348
352
|
"trainer",
|
@@ -7,9 +7,11 @@ from nshtrainer._hf_hub import (
|
|
7
7
|
HuggingFaceHubAutoCreateConfig as HuggingFaceHubAutoCreateConfig,
|
8
8
|
)
|
9
9
|
from nshtrainer._hf_hub import HuggingFaceHubConfig as HuggingFaceHubConfig
|
10
|
+
from nshtrainer._hf_hub import callback_registry as callback_registry
|
10
11
|
|
11
12
|
__all__ = [
|
12
13
|
"CallbackConfigBase",
|
13
14
|
"HuggingFaceHubAutoCreateConfig",
|
14
15
|
"HuggingFaceHubConfig",
|
16
|
+
"callback_registry",
|
15
17
|
]
|
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__codegen__ = True
|
4
4
|
|
5
|
+
from nshtrainer.callbacks import ActSaveConfig as ActSaveConfig
|
5
6
|
from nshtrainer.callbacks import (
|
6
7
|
BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
|
7
8
|
)
|
@@ -25,6 +26,7 @@ from nshtrainer.callbacks import (
|
|
25
26
|
from nshtrainer.callbacks import (
|
26
27
|
LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
|
27
28
|
)
|
29
|
+
from nshtrainer.callbacks import LearningRateMonitorConfig as LearningRateMonitorConfig
|
28
30
|
from nshtrainer.callbacks import LogEpochCallbackConfig as LogEpochCallbackConfig
|
29
31
|
from nshtrainer.callbacks import NormLoggingCallbackConfig as NormLoggingCallbackConfig
|
30
32
|
from nshtrainer.callbacks import (
|
@@ -39,14 +41,11 @@ from nshtrainer.callbacks import (
|
|
39
41
|
from nshtrainer.callbacks import (
|
40
42
|
SharedParametersCallbackConfig as SharedParametersCallbackConfig,
|
41
43
|
)
|
42
|
-
from nshtrainer.callbacks import (
|
43
|
-
TimeCheckpointCallbackConfig as TimeCheckpointCallbackConfig,
|
44
|
-
)
|
45
44
|
from nshtrainer.callbacks import (
|
46
45
|
WandbUploadCodeCallbackConfig as WandbUploadCodeCallbackConfig,
|
47
46
|
)
|
48
47
|
from nshtrainer.callbacks import WandbWatchCallbackConfig as WandbWatchCallbackConfig
|
49
|
-
from nshtrainer.callbacks
|
48
|
+
from nshtrainer.callbacks import callback_registry as callback_registry
|
50
49
|
from nshtrainer.callbacks.checkpoint._base import (
|
51
50
|
BaseCheckpointCallbackConfig as BaseCheckpointCallbackConfig,
|
52
51
|
)
|
@@ -54,9 +53,6 @@ from nshtrainer.callbacks.checkpoint._base import (
|
|
54
53
|
CheckpointMetadata as CheckpointMetadata,
|
55
54
|
)
|
56
55
|
from nshtrainer.callbacks.early_stopping import MetricConfig as MetricConfig
|
57
|
-
from nshtrainer.callbacks.lr_monitor import (
|
58
|
-
LearningRateMonitorConfig as LearningRateMonitorConfig,
|
59
|
-
)
|
60
56
|
|
61
57
|
from . import actsave as actsave
|
62
58
|
from . import base as base
|
@@ -100,11 +96,11 @@ __all__ = [
|
|
100
96
|
"PrintTableMetricsCallbackConfig",
|
101
97
|
"RLPSanityChecksCallbackConfig",
|
102
98
|
"SharedParametersCallbackConfig",
|
103
|
-
"TimeCheckpointCallbackConfig",
|
104
99
|
"WandbUploadCodeCallbackConfig",
|
105
100
|
"WandbWatchCallbackConfig",
|
106
101
|
"actsave",
|
107
102
|
"base",
|
103
|
+
"callback_registry",
|
108
104
|
"checkpoint",
|
109
105
|
"debug_flag",
|
110
106
|
"directory_setup",
|
@@ -4,8 +4,10 @@ __codegen__ = True
|
|
4
4
|
|
5
5
|
from nshtrainer.callbacks.actsave import ActSaveConfig as ActSaveConfig
|
6
6
|
from nshtrainer.callbacks.actsave import CallbackConfigBase as CallbackConfigBase
|
7
|
+
from nshtrainer.callbacks.actsave import callback_registry as callback_registry
|
7
8
|
|
8
9
|
__all__ = [
|
9
10
|
"ActSaveConfig",
|
10
11
|
"CallbackConfigBase",
|
12
|
+
"callback_registry",
|
11
13
|
]
|
@@ -3,7 +3,9 @@ from __future__ import annotations
|
|
3
3
|
__codegen__ = True
|
4
4
|
|
5
5
|
from nshtrainer.callbacks.base import CallbackConfigBase as CallbackConfigBase
|
6
|
+
from nshtrainer.callbacks.base import callback_registry as callback_registry
|
6
7
|
|
7
8
|
__all__ = [
|
8
9
|
"CallbackConfigBase",
|
10
|
+
"callback_registry",
|
9
11
|
]
|
@@ -11,9 +11,6 @@ from nshtrainer.callbacks.checkpoint import (
|
|
11
11
|
from nshtrainer.callbacks.checkpoint import (
|
12
12
|
OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
|
13
13
|
)
|
14
|
-
from nshtrainer.callbacks.checkpoint import (
|
15
|
-
TimeCheckpointCallbackConfig as TimeCheckpointCallbackConfig,
|
16
|
-
)
|
17
14
|
from nshtrainer.callbacks.checkpoint._base import (
|
18
15
|
BaseCheckpointCallbackConfig as BaseCheckpointCallbackConfig,
|
19
16
|
)
|
@@ -24,12 +21,14 @@ from nshtrainer.callbacks.checkpoint._base import (
|
|
24
21
|
CheckpointMetadata as CheckpointMetadata,
|
25
22
|
)
|
26
23
|
from nshtrainer.callbacks.checkpoint.best_checkpoint import MetricConfig as MetricConfig
|
24
|
+
from nshtrainer.callbacks.checkpoint.best_checkpoint import (
|
25
|
+
callback_registry as callback_registry,
|
26
|
+
)
|
27
27
|
|
28
28
|
from . import _base as _base
|
29
29
|
from . import best_checkpoint as best_checkpoint
|
30
30
|
from . import last_checkpoint as last_checkpoint
|
31
31
|
from . import on_exception_checkpoint as on_exception_checkpoint
|
32
|
-
from . import time_checkpoint as time_checkpoint
|
33
32
|
|
34
33
|
__all__ = [
|
35
34
|
"BaseCheckpointCallbackConfig",
|
@@ -39,10 +38,9 @@ __all__ = [
|
|
39
38
|
"LastCheckpointCallbackConfig",
|
40
39
|
"MetricConfig",
|
41
40
|
"OnExceptionCheckpointCallbackConfig",
|
42
|
-
"TimeCheckpointCallbackConfig",
|
43
41
|
"_base",
|
44
42
|
"best_checkpoint",
|
43
|
+
"callback_registry",
|
45
44
|
"last_checkpoint",
|
46
45
|
"on_exception_checkpoint",
|
47
|
-
"time_checkpoint",
|
48
46
|
]
|
@@ -12,10 +12,14 @@ from nshtrainer.callbacks.checkpoint.best_checkpoint import (
|
|
12
12
|
CheckpointMetadata as CheckpointMetadata,
|
13
13
|
)
|
14
14
|
from nshtrainer.callbacks.checkpoint.best_checkpoint import MetricConfig as MetricConfig
|
15
|
+
from nshtrainer.callbacks.checkpoint.best_checkpoint import (
|
16
|
+
callback_registry as callback_registry,
|
17
|
+
)
|
15
18
|
|
16
19
|
__all__ = [
|
17
20
|
"BaseCheckpointCallbackConfig",
|
18
21
|
"BestCheckpointCallbackConfig",
|
19
22
|
"CheckpointMetadata",
|
20
23
|
"MetricConfig",
|
24
|
+
"callback_registry",
|
21
25
|
]
|
@@ -11,9 +11,13 @@ from nshtrainer.callbacks.checkpoint.last_checkpoint import (
|
|
11
11
|
from nshtrainer.callbacks.checkpoint.last_checkpoint import (
|
12
12
|
LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
|
13
13
|
)
|
14
|
+
from nshtrainer.callbacks.checkpoint.last_checkpoint import (
|
15
|
+
callback_registry as callback_registry,
|
16
|
+
)
|
14
17
|
|
15
18
|
__all__ = [
|
16
19
|
"BaseCheckpointCallbackConfig",
|
17
20
|
"CheckpointMetadata",
|
18
21
|
"LastCheckpointCallbackConfig",
|
22
|
+
"callback_registry",
|
19
23
|
]
|
@@ -8,8 +8,12 @@ from nshtrainer.callbacks.checkpoint.on_exception_checkpoint import (
|
|
8
8
|
from nshtrainer.callbacks.checkpoint.on_exception_checkpoint import (
|
9
9
|
OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
|
10
10
|
)
|
11
|
+
from nshtrainer.callbacks.checkpoint.on_exception_checkpoint import (
|
12
|
+
callback_registry as callback_registry,
|
13
|
+
)
|
11
14
|
|
12
15
|
__all__ = [
|
13
16
|
"CallbackConfigBase",
|
14
17
|
"OnExceptionCheckpointCallbackConfig",
|
18
|
+
"callback_registry",
|
15
19
|
]
|
@@ -6,8 +6,10 @@ from nshtrainer.callbacks.debug_flag import CallbackConfigBase as CallbackConfig
|
|
6
6
|
from nshtrainer.callbacks.debug_flag import (
|
7
7
|
DebugFlagCallbackConfig as DebugFlagCallbackConfig,
|
8
8
|
)
|
9
|
+
from nshtrainer.callbacks.debug_flag import callback_registry as callback_registry
|
9
10
|
|
10
11
|
__all__ = [
|
11
12
|
"CallbackConfigBase",
|
12
13
|
"DebugFlagCallbackConfig",
|
14
|
+
"callback_registry",
|
13
15
|
]
|
@@ -8,8 +8,10 @@ from nshtrainer.callbacks.directory_setup import (
|
|
8
8
|
from nshtrainer.callbacks.directory_setup import (
|
9
9
|
DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
|
10
10
|
)
|
11
|
+
from nshtrainer.callbacks.directory_setup import callback_registry as callback_registry
|
11
12
|
|
12
13
|
__all__ = [
|
13
14
|
"CallbackConfigBase",
|
14
15
|
"DirectorySetupCallbackConfig",
|
16
|
+
"callback_registry",
|
15
17
|
]
|
@@ -7,9 +7,11 @@ from nshtrainer.callbacks.early_stopping import (
|
|
7
7
|
EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
|
8
8
|
)
|
9
9
|
from nshtrainer.callbacks.early_stopping import MetricConfig as MetricConfig
|
10
|
+
from nshtrainer.callbacks.early_stopping import callback_registry as callback_registry
|
10
11
|
|
11
12
|
__all__ = [
|
12
13
|
"CallbackConfigBase",
|
13
14
|
"EarlyStoppingCallbackConfig",
|
14
15
|
"MetricConfig",
|
16
|
+
"callback_registry",
|
15
17
|
]
|
@@ -4,8 +4,10 @@ __codegen__ = True
|
|
4
4
|
|
5
5
|
from nshtrainer.callbacks.ema import CallbackConfigBase as CallbackConfigBase
|
6
6
|
from nshtrainer.callbacks.ema import EMACallbackConfig as EMACallbackConfig
|
7
|
+
from nshtrainer.callbacks.ema import callback_registry as callback_registry
|
7
8
|
|
8
9
|
__all__ = [
|
9
10
|
"CallbackConfigBase",
|
10
11
|
"EMACallbackConfig",
|
12
|
+
"callback_registry",
|
11
13
|
]
|
@@ -6,8 +6,10 @@ from nshtrainer.callbacks.finite_checks import CallbackConfigBase as CallbackCon
|
|
6
6
|
from nshtrainer.callbacks.finite_checks import (
|
7
7
|
FiniteChecksCallbackConfig as FiniteChecksCallbackConfig,
|
8
8
|
)
|
9
|
+
from nshtrainer.callbacks.finite_checks import callback_registry as callback_registry
|
9
10
|
|
10
11
|
__all__ = [
|
11
12
|
"CallbackConfigBase",
|
12
13
|
"FiniteChecksCallbackConfig",
|
14
|
+
"callback_registry",
|
13
15
|
]
|
@@ -8,8 +8,12 @@ from nshtrainer.callbacks.gradient_skipping import (
|
|
8
8
|
from nshtrainer.callbacks.gradient_skipping import (
|
9
9
|
GradientSkippingCallbackConfig as GradientSkippingCallbackConfig,
|
10
10
|
)
|
11
|
+
from nshtrainer.callbacks.gradient_skipping import (
|
12
|
+
callback_registry as callback_registry,
|
13
|
+
)
|
11
14
|
|
12
15
|
__all__ = [
|
13
16
|
"CallbackConfigBase",
|
14
17
|
"GradientSkippingCallbackConfig",
|
18
|
+
"callback_registry",
|
15
19
|
]
|
@@ -6,8 +6,10 @@ from nshtrainer.callbacks.log_epoch import CallbackConfigBase as CallbackConfigB
|
|
6
6
|
from nshtrainer.callbacks.log_epoch import (
|
7
7
|
LogEpochCallbackConfig as LogEpochCallbackConfig,
|
8
8
|
)
|
9
|
+
from nshtrainer.callbacks.log_epoch import callback_registry as callback_registry
|
9
10
|
|
10
11
|
__all__ = [
|
11
12
|
"CallbackConfigBase",
|
12
13
|
"LogEpochCallbackConfig",
|
14
|
+
"callback_registry",
|
13
15
|
]
|