hyper-connections 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hyper_connections/hyper_connections.py +30 -4
- {hyper_connections-0.0.1.dist-info → hyper_connections-0.0.3.dist-info}/METADATA +45 -2
- hyper_connections-0.0.3.dist-info/RECORD +6 -0
- hyper_connections-0.0.1.dist-info/RECORD +0 -6
- {hyper_connections-0.0.1.dist-info → hyper_connections-0.0.3.dist-info}/WHEEL +0 -0
- {hyper_connections-0.0.1.dist-info → hyper_connections-0.0.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -6,6 +6,7 @@ import torch
|
|
|
6
6
|
from torch import nn
|
|
7
7
|
from torch.nn import Module
|
|
8
8
|
import torch.nn.functional as F
|
|
9
|
+
from torch.utils._pytree import tree_flatten, tree_unflatten
|
|
9
10
|
|
|
10
11
|
from einops import rearrange, repeat, reduce, einsum
|
|
11
12
|
|
|
@@ -30,6 +31,7 @@ class HyperConnections(Module):
|
|
|
30
31
|
branch: Module | None = None,
|
|
31
32
|
layer_index = None,
|
|
32
33
|
tanh = True,
|
|
34
|
+
channel_first = False
|
|
33
35
|
):
|
|
34
36
|
"""
|
|
35
37
|
Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
|
|
@@ -56,6 +58,10 @@ class HyperConnections(Module):
|
|
|
56
58
|
self.dynamic_beta_fn = nn.Parameter(torch.zeros(dim))
|
|
57
59
|
self.dynamic_beta_scale = nn.Parameter(torch.ones(()) * 1e-2)
|
|
58
60
|
|
|
61
|
+
# channel first option
|
|
62
|
+
|
|
63
|
+
self.channel_first = channel_first
|
|
64
|
+
|
|
59
65
|
@classmethod
|
|
60
66
|
def get_expand_reduce_stream_functions(cls, num_streams):
|
|
61
67
|
expand_fn = partial(repeat, pattern = 'b ... -> (b s) ...', s = num_streams)
|
|
@@ -66,31 +72,47 @@ class HyperConnections(Module):
|
|
|
66
72
|
def width_connection(self, residuals):
|
|
67
73
|
# width connection
|
|
68
74
|
|
|
75
|
+
if self.channel_first:
|
|
76
|
+
residuals = rearrange(residuals, 'b d ... -> b ... d')
|
|
77
|
+
|
|
69
78
|
residuals = rearrange(residuals, '(b s) ... d -> b ... s d', s = self.num_residual_streams)
|
|
70
79
|
|
|
71
80
|
normed = self.norm(residuals)
|
|
72
81
|
|
|
82
|
+
# alpha for weighted sum of residuals going into branch
|
|
83
|
+
|
|
73
84
|
wc_weight = self.act(normed @ self.dynamic_alpha_fn)
|
|
74
85
|
dynamic_alpha = wc_weight * self.dynamic_alpha_scale
|
|
75
86
|
alpha = dynamic_alpha + self.static_alpha
|
|
76
87
|
|
|
88
|
+
# beta for weights from branch output back to residual streams
|
|
89
|
+
|
|
77
90
|
dc_weight = self.act(normed @ self.dynamic_beta_fn)
|
|
78
91
|
dynamic_beta = dc_weight * self.dynamic_beta_scale
|
|
79
92
|
beta = dynamic_beta + self.static_beta
|
|
80
93
|
|
|
81
|
-
# width connection
|
|
82
|
-
|
|
83
94
|
mix_h = einsum(alpha, residuals, '... s t, ... s d -> ... t d')
|
|
84
95
|
|
|
85
96
|
branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :]
|
|
86
97
|
|
|
98
|
+
if self.channel_first:
|
|
99
|
+
branch_input = rearrange(branch_input, 'b ... d -> b d ...')
|
|
100
|
+
|
|
87
101
|
return branch_input, residuals, beta
|
|
88
102
|
|
|
89
103
|
def depth_connection(self, branch_output, residuals, beta):
|
|
90
104
|
# 'depth' connection
|
|
91
105
|
|
|
106
|
+
if self.channel_first:
|
|
107
|
+
branch_output = rearrange(branch_output, 'b d ... -> b ... d')
|
|
108
|
+
|
|
92
109
|
residuals = einsum(branch_output, beta, 'b ... d, b ... s -> b ... s d') + residuals
|
|
93
|
-
|
|
110
|
+
output = rearrange(residuals, 'b ... s d -> (b s) ... d')
|
|
111
|
+
|
|
112
|
+
if self.channel_first:
|
|
113
|
+
output = rearrange(output, 'b ... d -> b d ...')
|
|
114
|
+
|
|
115
|
+
return output
|
|
94
116
|
|
|
95
117
|
def forward(self, residuals, **branch_kwargs):
|
|
96
118
|
|
|
@@ -104,4 +126,8 @@ class HyperConnections(Module):
|
|
|
104
126
|
|
|
105
127
|
branch_output = self.branch(branch_input, **branch_kwargs)
|
|
106
128
|
|
|
107
|
-
|
|
129
|
+
(branch_output, *rest), tree_spec = tree_flatten(branch_output)
|
|
130
|
+
|
|
131
|
+
branch_output = add_residual_fn(branch_output)
|
|
132
|
+
|
|
133
|
+
return tree_unflatten((branch_output, *rest), tree_spec)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: hyper-connections
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.3
|
|
4
4
|
Summary: Hyper-Connections
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/hyper-connections/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/hyper-connections
|
|
@@ -45,7 +45,7 @@ Description-Content-Type: text/markdown
|
|
|
45
45
|
|
|
46
46
|
## Hyper Connections
|
|
47
47
|
|
|
48
|
-
Attempt to make
|
|
48
|
+
Attempt to make multiple residual streams, proposed in [Hyper-Connections paper](https://arxiv.org/abs/2409.19606) out of Bytedance AI lab, accessible as an easy to use library, as well as for following any new research in this direction.
|
|
49
49
|
|
|
50
50
|
## Install
|
|
51
51
|
|
|
@@ -92,6 +92,49 @@ residual = hyper_conn_branch(residual)
|
|
|
92
92
|
residual = reduce_stream(residual)
|
|
93
93
|
```
|
|
94
94
|
|
|
95
|
+
Or doing it manually, as in the paper
|
|
96
|
+
|
|
97
|
+
```python
|
|
98
|
+
import torch
|
|
99
|
+
from torch import nn
|
|
100
|
+
|
|
101
|
+
# a single branch layer
|
|
102
|
+
|
|
103
|
+
branch = nn.Linear(512, 512)
|
|
104
|
+
|
|
105
|
+
# before
|
|
106
|
+
|
|
107
|
+
residual = torch.randn(2, 1024, 512)
|
|
108
|
+
|
|
109
|
+
residual = branch(residual) + residual
|
|
110
|
+
|
|
111
|
+
# after, say 4 streams in paper
|
|
112
|
+
|
|
113
|
+
from hyper_connections import HyperConnections
|
|
114
|
+
|
|
115
|
+
expand_stream, reduce_stream = HyperConnections.get_expand_reduce_stream_functions(4)
|
|
116
|
+
|
|
117
|
+
# 1. instantiate hyper connection with correct number of streams (4 in this case)
|
|
118
|
+
|
|
119
|
+
hyper_conn = HyperConnections(4, dim = 512)
|
|
120
|
+
|
|
121
|
+
# 2. expand to 4 streams
|
|
122
|
+
|
|
123
|
+
residual = expand_stream(residual)
|
|
124
|
+
|
|
125
|
+
# 3. forward your residual into hyper connection for the branch input + add residual function (learned betas)
|
|
126
|
+
|
|
127
|
+
branch_input, add_residual = hyper_conn(residual)
|
|
128
|
+
|
|
129
|
+
branch_output = branch(branch_input)
|
|
130
|
+
|
|
131
|
+
residual = add_residual(branch_output)
|
|
132
|
+
|
|
133
|
+
# 4. reduce 4 streams with a summation, this has to be done after your for loop trunk
|
|
134
|
+
|
|
135
|
+
residual = reduce_stream(residual)
|
|
136
|
+
```
|
|
137
|
+
|
|
95
138
|
## Citation
|
|
96
139
|
|
|
97
140
|
```bibtex
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
hyper_connections/__init__.py,sha256=xXx2Mb-dS1__UPzT-5VR1XZmyqKSSkT1DU6bAcK8jR0,73
|
|
2
|
+
hyper_connections/hyper_connections.py,sha256=91QtTtnpffmErIZvrnTtosSf4JgBqcyGvxftmka-EOw,4303
|
|
3
|
+
hyper_connections-0.0.3.dist-info/METADATA,sha256=8XKTmC6Ys10uOyotPtvL17v4uZyepkWoeMRVM4B_TSQ,4676
|
|
4
|
+
hyper_connections-0.0.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
5
|
+
hyper_connections-0.0.3.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
|
|
6
|
+
hyper_connections-0.0.3.dist-info/RECORD,,
|
|
@@ -1,6 +0,0 @@
|
|
|
1
|
-
hyper_connections/__init__.py,sha256=xXx2Mb-dS1__UPzT-5VR1XZmyqKSSkT1DU6bAcK8jR0,73
|
|
2
|
-
hyper_connections/hyper_connections.py,sha256=Nbv7_OZ8FkdRG1WfmqdQHf46GZxWY0G_h_p4lT_JW38,3450
|
|
3
|
-
hyper_connections-0.0.1.dist-info/METADATA,sha256=o_PVMP0Mm_sr7196WRYbh7O3wDp-nHVzAESgfWzi3FQ,3684
|
|
4
|
-
hyper_connections-0.0.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
5
|
-
hyper_connections-0.0.1.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
|
|
6
|
-
hyper_connections-0.0.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|