alphagenome-pytorch 0.0.11__tar.gz → 0.0.12__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: alphagenome-pytorch
3
- Version: 0.0.11
3
+ Version: 0.0.12
4
4
  Summary: AlphaGenome
5
5
  Project-URL: Homepage, https://pypi.org/project/alphagenome-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/alphagenome
@@ -8,4 +8,5 @@ from alphagenome_pytorch.alphagenome import (
8
8
  TransformerTower,
9
9
  UpresBlock,
10
10
  DownresBlock,
11
+ BatchRMSNorm
11
12
  )
@@ -4,7 +4,7 @@ from functools import partial
4
4
  import torch
5
5
  from torch import nn, cat, stack, arange, logspace
6
6
  import torch.nn.functional as F
7
- from torch.nn import Conv1d, Linear, Sequential, Module, ModuleList
7
+ from torch.nn import Conv1d, Linear, Sequential, Module, ModuleList, LayerNorm, RMSNorm
8
8
 
9
9
  from torch.nn.utils.parametrize import register_parametrization
10
10
 
@@ -47,6 +47,49 @@ def default(v, d):
47
47
  def softclamp(t, value = 5.):
48
48
  return (t / value).tanh() * value
49
49
 
50
+ def append_dims(t, ndims):
51
+ return t.shape(*t.shape, *((1,) * ndims))
52
+
53
+ # batch rmsnorm
54
+
55
+ class BatchRMSNorm(Module):
56
+ def __init__(
57
+ self,
58
+ dim_feat,
59
+ channel_first = False,
60
+ momentum = 0.9,
61
+ eps = 1e-5,
62
+ ):
63
+ super().__init__()
64
+ self.scale = dim_feat ** 0.5
65
+
66
+ self.eps = eps
67
+ self.momentum = 1. - momentum
68
+ self.gamma = nn.Parameter(torch.zeros(dim_feat))
69
+ self.channel_first = channel_first
70
+
71
+ self.register_buffer('running_var', torch.ones((dim_feat,)))
72
+
73
+ def forward(
74
+ self,
75
+ x,
76
+ update_running_var = None
77
+ ):
78
+ update_running_var = default(update_running_var, self.training)
79
+ running_var = self.running_var
80
+
81
+ if update_running_var:
82
+
83
+ to_reduce = rearrange(x, 'b d ... -> b ... d') if self.channel_first else x
84
+
85
+ batch_var = torch.var(to_reduce, dim = tuple(range(x.ndim - 1)))
86
+
87
+ running_var.lerp_(batch_var, self.momentum)
88
+
89
+ std = running_var.clamp(min = self.eps).sqrt()
90
+
91
+ return x * self.scale * (self.gamma + 1.) / std
92
+
50
93
  # convolutional unet related
51
94
 
52
95
  class WeightStandardConv(Conv1d):
@@ -60,7 +103,7 @@ class WeightStandardConv(Conv1d):
60
103
  ):
61
104
  super().__init__(dim, dim_out, width, *args, **kwargs)
62
105
 
63
- register_parametrization(self, 'weight', nn.LayerNorm(self.weight.shape, elementwise_affine = False))
106
+ register_parametrization(self, 'weight', LayerNorm(self.weight.shape, elementwise_affine = False))
64
107
 
65
108
  class ConvBlock(Module):
66
109
  def __init__(
@@ -219,10 +262,10 @@ class NormWrapper(Module):
219
262
  ):
220
263
  super().__init__()
221
264
  self.block = block
222
- self.pre_rmsnorm = nn.RMSNorm(dim) # they use an interesting variant of batchnorm, batch-rmsnorm. craft later and make sure it works distributed
265
+ self.pre_rmsnorm = RMSNorm(dim) # they use an interesting variant of batchnorm, batch-rmsnorm. craft later and make sure it works distributed
223
266
 
224
267
  self.post_block_dropout = nn.Dropout(dropout)
225
- self.post_rmsnorm = nn.RMSNorm(dim) if sandwich else nn.Identity()
268
+ self.post_rmsnorm = RMSNorm(dim) if sandwich else nn.Identity()
226
269
 
227
270
  def forward(
228
271
  self,
@@ -247,7 +290,8 @@ class Attention(Module):
247
290
  dim_head_qk = 128,
248
291
  dim_head_v = 192,
249
292
  dim_pairwise = None,
250
- softclamp_value = 5. # they employ attention softclamping
293
+ softclamp_value = 5., # they employ attention softclamping
294
+ use_qk_rmsnorm = True
251
295
  ):
252
296
  super().__init__()
253
297
  dim_pairwise = default(dim_pairwise, dim)
@@ -273,14 +317,15 @@ class Attention(Module):
273
317
 
274
318
  # they add layernorms to queries, keys, and interestingly enough, values as well. first time i've seen this
275
319
 
276
- self.q_norm = nn.LayerNorm(dim_head_qk, bias = False)
277
- self.k_norm = nn.LayerNorm(dim_head_qk, bias = False)
278
- self.v_norm = nn.LayerNorm(dim_head_v, bias = False)
320
+ norm_klass = RMSNorm if use_qk_rmsnorm else partial(LayerNorm, bias = False)
321
+ self.q_norm = norm_klass(dim_head_qk)
322
+ self.k_norm = norm_klass(dim_head_qk)
323
+ self.v_norm = norm_klass(dim_head_v)
279
324
 
280
325
  # to attention bias
281
326
 
282
327
  self.to_attn_bias = Sequential(
283
- nn.RMSNorm(dim_pairwise), # replace with BatchRMSNorm once crafted
328
+ RMSNorm(dim_pairwise), # replace with BatchRMSNorm once crafted
284
329
  nn.GELU(),
285
330
  LinearNoBias(dim_pairwise, heads),
286
331
  Rearrange('b i j (g h) -> b g h i j', g = groups)
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "alphagenome-pytorch"
3
- version = "0.0.11"
3
+ version = "0.0.12"
4
4
  description = "AlphaGenome"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -34,3 +34,12 @@ def test_alphagenome():
34
34
  pred = pred_nucleotide_logits.argmax(dim = -1)
35
35
 
36
36
  assert pred.shape == dna.shape
37
+
38
+ @pytest.mark.parametrize('channel_first', (False, True))
39
+ def test_batchrmsnorm(channel_first):
40
+ from alphagenome_pytorch.alphagenome import BatchRMSNorm
41
+
42
+ rmsnorm = BatchRMSNorm(512, channel_first = channel_first)
43
+
44
+ x = torch.randn(1, 512, 512)
45
+ assert rmsnorm(x).shape == x.shape