hyper-connections 0.0.19__tar.gz → 0.0.21__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {hyper_connections-0.0.19 → hyper_connections-0.0.21}/PKG-INFO +8 -1
- {hyper_connections-0.0.19 → hyper_connections-0.0.21}/README.md +7 -0
- {hyper_connections-0.0.19 → hyper_connections-0.0.21}/hyper_connections/hyper_connections.py +22 -1
- {hyper_connections-0.0.19 → hyper_connections-0.0.21}/hyper_connections/hyper_connections_with_multi_branch_inputs.py +12 -2
- {hyper_connections-0.0.19 → hyper_connections-0.0.21}/pyproject.toml +1 -1
- {hyper_connections-0.0.19 → hyper_connections-0.0.21}/.github/workflows/python-publish.yml +0 -0
- {hyper_connections-0.0.19 → hyper_connections-0.0.21}/.gitignore +0 -0
- {hyper_connections-0.0.19 → hyper_connections-0.0.21}/LICENSE +0 -0
- {hyper_connections-0.0.19 → hyper_connections-0.0.21}/hyper-connections.png +0 -0
- {hyper_connections-0.0.19 → hyper_connections-0.0.21}/hyper_connections/__init__.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: hyper-connections
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.21
|
|
4
4
|
Summary: Hyper-Connections
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/hyper-connections/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/hyper-connections
|
|
@@ -155,3 +155,10 @@ HyperConnections.get_init_and_expand_reduce_stream_functions(4, disable = True)
|
|
|
155
155
|
url = {https://api.semanticscholar.org/CorpusID:272987528}
|
|
156
156
|
}
|
|
157
157
|
```
|
|
158
|
+
|
|
159
|
+
```bibtex
|
|
160
|
+
@misc{Rubin2024,
|
|
161
|
+
author = {Ohad Rubin},
|
|
162
|
+
url = {https://medium.com/@ohadrubin/exploring-weight-decay-in-layer-normalization-challenges-and-a-reparameterization-solution-ad4d12c24950}
|
|
163
|
+
}
|
|
164
|
+
```
|
|
@@ -112,3 +112,10 @@ HyperConnections.get_init_and_expand_reduce_stream_functions(4, disable = True)
|
|
|
112
112
|
url = {https://api.semanticscholar.org/CorpusID:272987528}
|
|
113
113
|
}
|
|
114
114
|
```
|
|
115
|
+
|
|
116
|
+
```bibtex
|
|
117
|
+
@misc{Rubin2024,
|
|
118
|
+
author = {Ohad Rubin},
|
|
119
|
+
url = {https://medium.com/@ohadrubin/exploring-weight-decay-in-layer-normalization-challenges-and-a-reparameterization-solution-ad4d12c24950}
|
|
120
|
+
}
|
|
121
|
+
```
|
{hyper_connections-0.0.19 → hyper_connections-0.0.21}/hyper_connections/hyper_connections.py
RENAMED
|
@@ -12,6 +12,14 @@ from torch.utils._pytree import tree_flatten, tree_unflatten
|
|
|
12
12
|
|
|
13
13
|
from einops import rearrange, repeat, reduce, einsum
|
|
14
14
|
|
|
15
|
+
"""
|
|
16
|
+
ein notation:
|
|
17
|
+
b - batch
|
|
18
|
+
d - feature dimension
|
|
19
|
+
s - residual streams
|
|
20
|
+
t - residual streams + num branch inputs
|
|
21
|
+
"""
|
|
22
|
+
|
|
15
23
|
# helper functions
|
|
16
24
|
|
|
17
25
|
def exists(v):
|
|
@@ -23,6 +31,17 @@ def default(v, d):
|
|
|
23
31
|
def identity(t):
|
|
24
32
|
return t
|
|
25
33
|
|
|
34
|
+
# norms
|
|
35
|
+
|
|
36
|
+
class RMSNorm(Module):
|
|
37
|
+
def __init__(self, dim):
|
|
38
|
+
super().__init__()
|
|
39
|
+
self.scale = dim ** 0.5
|
|
40
|
+
self.gamma = nn.Parameter(torch.zeros(dim))
|
|
41
|
+
|
|
42
|
+
def forward(self, x):
|
|
43
|
+
return F.normalize(x, dim = -1) * self.scale * (self.gamma + 1)
|
|
44
|
+
|
|
26
45
|
# main classes
|
|
27
46
|
|
|
28
47
|
# residual base class
|
|
@@ -100,7 +119,9 @@ class HyperConnections(Module):
|
|
|
100
119
|
|
|
101
120
|
self.act = nn.Tanh() if tanh else nn.Identity()
|
|
102
121
|
|
|
103
|
-
self.norm =
|
|
122
|
+
self.norm = RMSNorm(dim) # they used layernorm in paper, but rmsnorm is fine given what we know now
|
|
123
|
+
|
|
124
|
+
assert num_residual_streams > 0, '`num_residual_streams` must be greater than 0'
|
|
104
125
|
|
|
105
126
|
self.num_residual_streams = num_residual_streams
|
|
106
127
|
init_residual_index = default(layer_index, randrange(num_residual_streams)) % num_residual_streams # just choose one random residual stream if layer index not given
|
|
@@ -12,7 +12,17 @@ from torch.utils._pytree import tree_flatten, tree_unflatten
|
|
|
12
12
|
|
|
13
13
|
from einops import rearrange, repeat, reduce, einsum
|
|
14
14
|
|
|
15
|
-
|
|
15
|
+
"""
|
|
16
|
+
ein notation:
|
|
17
|
+
b - batch
|
|
18
|
+
d - feature dimension
|
|
19
|
+
s - residual streams
|
|
20
|
+
i - branch inputs
|
|
21
|
+
br - branch functions
|
|
22
|
+
t - residual streams + num branch inputs
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
from hyper_connections.hyper_connections import Residual, StreamEmbed, RMSNorm
|
|
16
26
|
|
|
17
27
|
# helper functions
|
|
18
28
|
|
|
@@ -64,7 +74,7 @@ class HyperConnections(Module):
|
|
|
64
74
|
|
|
65
75
|
self.act = nn.Tanh() if tanh else nn.Identity()
|
|
66
76
|
|
|
67
|
-
self.norm =
|
|
77
|
+
self.norm = RMSNorm(dim) # they used layernorm in paper, but rmsnorm is fine given what we know now
|
|
68
78
|
|
|
69
79
|
self.num_residual_streams = num_residual_streams
|
|
70
80
|
self.num_branch_inputs = num_branch_inputs
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|