hyper-connections 0.1.10__py3-none-any.whl → 0.1.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -19,6 +19,7 @@ b - batch
19
19
  d - feature dimension
20
20
  s - residual streams
21
21
  t - residual streams + num branch inputs
22
+ v - number of views for branch input
22
23
  """
23
24
 
24
25
  # helper functions
@@ -32,6 +33,9 @@ def default(v, d):
32
33
  def identity(t):
33
34
  return t
34
35
 
36
+ def add(x, y):
37
+ return x + y
38
+
35
39
  # main functions
36
40
 
37
41
  def get_expand_reduce_stream_functions(num_streams, add_stream_embed = False, dim = None, disable = False):
@@ -141,7 +145,9 @@ class HyperConnections(Module):
141
145
  channel_first = False,
142
146
  dropout = 0.,
143
147
  residual_transform: Module | None = None, # to support resnet blocks where dimension in not equal to dimension out - usually a residual conv
144
- add_branch_out_to_residual = True # will disable depth connections (weighted residual sum with beta) if set False
148
+ add_branch_out_to_residual = True, # will disable depth connections (weighted residual sum with beta) if set False
149
+ num_input_views = 1, # allow for the branch module to receive multiple input views, dimension placed on the very left (before batch)
150
+ depth_residual_fn = add
145
151
  ):
146
152
  """
147
153
  Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
@@ -161,14 +167,19 @@ class HyperConnections(Module):
161
167
  self.num_residual_streams = num_residual_streams
162
168
  init_residual_index = default(layer_index, randrange(num_residual_streams)) % num_residual_streams # just choose one random residual stream if layer index not given
163
169
 
170
+ # width num residual streams
171
+
172
+ assert num_input_views >= 1
173
+ self.num_input_views = num_input_views
174
+
164
175
  # width connection
165
176
 
166
- init_alpha0 = torch.zeros((num_residual_streams, 1))
167
- init_alpha0[init_residual_index, 0] = 1.
177
+ init_alpha0 = torch.zeros((num_residual_streams, num_input_views))
178
+ init_alpha0[init_residual_index, :] = 1.
168
179
 
169
180
  self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))
170
181
 
171
- self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams + 1))
182
+ self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams + num_input_views))
172
183
  self.dynamic_alpha_scale = nn.Parameter(torch.ones(()) * 1e-2)
173
184
 
174
185
  # depth connection related (beta)
@@ -192,6 +203,12 @@ class HyperConnections(Module):
192
203
 
193
204
  self.residual_transform = default(residual_transform, nn.Identity())
194
205
 
206
+ # maybe custom depth connection residual function
207
+ # this is to prepare for gating the addition of the branch outputs to the residual streams
208
+ # needed for memory lanes a la RMT / LMM
209
+
210
+ self.depth_residual_fn = depth_residual_fn
211
+
195
212
  def width_connection(self, residuals):
196
213
 
197
214
  maybe_transformed_residuals = self.residual_transform(residuals)
@@ -222,7 +239,11 @@ class HyperConnections(Module):
222
239
 
223
240
  mix_h = einsum(alpha, residuals, '... s t, ... s d -> ... t d')
224
241
 
225
- branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :]
242
+ if self.num_input_views == 1:
243
+ branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :]
244
+ else:
245
+ branch_input, residuals = mix_h[..., :self.num_input_views, :], mix_h[..., self.num_input_views:, :]
246
+ branch_input = rearrange(branch_input, 'b ... v d -> v b ... d')
226
247
 
227
248
  if self.channel_first:
228
249
  branch_input = rearrange(branch_input, 'b ... d -> b d ...')
@@ -243,7 +264,7 @@ class HyperConnections(Module):
243
264
  if self.channel_first:
244
265
  output = rearrange(output, 'b ... d -> b d ...')
245
266
 
246
- residuals = residuals + output
267
+ residuals = self.depth_residual_fn(output, residuals)
247
268
 
248
269
  return self.dropout(residuals)
249
270
 
@@ -0,0 +1,22 @@
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import Module
4
+
5
+ from einops import rearrange
6
+
7
+ class GRUGatedResidual(Module):
8
+ def __init__(
9
+ self,
10
+ dim
11
+ ):
12
+ super().__init__()
13
+ self.gru = nn.GRUCell(dim, dim)
14
+
15
+ def forward(self, x, residual):
16
+
17
+ gated_output = self.gru(
18
+ rearrange(x, 'b n d -> (b n) d'),
19
+ rearrange(residual, 'b n d -> (b n) d')
20
+ )
21
+
22
+ return gated_output.reshape_as(x)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.1.10
3
+ Version: 0.1.12
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -1,9 +1,10 @@
1
1
  hyper_connections/__init__.py,sha256=d2zNTka0Gp9vINu4U-RhgTJFBhsVrs1fne_15Zl0oOs,224
2
- hyper_connections/hyper_connections.py,sha256=Jk7Ux8fJPz63EkgZgKa7fQqpqCasr6cLZt7Fd06dPoE,11563
2
+ hyper_connections/hyper_connections.py,sha256=vpipBRUGgYQ2qLBtT4Ws-myJYVdkQDkN3IkpTMkxRxc,12485
3
3
  hyper_connections/hyper_connections_channel_first.py,sha256=BojfO2dcT4jX1rlcU3kr0B6B_CjrkkS2AZU4ZXeWvh8,6769
4
4
  hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=2JABz1slrF5_XP33L1CMNTmmixfoD464FtQpUADqneU,7806
5
5
  hyper_connections/hyper_connections_with_multi_input_streams.py,sha256=UIKceEZEBLKFL5VuP5tR1KTDeZNIJEKjFuPAwXkcp0I,11282
6
- hyper_connections-0.1.10.dist-info/METADATA,sha256=K2EgcNxhmXRGTOEbVvsVsQl_dKLZ6iw88dzZqD6zaf4,5231
7
- hyper_connections-0.1.10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- hyper_connections-0.1.10.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
9
- hyper_connections-0.1.10.dist-info/RECORD,,
6
+ hyper_connections/residuals.py,sha256=F8NMh9vHDN8Nyo0nzIm_owgR9gXPZUCNujAuduYbTiU,465
7
+ hyper_connections-0.1.12.dist-info/METADATA,sha256=rRgT4Zhj2xUPSRPH0LKZmfJK4olGOxpxPR5ha9wEMQE,5231
8
+ hyper_connections-0.1.12.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
9
+ hyper_connections-0.1.12.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
10
+ hyper_connections-0.1.12.dist-info/RECORD,,