hyper-connections 0.1.10__py3-none-any.whl → 0.1.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hyper_connections/hyper_connections.py +16 -5
- {hyper_connections-0.1.10.dist-info → hyper_connections-0.1.11.dist-info}/METADATA +1 -1
- {hyper_connections-0.1.10.dist-info → hyper_connections-0.1.11.dist-info}/RECORD +5 -5
- {hyper_connections-0.1.10.dist-info → hyper_connections-0.1.11.dist-info}/WHEEL +0 -0
- {hyper_connections-0.1.10.dist-info → hyper_connections-0.1.11.dist-info}/licenses/LICENSE +0 -0
|
@@ -19,6 +19,7 @@ b - batch
|
|
|
19
19
|
d - feature dimension
|
|
20
20
|
s - residual streams
|
|
21
21
|
t - residual streams + num branch inputs
|
|
22
|
+
v - number of views for branch input
|
|
22
23
|
"""
|
|
23
24
|
|
|
24
25
|
# helper functions
|
|
@@ -141,7 +142,8 @@ class HyperConnections(Module):
|
|
|
141
142
|
channel_first = False,
|
|
142
143
|
dropout = 0.,
|
|
143
144
|
residual_transform: Module | None = None, # to support resnet blocks where dimension in not equal to dimension out - usually a residual conv
|
|
144
|
-
add_branch_out_to_residual = True # will disable depth connections (weighted residual sum with beta) if set False
|
|
145
|
+
add_branch_out_to_residual = True, # will disable depth connections (weighted residual sum with beta) if set False
|
|
146
|
+
num_input_views = 1 # allow for the branch module to receive multiple input views, dimension placed on the very left (before batch)
|
|
145
147
|
):
|
|
146
148
|
"""
|
|
147
149
|
Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
|
|
@@ -161,14 +163,19 @@ class HyperConnections(Module):
|
|
|
161
163
|
self.num_residual_streams = num_residual_streams
|
|
162
164
|
init_residual_index = default(layer_index, randrange(num_residual_streams)) % num_residual_streams # just choose one random residual stream if layer index not given
|
|
163
165
|
|
|
166
|
+
# width num residual streams
|
|
167
|
+
|
|
168
|
+
assert num_input_views >= 1
|
|
169
|
+
self.num_input_views = num_input_views
|
|
170
|
+
|
|
164
171
|
# width connection
|
|
165
172
|
|
|
166
|
-
init_alpha0 = torch.zeros((num_residual_streams,
|
|
167
|
-
init_alpha0[init_residual_index,
|
|
173
|
+
init_alpha0 = torch.zeros((num_residual_streams, num_input_views))
|
|
174
|
+
init_alpha0[init_residual_index, :] = 1.
|
|
168
175
|
|
|
169
176
|
self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))
|
|
170
177
|
|
|
171
|
-
self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams +
|
|
178
|
+
self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams + num_input_views))
|
|
172
179
|
self.dynamic_alpha_scale = nn.Parameter(torch.ones(()) * 1e-2)
|
|
173
180
|
|
|
174
181
|
# depth connection related (beta)
|
|
@@ -222,7 +229,11 @@ class HyperConnections(Module):
|
|
|
222
229
|
|
|
223
230
|
mix_h = einsum(alpha, residuals, '... s t, ... s d -> ... t d')
|
|
224
231
|
|
|
225
|
-
|
|
232
|
+
if self.num_input_views == 1:
|
|
233
|
+
branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :]
|
|
234
|
+
else:
|
|
235
|
+
branch_input, residuals = mix_h[..., :self.num_input_views, :], mix_h[..., self.num_input_views:, :]
|
|
236
|
+
branch_input = rearrange(branch_input, 'b ... v d -> v b ... d')
|
|
226
237
|
|
|
227
238
|
if self.channel_first:
|
|
228
239
|
branch_input = rearrange(branch_input, 'b ... d -> b d ...')
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
hyper_connections/__init__.py,sha256=d2zNTka0Gp9vINu4U-RhgTJFBhsVrs1fne_15Zl0oOs,224
|
|
2
|
-
hyper_connections/hyper_connections.py,sha256=
|
|
2
|
+
hyper_connections/hyper_connections.py,sha256=mUImPtaTE8Paygs-6vq7l_mlph1CkU__jRcE4TFim_Y,12137
|
|
3
3
|
hyper_connections/hyper_connections_channel_first.py,sha256=BojfO2dcT4jX1rlcU3kr0B6B_CjrkkS2AZU4ZXeWvh8,6769
|
|
4
4
|
hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=2JABz1slrF5_XP33L1CMNTmmixfoD464FtQpUADqneU,7806
|
|
5
5
|
hyper_connections/hyper_connections_with_multi_input_streams.py,sha256=UIKceEZEBLKFL5VuP5tR1KTDeZNIJEKjFuPAwXkcp0I,11282
|
|
6
|
-
hyper_connections-0.1.
|
|
7
|
-
hyper_connections-0.1.
|
|
8
|
-
hyper_connections-0.1.
|
|
9
|
-
hyper_connections-0.1.
|
|
6
|
+
hyper_connections-0.1.11.dist-info/METADATA,sha256=Ck3udilJMrT1ABRkqkNhfEdkjQXAKFBtsUAAalEu3No,5231
|
|
7
|
+
hyper_connections-0.1.11.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
8
|
+
hyper_connections-0.1.11.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
|
|
9
|
+
hyper_connections-0.1.11.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|