hyper-connections 0.0.23__tar.gz → 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hyper_connections-0.1.0/.github/workflows/test.yml +19 -0
- {hyper_connections-0.0.23 → hyper_connections-0.1.0}/PKG-INFO +1 -1
- {hyper_connections-0.0.23 → hyper_connections-0.1.0}/hyper_connections/hyper_connections_with_multi_input_streams.py +9 -5
- {hyper_connections-0.0.23 → hyper_connections-0.1.0}/pyproject.toml +1 -1
- hyper_connections-0.1.0/tests/test_hyper_connections.py +138 -0
- {hyper_connections-0.0.23 → hyper_connections-0.1.0}/.github/workflows/python-publish.yml +0 -0
- {hyper_connections-0.0.23 → hyper_connections-0.1.0}/.gitignore +0 -0
- {hyper_connections-0.0.23 → hyper_connections-0.1.0}/LICENSE +0 -0
- {hyper_connections-0.0.23 → hyper_connections-0.1.0}/README.md +0 -0
- {hyper_connections-0.0.23 → hyper_connections-0.1.0}/hyper-connections.png +0 -0
- {hyper_connections-0.0.23 → hyper_connections-0.1.0}/hyper_connections/__init__.py +0 -0
- {hyper_connections-0.0.23 → hyper_connections-0.1.0}/hyper_connections/hyper_connections.py +0 -0
- {hyper_connections-0.0.23 → hyper_connections-0.1.0}/hyper_connections/hyper_connections_with_multi_branch_inputs.py +0 -0
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
name: Tests the examples in README
|
|
2
|
+
on: push
|
|
3
|
+
|
|
4
|
+
jobs:
|
|
5
|
+
test:
|
|
6
|
+
runs-on: ubuntu-latest
|
|
7
|
+
steps:
|
|
8
|
+
- uses: actions/checkout@v4
|
|
9
|
+
- name: Install Python
|
|
10
|
+
uses: actions/setup-python@v4
|
|
11
|
+
- name: Install the latest version of rye
|
|
12
|
+
uses: eifinger/setup-rye@v2
|
|
13
|
+
- name: Use UV instead of pip
|
|
14
|
+
run: rye config --set-bool behavior.use-uv=true
|
|
15
|
+
- name: Install dependencies
|
|
16
|
+
run: |
|
|
17
|
+
rye sync
|
|
18
|
+
- name: Run pytest
|
|
19
|
+
run: rye run pytest tests/
|
|
@@ -145,7 +145,7 @@ class Residual(Module):
|
|
|
145
145
|
|
|
146
146
|
# hyper connection with multiple input streams
|
|
147
147
|
|
|
148
|
-
InputPathType = int | str # the path to the second residual stream, where `int` points to *args[`int`
|
|
148
|
+
InputPathType = int | str # the path to the second residual stream, where `int` points to *args[`int`] and `str` points to **kwargs[`str`] - `int` needs to be > 0, as 0 is the default input residual stream
|
|
149
149
|
|
|
150
150
|
class HyperConnections(Module):
|
|
151
151
|
@beartype
|
|
@@ -185,7 +185,7 @@ class HyperConnections(Module):
|
|
|
185
185
|
init_alpha0 = torch.zeros((num_residual_streams, 1))
|
|
186
186
|
init_alpha0[init_residual_index, 0] = 1.
|
|
187
187
|
|
|
188
|
-
self.dynamic_alpha_and_branch_input = ProjActScale(dim, num_residual_streams + 1)
|
|
188
|
+
self.dynamic_alpha_and_branch_input = ProjActScale(dim, num_residual_streams + 1, activation = act)
|
|
189
189
|
self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))
|
|
190
190
|
|
|
191
191
|
self.dynamic_beta = ProjActScale(dim, 1, activation = act, squeeze_output = True)
|
|
@@ -196,9 +196,11 @@ class HyperConnections(Module):
|
|
|
196
196
|
additional_input_paths = default(additional_input_paths, [])
|
|
197
197
|
additional_input_paths = [one_path if isinstance(one_path, tuple) else (one_path, dim) for one_path in additional_input_paths]
|
|
198
198
|
|
|
199
|
+
assert all([isinstance(path, str) or path > 0 for (path, _) in additional_input_paths])
|
|
200
|
+
|
|
199
201
|
self.additional_norms = ModuleList([RMSNorm(dim) for _, dim in additional_input_paths])
|
|
200
202
|
self.additional_to_dynamic_input = ModuleList([ProjActScale(dim, 1, activation = act, squeeze_output = True) for _ , dim in additional_input_paths])
|
|
201
|
-
self.additional_static_input = nn.ParameterList([nn.Parameter(init_alpha0[..., 0])])
|
|
203
|
+
self.additional_static_input = nn.ParameterList([nn.Parameter(init_alpha0[..., 0]) for _ in additional_input_paths])
|
|
202
204
|
|
|
203
205
|
self.additional_input_paths = additional_input_paths
|
|
204
206
|
|
|
@@ -247,12 +249,14 @@ class HyperConnections(Module):
|
|
|
247
249
|
|
|
248
250
|
# take care of additional inputs
|
|
249
251
|
|
|
252
|
+
branch_args = list(branch_args)
|
|
253
|
+
|
|
250
254
|
for (path, *_), norm, proj, learned_static in zip(self.additional_input_paths, self.additional_norms, self.additional_to_dynamic_input, self.additional_static_input):
|
|
251
255
|
|
|
252
256
|
# get the residual streams from additional arguments
|
|
253
257
|
|
|
254
258
|
if isinstance(path, int):
|
|
255
|
-
additional_residuals = branch_args[path]
|
|
259
|
+
additional_residuals = branch_args[path - 1]
|
|
256
260
|
elif isinstance(path, str):
|
|
257
261
|
additional_residuals = branch_kwargs[path]
|
|
258
262
|
|
|
@@ -280,7 +284,7 @@ class HyperConnections(Module):
|
|
|
280
284
|
# set back transformed residual
|
|
281
285
|
|
|
282
286
|
if isinstance(path, int):
|
|
283
|
-
branch_args[path] = additional_residuals
|
|
287
|
+
branch_args[path - 1] = additional_residuals
|
|
284
288
|
elif isinstance(path, str):
|
|
285
289
|
branch_kwargs[path] = additional_residuals
|
|
286
290
|
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import nn
|
|
5
|
+
|
|
6
|
+
@pytest.mark.parametrize('disable', (False, True))
|
|
7
|
+
def test_readme(disable):
|
|
8
|
+
|
|
9
|
+
# a single branch layer
|
|
10
|
+
|
|
11
|
+
branch = nn.Linear(512, 512)
|
|
12
|
+
|
|
13
|
+
# before
|
|
14
|
+
|
|
15
|
+
residual = torch.randn(2, 1024, 512)
|
|
16
|
+
|
|
17
|
+
residual = branch(residual) + residual
|
|
18
|
+
|
|
19
|
+
# after, say 4 streams in paper
|
|
20
|
+
|
|
21
|
+
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
|
22
|
+
|
|
23
|
+
init_hyper_conn, expand_stream, reduce_stream = get_init_and_expand_reduce_stream_functions(4, disable = disable)
|
|
24
|
+
|
|
25
|
+
# 1. wrap your branch function
|
|
26
|
+
|
|
27
|
+
hyper_conn_branch = init_hyper_conn(dim = 512, branch = branch)
|
|
28
|
+
|
|
29
|
+
# 2. expand to 4 streams, this must be done before your trunk, typically a for-loop with many branch functions
|
|
30
|
+
|
|
31
|
+
residual = expand_stream(residual)
|
|
32
|
+
|
|
33
|
+
# 3. forward your residual as usual into the wrapped branch function(s)
|
|
34
|
+
|
|
35
|
+
residual = hyper_conn_branch(residual)
|
|
36
|
+
|
|
37
|
+
# 4. reduce 4 streams with a summation, this has to be done after your for-loop trunk. for transformer, unsure whether to do before or after final norm
|
|
38
|
+
|
|
39
|
+
residual = reduce_stream(residual)
|
|
40
|
+
|
|
41
|
+
assert residual.shape == (2, 1024, 512)
|
|
42
|
+
|
|
43
|
+
def test_manual():
|
|
44
|
+
# a single branch layer
|
|
45
|
+
|
|
46
|
+
branch = nn.Linear(512, 512)
|
|
47
|
+
|
|
48
|
+
# before
|
|
49
|
+
|
|
50
|
+
residual = torch.randn(2, 1024, 512)
|
|
51
|
+
|
|
52
|
+
residual = branch(residual) + residual
|
|
53
|
+
|
|
54
|
+
# after, say 4 streams in paper
|
|
55
|
+
|
|
56
|
+
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
|
57
|
+
|
|
58
|
+
init_hyper_conn, expand_stream, reduce_stream = get_init_and_expand_reduce_stream_functions(4)
|
|
59
|
+
|
|
60
|
+
# 1. instantiate hyper connection with correct number of streams (4 in this case) - or use the init function above
|
|
61
|
+
|
|
62
|
+
hyper_conn = init_hyper_conn(dim = 512)
|
|
63
|
+
|
|
64
|
+
# 2. expand to 4 streams
|
|
65
|
+
|
|
66
|
+
residual = expand_stream(residual)
|
|
67
|
+
|
|
68
|
+
# 3. forward your residual into hyper connection for the branch input + add residual function (learned betas)
|
|
69
|
+
|
|
70
|
+
branch_input, add_residual = hyper_conn(residual)
|
|
71
|
+
|
|
72
|
+
branch_output = branch(branch_input)
|
|
73
|
+
|
|
74
|
+
residual = add_residual(branch_output)
|
|
75
|
+
|
|
76
|
+
# or you can do it in one line as so -> residual = hyper_conn.decorate_branch(branch)(residual)
|
|
77
|
+
|
|
78
|
+
# 4. reduce 4 streams with a summation, this has to be done after your for loop trunk
|
|
79
|
+
|
|
80
|
+
residual = reduce_stream(residual)
|
|
81
|
+
assert residual.shape == (2, 1024, 512)
|
|
82
|
+
|
|
83
|
+
@pytest.mark.parametrize('disable', (False, True))
|
|
84
|
+
def test_multi_input_hyper_connections(disable):
|
|
85
|
+
|
|
86
|
+
# two branch layers
|
|
87
|
+
|
|
88
|
+
class CustomModule(nn.Module):
|
|
89
|
+
def __init__(self):
|
|
90
|
+
super().__init__()
|
|
91
|
+
self.linear = nn.Linear(512, 512)
|
|
92
|
+
self.second_linear = nn.Linear(256, 512)
|
|
93
|
+
self.third_linear = nn.Linear(128, 512)
|
|
94
|
+
|
|
95
|
+
def forward(self, x, second, *, third):
|
|
96
|
+
return self.linear(x) + self.second_linear(second) + self.third_linear(third), 3.
|
|
97
|
+
|
|
98
|
+
branch = CustomModule()
|
|
99
|
+
|
|
100
|
+
# before
|
|
101
|
+
|
|
102
|
+
residual = torch.randn(3, 1024, 512)
|
|
103
|
+
second_residual = torch.randn(3, 1024, 256)
|
|
104
|
+
third_residual = torch.randn(3, 1024, 128)
|
|
105
|
+
|
|
106
|
+
# residual = branch1(residual) + branch2(residual) + residual
|
|
107
|
+
|
|
108
|
+
# after, say 4 streams in paper
|
|
109
|
+
|
|
110
|
+
from hyper_connections.hyper_connections_with_multi_input_streams import HyperConnections
|
|
111
|
+
|
|
112
|
+
init_hyper_conn, expand_stream, reduce_stream = HyperConnections.get_init_and_expand_reduce_stream_functions(4, disable = disable)
|
|
113
|
+
|
|
114
|
+
# 1. instantiate hyper connection with correct number of streams (4 in this case) - or use the init function above
|
|
115
|
+
|
|
116
|
+
hyper_conn = init_hyper_conn(
|
|
117
|
+
dim = 512,
|
|
118
|
+
branch = branch,
|
|
119
|
+
additional_input_paths = [
|
|
120
|
+
(1, 256), # points at second residual stream, first arg
|
|
121
|
+
('third', 128) # points at third residual stream, keyword argument 'third'
|
|
122
|
+
],
|
|
123
|
+
layer_index = 1,
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
# 2. expand to 4 streams
|
|
127
|
+
|
|
128
|
+
residual = expand_stream(residual)
|
|
129
|
+
second_residual = expand_stream(second_residual)
|
|
130
|
+
third_residual = expand_stream(third_residual)
|
|
131
|
+
|
|
132
|
+
# 3. forward your residual into hyper connection for the branch input + add residual function (learned betas)
|
|
133
|
+
|
|
134
|
+
residual, rest_output = hyper_conn(residual, second_residual, third = third_residual)
|
|
135
|
+
|
|
136
|
+
residual = reduce_stream(residual)
|
|
137
|
+
|
|
138
|
+
assert residual.shape == (3, 1024, 512)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|