alphagenome-pytorch 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,11 +2,11 @@ from __future__ import annotations
2
2
  from functools import partial
3
3
 
4
4
  import torch
5
- from torch import nn, cat, stack, arange
5
+ from torch import nn, cat, stack, arange, logspace
6
6
  import torch.nn.functional as F
7
7
  from torch.nn import Linear, Sequential, Module, ModuleList
8
8
 
9
- import einx
9
+ from einx import add, multiply, greater
10
10
  from einops.layers.torch import Rearrange, Reduce
11
11
  from einops import rearrange, repeat, einsum
12
12
 
@@ -15,6 +15,7 @@ from einops import rearrange, repeat, einsum
15
15
  # b - batch
16
16
  # h - heads
17
17
  # n - sequence
18
+ # p - relative positions
18
19
  # d - feature dimension
19
20
 
20
21
  # constants
@@ -41,6 +42,12 @@ def default(v, d):
41
42
  def softclamp(t, value = 5.):
42
43
  return (t / value).tanh() * value
43
44
 
45
+ def relative_shift(t):
46
+ *leading_dims, seq_len, dim = t.shape
47
+ t = F.pad(t, (1, 0), value = 0.)
48
+ t = t.reshape(*leading_dims, dim + 1, seq_len)
49
+ return t[..., 1:, :].reshape(*leading_dims, seq_len, dim)
50
+
44
51
  # rotary, but with attenuation of short relative distance frequencies
45
52
 
46
53
  class RotaryEmbedding(Module):
@@ -51,7 +58,7 @@ class RotaryEmbedding(Module):
51
58
  ):
52
59
  super().__init__()
53
60
  num_freqs = dim // 2
54
- inv_freq = 1. / (arange(num_freqs).float() + torch.logspace(1, max_positions - num_freqs + 1, num_freqs))
61
+ inv_freq = 1. / (arange(num_freqs).float() + logspace(1, max_positions - num_freqs + 1, num_freqs))
55
62
  self.register_buffer('inv_freq', inv_freq)
56
63
 
57
64
  def forward(
@@ -70,6 +77,32 @@ def rotate_half(x):
70
77
  def apply_rotary_pos_emb(pos, t):
71
78
  return t * pos.cos() + rotate_half(t) * pos.sin()
72
79
 
80
+ # 'central mask features' - relative positions for constituting pairwise rep
81
+
82
+ class RelativePosFeatures(Module):
83
+ def __init__(self, pool_size = 16):
84
+ super().__init__()
85
+ self.pool_size = pool_size
86
+
87
+ def forward(self, single):
88
+
89
+ _, seq_len, dim = single.shape
90
+
91
+ seq_len //= self.pool_size
92
+ half_dim = dim // 2
93
+
94
+ rel_pos = arange(2 * seq_len - 1) - (seq_len - 1)
95
+
96
+ center_widths = (
97
+ arange(half_dim) +
98
+ logspace(1, seq_len - half_dim + 1, half_dim + 1)[:-1] # endpoint = False
99
+ )
100
+
101
+ abs_rel_pos, rel_pos_sign = rel_pos.abs(), rel_pos.sign()
102
+ embeds = greater('j, i -> i j', center_widths, abs_rel_pos).float()
103
+
104
+ return cat((embeds, multiply('i, i j', rel_pos_sign, embeds)), dim = -1)
105
+
73
106
  # prenorm and sandwich norm - they use sandwich norm for single rep, prenorm for pairwise rep
74
107
 
75
108
  class NormWrapper(Module):
@@ -90,10 +123,11 @@ class NormWrapper(Module):
90
123
  def forward(
91
124
  self,
92
125
  x,
126
+ *args,
93
127
  **kwargs
94
128
  ):
95
129
  x = self.pre_rmsnorm(x)
96
- out = self.block(x, **kwargs)
130
+ out = self.block(x, *args, **kwargs)
97
131
  out = self.post_block_dropout(out)
98
132
  return self.post_rmsnorm(out)
99
133
 
@@ -210,7 +244,7 @@ class SingleToPairwise(Module):
210
244
 
211
245
  dim_inner = heads * dim_pairwise
212
246
 
213
- self.split_heads = Rearrange('b n (h d) -> b n h d', h = heads)
247
+ self.split_heads = Rearrange('... (h d) -> ... h d', h = heads)
214
248
 
215
249
  self.to_outer_sum = Sequential(
216
250
  nn.GELU(),
@@ -220,7 +254,16 @@ class SingleToPairwise(Module):
220
254
  self.to_qk = LinearNoBias(dim, dim_inner * 2)
221
255
  self.qk_to_pairwise = Linear(heads, dim_pairwise)
222
256
 
223
- def forward(self, single):
257
+ # relative position related
258
+
259
+ self.to_rel_pos_encoding = Linear(dim, heads * dim_pairwise)
260
+ self.qk_rel_pos_bias = nn.Parameter(torch.zeros(2, 1, 1, heads, dim_pairwise))
261
+
262
+ def forward(
263
+ self,
264
+ single,
265
+ rel_pos_feats = None
266
+ ):
224
267
 
225
268
  single = self.avg_pool(single)
226
269
 
@@ -229,11 +272,29 @@ class SingleToPairwise(Module):
229
272
 
230
273
  sim = einsum(q, k, 'b i h d, b j h d -> b i j h')
231
274
 
275
+ if exists(rel_pos_feats):
276
+ rel_pos_encoding = self.to_rel_pos_encoding(rel_pos_feats)
277
+ rel_pos_encoding = self.split_heads(rel_pos_encoding)
278
+
279
+ q_rel_bias, k_rel_bias = self.qk_rel_pos_bias
280
+
281
+ rel_q = relative_shift(einsum(q + q_rel_bias, rel_pos_encoding, 'b n h d, p h d -> b n p h'))
282
+ rel_k = relative_shift(einsum(k + k_rel_bias, rel_pos_encoding, 'b n h d, p h d -> b n p h'))
283
+
284
+ _, seq, rel_pos_seq, _ = rel_q.shape
285
+ crop_padding = (rel_pos_seq - seq) // 2
286
+
287
+ rel_q, rel_k = tuple(t[..., crop_padding:(crop_padding + seq), :] for t in (rel_q, rel_k))
288
+
289
+ rel_sim = add('b i j d, b j i d -> b i j d', rel_q, rel_k) * 0.5
290
+
291
+ sim = sim + rel_sim
292
+
232
293
  pairwise_from_sim = self.qk_to_pairwise(sim)
233
294
 
234
295
  outer_q, outer_k = self.to_outer_sum(single).chunk(2, dim = -1)
235
296
 
236
- outer_sum = einx.add('b i d, b j d -> b i j d', outer_q, outer_k)
297
+ outer_sum = add('b i d, b j d -> b i j d', outer_q, outer_k)
237
298
 
238
299
  return outer_sum
239
300
 
@@ -304,6 +365,7 @@ class TransformerTower(Module):
304
365
  dim_pairwise = None,
305
366
  pairwise_every_num_single_blocks = 2, # how often to do a pairwise block
306
367
  single_to_pairwise_heads = 32, # they did 32
368
+ pool_size = 16,
307
369
  attn_kwargs: dict = dict(),
308
370
  ff_kwargs: dict = dict()
309
371
  ):
@@ -314,6 +376,8 @@ class TransformerTower(Module):
314
376
 
315
377
  self.pairwise_every = pairwise_every_num_single_blocks
316
378
 
379
+ self.rel_pos_features = RelativePosFeatures(pool_size)
380
+
317
381
  self.rotary_emb = RotaryEmbedding(dim_head_qk, max_positions = max_positions)
318
382
 
319
383
  for layer_index in range(depth):
@@ -330,7 +394,7 @@ class TransformerTower(Module):
330
394
  single_to_pairwise, pairwise_attn, pairwise_ff = None, None, None
331
395
 
332
396
  if divisible_by(layer_index, self.pairwise_every):
333
- single_to_pairwise = SingleToPairwise(dim = dim, dim_pairwise = dim_pairwise, heads = single_to_pairwise_heads)
397
+ single_to_pairwise = SingleToPairwise(dim = dim, dim_pairwise = dim_pairwise, heads = single_to_pairwise_heads, pool_size = pool_size)
334
398
  pairwise_attn = PairwiseRowAttention(dim_pairwise)
335
399
  pairwise_ff = FeedForward(dim = dim_pairwise, expansion_factor = ff_expansion_factor)
336
400
 
@@ -360,6 +424,8 @@ class TransformerTower(Module):
360
424
 
361
425
  pairwise = None
362
426
 
427
+ rel_pos_feats = self.rel_pos_features(single)
428
+
363
429
  rotary_emb = self.rotary_emb(seq_len)
364
430
 
365
431
  for (
@@ -374,7 +440,7 @@ class TransformerTower(Module):
374
440
  single = ff(single) + single
375
441
 
376
442
  if exists(maybe_single_to_pair):
377
- pairwise = maybe_single_to_pair(single) + default(pairwise, 0.)
443
+ pairwise = maybe_single_to_pair(single, rel_pos_feats) + default(pairwise, 0.)
378
444
  pairwise = maybe_pairwise_attn(pairwise) + pairwise
379
445
  pairwise = maybe_pairwise_ff(pairwise) + pairwise
380
446
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: alphagenome-pytorch
3
- Version: 0.0.3
3
+ Version: 0.0.5
4
4
  Summary: AlphaGenome
5
5
  Project-URL: Homepage, https://pypi.org/project/alphagenome-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/alphagenome
@@ -0,0 +1,6 @@
1
+ alphagenome_pytorch/__init__.py,sha256=7plC_YRm0UapNCl9hJEhFxXE-ELGKVy-DtuO5GUQxGI,101
2
+ alphagenome_pytorch/alphagenome.py,sha256=5uxmx-x-6KSu-2aUF4AxM-rYSxXuwnr-0meTNhu2vUM,14086
3
+ alphagenome_pytorch-0.0.5.dist-info/METADATA,sha256=0W6pkI9-RvNCOixhC1ius8a4-cwpTx7gsA370FLmnsY,3386
4
+ alphagenome_pytorch-0.0.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
+ alphagenome_pytorch-0.0.5.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
+ alphagenome_pytorch-0.0.5.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- alphagenome_pytorch/__init__.py,sha256=7plC_YRm0UapNCl9hJEhFxXE-ELGKVy-DtuO5GUQxGI,101
2
- alphagenome_pytorch/alphagenome.py,sha256=ynMieMVxKkL-BFr9yyAku4J48P4gaRAc0z0GmThTot0,11818
3
- alphagenome_pytorch-0.0.3.dist-info/METADATA,sha256=uKEwowc-D-OO5pZ_JU176e1AIB7o_-pKbKJRiyF0WO0,3386
4
- alphagenome_pytorch-0.0.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
- alphagenome_pytorch-0.0.3.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
- alphagenome_pytorch-0.0.3.dist-info/RECORD,,