boltz-vsynthes 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- boltz/__init__.py +7 -0
- boltz/data/__init__.py +0 -0
- boltz/data/const.py +1184 -0
- boltz/data/crop/__init__.py +0 -0
- boltz/data/crop/affinity.py +164 -0
- boltz/data/crop/boltz.py +296 -0
- boltz/data/crop/cropper.py +45 -0
- boltz/data/feature/__init__.py +0 -0
- boltz/data/feature/featurizer.py +1230 -0
- boltz/data/feature/featurizerv2.py +2208 -0
- boltz/data/feature/symmetry.py +602 -0
- boltz/data/filter/__init__.py +0 -0
- boltz/data/filter/dynamic/__init__.py +0 -0
- boltz/data/filter/dynamic/date.py +76 -0
- boltz/data/filter/dynamic/filter.py +24 -0
- boltz/data/filter/dynamic/max_residues.py +37 -0
- boltz/data/filter/dynamic/resolution.py +34 -0
- boltz/data/filter/dynamic/size.py +38 -0
- boltz/data/filter/dynamic/subset.py +42 -0
- boltz/data/filter/static/__init__.py +0 -0
- boltz/data/filter/static/filter.py +26 -0
- boltz/data/filter/static/ligand.py +37 -0
- boltz/data/filter/static/polymer.py +299 -0
- boltz/data/module/__init__.py +0 -0
- boltz/data/module/inference.py +307 -0
- boltz/data/module/inferencev2.py +429 -0
- boltz/data/module/training.py +684 -0
- boltz/data/module/trainingv2.py +660 -0
- boltz/data/mol.py +900 -0
- boltz/data/msa/__init__.py +0 -0
- boltz/data/msa/mmseqs2.py +235 -0
- boltz/data/pad.py +84 -0
- boltz/data/parse/__init__.py +0 -0
- boltz/data/parse/a3m.py +134 -0
- boltz/data/parse/csv.py +100 -0
- boltz/data/parse/fasta.py +138 -0
- boltz/data/parse/mmcif.py +1239 -0
- boltz/data/parse/mmcif_with_constraints.py +1607 -0
- boltz/data/parse/schema.py +1851 -0
- boltz/data/parse/yaml.py +68 -0
- boltz/data/sample/__init__.py +0 -0
- boltz/data/sample/cluster.py +283 -0
- boltz/data/sample/distillation.py +57 -0
- boltz/data/sample/random.py +39 -0
- boltz/data/sample/sampler.py +49 -0
- boltz/data/tokenize/__init__.py +0 -0
- boltz/data/tokenize/boltz.py +195 -0
- boltz/data/tokenize/boltz2.py +396 -0
- boltz/data/tokenize/tokenizer.py +24 -0
- boltz/data/types.py +777 -0
- boltz/data/write/__init__.py +0 -0
- boltz/data/write/mmcif.py +305 -0
- boltz/data/write/pdb.py +171 -0
- boltz/data/write/utils.py +23 -0
- boltz/data/write/writer.py +330 -0
- boltz/main.py +1292 -0
- boltz/model/__init__.py +0 -0
- boltz/model/layers/__init__.py +0 -0
- boltz/model/layers/attention.py +132 -0
- boltz/model/layers/attentionv2.py +111 -0
- boltz/model/layers/confidence_utils.py +231 -0
- boltz/model/layers/dropout.py +34 -0
- boltz/model/layers/initialize.py +100 -0
- boltz/model/layers/outer_product_mean.py +98 -0
- boltz/model/layers/pair_averaging.py +135 -0
- boltz/model/layers/pairformer.py +337 -0
- boltz/model/layers/relative.py +58 -0
- boltz/model/layers/transition.py +78 -0
- boltz/model/layers/triangular_attention/__init__.py +0 -0
- boltz/model/layers/triangular_attention/attention.py +189 -0
- boltz/model/layers/triangular_attention/primitives.py +409 -0
- boltz/model/layers/triangular_attention/utils.py +380 -0
- boltz/model/layers/triangular_mult.py +212 -0
- boltz/model/loss/__init__.py +0 -0
- boltz/model/loss/bfactor.py +49 -0
- boltz/model/loss/confidence.py +590 -0
- boltz/model/loss/confidencev2.py +621 -0
- boltz/model/loss/diffusion.py +171 -0
- boltz/model/loss/diffusionv2.py +134 -0
- boltz/model/loss/distogram.py +48 -0
- boltz/model/loss/distogramv2.py +105 -0
- boltz/model/loss/validation.py +1025 -0
- boltz/model/models/__init__.py +0 -0
- boltz/model/models/boltz1.py +1286 -0
- boltz/model/models/boltz2.py +1249 -0
- boltz/model/modules/__init__.py +0 -0
- boltz/model/modules/affinity.py +223 -0
- boltz/model/modules/confidence.py +481 -0
- boltz/model/modules/confidence_utils.py +181 -0
- boltz/model/modules/confidencev2.py +495 -0
- boltz/model/modules/diffusion.py +844 -0
- boltz/model/modules/diffusion_conditioning.py +116 -0
- boltz/model/modules/diffusionv2.py +677 -0
- boltz/model/modules/encoders.py +639 -0
- boltz/model/modules/encodersv2.py +565 -0
- boltz/model/modules/transformers.py +322 -0
- boltz/model/modules/transformersv2.py +261 -0
- boltz/model/modules/trunk.py +688 -0
- boltz/model/modules/trunkv2.py +828 -0
- boltz/model/modules/utils.py +303 -0
- boltz/model/optim/__init__.py +0 -0
- boltz/model/optim/ema.py +389 -0
- boltz/model/optim/scheduler.py +99 -0
- boltz/model/potentials/__init__.py +0 -0
- boltz/model/potentials/potentials.py +497 -0
- boltz/model/potentials/schedules.py +32 -0
- boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
- boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
- boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
- boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
- boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
- boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1230 @@
|
|
1
|
+
import math
|
2
|
+
import random
|
3
|
+
from typing import Optional
|
4
|
+
|
5
|
+
import numba
|
6
|
+
import numpy as np
|
7
|
+
import numpy.typing as npt
|
8
|
+
import torch
|
9
|
+
from numba import types
|
10
|
+
from torch import Tensor, from_numpy
|
11
|
+
from torch.nn.functional import one_hot
|
12
|
+
|
13
|
+
from boltz.data import const
|
14
|
+
from boltz.data.feature.symmetry import (
|
15
|
+
get_amino_acids_symmetries,
|
16
|
+
get_chain_symmetries,
|
17
|
+
get_ligand_symmetries,
|
18
|
+
)
|
19
|
+
from boltz.data.pad import pad_dim
|
20
|
+
from boltz.data.types import (
|
21
|
+
MSA,
|
22
|
+
MSADeletion,
|
23
|
+
MSAResidue,
|
24
|
+
MSASequence,
|
25
|
+
Tokenized,
|
26
|
+
)
|
27
|
+
from boltz.model.modules.utils import center_random_augmentation
|
28
|
+
|
29
|
+
####################################################################################################
|
30
|
+
# HELPERS
|
31
|
+
####################################################################################################
|
32
|
+
|
33
|
+
|
34
|
+
def compute_frames_nonpolymer(
|
35
|
+
data: Tokenized,
|
36
|
+
coords,
|
37
|
+
resolved_mask,
|
38
|
+
atom_to_token,
|
39
|
+
frame_data: list,
|
40
|
+
resolved_frame_data: list,
|
41
|
+
) -> tuple[list, list]:
|
42
|
+
"""Get the frames for non-polymer tokens.
|
43
|
+
|
44
|
+
Parameters
|
45
|
+
----------
|
46
|
+
data : Tokenized
|
47
|
+
The tokenized data.
|
48
|
+
frame_data : list
|
49
|
+
The frame data.
|
50
|
+
resolved_frame_data : list
|
51
|
+
The resolved frame data.
|
52
|
+
|
53
|
+
Returns
|
54
|
+
-------
|
55
|
+
tuple[list, list]
|
56
|
+
The frame data and resolved frame data.
|
57
|
+
|
58
|
+
"""
|
59
|
+
frame_data = np.array(frame_data)
|
60
|
+
resolved_frame_data = np.array(resolved_frame_data)
|
61
|
+
asym_id_token = data.tokens["asym_id"]
|
62
|
+
asym_id_atom = data.tokens["asym_id"][atom_to_token]
|
63
|
+
token_idx = 0
|
64
|
+
atom_idx = 0
|
65
|
+
for id in np.unique(data.tokens["asym_id"]):
|
66
|
+
mask_chain_token = asym_id_token == id
|
67
|
+
mask_chain_atom = asym_id_atom == id
|
68
|
+
num_tokens = mask_chain_token.sum()
|
69
|
+
num_atoms = mask_chain_atom.sum()
|
70
|
+
if (
|
71
|
+
data.tokens[token_idx]["mol_type"] != const.chain_type_ids["NONPOLYMER"]
|
72
|
+
or num_atoms < 3
|
73
|
+
):
|
74
|
+
token_idx += num_tokens
|
75
|
+
atom_idx += num_atoms
|
76
|
+
continue
|
77
|
+
dist_mat = (
|
78
|
+
(
|
79
|
+
coords.reshape(-1, 3)[mask_chain_atom][:, None, :]
|
80
|
+
- coords.reshape(-1, 3)[mask_chain_atom][None, :, :]
|
81
|
+
)
|
82
|
+
** 2
|
83
|
+
).sum(-1) ** 0.5
|
84
|
+
resolved_pair = 1 - (
|
85
|
+
resolved_mask[mask_chain_atom][None, :]
|
86
|
+
* resolved_mask[mask_chain_atom][:, None]
|
87
|
+
).astype(np.float32)
|
88
|
+
resolved_pair[resolved_pair == 1] = math.inf
|
89
|
+
indices = np.argsort(dist_mat + resolved_pair, axis=1)
|
90
|
+
frames = (
|
91
|
+
np.concatenate(
|
92
|
+
[
|
93
|
+
indices[:, 1:2],
|
94
|
+
indices[:, 0:1],
|
95
|
+
indices[:, 2:3],
|
96
|
+
],
|
97
|
+
axis=1,
|
98
|
+
)
|
99
|
+
+ atom_idx
|
100
|
+
)
|
101
|
+
frame_data[token_idx : token_idx + num_atoms, :] = frames
|
102
|
+
resolved_frame_data[token_idx : token_idx + num_atoms] = resolved_mask[
|
103
|
+
frames
|
104
|
+
].all(axis=1)
|
105
|
+
token_idx += num_tokens
|
106
|
+
atom_idx += num_atoms
|
107
|
+
frames_expanded = coords.reshape(-1, 3)[frame_data]
|
108
|
+
|
109
|
+
mask_collinear = compute_collinear_mask(
|
110
|
+
frames_expanded[:, 1] - frames_expanded[:, 0],
|
111
|
+
frames_expanded[:, 1] - frames_expanded[:, 2],
|
112
|
+
)
|
113
|
+
return frame_data, resolved_frame_data & mask_collinear
|
114
|
+
|
115
|
+
|
116
|
+
def compute_collinear_mask(v1, v2):
|
117
|
+
norm1 = np.linalg.norm(v1, axis=1, keepdims=True)
|
118
|
+
norm2 = np.linalg.norm(v2, axis=1, keepdims=True)
|
119
|
+
v1 = v1 / (norm1 + 1e-6)
|
120
|
+
v2 = v2 / (norm2 + 1e-6)
|
121
|
+
mask_angle = np.abs(np.sum(v1 * v2, axis=1)) < 0.9063
|
122
|
+
mask_overlap1 = norm1.reshape(-1) > 1e-2
|
123
|
+
mask_overlap2 = norm2.reshape(-1) > 1e-2
|
124
|
+
return mask_angle & mask_overlap1 & mask_overlap2
|
125
|
+
|
126
|
+
|
127
|
+
def dummy_msa(residues: np.ndarray) -> MSA:
|
128
|
+
"""Create a dummy MSA for a chain.
|
129
|
+
|
130
|
+
Parameters
|
131
|
+
----------
|
132
|
+
residues : np.ndarray
|
133
|
+
The residues for the chain.
|
134
|
+
|
135
|
+
Returns
|
136
|
+
-------
|
137
|
+
MSA
|
138
|
+
The dummy MSA.
|
139
|
+
|
140
|
+
"""
|
141
|
+
residues = [res["res_type"] for res in residues]
|
142
|
+
deletions = []
|
143
|
+
sequences = [(0, -1, 0, len(residues), 0, 0)]
|
144
|
+
return MSA(
|
145
|
+
residues=np.array(residues, dtype=MSAResidue),
|
146
|
+
deletions=np.array(deletions, dtype=MSADeletion),
|
147
|
+
sequences=np.array(sequences, dtype=MSASequence),
|
148
|
+
)
|
149
|
+
|
150
|
+
|
151
|
+
def construct_paired_msa( # noqa: C901, PLR0915, PLR0912
|
152
|
+
data: Tokenized,
|
153
|
+
max_seqs: int,
|
154
|
+
max_pairs: int = 8192,
|
155
|
+
max_total: int = 16384,
|
156
|
+
random_subset: bool = False,
|
157
|
+
) -> tuple[Tensor, Tensor, Tensor]:
|
158
|
+
"""Pair the MSA data.
|
159
|
+
|
160
|
+
Parameters
|
161
|
+
----------
|
162
|
+
data : Input
|
163
|
+
The input data.
|
164
|
+
|
165
|
+
Returns
|
166
|
+
-------
|
167
|
+
Tensor
|
168
|
+
The MSA data.
|
169
|
+
Tensor
|
170
|
+
The deletion data.
|
171
|
+
Tensor
|
172
|
+
Mask indicating paired sequences.
|
173
|
+
|
174
|
+
"""
|
175
|
+
# Get unique chains (ensuring monotonicity in the order)
|
176
|
+
assert np.all(np.diff(data.tokens["asym_id"], n=1) >= 0)
|
177
|
+
chain_ids = np.unique(data.tokens["asym_id"])
|
178
|
+
|
179
|
+
# Get relevant MSA, and create a dummy for chains without
|
180
|
+
msa = {k: data.msa[k] for k in chain_ids if k in data.msa}
|
181
|
+
for chain_id in chain_ids:
|
182
|
+
if chain_id not in msa:
|
183
|
+
chain = data.structure.chains[chain_id]
|
184
|
+
res_start = chain["res_idx"]
|
185
|
+
res_end = res_start + chain["res_num"]
|
186
|
+
residues = data.structure.residues[res_start:res_end]
|
187
|
+
msa[chain_id] = dummy_msa(residues)
|
188
|
+
|
189
|
+
# Map taxonomies to (chain_id, seq_idx)
|
190
|
+
taxonomy_map: dict[str, list] = {}
|
191
|
+
for chain_id, chain_msa in msa.items():
|
192
|
+
sequences = chain_msa.sequences
|
193
|
+
sequences = sequences[sequences["taxonomy"] != -1]
|
194
|
+
for sequence in sequences:
|
195
|
+
seq_idx = sequence["seq_idx"]
|
196
|
+
taxon = sequence["taxonomy"]
|
197
|
+
taxonomy_map.setdefault(taxon, []).append((chain_id, seq_idx))
|
198
|
+
|
199
|
+
# Remove taxonomies with only one sequence and sort by the
|
200
|
+
# number of chain_id present in each of the taxonomies
|
201
|
+
taxonomy_map = {k: v for k, v in taxonomy_map.items() if len(v) > 1}
|
202
|
+
taxonomy_map = sorted(
|
203
|
+
taxonomy_map.items(),
|
204
|
+
key=lambda x: len({c for c, _ in x[1]}),
|
205
|
+
reverse=True,
|
206
|
+
)
|
207
|
+
|
208
|
+
# Keep track of the sequences available per chain, keeping the original
|
209
|
+
# order of the sequences in the MSA to favor the best matching sequences
|
210
|
+
visited = {(c, s) for c, items in taxonomy_map for s in items}
|
211
|
+
available = {}
|
212
|
+
for c in chain_ids:
|
213
|
+
available[c] = [
|
214
|
+
i for i in range(1, len(msa[c].sequences)) if (c, i) not in visited
|
215
|
+
]
|
216
|
+
|
217
|
+
# Create sequence pairs
|
218
|
+
is_paired = []
|
219
|
+
pairing = []
|
220
|
+
|
221
|
+
# Start with the first sequence for each chain
|
222
|
+
is_paired.append({c: 1 for c in chain_ids})
|
223
|
+
pairing.append({c: 0 for c in chain_ids})
|
224
|
+
|
225
|
+
# Then add up to 8191 paired rows
|
226
|
+
for _, pairs in taxonomy_map:
|
227
|
+
# Group occurences by chain_id in case we have multiple
|
228
|
+
# sequences from the same chain and same taxonomy
|
229
|
+
chain_occurences = {}
|
230
|
+
for chain_id, seq_idx in pairs:
|
231
|
+
chain_occurences.setdefault(chain_id, []).append(seq_idx)
|
232
|
+
|
233
|
+
# We create as many pairings as the maximum number of occurences
|
234
|
+
max_occurences = max(len(v) for v in chain_occurences.values())
|
235
|
+
for i in range(max_occurences):
|
236
|
+
row_pairing = {}
|
237
|
+
row_is_paired = {}
|
238
|
+
|
239
|
+
# Add the chains present in the taxonomy
|
240
|
+
for chain_id, seq_idxs in chain_occurences.items():
|
241
|
+
# Roll over the sequence index to maximize diversity
|
242
|
+
idx = i % len(seq_idxs)
|
243
|
+
seq_idx = seq_idxs[idx]
|
244
|
+
|
245
|
+
# Add the sequence to the pairing
|
246
|
+
row_pairing[chain_id] = seq_idx
|
247
|
+
row_is_paired[chain_id] = 1
|
248
|
+
|
249
|
+
# Add any missing chains
|
250
|
+
for chain_id in chain_ids:
|
251
|
+
if chain_id not in row_pairing:
|
252
|
+
row_is_paired[chain_id] = 0
|
253
|
+
if available[chain_id]:
|
254
|
+
# Add the next available sequence
|
255
|
+
seq_idx = available[chain_id].pop(0)
|
256
|
+
row_pairing[chain_id] = seq_idx
|
257
|
+
else:
|
258
|
+
# No more sequences available, we place a gap
|
259
|
+
row_pairing[chain_id] = -1
|
260
|
+
|
261
|
+
pairing.append(row_pairing)
|
262
|
+
is_paired.append(row_is_paired)
|
263
|
+
|
264
|
+
# Break if we have enough pairs
|
265
|
+
if len(pairing) >= max_pairs:
|
266
|
+
break
|
267
|
+
|
268
|
+
# Break if we have enough pairs
|
269
|
+
if len(pairing) >= max_pairs:
|
270
|
+
break
|
271
|
+
|
272
|
+
# Now add up to 16384 unpaired rows total
|
273
|
+
max_left = max(len(v) for v in available.values())
|
274
|
+
for _ in range(min(max_total - len(pairing), max_left)):
|
275
|
+
row_pairing = {}
|
276
|
+
row_is_paired = {}
|
277
|
+
for chain_id in chain_ids:
|
278
|
+
row_is_paired[chain_id] = 0
|
279
|
+
if available[chain_id]:
|
280
|
+
# Add the next available sequence
|
281
|
+
seq_idx = available[chain_id].pop(0)
|
282
|
+
row_pairing[chain_id] = seq_idx
|
283
|
+
else:
|
284
|
+
# No more sequences available, we place a gap
|
285
|
+
row_pairing[chain_id] = -1
|
286
|
+
|
287
|
+
pairing.append(row_pairing)
|
288
|
+
is_paired.append(row_is_paired)
|
289
|
+
|
290
|
+
# Break if we have enough sequences
|
291
|
+
if len(pairing) >= max_total:
|
292
|
+
break
|
293
|
+
|
294
|
+
# Randomly sample a subset of the pairs
|
295
|
+
# ensuring the first row is always present
|
296
|
+
if random_subset:
|
297
|
+
num_seqs = len(pairing)
|
298
|
+
if num_seqs > max_seqs:
|
299
|
+
indices = np.random.choice(
|
300
|
+
list(range(1, num_seqs)), size=max_seqs - 1, replace=False
|
301
|
+
) # noqa: NPY002
|
302
|
+
pairing = [pairing[0]] + [pairing[i] for i in indices]
|
303
|
+
is_paired = [is_paired[0]] + [is_paired[i] for i in indices]
|
304
|
+
else:
|
305
|
+
# Deterministic downsample to max_seqs
|
306
|
+
pairing = pairing[:max_seqs]
|
307
|
+
is_paired = is_paired[:max_seqs]
|
308
|
+
|
309
|
+
# Map (chain_id, seq_idx, res_idx) to deletion
|
310
|
+
deletions = {}
|
311
|
+
for chain_id, chain_msa in msa.items():
|
312
|
+
chain_deletions = chain_msa.deletions
|
313
|
+
for sequence in chain_msa.sequences:
|
314
|
+
del_start = sequence["del_start"]
|
315
|
+
del_end = sequence["del_end"]
|
316
|
+
chain_deletions = chain_msa.deletions[del_start:del_end]
|
317
|
+
for deletion_data in chain_deletions:
|
318
|
+
seq_idx = sequence["seq_idx"]
|
319
|
+
res_idx = deletion_data["res_idx"]
|
320
|
+
deletion = deletion_data["deletion"]
|
321
|
+
deletions[(chain_id, seq_idx, res_idx)] = deletion
|
322
|
+
|
323
|
+
# Add all the token MSA data
|
324
|
+
msa_data, del_data, paired_data = prepare_msa_arrays(
|
325
|
+
data.tokens, pairing, is_paired, deletions, msa
|
326
|
+
)
|
327
|
+
|
328
|
+
msa_data = torch.tensor(msa_data, dtype=torch.long)
|
329
|
+
del_data = torch.tensor(del_data, dtype=torch.float)
|
330
|
+
paired_data = torch.tensor(paired_data, dtype=torch.float)
|
331
|
+
|
332
|
+
return msa_data, del_data, paired_data
|
333
|
+
|
334
|
+
|
335
|
+
def prepare_msa_arrays(
|
336
|
+
tokens,
|
337
|
+
pairing: list[dict[int, int]],
|
338
|
+
is_paired: list[dict[int, int]],
|
339
|
+
deletions: dict[tuple[int, int, int], int],
|
340
|
+
msa: dict[int, MSA],
|
341
|
+
) -> tuple[npt.NDArray[np.int64], npt.NDArray[np.int64], npt.NDArray[np.int64]]:
|
342
|
+
"""Reshape data to play nicely with numba jit."""
|
343
|
+
token_asym_ids_arr = np.array([t["asym_id"] for t in tokens], dtype=np.int64)
|
344
|
+
token_res_idxs_arr = np.array([t["res_idx"] for t in tokens], dtype=np.int64)
|
345
|
+
|
346
|
+
chain_ids = sorted(msa.keys())
|
347
|
+
|
348
|
+
# chain_ids are not necessarily contiguous (e.g. they might be 0, 24, 25).
|
349
|
+
# This allows us to look up a chain_id by it's index in the chain_ids list.
|
350
|
+
chain_id_to_idx = {chain_id: i for i, chain_id in enumerate(chain_ids)}
|
351
|
+
token_asym_ids_idx_arr = np.array(
|
352
|
+
[chain_id_to_idx[asym_id] for asym_id in token_asym_ids_arr], dtype=np.int64
|
353
|
+
)
|
354
|
+
|
355
|
+
pairing_arr = np.zeros((len(pairing), len(chain_ids)), dtype=np.int64)
|
356
|
+
is_paired_arr = np.zeros((len(is_paired), len(chain_ids)), dtype=np.int64)
|
357
|
+
|
358
|
+
for i, row_pairing in enumerate(pairing):
|
359
|
+
for chain_id in chain_ids:
|
360
|
+
pairing_arr[i, chain_id_to_idx[chain_id]] = row_pairing[chain_id]
|
361
|
+
|
362
|
+
for i, row_is_paired in enumerate(is_paired):
|
363
|
+
for chain_id in chain_ids:
|
364
|
+
is_paired_arr[i, chain_id_to_idx[chain_id]] = row_is_paired[chain_id]
|
365
|
+
|
366
|
+
max_seq_len = max(len(msa[chain_id].sequences) for chain_id in chain_ids)
|
367
|
+
|
368
|
+
# we want res_start from sequences
|
369
|
+
msa_sequences = np.full((len(chain_ids), max_seq_len), -1, dtype=np.int64)
|
370
|
+
for chain_id in chain_ids:
|
371
|
+
for i, seq in enumerate(msa[chain_id].sequences):
|
372
|
+
msa_sequences[chain_id_to_idx[chain_id], i] = seq["res_start"]
|
373
|
+
|
374
|
+
max_residues_len = max(len(msa[chain_id].residues) for chain_id in chain_ids)
|
375
|
+
msa_residues = np.full((len(chain_ids), max_residues_len), -1, dtype=np.int64)
|
376
|
+
for chain_id in chain_ids:
|
377
|
+
residues = msa[chain_id].residues.astype(np.int64)
|
378
|
+
idxs = np.arange(len(residues))
|
379
|
+
chain_idx = chain_id_to_idx[chain_id]
|
380
|
+
msa_residues[chain_idx, idxs] = residues
|
381
|
+
|
382
|
+
deletions_dict = numba.typed.Dict.empty(
|
383
|
+
key_type=numba.types.Tuple(
|
384
|
+
[numba.types.int64, numba.types.int64, numba.types.int64]
|
385
|
+
),
|
386
|
+
value_type=numba.types.int64,
|
387
|
+
)
|
388
|
+
deletions_dict.update(deletions)
|
389
|
+
|
390
|
+
return _prepare_msa_arrays_inner(
|
391
|
+
token_asym_ids_arr,
|
392
|
+
token_res_idxs_arr,
|
393
|
+
token_asym_ids_idx_arr,
|
394
|
+
pairing_arr,
|
395
|
+
is_paired_arr,
|
396
|
+
deletions_dict,
|
397
|
+
msa_sequences,
|
398
|
+
msa_residues,
|
399
|
+
const.token_ids["-"],
|
400
|
+
)
|
401
|
+
|
402
|
+
|
403
|
+
deletions_dict_type = types.DictType(types.UniTuple(types.int64, 3), types.int64)
|
404
|
+
|
405
|
+
|
406
|
+
@numba.njit(
|
407
|
+
[
|
408
|
+
types.Tuple(
|
409
|
+
(
|
410
|
+
types.int64[:, ::1], # msa_data
|
411
|
+
types.int64[:, ::1], # del_data
|
412
|
+
types.int64[:, ::1], # paired_data
|
413
|
+
)
|
414
|
+
)(
|
415
|
+
types.int64[::1], # token_asym_ids
|
416
|
+
types.int64[::1], # token_res_idxs
|
417
|
+
types.int64[::1], # token_asym_ids_idx
|
418
|
+
types.int64[:, ::1], # pairing
|
419
|
+
types.int64[:, ::1], # is_paired
|
420
|
+
deletions_dict_type, # deletions
|
421
|
+
types.int64[:, ::1], # msa_sequences
|
422
|
+
types.int64[:, ::1], # msa_residues
|
423
|
+
types.int64, # gap_token
|
424
|
+
)
|
425
|
+
],
|
426
|
+
cache=True,
|
427
|
+
)
|
428
|
+
def _prepare_msa_arrays_inner(
|
429
|
+
token_asym_ids: npt.NDArray[np.int64],
|
430
|
+
token_res_idxs: npt.NDArray[np.int64],
|
431
|
+
token_asym_ids_idx: npt.NDArray[np.int64],
|
432
|
+
pairing: npt.NDArray[np.int64],
|
433
|
+
is_paired: npt.NDArray[np.int64],
|
434
|
+
deletions: dict[tuple[int, int, int], int],
|
435
|
+
msa_sequences: npt.NDArray[np.int64],
|
436
|
+
msa_residues: npt.NDArray[np.int64],
|
437
|
+
gap_token: int,
|
438
|
+
) -> tuple[npt.NDArray[np.int64], npt.NDArray[np.int64], npt.NDArray[np.int64]]:
|
439
|
+
n_tokens = len(token_asym_ids)
|
440
|
+
n_pairs = len(pairing)
|
441
|
+
msa_data = np.full((n_tokens, n_pairs), gap_token, dtype=np.int64)
|
442
|
+
paired_data = np.zeros((n_tokens, n_pairs), dtype=np.int64)
|
443
|
+
del_data = np.zeros((n_tokens, n_pairs), dtype=np.int64)
|
444
|
+
|
445
|
+
# Add all the token MSA data
|
446
|
+
for token_idx in range(n_tokens):
|
447
|
+
chain_id_idx = token_asym_ids_idx[token_idx]
|
448
|
+
chain_id = token_asym_ids[token_idx]
|
449
|
+
res_idx = token_res_idxs[token_idx]
|
450
|
+
|
451
|
+
for pair_idx in range(n_pairs):
|
452
|
+
seq_idx = pairing[pair_idx, chain_id_idx]
|
453
|
+
paired_data[token_idx, pair_idx] = is_paired[pair_idx, chain_id_idx]
|
454
|
+
|
455
|
+
# Add residue type
|
456
|
+
if seq_idx != -1:
|
457
|
+
res_start = msa_sequences[chain_id_idx, seq_idx]
|
458
|
+
res_type = msa_residues[chain_id_idx, res_start + res_idx]
|
459
|
+
k = (chain_id, seq_idx, res_idx)
|
460
|
+
if k in deletions:
|
461
|
+
del_data[token_idx, pair_idx] = deletions[k]
|
462
|
+
msa_data[token_idx, pair_idx] = res_type
|
463
|
+
|
464
|
+
return msa_data, del_data, paired_data
|
465
|
+
|
466
|
+
|
467
|
+
####################################################################################################
|
468
|
+
# FEATURES
|
469
|
+
####################################################################################################
|
470
|
+
|
471
|
+
|
472
|
+
def select_subset_from_mask(mask, p):
|
473
|
+
num_true = np.sum(mask)
|
474
|
+
v = np.random.geometric(p) + 1
|
475
|
+
k = min(v, num_true)
|
476
|
+
|
477
|
+
true_indices = np.where(mask)[0]
|
478
|
+
|
479
|
+
# Randomly select k indices from the true_indices
|
480
|
+
selected_indices = np.random.choice(true_indices, size=k, replace=False)
|
481
|
+
|
482
|
+
new_mask = np.zeros_like(mask)
|
483
|
+
new_mask[selected_indices] = 1
|
484
|
+
|
485
|
+
return new_mask
|
486
|
+
|
487
|
+
|
488
|
+
def process_token_features(
|
489
|
+
data: Tokenized,
|
490
|
+
max_tokens: Optional[int] = None,
|
491
|
+
binder_pocket_conditioned_prop: Optional[float] = 0.0,
|
492
|
+
binder_pocket_cutoff: Optional[float] = 6.0,
|
493
|
+
binder_pocket_sampling_geometric_p: Optional[float] = 0.0,
|
494
|
+
only_ligand_binder_pocket: Optional[bool] = False,
|
495
|
+
inference_binder: Optional[list[int]] = None,
|
496
|
+
inference_pocket: Optional[list[tuple[int, int]]] = None,
|
497
|
+
) -> dict[str, Tensor]:
|
498
|
+
"""Get the token features.
|
499
|
+
|
500
|
+
Parameters
|
501
|
+
----------
|
502
|
+
data : Tokenized
|
503
|
+
The tokenized data.
|
504
|
+
max_tokens : int
|
505
|
+
The maximum number of tokens.
|
506
|
+
|
507
|
+
Returns
|
508
|
+
-------
|
509
|
+
dict[str, Tensor]
|
510
|
+
The token features.
|
511
|
+
|
512
|
+
"""
|
513
|
+
# Token data
|
514
|
+
token_data = data.tokens
|
515
|
+
token_bonds = data.bonds
|
516
|
+
|
517
|
+
# Token core features
|
518
|
+
token_index = torch.arange(len(token_data), dtype=torch.long)
|
519
|
+
residue_index = from_numpy(token_data["res_idx"].copy()).long()
|
520
|
+
asym_id = from_numpy(token_data["asym_id"].copy()).long()
|
521
|
+
entity_id = from_numpy(token_data["entity_id"].copy()).long()
|
522
|
+
sym_id = from_numpy(token_data["sym_id"].copy()).long()
|
523
|
+
mol_type = from_numpy(token_data["mol_type"].copy()).long()
|
524
|
+
res_type = from_numpy(token_data["res_type"].copy()).long()
|
525
|
+
res_type = one_hot(res_type, num_classes=const.num_tokens)
|
526
|
+
disto_center = from_numpy(token_data["disto_coords"].copy())
|
527
|
+
|
528
|
+
# Token mask features
|
529
|
+
pad_mask = torch.ones(len(token_data), dtype=torch.float)
|
530
|
+
resolved_mask = from_numpy(token_data["resolved_mask"].copy()).float()
|
531
|
+
disto_mask = from_numpy(token_data["disto_mask"].copy()).float()
|
532
|
+
cyclic_period = from_numpy(token_data["cyclic_period"].copy())
|
533
|
+
|
534
|
+
# Token bond features
|
535
|
+
if max_tokens is not None:
|
536
|
+
pad_len = max_tokens - len(token_data)
|
537
|
+
num_tokens = max_tokens if pad_len > 0 else len(token_data)
|
538
|
+
else:
|
539
|
+
num_tokens = len(token_data)
|
540
|
+
|
541
|
+
tok_to_idx = {tok["token_idx"]: idx for idx, tok in enumerate(token_data)}
|
542
|
+
bonds = torch.zeros(num_tokens, num_tokens, dtype=torch.float)
|
543
|
+
for token_bond in token_bonds:
|
544
|
+
token_1 = tok_to_idx[token_bond["token_1"]]
|
545
|
+
token_2 = tok_to_idx[token_bond["token_2"]]
|
546
|
+
bonds[token_1, token_2] = 1
|
547
|
+
bonds[token_2, token_1] = 1
|
548
|
+
|
549
|
+
bonds = bonds.unsqueeze(-1)
|
550
|
+
|
551
|
+
# Pocket conditioned feature
|
552
|
+
pocket_feature = (
|
553
|
+
np.zeros(len(token_data)) + const.pocket_contact_info["UNSPECIFIED"]
|
554
|
+
)
|
555
|
+
if inference_binder is not None:
|
556
|
+
assert inference_pocket is not None
|
557
|
+
pocket_residues = set(inference_pocket)
|
558
|
+
for idx, token in enumerate(token_data):
|
559
|
+
if token["asym_id"] in inference_binder:
|
560
|
+
pocket_feature[idx] = const.pocket_contact_info["BINDER"]
|
561
|
+
elif (token["asym_id"], token["res_idx"]) in pocket_residues:
|
562
|
+
pocket_feature[idx] = const.pocket_contact_info["POCKET"]
|
563
|
+
else:
|
564
|
+
pocket_feature[idx] = const.pocket_contact_info["UNSELECTED"]
|
565
|
+
elif (
|
566
|
+
binder_pocket_conditioned_prop > 0.0
|
567
|
+
and random.random() < binder_pocket_conditioned_prop
|
568
|
+
):
|
569
|
+
# choose as binder a random ligand in the crop, if there are no ligands select a protein chain
|
570
|
+
binder_asym_ids = np.unique(
|
571
|
+
token_data["asym_id"][
|
572
|
+
token_data["mol_type"] == const.chain_type_ids["NONPOLYMER"]
|
573
|
+
]
|
574
|
+
)
|
575
|
+
|
576
|
+
if len(binder_asym_ids) == 0:
|
577
|
+
if not only_ligand_binder_pocket:
|
578
|
+
binder_asym_ids = np.unique(token_data["asym_id"])
|
579
|
+
|
580
|
+
if len(binder_asym_ids) > 0:
|
581
|
+
pocket_asym_id = random.choice(binder_asym_ids)
|
582
|
+
binder_mask = token_data["asym_id"] == pocket_asym_id
|
583
|
+
|
584
|
+
binder_coords = []
|
585
|
+
for token in token_data:
|
586
|
+
if token["asym_id"] == pocket_asym_id:
|
587
|
+
binder_coords.append(
|
588
|
+
data.structure.atoms["coords"][
|
589
|
+
token["atom_idx"] : token["atom_idx"] + token["atom_num"]
|
590
|
+
]
|
591
|
+
)
|
592
|
+
binder_coords = np.concatenate(binder_coords, axis=0)
|
593
|
+
|
594
|
+
# find the tokens in the pocket
|
595
|
+
token_dist = np.zeros(len(token_data)) + 1000
|
596
|
+
for i, token in enumerate(token_data):
|
597
|
+
if (
|
598
|
+
token["mol_type"] != const.chain_type_ids["NONPOLYMER"]
|
599
|
+
and token["asym_id"] != pocket_asym_id
|
600
|
+
and token["resolved_mask"] == 1
|
601
|
+
):
|
602
|
+
token_coords = data.structure.atoms["coords"][
|
603
|
+
token["atom_idx"] : token["atom_idx"] + token["atom_num"]
|
604
|
+
]
|
605
|
+
|
606
|
+
# find chain and apply chain transformation
|
607
|
+
for chain in data.structure.chains:
|
608
|
+
if chain["asym_id"] == token["asym_id"]:
|
609
|
+
break
|
610
|
+
|
611
|
+
token_dist[i] = np.min(
|
612
|
+
np.linalg.norm(
|
613
|
+
token_coords[:, None, :] - binder_coords[None, :, :],
|
614
|
+
axis=-1,
|
615
|
+
)
|
616
|
+
)
|
617
|
+
|
618
|
+
pocket_mask = token_dist < binder_pocket_cutoff
|
619
|
+
|
620
|
+
if np.sum(pocket_mask) > 0:
|
621
|
+
pocket_feature = (
|
622
|
+
np.zeros(len(token_data)) + const.pocket_contact_info["UNSELECTED"]
|
623
|
+
)
|
624
|
+
pocket_feature[binder_mask] = const.pocket_contact_info["BINDER"]
|
625
|
+
|
626
|
+
if binder_pocket_sampling_geometric_p > 0.0:
|
627
|
+
# select a subset of the pocket, according
|
628
|
+
# to a geometric distribution with one as minimum
|
629
|
+
pocket_mask = select_subset_from_mask(
|
630
|
+
pocket_mask, binder_pocket_sampling_geometric_p
|
631
|
+
)
|
632
|
+
|
633
|
+
pocket_feature[pocket_mask] = const.pocket_contact_info["POCKET"]
|
634
|
+
pocket_feature = from_numpy(pocket_feature).long()
|
635
|
+
pocket_feature = one_hot(pocket_feature, num_classes=len(const.pocket_contact_info))
|
636
|
+
|
637
|
+
# Pad to max tokens if given
|
638
|
+
if max_tokens is not None:
|
639
|
+
pad_len = max_tokens - len(token_data)
|
640
|
+
if pad_len > 0:
|
641
|
+
token_index = pad_dim(token_index, 0, pad_len)
|
642
|
+
residue_index = pad_dim(residue_index, 0, pad_len)
|
643
|
+
asym_id = pad_dim(asym_id, 0, pad_len)
|
644
|
+
entity_id = pad_dim(entity_id, 0, pad_len)
|
645
|
+
sym_id = pad_dim(sym_id, 0, pad_len)
|
646
|
+
mol_type = pad_dim(mol_type, 0, pad_len)
|
647
|
+
res_type = pad_dim(res_type, 0, pad_len)
|
648
|
+
disto_center = pad_dim(disto_center, 0, pad_len)
|
649
|
+
pad_mask = pad_dim(pad_mask, 0, pad_len)
|
650
|
+
resolved_mask = pad_dim(resolved_mask, 0, pad_len)
|
651
|
+
disto_mask = pad_dim(disto_mask, 0, pad_len)
|
652
|
+
pocket_feature = pad_dim(pocket_feature, 0, pad_len)
|
653
|
+
|
654
|
+
token_features = {
|
655
|
+
"token_index": token_index,
|
656
|
+
"residue_index": residue_index,
|
657
|
+
"asym_id": asym_id,
|
658
|
+
"entity_id": entity_id,
|
659
|
+
"sym_id": sym_id,
|
660
|
+
"mol_type": mol_type,
|
661
|
+
"res_type": res_type,
|
662
|
+
"disto_center": disto_center,
|
663
|
+
"token_bonds": bonds,
|
664
|
+
"token_pad_mask": pad_mask,
|
665
|
+
"token_resolved_mask": resolved_mask,
|
666
|
+
"token_disto_mask": disto_mask,
|
667
|
+
"pocket_feature": pocket_feature,
|
668
|
+
"cyclic_period": cyclic_period,
|
669
|
+
}
|
670
|
+
return token_features
|
671
|
+
|
672
|
+
|
673
|
+
def process_atom_features(
|
674
|
+
data: Tokenized,
|
675
|
+
atoms_per_window_queries: int = 32,
|
676
|
+
min_dist: float = 2.0,
|
677
|
+
max_dist: float = 22.0,
|
678
|
+
num_bins: int = 64,
|
679
|
+
max_atoms: Optional[int] = None,
|
680
|
+
max_tokens: Optional[int] = None,
|
681
|
+
) -> dict[str, Tensor]:
|
682
|
+
"""Get the atom features.
|
683
|
+
|
684
|
+
Parameters
|
685
|
+
----------
|
686
|
+
data : Tokenized
|
687
|
+
The tokenized data.
|
688
|
+
max_atoms : int, optional
|
689
|
+
The maximum number of atoms.
|
690
|
+
|
691
|
+
Returns
|
692
|
+
-------
|
693
|
+
dict[str, Tensor]
|
694
|
+
The atom features.
|
695
|
+
|
696
|
+
"""
|
697
|
+
# Filter to tokens' atoms
|
698
|
+
atom_data = []
|
699
|
+
ref_space_uid = []
|
700
|
+
coord_data = []
|
701
|
+
frame_data = []
|
702
|
+
resolved_frame_data = []
|
703
|
+
atom_to_token = []
|
704
|
+
token_to_rep_atom = [] # index on cropped atom table
|
705
|
+
r_set_to_rep_atom = []
|
706
|
+
disto_coords = []
|
707
|
+
atom_idx = 0
|
708
|
+
|
709
|
+
chain_res_ids = {}
|
710
|
+
for token_id, token in enumerate(data.tokens):
|
711
|
+
# Get the chain residue ids
|
712
|
+
chain_idx, res_id = token["asym_id"], token["res_idx"]
|
713
|
+
chain = data.structure.chains[chain_idx]
|
714
|
+
|
715
|
+
if (chain_idx, res_id) not in chain_res_ids:
|
716
|
+
new_idx = len(chain_res_ids)
|
717
|
+
chain_res_ids[(chain_idx, res_id)] = new_idx
|
718
|
+
else:
|
719
|
+
new_idx = chain_res_ids[(chain_idx, res_id)]
|
720
|
+
|
721
|
+
# Map atoms to token indices
|
722
|
+
ref_space_uid.extend([new_idx] * token["atom_num"])
|
723
|
+
atom_to_token.extend([token_id] * token["atom_num"])
|
724
|
+
|
725
|
+
# Add atom data
|
726
|
+
start = token["atom_idx"]
|
727
|
+
end = token["atom_idx"] + token["atom_num"]
|
728
|
+
token_atoms = data.structure.atoms[start:end]
|
729
|
+
|
730
|
+
# Map token to representative atom
|
731
|
+
token_to_rep_atom.append(atom_idx + token["disto_idx"] - start)
|
732
|
+
if (chain["mol_type"] != const.chain_type_ids["NONPOLYMER"]) and token[
|
733
|
+
"resolved_mask"
|
734
|
+
]:
|
735
|
+
r_set_to_rep_atom.append(atom_idx + token["center_idx"] - start)
|
736
|
+
|
737
|
+
# Get token coordinates
|
738
|
+
token_coords = np.array([token_atoms["coords"]])
|
739
|
+
coord_data.append(token_coords)
|
740
|
+
|
741
|
+
# Get frame data
|
742
|
+
res_type = const.tokens[token["res_type"]]
|
743
|
+
|
744
|
+
if token["atom_num"] < 3 or res_type in ["PAD", "UNK", "-"]:
|
745
|
+
idx_frame_a, idx_frame_b, idx_frame_c = 0, 0, 0
|
746
|
+
mask_frame = False
|
747
|
+
elif (token["mol_type"] == const.chain_type_ids["PROTEIN"]) and (
|
748
|
+
res_type in const.ref_atoms
|
749
|
+
):
|
750
|
+
idx_frame_a, idx_frame_b, idx_frame_c = (
|
751
|
+
const.ref_atoms[res_type].index("N"),
|
752
|
+
const.ref_atoms[res_type].index("CA"),
|
753
|
+
const.ref_atoms[res_type].index("C"),
|
754
|
+
)
|
755
|
+
mask_frame = (
|
756
|
+
token_atoms["is_present"][idx_frame_a]
|
757
|
+
and token_atoms["is_present"][idx_frame_b]
|
758
|
+
and token_atoms["is_present"][idx_frame_c]
|
759
|
+
)
|
760
|
+
elif (
|
761
|
+
token["mol_type"] == const.chain_type_ids["DNA"]
|
762
|
+
or token["mol_type"] == const.chain_type_ids["RNA"]
|
763
|
+
) and (res_type in const.ref_atoms):
|
764
|
+
idx_frame_a, idx_frame_b, idx_frame_c = (
|
765
|
+
const.ref_atoms[res_type].index("C1'"),
|
766
|
+
const.ref_atoms[res_type].index("C3'"),
|
767
|
+
const.ref_atoms[res_type].index("C4'"),
|
768
|
+
)
|
769
|
+
mask_frame = (
|
770
|
+
token_atoms["is_present"][idx_frame_a]
|
771
|
+
and token_atoms["is_present"][idx_frame_b]
|
772
|
+
and token_atoms["is_present"][idx_frame_c]
|
773
|
+
)
|
774
|
+
else:
|
775
|
+
idx_frame_a, idx_frame_b, idx_frame_c = 0, 0, 0
|
776
|
+
mask_frame = False
|
777
|
+
frame_data.append(
|
778
|
+
[idx_frame_a + atom_idx, idx_frame_b + atom_idx, idx_frame_c + atom_idx]
|
779
|
+
)
|
780
|
+
resolved_frame_data.append(mask_frame)
|
781
|
+
|
782
|
+
# Get distogram coordinates
|
783
|
+
disto_coords_tok = data.structure.atoms[token["disto_idx"]]["coords"]
|
784
|
+
disto_coords.append(disto_coords_tok)
|
785
|
+
|
786
|
+
# Update atom data. This is technically never used again (we rely on coord_data),
|
787
|
+
# but we update for consistency and to make sure the Atom object has valid, transformed coordinates.
|
788
|
+
token_atoms = token_atoms.copy()
|
789
|
+
token_atoms["coords"] = token_coords[0] # atom has a copy of first coords
|
790
|
+
atom_data.append(token_atoms)
|
791
|
+
atom_idx += len(token_atoms)
|
792
|
+
|
793
|
+
disto_coords = np.array(disto_coords)
|
794
|
+
|
795
|
+
# Compute distogram
|
796
|
+
t_center = torch.Tensor(disto_coords)
|
797
|
+
t_dists = torch.cdist(t_center, t_center)
|
798
|
+
boundaries = torch.linspace(min_dist, max_dist, num_bins - 1)
|
799
|
+
distogram = (t_dists.unsqueeze(-1) > boundaries).sum(dim=-1).long()
|
800
|
+
disto_target = one_hot(distogram, num_classes=num_bins)
|
801
|
+
|
802
|
+
atom_data = np.concatenate(atom_data)
|
803
|
+
coord_data = np.concatenate(coord_data, axis=1)
|
804
|
+
ref_space_uid = np.array(ref_space_uid)
|
805
|
+
|
806
|
+
# Compute features
|
807
|
+
ref_atom_name_chars = from_numpy(atom_data["name"]).long()
|
808
|
+
ref_element = from_numpy(atom_data["element"]).long()
|
809
|
+
ref_charge = from_numpy(atom_data["charge"])
|
810
|
+
ref_pos = from_numpy(
|
811
|
+
atom_data["conformer"].copy()
|
812
|
+
) # not sure why I need to copy here..
|
813
|
+
ref_space_uid = from_numpy(ref_space_uid)
|
814
|
+
coords = from_numpy(coord_data.copy())
|
815
|
+
resolved_mask = from_numpy(atom_data["is_present"])
|
816
|
+
pad_mask = torch.ones(len(atom_data), dtype=torch.float)
|
817
|
+
atom_to_token = torch.tensor(atom_to_token, dtype=torch.long)
|
818
|
+
token_to_rep_atom = torch.tensor(token_to_rep_atom, dtype=torch.long)
|
819
|
+
r_set_to_rep_atom = torch.tensor(r_set_to_rep_atom, dtype=torch.long)
|
820
|
+
frame_data, resolved_frame_data = compute_frames_nonpolymer(
|
821
|
+
data,
|
822
|
+
coord_data,
|
823
|
+
atom_data["is_present"],
|
824
|
+
atom_to_token,
|
825
|
+
frame_data,
|
826
|
+
resolved_frame_data,
|
827
|
+
) # Compute frames for NONPOLYMER tokens
|
828
|
+
frames = from_numpy(frame_data.copy())
|
829
|
+
frame_resolved_mask = from_numpy(resolved_frame_data.copy())
|
830
|
+
# Convert to one-hot
|
831
|
+
ref_atom_name_chars = one_hot(
|
832
|
+
ref_atom_name_chars % num_bins, num_classes=num_bins
|
833
|
+
) # added for lower case letters
|
834
|
+
ref_element = one_hot(ref_element, num_classes=const.num_elements)
|
835
|
+
atom_to_token = one_hot(atom_to_token, num_classes=token_id + 1)
|
836
|
+
token_to_rep_atom = one_hot(token_to_rep_atom, num_classes=len(atom_data))
|
837
|
+
r_set_to_rep_atom = one_hot(r_set_to_rep_atom, num_classes=len(atom_data))
|
838
|
+
|
839
|
+
# Center the ground truth coordinates
|
840
|
+
center = (coords * resolved_mask[None, :, None]).sum(dim=1)
|
841
|
+
center = center / resolved_mask.sum().clamp(min=1)
|
842
|
+
coords = coords - center[:, None]
|
843
|
+
|
844
|
+
# Apply random roto-translation to the input atoms
|
845
|
+
ref_pos = center_random_augmentation(
|
846
|
+
ref_pos[None], resolved_mask[None], centering=False
|
847
|
+
)[0]
|
848
|
+
|
849
|
+
# Compute padding and apply
|
850
|
+
if max_atoms is not None:
|
851
|
+
assert max_atoms % atoms_per_window_queries == 0
|
852
|
+
pad_len = max_atoms - len(atom_data)
|
853
|
+
else:
|
854
|
+
pad_len = (
|
855
|
+
(len(atom_data) - 1) // atoms_per_window_queries + 1
|
856
|
+
) * atoms_per_window_queries - len(atom_data)
|
857
|
+
|
858
|
+
if pad_len > 0:
|
859
|
+
pad_mask = pad_dim(pad_mask, 0, pad_len)
|
860
|
+
ref_pos = pad_dim(ref_pos, 0, pad_len)
|
861
|
+
resolved_mask = pad_dim(resolved_mask, 0, pad_len)
|
862
|
+
ref_element = pad_dim(ref_element, 0, pad_len)
|
863
|
+
ref_charge = pad_dim(ref_charge, 0, pad_len)
|
864
|
+
ref_atom_name_chars = pad_dim(ref_atom_name_chars, 0, pad_len)
|
865
|
+
ref_space_uid = pad_dim(ref_space_uid, 0, pad_len)
|
866
|
+
coords = pad_dim(coords, 1, pad_len)
|
867
|
+
atom_to_token = pad_dim(atom_to_token, 0, pad_len)
|
868
|
+
token_to_rep_atom = pad_dim(token_to_rep_atom, 1, pad_len)
|
869
|
+
r_set_to_rep_atom = pad_dim(r_set_to_rep_atom, 1, pad_len)
|
870
|
+
|
871
|
+
if max_tokens is not None:
|
872
|
+
pad_len = max_tokens - token_to_rep_atom.shape[0]
|
873
|
+
if pad_len > 0:
|
874
|
+
atom_to_token = pad_dim(atom_to_token, 1, pad_len)
|
875
|
+
token_to_rep_atom = pad_dim(token_to_rep_atom, 0, pad_len)
|
876
|
+
r_set_to_rep_atom = pad_dim(r_set_to_rep_atom, 0, pad_len)
|
877
|
+
disto_target = pad_dim(pad_dim(disto_target, 0, pad_len), 1, pad_len)
|
878
|
+
frames = pad_dim(frames, 0, pad_len)
|
879
|
+
frame_resolved_mask = pad_dim(frame_resolved_mask, 0, pad_len)
|
880
|
+
|
881
|
+
return {
|
882
|
+
"ref_pos": ref_pos,
|
883
|
+
"atom_resolved_mask": resolved_mask,
|
884
|
+
"ref_element": ref_element,
|
885
|
+
"ref_charge": ref_charge,
|
886
|
+
"ref_atom_name_chars": ref_atom_name_chars,
|
887
|
+
"ref_space_uid": ref_space_uid,
|
888
|
+
"coords": coords,
|
889
|
+
"atom_pad_mask": pad_mask,
|
890
|
+
"atom_to_token": atom_to_token,
|
891
|
+
"token_to_rep_atom": token_to_rep_atom,
|
892
|
+
"r_set_to_rep_atom": r_set_to_rep_atom,
|
893
|
+
"disto_target": disto_target,
|
894
|
+
"frames_idx": frames,
|
895
|
+
"frame_resolved_mask": frame_resolved_mask,
|
896
|
+
}
|
897
|
+
|
898
|
+
|
899
|
+
def process_msa_features(
|
900
|
+
data: Tokenized,
|
901
|
+
max_seqs_batch: int,
|
902
|
+
max_seqs: int,
|
903
|
+
max_tokens: Optional[int] = None,
|
904
|
+
pad_to_max_seqs: bool = False,
|
905
|
+
) -> dict[str, Tensor]:
|
906
|
+
"""Get the MSA features.
|
907
|
+
|
908
|
+
Parameters
|
909
|
+
----------
|
910
|
+
data : Tokenized
|
911
|
+
The tokenized data.
|
912
|
+
max_seqs : int
|
913
|
+
The maximum number of MSA sequences.
|
914
|
+
max_tokens : int
|
915
|
+
The maximum number of tokens.
|
916
|
+
pad_to_max_seqs : bool
|
917
|
+
Whether to pad to the maximum number of sequences.
|
918
|
+
|
919
|
+
Returns
|
920
|
+
-------
|
921
|
+
dict[str, Tensor]
|
922
|
+
The MSA features.
|
923
|
+
|
924
|
+
"""
|
925
|
+
# Created paired MSA
|
926
|
+
msa, deletion, paired = construct_paired_msa(data, max_seqs_batch)
|
927
|
+
msa, deletion, paired = (
|
928
|
+
msa.transpose(1, 0),
|
929
|
+
deletion.transpose(1, 0),
|
930
|
+
paired.transpose(1, 0),
|
931
|
+
) # (N_MSA, N_RES, N_AA)
|
932
|
+
|
933
|
+
# Prepare features
|
934
|
+
msa = torch.nn.functional.one_hot(msa, num_classes=const.num_tokens)
|
935
|
+
msa_mask = torch.ones_like(msa[:, :, 0])
|
936
|
+
profile = msa.float().mean(dim=0)
|
937
|
+
has_deletion = deletion > 0
|
938
|
+
deletion = np.pi / 2 * np.arctan(deletion / 3)
|
939
|
+
deletion_mean = deletion.mean(axis=0)
|
940
|
+
|
941
|
+
# Pad in the MSA dimension (dim=0)
|
942
|
+
if pad_to_max_seqs:
|
943
|
+
pad_len = max_seqs - msa.shape[0]
|
944
|
+
if pad_len > 0:
|
945
|
+
msa = pad_dim(msa, 0, pad_len, const.token_ids["-"])
|
946
|
+
paired = pad_dim(paired, 0, pad_len)
|
947
|
+
msa_mask = pad_dim(msa_mask, 0, pad_len)
|
948
|
+
has_deletion = pad_dim(has_deletion, 0, pad_len)
|
949
|
+
deletion = pad_dim(deletion, 0, pad_len)
|
950
|
+
|
951
|
+
# Pad in the token dimension (dim=1)
|
952
|
+
if max_tokens is not None:
|
953
|
+
pad_len = max_tokens - msa.shape[1]
|
954
|
+
if pad_len > 0:
|
955
|
+
msa = pad_dim(msa, 1, pad_len, const.token_ids["-"])
|
956
|
+
paired = pad_dim(paired, 1, pad_len)
|
957
|
+
msa_mask = pad_dim(msa_mask, 1, pad_len)
|
958
|
+
has_deletion = pad_dim(has_deletion, 1, pad_len)
|
959
|
+
deletion = pad_dim(deletion, 1, pad_len)
|
960
|
+
profile = pad_dim(profile, 0, pad_len)
|
961
|
+
deletion_mean = pad_dim(deletion_mean, 0, pad_len)
|
962
|
+
|
963
|
+
return {
|
964
|
+
"msa": msa,
|
965
|
+
"msa_paired": paired,
|
966
|
+
"deletion_value": deletion,
|
967
|
+
"has_deletion": has_deletion,
|
968
|
+
"deletion_mean": deletion_mean,
|
969
|
+
"profile": profile,
|
970
|
+
"msa_mask": msa_mask,
|
971
|
+
}
|
972
|
+
|
973
|
+
|
974
|
+
def process_symmetry_features(
|
975
|
+
cropped: Tokenized, symmetries: dict
|
976
|
+
) -> dict[str, Tensor]:
|
977
|
+
"""Get the symmetry features.
|
978
|
+
|
979
|
+
Parameters
|
980
|
+
----------
|
981
|
+
data : Tokenized
|
982
|
+
The tokenized data.
|
983
|
+
|
984
|
+
Returns
|
985
|
+
-------
|
986
|
+
dict[str, Tensor]
|
987
|
+
The symmetry features.
|
988
|
+
|
989
|
+
"""
|
990
|
+
features = get_chain_symmetries(cropped)
|
991
|
+
features.update(get_amino_acids_symmetries(cropped))
|
992
|
+
features.update(get_ligand_symmetries(cropped, symmetries))
|
993
|
+
|
994
|
+
return features
|
995
|
+
|
996
|
+
|
997
|
+
def process_residue_constraint_features(
|
998
|
+
data: Tokenized,
|
999
|
+
) -> dict[str, Tensor]:
|
1000
|
+
residue_constraints = data.residue_constraints
|
1001
|
+
if residue_constraints is not None:
|
1002
|
+
rdkit_bounds_constraints = residue_constraints.rdkit_bounds_constraints
|
1003
|
+
chiral_atom_constraints = residue_constraints.chiral_atom_constraints
|
1004
|
+
stereo_bond_constraints = residue_constraints.stereo_bond_constraints
|
1005
|
+
planar_bond_constraints = residue_constraints.planar_bond_constraints
|
1006
|
+
planar_ring_5_constraints = residue_constraints.planar_ring_5_constraints
|
1007
|
+
planar_ring_6_constraints = residue_constraints.planar_ring_6_constraints
|
1008
|
+
|
1009
|
+
rdkit_bounds_index = torch.tensor(
|
1010
|
+
rdkit_bounds_constraints["atom_idxs"].copy(), dtype=torch.long
|
1011
|
+
).T
|
1012
|
+
rdkit_bounds_bond_mask = torch.tensor(
|
1013
|
+
rdkit_bounds_constraints["is_bond"].copy(), dtype=torch.bool
|
1014
|
+
)
|
1015
|
+
rdkit_bounds_angle_mask = torch.tensor(
|
1016
|
+
rdkit_bounds_constraints["is_angle"].copy(), dtype=torch.bool
|
1017
|
+
)
|
1018
|
+
rdkit_upper_bounds = torch.tensor(
|
1019
|
+
rdkit_bounds_constraints["upper_bound"].copy(), dtype=torch.float
|
1020
|
+
)
|
1021
|
+
rdkit_lower_bounds = torch.tensor(
|
1022
|
+
rdkit_bounds_constraints["lower_bound"].copy(), dtype=torch.float
|
1023
|
+
)
|
1024
|
+
|
1025
|
+
chiral_atom_index = torch.tensor(
|
1026
|
+
chiral_atom_constraints["atom_idxs"].copy(), dtype=torch.long
|
1027
|
+
).T
|
1028
|
+
chiral_reference_mask = torch.tensor(
|
1029
|
+
chiral_atom_constraints["is_reference"].copy(), dtype=torch.bool
|
1030
|
+
)
|
1031
|
+
chiral_atom_orientations = torch.tensor(
|
1032
|
+
chiral_atom_constraints["is_r"].copy(), dtype=torch.bool
|
1033
|
+
)
|
1034
|
+
|
1035
|
+
stereo_bond_index = torch.tensor(
|
1036
|
+
stereo_bond_constraints["atom_idxs"].copy(), dtype=torch.long
|
1037
|
+
).T
|
1038
|
+
stereo_reference_mask = torch.tensor(
|
1039
|
+
stereo_bond_constraints["is_reference"].copy(), dtype=torch.bool
|
1040
|
+
)
|
1041
|
+
stereo_bond_orientations = torch.tensor(
|
1042
|
+
stereo_bond_constraints["is_e"].copy(), dtype=torch.bool
|
1043
|
+
)
|
1044
|
+
|
1045
|
+
planar_bond_index = torch.tensor(
|
1046
|
+
planar_bond_constraints["atom_idxs"].copy(), dtype=torch.long
|
1047
|
+
).T
|
1048
|
+
planar_ring_5_index = torch.tensor(
|
1049
|
+
planar_ring_5_constraints["atom_idxs"].copy(), dtype=torch.long
|
1050
|
+
).T
|
1051
|
+
planar_ring_6_index = torch.tensor(
|
1052
|
+
planar_ring_6_constraints["atom_idxs"].copy(), dtype=torch.long
|
1053
|
+
).T
|
1054
|
+
else:
|
1055
|
+
rdkit_bounds_index = torch.empty((2, 0), dtype=torch.long)
|
1056
|
+
rdkit_bounds_bond_mask = torch.empty((0,), dtype=torch.bool)
|
1057
|
+
rdkit_bounds_angle_mask = torch.empty((0,), dtype=torch.bool)
|
1058
|
+
rdkit_upper_bounds = torch.empty((0,), dtype=torch.float)
|
1059
|
+
rdkit_lower_bounds = torch.empty((0,), dtype=torch.float)
|
1060
|
+
chiral_atom_index = torch.empty(
|
1061
|
+
(
|
1062
|
+
4,
|
1063
|
+
0,
|
1064
|
+
),
|
1065
|
+
dtype=torch.long,
|
1066
|
+
)
|
1067
|
+
chiral_reference_mask = torch.empty((0,), dtype=torch.bool)
|
1068
|
+
chiral_atom_orientations = torch.empty((0,), dtype=torch.bool)
|
1069
|
+
stereo_bond_index = torch.empty((4, 0), dtype=torch.long)
|
1070
|
+
stereo_reference_mask = torch.empty((0,), dtype=torch.bool)
|
1071
|
+
stereo_bond_orientations = torch.empty((0,), dtype=torch.bool)
|
1072
|
+
planar_bond_index = torch.empty((6, 0), dtype=torch.long)
|
1073
|
+
planar_ring_5_index = torch.empty((5, 0), dtype=torch.long)
|
1074
|
+
planar_ring_6_index = torch.empty((6, 0), dtype=torch.long)
|
1075
|
+
|
1076
|
+
return {
|
1077
|
+
"rdkit_bounds_index": rdkit_bounds_index,
|
1078
|
+
"rdkit_bounds_bond_mask": rdkit_bounds_bond_mask,
|
1079
|
+
"rdkit_bounds_angle_mask": rdkit_bounds_angle_mask,
|
1080
|
+
"rdkit_upper_bounds": rdkit_upper_bounds,
|
1081
|
+
"rdkit_lower_bounds": rdkit_lower_bounds,
|
1082
|
+
"chiral_atom_index": chiral_atom_index,
|
1083
|
+
"chiral_reference_mask": chiral_reference_mask,
|
1084
|
+
"chiral_atom_orientations": chiral_atom_orientations,
|
1085
|
+
"stereo_bond_index": stereo_bond_index,
|
1086
|
+
"stereo_reference_mask": stereo_reference_mask,
|
1087
|
+
"stereo_bond_orientations": stereo_bond_orientations,
|
1088
|
+
"planar_bond_index": planar_bond_index,
|
1089
|
+
"planar_ring_5_index": planar_ring_5_index,
|
1090
|
+
"planar_ring_6_index": planar_ring_6_index,
|
1091
|
+
}
|
1092
|
+
|
1093
|
+
|
1094
|
+
def process_chain_feature_constraints(
|
1095
|
+
data: Tokenized,
|
1096
|
+
) -> dict[str, Tensor]:
|
1097
|
+
structure = data.structure
|
1098
|
+
if structure.connections.shape[0] > 0:
|
1099
|
+
connected_chain_index, connected_atom_index = [], []
|
1100
|
+
for connection in structure.connections:
|
1101
|
+
connected_chain_index.append([connection["chain_1"], connection["chain_2"]])
|
1102
|
+
connected_atom_index.append([connection["atom_1"], connection["atom_2"]])
|
1103
|
+
connected_chain_index = torch.tensor(connected_chain_index, dtype=torch.long).T
|
1104
|
+
connected_atom_index = torch.tensor(connected_atom_index, dtype=torch.long).T
|
1105
|
+
else:
|
1106
|
+
connected_chain_index = torch.empty((2, 0), dtype=torch.long)
|
1107
|
+
connected_atom_index = torch.empty((2, 0), dtype=torch.long)
|
1108
|
+
|
1109
|
+
symmetric_chain_index = []
|
1110
|
+
for i, chain_i in enumerate(structure.chains):
|
1111
|
+
for j, chain_j in enumerate(structure.chains):
|
1112
|
+
if j <= i:
|
1113
|
+
continue
|
1114
|
+
if chain_i["entity_id"] == chain_j["entity_id"]:
|
1115
|
+
symmetric_chain_index.append([i, j])
|
1116
|
+
if len(symmetric_chain_index) > 0:
|
1117
|
+
symmetric_chain_index = torch.tensor(symmetric_chain_index, dtype=torch.long).T
|
1118
|
+
else:
|
1119
|
+
symmetric_chain_index = torch.empty((2, 0), dtype=torch.long)
|
1120
|
+
return {
|
1121
|
+
"connected_chain_index": connected_chain_index,
|
1122
|
+
"connected_atom_index": connected_atom_index,
|
1123
|
+
"symmetric_chain_index": symmetric_chain_index,
|
1124
|
+
}
|
1125
|
+
|
1126
|
+
|
1127
|
+
class BoltzFeaturizer:
|
1128
|
+
"""Boltz featurizer."""
|
1129
|
+
|
1130
|
+
def process(
|
1131
|
+
self,
|
1132
|
+
data: Tokenized,
|
1133
|
+
training: bool,
|
1134
|
+
max_seqs: int = 4096,
|
1135
|
+
atoms_per_window_queries: int = 32,
|
1136
|
+
min_dist: float = 2.0,
|
1137
|
+
max_dist: float = 22.0,
|
1138
|
+
num_bins: int = 64,
|
1139
|
+
max_tokens: Optional[int] = None,
|
1140
|
+
max_atoms: Optional[int] = None,
|
1141
|
+
pad_to_max_seqs: bool = False,
|
1142
|
+
compute_symmetries: bool = False,
|
1143
|
+
symmetries: Optional[dict] = None,
|
1144
|
+
binder_pocket_conditioned_prop: Optional[float] = 0.0,
|
1145
|
+
binder_pocket_cutoff: Optional[float] = 6.0,
|
1146
|
+
binder_pocket_sampling_geometric_p: Optional[float] = 0.0,
|
1147
|
+
only_ligand_binder_pocket: Optional[bool] = False,
|
1148
|
+
inference_binder: Optional[int] = None,
|
1149
|
+
inference_pocket: Optional[list[tuple[int, int]]] = None,
|
1150
|
+
compute_constraint_features: bool = False,
|
1151
|
+
) -> dict[str, Tensor]:
|
1152
|
+
"""Compute features.
|
1153
|
+
|
1154
|
+
Parameters
|
1155
|
+
----------
|
1156
|
+
data : Tokenized
|
1157
|
+
The tokenized data.
|
1158
|
+
training : bool
|
1159
|
+
Whether the model is in training mode.
|
1160
|
+
max_tokens : int, optional
|
1161
|
+
The maximum number of tokens.
|
1162
|
+
max_atoms : int, optional
|
1163
|
+
The maximum number of atoms
|
1164
|
+
max_seqs : int, optional
|
1165
|
+
The maximum number of sequences.
|
1166
|
+
|
1167
|
+
Returns
|
1168
|
+
-------
|
1169
|
+
dict[str, Tensor]
|
1170
|
+
The features for model training.
|
1171
|
+
|
1172
|
+
"""
|
1173
|
+
# Compute random number of sequences
|
1174
|
+
if training and max_seqs is not None:
|
1175
|
+
max_seqs_batch = np.random.randint(1, max_seqs + 1) # noqa: NPY002
|
1176
|
+
else:
|
1177
|
+
max_seqs_batch = max_seqs
|
1178
|
+
|
1179
|
+
# Compute token features
|
1180
|
+
token_features = process_token_features(
|
1181
|
+
data,
|
1182
|
+
max_tokens,
|
1183
|
+
binder_pocket_conditioned_prop,
|
1184
|
+
binder_pocket_cutoff,
|
1185
|
+
binder_pocket_sampling_geometric_p,
|
1186
|
+
only_ligand_binder_pocket,
|
1187
|
+
inference_binder=inference_binder,
|
1188
|
+
inference_pocket=inference_pocket,
|
1189
|
+
)
|
1190
|
+
|
1191
|
+
# Compute atom features
|
1192
|
+
atom_features = process_atom_features(
|
1193
|
+
data,
|
1194
|
+
atoms_per_window_queries,
|
1195
|
+
min_dist,
|
1196
|
+
max_dist,
|
1197
|
+
num_bins,
|
1198
|
+
max_atoms,
|
1199
|
+
max_tokens,
|
1200
|
+
)
|
1201
|
+
|
1202
|
+
# Compute MSA features
|
1203
|
+
msa_features = process_msa_features(
|
1204
|
+
data,
|
1205
|
+
max_seqs_batch,
|
1206
|
+
max_seqs,
|
1207
|
+
max_tokens,
|
1208
|
+
pad_to_max_seqs,
|
1209
|
+
)
|
1210
|
+
|
1211
|
+
# Compute symmetry features
|
1212
|
+
symmetry_features = {}
|
1213
|
+
if compute_symmetries:
|
1214
|
+
symmetry_features = process_symmetry_features(data, symmetries)
|
1215
|
+
|
1216
|
+
# Compute residue constraint features
|
1217
|
+
residue_constraint_features = {}
|
1218
|
+
chain_constraint_features = {}
|
1219
|
+
if compute_constraint_features:
|
1220
|
+
residue_constraint_features = process_residue_constraint_features(data)
|
1221
|
+
chain_constraint_features = process_chain_feature_constraints(data)
|
1222
|
+
|
1223
|
+
return {
|
1224
|
+
**token_features,
|
1225
|
+
**atom_features,
|
1226
|
+
**msa_features,
|
1227
|
+
**symmetry_features,
|
1228
|
+
**residue_constraint_features,
|
1229
|
+
**chain_constraint_features,
|
1230
|
+
}
|