boltz-vsynthes 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. boltz/__init__.py +7 -0
  2. boltz/data/__init__.py +0 -0
  3. boltz/data/const.py +1184 -0
  4. boltz/data/crop/__init__.py +0 -0
  5. boltz/data/crop/affinity.py +164 -0
  6. boltz/data/crop/boltz.py +296 -0
  7. boltz/data/crop/cropper.py +45 -0
  8. boltz/data/feature/__init__.py +0 -0
  9. boltz/data/feature/featurizer.py +1230 -0
  10. boltz/data/feature/featurizerv2.py +2208 -0
  11. boltz/data/feature/symmetry.py +602 -0
  12. boltz/data/filter/__init__.py +0 -0
  13. boltz/data/filter/dynamic/__init__.py +0 -0
  14. boltz/data/filter/dynamic/date.py +76 -0
  15. boltz/data/filter/dynamic/filter.py +24 -0
  16. boltz/data/filter/dynamic/max_residues.py +37 -0
  17. boltz/data/filter/dynamic/resolution.py +34 -0
  18. boltz/data/filter/dynamic/size.py +38 -0
  19. boltz/data/filter/dynamic/subset.py +42 -0
  20. boltz/data/filter/static/__init__.py +0 -0
  21. boltz/data/filter/static/filter.py +26 -0
  22. boltz/data/filter/static/ligand.py +37 -0
  23. boltz/data/filter/static/polymer.py +299 -0
  24. boltz/data/module/__init__.py +0 -0
  25. boltz/data/module/inference.py +307 -0
  26. boltz/data/module/inferencev2.py +429 -0
  27. boltz/data/module/training.py +684 -0
  28. boltz/data/module/trainingv2.py +660 -0
  29. boltz/data/mol.py +900 -0
  30. boltz/data/msa/__init__.py +0 -0
  31. boltz/data/msa/mmseqs2.py +235 -0
  32. boltz/data/pad.py +84 -0
  33. boltz/data/parse/__init__.py +0 -0
  34. boltz/data/parse/a3m.py +134 -0
  35. boltz/data/parse/csv.py +100 -0
  36. boltz/data/parse/fasta.py +138 -0
  37. boltz/data/parse/mmcif.py +1239 -0
  38. boltz/data/parse/mmcif_with_constraints.py +1607 -0
  39. boltz/data/parse/schema.py +1851 -0
  40. boltz/data/parse/yaml.py +68 -0
  41. boltz/data/sample/__init__.py +0 -0
  42. boltz/data/sample/cluster.py +283 -0
  43. boltz/data/sample/distillation.py +57 -0
  44. boltz/data/sample/random.py +39 -0
  45. boltz/data/sample/sampler.py +49 -0
  46. boltz/data/tokenize/__init__.py +0 -0
  47. boltz/data/tokenize/boltz.py +195 -0
  48. boltz/data/tokenize/boltz2.py +396 -0
  49. boltz/data/tokenize/tokenizer.py +24 -0
  50. boltz/data/types.py +777 -0
  51. boltz/data/write/__init__.py +0 -0
  52. boltz/data/write/mmcif.py +305 -0
  53. boltz/data/write/pdb.py +171 -0
  54. boltz/data/write/utils.py +23 -0
  55. boltz/data/write/writer.py +330 -0
  56. boltz/main.py +1292 -0
  57. boltz/model/__init__.py +0 -0
  58. boltz/model/layers/__init__.py +0 -0
  59. boltz/model/layers/attention.py +132 -0
  60. boltz/model/layers/attentionv2.py +111 -0
  61. boltz/model/layers/confidence_utils.py +231 -0
  62. boltz/model/layers/dropout.py +34 -0
  63. boltz/model/layers/initialize.py +100 -0
  64. boltz/model/layers/outer_product_mean.py +98 -0
  65. boltz/model/layers/pair_averaging.py +135 -0
  66. boltz/model/layers/pairformer.py +337 -0
  67. boltz/model/layers/relative.py +58 -0
  68. boltz/model/layers/transition.py +78 -0
  69. boltz/model/layers/triangular_attention/__init__.py +0 -0
  70. boltz/model/layers/triangular_attention/attention.py +189 -0
  71. boltz/model/layers/triangular_attention/primitives.py +409 -0
  72. boltz/model/layers/triangular_attention/utils.py +380 -0
  73. boltz/model/layers/triangular_mult.py +212 -0
  74. boltz/model/loss/__init__.py +0 -0
  75. boltz/model/loss/bfactor.py +49 -0
  76. boltz/model/loss/confidence.py +590 -0
  77. boltz/model/loss/confidencev2.py +621 -0
  78. boltz/model/loss/diffusion.py +171 -0
  79. boltz/model/loss/diffusionv2.py +134 -0
  80. boltz/model/loss/distogram.py +48 -0
  81. boltz/model/loss/distogramv2.py +105 -0
  82. boltz/model/loss/validation.py +1025 -0
  83. boltz/model/models/__init__.py +0 -0
  84. boltz/model/models/boltz1.py +1286 -0
  85. boltz/model/models/boltz2.py +1249 -0
  86. boltz/model/modules/__init__.py +0 -0
  87. boltz/model/modules/affinity.py +223 -0
  88. boltz/model/modules/confidence.py +481 -0
  89. boltz/model/modules/confidence_utils.py +181 -0
  90. boltz/model/modules/confidencev2.py +495 -0
  91. boltz/model/modules/diffusion.py +844 -0
  92. boltz/model/modules/diffusion_conditioning.py +116 -0
  93. boltz/model/modules/diffusionv2.py +677 -0
  94. boltz/model/modules/encoders.py +639 -0
  95. boltz/model/modules/encodersv2.py +565 -0
  96. boltz/model/modules/transformers.py +322 -0
  97. boltz/model/modules/transformersv2.py +261 -0
  98. boltz/model/modules/trunk.py +688 -0
  99. boltz/model/modules/trunkv2.py +828 -0
  100. boltz/model/modules/utils.py +303 -0
  101. boltz/model/optim/__init__.py +0 -0
  102. boltz/model/optim/ema.py +389 -0
  103. boltz/model/optim/scheduler.py +99 -0
  104. boltz/model/potentials/__init__.py +0 -0
  105. boltz/model/potentials/potentials.py +497 -0
  106. boltz/model/potentials/schedules.py +32 -0
  107. boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
  108. boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
  109. boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
  110. boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
  111. boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
  112. boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
boltz/data/types.py ADDED
@@ -0,0 +1,777 @@
1
+ import json
2
+ from dataclasses import asdict, dataclass
3
+ from pathlib import Path
4
+ from typing import Optional, Union
5
+
6
+ import numpy as np
7
+ from mashumaro.mixins.dict import DataClassDictMixin
8
+ from rdkit.Chem import Mol
9
+
10
+ ####################################################################################################
11
+ # SERIALIZABLE
12
+ ####################################################################################################
13
+
14
+
15
+ class NumpySerializable:
16
+ """Serializable datatype."""
17
+
18
+ @classmethod
19
+ def load(cls: "NumpySerializable", path: Path) -> "NumpySerializable":
20
+ """Load the object from an NPZ file.
21
+
22
+ Parameters
23
+ ----------
24
+ path : Path
25
+ The path to the file.
26
+
27
+ Returns
28
+ -------
29
+ Serializable
30
+ The loaded object.
31
+
32
+ """
33
+ return cls(**np.load(path, allow_pickle=True))
34
+
35
+ def dump(self, path: Path) -> None:
36
+ """Dump the object to an NPZ file.
37
+
38
+ Parameters
39
+ ----------
40
+ path : Path
41
+ The path to the file.
42
+
43
+ """
44
+ np.savez_compressed(str(path), **asdict(self))
45
+
46
+
47
+ class JSONSerializable(DataClassDictMixin):
48
+ """Serializable datatype."""
49
+
50
+ @classmethod
51
+ def load(cls: "JSONSerializable", path: Path) -> "JSONSerializable":
52
+ """Load the object from a JSON file.
53
+
54
+ Parameters
55
+ ----------
56
+ path : Path
57
+ The path to the file.
58
+
59
+ Returns
60
+ -------
61
+ Serializable
62
+ The loaded object.
63
+
64
+ """
65
+ with path.open("r") as f:
66
+ return cls.from_dict(json.load(f))
67
+
68
+ def dump(self, path: Path) -> None:
69
+ """Dump the object to a JSON file.
70
+
71
+ Parameters
72
+ ----------
73
+ path : Path
74
+ The path to the file.
75
+
76
+ """
77
+ with path.open("w") as f:
78
+ json.dump(self.to_dict(), f)
79
+
80
+
81
+ ####################################################################################################
82
+ # STRUCTURE
83
+ ####################################################################################################
84
+
85
+ Atom = [
86
+ ("name", np.dtype("4i1")),
87
+ ("element", np.dtype("i1")),
88
+ ("charge", np.dtype("i1")),
89
+ ("coords", np.dtype("3f4")),
90
+ ("conformer", np.dtype("3f4")),
91
+ ("is_present", np.dtype("?")),
92
+ ("chirality", np.dtype("i1")),
93
+ ]
94
+
95
+ AtomV2 = [
96
+ ("name", np.dtype("<U4")),
97
+ ("coords", np.dtype("3f4")),
98
+ ("is_present", np.dtype("?")),
99
+ ("bfactor", np.dtype("f4")),
100
+ ("plddt", np.dtype("f4")),
101
+ ]
102
+
103
+ Bond = [
104
+ ("atom_1", np.dtype("i4")),
105
+ ("atom_2", np.dtype("i4")),
106
+ ("type", np.dtype("i1")),
107
+ ]
108
+
109
+ BondV2 = [
110
+ ("chain_1", np.dtype("i4")),
111
+ ("chain_2", np.dtype("i4")),
112
+ ("res_1", np.dtype("i4")),
113
+ ("res_2", np.dtype("i4")),
114
+ ("atom_1", np.dtype("i4")),
115
+ ("atom_2", np.dtype("i4")),
116
+ ("type", np.dtype("i1")),
117
+ ]
118
+
119
+ Residue = [
120
+ ("name", np.dtype("<U5")),
121
+ ("res_type", np.dtype("i1")),
122
+ ("res_idx", np.dtype("i4")),
123
+ ("atom_idx", np.dtype("i4")),
124
+ ("atom_num", np.dtype("i4")),
125
+ ("atom_center", np.dtype("i4")),
126
+ ("atom_disto", np.dtype("i4")),
127
+ ("is_standard", np.dtype("?")),
128
+ ("is_present", np.dtype("?")),
129
+ ]
130
+
131
+ Chain = [
132
+ ("name", np.dtype("<U5")),
133
+ ("mol_type", np.dtype("i1")),
134
+ ("entity_id", np.dtype("i4")),
135
+ ("sym_id", np.dtype("i4")),
136
+ ("asym_id", np.dtype("i4")),
137
+ ("atom_idx", np.dtype("i4")),
138
+ ("atom_num", np.dtype("i4")),
139
+ ("res_idx", np.dtype("i4")),
140
+ ("res_num", np.dtype("i4")),
141
+ ("cyclic_period", np.dtype("i4")),
142
+ ]
143
+
144
+ Connection = [
145
+ ("chain_1", np.dtype("i4")),
146
+ ("chain_2", np.dtype("i4")),
147
+ ("res_1", np.dtype("i4")),
148
+ ("res_2", np.dtype("i4")),
149
+ ("atom_1", np.dtype("i4")),
150
+ ("atom_2", np.dtype("i4")),
151
+ ]
152
+
153
+ Interface = [
154
+ ("chain_1", np.dtype("i4")),
155
+ ("chain_2", np.dtype("i4")),
156
+ ]
157
+
158
+ Coords = [
159
+ ("coords", np.dtype("3f4")),
160
+ ]
161
+
162
+ Ensemble = [
163
+ ("atom_coord_idx", np.dtype("i4")),
164
+ ("atom_num", np.dtype("i4")),
165
+ ]
166
+
167
+
168
+ @dataclass(frozen=True)
169
+ class Structure(NumpySerializable):
170
+ """Structure datatype."""
171
+
172
+ atoms: np.ndarray
173
+ bonds: np.ndarray
174
+ residues: np.ndarray
175
+ chains: np.ndarray
176
+ connections: np.ndarray
177
+ interfaces: np.ndarray
178
+ mask: np.ndarray
179
+
180
+ @classmethod
181
+ def load(cls: "Structure", path: Path) -> "Structure":
182
+ """Load a structure from an NPZ file.
183
+
184
+ Parameters
185
+ ----------
186
+ path : Path
187
+ The path to the file.
188
+
189
+ Returns
190
+ -------
191
+ Structure
192
+ The loaded structure.
193
+
194
+ """
195
+ structure = np.load(path)
196
+ return cls(
197
+ atoms=structure["atoms"],
198
+ bonds=structure["bonds"],
199
+ residues=structure["residues"],
200
+ chains=structure["chains"],
201
+ connections=structure["connections"].astype(Connection),
202
+ interfaces=structure["interfaces"],
203
+ mask=structure["mask"],
204
+ )
205
+
206
+ def remove_invalid_chains(self) -> "Structure": # noqa: PLR0915
207
+ """Remove invalid chains.
208
+
209
+ Parameters
210
+ ----------
211
+ structure : Structure
212
+ The structure to process.
213
+
214
+ Returns
215
+ -------
216
+ Structure
217
+ The structure with masked chains removed.
218
+
219
+ """
220
+ entity_counter = {}
221
+ atom_idx, res_idx, chain_idx = 0, 0, 0
222
+ atoms, residues, chains = [], [], []
223
+ atom_map, res_map, chain_map = {}, {}, {}
224
+ for i, chain in enumerate(self.chains):
225
+ # Skip masked chains
226
+ if not self.mask[i]:
227
+ continue
228
+
229
+ # Update entity counter
230
+ entity_id = chain["entity_id"]
231
+ if entity_id not in entity_counter:
232
+ entity_counter[entity_id] = 0
233
+ else:
234
+ entity_counter[entity_id] += 1
235
+
236
+ # Update the chain
237
+ new_chain = chain.copy()
238
+ new_chain["atom_idx"] = atom_idx
239
+ new_chain["res_idx"] = res_idx
240
+ new_chain["asym_id"] = chain_idx
241
+ new_chain["sym_id"] = entity_counter[entity_id]
242
+ chains.append(new_chain)
243
+ chain_map[i] = chain_idx
244
+ chain_idx += 1
245
+
246
+ # Add the chain residues
247
+ res_start = chain["res_idx"]
248
+ res_end = chain["res_idx"] + chain["res_num"]
249
+ for j, res in enumerate(self.residues[res_start:res_end]):
250
+ # Update the residue
251
+ new_res = res.copy()
252
+ new_res["atom_idx"] = atom_idx
253
+ new_res["atom_center"] = (
254
+ atom_idx + new_res["atom_center"] - res["atom_idx"]
255
+ )
256
+ new_res["atom_disto"] = (
257
+ atom_idx + new_res["atom_disto"] - res["atom_idx"]
258
+ )
259
+ residues.append(new_res)
260
+ res_map[res_start + j] = res_idx
261
+ res_idx += 1
262
+
263
+ # Update the atoms
264
+ start = res["atom_idx"]
265
+ end = res["atom_idx"] + res["atom_num"]
266
+ atoms.append(self.atoms[start:end])
267
+ atom_map.update({k: atom_idx + k - start for k in range(start, end)})
268
+ atom_idx += res["atom_num"]
269
+
270
+ # Concatenate the tables
271
+ atoms = np.concatenate(atoms, dtype=Atom)
272
+ residues = np.array(residues, dtype=Residue)
273
+ chains = np.array(chains, dtype=Chain)
274
+
275
+ # Update bonds
276
+ bonds = []
277
+ for bond in self.bonds:
278
+ atom_1 = bond["atom_1"]
279
+ atom_2 = bond["atom_2"]
280
+ if (atom_1 in atom_map) and (atom_2 in atom_map):
281
+ new_bond = bond.copy()
282
+ new_bond["atom_1"] = atom_map[atom_1]
283
+ new_bond["atom_2"] = atom_map[atom_2]
284
+ bonds.append(new_bond)
285
+
286
+ # Update connections
287
+ connections = []
288
+ for connection in self.connections:
289
+ chain_1 = connection["chain_1"]
290
+ chain_2 = connection["chain_2"]
291
+ res_1 = connection["res_1"]
292
+ res_2 = connection["res_2"]
293
+ atom_1 = connection["atom_1"]
294
+ atom_2 = connection["atom_2"]
295
+ if (atom_1 in atom_map) and (atom_2 in atom_map):
296
+ new_connection = connection.copy()
297
+ new_connection["chain_1"] = chain_map[chain_1]
298
+ new_connection["chain_2"] = chain_map[chain_2]
299
+ new_connection["res_1"] = res_map[res_1]
300
+ new_connection["res_2"] = res_map[res_2]
301
+ new_connection["atom_1"] = atom_map[atom_1]
302
+ new_connection["atom_2"] = atom_map[atom_2]
303
+ connections.append(new_connection)
304
+
305
+ # Create arrays
306
+ bonds = np.array(bonds, dtype=Bond)
307
+ connections = np.array(connections, dtype=Connection)
308
+ interfaces = np.array([], dtype=Interface)
309
+ mask = np.ones(len(chains), dtype=bool)
310
+
311
+ return Structure(
312
+ atoms=atoms,
313
+ bonds=bonds,
314
+ residues=residues,
315
+ chains=chains,
316
+ connections=connections,
317
+ interfaces=interfaces,
318
+ mask=mask,
319
+ )
320
+
321
+
322
+ @dataclass(frozen=True)
323
+ class StructureV2(NumpySerializable):
324
+ """Structure datatype."""
325
+
326
+ atoms: np.ndarray
327
+ bonds: np.ndarray
328
+ residues: np.ndarray
329
+ chains: np.ndarray
330
+ interfaces: np.ndarray
331
+ mask: np.ndarray
332
+ coords: np.ndarray
333
+ ensemble: np.ndarray
334
+ pocket: Optional[np.ndarray] = None
335
+
336
+ def remove_invalid_chains(self) -> "StructureV2": # noqa: PLR0915
337
+ """Remove invalid chains.
338
+
339
+ Parameters
340
+ ----------
341
+ structure : Structure
342
+ The structure to process.
343
+
344
+ Returns
345
+ -------
346
+ Structure
347
+ The structure with masked chains removed.
348
+
349
+ """
350
+ entity_counter = {}
351
+ atom_idx, res_idx, chain_idx = 0, 0, 0
352
+ atoms, residues, chains = [], [], []
353
+ atom_map, res_map, chain_map = {}, {}, {}
354
+ for i, chain in enumerate(self.chains):
355
+ # Skip masked chains
356
+ if not self.mask[i]:
357
+ continue
358
+
359
+ # Update entity counter
360
+ entity_id = chain["entity_id"]
361
+ if entity_id not in entity_counter:
362
+ entity_counter[entity_id] = 0
363
+ else:
364
+ entity_counter[entity_id] += 1
365
+
366
+ # Update the chain
367
+ new_chain = chain.copy()
368
+ new_chain["atom_idx"] = atom_idx
369
+ new_chain["res_idx"] = res_idx
370
+ new_chain["asym_id"] = chain_idx
371
+ new_chain["sym_id"] = entity_counter[entity_id]
372
+ chains.append(new_chain)
373
+ chain_map[i] = chain_idx
374
+ chain_idx += 1
375
+
376
+ # Add the chain residues
377
+ res_start = chain["res_idx"]
378
+ res_end = chain["res_idx"] + chain["res_num"]
379
+ for j, res in enumerate(self.residues[res_start:res_end]):
380
+ # Update the residue
381
+ new_res = res.copy()
382
+ new_res["atom_idx"] = atom_idx
383
+ new_res["atom_center"] = (
384
+ atom_idx + new_res["atom_center"] - res["atom_idx"]
385
+ )
386
+ new_res["atom_disto"] = (
387
+ atom_idx + new_res["atom_disto"] - res["atom_idx"]
388
+ )
389
+ residues.append(new_res)
390
+ res_map[res_start + j] = res_idx
391
+ res_idx += 1
392
+
393
+ # Update the atoms
394
+ start = res["atom_idx"]
395
+ end = res["atom_idx"] + res["atom_num"]
396
+ atoms.append(self.atoms[start:end])
397
+ atom_map.update({k: atom_idx + k - start for k in range(start, end)})
398
+ atom_idx += res["atom_num"]
399
+
400
+ # Concatenate the tables
401
+ atoms = np.concatenate(atoms, dtype=AtomV2)
402
+ residues = np.array(residues, dtype=Residue)
403
+ chains = np.array(chains, dtype=Chain)
404
+
405
+ # Update bonds
406
+ bonds = []
407
+ for bond in self.bonds:
408
+ chain_1 = bond["chain_1"]
409
+ chain_2 = bond["chain_2"]
410
+ res_1 = bond["res_1"]
411
+ res_2 = bond["res_2"]
412
+ atom_1 = bond["atom_1"]
413
+ atom_2 = bond["atom_2"]
414
+ if (atom_1 in atom_map) and (atom_2 in atom_map):
415
+ new_bond = bond.copy()
416
+ new_bond["chain_1"] = chain_map[chain_1]
417
+ new_bond["chain_2"] = chain_map[chain_2]
418
+ new_bond["res_1"] = res_map[res_1]
419
+ new_bond["res_2"] = res_map[res_2]
420
+ new_bond["atom_1"] = atom_map[atom_1]
421
+ new_bond["atom_2"] = atom_map[atom_2]
422
+ bonds.append(new_bond)
423
+
424
+ # Create arrays
425
+ bonds = np.array(bonds, dtype=BondV2)
426
+ interfaces = np.array([], dtype=Interface)
427
+ mask = np.ones(len(chains), dtype=bool)
428
+ coords = [(x,) for x in atoms["coords"]]
429
+ coords = np.array(coords, dtype=Coords)
430
+ ensemble = np.array([(0, len(coords))], dtype=Ensemble)
431
+
432
+ return StructureV2(
433
+ atoms=atoms,
434
+ bonds=bonds,
435
+ residues=residues,
436
+ chains=chains,
437
+ interfaces=interfaces,
438
+ mask=mask,
439
+ coords=coords,
440
+ ensemble=ensemble,
441
+ )
442
+
443
+
444
+ ####################################################################################################
445
+ # MSA
446
+ ####################################################################################################
447
+
448
+
449
+ MSAResidue = [
450
+ ("res_type", np.dtype("i1")),
451
+ ]
452
+
453
+ MSADeletion = [
454
+ ("res_idx", np.dtype("i2")),
455
+ ("deletion", np.dtype("i2")),
456
+ ]
457
+
458
+ MSASequence = [
459
+ ("seq_idx", np.dtype("i2")),
460
+ ("taxonomy", np.dtype("i4")),
461
+ ("res_start", np.dtype("i4")),
462
+ ("res_end", np.dtype("i4")),
463
+ ("del_start", np.dtype("i4")),
464
+ ("del_end", np.dtype("i4")),
465
+ ]
466
+
467
+
468
+ @dataclass(frozen=True)
469
+ class MSA(NumpySerializable):
470
+ """MSA datatype."""
471
+
472
+ sequences: np.ndarray
473
+ deletions: np.ndarray
474
+ residues: np.ndarray
475
+
476
+
477
+ ####################################################################################################
478
+ # RECORD
479
+ ####################################################################################################
480
+
481
+
482
+ @dataclass(frozen=True)
483
+ class StructureInfo:
484
+ """StructureInfo datatype."""
485
+
486
+ resolution: Optional[float] = None
487
+ method: Optional[str] = None
488
+ deposited: Optional[str] = None
489
+ released: Optional[str] = None
490
+ revised: Optional[str] = None
491
+ num_chains: Optional[int] = None
492
+ num_interfaces: Optional[int] = None
493
+ pH: Optional[float] = None
494
+ temperature: Optional[float] = None
495
+
496
+
497
+ @dataclass(frozen=False)
498
+ class ChainInfo:
499
+ """ChainInfo datatype."""
500
+
501
+ chain_id: int
502
+ chain_name: str
503
+ mol_type: int
504
+ cluster_id: Union[str, int]
505
+ msa_id: Union[str, int]
506
+ num_residues: int
507
+ valid: bool = True
508
+ entity_id: Optional[Union[str, int]] = None
509
+
510
+
511
+ @dataclass(frozen=True)
512
+ class InterfaceInfo:
513
+ """InterfaceInfo datatype."""
514
+
515
+ chain_1: int
516
+ chain_2: int
517
+ valid: bool = True
518
+
519
+
520
+ @dataclass(frozen=True)
521
+ class InferenceOptions:
522
+ """InferenceOptions datatype."""
523
+
524
+ pocket_constraints: Optional[list[tuple[int, list[tuple[int, int]], float]]] = None
525
+
526
+
527
+ @dataclass(frozen=True)
528
+ class MDInfo:
529
+ """MDInfo datatype."""
530
+
531
+ forcefield: Optional[list[str]]
532
+ temperature: Optional[float] # Kelvin
533
+ pH: Optional[float]
534
+ pressure: Optional[float] # bar
535
+ prod_ensemble: Optional[str]
536
+ solvent: Optional[str]
537
+ ion_concentration: Optional[float] # mM
538
+ time_step: Optional[float] # fs
539
+ sample_frame_time: Optional[float] # ps
540
+ sim_time: Optional[float] # ns
541
+ coarse_grained: Optional[bool] = False
542
+
543
+
544
+ @dataclass(frozen=True)
545
+ class TemplateInfo:
546
+ """InterfaceInfo datatype."""
547
+
548
+ name: str
549
+ query_chain: str
550
+ query_st: int
551
+ query_en: int
552
+ template_chain: str
553
+ template_st: int
554
+ template_en: int
555
+
556
+
557
+ @dataclass(frozen=True)
558
+ class AffinityInfo:
559
+ """AffinityInfo datatype."""
560
+
561
+ chain_id: int
562
+ mw: float
563
+
564
+
565
+ @dataclass(frozen=True)
566
+ class Record(JSONSerializable):
567
+ """Record datatype."""
568
+
569
+ id: str
570
+ structure: StructureInfo
571
+ chains: list[ChainInfo]
572
+ interfaces: list[InterfaceInfo]
573
+ inference_options: Optional[InferenceOptions] = None
574
+ templates: Optional[list[TemplateInfo]] = None
575
+ md: Optional[MDInfo] = None
576
+ affinity: Optional[AffinityInfo] = None
577
+
578
+
579
+ ####################################################################################################
580
+ # RESIDUE CONSTRAINTS
581
+ ####################################################################################################
582
+
583
+
584
+ RDKitBoundsConstraint = [
585
+ ("atom_idxs", np.dtype("2i4")),
586
+ ("is_bond", np.dtype("?")),
587
+ ("is_angle", np.dtype("?")),
588
+ ("upper_bound", np.dtype("f4")),
589
+ ("lower_bound", np.dtype("f4")),
590
+ ]
591
+
592
+ ChiralAtomConstraint = [
593
+ ("atom_idxs", np.dtype("4i4")),
594
+ ("is_reference", np.dtype("?")),
595
+ ("is_r", np.dtype("?")),
596
+ ]
597
+
598
+ StereoBondConstraint = [
599
+ ("atom_idxs", np.dtype("4i4")),
600
+ ("is_reference", np.dtype("?")),
601
+ ("is_e", np.dtype("?")),
602
+ ]
603
+
604
+ PlanarBondConstraint = [
605
+ ("atom_idxs", np.dtype("6i4")),
606
+ ]
607
+
608
+ PlanarRing5Constraint = [
609
+ ("atom_idxs", np.dtype("5i4")),
610
+ ]
611
+
612
+ PlanarRing6Constraint = [
613
+ ("atom_idxs", np.dtype("6i4")),
614
+ ]
615
+
616
+
617
+ @dataclass(frozen=True)
618
+ class ResidueConstraints(NumpySerializable):
619
+ """ResidueConstraints datatype."""
620
+
621
+ rdkit_bounds_constraints: np.ndarray
622
+ chiral_atom_constraints: np.ndarray
623
+ stereo_bond_constraints: np.ndarray
624
+ planar_bond_constraints: np.ndarray
625
+ planar_ring_5_constraints: np.ndarray
626
+ planar_ring_6_constraints: np.ndarray
627
+
628
+
629
+ ####################################################################################################
630
+ # TARGET
631
+ ####################################################################################################
632
+
633
+
634
+ @dataclass(frozen=True)
635
+ class Target:
636
+ """Target datatype."""
637
+
638
+ record: Record
639
+ structure: Structure
640
+ sequences: Optional[dict[str, str]] = None
641
+ residue_constraints: Optional[ResidueConstraints] = None
642
+ templates: Optional[dict[str, StructureV2]] = None
643
+ extra_mols: Optional[dict[str, Mol]] = None
644
+
645
+
646
+ @dataclass(frozen=True)
647
+ class Manifest(JSONSerializable):
648
+ """Manifest datatype."""
649
+
650
+ records: list[Record]
651
+
652
+ @classmethod
653
+ def load(cls: "JSONSerializable", path: Path) -> "JSONSerializable":
654
+ """Load the object from a JSON file.
655
+
656
+ Parameters
657
+ ----------
658
+ path : Path
659
+ The path to the file.
660
+
661
+ Returns
662
+ -------
663
+ Serializable
664
+ The loaded object.
665
+
666
+ Raises
667
+ ------
668
+ TypeError
669
+ If the file is not a valid manifest file.
670
+
671
+ """
672
+ with path.open("r") as f:
673
+ data = json.load(f)
674
+ if isinstance(data, dict):
675
+ manifest = cls.from_dict(data)
676
+ elif isinstance(data, list):
677
+ records = [Record.from_dict(r) for r in data]
678
+ manifest = cls(records=records)
679
+ else:
680
+ msg = "Invalid manifest file."
681
+ raise TypeError(msg)
682
+
683
+ return manifest
684
+
685
+
686
+ ####################################################################################################
687
+ # INPUT
688
+ ####################################################################################################
689
+
690
+
691
+ @dataclass(frozen=True, slots=True)
692
+ class Input:
693
+ """Input datatype."""
694
+
695
+ structure: Structure
696
+ msa: dict[str, MSA]
697
+ record: Optional[Record] = None
698
+ residue_constraints: Optional[ResidueConstraints] = None
699
+ templates: Optional[dict[str, StructureV2]] = None
700
+ extra_mols: Optional[dict[str, Mol]] = None
701
+
702
+
703
+ ####################################################################################################
704
+ # TOKENS
705
+ ####################################################################################################
706
+
707
+ Token = [
708
+ ("token_idx", np.dtype("i4")),
709
+ ("atom_idx", np.dtype("i4")),
710
+ ("atom_num", np.dtype("i4")),
711
+ ("res_idx", np.dtype("i4")),
712
+ ("res_type", np.dtype("i1")),
713
+ ("sym_id", np.dtype("i4")),
714
+ ("asym_id", np.dtype("i4")),
715
+ ("entity_id", np.dtype("i4")),
716
+ ("mol_type", np.dtype("i1")),
717
+ ("center_idx", np.dtype("i4")),
718
+ ("disto_idx", np.dtype("i4")),
719
+ ("center_coords", np.dtype("3f4")),
720
+ ("disto_coords", np.dtype("3f4")),
721
+ ("resolved_mask", np.dtype("?")),
722
+ ("disto_mask", np.dtype("?")),
723
+ ("cyclic_period", np.dtype("i4")),
724
+ ]
725
+
726
+ TokenBond = [
727
+ ("token_1", np.dtype("i4")),
728
+ ("token_2", np.dtype("i4")),
729
+ ]
730
+
731
+
732
+ TokenV2 = [
733
+ ("token_idx", np.dtype("i4")),
734
+ ("atom_idx", np.dtype("i4")),
735
+ ("atom_num", np.dtype("i4")),
736
+ ("res_idx", np.dtype("i4")),
737
+ ("res_type", np.dtype("i4")),
738
+ ("res_name", np.dtype("<U8")),
739
+ ("sym_id", np.dtype("i4")),
740
+ ("asym_id", np.dtype("i4")),
741
+ ("entity_id", np.dtype("i4")),
742
+ ("mol_type", np.dtype("i4")), # the total bytes need to be divisible by 4
743
+ ("center_idx", np.dtype("i4")),
744
+ ("disto_idx", np.dtype("i4")),
745
+ ("center_coords", np.dtype("3f4")),
746
+ ("disto_coords", np.dtype("3f4")),
747
+ ("resolved_mask", np.dtype("?")),
748
+ ("disto_mask", np.dtype("?")),
749
+ ("modified", np.dtype("?")),
750
+ ("frame_rot", np.dtype("9f4")),
751
+ ("frame_t", np.dtype("3f4")),
752
+ ("frame_mask", np.dtype("i4")),
753
+ ("cyclic_period", np.dtype("i4")),
754
+ ("affinity_mask", np.dtype("?")),
755
+ ]
756
+
757
+ TokenBondV2 = [
758
+ ("token_1", np.dtype("i4")),
759
+ ("token_2", np.dtype("i4")),
760
+ ("type", np.dtype("i1")),
761
+ ]
762
+
763
+
764
+ @dataclass(frozen=True)
765
+ class Tokenized:
766
+ """Tokenized datatype."""
767
+
768
+ tokens: np.ndarray
769
+ bonds: np.ndarray
770
+ structure: Structure
771
+ msa: dict[str, MSA]
772
+ record: Optional[Record] = None
773
+ residue_constraints: Optional[ResidueConstraints] = None
774
+ templates: Optional[dict[str, StructureV2]] = None
775
+ template_tokens: Optional[dict[str, np.ndarray]] = None
776
+ template_bonds: Optional[dict[str, np.ndarray]] = None
777
+ extra_mols: Optional[dict[str, Mol]] = None