alphagenome-pytorch 0.0.5__tar.gz → 0.0.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: alphagenome-pytorch
3
- Version: 0.0.5
3
+ Version: 0.0.6
4
4
  Summary: AlphaGenome
5
5
  Project-URL: Homepage, https://pypi.org/project/alphagenome-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/alphagenome
@@ -267,6 +267,8 @@ class SingleToPairwise(Module):
267
267
 
268
268
  single = self.avg_pool(single)
269
269
 
270
+ pool_seq_len = single.shape[1]
271
+
270
272
  q, k = self.to_qk(single).chunk(2, dim = -1)
271
273
  q, k = tuple(self.split_heads(t) for t in (q, k))
272
274
 
@@ -281,10 +283,7 @@ class SingleToPairwise(Module):
281
283
  rel_q = relative_shift(einsum(q + q_rel_bias, rel_pos_encoding, 'b n h d, p h d -> b n p h'))
282
284
  rel_k = relative_shift(einsum(k + k_rel_bias, rel_pos_encoding, 'b n h d, p h d -> b n p h'))
283
285
 
284
- _, seq, rel_pos_seq, _ = rel_q.shape
285
- crop_padding = (rel_pos_seq - seq) // 2
286
-
287
- rel_q, rel_k = tuple(t[..., crop_padding:(crop_padding + seq), :] for t in (rel_q, rel_k))
286
+ rel_q, rel_k = tuple(t[..., :pool_seq_len, :] for t in (rel_q, rel_k))
288
287
 
289
288
  rel_sim = add('b i j d, b j i d -> b i j d', rel_q, rel_k) * 0.5
290
289
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "alphagenome-pytorch"
3
- version = "0.0.5"
3
+ version = "0.0.6"
4
4
  description = "AlphaGenome"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }