boltz-vsynthes 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- boltz/__init__.py +7 -0
- boltz/data/__init__.py +0 -0
- boltz/data/const.py +1184 -0
- boltz/data/crop/__init__.py +0 -0
- boltz/data/crop/affinity.py +164 -0
- boltz/data/crop/boltz.py +296 -0
- boltz/data/crop/cropper.py +45 -0
- boltz/data/feature/__init__.py +0 -0
- boltz/data/feature/featurizer.py +1230 -0
- boltz/data/feature/featurizerv2.py +2208 -0
- boltz/data/feature/symmetry.py +602 -0
- boltz/data/filter/__init__.py +0 -0
- boltz/data/filter/dynamic/__init__.py +0 -0
- boltz/data/filter/dynamic/date.py +76 -0
- boltz/data/filter/dynamic/filter.py +24 -0
- boltz/data/filter/dynamic/max_residues.py +37 -0
- boltz/data/filter/dynamic/resolution.py +34 -0
- boltz/data/filter/dynamic/size.py +38 -0
- boltz/data/filter/dynamic/subset.py +42 -0
- boltz/data/filter/static/__init__.py +0 -0
- boltz/data/filter/static/filter.py +26 -0
- boltz/data/filter/static/ligand.py +37 -0
- boltz/data/filter/static/polymer.py +299 -0
- boltz/data/module/__init__.py +0 -0
- boltz/data/module/inference.py +307 -0
- boltz/data/module/inferencev2.py +429 -0
- boltz/data/module/training.py +684 -0
- boltz/data/module/trainingv2.py +660 -0
- boltz/data/mol.py +900 -0
- boltz/data/msa/__init__.py +0 -0
- boltz/data/msa/mmseqs2.py +235 -0
- boltz/data/pad.py +84 -0
- boltz/data/parse/__init__.py +0 -0
- boltz/data/parse/a3m.py +134 -0
- boltz/data/parse/csv.py +100 -0
- boltz/data/parse/fasta.py +138 -0
- boltz/data/parse/mmcif.py +1239 -0
- boltz/data/parse/mmcif_with_constraints.py +1607 -0
- boltz/data/parse/schema.py +1851 -0
- boltz/data/parse/yaml.py +68 -0
- boltz/data/sample/__init__.py +0 -0
- boltz/data/sample/cluster.py +283 -0
- boltz/data/sample/distillation.py +57 -0
- boltz/data/sample/random.py +39 -0
- boltz/data/sample/sampler.py +49 -0
- boltz/data/tokenize/__init__.py +0 -0
- boltz/data/tokenize/boltz.py +195 -0
- boltz/data/tokenize/boltz2.py +396 -0
- boltz/data/tokenize/tokenizer.py +24 -0
- boltz/data/types.py +777 -0
- boltz/data/write/__init__.py +0 -0
- boltz/data/write/mmcif.py +305 -0
- boltz/data/write/pdb.py +171 -0
- boltz/data/write/utils.py +23 -0
- boltz/data/write/writer.py +330 -0
- boltz/main.py +1292 -0
- boltz/model/__init__.py +0 -0
- boltz/model/layers/__init__.py +0 -0
- boltz/model/layers/attention.py +132 -0
- boltz/model/layers/attentionv2.py +111 -0
- boltz/model/layers/confidence_utils.py +231 -0
- boltz/model/layers/dropout.py +34 -0
- boltz/model/layers/initialize.py +100 -0
- boltz/model/layers/outer_product_mean.py +98 -0
- boltz/model/layers/pair_averaging.py +135 -0
- boltz/model/layers/pairformer.py +337 -0
- boltz/model/layers/relative.py +58 -0
- boltz/model/layers/transition.py +78 -0
- boltz/model/layers/triangular_attention/__init__.py +0 -0
- boltz/model/layers/triangular_attention/attention.py +189 -0
- boltz/model/layers/triangular_attention/primitives.py +409 -0
- boltz/model/layers/triangular_attention/utils.py +380 -0
- boltz/model/layers/triangular_mult.py +212 -0
- boltz/model/loss/__init__.py +0 -0
- boltz/model/loss/bfactor.py +49 -0
- boltz/model/loss/confidence.py +590 -0
- boltz/model/loss/confidencev2.py +621 -0
- boltz/model/loss/diffusion.py +171 -0
- boltz/model/loss/diffusionv2.py +134 -0
- boltz/model/loss/distogram.py +48 -0
- boltz/model/loss/distogramv2.py +105 -0
- boltz/model/loss/validation.py +1025 -0
- boltz/model/models/__init__.py +0 -0
- boltz/model/models/boltz1.py +1286 -0
- boltz/model/models/boltz2.py +1249 -0
- boltz/model/modules/__init__.py +0 -0
- boltz/model/modules/affinity.py +223 -0
- boltz/model/modules/confidence.py +481 -0
- boltz/model/modules/confidence_utils.py +181 -0
- boltz/model/modules/confidencev2.py +495 -0
- boltz/model/modules/diffusion.py +844 -0
- boltz/model/modules/diffusion_conditioning.py +116 -0
- boltz/model/modules/diffusionv2.py +677 -0
- boltz/model/modules/encoders.py +639 -0
- boltz/model/modules/encodersv2.py +565 -0
- boltz/model/modules/transformers.py +322 -0
- boltz/model/modules/transformersv2.py +261 -0
- boltz/model/modules/trunk.py +688 -0
- boltz/model/modules/trunkv2.py +828 -0
- boltz/model/modules/utils.py +303 -0
- boltz/model/optim/__init__.py +0 -0
- boltz/model/optim/ema.py +389 -0
- boltz/model/optim/scheduler.py +99 -0
- boltz/model/potentials/__init__.py +0 -0
- boltz/model/potentials/potentials.py +497 -0
- boltz/model/potentials/schedules.py +32 -0
- boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
- boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
- boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
- boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
- boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
- boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
File without changes
|
@@ -0,0 +1,235 @@
|
|
1
|
+
# From https://github.com/sokrypton/ColabFold/blob/main/colabfold/colabfold.py
|
2
|
+
|
3
|
+
import logging
|
4
|
+
import os
|
5
|
+
import random
|
6
|
+
import tarfile
|
7
|
+
import time
|
8
|
+
from typing import Union
|
9
|
+
|
10
|
+
import requests
|
11
|
+
from tqdm import tqdm
|
12
|
+
|
13
|
+
logger = logging.getLogger(__name__)
|
14
|
+
|
15
|
+
TQDM_BAR_FORMAT = (
|
16
|
+
"{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]"
|
17
|
+
)
|
18
|
+
|
19
|
+
|
20
|
+
def run_mmseqs2( # noqa: PLR0912, D103, C901, PLR0915
|
21
|
+
x: Union[str, list[str]],
|
22
|
+
prefix: str = "tmp",
|
23
|
+
use_env: bool = True,
|
24
|
+
use_filter: bool = True,
|
25
|
+
use_pairing: bool = False,
|
26
|
+
pairing_strategy: str = "greedy",
|
27
|
+
host_url: str = "https://api.colabfold.com",
|
28
|
+
) -> tuple[list[str], list[str]]:
|
29
|
+
submission_endpoint = "ticket/pair" if use_pairing else "ticket/msa"
|
30
|
+
|
31
|
+
# Set header agent as boltz
|
32
|
+
headers = {}
|
33
|
+
headers["User-Agent"] = "boltz"
|
34
|
+
|
35
|
+
def submit(seqs, mode, N=101):
|
36
|
+
n, query = N, ""
|
37
|
+
for seq in seqs:
|
38
|
+
query += f">{n}\n{seq}\n"
|
39
|
+
n += 1
|
40
|
+
|
41
|
+
error_count = 0
|
42
|
+
while True:
|
43
|
+
try:
|
44
|
+
# https://requests.readthedocs.io/en/latest/user/advanced/#advanced
|
45
|
+
# "good practice to set connect timeouts to slightly larger than a multiple of 3"
|
46
|
+
res = requests.post(
|
47
|
+
f"{host_url}/{submission_endpoint}",
|
48
|
+
data={"q": query, "mode": mode},
|
49
|
+
timeout=6.02,
|
50
|
+
headers=headers,
|
51
|
+
)
|
52
|
+
except Exception as e:
|
53
|
+
error_count += 1
|
54
|
+
logger.warning(f"Error while fetching result from MSA server. Retrying... ({error_count}/5)")
|
55
|
+
logger.warning(f"Error: {e}")
|
56
|
+
if error_count > 5:
|
57
|
+
raise Exception("Too many failed attempts for the MSA generation request.")
|
58
|
+
time.sleep(5)
|
59
|
+
else:
|
60
|
+
break
|
61
|
+
|
62
|
+
try:
|
63
|
+
out = res.json()
|
64
|
+
except ValueError:
|
65
|
+
logger.error(f"Server didn't reply with json: {res.text}")
|
66
|
+
out = {"status": "ERROR"}
|
67
|
+
return out
|
68
|
+
|
69
|
+
def status(ID):
|
70
|
+
error_count = 0
|
71
|
+
while True:
|
72
|
+
try:
|
73
|
+
res = requests.get(
|
74
|
+
f"{host_url}/ticket/{ID}", timeout=6.02, headers=headers
|
75
|
+
)
|
76
|
+
except Exception as e:
|
77
|
+
error_count += 1
|
78
|
+
logger.warning(f"Error while fetching result from MSA server. Retrying... ({error_count}/5)")
|
79
|
+
logger.warning(f"Error: {e}")
|
80
|
+
if error_count > 5:
|
81
|
+
raise Exception("Too many failed attempts for the MSA generation request.")
|
82
|
+
time.sleep(5)
|
83
|
+
else:
|
84
|
+
break
|
85
|
+
try:
|
86
|
+
out = res.json()
|
87
|
+
except ValueError:
|
88
|
+
logger.error(f"Server didn't reply with json: {res.text}")
|
89
|
+
out = {"status": "ERROR"}
|
90
|
+
return out
|
91
|
+
|
92
|
+
def download(ID, path):
|
93
|
+
error_count = 0
|
94
|
+
while True:
|
95
|
+
try:
|
96
|
+
res = requests.get(
|
97
|
+
f"{host_url}/result/download/{ID}", timeout=6.02, headers=headers
|
98
|
+
)
|
99
|
+
except Exception as e:
|
100
|
+
error_count += 1
|
101
|
+
logger.warning(f"Error while fetching result from MSA server. Retrying... ({error_count}/5)")
|
102
|
+
logger.warning(f"Error: {e}")
|
103
|
+
if error_count > 5:
|
104
|
+
raise Exception("Too many failed attempts for the MSA generation request.")
|
105
|
+
time.sleep(5)
|
106
|
+
else:
|
107
|
+
break
|
108
|
+
with open(path, "wb") as out:
|
109
|
+
out.write(res.content)
|
110
|
+
|
111
|
+
# process input x
|
112
|
+
seqs = [x] if isinstance(x, str) else x
|
113
|
+
|
114
|
+
# setup mode
|
115
|
+
if use_filter:
|
116
|
+
mode = "env" if use_env else "all"
|
117
|
+
else:
|
118
|
+
mode = "env-nofilter" if use_env else "nofilter"
|
119
|
+
|
120
|
+
if use_pairing:
|
121
|
+
mode = ""
|
122
|
+
# greedy is default, complete was the previous behavior
|
123
|
+
if pairing_strategy == "greedy":
|
124
|
+
mode = "pairgreedy"
|
125
|
+
elif pairing_strategy == "complete":
|
126
|
+
mode = "paircomplete"
|
127
|
+
if use_env:
|
128
|
+
mode = mode + "-env"
|
129
|
+
|
130
|
+
# define path
|
131
|
+
path = f"{prefix}_{mode}"
|
132
|
+
if not os.path.isdir(path):
|
133
|
+
os.mkdir(path)
|
134
|
+
|
135
|
+
# call mmseqs2 api
|
136
|
+
tar_gz_file = f"{path}/out.tar.gz"
|
137
|
+
N, REDO = 101, True
|
138
|
+
|
139
|
+
# deduplicate and keep track of order
|
140
|
+
seqs_unique = []
|
141
|
+
# TODO this might be slow for large sets
|
142
|
+
[seqs_unique.append(x) for x in seqs if x not in seqs_unique]
|
143
|
+
Ms = [N + seqs_unique.index(seq) for seq in seqs]
|
144
|
+
# lets do it!
|
145
|
+
if not os.path.isfile(tar_gz_file):
|
146
|
+
TIME_ESTIMATE = 150 * len(seqs_unique)
|
147
|
+
with tqdm(total=TIME_ESTIMATE, bar_format=TQDM_BAR_FORMAT) as pbar:
|
148
|
+
while REDO:
|
149
|
+
pbar.set_description("SUBMIT")
|
150
|
+
|
151
|
+
# Resubmit job until it goes through
|
152
|
+
out = submit(seqs_unique, mode, N)
|
153
|
+
while out["status"] in ["UNKNOWN", "RATELIMIT"]:
|
154
|
+
sleep_time = 5 + random.randint(0, 5)
|
155
|
+
logger.error(f"Sleeping for {sleep_time}s. Reason: {out['status']}")
|
156
|
+
# resubmit
|
157
|
+
time.sleep(sleep_time)
|
158
|
+
out = submit(seqs_unique, mode, N)
|
159
|
+
|
160
|
+
if out["status"] == "ERROR":
|
161
|
+
msg = (
|
162
|
+
"MMseqs2 API is giving errors. Please confirm your "
|
163
|
+
" input is a valid protein sequence. If error persists, "
|
164
|
+
"please try again an hour later."
|
165
|
+
)
|
166
|
+
raise Exception(msg)
|
167
|
+
|
168
|
+
if out["status"] == "MAINTENANCE":
|
169
|
+
msg = (
|
170
|
+
"MMseqs2 API is undergoing maintenance. "
|
171
|
+
"Please try again in a few minutes."
|
172
|
+
)
|
173
|
+
raise Exception(msg)
|
174
|
+
|
175
|
+
# wait for job to finish
|
176
|
+
ID, TIME = out["id"], 0
|
177
|
+
pbar.set_description(out["status"])
|
178
|
+
while out["status"] in ["UNKNOWN", "RUNNING", "PENDING"]:
|
179
|
+
t = 5 + random.randint(0, 5)
|
180
|
+
logger.error(f"Sleeping for {t}s. Reason: {out['status']}")
|
181
|
+
time.sleep(t)
|
182
|
+
out = status(ID)
|
183
|
+
pbar.set_description(out["status"])
|
184
|
+
if out["status"] == "RUNNING":
|
185
|
+
TIME += t
|
186
|
+
pbar.update(n=t)
|
187
|
+
|
188
|
+
if out["status"] == "COMPLETE":
|
189
|
+
if TIME < TIME_ESTIMATE:
|
190
|
+
pbar.update(n=(TIME_ESTIMATE - TIME))
|
191
|
+
REDO = False
|
192
|
+
|
193
|
+
if out["status"] == "ERROR":
|
194
|
+
REDO = False
|
195
|
+
msg = (
|
196
|
+
"MMseqs2 API is giving errors. Please confirm your "
|
197
|
+
" input is a valid protein sequence. If error persists, "
|
198
|
+
"please try again an hour later."
|
199
|
+
)
|
200
|
+
raise Exception(msg)
|
201
|
+
|
202
|
+
# Download results
|
203
|
+
download(ID, tar_gz_file)
|
204
|
+
|
205
|
+
# prep list of a3m files
|
206
|
+
if use_pairing:
|
207
|
+
a3m_files = [f"{path}/pair.a3m"]
|
208
|
+
else:
|
209
|
+
a3m_files = [f"{path}/uniref.a3m"]
|
210
|
+
if use_env:
|
211
|
+
a3m_files.append(f"{path}/bfd.mgnify30.metaeuk30.smag30.a3m")
|
212
|
+
|
213
|
+
# extract a3m files
|
214
|
+
if any(not os.path.isfile(a3m_file) for a3m_file in a3m_files):
|
215
|
+
with tarfile.open(tar_gz_file) as tar_gz:
|
216
|
+
tar_gz.extractall(path)
|
217
|
+
|
218
|
+
# gather a3m lines
|
219
|
+
a3m_lines = {}
|
220
|
+
for a3m_file in a3m_files:
|
221
|
+
update_M, M = True, None
|
222
|
+
for line in open(a3m_file, "r"):
|
223
|
+
if len(line) > 0:
|
224
|
+
if "\x00" in line:
|
225
|
+
line = line.replace("\x00", "")
|
226
|
+
update_M = True
|
227
|
+
if line.startswith(">") and update_M:
|
228
|
+
M = int(line[1:].rstrip())
|
229
|
+
update_M = False
|
230
|
+
if M not in a3m_lines:
|
231
|
+
a3m_lines[M] = []
|
232
|
+
a3m_lines[M].append(line)
|
233
|
+
|
234
|
+
a3m_lines = ["".join(a3m_lines[n]) for n in Ms]
|
235
|
+
return a3m_lines
|
boltz/data/pad.py
ADDED
@@ -0,0 +1,84 @@
|
|
1
|
+
import torch
|
2
|
+
from torch import Tensor
|
3
|
+
from torch.nn.functional import pad
|
4
|
+
|
5
|
+
|
6
|
+
def pad_dim(data: Tensor, dim: int, pad_len: float, value: float = 0) -> Tensor:
|
7
|
+
"""Pad a tensor along a given dimension.
|
8
|
+
|
9
|
+
Parameters
|
10
|
+
----------
|
11
|
+
data : Tensor
|
12
|
+
The input tensor.
|
13
|
+
dim : int
|
14
|
+
The dimension to pad.
|
15
|
+
pad_len : float
|
16
|
+
The padding length.
|
17
|
+
value : int, optional
|
18
|
+
The value to pad with.
|
19
|
+
|
20
|
+
Returns
|
21
|
+
-------
|
22
|
+
Tensor
|
23
|
+
The padded tensor.
|
24
|
+
|
25
|
+
"""
|
26
|
+
if pad_len == 0:
|
27
|
+
return data
|
28
|
+
|
29
|
+
total_dims = len(data.shape)
|
30
|
+
padding = [0] * (2 * (total_dims - dim))
|
31
|
+
padding[2 * (total_dims - 1 - dim) + 1] = pad_len
|
32
|
+
return pad(data, tuple(padding), value=value)
|
33
|
+
|
34
|
+
|
35
|
+
def pad_to_max(data: list[Tensor], value: float = 0) -> tuple[Tensor, Tensor]:
|
36
|
+
"""Pad the data in all dimensions to the maximum found.
|
37
|
+
|
38
|
+
Parameters
|
39
|
+
----------
|
40
|
+
data : list[Tensor]
|
41
|
+
list of tensors to pad.
|
42
|
+
value : float
|
43
|
+
The value to use for padding.
|
44
|
+
|
45
|
+
Returns
|
46
|
+
-------
|
47
|
+
Tensor
|
48
|
+
The padded tensor.
|
49
|
+
Tensor
|
50
|
+
The padding mask.
|
51
|
+
|
52
|
+
"""
|
53
|
+
if isinstance(data[0], str):
|
54
|
+
return data, 0
|
55
|
+
|
56
|
+
# Check if all have the same shape
|
57
|
+
if all(d.shape == data[0].shape for d in data):
|
58
|
+
return torch.stack(data, dim=0), 0
|
59
|
+
|
60
|
+
# Get the maximum in each dimension
|
61
|
+
num_dims = len(data[0].shape)
|
62
|
+
max_dims = [max(d.shape[i] for d in data) for i in range(num_dims)]
|
63
|
+
|
64
|
+
# Get the padding lengths
|
65
|
+
pad_lengths = []
|
66
|
+
for d in data:
|
67
|
+
dims = []
|
68
|
+
for i in range(num_dims):
|
69
|
+
dims.append(0)
|
70
|
+
dims.append(max_dims[num_dims - i - 1] - d.shape[num_dims - i - 1])
|
71
|
+
pad_lengths.append(dims)
|
72
|
+
|
73
|
+
# Pad the data
|
74
|
+
padding = [
|
75
|
+
pad(torch.ones_like(d), pad_len, value=0)
|
76
|
+
for d, pad_len in zip(data, pad_lengths)
|
77
|
+
]
|
78
|
+
data = [pad(d, pad_len, value=value) for d, pad_len in zip(data, pad_lengths)]
|
79
|
+
|
80
|
+
# Stack the data
|
81
|
+
padding = torch.stack(padding, dim=0)
|
82
|
+
data = torch.stack(data, dim=0)
|
83
|
+
|
84
|
+
return data, padding
|
File without changes
|
boltz/data/parse/a3m.py
ADDED
@@ -0,0 +1,134 @@
|
|
1
|
+
import gzip
|
2
|
+
from pathlib import Path
|
3
|
+
from typing import Optional, TextIO
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
|
7
|
+
from boltz.data import const
|
8
|
+
from boltz.data.types import MSA, MSADeletion, MSAResidue, MSASequence
|
9
|
+
|
10
|
+
|
11
|
+
def _parse_a3m( # noqa: C901
|
12
|
+
lines: TextIO,
|
13
|
+
taxonomy: Optional[dict[str, str]],
|
14
|
+
max_seqs: Optional[int] = None,
|
15
|
+
) -> MSA:
|
16
|
+
"""Process an MSA file.
|
17
|
+
|
18
|
+
Parameters
|
19
|
+
----------
|
20
|
+
lines : TextIO
|
21
|
+
The lines of the MSA file.
|
22
|
+
taxonomy : dict[str, str]
|
23
|
+
The taxonomy database, if available.
|
24
|
+
max_seqs : int, optional
|
25
|
+
The maximum number of sequences.
|
26
|
+
|
27
|
+
Returns
|
28
|
+
-------
|
29
|
+
MSA
|
30
|
+
The MSA object.
|
31
|
+
|
32
|
+
"""
|
33
|
+
visited = set()
|
34
|
+
sequences = []
|
35
|
+
deletions = []
|
36
|
+
residues = []
|
37
|
+
|
38
|
+
seq_idx = 0
|
39
|
+
for line in lines:
|
40
|
+
line: str
|
41
|
+
line = line.strip() # noqa: PLW2901
|
42
|
+
if not line or line.startswith("#"):
|
43
|
+
continue
|
44
|
+
|
45
|
+
# Get taxonomy, if annotated
|
46
|
+
if line.startswith(">"):
|
47
|
+
header = line.split()[0]
|
48
|
+
if taxonomy and header.startswith(">UniRef100"):
|
49
|
+
uniref_id = header.split("_")[1]
|
50
|
+
taxonomy_id = taxonomy.get(uniref_id)
|
51
|
+
if taxonomy_id is None:
|
52
|
+
taxonomy_id = -1
|
53
|
+
else:
|
54
|
+
taxonomy_id = -1
|
55
|
+
continue
|
56
|
+
|
57
|
+
# Skip if duplicate sequence
|
58
|
+
str_seq = line.replace("-", "").upper()
|
59
|
+
if str_seq not in visited:
|
60
|
+
visited.add(str_seq)
|
61
|
+
else:
|
62
|
+
continue
|
63
|
+
|
64
|
+
# Process sequence
|
65
|
+
residue = []
|
66
|
+
deletion = []
|
67
|
+
count = 0
|
68
|
+
res_idx = 0
|
69
|
+
for c in line:
|
70
|
+
if c != "-" and c.islower():
|
71
|
+
count += 1
|
72
|
+
continue
|
73
|
+
token = const.prot_letter_to_token[c]
|
74
|
+
token = const.token_ids[token]
|
75
|
+
residue.append(token)
|
76
|
+
if count > 0:
|
77
|
+
deletion.append((res_idx, count))
|
78
|
+
count = 0
|
79
|
+
res_idx += 1
|
80
|
+
|
81
|
+
res_start = len(residues)
|
82
|
+
res_end = res_start + len(residue)
|
83
|
+
|
84
|
+
del_start = len(deletions)
|
85
|
+
del_end = del_start + len(deletion)
|
86
|
+
|
87
|
+
sequences.append((seq_idx, taxonomy_id, res_start, res_end, del_start, del_end))
|
88
|
+
residues.extend(residue)
|
89
|
+
deletions.extend(deletion)
|
90
|
+
|
91
|
+
seq_idx += 1
|
92
|
+
if (max_seqs is not None) and (seq_idx >= max_seqs):
|
93
|
+
break
|
94
|
+
|
95
|
+
# Create MSA object
|
96
|
+
msa = MSA(
|
97
|
+
residues=np.array(residues, dtype=MSAResidue),
|
98
|
+
deletions=np.array(deletions, dtype=MSADeletion),
|
99
|
+
sequences=np.array(sequences, dtype=MSASequence),
|
100
|
+
)
|
101
|
+
return msa
|
102
|
+
|
103
|
+
|
104
|
+
def parse_a3m(
|
105
|
+
path: Path,
|
106
|
+
taxonomy: Optional[dict[str, str]],
|
107
|
+
max_seqs: Optional[int] = None,
|
108
|
+
) -> MSA:
|
109
|
+
"""Process an A3M file.
|
110
|
+
|
111
|
+
Parameters
|
112
|
+
----------
|
113
|
+
path : Path
|
114
|
+
The path to the a3m(.gz) file.
|
115
|
+
taxonomy : Redis
|
116
|
+
The taxonomy database.
|
117
|
+
max_seqs : int, optional
|
118
|
+
The maximum number of sequences.
|
119
|
+
|
120
|
+
Returns
|
121
|
+
-------
|
122
|
+
MSA
|
123
|
+
The MSA object.
|
124
|
+
|
125
|
+
"""
|
126
|
+
# Read the file
|
127
|
+
if path.suffix == ".gz":
|
128
|
+
with gzip.open(str(path), "rt") as f:
|
129
|
+
msa = _parse_a3m(f, taxonomy, max_seqs)
|
130
|
+
else:
|
131
|
+
with path.open("r") as f:
|
132
|
+
msa = _parse_a3m(f, taxonomy, max_seqs)
|
133
|
+
|
134
|
+
return msa
|
boltz/data/parse/csv.py
ADDED
@@ -0,0 +1,100 @@
|
|
1
|
+
from pathlib import Path
|
2
|
+
from typing import Optional
|
3
|
+
|
4
|
+
import numpy as np
|
5
|
+
import pandas as pd
|
6
|
+
|
7
|
+
from boltz.data import const
|
8
|
+
from boltz.data.types import MSA, MSADeletion, MSAResidue, MSASequence
|
9
|
+
|
10
|
+
|
11
|
+
def parse_csv(
|
12
|
+
path: Path,
|
13
|
+
max_seqs: Optional[int] = None,
|
14
|
+
) -> MSA:
|
15
|
+
"""Process an A3M file.
|
16
|
+
|
17
|
+
Parameters
|
18
|
+
----------
|
19
|
+
path : Path
|
20
|
+
The path to the a3m(.gz) file.
|
21
|
+
max_seqs : int, optional
|
22
|
+
The maximum number of sequences.
|
23
|
+
|
24
|
+
Returns
|
25
|
+
-------
|
26
|
+
MSA
|
27
|
+
The MSA object.
|
28
|
+
|
29
|
+
"""
|
30
|
+
# Read file
|
31
|
+
data = pd.read_csv(path)
|
32
|
+
|
33
|
+
# Check columns
|
34
|
+
if tuple(sorted(data.columns)) != ("key", "sequence"):
|
35
|
+
msg = "Invalid CSV format, expected columns: ['sequence', 'key']"
|
36
|
+
raise ValueError(msg)
|
37
|
+
|
38
|
+
# Create taxonomy mapping
|
39
|
+
visited = set()
|
40
|
+
sequences = []
|
41
|
+
deletions = []
|
42
|
+
residues = []
|
43
|
+
|
44
|
+
seq_idx = 0
|
45
|
+
for line, key in zip(data["sequence"], data["key"]):
|
46
|
+
line: str
|
47
|
+
line = line.strip() # noqa: PLW2901
|
48
|
+
if not line:
|
49
|
+
continue
|
50
|
+
|
51
|
+
# Get taxonomy, if annotated
|
52
|
+
taxonomy_id = -1
|
53
|
+
if (str(key) != "nan") and (key is not None) and (key != ""):
|
54
|
+
taxonomy_id = key
|
55
|
+
|
56
|
+
# Skip if duplicate sequence
|
57
|
+
str_seq = line.replace("-", "").upper()
|
58
|
+
if str_seq not in visited:
|
59
|
+
visited.add(str_seq)
|
60
|
+
else:
|
61
|
+
continue
|
62
|
+
|
63
|
+
# Process sequence
|
64
|
+
residue = []
|
65
|
+
deletion = []
|
66
|
+
count = 0
|
67
|
+
res_idx = 0
|
68
|
+
for c in line:
|
69
|
+
if c != "-" and c.islower():
|
70
|
+
count += 1
|
71
|
+
continue
|
72
|
+
token = const.prot_letter_to_token[c]
|
73
|
+
token = const.token_ids[token]
|
74
|
+
residue.append(token)
|
75
|
+
if count > 0:
|
76
|
+
deletion.append((res_idx, count))
|
77
|
+
count = 0
|
78
|
+
res_idx += 1
|
79
|
+
|
80
|
+
res_start = len(residues)
|
81
|
+
res_end = res_start + len(residue)
|
82
|
+
|
83
|
+
del_start = len(deletions)
|
84
|
+
del_end = del_start + len(deletion)
|
85
|
+
|
86
|
+
sequences.append((seq_idx, taxonomy_id, res_start, res_end, del_start, del_end))
|
87
|
+
residues.extend(residue)
|
88
|
+
deletions.extend(deletion)
|
89
|
+
|
90
|
+
seq_idx += 1
|
91
|
+
if (max_seqs is not None) and (seq_idx >= max_seqs):
|
92
|
+
break
|
93
|
+
|
94
|
+
# Create MSA object
|
95
|
+
msa = MSA(
|
96
|
+
residues=np.array(residues, dtype=MSAResidue),
|
97
|
+
deletions=np.array(deletions, dtype=MSADeletion),
|
98
|
+
sequences=np.array(sequences, dtype=MSASequence),
|
99
|
+
)
|
100
|
+
return msa
|
@@ -0,0 +1,138 @@
|
|
1
|
+
from collections.abc import Mapping
|
2
|
+
from pathlib import Path
|
3
|
+
|
4
|
+
from Bio import SeqIO
|
5
|
+
from rdkit.Chem.rdchem import Mol
|
6
|
+
|
7
|
+
from boltz.data.parse.yaml import parse_boltz_schema
|
8
|
+
from boltz.data.types import Target
|
9
|
+
|
10
|
+
|
11
|
+
def parse_fasta( # noqa: C901, PLR0912
|
12
|
+
path: Path,
|
13
|
+
ccd: Mapping[str, Mol],
|
14
|
+
mol_dir: Path,
|
15
|
+
boltz2: bool = False,
|
16
|
+
) -> Target:
|
17
|
+
"""Parse a fasta file.
|
18
|
+
|
19
|
+
The name of the fasta file is used as the name of this job.
|
20
|
+
We rely on the fasta record id to determine the entity type.
|
21
|
+
|
22
|
+
> CHAIN_ID|ENTITY_TYPE|MSA_ID
|
23
|
+
SEQUENCE
|
24
|
+
> CHAIN_ID|ENTITY_TYPE|MSA_ID
|
25
|
+
...
|
26
|
+
|
27
|
+
Where ENTITY_TYPE is either protein, rna, dna, ccd or smiles,
|
28
|
+
and CHAIN_ID is the chain identifier, which should be unique.
|
29
|
+
The MSA_ID is optional and should only be used on proteins.
|
30
|
+
|
31
|
+
Parameters
|
32
|
+
----------
|
33
|
+
fasta_file : Path
|
34
|
+
Path to the fasta file.
|
35
|
+
ccd : Dict
|
36
|
+
Dictionary of CCD components.
|
37
|
+
mol_dir : Path
|
38
|
+
Path to the directory containing the molecules.
|
39
|
+
boltz2 : bool
|
40
|
+
Whether to parse the input for Boltz2.
|
41
|
+
|
42
|
+
Returns
|
43
|
+
-------
|
44
|
+
Target
|
45
|
+
The parsed target.
|
46
|
+
|
47
|
+
"""
|
48
|
+
# Read fasta file
|
49
|
+
with path.open("r") as f:
|
50
|
+
records = list(SeqIO.parse(f, "fasta"))
|
51
|
+
|
52
|
+
# Make sure all records have a chain id and entity
|
53
|
+
for seq_record in records:
|
54
|
+
if "|" not in seq_record.id:
|
55
|
+
msg = f"Invalid record id: {seq_record.id}"
|
56
|
+
raise ValueError(msg)
|
57
|
+
|
58
|
+
header = seq_record.id.split("|")
|
59
|
+
assert len(header) >= 2, f"Invalid record id: {seq_record.id}"
|
60
|
+
|
61
|
+
chain_id, entity_type = header[:2]
|
62
|
+
if entity_type.lower() not in {"protein", "dna", "rna", "ccd", "smiles"}:
|
63
|
+
msg = f"Invalid entity type: {entity_type}"
|
64
|
+
raise ValueError(msg)
|
65
|
+
if chain_id == "":
|
66
|
+
msg = "Empty chain id in input fasta!"
|
67
|
+
raise ValueError(msg)
|
68
|
+
if entity_type == "":
|
69
|
+
msg = "Empty entity type in input fasta!"
|
70
|
+
raise ValueError(msg)
|
71
|
+
|
72
|
+
# Convert to yaml format
|
73
|
+
sequences = []
|
74
|
+
for seq_record in records:
|
75
|
+
# Get chain id, entity type and sequence
|
76
|
+
header = seq_record.id.split("|")
|
77
|
+
chain_id, entity_type = header[:2]
|
78
|
+
if len(header) == 3 and header[2] != "":
|
79
|
+
assert (
|
80
|
+
entity_type.lower() == "protein"
|
81
|
+
), "MSA_ID is only allowed for proteins"
|
82
|
+
msa_id = header[2]
|
83
|
+
else:
|
84
|
+
msa_id = None
|
85
|
+
|
86
|
+
entity_type = entity_type.upper()
|
87
|
+
seq = str(seq_record.seq)
|
88
|
+
|
89
|
+
if entity_type == "PROTEIN":
|
90
|
+
molecule = {
|
91
|
+
"protein": {
|
92
|
+
"id": chain_id,
|
93
|
+
"sequence": seq,
|
94
|
+
"modifications": [],
|
95
|
+
"msa": msa_id,
|
96
|
+
},
|
97
|
+
}
|
98
|
+
elif entity_type == "RNA":
|
99
|
+
molecule = {
|
100
|
+
"rna": {
|
101
|
+
"id": chain_id,
|
102
|
+
"sequence": seq,
|
103
|
+
"modifications": [],
|
104
|
+
},
|
105
|
+
}
|
106
|
+
elif entity_type == "DNA":
|
107
|
+
molecule = {
|
108
|
+
"dna": {
|
109
|
+
"id": chain_id,
|
110
|
+
"sequence": seq,
|
111
|
+
"modifications": [],
|
112
|
+
}
|
113
|
+
}
|
114
|
+
elif entity_type.upper() == "CCD":
|
115
|
+
molecule = {
|
116
|
+
"ligand": {
|
117
|
+
"id": chain_id,
|
118
|
+
"ccd": seq,
|
119
|
+
}
|
120
|
+
}
|
121
|
+
elif entity_type.upper() == "SMILES":
|
122
|
+
molecule = {
|
123
|
+
"ligand": {
|
124
|
+
"id": chain_id,
|
125
|
+
"smiles": seq,
|
126
|
+
}
|
127
|
+
}
|
128
|
+
|
129
|
+
sequences.append(molecule)
|
130
|
+
|
131
|
+
data = {
|
132
|
+
"sequences": sequences,
|
133
|
+
"bonds": [],
|
134
|
+
"version": 1,
|
135
|
+
}
|
136
|
+
|
137
|
+
name = path.stem
|
138
|
+
return parse_boltz_schema(name, data, ccd, mol_dir, boltz2)
|