alphagenome-pytorch 0.0.7__tar.gz → 0.0.9__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: alphagenome-pytorch
3
- Version: 0.0.7
3
+ Version: 0.0.9
4
4
  Summary: AlphaGenome
5
5
  Project-URL: Homepage, https://pypi.org/project/alphagenome-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/alphagenome
@@ -4,7 +4,9 @@ from functools import partial
4
4
  import torch
5
5
  from torch import nn, cat, stack, arange, logspace
6
6
  import torch.nn.functional as F
7
- from torch.nn import Linear, Sequential, Module, ModuleList
7
+ from torch.nn import Conv1d, Linear, Sequential, Module, ModuleList
8
+
9
+ from torch.nn.utils.parametrize import register_parametrization
8
10
 
9
11
  from einx import add, multiply, greater
10
12
  from einops.layers.torch import Rearrange, Reduce
@@ -42,6 +44,104 @@ def default(v, d):
42
44
  def softclamp(t, value = 5.):
43
45
  return (t / value).tanh() * value
44
46
 
47
+ # convolutional unet related
48
+
49
+ class WeightStandardConv(Conv1d):
50
+ def __init__(
51
+ self,
52
+ dim,
53
+ dim_out,
54
+ width,
55
+ *args,
56
+ **kwargs
57
+ ):
58
+ super().__init__(dim, dim_out, width, *args, **kwargs)
59
+
60
+ register_parametrization(self, 'weight', nn.LayerNorm(self.weight.shape, elementwise_affine = False))
61
+
62
+ class ConvBlock(Module):
63
+ def __init__(
64
+ self,
65
+ dim,
66
+ width = 5,
67
+ dim_out = None
68
+ ):
69
+ super().__init__()
70
+ assert is_odd(width)
71
+ dim_out = default(dim_out, dim)
72
+
73
+ conv_klass = Conv1d if width == 1 else WeightStandardConv
74
+
75
+ self.conv = conv_klass(dim, dim_out, width, padding = width // 2)
76
+
77
+ def forward(self, x):
78
+
79
+ x = F.gelu(x)
80
+ out = self.conv(x)
81
+ return out
82
+
83
+ class DownresBlock(Module):
84
+ def __init__(
85
+ self,
86
+ dim,
87
+ channels_to_add = 128 # this is new as well? instead of doubling channels, they add 128 at a time, and use padding or slicing for the residual
88
+ ):
89
+ super().__init__()
90
+
91
+ dim_out = dim + channels_to_add
92
+ self.pad = channels_to_add
93
+
94
+ self.conv = ConvBlock(dim, width = 1, dim_out = dim_out)
95
+ self.conv_out = ConvBlock(dim_out, width = 1)
96
+
97
+ self.max_pool = Reduce('b d (n pool) -> b d n', 'max', pool = 2)
98
+
99
+ def forward(self, x):
100
+
101
+ residual = F.pad(x, (0, 0, 0, self.pad), value = 0.)
102
+
103
+ out = self.conv(x) + residual
104
+
105
+ out = self.conv_out(out) + out
106
+
107
+ return self.max_pool(out)
108
+
109
+ class UpresBlock(Module):
110
+ def __init__(
111
+ self,
112
+ dim,
113
+ channels_to_remove = 128,
114
+ residual_scale_init = .9
115
+ ):
116
+ super().__init__()
117
+
118
+ dim_out = dim - channels_to_remove
119
+ self.pad = channels_to_remove
120
+
121
+ self.conv = ConvBlock(dim, width = 1, dim_out = dim_out)
122
+ self.unet_conv = ConvBlock(dim_out, width = 1)
123
+
124
+ self.conv_out = ConvBlock(dim_out, width = 1)
125
+
126
+ self.residual_scale = nn.Parameter(torch.ones(1,) * residual_scale_init)
127
+
128
+ def forward(
129
+ self,
130
+ x,
131
+ unet_skip = None
132
+ ):
133
+
134
+ residual = x[:, :-self.pad]
135
+ out = self.conv(x) + residual
136
+
137
+ if exists(unet_skip):
138
+ out = repeat(out, 'b c n -> b c (n upsample)', upsample = 2) * self.residual_scale
139
+ out = out + self.unet_conv(unet_skip)
140
+
141
+ return self.conv_out(out) + out
142
+
143
+ # position related
144
+
45
145
  def relative_shift(t):
46
146
  *leading_dims, seq_len, dim = t.shape
47
147
  t = F.pad(t, (1, 0), value = 0.)
@@ -434,14 +534,14 @@ class TransformerTower(Module):
434
534
  maybe_pairwise_ff
435
535
  ) in self.layers:
436
536
 
437
- single = attn(single, rotary_emb = rotary_emb, pairwise = None) + single
438
- single = ff(single) + single
439
-
440
537
  if exists(maybe_single_to_pair):
441
538
  pairwise = maybe_single_to_pair(single, rel_pos_feats) + default(pairwise, 0.)
442
539
  pairwise = maybe_pairwise_attn(pairwise) + pairwise
443
540
  pairwise = maybe_pairwise_ff(pairwise) + pairwise
444
541
 
542
+ single = attn(single, rotary_emb = rotary_emb, pairwise = None) + single
543
+ single = ff(single) + single
544
+
445
545
  return single, pairwise
446
546
 
447
547
  # embedding
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "alphagenome-pytorch"
3
- version = "0.0.7"
3
+ version = "0.0.9"
4
4
  description = "AlphaGenome"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -1,4 +1,5 @@
1
1
  import pytest
2
+
2
3
  import torch
3
4
  from alphagenome_pytorch.alphagenome import TransformerTower
4
5
 
@@ -12,3 +13,11 @@ def test_attention():
12
13
 
13
14
  assert single_repr.shape == (2, 512, 768)
14
15
  assert pairwise_repr.shape == (2, 512 // 16, 512 // 16, 128)
16
+
17
+ def test_down_up():
18
+ from alphagenome_pytorch.alphagenome import DownresBlock, UpresBlock
19
+ down = DownresBlock(64)
20
+ up = UpresBlock(64 + 128)
21
+
22
+ x = torch.randn(1, 64, 8)
23
+ assert up(down(x), x).shape == x.shape