alphagenome-pytorch 0.0.9__tar.gz → 0.0.10__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {alphagenome_pytorch-0.0.9 → alphagenome_pytorch-0.0.10}/PKG-INFO +5 -5
- {alphagenome_pytorch-0.0.9 → alphagenome_pytorch-0.0.10}/README.md +4 -4
- alphagenome_pytorch-0.0.10/alphagenome_pytorch/__init__.py +11 -0
- {alphagenome_pytorch-0.0.9 → alphagenome_pytorch-0.0.10}/alphagenome_pytorch/alphagenome.py +82 -19
- {alphagenome_pytorch-0.0.9 → alphagenome_pytorch-0.0.10}/pyproject.toml +1 -1
- {alphagenome_pytorch-0.0.9 → alphagenome_pytorch-0.0.10}/tests/test_alphagenome.py +13 -0
- alphagenome_pytorch-0.0.9/alphagenome_pytorch/__init__.py +0 -5
- {alphagenome_pytorch-0.0.9 → alphagenome_pytorch-0.0.10}/.github/workflows/python-publish.yml +0 -0
- {alphagenome_pytorch-0.0.9 → alphagenome_pytorch-0.0.10}/.github/workflows/test.yml +0 -0
- {alphagenome_pytorch-0.0.9 → alphagenome_pytorch-0.0.10}/.gitignore +0 -0
- {alphagenome_pytorch-0.0.9 → alphagenome_pytorch-0.0.10}/LICENSE +0 -0
- {alphagenome_pytorch-0.0.9 → alphagenome_pytorch-0.0.10}/extended-figure-1.png +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: alphagenome-pytorch
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.10
|
4
4
|
Summary: AlphaGenome
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/alphagenome-pytorch/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/alphagenome
|
@@ -58,13 +58,13 @@ $ pip install alphagenome-pytorch
|
|
58
58
|
|
59
59
|
```python
|
60
60
|
import torch
|
61
|
-
from alphagenome_pytorch import
|
61
|
+
from alphagenome_pytorch import AlphaGenome
|
62
62
|
|
63
|
-
|
63
|
+
model = AlphaGenome()
|
64
64
|
|
65
|
-
|
65
|
+
dna = torch.randint(0, 5, (2, 8192))
|
66
66
|
|
67
|
-
|
67
|
+
pred_nucleotide, single, pairwise = model(dna) # (2, 8192, 5), (2, 64, 1536), (2, 4, 4, 1536)
|
68
68
|
```
|
69
69
|
|
70
70
|
## Citations
|
@@ -14,13 +14,13 @@ $ pip install alphagenome-pytorch
|
|
14
14
|
|
15
15
|
```python
|
16
16
|
import torch
|
17
|
-
from alphagenome_pytorch import
|
17
|
+
from alphagenome_pytorch import AlphaGenome
|
18
18
|
|
19
|
-
|
19
|
+
model = AlphaGenome()
|
20
20
|
|
21
|
-
|
21
|
+
dna = torch.randint(0, 5, (2, 8192))
|
22
22
|
|
23
|
-
|
23
|
+
pred_nucleotide, single, pairwise = model(dna) # (2, 8192, 5), (2, 64, 1536), (2, 4, 4, 1536)
|
24
24
|
```
|
25
25
|
|
26
26
|
## Citations
|
@@ -32,6 +32,9 @@ def exists(v):
|
|
32
32
|
def divisible_by(num, den):
|
33
33
|
return (num % den) == 0
|
34
34
|
|
35
|
+
def last(arr):
|
36
|
+
return arr[-1]
|
37
|
+
|
35
38
|
def is_odd(num):
|
36
39
|
return not divisible_by(num, 2)
|
37
40
|
|
@@ -128,15 +131,15 @@ class UpresBlock(Module):
|
|
128
131
|
def forward(
|
129
132
|
self,
|
130
133
|
x,
|
131
|
-
|
134
|
+
skip = None
|
132
135
|
):
|
133
136
|
|
134
137
|
residual = x[:, :-self.pad]
|
135
138
|
out = self.conv(x) + residual
|
136
139
|
|
137
|
-
if exists(
|
140
|
+
if exists(skip):
|
138
141
|
out = repeat(out, 'b c n -> b c (n upsample)', upsample = 2) * self.residual_scale
|
139
|
-
out = out + self.unet_conv(
|
142
|
+
out = out + self.unet_conv(skip)
|
140
143
|
|
141
144
|
return self.conv_out(out) + out
|
142
145
|
|
@@ -384,7 +387,7 @@ class SingleToPairwise(Module):
|
|
384
387
|
rel_q = relative_shift(einsum(q + q_rel_bias, rel_pos_encoding, 'b n h d, p h d -> b h n p'))
|
385
388
|
rel_k = relative_shift(einsum(k + k_rel_bias, rel_pos_encoding, 'b n h d, p h d -> b h n p'))
|
386
389
|
|
387
|
-
rel_sim = add('b h i j, b h j i -> b
|
390
|
+
rel_sim = add('b h i j, b h j i -> b i j h', rel_q, rel_k) * 0.5
|
388
391
|
|
389
392
|
sim = sim + rel_sim
|
390
393
|
|
@@ -460,7 +463,7 @@ class TransformerTower(Module):
|
|
460
463
|
dropout = 0.,
|
461
464
|
ff_expansion_factor = 2.,
|
462
465
|
max_positions = 8192,
|
463
|
-
dim_pairwise =
|
466
|
+
dim_pairwise = 128,
|
464
467
|
pairwise_every_num_single_blocks = 2, # how often to do a pairwise block
|
465
468
|
single_to_pairwise_heads = 32, # they did 32
|
466
469
|
pool_size = 16,
|
@@ -556,8 +559,10 @@ class DNAEmbed(Module):
|
|
556
559
|
super().__init__()
|
557
560
|
assert is_odd(width)
|
558
561
|
self.dim_input = dim_input
|
559
|
-
self.conv =
|
560
|
-
self.pointwise =
|
562
|
+
self.conv = Conv1d(dim_input, dim, width, padding = width // 2)
|
563
|
+
self.pointwise = Conv1d(dim, dim, 1)
|
564
|
+
|
565
|
+
self.pool = Reduce('b d (n pool) -> b d n', 'max', pool = 2)
|
561
566
|
|
562
567
|
def forward(
|
563
568
|
self,
|
@@ -568,38 +573,96 @@ class DNAEmbed(Module):
|
|
568
573
|
|
569
574
|
out = self.conv(x)
|
570
575
|
out = out + self.pointwise(out)
|
571
|
-
|
576
|
+
pooled = self.pool(out) # think they downsample for dna embed block
|
577
|
+
|
578
|
+
return pooled, x
|
572
579
|
|
573
580
|
# classes
|
574
581
|
|
575
582
|
class AlphaGenome(Module):
|
576
583
|
def __init__(
|
577
584
|
self,
|
578
|
-
|
585
|
+
dims: tuple[int, ...] = (
|
586
|
+
768,
|
587
|
+
896,
|
588
|
+
1024,
|
589
|
+
1152,
|
590
|
+
1280,
|
591
|
+
1408,
|
592
|
+
1536
|
593
|
+
),
|
579
594
|
basepairs = 5,
|
580
595
|
dna_embed_width = 15,
|
581
|
-
dim_pairwise = None,
|
582
596
|
transformer_kwargs: dict = dict()
|
583
597
|
):
|
584
598
|
super().__init__()
|
599
|
+
|
585
600
|
assert is_odd(dna_embed_width)
|
586
601
|
|
587
|
-
|
602
|
+
assert len(dims) >= 2
|
603
|
+
first_dim, *_, last_dim = dims
|
604
|
+
|
605
|
+
self.dna_embed = DNAEmbed(first_dim, dim_input = basepairs, width = dna_embed_width)
|
606
|
+
|
607
|
+
dim_with_input = (basepairs, *dims)
|
608
|
+
dim_pairs = zip(dim_with_input[:-1], dim_with_input[1:])
|
609
|
+
|
610
|
+
downs = []
|
611
|
+
ups = []
|
612
|
+
|
613
|
+
for layer_num, (dim_in, dim_out) in enumerate(dim_pairs, start = 1):
|
614
|
+
is_first = layer_num == 1
|
615
|
+
channel_diff = dim_out - dim_in
|
616
|
+
|
617
|
+
assert channel_diff > 0
|
618
|
+
|
619
|
+
if not is_first:
|
620
|
+
down = DownresBlock(dim_in, channels_to_add = channel_diff)
|
621
|
+
downs.append(down)
|
622
|
+
|
623
|
+
up = UpresBlock(dim_out, channels_to_remove = channel_diff)
|
624
|
+
ups.insert(0, up)
|
588
625
|
|
589
|
-
|
590
|
-
|
591
|
-
|
626
|
+
|
627
|
+
self.downs = ModuleList(downs)
|
628
|
+
self.ups = ModuleList(ups)
|
629
|
+
|
630
|
+
self.transformer = TransformerTower(
|
631
|
+
dim = last_dim,
|
592
632
|
**transformer_kwargs
|
593
633
|
)
|
594
634
|
|
595
635
|
def forward(
|
596
636
|
self,
|
597
|
-
seq
|
598
|
-
pairwise
|
637
|
+
seq # Int['b n']
|
599
638
|
):
|
600
639
|
|
601
|
-
|
640
|
+
skips = []
|
641
|
+
|
642
|
+
# embed with one hot and add skip
|
643
|
+
|
644
|
+
x, skip = self.dna_embed(seq)
|
645
|
+
skips.append(skip)
|
646
|
+
|
647
|
+
# downs
|
648
|
+
|
649
|
+
for down in self.downs:
|
650
|
+
skips.append(x)
|
651
|
+
x = down(x)
|
652
|
+
|
653
|
+
x = rearrange(x, 'b d n -> b n d')
|
654
|
+
|
655
|
+
# attention
|
656
|
+
|
657
|
+
single, pairwise = self.transformer(x)
|
658
|
+
|
659
|
+
# ups with skips from down
|
660
|
+
|
661
|
+
x = rearrange(x, 'b n d -> b d n')
|
662
|
+
|
663
|
+
for up in self.ups:
|
664
|
+
x = up(x, skip = skips.pop())
|
602
665
|
|
603
|
-
|
666
|
+
pred = rearrange(x, 'b l n -> b n l') # 1bp resolution
|
604
667
|
|
605
|
-
return
|
668
|
+
return pred, single, pairwise
|
@@ -21,3 +21,16 @@ def test_down_up():
|
|
21
21
|
|
22
22
|
x = torch.randn(1, 64, 8)
|
23
23
|
assert up(down(x), x).shape == x.shape
|
24
|
+
|
25
|
+
def test_alphagenome():
|
26
|
+
from alphagenome_pytorch import AlphaGenome
|
27
|
+
|
28
|
+
model = AlphaGenome()
|
29
|
+
|
30
|
+
dna = torch.randint(0, 5, (2, 8192))
|
31
|
+
|
32
|
+
pred_nucleotide_logits, single, pairwise = model(dna)
|
33
|
+
|
34
|
+
pred = pred_nucleotide_logits.argmax(dim = -1)
|
35
|
+
|
36
|
+
assert pred.shape == dna.shape
|
{alphagenome_pytorch-0.0.9 → alphagenome_pytorch-0.0.10}/.github/workflows/python-publish.yml
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|