hyper-connections 0.4.7__tar.gz → 0.4.9__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {hyper_connections-0.4.7 → hyper_connections-0.4.9}/PKG-INFO +2 -1
- {hyper_connections-0.4.7 → hyper_connections-0.4.9}/hyper_connections/mHCv2.py +24 -1
- {hyper_connections-0.4.7 → hyper_connections-0.4.9}/pyproject.toml +2 -1
- {hyper_connections-0.4.7 → hyper_connections-0.4.9}/tests/test_hyper_connections.py +6 -2
- {hyper_connections-0.4.7 → hyper_connections-0.4.9}/.github/workflows/python-publish.yml +0 -0
- {hyper_connections-0.4.7 → hyper_connections-0.4.9}/.github/workflows/test.yml +0 -0
- {hyper_connections-0.4.7 → hyper_connections-0.4.9}/.gitignore +0 -0
- {hyper_connections-0.4.7 → hyper_connections-0.4.9}/LICENSE +0 -0
- {hyper_connections-0.4.7 → hyper_connections-0.4.9}/README.md +0 -0
- {hyper_connections-0.4.7 → hyper_connections-0.4.9}/hyper-connections.png +0 -0
- {hyper_connections-0.4.7 → hyper_connections-0.4.9}/hyper_connections/__init__.py +0 -0
- {hyper_connections-0.4.7 → hyper_connections-0.4.9}/hyper_connections/hyper_connections.py +0 -0
- {hyper_connections-0.4.7 → hyper_connections-0.4.9}/hyper_connections/hyper_connections_channel_first.py +0 -0
- {hyper_connections-0.4.7 → hyper_connections-0.4.9}/hyper_connections/hyper_connections_with_multi_branch_inputs.py +0 -0
- {hyper_connections-0.4.7 → hyper_connections-0.4.9}/hyper_connections/hyper_connections_with_multi_input_streams.py +0 -0
- {hyper_connections-0.4.7 → hyper_connections-0.4.9}/hyper_connections/manifold_constrained_hyper_connections.py +0 -0
- {hyper_connections-0.4.7 → hyper_connections-0.4.9}/hyper_connections/residuals.py +0 -0
- {hyper_connections-0.4.7 → hyper_connections-0.4.9}/hyper_connections/triton_sinkhorn.py +0 -0
- {hyper_connections-0.4.7 → hyper_connections-0.4.9}/hyper_connections/vit.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: hyper-connections
|
|
3
|
-
Version: 0.4.
|
|
3
|
+
Version: 0.4.9
|
|
4
4
|
Summary: Hyper-Connections
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/hyper-connections/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/hyper-connections
|
|
@@ -35,6 +35,7 @@ Classifier: Programming Language :: Python :: 3.9
|
|
|
35
35
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
36
36
|
Requires-Python: >=3.9
|
|
37
37
|
Requires-Dist: einops>=0.8.1
|
|
38
|
+
Requires-Dist: torch-einops-utils>=0.0.20
|
|
38
39
|
Requires-Dist: torch>=2.5
|
|
39
40
|
Provides-Extra: examples
|
|
40
41
|
Description-Content-Type: text/markdown
|
|
@@ -13,6 +13,8 @@ from torch.utils._pytree import tree_flatten, tree_unflatten
|
|
|
13
13
|
from einops import rearrange, repeat, reduce, einsum
|
|
14
14
|
from einops.layers.torch import Rearrange, Reduce
|
|
15
15
|
|
|
16
|
+
from torch_einops_utils import pack_with_inverse
|
|
17
|
+
|
|
16
18
|
"""
|
|
17
19
|
ein notation:
|
|
18
20
|
b - batch
|
|
@@ -241,6 +243,7 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
241
243
|
forward_method_names: tuple[str, ...] = (),
|
|
242
244
|
num_dynamic_alpha_proposals = 1,
|
|
243
245
|
use_triton_sinkhorn = False,
|
|
246
|
+
mix_streams_before_norm = False, # whether to mix the residual streams before the norm (that then projects to Hpre, Hpost, Hresidual)
|
|
244
247
|
):
|
|
245
248
|
"""
|
|
246
249
|
Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
|
|
@@ -263,6 +266,16 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
263
266
|
|
|
264
267
|
dim //= num_fracs # effective dim handled in dimension is feature dimension divided by num fractions
|
|
265
268
|
|
|
269
|
+
# whether to mix the streams before the norm below
|
|
270
|
+
# this would be equivalent to separable depthwise convs from yesteryears (with a norm in between) - parameter efficient improv
|
|
271
|
+
|
|
272
|
+
self.maybe_mix_streams = None
|
|
273
|
+
|
|
274
|
+
if mix_streams_before_norm:
|
|
275
|
+
self.maybe_mix_streams = nn.Conv2d(num_residual_streams, num_residual_streams, 1, bias = False)
|
|
276
|
+
|
|
277
|
+
nn.init.dirac_(self.maybe_mix_streams.weight)
|
|
278
|
+
|
|
266
279
|
# they used layernorm in paper, but rmsnorm is fine given what we know now
|
|
267
280
|
|
|
268
281
|
self.norm = RMSNorm(dim)
|
|
@@ -370,9 +383,19 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
370
383
|
|
|
371
384
|
residuals = self.split_fracs(residuals)
|
|
372
385
|
|
|
386
|
+
# maybe mix streams
|
|
387
|
+
|
|
388
|
+
norm_input = residuals
|
|
389
|
+
|
|
390
|
+
if exists(self.maybe_mix_streams):
|
|
391
|
+
|
|
392
|
+
norm_input, inverse_pack_lead_dims = pack_with_inverse(norm_input, '* c h w')
|
|
393
|
+
norm_input = self.maybe_mix_streams(norm_input)
|
|
394
|
+
norm_input = inverse_pack_lead_dims(norm_input)
|
|
395
|
+
|
|
373
396
|
# norm
|
|
374
397
|
|
|
375
|
-
normed = self.norm(
|
|
398
|
+
normed = self.norm(norm_input)
|
|
376
399
|
|
|
377
400
|
# alpha for weighted sum of residuals going into branch
|
|
378
401
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "hyper-connections"
|
|
3
|
-
version = "0.4.
|
|
3
|
+
version = "0.4.9"
|
|
4
4
|
description = "Hyper-Connections"
|
|
5
5
|
authors = [
|
|
6
6
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
|
@@ -25,6 +25,7 @@ classifiers=[
|
|
|
25
25
|
dependencies = [
|
|
26
26
|
"einops>=0.8.1",
|
|
27
27
|
"torch>=2.5",
|
|
28
|
+
"torch-einops-utils>=0.0.20"
|
|
28
29
|
]
|
|
29
30
|
|
|
30
31
|
[project.urls]
|
|
@@ -227,7 +227,9 @@ def test_mhc_vit(
|
|
|
227
227
|
dropout = 0.1,
|
|
228
228
|
emb_dropout = 0.1,
|
|
229
229
|
num_residual_streams = 4,
|
|
230
|
-
|
|
230
|
+
mhc_kwargs = dict(
|
|
231
|
+
num_dynamic_alpha_proposals = num_dynamic_alpha_proposals
|
|
232
|
+
)
|
|
231
233
|
)
|
|
232
234
|
|
|
233
235
|
img = torch.randn(1, 3, 256, 256)
|
|
@@ -238,11 +240,13 @@ def test_mhc_vit(
|
|
|
238
240
|
@param('num_fracs', (1, 2))
|
|
239
241
|
@param('num_streams', (1, 3))
|
|
240
242
|
@param('disable', (False, True))
|
|
243
|
+
@param('mix_streams_before_norm', (False, True))
|
|
241
244
|
@param('add_attn_pool_reduce_stream', (False, True))
|
|
242
245
|
def test_mhcv2(
|
|
243
246
|
num_fracs,
|
|
244
247
|
num_streams,
|
|
245
248
|
disable,
|
|
249
|
+
mix_streams_before_norm,
|
|
246
250
|
add_attn_pool_reduce_stream
|
|
247
251
|
):
|
|
248
252
|
import torch
|
|
@@ -261,7 +265,7 @@ def test_mhcv2(
|
|
|
261
265
|
|
|
262
266
|
from hyper_connections.mHCv2 import get_init_and_expand_reduce_stream_functions
|
|
263
267
|
|
|
264
|
-
init_hyper_conn, expand_stream, reduce_stream = get_init_and_expand_reduce_stream_functions(num_streams, dim = 512, num_fracs = num_fracs, disable = disable, add_attn_pool_reduce_stream = add_attn_pool_reduce_stream)
|
|
268
|
+
init_hyper_conn, expand_stream, reduce_stream = get_init_and_expand_reduce_stream_functions(num_streams, dim = 512, num_fracs = num_fracs, mix_streams_before_norm = mix_streams_before_norm, disable = disable, add_attn_pool_reduce_stream = add_attn_pool_reduce_stream)
|
|
265
269
|
|
|
266
270
|
# 1. wrap your branch function
|
|
267
271
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|