hyper-connections 0.3.12__py3-none-any.whl → 0.3.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -59,6 +59,14 @@ def sinkhorn_knopps(log_alpha, iters = 20):
59
59
 
60
60
  return alpha.to(dtype)
61
61
 
62
+ def log_domain_sinkhorn_knopps(alpha, iters = 20):
63
+
64
+ for _ in range(iters):
65
+ alpha = alpha - alpha.logsumexp(dim = -2, keepdim = True)
66
+ alpha = alpha - alpha.logsumexp(dim = -1, keepdim = True)
67
+
68
+ return alpha.exp()
69
+
62
70
  # main functions
63
71
 
64
72
  def get_expand_reduce_stream_functions(
@@ -201,6 +209,7 @@ class ManifoldConstrainedHyperConnections(Module):
201
209
  depth_residual_fn = add,
202
210
  num_fracs = 1, # https://arxiv.org/abs/2503.14125
203
211
  sinkhorn_iters = 20,
212
+ log_domain_sinkhorn = False,
204
213
  forward_method_names: tuple[str, ...] = (),
205
214
  ):
206
215
  """
@@ -272,6 +281,7 @@ class ManifoldConstrainedHyperConnections(Module):
272
281
  # sinkhorn related
273
282
 
274
283
  self.sinkhorn_iters = sinkhorn_iters
284
+ self.log_domain_sinkhorn = log_domain_sinkhorn
275
285
 
276
286
  # dropouts
277
287
 
@@ -354,7 +364,9 @@ class ManifoldConstrainedHyperConnections(Module):
354
364
 
355
365
  alpha_pre = alpha_pre.sigmoid()
356
366
 
357
- alpha_residual = sinkhorn_knopps(alpha_residual, self.sinkhorn_iters)
367
+ sinkhorn_fn = sinkhorn_knopps if not self.log_domain_sinkhorn else log_domain_sinkhorn_knopps
368
+
369
+ alpha_residual = sinkhorn_fn(alpha_residual, self.sinkhorn_iters)
358
370
 
359
371
  alpha = cat((alpha_pre, alpha_residual), dim = -1)
360
372
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.3.12
3
+ Version: 0.3.14
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -3,10 +3,10 @@ hyper_connections/hyper_connections.py,sha256=rqFJj3U0LF3uDKNKNPBpRrmf0oa2BGWVbD
3
3
  hyper_connections/hyper_connections_channel_first.py,sha256=Mh_hzhTi96ZoOPmhSKwUaF4TbHpNqhs83wNe5hNuL7o,6532
4
4
  hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=6BXKdSwyx6wdQVseebKG2EQkhVaVLrrepOlL8lLnex4,7855
5
5
  hyper_connections/hyper_connections_with_multi_input_streams.py,sha256=ueT3CJPHrt5hRU7q1bFF0rANWJh_pXqclt6HiUu1gBY,11331
6
- hyper_connections/manifold_constrained_hyper_connections.py,sha256=JF_ncPloyNLr303xZMlCiCorCE1Qt_soo1oeOyGYIdc,17168
6
+ hyper_connections/manifold_constrained_hyper_connections.py,sha256=BekY37Gt5us7zZjxP6laDHCbygMww8kS5GatnveNrYw,17595
7
7
  hyper_connections/residuals.py,sha256=JVSFJj_H7xQ3_Fd-pZH5Hdv9SveAQu29jQNvMyom5ek,921
8
8
  hyper_connections/vit.py,sha256=fTC8hAYkD4qm-KURAj8SJ66C6ZWtsBdHf_kS-4rJZGQ,5049
9
- hyper_connections-0.3.12.dist-info/METADATA,sha256=OyTViiZQQ1AmohzVqRhf71HOFPjLEVMGPVY5kY6wHsg,6705
10
- hyper_connections-0.3.12.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
11
- hyper_connections-0.3.12.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
12
- hyper_connections-0.3.12.dist-info/RECORD,,
9
+ hyper_connections-0.3.14.dist-info/METADATA,sha256=xncYFa2ttI2WOVLTy5-jk69dZ29aQa-QfFAgI4qJvHc,6705
10
+ hyper_connections-0.3.14.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
11
+ hyper_connections-0.3.14.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
12
+ hyper_connections-0.3.14.dist-info/RECORD,,