hyper-connections 0.0.18__py3-none-any.whl → 0.0.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,6 @@
1
1
  from hyper_connections.hyper_connections import (
2
2
  HyperConnections,
3
3
  Residual,
4
- StreamEmbed
4
+ StreamEmbed,
5
+ AttentionPoolReduceStream
5
6
  )
@@ -12,6 +12,14 @@ from torch.utils._pytree import tree_flatten, tree_unflatten
12
12
 
13
13
  from einops import rearrange, repeat, reduce, einsum
14
14
 
15
+ """
16
+ ein notation:
17
+ b - batch
18
+ d - feature dimension
19
+ s - residual streams
20
+ t - residual streams + num branch inputs
21
+ """
22
+
15
23
  # helper functions
16
24
 
17
25
  def exists(v):
@@ -102,6 +110,8 @@ class HyperConnections(Module):
102
110
 
103
111
  self.norm = nn.RMSNorm(dim) # they used layernorm in paper, but rmsnorm is fine given what we know now
104
112
 
113
+ assert num_residual_streams > 0, '`num_residual_streams` must be greater than 0'
114
+
105
115
  self.num_residual_streams = num_residual_streams
106
116
  init_residual_index = default(layer_index, randrange(num_residual_streams)) % num_residual_streams # just choose one random residual stream if layer index not given
107
117
 
@@ -258,3 +268,36 @@ class StreamEmbed(Module):
258
268
  residuals = rearrange(residuals, 'b ... s d -> (b s) ... d', s = self.num_streams)
259
269
 
260
270
  return residuals
271
+
272
+ # attention pool - taken from Enformer https://www.nature.com/articles/s41592-021-01252-x , in turn taken from somewhere else
273
+
274
+ class AttentionPoolReduceStream(Module):
275
+ def __init__(
276
+ self,
277
+ num_streams,
278
+ dim,
279
+ channel_first = False
280
+ ):
281
+ super().__init__()
282
+ self.num_streams = num_streams
283
+ self.channel_first = channel_first
284
+
285
+ self.to_attn_logits = nn.Linear(dim, dim, bias = False)
286
+ self.to_attn_logits.weight.data.copy_(torch.eye(dim))
287
+
288
+ def forward(self, residuals):
289
+
290
+ if self.channel_first:
291
+ residuals = rearrange(residuals, '(b s) d ... -> b ... s d', s = self.num_streams)
292
+ else:
293
+ residuals = rearrange(residuals, '(b s) ... d -> b ... s d', s = self.num_streams)
294
+
295
+ attn_logits = self.to_attn_logits(residuals)
296
+ attn = attn_logits.softmax(dim = -2)
297
+
298
+ residuals = reduce(residuals * attn, 'b ... s d -> b ... d', 'sum')
299
+
300
+ if self.channel_first:
301
+ residuals = rearrange(residuals, 'b ... d -> b d ...')
302
+
303
+ return residuals
@@ -12,6 +12,16 @@ from torch.utils._pytree import tree_flatten, tree_unflatten
12
12
 
13
13
  from einops import rearrange, repeat, reduce, einsum
14
14
 
15
+ """
16
+ ein notation:
17
+ b - batch
18
+ d - feature dimension
19
+ s - residual streams
20
+ i - branch inputs
21
+ br - branch functions
22
+ t - residual streams + num branch inputs
23
+ """
24
+
15
25
  from hyper_connections.hyper_connections import Residual, StreamEmbed
16
26
 
17
27
  # helper functions
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.0.18
3
+ Version: 0.0.20
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -0,0 +1,7 @@
1
+ hyper_connections/__init__.py,sha256=wJxbrEXRGmOIjPw8fWP-cUq6CE8bvx95mIlhWifNvYc,135
2
+ hyper_connections/hyper_connections.py,sha256=ElPtieRLvVKaVg2Attx1k6esKq1SY2X4AVZbZmsQAOM,9486
3
+ hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=HbLpt79xcMv_os6brMvDd90t2GOPceliE1YFusR2eJI,7553
4
+ hyper_connections-0.0.20.dist-info/METADATA,sha256=erA-d7KNNdzPY76x8IWKd2trv2WuBO9-C2DtH-SoQ_Y,5076
5
+ hyper_connections-0.0.20.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
+ hyper_connections-0.0.20.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
7
+ hyper_connections-0.0.20.dist-info/RECORD,,
@@ -1,7 +0,0 @@
1
- hyper_connections/__init__.py,sha256=mAy66IuHqXM4XOyOZGt5mo2B2hfHdUk8jW31YnWNQTg,104
2
- hyper_connections/hyper_connections.py,sha256=hZ-O79zKOAJxAzobAQVHGamQUEnZbs5tD2vqonATLUY,8199
3
- hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=sTQ4sh1JLhRz06iP4PvpkLb_BMvIrcWRC-5_JxkbznQ,7396
4
- hyper_connections-0.0.18.dist-info/METADATA,sha256=GzMgZhbW5wLLHTsTbAU2lk1-8cGL9UNBt8iejNr08xA,5076
5
- hyper_connections-0.0.18.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
- hyper_connections-0.0.18.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
7
- hyper_connections-0.0.18.dist-info/RECORD,,