alphagenome-pytorch 0.0.3__tar.gz → 0.0.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {alphagenome_pytorch-0.0.3 → alphagenome_pytorch-0.0.5}/PKG-INFO +1 -1
- {alphagenome_pytorch-0.0.3 → alphagenome_pytorch-0.0.5}/alphagenome_pytorch/alphagenome.py +75 -9
- {alphagenome_pytorch-0.0.3 → alphagenome_pytorch-0.0.5}/pyproject.toml +1 -1
- {alphagenome_pytorch-0.0.3 → alphagenome_pytorch-0.0.5}/.github/workflows/python-publish.yml +0 -0
- {alphagenome_pytorch-0.0.3 → alphagenome_pytorch-0.0.5}/.github/workflows/test.yml +0 -0
- {alphagenome_pytorch-0.0.3 → alphagenome_pytorch-0.0.5}/.gitignore +0 -0
- {alphagenome_pytorch-0.0.3 → alphagenome_pytorch-0.0.5}/LICENSE +0 -0
- {alphagenome_pytorch-0.0.3 → alphagenome_pytorch-0.0.5}/README.md +0 -0
- {alphagenome_pytorch-0.0.3 → alphagenome_pytorch-0.0.5}/alphagenome_pytorch/__init__.py +0 -0
- {alphagenome_pytorch-0.0.3 → alphagenome_pytorch-0.0.5}/extended-figure-1.png +0 -0
- {alphagenome_pytorch-0.0.3 → alphagenome_pytorch-0.0.5}/tests/test_alphagenome.py +0 -0
@@ -2,11 +2,11 @@ from __future__ import annotations
|
|
2
2
|
from functools import partial
|
3
3
|
|
4
4
|
import torch
|
5
|
-
from torch import nn, cat, stack, arange
|
5
|
+
from torch import nn, cat, stack, arange, logspace
|
6
6
|
import torch.nn.functional as F
|
7
7
|
from torch.nn import Linear, Sequential, Module, ModuleList
|
8
8
|
|
9
|
-
import
|
9
|
+
from einx import add, multiply, greater
|
10
10
|
from einops.layers.torch import Rearrange, Reduce
|
11
11
|
from einops import rearrange, repeat, einsum
|
12
12
|
|
@@ -15,6 +15,7 @@ from einops import rearrange, repeat, einsum
|
|
15
15
|
# b - batch
|
16
16
|
# h - heads
|
17
17
|
# n - sequence
|
18
|
+
# p - relative positions
|
18
19
|
# d - feature dimension
|
19
20
|
|
20
21
|
# constants
|
@@ -41,6 +42,12 @@ def default(v, d):
|
|
41
42
|
def softclamp(t, value = 5.):
|
42
43
|
return (t / value).tanh() * value
|
43
44
|
|
45
|
+
def relative_shift(t):
|
46
|
+
*leading_dims, seq_len, dim = t.shape
|
47
|
+
t = F.pad(t, (1, 0), value = 0.)
|
48
|
+
t = t.reshape(*leading_dims, dim + 1, seq_len)
|
49
|
+
return t[..., 1:, :].reshape(*leading_dims, seq_len, dim)
|
50
|
+
|
44
51
|
# rotary, but with attenuation of short relative distance frequencies
|
45
52
|
|
46
53
|
class RotaryEmbedding(Module):
|
@@ -51,7 +58,7 @@ class RotaryEmbedding(Module):
|
|
51
58
|
):
|
52
59
|
super().__init__()
|
53
60
|
num_freqs = dim // 2
|
54
|
-
inv_freq = 1. / (arange(num_freqs).float() +
|
61
|
+
inv_freq = 1. / (arange(num_freqs).float() + logspace(1, max_positions - num_freqs + 1, num_freqs))
|
55
62
|
self.register_buffer('inv_freq', inv_freq)
|
56
63
|
|
57
64
|
def forward(
|
@@ -70,6 +77,32 @@ def rotate_half(x):
|
|
70
77
|
def apply_rotary_pos_emb(pos, t):
|
71
78
|
return t * pos.cos() + rotate_half(t) * pos.sin()
|
72
79
|
|
80
|
+
# 'central mask features' - relative positions for constituting pairwise rep
|
81
|
+
|
82
|
+
class RelativePosFeatures(Module):
|
83
|
+
def __init__(self, pool_size = 16):
|
84
|
+
super().__init__()
|
85
|
+
self.pool_size = pool_size
|
86
|
+
|
87
|
+
def forward(self, single):
|
88
|
+
|
89
|
+
_, seq_len, dim = single.shape
|
90
|
+
|
91
|
+
seq_len //= self.pool_size
|
92
|
+
half_dim = dim // 2
|
93
|
+
|
94
|
+
rel_pos = arange(2 * seq_len - 1) - (seq_len - 1)
|
95
|
+
|
96
|
+
center_widths = (
|
97
|
+
arange(half_dim) +
|
98
|
+
logspace(1, seq_len - half_dim + 1, half_dim + 1)[:-1] # endpoint = False
|
99
|
+
)
|
100
|
+
|
101
|
+
abs_rel_pos, rel_pos_sign = rel_pos.abs(), rel_pos.sign()
|
102
|
+
embeds = greater('j, i -> i j', center_widths, abs_rel_pos).float()
|
103
|
+
|
104
|
+
return cat((embeds, multiply('i, i j', rel_pos_sign, embeds)), dim = -1)
|
105
|
+
|
73
106
|
# prenorm and sandwich norm - they use sandwich norm for single rep, prenorm for pairwise rep
|
74
107
|
|
75
108
|
class NormWrapper(Module):
|
@@ -90,10 +123,11 @@ class NormWrapper(Module):
|
|
90
123
|
def forward(
|
91
124
|
self,
|
92
125
|
x,
|
126
|
+
*args,
|
93
127
|
**kwargs
|
94
128
|
):
|
95
129
|
x = self.pre_rmsnorm(x)
|
96
|
-
out = self.block(x, **kwargs)
|
130
|
+
out = self.block(x, *args, **kwargs)
|
97
131
|
out = self.post_block_dropout(out)
|
98
132
|
return self.post_rmsnorm(out)
|
99
133
|
|
@@ -210,7 +244,7 @@ class SingleToPairwise(Module):
|
|
210
244
|
|
211
245
|
dim_inner = heads * dim_pairwise
|
212
246
|
|
213
|
-
self.split_heads = Rearrange('
|
247
|
+
self.split_heads = Rearrange('... (h d) -> ... h d', h = heads)
|
214
248
|
|
215
249
|
self.to_outer_sum = Sequential(
|
216
250
|
nn.GELU(),
|
@@ -220,7 +254,16 @@ class SingleToPairwise(Module):
|
|
220
254
|
self.to_qk = LinearNoBias(dim, dim_inner * 2)
|
221
255
|
self.qk_to_pairwise = Linear(heads, dim_pairwise)
|
222
256
|
|
223
|
-
|
257
|
+
# relative position related
|
258
|
+
|
259
|
+
self.to_rel_pos_encoding = Linear(dim, heads * dim_pairwise)
|
260
|
+
self.qk_rel_pos_bias = nn.Parameter(torch.zeros(2, 1, 1, heads, dim_pairwise))
|
261
|
+
|
262
|
+
def forward(
|
263
|
+
self,
|
264
|
+
single,
|
265
|
+
rel_pos_feats = None
|
266
|
+
):
|
224
267
|
|
225
268
|
single = self.avg_pool(single)
|
226
269
|
|
@@ -229,11 +272,29 @@ class SingleToPairwise(Module):
|
|
229
272
|
|
230
273
|
sim = einsum(q, k, 'b i h d, b j h d -> b i j h')
|
231
274
|
|
275
|
+
if exists(rel_pos_feats):
|
276
|
+
rel_pos_encoding = self.to_rel_pos_encoding(rel_pos_feats)
|
277
|
+
rel_pos_encoding = self.split_heads(rel_pos_encoding)
|
278
|
+
|
279
|
+
q_rel_bias, k_rel_bias = self.qk_rel_pos_bias
|
280
|
+
|
281
|
+
rel_q = relative_shift(einsum(q + q_rel_bias, rel_pos_encoding, 'b n h d, p h d -> b n p h'))
|
282
|
+
rel_k = relative_shift(einsum(k + k_rel_bias, rel_pos_encoding, 'b n h d, p h d -> b n p h'))
|
283
|
+
|
284
|
+
_, seq, rel_pos_seq, _ = rel_q.shape
|
285
|
+
crop_padding = (rel_pos_seq - seq) // 2
|
286
|
+
|
287
|
+
rel_q, rel_k = tuple(t[..., crop_padding:(crop_padding + seq), :] for t in (rel_q, rel_k))
|
288
|
+
|
289
|
+
rel_sim = add('b i j d, b j i d -> b i j d', rel_q, rel_k) * 0.5
|
290
|
+
|
291
|
+
sim = sim + rel_sim
|
292
|
+
|
232
293
|
pairwise_from_sim = self.qk_to_pairwise(sim)
|
233
294
|
|
234
295
|
outer_q, outer_k = self.to_outer_sum(single).chunk(2, dim = -1)
|
235
296
|
|
236
|
-
outer_sum =
|
297
|
+
outer_sum = add('b i d, b j d -> b i j d', outer_q, outer_k)
|
237
298
|
|
238
299
|
return outer_sum
|
239
300
|
|
@@ -304,6 +365,7 @@ class TransformerTower(Module):
|
|
304
365
|
dim_pairwise = None,
|
305
366
|
pairwise_every_num_single_blocks = 2, # how often to do a pairwise block
|
306
367
|
single_to_pairwise_heads = 32, # they did 32
|
368
|
+
pool_size = 16,
|
307
369
|
attn_kwargs: dict = dict(),
|
308
370
|
ff_kwargs: dict = dict()
|
309
371
|
):
|
@@ -314,6 +376,8 @@ class TransformerTower(Module):
|
|
314
376
|
|
315
377
|
self.pairwise_every = pairwise_every_num_single_blocks
|
316
378
|
|
379
|
+
self.rel_pos_features = RelativePosFeatures(pool_size)
|
380
|
+
|
317
381
|
self.rotary_emb = RotaryEmbedding(dim_head_qk, max_positions = max_positions)
|
318
382
|
|
319
383
|
for layer_index in range(depth):
|
@@ -330,7 +394,7 @@ class TransformerTower(Module):
|
|
330
394
|
single_to_pairwise, pairwise_attn, pairwise_ff = None, None, None
|
331
395
|
|
332
396
|
if divisible_by(layer_index, self.pairwise_every):
|
333
|
-
single_to_pairwise = SingleToPairwise(dim = dim, dim_pairwise = dim_pairwise, heads = single_to_pairwise_heads)
|
397
|
+
single_to_pairwise = SingleToPairwise(dim = dim, dim_pairwise = dim_pairwise, heads = single_to_pairwise_heads, pool_size = pool_size)
|
334
398
|
pairwise_attn = PairwiseRowAttention(dim_pairwise)
|
335
399
|
pairwise_ff = FeedForward(dim = dim_pairwise, expansion_factor = ff_expansion_factor)
|
336
400
|
|
@@ -360,6 +424,8 @@ class TransformerTower(Module):
|
|
360
424
|
|
361
425
|
pairwise = None
|
362
426
|
|
427
|
+
rel_pos_feats = self.rel_pos_features(single)
|
428
|
+
|
363
429
|
rotary_emb = self.rotary_emb(seq_len)
|
364
430
|
|
365
431
|
for (
|
@@ -374,7 +440,7 @@ class TransformerTower(Module):
|
|
374
440
|
single = ff(single) + single
|
375
441
|
|
376
442
|
if exists(maybe_single_to_pair):
|
377
|
-
pairwise = maybe_single_to_pair(single) + default(pairwise, 0.)
|
443
|
+
pairwise = maybe_single_to_pair(single, rel_pos_feats) + default(pairwise, 0.)
|
378
444
|
pairwise = maybe_pairwise_attn(pairwise) + pairwise
|
379
445
|
pairwise = maybe_pairwise_ff(pairwise) + pairwise
|
380
446
|
|
{alphagenome_pytorch-0.0.3 → alphagenome_pytorch-0.0.5}/.github/workflows/python-publish.yml
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|