nshtrainer 1.0.0b39__py3-none-any.whl → 1.0.0b41__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/configs/.gitattributes +1 -0
- nshtrainer/nn/__init__.py +1 -0
- nshtrainer/nn/mlp.py +70 -45
- nshtrainer/nn/tests/test_mlp.py +55 -0
- {nshtrainer-1.0.0b39.dist-info → nshtrainer-1.0.0b41.dist-info}/METADATA +2 -2
- {nshtrainer-1.0.0b39.dist-info → nshtrainer-1.0.0b41.dist-info}/RECORD +7 -5
- {nshtrainer-1.0.0b39.dist-info → nshtrainer-1.0.0b41.dist-info}/WHEEL +1 -1
@@ -0,0 +1 @@
|
|
1
|
+
* linguist-generated=true
|
nshtrainer/nn/__init__.py
CHANGED
@@ -4,6 +4,7 @@ from .mlp import MLP as MLP
|
|
4
4
|
from .mlp import MLPConfig as MLPConfig
|
5
5
|
from .mlp import MLPConfigDict as MLPConfigDict
|
6
6
|
from .mlp import ResidualSequential as ResidualSequential
|
7
|
+
from .mlp import custom_seed_context as custom_seed_context
|
7
8
|
from .module_dict import TypedModuleDict as TypedModuleDict
|
8
9
|
from .module_list import TypedModuleList as TypedModuleList
|
9
10
|
from .nonlinearity import ELUNonlinearityConfig as ELUNonlinearityConfig
|
nshtrainer/nn/mlp.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
import contextlib
|
3
4
|
import copy
|
4
5
|
from collections.abc import Callable, Sequence
|
5
6
|
from typing import Literal, Protocol, runtime_checkable
|
@@ -44,6 +45,9 @@ class MLPConfigDict(TypedDict):
|
|
44
45
|
residual: bool
|
45
46
|
"""Whether to use residual connections between layers."""
|
46
47
|
|
48
|
+
seed: int | None
|
49
|
+
"""Random seed to use for initialization. If None, the default Torch behavior is used."""
|
50
|
+
|
47
51
|
|
48
52
|
class MLPConfig(C.Config):
|
49
53
|
bias: bool = True
|
@@ -64,15 +68,20 @@ class MLPConfig(C.Config):
|
|
64
68
|
residual: bool = False
|
65
69
|
"""Whether to use residual connections between layers."""
|
66
70
|
|
71
|
+
seed: int | None = None
|
72
|
+
"""Random seed to use for initialization. If None, the default Torch behavior is used."""
|
73
|
+
|
67
74
|
def to_kwargs(self) -> MLPConfigDict:
|
68
|
-
|
75
|
+
kwargs: MLPConfigDict = {
|
69
76
|
"bias": self.bias,
|
70
77
|
"no_bias_scalar": self.no_bias_scalar,
|
71
78
|
"nonlinearity": self.nonlinearity,
|
72
79
|
"ln": self.ln,
|
73
80
|
"dropout": self.dropout,
|
74
81
|
"residual": self.residual,
|
82
|
+
"seed": self.seed,
|
75
83
|
}
|
84
|
+
return kwargs
|
76
85
|
|
77
86
|
def create_module(
|
78
87
|
self,
|
@@ -90,6 +99,18 @@ class MLPConfig(C.Config):
|
|
90
99
|
)
|
91
100
|
|
92
101
|
|
102
|
+
@contextlib.contextmanager
|
103
|
+
def custom_seed_context(seed: int | None):
|
104
|
+
with contextlib.ExitStack() as stack:
|
105
|
+
if seed is not None:
|
106
|
+
stack.enter_context(
|
107
|
+
torch.random.fork_rng(devices=range(torch.cuda.device_count()))
|
108
|
+
)
|
109
|
+
torch.manual_seed(seed)
|
110
|
+
|
111
|
+
yield
|
112
|
+
|
113
|
+
|
93
114
|
def MLP(
|
94
115
|
dims: Sequence[int],
|
95
116
|
activation: NonlinearityConfigBase
|
@@ -108,6 +129,7 @@ def MLP(
|
|
108
129
|
pre_layers: Sequence[nn.Module] = [],
|
109
130
|
post_layers: Sequence[nn.Module] = [],
|
110
131
|
linear_cls: LinearModuleConstructor = nn.Linear,
|
132
|
+
seed: int | None = None,
|
111
133
|
):
|
112
134
|
"""
|
113
135
|
Constructs a multi-layer perceptron (MLP) with the given dimensions and activation function.
|
@@ -123,52 +145,55 @@ def MLP(
|
|
123
145
|
residual (bool, optional): Whether to use residual connections between layers. Defaults to False.
|
124
146
|
pre_layers (Sequence[nn.Module], optional): List of layers to insert before the linear layers. Defaults to [].
|
125
147
|
post_layers (Sequence[nn.Module], optional): List of layers to insert after the linear layers. Defaults to [].
|
148
|
+
linear_cls (LinearModuleConstructor, optional): Linear module constructor to use. Defaults to nn.Linear.
|
149
|
+
seed (int | None, optional): Random seed to use for initialization. If None, the default Torch behavior is used. Defaults to None.
|
126
150
|
|
127
151
|
Returns:
|
128
152
|
nn.Sequential: The constructed MLP.
|
129
153
|
"""
|
130
154
|
|
131
|
-
|
132
|
-
activation
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
ln
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
155
|
+
with custom_seed_context(seed):
|
156
|
+
if activation is None:
|
157
|
+
activation = nonlinearity
|
158
|
+
|
159
|
+
if len(dims) < 2:
|
160
|
+
raise ValueError("mlp requires at least 2 dimensions")
|
161
|
+
if ln is True:
|
162
|
+
ln = "pre"
|
163
|
+
elif isinstance(ln, str) and ln not in ("pre", "post"):
|
164
|
+
raise ValueError("ln must be a boolean or 'pre' or 'post'")
|
165
|
+
|
166
|
+
layers: list[nn.Module] = []
|
167
|
+
if ln == "pre":
|
168
|
+
layers.append(nn.LayerNorm(dims[0]))
|
169
|
+
|
170
|
+
layers.extend(pre_layers)
|
171
|
+
|
172
|
+
for i in range(len(dims) - 1):
|
173
|
+
in_features = dims[i]
|
174
|
+
out_features = dims[i + 1]
|
175
|
+
bias_ = bias and not (no_bias_scalar and out_features == 1)
|
176
|
+
layers.append(linear_cls(in_features, out_features, bias=bias_))
|
177
|
+
if dropout is not None:
|
178
|
+
layers.append(nn.Dropout(dropout))
|
179
|
+
if i < len(dims) - 2:
|
180
|
+
match activation:
|
181
|
+
case NonlinearityConfigBase():
|
182
|
+
layers.append(activation.create_module())
|
183
|
+
case nn.Module():
|
184
|
+
# In this case, we create a deep copy of the module to avoid sharing parameters (if any).
|
185
|
+
layers.append(copy.deepcopy(activation))
|
186
|
+
case Callable():
|
187
|
+
layers.append(activation())
|
188
|
+
case _:
|
189
|
+
raise ValueError(
|
190
|
+
"Either `nonlinearity` or `activation` must be provided"
|
191
|
+
)
|
192
|
+
|
193
|
+
layers.extend(post_layers)
|
194
|
+
|
195
|
+
if ln == "post":
|
196
|
+
layers.append(nn.LayerNorm(dims[-1]))
|
197
|
+
|
198
|
+
cls = ResidualSequential if residual else nn.Sequential
|
199
|
+
return cls(*layers)
|
@@ -0,0 +1,55 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import cast
|
4
|
+
|
5
|
+
import pytest
|
6
|
+
import torch
|
7
|
+
|
8
|
+
from nshtrainer.nn.mlp import MLP
|
9
|
+
|
10
|
+
|
11
|
+
def test_mlp_seed_reproducibility():
|
12
|
+
"""Test that the seed parameter in MLP ensures reproducible weights."""
|
13
|
+
|
14
|
+
# Test dimensions
|
15
|
+
dims = [10, 20, 5]
|
16
|
+
|
17
|
+
# Create two MLPs with the same seed
|
18
|
+
seed1 = 42
|
19
|
+
mlp1 = MLP(dims, activation=torch.nn.ReLU(), seed=seed1)
|
20
|
+
mlp2 = MLP(dims, activation=torch.nn.ReLU(), seed=seed1)
|
21
|
+
|
22
|
+
# Create an MLP with a different seed
|
23
|
+
seed2 = 123
|
24
|
+
mlp3 = MLP(dims, activation=torch.nn.ReLU(), seed=seed2)
|
25
|
+
|
26
|
+
# Check first layer weights
|
27
|
+
layer1_weights1 = cast(torch.Tensor, mlp1[0].weight)
|
28
|
+
layer1_weights2 = cast(torch.Tensor, mlp2[0].weight)
|
29
|
+
layer1_weights3 = cast(torch.Tensor, mlp3[0].weight)
|
30
|
+
|
31
|
+
# Same seed should produce identical weights
|
32
|
+
assert torch.allclose(layer1_weights1, layer1_weights2)
|
33
|
+
|
34
|
+
# Different seeds should produce different weights
|
35
|
+
assert not torch.allclose(layer1_weights1, layer1_weights3)
|
36
|
+
|
37
|
+
# Check second layer weights
|
38
|
+
layer2_weights1 = cast(torch.Tensor, mlp1[2].weight)
|
39
|
+
layer2_weights2 = cast(torch.Tensor, mlp2[2].weight)
|
40
|
+
layer2_weights3 = cast(torch.Tensor, mlp3[2].weight)
|
41
|
+
|
42
|
+
# Same seed should produce identical weights for all layers
|
43
|
+
assert torch.allclose(layer2_weights1, layer2_weights2)
|
44
|
+
|
45
|
+
# Different seeds should produce different weights for all layers
|
46
|
+
assert not torch.allclose(layer2_weights1, layer2_weights3)
|
47
|
+
|
48
|
+
# Test that not providing a seed gives different results each time
|
49
|
+
mlp4 = MLP(dims, activation=torch.nn.ReLU(), seed=None)
|
50
|
+
mlp5 = MLP(dims, activation=torch.nn.ReLU(), seed=None)
|
51
|
+
|
52
|
+
# Without seeds, weights should be different
|
53
|
+
assert not torch.allclose(
|
54
|
+
cast(torch.Tensor, mlp4[0].weight), cast(torch.Tensor, mlp5[0].weight)
|
55
|
+
)
|
@@ -30,6 +30,7 @@ nshtrainer/callbacks/shared_parameters.py,sha256=s94jJTAIbDGukYJu6l247QonVOCudGC
|
|
30
30
|
nshtrainer/callbacks/timer.py,sha256=gDcw_K_ikf0bkVgxQ0cDhvvNvz6GLZVLcatuKfh0ORU,4731
|
31
31
|
nshtrainer/callbacks/wandb_upload_code.py,sha256=shV7UtnXgY2bUlXdVrXiaDs0PNLlIt7TzNJkJPkzvzI,2414
|
32
32
|
nshtrainer/callbacks/wandb_watch.py,sha256=VB14Dy5ZRXQ3di0fPv0K_DFJurLhroLPytnuwQBiJFg,3037
|
33
|
+
nshtrainer/configs/.gitattributes,sha256=VeZmarvNEqiRBOHGcllpKm90nL6C8u4tBu7SEm7fj-E,26
|
33
34
|
nshtrainer/configs/__init__.py,sha256=MZfcSKhnjtVObBvVv9lu8L2cFTLINP5zcTQvWnz8jdk,14505
|
34
35
|
nshtrainer/configs/_checkpoint/__init__.py,sha256=6s7Y68StboqscY2G4P_QG443jz5aiym5SjOogIljWLg,342
|
35
36
|
nshtrainer/configs/_checkpoint/metadata/__init__.py,sha256=oOPfYkXTjKgm6pluGsG6V1TPyCEGjsQpHVL-LffSUFQ,290
|
@@ -116,11 +117,12 @@ nshtrainer/model/base.py,sha256=JL3AmH17GQjQIoMrZl3O0vUI7dj5ZsO5iEJgoLPyzHw,1035
|
|
116
117
|
nshtrainer/model/mixins/callback.py,sha256=0LPgve4VszHbLipid4mpI1qnnmdGS2spivs0dXLvqHw,3154
|
117
118
|
nshtrainer/model/mixins/debug.py,sha256=1LX9KzeFX9JDPs_a6YCdYDZXLhEk_5rBO2aCqlfBy7w,2087
|
118
119
|
nshtrainer/model/mixins/logger.py,sha256=27H99FuLaxc6_dDLG2pid4E_5E0-eLGnc2Ifpt0HYIM,6066
|
119
|
-
nshtrainer/nn/__init__.py,sha256=
|
120
|
-
nshtrainer/nn/mlp.py,sha256=
|
120
|
+
nshtrainer/nn/__init__.py,sha256=0FgeoaLYtRiSLT8fdPigLD8t-d8DKR8IQDw16JA9lT4,1523
|
121
|
+
nshtrainer/nn/mlp.py,sha256=_a8rJJniSCvM08gyQGO-5MUoO18U9_FSGGn3tZL2_U4,7101
|
121
122
|
nshtrainer/nn/module_dict.py,sha256=9plb8aQUx5TUEPhX5jI9u8LrpTeKe7jZAHi8iIqcN8w,2365
|
122
123
|
nshtrainer/nn/module_list.py,sha256=UB43pcwD_3nUke_DyLQt-iXKhWdKM6Zjm84lRC1hPYA,1755
|
123
124
|
nshtrainer/nn/nonlinearity.py,sha256=xmaL4QCRvCxqmaGIOwetJeKK-6IK4m2OV7D3SjxSwJQ,6322
|
125
|
+
nshtrainer/nn/tests/test_mlp.py,sha256=xBPiHlBvOCn67EbpzzKL-2FU7ikGxHT3i6CMSp1wk7M,1840
|
124
126
|
nshtrainer/optimizer.py,sha256=u968GRNPUNn3f_9BEY2RBNuJq5O3wJWams3NG0dkrOA,1738
|
125
127
|
nshtrainer/profiler/__init__.py,sha256=RjaNBoVcTFu8lF0dNlFp-2LaPYdonoIbDy2_KhgF0Ek,594
|
126
128
|
nshtrainer/profiler/_base.py,sha256=kFcSVn9gJuMwgDxbfyHh46CmEAIPZjxw3yjPbKgzvwA,950
|
@@ -151,6 +153,6 @@ nshtrainer/util/seed.py,sha256=diMV8iwBKN7Xxt5pELmui-gyqyT80_CZzomrWhNss0k,316
|
|
151
153
|
nshtrainer/util/slurm.py,sha256=HflkP5iI_r4UHMyPjw9R4dD5AHsJUpcfJw5PLvGYBRM,1603
|
152
154
|
nshtrainer/util/typed.py,sha256=Xt5fUU6zwLKSTLUdenovnKK0N8qUq89Kddz2_XeykVQ,164
|
153
155
|
nshtrainer/util/typing_utils.py,sha256=MjY-CUX9R5Tzat-BlFnQjwl1PQ_W2yZQoXhkYHlJ_VA,442
|
154
|
-
nshtrainer-1.0.
|
155
|
-
nshtrainer-1.0.
|
156
|
-
nshtrainer-1.0.
|
156
|
+
nshtrainer-1.0.0b41.dist-info/METADATA,sha256=DL9HgN6RP8X8v0sCdTr2IjRSwIBY96NZXe15m5V4y4c,988
|
157
|
+
nshtrainer-1.0.0b41.dist-info/WHEEL,sha256=XbeZDeTWKc1w7CSIyre5aMDU_-PohRwTQceYnisIYYY,88
|
158
|
+
nshtrainer-1.0.0b41.dist-info/RECORD,,
|