openprotein-python 0.8.2__1-py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openprotein/__init__.py +164 -0
- openprotein/_version.py +48 -0
- openprotein/align/__init__.py +8 -0
- openprotein/align/align.py +395 -0
- openprotein/align/api.py +428 -0
- openprotein/align/future.py +55 -0
- openprotein/align/msa.py +129 -0
- openprotein/align/schemas.py +165 -0
- openprotein/base.py +181 -0
- openprotein/chains.py +88 -0
- openprotein/common/__init__.py +5 -0
- openprotein/common/features.py +7 -0
- openprotein/common/model_metadata.py +33 -0
- openprotein/common/reduction.py +8 -0
- openprotein/config.py +9 -0
- openprotein/csv.py +31 -0
- openprotein/data/__init__.py +9 -0
- openprotein/data/api.py +218 -0
- openprotein/data/assaydataset.py +178 -0
- openprotein/data/data.py +93 -0
- openprotein/data/schemas.py +27 -0
- openprotein/design/__init__.py +16 -0
- openprotein/design/api.py +259 -0
- openprotein/design/design.py +125 -0
- openprotein/design/future.py +146 -0
- openprotein/design/schemas.py +607 -0
- openprotein/embeddings/__init__.py +27 -0
- openprotein/embeddings/api.py +619 -0
- openprotein/embeddings/embeddings.py +151 -0
- openprotein/embeddings/esm.py +33 -0
- openprotein/embeddings/future.py +146 -0
- openprotein/embeddings/models.py +421 -0
- openprotein/embeddings/openprotein.py +21 -0
- openprotein/embeddings/poet.py +446 -0
- openprotein/embeddings/poet2.py +505 -0
- openprotein/embeddings/schemas.py +78 -0
- openprotein/errors.py +76 -0
- openprotein/fasta.py +92 -0
- openprotein/fold/__init__.py +21 -0
- openprotein/fold/alphafold2.py +131 -0
- openprotein/fold/api.py +287 -0
- openprotein/fold/boltz.py +691 -0
- openprotein/fold/esmfold.py +54 -0
- openprotein/fold/fold.py +107 -0
- openprotein/fold/future.py +509 -0
- openprotein/fold/models.py +139 -0
- openprotein/fold/schemas.py +39 -0
- openprotein/jobs/__init__.py +9 -0
- openprotein/jobs/api.py +71 -0
- openprotein/jobs/futures.py +746 -0
- openprotein/jobs/jobs.py +69 -0
- openprotein/jobs/schemas.py +135 -0
- openprotein/models/__init__.py +4 -0
- openprotein/models/base.py +63 -0
- openprotein/models/foundation/rfdiffusion.py +283 -0
- openprotein/models/models.py +33 -0
- openprotein/predictor/__init__.py +25 -0
- openprotein/predictor/api.py +384 -0
- openprotein/predictor/models.py +374 -0
- openprotein/predictor/prediction.py +79 -0
- openprotein/predictor/predictor.py +242 -0
- openprotein/predictor/schemas.py +113 -0
- openprotein/predictor/validate.py +40 -0
- openprotein/prompt/__init__.py +9 -0
- openprotein/prompt/api.py +505 -0
- openprotein/prompt/models.py +142 -0
- openprotein/prompt/prompt.py +130 -0
- openprotein/prompt/schemas.py +49 -0
- openprotein/protein.py +587 -0
- openprotein/svd/__init__.py +9 -0
- openprotein/svd/api.py +206 -0
- openprotein/svd/models.py +288 -0
- openprotein/svd/schemas.py +31 -0
- openprotein/svd/svd.py +134 -0
- openprotein/umap/__init__.py +9 -0
- openprotein/umap/api.py +259 -0
- openprotein/umap/models.py +211 -0
- openprotein/umap/schemas.py +35 -0
- openprotein/umap/umap.py +175 -0
- openprotein/utils/uuid.py +29 -0
- openprotein_python-0.8.2.dist-info/METADATA +176 -0
- openprotein_python-0.8.2.dist-info/RECORD +84 -0
- openprotein_python-0.8.2.dist-info/WHEEL +4 -0
- openprotein_python-0.8.2.dist-info/licenses/LICENSE.txt +30 -0
|
@@ -0,0 +1,691 @@
|
|
|
1
|
+
"""Community-based Boltz models for complex structure prediction with ligands/dna/rna."""
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
import string
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from pydantic import BaseModel, Field, TypeAdapter, model_validator
|
|
8
|
+
|
|
9
|
+
from openprotein.align import AlignAPI, MSAFuture
|
|
10
|
+
from openprotein.base import APISession
|
|
11
|
+
from openprotein.chains import DNA, RNA, Ligand
|
|
12
|
+
from openprotein.common import ModelMetadata
|
|
13
|
+
from openprotein.protein import Protein
|
|
14
|
+
|
|
15
|
+
from . import api
|
|
16
|
+
from .future import FoldComplexResultFuture
|
|
17
|
+
from .models import FoldModel
|
|
18
|
+
|
|
19
|
+
valid_id_pattern = re.compile(r"^[A-Z]{1,5}$|^\d{1,5}$")
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def is_valid_id(id_str: str) -> bool:
|
|
23
|
+
"""
|
|
24
|
+
Check if the id_str matches the valid pattern for IDs (1-5 uppercase or 1-5 digits).
|
|
25
|
+
"""
|
|
26
|
+
if not id_str or len(id_str) > 5:
|
|
27
|
+
return False
|
|
28
|
+
return bool(valid_id_pattern.fullmatch(id_str))
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def id_generator(used_ids: list[str] | None = None, max_alpha_len=5, max_numeric=99999):
|
|
32
|
+
"""
|
|
33
|
+
Yields new chain IDs, skipping any in 'used_ids'.
|
|
34
|
+
First A..Z, AA..ZZ, … up to max_alpha_len, then '1','2',… up to max_numeric.
|
|
35
|
+
"""
|
|
36
|
+
used = set(tuple(used_ids or []))
|
|
37
|
+
letters = list(string.ascii_uppercase)
|
|
38
|
+
|
|
39
|
+
# --- Alphabetic IDs ---
|
|
40
|
+
curr_len = 1
|
|
41
|
+
curr_indices = [0] * curr_len # start at 'A'
|
|
42
|
+
|
|
43
|
+
def bump_indices():
|
|
44
|
+
# lexicographically increment curr_indices; return False on overflow
|
|
45
|
+
for i in reversed(range(len(curr_indices))):
|
|
46
|
+
if curr_indices[i] < len(letters) - 1:
|
|
47
|
+
curr_indices[i] += 1
|
|
48
|
+
for j in range(i + 1, len(curr_indices)):
|
|
49
|
+
curr_indices[j] = 0
|
|
50
|
+
return True
|
|
51
|
+
return False
|
|
52
|
+
|
|
53
|
+
while curr_len <= max_alpha_len:
|
|
54
|
+
candidate = "".join(letters[i] for i in curr_indices)
|
|
55
|
+
if candidate not in used:
|
|
56
|
+
used.add(candidate)
|
|
57
|
+
yield candidate
|
|
58
|
+
# bump
|
|
59
|
+
if not bump_indices():
|
|
60
|
+
curr_len += 1
|
|
61
|
+
if curr_len > max_alpha_len:
|
|
62
|
+
break
|
|
63
|
+
curr_indices = [0] * curr_len
|
|
64
|
+
|
|
65
|
+
# --- Numeric IDs ---
|
|
66
|
+
num = 1
|
|
67
|
+
while num <= max_numeric:
|
|
68
|
+
candidate = str(num)
|
|
69
|
+
num += 1
|
|
70
|
+
if candidate not in used:
|
|
71
|
+
used.add(candidate)
|
|
72
|
+
yield candidate
|
|
73
|
+
|
|
74
|
+
# exhausted
|
|
75
|
+
raise RuntimeError("exhausted all possible IDs")
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class BoltzModel(FoldModel):
|
|
79
|
+
"""
|
|
80
|
+
Class providing inference endpoints for Boltz structure prediction models.
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
model_id: str = "boltz"
|
|
84
|
+
|
|
85
|
+
def __init__(
|
|
86
|
+
self,
|
|
87
|
+
session: APISession,
|
|
88
|
+
model_id: str,
|
|
89
|
+
metadata: ModelMetadata | None = None,
|
|
90
|
+
):
|
|
91
|
+
super().__init__(session, model_id, metadata)
|
|
92
|
+
|
|
93
|
+
def fold(
|
|
94
|
+
self,
|
|
95
|
+
proteins: list[Protein] | MSAFuture | None = None,
|
|
96
|
+
dnas: list[DNA] | None = None,
|
|
97
|
+
rnas: list[RNA] | None = None,
|
|
98
|
+
ligands: list[Ligand] | None = None,
|
|
99
|
+
diffusion_samples: int = 1,
|
|
100
|
+
recycling_steps: int = 3,
|
|
101
|
+
sampling_steps: int = 200,
|
|
102
|
+
step_scale: float = 1.638,
|
|
103
|
+
use_potentials: bool = False,
|
|
104
|
+
constraints: list[dict] | None = None,
|
|
105
|
+
**kwargs,
|
|
106
|
+
) -> FoldComplexResultFuture:
|
|
107
|
+
"""
|
|
108
|
+
Request structure prediction with boltz model.
|
|
109
|
+
|
|
110
|
+
Parameters
|
|
111
|
+
----------
|
|
112
|
+
proteins : List[Protein] | MSAFuture | None
|
|
113
|
+
List of protein sequences to include in folded output. `Protein` objects must be tagged with an `msa`, which can be a `Protein.single_sequence_mode` for single sequence mode. Alternatively, supply an `MSAFuture` to use all query sequences as a multimer.
|
|
114
|
+
dna : List[DNA] | None
|
|
115
|
+
List of DNA sequences to include in folded output.
|
|
116
|
+
rna : List[RNA] | None
|
|
117
|
+
List of RNA sequences to include in folded output.
|
|
118
|
+
ligands : List[Ligand] | None
|
|
119
|
+
List of ligands to include in folded output.
|
|
120
|
+
diffusion_samples: int
|
|
121
|
+
Number of diffusion samples to use
|
|
122
|
+
recycling_steps : int
|
|
123
|
+
Number of recycling steps to use
|
|
124
|
+
sampling_steps : int
|
|
125
|
+
Number of sampling steps to use
|
|
126
|
+
step_scale : float
|
|
127
|
+
Scaling factor for diffusion steps.
|
|
128
|
+
constraints : Optional[List[dict]]
|
|
129
|
+
List of constraints.
|
|
130
|
+
|
|
131
|
+
Returns
|
|
132
|
+
-------
|
|
133
|
+
FoldComplexResultFuture
|
|
134
|
+
Future for the folding complex result.
|
|
135
|
+
"""
|
|
136
|
+
# validate constraints
|
|
137
|
+
if constraints is not None:
|
|
138
|
+
TypeAdapter(list[BoltzConstraint]).validate_python(constraints)
|
|
139
|
+
# collate the id's used
|
|
140
|
+
used_ids = []
|
|
141
|
+
if isinstance(proteins, list):
|
|
142
|
+
for protein in proteins:
|
|
143
|
+
if isinstance(protein, Protein) and protein.chain_id is not None:
|
|
144
|
+
if isinstance(protein.chain_id, str):
|
|
145
|
+
used_ids.append(protein.chain_id)
|
|
146
|
+
elif isinstance(protein.chain_id, list):
|
|
147
|
+
used_ids.extend(protein.chain_id)
|
|
148
|
+
for dna in dnas or []:
|
|
149
|
+
if isinstance(dna.chain_id, str):
|
|
150
|
+
used_ids.append(dna.chain_id)
|
|
151
|
+
elif isinstance(dna.chain_id, list):
|
|
152
|
+
used_ids.extend(dna.chain_id)
|
|
153
|
+
for rna in rnas or []:
|
|
154
|
+
if isinstance(rna.chain_id, str):
|
|
155
|
+
used_ids.append(rna.chain_id)
|
|
156
|
+
elif isinstance(rna.chain_id, list):
|
|
157
|
+
used_ids.extend(rna.chain_id)
|
|
158
|
+
for ligand in ligands or []:
|
|
159
|
+
if isinstance(ligand.chain_id, str):
|
|
160
|
+
used_ids.append(ligand.chain_id)
|
|
161
|
+
elif isinstance(ligand.chain_id, list):
|
|
162
|
+
used_ids.extend(ligand.chain_id)
|
|
163
|
+
id_gen = id_generator(used_ids)
|
|
164
|
+
# build the proteins from msa
|
|
165
|
+
if isinstance(proteins, MSAFuture):
|
|
166
|
+
align_api = getattr(self.session, "align", None)
|
|
167
|
+
assert isinstance(align_api, AlignAPI)
|
|
168
|
+
msa = proteins # rename
|
|
169
|
+
proteins = [] # convert back to list of proteins
|
|
170
|
+
seed = align_api.get_seed(job_id=msa.job.job_id)
|
|
171
|
+
query_seqs_cardinality: dict[str, int] = dict()
|
|
172
|
+
for seq in seed.split(":"):
|
|
173
|
+
query_seqs_cardinality[seq] = query_seqs_cardinality.get(seq, 0) + 1
|
|
174
|
+
for seq, card in query_seqs_cardinality.items():
|
|
175
|
+
protein = Protein(sequence=seq)
|
|
176
|
+
if card == 1:
|
|
177
|
+
id = next(id_gen)
|
|
178
|
+
else:
|
|
179
|
+
id = [next(id_gen) for _ in range(card)]
|
|
180
|
+
protein.chain_id = id
|
|
181
|
+
protein.msa = msa
|
|
182
|
+
proteins.append(protein)
|
|
183
|
+
|
|
184
|
+
# build the sequences input
|
|
185
|
+
sequences: list[dict[str, Any]] = []
|
|
186
|
+
for protein in proteins or []:
|
|
187
|
+
# check the msa
|
|
188
|
+
msa = protein.msa
|
|
189
|
+
if msa is None:
|
|
190
|
+
raise ValueError(
|
|
191
|
+
"Expected all protein sequences to have `.msa` set with an `MSAFuture` or `Protein.single_sequence_mode` for single sequence mode."
|
|
192
|
+
)
|
|
193
|
+
# convert to msa id or null for single sequence mode
|
|
194
|
+
msa_id = (
|
|
195
|
+
msa
|
|
196
|
+
if isinstance(msa, str)
|
|
197
|
+
else msa.id if isinstance(msa, MSAFuture) else None
|
|
198
|
+
)
|
|
199
|
+
# add the protein in the expected boltz format
|
|
200
|
+
p = {
|
|
201
|
+
"id": protein.chain_id or next(id_gen),
|
|
202
|
+
"msa_id": msa_id,
|
|
203
|
+
"sequence": protein.sequence.decode(),
|
|
204
|
+
}
|
|
205
|
+
if protein.cyclic:
|
|
206
|
+
p["cyclic"] = protein.cyclic
|
|
207
|
+
sequences.append({"protein": p})
|
|
208
|
+
for dna in dnas or []:
|
|
209
|
+
d = {
|
|
210
|
+
"id": dna.chain_id or next(id_gen),
|
|
211
|
+
"sequence": dna.sequence,
|
|
212
|
+
}
|
|
213
|
+
if dna.cyclic:
|
|
214
|
+
d["cyclic"] = dna.cyclic
|
|
215
|
+
sequences.append(
|
|
216
|
+
{
|
|
217
|
+
"dna": d,
|
|
218
|
+
}
|
|
219
|
+
)
|
|
220
|
+
for rna in rnas or []:
|
|
221
|
+
r = {
|
|
222
|
+
"id": rna.chain_id or next(id_gen),
|
|
223
|
+
"sequence": rna.sequence,
|
|
224
|
+
}
|
|
225
|
+
if rna.cyclic:
|
|
226
|
+
r["cyclic"] = rna.cyclic
|
|
227
|
+
sequences.append(
|
|
228
|
+
{
|
|
229
|
+
"rna": r,
|
|
230
|
+
}
|
|
231
|
+
)
|
|
232
|
+
for ligand in ligands or []:
|
|
233
|
+
ligand_: dict = {"id": ligand.chain_id or next(id_gen)}
|
|
234
|
+
if ligand.ccd:
|
|
235
|
+
ligand_["ccd"] = ligand.ccd
|
|
236
|
+
if ligand.smiles:
|
|
237
|
+
ligand_["smiles"] = ligand.smiles
|
|
238
|
+
sequences.append({"ligand": ligand_})
|
|
239
|
+
|
|
240
|
+
if len(sequences) == 0:
|
|
241
|
+
raise ValueError("Expected proteins, dna, rna or ligands")
|
|
242
|
+
|
|
243
|
+
return FoldComplexResultFuture.create(
|
|
244
|
+
session=self.session,
|
|
245
|
+
job=api.fold_models_post(
|
|
246
|
+
session=self.session,
|
|
247
|
+
model_id=self.model_id,
|
|
248
|
+
sequences=sequences,
|
|
249
|
+
diffusion_samples=diffusion_samples,
|
|
250
|
+
recycling_steps=recycling_steps,
|
|
251
|
+
sampling_steps=sampling_steps,
|
|
252
|
+
step_scale=step_scale,
|
|
253
|
+
constraints=constraints,
|
|
254
|
+
use_potentials=use_potentials,
|
|
255
|
+
**kwargs,
|
|
256
|
+
),
|
|
257
|
+
model_id=self.model_id,
|
|
258
|
+
proteins=proteins,
|
|
259
|
+
dnas=dnas,
|
|
260
|
+
rnas=rnas,
|
|
261
|
+
ligands=ligands,
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
class Boltz2Model(BoltzModel, FoldModel):
|
|
266
|
+
"""
|
|
267
|
+
Class providing inference endpoints for Boltz-2 structure prediction model which jointly models complex structures and binding affinities.
|
|
268
|
+
"""
|
|
269
|
+
|
|
270
|
+
model_id = "boltz-2"
|
|
271
|
+
|
|
272
|
+
def fold(
|
|
273
|
+
self,
|
|
274
|
+
proteins: list[Protein] | MSAFuture | None = None,
|
|
275
|
+
dnas: list[DNA] | None = None,
|
|
276
|
+
rnas: list[RNA] | None = None,
|
|
277
|
+
ligands: list[Ligand] | None = None,
|
|
278
|
+
diffusion_samples: int = 1,
|
|
279
|
+
recycling_steps: int = 3,
|
|
280
|
+
sampling_steps: int = 200,
|
|
281
|
+
step_scale: float = 1.638,
|
|
282
|
+
use_potentials: bool = False,
|
|
283
|
+
constraints: list[dict] | None = None,
|
|
284
|
+
templates: list[dict] | None = None,
|
|
285
|
+
properties: list[dict] | None = None,
|
|
286
|
+
method: str | None = None,
|
|
287
|
+
) -> FoldComplexResultFuture:
|
|
288
|
+
"""
|
|
289
|
+
Request structure prediction with Boltz-2 model.
|
|
290
|
+
|
|
291
|
+
Parameters
|
|
292
|
+
----------
|
|
293
|
+
proteins : List[Protein] | MSAFuture | None
|
|
294
|
+
List of protein sequences to include in folded output. `Protein` objects must be tagged with an `msa`, which can be a `Protein.single_sequence_mode` for single sequence mode. Alternatively, supply an `MSAFuture` to use all query sequences as a multimer.
|
|
295
|
+
dna : List[DNA] | None
|
|
296
|
+
List of DNA sequences to include in folded output.
|
|
297
|
+
rna : List[RNA] | None
|
|
298
|
+
List of RNA sequences to include in folded output.
|
|
299
|
+
ligands : List[Ligand] | None
|
|
300
|
+
List of ligands to include in folded output.
|
|
301
|
+
diffusion_samples: int
|
|
302
|
+
Number of diffusion samples to use
|
|
303
|
+
recycling_steps : int
|
|
304
|
+
Number of recycling steps to use
|
|
305
|
+
sampling_steps : int
|
|
306
|
+
Number of sampling steps to use
|
|
307
|
+
step_scale : float
|
|
308
|
+
Scaling factor for diffusion steps.
|
|
309
|
+
use_potentials: bool = False.
|
|
310
|
+
Whether or not to use potentials.
|
|
311
|
+
constraints : list[dict] | None = None
|
|
312
|
+
List of constraints.
|
|
313
|
+
templates: list[dict] | None = None
|
|
314
|
+
List of templates to use for structure prediction.
|
|
315
|
+
properties: list[dict] | None = None
|
|
316
|
+
List of additional properties to predict. Should match the `BoltzProperties`
|
|
317
|
+
method: str | None
|
|
318
|
+
The experimental method or supervision source used for the prediction. Defults to None.
|
|
319
|
+
Supported values (case-insensitive) include:
|
|
320
|
+
'MD', 'X-RAY DIFFRACTION', 'ELECTRON MICROSCOPY', 'SOLUTION NMR',
|
|
321
|
+
'SOLID-STATE NMR', 'NEUTRON DIFFRACTION', 'ELECTRON CRYSTALLOGRAPHY',
|
|
322
|
+
'FIBER DIFFRACTION', 'POWDER DIFFRACTION', 'INFRARED SPECTROSCOPY',
|
|
323
|
+
'FLUORESCENCE TRANSFER', 'EPR', 'THEORETICAL MODEL',
|
|
324
|
+
'SOLUTION SCATTERING', 'OTHER', 'AFDB', 'BOLTZ-1'.
|
|
325
|
+
View the documentation on Boltz for upstream details.
|
|
326
|
+
|
|
327
|
+
Returns
|
|
328
|
+
-------
|
|
329
|
+
FoldComplexResultFuture
|
|
330
|
+
Future for the folding result.
|
|
331
|
+
"""
|
|
332
|
+
|
|
333
|
+
if templates is not None:
|
|
334
|
+
raise ValueError("`templates` not yet supported!")
|
|
335
|
+
|
|
336
|
+
# validate properties
|
|
337
|
+
if properties is not None:
|
|
338
|
+
props = TypeAdapter(list[BoltzProperty]).validate_python(properties)
|
|
339
|
+
# Only allow affinity for ligands, and check binder refers to a ligand chain_id (str, not list)
|
|
340
|
+
ligand_chain_ids = set()
|
|
341
|
+
if ligands:
|
|
342
|
+
for ligand in ligands:
|
|
343
|
+
if isinstance(ligand.chain_id, str):
|
|
344
|
+
ligand_chain_ids.add(ligand.chain_id)
|
|
345
|
+
elif isinstance(ligand.chain_id, list):
|
|
346
|
+
raise ValueError(
|
|
347
|
+
f"Ligand {ligand} has multiple chain_ids ({ligand.chain_id}); only single (str) chain_id allowed for affinity."
|
|
348
|
+
)
|
|
349
|
+
for prop in props:
|
|
350
|
+
if hasattr(prop, "affinity") and prop.affinity is not None:
|
|
351
|
+
binder_id = prop.affinity.binder
|
|
352
|
+
if binder_id not in ligand_chain_ids:
|
|
353
|
+
raise ValueError(
|
|
354
|
+
f"Affinity property binder '{binder_id}' does not match any ligand chain_id (must be a ligand with a single chain_id)."
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
return super().fold(
|
|
358
|
+
proteins=proteins,
|
|
359
|
+
dnas=dnas,
|
|
360
|
+
rnas=rnas,
|
|
361
|
+
ligands=ligands,
|
|
362
|
+
diffusion_samples=diffusion_samples,
|
|
363
|
+
recycling_steps=recycling_steps,
|
|
364
|
+
sampling_steps=sampling_steps,
|
|
365
|
+
step_scale=step_scale,
|
|
366
|
+
use_potentials=use_potentials,
|
|
367
|
+
constraints=constraints,
|
|
368
|
+
templates=templates,
|
|
369
|
+
properties=properties,
|
|
370
|
+
method=method,
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
class Boltz1xModel(BoltzModel, FoldModel):
|
|
375
|
+
"""
|
|
376
|
+
Class providing inference endpoints for Boltz-1x open-source structure prediction model, which adds the use of inference potentials to improve performance.
|
|
377
|
+
"""
|
|
378
|
+
|
|
379
|
+
model_id = "boltz-1x"
|
|
380
|
+
|
|
381
|
+
def fold(
|
|
382
|
+
self,
|
|
383
|
+
proteins: list[Protein] | MSAFuture | None = None,
|
|
384
|
+
dnas: list[DNA] | None = None,
|
|
385
|
+
rnas: list[RNA] | None = None,
|
|
386
|
+
ligands: list[Ligand] | None = None,
|
|
387
|
+
diffusion_samples: int = 1,
|
|
388
|
+
recycling_steps: int = 3,
|
|
389
|
+
sampling_steps: int = 200,
|
|
390
|
+
step_scale: float = 1.638,
|
|
391
|
+
constraints: list[dict] | None = None,
|
|
392
|
+
) -> FoldComplexResultFuture:
|
|
393
|
+
"""
|
|
394
|
+
Request structure prediction with Boltz-1x model. Uses potentials with Boltz-1 model.
|
|
395
|
+
|
|
396
|
+
Parameters
|
|
397
|
+
----------
|
|
398
|
+
proteins : List[Protein] | MSAFuture | None
|
|
399
|
+
List of protein sequences to include in folded output. `Protein` objects must be tagged with an `msa`, which can be a `Protein.single_sequence_mode` for single sequence mode. Alternatively, supply an `MSAFuture` to use all query sequences as a multimer.
|
|
400
|
+
dna : List[DNA] | None
|
|
401
|
+
List of DNA sequences to include in folded output.
|
|
402
|
+
rna : List[RNA] | None
|
|
403
|
+
List of RNA sequences to include in folded output.
|
|
404
|
+
ligands : List[Ligand] | None
|
|
405
|
+
List of ligands to include in folded output.
|
|
406
|
+
diffusion_samples: int
|
|
407
|
+
Number of diffusion samples to use
|
|
408
|
+
recycling_steps : int
|
|
409
|
+
Number of recycling steps to use
|
|
410
|
+
sampling_steps : int
|
|
411
|
+
Number of sampling steps to use
|
|
412
|
+
step_scale : float
|
|
413
|
+
Scaling factor for diffusion steps.
|
|
414
|
+
constraints : Optional[List[dict]]
|
|
415
|
+
List of constraints.
|
|
416
|
+
|
|
417
|
+
Returns
|
|
418
|
+
-------
|
|
419
|
+
FoldComplexResultFuture
|
|
420
|
+
Future for the folding complex result.
|
|
421
|
+
"""
|
|
422
|
+
|
|
423
|
+
return super().fold(
|
|
424
|
+
proteins=proteins,
|
|
425
|
+
dnas=dnas,
|
|
426
|
+
rnas=rnas,
|
|
427
|
+
ligands=ligands,
|
|
428
|
+
diffusion_samples=diffusion_samples,
|
|
429
|
+
recycling_steps=recycling_steps,
|
|
430
|
+
sampling_steps=sampling_steps,
|
|
431
|
+
step_scale=step_scale,
|
|
432
|
+
use_potentials=True,
|
|
433
|
+
constraints=constraints,
|
|
434
|
+
)
|
|
435
|
+
|
|
436
|
+
|
|
437
|
+
class Boltz1Model(BoltzModel, FoldModel):
|
|
438
|
+
"""
|
|
439
|
+
Class providing inference endpoints for Boltz-1 open-source structure prediction model.
|
|
440
|
+
"""
|
|
441
|
+
|
|
442
|
+
model_id = "boltz-1"
|
|
443
|
+
|
|
444
|
+
def fold(
|
|
445
|
+
self,
|
|
446
|
+
proteins: list[Protein] | MSAFuture | None = None,
|
|
447
|
+
dnas: list[DNA] | None = None,
|
|
448
|
+
rnas: list[RNA] | None = None,
|
|
449
|
+
ligands: list[Ligand] | None = None,
|
|
450
|
+
diffusion_samples: int = 1,
|
|
451
|
+
recycling_steps: int = 3,
|
|
452
|
+
sampling_steps: int = 200,
|
|
453
|
+
step_scale: float = 1.638,
|
|
454
|
+
use_potentials: bool = False,
|
|
455
|
+
constraints: list[dict] | None = None,
|
|
456
|
+
) -> FoldComplexResultFuture:
|
|
457
|
+
"""
|
|
458
|
+
Request structure prediction with Boltz-1 model.
|
|
459
|
+
|
|
460
|
+
Parameters
|
|
461
|
+
----------
|
|
462
|
+
proteins : List[Protein] | MSAFuture | None
|
|
463
|
+
List of protein sequences to include in folded output. `Protein` objects must be tagged with an `msa`, which can be a `Protein.single_sequence_mode` for single sequence mode. Alternatively, supply an `MSAFuture` to use all query sequences as a multimer.
|
|
464
|
+
dna : List[DNA] | None
|
|
465
|
+
List of DNA sequences to include in folded output.
|
|
466
|
+
rna : List[RNA] | None
|
|
467
|
+
List of RNA sequences to include in folded output.
|
|
468
|
+
ligands : List[Ligand] | None
|
|
469
|
+
List of ligands to include in folded output.
|
|
470
|
+
diffusion_samples: int
|
|
471
|
+
Number of diffusion samples to use
|
|
472
|
+
recycling_steps : int
|
|
473
|
+
Number of recycling steps to use
|
|
474
|
+
sampling_steps : int
|
|
475
|
+
Number of sampling steps to use
|
|
476
|
+
step_scale : float
|
|
477
|
+
Scaling factor for diffusion steps.
|
|
478
|
+
use_potentials: bool = False.
|
|
479
|
+
Whether or not to use potentials.
|
|
480
|
+
constraints : Optional[List[dict]]
|
|
481
|
+
List of constraints.
|
|
482
|
+
|
|
483
|
+
Returns
|
|
484
|
+
-------
|
|
485
|
+
FoldComplexResultFuture
|
|
486
|
+
Future for the folding complex result.
|
|
487
|
+
"""
|
|
488
|
+
|
|
489
|
+
return super().fold(
|
|
490
|
+
proteins=proteins,
|
|
491
|
+
dnas=dnas,
|
|
492
|
+
rnas=rnas,
|
|
493
|
+
ligands=ligands,
|
|
494
|
+
diffusion_samples=diffusion_samples,
|
|
495
|
+
recycling_steps=recycling_steps,
|
|
496
|
+
sampling_steps=sampling_steps,
|
|
497
|
+
step_scale=step_scale,
|
|
498
|
+
use_potentials=use_potentials,
|
|
499
|
+
constraints=constraints,
|
|
500
|
+
)
|
|
501
|
+
|
|
502
|
+
|
|
503
|
+
class BondConstraint(BaseModel):
|
|
504
|
+
"""
|
|
505
|
+
Constraint specifying a covalent bond between two atoms.
|
|
506
|
+
|
|
507
|
+
Attributes
|
|
508
|
+
----------
|
|
509
|
+
atom1 : list of (str or int)
|
|
510
|
+
The first atom, specified as [CHAIN_ID, RES_IDX, ATOM_NAME].
|
|
511
|
+
atom2 : list of (str or int)
|
|
512
|
+
The second atom, specified as [CHAIN_ID, RES_IDX, ATOM_NAME].
|
|
513
|
+
"""
|
|
514
|
+
|
|
515
|
+
atom1: list[str | int]
|
|
516
|
+
atom2: list[str | int]
|
|
517
|
+
|
|
518
|
+
|
|
519
|
+
class PocketConstraint(BaseModel):
|
|
520
|
+
"""
|
|
521
|
+
Constraint specifying a ligand pocket.
|
|
522
|
+
|
|
523
|
+
Attributes
|
|
524
|
+
----------
|
|
525
|
+
binder : str
|
|
526
|
+
The chain ID of the binder.
|
|
527
|
+
contacts : list of list of (str or int)
|
|
528
|
+
List of contacts, each specified as [CHAIN_ID, RES_IDX/ATOM_NAME].
|
|
529
|
+
max_distance : float
|
|
530
|
+
Maximum distance in angstroms for the pocket constraint.
|
|
531
|
+
"""
|
|
532
|
+
|
|
533
|
+
binder: str
|
|
534
|
+
contacts: list[list[str | int]]
|
|
535
|
+
max_distance: float
|
|
536
|
+
|
|
537
|
+
|
|
538
|
+
class ContactConstraint(BaseModel):
|
|
539
|
+
"""
|
|
540
|
+
Constraint specifying a contact between two tokens.
|
|
541
|
+
|
|
542
|
+
Attributes
|
|
543
|
+
----------
|
|
544
|
+
token1 : list of (str or int)
|
|
545
|
+
The first token, specified as [CHAIN_ID, RES_IDX/ATOM_NAME].
|
|
546
|
+
token2 : list of (str or int)
|
|
547
|
+
The second token, specified as [CHAIN_ID, RES_IDX/ATOM_NAME].
|
|
548
|
+
max_distance : float
|
|
549
|
+
Maximum distance in angstroms for the contact constraint.
|
|
550
|
+
"""
|
|
551
|
+
|
|
552
|
+
token1: list[str | int]
|
|
553
|
+
token2: list[str | int]
|
|
554
|
+
max_distance: float
|
|
555
|
+
|
|
556
|
+
|
|
557
|
+
class BoltzConstraint(BaseModel):
|
|
558
|
+
"""
|
|
559
|
+
Possible constraints for Boltz.
|
|
560
|
+
|
|
561
|
+
Attributes
|
|
562
|
+
----------
|
|
563
|
+
bond : BondConstraint or None, optional
|
|
564
|
+
Covalent bond constraint.
|
|
565
|
+
pocket : PocketConstraint or None, optional
|
|
566
|
+
Pocket constraint.
|
|
567
|
+
contact : ContactConstraint or None, optional
|
|
568
|
+
Contact constraint.
|
|
569
|
+
"""
|
|
570
|
+
|
|
571
|
+
bond: BondConstraint | None = None
|
|
572
|
+
pocket: PocketConstraint | None = None
|
|
573
|
+
contact: ContactConstraint | None = None
|
|
574
|
+
|
|
575
|
+
@model_validator(mode="after")
|
|
576
|
+
def check_exactly_one(cls, self):
|
|
577
|
+
fields = [self.bond, self.pocket, self.contact]
|
|
578
|
+
if sum(x is not None for x in fields) != 1:
|
|
579
|
+
raise ValueError(
|
|
580
|
+
"Exactly one of 'bond', 'pocket', or 'contact' must be set."
|
|
581
|
+
)
|
|
582
|
+
return self
|
|
583
|
+
|
|
584
|
+
|
|
585
|
+
class AffinityProperty(BaseModel):
|
|
586
|
+
"""
|
|
587
|
+
Property specifying affinity computation.
|
|
588
|
+
|
|
589
|
+
Attributes
|
|
590
|
+
----------
|
|
591
|
+
binder : str
|
|
592
|
+
The chain ID of the ligand for which to compute affinity.
|
|
593
|
+
"""
|
|
594
|
+
|
|
595
|
+
binder: str
|
|
596
|
+
|
|
597
|
+
|
|
598
|
+
class BoltzProperty(BaseModel):
|
|
599
|
+
"""
|
|
600
|
+
Properties (additionally) requested for computation.
|
|
601
|
+
|
|
602
|
+
Attributes
|
|
603
|
+
----------
|
|
604
|
+
affinity : AffinityProperty
|
|
605
|
+
Affinity property specification.
|
|
606
|
+
"""
|
|
607
|
+
|
|
608
|
+
# TODO handle more than more property
|
|
609
|
+
affinity: AffinityProperty
|
|
610
|
+
|
|
611
|
+
|
|
612
|
+
class BoltzConfidence(BaseModel):
|
|
613
|
+
"""
|
|
614
|
+
Model representing the aggregated confidence scores for a prediction sample.
|
|
615
|
+
|
|
616
|
+
Attributes
|
|
617
|
+
----------
|
|
618
|
+
confidence_score : float
|
|
619
|
+
Aggregated score used to sort the predictions, corresponds to
|
|
620
|
+
0.8 * complex_plddt + 0.2 * iptm (ptm for single chains).
|
|
621
|
+
ptm : float
|
|
622
|
+
Predicted TM score for the complex.
|
|
623
|
+
iptm : float
|
|
624
|
+
Predicted TM score when aggregating at the interfaces.
|
|
625
|
+
ligand_iptm : float
|
|
626
|
+
ipTM but only aggregating at protein-ligand interfaces.
|
|
627
|
+
protein_iptm : float
|
|
628
|
+
ipTM but only aggregating at protein-protein interfaces.
|
|
629
|
+
complex_plddt : float
|
|
630
|
+
Average pLDDT score for the complex.
|
|
631
|
+
complex_iplddt : float
|
|
632
|
+
Average pLDDT score when upweighting interface tokens.
|
|
633
|
+
complex_pde : float
|
|
634
|
+
Average PDE score for the complex.
|
|
635
|
+
complex_ipde : float
|
|
636
|
+
Average PDE score when aggregating at interfaces.
|
|
637
|
+
chains_ptm : dict[str, float]
|
|
638
|
+
Predicted TM score within each chain, keyed by chain index as a string.
|
|
639
|
+
pair_chains_iptm : dict[str, dict[str, float]]
|
|
640
|
+
Predicted (interface) TM score between each pair of chains,
|
|
641
|
+
keyed by chain indices as strings.
|
|
642
|
+
"""
|
|
643
|
+
|
|
644
|
+
confidence_score: float
|
|
645
|
+
ptm: float
|
|
646
|
+
iptm: float
|
|
647
|
+
ligand_iptm: float
|
|
648
|
+
protein_iptm: float
|
|
649
|
+
complex_plddt: float
|
|
650
|
+
complex_iplddt: float
|
|
651
|
+
complex_pde: float
|
|
652
|
+
complex_ipde: float
|
|
653
|
+
chains_ptm: dict[str, float]
|
|
654
|
+
pair_chains_iptm: dict[str, dict[str, float]]
|
|
655
|
+
|
|
656
|
+
|
|
657
|
+
class BoltzAffinity(BaseModel):
|
|
658
|
+
"""
|
|
659
|
+
Output schema for Boltz affinity ensemble predictions.
|
|
660
|
+
|
|
661
|
+
Attributes
|
|
662
|
+
----------
|
|
663
|
+
affinity_pred_value : float
|
|
664
|
+
Predicted binding affinity from the ensemble model.
|
|
665
|
+
affinity_probability_binary : float
|
|
666
|
+
Predicted binding likelihood from the ensemble model.
|
|
667
|
+
per_model : dict of str to float
|
|
668
|
+
Dictionary containing predictions from each individual model in the ensemble.
|
|
669
|
+
Keys are of the form 'affinity_pred_valueN' and 'affinity_probability_binaryN',
|
|
670
|
+
where N is the model index (e.g., 1, 2, 3, ...).
|
|
671
|
+
|
|
672
|
+
Notes
|
|
673
|
+
-----
|
|
674
|
+
Use the `parse_obj_with_models` class method to construct this object from a raw output
|
|
675
|
+
dictionary, which will automatically separate ensemble-level and per-model predictions.
|
|
676
|
+
"""
|
|
677
|
+
|
|
678
|
+
affinity_pred_value: float
|
|
679
|
+
affinity_probability_binary: float
|
|
680
|
+
# Catch all other per-model fields
|
|
681
|
+
per_model: dict[str, float] = Field(default_factory=dict)
|
|
682
|
+
|
|
683
|
+
@classmethod
|
|
684
|
+
def parse_obj_with_models(cls, obj: dict):
|
|
685
|
+
# Extract fixed fields
|
|
686
|
+
fixed = {
|
|
687
|
+
"affinity_pred_value": obj.pop("affinity_pred_value"),
|
|
688
|
+
"affinity_probability_binary": obj.pop("affinity_probability_binary"),
|
|
689
|
+
}
|
|
690
|
+
# Everything else goes into per_model
|
|
691
|
+
return cls(**fixed, per_model=obj)
|