hyper-connections 0.3.16__tar.gz → 0.4.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {hyper_connections-0.3.16 → hyper_connections-0.4.1}/PKG-INFO +1 -1
- {hyper_connections-0.3.16 → hyper_connections-0.4.1}/hyper_connections/manifold_constrained_hyper_connections.py +23 -8
- {hyper_connections-0.3.16 → hyper_connections-0.4.1}/hyper_connections/vit.py +4 -4
- {hyper_connections-0.3.16 → hyper_connections-0.4.1}/pyproject.toml +1 -1
- {hyper_connections-0.3.16 → hyper_connections-0.4.1}/tests/test_hyper_connections.py +11 -5
- {hyper_connections-0.3.16 → hyper_connections-0.4.1}/.github/workflows/python-publish.yml +0 -0
- {hyper_connections-0.3.16 → hyper_connections-0.4.1}/.github/workflows/test.yml +0 -0
- {hyper_connections-0.3.16 → hyper_connections-0.4.1}/.gitignore +0 -0
- {hyper_connections-0.3.16 → hyper_connections-0.4.1}/LICENSE +0 -0
- {hyper_connections-0.3.16 → hyper_connections-0.4.1}/README.md +0 -0
- {hyper_connections-0.3.16 → hyper_connections-0.4.1}/hyper-connections.png +0 -0
- {hyper_connections-0.3.16 → hyper_connections-0.4.1}/hyper_connections/__init__.py +0 -0
- {hyper_connections-0.3.16 → hyper_connections-0.4.1}/hyper_connections/hyper_connections.py +0 -0
- {hyper_connections-0.3.16 → hyper_connections-0.4.1}/hyper_connections/hyper_connections_channel_first.py +0 -0
- {hyper_connections-0.3.16 → hyper_connections-0.4.1}/hyper_connections/hyper_connections_with_multi_branch_inputs.py +0 -0
- {hyper_connections-0.3.16 → hyper_connections-0.4.1}/hyper_connections/hyper_connections_with_multi_input_streams.py +0 -0
- {hyper_connections-0.3.16 → hyper_connections-0.4.1}/hyper_connections/residuals.py +0 -0
|
@@ -21,6 +21,7 @@ s - residual streams
|
|
|
21
21
|
t - residual streams + num branch inputs
|
|
22
22
|
f - number of fractions (division of feature dimension space)
|
|
23
23
|
v - number of views for branch input
|
|
24
|
+
p - proposals
|
|
24
25
|
"""
|
|
25
26
|
|
|
26
27
|
# helper functions
|
|
@@ -212,7 +213,10 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
212
213
|
num_fracs = 1, # https://arxiv.org/abs/2503.14125
|
|
213
214
|
sinkhorn_iters = 20,
|
|
214
215
|
log_domain_sinkhorn = False,
|
|
216
|
+
residual_constraint_fn: Callable | None = None,
|
|
215
217
|
forward_method_names: tuple[str, ...] = (),
|
|
218
|
+
num_dynamic_alpha_proposals = 1,
|
|
219
|
+
|
|
216
220
|
):
|
|
217
221
|
"""
|
|
218
222
|
Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
|
|
@@ -256,6 +260,11 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
256
260
|
assert num_input_views >= 1
|
|
257
261
|
self.num_input_views = num_input_views
|
|
258
262
|
|
|
263
|
+
# number of dynamic alpha proposals, for averaging Hres across proposals
|
|
264
|
+
|
|
265
|
+
self.has_dynamic_alpha_proposals = num_dynamic_alpha_proposals > 1
|
|
266
|
+
self.num_dynamic_alpha_proposals = num_dynamic_alpha_proposals
|
|
267
|
+
|
|
259
268
|
# width connection
|
|
260
269
|
|
|
261
270
|
init_alpha0 = torch.zeros((num_residual_streams_fracs, num_input_views_fracs))
|
|
@@ -263,7 +272,7 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
263
272
|
|
|
264
273
|
self.static_alpha = nn.Parameter(cat((init_alpha0, torch.eye(num_residual_streams_fracs)), dim = 1))
|
|
265
274
|
|
|
266
|
-
self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams_fracs + num_input_views_fracs))
|
|
275
|
+
self.dynamic_alpha_fn = nn.Parameter(torch.zeros(num_dynamic_alpha_proposals, dim, num_residual_streams_fracs + num_input_views_fracs))
|
|
267
276
|
|
|
268
277
|
self.pre_branch_scale = nn.Parameter(torch.ones(1) * 1e-2)
|
|
269
278
|
self.residual_scale = nn.Parameter(torch.ones(1) * 1e-2)
|
|
@@ -280,10 +289,13 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
280
289
|
|
|
281
290
|
self.h_post_scale = nn.Parameter(torch.ones(()) * 1e-2)
|
|
282
291
|
|
|
283
|
-
#
|
|
292
|
+
# Hres constraint related
|
|
293
|
+
# by default is sinkhorn
|
|
284
294
|
|
|
285
|
-
self.
|
|
286
|
-
|
|
295
|
+
self.residual_constraint_fn = default(
|
|
296
|
+
residual_constraint_fn,
|
|
297
|
+
partial(sinkhorn_knopps if not log_domain_sinkhorn else log_domain_sinkhorn_knopps, iters = sinkhorn_iters)
|
|
298
|
+
)
|
|
287
299
|
|
|
288
300
|
# dropouts
|
|
289
301
|
|
|
@@ -346,7 +358,7 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
346
358
|
|
|
347
359
|
normed = normed.float()
|
|
348
360
|
|
|
349
|
-
wc_weight = normed
|
|
361
|
+
wc_weight = einsum(normed, self.dynamic_alpha_fn.float(), '... d, p d e -> p ... e')
|
|
350
362
|
|
|
351
363
|
pre_branch_scale = repeat(self.pre_branch_scale.float(), '1 -> s', s = self.num_fracs)
|
|
352
364
|
residual_scale = repeat(self.residual_scale.float(), '1 -> s', s = self.num_fracs * streams)
|
|
@@ -366,12 +378,15 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
366
378
|
|
|
367
379
|
alpha_pre = alpha_pre.sigmoid()
|
|
368
380
|
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
alpha_residual = sinkhorn_fn(alpha_residual, self.sinkhorn_iters)
|
|
381
|
+
alpha_residual = self.residual_constraint_fn(alpha_residual)
|
|
372
382
|
|
|
373
383
|
alpha = cat((alpha_pre, alpha_residual), dim = -1)
|
|
374
384
|
|
|
385
|
+
if self.has_dynamic_alpha_proposals:
|
|
386
|
+
alpha = reduce(alpha, 'p ... -> ...', 'mean')
|
|
387
|
+
else:
|
|
388
|
+
alpha = rearrange(alpha, '1 ... -> ...')
|
|
389
|
+
|
|
375
390
|
alpha = self.split_fracs(alpha) # (batch, seq, fracs1, streams, fracs2, input + residual streams)
|
|
376
391
|
|
|
377
392
|
# beta for weights from branch output back to residual streams
|
|
@@ -66,12 +66,12 @@ class Attention(Module):
|
|
|
66
66
|
return self.to_out(out)
|
|
67
67
|
|
|
68
68
|
class Transformer(Module):
|
|
69
|
-
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., num_residual_streams = 4):
|
|
69
|
+
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., num_residual_streams = 4, num_dynamic_alpha_proposals = 1):
|
|
70
70
|
super().__init__()
|
|
71
71
|
self.norm = nn.LayerNorm(dim)
|
|
72
72
|
self.layers = ModuleList([])
|
|
73
73
|
|
|
74
|
-
init_hyper_conn, self.expand_streams, self.reduce_streams = mHC.get_init_and_expand_reduce_stream_functions(num_residual_streams)
|
|
74
|
+
init_hyper_conn, self.expand_streams, self.reduce_streams = mHC.get_init_and_expand_reduce_stream_functions(num_residual_streams, num_dynamic_alpha_proposals = num_dynamic_alpha_proposals)
|
|
75
75
|
|
|
76
76
|
for _ in range(depth):
|
|
77
77
|
self.layers.append(ModuleList([
|
|
@@ -92,7 +92,7 @@ class Transformer(Module):
|
|
|
92
92
|
return self.norm(x)
|
|
93
93
|
|
|
94
94
|
class ViT(Module):
|
|
95
|
-
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., num_residual_streams = 4):
|
|
95
|
+
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., num_residual_streams = 4, num_dynamic_alpha_proposals = 1):
|
|
96
96
|
super().__init__()
|
|
97
97
|
image_height, image_width = pair(image_size)
|
|
98
98
|
patch_height, patch_width = pair(patch_size)
|
|
@@ -117,7 +117,7 @@ class ViT(Module):
|
|
|
117
117
|
|
|
118
118
|
self.dropout = nn.Dropout(emb_dropout)
|
|
119
119
|
|
|
120
|
-
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
|
120
|
+
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, num_residual_streams, num_dynamic_alpha_proposals)
|
|
121
121
|
|
|
122
122
|
self.pool = pool
|
|
123
123
|
self.to_latent = nn.Identity()
|
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
import pytest
|
|
2
|
+
param = pytest.mark.parametrize
|
|
2
3
|
|
|
3
4
|
import torch
|
|
4
5
|
from torch import nn
|
|
5
6
|
|
|
6
|
-
@
|
|
7
|
-
@
|
|
8
|
-
@
|
|
7
|
+
@param('num_fracs', (1, 4))
|
|
8
|
+
@param('disable', (False, True))
|
|
9
|
+
@param('manifold_constrained', (False, True))
|
|
9
10
|
def test_readme(
|
|
10
11
|
num_fracs,
|
|
11
12
|
disable,
|
|
@@ -208,7 +209,11 @@ def test_mhc_dtype_restoration():
|
|
|
208
209
|
|
|
209
210
|
assert residual.dtype == torch.half
|
|
210
211
|
|
|
211
|
-
|
|
212
|
+
@param('num_dynamic_alpha_proposals', (1, 2))
|
|
213
|
+
def test_mhc_vit(
|
|
214
|
+
num_dynamic_alpha_proposals
|
|
215
|
+
):
|
|
216
|
+
|
|
212
217
|
from hyper_connections.vit import ViT
|
|
213
218
|
|
|
214
219
|
v = ViT(
|
|
@@ -221,7 +226,8 @@ def test_mhc_vit():
|
|
|
221
226
|
mlp_dim = 2048,
|
|
222
227
|
dropout = 0.1,
|
|
223
228
|
emb_dropout = 0.1,
|
|
224
|
-
num_residual_streams = 4
|
|
229
|
+
num_residual_streams = 4,
|
|
230
|
+
num_dynamic_alpha_proposals = num_dynamic_alpha_proposals
|
|
225
231
|
)
|
|
226
232
|
|
|
227
233
|
img = torch.randn(1, 3, 256, 256)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|