alphagenome-pytorch 0.0.2__tar.gz → 0.0.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: alphagenome-pytorch
3
- Version: 0.0.2
3
+ Version: 0.0.4
4
4
  Summary: AlphaGenome
5
5
  Project-URL: Homepage, https://pypi.org/project/alphagenome-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/alphagenome
@@ -2,11 +2,11 @@ from __future__ import annotations
2
2
  from functools import partial
3
3
 
4
4
  import torch
5
- from torch import nn, cat, stack, arange
5
+ from torch import nn, cat, stack, arange, logspace
6
6
  import torch.nn.functional as F
7
7
  from torch.nn import Linear, Sequential, Module, ModuleList
8
8
 
9
- import einx
9
+ from einx import add, multiply, greater
10
10
  from einops.layers.torch import Rearrange, Reduce
11
11
  from einops import rearrange, repeat, einsum
12
12
 
@@ -41,6 +41,12 @@ def default(v, d):
41
41
  def softclamp(t, value = 5.):
42
42
  return (t / value).tanh() * value
43
43
 
44
+ def relative_shift(t):
45
+ *leading_dims, seq_len, dim = t.shape
46
+ t = F.pad(t, (1, 0), value = 0.)
47
+ t = t.reshape(*leading_dims, dim + 1, seq_len)
48
+ return t[..., 1:, :].reshape(*leading_dims, seq_len, dim)
49
+
44
50
  # rotary, but with attenuation of short relative distance frequencies
45
51
 
46
52
  class RotaryEmbedding(Module):
@@ -51,7 +57,7 @@ class RotaryEmbedding(Module):
51
57
  ):
52
58
  super().__init__()
53
59
  num_freqs = dim // 2
54
- inv_freq = 1. / (arange(num_freqs).float() + torch.logspace(1, max_positions - num_freqs + 1, num_freqs))
60
+ inv_freq = 1. / (arange(num_freqs).float() + logspace(1, max_positions - num_freqs + 1, num_freqs))
55
61
  self.register_buffer('inv_freq', inv_freq)
56
62
 
57
63
  def forward(
@@ -70,6 +76,26 @@ def rotate_half(x):
70
76
  def apply_rotary_pos_emb(pos, t):
71
77
  return t * pos.cos() + rotate_half(t) * pos.sin()
72
78
 
79
+ # 'central mask features' - relative positions for constituting pairwise rep
80
+
81
+ class RelativePosFeatures(Module):
82
+ def forward(self, single):
83
+
84
+ _, seq_len, dim = single.shape
85
+ half_dim = dim // 2
86
+
87
+ rel_pos = arange(2 * seq_len - 1) - (seq_len - 1)
88
+
89
+ center_widths = (
90
+ arange(half_dim) +
91
+ logspace(1, seq_len - half_dim + 1, half_dim + 1)[:-1] # endpoint = False
92
+ )
93
+
94
+ abs_rel_pos, rel_pos_sign = rel_pos.abs(), rel_pos.sign()
95
+ embeds = greater('j, i -> i j', center_widths, abs_rel_pos).float()
96
+
97
+ return cat((embeds, multiply('i, i j', rel_pos_sign, embeds)), dim = -1)
98
+
73
99
  # prenorm and sandwich norm - they use sandwich norm for single rep, prenorm for pairwise rep
74
100
 
75
101
  class NormWrapper(Module):
@@ -90,10 +116,11 @@ class NormWrapper(Module):
90
116
  def forward(
91
117
  self,
92
118
  x,
119
+ *args,
93
120
  **kwargs
94
121
  ):
95
122
  x = self.pre_rmsnorm(x)
96
- out = self.block(x, **kwargs)
123
+ out = self.block(x, *args, **kwargs)
97
124
  out = self.post_block_dropout(out)
98
125
  return self.post_rmsnorm(out)
99
126
 
@@ -210,17 +237,26 @@ class SingleToPairwise(Module):
210
237
 
211
238
  dim_inner = heads * dim_pairwise
212
239
 
213
- self.split_heads = Rearrange('b n (h d) -> b n h d', h = heads)
240
+ self.split_heads = Rearrange('... (h d) -> ... h d', h = heads)
214
241
 
215
242
  self.to_outer_sum = Sequential(
243
+ nn.GELU(),
216
244
  LinearNoBias(dim, dim_pairwise * 2),
217
- nn.GELU()
218
245
  )
219
246
 
220
247
  self.to_qk = LinearNoBias(dim, dim_inner * 2)
221
248
  self.qk_to_pairwise = Linear(heads, dim_pairwise)
222
249
 
223
- def forward(self, single):
250
+ # relative position related
251
+
252
+ self.to_rel_pos_encoding = Linear(dim, heads * dim_pairwise)
253
+ self.qk_rel_pos_bias = nn.Parameter(torch.zeros(2, 1, 1, heads, dim_pairwise))
254
+
255
+ def forward(
256
+ self,
257
+ single,
258
+ rel_pos_feats = None
259
+ ):
224
260
 
225
261
  single = self.avg_pool(single)
226
262
 
@@ -229,11 +265,24 @@ class SingleToPairwise(Module):
229
265
 
230
266
  sim = einsum(q, k, 'b i h d, b j h d -> b i j h')
231
267
 
268
+ if exists(rel_pos_feats):
269
+ rel_pos_encoding = self.to_rel_pos_encoding(rel_pos_feats)
270
+ rel_pos_encoding = self.split_heads(rel_pos_encoding)
271
+
272
+ q_rel_bias, k_rel_bias = self.qk_rel_pos_bias
273
+
274
+ rel_q = relative_shift(einsum(q + q_rel_bias, rel_pos_encoding, 'b i h d, j h d -> b i j h'))
275
+ rel_k = relative_shift(einsum(k + k_rel_bias, rel_pos_encoding, 'b i h d, j h d -> b i j h'))
276
+
277
+ rel_sim = einsum(rel_q, rel_k, 'b i p h, b j p h -> b i j h')
278
+
279
+ sim = (sim + rel_sim) * 0.5
280
+
232
281
  pairwise_from_sim = self.qk_to_pairwise(sim)
233
282
 
234
283
  outer_q, outer_k = self.to_outer_sum(single).chunk(2, dim = -1)
235
284
 
236
- outer_sum = einx.add('b i d, b j d -> b i j d', outer_q, outer_k)
285
+ outer_sum = add('b i d, b j d -> b i j d', outer_q, outer_k)
237
286
 
238
287
  return outer_sum
239
288
 
@@ -314,6 +363,8 @@ class TransformerTower(Module):
314
363
 
315
364
  self.pairwise_every = pairwise_every_num_single_blocks
316
365
 
366
+ self.rel_pos_features = RelativePosFeatures()
367
+
317
368
  self.rotary_emb = RotaryEmbedding(dim_head_qk, max_positions = max_positions)
318
369
 
319
370
  for layer_index in range(depth):
@@ -360,6 +411,8 @@ class TransformerTower(Module):
360
411
 
361
412
  pairwise = None
362
413
 
414
+ rel_pos_feats = self.rel_pos_features(single)
415
+
363
416
  rotary_emb = self.rotary_emb(seq_len)
364
417
 
365
418
  for (
@@ -374,7 +427,7 @@ class TransformerTower(Module):
374
427
  single = ff(single) + single
375
428
 
376
429
  if exists(maybe_single_to_pair):
377
- pairwise = maybe_single_to_pair(single) + default(pairwise, 0.)
430
+ pairwise = maybe_single_to_pair(single, rel_pos_feats) + default(pairwise, 0.)
378
431
  pairwise = maybe_pairwise_attn(pairwise) + pairwise
379
432
  pairwise = maybe_pairwise_ff(pairwise) + pairwise
380
433
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "alphagenome-pytorch"
3
- version = "0.0.2"
3
+ version = "0.0.4"
4
4
  description = "AlphaGenome"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }