alphagenome-pytorch 0.0.7__py3-none-any.whl → 0.0.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- alphagenome_pytorch/alphagenome.py +104 -4
- {alphagenome_pytorch-0.0.7.dist-info → alphagenome_pytorch-0.0.9.dist-info}/METADATA +1 -1
- alphagenome_pytorch-0.0.9.dist-info/RECORD +6 -0
- alphagenome_pytorch-0.0.7.dist-info/RECORD +0 -6
- {alphagenome_pytorch-0.0.7.dist-info → alphagenome_pytorch-0.0.9.dist-info}/WHEEL +0 -0
- {alphagenome_pytorch-0.0.7.dist-info → alphagenome_pytorch-0.0.9.dist-info}/licenses/LICENSE +0 -0
@@ -4,7 +4,9 @@ from functools import partial
|
|
4
4
|
import torch
|
5
5
|
from torch import nn, cat, stack, arange, logspace
|
6
6
|
import torch.nn.functional as F
|
7
|
-
from torch.nn import Linear, Sequential, Module, ModuleList
|
7
|
+
from torch.nn import Conv1d, Linear, Sequential, Module, ModuleList
|
8
|
+
|
9
|
+
from torch.nn.utils.parametrize import register_parametrization
|
8
10
|
|
9
11
|
from einx import add, multiply, greater
|
10
12
|
from einops.layers.torch import Rearrange, Reduce
|
@@ -42,6 +44,104 @@ def default(v, d):
|
|
42
44
|
def softclamp(t, value = 5.):
|
43
45
|
return (t / value).tanh() * value
|
44
46
|
|
47
|
+
# convolutional unet related
|
48
|
+
|
49
|
+
class WeightStandardConv(Conv1d):
|
50
|
+
def __init__(
|
51
|
+
self,
|
52
|
+
dim,
|
53
|
+
dim_out,
|
54
|
+
width,
|
55
|
+
*args,
|
56
|
+
**kwargs
|
57
|
+
):
|
58
|
+
super().__init__(dim, dim_out, width, *args, **kwargs)
|
59
|
+
|
60
|
+
register_parametrization(self, 'weight', nn.LayerNorm(self.weight.shape, elementwise_affine = False))
|
61
|
+
|
62
|
+
class ConvBlock(Module):
|
63
|
+
def __init__(
|
64
|
+
self,
|
65
|
+
dim,
|
66
|
+
width = 5,
|
67
|
+
dim_out = None
|
68
|
+
):
|
69
|
+
super().__init__()
|
70
|
+
assert is_odd(width)
|
71
|
+
dim_out = default(dim_out, dim)
|
72
|
+
|
73
|
+
conv_klass = Conv1d if width == 1 else WeightStandardConv
|
74
|
+
|
75
|
+
self.conv = conv_klass(dim, dim_out, width, padding = width // 2)
|
76
|
+
|
77
|
+
def forward(self, x):
|
78
|
+
|
79
|
+
x = F.gelu(x)
|
80
|
+
out = self.conv(x)
|
81
|
+
return out
|
82
|
+
|
83
|
+
class DownresBlock(Module):
|
84
|
+
def __init__(
|
85
|
+
self,
|
86
|
+
dim,
|
87
|
+
channels_to_add = 128 # this is new as well? instead of doubling channels, they add 128 at a time, and use padding or slicing for the residual
|
88
|
+
):
|
89
|
+
super().__init__()
|
90
|
+
|
91
|
+
dim_out = dim + channels_to_add
|
92
|
+
self.pad = channels_to_add
|
93
|
+
|
94
|
+
self.conv = ConvBlock(dim, width = 1, dim_out = dim_out)
|
95
|
+
self.conv_out = ConvBlock(dim_out, width = 1)
|
96
|
+
|
97
|
+
self.max_pool = Reduce('b d (n pool) -> b d n', 'max', pool = 2)
|
98
|
+
|
99
|
+
def forward(self, x):
|
100
|
+
|
101
|
+
residual = F.pad(x, (0, 0, 0, self.pad), value = 0.)
|
102
|
+
|
103
|
+
out = self.conv(x) + residual
|
104
|
+
|
105
|
+
out = self.conv_out(out) + out
|
106
|
+
|
107
|
+
return self.max_pool(out)
|
108
|
+
|
109
|
+
class UpresBlock(Module):
|
110
|
+
def __init__(
|
111
|
+
self,
|
112
|
+
dim,
|
113
|
+
channels_to_remove = 128,
|
114
|
+
residual_scale_init = .9
|
115
|
+
):
|
116
|
+
super().__init__()
|
117
|
+
|
118
|
+
dim_out = dim - channels_to_remove
|
119
|
+
self.pad = channels_to_remove
|
120
|
+
|
121
|
+
self.conv = ConvBlock(dim, width = 1, dim_out = dim_out)
|
122
|
+
self.unet_conv = ConvBlock(dim_out, width = 1)
|
123
|
+
|
124
|
+
self.conv_out = ConvBlock(dim_out, width = 1)
|
125
|
+
|
126
|
+
self.residual_scale = nn.Parameter(torch.ones(1,) * residual_scale_init)
|
127
|
+
|
128
|
+
def forward(
|
129
|
+
self,
|
130
|
+
x,
|
131
|
+
unet_skip = None
|
132
|
+
):
|
133
|
+
|
134
|
+
residual = x[:, :-self.pad]
|
135
|
+
out = self.conv(x) + residual
|
136
|
+
|
137
|
+
if exists(unet_skip):
|
138
|
+
out = repeat(out, 'b c n -> b c (n upsample)', upsample = 2) * self.residual_scale
|
139
|
+
out = out + self.unet_conv(unet_skip)
|
140
|
+
|
141
|
+
return self.conv_out(out) + out
|
142
|
+
|
143
|
+
# position related
|
144
|
+
|
45
145
|
def relative_shift(t):
|
46
146
|
*leading_dims, seq_len, dim = t.shape
|
47
147
|
t = F.pad(t, (1, 0), value = 0.)
|
@@ -434,14 +534,14 @@ class TransformerTower(Module):
|
|
434
534
|
maybe_pairwise_ff
|
435
535
|
) in self.layers:
|
436
536
|
|
437
|
-
single = attn(single, rotary_emb = rotary_emb, pairwise = None) + single
|
438
|
-
single = ff(single) + single
|
439
|
-
|
440
537
|
if exists(maybe_single_to_pair):
|
441
538
|
pairwise = maybe_single_to_pair(single, rel_pos_feats) + default(pairwise, 0.)
|
442
539
|
pairwise = maybe_pairwise_attn(pairwise) + pairwise
|
443
540
|
pairwise = maybe_pairwise_ff(pairwise) + pairwise
|
444
541
|
|
542
|
+
single = attn(single, rotary_emb = rotary_emb, pairwise = None) + single
|
543
|
+
single = ff(single) + single
|
544
|
+
|
445
545
|
return single, pairwise
|
446
546
|
|
447
547
|
# embedding
|
@@ -0,0 +1,6 @@
|
|
1
|
+
alphagenome_pytorch/__init__.py,sha256=7plC_YRm0UapNCl9hJEhFxXE-ELGKVy-DtuO5GUQxGI,101
|
2
|
+
alphagenome_pytorch/alphagenome.py,sha256=sRQXd-wvi0iSgGjuzfek7jpAJqJkiCSEtt0tFpAbTGo,16462
|
3
|
+
alphagenome_pytorch-0.0.9.dist-info/METADATA,sha256=pqLrVpzTOuFuu6NjvwS6PujEd2BLrUJEB97nxbYTGdc,3386
|
4
|
+
alphagenome_pytorch-0.0.9.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
5
|
+
alphagenome_pytorch-0.0.9.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
6
|
+
alphagenome_pytorch-0.0.9.dist-info/RECORD,,
|
@@ -1,6 +0,0 @@
|
|
1
|
-
alphagenome_pytorch/__init__.py,sha256=7plC_YRm0UapNCl9hJEhFxXE-ELGKVy-DtuO5GUQxGI,101
|
2
|
-
alphagenome_pytorch/alphagenome.py,sha256=6rDCYbXm3tsS6i4HpGjzVZG4ImFpAHdzdfQhjL70vPU,13948
|
3
|
-
alphagenome_pytorch-0.0.7.dist-info/METADATA,sha256=wUhwmvfijnaqPTaY6q82_vopS6U-sVskX1sANF4ls80,3386
|
4
|
-
alphagenome_pytorch-0.0.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
5
|
-
alphagenome_pytorch-0.0.7.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
6
|
-
alphagenome_pytorch-0.0.7.dist-info/RECORD,,
|
File without changes
|
{alphagenome_pytorch-0.0.7.dist-info → alphagenome_pytorch-0.0.9.dist-info}/licenses/LICENSE
RENAMED
File without changes
|