alphagenome-pytorch 0.0.5__tar.gz → 0.0.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: alphagenome-pytorch
3
- Version: 0.0.5
3
+ Version: 0.0.7
4
4
  Summary: AlphaGenome
5
5
  Project-URL: Homepage, https://pypi.org/project/alphagenome-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/alphagenome
@@ -46,7 +46,8 @@ def relative_shift(t):
46
46
  *leading_dims, seq_len, dim = t.shape
47
47
  t = F.pad(t, (1, 0), value = 0.)
48
48
  t = t.reshape(*leading_dims, dim + 1, seq_len)
49
- return t[..., 1:, :].reshape(*leading_dims, seq_len, dim)
49
+ t = t[..., 1:, :].reshape(*leading_dims, seq_len, dim)
50
+ return t[..., :, :seq_len]
50
51
 
51
52
  # rotary, but with attenuation of short relative distance frequencies
52
53
 
@@ -267,6 +268,8 @@ class SingleToPairwise(Module):
267
268
 
268
269
  single = self.avg_pool(single)
269
270
 
271
+ pool_seq_len = single.shape[1]
272
+
270
273
  q, k = self.to_qk(single).chunk(2, dim = -1)
271
274
  q, k = tuple(self.split_heads(t) for t in (q, k))
272
275
 
@@ -278,15 +281,10 @@ class SingleToPairwise(Module):
278
281
 
279
282
  q_rel_bias, k_rel_bias = self.qk_rel_pos_bias
280
283
 
281
- rel_q = relative_shift(einsum(q + q_rel_bias, rel_pos_encoding, 'b n h d, p h d -> b n p h'))
282
- rel_k = relative_shift(einsum(k + k_rel_bias, rel_pos_encoding, 'b n h d, p h d -> b n p h'))
283
-
284
- _, seq, rel_pos_seq, _ = rel_q.shape
285
- crop_padding = (rel_pos_seq - seq) // 2
286
-
287
- rel_q, rel_k = tuple(t[..., crop_padding:(crop_padding + seq), :] for t in (rel_q, rel_k))
284
+ rel_q = relative_shift(einsum(q + q_rel_bias, rel_pos_encoding, 'b n h d, p h d -> b h n p'))
285
+ rel_k = relative_shift(einsum(k + k_rel_bias, rel_pos_encoding, 'b n h d, p h d -> b h n p'))
288
286
 
289
- rel_sim = add('b i j d, b j i d -> b i j d', rel_q, rel_k) * 0.5
287
+ rel_sim = add('b h i j, b h j i -> b h i j', rel_q, rel_k) * 0.5
290
288
 
291
289
  sim = sim + rel_sim
292
290
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "alphagenome-pytorch"
3
- version = "0.0.5"
3
+ version = "0.0.7"
4
4
  description = "AlphaGenome"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }