boltz-vsynthes 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. boltz/__init__.py +7 -0
  2. boltz/data/__init__.py +0 -0
  3. boltz/data/const.py +1184 -0
  4. boltz/data/crop/__init__.py +0 -0
  5. boltz/data/crop/affinity.py +164 -0
  6. boltz/data/crop/boltz.py +296 -0
  7. boltz/data/crop/cropper.py +45 -0
  8. boltz/data/feature/__init__.py +0 -0
  9. boltz/data/feature/featurizer.py +1230 -0
  10. boltz/data/feature/featurizerv2.py +2208 -0
  11. boltz/data/feature/symmetry.py +602 -0
  12. boltz/data/filter/__init__.py +0 -0
  13. boltz/data/filter/dynamic/__init__.py +0 -0
  14. boltz/data/filter/dynamic/date.py +76 -0
  15. boltz/data/filter/dynamic/filter.py +24 -0
  16. boltz/data/filter/dynamic/max_residues.py +37 -0
  17. boltz/data/filter/dynamic/resolution.py +34 -0
  18. boltz/data/filter/dynamic/size.py +38 -0
  19. boltz/data/filter/dynamic/subset.py +42 -0
  20. boltz/data/filter/static/__init__.py +0 -0
  21. boltz/data/filter/static/filter.py +26 -0
  22. boltz/data/filter/static/ligand.py +37 -0
  23. boltz/data/filter/static/polymer.py +299 -0
  24. boltz/data/module/__init__.py +0 -0
  25. boltz/data/module/inference.py +307 -0
  26. boltz/data/module/inferencev2.py +429 -0
  27. boltz/data/module/training.py +684 -0
  28. boltz/data/module/trainingv2.py +660 -0
  29. boltz/data/mol.py +900 -0
  30. boltz/data/msa/__init__.py +0 -0
  31. boltz/data/msa/mmseqs2.py +235 -0
  32. boltz/data/pad.py +84 -0
  33. boltz/data/parse/__init__.py +0 -0
  34. boltz/data/parse/a3m.py +134 -0
  35. boltz/data/parse/csv.py +100 -0
  36. boltz/data/parse/fasta.py +138 -0
  37. boltz/data/parse/mmcif.py +1239 -0
  38. boltz/data/parse/mmcif_with_constraints.py +1607 -0
  39. boltz/data/parse/schema.py +1851 -0
  40. boltz/data/parse/yaml.py +68 -0
  41. boltz/data/sample/__init__.py +0 -0
  42. boltz/data/sample/cluster.py +283 -0
  43. boltz/data/sample/distillation.py +57 -0
  44. boltz/data/sample/random.py +39 -0
  45. boltz/data/sample/sampler.py +49 -0
  46. boltz/data/tokenize/__init__.py +0 -0
  47. boltz/data/tokenize/boltz.py +195 -0
  48. boltz/data/tokenize/boltz2.py +396 -0
  49. boltz/data/tokenize/tokenizer.py +24 -0
  50. boltz/data/types.py +777 -0
  51. boltz/data/write/__init__.py +0 -0
  52. boltz/data/write/mmcif.py +305 -0
  53. boltz/data/write/pdb.py +171 -0
  54. boltz/data/write/utils.py +23 -0
  55. boltz/data/write/writer.py +330 -0
  56. boltz/main.py +1292 -0
  57. boltz/model/__init__.py +0 -0
  58. boltz/model/layers/__init__.py +0 -0
  59. boltz/model/layers/attention.py +132 -0
  60. boltz/model/layers/attentionv2.py +111 -0
  61. boltz/model/layers/confidence_utils.py +231 -0
  62. boltz/model/layers/dropout.py +34 -0
  63. boltz/model/layers/initialize.py +100 -0
  64. boltz/model/layers/outer_product_mean.py +98 -0
  65. boltz/model/layers/pair_averaging.py +135 -0
  66. boltz/model/layers/pairformer.py +337 -0
  67. boltz/model/layers/relative.py +58 -0
  68. boltz/model/layers/transition.py +78 -0
  69. boltz/model/layers/triangular_attention/__init__.py +0 -0
  70. boltz/model/layers/triangular_attention/attention.py +189 -0
  71. boltz/model/layers/triangular_attention/primitives.py +409 -0
  72. boltz/model/layers/triangular_attention/utils.py +380 -0
  73. boltz/model/layers/triangular_mult.py +212 -0
  74. boltz/model/loss/__init__.py +0 -0
  75. boltz/model/loss/bfactor.py +49 -0
  76. boltz/model/loss/confidence.py +590 -0
  77. boltz/model/loss/confidencev2.py +621 -0
  78. boltz/model/loss/diffusion.py +171 -0
  79. boltz/model/loss/diffusionv2.py +134 -0
  80. boltz/model/loss/distogram.py +48 -0
  81. boltz/model/loss/distogramv2.py +105 -0
  82. boltz/model/loss/validation.py +1025 -0
  83. boltz/model/models/__init__.py +0 -0
  84. boltz/model/models/boltz1.py +1286 -0
  85. boltz/model/models/boltz2.py +1249 -0
  86. boltz/model/modules/__init__.py +0 -0
  87. boltz/model/modules/affinity.py +223 -0
  88. boltz/model/modules/confidence.py +481 -0
  89. boltz/model/modules/confidence_utils.py +181 -0
  90. boltz/model/modules/confidencev2.py +495 -0
  91. boltz/model/modules/diffusion.py +844 -0
  92. boltz/model/modules/diffusion_conditioning.py +116 -0
  93. boltz/model/modules/diffusionv2.py +677 -0
  94. boltz/model/modules/encoders.py +639 -0
  95. boltz/model/modules/encodersv2.py +565 -0
  96. boltz/model/modules/transformers.py +322 -0
  97. boltz/model/modules/transformersv2.py +261 -0
  98. boltz/model/modules/trunk.py +688 -0
  99. boltz/model/modules/trunkv2.py +828 -0
  100. boltz/model/modules/utils.py +303 -0
  101. boltz/model/optim/__init__.py +0 -0
  102. boltz/model/optim/ema.py +389 -0
  103. boltz/model/optim/scheduler.py +99 -0
  104. boltz/model/potentials/__init__.py +0 -0
  105. boltz/model/potentials/potentials.py +497 -0
  106. boltz/model/potentials/schedules.py +32 -0
  107. boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
  108. boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
  109. boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
  110. boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
  111. boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
  112. boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1851 @@
1
+ from collections.abc import Mapping
2
+ from dataclasses import dataclass
3
+ from pathlib import Path
4
+ from typing import Optional
5
+
6
+ import click
7
+ import numpy as np
8
+ from Bio import Align
9
+ from chembl_structure_pipeline.exclude_flag import exclude_flag
10
+ from chembl_structure_pipeline.standardizer import standardize_mol
11
+ from rdkit import Chem, rdBase
12
+ from rdkit.Chem import AllChem, HybridizationType
13
+ from rdkit.Chem.MolStandardize import rdMolStandardize
14
+ from rdkit.Chem.rdchem import BondStereo, Conformer, Mol
15
+ from rdkit.Chem.rdDistGeom import GetMoleculeBoundsMatrix
16
+ from rdkit.Chem.rdMolDescriptors import CalcNumHeavyAtoms
17
+ from scipy.optimize import linear_sum_assignment
18
+
19
+ from boltz.data import const
20
+ from boltz.data.mol import load_molecules
21
+ from boltz.data.parse.mmcif import parse_mmcif
22
+ from boltz.data.types import (
23
+ AffinityInfo,
24
+ Atom,
25
+ AtomV2,
26
+ Bond,
27
+ BondV2,
28
+ Chain,
29
+ ChainInfo,
30
+ ChiralAtomConstraint,
31
+ Connection,
32
+ Coords,
33
+ Ensemble,
34
+ InferenceOptions,
35
+ Interface,
36
+ PlanarBondConstraint,
37
+ PlanarRing5Constraint,
38
+ PlanarRing6Constraint,
39
+ RDKitBoundsConstraint,
40
+ Record,
41
+ Residue,
42
+ ResidueConstraints,
43
+ StereoBondConstraint,
44
+ Structure,
45
+ StructureInfo,
46
+ StructureV2,
47
+ Target,
48
+ TemplateInfo,
49
+ )
50
+
51
+ ####################################################################################################
52
+ # DATACLASSES
53
+ ####################################################################################################
54
+
55
+
56
+ @dataclass(frozen=True)
57
+ class ParsedAtom:
58
+ """A parsed atom object."""
59
+
60
+ name: str
61
+ element: int
62
+ charge: int
63
+ coords: tuple[float, float, float]
64
+ conformer: tuple[float, float, float]
65
+ is_present: bool
66
+ chirality: int
67
+
68
+
69
+ @dataclass(frozen=True)
70
+ class ParsedBond:
71
+ """A parsed bond object."""
72
+
73
+ atom_1: int
74
+ atom_2: int
75
+ type: int
76
+
77
+
78
+ @dataclass(frozen=True)
79
+ class ParsedRDKitBoundsConstraint:
80
+ """A parsed RDKit bounds constraint object."""
81
+
82
+ atom_idxs: tuple[int, int]
83
+ is_bond: bool
84
+ is_angle: bool
85
+ upper_bound: float
86
+ lower_bound: float
87
+
88
+
89
+ @dataclass(frozen=True)
90
+ class ParsedChiralAtomConstraint:
91
+ """A parsed chiral atom constraint object."""
92
+
93
+ atom_idxs: tuple[int, int, int, int]
94
+ is_reference: bool
95
+ is_r: bool
96
+
97
+
98
+ @dataclass(frozen=True)
99
+ class ParsedStereoBondConstraint:
100
+ """A parsed stereo bond constraint object."""
101
+
102
+ atom_idxs: tuple[int, int, int, int]
103
+ is_check: bool
104
+ is_e: bool
105
+
106
+
107
+ @dataclass(frozen=True)
108
+ class ParsedPlanarBondConstraint:
109
+ """A parsed planar bond constraint object."""
110
+
111
+ atom_idxs: tuple[int, int, int, int, int, int]
112
+
113
+
114
+ @dataclass(frozen=True)
115
+ class ParsedPlanarRing5Constraint:
116
+ """A parsed planar bond constraint object."""
117
+
118
+ atom_idxs: tuple[int, int, int, int, int]
119
+
120
+
121
+ @dataclass(frozen=True)
122
+ class ParsedPlanarRing6Constraint:
123
+ """A parsed planar bond constraint object."""
124
+
125
+ atom_idxs: tuple[int, int, int, int, int, int]
126
+
127
+
128
+ @dataclass(frozen=True)
129
+ class ParsedResidue:
130
+ """A parsed residue object."""
131
+
132
+ name: str
133
+ type: int
134
+ idx: int
135
+ atoms: list[ParsedAtom]
136
+ bonds: list[ParsedBond]
137
+ orig_idx: Optional[int]
138
+ atom_center: int
139
+ atom_disto: int
140
+ is_standard: bool
141
+ is_present: bool
142
+ rdkit_bounds_constraints: Optional[list[ParsedRDKitBoundsConstraint]] = None
143
+ chiral_atom_constraints: Optional[list[ParsedChiralAtomConstraint]] = None
144
+ stereo_bond_constraints: Optional[list[ParsedStereoBondConstraint]] = None
145
+ planar_bond_constraints: Optional[list[ParsedPlanarBondConstraint]] = None
146
+ planar_ring_5_constraints: Optional[list[ParsedPlanarRing5Constraint]] = None
147
+ planar_ring_6_constraints: Optional[list[ParsedPlanarRing6Constraint]] = None
148
+
149
+
150
+ @dataclass(frozen=True)
151
+ class ParsedChain:
152
+ """A parsed chain object."""
153
+
154
+ entity: str
155
+ type: int
156
+ residues: list[ParsedResidue]
157
+ cyclic_period: int
158
+ sequence: Optional[str] = None
159
+ affinity: Optional[bool] = False
160
+ affinity_mw: Optional[float] = None
161
+
162
+
163
+ @dataclass(frozen=True)
164
+ class Alignment:
165
+ """A parsed alignment object."""
166
+
167
+ query_st: int
168
+ query_en: int
169
+ template_st: int
170
+ template_en: int
171
+
172
+
173
+ ####################################################################################################
174
+ # HELPERS
175
+ ####################################################################################################
176
+
177
+
178
+ def convert_atom_name(name: str) -> tuple[int, int, int, int]:
179
+ """Convert an atom name to a standard format.
180
+
181
+ Parameters
182
+ ----------
183
+ name : str
184
+ The atom name.
185
+
186
+ Returns
187
+ -------
188
+ Tuple[int, int, int, int]
189
+ The converted atom name.
190
+
191
+ """
192
+ name = name.strip()
193
+ name = [ord(c) - 32 for c in name]
194
+ name = name + [0] * (4 - len(name))
195
+ return tuple(name)
196
+
197
+
198
+ def compute_3d_conformer(mol: Mol, version: str = "v3") -> bool:
199
+ """Generate 3D coordinates using EKTDG method.
200
+
201
+ Taken from `pdbeccdutils.core.component.Component`.
202
+
203
+ Parameters
204
+ ----------
205
+ mol: Mol
206
+ The RDKit molecule to process
207
+ version: str, optional
208
+ The ETKDG version, defaults ot v3
209
+
210
+ Returns
211
+ -------
212
+ bool
213
+ Whether computation was successful.
214
+
215
+ """
216
+ if version == "v3":
217
+ options = AllChem.ETKDGv3()
218
+ elif version == "v2":
219
+ options = AllChem.ETKDGv2()
220
+ else:
221
+ options = AllChem.ETKDGv2()
222
+
223
+ options.clearConfs = False
224
+ conf_id = -1
225
+
226
+ try:
227
+ conf_id = AllChem.EmbedMolecule(mol, options)
228
+
229
+ if conf_id == -1:
230
+ print(
231
+ f"WARNING: RDKit ETKDGv3 failed to generate a conformer for molecule "
232
+ f"{Chem.MolToSmiles(AllChem.RemoveHs(mol))}, so the program will start with random coordinates. "
233
+ f"Note that the performance of the model under this behaviour was not tested."
234
+ )
235
+ options.useRandomCoords = True
236
+ conf_id = AllChem.EmbedMolecule(mol, options)
237
+
238
+ AllChem.UFFOptimizeMolecule(mol, confId=conf_id, maxIters=1000)
239
+
240
+ except RuntimeError:
241
+ pass # Force field issue here
242
+ except ValueError:
243
+ pass # sanitization issue here
244
+
245
+ if conf_id != -1:
246
+ conformer = mol.GetConformer(conf_id)
247
+ conformer.SetProp("name", "Computed")
248
+ conformer.SetProp("coord_generation", f"ETKDG{version}")
249
+
250
+ return True
251
+
252
+ return False
253
+
254
+
255
+ def get_conformer(mol: Mol) -> Conformer:
256
+ """Retrieve an rdkit object for a deemed conformer.
257
+
258
+ Inspired by `pdbeccdutils.core.component.Component`.
259
+
260
+ Parameters
261
+ ----------
262
+ mol: Mol
263
+ The molecule to process.
264
+
265
+ Returns
266
+ -------
267
+ Conformer
268
+ The desired conformer, if any.
269
+
270
+ Raises
271
+ ------
272
+ ValueError
273
+ If there are no conformers of the given tyoe.
274
+
275
+ """
276
+ # Try using the computed conformer
277
+ for c in mol.GetConformers():
278
+ try:
279
+ if c.GetProp("name") == "Computed":
280
+ return c
281
+ except KeyError: # noqa: PERF203
282
+ pass
283
+
284
+ # Fallback to the ideal coordinates
285
+ for c in mol.GetConformers():
286
+ try:
287
+ if c.GetProp("name") == "Ideal":
288
+ return c
289
+ except KeyError: # noqa: PERF203
290
+ pass
291
+
292
+ # Fallback to boltz2 format
293
+ conf_ids = [int(conf.GetId()) for conf in mol.GetConformers()]
294
+ if len(conf_ids) > 0:
295
+ conf_id = conf_ids[0]
296
+ conformer = mol.GetConformer(conf_id)
297
+ return conformer
298
+
299
+ msg = "Conformer does not exist."
300
+ raise ValueError(msg)
301
+
302
+
303
+ def compute_geometry_constraints(mol: Mol, idx_map):
304
+ if mol.GetNumAtoms() <= 1:
305
+ return []
306
+
307
+ # Ensure RingInfo is initialized
308
+ mol.UpdatePropertyCache(strict=False)
309
+ Chem.GetSymmSSSR(mol) # Compute ring information
310
+
311
+ bounds = GetMoleculeBoundsMatrix(
312
+ mol,
313
+ set15bounds=True,
314
+ scaleVDW=True,
315
+ doTriangleSmoothing=True,
316
+ useMacrocycle14config=False,
317
+ )
318
+ bonds = set(
319
+ tuple(sorted(b)) for b in mol.GetSubstructMatches(Chem.MolFromSmarts("*~*"))
320
+ )
321
+ angles = set(
322
+ tuple(sorted([a[0], a[2]]))
323
+ for a in mol.GetSubstructMatches(Chem.MolFromSmarts("*~*~*"))
324
+ )
325
+
326
+ constraints = []
327
+ for i, j in zip(*np.triu_indices(mol.GetNumAtoms(), k=1)):
328
+ if i in idx_map and j in idx_map:
329
+ constraint = ParsedRDKitBoundsConstraint(
330
+ atom_idxs=(idx_map[i], idx_map[j]),
331
+ is_bond=tuple(sorted([i, j])) in bonds,
332
+ is_angle=tuple(sorted([i, j])) in angles,
333
+ upper_bound=bounds[i, j],
334
+ lower_bound=bounds[j, i],
335
+ )
336
+ constraints.append(constraint)
337
+ return constraints
338
+
339
+
340
+ def compute_chiral_atom_constraints(mol, idx_map):
341
+ constraints = []
342
+ if all([atom.HasProp("_CIPRank") for atom in mol.GetAtoms()]):
343
+ for center_idx, orientation in Chem.FindMolChiralCenters(
344
+ mol, includeUnassigned=False
345
+ ):
346
+ center = mol.GetAtomWithIdx(center_idx)
347
+ neighbors = [
348
+ (neighbor.GetIdx(), int(neighbor.GetProp("_CIPRank")))
349
+ for neighbor in center.GetNeighbors()
350
+ ]
351
+ neighbors = sorted(
352
+ neighbors, key=lambda neighbor: neighbor[1], reverse=True
353
+ )
354
+ neighbors = tuple(neighbor[0] for neighbor in neighbors)
355
+ is_r = orientation == "R"
356
+
357
+ if len(neighbors) > 4 or center.GetHybridization() != HybridizationType.SP3:
358
+ continue
359
+
360
+ atom_idxs = (*neighbors[:3], center_idx)
361
+ if all(i in idx_map for i in atom_idxs):
362
+ constraints.append(
363
+ ParsedChiralAtomConstraint(
364
+ atom_idxs=tuple(idx_map[i] for i in atom_idxs),
365
+ is_reference=True,
366
+ is_r=is_r,
367
+ )
368
+ )
369
+
370
+ if len(neighbors) == 4:
371
+ for skip_idx in range(3):
372
+ chiral_set = neighbors[:skip_idx] + neighbors[skip_idx + 1 :]
373
+ if skip_idx % 2 == 0:
374
+ atom_idxs = chiral_set[::-1] + (center_idx,)
375
+ else:
376
+ atom_idxs = chiral_set + (center_idx,)
377
+ if all(i in idx_map for i in atom_idxs):
378
+ constraints.append(
379
+ ParsedChiralAtomConstraint(
380
+ atom_idxs=tuple(idx_map[i] for i in atom_idxs),
381
+ is_reference=False,
382
+ is_r=is_r,
383
+ )
384
+ )
385
+ return constraints
386
+
387
+
388
+ def compute_stereo_bond_constraints(mol, idx_map):
389
+ constraints = []
390
+ if all([atom.HasProp("_CIPRank") for atom in mol.GetAtoms()]):
391
+ for bond in mol.GetBonds():
392
+ stereo = bond.GetStereo()
393
+ if stereo in {BondStereo.STEREOE, BondStereo.STEREOZ}:
394
+ start_atom_idx, end_atom_idx = (
395
+ bond.GetBeginAtomIdx(),
396
+ bond.GetEndAtomIdx(),
397
+ )
398
+ start_neighbors = [
399
+ (neighbor.GetIdx(), int(neighbor.GetProp("_CIPRank")))
400
+ for neighbor in mol.GetAtomWithIdx(start_atom_idx).GetNeighbors()
401
+ if neighbor.GetIdx() != end_atom_idx
402
+ ]
403
+ start_neighbors = sorted(
404
+ start_neighbors, key=lambda neighbor: neighbor[1], reverse=True
405
+ )
406
+ start_neighbors = [neighbor[0] for neighbor in start_neighbors]
407
+ end_neighbors = [
408
+ (neighbor.GetIdx(), int(neighbor.GetProp("_CIPRank")))
409
+ for neighbor in mol.GetAtomWithIdx(end_atom_idx).GetNeighbors()
410
+ if neighbor.GetIdx() != start_atom_idx
411
+ ]
412
+ end_neighbors = sorted(
413
+ end_neighbors, key=lambda neighbor: neighbor[1], reverse=True
414
+ )
415
+ end_neighbors = [neighbor[0] for neighbor in end_neighbors]
416
+ is_e = stereo == BondStereo.STEREOE
417
+
418
+ atom_idxs = (
419
+ start_neighbors[0],
420
+ start_atom_idx,
421
+ end_atom_idx,
422
+ end_neighbors[0],
423
+ )
424
+ if all(i in idx_map for i in atom_idxs):
425
+ constraints.append(
426
+ ParsedStereoBondConstraint(
427
+ atom_idxs=tuple(idx_map[i] for i in atom_idxs),
428
+ is_check=True,
429
+ is_e=is_e,
430
+ )
431
+ )
432
+
433
+ if len(start_neighbors) == 2 and len(end_neighbors) == 2:
434
+ atom_idxs = (
435
+ start_neighbors[1],
436
+ start_atom_idx,
437
+ end_atom_idx,
438
+ end_neighbors[1],
439
+ )
440
+ if all(i in idx_map for i in atom_idxs):
441
+ constraints.append(
442
+ ParsedStereoBondConstraint(
443
+ atom_idxs=tuple(idx_map[i] for i in atom_idxs),
444
+ is_check=False,
445
+ is_e=is_e,
446
+ )
447
+ )
448
+ return constraints
449
+
450
+
451
+ def compute_flatness_constraints(mol, idx_map):
452
+ planar_double_bond_smarts = Chem.MolFromSmarts("[C;X3;^2](*)(*)=[C;X3;^2](*)(*)")
453
+ aromatic_ring_5_smarts = Chem.MolFromSmarts("[ar5^2]1[ar5^2][ar5^2][ar5^2][ar5^2]1")
454
+ aromatic_ring_6_smarts = Chem.MolFromSmarts(
455
+ "[ar6^2]1[ar6^2][ar6^2][ar6^2][ar6^2][ar6^2]1"
456
+ )
457
+
458
+ planar_double_bond_constraints = []
459
+ aromatic_ring_5_constraints = []
460
+ aromatic_ring_6_constraints = []
461
+ for match in mol.GetSubstructMatches(planar_double_bond_smarts):
462
+ if all(i in idx_map for i in match):
463
+ planar_double_bond_constraints.append(
464
+ ParsedPlanarBondConstraint(atom_idxs=tuple(idx_map[i] for i in match))
465
+ )
466
+ for match in mol.GetSubstructMatches(aromatic_ring_5_smarts):
467
+ if all(i in idx_map for i in match):
468
+ aromatic_ring_5_constraints.append(
469
+ ParsedPlanarRing5Constraint(atom_idxs=tuple(idx_map[i] for i in match))
470
+ )
471
+ for match in mol.GetSubstructMatches(aromatic_ring_6_smarts):
472
+ if all(i in idx_map for i in match):
473
+ aromatic_ring_6_constraints.append(
474
+ ParsedPlanarRing6Constraint(atom_idxs=tuple(idx_map[i] for i in match))
475
+ )
476
+
477
+ return (
478
+ planar_double_bond_constraints,
479
+ aromatic_ring_5_constraints,
480
+ aromatic_ring_6_constraints,
481
+ )
482
+
483
+
484
+ def get_global_alignment_score(query: str, template: str) -> float:
485
+ """Align a sequence to a template.
486
+
487
+ Parameters
488
+ ----------
489
+ query : str
490
+ The query sequence.
491
+ template : str
492
+ The template sequence.
493
+
494
+ Returns
495
+ -------
496
+ float
497
+ The global alignment score.
498
+
499
+ """
500
+ aligner = Align.PairwiseAligner(scoring="blastp")
501
+ aligner.mode = "global"
502
+ score = aligner.align(query, template)[0].score
503
+ return score
504
+
505
+
506
+ def get_local_alignments(query: str, template: str) -> list[Alignment]:
507
+ """Align a sequence to a template.
508
+
509
+ Parameters
510
+ ----------
511
+ query : str
512
+ The query sequence.
513
+ template : str
514
+ The template sequence.
515
+
516
+ Returns
517
+ -------
518
+ Alignment
519
+ The alignment between the query and template.
520
+
521
+ """
522
+ aligner = Align.PairwiseAligner(scoring="blastp")
523
+ aligner.mode = "local"
524
+ aligner.open_gap_score = -1000
525
+ aligner.extend_gap_score = -1000
526
+
527
+ alignments = []
528
+ for result in aligner.align(query, template):
529
+ coordinates = result.coordinates
530
+ alignment = Alignment(
531
+ query_st=int(coordinates[0][0]),
532
+ query_en=int(coordinates[0][1]),
533
+ template_st=int(coordinates[1][0]),
534
+ template_en=int(coordinates[1][1]),
535
+ )
536
+ alignments.append(alignment)
537
+
538
+ return alignments
539
+
540
+
541
+ def get_template_records_from_search(
542
+ template_id: str,
543
+ chain_ids: list[str],
544
+ sequences: dict[str, str],
545
+ template_chain_ids: list[str],
546
+ template_sequences: dict[str, str],
547
+ ) -> list[TemplateInfo]:
548
+ """Get template records from an alignment."""
549
+ # Compute pairwise scores
550
+ score_matrix = []
551
+ for chain_id in chain_ids:
552
+ row = []
553
+ for template_chain_id in template_chain_ids:
554
+ chain_seq = sequences[chain_id]
555
+ template_seq = template_sequences[template_chain_id]
556
+ score = get_global_alignment_score(chain_seq, template_seq)
557
+ row.append(score)
558
+ score_matrix.append(row)
559
+
560
+ # Find optimal mapping
561
+ row_ind, col_ind = linear_sum_assignment(score_matrix, maximize=True)
562
+
563
+ # Get alignment records
564
+ template_records = []
565
+
566
+ for row_idx, col_idx in zip(row_ind, col_ind):
567
+ chain_id = chain_ids[row_idx]
568
+ template_chain_id = template_chain_ids[col_idx]
569
+ chain_seq = sequences[chain_id]
570
+ template_seq = template_sequences[template_chain_id]
571
+ alignments = get_local_alignments(chain_seq, template_seq)
572
+
573
+ for alignment in alignments:
574
+ template_record = TemplateInfo(
575
+ name=template_id,
576
+ query_chain=chain_id,
577
+ query_st=alignment.query_st,
578
+ query_en=alignment.query_en,
579
+ template_chain=template_chain_id,
580
+ template_st=alignment.template_st,
581
+ template_en=alignment.template_en,
582
+ )
583
+ template_records.append(template_record)
584
+
585
+ return template_records
586
+
587
+
588
+ def get_template_records_from_matching(
589
+ template_id: str,
590
+ chain_ids: list[str],
591
+ sequences: dict[str, str],
592
+ template_chain_ids: list[str],
593
+ template_sequences: dict[str, str],
594
+ ) -> list[TemplateInfo]:
595
+ """Get template records from a given matching."""
596
+ template_records = []
597
+
598
+ for chain_id, template_chain_id in zip(chain_ids, template_chain_ids):
599
+ # Align the sequences
600
+ chain_seq = sequences[chain_id]
601
+ template_seq = template_sequences[template_chain_id]
602
+ alignments = get_local_alignments(chain_seq, template_seq)
603
+ for alignment in alignments:
604
+ template_record = TemplateInfo(
605
+ name=template_id,
606
+ query_chain=chain_id,
607
+ query_st=alignment.query_st,
608
+ query_en=alignment.query_en,
609
+ template_chain=template_chain_id,
610
+ template_st=alignment.template_st,
611
+ template_en=alignment.template_en,
612
+ )
613
+ template_records.append(template_record)
614
+
615
+ return template_records
616
+
617
+
618
+ def get_mol(ccd: str, mols: dict, moldir: str) -> Mol:
619
+ """Get mol from CCD code.
620
+
621
+ Return mol with ccd from mols if it is in mols. Otherwise load it from moldir,
622
+ add it to mols, and return the mol.
623
+ """
624
+ mol = mols.get(ccd)
625
+ if mol is None:
626
+ mol = load_molecules(moldir, [ccd])[ccd]
627
+ return mol
628
+
629
+
630
+ ####################################################################################################
631
+ # PARSING
632
+ ####################################################################################################
633
+
634
+
635
+ def parse_ccd_residue(
636
+ name: str, ref_mol: Mol, res_idx: int, drop_leaving_atoms: bool = False
637
+ ) -> Optional[ParsedResidue]:
638
+ """Parse an MMCIF ligand.
639
+
640
+ First tries to get the SMILES string from the RCSB.
641
+ Then, tries to infer atom ordering using RDKit.
642
+
643
+ Parameters
644
+ ----------
645
+ name: str
646
+ The name of the molecule to parse.
647
+ ref_mol: Mol
648
+ The reference molecule to parse.
649
+ res_idx : int
650
+ The residue index.
651
+
652
+ Returns
653
+ -------
654
+ ParsedResidue, optional
655
+ The output ParsedResidue, if successful.
656
+
657
+ """
658
+ unk_chirality = const.chirality_type_ids[const.unk_chirality_type]
659
+
660
+ # Check if this is a single heavy atom CCD residue
661
+ if CalcNumHeavyAtoms(ref_mol) == 1:
662
+ # Remove hydrogens
663
+ ref_mol = AllChem.RemoveHs(ref_mol, sanitize=False)
664
+
665
+ pos = (0, 0, 0)
666
+ ref_atom = ref_mol.GetAtoms()[0]
667
+ chirality_type = const.chirality_type_ids.get(
668
+ str(ref_atom.GetChiralTag()), unk_chirality
669
+ )
670
+ atom = ParsedAtom(
671
+ name=ref_atom.GetProp("name"),
672
+ element=ref_atom.GetAtomicNum(),
673
+ charge=ref_atom.GetFormalCharge(),
674
+ coords=pos,
675
+ conformer=(0, 0, 0),
676
+ is_present=True,
677
+ chirality=chirality_type,
678
+ )
679
+ unk_prot_id = const.unk_token_ids["PROTEIN"]
680
+ residue = ParsedResidue(
681
+ name=name,
682
+ type=unk_prot_id,
683
+ atoms=[atom],
684
+ bonds=[],
685
+ idx=res_idx,
686
+ orig_idx=None,
687
+ atom_center=0, # Placeholder, no center
688
+ atom_disto=0, # Placeholder, no center
689
+ is_standard=False,
690
+ is_present=True,
691
+ )
692
+ return residue
693
+
694
+ # Get reference conformer coordinates
695
+ conformer = get_conformer(ref_mol)
696
+
697
+ # Parse each atom in order of the reference mol
698
+ atoms = []
699
+ atom_idx = 0
700
+ idx_map = {} # Used for bonds later
701
+
702
+ for i, atom in enumerate(ref_mol.GetAtoms()):
703
+ # Ignore Hydrogen atoms
704
+ if atom.GetAtomicNum() == 1:
705
+ continue
706
+
707
+ # Get atom name, charge, element and reference coordinates
708
+ atom_name = atom.GetProp("name")
709
+
710
+ # Drop leaving atoms for non-canonical amino acids.
711
+ if drop_leaving_atoms and int(atom.GetProp('leaving_atom')):
712
+ continue
713
+
714
+ charge = atom.GetFormalCharge()
715
+ element = atom.GetAtomicNum()
716
+ ref_coords = conformer.GetAtomPosition(atom.GetIdx())
717
+ ref_coords = (ref_coords.x, ref_coords.y, ref_coords.z)
718
+ chirality_type = const.chirality_type_ids.get(
719
+ str(atom.GetChiralTag()), unk_chirality
720
+ )
721
+
722
+ # Get PDB coordinates, if any
723
+ coords = (0, 0, 0)
724
+ atom_is_present = True
725
+
726
+ # Add atom to list
727
+ atoms.append(
728
+ ParsedAtom(
729
+ name=atom_name,
730
+ element=element,
731
+ charge=charge,
732
+ coords=coords,
733
+ conformer=ref_coords,
734
+ is_present=atom_is_present,
735
+ chirality=chirality_type,
736
+ )
737
+ )
738
+ idx_map[i] = atom_idx
739
+ atom_idx += 1
740
+
741
+ # Load bonds
742
+ bonds = []
743
+ unk_bond = const.bond_type_ids[const.unk_bond_type]
744
+ for bond in ref_mol.GetBonds():
745
+ idx_1 = bond.GetBeginAtomIdx()
746
+ idx_2 = bond.GetEndAtomIdx()
747
+
748
+ # Skip bonds with atoms ignored
749
+ if (idx_1 not in idx_map) or (idx_2 not in idx_map):
750
+ continue
751
+
752
+ idx_1 = idx_map[idx_1]
753
+ idx_2 = idx_map[idx_2]
754
+ start = min(idx_1, idx_2)
755
+ end = max(idx_1, idx_2)
756
+ bond_type = bond.GetBondType().name
757
+ bond_type = const.bond_type_ids.get(bond_type, unk_bond)
758
+ bonds.append(ParsedBond(start, end, bond_type))
759
+
760
+ rdkit_bounds_constraints = compute_geometry_constraints(ref_mol, idx_map)
761
+ chiral_atom_constraints = compute_chiral_atom_constraints(ref_mol, idx_map)
762
+ stereo_bond_constraints = compute_stereo_bond_constraints(ref_mol, idx_map)
763
+ planar_bond_constraints, planar_ring_5_constraints, planar_ring_6_constraints = (
764
+ compute_flatness_constraints(ref_mol, idx_map)
765
+ )
766
+
767
+ unk_prot_id = const.unk_token_ids["PROTEIN"]
768
+ return ParsedResidue(
769
+ name=name,
770
+ type=unk_prot_id,
771
+ atoms=atoms,
772
+ bonds=bonds,
773
+ idx=res_idx,
774
+ atom_center=0,
775
+ atom_disto=0,
776
+ orig_idx=None,
777
+ is_standard=False,
778
+ is_present=True,
779
+ rdkit_bounds_constraints=rdkit_bounds_constraints,
780
+ chiral_atom_constraints=chiral_atom_constraints,
781
+ stereo_bond_constraints=stereo_bond_constraints,
782
+ planar_bond_constraints=planar_bond_constraints,
783
+ planar_ring_5_constraints=planar_ring_5_constraints,
784
+ planar_ring_6_constraints=planar_ring_6_constraints,
785
+ )
786
+
787
+
788
+ def parse_polymer(
789
+ sequence: list[str],
790
+ raw_sequence: str,
791
+ entity: str,
792
+ chain_type: str,
793
+ components: dict[str, Mol],
794
+ cyclic: bool,
795
+ mol_dir: Path,
796
+ ) -> Optional[ParsedChain]:
797
+ """Process a sequence into a chain object.
798
+
799
+ Performs alignment of the full sequence to the polymer
800
+ residues. Loads coordinates and masks for the atoms in
801
+ the polymer, following the ordering in const.atom_order.
802
+
803
+ Parameters
804
+ ----------
805
+ sequence : list[str]
806
+ The full sequence of the polymer.
807
+ entity : str
808
+ The entity id.
809
+ entity_type : str
810
+ The entity type.
811
+ components : dict[str, Mol]
812
+ The preprocessed PDB components dictionary.
813
+
814
+ Returns
815
+ -------
816
+ ParsedChain, optional
817
+ The output chain, if successful.
818
+
819
+ Raises
820
+ ------
821
+ ValueError
822
+ If the alignment fails.
823
+
824
+ """
825
+ ref_res = set(const.tokens)
826
+ unk_chirality = const.chirality_type_ids[const.unk_chirality_type]
827
+
828
+ # Get coordinates and masks
829
+ parsed = []
830
+ for res_idx, res_name in enumerate(sequence):
831
+ # Check if modified residue
832
+ # Map MSE to MET
833
+ res_corrected = res_name if res_name != "MSE" else "MET"
834
+
835
+ # Handle non-standard residues
836
+ if res_corrected not in ref_res:
837
+ ref_mol = get_mol(res_corrected, components, mol_dir)
838
+ residue = parse_ccd_residue(
839
+ name=res_corrected,
840
+ ref_mol=ref_mol,
841
+ res_idx=res_idx,
842
+ drop_leaving_atoms=True,
843
+ )
844
+ parsed.append(residue)
845
+ continue
846
+
847
+ # Load ref residue
848
+ ref_mol = get_mol(res_corrected, components, mol_dir)
849
+ ref_mol = AllChem.RemoveHs(ref_mol, sanitize=False)
850
+ ref_conformer = get_conformer(ref_mol)
851
+
852
+ # Only use reference atoms set in constants
853
+ ref_name_to_atom = {a.GetProp("name"): a for a in ref_mol.GetAtoms()}
854
+ ref_atoms = [ref_name_to_atom[a] for a in const.ref_atoms[res_corrected]]
855
+
856
+ # Iterate, always in the same order
857
+ atoms: list[ParsedAtom] = []
858
+
859
+ for ref_atom in ref_atoms:
860
+ # Get atom name
861
+ atom_name = ref_atom.GetProp("name")
862
+ idx = ref_atom.GetIdx()
863
+
864
+ # Get conformer coordinates
865
+ ref_coords = ref_conformer.GetAtomPosition(idx)
866
+ ref_coords = (ref_coords.x, ref_coords.y, ref_coords.z)
867
+
868
+ # Set 0 coordinate
869
+ atom_is_present = True
870
+ coords = (0, 0, 0)
871
+
872
+ # Add atom to list
873
+ atoms.append(
874
+ ParsedAtom(
875
+ name=atom_name,
876
+ element=ref_atom.GetAtomicNum(),
877
+ charge=ref_atom.GetFormalCharge(),
878
+ coords=coords,
879
+ conformer=ref_coords,
880
+ is_present=atom_is_present,
881
+ chirality=const.chirality_type_ids.get(
882
+ str(ref_atom.GetChiralTag()), unk_chirality
883
+ ),
884
+ )
885
+ )
886
+
887
+ atom_center = const.res_to_center_atom_id[res_corrected]
888
+ atom_disto = const.res_to_disto_atom_id[res_corrected]
889
+ parsed.append(
890
+ ParsedResidue(
891
+ name=res_corrected,
892
+ type=const.token_ids[res_corrected],
893
+ atoms=atoms,
894
+ bonds=[],
895
+ idx=res_idx,
896
+ atom_center=atom_center,
897
+ atom_disto=atom_disto,
898
+ is_standard=True,
899
+ is_present=True,
900
+ orig_idx=None,
901
+ )
902
+ )
903
+
904
+ if cyclic:
905
+ cyclic_period = len(sequence)
906
+ else:
907
+ cyclic_period = 0
908
+
909
+ # Return polymer object
910
+ return ParsedChain(
911
+ entity=entity,
912
+ residues=parsed,
913
+ type=chain_type,
914
+ cyclic_period=cyclic_period,
915
+ sequence=raw_sequence,
916
+ )
917
+
918
+
919
+ def token_spec_to_ids(
920
+ chain_name, residue_index_or_atom_name, chain_to_idx, atom_idx_map, chains
921
+ ):
922
+ # TODO: unfinished
923
+ if chains[chain_name].type == const.chain_type_ids["NONPOLYMER"]:
924
+ # Non-polymer chains are indexed by atom name
925
+ _, _, atom_idx = atom_idx_map[(chain_name, 0, residue_index_or_atom_name)]
926
+ return (chain_to_idx[chain_name], atom_idx)
927
+ else:
928
+ # Polymer chains are indexed by residue index
929
+ contacts.append((chain_to_idx[chain_name], residue_index_or_atom_name - 1))
930
+
931
+
932
+ def parse_boltz_schema( # noqa: C901, PLR0915, PLR0912
933
+ name: str,
934
+ schema: dict,
935
+ ccd: Mapping[str, Mol],
936
+ mol_dir: Optional[Path] = None,
937
+ boltz_2: bool = False,
938
+ ) -> Target:
939
+ """Parse a Boltz input yaml / json.
940
+
941
+ The input file should be a dictionary with the following format:
942
+
943
+ version: 1
944
+ sequences:
945
+ - protein:
946
+ id: A
947
+ sequence: "MADQLTEEQIAEFKEAFSLF" # or pdb: "1a2k" or pdb: "path/to/file.pdb"
948
+ msa: path/to/msa1.a3m
949
+ - protein:
950
+ id: [B, C]
951
+ sequence: "AKLSILPWGHC"
952
+ msa: path/to/msa2.a3m
953
+ - rna:
954
+ id: D
955
+ sequence: "GCAUAGC"
956
+ - ligand:
957
+ id: E
958
+ smiles: "CC1=CC=CC=C1"
959
+ constraints:
960
+ - bond:
961
+ atom1: [A, 1, CA]
962
+ atom2: [A, 2, N]
963
+ - pocket:
964
+ binder: E
965
+ contacts: [[B, 1], [B, 2]]
966
+ max_distance: 6
967
+ - contact:
968
+ token1: [A, 1]
969
+ token2: [B, 1]
970
+ max_distance: 6
971
+ templates:
972
+ - cif: path/to/template.cif
973
+ properties:
974
+ - affinity:
975
+ binder: E
976
+
977
+ Parameters
978
+ ----------
979
+ name : str
980
+ A name for the input.
981
+ schema : dict
982
+ The input schema.
983
+ components : dict
984
+ Dictionary of CCD components.
985
+ mol_dir: Path
986
+ Path to the directory containing the molecules.
987
+ boltz2: bool
988
+ Whether to parse the input for Boltz2.
989
+
990
+ Returns
991
+ -------
992
+ Target
993
+ The parsed target.
994
+
995
+ """
996
+ # Assert version 1
997
+ version = schema.get("version", 1)
998
+ if version != 1:
999
+ msg = f"Invalid version {version} in input!"
1000
+ raise ValueError(msg)
1001
+
1002
+ # Disable rdkit warnings
1003
+ blocker = rdBase.BlockLogs() # noqa: F841
1004
+
1005
+ # First group items that have the same type, sequence and modifications
1006
+ items_to_group = {}
1007
+ chain_name_to_entity_type = {}
1008
+
1009
+ for item in schema["sequences"]:
1010
+ # Get entity type
1011
+ entity_type = next(iter(item.keys())).lower()
1012
+ if entity_type not in {"protein", "dna", "rna", "ligand"}:
1013
+ msg = f"Invalid entity type: {entity_type}"
1014
+ raise ValueError(msg)
1015
+
1016
+ # Get sequence or PDB
1017
+ if entity_type in {"protein", "dna", "rna"}:
1018
+ if "sequence" in item[entity_type]:
1019
+ seq = str(item[entity_type]["sequence"])
1020
+ elif "pdb" in item[entity_type]:
1021
+ pdb_input = str(item[entity_type]["pdb"])
1022
+ # Check if it's a PDB code (4 characters) or a file path
1023
+ if len(pdb_input) == 4 and pdb_input.isalnum():
1024
+ # It's a PDB code, check cache first
1025
+ cache_dir = Path(os.environ.get("BOLTZ_CACHE", "~/.boltz")).expanduser()
1026
+ pdb_cache_dir = cache_dir / "pdb"
1027
+ pdb_cache_dir.mkdir(parents=True, exist_ok=True)
1028
+
1029
+ pdb_cache_file = pdb_cache_dir / f"{pdb_input.lower()}.pdb"
1030
+
1031
+ if pdb_cache_file.exists():
1032
+ # Use cached file
1033
+ with pdb_cache_file.open("r") as f:
1034
+ pdb_data = f.read()
1035
+ else:
1036
+ # Download and cache
1037
+ import urllib.request
1038
+ pdb_url = f"https://files.rcsb.org/download/{pdb_input.lower()}.pdb"
1039
+ try:
1040
+ with urllib.request.urlopen(pdb_url) as response:
1041
+ pdb_data = response.read().decode()
1042
+ # Cache the downloaded data
1043
+ with pdb_cache_file.open("w") as f:
1044
+ f.write(pdb_data)
1045
+ except Exception as e:
1046
+ msg = f"Failed to download PDB {pdb_input}: {str(e)}"
1047
+ raise RuntimeError(msg) from e
1048
+ else:
1049
+ # It's a file path
1050
+ pdb_path = Path(pdb_input)
1051
+ if not pdb_path.exists():
1052
+ msg = f"PDB file not found: {pdb_path}"
1053
+ raise FileNotFoundError(msg)
1054
+ with pdb_path.open("r") as f:
1055
+ pdb_data = f.read()
1056
+
1057
+ # Parse PDB data
1058
+ from Bio.PDB import PDBParser
1059
+ from io import StringIO
1060
+ parser = PDBParser()
1061
+ structure = parser.get_structure("protein", StringIO(pdb_data))
1062
+
1063
+ # Extract sequence
1064
+ seq = ""
1065
+ for model in structure:
1066
+ for chain in model:
1067
+ for residue in chain:
1068
+ if residue.id[0] == " ": # Only standard residues
1069
+ seq += residue.resname
1070
+ seq = "".join(seq)
1071
+ else:
1072
+ msg = "Protein must have either 'sequence' or 'pdb' field"
1073
+ raise ValueError(msg)
1074
+ elif entity_type == "ligand":
1075
+ assert "smiles" in item[entity_type] or "ccd" in item[entity_type]
1076
+ assert "smiles" not in item[entity_type] or "ccd" not in item[entity_type]
1077
+ if "smiles" in item[entity_type]:
1078
+ seq = str(item[entity_type]["smiles"])
1079
+ else:
1080
+ seq = str(item[entity_type]["ccd"])
1081
+
1082
+ # Group items by entity
1083
+ items_to_group.setdefault((entity_type, seq), []).append(item)
1084
+
1085
+ # Map chain names to entity types
1086
+ chain_names = item[entity_type]["id"]
1087
+ chain_names = [chain_names] if isinstance(chain_names, str) else chain_names
1088
+ for chain_name in chain_names:
1089
+ chain_name_to_entity_type[chain_name] = entity_type
1090
+
1091
+ # Check if any affinity ligand is present
1092
+ affinity_ligands = set()
1093
+ properties = schema.get("properties", [])
1094
+ if properties and not boltz_2:
1095
+ msg = "Affinity prediction is only supported for Boltz2!"
1096
+ raise ValueError(msg)
1097
+
1098
+ for prop in properties:
1099
+ prop_type = next(iter(prop.keys())).lower()
1100
+ if prop_type == "affinity":
1101
+ binder = prop["affinity"]["binder"]
1102
+ if not isinstance(binder, str):
1103
+ # TODO: support multi residue ligands and ccd's
1104
+ msg = "Binder must be a single chain."
1105
+ raise ValueError(msg)
1106
+
1107
+ if binder not in chain_name_to_entity_type:
1108
+ msg = f"Could not find binder with name {binder} in the input!"
1109
+ raise ValueError(msg)
1110
+
1111
+ if chain_name_to_entity_type[binder] != "ligand":
1112
+ msg = (
1113
+ f"Chain {binder} is not a ligand! "
1114
+ "Affinity is currently only supported for ligands."
1115
+ )
1116
+ raise ValueError(msg)
1117
+
1118
+ affinity_ligands.add(binder)
1119
+
1120
+ # Check only one affinity ligand is present
1121
+ if len(affinity_ligands) > 1:
1122
+ msg = "Only one affinity ligand is currently supported!"
1123
+ raise ValueError(msg)
1124
+
1125
+ # Go through entities and parse them
1126
+ extra_mols: dict[str, Mol] = {}
1127
+ chains: dict[str, ParsedChain] = {}
1128
+ chain_to_msa: dict[str, str] = {}
1129
+ entity_to_seq: dict[str, str] = {}
1130
+ is_msa_custom = False
1131
+ is_msa_auto = False
1132
+ ligand_id = 1
1133
+ for entity_id, items in enumerate(items_to_group.values()):
1134
+ # Get entity type and sequence
1135
+ entity_type = next(iter(items[0].keys())).lower()
1136
+
1137
+ # Get ids
1138
+ ids = []
1139
+ for item in items:
1140
+ if isinstance(item[entity_type]["id"], str):
1141
+ ids.append(item[entity_type]["id"])
1142
+ elif isinstance(item[entity_type]["id"], list):
1143
+ ids.extend(item[entity_type]["id"])
1144
+
1145
+ # Check if any affinity ligand is present
1146
+ if len(ids) == 1:
1147
+ affinity = ids[0] in affinity_ligands
1148
+ elif (len(ids) > 1) and any(x in affinity_ligands for x in ids):
1149
+ msg = "Cannot compute affinity for a ligand that has multiple copies!"
1150
+ raise ValueError(msg)
1151
+ else:
1152
+ affinity = False
1153
+
1154
+ # Ensure all the items share the same msa
1155
+ msa = -1
1156
+ if entity_type == "protein":
1157
+ # Get the msa, default to 0, meaning auto-generated
1158
+ msa = items[0][entity_type].get("msa", 0)
1159
+ if (msa is None) or (msa == ""):
1160
+ msa = 0
1161
+
1162
+ # Check if all MSAs are the same within the same entity
1163
+ for item in items:
1164
+ item_msa = item[entity_type].get("msa", 0)
1165
+ if (item_msa is None) or (item_msa == ""):
1166
+ item_msa = 0
1167
+
1168
+ if item_msa != msa:
1169
+ msg = "All proteins with the same sequence must share the same MSA!"
1170
+ raise ValueError(msg)
1171
+
1172
+ # Set the MSA, warn if passed in single-sequence mode
1173
+ if msa == "empty":
1174
+ msa = -1
1175
+ msg = (
1176
+ "Found explicit empty MSA for some proteins, will run "
1177
+ "these in single sequence mode. Keep in mind that the "
1178
+ "model predictions will be suboptimal without an MSA."
1179
+ )
1180
+ click.echo(msg)
1181
+
1182
+ if msa not in (0, -1):
1183
+ is_msa_custom = True
1184
+ elif msa == 0:
1185
+ is_msa_auto = True
1186
+
1187
+ # Parse a polymer
1188
+ if entity_type in {"protein", "dna", "rna"}:
1189
+ # Get token map
1190
+ if entity_type == "rna":
1191
+ token_map = const.rna_letter_to_token
1192
+ elif entity_type == "dna":
1193
+ token_map = const.dna_letter_to_token
1194
+ elif entity_type == "protein":
1195
+ token_map = const.prot_letter_to_token
1196
+ else:
1197
+ msg = f"Unknown polymer type: {entity_type}"
1198
+ raise ValueError(msg)
1199
+
1200
+ # Get polymer info
1201
+ chain_type = const.chain_type_ids[entity_type.upper()]
1202
+ unk_token = const.unk_token[entity_type.upper()]
1203
+
1204
+ # Extract sequence
1205
+ raw_seq = items[0][entity_type]["sequence"]
1206
+ entity_to_seq[entity_id] = raw_seq
1207
+
1208
+ # Convert sequence to tokens
1209
+ seq = [token_map.get(c, unk_token) for c in list(raw_seq)]
1210
+
1211
+ # Apply modifications
1212
+ for mod in items[0][entity_type].get("modifications", []):
1213
+ code = mod["ccd"]
1214
+ idx = mod["position"] - 1 # 1-indexed
1215
+ seq[idx] = code
1216
+
1217
+ cyclic = items[0][entity_type].get("cyclic", False)
1218
+
1219
+ # Parse a polymer
1220
+ parsed_chain = parse_polymer(
1221
+ sequence=seq,
1222
+ raw_sequence=raw_seq,
1223
+ entity=entity_id,
1224
+ chain_type=chain_type,
1225
+ components=ccd,
1226
+ cyclic=cyclic,
1227
+ mol_dir=mol_dir,
1228
+ )
1229
+
1230
+ # Parse a non-polymer
1231
+ elif (entity_type == "ligand") and "ccd" in (items[0][entity_type]):
1232
+ seq = items[0][entity_type]["ccd"]
1233
+
1234
+ if isinstance(seq, str):
1235
+ seq = [seq]
1236
+
1237
+ if affinity and len(seq) > 1:
1238
+ msg = "Cannot compute affinity for multi residue ligands!"
1239
+ raise ValueError(msg)
1240
+
1241
+ residues = []
1242
+ affinity_mw = None
1243
+ for res_idx, code in enumerate(seq):
1244
+ # Get mol
1245
+ ref_mol = get_mol(code, ccd, mol_dir)
1246
+
1247
+ if affinity:
1248
+ affinity_mw = AllChem.Descriptors.MolWt(ref_mol)
1249
+
1250
+ # Parse residue
1251
+ residue = parse_ccd_residue(
1252
+ name=code,
1253
+ ref_mol=ref_mol,
1254
+ res_idx=res_idx,
1255
+ )
1256
+ residues.append(residue)
1257
+
1258
+ # Create multi ligand chain
1259
+ parsed_chain = ParsedChain(
1260
+ entity=entity_id,
1261
+ residues=residues,
1262
+ type=const.chain_type_ids["NONPOLYMER"],
1263
+ cyclic_period=0,
1264
+ sequence=None,
1265
+ affinity=affinity,
1266
+ affinity_mw=affinity_mw,
1267
+ )
1268
+
1269
+ assert not items[0][entity_type].get(
1270
+ "cyclic", False
1271
+ ), "Cyclic flag is not supported for ligands"
1272
+
1273
+ elif (entity_type == "ligand") and ("smiles" in items[0][entity_type]):
1274
+ seq = items[0][entity_type]["smiles"]
1275
+
1276
+ if affinity:
1277
+ seq = standardize(seq)
1278
+
1279
+ mol = AllChem.MolFromSmiles(seq)
1280
+ mol = AllChem.AddHs(mol)
1281
+
1282
+ # Set atom names
1283
+ canonical_order = AllChem.CanonicalRankAtoms(mol)
1284
+ for atom, can_idx in zip(mol.GetAtoms(), canonical_order):
1285
+ atom_name = atom.GetSymbol().upper() + str(can_idx + 1)
1286
+ if len(atom_name) > 4:
1287
+ msg = (
1288
+ f"{seq} has an atom with a name longer than "
1289
+ f"4 characters: {atom_name}."
1290
+ )
1291
+ raise ValueError(msg)
1292
+ atom.SetProp("name", atom_name)
1293
+
1294
+ success = compute_3d_conformer(mol)
1295
+ if not success:
1296
+ msg = f"Failed to compute 3D conformer for {seq}"
1297
+ raise ValueError(msg)
1298
+
1299
+ mol_no_h = AllChem.RemoveHs(mol, sanitize=False)
1300
+ affinity_mw = AllChem.Descriptors.MolWt(mol_no_h) if affinity else None
1301
+ extra_mols[f"LIG{ligand_id}"] = mol_no_h
1302
+ residue = parse_ccd_residue(
1303
+ name=f"LIG{ligand_id}",
1304
+ ref_mol=mol,
1305
+ res_idx=0,
1306
+ )
1307
+
1308
+ ligand_id += 1
1309
+ parsed_chain = ParsedChain(
1310
+ entity=entity_id,
1311
+ residues=[residue],
1312
+ type=const.chain_type_ids["NONPOLYMER"],
1313
+ cyclic_period=0,
1314
+ sequence=None,
1315
+ affinity=affinity,
1316
+ affinity_mw=affinity_mw,
1317
+ )
1318
+
1319
+ assert not items[0][entity_type].get(
1320
+ "cyclic", False
1321
+ ), "Cyclic flag is not supported for ligands"
1322
+
1323
+ else:
1324
+ msg = f"Invalid entity type: {entity_type}"
1325
+ raise ValueError(msg)
1326
+
1327
+ # Add as many chains as provided ids
1328
+ for item in items:
1329
+ ids = item[entity_type]["id"]
1330
+ if isinstance(ids, str):
1331
+ ids = [ids]
1332
+ for chain_name in ids:
1333
+ chains[chain_name] = parsed_chain
1334
+ chain_to_msa[chain_name] = msa
1335
+
1336
+ # Check if msa is custom or auto
1337
+ if is_msa_custom and is_msa_auto:
1338
+ msg = "Cannot mix custom and auto-generated MSAs in the same input!"
1339
+ raise ValueError(msg)
1340
+
1341
+ # If no chains parsed fail
1342
+ if not chains:
1343
+ msg = "No chains parsed!"
1344
+ raise ValueError(msg)
1345
+
1346
+ # Create tables
1347
+ atom_data = []
1348
+ bond_data = []
1349
+ res_data = []
1350
+ chain_data = []
1351
+ protein_chains = set()
1352
+ affinity_info = None
1353
+
1354
+ rdkit_bounds_constraint_data = []
1355
+ chiral_atom_constraint_data = []
1356
+ stereo_bond_constraint_data = []
1357
+ planar_bond_constraint_data = []
1358
+ planar_ring_5_constraint_data = []
1359
+ planar_ring_6_constraint_data = []
1360
+
1361
+ # Convert parsed chains to tables
1362
+ atom_idx = 0
1363
+ res_idx = 0
1364
+ asym_id = 0
1365
+ sym_count = {}
1366
+ chain_to_idx = {}
1367
+
1368
+ # Keep a mapping of (chain_name, residue_idx, atom_name) to atom_idx
1369
+ atom_idx_map = {}
1370
+
1371
+ for asym_id, (chain_name, chain) in enumerate(chains.items()):
1372
+ # Compute number of atoms and residues
1373
+ res_num = len(chain.residues)
1374
+ atom_num = sum(len(res.atoms) for res in chain.residues)
1375
+
1376
+ # Save protein chains for later
1377
+ if chain.type == const.chain_type_ids["PROTEIN"]:
1378
+ protein_chains.add(chain_name)
1379
+
1380
+ # Add affinity info
1381
+ if chain.affinity and affinity_info is not None:
1382
+ msg = "Cannot compute affinity for multiple ligands!"
1383
+ raise ValueError(msg)
1384
+
1385
+ if chain.affinity:
1386
+ affinity_info = AffinityInfo(
1387
+ chain_id=asym_id,
1388
+ mw=chain.affinity_mw,
1389
+ )
1390
+
1391
+ # Find all copies of this chain in the assembly
1392
+ entity_id = int(chain.entity)
1393
+ sym_id = sym_count.get(entity_id, 0)
1394
+ chain_data.append(
1395
+ (
1396
+ chain_name,
1397
+ chain.type,
1398
+ entity_id,
1399
+ sym_id,
1400
+ asym_id,
1401
+ atom_idx,
1402
+ atom_num,
1403
+ res_idx,
1404
+ res_num,
1405
+ chain.cyclic_period,
1406
+ )
1407
+ )
1408
+ chain_to_idx[chain_name] = asym_id
1409
+ sym_count[entity_id] = sym_id + 1
1410
+
1411
+ # Add residue, atom, bond, data
1412
+ for res in chain.residues:
1413
+ atom_center = atom_idx + res.atom_center
1414
+ atom_disto = atom_idx + res.atom_disto
1415
+ res_data.append(
1416
+ (
1417
+ res.name,
1418
+ res.type,
1419
+ res.idx,
1420
+ atom_idx,
1421
+ len(res.atoms),
1422
+ atom_center,
1423
+ atom_disto,
1424
+ res.is_standard,
1425
+ res.is_present,
1426
+ )
1427
+ )
1428
+
1429
+ if res.rdkit_bounds_constraints is not None:
1430
+ for constraint in res.rdkit_bounds_constraints:
1431
+ rdkit_bounds_constraint_data.append( # noqa: PERF401
1432
+ (
1433
+ tuple(
1434
+ c_atom_idx + atom_idx
1435
+ for c_atom_idx in constraint.atom_idxs
1436
+ ),
1437
+ constraint.is_bond,
1438
+ constraint.is_angle,
1439
+ constraint.upper_bound,
1440
+ constraint.lower_bound,
1441
+ )
1442
+ )
1443
+ if res.chiral_atom_constraints is not None:
1444
+ for constraint in res.chiral_atom_constraints:
1445
+ chiral_atom_constraint_data.append( # noqa: PERF401
1446
+ (
1447
+ tuple(
1448
+ c_atom_idx + atom_idx
1449
+ for c_atom_idx in constraint.atom_idxs
1450
+ ),
1451
+ constraint.is_reference,
1452
+ constraint.is_r,
1453
+ )
1454
+ )
1455
+ if res.stereo_bond_constraints is not None:
1456
+ for constraint in res.stereo_bond_constraints:
1457
+ stereo_bond_constraint_data.append( # noqa: PERF401
1458
+ (
1459
+ tuple(
1460
+ c_atom_idx + atom_idx
1461
+ for c_atom_idx in constraint.atom_idxs
1462
+ ),
1463
+ constraint.is_check,
1464
+ constraint.is_e,
1465
+ )
1466
+ )
1467
+ if res.planar_bond_constraints is not None:
1468
+ for constraint in res.planar_bond_constraints:
1469
+ planar_bond_constraint_data.append( # noqa: PERF401
1470
+ (
1471
+ tuple(
1472
+ c_atom_idx + atom_idx
1473
+ for c_atom_idx in constraint.atom_idxs
1474
+ ),
1475
+ )
1476
+ )
1477
+ if res.planar_ring_5_constraints is not None:
1478
+ for constraint in res.planar_ring_5_constraints:
1479
+ planar_ring_5_constraint_data.append( # noqa: PERF401
1480
+ (
1481
+ tuple(
1482
+ c_atom_idx + atom_idx
1483
+ for c_atom_idx in constraint.atom_idxs
1484
+ ),
1485
+ )
1486
+ )
1487
+ if res.planar_ring_6_constraints is not None:
1488
+ for constraint in res.planar_ring_6_constraints:
1489
+ planar_ring_6_constraint_data.append( # noqa: PERF401
1490
+ (
1491
+ tuple(
1492
+ c_atom_idx + atom_idx
1493
+ for c_atom_idx in constraint.atom_idxs
1494
+ ),
1495
+ )
1496
+ )
1497
+
1498
+ for bond in res.bonds:
1499
+ atom_1 = atom_idx + bond.atom_1
1500
+ atom_2 = atom_idx + bond.atom_2
1501
+ bond_data.append(
1502
+ (
1503
+ asym_id,
1504
+ asym_id,
1505
+ res_idx,
1506
+ res_idx,
1507
+ atom_1,
1508
+ atom_2,
1509
+ bond.type,
1510
+ )
1511
+ )
1512
+
1513
+ for atom in res.atoms:
1514
+ # Add atom to map
1515
+ atom_idx_map[(chain_name, res.idx, atom.name)] = (
1516
+ asym_id,
1517
+ res_idx,
1518
+ atom_idx,
1519
+ )
1520
+
1521
+ # Add atom to data
1522
+ atom_data.append(
1523
+ (
1524
+ atom.name,
1525
+ atom.element,
1526
+ atom.charge,
1527
+ atom.coords,
1528
+ atom.conformer,
1529
+ atom.is_present,
1530
+ atom.chirality,
1531
+ )
1532
+ )
1533
+ atom_idx += 1
1534
+
1535
+ res_idx += 1
1536
+
1537
+ # Parse constraints
1538
+ connections = []
1539
+ pocket_constraints = []
1540
+ contact_constraints = []
1541
+ constraints = schema.get("constraints", [])
1542
+ for constraint in constraints:
1543
+ if "bond" in constraint:
1544
+ if "atom1" not in constraint["bond"] or "atom2" not in constraint["bond"]:
1545
+ msg = f"Bond constraint was not properly specified"
1546
+ raise ValueError(msg)
1547
+
1548
+ c1, r1, a1 = tuple(constraint["bond"]["atom1"])
1549
+ c2, r2, a2 = tuple(constraint["bond"]["atom2"])
1550
+ c1, r1, a1 = atom_idx_map[(c1, r1 - 1, a1)] # 1-indexed
1551
+ c2, r2, a2 = atom_idx_map[(c2, r2 - 1, a2)] # 1-indexed
1552
+ connections.append((c1, c2, r1, r2, a1, a2))
1553
+ elif "pocket" in constraint:
1554
+ if (
1555
+ "binder" not in constraint["pocket"]
1556
+ or "contacts" not in constraint["pocket"]
1557
+ ):
1558
+ msg = f"Pocket constraint was not properly specified"
1559
+ raise ValueError(msg)
1560
+
1561
+ if len(pocket_constraints) > 0 and not boltz_2:
1562
+ msg = f"Only one pocket binders is supported in Boltz-1!"
1563
+ raise ValueError(msg)
1564
+
1565
+ max_distance = constraint["pocket"].get("max_distance", 6.0)
1566
+ if max_distance != 6.0 and not boltz_2:
1567
+ msg = f"Max distance != 6.0 is not supported in Boltz-1!"
1568
+ raise ValueError(msg)
1569
+
1570
+ binder = constraint["pocket"]["binder"]
1571
+ binder = chain_to_idx[binder]
1572
+
1573
+ contacts = []
1574
+ for chain_name, residue_index_or_atom_name in constraint["pocket"][
1575
+ "contacts"
1576
+ ]:
1577
+ if chains[chain_name].type == const.chain_type_ids["NONPOLYMER"]:
1578
+ # Non-polymer chains are indexed by atom name
1579
+ _, _, atom_idx = atom_idx_map[
1580
+ (chain_name, 0, residue_index_or_atom_name)
1581
+ ]
1582
+ contact = (chain_to_idx[chain_name], atom_idx)
1583
+ else:
1584
+ # Polymer chains are indexed by residue index
1585
+ contact = (chain_to_idx[chain_name], residue_index_or_atom_name - 1)
1586
+ contacts.append(contact)
1587
+
1588
+ pocket_constraints.append((binder, contacts, max_distance))
1589
+ elif "contact" in constraint:
1590
+ if (
1591
+ "token1" not in constraint["contact"]
1592
+ or "token2" not in constraint["contact"]
1593
+ ):
1594
+ msg = f"Contact constraint was not properly specified"
1595
+ raise ValueError(msg)
1596
+
1597
+ if not boltz_2:
1598
+ msg = f"Contact constraint is not supported in Boltz-1!"
1599
+ raise ValueError(msg)
1600
+
1601
+ max_distance = constraint["contact"].get("max_distance", 6.0)
1602
+
1603
+ chain_name1, residue_index_or_atom_name1 = constraint["contact"]["token1"]
1604
+ if chains[chain_name1].type == const.chain_type_ids["NONPOLYMER"]:
1605
+ # Non-polymer chains are indexed by atom name
1606
+ _, _, atom_idx = atom_idx_map[
1607
+ (chain_name1, 0, residue_index_or_atom_name1)
1608
+ ]
1609
+ token1 = (chain_to_idx[chain_name1], atom_idx)
1610
+ else:
1611
+ # Polymer chains are indexed by residue index
1612
+ token1 = (chain_to_idx[chain_name1], residue_index_or_atom_name1 - 1)
1613
+
1614
+ pocket_constraints.append((binder, contacts, max_distance))
1615
+ else:
1616
+ msg = f"Invalid constraint: {constraint}"
1617
+ raise ValueError(msg)
1618
+
1619
+ # Get protein sequences in this YAML
1620
+ protein_seqs = {name: chains[name].sequence for name in protein_chains}
1621
+
1622
+ # Parse templates
1623
+ template_schema = schema.get("templates", [])
1624
+ if template_schema and not boltz_2:
1625
+ msg = "Templates are not supported in Boltz 1.0!"
1626
+ raise ValueError(msg)
1627
+
1628
+ templates = {}
1629
+ template_records = []
1630
+ for template in template_schema:
1631
+ if "cif" not in template:
1632
+ msg = "Template was not properly specified, missing CIF path!"
1633
+ raise ValueError(msg)
1634
+
1635
+ path = template["cif"]
1636
+ template_id = Path(path).stem
1637
+ chain_ids = template.get("chain_id", None)
1638
+ template_chain_ids = template.get("template_id", None)
1639
+
1640
+ # Check validity of input
1641
+ matched = False
1642
+
1643
+ if chain_ids is not None and not isinstance(chain_ids, list):
1644
+ chain_ids = [chain_ids]
1645
+ if template_chain_ids is not None and not isinstance(template_chain_ids, list):
1646
+ template_chain_ids = [template_chain_ids]
1647
+
1648
+ if (
1649
+ template_chain_ids is not None
1650
+ and chain_ids is not None
1651
+ and len(template_chain_ids) != len(chain_ids)
1652
+ ):
1653
+ matched = True
1654
+ if len(template_chain_ids) != len(chain_ids):
1655
+ msg = (
1656
+ "When providing both the chain_id and template_id, the number of"
1657
+ "template_ids provided must match the number of chain_ids!"
1658
+ )
1659
+ raise ValueError(msg)
1660
+
1661
+ # Get relevant chains ids
1662
+ if chain_ids is None:
1663
+ chain_ids = list(protein_chains)
1664
+
1665
+ for chain_id in chain_ids:
1666
+ if chain_id not in protein_chains:
1667
+ msg = (
1668
+ f"Chain {chain_id} assigned for template"
1669
+ f"{template_id} is not one of the protein chains!"
1670
+ )
1671
+ raise ValueError(msg)
1672
+
1673
+ # Get relevant template chain ids
1674
+ parsed_template = parse_mmcif(
1675
+ path,
1676
+ mols=ccd,
1677
+ moldir=mol_dir,
1678
+ use_assembly=False,
1679
+ compute_interfaces=False,
1680
+ )
1681
+ template_proteins = {
1682
+ str(c["name"])
1683
+ for c in parsed_template.data.chains
1684
+ if c["mol_type"] == const.chain_type_ids["PROTEIN"]
1685
+ }
1686
+ if template_chain_ids is None:
1687
+ template_chain_ids = list(template_proteins)
1688
+
1689
+ for chain_id in template_chain_ids:
1690
+ if chain_id not in template_proteins:
1691
+ msg = (
1692
+ f"Template chain {chain_id} assigned for template"
1693
+ f"{template_id} is not one of the protein chains!"
1694
+ )
1695
+ raise ValueError(msg)
1696
+
1697
+ # Compute template records
1698
+ if matched:
1699
+ template_records.extend(
1700
+ get_template_records_from_matching(
1701
+ template_id=template_id,
1702
+ chain_ids=chain_ids,
1703
+ sequences=protein_seqs,
1704
+ template_chain_ids=template_chain_ids,
1705
+ template_sequences=parsed_template.sequences,
1706
+ )
1707
+ )
1708
+ else:
1709
+ template_records.extend(
1710
+ get_template_records_from_search(
1711
+ template_id=template_id,
1712
+ chain_ids=chain_ids,
1713
+ sequences=protein_seqs,
1714
+ template_chain_ids=template_chain_ids,
1715
+ template_sequences=parsed_template.sequences,
1716
+ )
1717
+ )
1718
+ # Save template
1719
+ templates[template_id] = parsed_template.data
1720
+
1721
+ # Convert into datatypes
1722
+ residues = np.array(res_data, dtype=Residue)
1723
+ chains = np.array(chain_data, dtype=Chain)
1724
+ interfaces = np.array([], dtype=Interface)
1725
+ mask = np.ones(len(chain_data), dtype=bool)
1726
+ rdkit_bounds_constraints = np.array(
1727
+ rdkit_bounds_constraint_data, dtype=RDKitBoundsConstraint
1728
+ )
1729
+ chiral_atom_constraints = np.array(
1730
+ chiral_atom_constraint_data, dtype=ChiralAtomConstraint
1731
+ )
1732
+ stereo_bond_constraints = np.array(
1733
+ stereo_bond_constraint_data, dtype=StereoBondConstraint
1734
+ )
1735
+ planar_bond_constraints = np.array(
1736
+ planar_bond_constraint_data, dtype=PlanarBondConstraint
1737
+ )
1738
+ planar_ring_5_constraints = np.array(
1739
+ planar_ring_5_constraint_data, dtype=PlanarRing5Constraint
1740
+ )
1741
+ planar_ring_6_constraints = np.array(
1742
+ planar_ring_6_constraint_data, dtype=PlanarRing6Constraint
1743
+ )
1744
+
1745
+ if boltz_2:
1746
+ atom_data = [(a[0], a[3], a[5], 0.0, 1.0) for a in atom_data]
1747
+ connections = [(*c, const.bond_type_ids["COVALENT"]) for c in connections]
1748
+ bond_data = bond_data + connections
1749
+ atoms = np.array(atom_data, dtype=AtomV2)
1750
+ bonds = np.array(bond_data, dtype=BondV2)
1751
+ coords = [(x,) for x in atoms["coords"]]
1752
+ coords = np.array(coords, Coords)
1753
+ ensemble = np.array([(0, len(coords))], dtype=Ensemble)
1754
+ data = StructureV2(
1755
+ atoms=atoms,
1756
+ bonds=bonds,
1757
+ residues=residues,
1758
+ chains=chains,
1759
+ interfaces=interfaces,
1760
+ mask=mask,
1761
+ coords=coords,
1762
+ ensemble=ensemble,
1763
+ )
1764
+ else:
1765
+ bond_data = [(b[4], b[5], b[6]) for b in bond_data]
1766
+ atom_data = [(convert_atom_name(a[0]), *a[1:]) for a in atom_data]
1767
+ atoms = np.array(atom_data, dtype=Atom)
1768
+ bonds = np.array(bond_data, dtype=Bond)
1769
+ connections = np.array(connections, dtype=Connection)
1770
+ data = Structure(
1771
+ atoms=atoms,
1772
+ bonds=bonds,
1773
+ residues=residues,
1774
+ chains=chains,
1775
+ connections=connections,
1776
+ interfaces=interfaces,
1777
+ mask=mask,
1778
+ )
1779
+
1780
+ # Create metadata
1781
+ struct_info = StructureInfo(num_chains=len(chains))
1782
+ chain_infos = []
1783
+ for chain in chains:
1784
+ chain_info = ChainInfo(
1785
+ chain_id=int(chain["asym_id"]),
1786
+ chain_name=chain["name"],
1787
+ mol_type=int(chain["mol_type"]),
1788
+ cluster_id=-1,
1789
+ msa_id=chain_to_msa[chain["name"]],
1790
+ num_residues=int(chain["res_num"]),
1791
+ valid=True,
1792
+ entity_id=int(chain["entity_id"]),
1793
+ )
1794
+ chain_infos.append(chain_info)
1795
+
1796
+ options = InferenceOptions(pocket_constraints=pocket_constraints)
1797
+ record = Record(
1798
+ id=name,
1799
+ structure=struct_info,
1800
+ chains=chain_infos,
1801
+ interfaces=[],
1802
+ inference_options=options,
1803
+ templates=template_records,
1804
+ affinity=affinity_info,
1805
+ )
1806
+
1807
+ residue_constraints = ResidueConstraints(
1808
+ rdkit_bounds_constraints=rdkit_bounds_constraints,
1809
+ chiral_atom_constraints=chiral_atom_constraints,
1810
+ stereo_bond_constraints=stereo_bond_constraints,
1811
+ planar_bond_constraints=planar_bond_constraints,
1812
+ planar_ring_5_constraints=planar_ring_5_constraints,
1813
+ planar_ring_6_constraints=planar_ring_6_constraints,
1814
+ )
1815
+
1816
+ return Target(
1817
+ record=record,
1818
+ structure=data,
1819
+ sequences=entity_to_seq,
1820
+ residue_constraints=residue_constraints,
1821
+ templates=templates,
1822
+ extra_mols=extra_mols,
1823
+ )
1824
+
1825
+
1826
+ def standardize(smiles: str) -> Optional[str]:
1827
+ """Standardize a molecule and return its SMILES and a flag indicating whether the molecule is valid.
1828
+ This version has exception handling, which the original in mol-finder/data doesn't have. I didn't change the mol-finder/data
1829
+ since there are a lot of other functions that depend on it and I didn't want to break them.
1830
+ """
1831
+ LARGEST_FRAGMENT_CHOOSER = rdMolStandardize.LargestFragmentChooser()
1832
+
1833
+ mol = Chem.MolFromSmiles(smiles, sanitize=False)
1834
+
1835
+ exclude = exclude_flag(mol, includeRDKitSanitization=False)
1836
+
1837
+ if exclude:
1838
+ raise ValueError("Molecule is excluded")
1839
+
1840
+ # Standardize with ChEMBL data curation pipeline. During standardization, the molecule may be broken
1841
+ # Choose molecule with largest component
1842
+ mol = LARGEST_FRAGMENT_CHOOSER.choose(mol)
1843
+ # Standardize with ChEMBL data curation pipeline. During standardization, the molecule may be broken
1844
+ mol = standardize_mol(mol)
1845
+ smiles = Chem.MolToSmiles(mol)
1846
+
1847
+ # Check if molecule can be parsed by RDKit (in rare cases, the molecule may be broken during standardization)
1848
+ if Chem.MolFromSmiles(smiles) is None:
1849
+ raise ValueError("Molecule is broken")
1850
+
1851
+ return smiles