hyper-connections 0.0.17__tar.gz → 0.0.19__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {hyper_connections-0.0.17 → hyper_connections-0.0.19}/PKG-INFO +1 -1
- {hyper_connections-0.0.17 → hyper_connections-0.0.19}/hyper_connections/__init__.py +2 -1
- {hyper_connections-0.0.17 → hyper_connections-0.0.19}/hyper_connections/hyper_connections.py +40 -2
- {hyper_connections-0.0.17 → hyper_connections-0.0.19}/hyper_connections/hyper_connections_with_multi_branch_inputs.py +6 -1
- {hyper_connections-0.0.17 → hyper_connections-0.0.19}/pyproject.toml +1 -1
- {hyper_connections-0.0.17 → hyper_connections-0.0.19}/.github/workflows/python-publish.yml +0 -0
- {hyper_connections-0.0.17 → hyper_connections-0.0.19}/.gitignore +0 -0
- {hyper_connections-0.0.17 → hyper_connections-0.0.19}/LICENSE +0 -0
- {hyper_connections-0.0.17 → hyper_connections-0.0.19}/README.md +0 -0
- {hyper_connections-0.0.17 → hyper_connections-0.0.19}/hyper-connections.png +0 -0
{hyper_connections-0.0.17 → hyper_connections-0.0.19}/hyper_connections/hyper_connections.py
RENAMED
|
@@ -86,7 +86,8 @@ class HyperConnections(Module):
|
|
|
86
86
|
branch: Module | None = None,
|
|
87
87
|
layer_index = None,
|
|
88
88
|
tanh = True,
|
|
89
|
-
channel_first = False
|
|
89
|
+
channel_first = False,
|
|
90
|
+
dropout = 0.
|
|
90
91
|
):
|
|
91
92
|
"""
|
|
92
93
|
Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
|
|
@@ -116,6 +117,10 @@ class HyperConnections(Module):
|
|
|
116
117
|
self.dynamic_beta_fn = nn.Parameter(torch.zeros(dim))
|
|
117
118
|
self.dynamic_beta_scale = nn.Parameter(torch.ones(()) * 1e-2)
|
|
118
119
|
|
|
120
|
+
# dropouts
|
|
121
|
+
|
|
122
|
+
self.dropout = nn.Dropout(dropout)
|
|
123
|
+
|
|
119
124
|
# channel first option
|
|
120
125
|
|
|
121
126
|
self.channel_first = channel_first
|
|
@@ -184,7 +189,7 @@ class HyperConnections(Module):
|
|
|
184
189
|
if self.channel_first:
|
|
185
190
|
output = rearrange(output, 'b ... d -> b d ...')
|
|
186
191
|
|
|
187
|
-
return output
|
|
192
|
+
return self.dropout(output)
|
|
188
193
|
|
|
189
194
|
def decorate_branch(self, branch: Callable):
|
|
190
195
|
assert not exists(self.branch), 'branch was already wrapped on init'
|
|
@@ -253,3 +258,36 @@ class StreamEmbed(Module):
|
|
|
253
258
|
residuals = rearrange(residuals, 'b ... s d -> (b s) ... d', s = self.num_streams)
|
|
254
259
|
|
|
255
260
|
return residuals
|
|
261
|
+
|
|
262
|
+
# attention pool - taken from Enformer https://www.nature.com/articles/s41592-021-01252-x , in turn taken from somewhere else
|
|
263
|
+
|
|
264
|
+
class AttentionPoolReduceStream(Module):
|
|
265
|
+
def __init__(
|
|
266
|
+
self,
|
|
267
|
+
num_streams,
|
|
268
|
+
dim,
|
|
269
|
+
channel_first = False
|
|
270
|
+
):
|
|
271
|
+
super().__init__()
|
|
272
|
+
self.num_streams = num_streams
|
|
273
|
+
self.channel_first = channel_first
|
|
274
|
+
|
|
275
|
+
self.to_attn_logits = nn.Linear(dim, dim, bias = False)
|
|
276
|
+
self.to_attn_logits.weight.data.copy_(torch.eye(dim))
|
|
277
|
+
|
|
278
|
+
def forward(self, residuals):
|
|
279
|
+
|
|
280
|
+
if self.channel_first:
|
|
281
|
+
residuals = rearrange(residuals, '(b s) d ... -> b ... s d', s = self.num_streams)
|
|
282
|
+
else:
|
|
283
|
+
residuals = rearrange(residuals, '(b s) ... d -> b ... s d', s = self.num_streams)
|
|
284
|
+
|
|
285
|
+
attn_logits = self.to_attn_logits(residuals)
|
|
286
|
+
attn = attn_logits.softmax(dim = -2)
|
|
287
|
+
|
|
288
|
+
residuals = reduce(residuals * attn, 'b ... s d -> b ... d', 'sum')
|
|
289
|
+
|
|
290
|
+
if self.channel_first:
|
|
291
|
+
residuals = rearrange(residuals, 'b ... d -> b d ...')
|
|
292
|
+
|
|
293
|
+
return residuals
|
|
@@ -42,6 +42,7 @@ class HyperConnections(Module):
|
|
|
42
42
|
layer_index = None,
|
|
43
43
|
tanh = True,
|
|
44
44
|
channel_first = False,
|
|
45
|
+
dropout = 0.,
|
|
45
46
|
num_branch_inputs = 1 # residuals will be linearly combined to multiple inputs, fed through the branch, then linearly combined back out to residuals
|
|
46
47
|
):
|
|
47
48
|
"""
|
|
@@ -89,6 +90,10 @@ class HyperConnections(Module):
|
|
|
89
90
|
self.dynamic_beta_fn = nn.Parameter(torch.zeros(dim, num_branch_inputs))
|
|
90
91
|
self.dynamic_beta_scale = nn.Parameter(torch.ones(()) * 1e-2)
|
|
91
92
|
|
|
93
|
+
# dropout
|
|
94
|
+
|
|
95
|
+
self.dropout = nn.Dropout(dropout)
|
|
96
|
+
|
|
92
97
|
# channel first option
|
|
93
98
|
|
|
94
99
|
self.channel_first = channel_first
|
|
@@ -164,7 +169,7 @@ class HyperConnections(Module):
|
|
|
164
169
|
if self.channel_first:
|
|
165
170
|
output = rearrange(output, 'b ... d -> b d ...')
|
|
166
171
|
|
|
167
|
-
return output
|
|
172
|
+
return self.dropout(output)
|
|
168
173
|
|
|
169
174
|
def decorate_branch(self, branch: Callable | tuple[Callable, ...] | list[Callable]):
|
|
170
175
|
assert not exists(self.branches), 'branch was already wrapped on init'
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|