hyper-connections 0.0.24__tar.gz → 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hyper_connections-0.1.0/.github/workflows/test.yml +19 -0
- {hyper_connections-0.0.24 → hyper_connections-0.1.0}/PKG-INFO +1 -1
- {hyper_connections-0.0.24 → hyper_connections-0.1.0}/hyper_connections/hyper_connections_with_multi_input_streams.py +3 -3
- {hyper_connections-0.0.24 → hyper_connections-0.1.0}/pyproject.toml +1 -1
- hyper_connections-0.1.0/tests/test_hyper_connections.py +138 -0
- {hyper_connections-0.0.24 → hyper_connections-0.1.0}/.github/workflows/python-publish.yml +0 -0
- {hyper_connections-0.0.24 → hyper_connections-0.1.0}/.gitignore +0 -0
- {hyper_connections-0.0.24 → hyper_connections-0.1.0}/LICENSE +0 -0
- {hyper_connections-0.0.24 → hyper_connections-0.1.0}/README.md +0 -0
- {hyper_connections-0.0.24 → hyper_connections-0.1.0}/hyper-connections.png +0 -0
- {hyper_connections-0.0.24 → hyper_connections-0.1.0}/hyper_connections/__init__.py +0 -0
- {hyper_connections-0.0.24 → hyper_connections-0.1.0}/hyper_connections/hyper_connections.py +0 -0
- {hyper_connections-0.0.24 → hyper_connections-0.1.0}/hyper_connections/hyper_connections_with_multi_branch_inputs.py +0 -0
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
name: Tests the examples in README
|
|
2
|
+
on: push
|
|
3
|
+
|
|
4
|
+
jobs:
|
|
5
|
+
test:
|
|
6
|
+
runs-on: ubuntu-latest
|
|
7
|
+
steps:
|
|
8
|
+
- uses: actions/checkout@v4
|
|
9
|
+
- name: Install Python
|
|
10
|
+
uses: actions/setup-python@v4
|
|
11
|
+
- name: Install the latest version of rye
|
|
12
|
+
uses: eifinger/setup-rye@v2
|
|
13
|
+
- name: Use UV instead of pip
|
|
14
|
+
run: rye config --set-bool behavior.use-uv=true
|
|
15
|
+
- name: Install dependencies
|
|
16
|
+
run: |
|
|
17
|
+
rye sync
|
|
18
|
+
- name: Run pytest
|
|
19
|
+
run: rye run pytest tests/
|
|
@@ -145,7 +145,7 @@ class Residual(Module):
|
|
|
145
145
|
|
|
146
146
|
# hyper connection with multiple input streams
|
|
147
147
|
|
|
148
|
-
InputPathType = int | str # the path to the second residual stream, where `int` points to *args[`int`
|
|
148
|
+
InputPathType = int | str # the path to the second residual stream, where `int` points to *args[`int`] and `str` points to **kwargs[`str`] - `int` needs to be > 0, as 0 is the default input residual stream
|
|
149
149
|
|
|
150
150
|
class HyperConnections(Module):
|
|
151
151
|
@beartype
|
|
@@ -185,7 +185,7 @@ class HyperConnections(Module):
|
|
|
185
185
|
init_alpha0 = torch.zeros((num_residual_streams, 1))
|
|
186
186
|
init_alpha0[init_residual_index, 0] = 1.
|
|
187
187
|
|
|
188
|
-
self.dynamic_alpha_and_branch_input = ProjActScale(dim, num_residual_streams + 1)
|
|
188
|
+
self.dynamic_alpha_and_branch_input = ProjActScale(dim, num_residual_streams + 1, activation = act)
|
|
189
189
|
self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))
|
|
190
190
|
|
|
191
191
|
self.dynamic_beta = ProjActScale(dim, 1, activation = act, squeeze_output = True)
|
|
@@ -200,7 +200,7 @@ class HyperConnections(Module):
|
|
|
200
200
|
|
|
201
201
|
self.additional_norms = ModuleList([RMSNorm(dim) for _, dim in additional_input_paths])
|
|
202
202
|
self.additional_to_dynamic_input = ModuleList([ProjActScale(dim, 1, activation = act, squeeze_output = True) for _ , dim in additional_input_paths])
|
|
203
|
-
self.additional_static_input = nn.ParameterList([nn.Parameter(init_alpha0[..., 0])])
|
|
203
|
+
self.additional_static_input = nn.ParameterList([nn.Parameter(init_alpha0[..., 0]) for _ in additional_input_paths])
|
|
204
204
|
|
|
205
205
|
self.additional_input_paths = additional_input_paths
|
|
206
206
|
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import nn
|
|
5
|
+
|
|
6
|
+
@pytest.mark.parametrize('disable', (False, True))
|
|
7
|
+
def test_readme(disable):
|
|
8
|
+
|
|
9
|
+
# a single branch layer
|
|
10
|
+
|
|
11
|
+
branch = nn.Linear(512, 512)
|
|
12
|
+
|
|
13
|
+
# before
|
|
14
|
+
|
|
15
|
+
residual = torch.randn(2, 1024, 512)
|
|
16
|
+
|
|
17
|
+
residual = branch(residual) + residual
|
|
18
|
+
|
|
19
|
+
# after, say 4 streams in paper
|
|
20
|
+
|
|
21
|
+
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
|
22
|
+
|
|
23
|
+
init_hyper_conn, expand_stream, reduce_stream = get_init_and_expand_reduce_stream_functions(4, disable = disable)
|
|
24
|
+
|
|
25
|
+
# 1. wrap your branch function
|
|
26
|
+
|
|
27
|
+
hyper_conn_branch = init_hyper_conn(dim = 512, branch = branch)
|
|
28
|
+
|
|
29
|
+
# 2. expand to 4 streams, this must be done before your trunk, typically a for-loop with many branch functions
|
|
30
|
+
|
|
31
|
+
residual = expand_stream(residual)
|
|
32
|
+
|
|
33
|
+
# 3. forward your residual as usual into the wrapped branch function(s)
|
|
34
|
+
|
|
35
|
+
residual = hyper_conn_branch(residual)
|
|
36
|
+
|
|
37
|
+
# 4. reduce 4 streams with a summation, this has to be done after your for-loop trunk. for transformer, unsure whether to do before or after final norm
|
|
38
|
+
|
|
39
|
+
residual = reduce_stream(residual)
|
|
40
|
+
|
|
41
|
+
assert residual.shape == (2, 1024, 512)
|
|
42
|
+
|
|
43
|
+
def test_manual():
|
|
44
|
+
# a single branch layer
|
|
45
|
+
|
|
46
|
+
branch = nn.Linear(512, 512)
|
|
47
|
+
|
|
48
|
+
# before
|
|
49
|
+
|
|
50
|
+
residual = torch.randn(2, 1024, 512)
|
|
51
|
+
|
|
52
|
+
residual = branch(residual) + residual
|
|
53
|
+
|
|
54
|
+
# after, say 4 streams in paper
|
|
55
|
+
|
|
56
|
+
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
|
57
|
+
|
|
58
|
+
init_hyper_conn, expand_stream, reduce_stream = get_init_and_expand_reduce_stream_functions(4)
|
|
59
|
+
|
|
60
|
+
# 1. instantiate hyper connection with correct number of streams (4 in this case) - or use the init function above
|
|
61
|
+
|
|
62
|
+
hyper_conn = init_hyper_conn(dim = 512)
|
|
63
|
+
|
|
64
|
+
# 2. expand to 4 streams
|
|
65
|
+
|
|
66
|
+
residual = expand_stream(residual)
|
|
67
|
+
|
|
68
|
+
# 3. forward your residual into hyper connection for the branch input + add residual function (learned betas)
|
|
69
|
+
|
|
70
|
+
branch_input, add_residual = hyper_conn(residual)
|
|
71
|
+
|
|
72
|
+
branch_output = branch(branch_input)
|
|
73
|
+
|
|
74
|
+
residual = add_residual(branch_output)
|
|
75
|
+
|
|
76
|
+
# or you can do it in one line as so -> residual = hyper_conn.decorate_branch(branch)(residual)
|
|
77
|
+
|
|
78
|
+
# 4. reduce 4 streams with a summation, this has to be done after your for loop trunk
|
|
79
|
+
|
|
80
|
+
residual = reduce_stream(residual)
|
|
81
|
+
assert residual.shape == (2, 1024, 512)
|
|
82
|
+
|
|
83
|
+
@pytest.mark.parametrize('disable', (False, True))
|
|
84
|
+
def test_multi_input_hyper_connections(disable):
|
|
85
|
+
|
|
86
|
+
# two branch layers
|
|
87
|
+
|
|
88
|
+
class CustomModule(nn.Module):
|
|
89
|
+
def __init__(self):
|
|
90
|
+
super().__init__()
|
|
91
|
+
self.linear = nn.Linear(512, 512)
|
|
92
|
+
self.second_linear = nn.Linear(256, 512)
|
|
93
|
+
self.third_linear = nn.Linear(128, 512)
|
|
94
|
+
|
|
95
|
+
def forward(self, x, second, *, third):
|
|
96
|
+
return self.linear(x) + self.second_linear(second) + self.third_linear(third), 3.
|
|
97
|
+
|
|
98
|
+
branch = CustomModule()
|
|
99
|
+
|
|
100
|
+
# before
|
|
101
|
+
|
|
102
|
+
residual = torch.randn(3, 1024, 512)
|
|
103
|
+
second_residual = torch.randn(3, 1024, 256)
|
|
104
|
+
third_residual = torch.randn(3, 1024, 128)
|
|
105
|
+
|
|
106
|
+
# residual = branch1(residual) + branch2(residual) + residual
|
|
107
|
+
|
|
108
|
+
# after, say 4 streams in paper
|
|
109
|
+
|
|
110
|
+
from hyper_connections.hyper_connections_with_multi_input_streams import HyperConnections
|
|
111
|
+
|
|
112
|
+
init_hyper_conn, expand_stream, reduce_stream = HyperConnections.get_init_and_expand_reduce_stream_functions(4, disable = disable)
|
|
113
|
+
|
|
114
|
+
# 1. instantiate hyper connection with correct number of streams (4 in this case) - or use the init function above
|
|
115
|
+
|
|
116
|
+
hyper_conn = init_hyper_conn(
|
|
117
|
+
dim = 512,
|
|
118
|
+
branch = branch,
|
|
119
|
+
additional_input_paths = [
|
|
120
|
+
(1, 256), # points at second residual stream, first arg
|
|
121
|
+
('third', 128) # points at third residual stream, keyword argument 'third'
|
|
122
|
+
],
|
|
123
|
+
layer_index = 1,
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
# 2. expand to 4 streams
|
|
127
|
+
|
|
128
|
+
residual = expand_stream(residual)
|
|
129
|
+
second_residual = expand_stream(second_residual)
|
|
130
|
+
third_residual = expand_stream(third_residual)
|
|
131
|
+
|
|
132
|
+
# 3. forward your residual into hyper connection for the branch input + add residual function (learned betas)
|
|
133
|
+
|
|
134
|
+
residual, rest_output = hyper_conn(residual, second_residual, third = third_residual)
|
|
135
|
+
|
|
136
|
+
residual = reduce_stream(residual)
|
|
137
|
+
|
|
138
|
+
assert residual.shape == (3, 1024, 512)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|