openprotein-python 0.8.2__1-py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. openprotein/__init__.py +164 -0
  2. openprotein/_version.py +48 -0
  3. openprotein/align/__init__.py +8 -0
  4. openprotein/align/align.py +395 -0
  5. openprotein/align/api.py +428 -0
  6. openprotein/align/future.py +55 -0
  7. openprotein/align/msa.py +129 -0
  8. openprotein/align/schemas.py +165 -0
  9. openprotein/base.py +181 -0
  10. openprotein/chains.py +88 -0
  11. openprotein/common/__init__.py +5 -0
  12. openprotein/common/features.py +7 -0
  13. openprotein/common/model_metadata.py +33 -0
  14. openprotein/common/reduction.py +8 -0
  15. openprotein/config.py +9 -0
  16. openprotein/csv.py +31 -0
  17. openprotein/data/__init__.py +9 -0
  18. openprotein/data/api.py +218 -0
  19. openprotein/data/assaydataset.py +178 -0
  20. openprotein/data/data.py +93 -0
  21. openprotein/data/schemas.py +27 -0
  22. openprotein/design/__init__.py +16 -0
  23. openprotein/design/api.py +259 -0
  24. openprotein/design/design.py +125 -0
  25. openprotein/design/future.py +146 -0
  26. openprotein/design/schemas.py +607 -0
  27. openprotein/embeddings/__init__.py +27 -0
  28. openprotein/embeddings/api.py +619 -0
  29. openprotein/embeddings/embeddings.py +151 -0
  30. openprotein/embeddings/esm.py +33 -0
  31. openprotein/embeddings/future.py +146 -0
  32. openprotein/embeddings/models.py +421 -0
  33. openprotein/embeddings/openprotein.py +21 -0
  34. openprotein/embeddings/poet.py +446 -0
  35. openprotein/embeddings/poet2.py +505 -0
  36. openprotein/embeddings/schemas.py +78 -0
  37. openprotein/errors.py +76 -0
  38. openprotein/fasta.py +92 -0
  39. openprotein/fold/__init__.py +21 -0
  40. openprotein/fold/alphafold2.py +131 -0
  41. openprotein/fold/api.py +287 -0
  42. openprotein/fold/boltz.py +691 -0
  43. openprotein/fold/esmfold.py +54 -0
  44. openprotein/fold/fold.py +107 -0
  45. openprotein/fold/future.py +509 -0
  46. openprotein/fold/models.py +139 -0
  47. openprotein/fold/schemas.py +39 -0
  48. openprotein/jobs/__init__.py +9 -0
  49. openprotein/jobs/api.py +71 -0
  50. openprotein/jobs/futures.py +746 -0
  51. openprotein/jobs/jobs.py +69 -0
  52. openprotein/jobs/schemas.py +135 -0
  53. openprotein/models/__init__.py +4 -0
  54. openprotein/models/base.py +63 -0
  55. openprotein/models/foundation/rfdiffusion.py +283 -0
  56. openprotein/models/models.py +33 -0
  57. openprotein/predictor/__init__.py +25 -0
  58. openprotein/predictor/api.py +384 -0
  59. openprotein/predictor/models.py +374 -0
  60. openprotein/predictor/prediction.py +79 -0
  61. openprotein/predictor/predictor.py +242 -0
  62. openprotein/predictor/schemas.py +113 -0
  63. openprotein/predictor/validate.py +40 -0
  64. openprotein/prompt/__init__.py +9 -0
  65. openprotein/prompt/api.py +505 -0
  66. openprotein/prompt/models.py +142 -0
  67. openprotein/prompt/prompt.py +130 -0
  68. openprotein/prompt/schemas.py +49 -0
  69. openprotein/protein.py +587 -0
  70. openprotein/svd/__init__.py +9 -0
  71. openprotein/svd/api.py +206 -0
  72. openprotein/svd/models.py +288 -0
  73. openprotein/svd/schemas.py +31 -0
  74. openprotein/svd/svd.py +134 -0
  75. openprotein/umap/__init__.py +9 -0
  76. openprotein/umap/api.py +259 -0
  77. openprotein/umap/models.py +211 -0
  78. openprotein/umap/schemas.py +35 -0
  79. openprotein/umap/umap.py +175 -0
  80. openprotein/utils/uuid.py +29 -0
  81. openprotein_python-0.8.2.dist-info/METADATA +176 -0
  82. openprotein_python-0.8.2.dist-info/RECORD +84 -0
  83. openprotein_python-0.8.2.dist-info/WHEEL +4 -0
  84. openprotein_python-0.8.2.dist-info/licenses/LICENSE.txt +30 -0
@@ -0,0 +1,691 @@
1
+ """Community-based Boltz models for complex structure prediction with ligands/dna/rna."""
2
+
3
+ import re
4
+ import string
5
+ from typing import Any
6
+
7
+ from pydantic import BaseModel, Field, TypeAdapter, model_validator
8
+
9
+ from openprotein.align import AlignAPI, MSAFuture
10
+ from openprotein.base import APISession
11
+ from openprotein.chains import DNA, RNA, Ligand
12
+ from openprotein.common import ModelMetadata
13
+ from openprotein.protein import Protein
14
+
15
+ from . import api
16
+ from .future import FoldComplexResultFuture
17
+ from .models import FoldModel
18
+
19
+ valid_id_pattern = re.compile(r"^[A-Z]{1,5}$|^\d{1,5}$")
20
+
21
+
22
+ def is_valid_id(id_str: str) -> bool:
23
+ """
24
+ Check if the id_str matches the valid pattern for IDs (1-5 uppercase or 1-5 digits).
25
+ """
26
+ if not id_str or len(id_str) > 5:
27
+ return False
28
+ return bool(valid_id_pattern.fullmatch(id_str))
29
+
30
+
31
+ def id_generator(used_ids: list[str] | None = None, max_alpha_len=5, max_numeric=99999):
32
+ """
33
+ Yields new chain IDs, skipping any in 'used_ids'.
34
+ First A..Z, AA..ZZ, … up to max_alpha_len, then '1','2',… up to max_numeric.
35
+ """
36
+ used = set(tuple(used_ids or []))
37
+ letters = list(string.ascii_uppercase)
38
+
39
+ # --- Alphabetic IDs ---
40
+ curr_len = 1
41
+ curr_indices = [0] * curr_len # start at 'A'
42
+
43
+ def bump_indices():
44
+ # lexicographically increment curr_indices; return False on overflow
45
+ for i in reversed(range(len(curr_indices))):
46
+ if curr_indices[i] < len(letters) - 1:
47
+ curr_indices[i] += 1
48
+ for j in range(i + 1, len(curr_indices)):
49
+ curr_indices[j] = 0
50
+ return True
51
+ return False
52
+
53
+ while curr_len <= max_alpha_len:
54
+ candidate = "".join(letters[i] for i in curr_indices)
55
+ if candidate not in used:
56
+ used.add(candidate)
57
+ yield candidate
58
+ # bump
59
+ if not bump_indices():
60
+ curr_len += 1
61
+ if curr_len > max_alpha_len:
62
+ break
63
+ curr_indices = [0] * curr_len
64
+
65
+ # --- Numeric IDs ---
66
+ num = 1
67
+ while num <= max_numeric:
68
+ candidate = str(num)
69
+ num += 1
70
+ if candidate not in used:
71
+ used.add(candidate)
72
+ yield candidate
73
+
74
+ # exhausted
75
+ raise RuntimeError("exhausted all possible IDs")
76
+
77
+
78
+ class BoltzModel(FoldModel):
79
+ """
80
+ Class providing inference endpoints for Boltz structure prediction models.
81
+ """
82
+
83
+ model_id: str = "boltz"
84
+
85
+ def __init__(
86
+ self,
87
+ session: APISession,
88
+ model_id: str,
89
+ metadata: ModelMetadata | None = None,
90
+ ):
91
+ super().__init__(session, model_id, metadata)
92
+
93
+ def fold(
94
+ self,
95
+ proteins: list[Protein] | MSAFuture | None = None,
96
+ dnas: list[DNA] | None = None,
97
+ rnas: list[RNA] | None = None,
98
+ ligands: list[Ligand] | None = None,
99
+ diffusion_samples: int = 1,
100
+ recycling_steps: int = 3,
101
+ sampling_steps: int = 200,
102
+ step_scale: float = 1.638,
103
+ use_potentials: bool = False,
104
+ constraints: list[dict] | None = None,
105
+ **kwargs,
106
+ ) -> FoldComplexResultFuture:
107
+ """
108
+ Request structure prediction with boltz model.
109
+
110
+ Parameters
111
+ ----------
112
+ proteins : List[Protein] | MSAFuture | None
113
+ List of protein sequences to include in folded output. `Protein` objects must be tagged with an `msa`, which can be a `Protein.single_sequence_mode` for single sequence mode. Alternatively, supply an `MSAFuture` to use all query sequences as a multimer.
114
+ dna : List[DNA] | None
115
+ List of DNA sequences to include in folded output.
116
+ rna : List[RNA] | None
117
+ List of RNA sequences to include in folded output.
118
+ ligands : List[Ligand] | None
119
+ List of ligands to include in folded output.
120
+ diffusion_samples: int
121
+ Number of diffusion samples to use
122
+ recycling_steps : int
123
+ Number of recycling steps to use
124
+ sampling_steps : int
125
+ Number of sampling steps to use
126
+ step_scale : float
127
+ Scaling factor for diffusion steps.
128
+ constraints : Optional[List[dict]]
129
+ List of constraints.
130
+
131
+ Returns
132
+ -------
133
+ FoldComplexResultFuture
134
+ Future for the folding complex result.
135
+ """
136
+ # validate constraints
137
+ if constraints is not None:
138
+ TypeAdapter(list[BoltzConstraint]).validate_python(constraints)
139
+ # collate the id's used
140
+ used_ids = []
141
+ if isinstance(proteins, list):
142
+ for protein in proteins:
143
+ if isinstance(protein, Protein) and protein.chain_id is not None:
144
+ if isinstance(protein.chain_id, str):
145
+ used_ids.append(protein.chain_id)
146
+ elif isinstance(protein.chain_id, list):
147
+ used_ids.extend(protein.chain_id)
148
+ for dna in dnas or []:
149
+ if isinstance(dna.chain_id, str):
150
+ used_ids.append(dna.chain_id)
151
+ elif isinstance(dna.chain_id, list):
152
+ used_ids.extend(dna.chain_id)
153
+ for rna in rnas or []:
154
+ if isinstance(rna.chain_id, str):
155
+ used_ids.append(rna.chain_id)
156
+ elif isinstance(rna.chain_id, list):
157
+ used_ids.extend(rna.chain_id)
158
+ for ligand in ligands or []:
159
+ if isinstance(ligand.chain_id, str):
160
+ used_ids.append(ligand.chain_id)
161
+ elif isinstance(ligand.chain_id, list):
162
+ used_ids.extend(ligand.chain_id)
163
+ id_gen = id_generator(used_ids)
164
+ # build the proteins from msa
165
+ if isinstance(proteins, MSAFuture):
166
+ align_api = getattr(self.session, "align", None)
167
+ assert isinstance(align_api, AlignAPI)
168
+ msa = proteins # rename
169
+ proteins = [] # convert back to list of proteins
170
+ seed = align_api.get_seed(job_id=msa.job.job_id)
171
+ query_seqs_cardinality: dict[str, int] = dict()
172
+ for seq in seed.split(":"):
173
+ query_seqs_cardinality[seq] = query_seqs_cardinality.get(seq, 0) + 1
174
+ for seq, card in query_seqs_cardinality.items():
175
+ protein = Protein(sequence=seq)
176
+ if card == 1:
177
+ id = next(id_gen)
178
+ else:
179
+ id = [next(id_gen) for _ in range(card)]
180
+ protein.chain_id = id
181
+ protein.msa = msa
182
+ proteins.append(protein)
183
+
184
+ # build the sequences input
185
+ sequences: list[dict[str, Any]] = []
186
+ for protein in proteins or []:
187
+ # check the msa
188
+ msa = protein.msa
189
+ if msa is None:
190
+ raise ValueError(
191
+ "Expected all protein sequences to have `.msa` set with an `MSAFuture` or `Protein.single_sequence_mode` for single sequence mode."
192
+ )
193
+ # convert to msa id or null for single sequence mode
194
+ msa_id = (
195
+ msa
196
+ if isinstance(msa, str)
197
+ else msa.id if isinstance(msa, MSAFuture) else None
198
+ )
199
+ # add the protein in the expected boltz format
200
+ p = {
201
+ "id": protein.chain_id or next(id_gen),
202
+ "msa_id": msa_id,
203
+ "sequence": protein.sequence.decode(),
204
+ }
205
+ if protein.cyclic:
206
+ p["cyclic"] = protein.cyclic
207
+ sequences.append({"protein": p})
208
+ for dna in dnas or []:
209
+ d = {
210
+ "id": dna.chain_id or next(id_gen),
211
+ "sequence": dna.sequence,
212
+ }
213
+ if dna.cyclic:
214
+ d["cyclic"] = dna.cyclic
215
+ sequences.append(
216
+ {
217
+ "dna": d,
218
+ }
219
+ )
220
+ for rna in rnas or []:
221
+ r = {
222
+ "id": rna.chain_id or next(id_gen),
223
+ "sequence": rna.sequence,
224
+ }
225
+ if rna.cyclic:
226
+ r["cyclic"] = rna.cyclic
227
+ sequences.append(
228
+ {
229
+ "rna": r,
230
+ }
231
+ )
232
+ for ligand in ligands or []:
233
+ ligand_: dict = {"id": ligand.chain_id or next(id_gen)}
234
+ if ligand.ccd:
235
+ ligand_["ccd"] = ligand.ccd
236
+ if ligand.smiles:
237
+ ligand_["smiles"] = ligand.smiles
238
+ sequences.append({"ligand": ligand_})
239
+
240
+ if len(sequences) == 0:
241
+ raise ValueError("Expected proteins, dna, rna or ligands")
242
+
243
+ return FoldComplexResultFuture.create(
244
+ session=self.session,
245
+ job=api.fold_models_post(
246
+ session=self.session,
247
+ model_id=self.model_id,
248
+ sequences=sequences,
249
+ diffusion_samples=diffusion_samples,
250
+ recycling_steps=recycling_steps,
251
+ sampling_steps=sampling_steps,
252
+ step_scale=step_scale,
253
+ constraints=constraints,
254
+ use_potentials=use_potentials,
255
+ **kwargs,
256
+ ),
257
+ model_id=self.model_id,
258
+ proteins=proteins,
259
+ dnas=dnas,
260
+ rnas=rnas,
261
+ ligands=ligands,
262
+ )
263
+
264
+
265
+ class Boltz2Model(BoltzModel, FoldModel):
266
+ """
267
+ Class providing inference endpoints for Boltz-2 structure prediction model which jointly models complex structures and binding affinities.
268
+ """
269
+
270
+ model_id = "boltz-2"
271
+
272
+ def fold(
273
+ self,
274
+ proteins: list[Protein] | MSAFuture | None = None,
275
+ dnas: list[DNA] | None = None,
276
+ rnas: list[RNA] | None = None,
277
+ ligands: list[Ligand] | None = None,
278
+ diffusion_samples: int = 1,
279
+ recycling_steps: int = 3,
280
+ sampling_steps: int = 200,
281
+ step_scale: float = 1.638,
282
+ use_potentials: bool = False,
283
+ constraints: list[dict] | None = None,
284
+ templates: list[dict] | None = None,
285
+ properties: list[dict] | None = None,
286
+ method: str | None = None,
287
+ ) -> FoldComplexResultFuture:
288
+ """
289
+ Request structure prediction with Boltz-2 model.
290
+
291
+ Parameters
292
+ ----------
293
+ proteins : List[Protein] | MSAFuture | None
294
+ List of protein sequences to include in folded output. `Protein` objects must be tagged with an `msa`, which can be a `Protein.single_sequence_mode` for single sequence mode. Alternatively, supply an `MSAFuture` to use all query sequences as a multimer.
295
+ dna : List[DNA] | None
296
+ List of DNA sequences to include in folded output.
297
+ rna : List[RNA] | None
298
+ List of RNA sequences to include in folded output.
299
+ ligands : List[Ligand] | None
300
+ List of ligands to include in folded output.
301
+ diffusion_samples: int
302
+ Number of diffusion samples to use
303
+ recycling_steps : int
304
+ Number of recycling steps to use
305
+ sampling_steps : int
306
+ Number of sampling steps to use
307
+ step_scale : float
308
+ Scaling factor for diffusion steps.
309
+ use_potentials: bool = False.
310
+ Whether or not to use potentials.
311
+ constraints : list[dict] | None = None
312
+ List of constraints.
313
+ templates: list[dict] | None = None
314
+ List of templates to use for structure prediction.
315
+ properties: list[dict] | None = None
316
+ List of additional properties to predict. Should match the `BoltzProperties`
317
+ method: str | None
318
+ The experimental method or supervision source used for the prediction. Defults to None.
319
+ Supported values (case-insensitive) include:
320
+ 'MD', 'X-RAY DIFFRACTION', 'ELECTRON MICROSCOPY', 'SOLUTION NMR',
321
+ 'SOLID-STATE NMR', 'NEUTRON DIFFRACTION', 'ELECTRON CRYSTALLOGRAPHY',
322
+ 'FIBER DIFFRACTION', 'POWDER DIFFRACTION', 'INFRARED SPECTROSCOPY',
323
+ 'FLUORESCENCE TRANSFER', 'EPR', 'THEORETICAL MODEL',
324
+ 'SOLUTION SCATTERING', 'OTHER', 'AFDB', 'BOLTZ-1'.
325
+ View the documentation on Boltz for upstream details.
326
+
327
+ Returns
328
+ -------
329
+ FoldComplexResultFuture
330
+ Future for the folding result.
331
+ """
332
+
333
+ if templates is not None:
334
+ raise ValueError("`templates` not yet supported!")
335
+
336
+ # validate properties
337
+ if properties is not None:
338
+ props = TypeAdapter(list[BoltzProperty]).validate_python(properties)
339
+ # Only allow affinity for ligands, and check binder refers to a ligand chain_id (str, not list)
340
+ ligand_chain_ids = set()
341
+ if ligands:
342
+ for ligand in ligands:
343
+ if isinstance(ligand.chain_id, str):
344
+ ligand_chain_ids.add(ligand.chain_id)
345
+ elif isinstance(ligand.chain_id, list):
346
+ raise ValueError(
347
+ f"Ligand {ligand} has multiple chain_ids ({ligand.chain_id}); only single (str) chain_id allowed for affinity."
348
+ )
349
+ for prop in props:
350
+ if hasattr(prop, "affinity") and prop.affinity is not None:
351
+ binder_id = prop.affinity.binder
352
+ if binder_id not in ligand_chain_ids:
353
+ raise ValueError(
354
+ f"Affinity property binder '{binder_id}' does not match any ligand chain_id (must be a ligand with a single chain_id)."
355
+ )
356
+
357
+ return super().fold(
358
+ proteins=proteins,
359
+ dnas=dnas,
360
+ rnas=rnas,
361
+ ligands=ligands,
362
+ diffusion_samples=diffusion_samples,
363
+ recycling_steps=recycling_steps,
364
+ sampling_steps=sampling_steps,
365
+ step_scale=step_scale,
366
+ use_potentials=use_potentials,
367
+ constraints=constraints,
368
+ templates=templates,
369
+ properties=properties,
370
+ method=method,
371
+ )
372
+
373
+
374
+ class Boltz1xModel(BoltzModel, FoldModel):
375
+ """
376
+ Class providing inference endpoints for Boltz-1x open-source structure prediction model, which adds the use of inference potentials to improve performance.
377
+ """
378
+
379
+ model_id = "boltz-1x"
380
+
381
+ def fold(
382
+ self,
383
+ proteins: list[Protein] | MSAFuture | None = None,
384
+ dnas: list[DNA] | None = None,
385
+ rnas: list[RNA] | None = None,
386
+ ligands: list[Ligand] | None = None,
387
+ diffusion_samples: int = 1,
388
+ recycling_steps: int = 3,
389
+ sampling_steps: int = 200,
390
+ step_scale: float = 1.638,
391
+ constraints: list[dict] | None = None,
392
+ ) -> FoldComplexResultFuture:
393
+ """
394
+ Request structure prediction with Boltz-1x model. Uses potentials with Boltz-1 model.
395
+
396
+ Parameters
397
+ ----------
398
+ proteins : List[Protein] | MSAFuture | None
399
+ List of protein sequences to include in folded output. `Protein` objects must be tagged with an `msa`, which can be a `Protein.single_sequence_mode` for single sequence mode. Alternatively, supply an `MSAFuture` to use all query sequences as a multimer.
400
+ dna : List[DNA] | None
401
+ List of DNA sequences to include in folded output.
402
+ rna : List[RNA] | None
403
+ List of RNA sequences to include in folded output.
404
+ ligands : List[Ligand] | None
405
+ List of ligands to include in folded output.
406
+ diffusion_samples: int
407
+ Number of diffusion samples to use
408
+ recycling_steps : int
409
+ Number of recycling steps to use
410
+ sampling_steps : int
411
+ Number of sampling steps to use
412
+ step_scale : float
413
+ Scaling factor for diffusion steps.
414
+ constraints : Optional[List[dict]]
415
+ List of constraints.
416
+
417
+ Returns
418
+ -------
419
+ FoldComplexResultFuture
420
+ Future for the folding complex result.
421
+ """
422
+
423
+ return super().fold(
424
+ proteins=proteins,
425
+ dnas=dnas,
426
+ rnas=rnas,
427
+ ligands=ligands,
428
+ diffusion_samples=diffusion_samples,
429
+ recycling_steps=recycling_steps,
430
+ sampling_steps=sampling_steps,
431
+ step_scale=step_scale,
432
+ use_potentials=True,
433
+ constraints=constraints,
434
+ )
435
+
436
+
437
+ class Boltz1Model(BoltzModel, FoldModel):
438
+ """
439
+ Class providing inference endpoints for Boltz-1 open-source structure prediction model.
440
+ """
441
+
442
+ model_id = "boltz-1"
443
+
444
+ def fold(
445
+ self,
446
+ proteins: list[Protein] | MSAFuture | None = None,
447
+ dnas: list[DNA] | None = None,
448
+ rnas: list[RNA] | None = None,
449
+ ligands: list[Ligand] | None = None,
450
+ diffusion_samples: int = 1,
451
+ recycling_steps: int = 3,
452
+ sampling_steps: int = 200,
453
+ step_scale: float = 1.638,
454
+ use_potentials: bool = False,
455
+ constraints: list[dict] | None = None,
456
+ ) -> FoldComplexResultFuture:
457
+ """
458
+ Request structure prediction with Boltz-1 model.
459
+
460
+ Parameters
461
+ ----------
462
+ proteins : List[Protein] | MSAFuture | None
463
+ List of protein sequences to include in folded output. `Protein` objects must be tagged with an `msa`, which can be a `Protein.single_sequence_mode` for single sequence mode. Alternatively, supply an `MSAFuture` to use all query sequences as a multimer.
464
+ dna : List[DNA] | None
465
+ List of DNA sequences to include in folded output.
466
+ rna : List[RNA] | None
467
+ List of RNA sequences to include in folded output.
468
+ ligands : List[Ligand] | None
469
+ List of ligands to include in folded output.
470
+ diffusion_samples: int
471
+ Number of diffusion samples to use
472
+ recycling_steps : int
473
+ Number of recycling steps to use
474
+ sampling_steps : int
475
+ Number of sampling steps to use
476
+ step_scale : float
477
+ Scaling factor for diffusion steps.
478
+ use_potentials: bool = False.
479
+ Whether or not to use potentials.
480
+ constraints : Optional[List[dict]]
481
+ List of constraints.
482
+
483
+ Returns
484
+ -------
485
+ FoldComplexResultFuture
486
+ Future for the folding complex result.
487
+ """
488
+
489
+ return super().fold(
490
+ proteins=proteins,
491
+ dnas=dnas,
492
+ rnas=rnas,
493
+ ligands=ligands,
494
+ diffusion_samples=diffusion_samples,
495
+ recycling_steps=recycling_steps,
496
+ sampling_steps=sampling_steps,
497
+ step_scale=step_scale,
498
+ use_potentials=use_potentials,
499
+ constraints=constraints,
500
+ )
501
+
502
+
503
+ class BondConstraint(BaseModel):
504
+ """
505
+ Constraint specifying a covalent bond between two atoms.
506
+
507
+ Attributes
508
+ ----------
509
+ atom1 : list of (str or int)
510
+ The first atom, specified as [CHAIN_ID, RES_IDX, ATOM_NAME].
511
+ atom2 : list of (str or int)
512
+ The second atom, specified as [CHAIN_ID, RES_IDX, ATOM_NAME].
513
+ """
514
+
515
+ atom1: list[str | int]
516
+ atom2: list[str | int]
517
+
518
+
519
+ class PocketConstraint(BaseModel):
520
+ """
521
+ Constraint specifying a ligand pocket.
522
+
523
+ Attributes
524
+ ----------
525
+ binder : str
526
+ The chain ID of the binder.
527
+ contacts : list of list of (str or int)
528
+ List of contacts, each specified as [CHAIN_ID, RES_IDX/ATOM_NAME].
529
+ max_distance : float
530
+ Maximum distance in angstroms for the pocket constraint.
531
+ """
532
+
533
+ binder: str
534
+ contacts: list[list[str | int]]
535
+ max_distance: float
536
+
537
+
538
+ class ContactConstraint(BaseModel):
539
+ """
540
+ Constraint specifying a contact between two tokens.
541
+
542
+ Attributes
543
+ ----------
544
+ token1 : list of (str or int)
545
+ The first token, specified as [CHAIN_ID, RES_IDX/ATOM_NAME].
546
+ token2 : list of (str or int)
547
+ The second token, specified as [CHAIN_ID, RES_IDX/ATOM_NAME].
548
+ max_distance : float
549
+ Maximum distance in angstroms for the contact constraint.
550
+ """
551
+
552
+ token1: list[str | int]
553
+ token2: list[str | int]
554
+ max_distance: float
555
+
556
+
557
+ class BoltzConstraint(BaseModel):
558
+ """
559
+ Possible constraints for Boltz.
560
+
561
+ Attributes
562
+ ----------
563
+ bond : BondConstraint or None, optional
564
+ Covalent bond constraint.
565
+ pocket : PocketConstraint or None, optional
566
+ Pocket constraint.
567
+ contact : ContactConstraint or None, optional
568
+ Contact constraint.
569
+ """
570
+
571
+ bond: BondConstraint | None = None
572
+ pocket: PocketConstraint | None = None
573
+ contact: ContactConstraint | None = None
574
+
575
+ @model_validator(mode="after")
576
+ def check_exactly_one(cls, self):
577
+ fields = [self.bond, self.pocket, self.contact]
578
+ if sum(x is not None for x in fields) != 1:
579
+ raise ValueError(
580
+ "Exactly one of 'bond', 'pocket', or 'contact' must be set."
581
+ )
582
+ return self
583
+
584
+
585
+ class AffinityProperty(BaseModel):
586
+ """
587
+ Property specifying affinity computation.
588
+
589
+ Attributes
590
+ ----------
591
+ binder : str
592
+ The chain ID of the ligand for which to compute affinity.
593
+ """
594
+
595
+ binder: str
596
+
597
+
598
+ class BoltzProperty(BaseModel):
599
+ """
600
+ Properties (additionally) requested for computation.
601
+
602
+ Attributes
603
+ ----------
604
+ affinity : AffinityProperty
605
+ Affinity property specification.
606
+ """
607
+
608
+ # TODO handle more than more property
609
+ affinity: AffinityProperty
610
+
611
+
612
+ class BoltzConfidence(BaseModel):
613
+ """
614
+ Model representing the aggregated confidence scores for a prediction sample.
615
+
616
+ Attributes
617
+ ----------
618
+ confidence_score : float
619
+ Aggregated score used to sort the predictions, corresponds to
620
+ 0.8 * complex_plddt + 0.2 * iptm (ptm for single chains).
621
+ ptm : float
622
+ Predicted TM score for the complex.
623
+ iptm : float
624
+ Predicted TM score when aggregating at the interfaces.
625
+ ligand_iptm : float
626
+ ipTM but only aggregating at protein-ligand interfaces.
627
+ protein_iptm : float
628
+ ipTM but only aggregating at protein-protein interfaces.
629
+ complex_plddt : float
630
+ Average pLDDT score for the complex.
631
+ complex_iplddt : float
632
+ Average pLDDT score when upweighting interface tokens.
633
+ complex_pde : float
634
+ Average PDE score for the complex.
635
+ complex_ipde : float
636
+ Average PDE score when aggregating at interfaces.
637
+ chains_ptm : dict[str, float]
638
+ Predicted TM score within each chain, keyed by chain index as a string.
639
+ pair_chains_iptm : dict[str, dict[str, float]]
640
+ Predicted (interface) TM score between each pair of chains,
641
+ keyed by chain indices as strings.
642
+ """
643
+
644
+ confidence_score: float
645
+ ptm: float
646
+ iptm: float
647
+ ligand_iptm: float
648
+ protein_iptm: float
649
+ complex_plddt: float
650
+ complex_iplddt: float
651
+ complex_pde: float
652
+ complex_ipde: float
653
+ chains_ptm: dict[str, float]
654
+ pair_chains_iptm: dict[str, dict[str, float]]
655
+
656
+
657
+ class BoltzAffinity(BaseModel):
658
+ """
659
+ Output schema for Boltz affinity ensemble predictions.
660
+
661
+ Attributes
662
+ ----------
663
+ affinity_pred_value : float
664
+ Predicted binding affinity from the ensemble model.
665
+ affinity_probability_binary : float
666
+ Predicted binding likelihood from the ensemble model.
667
+ per_model : dict of str to float
668
+ Dictionary containing predictions from each individual model in the ensemble.
669
+ Keys are of the form 'affinity_pred_valueN' and 'affinity_probability_binaryN',
670
+ where N is the model index (e.g., 1, 2, 3, ...).
671
+
672
+ Notes
673
+ -----
674
+ Use the `parse_obj_with_models` class method to construct this object from a raw output
675
+ dictionary, which will automatically separate ensemble-level and per-model predictions.
676
+ """
677
+
678
+ affinity_pred_value: float
679
+ affinity_probability_binary: float
680
+ # Catch all other per-model fields
681
+ per_model: dict[str, float] = Field(default_factory=dict)
682
+
683
+ @classmethod
684
+ def parse_obj_with_models(cls, obj: dict):
685
+ # Extract fixed fields
686
+ fixed = {
687
+ "affinity_pred_value": obj.pop("affinity_pred_value"),
688
+ "affinity_probability_binary": obj.pop("affinity_probability_binary"),
689
+ }
690
+ # Everything else goes into per_model
691
+ return cls(**fixed, per_model=obj)