hyper-connections 0.1.5__tar.gz → 0.1.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {hyper_connections-0.1.5 → hyper_connections-0.1.7}/PKG-INFO +1 -1
- {hyper_connections-0.1.5 → hyper_connections-0.1.7}/hyper_connections/hyper_connections.py +23 -10
- {hyper_connections-0.1.5 → hyper_connections-0.1.7}/hyper_connections/hyper_connections_with_multi_branch_inputs.py +5 -4
- {hyper_connections-0.1.5 → hyper_connections-0.1.7}/hyper_connections/hyper_connections_with_multi_input_streams.py +5 -8
- {hyper_connections-0.1.5 → hyper_connections-0.1.7}/pyproject.toml +1 -1
- {hyper_connections-0.1.5 → hyper_connections-0.1.7}/tests/test_hyper_connections.py +43 -0
- {hyper_connections-0.1.5 → hyper_connections-0.1.7}/.github/workflows/python-publish.yml +0 -0
- {hyper_connections-0.1.5 → hyper_connections-0.1.7}/.github/workflows/test.yml +0 -0
- {hyper_connections-0.1.5 → hyper_connections-0.1.7}/.gitignore +0 -0
- {hyper_connections-0.1.5 → hyper_connections-0.1.7}/LICENSE +0 -0
- {hyper_connections-0.1.5 → hyper_connections-0.1.7}/README.md +0 -0
- {hyper_connections-0.1.5 → hyper_connections-0.1.7}/hyper-connections.png +0 -0
- {hyper_connections-0.1.5 → hyper_connections-0.1.7}/hyper_connections/__init__.py +0 -0
|
@@ -11,6 +11,7 @@ import torch.nn.functional as F
|
|
|
11
11
|
from torch.utils._pytree import tree_flatten, tree_unflatten
|
|
12
12
|
|
|
13
13
|
from einops import rearrange, repeat, reduce, einsum
|
|
14
|
+
from einops.layers.torch import Reduce
|
|
14
15
|
|
|
15
16
|
"""
|
|
16
17
|
ein notation:
|
|
@@ -35,11 +36,11 @@ def identity(t):
|
|
|
35
36
|
|
|
36
37
|
def get_expand_reduce_stream_functions(num_streams, disable = False):
|
|
37
38
|
|
|
38
|
-
if disable:
|
|
39
|
-
return (
|
|
39
|
+
if num_streams == 1 or disable:
|
|
40
|
+
return (nn.Identity(), nn.Identity())
|
|
40
41
|
|
|
41
|
-
expand_fn =
|
|
42
|
-
reduce_fn =
|
|
42
|
+
expand_fn = Reduce(pattern = 'b ... -> (b s) ...', reduction = 'repeat', s = num_streams)
|
|
43
|
+
reduce_fn = Reduce(pattern = '(b s) ... -> b ...', reduction = 'sum', s = num_streams)
|
|
43
44
|
|
|
44
45
|
return expand_fn, reduce_fn
|
|
45
46
|
|
|
@@ -72,16 +73,18 @@ class Residual(Module):
|
|
|
72
73
|
self,
|
|
73
74
|
*args,
|
|
74
75
|
branch: Module | None = None,
|
|
76
|
+
residual_transform: Module | None = None,
|
|
75
77
|
**kwargs
|
|
76
78
|
):
|
|
77
79
|
super().__init__()
|
|
78
80
|
self.branch = branch
|
|
81
|
+
self.residual_transform = default(residual_transform, nn.Identity())
|
|
79
82
|
|
|
80
83
|
def width_connection(self, residuals):
|
|
81
84
|
return residuals, residuals, dict()
|
|
82
85
|
|
|
83
86
|
def depth_connection(self, branch_output, residuals):
|
|
84
|
-
return branch_output + residuals
|
|
87
|
+
return branch_output + self.residual_transform(residuals)
|
|
85
88
|
|
|
86
89
|
def decorate_branch(self, branch: Callable):
|
|
87
90
|
assert not exists(self.branch), 'branch was already wrapped on init'
|
|
@@ -127,7 +130,8 @@ class HyperConnections(Module):
|
|
|
127
130
|
layer_index = None,
|
|
128
131
|
tanh = True,
|
|
129
132
|
channel_first = False,
|
|
130
|
-
dropout = 0
|
|
133
|
+
dropout = 0.,
|
|
134
|
+
residual_transform: Module | None = None, # to support resnet blocks where dimension in not equal to dimension out - usually a residual conv
|
|
131
135
|
):
|
|
132
136
|
"""
|
|
133
137
|
Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
|
|
@@ -167,7 +171,14 @@ class HyperConnections(Module):
|
|
|
167
171
|
|
|
168
172
|
self.channel_first = channel_first
|
|
169
173
|
|
|
174
|
+
# maybe residual transform
|
|
175
|
+
|
|
176
|
+
self.residual_transform = default(residual_transform, nn.Identity())
|
|
177
|
+
|
|
170
178
|
def width_connection(self, residuals):
|
|
179
|
+
|
|
180
|
+
maybe_transformed_residuals = self.residual_transform(residuals)
|
|
181
|
+
|
|
171
182
|
# width connection
|
|
172
183
|
|
|
173
184
|
if self.channel_first:
|
|
@@ -196,7 +207,7 @@ class HyperConnections(Module):
|
|
|
196
207
|
if self.channel_first:
|
|
197
208
|
branch_input = rearrange(branch_input, 'b ... d -> b d ...')
|
|
198
209
|
|
|
199
|
-
return branch_input,
|
|
210
|
+
return branch_input, maybe_transformed_residuals, dict(beta = beta)
|
|
200
211
|
|
|
201
212
|
def depth_connection(self, branch_output, residuals, *, beta):
|
|
202
213
|
# 'depth' connection
|
|
@@ -204,13 +215,15 @@ class HyperConnections(Module):
|
|
|
204
215
|
if self.channel_first:
|
|
205
216
|
branch_output = rearrange(branch_output, 'b d ... -> b ... d')
|
|
206
217
|
|
|
207
|
-
|
|
208
|
-
output = rearrange(
|
|
218
|
+
output = einsum(branch_output, beta, 'b ... d, b ... s -> b ... s d')
|
|
219
|
+
output = rearrange(output, 'b ... s d -> (b s) ... d')
|
|
209
220
|
|
|
210
221
|
if self.channel_first:
|
|
211
222
|
output = rearrange(output, 'b ... d -> b d ...')
|
|
212
223
|
|
|
213
|
-
|
|
224
|
+
residuals = residuals + output
|
|
225
|
+
|
|
226
|
+
return self.dropout(residuals)
|
|
214
227
|
|
|
215
228
|
def decorate_branch(self, branch: Callable):
|
|
216
229
|
assert not exists(self.branch), 'branch was already wrapped on init'
|
|
@@ -11,6 +11,7 @@ from torch.nn import Module, ModuleList
|
|
|
11
11
|
from torch.utils._pytree import tree_flatten, tree_unflatten
|
|
12
12
|
|
|
13
13
|
from einops import rearrange, repeat, reduce, einsum
|
|
14
|
+
from einops.layers.torch import Reduce
|
|
14
15
|
|
|
15
16
|
"""
|
|
16
17
|
ein notation:
|
|
@@ -41,11 +42,11 @@ def identity(t):
|
|
|
41
42
|
# main functions
|
|
42
43
|
|
|
43
44
|
def get_expand_reduce_stream_functions(cls, num_streams, disable = False):
|
|
44
|
-
if disable:
|
|
45
|
-
return (
|
|
45
|
+
if num_streams == 1 or disable:
|
|
46
|
+
return (nn.Identity(), nn.Identity())
|
|
46
47
|
|
|
47
|
-
expand_fn =
|
|
48
|
-
reduce_fn =
|
|
48
|
+
expand_fn = Reduce(pattern = 'b ... -> (b s) ...', reduction = 'repeat', s = num_streams)
|
|
49
|
+
reduce_fn = Reduce(pattern = '(b s) ... -> b ...', reduction = 'sum', s = num_streams)
|
|
49
50
|
|
|
50
51
|
return expand_fn, reduce_fn
|
|
51
52
|
|
|
@@ -11,7 +11,7 @@ from torch.nn import Module, ModuleList
|
|
|
11
11
|
from torch.utils._pytree import tree_flatten, tree_unflatten
|
|
12
12
|
|
|
13
13
|
from einops import rearrange, repeat, reduce, einsum
|
|
14
|
-
from einops.layers.torch import Rearrange
|
|
14
|
+
from einops.layers.torch import Rearrange, Reduce
|
|
15
15
|
|
|
16
16
|
"""
|
|
17
17
|
ein notation:
|
|
@@ -29,18 +29,15 @@ def exists(v):
|
|
|
29
29
|
def default(v, d):
|
|
30
30
|
return v if exists(v) else d
|
|
31
31
|
|
|
32
|
-
def identity(t):
|
|
33
|
-
return t
|
|
34
|
-
|
|
35
32
|
# main functions
|
|
36
33
|
|
|
37
34
|
def get_expand_reduce_stream_functions(num_streams, disable = False):
|
|
38
35
|
|
|
39
|
-
if disable:
|
|
40
|
-
return (
|
|
36
|
+
if num_streams == 1 or disable:
|
|
37
|
+
return (nn.Identity(), nn.Identity())
|
|
41
38
|
|
|
42
|
-
expand_fn =
|
|
43
|
-
reduce_fn =
|
|
39
|
+
expand_fn = Reduce(pattern = 'b ... -> (b s) ...', reduction = 'repeat', s = num_streams)
|
|
40
|
+
reduce_fn = Reduce(pattern = '(b s) ... -> b ...', reduction = 'sum', s = num_streams)
|
|
44
41
|
|
|
45
42
|
return expand_fn, reduce_fn
|
|
46
43
|
|
|
@@ -136,3 +136,46 @@ def test_multi_input_hyper_connections(disable):
|
|
|
136
136
|
residual = reduce_stream(residual)
|
|
137
137
|
|
|
138
138
|
assert residual.shape == (3, 1024, 512)
|
|
139
|
+
|
|
140
|
+
@pytest.mark.parametrize('disable', (False, True))
|
|
141
|
+
def test_residual_transform(disable):
|
|
142
|
+
|
|
143
|
+
# a single branch layer
|
|
144
|
+
|
|
145
|
+
branch = nn.Sequential(
|
|
146
|
+
nn.Linear(512, 512),
|
|
147
|
+
nn.SiLU(),
|
|
148
|
+
nn.Linear(512, 256)
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
residual_fn = nn.Linear(512, 256)
|
|
152
|
+
|
|
153
|
+
# before
|
|
154
|
+
|
|
155
|
+
residual = torch.randn(2, 1024, 512)
|
|
156
|
+
|
|
157
|
+
before_residual = branch(residual) + residual_fn(residual)
|
|
158
|
+
|
|
159
|
+
# after, say 4 streams in paper
|
|
160
|
+
|
|
161
|
+
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
|
162
|
+
|
|
163
|
+
init_hyper_conn, expand_stream, reduce_stream = get_init_and_expand_reduce_stream_functions(4, disable = disable)
|
|
164
|
+
|
|
165
|
+
# 1. wrap your branch function
|
|
166
|
+
|
|
167
|
+
hyper_conn_branch = init_hyper_conn(dim = 512, branch = branch, residual_transform = residual_fn)
|
|
168
|
+
|
|
169
|
+
# 2. expand to 4 streams, this must be done before your trunk, typically a for-loop with many branch functions
|
|
170
|
+
|
|
171
|
+
residual = expand_stream(residual)
|
|
172
|
+
|
|
173
|
+
# 3. forward your residual as usual into the wrapped branch function(s)
|
|
174
|
+
|
|
175
|
+
residual = hyper_conn_branch(residual)
|
|
176
|
+
|
|
177
|
+
# 4. reduce 4 streams with a summation, this has to be done after your for-loop trunk. for transformer, unsure whether to do before or after final norm
|
|
178
|
+
|
|
179
|
+
after_residual = reduce_stream(residual)
|
|
180
|
+
|
|
181
|
+
assert before_residual.shape == after_residual.shape
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|