boltz-vsynthes 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- boltz/__init__.py +7 -0
- boltz/data/__init__.py +0 -0
- boltz/data/const.py +1184 -0
- boltz/data/crop/__init__.py +0 -0
- boltz/data/crop/affinity.py +164 -0
- boltz/data/crop/boltz.py +296 -0
- boltz/data/crop/cropper.py +45 -0
- boltz/data/feature/__init__.py +0 -0
- boltz/data/feature/featurizer.py +1230 -0
- boltz/data/feature/featurizerv2.py +2208 -0
- boltz/data/feature/symmetry.py +602 -0
- boltz/data/filter/__init__.py +0 -0
- boltz/data/filter/dynamic/__init__.py +0 -0
- boltz/data/filter/dynamic/date.py +76 -0
- boltz/data/filter/dynamic/filter.py +24 -0
- boltz/data/filter/dynamic/max_residues.py +37 -0
- boltz/data/filter/dynamic/resolution.py +34 -0
- boltz/data/filter/dynamic/size.py +38 -0
- boltz/data/filter/dynamic/subset.py +42 -0
- boltz/data/filter/static/__init__.py +0 -0
- boltz/data/filter/static/filter.py +26 -0
- boltz/data/filter/static/ligand.py +37 -0
- boltz/data/filter/static/polymer.py +299 -0
- boltz/data/module/__init__.py +0 -0
- boltz/data/module/inference.py +307 -0
- boltz/data/module/inferencev2.py +429 -0
- boltz/data/module/training.py +684 -0
- boltz/data/module/trainingv2.py +660 -0
- boltz/data/mol.py +900 -0
- boltz/data/msa/__init__.py +0 -0
- boltz/data/msa/mmseqs2.py +235 -0
- boltz/data/pad.py +84 -0
- boltz/data/parse/__init__.py +0 -0
- boltz/data/parse/a3m.py +134 -0
- boltz/data/parse/csv.py +100 -0
- boltz/data/parse/fasta.py +138 -0
- boltz/data/parse/mmcif.py +1239 -0
- boltz/data/parse/mmcif_with_constraints.py +1607 -0
- boltz/data/parse/schema.py +1851 -0
- boltz/data/parse/yaml.py +68 -0
- boltz/data/sample/__init__.py +0 -0
- boltz/data/sample/cluster.py +283 -0
- boltz/data/sample/distillation.py +57 -0
- boltz/data/sample/random.py +39 -0
- boltz/data/sample/sampler.py +49 -0
- boltz/data/tokenize/__init__.py +0 -0
- boltz/data/tokenize/boltz.py +195 -0
- boltz/data/tokenize/boltz2.py +396 -0
- boltz/data/tokenize/tokenizer.py +24 -0
- boltz/data/types.py +777 -0
- boltz/data/write/__init__.py +0 -0
- boltz/data/write/mmcif.py +305 -0
- boltz/data/write/pdb.py +171 -0
- boltz/data/write/utils.py +23 -0
- boltz/data/write/writer.py +330 -0
- boltz/main.py +1292 -0
- boltz/model/__init__.py +0 -0
- boltz/model/layers/__init__.py +0 -0
- boltz/model/layers/attention.py +132 -0
- boltz/model/layers/attentionv2.py +111 -0
- boltz/model/layers/confidence_utils.py +231 -0
- boltz/model/layers/dropout.py +34 -0
- boltz/model/layers/initialize.py +100 -0
- boltz/model/layers/outer_product_mean.py +98 -0
- boltz/model/layers/pair_averaging.py +135 -0
- boltz/model/layers/pairformer.py +337 -0
- boltz/model/layers/relative.py +58 -0
- boltz/model/layers/transition.py +78 -0
- boltz/model/layers/triangular_attention/__init__.py +0 -0
- boltz/model/layers/triangular_attention/attention.py +189 -0
- boltz/model/layers/triangular_attention/primitives.py +409 -0
- boltz/model/layers/triangular_attention/utils.py +380 -0
- boltz/model/layers/triangular_mult.py +212 -0
- boltz/model/loss/__init__.py +0 -0
- boltz/model/loss/bfactor.py +49 -0
- boltz/model/loss/confidence.py +590 -0
- boltz/model/loss/confidencev2.py +621 -0
- boltz/model/loss/diffusion.py +171 -0
- boltz/model/loss/diffusionv2.py +134 -0
- boltz/model/loss/distogram.py +48 -0
- boltz/model/loss/distogramv2.py +105 -0
- boltz/model/loss/validation.py +1025 -0
- boltz/model/models/__init__.py +0 -0
- boltz/model/models/boltz1.py +1286 -0
- boltz/model/models/boltz2.py +1249 -0
- boltz/model/modules/__init__.py +0 -0
- boltz/model/modules/affinity.py +223 -0
- boltz/model/modules/confidence.py +481 -0
- boltz/model/modules/confidence_utils.py +181 -0
- boltz/model/modules/confidencev2.py +495 -0
- boltz/model/modules/diffusion.py +844 -0
- boltz/model/modules/diffusion_conditioning.py +116 -0
- boltz/model/modules/diffusionv2.py +677 -0
- boltz/model/modules/encoders.py +639 -0
- boltz/model/modules/encodersv2.py +565 -0
- boltz/model/modules/transformers.py +322 -0
- boltz/model/modules/transformersv2.py +261 -0
- boltz/model/modules/trunk.py +688 -0
- boltz/model/modules/trunkv2.py +828 -0
- boltz/model/modules/utils.py +303 -0
- boltz/model/optim/__init__.py +0 -0
- boltz/model/optim/ema.py +389 -0
- boltz/model/optim/scheduler.py +99 -0
- boltz/model/potentials/__init__.py +0 -0
- boltz/model/potentials/potentials.py +497 -0
- boltz/model/potentials/schedules.py +32 -0
- boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
- boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
- boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
- boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
- boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
- boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,602 @@
|
|
1
|
+
import itertools
|
2
|
+
import pickle
|
3
|
+
import random
|
4
|
+
from pathlib import Path
|
5
|
+
|
6
|
+
import numpy as np
|
7
|
+
import torch
|
8
|
+
|
9
|
+
from boltz.data import const
|
10
|
+
from boltz.data.pad import pad_dim
|
11
|
+
from boltz.model.loss.confidence import lddt_dist
|
12
|
+
from boltz.model.loss.validation import weighted_minimum_rmsd_single
|
13
|
+
|
14
|
+
|
15
|
+
def convert_atom_name(name: str) -> tuple[int, int, int, int]:
|
16
|
+
"""Convert an atom name to a standard format.
|
17
|
+
|
18
|
+
Parameters
|
19
|
+
----------
|
20
|
+
name : str
|
21
|
+
The atom name.
|
22
|
+
|
23
|
+
Returns
|
24
|
+
-------
|
25
|
+
Tuple[int, int, int, int]
|
26
|
+
The converted atom name.
|
27
|
+
|
28
|
+
"""
|
29
|
+
name = name.strip()
|
30
|
+
name = [ord(c) - 32 for c in name]
|
31
|
+
name = name + [0] * (4 - len(name))
|
32
|
+
return tuple(name)
|
33
|
+
|
34
|
+
|
35
|
+
def get_symmetries(path: str) -> dict:
|
36
|
+
"""Create a dictionary for the ligand symmetries.
|
37
|
+
|
38
|
+
Parameters
|
39
|
+
----------
|
40
|
+
path : str
|
41
|
+
The path to the ligand symmetries.
|
42
|
+
|
43
|
+
Returns
|
44
|
+
-------
|
45
|
+
dict
|
46
|
+
The ligand symmetries.
|
47
|
+
|
48
|
+
"""
|
49
|
+
with Path(path).open("rb") as f:
|
50
|
+
data: dict = pickle.load(f) # noqa: S301
|
51
|
+
|
52
|
+
symmetries = {}
|
53
|
+
for key, mol in data.items():
|
54
|
+
try:
|
55
|
+
serialized_sym = bytes.fromhex(mol.GetProp("symmetries"))
|
56
|
+
sym = pickle.loads(serialized_sym) # noqa: S301
|
57
|
+
atom_names = []
|
58
|
+
for atom in mol.GetAtoms():
|
59
|
+
# Get atom name
|
60
|
+
atom_name = convert_atom_name(atom.GetProp("name"))
|
61
|
+
atom_names.append(atom_name)
|
62
|
+
|
63
|
+
symmetries[key] = (sym, atom_names)
|
64
|
+
except Exception: # noqa: BLE001, PERF203, S110
|
65
|
+
pass
|
66
|
+
|
67
|
+
return symmetries
|
68
|
+
|
69
|
+
|
70
|
+
def compute_symmetry_idx_dictionary(data):
|
71
|
+
# Compute the symmetry index dictionary
|
72
|
+
total_count = 0
|
73
|
+
all_coords = []
|
74
|
+
for i, chain in enumerate(data.chains):
|
75
|
+
chain.start_idx = total_count
|
76
|
+
for j, token in enumerate(chain.tokens):
|
77
|
+
token.start_idx = total_count - chain.start_idx
|
78
|
+
all_coords.extend(
|
79
|
+
[[atom.coords.x, atom.coords.y, atom.coords.z] for atom in token.atoms]
|
80
|
+
)
|
81
|
+
total_count += len(token.atoms)
|
82
|
+
return all_coords
|
83
|
+
|
84
|
+
|
85
|
+
def get_current_idx_list(data):
|
86
|
+
idx = []
|
87
|
+
for chain in data.chains:
|
88
|
+
if chain.in_crop:
|
89
|
+
for token in chain.tokens:
|
90
|
+
if token.in_crop:
|
91
|
+
idx.extend(
|
92
|
+
[
|
93
|
+
chain.start_idx + token.start_idx + i
|
94
|
+
for i in range(len(token.atoms))
|
95
|
+
]
|
96
|
+
)
|
97
|
+
return idx
|
98
|
+
|
99
|
+
|
100
|
+
def all_different_after_swap(l):
|
101
|
+
final = [s[-1] for s in l]
|
102
|
+
return len(final) == len(set(final))
|
103
|
+
|
104
|
+
|
105
|
+
def minimum_symmetry_coords(
|
106
|
+
coords: torch.Tensor,
|
107
|
+
feats: dict,
|
108
|
+
index_batch: int,
|
109
|
+
**args_rmsd,
|
110
|
+
):
|
111
|
+
all_coords = feats["all_coords"][index_batch].unsqueeze(0).to(coords)
|
112
|
+
all_resolved_mask = (
|
113
|
+
feats["all_resolved_mask"][index_batch].to(coords).to(torch.bool)
|
114
|
+
)
|
115
|
+
crop_to_all_atom_map = (
|
116
|
+
feats["crop_to_all_atom_map"][index_batch].to(coords).to(torch.long)
|
117
|
+
)
|
118
|
+
chain_symmetries = feats["chain_symmetries"][index_batch]
|
119
|
+
amino_acids_symmetries = feats["amino_acids_symmetries"][index_batch]
|
120
|
+
ligand_symmetries = feats["ligand_symmetries"][index_batch]
|
121
|
+
|
122
|
+
# Check best symmetry on chain swap
|
123
|
+
best_true_coords = None
|
124
|
+
best_rmsd = float("inf")
|
125
|
+
best_align_weights = None
|
126
|
+
for c in chain_symmetries:
|
127
|
+
true_all_coords = all_coords.clone()
|
128
|
+
true_all_resolved_mask = all_resolved_mask.clone()
|
129
|
+
for start1, end1, start2, end2, chainidx1, chainidx2 in c:
|
130
|
+
true_all_coords[:, start1:end1] = all_coords[:, start2:end2]
|
131
|
+
true_all_resolved_mask[start1:end1] = all_resolved_mask[start2:end2]
|
132
|
+
true_coords = true_all_coords[:, crop_to_all_atom_map]
|
133
|
+
true_resolved_mask = true_all_resolved_mask[crop_to_all_atom_map]
|
134
|
+
true_coords = pad_dim(true_coords, 1, coords.shape[1] - true_coords.shape[1])
|
135
|
+
true_resolved_mask = pad_dim(
|
136
|
+
true_resolved_mask,
|
137
|
+
0,
|
138
|
+
coords.shape[1] - true_resolved_mask.shape[0],
|
139
|
+
)
|
140
|
+
try:
|
141
|
+
rmsd, aligned_coords, align_weights = weighted_minimum_rmsd_single(
|
142
|
+
coords,
|
143
|
+
true_coords,
|
144
|
+
atom_mask=true_resolved_mask,
|
145
|
+
atom_to_token=feats["atom_to_token"][index_batch : index_batch + 1],
|
146
|
+
mol_type=feats["mol_type"][index_batch : index_batch + 1],
|
147
|
+
**args_rmsd,
|
148
|
+
)
|
149
|
+
except:
|
150
|
+
print("Warning: error in rmsd computation inside symmetry code")
|
151
|
+
continue
|
152
|
+
rmsd = rmsd.item()
|
153
|
+
|
154
|
+
if rmsd < best_rmsd:
|
155
|
+
best_rmsd = rmsd
|
156
|
+
best_true_coords = aligned_coords
|
157
|
+
best_align_weights = align_weights
|
158
|
+
best_true_resolved_mask = true_resolved_mask
|
159
|
+
|
160
|
+
# atom symmetries (nucleic acid and protein residues), resolved greedily without recomputing alignment
|
161
|
+
true_coords = best_true_coords.clone()
|
162
|
+
true_resolved_mask = best_true_resolved_mask.clone()
|
163
|
+
for symmetric_amino in amino_acids_symmetries:
|
164
|
+
for c in symmetric_amino:
|
165
|
+
# starting from greedy best, try to swap the atoms
|
166
|
+
new_true_coords = true_coords.clone()
|
167
|
+
new_true_resolved_mask = true_resolved_mask.clone()
|
168
|
+
for i, j in c:
|
169
|
+
new_true_coords[:, i] = true_coords[:, j]
|
170
|
+
new_true_resolved_mask[i] = true_resolved_mask[j]
|
171
|
+
|
172
|
+
# compute squared distance, for efficiency we do not recompute the alignment
|
173
|
+
best_mse_loss = torch.sum(
|
174
|
+
((coords - best_true_coords) ** 2).sum(dim=-1)
|
175
|
+
* best_align_weights
|
176
|
+
* best_true_resolved_mask,
|
177
|
+
dim=-1,
|
178
|
+
) / torch.sum(best_align_weights * best_true_resolved_mask, dim=-1)
|
179
|
+
new_mse_loss = torch.sum(
|
180
|
+
((coords - new_true_coords) ** 2).sum(dim=-1)
|
181
|
+
* best_align_weights
|
182
|
+
* new_true_resolved_mask,
|
183
|
+
dim=-1,
|
184
|
+
) / torch.sum(best_align_weights * new_true_resolved_mask, dim=-1)
|
185
|
+
|
186
|
+
if best_mse_loss > new_mse_loss:
|
187
|
+
best_true_coords = new_true_coords
|
188
|
+
best_true_resolved_mask = new_true_resolved_mask
|
189
|
+
|
190
|
+
# greedily update best coordinates after each amino acid
|
191
|
+
true_coords = best_true_coords.clone()
|
192
|
+
true_resolved_mask = best_true_resolved_mask.clone()
|
193
|
+
|
194
|
+
# Recomputing alignment
|
195
|
+
rmsd, true_coords, best_align_weights = weighted_minimum_rmsd_single(
|
196
|
+
coords,
|
197
|
+
true_coords,
|
198
|
+
atom_mask=true_resolved_mask,
|
199
|
+
atom_to_token=feats["atom_to_token"][index_batch : index_batch + 1],
|
200
|
+
mol_type=feats["mol_type"][index_batch : index_batch + 1],
|
201
|
+
**args_rmsd,
|
202
|
+
)
|
203
|
+
best_rmsd = rmsd.item()
|
204
|
+
|
205
|
+
# atom symmetries (ligand and non-standard), resolved greedily recomputing alignment
|
206
|
+
for symmetric_ligand in ligand_symmetries:
|
207
|
+
for c in symmetric_ligand:
|
208
|
+
new_true_coords = true_coords.clone()
|
209
|
+
new_true_resolved_mask = true_resolved_mask.clone()
|
210
|
+
for i, j in c:
|
211
|
+
new_true_coords[:, j] = true_coords[:, i]
|
212
|
+
new_true_resolved_mask[j] = true_resolved_mask[i]
|
213
|
+
try:
|
214
|
+
# TODO if this is too slow maybe we can get away with not recomputing alignment
|
215
|
+
rmsd, aligned_coords, align_weights = weighted_minimum_rmsd_single(
|
216
|
+
coords,
|
217
|
+
new_true_coords,
|
218
|
+
atom_mask=new_true_resolved_mask,
|
219
|
+
atom_to_token=feats["atom_to_token"][index_batch : index_batch + 1],
|
220
|
+
mol_type=feats["mol_type"][index_batch : index_batch + 1],
|
221
|
+
**args_rmsd,
|
222
|
+
)
|
223
|
+
except Exception as e:
|
224
|
+
raise e
|
225
|
+
print(e)
|
226
|
+
continue
|
227
|
+
rmsd = rmsd.item()
|
228
|
+
if rmsd < best_rmsd:
|
229
|
+
best_true_coords = aligned_coords
|
230
|
+
best_rmsd = rmsd
|
231
|
+
best_true_resolved_mask = new_true_resolved_mask
|
232
|
+
|
233
|
+
true_coords = best_true_coords.clone()
|
234
|
+
true_resolved_mask = best_true_resolved_mask.clone()
|
235
|
+
|
236
|
+
return best_true_coords, best_rmsd, best_true_resolved_mask.unsqueeze(0)
|
237
|
+
|
238
|
+
|
239
|
+
def minimum_lddt_symmetry_coords(
|
240
|
+
coords: torch.Tensor,
|
241
|
+
feats: dict,
|
242
|
+
index_batch: int,
|
243
|
+
**args_rmsd,
|
244
|
+
):
|
245
|
+
all_coords = feats["all_coords"][index_batch].unsqueeze(0).to(coords)
|
246
|
+
all_resolved_mask = (
|
247
|
+
feats["all_resolved_mask"][index_batch].to(coords).to(torch.bool)
|
248
|
+
)
|
249
|
+
crop_to_all_atom_map = (
|
250
|
+
feats["crop_to_all_atom_map"][index_batch].to(coords).to(torch.long)
|
251
|
+
)
|
252
|
+
chain_symmetries = feats["chain_symmetries"][index_batch]
|
253
|
+
amino_acids_symmetries = feats["amino_acids_symmetries"][index_batch]
|
254
|
+
ligand_symmetries = feats["ligand_symmetries"][index_batch]
|
255
|
+
|
256
|
+
dmat_predicted = torch.cdist(
|
257
|
+
coords[:, : len(crop_to_all_atom_map)], coords[:, : len(crop_to_all_atom_map)]
|
258
|
+
)
|
259
|
+
|
260
|
+
# Check best symmetry on chain swap
|
261
|
+
best_true_coords = None
|
262
|
+
best_lddt = 0
|
263
|
+
for c in chain_symmetries:
|
264
|
+
true_all_coords = all_coords.clone()
|
265
|
+
true_all_resolved_mask = all_resolved_mask.clone()
|
266
|
+
for start1, end1, start2, end2, chainidx1, chainidx2 in c:
|
267
|
+
true_all_coords[:, start1:end1] = all_coords[:, start2:end2]
|
268
|
+
true_all_resolved_mask[start1:end1] = all_resolved_mask[start2:end2]
|
269
|
+
true_coords = true_all_coords[:, crop_to_all_atom_map]
|
270
|
+
true_resolved_mask = true_all_resolved_mask[crop_to_all_atom_map]
|
271
|
+
dmat_true = torch.cdist(true_coords, true_coords)
|
272
|
+
pair_mask = (
|
273
|
+
true_resolved_mask[:, None]
|
274
|
+
* true_resolved_mask[None, :]
|
275
|
+
* (1 - torch.eye(len(true_resolved_mask))).to(true_resolved_mask)
|
276
|
+
)
|
277
|
+
|
278
|
+
lddt = lddt_dist(
|
279
|
+
dmat_predicted, dmat_true, pair_mask, cutoff=15.0, per_atom=False
|
280
|
+
)[0]
|
281
|
+
lddt = lddt.item()
|
282
|
+
|
283
|
+
if lddt > best_lddt:
|
284
|
+
best_lddt = lddt
|
285
|
+
best_true_coords = true_coords
|
286
|
+
best_true_resolved_mask = true_resolved_mask
|
287
|
+
|
288
|
+
# atom symmetries (nucleic acid and protein residues), resolved greedily without recomputing alignment
|
289
|
+
true_coords = best_true_coords.clone()
|
290
|
+
true_resolved_mask = best_true_resolved_mask.clone()
|
291
|
+
for symmetric_amino_or_lig in amino_acids_symmetries + ligand_symmetries:
|
292
|
+
for c in symmetric_amino_or_lig:
|
293
|
+
# starting from greedy best, try to swap the atoms
|
294
|
+
new_true_coords = true_coords.clone()
|
295
|
+
new_true_resolved_mask = true_resolved_mask.clone()
|
296
|
+
indices = []
|
297
|
+
for i, j in c:
|
298
|
+
new_true_coords[:, i] = true_coords[:, j]
|
299
|
+
new_true_resolved_mask[i] = true_resolved_mask[j]
|
300
|
+
indices.append(i)
|
301
|
+
|
302
|
+
indices = (
|
303
|
+
torch.from_numpy(np.asarray(indices)).to(new_true_coords.device).long()
|
304
|
+
)
|
305
|
+
|
306
|
+
pred_coords_subset = coords[:, : len(crop_to_all_atom_map)][:, indices]
|
307
|
+
true_coords_subset = true_coords[:, indices]
|
308
|
+
new_true_coords_subset = new_true_coords[:, indices]
|
309
|
+
|
310
|
+
sub_dmat_pred = torch.cdist(
|
311
|
+
coords[:, : len(crop_to_all_atom_map)], pred_coords_subset
|
312
|
+
)
|
313
|
+
sub_dmat_true = torch.cdist(true_coords, true_coords_subset)
|
314
|
+
sub_dmat_new_true = torch.cdist(new_true_coords, new_true_coords_subset)
|
315
|
+
|
316
|
+
sub_true_pair_lddt = (
|
317
|
+
true_resolved_mask[:, None] * true_resolved_mask[None, indices]
|
318
|
+
)
|
319
|
+
sub_true_pair_lddt[indices] = (
|
320
|
+
sub_true_pair_lddt[indices]
|
321
|
+
* (1 - torch.eye(len(indices))).to(sub_true_pair_lddt).bool()
|
322
|
+
)
|
323
|
+
|
324
|
+
sub_new_true_pair_lddt = (
|
325
|
+
new_true_resolved_mask[:, None] * new_true_resolved_mask[None, indices]
|
326
|
+
)
|
327
|
+
sub_new_true_pair_lddt[indices] = (
|
328
|
+
sub_new_true_pair_lddt[indices]
|
329
|
+
* (1 - torch.eye(len(indices))).to(sub_true_pair_lddt).bool()
|
330
|
+
)
|
331
|
+
|
332
|
+
lddt = lddt_dist(
|
333
|
+
sub_dmat_pred,
|
334
|
+
sub_dmat_true,
|
335
|
+
sub_true_pair_lddt,
|
336
|
+
cutoff=15.0,
|
337
|
+
per_atom=False,
|
338
|
+
)[0]
|
339
|
+
new_lddt = lddt_dist(
|
340
|
+
sub_dmat_pred,
|
341
|
+
sub_dmat_new_true,
|
342
|
+
sub_new_true_pair_lddt,
|
343
|
+
cutoff=15.0,
|
344
|
+
per_atom=False,
|
345
|
+
)[0]
|
346
|
+
|
347
|
+
if new_lddt > lddt:
|
348
|
+
best_true_coords = new_true_coords
|
349
|
+
best_true_resolved_mask = new_true_resolved_mask
|
350
|
+
|
351
|
+
# greedily update best coordinates after each amino acid
|
352
|
+
true_coords = best_true_coords.clone()
|
353
|
+
true_resolved_mask = best_true_resolved_mask.clone()
|
354
|
+
|
355
|
+
# Recomputing alignment
|
356
|
+
true_coords = pad_dim(true_coords, 1, coords.shape[1] - true_coords.shape[1])
|
357
|
+
true_resolved_mask = pad_dim(
|
358
|
+
true_resolved_mask,
|
359
|
+
0,
|
360
|
+
coords.shape[1] - true_resolved_mask.shape[0],
|
361
|
+
)
|
362
|
+
|
363
|
+
try:
|
364
|
+
rmsd, true_coords, _ = weighted_minimum_rmsd_single(
|
365
|
+
coords,
|
366
|
+
true_coords,
|
367
|
+
atom_mask=true_resolved_mask,
|
368
|
+
atom_to_token=feats["atom_to_token"][index_batch : index_batch + 1],
|
369
|
+
mol_type=feats["mol_type"][index_batch : index_batch + 1],
|
370
|
+
**args_rmsd,
|
371
|
+
)
|
372
|
+
best_rmsd = rmsd.item()
|
373
|
+
except Exception as e:
|
374
|
+
print("Failed proper RMSD computation, returning inf. Error: ", e)
|
375
|
+
best_rmsd = 1000
|
376
|
+
|
377
|
+
return true_coords, best_rmsd, true_resolved_mask.unsqueeze(0)
|
378
|
+
|
379
|
+
|
380
|
+
def compute_all_coords_mask(structure):
|
381
|
+
# Compute all coords, crop mask and add start_idx to structure
|
382
|
+
total_count = 0
|
383
|
+
all_coords = []
|
384
|
+
all_coords_crop_mask = []
|
385
|
+
all_resolved_mask = []
|
386
|
+
for i, chain in enumerate(structure.chains):
|
387
|
+
chain.start_idx = total_count
|
388
|
+
for j, token in enumerate(chain.tokens):
|
389
|
+
token.start_idx = total_count - chain.start_idx
|
390
|
+
all_coords.extend(
|
391
|
+
[[atom.coords.x, atom.coords.y, atom.coords.z] for atom in token.atoms]
|
392
|
+
)
|
393
|
+
all_coords_crop_mask.extend(
|
394
|
+
[token.in_crop for _ in range(len(token.atoms))]
|
395
|
+
)
|
396
|
+
all_resolved_mask.extend(
|
397
|
+
[token.is_present for _ in range(len(token.atoms))]
|
398
|
+
)
|
399
|
+
total_count += len(token.atoms)
|
400
|
+
if len(all_coords_crop_mask) != len(all_resolved_mask):
|
401
|
+
pass
|
402
|
+
return all_coords, all_coords_crop_mask, all_resolved_mask
|
403
|
+
|
404
|
+
|
405
|
+
def get_chain_symmetries(cropped, max_n_symmetries=100):
|
406
|
+
# get all coordinates and resolved mask
|
407
|
+
structure = cropped.structure
|
408
|
+
all_coords = []
|
409
|
+
all_resolved_mask = []
|
410
|
+
original_atom_idx = []
|
411
|
+
chain_atom_idx = []
|
412
|
+
chain_atom_num = []
|
413
|
+
chain_in_crop = []
|
414
|
+
chain_asym_id = []
|
415
|
+
new_atom_idx = 0
|
416
|
+
|
417
|
+
for chain in structure.chains:
|
418
|
+
atom_idx, atom_num = (
|
419
|
+
chain["atom_idx"],
|
420
|
+
chain["atom_num"],
|
421
|
+
)
|
422
|
+
|
423
|
+
# compute coordinates and resolved mask
|
424
|
+
resolved_mask = structure.atoms["is_present"][atom_idx : atom_idx + atom_num]
|
425
|
+
|
426
|
+
# ensemble_atom_starts = [structure.ensemble[idx]["atom_coord_idx"] for idx in cropped.ensemble_ref_idxs]
|
427
|
+
# coords = np.array(
|
428
|
+
# [structure.coords[ensemble_atom_start + atom_idx: ensemble_atom_start + atom_idx + atom_num]["coords"] for
|
429
|
+
# ensemble_atom_start in ensemble_atom_starts])
|
430
|
+
|
431
|
+
coords = structure.atoms["coords"][atom_idx : atom_idx + atom_num]
|
432
|
+
|
433
|
+
in_crop = False
|
434
|
+
for token in cropped.tokens:
|
435
|
+
if token["asym_id"] == chain["asym_id"]:
|
436
|
+
in_crop = True
|
437
|
+
break
|
438
|
+
|
439
|
+
all_coords.append(coords)
|
440
|
+
all_resolved_mask.append(resolved_mask)
|
441
|
+
original_atom_idx.append(atom_idx)
|
442
|
+
chain_atom_idx.append(new_atom_idx)
|
443
|
+
chain_atom_num.append(atom_num)
|
444
|
+
chain_in_crop.append(in_crop)
|
445
|
+
chain_asym_id.append(chain["asym_id"])
|
446
|
+
|
447
|
+
new_atom_idx += atom_num
|
448
|
+
|
449
|
+
# Compute backmapping from token to all coords
|
450
|
+
crop_to_all_atom_map = []
|
451
|
+
for token in cropped.tokens:
|
452
|
+
chain_idx = chain_asym_id.index(token["asym_id"])
|
453
|
+
start = (
|
454
|
+
chain_atom_idx[chain_idx] - original_atom_idx[chain_idx] + token["atom_idx"]
|
455
|
+
)
|
456
|
+
crop_to_all_atom_map.append(np.arange(start, start + token["atom_num"]))
|
457
|
+
|
458
|
+
# Compute the symmetries between chains
|
459
|
+
swaps = []
|
460
|
+
for i, chain in enumerate(structure.chains):
|
461
|
+
start = chain_atom_idx[i]
|
462
|
+
end = start + chain_atom_num[i]
|
463
|
+
if chain_in_crop[i]:
|
464
|
+
possible_swaps = []
|
465
|
+
for j, chain2 in enumerate(structure.chains):
|
466
|
+
start2 = chain_atom_idx[j]
|
467
|
+
end2 = start2 + chain_atom_num[j]
|
468
|
+
if (
|
469
|
+
chain["entity_id"] == chain2["entity_id"]
|
470
|
+
and end - start == end2 - start2
|
471
|
+
):
|
472
|
+
possible_swaps.append((start, end, start2, end2, i, j))
|
473
|
+
swaps.append(possible_swaps)
|
474
|
+
combinations = itertools.product(*swaps)
|
475
|
+
# to avoid combinatorial explosion, bound the number of combinations even considered
|
476
|
+
combinations = list(itertools.islice(combinations, max_n_symmetries * 10))
|
477
|
+
# filter for all chains getting a different assignment
|
478
|
+
combinations = [c for c in combinations if all_different_after_swap(c)]
|
479
|
+
|
480
|
+
if len(combinations) > max_n_symmetries:
|
481
|
+
combinations = random.sample(combinations, max_n_symmetries)
|
482
|
+
|
483
|
+
if len(combinations) == 0:
|
484
|
+
combinations.append([])
|
485
|
+
|
486
|
+
features = {}
|
487
|
+
features["all_coords"] = torch.Tensor(
|
488
|
+
np.concatenate(all_coords, axis=0)
|
489
|
+
) # axis=1 with ensemble
|
490
|
+
|
491
|
+
features["all_resolved_mask"] = torch.Tensor(
|
492
|
+
np.concatenate(all_resolved_mask, axis=0)
|
493
|
+
)
|
494
|
+
features["crop_to_all_atom_map"] = torch.Tensor(
|
495
|
+
np.concatenate(crop_to_all_atom_map, axis=0)
|
496
|
+
)
|
497
|
+
features["chain_symmetries"] = combinations
|
498
|
+
|
499
|
+
return features
|
500
|
+
|
501
|
+
|
502
|
+
def get_amino_acids_symmetries(cropped):
|
503
|
+
# Compute standard amino-acids symmetries
|
504
|
+
swaps = []
|
505
|
+
start_index_crop = 0
|
506
|
+
for token in cropped.tokens:
|
507
|
+
symmetries = const.ref_symmetries.get(const.tokens[token["res_type"]], [])
|
508
|
+
if len(symmetries) > 0:
|
509
|
+
residue_swaps = []
|
510
|
+
for sym in symmetries:
|
511
|
+
sym_new_idx = [
|
512
|
+
(i + start_index_crop, j + start_index_crop) for i, j in sym
|
513
|
+
]
|
514
|
+
residue_swaps.append(sym_new_idx)
|
515
|
+
swaps.append(residue_swaps)
|
516
|
+
start_index_crop += token["atom_num"]
|
517
|
+
|
518
|
+
features = {"amino_acids_symmetries": swaps}
|
519
|
+
return features
|
520
|
+
|
521
|
+
|
522
|
+
def get_ligand_symmetries(cropped, symmetries):
|
523
|
+
# Compute ligand and non-standard amino-acids symmetries
|
524
|
+
structure = cropped.structure
|
525
|
+
|
526
|
+
added_molecules = {}
|
527
|
+
index_mols = []
|
528
|
+
atom_count = 0
|
529
|
+
for token in cropped.tokens:
|
530
|
+
# check if molecule is already added by identifying it through asym_id and res_idx
|
531
|
+
atom_count += token["atom_num"]
|
532
|
+
mol_id = (token["asym_id"], token["res_idx"])
|
533
|
+
if mol_id in added_molecules.keys():
|
534
|
+
added_molecules[mol_id] += token["atom_num"]
|
535
|
+
continue
|
536
|
+
added_molecules[mol_id] = token["atom_num"]
|
537
|
+
|
538
|
+
# get the molecule type and indices
|
539
|
+
residue_idx = token["res_idx"] + structure.chains[token["asym_id"]]["res_idx"]
|
540
|
+
mol_name = structure.residues[residue_idx]["name"]
|
541
|
+
atom_idx = structure.residues[residue_idx]["atom_idx"]
|
542
|
+
mol_atom_names = structure.atoms[
|
543
|
+
atom_idx : atom_idx + structure.residues[residue_idx]["atom_num"]
|
544
|
+
]["name"]
|
545
|
+
mol_atom_names = [tuple(m) for m in mol_atom_names]
|
546
|
+
if mol_name not in const.ref_symmetries.keys():
|
547
|
+
index_mols.append(
|
548
|
+
(mol_name, atom_count - token["atom_num"], mol_id, mol_atom_names)
|
549
|
+
)
|
550
|
+
|
551
|
+
# for each molecule, get the symmetries
|
552
|
+
molecule_symmetries = []
|
553
|
+
for mol_name, start_mol, mol_id, mol_atom_names in index_mols:
|
554
|
+
if not mol_name in symmetries:
|
555
|
+
continue
|
556
|
+
else:
|
557
|
+
swaps = []
|
558
|
+
syms_ccd, mol_atom_names_ccd = symmetries[mol_name]
|
559
|
+
# Get indices of mol_atom_names_ccd that are in mol_atom_names
|
560
|
+
ccd_to_valid_ids = {
|
561
|
+
mol_atom_names_ccd.index(name): i
|
562
|
+
for i, name in enumerate(mol_atom_names)
|
563
|
+
}
|
564
|
+
ccd_valid_ids = set(ccd_to_valid_ids.keys())
|
565
|
+
|
566
|
+
syms = []
|
567
|
+
# Get syms
|
568
|
+
for sym_ccd in syms_ccd:
|
569
|
+
sym_dict = {}
|
570
|
+
bool_add = True
|
571
|
+
for i, j in enumerate(sym_ccd):
|
572
|
+
if i in ccd_valid_ids:
|
573
|
+
if j in ccd_valid_ids:
|
574
|
+
i_true = ccd_to_valid_ids[i]
|
575
|
+
j_true = ccd_to_valid_ids[j]
|
576
|
+
sym_dict[i_true] = j_true
|
577
|
+
else:
|
578
|
+
bool_add = False
|
579
|
+
break
|
580
|
+
if bool_add:
|
581
|
+
syms.append([sym_dict[i] for i in range(len(ccd_valid_ids))])
|
582
|
+
|
583
|
+
for sym in syms:
|
584
|
+
if len(sym) != added_molecules[mol_id]:
|
585
|
+
raise Exception(
|
586
|
+
f"Symmetry length mismatch {len(sym)} {added_molecules[mol_id]}"
|
587
|
+
)
|
588
|
+
# assert (
|
589
|
+
# len(sym) == added_molecules[mol_id]
|
590
|
+
# ), f"Symmetry length mismatch {len(sym)} {added_molecules[mol_id]}"
|
591
|
+
sym_new_idx = []
|
592
|
+
for i, j in enumerate(sym):
|
593
|
+
if i != int(j):
|
594
|
+
sym_new_idx.append((i + start_mol, int(j) + start_mol))
|
595
|
+
if len(sym_new_idx) > 0:
|
596
|
+
swaps.append(sym_new_idx)
|
597
|
+
if len(swaps) > 0:
|
598
|
+
molecule_symmetries.append(swaps)
|
599
|
+
|
600
|
+
features = {"ligand_symmetries": molecule_symmetries}
|
601
|
+
|
602
|
+
return features
|
File without changes
|
File without changes
|
@@ -0,0 +1,76 @@
|
|
1
|
+
from datetime import datetime
|
2
|
+
from typing import Literal
|
3
|
+
|
4
|
+
from boltz.data.types import Record
|
5
|
+
from boltz.data.filter.dynamic.filter import DynamicFilter
|
6
|
+
|
7
|
+
|
8
|
+
class DateFilter(DynamicFilter):
|
9
|
+
"""A filter that filters complexes based on their date.
|
10
|
+
|
11
|
+
The date can be the deposition, release, or revision date.
|
12
|
+
If the date is not available, the previous date is used.
|
13
|
+
|
14
|
+
If no date is available, the complex is rejected.
|
15
|
+
|
16
|
+
"""
|
17
|
+
|
18
|
+
def __init__(
|
19
|
+
self,
|
20
|
+
date: str,
|
21
|
+
ref: Literal["deposited", "revised", "released"],
|
22
|
+
) -> None:
|
23
|
+
"""Initialize the filter.
|
24
|
+
|
25
|
+
Parameters
|
26
|
+
----------
|
27
|
+
date : str, optional
|
28
|
+
The maximum date of PDB entries to filter
|
29
|
+
ref : Literal["deposited", "revised", "released"]
|
30
|
+
The reference date to use.
|
31
|
+
|
32
|
+
"""
|
33
|
+
self.filter_date = datetime.fromisoformat(date)
|
34
|
+
self.ref = ref
|
35
|
+
|
36
|
+
if ref not in ["deposited", "revised", "released"]:
|
37
|
+
msg = (
|
38
|
+
"Invalid reference date. Must be ",
|
39
|
+
"deposited, revised, or released",
|
40
|
+
)
|
41
|
+
raise ValueError(msg)
|
42
|
+
|
43
|
+
def filter(self, record: Record) -> bool:
|
44
|
+
"""Filter a record based on its date.
|
45
|
+
|
46
|
+
Parameters
|
47
|
+
----------
|
48
|
+
record : Record
|
49
|
+
The record to filter.
|
50
|
+
|
51
|
+
Returns
|
52
|
+
-------
|
53
|
+
bool
|
54
|
+
Whether the record should be filtered.
|
55
|
+
|
56
|
+
"""
|
57
|
+
structure = record.structure
|
58
|
+
|
59
|
+
if self.ref == "deposited":
|
60
|
+
date = structure.deposited
|
61
|
+
elif self.ref == "released":
|
62
|
+
date = structure.released
|
63
|
+
if not date:
|
64
|
+
date = structure.deposited
|
65
|
+
elif self.ref == "revised":
|
66
|
+
date = structure.revised
|
67
|
+
if not date and structure.released:
|
68
|
+
date = structure.released
|
69
|
+
elif not date:
|
70
|
+
date = structure.deposited
|
71
|
+
|
72
|
+
if date is None or date == "":
|
73
|
+
return False
|
74
|
+
|
75
|
+
date = datetime.fromisoformat(date)
|
76
|
+
return date <= self.filter_date
|
@@ -0,0 +1,24 @@
|
|
1
|
+
from abc import ABC, abstractmethod
|
2
|
+
|
3
|
+
from boltz.data.types import Record
|
4
|
+
|
5
|
+
|
6
|
+
class DynamicFilter(ABC):
|
7
|
+
"""Base class for data filters."""
|
8
|
+
|
9
|
+
@abstractmethod
|
10
|
+
def filter(self, record: Record) -> bool:
|
11
|
+
"""Filter a data record.
|
12
|
+
|
13
|
+
Parameters
|
14
|
+
----------
|
15
|
+
record : Record
|
16
|
+
The object to consider filtering in / out.
|
17
|
+
|
18
|
+
Returns
|
19
|
+
-------
|
20
|
+
bool
|
21
|
+
True if the data passes the filter, False otherwise.
|
22
|
+
|
23
|
+
"""
|
24
|
+
raise NotImplementedError
|