hyper-connections 0.1.15__tar.gz → 0.2.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.1.15
3
+ Version: 0.2.1
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -141,6 +141,12 @@ To compare hyper connections to plain residual without changing the code, just p
141
141
  get_init_and_expand_reduce_stream_functions(4, disable = True)
142
142
  ```
143
143
 
144
+ To use the fractionated feature dimensions proposed in [a follow up paper](https://arxiv.org/abs/2503.14125) by same authors, just instantiate with `num_fracs` greater than `1` as so
145
+
146
+ ```python
147
+ get_init_and_expand_reduce_stream_functions(1, num_fracs = 4) # also allows you to mix streams and fractions of feature dimension
148
+ ```
149
+
144
150
  ## Citation
145
151
 
146
152
  ```bibtex
@@ -160,3 +166,14 @@ get_init_and_expand_reduce_stream_functions(4, disable = True)
160
166
  url = {https://medium.com/@ohadrubin/exploring-weight-decay-in-layer-normalization-challenges-and-a-reparameterization-solution-ad4d12c24950}
161
167
  }
162
168
  ```
169
+
170
+ ```bibtex
171
+ @article{Zhu2025FracConnectionsFE,
172
+ title = {Frac-Connections: Fractional Extension of Hyper-Connections},
173
+ author = {Defa Zhu and Hongzhi Huang and Jundong Zhou and Zihao Huang and Yutao Zeng and Banggu Wu and Qiyang Min and Xun Zhou},
174
+ journal = {ArXiv},
175
+ year = {2025},
176
+ volume = {abs/2503.14125},
177
+ url = {https://api.semanticscholar.org/CorpusID:277104144}
178
+ }
179
+ ```
@@ -100,6 +100,12 @@ To compare hyper connections to plain residual without changing the code, just p
100
100
  get_init_and_expand_reduce_stream_functions(4, disable = True)
101
101
  ```
102
102
 
103
+ To use the fractionated feature dimensions proposed in [a follow up paper](https://arxiv.org/abs/2503.14125) by same authors, just instantiate with `num_fracs` greater than `1` as so
104
+
105
+ ```python
106
+ get_init_and_expand_reduce_stream_functions(1, num_fracs = 4) # also allows you to mix streams and fractions of feature dimension
107
+ ```
108
+
103
109
  ## Citation
104
110
 
105
111
  ```bibtex
@@ -119,3 +125,14 @@ get_init_and_expand_reduce_stream_functions(4, disable = True)
119
125
  url = {https://medium.com/@ohadrubin/exploring-weight-decay-in-layer-normalization-challenges-and-a-reparameterization-solution-ad4d12c24950}
120
126
  }
121
127
  ```
128
+
129
+ ```bibtex
130
+ @article{Zhu2025FracConnectionsFE,
131
+ title = {Frac-Connections: Fractional Extension of Hyper-Connections},
132
+ author = {Defa Zhu and Hongzhi Huang and Jundong Zhou and Zihao Huang and Yutao Zeng and Banggu Wu and Qiyang Min and Xun Zhou},
133
+ journal = {ArXiv},
134
+ year = {2025},
135
+ volume = {abs/2503.14125},
136
+ url = {https://api.semanticscholar.org/CorpusID:277104144}
137
+ }
138
+ ```
@@ -5,13 +5,13 @@ from functools import partial
5
5
  from random import randrange
6
6
 
7
7
  import torch
8
- from torch import nn
9
- from torch.nn import Module
8
+ from torch import nn, cat
10
9
  import torch.nn.functional as F
10
+ from torch.nn import Module, Sequential
11
11
  from torch.utils._pytree import tree_flatten, tree_unflatten
12
12
 
13
13
  from einops import rearrange, repeat, reduce, einsum
14
- from einops.layers.torch import Reduce
14
+ from einops.layers.torch import Rearrange, Reduce
15
15
 
16
16
  """
17
17
  ein notation:
@@ -19,6 +19,7 @@ b - batch
19
19
  d - feature dimension
20
20
  s - residual streams
21
21
  t - residual streams + num branch inputs
22
+ f - number of fractions (division of feature dimension space)
22
23
  v - number of views for branch input
23
24
  """
24
25
 
@@ -27,6 +28,9 @@ v - number of views for branch input
27
28
  def exists(v):
28
29
  return v is not None
29
30
 
31
+ def divisible_by(num, den):
32
+ return (num % den) == 0
33
+
30
34
  def default(v, d):
31
35
  return v if exists(v) else d
32
36
 
@@ -38,8 +42,12 @@ def add(x, y):
38
42
 
39
43
  # main functions
40
44
 
41
- def get_expand_reduce_stream_functions(num_streams, add_stream_embed = False, dim = None, disable = False):
42
-
45
+ def get_expand_reduce_stream_functions(
46
+ num_streams,
47
+ add_stream_embed = False,
48
+ dim = None,
49
+ disable = False
50
+ ):
43
51
  if num_streams == 1 or disable:
44
52
  return (nn.Identity(), nn.Identity())
45
53
 
@@ -54,11 +62,18 @@ def get_expand_reduce_stream_functions(num_streams, add_stream_embed = False, di
54
62
 
55
63
  return expand_fn, reduce_fn
56
64
 
57
- def get_init_and_expand_reduce_stream_functions(num_streams, dim = None, add_stream_embed = False, disable = False):
65
+ def get_init_and_expand_reduce_stream_functions(
66
+ num_streams,
67
+ num_fracs = 1,
68
+ dim = None,
69
+ add_stream_embed = False,
70
+ disable = None
71
+ ):
72
+ disable = default(disable, num_streams == 1 and num_fracs == 1)
58
73
 
59
74
  hyper_conn_klass = HyperConnections if not disable else Residual
60
75
 
61
- init_hyper_conn_fn = partial(hyper_conn_klass, num_streams)
76
+ init_hyper_conn_fn = partial(hyper_conn_klass, num_streams, num_fracs = num_fracs)
62
77
  expand_reduce_fns = get_expand_reduce_stream_functions(num_streams, add_stream_embed = add_stream_embed, dim = dim, disable = disable)
63
78
 
64
79
  if exists(dim):
@@ -93,13 +108,24 @@ class Residual(Module):
93
108
  self.branch = branch
94
109
  self.residual_transform = default(residual_transform, nn.Identity())
95
110
 
96
- def width_connection(self, residuals):
111
+ def width_connection(
112
+ self,
113
+ residuals
114
+ ):
97
115
  return residuals, residuals, dict()
98
116
 
99
- def depth_connection(self, branch_output, residuals):
117
+ def depth_connection(
118
+ self,
119
+ branch_output,
120
+ residuals,
121
+
122
+ ):
100
123
  return branch_output + self.residual_transform(residuals)
101
124
 
102
- def decorate_branch(self, branch: Callable):
125
+ def decorate_branch(
126
+ self,
127
+ branch: Callable
128
+ ):
103
129
  assert not exists(self.branch), 'branch was already wrapped on init'
104
130
 
105
131
  def forward_and_add_residual(residual, *args, **kwargs):
@@ -113,7 +139,12 @@ class Residual(Module):
113
139
 
114
140
  return forward_and_add_residual
115
141
 
116
- def forward(self, residuals, *branch_args, **branch_kwargs):
142
+ def forward(
143
+ self,
144
+ residuals,
145
+ *branch_args,
146
+ **branch_kwargs
147
+ ):
117
148
 
118
149
  branch_input, residuals, residual_kwargs = self.width_connection(residuals)
119
150
 
@@ -145,9 +176,10 @@ class HyperConnections(Module):
145
176
  channel_first = False,
146
177
  dropout = 0.,
147
178
  residual_transform: Module | None = None, # to support resnet blocks where dimension in not equal to dimension out - usually a residual conv
148
- add_branch_out_to_residual = True, # will disable depth connections (weighted residual sum with beta) if set False
149
- num_input_views = 1, # allow for the branch module to receive multiple input views, dimension placed on the very left (before batch)
150
- depth_residual_fn = add
179
+ add_branch_out_to_residual = True, # will disable depth connections (weighted residual sum with beta) if set False
180
+ num_input_views = 1, # allow for the branch module to receive multiple input views, dimension placed on the very left (before batch)
181
+ depth_residual_fn = add,
182
+ num_fracs = 1 # https://arxiv.org/abs/2503.14125
151
183
  ):
152
184
  """
153
185
  Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
@@ -160,13 +192,34 @@ class HyperConnections(Module):
160
192
 
161
193
  self.act = nn.Tanh() if tanh else nn.Identity()
162
194
 
163
- self.norm = RMSNorm(dim) # they used layernorm in paper, but rmsnorm is fine given what we know now
195
+ # frac-connections paper - num_fracs > 1 will be the `m` in their paper https://arxiv.org/abs/2503.14125
196
+
197
+ assert num_fracs >= 1
198
+
199
+ self.num_fracs = num_fracs
200
+ self.has_fracs = num_fracs > 1
201
+
202
+ self.split_fracs = Rearrange('b ... (f d) -> b ... f d', f = num_fracs)
203
+ self.merge_fracs = Rearrange('b ... f d -> b ... (f d)')
204
+
205
+ assert divisible_by(dim, num_fracs), f'feature dimension ({dim}) must be divisible by the `num_fracs` ({num_fracs})'
206
+
207
+ dim //= num_fracs # effective dim handled in dimension is feature dimension divided by num fractions
208
+
209
+ # they used layernorm in paper, but rmsnorm is fine given what we know now
210
+
211
+ self.norm = RMSNorm(dim)
164
212
 
165
213
  assert num_residual_streams > 0, '`num_residual_streams` must be greater than 0'
166
214
 
167
215
  self.num_residual_streams = num_residual_streams
168
216
  init_residual_index = default(layer_index, randrange(num_residual_streams)) % num_residual_streams # just choose one random residual stream if layer index not given
169
217
 
218
+ # handle the parameter dimensions, which may require (num_residuals x num_fractions) - generalizing hyper + frac connections
219
+
220
+ num_residual_streams_fracs = num_residual_streams * num_fracs
221
+ num_input_views_fracs = num_input_views * num_fracs
222
+
170
223
  # width num residual streams
171
224
 
172
225
  assert num_input_views >= 1
@@ -174,12 +227,12 @@ class HyperConnections(Module):
174
227
 
175
228
  # width connection
176
229
 
177
- init_alpha0 = torch.zeros((num_residual_streams, num_input_views))
230
+ init_alpha0 = torch.zeros((num_residual_streams_fracs, num_input_views_fracs))
178
231
  init_alpha0[init_residual_index, :] = 1.
179
232
 
180
- self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))
233
+ self.static_alpha = nn.Parameter(cat((init_alpha0, torch.eye(num_residual_streams_fracs)), dim = 1))
181
234
 
182
- self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams + num_input_views))
235
+ self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams_fracs + num_input_views_fracs))
183
236
  self.dynamic_alpha_scale = nn.Parameter(torch.ones(()) * 1e-2)
184
237
 
185
238
  # depth connection related (beta)
@@ -187,8 +240,11 @@ class HyperConnections(Module):
187
240
  self.add_branch_out_to_residual = add_branch_out_to_residual
188
241
 
189
242
  if add_branch_out_to_residual:
190
- self.static_beta = nn.Parameter(torch.ones(num_residual_streams))
191
- self.dynamic_beta_fn = nn.Parameter(torch.zeros(dim))
243
+ self.static_beta = nn.Parameter(torch.ones(num_residual_streams_fracs))
244
+
245
+ dynamic_beta_shape = (dim,) if num_fracs == 1 else (dim, num_fracs) # preserve backwards compat
246
+ self.dynamic_beta_fn = nn.Parameter(torch.zeros(dynamic_beta_shape))
247
+
192
248
  self.dynamic_beta_scale = nn.Parameter(torch.ones(()) * 1e-2)
193
249
 
194
250
  # dropouts
@@ -209,16 +265,30 @@ class HyperConnections(Module):
209
265
 
210
266
  self.depth_residual_fn = depth_residual_fn
211
267
 
212
- def width_connection(self, residuals):
268
+ def width_connection(
269
+ self,
270
+ residuals
271
+ ):
272
+ streams = self.num_residual_streams
213
273
 
214
274
  maybe_transformed_residuals = self.residual_transform(residuals)
215
275
 
216
276
  # width connection
217
277
 
278
+ # handle channel first
279
+
218
280
  if self.channel_first:
219
281
  residuals = rearrange(residuals, 'b d ... -> b ... d')
220
282
 
221
- residuals = rearrange(residuals, '(b s) ... d -> b ... s d', s = self.num_residual_streams)
283
+ # split out fractions
284
+
285
+ residuals = self.split_fracs(residuals)
286
+
287
+ # split out streams
288
+
289
+ residuals = rearrange(residuals, '(b s) ... d -> b ... s d', s = streams)
290
+
291
+ # norm
222
292
 
223
293
  normed = self.norm(residuals)
224
294
 
@@ -226,7 +296,12 @@ class HyperConnections(Module):
226
296
 
227
297
  wc_weight = self.act(normed @ self.dynamic_alpha_fn)
228
298
  dynamic_alpha = wc_weight * self.dynamic_alpha_scale
229
- alpha = dynamic_alpha + self.static_alpha
299
+
300
+ static_alpha = rearrange(self.static_alpha, '(f s) d -> f s d', s = streams)
301
+
302
+ alpha = dynamic_alpha + static_alpha
303
+
304
+ alpha = self.split_fracs(alpha) # (batch, seq, fracs1, streams, fracs2, input + residual streams)
230
305
 
231
306
  # beta for weights from branch output back to residual streams
232
307
 
@@ -234,10 +309,17 @@ class HyperConnections(Module):
234
309
 
235
310
  if self.add_branch_out_to_residual:
236
311
  dc_weight = self.act(normed @ self.dynamic_beta_fn)
312
+
313
+ if not self.has_fracs:
314
+ dc_weight = rearrange(dc_weight, '... -> ... 1')
315
+
237
316
  dynamic_beta = dc_weight * self.dynamic_beta_scale
238
- beta = dynamic_beta + self.static_beta
239
317
 
240
- mix_h = einsum(alpha, residuals, '... s t, ... s d -> ... t d')
318
+ static_beta = rearrange(self.static_beta, '... (s f) -> ... s f', s = streams)
319
+
320
+ beta = dynamic_beta + static_beta
321
+
322
+ mix_h = einsum(alpha, residuals, '... f1 s f2 t, ... f1 s d -> ... f2 t d')
241
323
 
242
324
  if self.num_input_views == 1:
243
325
  branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :]
@@ -248,19 +330,40 @@ class HyperConnections(Module):
248
330
  if self.channel_first:
249
331
  branch_input = rearrange(branch_input, 'b ... d -> b d ...')
250
332
 
333
+ # maybe merge fractions back
334
+
335
+ branch_input = self.merge_fracs(branch_input)
336
+
251
337
  return branch_input, maybe_transformed_residuals, dict(beta = beta)
252
338
 
253
- def depth_connection(self, branch_output, residuals, *, beta):
339
+ def depth_connection(
340
+ self,
341
+ branch_output,
342
+ residuals,
343
+ *,
344
+ beta
345
+ ):
254
346
  assert self.add_branch_out_to_residual
255
347
 
348
+ # maybe split fractions
349
+
350
+ branch_output = self.split_fracs(branch_output)
351
+
256
352
  # 'depth' connection
257
353
 
258
354
  if self.channel_first:
259
355
  branch_output = rearrange(branch_output, 'b d ... -> b ... d')
260
356
 
261
- output = einsum(branch_output, beta, 'b ... d, b ... s -> b ... s d')
357
+ output = einsum(branch_output, beta, 'b ... f1 d, b ... f1 s f2 -> b ... f2 s d')
358
+
262
359
  output = rearrange(output, 'b ... s d -> (b s) ... d')
263
360
 
361
+ # merge merge back fractions
362
+
363
+ output = self.merge_fracs(output)
364
+
365
+ # channel first
366
+
264
367
  if self.channel_first:
265
368
  output = rearrange(output, 'b ... d -> b d ...')
266
369
 
@@ -268,7 +371,10 @@ class HyperConnections(Module):
268
371
 
269
372
  return self.dropout(residuals)
270
373
 
271
- def decorate_branch(self, branch: Callable):
374
+ def decorate_branch(
375
+ self,
376
+ branch: Callable
377
+ ):
272
378
  assert not exists(self.branch), 'branch was already wrapped on init'
273
379
 
274
380
  def forward_and_add_residual(residual, *args, **kwargs):
@@ -282,7 +388,12 @@ class HyperConnections(Module):
282
388
 
283
389
  return forward_and_add_residual
284
390
 
285
- def forward(self, residuals, *branch_args, **branch_kwargs):
391
+ def forward(
392
+ self,
393
+ residuals,
394
+ *branch_args,
395
+ **branch_kwargs
396
+ ):
286
397
 
287
398
  branch_input, residuals, residual_kwargs = self.width_connection(residuals)
288
399
 
@@ -49,7 +49,9 @@ def get_expand_reduce_stream_functions(num_streams, disable = False):
49
49
 
50
50
  return expand_fn, reduce_fn
51
51
 
52
- def get_init_and_expand_reduce_stream_functions(num_streams, disable = False):
52
+ def get_init_and_expand_reduce_stream_functions(num_streams, disable = None):
53
+
54
+ disable = default(disable, num_streams == 1)
53
55
 
54
56
  hyper_conn_klass = HyperConnections if not disable else Residual
55
57
 
@@ -50,7 +50,9 @@ def get_expand_reduce_stream_functions(cls, num_streams, disable = False):
50
50
 
51
51
  return expand_fn, reduce_fn
52
52
 
53
- def get_init_and_expand_reduce_stream_functions(cls, num_streams, disable = False):
53
+ def get_init_and_expand_reduce_stream_functions(cls, num_streams, disable = None):
54
+
55
+ disable = default(disable, num_streams == 1)
54
56
 
55
57
  hyper_conn_klass = HyperConnections if not disable else Residual
56
58
 
@@ -41,7 +41,9 @@ def get_expand_reduce_stream_functions(num_streams, disable = False):
41
41
 
42
42
  return expand_fn, reduce_fn
43
43
 
44
- def get_init_and_expand_reduce_stream_functions(num_streams, disable = False):
44
+ def get_init_and_expand_reduce_stream_functions(num_streams, disable = None):
45
+
46
+ disable = default(disable, num_streams == 1)
45
47
 
46
48
  hyper_conn_klass = HyperConnections if not disable else Residual
47
49
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "hyper-connections"
3
- version = "0.1.15"
3
+ version = "0.2.1"
4
4
  description = "Hyper-Connections"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -3,8 +3,12 @@ import pytest
3
3
  import torch
4
4
  from torch import nn
5
5
 
6
+ @pytest.mark.parametrize('num_fracs', (1, 4))
6
7
  @pytest.mark.parametrize('disable', (False, True))
7
- def test_readme(disable):
8
+ def test_readme(
9
+ num_fracs,
10
+ disable
11
+ ):
8
12
 
9
13
  # a single branch layer
10
14
 
@@ -20,7 +24,7 @@ def test_readme(disable):
20
24
 
21
25
  from hyper_connections import get_init_and_expand_reduce_stream_functions
22
26
 
23
- init_hyper_conn, expand_stream, reduce_stream = get_init_and_expand_reduce_stream_functions(4, disable = disable)
27
+ init_hyper_conn, expand_stream, reduce_stream = get_init_and_expand_reduce_stream_functions(4, num_fracs = num_fracs, disable = disable)
24
28
 
25
29
  # 1. wrap your branch function
26
30