boltz-vsynthes 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. boltz/__init__.py +7 -0
  2. boltz/data/__init__.py +0 -0
  3. boltz/data/const.py +1184 -0
  4. boltz/data/crop/__init__.py +0 -0
  5. boltz/data/crop/affinity.py +164 -0
  6. boltz/data/crop/boltz.py +296 -0
  7. boltz/data/crop/cropper.py +45 -0
  8. boltz/data/feature/__init__.py +0 -0
  9. boltz/data/feature/featurizer.py +1230 -0
  10. boltz/data/feature/featurizerv2.py +2208 -0
  11. boltz/data/feature/symmetry.py +602 -0
  12. boltz/data/filter/__init__.py +0 -0
  13. boltz/data/filter/dynamic/__init__.py +0 -0
  14. boltz/data/filter/dynamic/date.py +76 -0
  15. boltz/data/filter/dynamic/filter.py +24 -0
  16. boltz/data/filter/dynamic/max_residues.py +37 -0
  17. boltz/data/filter/dynamic/resolution.py +34 -0
  18. boltz/data/filter/dynamic/size.py +38 -0
  19. boltz/data/filter/dynamic/subset.py +42 -0
  20. boltz/data/filter/static/__init__.py +0 -0
  21. boltz/data/filter/static/filter.py +26 -0
  22. boltz/data/filter/static/ligand.py +37 -0
  23. boltz/data/filter/static/polymer.py +299 -0
  24. boltz/data/module/__init__.py +0 -0
  25. boltz/data/module/inference.py +307 -0
  26. boltz/data/module/inferencev2.py +429 -0
  27. boltz/data/module/training.py +684 -0
  28. boltz/data/module/trainingv2.py +660 -0
  29. boltz/data/mol.py +900 -0
  30. boltz/data/msa/__init__.py +0 -0
  31. boltz/data/msa/mmseqs2.py +235 -0
  32. boltz/data/pad.py +84 -0
  33. boltz/data/parse/__init__.py +0 -0
  34. boltz/data/parse/a3m.py +134 -0
  35. boltz/data/parse/csv.py +100 -0
  36. boltz/data/parse/fasta.py +138 -0
  37. boltz/data/parse/mmcif.py +1239 -0
  38. boltz/data/parse/mmcif_with_constraints.py +1607 -0
  39. boltz/data/parse/schema.py +1851 -0
  40. boltz/data/parse/yaml.py +68 -0
  41. boltz/data/sample/__init__.py +0 -0
  42. boltz/data/sample/cluster.py +283 -0
  43. boltz/data/sample/distillation.py +57 -0
  44. boltz/data/sample/random.py +39 -0
  45. boltz/data/sample/sampler.py +49 -0
  46. boltz/data/tokenize/__init__.py +0 -0
  47. boltz/data/tokenize/boltz.py +195 -0
  48. boltz/data/tokenize/boltz2.py +396 -0
  49. boltz/data/tokenize/tokenizer.py +24 -0
  50. boltz/data/types.py +777 -0
  51. boltz/data/write/__init__.py +0 -0
  52. boltz/data/write/mmcif.py +305 -0
  53. boltz/data/write/pdb.py +171 -0
  54. boltz/data/write/utils.py +23 -0
  55. boltz/data/write/writer.py +330 -0
  56. boltz/main.py +1292 -0
  57. boltz/model/__init__.py +0 -0
  58. boltz/model/layers/__init__.py +0 -0
  59. boltz/model/layers/attention.py +132 -0
  60. boltz/model/layers/attentionv2.py +111 -0
  61. boltz/model/layers/confidence_utils.py +231 -0
  62. boltz/model/layers/dropout.py +34 -0
  63. boltz/model/layers/initialize.py +100 -0
  64. boltz/model/layers/outer_product_mean.py +98 -0
  65. boltz/model/layers/pair_averaging.py +135 -0
  66. boltz/model/layers/pairformer.py +337 -0
  67. boltz/model/layers/relative.py +58 -0
  68. boltz/model/layers/transition.py +78 -0
  69. boltz/model/layers/triangular_attention/__init__.py +0 -0
  70. boltz/model/layers/triangular_attention/attention.py +189 -0
  71. boltz/model/layers/triangular_attention/primitives.py +409 -0
  72. boltz/model/layers/triangular_attention/utils.py +380 -0
  73. boltz/model/layers/triangular_mult.py +212 -0
  74. boltz/model/loss/__init__.py +0 -0
  75. boltz/model/loss/bfactor.py +49 -0
  76. boltz/model/loss/confidence.py +590 -0
  77. boltz/model/loss/confidencev2.py +621 -0
  78. boltz/model/loss/diffusion.py +171 -0
  79. boltz/model/loss/diffusionv2.py +134 -0
  80. boltz/model/loss/distogram.py +48 -0
  81. boltz/model/loss/distogramv2.py +105 -0
  82. boltz/model/loss/validation.py +1025 -0
  83. boltz/model/models/__init__.py +0 -0
  84. boltz/model/models/boltz1.py +1286 -0
  85. boltz/model/models/boltz2.py +1249 -0
  86. boltz/model/modules/__init__.py +0 -0
  87. boltz/model/modules/affinity.py +223 -0
  88. boltz/model/modules/confidence.py +481 -0
  89. boltz/model/modules/confidence_utils.py +181 -0
  90. boltz/model/modules/confidencev2.py +495 -0
  91. boltz/model/modules/diffusion.py +844 -0
  92. boltz/model/modules/diffusion_conditioning.py +116 -0
  93. boltz/model/modules/diffusionv2.py +677 -0
  94. boltz/model/modules/encoders.py +639 -0
  95. boltz/model/modules/encodersv2.py +565 -0
  96. boltz/model/modules/transformers.py +322 -0
  97. boltz/model/modules/transformersv2.py +261 -0
  98. boltz/model/modules/trunk.py +688 -0
  99. boltz/model/modules/trunkv2.py +828 -0
  100. boltz/model/modules/utils.py +303 -0
  101. boltz/model/optim/__init__.py +0 -0
  102. boltz/model/optim/ema.py +389 -0
  103. boltz/model/optim/scheduler.py +99 -0
  104. boltz/model/potentials/__init__.py +0 -0
  105. boltz/model/potentials/potentials.py +497 -0
  106. boltz/model/potentials/schedules.py +32 -0
  107. boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
  108. boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
  109. boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
  110. boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
  111. boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
  112. boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1607 @@
1
+ import contextlib
2
+ from collections import defaultdict
3
+ from dataclasses import dataclass, replace
4
+ from typing import Optional
5
+
6
+ import gemmi
7
+ import numpy as np
8
+ from rdkit import Chem, rdBase
9
+ from rdkit.Chem import AllChem, HybridizationType
10
+ from rdkit.Chem.rdchem import BondStereo, Mol
11
+ from rdkit.Chem.rdDistGeom import GetMoleculeBoundsMatrix
12
+ from rdkit.Chem.rdMolDescriptors import CalcNumHeavyAtoms
13
+ from sklearn.neighbors import KDTree
14
+
15
+ from boltz.data import const
16
+ from boltz.data.mol import load_molecules
17
+ from boltz.data.types import (
18
+ AtomV2,
19
+ BondV2,
20
+ Chain,
21
+ ChiralAtomConstraint,
22
+ Coords,
23
+ Ensemble,
24
+ Interface,
25
+ PlanarBondConstraint,
26
+ PlanarRing5Constraint,
27
+ PlanarRing6Constraint,
28
+ RDKitBoundsConstraint,
29
+ Residue,
30
+ ResidueConstraints,
31
+ StereoBondConstraint,
32
+ StructureInfo,
33
+ StructureV2,
34
+ )
35
+
36
+ ####################################################################################################
37
+ # DATACLASSES
38
+ ####################################################################################################
39
+
40
+
41
+ @dataclass(frozen=True, slots=True)
42
+ class ParsedAtom:
43
+ """A parsed atom object."""
44
+
45
+ name: str
46
+ coords: tuple[float, float, float]
47
+ is_present: bool
48
+ bfactor: float
49
+ plddt: Optional[float] = None
50
+
51
+
52
+ @dataclass(frozen=True, slots=True)
53
+ class ParsedBond:
54
+ """A parsed bond object."""
55
+
56
+ atom_1: int
57
+ atom_2: int
58
+ type: int
59
+
60
+
61
+ @dataclass(frozen=True)
62
+ class ParsedRDKitBoundsConstraint:
63
+ """A parsed RDKit bounds constraint object."""
64
+
65
+ atom_idxs: tuple[int, int]
66
+ is_bond: bool
67
+ is_angle: bool
68
+ upper_bound: float
69
+ lower_bound: float
70
+
71
+
72
+ @dataclass(frozen=True)
73
+ class ParsedChiralAtomConstraint:
74
+ """A parsed chiral atom constraint object."""
75
+
76
+ atom_idxs: tuple[int, int, int, int]
77
+ is_reference: bool
78
+ is_r: bool
79
+
80
+
81
+ @dataclass(frozen=True)
82
+ class ParsedStereoBondConstraint:
83
+ """A parsed stereo bond constraint object."""
84
+
85
+ atom_idxs: tuple[int, int, int, int]
86
+ is_check: bool
87
+ is_e: bool
88
+
89
+
90
+ @dataclass(frozen=True)
91
+ class ParsedPlanarBondConstraint:
92
+ """A parsed planar bond constraint object."""
93
+
94
+ atom_idxs: tuple[int, int, int, int, int, int]
95
+
96
+
97
+ @dataclass(frozen=True)
98
+ class ParsedPlanarRing5Constraint:
99
+ """A parsed planar bond constraint object."""
100
+
101
+ atom_idxs: tuple[int, int, int, int, int]
102
+
103
+
104
+ @dataclass(frozen=True)
105
+ class ParsedPlanarRing6Constraint:
106
+ """A parsed planar bond constraint object."""
107
+
108
+ atom_idxs: tuple[int, int, int, int, int, int]
109
+
110
+
111
+ @dataclass(frozen=True, slots=True)
112
+ class ParsedResidue:
113
+ """A parsed residue object."""
114
+
115
+ name: str
116
+ type: int
117
+ idx: int
118
+ atoms: list[ParsedAtom]
119
+ bonds: list[ParsedBond]
120
+ orig_idx: Optional[int]
121
+ atom_center: int
122
+ atom_disto: int
123
+ is_standard: bool
124
+ is_present: bool
125
+ rdkit_bounds_constraints: Optional[list[ParsedRDKitBoundsConstraint]] = None
126
+ chiral_atom_constraints: Optional[list[ParsedChiralAtomConstraint]] = None
127
+ stereo_bond_constraints: Optional[list[ParsedStereoBondConstraint]] = None
128
+ planar_bond_constraints: Optional[list[ParsedPlanarBondConstraint]] = None
129
+ planar_ring_5_constraints: Optional[list[ParsedPlanarRing5Constraint]] = None
130
+ planar_ring_6_constraints: Optional[list[ParsedPlanarRing6Constraint]] = None
131
+
132
+
133
+ @dataclass(frozen=True, slots=True)
134
+ class ParsedChain:
135
+ """A parsed chain object."""
136
+
137
+ name: str
138
+ entity: str
139
+ type: int
140
+ residues: list[ParsedResidue]
141
+ sequence: Optional[str] = None
142
+
143
+
144
+ @dataclass(frozen=True, slots=True)
145
+ class ParsedConnection:
146
+ """A parsed connection object."""
147
+
148
+ chain_1: str
149
+ chain_2: str
150
+ residue_index_1: int
151
+ residue_index_2: int
152
+ atom_index_1: str
153
+ atom_index_2: str
154
+
155
+
156
+ @dataclass(frozen=True, slots=True)
157
+ class ParsedStructure:
158
+ """A parsed structure object."""
159
+
160
+ data: StructureV2
161
+ info: StructureInfo
162
+ sequences: dict[str, str]
163
+ residue_constraints: Optional[ResidueConstraints] = None
164
+
165
+
166
+ ####################################################################################################
167
+ # HELPERS
168
+ ####################################################################################################
169
+
170
+
171
+ def get_mol(ccd: str, mols: dict, moldir: str) -> Mol:
172
+ """Get mol from CCD code.
173
+
174
+ Return mol with ccd from mols if it is in mols. Otherwise load it from moldir,
175
+ add it to mols, and return the mol.
176
+ """
177
+ mol = mols.get(ccd)
178
+ if mol is None:
179
+ # Load molecule
180
+ mol = load_molecules(moldir, [ccd])[ccd]
181
+
182
+ # Add to resource
183
+ if isinstance(mols, dict):
184
+ mols[ccd] = mol
185
+ else:
186
+ mols.set(ccd, mol)
187
+
188
+ return mol
189
+
190
+
191
+ def get_dates(block: gemmi.cif.Block) -> tuple[str, str, str]:
192
+ """Get the deposited, released, and last revision dates.
193
+
194
+ Parameters
195
+ ----------
196
+ block : gemmi.cif.Block
197
+ The block to process.
198
+
199
+ Returns
200
+ -------
201
+ str
202
+ The deposited date.
203
+ str
204
+ The released date.
205
+ str
206
+ The last revision date.
207
+
208
+ """
209
+ deposited = "_pdbx_database_status.recvd_initial_deposition_date"
210
+ revision = "_pdbx_audit_revision_history.revision_date"
211
+ deposit_date = revision_date = release_date = ""
212
+ with contextlib.suppress(Exception):
213
+ deposit_date = block.find([deposited])[0][0]
214
+ release_date = block.find([revision])[0][0]
215
+ revision_date = block.find([revision])[-1][0]
216
+
217
+ return deposit_date, release_date, revision_date
218
+
219
+
220
+ def get_resolution(block: gemmi.cif.Block) -> float:
221
+ """Get the resolution from a gemmi structure.
222
+
223
+ Parameters
224
+ ----------
225
+ block : gemmi.cif.Block
226
+ The block to process.
227
+
228
+ Returns
229
+ -------
230
+ float
231
+ The resolution.
232
+
233
+ """
234
+ resolution = 0.0
235
+ for res_key in (
236
+ "_refine.ls_d_res_high",
237
+ "_em_3d_reconstruction.resolution",
238
+ "_reflns.d_resolution_high",
239
+ ):
240
+ with contextlib.suppress(Exception):
241
+ resolution = float(block.find([res_key])[0].str(0))
242
+ break
243
+ return resolution
244
+
245
+
246
+ def get_method(block: gemmi.cif.Block) -> str:
247
+ """Get the method from a gemmi structure.
248
+
249
+ Parameters
250
+ ----------
251
+ block : gemmi.cif.Block
252
+ The block to process.
253
+
254
+ Returns
255
+ -------
256
+ str
257
+ The method.
258
+
259
+ """
260
+ method = ""
261
+ method_key = "_exptl.method"
262
+ with contextlib.suppress(Exception):
263
+ methods = block.find([method_key])
264
+ method = ",".join([m.str(0).lower() for m in methods])
265
+
266
+ return method
267
+
268
+
269
+ def get_experiment_conditions(
270
+ block: gemmi.cif.Block,
271
+ ) -> tuple[Optional[float], Optional[float]]:
272
+ """Get temperature and pH.
273
+
274
+ Parameters
275
+ ----------
276
+ block : gemmi.cif.Block
277
+ The block to process.
278
+
279
+ Returns
280
+ -------
281
+ tuple[float, float]
282
+ The temperature and pH.
283
+ """
284
+ temperature = None
285
+ ph = None
286
+
287
+ keys_t = [
288
+ "_exptl_crystal_grow.temp",
289
+ "_pdbx_nmr_exptl_sample_conditions.temperature",
290
+ ]
291
+ for key in keys_t:
292
+ with contextlib.suppress(Exception):
293
+ temperature = float(block.find([key])[0][0])
294
+ break
295
+
296
+ keys_ph = ["_exptl_crystal_grow.pH", "_pdbx_nmr_exptl_sample_conditions.pH"]
297
+ with contextlib.suppress(Exception):
298
+ for key in keys_ph:
299
+ ph = float(block.find([key])[0][0])
300
+ break
301
+
302
+ return temperature, ph
303
+
304
+
305
+ def get_unk_token(dtype: gemmi.PolymerType) -> str:
306
+ """Get the unknown token for a given entity type.
307
+
308
+ Parameters
309
+ ----------
310
+ dtype : gemmi.EntityType
311
+ The entity type.
312
+
313
+ Returns
314
+ -------
315
+ str
316
+ The unknown token.
317
+
318
+ """
319
+ if dtype == gemmi.PolymerType.PeptideL:
320
+ unk = const.unk_token["PROTEIN"]
321
+ elif dtype == gemmi.PolymerType.Dna:
322
+ unk = const.unk_token["DNA"]
323
+ elif dtype == gemmi.PolymerType.Rna:
324
+ unk = const.unk_token["RNA"]
325
+ else:
326
+ msg = f"Unknown polymer type: {dtype}"
327
+ raise ValueError(msg)
328
+
329
+ return unk
330
+
331
+
332
+ def compute_covalent_ligands(
333
+ connections: list[gemmi.Connection],
334
+ subchain_map: dict[tuple[str, int], str],
335
+ entities: dict[str, gemmi.Entity],
336
+ ) -> set[str]:
337
+ """Compute the covalent ligands from a list of connections.
338
+
339
+ Parameters
340
+ ----------
341
+ connections: list[gemmi.Connection]
342
+ The connections to process.
343
+ subchain_map: dict[tuple[str, int], str]
344
+ The mapping from chain, residue index to subchain name.
345
+ entities: dict[str, gemmi.Entity]
346
+ The entities in the structure.
347
+
348
+ Returns
349
+ -------
350
+ set
351
+ The covalent ligand subchains.
352
+
353
+ """
354
+ # Get covalent chain ids
355
+ covalent_chain_ids = set()
356
+ for connection in connections:
357
+ if connection.type.name != "Covale":
358
+ continue
359
+
360
+ # Map to correct subchain
361
+ chain_1_name = connection.partner1.chain_name
362
+ chain_2_name = connection.partner2.chain_name
363
+
364
+ res_1_id = connection.partner1.res_id.seqid
365
+ res_1_id = str(res_1_id.num) + str(res_1_id.icode).strip()
366
+
367
+ res_2_id = connection.partner2.res_id.seqid
368
+ res_2_id = str(res_2_id.num) + str(res_2_id.icode).strip()
369
+
370
+ subchain_1 = subchain_map[(chain_1_name, res_1_id)]
371
+ subchain_2 = subchain_map[(chain_2_name, res_2_id)]
372
+
373
+ # If non-polymer or branched, add to set
374
+ entity_1 = entities[subchain_1].entity_type.name
375
+ entity_2 = entities[subchain_2].entity_type.name
376
+
377
+ if entity_1 in {"NonPolymer", "Branched"}:
378
+ covalent_chain_ids.add(subchain_1)
379
+ if entity_2 in {"NonPolymer", "Branched"}:
380
+ covalent_chain_ids.add(subchain_2)
381
+
382
+ return covalent_chain_ids
383
+
384
+
385
+ def compute_interfaces(atom_data: np.ndarray, chain_data: np.ndarray) -> np.ndarray:
386
+ """Compute the chain-chain interfaces from a gemmi structure.
387
+
388
+ Parameters
389
+ ----------
390
+ atom_data : list[tuple]
391
+ The atom data.
392
+ chain_data : list[tuple]
393
+ The chain data.
394
+
395
+ Returns
396
+ -------
397
+ list[tuple[int, int]]
398
+ The interfaces.
399
+
400
+ """
401
+ # Compute chain_id per atom
402
+ chain_ids = []
403
+ for idx, chain in enumerate(chain_data):
404
+ chain_ids.extend([idx] * chain["atom_num"])
405
+ chain_ids = np.array(chain_ids)
406
+
407
+ # Filter to present atoms
408
+ coords = atom_data["coords"]
409
+ mask = atom_data["is_present"]
410
+
411
+ coords = coords[mask]
412
+ chain_ids = chain_ids[mask]
413
+
414
+ # Compute the distance matrix
415
+ tree = KDTree(coords, metric="euclidean")
416
+ query = tree.query_radius(coords, const.atom_interface_cutoff)
417
+
418
+ # Get unique chain pairs
419
+ interfaces = set()
420
+ for c1, pairs in zip(chain_ids, query):
421
+ chains = np.unique(chain_ids[pairs])
422
+ chains = chains[chains != c1]
423
+ interfaces.update((c1, c2) for c2 in chains)
424
+
425
+ # Get unique chain pairs
426
+ interfaces = [(min(i, j), max(i, j)) for i, j in interfaces]
427
+ interfaces = list({(int(i), int(j)) for i, j in interfaces})
428
+ interfaces = np.array(interfaces, dtype=Interface)
429
+ return interfaces
430
+
431
+
432
+ ####################################################################################################
433
+ # CONSTRAINTS
434
+ ####################################################################################################
435
+
436
+
437
+ def compute_geometry_constraints(mol: Mol, idx_map):
438
+ if mol.GetNumAtoms() <= 1:
439
+ return []
440
+
441
+ bounds = GetMoleculeBoundsMatrix(
442
+ mol,
443
+ set15bounds=True,
444
+ scaleVDW=True,
445
+ doTriangleSmoothing=True,
446
+ useMacrocycle14config=False,
447
+ )
448
+ bonds = set(
449
+ tuple(sorted(b)) for b in mol.GetSubstructMatches(Chem.MolFromSmarts("*~*"))
450
+ )
451
+ angles = set(
452
+ tuple(sorted([a[0], a[2]]))
453
+ for a in mol.GetSubstructMatches(Chem.MolFromSmarts("*~*~*"))
454
+ )
455
+
456
+ constraints = []
457
+ for i, j in zip(*np.triu_indices(mol.GetNumAtoms(), k=1)):
458
+ if i in idx_map and j in idx_map:
459
+ constraint = ParsedRDKitBoundsConstraint(
460
+ atom_idxs=(idx_map[i], idx_map[j]),
461
+ is_bond=tuple(sorted([i, j])) in bonds,
462
+ is_angle=tuple(sorted([i, j])) in angles,
463
+ upper_bound=bounds[i, j],
464
+ lower_bound=bounds[j, i],
465
+ )
466
+ constraints.append(constraint)
467
+ return constraints
468
+
469
+
470
+ def compute_chiral_atom_constraints(mol, idx_map):
471
+ constraints = []
472
+ if all([atom.HasProp("_CIPRank") for atom in mol.GetAtoms()]):
473
+ for center_idx, orientation in Chem.FindMolChiralCenters(
474
+ mol, includeUnassigned=False
475
+ ):
476
+ center = mol.GetAtomWithIdx(center_idx)
477
+ neighbors = [
478
+ (neighbor.GetIdx(), int(neighbor.GetProp("_CIPRank")))
479
+ for neighbor in center.GetNeighbors()
480
+ ]
481
+ neighbors = sorted(
482
+ neighbors, key=lambda neighbor: neighbor[1], reverse=True
483
+ )
484
+ neighbors = tuple(neighbor[0] for neighbor in neighbors)
485
+ is_r = orientation == "R"
486
+
487
+ if len(neighbors) > 4 or center.GetHybridization() != HybridizationType.SP3:
488
+ continue
489
+
490
+ atom_idxs = (*neighbors[:3], center_idx)
491
+ if all(i in idx_map for i in atom_idxs):
492
+ constraints.append(
493
+ ParsedChiralAtomConstraint(
494
+ atom_idxs=tuple(idx_map[i] for i in atom_idxs),
495
+ is_reference=True,
496
+ is_r=is_r,
497
+ )
498
+ )
499
+
500
+ if len(neighbors) == 4:
501
+ for skip_idx in range(3):
502
+ chiral_set = neighbors[:skip_idx] + neighbors[skip_idx + 1 :]
503
+ if skip_idx % 2 == 0:
504
+ atom_idxs = chiral_set[::-1] + (center_idx,)
505
+ else:
506
+ atom_idxs = chiral_set + (center_idx,)
507
+ if all(i in idx_map for i in atom_idxs):
508
+ constraints.append(
509
+ ParsedChiralAtomConstraint(
510
+ atom_idxs=tuple(idx_map[i] for i in atom_idxs),
511
+ is_reference=False,
512
+ is_r=is_r,
513
+ )
514
+ )
515
+ return constraints
516
+
517
+
518
+ def compute_stereo_bond_constraints(mol, idx_map):
519
+ constraints = []
520
+ if all([atom.HasProp("_CIPRank") for atom in mol.GetAtoms()]):
521
+ for bond in mol.GetBonds():
522
+ stereo = bond.GetStereo()
523
+ if stereo in {BondStereo.STEREOE, BondStereo.STEREOZ}:
524
+ start_atom_idx, end_atom_idx = (
525
+ bond.GetBeginAtomIdx(),
526
+ bond.GetEndAtomIdx(),
527
+ )
528
+ start_neighbors = [
529
+ (neighbor.GetIdx(), int(neighbor.GetProp("_CIPRank")))
530
+ for neighbor in mol.GetAtomWithIdx(start_atom_idx).GetNeighbors()
531
+ if neighbor.GetIdx() != end_atom_idx
532
+ ]
533
+ start_neighbors = sorted(
534
+ start_neighbors, key=lambda neighbor: neighbor[1], reverse=True
535
+ )
536
+ start_neighbors = [neighbor[0] for neighbor in start_neighbors]
537
+ end_neighbors = [
538
+ (neighbor.GetIdx(), int(neighbor.GetProp("_CIPRank")))
539
+ for neighbor in mol.GetAtomWithIdx(end_atom_idx).GetNeighbors()
540
+ if neighbor.GetIdx() != start_atom_idx
541
+ ]
542
+ end_neighbors = sorted(
543
+ end_neighbors, key=lambda neighbor: neighbor[1], reverse=True
544
+ )
545
+ end_neighbors = [neighbor[0] for neighbor in end_neighbors]
546
+ is_e = stereo == BondStereo.STEREOE
547
+
548
+ atom_idxs = (
549
+ start_neighbors[0],
550
+ start_atom_idx,
551
+ end_atom_idx,
552
+ end_neighbors[0],
553
+ )
554
+ if all(i in idx_map for i in atom_idxs):
555
+ constraints.append(
556
+ ParsedStereoBondConstraint(
557
+ atom_idxs=tuple(idx_map[i] for i in atom_idxs),
558
+ is_check=True,
559
+ is_e=is_e,
560
+ )
561
+ )
562
+
563
+ if len(start_neighbors) == 2 and len(end_neighbors) == 2:
564
+ atom_idxs = (
565
+ start_neighbors[1],
566
+ start_atom_idx,
567
+ end_atom_idx,
568
+ end_neighbors[1],
569
+ )
570
+ if all(i in idx_map for i in atom_idxs):
571
+ constraints.append(
572
+ ParsedStereoBondConstraint(
573
+ atom_idxs=tuple(idx_map[i] for i in atom_idxs),
574
+ is_check=False,
575
+ is_e=is_e,
576
+ )
577
+ )
578
+ return constraints
579
+
580
+
581
+ def compute_flatness_constraints(mol, idx_map):
582
+ planar_double_bond_smarts = Chem.MolFromSmarts("[C;X3;^2](*)(*)=[C;X3;^2](*)(*)")
583
+ aromatic_ring_5_smarts = Chem.MolFromSmarts("[ar5^2]1[ar5^2][ar5^2][ar5^2][ar5^2]1")
584
+ aromatic_ring_6_smarts = Chem.MolFromSmarts(
585
+ "[ar6^2]1[ar6^2][ar6^2][ar6^2][ar6^2][ar6^2]1"
586
+ )
587
+
588
+ planar_double_bond_constraints = []
589
+ aromatic_ring_5_constraints = []
590
+ aromatic_ring_6_constraints = []
591
+ for match in mol.GetSubstructMatches(planar_double_bond_smarts):
592
+ if all(i in idx_map for i in match):
593
+ planar_double_bond_constraints.append(
594
+ ParsedPlanarBondConstraint(atom_idxs=tuple(idx_map[i] for i in match))
595
+ )
596
+ for match in mol.GetSubstructMatches(aromatic_ring_5_smarts):
597
+ if all(i in idx_map for i in match):
598
+ aromatic_ring_5_constraints.append(
599
+ ParsedPlanarRing5Constraint(atom_idxs=tuple(idx_map[i] for i in match))
600
+ )
601
+ for match in mol.GetSubstructMatches(aromatic_ring_6_smarts):
602
+ if all(i in idx_map for i in match):
603
+ aromatic_ring_6_constraints.append(
604
+ ParsedPlanarRing6Constraint(atom_idxs=tuple(idx_map[i] for i in match))
605
+ )
606
+
607
+ return (
608
+ planar_double_bond_constraints,
609
+ aromatic_ring_5_constraints,
610
+ aromatic_ring_6_constraints,
611
+ )
612
+
613
+
614
+ ####################################################################################################
615
+ # PARSING
616
+ ####################################################################################################
617
+
618
+
619
+ def parse_ccd_residue( # noqa: PLR0915, C901
620
+ name: str,
621
+ ref_mol: Mol,
622
+ res_idx: int,
623
+ gemmi_mol: Optional[gemmi.Residue] = None,
624
+ is_covalent: bool = False,
625
+ ) -> Optional[ParsedResidue]:
626
+ """Parse an MMCIF ligand.
627
+
628
+ First tries to get the SMILES string from the RCSB.
629
+ Then, tries to infer atom ordering using RDKit.
630
+
631
+ Parameters
632
+ ----------
633
+ name: str
634
+ The name of the molecule to parse.
635
+ components : dict
636
+ The preprocessed PDB components dictionary.
637
+ res_idx : int
638
+ The residue index.
639
+ gemmi_mol : Optional[gemmi.Residue]
640
+ The PDB molecule, as a gemmi Residue object, if any.
641
+
642
+ Returns
643
+ -------
644
+ ParsedResidue, optional
645
+ The output ParsedResidue, if successful.
646
+
647
+ """
648
+ # Check if we have a PDB structure for this residue,
649
+ # it could be a missing residue from the sequence
650
+ is_present = gemmi_mol is not None
651
+
652
+ # Save original index (required for parsing connections)
653
+ if is_present:
654
+ orig_idx = gemmi_mol.seqid
655
+ orig_idx = str(orig_idx.num) + str(orig_idx.icode).strip()
656
+ else:
657
+ orig_idx = None
658
+
659
+ # Check if this is a single heavy atom CCD residue
660
+ if CalcNumHeavyAtoms(ref_mol) == 1:
661
+ # Remove hydrogens
662
+ ref_mol = AllChem.RemoveHs(ref_mol, sanitize=False)
663
+
664
+ pos = (0, 0, 0)
665
+ bfactor = 0
666
+ if is_present:
667
+ pos = (
668
+ gemmi_mol[0].pos.x,
669
+ gemmi_mol[0].pos.y,
670
+ gemmi_mol[0].pos.z,
671
+ )
672
+ bfactor = gemmi_mol[0].b_iso
673
+ ref_atom = ref_mol.GetAtoms()[0]
674
+ atom = ParsedAtom(
675
+ name=ref_atom.GetProp("name"),
676
+ coords=pos,
677
+ is_present=is_present,
678
+ bfactor=bfactor,
679
+ )
680
+ unk_prot_id = const.unk_token_ids["PROTEIN"]
681
+ residue = ParsedResidue(
682
+ name=name,
683
+ type=unk_prot_id,
684
+ atoms=[atom],
685
+ bonds=[],
686
+ idx=res_idx,
687
+ orig_idx=orig_idx,
688
+ atom_center=0, # Placeholder, no center
689
+ atom_disto=0, # Placeholder, no center
690
+ is_standard=False,
691
+ is_present=is_present,
692
+ )
693
+ return residue
694
+
695
+ # If multi-atom, start by getting the PDB coordinates
696
+ pdb_pos = {}
697
+ bfactor = {}
698
+ if is_present:
699
+ # Match atoms based on names
700
+ for atom in gemmi_mol:
701
+ atom: gemmi.Atom
702
+ pos = (atom.pos.x, atom.pos.y, atom.pos.z)
703
+ pdb_pos[atom.name] = pos
704
+ bfactor[atom.name] = atom.b_iso
705
+ # Parse each atom in order of the reference mol
706
+ atoms = []
707
+ atom_idx = 0
708
+ idx_map = {} # Used for bonds later
709
+
710
+ for i, atom in enumerate(ref_mol.GetAtoms()):
711
+ # Ignore Hydrogen atoms
712
+ if atom.GetAtomicNum() == 1:
713
+ continue
714
+
715
+ # Get atom name, charge, element and reference coordinates
716
+ atom_name = atom.GetProp("name")
717
+
718
+ # If the atom is a leaving atom, skip if not in the PDB and is_covalent
719
+ if (
720
+ atom.HasProp("leaving_atom")
721
+ and int(atom.GetProp("leaving_atom")) == 1
722
+ and is_covalent
723
+ and (atom_name not in pdb_pos)
724
+ ):
725
+ continue
726
+
727
+ # Get PDB coordinates, if any
728
+ coords = pdb_pos.get(atom_name)
729
+ if coords is None:
730
+ atom_is_present = False
731
+ coords = (0, 0, 0)
732
+ else:
733
+ atom_is_present = True
734
+
735
+ # Add atom to list
736
+ atoms.append(
737
+ ParsedAtom(
738
+ name=atom_name,
739
+ coords=coords,
740
+ is_present=atom_is_present,
741
+ bfactor=bfactor.get(atom_name, 0),
742
+ )
743
+ )
744
+ idx_map[i] = atom_idx
745
+ atom_idx += 1
746
+
747
+ # Load bonds
748
+ bonds = []
749
+ unk_bond = const.bond_type_ids[const.unk_bond_type]
750
+ for bond in ref_mol.GetBonds():
751
+ idx_1 = bond.GetBeginAtomIdx()
752
+ idx_2 = bond.GetEndAtomIdx()
753
+
754
+ # Skip bonds with atoms ignored
755
+ if (idx_1 not in idx_map) or (idx_2 not in idx_map):
756
+ continue
757
+
758
+ idx_1 = idx_map[idx_1]
759
+ idx_2 = idx_map[idx_2]
760
+ start = min(idx_1, idx_2)
761
+ end = max(idx_1, idx_2)
762
+ bond_type = bond.GetBondType().name
763
+ bond_type = const.bond_type_ids.get(bond_type, unk_bond)
764
+ bonds.append(ParsedBond(start, end, bond_type))
765
+
766
+ rdkit_bounds_constraints = compute_geometry_constraints(ref_mol, idx_map)
767
+ chiral_atom_constraints = compute_chiral_atom_constraints(ref_mol, idx_map)
768
+ stereo_bond_constraints = compute_stereo_bond_constraints(ref_mol, idx_map)
769
+ planar_bond_constraints, planar_ring_5_constraints, planar_ring_6_constraints = (
770
+ compute_flatness_constraints(ref_mol, idx_map)
771
+ )
772
+
773
+ unk_prot_id = const.unk_token_ids["PROTEIN"]
774
+ return ParsedResidue(
775
+ name=name,
776
+ type=unk_prot_id,
777
+ atoms=atoms,
778
+ bonds=bonds,
779
+ idx=res_idx,
780
+ atom_center=0,
781
+ atom_disto=0,
782
+ orig_idx=orig_idx,
783
+ is_standard=False,
784
+ is_present=is_present,
785
+ rdkit_bounds_constraints=rdkit_bounds_constraints,
786
+ chiral_atom_constraints=chiral_atom_constraints,
787
+ stereo_bond_constraints=stereo_bond_constraints,
788
+ planar_bond_constraints=planar_bond_constraints,
789
+ planar_ring_5_constraints=planar_ring_5_constraints,
790
+ planar_ring_6_constraints=planar_ring_6_constraints,
791
+ )
792
+
793
+
794
+ def parse_polymer( # noqa: C901, PLR0915, PLR0912
795
+ polymer: gemmi.ResidueSpan,
796
+ polymer_type: gemmi.PolymerType,
797
+ sequence: list[str],
798
+ chain_id: str,
799
+ entity: str,
800
+ mols: dict[str, Mol],
801
+ moldir: str,
802
+ ) -> Optional[ParsedChain]:
803
+ """Process a gemmi Polymer into a chain object.
804
+
805
+ Performs alignment of the full sequence to the polymer
806
+ residues. Loads coordinates and masks for the atoms in
807
+ the polymer, following the ordering in const.atom_order.
808
+
809
+ Parameters
810
+ ----------
811
+ polymer : gemmi.ResidueSpan
812
+ The polymer to process.
813
+ polymer_type : gemmi.PolymerType
814
+ The polymer type.
815
+ sequence : str
816
+ The full sequence of the polymer.
817
+ chain_id : str
818
+ The chain identifier.
819
+ entity : str
820
+ The entity name.
821
+ components : dict[str, Mol]
822
+ The preprocessed PDB components dictionary.
823
+
824
+ Returns
825
+ -------
826
+ ParsedChain, optional
827
+ The output chain, if successful.
828
+
829
+ Raises
830
+ ------
831
+ ValueError
832
+ If the alignment fails.
833
+
834
+ """
835
+ # Ignore microheterogeneities (pick first)
836
+ sequence = [gemmi.Entity.first_mon(item) for item in sequence]
837
+
838
+ # Align full sequence to polymer residues
839
+ # This is a simple way to handle all the different numbering schemes
840
+ result = gemmi.align_sequence_to_polymer(
841
+ sequence,
842
+ polymer,
843
+ polymer_type,
844
+ gemmi.AlignmentScoring(),
845
+ )
846
+
847
+ # Get coordinates and masks
848
+ i = 0
849
+ ref_res = set(const.tokens)
850
+ parsed = []
851
+ for j, match in enumerate(result.match_string):
852
+ # Get residue name from sequence
853
+ res_name = sequence[j]
854
+
855
+ # Check if we have a match in the structure
856
+ res = None
857
+ name_to_atom = {}
858
+
859
+ if match == "|":
860
+ # Get pdb residue
861
+ res = polymer[i]
862
+ name_to_atom = {a.name.upper(): a for a in res}
863
+
864
+ # Double check the match
865
+ if res.name != res_name:
866
+ msg = "Alignment mismatch!"
867
+ raise ValueError(msg)
868
+
869
+ # Increment polymer index
870
+ i += 1
871
+
872
+ # Map MSE to MET, put the selenium atom in the sulphur column
873
+ if res_name == "MSE":
874
+ res_name = "MET"
875
+ if "SE" in name_to_atom:
876
+ name_to_atom["SD"] = name_to_atom["SE"]
877
+
878
+ # Handle non-standard residues
879
+ elif res_name not in ref_res:
880
+ modified_mol = get_mol(res_name, mols, moldir)
881
+ if modified_mol is not None:
882
+ residue = parse_ccd_residue(
883
+ name=res_name,
884
+ ref_mol=modified_mol,
885
+ res_idx=j,
886
+ gemmi_mol=res,
887
+ is_covalent=True,
888
+ )
889
+ parsed.append(residue)
890
+ continue
891
+ else: # noqa: RET507
892
+ res_name = "UNK"
893
+
894
+ # Load regular residues
895
+ ref_mol = get_mol(res_name, mols, moldir)
896
+ ref_mol = AllChem.RemoveHs(ref_mol, sanitize=False)
897
+
898
+ # Only use reference atoms set in constants
899
+ ref_name_to_atom = {a.GetProp("name"): a for a in ref_mol.GetAtoms()}
900
+ ref_atoms = [ref_name_to_atom[a] for a in const.ref_atoms[res_name]]
901
+
902
+ # Iterate, always in the same order
903
+ atoms: list[ParsedAtom] = []
904
+
905
+ for ref_atom in ref_atoms:
906
+ # Get atom name
907
+ atom_name = ref_atom.GetProp("name")
908
+
909
+ # Get coordinates from PDB
910
+ if atom_name in name_to_atom:
911
+ atom: gemmi.Atom = name_to_atom[atom_name]
912
+ atom_is_present = True
913
+ coords = (atom.pos.x, atom.pos.y, atom.pos.z)
914
+ bfactor = atom.b_iso
915
+ else:
916
+ atom_is_present = False
917
+ coords = (0, 0, 0)
918
+ bfactor = 0
919
+
920
+ # Add atom to list
921
+ atoms.append(
922
+ ParsedAtom(
923
+ name=atom_name,
924
+ coords=coords,
925
+ is_present=atom_is_present,
926
+ bfactor=bfactor,
927
+ )
928
+ )
929
+
930
+ # Fix naming errors in arginine residues where NH2 is
931
+ # incorrectly assigned to be closer to CD than NH1
932
+ if (res is not None) and (res_name == "ARG"):
933
+ ref_atoms: list[str] = const.ref_atoms["ARG"]
934
+ cd = atoms[ref_atoms.index("CD")]
935
+ nh1 = atoms[ref_atoms.index("NH1")]
936
+ nh2 = atoms[ref_atoms.index("NH2")]
937
+
938
+ cd_coords = np.array(cd.coords)
939
+ nh1_coords = np.array(nh1.coords)
940
+ nh2_coords = np.array(nh2.coords)
941
+
942
+ if all(atom.is_present for atom in (cd, nh1, nh2)) and (
943
+ np.linalg.norm(nh1_coords - cd_coords)
944
+ > np.linalg.norm(nh2_coords - cd_coords)
945
+ ):
946
+ atoms[ref_atoms.index("NH1")] = replace(nh1, coords=nh2.coords)
947
+ atoms[ref_atoms.index("NH2")] = replace(nh2, coords=nh1.coords)
948
+
949
+ # Add residue to parsed list
950
+ if res is not None:
951
+ orig_idx = res.seqid
952
+ orig_idx = str(orig_idx.num) + str(orig_idx.icode).strip()
953
+ else:
954
+ orig_idx = None
955
+
956
+ atom_center = const.res_to_center_atom_id[res_name]
957
+ atom_disto = const.res_to_disto_atom_id[res_name]
958
+ parsed.append(
959
+ ParsedResidue(
960
+ name=res_name,
961
+ type=const.token_ids[res_name],
962
+ atoms=atoms,
963
+ bonds=[],
964
+ idx=j,
965
+ atom_center=atom_center,
966
+ atom_disto=atom_disto,
967
+ is_standard=True,
968
+ is_present=res is not None,
969
+ orig_idx=orig_idx,
970
+ )
971
+ )
972
+
973
+ # Get polymer class
974
+ if polymer_type == gemmi.PolymerType.PeptideL:
975
+ chain_type = const.chain_type_ids["PROTEIN"]
976
+ elif polymer_type == gemmi.PolymerType.Dna:
977
+ chain_type = const.chain_type_ids["DNA"]
978
+ elif polymer_type == gemmi.PolymerType.Rna:
979
+ chain_type = const.chain_type_ids["RNA"]
980
+
981
+ # Return polymer object
982
+ return ParsedChain(
983
+ name=chain_id,
984
+ entity=entity,
985
+ residues=parsed,
986
+ type=chain_type,
987
+ sequence=gemmi.one_letter_code(sequence),
988
+ )
989
+
990
+
991
+ def parse_connection(
992
+ connection: gemmi.Connection,
993
+ chains: list[ParsedChain],
994
+ subchain_map: dict[tuple[str, int], str],
995
+ ) -> ParsedConnection:
996
+ """Parse (covalent) connection from a gemmi Connection.
997
+
998
+ Parameters
999
+ ----------
1000
+ connections : gemmi.Connectionlist
1001
+ The connection list to parse.
1002
+ chains : list[Chain]
1003
+ The parsed chains.
1004
+ subchain_map : dict[tuple[str, int], str]
1005
+ The mapping from chain, residue index to subchain name.
1006
+
1007
+ Returns
1008
+ -------
1009
+ list[Connection]
1010
+ The parsed connections.
1011
+
1012
+ """
1013
+ # Map to correct subchains
1014
+ chain_1_name = connection.partner1.chain_name
1015
+ chain_2_name = connection.partner2.chain_name
1016
+
1017
+ res_1_id = connection.partner1.res_id.seqid
1018
+ res_1_id = str(res_1_id.num) + str(res_1_id.icode).strip()
1019
+
1020
+ res_2_id = connection.partner2.res_id.seqid
1021
+ res_2_id = str(res_2_id.num) + str(res_2_id.icode).strip()
1022
+
1023
+ subchain_1 = subchain_map[(chain_1_name, res_1_id)]
1024
+ subchain_2 = subchain_map[(chain_2_name, res_2_id)]
1025
+
1026
+ # Get chain indices
1027
+ chain_1 = next(chain for chain in chains if (chain.name == subchain_1))
1028
+ chain_2 = next(chain for chain in chains if (chain.name == subchain_2))
1029
+
1030
+ # Get residue indices
1031
+ res_1_idx, res_1 = next(
1032
+ (idx, res)
1033
+ for idx, res in enumerate(chain_1.residues)
1034
+ if (res.orig_idx == res_1_id)
1035
+ )
1036
+ res_2_idx, res_2 = next(
1037
+ (idx, res)
1038
+ for idx, res in enumerate(chain_2.residues)
1039
+ if (res.orig_idx == res_2_id)
1040
+ )
1041
+
1042
+ # Get atom indices
1043
+ atom_index_1 = next(
1044
+ idx
1045
+ for idx, atom in enumerate(res_1.atoms)
1046
+ if atom.name == connection.partner1.atom_name
1047
+ )
1048
+ atom_index_2 = next(
1049
+ idx
1050
+ for idx, atom in enumerate(res_2.atoms)
1051
+ if atom.name == connection.partner2.atom_name
1052
+ )
1053
+
1054
+ conn = ParsedConnection(
1055
+ chain_1=subchain_1,
1056
+ chain_2=subchain_2,
1057
+ residue_index_1=res_1_idx,
1058
+ residue_index_2=res_2_idx,
1059
+ atom_index_1=atom_index_1,
1060
+ atom_index_2=atom_index_2,
1061
+ )
1062
+
1063
+ return conn
1064
+
1065
+
1066
+ def parse_mmcif( # noqa: C901, PLR0915, PLR0912
1067
+ path: str,
1068
+ mols: Optional[dict[str, Mol]] = None,
1069
+ moldir: Optional[str] = None,
1070
+ use_assembly: bool = True,
1071
+ call_compute_interfaces: bool = True,
1072
+ ) -> ParsedStructure:
1073
+ """Parse a structure in MMCIF format.
1074
+
1075
+ Parameters
1076
+ ----------
1077
+ mmcif_file : PathLike
1078
+ Path to the MMCIF file.
1079
+ components: Mapping[str, Mol]
1080
+ The preprocessed PDB components dictionary.
1081
+
1082
+ Returns
1083
+ -------
1084
+ ParsedStructure
1085
+ The parsed structure.
1086
+
1087
+ """
1088
+ # Disable rdkit warnings
1089
+ blocker = rdBase.BlockLogs() # noqa: F841
1090
+
1091
+ # set mols
1092
+ mols = {} if mols is None else mols
1093
+
1094
+ # Parse MMCIF input file
1095
+ block = gemmi.cif.read(str(path))[0]
1096
+
1097
+ # Extract medatadata
1098
+ deposit_date, release_date, revision_date = get_dates(block)
1099
+ resolution = get_resolution(block)
1100
+ method = get_method(block)
1101
+ temperature, ph = get_experiment_conditions(block)
1102
+
1103
+ # Load structure object
1104
+ structure = gemmi.make_structure_from_block(block)
1105
+
1106
+ # Clean up the structure
1107
+ structure.merge_chain_parts()
1108
+ structure.remove_waters()
1109
+ structure.remove_hydrogens()
1110
+ structure.remove_alternative_conformations()
1111
+ structure.remove_empty_chains()
1112
+
1113
+ # Expand assembly 1
1114
+ if use_assembly and structure.assemblies:
1115
+ how = gemmi.HowToNameCopiedChain.AddNumber
1116
+ assembly_name = structure.assemblies[0].name
1117
+ structure.transform_to_assembly(assembly_name, how=how)
1118
+
1119
+ # Parse entities
1120
+ # Create mapping from subchain id to entity
1121
+ entities: dict[str, gemmi.Entity] = {}
1122
+ entity_ids: dict[str, int] = {}
1123
+ for entity_id, entity in enumerate(structure.entities):
1124
+ entity: gemmi.Entity
1125
+ if entity.entity_type.name == "Water":
1126
+ continue
1127
+ for subchain_id in entity.subchains:
1128
+ entities[subchain_id] = entity
1129
+ entity_ids[subchain_id] = entity_id
1130
+
1131
+ # Create mapping from chain, residue to subchains
1132
+ # since a Connection uses the chains and not subchins
1133
+ subchain_map = {}
1134
+ for chain in structure[0]:
1135
+ for residue in chain:
1136
+ seq_id = residue.seqid
1137
+ seq_id = str(seq_id.num) + str(seq_id.icode).strip()
1138
+ subchain_map[(chain.name, seq_id)] = residue.subchain
1139
+
1140
+ # Find covalent ligands
1141
+ covalent_chain_ids = compute_covalent_ligands(
1142
+ connections=structure.connections,
1143
+ subchain_map=subchain_map,
1144
+ entities=entities,
1145
+ )
1146
+
1147
+ # Parse chains
1148
+ chains: list[ParsedChain] = []
1149
+ for raw_chain in structure[0].subchains():
1150
+ # Check chain type
1151
+ subchain_id = raw_chain.subchain_id()
1152
+ entity: gemmi.Entity = entities[subchain_id]
1153
+ entity_type = entity.entity_type.name
1154
+
1155
+ # Parse a polymer
1156
+ if entity_type == "Polymer":
1157
+ # Skip PeptideD, DnaRnaHybrid, Pna, Other
1158
+ if entity.polymer_type.name not in {
1159
+ "PeptideL",
1160
+ "Dna",
1161
+ "Rna",
1162
+ }:
1163
+ continue
1164
+
1165
+ # Add polymer if successful
1166
+ parsed_polymer = parse_polymer(
1167
+ polymer=raw_chain,
1168
+ polymer_type=entity.polymer_type,
1169
+ sequence=entity.full_sequence,
1170
+ chain_id=subchain_id,
1171
+ entity=entity.name,
1172
+ mols=mols,
1173
+ moldir=moldir,
1174
+ )
1175
+ if parsed_polymer is not None:
1176
+ chains.append(parsed_polymer)
1177
+
1178
+ # Parse a non-polymer
1179
+ elif entity_type in {"NonPolymer", "Branched"}:
1180
+ # Skip UNL
1181
+ if any(lig.name == "UNL" for lig in raw_chain):
1182
+ continue
1183
+
1184
+ residues = []
1185
+ for lig_idx, ligand in enumerate(raw_chain):
1186
+ # Check if ligand is covalent
1187
+ if entity_type == "Branched":
1188
+ is_covalent = True
1189
+ else:
1190
+ is_covalent = subchain_id in covalent_chain_ids
1191
+
1192
+ ligand: gemmi.Residue
1193
+ ligand_mol = get_mol(ligand.name, mols, moldir)
1194
+
1195
+ residue = parse_ccd_residue(
1196
+ name=ligand.name,
1197
+ ref_mol=ligand_mol,
1198
+ res_idx=lig_idx,
1199
+ gemmi_mol=ligand,
1200
+ is_covalent=is_covalent,
1201
+ )
1202
+ residues.append(residue)
1203
+
1204
+ if residues:
1205
+ chains.append(
1206
+ ParsedChain(
1207
+ name=subchain_id,
1208
+ entity=entity.name,
1209
+ residues=residues,
1210
+ type=const.chain_type_ids["NONPOLYMER"],
1211
+ )
1212
+ )
1213
+
1214
+ # If no chains parsed fail
1215
+ if not chains:
1216
+ msg = "No chains parsed!"
1217
+ raise ValueError(msg)
1218
+
1219
+ # Want to traverse subchains in same order as reference structure
1220
+ ref_chain_map = {ref_chain.name: i for i, ref_chain in enumerate(chains)}
1221
+ all_ensembles = [chains]
1222
+
1223
+ # Loop through different structures in model
1224
+ for struct in list(structure)[1:]:
1225
+ struct: gemmi.Model
1226
+ ensemble_chains = {}
1227
+
1228
+ for raw_chain in struct.subchains():
1229
+ # Check chain type
1230
+ subchain_id = raw_chain.subchain_id()
1231
+ entity: gemmi.Entity = entities[subchain_id]
1232
+ entity_type = entity.entity_type.name
1233
+
1234
+ # Parse a polymer
1235
+ if entity_type == "Polymer":
1236
+ # Skip PeptideD, DnaRnaHybrid, Pna, Other
1237
+ if entity.polymer_type.name not in {
1238
+ "PeptideL",
1239
+ "Dna",
1240
+ "Rna",
1241
+ }:
1242
+ continue
1243
+
1244
+ # Add polymer if successful
1245
+ parsed_polymer = parse_polymer(
1246
+ polymer=raw_chain,
1247
+ polymer_type=entity.polymer_type,
1248
+ sequence=entity.full_sequence,
1249
+ chain_id=subchain_id,
1250
+ entity=entity.name,
1251
+ mols=mols,
1252
+ moldir=moldir,
1253
+ )
1254
+ if parsed_polymer is not None:
1255
+ ensemble_chains[ref_chain_map[subchain_id]] = parsed_polymer
1256
+
1257
+ # Parse a non-polymer
1258
+ elif entity_type in {"NonPolymer", "Branched"}:
1259
+ # Skip UNL
1260
+ if any(lig.name == "UNL" for lig in raw_chain):
1261
+ continue
1262
+
1263
+ residues = []
1264
+ for lig_idx, ligand in enumerate(raw_chain):
1265
+ # Check if ligand is covalent
1266
+ if entity_type == "Branched":
1267
+ is_covalent = True
1268
+ else:
1269
+ is_covalent = subchain_id in covalent_chain_ids
1270
+
1271
+ ligand: gemmi.Residue
1272
+ ligand_mol = get_mol(ligand.name, mols, moldir)
1273
+
1274
+ residue = parse_ccd_residue(
1275
+ name=ligand.name,
1276
+ ref_mol=ligand_mol,
1277
+ res_idx=lig_idx,
1278
+ gemmi_mol=ligand,
1279
+ is_covalent=is_covalent,
1280
+ )
1281
+ residues.append(residue)
1282
+
1283
+ if residues:
1284
+ parsed_non_polymer = ParsedChain(
1285
+ name=subchain_id,
1286
+ entity=entity.name,
1287
+ residues=residues,
1288
+ type=const.chain_type_ids["NONPOLYMER"],
1289
+ )
1290
+ ensemble_chains[ref_chain_map[subchain_id]] = parsed_non_polymer
1291
+
1292
+ # Ensure ensemble chains are in the same order as reference structure
1293
+ ensemble_chains = [ensemble_chains[idx] for idx in range(len(ensemble_chains))]
1294
+ all_ensembles.append(ensemble_chains)
1295
+
1296
+ # Parse covalent connections
1297
+ connections: list[ParsedConnection] = []
1298
+ for connection in structure.connections:
1299
+ # Skip non-covalent connections
1300
+ connection: gemmi.Connection
1301
+ if connection.type.name != "Covale":
1302
+ continue
1303
+ try:
1304
+ parsed_connection = parse_connection(
1305
+ connection=connection,
1306
+ chains=chains,
1307
+ subchain_map=subchain_map,
1308
+ )
1309
+ except Exception: # noqa: S112, BLE001
1310
+ continue
1311
+ connections.append(parsed_connection)
1312
+
1313
+ # Create tables
1314
+ atom_data = []
1315
+ bond_data = []
1316
+ res_data = []
1317
+ chain_data = []
1318
+ ensemble_data = []
1319
+ coords_data = defaultdict(list)
1320
+
1321
+ rdkit_bounds_constraint_data = []
1322
+ chiral_atom_constraint_data = []
1323
+ stereo_bond_constraint_data = []
1324
+ planar_bond_constraint_data = []
1325
+ planar_ring_5_constraint_data = []
1326
+ planar_ring_6_constraint_data = []
1327
+
1328
+ # Convert parsed chains to tables
1329
+ atom_idx = 0
1330
+ res_idx = 0
1331
+ sym_count = {}
1332
+ chain_to_idx = {}
1333
+ res_to_idx = {}
1334
+ chain_to_seq = {}
1335
+
1336
+ for asym_id, chain in enumerate(chains):
1337
+ # Compute number of atoms and residues
1338
+ res_num = len(chain.residues)
1339
+ atom_num = sum(len(res.atoms) for res in chain.residues)
1340
+
1341
+ # Get same chain across models in ensemble
1342
+ ensemble_chains = [ensemble[asym_id] for ensemble in all_ensembles]
1343
+ assert len(ensemble_chains) == len(all_ensembles)
1344
+ for ensemble_chain in ensemble_chains:
1345
+ assert len(ensemble_chain.residues) == res_num
1346
+ assert sum(len(res.atoms) for res in ensemble_chain.residues) == atom_num
1347
+
1348
+ # Find all copies of this chain in the assembly
1349
+ entity_id = entity_ids[chain.name]
1350
+ sym_id = sym_count.get(entity_id, 0)
1351
+ chain_data.append(
1352
+ (
1353
+ chain.name,
1354
+ chain.type,
1355
+ entity_id,
1356
+ sym_id,
1357
+ asym_id,
1358
+ atom_idx,
1359
+ atom_num,
1360
+ res_idx,
1361
+ res_num,
1362
+ 0,
1363
+ )
1364
+ )
1365
+ chain_to_idx[chain.name] = asym_id
1366
+ sym_count[entity_id] = sym_id + 1
1367
+ if chain.sequence is not None:
1368
+ chain_to_seq[chain.name] = chain.sequence
1369
+
1370
+ # Add residue, atom, bond, data
1371
+ for i, res in enumerate(chain.residues):
1372
+ # Get same residue across models in ensemble
1373
+ ensemble_residues = [
1374
+ ensemble_chain.residues[i] for ensemble_chain in ensemble_chains
1375
+ ]
1376
+ assert len(ensemble_residues) == len(all_ensembles)
1377
+ for ensemble_res in ensemble_residues:
1378
+ assert ensemble_res.name == res.name
1379
+
1380
+ atom_center = atom_idx + res.atom_center
1381
+ atom_disto = atom_idx + res.atom_disto
1382
+ res_data.append(
1383
+ (
1384
+ res.name,
1385
+ res.type,
1386
+ res.idx,
1387
+ atom_idx,
1388
+ len(res.atoms),
1389
+ atom_center,
1390
+ atom_disto,
1391
+ res.is_standard,
1392
+ res.is_present,
1393
+ )
1394
+ )
1395
+ res_to_idx[(chain.name, i)] = (res_idx, atom_idx)
1396
+
1397
+ if res.rdkit_bounds_constraints is not None:
1398
+ for constraint in res.rdkit_bounds_constraints:
1399
+ rdkit_bounds_constraint_data.append( # noqa: PERF401
1400
+ (
1401
+ tuple(
1402
+ c_atom_idx + atom_idx
1403
+ for c_atom_idx in constraint.atom_idxs
1404
+ ),
1405
+ constraint.is_bond,
1406
+ constraint.is_angle,
1407
+ constraint.upper_bound,
1408
+ constraint.lower_bound,
1409
+ )
1410
+ )
1411
+ if res.chiral_atom_constraints is not None:
1412
+ for constraint in res.chiral_atom_constraints:
1413
+ chiral_atom_constraint_data.append( # noqa: PERF401
1414
+ (
1415
+ tuple(
1416
+ c_atom_idx + atom_idx
1417
+ for c_atom_idx in constraint.atom_idxs
1418
+ ),
1419
+ constraint.is_reference,
1420
+ constraint.is_r,
1421
+ )
1422
+ )
1423
+ if res.stereo_bond_constraints is not None:
1424
+ for constraint in res.stereo_bond_constraints:
1425
+ stereo_bond_constraint_data.append( # noqa: PERF401
1426
+ (
1427
+ tuple(
1428
+ c_atom_idx + atom_idx
1429
+ for c_atom_idx in constraint.atom_idxs
1430
+ ),
1431
+ constraint.is_check,
1432
+ constraint.is_e,
1433
+ )
1434
+ )
1435
+ if res.planar_bond_constraints is not None:
1436
+ for constraint in res.planar_bond_constraints:
1437
+ planar_bond_constraint_data.append( # noqa: PERF401
1438
+ (
1439
+ tuple(
1440
+ c_atom_idx + atom_idx
1441
+ for c_atom_idx in constraint.atom_idxs
1442
+ ),
1443
+ )
1444
+ )
1445
+ if res.planar_ring_5_constraints is not None:
1446
+ for constraint in res.planar_ring_5_constraints:
1447
+ planar_ring_5_constraint_data.append( # noqa: PERF401
1448
+ (
1449
+ tuple(
1450
+ c_atom_idx + atom_idx
1451
+ for c_atom_idx in constraint.atom_idxs
1452
+ ),
1453
+ )
1454
+ )
1455
+ if res.planar_ring_6_constraints is not None:
1456
+ for constraint in res.planar_ring_6_constraints:
1457
+ planar_ring_6_constraint_data.append( # noqa: PERF401
1458
+ (
1459
+ tuple(
1460
+ c_atom_idx + atom_idx
1461
+ for c_atom_idx in constraint.atom_idxs
1462
+ ),
1463
+ )
1464
+ )
1465
+
1466
+ for bond in res.bonds:
1467
+ chain_1 = asym_id
1468
+ chain_2 = asym_id
1469
+ res_1 = res_idx
1470
+ res_2 = res_idx
1471
+ atom_1 = atom_idx + bond.atom_1
1472
+ atom_2 = atom_idx + bond.atom_2
1473
+ bond_data.append(
1474
+ (
1475
+ chain_1,
1476
+ chain_2,
1477
+ res_1,
1478
+ res_2,
1479
+ atom_1,
1480
+ atom_2,
1481
+ bond.type,
1482
+ )
1483
+ )
1484
+
1485
+ for a_idx, atom in enumerate(res.atoms):
1486
+ # Get same atom across models in ensemble
1487
+ ensemble_atoms = [
1488
+ ensemble_res.atoms[a_idx] for ensemble_res in ensemble_residues
1489
+ ]
1490
+ assert len(ensemble_atoms) == len(all_ensembles)
1491
+ for e_idx, ensemble_atom in enumerate(ensemble_atoms):
1492
+ assert ensemble_atom.name == atom.name
1493
+ assert atom.is_present == ensemble_atom.is_present
1494
+
1495
+ coords_data[e_idx].append(ensemble_atom.coords)
1496
+
1497
+ atom_data.append(
1498
+ (
1499
+ atom.name,
1500
+ atom.coords,
1501
+ atom.is_present,
1502
+ atom.bfactor,
1503
+ 1.0, # plddt is 1 for real data
1504
+ )
1505
+ )
1506
+ atom_idx += 1
1507
+
1508
+ res_idx += 1
1509
+
1510
+ # Create coordinates table
1511
+ coords_data_ = []
1512
+ for e_idx in range(len(coords_data)):
1513
+ ensemble_data.append((e_idx * atom_idx, atom_idx))
1514
+ coords_data_.append(coords_data[e_idx])
1515
+ coords_data = [(x,) for xs in coords_data_ for x in xs]
1516
+
1517
+ # Convert connections to tables
1518
+ for conn in connections:
1519
+ chain_1_idx = chain_to_idx[conn.chain_1]
1520
+ chain_2_idx = chain_to_idx[conn.chain_2]
1521
+ res_1_idx, atom_1_offset = res_to_idx[(conn.chain_1, conn.residue_index_1)]
1522
+ res_2_idx, atom_2_offset = res_to_idx[(conn.chain_2, conn.residue_index_2)]
1523
+ atom_1_idx = atom_1_offset + conn.atom_index_1
1524
+ atom_2_idx = atom_2_offset + conn.atom_index_2
1525
+ bond_data.append(
1526
+ (
1527
+ chain_1_idx,
1528
+ chain_2_idx,
1529
+ res_1_idx,
1530
+ res_2_idx,
1531
+ atom_1_idx,
1532
+ atom_2_idx,
1533
+ const.bond_type_ids["COVALENT"],
1534
+ )
1535
+ )
1536
+
1537
+ # Convert into datatypes
1538
+ atoms = np.array(atom_data, dtype=AtomV2)
1539
+ bonds = np.array(bond_data, dtype=BondV2)
1540
+ residues = np.array(res_data, dtype=Residue)
1541
+ chains = np.array(chain_data, dtype=Chain)
1542
+ mask = np.ones(len(chain_data), dtype=bool)
1543
+ ensemble = np.array(ensemble_data, dtype=Ensemble)
1544
+ coords = np.array(coords_data, dtype=Coords)
1545
+ rdkit_bounds_constraints = np.array(
1546
+ rdkit_bounds_constraint_data, dtype=RDKitBoundsConstraint
1547
+ )
1548
+ chiral_atom_constraints = np.array(
1549
+ chiral_atom_constraint_data, dtype=ChiralAtomConstraint
1550
+ )
1551
+ stereo_bond_constraints = np.array(
1552
+ stereo_bond_constraint_data, dtype=StereoBondConstraint
1553
+ )
1554
+ planar_bond_constraints = np.array(
1555
+ planar_bond_constraint_data, dtype=PlanarBondConstraint
1556
+ )
1557
+ planar_ring_5_constraints = np.array(
1558
+ planar_ring_5_constraint_data, dtype=PlanarRing5Constraint
1559
+ )
1560
+ planar_ring_6_constraints = np.array(
1561
+ planar_ring_6_constraint_data, dtype=PlanarRing6Constraint
1562
+ )
1563
+ residue_constraints = ResidueConstraints(
1564
+ rdkit_bounds_constraints=rdkit_bounds_constraints,
1565
+ chiral_atom_constraints=chiral_atom_constraints,
1566
+ stereo_bond_constraints=stereo_bond_constraints,
1567
+ planar_bond_constraints=planar_bond_constraints,
1568
+ planar_ring_5_constraints=planar_ring_5_constraints,
1569
+ planar_ring_6_constraints=planar_ring_6_constraints,
1570
+ )
1571
+
1572
+ # Compute interface chains (find chains with a heavy atom within 5A)
1573
+ if call_compute_interfaces:
1574
+ interfaces = compute_interfaces(atoms, chains)
1575
+ else:
1576
+ interfaces = np.array([], dtype=Interface)
1577
+
1578
+ # Return parsed structure
1579
+ info = StructureInfo(
1580
+ deposited=deposit_date,
1581
+ revised=revision_date,
1582
+ released=release_date,
1583
+ resolution=resolution,
1584
+ method=method,
1585
+ num_chains=len(chains),
1586
+ num_interfaces=len(interfaces),
1587
+ temperature=temperature,
1588
+ pH=ph,
1589
+ )
1590
+
1591
+ data = StructureV2(
1592
+ atoms=atoms,
1593
+ bonds=bonds,
1594
+ residues=residues,
1595
+ chains=chains,
1596
+ interfaces=interfaces,
1597
+ mask=mask,
1598
+ ensemble=ensemble,
1599
+ coords=coords,
1600
+ )
1601
+
1602
+ return ParsedStructure(
1603
+ data=data,
1604
+ info=info,
1605
+ sequences=chain_to_seq,
1606
+ residue_constraints=residue_constraints,
1607
+ )