alphagenome-pytorch 0.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,36 @@
1
+ # This workflow will upload a Python Package using Twine when a release is created
2
+ # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3
+
4
+ # This workflow uses actions that are not certified by GitHub.
5
+ # They are provided by a third-party and are governed by
6
+ # separate terms of service, privacy policy, and support
7
+ # documentation.
8
+
9
+ name: Upload Python Package
10
+
11
+ on:
12
+ release:
13
+ types: [published]
14
+
15
+ jobs:
16
+ deploy:
17
+
18
+ runs-on: ubuntu-latest
19
+
20
+ steps:
21
+ - uses: actions/checkout@v2
22
+ - name: Set up Python
23
+ uses: actions/setup-python@v2
24
+ with:
25
+ python-version: '3.x'
26
+ - name: Install dependencies
27
+ run: |
28
+ python -m pip install --upgrade pip
29
+ pip install build
30
+ - name: Build package
31
+ run: python -m build
32
+ - name: Publish package
33
+ uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
34
+ with:
35
+ user: __token__
36
+ password: ${{ secrets.PYPI_API_TOKEN }}
@@ -0,0 +1,21 @@
1
+ name: Pytest
2
+ on: [push, pull_request]
3
+
4
+ jobs:
5
+ build:
6
+
7
+ runs-on: ubuntu-latest
8
+
9
+ steps:
10
+ - uses: actions/checkout@v4
11
+ - name: Set up Python 3.10
12
+ uses: actions/setup-python@v5
13
+ with:
14
+ python-version: "3.10"
15
+ - name: Install dependencies
16
+ run: |
17
+ python -m pip install --upgrade pip
18
+ python -m pip install -e .[test]
19
+ - name: Test with pytest
20
+ run: |
21
+ python -m pytest tests/
@@ -0,0 +1,194 @@
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+
110
+ # pdm
111
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112
+ #pdm.lock
113
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114
+ # in version control.
115
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116
+ .pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121
+ __pypackages__/
122
+
123
+ # Celery stuff
124
+ celerybeat-schedule
125
+ celerybeat.pid
126
+
127
+ # SageMath parsed files
128
+ *.sage.py
129
+
130
+ # Environments
131
+ .env
132
+ .venv
133
+ env/
134
+ venv/
135
+ ENV/
136
+ env.bak/
137
+ venv.bak/
138
+
139
+ # Spyder project settings
140
+ .spyderproject
141
+ .spyproject
142
+
143
+ # Rope project settings
144
+ .ropeproject
145
+
146
+ # mkdocs documentation
147
+ /site
148
+
149
+ # mypy
150
+ .mypy_cache/
151
+ .dmypy.json
152
+ dmypy.json
153
+
154
+ # Pyre type checker
155
+ .pyre/
156
+
157
+ # pytype static type analyzer
158
+ .pytype/
159
+
160
+ # Cython debug symbols
161
+ cython_debug/
162
+
163
+ # PyCharm
164
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
167
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168
+ #.idea/
169
+
170
+ # Abstra
171
+ # Abstra is an AI-powered process automation framework.
172
+ # Ignore directories containing user credentials, local state, and settings.
173
+ # Learn more at https://abstra.io/docs
174
+ .abstra/
175
+
176
+ # Visual Studio Code
177
+ # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
178
+ # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
179
+ # and can be added to the global gitignore or merged into this file. However, if you prefer,
180
+ # you could uncomment the following to ignore the enitre vscode folder
181
+ # .vscode/
182
+
183
+ # Ruff stuff:
184
+ .ruff_cache/
185
+
186
+ # PyPI configuration file
187
+ .pypirc
188
+
189
+ # Cursor
190
+ # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
191
+ # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
192
+ # refer to https://docs.cursor.com/context/ignore-files
193
+ .cursorignore
194
+ .cursorindexingignore
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Phil Wang
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,78 @@
1
+ Metadata-Version: 2.4
2
+ Name: alphagenome-pytorch
3
+ Version: 0.0.1
4
+ Summary: AlphaGenome
5
+ Project-URL: Homepage, https://pypi.org/project/alphagenome-pytorch/
6
+ Project-URL: Repository, https://github.com/lucidrains/alphagenome
7
+ Author-email: Phil Wang <lucidrains@gmail.com>
8
+ License: MIT License
9
+
10
+ Copyright (c) 2025 Phil Wang
11
+
12
+ Permission is hereby granted, free of charge, to any person obtaining a copy
13
+ of this software and associated documentation files (the "Software"), to deal
14
+ in the Software without restriction, including without limitation the rights
15
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
16
+ copies of the Software, and to permit persons to whom the Software is
17
+ furnished to do so, subject to the following conditions:
18
+
19
+ The above copyright notice and this permission notice shall be included in all
20
+ copies or substantial portions of the Software.
21
+
22
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
23
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
24
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
25
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
26
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
27
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
28
+ SOFTWARE.
29
+ License-File: LICENSE
30
+ Keywords: artificial intelligence,attention mechanism,deep learning,genomics,splicing,transformers
31
+ Classifier: Development Status :: 4 - Beta
32
+ Classifier: Intended Audience :: Developers
33
+ Classifier: License :: OSI Approved :: MIT License
34
+ Classifier: Programming Language :: Python :: 3.9
35
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
+ Requires-Python: >=3.9
37
+ Requires-Dist: einops>=0.8.0
38
+ Requires-Dist: einx>=0.3.0
39
+ Requires-Dist: torch>=2.4
40
+ Provides-Extra: examples
41
+ Provides-Extra: test
42
+ Requires-Dist: pytest; extra == 'test'
43
+ Description-Content-Type: text/markdown
44
+
45
+ <img src="./extended-figure-1.png" width="450px"></img>
46
+
47
+ ## AlphaGenome (wip)
48
+
49
+ Implementation of [AlphaGenome](https://deepmind.google/discover/blog/alphagenome-ai-for-better-understanding-the-genome/), Deepmind's updated genomic attention model
50
+
51
+ ## Install
52
+
53
+ ```bash
54
+ $ pip install alphagenome-pytorch
55
+ ```
56
+
57
+ ## Usage
58
+
59
+ ```python
60
+ import torch
61
+ from alphagenome import TransformerTower
62
+
63
+ transformer = TransformerTower(dim = 768, dim_pairwise = 128)
64
+
65
+ single = torch.randn(2, 512, 768)
66
+
67
+ attended_single, attended_pairwise = transformer(single)
68
+ ```
69
+
70
+ ## Citations
71
+
72
+ ```bibtex
73
+ @article{avsec2025alphagenome,
74
+ title = {AlphaGenome: advancing regulatory variant effect prediction with a unified DNA sequence model},
75
+ author = {Avsec, {\v{Z}}iga and Latysheva, Natasha and Cheng, Jun and Novati, Guido and Taylor, Kyle R and Ward, Tom and Bycroft, Clare and Nicolaisen, Lauren and Arvaniti, Eirini and Pan, Joshua and Thomas, Raina and Dutordoir, Vincent and Perino, Matteo and De, Soham and Karollus, Alexander and Gayoso, Adam and Sargeant, Toby and Mottram, Anne and Wong, Lai Hong and Drot{\'a}r, Pavol and Kosiorek, Adam and Senior, Andrew and Tanburn, Richard and Applebaum, Taylor and Basu, Souradeep and Hassabis, Demis and Kohli, Pushmeet},
76
+ year = {2025}
77
+ }
78
+ ```
@@ -0,0 +1,34 @@
1
+ <img src="./extended-figure-1.png" width="450px"></img>
2
+
3
+ ## AlphaGenome (wip)
4
+
5
+ Implementation of [AlphaGenome](https://deepmind.google/discover/blog/alphagenome-ai-for-better-understanding-the-genome/), Deepmind's updated genomic attention model
6
+
7
+ ## Install
8
+
9
+ ```bash
10
+ $ pip install alphagenome-pytorch
11
+ ```
12
+
13
+ ## Usage
14
+
15
+ ```python
16
+ import torch
17
+ from alphagenome import TransformerTower
18
+
19
+ transformer = TransformerTower(dim = 768, dim_pairwise = 128)
20
+
21
+ single = torch.randn(2, 512, 768)
22
+
23
+ attended_single, attended_pairwise = transformer(single)
24
+ ```
25
+
26
+ ## Citations
27
+
28
+ ```bibtex
29
+ @article{avsec2025alphagenome,
30
+ title = {AlphaGenome: advancing regulatory variant effect prediction with a unified DNA sequence model},
31
+ author = {Avsec, {\v{Z}}iga and Latysheva, Natasha and Cheng, Jun and Novati, Guido and Taylor, Kyle R and Ward, Tom and Bycroft, Clare and Nicolaisen, Lauren and Arvaniti, Eirini and Pan, Joshua and Thomas, Raina and Dutordoir, Vincent and Perino, Matteo and De, Soham and Karollus, Alexander and Gayoso, Adam and Sargeant, Toby and Mottram, Anne and Wong, Lai Hong and Drot{\'a}r, Pavol and Kosiorek, Adam and Senior, Andrew and Tanburn, Richard and Applebaum, Taylor and Basu, Souradeep and Hassabis, Demis and Kohli, Pushmeet},
32
+ year = {2025}
33
+ }
34
+ ```
@@ -0,0 +1,5 @@
1
+ from alphagenome.alphagenome import (
2
+ AlphaGenome,
3
+ Attention,
4
+ TransformerTower
5
+ )
@@ -0,0 +1,441 @@
1
+ from __future__ import annotations
2
+ from functools import partial
3
+
4
+ import torch
5
+ from torch import nn, cat, stack, arange
6
+ import torch.nn.functional as F
7
+ from torch.nn import Linear, Sequential, Module, ModuleList
8
+
9
+ import einx
10
+ from einops.layers.torch import Rearrange, Reduce
11
+ from einops import rearrange, repeat, einsum
12
+
13
+ # ein notation
14
+
15
+ # b - batch
16
+ # h - heads
17
+ # n - sequence
18
+ # d - feature dimension
19
+
20
+ # constants
21
+
22
+ LinearNoBias = partial(Linear, bias = False)
23
+
24
+ # functions
25
+
26
+ def exists(v):
27
+ return v is not None
28
+
29
+ def divisible_by(num, den):
30
+ return (num % den) == 0
31
+
32
+ def is_odd(num):
33
+ return not divisible_by(num, 2)
34
+
35
+ def is_even(num):
36
+ return divisible_by(num, 2)
37
+
38
+ def default(v, d):
39
+ return v if exists(v) else d
40
+
41
+ def softclamp(t, value = 5.):
42
+ return (t / value).tanh() * value
43
+
44
+ # rotary, but with attenuation of short relative distance frequencies
45
+
46
+ class RotaryEmbedding(Module):
47
+ def __init__(
48
+ self,
49
+ dim,
50
+ max_positions = 8192
51
+ ):
52
+ super().__init__()
53
+ num_freqs = dim // 2
54
+ inv_freq = 1. / (arange(num_freqs).float() + torch.logspace(1, max_positions - num_freqs + 1, num_freqs))
55
+ self.register_buffer('inv_freq', inv_freq)
56
+
57
+ def forward(
58
+ self,
59
+ seq_len
60
+ ):
61
+ device = self.inv_freq.device
62
+ t = arange(seq_len, device = device).type_as(self.inv_freq)
63
+ freqs = einsum(t, self.inv_freq, 'i , j -> i j')
64
+ return cat((freqs, freqs), dim = -1)
65
+
66
+ def rotate_half(x):
67
+ x1, x2 = x.chunk(2, dim = -1)
68
+ return torch.cat((-x2, x1), dim = -1)
69
+
70
+ def apply_rotary_pos_emb(pos, t):
71
+ return t * pos.cos() + rotate_half(t) * pos.sin()
72
+
73
+ # prenorm and sandwich norm - they use sandwich norm for single rep, prenorm for pairwise rep
74
+
75
+ class NormWrapper(Module):
76
+ def __init__(
77
+ self,
78
+ dim,
79
+ block: Module,
80
+ dropout = 0.,
81
+ sandwich = False
82
+ ):
83
+ super().__init__()
84
+ self.block = block
85
+ self.pre_rmsnorm = nn.RMSNorm(dim) # they use an interesting variant of batchnorm, batch-rmsnorm. craft later and make sure it works distributed
86
+
87
+ self.post_block_dropout = nn.Dropout(dropout)
88
+ self.post_rmsnorm = nn.RMSNorm(dim) if sandwich else nn.Identity()
89
+
90
+ def forward(
91
+ self,
92
+ x,
93
+ **kwargs
94
+ ):
95
+ x = self.pre_rmsnorm(x)
96
+ out = self.block(x, **kwargs)
97
+ out = self.post_block_dropout(out)
98
+ return self.post_rmsnorm(out)
99
+
100
+ # attention
101
+
102
+ class Attention(Module):
103
+ def __init__(
104
+ self,
105
+ dim,
106
+ dim_head = 64,
107
+ heads = 8,
108
+ dim_head_qk = 128,
109
+ dim_head_v = 192,
110
+ dim_pairwise = None,
111
+ softclamp_value = 5. # they employ attention softclamping
112
+ ):
113
+ super().__init__()
114
+ dim_pairwise = default(dim_pairwise, dim)
115
+
116
+ self.scale = dim_head ** -0.5
117
+
118
+ qkv_proj_dim_out = (dim_head_qk * heads, dim_head_qk, dim_head_v)
119
+
120
+ # splitting and merging of attention heads
121
+
122
+ self.split_q_heads = Rearrange('b n (h d) -> b h n d', h = heads)
123
+ self.merge_heads = Rearrange('b h n d -> b n (h d)')
124
+
125
+ # projections
126
+
127
+ self.to_qkv = LinearNoBias(dim, sum(qkv_proj_dim_out))
128
+ self.to_out = LinearNoBias(dim_head_v * heads, dim)
129
+
130
+ # they add layernorms to queries, keys, and interestingly enough, values as well. first time i've seen this
131
+
132
+ self.q_norm = nn.LayerNorm(dim_head_qk, bias = False)
133
+ self.k_norm = nn.LayerNorm(dim_head_qk, bias = False)
134
+ self.v_norm = nn.LayerNorm(dim_head_v, bias = False)
135
+
136
+ # to attention bias
137
+
138
+ self.to_attn_bias = Sequential(
139
+ nn.RMSNorm(dim_pairwise), # replace with BatchRMSNorm once crafted
140
+ nn.GELU(),
141
+ LinearNoBias(dim_pairwise, heads),
142
+ Rearrange('b i j h -> b h i j')
143
+ )
144
+ # variables
145
+
146
+ self.qkv_dim_splits = qkv_proj_dim_out
147
+ self.softclamp_value = softclamp_value
148
+
149
+ def forward(
150
+ self,
151
+ x,
152
+ pairwise = None, # Float['b i j dp']
153
+ rotary_emb = None
154
+ ):
155
+
156
+ q, k, v = self.to_qkv(x).split(self.qkv_dim_splits, dim = -1)
157
+
158
+ # they use multi-query attention, with only 1 key / value head - pretty unconventional, but maybe enough for genomic modeling
159
+
160
+ q = self.split_q_heads(q)
161
+
162
+ q, k, v = self.q_norm(q), self.k_norm(k), self.v_norm(v)
163
+
164
+ q = q * self.scale
165
+
166
+ # maybe rotary
167
+
168
+ if exists(rotary_emb):
169
+ q, k = tuple(apply_rotary_pos_emb(rotary_emb, t) for t in (q, k))
170
+
171
+ # similarities
172
+
173
+ sim = einsum(q, k, 'b h i d, b j d -> b h i j')
174
+
175
+ # add attention bias + softclamping
176
+
177
+ if exists(pairwise):
178
+ attn_bias = self.to_attn_bias(pairwise)
179
+
180
+ assert divisible_by(sim.shape[-1], attn_bias.shape[-1])
181
+ expand_factor = sim.shape[-1] // attn_bias.shape[-1]
182
+
183
+ attn_bias = repeat(attn_bias, 'b h i j -> b h (i r1) (j r2)', r1 = expand_factor, r2 = expand_factor)
184
+
185
+ sim = softclamp(sim + attn_bias, value = self.softclamp_value)
186
+
187
+ # attention
188
+
189
+ attn = sim.softmax(dim = -1)
190
+
191
+ # aggregate
192
+
193
+ out = einsum(attn, v, 'b h i j, b j d -> b h i d')
194
+
195
+ out = self.merge_heads(out)
196
+ return self.to_out(out)
197
+
198
+ # single to pairwise
199
+
200
+ class SingleToPairwise(Module):
201
+ def __init__(
202
+ self,
203
+ dim,
204
+ pool_size = 16,
205
+ dim_pairwise = 128,
206
+ heads = 32
207
+ ):
208
+ super().__init__()
209
+ self.avg_pool = Reduce('b (n pool) d -> b n d', 'mean', pool = pool_size)
210
+
211
+ dim_inner = heads * dim_pairwise
212
+
213
+ self.split_heads = Rearrange('b n (h d) -> b n h d', h = heads)
214
+
215
+ self.to_outer_sum = Sequential(
216
+ LinearNoBias(dim, dim_pairwise * 2),
217
+ nn.GELU()
218
+ )
219
+
220
+ self.to_qk = LinearNoBias(dim, dim_inner * 2)
221
+ self.qk_to_pairwise = Linear(heads, dim_pairwise)
222
+
223
+ def forward(self, single):
224
+
225
+ single = self.avg_pool(single)
226
+
227
+ q, k = self.to_qk(single).chunk(2, dim = -1)
228
+ q, k = tuple(self.split_heads(t) for t in (q, k))
229
+
230
+ sim = einsum(q, k, 'b i h d, b j h d -> b i j h')
231
+
232
+ pairwise_from_sim = self.qk_to_pairwise(sim)
233
+
234
+ outer_q, outer_k = self.to_outer_sum(single).chunk(2, dim = -1)
235
+
236
+ outer_sum = einx.add('b i d, b j d -> b i j d', outer_q, outer_k)
237
+
238
+ return outer_sum
239
+
240
+ # pairwise attention is a single headed attention across rows, they said columns did not help
241
+
242
+ class PairwiseRowAttention(Module):
243
+ def __init__(
244
+ self,
245
+ dim
246
+ ):
247
+ super().__init__()
248
+ self.scale = dim ** -0.5
249
+
250
+ self.to_qk = LinearNoBias(dim, dim * 2)
251
+ self.to_v = Linear(dim, dim)
252
+
253
+ def forward(
254
+ self,
255
+ x
256
+ ):
257
+
258
+ q, k = self.to_qk(x).chunk(2, dim = -1)
259
+ v = self.to_v(x)
260
+
261
+ # similarity
262
+
263
+ sim = einsum(q, k, 'b n i d, b n j d -> b n i j')
264
+
265
+ # attention
266
+
267
+ attn = sim.softmax(dim = -1)
268
+
269
+ # aggregate
270
+
271
+ return einsum(attn, v, 'b n i j, b n j d -> b n i d')
272
+
273
+ # feedforward for both single and pairwise
274
+
275
+ def FeedForward(
276
+ dim,
277
+ *,
278
+ dropout = 0.,
279
+ expansion_factor = 2., # they only do expansion factor of 2, no glu
280
+ ):
281
+ dim_inner = int(dim * expansion_factor)
282
+
283
+ return Sequential(
284
+ Linear(dim, dim_inner),
285
+ nn.ReLU(),
286
+ nn.Dropout(dropout),
287
+ Linear(dim_inner, dim)
288
+ )
289
+
290
+ # transformer
291
+
292
+ class TransformerTower(Module):
293
+ def __init__(
294
+ self,
295
+ dim,
296
+ *,
297
+ depth = 8,
298
+ heads = 8,
299
+ dim_head_qk = 128,
300
+ dim_head_v = 192,
301
+ dropout = 0.,
302
+ ff_expansion_factor = 2.,
303
+ max_positions = 8192,
304
+ dim_pairwise = None,
305
+ pairwise_every_num_single_blocks = 2, # how often to do a pairwise block
306
+ single_to_pairwise_heads = 32, # they did 32
307
+ attn_kwargs: dict = dict(),
308
+ ff_kwargs: dict = dict()
309
+ ):
310
+ super().__init__()
311
+ dim_pairwise = default(dim_pairwise, dim)
312
+
313
+ layers = []
314
+
315
+ self.pairwise_every = pairwise_every_num_single_blocks
316
+
317
+ self.rotary_emb = RotaryEmbedding(dim_head_qk, max_positions = max_positions)
318
+
319
+ for layer_index in range(depth):
320
+
321
+ attn = Attention(dim = dim, dim_head_qk = dim_head_qk, dim_head_v = dim_head_v, heads = heads, dim_pairwise = dim_pairwise)
322
+
323
+ ff = FeedForward(dim = dim, expansion_factor = ff_expansion_factor)
324
+
325
+ attn = NormWrapper(dim = dim, block = attn, dropout = dropout, sandwich = True)
326
+ ff = NormWrapper(dim = dim, block = ff, dropout = dropout, sandwich = True)
327
+
328
+ # maybe pairwise
329
+
330
+ single_to_pairwise, pairwise_attn, pairwise_ff = None, None, None
331
+
332
+ if divisible_by(layer_index, self.pairwise_every):
333
+ single_to_pairwise = SingleToPairwise(dim = dim, dim_pairwise = dim_pairwise, heads = single_to_pairwise_heads)
334
+ pairwise_attn = PairwiseRowAttention(dim_pairwise)
335
+ pairwise_ff = FeedForward(dim = dim_pairwise, expansion_factor = ff_expansion_factor)
336
+
337
+ single_to_pairwise = NormWrapper(dim = dim, block = single_to_pairwise, dropout = dropout)
338
+ pairwise_attn = NormWrapper(dim = dim_pairwise, block = pairwise_attn, dropout = dropout)
339
+ pairwise_ff = NormWrapper(dim = dim_pairwise, block = pairwise_ff, dropout = dropout)
340
+
341
+ # add to layers
342
+
343
+ layers.append(ModuleList([
344
+ attn,
345
+ ff,
346
+ single_to_pairwise,
347
+ pairwise_attn,
348
+ pairwise_ff
349
+ ]))
350
+
351
+
352
+ self.layers = ModuleList(layers)
353
+
354
+ def forward(
355
+ self,
356
+ single
357
+ ):
358
+
359
+ seq_len = single.shape[1]
360
+
361
+ pairwise = None
362
+
363
+ rotary_emb = self.rotary_emb(seq_len)
364
+
365
+ for (
366
+ attn,
367
+ ff,
368
+ maybe_single_to_pair,
369
+ maybe_pairwise_attn,
370
+ maybe_pairwise_ff
371
+ ) in self.layers:
372
+
373
+ single = attn(single, rotary_emb = rotary_emb, pairwise = None) + single
374
+ single = ff(single) + single
375
+
376
+ if exists(maybe_single_to_pair):
377
+ pairwise = maybe_single_to_pair(single) + default(pairwise, 0.)
378
+ pairwise = maybe_pairwise_attn(pairwise) + pairwise
379
+ pairwise = maybe_pairwise_ff(pairwise) + pairwise
380
+
381
+ return single, pairwise
382
+
383
+ # embedding
384
+
385
+ class DNAEmbed(Module):
386
+ def __init__(
387
+ self,
388
+ dim,
389
+ dim_input = 5, # 5 basepairs
390
+ width = 15
391
+ ):
392
+ super().__init__()
393
+ assert is_odd(width)
394
+ self.dim_input = dim_input
395
+ self.conv = nn.Conv1d(dim_input, dim, width, padding = width // 2)
396
+ self.pointwise = nn.Conv1d(dim, dim, 1)
397
+
398
+ def forward(
399
+ self,
400
+ seq # Int['b n']
401
+ ):
402
+ onehot = F.one_hot(seq, num_classes = self.dim_input).float()
403
+ x = rearrange(onehot, 'b n d -> b d n')
404
+
405
+ out = self.conv(x)
406
+ out = out + self.pointwise(out)
407
+ return rearrange(out, 'b d n -> b n d')
408
+
409
+ # classes
410
+
411
+ class AlphaGenome(Module):
412
+ def __init__(
413
+ self,
414
+ dim = 768,
415
+ basepairs = 5,
416
+ dna_embed_width = 15,
417
+ dim_pairwise = None,
418
+ transformer_kwargs: dict = dict()
419
+ ):
420
+ super().__init__()
421
+ assert is_odd(dna_embed_width)
422
+
423
+ self.to_dna_embed = DNAEmbed(dim, dim_input = basepairs, width = dna_embed_width)
424
+
425
+ self.transformer = Transformer(
426
+ dim = dim,
427
+ dim_pairwise = dim_pairwise,
428
+ **transformer_kwargs
429
+ )
430
+
431
+ def forward(
432
+ self,
433
+ seq,
434
+ pairwise
435
+ ):
436
+
437
+ dna_embed = self.to_dna_embed(seq)
438
+
439
+ attended = self.transformer(dna_embed)
440
+
441
+ return attended
@@ -0,0 +1,61 @@
1
+ [project]
2
+ name = "alphagenome-pytorch"
3
+ version = "0.0.1"
4
+ description = "AlphaGenome"
5
+ authors = [
6
+ { name = "Phil Wang", email = "lucidrains@gmail.com" }
7
+ ]
8
+ readme = "README.md"
9
+ requires-python = ">= 3.9"
10
+ license = { file = "LICENSE" }
11
+ keywords = [
12
+ 'artificial intelligence',
13
+ 'deep learning',
14
+ 'transformers',
15
+ 'attention mechanism',
16
+ 'genomics',
17
+ 'splicing',
18
+ ]
19
+
20
+ classifiers=[
21
+ 'Development Status :: 4 - Beta',
22
+ 'Intended Audience :: Developers',
23
+ 'Topic :: Scientific/Engineering :: Artificial Intelligence',
24
+ 'License :: OSI Approved :: MIT License',
25
+ 'Programming Language :: Python :: 3.9',
26
+ ]
27
+
28
+ dependencies = [
29
+ "einx>=0.3.0",
30
+ "einops>=0.8.0",
31
+ "torch>=2.4",
32
+ ]
33
+
34
+ [project.urls]
35
+ Homepage = "https://pypi.org/project/alphagenome-pytorch/"
36
+ Repository = "https://github.com/lucidrains/alphagenome"
37
+
38
+ [project.optional-dependencies]
39
+ examples = []
40
+ test = [
41
+ "pytest"
42
+ ]
43
+
44
+ [tool.pytest.ini_options]
45
+ pythonpath = [
46
+ "."
47
+ ]
48
+
49
+ [build-system]
50
+ requires = ["hatchling"]
51
+ build-backend = "hatchling.build"
52
+
53
+ [tool.rye]
54
+ managed = true
55
+ dev-dependencies = []
56
+
57
+ [tool.hatch.metadata]
58
+ allow-direct-references = true
59
+
60
+ [tool.hatch.build.targets.wheel]
61
+ packages = ["alphagenome"]
@@ -0,0 +1,14 @@
1
+ import pytest
2
+ import torch
3
+ from alphagenome.alphagenome import TransformerTower
4
+
5
+ def test_attention():
6
+
7
+ transformer = TransformerTower(dim = 768, dim_pairwise = 128)
8
+
9
+ single = torch.randn(2, 512, 768)
10
+
11
+ single_repr, pairwise_repr = transformer(single)
12
+
13
+ assert single_repr.shape == (2, 512, 768)
14
+ assert pairwise_repr.shape == (2, 512 // 16, 512 // 16, 128)