alphagenome-pytorch 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -15,6 +15,7 @@ from einops import rearrange, repeat, einsum
15
15
  # b - batch
16
16
  # h - heads
17
17
  # n - sequence
18
+ # p - relative positions
18
19
  # d - feature dimension
19
20
 
20
21
  # constants
@@ -79,9 +80,15 @@ def apply_rotary_pos_emb(pos, t):
79
80
  # 'central mask features' - relative positions for constituting pairwise rep
80
81
 
81
82
  class RelativePosFeatures(Module):
83
+ def __init__(self, pool_size = 16):
84
+ super().__init__()
85
+ self.pool_size = pool_size
86
+
82
87
  def forward(self, single):
83
88
 
84
89
  _, seq_len, dim = single.shape
90
+
91
+ seq_len //= self.pool_size
85
92
  half_dim = dim // 2
86
93
 
87
94
  rel_pos = arange(2 * seq_len - 1) - (seq_len - 1)
@@ -260,6 +267,8 @@ class SingleToPairwise(Module):
260
267
 
261
268
  single = self.avg_pool(single)
262
269
 
270
+ pool_seq_len = single.shape[1]
271
+
263
272
  q, k = self.to_qk(single).chunk(2, dim = -1)
264
273
  q, k = tuple(self.split_heads(t) for t in (q, k))
265
274
 
@@ -271,12 +280,14 @@ class SingleToPairwise(Module):
271
280
 
272
281
  q_rel_bias, k_rel_bias = self.qk_rel_pos_bias
273
282
 
274
- rel_q = relative_shift(einsum(q + q_rel_bias, rel_pos_encoding, 'b i h d, j h d -> b i j h'))
275
- rel_k = relative_shift(einsum(k + k_rel_bias, rel_pos_encoding, 'b i h d, j h d -> b i j h'))
283
+ rel_q = relative_shift(einsum(q + q_rel_bias, rel_pos_encoding, 'b n h d, p h d -> b n p h'))
284
+ rel_k = relative_shift(einsum(k + k_rel_bias, rel_pos_encoding, 'b n h d, p h d -> b n p h'))
285
+
286
+ rel_q, rel_k = tuple(t[..., :pool_seq_len, :] for t in (rel_q, rel_k))
276
287
 
277
- rel_sim = einsum(rel_q, rel_k, 'b i p h, b j p h -> b i j h')
288
+ rel_sim = add('b i j d, b j i d -> b i j d', rel_q, rel_k) * 0.5
278
289
 
279
- sim = (sim + rel_sim) * 0.5
290
+ sim = sim + rel_sim
280
291
 
281
292
  pairwise_from_sim = self.qk_to_pairwise(sim)
282
293
 
@@ -353,6 +364,7 @@ class TransformerTower(Module):
353
364
  dim_pairwise = None,
354
365
  pairwise_every_num_single_blocks = 2, # how often to do a pairwise block
355
366
  single_to_pairwise_heads = 32, # they did 32
367
+ pool_size = 16,
356
368
  attn_kwargs: dict = dict(),
357
369
  ff_kwargs: dict = dict()
358
370
  ):
@@ -363,7 +375,7 @@ class TransformerTower(Module):
363
375
 
364
376
  self.pairwise_every = pairwise_every_num_single_blocks
365
377
 
366
- self.rel_pos_features = RelativePosFeatures()
378
+ self.rel_pos_features = RelativePosFeatures(pool_size)
367
379
 
368
380
  self.rotary_emb = RotaryEmbedding(dim_head_qk, max_positions = max_positions)
369
381
 
@@ -381,7 +393,7 @@ class TransformerTower(Module):
381
393
  single_to_pairwise, pairwise_attn, pairwise_ff = None, None, None
382
394
 
383
395
  if divisible_by(layer_index, self.pairwise_every):
384
- single_to_pairwise = SingleToPairwise(dim = dim, dim_pairwise = dim_pairwise, heads = single_to_pairwise_heads)
396
+ single_to_pairwise = SingleToPairwise(dim = dim, dim_pairwise = dim_pairwise, heads = single_to_pairwise_heads, pool_size = pool_size)
385
397
  pairwise_attn = PairwiseRowAttention(dim_pairwise)
386
398
  pairwise_ff = FeedForward(dim = dim_pairwise, expansion_factor = ff_expansion_factor)
387
399
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: alphagenome-pytorch
3
- Version: 0.0.4
3
+ Version: 0.0.6
4
4
  Summary: AlphaGenome
5
5
  Project-URL: Homepage, https://pypi.org/project/alphagenome-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/alphagenome
@@ -0,0 +1,6 @@
1
+ alphagenome_pytorch/__init__.py,sha256=7plC_YRm0UapNCl9hJEhFxXE-ELGKVy-DtuO5GUQxGI,101
2
+ alphagenome_pytorch/alphagenome.py,sha256=f0XX7VnbOABwT4WOadQQ64VgYS0-jsqBconX8qaa1eM,14004
3
+ alphagenome_pytorch-0.0.6.dist-info/METADATA,sha256=kI5ZBXu_SYAQmvXrA6L-QBGS7Bd3Xv2SdPoJ9rnCyjw,3386
4
+ alphagenome_pytorch-0.0.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
+ alphagenome_pytorch-0.0.6.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
+ alphagenome_pytorch-0.0.6.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- alphagenome_pytorch/__init__.py,sha256=7plC_YRm0UapNCl9hJEhFxXE-ELGKVy-DtuO5GUQxGI,101
2
- alphagenome_pytorch/alphagenome.py,sha256=5CqfuI2TaI8vxlLGj7Zy9W6ipHDKy4T7YzF2uuxito4,13665
3
- alphagenome_pytorch-0.0.4.dist-info/METADATA,sha256=jbBeJbCYA-fUKJBprrWfNYWzP8C3RTarkWoimODyHyw,3386
4
- alphagenome_pytorch-0.0.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
- alphagenome_pytorch-0.0.4.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
- alphagenome_pytorch-0.0.4.dist-info/RECORD,,