hyper-connections 0.3.11__py3-none-any.whl → 0.3.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -59,6 +59,14 @@ def sinkhorn_knopps(log_alpha, iters = 20):
59
59
 
60
60
  return alpha.to(dtype)
61
61
 
62
+ def log_domain_sinkhorn_knopps(alpha, iters = 20):
63
+
64
+ for _ in range(iters):
65
+ alpha = alpha - alpha.logsumexp(dim = -2, keepdim = True)
66
+ alpha = alpha - alpha.logsumexp(dim = -1, keepdim = True)
67
+
68
+ return alpha.exp()
69
+
62
70
  # main functions
63
71
 
64
72
  def get_expand_reduce_stream_functions(
@@ -201,6 +209,7 @@ class ManifoldConstrainedHyperConnections(Module):
201
209
  depth_residual_fn = add,
202
210
  num_fracs = 1, # https://arxiv.org/abs/2503.14125
203
211
  sinkhorn_iters = 20,
212
+ log_domain_sinkhorn = False,
204
213
  forward_method_names: tuple[str, ...] = (),
205
214
  ):
206
215
  """
@@ -272,6 +281,7 @@ class ManifoldConstrainedHyperConnections(Module):
272
281
  # sinkhorn related
273
282
 
274
283
  self.sinkhorn_iters = sinkhorn_iters
284
+ self.log_domain_sinkhorn = log_domain_sinkhorn
275
285
 
276
286
  # dropouts
277
287
 
@@ -348,22 +358,20 @@ class ManifoldConstrainedHyperConnections(Module):
348
358
 
349
359
  alpha = dynamic_alpha + static_alpha
350
360
 
351
- alpha = self.split_fracs(alpha) # (batch, seq, fracs1, streams, fracs2, input + residual streams)
352
-
353
361
  # the alpha is now split and "manifold constrained" with sinkhorn and sigmoid
354
362
 
355
363
  alpha_pre, alpha_residual = alpha[..., :self.num_input_views], alpha[..., self.num_input_views:]
356
364
 
357
365
  alpha_pre = alpha_pre.sigmoid()
358
366
 
359
- alpha_residual = rearrange(alpha_residual, '... (v s1 s2) -> ... v s1 s2', v = self.num_input_views, s1 = streams)
360
-
361
- alpha_residual = sinkhorn_knopps(alpha_residual, self.sinkhorn_iters)
367
+ sinkhorn_fn = sinkhorn_knopps if not self.log_domain_sinkhorn else log_domain_sinkhorn_knopps
362
368
 
363
- alpha_residual = rearrange(alpha_residual, '... v s1 s2 -> ... (v s1 s2)')
369
+ alpha_residual = sinkhorn_fn(alpha_residual, self.sinkhorn_iters)
364
370
 
365
371
  alpha = cat((alpha_pre, alpha_residual), dim = -1)
366
372
 
373
+ alpha = self.split_fracs(alpha) # (batch, seq, fracs1, streams, fracs2, input + residual streams)
374
+
367
375
  # beta for weights from branch output back to residual streams
368
376
 
369
377
  beta = None
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.3.11
3
+ Version: 0.3.14
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -3,10 +3,10 @@ hyper_connections/hyper_connections.py,sha256=rqFJj3U0LF3uDKNKNPBpRrmf0oa2BGWVbD
3
3
  hyper_connections/hyper_connections_channel_first.py,sha256=Mh_hzhTi96ZoOPmhSKwUaF4TbHpNqhs83wNe5hNuL7o,6532
4
4
  hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=6BXKdSwyx6wdQVseebKG2EQkhVaVLrrepOlL8lLnex4,7855
5
5
  hyper_connections/hyper_connections_with_multi_input_streams.py,sha256=ueT3CJPHrt5hRU7q1bFF0rANWJh_pXqclt6HiUu1gBY,11331
6
- hyper_connections/manifold_constrained_hyper_connections.py,sha256=u_yeGVdPCMvZnPzj3cM_Kfo0XoZRChU4z-qL8oIKnRQ,17376
6
+ hyper_connections/manifold_constrained_hyper_connections.py,sha256=BekY37Gt5us7zZjxP6laDHCbygMww8kS5GatnveNrYw,17595
7
7
  hyper_connections/residuals.py,sha256=JVSFJj_H7xQ3_Fd-pZH5Hdv9SveAQu29jQNvMyom5ek,921
8
8
  hyper_connections/vit.py,sha256=fTC8hAYkD4qm-KURAj8SJ66C6ZWtsBdHf_kS-4rJZGQ,5049
9
- hyper_connections-0.3.11.dist-info/METADATA,sha256=HPZFYGNNJENlhuSEdKAdUUPEBaQ88ikkwViItlZtBHs,6705
10
- hyper_connections-0.3.11.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
11
- hyper_connections-0.3.11.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
12
- hyper_connections-0.3.11.dist-info/RECORD,,
9
+ hyper_connections-0.3.14.dist-info/METADATA,sha256=xncYFa2ttI2WOVLTy5-jk69dZ29aQa-QfFAgI4qJvHc,6705
10
+ hyper_connections-0.3.14.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
11
+ hyper_connections-0.3.14.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
12
+ hyper_connections-0.3.14.dist-info/RECORD,,