hyper-connections 0.0.20__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hyper_connections/hyper_connections.py +12 -1
- hyper_connections/hyper_connections_with_multi_branch_inputs.py +2 -2
- {hyper_connections-0.0.20.dist-info → hyper_connections-0.0.21.dist-info}/METADATA +8 -1
- hyper_connections-0.0.21.dist-info/RECORD +7 -0
- hyper_connections-0.0.20.dist-info/RECORD +0 -7
- {hyper_connections-0.0.20.dist-info → hyper_connections-0.0.21.dist-info}/WHEEL +0 -0
- {hyper_connections-0.0.20.dist-info → hyper_connections-0.0.21.dist-info}/licenses/LICENSE +0 -0
|
@@ -31,6 +31,17 @@ def default(v, d):
|
|
|
31
31
|
def identity(t):
|
|
32
32
|
return t
|
|
33
33
|
|
|
34
|
+
# norms
|
|
35
|
+
|
|
36
|
+
class RMSNorm(Module):
|
|
37
|
+
def __init__(self, dim):
|
|
38
|
+
super().__init__()
|
|
39
|
+
self.scale = dim ** 0.5
|
|
40
|
+
self.gamma = nn.Parameter(torch.zeros(dim))
|
|
41
|
+
|
|
42
|
+
def forward(self, x):
|
|
43
|
+
return F.normalize(x, dim = -1) * self.scale * (self.gamma + 1)
|
|
44
|
+
|
|
34
45
|
# main classes
|
|
35
46
|
|
|
36
47
|
# residual base class
|
|
@@ -108,7 +119,7 @@ class HyperConnections(Module):
|
|
|
108
119
|
|
|
109
120
|
self.act = nn.Tanh() if tanh else nn.Identity()
|
|
110
121
|
|
|
111
|
-
self.norm =
|
|
122
|
+
self.norm = RMSNorm(dim) # they used layernorm in paper, but rmsnorm is fine given what we know now
|
|
112
123
|
|
|
113
124
|
assert num_residual_streams > 0, '`num_residual_streams` must be greater than 0'
|
|
114
125
|
|
|
@@ -22,7 +22,7 @@ br - branch functions
|
|
|
22
22
|
t - residual streams + num branch inputs
|
|
23
23
|
"""
|
|
24
24
|
|
|
25
|
-
from hyper_connections.hyper_connections import Residual, StreamEmbed
|
|
25
|
+
from hyper_connections.hyper_connections import Residual, StreamEmbed, RMSNorm
|
|
26
26
|
|
|
27
27
|
# helper functions
|
|
28
28
|
|
|
@@ -74,7 +74,7 @@ class HyperConnections(Module):
|
|
|
74
74
|
|
|
75
75
|
self.act = nn.Tanh() if tanh else nn.Identity()
|
|
76
76
|
|
|
77
|
-
self.norm =
|
|
77
|
+
self.norm = RMSNorm(dim) # they used layernorm in paper, but rmsnorm is fine given what we know now
|
|
78
78
|
|
|
79
79
|
self.num_residual_streams = num_residual_streams
|
|
80
80
|
self.num_branch_inputs = num_branch_inputs
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: hyper-connections
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.21
|
|
4
4
|
Summary: Hyper-Connections
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/hyper-connections/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/hyper-connections
|
|
@@ -155,3 +155,10 @@ HyperConnections.get_init_and_expand_reduce_stream_functions(4, disable = True)
|
|
|
155
155
|
url = {https://api.semanticscholar.org/CorpusID:272987528}
|
|
156
156
|
}
|
|
157
157
|
```
|
|
158
|
+
|
|
159
|
+
```bibtex
|
|
160
|
+
@misc{Rubin2024,
|
|
161
|
+
author = {Ohad Rubin},
|
|
162
|
+
url = {https://medium.com/@ohadrubin/exploring-weight-decay-in-layer-normalization-challenges-and-a-reparameterization-solution-ad4d12c24950}
|
|
163
|
+
}
|
|
164
|
+
```
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
hyper_connections/__init__.py,sha256=wJxbrEXRGmOIjPw8fWP-cUq6CE8bvx95mIlhWifNvYc,135
|
|
2
|
+
hyper_connections/hyper_connections.py,sha256=l64d-qB8m188RuYfzgWOoIzmXnGMjIIQAvZZBQyUOGs,9755
|
|
3
|
+
hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=-9K-IS-zbjMVaFWr_29okNdup3YATJTxeypUcD0Syoc,7559
|
|
4
|
+
hyper_connections-0.0.21.dist-info/METADATA,sha256=ZBcfPopgUUOK6JeUIEGMSGgaeWGG3bvYoFOLTeskGoA,5288
|
|
5
|
+
hyper_connections-0.0.21.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
6
|
+
hyper_connections-0.0.21.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
|
|
7
|
+
hyper_connections-0.0.21.dist-info/RECORD,,
|
|
@@ -1,7 +0,0 @@
|
|
|
1
|
-
hyper_connections/__init__.py,sha256=wJxbrEXRGmOIjPw8fWP-cUq6CE8bvx95mIlhWifNvYc,135
|
|
2
|
-
hyper_connections/hyper_connections.py,sha256=ElPtieRLvVKaVg2Attx1k6esKq1SY2X4AVZbZmsQAOM,9486
|
|
3
|
-
hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=HbLpt79xcMv_os6brMvDd90t2GOPceliE1YFusR2eJI,7553
|
|
4
|
-
hyper_connections-0.0.20.dist-info/METADATA,sha256=erA-d7KNNdzPY76x8IWKd2trv2WuBO9-C2DtH-SoQ_Y,5076
|
|
5
|
-
hyper_connections-0.0.20.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
6
|
-
hyper_connections-0.0.20.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
|
|
7
|
-
hyper_connections-0.0.20.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|