hyper-connections 0.0.18__tar.gz → 0.0.19__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.0.18
3
+ Version: 0.0.19
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -1,5 +1,6 @@
1
1
  from hyper_connections.hyper_connections import (
2
2
  HyperConnections,
3
3
  Residual,
4
- StreamEmbed
4
+ StreamEmbed,
5
+ AttentionPoolReduceStream
5
6
  )
@@ -258,3 +258,36 @@ class StreamEmbed(Module):
258
258
  residuals = rearrange(residuals, 'b ... s d -> (b s) ... d', s = self.num_streams)
259
259
 
260
260
  return residuals
261
+
262
+ # attention pool - taken from Enformer https://www.nature.com/articles/s41592-021-01252-x , in turn taken from somewhere else
263
+
264
+ class AttentionPoolReduceStream(Module):
265
+ def __init__(
266
+ self,
267
+ num_streams,
268
+ dim,
269
+ channel_first = False
270
+ ):
271
+ super().__init__()
272
+ self.num_streams = num_streams
273
+ self.channel_first = channel_first
274
+
275
+ self.to_attn_logits = nn.Linear(dim, dim, bias = False)
276
+ self.to_attn_logits.weight.data.copy_(torch.eye(dim))
277
+
278
+ def forward(self, residuals):
279
+
280
+ if self.channel_first:
281
+ residuals = rearrange(residuals, '(b s) d ... -> b ... s d', s = self.num_streams)
282
+ else:
283
+ residuals = rearrange(residuals, '(b s) ... d -> b ... s d', s = self.num_streams)
284
+
285
+ attn_logits = self.to_attn_logits(residuals)
286
+ attn = attn_logits.softmax(dim = -2)
287
+
288
+ residuals = reduce(residuals * attn, 'b ... s d -> b ... d', 'sum')
289
+
290
+ if self.channel_first:
291
+ residuals = rearrange(residuals, 'b ... d -> b d ...')
292
+
293
+ return residuals
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "hyper-connections"
3
- version = "0.0.18"
3
+ version = "0.0.19"
4
4
  description = "Hyper-Connections"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }