nshtrainer 1.0.0b43__py3-none-any.whl → 1.0.0b45__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -40,6 +40,10 @@ from .log_epoch import LogEpochCallback as LogEpochCallback
40
40
  from .log_epoch import LogEpochCallbackConfig as LogEpochCallbackConfig
41
41
  from .lr_monitor import LearningRateMonitor as LearningRateMonitor
42
42
  from .lr_monitor import LearningRateMonitorConfig as LearningRateMonitorConfig
43
+ from .metric_validation import MetricValidationCallback as MetricValidationCallback
44
+ from .metric_validation import (
45
+ MetricValidationCallbackConfig as MetricValidationCallbackConfig,
46
+ )
43
47
  from .norm_logging import NormLoggingCallback as NormLoggingCallback
44
48
  from .norm_logging import NormLoggingCallbackConfig as NormLoggingCallbackConfig
45
49
  from .print_table import PrintTableMetricsCallback as PrintTableMetricsCallback
@@ -0,0 +1,90 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from typing import Literal
5
+
6
+ from lightning.pytorch.utilities.exceptions import MisconfigurationException
7
+ from typing_extensions import final, override, assert_never
8
+ from lightning.pytorch import Trainer
9
+ from lightning.pytorch.callbacks import Callback
10
+ from ..metrics import MetricConfig
11
+ from .base import CallbackConfigBase, callback_registry
12
+
13
+ log = logging.getLogger(__name__)
14
+
15
+
16
+ @final
17
+ @callback_registry.register
18
+ class MetricValidationCallbackConfig(CallbackConfigBase):
19
+ name: Literal["metric_validation"] = "metric_validation"
20
+
21
+ error_behavior: Literal["raise", "warn"] = "raise"
22
+ """
23
+ Behavior when an error occurs during validation:
24
+ - "raise": Raise an error and stop the training.
25
+ - "warn": Log a warning and continue the training.
26
+ """
27
+
28
+ validate_default_metric: bool = True
29
+ """Whether to validate the default metric from the root config."""
30
+
31
+ metrics: list[MetricConfig] = []
32
+ """List of metrics to validate."""
33
+
34
+ @override
35
+ def create_callbacks(self, trainer_config):
36
+ metrics = self.metrics.copy()
37
+ if (
38
+ self.validate_default_metric
39
+ and (default_metric := trainer_config.primary_metric) is not None
40
+ ):
41
+ metrics.append(default_metric)
42
+
43
+ yield MetricValidationCallback(self, metrics)
44
+
45
+
46
+ class MetricValidationCallback(Callback):
47
+ def __init__(
48
+ self,
49
+ config: MetricValidationCallbackConfig,
50
+ metrics: list[MetricConfig],
51
+ ):
52
+ super().__init__()
53
+
54
+ self.config = config
55
+ self.metrics = metrics
56
+
57
+ def _check_metrics(self, trainer: Trainer):
58
+ metric_names = ", ".join(metric.validation_monitor for metric in self.metrics)
59
+ log.info(f"Validating metrics: {metric_names}...")
60
+ logged_metrics = set(trainer.logged_metrics.keys())
61
+
62
+ invalid_metrics: list[str] = []
63
+ for metric in self.metrics:
64
+ if metric.validation_monitor not in logged_metrics:
65
+ invalid_metrics.append(metric.validation_monitor)
66
+
67
+ if invalid_metrics:
68
+ msg = (
69
+ f"The following metrics were not found in logged metrics: {invalid_metrics}\n"
70
+ f"List of logged metrics: {list(trainer.logged_metrics.keys())}"
71
+ )
72
+ match self.config.error_behavior:
73
+ case "raise":
74
+ raise MisconfigurationException(msg)
75
+ case "warn":
76
+ log.warning(msg)
77
+ case _:
78
+ assert_never(self.config.error_behavior)
79
+
80
+ @override
81
+ def on_sanity_check_end(self, trainer, pl_module):
82
+ super().on_sanity_check_end(trainer, pl_module)
83
+
84
+ self._check_metrics(trainer)
85
+
86
+ @override
87
+ def on_validation_end(self, trainer, pl_module):
88
+ super().on_validation_end(trainer, pl_module)
89
+
90
+ self._check_metrics(trainer)
@@ -39,6 +39,9 @@ from nshtrainer.callbacks import (
39
39
  )
40
40
  from nshtrainer.callbacks import LearningRateMonitorConfig as LearningRateMonitorConfig
41
41
  from nshtrainer.callbacks import LogEpochCallbackConfig as LogEpochCallbackConfig
42
+ from nshtrainer.callbacks import (
43
+ MetricValidationCallbackConfig as MetricValidationCallbackConfig,
44
+ )
42
45
  from nshtrainer.callbacks import NormLoggingCallbackConfig as NormLoggingCallbackConfig
43
46
  from nshtrainer.callbacks import (
44
47
  OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
@@ -287,6 +290,7 @@ __all__ = [
287
290
  "MPIEnvironmentPlugin",
288
291
  "MPSAcceleratorConfig",
289
292
  "MetricConfig",
293
+ "MetricValidationCallbackConfig",
290
294
  "MishNonlinearityConfig",
291
295
  "MixedPrecisionPluginConfig",
292
296
  "NonlinearityConfig",
@@ -28,6 +28,9 @@ from nshtrainer.callbacks import (
28
28
  )
29
29
  from nshtrainer.callbacks import LearningRateMonitorConfig as LearningRateMonitorConfig
30
30
  from nshtrainer.callbacks import LogEpochCallbackConfig as LogEpochCallbackConfig
31
+ from nshtrainer.callbacks import (
32
+ MetricValidationCallbackConfig as MetricValidationCallbackConfig,
33
+ )
31
34
  from nshtrainer.callbacks import NormLoggingCallbackConfig as NormLoggingCallbackConfig
32
35
  from nshtrainer.callbacks import (
33
36
  OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
@@ -65,6 +68,7 @@ from . import finite_checks as finite_checks
65
68
  from . import gradient_skipping as gradient_skipping
66
69
  from . import log_epoch as log_epoch
67
70
  from . import lr_monitor as lr_monitor
71
+ from . import metric_validation as metric_validation
68
72
  from . import norm_logging as norm_logging
69
73
  from . import print_table as print_table
70
74
  from . import rlp_sanity_checks as rlp_sanity_checks
@@ -91,6 +95,7 @@ __all__ = [
91
95
  "LearningRateMonitorConfig",
92
96
  "LogEpochCallbackConfig",
93
97
  "MetricConfig",
98
+ "MetricValidationCallbackConfig",
94
99
  "NormLoggingCallbackConfig",
95
100
  "OnExceptionCheckpointCallbackConfig",
96
101
  "PrintTableMetricsCallbackConfig",
@@ -110,6 +115,7 @@ __all__ = [
110
115
  "gradient_skipping",
111
116
  "log_epoch",
112
117
  "lr_monitor",
118
+ "metric_validation",
113
119
  "norm_logging",
114
120
  "print_table",
115
121
  "rlp_sanity_checks",
@@ -0,0 +1,21 @@
1
+ from __future__ import annotations
2
+
3
+ __codegen__ = True
4
+
5
+ from nshtrainer.callbacks.metric_validation import (
6
+ CallbackConfigBase as CallbackConfigBase,
7
+ )
8
+ from nshtrainer.callbacks.metric_validation import MetricConfig as MetricConfig
9
+ from nshtrainer.callbacks.metric_validation import (
10
+ MetricValidationCallbackConfig as MetricValidationCallbackConfig,
11
+ )
12
+ from nshtrainer.callbacks.metric_validation import (
13
+ callback_registry as callback_registry,
14
+ )
15
+
16
+ __all__ = [
17
+ "CallbackConfigBase",
18
+ "MetricConfig",
19
+ "MetricValidationCallbackConfig",
20
+ "callback_registry",
21
+ ]
@@ -38,6 +38,9 @@ from nshtrainer.trainer._config import LogEpochCallbackConfig as LogEpochCallbac
38
38
  from nshtrainer.trainer._config import LoggerConfig as LoggerConfig
39
39
  from nshtrainer.trainer._config import LoggerConfigBase as LoggerConfigBase
40
40
  from nshtrainer.trainer._config import MetricConfig as MetricConfig
41
+ from nshtrainer.trainer._config import (
42
+ MetricValidationCallbackConfig as MetricValidationCallbackConfig,
43
+ )
41
44
  from nshtrainer.trainer._config import (
42
45
  NormLoggingCallbackConfig as NormLoggingCallbackConfig,
43
46
  )
@@ -164,6 +167,7 @@ __all__ = [
164
167
  "MPIEnvironmentPlugin",
165
168
  "MPSAcceleratorConfig",
166
169
  "MetricConfig",
170
+ "MetricValidationCallbackConfig",
167
171
  "MixedPrecisionPluginConfig",
168
172
  "NormLoggingCallbackConfig",
169
173
  "OnExceptionCheckpointCallbackConfig",
@@ -34,6 +34,9 @@ from nshtrainer.trainer._config import LogEpochCallbackConfig as LogEpochCallbac
34
34
  from nshtrainer.trainer._config import LoggerConfig as LoggerConfig
35
35
  from nshtrainer.trainer._config import LoggerConfigBase as LoggerConfigBase
36
36
  from nshtrainer.trainer._config import MetricConfig as MetricConfig
37
+ from nshtrainer.trainer._config import (
38
+ MetricValidationCallbackConfig as MetricValidationCallbackConfig,
39
+ )
37
40
  from nshtrainer.trainer._config import (
38
41
  NormLoggingCallbackConfig as NormLoggingCallbackConfig,
39
42
  )
@@ -77,6 +80,7 @@ __all__ = [
77
80
  "LoggerConfig",
78
81
  "LoggerConfigBase",
79
82
  "MetricConfig",
83
+ "MetricValidationCallbackConfig",
80
84
  "NormLoggingCallbackConfig",
81
85
  "OnExceptionCheckpointCallbackConfig",
82
86
  "PluginConfig",
@@ -4,14 +4,12 @@ __codegen__ = True
4
4
 
5
5
  from nshtrainer.trainer.trainer import AcceleratorConfigBase as AcceleratorConfigBase
6
6
  from nshtrainer.trainer.trainer import EnvironmentConfig as EnvironmentConfig
7
- from nshtrainer.trainer.trainer import PluginConfigBase as PluginConfigBase
8
7
  from nshtrainer.trainer.trainer import StrategyConfigBase as StrategyConfigBase
9
8
  from nshtrainer.trainer.trainer import TrainerConfig as TrainerConfig
10
9
 
11
10
  __all__ = [
12
11
  "AcceleratorConfigBase",
13
12
  "EnvironmentConfig",
14
- "PluginConfigBase",
15
13
  "StrategyConfigBase",
16
14
  "TrainerConfig",
17
15
  ]
nshtrainer/nn/__init__.py CHANGED
@@ -2,7 +2,6 @@ from __future__ import annotations
2
2
 
3
3
  from .mlp import MLP as MLP
4
4
  from .mlp import MLPConfig as MLPConfig
5
- from .mlp import MLPConfigDict as MLPConfigDict
6
5
  from .mlp import ResidualSequential as ResidualSequential
7
6
  from .mlp import custom_seed_context as custom_seed_context
8
7
  from .module_dict import TypedModuleDict as TypedModuleDict
nshtrainer/nn/mlp.py CHANGED
@@ -3,12 +3,12 @@ from __future__ import annotations
3
3
  import contextlib
4
4
  import copy
5
5
  from collections.abc import Callable, Sequence
6
- from typing import Literal, Protocol, runtime_checkable
6
+ from typing import Any, Literal, Protocol, runtime_checkable
7
7
 
8
8
  import nshconfig as C
9
9
  import torch
10
10
  import torch.nn as nn
11
- from typing_extensions import TypedDict, override
11
+ from typing_extensions import deprecated, override
12
12
 
13
13
  from .nonlinearity import NonlinearityConfig, NonlinearityConfigBase
14
14
 
@@ -26,29 +26,6 @@ class ResidualSequential(nn.Sequential):
26
26
  return input + super().forward(input)
27
27
 
28
28
 
29
- class MLPConfigDict(TypedDict):
30
- bias: bool
31
- """Whether to include bias terms in the linear layers."""
32
-
33
- no_bias_scalar: bool
34
- """Whether to exclude bias terms when the output dimension is 1."""
35
-
36
- nonlinearity: NonlinearityConfig | None
37
- """Activation function to use between layers."""
38
-
39
- ln: bool | Literal["pre", "post"]
40
- """Whether to apply layer normalization before or after the linear layers."""
41
-
42
- dropout: float | None
43
- """Dropout probability to apply between layers."""
44
-
45
- residual: bool
46
- """Whether to use residual connections between layers."""
47
-
48
- seed: int | None
49
- """Random seed to use for initialization. If None, the default Torch behavior is used."""
50
-
51
-
52
29
  class MLPConfig(C.Config):
53
30
  bias: bool = True
54
31
  """Whether to include bias terms in the linear layers."""
@@ -71,8 +48,15 @@ class MLPConfig(C.Config):
71
48
  seed: int | None = None
72
49
  """Random seed to use for initialization. If None, the default Torch behavior is used."""
73
50
 
74
- def to_kwargs(self) -> MLPConfigDict:
75
- kwargs: MLPConfigDict = {
51
+ @deprecated("Use `nt.nn.MLP(config=...)` instead.")
52
+ def create_module(
53
+ self,
54
+ dims: Sequence[int],
55
+ pre_layers: Sequence[nn.Module] = [],
56
+ post_layers: Sequence[nn.Module] = [],
57
+ linear_cls: LinearModuleConstructor = nn.Linear,
58
+ ):
59
+ kwargs: dict[str, Any] = {
76
60
  "bias": self.bias,
77
61
  "no_bias_scalar": self.no_bias_scalar,
78
62
  "nonlinearity": self.nonlinearity,
@@ -81,18 +65,9 @@ class MLPConfig(C.Config):
81
65
  "residual": self.residual,
82
66
  "seed": self.seed,
83
67
  }
84
- return kwargs
85
-
86
- def create_module(
87
- self,
88
- dims: Sequence[int],
89
- pre_layers: Sequence[nn.Module] = [],
90
- post_layers: Sequence[nn.Module] = [],
91
- linear_cls: LinearModuleConstructor = nn.Linear,
92
- ):
93
68
  return MLP(
94
69
  dims,
95
- **self.to_kwargs(),
70
+ **kwargs,
96
71
  pre_layers=pre_layers,
97
72
  post_layers=post_layers,
98
73
  linear_cls=linear_cls,
@@ -121,50 +96,73 @@ def MLP(
121
96
  | nn.Module
122
97
  | Callable[[], nn.Module]
123
98
  | None = None,
124
- bias: bool = True,
125
- no_bias_scalar: bool = True,
126
- ln: bool | Literal["pre", "post"] = False,
99
+ bias: bool | None = None,
100
+ no_bias_scalar: bool | None = None,
101
+ ln: bool | Literal["pre", "post"] | None = None,
127
102
  dropout: float | None = None,
128
- residual: bool = False,
103
+ residual: bool | None = None,
129
104
  pre_layers: Sequence[nn.Module] = [],
130
105
  post_layers: Sequence[nn.Module] = [],
131
106
  linear_cls: LinearModuleConstructor = nn.Linear,
132
107
  seed: int | None = None,
108
+ config: MLPConfig | None = None,
133
109
  ):
134
110
  """
135
111
  Constructs a multi-layer perceptron (MLP) with the given dimensions and activation function.
136
112
 
137
113
  Args:
138
114
  dims (Sequence[int]): List of integers representing the dimensions of the MLP.
139
- nonlinearity (Callable[[], nn.Module]): Activation function to use between layers.
140
- activation (Callable[[], nn.Module]): Activation function to use between layers.
141
- bias (bool, optional): Whether to include bias terms in the linear layers. Defaults to True.
142
- no_bias_scalar (bool, optional): Whether to exclude bias terms when the output dimension is 1. Defaults to True.
143
- ln (bool | Literal["pre", "post"], optional): Whether to apply layer normalization before or after the linear layers. Defaults to False.
144
- dropout (float | None, optional): Dropout probability to apply between layers. Defaults to None.
145
- residual (bool, optional): Whether to use residual connections between layers. Defaults to False.
115
+ nonlinearity (Callable[[], nn.Module] | None, optional): Activation function to use between layers.
116
+ activation (Callable[[], nn.Module] | None, optional): Activation function to use between layers.
117
+ bias (bool | None, optional): Whether to include bias terms in the linear layers.
118
+ no_bias_scalar (bool | None, optional): Whether to exclude bias terms when the output dimension is 1.
119
+ ln (bool | Literal["pre", "post"] | None, optional): Whether to apply layer normalization before or after the linear layers.
120
+ dropout (float | None, optional): Dropout probability to apply between layers.
121
+ residual (bool | None, optional): Whether to use residual connections between layers.
146
122
  pre_layers (Sequence[nn.Module], optional): List of layers to insert before the linear layers. Defaults to [].
147
123
  post_layers (Sequence[nn.Module], optional): List of layers to insert after the linear layers. Defaults to [].
148
124
  linear_cls (LinearModuleConstructor, optional): Linear module constructor to use. Defaults to nn.Linear.
149
- seed (int | None, optional): Random seed to use for initialization. If None, the default Torch behavior is used. Defaults to None.
125
+ seed (int | None, optional): Random seed to use for initialization. If None, the default Torch behavior is used.
126
+ config (MLPConfig | None, optional): Configuration object for the MLP. Parameters specified directly take precedence.
150
127
 
151
128
  Returns:
152
129
  nn.Sequential: The constructed MLP.
153
130
  """
154
131
 
155
- with custom_seed_context(seed):
132
+ # Resolve parameters: arg if not None, otherwise config value if config exists, otherwise default
133
+ resolved_bias = bias if bias is not None else (config.bias if config else True)
134
+ resolved_no_bias_scalar = (
135
+ no_bias_scalar
136
+ if no_bias_scalar is not None
137
+ else (config.no_bias_scalar if config else True)
138
+ )
139
+ resolved_nonlinearity = (
140
+ nonlinearity
141
+ if nonlinearity is not None
142
+ else (config.nonlinearity if config else None)
143
+ )
144
+ resolved_ln = ln if ln is not None else (config.ln if config else False)
145
+ resolved_dropout = (
146
+ dropout if dropout is not None else (config.dropout if config else None)
147
+ )
148
+ resolved_residual = (
149
+ residual if residual is not None else (config.residual if config else False)
150
+ )
151
+ resolved_seed = seed if seed is not None else (config.seed if config else None)
152
+
153
+ with custom_seed_context(resolved_seed):
156
154
  if activation is None:
157
- activation = nonlinearity
155
+ activation = resolved_nonlinearity
158
156
 
159
157
  if len(dims) < 2:
160
158
  raise ValueError("mlp requires at least 2 dimensions")
161
- if ln is True:
162
- ln = "pre"
163
- elif isinstance(ln, str) and ln not in ("pre", "post"):
159
+ if resolved_ln is True:
160
+ resolved_ln = "pre"
161
+ elif isinstance(resolved_ln, str) and resolved_ln not in ("pre", "post"):
164
162
  raise ValueError("ln must be a boolean or 'pre' or 'post'")
165
163
 
166
164
  layers: list[nn.Module] = []
167
- if ln == "pre":
165
+ if resolved_ln == "pre":
168
166
  layers.append(nn.LayerNorm(dims[0]))
169
167
 
170
168
  layers.extend(pre_layers)
@@ -172,10 +170,12 @@ def MLP(
172
170
  for i in range(len(dims) - 1):
173
171
  in_features = dims[i]
174
172
  out_features = dims[i + 1]
175
- bias_ = bias and not (no_bias_scalar and out_features == 1)
173
+ bias_ = resolved_bias and not (
174
+ resolved_no_bias_scalar and out_features == 1
175
+ )
176
176
  layers.append(linear_cls(in_features, out_features, bias=bias_))
177
- if dropout is not None:
178
- layers.append(nn.Dropout(dropout))
177
+ if resolved_dropout is not None:
178
+ layers.append(nn.Dropout(resolved_dropout))
179
179
  if i < len(dims) - 2:
180
180
  match activation:
181
181
  case NonlinearityConfigBase():
@@ -192,8 +192,8 @@ def MLP(
192
192
 
193
193
  layers.extend(post_layers)
194
194
 
195
- if ln == "post":
195
+ if resolved_ln == "post":
196
196
  layers.append(nn.LayerNorm(dims[-1]))
197
197
 
198
- cls = ResidualSequential if residual else nn.Sequential
198
+ cls = ResidualSequential if resolved_residual else nn.Sequential
199
199
  return cls(*layers)
@@ -40,6 +40,7 @@ from ..callbacks.base import CallbackConfigBase
40
40
  from ..callbacks.debug_flag import DebugFlagCallbackConfig
41
41
  from ..callbacks.log_epoch import LogEpochCallbackConfig
42
42
  from ..callbacks.lr_monitor import LearningRateMonitorConfig
43
+ from ..callbacks.metric_validation import MetricValidationCallbackConfig
43
44
  from ..callbacks.rlp_sanity_checks import RLPSanityChecksCallbackConfig
44
45
  from ..callbacks.shared_parameters import SharedParametersCallbackConfig
45
46
  from ..loggers import (
@@ -697,6 +698,10 @@ class TrainerConfig(C.Config):
697
698
  - The trainer is running in fast_dev_run mode.
698
699
  - The trainer is running a sanity check (which happens before starting the training routine).
699
700
  """
701
+ auto_validate_metrics: MetricValidationCallbackConfig | None = (
702
+ MetricValidationCallbackConfig()
703
+ )
704
+ """If enabled, will automatically validate the metrics before starting the training routine."""
700
705
 
701
706
  lightning_kwargs: LightningTrainerKwargs = LightningTrainerKwargs()
702
707
  """
@@ -768,6 +773,7 @@ class TrainerConfig(C.Config):
768
773
  yield self.shared_parameters
769
774
  yield self.reduce_lr_on_plateau_sanity_checking
770
775
  yield self.auto_set_debug_flag
776
+ yield self.auto_validate_metrics
771
777
  yield from self.callbacks
772
778
 
773
779
  def _nshtrainer_all_logger_configs(self) -> Iterable[LoggerConfigBase | None]:
@@ -25,7 +25,6 @@ from ..util.bf16 import is_bf16_supported_no_emulation
25
25
  from ._config import LightningTrainerKwargs, TrainerConfig
26
26
  from ._runtime_callback import RuntimeTrackerCallback, Stage
27
27
  from .accelerator import AcceleratorConfigBase
28
- from .plugin import PluginConfigBase
29
28
  from .signal_connector import _SignalConnector
30
29
  from .strategy import StrategyConfigBase
31
30
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: nshtrainer
3
- Version: 1.0.0b43
3
+ Version: 1.0.0b45
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -6,7 +6,7 @@ nshtrainer/_checkpoint/saver.py,sha256=rWl4d2lCTMU4_wt8yZFL2pFQaP9hj5sPgqHMPQ4zu
6
6
  nshtrainer/_directory.py,sha256=TJR9ccyuzRlAVfVjGyeQ3E2AFAcz-XbBCxWfiXo2SlY,3191
7
7
  nshtrainer/_experimental/__init__.py,sha256=U4S_2y3zgLZVfMenHRaJFBW8yqh2mUBuI291LGQVOJ8,35
8
8
  nshtrainer/_hf_hub.py,sha256=4OsCbIITnZk_YLyoMrVyZ0SIN04FBxlC0ig2Et8UAdo,14287
9
- nshtrainer/callbacks/__init__.py,sha256=4giOYT8A709UOLRtQEt16QbOAFUHCjJ_aLB7ITTwXJI,3577
9
+ nshtrainer/callbacks/__init__.py,sha256=w80d6PGNu3wjUj9NiRGMqCX9NnXD5ZlvbY-DIK4zjPE,3766
10
10
  nshtrainer/callbacks/actsave.py,sha256=NSXIIu62MNYe5gz479SMW33bdoKYoYtWtd_iTWFpKpc,3881
11
11
  nshtrainer/callbacks/base.py,sha256=Alaou1IHAIlMEM7g58d_02ozY2xWlshBN7fsw5Ee21s,3683
12
12
  nshtrainer/callbacks/checkpoint/__init__.py,sha256=l8tkHc83_mLiU0-wT09SWdRzwpm2ulbkLzcuCmuTwzE,620
@@ -23,6 +23,7 @@ nshtrainer/callbacks/gradient_skipping.py,sha256=8g7oC7PF0LTAEzwiNoaS5tWOnkjk_EB
23
23
  nshtrainer/callbacks/interval.py,sha256=UCzUzt3XCFVyQyCWL9lOrStkkxesvduNOYk8yMrGTTk,8116
24
24
  nshtrainer/callbacks/log_epoch.py,sha256=B5Dm8XVZwCzKUhUWfT_5PDdDac993191OsbcxxuSVJE,1457
25
25
  nshtrainer/callbacks/lr_monitor.py,sha256=qy_C0R40J0hBAukzBwng5FI2jJUpWuXOi5N6FU6ym3I,1210
26
+ nshtrainer/callbacks/metric_validation.py,sha256=tqUVS2n9QRT3v1_8jAGlYBFhLpA6Bm9pxOsfWhD3yZQ,2915
26
27
  nshtrainer/callbacks/norm_logging.py,sha256=nVIDWe-ASl5zN830-ODR8QMCqI1ma-QPCIwoy0Wb-Nk,6390
27
28
  nshtrainer/callbacks/print_table.py,sha256=VaS4JgI963do79laXK4lUkFQx8v6aRSy22W0zyal_LA,3035
28
29
  nshtrainer/callbacks/rlp_sanity_checks.py,sha256=74BZvV2HLO__ucQXsLXb8eJLUZgRFUNJZ6TL9efMp74,10051
@@ -31,12 +32,12 @@ nshtrainer/callbacks/timer.py,sha256=gDcw_K_ikf0bkVgxQ0cDhvvNvz6GLZVLcatuKfh0ORU
31
32
  nshtrainer/callbacks/wandb_upload_code.py,sha256=shV7UtnXgY2bUlXdVrXiaDs0PNLlIt7TzNJkJPkzvzI,2414
32
33
  nshtrainer/callbacks/wandb_watch.py,sha256=VB14Dy5ZRXQ3di0fPv0K_DFJurLhroLPytnuwQBiJFg,3037
33
34
  nshtrainer/configs/.gitattributes,sha256=VeZmarvNEqiRBOHGcllpKm90nL6C8u4tBu7SEm7fj-E,26
34
- nshtrainer/configs/__init__.py,sha256=MZfcSKhnjtVObBvVv9lu8L2cFTLINP5zcTQvWnz8jdk,14505
35
+ nshtrainer/configs/__init__.py,sha256=0BzCgE1iEJ0Ywmy__mqJZipLQtwZVdz6XK-gHbkA7GY,14650
35
36
  nshtrainer/configs/_checkpoint/__init__.py,sha256=6s7Y68StboqscY2G4P_QG443jz5aiym5SjOogIljWLg,342
36
37
  nshtrainer/configs/_checkpoint/metadata/__init__.py,sha256=oOPfYkXTjKgm6pluGsG6V1TPyCEGjsQpHVL-LffSUFQ,290
37
38
  nshtrainer/configs/_directory/__init__.py,sha256=_oO7vM9DhzHSxtZcv86sTi7hZIptnK1gr-AP9mqQ370,386
38
39
  nshtrainer/configs/_hf_hub/__init__.py,sha256=ciFLbV-JV8SVzqo2SyythEuDMnk7gGfdIacB18QYnkY,511
39
- nshtrainer/configs/callbacks/__init__.py,sha256=jSWkbsdiu9vdGWTzqkDf-Bo9dXr9RengeNZLzWUhi7Y,4283
40
+ nshtrainer/configs/callbacks/__init__.py,sha256=PB3Jg-8_vMhp-mCFw2_Tqt05drKwHK6Ovl9mb8NNiXs,4506
40
41
  nshtrainer/configs/callbacks/actsave/__init__.py,sha256=JvjSZtEoA28FC4u-QT3skQzBDVbN9eq07rn4u2ydW-E,377
41
42
  nshtrainer/configs/callbacks/base/__init__.py,sha256=wT3RhXttLyf6RFWCIvsoiXcPdfGx5W309WBI18AI5os,278
42
43
  nshtrainer/configs/callbacks/checkpoint/__init__.py,sha256=aGJ7vX14YamkMdwYAdPv6XrRnP0aZd5uZ5X0nSLc6IU,1475
@@ -52,6 +53,7 @@ nshtrainer/configs/callbacks/finite_checks/__init__.py,sha256=e-vx9Kn-noqw4wPvZw
52
53
  nshtrainer/configs/callbacks/gradient_skipping/__init__.py,sha256=T3eVxxJfnYBrO9WfLiycn4TyWP4vaqJ57yp7Epkg7B4,485
53
54
  nshtrainer/configs/callbacks/log_epoch/__init__.py,sha256=IQ5owYYvyk7fiQP1QXYtncRRJrESuq3rRFhab-II2uE,419
54
55
  nshtrainer/configs/callbacks/lr_monitor/__init__.py,sha256=qejy1AnXNDHmsFuXRAXQQ5B0TcbKzvpaw-I4dv2AXIs,431
56
+ nshtrainer/configs/callbacks/metric_validation/__init__.py,sha256=_YV0EbISkforE_GDlTTVA6Nn2_l13zX3m1ggcbhnAvs,585
55
57
  nshtrainer/configs/callbacks/norm_logging/__init__.py,sha256=j2LrnYEbDGLwJR2lk-jmh-4J_iLEs2HNEoepvJSFLAg,437
56
58
  nshtrainer/configs/callbacks/print_table/__init__.py,sha256=t6fA_dBkUCszUXDJKEdnlBH4oEpfAQqcmAlatTFYIyQ,452
57
59
  nshtrainer/configs/callbacks/rlp_sanity_checks/__init__.py,sha256=dlP14Wh-w8zG_B4EtNmCIFzVMhf6bXCJ1O9cJWmEFnA,482
@@ -80,8 +82,8 @@ nshtrainer/configs/profiler/_base/__init__.py,sha256=ekYfPg-VDhCAFM5nJka2TxUYdRD
80
82
  nshtrainer/configs/profiler/advanced/__init__.py,sha256=-ThpUat16Ij_0avkMUVVA8wCWDG_q_tM7KQofnWQCtg,308
81
83
  nshtrainer/configs/profiler/pytorch/__init__.py,sha256=soAU1s2_Pa1na4gW8CK-iysJBO5M_7YeZC2_x40iEdg,294
82
84
  nshtrainer/configs/profiler/simple/__init__.py,sha256=3Wb11lPuFuyasq8xS1CZ4WLuBCLS_nVSQGVllvOOi0Y,289
83
- nshtrainer/configs/trainer/__init__.py,sha256=jYCp4Q9uvutA6NYqfthbREMg09-obD3gHtzEI2Ta-hU,7729
84
- nshtrainer/configs/trainer/_config/__init__.py,sha256=uof_oJfhwjB1pft7KsRdk_RvNj-tE8wcDBEM7X5qtNc,3666
85
+ nshtrainer/configs/trainer/__init__.py,sha256=a8pzGVid52abAVARPbgjaN566H1ZM44FH_x95bsBaGE,7880
86
+ nshtrainer/configs/trainer/_config/__init__.py,sha256=6DXdtP-uH11TopQ7kzId9fco-wVkD7ZfevbBqDpN6TE,3817
85
87
  nshtrainer/configs/trainer/accelerator/__init__.py,sha256=3H6R3wlwbKL1TzDqGCChZk78-BcE2czLouo7Djiq3nA,898
86
88
  nshtrainer/configs/trainer/plugin/__init__.py,sha256=NkHQxMPkrtTtdIAO4dQUE9SWEcHRDB0yUXLkTjnl4dA,3332
87
89
  nshtrainer/configs/trainer/plugin/base/__init__.py,sha256=GuubcKrbXt4VjJRT8VpNUQqBtyuutre_CfkS0EWZ5_E,368
@@ -90,7 +92,7 @@ nshtrainer/configs/trainer/plugin/io/__init__.py,sha256=W6G67JnigB6d3MiwLrbSKgtI
90
92
  nshtrainer/configs/trainer/plugin/layer_sync/__init__.py,sha256=SYDZk2M6sgpt4sEuoURuS8EKYmaqGcvYxETE9jvTrEE,431
91
93
  nshtrainer/configs/trainer/plugin/precision/__init__.py,sha256=szlqSfK2XuWdkf72LQzQFv3SlWfKFdRUpBEYIxQ3TPs,1507
92
94
  nshtrainer/configs/trainer/strategy/__init__.py,sha256=50whNloJVBq_bdbLaPQnPBTeS1Rcs8MwxTCYBj1kKa4,273
93
- nshtrainer/configs/trainer/trainer/__init__.py,sha256=QnuhMQNAa1nSVN2o50_WeKAQG_qkNlkeoq9zTjjwmTI,586
95
+ nshtrainer/configs/trainer/trainer/__init__.py,sha256=DDuBRx0kVNMW0z_sqKTUt8-Ql7bOpargi4KcHHvDu_c,486
94
96
  nshtrainer/configs/util/__init__.py,sha256=qXittS7f7MyaqJnjvFLKnKsyb6bXTD3dEV16jXVDaH4,2104
95
97
  nshtrainer/configs/util/_environment_info/__init__.py,sha256=eB4E0Ck7XCeSC5gbUdA5thd7TXnjGCL0t8GZIFj7uCI,1644
96
98
  nshtrainer/configs/util/config/__init__.py,sha256=nEFiDG3-dvvTytYn1tEkPFzp7fgaGRp2j7toSN7yRGs,501
@@ -117,8 +119,8 @@ nshtrainer/model/base.py,sha256=JL3AmH17GQjQIoMrZl3O0vUI7dj5ZsO5iEJgoLPyzHw,1035
117
119
  nshtrainer/model/mixins/callback.py,sha256=0LPgve4VszHbLipid4mpI1qnnmdGS2spivs0dXLvqHw,3154
118
120
  nshtrainer/model/mixins/debug.py,sha256=1LX9KzeFX9JDPs_a6YCdYDZXLhEk_5rBO2aCqlfBy7w,2087
119
121
  nshtrainer/model/mixins/logger.py,sha256=27H99FuLaxc6_dDLG2pid4E_5E0-eLGnc2Ifpt0HYIM,6066
120
- nshtrainer/nn/__init__.py,sha256=0FgeoaLYtRiSLT8fdPigLD8t-d8DKR8IQDw16JA9lT4,1523
121
- nshtrainer/nn/mlp.py,sha256=_a8rJJniSCvM08gyQGO-5MUoO18U9_FSGGn3tZL2_U4,7101
122
+ nshtrainer/nn/__init__.py,sha256=5Gg3nieGSC5_dXaI9KUVUUbM13hHexH9831m4hcf6no,1475
123
+ nshtrainer/nn/mlp.py,sha256=nYUgAISzuhC8sav6PloAdyz0PdEoikwppiXIuToEVdE,7550
122
124
  nshtrainer/nn/module_dict.py,sha256=9plb8aQUx5TUEPhX5jI9u8LrpTeKe7jZAHi8iIqcN8w,2365
123
125
  nshtrainer/nn/module_list.py,sha256=UB43pcwD_3nUke_DyLQt-iXKhWdKM6Zjm84lRC1hPYA,1755
124
126
  nshtrainer/nn/nonlinearity.py,sha256=xmaL4QCRvCxqmaGIOwetJeKK-6IK4m2OV7D3SjxSwJQ,6322
@@ -129,7 +131,7 @@ nshtrainer/profiler/advanced.py,sha256=XrM3FX0ThCv5UwUrrH0l4Ow4LGAtpiBww2N8QAU5N
129
131
  nshtrainer/profiler/pytorch.py,sha256=8K37XvPnCApUpIK8tA2zNMFIaIiTLSoxKQoiyCPBm1Q,2757
130
132
  nshtrainer/profiler/simple.py,sha256=PimjqcU-JuS-8C0ZGHAdwCxgNLij4x0FH6WXsjBQzZs,1005
131
133
  nshtrainer/trainer/__init__.py,sha256=fQ7gQRlGWX-90TYT0rttkQyvXDCzo7DAvJgr-jX1zsY,316
132
- nshtrainer/trainer/_config.py,sha256=QDy6sINVDGEqfHfPTWXSN-06EoEuMSVscHn8fCRTvr0,32981
134
+ nshtrainer/trainer/_config.py,sha256=pCBRtqIC_BzNPqthsDhd7L5_7DG5y8_uVq19lj1mtOM,33311
133
135
  nshtrainer/trainer/_runtime_callback.py,sha256=6F2Gq27Q8OFfN3RtdNC6QRA8ac0LC1hh4DUE3V5WgbI,4217
134
136
  nshtrainer/trainer/accelerator.py,sha256=Bqq-ry7DeCY4zw9_zBvTZiijpA-uUHrDjtbLV652m4M,2415
135
137
  nshtrainer/trainer/plugin/__init__.py,sha256=UM8f70Ml3RGpsXeQr1Yh1yBcxxjFNGFGgJbniCn_rws,366
@@ -140,7 +142,7 @@ nshtrainer/trainer/plugin/layer_sync.py,sha256=h-ydZwXepnsw5-paLgiDatqPyQ_8C0QEv
140
142
  nshtrainer/trainer/plugin/precision.py,sha256=I0QsB1bVxmsFmBOkgrAfGONsuYae_lD9Bz0PfJEQvH4,5598
141
143
  nshtrainer/trainer/signal_connector.py,sha256=GhfGcSzfaTNhnj2QFkBDq5aT7FqbLMA7eC8SYQs8_8w,10828
142
144
  nshtrainer/trainer/strategy.py,sha256=VPTn5z3zvXTydY8IJchjhjcOfpvtoejnvUkq5E4WTus,1368
143
- nshtrainer/trainer/trainer.py,sha256=ed_Pn-yQCb9BqaHXo2wVhkt2CSfGNEzMAM6RsDoTo-I,20834
145
+ nshtrainer/trainer/trainer.py,sha256=8wMe0qArbDfStS4UdmuKSC2aiAImR3mhj14_kCJiNSM,20797
144
146
  nshtrainer/util/_environment_info.py,sha256=MT8mBe6ZolRfKiwU-les1P-lPNPqXpHQcfADrh_A3uY,24629
145
147
  nshtrainer/util/bf16.py,sha256=9QhHZCkYSfYpIcxwAMoXyuh2yTSHBzT-EdLQB297jEs,762
146
148
  nshtrainer/util/config/__init__.py,sha256=Z39JJufSb61Lhn2GfVcv3eFW_eorOrN9-9llDWlnZZM,272
@@ -152,6 +154,6 @@ nshtrainer/util/seed.py,sha256=diMV8iwBKN7Xxt5pELmui-gyqyT80_CZzomrWhNss0k,316
152
154
  nshtrainer/util/slurm.py,sha256=HflkP5iI_r4UHMyPjw9R4dD5AHsJUpcfJw5PLvGYBRM,1603
153
155
  nshtrainer/util/typed.py,sha256=Xt5fUU6zwLKSTLUdenovnKK0N8qUq89Kddz2_XeykVQ,164
154
156
  nshtrainer/util/typing_utils.py,sha256=MjY-CUX9R5Tzat-BlFnQjwl1PQ_W2yZQoXhkYHlJ_VA,442
155
- nshtrainer-1.0.0b43.dist-info/METADATA,sha256=ZE3l6CN34ptFgx3SDPfKIgjdV2s3J8qdP729eb58vzo,988
156
- nshtrainer-1.0.0b43.dist-info/WHEEL,sha256=XbeZDeTWKc1w7CSIyre5aMDU_-PohRwTQceYnisIYYY,88
157
- nshtrainer-1.0.0b43.dist-info/RECORD,,
157
+ nshtrainer-1.0.0b45.dist-info/METADATA,sha256=_RPpe6F7DXpsQSmBF1GTc-E5VUfaC69fIYfoFhsip2s,988
158
+ nshtrainer-1.0.0b45.dist-info/WHEEL,sha256=XbeZDeTWKc1w7CSIyre5aMDU_-PohRwTQceYnisIYYY,88
159
+ nshtrainer-1.0.0b45.dist-info/RECORD,,