alphagenome-pytorch 0.0.9__tar.gz → 0.0.11__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {alphagenome_pytorch-0.0.9 → alphagenome_pytorch-0.0.11}/PKG-INFO +5 -5
- {alphagenome_pytorch-0.0.9 → alphagenome_pytorch-0.0.11}/README.md +4 -4
- alphagenome_pytorch-0.0.11/alphagenome_pytorch/__init__.py +11 -0
- {alphagenome_pytorch-0.0.9 → alphagenome_pytorch-0.0.11}/alphagenome_pytorch/alphagenome.py +96 -26
- {alphagenome_pytorch-0.0.9 → alphagenome_pytorch-0.0.11}/pyproject.toml +1 -1
- {alphagenome_pytorch-0.0.9 → alphagenome_pytorch-0.0.11}/tests/test_alphagenome.py +13 -0
- alphagenome_pytorch-0.0.9/alphagenome_pytorch/__init__.py +0 -5
- {alphagenome_pytorch-0.0.9 → alphagenome_pytorch-0.0.11}/.github/workflows/python-publish.yml +0 -0
- {alphagenome_pytorch-0.0.9 → alphagenome_pytorch-0.0.11}/.github/workflows/test.yml +0 -0
- {alphagenome_pytorch-0.0.9 → alphagenome_pytorch-0.0.11}/.gitignore +0 -0
- {alphagenome_pytorch-0.0.9 → alphagenome_pytorch-0.0.11}/LICENSE +0 -0
- {alphagenome_pytorch-0.0.9 → alphagenome_pytorch-0.0.11}/extended-figure-1.png +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: alphagenome-pytorch
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.11
|
4
4
|
Summary: AlphaGenome
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/alphagenome-pytorch/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/alphagenome
|
@@ -58,13 +58,13 @@ $ pip install alphagenome-pytorch
|
|
58
58
|
|
59
59
|
```python
|
60
60
|
import torch
|
61
|
-
from alphagenome_pytorch import
|
61
|
+
from alphagenome_pytorch import AlphaGenome
|
62
62
|
|
63
|
-
|
63
|
+
model = AlphaGenome()
|
64
64
|
|
65
|
-
|
65
|
+
dna = torch.randint(0, 5, (2, 8192))
|
66
66
|
|
67
|
-
|
67
|
+
pred_nucleotide, single, pairwise = model(dna) # (2, 8192, 5), (2, 64, 1536), (2, 4, 4, 1536)
|
68
68
|
```
|
69
69
|
|
70
70
|
## Citations
|
@@ -14,13 +14,13 @@ $ pip install alphagenome-pytorch
|
|
14
14
|
|
15
15
|
```python
|
16
16
|
import torch
|
17
|
-
from alphagenome_pytorch import
|
17
|
+
from alphagenome_pytorch import AlphaGenome
|
18
18
|
|
19
|
-
|
19
|
+
model = AlphaGenome()
|
20
20
|
|
21
|
-
|
21
|
+
dna = torch.randint(0, 5, (2, 8192))
|
22
22
|
|
23
|
-
|
23
|
+
pred_nucleotide, single, pairwise = model(dna) # (2, 8192, 5), (2, 64, 1536), (2, 4, 4, 1536)
|
24
24
|
```
|
25
25
|
|
26
26
|
## Citations
|
@@ -32,6 +32,9 @@ def exists(v):
|
|
32
32
|
def divisible_by(num, den):
|
33
33
|
return (num % den) == 0
|
34
34
|
|
35
|
+
def last(arr):
|
36
|
+
return arr[-1]
|
37
|
+
|
35
38
|
def is_odd(num):
|
36
39
|
return not divisible_by(num, 2)
|
37
40
|
|
@@ -128,15 +131,15 @@ class UpresBlock(Module):
|
|
128
131
|
def forward(
|
129
132
|
self,
|
130
133
|
x,
|
131
|
-
|
134
|
+
skip = None
|
132
135
|
):
|
133
136
|
|
134
137
|
residual = x[:, :-self.pad]
|
135
138
|
out = self.conv(x) + residual
|
136
139
|
|
137
|
-
if exists(
|
140
|
+
if exists(skip):
|
138
141
|
out = repeat(out, 'b c n -> b c (n upsample)', upsample = 2) * self.residual_scale
|
139
|
-
out = out + self.unet_conv(
|
142
|
+
out = out + self.unet_conv(skip)
|
140
143
|
|
141
144
|
return self.conv_out(out) + out
|
142
145
|
|
@@ -240,6 +243,7 @@ class Attention(Module):
|
|
240
243
|
dim,
|
241
244
|
dim_head = 64,
|
242
245
|
heads = 8,
|
246
|
+
kv_heads = 1,
|
243
247
|
dim_head_qk = 128,
|
244
248
|
dim_head_v = 192,
|
245
249
|
dim_pairwise = None,
|
@@ -254,8 +258,13 @@ class Attention(Module):
|
|
254
258
|
|
255
259
|
# splitting and merging of attention heads
|
256
260
|
|
257
|
-
|
258
|
-
|
261
|
+
assert divisible_by(heads, kv_heads)
|
262
|
+
groups = heads // kv_heads
|
263
|
+
|
264
|
+
self.split_q_heads = Rearrange('b n (g h d) -> b g h n d', h = kv_heads, g = groups)
|
265
|
+
self.split_kv_heads = Rearrange('b n (h d) -> b h n d', h = kv_heads)
|
266
|
+
|
267
|
+
self.merge_heads = Rearrange('b g h n d -> b n (g h d)')
|
259
268
|
|
260
269
|
# projections
|
261
270
|
|
@@ -274,7 +283,7 @@ class Attention(Module):
|
|
274
283
|
nn.RMSNorm(dim_pairwise), # replace with BatchRMSNorm once crafted
|
275
284
|
nn.GELU(),
|
276
285
|
LinearNoBias(dim_pairwise, heads),
|
277
|
-
Rearrange('b i j h -> b h i j')
|
286
|
+
Rearrange('b i j (g h) -> b g h i j', g = groups)
|
278
287
|
)
|
279
288
|
# variables
|
280
289
|
|
@@ -293,6 +302,7 @@ class Attention(Module):
|
|
293
302
|
# they use multi-query attention, with only 1 key / value head - pretty unconventional, but maybe enough for genomic modeling
|
294
303
|
|
295
304
|
q = self.split_q_heads(q)
|
305
|
+
k, v = tuple(self.split_kv_heads(t) for t in (k, v))
|
296
306
|
|
297
307
|
q, k, v = self.q_norm(q), self.k_norm(k), self.v_norm(v)
|
298
308
|
|
@@ -305,7 +315,7 @@ class Attention(Module):
|
|
305
315
|
|
306
316
|
# similarities
|
307
317
|
|
308
|
-
sim = einsum(q, k, 'b h i d, b j d -> b h i j')
|
318
|
+
sim = einsum(q, k, 'b g h i d, b h j d -> b g h i j')
|
309
319
|
|
310
320
|
# add attention bias + softclamping
|
311
321
|
|
@@ -315,7 +325,7 @@ class Attention(Module):
|
|
315
325
|
assert divisible_by(sim.shape[-1], attn_bias.shape[-1])
|
316
326
|
expand_factor = sim.shape[-1] // attn_bias.shape[-1]
|
317
327
|
|
318
|
-
attn_bias = repeat(attn_bias, 'b h i j -> b h (i r1) (j r2)', r1 = expand_factor, r2 = expand_factor)
|
328
|
+
attn_bias = repeat(attn_bias, 'b g h i j -> b g h (i r1) (j r2)', r1 = expand_factor, r2 = expand_factor)
|
319
329
|
|
320
330
|
sim = softclamp(sim + attn_bias, value = self.softclamp_value)
|
321
331
|
|
@@ -325,7 +335,7 @@ class Attention(Module):
|
|
325
335
|
|
326
336
|
# aggregate
|
327
337
|
|
328
|
-
out = einsum(attn, v, 'b h i j, b j d -> b h i d')
|
338
|
+
out = einsum(attn, v, 'b g h i j, b h j d -> b g h i d')
|
329
339
|
|
330
340
|
out = self.merge_heads(out)
|
331
341
|
return self.to_out(out)
|
@@ -384,7 +394,7 @@ class SingleToPairwise(Module):
|
|
384
394
|
rel_q = relative_shift(einsum(q + q_rel_bias, rel_pos_encoding, 'b n h d, p h d -> b h n p'))
|
385
395
|
rel_k = relative_shift(einsum(k + k_rel_bias, rel_pos_encoding, 'b n h d, p h d -> b h n p'))
|
386
396
|
|
387
|
-
rel_sim = add('b h i j, b h j i -> b
|
397
|
+
rel_sim = add('b h i j, b h j i -> b i j h', rel_q, rel_k) * 0.5
|
388
398
|
|
389
399
|
sim = sim + rel_sim
|
390
400
|
|
@@ -460,7 +470,7 @@ class TransformerTower(Module):
|
|
460
470
|
dropout = 0.,
|
461
471
|
ff_expansion_factor = 2.,
|
462
472
|
max_positions = 8192,
|
463
|
-
dim_pairwise =
|
473
|
+
dim_pairwise = 128,
|
464
474
|
pairwise_every_num_single_blocks = 2, # how often to do a pairwise block
|
465
475
|
single_to_pairwise_heads = 32, # they did 32
|
466
476
|
pool_size = 16,
|
@@ -539,7 +549,7 @@ class TransformerTower(Module):
|
|
539
549
|
pairwise = maybe_pairwise_attn(pairwise) + pairwise
|
540
550
|
pairwise = maybe_pairwise_ff(pairwise) + pairwise
|
541
551
|
|
542
|
-
single = attn(single, rotary_emb = rotary_emb, pairwise =
|
552
|
+
single = attn(single, rotary_emb = rotary_emb, pairwise = pairwise) + single
|
543
553
|
single = ff(single) + single
|
544
554
|
|
545
555
|
return single, pairwise
|
@@ -556,8 +566,10 @@ class DNAEmbed(Module):
|
|
556
566
|
super().__init__()
|
557
567
|
assert is_odd(width)
|
558
568
|
self.dim_input = dim_input
|
559
|
-
self.conv =
|
560
|
-
self.pointwise =
|
569
|
+
self.conv = Conv1d(dim_input, dim, width, padding = width // 2)
|
570
|
+
self.pointwise = Conv1d(dim, dim, 1)
|
571
|
+
|
572
|
+
self.pool = Reduce('b d (n pool) -> b d n', 'max', pool = 2)
|
561
573
|
|
562
574
|
def forward(
|
563
575
|
self,
|
@@ -568,38 +580,96 @@ class DNAEmbed(Module):
|
|
568
580
|
|
569
581
|
out = self.conv(x)
|
570
582
|
out = out + self.pointwise(out)
|
571
|
-
|
583
|
+
pooled = self.pool(out) # think they downsample for dna embed block
|
584
|
+
|
585
|
+
return pooled, x
|
572
586
|
|
573
587
|
# classes
|
574
588
|
|
575
589
|
class AlphaGenome(Module):
|
576
590
|
def __init__(
|
577
591
|
self,
|
578
|
-
|
592
|
+
dims: tuple[int, ...] = (
|
593
|
+
768,
|
594
|
+
896,
|
595
|
+
1024,
|
596
|
+
1152,
|
597
|
+
1280,
|
598
|
+
1408,
|
599
|
+
1536
|
600
|
+
),
|
579
601
|
basepairs = 5,
|
580
602
|
dna_embed_width = 15,
|
581
|
-
dim_pairwise = None,
|
582
603
|
transformer_kwargs: dict = dict()
|
583
604
|
):
|
584
605
|
super().__init__()
|
606
|
+
|
585
607
|
assert is_odd(dna_embed_width)
|
586
608
|
|
587
|
-
|
609
|
+
assert len(dims) >= 2
|
610
|
+
first_dim, *_, last_dim = dims
|
611
|
+
|
612
|
+
self.dna_embed = DNAEmbed(first_dim, dim_input = basepairs, width = dna_embed_width)
|
613
|
+
|
614
|
+
dim_with_input = (basepairs, *dims)
|
615
|
+
dim_pairs = zip(dim_with_input[:-1], dim_with_input[1:])
|
616
|
+
|
617
|
+
downs = []
|
618
|
+
ups = []
|
619
|
+
|
620
|
+
for layer_num, (dim_in, dim_out) in enumerate(dim_pairs, start = 1):
|
621
|
+
is_first = layer_num == 1
|
622
|
+
channel_diff = dim_out - dim_in
|
623
|
+
|
624
|
+
assert channel_diff > 0
|
625
|
+
|
626
|
+
if not is_first:
|
627
|
+
down = DownresBlock(dim_in, channels_to_add = channel_diff)
|
628
|
+
downs.append(down)
|
588
629
|
|
589
|
-
|
590
|
-
|
591
|
-
|
630
|
+
up = UpresBlock(dim_out, channels_to_remove = channel_diff)
|
631
|
+
ups.insert(0, up)
|
632
|
+
|
633
|
+
|
634
|
+
self.downs = ModuleList(downs)
|
635
|
+
self.ups = ModuleList(ups)
|
636
|
+
|
637
|
+
self.transformer = TransformerTower(
|
638
|
+
dim = last_dim,
|
592
639
|
**transformer_kwargs
|
593
640
|
)
|
594
641
|
|
595
642
|
def forward(
|
596
643
|
self,
|
597
|
-
seq
|
598
|
-
pairwise
|
644
|
+
seq # Int['b n']
|
599
645
|
):
|
600
646
|
|
601
|
-
|
647
|
+
skips = []
|
648
|
+
|
649
|
+
# embed with one hot and add skip
|
650
|
+
|
651
|
+
x, skip = self.dna_embed(seq)
|
652
|
+
skips.append(skip)
|
653
|
+
|
654
|
+
# downs
|
655
|
+
|
656
|
+
for down in self.downs:
|
657
|
+
skips.append(x)
|
658
|
+
x = down(x)
|
659
|
+
|
660
|
+
x = rearrange(x, 'b d n -> b n d')
|
661
|
+
|
662
|
+
# attention
|
663
|
+
|
664
|
+
single, pairwise = self.transformer(x)
|
665
|
+
|
666
|
+
# ups with skips from down
|
667
|
+
|
668
|
+
x = rearrange(x, 'b n d -> b d n')
|
669
|
+
|
670
|
+
for up in self.ups:
|
671
|
+
x = up(x, skip = skips.pop())
|
602
672
|
|
603
|
-
|
673
|
+
pred = rearrange(x, 'b l n -> b n l') # 1bp resolution
|
604
674
|
|
605
|
-
return
|
675
|
+
return pred, single, pairwise
|
@@ -21,3 +21,16 @@ def test_down_up():
|
|
21
21
|
|
22
22
|
x = torch.randn(1, 64, 8)
|
23
23
|
assert up(down(x), x).shape == x.shape
|
24
|
+
|
25
|
+
def test_alphagenome():
|
26
|
+
from alphagenome_pytorch import AlphaGenome
|
27
|
+
|
28
|
+
model = AlphaGenome()
|
29
|
+
|
30
|
+
dna = torch.randint(0, 5, (2, 8192))
|
31
|
+
|
32
|
+
pred_nucleotide_logits, single, pairwise = model(dna)
|
33
|
+
|
34
|
+
pred = pred_nucleotide_logits.argmax(dim = -1)
|
35
|
+
|
36
|
+
assert pred.shape == dna.shape
|
{alphagenome_pytorch-0.0.9 → alphagenome_pytorch-0.0.11}/.github/workflows/python-publish.yml
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|