alphagenome-pytorch 0.0.10__py3-none-any.whl → 0.0.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -8,4 +8,5 @@ from alphagenome_pytorch.alphagenome import (
8
8
  TransformerTower,
9
9
  UpresBlock,
10
10
  DownresBlock,
11
+ BatchRMSNorm
11
12
  )
@@ -4,7 +4,7 @@ from functools import partial
4
4
  import torch
5
5
  from torch import nn, cat, stack, arange, logspace
6
6
  import torch.nn.functional as F
7
- from torch.nn import Conv1d, Linear, Sequential, Module, ModuleList
7
+ from torch.nn import Conv1d, Linear, Sequential, Module, ModuleList, LayerNorm, RMSNorm
8
8
 
9
9
  from torch.nn.utils.parametrize import register_parametrization
10
10
 
@@ -47,6 +47,49 @@ def default(v, d):
47
47
  def softclamp(t, value = 5.):
48
48
  return (t / value).tanh() * value
49
49
 
50
+ def append_dims(t, ndims):
51
+ return t.shape(*t.shape, *((1,) * ndims))
52
+
53
+ # batch rmsnorm
54
+
55
+ class BatchRMSNorm(Module):
56
+ def __init__(
57
+ self,
58
+ dim_feat,
59
+ channel_first = False,
60
+ momentum = 0.9,
61
+ eps = 1e-5,
62
+ ):
63
+ super().__init__()
64
+ self.scale = dim_feat ** 0.5
65
+
66
+ self.eps = eps
67
+ self.momentum = 1. - momentum
68
+ self.gamma = nn.Parameter(torch.zeros(dim_feat))
69
+ self.channel_first = channel_first
70
+
71
+ self.register_buffer('running_var', torch.ones((dim_feat,)))
72
+
73
+ def forward(
74
+ self,
75
+ x,
76
+ update_running_var = None
77
+ ):
78
+ update_running_var = default(update_running_var, self.training)
79
+ running_var = self.running_var
80
+
81
+ if update_running_var:
82
+
83
+ to_reduce = rearrange(x, 'b d ... -> b ... d') if self.channel_first else x
84
+
85
+ batch_var = torch.var(to_reduce, dim = tuple(range(x.ndim - 1)))
86
+
87
+ running_var.lerp_(batch_var, self.momentum)
88
+
89
+ std = running_var.clamp(min = self.eps).sqrt()
90
+
91
+ return x * self.scale * (self.gamma + 1.) / std
92
+
50
93
  # convolutional unet related
51
94
 
52
95
  class WeightStandardConv(Conv1d):
@@ -60,7 +103,7 @@ class WeightStandardConv(Conv1d):
60
103
  ):
61
104
  super().__init__(dim, dim_out, width, *args, **kwargs)
62
105
 
63
- register_parametrization(self, 'weight', nn.LayerNorm(self.weight.shape, elementwise_affine = False))
106
+ register_parametrization(self, 'weight', LayerNorm(self.weight.shape, elementwise_affine = False))
64
107
 
65
108
  class ConvBlock(Module):
66
109
  def __init__(
@@ -219,10 +262,10 @@ class NormWrapper(Module):
219
262
  ):
220
263
  super().__init__()
221
264
  self.block = block
222
- self.pre_rmsnorm = nn.RMSNorm(dim) # they use an interesting variant of batchnorm, batch-rmsnorm. craft later and make sure it works distributed
265
+ self.pre_rmsnorm = RMSNorm(dim) # they use an interesting variant of batchnorm, batch-rmsnorm. craft later and make sure it works distributed
223
266
 
224
267
  self.post_block_dropout = nn.Dropout(dropout)
225
- self.post_rmsnorm = nn.RMSNorm(dim) if sandwich else nn.Identity()
268
+ self.post_rmsnorm = RMSNorm(dim) if sandwich else nn.Identity()
226
269
 
227
270
  def forward(
228
271
  self,
@@ -243,10 +286,12 @@ class Attention(Module):
243
286
  dim,
244
287
  dim_head = 64,
245
288
  heads = 8,
289
+ kv_heads = 1,
246
290
  dim_head_qk = 128,
247
291
  dim_head_v = 192,
248
292
  dim_pairwise = None,
249
- softclamp_value = 5. # they employ attention softclamping
293
+ softclamp_value = 5., # they employ attention softclamping
294
+ use_qk_rmsnorm = True
250
295
  ):
251
296
  super().__init__()
252
297
  dim_pairwise = default(dim_pairwise, dim)
@@ -257,8 +302,13 @@ class Attention(Module):
257
302
 
258
303
  # splitting and merging of attention heads
259
304
 
260
- self.split_q_heads = Rearrange('b n (h d) -> b h n d', h = heads)
261
- self.merge_heads = Rearrange('b h n d -> b n (h d)')
305
+ assert divisible_by(heads, kv_heads)
306
+ groups = heads // kv_heads
307
+
308
+ self.split_q_heads = Rearrange('b n (g h d) -> b g h n d', h = kv_heads, g = groups)
309
+ self.split_kv_heads = Rearrange('b n (h d) -> b h n d', h = kv_heads)
310
+
311
+ self.merge_heads = Rearrange('b g h n d -> b n (g h d)')
262
312
 
263
313
  # projections
264
314
 
@@ -267,17 +317,18 @@ class Attention(Module):
267
317
 
268
318
  # they add layernorms to queries, keys, and interestingly enough, values as well. first time i've seen this
269
319
 
270
- self.q_norm = nn.LayerNorm(dim_head_qk, bias = False)
271
- self.k_norm = nn.LayerNorm(dim_head_qk, bias = False)
272
- self.v_norm = nn.LayerNorm(dim_head_v, bias = False)
320
+ norm_klass = RMSNorm if use_qk_rmsnorm else partial(LayerNorm, bias = False)
321
+ self.q_norm = norm_klass(dim_head_qk)
322
+ self.k_norm = norm_klass(dim_head_qk)
323
+ self.v_norm = norm_klass(dim_head_v)
273
324
 
274
325
  # to attention bias
275
326
 
276
327
  self.to_attn_bias = Sequential(
277
- nn.RMSNorm(dim_pairwise), # replace with BatchRMSNorm once crafted
328
+ RMSNorm(dim_pairwise), # replace with BatchRMSNorm once crafted
278
329
  nn.GELU(),
279
330
  LinearNoBias(dim_pairwise, heads),
280
- Rearrange('b i j h -> b h i j')
331
+ Rearrange('b i j (g h) -> b g h i j', g = groups)
281
332
  )
282
333
  # variables
283
334
 
@@ -296,6 +347,7 @@ class Attention(Module):
296
347
  # they use multi-query attention, with only 1 key / value head - pretty unconventional, but maybe enough for genomic modeling
297
348
 
298
349
  q = self.split_q_heads(q)
350
+ k, v = tuple(self.split_kv_heads(t) for t in (k, v))
299
351
 
300
352
  q, k, v = self.q_norm(q), self.k_norm(k), self.v_norm(v)
301
353
 
@@ -308,7 +360,7 @@ class Attention(Module):
308
360
 
309
361
  # similarities
310
362
 
311
- sim = einsum(q, k, 'b h i d, b j d -> b h i j')
363
+ sim = einsum(q, k, 'b g h i d, b h j d -> b g h i j')
312
364
 
313
365
  # add attention bias + softclamping
314
366
 
@@ -318,7 +370,7 @@ class Attention(Module):
318
370
  assert divisible_by(sim.shape[-1], attn_bias.shape[-1])
319
371
  expand_factor = sim.shape[-1] // attn_bias.shape[-1]
320
372
 
321
- attn_bias = repeat(attn_bias, 'b h i j -> b h (i r1) (j r2)', r1 = expand_factor, r2 = expand_factor)
373
+ attn_bias = repeat(attn_bias, 'b g h i j -> b g h (i r1) (j r2)', r1 = expand_factor, r2 = expand_factor)
322
374
 
323
375
  sim = softclamp(sim + attn_bias, value = self.softclamp_value)
324
376
 
@@ -328,7 +380,7 @@ class Attention(Module):
328
380
 
329
381
  # aggregate
330
382
 
331
- out = einsum(attn, v, 'b h i j, b j d -> b h i d')
383
+ out = einsum(attn, v, 'b g h i j, b h j d -> b g h i d')
332
384
 
333
385
  out = self.merge_heads(out)
334
386
  return self.to_out(out)
@@ -542,7 +594,7 @@ class TransformerTower(Module):
542
594
  pairwise = maybe_pairwise_attn(pairwise) + pairwise
543
595
  pairwise = maybe_pairwise_ff(pairwise) + pairwise
544
596
 
545
- single = attn(single, rotary_emb = rotary_emb, pairwise = None) + single
597
+ single = attn(single, rotary_emb = rotary_emb, pairwise = pairwise) + single
546
598
  single = ff(single) + single
547
599
 
548
600
  return single, pairwise
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: alphagenome-pytorch
3
- Version: 0.0.10
3
+ Version: 0.0.12
4
4
  Summary: AlphaGenome
5
5
  Project-URL: Homepage, https://pypi.org/project/alphagenome-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/alphagenome
@@ -0,0 +1,6 @@
1
+ alphagenome_pytorch/__init__.py,sha256=O7Mq4mhJwDUEKAJhFTLZRgSJf2uJaccX-Saef1d5nyg,242
2
+ alphagenome_pytorch/alphagenome.py,sha256=21HFdrp5q55BRz4xt2h8nw4EgIJgUUHuqoxfgVfXjWE,19394
3
+ alphagenome_pytorch-0.0.12.dist-info/METADATA,sha256=eODwGA_Ig6C4cbTI1nZDHCgWL4lRgAxY_a31EjyfYH4,3382
4
+ alphagenome_pytorch-0.0.12.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
+ alphagenome_pytorch-0.0.12.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
+ alphagenome_pytorch-0.0.12.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- alphagenome_pytorch/__init__.py,sha256=XPNDv0q_c3nkiQo-4ROb_RQsbbKNV8KXmD6X5VnErKI,225
2
- alphagenome_pytorch/alphagenome.py,sha256=gfctPlDfzjl7CUkWJBX-9omT1MrvQDIgsvKSR0lLCQc,17902
3
- alphagenome_pytorch-0.0.10.dist-info/METADATA,sha256=SWZoRmpq6HVnddFPVMt108RFL-PywE-mniDLVRYkI3o,3382
4
- alphagenome_pytorch-0.0.10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
- alphagenome_pytorch-0.0.10.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
- alphagenome_pytorch-0.0.10.dist-info/RECORD,,