hyper-connections 0.3.8__py3-none-any.whl → 0.3.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hyper_connections/hyper_connections.py +8 -8
- hyper_connections/hyper_connections_channel_first.py +3 -9
- hyper_connections/manifold_constrained_hyper_connections.py +7 -4
- hyper_connections/vit.py +163 -0
- {hyper_connections-0.3.8.dist-info → hyper_connections-0.3.10.dist-info}/METADATA +1 -1
- hyper_connections-0.3.10.dist-info/RECORD +12 -0
- hyper_connections-0.3.8.dist-info/RECORD +0 -11
- {hyper_connections-0.3.8.dist-info → hyper_connections-0.3.10.dist-info}/WHEEL +0 -0
- {hyper_connections-0.3.8.dist-info → hyper_connections-0.3.10.dist-info}/licenses/LICENSE +0 -0
|
@@ -175,7 +175,6 @@ class HyperConnections(Module):
|
|
|
175
175
|
tanh = True,
|
|
176
176
|
channel_first = False,
|
|
177
177
|
dropout = 0.,
|
|
178
|
-
residual_transform: Module | None = None, # to support resnet blocks where dimension in not equal to dimension out - usually a residual conv
|
|
179
178
|
add_branch_out_to_residual = True, # will disable depth connections (weighted residual sum with beta) if set False
|
|
180
179
|
num_input_views = 1, # allow for the branch module to receive multiple input views, dimension placed on the very left (before batch)
|
|
181
180
|
depth_residual_fn = add,
|
|
@@ -255,10 +254,6 @@ class HyperConnections(Module):
|
|
|
255
254
|
|
|
256
255
|
self.channel_first = channel_first
|
|
257
256
|
|
|
258
|
-
# maybe residual transform
|
|
259
|
-
|
|
260
|
-
self.residual_transform = default(residual_transform, nn.Identity())
|
|
261
|
-
|
|
262
257
|
# maybe custom depth connection residual function
|
|
263
258
|
# this is to prepare for gating the addition of the branch outputs to the residual streams
|
|
264
259
|
# needed for memory lanes a la RMT / LMM
|
|
@@ -271,8 +266,6 @@ class HyperConnections(Module):
|
|
|
271
266
|
):
|
|
272
267
|
streams = self.num_residual_streams
|
|
273
268
|
|
|
274
|
-
maybe_transformed_residuals = self.residual_transform(residuals)
|
|
275
|
-
|
|
276
269
|
# width connection
|
|
277
270
|
|
|
278
271
|
# handle channel first
|
|
@@ -334,7 +327,14 @@ class HyperConnections(Module):
|
|
|
334
327
|
|
|
335
328
|
branch_input = self.merge_fracs(branch_input)
|
|
336
329
|
|
|
337
|
-
|
|
330
|
+
# reshape residuals back
|
|
331
|
+
|
|
332
|
+
if self.channel_first:
|
|
333
|
+
residuals = rearrange(residuals, 'b ... f s d -> (b s) (f d) ...')
|
|
334
|
+
else:
|
|
335
|
+
residuals = rearrange(residuals, 'b ... f s d -> (b s) ... (f d)')
|
|
336
|
+
|
|
337
|
+
return branch_input, residuals, dict(beta = beta)
|
|
338
338
|
|
|
339
339
|
def depth_connection(
|
|
340
340
|
self,
|
|
@@ -84,7 +84,6 @@ class HyperConnections(Module):
|
|
|
84
84
|
tanh = True,
|
|
85
85
|
channel_first = True,
|
|
86
86
|
dropout = 0.,
|
|
87
|
-
residual_transform: Module | None = None, # to support resnet blocks where dimension in not equal to dimension out - usually a residual conv
|
|
88
87
|
):
|
|
89
88
|
"""
|
|
90
89
|
Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
|
|
@@ -124,19 +123,12 @@ class HyperConnections(Module):
|
|
|
124
123
|
self.dynamic_alpha_scale = nn.Parameter(torch.ones(()) * 1e-2)
|
|
125
124
|
self.dynamic_beta_scale = nn.Parameter(torch.ones(()) * 1e-2)
|
|
126
125
|
|
|
127
|
-
|
|
128
126
|
# dropouts
|
|
129
127
|
|
|
130
128
|
self.dropout = nn.Dropout(dropout)
|
|
131
129
|
|
|
132
|
-
# maybe residual transform
|
|
133
|
-
|
|
134
|
-
self.residual_transform = default(residual_transform, nn.Identity())
|
|
135
|
-
|
|
136
130
|
def width_connection(self, residuals):
|
|
137
131
|
|
|
138
|
-
maybe_transformed_residuals = self.residual_transform(residuals)
|
|
139
|
-
|
|
140
132
|
# width connection
|
|
141
133
|
|
|
142
134
|
normed = self.norm(residuals)
|
|
@@ -161,7 +153,9 @@ class HyperConnections(Module):
|
|
|
161
153
|
|
|
162
154
|
branch_input, residuals = mix_h[:, 0, ...], mix_h[:, 1:, ...]
|
|
163
155
|
|
|
164
|
-
|
|
156
|
+
residuals = rearrange(residuals, 'b s d ... -> (b s) d ...')
|
|
157
|
+
|
|
158
|
+
return branch_input, residuals, dict(beta = beta)
|
|
165
159
|
|
|
166
160
|
def depth_connection(self, branch_output, residuals, *, beta):
|
|
167
161
|
# 'depth' connection
|
|
@@ -307,7 +307,7 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
307
307
|
):
|
|
308
308
|
streams = self.num_residual_streams
|
|
309
309
|
|
|
310
|
-
|
|
310
|
+
residuals = self.residual_transform(residuals)
|
|
311
311
|
|
|
312
312
|
# width connection
|
|
313
313
|
|
|
@@ -397,13 +397,14 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
397
397
|
|
|
398
398
|
branch_input = self.merge_fracs(branch_input)
|
|
399
399
|
|
|
400
|
-
|
|
401
|
-
|
|
400
|
+
residuals = rearrange(residuals, 'b ... f s d -> (b s) ... (f d)')
|
|
401
|
+
|
|
402
|
+
branch_input, residuals = tuple(t.to(dtype) for t in (branch_input, residuals))
|
|
402
403
|
|
|
403
404
|
if exists(beta):
|
|
404
405
|
beta = beta.to(dtype)
|
|
405
406
|
|
|
406
|
-
return branch_input,
|
|
407
|
+
return branch_input, residuals, dict(beta = beta)
|
|
407
408
|
|
|
408
409
|
def depth_connection(
|
|
409
410
|
self,
|
|
@@ -486,6 +487,8 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
486
487
|
|
|
487
488
|
return add_residual_fn(branch_output)
|
|
488
489
|
|
|
490
|
+
mHC = ManifoldConstrainedHyperConnections
|
|
491
|
+
|
|
489
492
|
ManifoldConstrainedHyperConnections.get_expand_reduce_stream_functions = staticmethod(get_expand_reduce_stream_functions)
|
|
490
493
|
ManifoldConstrainedHyperConnections.get_init_and_expand_reduce_stream_functions = staticmethod(get_init_and_expand_reduce_stream_functions)
|
|
491
494
|
|
hyper_connections/vit.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import nn
|
|
3
|
+
from torch.nn import Module, ModuleList
|
|
4
|
+
|
|
5
|
+
from einops import rearrange, repeat
|
|
6
|
+
from einops.layers.torch import Rearrange
|
|
7
|
+
|
|
8
|
+
from hyper_connections.manifold_constrained_hyper_connections import mHC
|
|
9
|
+
|
|
10
|
+
# helpers
|
|
11
|
+
|
|
12
|
+
def pair(t):
|
|
13
|
+
return t if isinstance(t, tuple) else (t, t)
|
|
14
|
+
|
|
15
|
+
# classes
|
|
16
|
+
|
|
17
|
+
class FeedForward(Module):
|
|
18
|
+
def __init__(self, dim, hidden_dim, dropout = 0.):
|
|
19
|
+
super().__init__()
|
|
20
|
+
self.net = nn.Sequential(
|
|
21
|
+
nn.LayerNorm(dim),
|
|
22
|
+
nn.Linear(dim, hidden_dim),
|
|
23
|
+
nn.GELU(),
|
|
24
|
+
nn.Dropout(dropout),
|
|
25
|
+
nn.Linear(hidden_dim, dim),
|
|
26
|
+
nn.Dropout(dropout)
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
def forward(self, x):
|
|
30
|
+
return self.net(x)
|
|
31
|
+
|
|
32
|
+
class Attention(Module):
|
|
33
|
+
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
|
34
|
+
super().__init__()
|
|
35
|
+
inner_dim = dim_head * heads
|
|
36
|
+
project_out = not (heads == 1 and dim_head == dim)
|
|
37
|
+
|
|
38
|
+
self.heads = heads
|
|
39
|
+
self.scale = dim_head ** -0.5
|
|
40
|
+
|
|
41
|
+
self.norm = nn.LayerNorm(dim)
|
|
42
|
+
|
|
43
|
+
self.attend = nn.Softmax(dim = -1)
|
|
44
|
+
self.dropout = nn.Dropout(dropout)
|
|
45
|
+
|
|
46
|
+
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
|
47
|
+
|
|
48
|
+
self.to_out = nn.Sequential(
|
|
49
|
+
nn.Linear(inner_dim, dim),
|
|
50
|
+
nn.Dropout(dropout)
|
|
51
|
+
) if project_out else nn.Identity()
|
|
52
|
+
|
|
53
|
+
def forward(self, x):
|
|
54
|
+
x = self.norm(x)
|
|
55
|
+
|
|
56
|
+
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
|
57
|
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
|
58
|
+
|
|
59
|
+
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
|
60
|
+
|
|
61
|
+
attn = self.attend(dots)
|
|
62
|
+
attn = self.dropout(attn)
|
|
63
|
+
|
|
64
|
+
out = torch.matmul(attn, v)
|
|
65
|
+
out = rearrange(out, 'b h n d -> b n (h d)')
|
|
66
|
+
return self.to_out(out)
|
|
67
|
+
|
|
68
|
+
class Transformer(Module):
|
|
69
|
+
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., num_residual_streams = 4):
|
|
70
|
+
super().__init__()
|
|
71
|
+
self.norm = nn.LayerNorm(dim)
|
|
72
|
+
self.layers = ModuleList([])
|
|
73
|
+
|
|
74
|
+
init_hyper_conn, self.expand_streams, self.reduce_streams = mHC.get_init_and_expand_reduce_stream_functions(num_residual_streams)
|
|
75
|
+
|
|
76
|
+
for _ in range(depth):
|
|
77
|
+
self.layers.append(ModuleList([
|
|
78
|
+
init_hyper_conn(dim = dim , branch = Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
|
|
79
|
+
init_hyper_conn(dim = dim, branch = FeedForward(dim, mlp_dim, dropout = dropout))
|
|
80
|
+
]))
|
|
81
|
+
|
|
82
|
+
def forward(self, x):
|
|
83
|
+
|
|
84
|
+
x = self.expand_streams(x)
|
|
85
|
+
|
|
86
|
+
for attn, ff in self.layers:
|
|
87
|
+
x = attn(x)
|
|
88
|
+
x = ff(x)
|
|
89
|
+
|
|
90
|
+
x = self.reduce_streams(x)
|
|
91
|
+
|
|
92
|
+
return self.norm(x)
|
|
93
|
+
|
|
94
|
+
class ViT(Module):
|
|
95
|
+
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., num_residual_streams = 4):
|
|
96
|
+
super().__init__()
|
|
97
|
+
image_height, image_width = pair(image_size)
|
|
98
|
+
patch_height, patch_width = pair(patch_size)
|
|
99
|
+
|
|
100
|
+
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
|
101
|
+
|
|
102
|
+
num_patches = (image_height // patch_height) * (image_width // patch_width)
|
|
103
|
+
patch_dim = channels * patch_height * patch_width
|
|
104
|
+
|
|
105
|
+
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
|
106
|
+
num_cls_tokens = 1 if pool == 'cls' else 0
|
|
107
|
+
|
|
108
|
+
self.to_patch_embedding = nn.Sequential(
|
|
109
|
+
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
|
|
110
|
+
nn.LayerNorm(patch_dim),
|
|
111
|
+
nn.Linear(patch_dim, dim),
|
|
112
|
+
nn.LayerNorm(dim),
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
self.cls_token = nn.Parameter(torch.randn(num_cls_tokens, dim))
|
|
116
|
+
self.pos_embedding = nn.Parameter(torch.randn(num_patches + num_cls_tokens, dim))
|
|
117
|
+
|
|
118
|
+
self.dropout = nn.Dropout(emb_dropout)
|
|
119
|
+
|
|
120
|
+
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
|
121
|
+
|
|
122
|
+
self.pool = pool
|
|
123
|
+
self.to_latent = nn.Identity()
|
|
124
|
+
|
|
125
|
+
self.mlp_head = nn.Linear(dim, num_classes)
|
|
126
|
+
|
|
127
|
+
def forward(self, img):
|
|
128
|
+
batch = img.shape[0]
|
|
129
|
+
x = self.to_patch_embedding(img)
|
|
130
|
+
|
|
131
|
+
cls_tokens = repeat(self.cls_token, '... d -> b ... d', b = batch)
|
|
132
|
+
x = torch.cat((cls_tokens, x), dim = 1)
|
|
133
|
+
|
|
134
|
+
seq = x.shape[1]
|
|
135
|
+
|
|
136
|
+
x = x + self.pos_embedding[:seq]
|
|
137
|
+
x = self.dropout(x)
|
|
138
|
+
|
|
139
|
+
x = self.transformer(x)
|
|
140
|
+
|
|
141
|
+
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
|
|
142
|
+
|
|
143
|
+
x = self.to_latent(x)
|
|
144
|
+
return self.mlp_head(x)
|
|
145
|
+
|
|
146
|
+
if __name__ == '__main__':
|
|
147
|
+
v = ViT(
|
|
148
|
+
image_size = 256,
|
|
149
|
+
patch_size = 32,
|
|
150
|
+
num_classes = 1000,
|
|
151
|
+
dim = 1024,
|
|
152
|
+
depth = 6,
|
|
153
|
+
heads = 16,
|
|
154
|
+
mlp_dim = 2048,
|
|
155
|
+
dropout = 0.1,
|
|
156
|
+
emb_dropout = 0.1,
|
|
157
|
+
num_residual_streams = 4
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
img = torch.randn(1, 3, 256, 256)
|
|
161
|
+
|
|
162
|
+
preds = v(img) # (1, 1000)
|
|
163
|
+
assert preds.shape == (1, 1000)
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
hyper_connections/__init__.py,sha256=BAGwi53ozXcnfPJAGur0RHA4vcolF1ORBhbZ9a8SkrE,602
|
|
2
|
+
hyper_connections/hyper_connections.py,sha256=rqFJj3U0LF3uDKNKNPBpRrmf0oa2BGWVbD6S-xdZdLo,14904
|
|
3
|
+
hyper_connections/hyper_connections_channel_first.py,sha256=Mh_hzhTi96ZoOPmhSKwUaF4TbHpNqhs83wNe5hNuL7o,6532
|
|
4
|
+
hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=6BXKdSwyx6wdQVseebKG2EQkhVaVLrrepOlL8lLnex4,7855
|
|
5
|
+
hyper_connections/hyper_connections_with_multi_input_streams.py,sha256=ueT3CJPHrt5hRU7q1bFF0rANWJh_pXqclt6HiUu1gBY,11331
|
|
6
|
+
hyper_connections/manifold_constrained_hyper_connections.py,sha256=SkGAWpBHnrOlIcixb0iIGej9StO82O7KXrFjYuSKx7I,17424
|
|
7
|
+
hyper_connections/residuals.py,sha256=JVSFJj_H7xQ3_Fd-pZH5Hdv9SveAQu29jQNvMyom5ek,921
|
|
8
|
+
hyper_connections/vit.py,sha256=fTC8hAYkD4qm-KURAj8SJ66C6ZWtsBdHf_kS-4rJZGQ,5049
|
|
9
|
+
hyper_connections-0.3.10.dist-info/METADATA,sha256=tEYVvFTVY_13gYQbflz-mjWMQEwH4DvPvQk76X9Iq2E,6705
|
|
10
|
+
hyper_connections-0.3.10.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
11
|
+
hyper_connections-0.3.10.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
|
|
12
|
+
hyper_connections-0.3.10.dist-info/RECORD,,
|
|
@@ -1,11 +0,0 @@
|
|
|
1
|
-
hyper_connections/__init__.py,sha256=BAGwi53ozXcnfPJAGur0RHA4vcolF1ORBhbZ9a8SkrE,602
|
|
2
|
-
hyper_connections/hyper_connections.py,sha256=UHxZhyRwx89GRgmQVt53Gv6JeNhX8UCjjETlydMZjTk,15021
|
|
3
|
-
hyper_connections/hyper_connections_channel_first.py,sha256=_1PM4LRcPpDqfCiHlBMc2nLV08sXM2nuyZGSKTiuqbE,6818
|
|
4
|
-
hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=6BXKdSwyx6wdQVseebKG2EQkhVaVLrrepOlL8lLnex4,7855
|
|
5
|
-
hyper_connections/hyper_connections_with_multi_input_streams.py,sha256=ueT3CJPHrt5hRU7q1bFF0rANWJh_pXqclt6HiUu1gBY,11331
|
|
6
|
-
hyper_connections/manifold_constrained_hyper_connections.py,sha256=rNRh7Hkz-mgZD4GMImIpk3gPHFLsM2_cvVy8I0x2W5U,17339
|
|
7
|
-
hyper_connections/residuals.py,sha256=JVSFJj_H7xQ3_Fd-pZH5Hdv9SveAQu29jQNvMyom5ek,921
|
|
8
|
-
hyper_connections-0.3.8.dist-info/METADATA,sha256=kqntUf_yXJ9fNqxGvQ8XFERHOInCCDAqijtazZfFqes,6704
|
|
9
|
-
hyper_connections-0.3.8.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
10
|
-
hyper_connections-0.3.8.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
|
|
11
|
-
hyper_connections-0.3.8.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|