hyper-connections 0.1.10__py3-none-any.whl → 0.1.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hyper_connections/hyper_connections.py +27 -6
- hyper_connections/residuals.py +22 -0
- {hyper_connections-0.1.10.dist-info → hyper_connections-0.1.12.dist-info}/METADATA +1 -1
- {hyper_connections-0.1.10.dist-info → hyper_connections-0.1.12.dist-info}/RECORD +6 -5
- {hyper_connections-0.1.10.dist-info → hyper_connections-0.1.12.dist-info}/WHEEL +0 -0
- {hyper_connections-0.1.10.dist-info → hyper_connections-0.1.12.dist-info}/licenses/LICENSE +0 -0
|
@@ -19,6 +19,7 @@ b - batch
|
|
|
19
19
|
d - feature dimension
|
|
20
20
|
s - residual streams
|
|
21
21
|
t - residual streams + num branch inputs
|
|
22
|
+
v - number of views for branch input
|
|
22
23
|
"""
|
|
23
24
|
|
|
24
25
|
# helper functions
|
|
@@ -32,6 +33,9 @@ def default(v, d):
|
|
|
32
33
|
def identity(t):
|
|
33
34
|
return t
|
|
34
35
|
|
|
36
|
+
def add(x, y):
|
|
37
|
+
return x + y
|
|
38
|
+
|
|
35
39
|
# main functions
|
|
36
40
|
|
|
37
41
|
def get_expand_reduce_stream_functions(num_streams, add_stream_embed = False, dim = None, disable = False):
|
|
@@ -141,7 +145,9 @@ class HyperConnections(Module):
|
|
|
141
145
|
channel_first = False,
|
|
142
146
|
dropout = 0.,
|
|
143
147
|
residual_transform: Module | None = None, # to support resnet blocks where dimension in not equal to dimension out - usually a residual conv
|
|
144
|
-
add_branch_out_to_residual = True # will disable depth connections (weighted residual sum with beta) if set False
|
|
148
|
+
add_branch_out_to_residual = True, # will disable depth connections (weighted residual sum with beta) if set False
|
|
149
|
+
num_input_views = 1, # allow for the branch module to receive multiple input views, dimension placed on the very left (before batch)
|
|
150
|
+
depth_residual_fn = add
|
|
145
151
|
):
|
|
146
152
|
"""
|
|
147
153
|
Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
|
|
@@ -161,14 +167,19 @@ class HyperConnections(Module):
|
|
|
161
167
|
self.num_residual_streams = num_residual_streams
|
|
162
168
|
init_residual_index = default(layer_index, randrange(num_residual_streams)) % num_residual_streams # just choose one random residual stream if layer index not given
|
|
163
169
|
|
|
170
|
+
# width num residual streams
|
|
171
|
+
|
|
172
|
+
assert num_input_views >= 1
|
|
173
|
+
self.num_input_views = num_input_views
|
|
174
|
+
|
|
164
175
|
# width connection
|
|
165
176
|
|
|
166
|
-
init_alpha0 = torch.zeros((num_residual_streams,
|
|
167
|
-
init_alpha0[init_residual_index,
|
|
177
|
+
init_alpha0 = torch.zeros((num_residual_streams, num_input_views))
|
|
178
|
+
init_alpha0[init_residual_index, :] = 1.
|
|
168
179
|
|
|
169
180
|
self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))
|
|
170
181
|
|
|
171
|
-
self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams +
|
|
182
|
+
self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams + num_input_views))
|
|
172
183
|
self.dynamic_alpha_scale = nn.Parameter(torch.ones(()) * 1e-2)
|
|
173
184
|
|
|
174
185
|
# depth connection related (beta)
|
|
@@ -192,6 +203,12 @@ class HyperConnections(Module):
|
|
|
192
203
|
|
|
193
204
|
self.residual_transform = default(residual_transform, nn.Identity())
|
|
194
205
|
|
|
206
|
+
# maybe custom depth connection residual function
|
|
207
|
+
# this is to prepare for gating the addition of the branch outputs to the residual streams
|
|
208
|
+
# needed for memory lanes a la RMT / LMM
|
|
209
|
+
|
|
210
|
+
self.depth_residual_fn = depth_residual_fn
|
|
211
|
+
|
|
195
212
|
def width_connection(self, residuals):
|
|
196
213
|
|
|
197
214
|
maybe_transformed_residuals = self.residual_transform(residuals)
|
|
@@ -222,7 +239,11 @@ class HyperConnections(Module):
|
|
|
222
239
|
|
|
223
240
|
mix_h = einsum(alpha, residuals, '... s t, ... s d -> ... t d')
|
|
224
241
|
|
|
225
|
-
|
|
242
|
+
if self.num_input_views == 1:
|
|
243
|
+
branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :]
|
|
244
|
+
else:
|
|
245
|
+
branch_input, residuals = mix_h[..., :self.num_input_views, :], mix_h[..., self.num_input_views:, :]
|
|
246
|
+
branch_input = rearrange(branch_input, 'b ... v d -> v b ... d')
|
|
226
247
|
|
|
227
248
|
if self.channel_first:
|
|
228
249
|
branch_input = rearrange(branch_input, 'b ... d -> b d ...')
|
|
@@ -243,7 +264,7 @@ class HyperConnections(Module):
|
|
|
243
264
|
if self.channel_first:
|
|
244
265
|
output = rearrange(output, 'b ... d -> b d ...')
|
|
245
266
|
|
|
246
|
-
residuals = residuals
|
|
267
|
+
residuals = self.depth_residual_fn(output, residuals)
|
|
247
268
|
|
|
248
269
|
return self.dropout(residuals)
|
|
249
270
|
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import nn
|
|
3
|
+
from torch.nn import Module
|
|
4
|
+
|
|
5
|
+
from einops import rearrange
|
|
6
|
+
|
|
7
|
+
class GRUGatedResidual(Module):
|
|
8
|
+
def __init__(
|
|
9
|
+
self,
|
|
10
|
+
dim
|
|
11
|
+
):
|
|
12
|
+
super().__init__()
|
|
13
|
+
self.gru = nn.GRUCell(dim, dim)
|
|
14
|
+
|
|
15
|
+
def forward(self, x, residual):
|
|
16
|
+
|
|
17
|
+
gated_output = self.gru(
|
|
18
|
+
rearrange(x, 'b n d -> (b n) d'),
|
|
19
|
+
rearrange(residual, 'b n d -> (b n) d')
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
return gated_output.reshape_as(x)
|
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
hyper_connections/__init__.py,sha256=d2zNTka0Gp9vINu4U-RhgTJFBhsVrs1fne_15Zl0oOs,224
|
|
2
|
-
hyper_connections/hyper_connections.py,sha256=
|
|
2
|
+
hyper_connections/hyper_connections.py,sha256=vpipBRUGgYQ2qLBtT4Ws-myJYVdkQDkN3IkpTMkxRxc,12485
|
|
3
3
|
hyper_connections/hyper_connections_channel_first.py,sha256=BojfO2dcT4jX1rlcU3kr0B6B_CjrkkS2AZU4ZXeWvh8,6769
|
|
4
4
|
hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=2JABz1slrF5_XP33L1CMNTmmixfoD464FtQpUADqneU,7806
|
|
5
5
|
hyper_connections/hyper_connections_with_multi_input_streams.py,sha256=UIKceEZEBLKFL5VuP5tR1KTDeZNIJEKjFuPAwXkcp0I,11282
|
|
6
|
-
hyper_connections
|
|
7
|
-
hyper_connections-0.1.
|
|
8
|
-
hyper_connections-0.1.
|
|
9
|
-
hyper_connections-0.1.
|
|
6
|
+
hyper_connections/residuals.py,sha256=F8NMh9vHDN8Nyo0nzIm_owgR9gXPZUCNujAuduYbTiU,465
|
|
7
|
+
hyper_connections-0.1.12.dist-info/METADATA,sha256=rRgT4Zhj2xUPSRPH0LKZmfJK4olGOxpxPR5ha9wEMQE,5231
|
|
8
|
+
hyper_connections-0.1.12.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
9
|
+
hyper_connections-0.1.12.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
|
|
10
|
+
hyper_connections-0.1.12.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|