hyper-connections 0.3.7__py3-none-any.whl → 0.3.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -200,7 +200,8 @@ class ManifoldConstrainedHyperConnections(Module):
200
200
  num_input_views = 1, # allow for the branch module to receive multiple input views, dimension placed on the very left (before batch)
201
201
  depth_residual_fn = add,
202
202
  num_fracs = 1, # https://arxiv.org/abs/2503.14125
203
- sinkhorn_iters = 20
203
+ sinkhorn_iters = 20,
204
+ forward_method_names: tuple[str, ...] = (),
204
205
  ):
205
206
  """
206
207
  Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
@@ -290,6 +291,16 @@ class ManifoldConstrainedHyperConnections(Module):
290
291
 
291
292
  self.depth_residual_fn = depth_residual_fn
292
293
 
294
+ # forwarding method names
295
+
296
+ self.forward_method_names = forward_method_names
297
+
298
+ for forward_method_name in self.forward_method_names:
299
+ assert not hasattr(self, forward_method_name)
300
+
301
+ fn = getattr(self.branch, forward_method_name)
302
+ setattr(self, forward_method_name, fn)
303
+
293
304
  def width_connection(
294
305
  self,
295
306
  residuals
@@ -475,6 +486,8 @@ class ManifoldConstrainedHyperConnections(Module):
475
486
 
476
487
  return add_residual_fn(branch_output)
477
488
 
489
+ mHC = ManifoldConstrainedHyperConnections
490
+
478
491
  ManifoldConstrainedHyperConnections.get_expand_reduce_stream_functions = staticmethod(get_expand_reduce_stream_functions)
479
492
  ManifoldConstrainedHyperConnections.get_init_and_expand_reduce_stream_functions = staticmethod(get_init_and_expand_reduce_stream_functions)
480
493
 
@@ -0,0 +1,163 @@
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import Module, ModuleList
4
+
5
+ from einops import rearrange, repeat
6
+ from einops.layers.torch import Rearrange
7
+
8
+ from hyper_connections.manifold_constrained_hyper_connections import mHC
9
+
10
+ # helpers
11
+
12
+ def pair(t):
13
+ return t if isinstance(t, tuple) else (t, t)
14
+
15
+ # classes
16
+
17
+ class FeedForward(Module):
18
+ def __init__(self, dim, hidden_dim, dropout = 0.):
19
+ super().__init__()
20
+ self.net = nn.Sequential(
21
+ nn.LayerNorm(dim),
22
+ nn.Linear(dim, hidden_dim),
23
+ nn.GELU(),
24
+ nn.Dropout(dropout),
25
+ nn.Linear(hidden_dim, dim),
26
+ nn.Dropout(dropout)
27
+ )
28
+
29
+ def forward(self, x):
30
+ return self.net(x)
31
+
32
+ class Attention(Module):
33
+ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
34
+ super().__init__()
35
+ inner_dim = dim_head * heads
36
+ project_out = not (heads == 1 and dim_head == dim)
37
+
38
+ self.heads = heads
39
+ self.scale = dim_head ** -0.5
40
+
41
+ self.norm = nn.LayerNorm(dim)
42
+
43
+ self.attend = nn.Softmax(dim = -1)
44
+ self.dropout = nn.Dropout(dropout)
45
+
46
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
47
+
48
+ self.to_out = nn.Sequential(
49
+ nn.Linear(inner_dim, dim),
50
+ nn.Dropout(dropout)
51
+ ) if project_out else nn.Identity()
52
+
53
+ def forward(self, x):
54
+ x = self.norm(x)
55
+
56
+ qkv = self.to_qkv(x).chunk(3, dim = -1)
57
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
58
+
59
+ dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
60
+
61
+ attn = self.attend(dots)
62
+ attn = self.dropout(attn)
63
+
64
+ out = torch.matmul(attn, v)
65
+ out = rearrange(out, 'b h n d -> b n (h d)')
66
+ return self.to_out(out)
67
+
68
+ class Transformer(Module):
69
+ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., num_residual_streams = 4):
70
+ super().__init__()
71
+ self.norm = nn.LayerNorm(dim)
72
+ self.layers = ModuleList([])
73
+
74
+ init_hyper_conn, self.expand_streams, self.reduce_streams = mHC.get_init_and_expand_reduce_stream_functions(num_residual_streams)
75
+
76
+ for _ in range(depth):
77
+ self.layers.append(ModuleList([
78
+ init_hyper_conn(dim = dim , branch = Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
79
+ init_hyper_conn(dim = dim, branch = FeedForward(dim, mlp_dim, dropout = dropout))
80
+ ]))
81
+
82
+ def forward(self, x):
83
+
84
+ x = self.expand_streams(x)
85
+
86
+ for attn, ff in self.layers:
87
+ x = attn(x)
88
+ x = ff(x)
89
+
90
+ x = self.reduce_streams(x)
91
+
92
+ return self.norm(x)
93
+
94
+ class ViT(Module):
95
+ def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., num_residual_streams = 4):
96
+ super().__init__()
97
+ image_height, image_width = pair(image_size)
98
+ patch_height, patch_width = pair(patch_size)
99
+
100
+ assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
101
+
102
+ num_patches = (image_height // patch_height) * (image_width // patch_width)
103
+ patch_dim = channels * patch_height * patch_width
104
+
105
+ assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
106
+ num_cls_tokens = 1 if pool == 'cls' else 0
107
+
108
+ self.to_patch_embedding = nn.Sequential(
109
+ Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
110
+ nn.LayerNorm(patch_dim),
111
+ nn.Linear(patch_dim, dim),
112
+ nn.LayerNorm(dim),
113
+ )
114
+
115
+ self.cls_token = nn.Parameter(torch.randn(num_cls_tokens, dim))
116
+ self.pos_embedding = nn.Parameter(torch.randn(num_patches + num_cls_tokens, dim))
117
+
118
+ self.dropout = nn.Dropout(emb_dropout)
119
+
120
+ self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
121
+
122
+ self.pool = pool
123
+ self.to_latent = nn.Identity()
124
+
125
+ self.mlp_head = nn.Linear(dim, num_classes)
126
+
127
+ def forward(self, img):
128
+ batch = img.shape[0]
129
+ x = self.to_patch_embedding(img)
130
+
131
+ cls_tokens = repeat(self.cls_token, '... d -> b ... d', b = batch)
132
+ x = torch.cat((cls_tokens, x), dim = 1)
133
+
134
+ seq = x.shape[1]
135
+
136
+ x = x + self.pos_embedding[:seq]
137
+ x = self.dropout(x)
138
+
139
+ x = self.transformer(x)
140
+
141
+ x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
142
+
143
+ x = self.to_latent(x)
144
+ return self.mlp_head(x)
145
+
146
+ if __name__ == '__main__':
147
+ v = ViT(
148
+ image_size = 256,
149
+ patch_size = 32,
150
+ num_classes = 1000,
151
+ dim = 1024,
152
+ depth = 6,
153
+ heads = 16,
154
+ mlp_dim = 2048,
155
+ dropout = 0.1,
156
+ emb_dropout = 0.1,
157
+ num_residual_streams = 4
158
+ )
159
+
160
+ img = torch.randn(1, 3, 256, 256)
161
+
162
+ preds = v(img) # (1, 1000)
163
+ assert preds.shape == (1, 1000)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.3.7
3
+ Version: 0.3.9
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -3,9 +3,10 @@ hyper_connections/hyper_connections.py,sha256=UHxZhyRwx89GRgmQVt53Gv6JeNhX8UCjjE
3
3
  hyper_connections/hyper_connections_channel_first.py,sha256=_1PM4LRcPpDqfCiHlBMc2nLV08sXM2nuyZGSKTiuqbE,6818
4
4
  hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=6BXKdSwyx6wdQVseebKG2EQkhVaVLrrepOlL8lLnex4,7855
5
5
  hyper_connections/hyper_connections_with_multi_input_streams.py,sha256=ueT3CJPHrt5hRU7q1bFF0rANWJh_pXqclt6HiUu1gBY,11331
6
- hyper_connections/manifold_constrained_hyper_connections.py,sha256=gJzc9oHZjhC3S85HiXGQck5qBNcRWGCjVsDfMXwqPxo,16961
6
+ hyper_connections/manifold_constrained_hyper_connections.py,sha256=uF9WALGLeEBdfUm_p8O8ZTmmsk3L44gg-G1GW1SCMO0,17382
7
7
  hyper_connections/residuals.py,sha256=JVSFJj_H7xQ3_Fd-pZH5Hdv9SveAQu29jQNvMyom5ek,921
8
- hyper_connections-0.3.7.dist-info/METADATA,sha256=Pyj7qVEMj6Szb_PD80eYoN3O4fwv8DjsYSzFd8EY_bo,6704
9
- hyper_connections-0.3.7.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
10
- hyper_connections-0.3.7.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
11
- hyper_connections-0.3.7.dist-info/RECORD,,
8
+ hyper_connections/vit.py,sha256=fTC8hAYkD4qm-KURAj8SJ66C6ZWtsBdHf_kS-4rJZGQ,5049
9
+ hyper_connections-0.3.9.dist-info/METADATA,sha256=mAciMU5pRr1oxP1OvjUFnwJAxqU9RTOKjPs-I7xn1ns,6704
10
+ hyper_connections-0.3.9.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
11
+ hyper_connections-0.3.9.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
12
+ hyper_connections-0.3.9.dist-info/RECORD,,