nshtrainer 0.6.0__py3-none-any.whl → 0.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nshtrainer/nn/__init__.py
CHANGED
|
@@ -1,4 +1,6 @@
|
|
|
1
1
|
from .mlp import MLP as MLP
|
|
2
|
+
from .mlp import MLPConfig as MLPConfig
|
|
3
|
+
from .mlp import MLPConfigDict as MLPConfigDict
|
|
2
4
|
from .mlp import ResidualSequential as ResidualSequential
|
|
3
5
|
from .module_dict import TypedModuleDict as TypedModuleDict
|
|
4
6
|
from .module_list import TypedModuleList as TypedModuleList
|
nshtrainer/nn/mlp.py
CHANGED
|
@@ -2,9 +2,10 @@ import copy
|
|
|
2
2
|
from collections.abc import Callable, Sequence
|
|
3
3
|
from typing import Literal, Protocol, runtime_checkable
|
|
4
4
|
|
|
5
|
+
import nshconfig as C
|
|
5
6
|
import torch
|
|
6
7
|
import torch.nn as nn
|
|
7
|
-
from typing_extensions import override
|
|
8
|
+
from typing_extensions import TypedDict, override
|
|
8
9
|
|
|
9
10
|
from .nonlinearity import BaseNonlinearityConfig
|
|
10
11
|
|
|
@@ -22,6 +23,71 @@ class ResidualSequential(nn.Sequential):
|
|
|
22
23
|
return input + super().forward(input)
|
|
23
24
|
|
|
24
25
|
|
|
26
|
+
class MLPConfigDict(TypedDict):
|
|
27
|
+
bias: bool
|
|
28
|
+
"""Whether to include bias terms in the linear layers."""
|
|
29
|
+
|
|
30
|
+
no_bias_scalar: bool
|
|
31
|
+
"""Whether to exclude bias terms when the output dimension is 1."""
|
|
32
|
+
|
|
33
|
+
nonlinearity: BaseNonlinearityConfig | None
|
|
34
|
+
"""Activation function to use between layers."""
|
|
35
|
+
|
|
36
|
+
ln: bool | Literal["pre", "post"]
|
|
37
|
+
"""Whether to apply layer normalization before or after the linear layers."""
|
|
38
|
+
|
|
39
|
+
dropout: float | None
|
|
40
|
+
"""Dropout probability to apply between layers."""
|
|
41
|
+
|
|
42
|
+
residual: bool
|
|
43
|
+
"""Whether to use residual connections between layers."""
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class MLPConfig(C.Config):
|
|
47
|
+
bias: bool = True
|
|
48
|
+
"""Whether to include bias terms in the linear layers."""
|
|
49
|
+
|
|
50
|
+
no_bias_scalar: bool = True
|
|
51
|
+
"""Whether to exclude bias terms when the output dimension is 1."""
|
|
52
|
+
|
|
53
|
+
nonlinearity: BaseNonlinearityConfig | None = None
|
|
54
|
+
"""Activation function to use between layers."""
|
|
55
|
+
|
|
56
|
+
ln: bool | Literal["pre", "post"] = False
|
|
57
|
+
"""Whether to apply layer normalization before or after the linear layers."""
|
|
58
|
+
|
|
59
|
+
dropout: float | None = None
|
|
60
|
+
"""Dropout probability to apply between layers."""
|
|
61
|
+
|
|
62
|
+
residual: bool = False
|
|
63
|
+
"""Whether to use residual connections between layers."""
|
|
64
|
+
|
|
65
|
+
def to_kwargs(self) -> MLPConfigDict:
|
|
66
|
+
return {
|
|
67
|
+
"bias": self.bias,
|
|
68
|
+
"no_bias_scalar": self.no_bias_scalar,
|
|
69
|
+
"nonlinearity": self.nonlinearity,
|
|
70
|
+
"ln": self.ln,
|
|
71
|
+
"dropout": self.dropout,
|
|
72
|
+
"residual": self.residual,
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
def create_module(
|
|
76
|
+
self,
|
|
77
|
+
dims: Sequence[int],
|
|
78
|
+
pre_layers: Sequence[nn.Module] = [],
|
|
79
|
+
post_layers: Sequence[nn.Module] = [],
|
|
80
|
+
linear_cls: LinearModuleConstructor = nn.Linear,
|
|
81
|
+
):
|
|
82
|
+
return MLP(
|
|
83
|
+
dims,
|
|
84
|
+
**self.to_kwargs(),
|
|
85
|
+
pre_layers=pre_layers,
|
|
86
|
+
post_layers=post_layers,
|
|
87
|
+
linear_cls=linear_cls,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
|
|
25
91
|
def MLP(
|
|
26
92
|
dims: Sequence[int],
|
|
27
93
|
activation: BaseNonlinearityConfig
|
|
@@ -41,8 +41,8 @@ nshtrainer/model/modules/logger.py,sha256=XEeo3QrplTNKZqfl6iWZf3fze3R4YOeOvs-RKV
|
|
|
41
41
|
nshtrainer/model/modules/profiler.py,sha256=rQ_jRMcM1Z2AIROZlRnBRHM5rkTpq67afZPD6CIRfXs,825
|
|
42
42
|
nshtrainer/model/modules/rlp_sanity_checks.py,sha256=o6gUceFwsuDHmL8eLOYuT3JGXFzq_qc4awl2RWaBygU,8900
|
|
43
43
|
nshtrainer/model/modules/shared_parameters.py,sha256=mD5wrlBE3c025vzVdTpnSyC8yxzuI-aUWMmPhqPT0a0,2694
|
|
44
|
-
nshtrainer/nn/__init__.py,sha256=
|
|
45
|
-
nshtrainer/nn/mlp.py,sha256=
|
|
44
|
+
nshtrainer/nn/__init__.py,sha256=0QPFl02a71WZQjLMGOlFNMmsYP5aa1q3eABHmnWH58Q,1427
|
|
45
|
+
nshtrainer/nn/mlp.py,sha256=tX1VdtdzB0dyMkV0oEGCcob9hsYrXEIeusXfb-bJ5lQ,5940
|
|
46
46
|
nshtrainer/nn/module_dict.py,sha256=NOY0B6WDTnktyWH4GthsprMQo0bpehC-hCq9SfD8paE,2329
|
|
47
47
|
nshtrainer/nn/module_list.py,sha256=fb2u5Rqdjff8Pekyr9hkCPkBorQ-fldzzFAjsgWAm30,1719
|
|
48
48
|
nshtrainer/nn/nonlinearity.py,sha256=owtU4kh4G98psD0axOJWVfBhm-OtJVgFM-TXSHmbNPU,3625
|
|
@@ -59,6 +59,6 @@ nshtrainer/util/seed.py,sha256=HEXgVs-wldByahOysKwq7506OHxdYTEgmP-tDQVAEkQ,287
|
|
|
59
59
|
nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
|
|
60
60
|
nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
|
|
61
61
|
nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
|
|
62
|
-
nshtrainer-0.6.
|
|
63
|
-
nshtrainer-0.6.
|
|
64
|
-
nshtrainer-0.6.
|
|
62
|
+
nshtrainer-0.6.1.dist-info/METADATA,sha256=A99lygdq2iZY6oebqY5iOFcU71GX0t_cQN3cwKUlCVg,788
|
|
63
|
+
nshtrainer-0.6.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
64
|
+
nshtrainer-0.6.1.dist-info/RECORD,,
|
|
File without changes
|