hyper-connections 0.1.15__tar.gz → 0.2.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {hyper_connections-0.1.15 → hyper_connections-0.2.1}/PKG-INFO +18 -1
- {hyper_connections-0.1.15 → hyper_connections-0.2.1}/README.md +17 -0
- {hyper_connections-0.1.15 → hyper_connections-0.2.1}/hyper_connections/hyper_connections.py +140 -29
- {hyper_connections-0.1.15 → hyper_connections-0.2.1}/hyper_connections/hyper_connections_channel_first.py +3 -1
- {hyper_connections-0.1.15 → hyper_connections-0.2.1}/hyper_connections/hyper_connections_with_multi_branch_inputs.py +3 -1
- {hyper_connections-0.1.15 → hyper_connections-0.2.1}/hyper_connections/hyper_connections_with_multi_input_streams.py +3 -1
- {hyper_connections-0.1.15 → hyper_connections-0.2.1}/pyproject.toml +1 -1
- {hyper_connections-0.1.15 → hyper_connections-0.2.1}/tests/test_hyper_connections.py +6 -2
- {hyper_connections-0.1.15 → hyper_connections-0.2.1}/.github/workflows/python-publish.yml +0 -0
- {hyper_connections-0.1.15 → hyper_connections-0.2.1}/.github/workflows/test.yml +0 -0
- {hyper_connections-0.1.15 → hyper_connections-0.2.1}/.gitignore +0 -0
- {hyper_connections-0.1.15 → hyper_connections-0.2.1}/LICENSE +0 -0
- {hyper_connections-0.1.15 → hyper_connections-0.2.1}/hyper-connections.png +0 -0
- {hyper_connections-0.1.15 → hyper_connections-0.2.1}/hyper_connections/__init__.py +0 -0
- {hyper_connections-0.1.15 → hyper_connections-0.2.1}/hyper_connections/residuals.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: hyper-connections
|
|
3
|
-
Version: 0.1
|
|
3
|
+
Version: 0.2.1
|
|
4
4
|
Summary: Hyper-Connections
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/hyper-connections/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/hyper-connections
|
|
@@ -141,6 +141,12 @@ To compare hyper connections to plain residual without changing the code, just p
|
|
|
141
141
|
get_init_and_expand_reduce_stream_functions(4, disable = True)
|
|
142
142
|
```
|
|
143
143
|
|
|
144
|
+
To use the fractionated feature dimensions proposed in [a follow up paper](https://arxiv.org/abs/2503.14125) by same authors, just instantiate with `num_fracs` greater than `1` as so
|
|
145
|
+
|
|
146
|
+
```python
|
|
147
|
+
get_init_and_expand_reduce_stream_functions(1, num_fracs = 4) # also allows you to mix streams and fractions of feature dimension
|
|
148
|
+
```
|
|
149
|
+
|
|
144
150
|
## Citation
|
|
145
151
|
|
|
146
152
|
```bibtex
|
|
@@ -160,3 +166,14 @@ get_init_and_expand_reduce_stream_functions(4, disable = True)
|
|
|
160
166
|
url = {https://medium.com/@ohadrubin/exploring-weight-decay-in-layer-normalization-challenges-and-a-reparameterization-solution-ad4d12c24950}
|
|
161
167
|
}
|
|
162
168
|
```
|
|
169
|
+
|
|
170
|
+
```bibtex
|
|
171
|
+
@article{Zhu2025FracConnectionsFE,
|
|
172
|
+
title = {Frac-Connections: Fractional Extension of Hyper-Connections},
|
|
173
|
+
author = {Defa Zhu and Hongzhi Huang and Jundong Zhou and Zihao Huang and Yutao Zeng and Banggu Wu and Qiyang Min and Xun Zhou},
|
|
174
|
+
journal = {ArXiv},
|
|
175
|
+
year = {2025},
|
|
176
|
+
volume = {abs/2503.14125},
|
|
177
|
+
url = {https://api.semanticscholar.org/CorpusID:277104144}
|
|
178
|
+
}
|
|
179
|
+
```
|
|
@@ -100,6 +100,12 @@ To compare hyper connections to plain residual without changing the code, just p
|
|
|
100
100
|
get_init_and_expand_reduce_stream_functions(4, disable = True)
|
|
101
101
|
```
|
|
102
102
|
|
|
103
|
+
To use the fractionated feature dimensions proposed in [a follow up paper](https://arxiv.org/abs/2503.14125) by same authors, just instantiate with `num_fracs` greater than `1` as so
|
|
104
|
+
|
|
105
|
+
```python
|
|
106
|
+
get_init_and_expand_reduce_stream_functions(1, num_fracs = 4) # also allows you to mix streams and fractions of feature dimension
|
|
107
|
+
```
|
|
108
|
+
|
|
103
109
|
## Citation
|
|
104
110
|
|
|
105
111
|
```bibtex
|
|
@@ -119,3 +125,14 @@ get_init_and_expand_reduce_stream_functions(4, disable = True)
|
|
|
119
125
|
url = {https://medium.com/@ohadrubin/exploring-weight-decay-in-layer-normalization-challenges-and-a-reparameterization-solution-ad4d12c24950}
|
|
120
126
|
}
|
|
121
127
|
```
|
|
128
|
+
|
|
129
|
+
```bibtex
|
|
130
|
+
@article{Zhu2025FracConnectionsFE,
|
|
131
|
+
title = {Frac-Connections: Fractional Extension of Hyper-Connections},
|
|
132
|
+
author = {Defa Zhu and Hongzhi Huang and Jundong Zhou and Zihao Huang and Yutao Zeng and Banggu Wu and Qiyang Min and Xun Zhou},
|
|
133
|
+
journal = {ArXiv},
|
|
134
|
+
year = {2025},
|
|
135
|
+
volume = {abs/2503.14125},
|
|
136
|
+
url = {https://api.semanticscholar.org/CorpusID:277104144}
|
|
137
|
+
}
|
|
138
|
+
```
|
|
@@ -5,13 +5,13 @@ from functools import partial
|
|
|
5
5
|
from random import randrange
|
|
6
6
|
|
|
7
7
|
import torch
|
|
8
|
-
from torch import nn
|
|
9
|
-
from torch.nn import Module
|
|
8
|
+
from torch import nn, cat
|
|
10
9
|
import torch.nn.functional as F
|
|
10
|
+
from torch.nn import Module, Sequential
|
|
11
11
|
from torch.utils._pytree import tree_flatten, tree_unflatten
|
|
12
12
|
|
|
13
13
|
from einops import rearrange, repeat, reduce, einsum
|
|
14
|
-
from einops.layers.torch import Reduce
|
|
14
|
+
from einops.layers.torch import Rearrange, Reduce
|
|
15
15
|
|
|
16
16
|
"""
|
|
17
17
|
ein notation:
|
|
@@ -19,6 +19,7 @@ b - batch
|
|
|
19
19
|
d - feature dimension
|
|
20
20
|
s - residual streams
|
|
21
21
|
t - residual streams + num branch inputs
|
|
22
|
+
f - number of fractions (division of feature dimension space)
|
|
22
23
|
v - number of views for branch input
|
|
23
24
|
"""
|
|
24
25
|
|
|
@@ -27,6 +28,9 @@ v - number of views for branch input
|
|
|
27
28
|
def exists(v):
|
|
28
29
|
return v is not None
|
|
29
30
|
|
|
31
|
+
def divisible_by(num, den):
|
|
32
|
+
return (num % den) == 0
|
|
33
|
+
|
|
30
34
|
def default(v, d):
|
|
31
35
|
return v if exists(v) else d
|
|
32
36
|
|
|
@@ -38,8 +42,12 @@ def add(x, y):
|
|
|
38
42
|
|
|
39
43
|
# main functions
|
|
40
44
|
|
|
41
|
-
def get_expand_reduce_stream_functions(
|
|
42
|
-
|
|
45
|
+
def get_expand_reduce_stream_functions(
|
|
46
|
+
num_streams,
|
|
47
|
+
add_stream_embed = False,
|
|
48
|
+
dim = None,
|
|
49
|
+
disable = False
|
|
50
|
+
):
|
|
43
51
|
if num_streams == 1 or disable:
|
|
44
52
|
return (nn.Identity(), nn.Identity())
|
|
45
53
|
|
|
@@ -54,11 +62,18 @@ def get_expand_reduce_stream_functions(num_streams, add_stream_embed = False, di
|
|
|
54
62
|
|
|
55
63
|
return expand_fn, reduce_fn
|
|
56
64
|
|
|
57
|
-
def get_init_and_expand_reduce_stream_functions(
|
|
65
|
+
def get_init_and_expand_reduce_stream_functions(
|
|
66
|
+
num_streams,
|
|
67
|
+
num_fracs = 1,
|
|
68
|
+
dim = None,
|
|
69
|
+
add_stream_embed = False,
|
|
70
|
+
disable = None
|
|
71
|
+
):
|
|
72
|
+
disable = default(disable, num_streams == 1 and num_fracs == 1)
|
|
58
73
|
|
|
59
74
|
hyper_conn_klass = HyperConnections if not disable else Residual
|
|
60
75
|
|
|
61
|
-
init_hyper_conn_fn = partial(hyper_conn_klass, num_streams)
|
|
76
|
+
init_hyper_conn_fn = partial(hyper_conn_klass, num_streams, num_fracs = num_fracs)
|
|
62
77
|
expand_reduce_fns = get_expand_reduce_stream_functions(num_streams, add_stream_embed = add_stream_embed, dim = dim, disable = disable)
|
|
63
78
|
|
|
64
79
|
if exists(dim):
|
|
@@ -93,13 +108,24 @@ class Residual(Module):
|
|
|
93
108
|
self.branch = branch
|
|
94
109
|
self.residual_transform = default(residual_transform, nn.Identity())
|
|
95
110
|
|
|
96
|
-
def width_connection(
|
|
111
|
+
def width_connection(
|
|
112
|
+
self,
|
|
113
|
+
residuals
|
|
114
|
+
):
|
|
97
115
|
return residuals, residuals, dict()
|
|
98
116
|
|
|
99
|
-
def depth_connection(
|
|
117
|
+
def depth_connection(
|
|
118
|
+
self,
|
|
119
|
+
branch_output,
|
|
120
|
+
residuals,
|
|
121
|
+
|
|
122
|
+
):
|
|
100
123
|
return branch_output + self.residual_transform(residuals)
|
|
101
124
|
|
|
102
|
-
def decorate_branch(
|
|
125
|
+
def decorate_branch(
|
|
126
|
+
self,
|
|
127
|
+
branch: Callable
|
|
128
|
+
):
|
|
103
129
|
assert not exists(self.branch), 'branch was already wrapped on init'
|
|
104
130
|
|
|
105
131
|
def forward_and_add_residual(residual, *args, **kwargs):
|
|
@@ -113,7 +139,12 @@ class Residual(Module):
|
|
|
113
139
|
|
|
114
140
|
return forward_and_add_residual
|
|
115
141
|
|
|
116
|
-
def forward(
|
|
142
|
+
def forward(
|
|
143
|
+
self,
|
|
144
|
+
residuals,
|
|
145
|
+
*branch_args,
|
|
146
|
+
**branch_kwargs
|
|
147
|
+
):
|
|
117
148
|
|
|
118
149
|
branch_input, residuals, residual_kwargs = self.width_connection(residuals)
|
|
119
150
|
|
|
@@ -145,9 +176,10 @@ class HyperConnections(Module):
|
|
|
145
176
|
channel_first = False,
|
|
146
177
|
dropout = 0.,
|
|
147
178
|
residual_transform: Module | None = None, # to support resnet blocks where dimension in not equal to dimension out - usually a residual conv
|
|
148
|
-
add_branch_out_to_residual = True,
|
|
149
|
-
num_input_views = 1,
|
|
150
|
-
depth_residual_fn = add
|
|
179
|
+
add_branch_out_to_residual = True, # will disable depth connections (weighted residual sum with beta) if set False
|
|
180
|
+
num_input_views = 1, # allow for the branch module to receive multiple input views, dimension placed on the very left (before batch)
|
|
181
|
+
depth_residual_fn = add,
|
|
182
|
+
num_fracs = 1 # https://arxiv.org/abs/2503.14125
|
|
151
183
|
):
|
|
152
184
|
"""
|
|
153
185
|
Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
|
|
@@ -160,13 +192,34 @@ class HyperConnections(Module):
|
|
|
160
192
|
|
|
161
193
|
self.act = nn.Tanh() if tanh else nn.Identity()
|
|
162
194
|
|
|
163
|
-
|
|
195
|
+
# frac-connections paper - num_fracs > 1 will be the `m` in their paper https://arxiv.org/abs/2503.14125
|
|
196
|
+
|
|
197
|
+
assert num_fracs >= 1
|
|
198
|
+
|
|
199
|
+
self.num_fracs = num_fracs
|
|
200
|
+
self.has_fracs = num_fracs > 1
|
|
201
|
+
|
|
202
|
+
self.split_fracs = Rearrange('b ... (f d) -> b ... f d', f = num_fracs)
|
|
203
|
+
self.merge_fracs = Rearrange('b ... f d -> b ... (f d)')
|
|
204
|
+
|
|
205
|
+
assert divisible_by(dim, num_fracs), f'feature dimension ({dim}) must be divisible by the `num_fracs` ({num_fracs})'
|
|
206
|
+
|
|
207
|
+
dim //= num_fracs # effective dim handled in dimension is feature dimension divided by num fractions
|
|
208
|
+
|
|
209
|
+
# they used layernorm in paper, but rmsnorm is fine given what we know now
|
|
210
|
+
|
|
211
|
+
self.norm = RMSNorm(dim)
|
|
164
212
|
|
|
165
213
|
assert num_residual_streams > 0, '`num_residual_streams` must be greater than 0'
|
|
166
214
|
|
|
167
215
|
self.num_residual_streams = num_residual_streams
|
|
168
216
|
init_residual_index = default(layer_index, randrange(num_residual_streams)) % num_residual_streams # just choose one random residual stream if layer index not given
|
|
169
217
|
|
|
218
|
+
# handle the parameter dimensions, which may require (num_residuals x num_fractions) - generalizing hyper + frac connections
|
|
219
|
+
|
|
220
|
+
num_residual_streams_fracs = num_residual_streams * num_fracs
|
|
221
|
+
num_input_views_fracs = num_input_views * num_fracs
|
|
222
|
+
|
|
170
223
|
# width num residual streams
|
|
171
224
|
|
|
172
225
|
assert num_input_views >= 1
|
|
@@ -174,12 +227,12 @@ class HyperConnections(Module):
|
|
|
174
227
|
|
|
175
228
|
# width connection
|
|
176
229
|
|
|
177
|
-
init_alpha0 = torch.zeros((
|
|
230
|
+
init_alpha0 = torch.zeros((num_residual_streams_fracs, num_input_views_fracs))
|
|
178
231
|
init_alpha0[init_residual_index, :] = 1.
|
|
179
232
|
|
|
180
|
-
self.static_alpha = nn.Parameter(
|
|
233
|
+
self.static_alpha = nn.Parameter(cat((init_alpha0, torch.eye(num_residual_streams_fracs)), dim = 1))
|
|
181
234
|
|
|
182
|
-
self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim,
|
|
235
|
+
self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams_fracs + num_input_views_fracs))
|
|
183
236
|
self.dynamic_alpha_scale = nn.Parameter(torch.ones(()) * 1e-2)
|
|
184
237
|
|
|
185
238
|
# depth connection related (beta)
|
|
@@ -187,8 +240,11 @@ class HyperConnections(Module):
|
|
|
187
240
|
self.add_branch_out_to_residual = add_branch_out_to_residual
|
|
188
241
|
|
|
189
242
|
if add_branch_out_to_residual:
|
|
190
|
-
self.static_beta = nn.Parameter(torch.ones(
|
|
191
|
-
|
|
243
|
+
self.static_beta = nn.Parameter(torch.ones(num_residual_streams_fracs))
|
|
244
|
+
|
|
245
|
+
dynamic_beta_shape = (dim,) if num_fracs == 1 else (dim, num_fracs) # preserve backwards compat
|
|
246
|
+
self.dynamic_beta_fn = nn.Parameter(torch.zeros(dynamic_beta_shape))
|
|
247
|
+
|
|
192
248
|
self.dynamic_beta_scale = nn.Parameter(torch.ones(()) * 1e-2)
|
|
193
249
|
|
|
194
250
|
# dropouts
|
|
@@ -209,16 +265,30 @@ class HyperConnections(Module):
|
|
|
209
265
|
|
|
210
266
|
self.depth_residual_fn = depth_residual_fn
|
|
211
267
|
|
|
212
|
-
def width_connection(
|
|
268
|
+
def width_connection(
|
|
269
|
+
self,
|
|
270
|
+
residuals
|
|
271
|
+
):
|
|
272
|
+
streams = self.num_residual_streams
|
|
213
273
|
|
|
214
274
|
maybe_transformed_residuals = self.residual_transform(residuals)
|
|
215
275
|
|
|
216
276
|
# width connection
|
|
217
277
|
|
|
278
|
+
# handle channel first
|
|
279
|
+
|
|
218
280
|
if self.channel_first:
|
|
219
281
|
residuals = rearrange(residuals, 'b d ... -> b ... d')
|
|
220
282
|
|
|
221
|
-
|
|
283
|
+
# split out fractions
|
|
284
|
+
|
|
285
|
+
residuals = self.split_fracs(residuals)
|
|
286
|
+
|
|
287
|
+
# split out streams
|
|
288
|
+
|
|
289
|
+
residuals = rearrange(residuals, '(b s) ... d -> b ... s d', s = streams)
|
|
290
|
+
|
|
291
|
+
# norm
|
|
222
292
|
|
|
223
293
|
normed = self.norm(residuals)
|
|
224
294
|
|
|
@@ -226,7 +296,12 @@ class HyperConnections(Module):
|
|
|
226
296
|
|
|
227
297
|
wc_weight = self.act(normed @ self.dynamic_alpha_fn)
|
|
228
298
|
dynamic_alpha = wc_weight * self.dynamic_alpha_scale
|
|
229
|
-
|
|
299
|
+
|
|
300
|
+
static_alpha = rearrange(self.static_alpha, '(f s) d -> f s d', s = streams)
|
|
301
|
+
|
|
302
|
+
alpha = dynamic_alpha + static_alpha
|
|
303
|
+
|
|
304
|
+
alpha = self.split_fracs(alpha) # (batch, seq, fracs1, streams, fracs2, input + residual streams)
|
|
230
305
|
|
|
231
306
|
# beta for weights from branch output back to residual streams
|
|
232
307
|
|
|
@@ -234,10 +309,17 @@ class HyperConnections(Module):
|
|
|
234
309
|
|
|
235
310
|
if self.add_branch_out_to_residual:
|
|
236
311
|
dc_weight = self.act(normed @ self.dynamic_beta_fn)
|
|
312
|
+
|
|
313
|
+
if not self.has_fracs:
|
|
314
|
+
dc_weight = rearrange(dc_weight, '... -> ... 1')
|
|
315
|
+
|
|
237
316
|
dynamic_beta = dc_weight * self.dynamic_beta_scale
|
|
238
|
-
beta = dynamic_beta + self.static_beta
|
|
239
317
|
|
|
240
|
-
|
|
318
|
+
static_beta = rearrange(self.static_beta, '... (s f) -> ... s f', s = streams)
|
|
319
|
+
|
|
320
|
+
beta = dynamic_beta + static_beta
|
|
321
|
+
|
|
322
|
+
mix_h = einsum(alpha, residuals, '... f1 s f2 t, ... f1 s d -> ... f2 t d')
|
|
241
323
|
|
|
242
324
|
if self.num_input_views == 1:
|
|
243
325
|
branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :]
|
|
@@ -248,19 +330,40 @@ class HyperConnections(Module):
|
|
|
248
330
|
if self.channel_first:
|
|
249
331
|
branch_input = rearrange(branch_input, 'b ... d -> b d ...')
|
|
250
332
|
|
|
333
|
+
# maybe merge fractions back
|
|
334
|
+
|
|
335
|
+
branch_input = self.merge_fracs(branch_input)
|
|
336
|
+
|
|
251
337
|
return branch_input, maybe_transformed_residuals, dict(beta = beta)
|
|
252
338
|
|
|
253
|
-
def depth_connection(
|
|
339
|
+
def depth_connection(
|
|
340
|
+
self,
|
|
341
|
+
branch_output,
|
|
342
|
+
residuals,
|
|
343
|
+
*,
|
|
344
|
+
beta
|
|
345
|
+
):
|
|
254
346
|
assert self.add_branch_out_to_residual
|
|
255
347
|
|
|
348
|
+
# maybe split fractions
|
|
349
|
+
|
|
350
|
+
branch_output = self.split_fracs(branch_output)
|
|
351
|
+
|
|
256
352
|
# 'depth' connection
|
|
257
353
|
|
|
258
354
|
if self.channel_first:
|
|
259
355
|
branch_output = rearrange(branch_output, 'b d ... -> b ... d')
|
|
260
356
|
|
|
261
|
-
output = einsum(branch_output, beta, 'b ... d, b ... s -> b ... s d')
|
|
357
|
+
output = einsum(branch_output, beta, 'b ... f1 d, b ... f1 s f2 -> b ... f2 s d')
|
|
358
|
+
|
|
262
359
|
output = rearrange(output, 'b ... s d -> (b s) ... d')
|
|
263
360
|
|
|
361
|
+
# merge merge back fractions
|
|
362
|
+
|
|
363
|
+
output = self.merge_fracs(output)
|
|
364
|
+
|
|
365
|
+
# channel first
|
|
366
|
+
|
|
264
367
|
if self.channel_first:
|
|
265
368
|
output = rearrange(output, 'b ... d -> b d ...')
|
|
266
369
|
|
|
@@ -268,7 +371,10 @@ class HyperConnections(Module):
|
|
|
268
371
|
|
|
269
372
|
return self.dropout(residuals)
|
|
270
373
|
|
|
271
|
-
def decorate_branch(
|
|
374
|
+
def decorate_branch(
|
|
375
|
+
self,
|
|
376
|
+
branch: Callable
|
|
377
|
+
):
|
|
272
378
|
assert not exists(self.branch), 'branch was already wrapped on init'
|
|
273
379
|
|
|
274
380
|
def forward_and_add_residual(residual, *args, **kwargs):
|
|
@@ -282,7 +388,12 @@ class HyperConnections(Module):
|
|
|
282
388
|
|
|
283
389
|
return forward_and_add_residual
|
|
284
390
|
|
|
285
|
-
def forward(
|
|
391
|
+
def forward(
|
|
392
|
+
self,
|
|
393
|
+
residuals,
|
|
394
|
+
*branch_args,
|
|
395
|
+
**branch_kwargs
|
|
396
|
+
):
|
|
286
397
|
|
|
287
398
|
branch_input, residuals, residual_kwargs = self.width_connection(residuals)
|
|
288
399
|
|
|
@@ -49,7 +49,9 @@ def get_expand_reduce_stream_functions(num_streams, disable = False):
|
|
|
49
49
|
|
|
50
50
|
return expand_fn, reduce_fn
|
|
51
51
|
|
|
52
|
-
def get_init_and_expand_reduce_stream_functions(num_streams, disable =
|
|
52
|
+
def get_init_and_expand_reduce_stream_functions(num_streams, disable = None):
|
|
53
|
+
|
|
54
|
+
disable = default(disable, num_streams == 1)
|
|
53
55
|
|
|
54
56
|
hyper_conn_klass = HyperConnections if not disable else Residual
|
|
55
57
|
|
|
@@ -50,7 +50,9 @@ def get_expand_reduce_stream_functions(cls, num_streams, disable = False):
|
|
|
50
50
|
|
|
51
51
|
return expand_fn, reduce_fn
|
|
52
52
|
|
|
53
|
-
def get_init_and_expand_reduce_stream_functions(cls, num_streams, disable =
|
|
53
|
+
def get_init_and_expand_reduce_stream_functions(cls, num_streams, disable = None):
|
|
54
|
+
|
|
55
|
+
disable = default(disable, num_streams == 1)
|
|
54
56
|
|
|
55
57
|
hyper_conn_klass = HyperConnections if not disable else Residual
|
|
56
58
|
|
|
@@ -41,7 +41,9 @@ def get_expand_reduce_stream_functions(num_streams, disable = False):
|
|
|
41
41
|
|
|
42
42
|
return expand_fn, reduce_fn
|
|
43
43
|
|
|
44
|
-
def get_init_and_expand_reduce_stream_functions(num_streams, disable =
|
|
44
|
+
def get_init_and_expand_reduce_stream_functions(num_streams, disable = None):
|
|
45
|
+
|
|
46
|
+
disable = default(disable, num_streams == 1)
|
|
45
47
|
|
|
46
48
|
hyper_conn_klass = HyperConnections if not disable else Residual
|
|
47
49
|
|
|
@@ -3,8 +3,12 @@ import pytest
|
|
|
3
3
|
import torch
|
|
4
4
|
from torch import nn
|
|
5
5
|
|
|
6
|
+
@pytest.mark.parametrize('num_fracs', (1, 4))
|
|
6
7
|
@pytest.mark.parametrize('disable', (False, True))
|
|
7
|
-
def test_readme(
|
|
8
|
+
def test_readme(
|
|
9
|
+
num_fracs,
|
|
10
|
+
disable
|
|
11
|
+
):
|
|
8
12
|
|
|
9
13
|
# a single branch layer
|
|
10
14
|
|
|
@@ -20,7 +24,7 @@ def test_readme(disable):
|
|
|
20
24
|
|
|
21
25
|
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
|
22
26
|
|
|
23
|
-
init_hyper_conn, expand_stream, reduce_stream = get_init_and_expand_reduce_stream_functions(4, disable = disable)
|
|
27
|
+
init_hyper_conn, expand_stream, reduce_stream = get_init_and_expand_reduce_stream_functions(4, num_fracs = num_fracs, disable = disable)
|
|
24
28
|
|
|
25
29
|
# 1. wrap your branch function
|
|
26
30
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|