hyper-connections 0.3.15__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -21,6 +21,7 @@ s - residual streams
21
21
  t - residual streams + num branch inputs
22
22
  f - number of fractions (division of feature dimension space)
23
23
  v - number of views for branch input
24
+ p - proposals
24
25
  """
25
26
 
26
27
  # helper functions
@@ -67,7 +68,7 @@ def log_domain_sinkhorn_knopps(log_alpha, iters = 20):
67
68
  log_alpha = log_alpha - log_alpha.logsumexp(dim = -2, keepdim = True)
68
69
  log_alpha = log_alpha - log_alpha.logsumexp(dim = -1, keepdim = True)
69
70
 
70
- return log_alpha.exp()
71
+ return log_alpha.exp().to(dtype)
71
72
 
72
73
  # main functions
73
74
 
@@ -213,6 +214,7 @@ class ManifoldConstrainedHyperConnections(Module):
213
214
  sinkhorn_iters = 20,
214
215
  log_domain_sinkhorn = False,
215
216
  forward_method_names: tuple[str, ...] = (),
217
+ num_dynamic_alpha_proposals = 1
216
218
  ):
217
219
  """
218
220
  Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
@@ -256,6 +258,11 @@ class ManifoldConstrainedHyperConnections(Module):
256
258
  assert num_input_views >= 1
257
259
  self.num_input_views = num_input_views
258
260
 
261
+ # number of dynamic alpha proposals, for averaging Hres across proposals
262
+
263
+ self.has_dynamic_alpha_proposals = num_dynamic_alpha_proposals > 1
264
+ self.num_dynamic_alpha_proposals = num_dynamic_alpha_proposals
265
+
259
266
  # width connection
260
267
 
261
268
  init_alpha0 = torch.zeros((num_residual_streams_fracs, num_input_views_fracs))
@@ -263,7 +270,7 @@ class ManifoldConstrainedHyperConnections(Module):
263
270
 
264
271
  self.static_alpha = nn.Parameter(cat((init_alpha0, torch.eye(num_residual_streams_fracs)), dim = 1))
265
272
 
266
- self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams_fracs + num_input_views_fracs))
273
+ self.dynamic_alpha_fn = nn.Parameter(torch.zeros(num_dynamic_alpha_proposals, dim, num_residual_streams_fracs + num_input_views_fracs))
267
274
 
268
275
  self.pre_branch_scale = nn.Parameter(torch.ones(1) * 1e-2)
269
276
  self.residual_scale = nn.Parameter(torch.ones(1) * 1e-2)
@@ -346,7 +353,7 @@ class ManifoldConstrainedHyperConnections(Module):
346
353
 
347
354
  normed = normed.float()
348
355
 
349
- wc_weight = normed @ self.dynamic_alpha_fn.float()
356
+ wc_weight = einsum(normed, self.dynamic_alpha_fn.float(), '... d, p d e -> p ... e')
350
357
 
351
358
  pre_branch_scale = repeat(self.pre_branch_scale.float(), '1 -> s', s = self.num_fracs)
352
359
  residual_scale = repeat(self.residual_scale.float(), '1 -> s', s = self.num_fracs * streams)
@@ -372,6 +379,11 @@ class ManifoldConstrainedHyperConnections(Module):
372
379
 
373
380
  alpha = cat((alpha_pre, alpha_residual), dim = -1)
374
381
 
382
+ if self.has_dynamic_alpha_proposals:
383
+ alpha = reduce(alpha, 'p ... -> ...', 'mean')
384
+ else:
385
+ alpha = rearrange(alpha, '1 ... -> ...')
386
+
375
387
  alpha = self.split_fracs(alpha) # (batch, seq, fracs1, streams, fracs2, input + residual streams)
376
388
 
377
389
  # beta for weights from branch output back to residual streams
hyper_connections/vit.py CHANGED
@@ -66,12 +66,12 @@ class Attention(Module):
66
66
  return self.to_out(out)
67
67
 
68
68
  class Transformer(Module):
69
- def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., num_residual_streams = 4):
69
+ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., num_residual_streams = 4, num_dynamic_alpha_proposals = 1):
70
70
  super().__init__()
71
71
  self.norm = nn.LayerNorm(dim)
72
72
  self.layers = ModuleList([])
73
73
 
74
- init_hyper_conn, self.expand_streams, self.reduce_streams = mHC.get_init_and_expand_reduce_stream_functions(num_residual_streams)
74
+ init_hyper_conn, self.expand_streams, self.reduce_streams = mHC.get_init_and_expand_reduce_stream_functions(num_residual_streams, num_dynamic_alpha_proposals = num_dynamic_alpha_proposals)
75
75
 
76
76
  for _ in range(depth):
77
77
  self.layers.append(ModuleList([
@@ -92,7 +92,7 @@ class Transformer(Module):
92
92
  return self.norm(x)
93
93
 
94
94
  class ViT(Module):
95
- def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., num_residual_streams = 4):
95
+ def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., num_residual_streams = 4, num_dynamic_alpha_proposals = 1):
96
96
  super().__init__()
97
97
  image_height, image_width = pair(image_size)
98
98
  patch_height, patch_width = pair(patch_size)
@@ -117,7 +117,7 @@ class ViT(Module):
117
117
 
118
118
  self.dropout = nn.Dropout(emb_dropout)
119
119
 
120
- self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
120
+ self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, num_residual_streams, num_dynamic_alpha_proposals)
121
121
 
122
122
  self.pool = pool
123
123
  self.to_latent = nn.Identity()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.3.15
3
+ Version: 0.4.0
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -3,10 +3,10 @@ hyper_connections/hyper_connections.py,sha256=rqFJj3U0LF3uDKNKNPBpRrmf0oa2BGWVbD
3
3
  hyper_connections/hyper_connections_channel_first.py,sha256=Mh_hzhTi96ZoOPmhSKwUaF4TbHpNqhs83wNe5hNuL7o,6532
4
4
  hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=6BXKdSwyx6wdQVseebKG2EQkhVaVLrrepOlL8lLnex4,7855
5
5
  hyper_connections/hyper_connections_with_multi_input_streams.py,sha256=ueT3CJPHrt5hRU7q1bFF0rANWJh_pXqclt6HiUu1gBY,11331
6
- hyper_connections/manifold_constrained_hyper_connections.py,sha256=iS0I6Ha5iAkWdrH-dEYpyLkHRRF8yarFqHJwlTV2qLI,17689
6
+ hyper_connections/manifold_constrained_hyper_connections.py,sha256=-gWJbFR7PT-ji1pfPIMUGWdsfXlcUWVkyJvDmfM7BqM,18216
7
7
  hyper_connections/residuals.py,sha256=JVSFJj_H7xQ3_Fd-pZH5Hdv9SveAQu29jQNvMyom5ek,921
8
- hyper_connections/vit.py,sha256=fTC8hAYkD4qm-KURAj8SJ66C6ZWtsBdHf_kS-4rJZGQ,5049
9
- hyper_connections-0.3.15.dist-info/METADATA,sha256=PF7nUEiHAtWHgnRQfni_eNwywX0z7_5vKyj_bNm9oi0,6705
10
- hyper_connections-0.3.15.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
11
- hyper_connections-0.3.15.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
12
- hyper_connections-0.3.15.dist-info/RECORD,,
8
+ hyper_connections/vit.py,sha256=BOWVfCAIzDQdnTq8OBzNUyiKGGILYZkIQ6mr1GKJVB0,5225
9
+ hyper_connections-0.4.0.dist-info/METADATA,sha256=jEzPsFZ-71ZJ-WbkuhzcogDbvhWk4rKI3oINGQ2nMe0,6704
10
+ hyper_connections-0.4.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
11
+ hyper_connections-0.4.0.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
12
+ hyper_connections-0.4.0.dist-info/RECORD,,