alphagenome-pytorch 0.0.8__tar.gz → 0.0.10__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {alphagenome_pytorch-0.0.8 → alphagenome_pytorch-0.0.10}/PKG-INFO +5 -5
- {alphagenome_pytorch-0.0.8 → alphagenome_pytorch-0.0.10}/README.md +4 -4
- alphagenome_pytorch-0.0.10/alphagenome_pytorch/__init__.py +11 -0
- {alphagenome_pytorch-0.0.8 → alphagenome_pytorch-0.0.10}/alphagenome_pytorch/alphagenome.py +180 -17
- {alphagenome_pytorch-0.0.8 → alphagenome_pytorch-0.0.10}/pyproject.toml +1 -1
- alphagenome_pytorch-0.0.10/tests/test_alphagenome.py +36 -0
- alphagenome_pytorch-0.0.8/alphagenome_pytorch/__init__.py +0 -5
- alphagenome_pytorch-0.0.8/tests/test_alphagenome.py +0 -14
- {alphagenome_pytorch-0.0.8 → alphagenome_pytorch-0.0.10}/.github/workflows/python-publish.yml +0 -0
- {alphagenome_pytorch-0.0.8 → alphagenome_pytorch-0.0.10}/.github/workflows/test.yml +0 -0
- {alphagenome_pytorch-0.0.8 → alphagenome_pytorch-0.0.10}/.gitignore +0 -0
- {alphagenome_pytorch-0.0.8 → alphagenome_pytorch-0.0.10}/LICENSE +0 -0
- {alphagenome_pytorch-0.0.8 → alphagenome_pytorch-0.0.10}/extended-figure-1.png +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: alphagenome-pytorch
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.10
|
4
4
|
Summary: AlphaGenome
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/alphagenome-pytorch/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/alphagenome
|
@@ -58,13 +58,13 @@ $ pip install alphagenome-pytorch
|
|
58
58
|
|
59
59
|
```python
|
60
60
|
import torch
|
61
|
-
from alphagenome_pytorch import
|
61
|
+
from alphagenome_pytorch import AlphaGenome
|
62
62
|
|
63
|
-
|
63
|
+
model = AlphaGenome()
|
64
64
|
|
65
|
-
|
65
|
+
dna = torch.randint(0, 5, (2, 8192))
|
66
66
|
|
67
|
-
|
67
|
+
pred_nucleotide, single, pairwise = model(dna) # (2, 8192, 5), (2, 64, 1536), (2, 4, 4, 1536)
|
68
68
|
```
|
69
69
|
|
70
70
|
## Citations
|
@@ -14,13 +14,13 @@ $ pip install alphagenome-pytorch
|
|
14
14
|
|
15
15
|
```python
|
16
16
|
import torch
|
17
|
-
from alphagenome_pytorch import
|
17
|
+
from alphagenome_pytorch import AlphaGenome
|
18
18
|
|
19
|
-
|
19
|
+
model = AlphaGenome()
|
20
20
|
|
21
|
-
|
21
|
+
dna = torch.randint(0, 5, (2, 8192))
|
22
22
|
|
23
|
-
|
23
|
+
pred_nucleotide, single, pairwise = model(dna) # (2, 8192, 5), (2, 64, 1536), (2, 4, 4, 1536)
|
24
24
|
```
|
25
25
|
|
26
26
|
## Citations
|
@@ -4,7 +4,9 @@ from functools import partial
|
|
4
4
|
import torch
|
5
5
|
from torch import nn, cat, stack, arange, logspace
|
6
6
|
import torch.nn.functional as F
|
7
|
-
from torch.nn import Linear, Sequential, Module, ModuleList
|
7
|
+
from torch.nn import Conv1d, Linear, Sequential, Module, ModuleList
|
8
|
+
|
9
|
+
from torch.nn.utils.parametrize import register_parametrization
|
8
10
|
|
9
11
|
from einx import add, multiply, greater
|
10
12
|
from einops.layers.torch import Rearrange, Reduce
|
@@ -30,6 +32,9 @@ def exists(v):
|
|
30
32
|
def divisible_by(num, den):
|
31
33
|
return (num % den) == 0
|
32
34
|
|
35
|
+
def last(arr):
|
36
|
+
return arr[-1]
|
37
|
+
|
33
38
|
def is_odd(num):
|
34
39
|
return not divisible_by(num, 2)
|
35
40
|
|
@@ -42,6 +47,104 @@ def default(v, d):
|
|
42
47
|
def softclamp(t, value = 5.):
|
43
48
|
return (t / value).tanh() * value
|
44
49
|
|
50
|
+
# convolutional unet related
|
51
|
+
|
52
|
+
class WeightStandardConv(Conv1d):
|
53
|
+
def __init__(
|
54
|
+
self,
|
55
|
+
dim,
|
56
|
+
dim_out,
|
57
|
+
width,
|
58
|
+
*args,
|
59
|
+
**kwargs
|
60
|
+
):
|
61
|
+
super().__init__(dim, dim_out, width, *args, **kwargs)
|
62
|
+
|
63
|
+
register_parametrization(self, 'weight', nn.LayerNorm(self.weight.shape, elementwise_affine = False))
|
64
|
+
|
65
|
+
class ConvBlock(Module):
|
66
|
+
def __init__(
|
67
|
+
self,
|
68
|
+
dim,
|
69
|
+
width = 5,
|
70
|
+
dim_out = None
|
71
|
+
):
|
72
|
+
super().__init__()
|
73
|
+
assert is_odd(width)
|
74
|
+
dim_out = default(dim_out, dim)
|
75
|
+
|
76
|
+
conv_klass = Conv1d if width == 1 else WeightStandardConv
|
77
|
+
|
78
|
+
self.conv = conv_klass(dim, dim_out, width, padding = width // 2)
|
79
|
+
|
80
|
+
def forward(self, x):
|
81
|
+
|
82
|
+
x = F.gelu(x)
|
83
|
+
out = self.conv(x)
|
84
|
+
return out
|
85
|
+
|
86
|
+
class DownresBlock(Module):
|
87
|
+
def __init__(
|
88
|
+
self,
|
89
|
+
dim,
|
90
|
+
channels_to_add = 128 # this is new as well? instead of doubling channels, they add 128 at a time, and use padding or slicing for the residual
|
91
|
+
):
|
92
|
+
super().__init__()
|
93
|
+
|
94
|
+
dim_out = dim + channels_to_add
|
95
|
+
self.pad = channels_to_add
|
96
|
+
|
97
|
+
self.conv = ConvBlock(dim, width = 1, dim_out = dim_out)
|
98
|
+
self.conv_out = ConvBlock(dim_out, width = 1)
|
99
|
+
|
100
|
+
self.max_pool = Reduce('b d (n pool) -> b d n', 'max', pool = 2)
|
101
|
+
|
102
|
+
def forward(self, x):
|
103
|
+
|
104
|
+
residual = F.pad(x, (0, 0, 0, self.pad), value = 0.)
|
105
|
+
|
106
|
+
out = self.conv(x) + residual
|
107
|
+
|
108
|
+
out = self.conv_out(out) + out
|
109
|
+
|
110
|
+
return self.max_pool(out)
|
111
|
+
|
112
|
+
class UpresBlock(Module):
|
113
|
+
def __init__(
|
114
|
+
self,
|
115
|
+
dim,
|
116
|
+
channels_to_remove = 128,
|
117
|
+
residual_scale_init = .9
|
118
|
+
):
|
119
|
+
super().__init__()
|
120
|
+
|
121
|
+
dim_out = dim - channels_to_remove
|
122
|
+
self.pad = channels_to_remove
|
123
|
+
|
124
|
+
self.conv = ConvBlock(dim, width = 1, dim_out = dim_out)
|
125
|
+
self.unet_conv = ConvBlock(dim_out, width = 1)
|
126
|
+
|
127
|
+
self.conv_out = ConvBlock(dim_out, width = 1)
|
128
|
+
|
129
|
+
self.residual_scale = nn.Parameter(torch.ones(1,) * residual_scale_init)
|
130
|
+
|
131
|
+
def forward(
|
132
|
+
self,
|
133
|
+
x,
|
134
|
+
skip = None
|
135
|
+
):
|
136
|
+
|
137
|
+
residual = x[:, :-self.pad]
|
138
|
+
out = self.conv(x) + residual
|
139
|
+
|
140
|
+
if exists(skip):
|
141
|
+
out = repeat(out, 'b c n -> b c (n upsample)', upsample = 2) * self.residual_scale
|
142
|
+
out = out + self.unet_conv(skip)
|
143
|
+
|
144
|
+
return self.conv_out(out) + out
|
145
|
+
|
146
|
+
# position related
|
147
|
+
|
45
148
|
def relative_shift(t):
|
46
149
|
*leading_dims, seq_len, dim = t.shape
|
47
150
|
t = F.pad(t, (1, 0), value = 0.)
|
@@ -284,7 +387,7 @@ class SingleToPairwise(Module):
|
|
284
387
|
rel_q = relative_shift(einsum(q + q_rel_bias, rel_pos_encoding, 'b n h d, p h d -> b h n p'))
|
285
388
|
rel_k = relative_shift(einsum(k + k_rel_bias, rel_pos_encoding, 'b n h d, p h d -> b h n p'))
|
286
389
|
|
287
|
-
rel_sim = add('b h i j, b h j i -> b
|
390
|
+
rel_sim = add('b h i j, b h j i -> b i j h', rel_q, rel_k) * 0.5
|
288
391
|
|
289
392
|
sim = sim + rel_sim
|
290
393
|
|
@@ -360,7 +463,7 @@ class TransformerTower(Module):
|
|
360
463
|
dropout = 0.,
|
361
464
|
ff_expansion_factor = 2.,
|
362
465
|
max_positions = 8192,
|
363
|
-
dim_pairwise =
|
466
|
+
dim_pairwise = 128,
|
364
467
|
pairwise_every_num_single_blocks = 2, # how often to do a pairwise block
|
365
468
|
single_to_pairwise_heads = 32, # they did 32
|
366
469
|
pool_size = 16,
|
@@ -456,8 +559,10 @@ class DNAEmbed(Module):
|
|
456
559
|
super().__init__()
|
457
560
|
assert is_odd(width)
|
458
561
|
self.dim_input = dim_input
|
459
|
-
self.conv =
|
460
|
-
self.pointwise =
|
562
|
+
self.conv = Conv1d(dim_input, dim, width, padding = width // 2)
|
563
|
+
self.pointwise = Conv1d(dim, dim, 1)
|
564
|
+
|
565
|
+
self.pool = Reduce('b d (n pool) -> b d n', 'max', pool = 2)
|
461
566
|
|
462
567
|
def forward(
|
463
568
|
self,
|
@@ -468,38 +573,96 @@ class DNAEmbed(Module):
|
|
468
573
|
|
469
574
|
out = self.conv(x)
|
470
575
|
out = out + self.pointwise(out)
|
471
|
-
|
576
|
+
pooled = self.pool(out) # think they downsample for dna embed block
|
577
|
+
|
578
|
+
return pooled, x
|
472
579
|
|
473
580
|
# classes
|
474
581
|
|
475
582
|
class AlphaGenome(Module):
|
476
583
|
def __init__(
|
477
584
|
self,
|
478
|
-
|
585
|
+
dims: tuple[int, ...] = (
|
586
|
+
768,
|
587
|
+
896,
|
588
|
+
1024,
|
589
|
+
1152,
|
590
|
+
1280,
|
591
|
+
1408,
|
592
|
+
1536
|
593
|
+
),
|
479
594
|
basepairs = 5,
|
480
595
|
dna_embed_width = 15,
|
481
|
-
dim_pairwise = None,
|
482
596
|
transformer_kwargs: dict = dict()
|
483
597
|
):
|
484
598
|
super().__init__()
|
599
|
+
|
485
600
|
assert is_odd(dna_embed_width)
|
486
601
|
|
487
|
-
|
602
|
+
assert len(dims) >= 2
|
603
|
+
first_dim, *_, last_dim = dims
|
604
|
+
|
605
|
+
self.dna_embed = DNAEmbed(first_dim, dim_input = basepairs, width = dna_embed_width)
|
488
606
|
|
489
|
-
|
490
|
-
|
491
|
-
|
607
|
+
dim_with_input = (basepairs, *dims)
|
608
|
+
dim_pairs = zip(dim_with_input[:-1], dim_with_input[1:])
|
609
|
+
|
610
|
+
downs = []
|
611
|
+
ups = []
|
612
|
+
|
613
|
+
for layer_num, (dim_in, dim_out) in enumerate(dim_pairs, start = 1):
|
614
|
+
is_first = layer_num == 1
|
615
|
+
channel_diff = dim_out - dim_in
|
616
|
+
|
617
|
+
assert channel_diff > 0
|
618
|
+
|
619
|
+
if not is_first:
|
620
|
+
down = DownresBlock(dim_in, channels_to_add = channel_diff)
|
621
|
+
downs.append(down)
|
622
|
+
|
623
|
+
up = UpresBlock(dim_out, channels_to_remove = channel_diff)
|
624
|
+
ups.insert(0, up)
|
625
|
+
|
626
|
+
|
627
|
+
self.downs = ModuleList(downs)
|
628
|
+
self.ups = ModuleList(ups)
|
629
|
+
|
630
|
+
self.transformer = TransformerTower(
|
631
|
+
dim = last_dim,
|
492
632
|
**transformer_kwargs
|
493
633
|
)
|
494
634
|
|
495
635
|
def forward(
|
496
636
|
self,
|
497
|
-
seq
|
498
|
-
pairwise
|
637
|
+
seq # Int['b n']
|
499
638
|
):
|
500
639
|
|
501
|
-
|
640
|
+
skips = []
|
641
|
+
|
642
|
+
# embed with one hot and add skip
|
643
|
+
|
644
|
+
x, skip = self.dna_embed(seq)
|
645
|
+
skips.append(skip)
|
646
|
+
|
647
|
+
# downs
|
648
|
+
|
649
|
+
for down in self.downs:
|
650
|
+
skips.append(x)
|
651
|
+
x = down(x)
|
652
|
+
|
653
|
+
x = rearrange(x, 'b d n -> b n d')
|
654
|
+
|
655
|
+
# attention
|
656
|
+
|
657
|
+
single, pairwise = self.transformer(x)
|
658
|
+
|
659
|
+
# ups with skips from down
|
660
|
+
|
661
|
+
x = rearrange(x, 'b n d -> b d n')
|
662
|
+
|
663
|
+
for up in self.ups:
|
664
|
+
x = up(x, skip = skips.pop())
|
502
665
|
|
503
|
-
|
666
|
+
pred = rearrange(x, 'b l n -> b n l') # 1bp resolution
|
504
667
|
|
505
|
-
return
|
668
|
+
return pred, single, pairwise
|
@@ -0,0 +1,36 @@
|
|
1
|
+
import pytest
|
2
|
+
|
3
|
+
import torch
|
4
|
+
from alphagenome_pytorch.alphagenome import TransformerTower
|
5
|
+
|
6
|
+
def test_attention():
|
7
|
+
|
8
|
+
transformer = TransformerTower(dim = 768, dim_pairwise = 128)
|
9
|
+
|
10
|
+
single = torch.randn(2, 512, 768)
|
11
|
+
|
12
|
+
single_repr, pairwise_repr = transformer(single)
|
13
|
+
|
14
|
+
assert single_repr.shape == (2, 512, 768)
|
15
|
+
assert pairwise_repr.shape == (2, 512 // 16, 512 // 16, 128)
|
16
|
+
|
17
|
+
def test_down_up():
|
18
|
+
from alphagenome_pytorch.alphagenome import DownresBlock, UpresBlock
|
19
|
+
down = DownresBlock(64)
|
20
|
+
up = UpresBlock(64 + 128)
|
21
|
+
|
22
|
+
x = torch.randn(1, 64, 8)
|
23
|
+
assert up(down(x), x).shape == x.shape
|
24
|
+
|
25
|
+
def test_alphagenome():
|
26
|
+
from alphagenome_pytorch import AlphaGenome
|
27
|
+
|
28
|
+
model = AlphaGenome()
|
29
|
+
|
30
|
+
dna = torch.randint(0, 5, (2, 8192))
|
31
|
+
|
32
|
+
pred_nucleotide_logits, single, pairwise = model(dna)
|
33
|
+
|
34
|
+
pred = pred_nucleotide_logits.argmax(dim = -1)
|
35
|
+
|
36
|
+
assert pred.shape == dna.shape
|
@@ -1,14 +0,0 @@
|
|
1
|
-
import pytest
|
2
|
-
import torch
|
3
|
-
from alphagenome_pytorch.alphagenome import TransformerTower
|
4
|
-
|
5
|
-
def test_attention():
|
6
|
-
|
7
|
-
transformer = TransformerTower(dim = 768, dim_pairwise = 128)
|
8
|
-
|
9
|
-
single = torch.randn(2, 512, 768)
|
10
|
-
|
11
|
-
single_repr, pairwise_repr = transformer(single)
|
12
|
-
|
13
|
-
assert single_repr.shape == (2, 512, 768)
|
14
|
-
assert pairwise_repr.shape == (2, 512 // 16, 512 // 16, 128)
|
{alphagenome_pytorch-0.0.8 → alphagenome_pytorch-0.0.10}/.github/workflows/python-publish.yml
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|