hyper-connections 0.3.12__py3-none-any.whl → 0.3.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hyper_connections/manifold_constrained_hyper_connections.py +15 -1
- {hyper_connections-0.3.12.dist-info → hyper_connections-0.3.15.dist-info}/METADATA +1 -1
- {hyper_connections-0.3.12.dist-info → hyper_connections-0.3.15.dist-info}/RECORD +5 -5
- {hyper_connections-0.3.12.dist-info → hyper_connections-0.3.15.dist-info}/WHEEL +0 -0
- {hyper_connections-0.3.12.dist-info → hyper_connections-0.3.15.dist-info}/licenses/LICENSE +0 -0
|
@@ -59,6 +59,16 @@ def sinkhorn_knopps(log_alpha, iters = 20):
|
|
|
59
59
|
|
|
60
60
|
return alpha.to(dtype)
|
|
61
61
|
|
|
62
|
+
def log_domain_sinkhorn_knopps(log_alpha, iters = 20):
|
|
63
|
+
dtype = log_alpha.dtype
|
|
64
|
+
log_alpha = log_alpha.float()
|
|
65
|
+
|
|
66
|
+
for _ in range(iters):
|
|
67
|
+
log_alpha = log_alpha - log_alpha.logsumexp(dim = -2, keepdim = True)
|
|
68
|
+
log_alpha = log_alpha - log_alpha.logsumexp(dim = -1, keepdim = True)
|
|
69
|
+
|
|
70
|
+
return log_alpha.exp()
|
|
71
|
+
|
|
62
72
|
# main functions
|
|
63
73
|
|
|
64
74
|
def get_expand_reduce_stream_functions(
|
|
@@ -201,6 +211,7 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
201
211
|
depth_residual_fn = add,
|
|
202
212
|
num_fracs = 1, # https://arxiv.org/abs/2503.14125
|
|
203
213
|
sinkhorn_iters = 20,
|
|
214
|
+
log_domain_sinkhorn = False,
|
|
204
215
|
forward_method_names: tuple[str, ...] = (),
|
|
205
216
|
):
|
|
206
217
|
"""
|
|
@@ -272,6 +283,7 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
272
283
|
# sinkhorn related
|
|
273
284
|
|
|
274
285
|
self.sinkhorn_iters = sinkhorn_iters
|
|
286
|
+
self.log_domain_sinkhorn = log_domain_sinkhorn
|
|
275
287
|
|
|
276
288
|
# dropouts
|
|
277
289
|
|
|
@@ -354,7 +366,9 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
354
366
|
|
|
355
367
|
alpha_pre = alpha_pre.sigmoid()
|
|
356
368
|
|
|
357
|
-
|
|
369
|
+
sinkhorn_fn = sinkhorn_knopps if not self.log_domain_sinkhorn else log_domain_sinkhorn_knopps
|
|
370
|
+
|
|
371
|
+
alpha_residual = sinkhorn_fn(alpha_residual, self.sinkhorn_iters)
|
|
358
372
|
|
|
359
373
|
alpha = cat((alpha_pre, alpha_residual), dim = -1)
|
|
360
374
|
|
|
@@ -3,10 +3,10 @@ hyper_connections/hyper_connections.py,sha256=rqFJj3U0LF3uDKNKNPBpRrmf0oa2BGWVbD
|
|
|
3
3
|
hyper_connections/hyper_connections_channel_first.py,sha256=Mh_hzhTi96ZoOPmhSKwUaF4TbHpNqhs83wNe5hNuL7o,6532
|
|
4
4
|
hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=6BXKdSwyx6wdQVseebKG2EQkhVaVLrrepOlL8lLnex4,7855
|
|
5
5
|
hyper_connections/hyper_connections_with_multi_input_streams.py,sha256=ueT3CJPHrt5hRU7q1bFF0rANWJh_pXqclt6HiUu1gBY,11331
|
|
6
|
-
hyper_connections/manifold_constrained_hyper_connections.py,sha256=
|
|
6
|
+
hyper_connections/manifold_constrained_hyper_connections.py,sha256=iS0I6Ha5iAkWdrH-dEYpyLkHRRF8yarFqHJwlTV2qLI,17689
|
|
7
7
|
hyper_connections/residuals.py,sha256=JVSFJj_H7xQ3_Fd-pZH5Hdv9SveAQu29jQNvMyom5ek,921
|
|
8
8
|
hyper_connections/vit.py,sha256=fTC8hAYkD4qm-KURAj8SJ66C6ZWtsBdHf_kS-4rJZGQ,5049
|
|
9
|
-
hyper_connections-0.3.
|
|
10
|
-
hyper_connections-0.3.
|
|
11
|
-
hyper_connections-0.3.
|
|
12
|
-
hyper_connections-0.3.
|
|
9
|
+
hyper_connections-0.3.15.dist-info/METADATA,sha256=PF7nUEiHAtWHgnRQfni_eNwywX0z7_5vKyj_bNm9oi0,6705
|
|
10
|
+
hyper_connections-0.3.15.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
11
|
+
hyper_connections-0.3.15.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
|
|
12
|
+
hyper_connections-0.3.15.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|