hyper-connections 0.0.17__tar.gz → 0.0.19__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.0.17
3
+ Version: 0.0.19
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -1,5 +1,6 @@
1
1
  from hyper_connections.hyper_connections import (
2
2
  HyperConnections,
3
3
  Residual,
4
- StreamEmbed
4
+ StreamEmbed,
5
+ AttentionPoolReduceStream
5
6
  )
@@ -86,7 +86,8 @@ class HyperConnections(Module):
86
86
  branch: Module | None = None,
87
87
  layer_index = None,
88
88
  tanh = True,
89
- channel_first = False
89
+ channel_first = False,
90
+ dropout = 0.
90
91
  ):
91
92
  """
92
93
  Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
@@ -116,6 +117,10 @@ class HyperConnections(Module):
116
117
  self.dynamic_beta_fn = nn.Parameter(torch.zeros(dim))
117
118
  self.dynamic_beta_scale = nn.Parameter(torch.ones(()) * 1e-2)
118
119
 
120
+ # dropouts
121
+
122
+ self.dropout = nn.Dropout(dropout)
123
+
119
124
  # channel first option
120
125
 
121
126
  self.channel_first = channel_first
@@ -184,7 +189,7 @@ class HyperConnections(Module):
184
189
  if self.channel_first:
185
190
  output = rearrange(output, 'b ... d -> b d ...')
186
191
 
187
- return output
192
+ return self.dropout(output)
188
193
 
189
194
  def decorate_branch(self, branch: Callable):
190
195
  assert not exists(self.branch), 'branch was already wrapped on init'
@@ -253,3 +258,36 @@ class StreamEmbed(Module):
253
258
  residuals = rearrange(residuals, 'b ... s d -> (b s) ... d', s = self.num_streams)
254
259
 
255
260
  return residuals
261
+
262
+ # attention pool - taken from Enformer https://www.nature.com/articles/s41592-021-01252-x , in turn taken from somewhere else
263
+
264
+ class AttentionPoolReduceStream(Module):
265
+ def __init__(
266
+ self,
267
+ num_streams,
268
+ dim,
269
+ channel_first = False
270
+ ):
271
+ super().__init__()
272
+ self.num_streams = num_streams
273
+ self.channel_first = channel_first
274
+
275
+ self.to_attn_logits = nn.Linear(dim, dim, bias = False)
276
+ self.to_attn_logits.weight.data.copy_(torch.eye(dim))
277
+
278
+ def forward(self, residuals):
279
+
280
+ if self.channel_first:
281
+ residuals = rearrange(residuals, '(b s) d ... -> b ... s d', s = self.num_streams)
282
+ else:
283
+ residuals = rearrange(residuals, '(b s) ... d -> b ... s d', s = self.num_streams)
284
+
285
+ attn_logits = self.to_attn_logits(residuals)
286
+ attn = attn_logits.softmax(dim = -2)
287
+
288
+ residuals = reduce(residuals * attn, 'b ... s d -> b ... d', 'sum')
289
+
290
+ if self.channel_first:
291
+ residuals = rearrange(residuals, 'b ... d -> b d ...')
292
+
293
+ return residuals
@@ -42,6 +42,7 @@ class HyperConnections(Module):
42
42
  layer_index = None,
43
43
  tanh = True,
44
44
  channel_first = False,
45
+ dropout = 0.,
45
46
  num_branch_inputs = 1 # residuals will be linearly combined to multiple inputs, fed through the branch, then linearly combined back out to residuals
46
47
  ):
47
48
  """
@@ -89,6 +90,10 @@ class HyperConnections(Module):
89
90
  self.dynamic_beta_fn = nn.Parameter(torch.zeros(dim, num_branch_inputs))
90
91
  self.dynamic_beta_scale = nn.Parameter(torch.ones(()) * 1e-2)
91
92
 
93
+ # dropout
94
+
95
+ self.dropout = nn.Dropout(dropout)
96
+
92
97
  # channel first option
93
98
 
94
99
  self.channel_first = channel_first
@@ -164,7 +169,7 @@ class HyperConnections(Module):
164
169
  if self.channel_first:
165
170
  output = rearrange(output, 'b ... d -> b d ...')
166
171
 
167
- return output
172
+ return self.dropout(output)
168
173
 
169
174
  def decorate_branch(self, branch: Callable | tuple[Callable, ...] | list[Callable]):
170
175
  assert not exists(self.branches), 'branch was already wrapped on init'
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "hyper-connections"
3
- version = "0.0.17"
3
+ version = "0.0.19"
4
4
  description = "Hyper-Connections"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }