hyper-connections 0.3.12__py3-none-any.whl → 0.3.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -59,6 +59,16 @@ def sinkhorn_knopps(log_alpha, iters = 20):
59
59
 
60
60
  return alpha.to(dtype)
61
61
 
62
+ def log_domain_sinkhorn_knopps(log_alpha, iters = 20):
63
+ dtype = log_alpha.dtype
64
+ log_alpha = log_alpha.float()
65
+
66
+ for _ in range(iters):
67
+ log_alpha = log_alpha - log_alpha.logsumexp(dim = -2, keepdim = True)
68
+ log_alpha = log_alpha - log_alpha.logsumexp(dim = -1, keepdim = True)
69
+
70
+ return log_alpha.exp()
71
+
62
72
  # main functions
63
73
 
64
74
  def get_expand_reduce_stream_functions(
@@ -201,6 +211,7 @@ class ManifoldConstrainedHyperConnections(Module):
201
211
  depth_residual_fn = add,
202
212
  num_fracs = 1, # https://arxiv.org/abs/2503.14125
203
213
  sinkhorn_iters = 20,
214
+ log_domain_sinkhorn = False,
204
215
  forward_method_names: tuple[str, ...] = (),
205
216
  ):
206
217
  """
@@ -272,6 +283,7 @@ class ManifoldConstrainedHyperConnections(Module):
272
283
  # sinkhorn related
273
284
 
274
285
  self.sinkhorn_iters = sinkhorn_iters
286
+ self.log_domain_sinkhorn = log_domain_sinkhorn
275
287
 
276
288
  # dropouts
277
289
 
@@ -354,7 +366,9 @@ class ManifoldConstrainedHyperConnections(Module):
354
366
 
355
367
  alpha_pre = alpha_pre.sigmoid()
356
368
 
357
- alpha_residual = sinkhorn_knopps(alpha_residual, self.sinkhorn_iters)
369
+ sinkhorn_fn = sinkhorn_knopps if not self.log_domain_sinkhorn else log_domain_sinkhorn_knopps
370
+
371
+ alpha_residual = sinkhorn_fn(alpha_residual, self.sinkhorn_iters)
358
372
 
359
373
  alpha = cat((alpha_pre, alpha_residual), dim = -1)
360
374
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.3.12
3
+ Version: 0.3.15
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -3,10 +3,10 @@ hyper_connections/hyper_connections.py,sha256=rqFJj3U0LF3uDKNKNPBpRrmf0oa2BGWVbD
3
3
  hyper_connections/hyper_connections_channel_first.py,sha256=Mh_hzhTi96ZoOPmhSKwUaF4TbHpNqhs83wNe5hNuL7o,6532
4
4
  hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=6BXKdSwyx6wdQVseebKG2EQkhVaVLrrepOlL8lLnex4,7855
5
5
  hyper_connections/hyper_connections_with_multi_input_streams.py,sha256=ueT3CJPHrt5hRU7q1bFF0rANWJh_pXqclt6HiUu1gBY,11331
6
- hyper_connections/manifold_constrained_hyper_connections.py,sha256=JF_ncPloyNLr303xZMlCiCorCE1Qt_soo1oeOyGYIdc,17168
6
+ hyper_connections/manifold_constrained_hyper_connections.py,sha256=iS0I6Ha5iAkWdrH-dEYpyLkHRRF8yarFqHJwlTV2qLI,17689
7
7
  hyper_connections/residuals.py,sha256=JVSFJj_H7xQ3_Fd-pZH5Hdv9SveAQu29jQNvMyom5ek,921
8
8
  hyper_connections/vit.py,sha256=fTC8hAYkD4qm-KURAj8SJ66C6ZWtsBdHf_kS-4rJZGQ,5049
9
- hyper_connections-0.3.12.dist-info/METADATA,sha256=OyTViiZQQ1AmohzVqRhf71HOFPjLEVMGPVY5kY6wHsg,6705
10
- hyper_connections-0.3.12.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
11
- hyper_connections-0.3.12.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
12
- hyper_connections-0.3.12.dist-info/RECORD,,
9
+ hyper_connections-0.3.15.dist-info/METADATA,sha256=PF7nUEiHAtWHgnRQfni_eNwywX0z7_5vKyj_bNm9oi0,6705
10
+ hyper_connections-0.3.15.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
11
+ hyper_connections-0.3.15.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
12
+ hyper_connections-0.3.15.dist-info/RECORD,,