hyper-connections 0.1.9__py3-none-any.whl → 0.1.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hyper_connections/hyper_connections.py +29 -9
- {hyper_connections-0.1.9.dist-info → hyper_connections-0.1.11.dist-info}/METADATA +1 -1
- hyper_connections-0.1.11.dist-info/RECORD +9 -0
- hyper_connections-0.1.9.dist-info/RECORD +0 -9
- {hyper_connections-0.1.9.dist-info → hyper_connections-0.1.11.dist-info}/WHEEL +0 -0
- {hyper_connections-0.1.9.dist-info → hyper_connections-0.1.11.dist-info}/licenses/LICENSE +0 -0
|
@@ -19,6 +19,7 @@ b - batch
|
|
|
19
19
|
d - feature dimension
|
|
20
20
|
s - residual streams
|
|
21
21
|
t - residual streams + num branch inputs
|
|
22
|
+
v - number of views for branch input
|
|
22
23
|
"""
|
|
23
24
|
|
|
24
25
|
# helper functions
|
|
@@ -34,22 +35,31 @@ def identity(t):
|
|
|
34
35
|
|
|
35
36
|
# main functions
|
|
36
37
|
|
|
37
|
-
def get_expand_reduce_stream_functions(num_streams, disable = False):
|
|
38
|
+
def get_expand_reduce_stream_functions(num_streams, add_stream_embed = False, dim = None, disable = False):
|
|
38
39
|
|
|
39
40
|
if num_streams == 1 or disable:
|
|
40
41
|
return (nn.Identity(), nn.Identity())
|
|
41
42
|
|
|
42
|
-
|
|
43
|
+
if add_stream_embed:
|
|
44
|
+
assert exists(dim), '`dim` must be passed into get_init_and_expand_reduce_stream_functions for returning an expansion function with stream embeddings added'
|
|
45
|
+
|
|
46
|
+
expand_fn = StreamEmbed(num_streams, dim, expand_to_streams = True)
|
|
47
|
+
else:
|
|
48
|
+
expand_fn = Reduce(pattern = 'b ... -> (b s) ...', reduction = 'repeat', s = num_streams)
|
|
49
|
+
|
|
43
50
|
reduce_fn = Reduce(pattern = '(b s) ... -> b ...', reduction = 'sum', s = num_streams)
|
|
44
51
|
|
|
45
52
|
return expand_fn, reduce_fn
|
|
46
53
|
|
|
47
|
-
def get_init_and_expand_reduce_stream_functions(num_streams, disable = False):
|
|
54
|
+
def get_init_and_expand_reduce_stream_functions(num_streams, dim = None, add_stream_embed = False, disable = False):
|
|
48
55
|
|
|
49
56
|
hyper_conn_klass = HyperConnections if not disable else Residual
|
|
50
57
|
|
|
51
58
|
init_hyper_conn_fn = partial(hyper_conn_klass, num_streams)
|
|
52
|
-
expand_reduce_fns = get_expand_reduce_stream_functions(num_streams, disable = disable)
|
|
59
|
+
expand_reduce_fns = get_expand_reduce_stream_functions(num_streams, add_stream_embed = add_stream_embed, dim = dim, disable = disable)
|
|
60
|
+
|
|
61
|
+
if exists(dim):
|
|
62
|
+
init_hyper_conn_fn = partial(init_hyper_conn_fn, dim = dim)
|
|
53
63
|
|
|
54
64
|
return (init_hyper_conn_fn, *expand_reduce_fns)
|
|
55
65
|
|
|
@@ -132,7 +142,8 @@ class HyperConnections(Module):
|
|
|
132
142
|
channel_first = False,
|
|
133
143
|
dropout = 0.,
|
|
134
144
|
residual_transform: Module | None = None, # to support resnet blocks where dimension in not equal to dimension out - usually a residual conv
|
|
135
|
-
add_branch_out_to_residual = True # will disable depth connections (weighted residual sum with beta) if set False
|
|
145
|
+
add_branch_out_to_residual = True, # will disable depth connections (weighted residual sum with beta) if set False
|
|
146
|
+
num_input_views = 1 # allow for the branch module to receive multiple input views, dimension placed on the very left (before batch)
|
|
136
147
|
):
|
|
137
148
|
"""
|
|
138
149
|
Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
|
|
@@ -152,14 +163,19 @@ class HyperConnections(Module):
|
|
|
152
163
|
self.num_residual_streams = num_residual_streams
|
|
153
164
|
init_residual_index = default(layer_index, randrange(num_residual_streams)) % num_residual_streams # just choose one random residual stream if layer index not given
|
|
154
165
|
|
|
166
|
+
# width num residual streams
|
|
167
|
+
|
|
168
|
+
assert num_input_views >= 1
|
|
169
|
+
self.num_input_views = num_input_views
|
|
170
|
+
|
|
155
171
|
# width connection
|
|
156
172
|
|
|
157
|
-
init_alpha0 = torch.zeros((num_residual_streams,
|
|
158
|
-
init_alpha0[init_residual_index,
|
|
173
|
+
init_alpha0 = torch.zeros((num_residual_streams, num_input_views))
|
|
174
|
+
init_alpha0[init_residual_index, :] = 1.
|
|
159
175
|
|
|
160
176
|
self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))
|
|
161
177
|
|
|
162
|
-
self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams +
|
|
178
|
+
self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams + num_input_views))
|
|
163
179
|
self.dynamic_alpha_scale = nn.Parameter(torch.ones(()) * 1e-2)
|
|
164
180
|
|
|
165
181
|
# depth connection related (beta)
|
|
@@ -213,7 +229,11 @@ class HyperConnections(Module):
|
|
|
213
229
|
|
|
214
230
|
mix_h = einsum(alpha, residuals, '... s t, ... s d -> ... t d')
|
|
215
231
|
|
|
216
|
-
|
|
232
|
+
if self.num_input_views == 1:
|
|
233
|
+
branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :]
|
|
234
|
+
else:
|
|
235
|
+
branch_input, residuals = mix_h[..., :self.num_input_views, :], mix_h[..., self.num_input_views:, :]
|
|
236
|
+
branch_input = rearrange(branch_input, 'b ... v d -> v b ... d')
|
|
217
237
|
|
|
218
238
|
if self.channel_first:
|
|
219
239
|
branch_input = rearrange(branch_input, 'b ... d -> b d ...')
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
hyper_connections/__init__.py,sha256=d2zNTka0Gp9vINu4U-RhgTJFBhsVrs1fne_15Zl0oOs,224
|
|
2
|
+
hyper_connections/hyper_connections.py,sha256=mUImPtaTE8Paygs-6vq7l_mlph1CkU__jRcE4TFim_Y,12137
|
|
3
|
+
hyper_connections/hyper_connections_channel_first.py,sha256=BojfO2dcT4jX1rlcU3kr0B6B_CjrkkS2AZU4ZXeWvh8,6769
|
|
4
|
+
hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=2JABz1slrF5_XP33L1CMNTmmixfoD464FtQpUADqneU,7806
|
|
5
|
+
hyper_connections/hyper_connections_with_multi_input_streams.py,sha256=UIKceEZEBLKFL5VuP5tR1KTDeZNIJEKjFuPAwXkcp0I,11282
|
|
6
|
+
hyper_connections-0.1.11.dist-info/METADATA,sha256=Ck3udilJMrT1ABRkqkNhfEdkjQXAKFBtsUAAalEu3No,5231
|
|
7
|
+
hyper_connections-0.1.11.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
8
|
+
hyper_connections-0.1.11.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
|
|
9
|
+
hyper_connections-0.1.11.dist-info/RECORD,,
|
|
@@ -1,9 +0,0 @@
|
|
|
1
|
-
hyper_connections/__init__.py,sha256=d2zNTka0Gp9vINu4U-RhgTJFBhsVrs1fne_15Zl0oOs,224
|
|
2
|
-
hyper_connections/hyper_connections.py,sha256=F81iJkGMpxgCZPaBTLf0c3CYIE-ROAVgZJWY3NlrsJw,11068
|
|
3
|
-
hyper_connections/hyper_connections_channel_first.py,sha256=BojfO2dcT4jX1rlcU3kr0B6B_CjrkkS2AZU4ZXeWvh8,6769
|
|
4
|
-
hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=2JABz1slrF5_XP33L1CMNTmmixfoD464FtQpUADqneU,7806
|
|
5
|
-
hyper_connections/hyper_connections_with_multi_input_streams.py,sha256=UIKceEZEBLKFL5VuP5tR1KTDeZNIJEKjFuPAwXkcp0I,11282
|
|
6
|
-
hyper_connections-0.1.9.dist-info/METADATA,sha256=XxicphOwzNfTmBLF4Py89MhTWSpJqVk1EG-DV1gpFvo,5230
|
|
7
|
-
hyper_connections-0.1.9.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
8
|
-
hyper_connections-0.1.9.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
|
|
9
|
-
hyper_connections-0.1.9.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|