openprotein-python 0.8.2__1-py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openprotein/__init__.py +164 -0
- openprotein/_version.py +48 -0
- openprotein/align/__init__.py +8 -0
- openprotein/align/align.py +395 -0
- openprotein/align/api.py +428 -0
- openprotein/align/future.py +55 -0
- openprotein/align/msa.py +129 -0
- openprotein/align/schemas.py +165 -0
- openprotein/base.py +181 -0
- openprotein/chains.py +88 -0
- openprotein/common/__init__.py +5 -0
- openprotein/common/features.py +7 -0
- openprotein/common/model_metadata.py +33 -0
- openprotein/common/reduction.py +8 -0
- openprotein/config.py +9 -0
- openprotein/csv.py +31 -0
- openprotein/data/__init__.py +9 -0
- openprotein/data/api.py +218 -0
- openprotein/data/assaydataset.py +178 -0
- openprotein/data/data.py +93 -0
- openprotein/data/schemas.py +27 -0
- openprotein/design/__init__.py +16 -0
- openprotein/design/api.py +259 -0
- openprotein/design/design.py +125 -0
- openprotein/design/future.py +146 -0
- openprotein/design/schemas.py +607 -0
- openprotein/embeddings/__init__.py +27 -0
- openprotein/embeddings/api.py +619 -0
- openprotein/embeddings/embeddings.py +151 -0
- openprotein/embeddings/esm.py +33 -0
- openprotein/embeddings/future.py +146 -0
- openprotein/embeddings/models.py +421 -0
- openprotein/embeddings/openprotein.py +21 -0
- openprotein/embeddings/poet.py +446 -0
- openprotein/embeddings/poet2.py +505 -0
- openprotein/embeddings/schemas.py +78 -0
- openprotein/errors.py +76 -0
- openprotein/fasta.py +92 -0
- openprotein/fold/__init__.py +21 -0
- openprotein/fold/alphafold2.py +131 -0
- openprotein/fold/api.py +287 -0
- openprotein/fold/boltz.py +691 -0
- openprotein/fold/esmfold.py +54 -0
- openprotein/fold/fold.py +107 -0
- openprotein/fold/future.py +509 -0
- openprotein/fold/models.py +139 -0
- openprotein/fold/schemas.py +39 -0
- openprotein/jobs/__init__.py +9 -0
- openprotein/jobs/api.py +71 -0
- openprotein/jobs/futures.py +746 -0
- openprotein/jobs/jobs.py +69 -0
- openprotein/jobs/schemas.py +135 -0
- openprotein/models/__init__.py +4 -0
- openprotein/models/base.py +63 -0
- openprotein/models/foundation/rfdiffusion.py +283 -0
- openprotein/models/models.py +33 -0
- openprotein/predictor/__init__.py +25 -0
- openprotein/predictor/api.py +384 -0
- openprotein/predictor/models.py +374 -0
- openprotein/predictor/prediction.py +79 -0
- openprotein/predictor/predictor.py +242 -0
- openprotein/predictor/schemas.py +113 -0
- openprotein/predictor/validate.py +40 -0
- openprotein/prompt/__init__.py +9 -0
- openprotein/prompt/api.py +505 -0
- openprotein/prompt/models.py +142 -0
- openprotein/prompt/prompt.py +130 -0
- openprotein/prompt/schemas.py +49 -0
- openprotein/protein.py +587 -0
- openprotein/svd/__init__.py +9 -0
- openprotein/svd/api.py +206 -0
- openprotein/svd/models.py +288 -0
- openprotein/svd/schemas.py +31 -0
- openprotein/svd/svd.py +134 -0
- openprotein/umap/__init__.py +9 -0
- openprotein/umap/api.py +259 -0
- openprotein/umap/models.py +211 -0
- openprotein/umap/schemas.py +35 -0
- openprotein/umap/umap.py +175 -0
- openprotein/utils/uuid.py +29 -0
- openprotein_python-0.8.2.dist-info/METADATA +176 -0
- openprotein_python-0.8.2.dist-info/RECORD +84 -0
- openprotein_python-0.8.2.dist-info/WHEEL +4 -0
- openprotein_python-0.8.2.dist-info/licenses/LICENSE.txt +30 -0
openprotein/fasta.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
from typing import Iterator, Sequence, overload
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
@overload
|
|
5
|
+
def parse_stream(
|
|
6
|
+
lines: Iterator[str], comment: str = "#"
|
|
7
|
+
) -> Iterator[tuple[str, str]]: ...
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@overload
|
|
11
|
+
def parse_stream(
|
|
12
|
+
lines: Iterator[bytes], comment: str = "#"
|
|
13
|
+
) -> Iterator[tuple[bytes, bytes]]: ...
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def parse_stream(
|
|
17
|
+
lines: Iterator[str] | Iterator[bytes], comment: str = "#"
|
|
18
|
+
) -> Iterator[tuple[str, str]] | Iterator[tuple[bytes, bytes]]:
|
|
19
|
+
is_bytes: bool | None = None
|
|
20
|
+
name = None
|
|
21
|
+
sequence = []
|
|
22
|
+
|
|
23
|
+
for line in lines:
|
|
24
|
+
if not line:
|
|
25
|
+
continue # skip empty lines
|
|
26
|
+
if is_bytes := isinstance(line, bytes):
|
|
27
|
+
line = line.decode()
|
|
28
|
+
if line.startswith(comment):
|
|
29
|
+
continue
|
|
30
|
+
line = line.strip()
|
|
31
|
+
if line.startswith(">"):
|
|
32
|
+
if name is not None:
|
|
33
|
+
sequence = "".join(sequence)
|
|
34
|
+
if is_bytes:
|
|
35
|
+
name = name.encode()
|
|
36
|
+
sequence = sequence.encode()
|
|
37
|
+
yield name, sequence
|
|
38
|
+
else:
|
|
39
|
+
yield name, sequence
|
|
40
|
+
name = line[1:].strip()
|
|
41
|
+
sequence = []
|
|
42
|
+
else:
|
|
43
|
+
sequence.append(line.strip())
|
|
44
|
+
|
|
45
|
+
if name is not None:
|
|
46
|
+
sequence = "".join(sequence)
|
|
47
|
+
if is_bytes:
|
|
48
|
+
name = name.encode()
|
|
49
|
+
sequence = sequence.encode()
|
|
50
|
+
yield name, sequence
|
|
51
|
+
else:
|
|
52
|
+
yield name, sequence
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def parse(
|
|
56
|
+
f: Sequence[str] | Sequence[bytes], comment: str = "#"
|
|
57
|
+
) -> tuple[list[str], list[str]] | tuple[list[bytes], list[bytes]]:
|
|
58
|
+
is_bytes: bool | None = None
|
|
59
|
+
names = []
|
|
60
|
+
sequences = []
|
|
61
|
+
name = None
|
|
62
|
+
sequence = []
|
|
63
|
+
for line in f:
|
|
64
|
+
if is_bytes := isinstance(line, bytes):
|
|
65
|
+
line = line.decode()
|
|
66
|
+
if line.startswith(comment):
|
|
67
|
+
continue
|
|
68
|
+
line = line.strip()
|
|
69
|
+
if line.startswith(">"):
|
|
70
|
+
# its a new entry
|
|
71
|
+
if name is not None:
|
|
72
|
+
sequence = "".join(sequence)
|
|
73
|
+
if is_bytes:
|
|
74
|
+
name = name.encode()
|
|
75
|
+
sequence = sequence.encode()
|
|
76
|
+
names.append(name)
|
|
77
|
+
sequences.append(sequence)
|
|
78
|
+
# reset the reading
|
|
79
|
+
name = line[1:]
|
|
80
|
+
sequence = []
|
|
81
|
+
else:
|
|
82
|
+
sequence.append(line.upper())
|
|
83
|
+
if name is not None:
|
|
84
|
+
# last entry
|
|
85
|
+
sequence = "".join(sequence)
|
|
86
|
+
if is_bytes:
|
|
87
|
+
name = name.encode()
|
|
88
|
+
sequence = sequence.encode()
|
|
89
|
+
names.append(name)
|
|
90
|
+
sequences.append(sequence)
|
|
91
|
+
|
|
92
|
+
return names, sequences
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Fold module for predicting structures on OpenProtein.
|
|
3
|
+
|
|
4
|
+
isort:skip_file
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from .schemas import FoldJob, FoldMetadata
|
|
8
|
+
from .models import FoldModel
|
|
9
|
+
from .esmfold import ESMFoldModel
|
|
10
|
+
from .alphafold2 import AlphaFold2Model
|
|
11
|
+
from .boltz import (
|
|
12
|
+
Boltz1Model,
|
|
13
|
+
Boltz1xModel,
|
|
14
|
+
Boltz2Model,
|
|
15
|
+
BoltzAffinity,
|
|
16
|
+
BoltzConfidence,
|
|
17
|
+
BoltzConstraint,
|
|
18
|
+
BoltzProperty,
|
|
19
|
+
)
|
|
20
|
+
from .future import FoldResultFuture, FoldComplexResultFuture
|
|
21
|
+
from .fold import FoldAPI
|
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
"""Community-based AlphaFold 2 model running using ColabFold."""
|
|
2
|
+
|
|
3
|
+
import warnings
|
|
4
|
+
from collections import Counter
|
|
5
|
+
|
|
6
|
+
from openprotein.align import MSAFuture
|
|
7
|
+
from openprotein.base import APISession
|
|
8
|
+
from openprotein.common import ModelMetadata
|
|
9
|
+
from openprotein.protein import Protein
|
|
10
|
+
|
|
11
|
+
from . import api
|
|
12
|
+
from .future import FoldComplexResultFuture
|
|
13
|
+
from .models import FoldModel
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class AlphaFold2Model(FoldModel):
|
|
17
|
+
"""
|
|
18
|
+
Class providing inference endpoints for AlphaFold2 structure prediction models, based on the implementation by ColabFold.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
model_id: str = "alphafold2"
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
session: APISession,
|
|
26
|
+
model_id: str,
|
|
27
|
+
metadata: ModelMetadata | None = None,
|
|
28
|
+
):
|
|
29
|
+
super().__init__(session=session, model_id=model_id, metadata=metadata)
|
|
30
|
+
|
|
31
|
+
def fold(
|
|
32
|
+
self,
|
|
33
|
+
proteins: list[Protein] | MSAFuture | None = None,
|
|
34
|
+
num_recycles: int | None = None,
|
|
35
|
+
num_models: int = 1,
|
|
36
|
+
num_relax: int = 0,
|
|
37
|
+
**kwargs,
|
|
38
|
+
) -> FoldComplexResultFuture:
|
|
39
|
+
"""
|
|
40
|
+
Post sequences to alphafold model.
|
|
41
|
+
|
|
42
|
+
Parameters
|
|
43
|
+
----------
|
|
44
|
+
proteins : List[Protein] | MSAFuture
|
|
45
|
+
List of protein sequences to fold. `Protein` objects must be tagged with an `msa`. Alternatively, supply an `MSAFuture` to use all query sequences as a multimer.
|
|
46
|
+
num_recycles : int
|
|
47
|
+
number of times to recycle models
|
|
48
|
+
num_models : int
|
|
49
|
+
number of models to train - best model will be used
|
|
50
|
+
max_msa : Union[str, int]
|
|
51
|
+
maximum number of sequences in the msa to use.
|
|
52
|
+
relax_max_iterations : int
|
|
53
|
+
maximum number of iterations
|
|
54
|
+
|
|
55
|
+
Returns
|
|
56
|
+
-------
|
|
57
|
+
job : Job
|
|
58
|
+
"""
|
|
59
|
+
if "msa" in kwargs:
|
|
60
|
+
warnings.warn(
|
|
61
|
+
"Inputs to AlphaFold 2 have been updated. 'msa' should be supplied as 'proteins' argument. Support will be dropped in the future."
|
|
62
|
+
)
|
|
63
|
+
proteins = kwargs["msa"]
|
|
64
|
+
if "ligands" in kwargs or "dnas" in kwargs or "rnas" in kwargs:
|
|
65
|
+
with warnings.catch_warnings():
|
|
66
|
+
warnings.simplefilter("always") # Force warning to always show
|
|
67
|
+
warnings.warn(
|
|
68
|
+
"Alphafold 2 only supports proteins. All other chains will be ignored"
|
|
69
|
+
)
|
|
70
|
+
if proteins is None:
|
|
71
|
+
raise TypeError("Expected 'proteins' argument")
|
|
72
|
+
if isinstance(proteins, list):
|
|
73
|
+
msa_to_seed: dict[str, Counter] = dict()
|
|
74
|
+
for protein in proteins:
|
|
75
|
+
if (msa := protein.msa) is not None:
|
|
76
|
+
msa_id = msa.id if isinstance(msa, MSAFuture) else msa
|
|
77
|
+
if msa_id in msa_to_seed:
|
|
78
|
+
seeds = msa_to_seed[msa_id]
|
|
79
|
+
else:
|
|
80
|
+
from openprotein.align import AlignAPI
|
|
81
|
+
|
|
82
|
+
align_api = getattr(self.session, "align", None)
|
|
83
|
+
assert isinstance(align_api, AlignAPI)
|
|
84
|
+
seed = align_api.get_seed(job_id=msa_id)
|
|
85
|
+
# need a counter so we can make sure later that the proteins make up the msa completely
|
|
86
|
+
seeds = Counter(seed.split(":"))
|
|
87
|
+
msa_to_seed[msa_id] = seeds
|
|
88
|
+
# check that this protein is in the seed
|
|
89
|
+
if protein.sequence.decode() not in seeds:
|
|
90
|
+
raise ValueError(
|
|
91
|
+
f"Expected specified msa_id {msa_id} for protein {protein.sequence} to contain the sequence as part of its seed/query"
|
|
92
|
+
)
|
|
93
|
+
else:
|
|
94
|
+
raise ValueError("Expected msa for protein when using AlphaFold 2")
|
|
95
|
+
# now make sure we only have one msa
|
|
96
|
+
if len(msa_to_seed) > 1:
|
|
97
|
+
raise ValueError("Expected only 1 unique msa when using AlphaFold 2")
|
|
98
|
+
# now check that the list of proteins completely make up the msa
|
|
99
|
+
seeds = list(msa_to_seed.values())[0] # should have just 1
|
|
100
|
+
for protein in proteins:
|
|
101
|
+
# make sure to account for multimers
|
|
102
|
+
seeds[protein.sequence.decode()] -= (
|
|
103
|
+
len(protein.chain_id) if isinstance(protein.chain_id, list) else 1
|
|
104
|
+
)
|
|
105
|
+
# handle when too many of a sequence in the list of proteins
|
|
106
|
+
if seeds[protein.sequence.decode()] < 0:
|
|
107
|
+
raise ValueError(
|
|
108
|
+
"List of proteins does not completely make up the MSA seed"
|
|
109
|
+
)
|
|
110
|
+
if seeds.total() != 0:
|
|
111
|
+
# handle when overall mismatch - 1 and -1 case is handled above
|
|
112
|
+
raise ValueError(
|
|
113
|
+
"List of proteins does not completely make up the MSA seed"
|
|
114
|
+
)
|
|
115
|
+
msa_id = list(msa_to_seed.keys())[0]
|
|
116
|
+
elif isinstance(proteins, MSAFuture):
|
|
117
|
+
msa_id = proteins.id
|
|
118
|
+
else:
|
|
119
|
+
raise TypeError("Expected either list of Proteins or MSAFuture")
|
|
120
|
+
|
|
121
|
+
return FoldComplexResultFuture.create(
|
|
122
|
+
session=self.session,
|
|
123
|
+
job=api.fold_models_post(
|
|
124
|
+
self.session,
|
|
125
|
+
model_id=self.model_id,
|
|
126
|
+
msa_id=msa_id,
|
|
127
|
+
num_recycles=num_recycles,
|
|
128
|
+
num_models=num_models,
|
|
129
|
+
num_relax=num_relax,
|
|
130
|
+
),
|
|
131
|
+
)
|
openprotein/fold/api.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
1
|
+
"""Fold REST API interface for making HTTP calls to our fold backend."""
|
|
2
|
+
|
|
3
|
+
import io
|
|
4
|
+
from typing import Literal
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from pydantic import TypeAdapter
|
|
8
|
+
|
|
9
|
+
from openprotein.base import APISession
|
|
10
|
+
from openprotein.common import ModelMetadata
|
|
11
|
+
from openprotein.errors import HTTPError
|
|
12
|
+
|
|
13
|
+
from .schemas import FoldJob, FoldMetadata
|
|
14
|
+
|
|
15
|
+
PATH_PREFIX = "v1/fold"
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def fold_models_list_get(session: APISession) -> list[str]:
|
|
19
|
+
"""
|
|
20
|
+
List available fold models.
|
|
21
|
+
|
|
22
|
+
Parameters
|
|
23
|
+
----------
|
|
24
|
+
session : APISession
|
|
25
|
+
API session.
|
|
26
|
+
|
|
27
|
+
Returns
|
|
28
|
+
-------
|
|
29
|
+
list of str
|
|
30
|
+
List of model names.
|
|
31
|
+
"""
|
|
32
|
+
endpoint = PATH_PREFIX + "/models"
|
|
33
|
+
response = session.get(endpoint)
|
|
34
|
+
result = response.json()
|
|
35
|
+
return result
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def fold_model_get(session: APISession, model_id: str) -> ModelMetadata:
|
|
39
|
+
"""
|
|
40
|
+
Get metadata for a specific fold model.
|
|
41
|
+
|
|
42
|
+
Parameters
|
|
43
|
+
----------
|
|
44
|
+
session : APISession
|
|
45
|
+
API session.
|
|
46
|
+
model_id : str
|
|
47
|
+
Model ID to fetch.
|
|
48
|
+
|
|
49
|
+
Returns
|
|
50
|
+
-------
|
|
51
|
+
ModelMetadata
|
|
52
|
+
Metadata for the specified model.
|
|
53
|
+
"""
|
|
54
|
+
endpoint = PATH_PREFIX + f"/models/{model_id}"
|
|
55
|
+
response = session.get(endpoint)
|
|
56
|
+
result = response.json()
|
|
57
|
+
return ModelMetadata(**result)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def fold_get(session: APISession, job_id: str) -> FoldMetadata:
|
|
61
|
+
"""
|
|
62
|
+
Get metadata associated with the given request ID.
|
|
63
|
+
|
|
64
|
+
Parameters
|
|
65
|
+
----------
|
|
66
|
+
session : APISession
|
|
67
|
+
Session object for API communication.
|
|
68
|
+
job_id : str
|
|
69
|
+
Fold ID to fetch.
|
|
70
|
+
|
|
71
|
+
Returns
|
|
72
|
+
-------
|
|
73
|
+
FoldMetadata
|
|
74
|
+
Metadata about the fold job.
|
|
75
|
+
"""
|
|
76
|
+
endpoint = PATH_PREFIX + f"/{job_id}"
|
|
77
|
+
response = session.get(endpoint)
|
|
78
|
+
fold = FoldMetadata.model_validate(response.json())
|
|
79
|
+
return fold
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def fold_get_sequences(session: APISession, job_id: str) -> list[bytes]:
|
|
83
|
+
"""
|
|
84
|
+
Get results associated with the given request ID.
|
|
85
|
+
|
|
86
|
+
Parameters
|
|
87
|
+
----------
|
|
88
|
+
session : APISession
|
|
89
|
+
Session object for API communication.
|
|
90
|
+
job_id : str
|
|
91
|
+
Job ID to fetch.
|
|
92
|
+
|
|
93
|
+
Returns
|
|
94
|
+
-------
|
|
95
|
+
list of bytes
|
|
96
|
+
List of sequences as bytes.
|
|
97
|
+
"""
|
|
98
|
+
endpoint = PATH_PREFIX + f"/{job_id}/sequences"
|
|
99
|
+
response = session.get(endpoint)
|
|
100
|
+
return TypeAdapter(list[bytes]).validate_python(response.json())
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def fold_get_sequence_result(
|
|
104
|
+
session: APISession, job_id: str, sequence: bytes | str
|
|
105
|
+
) -> bytes:
|
|
106
|
+
"""
|
|
107
|
+
Get encoded result for a sequence from the request ID.
|
|
108
|
+
|
|
109
|
+
Parameters
|
|
110
|
+
----------
|
|
111
|
+
session : APISession
|
|
112
|
+
Session object for API communication.
|
|
113
|
+
job_id : str
|
|
114
|
+
Job ID to retrieve results from.
|
|
115
|
+
sequence : bytes or str
|
|
116
|
+
Sequence to retrieve results for.
|
|
117
|
+
|
|
118
|
+
Returns
|
|
119
|
+
-------
|
|
120
|
+
bytes
|
|
121
|
+
Encoded result for the sequence.
|
|
122
|
+
"""
|
|
123
|
+
if isinstance(sequence, bytes):
|
|
124
|
+
sequence = sequence.decode()
|
|
125
|
+
endpoint = PATH_PREFIX + f"/{job_id}/{sequence}"
|
|
126
|
+
response = session.get(endpoint)
|
|
127
|
+
return response.content
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def fold_get_complex_result(
|
|
131
|
+
session: APISession, job_id: str, format: Literal["pdb", "mmcif"]
|
|
132
|
+
) -> bytes:
|
|
133
|
+
"""
|
|
134
|
+
Get encoded result for a complex from the request ID.
|
|
135
|
+
|
|
136
|
+
Parameters
|
|
137
|
+
----------
|
|
138
|
+
session : APISession
|
|
139
|
+
Session object for API communication.
|
|
140
|
+
job_id : str
|
|
141
|
+
Job ID to retrieve results from.
|
|
142
|
+
format : {'pdb', 'mmcif'}
|
|
143
|
+
Format of the result.
|
|
144
|
+
|
|
145
|
+
Returns
|
|
146
|
+
-------
|
|
147
|
+
bytes
|
|
148
|
+
Encoded result for the complex.
|
|
149
|
+
"""
|
|
150
|
+
endpoint = PATH_PREFIX + f"/{job_id}/complex"
|
|
151
|
+
response = session.get(
|
|
152
|
+
endpoint,
|
|
153
|
+
params={
|
|
154
|
+
"format": format,
|
|
155
|
+
},
|
|
156
|
+
)
|
|
157
|
+
return response.content
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def fold_get_complex_extra_result(
|
|
161
|
+
session: APISession,
|
|
162
|
+
job_id: str,
|
|
163
|
+
key: Literal["pae", "pde", "plddt", "confidence", "affinity"],
|
|
164
|
+
) -> np.ndarray | list[dict]:
|
|
165
|
+
"""
|
|
166
|
+
Get extra result for a complex from the request ID.
|
|
167
|
+
|
|
168
|
+
Parameters
|
|
169
|
+
----------
|
|
170
|
+
session : APISession
|
|
171
|
+
Session object for API communication.
|
|
172
|
+
job_id : str
|
|
173
|
+
Job ID to retrieve results from.
|
|
174
|
+
key : {'pae', 'pde', 'plddt', 'confidence', 'affinity'}
|
|
175
|
+
The type of result to retrieve.
|
|
176
|
+
|
|
177
|
+
Returns
|
|
178
|
+
-------
|
|
179
|
+
numpy.ndarray or list of dict
|
|
180
|
+
The result as a numpy array (for "pae", "pde", "plddt") or a list of dictionaries (for "confidence", "affinity").
|
|
181
|
+
"""
|
|
182
|
+
if key in {"pae", "pde", "plddt"}:
|
|
183
|
+
formatter = lambda response: np.load(io.BytesIO(response.content))
|
|
184
|
+
elif key in {"confidence", "affinity"}:
|
|
185
|
+
formatter = lambda response: response.json()
|
|
186
|
+
else:
|
|
187
|
+
raise ValueError(f"Unexpected key: {key}")
|
|
188
|
+
endpoint = PATH_PREFIX + f"/{job_id}/complex/{key}"
|
|
189
|
+
try:
|
|
190
|
+
response = session.get(
|
|
191
|
+
endpoint,
|
|
192
|
+
)
|
|
193
|
+
except HTTPError as e:
|
|
194
|
+
if e.status_code == 400 and key == "affinity":
|
|
195
|
+
raise ValueError("affinity not found for request") from None
|
|
196
|
+
raise e
|
|
197
|
+
output: np.ndarray | list[dict] = formatter(response)
|
|
198
|
+
return output
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def fold_models_post(
|
|
202
|
+
session: APISession,
|
|
203
|
+
model_id: str,
|
|
204
|
+
**kwargs,
|
|
205
|
+
) -> FoldJob:
|
|
206
|
+
"""
|
|
207
|
+
POST a request for structure prediction.
|
|
208
|
+
|
|
209
|
+
Returns a Job object referring to this request
|
|
210
|
+
that can be used to retrieve results later.
|
|
211
|
+
|
|
212
|
+
Parameters
|
|
213
|
+
----------
|
|
214
|
+
session : APISession
|
|
215
|
+
Session object for API communication.
|
|
216
|
+
model_id : str
|
|
217
|
+
Model ID to use for prediction.
|
|
218
|
+
sequences : sequence of bytes or str, optional
|
|
219
|
+
Sequences to request results for.
|
|
220
|
+
msa_id : str, optional
|
|
221
|
+
MSA ID to use.
|
|
222
|
+
num_recycles : int, optional
|
|
223
|
+
Number of recycles for structure prediction.
|
|
224
|
+
num_models : int, optional
|
|
225
|
+
Number of models to generate.
|
|
226
|
+
num_relax : int, optional
|
|
227
|
+
Number of relaxation steps.
|
|
228
|
+
use_potentials : bool, optional
|
|
229
|
+
Whether to use potentials.
|
|
230
|
+
diffusion_samples : int, optional
|
|
231
|
+
Number of diffusion samples (boltz).
|
|
232
|
+
recycling_steps : int, optional
|
|
233
|
+
Number of recycling steps (boltz).
|
|
234
|
+
sampling_steps : int, optional
|
|
235
|
+
Number of sampling steps (boltz).
|
|
236
|
+
step_scale : float, optional
|
|
237
|
+
Step scale (boltz).
|
|
238
|
+
constraints : dict, optional
|
|
239
|
+
Constraints to apply.
|
|
240
|
+
templates : list, optional
|
|
241
|
+
Templates to use.
|
|
242
|
+
properties : dict, optional
|
|
243
|
+
Additional properties.
|
|
244
|
+
|
|
245
|
+
Returns
|
|
246
|
+
-------
|
|
247
|
+
FoldJob
|
|
248
|
+
Job object referring to this request.
|
|
249
|
+
"""
|
|
250
|
+
endpoint = PATH_PREFIX + f"/models/{model_id}"
|
|
251
|
+
|
|
252
|
+
body: dict = {}
|
|
253
|
+
if kwargs.get("sequences"):
|
|
254
|
+
sequences = kwargs["sequences"]
|
|
255
|
+
# NOTE we are handling the boltz form here too
|
|
256
|
+
sequences = [s.decode() if isinstance(s, bytes) else s for s in sequences]
|
|
257
|
+
body["sequences"] = sequences
|
|
258
|
+
if kwargs.get("msa_id"):
|
|
259
|
+
body["msa_id"] = kwargs["msa_id"]
|
|
260
|
+
if kwargs.get("num_recycles"):
|
|
261
|
+
body["num_recycles"] = kwargs["num_recycles"]
|
|
262
|
+
if kwargs.get("num_models"):
|
|
263
|
+
body["num_models"] = kwargs["num_models"]
|
|
264
|
+
if kwargs.get("num_relax"):
|
|
265
|
+
body["num_relax"] = kwargs["num_relax"]
|
|
266
|
+
if kwargs.get("use_potentials"):
|
|
267
|
+
body["use_potentials"] = kwargs["use_potentials"]
|
|
268
|
+
# boltz
|
|
269
|
+
if kwargs.get("diffusion_samples"):
|
|
270
|
+
body["diffusion_samples"] = kwargs["diffusion_samples"]
|
|
271
|
+
if kwargs.get("recycling_steps"):
|
|
272
|
+
body["recycling_steps"] = kwargs["recycling_steps"]
|
|
273
|
+
if kwargs.get("sampling_steps"):
|
|
274
|
+
body["sampling_steps"] = kwargs["sampling_steps"]
|
|
275
|
+
if kwargs.get("step_scale"):
|
|
276
|
+
body["step_scale"] = kwargs["step_scale"]
|
|
277
|
+
if kwargs.get("constraints"):
|
|
278
|
+
body["constraints"] = kwargs["constraints"]
|
|
279
|
+
if kwargs.get("templates"):
|
|
280
|
+
body["templates"] = kwargs["templates"]
|
|
281
|
+
if kwargs.get("properties"):
|
|
282
|
+
body["properties"] = kwargs["properties"]
|
|
283
|
+
if kwargs.get("method"):
|
|
284
|
+
body["method"] = kwargs["method"]
|
|
285
|
+
|
|
286
|
+
response = session.post(endpoint, json=body)
|
|
287
|
+
return FoldJob.model_validate(response.json())
|