hyper-connections 0.3.3__tar.gz → 0.3.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {hyper_connections-0.3.3 → hyper_connections-0.3.5}/PKG-INFO +3 -1
- {hyper_connections-0.3.3 → hyper_connections-0.3.5}/README.md +2 -0
- {hyper_connections-0.3.3 → hyper_connections-0.3.5}/hyper_connections/manifold_constrained_hyper_connections.py +8 -14
- {hyper_connections-0.3.3 → hyper_connections-0.3.5}/pyproject.toml +1 -1
- {hyper_connections-0.3.3 → hyper_connections-0.3.5}/.github/workflows/python-publish.yml +0 -0
- {hyper_connections-0.3.3 → hyper_connections-0.3.5}/.github/workflows/test.yml +0 -0
- {hyper_connections-0.3.3 → hyper_connections-0.3.5}/.gitignore +0 -0
- {hyper_connections-0.3.3 → hyper_connections-0.3.5}/LICENSE +0 -0
- {hyper_connections-0.3.3 → hyper_connections-0.3.5}/hyper-connections.png +0 -0
- {hyper_connections-0.3.3 → hyper_connections-0.3.5}/hyper_connections/__init__.py +0 -0
- {hyper_connections-0.3.3 → hyper_connections-0.3.5}/hyper_connections/hyper_connections.py +0 -0
- {hyper_connections-0.3.3 → hyper_connections-0.3.5}/hyper_connections/hyper_connections_channel_first.py +0 -0
- {hyper_connections-0.3.3 → hyper_connections-0.3.5}/hyper_connections/hyper_connections_with_multi_branch_inputs.py +0 -0
- {hyper_connections-0.3.3 → hyper_connections-0.3.5}/hyper_connections/hyper_connections_with_multi_input_streams.py +0 -0
- {hyper_connections-0.3.3 → hyper_connections-0.3.5}/hyper_connections/residuals.py +0 -0
- {hyper_connections-0.3.3 → hyper_connections-0.3.5}/tests/test_hyper_connections.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: hyper-connections
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.5
|
|
4
4
|
Summary: Hyper-Connections
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/hyper-connections/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/hyper-connections
|
|
@@ -45,6 +45,8 @@ Description-Content-Type: text/markdown
|
|
|
45
45
|
|
|
46
46
|
Attempt to make multiple residual streams, proposed in [Hyper-Connections paper](https://arxiv.org/abs/2409.19606) out of Bytedance AI lab, accessible as an easy to use library, as well as for following any new research in this direction.
|
|
47
47
|
|
|
48
|
+
[Write up on mHC from Subhadip Mitra](https://subhadipmitra.com/blog/2026/deepseek-mhc-manifold-constrained-hyper-connections/)
|
|
49
|
+
|
|
48
50
|
## Install
|
|
49
51
|
|
|
50
52
|
```bash
|
|
@@ -4,6 +4,8 @@
|
|
|
4
4
|
|
|
5
5
|
Attempt to make multiple residual streams, proposed in [Hyper-Connections paper](https://arxiv.org/abs/2409.19606) out of Bytedance AI lab, accessible as an easy to use library, as well as for following any new research in this direction.
|
|
6
6
|
|
|
7
|
+
[Write up on mHC from Subhadip Mitra](https://subhadipmitra.com/blog/2026/deepseek-mhc-manifold-constrained-hyper-connections/)
|
|
8
|
+
|
|
7
9
|
## Install
|
|
8
10
|
|
|
9
11
|
```bash
|
|
@@ -10,7 +10,7 @@ import torch.nn.functional as F
|
|
|
10
10
|
from torch.nn import Module, Sequential
|
|
11
11
|
from torch.utils._pytree import tree_flatten, tree_unflatten
|
|
12
12
|
|
|
13
|
-
from einops import rearrange, repeat, reduce, einsum
|
|
13
|
+
from einops import rearrange, repeat, reduce, einsum
|
|
14
14
|
from einops.layers.torch import Rearrange, Reduce
|
|
15
15
|
|
|
16
16
|
"""
|
|
@@ -40,14 +40,6 @@ def identity(t):
|
|
|
40
40
|
def add(x, y):
|
|
41
41
|
return x + y
|
|
42
42
|
|
|
43
|
-
def pack_one_with_inverse(t, pattern):
|
|
44
|
-
packed, packed_shape = pack([t], pattern)
|
|
45
|
-
|
|
46
|
-
def inverse(out):
|
|
47
|
-
return unpack(out, packed_shape, pattern)[0]
|
|
48
|
-
|
|
49
|
-
return packed, inverse
|
|
50
|
-
|
|
51
43
|
# sinkhorn
|
|
52
44
|
|
|
53
45
|
def l1norm(t, dim):
|
|
@@ -91,13 +83,15 @@ def get_init_and_expand_reduce_stream_functions(
|
|
|
91
83
|
num_fracs = 1,
|
|
92
84
|
dim = None,
|
|
93
85
|
add_stream_embed = False,
|
|
94
|
-
disable = None
|
|
86
|
+
disable = None,
|
|
87
|
+
sinkhorn_iters = 20,
|
|
88
|
+
**kwargs
|
|
95
89
|
):
|
|
96
90
|
disable = default(disable, num_streams == 1 and num_fracs == 1)
|
|
97
91
|
|
|
98
92
|
hyper_conn_klass = ManifoldConstrainedHyperConnections if not disable else Residual
|
|
99
93
|
|
|
100
|
-
init_hyper_conn_fn = partial(hyper_conn_klass, num_streams, num_fracs = num_fracs)
|
|
94
|
+
init_hyper_conn_fn = partial(hyper_conn_klass, num_streams, num_fracs = num_fracs, sinkhorn_iters = sinkhorn_iters, **kwargs)
|
|
101
95
|
expand_reduce_fns = get_expand_reduce_stream_functions(num_streams, add_stream_embed = add_stream_embed, dim = dim, disable = disable)
|
|
102
96
|
|
|
103
97
|
if exists(dim):
|
|
@@ -318,11 +312,11 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
318
312
|
|
|
319
313
|
# norm
|
|
320
314
|
|
|
321
|
-
|
|
315
|
+
normed = rearrange(residuals, 'b ... f s d -> b ... (f s d)')
|
|
322
316
|
|
|
323
|
-
normed = self.norm(
|
|
317
|
+
normed = self.norm(normed)
|
|
324
318
|
|
|
325
|
-
normed =
|
|
319
|
+
normed = rearrange(normed, 'b ... (f s d) -> b ... f s d', f = self.num_fracs, s = streams)
|
|
326
320
|
|
|
327
321
|
# alpha for weighted sum of residuals going into branch
|
|
328
322
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|