alphagenome-pytorch 0.0.8__py3-none-any.whl → 0.0.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- alphagenome_pytorch/__init__.py +7 -1
- alphagenome_pytorch/alphagenome.py +180 -17
- {alphagenome_pytorch-0.0.8.dist-info → alphagenome_pytorch-0.0.10.dist-info}/METADATA +5 -5
- alphagenome_pytorch-0.0.10.dist-info/RECORD +6 -0
- alphagenome_pytorch-0.0.8.dist-info/RECORD +0 -6
- {alphagenome_pytorch-0.0.8.dist-info → alphagenome_pytorch-0.0.10.dist-info}/WHEEL +0 -0
- {alphagenome_pytorch-0.0.8.dist-info → alphagenome_pytorch-0.0.10.dist-info}/licenses/LICENSE +0 -0
alphagenome_pytorch/__init__.py
CHANGED
@@ -4,7 +4,9 @@ from functools import partial
|
|
4
4
|
import torch
|
5
5
|
from torch import nn, cat, stack, arange, logspace
|
6
6
|
import torch.nn.functional as F
|
7
|
-
from torch.nn import Linear, Sequential, Module, ModuleList
|
7
|
+
from torch.nn import Conv1d, Linear, Sequential, Module, ModuleList
|
8
|
+
|
9
|
+
from torch.nn.utils.parametrize import register_parametrization
|
8
10
|
|
9
11
|
from einx import add, multiply, greater
|
10
12
|
from einops.layers.torch import Rearrange, Reduce
|
@@ -30,6 +32,9 @@ def exists(v):
|
|
30
32
|
def divisible_by(num, den):
|
31
33
|
return (num % den) == 0
|
32
34
|
|
35
|
+
def last(arr):
|
36
|
+
return arr[-1]
|
37
|
+
|
33
38
|
def is_odd(num):
|
34
39
|
return not divisible_by(num, 2)
|
35
40
|
|
@@ -42,6 +47,104 @@ def default(v, d):
|
|
42
47
|
def softclamp(t, value = 5.):
|
43
48
|
return (t / value).tanh() * value
|
44
49
|
|
50
|
+
# convolutional unet related
|
51
|
+
|
52
|
+
class WeightStandardConv(Conv1d):
|
53
|
+
def __init__(
|
54
|
+
self,
|
55
|
+
dim,
|
56
|
+
dim_out,
|
57
|
+
width,
|
58
|
+
*args,
|
59
|
+
**kwargs
|
60
|
+
):
|
61
|
+
super().__init__(dim, dim_out, width, *args, **kwargs)
|
62
|
+
|
63
|
+
register_parametrization(self, 'weight', nn.LayerNorm(self.weight.shape, elementwise_affine = False))
|
64
|
+
|
65
|
+
class ConvBlock(Module):
|
66
|
+
def __init__(
|
67
|
+
self,
|
68
|
+
dim,
|
69
|
+
width = 5,
|
70
|
+
dim_out = None
|
71
|
+
):
|
72
|
+
super().__init__()
|
73
|
+
assert is_odd(width)
|
74
|
+
dim_out = default(dim_out, dim)
|
75
|
+
|
76
|
+
conv_klass = Conv1d if width == 1 else WeightStandardConv
|
77
|
+
|
78
|
+
self.conv = conv_klass(dim, dim_out, width, padding = width // 2)
|
79
|
+
|
80
|
+
def forward(self, x):
|
81
|
+
|
82
|
+
x = F.gelu(x)
|
83
|
+
out = self.conv(x)
|
84
|
+
return out
|
85
|
+
|
86
|
+
class DownresBlock(Module):
|
87
|
+
def __init__(
|
88
|
+
self,
|
89
|
+
dim,
|
90
|
+
channels_to_add = 128 # this is new as well? instead of doubling channels, they add 128 at a time, and use padding or slicing for the residual
|
91
|
+
):
|
92
|
+
super().__init__()
|
93
|
+
|
94
|
+
dim_out = dim + channels_to_add
|
95
|
+
self.pad = channels_to_add
|
96
|
+
|
97
|
+
self.conv = ConvBlock(dim, width = 1, dim_out = dim_out)
|
98
|
+
self.conv_out = ConvBlock(dim_out, width = 1)
|
99
|
+
|
100
|
+
self.max_pool = Reduce('b d (n pool) -> b d n', 'max', pool = 2)
|
101
|
+
|
102
|
+
def forward(self, x):
|
103
|
+
|
104
|
+
residual = F.pad(x, (0, 0, 0, self.pad), value = 0.)
|
105
|
+
|
106
|
+
out = self.conv(x) + residual
|
107
|
+
|
108
|
+
out = self.conv_out(out) + out
|
109
|
+
|
110
|
+
return self.max_pool(out)
|
111
|
+
|
112
|
+
class UpresBlock(Module):
|
113
|
+
def __init__(
|
114
|
+
self,
|
115
|
+
dim,
|
116
|
+
channels_to_remove = 128,
|
117
|
+
residual_scale_init = .9
|
118
|
+
):
|
119
|
+
super().__init__()
|
120
|
+
|
121
|
+
dim_out = dim - channels_to_remove
|
122
|
+
self.pad = channels_to_remove
|
123
|
+
|
124
|
+
self.conv = ConvBlock(dim, width = 1, dim_out = dim_out)
|
125
|
+
self.unet_conv = ConvBlock(dim_out, width = 1)
|
126
|
+
|
127
|
+
self.conv_out = ConvBlock(dim_out, width = 1)
|
128
|
+
|
129
|
+
self.residual_scale = nn.Parameter(torch.ones(1,) * residual_scale_init)
|
130
|
+
|
131
|
+
def forward(
|
132
|
+
self,
|
133
|
+
x,
|
134
|
+
skip = None
|
135
|
+
):
|
136
|
+
|
137
|
+
residual = x[:, :-self.pad]
|
138
|
+
out = self.conv(x) + residual
|
139
|
+
|
140
|
+
if exists(skip):
|
141
|
+
out = repeat(out, 'b c n -> b c (n upsample)', upsample = 2) * self.residual_scale
|
142
|
+
out = out + self.unet_conv(skip)
|
143
|
+
|
144
|
+
return self.conv_out(out) + out
|
145
|
+
|
146
|
+
# position related
|
147
|
+
|
45
148
|
def relative_shift(t):
|
46
149
|
*leading_dims, seq_len, dim = t.shape
|
47
150
|
t = F.pad(t, (1, 0), value = 0.)
|
@@ -284,7 +387,7 @@ class SingleToPairwise(Module):
|
|
284
387
|
rel_q = relative_shift(einsum(q + q_rel_bias, rel_pos_encoding, 'b n h d, p h d -> b h n p'))
|
285
388
|
rel_k = relative_shift(einsum(k + k_rel_bias, rel_pos_encoding, 'b n h d, p h d -> b h n p'))
|
286
389
|
|
287
|
-
rel_sim = add('b h i j, b h j i -> b
|
390
|
+
rel_sim = add('b h i j, b h j i -> b i j h', rel_q, rel_k) * 0.5
|
288
391
|
|
289
392
|
sim = sim + rel_sim
|
290
393
|
|
@@ -360,7 +463,7 @@ class TransformerTower(Module):
|
|
360
463
|
dropout = 0.,
|
361
464
|
ff_expansion_factor = 2.,
|
362
465
|
max_positions = 8192,
|
363
|
-
dim_pairwise =
|
466
|
+
dim_pairwise = 128,
|
364
467
|
pairwise_every_num_single_blocks = 2, # how often to do a pairwise block
|
365
468
|
single_to_pairwise_heads = 32, # they did 32
|
366
469
|
pool_size = 16,
|
@@ -456,8 +559,10 @@ class DNAEmbed(Module):
|
|
456
559
|
super().__init__()
|
457
560
|
assert is_odd(width)
|
458
561
|
self.dim_input = dim_input
|
459
|
-
self.conv =
|
460
|
-
self.pointwise =
|
562
|
+
self.conv = Conv1d(dim_input, dim, width, padding = width // 2)
|
563
|
+
self.pointwise = Conv1d(dim, dim, 1)
|
564
|
+
|
565
|
+
self.pool = Reduce('b d (n pool) -> b d n', 'max', pool = 2)
|
461
566
|
|
462
567
|
def forward(
|
463
568
|
self,
|
@@ -468,38 +573,96 @@ class DNAEmbed(Module):
|
|
468
573
|
|
469
574
|
out = self.conv(x)
|
470
575
|
out = out + self.pointwise(out)
|
471
|
-
|
576
|
+
pooled = self.pool(out) # think they downsample for dna embed block
|
577
|
+
|
578
|
+
return pooled, x
|
472
579
|
|
473
580
|
# classes
|
474
581
|
|
475
582
|
class AlphaGenome(Module):
|
476
583
|
def __init__(
|
477
584
|
self,
|
478
|
-
|
585
|
+
dims: tuple[int, ...] = (
|
586
|
+
768,
|
587
|
+
896,
|
588
|
+
1024,
|
589
|
+
1152,
|
590
|
+
1280,
|
591
|
+
1408,
|
592
|
+
1536
|
593
|
+
),
|
479
594
|
basepairs = 5,
|
480
595
|
dna_embed_width = 15,
|
481
|
-
dim_pairwise = None,
|
482
596
|
transformer_kwargs: dict = dict()
|
483
597
|
):
|
484
598
|
super().__init__()
|
599
|
+
|
485
600
|
assert is_odd(dna_embed_width)
|
486
601
|
|
487
|
-
|
602
|
+
assert len(dims) >= 2
|
603
|
+
first_dim, *_, last_dim = dims
|
604
|
+
|
605
|
+
self.dna_embed = DNAEmbed(first_dim, dim_input = basepairs, width = dna_embed_width)
|
488
606
|
|
489
|
-
|
490
|
-
|
491
|
-
|
607
|
+
dim_with_input = (basepairs, *dims)
|
608
|
+
dim_pairs = zip(dim_with_input[:-1], dim_with_input[1:])
|
609
|
+
|
610
|
+
downs = []
|
611
|
+
ups = []
|
612
|
+
|
613
|
+
for layer_num, (dim_in, dim_out) in enumerate(dim_pairs, start = 1):
|
614
|
+
is_first = layer_num == 1
|
615
|
+
channel_diff = dim_out - dim_in
|
616
|
+
|
617
|
+
assert channel_diff > 0
|
618
|
+
|
619
|
+
if not is_first:
|
620
|
+
down = DownresBlock(dim_in, channels_to_add = channel_diff)
|
621
|
+
downs.append(down)
|
622
|
+
|
623
|
+
up = UpresBlock(dim_out, channels_to_remove = channel_diff)
|
624
|
+
ups.insert(0, up)
|
625
|
+
|
626
|
+
|
627
|
+
self.downs = ModuleList(downs)
|
628
|
+
self.ups = ModuleList(ups)
|
629
|
+
|
630
|
+
self.transformer = TransformerTower(
|
631
|
+
dim = last_dim,
|
492
632
|
**transformer_kwargs
|
493
633
|
)
|
494
634
|
|
495
635
|
def forward(
|
496
636
|
self,
|
497
|
-
seq
|
498
|
-
pairwise
|
637
|
+
seq # Int['b n']
|
499
638
|
):
|
500
639
|
|
501
|
-
|
640
|
+
skips = []
|
641
|
+
|
642
|
+
# embed with one hot and add skip
|
643
|
+
|
644
|
+
x, skip = self.dna_embed(seq)
|
645
|
+
skips.append(skip)
|
646
|
+
|
647
|
+
# downs
|
648
|
+
|
649
|
+
for down in self.downs:
|
650
|
+
skips.append(x)
|
651
|
+
x = down(x)
|
652
|
+
|
653
|
+
x = rearrange(x, 'b d n -> b n d')
|
654
|
+
|
655
|
+
# attention
|
656
|
+
|
657
|
+
single, pairwise = self.transformer(x)
|
658
|
+
|
659
|
+
# ups with skips from down
|
660
|
+
|
661
|
+
x = rearrange(x, 'b n d -> b d n')
|
662
|
+
|
663
|
+
for up in self.ups:
|
664
|
+
x = up(x, skip = skips.pop())
|
502
665
|
|
503
|
-
|
666
|
+
pred = rearrange(x, 'b l n -> b n l') # 1bp resolution
|
504
667
|
|
505
|
-
return
|
668
|
+
return pred, single, pairwise
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: alphagenome-pytorch
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.10
|
4
4
|
Summary: AlphaGenome
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/alphagenome-pytorch/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/alphagenome
|
@@ -58,13 +58,13 @@ $ pip install alphagenome-pytorch
|
|
58
58
|
|
59
59
|
```python
|
60
60
|
import torch
|
61
|
-
from alphagenome_pytorch import
|
61
|
+
from alphagenome_pytorch import AlphaGenome
|
62
62
|
|
63
|
-
|
63
|
+
model = AlphaGenome()
|
64
64
|
|
65
|
-
|
65
|
+
dna = torch.randint(0, 5, (2, 8192))
|
66
66
|
|
67
|
-
|
67
|
+
pred_nucleotide, single, pairwise = model(dna) # (2, 8192, 5), (2, 64, 1536), (2, 4, 4, 1536)
|
68
68
|
```
|
69
69
|
|
70
70
|
## Citations
|
@@ -0,0 +1,6 @@
|
|
1
|
+
alphagenome_pytorch/__init__.py,sha256=XPNDv0q_c3nkiQo-4ROb_RQsbbKNV8KXmD6X5VnErKI,225
|
2
|
+
alphagenome_pytorch/alphagenome.py,sha256=gfctPlDfzjl7CUkWJBX-9omT1MrvQDIgsvKSR0lLCQc,17902
|
3
|
+
alphagenome_pytorch-0.0.10.dist-info/METADATA,sha256=SWZoRmpq6HVnddFPVMt108RFL-PywE-mniDLVRYkI3o,3382
|
4
|
+
alphagenome_pytorch-0.0.10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
5
|
+
alphagenome_pytorch-0.0.10.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
6
|
+
alphagenome_pytorch-0.0.10.dist-info/RECORD,,
|
@@ -1,6 +0,0 @@
|
|
1
|
-
alphagenome_pytorch/__init__.py,sha256=7plC_YRm0UapNCl9hJEhFxXE-ELGKVy-DtuO5GUQxGI,101
|
2
|
-
alphagenome_pytorch/alphagenome.py,sha256=UjlOJGTMshcaNqmY0r6IRCgRunu7BTZyhbk2vNi5Mis,13948
|
3
|
-
alphagenome_pytorch-0.0.8.dist-info/METADATA,sha256=mhlnCRy7Ovq_Gt9ul_iIqws4zdoOcLjfQTBrlbcNic8,3386
|
4
|
-
alphagenome_pytorch-0.0.8.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
5
|
-
alphagenome_pytorch-0.0.8.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
6
|
-
alphagenome_pytorch-0.0.8.dist-info/RECORD,,
|
File without changes
|
{alphagenome_pytorch-0.0.8.dist-info → alphagenome_pytorch-0.0.10.dist-info}/licenses/LICENSE
RENAMED
File without changes
|