hyper-connections 0.1.11__py3-none-any.whl → 0.1.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hyper_connections/hyper_connections.py +12 -2
- hyper_connections/residuals.py +22 -0
- {hyper_connections-0.1.11.dist-info → hyper_connections-0.1.12.dist-info}/METADATA +1 -1
- {hyper_connections-0.1.11.dist-info → hyper_connections-0.1.12.dist-info}/RECORD +6 -5
- {hyper_connections-0.1.11.dist-info → hyper_connections-0.1.12.dist-info}/WHEEL +0 -0
- {hyper_connections-0.1.11.dist-info → hyper_connections-0.1.12.dist-info}/licenses/LICENSE +0 -0
|
@@ -33,6 +33,9 @@ def default(v, d):
|
|
|
33
33
|
def identity(t):
|
|
34
34
|
return t
|
|
35
35
|
|
|
36
|
+
def add(x, y):
|
|
37
|
+
return x + y
|
|
38
|
+
|
|
36
39
|
# main functions
|
|
37
40
|
|
|
38
41
|
def get_expand_reduce_stream_functions(num_streams, add_stream_embed = False, dim = None, disable = False):
|
|
@@ -143,7 +146,8 @@ class HyperConnections(Module):
|
|
|
143
146
|
dropout = 0.,
|
|
144
147
|
residual_transform: Module | None = None, # to support resnet blocks where dimension in not equal to dimension out - usually a residual conv
|
|
145
148
|
add_branch_out_to_residual = True, # will disable depth connections (weighted residual sum with beta) if set False
|
|
146
|
-
num_input_views = 1 # allow for the branch module to receive multiple input views, dimension placed on the very left (before batch)
|
|
149
|
+
num_input_views = 1, # allow for the branch module to receive multiple input views, dimension placed on the very left (before batch)
|
|
150
|
+
depth_residual_fn = add
|
|
147
151
|
):
|
|
148
152
|
"""
|
|
149
153
|
Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
|
|
@@ -199,6 +203,12 @@ class HyperConnections(Module):
|
|
|
199
203
|
|
|
200
204
|
self.residual_transform = default(residual_transform, nn.Identity())
|
|
201
205
|
|
|
206
|
+
# maybe custom depth connection residual function
|
|
207
|
+
# this is to prepare for gating the addition of the branch outputs to the residual streams
|
|
208
|
+
# needed for memory lanes a la RMT / LMM
|
|
209
|
+
|
|
210
|
+
self.depth_residual_fn = depth_residual_fn
|
|
211
|
+
|
|
202
212
|
def width_connection(self, residuals):
|
|
203
213
|
|
|
204
214
|
maybe_transformed_residuals = self.residual_transform(residuals)
|
|
@@ -254,7 +264,7 @@ class HyperConnections(Module):
|
|
|
254
264
|
if self.channel_first:
|
|
255
265
|
output = rearrange(output, 'b ... d -> b d ...')
|
|
256
266
|
|
|
257
|
-
residuals = residuals
|
|
267
|
+
residuals = self.depth_residual_fn(output, residuals)
|
|
258
268
|
|
|
259
269
|
return self.dropout(residuals)
|
|
260
270
|
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import nn
|
|
3
|
+
from torch.nn import Module
|
|
4
|
+
|
|
5
|
+
from einops import rearrange
|
|
6
|
+
|
|
7
|
+
class GRUGatedResidual(Module):
|
|
8
|
+
def __init__(
|
|
9
|
+
self,
|
|
10
|
+
dim
|
|
11
|
+
):
|
|
12
|
+
super().__init__()
|
|
13
|
+
self.gru = nn.GRUCell(dim, dim)
|
|
14
|
+
|
|
15
|
+
def forward(self, x, residual):
|
|
16
|
+
|
|
17
|
+
gated_output = self.gru(
|
|
18
|
+
rearrange(x, 'b n d -> (b n) d'),
|
|
19
|
+
rearrange(residual, 'b n d -> (b n) d')
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
return gated_output.reshape_as(x)
|
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
hyper_connections/__init__.py,sha256=d2zNTka0Gp9vINu4U-RhgTJFBhsVrs1fne_15Zl0oOs,224
|
|
2
|
-
hyper_connections/hyper_connections.py,sha256=
|
|
2
|
+
hyper_connections/hyper_connections.py,sha256=vpipBRUGgYQ2qLBtT4Ws-myJYVdkQDkN3IkpTMkxRxc,12485
|
|
3
3
|
hyper_connections/hyper_connections_channel_first.py,sha256=BojfO2dcT4jX1rlcU3kr0B6B_CjrkkS2AZU4ZXeWvh8,6769
|
|
4
4
|
hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=2JABz1slrF5_XP33L1CMNTmmixfoD464FtQpUADqneU,7806
|
|
5
5
|
hyper_connections/hyper_connections_with_multi_input_streams.py,sha256=UIKceEZEBLKFL5VuP5tR1KTDeZNIJEKjFuPAwXkcp0I,11282
|
|
6
|
-
hyper_connections
|
|
7
|
-
hyper_connections-0.1.
|
|
8
|
-
hyper_connections-0.1.
|
|
9
|
-
hyper_connections-0.1.
|
|
6
|
+
hyper_connections/residuals.py,sha256=F8NMh9vHDN8Nyo0nzIm_owgR9gXPZUCNujAuduYbTiU,465
|
|
7
|
+
hyper_connections-0.1.12.dist-info/METADATA,sha256=rRgT4Zhj2xUPSRPH0LKZmfJK4olGOxpxPR5ha9wEMQE,5231
|
|
8
|
+
hyper_connections-0.1.12.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
9
|
+
hyper_connections-0.1.12.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
|
|
10
|
+
hyper_connections-0.1.12.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|