openprotein-python 0.8.2__1-py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openprotein/__init__.py +164 -0
- openprotein/_version.py +48 -0
- openprotein/align/__init__.py +8 -0
- openprotein/align/align.py +395 -0
- openprotein/align/api.py +428 -0
- openprotein/align/future.py +55 -0
- openprotein/align/msa.py +129 -0
- openprotein/align/schemas.py +165 -0
- openprotein/base.py +181 -0
- openprotein/chains.py +88 -0
- openprotein/common/__init__.py +5 -0
- openprotein/common/features.py +7 -0
- openprotein/common/model_metadata.py +33 -0
- openprotein/common/reduction.py +8 -0
- openprotein/config.py +9 -0
- openprotein/csv.py +31 -0
- openprotein/data/__init__.py +9 -0
- openprotein/data/api.py +218 -0
- openprotein/data/assaydataset.py +178 -0
- openprotein/data/data.py +93 -0
- openprotein/data/schemas.py +27 -0
- openprotein/design/__init__.py +16 -0
- openprotein/design/api.py +259 -0
- openprotein/design/design.py +125 -0
- openprotein/design/future.py +146 -0
- openprotein/design/schemas.py +607 -0
- openprotein/embeddings/__init__.py +27 -0
- openprotein/embeddings/api.py +619 -0
- openprotein/embeddings/embeddings.py +151 -0
- openprotein/embeddings/esm.py +33 -0
- openprotein/embeddings/future.py +146 -0
- openprotein/embeddings/models.py +421 -0
- openprotein/embeddings/openprotein.py +21 -0
- openprotein/embeddings/poet.py +446 -0
- openprotein/embeddings/poet2.py +505 -0
- openprotein/embeddings/schemas.py +78 -0
- openprotein/errors.py +76 -0
- openprotein/fasta.py +92 -0
- openprotein/fold/__init__.py +21 -0
- openprotein/fold/alphafold2.py +131 -0
- openprotein/fold/api.py +287 -0
- openprotein/fold/boltz.py +691 -0
- openprotein/fold/esmfold.py +54 -0
- openprotein/fold/fold.py +107 -0
- openprotein/fold/future.py +509 -0
- openprotein/fold/models.py +139 -0
- openprotein/fold/schemas.py +39 -0
- openprotein/jobs/__init__.py +9 -0
- openprotein/jobs/api.py +71 -0
- openprotein/jobs/futures.py +746 -0
- openprotein/jobs/jobs.py +69 -0
- openprotein/jobs/schemas.py +135 -0
- openprotein/models/__init__.py +4 -0
- openprotein/models/base.py +63 -0
- openprotein/models/foundation/rfdiffusion.py +283 -0
- openprotein/models/models.py +33 -0
- openprotein/predictor/__init__.py +25 -0
- openprotein/predictor/api.py +384 -0
- openprotein/predictor/models.py +374 -0
- openprotein/predictor/prediction.py +79 -0
- openprotein/predictor/predictor.py +242 -0
- openprotein/predictor/schemas.py +113 -0
- openprotein/predictor/validate.py +40 -0
- openprotein/prompt/__init__.py +9 -0
- openprotein/prompt/api.py +505 -0
- openprotein/prompt/models.py +142 -0
- openprotein/prompt/prompt.py +130 -0
- openprotein/prompt/schemas.py +49 -0
- openprotein/protein.py +587 -0
- openprotein/svd/__init__.py +9 -0
- openprotein/svd/api.py +206 -0
- openprotein/svd/models.py +288 -0
- openprotein/svd/schemas.py +31 -0
- openprotein/svd/svd.py +134 -0
- openprotein/umap/__init__.py +9 -0
- openprotein/umap/api.py +259 -0
- openprotein/umap/models.py +211 -0
- openprotein/umap/schemas.py +35 -0
- openprotein/umap/umap.py +175 -0
- openprotein/utils/uuid.py +29 -0
- openprotein_python-0.8.2.dist-info/METADATA +176 -0
- openprotein_python-0.8.2.dist-info/RECORD +84 -0
- openprotein_python-0.8.2.dist-info/WHEEL +4 -0
- openprotein_python-0.8.2.dist-info/licenses/LICENSE.txt +30 -0
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
"""Schemas for OpenProtein predictor system."""
|
|
2
|
+
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
from enum import Enum
|
|
5
|
+
from typing import Literal
|
|
6
|
+
|
|
7
|
+
from pydantic import BaseModel, ConfigDict
|
|
8
|
+
|
|
9
|
+
from openprotein.common import FeatureType
|
|
10
|
+
from openprotein.jobs import Job, JobStatus, JobType
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Kernel(BaseModel):
|
|
14
|
+
type: str
|
|
15
|
+
multitask: bool = False
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class Constraints(BaseModel):
|
|
19
|
+
sequence_length: int | None = None
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class PredictorType(str, Enum):
|
|
23
|
+
GP = "GP"
|
|
24
|
+
ENSEMBLE = "ENSEMBLE"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class Features(BaseModel):
|
|
28
|
+
type: FeatureType
|
|
29
|
+
model_id: str | None = None
|
|
30
|
+
reduction: str | None = None
|
|
31
|
+
|
|
32
|
+
model_config = ConfigDict(protected_namespaces=())
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class PredictorArgs(BaseModel):
|
|
36
|
+
kernel: Kernel | None = None
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class ModelSpec(PredictorArgs, BaseModel):
|
|
40
|
+
type: PredictorType
|
|
41
|
+
constraints: Constraints | None = None
|
|
42
|
+
features: Features | None = None
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class Dataset(BaseModel):
|
|
46
|
+
assay_id: str
|
|
47
|
+
properties: list[str]
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class PredictorMetadata(BaseModel):
|
|
51
|
+
"""Metadata about the predictor."""
|
|
52
|
+
|
|
53
|
+
class CalibrationStats(BaseModel):
|
|
54
|
+
"""Calibration stats for this predictor, based on the latest crossvalidation."""
|
|
55
|
+
|
|
56
|
+
pearson: float | None = None
|
|
57
|
+
spearman: float | None = None
|
|
58
|
+
ece: float | None = None
|
|
59
|
+
|
|
60
|
+
class CalibrationCurvePoint(BaseModel):
|
|
61
|
+
x: float
|
|
62
|
+
y: float
|
|
63
|
+
|
|
64
|
+
id: str
|
|
65
|
+
name: str
|
|
66
|
+
description: str | None = None
|
|
67
|
+
status: JobStatus
|
|
68
|
+
created_date: datetime
|
|
69
|
+
model_spec: ModelSpec
|
|
70
|
+
ensemble_model_ids: list[str] | None = None
|
|
71
|
+
training_dataset: Dataset
|
|
72
|
+
traingraphs: list["TrainGraph"] | None = None
|
|
73
|
+
stats: CalibrationStats | None = None
|
|
74
|
+
curve: list[CalibrationCurvePoint] | None = None
|
|
75
|
+
|
|
76
|
+
def is_done(self):
|
|
77
|
+
return self.status.done()
|
|
78
|
+
|
|
79
|
+
model_config = ConfigDict(protected_namespaces=())
|
|
80
|
+
|
|
81
|
+
class TrainGraph(BaseModel):
|
|
82
|
+
measurement_name: str
|
|
83
|
+
hyperparam_search_step: int
|
|
84
|
+
losses: list[float]
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class PredictorEnsembleJob(Job):
|
|
88
|
+
job_id: None = None
|
|
89
|
+
progress_counter: None = None
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class PredictorTrainJob(Job):
|
|
93
|
+
job_type: Literal[JobType.predictor_train]
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class PredictJob(Job):
|
|
97
|
+
job_type: Literal[JobType.predictor_predict]
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class PredictSingleSiteJob(Job):
|
|
101
|
+
job_type: Literal[JobType.predictor_predict_single_site]
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class PredictMultiJob(Job):
|
|
105
|
+
job_type: Literal[JobType.predictor_predict_multi]
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class PredictMultiSingleSiteJob(Job):
|
|
109
|
+
job_type: Literal[JobType.predictor_predict_multi_single_site]
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class PredictorCVJob(Job):
|
|
113
|
+
job_type: Literal[JobType.predictor_crossvalidate]
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
"""Predictor validation results represented as futures."""
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from openprotein.base import APISession
|
|
6
|
+
from openprotein.jobs import Future
|
|
7
|
+
|
|
8
|
+
from . import api
|
|
9
|
+
from .schemas import PredictorCVJob
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class CVResultFuture(Future):
|
|
13
|
+
"""Future Job for manipulating results"""
|
|
14
|
+
|
|
15
|
+
job: PredictorCVJob
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
session: APISession,
|
|
20
|
+
job: PredictorCVJob,
|
|
21
|
+
):
|
|
22
|
+
super().__init__(session, job)
|
|
23
|
+
|
|
24
|
+
@property
|
|
25
|
+
def id(self):
|
|
26
|
+
return self.job.job_id
|
|
27
|
+
|
|
28
|
+
def get(self, verbose: bool = False) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
29
|
+
"""
|
|
30
|
+
Get embedding results for specified sequence.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
sequence (bytes): sequence to fetch results for
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
mu (np.ndarray): means of predictions
|
|
37
|
+
var (np.ndarray): variances of predictions
|
|
38
|
+
"""
|
|
39
|
+
data = api.predictor_crossvalidate_get(self.session, self.job.job_id)
|
|
40
|
+
return api.decode_crossvalidate(data)
|
|
@@ -0,0 +1,505 @@
|
|
|
1
|
+
"""Prompt REST API interface for making HTTP calls to the prompt backend."""
|
|
2
|
+
|
|
3
|
+
import copy
|
|
4
|
+
import io
|
|
5
|
+
import zipfile
|
|
6
|
+
from typing import BinaryIO, Sequence, cast
|
|
7
|
+
|
|
8
|
+
from openprotein.base import APISession
|
|
9
|
+
from openprotein.errors import APIError, InvalidParameterError, RawAPIError
|
|
10
|
+
from openprotein.protein import Protein
|
|
11
|
+
|
|
12
|
+
from .schemas import Context, PromptMetadata, QueryMetadata
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def create_prompt(
|
|
16
|
+
session: APISession,
|
|
17
|
+
context: Context | Sequence[Context],
|
|
18
|
+
name: str | None = None,
|
|
19
|
+
description: str | None = None,
|
|
20
|
+
) -> PromptMetadata:
|
|
21
|
+
"""
|
|
22
|
+
Create a prompt.
|
|
23
|
+
|
|
24
|
+
Parameters
|
|
25
|
+
----------
|
|
26
|
+
session : APISession
|
|
27
|
+
The API session.
|
|
28
|
+
context : Context or Sequence[Context]
|
|
29
|
+
Context or list of contexts, each of which is a list of sequences/structures.
|
|
30
|
+
name : str or None, optional
|
|
31
|
+
Name of the prompt.
|
|
32
|
+
description : str or None, optional
|
|
33
|
+
Description of the prompt.
|
|
34
|
+
|
|
35
|
+
Returns
|
|
36
|
+
-------
|
|
37
|
+
PromptMetadata
|
|
38
|
+
Metadata of the created prompt.
|
|
39
|
+
|
|
40
|
+
Raises
|
|
41
|
+
------
|
|
42
|
+
InvalidParameterError
|
|
43
|
+
If the parameters are invalid.
|
|
44
|
+
APIError
|
|
45
|
+
If the API returns an error.
|
|
46
|
+
"""
|
|
47
|
+
endpoint = "v1/prompt/create_prompt"
|
|
48
|
+
data = {}
|
|
49
|
+
if name is not None:
|
|
50
|
+
data["name"] = name
|
|
51
|
+
if description is not None:
|
|
52
|
+
data["description"] = description
|
|
53
|
+
|
|
54
|
+
context_zip_files = zip_prompt(context=context)
|
|
55
|
+
|
|
56
|
+
files = [
|
|
57
|
+
("context", (f"context-{i}.zip", context_zip_file, "application/zip"))
|
|
58
|
+
for i, context_zip_file in enumerate(context_zip_files)
|
|
59
|
+
]
|
|
60
|
+
form: dict = {
|
|
61
|
+
"files": files,
|
|
62
|
+
}
|
|
63
|
+
if len(data) > 0:
|
|
64
|
+
form["data"] = data
|
|
65
|
+
|
|
66
|
+
response = session.post(endpoint, **form)
|
|
67
|
+
|
|
68
|
+
if response.status_code == 200:
|
|
69
|
+
return PromptMetadata.model_validate(response.json())
|
|
70
|
+
elif response.status_code == 400:
|
|
71
|
+
error = RawAPIError.model_validate(response.json())
|
|
72
|
+
raise InvalidParameterError(error.detail)
|
|
73
|
+
elif response.status_code == 401:
|
|
74
|
+
error = RawAPIError.model_validate(response.json())
|
|
75
|
+
raise APIError(error.detail)
|
|
76
|
+
else:
|
|
77
|
+
raise APIError(f"Unexpected response status code: {response.status_code}")
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def get_prompt_metadata(session: APISession, prompt_id: str) -> PromptMetadata:
|
|
81
|
+
"""
|
|
82
|
+
Get metadata for a given prompt ID.
|
|
83
|
+
|
|
84
|
+
Parameters
|
|
85
|
+
----------
|
|
86
|
+
session : APISession
|
|
87
|
+
The API session.
|
|
88
|
+
prompt_id : str
|
|
89
|
+
The prompt ID.
|
|
90
|
+
|
|
91
|
+
Returns
|
|
92
|
+
-------
|
|
93
|
+
PromptMetadata
|
|
94
|
+
Metadata of the prompt.
|
|
95
|
+
|
|
96
|
+
Raises
|
|
97
|
+
------
|
|
98
|
+
APIError
|
|
99
|
+
If the API returns an error.
|
|
100
|
+
"""
|
|
101
|
+
endpoint = f"v1/prompt/{prompt_id}"
|
|
102
|
+
response = session.get(endpoint)
|
|
103
|
+
|
|
104
|
+
if response.status_code == 200:
|
|
105
|
+
return PromptMetadata.model_validate(response.json())
|
|
106
|
+
elif response.status_code == 401:
|
|
107
|
+
error = RawAPIError.model_validate(response.json())
|
|
108
|
+
raise APIError(error.detail)
|
|
109
|
+
elif response.status_code == 404:
|
|
110
|
+
error = RawAPIError.model_validate(response.json())
|
|
111
|
+
raise APIError(error.detail)
|
|
112
|
+
else:
|
|
113
|
+
raise APIError(f"Unexpected response status code: {response.status_code}")
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def get_prompt(session: APISession, prompt_id: str) -> list[list[Protein]]:
|
|
117
|
+
"""
|
|
118
|
+
Get the prompt content for a given prompt ID.
|
|
119
|
+
|
|
120
|
+
Parameters
|
|
121
|
+
----------
|
|
122
|
+
session : APISession
|
|
123
|
+
The API session.
|
|
124
|
+
prompt_id : str
|
|
125
|
+
The prompt ID.
|
|
126
|
+
|
|
127
|
+
Returns
|
|
128
|
+
-------
|
|
129
|
+
list of list of Protein
|
|
130
|
+
The prompt data as a list of context protein lists.
|
|
131
|
+
|
|
132
|
+
Raises
|
|
133
|
+
------
|
|
134
|
+
APIError
|
|
135
|
+
If the API returns an error.
|
|
136
|
+
"""
|
|
137
|
+
endpoint = f"v1/prompt/{prompt_id}/content"
|
|
138
|
+
response = session.get(endpoint, stream=True)
|
|
139
|
+
|
|
140
|
+
if response.status_code == 200:
|
|
141
|
+
return unzip_prompt(io.BytesIO(response.content))
|
|
142
|
+
elif response.status_code == 401:
|
|
143
|
+
error = RawAPIError.model_validate(response.json())
|
|
144
|
+
raise APIError(error.detail)
|
|
145
|
+
elif response.status_code == 404:
|
|
146
|
+
error = RawAPIError.model_validate(response.json())
|
|
147
|
+
raise APIError(error.detail)
|
|
148
|
+
else:
|
|
149
|
+
raise APIError(f"Unexpected response status code: {response.status_code}")
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def list_prompts(session: APISession) -> list[PromptMetadata]:
|
|
153
|
+
"""
|
|
154
|
+
List all prompts.
|
|
155
|
+
|
|
156
|
+
Parameters
|
|
157
|
+
----------
|
|
158
|
+
session : APISession
|
|
159
|
+
The API session.
|
|
160
|
+
|
|
161
|
+
Returns
|
|
162
|
+
-------
|
|
163
|
+
list of PromptMetadata
|
|
164
|
+
List of prompt metadata.
|
|
165
|
+
|
|
166
|
+
Raises
|
|
167
|
+
------
|
|
168
|
+
APIError
|
|
169
|
+
If the API returns an error.
|
|
170
|
+
"""
|
|
171
|
+
endpoint = "v1/prompt"
|
|
172
|
+
response = session.get(endpoint)
|
|
173
|
+
|
|
174
|
+
if response.status_code == 200:
|
|
175
|
+
return [PromptMetadata.model_validate(prompt) for prompt in response.json()]
|
|
176
|
+
elif response.status_code == 401:
|
|
177
|
+
error = RawAPIError.model_validate(response.json())
|
|
178
|
+
raise APIError(error.detail)
|
|
179
|
+
else:
|
|
180
|
+
raise APIError(f"Unexpected response status code: {response.status_code}")
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def zip_prompt(
|
|
184
|
+
context: Context | Sequence[Context],
|
|
185
|
+
) -> list[io.BytesIO]:
|
|
186
|
+
"""
|
|
187
|
+
Zip a prompt context to prepare for upload.
|
|
188
|
+
|
|
189
|
+
Parameters
|
|
190
|
+
----------
|
|
191
|
+
context : Context or Sequence[Context]
|
|
192
|
+
A list of proteins, or a group of such proteins (for ensembles), representing the context for the prompt.
|
|
193
|
+
|
|
194
|
+
Returns
|
|
195
|
+
-------
|
|
196
|
+
list of io.BytesIO
|
|
197
|
+
A list of in-memory zip files for the contexts.
|
|
198
|
+
"""
|
|
199
|
+
if len(context) == 0:
|
|
200
|
+
context = [[]]
|
|
201
|
+
if isinstance(context[0], (bytes, str, Protein)):
|
|
202
|
+
context = [cast(Context, context)]
|
|
203
|
+
context = cast(Sequence[Context], context)
|
|
204
|
+
|
|
205
|
+
context_zip_files = []
|
|
206
|
+
for this_context in context:
|
|
207
|
+
this_context_as_proteins: list[Protein] = []
|
|
208
|
+
for i, x in enumerate(this_context):
|
|
209
|
+
if not isinstance(x, Protein):
|
|
210
|
+
x = Protein(name=f"unnamed-{i:06}", sequence=x)
|
|
211
|
+
else:
|
|
212
|
+
x = copy.copy(x)
|
|
213
|
+
if x.name is None:
|
|
214
|
+
x.name = f"unnamed-{i:06}"
|
|
215
|
+
this_context_as_proteins.append(x)
|
|
216
|
+
context_files: list[tuple[str, io.BytesIO]] = []
|
|
217
|
+
for protein in this_context_as_proteins:
|
|
218
|
+
index = len(context_files)
|
|
219
|
+
if protein.has_structure:
|
|
220
|
+
context_files.append(
|
|
221
|
+
(
|
|
222
|
+
f"{index:06}.{protein.name}.cif",
|
|
223
|
+
io.BytesIO(protein.make_cif_string().encode()),
|
|
224
|
+
)
|
|
225
|
+
)
|
|
226
|
+
else:
|
|
227
|
+
# write sequences with no structure as fasta, continuing existing fasta file
|
|
228
|
+
# if previous protein was sequence only
|
|
229
|
+
if len(context_files) == 0 or not context_files[-1][0].endswith(
|
|
230
|
+
".fasta"
|
|
231
|
+
):
|
|
232
|
+
context_files.append((f"{index:06}.fasta", io.BytesIO()))
|
|
233
|
+
_, current_file = context_files[-1]
|
|
234
|
+
current_file.write(protein.make_fasta_bytes())
|
|
235
|
+
# generate context zip file
|
|
236
|
+
in_memory_zip = io.BytesIO()
|
|
237
|
+
with zipfile.ZipFile(in_memory_zip, "w", zipfile.ZIP_DEFLATED) as zf:
|
|
238
|
+
for filename, contents in context_files:
|
|
239
|
+
zf.writestr(filename, contents.getvalue())
|
|
240
|
+
in_memory_zip.seek(0)
|
|
241
|
+
context_zip_files.append(in_memory_zip)
|
|
242
|
+
|
|
243
|
+
return context_zip_files
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def unzip_prompt(prompt_zip: BinaryIO) -> list[list[Protein]]:
|
|
247
|
+
"""
|
|
248
|
+
Unzip a prompt zip file retrieved from the prompt API.
|
|
249
|
+
|
|
250
|
+
This function is the reverse of zip_prompt. It extracts the context proteins
|
|
251
|
+
from a prompt zip file returned by get_prompt().
|
|
252
|
+
|
|
253
|
+
Parameters
|
|
254
|
+
----------
|
|
255
|
+
prompt_zip : BinaryIO
|
|
256
|
+
The binary data of the prompt zip file returned by get_prompt().
|
|
257
|
+
|
|
258
|
+
Returns
|
|
259
|
+
-------
|
|
260
|
+
list of list of Protein
|
|
261
|
+
List of context protein lists, where each inner list represents a context group.
|
|
262
|
+
"""
|
|
263
|
+
context_zip_files = []
|
|
264
|
+
with zipfile.ZipFile(prompt_zip, "r") as zip_file:
|
|
265
|
+
file_names = zip_file.namelist()
|
|
266
|
+
|
|
267
|
+
for file_name in file_names:
|
|
268
|
+
if file_name.startswith("context-"):
|
|
269
|
+
context_zip_file = io.BytesIO(zip_file.read(file_name))
|
|
270
|
+
context_zip_files.append(context_zip_file)
|
|
271
|
+
context = __parse_prompt(context_files=context_zip_files)
|
|
272
|
+
|
|
273
|
+
return context
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
def __parse_prompt(
|
|
277
|
+
context_files: Sequence[BinaryIO],
|
|
278
|
+
) -> list[list[Protein]]:
|
|
279
|
+
"""
|
|
280
|
+
Parse context and query files into proteins.
|
|
281
|
+
|
|
282
|
+
Parameters
|
|
283
|
+
----------
|
|
284
|
+
context_files : Sequence[BinaryIO]
|
|
285
|
+
Sequence of binary zip files, each representing a context group.
|
|
286
|
+
|
|
287
|
+
Returns
|
|
288
|
+
-------
|
|
289
|
+
list of list of Protein
|
|
290
|
+
List of context protein lists, where each inner list represents a context group.
|
|
291
|
+
"""
|
|
292
|
+
context: list[list[Protein]] = []
|
|
293
|
+
|
|
294
|
+
# Process each context file (representing an ensemble)
|
|
295
|
+
for context_file in context_files:
|
|
296
|
+
# Reset the file pointer to the beginning
|
|
297
|
+
context_file.seek(0)
|
|
298
|
+
proteins_in_context: list[Protein] = []
|
|
299
|
+
|
|
300
|
+
with zipfile.ZipFile(context_file, "r") as zf:
|
|
301
|
+
# Sort filenames to process them in a consistent order
|
|
302
|
+
filenames = zf.namelist()
|
|
303
|
+
|
|
304
|
+
# Process each file in the zip
|
|
305
|
+
for filename in filenames:
|
|
306
|
+
with zf.open(filename) as f:
|
|
307
|
+
content = f.read()
|
|
308
|
+
|
|
309
|
+
if filename.endswith(".cif"):
|
|
310
|
+
# For CIF files, create a temporary file for gemmi to read
|
|
311
|
+
import tempfile
|
|
312
|
+
|
|
313
|
+
with tempfile.NamedTemporaryFile(
|
|
314
|
+
suffix=".cif", delete=True
|
|
315
|
+
) as tmp:
|
|
316
|
+
tmp.write(content)
|
|
317
|
+
tmp.flush()
|
|
318
|
+
# extract chain ID (using 'A' as default)
|
|
319
|
+
chain_id = "A"
|
|
320
|
+
# extract name from filename (without extension)
|
|
321
|
+
name = filename[:-4]
|
|
322
|
+
protein = Protein.from_filepath(
|
|
323
|
+
path=tmp.name, chain_id=chain_id, verbose=False
|
|
324
|
+
)
|
|
325
|
+
# override the name with the filename
|
|
326
|
+
protein.name = name
|
|
327
|
+
proteins_in_context.append(protein)
|
|
328
|
+
|
|
329
|
+
elif filename.endswith(".fasta"):
|
|
330
|
+
# Process FASTA file
|
|
331
|
+
import io
|
|
332
|
+
|
|
333
|
+
from openprotein import fasta
|
|
334
|
+
|
|
335
|
+
fasta_stream = io.BytesIO(content)
|
|
336
|
+
for name, sequence in fasta.parse_stream(fasta_stream):
|
|
337
|
+
proteins_in_context.append(
|
|
338
|
+
Protein(name=name, sequence=sequence)
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
# Add this group of proteins to the context
|
|
342
|
+
context.append(proteins_in_context)
|
|
343
|
+
|
|
344
|
+
return context
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
def create_query(
|
|
348
|
+
session: APISession,
|
|
349
|
+
query: bytes | str | Protein,
|
|
350
|
+
) -> QueryMetadata:
|
|
351
|
+
"""
|
|
352
|
+
Create a query.
|
|
353
|
+
|
|
354
|
+
Parameters
|
|
355
|
+
----------
|
|
356
|
+
session : APISession
|
|
357
|
+
The API session.
|
|
358
|
+
query : bytes or str or Protein
|
|
359
|
+
A query representing a protein to be used with a query.
|
|
360
|
+
|
|
361
|
+
Returns
|
|
362
|
+
-------
|
|
363
|
+
QueryMetadata
|
|
364
|
+
Metadata of the created query.
|
|
365
|
+
|
|
366
|
+
Raises
|
|
367
|
+
------
|
|
368
|
+
InvalidParameterError
|
|
369
|
+
If the parameters are invalid.
|
|
370
|
+
APIError
|
|
371
|
+
If the API returns an error.
|
|
372
|
+
"""
|
|
373
|
+
endpoint = "v1/prompt/query"
|
|
374
|
+
|
|
375
|
+
if not isinstance(query, Protein):
|
|
376
|
+
query = Protein(name="query", sequence=query)
|
|
377
|
+
if query.has_structure:
|
|
378
|
+
qf, filename, typ = (
|
|
379
|
+
query.make_cif_string().encode(),
|
|
380
|
+
"query.cif",
|
|
381
|
+
"chemical/x-mmcif",
|
|
382
|
+
)
|
|
383
|
+
else:
|
|
384
|
+
qf, filename, typ = query.make_fasta_bytes(), "query.fasta", "text/x-fasta"
|
|
385
|
+
|
|
386
|
+
response = session.post(endpoint, files={"query": (filename, io.BytesIO(qf), typ)})
|
|
387
|
+
|
|
388
|
+
if response.status_code == 200:
|
|
389
|
+
return QueryMetadata.model_validate(response.json())
|
|
390
|
+
elif response.status_code == 400:
|
|
391
|
+
error = RawAPIError.model_validate(response.json())
|
|
392
|
+
raise InvalidParameterError(error.detail)
|
|
393
|
+
elif response.status_code == 401:
|
|
394
|
+
error = RawAPIError.model_validate(response.json())
|
|
395
|
+
raise APIError(error.detail)
|
|
396
|
+
else:
|
|
397
|
+
raise APIError(f"Unexpected response status code: {response.status_code}")
|
|
398
|
+
|
|
399
|
+
|
|
400
|
+
def get_query_metadata(session: APISession, query_id: str) -> QueryMetadata:
|
|
401
|
+
"""
|
|
402
|
+
Get metadata for a given query ID.
|
|
403
|
+
|
|
404
|
+
Parameters
|
|
405
|
+
----------
|
|
406
|
+
session : APISession
|
|
407
|
+
The API session.
|
|
408
|
+
query_id : str
|
|
409
|
+
The query ID.
|
|
410
|
+
|
|
411
|
+
Returns
|
|
412
|
+
-------
|
|
413
|
+
QueryMetadata
|
|
414
|
+
Metadata of the query.
|
|
415
|
+
|
|
416
|
+
Raises
|
|
417
|
+
------
|
|
418
|
+
APIError
|
|
419
|
+
If the API returns an error.
|
|
420
|
+
"""
|
|
421
|
+
endpoint = f"v1/prompt/query/{query_id}"
|
|
422
|
+
response = session.get(endpoint)
|
|
423
|
+
|
|
424
|
+
if response.status_code == 200:
|
|
425
|
+
return QueryMetadata.model_validate(response.json())
|
|
426
|
+
elif response.status_code == 401:
|
|
427
|
+
error = RawAPIError.model_validate(response.json())
|
|
428
|
+
raise APIError(error.detail)
|
|
429
|
+
elif response.status_code == 404:
|
|
430
|
+
error = RawAPIError.model_validate(response.json())
|
|
431
|
+
raise APIError(error.detail)
|
|
432
|
+
else:
|
|
433
|
+
raise APIError(f"Unexpected response status code: {response.status_code}")
|
|
434
|
+
|
|
435
|
+
|
|
436
|
+
def get_query(session: APISession, query_id: str) -> Protein:
|
|
437
|
+
"""
|
|
438
|
+
Get the query content for a given query ID.
|
|
439
|
+
|
|
440
|
+
Parameters
|
|
441
|
+
----------
|
|
442
|
+
session : APISession
|
|
443
|
+
The API session.
|
|
444
|
+
query_id : str
|
|
445
|
+
The query ID.
|
|
446
|
+
|
|
447
|
+
Returns
|
|
448
|
+
-------
|
|
449
|
+
Protein
|
|
450
|
+
The query protein.
|
|
451
|
+
|
|
452
|
+
Raises
|
|
453
|
+
------
|
|
454
|
+
APIError
|
|
455
|
+
If the API returns an error or the file format is unexpected.
|
|
456
|
+
"""
|
|
457
|
+
endpoint = f"v1/prompt/query/{query_id}/content"
|
|
458
|
+
response = session.get(endpoint, stream=True)
|
|
459
|
+
filename = response.headers.get("Content-Disposition", "query")
|
|
460
|
+
media_type = response.headers.get("Content-Type", "text/plain")
|
|
461
|
+
is_mmcif = filename.endswith(".cif") or media_type == "chemical/x-mmcif"
|
|
462
|
+
is_fasta = filename.endswith(".fasta") or media_type == "text/x-fasta"
|
|
463
|
+
|
|
464
|
+
query_protein = None
|
|
465
|
+
if is_mmcif:
|
|
466
|
+
# for cif files, create a temporary file for gemmi to read
|
|
467
|
+
import tempfile
|
|
468
|
+
|
|
469
|
+
with tempfile.NamedTemporaryFile(suffix=".cif", delete=True) as tmp:
|
|
470
|
+
tmp.write(response.content)
|
|
471
|
+
tmp.flush()
|
|
472
|
+
# extract chain id (using 'A' as default)
|
|
473
|
+
chain_id = "A"
|
|
474
|
+
query_protein = Protein.from_filepath(
|
|
475
|
+
path=tmp.name, chain_id=chain_id, verbose=False
|
|
476
|
+
)
|
|
477
|
+
|
|
478
|
+
elif is_fasta:
|
|
479
|
+
# Process FASTA file - take only the first sequence
|
|
480
|
+
import io
|
|
481
|
+
|
|
482
|
+
from openprotein import fasta
|
|
483
|
+
|
|
484
|
+
fasta_stream = io.BytesIO(response.content)
|
|
485
|
+
for name, sequence in fasta.parse_stream(fasta_stream):
|
|
486
|
+
query_protein = Protein(name=name, sequence=sequence)
|
|
487
|
+
break # Only take the first sequence
|
|
488
|
+
else:
|
|
489
|
+
raise APIError(
|
|
490
|
+
f"Unexpected file returned with filename {filename} and type {media_type}"
|
|
491
|
+
)
|
|
492
|
+
|
|
493
|
+
if query_protein is None:
|
|
494
|
+
raise APIError(f"Invalid query file returned from API {response.content[:10]}")
|
|
495
|
+
|
|
496
|
+
if response.status_code == 200:
|
|
497
|
+
return query_protein
|
|
498
|
+
elif response.status_code == 401:
|
|
499
|
+
error = RawAPIError.model_validate(response.json())
|
|
500
|
+
raise APIError(error.detail)
|
|
501
|
+
elif response.status_code == 404:
|
|
502
|
+
error = RawAPIError.model_validate(response.json())
|
|
503
|
+
raise APIError(error.detail)
|
|
504
|
+
else:
|
|
505
|
+
raise APIError(f"Unexpected response status code: {response.status_code}")
|