alphagenome-pytorch 0.0.6__py3-none-any.whl → 0.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -46,7 +46,8 @@ def relative_shift(t):
46
46
  *leading_dims, seq_len, dim = t.shape
47
47
  t = F.pad(t, (1, 0), value = 0.)
48
48
  t = t.reshape(*leading_dims, dim + 1, seq_len)
49
- return t[..., 1:, :].reshape(*leading_dims, seq_len, dim)
49
+ t = t[..., 1:, :].reshape(*leading_dims, seq_len, dim)
50
+ return t[..., :, :seq_len]
50
51
 
51
52
  # rotary, but with attenuation of short relative distance frequencies
52
53
 
@@ -280,12 +281,10 @@ class SingleToPairwise(Module):
280
281
 
281
282
  q_rel_bias, k_rel_bias = self.qk_rel_pos_bias
282
283
 
283
- rel_q = relative_shift(einsum(q + q_rel_bias, rel_pos_encoding, 'b n h d, p h d -> b n p h'))
284
- rel_k = relative_shift(einsum(k + k_rel_bias, rel_pos_encoding, 'b n h d, p h d -> b n p h'))
284
+ rel_q = relative_shift(einsum(q + q_rel_bias, rel_pos_encoding, 'b n h d, p h d -> b h n p'))
285
+ rel_k = relative_shift(einsum(k + k_rel_bias, rel_pos_encoding, 'b n h d, p h d -> b h n p'))
285
286
 
286
- rel_q, rel_k = tuple(t[..., :pool_seq_len, :] for t in (rel_q, rel_k))
287
-
288
- rel_sim = add('b i j d, b j i d -> b i j d', rel_q, rel_k) * 0.5
287
+ rel_sim = add('b h i j, b h j i -> b h i j', rel_q, rel_k) * 0.5
289
288
 
290
289
  sim = sim + rel_sim
291
290
 
@@ -435,14 +434,14 @@ class TransformerTower(Module):
435
434
  maybe_pairwise_ff
436
435
  ) in self.layers:
437
436
 
438
- single = attn(single, rotary_emb = rotary_emb, pairwise = None) + single
439
- single = ff(single) + single
440
-
441
437
  if exists(maybe_single_to_pair):
442
438
  pairwise = maybe_single_to_pair(single, rel_pos_feats) + default(pairwise, 0.)
443
439
  pairwise = maybe_pairwise_attn(pairwise) + pairwise
444
440
  pairwise = maybe_pairwise_ff(pairwise) + pairwise
445
441
 
442
+ single = attn(single, rotary_emb = rotary_emb, pairwise = None) + single
443
+ single = ff(single) + single
444
+
446
445
  return single, pairwise
447
446
 
448
447
  # embedding
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: alphagenome-pytorch
3
- Version: 0.0.6
3
+ Version: 0.0.8
4
4
  Summary: AlphaGenome
5
5
  Project-URL: Homepage, https://pypi.org/project/alphagenome-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/alphagenome
@@ -0,0 +1,6 @@
1
+ alphagenome_pytorch/__init__.py,sha256=7plC_YRm0UapNCl9hJEhFxXE-ELGKVy-DtuO5GUQxGI,101
2
+ alphagenome_pytorch/alphagenome.py,sha256=UjlOJGTMshcaNqmY0r6IRCgRunu7BTZyhbk2vNi5Mis,13948
3
+ alphagenome_pytorch-0.0.8.dist-info/METADATA,sha256=mhlnCRy7Ovq_Gt9ul_iIqws4zdoOcLjfQTBrlbcNic8,3386
4
+ alphagenome_pytorch-0.0.8.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
+ alphagenome_pytorch-0.0.8.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
+ alphagenome_pytorch-0.0.8.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- alphagenome_pytorch/__init__.py,sha256=7plC_YRm0UapNCl9hJEhFxXE-ELGKVy-DtuO5GUQxGI,101
2
- alphagenome_pytorch/alphagenome.py,sha256=f0XX7VnbOABwT4WOadQQ64VgYS0-jsqBconX8qaa1eM,14004
3
- alphagenome_pytorch-0.0.6.dist-info/METADATA,sha256=kI5ZBXu_SYAQmvXrA6L-QBGS7Bd3Xv2SdPoJ9rnCyjw,3386
4
- alphagenome_pytorch-0.0.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
- alphagenome_pytorch-0.0.6.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
- alphagenome_pytorch-0.0.6.dist-info/RECORD,,