boltz-vsynthes 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. boltz/__init__.py +7 -0
  2. boltz/data/__init__.py +0 -0
  3. boltz/data/const.py +1184 -0
  4. boltz/data/crop/__init__.py +0 -0
  5. boltz/data/crop/affinity.py +164 -0
  6. boltz/data/crop/boltz.py +296 -0
  7. boltz/data/crop/cropper.py +45 -0
  8. boltz/data/feature/__init__.py +0 -0
  9. boltz/data/feature/featurizer.py +1230 -0
  10. boltz/data/feature/featurizerv2.py +2208 -0
  11. boltz/data/feature/symmetry.py +602 -0
  12. boltz/data/filter/__init__.py +0 -0
  13. boltz/data/filter/dynamic/__init__.py +0 -0
  14. boltz/data/filter/dynamic/date.py +76 -0
  15. boltz/data/filter/dynamic/filter.py +24 -0
  16. boltz/data/filter/dynamic/max_residues.py +37 -0
  17. boltz/data/filter/dynamic/resolution.py +34 -0
  18. boltz/data/filter/dynamic/size.py +38 -0
  19. boltz/data/filter/dynamic/subset.py +42 -0
  20. boltz/data/filter/static/__init__.py +0 -0
  21. boltz/data/filter/static/filter.py +26 -0
  22. boltz/data/filter/static/ligand.py +37 -0
  23. boltz/data/filter/static/polymer.py +299 -0
  24. boltz/data/module/__init__.py +0 -0
  25. boltz/data/module/inference.py +307 -0
  26. boltz/data/module/inferencev2.py +429 -0
  27. boltz/data/module/training.py +684 -0
  28. boltz/data/module/trainingv2.py +660 -0
  29. boltz/data/mol.py +900 -0
  30. boltz/data/msa/__init__.py +0 -0
  31. boltz/data/msa/mmseqs2.py +235 -0
  32. boltz/data/pad.py +84 -0
  33. boltz/data/parse/__init__.py +0 -0
  34. boltz/data/parse/a3m.py +134 -0
  35. boltz/data/parse/csv.py +100 -0
  36. boltz/data/parse/fasta.py +138 -0
  37. boltz/data/parse/mmcif.py +1239 -0
  38. boltz/data/parse/mmcif_with_constraints.py +1607 -0
  39. boltz/data/parse/schema.py +1851 -0
  40. boltz/data/parse/yaml.py +68 -0
  41. boltz/data/sample/__init__.py +0 -0
  42. boltz/data/sample/cluster.py +283 -0
  43. boltz/data/sample/distillation.py +57 -0
  44. boltz/data/sample/random.py +39 -0
  45. boltz/data/sample/sampler.py +49 -0
  46. boltz/data/tokenize/__init__.py +0 -0
  47. boltz/data/tokenize/boltz.py +195 -0
  48. boltz/data/tokenize/boltz2.py +396 -0
  49. boltz/data/tokenize/tokenizer.py +24 -0
  50. boltz/data/types.py +777 -0
  51. boltz/data/write/__init__.py +0 -0
  52. boltz/data/write/mmcif.py +305 -0
  53. boltz/data/write/pdb.py +171 -0
  54. boltz/data/write/utils.py +23 -0
  55. boltz/data/write/writer.py +330 -0
  56. boltz/main.py +1292 -0
  57. boltz/model/__init__.py +0 -0
  58. boltz/model/layers/__init__.py +0 -0
  59. boltz/model/layers/attention.py +132 -0
  60. boltz/model/layers/attentionv2.py +111 -0
  61. boltz/model/layers/confidence_utils.py +231 -0
  62. boltz/model/layers/dropout.py +34 -0
  63. boltz/model/layers/initialize.py +100 -0
  64. boltz/model/layers/outer_product_mean.py +98 -0
  65. boltz/model/layers/pair_averaging.py +135 -0
  66. boltz/model/layers/pairformer.py +337 -0
  67. boltz/model/layers/relative.py +58 -0
  68. boltz/model/layers/transition.py +78 -0
  69. boltz/model/layers/triangular_attention/__init__.py +0 -0
  70. boltz/model/layers/triangular_attention/attention.py +189 -0
  71. boltz/model/layers/triangular_attention/primitives.py +409 -0
  72. boltz/model/layers/triangular_attention/utils.py +380 -0
  73. boltz/model/layers/triangular_mult.py +212 -0
  74. boltz/model/loss/__init__.py +0 -0
  75. boltz/model/loss/bfactor.py +49 -0
  76. boltz/model/loss/confidence.py +590 -0
  77. boltz/model/loss/confidencev2.py +621 -0
  78. boltz/model/loss/diffusion.py +171 -0
  79. boltz/model/loss/diffusionv2.py +134 -0
  80. boltz/model/loss/distogram.py +48 -0
  81. boltz/model/loss/distogramv2.py +105 -0
  82. boltz/model/loss/validation.py +1025 -0
  83. boltz/model/models/__init__.py +0 -0
  84. boltz/model/models/boltz1.py +1286 -0
  85. boltz/model/models/boltz2.py +1249 -0
  86. boltz/model/modules/__init__.py +0 -0
  87. boltz/model/modules/affinity.py +223 -0
  88. boltz/model/modules/confidence.py +481 -0
  89. boltz/model/modules/confidence_utils.py +181 -0
  90. boltz/model/modules/confidencev2.py +495 -0
  91. boltz/model/modules/diffusion.py +844 -0
  92. boltz/model/modules/diffusion_conditioning.py +116 -0
  93. boltz/model/modules/diffusionv2.py +677 -0
  94. boltz/model/modules/encoders.py +639 -0
  95. boltz/model/modules/encodersv2.py +565 -0
  96. boltz/model/modules/transformers.py +322 -0
  97. boltz/model/modules/transformersv2.py +261 -0
  98. boltz/model/modules/trunk.py +688 -0
  99. boltz/model/modules/trunkv2.py +828 -0
  100. boltz/model/modules/utils.py +303 -0
  101. boltz/model/optim/__init__.py +0 -0
  102. boltz/model/optim/ema.py +389 -0
  103. boltz/model/optim/scheduler.py +99 -0
  104. boltz/model/potentials/__init__.py +0 -0
  105. boltz/model/potentials/potentials.py +497 -0
  106. boltz/model/potentials/schedules.py +32 -0
  107. boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
  108. boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
  109. boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
  110. boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
  111. boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
  112. boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,602 @@
1
+ import itertools
2
+ import pickle
3
+ import random
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ import torch
8
+
9
+ from boltz.data import const
10
+ from boltz.data.pad import pad_dim
11
+ from boltz.model.loss.confidence import lddt_dist
12
+ from boltz.model.loss.validation import weighted_minimum_rmsd_single
13
+
14
+
15
+ def convert_atom_name(name: str) -> tuple[int, int, int, int]:
16
+ """Convert an atom name to a standard format.
17
+
18
+ Parameters
19
+ ----------
20
+ name : str
21
+ The atom name.
22
+
23
+ Returns
24
+ -------
25
+ Tuple[int, int, int, int]
26
+ The converted atom name.
27
+
28
+ """
29
+ name = name.strip()
30
+ name = [ord(c) - 32 for c in name]
31
+ name = name + [0] * (4 - len(name))
32
+ return tuple(name)
33
+
34
+
35
+ def get_symmetries(path: str) -> dict:
36
+ """Create a dictionary for the ligand symmetries.
37
+
38
+ Parameters
39
+ ----------
40
+ path : str
41
+ The path to the ligand symmetries.
42
+
43
+ Returns
44
+ -------
45
+ dict
46
+ The ligand symmetries.
47
+
48
+ """
49
+ with Path(path).open("rb") as f:
50
+ data: dict = pickle.load(f) # noqa: S301
51
+
52
+ symmetries = {}
53
+ for key, mol in data.items():
54
+ try:
55
+ serialized_sym = bytes.fromhex(mol.GetProp("symmetries"))
56
+ sym = pickle.loads(serialized_sym) # noqa: S301
57
+ atom_names = []
58
+ for atom in mol.GetAtoms():
59
+ # Get atom name
60
+ atom_name = convert_atom_name(atom.GetProp("name"))
61
+ atom_names.append(atom_name)
62
+
63
+ symmetries[key] = (sym, atom_names)
64
+ except Exception: # noqa: BLE001, PERF203, S110
65
+ pass
66
+
67
+ return symmetries
68
+
69
+
70
+ def compute_symmetry_idx_dictionary(data):
71
+ # Compute the symmetry index dictionary
72
+ total_count = 0
73
+ all_coords = []
74
+ for i, chain in enumerate(data.chains):
75
+ chain.start_idx = total_count
76
+ for j, token in enumerate(chain.tokens):
77
+ token.start_idx = total_count - chain.start_idx
78
+ all_coords.extend(
79
+ [[atom.coords.x, atom.coords.y, atom.coords.z] for atom in token.atoms]
80
+ )
81
+ total_count += len(token.atoms)
82
+ return all_coords
83
+
84
+
85
+ def get_current_idx_list(data):
86
+ idx = []
87
+ for chain in data.chains:
88
+ if chain.in_crop:
89
+ for token in chain.tokens:
90
+ if token.in_crop:
91
+ idx.extend(
92
+ [
93
+ chain.start_idx + token.start_idx + i
94
+ for i in range(len(token.atoms))
95
+ ]
96
+ )
97
+ return idx
98
+
99
+
100
+ def all_different_after_swap(l):
101
+ final = [s[-1] for s in l]
102
+ return len(final) == len(set(final))
103
+
104
+
105
+ def minimum_symmetry_coords(
106
+ coords: torch.Tensor,
107
+ feats: dict,
108
+ index_batch: int,
109
+ **args_rmsd,
110
+ ):
111
+ all_coords = feats["all_coords"][index_batch].unsqueeze(0).to(coords)
112
+ all_resolved_mask = (
113
+ feats["all_resolved_mask"][index_batch].to(coords).to(torch.bool)
114
+ )
115
+ crop_to_all_atom_map = (
116
+ feats["crop_to_all_atom_map"][index_batch].to(coords).to(torch.long)
117
+ )
118
+ chain_symmetries = feats["chain_symmetries"][index_batch]
119
+ amino_acids_symmetries = feats["amino_acids_symmetries"][index_batch]
120
+ ligand_symmetries = feats["ligand_symmetries"][index_batch]
121
+
122
+ # Check best symmetry on chain swap
123
+ best_true_coords = None
124
+ best_rmsd = float("inf")
125
+ best_align_weights = None
126
+ for c in chain_symmetries:
127
+ true_all_coords = all_coords.clone()
128
+ true_all_resolved_mask = all_resolved_mask.clone()
129
+ for start1, end1, start2, end2, chainidx1, chainidx2 in c:
130
+ true_all_coords[:, start1:end1] = all_coords[:, start2:end2]
131
+ true_all_resolved_mask[start1:end1] = all_resolved_mask[start2:end2]
132
+ true_coords = true_all_coords[:, crop_to_all_atom_map]
133
+ true_resolved_mask = true_all_resolved_mask[crop_to_all_atom_map]
134
+ true_coords = pad_dim(true_coords, 1, coords.shape[1] - true_coords.shape[1])
135
+ true_resolved_mask = pad_dim(
136
+ true_resolved_mask,
137
+ 0,
138
+ coords.shape[1] - true_resolved_mask.shape[0],
139
+ )
140
+ try:
141
+ rmsd, aligned_coords, align_weights = weighted_minimum_rmsd_single(
142
+ coords,
143
+ true_coords,
144
+ atom_mask=true_resolved_mask,
145
+ atom_to_token=feats["atom_to_token"][index_batch : index_batch + 1],
146
+ mol_type=feats["mol_type"][index_batch : index_batch + 1],
147
+ **args_rmsd,
148
+ )
149
+ except:
150
+ print("Warning: error in rmsd computation inside symmetry code")
151
+ continue
152
+ rmsd = rmsd.item()
153
+
154
+ if rmsd < best_rmsd:
155
+ best_rmsd = rmsd
156
+ best_true_coords = aligned_coords
157
+ best_align_weights = align_weights
158
+ best_true_resolved_mask = true_resolved_mask
159
+
160
+ # atom symmetries (nucleic acid and protein residues), resolved greedily without recomputing alignment
161
+ true_coords = best_true_coords.clone()
162
+ true_resolved_mask = best_true_resolved_mask.clone()
163
+ for symmetric_amino in amino_acids_symmetries:
164
+ for c in symmetric_amino:
165
+ # starting from greedy best, try to swap the atoms
166
+ new_true_coords = true_coords.clone()
167
+ new_true_resolved_mask = true_resolved_mask.clone()
168
+ for i, j in c:
169
+ new_true_coords[:, i] = true_coords[:, j]
170
+ new_true_resolved_mask[i] = true_resolved_mask[j]
171
+
172
+ # compute squared distance, for efficiency we do not recompute the alignment
173
+ best_mse_loss = torch.sum(
174
+ ((coords - best_true_coords) ** 2).sum(dim=-1)
175
+ * best_align_weights
176
+ * best_true_resolved_mask,
177
+ dim=-1,
178
+ ) / torch.sum(best_align_weights * best_true_resolved_mask, dim=-1)
179
+ new_mse_loss = torch.sum(
180
+ ((coords - new_true_coords) ** 2).sum(dim=-1)
181
+ * best_align_weights
182
+ * new_true_resolved_mask,
183
+ dim=-1,
184
+ ) / torch.sum(best_align_weights * new_true_resolved_mask, dim=-1)
185
+
186
+ if best_mse_loss > new_mse_loss:
187
+ best_true_coords = new_true_coords
188
+ best_true_resolved_mask = new_true_resolved_mask
189
+
190
+ # greedily update best coordinates after each amino acid
191
+ true_coords = best_true_coords.clone()
192
+ true_resolved_mask = best_true_resolved_mask.clone()
193
+
194
+ # Recomputing alignment
195
+ rmsd, true_coords, best_align_weights = weighted_minimum_rmsd_single(
196
+ coords,
197
+ true_coords,
198
+ atom_mask=true_resolved_mask,
199
+ atom_to_token=feats["atom_to_token"][index_batch : index_batch + 1],
200
+ mol_type=feats["mol_type"][index_batch : index_batch + 1],
201
+ **args_rmsd,
202
+ )
203
+ best_rmsd = rmsd.item()
204
+
205
+ # atom symmetries (ligand and non-standard), resolved greedily recomputing alignment
206
+ for symmetric_ligand in ligand_symmetries:
207
+ for c in symmetric_ligand:
208
+ new_true_coords = true_coords.clone()
209
+ new_true_resolved_mask = true_resolved_mask.clone()
210
+ for i, j in c:
211
+ new_true_coords[:, j] = true_coords[:, i]
212
+ new_true_resolved_mask[j] = true_resolved_mask[i]
213
+ try:
214
+ # TODO if this is too slow maybe we can get away with not recomputing alignment
215
+ rmsd, aligned_coords, align_weights = weighted_minimum_rmsd_single(
216
+ coords,
217
+ new_true_coords,
218
+ atom_mask=new_true_resolved_mask,
219
+ atom_to_token=feats["atom_to_token"][index_batch : index_batch + 1],
220
+ mol_type=feats["mol_type"][index_batch : index_batch + 1],
221
+ **args_rmsd,
222
+ )
223
+ except Exception as e:
224
+ raise e
225
+ print(e)
226
+ continue
227
+ rmsd = rmsd.item()
228
+ if rmsd < best_rmsd:
229
+ best_true_coords = aligned_coords
230
+ best_rmsd = rmsd
231
+ best_true_resolved_mask = new_true_resolved_mask
232
+
233
+ true_coords = best_true_coords.clone()
234
+ true_resolved_mask = best_true_resolved_mask.clone()
235
+
236
+ return best_true_coords, best_rmsd, best_true_resolved_mask.unsqueeze(0)
237
+
238
+
239
+ def minimum_lddt_symmetry_coords(
240
+ coords: torch.Tensor,
241
+ feats: dict,
242
+ index_batch: int,
243
+ **args_rmsd,
244
+ ):
245
+ all_coords = feats["all_coords"][index_batch].unsqueeze(0).to(coords)
246
+ all_resolved_mask = (
247
+ feats["all_resolved_mask"][index_batch].to(coords).to(torch.bool)
248
+ )
249
+ crop_to_all_atom_map = (
250
+ feats["crop_to_all_atom_map"][index_batch].to(coords).to(torch.long)
251
+ )
252
+ chain_symmetries = feats["chain_symmetries"][index_batch]
253
+ amino_acids_symmetries = feats["amino_acids_symmetries"][index_batch]
254
+ ligand_symmetries = feats["ligand_symmetries"][index_batch]
255
+
256
+ dmat_predicted = torch.cdist(
257
+ coords[:, : len(crop_to_all_atom_map)], coords[:, : len(crop_to_all_atom_map)]
258
+ )
259
+
260
+ # Check best symmetry on chain swap
261
+ best_true_coords = None
262
+ best_lddt = 0
263
+ for c in chain_symmetries:
264
+ true_all_coords = all_coords.clone()
265
+ true_all_resolved_mask = all_resolved_mask.clone()
266
+ for start1, end1, start2, end2, chainidx1, chainidx2 in c:
267
+ true_all_coords[:, start1:end1] = all_coords[:, start2:end2]
268
+ true_all_resolved_mask[start1:end1] = all_resolved_mask[start2:end2]
269
+ true_coords = true_all_coords[:, crop_to_all_atom_map]
270
+ true_resolved_mask = true_all_resolved_mask[crop_to_all_atom_map]
271
+ dmat_true = torch.cdist(true_coords, true_coords)
272
+ pair_mask = (
273
+ true_resolved_mask[:, None]
274
+ * true_resolved_mask[None, :]
275
+ * (1 - torch.eye(len(true_resolved_mask))).to(true_resolved_mask)
276
+ )
277
+
278
+ lddt = lddt_dist(
279
+ dmat_predicted, dmat_true, pair_mask, cutoff=15.0, per_atom=False
280
+ )[0]
281
+ lddt = lddt.item()
282
+
283
+ if lddt > best_lddt:
284
+ best_lddt = lddt
285
+ best_true_coords = true_coords
286
+ best_true_resolved_mask = true_resolved_mask
287
+
288
+ # atom symmetries (nucleic acid and protein residues), resolved greedily without recomputing alignment
289
+ true_coords = best_true_coords.clone()
290
+ true_resolved_mask = best_true_resolved_mask.clone()
291
+ for symmetric_amino_or_lig in amino_acids_symmetries + ligand_symmetries:
292
+ for c in symmetric_amino_or_lig:
293
+ # starting from greedy best, try to swap the atoms
294
+ new_true_coords = true_coords.clone()
295
+ new_true_resolved_mask = true_resolved_mask.clone()
296
+ indices = []
297
+ for i, j in c:
298
+ new_true_coords[:, i] = true_coords[:, j]
299
+ new_true_resolved_mask[i] = true_resolved_mask[j]
300
+ indices.append(i)
301
+
302
+ indices = (
303
+ torch.from_numpy(np.asarray(indices)).to(new_true_coords.device).long()
304
+ )
305
+
306
+ pred_coords_subset = coords[:, : len(crop_to_all_atom_map)][:, indices]
307
+ true_coords_subset = true_coords[:, indices]
308
+ new_true_coords_subset = new_true_coords[:, indices]
309
+
310
+ sub_dmat_pred = torch.cdist(
311
+ coords[:, : len(crop_to_all_atom_map)], pred_coords_subset
312
+ )
313
+ sub_dmat_true = torch.cdist(true_coords, true_coords_subset)
314
+ sub_dmat_new_true = torch.cdist(new_true_coords, new_true_coords_subset)
315
+
316
+ sub_true_pair_lddt = (
317
+ true_resolved_mask[:, None] * true_resolved_mask[None, indices]
318
+ )
319
+ sub_true_pair_lddt[indices] = (
320
+ sub_true_pair_lddt[indices]
321
+ * (1 - torch.eye(len(indices))).to(sub_true_pair_lddt).bool()
322
+ )
323
+
324
+ sub_new_true_pair_lddt = (
325
+ new_true_resolved_mask[:, None] * new_true_resolved_mask[None, indices]
326
+ )
327
+ sub_new_true_pair_lddt[indices] = (
328
+ sub_new_true_pair_lddt[indices]
329
+ * (1 - torch.eye(len(indices))).to(sub_true_pair_lddt).bool()
330
+ )
331
+
332
+ lddt = lddt_dist(
333
+ sub_dmat_pred,
334
+ sub_dmat_true,
335
+ sub_true_pair_lddt,
336
+ cutoff=15.0,
337
+ per_atom=False,
338
+ )[0]
339
+ new_lddt = lddt_dist(
340
+ sub_dmat_pred,
341
+ sub_dmat_new_true,
342
+ sub_new_true_pair_lddt,
343
+ cutoff=15.0,
344
+ per_atom=False,
345
+ )[0]
346
+
347
+ if new_lddt > lddt:
348
+ best_true_coords = new_true_coords
349
+ best_true_resolved_mask = new_true_resolved_mask
350
+
351
+ # greedily update best coordinates after each amino acid
352
+ true_coords = best_true_coords.clone()
353
+ true_resolved_mask = best_true_resolved_mask.clone()
354
+
355
+ # Recomputing alignment
356
+ true_coords = pad_dim(true_coords, 1, coords.shape[1] - true_coords.shape[1])
357
+ true_resolved_mask = pad_dim(
358
+ true_resolved_mask,
359
+ 0,
360
+ coords.shape[1] - true_resolved_mask.shape[0],
361
+ )
362
+
363
+ try:
364
+ rmsd, true_coords, _ = weighted_minimum_rmsd_single(
365
+ coords,
366
+ true_coords,
367
+ atom_mask=true_resolved_mask,
368
+ atom_to_token=feats["atom_to_token"][index_batch : index_batch + 1],
369
+ mol_type=feats["mol_type"][index_batch : index_batch + 1],
370
+ **args_rmsd,
371
+ )
372
+ best_rmsd = rmsd.item()
373
+ except Exception as e:
374
+ print("Failed proper RMSD computation, returning inf. Error: ", e)
375
+ best_rmsd = 1000
376
+
377
+ return true_coords, best_rmsd, true_resolved_mask.unsqueeze(0)
378
+
379
+
380
+ def compute_all_coords_mask(structure):
381
+ # Compute all coords, crop mask and add start_idx to structure
382
+ total_count = 0
383
+ all_coords = []
384
+ all_coords_crop_mask = []
385
+ all_resolved_mask = []
386
+ for i, chain in enumerate(structure.chains):
387
+ chain.start_idx = total_count
388
+ for j, token in enumerate(chain.tokens):
389
+ token.start_idx = total_count - chain.start_idx
390
+ all_coords.extend(
391
+ [[atom.coords.x, atom.coords.y, atom.coords.z] for atom in token.atoms]
392
+ )
393
+ all_coords_crop_mask.extend(
394
+ [token.in_crop for _ in range(len(token.atoms))]
395
+ )
396
+ all_resolved_mask.extend(
397
+ [token.is_present for _ in range(len(token.atoms))]
398
+ )
399
+ total_count += len(token.atoms)
400
+ if len(all_coords_crop_mask) != len(all_resolved_mask):
401
+ pass
402
+ return all_coords, all_coords_crop_mask, all_resolved_mask
403
+
404
+
405
+ def get_chain_symmetries(cropped, max_n_symmetries=100):
406
+ # get all coordinates and resolved mask
407
+ structure = cropped.structure
408
+ all_coords = []
409
+ all_resolved_mask = []
410
+ original_atom_idx = []
411
+ chain_atom_idx = []
412
+ chain_atom_num = []
413
+ chain_in_crop = []
414
+ chain_asym_id = []
415
+ new_atom_idx = 0
416
+
417
+ for chain in structure.chains:
418
+ atom_idx, atom_num = (
419
+ chain["atom_idx"],
420
+ chain["atom_num"],
421
+ )
422
+
423
+ # compute coordinates and resolved mask
424
+ resolved_mask = structure.atoms["is_present"][atom_idx : atom_idx + atom_num]
425
+
426
+ # ensemble_atom_starts = [structure.ensemble[idx]["atom_coord_idx"] for idx in cropped.ensemble_ref_idxs]
427
+ # coords = np.array(
428
+ # [structure.coords[ensemble_atom_start + atom_idx: ensemble_atom_start + atom_idx + atom_num]["coords"] for
429
+ # ensemble_atom_start in ensemble_atom_starts])
430
+
431
+ coords = structure.atoms["coords"][atom_idx : atom_idx + atom_num]
432
+
433
+ in_crop = False
434
+ for token in cropped.tokens:
435
+ if token["asym_id"] == chain["asym_id"]:
436
+ in_crop = True
437
+ break
438
+
439
+ all_coords.append(coords)
440
+ all_resolved_mask.append(resolved_mask)
441
+ original_atom_idx.append(atom_idx)
442
+ chain_atom_idx.append(new_atom_idx)
443
+ chain_atom_num.append(atom_num)
444
+ chain_in_crop.append(in_crop)
445
+ chain_asym_id.append(chain["asym_id"])
446
+
447
+ new_atom_idx += atom_num
448
+
449
+ # Compute backmapping from token to all coords
450
+ crop_to_all_atom_map = []
451
+ for token in cropped.tokens:
452
+ chain_idx = chain_asym_id.index(token["asym_id"])
453
+ start = (
454
+ chain_atom_idx[chain_idx] - original_atom_idx[chain_idx] + token["atom_idx"]
455
+ )
456
+ crop_to_all_atom_map.append(np.arange(start, start + token["atom_num"]))
457
+
458
+ # Compute the symmetries between chains
459
+ swaps = []
460
+ for i, chain in enumerate(structure.chains):
461
+ start = chain_atom_idx[i]
462
+ end = start + chain_atom_num[i]
463
+ if chain_in_crop[i]:
464
+ possible_swaps = []
465
+ for j, chain2 in enumerate(structure.chains):
466
+ start2 = chain_atom_idx[j]
467
+ end2 = start2 + chain_atom_num[j]
468
+ if (
469
+ chain["entity_id"] == chain2["entity_id"]
470
+ and end - start == end2 - start2
471
+ ):
472
+ possible_swaps.append((start, end, start2, end2, i, j))
473
+ swaps.append(possible_swaps)
474
+ combinations = itertools.product(*swaps)
475
+ # to avoid combinatorial explosion, bound the number of combinations even considered
476
+ combinations = list(itertools.islice(combinations, max_n_symmetries * 10))
477
+ # filter for all chains getting a different assignment
478
+ combinations = [c for c in combinations if all_different_after_swap(c)]
479
+
480
+ if len(combinations) > max_n_symmetries:
481
+ combinations = random.sample(combinations, max_n_symmetries)
482
+
483
+ if len(combinations) == 0:
484
+ combinations.append([])
485
+
486
+ features = {}
487
+ features["all_coords"] = torch.Tensor(
488
+ np.concatenate(all_coords, axis=0)
489
+ ) # axis=1 with ensemble
490
+
491
+ features["all_resolved_mask"] = torch.Tensor(
492
+ np.concatenate(all_resolved_mask, axis=0)
493
+ )
494
+ features["crop_to_all_atom_map"] = torch.Tensor(
495
+ np.concatenate(crop_to_all_atom_map, axis=0)
496
+ )
497
+ features["chain_symmetries"] = combinations
498
+
499
+ return features
500
+
501
+
502
+ def get_amino_acids_symmetries(cropped):
503
+ # Compute standard amino-acids symmetries
504
+ swaps = []
505
+ start_index_crop = 0
506
+ for token in cropped.tokens:
507
+ symmetries = const.ref_symmetries.get(const.tokens[token["res_type"]], [])
508
+ if len(symmetries) > 0:
509
+ residue_swaps = []
510
+ for sym in symmetries:
511
+ sym_new_idx = [
512
+ (i + start_index_crop, j + start_index_crop) for i, j in sym
513
+ ]
514
+ residue_swaps.append(sym_new_idx)
515
+ swaps.append(residue_swaps)
516
+ start_index_crop += token["atom_num"]
517
+
518
+ features = {"amino_acids_symmetries": swaps}
519
+ return features
520
+
521
+
522
+ def get_ligand_symmetries(cropped, symmetries):
523
+ # Compute ligand and non-standard amino-acids symmetries
524
+ structure = cropped.structure
525
+
526
+ added_molecules = {}
527
+ index_mols = []
528
+ atom_count = 0
529
+ for token in cropped.tokens:
530
+ # check if molecule is already added by identifying it through asym_id and res_idx
531
+ atom_count += token["atom_num"]
532
+ mol_id = (token["asym_id"], token["res_idx"])
533
+ if mol_id in added_molecules.keys():
534
+ added_molecules[mol_id] += token["atom_num"]
535
+ continue
536
+ added_molecules[mol_id] = token["atom_num"]
537
+
538
+ # get the molecule type and indices
539
+ residue_idx = token["res_idx"] + structure.chains[token["asym_id"]]["res_idx"]
540
+ mol_name = structure.residues[residue_idx]["name"]
541
+ atom_idx = structure.residues[residue_idx]["atom_idx"]
542
+ mol_atom_names = structure.atoms[
543
+ atom_idx : atom_idx + structure.residues[residue_idx]["atom_num"]
544
+ ]["name"]
545
+ mol_atom_names = [tuple(m) for m in mol_atom_names]
546
+ if mol_name not in const.ref_symmetries.keys():
547
+ index_mols.append(
548
+ (mol_name, atom_count - token["atom_num"], mol_id, mol_atom_names)
549
+ )
550
+
551
+ # for each molecule, get the symmetries
552
+ molecule_symmetries = []
553
+ for mol_name, start_mol, mol_id, mol_atom_names in index_mols:
554
+ if not mol_name in symmetries:
555
+ continue
556
+ else:
557
+ swaps = []
558
+ syms_ccd, mol_atom_names_ccd = symmetries[mol_name]
559
+ # Get indices of mol_atom_names_ccd that are in mol_atom_names
560
+ ccd_to_valid_ids = {
561
+ mol_atom_names_ccd.index(name): i
562
+ for i, name in enumerate(mol_atom_names)
563
+ }
564
+ ccd_valid_ids = set(ccd_to_valid_ids.keys())
565
+
566
+ syms = []
567
+ # Get syms
568
+ for sym_ccd in syms_ccd:
569
+ sym_dict = {}
570
+ bool_add = True
571
+ for i, j in enumerate(sym_ccd):
572
+ if i in ccd_valid_ids:
573
+ if j in ccd_valid_ids:
574
+ i_true = ccd_to_valid_ids[i]
575
+ j_true = ccd_to_valid_ids[j]
576
+ sym_dict[i_true] = j_true
577
+ else:
578
+ bool_add = False
579
+ break
580
+ if bool_add:
581
+ syms.append([sym_dict[i] for i in range(len(ccd_valid_ids))])
582
+
583
+ for sym in syms:
584
+ if len(sym) != added_molecules[mol_id]:
585
+ raise Exception(
586
+ f"Symmetry length mismatch {len(sym)} {added_molecules[mol_id]}"
587
+ )
588
+ # assert (
589
+ # len(sym) == added_molecules[mol_id]
590
+ # ), f"Symmetry length mismatch {len(sym)} {added_molecules[mol_id]}"
591
+ sym_new_idx = []
592
+ for i, j in enumerate(sym):
593
+ if i != int(j):
594
+ sym_new_idx.append((i + start_mol, int(j) + start_mol))
595
+ if len(sym_new_idx) > 0:
596
+ swaps.append(sym_new_idx)
597
+ if len(swaps) > 0:
598
+ molecule_symmetries.append(swaps)
599
+
600
+ features = {"ligand_symmetries": molecule_symmetries}
601
+
602
+ return features
File without changes
File without changes
@@ -0,0 +1,76 @@
1
+ from datetime import datetime
2
+ from typing import Literal
3
+
4
+ from boltz.data.types import Record
5
+ from boltz.data.filter.dynamic.filter import DynamicFilter
6
+
7
+
8
+ class DateFilter(DynamicFilter):
9
+ """A filter that filters complexes based on their date.
10
+
11
+ The date can be the deposition, release, or revision date.
12
+ If the date is not available, the previous date is used.
13
+
14
+ If no date is available, the complex is rejected.
15
+
16
+ """
17
+
18
+ def __init__(
19
+ self,
20
+ date: str,
21
+ ref: Literal["deposited", "revised", "released"],
22
+ ) -> None:
23
+ """Initialize the filter.
24
+
25
+ Parameters
26
+ ----------
27
+ date : str, optional
28
+ The maximum date of PDB entries to filter
29
+ ref : Literal["deposited", "revised", "released"]
30
+ The reference date to use.
31
+
32
+ """
33
+ self.filter_date = datetime.fromisoformat(date)
34
+ self.ref = ref
35
+
36
+ if ref not in ["deposited", "revised", "released"]:
37
+ msg = (
38
+ "Invalid reference date. Must be ",
39
+ "deposited, revised, or released",
40
+ )
41
+ raise ValueError(msg)
42
+
43
+ def filter(self, record: Record) -> bool:
44
+ """Filter a record based on its date.
45
+
46
+ Parameters
47
+ ----------
48
+ record : Record
49
+ The record to filter.
50
+
51
+ Returns
52
+ -------
53
+ bool
54
+ Whether the record should be filtered.
55
+
56
+ """
57
+ structure = record.structure
58
+
59
+ if self.ref == "deposited":
60
+ date = structure.deposited
61
+ elif self.ref == "released":
62
+ date = structure.released
63
+ if not date:
64
+ date = structure.deposited
65
+ elif self.ref == "revised":
66
+ date = structure.revised
67
+ if not date and structure.released:
68
+ date = structure.released
69
+ elif not date:
70
+ date = structure.deposited
71
+
72
+ if date is None or date == "":
73
+ return False
74
+
75
+ date = datetime.fromisoformat(date)
76
+ return date <= self.filter_date
@@ -0,0 +1,24 @@
1
+ from abc import ABC, abstractmethod
2
+
3
+ from boltz.data.types import Record
4
+
5
+
6
+ class DynamicFilter(ABC):
7
+ """Base class for data filters."""
8
+
9
+ @abstractmethod
10
+ def filter(self, record: Record) -> bool:
11
+ """Filter a data record.
12
+
13
+ Parameters
14
+ ----------
15
+ record : Record
16
+ The object to consider filtering in / out.
17
+
18
+ Returns
19
+ -------
20
+ bool
21
+ True if the data passes the filter, False otherwise.
22
+
23
+ """
24
+ raise NotImplementedError