hyper-connections 0.1.8__py3-none-any.whl → 0.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hyper_connections/hyper_connections.py +23 -6
- {hyper_connections-0.1.8.dist-info → hyper_connections-0.1.9.dist-info}/METADATA +1 -1
- {hyper_connections-0.1.8.dist-info → hyper_connections-0.1.9.dist-info}/RECORD +5 -5
- {hyper_connections-0.1.8.dist-info → hyper_connections-0.1.9.dist-info}/WHEEL +0 -0
- {hyper_connections-0.1.8.dist-info → hyper_connections-0.1.9.dist-info}/licenses/LICENSE +0 -0
|
@@ -132,6 +132,7 @@ class HyperConnections(Module):
|
|
|
132
132
|
channel_first = False,
|
|
133
133
|
dropout = 0.,
|
|
134
134
|
residual_transform: Module | None = None, # to support resnet blocks where dimension in not equal to dimension out - usually a residual conv
|
|
135
|
+
add_branch_out_to_residual = True # will disable depth connections (weighted residual sum with beta) if set False
|
|
135
136
|
):
|
|
136
137
|
"""
|
|
137
138
|
Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
|
|
@@ -151,7 +152,7 @@ class HyperConnections(Module):
|
|
|
151
152
|
self.num_residual_streams = num_residual_streams
|
|
152
153
|
init_residual_index = default(layer_index, randrange(num_residual_streams)) % num_residual_streams # just choose one random residual stream if layer index not given
|
|
153
154
|
|
|
154
|
-
|
|
155
|
+
# width connection
|
|
155
156
|
|
|
156
157
|
init_alpha0 = torch.zeros((num_residual_streams, 1))
|
|
157
158
|
init_alpha0[init_residual_index, 0] = 1.
|
|
@@ -160,8 +161,15 @@ class HyperConnections(Module):
|
|
|
160
161
|
|
|
161
162
|
self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams + 1))
|
|
162
163
|
self.dynamic_alpha_scale = nn.Parameter(torch.ones(()) * 1e-2)
|
|
163
|
-
|
|
164
|
-
|
|
164
|
+
|
|
165
|
+
# depth connection related (beta)
|
|
166
|
+
|
|
167
|
+
self.add_branch_out_to_residual = add_branch_out_to_residual
|
|
168
|
+
|
|
169
|
+
if add_branch_out_to_residual:
|
|
170
|
+
self.static_beta = nn.Parameter(torch.ones(num_residual_streams))
|
|
171
|
+
self.dynamic_beta_fn = nn.Parameter(torch.zeros(dim))
|
|
172
|
+
self.dynamic_beta_scale = nn.Parameter(torch.ones(()) * 1e-2)
|
|
165
173
|
|
|
166
174
|
# dropouts
|
|
167
175
|
|
|
@@ -196,9 +204,12 @@ class HyperConnections(Module):
|
|
|
196
204
|
|
|
197
205
|
# beta for weights from branch output back to residual streams
|
|
198
206
|
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
207
|
+
beta = None
|
|
208
|
+
|
|
209
|
+
if self.add_branch_out_to_residual:
|
|
210
|
+
dc_weight = self.act(normed @ self.dynamic_beta_fn)
|
|
211
|
+
dynamic_beta = dc_weight * self.dynamic_beta_scale
|
|
212
|
+
beta = dynamic_beta + self.static_beta
|
|
202
213
|
|
|
203
214
|
mix_h = einsum(alpha, residuals, '... s t, ... s d -> ... t d')
|
|
204
215
|
|
|
@@ -210,6 +221,8 @@ class HyperConnections(Module):
|
|
|
210
221
|
return branch_input, maybe_transformed_residuals, dict(beta = beta)
|
|
211
222
|
|
|
212
223
|
def depth_connection(self, branch_output, residuals, *, beta):
|
|
224
|
+
assert self.add_branch_out_to_residual
|
|
225
|
+
|
|
213
226
|
# 'depth' connection
|
|
214
227
|
|
|
215
228
|
if self.channel_first:
|
|
@@ -244,6 +257,10 @@ class HyperConnections(Module):
|
|
|
244
257
|
branch_input, residuals, residual_kwargs = self.width_connection(residuals)
|
|
245
258
|
|
|
246
259
|
def add_residual_fn(branch_out):
|
|
260
|
+
|
|
261
|
+
if not self.add_branch_out_to_residual:
|
|
262
|
+
return branch_out
|
|
263
|
+
|
|
247
264
|
(branch_out, *rest), tree_spec = tree_flatten(branch_out)
|
|
248
265
|
|
|
249
266
|
branch_out = self.depth_connection(branch_out, residuals, **residual_kwargs)
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
hyper_connections/__init__.py,sha256=d2zNTka0Gp9vINu4U-RhgTJFBhsVrs1fne_15Zl0oOs,224
|
|
2
|
-
hyper_connections/hyper_connections.py,sha256=
|
|
2
|
+
hyper_connections/hyper_connections.py,sha256=F81iJkGMpxgCZPaBTLf0c3CYIE-ROAVgZJWY3NlrsJw,11068
|
|
3
3
|
hyper_connections/hyper_connections_channel_first.py,sha256=BojfO2dcT4jX1rlcU3kr0B6B_CjrkkS2AZU4ZXeWvh8,6769
|
|
4
4
|
hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=2JABz1slrF5_XP33L1CMNTmmixfoD464FtQpUADqneU,7806
|
|
5
5
|
hyper_connections/hyper_connections_with_multi_input_streams.py,sha256=UIKceEZEBLKFL5VuP5tR1KTDeZNIJEKjFuPAwXkcp0I,11282
|
|
6
|
-
hyper_connections-0.1.
|
|
7
|
-
hyper_connections-0.1.
|
|
8
|
-
hyper_connections-0.1.
|
|
9
|
-
hyper_connections-0.1.
|
|
6
|
+
hyper_connections-0.1.9.dist-info/METADATA,sha256=XxicphOwzNfTmBLF4Py89MhTWSpJqVk1EG-DV1gpFvo,5230
|
|
7
|
+
hyper_connections-0.1.9.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
8
|
+
hyper_connections-0.1.9.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
|
|
9
|
+
hyper_connections-0.1.9.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|