alphagenome-pytorch 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,5 @@
1
+ from alphagenome.alphagenome import (
2
+ AlphaGenome,
3
+ Attention,
4
+ TransformerTower
5
+ )
@@ -0,0 +1,441 @@
1
+ from __future__ import annotations
2
+ from functools import partial
3
+
4
+ import torch
5
+ from torch import nn, cat, stack, arange
6
+ import torch.nn.functional as F
7
+ from torch.nn import Linear, Sequential, Module, ModuleList
8
+
9
+ import einx
10
+ from einops.layers.torch import Rearrange, Reduce
11
+ from einops import rearrange, repeat, einsum
12
+
13
+ # ein notation
14
+
15
+ # b - batch
16
+ # h - heads
17
+ # n - sequence
18
+ # d - feature dimension
19
+
20
+ # constants
21
+
22
+ LinearNoBias = partial(Linear, bias = False)
23
+
24
+ # functions
25
+
26
+ def exists(v):
27
+ return v is not None
28
+
29
+ def divisible_by(num, den):
30
+ return (num % den) == 0
31
+
32
+ def is_odd(num):
33
+ return not divisible_by(num, 2)
34
+
35
+ def is_even(num):
36
+ return divisible_by(num, 2)
37
+
38
+ def default(v, d):
39
+ return v if exists(v) else d
40
+
41
+ def softclamp(t, value = 5.):
42
+ return (t / value).tanh() * value
43
+
44
+ # rotary, but with attenuation of short relative distance frequencies
45
+
46
+ class RotaryEmbedding(Module):
47
+ def __init__(
48
+ self,
49
+ dim,
50
+ max_positions = 8192
51
+ ):
52
+ super().__init__()
53
+ num_freqs = dim // 2
54
+ inv_freq = 1. / (arange(num_freqs).float() + torch.logspace(1, max_positions - num_freqs + 1, num_freqs))
55
+ self.register_buffer('inv_freq', inv_freq)
56
+
57
+ def forward(
58
+ self,
59
+ seq_len
60
+ ):
61
+ device = self.inv_freq.device
62
+ t = arange(seq_len, device = device).type_as(self.inv_freq)
63
+ freqs = einsum(t, self.inv_freq, 'i , j -> i j')
64
+ return cat((freqs, freqs), dim = -1)
65
+
66
+ def rotate_half(x):
67
+ x1, x2 = x.chunk(2, dim = -1)
68
+ return torch.cat((-x2, x1), dim = -1)
69
+
70
+ def apply_rotary_pos_emb(pos, t):
71
+ return t * pos.cos() + rotate_half(t) * pos.sin()
72
+
73
+ # prenorm and sandwich norm - they use sandwich norm for single rep, prenorm for pairwise rep
74
+
75
+ class NormWrapper(Module):
76
+ def __init__(
77
+ self,
78
+ dim,
79
+ block: Module,
80
+ dropout = 0.,
81
+ sandwich = False
82
+ ):
83
+ super().__init__()
84
+ self.block = block
85
+ self.pre_rmsnorm = nn.RMSNorm(dim) # they use an interesting variant of batchnorm, batch-rmsnorm. craft later and make sure it works distributed
86
+
87
+ self.post_block_dropout = nn.Dropout(dropout)
88
+ self.post_rmsnorm = nn.RMSNorm(dim) if sandwich else nn.Identity()
89
+
90
+ def forward(
91
+ self,
92
+ x,
93
+ **kwargs
94
+ ):
95
+ x = self.pre_rmsnorm(x)
96
+ out = self.block(x, **kwargs)
97
+ out = self.post_block_dropout(out)
98
+ return self.post_rmsnorm(out)
99
+
100
+ # attention
101
+
102
+ class Attention(Module):
103
+ def __init__(
104
+ self,
105
+ dim,
106
+ dim_head = 64,
107
+ heads = 8,
108
+ dim_head_qk = 128,
109
+ dim_head_v = 192,
110
+ dim_pairwise = None,
111
+ softclamp_value = 5. # they employ attention softclamping
112
+ ):
113
+ super().__init__()
114
+ dim_pairwise = default(dim_pairwise, dim)
115
+
116
+ self.scale = dim_head ** -0.5
117
+
118
+ qkv_proj_dim_out = (dim_head_qk * heads, dim_head_qk, dim_head_v)
119
+
120
+ # splitting and merging of attention heads
121
+
122
+ self.split_q_heads = Rearrange('b n (h d) -> b h n d', h = heads)
123
+ self.merge_heads = Rearrange('b h n d -> b n (h d)')
124
+
125
+ # projections
126
+
127
+ self.to_qkv = LinearNoBias(dim, sum(qkv_proj_dim_out))
128
+ self.to_out = LinearNoBias(dim_head_v * heads, dim)
129
+
130
+ # they add layernorms to queries, keys, and interestingly enough, values as well. first time i've seen this
131
+
132
+ self.q_norm = nn.LayerNorm(dim_head_qk, bias = False)
133
+ self.k_norm = nn.LayerNorm(dim_head_qk, bias = False)
134
+ self.v_norm = nn.LayerNorm(dim_head_v, bias = False)
135
+
136
+ # to attention bias
137
+
138
+ self.to_attn_bias = Sequential(
139
+ nn.RMSNorm(dim_pairwise), # replace with BatchRMSNorm once crafted
140
+ nn.GELU(),
141
+ LinearNoBias(dim_pairwise, heads),
142
+ Rearrange('b i j h -> b h i j')
143
+ )
144
+ # variables
145
+
146
+ self.qkv_dim_splits = qkv_proj_dim_out
147
+ self.softclamp_value = softclamp_value
148
+
149
+ def forward(
150
+ self,
151
+ x,
152
+ pairwise = None, # Float['b i j dp']
153
+ rotary_emb = None
154
+ ):
155
+
156
+ q, k, v = self.to_qkv(x).split(self.qkv_dim_splits, dim = -1)
157
+
158
+ # they use multi-query attention, with only 1 key / value head - pretty unconventional, but maybe enough for genomic modeling
159
+
160
+ q = self.split_q_heads(q)
161
+
162
+ q, k, v = self.q_norm(q), self.k_norm(k), self.v_norm(v)
163
+
164
+ q = q * self.scale
165
+
166
+ # maybe rotary
167
+
168
+ if exists(rotary_emb):
169
+ q, k = tuple(apply_rotary_pos_emb(rotary_emb, t) for t in (q, k))
170
+
171
+ # similarities
172
+
173
+ sim = einsum(q, k, 'b h i d, b j d -> b h i j')
174
+
175
+ # add attention bias + softclamping
176
+
177
+ if exists(pairwise):
178
+ attn_bias = self.to_attn_bias(pairwise)
179
+
180
+ assert divisible_by(sim.shape[-1], attn_bias.shape[-1])
181
+ expand_factor = sim.shape[-1] // attn_bias.shape[-1]
182
+
183
+ attn_bias = repeat(attn_bias, 'b h i j -> b h (i r1) (j r2)', r1 = expand_factor, r2 = expand_factor)
184
+
185
+ sim = softclamp(sim + attn_bias, value = self.softclamp_value)
186
+
187
+ # attention
188
+
189
+ attn = sim.softmax(dim = -1)
190
+
191
+ # aggregate
192
+
193
+ out = einsum(attn, v, 'b h i j, b j d -> b h i d')
194
+
195
+ out = self.merge_heads(out)
196
+ return self.to_out(out)
197
+
198
+ # single to pairwise
199
+
200
+ class SingleToPairwise(Module):
201
+ def __init__(
202
+ self,
203
+ dim,
204
+ pool_size = 16,
205
+ dim_pairwise = 128,
206
+ heads = 32
207
+ ):
208
+ super().__init__()
209
+ self.avg_pool = Reduce('b (n pool) d -> b n d', 'mean', pool = pool_size)
210
+
211
+ dim_inner = heads * dim_pairwise
212
+
213
+ self.split_heads = Rearrange('b n (h d) -> b n h d', h = heads)
214
+
215
+ self.to_outer_sum = Sequential(
216
+ LinearNoBias(dim, dim_pairwise * 2),
217
+ nn.GELU()
218
+ )
219
+
220
+ self.to_qk = LinearNoBias(dim, dim_inner * 2)
221
+ self.qk_to_pairwise = Linear(heads, dim_pairwise)
222
+
223
+ def forward(self, single):
224
+
225
+ single = self.avg_pool(single)
226
+
227
+ q, k = self.to_qk(single).chunk(2, dim = -1)
228
+ q, k = tuple(self.split_heads(t) for t in (q, k))
229
+
230
+ sim = einsum(q, k, 'b i h d, b j h d -> b i j h')
231
+
232
+ pairwise_from_sim = self.qk_to_pairwise(sim)
233
+
234
+ outer_q, outer_k = self.to_outer_sum(single).chunk(2, dim = -1)
235
+
236
+ outer_sum = einx.add('b i d, b j d -> b i j d', outer_q, outer_k)
237
+
238
+ return outer_sum
239
+
240
+ # pairwise attention is a single headed attention across rows, they said columns did not help
241
+
242
+ class PairwiseRowAttention(Module):
243
+ def __init__(
244
+ self,
245
+ dim
246
+ ):
247
+ super().__init__()
248
+ self.scale = dim ** -0.5
249
+
250
+ self.to_qk = LinearNoBias(dim, dim * 2)
251
+ self.to_v = Linear(dim, dim)
252
+
253
+ def forward(
254
+ self,
255
+ x
256
+ ):
257
+
258
+ q, k = self.to_qk(x).chunk(2, dim = -1)
259
+ v = self.to_v(x)
260
+
261
+ # similarity
262
+
263
+ sim = einsum(q, k, 'b n i d, b n j d -> b n i j')
264
+
265
+ # attention
266
+
267
+ attn = sim.softmax(dim = -1)
268
+
269
+ # aggregate
270
+
271
+ return einsum(attn, v, 'b n i j, b n j d -> b n i d')
272
+
273
+ # feedforward for both single and pairwise
274
+
275
+ def FeedForward(
276
+ dim,
277
+ *,
278
+ dropout = 0.,
279
+ expansion_factor = 2., # they only do expansion factor of 2, no glu
280
+ ):
281
+ dim_inner = int(dim * expansion_factor)
282
+
283
+ return Sequential(
284
+ Linear(dim, dim_inner),
285
+ nn.ReLU(),
286
+ nn.Dropout(dropout),
287
+ Linear(dim_inner, dim)
288
+ )
289
+
290
+ # transformer
291
+
292
+ class TransformerTower(Module):
293
+ def __init__(
294
+ self,
295
+ dim,
296
+ *,
297
+ depth = 8,
298
+ heads = 8,
299
+ dim_head_qk = 128,
300
+ dim_head_v = 192,
301
+ dropout = 0.,
302
+ ff_expansion_factor = 2.,
303
+ max_positions = 8192,
304
+ dim_pairwise = None,
305
+ pairwise_every_num_single_blocks = 2, # how often to do a pairwise block
306
+ single_to_pairwise_heads = 32, # they did 32
307
+ attn_kwargs: dict = dict(),
308
+ ff_kwargs: dict = dict()
309
+ ):
310
+ super().__init__()
311
+ dim_pairwise = default(dim_pairwise, dim)
312
+
313
+ layers = []
314
+
315
+ self.pairwise_every = pairwise_every_num_single_blocks
316
+
317
+ self.rotary_emb = RotaryEmbedding(dim_head_qk, max_positions = max_positions)
318
+
319
+ for layer_index in range(depth):
320
+
321
+ attn = Attention(dim = dim, dim_head_qk = dim_head_qk, dim_head_v = dim_head_v, heads = heads, dim_pairwise = dim_pairwise)
322
+
323
+ ff = FeedForward(dim = dim, expansion_factor = ff_expansion_factor)
324
+
325
+ attn = NormWrapper(dim = dim, block = attn, dropout = dropout, sandwich = True)
326
+ ff = NormWrapper(dim = dim, block = ff, dropout = dropout, sandwich = True)
327
+
328
+ # maybe pairwise
329
+
330
+ single_to_pairwise, pairwise_attn, pairwise_ff = None, None, None
331
+
332
+ if divisible_by(layer_index, self.pairwise_every):
333
+ single_to_pairwise = SingleToPairwise(dim = dim, dim_pairwise = dim_pairwise, heads = single_to_pairwise_heads)
334
+ pairwise_attn = PairwiseRowAttention(dim_pairwise)
335
+ pairwise_ff = FeedForward(dim = dim_pairwise, expansion_factor = ff_expansion_factor)
336
+
337
+ single_to_pairwise = NormWrapper(dim = dim, block = single_to_pairwise, dropout = dropout)
338
+ pairwise_attn = NormWrapper(dim = dim_pairwise, block = pairwise_attn, dropout = dropout)
339
+ pairwise_ff = NormWrapper(dim = dim_pairwise, block = pairwise_ff, dropout = dropout)
340
+
341
+ # add to layers
342
+
343
+ layers.append(ModuleList([
344
+ attn,
345
+ ff,
346
+ single_to_pairwise,
347
+ pairwise_attn,
348
+ pairwise_ff
349
+ ]))
350
+
351
+
352
+ self.layers = ModuleList(layers)
353
+
354
+ def forward(
355
+ self,
356
+ single
357
+ ):
358
+
359
+ seq_len = single.shape[1]
360
+
361
+ pairwise = None
362
+
363
+ rotary_emb = self.rotary_emb(seq_len)
364
+
365
+ for (
366
+ attn,
367
+ ff,
368
+ maybe_single_to_pair,
369
+ maybe_pairwise_attn,
370
+ maybe_pairwise_ff
371
+ ) in self.layers:
372
+
373
+ single = attn(single, rotary_emb = rotary_emb, pairwise = None) + single
374
+ single = ff(single) + single
375
+
376
+ if exists(maybe_single_to_pair):
377
+ pairwise = maybe_single_to_pair(single) + default(pairwise, 0.)
378
+ pairwise = maybe_pairwise_attn(pairwise) + pairwise
379
+ pairwise = maybe_pairwise_ff(pairwise) + pairwise
380
+
381
+ return single, pairwise
382
+
383
+ # embedding
384
+
385
+ class DNAEmbed(Module):
386
+ def __init__(
387
+ self,
388
+ dim,
389
+ dim_input = 5, # 5 basepairs
390
+ width = 15
391
+ ):
392
+ super().__init__()
393
+ assert is_odd(width)
394
+ self.dim_input = dim_input
395
+ self.conv = nn.Conv1d(dim_input, dim, width, padding = width // 2)
396
+ self.pointwise = nn.Conv1d(dim, dim, 1)
397
+
398
+ def forward(
399
+ self,
400
+ seq # Int['b n']
401
+ ):
402
+ onehot = F.one_hot(seq, num_classes = self.dim_input).float()
403
+ x = rearrange(onehot, 'b n d -> b d n')
404
+
405
+ out = self.conv(x)
406
+ out = out + self.pointwise(out)
407
+ return rearrange(out, 'b d n -> b n d')
408
+
409
+ # classes
410
+
411
+ class AlphaGenome(Module):
412
+ def __init__(
413
+ self,
414
+ dim = 768,
415
+ basepairs = 5,
416
+ dna_embed_width = 15,
417
+ dim_pairwise = None,
418
+ transformer_kwargs: dict = dict()
419
+ ):
420
+ super().__init__()
421
+ assert is_odd(dna_embed_width)
422
+
423
+ self.to_dna_embed = DNAEmbed(dim, dim_input = basepairs, width = dna_embed_width)
424
+
425
+ self.transformer = Transformer(
426
+ dim = dim,
427
+ dim_pairwise = dim_pairwise,
428
+ **transformer_kwargs
429
+ )
430
+
431
+ def forward(
432
+ self,
433
+ seq,
434
+ pairwise
435
+ ):
436
+
437
+ dna_embed = self.to_dna_embed(seq)
438
+
439
+ attended = self.transformer(dna_embed)
440
+
441
+ return attended
@@ -0,0 +1,78 @@
1
+ Metadata-Version: 2.4
2
+ Name: alphagenome-pytorch
3
+ Version: 0.0.1
4
+ Summary: AlphaGenome
5
+ Project-URL: Homepage, https://pypi.org/project/alphagenome-pytorch/
6
+ Project-URL: Repository, https://github.com/lucidrains/alphagenome
7
+ Author-email: Phil Wang <lucidrains@gmail.com>
8
+ License: MIT License
9
+
10
+ Copyright (c) 2025 Phil Wang
11
+
12
+ Permission is hereby granted, free of charge, to any person obtaining a copy
13
+ of this software and associated documentation files (the "Software"), to deal
14
+ in the Software without restriction, including without limitation the rights
15
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
16
+ copies of the Software, and to permit persons to whom the Software is
17
+ furnished to do so, subject to the following conditions:
18
+
19
+ The above copyright notice and this permission notice shall be included in all
20
+ copies or substantial portions of the Software.
21
+
22
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
23
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
24
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
25
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
26
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
27
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
28
+ SOFTWARE.
29
+ License-File: LICENSE
30
+ Keywords: artificial intelligence,attention mechanism,deep learning,genomics,splicing,transformers
31
+ Classifier: Development Status :: 4 - Beta
32
+ Classifier: Intended Audience :: Developers
33
+ Classifier: License :: OSI Approved :: MIT License
34
+ Classifier: Programming Language :: Python :: 3.9
35
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
+ Requires-Python: >=3.9
37
+ Requires-Dist: einops>=0.8.0
38
+ Requires-Dist: einx>=0.3.0
39
+ Requires-Dist: torch>=2.4
40
+ Provides-Extra: examples
41
+ Provides-Extra: test
42
+ Requires-Dist: pytest; extra == 'test'
43
+ Description-Content-Type: text/markdown
44
+
45
+ <img src="./extended-figure-1.png" width="450px"></img>
46
+
47
+ ## AlphaGenome (wip)
48
+
49
+ Implementation of [AlphaGenome](https://deepmind.google/discover/blog/alphagenome-ai-for-better-understanding-the-genome/), Deepmind's updated genomic attention model
50
+
51
+ ## Install
52
+
53
+ ```bash
54
+ $ pip install alphagenome-pytorch
55
+ ```
56
+
57
+ ## Usage
58
+
59
+ ```python
60
+ import torch
61
+ from alphagenome import TransformerTower
62
+
63
+ transformer = TransformerTower(dim = 768, dim_pairwise = 128)
64
+
65
+ single = torch.randn(2, 512, 768)
66
+
67
+ attended_single, attended_pairwise = transformer(single)
68
+ ```
69
+
70
+ ## Citations
71
+
72
+ ```bibtex
73
+ @article{avsec2025alphagenome,
74
+ title = {AlphaGenome: advancing regulatory variant effect prediction with a unified DNA sequence model},
75
+ author = {Avsec, {\v{Z}}iga and Latysheva, Natasha and Cheng, Jun and Novati, Guido and Taylor, Kyle R and Ward, Tom and Bycroft, Clare and Nicolaisen, Lauren and Arvaniti, Eirini and Pan, Joshua and Thomas, Raina and Dutordoir, Vincent and Perino, Matteo and De, Soham and Karollus, Alexander and Gayoso, Adam and Sargeant, Toby and Mottram, Anne and Wong, Lai Hong and Drot{\'a}r, Pavol and Kosiorek, Adam and Senior, Andrew and Tanburn, Richard and Applebaum, Taylor and Basu, Souradeep and Hassabis, Demis and Kohli, Pushmeet},
76
+ year = {2025}
77
+ }
78
+ ```
@@ -0,0 +1,6 @@
1
+ alphagenome/__init__.py,sha256=FjaT3_la9IG9w-PJ0Tk7ZK550O4Zq8SM5jej3nlzE6U,93
2
+ alphagenome/alphagenome.py,sha256=xIsgMa23nANZyrIlsY4Og7fpibpDED9XVxkeQQXZmYg,11817
3
+ alphagenome_pytorch-0.0.1.dist-info/METADATA,sha256=19TlrcXyHn1-YbLK8V_VZ3nPima2NRBflOFveDsO9Bg,3378
4
+ alphagenome_pytorch-0.0.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
+ alphagenome_pytorch-0.0.1.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
+ alphagenome_pytorch-0.0.1.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: hatchling 1.27.0
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Phil Wang
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.