hyper-connections 0.4.7__tar.gz → 0.4.9__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (19) hide show
  1. {hyper_connections-0.4.7 → hyper_connections-0.4.9}/PKG-INFO +2 -1
  2. {hyper_connections-0.4.7 → hyper_connections-0.4.9}/hyper_connections/mHCv2.py +24 -1
  3. {hyper_connections-0.4.7 → hyper_connections-0.4.9}/pyproject.toml +2 -1
  4. {hyper_connections-0.4.7 → hyper_connections-0.4.9}/tests/test_hyper_connections.py +6 -2
  5. {hyper_connections-0.4.7 → hyper_connections-0.4.9}/.github/workflows/python-publish.yml +0 -0
  6. {hyper_connections-0.4.7 → hyper_connections-0.4.9}/.github/workflows/test.yml +0 -0
  7. {hyper_connections-0.4.7 → hyper_connections-0.4.9}/.gitignore +0 -0
  8. {hyper_connections-0.4.7 → hyper_connections-0.4.9}/LICENSE +0 -0
  9. {hyper_connections-0.4.7 → hyper_connections-0.4.9}/README.md +0 -0
  10. {hyper_connections-0.4.7 → hyper_connections-0.4.9}/hyper-connections.png +0 -0
  11. {hyper_connections-0.4.7 → hyper_connections-0.4.9}/hyper_connections/__init__.py +0 -0
  12. {hyper_connections-0.4.7 → hyper_connections-0.4.9}/hyper_connections/hyper_connections.py +0 -0
  13. {hyper_connections-0.4.7 → hyper_connections-0.4.9}/hyper_connections/hyper_connections_channel_first.py +0 -0
  14. {hyper_connections-0.4.7 → hyper_connections-0.4.9}/hyper_connections/hyper_connections_with_multi_branch_inputs.py +0 -0
  15. {hyper_connections-0.4.7 → hyper_connections-0.4.9}/hyper_connections/hyper_connections_with_multi_input_streams.py +0 -0
  16. {hyper_connections-0.4.7 → hyper_connections-0.4.9}/hyper_connections/manifold_constrained_hyper_connections.py +0 -0
  17. {hyper_connections-0.4.7 → hyper_connections-0.4.9}/hyper_connections/residuals.py +0 -0
  18. {hyper_connections-0.4.7 → hyper_connections-0.4.9}/hyper_connections/triton_sinkhorn.py +0 -0
  19. {hyper_connections-0.4.7 → hyper_connections-0.4.9}/hyper_connections/vit.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.4.7
3
+ Version: 0.4.9
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -35,6 +35,7 @@ Classifier: Programming Language :: Python :: 3.9
35
35
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
36
  Requires-Python: >=3.9
37
37
  Requires-Dist: einops>=0.8.1
38
+ Requires-Dist: torch-einops-utils>=0.0.20
38
39
  Requires-Dist: torch>=2.5
39
40
  Provides-Extra: examples
40
41
  Description-Content-Type: text/markdown
@@ -13,6 +13,8 @@ from torch.utils._pytree import tree_flatten, tree_unflatten
13
13
  from einops import rearrange, repeat, reduce, einsum
14
14
  from einops.layers.torch import Rearrange, Reduce
15
15
 
16
+ from torch_einops_utils import pack_with_inverse
17
+
16
18
  """
17
19
  ein notation:
18
20
  b - batch
@@ -241,6 +243,7 @@ class ManifoldConstrainedHyperConnections(Module):
241
243
  forward_method_names: tuple[str, ...] = (),
242
244
  num_dynamic_alpha_proposals = 1,
243
245
  use_triton_sinkhorn = False,
246
+ mix_streams_before_norm = False, # whether to mix the residual streams before the norm (that then projects to Hpre, Hpost, Hresidual)
244
247
  ):
245
248
  """
246
249
  Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
@@ -263,6 +266,16 @@ class ManifoldConstrainedHyperConnections(Module):
263
266
 
264
267
  dim //= num_fracs # effective dim handled in dimension is feature dimension divided by num fractions
265
268
 
269
+ # whether to mix the streams before the norm below
270
+ # this would be equivalent to separable depthwise convs from yesteryears (with a norm in between) - parameter efficient improv
271
+
272
+ self.maybe_mix_streams = None
273
+
274
+ if mix_streams_before_norm:
275
+ self.maybe_mix_streams = nn.Conv2d(num_residual_streams, num_residual_streams, 1, bias = False)
276
+
277
+ nn.init.dirac_(self.maybe_mix_streams.weight)
278
+
266
279
  # they used layernorm in paper, but rmsnorm is fine given what we know now
267
280
 
268
281
  self.norm = RMSNorm(dim)
@@ -370,9 +383,19 @@ class ManifoldConstrainedHyperConnections(Module):
370
383
 
371
384
  residuals = self.split_fracs(residuals)
372
385
 
386
+ # maybe mix streams
387
+
388
+ norm_input = residuals
389
+
390
+ if exists(self.maybe_mix_streams):
391
+
392
+ norm_input, inverse_pack_lead_dims = pack_with_inverse(norm_input, '* c h w')
393
+ norm_input = self.maybe_mix_streams(norm_input)
394
+ norm_input = inverse_pack_lead_dims(norm_input)
395
+
373
396
  # norm
374
397
 
375
- normed = self.norm(residuals)
398
+ normed = self.norm(norm_input)
376
399
 
377
400
  # alpha for weighted sum of residuals going into branch
378
401
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "hyper-connections"
3
- version = "0.4.7"
3
+ version = "0.4.9"
4
4
  description = "Hyper-Connections"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -25,6 +25,7 @@ classifiers=[
25
25
  dependencies = [
26
26
  "einops>=0.8.1",
27
27
  "torch>=2.5",
28
+ "torch-einops-utils>=0.0.20"
28
29
  ]
29
30
 
30
31
  [project.urls]
@@ -227,7 +227,9 @@ def test_mhc_vit(
227
227
  dropout = 0.1,
228
228
  emb_dropout = 0.1,
229
229
  num_residual_streams = 4,
230
- num_dynamic_alpha_proposals = num_dynamic_alpha_proposals
230
+ mhc_kwargs = dict(
231
+ num_dynamic_alpha_proposals = num_dynamic_alpha_proposals
232
+ )
231
233
  )
232
234
 
233
235
  img = torch.randn(1, 3, 256, 256)
@@ -238,11 +240,13 @@ def test_mhc_vit(
238
240
  @param('num_fracs', (1, 2))
239
241
  @param('num_streams', (1, 3))
240
242
  @param('disable', (False, True))
243
+ @param('mix_streams_before_norm', (False, True))
241
244
  @param('add_attn_pool_reduce_stream', (False, True))
242
245
  def test_mhcv2(
243
246
  num_fracs,
244
247
  num_streams,
245
248
  disable,
249
+ mix_streams_before_norm,
246
250
  add_attn_pool_reduce_stream
247
251
  ):
248
252
  import torch
@@ -261,7 +265,7 @@ def test_mhcv2(
261
265
 
262
266
  from hyper_connections.mHCv2 import get_init_and_expand_reduce_stream_functions
263
267
 
264
- init_hyper_conn, expand_stream, reduce_stream = get_init_and_expand_reduce_stream_functions(num_streams, dim = 512, num_fracs = num_fracs, disable = disable, add_attn_pool_reduce_stream = add_attn_pool_reduce_stream)
268
+ init_hyper_conn, expand_stream, reduce_stream = get_init_and_expand_reduce_stream_functions(num_streams, dim = 512, num_fracs = num_fracs, mix_streams_before_norm = mix_streams_before_norm, disable = disable, add_attn_pool_reduce_stream = add_attn_pool_reduce_stream)
265
269
 
266
270
  # 1. wrap your branch function
267
271