hyper-connections 0.4.7__py3-none-any.whl → 0.4.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -13,6 +13,8 @@ from torch.utils._pytree import tree_flatten, tree_unflatten
13
13
  from einops import rearrange, repeat, reduce, einsum
14
14
  from einops.layers.torch import Rearrange, Reduce
15
15
 
16
+ from torch_einops_utils import pack_with_inverse
17
+
16
18
  """
17
19
  ein notation:
18
20
  b - batch
@@ -241,6 +243,7 @@ class ManifoldConstrainedHyperConnections(Module):
241
243
  forward_method_names: tuple[str, ...] = (),
242
244
  num_dynamic_alpha_proposals = 1,
243
245
  use_triton_sinkhorn = False,
246
+ mix_streams_before_norm = False, # whether to mix the residual streams before the norm (that then projects to Hpre, Hpost, Hresidual)
244
247
  ):
245
248
  """
246
249
  Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
@@ -263,6 +266,16 @@ class ManifoldConstrainedHyperConnections(Module):
263
266
 
264
267
  dim //= num_fracs # effective dim handled in dimension is feature dimension divided by num fractions
265
268
 
269
+ # whether to mix the streams before the norm below
270
+ # this would be equivalent to separable depthwise convs from yesteryears (with a norm in between) - parameter efficient improv
271
+
272
+ self.maybe_mix_streams = None
273
+
274
+ if mix_streams_before_norm:
275
+ self.maybe_mix_streams = nn.Conv2d(num_residual_streams, num_residual_streams, 1, bias = False)
276
+
277
+ nn.init.dirac_(self.maybe_mix_streams.weight)
278
+
266
279
  # they used layernorm in paper, but rmsnorm is fine given what we know now
267
280
 
268
281
  self.norm = RMSNorm(dim)
@@ -370,9 +383,19 @@ class ManifoldConstrainedHyperConnections(Module):
370
383
 
371
384
  residuals = self.split_fracs(residuals)
372
385
 
386
+ # maybe mix streams
387
+
388
+ norm_input = residuals
389
+
390
+ if exists(self.maybe_mix_streams):
391
+
392
+ norm_input, inverse_pack_lead_dims = pack_with_inverse(norm_input, '* c h w')
393
+ norm_input = self.maybe_mix_streams(norm_input)
394
+ norm_input = inverse_pack_lead_dims(norm_input)
395
+
373
396
  # norm
374
397
 
375
- normed = self.norm(residuals)
398
+ normed = self.norm(norm_input)
376
399
 
377
400
  # alpha for weighted sum of residuals going into branch
378
401
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.4.7
3
+ Version: 0.4.9
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -35,6 +35,7 @@ Classifier: Programming Language :: Python :: 3.9
35
35
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
36
  Requires-Python: >=3.9
37
37
  Requires-Dist: einops>=0.8.1
38
+ Requires-Dist: torch-einops-utils>=0.0.20
38
39
  Requires-Dist: torch>=2.5
39
40
  Provides-Extra: examples
40
41
  Description-Content-Type: text/markdown
@@ -3,12 +3,12 @@ hyper_connections/hyper_connections.py,sha256=2F-104cGE82KCK0KeC07NSOJNPT-0PCtvX
3
3
  hyper_connections/hyper_connections_channel_first.py,sha256=5vAen4WXxNI9K07ndLBQJwdJv-OjoXznta5EIQTaQNA,6512
4
4
  hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=yn2AlFB6qCYQeRhhhaMlCM3mxxLEdWCYwU2p9TsMwWI,7835
5
5
  hyper_connections/hyper_connections_with_multi_input_streams.py,sha256=iFPw5pgCRHTo16nBJ2PExKSrvTyCh7ba7Py14P1oSPE,11311
6
- hyper_connections/mHCv2.py,sha256=XB2HwxTo7daZvy9fzF8UjHI12ephwgE91h9AH2Ou4WI,17452
6
+ hyper_connections/mHCv2.py,sha256=1kdFEbO1WLFsQT-nZUPRFf-c8CLsOiDrptBTyfuWsWY,18399
7
7
  hyper_connections/manifold_constrained_hyper_connections.py,sha256=E4os-6q_SMjJO1JD0EG8rFTCXA7MQoy-aqUlM7KVS5Q,18269
8
8
  hyper_connections/residuals.py,sha256=JVSFJj_H7xQ3_Fd-pZH5Hdv9SveAQu29jQNvMyom5ek,921
9
9
  hyper_connections/triton_sinkhorn.py,sha256=n2WyQcUemtv5T5Sk2nljnSpV2hEED4I3HaPsIUy4638,5905
10
10
  hyper_connections/vit.py,sha256=dh8AVMUPaUHuWxXJEHoMW_G5nj-EQQjDmgbPwwhiq5g,5215
11
- hyper_connections-0.4.7.dist-info/METADATA,sha256=2ajn-IuCxuUjgnOw5dEBxXKqLJbyQohHSGgKJ2dZFoA,6704
12
- hyper_connections-0.4.7.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
13
- hyper_connections-0.4.7.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
14
- hyper_connections-0.4.7.dist-info/RECORD,,
11
+ hyper_connections-0.4.9.dist-info/METADATA,sha256=MDfxbjmRiv15To0-HwzrFuROTgH1Yo2tWkl5E1vVRZ0,6746
12
+ hyper_connections-0.4.9.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
13
+ hyper_connections-0.4.9.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
14
+ hyper_connections-0.4.9.dist-info/RECORD,,