boltz-vsynthes 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. boltz/__init__.py +7 -0
  2. boltz/data/__init__.py +0 -0
  3. boltz/data/const.py +1184 -0
  4. boltz/data/crop/__init__.py +0 -0
  5. boltz/data/crop/affinity.py +164 -0
  6. boltz/data/crop/boltz.py +296 -0
  7. boltz/data/crop/cropper.py +45 -0
  8. boltz/data/feature/__init__.py +0 -0
  9. boltz/data/feature/featurizer.py +1230 -0
  10. boltz/data/feature/featurizerv2.py +2208 -0
  11. boltz/data/feature/symmetry.py +602 -0
  12. boltz/data/filter/__init__.py +0 -0
  13. boltz/data/filter/dynamic/__init__.py +0 -0
  14. boltz/data/filter/dynamic/date.py +76 -0
  15. boltz/data/filter/dynamic/filter.py +24 -0
  16. boltz/data/filter/dynamic/max_residues.py +37 -0
  17. boltz/data/filter/dynamic/resolution.py +34 -0
  18. boltz/data/filter/dynamic/size.py +38 -0
  19. boltz/data/filter/dynamic/subset.py +42 -0
  20. boltz/data/filter/static/__init__.py +0 -0
  21. boltz/data/filter/static/filter.py +26 -0
  22. boltz/data/filter/static/ligand.py +37 -0
  23. boltz/data/filter/static/polymer.py +299 -0
  24. boltz/data/module/__init__.py +0 -0
  25. boltz/data/module/inference.py +307 -0
  26. boltz/data/module/inferencev2.py +429 -0
  27. boltz/data/module/training.py +684 -0
  28. boltz/data/module/trainingv2.py +660 -0
  29. boltz/data/mol.py +900 -0
  30. boltz/data/msa/__init__.py +0 -0
  31. boltz/data/msa/mmseqs2.py +235 -0
  32. boltz/data/pad.py +84 -0
  33. boltz/data/parse/__init__.py +0 -0
  34. boltz/data/parse/a3m.py +134 -0
  35. boltz/data/parse/csv.py +100 -0
  36. boltz/data/parse/fasta.py +138 -0
  37. boltz/data/parse/mmcif.py +1239 -0
  38. boltz/data/parse/mmcif_with_constraints.py +1607 -0
  39. boltz/data/parse/schema.py +1851 -0
  40. boltz/data/parse/yaml.py +68 -0
  41. boltz/data/sample/__init__.py +0 -0
  42. boltz/data/sample/cluster.py +283 -0
  43. boltz/data/sample/distillation.py +57 -0
  44. boltz/data/sample/random.py +39 -0
  45. boltz/data/sample/sampler.py +49 -0
  46. boltz/data/tokenize/__init__.py +0 -0
  47. boltz/data/tokenize/boltz.py +195 -0
  48. boltz/data/tokenize/boltz2.py +396 -0
  49. boltz/data/tokenize/tokenizer.py +24 -0
  50. boltz/data/types.py +777 -0
  51. boltz/data/write/__init__.py +0 -0
  52. boltz/data/write/mmcif.py +305 -0
  53. boltz/data/write/pdb.py +171 -0
  54. boltz/data/write/utils.py +23 -0
  55. boltz/data/write/writer.py +330 -0
  56. boltz/main.py +1292 -0
  57. boltz/model/__init__.py +0 -0
  58. boltz/model/layers/__init__.py +0 -0
  59. boltz/model/layers/attention.py +132 -0
  60. boltz/model/layers/attentionv2.py +111 -0
  61. boltz/model/layers/confidence_utils.py +231 -0
  62. boltz/model/layers/dropout.py +34 -0
  63. boltz/model/layers/initialize.py +100 -0
  64. boltz/model/layers/outer_product_mean.py +98 -0
  65. boltz/model/layers/pair_averaging.py +135 -0
  66. boltz/model/layers/pairformer.py +337 -0
  67. boltz/model/layers/relative.py +58 -0
  68. boltz/model/layers/transition.py +78 -0
  69. boltz/model/layers/triangular_attention/__init__.py +0 -0
  70. boltz/model/layers/triangular_attention/attention.py +189 -0
  71. boltz/model/layers/triangular_attention/primitives.py +409 -0
  72. boltz/model/layers/triangular_attention/utils.py +380 -0
  73. boltz/model/layers/triangular_mult.py +212 -0
  74. boltz/model/loss/__init__.py +0 -0
  75. boltz/model/loss/bfactor.py +49 -0
  76. boltz/model/loss/confidence.py +590 -0
  77. boltz/model/loss/confidencev2.py +621 -0
  78. boltz/model/loss/diffusion.py +171 -0
  79. boltz/model/loss/diffusionv2.py +134 -0
  80. boltz/model/loss/distogram.py +48 -0
  81. boltz/model/loss/distogramv2.py +105 -0
  82. boltz/model/loss/validation.py +1025 -0
  83. boltz/model/models/__init__.py +0 -0
  84. boltz/model/models/boltz1.py +1286 -0
  85. boltz/model/models/boltz2.py +1249 -0
  86. boltz/model/modules/__init__.py +0 -0
  87. boltz/model/modules/affinity.py +223 -0
  88. boltz/model/modules/confidence.py +481 -0
  89. boltz/model/modules/confidence_utils.py +181 -0
  90. boltz/model/modules/confidencev2.py +495 -0
  91. boltz/model/modules/diffusion.py +844 -0
  92. boltz/model/modules/diffusion_conditioning.py +116 -0
  93. boltz/model/modules/diffusionv2.py +677 -0
  94. boltz/model/modules/encoders.py +639 -0
  95. boltz/model/modules/encodersv2.py +565 -0
  96. boltz/model/modules/transformers.py +322 -0
  97. boltz/model/modules/transformersv2.py +261 -0
  98. boltz/model/modules/trunk.py +688 -0
  99. boltz/model/modules/trunkv2.py +828 -0
  100. boltz/model/modules/utils.py +303 -0
  101. boltz/model/optim/__init__.py +0 -0
  102. boltz/model/optim/ema.py +389 -0
  103. boltz/model/optim/scheduler.py +99 -0
  104. boltz/model/potentials/__init__.py +0 -0
  105. boltz/model/potentials/potentials.py +497 -0
  106. boltz/model/potentials/schedules.py +32 -0
  107. boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
  108. boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
  109. boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
  110. boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
  111. boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
  112. boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
boltz/data/mol.py ADDED
@@ -0,0 +1,900 @@
1
+ import itertools
2
+ import pickle
3
+ import random
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ import torch
8
+ from rdkit.Chem import Mol
9
+ from tqdm import tqdm
10
+
11
+ from boltz.data import const
12
+ from boltz.data.pad import pad_dim
13
+ from boltz.model.loss.confidence import lddt_dist
14
+
15
+
16
+ def load_molecules(moldir: str, molecules: list[str]) -> dict[str, Mol]:
17
+ """Load the given input data.
18
+
19
+ Parameters
20
+ ----------
21
+ moldir : str
22
+ The path to the molecules directory.
23
+ molecules : list[str]
24
+ The molecules to load.
25
+
26
+ Returns
27
+ -------
28
+ dict[str, Mol]
29
+ The loaded molecules.
30
+ """
31
+ loaded_mols = {}
32
+ for molecule in molecules:
33
+ path = Path(moldir) / f"{molecule}.pkl"
34
+ if not path.exists():
35
+ msg = f"CCD component {molecule} not found!"
36
+ raise ValueError(msg)
37
+ with path.open("rb") as f:
38
+ loaded_mols[molecule] = pickle.load(f) # noqa: S301
39
+ return loaded_mols
40
+
41
+
42
+ def load_canonicals(moldir: str) -> dict[str, Mol]:
43
+ """Load the given input data.
44
+
45
+ Parameters
46
+ ----------
47
+ moldir : str
48
+ The molecules to load.
49
+
50
+ Returns
51
+ -------
52
+ dict[str, Mol]
53
+ The loaded molecules.
54
+
55
+ """
56
+ return load_molecules(moldir, const.canonical_tokens)
57
+
58
+
59
+ def load_all_molecules(moldir: str) -> dict[str, Mol]:
60
+ """Load the given input data.
61
+
62
+ Parameters
63
+ ----------
64
+ moldir : str
65
+ The path to the molecules directory.
66
+ molecules : list[str]
67
+ The molecules to load.
68
+
69
+ Returns
70
+ -------
71
+ dict[str, Mol]
72
+ The loaded molecules.
73
+
74
+ """
75
+ loaded_mols = {}
76
+ files = list(Path(moldir).glob("*.pkl"))
77
+ for path in tqdm(files, total=len(files), desc="Loading molecules", leave=False):
78
+ mol_name = path.stem
79
+ with path.open("rb") as f:
80
+ loaded_mols[mol_name] = pickle.load(f) # noqa: S301
81
+ return loaded_mols
82
+
83
+
84
+ def get_symmetries(mols: dict[str, Mol]) -> dict: # noqa: PLR0912
85
+ """Create a dictionary for the ligand symmetries.
86
+
87
+ Parameters
88
+ ----------
89
+ path : str
90
+ The path to the ligand symmetries.
91
+
92
+ Returns
93
+ -------
94
+ dict
95
+ The ligand symmetries.
96
+
97
+ """
98
+ symmetries = {}
99
+ for key, mol in mols.items():
100
+ try:
101
+ sym = pickle.loads(bytes.fromhex(mol.GetProp("symmetries"))) # noqa: S301
102
+
103
+ if mol.HasProp("pb_edge_index"):
104
+ edge_index = pickle.loads(
105
+ bytes.fromhex(mol.GetProp("pb_edge_index"))
106
+ ).astype(np.int64) # noqa: S301
107
+ lower_bounds = pickle.loads(
108
+ bytes.fromhex(mol.GetProp("pb_lower_bounds"))
109
+ ) # noqa: S301
110
+ upper_bounds = pickle.loads(
111
+ bytes.fromhex(mol.GetProp("pb_upper_bounds"))
112
+ ) # noqa: S301
113
+ bond_mask = pickle.loads(bytes.fromhex(mol.GetProp("pb_bond_mask"))) # noqa: S301
114
+ angle_mask = pickle.loads(bytes.fromhex(mol.GetProp("pb_angle_mask"))) # noqa: S301
115
+ else:
116
+ edge_index = np.empty((2, 0), dtype=np.int64)
117
+ lower_bounds = np.array([], dtype=np.float32)
118
+ upper_bounds = np.array([], dtype=np.float32)
119
+ bond_mask = np.array([], dtype=np.float32)
120
+ angle_mask = np.array([], dtype=np.float32)
121
+
122
+ if mol.HasProp("chiral_atom_index"):
123
+ chiral_atom_index = pickle.loads(
124
+ bytes.fromhex(mol.GetProp("chiral_atom_index"))
125
+ ).astype(np.int64)
126
+ chiral_check_mask = pickle.loads(
127
+ bytes.fromhex(mol.GetProp("chiral_check_mask"))
128
+ ).astype(np.int64)
129
+ chiral_atom_orientations = pickle.loads(
130
+ bytes.fromhex(mol.GetProp("chiral_atom_orientations"))
131
+ )
132
+ else:
133
+ chiral_atom_index = np.empty((4, 0), dtype=np.int64)
134
+ chiral_check_mask = np.array([], dtype=bool)
135
+ chiral_atom_orientations = np.array([], dtype=bool)
136
+
137
+ if mol.HasProp("stereo_bond_index"):
138
+ stereo_bond_index = pickle.loads(
139
+ bytes.fromhex(mol.GetProp("stereo_bond_index"))
140
+ ).astype(np.int64)
141
+ stereo_check_mask = pickle.loads(
142
+ bytes.fromhex(mol.GetProp("stereo_check_mask"))
143
+ ).astype(np.int64)
144
+ stereo_bond_orientations = pickle.loads(
145
+ bytes.fromhex(mol.GetProp("stereo_bond_orientations"))
146
+ )
147
+ else:
148
+ stereo_bond_index = np.empty((4, 0), dtype=np.int64)
149
+ stereo_check_mask = np.array([], dtype=bool)
150
+ stereo_bond_orientations = np.array([], dtype=bool)
151
+
152
+ if mol.HasProp("aromatic_5_ring_index"):
153
+ aromatic_5_ring_index = pickle.loads(
154
+ bytes.fromhex(mol.GetProp("aromatic_5_ring_index"))
155
+ ).astype(np.int64)
156
+ else:
157
+ aromatic_5_ring_index = np.empty((5, 0), dtype=np.int64)
158
+ if mol.HasProp("aromatic_6_ring_index"):
159
+ aromatic_6_ring_index = pickle.loads(
160
+ bytes.fromhex(mol.GetProp("aromatic_6_ring_index"))
161
+ ).astype(np.int64)
162
+ else:
163
+ aromatic_6_ring_index = np.empty((6, 0), dtype=np.int64)
164
+ if mol.HasProp("planar_double_bond_index"):
165
+ planar_double_bond_index = pickle.loads(
166
+ bytes.fromhex(mol.GetProp("planar_double_bond_index"))
167
+ ).astype(np.int64)
168
+ else:
169
+ planar_double_bond_index = np.empty((6, 0), dtype=np.int64)
170
+
171
+ atom_names = [atom.GetProp("name") for atom in mol.GetAtoms()]
172
+ symmetries[key] = (
173
+ sym,
174
+ atom_names,
175
+ edge_index,
176
+ lower_bounds,
177
+ upper_bounds,
178
+ bond_mask,
179
+ angle_mask,
180
+ chiral_atom_index,
181
+ chiral_check_mask,
182
+ chiral_atom_orientations,
183
+ stereo_bond_index,
184
+ stereo_check_mask,
185
+ stereo_bond_orientations,
186
+ aromatic_5_ring_index,
187
+ aromatic_6_ring_index,
188
+ planar_double_bond_index,
189
+ )
190
+ except Exception as e: # noqa: BLE001, PERF203, S110
191
+ pass
192
+
193
+ return symmetries
194
+
195
+
196
+ def compute_symmetry_idx_dictionary(data):
197
+ # Compute the symmetry index dictionary
198
+ total_count = 0
199
+ all_coords = []
200
+ for i, chain in enumerate(data.chains):
201
+ chain.start_idx = total_count
202
+ for j, token in enumerate(chain.tokens):
203
+ token.start_idx = total_count - chain.start_idx
204
+ all_coords.extend(
205
+ [[atom.coords.x, atom.coords.y, atom.coords.z] for atom in token.atoms]
206
+ )
207
+ total_count += len(token.atoms)
208
+ return all_coords
209
+
210
+
211
+ def get_current_idx_list(data):
212
+ idx = []
213
+ for chain in data.chains:
214
+ if chain.in_crop:
215
+ for token in chain.tokens:
216
+ if token.in_crop:
217
+ idx.extend(
218
+ [
219
+ chain.start_idx + token.start_idx + i
220
+ for i in range(len(token.atoms))
221
+ ]
222
+ )
223
+ return idx
224
+
225
+
226
+ def all_different_after_swap(l):
227
+ final = [s[-1] for s in l]
228
+ return len(final) == len(set(final))
229
+
230
+
231
+ def minimum_lddt_symmetry_coords(
232
+ coords: torch.Tensor,
233
+ feats: dict,
234
+ index_batch: int,
235
+ ):
236
+ all_coords = feats["all_coords"][index_batch].unsqueeze(0).to(coords)
237
+ all_resolved_mask = (
238
+ feats["all_resolved_mask"][index_batch].to(coords).to(torch.bool)
239
+ )
240
+ crop_to_all_atom_map = (
241
+ feats["crop_to_all_atom_map"][index_batch].to(coords).to(torch.long)
242
+ )
243
+ chain_symmetries = feats["chain_swaps"][index_batch]
244
+ amino_acids_symmetries = feats["amino_acids_symmetries"][index_batch]
245
+ ligand_symmetries = feats["ligand_symmetries"][index_batch]
246
+
247
+ dmat_predicted = torch.cdist(
248
+ coords[:, : len(crop_to_all_atom_map)], coords[:, : len(crop_to_all_atom_map)]
249
+ )
250
+
251
+ # Check best symmetry on chain swap
252
+ best_true_coords = all_coords[:, crop_to_all_atom_map].clone()
253
+ best_true_resolved_mask = all_resolved_mask[crop_to_all_atom_map].clone()
254
+ best_lddt = -1.0
255
+ for c in chain_symmetries:
256
+ true_all_coords = all_coords.clone()
257
+ true_all_resolved_mask = all_resolved_mask.clone()
258
+ for start1, end1, start2, end2, chainidx1, chainidx2 in c:
259
+ true_all_coords[:, start1:end1] = all_coords[:, start2:end2]
260
+ true_all_resolved_mask[start1:end1] = all_resolved_mask[start2:end2]
261
+ true_coords = true_all_coords[:, crop_to_all_atom_map]
262
+ true_resolved_mask = true_all_resolved_mask[crop_to_all_atom_map]
263
+ dmat_true = torch.cdist(true_coords, true_coords)
264
+ pair_mask = (
265
+ true_resolved_mask[:, None]
266
+ * true_resolved_mask[None, :]
267
+ * (1 - torch.eye(len(true_resolved_mask))).to(true_resolved_mask)
268
+ )
269
+
270
+ lddt = lddt_dist(
271
+ dmat_predicted, dmat_true, pair_mask, cutoff=15.0, per_atom=False
272
+ )[0]
273
+ lddt = lddt.item()
274
+
275
+ if lddt > best_lddt and torch.sum(true_resolved_mask) > 3:
276
+ best_lddt = lddt
277
+ best_true_coords = true_coords
278
+ best_true_resolved_mask = true_resolved_mask
279
+
280
+ # atom symmetries (nucleic acid and protein residues), resolved greedily without recomputing alignment
281
+ true_coords = best_true_coords.clone()
282
+ true_resolved_mask = best_true_resolved_mask.clone()
283
+ for symmetric_amino_or_lig in amino_acids_symmetries + ligand_symmetries:
284
+ best_lddt_improvement = 0.0
285
+
286
+ indices = set()
287
+ for c in symmetric_amino_or_lig:
288
+ for i, j in c:
289
+ indices.add(i)
290
+ indices = sorted(list(indices))
291
+ indices = torch.from_numpy(np.asarray(indices)).to(true_coords.device).long()
292
+ pred_coords_subset = coords[:, : len(crop_to_all_atom_map)][:, indices]
293
+ sub_dmat_pred = torch.cdist(
294
+ coords[:, : len(crop_to_all_atom_map)], pred_coords_subset
295
+ )
296
+
297
+ for c in symmetric_amino_or_lig:
298
+ # starting from greedy best, try to swap the atoms
299
+ new_true_coords = true_coords.clone()
300
+ new_true_resolved_mask = true_resolved_mask.clone()
301
+ for i, j in c:
302
+ new_true_coords[:, i] = true_coords[:, j]
303
+ new_true_resolved_mask[i] = true_resolved_mask[j]
304
+
305
+ true_coords_subset = true_coords[:, indices]
306
+ new_true_coords_subset = new_true_coords[:, indices]
307
+
308
+ sub_dmat_true = torch.cdist(true_coords, true_coords_subset)
309
+ sub_dmat_new_true = torch.cdist(new_true_coords, new_true_coords_subset)
310
+
311
+ sub_true_pair_lddt = (
312
+ true_resolved_mask[:, None] * true_resolved_mask[None, indices]
313
+ )
314
+ sub_true_pair_lddt[indices] = (
315
+ sub_true_pair_lddt[indices]
316
+ * (1 - torch.eye(len(indices))).to(sub_true_pair_lddt).bool()
317
+ )
318
+
319
+ sub_new_true_pair_lddt = (
320
+ new_true_resolved_mask[:, None] * new_true_resolved_mask[None, indices]
321
+ )
322
+ sub_new_true_pair_lddt[indices] = (
323
+ sub_new_true_pair_lddt[indices]
324
+ * (1 - torch.eye(len(indices))).to(sub_true_pair_lddt).bool()
325
+ )
326
+
327
+ lddt, total = lddt_dist(
328
+ sub_dmat_pred,
329
+ sub_dmat_true,
330
+ sub_true_pair_lddt,
331
+ cutoff=15.0,
332
+ per_atom=False,
333
+ )
334
+ new_lddt, new_total = lddt_dist(
335
+ sub_dmat_pred,
336
+ sub_dmat_new_true,
337
+ sub_new_true_pair_lddt,
338
+ cutoff=15.0,
339
+ per_atom=False,
340
+ )
341
+
342
+ lddt_improvement = new_lddt - lddt
343
+
344
+ if lddt_improvement > best_lddt_improvement:
345
+ best_true_coords = new_true_coords
346
+ best_true_resolved_mask = new_true_resolved_mask
347
+ best_lddt_improvement = lddt_improvement
348
+
349
+ # greedily update best coordinates after each amino acid
350
+ true_coords = best_true_coords.clone()
351
+ true_resolved_mask = best_true_resolved_mask.clone()
352
+
353
+ # Recomputing alignment
354
+ true_coords = pad_dim(true_coords, 1, coords.shape[1] - true_coords.shape[1])
355
+ true_resolved_mask = pad_dim(
356
+ true_resolved_mask,
357
+ 0,
358
+ coords.shape[1] - true_resolved_mask.shape[0],
359
+ )
360
+
361
+ return true_coords, true_resolved_mask.unsqueeze(0)
362
+
363
+
364
+ def compute_single_distogram_loss(pred, target, mask):
365
+ # Compute the distogram loss
366
+ errors = -1 * torch.sum(
367
+ target * torch.nn.functional.log_softmax(pred, dim=-1),
368
+ dim=-1,
369
+ )
370
+ denom = 1e-5 + torch.sum(mask, dim=(-1, -2))
371
+ mean = errors * mask
372
+ mean = torch.sum(mean, dim=-1)
373
+ mean = mean / denom[..., None]
374
+ batch_loss = torch.sum(mean, dim=-1)
375
+ global_loss = torch.mean(batch_loss)
376
+ return global_loss
377
+
378
+
379
+ def minimum_lddt_symmetry_dist(
380
+ pred_distogram: torch.Tensor,
381
+ feats: dict,
382
+ index_batch: int,
383
+ ):
384
+ # Note: for now only ligand symmetries are resolved
385
+
386
+ disto_target = feats["disto_target"][index_batch]
387
+ mask = feats["token_disto_mask"][index_batch]
388
+ mask = mask[None, :] * mask[:, None]
389
+ mask = mask * (1 - torch.eye(mask.shape[1])).to(disto_target)
390
+
391
+ coords = feats["coords"][index_batch]
392
+
393
+ ligand_symmetries = feats["ligand_symmetries"][index_batch]
394
+ atom_to_token_map = feats["atom_to_token"][index_batch].argmax(dim=-1)
395
+
396
+ # atom symmetries, resolved greedily without recomputing alignment
397
+ for symmetric_amino_or_lig in ligand_symmetries:
398
+ best_c, best_disto, best_loss_improvement = None, None, 0.0
399
+ for c in symmetric_amino_or_lig:
400
+ # starting from greedy best, try to swap the atoms
401
+ new_disto_target = disto_target.clone()
402
+ indices = []
403
+
404
+ # fix the distogram by replacing first the columns then the rows
405
+ disto_temp = new_disto_target.clone()
406
+ for i, j in c:
407
+ new_disto_target[:, atom_to_token_map[i]] = disto_temp[
408
+ :, atom_to_token_map[j]
409
+ ]
410
+ indices.append(atom_to_token_map[i].item())
411
+ disto_temp = new_disto_target.clone()
412
+ for i, j in c:
413
+ new_disto_target[atom_to_token_map[i], :] = disto_temp[
414
+ atom_to_token_map[j], :
415
+ ]
416
+
417
+ indices = (
418
+ torch.from_numpy(np.asarray(indices)).to(disto_target.device).long()
419
+ )
420
+
421
+ pred_distogram_subset = pred_distogram[:, indices]
422
+ disto_target_subset = disto_target[:, indices]
423
+ new_disto_target_subset = new_disto_target[:, indices]
424
+ mask_subset = mask[:, indices]
425
+
426
+ loss = compute_single_distogram_loss(
427
+ pred_distogram_subset, disto_target_subset, mask_subset
428
+ )
429
+ new_loss = compute_single_distogram_loss(
430
+ pred_distogram_subset, new_disto_target_subset, mask_subset
431
+ )
432
+ loss_improvement = (loss - new_loss) * len(indices)
433
+
434
+ if loss_improvement > best_loss_improvement:
435
+ best_c = c
436
+ best_disto = new_disto_target
437
+ best_loss_improvement = loss_improvement
438
+
439
+ # greedily update best coordinates after each ligand
440
+ if best_loss_improvement > 0:
441
+ disto_target = best_disto.clone()
442
+ old_coords = coords.clone()
443
+ for i, j in best_c:
444
+ coords[:, i] = old_coords[:, j]
445
+
446
+ # update features to be used in diffusion and in distogram loss
447
+ feats["disto_target"][index_batch] = disto_target
448
+ feats["coords"][index_batch] = coords
449
+ return
450
+
451
+
452
+ def compute_all_coords_mask(structure):
453
+ # Compute all coords, crop mask and add start_idx to structure
454
+ total_count = 0
455
+ all_coords = []
456
+ all_coords_crop_mask = []
457
+ all_resolved_mask = []
458
+ for i, chain in enumerate(structure.chains):
459
+ chain.start_idx = total_count
460
+ for j, token in enumerate(chain.tokens):
461
+ token.start_idx = total_count - chain.start_idx
462
+ all_coords.extend(
463
+ [[atom.coords.x, atom.coords.y, atom.coords.z] for atom in token.atoms]
464
+ )
465
+ all_coords_crop_mask.extend(
466
+ [token.in_crop for _ in range(len(token.atoms))]
467
+ )
468
+ all_resolved_mask.extend(
469
+ [token.is_present for _ in range(len(token.atoms))]
470
+ )
471
+ total_count += len(token.atoms)
472
+ if len(all_coords_crop_mask) != len(all_resolved_mask):
473
+ pass
474
+ return all_coords, all_coords_crop_mask, all_resolved_mask
475
+
476
+
477
+ def get_chain_symmetries(cropped, max_n_symmetries=100):
478
+ # get all coordinates and resolved mask
479
+ structure = cropped.structure
480
+ all_coords = []
481
+ all_resolved_mask = []
482
+ original_atom_idx = []
483
+ chain_atom_idx = []
484
+ chain_atom_num = []
485
+ chain_in_crop = []
486
+ chain_asym_id = []
487
+ new_atom_idx = 0
488
+
489
+ for chain in structure.chains:
490
+ atom_idx, atom_num = (
491
+ chain["atom_idx"], # Global index of first atom in the chain
492
+ chain["atom_num"], # Number of atoms in the chain
493
+ )
494
+
495
+ # compute coordinates and resolved mask
496
+ resolved_mask = structure.atoms["is_present"][
497
+ atom_idx : atom_idx + atom_num
498
+ ] # Whether each atom in the chain is actually resolved
499
+
500
+ # ensemble_atom_starts = [structure.ensemble[idx]["atom_coord_idx"] for idx in cropped.ensemble_ref_idxs]
501
+ # coords = np.array(
502
+ # [structure.coords[ensemble_atom_start + atom_idx: ensemble_atom_start + atom_idx + atom_num]["coords"] for
503
+ # ensemble_atom_start in ensemble_atom_starts])
504
+
505
+ coords = structure.atoms["coords"][atom_idx : atom_idx + atom_num]
506
+
507
+ in_crop = False
508
+ for token in cropped.tokens:
509
+ if token["asym_id"] == chain["asym_id"]:
510
+ in_crop = True
511
+ break
512
+
513
+ all_coords.append(coords)
514
+ all_resolved_mask.append(resolved_mask)
515
+ original_atom_idx.append(atom_idx)
516
+ chain_atom_idx.append(new_atom_idx)
517
+ chain_atom_num.append(atom_num)
518
+ chain_in_crop.append(in_crop)
519
+ chain_asym_id.append(chain["asym_id"])
520
+
521
+ new_atom_idx += atom_num
522
+
523
+ all_coords = np.concatenate(all_coords, axis=0)
524
+ # Compute backmapping from token to all coords
525
+ crop_to_all_atom_map = []
526
+ for token in cropped.tokens:
527
+ chain_idx = chain_asym_id.index(token["asym_id"])
528
+ start = (
529
+ chain_atom_idx[chain_idx] - original_atom_idx[chain_idx] + token["atom_idx"]
530
+ )
531
+ crop_to_all_atom_map.append(np.arange(start, start + token["atom_num"]))
532
+ crop_to_all_atom_map = np.concatenate(crop_to_all_atom_map, axis=0)
533
+
534
+ # Compute the connections edge index for covalent bonds
535
+ all_atom_to_crop_map = np.zeros(all_coords.shape[0], dtype=np.int64)
536
+ all_atom_to_crop_map[crop_to_all_atom_map.astype(np.int64)] = np.arange(
537
+ crop_to_all_atom_map.shape[0]
538
+ )
539
+ connections_edge_index = []
540
+ for connection in structure.bonds:
541
+ if (connection["chain_1"] == connection["chain_2"]) and (
542
+ connection["res_1"] == connection["res_2"]
543
+ ):
544
+ continue
545
+ connections_edge_index.append([connection["atom_1"], connection["atom_2"]])
546
+ if len(connections_edge_index) > 0:
547
+ connections_edge_index = np.array(connections_edge_index, dtype=np.int64).T
548
+ connections_edge_index = all_atom_to_crop_map[connections_edge_index]
549
+ else:
550
+ connections_edge_index = np.empty((2, 0))
551
+
552
+ # Compute the symmetries between chains
553
+ symmetries = []
554
+ swaps = []
555
+ for i, chain in enumerate(structure.chains):
556
+ start = chain_atom_idx[i]
557
+ end = start + chain_atom_num[i]
558
+
559
+ if chain_in_crop[i]:
560
+ possible_swaps = []
561
+ for j, chain2 in enumerate(structure.chains):
562
+ start2 = chain_atom_idx[j]
563
+ end2 = start2 + chain_atom_num[j]
564
+ if (
565
+ chain["entity_id"] == chain2["entity_id"]
566
+ and end - start == end2 - start2
567
+ ):
568
+ possible_swaps.append((start, end, start2, end2, i, j))
569
+ swaps.append(possible_swaps)
570
+
571
+ found = False
572
+ for symmetry_idx, symmetry in enumerate(symmetries):
573
+ j = symmetry[0][0]
574
+ chain2 = structure.chains[j]
575
+ start2 = chain_atom_idx[j]
576
+ end2 = start2 + chain_atom_num[j]
577
+ if (
578
+ chain["entity_id"] == chain2["entity_id"]
579
+ and end - start == end2 - start2
580
+ ):
581
+ symmetries[symmetry_idx].append(
582
+ (i, start, end, chain_in_crop[i], chain["mol_type"])
583
+ )
584
+ found = True
585
+ if not found:
586
+ symmetries.append([(i, start, end, chain_in_crop[i], chain["mol_type"])])
587
+
588
+ combinations = itertools.product(*swaps)
589
+ # to avoid combinatorial explosion, bound the number of combinations even considered
590
+ combinations = list(itertools.islice(combinations, max_n_symmetries * 10))
591
+ # filter for all chains getting a different assignment
592
+ combinations = [c for c in combinations if all_different_after_swap(c)]
593
+
594
+ if len(combinations) > max_n_symmetries:
595
+ combinations = random.sample(combinations, max_n_symmetries)
596
+
597
+ if len(combinations) == 0:
598
+ combinations.append([])
599
+
600
+ for i in range(len(symmetries) - 1, -1, -1):
601
+ if not any(chain[3] for chain in symmetries[i]):
602
+ symmetries.pop(i)
603
+
604
+ features = {}
605
+ features["all_coords"] = torch.Tensor(all_coords) # axis=1 with ensemble
606
+
607
+ features["all_resolved_mask"] = torch.Tensor(
608
+ np.concatenate(all_resolved_mask, axis=0)
609
+ )
610
+ features["crop_to_all_atom_map"] = torch.Tensor(crop_to_all_atom_map)
611
+ features["chain_symmetries"] = symmetries
612
+ features["connections_edge_index"] = torch.tensor(connections_edge_index)
613
+ features["chain_swaps"] = combinations
614
+
615
+ return features
616
+
617
+
618
+ def get_amino_acids_symmetries(cropped):
619
+ # Compute standard amino-acids symmetries
620
+ swaps = []
621
+ start_index_crop = 0
622
+ for token in cropped.tokens:
623
+ symmetries = const.ref_symmetries.get(const.tokens[token["res_type"]], [])
624
+ if len(symmetries) > 0:
625
+ residue_swaps = []
626
+ for sym in symmetries:
627
+ sym_new_idx = [
628
+ (i + start_index_crop, j + start_index_crop) for i, j in sym
629
+ ]
630
+ residue_swaps.append(sym_new_idx)
631
+ swaps.append(residue_swaps)
632
+ start_index_crop += token["atom_num"]
633
+
634
+ features = {"amino_acids_symmetries": swaps}
635
+ return features
636
+
637
+
638
+ def slice_valid_index(index, ccd_to_valid_id_array, args=None):
639
+ index = ccd_to_valid_id_array[index]
640
+ valid_index_mask = (~np.isnan(index)).all(axis=0)
641
+ index = index[:, valid_index_mask]
642
+ if args is None:
643
+ return index
644
+ args = (arg[valid_index_mask] for arg in args)
645
+ return index, args
646
+
647
+
648
+ def get_ligand_symmetries(cropped, symmetries, return_physical_metrics=False):
649
+ # Compute ligand and non-standard amino-acids symmetries
650
+ structure = cropped.structure
651
+
652
+ added_molecules = {}
653
+ index_mols = []
654
+ atom_count = 0
655
+
656
+ for token in cropped.tokens:
657
+ # check if molecule is already added by identifying it through asym_id and res_idx
658
+ atom_count += token["atom_num"]
659
+ mol_id = (token["asym_id"], token["res_idx"])
660
+ if mol_id in added_molecules:
661
+ added_molecules[mol_id] += token["atom_num"]
662
+ continue
663
+ added_molecules[mol_id] = token["atom_num"]
664
+
665
+ # get the molecule type and indices
666
+ residue_idx = token["res_idx"] + structure.chains[token["asym_id"]]["res_idx"]
667
+ mol_name = structure.residues[residue_idx]["name"]
668
+ atom_idx = structure.residues[residue_idx]["atom_idx"]
669
+ mol_atom_names = structure.atoms[
670
+ atom_idx : atom_idx + structure.residues[residue_idx]["atom_num"]
671
+ ]["name"]
672
+ if mol_name not in const.ref_symmetries:
673
+ index_mols.append(
674
+ (mol_name, atom_count - token["atom_num"], mol_id, mol_atom_names)
675
+ )
676
+
677
+ # for each molecule, get the symmetries
678
+ molecule_symmetries = []
679
+ all_edge_index = []
680
+ all_lower_bounds, all_upper_bounds = [], []
681
+ all_bond_mask, all_angle_mask = [], []
682
+ all_chiral_atom_index, all_chiral_check_mask, all_chiral_atom_orientations = (
683
+ [],
684
+ [],
685
+ [],
686
+ )
687
+ all_stereo_bond_index, all_stereo_check_mask, all_stereo_bond_orientations = (
688
+ [],
689
+ [],
690
+ [],
691
+ )
692
+ (
693
+ all_aromatic_5_ring_index,
694
+ all_aromatic_6_ring_index,
695
+ all_planar_double_bond_index,
696
+ ) = (
697
+ [],
698
+ [],
699
+ [],
700
+ )
701
+ for mol_name, start_mol, mol_id, mol_atom_names in index_mols:
702
+ if not mol_name in symmetries:
703
+ continue
704
+ else:
705
+ swaps = []
706
+ (
707
+ syms_ccd,
708
+ mol_atom_names_ccd,
709
+ edge_index,
710
+ lower_bounds,
711
+ upper_bounds,
712
+ bond_mask,
713
+ angle_mask,
714
+ chiral_atom_index,
715
+ chiral_check_mask,
716
+ chiral_atom_orientations,
717
+ stereo_bond_index,
718
+ stereo_check_mask,
719
+ stereo_bond_orientations,
720
+ aromatic_5_ring_index,
721
+ aromatic_6_ring_index,
722
+ planar_double_bond_index,
723
+ ) = symmetries[mol_name]
724
+ # Get indices of mol_atom_names_ccd that are in mol_atom_names
725
+ ccd_to_valid_ids = {
726
+ mol_atom_names_ccd.index(name): i
727
+ for i, name in enumerate(mol_atom_names)
728
+ }
729
+ ccd_to_valid_id_array = np.array(
730
+ [
731
+ float("nan") if i not in ccd_to_valid_ids else ccd_to_valid_ids[i]
732
+ for i in range(len(mol_atom_names_ccd))
733
+ ]
734
+ )
735
+ ccd_valid_ids = set(ccd_to_valid_ids.keys())
736
+ syms = []
737
+ # Get syms
738
+ for sym_ccd in syms_ccd:
739
+ sym_dict = {}
740
+ bool_add = True
741
+ for i, j in enumerate(sym_ccd):
742
+ if i in ccd_valid_ids:
743
+ if j in ccd_valid_ids:
744
+ i_true = ccd_to_valid_ids[i]
745
+ j_true = ccd_to_valid_ids[j]
746
+ sym_dict[i_true] = j_true
747
+ else:
748
+ bool_add = False
749
+ break
750
+ if bool_add:
751
+ syms.append([sym_dict[i] for i in range(len(ccd_valid_ids))])
752
+ for sym in syms:
753
+ if len(sym) != added_molecules[mol_id]:
754
+ raise Exception(
755
+ f"Symmetry length mismatch {len(sym)} {added_molecules[mol_id]}"
756
+ )
757
+ # assert (
758
+ # len(sym) == added_molecules[mol_id]
759
+ # ), f"Symmetry length mismatch {len(sym)} {added_molecules[mol_id]}"
760
+ sym_new_idx = []
761
+ for i, j in enumerate(sym):
762
+ if i != int(j):
763
+ sym_new_idx.append((i + start_mol, int(j) + start_mol))
764
+ if len(sym_new_idx) > 0:
765
+ swaps.append(sym_new_idx)
766
+
767
+ if len(swaps) > 0:
768
+ molecule_symmetries.append(swaps)
769
+
770
+ if return_physical_metrics:
771
+ edge_index, (lower_bounds, upper_bounds, bond_mask, angle_mask) = (
772
+ slice_valid_index(
773
+ edge_index,
774
+ ccd_to_valid_id_array,
775
+ (lower_bounds, upper_bounds, bond_mask, angle_mask),
776
+ )
777
+ )
778
+ all_edge_index.append(edge_index + start_mol)
779
+ all_lower_bounds.append(lower_bounds)
780
+ all_upper_bounds.append(upper_bounds)
781
+ all_bond_mask.append(bond_mask)
782
+ all_angle_mask.append(angle_mask)
783
+
784
+ chiral_atom_index, (chiral_check_mask, chiral_atom_orientations) = (
785
+ slice_valid_index(
786
+ chiral_atom_index,
787
+ ccd_to_valid_id_array,
788
+ (chiral_check_mask, chiral_atom_orientations),
789
+ )
790
+ )
791
+ all_chiral_atom_index.append(chiral_atom_index + start_mol)
792
+ all_chiral_check_mask.append(chiral_check_mask)
793
+ all_chiral_atom_orientations.append(chiral_atom_orientations)
794
+
795
+ stereo_bond_index, (stereo_check_mask, stereo_bond_orientations) = (
796
+ slice_valid_index(
797
+ stereo_bond_index,
798
+ ccd_to_valid_id_array,
799
+ (stereo_check_mask, stereo_bond_orientations),
800
+ )
801
+ )
802
+ all_stereo_bond_index.append(stereo_bond_index + start_mol)
803
+ all_stereo_check_mask.append(stereo_check_mask)
804
+ all_stereo_bond_orientations.append(stereo_bond_orientations)
805
+
806
+ aromatic_5_ring_index = slice_valid_index(
807
+ aromatic_5_ring_index, ccd_to_valid_id_array
808
+ )
809
+ aromatic_6_ring_index = slice_valid_index(
810
+ aromatic_6_ring_index, ccd_to_valid_id_array
811
+ )
812
+ planar_double_bond_index = slice_valid_index(
813
+ planar_double_bond_index, ccd_to_valid_id_array
814
+ )
815
+ all_aromatic_5_ring_index.append(aromatic_5_ring_index + start_mol)
816
+ all_aromatic_6_ring_index.append(aromatic_6_ring_index + start_mol)
817
+ all_planar_double_bond_index.append(
818
+ planar_double_bond_index + start_mol
819
+ )
820
+
821
+ if return_physical_metrics:
822
+ if len(all_edge_index) > 0:
823
+ all_edge_index = np.concatenate(all_edge_index, axis=1)
824
+ all_lower_bounds = np.concatenate(all_lower_bounds, axis=0)
825
+ all_upper_bounds = np.concatenate(all_upper_bounds, axis=0)
826
+ all_bond_mask = np.concatenate(all_bond_mask, axis=0)
827
+ all_angle_mask = np.concatenate(all_angle_mask, axis=0)
828
+
829
+ all_chiral_atom_index = np.concatenate(all_chiral_atom_index, axis=1)
830
+ all_chiral_check_mask = np.concatenate(all_chiral_check_mask, axis=0)
831
+ all_chiral_atom_orientations = np.concatenate(
832
+ all_chiral_atom_orientations, axis=0
833
+ )
834
+
835
+ all_stereo_bond_index = np.concatenate(all_stereo_bond_index, axis=1)
836
+ all_stereo_check_mask = np.concatenate(all_stereo_check_mask, axis=0)
837
+ all_stereo_bond_orientations = np.concatenate(
838
+ all_stereo_bond_orientations, axis=0
839
+ )
840
+
841
+ all_aromatic_5_ring_index = np.concatenate(
842
+ all_aromatic_5_ring_index, axis=1
843
+ )
844
+ all_aromatic_6_ring_index = np.concatenate(
845
+ all_aromatic_6_ring_index, axis=1
846
+ )
847
+ all_planar_double_bond_index = np.empty(
848
+ (6, 0), dtype=np.int64
849
+ ) # TODO remove np.concatenate(all_planar_double_bond_index, axis=1)
850
+ else:
851
+ all_edge_index = np.empty((2, 0), dtype=np.int64)
852
+ all_lower_bounds = np.array([], dtype=np.float32)
853
+ all_upper_bounds = np.array([], dtype=np.float32)
854
+ all_bond_mask = np.array([], dtype=bool)
855
+ all_angle_mask = np.array([], dtype=bool)
856
+
857
+ all_chiral_atom_index = np.empty((4, 0), dtype=np.int64)
858
+ all_chiral_check_mask = np.array([], dtype=bool)
859
+ all_chiral_atom_orientations = np.array([], dtype=bool)
860
+
861
+ all_stereo_bond_index = np.empty((4, 0), dtype=np.int64)
862
+ all_stereo_check_mask = np.array([], dtype=bool)
863
+ all_stereo_bond_orientations = np.array([], dtype=bool)
864
+
865
+ all_aromatic_5_ring_index = np.empty((5, 0), dtype=np.int64)
866
+ all_aromatic_6_ring_index = np.empty((6, 0), dtype=np.int64)
867
+ all_planar_double_bond_index = np.empty((6, 0), dtype=np.int64)
868
+
869
+ features = {
870
+ "ligand_symmetries": molecule_symmetries,
871
+ "ligand_edge_index": torch.tensor(all_edge_index).long(),
872
+ "ligand_edge_lower_bounds": torch.tensor(all_lower_bounds),
873
+ "ligand_edge_upper_bounds": torch.tensor(all_upper_bounds),
874
+ "ligand_edge_bond_mask": torch.tensor(all_bond_mask),
875
+ "ligand_edge_angle_mask": torch.tensor(all_angle_mask),
876
+ "ligand_chiral_atom_index": torch.tensor(all_chiral_atom_index).long(),
877
+ "ligand_chiral_check_mask": torch.tensor(all_chiral_check_mask),
878
+ "ligand_chiral_atom_orientations": torch.tensor(
879
+ all_chiral_atom_orientations
880
+ ),
881
+ "ligand_stereo_bond_index": torch.tensor(all_stereo_bond_index).long(),
882
+ "ligand_stereo_check_mask": torch.tensor(all_stereo_check_mask),
883
+ "ligand_stereo_bond_orientations": torch.tensor(
884
+ all_stereo_bond_orientations
885
+ ),
886
+ "ligand_aromatic_5_ring_index": torch.tensor(
887
+ all_aromatic_5_ring_index
888
+ ).long(),
889
+ "ligand_aromatic_6_ring_index": torch.tensor(
890
+ all_aromatic_6_ring_index
891
+ ).long(),
892
+ "ligand_planar_double_bond_index": torch.tensor(
893
+ all_planar_double_bond_index
894
+ ).long(),
895
+ }
896
+ else:
897
+ features = {
898
+ "ligand_symmetries": molecule_symmetries,
899
+ }
900
+ return features