alphagenome-pytorch 0.0.4__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -15,6 +15,7 @@ from einops import rearrange, repeat, einsum
15
15
  # b - batch
16
16
  # h - heads
17
17
  # n - sequence
18
+ # p - relative positions
18
19
  # d - feature dimension
19
20
 
20
21
  # constants
@@ -79,9 +80,15 @@ def apply_rotary_pos_emb(pos, t):
79
80
  # 'central mask features' - relative positions for constituting pairwise rep
80
81
 
81
82
  class RelativePosFeatures(Module):
83
+ def __init__(self, pool_size = 16):
84
+ super().__init__()
85
+ self.pool_size = pool_size
86
+
82
87
  def forward(self, single):
83
88
 
84
89
  _, seq_len, dim = single.shape
90
+
91
+ seq_len //= self.pool_size
85
92
  half_dim = dim // 2
86
93
 
87
94
  rel_pos = arange(2 * seq_len - 1) - (seq_len - 1)
@@ -271,12 +278,17 @@ class SingleToPairwise(Module):
271
278
 
272
279
  q_rel_bias, k_rel_bias = self.qk_rel_pos_bias
273
280
 
274
- rel_q = relative_shift(einsum(q + q_rel_bias, rel_pos_encoding, 'b i h d, j h d -> b i j h'))
275
- rel_k = relative_shift(einsum(k + k_rel_bias, rel_pos_encoding, 'b i h d, j h d -> b i j h'))
281
+ rel_q = relative_shift(einsum(q + q_rel_bias, rel_pos_encoding, 'b n h d, p h d -> b n p h'))
282
+ rel_k = relative_shift(einsum(k + k_rel_bias, rel_pos_encoding, 'b n h d, p h d -> b n p h'))
276
283
 
277
- rel_sim = einsum(rel_q, rel_k, 'b i p h, b j p h -> b i j h')
284
+ _, seq, rel_pos_seq, _ = rel_q.shape
285
+ crop_padding = (rel_pos_seq - seq) // 2
278
286
 
279
- sim = (sim + rel_sim) * 0.5
287
+ rel_q, rel_k = tuple(t[..., crop_padding:(crop_padding + seq), :] for t in (rel_q, rel_k))
288
+
289
+ rel_sim = add('b i j d, b j i d -> b i j d', rel_q, rel_k) * 0.5
290
+
291
+ sim = sim + rel_sim
280
292
 
281
293
  pairwise_from_sim = self.qk_to_pairwise(sim)
282
294
 
@@ -353,6 +365,7 @@ class TransformerTower(Module):
353
365
  dim_pairwise = None,
354
366
  pairwise_every_num_single_blocks = 2, # how often to do a pairwise block
355
367
  single_to_pairwise_heads = 32, # they did 32
368
+ pool_size = 16,
356
369
  attn_kwargs: dict = dict(),
357
370
  ff_kwargs: dict = dict()
358
371
  ):
@@ -363,7 +376,7 @@ class TransformerTower(Module):
363
376
 
364
377
  self.pairwise_every = pairwise_every_num_single_blocks
365
378
 
366
- self.rel_pos_features = RelativePosFeatures()
379
+ self.rel_pos_features = RelativePosFeatures(pool_size)
367
380
 
368
381
  self.rotary_emb = RotaryEmbedding(dim_head_qk, max_positions = max_positions)
369
382
 
@@ -381,7 +394,7 @@ class TransformerTower(Module):
381
394
  single_to_pairwise, pairwise_attn, pairwise_ff = None, None, None
382
395
 
383
396
  if divisible_by(layer_index, self.pairwise_every):
384
- single_to_pairwise = SingleToPairwise(dim = dim, dim_pairwise = dim_pairwise, heads = single_to_pairwise_heads)
397
+ single_to_pairwise = SingleToPairwise(dim = dim, dim_pairwise = dim_pairwise, heads = single_to_pairwise_heads, pool_size = pool_size)
385
398
  pairwise_attn = PairwiseRowAttention(dim_pairwise)
386
399
  pairwise_ff = FeedForward(dim = dim_pairwise, expansion_factor = ff_expansion_factor)
387
400
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: alphagenome-pytorch
3
- Version: 0.0.4
3
+ Version: 0.0.5
4
4
  Summary: AlphaGenome
5
5
  Project-URL: Homepage, https://pypi.org/project/alphagenome-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/alphagenome
@@ -0,0 +1,6 @@
1
+ alphagenome_pytorch/__init__.py,sha256=7plC_YRm0UapNCl9hJEhFxXE-ELGKVy-DtuO5GUQxGI,101
2
+ alphagenome_pytorch/alphagenome.py,sha256=5uxmx-x-6KSu-2aUF4AxM-rYSxXuwnr-0meTNhu2vUM,14086
3
+ alphagenome_pytorch-0.0.5.dist-info/METADATA,sha256=0W6pkI9-RvNCOixhC1ius8a4-cwpTx7gsA370FLmnsY,3386
4
+ alphagenome_pytorch-0.0.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
+ alphagenome_pytorch-0.0.5.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
+ alphagenome_pytorch-0.0.5.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- alphagenome_pytorch/__init__.py,sha256=7plC_YRm0UapNCl9hJEhFxXE-ELGKVy-DtuO5GUQxGI,101
2
- alphagenome_pytorch/alphagenome.py,sha256=5CqfuI2TaI8vxlLGj7Zy9W6ipHDKy4T7YzF2uuxito4,13665
3
- alphagenome_pytorch-0.0.4.dist-info/METADATA,sha256=jbBeJbCYA-fUKJBprrWfNYWzP8C3RTarkWoimODyHyw,3386
4
- alphagenome_pytorch-0.0.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
- alphagenome_pytorch-0.0.4.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
- alphagenome_pytorch-0.0.4.dist-info/RECORD,,