hyper-connections 0.3.11__py3-none-any.whl → 0.3.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hyper_connections/manifold_constrained_hyper_connections.py +2 -6
- {hyper_connections-0.3.11.dist-info → hyper_connections-0.3.12.dist-info}/METADATA +1 -1
- {hyper_connections-0.3.11.dist-info → hyper_connections-0.3.12.dist-info}/RECORD +5 -5
- {hyper_connections-0.3.11.dist-info → hyper_connections-0.3.12.dist-info}/WHEEL +0 -0
- {hyper_connections-0.3.11.dist-info → hyper_connections-0.3.12.dist-info}/licenses/LICENSE +0 -0
|
@@ -348,22 +348,18 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
348
348
|
|
|
349
349
|
alpha = dynamic_alpha + static_alpha
|
|
350
350
|
|
|
351
|
-
alpha = self.split_fracs(alpha) # (batch, seq, fracs1, streams, fracs2, input + residual streams)
|
|
352
|
-
|
|
353
351
|
# the alpha is now split and "manifold constrained" with sinkhorn and sigmoid
|
|
354
352
|
|
|
355
353
|
alpha_pre, alpha_residual = alpha[..., :self.num_input_views], alpha[..., self.num_input_views:]
|
|
356
354
|
|
|
357
355
|
alpha_pre = alpha_pre.sigmoid()
|
|
358
356
|
|
|
359
|
-
alpha_residual = rearrange(alpha_residual, '... (v s1 s2) -> ... v s1 s2', v = self.num_input_views, s1 = streams)
|
|
360
|
-
|
|
361
357
|
alpha_residual = sinkhorn_knopps(alpha_residual, self.sinkhorn_iters)
|
|
362
358
|
|
|
363
|
-
alpha_residual = rearrange(alpha_residual, '... v s1 s2 -> ... (v s1 s2)')
|
|
364
|
-
|
|
365
359
|
alpha = cat((alpha_pre, alpha_residual), dim = -1)
|
|
366
360
|
|
|
361
|
+
alpha = self.split_fracs(alpha) # (batch, seq, fracs1, streams, fracs2, input + residual streams)
|
|
362
|
+
|
|
367
363
|
# beta for weights from branch output back to residual streams
|
|
368
364
|
|
|
369
365
|
beta = None
|
|
@@ -3,10 +3,10 @@ hyper_connections/hyper_connections.py,sha256=rqFJj3U0LF3uDKNKNPBpRrmf0oa2BGWVbD
|
|
|
3
3
|
hyper_connections/hyper_connections_channel_first.py,sha256=Mh_hzhTi96ZoOPmhSKwUaF4TbHpNqhs83wNe5hNuL7o,6532
|
|
4
4
|
hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=6BXKdSwyx6wdQVseebKG2EQkhVaVLrrepOlL8lLnex4,7855
|
|
5
5
|
hyper_connections/hyper_connections_with_multi_input_streams.py,sha256=ueT3CJPHrt5hRU7q1bFF0rANWJh_pXqclt6HiUu1gBY,11331
|
|
6
|
-
hyper_connections/manifold_constrained_hyper_connections.py,sha256=
|
|
6
|
+
hyper_connections/manifold_constrained_hyper_connections.py,sha256=JF_ncPloyNLr303xZMlCiCorCE1Qt_soo1oeOyGYIdc,17168
|
|
7
7
|
hyper_connections/residuals.py,sha256=JVSFJj_H7xQ3_Fd-pZH5Hdv9SveAQu29jQNvMyom5ek,921
|
|
8
8
|
hyper_connections/vit.py,sha256=fTC8hAYkD4qm-KURAj8SJ66C6ZWtsBdHf_kS-4rJZGQ,5049
|
|
9
|
-
hyper_connections-0.3.
|
|
10
|
-
hyper_connections-0.3.
|
|
11
|
-
hyper_connections-0.3.
|
|
12
|
-
hyper_connections-0.3.
|
|
9
|
+
hyper_connections-0.3.12.dist-info/METADATA,sha256=OyTViiZQQ1AmohzVqRhf71HOFPjLEVMGPVY5kY6wHsg,6705
|
|
10
|
+
hyper_connections-0.3.12.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
11
|
+
hyper_connections-0.3.12.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
|
|
12
|
+
hyper_connections-0.3.12.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|