hyper-connections 0.3.12__py3-none-any.whl → 0.3.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hyper_connections/manifold_constrained_hyper_connections.py +13 -1
- {hyper_connections-0.3.12.dist-info → hyper_connections-0.3.14.dist-info}/METADATA +1 -1
- {hyper_connections-0.3.12.dist-info → hyper_connections-0.3.14.dist-info}/RECORD +5 -5
- {hyper_connections-0.3.12.dist-info → hyper_connections-0.3.14.dist-info}/WHEEL +0 -0
- {hyper_connections-0.3.12.dist-info → hyper_connections-0.3.14.dist-info}/licenses/LICENSE +0 -0
|
@@ -59,6 +59,14 @@ def sinkhorn_knopps(log_alpha, iters = 20):
|
|
|
59
59
|
|
|
60
60
|
return alpha.to(dtype)
|
|
61
61
|
|
|
62
|
+
def log_domain_sinkhorn_knopps(alpha, iters = 20):
|
|
63
|
+
|
|
64
|
+
for _ in range(iters):
|
|
65
|
+
alpha = alpha - alpha.logsumexp(dim = -2, keepdim = True)
|
|
66
|
+
alpha = alpha - alpha.logsumexp(dim = -1, keepdim = True)
|
|
67
|
+
|
|
68
|
+
return alpha.exp()
|
|
69
|
+
|
|
62
70
|
# main functions
|
|
63
71
|
|
|
64
72
|
def get_expand_reduce_stream_functions(
|
|
@@ -201,6 +209,7 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
201
209
|
depth_residual_fn = add,
|
|
202
210
|
num_fracs = 1, # https://arxiv.org/abs/2503.14125
|
|
203
211
|
sinkhorn_iters = 20,
|
|
212
|
+
log_domain_sinkhorn = False,
|
|
204
213
|
forward_method_names: tuple[str, ...] = (),
|
|
205
214
|
):
|
|
206
215
|
"""
|
|
@@ -272,6 +281,7 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
272
281
|
# sinkhorn related
|
|
273
282
|
|
|
274
283
|
self.sinkhorn_iters = sinkhorn_iters
|
|
284
|
+
self.log_domain_sinkhorn = log_domain_sinkhorn
|
|
275
285
|
|
|
276
286
|
# dropouts
|
|
277
287
|
|
|
@@ -354,7 +364,9 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
354
364
|
|
|
355
365
|
alpha_pre = alpha_pre.sigmoid()
|
|
356
366
|
|
|
357
|
-
|
|
367
|
+
sinkhorn_fn = sinkhorn_knopps if not self.log_domain_sinkhorn else log_domain_sinkhorn_knopps
|
|
368
|
+
|
|
369
|
+
alpha_residual = sinkhorn_fn(alpha_residual, self.sinkhorn_iters)
|
|
358
370
|
|
|
359
371
|
alpha = cat((alpha_pre, alpha_residual), dim = -1)
|
|
360
372
|
|
|
@@ -3,10 +3,10 @@ hyper_connections/hyper_connections.py,sha256=rqFJj3U0LF3uDKNKNPBpRrmf0oa2BGWVbD
|
|
|
3
3
|
hyper_connections/hyper_connections_channel_first.py,sha256=Mh_hzhTi96ZoOPmhSKwUaF4TbHpNqhs83wNe5hNuL7o,6532
|
|
4
4
|
hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=6BXKdSwyx6wdQVseebKG2EQkhVaVLrrepOlL8lLnex4,7855
|
|
5
5
|
hyper_connections/hyper_connections_with_multi_input_streams.py,sha256=ueT3CJPHrt5hRU7q1bFF0rANWJh_pXqclt6HiUu1gBY,11331
|
|
6
|
-
hyper_connections/manifold_constrained_hyper_connections.py,sha256=
|
|
6
|
+
hyper_connections/manifold_constrained_hyper_connections.py,sha256=BekY37Gt5us7zZjxP6laDHCbygMww8kS5GatnveNrYw,17595
|
|
7
7
|
hyper_connections/residuals.py,sha256=JVSFJj_H7xQ3_Fd-pZH5Hdv9SveAQu29jQNvMyom5ek,921
|
|
8
8
|
hyper_connections/vit.py,sha256=fTC8hAYkD4qm-KURAj8SJ66C6ZWtsBdHf_kS-4rJZGQ,5049
|
|
9
|
-
hyper_connections-0.3.
|
|
10
|
-
hyper_connections-0.3.
|
|
11
|
-
hyper_connections-0.3.
|
|
12
|
-
hyper_connections-0.3.
|
|
9
|
+
hyper_connections-0.3.14.dist-info/METADATA,sha256=xncYFa2ttI2WOVLTy5-jk69dZ29aQa-QfFAgI4qJvHc,6705
|
|
10
|
+
hyper_connections-0.3.14.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
11
|
+
hyper_connections-0.3.14.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
|
|
12
|
+
hyper_connections-0.3.14.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|