hyper-connections 0.0.24__tar.gz → 0.1.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hyper_connections-0.1.1/.github/workflows/test.yml +19 -0
- {hyper_connections-0.0.24 → hyper_connections-0.1.1}/PKG-INFO +1 -4
- {hyper_connections-0.0.24 → hyper_connections-0.1.1}/hyper_connections/hyper_connections_with_multi_branch_inputs.py +0 -3
- {hyper_connections-0.0.24 → hyper_connections-0.1.1}/hyper_connections/hyper_connections_with_multi_input_streams.py +3 -7
- {hyper_connections-0.0.24 → hyper_connections-0.1.1}/pyproject.toml +3 -9
- hyper_connections-0.1.1/tests/test_hyper_connections.py +138 -0
- {hyper_connections-0.0.24 → hyper_connections-0.1.1}/.github/workflows/python-publish.yml +0 -0
- {hyper_connections-0.0.24 → hyper_connections-0.1.1}/.gitignore +0 -0
- {hyper_connections-0.0.24 → hyper_connections-0.1.1}/LICENSE +0 -0
- {hyper_connections-0.0.24 → hyper_connections-0.1.1}/README.md +0 -0
- {hyper_connections-0.0.24 → hyper_connections-0.1.1}/hyper-connections.png +0 -0
- {hyper_connections-0.0.24 → hyper_connections-0.1.1}/hyper_connections/__init__.py +0 -0
- {hyper_connections-0.0.24 → hyper_connections-0.1.1}/hyper_connections/hyper_connections.py +0 -0
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
name: Tests the examples in README
|
|
2
|
+
on: push
|
|
3
|
+
|
|
4
|
+
jobs:
|
|
5
|
+
test:
|
|
6
|
+
runs-on: ubuntu-latest
|
|
7
|
+
steps:
|
|
8
|
+
- uses: actions/checkout@v4
|
|
9
|
+
- name: Install Python
|
|
10
|
+
uses: actions/setup-python@v4
|
|
11
|
+
- name: Install the latest version of rye
|
|
12
|
+
uses: eifinger/setup-rye@v2
|
|
13
|
+
- name: Use UV instead of pip
|
|
14
|
+
run: rye config --set-bool behavior.use-uv=true
|
|
15
|
+
- name: Install dependencies
|
|
16
|
+
run: |
|
|
17
|
+
rye sync
|
|
18
|
+
- name: Run pytest
|
|
19
|
+
run: rye run pytest tests/
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: hyper-connections
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.1.1
|
|
4
4
|
Summary: Hyper-Connections
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/hyper-connections/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/hyper-connections
|
|
@@ -34,12 +34,9 @@ Classifier: License :: OSI Approved :: MIT License
|
|
|
34
34
|
Classifier: Programming Language :: Python :: 3.9
|
|
35
35
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
36
36
|
Requires-Python: >=3.9
|
|
37
|
-
Requires-Dist: beartype
|
|
38
37
|
Requires-Dist: einops>=0.8.0
|
|
39
38
|
Requires-Dist: torch>=2.3
|
|
40
39
|
Provides-Extra: examples
|
|
41
|
-
Provides-Extra: test
|
|
42
|
-
Requires-Dist: pytest; extra == 'test'
|
|
43
40
|
Description-Content-Type: text/markdown
|
|
44
41
|
|
|
45
42
|
<img src="./hyper-connections.png" width="450px"></img>
|
|
@@ -12,8 +12,6 @@ from torch.utils._pytree import tree_flatten, tree_unflatten
|
|
|
12
12
|
|
|
13
13
|
from einops import rearrange, repeat, reduce, einsum
|
|
14
14
|
|
|
15
|
-
from beartype import beartype
|
|
16
|
-
|
|
17
15
|
"""
|
|
18
16
|
ein notation:
|
|
19
17
|
b - batch
|
|
@@ -65,7 +63,6 @@ def get_init_and_expand_reduce_stream_functions(cls, num_streams, disable = Fals
|
|
|
65
63
|
# hyper connection residual streams
|
|
66
64
|
|
|
67
65
|
class HyperConnections(Module):
|
|
68
|
-
@beartype
|
|
69
66
|
def __init__(
|
|
70
67
|
self,
|
|
71
68
|
num_residual_streams,
|
|
@@ -13,8 +13,6 @@ from torch.utils._pytree import tree_flatten, tree_unflatten
|
|
|
13
13
|
from einops import rearrange, repeat, reduce, einsum
|
|
14
14
|
from einops.layers.torch import Rearrange
|
|
15
15
|
|
|
16
|
-
from beartype import beartype
|
|
17
|
-
|
|
18
16
|
"""
|
|
19
17
|
ein notation:
|
|
20
18
|
b - batch
|
|
@@ -95,7 +93,6 @@ class ProjActScale(Module):
|
|
|
95
93
|
# residual base class
|
|
96
94
|
|
|
97
95
|
class Residual(Module):
|
|
98
|
-
@beartype
|
|
99
96
|
def __init__(
|
|
100
97
|
self,
|
|
101
98
|
*args,
|
|
@@ -145,10 +142,9 @@ class Residual(Module):
|
|
|
145
142
|
|
|
146
143
|
# hyper connection with multiple input streams
|
|
147
144
|
|
|
148
|
-
InputPathType = int | str # the path to the second residual stream, where `int` points to *args[`int`
|
|
145
|
+
InputPathType = int | str # the path to the second residual stream, where `int` points to *args[`int`] and `str` points to **kwargs[`str`] - `int` needs to be > 0, as 0 is the default input residual stream
|
|
149
146
|
|
|
150
147
|
class HyperConnections(Module):
|
|
151
|
-
@beartype
|
|
152
148
|
def __init__(
|
|
153
149
|
self,
|
|
154
150
|
num_residual_streams,
|
|
@@ -185,7 +181,7 @@ class HyperConnections(Module):
|
|
|
185
181
|
init_alpha0 = torch.zeros((num_residual_streams, 1))
|
|
186
182
|
init_alpha0[init_residual_index, 0] = 1.
|
|
187
183
|
|
|
188
|
-
self.dynamic_alpha_and_branch_input = ProjActScale(dim, num_residual_streams + 1)
|
|
184
|
+
self.dynamic_alpha_and_branch_input = ProjActScale(dim, num_residual_streams + 1, activation = act)
|
|
189
185
|
self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))
|
|
190
186
|
|
|
191
187
|
self.dynamic_beta = ProjActScale(dim, 1, activation = act, squeeze_output = True)
|
|
@@ -200,7 +196,7 @@ class HyperConnections(Module):
|
|
|
200
196
|
|
|
201
197
|
self.additional_norms = ModuleList([RMSNorm(dim) for _, dim in additional_input_paths])
|
|
202
198
|
self.additional_to_dynamic_input = ModuleList([ProjActScale(dim, 1, activation = act, squeeze_output = True) for _ , dim in additional_input_paths])
|
|
203
|
-
self.additional_static_input = nn.ParameterList([nn.Parameter(init_alpha0[..., 0])])
|
|
199
|
+
self.additional_static_input = nn.ParameterList([nn.Parameter(init_alpha0[..., 0]) for _ in additional_input_paths])
|
|
204
200
|
|
|
205
201
|
self.additional_input_paths = additional_input_paths
|
|
206
202
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "hyper-connections"
|
|
3
|
-
version = "0.
|
|
3
|
+
version = "0.1.1"
|
|
4
4
|
description = "Hyper-Connections"
|
|
5
5
|
authors = [
|
|
6
6
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
|
@@ -23,7 +23,6 @@ classifiers=[
|
|
|
23
23
|
]
|
|
24
24
|
|
|
25
25
|
dependencies = [
|
|
26
|
-
"beartype",
|
|
27
26
|
"einops>=0.8.0",
|
|
28
27
|
"torch>=2.3",
|
|
29
28
|
]
|
|
@@ -34,14 +33,9 @@ Repository = "https://github.com/lucidrains/hyper-connections"
|
|
|
34
33
|
|
|
35
34
|
[project.optional-dependencies]
|
|
36
35
|
examples = []
|
|
37
|
-
test = [
|
|
38
|
-
"pytest"
|
|
39
|
-
]
|
|
40
36
|
|
|
41
37
|
[tool.pytest.ini_options]
|
|
42
|
-
pythonpath = [
|
|
43
|
-
"."
|
|
44
|
-
]
|
|
38
|
+
pythonpath = ["."]
|
|
45
39
|
|
|
46
40
|
[build-system]
|
|
47
41
|
requires = ["hatchling"]
|
|
@@ -49,7 +43,7 @@ build-backend = "hatchling.build"
|
|
|
49
43
|
|
|
50
44
|
[tool.rye]
|
|
51
45
|
managed = true
|
|
52
|
-
dev-dependencies = []
|
|
46
|
+
dev-dependencies = ["pytest>=8.2.0"]
|
|
53
47
|
|
|
54
48
|
[tool.hatch.metadata]
|
|
55
49
|
allow-direct-references = true
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import nn
|
|
5
|
+
|
|
6
|
+
@pytest.mark.parametrize('disable', (False, True))
|
|
7
|
+
def test_readme(disable):
|
|
8
|
+
|
|
9
|
+
# a single branch layer
|
|
10
|
+
|
|
11
|
+
branch = nn.Linear(512, 512)
|
|
12
|
+
|
|
13
|
+
# before
|
|
14
|
+
|
|
15
|
+
residual = torch.randn(2, 1024, 512)
|
|
16
|
+
|
|
17
|
+
residual = branch(residual) + residual
|
|
18
|
+
|
|
19
|
+
# after, say 4 streams in paper
|
|
20
|
+
|
|
21
|
+
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
|
22
|
+
|
|
23
|
+
init_hyper_conn, expand_stream, reduce_stream = get_init_and_expand_reduce_stream_functions(4, disable = disable)
|
|
24
|
+
|
|
25
|
+
# 1. wrap your branch function
|
|
26
|
+
|
|
27
|
+
hyper_conn_branch = init_hyper_conn(dim = 512, branch = branch)
|
|
28
|
+
|
|
29
|
+
# 2. expand to 4 streams, this must be done before your trunk, typically a for-loop with many branch functions
|
|
30
|
+
|
|
31
|
+
residual = expand_stream(residual)
|
|
32
|
+
|
|
33
|
+
# 3. forward your residual as usual into the wrapped branch function(s)
|
|
34
|
+
|
|
35
|
+
residual = hyper_conn_branch(residual)
|
|
36
|
+
|
|
37
|
+
# 4. reduce 4 streams with a summation, this has to be done after your for-loop trunk. for transformer, unsure whether to do before or after final norm
|
|
38
|
+
|
|
39
|
+
residual = reduce_stream(residual)
|
|
40
|
+
|
|
41
|
+
assert residual.shape == (2, 1024, 512)
|
|
42
|
+
|
|
43
|
+
def test_manual():
|
|
44
|
+
# a single branch layer
|
|
45
|
+
|
|
46
|
+
branch = nn.Linear(512, 512)
|
|
47
|
+
|
|
48
|
+
# before
|
|
49
|
+
|
|
50
|
+
residual = torch.randn(2, 1024, 512)
|
|
51
|
+
|
|
52
|
+
residual = branch(residual) + residual
|
|
53
|
+
|
|
54
|
+
# after, say 4 streams in paper
|
|
55
|
+
|
|
56
|
+
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
|
57
|
+
|
|
58
|
+
init_hyper_conn, expand_stream, reduce_stream = get_init_and_expand_reduce_stream_functions(4)
|
|
59
|
+
|
|
60
|
+
# 1. instantiate hyper connection with correct number of streams (4 in this case) - or use the init function above
|
|
61
|
+
|
|
62
|
+
hyper_conn = init_hyper_conn(dim = 512)
|
|
63
|
+
|
|
64
|
+
# 2. expand to 4 streams
|
|
65
|
+
|
|
66
|
+
residual = expand_stream(residual)
|
|
67
|
+
|
|
68
|
+
# 3. forward your residual into hyper connection for the branch input + add residual function (learned betas)
|
|
69
|
+
|
|
70
|
+
branch_input, add_residual = hyper_conn(residual)
|
|
71
|
+
|
|
72
|
+
branch_output = branch(branch_input)
|
|
73
|
+
|
|
74
|
+
residual = add_residual(branch_output)
|
|
75
|
+
|
|
76
|
+
# or you can do it in one line as so -> residual = hyper_conn.decorate_branch(branch)(residual)
|
|
77
|
+
|
|
78
|
+
# 4. reduce 4 streams with a summation, this has to be done after your for loop trunk
|
|
79
|
+
|
|
80
|
+
residual = reduce_stream(residual)
|
|
81
|
+
assert residual.shape == (2, 1024, 512)
|
|
82
|
+
|
|
83
|
+
@pytest.mark.parametrize('disable', (False, True))
|
|
84
|
+
def test_multi_input_hyper_connections(disable):
|
|
85
|
+
|
|
86
|
+
# two branch layers
|
|
87
|
+
|
|
88
|
+
class CustomModule(nn.Module):
|
|
89
|
+
def __init__(self):
|
|
90
|
+
super().__init__()
|
|
91
|
+
self.linear = nn.Linear(512, 512)
|
|
92
|
+
self.second_linear = nn.Linear(256, 512)
|
|
93
|
+
self.third_linear = nn.Linear(128, 512)
|
|
94
|
+
|
|
95
|
+
def forward(self, x, second, *, third):
|
|
96
|
+
return self.linear(x) + self.second_linear(second) + self.third_linear(third), 3.
|
|
97
|
+
|
|
98
|
+
branch = CustomModule()
|
|
99
|
+
|
|
100
|
+
# before
|
|
101
|
+
|
|
102
|
+
residual = torch.randn(3, 1024, 512)
|
|
103
|
+
second_residual = torch.randn(3, 1024, 256)
|
|
104
|
+
third_residual = torch.randn(3, 1024, 128)
|
|
105
|
+
|
|
106
|
+
# residual = branch1(residual) + branch2(residual) + residual
|
|
107
|
+
|
|
108
|
+
# after, say 4 streams in paper
|
|
109
|
+
|
|
110
|
+
from hyper_connections.hyper_connections_with_multi_input_streams import HyperConnections
|
|
111
|
+
|
|
112
|
+
init_hyper_conn, expand_stream, reduce_stream = HyperConnections.get_init_and_expand_reduce_stream_functions(4, disable = disable)
|
|
113
|
+
|
|
114
|
+
# 1. instantiate hyper connection with correct number of streams (4 in this case) - or use the init function above
|
|
115
|
+
|
|
116
|
+
hyper_conn = init_hyper_conn(
|
|
117
|
+
dim = 512,
|
|
118
|
+
branch = branch,
|
|
119
|
+
additional_input_paths = [
|
|
120
|
+
(1, 256), # points at second residual stream, first arg
|
|
121
|
+
('third', 128) # points at third residual stream, keyword argument 'third'
|
|
122
|
+
],
|
|
123
|
+
layer_index = 1,
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
# 2. expand to 4 streams
|
|
127
|
+
|
|
128
|
+
residual = expand_stream(residual)
|
|
129
|
+
second_residual = expand_stream(second_residual)
|
|
130
|
+
third_residual = expand_stream(third_residual)
|
|
131
|
+
|
|
132
|
+
# 3. forward your residual into hyper connection for the branch input + add residual function (learned betas)
|
|
133
|
+
|
|
134
|
+
residual, rest_output = hyper_conn(residual, second_residual, third = third_residual)
|
|
135
|
+
|
|
136
|
+
residual = reduce_stream(residual)
|
|
137
|
+
|
|
138
|
+
assert residual.shape == (3, 1024, 512)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|