alphagenome-pytorch 0.0.10__py3-none-any.whl → 0.0.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- alphagenome_pytorch/__init__.py +1 -0
- alphagenome_pytorch/alphagenome.py +68 -16
- {alphagenome_pytorch-0.0.10.dist-info → alphagenome_pytorch-0.0.12.dist-info}/METADATA +1 -1
- alphagenome_pytorch-0.0.12.dist-info/RECORD +6 -0
- alphagenome_pytorch-0.0.10.dist-info/RECORD +0 -6
- {alphagenome_pytorch-0.0.10.dist-info → alphagenome_pytorch-0.0.12.dist-info}/WHEEL +0 -0
- {alphagenome_pytorch-0.0.10.dist-info → alphagenome_pytorch-0.0.12.dist-info}/licenses/LICENSE +0 -0
alphagenome_pytorch/__init__.py
CHANGED
@@ -4,7 +4,7 @@ from functools import partial
|
|
4
4
|
import torch
|
5
5
|
from torch import nn, cat, stack, arange, logspace
|
6
6
|
import torch.nn.functional as F
|
7
|
-
from torch.nn import Conv1d, Linear, Sequential, Module, ModuleList
|
7
|
+
from torch.nn import Conv1d, Linear, Sequential, Module, ModuleList, LayerNorm, RMSNorm
|
8
8
|
|
9
9
|
from torch.nn.utils.parametrize import register_parametrization
|
10
10
|
|
@@ -47,6 +47,49 @@ def default(v, d):
|
|
47
47
|
def softclamp(t, value = 5.):
|
48
48
|
return (t / value).tanh() * value
|
49
49
|
|
50
|
+
def append_dims(t, ndims):
|
51
|
+
return t.shape(*t.shape, *((1,) * ndims))
|
52
|
+
|
53
|
+
# batch rmsnorm
|
54
|
+
|
55
|
+
class BatchRMSNorm(Module):
|
56
|
+
def __init__(
|
57
|
+
self,
|
58
|
+
dim_feat,
|
59
|
+
channel_first = False,
|
60
|
+
momentum = 0.9,
|
61
|
+
eps = 1e-5,
|
62
|
+
):
|
63
|
+
super().__init__()
|
64
|
+
self.scale = dim_feat ** 0.5
|
65
|
+
|
66
|
+
self.eps = eps
|
67
|
+
self.momentum = 1. - momentum
|
68
|
+
self.gamma = nn.Parameter(torch.zeros(dim_feat))
|
69
|
+
self.channel_first = channel_first
|
70
|
+
|
71
|
+
self.register_buffer('running_var', torch.ones((dim_feat,)))
|
72
|
+
|
73
|
+
def forward(
|
74
|
+
self,
|
75
|
+
x,
|
76
|
+
update_running_var = None
|
77
|
+
):
|
78
|
+
update_running_var = default(update_running_var, self.training)
|
79
|
+
running_var = self.running_var
|
80
|
+
|
81
|
+
if update_running_var:
|
82
|
+
|
83
|
+
to_reduce = rearrange(x, 'b d ... -> b ... d') if self.channel_first else x
|
84
|
+
|
85
|
+
batch_var = torch.var(to_reduce, dim = tuple(range(x.ndim - 1)))
|
86
|
+
|
87
|
+
running_var.lerp_(batch_var, self.momentum)
|
88
|
+
|
89
|
+
std = running_var.clamp(min = self.eps).sqrt()
|
90
|
+
|
91
|
+
return x * self.scale * (self.gamma + 1.) / std
|
92
|
+
|
50
93
|
# convolutional unet related
|
51
94
|
|
52
95
|
class WeightStandardConv(Conv1d):
|
@@ -60,7 +103,7 @@ class WeightStandardConv(Conv1d):
|
|
60
103
|
):
|
61
104
|
super().__init__(dim, dim_out, width, *args, **kwargs)
|
62
105
|
|
63
|
-
register_parametrization(self, 'weight',
|
106
|
+
register_parametrization(self, 'weight', LayerNorm(self.weight.shape, elementwise_affine = False))
|
64
107
|
|
65
108
|
class ConvBlock(Module):
|
66
109
|
def __init__(
|
@@ -219,10 +262,10 @@ class NormWrapper(Module):
|
|
219
262
|
):
|
220
263
|
super().__init__()
|
221
264
|
self.block = block
|
222
|
-
self.pre_rmsnorm =
|
265
|
+
self.pre_rmsnorm = RMSNorm(dim) # they use an interesting variant of batchnorm, batch-rmsnorm. craft later and make sure it works distributed
|
223
266
|
|
224
267
|
self.post_block_dropout = nn.Dropout(dropout)
|
225
|
-
self.post_rmsnorm =
|
268
|
+
self.post_rmsnorm = RMSNorm(dim) if sandwich else nn.Identity()
|
226
269
|
|
227
270
|
def forward(
|
228
271
|
self,
|
@@ -243,10 +286,12 @@ class Attention(Module):
|
|
243
286
|
dim,
|
244
287
|
dim_head = 64,
|
245
288
|
heads = 8,
|
289
|
+
kv_heads = 1,
|
246
290
|
dim_head_qk = 128,
|
247
291
|
dim_head_v = 192,
|
248
292
|
dim_pairwise = None,
|
249
|
-
softclamp_value = 5
|
293
|
+
softclamp_value = 5., # they employ attention softclamping
|
294
|
+
use_qk_rmsnorm = True
|
250
295
|
):
|
251
296
|
super().__init__()
|
252
297
|
dim_pairwise = default(dim_pairwise, dim)
|
@@ -257,8 +302,13 @@ class Attention(Module):
|
|
257
302
|
|
258
303
|
# splitting and merging of attention heads
|
259
304
|
|
260
|
-
|
261
|
-
|
305
|
+
assert divisible_by(heads, kv_heads)
|
306
|
+
groups = heads // kv_heads
|
307
|
+
|
308
|
+
self.split_q_heads = Rearrange('b n (g h d) -> b g h n d', h = kv_heads, g = groups)
|
309
|
+
self.split_kv_heads = Rearrange('b n (h d) -> b h n d', h = kv_heads)
|
310
|
+
|
311
|
+
self.merge_heads = Rearrange('b g h n d -> b n (g h d)')
|
262
312
|
|
263
313
|
# projections
|
264
314
|
|
@@ -267,17 +317,18 @@ class Attention(Module):
|
|
267
317
|
|
268
318
|
# they add layernorms to queries, keys, and interestingly enough, values as well. first time i've seen this
|
269
319
|
|
270
|
-
|
271
|
-
self.
|
272
|
-
self.
|
320
|
+
norm_klass = RMSNorm if use_qk_rmsnorm else partial(LayerNorm, bias = False)
|
321
|
+
self.q_norm = norm_klass(dim_head_qk)
|
322
|
+
self.k_norm = norm_klass(dim_head_qk)
|
323
|
+
self.v_norm = norm_klass(dim_head_v)
|
273
324
|
|
274
325
|
# to attention bias
|
275
326
|
|
276
327
|
self.to_attn_bias = Sequential(
|
277
|
-
|
328
|
+
RMSNorm(dim_pairwise), # replace with BatchRMSNorm once crafted
|
278
329
|
nn.GELU(),
|
279
330
|
LinearNoBias(dim_pairwise, heads),
|
280
|
-
Rearrange('b i j h -> b h i j')
|
331
|
+
Rearrange('b i j (g h) -> b g h i j', g = groups)
|
281
332
|
)
|
282
333
|
# variables
|
283
334
|
|
@@ -296,6 +347,7 @@ class Attention(Module):
|
|
296
347
|
# they use multi-query attention, with only 1 key / value head - pretty unconventional, but maybe enough for genomic modeling
|
297
348
|
|
298
349
|
q = self.split_q_heads(q)
|
350
|
+
k, v = tuple(self.split_kv_heads(t) for t in (k, v))
|
299
351
|
|
300
352
|
q, k, v = self.q_norm(q), self.k_norm(k), self.v_norm(v)
|
301
353
|
|
@@ -308,7 +360,7 @@ class Attention(Module):
|
|
308
360
|
|
309
361
|
# similarities
|
310
362
|
|
311
|
-
sim = einsum(q, k, 'b h i d, b j d -> b h i j')
|
363
|
+
sim = einsum(q, k, 'b g h i d, b h j d -> b g h i j')
|
312
364
|
|
313
365
|
# add attention bias + softclamping
|
314
366
|
|
@@ -318,7 +370,7 @@ class Attention(Module):
|
|
318
370
|
assert divisible_by(sim.shape[-1], attn_bias.shape[-1])
|
319
371
|
expand_factor = sim.shape[-1] // attn_bias.shape[-1]
|
320
372
|
|
321
|
-
attn_bias = repeat(attn_bias, 'b h i j -> b h (i r1) (j r2)', r1 = expand_factor, r2 = expand_factor)
|
373
|
+
attn_bias = repeat(attn_bias, 'b g h i j -> b g h (i r1) (j r2)', r1 = expand_factor, r2 = expand_factor)
|
322
374
|
|
323
375
|
sim = softclamp(sim + attn_bias, value = self.softclamp_value)
|
324
376
|
|
@@ -328,7 +380,7 @@ class Attention(Module):
|
|
328
380
|
|
329
381
|
# aggregate
|
330
382
|
|
331
|
-
out = einsum(attn, v, 'b h i j, b j d -> b h i d')
|
383
|
+
out = einsum(attn, v, 'b g h i j, b h j d -> b g h i d')
|
332
384
|
|
333
385
|
out = self.merge_heads(out)
|
334
386
|
return self.to_out(out)
|
@@ -542,7 +594,7 @@ class TransformerTower(Module):
|
|
542
594
|
pairwise = maybe_pairwise_attn(pairwise) + pairwise
|
543
595
|
pairwise = maybe_pairwise_ff(pairwise) + pairwise
|
544
596
|
|
545
|
-
single = attn(single, rotary_emb = rotary_emb, pairwise =
|
597
|
+
single = attn(single, rotary_emb = rotary_emb, pairwise = pairwise) + single
|
546
598
|
single = ff(single) + single
|
547
599
|
|
548
600
|
return single, pairwise
|
@@ -0,0 +1,6 @@
|
|
1
|
+
alphagenome_pytorch/__init__.py,sha256=O7Mq4mhJwDUEKAJhFTLZRgSJf2uJaccX-Saef1d5nyg,242
|
2
|
+
alphagenome_pytorch/alphagenome.py,sha256=21HFdrp5q55BRz4xt2h8nw4EgIJgUUHuqoxfgVfXjWE,19394
|
3
|
+
alphagenome_pytorch-0.0.12.dist-info/METADATA,sha256=eODwGA_Ig6C4cbTI1nZDHCgWL4lRgAxY_a31EjyfYH4,3382
|
4
|
+
alphagenome_pytorch-0.0.12.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
5
|
+
alphagenome_pytorch-0.0.12.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
6
|
+
alphagenome_pytorch-0.0.12.dist-info/RECORD,,
|
@@ -1,6 +0,0 @@
|
|
1
|
-
alphagenome_pytorch/__init__.py,sha256=XPNDv0q_c3nkiQo-4ROb_RQsbbKNV8KXmD6X5VnErKI,225
|
2
|
-
alphagenome_pytorch/alphagenome.py,sha256=gfctPlDfzjl7CUkWJBX-9omT1MrvQDIgsvKSR0lLCQc,17902
|
3
|
-
alphagenome_pytorch-0.0.10.dist-info/METADATA,sha256=SWZoRmpq6HVnddFPVMt108RFL-PywE-mniDLVRYkI3o,3382
|
4
|
-
alphagenome_pytorch-0.0.10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
5
|
-
alphagenome_pytorch-0.0.10.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
6
|
-
alphagenome_pytorch-0.0.10.dist-info/RECORD,,
|
File without changes
|
{alphagenome_pytorch-0.0.10.dist-info → alphagenome_pytorch-0.0.12.dist-info}/licenses/LICENSE
RENAMED
File without changes
|