alphagenome-pytorch 0.0.11__py3-none-any.whl → 0.0.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- alphagenome_pytorch/__init__.py +1 -0
- alphagenome_pytorch/alphagenome.py +54 -9
- {alphagenome_pytorch-0.0.11.dist-info → alphagenome_pytorch-0.0.12.dist-info}/METADATA +1 -1
- alphagenome_pytorch-0.0.12.dist-info/RECORD +6 -0
- alphagenome_pytorch-0.0.11.dist-info/RECORD +0 -6
- {alphagenome_pytorch-0.0.11.dist-info → alphagenome_pytorch-0.0.12.dist-info}/WHEEL +0 -0
- {alphagenome_pytorch-0.0.11.dist-info → alphagenome_pytorch-0.0.12.dist-info}/licenses/LICENSE +0 -0
alphagenome_pytorch/__init__.py
CHANGED
@@ -4,7 +4,7 @@ from functools import partial
|
|
4
4
|
import torch
|
5
5
|
from torch import nn, cat, stack, arange, logspace
|
6
6
|
import torch.nn.functional as F
|
7
|
-
from torch.nn import Conv1d, Linear, Sequential, Module, ModuleList
|
7
|
+
from torch.nn import Conv1d, Linear, Sequential, Module, ModuleList, LayerNorm, RMSNorm
|
8
8
|
|
9
9
|
from torch.nn.utils.parametrize import register_parametrization
|
10
10
|
|
@@ -47,6 +47,49 @@ def default(v, d):
|
|
47
47
|
def softclamp(t, value = 5.):
|
48
48
|
return (t / value).tanh() * value
|
49
49
|
|
50
|
+
def append_dims(t, ndims):
|
51
|
+
return t.shape(*t.shape, *((1,) * ndims))
|
52
|
+
|
53
|
+
# batch rmsnorm
|
54
|
+
|
55
|
+
class BatchRMSNorm(Module):
|
56
|
+
def __init__(
|
57
|
+
self,
|
58
|
+
dim_feat,
|
59
|
+
channel_first = False,
|
60
|
+
momentum = 0.9,
|
61
|
+
eps = 1e-5,
|
62
|
+
):
|
63
|
+
super().__init__()
|
64
|
+
self.scale = dim_feat ** 0.5
|
65
|
+
|
66
|
+
self.eps = eps
|
67
|
+
self.momentum = 1. - momentum
|
68
|
+
self.gamma = nn.Parameter(torch.zeros(dim_feat))
|
69
|
+
self.channel_first = channel_first
|
70
|
+
|
71
|
+
self.register_buffer('running_var', torch.ones((dim_feat,)))
|
72
|
+
|
73
|
+
def forward(
|
74
|
+
self,
|
75
|
+
x,
|
76
|
+
update_running_var = None
|
77
|
+
):
|
78
|
+
update_running_var = default(update_running_var, self.training)
|
79
|
+
running_var = self.running_var
|
80
|
+
|
81
|
+
if update_running_var:
|
82
|
+
|
83
|
+
to_reduce = rearrange(x, 'b d ... -> b ... d') if self.channel_first else x
|
84
|
+
|
85
|
+
batch_var = torch.var(to_reduce, dim = tuple(range(x.ndim - 1)))
|
86
|
+
|
87
|
+
running_var.lerp_(batch_var, self.momentum)
|
88
|
+
|
89
|
+
std = running_var.clamp(min = self.eps).sqrt()
|
90
|
+
|
91
|
+
return x * self.scale * (self.gamma + 1.) / std
|
92
|
+
|
50
93
|
# convolutional unet related
|
51
94
|
|
52
95
|
class WeightStandardConv(Conv1d):
|
@@ -60,7 +103,7 @@ class WeightStandardConv(Conv1d):
|
|
60
103
|
):
|
61
104
|
super().__init__(dim, dim_out, width, *args, **kwargs)
|
62
105
|
|
63
|
-
register_parametrization(self, 'weight',
|
106
|
+
register_parametrization(self, 'weight', LayerNorm(self.weight.shape, elementwise_affine = False))
|
64
107
|
|
65
108
|
class ConvBlock(Module):
|
66
109
|
def __init__(
|
@@ -219,10 +262,10 @@ class NormWrapper(Module):
|
|
219
262
|
):
|
220
263
|
super().__init__()
|
221
264
|
self.block = block
|
222
|
-
self.pre_rmsnorm =
|
265
|
+
self.pre_rmsnorm = RMSNorm(dim) # they use an interesting variant of batchnorm, batch-rmsnorm. craft later and make sure it works distributed
|
223
266
|
|
224
267
|
self.post_block_dropout = nn.Dropout(dropout)
|
225
|
-
self.post_rmsnorm =
|
268
|
+
self.post_rmsnorm = RMSNorm(dim) if sandwich else nn.Identity()
|
226
269
|
|
227
270
|
def forward(
|
228
271
|
self,
|
@@ -247,7 +290,8 @@ class Attention(Module):
|
|
247
290
|
dim_head_qk = 128,
|
248
291
|
dim_head_v = 192,
|
249
292
|
dim_pairwise = None,
|
250
|
-
softclamp_value = 5
|
293
|
+
softclamp_value = 5., # they employ attention softclamping
|
294
|
+
use_qk_rmsnorm = True
|
251
295
|
):
|
252
296
|
super().__init__()
|
253
297
|
dim_pairwise = default(dim_pairwise, dim)
|
@@ -273,14 +317,15 @@ class Attention(Module):
|
|
273
317
|
|
274
318
|
# they add layernorms to queries, keys, and interestingly enough, values as well. first time i've seen this
|
275
319
|
|
276
|
-
|
277
|
-
self.
|
278
|
-
self.
|
320
|
+
norm_klass = RMSNorm if use_qk_rmsnorm else partial(LayerNorm, bias = False)
|
321
|
+
self.q_norm = norm_klass(dim_head_qk)
|
322
|
+
self.k_norm = norm_klass(dim_head_qk)
|
323
|
+
self.v_norm = norm_klass(dim_head_v)
|
279
324
|
|
280
325
|
# to attention bias
|
281
326
|
|
282
327
|
self.to_attn_bias = Sequential(
|
283
|
-
|
328
|
+
RMSNorm(dim_pairwise), # replace with BatchRMSNorm once crafted
|
284
329
|
nn.GELU(),
|
285
330
|
LinearNoBias(dim_pairwise, heads),
|
286
331
|
Rearrange('b i j (g h) -> b g h i j', g = groups)
|
@@ -0,0 +1,6 @@
|
|
1
|
+
alphagenome_pytorch/__init__.py,sha256=O7Mq4mhJwDUEKAJhFTLZRgSJf2uJaccX-Saef1d5nyg,242
|
2
|
+
alphagenome_pytorch/alphagenome.py,sha256=21HFdrp5q55BRz4xt2h8nw4EgIJgUUHuqoxfgVfXjWE,19394
|
3
|
+
alphagenome_pytorch-0.0.12.dist-info/METADATA,sha256=eODwGA_Ig6C4cbTI1nZDHCgWL4lRgAxY_a31EjyfYH4,3382
|
4
|
+
alphagenome_pytorch-0.0.12.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
5
|
+
alphagenome_pytorch-0.0.12.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
6
|
+
alphagenome_pytorch-0.0.12.dist-info/RECORD,,
|
@@ -1,6 +0,0 @@
|
|
1
|
-
alphagenome_pytorch/__init__.py,sha256=XPNDv0q_c3nkiQo-4ROb_RQsbbKNV8KXmD6X5VnErKI,225
|
2
|
-
alphagenome_pytorch/alphagenome.py,sha256=7O1bS-_7dlEC_r2rzI_7VvTPsXQaWsIqXd4heocysVA,18206
|
3
|
-
alphagenome_pytorch-0.0.11.dist-info/METADATA,sha256=vwetYdPP9P17KBBuDX_q1N1DgRWe_kbqGXtnqdScSvg,3382
|
4
|
-
alphagenome_pytorch-0.0.11.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
5
|
-
alphagenome_pytorch-0.0.11.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
6
|
-
alphagenome_pytorch-0.0.11.dist-info/RECORD,,
|
File without changes
|
{alphagenome_pytorch-0.0.11.dist-info → alphagenome_pytorch-0.0.12.dist-info}/licenses/LICENSE
RENAMED
File without changes
|