nshtrainer 1.0.0b43__py3-none-any.whl → 1.0.0b45__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/callbacks/__init__.py +4 -0
- nshtrainer/callbacks/metric_validation.py +90 -0
- nshtrainer/configs/__init__.py +4 -0
- nshtrainer/configs/callbacks/__init__.py +6 -0
- nshtrainer/configs/callbacks/metric_validation/__init__.py +21 -0
- nshtrainer/configs/trainer/__init__.py +4 -0
- nshtrainer/configs/trainer/_config/__init__.py +4 -0
- nshtrainer/configs/trainer/trainer/__init__.py +0 -2
- nshtrainer/nn/__init__.py +0 -1
- nshtrainer/nn/mlp.py +60 -60
- nshtrainer/trainer/_config.py +6 -0
- nshtrainer/trainer/trainer.py +0 -1
- {nshtrainer-1.0.0b43.dist-info → nshtrainer-1.0.0b45.dist-info}/METADATA +1 -1
- {nshtrainer-1.0.0b43.dist-info → nshtrainer-1.0.0b45.dist-info}/RECORD +15 -13
- {nshtrainer-1.0.0b43.dist-info → nshtrainer-1.0.0b45.dist-info}/WHEEL +0 -0
nshtrainer/callbacks/__init__.py
CHANGED
@@ -40,6 +40,10 @@ from .log_epoch import LogEpochCallback as LogEpochCallback
|
|
40
40
|
from .log_epoch import LogEpochCallbackConfig as LogEpochCallbackConfig
|
41
41
|
from .lr_monitor import LearningRateMonitor as LearningRateMonitor
|
42
42
|
from .lr_monitor import LearningRateMonitorConfig as LearningRateMonitorConfig
|
43
|
+
from .metric_validation import MetricValidationCallback as MetricValidationCallback
|
44
|
+
from .metric_validation import (
|
45
|
+
MetricValidationCallbackConfig as MetricValidationCallbackConfig,
|
46
|
+
)
|
43
47
|
from .norm_logging import NormLoggingCallback as NormLoggingCallback
|
44
48
|
from .norm_logging import NormLoggingCallbackConfig as NormLoggingCallbackConfig
|
45
49
|
from .print_table import PrintTableMetricsCallback as PrintTableMetricsCallback
|
@@ -0,0 +1,90 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import logging
|
4
|
+
from typing import Literal
|
5
|
+
|
6
|
+
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
7
|
+
from typing_extensions import final, override, assert_never
|
8
|
+
from lightning.pytorch import Trainer
|
9
|
+
from lightning.pytorch.callbacks import Callback
|
10
|
+
from ..metrics import MetricConfig
|
11
|
+
from .base import CallbackConfigBase, callback_registry
|
12
|
+
|
13
|
+
log = logging.getLogger(__name__)
|
14
|
+
|
15
|
+
|
16
|
+
@final
|
17
|
+
@callback_registry.register
|
18
|
+
class MetricValidationCallbackConfig(CallbackConfigBase):
|
19
|
+
name: Literal["metric_validation"] = "metric_validation"
|
20
|
+
|
21
|
+
error_behavior: Literal["raise", "warn"] = "raise"
|
22
|
+
"""
|
23
|
+
Behavior when an error occurs during validation:
|
24
|
+
- "raise": Raise an error and stop the training.
|
25
|
+
- "warn": Log a warning and continue the training.
|
26
|
+
"""
|
27
|
+
|
28
|
+
validate_default_metric: bool = True
|
29
|
+
"""Whether to validate the default metric from the root config."""
|
30
|
+
|
31
|
+
metrics: list[MetricConfig] = []
|
32
|
+
"""List of metrics to validate."""
|
33
|
+
|
34
|
+
@override
|
35
|
+
def create_callbacks(self, trainer_config):
|
36
|
+
metrics = self.metrics.copy()
|
37
|
+
if (
|
38
|
+
self.validate_default_metric
|
39
|
+
and (default_metric := trainer_config.primary_metric) is not None
|
40
|
+
):
|
41
|
+
metrics.append(default_metric)
|
42
|
+
|
43
|
+
yield MetricValidationCallback(self, metrics)
|
44
|
+
|
45
|
+
|
46
|
+
class MetricValidationCallback(Callback):
|
47
|
+
def __init__(
|
48
|
+
self,
|
49
|
+
config: MetricValidationCallbackConfig,
|
50
|
+
metrics: list[MetricConfig],
|
51
|
+
):
|
52
|
+
super().__init__()
|
53
|
+
|
54
|
+
self.config = config
|
55
|
+
self.metrics = metrics
|
56
|
+
|
57
|
+
def _check_metrics(self, trainer: Trainer):
|
58
|
+
metric_names = ", ".join(metric.validation_monitor for metric in self.metrics)
|
59
|
+
log.info(f"Validating metrics: {metric_names}...")
|
60
|
+
logged_metrics = set(trainer.logged_metrics.keys())
|
61
|
+
|
62
|
+
invalid_metrics: list[str] = []
|
63
|
+
for metric in self.metrics:
|
64
|
+
if metric.validation_monitor not in logged_metrics:
|
65
|
+
invalid_metrics.append(metric.validation_monitor)
|
66
|
+
|
67
|
+
if invalid_metrics:
|
68
|
+
msg = (
|
69
|
+
f"The following metrics were not found in logged metrics: {invalid_metrics}\n"
|
70
|
+
f"List of logged metrics: {list(trainer.logged_metrics.keys())}"
|
71
|
+
)
|
72
|
+
match self.config.error_behavior:
|
73
|
+
case "raise":
|
74
|
+
raise MisconfigurationException(msg)
|
75
|
+
case "warn":
|
76
|
+
log.warning(msg)
|
77
|
+
case _:
|
78
|
+
assert_never(self.config.error_behavior)
|
79
|
+
|
80
|
+
@override
|
81
|
+
def on_sanity_check_end(self, trainer, pl_module):
|
82
|
+
super().on_sanity_check_end(trainer, pl_module)
|
83
|
+
|
84
|
+
self._check_metrics(trainer)
|
85
|
+
|
86
|
+
@override
|
87
|
+
def on_validation_end(self, trainer, pl_module):
|
88
|
+
super().on_validation_end(trainer, pl_module)
|
89
|
+
|
90
|
+
self._check_metrics(trainer)
|
nshtrainer/configs/__init__.py
CHANGED
@@ -39,6 +39,9 @@ from nshtrainer.callbacks import (
|
|
39
39
|
)
|
40
40
|
from nshtrainer.callbacks import LearningRateMonitorConfig as LearningRateMonitorConfig
|
41
41
|
from nshtrainer.callbacks import LogEpochCallbackConfig as LogEpochCallbackConfig
|
42
|
+
from nshtrainer.callbacks import (
|
43
|
+
MetricValidationCallbackConfig as MetricValidationCallbackConfig,
|
44
|
+
)
|
42
45
|
from nshtrainer.callbacks import NormLoggingCallbackConfig as NormLoggingCallbackConfig
|
43
46
|
from nshtrainer.callbacks import (
|
44
47
|
OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
|
@@ -287,6 +290,7 @@ __all__ = [
|
|
287
290
|
"MPIEnvironmentPlugin",
|
288
291
|
"MPSAcceleratorConfig",
|
289
292
|
"MetricConfig",
|
293
|
+
"MetricValidationCallbackConfig",
|
290
294
|
"MishNonlinearityConfig",
|
291
295
|
"MixedPrecisionPluginConfig",
|
292
296
|
"NonlinearityConfig",
|
@@ -28,6 +28,9 @@ from nshtrainer.callbacks import (
|
|
28
28
|
)
|
29
29
|
from nshtrainer.callbacks import LearningRateMonitorConfig as LearningRateMonitorConfig
|
30
30
|
from nshtrainer.callbacks import LogEpochCallbackConfig as LogEpochCallbackConfig
|
31
|
+
from nshtrainer.callbacks import (
|
32
|
+
MetricValidationCallbackConfig as MetricValidationCallbackConfig,
|
33
|
+
)
|
31
34
|
from nshtrainer.callbacks import NormLoggingCallbackConfig as NormLoggingCallbackConfig
|
32
35
|
from nshtrainer.callbacks import (
|
33
36
|
OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
|
@@ -65,6 +68,7 @@ from . import finite_checks as finite_checks
|
|
65
68
|
from . import gradient_skipping as gradient_skipping
|
66
69
|
from . import log_epoch as log_epoch
|
67
70
|
from . import lr_monitor as lr_monitor
|
71
|
+
from . import metric_validation as metric_validation
|
68
72
|
from . import norm_logging as norm_logging
|
69
73
|
from . import print_table as print_table
|
70
74
|
from . import rlp_sanity_checks as rlp_sanity_checks
|
@@ -91,6 +95,7 @@ __all__ = [
|
|
91
95
|
"LearningRateMonitorConfig",
|
92
96
|
"LogEpochCallbackConfig",
|
93
97
|
"MetricConfig",
|
98
|
+
"MetricValidationCallbackConfig",
|
94
99
|
"NormLoggingCallbackConfig",
|
95
100
|
"OnExceptionCheckpointCallbackConfig",
|
96
101
|
"PrintTableMetricsCallbackConfig",
|
@@ -110,6 +115,7 @@ __all__ = [
|
|
110
115
|
"gradient_skipping",
|
111
116
|
"log_epoch",
|
112
117
|
"lr_monitor",
|
118
|
+
"metric_validation",
|
113
119
|
"norm_logging",
|
114
120
|
"print_table",
|
115
121
|
"rlp_sanity_checks",
|
@@ -0,0 +1,21 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__codegen__ = True
|
4
|
+
|
5
|
+
from nshtrainer.callbacks.metric_validation import (
|
6
|
+
CallbackConfigBase as CallbackConfigBase,
|
7
|
+
)
|
8
|
+
from nshtrainer.callbacks.metric_validation import MetricConfig as MetricConfig
|
9
|
+
from nshtrainer.callbacks.metric_validation import (
|
10
|
+
MetricValidationCallbackConfig as MetricValidationCallbackConfig,
|
11
|
+
)
|
12
|
+
from nshtrainer.callbacks.metric_validation import (
|
13
|
+
callback_registry as callback_registry,
|
14
|
+
)
|
15
|
+
|
16
|
+
__all__ = [
|
17
|
+
"CallbackConfigBase",
|
18
|
+
"MetricConfig",
|
19
|
+
"MetricValidationCallbackConfig",
|
20
|
+
"callback_registry",
|
21
|
+
]
|
@@ -38,6 +38,9 @@ from nshtrainer.trainer._config import LogEpochCallbackConfig as LogEpochCallbac
|
|
38
38
|
from nshtrainer.trainer._config import LoggerConfig as LoggerConfig
|
39
39
|
from nshtrainer.trainer._config import LoggerConfigBase as LoggerConfigBase
|
40
40
|
from nshtrainer.trainer._config import MetricConfig as MetricConfig
|
41
|
+
from nshtrainer.trainer._config import (
|
42
|
+
MetricValidationCallbackConfig as MetricValidationCallbackConfig,
|
43
|
+
)
|
41
44
|
from nshtrainer.trainer._config import (
|
42
45
|
NormLoggingCallbackConfig as NormLoggingCallbackConfig,
|
43
46
|
)
|
@@ -164,6 +167,7 @@ __all__ = [
|
|
164
167
|
"MPIEnvironmentPlugin",
|
165
168
|
"MPSAcceleratorConfig",
|
166
169
|
"MetricConfig",
|
170
|
+
"MetricValidationCallbackConfig",
|
167
171
|
"MixedPrecisionPluginConfig",
|
168
172
|
"NormLoggingCallbackConfig",
|
169
173
|
"OnExceptionCheckpointCallbackConfig",
|
@@ -34,6 +34,9 @@ from nshtrainer.trainer._config import LogEpochCallbackConfig as LogEpochCallbac
|
|
34
34
|
from nshtrainer.trainer._config import LoggerConfig as LoggerConfig
|
35
35
|
from nshtrainer.trainer._config import LoggerConfigBase as LoggerConfigBase
|
36
36
|
from nshtrainer.trainer._config import MetricConfig as MetricConfig
|
37
|
+
from nshtrainer.trainer._config import (
|
38
|
+
MetricValidationCallbackConfig as MetricValidationCallbackConfig,
|
39
|
+
)
|
37
40
|
from nshtrainer.trainer._config import (
|
38
41
|
NormLoggingCallbackConfig as NormLoggingCallbackConfig,
|
39
42
|
)
|
@@ -77,6 +80,7 @@ __all__ = [
|
|
77
80
|
"LoggerConfig",
|
78
81
|
"LoggerConfigBase",
|
79
82
|
"MetricConfig",
|
83
|
+
"MetricValidationCallbackConfig",
|
80
84
|
"NormLoggingCallbackConfig",
|
81
85
|
"OnExceptionCheckpointCallbackConfig",
|
82
86
|
"PluginConfig",
|
@@ -4,14 +4,12 @@ __codegen__ = True
|
|
4
4
|
|
5
5
|
from nshtrainer.trainer.trainer import AcceleratorConfigBase as AcceleratorConfigBase
|
6
6
|
from nshtrainer.trainer.trainer import EnvironmentConfig as EnvironmentConfig
|
7
|
-
from nshtrainer.trainer.trainer import PluginConfigBase as PluginConfigBase
|
8
7
|
from nshtrainer.trainer.trainer import StrategyConfigBase as StrategyConfigBase
|
9
8
|
from nshtrainer.trainer.trainer import TrainerConfig as TrainerConfig
|
10
9
|
|
11
10
|
__all__ = [
|
12
11
|
"AcceleratorConfigBase",
|
13
12
|
"EnvironmentConfig",
|
14
|
-
"PluginConfigBase",
|
15
13
|
"StrategyConfigBase",
|
16
14
|
"TrainerConfig",
|
17
15
|
]
|
nshtrainer/nn/__init__.py
CHANGED
@@ -2,7 +2,6 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
from .mlp import MLP as MLP
|
4
4
|
from .mlp import MLPConfig as MLPConfig
|
5
|
-
from .mlp import MLPConfigDict as MLPConfigDict
|
6
5
|
from .mlp import ResidualSequential as ResidualSequential
|
7
6
|
from .mlp import custom_seed_context as custom_seed_context
|
8
7
|
from .module_dict import TypedModuleDict as TypedModuleDict
|
nshtrainer/nn/mlp.py
CHANGED
@@ -3,12 +3,12 @@ from __future__ import annotations
|
|
3
3
|
import contextlib
|
4
4
|
import copy
|
5
5
|
from collections.abc import Callable, Sequence
|
6
|
-
from typing import Literal, Protocol, runtime_checkable
|
6
|
+
from typing import Any, Literal, Protocol, runtime_checkable
|
7
7
|
|
8
8
|
import nshconfig as C
|
9
9
|
import torch
|
10
10
|
import torch.nn as nn
|
11
|
-
from typing_extensions import
|
11
|
+
from typing_extensions import deprecated, override
|
12
12
|
|
13
13
|
from .nonlinearity import NonlinearityConfig, NonlinearityConfigBase
|
14
14
|
|
@@ -26,29 +26,6 @@ class ResidualSequential(nn.Sequential):
|
|
26
26
|
return input + super().forward(input)
|
27
27
|
|
28
28
|
|
29
|
-
class MLPConfigDict(TypedDict):
|
30
|
-
bias: bool
|
31
|
-
"""Whether to include bias terms in the linear layers."""
|
32
|
-
|
33
|
-
no_bias_scalar: bool
|
34
|
-
"""Whether to exclude bias terms when the output dimension is 1."""
|
35
|
-
|
36
|
-
nonlinearity: NonlinearityConfig | None
|
37
|
-
"""Activation function to use between layers."""
|
38
|
-
|
39
|
-
ln: bool | Literal["pre", "post"]
|
40
|
-
"""Whether to apply layer normalization before or after the linear layers."""
|
41
|
-
|
42
|
-
dropout: float | None
|
43
|
-
"""Dropout probability to apply between layers."""
|
44
|
-
|
45
|
-
residual: bool
|
46
|
-
"""Whether to use residual connections between layers."""
|
47
|
-
|
48
|
-
seed: int | None
|
49
|
-
"""Random seed to use for initialization. If None, the default Torch behavior is used."""
|
50
|
-
|
51
|
-
|
52
29
|
class MLPConfig(C.Config):
|
53
30
|
bias: bool = True
|
54
31
|
"""Whether to include bias terms in the linear layers."""
|
@@ -71,8 +48,15 @@ class MLPConfig(C.Config):
|
|
71
48
|
seed: int | None = None
|
72
49
|
"""Random seed to use for initialization. If None, the default Torch behavior is used."""
|
73
50
|
|
74
|
-
|
75
|
-
|
51
|
+
@deprecated("Use `nt.nn.MLP(config=...)` instead.")
|
52
|
+
def create_module(
|
53
|
+
self,
|
54
|
+
dims: Sequence[int],
|
55
|
+
pre_layers: Sequence[nn.Module] = [],
|
56
|
+
post_layers: Sequence[nn.Module] = [],
|
57
|
+
linear_cls: LinearModuleConstructor = nn.Linear,
|
58
|
+
):
|
59
|
+
kwargs: dict[str, Any] = {
|
76
60
|
"bias": self.bias,
|
77
61
|
"no_bias_scalar": self.no_bias_scalar,
|
78
62
|
"nonlinearity": self.nonlinearity,
|
@@ -81,18 +65,9 @@ class MLPConfig(C.Config):
|
|
81
65
|
"residual": self.residual,
|
82
66
|
"seed": self.seed,
|
83
67
|
}
|
84
|
-
return kwargs
|
85
|
-
|
86
|
-
def create_module(
|
87
|
-
self,
|
88
|
-
dims: Sequence[int],
|
89
|
-
pre_layers: Sequence[nn.Module] = [],
|
90
|
-
post_layers: Sequence[nn.Module] = [],
|
91
|
-
linear_cls: LinearModuleConstructor = nn.Linear,
|
92
|
-
):
|
93
68
|
return MLP(
|
94
69
|
dims,
|
95
|
-
**
|
70
|
+
**kwargs,
|
96
71
|
pre_layers=pre_layers,
|
97
72
|
post_layers=post_layers,
|
98
73
|
linear_cls=linear_cls,
|
@@ -121,50 +96,73 @@ def MLP(
|
|
121
96
|
| nn.Module
|
122
97
|
| Callable[[], nn.Module]
|
123
98
|
| None = None,
|
124
|
-
bias: bool =
|
125
|
-
no_bias_scalar: bool =
|
126
|
-
ln: bool | Literal["pre", "post"] =
|
99
|
+
bias: bool | None = None,
|
100
|
+
no_bias_scalar: bool | None = None,
|
101
|
+
ln: bool | Literal["pre", "post"] | None = None,
|
127
102
|
dropout: float | None = None,
|
128
|
-
residual: bool =
|
103
|
+
residual: bool | None = None,
|
129
104
|
pre_layers: Sequence[nn.Module] = [],
|
130
105
|
post_layers: Sequence[nn.Module] = [],
|
131
106
|
linear_cls: LinearModuleConstructor = nn.Linear,
|
132
107
|
seed: int | None = None,
|
108
|
+
config: MLPConfig | None = None,
|
133
109
|
):
|
134
110
|
"""
|
135
111
|
Constructs a multi-layer perceptron (MLP) with the given dimensions and activation function.
|
136
112
|
|
137
113
|
Args:
|
138
114
|
dims (Sequence[int]): List of integers representing the dimensions of the MLP.
|
139
|
-
nonlinearity (Callable[[], nn.Module]): Activation function to use between layers.
|
140
|
-
activation (Callable[[], nn.Module]): Activation function to use between layers.
|
141
|
-
bias (bool, optional): Whether to include bias terms in the linear layers.
|
142
|
-
no_bias_scalar (bool, optional): Whether to exclude bias terms when the output dimension is 1.
|
143
|
-
ln (bool | Literal["pre", "post"], optional): Whether to apply layer normalization before or after the linear layers.
|
144
|
-
dropout (float | None, optional): Dropout probability to apply between layers.
|
145
|
-
residual (bool, optional): Whether to use residual connections between layers.
|
115
|
+
nonlinearity (Callable[[], nn.Module] | None, optional): Activation function to use between layers.
|
116
|
+
activation (Callable[[], nn.Module] | None, optional): Activation function to use between layers.
|
117
|
+
bias (bool | None, optional): Whether to include bias terms in the linear layers.
|
118
|
+
no_bias_scalar (bool | None, optional): Whether to exclude bias terms when the output dimension is 1.
|
119
|
+
ln (bool | Literal["pre", "post"] | None, optional): Whether to apply layer normalization before or after the linear layers.
|
120
|
+
dropout (float | None, optional): Dropout probability to apply between layers.
|
121
|
+
residual (bool | None, optional): Whether to use residual connections between layers.
|
146
122
|
pre_layers (Sequence[nn.Module], optional): List of layers to insert before the linear layers. Defaults to [].
|
147
123
|
post_layers (Sequence[nn.Module], optional): List of layers to insert after the linear layers. Defaults to [].
|
148
124
|
linear_cls (LinearModuleConstructor, optional): Linear module constructor to use. Defaults to nn.Linear.
|
149
|
-
seed (int | None, optional): Random seed to use for initialization. If None, the default Torch behavior is used.
|
125
|
+
seed (int | None, optional): Random seed to use for initialization. If None, the default Torch behavior is used.
|
126
|
+
config (MLPConfig | None, optional): Configuration object for the MLP. Parameters specified directly take precedence.
|
150
127
|
|
151
128
|
Returns:
|
152
129
|
nn.Sequential: The constructed MLP.
|
153
130
|
"""
|
154
131
|
|
155
|
-
|
132
|
+
# Resolve parameters: arg if not None, otherwise config value if config exists, otherwise default
|
133
|
+
resolved_bias = bias if bias is not None else (config.bias if config else True)
|
134
|
+
resolved_no_bias_scalar = (
|
135
|
+
no_bias_scalar
|
136
|
+
if no_bias_scalar is not None
|
137
|
+
else (config.no_bias_scalar if config else True)
|
138
|
+
)
|
139
|
+
resolved_nonlinearity = (
|
140
|
+
nonlinearity
|
141
|
+
if nonlinearity is not None
|
142
|
+
else (config.nonlinearity if config else None)
|
143
|
+
)
|
144
|
+
resolved_ln = ln if ln is not None else (config.ln if config else False)
|
145
|
+
resolved_dropout = (
|
146
|
+
dropout if dropout is not None else (config.dropout if config else None)
|
147
|
+
)
|
148
|
+
resolved_residual = (
|
149
|
+
residual if residual is not None else (config.residual if config else False)
|
150
|
+
)
|
151
|
+
resolved_seed = seed if seed is not None else (config.seed if config else None)
|
152
|
+
|
153
|
+
with custom_seed_context(resolved_seed):
|
156
154
|
if activation is None:
|
157
|
-
activation =
|
155
|
+
activation = resolved_nonlinearity
|
158
156
|
|
159
157
|
if len(dims) < 2:
|
160
158
|
raise ValueError("mlp requires at least 2 dimensions")
|
161
|
-
if
|
162
|
-
|
163
|
-
elif isinstance(
|
159
|
+
if resolved_ln is True:
|
160
|
+
resolved_ln = "pre"
|
161
|
+
elif isinstance(resolved_ln, str) and resolved_ln not in ("pre", "post"):
|
164
162
|
raise ValueError("ln must be a boolean or 'pre' or 'post'")
|
165
163
|
|
166
164
|
layers: list[nn.Module] = []
|
167
|
-
if
|
165
|
+
if resolved_ln == "pre":
|
168
166
|
layers.append(nn.LayerNorm(dims[0]))
|
169
167
|
|
170
168
|
layers.extend(pre_layers)
|
@@ -172,10 +170,12 @@ def MLP(
|
|
172
170
|
for i in range(len(dims) - 1):
|
173
171
|
in_features = dims[i]
|
174
172
|
out_features = dims[i + 1]
|
175
|
-
bias_ =
|
173
|
+
bias_ = resolved_bias and not (
|
174
|
+
resolved_no_bias_scalar and out_features == 1
|
175
|
+
)
|
176
176
|
layers.append(linear_cls(in_features, out_features, bias=bias_))
|
177
|
-
if
|
178
|
-
layers.append(nn.Dropout(
|
177
|
+
if resolved_dropout is not None:
|
178
|
+
layers.append(nn.Dropout(resolved_dropout))
|
179
179
|
if i < len(dims) - 2:
|
180
180
|
match activation:
|
181
181
|
case NonlinearityConfigBase():
|
@@ -192,8 +192,8 @@ def MLP(
|
|
192
192
|
|
193
193
|
layers.extend(post_layers)
|
194
194
|
|
195
|
-
if
|
195
|
+
if resolved_ln == "post":
|
196
196
|
layers.append(nn.LayerNorm(dims[-1]))
|
197
197
|
|
198
|
-
cls = ResidualSequential if
|
198
|
+
cls = ResidualSequential if resolved_residual else nn.Sequential
|
199
199
|
return cls(*layers)
|
nshtrainer/trainer/_config.py
CHANGED
@@ -40,6 +40,7 @@ from ..callbacks.base import CallbackConfigBase
|
|
40
40
|
from ..callbacks.debug_flag import DebugFlagCallbackConfig
|
41
41
|
from ..callbacks.log_epoch import LogEpochCallbackConfig
|
42
42
|
from ..callbacks.lr_monitor import LearningRateMonitorConfig
|
43
|
+
from ..callbacks.metric_validation import MetricValidationCallbackConfig
|
43
44
|
from ..callbacks.rlp_sanity_checks import RLPSanityChecksCallbackConfig
|
44
45
|
from ..callbacks.shared_parameters import SharedParametersCallbackConfig
|
45
46
|
from ..loggers import (
|
@@ -697,6 +698,10 @@ class TrainerConfig(C.Config):
|
|
697
698
|
- The trainer is running in fast_dev_run mode.
|
698
699
|
- The trainer is running a sanity check (which happens before starting the training routine).
|
699
700
|
"""
|
701
|
+
auto_validate_metrics: MetricValidationCallbackConfig | None = (
|
702
|
+
MetricValidationCallbackConfig()
|
703
|
+
)
|
704
|
+
"""If enabled, will automatically validate the metrics before starting the training routine."""
|
700
705
|
|
701
706
|
lightning_kwargs: LightningTrainerKwargs = LightningTrainerKwargs()
|
702
707
|
"""
|
@@ -768,6 +773,7 @@ class TrainerConfig(C.Config):
|
|
768
773
|
yield self.shared_parameters
|
769
774
|
yield self.reduce_lr_on_plateau_sanity_checking
|
770
775
|
yield self.auto_set_debug_flag
|
776
|
+
yield self.auto_validate_metrics
|
771
777
|
yield from self.callbacks
|
772
778
|
|
773
779
|
def _nshtrainer_all_logger_configs(self) -> Iterable[LoggerConfigBase | None]:
|
nshtrainer/trainer/trainer.py
CHANGED
@@ -25,7 +25,6 @@ from ..util.bf16 import is_bf16_supported_no_emulation
|
|
25
25
|
from ._config import LightningTrainerKwargs, TrainerConfig
|
26
26
|
from ._runtime_callback import RuntimeTrackerCallback, Stage
|
27
27
|
from .accelerator import AcceleratorConfigBase
|
28
|
-
from .plugin import PluginConfigBase
|
29
28
|
from .signal_connector import _SignalConnector
|
30
29
|
from .strategy import StrategyConfigBase
|
31
30
|
|
@@ -6,7 +6,7 @@ nshtrainer/_checkpoint/saver.py,sha256=rWl4d2lCTMU4_wt8yZFL2pFQaP9hj5sPgqHMPQ4zu
|
|
6
6
|
nshtrainer/_directory.py,sha256=TJR9ccyuzRlAVfVjGyeQ3E2AFAcz-XbBCxWfiXo2SlY,3191
|
7
7
|
nshtrainer/_experimental/__init__.py,sha256=U4S_2y3zgLZVfMenHRaJFBW8yqh2mUBuI291LGQVOJ8,35
|
8
8
|
nshtrainer/_hf_hub.py,sha256=4OsCbIITnZk_YLyoMrVyZ0SIN04FBxlC0ig2Et8UAdo,14287
|
9
|
-
nshtrainer/callbacks/__init__.py,sha256=
|
9
|
+
nshtrainer/callbacks/__init__.py,sha256=w80d6PGNu3wjUj9NiRGMqCX9NnXD5ZlvbY-DIK4zjPE,3766
|
10
10
|
nshtrainer/callbacks/actsave.py,sha256=NSXIIu62MNYe5gz479SMW33bdoKYoYtWtd_iTWFpKpc,3881
|
11
11
|
nshtrainer/callbacks/base.py,sha256=Alaou1IHAIlMEM7g58d_02ozY2xWlshBN7fsw5Ee21s,3683
|
12
12
|
nshtrainer/callbacks/checkpoint/__init__.py,sha256=l8tkHc83_mLiU0-wT09SWdRzwpm2ulbkLzcuCmuTwzE,620
|
@@ -23,6 +23,7 @@ nshtrainer/callbacks/gradient_skipping.py,sha256=8g7oC7PF0LTAEzwiNoaS5tWOnkjk_EB
|
|
23
23
|
nshtrainer/callbacks/interval.py,sha256=UCzUzt3XCFVyQyCWL9lOrStkkxesvduNOYk8yMrGTTk,8116
|
24
24
|
nshtrainer/callbacks/log_epoch.py,sha256=B5Dm8XVZwCzKUhUWfT_5PDdDac993191OsbcxxuSVJE,1457
|
25
25
|
nshtrainer/callbacks/lr_monitor.py,sha256=qy_C0R40J0hBAukzBwng5FI2jJUpWuXOi5N6FU6ym3I,1210
|
26
|
+
nshtrainer/callbacks/metric_validation.py,sha256=tqUVS2n9QRT3v1_8jAGlYBFhLpA6Bm9pxOsfWhD3yZQ,2915
|
26
27
|
nshtrainer/callbacks/norm_logging.py,sha256=nVIDWe-ASl5zN830-ODR8QMCqI1ma-QPCIwoy0Wb-Nk,6390
|
27
28
|
nshtrainer/callbacks/print_table.py,sha256=VaS4JgI963do79laXK4lUkFQx8v6aRSy22W0zyal_LA,3035
|
28
29
|
nshtrainer/callbacks/rlp_sanity_checks.py,sha256=74BZvV2HLO__ucQXsLXb8eJLUZgRFUNJZ6TL9efMp74,10051
|
@@ -31,12 +32,12 @@ nshtrainer/callbacks/timer.py,sha256=gDcw_K_ikf0bkVgxQ0cDhvvNvz6GLZVLcatuKfh0ORU
|
|
31
32
|
nshtrainer/callbacks/wandb_upload_code.py,sha256=shV7UtnXgY2bUlXdVrXiaDs0PNLlIt7TzNJkJPkzvzI,2414
|
32
33
|
nshtrainer/callbacks/wandb_watch.py,sha256=VB14Dy5ZRXQ3di0fPv0K_DFJurLhroLPytnuwQBiJFg,3037
|
33
34
|
nshtrainer/configs/.gitattributes,sha256=VeZmarvNEqiRBOHGcllpKm90nL6C8u4tBu7SEm7fj-E,26
|
34
|
-
nshtrainer/configs/__init__.py,sha256=
|
35
|
+
nshtrainer/configs/__init__.py,sha256=0BzCgE1iEJ0Ywmy__mqJZipLQtwZVdz6XK-gHbkA7GY,14650
|
35
36
|
nshtrainer/configs/_checkpoint/__init__.py,sha256=6s7Y68StboqscY2G4P_QG443jz5aiym5SjOogIljWLg,342
|
36
37
|
nshtrainer/configs/_checkpoint/metadata/__init__.py,sha256=oOPfYkXTjKgm6pluGsG6V1TPyCEGjsQpHVL-LffSUFQ,290
|
37
38
|
nshtrainer/configs/_directory/__init__.py,sha256=_oO7vM9DhzHSxtZcv86sTi7hZIptnK1gr-AP9mqQ370,386
|
38
39
|
nshtrainer/configs/_hf_hub/__init__.py,sha256=ciFLbV-JV8SVzqo2SyythEuDMnk7gGfdIacB18QYnkY,511
|
39
|
-
nshtrainer/configs/callbacks/__init__.py,sha256=
|
40
|
+
nshtrainer/configs/callbacks/__init__.py,sha256=PB3Jg-8_vMhp-mCFw2_Tqt05drKwHK6Ovl9mb8NNiXs,4506
|
40
41
|
nshtrainer/configs/callbacks/actsave/__init__.py,sha256=JvjSZtEoA28FC4u-QT3skQzBDVbN9eq07rn4u2ydW-E,377
|
41
42
|
nshtrainer/configs/callbacks/base/__init__.py,sha256=wT3RhXttLyf6RFWCIvsoiXcPdfGx5W309WBI18AI5os,278
|
42
43
|
nshtrainer/configs/callbacks/checkpoint/__init__.py,sha256=aGJ7vX14YamkMdwYAdPv6XrRnP0aZd5uZ5X0nSLc6IU,1475
|
@@ -52,6 +53,7 @@ nshtrainer/configs/callbacks/finite_checks/__init__.py,sha256=e-vx9Kn-noqw4wPvZw
|
|
52
53
|
nshtrainer/configs/callbacks/gradient_skipping/__init__.py,sha256=T3eVxxJfnYBrO9WfLiycn4TyWP4vaqJ57yp7Epkg7B4,485
|
53
54
|
nshtrainer/configs/callbacks/log_epoch/__init__.py,sha256=IQ5owYYvyk7fiQP1QXYtncRRJrESuq3rRFhab-II2uE,419
|
54
55
|
nshtrainer/configs/callbacks/lr_monitor/__init__.py,sha256=qejy1AnXNDHmsFuXRAXQQ5B0TcbKzvpaw-I4dv2AXIs,431
|
56
|
+
nshtrainer/configs/callbacks/metric_validation/__init__.py,sha256=_YV0EbISkforE_GDlTTVA6Nn2_l13zX3m1ggcbhnAvs,585
|
55
57
|
nshtrainer/configs/callbacks/norm_logging/__init__.py,sha256=j2LrnYEbDGLwJR2lk-jmh-4J_iLEs2HNEoepvJSFLAg,437
|
56
58
|
nshtrainer/configs/callbacks/print_table/__init__.py,sha256=t6fA_dBkUCszUXDJKEdnlBH4oEpfAQqcmAlatTFYIyQ,452
|
57
59
|
nshtrainer/configs/callbacks/rlp_sanity_checks/__init__.py,sha256=dlP14Wh-w8zG_B4EtNmCIFzVMhf6bXCJ1O9cJWmEFnA,482
|
@@ -80,8 +82,8 @@ nshtrainer/configs/profiler/_base/__init__.py,sha256=ekYfPg-VDhCAFM5nJka2TxUYdRD
|
|
80
82
|
nshtrainer/configs/profiler/advanced/__init__.py,sha256=-ThpUat16Ij_0avkMUVVA8wCWDG_q_tM7KQofnWQCtg,308
|
81
83
|
nshtrainer/configs/profiler/pytorch/__init__.py,sha256=soAU1s2_Pa1na4gW8CK-iysJBO5M_7YeZC2_x40iEdg,294
|
82
84
|
nshtrainer/configs/profiler/simple/__init__.py,sha256=3Wb11lPuFuyasq8xS1CZ4WLuBCLS_nVSQGVllvOOi0Y,289
|
83
|
-
nshtrainer/configs/trainer/__init__.py,sha256=
|
84
|
-
nshtrainer/configs/trainer/_config/__init__.py,sha256=
|
85
|
+
nshtrainer/configs/trainer/__init__.py,sha256=a8pzGVid52abAVARPbgjaN566H1ZM44FH_x95bsBaGE,7880
|
86
|
+
nshtrainer/configs/trainer/_config/__init__.py,sha256=6DXdtP-uH11TopQ7kzId9fco-wVkD7ZfevbBqDpN6TE,3817
|
85
87
|
nshtrainer/configs/trainer/accelerator/__init__.py,sha256=3H6R3wlwbKL1TzDqGCChZk78-BcE2czLouo7Djiq3nA,898
|
86
88
|
nshtrainer/configs/trainer/plugin/__init__.py,sha256=NkHQxMPkrtTtdIAO4dQUE9SWEcHRDB0yUXLkTjnl4dA,3332
|
87
89
|
nshtrainer/configs/trainer/plugin/base/__init__.py,sha256=GuubcKrbXt4VjJRT8VpNUQqBtyuutre_CfkS0EWZ5_E,368
|
@@ -90,7 +92,7 @@ nshtrainer/configs/trainer/plugin/io/__init__.py,sha256=W6G67JnigB6d3MiwLrbSKgtI
|
|
90
92
|
nshtrainer/configs/trainer/plugin/layer_sync/__init__.py,sha256=SYDZk2M6sgpt4sEuoURuS8EKYmaqGcvYxETE9jvTrEE,431
|
91
93
|
nshtrainer/configs/trainer/plugin/precision/__init__.py,sha256=szlqSfK2XuWdkf72LQzQFv3SlWfKFdRUpBEYIxQ3TPs,1507
|
92
94
|
nshtrainer/configs/trainer/strategy/__init__.py,sha256=50whNloJVBq_bdbLaPQnPBTeS1Rcs8MwxTCYBj1kKa4,273
|
93
|
-
nshtrainer/configs/trainer/trainer/__init__.py,sha256=
|
95
|
+
nshtrainer/configs/trainer/trainer/__init__.py,sha256=DDuBRx0kVNMW0z_sqKTUt8-Ql7bOpargi4KcHHvDu_c,486
|
94
96
|
nshtrainer/configs/util/__init__.py,sha256=qXittS7f7MyaqJnjvFLKnKsyb6bXTD3dEV16jXVDaH4,2104
|
95
97
|
nshtrainer/configs/util/_environment_info/__init__.py,sha256=eB4E0Ck7XCeSC5gbUdA5thd7TXnjGCL0t8GZIFj7uCI,1644
|
96
98
|
nshtrainer/configs/util/config/__init__.py,sha256=nEFiDG3-dvvTytYn1tEkPFzp7fgaGRp2j7toSN7yRGs,501
|
@@ -117,8 +119,8 @@ nshtrainer/model/base.py,sha256=JL3AmH17GQjQIoMrZl3O0vUI7dj5ZsO5iEJgoLPyzHw,1035
|
|
117
119
|
nshtrainer/model/mixins/callback.py,sha256=0LPgve4VszHbLipid4mpI1qnnmdGS2spivs0dXLvqHw,3154
|
118
120
|
nshtrainer/model/mixins/debug.py,sha256=1LX9KzeFX9JDPs_a6YCdYDZXLhEk_5rBO2aCqlfBy7w,2087
|
119
121
|
nshtrainer/model/mixins/logger.py,sha256=27H99FuLaxc6_dDLG2pid4E_5E0-eLGnc2Ifpt0HYIM,6066
|
120
|
-
nshtrainer/nn/__init__.py,sha256=
|
121
|
-
nshtrainer/nn/mlp.py,sha256=
|
122
|
+
nshtrainer/nn/__init__.py,sha256=5Gg3nieGSC5_dXaI9KUVUUbM13hHexH9831m4hcf6no,1475
|
123
|
+
nshtrainer/nn/mlp.py,sha256=nYUgAISzuhC8sav6PloAdyz0PdEoikwppiXIuToEVdE,7550
|
122
124
|
nshtrainer/nn/module_dict.py,sha256=9plb8aQUx5TUEPhX5jI9u8LrpTeKe7jZAHi8iIqcN8w,2365
|
123
125
|
nshtrainer/nn/module_list.py,sha256=UB43pcwD_3nUke_DyLQt-iXKhWdKM6Zjm84lRC1hPYA,1755
|
124
126
|
nshtrainer/nn/nonlinearity.py,sha256=xmaL4QCRvCxqmaGIOwetJeKK-6IK4m2OV7D3SjxSwJQ,6322
|
@@ -129,7 +131,7 @@ nshtrainer/profiler/advanced.py,sha256=XrM3FX0ThCv5UwUrrH0l4Ow4LGAtpiBww2N8QAU5N
|
|
129
131
|
nshtrainer/profiler/pytorch.py,sha256=8K37XvPnCApUpIK8tA2zNMFIaIiTLSoxKQoiyCPBm1Q,2757
|
130
132
|
nshtrainer/profiler/simple.py,sha256=PimjqcU-JuS-8C0ZGHAdwCxgNLij4x0FH6WXsjBQzZs,1005
|
131
133
|
nshtrainer/trainer/__init__.py,sha256=fQ7gQRlGWX-90TYT0rttkQyvXDCzo7DAvJgr-jX1zsY,316
|
132
|
-
nshtrainer/trainer/_config.py,sha256=
|
134
|
+
nshtrainer/trainer/_config.py,sha256=pCBRtqIC_BzNPqthsDhd7L5_7DG5y8_uVq19lj1mtOM,33311
|
133
135
|
nshtrainer/trainer/_runtime_callback.py,sha256=6F2Gq27Q8OFfN3RtdNC6QRA8ac0LC1hh4DUE3V5WgbI,4217
|
134
136
|
nshtrainer/trainer/accelerator.py,sha256=Bqq-ry7DeCY4zw9_zBvTZiijpA-uUHrDjtbLV652m4M,2415
|
135
137
|
nshtrainer/trainer/plugin/__init__.py,sha256=UM8f70Ml3RGpsXeQr1Yh1yBcxxjFNGFGgJbniCn_rws,366
|
@@ -140,7 +142,7 @@ nshtrainer/trainer/plugin/layer_sync.py,sha256=h-ydZwXepnsw5-paLgiDatqPyQ_8C0QEv
|
|
140
142
|
nshtrainer/trainer/plugin/precision.py,sha256=I0QsB1bVxmsFmBOkgrAfGONsuYae_lD9Bz0PfJEQvH4,5598
|
141
143
|
nshtrainer/trainer/signal_connector.py,sha256=GhfGcSzfaTNhnj2QFkBDq5aT7FqbLMA7eC8SYQs8_8w,10828
|
142
144
|
nshtrainer/trainer/strategy.py,sha256=VPTn5z3zvXTydY8IJchjhjcOfpvtoejnvUkq5E4WTus,1368
|
143
|
-
nshtrainer/trainer/trainer.py,sha256=
|
145
|
+
nshtrainer/trainer/trainer.py,sha256=8wMe0qArbDfStS4UdmuKSC2aiAImR3mhj14_kCJiNSM,20797
|
144
146
|
nshtrainer/util/_environment_info.py,sha256=MT8mBe6ZolRfKiwU-les1P-lPNPqXpHQcfADrh_A3uY,24629
|
145
147
|
nshtrainer/util/bf16.py,sha256=9QhHZCkYSfYpIcxwAMoXyuh2yTSHBzT-EdLQB297jEs,762
|
146
148
|
nshtrainer/util/config/__init__.py,sha256=Z39JJufSb61Lhn2GfVcv3eFW_eorOrN9-9llDWlnZZM,272
|
@@ -152,6 +154,6 @@ nshtrainer/util/seed.py,sha256=diMV8iwBKN7Xxt5pELmui-gyqyT80_CZzomrWhNss0k,316
|
|
152
154
|
nshtrainer/util/slurm.py,sha256=HflkP5iI_r4UHMyPjw9R4dD5AHsJUpcfJw5PLvGYBRM,1603
|
153
155
|
nshtrainer/util/typed.py,sha256=Xt5fUU6zwLKSTLUdenovnKK0N8qUq89Kddz2_XeykVQ,164
|
154
156
|
nshtrainer/util/typing_utils.py,sha256=MjY-CUX9R5Tzat-BlFnQjwl1PQ_W2yZQoXhkYHlJ_VA,442
|
155
|
-
nshtrainer-1.0.
|
156
|
-
nshtrainer-1.0.
|
157
|
-
nshtrainer-1.0.
|
157
|
+
nshtrainer-1.0.0b45.dist-info/METADATA,sha256=_RPpe6F7DXpsQSmBF1GTc-E5VUfaC69fIYfoFhsip2s,988
|
158
|
+
nshtrainer-1.0.0b45.dist-info/WHEEL,sha256=XbeZDeTWKc1w7CSIyre5aMDU_-PohRwTQceYnisIYYY,88
|
159
|
+
nshtrainer-1.0.0b45.dist-info/RECORD,,
|
File without changes
|