hyper-connections 0.1.6__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hyper_connections/hyper_connections.py +18 -6
- {hyper_connections-0.1.6.dist-info → hyper_connections-0.1.7.dist-info}/METADATA +1 -1
- {hyper_connections-0.1.6.dist-info → hyper_connections-0.1.7.dist-info}/RECORD +5 -5
- {hyper_connections-0.1.6.dist-info → hyper_connections-0.1.7.dist-info}/WHEEL +0 -0
- {hyper_connections-0.1.6.dist-info → hyper_connections-0.1.7.dist-info}/licenses/LICENSE +0 -0
|
@@ -73,16 +73,18 @@ class Residual(Module):
|
|
|
73
73
|
self,
|
|
74
74
|
*args,
|
|
75
75
|
branch: Module | None = None,
|
|
76
|
+
residual_transform: Module | None = None,
|
|
76
77
|
**kwargs
|
|
77
78
|
):
|
|
78
79
|
super().__init__()
|
|
79
80
|
self.branch = branch
|
|
81
|
+
self.residual_transform = default(residual_transform, nn.Identity())
|
|
80
82
|
|
|
81
83
|
def width_connection(self, residuals):
|
|
82
84
|
return residuals, residuals, dict()
|
|
83
85
|
|
|
84
86
|
def depth_connection(self, branch_output, residuals):
|
|
85
|
-
return branch_output + residuals
|
|
87
|
+
return branch_output + self.residual_transform(residuals)
|
|
86
88
|
|
|
87
89
|
def decorate_branch(self, branch: Callable):
|
|
88
90
|
assert not exists(self.branch), 'branch was already wrapped on init'
|
|
@@ -128,7 +130,8 @@ class HyperConnections(Module):
|
|
|
128
130
|
layer_index = None,
|
|
129
131
|
tanh = True,
|
|
130
132
|
channel_first = False,
|
|
131
|
-
dropout = 0
|
|
133
|
+
dropout = 0.,
|
|
134
|
+
residual_transform: Module | None = None, # to support resnet blocks where dimension in not equal to dimension out - usually a residual conv
|
|
132
135
|
):
|
|
133
136
|
"""
|
|
134
137
|
Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
|
|
@@ -168,7 +171,14 @@ class HyperConnections(Module):
|
|
|
168
171
|
|
|
169
172
|
self.channel_first = channel_first
|
|
170
173
|
|
|
174
|
+
# maybe residual transform
|
|
175
|
+
|
|
176
|
+
self.residual_transform = default(residual_transform, nn.Identity())
|
|
177
|
+
|
|
171
178
|
def width_connection(self, residuals):
|
|
179
|
+
|
|
180
|
+
maybe_transformed_residuals = self.residual_transform(residuals)
|
|
181
|
+
|
|
172
182
|
# width connection
|
|
173
183
|
|
|
174
184
|
if self.channel_first:
|
|
@@ -197,7 +207,7 @@ class HyperConnections(Module):
|
|
|
197
207
|
if self.channel_first:
|
|
198
208
|
branch_input = rearrange(branch_input, 'b ... d -> b d ...')
|
|
199
209
|
|
|
200
|
-
return branch_input,
|
|
210
|
+
return branch_input, maybe_transformed_residuals, dict(beta = beta)
|
|
201
211
|
|
|
202
212
|
def depth_connection(self, branch_output, residuals, *, beta):
|
|
203
213
|
# 'depth' connection
|
|
@@ -205,13 +215,15 @@ class HyperConnections(Module):
|
|
|
205
215
|
if self.channel_first:
|
|
206
216
|
branch_output = rearrange(branch_output, 'b d ... -> b ... d')
|
|
207
217
|
|
|
208
|
-
|
|
209
|
-
output = rearrange(
|
|
218
|
+
output = einsum(branch_output, beta, 'b ... d, b ... s -> b ... s d')
|
|
219
|
+
output = rearrange(output, 'b ... s d -> (b s) ... d')
|
|
210
220
|
|
|
211
221
|
if self.channel_first:
|
|
212
222
|
output = rearrange(output, 'b ... d -> b d ...')
|
|
213
223
|
|
|
214
|
-
|
|
224
|
+
residuals = residuals + output
|
|
225
|
+
|
|
226
|
+
return self.dropout(residuals)
|
|
215
227
|
|
|
216
228
|
def decorate_branch(self, branch: Callable):
|
|
217
229
|
assert not exists(self.branch), 'branch was already wrapped on init'
|
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
hyper_connections/__init__.py,sha256=d2zNTka0Gp9vINu4U-RhgTJFBhsVrs1fne_15Zl0oOs,224
|
|
2
|
-
hyper_connections/hyper_connections.py,sha256=
|
|
2
|
+
hyper_connections/hyper_connections.py,sha256=L2e4DduzPGdH30NhfHuiSiVZTwXRgeZW2MDAZ0Z-TKk,10541
|
|
3
3
|
hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=2JABz1slrF5_XP33L1CMNTmmixfoD464FtQpUADqneU,7806
|
|
4
4
|
hyper_connections/hyper_connections_with_multi_input_streams.py,sha256=UIKceEZEBLKFL5VuP5tR1KTDeZNIJEKjFuPAwXkcp0I,11282
|
|
5
|
-
hyper_connections-0.1.
|
|
6
|
-
hyper_connections-0.1.
|
|
7
|
-
hyper_connections-0.1.
|
|
8
|
-
hyper_connections-0.1.
|
|
5
|
+
hyper_connections-0.1.7.dist-info/METADATA,sha256=YThD719ySS2H6ABQnrNHKmoWI9vaGh5d1H9mbMKehV0,5230
|
|
6
|
+
hyper_connections-0.1.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
+
hyper_connections-0.1.7.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
|
|
8
|
+
hyper_connections-0.1.7.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|