hyper-connections 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hyper_connections/mHCv2.py +21 -0
- {hyper_connections-0.4.7.dist-info → hyper_connections-0.4.8.dist-info}/METADATA +2 -1
- {hyper_connections-0.4.7.dist-info → hyper_connections-0.4.8.dist-info}/RECORD +5 -5
- {hyper_connections-0.4.7.dist-info → hyper_connections-0.4.8.dist-info}/WHEEL +0 -0
- {hyper_connections-0.4.7.dist-info → hyper_connections-0.4.8.dist-info}/licenses/LICENSE +0 -0
hyper_connections/mHCv2.py
CHANGED
|
@@ -13,6 +13,8 @@ from torch.utils._pytree import tree_flatten, tree_unflatten
|
|
|
13
13
|
from einops import rearrange, repeat, reduce, einsum
|
|
14
14
|
from einops.layers.torch import Rearrange, Reduce
|
|
15
15
|
|
|
16
|
+
from torch_einops_utils import pack_with_inverse
|
|
17
|
+
|
|
16
18
|
"""
|
|
17
19
|
ein notation:
|
|
18
20
|
b - batch
|
|
@@ -241,6 +243,7 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
241
243
|
forward_method_names: tuple[str, ...] = (),
|
|
242
244
|
num_dynamic_alpha_proposals = 1,
|
|
243
245
|
use_triton_sinkhorn = False,
|
|
246
|
+
mix_streams_before_norm = False, # whether to mix the residual streams before the norm (that then projects to Hpre, Hpost, Hresidual)
|
|
244
247
|
):
|
|
245
248
|
"""
|
|
246
249
|
Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
|
|
@@ -263,6 +266,16 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
263
266
|
|
|
264
267
|
dim //= num_fracs # effective dim handled in dimension is feature dimension divided by num fractions
|
|
265
268
|
|
|
269
|
+
# whether to mix the streams before the norm below
|
|
270
|
+
# this would be equivalent to separable depthwise convs from yesteryears (with a norm in between) - parameter efficient improv
|
|
271
|
+
|
|
272
|
+
self.maybe_mix_streams = None
|
|
273
|
+
|
|
274
|
+
if mix_streams_before_norm:
|
|
275
|
+
self.maybe_mix_streams = nn.Conv2d(num_residual_streams, num_residual_streams, 1, bias = False)
|
|
276
|
+
|
|
277
|
+
nn.init.dirac_(self.maybe_mix_streams.weight)
|
|
278
|
+
|
|
266
279
|
# they used layernorm in paper, but rmsnorm is fine given what we know now
|
|
267
280
|
|
|
268
281
|
self.norm = RMSNorm(dim)
|
|
@@ -370,6 +383,14 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
370
383
|
|
|
371
384
|
residuals = self.split_fracs(residuals)
|
|
372
385
|
|
|
386
|
+
# maybe mix streams
|
|
387
|
+
|
|
388
|
+
if exists(self.maybe_mix_streams):
|
|
389
|
+
|
|
390
|
+
residuals, inverse_pack_lead_dims = pack_with_inverse(residuals, '* c h w')
|
|
391
|
+
residuals = self.maybe_mix_streams(residuals)
|
|
392
|
+
residuals = inverse_pack_lead_dims(residuals)
|
|
393
|
+
|
|
373
394
|
# norm
|
|
374
395
|
|
|
375
396
|
normed = self.norm(residuals)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: hyper-connections
|
|
3
|
-
Version: 0.4.
|
|
3
|
+
Version: 0.4.8
|
|
4
4
|
Summary: Hyper-Connections
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/hyper-connections/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/hyper-connections
|
|
@@ -35,6 +35,7 @@ Classifier: Programming Language :: Python :: 3.9
|
|
|
35
35
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
36
36
|
Requires-Python: >=3.9
|
|
37
37
|
Requires-Dist: einops>=0.8.1
|
|
38
|
+
Requires-Dist: torch-einops-utils>=0.0.20
|
|
38
39
|
Requires-Dist: torch>=2.5
|
|
39
40
|
Provides-Extra: examples
|
|
40
41
|
Description-Content-Type: text/markdown
|
|
@@ -3,12 +3,12 @@ hyper_connections/hyper_connections.py,sha256=2F-104cGE82KCK0KeC07NSOJNPT-0PCtvX
|
|
|
3
3
|
hyper_connections/hyper_connections_channel_first.py,sha256=5vAen4WXxNI9K07ndLBQJwdJv-OjoXznta5EIQTaQNA,6512
|
|
4
4
|
hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=yn2AlFB6qCYQeRhhhaMlCM3mxxLEdWCYwU2p9TsMwWI,7835
|
|
5
5
|
hyper_connections/hyper_connections_with_multi_input_streams.py,sha256=iFPw5pgCRHTo16nBJ2PExKSrvTyCh7ba7Py14P1oSPE,11311
|
|
6
|
-
hyper_connections/mHCv2.py,sha256=
|
|
6
|
+
hyper_connections/mHCv2.py,sha256=LpMtlrb7Vfi2qq_cqPl9fajA5SxkMTl5QGpmvBJyD1M,18360
|
|
7
7
|
hyper_connections/manifold_constrained_hyper_connections.py,sha256=E4os-6q_SMjJO1JD0EG8rFTCXA7MQoy-aqUlM7KVS5Q,18269
|
|
8
8
|
hyper_connections/residuals.py,sha256=JVSFJj_H7xQ3_Fd-pZH5Hdv9SveAQu29jQNvMyom5ek,921
|
|
9
9
|
hyper_connections/triton_sinkhorn.py,sha256=n2WyQcUemtv5T5Sk2nljnSpV2hEED4I3HaPsIUy4638,5905
|
|
10
10
|
hyper_connections/vit.py,sha256=dh8AVMUPaUHuWxXJEHoMW_G5nj-EQQjDmgbPwwhiq5g,5215
|
|
11
|
-
hyper_connections-0.4.
|
|
12
|
-
hyper_connections-0.4.
|
|
13
|
-
hyper_connections-0.4.
|
|
14
|
-
hyper_connections-0.4.
|
|
11
|
+
hyper_connections-0.4.8.dist-info/METADATA,sha256=vevhBHad-7ffu1KBFcazUqU5C2XVRy1LlZkIxJUNDIs,6746
|
|
12
|
+
hyper_connections-0.4.8.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
13
|
+
hyper_connections-0.4.8.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
|
|
14
|
+
hyper_connections-0.4.8.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|