nshtrainer 0.18.2__py3-none-any.whl → 0.19.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,15 +4,19 @@ from typing import Annotated, Literal
4
4
  import nshconfig as C
5
5
  import torch
6
6
  import torch.nn as nn
7
- from typing_extensions import override
7
+ import torch.nn.functional as F
8
+ from typing_extensions import final, override
8
9
 
9
10
 
10
11
  class BaseNonlinearityConfig(C.Config, ABC):
11
12
  @abstractmethod
12
- def create_module(self) -> nn.Module:
13
- pass
13
+ def create_module(self) -> nn.Module: ...
14
+
15
+ @abstractmethod
16
+ def __call__(self, x: torch.Tensor) -> torch.Tensor: ...
14
17
 
15
18
 
19
+ @final
16
20
  class ReLUNonlinearityConfig(BaseNonlinearityConfig):
17
21
  name: Literal["relu"] = "relu"
18
22
 
@@ -20,7 +24,11 @@ class ReLUNonlinearityConfig(BaseNonlinearityConfig):
20
24
  def create_module(self) -> nn.Module:
21
25
  return nn.ReLU()
22
26
 
27
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
28
+ return F.relu(x)
23
29
 
30
+
31
+ @final
24
32
  class SigmoidNonlinearityConfig(BaseNonlinearityConfig):
25
33
  name: Literal["sigmoid"] = "sigmoid"
26
34
 
@@ -28,7 +36,11 @@ class SigmoidNonlinearityConfig(BaseNonlinearityConfig):
28
36
  def create_module(self) -> nn.Module:
29
37
  return nn.Sigmoid()
30
38
 
39
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
40
+ return torch.sigmoid(x)
41
+
31
42
 
43
+ @final
32
44
  class TanhNonlinearityConfig(BaseNonlinearityConfig):
33
45
  name: Literal["tanh"] = "tanh"
34
46
 
@@ -36,23 +48,44 @@ class TanhNonlinearityConfig(BaseNonlinearityConfig):
36
48
  def create_module(self) -> nn.Module:
37
49
  return nn.Tanh()
38
50
 
51
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
52
+ return torch.tanh(x)
53
+
39
54
 
55
+ @final
40
56
  class SoftmaxNonlinearityConfig(BaseNonlinearityConfig):
41
57
  name: Literal["softmax"] = "softmax"
42
58
 
59
+ dim: int = -1
60
+ """The dimension to apply the softmax function."""
61
+
43
62
  @override
44
63
  def create_module(self) -> nn.Module:
45
- return nn.Softmax(dim=1)
64
+ return nn.Softmax(dim=self.dim)
65
+
66
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
67
+ return torch.softmax(x, dim=self.dim)
46
68
 
47
69
 
70
+ @final
48
71
  class SoftplusNonlinearityConfig(BaseNonlinearityConfig):
49
72
  name: Literal["softplus"] = "softplus"
50
73
 
74
+ beta: float = 1.0
75
+ """The beta parameter in the softplus function."""
76
+
77
+ threshold: float = 20.0
78
+ """Values above this revert to a linear function."""
79
+
51
80
  @override
52
81
  def create_module(self) -> nn.Module:
53
- return nn.Softplus()
82
+ return nn.Softplus(beta=self.beta, threshold=self.threshold)
83
+
84
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
85
+ return F.softplus(x, beta=self.beta, threshold=self.threshold)
54
86
 
55
87
 
88
+ @final
56
89
  class SoftsignNonlinearityConfig(BaseNonlinearityConfig):
57
90
  name: Literal["softsign"] = "softsign"
58
91
 
@@ -60,44 +93,78 @@ class SoftsignNonlinearityConfig(BaseNonlinearityConfig):
60
93
  def create_module(self) -> nn.Module:
61
94
  return nn.Softsign()
62
95
 
96
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
97
+ return F.softsign(x)
98
+
63
99
 
100
+ @final
64
101
  class ELUNonlinearityConfig(BaseNonlinearityConfig):
65
102
  name: Literal["elu"] = "elu"
66
103
 
104
+ alpha: float = 1.0
105
+ """The alpha parameter in the ELU function."""
106
+
67
107
  @override
68
108
  def create_module(self) -> nn.Module:
69
109
  return nn.ELU()
70
110
 
111
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
112
+ return F.elu(x, alpha=self.alpha)
113
+
71
114
 
115
+ @final
72
116
  class LeakyReLUNonlinearityConfig(BaseNonlinearityConfig):
73
117
  name: Literal["leaky_relu"] = "leaky_relu"
74
118
 
75
- negative_slope: float | None = None
119
+ negative_slope: float = 1.0e-2
120
+ """The negative slope of the leaky ReLU function."""
76
121
 
77
122
  @override
78
123
  def create_module(self) -> nn.Module:
79
- kwargs = {}
80
- if self.negative_slope is not None:
81
- kwargs["negative_slope"] = self.negative_slope
82
- return nn.LeakyReLU(**kwargs)
124
+ return nn.LeakyReLU(negative_slope=self.negative_slope)
125
+
126
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
127
+ return F.leaky_relu(x, negative_slope=self.negative_slope)
83
128
 
84
129
 
130
+ @final
85
131
  class PReLUConfig(BaseNonlinearityConfig):
86
132
  name: Literal["prelu"] = "prelu"
87
133
 
134
+ num_parameters: int = 1
135
+ """The number of :math:`a` to learn.
136
+ Although it takes an int as input, there is only two values are legitimate:
137
+ 1, or the number of channels at input."""
138
+
139
+ init: float = 0.25
140
+ """The initial value of :math:`a`."""
141
+
88
142
  @override
89
143
  def create_module(self) -> nn.Module:
90
- return nn.PReLU()
144
+ return nn.PReLU(num_parameters=self.num_parameters, init=self.init)
145
+
146
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
147
+ raise NotImplementedError(
148
+ "PReLU requires learnable parameters and cannot be called directly."
149
+ )
91
150
 
92
151
 
152
+ @final
93
153
  class GELUNonlinearityConfig(BaseNonlinearityConfig):
94
154
  name: Literal["gelu"] = "gelu"
95
155
 
156
+ approximate: Literal["tanh", "none"] = "none"
157
+ """The gelu approximation algorithm to use."""
158
+
96
159
  @override
97
160
  def create_module(self) -> nn.Module:
98
- return nn.GELU()
161
+ return nn.GELU(approximate=self.approximate)
162
+
163
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
164
+ return F.gelu(x, approximate=self.approximate)
99
165
 
100
166
 
167
+ @final
101
168
  class SwishNonlinearityConfig(BaseNonlinearityConfig):
102
169
  name: Literal["swish"] = "swish"
103
170
 
@@ -105,7 +172,11 @@ class SwishNonlinearityConfig(BaseNonlinearityConfig):
105
172
  def create_module(self) -> nn.Module:
106
173
  return nn.SiLU()
107
174
 
175
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
176
+ return F.silu(x)
177
+
108
178
 
179
+ @final
109
180
  class SiLUNonlinearityConfig(BaseNonlinearityConfig):
110
181
  name: Literal["silu"] = "silu"
111
182
 
@@ -113,7 +184,11 @@ class SiLUNonlinearityConfig(BaseNonlinearityConfig):
113
184
  def create_module(self) -> nn.Module:
114
185
  return nn.SiLU()
115
186
 
187
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
188
+ return F.silu(x)
116
189
 
190
+
191
+ @final
117
192
  class MishNonlinearityConfig(BaseNonlinearityConfig):
118
193
  name: Literal["mish"] = "mish"
119
194
 
@@ -121,6 +196,9 @@ class MishNonlinearityConfig(BaseNonlinearityConfig):
121
196
  def create_module(self) -> nn.Module:
122
197
  return nn.Mish()
123
198
 
199
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
200
+ return F.mish(x)
201
+
124
202
 
125
203
  class SwiGLU(nn.SiLU):
126
204
  @override
@@ -129,6 +207,7 @@ class SwiGLU(nn.SiLU):
129
207
  return input * super().forward(gate)
130
208
 
131
209
 
210
+ @final
132
211
  class SwiGLUNonlinearityConfig(BaseNonlinearityConfig):
133
212
  name: Literal["swiglu"] = "swiglu"
134
213
 
@@ -136,6 +215,10 @@ class SwiGLUNonlinearityConfig(BaseNonlinearityConfig):
136
215
  def create_module(self) -> nn.Module:
137
216
  return SwiGLU()
138
217
 
218
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
219
+ input, gate = x.chunk(2, dim=-1)
220
+ return input * F.silu(gate)
221
+
139
222
 
140
223
  NonlinearityConfig = Annotated[
141
224
  ReLUNonlinearityConfig
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.18.2
3
+ Version: 0.19.0
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -69,7 +69,7 @@ nshtrainer/nn/__init__.py,sha256=0QPFl02a71WZQjLMGOlFNMmsYP5aa1q3eABHmnWH58Q,142
69
69
  nshtrainer/nn/mlp.py,sha256=V0FrScpIUdg_IgIO8GMtIsGEtmHjwF14i2IWxmZrsqg,5952
70
70
  nshtrainer/nn/module_dict.py,sha256=NOY0B6WDTnktyWH4GthsprMQo0bpehC-hCq9SfD8paE,2329
71
71
  nshtrainer/nn/module_list.py,sha256=fb2u5Rqdjff8Pekyr9hkCPkBorQ-fldzzFAjsgWAm30,1719
72
- nshtrainer/nn/nonlinearity.py,sha256=owtU4kh4G98psD0axOJWVfBhm-OtJVgFM-TXSHmbNPU,3625
72
+ nshtrainer/nn/nonlinearity.py,sha256=4sYE4MN5zojc-go1k0PYtqssVRuXrM7D4tbpIXp5K-E,6078
73
73
  nshtrainer/optimizer.py,sha256=kuJEA1pvB3y1FcsfhAoOJujVqEZqFHlmYO8GW6JeA1g,1527
74
74
  nshtrainer/runner.py,sha256=USAjrExHkN5oVNVunsoPnLxfQrEHSaa54S3RipOe544,3605
75
75
  nshtrainer/scripts/find_packages.py,sha256=ixYivZobumyyGsf2B9oYMLyLTRcBzY_vUv-u3bNW-hs,1424
@@ -85,6 +85,6 @@ nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
85
85
  nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
86
86
  nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
87
87
  nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
88
- nshtrainer-0.18.2.dist-info/METADATA,sha256=vev96DaxCnqJOAvvGrGOJ37OpWNFLrCdtGPN-kpnvO4,935
89
- nshtrainer-0.18.2.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
90
- nshtrainer-0.18.2.dist-info/RECORD,,
88
+ nshtrainer-0.19.0.dist-info/METADATA,sha256=VLb38BSORQBx6g_SfGnbdBWa37N9xCtZ-JI45ATouzY,935
89
+ nshtrainer-0.19.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
90
+ nshtrainer-0.19.0.dist-info/RECORD,,