hyper-connections 0.1.8__tar.gz → 0.1.10__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {hyper_connections-0.1.8 → hyper_connections-0.1.10}/PKG-INFO +1 -1
- {hyper_connections-0.1.8 → hyper_connections-0.1.10}/hyper_connections/hyper_connections.py +36 -10
- {hyper_connections-0.1.8 → hyper_connections-0.1.10}/pyproject.toml +1 -1
- {hyper_connections-0.1.8 → hyper_connections-0.1.10}/.github/workflows/python-publish.yml +0 -0
- {hyper_connections-0.1.8 → hyper_connections-0.1.10}/.github/workflows/test.yml +0 -0
- {hyper_connections-0.1.8 → hyper_connections-0.1.10}/.gitignore +0 -0
- {hyper_connections-0.1.8 → hyper_connections-0.1.10}/LICENSE +0 -0
- {hyper_connections-0.1.8 → hyper_connections-0.1.10}/README.md +0 -0
- {hyper_connections-0.1.8 → hyper_connections-0.1.10}/hyper-connections.png +0 -0
- {hyper_connections-0.1.8 → hyper_connections-0.1.10}/hyper_connections/__init__.py +0 -0
- {hyper_connections-0.1.8 → hyper_connections-0.1.10}/hyper_connections/hyper_connections_channel_first.py +0 -0
- {hyper_connections-0.1.8 → hyper_connections-0.1.10}/hyper_connections/hyper_connections_with_multi_branch_inputs.py +0 -0
- {hyper_connections-0.1.8 → hyper_connections-0.1.10}/hyper_connections/hyper_connections_with_multi_input_streams.py +0 -0
- {hyper_connections-0.1.8 → hyper_connections-0.1.10}/tests/test_hyper_connections.py +0 -0
|
@@ -34,22 +34,31 @@ def identity(t):
|
|
|
34
34
|
|
|
35
35
|
# main functions
|
|
36
36
|
|
|
37
|
-
def get_expand_reduce_stream_functions(num_streams, disable = False):
|
|
37
|
+
def get_expand_reduce_stream_functions(num_streams, add_stream_embed = False, dim = None, disable = False):
|
|
38
38
|
|
|
39
39
|
if num_streams == 1 or disable:
|
|
40
40
|
return (nn.Identity(), nn.Identity())
|
|
41
41
|
|
|
42
|
-
|
|
42
|
+
if add_stream_embed:
|
|
43
|
+
assert exists(dim), '`dim` must be passed into get_init_and_expand_reduce_stream_functions for returning an expansion function with stream embeddings added'
|
|
44
|
+
|
|
45
|
+
expand_fn = StreamEmbed(num_streams, dim, expand_to_streams = True)
|
|
46
|
+
else:
|
|
47
|
+
expand_fn = Reduce(pattern = 'b ... -> (b s) ...', reduction = 'repeat', s = num_streams)
|
|
48
|
+
|
|
43
49
|
reduce_fn = Reduce(pattern = '(b s) ... -> b ...', reduction = 'sum', s = num_streams)
|
|
44
50
|
|
|
45
51
|
return expand_fn, reduce_fn
|
|
46
52
|
|
|
47
|
-
def get_init_and_expand_reduce_stream_functions(num_streams, disable = False):
|
|
53
|
+
def get_init_and_expand_reduce_stream_functions(num_streams, dim = None, add_stream_embed = False, disable = False):
|
|
48
54
|
|
|
49
55
|
hyper_conn_klass = HyperConnections if not disable else Residual
|
|
50
56
|
|
|
51
57
|
init_hyper_conn_fn = partial(hyper_conn_klass, num_streams)
|
|
52
|
-
expand_reduce_fns = get_expand_reduce_stream_functions(num_streams, disable = disable)
|
|
58
|
+
expand_reduce_fns = get_expand_reduce_stream_functions(num_streams, add_stream_embed = add_stream_embed, dim = dim, disable = disable)
|
|
59
|
+
|
|
60
|
+
if exists(dim):
|
|
61
|
+
init_hyper_conn_fn = partial(init_hyper_conn_fn, dim = dim)
|
|
53
62
|
|
|
54
63
|
return (init_hyper_conn_fn, *expand_reduce_fns)
|
|
55
64
|
|
|
@@ -132,6 +141,7 @@ class HyperConnections(Module):
|
|
|
132
141
|
channel_first = False,
|
|
133
142
|
dropout = 0.,
|
|
134
143
|
residual_transform: Module | None = None, # to support resnet blocks where dimension in not equal to dimension out - usually a residual conv
|
|
144
|
+
add_branch_out_to_residual = True # will disable depth connections (weighted residual sum with beta) if set False
|
|
135
145
|
):
|
|
136
146
|
"""
|
|
137
147
|
Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
|
|
@@ -151,7 +161,7 @@ class HyperConnections(Module):
|
|
|
151
161
|
self.num_residual_streams = num_residual_streams
|
|
152
162
|
init_residual_index = default(layer_index, randrange(num_residual_streams)) % num_residual_streams # just choose one random residual stream if layer index not given
|
|
153
163
|
|
|
154
|
-
|
|
164
|
+
# width connection
|
|
155
165
|
|
|
156
166
|
init_alpha0 = torch.zeros((num_residual_streams, 1))
|
|
157
167
|
init_alpha0[init_residual_index, 0] = 1.
|
|
@@ -160,8 +170,15 @@ class HyperConnections(Module):
|
|
|
160
170
|
|
|
161
171
|
self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams + 1))
|
|
162
172
|
self.dynamic_alpha_scale = nn.Parameter(torch.ones(()) * 1e-2)
|
|
163
|
-
|
|
164
|
-
|
|
173
|
+
|
|
174
|
+
# depth connection related (beta)
|
|
175
|
+
|
|
176
|
+
self.add_branch_out_to_residual = add_branch_out_to_residual
|
|
177
|
+
|
|
178
|
+
if add_branch_out_to_residual:
|
|
179
|
+
self.static_beta = nn.Parameter(torch.ones(num_residual_streams))
|
|
180
|
+
self.dynamic_beta_fn = nn.Parameter(torch.zeros(dim))
|
|
181
|
+
self.dynamic_beta_scale = nn.Parameter(torch.ones(()) * 1e-2)
|
|
165
182
|
|
|
166
183
|
# dropouts
|
|
167
184
|
|
|
@@ -196,9 +213,12 @@ class HyperConnections(Module):
|
|
|
196
213
|
|
|
197
214
|
# beta for weights from branch output back to residual streams
|
|
198
215
|
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
216
|
+
beta = None
|
|
217
|
+
|
|
218
|
+
if self.add_branch_out_to_residual:
|
|
219
|
+
dc_weight = self.act(normed @ self.dynamic_beta_fn)
|
|
220
|
+
dynamic_beta = dc_weight * self.dynamic_beta_scale
|
|
221
|
+
beta = dynamic_beta + self.static_beta
|
|
202
222
|
|
|
203
223
|
mix_h = einsum(alpha, residuals, '... s t, ... s d -> ... t d')
|
|
204
224
|
|
|
@@ -210,6 +230,8 @@ class HyperConnections(Module):
|
|
|
210
230
|
return branch_input, maybe_transformed_residuals, dict(beta = beta)
|
|
211
231
|
|
|
212
232
|
def depth_connection(self, branch_output, residuals, *, beta):
|
|
233
|
+
assert self.add_branch_out_to_residual
|
|
234
|
+
|
|
213
235
|
# 'depth' connection
|
|
214
236
|
|
|
215
237
|
if self.channel_first:
|
|
@@ -244,6 +266,10 @@ class HyperConnections(Module):
|
|
|
244
266
|
branch_input, residuals, residual_kwargs = self.width_connection(residuals)
|
|
245
267
|
|
|
246
268
|
def add_residual_fn(branch_out):
|
|
269
|
+
|
|
270
|
+
if not self.add_branch_out_to_residual:
|
|
271
|
+
return branch_out
|
|
272
|
+
|
|
247
273
|
(branch_out, *rest), tree_spec = tree_flatten(branch_out)
|
|
248
274
|
|
|
249
275
|
branch_out = self.depth_connection(branch_out, residuals, **residual_kwargs)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|