hyper-connections 0.1.14__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -5,13 +5,13 @@ from functools import partial
5
5
  from random import randrange
6
6
 
7
7
  import torch
8
- from torch import nn
9
- from torch.nn import Module
8
+ from torch import nn, cat
10
9
  import torch.nn.functional as F
10
+ from torch.nn import Module, Sequential
11
11
  from torch.utils._pytree import tree_flatten, tree_unflatten
12
12
 
13
13
  from einops import rearrange, repeat, reduce, einsum
14
- from einops.layers.torch import Reduce
14
+ from einops.layers.torch import Rearrange, Reduce
15
15
 
16
16
  """
17
17
  ein notation:
@@ -19,6 +19,7 @@ b - batch
19
19
  d - feature dimension
20
20
  s - residual streams
21
21
  t - residual streams + num branch inputs
22
+ f - number of fractions (division of feature dimension space)
22
23
  v - number of views for branch input
23
24
  """
24
25
 
@@ -27,6 +28,9 @@ v - number of views for branch input
27
28
  def exists(v):
28
29
  return v is not None
29
30
 
31
+ def divisible_by(num, den):
32
+ return (num % den) == 0
33
+
30
34
  def default(v, d):
31
35
  return v if exists(v) else d
32
36
 
@@ -38,8 +42,12 @@ def add(x, y):
38
42
 
39
43
  # main functions
40
44
 
41
- def get_expand_reduce_stream_functions(num_streams, add_stream_embed = False, dim = None, disable = False):
42
-
45
+ def get_expand_reduce_stream_functions(
46
+ num_streams,
47
+ add_stream_embed = False,
48
+ dim = None,
49
+ disable = False
50
+ ):
43
51
  if num_streams == 1 or disable:
44
52
  return (nn.Identity(), nn.Identity())
45
53
 
@@ -54,11 +62,16 @@ def get_expand_reduce_stream_functions(num_streams, add_stream_embed = False, di
54
62
 
55
63
  return expand_fn, reduce_fn
56
64
 
57
- def get_init_and_expand_reduce_stream_functions(num_streams, dim = None, add_stream_embed = False, disable = False):
58
-
65
+ def get_init_and_expand_reduce_stream_functions(
66
+ num_streams,
67
+ num_fracs = 1,
68
+ dim = None,
69
+ add_stream_embed = False,
70
+ disable = False
71
+ ):
59
72
  hyper_conn_klass = HyperConnections if not disable else Residual
60
73
 
61
- init_hyper_conn_fn = partial(hyper_conn_klass, num_streams)
74
+ init_hyper_conn_fn = partial(hyper_conn_klass, num_streams, num_fracs = num_fracs)
62
75
  expand_reduce_fns = get_expand_reduce_stream_functions(num_streams, add_stream_embed = add_stream_embed, dim = dim, disable = disable)
63
76
 
64
77
  if exists(dim):
@@ -93,13 +106,24 @@ class Residual(Module):
93
106
  self.branch = branch
94
107
  self.residual_transform = default(residual_transform, nn.Identity())
95
108
 
96
- def width_connection(self, residuals):
109
+ def width_connection(
110
+ self,
111
+ residuals
112
+ ):
97
113
  return residuals, residuals, dict()
98
114
 
99
- def depth_connection(self, branch_output, residuals):
115
+ def depth_connection(
116
+ self,
117
+ branch_output,
118
+ residuals,
119
+
120
+ ):
100
121
  return branch_output + self.residual_transform(residuals)
101
122
 
102
- def decorate_branch(self, branch: Callable):
123
+ def decorate_branch(
124
+ self,
125
+ branch: Callable
126
+ ):
103
127
  assert not exists(self.branch), 'branch was already wrapped on init'
104
128
 
105
129
  def forward_and_add_residual(residual, *args, **kwargs):
@@ -113,7 +137,12 @@ class Residual(Module):
113
137
 
114
138
  return forward_and_add_residual
115
139
 
116
- def forward(self, residuals, *branch_args, **branch_kwargs):
140
+ def forward(
141
+ self,
142
+ residuals,
143
+ *branch_args,
144
+ **branch_kwargs
145
+ ):
117
146
 
118
147
  branch_input, residuals, residual_kwargs = self.width_connection(residuals)
119
148
 
@@ -145,9 +174,10 @@ class HyperConnections(Module):
145
174
  channel_first = False,
146
175
  dropout = 0.,
147
176
  residual_transform: Module | None = None, # to support resnet blocks where dimension in not equal to dimension out - usually a residual conv
148
- add_branch_out_to_residual = True, # will disable depth connections (weighted residual sum with beta) if set False
149
- num_input_views = 1, # allow for the branch module to receive multiple input views, dimension placed on the very left (before batch)
150
- depth_residual_fn = add
177
+ add_branch_out_to_residual = True, # will disable depth connections (weighted residual sum with beta) if set False
178
+ num_input_views = 1, # allow for the branch module to receive multiple input views, dimension placed on the very left (before batch)
179
+ depth_residual_fn = add,
180
+ num_fracs = 1 # https://arxiv.org/abs/2503.14125
151
181
  ):
152
182
  """
153
183
  Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
@@ -160,13 +190,34 @@ class HyperConnections(Module):
160
190
 
161
191
  self.act = nn.Tanh() if tanh else nn.Identity()
162
192
 
163
- self.norm = RMSNorm(dim) # they used layernorm in paper, but rmsnorm is fine given what we know now
193
+ # frac-connections paper - num_fracs > 1 will be the `m` in their paper https://arxiv.org/abs/2503.14125
194
+
195
+ assert num_fracs >= 1
196
+
197
+ self.num_fracs = num_fracs
198
+ self.has_fracs = num_fracs > 1
199
+
200
+ self.split_fracs = Rearrange('b ... (f d) -> b ... f d', f = num_fracs)
201
+ self.merge_fracs = Rearrange('b ... f d -> b ... (f d)')
202
+
203
+ assert divisible_by(dim, num_fracs), f'feature dimension ({dim}) must be divisible by the `num_fracs` ({num_fracs})'
204
+
205
+ dim //= num_fracs # effective dim handled in dimension is feature dimension divided by num fractions
206
+
207
+ # they used layernorm in paper, but rmsnorm is fine given what we know now
208
+
209
+ self.norm = RMSNorm(dim)
164
210
 
165
211
  assert num_residual_streams > 0, '`num_residual_streams` must be greater than 0'
166
212
 
167
213
  self.num_residual_streams = num_residual_streams
168
214
  init_residual_index = default(layer_index, randrange(num_residual_streams)) % num_residual_streams # just choose one random residual stream if layer index not given
169
215
 
216
+ # handle the parameter dimensions, which may require (num_residuals x num_fractions) - generalizing hyper + frac connections
217
+
218
+ num_residual_streams_fracs = num_residual_streams * num_fracs
219
+ num_input_views_fracs = num_input_views * num_fracs
220
+
170
221
  # width num residual streams
171
222
 
172
223
  assert num_input_views >= 1
@@ -174,12 +225,12 @@ class HyperConnections(Module):
174
225
 
175
226
  # width connection
176
227
 
177
- init_alpha0 = torch.zeros((num_residual_streams, num_input_views))
228
+ init_alpha0 = torch.zeros((num_residual_streams_fracs, num_input_views_fracs))
178
229
  init_alpha0[init_residual_index, :] = 1.
179
230
 
180
- self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))
231
+ self.static_alpha = nn.Parameter(cat((init_alpha0, torch.eye(num_residual_streams_fracs)), dim = 1))
181
232
 
182
- self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams + num_input_views))
233
+ self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams_fracs + num_input_views_fracs))
183
234
  self.dynamic_alpha_scale = nn.Parameter(torch.ones(()) * 1e-2)
184
235
 
185
236
  # depth connection related (beta)
@@ -187,8 +238,11 @@ class HyperConnections(Module):
187
238
  self.add_branch_out_to_residual = add_branch_out_to_residual
188
239
 
189
240
  if add_branch_out_to_residual:
190
- self.static_beta = nn.Parameter(torch.ones(num_residual_streams))
191
- self.dynamic_beta_fn = nn.Parameter(torch.zeros(dim))
241
+ self.static_beta = nn.Parameter(torch.ones(num_residual_streams_fracs))
242
+
243
+ dynamic_beta_shape = (dim,) if num_fracs == 1 else (dim, num_fracs) # preserve backwards compat
244
+ self.dynamic_beta_fn = nn.Parameter(torch.zeros(dynamic_beta_shape))
245
+
192
246
  self.dynamic_beta_scale = nn.Parameter(torch.ones(()) * 1e-2)
193
247
 
194
248
  # dropouts
@@ -209,16 +263,30 @@ class HyperConnections(Module):
209
263
 
210
264
  self.depth_residual_fn = depth_residual_fn
211
265
 
212
- def width_connection(self, residuals):
266
+ def width_connection(
267
+ self,
268
+ residuals
269
+ ):
270
+ streams = self.num_residual_streams
213
271
 
214
272
  maybe_transformed_residuals = self.residual_transform(residuals)
215
273
 
216
274
  # width connection
217
275
 
276
+ # handle channel first
277
+
218
278
  if self.channel_first:
219
279
  residuals = rearrange(residuals, 'b d ... -> b ... d')
220
280
 
221
- residuals = rearrange(residuals, '(b s) ... d -> b ... s d', s = self.num_residual_streams)
281
+ # split out fractions
282
+
283
+ residuals = self.split_fracs(residuals)
284
+
285
+ # split out streams
286
+
287
+ residuals = rearrange(residuals, '(b s) ... d -> b ... s d', s = streams)
288
+
289
+ # norm
222
290
 
223
291
  normed = self.norm(residuals)
224
292
 
@@ -226,7 +294,12 @@ class HyperConnections(Module):
226
294
 
227
295
  wc_weight = self.act(normed @ self.dynamic_alpha_fn)
228
296
  dynamic_alpha = wc_weight * self.dynamic_alpha_scale
229
- alpha = dynamic_alpha + self.static_alpha
297
+
298
+ static_alpha = rearrange(self.static_alpha, '(f s) d -> f s d', s = streams)
299
+
300
+ alpha = dynamic_alpha + static_alpha
301
+
302
+ alpha = self.split_fracs(alpha) # (batch, seq, fracs1, streams, fracs2, input + residual streams)
230
303
 
231
304
  # beta for weights from branch output back to residual streams
232
305
 
@@ -234,10 +307,17 @@ class HyperConnections(Module):
234
307
 
235
308
  if self.add_branch_out_to_residual:
236
309
  dc_weight = self.act(normed @ self.dynamic_beta_fn)
310
+
311
+ if not self.has_fracs:
312
+ dc_weight = rearrange(dc_weight, '... -> ... 1')
313
+
237
314
  dynamic_beta = dc_weight * self.dynamic_beta_scale
238
- beta = dynamic_beta + self.static_beta
239
315
 
240
- mix_h = einsum(alpha, residuals, '... s t, ... s d -> ... t d')
316
+ static_beta = rearrange(self.static_beta, '... (s f) -> ... s f', s = streams)
317
+
318
+ beta = dynamic_beta + static_beta
319
+
320
+ mix_h = einsum(alpha, residuals, '... f1 s f2 t, ... f1 s d -> ... f2 t d')
241
321
 
242
322
  if self.num_input_views == 1:
243
323
  branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :]
@@ -248,19 +328,40 @@ class HyperConnections(Module):
248
328
  if self.channel_first:
249
329
  branch_input = rearrange(branch_input, 'b ... d -> b d ...')
250
330
 
331
+ # maybe merge fractions back
332
+
333
+ branch_input = self.merge_fracs(branch_input)
334
+
251
335
  return branch_input, maybe_transformed_residuals, dict(beta = beta)
252
336
 
253
- def depth_connection(self, branch_output, residuals, *, beta):
337
+ def depth_connection(
338
+ self,
339
+ branch_output,
340
+ residuals,
341
+ *,
342
+ beta
343
+ ):
254
344
  assert self.add_branch_out_to_residual
255
345
 
346
+ # maybe split fractions
347
+
348
+ branch_output = self.split_fracs(branch_output)
349
+
256
350
  # 'depth' connection
257
351
 
258
352
  if self.channel_first:
259
353
  branch_output = rearrange(branch_output, 'b d ... -> b ... d')
260
354
 
261
- output = einsum(branch_output, beta, 'b ... d, b ... s -> b ... s d')
355
+ output = einsum(branch_output, beta, 'b ... f1 d, b ... f1 s f2 -> b ... f2 s d')
356
+
262
357
  output = rearrange(output, 'b ... s d -> (b s) ... d')
263
358
 
359
+ # merge merge back fractions
360
+
361
+ output = self.merge_fracs(output)
362
+
363
+ # channel first
364
+
264
365
  if self.channel_first:
265
366
  output = rearrange(output, 'b ... d -> b d ...')
266
367
 
@@ -268,7 +369,10 @@ class HyperConnections(Module):
268
369
 
269
370
  return self.dropout(residuals)
270
371
 
271
- def decorate_branch(self, branch: Callable):
372
+ def decorate_branch(
373
+ self,
374
+ branch: Callable
375
+ ):
272
376
  assert not exists(self.branch), 'branch was already wrapped on init'
273
377
 
274
378
  def forward_and_add_residual(residual, *args, **kwargs):
@@ -282,7 +386,12 @@ class HyperConnections(Module):
282
386
 
283
387
  return forward_and_add_residual
284
388
 
285
- def forward(self, residuals, *branch_args, **branch_kwargs):
389
+ def forward(
390
+ self,
391
+ residuals,
392
+ *branch_args,
393
+ **branch_kwargs
394
+ ):
286
395
 
287
396
  branch_input, residuals, residual_kwargs = self.width_connection(residuals)
288
397
 
@@ -13,21 +13,23 @@ class GRUGatedResidual(Module):
13
13
  self.gru = nn.GRUCell(dim, dim)
14
14
 
15
15
  def forward(self, x, residual):
16
- x, ps = pack([x], '* d')
16
+ x, packed_shape = pack([x], '* d')
17
17
  residual, _ = pack([residual], '* d')
18
18
 
19
19
  output = self.gru(x, residual)
20
20
 
21
- output, = unpack(output, ps, '* d')
21
+ output, = unpack(output, packed_shape, '* d')
22
22
  return output
23
23
 
24
24
  class GatedResidual(Module):
25
25
  def __init__(
26
26
  self,
27
- dim
27
+ dim,
28
+ fine_gate = False
28
29
  ):
29
30
  super().__init__()
30
- self.to_learned_mix = nn.Linear(dim * 2, dim)
31
+
32
+ self.to_learned_mix = nn.Linear(dim * 2, dim if fine_gate else 1)
31
33
 
32
34
  def forward(self, x, residual):
33
35
  x_and_residual, _ = pack([x, residual], 'b n *')
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.1.14
3
+ Version: 0.2.0
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -141,6 +141,12 @@ To compare hyper connections to plain residual without changing the code, just p
141
141
  get_init_and_expand_reduce_stream_functions(4, disable = True)
142
142
  ```
143
143
 
144
+ To use the fractionated feature dimensions proposed in (a follow up paper)[https://arxiv.org/abs/2503.14125] by same authors, just instantiate with `num_fracs` greater than `1` as so
145
+
146
+ ```python
147
+ get_init_and_expand_reduce_stream_functions(1, num_fracs = 4) # also allows you to mix streams and fractions of feature dimension
148
+ ```
149
+
144
150
  ## Citation
145
151
 
146
152
  ```bibtex
@@ -160,3 +166,14 @@ get_init_and_expand_reduce_stream_functions(4, disable = True)
160
166
  url = {https://medium.com/@ohadrubin/exploring-weight-decay-in-layer-normalization-challenges-and-a-reparameterization-solution-ad4d12c24950}
161
167
  }
162
168
  ```
169
+
170
+ ```bibtex
171
+ @article{Zhu2025FracConnectionsFE,
172
+ title = {Frac-Connections: Fractional Extension of Hyper-Connections},
173
+ author = {Defa Zhu and Hongzhi Huang and Jundong Zhou and Zihao Huang and Yutao Zeng and Banggu Wu and Qiyang Min and Xun Zhou},
174
+ journal = {ArXiv},
175
+ year = {2025},
176
+ volume = {abs/2503.14125},
177
+ url = {https://api.semanticscholar.org/CorpusID:277104144}
178
+ }
179
+ ```
@@ -0,0 +1,10 @@
1
+ hyper_connections/__init__.py,sha256=d2zNTka0Gp9vINu4U-RhgTJFBhsVrs1fne_15Zl0oOs,224
2
+ hyper_connections/hyper_connections.py,sha256=iMKJBJioaPus6QcR50nwdDrIs96P1liyktC5RE6Drds,14953
3
+ hyper_connections/hyper_connections_channel_first.py,sha256=BojfO2dcT4jX1rlcU3kr0B6B_CjrkkS2AZU4ZXeWvh8,6769
4
+ hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=2JABz1slrF5_XP33L1CMNTmmixfoD464FtQpUADqneU,7806
5
+ hyper_connections/hyper_connections_with_multi_input_streams.py,sha256=UIKceEZEBLKFL5VuP5tR1KTDeZNIJEKjFuPAwXkcp0I,11282
6
+ hyper_connections/residuals.py,sha256=JVSFJj_H7xQ3_Fd-pZH5Hdv9SveAQu29jQNvMyom5ek,921
7
+ hyper_connections-0.2.0.dist-info/METADATA,sha256=Ypr8d84gZhK_SB-nFhOm5s3VP_-1if4JB39f_vBf6oc,5966
8
+ hyper_connections-0.2.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
9
+ hyper_connections-0.2.0.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
10
+ hyper_connections-0.2.0.dist-info/RECORD,,
@@ -1,10 +0,0 @@
1
- hyper_connections/__init__.py,sha256=d2zNTka0Gp9vINu4U-RhgTJFBhsVrs1fne_15Zl0oOs,224
2
- hyper_connections/hyper_connections.py,sha256=vpipBRUGgYQ2qLBtT4Ws-myJYVdkQDkN3IkpTMkxRxc,12485
3
- hyper_connections/hyper_connections_channel_first.py,sha256=BojfO2dcT4jX1rlcU3kr0B6B_CjrkkS2AZU4ZXeWvh8,6769
4
- hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=2JABz1slrF5_XP33L1CMNTmmixfoD464FtQpUADqneU,7806
5
- hyper_connections/hyper_connections_with_multi_input_streams.py,sha256=UIKceEZEBLKFL5VuP5tR1KTDeZNIJEKjFuPAwXkcp0I,11282
6
- hyper_connections/residuals.py,sha256=qapN4lt51qNWKa5nX7whN4xcNORxMdr3bdUwIMQPdpQ,853
7
- hyper_connections-0.1.14.dist-info/METADATA,sha256=7Agg3rGvMYkZEyX3n9yJPO4cNm8-9a33afV4Yc8r7WA,5231
8
- hyper_connections-0.1.14.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
9
- hyper_connections-0.1.14.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
10
- hyper_connections-0.1.14.dist-info/RECORD,,