nshtrainer 1.0.0b44__py3-none-any.whl → 1.0.0b45__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -5,8 +5,8 @@ from typing import Literal
5
5
 
6
6
  from lightning.pytorch.utilities.exceptions import MisconfigurationException
7
7
  from typing_extensions import final, override, assert_never
8
-
9
- from .._callback import NTCallbackBase
8
+ from lightning.pytorch import Trainer
9
+ from lightning.pytorch.callbacks import Callback
10
10
  from ..metrics import MetricConfig
11
11
  from .base import CallbackConfigBase, callback_registry
12
12
 
@@ -43,33 +43,48 @@ class MetricValidationCallbackConfig(CallbackConfigBase):
43
43
  yield MetricValidationCallback(self, metrics)
44
44
 
45
45
 
46
- class MetricValidationCallback(NTCallbackBase):
46
+ class MetricValidationCallback(Callback):
47
47
  def __init__(
48
- self, config: MetricValidationCallbackConfig, metrics: list[MetricConfig]
48
+ self,
49
+ config: MetricValidationCallbackConfig,
50
+ metrics: list[MetricConfig],
49
51
  ):
50
52
  super().__init__()
51
53
 
52
54
  self.config = config
53
55
  self.metrics = metrics
54
56
 
55
- @override
56
- def on_sanity_check_end(self, trainer, pl_module):
57
- super().on_sanity_check_end(trainer, pl_module)
58
-
59
- log.debug("Validating metrics...")
57
+ def _check_metrics(self, trainer: Trainer):
58
+ metric_names = ", ".join(metric.validation_monitor for metric in self.metrics)
59
+ log.info(f"Validating metrics: {metric_names}...")
60
60
  logged_metrics = set(trainer.logged_metrics.keys())
61
- for metric in self.metrics:
62
- if metric.validation_monitor in logged_metrics:
63
- continue
64
61
 
62
+ invalid_metrics: list[str] = []
63
+ for metric in self.metrics:
64
+ if metric.validation_monitor not in logged_metrics:
65
+ invalid_metrics.append(metric.validation_monitor)
66
+
67
+ if invalid_metrics:
68
+ msg = (
69
+ f"The following metrics were not found in logged metrics: {invalid_metrics}\n"
70
+ f"List of logged metrics: {list(trainer.logged_metrics.keys())}"
71
+ )
65
72
  match self.config.error_behavior:
66
73
  case "raise":
67
- raise MisconfigurationException(
68
- f"Metric '{metric.validation_monitor}' not found in logged metrics."
69
- )
74
+ raise MisconfigurationException(msg)
70
75
  case "warn":
71
- log.warning(
72
- f"Metric '{metric.validation_monitor}' not found in logged metrics."
73
- )
76
+ log.warning(msg)
74
77
  case _:
75
78
  assert_never(self.config.error_behavior)
79
+
80
+ @override
81
+ def on_sanity_check_end(self, trainer, pl_module):
82
+ super().on_sanity_check_end(trainer, pl_module)
83
+
84
+ self._check_metrics(trainer)
85
+
86
+ @override
87
+ def on_validation_end(self, trainer, pl_module):
88
+ super().on_validation_end(trainer, pl_module)
89
+
90
+ self._check_metrics(trainer)
@@ -4,14 +4,12 @@ __codegen__ = True
4
4
 
5
5
  from nshtrainer.trainer.trainer import AcceleratorConfigBase as AcceleratorConfigBase
6
6
  from nshtrainer.trainer.trainer import EnvironmentConfig as EnvironmentConfig
7
- from nshtrainer.trainer.trainer import PluginConfigBase as PluginConfigBase
8
7
  from nshtrainer.trainer.trainer import StrategyConfigBase as StrategyConfigBase
9
8
  from nshtrainer.trainer.trainer import TrainerConfig as TrainerConfig
10
9
 
11
10
  __all__ = [
12
11
  "AcceleratorConfigBase",
13
12
  "EnvironmentConfig",
14
- "PluginConfigBase",
15
13
  "StrategyConfigBase",
16
14
  "TrainerConfig",
17
15
  ]
nshtrainer/nn/__init__.py CHANGED
@@ -2,7 +2,6 @@ from __future__ import annotations
2
2
 
3
3
  from .mlp import MLP as MLP
4
4
  from .mlp import MLPConfig as MLPConfig
5
- from .mlp import MLPConfigDict as MLPConfigDict
6
5
  from .mlp import ResidualSequential as ResidualSequential
7
6
  from .mlp import custom_seed_context as custom_seed_context
8
7
  from .module_dict import TypedModuleDict as TypedModuleDict
nshtrainer/nn/mlp.py CHANGED
@@ -3,12 +3,12 @@ from __future__ import annotations
3
3
  import contextlib
4
4
  import copy
5
5
  from collections.abc import Callable, Sequence
6
- from typing import Literal, Protocol, runtime_checkable
6
+ from typing import Any, Literal, Protocol, runtime_checkable
7
7
 
8
8
  import nshconfig as C
9
9
  import torch
10
10
  import torch.nn as nn
11
- from typing_extensions import TypedDict, override
11
+ from typing_extensions import deprecated, override
12
12
 
13
13
  from .nonlinearity import NonlinearityConfig, NonlinearityConfigBase
14
14
 
@@ -26,29 +26,6 @@ class ResidualSequential(nn.Sequential):
26
26
  return input + super().forward(input)
27
27
 
28
28
 
29
- class MLPConfigDict(TypedDict):
30
- bias: bool
31
- """Whether to include bias terms in the linear layers."""
32
-
33
- no_bias_scalar: bool
34
- """Whether to exclude bias terms when the output dimension is 1."""
35
-
36
- nonlinearity: NonlinearityConfig | None
37
- """Activation function to use between layers."""
38
-
39
- ln: bool | Literal["pre", "post"]
40
- """Whether to apply layer normalization before or after the linear layers."""
41
-
42
- dropout: float | None
43
- """Dropout probability to apply between layers."""
44
-
45
- residual: bool
46
- """Whether to use residual connections between layers."""
47
-
48
- seed: int | None
49
- """Random seed to use for initialization. If None, the default Torch behavior is used."""
50
-
51
-
52
29
  class MLPConfig(C.Config):
53
30
  bias: bool = True
54
31
  """Whether to include bias terms in the linear layers."""
@@ -71,8 +48,15 @@ class MLPConfig(C.Config):
71
48
  seed: int | None = None
72
49
  """Random seed to use for initialization. If None, the default Torch behavior is used."""
73
50
 
74
- def to_kwargs(self) -> MLPConfigDict:
75
- kwargs: MLPConfigDict = {
51
+ @deprecated("Use `nt.nn.MLP(config=...)` instead.")
52
+ def create_module(
53
+ self,
54
+ dims: Sequence[int],
55
+ pre_layers: Sequence[nn.Module] = [],
56
+ post_layers: Sequence[nn.Module] = [],
57
+ linear_cls: LinearModuleConstructor = nn.Linear,
58
+ ):
59
+ kwargs: dict[str, Any] = {
76
60
  "bias": self.bias,
77
61
  "no_bias_scalar": self.no_bias_scalar,
78
62
  "nonlinearity": self.nonlinearity,
@@ -81,18 +65,9 @@ class MLPConfig(C.Config):
81
65
  "residual": self.residual,
82
66
  "seed": self.seed,
83
67
  }
84
- return kwargs
85
-
86
- def create_module(
87
- self,
88
- dims: Sequence[int],
89
- pre_layers: Sequence[nn.Module] = [],
90
- post_layers: Sequence[nn.Module] = [],
91
- linear_cls: LinearModuleConstructor = nn.Linear,
92
- ):
93
68
  return MLP(
94
69
  dims,
95
- **self.to_kwargs(),
70
+ **kwargs,
96
71
  pre_layers=pre_layers,
97
72
  post_layers=post_layers,
98
73
  linear_cls=linear_cls,
@@ -121,50 +96,73 @@ def MLP(
121
96
  | nn.Module
122
97
  | Callable[[], nn.Module]
123
98
  | None = None,
124
- bias: bool = True,
125
- no_bias_scalar: bool = True,
126
- ln: bool | Literal["pre", "post"] = False,
99
+ bias: bool | None = None,
100
+ no_bias_scalar: bool | None = None,
101
+ ln: bool | Literal["pre", "post"] | None = None,
127
102
  dropout: float | None = None,
128
- residual: bool = False,
103
+ residual: bool | None = None,
129
104
  pre_layers: Sequence[nn.Module] = [],
130
105
  post_layers: Sequence[nn.Module] = [],
131
106
  linear_cls: LinearModuleConstructor = nn.Linear,
132
107
  seed: int | None = None,
108
+ config: MLPConfig | None = None,
133
109
  ):
134
110
  """
135
111
  Constructs a multi-layer perceptron (MLP) with the given dimensions and activation function.
136
112
 
137
113
  Args:
138
114
  dims (Sequence[int]): List of integers representing the dimensions of the MLP.
139
- nonlinearity (Callable[[], nn.Module]): Activation function to use between layers.
140
- activation (Callable[[], nn.Module]): Activation function to use between layers.
141
- bias (bool, optional): Whether to include bias terms in the linear layers. Defaults to True.
142
- no_bias_scalar (bool, optional): Whether to exclude bias terms when the output dimension is 1. Defaults to True.
143
- ln (bool | Literal["pre", "post"], optional): Whether to apply layer normalization before or after the linear layers. Defaults to False.
144
- dropout (float | None, optional): Dropout probability to apply between layers. Defaults to None.
145
- residual (bool, optional): Whether to use residual connections between layers. Defaults to False.
115
+ nonlinearity (Callable[[], nn.Module] | None, optional): Activation function to use between layers.
116
+ activation (Callable[[], nn.Module] | None, optional): Activation function to use between layers.
117
+ bias (bool | None, optional): Whether to include bias terms in the linear layers.
118
+ no_bias_scalar (bool | None, optional): Whether to exclude bias terms when the output dimension is 1.
119
+ ln (bool | Literal["pre", "post"] | None, optional): Whether to apply layer normalization before or after the linear layers.
120
+ dropout (float | None, optional): Dropout probability to apply between layers.
121
+ residual (bool | None, optional): Whether to use residual connections between layers.
146
122
  pre_layers (Sequence[nn.Module], optional): List of layers to insert before the linear layers. Defaults to [].
147
123
  post_layers (Sequence[nn.Module], optional): List of layers to insert after the linear layers. Defaults to [].
148
124
  linear_cls (LinearModuleConstructor, optional): Linear module constructor to use. Defaults to nn.Linear.
149
- seed (int | None, optional): Random seed to use for initialization. If None, the default Torch behavior is used. Defaults to None.
125
+ seed (int | None, optional): Random seed to use for initialization. If None, the default Torch behavior is used.
126
+ config (MLPConfig | None, optional): Configuration object for the MLP. Parameters specified directly take precedence.
150
127
 
151
128
  Returns:
152
129
  nn.Sequential: The constructed MLP.
153
130
  """
154
131
 
155
- with custom_seed_context(seed):
132
+ # Resolve parameters: arg if not None, otherwise config value if config exists, otherwise default
133
+ resolved_bias = bias if bias is not None else (config.bias if config else True)
134
+ resolved_no_bias_scalar = (
135
+ no_bias_scalar
136
+ if no_bias_scalar is not None
137
+ else (config.no_bias_scalar if config else True)
138
+ )
139
+ resolved_nonlinearity = (
140
+ nonlinearity
141
+ if nonlinearity is not None
142
+ else (config.nonlinearity if config else None)
143
+ )
144
+ resolved_ln = ln if ln is not None else (config.ln if config else False)
145
+ resolved_dropout = (
146
+ dropout if dropout is not None else (config.dropout if config else None)
147
+ )
148
+ resolved_residual = (
149
+ residual if residual is not None else (config.residual if config else False)
150
+ )
151
+ resolved_seed = seed if seed is not None else (config.seed if config else None)
152
+
153
+ with custom_seed_context(resolved_seed):
156
154
  if activation is None:
157
- activation = nonlinearity
155
+ activation = resolved_nonlinearity
158
156
 
159
157
  if len(dims) < 2:
160
158
  raise ValueError("mlp requires at least 2 dimensions")
161
- if ln is True:
162
- ln = "pre"
163
- elif isinstance(ln, str) and ln not in ("pre", "post"):
159
+ if resolved_ln is True:
160
+ resolved_ln = "pre"
161
+ elif isinstance(resolved_ln, str) and resolved_ln not in ("pre", "post"):
164
162
  raise ValueError("ln must be a boolean or 'pre' or 'post'")
165
163
 
166
164
  layers: list[nn.Module] = []
167
- if ln == "pre":
165
+ if resolved_ln == "pre":
168
166
  layers.append(nn.LayerNorm(dims[0]))
169
167
 
170
168
  layers.extend(pre_layers)
@@ -172,10 +170,12 @@ def MLP(
172
170
  for i in range(len(dims) - 1):
173
171
  in_features = dims[i]
174
172
  out_features = dims[i + 1]
175
- bias_ = bias and not (no_bias_scalar and out_features == 1)
173
+ bias_ = resolved_bias and not (
174
+ resolved_no_bias_scalar and out_features == 1
175
+ )
176
176
  layers.append(linear_cls(in_features, out_features, bias=bias_))
177
- if dropout is not None:
178
- layers.append(nn.Dropout(dropout))
177
+ if resolved_dropout is not None:
178
+ layers.append(nn.Dropout(resolved_dropout))
179
179
  if i < len(dims) - 2:
180
180
  match activation:
181
181
  case NonlinearityConfigBase():
@@ -192,8 +192,8 @@ def MLP(
192
192
 
193
193
  layers.extend(post_layers)
194
194
 
195
- if ln == "post":
195
+ if resolved_ln == "post":
196
196
  layers.append(nn.LayerNorm(dims[-1]))
197
197
 
198
- cls = ResidualSequential if residual else nn.Sequential
198
+ cls = ResidualSequential if resolved_residual else nn.Sequential
199
199
  return cls(*layers)
@@ -25,7 +25,6 @@ from ..util.bf16 import is_bf16_supported_no_emulation
25
25
  from ._config import LightningTrainerKwargs, TrainerConfig
26
26
  from ._runtime_callback import RuntimeTrackerCallback, Stage
27
27
  from .accelerator import AcceleratorConfigBase
28
- from .plugin import PluginConfigBase
29
28
  from .signal_connector import _SignalConnector
30
29
  from .strategy import StrategyConfigBase
31
30
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: nshtrainer
3
- Version: 1.0.0b44
3
+ Version: 1.0.0b45
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -23,7 +23,7 @@ nshtrainer/callbacks/gradient_skipping.py,sha256=8g7oC7PF0LTAEzwiNoaS5tWOnkjk_EB
23
23
  nshtrainer/callbacks/interval.py,sha256=UCzUzt3XCFVyQyCWL9lOrStkkxesvduNOYk8yMrGTTk,8116
24
24
  nshtrainer/callbacks/log_epoch.py,sha256=B5Dm8XVZwCzKUhUWfT_5PDdDac993191OsbcxxuSVJE,1457
25
25
  nshtrainer/callbacks/lr_monitor.py,sha256=qy_C0R40J0hBAukzBwng5FI2jJUpWuXOi5N6FU6ym3I,1210
26
- nshtrainer/callbacks/metric_validation.py,sha256=4bMMHVQ7rBbveDiowZS7Wwr77rE8HrerIbo3n9OddPA,2406
26
+ nshtrainer/callbacks/metric_validation.py,sha256=tqUVS2n9QRT3v1_8jAGlYBFhLpA6Bm9pxOsfWhD3yZQ,2915
27
27
  nshtrainer/callbacks/norm_logging.py,sha256=nVIDWe-ASl5zN830-ODR8QMCqI1ma-QPCIwoy0Wb-Nk,6390
28
28
  nshtrainer/callbacks/print_table.py,sha256=VaS4JgI963do79laXK4lUkFQx8v6aRSy22W0zyal_LA,3035
29
29
  nshtrainer/callbacks/rlp_sanity_checks.py,sha256=74BZvV2HLO__ucQXsLXb8eJLUZgRFUNJZ6TL9efMp74,10051
@@ -92,7 +92,7 @@ nshtrainer/configs/trainer/plugin/io/__init__.py,sha256=W6G67JnigB6d3MiwLrbSKgtI
92
92
  nshtrainer/configs/trainer/plugin/layer_sync/__init__.py,sha256=SYDZk2M6sgpt4sEuoURuS8EKYmaqGcvYxETE9jvTrEE,431
93
93
  nshtrainer/configs/trainer/plugin/precision/__init__.py,sha256=szlqSfK2XuWdkf72LQzQFv3SlWfKFdRUpBEYIxQ3TPs,1507
94
94
  nshtrainer/configs/trainer/strategy/__init__.py,sha256=50whNloJVBq_bdbLaPQnPBTeS1Rcs8MwxTCYBj1kKa4,273
95
- nshtrainer/configs/trainer/trainer/__init__.py,sha256=QnuhMQNAa1nSVN2o50_WeKAQG_qkNlkeoq9zTjjwmTI,586
95
+ nshtrainer/configs/trainer/trainer/__init__.py,sha256=DDuBRx0kVNMW0z_sqKTUt8-Ql7bOpargi4KcHHvDu_c,486
96
96
  nshtrainer/configs/util/__init__.py,sha256=qXittS7f7MyaqJnjvFLKnKsyb6bXTD3dEV16jXVDaH4,2104
97
97
  nshtrainer/configs/util/_environment_info/__init__.py,sha256=eB4E0Ck7XCeSC5gbUdA5thd7TXnjGCL0t8GZIFj7uCI,1644
98
98
  nshtrainer/configs/util/config/__init__.py,sha256=nEFiDG3-dvvTytYn1tEkPFzp7fgaGRp2j7toSN7yRGs,501
@@ -119,8 +119,8 @@ nshtrainer/model/base.py,sha256=JL3AmH17GQjQIoMrZl3O0vUI7dj5ZsO5iEJgoLPyzHw,1035
119
119
  nshtrainer/model/mixins/callback.py,sha256=0LPgve4VszHbLipid4mpI1qnnmdGS2spivs0dXLvqHw,3154
120
120
  nshtrainer/model/mixins/debug.py,sha256=1LX9KzeFX9JDPs_a6YCdYDZXLhEk_5rBO2aCqlfBy7w,2087
121
121
  nshtrainer/model/mixins/logger.py,sha256=27H99FuLaxc6_dDLG2pid4E_5E0-eLGnc2Ifpt0HYIM,6066
122
- nshtrainer/nn/__init__.py,sha256=0FgeoaLYtRiSLT8fdPigLD8t-d8DKR8IQDw16JA9lT4,1523
123
- nshtrainer/nn/mlp.py,sha256=_a8rJJniSCvM08gyQGO-5MUoO18U9_FSGGn3tZL2_U4,7101
122
+ nshtrainer/nn/__init__.py,sha256=5Gg3nieGSC5_dXaI9KUVUUbM13hHexH9831m4hcf6no,1475
123
+ nshtrainer/nn/mlp.py,sha256=nYUgAISzuhC8sav6PloAdyz0PdEoikwppiXIuToEVdE,7550
124
124
  nshtrainer/nn/module_dict.py,sha256=9plb8aQUx5TUEPhX5jI9u8LrpTeKe7jZAHi8iIqcN8w,2365
125
125
  nshtrainer/nn/module_list.py,sha256=UB43pcwD_3nUke_DyLQt-iXKhWdKM6Zjm84lRC1hPYA,1755
126
126
  nshtrainer/nn/nonlinearity.py,sha256=xmaL4QCRvCxqmaGIOwetJeKK-6IK4m2OV7D3SjxSwJQ,6322
@@ -142,7 +142,7 @@ nshtrainer/trainer/plugin/layer_sync.py,sha256=h-ydZwXepnsw5-paLgiDatqPyQ_8C0QEv
142
142
  nshtrainer/trainer/plugin/precision.py,sha256=I0QsB1bVxmsFmBOkgrAfGONsuYae_lD9Bz0PfJEQvH4,5598
143
143
  nshtrainer/trainer/signal_connector.py,sha256=GhfGcSzfaTNhnj2QFkBDq5aT7FqbLMA7eC8SYQs8_8w,10828
144
144
  nshtrainer/trainer/strategy.py,sha256=VPTn5z3zvXTydY8IJchjhjcOfpvtoejnvUkq5E4WTus,1368
145
- nshtrainer/trainer/trainer.py,sha256=ed_Pn-yQCb9BqaHXo2wVhkt2CSfGNEzMAM6RsDoTo-I,20834
145
+ nshtrainer/trainer/trainer.py,sha256=8wMe0qArbDfStS4UdmuKSC2aiAImR3mhj14_kCJiNSM,20797
146
146
  nshtrainer/util/_environment_info.py,sha256=MT8mBe6ZolRfKiwU-les1P-lPNPqXpHQcfADrh_A3uY,24629
147
147
  nshtrainer/util/bf16.py,sha256=9QhHZCkYSfYpIcxwAMoXyuh2yTSHBzT-EdLQB297jEs,762
148
148
  nshtrainer/util/config/__init__.py,sha256=Z39JJufSb61Lhn2GfVcv3eFW_eorOrN9-9llDWlnZZM,272
@@ -154,6 +154,6 @@ nshtrainer/util/seed.py,sha256=diMV8iwBKN7Xxt5pELmui-gyqyT80_CZzomrWhNss0k,316
154
154
  nshtrainer/util/slurm.py,sha256=HflkP5iI_r4UHMyPjw9R4dD5AHsJUpcfJw5PLvGYBRM,1603
155
155
  nshtrainer/util/typed.py,sha256=Xt5fUU6zwLKSTLUdenovnKK0N8qUq89Kddz2_XeykVQ,164
156
156
  nshtrainer/util/typing_utils.py,sha256=MjY-CUX9R5Tzat-BlFnQjwl1PQ_W2yZQoXhkYHlJ_VA,442
157
- nshtrainer-1.0.0b44.dist-info/METADATA,sha256=u_dApZgfGst9vUiKBgnFQhGB0pBeULPOeGlaQ5-CPnI,988
158
- nshtrainer-1.0.0b44.dist-info/WHEEL,sha256=XbeZDeTWKc1w7CSIyre5aMDU_-PohRwTQceYnisIYYY,88
159
- nshtrainer-1.0.0b44.dist-info/RECORD,,
157
+ nshtrainer-1.0.0b45.dist-info/METADATA,sha256=_RPpe6F7DXpsQSmBF1GTc-E5VUfaC69fIYfoFhsip2s,988
158
+ nshtrainer-1.0.0b45.dist-info/WHEEL,sha256=XbeZDeTWKc1w7CSIyre5aMDU_-PohRwTQceYnisIYYY,88
159
+ nshtrainer-1.0.0b45.dist-info/RECORD,,