hyper-connections 0.3.16__tar.gz → 0.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (17) hide show
  1. {hyper_connections-0.3.16 → hyper_connections-0.4.0}/PKG-INFO +1 -1
  2. {hyper_connections-0.3.16 → hyper_connections-0.4.0}/hyper_connections/manifold_constrained_hyper_connections.py +14 -2
  3. {hyper_connections-0.3.16 → hyper_connections-0.4.0}/hyper_connections/vit.py +4 -4
  4. {hyper_connections-0.3.16 → hyper_connections-0.4.0}/pyproject.toml +1 -1
  5. {hyper_connections-0.3.16 → hyper_connections-0.4.0}/tests/test_hyper_connections.py +11 -5
  6. {hyper_connections-0.3.16 → hyper_connections-0.4.0}/.github/workflows/python-publish.yml +0 -0
  7. {hyper_connections-0.3.16 → hyper_connections-0.4.0}/.github/workflows/test.yml +0 -0
  8. {hyper_connections-0.3.16 → hyper_connections-0.4.0}/.gitignore +0 -0
  9. {hyper_connections-0.3.16 → hyper_connections-0.4.0}/LICENSE +0 -0
  10. {hyper_connections-0.3.16 → hyper_connections-0.4.0}/README.md +0 -0
  11. {hyper_connections-0.3.16 → hyper_connections-0.4.0}/hyper-connections.png +0 -0
  12. {hyper_connections-0.3.16 → hyper_connections-0.4.0}/hyper_connections/__init__.py +0 -0
  13. {hyper_connections-0.3.16 → hyper_connections-0.4.0}/hyper_connections/hyper_connections.py +0 -0
  14. {hyper_connections-0.3.16 → hyper_connections-0.4.0}/hyper_connections/hyper_connections_channel_first.py +0 -0
  15. {hyper_connections-0.3.16 → hyper_connections-0.4.0}/hyper_connections/hyper_connections_with_multi_branch_inputs.py +0 -0
  16. {hyper_connections-0.3.16 → hyper_connections-0.4.0}/hyper_connections/hyper_connections_with_multi_input_streams.py +0 -0
  17. {hyper_connections-0.3.16 → hyper_connections-0.4.0}/hyper_connections/residuals.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.3.16
3
+ Version: 0.4.0
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -21,6 +21,7 @@ s - residual streams
21
21
  t - residual streams + num branch inputs
22
22
  f - number of fractions (division of feature dimension space)
23
23
  v - number of views for branch input
24
+ p - proposals
24
25
  """
25
26
 
26
27
  # helper functions
@@ -213,6 +214,7 @@ class ManifoldConstrainedHyperConnections(Module):
213
214
  sinkhorn_iters = 20,
214
215
  log_domain_sinkhorn = False,
215
216
  forward_method_names: tuple[str, ...] = (),
217
+ num_dynamic_alpha_proposals = 1
216
218
  ):
217
219
  """
218
220
  Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
@@ -256,6 +258,11 @@ class ManifoldConstrainedHyperConnections(Module):
256
258
  assert num_input_views >= 1
257
259
  self.num_input_views = num_input_views
258
260
 
261
+ # number of dynamic alpha proposals, for averaging Hres across proposals
262
+
263
+ self.has_dynamic_alpha_proposals = num_dynamic_alpha_proposals > 1
264
+ self.num_dynamic_alpha_proposals = num_dynamic_alpha_proposals
265
+
259
266
  # width connection
260
267
 
261
268
  init_alpha0 = torch.zeros((num_residual_streams_fracs, num_input_views_fracs))
@@ -263,7 +270,7 @@ class ManifoldConstrainedHyperConnections(Module):
263
270
 
264
271
  self.static_alpha = nn.Parameter(cat((init_alpha0, torch.eye(num_residual_streams_fracs)), dim = 1))
265
272
 
266
- self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams_fracs + num_input_views_fracs))
273
+ self.dynamic_alpha_fn = nn.Parameter(torch.zeros(num_dynamic_alpha_proposals, dim, num_residual_streams_fracs + num_input_views_fracs))
267
274
 
268
275
  self.pre_branch_scale = nn.Parameter(torch.ones(1) * 1e-2)
269
276
  self.residual_scale = nn.Parameter(torch.ones(1) * 1e-2)
@@ -346,7 +353,7 @@ class ManifoldConstrainedHyperConnections(Module):
346
353
 
347
354
  normed = normed.float()
348
355
 
349
- wc_weight = normed @ self.dynamic_alpha_fn.float()
356
+ wc_weight = einsum(normed, self.dynamic_alpha_fn.float(), '... d, p d e -> p ... e')
350
357
 
351
358
  pre_branch_scale = repeat(self.pre_branch_scale.float(), '1 -> s', s = self.num_fracs)
352
359
  residual_scale = repeat(self.residual_scale.float(), '1 -> s', s = self.num_fracs * streams)
@@ -372,6 +379,11 @@ class ManifoldConstrainedHyperConnections(Module):
372
379
 
373
380
  alpha = cat((alpha_pre, alpha_residual), dim = -1)
374
381
 
382
+ if self.has_dynamic_alpha_proposals:
383
+ alpha = reduce(alpha, 'p ... -> ...', 'mean')
384
+ else:
385
+ alpha = rearrange(alpha, '1 ... -> ...')
386
+
375
387
  alpha = self.split_fracs(alpha) # (batch, seq, fracs1, streams, fracs2, input + residual streams)
376
388
 
377
389
  # beta for weights from branch output back to residual streams
@@ -66,12 +66,12 @@ class Attention(Module):
66
66
  return self.to_out(out)
67
67
 
68
68
  class Transformer(Module):
69
- def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., num_residual_streams = 4):
69
+ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., num_residual_streams = 4, num_dynamic_alpha_proposals = 1):
70
70
  super().__init__()
71
71
  self.norm = nn.LayerNorm(dim)
72
72
  self.layers = ModuleList([])
73
73
 
74
- init_hyper_conn, self.expand_streams, self.reduce_streams = mHC.get_init_and_expand_reduce_stream_functions(num_residual_streams)
74
+ init_hyper_conn, self.expand_streams, self.reduce_streams = mHC.get_init_and_expand_reduce_stream_functions(num_residual_streams, num_dynamic_alpha_proposals = num_dynamic_alpha_proposals)
75
75
 
76
76
  for _ in range(depth):
77
77
  self.layers.append(ModuleList([
@@ -92,7 +92,7 @@ class Transformer(Module):
92
92
  return self.norm(x)
93
93
 
94
94
  class ViT(Module):
95
- def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., num_residual_streams = 4):
95
+ def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., num_residual_streams = 4, num_dynamic_alpha_proposals = 1):
96
96
  super().__init__()
97
97
  image_height, image_width = pair(image_size)
98
98
  patch_height, patch_width = pair(patch_size)
@@ -117,7 +117,7 @@ class ViT(Module):
117
117
 
118
118
  self.dropout = nn.Dropout(emb_dropout)
119
119
 
120
- self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
120
+ self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, num_residual_streams, num_dynamic_alpha_proposals)
121
121
 
122
122
  self.pool = pool
123
123
  self.to_latent = nn.Identity()
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "hyper-connections"
3
- version = "0.3.16"
3
+ version = "0.4.0"
4
4
  description = "Hyper-Connections"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -1,11 +1,12 @@
1
1
  import pytest
2
+ param = pytest.mark.parametrize
2
3
 
3
4
  import torch
4
5
  from torch import nn
5
6
 
6
- @pytest.mark.parametrize('num_fracs', (1, 4))
7
- @pytest.mark.parametrize('disable', (False, True))
8
- @pytest.mark.parametrize('manifold_constrained', (False, True))
7
+ @param('num_fracs', (1, 4))
8
+ @param('disable', (False, True))
9
+ @param('manifold_constrained', (False, True))
9
10
  def test_readme(
10
11
  num_fracs,
11
12
  disable,
@@ -208,7 +209,11 @@ def test_mhc_dtype_restoration():
208
209
 
209
210
  assert residual.dtype == torch.half
210
211
 
211
- def test_mhc_vit():
212
+ @param('num_dynamic_alpha_proposals', (1, 2))
213
+ def test_mhc_vit(
214
+ num_dynamic_alpha_proposals
215
+ ):
216
+
212
217
  from hyper_connections.vit import ViT
213
218
 
214
219
  v = ViT(
@@ -221,7 +226,8 @@ def test_mhc_vit():
221
226
  mlp_dim = 2048,
222
227
  dropout = 0.1,
223
228
  emb_dropout = 0.1,
224
- num_residual_streams = 4
229
+ num_residual_streams = 4,
230
+ num_dynamic_alpha_proposals = num_dynamic_alpha_proposals
225
231
  )
226
232
 
227
233
  img = torch.randn(1, 3, 256, 256)