hyper-connections 0.1.11__py3-none-any.whl → 0.1.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -33,6 +33,9 @@ def default(v, d):
33
33
  def identity(t):
34
34
  return t
35
35
 
36
+ def add(x, y):
37
+ return x + y
38
+
36
39
  # main functions
37
40
 
38
41
  def get_expand_reduce_stream_functions(num_streams, add_stream_embed = False, dim = None, disable = False):
@@ -143,7 +146,8 @@ class HyperConnections(Module):
143
146
  dropout = 0.,
144
147
  residual_transform: Module | None = None, # to support resnet blocks where dimension in not equal to dimension out - usually a residual conv
145
148
  add_branch_out_to_residual = True, # will disable depth connections (weighted residual sum with beta) if set False
146
- num_input_views = 1 # allow for the branch module to receive multiple input views, dimension placed on the very left (before batch)
149
+ num_input_views = 1, # allow for the branch module to receive multiple input views, dimension placed on the very left (before batch)
150
+ depth_residual_fn = add
147
151
  ):
148
152
  """
149
153
  Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
@@ -199,6 +203,12 @@ class HyperConnections(Module):
199
203
 
200
204
  self.residual_transform = default(residual_transform, nn.Identity())
201
205
 
206
+ # maybe custom depth connection residual function
207
+ # this is to prepare for gating the addition of the branch outputs to the residual streams
208
+ # needed for memory lanes a la RMT / LMM
209
+
210
+ self.depth_residual_fn = depth_residual_fn
211
+
202
212
  def width_connection(self, residuals):
203
213
 
204
214
  maybe_transformed_residuals = self.residual_transform(residuals)
@@ -254,7 +264,7 @@ class HyperConnections(Module):
254
264
  if self.channel_first:
255
265
  output = rearrange(output, 'b ... d -> b d ...')
256
266
 
257
- residuals = residuals + output
267
+ residuals = self.depth_residual_fn(output, residuals)
258
268
 
259
269
  return self.dropout(residuals)
260
270
 
@@ -0,0 +1,38 @@
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import Module
4
+
5
+ from einops import rearrange, pack, unpack
6
+
7
+ class GRUGatedResidual(Module):
8
+ def __init__(
9
+ self,
10
+ dim
11
+ ):
12
+ super().__init__()
13
+ self.gru = nn.GRUCell(dim, dim)
14
+
15
+ def forward(self, x, residual):
16
+ x, ps = pack([x], '* d')
17
+ residual, _ = pack([residual], '* d')
18
+
19
+ output = self.gru(x, residual)
20
+
21
+ output, = unpack(output, ps, '* d')
22
+ return output
23
+
24
+ class GatedResidual(Module):
25
+ def __init__(
26
+ self,
27
+ dim
28
+ ):
29
+ super().__init__()
30
+ self.to_learned_mix = nn.Linear(dim * 2, dim)
31
+
32
+ def forward(self, x, residual):
33
+ x_and_residual, _ = pack([x, residual], 'b n *')
34
+
35
+ mix = self.to_learned_mix(x_and_residual)
36
+
37
+ out = x.lerp(residual, mix.sigmoid())
38
+ return out
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.1.11
3
+ Version: 0.1.14
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -1,9 +1,10 @@
1
1
  hyper_connections/__init__.py,sha256=d2zNTka0Gp9vINu4U-RhgTJFBhsVrs1fne_15Zl0oOs,224
2
- hyper_connections/hyper_connections.py,sha256=mUImPtaTE8Paygs-6vq7l_mlph1CkU__jRcE4TFim_Y,12137
2
+ hyper_connections/hyper_connections.py,sha256=vpipBRUGgYQ2qLBtT4Ws-myJYVdkQDkN3IkpTMkxRxc,12485
3
3
  hyper_connections/hyper_connections_channel_first.py,sha256=BojfO2dcT4jX1rlcU3kr0B6B_CjrkkS2AZU4ZXeWvh8,6769
4
4
  hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=2JABz1slrF5_XP33L1CMNTmmixfoD464FtQpUADqneU,7806
5
5
  hyper_connections/hyper_connections_with_multi_input_streams.py,sha256=UIKceEZEBLKFL5VuP5tR1KTDeZNIJEKjFuPAwXkcp0I,11282
6
- hyper_connections-0.1.11.dist-info/METADATA,sha256=Ck3udilJMrT1ABRkqkNhfEdkjQXAKFBtsUAAalEu3No,5231
7
- hyper_connections-0.1.11.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- hyper_connections-0.1.11.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
9
- hyper_connections-0.1.11.dist-info/RECORD,,
6
+ hyper_connections/residuals.py,sha256=qapN4lt51qNWKa5nX7whN4xcNORxMdr3bdUwIMQPdpQ,853
7
+ hyper_connections-0.1.14.dist-info/METADATA,sha256=7Agg3rGvMYkZEyX3n9yJPO4cNm8-9a33afV4Yc8r7WA,5231
8
+ hyper_connections-0.1.14.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
9
+ hyper_connections-0.1.14.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
10
+ hyper_connections-0.1.14.dist-info/RECORD,,