hyper-connections 0.4.6__py3-none-any.whl → 0.4.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hyper_connections/mHCv2.py +22 -1
- hyper_connections/vit.py +9 -6
- {hyper_connections-0.4.6.dist-info → hyper_connections-0.4.8.dist-info}/METADATA +2 -1
- {hyper_connections-0.4.6.dist-info → hyper_connections-0.4.8.dist-info}/RECORD +6 -6
- {hyper_connections-0.4.6.dist-info → hyper_connections-0.4.8.dist-info}/WHEEL +0 -0
- {hyper_connections-0.4.6.dist-info → hyper_connections-0.4.8.dist-info}/licenses/LICENSE +0 -0
hyper_connections/mHCv2.py
CHANGED
|
@@ -13,6 +13,8 @@ from torch.utils._pytree import tree_flatten, tree_unflatten
|
|
|
13
13
|
from einops import rearrange, repeat, reduce, einsum
|
|
14
14
|
from einops.layers.torch import Rearrange, Reduce
|
|
15
15
|
|
|
16
|
+
from torch_einops_utils import pack_with_inverse
|
|
17
|
+
|
|
16
18
|
"""
|
|
17
19
|
ein notation:
|
|
18
20
|
b - batch
|
|
@@ -124,8 +126,8 @@ def get_init_and_expand_reduce_stream_functions(
|
|
|
124
126
|
|
|
125
127
|
hyper_conn_klass = ManifoldConstrainedHyperConnections if not disable else Residual
|
|
126
128
|
|
|
127
|
-
kwargs.pop('add_attn_pool_reduce_stream', None)
|
|
128
129
|
init_hyper_conn_fn = partial(hyper_conn_klass, num_streams, num_fracs = num_fracs, sinkhorn_iters = sinkhorn_iters, use_triton_sinkhorn = use_triton_sinkhorn, **kwargs)
|
|
130
|
+
|
|
129
131
|
expand_reduce_fns = get_expand_reduce_stream_functions(
|
|
130
132
|
num_streams,
|
|
131
133
|
add_stream_embed = add_stream_embed,
|
|
@@ -241,6 +243,7 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
241
243
|
forward_method_names: tuple[str, ...] = (),
|
|
242
244
|
num_dynamic_alpha_proposals = 1,
|
|
243
245
|
use_triton_sinkhorn = False,
|
|
246
|
+
mix_streams_before_norm = False, # whether to mix the residual streams before the norm (that then projects to Hpre, Hpost, Hresidual)
|
|
244
247
|
):
|
|
245
248
|
"""
|
|
246
249
|
Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
|
|
@@ -263,6 +266,16 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
263
266
|
|
|
264
267
|
dim //= num_fracs # effective dim handled in dimension is feature dimension divided by num fractions
|
|
265
268
|
|
|
269
|
+
# whether to mix the streams before the norm below
|
|
270
|
+
# this would be equivalent to separable depthwise convs from yesteryears (with a norm in between) - parameter efficient improv
|
|
271
|
+
|
|
272
|
+
self.maybe_mix_streams = None
|
|
273
|
+
|
|
274
|
+
if mix_streams_before_norm:
|
|
275
|
+
self.maybe_mix_streams = nn.Conv2d(num_residual_streams, num_residual_streams, 1, bias = False)
|
|
276
|
+
|
|
277
|
+
nn.init.dirac_(self.maybe_mix_streams.weight)
|
|
278
|
+
|
|
266
279
|
# they used layernorm in paper, but rmsnorm is fine given what we know now
|
|
267
280
|
|
|
268
281
|
self.norm = RMSNorm(dim)
|
|
@@ -370,6 +383,14 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
370
383
|
|
|
371
384
|
residuals = self.split_fracs(residuals)
|
|
372
385
|
|
|
386
|
+
# maybe mix streams
|
|
387
|
+
|
|
388
|
+
if exists(self.maybe_mix_streams):
|
|
389
|
+
|
|
390
|
+
residuals, inverse_pack_lead_dims = pack_with_inverse(residuals, '* c h w')
|
|
391
|
+
residuals = self.maybe_mix_streams(residuals)
|
|
392
|
+
residuals = inverse_pack_lead_dims(residuals)
|
|
393
|
+
|
|
373
394
|
# norm
|
|
374
395
|
|
|
375
396
|
normed = self.norm(residuals)
|
hyper_connections/vit.py
CHANGED
|
@@ -5,7 +5,7 @@ from torch.nn import Module, ModuleList
|
|
|
5
5
|
from einops import rearrange, repeat
|
|
6
6
|
from einops.layers.torch import Rearrange
|
|
7
7
|
|
|
8
|
-
from hyper_connections.
|
|
8
|
+
from hyper_connections.mHCv2 import mHC
|
|
9
9
|
|
|
10
10
|
# helpers
|
|
11
11
|
|
|
@@ -66,12 +66,12 @@ class Attention(Module):
|
|
|
66
66
|
return self.to_out(out)
|
|
67
67
|
|
|
68
68
|
class Transformer(Module):
|
|
69
|
-
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., num_residual_streams = 4,
|
|
69
|
+
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., num_residual_streams = 4, mhc_kwargs = dict()):
|
|
70
70
|
super().__init__()
|
|
71
71
|
self.norm = nn.LayerNorm(dim)
|
|
72
72
|
self.layers = ModuleList([])
|
|
73
73
|
|
|
74
|
-
init_hyper_conn, self.expand_streams, self.reduce_streams = mHC.get_init_and_expand_reduce_stream_functions(num_residual_streams,
|
|
74
|
+
init_hyper_conn, self.expand_streams, self.reduce_streams = mHC.get_init_and_expand_reduce_stream_functions(num_residual_streams, **mhc_kwargs)
|
|
75
75
|
|
|
76
76
|
for _ in range(depth):
|
|
77
77
|
self.layers.append(ModuleList([
|
|
@@ -92,7 +92,7 @@ class Transformer(Module):
|
|
|
92
92
|
return self.norm(x)
|
|
93
93
|
|
|
94
94
|
class ViT(Module):
|
|
95
|
-
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., num_residual_streams = 4, num_dynamic_alpha_proposals = 1):
|
|
95
|
+
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., num_residual_streams = 4, mhc_kwargs = dict(num_dynamic_alpha_proposals = 1)):
|
|
96
96
|
super().__init__()
|
|
97
97
|
image_height, image_width = pair(image_size)
|
|
98
98
|
patch_height, patch_width = pair(patch_size)
|
|
@@ -117,7 +117,7 @@ class ViT(Module):
|
|
|
117
117
|
|
|
118
118
|
self.dropout = nn.Dropout(emb_dropout)
|
|
119
119
|
|
|
120
|
-
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, num_residual_streams,
|
|
120
|
+
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, num_residual_streams, mhc_kwargs)
|
|
121
121
|
|
|
122
122
|
self.pool = pool
|
|
123
123
|
self.to_latent = nn.Identity()
|
|
@@ -154,7 +154,10 @@ if __name__ == '__main__':
|
|
|
154
154
|
mlp_dim = 2048,
|
|
155
155
|
dropout = 0.1,
|
|
156
156
|
emb_dropout = 0.1,
|
|
157
|
-
num_residual_streams = 4
|
|
157
|
+
num_residual_streams = 4,
|
|
158
|
+
mhc_kwargs = dict(
|
|
159
|
+
use_triton_sinkhorn = False
|
|
160
|
+
)
|
|
158
161
|
)
|
|
159
162
|
|
|
160
163
|
img = torch.randn(1, 3, 256, 256)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: hyper-connections
|
|
3
|
-
Version: 0.4.
|
|
3
|
+
Version: 0.4.8
|
|
4
4
|
Summary: Hyper-Connections
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/hyper-connections/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/hyper-connections
|
|
@@ -35,6 +35,7 @@ Classifier: Programming Language :: Python :: 3.9
|
|
|
35
35
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
36
36
|
Requires-Python: >=3.9
|
|
37
37
|
Requires-Dist: einops>=0.8.1
|
|
38
|
+
Requires-Dist: torch-einops-utils>=0.0.20
|
|
38
39
|
Requires-Dist: torch>=2.5
|
|
39
40
|
Provides-Extra: examples
|
|
40
41
|
Description-Content-Type: text/markdown
|
|
@@ -3,12 +3,12 @@ hyper_connections/hyper_connections.py,sha256=2F-104cGE82KCK0KeC07NSOJNPT-0PCtvX
|
|
|
3
3
|
hyper_connections/hyper_connections_channel_first.py,sha256=5vAen4WXxNI9K07ndLBQJwdJv-OjoXznta5EIQTaQNA,6512
|
|
4
4
|
hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=yn2AlFB6qCYQeRhhhaMlCM3mxxLEdWCYwU2p9TsMwWI,7835
|
|
5
5
|
hyper_connections/hyper_connections_with_multi_input_streams.py,sha256=iFPw5pgCRHTo16nBJ2PExKSrvTyCh7ba7Py14P1oSPE,11311
|
|
6
|
-
hyper_connections/mHCv2.py,sha256=
|
|
6
|
+
hyper_connections/mHCv2.py,sha256=LpMtlrb7Vfi2qq_cqPl9fajA5SxkMTl5QGpmvBJyD1M,18360
|
|
7
7
|
hyper_connections/manifold_constrained_hyper_connections.py,sha256=E4os-6q_SMjJO1JD0EG8rFTCXA7MQoy-aqUlM7KVS5Q,18269
|
|
8
8
|
hyper_connections/residuals.py,sha256=JVSFJj_H7xQ3_Fd-pZH5Hdv9SveAQu29jQNvMyom5ek,921
|
|
9
9
|
hyper_connections/triton_sinkhorn.py,sha256=n2WyQcUemtv5T5Sk2nljnSpV2hEED4I3HaPsIUy4638,5905
|
|
10
|
-
hyper_connections/vit.py,sha256=
|
|
11
|
-
hyper_connections-0.4.
|
|
12
|
-
hyper_connections-0.4.
|
|
13
|
-
hyper_connections-0.4.
|
|
14
|
-
hyper_connections-0.4.
|
|
10
|
+
hyper_connections/vit.py,sha256=dh8AVMUPaUHuWxXJEHoMW_G5nj-EQQjDmgbPwwhiq5g,5215
|
|
11
|
+
hyper_connections-0.4.8.dist-info/METADATA,sha256=vevhBHad-7ffu1KBFcazUqU5C2XVRy1LlZkIxJUNDIs,6746
|
|
12
|
+
hyper_connections-0.4.8.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
13
|
+
hyper_connections-0.4.8.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
|
|
14
|
+
hyper_connections-0.4.8.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|