hyper-connections 0.3.16__tar.gz → 0.4.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (17) hide show
  1. {hyper_connections-0.3.16 → hyper_connections-0.4.1}/PKG-INFO +1 -1
  2. {hyper_connections-0.3.16 → hyper_connections-0.4.1}/hyper_connections/manifold_constrained_hyper_connections.py +23 -8
  3. {hyper_connections-0.3.16 → hyper_connections-0.4.1}/hyper_connections/vit.py +4 -4
  4. {hyper_connections-0.3.16 → hyper_connections-0.4.1}/pyproject.toml +1 -1
  5. {hyper_connections-0.3.16 → hyper_connections-0.4.1}/tests/test_hyper_connections.py +11 -5
  6. {hyper_connections-0.3.16 → hyper_connections-0.4.1}/.github/workflows/python-publish.yml +0 -0
  7. {hyper_connections-0.3.16 → hyper_connections-0.4.1}/.github/workflows/test.yml +0 -0
  8. {hyper_connections-0.3.16 → hyper_connections-0.4.1}/.gitignore +0 -0
  9. {hyper_connections-0.3.16 → hyper_connections-0.4.1}/LICENSE +0 -0
  10. {hyper_connections-0.3.16 → hyper_connections-0.4.1}/README.md +0 -0
  11. {hyper_connections-0.3.16 → hyper_connections-0.4.1}/hyper-connections.png +0 -0
  12. {hyper_connections-0.3.16 → hyper_connections-0.4.1}/hyper_connections/__init__.py +0 -0
  13. {hyper_connections-0.3.16 → hyper_connections-0.4.1}/hyper_connections/hyper_connections.py +0 -0
  14. {hyper_connections-0.3.16 → hyper_connections-0.4.1}/hyper_connections/hyper_connections_channel_first.py +0 -0
  15. {hyper_connections-0.3.16 → hyper_connections-0.4.1}/hyper_connections/hyper_connections_with_multi_branch_inputs.py +0 -0
  16. {hyper_connections-0.3.16 → hyper_connections-0.4.1}/hyper_connections/hyper_connections_with_multi_input_streams.py +0 -0
  17. {hyper_connections-0.3.16 → hyper_connections-0.4.1}/hyper_connections/residuals.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.3.16
3
+ Version: 0.4.1
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -21,6 +21,7 @@ s - residual streams
21
21
  t - residual streams + num branch inputs
22
22
  f - number of fractions (division of feature dimension space)
23
23
  v - number of views for branch input
24
+ p - proposals
24
25
  """
25
26
 
26
27
  # helper functions
@@ -212,7 +213,10 @@ class ManifoldConstrainedHyperConnections(Module):
212
213
  num_fracs = 1, # https://arxiv.org/abs/2503.14125
213
214
  sinkhorn_iters = 20,
214
215
  log_domain_sinkhorn = False,
216
+ residual_constraint_fn: Callable | None = None,
215
217
  forward_method_names: tuple[str, ...] = (),
218
+ num_dynamic_alpha_proposals = 1,
219
+
216
220
  ):
217
221
  """
218
222
  Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
@@ -256,6 +260,11 @@ class ManifoldConstrainedHyperConnections(Module):
256
260
  assert num_input_views >= 1
257
261
  self.num_input_views = num_input_views
258
262
 
263
+ # number of dynamic alpha proposals, for averaging Hres across proposals
264
+
265
+ self.has_dynamic_alpha_proposals = num_dynamic_alpha_proposals > 1
266
+ self.num_dynamic_alpha_proposals = num_dynamic_alpha_proposals
267
+
259
268
  # width connection
260
269
 
261
270
  init_alpha0 = torch.zeros((num_residual_streams_fracs, num_input_views_fracs))
@@ -263,7 +272,7 @@ class ManifoldConstrainedHyperConnections(Module):
263
272
 
264
273
  self.static_alpha = nn.Parameter(cat((init_alpha0, torch.eye(num_residual_streams_fracs)), dim = 1))
265
274
 
266
- self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams_fracs + num_input_views_fracs))
275
+ self.dynamic_alpha_fn = nn.Parameter(torch.zeros(num_dynamic_alpha_proposals, dim, num_residual_streams_fracs + num_input_views_fracs))
267
276
 
268
277
  self.pre_branch_scale = nn.Parameter(torch.ones(1) * 1e-2)
269
278
  self.residual_scale = nn.Parameter(torch.ones(1) * 1e-2)
@@ -280,10 +289,13 @@ class ManifoldConstrainedHyperConnections(Module):
280
289
 
281
290
  self.h_post_scale = nn.Parameter(torch.ones(()) * 1e-2)
282
291
 
283
- # sinkhorn related
292
+ # Hres constraint related
293
+ # by default is sinkhorn
284
294
 
285
- self.sinkhorn_iters = sinkhorn_iters
286
- self.log_domain_sinkhorn = log_domain_sinkhorn
295
+ self.residual_constraint_fn = default(
296
+ residual_constraint_fn,
297
+ partial(sinkhorn_knopps if not log_domain_sinkhorn else log_domain_sinkhorn_knopps, iters = sinkhorn_iters)
298
+ )
287
299
 
288
300
  # dropouts
289
301
 
@@ -346,7 +358,7 @@ class ManifoldConstrainedHyperConnections(Module):
346
358
 
347
359
  normed = normed.float()
348
360
 
349
- wc_weight = normed @ self.dynamic_alpha_fn.float()
361
+ wc_weight = einsum(normed, self.dynamic_alpha_fn.float(), '... d, p d e -> p ... e')
350
362
 
351
363
  pre_branch_scale = repeat(self.pre_branch_scale.float(), '1 -> s', s = self.num_fracs)
352
364
  residual_scale = repeat(self.residual_scale.float(), '1 -> s', s = self.num_fracs * streams)
@@ -366,12 +378,15 @@ class ManifoldConstrainedHyperConnections(Module):
366
378
 
367
379
  alpha_pre = alpha_pre.sigmoid()
368
380
 
369
- sinkhorn_fn = sinkhorn_knopps if not self.log_domain_sinkhorn else log_domain_sinkhorn_knopps
370
-
371
- alpha_residual = sinkhorn_fn(alpha_residual, self.sinkhorn_iters)
381
+ alpha_residual = self.residual_constraint_fn(alpha_residual)
372
382
 
373
383
  alpha = cat((alpha_pre, alpha_residual), dim = -1)
374
384
 
385
+ if self.has_dynamic_alpha_proposals:
386
+ alpha = reduce(alpha, 'p ... -> ...', 'mean')
387
+ else:
388
+ alpha = rearrange(alpha, '1 ... -> ...')
389
+
375
390
  alpha = self.split_fracs(alpha) # (batch, seq, fracs1, streams, fracs2, input + residual streams)
376
391
 
377
392
  # beta for weights from branch output back to residual streams
@@ -66,12 +66,12 @@ class Attention(Module):
66
66
  return self.to_out(out)
67
67
 
68
68
  class Transformer(Module):
69
- def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., num_residual_streams = 4):
69
+ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., num_residual_streams = 4, num_dynamic_alpha_proposals = 1):
70
70
  super().__init__()
71
71
  self.norm = nn.LayerNorm(dim)
72
72
  self.layers = ModuleList([])
73
73
 
74
- init_hyper_conn, self.expand_streams, self.reduce_streams = mHC.get_init_and_expand_reduce_stream_functions(num_residual_streams)
74
+ init_hyper_conn, self.expand_streams, self.reduce_streams = mHC.get_init_and_expand_reduce_stream_functions(num_residual_streams, num_dynamic_alpha_proposals = num_dynamic_alpha_proposals)
75
75
 
76
76
  for _ in range(depth):
77
77
  self.layers.append(ModuleList([
@@ -92,7 +92,7 @@ class Transformer(Module):
92
92
  return self.norm(x)
93
93
 
94
94
  class ViT(Module):
95
- def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., num_residual_streams = 4):
95
+ def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., num_residual_streams = 4, num_dynamic_alpha_proposals = 1):
96
96
  super().__init__()
97
97
  image_height, image_width = pair(image_size)
98
98
  patch_height, patch_width = pair(patch_size)
@@ -117,7 +117,7 @@ class ViT(Module):
117
117
 
118
118
  self.dropout = nn.Dropout(emb_dropout)
119
119
 
120
- self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
120
+ self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, num_residual_streams, num_dynamic_alpha_proposals)
121
121
 
122
122
  self.pool = pool
123
123
  self.to_latent = nn.Identity()
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "hyper-connections"
3
- version = "0.3.16"
3
+ version = "0.4.1"
4
4
  description = "Hyper-Connections"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -1,11 +1,12 @@
1
1
  import pytest
2
+ param = pytest.mark.parametrize
2
3
 
3
4
  import torch
4
5
  from torch import nn
5
6
 
6
- @pytest.mark.parametrize('num_fracs', (1, 4))
7
- @pytest.mark.parametrize('disable', (False, True))
8
- @pytest.mark.parametrize('manifold_constrained', (False, True))
7
+ @param('num_fracs', (1, 4))
8
+ @param('disable', (False, True))
9
+ @param('manifold_constrained', (False, True))
9
10
  def test_readme(
10
11
  num_fracs,
11
12
  disable,
@@ -208,7 +209,11 @@ def test_mhc_dtype_restoration():
208
209
 
209
210
  assert residual.dtype == torch.half
210
211
 
211
- def test_mhc_vit():
212
+ @param('num_dynamic_alpha_proposals', (1, 2))
213
+ def test_mhc_vit(
214
+ num_dynamic_alpha_proposals
215
+ ):
216
+
212
217
  from hyper_connections.vit import ViT
213
218
 
214
219
  v = ViT(
@@ -221,7 +226,8 @@ def test_mhc_vit():
221
226
  mlp_dim = 2048,
222
227
  dropout = 0.1,
223
228
  emb_dropout = 0.1,
224
- num_residual_streams = 4
229
+ num_residual_streams = 4,
230
+ num_dynamic_alpha_proposals = num_dynamic_alpha_proposals
225
231
  )
226
232
 
227
233
  img = torch.randn(1, 3, 256, 256)