hyper-connections 0.4.1__tar.gz → 0.4.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {hyper_connections-0.4.1 → hyper_connections-0.4.3}/PKG-INFO +3 -3
- {hyper_connections-0.4.1 → hyper_connections-0.4.3}/hyper_connections/hyper_connections.py +1 -1
- {hyper_connections-0.4.1 → hyper_connections-0.4.3}/hyper_connections/hyper_connections_channel_first.py +1 -1
- {hyper_connections-0.4.1 → hyper_connections-0.4.3}/hyper_connections/hyper_connections_with_multi_branch_inputs.py +1 -1
- {hyper_connections-0.4.1 → hyper_connections-0.4.3}/hyper_connections/hyper_connections_with_multi_input_streams.py +1 -1
- hyper_connections-0.4.3/hyper_connections/mHCv2.py +541 -0
- {hyper_connections-0.4.1 → hyper_connections-0.4.3}/hyper_connections/manifold_constrained_hyper_connections.py +5 -5
- {hyper_connections-0.4.1 → hyper_connections-0.4.3}/pyproject.toml +3 -3
- {hyper_connections-0.4.1 → hyper_connections-0.4.3}/tests/test_hyper_connections.py +46 -0
- {hyper_connections-0.4.1 → hyper_connections-0.4.3}/.github/workflows/python-publish.yml +0 -0
- {hyper_connections-0.4.1 → hyper_connections-0.4.3}/.github/workflows/test.yml +0 -0
- {hyper_connections-0.4.1 → hyper_connections-0.4.3}/.gitignore +0 -0
- {hyper_connections-0.4.1 → hyper_connections-0.4.3}/LICENSE +0 -0
- {hyper_connections-0.4.1 → hyper_connections-0.4.3}/README.md +0 -0
- {hyper_connections-0.4.1 → hyper_connections-0.4.3}/hyper-connections.png +0 -0
- {hyper_connections-0.4.1 → hyper_connections-0.4.3}/hyper_connections/__init__.py +0 -0
- {hyper_connections-0.4.1 → hyper_connections-0.4.3}/hyper_connections/residuals.py +0 -0
- {hyper_connections-0.4.1 → hyper_connections-0.4.3}/hyper_connections/vit.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: hyper-connections
|
|
3
|
-
Version: 0.4.
|
|
3
|
+
Version: 0.4.3
|
|
4
4
|
Summary: Hyper-Connections
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/hyper-connections/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/hyper-connections
|
|
@@ -34,8 +34,8 @@ Classifier: License :: OSI Approved :: MIT License
|
|
|
34
34
|
Classifier: Programming Language :: Python :: 3.9
|
|
35
35
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
36
36
|
Requires-Python: >=3.9
|
|
37
|
-
Requires-Dist: einops>=0.8.
|
|
38
|
-
Requires-Dist: torch>=2.
|
|
37
|
+
Requires-Dist: einops>=0.8.1
|
|
38
|
+
Requires-Dist: torch>=2.5
|
|
39
39
|
Provides-Extra: examples
|
|
40
40
|
Description-Content-Type: text/markdown
|
|
41
41
|
|
|
@@ -41,7 +41,7 @@ def identity(t):
|
|
|
41
41
|
|
|
42
42
|
def get_expand_reduce_stream_functions(num_streams, disable = False):
|
|
43
43
|
|
|
44
|
-
if
|
|
44
|
+
if disable:
|
|
45
45
|
return (nn.Identity(), nn.Identity())
|
|
46
46
|
|
|
47
47
|
expand_fn = Reduce(pattern = 'b ... -> (b s) ...', reduction = 'repeat', s = num_streams)
|
|
@@ -42,7 +42,7 @@ def identity(t):
|
|
|
42
42
|
# main functions
|
|
43
43
|
|
|
44
44
|
def get_expand_reduce_stream_functions(cls, num_streams, disable = False):
|
|
45
|
-
if
|
|
45
|
+
if disable:
|
|
46
46
|
return (nn.Identity(), nn.Identity())
|
|
47
47
|
|
|
48
48
|
expand_fn = Reduce(pattern = 'b ... -> (b s) ...', reduction = 'repeat', s = num_streams)
|
|
@@ -33,7 +33,7 @@ def default(v, d):
|
|
|
33
33
|
|
|
34
34
|
def get_expand_reduce_stream_functions(num_streams, disable = False):
|
|
35
35
|
|
|
36
|
-
if
|
|
36
|
+
if disable:
|
|
37
37
|
return (nn.Identity(), nn.Identity())
|
|
38
38
|
|
|
39
39
|
expand_fn = Reduce(pattern = 'b ... -> (b s) ...', reduction = 'repeat', s = num_streams)
|
|
@@ -0,0 +1,541 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import Callable
|
|
3
|
+
|
|
4
|
+
from functools import partial
|
|
5
|
+
from random import randrange
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from torch import nn, cat
|
|
9
|
+
import torch.nn.functional as F
|
|
10
|
+
from torch.nn import Module, Sequential
|
|
11
|
+
from torch.utils._pytree import tree_flatten, tree_unflatten
|
|
12
|
+
|
|
13
|
+
from einops import rearrange, repeat, reduce, einsum
|
|
14
|
+
from einops.layers.torch import Rearrange, Reduce
|
|
15
|
+
|
|
16
|
+
"""
|
|
17
|
+
ein notation:
|
|
18
|
+
b - batch
|
|
19
|
+
d - feature dimension
|
|
20
|
+
s - residual streams
|
|
21
|
+
t - residual streams + num branch inputs
|
|
22
|
+
f - number of fractions (division of feature dimension space)
|
|
23
|
+
v - number of views for branch input
|
|
24
|
+
p - proposals
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
# helper functions
|
|
28
|
+
|
|
29
|
+
def exists(v):
|
|
30
|
+
return v is not None
|
|
31
|
+
|
|
32
|
+
def divisible_by(num, den):
|
|
33
|
+
return (num % den) == 0
|
|
34
|
+
|
|
35
|
+
def default(v, d):
|
|
36
|
+
return v if exists(v) else d
|
|
37
|
+
|
|
38
|
+
def identity(t):
|
|
39
|
+
return t
|
|
40
|
+
|
|
41
|
+
def add(x, y):
|
|
42
|
+
return x + y
|
|
43
|
+
|
|
44
|
+
# sinkhorn
|
|
45
|
+
|
|
46
|
+
def l1norm(t, dim):
|
|
47
|
+
return F.normalize(t, p = 1, dim = dim)
|
|
48
|
+
|
|
49
|
+
def sinkhorn_knopps(log_alpha, iters = 20):
|
|
50
|
+
assert log_alpha.shape[-2] == log_alpha.shape[-1]
|
|
51
|
+
|
|
52
|
+
dtype = log_alpha.dtype
|
|
53
|
+
log_alpha = log_alpha.float()
|
|
54
|
+
|
|
55
|
+
log_alpha = log_alpha - log_alpha.amax(dim = -2, keepdim = True).detach()
|
|
56
|
+
|
|
57
|
+
alpha = log_alpha.exp()
|
|
58
|
+
|
|
59
|
+
for _ in range(iters):
|
|
60
|
+
alpha = l1norm(alpha, dim = -2)
|
|
61
|
+
alpha = l1norm(alpha, dim = -1)
|
|
62
|
+
|
|
63
|
+
return alpha.to(dtype)
|
|
64
|
+
|
|
65
|
+
def log_domain_sinkhorn_knopps(log_alpha, iters = 20):
|
|
66
|
+
assert log_alpha.shape[-2] == log_alpha.shape[-1]
|
|
67
|
+
|
|
68
|
+
dtype = log_alpha.dtype
|
|
69
|
+
log_alpha = log_alpha.float()
|
|
70
|
+
|
|
71
|
+
for _ in range(iters):
|
|
72
|
+
log_alpha = log_alpha - log_alpha.logsumexp(dim = -2, keepdim = True)
|
|
73
|
+
log_alpha = log_alpha - log_alpha.logsumexp(dim = -1, keepdim = True)
|
|
74
|
+
|
|
75
|
+
return log_alpha.exp().to(dtype)
|
|
76
|
+
|
|
77
|
+
# main functions
|
|
78
|
+
|
|
79
|
+
def get_expand_reduce_stream_functions(
|
|
80
|
+
num_streams,
|
|
81
|
+
add_stream_embed = False,
|
|
82
|
+
add_attn_pool_reduce_stream = False,
|
|
83
|
+
dim = None,
|
|
84
|
+
disable = False
|
|
85
|
+
):
|
|
86
|
+
if disable:
|
|
87
|
+
return (nn.Identity(), nn.Identity())
|
|
88
|
+
|
|
89
|
+
if add_stream_embed or add_attn_pool_reduce_stream:
|
|
90
|
+
assert exists(dim), '`dim` must be passed into get_init_and_expand_reduce_stream_functions for returning an expansion function with stream embeddings added'
|
|
91
|
+
|
|
92
|
+
if add_stream_embed:
|
|
93
|
+
expand_fn = StreamEmbed(num_streams, dim, expand_to_streams = True)
|
|
94
|
+
else:
|
|
95
|
+
expand_fn = Reduce('... d -> ... s d', 'repeat', s = num_streams)
|
|
96
|
+
|
|
97
|
+
if add_attn_pool_reduce_stream:
|
|
98
|
+
reduce_fn = AttentionPoolReduceStream(dim)
|
|
99
|
+
else:
|
|
100
|
+
reduce_fn = Reduce('... s d -> ... d', 'sum')
|
|
101
|
+
|
|
102
|
+
return expand_fn, reduce_fn
|
|
103
|
+
|
|
104
|
+
def get_init_and_expand_reduce_stream_functions(
|
|
105
|
+
num_streams,
|
|
106
|
+
num_fracs = 1,
|
|
107
|
+
dim = None,
|
|
108
|
+
add_stream_embed = False,
|
|
109
|
+
add_attn_pool_reduce_stream = False,
|
|
110
|
+
disable = None,
|
|
111
|
+
sinkhorn_iters = 20,
|
|
112
|
+
**kwargs
|
|
113
|
+
):
|
|
114
|
+
disable = default(disable, num_streams == 1 and num_fracs == 1)
|
|
115
|
+
|
|
116
|
+
hyper_conn_klass = ManifoldConstrainedHyperConnections if not disable else Residual
|
|
117
|
+
|
|
118
|
+
kwargs.pop('add_attn_pool_reduce_stream', None)
|
|
119
|
+
init_hyper_conn_fn = partial(hyper_conn_klass, num_streams, num_fracs = num_fracs, sinkhorn_iters = sinkhorn_iters, **kwargs)
|
|
120
|
+
expand_reduce_fns = get_expand_reduce_stream_functions(
|
|
121
|
+
num_streams,
|
|
122
|
+
add_stream_embed = add_stream_embed,
|
|
123
|
+
add_attn_pool_reduce_stream = add_attn_pool_reduce_stream,
|
|
124
|
+
dim = dim,
|
|
125
|
+
disable = disable
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
if exists(dim):
|
|
129
|
+
init_hyper_conn_fn = partial(init_hyper_conn_fn, dim = dim)
|
|
130
|
+
|
|
131
|
+
return (init_hyper_conn_fn, *expand_reduce_fns)
|
|
132
|
+
|
|
133
|
+
# norms
|
|
134
|
+
|
|
135
|
+
class RMSNorm(Module):
|
|
136
|
+
def __init__(self, dim):
|
|
137
|
+
super().__init__()
|
|
138
|
+
self.scale = dim ** 0.5
|
|
139
|
+
self.gamma = nn.Parameter(torch.zeros(dim))
|
|
140
|
+
|
|
141
|
+
def forward(self, x):
|
|
142
|
+
return F.normalize(x, dim = -1) * self.scale * (self.gamma + 1)
|
|
143
|
+
|
|
144
|
+
# main classes
|
|
145
|
+
|
|
146
|
+
# residual base class
|
|
147
|
+
|
|
148
|
+
class Residual(Module):
|
|
149
|
+
def __init__(
|
|
150
|
+
self,
|
|
151
|
+
*args,
|
|
152
|
+
branch: Module | None = None,
|
|
153
|
+
residual_transform: Module | None = None,
|
|
154
|
+
**kwargs
|
|
155
|
+
):
|
|
156
|
+
super().__init__()
|
|
157
|
+
self.branch = branch
|
|
158
|
+
self.residual_transform = default(residual_transform, nn.Identity())
|
|
159
|
+
|
|
160
|
+
def width_connection(
|
|
161
|
+
self,
|
|
162
|
+
residuals
|
|
163
|
+
):
|
|
164
|
+
return residuals, residuals, dict()
|
|
165
|
+
|
|
166
|
+
def depth_connection(
|
|
167
|
+
self,
|
|
168
|
+
branch_output,
|
|
169
|
+
residuals
|
|
170
|
+
):
|
|
171
|
+
return branch_output + self.residual_transform(residuals)
|
|
172
|
+
|
|
173
|
+
def decorate_branch(
|
|
174
|
+
self,
|
|
175
|
+
branch: Callable
|
|
176
|
+
):
|
|
177
|
+
assert not exists(self.branch), 'branch was already wrapped on init'
|
|
178
|
+
|
|
179
|
+
def forward_and_add_residual(residual, *args, **kwargs):
|
|
180
|
+
branch_input, add_residual = self.forward(residual)
|
|
181
|
+
|
|
182
|
+
branch_output = branch(branch_input, *args, **kwargs)
|
|
183
|
+
|
|
184
|
+
residual = add_residual(branch_output)
|
|
185
|
+
|
|
186
|
+
return residual
|
|
187
|
+
|
|
188
|
+
return forward_and_add_residual
|
|
189
|
+
|
|
190
|
+
def forward(
|
|
191
|
+
self,
|
|
192
|
+
residuals,
|
|
193
|
+
*branch_args,
|
|
194
|
+
**branch_kwargs
|
|
195
|
+
):
|
|
196
|
+
|
|
197
|
+
branch_input, residuals, residual_kwargs = self.width_connection(residuals)
|
|
198
|
+
|
|
199
|
+
def add_residual_fn(branch_out):
|
|
200
|
+
(branch_out, *rest), tree_spec = tree_flatten(branch_out)
|
|
201
|
+
|
|
202
|
+
branch_out = self.depth_connection(branch_out, residuals, **residual_kwargs)
|
|
203
|
+
|
|
204
|
+
return tree_unflatten((branch_out, *rest), tree_spec)
|
|
205
|
+
|
|
206
|
+
if not exists(self.branch):
|
|
207
|
+
return branch_input, add_residual_fn
|
|
208
|
+
|
|
209
|
+
branch_output = self.branch(branch_input, *branch_args, **branch_kwargs)
|
|
210
|
+
|
|
211
|
+
return add_residual_fn(branch_output)
|
|
212
|
+
|
|
213
|
+
# hyper connection residual streams
|
|
214
|
+
|
|
215
|
+
class ManifoldConstrainedHyperConnections(Module):
|
|
216
|
+
def __init__(
|
|
217
|
+
self,
|
|
218
|
+
num_residual_streams,
|
|
219
|
+
*,
|
|
220
|
+
dim,
|
|
221
|
+
branch: Module | None = None,
|
|
222
|
+
layer_index = None,
|
|
223
|
+
dropout = 0.,
|
|
224
|
+
residual_transform: Module | None = None, # to support resnet blocks where dimension in not equal to dimension out - usually a residual conv
|
|
225
|
+
add_branch_out_to_residual = True, # will disable depth connections (weighted residual sum with beta) if set False
|
|
226
|
+
num_input_views = 1, # allow for the branch module to receive multiple input views, dimension placed on the very left (before batch)
|
|
227
|
+
depth_residual_fn = add,
|
|
228
|
+
num_fracs = 1, # https://arxiv.org/abs/2503.14125
|
|
229
|
+
sinkhorn_iters = 20,
|
|
230
|
+
log_domain_sinkhorn = False,
|
|
231
|
+
residual_mix_constraint_fn: Callable | None = None,
|
|
232
|
+
forward_method_names: tuple[str, ...] = (),
|
|
233
|
+
num_dynamic_alpha_proposals = 1,
|
|
234
|
+
|
|
235
|
+
):
|
|
236
|
+
"""
|
|
237
|
+
Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
|
|
238
|
+
"""
|
|
239
|
+
super().__init__()
|
|
240
|
+
|
|
241
|
+
self.branch = branch
|
|
242
|
+
|
|
243
|
+
# frac-connections paper - num_fracs > 1 will be the `m` in their paper https://arxiv.org/abs/2503.14125
|
|
244
|
+
|
|
245
|
+
assert num_fracs >= 1
|
|
246
|
+
|
|
247
|
+
self.num_fracs = num_fracs
|
|
248
|
+
self.has_fracs = num_fracs > 1
|
|
249
|
+
|
|
250
|
+
self.split_fracs = Rearrange('b ... (f d) -> b ... f d', f = num_fracs)
|
|
251
|
+
self.merge_fracs = Rearrange('b ... f d -> b ... (f d)')
|
|
252
|
+
|
|
253
|
+
assert divisible_by(dim, num_fracs), f'feature dimension ({dim}) must be divisible by the `num_fracs` ({num_fracs})'
|
|
254
|
+
|
|
255
|
+
dim //= num_fracs # effective dim handled in dimension is feature dimension divided by num fractions
|
|
256
|
+
|
|
257
|
+
# they used layernorm in paper, but rmsnorm is fine given what we know now
|
|
258
|
+
|
|
259
|
+
self.norm = RMSNorm(dim)
|
|
260
|
+
|
|
261
|
+
assert num_residual_streams > 0, '`num_residual_streams` must be greater than 0'
|
|
262
|
+
|
|
263
|
+
self.num_residual_streams = num_residual_streams
|
|
264
|
+
init_residual_index = default(layer_index, randrange(num_residual_streams)) % num_residual_streams # just choose one random residual stream if layer index not given
|
|
265
|
+
|
|
266
|
+
# handle the parameter dimensions, which may require (num_residuals x num_fractions) - generalizing hyper + frac connections
|
|
267
|
+
|
|
268
|
+
num_residual_streams_fracs = num_residual_streams * num_fracs
|
|
269
|
+
num_input_views_fracs = num_input_views * num_fracs
|
|
270
|
+
|
|
271
|
+
self.num_fracs = num_fracs
|
|
272
|
+
|
|
273
|
+
# width num residual streams
|
|
274
|
+
|
|
275
|
+
assert num_input_views >= 1
|
|
276
|
+
self.num_input_views = num_input_views
|
|
277
|
+
|
|
278
|
+
# number of dynamic alpha proposals, for averaging Hres across proposals
|
|
279
|
+
|
|
280
|
+
self.has_dynamic_alpha_proposals = num_dynamic_alpha_proposals > 1
|
|
281
|
+
self.num_dynamic_alpha_proposals = num_dynamic_alpha_proposals
|
|
282
|
+
|
|
283
|
+
# width connection
|
|
284
|
+
|
|
285
|
+
init_alpha0 = torch.zeros((num_residual_streams_fracs, num_input_views_fracs))
|
|
286
|
+
init_alpha0[init_residual_index, :] = 1.
|
|
287
|
+
|
|
288
|
+
self.static_alpha = nn.Parameter(cat((init_alpha0, torch.eye(num_residual_streams_fracs)), dim = 1))
|
|
289
|
+
|
|
290
|
+
self.dynamic_alpha_fn = nn.Parameter(torch.zeros(num_dynamic_alpha_proposals, dim, num_residual_streams_fracs + num_input_views_fracs))
|
|
291
|
+
|
|
292
|
+
self.pre_branch_scale = nn.Parameter(torch.ones(1) * 1e-2)
|
|
293
|
+
self.residual_scale = nn.Parameter(torch.ones(1) * 1e-2)
|
|
294
|
+
|
|
295
|
+
# depth connection related (beta)
|
|
296
|
+
|
|
297
|
+
self.add_branch_out_to_residual = add_branch_out_to_residual
|
|
298
|
+
|
|
299
|
+
if add_branch_out_to_residual:
|
|
300
|
+
self.static_beta = nn.Parameter(torch.ones(num_residual_streams, num_fracs, 1))
|
|
301
|
+
|
|
302
|
+
self.dynamic_beta_fn = nn.Parameter(torch.zeros(dim, num_fracs))
|
|
303
|
+
|
|
304
|
+
self.h_post_scale = nn.Parameter(torch.ones(()) * 1e-2)
|
|
305
|
+
|
|
306
|
+
# Hres constraint related
|
|
307
|
+
# by default is sinkhorn
|
|
308
|
+
|
|
309
|
+
self.residual_mix_constraint_fn = default(
|
|
310
|
+
residual_mix_constraint_fn,
|
|
311
|
+
partial(sinkhorn_knopps if not log_domain_sinkhorn else log_domain_sinkhorn_knopps, iters = sinkhorn_iters)
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
# dropouts
|
|
315
|
+
|
|
316
|
+
self.dropout = nn.Dropout(dropout)
|
|
317
|
+
|
|
318
|
+
# maybe residual transform
|
|
319
|
+
|
|
320
|
+
self.residual_transform = default(residual_transform, nn.Identity())
|
|
321
|
+
|
|
322
|
+
# maybe custom depth connection residual function
|
|
323
|
+
# this is to prepare for gating the addition of the branch outputs to the residual streams
|
|
324
|
+
# needed for memory lanes a la RMT / LMM
|
|
325
|
+
|
|
326
|
+
self.depth_residual_fn = depth_residual_fn
|
|
327
|
+
|
|
328
|
+
# forwarding method names
|
|
329
|
+
|
|
330
|
+
self.forward_method_names = forward_method_names
|
|
331
|
+
|
|
332
|
+
for forward_method_name in self.forward_method_names:
|
|
333
|
+
assert not hasattr(self, forward_method_name)
|
|
334
|
+
|
|
335
|
+
fn = getattr(self.branch, forward_method_name)
|
|
336
|
+
setattr(self, forward_method_name, fn)
|
|
337
|
+
|
|
338
|
+
def width_connection(
|
|
339
|
+
self,
|
|
340
|
+
residuals
|
|
341
|
+
):
|
|
342
|
+
streams, fracs = self.num_residual_streams, self.num_fracs
|
|
343
|
+
|
|
344
|
+
residuals = self.residual_transform(residuals)
|
|
345
|
+
|
|
346
|
+
# width connection
|
|
347
|
+
|
|
348
|
+
# split out fractions
|
|
349
|
+
|
|
350
|
+
residuals = self.split_fracs(residuals)
|
|
351
|
+
|
|
352
|
+
# norm
|
|
353
|
+
|
|
354
|
+
normed = self.norm(residuals)
|
|
355
|
+
|
|
356
|
+
# alpha for weighted sum of residuals going into branch
|
|
357
|
+
|
|
358
|
+
dtype = residuals.dtype
|
|
359
|
+
|
|
360
|
+
normed = normed.float()
|
|
361
|
+
|
|
362
|
+
wc_weight = einsum(normed, self.dynamic_alpha_fn.float(), '... d, p d mix -> p ... mix')
|
|
363
|
+
wc_weight = rearrange(wc_weight, '... s1 f2 mix -> ... (s1 f2) mix')
|
|
364
|
+
|
|
365
|
+
pre_branch_scale = repeat(self.pre_branch_scale.float(), '1 -> s', s = self.num_fracs)
|
|
366
|
+
residual_scale = repeat(self.residual_scale.float(), '1 -> s', s = self.num_fracs * streams)
|
|
367
|
+
alpha_scale = cat((pre_branch_scale, residual_scale))
|
|
368
|
+
|
|
369
|
+
alpha_scale = repeat(alpha_scale, 'n -> (v n)', v = self.num_input_views)
|
|
370
|
+
|
|
371
|
+
dynamic_alpha = wc_weight * alpha_scale
|
|
372
|
+
|
|
373
|
+
alpha = dynamic_alpha + self.static_alpha.float()
|
|
374
|
+
|
|
375
|
+
# the alpha is now split and "manifold constrained" with sinkhorn and sigmoid
|
|
376
|
+
|
|
377
|
+
alpha_pre, alpha_residual = alpha[..., :self.num_input_views * self.num_fracs], alpha[..., self.num_input_views * self.num_fracs:]
|
|
378
|
+
|
|
379
|
+
alpha_pre = alpha_pre.sigmoid()
|
|
380
|
+
|
|
381
|
+
alpha_residual = self.residual_mix_constraint_fn(alpha_residual)
|
|
382
|
+
|
|
383
|
+
alpha = cat((alpha_pre, alpha_residual), dim = -1)
|
|
384
|
+
|
|
385
|
+
if self.has_dynamic_alpha_proposals:
|
|
386
|
+
alpha = reduce(alpha, 'p ... -> ...', 'mean')
|
|
387
|
+
else:
|
|
388
|
+
alpha = rearrange(alpha, '1 ... -> ...')
|
|
389
|
+
|
|
390
|
+
alpha = rearrange(alpha, '... (s f) t -> ... s f t', s = streams) # (batch, seq, fracs1, streams, fracs2, input + residual streams)
|
|
391
|
+
|
|
392
|
+
# beta for weights from branch output back to residual streams
|
|
393
|
+
|
|
394
|
+
beta = None
|
|
395
|
+
|
|
396
|
+
if self.add_branch_out_to_residual:
|
|
397
|
+
dc_weight = normed @ self.dynamic_beta_fn.float()
|
|
398
|
+
|
|
399
|
+
dynamic_beta = dc_weight * self.h_post_scale.float()
|
|
400
|
+
|
|
401
|
+
beta = dynamic_beta + self.static_beta.float()
|
|
402
|
+
|
|
403
|
+
beta = beta.sigmoid() * 2 # for "H_post" manifold constraint
|
|
404
|
+
|
|
405
|
+
mix_h = einsum(alpha, residuals.float(), '... s f tf, ... s f d -> ... tf d')
|
|
406
|
+
|
|
407
|
+
mix_h = rearrange(mix_h, '... (t f) d -> ... t f d', f = fracs)
|
|
408
|
+
|
|
409
|
+
if self.num_input_views == 1:
|
|
410
|
+
branch_input, residuals = mix_h[..., 0, :, :], mix_h[..., 1:, :, :]
|
|
411
|
+
else:
|
|
412
|
+
branch_input, residuals = mix_h[..., :self.num_input_views, :, :], mix_h[..., self.num_input_views:, :, :]
|
|
413
|
+
branch_input = rearrange(branch_input, 'b ... v f d -> v b ... f d')
|
|
414
|
+
|
|
415
|
+
# maybe merge fractions back
|
|
416
|
+
|
|
417
|
+
branch_input = self.merge_fracs(branch_input)
|
|
418
|
+
|
|
419
|
+
residuals = rearrange(residuals, 'b ... s f d -> b ... s (f d)')
|
|
420
|
+
|
|
421
|
+
branch_input, residuals = tuple(t.to(dtype) for t in (branch_input, residuals))
|
|
422
|
+
|
|
423
|
+
if exists(beta):
|
|
424
|
+
beta = beta.to(dtype)
|
|
425
|
+
|
|
426
|
+
return branch_input, residuals, dict(beta = beta)
|
|
427
|
+
|
|
428
|
+
def depth_connection(
|
|
429
|
+
self,
|
|
430
|
+
branch_output,
|
|
431
|
+
residuals,
|
|
432
|
+
*,
|
|
433
|
+
beta
|
|
434
|
+
):
|
|
435
|
+
assert self.add_branch_out_to_residual
|
|
436
|
+
|
|
437
|
+
# maybe split fractions
|
|
438
|
+
|
|
439
|
+
branch_output = self.split_fracs(branch_output)
|
|
440
|
+
|
|
441
|
+
# 'depth' connection
|
|
442
|
+
|
|
443
|
+
dtype = residuals.dtype
|
|
444
|
+
|
|
445
|
+
output = einsum(branch_output.float(), beta.float(), 'b ... f1 d, b ... s f1 f2 -> b ... s f2 d')
|
|
446
|
+
|
|
447
|
+
# merge merge back fractions
|
|
448
|
+
|
|
449
|
+
output = self.merge_fracs(output)
|
|
450
|
+
|
|
451
|
+
# channel first
|
|
452
|
+
|
|
453
|
+
residuals = self.depth_residual_fn(output.to(dtype), residuals)
|
|
454
|
+
|
|
455
|
+
return self.dropout(residuals)
|
|
456
|
+
|
|
457
|
+
def decorate_branch(
|
|
458
|
+
self,
|
|
459
|
+
branch: Callable
|
|
460
|
+
):
|
|
461
|
+
assert not exists(self.branch), 'branch was already wrapped on init'
|
|
462
|
+
|
|
463
|
+
def forward_and_add_residual(residual, *args, **kwargs):
|
|
464
|
+
branch_input, add_residual = self.forward(residual)
|
|
465
|
+
|
|
466
|
+
branch_output = branch(branch_input, *args, **kwargs)
|
|
467
|
+
|
|
468
|
+
residual = add_residual(branch_output)
|
|
469
|
+
|
|
470
|
+
return residual
|
|
471
|
+
|
|
472
|
+
return forward_and_add_residual
|
|
473
|
+
|
|
474
|
+
def forward(
|
|
475
|
+
self,
|
|
476
|
+
residuals,
|
|
477
|
+
*branch_args,
|
|
478
|
+
**branch_kwargs
|
|
479
|
+
):
|
|
480
|
+
|
|
481
|
+
branch_input, residuals, residual_kwargs = self.width_connection(residuals)
|
|
482
|
+
|
|
483
|
+
def add_residual_fn(branch_out):
|
|
484
|
+
|
|
485
|
+
if not self.add_branch_out_to_residual:
|
|
486
|
+
return branch_out
|
|
487
|
+
|
|
488
|
+
(branch_out, *rest), tree_spec = tree_flatten(branch_out)
|
|
489
|
+
|
|
490
|
+
branch_out = self.depth_connection(branch_out, residuals, **residual_kwargs)
|
|
491
|
+
|
|
492
|
+
return tree_unflatten((branch_out, *rest), tree_spec)
|
|
493
|
+
|
|
494
|
+
if not exists(self.branch):
|
|
495
|
+
return branch_input, add_residual_fn
|
|
496
|
+
|
|
497
|
+
branch_output = self.branch(branch_input, *branch_args, **branch_kwargs)
|
|
498
|
+
|
|
499
|
+
return add_residual_fn(branch_output)
|
|
500
|
+
|
|
501
|
+
mHC = ManifoldConstrainedHyperConnections
|
|
502
|
+
|
|
503
|
+
ManifoldConstrainedHyperConnections.get_expand_reduce_stream_functions = staticmethod(get_expand_reduce_stream_functions)
|
|
504
|
+
ManifoldConstrainedHyperConnections.get_init_and_expand_reduce_stream_functions = staticmethod(get_init_and_expand_reduce_stream_functions)
|
|
505
|
+
|
|
506
|
+
# stream embed
|
|
507
|
+
|
|
508
|
+
class StreamEmbed(Module):
|
|
509
|
+
def __init__(
|
|
510
|
+
self,
|
|
511
|
+
num_streams,
|
|
512
|
+
dim,
|
|
513
|
+
expand_to_streams = False
|
|
514
|
+
):
|
|
515
|
+
super().__init__()
|
|
516
|
+
self.num_streams = num_streams
|
|
517
|
+
|
|
518
|
+
self.expand_to_streams = expand_to_streams
|
|
519
|
+
self.stream_embed = nn.Parameter(torch.zeros(num_streams, dim))
|
|
520
|
+
|
|
521
|
+
def forward(self, residuals):
|
|
522
|
+
|
|
523
|
+
if self.expand_to_streams:
|
|
524
|
+
residuals = repeat(residuals, '... d -> ... s d', s = self.num_streams)
|
|
525
|
+
|
|
526
|
+
return residuals + self.stream_embed
|
|
527
|
+
|
|
528
|
+
# attention pool - taken from Enformer https://www.nature.com/articles/s41592-021-01252-x , in turn taken from somewhere else
|
|
529
|
+
|
|
530
|
+
class AttentionPoolReduceStream(Module):
|
|
531
|
+
def __init__(self, dim):
|
|
532
|
+
super().__init__()
|
|
533
|
+
self.to_attn_logits = nn.Linear(dim, dim, bias = False)
|
|
534
|
+
self.to_attn_logits.weight.data.copy_(torch.eye(dim))
|
|
535
|
+
|
|
536
|
+
def forward(self, residuals):
|
|
537
|
+
|
|
538
|
+
attn_logits = self.to_attn_logits(residuals)
|
|
539
|
+
attn = attn_logits.softmax(dim = -2)
|
|
540
|
+
|
|
541
|
+
return einsum(residuals, attn, '... s d, ... s d -> ... d')
|
|
@@ -78,7 +78,7 @@ def get_expand_reduce_stream_functions(
|
|
|
78
78
|
dim = None,
|
|
79
79
|
disable = False
|
|
80
80
|
):
|
|
81
|
-
if
|
|
81
|
+
if disable:
|
|
82
82
|
return (nn.Identity(), nn.Identity())
|
|
83
83
|
|
|
84
84
|
if add_stream_embed:
|
|
@@ -213,7 +213,7 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
213
213
|
num_fracs = 1, # https://arxiv.org/abs/2503.14125
|
|
214
214
|
sinkhorn_iters = 20,
|
|
215
215
|
log_domain_sinkhorn = False,
|
|
216
|
-
|
|
216
|
+
residual_mix_constraint_fn: Callable | None = None,
|
|
217
217
|
forward_method_names: tuple[str, ...] = (),
|
|
218
218
|
num_dynamic_alpha_proposals = 1,
|
|
219
219
|
|
|
@@ -292,8 +292,8 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
292
292
|
# Hres constraint related
|
|
293
293
|
# by default is sinkhorn
|
|
294
294
|
|
|
295
|
-
self.
|
|
296
|
-
|
|
295
|
+
self.residual_mix_constraint_fn = default(
|
|
296
|
+
residual_mix_constraint_fn,
|
|
297
297
|
partial(sinkhorn_knopps if not log_domain_sinkhorn else log_domain_sinkhorn_knopps, iters = sinkhorn_iters)
|
|
298
298
|
)
|
|
299
299
|
|
|
@@ -378,7 +378,7 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
378
378
|
|
|
379
379
|
alpha_pre = alpha_pre.sigmoid()
|
|
380
380
|
|
|
381
|
-
alpha_residual = self.
|
|
381
|
+
alpha_residual = self.residual_mix_constraint_fn(alpha_residual)
|
|
382
382
|
|
|
383
383
|
alpha = cat((alpha_pre, alpha_residual), dim = -1)
|
|
384
384
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "hyper-connections"
|
|
3
|
-
version = "0.4.
|
|
3
|
+
version = "0.4.3"
|
|
4
4
|
description = "Hyper-Connections"
|
|
5
5
|
authors = [
|
|
6
6
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
|
@@ -23,8 +23,8 @@ classifiers=[
|
|
|
23
23
|
]
|
|
24
24
|
|
|
25
25
|
dependencies = [
|
|
26
|
-
"einops>=0.8.
|
|
27
|
-
"torch>=2.
|
|
26
|
+
"einops>=0.8.1",
|
|
27
|
+
"torch>=2.5",
|
|
28
28
|
]
|
|
29
29
|
|
|
30
30
|
[project.urls]
|
|
@@ -234,3 +234,49 @@ def test_mhc_vit(
|
|
|
234
234
|
|
|
235
235
|
preds = v(img) # (1, 1000)
|
|
236
236
|
assert preds.shape == (1, 1000)
|
|
237
|
+
|
|
238
|
+
@param('num_fracs', (1, 2))
|
|
239
|
+
@param('num_streams', (1, 3))
|
|
240
|
+
@param('disable', (False, True))
|
|
241
|
+
@param('add_attn_pool_reduce_stream', (False, True))
|
|
242
|
+
def test_mhcv2(
|
|
243
|
+
num_fracs,
|
|
244
|
+
num_streams,
|
|
245
|
+
disable,
|
|
246
|
+
add_attn_pool_reduce_stream
|
|
247
|
+
):
|
|
248
|
+
import torch
|
|
249
|
+
from torch import nn
|
|
250
|
+
# a single branch layer
|
|
251
|
+
|
|
252
|
+
branch = nn.Linear(512, 512)
|
|
253
|
+
|
|
254
|
+
# before
|
|
255
|
+
|
|
256
|
+
residual = torch.randn(2, 1024, 512)
|
|
257
|
+
|
|
258
|
+
residual = branch(residual) + residual
|
|
259
|
+
|
|
260
|
+
# after, say 4 streams in paper
|
|
261
|
+
|
|
262
|
+
from hyper_connections.mHCv2 import get_init_and_expand_reduce_stream_functions
|
|
263
|
+
|
|
264
|
+
init_hyper_conn, expand_stream, reduce_stream = get_init_and_expand_reduce_stream_functions(num_streams, dim = 512, num_fracs = num_fracs, disable = disable, add_attn_pool_reduce_stream = add_attn_pool_reduce_stream)
|
|
265
|
+
|
|
266
|
+
# 1. wrap your branch function
|
|
267
|
+
|
|
268
|
+
hyper_conn_branch = init_hyper_conn(dim = 512, branch = branch)
|
|
269
|
+
|
|
270
|
+
# 2. expand to 4 streams, this must be done before your trunk, typically a for-loop with many branch functions
|
|
271
|
+
|
|
272
|
+
residual = expand_stream(residual)
|
|
273
|
+
|
|
274
|
+
# 3. forward your residual as usual into the wrapped branch function(s)
|
|
275
|
+
|
|
276
|
+
residual = hyper_conn_branch(residual)
|
|
277
|
+
|
|
278
|
+
# 4. reduce 4 streams with a summation, this has to be done after your for-loop trunk. for transformer, unsure whether to do before or after final norm
|
|
279
|
+
|
|
280
|
+
residual = reduce_stream(residual)
|
|
281
|
+
|
|
282
|
+
assert residual.shape == (2, 1024, 512)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|