hyper-connections 0.3.10__py3-none-any.whl → 0.3.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hyper_connections/manifold_constrained_hyper_connections.py +4 -8
- {hyper_connections-0.3.10.dist-info → hyper_connections-0.3.12.dist-info}/METADATA +1 -1
- {hyper_connections-0.3.10.dist-info → hyper_connections-0.3.12.dist-info}/RECORD +5 -5
- {hyper_connections-0.3.10.dist-info → hyper_connections-0.3.12.dist-info}/WHEEL +0 -0
- {hyper_connections-0.3.10.dist-info → hyper_connections-0.3.12.dist-info}/licenses/LICENSE +0 -0
|
@@ -348,22 +348,18 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
348
348
|
|
|
349
349
|
alpha = dynamic_alpha + static_alpha
|
|
350
350
|
|
|
351
|
-
alpha = self.split_fracs(alpha) # (batch, seq, fracs1, streams, fracs2, input + residual streams)
|
|
352
|
-
|
|
353
351
|
# the alpha is now split and "manifold constrained" with sinkhorn and sigmoid
|
|
354
352
|
|
|
355
353
|
alpha_pre, alpha_residual = alpha[..., :self.num_input_views], alpha[..., self.num_input_views:]
|
|
356
354
|
|
|
357
355
|
alpha_pre = alpha_pre.sigmoid()
|
|
358
356
|
|
|
359
|
-
alpha_residual = rearrange(alpha_residual, '... (v s1 s2) -> ... v s1 s2', v = self.num_input_views, s1 = streams)
|
|
360
|
-
|
|
361
357
|
alpha_residual = sinkhorn_knopps(alpha_residual, self.sinkhorn_iters)
|
|
362
358
|
|
|
363
|
-
alpha_residual = rearrange(alpha_residual, '... v s1 s2 -> ... (v s1 s2)')
|
|
364
|
-
|
|
365
359
|
alpha = cat((alpha_pre, alpha_residual), dim = -1)
|
|
366
360
|
|
|
361
|
+
alpha = self.split_fracs(alpha) # (batch, seq, fracs1, streams, fracs2, input + residual streams)
|
|
362
|
+
|
|
367
363
|
# beta for weights from branch output back to residual streams
|
|
368
364
|
|
|
369
365
|
beta = None
|
|
@@ -371,8 +367,6 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
371
367
|
if self.add_branch_out_to_residual:
|
|
372
368
|
dc_weight = normed @ self.dynamic_beta_fn.float()
|
|
373
369
|
|
|
374
|
-
dc_weight = dc_weight.sigmoid() * 2 # sigmoid * 2 for "H_post", corresponding to dc weight in original paper
|
|
375
|
-
|
|
376
370
|
if not self.has_fracs:
|
|
377
371
|
dc_weight = rearrange(dc_weight, '... -> ... 1')
|
|
378
372
|
|
|
@@ -382,6 +376,8 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
382
376
|
|
|
383
377
|
beta = dynamic_beta + static_beta
|
|
384
378
|
|
|
379
|
+
beta = beta.sigmoid() * 2 # for "H_post" manifold constraint
|
|
380
|
+
|
|
385
381
|
mix_h = einsum(alpha, residuals.float(), '... f1 s f2 t, ... f1 s d -> ... f2 t d')
|
|
386
382
|
|
|
387
383
|
if self.num_input_views == 1:
|
|
@@ -3,10 +3,10 @@ hyper_connections/hyper_connections.py,sha256=rqFJj3U0LF3uDKNKNPBpRrmf0oa2BGWVbD
|
|
|
3
3
|
hyper_connections/hyper_connections_channel_first.py,sha256=Mh_hzhTi96ZoOPmhSKwUaF4TbHpNqhs83wNe5hNuL7o,6532
|
|
4
4
|
hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=6BXKdSwyx6wdQVseebKG2EQkhVaVLrrepOlL8lLnex4,7855
|
|
5
5
|
hyper_connections/hyper_connections_with_multi_input_streams.py,sha256=ueT3CJPHrt5hRU7q1bFF0rANWJh_pXqclt6HiUu1gBY,11331
|
|
6
|
-
hyper_connections/manifold_constrained_hyper_connections.py,sha256=
|
|
6
|
+
hyper_connections/manifold_constrained_hyper_connections.py,sha256=JF_ncPloyNLr303xZMlCiCorCE1Qt_soo1oeOyGYIdc,17168
|
|
7
7
|
hyper_connections/residuals.py,sha256=JVSFJj_H7xQ3_Fd-pZH5Hdv9SveAQu29jQNvMyom5ek,921
|
|
8
8
|
hyper_connections/vit.py,sha256=fTC8hAYkD4qm-KURAj8SJ66C6ZWtsBdHf_kS-4rJZGQ,5049
|
|
9
|
-
hyper_connections-0.3.
|
|
10
|
-
hyper_connections-0.3.
|
|
11
|
-
hyper_connections-0.3.
|
|
12
|
-
hyper_connections-0.3.
|
|
9
|
+
hyper_connections-0.3.12.dist-info/METADATA,sha256=OyTViiZQQ1AmohzVqRhf71HOFPjLEVMGPVY5kY6wHsg,6705
|
|
10
|
+
hyper_connections-0.3.12.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
11
|
+
hyper_connections-0.3.12.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
|
|
12
|
+
hyper_connections-0.3.12.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|