hyper-connections 0.2.1__tar.gz → 0.3.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {hyper_connections-0.2.1 → hyper_connections-0.3.0}/PKG-INFO +13 -1
- {hyper_connections-0.2.1 → hyper_connections-0.3.0}/README.md +12 -0
- hyper_connections-0.3.0/hyper_connections/manifold_constrained_hyper_connections.py +541 -0
- {hyper_connections-0.2.1 → hyper_connections-0.3.0}/pyproject.toml +1 -1
- {hyper_connections-0.2.1 → hyper_connections-0.3.0}/tests/test_hyper_connections.py +7 -2
- {hyper_connections-0.2.1 → hyper_connections-0.3.0}/.github/workflows/python-publish.yml +0 -0
- {hyper_connections-0.2.1 → hyper_connections-0.3.0}/.github/workflows/test.yml +0 -0
- {hyper_connections-0.2.1 → hyper_connections-0.3.0}/.gitignore +0 -0
- {hyper_connections-0.2.1 → hyper_connections-0.3.0}/LICENSE +0 -0
- {hyper_connections-0.2.1 → hyper_connections-0.3.0}/hyper-connections.png +0 -0
- {hyper_connections-0.2.1 → hyper_connections-0.3.0}/hyper_connections/__init__.py +0 -0
- {hyper_connections-0.2.1 → hyper_connections-0.3.0}/hyper_connections/hyper_connections.py +0 -0
- {hyper_connections-0.2.1 → hyper_connections-0.3.0}/hyper_connections/hyper_connections_channel_first.py +0 -0
- {hyper_connections-0.2.1 → hyper_connections-0.3.0}/hyper_connections/hyper_connections_with_multi_branch_inputs.py +0 -0
- {hyper_connections-0.2.1 → hyper_connections-0.3.0}/hyper_connections/hyper_connections_with_multi_input_streams.py +0 -0
- {hyper_connections-0.2.1 → hyper_connections-0.3.0}/hyper_connections/residuals.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: hyper-connections
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.3.0
|
|
4
4
|
Summary: Hyper-Connections
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/hyper-connections/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/hyper-connections
|
|
@@ -177,3 +177,15 @@ get_init_and_expand_reduce_stream_functions(1, num_fracs = 4) # also allows you
|
|
|
177
177
|
url = {https://api.semanticscholar.org/CorpusID:277104144}
|
|
178
178
|
}
|
|
179
179
|
```
|
|
180
|
+
|
|
181
|
+
```bibtex
|
|
182
|
+
@misc{xie2025mhcmanifoldconstrainedhyperconnections,
|
|
183
|
+
title = {mHC: Manifold-Constrained Hyper-Connections},
|
|
184
|
+
author = {Zhenda Xie and Yixuan Wei and Huanqi Cao and Chenggang Zhao and Chengqi Deng and Jiashi Li and Damai Dai and Huazuo Gao and Jiang Chang and Liang Zhao and Shangyan Zhou and Zhean Xu and Zhengyan Zhang and Wangding Zeng and Shengding Hu and Yuqing Wang and Jingyang Yuan and Lean Wang and Wenfeng Liang},
|
|
185
|
+
year = {2025},
|
|
186
|
+
eprint = {2512.24880},
|
|
187
|
+
archivePrefix = {arXiv},
|
|
188
|
+
primaryClass = {cs.CL},
|
|
189
|
+
url = {https://arxiv.org/abs/2512.24880},
|
|
190
|
+
}
|
|
191
|
+
```
|
|
@@ -136,3 +136,15 @@ get_init_and_expand_reduce_stream_functions(1, num_fracs = 4) # also allows you
|
|
|
136
136
|
url = {https://api.semanticscholar.org/CorpusID:277104144}
|
|
137
137
|
}
|
|
138
138
|
```
|
|
139
|
+
|
|
140
|
+
```bibtex
|
|
141
|
+
@misc{xie2025mhcmanifoldconstrainedhyperconnections,
|
|
142
|
+
title = {mHC: Manifold-Constrained Hyper-Connections},
|
|
143
|
+
author = {Zhenda Xie and Yixuan Wei and Huanqi Cao and Chenggang Zhao and Chengqi Deng and Jiashi Li and Damai Dai and Huazuo Gao and Jiang Chang and Liang Zhao and Shangyan Zhou and Zhean Xu and Zhengyan Zhang and Wangding Zeng and Shengding Hu and Yuqing Wang and Jingyang Yuan and Lean Wang and Wenfeng Liang},
|
|
144
|
+
year = {2025},
|
|
145
|
+
eprint = {2512.24880},
|
|
146
|
+
archivePrefix = {arXiv},
|
|
147
|
+
primaryClass = {cs.CL},
|
|
148
|
+
url = {https://arxiv.org/abs/2512.24880},
|
|
149
|
+
}
|
|
150
|
+
```
|
|
@@ -0,0 +1,541 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import Callable
|
|
3
|
+
|
|
4
|
+
from functools import partial
|
|
5
|
+
from random import randrange
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from torch import nn, cat
|
|
9
|
+
import torch.nn.functional as F
|
|
10
|
+
from torch.nn import Module, Sequential
|
|
11
|
+
from torch.utils._pytree import tree_flatten, tree_unflatten
|
|
12
|
+
|
|
13
|
+
from einops import rearrange, repeat, reduce, einsum, pack, unpack
|
|
14
|
+
from einops.layers.torch import Rearrange, Reduce
|
|
15
|
+
|
|
16
|
+
"""
|
|
17
|
+
ein notation:
|
|
18
|
+
b - batch
|
|
19
|
+
d - feature dimension
|
|
20
|
+
s - residual streams
|
|
21
|
+
t - residual streams + num branch inputs
|
|
22
|
+
f - number of fractions (division of feature dimension space)
|
|
23
|
+
v - number of views for branch input
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
# helper functions
|
|
27
|
+
|
|
28
|
+
def exists(v):
|
|
29
|
+
return v is not None
|
|
30
|
+
|
|
31
|
+
def divisible_by(num, den):
|
|
32
|
+
return (num % den) == 0
|
|
33
|
+
|
|
34
|
+
def default(v, d):
|
|
35
|
+
return v if exists(v) else d
|
|
36
|
+
|
|
37
|
+
def identity(t):
|
|
38
|
+
return t
|
|
39
|
+
|
|
40
|
+
def add(x, y):
|
|
41
|
+
return x + y
|
|
42
|
+
|
|
43
|
+
def pack_one_with_inverse(t, pattern):
|
|
44
|
+
packed, packed_shape = pack([t], pattern)
|
|
45
|
+
|
|
46
|
+
def inverse(out):
|
|
47
|
+
return unpack(out, packed_shape, pattern)[0]
|
|
48
|
+
|
|
49
|
+
return packed, inverse
|
|
50
|
+
|
|
51
|
+
# sinkhorn
|
|
52
|
+
|
|
53
|
+
def l1norm(t, dim):
|
|
54
|
+
return F.normalize(t, p = 1, dim = dim)
|
|
55
|
+
|
|
56
|
+
def sinkhorn_knopps(log_alpha, iters = 20):
|
|
57
|
+
alpha = log_alpha.exp()
|
|
58
|
+
|
|
59
|
+
for _ in range(iters):
|
|
60
|
+
alpha = l1norm(alpha, dim = -2)
|
|
61
|
+
alpha = l1norm(alpha, dim = -1)
|
|
62
|
+
|
|
63
|
+
return alpha
|
|
64
|
+
|
|
65
|
+
# main functions
|
|
66
|
+
|
|
67
|
+
def get_expand_reduce_stream_functions(
|
|
68
|
+
num_streams,
|
|
69
|
+
add_stream_embed = False,
|
|
70
|
+
dim = None,
|
|
71
|
+
disable = False
|
|
72
|
+
):
|
|
73
|
+
if num_streams == 1 or disable:
|
|
74
|
+
return (nn.Identity(), nn.Identity())
|
|
75
|
+
|
|
76
|
+
if add_stream_embed:
|
|
77
|
+
assert exists(dim), '`dim` must be passed into get_init_and_expand_reduce_stream_functions for returning an expansion function with stream embeddings added'
|
|
78
|
+
|
|
79
|
+
expand_fn = StreamEmbed(num_streams, dim, expand_to_streams = True)
|
|
80
|
+
else:
|
|
81
|
+
expand_fn = Reduce(pattern = 'b ... -> (b s) ...', reduction = 'repeat', s = num_streams)
|
|
82
|
+
|
|
83
|
+
reduce_fn = Reduce(pattern = '(b s) ... -> b ...', reduction = 'sum', s = num_streams)
|
|
84
|
+
|
|
85
|
+
return expand_fn, reduce_fn
|
|
86
|
+
|
|
87
|
+
def get_init_and_expand_reduce_stream_functions(
|
|
88
|
+
num_streams,
|
|
89
|
+
num_fracs = 1,
|
|
90
|
+
dim = None,
|
|
91
|
+
add_stream_embed = False,
|
|
92
|
+
disable = None
|
|
93
|
+
):
|
|
94
|
+
disable = default(disable, num_streams == 1 and num_fracs == 1)
|
|
95
|
+
|
|
96
|
+
hyper_conn_klass = ManifoldConstrainedHyperConnections if not disable else Residual
|
|
97
|
+
|
|
98
|
+
init_hyper_conn_fn = partial(hyper_conn_klass, num_streams, num_fracs = num_fracs)
|
|
99
|
+
expand_reduce_fns = get_expand_reduce_stream_functions(num_streams, add_stream_embed = add_stream_embed, dim = dim, disable = disable)
|
|
100
|
+
|
|
101
|
+
if exists(dim):
|
|
102
|
+
init_hyper_conn_fn = partial(init_hyper_conn_fn, dim = dim)
|
|
103
|
+
|
|
104
|
+
return (init_hyper_conn_fn, *expand_reduce_fns)
|
|
105
|
+
|
|
106
|
+
# norms
|
|
107
|
+
|
|
108
|
+
class RMSNorm(Module):
|
|
109
|
+
def __init__(self, dim):
|
|
110
|
+
super().__init__()
|
|
111
|
+
self.scale = dim ** 0.5
|
|
112
|
+
self.gamma = nn.Parameter(torch.zeros(dim))
|
|
113
|
+
|
|
114
|
+
def forward(self, x):
|
|
115
|
+
return F.normalize(x, dim = -1) * self.scale * (self.gamma + 1)
|
|
116
|
+
|
|
117
|
+
# main classes
|
|
118
|
+
|
|
119
|
+
# residual base class
|
|
120
|
+
|
|
121
|
+
class Residual(Module):
|
|
122
|
+
def __init__(
|
|
123
|
+
self,
|
|
124
|
+
*args,
|
|
125
|
+
branch: Module | None = None,
|
|
126
|
+
residual_transform: Module | None = None,
|
|
127
|
+
**kwargs
|
|
128
|
+
):
|
|
129
|
+
super().__init__()
|
|
130
|
+
self.branch = branch
|
|
131
|
+
self.residual_transform = default(residual_transform, nn.Identity())
|
|
132
|
+
|
|
133
|
+
def width_connection(
|
|
134
|
+
self,
|
|
135
|
+
residuals
|
|
136
|
+
):
|
|
137
|
+
return residuals, residuals, dict()
|
|
138
|
+
|
|
139
|
+
def depth_connection(
|
|
140
|
+
self,
|
|
141
|
+
branch_output,
|
|
142
|
+
residuals,
|
|
143
|
+
|
|
144
|
+
):
|
|
145
|
+
return branch_output + self.residual_transform(residuals)
|
|
146
|
+
|
|
147
|
+
def decorate_branch(
|
|
148
|
+
self,
|
|
149
|
+
branch: Callable
|
|
150
|
+
):
|
|
151
|
+
assert not exists(self.branch), 'branch was already wrapped on init'
|
|
152
|
+
|
|
153
|
+
def forward_and_add_residual(residual, *args, **kwargs):
|
|
154
|
+
branch_input, add_residual = self.forward(residual)
|
|
155
|
+
|
|
156
|
+
branch_output = branch(branch_input, *args, **kwargs)
|
|
157
|
+
|
|
158
|
+
residual = add_residual(branch_output)
|
|
159
|
+
|
|
160
|
+
return residual
|
|
161
|
+
|
|
162
|
+
return forward_and_add_residual
|
|
163
|
+
|
|
164
|
+
def forward(
|
|
165
|
+
self,
|
|
166
|
+
residuals,
|
|
167
|
+
*branch_args,
|
|
168
|
+
**branch_kwargs
|
|
169
|
+
):
|
|
170
|
+
|
|
171
|
+
branch_input, residuals, residual_kwargs = self.width_connection(residuals)
|
|
172
|
+
|
|
173
|
+
def add_residual_fn(branch_out):
|
|
174
|
+
(branch_out, *rest), tree_spec = tree_flatten(branch_out)
|
|
175
|
+
|
|
176
|
+
branch_out = self.depth_connection(branch_out, residuals, **residual_kwargs)
|
|
177
|
+
|
|
178
|
+
return tree_unflatten((branch_out, *rest), tree_spec)
|
|
179
|
+
|
|
180
|
+
if not exists(self.branch):
|
|
181
|
+
return branch_input, add_residual_fn
|
|
182
|
+
|
|
183
|
+
branch_output = self.branch(branch_input, *branch_args, **branch_kwargs)
|
|
184
|
+
|
|
185
|
+
return add_residual_fn(branch_output)
|
|
186
|
+
|
|
187
|
+
# hyper connection residual streams
|
|
188
|
+
|
|
189
|
+
class ManifoldConstrainedHyperConnections(Module):
|
|
190
|
+
def __init__(
|
|
191
|
+
self,
|
|
192
|
+
num_residual_streams,
|
|
193
|
+
*,
|
|
194
|
+
dim,
|
|
195
|
+
branch: Module | None = None,
|
|
196
|
+
layer_index = None,
|
|
197
|
+
channel_first = False,
|
|
198
|
+
dropout = 0.,
|
|
199
|
+
residual_transform: Module | None = None, # to support resnet blocks where dimension in not equal to dimension out - usually a residual conv
|
|
200
|
+
add_branch_out_to_residual = True, # will disable depth connections (weighted residual sum with beta) if set False
|
|
201
|
+
num_input_views = 1, # allow for the branch module to receive multiple input views, dimension placed on the very left (before batch)
|
|
202
|
+
depth_residual_fn = add,
|
|
203
|
+
num_fracs = 1, # https://arxiv.org/abs/2503.14125
|
|
204
|
+
sinkhorn_iters = 20
|
|
205
|
+
):
|
|
206
|
+
"""
|
|
207
|
+
Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
|
|
208
|
+
"""
|
|
209
|
+
super().__init__()
|
|
210
|
+
|
|
211
|
+
self.branch = branch
|
|
212
|
+
|
|
213
|
+
# frac-connections paper - num_fracs > 1 will be the `m` in their paper https://arxiv.org/abs/2503.14125
|
|
214
|
+
|
|
215
|
+
assert num_fracs >= 1
|
|
216
|
+
|
|
217
|
+
self.num_fracs = num_fracs
|
|
218
|
+
self.has_fracs = num_fracs > 1
|
|
219
|
+
|
|
220
|
+
self.split_fracs = Rearrange('b ... (f d) -> b ... f d', f = num_fracs)
|
|
221
|
+
self.merge_fracs = Rearrange('b ... f d -> b ... (f d)')
|
|
222
|
+
|
|
223
|
+
assert divisible_by(dim, num_fracs), f'feature dimension ({dim}) must be divisible by the `num_fracs` ({num_fracs})'
|
|
224
|
+
|
|
225
|
+
dim //= num_fracs # effective dim handled in dimension is feature dimension divided by num fractions
|
|
226
|
+
|
|
227
|
+
# they used layernorm in paper, but rmsnorm is fine given what we know now
|
|
228
|
+
|
|
229
|
+
self.norm = RMSNorm(dim * num_residual_streams * num_fracs)
|
|
230
|
+
|
|
231
|
+
assert num_residual_streams > 0, '`num_residual_streams` must be greater than 0'
|
|
232
|
+
|
|
233
|
+
self.num_residual_streams = num_residual_streams
|
|
234
|
+
init_residual_index = default(layer_index, randrange(num_residual_streams)) % num_residual_streams # just choose one random residual stream if layer index not given
|
|
235
|
+
|
|
236
|
+
# handle the parameter dimensions, which may require (num_residuals x num_fractions) - generalizing hyper + frac connections
|
|
237
|
+
|
|
238
|
+
num_residual_streams_fracs = num_residual_streams * num_fracs
|
|
239
|
+
num_input_views_fracs = num_input_views * num_fracs
|
|
240
|
+
|
|
241
|
+
self.num_fracs = num_fracs
|
|
242
|
+
|
|
243
|
+
# width num residual streams
|
|
244
|
+
|
|
245
|
+
assert num_input_views >= 1
|
|
246
|
+
self.num_input_views = num_input_views
|
|
247
|
+
|
|
248
|
+
# width connection
|
|
249
|
+
|
|
250
|
+
init_alpha0 = torch.zeros((num_residual_streams_fracs, num_input_views_fracs))
|
|
251
|
+
init_alpha0[init_residual_index, :] = 1.
|
|
252
|
+
|
|
253
|
+
self.static_alpha = nn.Parameter(cat((init_alpha0, torch.eye(num_residual_streams_fracs)), dim = 1))
|
|
254
|
+
|
|
255
|
+
self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams_fracs + num_input_views_fracs))
|
|
256
|
+
|
|
257
|
+
self.pre_branch_scale = nn.Parameter(torch.ones(1) * 1e-2)
|
|
258
|
+
self.residual_scale = nn.Parameter(torch.ones(1) * 1e-2)
|
|
259
|
+
|
|
260
|
+
# depth connection related (beta)
|
|
261
|
+
|
|
262
|
+
self.add_branch_out_to_residual = add_branch_out_to_residual
|
|
263
|
+
|
|
264
|
+
if add_branch_out_to_residual:
|
|
265
|
+
self.static_beta = nn.Parameter(torch.ones(num_residual_streams_fracs))
|
|
266
|
+
|
|
267
|
+
dynamic_beta_shape = (dim,) if num_fracs == 1 else (dim, num_fracs) # preserve backwards compat
|
|
268
|
+
self.dynamic_beta_fn = nn.Parameter(torch.zeros(dynamic_beta_shape))
|
|
269
|
+
|
|
270
|
+
self.h_post_scale = nn.Parameter(torch.ones(()) * 1e-2)
|
|
271
|
+
|
|
272
|
+
# sinkhorn related
|
|
273
|
+
|
|
274
|
+
self.sinkhorn_iters = sinkhorn_iters
|
|
275
|
+
|
|
276
|
+
# dropouts
|
|
277
|
+
|
|
278
|
+
self.dropout = nn.Dropout(dropout)
|
|
279
|
+
|
|
280
|
+
# channel first option
|
|
281
|
+
|
|
282
|
+
self.channel_first = channel_first
|
|
283
|
+
|
|
284
|
+
# maybe residual transform
|
|
285
|
+
|
|
286
|
+
self.residual_transform = default(residual_transform, nn.Identity())
|
|
287
|
+
|
|
288
|
+
# maybe custom depth connection residual function
|
|
289
|
+
# this is to prepare for gating the addition of the branch outputs to the residual streams
|
|
290
|
+
# needed for memory lanes a la RMT / LMM
|
|
291
|
+
|
|
292
|
+
self.depth_residual_fn = depth_residual_fn
|
|
293
|
+
|
|
294
|
+
def width_connection(
|
|
295
|
+
self,
|
|
296
|
+
residuals
|
|
297
|
+
):
|
|
298
|
+
streams = self.num_residual_streams
|
|
299
|
+
|
|
300
|
+
maybe_transformed_residuals = self.residual_transform(residuals)
|
|
301
|
+
|
|
302
|
+
# width connection
|
|
303
|
+
|
|
304
|
+
# handle channel first
|
|
305
|
+
|
|
306
|
+
if self.channel_first:
|
|
307
|
+
residuals = rearrange(residuals, 'b d ... -> b ... d')
|
|
308
|
+
|
|
309
|
+
# split out fractions
|
|
310
|
+
|
|
311
|
+
residuals = self.split_fracs(residuals)
|
|
312
|
+
|
|
313
|
+
# split out streams
|
|
314
|
+
|
|
315
|
+
residuals = rearrange(residuals, '(b s) ... d -> b ... s d', s = streams)
|
|
316
|
+
|
|
317
|
+
# norm
|
|
318
|
+
|
|
319
|
+
flattened_residuals, inverse_pack = pack_one_with_inverse(residuals, 'b n *')
|
|
320
|
+
|
|
321
|
+
normed = self.norm(flattened_residuals) # they norm across flattened stream + dimension?
|
|
322
|
+
|
|
323
|
+
normed = inverse_pack(normed)
|
|
324
|
+
|
|
325
|
+
# alpha for weighted sum of residuals going into branch
|
|
326
|
+
|
|
327
|
+
wc_weight = normed @ self.dynamic_alpha_fn
|
|
328
|
+
|
|
329
|
+
pre_branch_scale = repeat(self.pre_branch_scale, '1 -> s', s = self.num_fracs)
|
|
330
|
+
residual_scale = repeat(self.residual_scale, '1 -> s', s = self.num_fracs * streams)
|
|
331
|
+
alpha_scale = cat((pre_branch_scale, residual_scale))
|
|
332
|
+
|
|
333
|
+
alpha_scale = repeat(alpha_scale, 'n -> (v n)', v = self.num_input_views)
|
|
334
|
+
|
|
335
|
+
dynamic_alpha = wc_weight * alpha_scale
|
|
336
|
+
|
|
337
|
+
static_alpha = rearrange(self.static_alpha, '(f s) d -> f s d', s = streams)
|
|
338
|
+
|
|
339
|
+
alpha = dynamic_alpha + static_alpha
|
|
340
|
+
|
|
341
|
+
alpha = self.split_fracs(alpha) # (batch, seq, fracs1, streams, fracs2, input + residual streams)
|
|
342
|
+
|
|
343
|
+
# the alpha is now split and "manifold constrained" with sinkhorn and sigmoid
|
|
344
|
+
|
|
345
|
+
alpha_pre, alpha_residual = alpha[..., :self.num_input_views], alpha[..., self.num_input_views:]
|
|
346
|
+
|
|
347
|
+
alpha_pre = alpha_pre.sigmoid()
|
|
348
|
+
|
|
349
|
+
alpha_residual = rearrange(alpha_residual, '... (v s1 s2) -> ... v s1 s2', v = self.num_input_views, s1 = streams)
|
|
350
|
+
|
|
351
|
+
alpha_residual = sinkhorn_knopps(alpha_residual, self.sinkhorn_iters)
|
|
352
|
+
|
|
353
|
+
alpha_residual = rearrange(alpha_residual, '... v s1 s2 -> ... (v s1 s2)')
|
|
354
|
+
|
|
355
|
+
alpha = cat((alpha_pre, alpha_residual), dim = -1)
|
|
356
|
+
|
|
357
|
+
# beta for weights from branch output back to residual streams
|
|
358
|
+
|
|
359
|
+
beta = None
|
|
360
|
+
|
|
361
|
+
if self.add_branch_out_to_residual:
|
|
362
|
+
dc_weight = normed @ self.dynamic_beta_fn
|
|
363
|
+
|
|
364
|
+
dc_weight = dc_weight.sigmoid() * 2 # sigmoid * 2 for "H_post", corresponding to dc weight in original paper
|
|
365
|
+
|
|
366
|
+
if not self.has_fracs:
|
|
367
|
+
dc_weight = rearrange(dc_weight, '... -> ... 1')
|
|
368
|
+
|
|
369
|
+
dynamic_beta = dc_weight * self.h_post_scale
|
|
370
|
+
|
|
371
|
+
static_beta = rearrange(self.static_beta, '... (s f) -> ... s f', s = streams)
|
|
372
|
+
|
|
373
|
+
beta = dynamic_beta + static_beta
|
|
374
|
+
|
|
375
|
+
mix_h = einsum(alpha, residuals, '... f1 s f2 t, ... f1 s d -> ... f2 t d')
|
|
376
|
+
|
|
377
|
+
if self.num_input_views == 1:
|
|
378
|
+
branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :]
|
|
379
|
+
else:
|
|
380
|
+
branch_input, residuals = mix_h[..., :self.num_input_views, :], mix_h[..., self.num_input_views:, :]
|
|
381
|
+
branch_input = rearrange(branch_input, 'b ... v d -> v b ... d')
|
|
382
|
+
|
|
383
|
+
if self.channel_first:
|
|
384
|
+
branch_input = rearrange(branch_input, 'b ... d -> b d ...')
|
|
385
|
+
|
|
386
|
+
# maybe merge fractions back
|
|
387
|
+
|
|
388
|
+
branch_input = self.merge_fracs(branch_input)
|
|
389
|
+
|
|
390
|
+
return branch_input, maybe_transformed_residuals, dict(beta = beta)
|
|
391
|
+
|
|
392
|
+
def depth_connection(
|
|
393
|
+
self,
|
|
394
|
+
branch_output,
|
|
395
|
+
residuals,
|
|
396
|
+
*,
|
|
397
|
+
beta
|
|
398
|
+
):
|
|
399
|
+
assert self.add_branch_out_to_residual
|
|
400
|
+
|
|
401
|
+
# maybe split fractions
|
|
402
|
+
|
|
403
|
+
branch_output = self.split_fracs(branch_output)
|
|
404
|
+
|
|
405
|
+
# 'depth' connection
|
|
406
|
+
|
|
407
|
+
if self.channel_first:
|
|
408
|
+
branch_output = rearrange(branch_output, 'b d ... -> b ... d')
|
|
409
|
+
|
|
410
|
+
output = einsum(branch_output, beta, 'b ... f1 d, b ... f1 s f2 -> b ... f2 s d')
|
|
411
|
+
|
|
412
|
+
output = rearrange(output, 'b ... s d -> (b s) ... d')
|
|
413
|
+
|
|
414
|
+
# merge merge back fractions
|
|
415
|
+
|
|
416
|
+
output = self.merge_fracs(output)
|
|
417
|
+
|
|
418
|
+
# channel first
|
|
419
|
+
|
|
420
|
+
if self.channel_first:
|
|
421
|
+
output = rearrange(output, 'b ... d -> b d ...')
|
|
422
|
+
|
|
423
|
+
residuals = self.depth_residual_fn(output, residuals)
|
|
424
|
+
|
|
425
|
+
return self.dropout(residuals)
|
|
426
|
+
|
|
427
|
+
def decorate_branch(
|
|
428
|
+
self,
|
|
429
|
+
branch: Callable
|
|
430
|
+
):
|
|
431
|
+
assert not exists(self.branch), 'branch was already wrapped on init'
|
|
432
|
+
|
|
433
|
+
def forward_and_add_residual(residual, *args, **kwargs):
|
|
434
|
+
branch_input, add_residual = self.forward(residual)
|
|
435
|
+
|
|
436
|
+
branch_output = branch(branch_input, *args, **kwargs)
|
|
437
|
+
|
|
438
|
+
residual = add_residual(branch_output)
|
|
439
|
+
|
|
440
|
+
return residual
|
|
441
|
+
|
|
442
|
+
return forward_and_add_residual
|
|
443
|
+
|
|
444
|
+
def forward(
|
|
445
|
+
self,
|
|
446
|
+
residuals,
|
|
447
|
+
*branch_args,
|
|
448
|
+
**branch_kwargs
|
|
449
|
+
):
|
|
450
|
+
|
|
451
|
+
branch_input, residuals, residual_kwargs = self.width_connection(residuals)
|
|
452
|
+
|
|
453
|
+
def add_residual_fn(branch_out):
|
|
454
|
+
|
|
455
|
+
if not self.add_branch_out_to_residual:
|
|
456
|
+
return branch_out
|
|
457
|
+
|
|
458
|
+
(branch_out, *rest), tree_spec = tree_flatten(branch_out)
|
|
459
|
+
|
|
460
|
+
branch_out = self.depth_connection(branch_out, residuals, **residual_kwargs)
|
|
461
|
+
|
|
462
|
+
return tree_unflatten((branch_out, *rest), tree_spec)
|
|
463
|
+
|
|
464
|
+
if not exists(self.branch):
|
|
465
|
+
return branch_input, add_residual_fn
|
|
466
|
+
|
|
467
|
+
branch_output = self.branch(branch_input, *branch_args, **branch_kwargs)
|
|
468
|
+
|
|
469
|
+
return add_residual_fn(branch_output)
|
|
470
|
+
|
|
471
|
+
ManifoldConstrainedHyperConnections.get_expand_reduce_stream_functions = staticmethod(get_expand_reduce_stream_functions)
|
|
472
|
+
ManifoldConstrainedHyperConnections.get_init_and_expand_reduce_stream_functions = staticmethod(get_init_and_expand_reduce_stream_functions)
|
|
473
|
+
|
|
474
|
+
# stream embed
|
|
475
|
+
|
|
476
|
+
class StreamEmbed(Module):
|
|
477
|
+
def __init__(
|
|
478
|
+
self,
|
|
479
|
+
num_streams,
|
|
480
|
+
dim,
|
|
481
|
+
channel_first = False,
|
|
482
|
+
expand_to_streams = False
|
|
483
|
+
):
|
|
484
|
+
super().__init__()
|
|
485
|
+
self.channel_first = channel_first
|
|
486
|
+
self.num_streams = num_streams
|
|
487
|
+
|
|
488
|
+
self.expand_to_streams = expand_to_streams
|
|
489
|
+
self.stream_embed = nn.Parameter(torch.zeros(num_streams, dim))
|
|
490
|
+
|
|
491
|
+
def forward(self, residuals):
|
|
492
|
+
|
|
493
|
+
if self.expand_to_streams:
|
|
494
|
+
residuals = repeat(residuals, 'b ... -> (b s) ...', s = self.num_streams)
|
|
495
|
+
|
|
496
|
+
if self.channel_first:
|
|
497
|
+
residuals = rearrange(residuals, '(b s) d ... -> b ... s d', s = self.num_streams)
|
|
498
|
+
else:
|
|
499
|
+
residuals = rearrange(residuals, '(b s) ... d -> b ... s d', s = self.num_streams)
|
|
500
|
+
|
|
501
|
+
residuals = residuals + self.stream_embed
|
|
502
|
+
|
|
503
|
+
if self.channel_first:
|
|
504
|
+
residuals = rearrange(residuals, 'b ... s d -> (b s) d ...', s = self.num_streams)
|
|
505
|
+
else:
|
|
506
|
+
residuals = rearrange(residuals, 'b ... s d -> (b s) ... d', s = self.num_streams)
|
|
507
|
+
|
|
508
|
+
return residuals
|
|
509
|
+
|
|
510
|
+
# attention pool - taken from Enformer https://www.nature.com/articles/s41592-021-01252-x , in turn taken from somewhere else
|
|
511
|
+
|
|
512
|
+
class AttentionPoolReduceStream(Module):
|
|
513
|
+
def __init__(
|
|
514
|
+
self,
|
|
515
|
+
num_streams,
|
|
516
|
+
dim,
|
|
517
|
+
channel_first = False
|
|
518
|
+
):
|
|
519
|
+
super().__init__()
|
|
520
|
+
self.num_streams = num_streams
|
|
521
|
+
self.channel_first = channel_first
|
|
522
|
+
|
|
523
|
+
self.to_attn_logits = nn.Linear(dim, dim, bias = False)
|
|
524
|
+
self.to_attn_logits.weight.data.copy_(torch.eye(dim))
|
|
525
|
+
|
|
526
|
+
def forward(self, residuals):
|
|
527
|
+
|
|
528
|
+
if self.channel_first:
|
|
529
|
+
residuals = rearrange(residuals, '(b s) d ... -> b ... s d', s = self.num_streams)
|
|
530
|
+
else:
|
|
531
|
+
residuals = rearrange(residuals, '(b s) ... d -> b ... s d', s = self.num_streams)
|
|
532
|
+
|
|
533
|
+
attn_logits = self.to_attn_logits(residuals)
|
|
534
|
+
attn = attn_logits.softmax(dim = -2)
|
|
535
|
+
|
|
536
|
+
residuals = reduce(residuals * attn, 'b ... s d -> b ... d', 'sum')
|
|
537
|
+
|
|
538
|
+
if self.channel_first:
|
|
539
|
+
residuals = rearrange(residuals, 'b ... d -> b d ...')
|
|
540
|
+
|
|
541
|
+
return residuals
|
|
@@ -5,9 +5,11 @@ from torch import nn
|
|
|
5
5
|
|
|
6
6
|
@pytest.mark.parametrize('num_fracs', (1, 4))
|
|
7
7
|
@pytest.mark.parametrize('disable', (False, True))
|
|
8
|
+
@pytest.mark.parametrize('manifold_constrained', (False, True))
|
|
8
9
|
def test_readme(
|
|
9
10
|
num_fracs,
|
|
10
|
-
disable
|
|
11
|
+
disable,
|
|
12
|
+
manifold_constrained
|
|
11
13
|
):
|
|
12
14
|
|
|
13
15
|
# a single branch layer
|
|
@@ -22,7 +24,10 @@ def test_readme(
|
|
|
22
24
|
|
|
23
25
|
# after, say 4 streams in paper
|
|
24
26
|
|
|
25
|
-
|
|
27
|
+
if manifold_constrained:
|
|
28
|
+
from hyper_connections.manifold_constrained_hyper_connections import get_init_and_expand_reduce_stream_functions
|
|
29
|
+
else:
|
|
30
|
+
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
|
26
31
|
|
|
27
32
|
init_hyper_conn, expand_stream, reduce_stream = get_init_and_expand_reduce_stream_functions(4, num_fracs = num_fracs, disable = disable)
|
|
28
33
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|