hyper-connections 0.1.6__tar.gz → 0.1.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {hyper_connections-0.1.6 → hyper_connections-0.1.7}/PKG-INFO +1 -1
- {hyper_connections-0.1.6 → hyper_connections-0.1.7}/hyper_connections/hyper_connections.py +18 -6
- {hyper_connections-0.1.6 → hyper_connections-0.1.7}/pyproject.toml +1 -1
- {hyper_connections-0.1.6 → hyper_connections-0.1.7}/tests/test_hyper_connections.py +43 -0
- {hyper_connections-0.1.6 → hyper_connections-0.1.7}/.github/workflows/python-publish.yml +0 -0
- {hyper_connections-0.1.6 → hyper_connections-0.1.7}/.github/workflows/test.yml +0 -0
- {hyper_connections-0.1.6 → hyper_connections-0.1.7}/.gitignore +0 -0
- {hyper_connections-0.1.6 → hyper_connections-0.1.7}/LICENSE +0 -0
- {hyper_connections-0.1.6 → hyper_connections-0.1.7}/README.md +0 -0
- {hyper_connections-0.1.6 → hyper_connections-0.1.7}/hyper-connections.png +0 -0
- {hyper_connections-0.1.6 → hyper_connections-0.1.7}/hyper_connections/__init__.py +0 -0
- {hyper_connections-0.1.6 → hyper_connections-0.1.7}/hyper_connections/hyper_connections_with_multi_branch_inputs.py +0 -0
- {hyper_connections-0.1.6 → hyper_connections-0.1.7}/hyper_connections/hyper_connections_with_multi_input_streams.py +0 -0
|
@@ -73,16 +73,18 @@ class Residual(Module):
|
|
|
73
73
|
self,
|
|
74
74
|
*args,
|
|
75
75
|
branch: Module | None = None,
|
|
76
|
+
residual_transform: Module | None = None,
|
|
76
77
|
**kwargs
|
|
77
78
|
):
|
|
78
79
|
super().__init__()
|
|
79
80
|
self.branch = branch
|
|
81
|
+
self.residual_transform = default(residual_transform, nn.Identity())
|
|
80
82
|
|
|
81
83
|
def width_connection(self, residuals):
|
|
82
84
|
return residuals, residuals, dict()
|
|
83
85
|
|
|
84
86
|
def depth_connection(self, branch_output, residuals):
|
|
85
|
-
return branch_output + residuals
|
|
87
|
+
return branch_output + self.residual_transform(residuals)
|
|
86
88
|
|
|
87
89
|
def decorate_branch(self, branch: Callable):
|
|
88
90
|
assert not exists(self.branch), 'branch was already wrapped on init'
|
|
@@ -128,7 +130,8 @@ class HyperConnections(Module):
|
|
|
128
130
|
layer_index = None,
|
|
129
131
|
tanh = True,
|
|
130
132
|
channel_first = False,
|
|
131
|
-
dropout = 0
|
|
133
|
+
dropout = 0.,
|
|
134
|
+
residual_transform: Module | None = None, # to support resnet blocks where dimension in not equal to dimension out - usually a residual conv
|
|
132
135
|
):
|
|
133
136
|
"""
|
|
134
137
|
Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
|
|
@@ -168,7 +171,14 @@ class HyperConnections(Module):
|
|
|
168
171
|
|
|
169
172
|
self.channel_first = channel_first
|
|
170
173
|
|
|
174
|
+
# maybe residual transform
|
|
175
|
+
|
|
176
|
+
self.residual_transform = default(residual_transform, nn.Identity())
|
|
177
|
+
|
|
171
178
|
def width_connection(self, residuals):
|
|
179
|
+
|
|
180
|
+
maybe_transformed_residuals = self.residual_transform(residuals)
|
|
181
|
+
|
|
172
182
|
# width connection
|
|
173
183
|
|
|
174
184
|
if self.channel_first:
|
|
@@ -197,7 +207,7 @@ class HyperConnections(Module):
|
|
|
197
207
|
if self.channel_first:
|
|
198
208
|
branch_input = rearrange(branch_input, 'b ... d -> b d ...')
|
|
199
209
|
|
|
200
|
-
return branch_input,
|
|
210
|
+
return branch_input, maybe_transformed_residuals, dict(beta = beta)
|
|
201
211
|
|
|
202
212
|
def depth_connection(self, branch_output, residuals, *, beta):
|
|
203
213
|
# 'depth' connection
|
|
@@ -205,13 +215,15 @@ class HyperConnections(Module):
|
|
|
205
215
|
if self.channel_first:
|
|
206
216
|
branch_output = rearrange(branch_output, 'b d ... -> b ... d')
|
|
207
217
|
|
|
208
|
-
|
|
209
|
-
output = rearrange(
|
|
218
|
+
output = einsum(branch_output, beta, 'b ... d, b ... s -> b ... s d')
|
|
219
|
+
output = rearrange(output, 'b ... s d -> (b s) ... d')
|
|
210
220
|
|
|
211
221
|
if self.channel_first:
|
|
212
222
|
output = rearrange(output, 'b ... d -> b d ...')
|
|
213
223
|
|
|
214
|
-
|
|
224
|
+
residuals = residuals + output
|
|
225
|
+
|
|
226
|
+
return self.dropout(residuals)
|
|
215
227
|
|
|
216
228
|
def decorate_branch(self, branch: Callable):
|
|
217
229
|
assert not exists(self.branch), 'branch was already wrapped on init'
|
|
@@ -136,3 +136,46 @@ def test_multi_input_hyper_connections(disable):
|
|
|
136
136
|
residual = reduce_stream(residual)
|
|
137
137
|
|
|
138
138
|
assert residual.shape == (3, 1024, 512)
|
|
139
|
+
|
|
140
|
+
@pytest.mark.parametrize('disable', (False, True))
|
|
141
|
+
def test_residual_transform(disable):
|
|
142
|
+
|
|
143
|
+
# a single branch layer
|
|
144
|
+
|
|
145
|
+
branch = nn.Sequential(
|
|
146
|
+
nn.Linear(512, 512),
|
|
147
|
+
nn.SiLU(),
|
|
148
|
+
nn.Linear(512, 256)
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
residual_fn = nn.Linear(512, 256)
|
|
152
|
+
|
|
153
|
+
# before
|
|
154
|
+
|
|
155
|
+
residual = torch.randn(2, 1024, 512)
|
|
156
|
+
|
|
157
|
+
before_residual = branch(residual) + residual_fn(residual)
|
|
158
|
+
|
|
159
|
+
# after, say 4 streams in paper
|
|
160
|
+
|
|
161
|
+
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
|
162
|
+
|
|
163
|
+
init_hyper_conn, expand_stream, reduce_stream = get_init_and_expand_reduce_stream_functions(4, disable = disable)
|
|
164
|
+
|
|
165
|
+
# 1. wrap your branch function
|
|
166
|
+
|
|
167
|
+
hyper_conn_branch = init_hyper_conn(dim = 512, branch = branch, residual_transform = residual_fn)
|
|
168
|
+
|
|
169
|
+
# 2. expand to 4 streams, this must be done before your trunk, typically a for-loop with many branch functions
|
|
170
|
+
|
|
171
|
+
residual = expand_stream(residual)
|
|
172
|
+
|
|
173
|
+
# 3. forward your residual as usual into the wrapped branch function(s)
|
|
174
|
+
|
|
175
|
+
residual = hyper_conn_branch(residual)
|
|
176
|
+
|
|
177
|
+
# 4. reduce 4 streams with a summation, this has to be done after your for-loop trunk. for transformer, unsure whether to do before or after final norm
|
|
178
|
+
|
|
179
|
+
after_residual = reduce_stream(residual)
|
|
180
|
+
|
|
181
|
+
assert before_residual.shape == after_residual.shape
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|