boltz-vsynthes 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- boltz/__init__.py +7 -0
- boltz/data/__init__.py +0 -0
- boltz/data/const.py +1184 -0
- boltz/data/crop/__init__.py +0 -0
- boltz/data/crop/affinity.py +164 -0
- boltz/data/crop/boltz.py +296 -0
- boltz/data/crop/cropper.py +45 -0
- boltz/data/feature/__init__.py +0 -0
- boltz/data/feature/featurizer.py +1230 -0
- boltz/data/feature/featurizerv2.py +2208 -0
- boltz/data/feature/symmetry.py +602 -0
- boltz/data/filter/__init__.py +0 -0
- boltz/data/filter/dynamic/__init__.py +0 -0
- boltz/data/filter/dynamic/date.py +76 -0
- boltz/data/filter/dynamic/filter.py +24 -0
- boltz/data/filter/dynamic/max_residues.py +37 -0
- boltz/data/filter/dynamic/resolution.py +34 -0
- boltz/data/filter/dynamic/size.py +38 -0
- boltz/data/filter/dynamic/subset.py +42 -0
- boltz/data/filter/static/__init__.py +0 -0
- boltz/data/filter/static/filter.py +26 -0
- boltz/data/filter/static/ligand.py +37 -0
- boltz/data/filter/static/polymer.py +299 -0
- boltz/data/module/__init__.py +0 -0
- boltz/data/module/inference.py +307 -0
- boltz/data/module/inferencev2.py +429 -0
- boltz/data/module/training.py +684 -0
- boltz/data/module/trainingv2.py +660 -0
- boltz/data/mol.py +900 -0
- boltz/data/msa/__init__.py +0 -0
- boltz/data/msa/mmseqs2.py +235 -0
- boltz/data/pad.py +84 -0
- boltz/data/parse/__init__.py +0 -0
- boltz/data/parse/a3m.py +134 -0
- boltz/data/parse/csv.py +100 -0
- boltz/data/parse/fasta.py +138 -0
- boltz/data/parse/mmcif.py +1239 -0
- boltz/data/parse/mmcif_with_constraints.py +1607 -0
- boltz/data/parse/schema.py +1851 -0
- boltz/data/parse/yaml.py +68 -0
- boltz/data/sample/__init__.py +0 -0
- boltz/data/sample/cluster.py +283 -0
- boltz/data/sample/distillation.py +57 -0
- boltz/data/sample/random.py +39 -0
- boltz/data/sample/sampler.py +49 -0
- boltz/data/tokenize/__init__.py +0 -0
- boltz/data/tokenize/boltz.py +195 -0
- boltz/data/tokenize/boltz2.py +396 -0
- boltz/data/tokenize/tokenizer.py +24 -0
- boltz/data/types.py +777 -0
- boltz/data/write/__init__.py +0 -0
- boltz/data/write/mmcif.py +305 -0
- boltz/data/write/pdb.py +171 -0
- boltz/data/write/utils.py +23 -0
- boltz/data/write/writer.py +330 -0
- boltz/main.py +1292 -0
- boltz/model/__init__.py +0 -0
- boltz/model/layers/__init__.py +0 -0
- boltz/model/layers/attention.py +132 -0
- boltz/model/layers/attentionv2.py +111 -0
- boltz/model/layers/confidence_utils.py +231 -0
- boltz/model/layers/dropout.py +34 -0
- boltz/model/layers/initialize.py +100 -0
- boltz/model/layers/outer_product_mean.py +98 -0
- boltz/model/layers/pair_averaging.py +135 -0
- boltz/model/layers/pairformer.py +337 -0
- boltz/model/layers/relative.py +58 -0
- boltz/model/layers/transition.py +78 -0
- boltz/model/layers/triangular_attention/__init__.py +0 -0
- boltz/model/layers/triangular_attention/attention.py +189 -0
- boltz/model/layers/triangular_attention/primitives.py +409 -0
- boltz/model/layers/triangular_attention/utils.py +380 -0
- boltz/model/layers/triangular_mult.py +212 -0
- boltz/model/loss/__init__.py +0 -0
- boltz/model/loss/bfactor.py +49 -0
- boltz/model/loss/confidence.py +590 -0
- boltz/model/loss/confidencev2.py +621 -0
- boltz/model/loss/diffusion.py +171 -0
- boltz/model/loss/diffusionv2.py +134 -0
- boltz/model/loss/distogram.py +48 -0
- boltz/model/loss/distogramv2.py +105 -0
- boltz/model/loss/validation.py +1025 -0
- boltz/model/models/__init__.py +0 -0
- boltz/model/models/boltz1.py +1286 -0
- boltz/model/models/boltz2.py +1249 -0
- boltz/model/modules/__init__.py +0 -0
- boltz/model/modules/affinity.py +223 -0
- boltz/model/modules/confidence.py +481 -0
- boltz/model/modules/confidence_utils.py +181 -0
- boltz/model/modules/confidencev2.py +495 -0
- boltz/model/modules/diffusion.py +844 -0
- boltz/model/modules/diffusion_conditioning.py +116 -0
- boltz/model/modules/diffusionv2.py +677 -0
- boltz/model/modules/encoders.py +639 -0
- boltz/model/modules/encodersv2.py +565 -0
- boltz/model/modules/transformers.py +322 -0
- boltz/model/modules/transformersv2.py +261 -0
- boltz/model/modules/trunk.py +688 -0
- boltz/model/modules/trunkv2.py +828 -0
- boltz/model/modules/utils.py +303 -0
- boltz/model/optim/__init__.py +0 -0
- boltz/model/optim/ema.py +389 -0
- boltz/model/optim/scheduler.py +99 -0
- boltz/model/potentials/__init__.py +0 -0
- boltz/model/potentials/potentials.py +497 -0
- boltz/model/potentials/schedules.py +32 -0
- boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
- boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
- boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
- boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
- boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
- boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,2208 @@
|
|
1
|
+
import math
|
2
|
+
from typing import Optional
|
3
|
+
|
4
|
+
import numba
|
5
|
+
import numpy as np
|
6
|
+
import numpy.typing as npt
|
7
|
+
import rdkit.Chem.Descriptors
|
8
|
+
import torch
|
9
|
+
from numba import types
|
10
|
+
from rdkit.Chem import Mol
|
11
|
+
from scipy.spatial.distance import cdist
|
12
|
+
from torch import Tensor, from_numpy
|
13
|
+
from torch.nn.functional import one_hot
|
14
|
+
|
15
|
+
from boltz.data import const
|
16
|
+
from boltz.data.mol import (
|
17
|
+
get_amino_acids_symmetries,
|
18
|
+
get_chain_symmetries,
|
19
|
+
get_ligand_symmetries,
|
20
|
+
get_symmetries,
|
21
|
+
)
|
22
|
+
from boltz.data.pad import pad_dim
|
23
|
+
from boltz.data.types import (
|
24
|
+
MSA,
|
25
|
+
MSADeletion,
|
26
|
+
MSAResidue,
|
27
|
+
MSASequence,
|
28
|
+
TemplateInfo,
|
29
|
+
Tokenized,
|
30
|
+
)
|
31
|
+
from boltz.model.modules.utils import center_random_augmentation
|
32
|
+
|
33
|
+
####################################################################################################
|
34
|
+
# HELPERS
|
35
|
+
####################################################################################################
|
36
|
+
|
37
|
+
|
38
|
+
def convert_atom_name(name: str) -> tuple[int, int, int, int]:
|
39
|
+
"""Convert an atom name to a standard format.
|
40
|
+
|
41
|
+
Parameters
|
42
|
+
----------
|
43
|
+
name : str
|
44
|
+
The atom name.
|
45
|
+
|
46
|
+
Returns
|
47
|
+
-------
|
48
|
+
tuple[int, int, int, int]
|
49
|
+
The converted atom name.
|
50
|
+
|
51
|
+
"""
|
52
|
+
name = str(name).strip()
|
53
|
+
name = [ord(c) - 32 for c in name]
|
54
|
+
name = name + [0] * (4 - len(name))
|
55
|
+
return tuple(name)
|
56
|
+
|
57
|
+
|
58
|
+
def sample_d(
|
59
|
+
min_d: float,
|
60
|
+
max_d: float,
|
61
|
+
n_samples: int,
|
62
|
+
random: np.random.Generator,
|
63
|
+
) -> np.ndarray:
|
64
|
+
"""Generate samples from a 1/d distribution between min_d and max_d.
|
65
|
+
|
66
|
+
Parameters
|
67
|
+
----------
|
68
|
+
min_d : float
|
69
|
+
Minimum value of d
|
70
|
+
max_d : float
|
71
|
+
Maximum value of d
|
72
|
+
n_samples : int
|
73
|
+
Number of samples to generate
|
74
|
+
random : numpy.random.Generator
|
75
|
+
Random number generator
|
76
|
+
|
77
|
+
Returns
|
78
|
+
-------
|
79
|
+
numpy.ndarray
|
80
|
+
Array of samples drawn from the distribution
|
81
|
+
|
82
|
+
Notes
|
83
|
+
-----
|
84
|
+
The probability density function is:
|
85
|
+
f(d) = 1/(d * ln(max_d/min_d)) for d in [min_d, max_d]
|
86
|
+
|
87
|
+
The inverse CDF transform is:
|
88
|
+
d = min_d * (max_d/min_d)**u where u ~ Uniform(0,1)
|
89
|
+
|
90
|
+
"""
|
91
|
+
# Generate n_samples uniform random numbers in [0, 1]
|
92
|
+
u = random.random(n_samples)
|
93
|
+
# Transform u using the inverse CDF
|
94
|
+
return min_d * (max_d / min_d) ** u
|
95
|
+
|
96
|
+
|
97
|
+
def compute_frames_nonpolymer(
|
98
|
+
data: Tokenized,
|
99
|
+
coords,
|
100
|
+
resolved_mask,
|
101
|
+
atom_to_token,
|
102
|
+
frame_data: list,
|
103
|
+
resolved_frame_data: list,
|
104
|
+
) -> tuple[list, list]:
|
105
|
+
"""Get the frames for non-polymer tokens.
|
106
|
+
|
107
|
+
Parameters
|
108
|
+
----------
|
109
|
+
data : Tokenized
|
110
|
+
The input data to the model.
|
111
|
+
frame_data : list
|
112
|
+
The frame data.
|
113
|
+
resolved_frame_data : list
|
114
|
+
The resolved frame data.
|
115
|
+
|
116
|
+
Returns
|
117
|
+
-------
|
118
|
+
tuple[list, list]
|
119
|
+
The frame data and resolved frame data.
|
120
|
+
|
121
|
+
"""
|
122
|
+
frame_data = np.array(frame_data)
|
123
|
+
resolved_frame_data = np.array(resolved_frame_data)
|
124
|
+
asym_id_token = data.tokens["asym_id"]
|
125
|
+
asym_id_atom = data.tokens["asym_id"][atom_to_token]
|
126
|
+
token_idx = 0
|
127
|
+
atom_idx = 0
|
128
|
+
for id in np.unique(data.tokens["asym_id"]):
|
129
|
+
mask_chain_token = asym_id_token == id
|
130
|
+
mask_chain_atom = asym_id_atom == id
|
131
|
+
num_tokens = mask_chain_token.sum()
|
132
|
+
num_atoms = mask_chain_atom.sum()
|
133
|
+
if (
|
134
|
+
data.tokens[token_idx]["mol_type"] != const.chain_type_ids["NONPOLYMER"]
|
135
|
+
or num_atoms < 3 # noqa: PLR2004
|
136
|
+
):
|
137
|
+
token_idx += num_tokens
|
138
|
+
atom_idx += num_atoms
|
139
|
+
continue
|
140
|
+
dist_mat = (
|
141
|
+
(
|
142
|
+
coords.reshape(-1, 3)[mask_chain_atom][:, None, :]
|
143
|
+
- coords.reshape(-1, 3)[mask_chain_atom][None, :, :]
|
144
|
+
)
|
145
|
+
** 2
|
146
|
+
).sum(-1) ** 0.5
|
147
|
+
resolved_pair = 1 - (
|
148
|
+
resolved_mask[mask_chain_atom][None, :]
|
149
|
+
* resolved_mask[mask_chain_atom][:, None]
|
150
|
+
).astype(np.float32)
|
151
|
+
resolved_pair[resolved_pair == 1] = math.inf
|
152
|
+
indices = np.argsort(dist_mat + resolved_pair, axis=1)
|
153
|
+
frames = (
|
154
|
+
np.concatenate(
|
155
|
+
[
|
156
|
+
indices[:, 1:2],
|
157
|
+
indices[:, 0:1],
|
158
|
+
indices[:, 2:3],
|
159
|
+
],
|
160
|
+
axis=1,
|
161
|
+
)
|
162
|
+
+ atom_idx
|
163
|
+
)
|
164
|
+
frame_data[token_idx : token_idx + num_atoms, :] = frames
|
165
|
+
resolved_frame_data[token_idx : token_idx + num_atoms] = resolved_mask[
|
166
|
+
frames
|
167
|
+
].all(axis=1)
|
168
|
+
token_idx += num_tokens
|
169
|
+
atom_idx += num_atoms
|
170
|
+
frames_expanded = coords.reshape(-1, 3)[frame_data]
|
171
|
+
|
172
|
+
mask_collinear = compute_collinear_mask(
|
173
|
+
frames_expanded[:, 1] - frames_expanded[:, 0],
|
174
|
+
frames_expanded[:, 1] - frames_expanded[:, 2],
|
175
|
+
)
|
176
|
+
return frame_data, resolved_frame_data & mask_collinear
|
177
|
+
|
178
|
+
|
179
|
+
def compute_collinear_mask(v1, v2):
|
180
|
+
norm1 = np.linalg.norm(v1, axis=1, keepdims=True)
|
181
|
+
norm2 = np.linalg.norm(v2, axis=1, keepdims=True)
|
182
|
+
v1 = v1 / (norm1 + 1e-6)
|
183
|
+
v2 = v2 / (norm2 + 1e-6)
|
184
|
+
mask_angle = np.abs(np.sum(v1 * v2, axis=1)) < 0.9063
|
185
|
+
mask_overlap1 = norm1.reshape(-1) > 1e-2
|
186
|
+
mask_overlap2 = norm2.reshape(-1) > 1e-2
|
187
|
+
return mask_angle & mask_overlap1 & mask_overlap2
|
188
|
+
|
189
|
+
|
190
|
+
def dummy_msa(residues: np.ndarray) -> MSA:
|
191
|
+
"""Create a dummy MSA for a chain.
|
192
|
+
|
193
|
+
Parameters
|
194
|
+
----------
|
195
|
+
residues : np.ndarray
|
196
|
+
The residues for the chain.
|
197
|
+
|
198
|
+
Returns
|
199
|
+
-------
|
200
|
+
MSA
|
201
|
+
The dummy MSA.
|
202
|
+
|
203
|
+
"""
|
204
|
+
residues = [res["res_type"] for res in residues]
|
205
|
+
deletions = []
|
206
|
+
sequences = [(0, -1, 0, len(residues), 0, 0)]
|
207
|
+
return MSA(
|
208
|
+
residues=np.array(residues, dtype=MSAResidue),
|
209
|
+
deletions=np.array(deletions, dtype=MSADeletion),
|
210
|
+
sequences=np.array(sequences, dtype=MSASequence),
|
211
|
+
)
|
212
|
+
|
213
|
+
|
214
|
+
def construct_paired_msa( # noqa: C901, PLR0915, PLR0912
|
215
|
+
data: Tokenized,
|
216
|
+
random: np.random.Generator,
|
217
|
+
max_seqs: int,
|
218
|
+
max_pairs: int = 8192,
|
219
|
+
max_total: int = 16384,
|
220
|
+
random_subset: bool = False,
|
221
|
+
) -> tuple[Tensor, Tensor, Tensor]:
|
222
|
+
"""Pair the MSA data.
|
223
|
+
|
224
|
+
Parameters
|
225
|
+
----------
|
226
|
+
data : Tokenized
|
227
|
+
The input data to the model.
|
228
|
+
|
229
|
+
Returns
|
230
|
+
-------
|
231
|
+
Tensor
|
232
|
+
The MSA data.
|
233
|
+
Tensor
|
234
|
+
The deletion data.
|
235
|
+
Tensor
|
236
|
+
Mask indicating paired sequences.
|
237
|
+
|
238
|
+
"""
|
239
|
+
# Get unique chains (ensuring monotonicity in the order)
|
240
|
+
assert np.all(np.diff(data.tokens["asym_id"], n=1) >= 0)
|
241
|
+
chain_ids = np.unique(data.tokens["asym_id"])
|
242
|
+
|
243
|
+
# Get relevant MSA, and create a dummy for chains without
|
244
|
+
msa: dict[int, MSA] = {}
|
245
|
+
for chain_id in chain_ids:
|
246
|
+
# Get input sequence
|
247
|
+
chain = data.structure.chains[chain_id]
|
248
|
+
res_start = chain["res_idx"]
|
249
|
+
res_end = res_start + chain["res_num"]
|
250
|
+
residues = data.structure.residues[res_start:res_end]
|
251
|
+
|
252
|
+
# Check if we have an MSA, and that the
|
253
|
+
# first sequence matches the input sequence
|
254
|
+
if chain_id in data.msa:
|
255
|
+
# Set the MSA
|
256
|
+
msa[chain_id] = data.msa[chain_id]
|
257
|
+
|
258
|
+
# Run length and residue type checks
|
259
|
+
first = data.msa[chain_id].sequences[0]
|
260
|
+
first_start = first["res_start"]
|
261
|
+
first_end = first["res_end"]
|
262
|
+
msa_residues = data.msa[chain_id].residues
|
263
|
+
first_residues = msa_residues[first_start:first_end]
|
264
|
+
|
265
|
+
warning = "Warning: MSA does not match input sequence, creating dummy."
|
266
|
+
if len(residues) == len(first_residues):
|
267
|
+
# If there is a mismatch, check if it is between MET & UNK
|
268
|
+
# If so, replace the first sequence with the input sequence.
|
269
|
+
# Otherwise, replace with a dummy MSA for this chain.
|
270
|
+
mismatches = residues["res_type"] != first_residues["res_type"]
|
271
|
+
if mismatches.sum().item():
|
272
|
+
idx = np.where(mismatches)[0]
|
273
|
+
is_met = residues["res_type"][idx] == const.token_ids["MET"]
|
274
|
+
is_unk = residues["res_type"][idx] == const.token_ids["UNK"]
|
275
|
+
is_msa_unk = (
|
276
|
+
first_residues["res_type"][idx] == const.token_ids["UNK"]
|
277
|
+
)
|
278
|
+
if (np.all(is_met) and np.all(is_msa_unk)) or np.all(is_unk):
|
279
|
+
msa_residues[first_start:first_end]["res_type"] = residues[
|
280
|
+
"res_type"
|
281
|
+
]
|
282
|
+
else:
|
283
|
+
print(
|
284
|
+
warning,
|
285
|
+
"1",
|
286
|
+
residues["res_type"],
|
287
|
+
first_residues["res_type"],
|
288
|
+
data.record.id,
|
289
|
+
)
|
290
|
+
msa[chain_id] = dummy_msa(residues)
|
291
|
+
else:
|
292
|
+
print(
|
293
|
+
warning,
|
294
|
+
"2",
|
295
|
+
residues["res_type"],
|
296
|
+
first_residues["res_type"],
|
297
|
+
data.record.id,
|
298
|
+
)
|
299
|
+
msa[chain_id] = dummy_msa(residues)
|
300
|
+
else:
|
301
|
+
msa[chain_id] = dummy_msa(residues)
|
302
|
+
|
303
|
+
# Map taxonomies to (chain_id, seq_idx)
|
304
|
+
taxonomy_map: dict[str, list] = {}
|
305
|
+
for chain_id, chain_msa in msa.items():
|
306
|
+
sequences = chain_msa.sequences
|
307
|
+
sequences = sequences[sequences["taxonomy"] != -1]
|
308
|
+
for sequence in sequences:
|
309
|
+
seq_idx = sequence["seq_idx"]
|
310
|
+
taxon = sequence["taxonomy"]
|
311
|
+
taxonomy_map.setdefault(taxon, []).append((chain_id, seq_idx))
|
312
|
+
|
313
|
+
# Remove taxonomies with only one sequence and sort by the
|
314
|
+
# number of chain_id present in each of the taxonomies
|
315
|
+
taxonomy_map = {k: v for k, v in taxonomy_map.items() if len(v) > 1}
|
316
|
+
taxonomy_map = sorted(
|
317
|
+
taxonomy_map.items(),
|
318
|
+
key=lambda x: len({c for c, _ in x[1]}),
|
319
|
+
reverse=True,
|
320
|
+
)
|
321
|
+
|
322
|
+
# Keep track of the sequences available per chain, keeping the original
|
323
|
+
# order of the sequences in the MSA to favor the best matching sequences
|
324
|
+
visited = {(c, s) for c, items in taxonomy_map for s in items}
|
325
|
+
available = {}
|
326
|
+
for c in chain_ids:
|
327
|
+
available[c] = [
|
328
|
+
i for i in range(1, len(msa[c].sequences)) if (c, i) not in visited
|
329
|
+
]
|
330
|
+
|
331
|
+
# Create sequence pairs
|
332
|
+
is_paired = []
|
333
|
+
pairing = []
|
334
|
+
|
335
|
+
# Start with the first sequence for each chain
|
336
|
+
is_paired.append({c: 1 for c in chain_ids})
|
337
|
+
pairing.append({c: 0 for c in chain_ids})
|
338
|
+
|
339
|
+
# Then add up to 8191 paired rows
|
340
|
+
for _, pairs in taxonomy_map:
|
341
|
+
# Group occurences by chain_id in case we have multiple
|
342
|
+
# sequences from the same chain and same taxonomy
|
343
|
+
chain_occurences = {}
|
344
|
+
for chain_id, seq_idx in pairs:
|
345
|
+
chain_occurences.setdefault(chain_id, []).append(seq_idx)
|
346
|
+
|
347
|
+
# We create as many pairings as the maximum number of occurences
|
348
|
+
max_occurences = max(len(v) for v in chain_occurences.values())
|
349
|
+
for i in range(max_occurences):
|
350
|
+
row_pairing = {}
|
351
|
+
row_is_paired = {}
|
352
|
+
|
353
|
+
# Add the chains present in the taxonomy
|
354
|
+
for chain_id, seq_idxs in chain_occurences.items():
|
355
|
+
# Roll over the sequence index to maximize diversity
|
356
|
+
idx = i % len(seq_idxs)
|
357
|
+
seq_idx = seq_idxs[idx]
|
358
|
+
|
359
|
+
# Add the sequence to the pairing
|
360
|
+
row_pairing[chain_id] = seq_idx
|
361
|
+
row_is_paired[chain_id] = 1
|
362
|
+
|
363
|
+
# Add any missing chains
|
364
|
+
for chain_id in chain_ids:
|
365
|
+
if chain_id not in row_pairing:
|
366
|
+
row_is_paired[chain_id] = 0
|
367
|
+
if available[chain_id]:
|
368
|
+
# Add the next available sequence
|
369
|
+
seq_idx = available[chain_id].pop(0)
|
370
|
+
row_pairing[chain_id] = seq_idx
|
371
|
+
else:
|
372
|
+
# No more sequences available, we place a gap
|
373
|
+
row_pairing[chain_id] = -1
|
374
|
+
|
375
|
+
pairing.append(row_pairing)
|
376
|
+
is_paired.append(row_is_paired)
|
377
|
+
|
378
|
+
# Break if we have enough pairs
|
379
|
+
if len(pairing) >= max_pairs:
|
380
|
+
break
|
381
|
+
|
382
|
+
# Break if we have enough pairs
|
383
|
+
if len(pairing) >= max_pairs:
|
384
|
+
break
|
385
|
+
|
386
|
+
# Now add up to 16384 unpaired rows total
|
387
|
+
max_left = max(len(v) for v in available.values())
|
388
|
+
for _ in range(min(max_total - len(pairing), max_left)):
|
389
|
+
row_pairing = {}
|
390
|
+
row_is_paired = {}
|
391
|
+
for chain_id in chain_ids:
|
392
|
+
row_is_paired[chain_id] = 0
|
393
|
+
if available[chain_id]:
|
394
|
+
# Add the next available sequence
|
395
|
+
seq_idx = available[chain_id].pop(0)
|
396
|
+
row_pairing[chain_id] = seq_idx
|
397
|
+
else:
|
398
|
+
# No more sequences available, we place a gap
|
399
|
+
row_pairing[chain_id] = -1
|
400
|
+
|
401
|
+
pairing.append(row_pairing)
|
402
|
+
is_paired.append(row_is_paired)
|
403
|
+
|
404
|
+
# Break if we have enough sequences
|
405
|
+
if len(pairing) >= max_total:
|
406
|
+
break
|
407
|
+
|
408
|
+
# Randomly sample a subset of the pairs
|
409
|
+
# ensuring the first row is always present
|
410
|
+
if random_subset:
|
411
|
+
num_seqs = len(pairing)
|
412
|
+
if num_seqs > max_seqs:
|
413
|
+
indices = random.choice(
|
414
|
+
np.arange(1, num_seqs), size=max_seqs - 1, replace=False
|
415
|
+
) # noqa: NPY002
|
416
|
+
pairing = [pairing[0]] + [pairing[i] for i in indices]
|
417
|
+
is_paired = [is_paired[0]] + [is_paired[i] for i in indices]
|
418
|
+
else:
|
419
|
+
# Deterministic downsample to max_seqs
|
420
|
+
pairing = pairing[:max_seqs]
|
421
|
+
is_paired = is_paired[:max_seqs]
|
422
|
+
|
423
|
+
# Map (chain_id, seq_idx, res_idx) to deletion
|
424
|
+
deletions = {}
|
425
|
+
for chain_id, chain_msa in msa.items():
|
426
|
+
chain_deletions = chain_msa.deletions
|
427
|
+
for sequence in chain_msa.sequences:
|
428
|
+
del_start = sequence["del_start"]
|
429
|
+
del_end = sequence["del_end"]
|
430
|
+
chain_deletions = chain_msa.deletions[del_start:del_end]
|
431
|
+
for deletion_data in chain_deletions:
|
432
|
+
seq_idx = sequence["seq_idx"]
|
433
|
+
res_idx = deletion_data["res_idx"]
|
434
|
+
deletion = deletion_data["deletion"]
|
435
|
+
deletions[(chain_id, seq_idx, res_idx)] = deletion
|
436
|
+
|
437
|
+
# Add all the token MSA data
|
438
|
+
msa_data, del_data, paired_data = prepare_msa_arrays(
|
439
|
+
data.tokens, pairing, is_paired, deletions, msa
|
440
|
+
)
|
441
|
+
|
442
|
+
msa_data = torch.tensor(msa_data, dtype=torch.long)
|
443
|
+
del_data = torch.tensor(del_data, dtype=torch.float)
|
444
|
+
paired_data = torch.tensor(paired_data, dtype=torch.float)
|
445
|
+
|
446
|
+
return msa_data, del_data, paired_data
|
447
|
+
|
448
|
+
|
449
|
+
def prepare_msa_arrays(
|
450
|
+
tokens,
|
451
|
+
pairing: list[dict[int, int]],
|
452
|
+
is_paired: list[dict[int, int]],
|
453
|
+
deletions: dict[tuple[int, int, int], int],
|
454
|
+
msa: dict[int, MSA],
|
455
|
+
) -> tuple[npt.NDArray[np.int64], npt.NDArray[np.int64], npt.NDArray[np.int64]]:
|
456
|
+
"""Reshape data to play nicely with numba jit."""
|
457
|
+
token_asym_ids_arr = np.array([t["asym_id"] for t in tokens], dtype=np.int64)
|
458
|
+
token_res_idxs_arr = np.array([t["res_idx"] for t in tokens], dtype=np.int64)
|
459
|
+
|
460
|
+
chain_ids = sorted(msa.keys())
|
461
|
+
|
462
|
+
# chain_ids are not necessarily contiguous (e.g. they might be 0, 24, 25).
|
463
|
+
# This allows us to look up a chain_id by it's index in the chain_ids list.
|
464
|
+
chain_id_to_idx = {chain_id: i for i, chain_id in enumerate(chain_ids)}
|
465
|
+
token_asym_ids_idx_arr = np.array(
|
466
|
+
[chain_id_to_idx[asym_id] for asym_id in token_asym_ids_arr], dtype=np.int64
|
467
|
+
)
|
468
|
+
|
469
|
+
pairing_arr = np.zeros((len(pairing), len(chain_ids)), dtype=np.int64)
|
470
|
+
is_paired_arr = np.zeros((len(is_paired), len(chain_ids)), dtype=np.int64)
|
471
|
+
|
472
|
+
for i, row_pairing in enumerate(pairing):
|
473
|
+
for chain_id in chain_ids:
|
474
|
+
pairing_arr[i, chain_id_to_idx[chain_id]] = row_pairing[chain_id]
|
475
|
+
|
476
|
+
for i, row_is_paired in enumerate(is_paired):
|
477
|
+
for chain_id in chain_ids:
|
478
|
+
is_paired_arr[i, chain_id_to_idx[chain_id]] = row_is_paired[chain_id]
|
479
|
+
|
480
|
+
max_seq_len = max(len(msa[chain_id].sequences) for chain_id in chain_ids)
|
481
|
+
|
482
|
+
# we want res_start from sequences
|
483
|
+
msa_sequences = np.full((len(chain_ids), max_seq_len), -1, dtype=np.int64)
|
484
|
+
for chain_id in chain_ids:
|
485
|
+
for i, seq in enumerate(msa[chain_id].sequences):
|
486
|
+
msa_sequences[chain_id_to_idx[chain_id], i] = seq["res_start"]
|
487
|
+
|
488
|
+
max_residues_len = max(len(msa[chain_id].residues) for chain_id in chain_ids)
|
489
|
+
msa_residues = np.full((len(chain_ids), max_residues_len), -1, dtype=np.int64)
|
490
|
+
for chain_id in chain_ids:
|
491
|
+
residues = msa[chain_id].residues.astype(np.int64)
|
492
|
+
idxs = np.arange(len(residues))
|
493
|
+
chain_idx = chain_id_to_idx[chain_id]
|
494
|
+
msa_residues[chain_idx, idxs] = residues
|
495
|
+
|
496
|
+
deletions_dict = numba.typed.Dict.empty(
|
497
|
+
key_type=numba.types.Tuple(
|
498
|
+
[numba.types.int64, numba.types.int64, numba.types.int64]
|
499
|
+
),
|
500
|
+
value_type=numba.types.int64,
|
501
|
+
)
|
502
|
+
deletions_dict.update(deletions)
|
503
|
+
|
504
|
+
return _prepare_msa_arrays_inner(
|
505
|
+
token_asym_ids_arr,
|
506
|
+
token_res_idxs_arr,
|
507
|
+
token_asym_ids_idx_arr,
|
508
|
+
pairing_arr,
|
509
|
+
is_paired_arr,
|
510
|
+
deletions_dict,
|
511
|
+
msa_sequences,
|
512
|
+
msa_residues,
|
513
|
+
const.token_ids["-"],
|
514
|
+
)
|
515
|
+
|
516
|
+
|
517
|
+
deletions_dict_type = types.DictType(types.UniTuple(types.int64, 3), types.int64)
|
518
|
+
|
519
|
+
|
520
|
+
@numba.njit(
|
521
|
+
[
|
522
|
+
types.Tuple(
|
523
|
+
(
|
524
|
+
types.int64[:, ::1], # msa_data
|
525
|
+
types.int64[:, ::1], # del_data
|
526
|
+
types.int64[:, ::1], # paired_data
|
527
|
+
)
|
528
|
+
)(
|
529
|
+
types.int64[::1], # token_asym_ids
|
530
|
+
types.int64[::1], # token_res_idxs
|
531
|
+
types.int64[::1], # token_asym_ids_idx
|
532
|
+
types.int64[:, ::1], # pairing
|
533
|
+
types.int64[:, ::1], # is_paired
|
534
|
+
deletions_dict_type, # deletions
|
535
|
+
types.int64[:, ::1], # msa_sequences
|
536
|
+
types.int64[:, ::1], # msa_residues
|
537
|
+
types.int64, # gap_token
|
538
|
+
)
|
539
|
+
],
|
540
|
+
cache=True,
|
541
|
+
)
|
542
|
+
def _prepare_msa_arrays_inner(
|
543
|
+
token_asym_ids: npt.NDArray[np.int64],
|
544
|
+
token_res_idxs: npt.NDArray[np.int64],
|
545
|
+
token_asym_ids_idx: npt.NDArray[np.int64],
|
546
|
+
pairing: npt.NDArray[np.int64],
|
547
|
+
is_paired: npt.NDArray[np.int64],
|
548
|
+
deletions: dict[tuple[int, int, int], int],
|
549
|
+
msa_sequences: npt.NDArray[np.int64],
|
550
|
+
msa_residues: npt.NDArray[np.int64],
|
551
|
+
gap_token: int,
|
552
|
+
) -> tuple[npt.NDArray[np.int64], npt.NDArray[np.int64], npt.NDArray[np.int64]]:
|
553
|
+
n_tokens = len(token_asym_ids)
|
554
|
+
n_pairs = len(pairing)
|
555
|
+
msa_data = np.full((n_tokens, n_pairs), gap_token, dtype=np.int64)
|
556
|
+
paired_data = np.zeros((n_tokens, n_pairs), dtype=np.int64)
|
557
|
+
del_data = np.zeros((n_tokens, n_pairs), dtype=np.int64)
|
558
|
+
|
559
|
+
# Add all the token MSA data
|
560
|
+
for token_idx in range(n_tokens):
|
561
|
+
chain_id_idx = token_asym_ids_idx[token_idx]
|
562
|
+
chain_id = token_asym_ids[token_idx]
|
563
|
+
res_idx = token_res_idxs[token_idx]
|
564
|
+
|
565
|
+
for pair_idx in range(n_pairs):
|
566
|
+
seq_idx = pairing[pair_idx, chain_id_idx]
|
567
|
+
paired_data[token_idx, pair_idx] = is_paired[pair_idx, chain_id_idx]
|
568
|
+
|
569
|
+
# Add residue type
|
570
|
+
if seq_idx != -1:
|
571
|
+
res_start = msa_sequences[chain_id_idx, seq_idx]
|
572
|
+
res_type = msa_residues[chain_id_idx, res_start + res_idx]
|
573
|
+
k = (chain_id, seq_idx, res_idx)
|
574
|
+
if k in deletions:
|
575
|
+
del_data[token_idx, pair_idx] = deletions[k]
|
576
|
+
msa_data[token_idx, pair_idx] = res_type
|
577
|
+
|
578
|
+
return msa_data, del_data, paired_data
|
579
|
+
|
580
|
+
|
581
|
+
####################################################################################################
|
582
|
+
# FEATURES
|
583
|
+
####################################################################################################
|
584
|
+
|
585
|
+
|
586
|
+
def select_subset_from_mask(mask, p, random: np.random.Generator) -> np.ndarray:
|
587
|
+
num_true = np.sum(mask)
|
588
|
+
v = random.geometric(p) + 1
|
589
|
+
k = min(v, num_true)
|
590
|
+
|
591
|
+
true_indices = np.where(mask)[0]
|
592
|
+
|
593
|
+
# Randomly select k indices from the true_indices
|
594
|
+
selected_indices = random.choice(true_indices, size=k, replace=False)
|
595
|
+
|
596
|
+
new_mask = np.zeros_like(mask)
|
597
|
+
new_mask[selected_indices] = 1
|
598
|
+
|
599
|
+
return new_mask
|
600
|
+
|
601
|
+
|
602
|
+
def get_range_bin(value: float, range_dict: dict[tuple[float, float], int], default=0):
|
603
|
+
"""Get the bin of a value given a range dictionary."""
|
604
|
+
value = float(value)
|
605
|
+
for k, idx in range_dict.items():
|
606
|
+
if k == "other":
|
607
|
+
continue
|
608
|
+
low, high = k
|
609
|
+
if low <= value < high:
|
610
|
+
return idx
|
611
|
+
return default
|
612
|
+
|
613
|
+
|
614
|
+
def process_token_features( # noqa: C901, PLR0915, PLR0912
|
615
|
+
data: Tokenized,
|
616
|
+
random: np.random.Generator,
|
617
|
+
max_tokens: Optional[int] = None,
|
618
|
+
binder_pocket_conditioned_prop: Optional[float] = 0.0,
|
619
|
+
contact_conditioned_prop: Optional[float] = 0.0,
|
620
|
+
binder_pocket_cutoff_min: Optional[float] = 4.0,
|
621
|
+
binder_pocket_cutoff_max: Optional[float] = 20.0,
|
622
|
+
binder_pocket_sampling_geometric_p: Optional[float] = 0.0,
|
623
|
+
only_ligand_binder_pocket: Optional[bool] = False,
|
624
|
+
only_pp_contact: Optional[bool] = False,
|
625
|
+
inference_pocket_constraints: Optional[bool] = False,
|
626
|
+
override_method: Optional[str] = None,
|
627
|
+
) -> dict[str, Tensor]:
|
628
|
+
"""Get the token features.
|
629
|
+
|
630
|
+
Parameters
|
631
|
+
----------
|
632
|
+
data : Tokenized
|
633
|
+
The input data to the model.
|
634
|
+
max_tokens : int
|
635
|
+
The maximum number of tokens.
|
636
|
+
|
637
|
+
Returns
|
638
|
+
-------
|
639
|
+
dict[str, Tensor]
|
640
|
+
The token features.
|
641
|
+
|
642
|
+
"""
|
643
|
+
# Token data
|
644
|
+
token_data = data.tokens
|
645
|
+
token_bonds = data.bonds
|
646
|
+
|
647
|
+
# Token core features
|
648
|
+
token_index = torch.arange(len(token_data), dtype=torch.long)
|
649
|
+
residue_index = from_numpy(token_data["res_idx"]).long()
|
650
|
+
asym_id = from_numpy(token_data["asym_id"]).long()
|
651
|
+
entity_id = from_numpy(token_data["entity_id"]).long()
|
652
|
+
sym_id = from_numpy(token_data["sym_id"]).long()
|
653
|
+
mol_type = from_numpy(token_data["mol_type"]).long()
|
654
|
+
res_type = from_numpy(token_data["res_type"]).long()
|
655
|
+
res_type = one_hot(res_type, num_classes=const.num_tokens)
|
656
|
+
disto_center = from_numpy(token_data["disto_coords"])
|
657
|
+
modified = from_numpy(token_data["modified"]).long() # float()
|
658
|
+
cyclic_period = from_numpy(token_data["cyclic_period"].copy())
|
659
|
+
affinity_mask = from_numpy(token_data["affinity_mask"]).float()
|
660
|
+
|
661
|
+
## Conditioning features ##
|
662
|
+
method = (
|
663
|
+
np.zeros(len(token_data))
|
664
|
+
+ const.method_types_ids[
|
665
|
+
(
|
666
|
+
"x-ray diffraction"
|
667
|
+
if override_method is None
|
668
|
+
else override_method.lower()
|
669
|
+
)
|
670
|
+
]
|
671
|
+
)
|
672
|
+
if data.record is not None:
|
673
|
+
if (
|
674
|
+
override_method is None
|
675
|
+
and data.record.structure.method is not None
|
676
|
+
and data.record.structure.method.lower() in const.method_types_ids
|
677
|
+
):
|
678
|
+
method = (method * 0) + const.method_types_ids[
|
679
|
+
data.record.structure.method.lower()
|
680
|
+
]
|
681
|
+
|
682
|
+
method_feature = from_numpy(method).long()
|
683
|
+
|
684
|
+
# Token mask features
|
685
|
+
pad_mask = torch.ones(len(token_data), dtype=torch.float)
|
686
|
+
resolved_mask = from_numpy(token_data["resolved_mask"]).float()
|
687
|
+
disto_mask = from_numpy(token_data["disto_mask"]).float()
|
688
|
+
|
689
|
+
# Token bond features
|
690
|
+
if max_tokens is not None:
|
691
|
+
pad_len = max_tokens - len(token_data)
|
692
|
+
num_tokens = max_tokens if pad_len > 0 else len(token_data)
|
693
|
+
else:
|
694
|
+
num_tokens = len(token_data)
|
695
|
+
|
696
|
+
tok_to_idx = {tok["token_idx"]: idx for idx, tok in enumerate(token_data)}
|
697
|
+
bonds = torch.zeros(num_tokens, num_tokens, dtype=torch.float)
|
698
|
+
bonds_type = torch.zeros(num_tokens, num_tokens, dtype=torch.long)
|
699
|
+
for token_bond in token_bonds:
|
700
|
+
token_1 = tok_to_idx[token_bond["token_1"]]
|
701
|
+
token_2 = tok_to_idx[token_bond["token_2"]]
|
702
|
+
bonds[token_1, token_2] = 1
|
703
|
+
bonds[token_2, token_1] = 1
|
704
|
+
bond_type = token_bond["type"]
|
705
|
+
bonds_type[token_1, token_2] = bond_type
|
706
|
+
bonds_type[token_2, token_1] = bond_type
|
707
|
+
|
708
|
+
bonds = bonds.unsqueeze(-1)
|
709
|
+
|
710
|
+
# Pocket conditioned feature
|
711
|
+
contact_conditioning = (
|
712
|
+
np.zeros((len(token_data), len(token_data)))
|
713
|
+
+ const.contact_conditioning_info["UNSELECTED"]
|
714
|
+
)
|
715
|
+
contact_threshold = np.zeros((len(token_data), len(token_data)))
|
716
|
+
|
717
|
+
if inference_pocket_constraints is not None:
|
718
|
+
for binder, contacts, max_distance in inference_pocket_constraints:
|
719
|
+
binder_mask = token_data["asym_id"] == binder
|
720
|
+
|
721
|
+
for idx, token in enumerate(token_data):
|
722
|
+
if (
|
723
|
+
token["mol_type"] != const.chain_type_ids["NONPOLYMER"]
|
724
|
+
and (token["asym_id"], token["res_idx"]) in contacts
|
725
|
+
) or (
|
726
|
+
token["mol_type"] == const.chain_type_ids["NONPOLYMER"]
|
727
|
+
and (token["asym_id"], token["atom_idx"]) in contacts
|
728
|
+
):
|
729
|
+
contact_conditioning[binder_mask][:, idx] = (
|
730
|
+
const.contact_conditioning_info["BINDER>POCKET"]
|
731
|
+
)
|
732
|
+
contact_conditioning[idx][binder_mask] = (
|
733
|
+
const.contact_conditioning_info["POCKET>BINDER"]
|
734
|
+
)
|
735
|
+
contact_threshold[binder_mask][:, idx] = max_distance
|
736
|
+
contact_threshold[idx][binder_mask] = max_distance
|
737
|
+
|
738
|
+
if binder_pocket_conditioned_prop > 0.0:
|
739
|
+
# choose as binder a random ligand in the crop, if there are no ligands select a protein chain
|
740
|
+
binder_asym_ids = np.unique(
|
741
|
+
token_data["asym_id"][
|
742
|
+
token_data["mol_type"] == const.chain_type_ids["NONPOLYMER"]
|
743
|
+
]
|
744
|
+
)
|
745
|
+
|
746
|
+
if len(binder_asym_ids) == 0:
|
747
|
+
if not only_ligand_binder_pocket:
|
748
|
+
binder_asym_ids = np.unique(token_data["asym_id"])
|
749
|
+
|
750
|
+
while random.random() < binder_pocket_conditioned_prop:
|
751
|
+
if len(binder_asym_ids) == 0:
|
752
|
+
break
|
753
|
+
|
754
|
+
pocket_asym_id = random.choice(binder_asym_ids)
|
755
|
+
binder_asym_ids = binder_asym_ids[binder_asym_ids != pocket_asym_id]
|
756
|
+
|
757
|
+
binder_pocket_cutoff = sample_d(
|
758
|
+
min_d=binder_pocket_cutoff_min,
|
759
|
+
max_d=binder_pocket_cutoff_max,
|
760
|
+
n_samples=1,
|
761
|
+
random=random,
|
762
|
+
)
|
763
|
+
|
764
|
+
binder_mask = token_data["asym_id"] == pocket_asym_id
|
765
|
+
|
766
|
+
binder_coords = []
|
767
|
+
for token in token_data:
|
768
|
+
if token["asym_id"] == pocket_asym_id:
|
769
|
+
_coords = data.structure.atoms["coords"][
|
770
|
+
token["atom_idx"] : token["atom_idx"] + token["atom_num"]
|
771
|
+
]
|
772
|
+
_is_present = data.structure.atoms["is_present"][
|
773
|
+
token["atom_idx"] : token["atom_idx"] + token["atom_num"]
|
774
|
+
]
|
775
|
+
binder_coords.append(_coords[_is_present])
|
776
|
+
binder_coords = np.concatenate(binder_coords, axis=0)
|
777
|
+
|
778
|
+
# find the tokens in the pocket
|
779
|
+
token_dist = np.zeros(len(token_data)) + 1000
|
780
|
+
for i, token in enumerate(token_data):
|
781
|
+
if (
|
782
|
+
token["mol_type"] != const.chain_type_ids["NONPOLYMER"]
|
783
|
+
and token["asym_id"] != pocket_asym_id
|
784
|
+
and token["resolved_mask"] == 1
|
785
|
+
):
|
786
|
+
token_coords = data.structure.atoms["coords"][
|
787
|
+
token["atom_idx"] : token["atom_idx"] + token["atom_num"]
|
788
|
+
]
|
789
|
+
token_is_present = data.structure.atoms["is_present"][
|
790
|
+
token["atom_idx"] : token["atom_idx"] + token["atom_num"]
|
791
|
+
]
|
792
|
+
token_coords = token_coords[token_is_present]
|
793
|
+
|
794
|
+
# find chain and apply chain transformation
|
795
|
+
for chain in data.structure.chains:
|
796
|
+
if chain["asym_id"] == token["asym_id"]:
|
797
|
+
break
|
798
|
+
|
799
|
+
token_dist[i] = np.min(
|
800
|
+
np.linalg.norm(
|
801
|
+
token_coords[:, None, :] - binder_coords[None, :, :],
|
802
|
+
axis=-1,
|
803
|
+
)
|
804
|
+
)
|
805
|
+
|
806
|
+
pocket_mask = token_dist < binder_pocket_cutoff
|
807
|
+
|
808
|
+
if np.sum(pocket_mask) > 0:
|
809
|
+
if binder_pocket_sampling_geometric_p > 0.0:
|
810
|
+
# select a subset of the pocket, according
|
811
|
+
# to a geometric distribution with one as minimum
|
812
|
+
pocket_mask = select_subset_from_mask(
|
813
|
+
pocket_mask,
|
814
|
+
binder_pocket_sampling_geometric_p,
|
815
|
+
random,
|
816
|
+
)
|
817
|
+
|
818
|
+
contact_conditioning[np.ix_(binder_mask, pocket_mask)] = (
|
819
|
+
const.contact_conditioning_info["BINDER>POCKET"]
|
820
|
+
)
|
821
|
+
contact_conditioning[np.ix_(pocket_mask, binder_mask)] = (
|
822
|
+
const.contact_conditioning_info["POCKET>BINDER"]
|
823
|
+
)
|
824
|
+
contact_threshold[np.ix_(binder_mask, pocket_mask)] = (
|
825
|
+
binder_pocket_cutoff
|
826
|
+
)
|
827
|
+
contact_threshold[np.ix_(pocket_mask, binder_mask)] = (
|
828
|
+
binder_pocket_cutoff
|
829
|
+
)
|
830
|
+
|
831
|
+
# Contact conditioning feature
|
832
|
+
if contact_conditioned_prop > 0.0:
|
833
|
+
while random.random() < contact_conditioned_prop:
|
834
|
+
contact_cutoff = sample_d(
|
835
|
+
min_d=binder_pocket_cutoff_min,
|
836
|
+
max_d=binder_pocket_cutoff_max,
|
837
|
+
n_samples=1,
|
838
|
+
random=random,
|
839
|
+
)
|
840
|
+
if only_pp_contact:
|
841
|
+
chain_asym_ids = np.unique(
|
842
|
+
token_data["asym_id"][
|
843
|
+
token_data["mol_type"] == const.chain_type_ids["PROTEIN"]
|
844
|
+
]
|
845
|
+
)
|
846
|
+
else:
|
847
|
+
chain_asym_ids = np.unique(token_data["asym_id"])
|
848
|
+
|
849
|
+
if len(chain_asym_ids) > 1:
|
850
|
+
chain_asym_id = random.choice(chain_asym_ids)
|
851
|
+
|
852
|
+
chain_coords = []
|
853
|
+
for token in token_data:
|
854
|
+
if token["asym_id"] == chain_asym_id:
|
855
|
+
_coords = data.structure.atoms["coords"][
|
856
|
+
token["atom_idx"] : token["atom_idx"] + token["atom_num"]
|
857
|
+
]
|
858
|
+
_is_present = data.structure.atoms["is_present"][
|
859
|
+
token["atom_idx"] : token["atom_idx"] + token["atom_num"]
|
860
|
+
]
|
861
|
+
chain_coords.append(_coords[_is_present])
|
862
|
+
chain_coords = np.concatenate(chain_coords, axis=0)
|
863
|
+
|
864
|
+
# find contacts in other chains
|
865
|
+
possible_other_chains = []
|
866
|
+
for other_chain_id in chain_asym_ids[chain_asym_ids != chain_asym_id]:
|
867
|
+
for token in token_data:
|
868
|
+
if token["asym_id"] == other_chain_id:
|
869
|
+
_coords = data.structure.atoms["coords"][
|
870
|
+
token["atom_idx"] : token["atom_idx"]
|
871
|
+
+ token["atom_num"]
|
872
|
+
]
|
873
|
+
_is_present = data.structure.atoms["is_present"][
|
874
|
+
token["atom_idx"] : token["atom_idx"]
|
875
|
+
+ token["atom_num"]
|
876
|
+
]
|
877
|
+
if _is_present.sum() == 0:
|
878
|
+
continue
|
879
|
+
token_coords = _coords[_is_present]
|
880
|
+
|
881
|
+
# check minimum distance
|
882
|
+
if (
|
883
|
+
np.min(cdist(chain_coords, token_coords))
|
884
|
+
< contact_cutoff
|
885
|
+
):
|
886
|
+
possible_other_chains.append(other_chain_id)
|
887
|
+
break
|
888
|
+
|
889
|
+
if len(possible_other_chains) > 0:
|
890
|
+
other_chain_id = random.choice(possible_other_chains)
|
891
|
+
|
892
|
+
pairs = []
|
893
|
+
for token_1 in token_data:
|
894
|
+
if token_1["asym_id"] == chain_asym_id:
|
895
|
+
_coords = data.structure.atoms["coords"][
|
896
|
+
token_1["atom_idx"] : token_1["atom_idx"]
|
897
|
+
+ token_1["atom_num"]
|
898
|
+
]
|
899
|
+
_is_present = data.structure.atoms["is_present"][
|
900
|
+
token_1["atom_idx"] : token_1["atom_idx"]
|
901
|
+
+ token_1["atom_num"]
|
902
|
+
]
|
903
|
+
if _is_present.sum() == 0:
|
904
|
+
continue
|
905
|
+
token_1_coords = _coords[_is_present]
|
906
|
+
|
907
|
+
for token_2 in token_data:
|
908
|
+
if token_2["asym_id"] == other_chain_id:
|
909
|
+
_coords = data.structure.atoms["coords"][
|
910
|
+
token_2["atom_idx"] : token_2["atom_idx"]
|
911
|
+
+ token_2["atom_num"]
|
912
|
+
]
|
913
|
+
_is_present = data.structure.atoms["is_present"][
|
914
|
+
token_2["atom_idx"] : token_2["atom_idx"]
|
915
|
+
+ token_2["atom_num"]
|
916
|
+
]
|
917
|
+
if _is_present.sum() == 0:
|
918
|
+
continue
|
919
|
+
token_2_coords = _coords[_is_present]
|
920
|
+
|
921
|
+
if (
|
922
|
+
np.min(cdist(token_1_coords, token_2_coords))
|
923
|
+
< contact_cutoff
|
924
|
+
):
|
925
|
+
pairs.append(
|
926
|
+
(token_1["token_idx"], token_2["token_idx"])
|
927
|
+
)
|
928
|
+
|
929
|
+
assert len(pairs) > 0
|
930
|
+
|
931
|
+
pair = random.choice(pairs)
|
932
|
+
token_1_mask = token_data["token_idx"] == pair[0]
|
933
|
+
token_2_mask = token_data["token_idx"] == pair[1]
|
934
|
+
|
935
|
+
contact_conditioning[np.ix_(token_1_mask, token_2_mask)] = (
|
936
|
+
const.contact_conditioning_info["CONTACT"]
|
937
|
+
)
|
938
|
+
contact_conditioning[np.ix_(token_2_mask, token_1_mask)] = (
|
939
|
+
const.contact_conditioning_info["CONTACT"]
|
940
|
+
)
|
941
|
+
|
942
|
+
elif not only_pp_contact:
|
943
|
+
# only one chain, find contacts within the chain with minimum residue distance
|
944
|
+
pairs = []
|
945
|
+
for token_1 in token_data:
|
946
|
+
_coords = data.structure.atoms["coords"][
|
947
|
+
token_1["atom_idx"] : token_1["atom_idx"] + token_1["atom_num"]
|
948
|
+
]
|
949
|
+
_is_present = data.structure.atoms["is_present"][
|
950
|
+
token_1["atom_idx"] : token_1["atom_idx"] + token_1["atom_num"]
|
951
|
+
]
|
952
|
+
if _is_present.sum() == 0:
|
953
|
+
continue
|
954
|
+
token_1_coords = _coords[_is_present]
|
955
|
+
|
956
|
+
for token_2 in token_data:
|
957
|
+
if np.abs(token_1["res_idx"] - token_2["res_idx"]) <= 8:
|
958
|
+
continue
|
959
|
+
|
960
|
+
_coords = data.structure.atoms["coords"][
|
961
|
+
token_2["atom_idx"] : token_2["atom_idx"]
|
962
|
+
+ token_2["atom_num"]
|
963
|
+
]
|
964
|
+
_is_present = data.structure.atoms["is_present"][
|
965
|
+
token_2["atom_idx"] : token_2["atom_idx"]
|
966
|
+
+ token_2["atom_num"]
|
967
|
+
]
|
968
|
+
if _is_present.sum() == 0:
|
969
|
+
continue
|
970
|
+
token_2_coords = _coords[_is_present]
|
971
|
+
|
972
|
+
if (
|
973
|
+
np.min(cdist(token_1_coords, token_2_coords))
|
974
|
+
< contact_cutoff
|
975
|
+
):
|
976
|
+
pairs.append((token_1["token_idx"], token_2["token_idx"]))
|
977
|
+
|
978
|
+
if len(pairs) > 0:
|
979
|
+
pair = random.choice(pairs)
|
980
|
+
token_1_mask = token_data["token_idx"] == pair[0]
|
981
|
+
token_2_mask = token_data["token_idx"] == pair[1]
|
982
|
+
|
983
|
+
contact_conditioning[np.ix_(token_1_mask, token_2_mask)] = (
|
984
|
+
const.contact_conditioning_info["CONTACT"]
|
985
|
+
)
|
986
|
+
contact_conditioning[np.ix_(token_2_mask, token_1_mask)] = (
|
987
|
+
const.contact_conditioning_info["CONTACT"]
|
988
|
+
)
|
989
|
+
|
990
|
+
if np.all(contact_conditioning == const.contact_conditioning_info["UNSELECTED"]):
|
991
|
+
contact_conditioning = (
|
992
|
+
contact_conditioning
|
993
|
+
- const.contact_conditioning_info["UNSELECTED"]
|
994
|
+
+ const.contact_conditioning_info["UNSPECIFIED"]
|
995
|
+
)
|
996
|
+
contact_conditioning = from_numpy(contact_conditioning).long()
|
997
|
+
contact_conditioning = one_hot(
|
998
|
+
contact_conditioning, num_classes=len(const.contact_conditioning_info)
|
999
|
+
)
|
1000
|
+
contact_threshold = from_numpy(contact_threshold).float()
|
1001
|
+
|
1002
|
+
# compute cyclic polymer mask
|
1003
|
+
cyclic_ids = {}
|
1004
|
+
for idx_chain, asym_id_iter in enumerate(data.structure.chains["asym_id"]):
|
1005
|
+
for connection in data.structure.bonds:
|
1006
|
+
if (
|
1007
|
+
idx_chain == connection["chain_1"] == connection["chain_2"]
|
1008
|
+
and data.structure.chains[connection["chain_1"]]["res_num"] > 2
|
1009
|
+
and connection["res_1"]
|
1010
|
+
!= connection["res_2"] # Avoid same residue bonds!
|
1011
|
+
):
|
1012
|
+
if (
|
1013
|
+
data.structure.chains[connection["chain_1"]]["res_num"]
|
1014
|
+
== (connection["res_2"] + 1)
|
1015
|
+
and connection["res_1"] == 0
|
1016
|
+
) or (
|
1017
|
+
data.structure.chains[connection["chain_1"]]["res_num"]
|
1018
|
+
== (connection["res_1"] + 1)
|
1019
|
+
and connection["res_2"] == 0
|
1020
|
+
):
|
1021
|
+
cyclic_ids[asym_id_iter] = data.structure.chains[
|
1022
|
+
connection["chain_1"]
|
1023
|
+
]["res_num"]
|
1024
|
+
cyclic = from_numpy(
|
1025
|
+
np.array(
|
1026
|
+
[
|
1027
|
+
(cyclic_ids[asym_id_iter] if asym_id_iter in cyclic_ids else 0)
|
1028
|
+
for asym_id_iter in token_data["asym_id"]
|
1029
|
+
]
|
1030
|
+
)
|
1031
|
+
).float()
|
1032
|
+
|
1033
|
+
# cyclic period is either computed from the bonds or given as input flag
|
1034
|
+
cyclic_period = torch.maximum(cyclic, cyclic_period)
|
1035
|
+
|
1036
|
+
# Pad to max tokens if given
|
1037
|
+
if max_tokens is not None:
|
1038
|
+
pad_len = max_tokens - len(token_data)
|
1039
|
+
if pad_len > 0:
|
1040
|
+
token_index = pad_dim(token_index, 0, pad_len)
|
1041
|
+
residue_index = pad_dim(residue_index, 0, pad_len)
|
1042
|
+
asym_id = pad_dim(asym_id, 0, pad_len)
|
1043
|
+
entity_id = pad_dim(entity_id, 0, pad_len)
|
1044
|
+
sym_id = pad_dim(sym_id, 0, pad_len)
|
1045
|
+
mol_type = pad_dim(mol_type, 0, pad_len)
|
1046
|
+
res_type = pad_dim(res_type, 0, pad_len)
|
1047
|
+
disto_center = pad_dim(disto_center, 0, pad_len)
|
1048
|
+
pad_mask = pad_dim(pad_mask, 0, pad_len)
|
1049
|
+
resolved_mask = pad_dim(resolved_mask, 0, pad_len)
|
1050
|
+
disto_mask = pad_dim(disto_mask, 0, pad_len)
|
1051
|
+
contact_conditioning = pad_dim(contact_conditioning, 0, pad_len)
|
1052
|
+
contact_conditioning = pad_dim(contact_conditioning, 1, pad_len)
|
1053
|
+
contact_threshold = pad_dim(contact_threshold, 0, pad_len)
|
1054
|
+
contact_threshold = pad_dim(contact_threshold, 1, pad_len)
|
1055
|
+
method_feature = pad_dim(method_feature, 0, pad_len)
|
1056
|
+
modified = pad_dim(modified, 0, pad_len)
|
1057
|
+
cyclic_period = pad_dim(cyclic_period, 0, pad_len)
|
1058
|
+
affinity_mask = pad_dim(affinity_mask, 0, pad_len)
|
1059
|
+
|
1060
|
+
token_features = {
|
1061
|
+
"token_index": token_index,
|
1062
|
+
"residue_index": residue_index,
|
1063
|
+
"asym_id": asym_id,
|
1064
|
+
"entity_id": entity_id,
|
1065
|
+
"sym_id": sym_id,
|
1066
|
+
"mol_type": mol_type,
|
1067
|
+
"res_type": res_type,
|
1068
|
+
"disto_center": disto_center,
|
1069
|
+
"token_bonds": bonds,
|
1070
|
+
"type_bonds": bonds_type,
|
1071
|
+
"token_pad_mask": pad_mask,
|
1072
|
+
"token_resolved_mask": resolved_mask,
|
1073
|
+
"token_disto_mask": disto_mask,
|
1074
|
+
"contact_conditioning": contact_conditioning,
|
1075
|
+
"contact_threshold": contact_threshold,
|
1076
|
+
"method_feature": method_feature,
|
1077
|
+
"modified": modified,
|
1078
|
+
"cyclic_period": cyclic_period,
|
1079
|
+
"affinity_token_mask": affinity_mask,
|
1080
|
+
}
|
1081
|
+
|
1082
|
+
return token_features
|
1083
|
+
|
1084
|
+
|
1085
|
+
def process_atom_features(
|
1086
|
+
data: Tokenized,
|
1087
|
+
random: np.random.Generator,
|
1088
|
+
ensemble_features: dict,
|
1089
|
+
molecules: dict[str, Mol],
|
1090
|
+
atoms_per_window_queries: int = 32,
|
1091
|
+
min_dist: float = 2.0,
|
1092
|
+
max_dist: float = 22.0,
|
1093
|
+
num_bins: int = 64,
|
1094
|
+
max_atoms: Optional[int] = None,
|
1095
|
+
max_tokens: Optional[int] = None,
|
1096
|
+
disto_use_ensemble: Optional[bool] = False,
|
1097
|
+
override_bfactor: bool = False,
|
1098
|
+
compute_frames: bool = False,
|
1099
|
+
override_coords: Optional[Tensor] = None,
|
1100
|
+
bfactor_md_correction: bool = False,
|
1101
|
+
) -> dict[str, Tensor]:
|
1102
|
+
"""Get the atom features.
|
1103
|
+
|
1104
|
+
Parameters
|
1105
|
+
----------
|
1106
|
+
data : Tokenized
|
1107
|
+
The input to the model.
|
1108
|
+
max_atoms : int, optional
|
1109
|
+
The maximum number of atoms.
|
1110
|
+
|
1111
|
+
Returns
|
1112
|
+
-------
|
1113
|
+
dict[str, Tensor]
|
1114
|
+
The atom features.
|
1115
|
+
|
1116
|
+
"""
|
1117
|
+
# Filter to tokens' atoms
|
1118
|
+
atom_data = []
|
1119
|
+
atom_name = []
|
1120
|
+
atom_element = []
|
1121
|
+
atom_charge = []
|
1122
|
+
atom_conformer = []
|
1123
|
+
atom_chirality = []
|
1124
|
+
ref_space_uid = []
|
1125
|
+
coord_data = []
|
1126
|
+
if compute_frames:
|
1127
|
+
frame_data = []
|
1128
|
+
resolved_frame_data = []
|
1129
|
+
atom_to_token = []
|
1130
|
+
token_to_rep_atom = [] # index on cropped atom table
|
1131
|
+
r_set_to_rep_atom = []
|
1132
|
+
disto_coords_ensemble = []
|
1133
|
+
backbone_feat_index = []
|
1134
|
+
token_to_center_atom = []
|
1135
|
+
|
1136
|
+
e_offsets = data.structure.ensemble["atom_coord_idx"]
|
1137
|
+
atom_idx = 0
|
1138
|
+
|
1139
|
+
# Start atom idx in full atom table for structures chosen. Up to num_ensembles points.
|
1140
|
+
ensemble_atom_starts = [
|
1141
|
+
data.structure.ensemble[idx]["atom_coord_idx"]
|
1142
|
+
for idx in ensemble_features["ensemble_ref_idxs"]
|
1143
|
+
]
|
1144
|
+
|
1145
|
+
# Set unk chirality id
|
1146
|
+
unk_chirality = const.chirality_type_ids[const.unk_chirality_type]
|
1147
|
+
|
1148
|
+
chain_res_ids = {}
|
1149
|
+
res_index_to_conf_id = {}
|
1150
|
+
for token_id, token in enumerate(data.tokens):
|
1151
|
+
# Get the chain residue ids
|
1152
|
+
chain_idx, res_id = token["asym_id"], token["res_idx"]
|
1153
|
+
chain = data.structure.chains[chain_idx]
|
1154
|
+
|
1155
|
+
if (chain_idx, res_id) not in chain_res_ids:
|
1156
|
+
new_idx = len(chain_res_ids)
|
1157
|
+
chain_res_ids[(chain_idx, res_id)] = new_idx
|
1158
|
+
else:
|
1159
|
+
new_idx = chain_res_ids[(chain_idx, res_id)]
|
1160
|
+
|
1161
|
+
# Get the molecule and conformer
|
1162
|
+
mol = molecules[token["res_name"]]
|
1163
|
+
atom_name_to_ref = {a.GetProp("name"): a for a in mol.GetAtoms()}
|
1164
|
+
|
1165
|
+
# Sample a random conformer
|
1166
|
+
if (chain_idx, res_id) not in res_index_to_conf_id:
|
1167
|
+
conf_ids = [int(conf.GetId()) for conf in mol.GetConformers()]
|
1168
|
+
conf_id = int(random.choice(conf_ids))
|
1169
|
+
res_index_to_conf_id[(chain_idx, res_id)] = conf_id
|
1170
|
+
|
1171
|
+
conf_id = res_index_to_conf_id[(chain_idx, res_id)]
|
1172
|
+
conformer = mol.GetConformer(conf_id)
|
1173
|
+
|
1174
|
+
# Map atoms to token indices
|
1175
|
+
ref_space_uid.extend([new_idx] * token["atom_num"])
|
1176
|
+
atom_to_token.extend([token_id] * token["atom_num"])
|
1177
|
+
|
1178
|
+
# Add atom data
|
1179
|
+
start = token["atom_idx"]
|
1180
|
+
end = token["atom_idx"] + token["atom_num"]
|
1181
|
+
token_atoms = data.structure.atoms[start:end]
|
1182
|
+
|
1183
|
+
# Add atom ref data
|
1184
|
+
# element, charge, conformer, chirality
|
1185
|
+
token_atom_name = np.array([convert_atom_name(a["name"]) for a in token_atoms])
|
1186
|
+
token_atoms_ref = np.array([atom_name_to_ref[a["name"]] for a in token_atoms])
|
1187
|
+
token_atoms_element = np.array([a.GetAtomicNum() for a in token_atoms_ref])
|
1188
|
+
token_atoms_charge = np.array([a.GetFormalCharge() for a in token_atoms_ref])
|
1189
|
+
token_atoms_conformer = np.array(
|
1190
|
+
[
|
1191
|
+
(
|
1192
|
+
conformer.GetAtomPosition(a.GetIdx()).x,
|
1193
|
+
conformer.GetAtomPosition(a.GetIdx()).y,
|
1194
|
+
conformer.GetAtomPosition(a.GetIdx()).z,
|
1195
|
+
)
|
1196
|
+
for a in token_atoms_ref
|
1197
|
+
]
|
1198
|
+
)
|
1199
|
+
token_atoms_chirality = np.array(
|
1200
|
+
[
|
1201
|
+
const.chirality_type_ids.get(a.GetChiralTag().name, unk_chirality)
|
1202
|
+
for a in token_atoms_ref
|
1203
|
+
]
|
1204
|
+
)
|
1205
|
+
|
1206
|
+
# Map token to representative atom
|
1207
|
+
token_to_rep_atom.append(atom_idx + token["disto_idx"] - start)
|
1208
|
+
token_to_center_atom.append(atom_idx + token["center_idx"] - start)
|
1209
|
+
if (chain["mol_type"] != const.chain_type_ids["NONPOLYMER"]) and token[
|
1210
|
+
"resolved_mask"
|
1211
|
+
]:
|
1212
|
+
r_set_to_rep_atom.append(atom_idx + token["center_idx"] - start)
|
1213
|
+
|
1214
|
+
if chain["mol_type"] == const.chain_type_ids["PROTEIN"]:
|
1215
|
+
backbone_index = [
|
1216
|
+
(
|
1217
|
+
const.protein_backbone_atom_index[atom_name] + 1
|
1218
|
+
if atom_name in const.protein_backbone_atom_index
|
1219
|
+
else 0
|
1220
|
+
)
|
1221
|
+
for atom_name in token_atoms["name"]
|
1222
|
+
]
|
1223
|
+
elif (
|
1224
|
+
chain["mol_type"] == const.chain_type_ids["DNA"]
|
1225
|
+
or chain["mol_type"] == const.chain_type_ids["RNA"]
|
1226
|
+
):
|
1227
|
+
backbone_index = [
|
1228
|
+
(
|
1229
|
+
const.nucleic_backbone_atom_index[atom_name]
|
1230
|
+
+ 1
|
1231
|
+
+ len(const.protein_backbone_atom_index)
|
1232
|
+
if atom_name in const.nucleic_backbone_atom_index
|
1233
|
+
else 0
|
1234
|
+
)
|
1235
|
+
for atom_name in token_atoms["name"]
|
1236
|
+
]
|
1237
|
+
else:
|
1238
|
+
backbone_index = [0] * token["atom_num"]
|
1239
|
+
backbone_feat_index.extend(backbone_index)
|
1240
|
+
|
1241
|
+
# Get token coordinates across sampled ensembles and apply transforms
|
1242
|
+
token_coords = np.array(
|
1243
|
+
[
|
1244
|
+
data.structure.coords[
|
1245
|
+
ensemble_atom_start + start : ensemble_atom_start + end
|
1246
|
+
]["coords"]
|
1247
|
+
for ensemble_atom_start in ensemble_atom_starts
|
1248
|
+
]
|
1249
|
+
)
|
1250
|
+
coord_data.append(token_coords)
|
1251
|
+
|
1252
|
+
if compute_frames:
|
1253
|
+
# Get frame data
|
1254
|
+
res_type = const.tokens[token["res_type"]]
|
1255
|
+
res_name = str(token["res_name"])
|
1256
|
+
|
1257
|
+
if token["atom_num"] < 3 or res_type in ["PAD", "UNK", "-"]:
|
1258
|
+
idx_frame_a, idx_frame_b, idx_frame_c = 0, 0, 0
|
1259
|
+
mask_frame = False
|
1260
|
+
elif (token["mol_type"] == const.chain_type_ids["PROTEIN"]) and (
|
1261
|
+
res_name in const.ref_atoms
|
1262
|
+
):
|
1263
|
+
idx_frame_a, idx_frame_b, idx_frame_c = (
|
1264
|
+
const.ref_atoms[res_name].index("N"),
|
1265
|
+
const.ref_atoms[res_name].index("CA"),
|
1266
|
+
const.ref_atoms[res_name].index("C"),
|
1267
|
+
)
|
1268
|
+
mask_frame = (
|
1269
|
+
token_atoms["is_present"][idx_frame_a]
|
1270
|
+
and token_atoms["is_present"][idx_frame_b]
|
1271
|
+
and token_atoms["is_present"][idx_frame_c]
|
1272
|
+
)
|
1273
|
+
elif (
|
1274
|
+
token["mol_type"] == const.chain_type_ids["DNA"]
|
1275
|
+
or token["mol_type"] == const.chain_type_ids["RNA"]
|
1276
|
+
) and (res_name in const.ref_atoms):
|
1277
|
+
idx_frame_a, idx_frame_b, idx_frame_c = (
|
1278
|
+
const.ref_atoms[res_name].index("C1'"),
|
1279
|
+
const.ref_atoms[res_name].index("C3'"),
|
1280
|
+
const.ref_atoms[res_name].index("C4'"),
|
1281
|
+
)
|
1282
|
+
mask_frame = (
|
1283
|
+
token_atoms["is_present"][idx_frame_a]
|
1284
|
+
and token_atoms["is_present"][idx_frame_b]
|
1285
|
+
and token_atoms["is_present"][idx_frame_c]
|
1286
|
+
)
|
1287
|
+
elif token["mol_type"] == const.chain_type_ids["PROTEIN"]:
|
1288
|
+
# Try to look for the atom nams in the modified residue
|
1289
|
+
is_ca = token_atoms["name"] == "CA"
|
1290
|
+
idx_frame_a = is_ca.argmax()
|
1291
|
+
ca_present = (
|
1292
|
+
token_atoms[idx_frame_a]["is_present"] if is_ca.any() else False
|
1293
|
+
)
|
1294
|
+
|
1295
|
+
is_n = token_atoms["name"] == "N"
|
1296
|
+
idx_frame_b = is_n.argmax()
|
1297
|
+
n_present = (
|
1298
|
+
token_atoms[idx_frame_b]["is_present"] if is_n.any() else False
|
1299
|
+
)
|
1300
|
+
|
1301
|
+
is_c = token_atoms["name"] == "C"
|
1302
|
+
idx_frame_c = is_c.argmax()
|
1303
|
+
c_present = (
|
1304
|
+
token_atoms[idx_frame_c]["is_present"] if is_c.any() else False
|
1305
|
+
)
|
1306
|
+
mask_frame = ca_present and n_present and c_present
|
1307
|
+
|
1308
|
+
elif (token["mol_type"] == const.chain_type_ids["DNA"]) or (
|
1309
|
+
token["mol_type"] == const.chain_type_ids["RNA"]
|
1310
|
+
):
|
1311
|
+
# Try to look for the atom nams in the modified residue
|
1312
|
+
is_c1 = token_atoms["name"] == "C1'"
|
1313
|
+
idx_frame_a = is_c1.argmax()
|
1314
|
+
c1_present = (
|
1315
|
+
token_atoms[idx_frame_a]["is_present"] if is_c1.any() else False
|
1316
|
+
)
|
1317
|
+
|
1318
|
+
is_c3 = token_atoms["name"] == "C3'"
|
1319
|
+
idx_frame_b = is_c3.argmax()
|
1320
|
+
c3_present = (
|
1321
|
+
token_atoms[idx_frame_b]["is_present"] if is_c3.any() else False
|
1322
|
+
)
|
1323
|
+
|
1324
|
+
is_c4 = token_atoms["name"] == "C4'"
|
1325
|
+
idx_frame_c = is_c4.argmax()
|
1326
|
+
c4_present = (
|
1327
|
+
token_atoms[idx_frame_c]["is_present"] if is_c4.any() else False
|
1328
|
+
)
|
1329
|
+
mask_frame = c1_present and c3_present and c4_present
|
1330
|
+
else:
|
1331
|
+
idx_frame_a, idx_frame_b, idx_frame_c = 0, 0, 0
|
1332
|
+
mask_frame = False
|
1333
|
+
frame_data.append(
|
1334
|
+
[
|
1335
|
+
idx_frame_a + atom_idx,
|
1336
|
+
idx_frame_b + atom_idx,
|
1337
|
+
idx_frame_c + atom_idx,
|
1338
|
+
]
|
1339
|
+
)
|
1340
|
+
resolved_frame_data.append(mask_frame)
|
1341
|
+
|
1342
|
+
# Get distogram coordinates
|
1343
|
+
disto_coords_ensemble_tok = data.structure.coords[
|
1344
|
+
e_offsets + token["disto_idx"]
|
1345
|
+
]["coords"]
|
1346
|
+
disto_coords_ensemble.append(disto_coords_ensemble_tok)
|
1347
|
+
|
1348
|
+
# Update atom data. This is technically never used again (we rely on coord_data),
|
1349
|
+
# but we update for consistency and to make sure the Atom object has valid, transformed coordinates.
|
1350
|
+
token_atoms = token_atoms.copy()
|
1351
|
+
token_atoms["coords"] = token_coords[
|
1352
|
+
0
|
1353
|
+
] # atom has a copy of first coords in ensemble
|
1354
|
+
atom_data.append(token_atoms)
|
1355
|
+
atom_name.append(token_atom_name)
|
1356
|
+
atom_element.append(token_atoms_element)
|
1357
|
+
atom_charge.append(token_atoms_charge)
|
1358
|
+
atom_conformer.append(token_atoms_conformer)
|
1359
|
+
atom_chirality.append(token_atoms_chirality)
|
1360
|
+
atom_idx += len(token_atoms)
|
1361
|
+
|
1362
|
+
disto_coords_ensemble = np.array(disto_coords_ensemble) # (N_TOK, N_ENS, 3)
|
1363
|
+
|
1364
|
+
# Compute ensemble distogram
|
1365
|
+
L = len(data.tokens)
|
1366
|
+
|
1367
|
+
if disto_use_ensemble:
|
1368
|
+
# Use all available structures to create distogram
|
1369
|
+
idx_list = range(disto_coords_ensemble.shape[1])
|
1370
|
+
else:
|
1371
|
+
# Only use a sampled structures to create distogram
|
1372
|
+
idx_list = ensemble_features["ensemble_ref_idxs"]
|
1373
|
+
|
1374
|
+
# Create distogram
|
1375
|
+
disto_target = torch.zeros(L, L, len(idx_list), num_bins) # TODO1
|
1376
|
+
|
1377
|
+
# disto_target = torch.zeros(L, L, num_bins)
|
1378
|
+
for i, e_idx in enumerate(idx_list):
|
1379
|
+
t_center = torch.Tensor(disto_coords_ensemble[:, e_idx, :])
|
1380
|
+
t_dists = torch.cdist(t_center, t_center)
|
1381
|
+
boundaries = torch.linspace(min_dist, max_dist, num_bins - 1)
|
1382
|
+
distogram = (t_dists.unsqueeze(-1) > boundaries).sum(dim=-1).long()
|
1383
|
+
# disto_target += one_hot(distogram, num_classes=num_bins)
|
1384
|
+
disto_target[:, :, i, :] = one_hot(distogram, num_classes=num_bins) # TODO1
|
1385
|
+
|
1386
|
+
# Normalize distogram
|
1387
|
+
# disto_target = disto_target / disto_target.sum(-1)[..., None] # remove TODO1
|
1388
|
+
atom_data = np.concatenate(atom_data)
|
1389
|
+
atom_name = np.concatenate(atom_name)
|
1390
|
+
atom_element = np.concatenate(atom_element)
|
1391
|
+
atom_charge = np.concatenate(atom_charge)
|
1392
|
+
atom_conformer = np.concatenate(atom_conformer)
|
1393
|
+
atom_chirality = np.concatenate(atom_chirality)
|
1394
|
+
coord_data = np.concatenate(coord_data, axis=1)
|
1395
|
+
ref_space_uid = np.array(ref_space_uid)
|
1396
|
+
|
1397
|
+
# Compute features
|
1398
|
+
disto_coords_ensemble = from_numpy(disto_coords_ensemble)
|
1399
|
+
disto_coords_ensemble = disto_coords_ensemble[
|
1400
|
+
:, ensemble_features["ensemble_ref_idxs"]
|
1401
|
+
].permute(1, 0, 2)
|
1402
|
+
backbone_feat_index = from_numpy(np.asarray(backbone_feat_index)).long()
|
1403
|
+
ref_atom_name_chars = from_numpy(atom_name).long()
|
1404
|
+
ref_element = from_numpy(atom_element).long()
|
1405
|
+
ref_charge = from_numpy(atom_charge).float()
|
1406
|
+
ref_pos = from_numpy(atom_conformer).float()
|
1407
|
+
ref_space_uid = from_numpy(ref_space_uid)
|
1408
|
+
ref_chirality = from_numpy(atom_chirality).long()
|
1409
|
+
coords = from_numpy(coord_data.copy())
|
1410
|
+
resolved_mask = from_numpy(atom_data["is_present"])
|
1411
|
+
pad_mask = torch.ones(len(atom_data), dtype=torch.float)
|
1412
|
+
atom_to_token = torch.tensor(atom_to_token, dtype=torch.long)
|
1413
|
+
token_to_rep_atom = torch.tensor(token_to_rep_atom, dtype=torch.long)
|
1414
|
+
r_set_to_rep_atom = torch.tensor(r_set_to_rep_atom, dtype=torch.long)
|
1415
|
+
token_to_center_atom = torch.tensor(token_to_center_atom, dtype=torch.long)
|
1416
|
+
bfactor = from_numpy(atom_data["bfactor"].copy())
|
1417
|
+
plddt = from_numpy(atom_data["plddt"].copy())
|
1418
|
+
if override_bfactor:
|
1419
|
+
bfactor = bfactor * 0.0
|
1420
|
+
|
1421
|
+
if bfactor_md_correction and data.record.structure.method.lower() == "md":
|
1422
|
+
# MD bfactor was computed as RMSF
|
1423
|
+
# Convert to b-factor
|
1424
|
+
bfactor = 8 * (np.pi**2) * (bfactor**2)
|
1425
|
+
|
1426
|
+
# We compute frames within ensemble
|
1427
|
+
if compute_frames:
|
1428
|
+
frames = []
|
1429
|
+
frame_resolved_mask = []
|
1430
|
+
for i in range(coord_data.shape[0]):
|
1431
|
+
frame_data_, resolved_frame_data_ = compute_frames_nonpolymer(
|
1432
|
+
data,
|
1433
|
+
coord_data[i],
|
1434
|
+
atom_data["is_present"],
|
1435
|
+
atom_to_token,
|
1436
|
+
frame_data,
|
1437
|
+
resolved_frame_data,
|
1438
|
+
) # Compute frames for NONPOLYMER tokens
|
1439
|
+
frames.append(frame_data_.copy())
|
1440
|
+
frame_resolved_mask.append(resolved_frame_data_.copy())
|
1441
|
+
frames = from_numpy(np.stack(frames)) # (N_ENS, N_TOK, 3)
|
1442
|
+
frame_resolved_mask = from_numpy(np.stack(frame_resolved_mask))
|
1443
|
+
|
1444
|
+
# Convert to one-hot
|
1445
|
+
backbone_feat_index = one_hot(
|
1446
|
+
backbone_feat_index,
|
1447
|
+
num_classes=1
|
1448
|
+
+ len(const.protein_backbone_atom_index)
|
1449
|
+
+ len(const.nucleic_backbone_atom_index),
|
1450
|
+
)
|
1451
|
+
ref_atom_name_chars = one_hot(ref_atom_name_chars, num_classes=64)
|
1452
|
+
ref_element = one_hot(ref_element, num_classes=const.num_elements)
|
1453
|
+
atom_to_token = one_hot(atom_to_token, num_classes=token_id + 1)
|
1454
|
+
token_to_rep_atom = one_hot(token_to_rep_atom, num_classes=len(atom_data))
|
1455
|
+
r_set_to_rep_atom = one_hot(r_set_to_rep_atom, num_classes=len(atom_data))
|
1456
|
+
token_to_center_atom = one_hot(token_to_center_atom, num_classes=len(atom_data))
|
1457
|
+
|
1458
|
+
# Center the ground truth coordinates
|
1459
|
+
center = (coords * resolved_mask[None, :, None]).sum(dim=1)
|
1460
|
+
center = center / resolved_mask.sum().clamp(min=1)
|
1461
|
+
coords = coords - center[:, None]
|
1462
|
+
|
1463
|
+
if isinstance(override_coords, Tensor):
|
1464
|
+
coords = override_coords.unsqueeze(0)
|
1465
|
+
|
1466
|
+
# Apply random roto-translation to the input conformers
|
1467
|
+
for i in range(torch.max(ref_space_uid)):
|
1468
|
+
included = ref_space_uid == i
|
1469
|
+
if torch.sum(included) > 0 and torch.any(resolved_mask[included]):
|
1470
|
+
ref_pos[included] = center_random_augmentation(
|
1471
|
+
ref_pos[included][None], resolved_mask[included][None], centering=True
|
1472
|
+
)[0]
|
1473
|
+
|
1474
|
+
# Compute padding and apply
|
1475
|
+
if max_atoms is not None:
|
1476
|
+
assert max_atoms % atoms_per_window_queries == 0
|
1477
|
+
pad_len = max_atoms - len(atom_data)
|
1478
|
+
else:
|
1479
|
+
pad_len = (
|
1480
|
+
(len(atom_data) - 1) // atoms_per_window_queries + 1
|
1481
|
+
) * atoms_per_window_queries - len(atom_data)
|
1482
|
+
|
1483
|
+
if pad_len > 0:
|
1484
|
+
pad_mask = pad_dim(pad_mask, 0, pad_len)
|
1485
|
+
ref_pos = pad_dim(ref_pos, 0, pad_len)
|
1486
|
+
resolved_mask = pad_dim(resolved_mask, 0, pad_len)
|
1487
|
+
ref_atom_name_chars = pad_dim(ref_atom_name_chars, 0, pad_len)
|
1488
|
+
ref_element = pad_dim(ref_element, 0, pad_len)
|
1489
|
+
ref_charge = pad_dim(ref_charge, 0, pad_len)
|
1490
|
+
ref_chirality = pad_dim(ref_chirality, 0, pad_len)
|
1491
|
+
backbone_feat_index = pad_dim(backbone_feat_index, 0, pad_len)
|
1492
|
+
ref_space_uid = pad_dim(ref_space_uid, 0, pad_len)
|
1493
|
+
coords = pad_dim(coords, 1, pad_len)
|
1494
|
+
atom_to_token = pad_dim(atom_to_token, 0, pad_len)
|
1495
|
+
token_to_rep_atom = pad_dim(token_to_rep_atom, 1, pad_len)
|
1496
|
+
token_to_center_atom = pad_dim(token_to_center_atom, 1, pad_len)
|
1497
|
+
r_set_to_rep_atom = pad_dim(r_set_to_rep_atom, 1, pad_len)
|
1498
|
+
bfactor = pad_dim(bfactor, 0, pad_len)
|
1499
|
+
plddt = pad_dim(plddt, 0, pad_len)
|
1500
|
+
|
1501
|
+
if max_tokens is not None:
|
1502
|
+
pad_len = max_tokens - token_to_rep_atom.shape[0]
|
1503
|
+
if pad_len > 0:
|
1504
|
+
atom_to_token = pad_dim(atom_to_token, 1, pad_len)
|
1505
|
+
token_to_rep_atom = pad_dim(token_to_rep_atom, 0, pad_len)
|
1506
|
+
r_set_to_rep_atom = pad_dim(r_set_to_rep_atom, 0, pad_len)
|
1507
|
+
token_to_center_atom = pad_dim(token_to_center_atom, 0, pad_len)
|
1508
|
+
disto_target = pad_dim(pad_dim(disto_target, 0, pad_len), 1, pad_len)
|
1509
|
+
disto_coords_ensemble = pad_dim(disto_coords_ensemble, 1, pad_len)
|
1510
|
+
|
1511
|
+
if compute_frames:
|
1512
|
+
frames = pad_dim(frames, 1, pad_len)
|
1513
|
+
frame_resolved_mask = pad_dim(frame_resolved_mask, 1, pad_len)
|
1514
|
+
|
1515
|
+
atom_features = {
|
1516
|
+
"ref_pos": ref_pos,
|
1517
|
+
"atom_resolved_mask": resolved_mask,
|
1518
|
+
"ref_atom_name_chars": ref_atom_name_chars,
|
1519
|
+
"ref_element": ref_element,
|
1520
|
+
"ref_charge": ref_charge,
|
1521
|
+
"ref_chirality": ref_chirality,
|
1522
|
+
"atom_backbone_feat": backbone_feat_index,
|
1523
|
+
"ref_space_uid": ref_space_uid,
|
1524
|
+
"coords": coords,
|
1525
|
+
"atom_pad_mask": pad_mask,
|
1526
|
+
"atom_to_token": atom_to_token,
|
1527
|
+
"token_to_rep_atom": token_to_rep_atom,
|
1528
|
+
"r_set_to_rep_atom": r_set_to_rep_atom,
|
1529
|
+
"token_to_center_atom": token_to_center_atom,
|
1530
|
+
"disto_target": disto_target,
|
1531
|
+
"disto_coords_ensemble": disto_coords_ensemble,
|
1532
|
+
"bfactor": bfactor,
|
1533
|
+
"plddt": plddt,
|
1534
|
+
}
|
1535
|
+
|
1536
|
+
if compute_frames:
|
1537
|
+
atom_features["frames_idx"] = frames
|
1538
|
+
atom_features["frame_resolved_mask"] = frame_resolved_mask
|
1539
|
+
|
1540
|
+
return atom_features
|
1541
|
+
|
1542
|
+
|
1543
|
+
def process_msa_features(
|
1544
|
+
data: Tokenized,
|
1545
|
+
random: np.random.Generator,
|
1546
|
+
max_seqs_batch: int,
|
1547
|
+
max_seqs: int,
|
1548
|
+
max_tokens: Optional[int] = None,
|
1549
|
+
pad_to_max_seqs: bool = False,
|
1550
|
+
msa_sampling: bool = False,
|
1551
|
+
affinity: bool = False,
|
1552
|
+
) -> dict[str, Tensor]:
|
1553
|
+
"""Get the MSA features.
|
1554
|
+
|
1555
|
+
Parameters
|
1556
|
+
----------
|
1557
|
+
data : Tokenized
|
1558
|
+
The input to the model.
|
1559
|
+
random : np.random.Generator
|
1560
|
+
The random number generator.
|
1561
|
+
max_seqs : int
|
1562
|
+
The maximum number of MSA sequences.
|
1563
|
+
max_tokens : int
|
1564
|
+
The maximum number of tokens.
|
1565
|
+
pad_to_max_seqs : bool
|
1566
|
+
Whether to pad to the maximum number of sequences.
|
1567
|
+
msa_sampling : bool
|
1568
|
+
Whether to sample the MSA.
|
1569
|
+
|
1570
|
+
Returns
|
1571
|
+
-------
|
1572
|
+
dict[str, Tensor]
|
1573
|
+
The MSA features.
|
1574
|
+
|
1575
|
+
"""
|
1576
|
+
# Created paired MSA
|
1577
|
+
msa, deletion, paired = construct_paired_msa(
|
1578
|
+
data=data,
|
1579
|
+
random=random,
|
1580
|
+
max_seqs=max_seqs_batch,
|
1581
|
+
random_subset=msa_sampling,
|
1582
|
+
)
|
1583
|
+
msa, deletion, paired = (
|
1584
|
+
msa.transpose(1, 0),
|
1585
|
+
deletion.transpose(1, 0),
|
1586
|
+
paired.transpose(1, 0),
|
1587
|
+
) # (N_MSA, N_RES, N_AA)
|
1588
|
+
|
1589
|
+
# Prepare features
|
1590
|
+
assert torch.all(msa >= 0) and torch.all(msa < const.num_tokens)
|
1591
|
+
msa_one_hot = torch.nn.functional.one_hot(msa, num_classes=const.num_tokens)
|
1592
|
+
msa_mask = torch.ones_like(msa)
|
1593
|
+
profile = msa_one_hot.float().mean(dim=0)
|
1594
|
+
has_deletion = deletion > 0
|
1595
|
+
deletion = np.pi / 2 * np.arctan(deletion / 3)
|
1596
|
+
deletion_mean = deletion.mean(axis=0)
|
1597
|
+
|
1598
|
+
# Pad in the MSA dimension (dim=0)
|
1599
|
+
if pad_to_max_seqs:
|
1600
|
+
pad_len = max_seqs - msa.shape[0]
|
1601
|
+
if pad_len > 0:
|
1602
|
+
msa = pad_dim(msa, 0, pad_len, const.token_ids["-"])
|
1603
|
+
paired = pad_dim(paired, 0, pad_len)
|
1604
|
+
msa_mask = pad_dim(msa_mask, 0, pad_len)
|
1605
|
+
has_deletion = pad_dim(has_deletion, 0, pad_len)
|
1606
|
+
deletion = pad_dim(deletion, 0, pad_len)
|
1607
|
+
|
1608
|
+
# Pad in the token dimension (dim=1)
|
1609
|
+
if max_tokens is not None:
|
1610
|
+
pad_len = max_tokens - msa.shape[1]
|
1611
|
+
if pad_len > 0:
|
1612
|
+
msa = pad_dim(msa, 1, pad_len, const.token_ids["-"])
|
1613
|
+
paired = pad_dim(paired, 1, pad_len)
|
1614
|
+
msa_mask = pad_dim(msa_mask, 1, pad_len)
|
1615
|
+
has_deletion = pad_dim(has_deletion, 1, pad_len)
|
1616
|
+
deletion = pad_dim(deletion, 1, pad_len)
|
1617
|
+
profile = pad_dim(profile, 0, pad_len)
|
1618
|
+
deletion_mean = pad_dim(deletion_mean, 0, pad_len)
|
1619
|
+
if affinity:
|
1620
|
+
return {
|
1621
|
+
"deletion_mean_affinity": deletion_mean,
|
1622
|
+
"profile_affinity": profile,
|
1623
|
+
}
|
1624
|
+
else:
|
1625
|
+
return {
|
1626
|
+
"msa": msa,
|
1627
|
+
"msa_paired": paired,
|
1628
|
+
"deletion_value": deletion,
|
1629
|
+
"has_deletion": has_deletion,
|
1630
|
+
"deletion_mean": deletion_mean,
|
1631
|
+
"profile": profile,
|
1632
|
+
"msa_mask": msa_mask,
|
1633
|
+
}
|
1634
|
+
|
1635
|
+
|
1636
|
+
def load_dummy_templates_features(tdim: int, num_tokens: int) -> dict:
|
1637
|
+
"""Load dummy templates for v2."""
|
1638
|
+
# Allocate features
|
1639
|
+
res_type = np.zeros((tdim, num_tokens), dtype=np.int64)
|
1640
|
+
frame_rot = np.zeros((tdim, num_tokens, 3, 3), dtype=np.float32)
|
1641
|
+
frame_t = np.zeros((tdim, num_tokens, 3), dtype=np.float32)
|
1642
|
+
cb_coords = np.zeros((tdim, num_tokens, 3), dtype=np.float32)
|
1643
|
+
ca_coords = np.zeros((tdim, num_tokens, 3), dtype=np.float32)
|
1644
|
+
frame_mask = np.zeros((tdim, num_tokens), dtype=np.float32)
|
1645
|
+
cb_mask = np.zeros((tdim, num_tokens), dtype=np.float32)
|
1646
|
+
template_mask = np.zeros((tdim, num_tokens), dtype=np.float32)
|
1647
|
+
query_to_template = np.zeros((tdim, num_tokens), dtype=np.int64)
|
1648
|
+
visibility_ids = np.zeros((tdim, num_tokens), dtype=np.float32)
|
1649
|
+
|
1650
|
+
# Convert to one-hot
|
1651
|
+
res_type = torch.from_numpy(res_type)
|
1652
|
+
res_type = one_hot(res_type, num_classes=const.num_tokens)
|
1653
|
+
|
1654
|
+
return {
|
1655
|
+
"template_restype": res_type,
|
1656
|
+
"template_frame_rot": torch.from_numpy(frame_rot),
|
1657
|
+
"template_frame_t": torch.from_numpy(frame_t),
|
1658
|
+
"template_cb": torch.from_numpy(cb_coords),
|
1659
|
+
"template_ca": torch.from_numpy(ca_coords),
|
1660
|
+
"template_mask_cb": torch.from_numpy(cb_mask),
|
1661
|
+
"template_mask_frame": torch.from_numpy(frame_mask),
|
1662
|
+
"template_mask": torch.from_numpy(template_mask),
|
1663
|
+
"query_to_template": torch.from_numpy(query_to_template),
|
1664
|
+
"visibility_ids": torch.from_numpy(visibility_ids),
|
1665
|
+
}
|
1666
|
+
|
1667
|
+
|
1668
|
+
def compute_template_features(
|
1669
|
+
query_tokens: Tokenized,
|
1670
|
+
tmpl_tokens: list[dict],
|
1671
|
+
num_tokens: int,
|
1672
|
+
) -> dict:
|
1673
|
+
"""Compute the template features."""
|
1674
|
+
# Allocate features
|
1675
|
+
res_type = np.zeros((num_tokens,), dtype=np.int64)
|
1676
|
+
frame_rot = np.zeros((num_tokens, 3, 3), dtype=np.float32)
|
1677
|
+
frame_t = np.zeros((num_tokens, 3), dtype=np.float32)
|
1678
|
+
cb_coords = np.zeros((num_tokens, 3), dtype=np.float32)
|
1679
|
+
ca_coords = np.zeros((num_tokens, 3), dtype=np.float32)
|
1680
|
+
frame_mask = np.zeros((num_tokens,), dtype=np.float32)
|
1681
|
+
cb_mask = np.zeros((num_tokens,), dtype=np.float32)
|
1682
|
+
template_mask = np.zeros((num_tokens,), dtype=np.float32)
|
1683
|
+
query_to_template = np.zeros((num_tokens,), dtype=np.int64)
|
1684
|
+
visibility_ids = np.zeros((num_tokens,), dtype=np.float32)
|
1685
|
+
|
1686
|
+
# Now create features per token
|
1687
|
+
asym_id_to_pdb_id = {}
|
1688
|
+
|
1689
|
+
for token_dict in tmpl_tokens:
|
1690
|
+
idx = token_dict["q_idx"]
|
1691
|
+
pdb_id = token_dict["pdb_id"]
|
1692
|
+
token = token_dict["token"]
|
1693
|
+
query_token = query_tokens.tokens[idx]
|
1694
|
+
asym_id_to_pdb_id[query_token["asym_id"]] = pdb_id
|
1695
|
+
res_type[idx] = token["res_type"]
|
1696
|
+
frame_rot[idx] = token["frame_rot"].reshape(3, 3)
|
1697
|
+
frame_t[idx] = token["frame_t"]
|
1698
|
+
cb_coords[idx] = token["disto_coords"]
|
1699
|
+
ca_coords[idx] = token["center_coords"]
|
1700
|
+
cb_mask[idx] = token["disto_mask"]
|
1701
|
+
frame_mask[idx] = token["frame_mask"]
|
1702
|
+
template_mask[idx] = 1.0
|
1703
|
+
|
1704
|
+
# Set visibility_id for templated chains
|
1705
|
+
for asym_id, pdb_id in asym_id_to_pdb_id.items():
|
1706
|
+
indices = (query_tokens.tokens["asym_id"] == asym_id).nonzero()
|
1707
|
+
visibility_ids[indices] = pdb_id
|
1708
|
+
|
1709
|
+
# Set visibility for non templated chain + olygomerics
|
1710
|
+
for asym_id in np.unique(query_tokens.structure.chains["asym_id"]):
|
1711
|
+
if asym_id not in asym_id_to_pdb_id:
|
1712
|
+
# We hack the chain id to be negative to not overlap with the above
|
1713
|
+
indices = (query_tokens.tokens["asym_id"] == asym_id).nonzero()
|
1714
|
+
visibility_ids[indices] = -1 - asym_id
|
1715
|
+
|
1716
|
+
# Convert to one-hot
|
1717
|
+
res_type = torch.from_numpy(res_type)
|
1718
|
+
res_type = one_hot(res_type, num_classes=const.num_tokens)
|
1719
|
+
|
1720
|
+
return {
|
1721
|
+
"template_restype": res_type,
|
1722
|
+
"template_frame_rot": torch.from_numpy(frame_rot),
|
1723
|
+
"template_frame_t": torch.from_numpy(frame_t),
|
1724
|
+
"template_cb": torch.from_numpy(cb_coords),
|
1725
|
+
"template_ca": torch.from_numpy(ca_coords),
|
1726
|
+
"template_mask_cb": torch.from_numpy(cb_mask),
|
1727
|
+
"template_mask_frame": torch.from_numpy(frame_mask),
|
1728
|
+
"template_mask": torch.from_numpy(template_mask),
|
1729
|
+
"query_to_template": torch.from_numpy(query_to_template),
|
1730
|
+
"visibility_ids": torch.from_numpy(visibility_ids),
|
1731
|
+
}
|
1732
|
+
|
1733
|
+
|
1734
|
+
def process_template_features(
|
1735
|
+
data: Tokenized,
|
1736
|
+
max_tokens: int,
|
1737
|
+
) -> dict[str, torch.Tensor]:
|
1738
|
+
"""Load the given input data.
|
1739
|
+
|
1740
|
+
Parameters
|
1741
|
+
----------
|
1742
|
+
data : Tokenized
|
1743
|
+
The input to the model.
|
1744
|
+
max_tokens : int
|
1745
|
+
The maximum number of tokens.
|
1746
|
+
|
1747
|
+
Returns
|
1748
|
+
-------
|
1749
|
+
dict[str, torch.Tensor]
|
1750
|
+
The loaded template features.
|
1751
|
+
|
1752
|
+
"""
|
1753
|
+
# Group templates by name
|
1754
|
+
name_to_templates: dict[str, list[TemplateInfo]] = {}
|
1755
|
+
for template_info in data.record.templates:
|
1756
|
+
name_to_templates.setdefault(template_info.name, []).append(template_info)
|
1757
|
+
|
1758
|
+
# Map chain name to asym_id
|
1759
|
+
chain_name_to_asym_id = {}
|
1760
|
+
for chain in data.structure.chains:
|
1761
|
+
chain_name_to_asym_id[chain["name"]] = chain["asym_id"]
|
1762
|
+
|
1763
|
+
# Compute the offset
|
1764
|
+
template_features = []
|
1765
|
+
for template_id, (template_name, templates) in enumerate(name_to_templates.items()):
|
1766
|
+
row_tokens = []
|
1767
|
+
template_structure = data.templates[template_name]
|
1768
|
+
template_tokens = data.template_tokens[template_name]
|
1769
|
+
tmpl_chain_name_to_asym_id = {}
|
1770
|
+
for chain in template_structure.chains:
|
1771
|
+
tmpl_chain_name_to_asym_id[chain["name"]] = chain["asym_id"]
|
1772
|
+
|
1773
|
+
for template in templates:
|
1774
|
+
offset = template.template_st - template.query_st
|
1775
|
+
|
1776
|
+
# Get query and template tokens to map residues
|
1777
|
+
query_tokens = data.tokens
|
1778
|
+
chain_id = chain_name_to_asym_id[template.query_chain]
|
1779
|
+
q_tokens = query_tokens[query_tokens["asym_id"] == chain_id]
|
1780
|
+
q_indices = dict(zip(q_tokens["res_idx"], q_tokens["token_idx"]))
|
1781
|
+
|
1782
|
+
# Get the template tokens at the query residues
|
1783
|
+
chain_id = tmpl_chain_name_to_asym_id[template.template_chain]
|
1784
|
+
toks = template_tokens[template_tokens["asym_id"] == chain_id]
|
1785
|
+
toks = [t for t in toks if t["res_idx"] - offset in q_indices]
|
1786
|
+
for t in toks:
|
1787
|
+
q_idx = q_indices[t["res_idx"] - offset]
|
1788
|
+
row_tokens.append(
|
1789
|
+
{
|
1790
|
+
"token": t,
|
1791
|
+
"pdb_id": template_id,
|
1792
|
+
"q_idx": q_idx,
|
1793
|
+
}
|
1794
|
+
)
|
1795
|
+
|
1796
|
+
# Compute template features for each row
|
1797
|
+
row_features = compute_template_features(data, row_tokens, max_tokens)
|
1798
|
+
template_features.append(row_features)
|
1799
|
+
|
1800
|
+
# Stack each feature
|
1801
|
+
out = {}
|
1802
|
+
for k in template_features[0]:
|
1803
|
+
out[k] = torch.stack([f[k] for f in template_features])
|
1804
|
+
|
1805
|
+
return out
|
1806
|
+
|
1807
|
+
|
1808
|
+
def process_symmetry_features(
|
1809
|
+
cropped: Tokenized, symmetries: dict
|
1810
|
+
) -> dict[str, Tensor]:
|
1811
|
+
"""Get the symmetry features.
|
1812
|
+
|
1813
|
+
Parameters
|
1814
|
+
----------
|
1815
|
+
data : Tokenized
|
1816
|
+
The input to the model.
|
1817
|
+
|
1818
|
+
Returns
|
1819
|
+
-------
|
1820
|
+
dict[str, Tensor]
|
1821
|
+
The symmetry features.
|
1822
|
+
|
1823
|
+
"""
|
1824
|
+
features = get_chain_symmetries(cropped)
|
1825
|
+
features.update(get_amino_acids_symmetries(cropped))
|
1826
|
+
features.update(get_ligand_symmetries(cropped, symmetries))
|
1827
|
+
|
1828
|
+
return features
|
1829
|
+
|
1830
|
+
|
1831
|
+
def process_ensemble_features(
|
1832
|
+
data: Tokenized,
|
1833
|
+
random: np.random.Generator,
|
1834
|
+
num_ensembles: int,
|
1835
|
+
ensemble_sample_replacement: bool,
|
1836
|
+
fix_single_ensemble: bool,
|
1837
|
+
) -> dict[str, Tensor]:
|
1838
|
+
"""Get the ensemble features.
|
1839
|
+
|
1840
|
+
Parameters
|
1841
|
+
----------
|
1842
|
+
data : Tokenized
|
1843
|
+
The input to the model.
|
1844
|
+
random : np.random.Generator
|
1845
|
+
The random number generator.
|
1846
|
+
num_ensembles : int
|
1847
|
+
The maximum number of ensembles to sample.
|
1848
|
+
ensemble_sample_replacement : bool
|
1849
|
+
Whether to sample with replacement.
|
1850
|
+
|
1851
|
+
Returns
|
1852
|
+
-------
|
1853
|
+
dict[str, Tensor]
|
1854
|
+
The ensemble features.
|
1855
|
+
|
1856
|
+
"""
|
1857
|
+
assert num_ensembles > 0, "Number of conformers sampled must be greater than 0."
|
1858
|
+
|
1859
|
+
# Number of available conformers in the structure
|
1860
|
+
# s_ensemble_num = min(len(cropped.structure.ensemble), 24) # Limit to 24 conformers DEBUG: TODO: remove !
|
1861
|
+
s_ensemble_num = len(data.structure.ensemble)
|
1862
|
+
|
1863
|
+
if fix_single_ensemble:
|
1864
|
+
# Always take the first conformer for train and validation
|
1865
|
+
assert (
|
1866
|
+
num_ensembles == 1
|
1867
|
+
), "Number of conformers sampled must be 1 with fix_single_ensemble=True."
|
1868
|
+
ensemble_ref_idxs = np.array([0])
|
1869
|
+
else:
|
1870
|
+
if ensemble_sample_replacement:
|
1871
|
+
# Used in training
|
1872
|
+
ensemble_ref_idxs = random.integers(0, s_ensemble_num, (num_ensembles,))
|
1873
|
+
else:
|
1874
|
+
# Used in validation
|
1875
|
+
if s_ensemble_num < num_ensembles:
|
1876
|
+
# Take all available conformers
|
1877
|
+
ensemble_ref_idxs = np.arange(0, s_ensemble_num)
|
1878
|
+
else:
|
1879
|
+
# Sample without replacement
|
1880
|
+
ensemble_ref_idxs = random.choice(
|
1881
|
+
s_ensemble_num, num_ensembles, replace=False
|
1882
|
+
)
|
1883
|
+
|
1884
|
+
ensemble_features = {
|
1885
|
+
"ensemble_ref_idxs": torch.Tensor(ensemble_ref_idxs).long(),
|
1886
|
+
}
|
1887
|
+
|
1888
|
+
return ensemble_features
|
1889
|
+
|
1890
|
+
|
1891
|
+
def process_residue_constraint_features(data: Tokenized) -> dict[str, Tensor]:
|
1892
|
+
residue_constraints = data.residue_constraints
|
1893
|
+
if residue_constraints is not None:
|
1894
|
+
rdkit_bounds_constraints = residue_constraints.rdkit_bounds_constraints
|
1895
|
+
chiral_atom_constraints = residue_constraints.chiral_atom_constraints
|
1896
|
+
stereo_bond_constraints = residue_constraints.stereo_bond_constraints
|
1897
|
+
planar_bond_constraints = residue_constraints.planar_bond_constraints
|
1898
|
+
planar_ring_5_constraints = residue_constraints.planar_ring_5_constraints
|
1899
|
+
planar_ring_6_constraints = residue_constraints.planar_ring_6_constraints
|
1900
|
+
|
1901
|
+
rdkit_bounds_index = torch.tensor(
|
1902
|
+
rdkit_bounds_constraints["atom_idxs"].copy(), dtype=torch.long
|
1903
|
+
).T
|
1904
|
+
rdkit_bounds_bond_mask = torch.tensor(
|
1905
|
+
rdkit_bounds_constraints["is_bond"].copy(), dtype=torch.bool
|
1906
|
+
)
|
1907
|
+
rdkit_bounds_angle_mask = torch.tensor(
|
1908
|
+
rdkit_bounds_constraints["is_angle"].copy(), dtype=torch.bool
|
1909
|
+
)
|
1910
|
+
rdkit_upper_bounds = torch.tensor(
|
1911
|
+
rdkit_bounds_constraints["upper_bound"].copy(), dtype=torch.float
|
1912
|
+
)
|
1913
|
+
rdkit_lower_bounds = torch.tensor(
|
1914
|
+
rdkit_bounds_constraints["lower_bound"].copy(), dtype=torch.float
|
1915
|
+
)
|
1916
|
+
|
1917
|
+
chiral_atom_index = torch.tensor(
|
1918
|
+
chiral_atom_constraints["atom_idxs"].copy(), dtype=torch.long
|
1919
|
+
).T
|
1920
|
+
chiral_reference_mask = torch.tensor(
|
1921
|
+
chiral_atom_constraints["is_reference"].copy(), dtype=torch.bool
|
1922
|
+
)
|
1923
|
+
chiral_atom_orientations = torch.tensor(
|
1924
|
+
chiral_atom_constraints["is_r"].copy(), dtype=torch.bool
|
1925
|
+
)
|
1926
|
+
|
1927
|
+
stereo_bond_index = torch.tensor(
|
1928
|
+
stereo_bond_constraints["atom_idxs"].copy(), dtype=torch.long
|
1929
|
+
).T
|
1930
|
+
stereo_reference_mask = torch.tensor(
|
1931
|
+
stereo_bond_constraints["is_reference"].copy(), dtype=torch.bool
|
1932
|
+
)
|
1933
|
+
stereo_bond_orientations = torch.tensor(
|
1934
|
+
stereo_bond_constraints["is_e"].copy(), dtype=torch.bool
|
1935
|
+
)
|
1936
|
+
|
1937
|
+
planar_bond_index = torch.tensor(
|
1938
|
+
planar_bond_constraints["atom_idxs"].copy(), dtype=torch.long
|
1939
|
+
).T
|
1940
|
+
planar_ring_5_index = torch.tensor(
|
1941
|
+
planar_ring_5_constraints["atom_idxs"].copy(), dtype=torch.long
|
1942
|
+
).T
|
1943
|
+
planar_ring_6_index = torch.tensor(
|
1944
|
+
planar_ring_6_constraints["atom_idxs"].copy(), dtype=torch.long
|
1945
|
+
).T
|
1946
|
+
else:
|
1947
|
+
rdkit_bounds_index = torch.empty((2, 0), dtype=torch.long)
|
1948
|
+
rdkit_bounds_bond_mask = torch.empty((0,), dtype=torch.bool)
|
1949
|
+
rdkit_bounds_angle_mask = torch.empty((0,), dtype=torch.bool)
|
1950
|
+
rdkit_upper_bounds = torch.empty((0,), dtype=torch.float)
|
1951
|
+
rdkit_lower_bounds = torch.empty((0,), dtype=torch.float)
|
1952
|
+
chiral_atom_index = torch.empty(
|
1953
|
+
(
|
1954
|
+
4,
|
1955
|
+
0,
|
1956
|
+
),
|
1957
|
+
dtype=torch.long,
|
1958
|
+
)
|
1959
|
+
chiral_reference_mask = torch.empty((0,), dtype=torch.bool)
|
1960
|
+
chiral_atom_orientations = torch.empty((0,), dtype=torch.bool)
|
1961
|
+
stereo_bond_index = torch.empty((4, 0), dtype=torch.long)
|
1962
|
+
stereo_reference_mask = torch.empty((0,), dtype=torch.bool)
|
1963
|
+
stereo_bond_orientations = torch.empty((0,), dtype=torch.bool)
|
1964
|
+
planar_bond_index = torch.empty((6, 0), dtype=torch.long)
|
1965
|
+
planar_ring_5_index = torch.empty((5, 0), dtype=torch.long)
|
1966
|
+
planar_ring_6_index = torch.empty((6, 0), dtype=torch.long)
|
1967
|
+
|
1968
|
+
return {
|
1969
|
+
"rdkit_bounds_index": rdkit_bounds_index,
|
1970
|
+
"rdkit_bounds_bond_mask": rdkit_bounds_bond_mask,
|
1971
|
+
"rdkit_bounds_angle_mask": rdkit_bounds_angle_mask,
|
1972
|
+
"rdkit_upper_bounds": rdkit_upper_bounds,
|
1973
|
+
"rdkit_lower_bounds": rdkit_lower_bounds,
|
1974
|
+
"chiral_atom_index": chiral_atom_index,
|
1975
|
+
"chiral_reference_mask": chiral_reference_mask,
|
1976
|
+
"chiral_atom_orientations": chiral_atom_orientations,
|
1977
|
+
"stereo_bond_index": stereo_bond_index,
|
1978
|
+
"stereo_reference_mask": stereo_reference_mask,
|
1979
|
+
"stereo_bond_orientations": stereo_bond_orientations,
|
1980
|
+
"planar_bond_index": planar_bond_index,
|
1981
|
+
"planar_ring_5_index": planar_ring_5_index,
|
1982
|
+
"planar_ring_6_index": planar_ring_6_index,
|
1983
|
+
}
|
1984
|
+
|
1985
|
+
|
1986
|
+
def process_chain_feature_constraints(data: Tokenized) -> dict[str, Tensor]:
|
1987
|
+
structure = data.structure
|
1988
|
+
if structure.bonds.shape[0] > 0:
|
1989
|
+
connected_chain_index, connected_atom_index = [], []
|
1990
|
+
for connection in structure.bonds:
|
1991
|
+
if connection["chain_1"] == connection["chain_2"]:
|
1992
|
+
continue
|
1993
|
+
connected_chain_index.append([connection["chain_1"], connection["chain_2"]])
|
1994
|
+
connected_atom_index.append([connection["atom_1"], connection["atom_2"]])
|
1995
|
+
if len(connected_chain_index) > 0:
|
1996
|
+
connected_chain_index = torch.tensor(
|
1997
|
+
connected_chain_index, dtype=torch.long
|
1998
|
+
).T
|
1999
|
+
connected_atom_index = torch.tensor(
|
2000
|
+
connected_atom_index, dtype=torch.long
|
2001
|
+
).T
|
2002
|
+
else:
|
2003
|
+
connected_chain_index = torch.empty((2, 0), dtype=torch.long)
|
2004
|
+
connected_atom_index = torch.empty((2, 0), dtype=torch.long)
|
2005
|
+
else:
|
2006
|
+
connected_chain_index = torch.empty((2, 0), dtype=torch.long)
|
2007
|
+
connected_atom_index = torch.empty((2, 0), dtype=torch.long)
|
2008
|
+
|
2009
|
+
symmetric_chain_index = []
|
2010
|
+
for i, chain_i in enumerate(structure.chains):
|
2011
|
+
for j, chain_j in enumerate(structure.chains):
|
2012
|
+
if j <= i:
|
2013
|
+
continue
|
2014
|
+
if chain_i["entity_id"] == chain_j["entity_id"]:
|
2015
|
+
symmetric_chain_index.append([i, j])
|
2016
|
+
if len(symmetric_chain_index) > 0:
|
2017
|
+
symmetric_chain_index = torch.tensor(symmetric_chain_index, dtype=torch.long).T
|
2018
|
+
else:
|
2019
|
+
symmetric_chain_index = torch.empty((2, 0), dtype=torch.long)
|
2020
|
+
return {
|
2021
|
+
"connected_chain_index": connected_chain_index,
|
2022
|
+
"connected_atom_index": connected_atom_index,
|
2023
|
+
"symmetric_chain_index": symmetric_chain_index,
|
2024
|
+
}
|
2025
|
+
|
2026
|
+
|
2027
|
+
class Boltz2Featurizer:
|
2028
|
+
"""Boltz2 featurizer."""
|
2029
|
+
|
2030
|
+
def process(
|
2031
|
+
self,
|
2032
|
+
data: Tokenized,
|
2033
|
+
random: np.random.Generator,
|
2034
|
+
molecules: dict[str, Mol],
|
2035
|
+
training: bool,
|
2036
|
+
max_seqs: int,
|
2037
|
+
atoms_per_window_queries: int = 32,
|
2038
|
+
min_dist: float = 2.0,
|
2039
|
+
max_dist: float = 22.0,
|
2040
|
+
num_bins: int = 64,
|
2041
|
+
num_ensembles: int = 1,
|
2042
|
+
ensemble_sample_replacement: bool = False,
|
2043
|
+
disto_use_ensemble: Optional[bool] = False,
|
2044
|
+
fix_single_ensemble: Optional[bool] = True,
|
2045
|
+
max_tokens: Optional[int] = None,
|
2046
|
+
max_atoms: Optional[int] = None,
|
2047
|
+
pad_to_max_seqs: bool = False,
|
2048
|
+
compute_symmetries: bool = False,
|
2049
|
+
binder_pocket_conditioned_prop: Optional[float] = 0.0,
|
2050
|
+
contact_conditioned_prop: Optional[float] = 0.0,
|
2051
|
+
binder_pocket_cutoff_min: Optional[float] = 4.0,
|
2052
|
+
binder_pocket_cutoff_max: Optional[float] = 20.0,
|
2053
|
+
binder_pocket_sampling_geometric_p: Optional[float] = 0.0,
|
2054
|
+
only_ligand_binder_pocket: Optional[bool] = False,
|
2055
|
+
only_pp_contact: Optional[bool] = False,
|
2056
|
+
pocket_constraints: Optional[list] = None,
|
2057
|
+
single_sequence_prop: Optional[float] = 0.0,
|
2058
|
+
msa_sampling: bool = False,
|
2059
|
+
override_bfactor: float = False,
|
2060
|
+
override_method: Optional[str] = None,
|
2061
|
+
compute_frames: bool = False,
|
2062
|
+
override_coords: Optional[Tensor] = None,
|
2063
|
+
bfactor_md_correction: bool = False,
|
2064
|
+
compute_constraint_features: bool = False,
|
2065
|
+
inference_pocket_constraints: Optional[list] = None,
|
2066
|
+
compute_affinity: bool = False,
|
2067
|
+
) -> dict[str, Tensor]:
|
2068
|
+
"""Compute features.
|
2069
|
+
|
2070
|
+
Parameters
|
2071
|
+
----------
|
2072
|
+
data : Tokenized
|
2073
|
+
The input to the model.
|
2074
|
+
training : bool
|
2075
|
+
Whether the model is in training mode.
|
2076
|
+
max_tokens : int, optional
|
2077
|
+
The maximum number of tokens.
|
2078
|
+
max_atoms : int, optional
|
2079
|
+
The maximum number of atoms
|
2080
|
+
max_seqs : int, optional
|
2081
|
+
The maximum number of sequences.
|
2082
|
+
|
2083
|
+
Returns
|
2084
|
+
-------
|
2085
|
+
dict[str, Tensor]
|
2086
|
+
The features for model training.
|
2087
|
+
|
2088
|
+
"""
|
2089
|
+
# Compute random number of sequences
|
2090
|
+
if training and max_seqs is not None:
|
2091
|
+
if random.random() > single_sequence_prop:
|
2092
|
+
max_seqs_batch = random.integers(1, max_seqs + 1)
|
2093
|
+
else:
|
2094
|
+
max_seqs_batch = 1
|
2095
|
+
else:
|
2096
|
+
max_seqs_batch = max_seqs
|
2097
|
+
|
2098
|
+
# Compute ensemble features
|
2099
|
+
ensemble_features = process_ensemble_features(
|
2100
|
+
data=data,
|
2101
|
+
random=random,
|
2102
|
+
num_ensembles=num_ensembles,
|
2103
|
+
ensemble_sample_replacement=ensemble_sample_replacement,
|
2104
|
+
fix_single_ensemble=fix_single_ensemble,
|
2105
|
+
)
|
2106
|
+
|
2107
|
+
# Compute token features
|
2108
|
+
token_features = process_token_features(
|
2109
|
+
data=data,
|
2110
|
+
random=random,
|
2111
|
+
max_tokens=max_tokens,
|
2112
|
+
binder_pocket_conditioned_prop=binder_pocket_conditioned_prop,
|
2113
|
+
contact_conditioned_prop=contact_conditioned_prop,
|
2114
|
+
binder_pocket_cutoff_min=binder_pocket_cutoff_min,
|
2115
|
+
binder_pocket_cutoff_max=binder_pocket_cutoff_max,
|
2116
|
+
binder_pocket_sampling_geometric_p=binder_pocket_sampling_geometric_p,
|
2117
|
+
only_ligand_binder_pocket=only_ligand_binder_pocket,
|
2118
|
+
only_pp_contact=only_pp_contact,
|
2119
|
+
override_method=override_method,
|
2120
|
+
inference_pocket_constraints=inference_pocket_constraints,
|
2121
|
+
)
|
2122
|
+
|
2123
|
+
# Compute atom features
|
2124
|
+
atom_features = process_atom_features(
|
2125
|
+
data=data,
|
2126
|
+
random=random,
|
2127
|
+
molecules=molecules,
|
2128
|
+
ensemble_features=ensemble_features,
|
2129
|
+
atoms_per_window_queries=atoms_per_window_queries,
|
2130
|
+
min_dist=min_dist,
|
2131
|
+
max_dist=max_dist,
|
2132
|
+
num_bins=num_bins,
|
2133
|
+
max_atoms=max_atoms,
|
2134
|
+
max_tokens=max_tokens,
|
2135
|
+
disto_use_ensemble=disto_use_ensemble,
|
2136
|
+
override_bfactor=override_bfactor,
|
2137
|
+
compute_frames=compute_frames,
|
2138
|
+
override_coords=override_coords,
|
2139
|
+
bfactor_md_correction=bfactor_md_correction,
|
2140
|
+
)
|
2141
|
+
|
2142
|
+
# Compute MSA features
|
2143
|
+
msa_features = process_msa_features(
|
2144
|
+
data=data,
|
2145
|
+
random=random,
|
2146
|
+
max_seqs_batch=max_seqs_batch,
|
2147
|
+
max_seqs=max_seqs,
|
2148
|
+
max_tokens=max_tokens,
|
2149
|
+
pad_to_max_seqs=pad_to_max_seqs,
|
2150
|
+
msa_sampling=training and msa_sampling,
|
2151
|
+
)
|
2152
|
+
|
2153
|
+
# Compute MSA features
|
2154
|
+
msa_features_affinity = {}
|
2155
|
+
if compute_affinity:
|
2156
|
+
msa_features_affinity = process_msa_features(
|
2157
|
+
data=data,
|
2158
|
+
random=random,
|
2159
|
+
max_seqs_batch=1,
|
2160
|
+
max_seqs=1,
|
2161
|
+
max_tokens=max_tokens,
|
2162
|
+
pad_to_max_seqs=pad_to_max_seqs,
|
2163
|
+
msa_sampling=training and msa_sampling,
|
2164
|
+
affinity=True,
|
2165
|
+
)
|
2166
|
+
|
2167
|
+
# Compute affinity ligand Molecular Weight
|
2168
|
+
ligand_to_mw = {}
|
2169
|
+
if compute_affinity:
|
2170
|
+
ligand_to_mw["affinity_mw"] = data.record.affinity.mw
|
2171
|
+
|
2172
|
+
# Compute template features
|
2173
|
+
num_tokens = data.tokens.shape[0] if max_tokens is None else max_tokens
|
2174
|
+
if data.templates:
|
2175
|
+
template_features = process_template_features(
|
2176
|
+
data=data,
|
2177
|
+
max_tokens=num_tokens,
|
2178
|
+
)
|
2179
|
+
else:
|
2180
|
+
template_features = load_dummy_templates_features(
|
2181
|
+
tdim=1,
|
2182
|
+
num_tokens=num_tokens,
|
2183
|
+
)
|
2184
|
+
|
2185
|
+
# Compute symmetry features
|
2186
|
+
symmetry_features = {}
|
2187
|
+
if compute_symmetries:
|
2188
|
+
symmetries = get_symmetries(molecules)
|
2189
|
+
symmetry_features = process_symmetry_features(data, symmetries)
|
2190
|
+
|
2191
|
+
# Compute residue constraint features
|
2192
|
+
residue_constraint_features = {}
|
2193
|
+
if compute_constraint_features:
|
2194
|
+
residue_constraint_features = process_residue_constraint_features(data)
|
2195
|
+
chain_constraint_features = process_chain_feature_constraints(data)
|
2196
|
+
|
2197
|
+
return {
|
2198
|
+
**token_features,
|
2199
|
+
**atom_features,
|
2200
|
+
**msa_features,
|
2201
|
+
**msa_features_affinity,
|
2202
|
+
**template_features,
|
2203
|
+
**symmetry_features,
|
2204
|
+
**ensemble_features,
|
2205
|
+
**residue_constraint_features,
|
2206
|
+
**chain_constraint_features,
|
2207
|
+
**ligand_to_mw,
|
2208
|
+
}
|