alphagenome-pytorch 0.0.2__tar.gz → 0.0.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {alphagenome_pytorch-0.0.2 → alphagenome_pytorch-0.0.4}/PKG-INFO +1 -1
- {alphagenome_pytorch-0.0.2 → alphagenome_pytorch-0.0.4}/alphagenome_pytorch/alphagenome.py +62 -9
- {alphagenome_pytorch-0.0.2 → alphagenome_pytorch-0.0.4}/pyproject.toml +1 -1
- {alphagenome_pytorch-0.0.2 → alphagenome_pytorch-0.0.4}/.github/workflows/python-publish.yml +0 -0
- {alphagenome_pytorch-0.0.2 → alphagenome_pytorch-0.0.4}/.github/workflows/test.yml +0 -0
- {alphagenome_pytorch-0.0.2 → alphagenome_pytorch-0.0.4}/.gitignore +0 -0
- {alphagenome_pytorch-0.0.2 → alphagenome_pytorch-0.0.4}/LICENSE +0 -0
- {alphagenome_pytorch-0.0.2 → alphagenome_pytorch-0.0.4}/README.md +0 -0
- {alphagenome_pytorch-0.0.2 → alphagenome_pytorch-0.0.4}/alphagenome_pytorch/__init__.py +0 -0
- {alphagenome_pytorch-0.0.2 → alphagenome_pytorch-0.0.4}/extended-figure-1.png +0 -0
- {alphagenome_pytorch-0.0.2 → alphagenome_pytorch-0.0.4}/tests/test_alphagenome.py +0 -0
@@ -2,11 +2,11 @@ from __future__ import annotations
|
|
2
2
|
from functools import partial
|
3
3
|
|
4
4
|
import torch
|
5
|
-
from torch import nn, cat, stack, arange
|
5
|
+
from torch import nn, cat, stack, arange, logspace
|
6
6
|
import torch.nn.functional as F
|
7
7
|
from torch.nn import Linear, Sequential, Module, ModuleList
|
8
8
|
|
9
|
-
import
|
9
|
+
from einx import add, multiply, greater
|
10
10
|
from einops.layers.torch import Rearrange, Reduce
|
11
11
|
from einops import rearrange, repeat, einsum
|
12
12
|
|
@@ -41,6 +41,12 @@ def default(v, d):
|
|
41
41
|
def softclamp(t, value = 5.):
|
42
42
|
return (t / value).tanh() * value
|
43
43
|
|
44
|
+
def relative_shift(t):
|
45
|
+
*leading_dims, seq_len, dim = t.shape
|
46
|
+
t = F.pad(t, (1, 0), value = 0.)
|
47
|
+
t = t.reshape(*leading_dims, dim + 1, seq_len)
|
48
|
+
return t[..., 1:, :].reshape(*leading_dims, seq_len, dim)
|
49
|
+
|
44
50
|
# rotary, but with attenuation of short relative distance frequencies
|
45
51
|
|
46
52
|
class RotaryEmbedding(Module):
|
@@ -51,7 +57,7 @@ class RotaryEmbedding(Module):
|
|
51
57
|
):
|
52
58
|
super().__init__()
|
53
59
|
num_freqs = dim // 2
|
54
|
-
inv_freq = 1. / (arange(num_freqs).float() +
|
60
|
+
inv_freq = 1. / (arange(num_freqs).float() + logspace(1, max_positions - num_freqs + 1, num_freqs))
|
55
61
|
self.register_buffer('inv_freq', inv_freq)
|
56
62
|
|
57
63
|
def forward(
|
@@ -70,6 +76,26 @@ def rotate_half(x):
|
|
70
76
|
def apply_rotary_pos_emb(pos, t):
|
71
77
|
return t * pos.cos() + rotate_half(t) * pos.sin()
|
72
78
|
|
79
|
+
# 'central mask features' - relative positions for constituting pairwise rep
|
80
|
+
|
81
|
+
class RelativePosFeatures(Module):
|
82
|
+
def forward(self, single):
|
83
|
+
|
84
|
+
_, seq_len, dim = single.shape
|
85
|
+
half_dim = dim // 2
|
86
|
+
|
87
|
+
rel_pos = arange(2 * seq_len - 1) - (seq_len - 1)
|
88
|
+
|
89
|
+
center_widths = (
|
90
|
+
arange(half_dim) +
|
91
|
+
logspace(1, seq_len - half_dim + 1, half_dim + 1)[:-1] # endpoint = False
|
92
|
+
)
|
93
|
+
|
94
|
+
abs_rel_pos, rel_pos_sign = rel_pos.abs(), rel_pos.sign()
|
95
|
+
embeds = greater('j, i -> i j', center_widths, abs_rel_pos).float()
|
96
|
+
|
97
|
+
return cat((embeds, multiply('i, i j', rel_pos_sign, embeds)), dim = -1)
|
98
|
+
|
73
99
|
# prenorm and sandwich norm - they use sandwich norm for single rep, prenorm for pairwise rep
|
74
100
|
|
75
101
|
class NormWrapper(Module):
|
@@ -90,10 +116,11 @@ class NormWrapper(Module):
|
|
90
116
|
def forward(
|
91
117
|
self,
|
92
118
|
x,
|
119
|
+
*args,
|
93
120
|
**kwargs
|
94
121
|
):
|
95
122
|
x = self.pre_rmsnorm(x)
|
96
|
-
out = self.block(x, **kwargs)
|
123
|
+
out = self.block(x, *args, **kwargs)
|
97
124
|
out = self.post_block_dropout(out)
|
98
125
|
return self.post_rmsnorm(out)
|
99
126
|
|
@@ -210,17 +237,26 @@ class SingleToPairwise(Module):
|
|
210
237
|
|
211
238
|
dim_inner = heads * dim_pairwise
|
212
239
|
|
213
|
-
self.split_heads = Rearrange('
|
240
|
+
self.split_heads = Rearrange('... (h d) -> ... h d', h = heads)
|
214
241
|
|
215
242
|
self.to_outer_sum = Sequential(
|
243
|
+
nn.GELU(),
|
216
244
|
LinearNoBias(dim, dim_pairwise * 2),
|
217
|
-
nn.GELU()
|
218
245
|
)
|
219
246
|
|
220
247
|
self.to_qk = LinearNoBias(dim, dim_inner * 2)
|
221
248
|
self.qk_to_pairwise = Linear(heads, dim_pairwise)
|
222
249
|
|
223
|
-
|
250
|
+
# relative position related
|
251
|
+
|
252
|
+
self.to_rel_pos_encoding = Linear(dim, heads * dim_pairwise)
|
253
|
+
self.qk_rel_pos_bias = nn.Parameter(torch.zeros(2, 1, 1, heads, dim_pairwise))
|
254
|
+
|
255
|
+
def forward(
|
256
|
+
self,
|
257
|
+
single,
|
258
|
+
rel_pos_feats = None
|
259
|
+
):
|
224
260
|
|
225
261
|
single = self.avg_pool(single)
|
226
262
|
|
@@ -229,11 +265,24 @@ class SingleToPairwise(Module):
|
|
229
265
|
|
230
266
|
sim = einsum(q, k, 'b i h d, b j h d -> b i j h')
|
231
267
|
|
268
|
+
if exists(rel_pos_feats):
|
269
|
+
rel_pos_encoding = self.to_rel_pos_encoding(rel_pos_feats)
|
270
|
+
rel_pos_encoding = self.split_heads(rel_pos_encoding)
|
271
|
+
|
272
|
+
q_rel_bias, k_rel_bias = self.qk_rel_pos_bias
|
273
|
+
|
274
|
+
rel_q = relative_shift(einsum(q + q_rel_bias, rel_pos_encoding, 'b i h d, j h d -> b i j h'))
|
275
|
+
rel_k = relative_shift(einsum(k + k_rel_bias, rel_pos_encoding, 'b i h d, j h d -> b i j h'))
|
276
|
+
|
277
|
+
rel_sim = einsum(rel_q, rel_k, 'b i p h, b j p h -> b i j h')
|
278
|
+
|
279
|
+
sim = (sim + rel_sim) * 0.5
|
280
|
+
|
232
281
|
pairwise_from_sim = self.qk_to_pairwise(sim)
|
233
282
|
|
234
283
|
outer_q, outer_k = self.to_outer_sum(single).chunk(2, dim = -1)
|
235
284
|
|
236
|
-
outer_sum =
|
285
|
+
outer_sum = add('b i d, b j d -> b i j d', outer_q, outer_k)
|
237
286
|
|
238
287
|
return outer_sum
|
239
288
|
|
@@ -314,6 +363,8 @@ class TransformerTower(Module):
|
|
314
363
|
|
315
364
|
self.pairwise_every = pairwise_every_num_single_blocks
|
316
365
|
|
366
|
+
self.rel_pos_features = RelativePosFeatures()
|
367
|
+
|
317
368
|
self.rotary_emb = RotaryEmbedding(dim_head_qk, max_positions = max_positions)
|
318
369
|
|
319
370
|
for layer_index in range(depth):
|
@@ -360,6 +411,8 @@ class TransformerTower(Module):
|
|
360
411
|
|
361
412
|
pairwise = None
|
362
413
|
|
414
|
+
rel_pos_feats = self.rel_pos_features(single)
|
415
|
+
|
363
416
|
rotary_emb = self.rotary_emb(seq_len)
|
364
417
|
|
365
418
|
for (
|
@@ -374,7 +427,7 @@ class TransformerTower(Module):
|
|
374
427
|
single = ff(single) + single
|
375
428
|
|
376
429
|
if exists(maybe_single_to_pair):
|
377
|
-
pairwise = maybe_single_to_pair(single) + default(pairwise, 0.)
|
430
|
+
pairwise = maybe_single_to_pair(single, rel_pos_feats) + default(pairwise, 0.)
|
378
431
|
pairwise = maybe_pairwise_attn(pairwise) + pairwise
|
379
432
|
pairwise = maybe_pairwise_ff(pairwise) + pairwise
|
380
433
|
|
{alphagenome_pytorch-0.0.2 → alphagenome_pytorch-0.0.4}/.github/workflows/python-publish.yml
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|