hyper-connections 0.3.11__py3-none-any.whl → 0.3.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hyper_connections/manifold_constrained_hyper_connections.py +14 -6
- {hyper_connections-0.3.11.dist-info → hyper_connections-0.3.14.dist-info}/METADATA +1 -1
- {hyper_connections-0.3.11.dist-info → hyper_connections-0.3.14.dist-info}/RECORD +5 -5
- {hyper_connections-0.3.11.dist-info → hyper_connections-0.3.14.dist-info}/WHEEL +0 -0
- {hyper_connections-0.3.11.dist-info → hyper_connections-0.3.14.dist-info}/licenses/LICENSE +0 -0
|
@@ -59,6 +59,14 @@ def sinkhorn_knopps(log_alpha, iters = 20):
|
|
|
59
59
|
|
|
60
60
|
return alpha.to(dtype)
|
|
61
61
|
|
|
62
|
+
def log_domain_sinkhorn_knopps(alpha, iters = 20):
|
|
63
|
+
|
|
64
|
+
for _ in range(iters):
|
|
65
|
+
alpha = alpha - alpha.logsumexp(dim = -2, keepdim = True)
|
|
66
|
+
alpha = alpha - alpha.logsumexp(dim = -1, keepdim = True)
|
|
67
|
+
|
|
68
|
+
return alpha.exp()
|
|
69
|
+
|
|
62
70
|
# main functions
|
|
63
71
|
|
|
64
72
|
def get_expand_reduce_stream_functions(
|
|
@@ -201,6 +209,7 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
201
209
|
depth_residual_fn = add,
|
|
202
210
|
num_fracs = 1, # https://arxiv.org/abs/2503.14125
|
|
203
211
|
sinkhorn_iters = 20,
|
|
212
|
+
log_domain_sinkhorn = False,
|
|
204
213
|
forward_method_names: tuple[str, ...] = (),
|
|
205
214
|
):
|
|
206
215
|
"""
|
|
@@ -272,6 +281,7 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
272
281
|
# sinkhorn related
|
|
273
282
|
|
|
274
283
|
self.sinkhorn_iters = sinkhorn_iters
|
|
284
|
+
self.log_domain_sinkhorn = log_domain_sinkhorn
|
|
275
285
|
|
|
276
286
|
# dropouts
|
|
277
287
|
|
|
@@ -348,22 +358,20 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
348
358
|
|
|
349
359
|
alpha = dynamic_alpha + static_alpha
|
|
350
360
|
|
|
351
|
-
alpha = self.split_fracs(alpha) # (batch, seq, fracs1, streams, fracs2, input + residual streams)
|
|
352
|
-
|
|
353
361
|
# the alpha is now split and "manifold constrained" with sinkhorn and sigmoid
|
|
354
362
|
|
|
355
363
|
alpha_pre, alpha_residual = alpha[..., :self.num_input_views], alpha[..., self.num_input_views:]
|
|
356
364
|
|
|
357
365
|
alpha_pre = alpha_pre.sigmoid()
|
|
358
366
|
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
alpha_residual = sinkhorn_knopps(alpha_residual, self.sinkhorn_iters)
|
|
367
|
+
sinkhorn_fn = sinkhorn_knopps if not self.log_domain_sinkhorn else log_domain_sinkhorn_knopps
|
|
362
368
|
|
|
363
|
-
alpha_residual =
|
|
369
|
+
alpha_residual = sinkhorn_fn(alpha_residual, self.sinkhorn_iters)
|
|
364
370
|
|
|
365
371
|
alpha = cat((alpha_pre, alpha_residual), dim = -1)
|
|
366
372
|
|
|
373
|
+
alpha = self.split_fracs(alpha) # (batch, seq, fracs1, streams, fracs2, input + residual streams)
|
|
374
|
+
|
|
367
375
|
# beta for weights from branch output back to residual streams
|
|
368
376
|
|
|
369
377
|
beta = None
|
|
@@ -3,10 +3,10 @@ hyper_connections/hyper_connections.py,sha256=rqFJj3U0LF3uDKNKNPBpRrmf0oa2BGWVbD
|
|
|
3
3
|
hyper_connections/hyper_connections_channel_first.py,sha256=Mh_hzhTi96ZoOPmhSKwUaF4TbHpNqhs83wNe5hNuL7o,6532
|
|
4
4
|
hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=6BXKdSwyx6wdQVseebKG2EQkhVaVLrrepOlL8lLnex4,7855
|
|
5
5
|
hyper_connections/hyper_connections_with_multi_input_streams.py,sha256=ueT3CJPHrt5hRU7q1bFF0rANWJh_pXqclt6HiUu1gBY,11331
|
|
6
|
-
hyper_connections/manifold_constrained_hyper_connections.py,sha256=
|
|
6
|
+
hyper_connections/manifold_constrained_hyper_connections.py,sha256=BekY37Gt5us7zZjxP6laDHCbygMww8kS5GatnveNrYw,17595
|
|
7
7
|
hyper_connections/residuals.py,sha256=JVSFJj_H7xQ3_Fd-pZH5Hdv9SveAQu29jQNvMyom5ek,921
|
|
8
8
|
hyper_connections/vit.py,sha256=fTC8hAYkD4qm-KURAj8SJ66C6ZWtsBdHf_kS-4rJZGQ,5049
|
|
9
|
-
hyper_connections-0.3.
|
|
10
|
-
hyper_connections-0.3.
|
|
11
|
-
hyper_connections-0.3.
|
|
12
|
-
hyper_connections-0.3.
|
|
9
|
+
hyper_connections-0.3.14.dist-info/METADATA,sha256=xncYFa2ttI2WOVLTy5-jk69dZ29aQa-QfFAgI4qJvHc,6705
|
|
10
|
+
hyper_connections-0.3.14.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
11
|
+
hyper_connections-0.3.14.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
|
|
12
|
+
hyper_connections-0.3.14.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|