boltz-vsynthes 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- boltz/__init__.py +7 -0
- boltz/data/__init__.py +0 -0
- boltz/data/const.py +1184 -0
- boltz/data/crop/__init__.py +0 -0
- boltz/data/crop/affinity.py +164 -0
- boltz/data/crop/boltz.py +296 -0
- boltz/data/crop/cropper.py +45 -0
- boltz/data/feature/__init__.py +0 -0
- boltz/data/feature/featurizer.py +1230 -0
- boltz/data/feature/featurizerv2.py +2208 -0
- boltz/data/feature/symmetry.py +602 -0
- boltz/data/filter/__init__.py +0 -0
- boltz/data/filter/dynamic/__init__.py +0 -0
- boltz/data/filter/dynamic/date.py +76 -0
- boltz/data/filter/dynamic/filter.py +24 -0
- boltz/data/filter/dynamic/max_residues.py +37 -0
- boltz/data/filter/dynamic/resolution.py +34 -0
- boltz/data/filter/dynamic/size.py +38 -0
- boltz/data/filter/dynamic/subset.py +42 -0
- boltz/data/filter/static/__init__.py +0 -0
- boltz/data/filter/static/filter.py +26 -0
- boltz/data/filter/static/ligand.py +37 -0
- boltz/data/filter/static/polymer.py +299 -0
- boltz/data/module/__init__.py +0 -0
- boltz/data/module/inference.py +307 -0
- boltz/data/module/inferencev2.py +429 -0
- boltz/data/module/training.py +684 -0
- boltz/data/module/trainingv2.py +660 -0
- boltz/data/mol.py +900 -0
- boltz/data/msa/__init__.py +0 -0
- boltz/data/msa/mmseqs2.py +235 -0
- boltz/data/pad.py +84 -0
- boltz/data/parse/__init__.py +0 -0
- boltz/data/parse/a3m.py +134 -0
- boltz/data/parse/csv.py +100 -0
- boltz/data/parse/fasta.py +138 -0
- boltz/data/parse/mmcif.py +1239 -0
- boltz/data/parse/mmcif_with_constraints.py +1607 -0
- boltz/data/parse/schema.py +1851 -0
- boltz/data/parse/yaml.py +68 -0
- boltz/data/sample/__init__.py +0 -0
- boltz/data/sample/cluster.py +283 -0
- boltz/data/sample/distillation.py +57 -0
- boltz/data/sample/random.py +39 -0
- boltz/data/sample/sampler.py +49 -0
- boltz/data/tokenize/__init__.py +0 -0
- boltz/data/tokenize/boltz.py +195 -0
- boltz/data/tokenize/boltz2.py +396 -0
- boltz/data/tokenize/tokenizer.py +24 -0
- boltz/data/types.py +777 -0
- boltz/data/write/__init__.py +0 -0
- boltz/data/write/mmcif.py +305 -0
- boltz/data/write/pdb.py +171 -0
- boltz/data/write/utils.py +23 -0
- boltz/data/write/writer.py +330 -0
- boltz/main.py +1292 -0
- boltz/model/__init__.py +0 -0
- boltz/model/layers/__init__.py +0 -0
- boltz/model/layers/attention.py +132 -0
- boltz/model/layers/attentionv2.py +111 -0
- boltz/model/layers/confidence_utils.py +231 -0
- boltz/model/layers/dropout.py +34 -0
- boltz/model/layers/initialize.py +100 -0
- boltz/model/layers/outer_product_mean.py +98 -0
- boltz/model/layers/pair_averaging.py +135 -0
- boltz/model/layers/pairformer.py +337 -0
- boltz/model/layers/relative.py +58 -0
- boltz/model/layers/transition.py +78 -0
- boltz/model/layers/triangular_attention/__init__.py +0 -0
- boltz/model/layers/triangular_attention/attention.py +189 -0
- boltz/model/layers/triangular_attention/primitives.py +409 -0
- boltz/model/layers/triangular_attention/utils.py +380 -0
- boltz/model/layers/triangular_mult.py +212 -0
- boltz/model/loss/__init__.py +0 -0
- boltz/model/loss/bfactor.py +49 -0
- boltz/model/loss/confidence.py +590 -0
- boltz/model/loss/confidencev2.py +621 -0
- boltz/model/loss/diffusion.py +171 -0
- boltz/model/loss/diffusionv2.py +134 -0
- boltz/model/loss/distogram.py +48 -0
- boltz/model/loss/distogramv2.py +105 -0
- boltz/model/loss/validation.py +1025 -0
- boltz/model/models/__init__.py +0 -0
- boltz/model/models/boltz1.py +1286 -0
- boltz/model/models/boltz2.py +1249 -0
- boltz/model/modules/__init__.py +0 -0
- boltz/model/modules/affinity.py +223 -0
- boltz/model/modules/confidence.py +481 -0
- boltz/model/modules/confidence_utils.py +181 -0
- boltz/model/modules/confidencev2.py +495 -0
- boltz/model/modules/diffusion.py +844 -0
- boltz/model/modules/diffusion_conditioning.py +116 -0
- boltz/model/modules/diffusionv2.py +677 -0
- boltz/model/modules/encoders.py +639 -0
- boltz/model/modules/encodersv2.py +565 -0
- boltz/model/modules/transformers.py +322 -0
- boltz/model/modules/transformersv2.py +261 -0
- boltz/model/modules/trunk.py +688 -0
- boltz/model/modules/trunkv2.py +828 -0
- boltz/model/modules/utils.py +303 -0
- boltz/model/optim/__init__.py +0 -0
- boltz/model/optim/ema.py +389 -0
- boltz/model/optim/scheduler.py +99 -0
- boltz/model/potentials/__init__.py +0 -0
- boltz/model/potentials/potentials.py +497 -0
- boltz/model/potentials/schedules.py +32 -0
- boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
- boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
- boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
- boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
- boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
- boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1851 @@
|
|
1
|
+
from collections.abc import Mapping
|
2
|
+
from dataclasses import dataclass
|
3
|
+
from pathlib import Path
|
4
|
+
from typing import Optional
|
5
|
+
|
6
|
+
import click
|
7
|
+
import numpy as np
|
8
|
+
from Bio import Align
|
9
|
+
from chembl_structure_pipeline.exclude_flag import exclude_flag
|
10
|
+
from chembl_structure_pipeline.standardizer import standardize_mol
|
11
|
+
from rdkit import Chem, rdBase
|
12
|
+
from rdkit.Chem import AllChem, HybridizationType
|
13
|
+
from rdkit.Chem.MolStandardize import rdMolStandardize
|
14
|
+
from rdkit.Chem.rdchem import BondStereo, Conformer, Mol
|
15
|
+
from rdkit.Chem.rdDistGeom import GetMoleculeBoundsMatrix
|
16
|
+
from rdkit.Chem.rdMolDescriptors import CalcNumHeavyAtoms
|
17
|
+
from scipy.optimize import linear_sum_assignment
|
18
|
+
|
19
|
+
from boltz.data import const
|
20
|
+
from boltz.data.mol import load_molecules
|
21
|
+
from boltz.data.parse.mmcif import parse_mmcif
|
22
|
+
from boltz.data.types import (
|
23
|
+
AffinityInfo,
|
24
|
+
Atom,
|
25
|
+
AtomV2,
|
26
|
+
Bond,
|
27
|
+
BondV2,
|
28
|
+
Chain,
|
29
|
+
ChainInfo,
|
30
|
+
ChiralAtomConstraint,
|
31
|
+
Connection,
|
32
|
+
Coords,
|
33
|
+
Ensemble,
|
34
|
+
InferenceOptions,
|
35
|
+
Interface,
|
36
|
+
PlanarBondConstraint,
|
37
|
+
PlanarRing5Constraint,
|
38
|
+
PlanarRing6Constraint,
|
39
|
+
RDKitBoundsConstraint,
|
40
|
+
Record,
|
41
|
+
Residue,
|
42
|
+
ResidueConstraints,
|
43
|
+
StereoBondConstraint,
|
44
|
+
Structure,
|
45
|
+
StructureInfo,
|
46
|
+
StructureV2,
|
47
|
+
Target,
|
48
|
+
TemplateInfo,
|
49
|
+
)
|
50
|
+
|
51
|
+
####################################################################################################
|
52
|
+
# DATACLASSES
|
53
|
+
####################################################################################################
|
54
|
+
|
55
|
+
|
56
|
+
@dataclass(frozen=True)
|
57
|
+
class ParsedAtom:
|
58
|
+
"""A parsed atom object."""
|
59
|
+
|
60
|
+
name: str
|
61
|
+
element: int
|
62
|
+
charge: int
|
63
|
+
coords: tuple[float, float, float]
|
64
|
+
conformer: tuple[float, float, float]
|
65
|
+
is_present: bool
|
66
|
+
chirality: int
|
67
|
+
|
68
|
+
|
69
|
+
@dataclass(frozen=True)
|
70
|
+
class ParsedBond:
|
71
|
+
"""A parsed bond object."""
|
72
|
+
|
73
|
+
atom_1: int
|
74
|
+
atom_2: int
|
75
|
+
type: int
|
76
|
+
|
77
|
+
|
78
|
+
@dataclass(frozen=True)
|
79
|
+
class ParsedRDKitBoundsConstraint:
|
80
|
+
"""A parsed RDKit bounds constraint object."""
|
81
|
+
|
82
|
+
atom_idxs: tuple[int, int]
|
83
|
+
is_bond: bool
|
84
|
+
is_angle: bool
|
85
|
+
upper_bound: float
|
86
|
+
lower_bound: float
|
87
|
+
|
88
|
+
|
89
|
+
@dataclass(frozen=True)
|
90
|
+
class ParsedChiralAtomConstraint:
|
91
|
+
"""A parsed chiral atom constraint object."""
|
92
|
+
|
93
|
+
atom_idxs: tuple[int, int, int, int]
|
94
|
+
is_reference: bool
|
95
|
+
is_r: bool
|
96
|
+
|
97
|
+
|
98
|
+
@dataclass(frozen=True)
|
99
|
+
class ParsedStereoBondConstraint:
|
100
|
+
"""A parsed stereo bond constraint object."""
|
101
|
+
|
102
|
+
atom_idxs: tuple[int, int, int, int]
|
103
|
+
is_check: bool
|
104
|
+
is_e: bool
|
105
|
+
|
106
|
+
|
107
|
+
@dataclass(frozen=True)
|
108
|
+
class ParsedPlanarBondConstraint:
|
109
|
+
"""A parsed planar bond constraint object."""
|
110
|
+
|
111
|
+
atom_idxs: tuple[int, int, int, int, int, int]
|
112
|
+
|
113
|
+
|
114
|
+
@dataclass(frozen=True)
|
115
|
+
class ParsedPlanarRing5Constraint:
|
116
|
+
"""A parsed planar bond constraint object."""
|
117
|
+
|
118
|
+
atom_idxs: tuple[int, int, int, int, int]
|
119
|
+
|
120
|
+
|
121
|
+
@dataclass(frozen=True)
|
122
|
+
class ParsedPlanarRing6Constraint:
|
123
|
+
"""A parsed planar bond constraint object."""
|
124
|
+
|
125
|
+
atom_idxs: tuple[int, int, int, int, int, int]
|
126
|
+
|
127
|
+
|
128
|
+
@dataclass(frozen=True)
|
129
|
+
class ParsedResidue:
|
130
|
+
"""A parsed residue object."""
|
131
|
+
|
132
|
+
name: str
|
133
|
+
type: int
|
134
|
+
idx: int
|
135
|
+
atoms: list[ParsedAtom]
|
136
|
+
bonds: list[ParsedBond]
|
137
|
+
orig_idx: Optional[int]
|
138
|
+
atom_center: int
|
139
|
+
atom_disto: int
|
140
|
+
is_standard: bool
|
141
|
+
is_present: bool
|
142
|
+
rdkit_bounds_constraints: Optional[list[ParsedRDKitBoundsConstraint]] = None
|
143
|
+
chiral_atom_constraints: Optional[list[ParsedChiralAtomConstraint]] = None
|
144
|
+
stereo_bond_constraints: Optional[list[ParsedStereoBondConstraint]] = None
|
145
|
+
planar_bond_constraints: Optional[list[ParsedPlanarBondConstraint]] = None
|
146
|
+
planar_ring_5_constraints: Optional[list[ParsedPlanarRing5Constraint]] = None
|
147
|
+
planar_ring_6_constraints: Optional[list[ParsedPlanarRing6Constraint]] = None
|
148
|
+
|
149
|
+
|
150
|
+
@dataclass(frozen=True)
|
151
|
+
class ParsedChain:
|
152
|
+
"""A parsed chain object."""
|
153
|
+
|
154
|
+
entity: str
|
155
|
+
type: int
|
156
|
+
residues: list[ParsedResidue]
|
157
|
+
cyclic_period: int
|
158
|
+
sequence: Optional[str] = None
|
159
|
+
affinity: Optional[bool] = False
|
160
|
+
affinity_mw: Optional[float] = None
|
161
|
+
|
162
|
+
|
163
|
+
@dataclass(frozen=True)
|
164
|
+
class Alignment:
|
165
|
+
"""A parsed alignment object."""
|
166
|
+
|
167
|
+
query_st: int
|
168
|
+
query_en: int
|
169
|
+
template_st: int
|
170
|
+
template_en: int
|
171
|
+
|
172
|
+
|
173
|
+
####################################################################################################
|
174
|
+
# HELPERS
|
175
|
+
####################################################################################################
|
176
|
+
|
177
|
+
|
178
|
+
def convert_atom_name(name: str) -> tuple[int, int, int, int]:
|
179
|
+
"""Convert an atom name to a standard format.
|
180
|
+
|
181
|
+
Parameters
|
182
|
+
----------
|
183
|
+
name : str
|
184
|
+
The atom name.
|
185
|
+
|
186
|
+
Returns
|
187
|
+
-------
|
188
|
+
Tuple[int, int, int, int]
|
189
|
+
The converted atom name.
|
190
|
+
|
191
|
+
"""
|
192
|
+
name = name.strip()
|
193
|
+
name = [ord(c) - 32 for c in name]
|
194
|
+
name = name + [0] * (4 - len(name))
|
195
|
+
return tuple(name)
|
196
|
+
|
197
|
+
|
198
|
+
def compute_3d_conformer(mol: Mol, version: str = "v3") -> bool:
|
199
|
+
"""Generate 3D coordinates using EKTDG method.
|
200
|
+
|
201
|
+
Taken from `pdbeccdutils.core.component.Component`.
|
202
|
+
|
203
|
+
Parameters
|
204
|
+
----------
|
205
|
+
mol: Mol
|
206
|
+
The RDKit molecule to process
|
207
|
+
version: str, optional
|
208
|
+
The ETKDG version, defaults ot v3
|
209
|
+
|
210
|
+
Returns
|
211
|
+
-------
|
212
|
+
bool
|
213
|
+
Whether computation was successful.
|
214
|
+
|
215
|
+
"""
|
216
|
+
if version == "v3":
|
217
|
+
options = AllChem.ETKDGv3()
|
218
|
+
elif version == "v2":
|
219
|
+
options = AllChem.ETKDGv2()
|
220
|
+
else:
|
221
|
+
options = AllChem.ETKDGv2()
|
222
|
+
|
223
|
+
options.clearConfs = False
|
224
|
+
conf_id = -1
|
225
|
+
|
226
|
+
try:
|
227
|
+
conf_id = AllChem.EmbedMolecule(mol, options)
|
228
|
+
|
229
|
+
if conf_id == -1:
|
230
|
+
print(
|
231
|
+
f"WARNING: RDKit ETKDGv3 failed to generate a conformer for molecule "
|
232
|
+
f"{Chem.MolToSmiles(AllChem.RemoveHs(mol))}, so the program will start with random coordinates. "
|
233
|
+
f"Note that the performance of the model under this behaviour was not tested."
|
234
|
+
)
|
235
|
+
options.useRandomCoords = True
|
236
|
+
conf_id = AllChem.EmbedMolecule(mol, options)
|
237
|
+
|
238
|
+
AllChem.UFFOptimizeMolecule(mol, confId=conf_id, maxIters=1000)
|
239
|
+
|
240
|
+
except RuntimeError:
|
241
|
+
pass # Force field issue here
|
242
|
+
except ValueError:
|
243
|
+
pass # sanitization issue here
|
244
|
+
|
245
|
+
if conf_id != -1:
|
246
|
+
conformer = mol.GetConformer(conf_id)
|
247
|
+
conformer.SetProp("name", "Computed")
|
248
|
+
conformer.SetProp("coord_generation", f"ETKDG{version}")
|
249
|
+
|
250
|
+
return True
|
251
|
+
|
252
|
+
return False
|
253
|
+
|
254
|
+
|
255
|
+
def get_conformer(mol: Mol) -> Conformer:
|
256
|
+
"""Retrieve an rdkit object for a deemed conformer.
|
257
|
+
|
258
|
+
Inspired by `pdbeccdutils.core.component.Component`.
|
259
|
+
|
260
|
+
Parameters
|
261
|
+
----------
|
262
|
+
mol: Mol
|
263
|
+
The molecule to process.
|
264
|
+
|
265
|
+
Returns
|
266
|
+
-------
|
267
|
+
Conformer
|
268
|
+
The desired conformer, if any.
|
269
|
+
|
270
|
+
Raises
|
271
|
+
------
|
272
|
+
ValueError
|
273
|
+
If there are no conformers of the given tyoe.
|
274
|
+
|
275
|
+
"""
|
276
|
+
# Try using the computed conformer
|
277
|
+
for c in mol.GetConformers():
|
278
|
+
try:
|
279
|
+
if c.GetProp("name") == "Computed":
|
280
|
+
return c
|
281
|
+
except KeyError: # noqa: PERF203
|
282
|
+
pass
|
283
|
+
|
284
|
+
# Fallback to the ideal coordinates
|
285
|
+
for c in mol.GetConformers():
|
286
|
+
try:
|
287
|
+
if c.GetProp("name") == "Ideal":
|
288
|
+
return c
|
289
|
+
except KeyError: # noqa: PERF203
|
290
|
+
pass
|
291
|
+
|
292
|
+
# Fallback to boltz2 format
|
293
|
+
conf_ids = [int(conf.GetId()) for conf in mol.GetConformers()]
|
294
|
+
if len(conf_ids) > 0:
|
295
|
+
conf_id = conf_ids[0]
|
296
|
+
conformer = mol.GetConformer(conf_id)
|
297
|
+
return conformer
|
298
|
+
|
299
|
+
msg = "Conformer does not exist."
|
300
|
+
raise ValueError(msg)
|
301
|
+
|
302
|
+
|
303
|
+
def compute_geometry_constraints(mol: Mol, idx_map):
|
304
|
+
if mol.GetNumAtoms() <= 1:
|
305
|
+
return []
|
306
|
+
|
307
|
+
# Ensure RingInfo is initialized
|
308
|
+
mol.UpdatePropertyCache(strict=False)
|
309
|
+
Chem.GetSymmSSSR(mol) # Compute ring information
|
310
|
+
|
311
|
+
bounds = GetMoleculeBoundsMatrix(
|
312
|
+
mol,
|
313
|
+
set15bounds=True,
|
314
|
+
scaleVDW=True,
|
315
|
+
doTriangleSmoothing=True,
|
316
|
+
useMacrocycle14config=False,
|
317
|
+
)
|
318
|
+
bonds = set(
|
319
|
+
tuple(sorted(b)) for b in mol.GetSubstructMatches(Chem.MolFromSmarts("*~*"))
|
320
|
+
)
|
321
|
+
angles = set(
|
322
|
+
tuple(sorted([a[0], a[2]]))
|
323
|
+
for a in mol.GetSubstructMatches(Chem.MolFromSmarts("*~*~*"))
|
324
|
+
)
|
325
|
+
|
326
|
+
constraints = []
|
327
|
+
for i, j in zip(*np.triu_indices(mol.GetNumAtoms(), k=1)):
|
328
|
+
if i in idx_map and j in idx_map:
|
329
|
+
constraint = ParsedRDKitBoundsConstraint(
|
330
|
+
atom_idxs=(idx_map[i], idx_map[j]),
|
331
|
+
is_bond=tuple(sorted([i, j])) in bonds,
|
332
|
+
is_angle=tuple(sorted([i, j])) in angles,
|
333
|
+
upper_bound=bounds[i, j],
|
334
|
+
lower_bound=bounds[j, i],
|
335
|
+
)
|
336
|
+
constraints.append(constraint)
|
337
|
+
return constraints
|
338
|
+
|
339
|
+
|
340
|
+
def compute_chiral_atom_constraints(mol, idx_map):
|
341
|
+
constraints = []
|
342
|
+
if all([atom.HasProp("_CIPRank") for atom in mol.GetAtoms()]):
|
343
|
+
for center_idx, orientation in Chem.FindMolChiralCenters(
|
344
|
+
mol, includeUnassigned=False
|
345
|
+
):
|
346
|
+
center = mol.GetAtomWithIdx(center_idx)
|
347
|
+
neighbors = [
|
348
|
+
(neighbor.GetIdx(), int(neighbor.GetProp("_CIPRank")))
|
349
|
+
for neighbor in center.GetNeighbors()
|
350
|
+
]
|
351
|
+
neighbors = sorted(
|
352
|
+
neighbors, key=lambda neighbor: neighbor[1], reverse=True
|
353
|
+
)
|
354
|
+
neighbors = tuple(neighbor[0] for neighbor in neighbors)
|
355
|
+
is_r = orientation == "R"
|
356
|
+
|
357
|
+
if len(neighbors) > 4 or center.GetHybridization() != HybridizationType.SP3:
|
358
|
+
continue
|
359
|
+
|
360
|
+
atom_idxs = (*neighbors[:3], center_idx)
|
361
|
+
if all(i in idx_map for i in atom_idxs):
|
362
|
+
constraints.append(
|
363
|
+
ParsedChiralAtomConstraint(
|
364
|
+
atom_idxs=tuple(idx_map[i] for i in atom_idxs),
|
365
|
+
is_reference=True,
|
366
|
+
is_r=is_r,
|
367
|
+
)
|
368
|
+
)
|
369
|
+
|
370
|
+
if len(neighbors) == 4:
|
371
|
+
for skip_idx in range(3):
|
372
|
+
chiral_set = neighbors[:skip_idx] + neighbors[skip_idx + 1 :]
|
373
|
+
if skip_idx % 2 == 0:
|
374
|
+
atom_idxs = chiral_set[::-1] + (center_idx,)
|
375
|
+
else:
|
376
|
+
atom_idxs = chiral_set + (center_idx,)
|
377
|
+
if all(i in idx_map for i in atom_idxs):
|
378
|
+
constraints.append(
|
379
|
+
ParsedChiralAtomConstraint(
|
380
|
+
atom_idxs=tuple(idx_map[i] for i in atom_idxs),
|
381
|
+
is_reference=False,
|
382
|
+
is_r=is_r,
|
383
|
+
)
|
384
|
+
)
|
385
|
+
return constraints
|
386
|
+
|
387
|
+
|
388
|
+
def compute_stereo_bond_constraints(mol, idx_map):
|
389
|
+
constraints = []
|
390
|
+
if all([atom.HasProp("_CIPRank") for atom in mol.GetAtoms()]):
|
391
|
+
for bond in mol.GetBonds():
|
392
|
+
stereo = bond.GetStereo()
|
393
|
+
if stereo in {BondStereo.STEREOE, BondStereo.STEREOZ}:
|
394
|
+
start_atom_idx, end_atom_idx = (
|
395
|
+
bond.GetBeginAtomIdx(),
|
396
|
+
bond.GetEndAtomIdx(),
|
397
|
+
)
|
398
|
+
start_neighbors = [
|
399
|
+
(neighbor.GetIdx(), int(neighbor.GetProp("_CIPRank")))
|
400
|
+
for neighbor in mol.GetAtomWithIdx(start_atom_idx).GetNeighbors()
|
401
|
+
if neighbor.GetIdx() != end_atom_idx
|
402
|
+
]
|
403
|
+
start_neighbors = sorted(
|
404
|
+
start_neighbors, key=lambda neighbor: neighbor[1], reverse=True
|
405
|
+
)
|
406
|
+
start_neighbors = [neighbor[0] for neighbor in start_neighbors]
|
407
|
+
end_neighbors = [
|
408
|
+
(neighbor.GetIdx(), int(neighbor.GetProp("_CIPRank")))
|
409
|
+
for neighbor in mol.GetAtomWithIdx(end_atom_idx).GetNeighbors()
|
410
|
+
if neighbor.GetIdx() != start_atom_idx
|
411
|
+
]
|
412
|
+
end_neighbors = sorted(
|
413
|
+
end_neighbors, key=lambda neighbor: neighbor[1], reverse=True
|
414
|
+
)
|
415
|
+
end_neighbors = [neighbor[0] for neighbor in end_neighbors]
|
416
|
+
is_e = stereo == BondStereo.STEREOE
|
417
|
+
|
418
|
+
atom_idxs = (
|
419
|
+
start_neighbors[0],
|
420
|
+
start_atom_idx,
|
421
|
+
end_atom_idx,
|
422
|
+
end_neighbors[0],
|
423
|
+
)
|
424
|
+
if all(i in idx_map for i in atom_idxs):
|
425
|
+
constraints.append(
|
426
|
+
ParsedStereoBondConstraint(
|
427
|
+
atom_idxs=tuple(idx_map[i] for i in atom_idxs),
|
428
|
+
is_check=True,
|
429
|
+
is_e=is_e,
|
430
|
+
)
|
431
|
+
)
|
432
|
+
|
433
|
+
if len(start_neighbors) == 2 and len(end_neighbors) == 2:
|
434
|
+
atom_idxs = (
|
435
|
+
start_neighbors[1],
|
436
|
+
start_atom_idx,
|
437
|
+
end_atom_idx,
|
438
|
+
end_neighbors[1],
|
439
|
+
)
|
440
|
+
if all(i in idx_map for i in atom_idxs):
|
441
|
+
constraints.append(
|
442
|
+
ParsedStereoBondConstraint(
|
443
|
+
atom_idxs=tuple(idx_map[i] for i in atom_idxs),
|
444
|
+
is_check=False,
|
445
|
+
is_e=is_e,
|
446
|
+
)
|
447
|
+
)
|
448
|
+
return constraints
|
449
|
+
|
450
|
+
|
451
|
+
def compute_flatness_constraints(mol, idx_map):
|
452
|
+
planar_double_bond_smarts = Chem.MolFromSmarts("[C;X3;^2](*)(*)=[C;X3;^2](*)(*)")
|
453
|
+
aromatic_ring_5_smarts = Chem.MolFromSmarts("[ar5^2]1[ar5^2][ar5^2][ar5^2][ar5^2]1")
|
454
|
+
aromatic_ring_6_smarts = Chem.MolFromSmarts(
|
455
|
+
"[ar6^2]1[ar6^2][ar6^2][ar6^2][ar6^2][ar6^2]1"
|
456
|
+
)
|
457
|
+
|
458
|
+
planar_double_bond_constraints = []
|
459
|
+
aromatic_ring_5_constraints = []
|
460
|
+
aromatic_ring_6_constraints = []
|
461
|
+
for match in mol.GetSubstructMatches(planar_double_bond_smarts):
|
462
|
+
if all(i in idx_map for i in match):
|
463
|
+
planar_double_bond_constraints.append(
|
464
|
+
ParsedPlanarBondConstraint(atom_idxs=tuple(idx_map[i] for i in match))
|
465
|
+
)
|
466
|
+
for match in mol.GetSubstructMatches(aromatic_ring_5_smarts):
|
467
|
+
if all(i in idx_map for i in match):
|
468
|
+
aromatic_ring_5_constraints.append(
|
469
|
+
ParsedPlanarRing5Constraint(atom_idxs=tuple(idx_map[i] for i in match))
|
470
|
+
)
|
471
|
+
for match in mol.GetSubstructMatches(aromatic_ring_6_smarts):
|
472
|
+
if all(i in idx_map for i in match):
|
473
|
+
aromatic_ring_6_constraints.append(
|
474
|
+
ParsedPlanarRing6Constraint(atom_idxs=tuple(idx_map[i] for i in match))
|
475
|
+
)
|
476
|
+
|
477
|
+
return (
|
478
|
+
planar_double_bond_constraints,
|
479
|
+
aromatic_ring_5_constraints,
|
480
|
+
aromatic_ring_6_constraints,
|
481
|
+
)
|
482
|
+
|
483
|
+
|
484
|
+
def get_global_alignment_score(query: str, template: str) -> float:
|
485
|
+
"""Align a sequence to a template.
|
486
|
+
|
487
|
+
Parameters
|
488
|
+
----------
|
489
|
+
query : str
|
490
|
+
The query sequence.
|
491
|
+
template : str
|
492
|
+
The template sequence.
|
493
|
+
|
494
|
+
Returns
|
495
|
+
-------
|
496
|
+
float
|
497
|
+
The global alignment score.
|
498
|
+
|
499
|
+
"""
|
500
|
+
aligner = Align.PairwiseAligner(scoring="blastp")
|
501
|
+
aligner.mode = "global"
|
502
|
+
score = aligner.align(query, template)[0].score
|
503
|
+
return score
|
504
|
+
|
505
|
+
|
506
|
+
def get_local_alignments(query: str, template: str) -> list[Alignment]:
|
507
|
+
"""Align a sequence to a template.
|
508
|
+
|
509
|
+
Parameters
|
510
|
+
----------
|
511
|
+
query : str
|
512
|
+
The query sequence.
|
513
|
+
template : str
|
514
|
+
The template sequence.
|
515
|
+
|
516
|
+
Returns
|
517
|
+
-------
|
518
|
+
Alignment
|
519
|
+
The alignment between the query and template.
|
520
|
+
|
521
|
+
"""
|
522
|
+
aligner = Align.PairwiseAligner(scoring="blastp")
|
523
|
+
aligner.mode = "local"
|
524
|
+
aligner.open_gap_score = -1000
|
525
|
+
aligner.extend_gap_score = -1000
|
526
|
+
|
527
|
+
alignments = []
|
528
|
+
for result in aligner.align(query, template):
|
529
|
+
coordinates = result.coordinates
|
530
|
+
alignment = Alignment(
|
531
|
+
query_st=int(coordinates[0][0]),
|
532
|
+
query_en=int(coordinates[0][1]),
|
533
|
+
template_st=int(coordinates[1][0]),
|
534
|
+
template_en=int(coordinates[1][1]),
|
535
|
+
)
|
536
|
+
alignments.append(alignment)
|
537
|
+
|
538
|
+
return alignments
|
539
|
+
|
540
|
+
|
541
|
+
def get_template_records_from_search(
|
542
|
+
template_id: str,
|
543
|
+
chain_ids: list[str],
|
544
|
+
sequences: dict[str, str],
|
545
|
+
template_chain_ids: list[str],
|
546
|
+
template_sequences: dict[str, str],
|
547
|
+
) -> list[TemplateInfo]:
|
548
|
+
"""Get template records from an alignment."""
|
549
|
+
# Compute pairwise scores
|
550
|
+
score_matrix = []
|
551
|
+
for chain_id in chain_ids:
|
552
|
+
row = []
|
553
|
+
for template_chain_id in template_chain_ids:
|
554
|
+
chain_seq = sequences[chain_id]
|
555
|
+
template_seq = template_sequences[template_chain_id]
|
556
|
+
score = get_global_alignment_score(chain_seq, template_seq)
|
557
|
+
row.append(score)
|
558
|
+
score_matrix.append(row)
|
559
|
+
|
560
|
+
# Find optimal mapping
|
561
|
+
row_ind, col_ind = linear_sum_assignment(score_matrix, maximize=True)
|
562
|
+
|
563
|
+
# Get alignment records
|
564
|
+
template_records = []
|
565
|
+
|
566
|
+
for row_idx, col_idx in zip(row_ind, col_ind):
|
567
|
+
chain_id = chain_ids[row_idx]
|
568
|
+
template_chain_id = template_chain_ids[col_idx]
|
569
|
+
chain_seq = sequences[chain_id]
|
570
|
+
template_seq = template_sequences[template_chain_id]
|
571
|
+
alignments = get_local_alignments(chain_seq, template_seq)
|
572
|
+
|
573
|
+
for alignment in alignments:
|
574
|
+
template_record = TemplateInfo(
|
575
|
+
name=template_id,
|
576
|
+
query_chain=chain_id,
|
577
|
+
query_st=alignment.query_st,
|
578
|
+
query_en=alignment.query_en,
|
579
|
+
template_chain=template_chain_id,
|
580
|
+
template_st=alignment.template_st,
|
581
|
+
template_en=alignment.template_en,
|
582
|
+
)
|
583
|
+
template_records.append(template_record)
|
584
|
+
|
585
|
+
return template_records
|
586
|
+
|
587
|
+
|
588
|
+
def get_template_records_from_matching(
|
589
|
+
template_id: str,
|
590
|
+
chain_ids: list[str],
|
591
|
+
sequences: dict[str, str],
|
592
|
+
template_chain_ids: list[str],
|
593
|
+
template_sequences: dict[str, str],
|
594
|
+
) -> list[TemplateInfo]:
|
595
|
+
"""Get template records from a given matching."""
|
596
|
+
template_records = []
|
597
|
+
|
598
|
+
for chain_id, template_chain_id in zip(chain_ids, template_chain_ids):
|
599
|
+
# Align the sequences
|
600
|
+
chain_seq = sequences[chain_id]
|
601
|
+
template_seq = template_sequences[template_chain_id]
|
602
|
+
alignments = get_local_alignments(chain_seq, template_seq)
|
603
|
+
for alignment in alignments:
|
604
|
+
template_record = TemplateInfo(
|
605
|
+
name=template_id,
|
606
|
+
query_chain=chain_id,
|
607
|
+
query_st=alignment.query_st,
|
608
|
+
query_en=alignment.query_en,
|
609
|
+
template_chain=template_chain_id,
|
610
|
+
template_st=alignment.template_st,
|
611
|
+
template_en=alignment.template_en,
|
612
|
+
)
|
613
|
+
template_records.append(template_record)
|
614
|
+
|
615
|
+
return template_records
|
616
|
+
|
617
|
+
|
618
|
+
def get_mol(ccd: str, mols: dict, moldir: str) -> Mol:
|
619
|
+
"""Get mol from CCD code.
|
620
|
+
|
621
|
+
Return mol with ccd from mols if it is in mols. Otherwise load it from moldir,
|
622
|
+
add it to mols, and return the mol.
|
623
|
+
"""
|
624
|
+
mol = mols.get(ccd)
|
625
|
+
if mol is None:
|
626
|
+
mol = load_molecules(moldir, [ccd])[ccd]
|
627
|
+
return mol
|
628
|
+
|
629
|
+
|
630
|
+
####################################################################################################
|
631
|
+
# PARSING
|
632
|
+
####################################################################################################
|
633
|
+
|
634
|
+
|
635
|
+
def parse_ccd_residue(
|
636
|
+
name: str, ref_mol: Mol, res_idx: int, drop_leaving_atoms: bool = False
|
637
|
+
) -> Optional[ParsedResidue]:
|
638
|
+
"""Parse an MMCIF ligand.
|
639
|
+
|
640
|
+
First tries to get the SMILES string from the RCSB.
|
641
|
+
Then, tries to infer atom ordering using RDKit.
|
642
|
+
|
643
|
+
Parameters
|
644
|
+
----------
|
645
|
+
name: str
|
646
|
+
The name of the molecule to parse.
|
647
|
+
ref_mol: Mol
|
648
|
+
The reference molecule to parse.
|
649
|
+
res_idx : int
|
650
|
+
The residue index.
|
651
|
+
|
652
|
+
Returns
|
653
|
+
-------
|
654
|
+
ParsedResidue, optional
|
655
|
+
The output ParsedResidue, if successful.
|
656
|
+
|
657
|
+
"""
|
658
|
+
unk_chirality = const.chirality_type_ids[const.unk_chirality_type]
|
659
|
+
|
660
|
+
# Check if this is a single heavy atom CCD residue
|
661
|
+
if CalcNumHeavyAtoms(ref_mol) == 1:
|
662
|
+
# Remove hydrogens
|
663
|
+
ref_mol = AllChem.RemoveHs(ref_mol, sanitize=False)
|
664
|
+
|
665
|
+
pos = (0, 0, 0)
|
666
|
+
ref_atom = ref_mol.GetAtoms()[0]
|
667
|
+
chirality_type = const.chirality_type_ids.get(
|
668
|
+
str(ref_atom.GetChiralTag()), unk_chirality
|
669
|
+
)
|
670
|
+
atom = ParsedAtom(
|
671
|
+
name=ref_atom.GetProp("name"),
|
672
|
+
element=ref_atom.GetAtomicNum(),
|
673
|
+
charge=ref_atom.GetFormalCharge(),
|
674
|
+
coords=pos,
|
675
|
+
conformer=(0, 0, 0),
|
676
|
+
is_present=True,
|
677
|
+
chirality=chirality_type,
|
678
|
+
)
|
679
|
+
unk_prot_id = const.unk_token_ids["PROTEIN"]
|
680
|
+
residue = ParsedResidue(
|
681
|
+
name=name,
|
682
|
+
type=unk_prot_id,
|
683
|
+
atoms=[atom],
|
684
|
+
bonds=[],
|
685
|
+
idx=res_idx,
|
686
|
+
orig_idx=None,
|
687
|
+
atom_center=0, # Placeholder, no center
|
688
|
+
atom_disto=0, # Placeholder, no center
|
689
|
+
is_standard=False,
|
690
|
+
is_present=True,
|
691
|
+
)
|
692
|
+
return residue
|
693
|
+
|
694
|
+
# Get reference conformer coordinates
|
695
|
+
conformer = get_conformer(ref_mol)
|
696
|
+
|
697
|
+
# Parse each atom in order of the reference mol
|
698
|
+
atoms = []
|
699
|
+
atom_idx = 0
|
700
|
+
idx_map = {} # Used for bonds later
|
701
|
+
|
702
|
+
for i, atom in enumerate(ref_mol.GetAtoms()):
|
703
|
+
# Ignore Hydrogen atoms
|
704
|
+
if atom.GetAtomicNum() == 1:
|
705
|
+
continue
|
706
|
+
|
707
|
+
# Get atom name, charge, element and reference coordinates
|
708
|
+
atom_name = atom.GetProp("name")
|
709
|
+
|
710
|
+
# Drop leaving atoms for non-canonical amino acids.
|
711
|
+
if drop_leaving_atoms and int(atom.GetProp('leaving_atom')):
|
712
|
+
continue
|
713
|
+
|
714
|
+
charge = atom.GetFormalCharge()
|
715
|
+
element = atom.GetAtomicNum()
|
716
|
+
ref_coords = conformer.GetAtomPosition(atom.GetIdx())
|
717
|
+
ref_coords = (ref_coords.x, ref_coords.y, ref_coords.z)
|
718
|
+
chirality_type = const.chirality_type_ids.get(
|
719
|
+
str(atom.GetChiralTag()), unk_chirality
|
720
|
+
)
|
721
|
+
|
722
|
+
# Get PDB coordinates, if any
|
723
|
+
coords = (0, 0, 0)
|
724
|
+
atom_is_present = True
|
725
|
+
|
726
|
+
# Add atom to list
|
727
|
+
atoms.append(
|
728
|
+
ParsedAtom(
|
729
|
+
name=atom_name,
|
730
|
+
element=element,
|
731
|
+
charge=charge,
|
732
|
+
coords=coords,
|
733
|
+
conformer=ref_coords,
|
734
|
+
is_present=atom_is_present,
|
735
|
+
chirality=chirality_type,
|
736
|
+
)
|
737
|
+
)
|
738
|
+
idx_map[i] = atom_idx
|
739
|
+
atom_idx += 1
|
740
|
+
|
741
|
+
# Load bonds
|
742
|
+
bonds = []
|
743
|
+
unk_bond = const.bond_type_ids[const.unk_bond_type]
|
744
|
+
for bond in ref_mol.GetBonds():
|
745
|
+
idx_1 = bond.GetBeginAtomIdx()
|
746
|
+
idx_2 = bond.GetEndAtomIdx()
|
747
|
+
|
748
|
+
# Skip bonds with atoms ignored
|
749
|
+
if (idx_1 not in idx_map) or (idx_2 not in idx_map):
|
750
|
+
continue
|
751
|
+
|
752
|
+
idx_1 = idx_map[idx_1]
|
753
|
+
idx_2 = idx_map[idx_2]
|
754
|
+
start = min(idx_1, idx_2)
|
755
|
+
end = max(idx_1, idx_2)
|
756
|
+
bond_type = bond.GetBondType().name
|
757
|
+
bond_type = const.bond_type_ids.get(bond_type, unk_bond)
|
758
|
+
bonds.append(ParsedBond(start, end, bond_type))
|
759
|
+
|
760
|
+
rdkit_bounds_constraints = compute_geometry_constraints(ref_mol, idx_map)
|
761
|
+
chiral_atom_constraints = compute_chiral_atom_constraints(ref_mol, idx_map)
|
762
|
+
stereo_bond_constraints = compute_stereo_bond_constraints(ref_mol, idx_map)
|
763
|
+
planar_bond_constraints, planar_ring_5_constraints, planar_ring_6_constraints = (
|
764
|
+
compute_flatness_constraints(ref_mol, idx_map)
|
765
|
+
)
|
766
|
+
|
767
|
+
unk_prot_id = const.unk_token_ids["PROTEIN"]
|
768
|
+
return ParsedResidue(
|
769
|
+
name=name,
|
770
|
+
type=unk_prot_id,
|
771
|
+
atoms=atoms,
|
772
|
+
bonds=bonds,
|
773
|
+
idx=res_idx,
|
774
|
+
atom_center=0,
|
775
|
+
atom_disto=0,
|
776
|
+
orig_idx=None,
|
777
|
+
is_standard=False,
|
778
|
+
is_present=True,
|
779
|
+
rdkit_bounds_constraints=rdkit_bounds_constraints,
|
780
|
+
chiral_atom_constraints=chiral_atom_constraints,
|
781
|
+
stereo_bond_constraints=stereo_bond_constraints,
|
782
|
+
planar_bond_constraints=planar_bond_constraints,
|
783
|
+
planar_ring_5_constraints=planar_ring_5_constraints,
|
784
|
+
planar_ring_6_constraints=planar_ring_6_constraints,
|
785
|
+
)
|
786
|
+
|
787
|
+
|
788
|
+
def parse_polymer(
|
789
|
+
sequence: list[str],
|
790
|
+
raw_sequence: str,
|
791
|
+
entity: str,
|
792
|
+
chain_type: str,
|
793
|
+
components: dict[str, Mol],
|
794
|
+
cyclic: bool,
|
795
|
+
mol_dir: Path,
|
796
|
+
) -> Optional[ParsedChain]:
|
797
|
+
"""Process a sequence into a chain object.
|
798
|
+
|
799
|
+
Performs alignment of the full sequence to the polymer
|
800
|
+
residues. Loads coordinates and masks for the atoms in
|
801
|
+
the polymer, following the ordering in const.atom_order.
|
802
|
+
|
803
|
+
Parameters
|
804
|
+
----------
|
805
|
+
sequence : list[str]
|
806
|
+
The full sequence of the polymer.
|
807
|
+
entity : str
|
808
|
+
The entity id.
|
809
|
+
entity_type : str
|
810
|
+
The entity type.
|
811
|
+
components : dict[str, Mol]
|
812
|
+
The preprocessed PDB components dictionary.
|
813
|
+
|
814
|
+
Returns
|
815
|
+
-------
|
816
|
+
ParsedChain, optional
|
817
|
+
The output chain, if successful.
|
818
|
+
|
819
|
+
Raises
|
820
|
+
------
|
821
|
+
ValueError
|
822
|
+
If the alignment fails.
|
823
|
+
|
824
|
+
"""
|
825
|
+
ref_res = set(const.tokens)
|
826
|
+
unk_chirality = const.chirality_type_ids[const.unk_chirality_type]
|
827
|
+
|
828
|
+
# Get coordinates and masks
|
829
|
+
parsed = []
|
830
|
+
for res_idx, res_name in enumerate(sequence):
|
831
|
+
# Check if modified residue
|
832
|
+
# Map MSE to MET
|
833
|
+
res_corrected = res_name if res_name != "MSE" else "MET"
|
834
|
+
|
835
|
+
# Handle non-standard residues
|
836
|
+
if res_corrected not in ref_res:
|
837
|
+
ref_mol = get_mol(res_corrected, components, mol_dir)
|
838
|
+
residue = parse_ccd_residue(
|
839
|
+
name=res_corrected,
|
840
|
+
ref_mol=ref_mol,
|
841
|
+
res_idx=res_idx,
|
842
|
+
drop_leaving_atoms=True,
|
843
|
+
)
|
844
|
+
parsed.append(residue)
|
845
|
+
continue
|
846
|
+
|
847
|
+
# Load ref residue
|
848
|
+
ref_mol = get_mol(res_corrected, components, mol_dir)
|
849
|
+
ref_mol = AllChem.RemoveHs(ref_mol, sanitize=False)
|
850
|
+
ref_conformer = get_conformer(ref_mol)
|
851
|
+
|
852
|
+
# Only use reference atoms set in constants
|
853
|
+
ref_name_to_atom = {a.GetProp("name"): a for a in ref_mol.GetAtoms()}
|
854
|
+
ref_atoms = [ref_name_to_atom[a] for a in const.ref_atoms[res_corrected]]
|
855
|
+
|
856
|
+
# Iterate, always in the same order
|
857
|
+
atoms: list[ParsedAtom] = []
|
858
|
+
|
859
|
+
for ref_atom in ref_atoms:
|
860
|
+
# Get atom name
|
861
|
+
atom_name = ref_atom.GetProp("name")
|
862
|
+
idx = ref_atom.GetIdx()
|
863
|
+
|
864
|
+
# Get conformer coordinates
|
865
|
+
ref_coords = ref_conformer.GetAtomPosition(idx)
|
866
|
+
ref_coords = (ref_coords.x, ref_coords.y, ref_coords.z)
|
867
|
+
|
868
|
+
# Set 0 coordinate
|
869
|
+
atom_is_present = True
|
870
|
+
coords = (0, 0, 0)
|
871
|
+
|
872
|
+
# Add atom to list
|
873
|
+
atoms.append(
|
874
|
+
ParsedAtom(
|
875
|
+
name=atom_name,
|
876
|
+
element=ref_atom.GetAtomicNum(),
|
877
|
+
charge=ref_atom.GetFormalCharge(),
|
878
|
+
coords=coords,
|
879
|
+
conformer=ref_coords,
|
880
|
+
is_present=atom_is_present,
|
881
|
+
chirality=const.chirality_type_ids.get(
|
882
|
+
str(ref_atom.GetChiralTag()), unk_chirality
|
883
|
+
),
|
884
|
+
)
|
885
|
+
)
|
886
|
+
|
887
|
+
atom_center = const.res_to_center_atom_id[res_corrected]
|
888
|
+
atom_disto = const.res_to_disto_atom_id[res_corrected]
|
889
|
+
parsed.append(
|
890
|
+
ParsedResidue(
|
891
|
+
name=res_corrected,
|
892
|
+
type=const.token_ids[res_corrected],
|
893
|
+
atoms=atoms,
|
894
|
+
bonds=[],
|
895
|
+
idx=res_idx,
|
896
|
+
atom_center=atom_center,
|
897
|
+
atom_disto=atom_disto,
|
898
|
+
is_standard=True,
|
899
|
+
is_present=True,
|
900
|
+
orig_idx=None,
|
901
|
+
)
|
902
|
+
)
|
903
|
+
|
904
|
+
if cyclic:
|
905
|
+
cyclic_period = len(sequence)
|
906
|
+
else:
|
907
|
+
cyclic_period = 0
|
908
|
+
|
909
|
+
# Return polymer object
|
910
|
+
return ParsedChain(
|
911
|
+
entity=entity,
|
912
|
+
residues=parsed,
|
913
|
+
type=chain_type,
|
914
|
+
cyclic_period=cyclic_period,
|
915
|
+
sequence=raw_sequence,
|
916
|
+
)
|
917
|
+
|
918
|
+
|
919
|
+
def token_spec_to_ids(
|
920
|
+
chain_name, residue_index_or_atom_name, chain_to_idx, atom_idx_map, chains
|
921
|
+
):
|
922
|
+
# TODO: unfinished
|
923
|
+
if chains[chain_name].type == const.chain_type_ids["NONPOLYMER"]:
|
924
|
+
# Non-polymer chains are indexed by atom name
|
925
|
+
_, _, atom_idx = atom_idx_map[(chain_name, 0, residue_index_or_atom_name)]
|
926
|
+
return (chain_to_idx[chain_name], atom_idx)
|
927
|
+
else:
|
928
|
+
# Polymer chains are indexed by residue index
|
929
|
+
contacts.append((chain_to_idx[chain_name], residue_index_or_atom_name - 1))
|
930
|
+
|
931
|
+
|
932
|
+
def parse_boltz_schema( # noqa: C901, PLR0915, PLR0912
|
933
|
+
name: str,
|
934
|
+
schema: dict,
|
935
|
+
ccd: Mapping[str, Mol],
|
936
|
+
mol_dir: Optional[Path] = None,
|
937
|
+
boltz_2: bool = False,
|
938
|
+
) -> Target:
|
939
|
+
"""Parse a Boltz input yaml / json.
|
940
|
+
|
941
|
+
The input file should be a dictionary with the following format:
|
942
|
+
|
943
|
+
version: 1
|
944
|
+
sequences:
|
945
|
+
- protein:
|
946
|
+
id: A
|
947
|
+
sequence: "MADQLTEEQIAEFKEAFSLF" # or pdb: "1a2k" or pdb: "path/to/file.pdb"
|
948
|
+
msa: path/to/msa1.a3m
|
949
|
+
- protein:
|
950
|
+
id: [B, C]
|
951
|
+
sequence: "AKLSILPWGHC"
|
952
|
+
msa: path/to/msa2.a3m
|
953
|
+
- rna:
|
954
|
+
id: D
|
955
|
+
sequence: "GCAUAGC"
|
956
|
+
- ligand:
|
957
|
+
id: E
|
958
|
+
smiles: "CC1=CC=CC=C1"
|
959
|
+
constraints:
|
960
|
+
- bond:
|
961
|
+
atom1: [A, 1, CA]
|
962
|
+
atom2: [A, 2, N]
|
963
|
+
- pocket:
|
964
|
+
binder: E
|
965
|
+
contacts: [[B, 1], [B, 2]]
|
966
|
+
max_distance: 6
|
967
|
+
- contact:
|
968
|
+
token1: [A, 1]
|
969
|
+
token2: [B, 1]
|
970
|
+
max_distance: 6
|
971
|
+
templates:
|
972
|
+
- cif: path/to/template.cif
|
973
|
+
properties:
|
974
|
+
- affinity:
|
975
|
+
binder: E
|
976
|
+
|
977
|
+
Parameters
|
978
|
+
----------
|
979
|
+
name : str
|
980
|
+
A name for the input.
|
981
|
+
schema : dict
|
982
|
+
The input schema.
|
983
|
+
components : dict
|
984
|
+
Dictionary of CCD components.
|
985
|
+
mol_dir: Path
|
986
|
+
Path to the directory containing the molecules.
|
987
|
+
boltz2: bool
|
988
|
+
Whether to parse the input for Boltz2.
|
989
|
+
|
990
|
+
Returns
|
991
|
+
-------
|
992
|
+
Target
|
993
|
+
The parsed target.
|
994
|
+
|
995
|
+
"""
|
996
|
+
# Assert version 1
|
997
|
+
version = schema.get("version", 1)
|
998
|
+
if version != 1:
|
999
|
+
msg = f"Invalid version {version} in input!"
|
1000
|
+
raise ValueError(msg)
|
1001
|
+
|
1002
|
+
# Disable rdkit warnings
|
1003
|
+
blocker = rdBase.BlockLogs() # noqa: F841
|
1004
|
+
|
1005
|
+
# First group items that have the same type, sequence and modifications
|
1006
|
+
items_to_group = {}
|
1007
|
+
chain_name_to_entity_type = {}
|
1008
|
+
|
1009
|
+
for item in schema["sequences"]:
|
1010
|
+
# Get entity type
|
1011
|
+
entity_type = next(iter(item.keys())).lower()
|
1012
|
+
if entity_type not in {"protein", "dna", "rna", "ligand"}:
|
1013
|
+
msg = f"Invalid entity type: {entity_type}"
|
1014
|
+
raise ValueError(msg)
|
1015
|
+
|
1016
|
+
# Get sequence or PDB
|
1017
|
+
if entity_type in {"protein", "dna", "rna"}:
|
1018
|
+
if "sequence" in item[entity_type]:
|
1019
|
+
seq = str(item[entity_type]["sequence"])
|
1020
|
+
elif "pdb" in item[entity_type]:
|
1021
|
+
pdb_input = str(item[entity_type]["pdb"])
|
1022
|
+
# Check if it's a PDB code (4 characters) or a file path
|
1023
|
+
if len(pdb_input) == 4 and pdb_input.isalnum():
|
1024
|
+
# It's a PDB code, check cache first
|
1025
|
+
cache_dir = Path(os.environ.get("BOLTZ_CACHE", "~/.boltz")).expanduser()
|
1026
|
+
pdb_cache_dir = cache_dir / "pdb"
|
1027
|
+
pdb_cache_dir.mkdir(parents=True, exist_ok=True)
|
1028
|
+
|
1029
|
+
pdb_cache_file = pdb_cache_dir / f"{pdb_input.lower()}.pdb"
|
1030
|
+
|
1031
|
+
if pdb_cache_file.exists():
|
1032
|
+
# Use cached file
|
1033
|
+
with pdb_cache_file.open("r") as f:
|
1034
|
+
pdb_data = f.read()
|
1035
|
+
else:
|
1036
|
+
# Download and cache
|
1037
|
+
import urllib.request
|
1038
|
+
pdb_url = f"https://files.rcsb.org/download/{pdb_input.lower()}.pdb"
|
1039
|
+
try:
|
1040
|
+
with urllib.request.urlopen(pdb_url) as response:
|
1041
|
+
pdb_data = response.read().decode()
|
1042
|
+
# Cache the downloaded data
|
1043
|
+
with pdb_cache_file.open("w") as f:
|
1044
|
+
f.write(pdb_data)
|
1045
|
+
except Exception as e:
|
1046
|
+
msg = f"Failed to download PDB {pdb_input}: {str(e)}"
|
1047
|
+
raise RuntimeError(msg) from e
|
1048
|
+
else:
|
1049
|
+
# It's a file path
|
1050
|
+
pdb_path = Path(pdb_input)
|
1051
|
+
if not pdb_path.exists():
|
1052
|
+
msg = f"PDB file not found: {pdb_path}"
|
1053
|
+
raise FileNotFoundError(msg)
|
1054
|
+
with pdb_path.open("r") as f:
|
1055
|
+
pdb_data = f.read()
|
1056
|
+
|
1057
|
+
# Parse PDB data
|
1058
|
+
from Bio.PDB import PDBParser
|
1059
|
+
from io import StringIO
|
1060
|
+
parser = PDBParser()
|
1061
|
+
structure = parser.get_structure("protein", StringIO(pdb_data))
|
1062
|
+
|
1063
|
+
# Extract sequence
|
1064
|
+
seq = ""
|
1065
|
+
for model in structure:
|
1066
|
+
for chain in model:
|
1067
|
+
for residue in chain:
|
1068
|
+
if residue.id[0] == " ": # Only standard residues
|
1069
|
+
seq += residue.resname
|
1070
|
+
seq = "".join(seq)
|
1071
|
+
else:
|
1072
|
+
msg = "Protein must have either 'sequence' or 'pdb' field"
|
1073
|
+
raise ValueError(msg)
|
1074
|
+
elif entity_type == "ligand":
|
1075
|
+
assert "smiles" in item[entity_type] or "ccd" in item[entity_type]
|
1076
|
+
assert "smiles" not in item[entity_type] or "ccd" not in item[entity_type]
|
1077
|
+
if "smiles" in item[entity_type]:
|
1078
|
+
seq = str(item[entity_type]["smiles"])
|
1079
|
+
else:
|
1080
|
+
seq = str(item[entity_type]["ccd"])
|
1081
|
+
|
1082
|
+
# Group items by entity
|
1083
|
+
items_to_group.setdefault((entity_type, seq), []).append(item)
|
1084
|
+
|
1085
|
+
# Map chain names to entity types
|
1086
|
+
chain_names = item[entity_type]["id"]
|
1087
|
+
chain_names = [chain_names] if isinstance(chain_names, str) else chain_names
|
1088
|
+
for chain_name in chain_names:
|
1089
|
+
chain_name_to_entity_type[chain_name] = entity_type
|
1090
|
+
|
1091
|
+
# Check if any affinity ligand is present
|
1092
|
+
affinity_ligands = set()
|
1093
|
+
properties = schema.get("properties", [])
|
1094
|
+
if properties and not boltz_2:
|
1095
|
+
msg = "Affinity prediction is only supported for Boltz2!"
|
1096
|
+
raise ValueError(msg)
|
1097
|
+
|
1098
|
+
for prop in properties:
|
1099
|
+
prop_type = next(iter(prop.keys())).lower()
|
1100
|
+
if prop_type == "affinity":
|
1101
|
+
binder = prop["affinity"]["binder"]
|
1102
|
+
if not isinstance(binder, str):
|
1103
|
+
# TODO: support multi residue ligands and ccd's
|
1104
|
+
msg = "Binder must be a single chain."
|
1105
|
+
raise ValueError(msg)
|
1106
|
+
|
1107
|
+
if binder not in chain_name_to_entity_type:
|
1108
|
+
msg = f"Could not find binder with name {binder} in the input!"
|
1109
|
+
raise ValueError(msg)
|
1110
|
+
|
1111
|
+
if chain_name_to_entity_type[binder] != "ligand":
|
1112
|
+
msg = (
|
1113
|
+
f"Chain {binder} is not a ligand! "
|
1114
|
+
"Affinity is currently only supported for ligands."
|
1115
|
+
)
|
1116
|
+
raise ValueError(msg)
|
1117
|
+
|
1118
|
+
affinity_ligands.add(binder)
|
1119
|
+
|
1120
|
+
# Check only one affinity ligand is present
|
1121
|
+
if len(affinity_ligands) > 1:
|
1122
|
+
msg = "Only one affinity ligand is currently supported!"
|
1123
|
+
raise ValueError(msg)
|
1124
|
+
|
1125
|
+
# Go through entities and parse them
|
1126
|
+
extra_mols: dict[str, Mol] = {}
|
1127
|
+
chains: dict[str, ParsedChain] = {}
|
1128
|
+
chain_to_msa: dict[str, str] = {}
|
1129
|
+
entity_to_seq: dict[str, str] = {}
|
1130
|
+
is_msa_custom = False
|
1131
|
+
is_msa_auto = False
|
1132
|
+
ligand_id = 1
|
1133
|
+
for entity_id, items in enumerate(items_to_group.values()):
|
1134
|
+
# Get entity type and sequence
|
1135
|
+
entity_type = next(iter(items[0].keys())).lower()
|
1136
|
+
|
1137
|
+
# Get ids
|
1138
|
+
ids = []
|
1139
|
+
for item in items:
|
1140
|
+
if isinstance(item[entity_type]["id"], str):
|
1141
|
+
ids.append(item[entity_type]["id"])
|
1142
|
+
elif isinstance(item[entity_type]["id"], list):
|
1143
|
+
ids.extend(item[entity_type]["id"])
|
1144
|
+
|
1145
|
+
# Check if any affinity ligand is present
|
1146
|
+
if len(ids) == 1:
|
1147
|
+
affinity = ids[0] in affinity_ligands
|
1148
|
+
elif (len(ids) > 1) and any(x in affinity_ligands for x in ids):
|
1149
|
+
msg = "Cannot compute affinity for a ligand that has multiple copies!"
|
1150
|
+
raise ValueError(msg)
|
1151
|
+
else:
|
1152
|
+
affinity = False
|
1153
|
+
|
1154
|
+
# Ensure all the items share the same msa
|
1155
|
+
msa = -1
|
1156
|
+
if entity_type == "protein":
|
1157
|
+
# Get the msa, default to 0, meaning auto-generated
|
1158
|
+
msa = items[0][entity_type].get("msa", 0)
|
1159
|
+
if (msa is None) or (msa == ""):
|
1160
|
+
msa = 0
|
1161
|
+
|
1162
|
+
# Check if all MSAs are the same within the same entity
|
1163
|
+
for item in items:
|
1164
|
+
item_msa = item[entity_type].get("msa", 0)
|
1165
|
+
if (item_msa is None) or (item_msa == ""):
|
1166
|
+
item_msa = 0
|
1167
|
+
|
1168
|
+
if item_msa != msa:
|
1169
|
+
msg = "All proteins with the same sequence must share the same MSA!"
|
1170
|
+
raise ValueError(msg)
|
1171
|
+
|
1172
|
+
# Set the MSA, warn if passed in single-sequence mode
|
1173
|
+
if msa == "empty":
|
1174
|
+
msa = -1
|
1175
|
+
msg = (
|
1176
|
+
"Found explicit empty MSA for some proteins, will run "
|
1177
|
+
"these in single sequence mode. Keep in mind that the "
|
1178
|
+
"model predictions will be suboptimal without an MSA."
|
1179
|
+
)
|
1180
|
+
click.echo(msg)
|
1181
|
+
|
1182
|
+
if msa not in (0, -1):
|
1183
|
+
is_msa_custom = True
|
1184
|
+
elif msa == 0:
|
1185
|
+
is_msa_auto = True
|
1186
|
+
|
1187
|
+
# Parse a polymer
|
1188
|
+
if entity_type in {"protein", "dna", "rna"}:
|
1189
|
+
# Get token map
|
1190
|
+
if entity_type == "rna":
|
1191
|
+
token_map = const.rna_letter_to_token
|
1192
|
+
elif entity_type == "dna":
|
1193
|
+
token_map = const.dna_letter_to_token
|
1194
|
+
elif entity_type == "protein":
|
1195
|
+
token_map = const.prot_letter_to_token
|
1196
|
+
else:
|
1197
|
+
msg = f"Unknown polymer type: {entity_type}"
|
1198
|
+
raise ValueError(msg)
|
1199
|
+
|
1200
|
+
# Get polymer info
|
1201
|
+
chain_type = const.chain_type_ids[entity_type.upper()]
|
1202
|
+
unk_token = const.unk_token[entity_type.upper()]
|
1203
|
+
|
1204
|
+
# Extract sequence
|
1205
|
+
raw_seq = items[0][entity_type]["sequence"]
|
1206
|
+
entity_to_seq[entity_id] = raw_seq
|
1207
|
+
|
1208
|
+
# Convert sequence to tokens
|
1209
|
+
seq = [token_map.get(c, unk_token) for c in list(raw_seq)]
|
1210
|
+
|
1211
|
+
# Apply modifications
|
1212
|
+
for mod in items[0][entity_type].get("modifications", []):
|
1213
|
+
code = mod["ccd"]
|
1214
|
+
idx = mod["position"] - 1 # 1-indexed
|
1215
|
+
seq[idx] = code
|
1216
|
+
|
1217
|
+
cyclic = items[0][entity_type].get("cyclic", False)
|
1218
|
+
|
1219
|
+
# Parse a polymer
|
1220
|
+
parsed_chain = parse_polymer(
|
1221
|
+
sequence=seq,
|
1222
|
+
raw_sequence=raw_seq,
|
1223
|
+
entity=entity_id,
|
1224
|
+
chain_type=chain_type,
|
1225
|
+
components=ccd,
|
1226
|
+
cyclic=cyclic,
|
1227
|
+
mol_dir=mol_dir,
|
1228
|
+
)
|
1229
|
+
|
1230
|
+
# Parse a non-polymer
|
1231
|
+
elif (entity_type == "ligand") and "ccd" in (items[0][entity_type]):
|
1232
|
+
seq = items[0][entity_type]["ccd"]
|
1233
|
+
|
1234
|
+
if isinstance(seq, str):
|
1235
|
+
seq = [seq]
|
1236
|
+
|
1237
|
+
if affinity and len(seq) > 1:
|
1238
|
+
msg = "Cannot compute affinity for multi residue ligands!"
|
1239
|
+
raise ValueError(msg)
|
1240
|
+
|
1241
|
+
residues = []
|
1242
|
+
affinity_mw = None
|
1243
|
+
for res_idx, code in enumerate(seq):
|
1244
|
+
# Get mol
|
1245
|
+
ref_mol = get_mol(code, ccd, mol_dir)
|
1246
|
+
|
1247
|
+
if affinity:
|
1248
|
+
affinity_mw = AllChem.Descriptors.MolWt(ref_mol)
|
1249
|
+
|
1250
|
+
# Parse residue
|
1251
|
+
residue = parse_ccd_residue(
|
1252
|
+
name=code,
|
1253
|
+
ref_mol=ref_mol,
|
1254
|
+
res_idx=res_idx,
|
1255
|
+
)
|
1256
|
+
residues.append(residue)
|
1257
|
+
|
1258
|
+
# Create multi ligand chain
|
1259
|
+
parsed_chain = ParsedChain(
|
1260
|
+
entity=entity_id,
|
1261
|
+
residues=residues,
|
1262
|
+
type=const.chain_type_ids["NONPOLYMER"],
|
1263
|
+
cyclic_period=0,
|
1264
|
+
sequence=None,
|
1265
|
+
affinity=affinity,
|
1266
|
+
affinity_mw=affinity_mw,
|
1267
|
+
)
|
1268
|
+
|
1269
|
+
assert not items[0][entity_type].get(
|
1270
|
+
"cyclic", False
|
1271
|
+
), "Cyclic flag is not supported for ligands"
|
1272
|
+
|
1273
|
+
elif (entity_type == "ligand") and ("smiles" in items[0][entity_type]):
|
1274
|
+
seq = items[0][entity_type]["smiles"]
|
1275
|
+
|
1276
|
+
if affinity:
|
1277
|
+
seq = standardize(seq)
|
1278
|
+
|
1279
|
+
mol = AllChem.MolFromSmiles(seq)
|
1280
|
+
mol = AllChem.AddHs(mol)
|
1281
|
+
|
1282
|
+
# Set atom names
|
1283
|
+
canonical_order = AllChem.CanonicalRankAtoms(mol)
|
1284
|
+
for atom, can_idx in zip(mol.GetAtoms(), canonical_order):
|
1285
|
+
atom_name = atom.GetSymbol().upper() + str(can_idx + 1)
|
1286
|
+
if len(atom_name) > 4:
|
1287
|
+
msg = (
|
1288
|
+
f"{seq} has an atom with a name longer than "
|
1289
|
+
f"4 characters: {atom_name}."
|
1290
|
+
)
|
1291
|
+
raise ValueError(msg)
|
1292
|
+
atom.SetProp("name", atom_name)
|
1293
|
+
|
1294
|
+
success = compute_3d_conformer(mol)
|
1295
|
+
if not success:
|
1296
|
+
msg = f"Failed to compute 3D conformer for {seq}"
|
1297
|
+
raise ValueError(msg)
|
1298
|
+
|
1299
|
+
mol_no_h = AllChem.RemoveHs(mol, sanitize=False)
|
1300
|
+
affinity_mw = AllChem.Descriptors.MolWt(mol_no_h) if affinity else None
|
1301
|
+
extra_mols[f"LIG{ligand_id}"] = mol_no_h
|
1302
|
+
residue = parse_ccd_residue(
|
1303
|
+
name=f"LIG{ligand_id}",
|
1304
|
+
ref_mol=mol,
|
1305
|
+
res_idx=0,
|
1306
|
+
)
|
1307
|
+
|
1308
|
+
ligand_id += 1
|
1309
|
+
parsed_chain = ParsedChain(
|
1310
|
+
entity=entity_id,
|
1311
|
+
residues=[residue],
|
1312
|
+
type=const.chain_type_ids["NONPOLYMER"],
|
1313
|
+
cyclic_period=0,
|
1314
|
+
sequence=None,
|
1315
|
+
affinity=affinity,
|
1316
|
+
affinity_mw=affinity_mw,
|
1317
|
+
)
|
1318
|
+
|
1319
|
+
assert not items[0][entity_type].get(
|
1320
|
+
"cyclic", False
|
1321
|
+
), "Cyclic flag is not supported for ligands"
|
1322
|
+
|
1323
|
+
else:
|
1324
|
+
msg = f"Invalid entity type: {entity_type}"
|
1325
|
+
raise ValueError(msg)
|
1326
|
+
|
1327
|
+
# Add as many chains as provided ids
|
1328
|
+
for item in items:
|
1329
|
+
ids = item[entity_type]["id"]
|
1330
|
+
if isinstance(ids, str):
|
1331
|
+
ids = [ids]
|
1332
|
+
for chain_name in ids:
|
1333
|
+
chains[chain_name] = parsed_chain
|
1334
|
+
chain_to_msa[chain_name] = msa
|
1335
|
+
|
1336
|
+
# Check if msa is custom or auto
|
1337
|
+
if is_msa_custom and is_msa_auto:
|
1338
|
+
msg = "Cannot mix custom and auto-generated MSAs in the same input!"
|
1339
|
+
raise ValueError(msg)
|
1340
|
+
|
1341
|
+
# If no chains parsed fail
|
1342
|
+
if not chains:
|
1343
|
+
msg = "No chains parsed!"
|
1344
|
+
raise ValueError(msg)
|
1345
|
+
|
1346
|
+
# Create tables
|
1347
|
+
atom_data = []
|
1348
|
+
bond_data = []
|
1349
|
+
res_data = []
|
1350
|
+
chain_data = []
|
1351
|
+
protein_chains = set()
|
1352
|
+
affinity_info = None
|
1353
|
+
|
1354
|
+
rdkit_bounds_constraint_data = []
|
1355
|
+
chiral_atom_constraint_data = []
|
1356
|
+
stereo_bond_constraint_data = []
|
1357
|
+
planar_bond_constraint_data = []
|
1358
|
+
planar_ring_5_constraint_data = []
|
1359
|
+
planar_ring_6_constraint_data = []
|
1360
|
+
|
1361
|
+
# Convert parsed chains to tables
|
1362
|
+
atom_idx = 0
|
1363
|
+
res_idx = 0
|
1364
|
+
asym_id = 0
|
1365
|
+
sym_count = {}
|
1366
|
+
chain_to_idx = {}
|
1367
|
+
|
1368
|
+
# Keep a mapping of (chain_name, residue_idx, atom_name) to atom_idx
|
1369
|
+
atom_idx_map = {}
|
1370
|
+
|
1371
|
+
for asym_id, (chain_name, chain) in enumerate(chains.items()):
|
1372
|
+
# Compute number of atoms and residues
|
1373
|
+
res_num = len(chain.residues)
|
1374
|
+
atom_num = sum(len(res.atoms) for res in chain.residues)
|
1375
|
+
|
1376
|
+
# Save protein chains for later
|
1377
|
+
if chain.type == const.chain_type_ids["PROTEIN"]:
|
1378
|
+
protein_chains.add(chain_name)
|
1379
|
+
|
1380
|
+
# Add affinity info
|
1381
|
+
if chain.affinity and affinity_info is not None:
|
1382
|
+
msg = "Cannot compute affinity for multiple ligands!"
|
1383
|
+
raise ValueError(msg)
|
1384
|
+
|
1385
|
+
if chain.affinity:
|
1386
|
+
affinity_info = AffinityInfo(
|
1387
|
+
chain_id=asym_id,
|
1388
|
+
mw=chain.affinity_mw,
|
1389
|
+
)
|
1390
|
+
|
1391
|
+
# Find all copies of this chain in the assembly
|
1392
|
+
entity_id = int(chain.entity)
|
1393
|
+
sym_id = sym_count.get(entity_id, 0)
|
1394
|
+
chain_data.append(
|
1395
|
+
(
|
1396
|
+
chain_name,
|
1397
|
+
chain.type,
|
1398
|
+
entity_id,
|
1399
|
+
sym_id,
|
1400
|
+
asym_id,
|
1401
|
+
atom_idx,
|
1402
|
+
atom_num,
|
1403
|
+
res_idx,
|
1404
|
+
res_num,
|
1405
|
+
chain.cyclic_period,
|
1406
|
+
)
|
1407
|
+
)
|
1408
|
+
chain_to_idx[chain_name] = asym_id
|
1409
|
+
sym_count[entity_id] = sym_id + 1
|
1410
|
+
|
1411
|
+
# Add residue, atom, bond, data
|
1412
|
+
for res in chain.residues:
|
1413
|
+
atom_center = atom_idx + res.atom_center
|
1414
|
+
atom_disto = atom_idx + res.atom_disto
|
1415
|
+
res_data.append(
|
1416
|
+
(
|
1417
|
+
res.name,
|
1418
|
+
res.type,
|
1419
|
+
res.idx,
|
1420
|
+
atom_idx,
|
1421
|
+
len(res.atoms),
|
1422
|
+
atom_center,
|
1423
|
+
atom_disto,
|
1424
|
+
res.is_standard,
|
1425
|
+
res.is_present,
|
1426
|
+
)
|
1427
|
+
)
|
1428
|
+
|
1429
|
+
if res.rdkit_bounds_constraints is not None:
|
1430
|
+
for constraint in res.rdkit_bounds_constraints:
|
1431
|
+
rdkit_bounds_constraint_data.append( # noqa: PERF401
|
1432
|
+
(
|
1433
|
+
tuple(
|
1434
|
+
c_atom_idx + atom_idx
|
1435
|
+
for c_atom_idx in constraint.atom_idxs
|
1436
|
+
),
|
1437
|
+
constraint.is_bond,
|
1438
|
+
constraint.is_angle,
|
1439
|
+
constraint.upper_bound,
|
1440
|
+
constraint.lower_bound,
|
1441
|
+
)
|
1442
|
+
)
|
1443
|
+
if res.chiral_atom_constraints is not None:
|
1444
|
+
for constraint in res.chiral_atom_constraints:
|
1445
|
+
chiral_atom_constraint_data.append( # noqa: PERF401
|
1446
|
+
(
|
1447
|
+
tuple(
|
1448
|
+
c_atom_idx + atom_idx
|
1449
|
+
for c_atom_idx in constraint.atom_idxs
|
1450
|
+
),
|
1451
|
+
constraint.is_reference,
|
1452
|
+
constraint.is_r,
|
1453
|
+
)
|
1454
|
+
)
|
1455
|
+
if res.stereo_bond_constraints is not None:
|
1456
|
+
for constraint in res.stereo_bond_constraints:
|
1457
|
+
stereo_bond_constraint_data.append( # noqa: PERF401
|
1458
|
+
(
|
1459
|
+
tuple(
|
1460
|
+
c_atom_idx + atom_idx
|
1461
|
+
for c_atom_idx in constraint.atom_idxs
|
1462
|
+
),
|
1463
|
+
constraint.is_check,
|
1464
|
+
constraint.is_e,
|
1465
|
+
)
|
1466
|
+
)
|
1467
|
+
if res.planar_bond_constraints is not None:
|
1468
|
+
for constraint in res.planar_bond_constraints:
|
1469
|
+
planar_bond_constraint_data.append( # noqa: PERF401
|
1470
|
+
(
|
1471
|
+
tuple(
|
1472
|
+
c_atom_idx + atom_idx
|
1473
|
+
for c_atom_idx in constraint.atom_idxs
|
1474
|
+
),
|
1475
|
+
)
|
1476
|
+
)
|
1477
|
+
if res.planar_ring_5_constraints is not None:
|
1478
|
+
for constraint in res.planar_ring_5_constraints:
|
1479
|
+
planar_ring_5_constraint_data.append( # noqa: PERF401
|
1480
|
+
(
|
1481
|
+
tuple(
|
1482
|
+
c_atom_idx + atom_idx
|
1483
|
+
for c_atom_idx in constraint.atom_idxs
|
1484
|
+
),
|
1485
|
+
)
|
1486
|
+
)
|
1487
|
+
if res.planar_ring_6_constraints is not None:
|
1488
|
+
for constraint in res.planar_ring_6_constraints:
|
1489
|
+
planar_ring_6_constraint_data.append( # noqa: PERF401
|
1490
|
+
(
|
1491
|
+
tuple(
|
1492
|
+
c_atom_idx + atom_idx
|
1493
|
+
for c_atom_idx in constraint.atom_idxs
|
1494
|
+
),
|
1495
|
+
)
|
1496
|
+
)
|
1497
|
+
|
1498
|
+
for bond in res.bonds:
|
1499
|
+
atom_1 = atom_idx + bond.atom_1
|
1500
|
+
atom_2 = atom_idx + bond.atom_2
|
1501
|
+
bond_data.append(
|
1502
|
+
(
|
1503
|
+
asym_id,
|
1504
|
+
asym_id,
|
1505
|
+
res_idx,
|
1506
|
+
res_idx,
|
1507
|
+
atom_1,
|
1508
|
+
atom_2,
|
1509
|
+
bond.type,
|
1510
|
+
)
|
1511
|
+
)
|
1512
|
+
|
1513
|
+
for atom in res.atoms:
|
1514
|
+
# Add atom to map
|
1515
|
+
atom_idx_map[(chain_name, res.idx, atom.name)] = (
|
1516
|
+
asym_id,
|
1517
|
+
res_idx,
|
1518
|
+
atom_idx,
|
1519
|
+
)
|
1520
|
+
|
1521
|
+
# Add atom to data
|
1522
|
+
atom_data.append(
|
1523
|
+
(
|
1524
|
+
atom.name,
|
1525
|
+
atom.element,
|
1526
|
+
atom.charge,
|
1527
|
+
atom.coords,
|
1528
|
+
atom.conformer,
|
1529
|
+
atom.is_present,
|
1530
|
+
atom.chirality,
|
1531
|
+
)
|
1532
|
+
)
|
1533
|
+
atom_idx += 1
|
1534
|
+
|
1535
|
+
res_idx += 1
|
1536
|
+
|
1537
|
+
# Parse constraints
|
1538
|
+
connections = []
|
1539
|
+
pocket_constraints = []
|
1540
|
+
contact_constraints = []
|
1541
|
+
constraints = schema.get("constraints", [])
|
1542
|
+
for constraint in constraints:
|
1543
|
+
if "bond" in constraint:
|
1544
|
+
if "atom1" not in constraint["bond"] or "atom2" not in constraint["bond"]:
|
1545
|
+
msg = f"Bond constraint was not properly specified"
|
1546
|
+
raise ValueError(msg)
|
1547
|
+
|
1548
|
+
c1, r1, a1 = tuple(constraint["bond"]["atom1"])
|
1549
|
+
c2, r2, a2 = tuple(constraint["bond"]["atom2"])
|
1550
|
+
c1, r1, a1 = atom_idx_map[(c1, r1 - 1, a1)] # 1-indexed
|
1551
|
+
c2, r2, a2 = atom_idx_map[(c2, r2 - 1, a2)] # 1-indexed
|
1552
|
+
connections.append((c1, c2, r1, r2, a1, a2))
|
1553
|
+
elif "pocket" in constraint:
|
1554
|
+
if (
|
1555
|
+
"binder" not in constraint["pocket"]
|
1556
|
+
or "contacts" not in constraint["pocket"]
|
1557
|
+
):
|
1558
|
+
msg = f"Pocket constraint was not properly specified"
|
1559
|
+
raise ValueError(msg)
|
1560
|
+
|
1561
|
+
if len(pocket_constraints) > 0 and not boltz_2:
|
1562
|
+
msg = f"Only one pocket binders is supported in Boltz-1!"
|
1563
|
+
raise ValueError(msg)
|
1564
|
+
|
1565
|
+
max_distance = constraint["pocket"].get("max_distance", 6.0)
|
1566
|
+
if max_distance != 6.0 and not boltz_2:
|
1567
|
+
msg = f"Max distance != 6.0 is not supported in Boltz-1!"
|
1568
|
+
raise ValueError(msg)
|
1569
|
+
|
1570
|
+
binder = constraint["pocket"]["binder"]
|
1571
|
+
binder = chain_to_idx[binder]
|
1572
|
+
|
1573
|
+
contacts = []
|
1574
|
+
for chain_name, residue_index_or_atom_name in constraint["pocket"][
|
1575
|
+
"contacts"
|
1576
|
+
]:
|
1577
|
+
if chains[chain_name].type == const.chain_type_ids["NONPOLYMER"]:
|
1578
|
+
# Non-polymer chains are indexed by atom name
|
1579
|
+
_, _, atom_idx = atom_idx_map[
|
1580
|
+
(chain_name, 0, residue_index_or_atom_name)
|
1581
|
+
]
|
1582
|
+
contact = (chain_to_idx[chain_name], atom_idx)
|
1583
|
+
else:
|
1584
|
+
# Polymer chains are indexed by residue index
|
1585
|
+
contact = (chain_to_idx[chain_name], residue_index_or_atom_name - 1)
|
1586
|
+
contacts.append(contact)
|
1587
|
+
|
1588
|
+
pocket_constraints.append((binder, contacts, max_distance))
|
1589
|
+
elif "contact" in constraint:
|
1590
|
+
if (
|
1591
|
+
"token1" not in constraint["contact"]
|
1592
|
+
or "token2" not in constraint["contact"]
|
1593
|
+
):
|
1594
|
+
msg = f"Contact constraint was not properly specified"
|
1595
|
+
raise ValueError(msg)
|
1596
|
+
|
1597
|
+
if not boltz_2:
|
1598
|
+
msg = f"Contact constraint is not supported in Boltz-1!"
|
1599
|
+
raise ValueError(msg)
|
1600
|
+
|
1601
|
+
max_distance = constraint["contact"].get("max_distance", 6.0)
|
1602
|
+
|
1603
|
+
chain_name1, residue_index_or_atom_name1 = constraint["contact"]["token1"]
|
1604
|
+
if chains[chain_name1].type == const.chain_type_ids["NONPOLYMER"]:
|
1605
|
+
# Non-polymer chains are indexed by atom name
|
1606
|
+
_, _, atom_idx = atom_idx_map[
|
1607
|
+
(chain_name1, 0, residue_index_or_atom_name1)
|
1608
|
+
]
|
1609
|
+
token1 = (chain_to_idx[chain_name1], atom_idx)
|
1610
|
+
else:
|
1611
|
+
# Polymer chains are indexed by residue index
|
1612
|
+
token1 = (chain_to_idx[chain_name1], residue_index_or_atom_name1 - 1)
|
1613
|
+
|
1614
|
+
pocket_constraints.append((binder, contacts, max_distance))
|
1615
|
+
else:
|
1616
|
+
msg = f"Invalid constraint: {constraint}"
|
1617
|
+
raise ValueError(msg)
|
1618
|
+
|
1619
|
+
# Get protein sequences in this YAML
|
1620
|
+
protein_seqs = {name: chains[name].sequence for name in protein_chains}
|
1621
|
+
|
1622
|
+
# Parse templates
|
1623
|
+
template_schema = schema.get("templates", [])
|
1624
|
+
if template_schema and not boltz_2:
|
1625
|
+
msg = "Templates are not supported in Boltz 1.0!"
|
1626
|
+
raise ValueError(msg)
|
1627
|
+
|
1628
|
+
templates = {}
|
1629
|
+
template_records = []
|
1630
|
+
for template in template_schema:
|
1631
|
+
if "cif" not in template:
|
1632
|
+
msg = "Template was not properly specified, missing CIF path!"
|
1633
|
+
raise ValueError(msg)
|
1634
|
+
|
1635
|
+
path = template["cif"]
|
1636
|
+
template_id = Path(path).stem
|
1637
|
+
chain_ids = template.get("chain_id", None)
|
1638
|
+
template_chain_ids = template.get("template_id", None)
|
1639
|
+
|
1640
|
+
# Check validity of input
|
1641
|
+
matched = False
|
1642
|
+
|
1643
|
+
if chain_ids is not None and not isinstance(chain_ids, list):
|
1644
|
+
chain_ids = [chain_ids]
|
1645
|
+
if template_chain_ids is not None and not isinstance(template_chain_ids, list):
|
1646
|
+
template_chain_ids = [template_chain_ids]
|
1647
|
+
|
1648
|
+
if (
|
1649
|
+
template_chain_ids is not None
|
1650
|
+
and chain_ids is not None
|
1651
|
+
and len(template_chain_ids) != len(chain_ids)
|
1652
|
+
):
|
1653
|
+
matched = True
|
1654
|
+
if len(template_chain_ids) != len(chain_ids):
|
1655
|
+
msg = (
|
1656
|
+
"When providing both the chain_id and template_id, the number of"
|
1657
|
+
"template_ids provided must match the number of chain_ids!"
|
1658
|
+
)
|
1659
|
+
raise ValueError(msg)
|
1660
|
+
|
1661
|
+
# Get relevant chains ids
|
1662
|
+
if chain_ids is None:
|
1663
|
+
chain_ids = list(protein_chains)
|
1664
|
+
|
1665
|
+
for chain_id in chain_ids:
|
1666
|
+
if chain_id not in protein_chains:
|
1667
|
+
msg = (
|
1668
|
+
f"Chain {chain_id} assigned for template"
|
1669
|
+
f"{template_id} is not one of the protein chains!"
|
1670
|
+
)
|
1671
|
+
raise ValueError(msg)
|
1672
|
+
|
1673
|
+
# Get relevant template chain ids
|
1674
|
+
parsed_template = parse_mmcif(
|
1675
|
+
path,
|
1676
|
+
mols=ccd,
|
1677
|
+
moldir=mol_dir,
|
1678
|
+
use_assembly=False,
|
1679
|
+
compute_interfaces=False,
|
1680
|
+
)
|
1681
|
+
template_proteins = {
|
1682
|
+
str(c["name"])
|
1683
|
+
for c in parsed_template.data.chains
|
1684
|
+
if c["mol_type"] == const.chain_type_ids["PROTEIN"]
|
1685
|
+
}
|
1686
|
+
if template_chain_ids is None:
|
1687
|
+
template_chain_ids = list(template_proteins)
|
1688
|
+
|
1689
|
+
for chain_id in template_chain_ids:
|
1690
|
+
if chain_id not in template_proteins:
|
1691
|
+
msg = (
|
1692
|
+
f"Template chain {chain_id} assigned for template"
|
1693
|
+
f"{template_id} is not one of the protein chains!"
|
1694
|
+
)
|
1695
|
+
raise ValueError(msg)
|
1696
|
+
|
1697
|
+
# Compute template records
|
1698
|
+
if matched:
|
1699
|
+
template_records.extend(
|
1700
|
+
get_template_records_from_matching(
|
1701
|
+
template_id=template_id,
|
1702
|
+
chain_ids=chain_ids,
|
1703
|
+
sequences=protein_seqs,
|
1704
|
+
template_chain_ids=template_chain_ids,
|
1705
|
+
template_sequences=parsed_template.sequences,
|
1706
|
+
)
|
1707
|
+
)
|
1708
|
+
else:
|
1709
|
+
template_records.extend(
|
1710
|
+
get_template_records_from_search(
|
1711
|
+
template_id=template_id,
|
1712
|
+
chain_ids=chain_ids,
|
1713
|
+
sequences=protein_seqs,
|
1714
|
+
template_chain_ids=template_chain_ids,
|
1715
|
+
template_sequences=parsed_template.sequences,
|
1716
|
+
)
|
1717
|
+
)
|
1718
|
+
# Save template
|
1719
|
+
templates[template_id] = parsed_template.data
|
1720
|
+
|
1721
|
+
# Convert into datatypes
|
1722
|
+
residues = np.array(res_data, dtype=Residue)
|
1723
|
+
chains = np.array(chain_data, dtype=Chain)
|
1724
|
+
interfaces = np.array([], dtype=Interface)
|
1725
|
+
mask = np.ones(len(chain_data), dtype=bool)
|
1726
|
+
rdkit_bounds_constraints = np.array(
|
1727
|
+
rdkit_bounds_constraint_data, dtype=RDKitBoundsConstraint
|
1728
|
+
)
|
1729
|
+
chiral_atom_constraints = np.array(
|
1730
|
+
chiral_atom_constraint_data, dtype=ChiralAtomConstraint
|
1731
|
+
)
|
1732
|
+
stereo_bond_constraints = np.array(
|
1733
|
+
stereo_bond_constraint_data, dtype=StereoBondConstraint
|
1734
|
+
)
|
1735
|
+
planar_bond_constraints = np.array(
|
1736
|
+
planar_bond_constraint_data, dtype=PlanarBondConstraint
|
1737
|
+
)
|
1738
|
+
planar_ring_5_constraints = np.array(
|
1739
|
+
planar_ring_5_constraint_data, dtype=PlanarRing5Constraint
|
1740
|
+
)
|
1741
|
+
planar_ring_6_constraints = np.array(
|
1742
|
+
planar_ring_6_constraint_data, dtype=PlanarRing6Constraint
|
1743
|
+
)
|
1744
|
+
|
1745
|
+
if boltz_2:
|
1746
|
+
atom_data = [(a[0], a[3], a[5], 0.0, 1.0) for a in atom_data]
|
1747
|
+
connections = [(*c, const.bond_type_ids["COVALENT"]) for c in connections]
|
1748
|
+
bond_data = bond_data + connections
|
1749
|
+
atoms = np.array(atom_data, dtype=AtomV2)
|
1750
|
+
bonds = np.array(bond_data, dtype=BondV2)
|
1751
|
+
coords = [(x,) for x in atoms["coords"]]
|
1752
|
+
coords = np.array(coords, Coords)
|
1753
|
+
ensemble = np.array([(0, len(coords))], dtype=Ensemble)
|
1754
|
+
data = StructureV2(
|
1755
|
+
atoms=atoms,
|
1756
|
+
bonds=bonds,
|
1757
|
+
residues=residues,
|
1758
|
+
chains=chains,
|
1759
|
+
interfaces=interfaces,
|
1760
|
+
mask=mask,
|
1761
|
+
coords=coords,
|
1762
|
+
ensemble=ensemble,
|
1763
|
+
)
|
1764
|
+
else:
|
1765
|
+
bond_data = [(b[4], b[5], b[6]) for b in bond_data]
|
1766
|
+
atom_data = [(convert_atom_name(a[0]), *a[1:]) for a in atom_data]
|
1767
|
+
atoms = np.array(atom_data, dtype=Atom)
|
1768
|
+
bonds = np.array(bond_data, dtype=Bond)
|
1769
|
+
connections = np.array(connections, dtype=Connection)
|
1770
|
+
data = Structure(
|
1771
|
+
atoms=atoms,
|
1772
|
+
bonds=bonds,
|
1773
|
+
residues=residues,
|
1774
|
+
chains=chains,
|
1775
|
+
connections=connections,
|
1776
|
+
interfaces=interfaces,
|
1777
|
+
mask=mask,
|
1778
|
+
)
|
1779
|
+
|
1780
|
+
# Create metadata
|
1781
|
+
struct_info = StructureInfo(num_chains=len(chains))
|
1782
|
+
chain_infos = []
|
1783
|
+
for chain in chains:
|
1784
|
+
chain_info = ChainInfo(
|
1785
|
+
chain_id=int(chain["asym_id"]),
|
1786
|
+
chain_name=chain["name"],
|
1787
|
+
mol_type=int(chain["mol_type"]),
|
1788
|
+
cluster_id=-1,
|
1789
|
+
msa_id=chain_to_msa[chain["name"]],
|
1790
|
+
num_residues=int(chain["res_num"]),
|
1791
|
+
valid=True,
|
1792
|
+
entity_id=int(chain["entity_id"]),
|
1793
|
+
)
|
1794
|
+
chain_infos.append(chain_info)
|
1795
|
+
|
1796
|
+
options = InferenceOptions(pocket_constraints=pocket_constraints)
|
1797
|
+
record = Record(
|
1798
|
+
id=name,
|
1799
|
+
structure=struct_info,
|
1800
|
+
chains=chain_infos,
|
1801
|
+
interfaces=[],
|
1802
|
+
inference_options=options,
|
1803
|
+
templates=template_records,
|
1804
|
+
affinity=affinity_info,
|
1805
|
+
)
|
1806
|
+
|
1807
|
+
residue_constraints = ResidueConstraints(
|
1808
|
+
rdkit_bounds_constraints=rdkit_bounds_constraints,
|
1809
|
+
chiral_atom_constraints=chiral_atom_constraints,
|
1810
|
+
stereo_bond_constraints=stereo_bond_constraints,
|
1811
|
+
planar_bond_constraints=planar_bond_constraints,
|
1812
|
+
planar_ring_5_constraints=planar_ring_5_constraints,
|
1813
|
+
planar_ring_6_constraints=planar_ring_6_constraints,
|
1814
|
+
)
|
1815
|
+
|
1816
|
+
return Target(
|
1817
|
+
record=record,
|
1818
|
+
structure=data,
|
1819
|
+
sequences=entity_to_seq,
|
1820
|
+
residue_constraints=residue_constraints,
|
1821
|
+
templates=templates,
|
1822
|
+
extra_mols=extra_mols,
|
1823
|
+
)
|
1824
|
+
|
1825
|
+
|
1826
|
+
def standardize(smiles: str) -> Optional[str]:
|
1827
|
+
"""Standardize a molecule and return its SMILES and a flag indicating whether the molecule is valid.
|
1828
|
+
This version has exception handling, which the original in mol-finder/data doesn't have. I didn't change the mol-finder/data
|
1829
|
+
since there are a lot of other functions that depend on it and I didn't want to break them.
|
1830
|
+
"""
|
1831
|
+
LARGEST_FRAGMENT_CHOOSER = rdMolStandardize.LargestFragmentChooser()
|
1832
|
+
|
1833
|
+
mol = Chem.MolFromSmiles(smiles, sanitize=False)
|
1834
|
+
|
1835
|
+
exclude = exclude_flag(mol, includeRDKitSanitization=False)
|
1836
|
+
|
1837
|
+
if exclude:
|
1838
|
+
raise ValueError("Molecule is excluded")
|
1839
|
+
|
1840
|
+
# Standardize with ChEMBL data curation pipeline. During standardization, the molecule may be broken
|
1841
|
+
# Choose molecule with largest component
|
1842
|
+
mol = LARGEST_FRAGMENT_CHOOSER.choose(mol)
|
1843
|
+
# Standardize with ChEMBL data curation pipeline. During standardization, the molecule may be broken
|
1844
|
+
mol = standardize_mol(mol)
|
1845
|
+
smiles = Chem.MolToSmiles(mol)
|
1846
|
+
|
1847
|
+
# Check if molecule can be parsed by RDKit (in rare cases, the molecule may be broken during standardization)
|
1848
|
+
if Chem.MolFromSmiles(smiles) is None:
|
1849
|
+
raise ValueError("Molecule is broken")
|
1850
|
+
|
1851
|
+
return smiles
|