hyper-connections 0.4.6__py3-none-any.whl → 0.4.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -13,6 +13,8 @@ from torch.utils._pytree import tree_flatten, tree_unflatten
13
13
  from einops import rearrange, repeat, reduce, einsum
14
14
  from einops.layers.torch import Rearrange, Reduce
15
15
 
16
+ from torch_einops_utils import pack_with_inverse
17
+
16
18
  """
17
19
  ein notation:
18
20
  b - batch
@@ -124,8 +126,8 @@ def get_init_and_expand_reduce_stream_functions(
124
126
 
125
127
  hyper_conn_klass = ManifoldConstrainedHyperConnections if not disable else Residual
126
128
 
127
- kwargs.pop('add_attn_pool_reduce_stream', None)
128
129
  init_hyper_conn_fn = partial(hyper_conn_klass, num_streams, num_fracs = num_fracs, sinkhorn_iters = sinkhorn_iters, use_triton_sinkhorn = use_triton_sinkhorn, **kwargs)
130
+
129
131
  expand_reduce_fns = get_expand_reduce_stream_functions(
130
132
  num_streams,
131
133
  add_stream_embed = add_stream_embed,
@@ -241,6 +243,7 @@ class ManifoldConstrainedHyperConnections(Module):
241
243
  forward_method_names: tuple[str, ...] = (),
242
244
  num_dynamic_alpha_proposals = 1,
243
245
  use_triton_sinkhorn = False,
246
+ mix_streams_before_norm = False, # whether to mix the residual streams before the norm (that then projects to Hpre, Hpost, Hresidual)
244
247
  ):
245
248
  """
246
249
  Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
@@ -263,6 +266,16 @@ class ManifoldConstrainedHyperConnections(Module):
263
266
 
264
267
  dim //= num_fracs # effective dim handled in dimension is feature dimension divided by num fractions
265
268
 
269
+ # whether to mix the streams before the norm below
270
+ # this would be equivalent to separable depthwise convs from yesteryears (with a norm in between) - parameter efficient improv
271
+
272
+ self.maybe_mix_streams = None
273
+
274
+ if mix_streams_before_norm:
275
+ self.maybe_mix_streams = nn.Conv2d(num_residual_streams, num_residual_streams, 1, bias = False)
276
+
277
+ nn.init.dirac_(self.maybe_mix_streams.weight)
278
+
266
279
  # they used layernorm in paper, but rmsnorm is fine given what we know now
267
280
 
268
281
  self.norm = RMSNorm(dim)
@@ -370,6 +383,14 @@ class ManifoldConstrainedHyperConnections(Module):
370
383
 
371
384
  residuals = self.split_fracs(residuals)
372
385
 
386
+ # maybe mix streams
387
+
388
+ if exists(self.maybe_mix_streams):
389
+
390
+ residuals, inverse_pack_lead_dims = pack_with_inverse(residuals, '* c h w')
391
+ residuals = self.maybe_mix_streams(residuals)
392
+ residuals = inverse_pack_lead_dims(residuals)
393
+
373
394
  # norm
374
395
 
375
396
  normed = self.norm(residuals)
hyper_connections/vit.py CHANGED
@@ -5,7 +5,7 @@ from torch.nn import Module, ModuleList
5
5
  from einops import rearrange, repeat
6
6
  from einops.layers.torch import Rearrange
7
7
 
8
- from hyper_connections.manifold_constrained_hyper_connections import mHC
8
+ from hyper_connections.mHCv2 import mHC
9
9
 
10
10
  # helpers
11
11
 
@@ -66,12 +66,12 @@ class Attention(Module):
66
66
  return self.to_out(out)
67
67
 
68
68
  class Transformer(Module):
69
- def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., num_residual_streams = 4, num_dynamic_alpha_proposals = 1):
69
+ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., num_residual_streams = 4, mhc_kwargs = dict()):
70
70
  super().__init__()
71
71
  self.norm = nn.LayerNorm(dim)
72
72
  self.layers = ModuleList([])
73
73
 
74
- init_hyper_conn, self.expand_streams, self.reduce_streams = mHC.get_init_and_expand_reduce_stream_functions(num_residual_streams, num_dynamic_alpha_proposals = num_dynamic_alpha_proposals)
74
+ init_hyper_conn, self.expand_streams, self.reduce_streams = mHC.get_init_and_expand_reduce_stream_functions(num_residual_streams, **mhc_kwargs)
75
75
 
76
76
  for _ in range(depth):
77
77
  self.layers.append(ModuleList([
@@ -92,7 +92,7 @@ class Transformer(Module):
92
92
  return self.norm(x)
93
93
 
94
94
  class ViT(Module):
95
- def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., num_residual_streams = 4, num_dynamic_alpha_proposals = 1):
95
+ def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., num_residual_streams = 4, mhc_kwargs = dict(num_dynamic_alpha_proposals = 1)):
96
96
  super().__init__()
97
97
  image_height, image_width = pair(image_size)
98
98
  patch_height, patch_width = pair(patch_size)
@@ -117,7 +117,7 @@ class ViT(Module):
117
117
 
118
118
  self.dropout = nn.Dropout(emb_dropout)
119
119
 
120
- self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, num_residual_streams, num_dynamic_alpha_proposals)
120
+ self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, num_residual_streams, mhc_kwargs)
121
121
 
122
122
  self.pool = pool
123
123
  self.to_latent = nn.Identity()
@@ -154,7 +154,10 @@ if __name__ == '__main__':
154
154
  mlp_dim = 2048,
155
155
  dropout = 0.1,
156
156
  emb_dropout = 0.1,
157
- num_residual_streams = 4
157
+ num_residual_streams = 4,
158
+ mhc_kwargs = dict(
159
+ use_triton_sinkhorn = False
160
+ )
158
161
  )
159
162
 
160
163
  img = torch.randn(1, 3, 256, 256)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.4.6
3
+ Version: 0.4.8
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -35,6 +35,7 @@ Classifier: Programming Language :: Python :: 3.9
35
35
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
36
  Requires-Python: >=3.9
37
37
  Requires-Dist: einops>=0.8.1
38
+ Requires-Dist: torch-einops-utils>=0.0.20
38
39
  Requires-Dist: torch>=2.5
39
40
  Provides-Extra: examples
40
41
  Description-Content-Type: text/markdown
@@ -3,12 +3,12 @@ hyper_connections/hyper_connections.py,sha256=2F-104cGE82KCK0KeC07NSOJNPT-0PCtvX
3
3
  hyper_connections/hyper_connections_channel_first.py,sha256=5vAen4WXxNI9K07ndLBQJwdJv-OjoXznta5EIQTaQNA,6512
4
4
  hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=yn2AlFB6qCYQeRhhhaMlCM3mxxLEdWCYwU2p9TsMwWI,7835
5
5
  hyper_connections/hyper_connections_with_multi_input_streams.py,sha256=iFPw5pgCRHTo16nBJ2PExKSrvTyCh7ba7Py14P1oSPE,11311
6
- hyper_connections/mHCv2.py,sha256=k-qOt-lnDR-jnwJLTVxlNFMkJZQGT55ExpE1QxUEPco,17503
6
+ hyper_connections/mHCv2.py,sha256=LpMtlrb7Vfi2qq_cqPl9fajA5SxkMTl5QGpmvBJyD1M,18360
7
7
  hyper_connections/manifold_constrained_hyper_connections.py,sha256=E4os-6q_SMjJO1JD0EG8rFTCXA7MQoy-aqUlM7KVS5Q,18269
8
8
  hyper_connections/residuals.py,sha256=JVSFJj_H7xQ3_Fd-pZH5Hdv9SveAQu29jQNvMyom5ek,921
9
9
  hyper_connections/triton_sinkhorn.py,sha256=n2WyQcUemtv5T5Sk2nljnSpV2hEED4I3HaPsIUy4638,5905
10
- hyper_connections/vit.py,sha256=BOWVfCAIzDQdnTq8OBzNUyiKGGILYZkIQ6mr1GKJVB0,5225
11
- hyper_connections-0.4.6.dist-info/METADATA,sha256=ZU6BE9Y90LRK2Fg3WXg2Y8dKDg_qaUyyYELsqaPGD6c,6704
12
- hyper_connections-0.4.6.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
13
- hyper_connections-0.4.6.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
14
- hyper_connections-0.4.6.dist-info/RECORD,,
10
+ hyper_connections/vit.py,sha256=dh8AVMUPaUHuWxXJEHoMW_G5nj-EQQjDmgbPwwhiq5g,5215
11
+ hyper_connections-0.4.8.dist-info/METADATA,sha256=vevhBHad-7ffu1KBFcazUqU5C2XVRy1LlZkIxJUNDIs,6746
12
+ hyper_connections-0.4.8.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
13
+ hyper_connections-0.4.8.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
14
+ hyper_connections-0.4.8.dist-info/RECORD,,